diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..07fd857 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,11 @@ +## Contributing to Kafka + +*Before opening a pull request*, review the [Contributing](https://kafka.apache.org/contributing.html) and [Contributing Code Changes](https://cwiki.apache.org/confluence/display/KAFKA/Contributing+Code+Changes) pages. + +It lists steps that are required before creating a PR. + +When you contribute code, you affirm that the contribution is your original work and that you +license the work to the project under the project's open source license. Whether or not you +state this explicitly, by submitting any copyrighted material via pull request, email, or +other means you agree to license the material under the project's open source license and +warrant that you have the legal authority to do so. diff --git a/HEADER b/HEADER new file mode 100644 index 0000000..8853bce --- /dev/null +++ b/HEADER @@ -0,0 +1,14 @@ +Licensed to the Apache Software Foundation (ASF) under one or more +contributor license agreements. See the NOTICE file distributed with +this work for additional information regarding copyright ownership. +The ASF licenses this file to You under the Apache License, Version 2.0 +(the "License"); you may not use this file except in compliance with +the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/Jenkinsfile b/Jenkinsfile new file mode 100644 index 0000000..6b75cfd --- /dev/null +++ b/Jenkinsfile @@ -0,0 +1,274 @@ +/* + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +def doValidation() { + sh """ + ./gradlew -PscalaVersion=$SCALA_VERSION clean compileJava compileScala compileTestJava compileTestScala \ + spotlessScalaCheck checkstyleMain checkstyleTest spotbugsMain rat \ + --profile --no-daemon --continue -PxmlSpotBugsReport=true + """ +} + +def isChangeRequest(env) { + env.CHANGE_ID != null && !env.CHANGE_ID.isEmpty() +} + +def retryFlagsString(env) { + if (isChangeRequest(env)) " -PmaxTestRetries=1 -PmaxTestRetryFailures=5" + else "" +} + +def doTest(env, target = "unitTest integrationTest") { + sh """./gradlew -PscalaVersion=$SCALA_VERSION ${target} \ + --profile --no-daemon --continue -PtestLoggingEvents=started,passed,skipped,failed \ + -PignoreFailures=true -PmaxParallelForks=2""" + retryFlagsString(env) + junit '**/build/test-results/**/TEST-*.xml' +} + +def doStreamsArchetype() { + echo 'Verify that Kafka Streams archetype compiles' + + sh ''' + ./gradlew streams:publishToMavenLocal clients:publishToMavenLocal connect:json:publishToMavenLocal connect:api:publishToMavenLocal \ + || { echo 'Could not publish kafka-streams.jar (and dependencies) locally to Maven'; exit 1; } + ''' + + VERSION = sh(script: 'grep "^version=" gradle.properties | cut -d= -f 2', returnStdout: true).trim() + + dir('streams/quickstart') { + sh ''' + mvn clean install -Dgpg.skip \ + || { echo 'Could not `mvn install` streams quickstart archetype'; exit 1; } + ''' + + dir('test-streams-archetype') { + // Note the double quotes for variable interpolation + sh """ + echo "Y" | mvn archetype:generate \ + -DarchetypeCatalog=local \ + -DarchetypeGroupId=org.apache.kafka \ + -DarchetypeArtifactId=streams-quickstart-java \ + -DarchetypeVersion=${VERSION} \ + -DgroupId=streams.examples \ + -DartifactId=streams.examples \ + -Dversion=0.1 \ + -Dpackage=myapps \ + || { echo 'Could not create new project using streams quickstart archetype'; exit 1; } + """ + + dir('streams.examples') { + sh ''' + mvn compile \ + || { echo 'Could not compile streams quickstart archetype project'; exit 1; } + ''' + } + } + } +} + +def tryStreamsArchetype() { + try { + doStreamsArchetype() + } catch(err) { + echo 'Failed to build Kafka Streams archetype, marking this build UNSTABLE' + currentBuild.result = 'UNSTABLE' + } +} + + +pipeline { + agent none + + options { + disableConcurrentBuilds() + } + + stages { + stage('Build') { + parallel { + + stage('JDK 8 and Scala 2.12') { + agent { label 'ubuntu' } + tools { + jdk 'jdk_1.8_latest' + maven 'maven_3_latest' + } + options { + timeout(time: 8, unit: 'HOURS') + timestamps() + } + environment { + SCALA_VERSION=2.12 + } + steps { + doValidation() + doTest(env) + tryStreamsArchetype() + } + } + + stage('JDK 11 and Scala 2.13') { + agent { label 'ubuntu' } + tools { + jdk 'jdk_11_latest' + } + options { + timeout(time: 8, unit: 'HOURS') + timestamps() + } + environment { + SCALA_VERSION=2.13 + } + steps { + doValidation() + doTest(env) + echo 'Skipping Kafka Streams archetype test for Java 11' + } + } + + stage('JDK 17 and Scala 2.13') { + agent { label 'ubuntu' } + tools { + jdk 'jdk_17_latest' + } + options { + timeout(time: 8, unit: 'HOURS') + timestamps() + } + environment { + SCALA_VERSION=2.13 + } + steps { + doValidation() + doTest(env) + echo 'Skipping Kafka Streams archetype test for Java 17' + } + } + + stage('ARM') { + agent { label 'arm4' } + options { + timeout(time: 2, unit: 'HOURS') + timestamps() + } + environment { + SCALA_VERSION=2.12 + } + steps { + doValidation() + catchError(buildResult: 'SUCCESS', stageResult: 'FAILURE') { + doTest(env, 'unitTest') + } + echo 'Skipping Kafka Streams archetype test for ARM build' + } + } + + // To avoid excessive Jenkins resource usage, we only run the stages + // above at the PR stage. The ones below are executed after changes + // are pushed to trunk and/or release branches. We achieve this via + // the `when` clause. + + stage('JDK 8 and Scala 2.13') { + when { + not { changeRequest() } + beforeAgent true + } + agent { label 'ubuntu' } + tools { + jdk 'jdk_1.8_latest' + maven 'maven_3_latest' + } + options { + timeout(time: 8, unit: 'HOURS') + timestamps() + } + environment { + SCALA_VERSION=2.13 + } + steps { + doValidation() + doTest(env) + tryStreamsArchetype() + } + } + + stage('JDK 11 and Scala 2.12') { + when { + not { changeRequest() } + beforeAgent true + } + agent { label 'ubuntu' } + tools { + jdk 'jdk_11_latest' + } + options { + timeout(time: 8, unit: 'HOURS') + timestamps() + } + environment { + SCALA_VERSION=2.12 + } + steps { + doValidation() + doTest(env) + echo 'Skipping Kafka Streams archetype test for Java 11' + } + } + + stage('JDK 17 and Scala 2.12') { + when { + not { changeRequest() } + beforeAgent true + } + agent { label 'ubuntu' } + tools { + jdk 'jdk_17_latest' + } + options { + timeout(time: 8, unit: 'HOURS') + timestamps() + } + environment { + SCALA_VERSION=2.12 + } + steps { + doValidation() + doTest(env) + echo 'Skipping Kafka Streams archetype test for Java 17' + } + } + } + } + } + + post { + always { + node('ubuntu') { + script { + if (!isChangeRequest(env)) { + step([$class: 'Mailer', + notifyEveryUnstableBuild: true, + recipients: "dev@kafka.apache.org", + sendToIndividuals: false]) + } + } + } + } + } +} diff --git a/LICENSE-binary b/LICENSE-binary new file mode 100644 index 0000000..42a8d79 --- /dev/null +++ b/LICENSE-binary @@ -0,0 +1,321 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +------------------------------------------------------------------------------- +This project bundles some components that are also licensed under the Apache +License Version 2.0: + +audience-annotations-0.5.0 +commons-cli-1.4 +commons-lang3-3.8.1 +jackson-annotations-2.12.3 +jackson-core-2.12.3 +jackson-databind-2.12.3 +jackson-dataformat-csv-2.12.3 +jackson-datatype-jdk8-2.12.3 +jackson-jaxrs-base-2.12.3 +jackson-jaxrs-json-provider-2.12.3 +jackson-module-jaxb-annotations-2.12.3 +jackson-module-paranamer-2.10.5 +jackson-module-scala_2.13-2.12.3 +jakarta.validation-api-2.0.2 +javassist-3.27.0-GA +jetty-client-9.4.43.v20210629 +jetty-continuation-9.4.43.v20210629 +jetty-http-9.4.43.v20210629 +jetty-io-9.4.43.v20210629 +jetty-security-9.4.43.v20210629 +jetty-server-9.4.43.v20210629 +jetty-servlet-9.4.43.v20210629 +jetty-servlets-9.4.43.v20210629 +jetty-util-9.4.43.v20210629 +jetty-util-ajax-9.4.43.v20210629 +jersey-common-2.34 +jersey-server-2.34 +jose4j-0.7.8 +log4j-1.2.17 +lz4-java-1.8.0 +maven-artifact-3.8.1 +metrics-core-4.1.12.1 +netty-buffer-4.1.68.Final +netty-codec-4.1.68.Final +netty-common-4.1.68.Final +netty-handler-4.1.68.Final +netty-resolver-4.1.68.Final +netty-transport-4.1.68.Final +netty-transport-native-epoll-4.1.68.Final +netty-transport-native-unix-common-4.1.68.Final +plexus-utils-3.2.1 +rocksdbjni-6.22.1.1 +scala-collection-compat_2.13-2.4.4 +scala-library-2.13.6 +scala-logging_2.13-3.9.3 +scala-reflect-2.13.6 +scala-java8-compat_2.13-1.0.0 +snappy-java-1.1.8.4 +zookeeper-3.6.3 +zookeeper-jute-3.6.3 + +=============================================================================== +This product bundles various third-party components under other open source +licenses. This section summarizes those components and their licenses. +See licenses/ for text of these licenses. + +--------------------------------------- +Eclipse Distribution License - v 1.0 +see: licenses/eclipse-distribution-license-1.0 + +jakarta.activation-api-1.2.1 +jakarta.xml.bind-api-2.3.2 + +--------------------------------------- +Eclipse Public License - v 2.0 +see: licenses/eclipse-public-license-2.0 + +jakarta.annotation-api-1.3.5 +jakarta.ws.rs-api-2.1.6 +javax.ws.rs-api-2.1.1 +hk2-api-2.6.1 +hk2-locator-2.6.1 +hk2-utils-2.6.1 +osgi-resource-locator-1.0.3 +aopalliance-repackaged-2.6.1 +jakarta.inject-2.6.1 +jersey-container-servlet-2.34 +jersey-container-servlet-core-2.34 +jersey-client-2.34 +jersey-hk2-2.34 +jersey-media-jaxb-2.31 + +--------------------------------------- +CDDL 1.1 + GPLv2 with classpath exception +see: licenses/CDDL+GPL-1.1 + +javax.servlet-api-3.1.0 +jaxb-api-2.3.0 +activation-1.1.1 + +--------------------------------------- +MIT License + +argparse4j-0.7.0, see: licenses/argparse-MIT +jopt-simple-5.0.4, see: licenses/jopt-simple-MIT +slf4j-api-1.7.30, see: licenses/slf4j-MIT +slf4j-log4j12-1.7.30, see: licenses/slf4j-MIT + +--------------------------------------- +BSD 2-Clause + +zstd-jni-1.5.0-4 see: licenses/zstd-jni-BSD-2-clause + +--------------------------------------- +BSD 3-Clause + +jline-3.12.1, see: licenses/jline-BSD-3-clause +paranamer-2.8, see: licenses/paranamer-BSD-3-clause + +--------------------------------------- +Do What The F*ck You Want To Public License +see: licenses/DWTFYWTPL + +reflections-0.9.12 diff --git a/LICENSE的副本 b/LICENSE的副本 new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/LICENSE的副本 @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/NOTICE b/NOTICE new file mode 100644 index 0000000..fc89275 --- /dev/null +++ b/NOTICE @@ -0,0 +1,23 @@ +Apache Kafka +Copyright 2022 The Apache Software Foundation. + +This product includes software developed at +The Apache Software Foundation (https://www.apache.org/). + +This distribution has a binary dependency on jersey, which is available under the CDDL +License. The source code of jersey can be found at https://github.com/jersey/jersey/. + +This distribution has a binary test dependency on jqwik, which is available under +the Eclipse Public License 2.0. The source code can be found at +https://github.com/jlink/jqwik. + +The streams-scala (streams/streams-scala) module was donated by Lightbend and the original code was copyrighted by them: +Copyright (C) 2018 Lightbend Inc. +Copyright (C) 2017-2018 Alexis Seigneurin. + +This project contains the following code copied from Apache Hadoop: +clients/src/main/java/org/apache/kafka/common/utils/PureJavaCrc32C.java +Some portions of this file Copyright (c) 2004-2006 Intel Corporation and licensed under the BSD license. + +This project contains the following code copied from Apache Hive: +streams/src/main/java/org/apache/kafka/streams/state/internals/Murmur3.java \ No newline at end of file diff --git a/NOTICE-binary b/NOTICE-binary new file mode 100644 index 0000000..a50c86d --- /dev/null +++ b/NOTICE-binary @@ -0,0 +1,856 @@ +Apache Kafka +Copyright 2021 The Apache Software Foundation. + +This product includes software developed at +The Apache Software Foundation (https://www.apache.org/). + +This distribution has a binary dependency on jersey, which is available under the CDDL +License. The source code of jersey can be found at https://github.com/jersey/jersey/. + +This distribution has a binary test dependency on jqwik, which is available under +the Eclipse Public License 2.0. The source code can be found at +https://github.com/jlink/jqwik. + +The streams-scala (streams/streams-scala) module was donated by Lightbend and the original code was copyrighted by them: +Copyright (C) 2018 Lightbend Inc. +Copyright (C) 2017-2018 Alexis Seigneurin. + +This project contains the following code copied from Apache Hadoop: +clients/src/main/java/org/apache/kafka/common/utils/PureJavaCrc32C.java +Some portions of this file Copyright (c) 2004-2006 Intel Corporation and licensed under the BSD license. + +This project contains the following code copied from Apache Hive: +streams/src/main/java/org/apache/kafka/streams/state/internals/Murmur3.java + +// ------------------------------------------------------------------ +// NOTICE file corresponding to the section 4d of The Apache License, +// Version 2.0, in this case for +// ------------------------------------------------------------------ + +# Notices for Eclipse GlassFish + +This content is produced and maintained by the Eclipse GlassFish project. + +* Project home: https://projects.eclipse.org/projects/ee4j.glassfish + +## Trademarks + +Eclipse GlassFish, and GlassFish are trademarks of the Eclipse Foundation. + +## Copyright + +All content is the property of the respective authors or their employers. For +more information regarding authorship of content, please consult the listed +source code repository logs. + +## Declared Project Licenses + +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v. 2.0 which is available at +http://www.eclipse.org/legal/epl-2.0. This Source Code may also be made +available under the following Secondary Licenses when the conditions for such +availability set forth in the Eclipse Public License v. 2.0 are satisfied: GNU +General Public License, version 2 with the GNU Classpath Exception which is +available at https://www.gnu.org/software/classpath/license.html. + +SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0 + +## Source Code + +The project maintains the following source code repositories: + +* https://github.com/eclipse-ee4j/glassfish-ha-api +* https://github.com/eclipse-ee4j/glassfish-logging-annotation-processor +* https://github.com/eclipse-ee4j/glassfish-shoal +* https://github.com/eclipse-ee4j/glassfish-cdi-porting-tck +* https://github.com/eclipse-ee4j/glassfish-jsftemplating +* https://github.com/eclipse-ee4j/glassfish-hk2-extra +* https://github.com/eclipse-ee4j/glassfish-hk2 +* https://github.com/eclipse-ee4j/glassfish-fighterfish + +## Third-party Content + +This project leverages the following third party content. + +None + +## Cryptography + +Content may contain encryption software. The country in which you are currently +may have restrictions on the import, possession, and use, and/or re-export to +another country, of encryption software. BEFORE using any encryption software, +please check the country's laws, regulations and policies concerning the import, +possession, or use, and re-export of encryption software, to see if this is +permitted. + + +Apache Yetus - Audience Annotations +Copyright 2015-2017 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +Apache Commons CLI +Copyright 2001-2017 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +Apache Commons Lang +Copyright 2001-2018 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +# Jackson JSON processor + +Jackson is a high-performance, Free/Open Source JSON processing library. +It was originally written by Tatu Saloranta (tatu.saloranta@iki.fi), and has +been in development since 2007. +It is currently developed by a community of developers, as well as supported +commercially by FasterXML.com. + +## Licensing + +Jackson core and extension components may licensed under different licenses. +To find the details that apply to this artifact see the accompanying LICENSE file. +For more information, including possible other licensing options, contact +FasterXML.com (http://fasterxml.com). + +## Credits + +A list of contributors may be found from CREDITS file, which is included +in some artifacts (usually source distributions); but is always available +from the source code management (SCM) system project uses. + + +# Notices for Eclipse Project for JAF + +This content is produced and maintained by the Eclipse Project for JAF project. + +* Project home: https://projects.eclipse.org/projects/ee4j.jaf + +## Copyright + +All content is the property of the respective authors or their employers. For +more information regarding authorship of content, please consult the listed +source code repository logs. + +## Declared Project Licenses + +This program and the accompanying materials are made available under the terms +of the Eclipse Distribution License v. 1.0, +which is available at http://www.eclipse.org/org/documents/edl-v10.php. + +SPDX-License-Identifier: BSD-3-Clause + +## Source Code + +The project maintains the following source code repositories: + +* https://github.com/eclipse-ee4j/jaf + +## Third-party Content + +This project leverages the following third party content. + +JUnit (4.12) + +* License: Eclipse Public License + + +# Notices for Jakarta Annotations + +This content is produced and maintained by the Jakarta Annotations project. + + * Project home: https://projects.eclipse.org/projects/ee4j.ca + +## Trademarks + +Jakarta Annotations is a trademark of the Eclipse Foundation. + +## Declared Project Licenses + +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v. 2.0 which is available at +http://www.eclipse.org/legal/epl-2.0. This Source Code may also be made +available under the following Secondary Licenses when the conditions for such +availability set forth in the Eclipse Public License v. 2.0 are satisfied: GNU +General Public License, version 2 with the GNU Classpath Exception which is +available at https://www.gnu.org/software/classpath/license.html. + +SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0 + +## Source Code + +The project maintains the following source code repositories: + + * https://github.com/eclipse-ee4j/common-annotations-api + +## Third-party Content + +## Cryptography + +Content may contain encryption software. The country in which you are currently +may have restrictions on the import, possession, and use, and/or re-export to +another country, of encryption software. BEFORE using any encryption software, +please check the country's laws, regulations and policies concerning the import, +possession, or use, and re-export of encryption software, to see if this is +permitted. + + +# Notices for the Jakarta RESTful Web Services Project + +This content is produced and maintained by the **Jakarta RESTful Web Services** +project. + +* Project home: https://projects.eclipse.org/projects/ee4j.jaxrs + +## Trademarks + +**Jakarta RESTful Web Services** is a trademark of the Eclipse Foundation. + +## Copyright + +All content is the property of the respective authors or their employers. For +more information regarding authorship of content, please consult the listed +source code repository logs. + +## Declared Project Licenses + +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v. 2.0 which is available at +http://www.eclipse.org/legal/epl-2.0. This Source Code may also be made +available under the following Secondary Licenses when the conditions for such +availability set forth in the Eclipse Public License v. 2.0 are satisfied: GNU +General Public License, version 2 with the GNU Classpath Exception which is +available at https://www.gnu.org/software/classpath/license.html. + +SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0 + +## Source Code + +The project maintains the following source code repositories: + +* https://github.com/eclipse-ee4j/jaxrs-api + +## Third-party Content + +This project leverages the following third party content. + +javaee-api (7.0) + +* License: Apache-2.0 AND W3C + +JUnit (4.11) + +* License: Common Public License 1.0 + +Mockito (2.16.0) + +* Project: http://site.mockito.org +* Source: https://github.com/mockito/mockito/releases/tag/v2.16.0 + +## Cryptography + +Content may contain encryption software. The country in which you are currently +may have restrictions on the import, possession, and use, and/or re-export to +another country, of encryption software. BEFORE using any encryption software, +please check the country's laws, regulations and policies concerning the import, +possession, or use, and re-export of encryption software, to see if this is +permitted. + + +# Notices for Eclipse Project for JAXB + +This content is produced and maintained by the Eclipse Project for JAXB project. + +* Project home: https://projects.eclipse.org/projects/ee4j.jaxb + +## Trademarks + +Eclipse Project for JAXB is a trademark of the Eclipse Foundation. + +## Copyright + +All content is the property of the respective authors or their employers. For +more information regarding authorship of content, please consult the listed +source code repository logs. + +## Declared Project Licenses + +This program and the accompanying materials are made available under the terms +of the Eclipse Distribution License v. 1.0 which is available +at http://www.eclipse.org/org/documents/edl-v10.php. + +SPDX-License-Identifier: BSD-3-Clause + +## Source Code + +The project maintains the following source code repositories: + +* https://github.com/eclipse-ee4j/jaxb-api + +## Third-party Content + +This project leverages the following third party content. + +None + +## Cryptography + +Content may contain encryption software. The country in which you are currently +may have restrictions on the import, possession, and use, and/or re-export to +another country, of encryption software. BEFORE using any encryption software, +please check the country's laws, regulations and policies concerning the import, +possession, or use, and re-export of encryption software, to see if this is +permitted. + + +# Notice for Jersey +This content is produced and maintained by the Eclipse Jersey project. + +* Project home: https://projects.eclipse.org/projects/ee4j.jersey + +## Trademarks +Eclipse Jersey is a trademark of the Eclipse Foundation. + +## Copyright + +All content is the property of the respective authors or their employers. For +more information regarding authorship of content, please consult the listed +source code repository logs. + +## Declared Project Licenses + +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v. 2.0 which is available at +http://www.eclipse.org/legal/epl-2.0. This Source Code may also be made +available under the following Secondary Licenses when the conditions for such +availability set forth in the Eclipse Public License v. 2.0 are satisfied: GNU +General Public License, version 2 with the GNU Classpath Exception which is +available at https://www.gnu.org/software/classpath/license.html. + +SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0 + +## Source Code +The project maintains the following source code repositories: + +* https://github.com/eclipse-ee4j/jersey + +## Third-party Content + +Angular JS, v1.6.6 +* License MIT (http://www.opensource.org/licenses/mit-license.php) +* Project: http://angularjs.org +* Coyright: (c) 2010-2017 Google, Inc. + +aopalliance Version 1 +* License: all the source code provided by AOP Alliance is Public Domain. +* Project: http://aopalliance.sourceforge.net +* Copyright: Material in the public domain is not protected by copyright + +Bean Validation API 2.0.2 +* License: Apache License, 2.0 +* Project: http://beanvalidation.org/1.1/ +* Copyright: 2009, Red Hat, Inc. and/or its affiliates, and individual contributors +* by the @authors tag. + +Hibernate Validator CDI, 6.1.2.Final +* License: Apache License, 2.0 +* Project: https://beanvalidation.org/ +* Repackaged in org.glassfish.jersey.server.validation.internal.hibernate + +Bootstrap v3.3.7 +* License: MIT license (https://github.com/twbs/bootstrap/blob/master/LICENSE) +* Project: http://getbootstrap.com +* Copyright: 2011-2016 Twitter, Inc + +Google Guava Version 18.0 +* License: Apache License, 2.0 +* Copyright (C) 2009 The Guava Authors + +javax.inject Version: 1 +* License: Apache License, 2.0 +* Copyright (C) 2009 The JSR-330 Expert Group + +Javassist Version 3.25.0-GA +* License: Apache License, 2.0 +* Project: http://www.javassist.org/ +* Copyright (C) 1999- Shigeru Chiba. All Rights Reserved. + +Jackson JAX-RS Providers Version 2.10.1 +* License: Apache License, 2.0 +* Project: https://github.com/FasterXML/jackson-jaxrs-providers +* Copyright: (c) 2009-2011 FasterXML, LLC. All rights reserved unless otherwise indicated. + +jQuery v1.12.4 +* License: jquery.org/license +* Project: jquery.org +* Copyright: (c) jQuery Foundation + +jQuery Barcode plugin 0.3 +* License: MIT & GPL (http://www.opensource.org/licenses/mit-license.php & http://www.gnu.org/licenses/gpl.html) +* Project: http://www.pasella.it/projects/jQuery/barcode +* Copyright: (c) 2009 Antonello Pasella antonello.pasella@gmail.com + +JSR-166 Extension - JEP 266 +* License: CC0 +* No copyright +* Written by Doug Lea with assistance from members of JCP JSR-166 Expert Group and released to the public domain, as explained at http://creativecommons.org/publicdomain/zero/1.0/ + +KineticJS, v4.7.1 +* License: MIT license (http://www.opensource.org/licenses/mit-license.php) +* Project: http://www.kineticjs.com, https://github.com/ericdrowell/KineticJS +* Copyright: Eric Rowell + +org.objectweb.asm Version 8.0 +* License: Modified BSD (http://asm.objectweb.org/license.html) +* Copyright (c) 2000-2011 INRIA, France Telecom. All rights reserved. + +org.osgi.core version 6.0.0 +* License: Apache License, 2.0 +* Copyright (c) OSGi Alliance (2005, 2008). All Rights Reserved. + +org.glassfish.jersey.server.internal.monitoring.core +* License: Apache License, 2.0 +* Copyright (c) 2015-2018 Oracle and/or its affiliates. All rights reserved. +* Copyright 2010-2013 Coda Hale and Yammer, Inc. + +W3.org documents +* License: W3C License +* Copyright: Copyright (c) 1994-2001 World Wide Web Consortium, (Massachusetts Institute of Technology, Institut National de Recherche en Informatique et en Automatique, Keio University). All Rights Reserved. http://www.w3.org/Consortium/Legal/ + + +============================================================== + Jetty Web Container + Copyright 1995-2018 Mort Bay Consulting Pty Ltd. +============================================================== + +The Jetty Web Container is Copyright Mort Bay Consulting Pty Ltd +unless otherwise noted. + +Jetty is dual licensed under both + + * The Apache 2.0 License + http://www.apache.org/licenses/LICENSE-2.0.html + + and + + * The Eclipse Public 1.0 License + http://www.eclipse.org/legal/epl-v10.html + +Jetty may be distributed under either license. + +------ +Eclipse + +The following artifacts are EPL. + * org.eclipse.jetty.orbit:org.eclipse.jdt.core + +The following artifacts are EPL and ASL2. + * org.eclipse.jetty.orbit:javax.security.auth.message + + +The following artifacts are EPL and CDDL 1.0. + * org.eclipse.jetty.orbit:javax.mail.glassfish + + +------ +Oracle + +The following artifacts are CDDL + GPLv2 with classpath exception. +https://glassfish.dev.java.net/nonav/public/CDDL+GPL.html + + * javax.servlet:javax.servlet-api + * javax.annotation:javax.annotation-api + * javax.transaction:javax.transaction-api + * javax.websocket:javax.websocket-api + +------ +Oracle OpenJDK + +If ALPN is used to negotiate HTTP/2 connections, then the following +artifacts may be included in the distribution or downloaded when ALPN +module is selected. + + * java.sun.security.ssl + +These artifacts replace/modify OpenJDK classes. The modififications +are hosted at github and both modified and original are under GPL v2 with +classpath exceptions. +http://openjdk.java.net/legal/gplv2+ce.html + + +------ +OW2 + +The following artifacts are licensed by the OW2 Foundation according to the +terms of http://asm.ow2.org/license.html + +org.ow2.asm:asm-commons +org.ow2.asm:asm + + +------ +Apache + +The following artifacts are ASL2 licensed. + +org.apache.taglibs:taglibs-standard-spec +org.apache.taglibs:taglibs-standard-impl + + +------ +MortBay + +The following artifacts are ASL2 licensed. Based on selected classes from +following Apache Tomcat jars, all ASL2 licensed. + +org.mortbay.jasper:apache-jsp + org.apache.tomcat:tomcat-jasper + org.apache.tomcat:tomcat-juli + org.apache.tomcat:tomcat-jsp-api + org.apache.tomcat:tomcat-el-api + org.apache.tomcat:tomcat-jasper-el + org.apache.tomcat:tomcat-api + org.apache.tomcat:tomcat-util-scan + org.apache.tomcat:tomcat-util + +org.mortbay.jasper:apache-el + org.apache.tomcat:tomcat-jasper-el + org.apache.tomcat:tomcat-el-api + + +------ +Mortbay + +The following artifacts are CDDL + GPLv2 with classpath exception. + +https://glassfish.dev.java.net/nonav/public/CDDL+GPL.html + +org.eclipse.jetty.toolchain:jetty-schemas + +------ +Assorted + +The UnixCrypt.java code implements the one way cryptography used by +Unix systems for simple password protection. Copyright 1996 Aki Yoshida, +modified April 2001 by Iris Van den Broeke, Daniel Deville. +Permission to use, copy, modify and distribute UnixCrypt +for non-commercial or commercial purposes and without fee is +granted provided that the copyright notice appears in all copies. + + +Apache log4j +Copyright 2007 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +Maven Artifact +Copyright 2001-2019 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +This product includes software developed by the Indiana University + Extreme! Lab (http://www.extreme.indiana.edu/). + +This product includes software developed by +The Apache Software Foundation (http://www.apache.org/). + +This product includes software developed by +ThoughtWorks (http://www.thoughtworks.com). + +This product includes software developed by +javolution (http://javolution.org/). + +This product includes software developed by +Rome (https://rome.dev.java.net/). + + +Scala +Copyright (c) 2002-2020 EPFL +Copyright (c) 2011-2020 Lightbend, Inc. + +Scala includes software developed at +LAMP/EPFL (https://lamp.epfl.ch/) and +Lightbend, Inc. (https://www.lightbend.com/). + +Licensed under the Apache License, Version 2.0 (the "License"). +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +This software includes projects with other licenses -- see `doc/LICENSE.md`. + + +Apache ZooKeeper - Server +Copyright 2008-2021 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +Apache ZooKeeper - Jute +Copyright 2008-2021 The Apache Software Foundation + +This product includes software developed at +The Apache Software Foundation (http://www.apache.org/). + + +The Netty Project + ================= + +Please visit the Netty web site for more information: + + * https://netty.io/ + +Copyright 2014 The Netty Project + +The Netty Project licenses this file to you under the Apache License, +version 2.0 (the "License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at: + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +License for the specific language governing permissions and limitations +under the License. + +Also, please refer to each LICENSE..txt file, which is located in +the 'license' directory of the distribution file, for the license terms of the +components that this product depends on. + +------------------------------------------------------------------------------- +This product contains the extensions to Java Collections Framework which has +been derived from the works by JSR-166 EG, Doug Lea, and Jason T. Greene: + + * LICENSE: + * license/LICENSE.jsr166y.txt (Public Domain) + * HOMEPAGE: + * http://gee.cs.oswego.edu/cgi-bin/viewcvs.cgi/jsr166/ + * http://viewvc.jboss.org/cgi-bin/viewvc.cgi/jbosscache/experimental/jsr166/ + +This product contains a modified version of Robert Harder's Public Domain +Base64 Encoder and Decoder, which can be obtained at: + + * LICENSE: + * license/LICENSE.base64.txt (Public Domain) + * HOMEPAGE: + * http://iharder.sourceforge.net/current/java/base64/ + +This product contains a modified portion of 'Webbit', an event based +WebSocket and HTTP server, which can be obtained at: + + * LICENSE: + * license/LICENSE.webbit.txt (BSD License) + * HOMEPAGE: + * https://github.com/joewalnes/webbit + +This product contains a modified portion of 'SLF4J', a simple logging +facade for Java, which can be obtained at: + + * LICENSE: + * license/LICENSE.slf4j.txt (MIT License) + * HOMEPAGE: + * https://www.slf4j.org/ + +This product contains a modified portion of 'Apache Harmony', an open source +Java SE, which can be obtained at: + + * NOTICE: + * license/NOTICE.harmony.txt + * LICENSE: + * license/LICENSE.harmony.txt (Apache License 2.0) + * HOMEPAGE: + * https://archive.apache.org/dist/harmony/ + +This product contains a modified portion of 'jbzip2', a Java bzip2 compression +and decompression library written by Matthew J. Francis. It can be obtained at: + + * LICENSE: + * license/LICENSE.jbzip2.txt (MIT License) + * HOMEPAGE: + * https://code.google.com/p/jbzip2/ + +This product contains a modified portion of 'libdivsufsort', a C API library to construct +the suffix array and the Burrows-Wheeler transformed string for any input string of +a constant-size alphabet written by Yuta Mori. It can be obtained at: + + * LICENSE: + * license/LICENSE.libdivsufsort.txt (MIT License) + * HOMEPAGE: + * https://github.com/y-256/libdivsufsort + +This product contains a modified portion of Nitsan Wakart's 'JCTools', Java Concurrency Tools for the JVM, + which can be obtained at: + + * LICENSE: + * license/LICENSE.jctools.txt (ASL2 License) + * HOMEPAGE: + * https://github.com/JCTools/JCTools + +This product optionally depends on 'JZlib', a re-implementation of zlib in +pure Java, which can be obtained at: + + * LICENSE: + * license/LICENSE.jzlib.txt (BSD style License) + * HOMEPAGE: + * http://www.jcraft.com/jzlib/ + +This product optionally depends on 'Compress-LZF', a Java library for encoding and +decoding data in LZF format, written by Tatu Saloranta. It can be obtained at: + + * LICENSE: + * license/LICENSE.compress-lzf.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/ning/compress + +This product optionally depends on 'lz4', a LZ4 Java compression +and decompression library written by Adrien Grand. It can be obtained at: + + * LICENSE: + * license/LICENSE.lz4.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/jpountz/lz4-java + +This product optionally depends on 'lzma-java', a LZMA Java compression +and decompression library, which can be obtained at: + + * LICENSE: + * license/LICENSE.lzma-java.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/jponge/lzma-java + +This product contains a modified portion of 'jfastlz', a Java port of FastLZ compression +and decompression library written by William Kinney. It can be obtained at: + + * LICENSE: + * license/LICENSE.jfastlz.txt (MIT License) + * HOMEPAGE: + * https://code.google.com/p/jfastlz/ + +This product contains a modified portion of and optionally depends on 'Protocol Buffers', Google's data +interchange format, which can be obtained at: + + * LICENSE: + * license/LICENSE.protobuf.txt (New BSD License) + * HOMEPAGE: + * https://github.com/google/protobuf + +This product optionally depends on 'Bouncy Castle Crypto APIs' to generate +a temporary self-signed X.509 certificate when the JVM does not provide the +equivalent functionality. It can be obtained at: + + * LICENSE: + * license/LICENSE.bouncycastle.txt (MIT License) + * HOMEPAGE: + * https://www.bouncycastle.org/ + +This product optionally depends on 'Snappy', a compression library produced +by Google Inc, which can be obtained at: + + * LICENSE: + * license/LICENSE.snappy.txt (New BSD License) + * HOMEPAGE: + * https://github.com/google/snappy + +This product optionally depends on 'JBoss Marshalling', an alternative Java +serialization API, which can be obtained at: + + * LICENSE: + * license/LICENSE.jboss-marshalling.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/jboss-remoting/jboss-marshalling + +This product optionally depends on 'Caliper', Google's micro- +benchmarking framework, which can be obtained at: + + * LICENSE: + * license/LICENSE.caliper.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/google/caliper + +This product optionally depends on 'Apache Commons Logging', a logging +framework, which can be obtained at: + + * LICENSE: + * license/LICENSE.commons-logging.txt (Apache License 2.0) + * HOMEPAGE: + * https://commons.apache.org/logging/ + +This product optionally depends on 'Apache Log4J', a logging framework, which +can be obtained at: + + * LICENSE: + * license/LICENSE.log4j.txt (Apache License 2.0) + * HOMEPAGE: + * https://logging.apache.org/log4j/ + +This product optionally depends on 'Aalto XML', an ultra-high performance +non-blocking XML processor, which can be obtained at: + + * LICENSE: + * license/LICENSE.aalto-xml.txt (Apache License 2.0) + * HOMEPAGE: + * http://wiki.fasterxml.com/AaltoHome + +This product contains a modified version of 'HPACK', a Java implementation of +the HTTP/2 HPACK algorithm written by Twitter. It can be obtained at: + + * LICENSE: + * license/LICENSE.hpack.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/twitter/hpack + +This product contains a modified version of 'HPACK', a Java implementation of +the HTTP/2 HPACK algorithm written by Cory Benfield. It can be obtained at: + + * LICENSE: + * license/LICENSE.hyper-hpack.txt (MIT License) + * HOMEPAGE: + * https://github.com/python-hyper/hpack/ + +This product contains a modified version of 'HPACK', a Java implementation of +the HTTP/2 HPACK algorithm written by Tatsuhiro Tsujikawa. It can be obtained at: + + * LICENSE: + * license/LICENSE.nghttp2-hpack.txt (MIT License) + * HOMEPAGE: + * https://github.com/nghttp2/nghttp2/ + +This product contains a modified portion of 'Apache Commons Lang', a Java library +provides utilities for the java.lang API, which can be obtained at: + + * LICENSE: + * license/LICENSE.commons-lang.txt (Apache License 2.0) + * HOMEPAGE: + * https://commons.apache.org/proper/commons-lang/ + + +This product contains the Maven wrapper scripts from 'Maven Wrapper', that provides an easy way to ensure a user has everything necessary to run the Maven build. + + * LICENSE: + * license/LICENSE.mvn-wrapper.txt (Apache License 2.0) + * HOMEPAGE: + * https://github.com/takari/maven-wrapper + +This product contains the dnsinfo.h header file, that provides a way to retrieve the system DNS configuration on MacOS. +This private header is also used by Apple's open source + mDNSResponder (https://opensource.apple.com/tarballs/mDNSResponder/). + + * LICENSE: + * license/LICENSE.dnsinfo.txt (Apple Public Source License 2.0) + * HOMEPAGE: + * https://www.opensource.apple.com/source/configd/configd-453.19/dnsinfo/dnsinfo.h \ No newline at end of file diff --git a/PULL_REQUEST_TEMPLATE.md b/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..552a4d0 --- /dev/null +++ b/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,14 @@ +*More detailed description of your change, +if necessary. The PR title and PR message become +the squashed commit message, so use a separate +comment to ping reviewers.* + +*Summary of testing strategy (including rationale) +for the feature or bug fix. Unit and/or integration +tests are expected for any behaviour change and +system tests should be considered for larger changes.* + +### Committer Checklist (excluded from commit message) +- [ ] Verify design and implementation +- [ ] Verify test coverage and CI build status +- [ ] Verify documentation (including upgrade notes) diff --git a/README的副本.md b/README的副本.md new file mode 100644 index 0000000..f011c61 --- /dev/null +++ b/README的副本.md @@ -0,0 +1,277 @@ +Apache Kafka +================= +See our [web site](https://kafka.apache.org) for details on the project. + +You need to have [Java](http://www.oracle.com/technetwork/java/javase/downloads/index.html) installed. + +We build and test Apache Kafka with Java 8, 11 and 17. We set the `release` parameter in javac and scalac +to `8` to ensure the generated binaries are compatible with Java 8 or higher (independently of the Java version +used for compilation). Java 8 support has been deprecated since Apache Kafka 3.0 and will be removed in Apache +Kafka 4.0 (see [KIP-750](https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=181308223) for more details). + +Scala 2.12 and 2.13 are supported and 2.13 is used by default. Scala 2.12 support has been deprecated since +Apache Kafka 3.0 and will be removed in Apache Kafka 4.0 (see [KIP-751](https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=181308218) +for more details). See below for how to use a specific Scala version or all of the supported Scala versions. + +### Build a jar and run it ### + ./gradlew jar + +Follow instructions in https://kafka.apache.org/quickstart + +### Build source jar ### + ./gradlew srcJar + +### Build aggregated javadoc ### + ./gradlew aggregatedJavadoc + +### Build javadoc and scaladoc ### + ./gradlew javadoc + ./gradlew javadocJar # builds a javadoc jar for each module + ./gradlew scaladoc + ./gradlew scaladocJar # builds a scaladoc jar for each module + ./gradlew docsJar # builds both (if applicable) javadoc and scaladoc jars for each module + +### Run unit/integration tests ### + ./gradlew test # runs both unit and integration tests + ./gradlew unitTest + ./gradlew integrationTest + +### Force re-running tests without code change ### + ./gradlew cleanTest test + ./gradlew cleanTest unitTest + ./gradlew cleanTest integrationTest + +### Running a particular unit/integration test ### + ./gradlew clients:test --tests RequestResponseTest + +### Running a particular test method within a unit/integration test ### + ./gradlew core:test --tests kafka.api.ProducerFailureHandlingTest.testCannotSendToInternalTopic + ./gradlew clients:test --tests org.apache.kafka.clients.MetadataTest.testMetadataUpdateWaitTime + +### Running a particular unit/integration test with log4j output ### +Change the log4j setting in either `clients/src/test/resources/log4j.properties` or `core/src/test/resources/log4j.properties` + + ./gradlew clients:test --tests RequestResponseTest + +### Specifying test retries ### +By default, each failed test is retried once up to a maximum of five retries per test run. Tests are retried at the end of the test task. Adjust these parameters in the following way: + + ./gradlew test -PmaxTestRetries=1 -PmaxTestRetryFailures=5 + +See [Test Retry Gradle Plugin](https://github.com/gradle/test-retry-gradle-plugin) for more details. + +### Generating test coverage reports ### +Generate coverage reports for the whole project: + + ./gradlew reportCoverage -PenableTestCoverage=true -Dorg.gradle.parallel=false + +Generate coverage for a single module, i.e.: + + ./gradlew clients:reportCoverage -PenableTestCoverage=true -Dorg.gradle.parallel=false + +### Building a binary release gzipped tar ball ### + ./gradlew clean releaseTarGz + +The release file can be found inside `./core/build/distributions/`. + +### Building auto generated messages ### +Sometimes it is only necessary to rebuild the RPC auto-generated message data when switching between branches, as they could +fail due to code changes. You can just run: + + ./gradlew processMessages processTestMessages + +### Running a Kafka broker in ZooKeeper mode + + ./bin/zookeeper-server-start.sh config/zookeeper.properties + ./bin/kafka-server-start.sh config/server.properties + +### Running a Kafka broker in KRaft (Kafka Raft metadata) mode + +See [config/kraft/README.md](https://github.com/apache/kafka/blob/trunk/config/kraft/README.md). + +### Cleaning the build ### + ./gradlew clean + +### Running a task with one of the Scala versions available (2.12.x or 2.13.x) ### +*Note that if building the jars with a version other than 2.13.x, you need to set the `SCALA_VERSION` variable or change it in `bin/kafka-run-class.sh` to run the quick start.* + +You can pass either the major version (eg 2.12) or the full version (eg 2.12.7): + + ./gradlew -PscalaVersion=2.12 jar + ./gradlew -PscalaVersion=2.12 test + ./gradlew -PscalaVersion=2.12 releaseTarGz + +### Running a task with all the scala versions enabled by default ### + +Invoke the `gradlewAll` script followed by the task(s): + + ./gradlewAll test + ./gradlewAll jar + ./gradlewAll releaseTarGz + +### Running a task for a specific project ### +This is for `core`, `examples` and `clients` + + ./gradlew core:jar + ./gradlew core:test + +Streams has multiple sub-projects, but you can run all the tests: + + ./gradlew :streams:testAll + +### Listing all gradle tasks ### + ./gradlew tasks + +### Building IDE project #### +*Note that this is not strictly necessary (IntelliJ IDEA has good built-in support for Gradle projects, for example).* + + ./gradlew eclipse + ./gradlew idea + +The `eclipse` task has been configured to use `${project_dir}/build_eclipse` as Eclipse's build directory. Eclipse's default +build directory (`${project_dir}/bin`) clashes with Kafka's scripts directory and we don't use Gradle's build directory +to avoid known issues with this configuration. + +### Publishing the jar for all versions of Scala and for all projects to maven ### +The recommended command is: + + ./gradlewAll publish + +For backwards compatibility, the following also works: + + ./gradlewAll uploadArchives + +Please note for this to work you should create/update `${GRADLE_USER_HOME}/gradle.properties` (typically, `~/.gradle/gradle.properties`) and assign the following variables + + mavenUrl= + mavenUsername= + mavenPassword= + signing.keyId= + signing.password= + signing.secretKeyRingFile= + +### Publishing the streams quickstart archetype artifact to maven ### +For the Streams archetype project, one cannot use gradle to upload to maven; instead the `mvn deploy` command needs to be called at the quickstart folder: + + cd streams/quickstart + mvn deploy + +Please note for this to work you should create/update user maven settings (typically, `${USER_HOME}/.m2/settings.xml`) to assign the following variables + + + ... + + ... + + apache.snapshots.https + ${maven_username} + ${maven_password} + + + apache.releases.https + ${maven_username} + ${maven_password} + + ... + + ... + + +### Installing the jars to the local Maven repository ### +The recommended command is: + + ./gradlewAll publishToMavenLocal + +For backwards compatibility, the following also works: + + ./gradlewAll install + +### Building the test jar ### + ./gradlew testJar + +### Determining how transitive dependencies are added ### + ./gradlew core:dependencies --configuration runtime + +### Determining if any dependencies could be updated ### + ./gradlew dependencyUpdates + +### Running code quality checks ### +There are two code quality analysis tools that we regularly run, spotbugs and checkstyle. + +#### Checkstyle #### +Checkstyle enforces a consistent coding style in Kafka. +You can run checkstyle using: + + ./gradlew checkstyleMain checkstyleTest + +The checkstyle warnings will be found in `reports/checkstyle/reports/main.html` and `reports/checkstyle/reports/test.html` files in the +subproject build directories. They are also printed to the console. The build will fail if Checkstyle fails. + +#### Spotbugs #### +Spotbugs uses static analysis to look for bugs in the code. +You can run spotbugs using: + + ./gradlew spotbugsMain spotbugsTest -x test + +The spotbugs warnings will be found in `reports/spotbugs/main.html` and `reports/spotbugs/test.html` files in the subproject build +directories. Use -PxmlSpotBugsReport=true to generate an XML report instead of an HTML one. + +### JMH microbenchmarks ### +We use [JMH](https://openjdk.java.net/projects/code-tools/jmh/) to write microbenchmarks that produce reliable results in the JVM. + +See [jmh-benchmarks/README.md](https://github.com/apache/kafka/blob/trunk/jmh-benchmarks/README.md) for details on how to run the microbenchmarks. + +### Common build options ### + +The following options should be set with a `-P` switch, for example `./gradlew -PmaxParallelForks=1 test`. + +* `commitId`: sets the build commit ID as .git/HEAD might not be correct if there are local commits added for build purposes. +* `mavenUrl`: sets the URL of the maven deployment repository (`file://path/to/repo` can be used to point to a local repository). +* `maxParallelForks`: limits the maximum number of processes for each task. +* `ignoreFailures`: ignore test failures from junit +* `showStandardStreams`: shows standard out and standard error of the test JVM(s) on the console. +* `skipSigning`: skips signing of artifacts. +* `testLoggingEvents`: unit test events to be logged, separated by comma. For example `./gradlew -PtestLoggingEvents=started,passed,skipped,failed test`. +* `xmlSpotBugsReport`: enable XML reports for spotBugs. This also disables HTML reports as only one can be enabled at a time. +* `maxTestRetries`: the maximum number of retries for a failing test case. +* `maxTestRetryFailures`: maximum number of test failures before retrying is disabled for subsequent tests. +* `enableTestCoverage`: enables test coverage plugins and tasks, including bytecode enhancement of classes required to track said +coverage. Note that this introduces some overhead when running tests and hence why it's disabled by default (the overhead +varies, but 15-20% is a reasonable estimate). +* `scalaOptimizerMode`: configures the optimizing behavior of the scala compiler, the value should be one of `none`, `method`, `inline-kafka` or +`inline-scala` (the default is `inline-kafka`). `none` is the scala compiler default, which only eliminates unreachable code. `method` also +includes method-local optimizations. `inline-kafka` adds inlining of methods within the kafka packages. Finally, `inline-scala` also +includes inlining of methods within the scala library (which avoids lambda allocations for methods like `Option.exists`). `inline-scala` is +only safe if the Scala library version is the same at compile time and runtime. Since we cannot guarantee this for all cases (for example, users +may depend on the kafka jar for integration tests where they may include a scala library with a different version), we don't enable it by +default. See https://www.lightbend.com/blog/scala-inliner-optimizer for more details. + +### Dependency Analysis ### + +The gradle [dependency debugging documentation](https://docs.gradle.org/current/userguide/viewing_debugging_dependencies.html) mentions using the `dependencies` or `dependencyInsight` tasks to debug dependencies for the root project or individual subprojects. + +Alternatively, use the `allDeps` or `allDepInsight` tasks for recursively iterating through all subprojects: + + ./gradlew allDeps + + ./gradlew allDepInsight --configuration runtimeClasspath --dependency com.fasterxml.jackson.core:jackson-databind + +These take the same arguments as the builtin variants. + +### Running system tests ### + +See [tests/README.md](tests/README.md). + +### Running in Vagrant ### + +See [vagrant/README.md](vagrant/README.md). + +### Contribution ### + +Apache Kafka is interested in building the community; we would welcome any thoughts or [patches](https://issues.apache.org/jira/browse/KAFKA). You can reach us [on the Apache mailing lists](http://kafka.apache.org/contact.html). + +To contribute follow the instructions here: + * https://kafka.apache.org/contributing.html diff --git a/TROGDOR.md b/TROGDOR.md new file mode 100644 index 0000000..3891857 --- /dev/null +++ b/TROGDOR.md @@ -0,0 +1,189 @@ +Trogdor +======================================== +Trogdor is a test framework for Apache Kafka. + +Trogdor can run benchmarks and other workloads. Trogdor can also inject faults in order to stress test the system. + +Quickstart +========================================================= +First, we want to start a single-node Kafka cluster with a ZooKeeper and a broker. + +Running ZooKeeper: + + > ./bin/zookeeper-server-start.sh ./config/zookeeper.properties &> /tmp/zookeeper.log & + +Running Kafka: + + > ./bin/kafka-server-start.sh ./config/server.properties &> /tmp/kafka.log & + +Then, we want to run a Trogdor Agent, plus a Trogdor Coordinator. + +To run the Trogdor Agent: + + > ./bin/trogdor.sh agent -c ./config/trogdor.conf -n node0 &> /tmp/trogdor-agent.log & + +To run the Trogdor Coordinator: + + > ./bin/trogdor.sh coordinator -c ./config/trogdor.conf -n node0 &> /tmp/trogdor-coordinator.log & + +Let's confirm that all of the daemons are running: + + > jps + 116212 Coordinator + 115188 QuorumPeerMain + 116571 Jps + 115420 Kafka + 115694 Agent + +Now, we can submit a test job to Trogdor. + + > ./bin/trogdor.sh client createTask -t localhost:8889 -i produce0 --spec ./tests/spec/simple_produce_bench.json + Sent CreateTaskRequest for task produce0. + +We can run showTask to see what the task's status is: + + > ./bin/trogdor.sh client showTask -t localhost:8889 -i produce0 + Task bar of type org.apache.kafka.trogdor.workload.ProduceBenchSpec is DONE. FINISHED at 2019-01-09T20:38:22.039-08:00 after 6s + +To see the results, we use showTask with --show-status: + + > ./bin/trogdor.sh client showTask -t localhost:8889 -i produce0 --show-status + Task bar of type org.apache.kafka.trogdor.workload.ProduceBenchSpec is DONE. FINISHED at 2019-01-09T20:38:22.039-08:00 after 6s + Status: { + "totalSent" : 50000, + "averageLatencyMs" : 17.83388, + "p50LatencyMs" : 12, + "p95LatencyMs" : 75, + "p99LatencyMs" : 96, + "transactionsCommitted" : 0 + } + +Trogdor Architecture +======================================== +Trogdor has a single coordinator process which manages multiple agent processes. Each agent process is responsible for a single cluster node. + +The Trogdor coordinator manages tasks. A task is anything we might want to do on a cluster, such as running a benchmark, injecting a fault, or running a workload. In order to implement each task, the coordinator creates workers on one or more agent nodes. + +The Trogdor agent process implements the tasks. For example, when running a workload, the agent process is the process which produces and consumes messages. + +Both the coordinator and the agent expose a REST interface that accepts objects serialized via JSON. There is also a command-line program which makes it easy to send messages to either one without manually crafting the JSON message body. + +All Trogdor RPCs are idempotent except the shutdown requests. Sending an idempotent RPC twice in a row has the same effect as sending the RPC once. + +Tasks +======================================== +Tasks are described by specifications containing: + +* A "class" field describing the task type. This contains a full Java class name. +* A "startMs" field describing when the task should start. This is given in terms of milliseconds since the UNIX epoch. +* A "durationMs" field describing how long the task should last. This is given in terms of milliseconds. +* Other fields which are task-specific. + +The task specification is usually written as JSON. For example, this task specification describes a network partition between nodes 1 and 2, and 3: + + { + "class": "org.apache.kafka.trogdor.fault.NetworkPartitionFaultSpec", + "startMs": 1000, + "durationMs": 30000, + "partitions": [["node1", "node2"], ["node3"]] + } + +This task runs a simple ProduceBench test on a cluster with one producer node, 5 topics, and 10,000 messages per second. +The keys are generated sequentially and the configured partitioner (DefaultPartitioner) is used. + + { + "class": "org.apache.kafka.trogdor.workload.ProduceBenchSpec", + "durationMs": 10000000, + "producerNode": "node0", + "bootstrapServers": "localhost:9092", + "targetMessagesPerSec": 10000, + "maxMessages": 50000, + "activeTopics": { + "foo[1-3]": { + "numPartitions": 10, + "replicationFactor": 1 + } + }, + "inactiveTopics": { + "foo[4-5]": { + "numPartitions": 10, + "replicationFactor": 1 + } + }, + "keyGenerator": { + "type": "sequential", + "size": 8, + "offset": 1 + }, + "useConfiguredPartitioner": true + } + +Tasks are submitted to the coordinator. Once the coordinator determines that it is time for the task to start, it creates workers on agent processes. The workers run until the task is done. + +Task specifications are immutable; they do not change after the task has been created. + +Tasks can be in several states: +* PENDING, when task is waiting to execute, +* RUNNING, when the task is running, +* STOPPING, when the task is in the process of stopping, +* DONE, when the task is done. + +Tasks that are DONE also have an error field which will be set if the task failed. + +Workloads +======================================== +Trogdor can run several workloads. Workloads perform operations on the cluster and measure their performance. Workloads fail when the operations cannot be performed. + +### ProduceBench +ProduceBench starts a Kafka producer on a single agent node, producing to several partitions. The workload measures the average produce latency, as well as the median, 95th percentile, and 99th percentile latency. +It can be configured to use a transactional producer which can commit transactions based on a set time interval or number of messages. + +### RoundTripWorkload +RoundTripWorkload tests both production and consumption. The workload starts a Kafka producer and consumer on a single node. The consumer will read back the messages that were produced by the producer. + +### ConsumeBench +ConsumeBench starts one or more Kafka consumers on a single agent node. Depending on the passed in configuration (see ConsumeBenchSpec), the consumers either subscribe to a set of topics (leveraging consumer group functionality and dynamic partition assignment) or manually assign partitions to themselves. +The workload measures the average produce latency, as well as the median, 95th percentile, and 99th percentile latency. + +Faults +======================================== +Trogdor can run several faults which deliberately break something in the cluster. + +### ProcessStopFault +ProcessStopFault stops a process by sending it a SIGSTOP signal. When the fault ends, the process is resumed with SIGCONT. + +### NetworkPartitionFault +NetworkPartitionFault sets up an artificial network partition between one or more sets of nodes. Currently, this is implemented using iptables. The iptables rules are set up on the outbound traffic from the affected nodes. Therefore, the affected nodes should still be reachable from outside the cluster. + +External Processes +======================================== +Trogdor supports running arbitrary commands in external processes. This is a generic way to run any configurable command in the Trogdor framework - be it a Python program, bash script, docker image, etc. + +### ExternalCommandWorker +ExternalCommandWorker starts an external command defined by the ExternalCommandSpec. It essentially allows you to run any command on any Trogdor agent node. +The worker communicates with the external process via its stdin, stdout and stderr in a JSON protocol. It uses stdout for any actionable communication and only logs what it sees in stderr. +On startup the worker will first send a message describing the workload to the external process in this format: +``` +{"id":, "workload":} +``` +and will then listen for messages from the external process, again in a JSON format. +Said JSON can contain the following fields: +- status: If the object contains this field, the status of the worker will be set to the given value. +- error: If the object contains this field, the error of the worker will be set to the given value. Once an error occurs, the external process will be terminated. +- log: If the object contains this field, a log message will be issued with this text. +An example: +```json +{"log": "Finished successfully.", "status": {"p99ProduceLatency": "100ms", "messagesSent": 10000}} +``` + +Exec Mode +======================================== +Sometimes, you just want to run a test quickly on a single node. In this case, you can use "exec mode." This mode allows you to run a single Trogdor Agent without a Coordinator. + +When using exec mode, you must pass in a Task specification to use. The Agent will try to start this task. + +For example: + + > ./bin/trogdor.sh agent -n node0 -c ./config/trogdor.conf --exec ./tests/spec/simple_produce_bench.json + +When using exec mode, the Agent will exit once the task is complete. diff --git a/Vagrantfile b/Vagrantfile new file mode 100644 index 0000000..ee08487 --- /dev/null +++ b/Vagrantfile @@ -0,0 +1,199 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -*- mode: ruby -*- +# vi: set ft=ruby : + +require 'socket' + +# Vagrantfile API/syntax version. Don't touch unless you know what you're doing! +VAGRANTFILE_API_VERSION = "2" + +# General config +enable_dns = false +# Override to false when bringing up a cluster on AWS +enable_hostmanager = true +enable_jmx = false +num_zookeepers = 1 +num_brokers = 3 +num_workers = 0 # Generic workers that get the code, but don't start any services +ram_megabytes = 1280 +base_box = "ubuntu/trusty64" + +# EC2 +ec2_access_key = ENV['AWS_ACCESS_KEY'] +ec2_secret_key = ENV['AWS_SECRET_KEY'] +ec2_session_token = ENV['AWS_SESSION_TOKEN'] +ec2_keypair_name = nil +ec2_keypair_file = nil + +ec2_region = "us-east-1" +ec2_az = nil # Uses set by AWS +ec2_ami = "ami-29ebb519" +ec2_instance_type = "m3.medium" +ec2_spot_instance = ENV['SPOT_INSTANCE'] ? ENV['SPOT_INSTANCE'] == 'true' : true +ec2_spot_max_price = "0.113" # On-demand price for instance type +ec2_user = "ubuntu" +ec2_instance_name_prefix = "kafka-vagrant" +ec2_security_groups = nil +ec2_subnet_id = nil +# Only override this by setting it to false if you're running in a VPC and you +# are running Vagrant from within that VPC as well. +ec2_associate_public_ip = nil + +jdk_major = '8' +jdk_full = '8u202-linux-x64' + +local_config_file = File.join(File.dirname(__FILE__), "Vagrantfile.local") +if File.exists?(local_config_file) then + eval(File.read(local_config_file), binding, "Vagrantfile.local") +end + +# TODO(ksweeney): RAM requirements are not empirical and can probably be significantly lowered. +Vagrant.configure(VAGRANTFILE_API_VERSION) do |config| + config.hostmanager.enabled = enable_hostmanager + config.hostmanager.manage_host = enable_dns + config.hostmanager.include_offline = false + + ## Provider-specific global configs + config.vm.provider :virtualbox do |vb,override| + override.vm.box = base_box + + override.hostmanager.ignore_private_ip = false + + # Brokers started with the standard script currently set Xms and Xmx to 1G, + # plus we need some extra head room. + vb.customize ["modifyvm", :id, "--memory", ram_megabytes.to_s] + + if Vagrant.has_plugin?("vagrant-cachier") + override.cache.scope = :box + end + end + + config.vm.provider :aws do |aws,override| + # The "box" is specified as an AMI + override.vm.box = "dummy" + override.vm.box_url = "https://github.com/mitchellh/vagrant-aws/raw/master/dummy.box" + + cached_addresses = {} + # Use a custom resolver that SSH's into the machine and finds the IP address + # directly. This lets us get at the private IP address directly, avoiding + # some issues with using the default IP resolver, which uses the public IP + # address. + override.hostmanager.ip_resolver = proc do |vm, resolving_vm| + if !cached_addresses.has_key?(vm.name) + state_id = vm.state.id + if state_id != :not_created && state_id != :stopped && vm.communicate.ready? + contents = '' + vm.communicate.execute("/sbin/ifconfig eth0 | grep 'inet addr' | tail -n 1 | egrep -o '[0-9\.]+' | head -n 1 2>&1") do |type, data| + contents << data + end + cached_addresses[vm.name] = contents.split("\n").first[/(\d+\.\d+\.\d+\.\d+)/, 1] + else + cached_addresses[vm.name] = nil + end + end + cached_addresses[vm.name] + end + + override.ssh.username = ec2_user + override.ssh.private_key_path = ec2_keypair_file + + aws.access_key_id = ec2_access_key + aws.secret_access_key = ec2_secret_key + aws.session_token = ec2_session_token + aws.keypair_name = ec2_keypair_name + + aws.region = ec2_region + aws.availability_zone = ec2_az + aws.instance_type = ec2_instance_type + aws.ami = ec2_ami + aws.security_groups = ec2_security_groups + aws.subnet_id = ec2_subnet_id + # If a subnet is specified, default to turning on a public IP unless the + # user explicitly specifies the option. Without a public IP, Vagrant won't + # be able to SSH into the hosts unless Vagrant is also running in the VPC. + if ec2_associate_public_ip.nil? + aws.associate_public_ip = true unless ec2_subnet_id.nil? + else + aws.associate_public_ip = ec2_associate_public_ip + end + aws.region_config ec2_region do |region| + region.spot_instance = ec2_spot_instance + region.spot_max_price = ec2_spot_max_price + end + + # Exclude some directories that can grow very large from syncing + override.vm.synced_folder ".", "/vagrant", type: "rsync", rsync__exclude: ['.git', 'core/data/', 'logs/', 'tests/results/', 'results/'] + end + + def name_node(node, name, ec2_instance_name_prefix) + node.vm.hostname = name + node.vm.provider :aws do |aws| + aws.tags = { + 'Name' => ec2_instance_name_prefix + "-" + Socket.gethostname + "-" + name, + 'JenkinsBuildUrl' => ENV['BUILD_URL'] + } + end + end + + def assign_local_ip(node, ip_address) + node.vm.provider :virtualbox do |vb,override| + override.vm.network :private_network, ip: ip_address + end + end + + ## Cluster definition + zookeepers = [] + (1..num_zookeepers).each { |i| + name = "zk" + i.to_s + zookeepers.push(name) + config.vm.define name do |zookeeper| + name_node(zookeeper, name, ec2_instance_name_prefix) + ip_address = "192.168.50." + (10 + i).to_s + assign_local_ip(zookeeper, ip_address) + zookeeper.vm.provision "shell", path: "vagrant/base.sh", env: {"JDK_MAJOR" => jdk_major, "JDK_FULL" => jdk_full} + zk_jmx_port = enable_jmx ? (8000 + i).to_s : "" + zookeeper.vm.provision "shell", path: "vagrant/zk.sh", :args => [i.to_s, num_zookeepers, zk_jmx_port] + end + } + + (1..num_brokers).each { |i| + name = "broker" + i.to_s + config.vm.define name do |broker| + name_node(broker, name, ec2_instance_name_prefix) + ip_address = "192.168.50." + (50 + i).to_s + assign_local_ip(broker, ip_address) + # We need to be careful about what we list as the publicly routable + # address since this is registered in ZK and handed out to clients. If + # host DNS isn't setup, we shouldn't use hostnames -- IP addresses must be + # used to support clients running on the host. + zookeeper_connect = zookeepers.map{ |zk_addr| zk_addr + ":2181"}.join(",") + broker.vm.provision "shell", path: "vagrant/base.sh", env: {"JDK_MAJOR" => jdk_major, "JDK_FULL" => jdk_full} + kafka_jmx_port = enable_jmx ? (9000 + i).to_s : "" + broker.vm.provision "shell", path: "vagrant/broker.sh", :args => [i.to_s, enable_dns ? name : ip_address, zookeeper_connect, kafka_jmx_port] + end + } + + (1..num_workers).each { |i| + name = "worker" + i.to_s + config.vm.define name do |worker| + name_node(worker, name, ec2_instance_name_prefix) + ip_address = "192.168.50." + (100 + i).to_s + assign_local_ip(worker, ip_address) + worker.vm.provision "shell", path: "vagrant/base.sh", env: {"JDK_MAJOR" => jdk_major, "JDK_FULL" => jdk_full} + end + } + +end diff --git a/bin/connect-distributed.sh b/bin/connect-distributed.sh new file mode 100755 index 0000000..b8088ad --- /dev/null +++ b/bin/connect-distributed.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if [ $# -lt 1 ]; +then + echo "USAGE: $0 [-daemon] connect-distributed.properties" + exit 1 +fi + +base_dir=$(dirname $0) + +if [ "x$KAFKA_LOG4J_OPTS" = "x" ]; then + export KAFKA_LOG4J_OPTS="-Dlog4j.configuration=file:$base_dir/../config/connect-log4j.properties" +fi + +if [ "x$KAFKA_HEAP_OPTS" = "x" ]; then + export KAFKA_HEAP_OPTS="-Xms256M -Xmx2G" +fi + +EXTRA_ARGS=${EXTRA_ARGS-'-name connectDistributed'} + +COMMAND=$1 +case $COMMAND in + -daemon) + EXTRA_ARGS="-daemon "$EXTRA_ARGS + shift + ;; + *) + ;; +esac + +exec $(dirname $0)/kafka-run-class.sh $EXTRA_ARGS org.apache.kafka.connect.cli.ConnectDistributed "$@" diff --git a/bin/connect-mirror-maker.sh b/bin/connect-mirror-maker.sh new file mode 100755 index 0000000..8e2b2e1 --- /dev/null +++ b/bin/connect-mirror-maker.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if [ $# -lt 1 ]; +then + echo "USAGE: $0 [-daemon] mm2.properties" + exit 1 +fi + +base_dir=$(dirname $0) + +if [ "x$KAFKA_LOG4J_OPTS" = "x" ]; then + export KAFKA_LOG4J_OPTS="-Dlog4j.configuration=file:$base_dir/../config/connect-log4j.properties" +fi + +if [ "x$KAFKA_HEAP_OPTS" = "x" ]; then + export KAFKA_HEAP_OPTS="-Xms256M -Xmx2G" +fi + +EXTRA_ARGS=${EXTRA_ARGS-'-name mirrorMaker'} + +COMMAND=$1 +case $COMMAND in + -daemon) + EXTRA_ARGS="-daemon "$EXTRA_ARGS + shift + ;; + *) + ;; +esac + +exec $(dirname $0)/kafka-run-class.sh $EXTRA_ARGS org.apache.kafka.connect.mirror.MirrorMaker "$@" diff --git a/bin/connect-standalone.sh b/bin/connect-standalone.sh new file mode 100755 index 0000000..441069f --- /dev/null +++ b/bin/connect-standalone.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if [ $# -lt 1 ]; +then + echo "USAGE: $0 [-daemon] connect-standalone.properties" + exit 1 +fi + +base_dir=$(dirname $0) + +if [ "x$KAFKA_LOG4J_OPTS" = "x" ]; then + export KAFKA_LOG4J_OPTS="-Dlog4j.configuration=file:$base_dir/../config/connect-log4j.properties" +fi + +if [ "x$KAFKA_HEAP_OPTS" = "x" ]; then + export KAFKA_HEAP_OPTS="-Xms256M -Xmx2G" +fi + +EXTRA_ARGS=${EXTRA_ARGS-'-name connectStandalone'} + +COMMAND=$1 +case $COMMAND in + -daemon) + EXTRA_ARGS="-daemon "$EXTRA_ARGS + shift + ;; + *) + ;; +esac + +exec $(dirname $0)/kafka-run-class.sh $EXTRA_ARGS org.apache.kafka.connect.cli.ConnectStandalone "$@" diff --git a/bin/kafka-acls.sh b/bin/kafka-acls.sh new file mode 100755 index 0000000..8fa6554 --- /dev/null +++ b/bin/kafka-acls.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +exec $(dirname $0)/kafka-run-class.sh kafka.admin.AclCommand "$@" diff --git a/bin/kafka-broker-api-versions.sh b/bin/kafka-broker-api-versions.sh new file mode 100755 index 0000000..4f560a0 --- /dev/null +++ b/bin/kafka-broker-api-versions.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +exec $(dirname $0)/kafka-run-class.sh kafka.admin.BrokerApiVersionsCommand "$@" diff --git a/bin/kafka-cluster.sh b/bin/kafka-cluster.sh new file mode 100755 index 0000000..574007e --- /dev/null +++ b/bin/kafka-cluster.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +exec $(dirname $0)/kafka-run-class.sh kafka.tools.ClusterTool "$@" diff --git a/bin/kafka-configs.sh b/bin/kafka-configs.sh new file mode 100755 index 0000000..2f9eb8c --- /dev/null +++ b/bin/kafka-configs.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +exec $(dirname $0)/kafka-run-class.sh kafka.admin.ConfigCommand "$@" diff --git a/bin/kafka-console-consumer.sh b/bin/kafka-console-consumer.sh new file mode 100755 index 0000000..dbaac2b --- /dev/null +++ b/bin/kafka-console-consumer.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if [ "x$KAFKA_HEAP_OPTS" = "x" ]; then + export KAFKA_HEAP_OPTS="-Xmx512M" +fi + +exec $(dirname $0)/kafka-run-class.sh kafka.tools.ConsoleConsumer "$@" diff --git a/bin/kafka-console-producer.sh b/bin/kafka-console-producer.sh new file mode 100755 index 0000000..e5187b8 --- /dev/null +++ b/bin/kafka-console-producer.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if [ "x$KAFKA_HEAP_OPTS" = "x" ]; then + export KAFKA_HEAP_OPTS="-Xmx512M" +fi +exec $(dirname $0)/kafka-run-class.sh kafka.tools.ConsoleProducer "$@" diff --git a/bin/kafka-consumer-groups.sh b/bin/kafka-consumer-groups.sh new file mode 100755 index 0000000..feb063d --- /dev/null +++ b/bin/kafka-consumer-groups.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +exec $(dirname $0)/kafka-run-class.sh kafka.admin.ConsumerGroupCommand "$@" diff --git a/bin/kafka-consumer-perf-test.sh b/bin/kafka-consumer-perf-test.sh new file mode 100755 index 0000000..77cda72 --- /dev/null +++ b/bin/kafka-consumer-perf-test.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if [ "x$KAFKA_HEAP_OPTS" = "x" ]; then + export KAFKA_HEAP_OPTS="-Xmx512M" +fi +exec $(dirname $0)/kafka-run-class.sh kafka.tools.ConsumerPerformance "$@" diff --git a/bin/kafka-delegation-tokens.sh b/bin/kafka-delegation-tokens.sh new file mode 100755 index 0000000..49cb276 --- /dev/null +++ b/bin/kafka-delegation-tokens.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +exec $(dirname $0)/kafka-run-class.sh kafka.admin.DelegationTokenCommand "$@" diff --git a/bin/kafka-delete-records.sh b/bin/kafka-delete-records.sh new file mode 100755 index 0000000..8726f91 --- /dev/null +++ b/bin/kafka-delete-records.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +exec $(dirname $0)/kafka-run-class.sh kafka.admin.DeleteRecordsCommand "$@" diff --git a/bin/kafka-dump-log.sh b/bin/kafka-dump-log.sh new file mode 100755 index 0000000..a97ea7d --- /dev/null +++ b/bin/kafka-dump-log.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +exec $(dirname $0)/kafka-run-class.sh kafka.tools.DumpLogSegments "$@" diff --git a/bin/kafka-features.sh b/bin/kafka-features.sh new file mode 100755 index 0000000..9dd9f16 --- /dev/null +++ b/bin/kafka-features.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +exec $(dirname $0)/kafka-run-class.sh kafka.admin.FeatureCommand "$@" diff --git a/bin/kafka-get-offsets.sh b/bin/kafka-get-offsets.sh new file mode 100755 index 0000000..993a202 --- /dev/null +++ b/bin/kafka-get-offsets.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +exec $(dirname $0)/kafka-run-class.sh kafka.tools.GetOffsetShell "$@" diff --git a/bin/kafka-leader-election.sh b/bin/kafka-leader-election.sh new file mode 100755 index 0000000..88baef3 --- /dev/null +++ b/bin/kafka-leader-election.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +exec $(dirname $0)/kafka-run-class.sh kafka.admin.LeaderElectionCommand "$@" diff --git a/bin/kafka-log-dirs.sh b/bin/kafka-log-dirs.sh new file mode 100755 index 0000000..dc16edc --- /dev/null +++ b/bin/kafka-log-dirs.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +exec $(dirname $0)/kafka-run-class.sh kafka.admin.LogDirsCommand "$@" diff --git a/bin/kafka-metadata-shell.sh b/bin/kafka-metadata-shell.sh new file mode 100755 index 0000000..289f0c1 --- /dev/null +++ b/bin/kafka-metadata-shell.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +exec $(dirname $0)/kafka-run-class.sh org.apache.kafka.shell.MetadataShell "$@" diff --git a/bin/kafka-mirror-maker.sh b/bin/kafka-mirror-maker.sh new file mode 100755 index 0000000..981f271 --- /dev/null +++ b/bin/kafka-mirror-maker.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +exec $(dirname $0)/kafka-run-class.sh kafka.tools.MirrorMaker "$@" diff --git a/bin/kafka-producer-perf-test.sh b/bin/kafka-producer-perf-test.sh new file mode 100755 index 0000000..73a6288 --- /dev/null +++ b/bin/kafka-producer-perf-test.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if [ "x$KAFKA_HEAP_OPTS" = "x" ]; then + export KAFKA_HEAP_OPTS="-Xmx512M" +fi +exec $(dirname $0)/kafka-run-class.sh org.apache.kafka.tools.ProducerPerformance "$@" diff --git a/bin/kafka-reassign-partitions.sh b/bin/kafka-reassign-partitions.sh new file mode 100755 index 0000000..4c7f1bc --- /dev/null +++ b/bin/kafka-reassign-partitions.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +exec $(dirname $0)/kafka-run-class.sh kafka.admin.ReassignPartitionsCommand "$@" diff --git a/bin/kafka-replica-verification.sh b/bin/kafka-replica-verification.sh new file mode 100755 index 0000000..4960836 --- /dev/null +++ b/bin/kafka-replica-verification.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +exec $(dirname $0)/kafka-run-class.sh kafka.tools.ReplicaVerificationTool "$@" diff --git a/bin/kafka-run-class.sh b/bin/kafka-run-class.sh new file mode 100755 index 0000000..6167583 --- /dev/null +++ b/bin/kafka-run-class.sh @@ -0,0 +1,343 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if [ $# -lt 1 ]; +then + echo "USAGE: $0 [-daemon] [-name servicename] [-loggc] classname [opts]" + exit 1 +fi + +# CYGWIN == 1 if Cygwin is detected, else 0. +if [[ $(uname -a) =~ "CYGWIN" ]]; then + CYGWIN=1 +else + CYGWIN=0 +fi + +if [ -z "$INCLUDE_TEST_JARS" ]; then + INCLUDE_TEST_JARS=false +fi + +# Exclude jars not necessary for running commands. +regex="(-(test|test-sources|src|scaladoc|javadoc)\.jar|jar.asc)$" +should_include_file() { + if [ "$INCLUDE_TEST_JARS" = true ]; then + return 0 + fi + file=$1 + if [ -z "$(echo "$file" | egrep "$regex")" ] ; then + return 0 + else + return 1 + fi +} + +base_dir=$(dirname $0)/.. + +if [ -z "$SCALA_VERSION" ]; then + SCALA_VERSION=2.13.6 + if [[ -f "$base_dir/gradle.properties" ]]; then + SCALA_VERSION=`grep "^scalaVersion=" "$base_dir/gradle.properties" | cut -d= -f 2` + fi +fi + +if [ -z "$SCALA_BINARY_VERSION" ]; then + SCALA_BINARY_VERSION=$(echo $SCALA_VERSION | cut -f 1-2 -d '.') +fi + +# run ./gradlew copyDependantLibs to get all dependant jars in a local dir +shopt -s nullglob +if [ -z "$UPGRADE_KAFKA_STREAMS_TEST_VERSION" ]; then + for dir in "$base_dir"/core/build/dependant-libs-${SCALA_VERSION}*; + do + CLASSPATH="$CLASSPATH:$dir/*" + done +fi + +for file in "$base_dir"/examples/build/libs/kafka-examples*.jar; +do + if should_include_file "$file"; then + CLASSPATH="$CLASSPATH":"$file" + fi +done + +if [ -z "$UPGRADE_KAFKA_STREAMS_TEST_VERSION" ]; then + clients_lib_dir=$(dirname $0)/../clients/build/libs + streams_lib_dir=$(dirname $0)/../streams/build/libs + streams_dependant_clients_lib_dir=$(dirname $0)/../streams/build/dependant-libs-${SCALA_VERSION} +else + clients_lib_dir=/opt/kafka-$UPGRADE_KAFKA_STREAMS_TEST_VERSION/libs + streams_lib_dir=$clients_lib_dir + streams_dependant_clients_lib_dir=$streams_lib_dir +fi + + +for file in "$clients_lib_dir"/kafka-clients*.jar; +do + if should_include_file "$file"; then + CLASSPATH="$CLASSPATH":"$file" + fi +done + +for file in "$streams_lib_dir"/kafka-streams*.jar; +do + if should_include_file "$file"; then + CLASSPATH="$CLASSPATH":"$file" + fi +done + +if [ -z "$UPGRADE_KAFKA_STREAMS_TEST_VERSION" ]; then + for file in "$base_dir"/streams/examples/build/libs/kafka-streams-examples*.jar; + do + if should_include_file "$file"; then + CLASSPATH="$CLASSPATH":"$file" + fi + done +else + VERSION_NO_DOTS=`echo $UPGRADE_KAFKA_STREAMS_TEST_VERSION | sed 's/\.//g'` + SHORT_VERSION_NO_DOTS=${VERSION_NO_DOTS:0:((${#VERSION_NO_DOTS} - 1))} # remove last char, ie, bug-fix number + for file in "$base_dir"/streams/upgrade-system-tests-$SHORT_VERSION_NO_DOTS/build/libs/kafka-streams-upgrade-system-tests*.jar; + do + if should_include_file "$file"; then + CLASSPATH="$file":"$CLASSPATH" + fi + done + if [ "$SHORT_VERSION_NO_DOTS" = "0100" ]; then + CLASSPATH="/opt/kafka-$UPGRADE_KAFKA_STREAMS_TEST_VERSION/libs/zkclient-0.8.jar":"$CLASSPATH" + CLASSPATH="/opt/kafka-$UPGRADE_KAFKA_STREAMS_TEST_VERSION/libs/zookeeper-3.4.6.jar":"$CLASSPATH" + fi + if [ "$SHORT_VERSION_NO_DOTS" = "0101" ]; then + CLASSPATH="/opt/kafka-$UPGRADE_KAFKA_STREAMS_TEST_VERSION/libs/zkclient-0.9.jar":"$CLASSPATH" + CLASSPATH="/opt/kafka-$UPGRADE_KAFKA_STREAMS_TEST_VERSION/libs/zookeeper-3.4.8.jar":"$CLASSPATH" + fi +fi + +for file in "$streams_dependant_clients_lib_dir"/rocksdb*.jar; +do + CLASSPATH="$CLASSPATH":"$file" +done + +for file in "$streams_dependant_clients_lib_dir"/*hamcrest*.jar; +do + CLASSPATH="$CLASSPATH":"$file" +done + +for file in "$base_dir"/shell/build/libs/kafka-shell*.jar; +do + if should_include_file "$file"; then + CLASSPATH="$CLASSPATH":"$file" + fi +done + +for dir in "$base_dir"/shell/build/dependant-libs-${SCALA_VERSION}*; +do + CLASSPATH="$CLASSPATH:$dir/*" +done + +for file in "$base_dir"/tools/build/libs/kafka-tools*.jar; +do + if should_include_file "$file"; then + CLASSPATH="$CLASSPATH":"$file" + fi +done + +for dir in "$base_dir"/tools/build/dependant-libs-${SCALA_VERSION}*; +do + CLASSPATH="$CLASSPATH:$dir/*" +done + +for file in "$base_dir"/trogdor/build/libs/trogdor-*.jar; +do + if should_include_file "$file"; then + CLASSPATH="$CLASSPATH":"$file" + fi +done + +for dir in "$base_dir"/trogdor/build/dependant-libs-${SCALA_VERSION}*; +do + CLASSPATH="$CLASSPATH:$dir/*" +done + +for cc_pkg in "api" "transforms" "runtime" "file" "mirror" "mirror-client" "json" "tools" "basic-auth-extension" +do + for file in "$base_dir"/connect/${cc_pkg}/build/libs/connect-${cc_pkg}*.jar; + do + if should_include_file "$file"; then + CLASSPATH="$CLASSPATH":"$file" + fi + done + if [ -d "$base_dir/connect/${cc_pkg}/build/dependant-libs" ] ; then + CLASSPATH="$CLASSPATH:$base_dir/connect/${cc_pkg}/build/dependant-libs/*" + fi +done + +# classpath addition for release +for file in "$base_dir"/libs/*; +do + if should_include_file "$file"; then + CLASSPATH="$CLASSPATH":"$file" + fi +done + +for file in "$base_dir"/core/build/libs/kafka_${SCALA_BINARY_VERSION}*.jar; +do + if should_include_file "$file"; then + CLASSPATH="$CLASSPATH":"$file" + fi +done +shopt -u nullglob + +if [ -z "$CLASSPATH" ] ; then + echo "Classpath is empty. Please build the project first e.g. by running './gradlew jar -PscalaVersion=$SCALA_VERSION'" + exit 1 +fi + +# JMX settings +if [ -z "$KAFKA_JMX_OPTS" ]; then + KAFKA_JMX_OPTS="-Dcom.sun.management.jmxremote -Dcom.sun.management.jmxremote.authenticate=false -Dcom.sun.management.jmxremote.ssl=false " +fi + +# JMX port to use +if [ $JMX_PORT ]; then + KAFKA_JMX_OPTS="$KAFKA_JMX_OPTS -Dcom.sun.management.jmxremote.port=$JMX_PORT " +fi + +# Log directory to use +if [ "x$LOG_DIR" = "x" ]; then + LOG_DIR="$base_dir/logs" +fi + +# Log4j settings +if [ -z "$KAFKA_LOG4J_OPTS" ]; then + # Log to console. This is a tool. + LOG4J_DIR="$base_dir/config/tools-log4j.properties" + # If Cygwin is detected, LOG4J_DIR is converted to Windows format. + (( CYGWIN )) && LOG4J_DIR=$(cygpath --path --mixed "${LOG4J_DIR}") + KAFKA_LOG4J_OPTS="-Dlog4j.configuration=file:${LOG4J_DIR}" +else + # create logs directory + if [ ! -d "$LOG_DIR" ]; then + mkdir -p "$LOG_DIR" + fi +fi + +# If Cygwin is detected, LOG_DIR is converted to Windows format. +(( CYGWIN )) && LOG_DIR=$(cygpath --path --mixed "${LOG_DIR}") +KAFKA_LOG4J_OPTS="-Dkafka.logs.dir=$LOG_DIR $KAFKA_LOG4J_OPTS" + +# Generic jvm settings you want to add +if [ -z "$KAFKA_OPTS" ]; then + KAFKA_OPTS="" +fi + +# Set Debug options if enabled +if [ "x$KAFKA_DEBUG" != "x" ]; then + + # Use default ports + DEFAULT_JAVA_DEBUG_PORT="5005" + + if [ -z "$JAVA_DEBUG_PORT" ]; then + JAVA_DEBUG_PORT="$DEFAULT_JAVA_DEBUG_PORT" + fi + + # Use the defaults if JAVA_DEBUG_OPTS was not set + DEFAULT_JAVA_DEBUG_OPTS="-agentlib:jdwp=transport=dt_socket,server=y,suspend=${DEBUG_SUSPEND_FLAG:-n},address=$JAVA_DEBUG_PORT" + if [ -z "$JAVA_DEBUG_OPTS" ]; then + JAVA_DEBUG_OPTS="$DEFAULT_JAVA_DEBUG_OPTS" + fi + + echo "Enabling Java debug options: $JAVA_DEBUG_OPTS" + KAFKA_OPTS="$JAVA_DEBUG_OPTS $KAFKA_OPTS" +fi + +# Which java to use +if [ -z "$JAVA_HOME" ]; then + JAVA="java" +else + JAVA="$JAVA_HOME/bin/java" +fi + +# Memory options +if [ -z "$KAFKA_HEAP_OPTS" ]; then + KAFKA_HEAP_OPTS="-Xmx256M" +fi + +# JVM performance options +# MaxInlineLevel=15 is the default since JDK 14 and can be removed once older JDKs are no longer supported +if [ -z "$KAFKA_JVM_PERFORMANCE_OPTS" ]; then + KAFKA_JVM_PERFORMANCE_OPTS="-server -XX:+UseG1GC -XX:MaxGCPauseMillis=20 -XX:InitiatingHeapOccupancyPercent=35 -XX:+ExplicitGCInvokesConcurrent -XX:MaxInlineLevel=15 -Djava.awt.headless=true" +fi + +while [ $# -gt 0 ]; do + COMMAND=$1 + case $COMMAND in + -name) + DAEMON_NAME=$2 + CONSOLE_OUTPUT_FILE=$LOG_DIR/$DAEMON_NAME.out + shift 2 + ;; + -loggc) + if [ -z "$KAFKA_GC_LOG_OPTS" ]; then + GC_LOG_ENABLED="true" + fi + shift + ;; + -daemon) + DAEMON_MODE="true" + shift + ;; + *) + break + ;; + esac +done + +# GC options +GC_FILE_SUFFIX='-gc.log' +GC_LOG_FILE_NAME='' +if [ "x$GC_LOG_ENABLED" = "xtrue" ]; then + GC_LOG_FILE_NAME=$DAEMON_NAME$GC_FILE_SUFFIX + + # The first segment of the version number, which is '1' for releases before Java 9 + # it then becomes '9', '10', ... + # Some examples of the first line of `java --version`: + # 8 -> java version "1.8.0_152" + # 9.0.4 -> java version "9.0.4" + # 10 -> java version "10" 2018-03-20 + # 10.0.1 -> java version "10.0.1" 2018-04-17 + # We need to match to the end of the line to prevent sed from printing the characters that do not match + JAVA_MAJOR_VERSION=$("$JAVA" -version 2>&1 | sed -E -n 's/.* version "([0-9]*).*$/\1/p') + if [[ "$JAVA_MAJOR_VERSION" -ge "9" ]] ; then + KAFKA_GC_LOG_OPTS="-Xlog:gc*:file=$LOG_DIR/$GC_LOG_FILE_NAME:time,tags:filecount=10,filesize=100M" + else + KAFKA_GC_LOG_OPTS="-Xloggc:$LOG_DIR/$GC_LOG_FILE_NAME -verbose:gc -XX:+PrintGCDetails -XX:+PrintGCDateStamps -XX:+PrintGCTimeStamps -XX:+UseGCLogFileRotation -XX:NumberOfGCLogFiles=10 -XX:GCLogFileSize=100M" + fi +fi + +# Remove a possible colon prefix from the classpath (happens at lines like `CLASSPATH="$CLASSPATH:$file"` when CLASSPATH is blank) +# Syntax used on the right side is native Bash string manipulation; for more details see +# http://tldp.org/LDP/abs/html/string-manipulation.html, specifically the section titled "Substring Removal" +CLASSPATH=${CLASSPATH#:} + +# If Cygwin is detected, classpath is converted to Windows format. +(( CYGWIN )) && CLASSPATH=$(cygpath --path --mixed "${CLASSPATH}") + +# Launch mode +if [ "x$DAEMON_MODE" = "xtrue" ]; then + nohup "$JAVA" $KAFKA_HEAP_OPTS $KAFKA_JVM_PERFORMANCE_OPTS $KAFKA_GC_LOG_OPTS $KAFKA_JMX_OPTS $KAFKA_LOG4J_OPTS -cp "$CLASSPATH" $KAFKA_OPTS "$@" > "$CONSOLE_OUTPUT_FILE" 2>&1 < /dev/null & +else + exec "$JAVA" $KAFKA_HEAP_OPTS $KAFKA_JVM_PERFORMANCE_OPTS $KAFKA_GC_LOG_OPTS $KAFKA_JMX_OPTS $KAFKA_LOG4J_OPTS -cp "$CLASSPATH" $KAFKA_OPTS "$@" +fi diff --git a/bin/kafka-server-start.sh b/bin/kafka-server-start.sh new file mode 100755 index 0000000..5a53126 --- /dev/null +++ b/bin/kafka-server-start.sh @@ -0,0 +1,44 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if [ $# -lt 1 ]; +then + echo "USAGE: $0 [-daemon] server.properties [--override property=value]*" + exit 1 +fi +base_dir=$(dirname $0) + +if [ "x$KAFKA_LOG4J_OPTS" = "x" ]; then + export KAFKA_LOG4J_OPTS="-Dlog4j.configuration=file:$base_dir/../config/log4j.properties" +fi + +if [ "x$KAFKA_HEAP_OPTS" = "x" ]; then + export KAFKA_HEAP_OPTS="-Xmx1G -Xms1G" +fi + +EXTRA_ARGS=${EXTRA_ARGS-'-name kafkaServer -loggc'} + +COMMAND=$1 +case $COMMAND in + -daemon) + EXTRA_ARGS="-daemon "$EXTRA_ARGS + shift + ;; + *) + ;; +esac + +exec $base_dir/kafka-run-class.sh $EXTRA_ARGS kafka.Kafka "$@" diff --git a/bin/kafka-server-stop.sh b/bin/kafka-server-stop.sh new file mode 100755 index 0000000..437189f --- /dev/null +++ b/bin/kafka-server-stop.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +SIGNAL=${SIGNAL:-TERM} + +OSNAME=$(uname -s) +if [[ "$OSNAME" == "OS/390" ]]; then + if [ -z $JOBNAME ]; then + JOBNAME="KAFKSTRT" + fi + PIDS=$(ps -A -o pid,jobname,comm | grep -i $JOBNAME | grep java | grep -v grep | awk '{print $1}') +elif [[ "$OSNAME" == "OS400" ]]; then + PIDS=$(ps -Af | grep -i 'kafka\.Kafka' | grep java | grep -v grep | awk '{print $2}') +else + PIDS=$(ps ax | grep ' kafka\.Kafka ' | grep java | grep -v grep | awk '{print $1}') +fi + +if [ -z "$PIDS" ]; then + echo "No kafka server to stop" + exit 1 +else + kill -s $SIGNAL $PIDS +fi diff --git a/bin/kafka-storage.sh b/bin/kafka-storage.sh new file mode 100755 index 0000000..eef9342 --- /dev/null +++ b/bin/kafka-storage.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +exec $(dirname $0)/kafka-run-class.sh kafka.tools.StorageTool "$@" diff --git a/bin/kafka-streams-application-reset.sh b/bin/kafka-streams-application-reset.sh new file mode 100755 index 0000000..3363732 --- /dev/null +++ b/bin/kafka-streams-application-reset.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if [ "x$KAFKA_HEAP_OPTS" = "x" ]; then + export KAFKA_HEAP_OPTS="-Xmx512M" +fi + +exec $(dirname $0)/kafka-run-class.sh kafka.tools.StreamsResetter "$@" diff --git a/bin/kafka-topics.sh b/bin/kafka-topics.sh new file mode 100755 index 0000000..ad6a2d4 --- /dev/null +++ b/bin/kafka-topics.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +exec $(dirname $0)/kafka-run-class.sh kafka.admin.TopicCommand "$@" diff --git a/bin/kafka-transactions.sh b/bin/kafka-transactions.sh new file mode 100755 index 0000000..6fb5233 --- /dev/null +++ b/bin/kafka-transactions.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +exec $(dirname $0)/kafka-run-class.sh org.apache.kafka.tools.TransactionsCommand "$@" diff --git a/bin/kafka-verifiable-consumer.sh b/bin/kafka-verifiable-consumer.sh new file mode 100755 index 0000000..852847d --- /dev/null +++ b/bin/kafka-verifiable-consumer.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if [ "x$KAFKA_HEAP_OPTS" = "x" ]; then + export KAFKA_HEAP_OPTS="-Xmx512M" +fi +exec $(dirname $0)/kafka-run-class.sh org.apache.kafka.tools.VerifiableConsumer "$@" diff --git a/bin/kafka-verifiable-producer.sh b/bin/kafka-verifiable-producer.sh new file mode 100755 index 0000000..b59bae7 --- /dev/null +++ b/bin/kafka-verifiable-producer.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if [ "x$KAFKA_HEAP_OPTS" = "x" ]; then + export KAFKA_HEAP_OPTS="-Xmx512M" +fi +exec $(dirname $0)/kafka-run-class.sh org.apache.kafka.tools.VerifiableProducer "$@" diff --git a/bin/trogdor.sh b/bin/trogdor.sh new file mode 100755 index 0000000..3324c4e --- /dev/null +++ b/bin/trogdor.sh @@ -0,0 +1,50 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +usage() { + cat <nul 2>&1 + IF NOT ERRORLEVEL 1 ( + rem 32-bit OS + set KAFKA_HEAP_OPTS=-Xmx512M -Xms512M + ) ELSE ( + rem 64-bit OS + set KAFKA_HEAP_OPTS=-Xmx1G -Xms1G + ) +) +"%~dp0kafka-run-class.bat" kafka.Kafka %* +EndLocal diff --git a/bin/windows/kafka-server-stop.bat b/bin/windows/kafka-server-stop.bat new file mode 100644 index 0000000..676577c --- /dev/null +++ b/bin/windows/kafka-server-stop.bat @@ -0,0 +1,18 @@ +@echo off +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. + +wmic process where (commandline like "%%kafka.Kafka%%" and not name="wmic.exe") delete +rem ps ax | grep -i 'kafka.Kafka' | grep -v grep | awk '{print $1}' | xargs kill -SIGTERM diff --git a/bin/windows/kafka-streams-application-reset.bat b/bin/windows/kafka-streams-application-reset.bat new file mode 100644 index 0000000..1cfb6f5 --- /dev/null +++ b/bin/windows/kafka-streams-application-reset.bat @@ -0,0 +1,23 @@ +@echo off +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. + +SetLocal +IF ["%KAFKA_HEAP_OPTS%"] EQU [""] ( + set KAFKA_HEAP_OPTS=-Xmx512M +) + +"%~dp0kafka-run-class.bat" kafka.tools.StreamsResetter %* +EndLocal diff --git a/bin/windows/kafka-topics.bat b/bin/windows/kafka-topics.bat new file mode 100644 index 0000000..677b09d --- /dev/null +++ b/bin/windows/kafka-topics.bat @@ -0,0 +1,17 @@ +@echo off +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. + +"%~dp0kafka-run-class.bat" kafka.admin.TopicCommand %* diff --git a/bin/windows/kafka-transactions.bat b/bin/windows/kafka-transactions.bat new file mode 100644 index 0000000..9bb7585 --- /dev/null +++ b/bin/windows/kafka-transactions.bat @@ -0,0 +1,17 @@ +@echo off +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. + +"%~dp0kafka-run-class.bat" org.apache.kafka.tools.TransactionsCommand %* diff --git a/bin/windows/zookeeper-server-start.bat b/bin/windows/zookeeper-server-start.bat new file mode 100644 index 0000000..f201a58 --- /dev/null +++ b/bin/windows/zookeeper-server-start.bat @@ -0,0 +1,30 @@ +@echo off +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. + +IF [%1] EQU [] ( + echo USAGE: %0 zookeeper.properties + EXIT /B 1 +) + +SetLocal +IF ["%KAFKA_LOG4J_OPTS%"] EQU [""] ( + set KAFKA_LOG4J_OPTS=-Dlog4j.configuration=file:%~dp0../../config/log4j.properties +) +IF ["%KAFKA_HEAP_OPTS%"] EQU [""] ( + set KAFKA_HEAP_OPTS=-Xmx512M -Xms512M +) +"%~dp0kafka-run-class.bat" org.apache.zookeeper.server.quorum.QuorumPeerMain %* +EndLocal diff --git a/bin/windows/zookeeper-server-stop.bat b/bin/windows/zookeeper-server-stop.bat new file mode 100644 index 0000000..8b57dd8 --- /dev/null +++ b/bin/windows/zookeeper-server-stop.bat @@ -0,0 +1,17 @@ +@echo off +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. + +wmic process where (commandline like "%%zookeeper%%" and not name="wmic.exe") delete diff --git a/bin/windows/zookeeper-shell.bat b/bin/windows/zookeeper-shell.bat new file mode 100644 index 0000000..f1c86c4 --- /dev/null +++ b/bin/windows/zookeeper-shell.bat @@ -0,0 +1,22 @@ +@echo off +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. + +IF [%1] EQU [] ( + echo USAGE: %0 zookeeper_host:port[/path] [-zk-tls-config-file file] [args...] + EXIT /B 1 +) + +"%~dp0kafka-run-class.bat" org.apache.zookeeper.ZooKeeperMainWithTlsSupportForKafka -server %* diff --git a/bin/zookeeper-security-migration.sh b/bin/zookeeper-security-migration.sh new file mode 100755 index 0000000..722bde7 --- /dev/null +++ b/bin/zookeeper-security-migration.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +exec $(dirname $0)/kafka-run-class.sh kafka.admin.ZkSecurityMigrator "$@" diff --git a/bin/zookeeper-server-start.sh b/bin/zookeeper-server-start.sh new file mode 100755 index 0000000..bd9c114 --- /dev/null +++ b/bin/zookeeper-server-start.sh @@ -0,0 +1,44 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if [ $# -lt 1 ]; +then + echo "USAGE: $0 [-daemon] zookeeper.properties" + exit 1 +fi +base_dir=$(dirname $0) + +if [ "x$KAFKA_LOG4J_OPTS" = "x" ]; then + export KAFKA_LOG4J_OPTS="-Dlog4j.configuration=file:$base_dir/../config/log4j.properties" +fi + +if [ "x$KAFKA_HEAP_OPTS" = "x" ]; then + export KAFKA_HEAP_OPTS="-Xmx512M -Xms512M" +fi + +EXTRA_ARGS=${EXTRA_ARGS-'-name zookeeper -loggc'} + +COMMAND=$1 +case $COMMAND in + -daemon) + EXTRA_ARGS="-daemon "$EXTRA_ARGS + shift + ;; + *) + ;; +esac + +exec $base_dir/kafka-run-class.sh $EXTRA_ARGS org.apache.zookeeper.server.quorum.QuorumPeerMain "$@" diff --git a/bin/zookeeper-server-stop.sh b/bin/zookeeper-server-stop.sh new file mode 100755 index 0000000..11665f3 --- /dev/null +++ b/bin/zookeeper-server-stop.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +SIGNAL=${SIGNAL:-TERM} + +OSNAME=$(uname -s) +if [[ "$OSNAME" == "OS/390" ]]; then + if [ -z $JOBNAME ]; then + JOBNAME="ZKEESTRT" + fi + PIDS=$(ps -A -o pid,jobname,comm | grep -i $JOBNAME | grep java | grep -v grep | awk '{print $1}') +elif [[ "$OSNAME" == "OS400" ]]; then + PIDS=$(ps -Af | grep java | grep -i QuorumPeerMain | grep -v grep | awk '{print $2}') +else + PIDS=$(ps ax | grep java | grep -i QuorumPeerMain | grep -v grep | awk '{print $1}') +fi + +if [ -z "$PIDS" ]; then + echo "No zookeeper server to stop" + exit 1 +else + kill -s $SIGNAL $PIDS +fi diff --git a/bin/zookeeper-shell.sh b/bin/zookeeper-shell.sh new file mode 100755 index 0000000..2f1d0f2 --- /dev/null +++ b/bin/zookeeper-shell.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +if [ $# -lt 1 ]; +then + echo "USAGE: $0 zookeeper_host:port[/path] [-zk-tls-config-file file] [args...]" + exit 1 +fi + +exec $(dirname $0)/kafka-run-class.sh org.apache.zookeeper.ZooKeeperMainWithTlsSupportForKafka -server "$@" diff --git a/build.gradle b/build.gradle new file mode 100644 index 0000000..86f9d19 --- /dev/null +++ b/build.gradle @@ -0,0 +1,2664 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import org.ajoberstar.grgit.Grgit + +import java.nio.charset.StandardCharsets + +buildscript { + repositories { + maven { url 'https://plugins.gradle.org/m2/' } + maven { url 'https://maven.aliyun.com/nexus/content/repositories/google' } + maven { url 'https://maven.aliyun.com/nexus/content/groups/public' } + maven { url 'https://maven.aliyun.com/nexus/content/repositories/jcenter'} +// mavenCentral() + } + apply from: "$rootDir/gradle/dependencies.gradle" + + dependencies { + // For Apache Rat plugin to ignore non-Git files + classpath "org.ajoberstar.grgit:grgit-core:$versions.grgit" + } +} + +plugins { + id 'com.diffplug.spotless' version '5.12.5' + id 'com.github.ben-manes.versions' version '0.38.0' + id 'idea' + id 'java-library' + id 'org.owasp.dependencycheck' version '6.1.6' + id 'org.nosphere.apache.rat' version "0.7.0" + + id "com.github.spotbugs" version '4.7.1' apply false + id 'org.gradle.test-retry' version '1.3.1' apply false + id 'org.scoverage' version '5.0.0' apply false + id 'com.github.johnrengelman.shadow' version '7.0.0' apply false +} + +spotless { + scala { + target 'streams/**/*.scala' + scalafmt("$versions.scalafmt").configFile('checkstyle/.scalafmt.conf') + licenseHeaderFile 'checkstyle/java.header', 'package' + } +} + +allprojects { + + repositories { + maven { url 'https://plugins.gradle.org/m2/' } + maven { url 'https://maven.aliyun.com/nexus/content/repositories/google' } + maven { url 'https://maven.aliyun.com/nexus/content/groups/public' } + maven { url 'https://maven.aliyun.com/nexus/content/repositories/jcenter'} +// mavenCentral() + } + + dependencyUpdates { + revision="release" + resolutionStrategy { + componentSelection { rules -> + rules.all { ComponentSelection selection -> + boolean rejected = ['snap', 'alpha', 'beta', 'rc', 'cr', 'm'].any { qualifier -> + selection.candidate.version ==~ /(?i).*[.-]${qualifier}[.\d-]*/ + } + if (rejected) { + selection.reject('Release candidate') + } + } + } + } + } + + configurations.all { + // zinc is the Scala incremental compiler, it has a configuration for its own dependencies + // that are unrelated to the project dependencies, we should not change them + if (name != "zinc") { + resolutionStrategy { + force( + // be explicit about the javassist dependency version instead of relying on the transitive version + libs.javassist, + // ensure we have a single version in the classpath despite transitive dependencies + libs.scalaLibrary, + libs.scalaReflect, + libs.jacksonAnnotations, + // be explicit about the Netty dependency version instead of relying on the version set by + // ZooKeeper (potentially older and containing CVEs) + libs.nettyHandler, + libs.nettyTransportNativeEpoll + ) + } + } + } +} + +ext { + gradleVersion = versions.gradle + minJavaVersion = "8" + buildVersionFileName = "kafka-version.properties" + + defaultMaxHeapSize = "2g" + defaultJvmArgs = ["-Xss4m", "-XX:+UseParallelGC"] + + // "JEP 403: Strongly Encapsulate JDK Internals" causes some tests to fail when they try + // to access internals (often via mocking libraries). We use `--add-opens` as a workaround + // for now and we'll fix it properly (where possible) via KAFKA-13275. + if (JavaVersion.current().isCompatibleWith(JavaVersion.VERSION_16)) + defaultJvmArgs.addAll( + "--add-opens=java.base/java.io=ALL-UNNAMED", + "--add-opens=java.base/java.nio=ALL-UNNAMED", + "--add-opens=java.base/java.nio.file=ALL-UNNAMED", + "--add-opens=java.base/java.util.concurrent=ALL-UNNAMED", + "--add-opens=java.base/java.util.regex=ALL-UNNAMED", + "--add-opens=java.base/java.util.stream=ALL-UNNAMED", + "--add-opens=java.base/java.text=ALL-UNNAMED", + "--add-opens=java.base/java.time=ALL-UNNAMED", + "--add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED" + ) + + userMaxForks = project.hasProperty('maxParallelForks') ? maxParallelForks.toInteger() : null + userIgnoreFailures = project.hasProperty('ignoreFailures') ? ignoreFailures : false + + userMaxTestRetries = project.hasProperty('maxTestRetries') ? maxTestRetries.toInteger() : 0 + userMaxTestRetryFailures = project.hasProperty('maxTestRetryFailures') ? maxTestRetryFailures.toInteger() : 0 + + skipSigning = project.hasProperty('skipSigning') && skipSigning.toBoolean() + shouldSign = !skipSigning && !version.endsWith("SNAPSHOT") + + mavenUrl = project.hasProperty('mavenUrl') ? project.mavenUrl : '' + mavenUsername = project.hasProperty('mavenUsername') ? project.mavenUsername : '' + mavenPassword = project.hasProperty('mavenPassword') ? project.mavenPassword : '' + + userShowStandardStreams = project.hasProperty("showStandardStreams") ? showStandardStreams : null + + userTestLoggingEvents = project.hasProperty("testLoggingEvents") ? Arrays.asList(testLoggingEvents.split(",")) : null + + userEnableTestCoverage = project.hasProperty("enableTestCoverage") ? enableTestCoverage : false + + // See README.md for details on this option and the reasoning for the default + userScalaOptimizerMode = project.hasProperty("scalaOptimizerMode") ? scalaOptimizerMode : "inline-kafka" + def scalaOptimizerValues = ["none", "method", "inline-kafka", "inline-scala"] + if (!scalaOptimizerValues.contains(userScalaOptimizerMode)) + throw new GradleException("Unexpected value for scalaOptimizerMode property. Expected one of $scalaOptimizerValues), but received: $userScalaOptimizerMode") + + generatedDocsDir = new File("${project.rootDir}/docs/generated") + + commitId = project.hasProperty('commitId') ? commitId : null +} + +apply from: file('wrapper.gradle') + +if (file('.git').exists()) { + rat { + verbose = true + reportDir.set(project.file('build/rat')) + stylesheet.set(file('gradle/resources/rat-output-to-html.xsl')) + + // Exclude everything under the directory that git should be ignoring via .gitignore or that isn't checked in. These + // restrict us only to files that are checked in or are staged. + def repo = Grgit.open(currentDir: project.getRootDir()) + excludes = new ArrayList(repo.clean(ignore: false, directories: true, dryRun: true)) + // And some of the files that we have checked in should also be excluded from this check + excludes.addAll([ + '**/.git/**', + '**/build/**', + 'CONTRIBUTING.md', + 'PULL_REQUEST_TEMPLATE.md', + 'gradlew', + 'gradlew.bat', + 'gradle/wrapper/gradle-wrapper.properties', + 'config/kraft/README.md', + 'TROGDOR.md', + '**/README.md', + '**/id_rsa', + '**/id_rsa.pub', + 'checkstyle/suppressions.xml', + 'streams/quickstart/java/src/test/resources/projects/basic/goal.txt', + 'streams/streams-scala/logs/*', + 'licenses/*', + '**/generated/**', + 'clients/src/test/resources/serializedData/*' + ]) + } +} else { + rat.enabled = false +} +println("Starting build with version $version using Gradle $gradleVersion, Java ${JavaVersion.current()} and Scala ${versions.scala}") + +subprojects { + + // enable running :dependencies task recursively on all subprojects + // eg: ./gradlew allDeps + task allDeps(type: DependencyReportTask) {} + // enable running :dependencyInsight task recursively on all subprojects + // eg: ./gradlew allDepInsight --configuration runtime --dependency com.fasterxml.jackson.core:jackson-databind + task allDepInsight(type: DependencyInsightReportTask) doLast {} + + apply plugin: 'java-library' + apply plugin: 'checkstyle' + apply plugin: "com.github.spotbugs" + apply plugin: 'org.gradle.test-retry' + + // We use the shadow plugin for the jmh-benchmarks module and the `-all` jar can get pretty large, so + // don't publish it + def shouldPublish = !project.name.equals('jmh-benchmarks') + + if (shouldPublish) { + apply plugin: 'maven-publish' + apply plugin: 'signing' + + // Add aliases for the task names used by the maven plugin for backwards compatibility + // The maven plugin was replaced by the maven-publish plugin in Gradle 7.0 + tasks.register('install').configure { dependsOn(publishToMavenLocal) } + tasks.register('uploadArchives').configure { dependsOn(publish) } + } + + // apply the eclipse plugin only to subprojects that hold code. 'connect' is just a folder. + if (!project.name.equals('connect')) { + apply plugin: 'eclipse' + fineTuneEclipseClasspathFile(eclipse, project) + } + + sourceCompatibility = minJavaVersion + targetCompatibility = minJavaVersion + + java { + consistentResolution { + // resolve the compileClasspath and then "inject" the result of resolution as strict constraints into the runtimeClasspath + useCompileClasspathVersions() + } + } + + tasks.withType(JavaCompile) { + options.encoding = 'UTF-8' + options.compilerArgs << "-Xlint:all" + // temporary exclusions until all the warnings are fixed + if (!project.path.startsWith(":connect")) + options.compilerArgs << "-Xlint:-rawtypes" + options.compilerArgs << "-Xlint:-serial" + options.compilerArgs << "-Xlint:-try" + options.compilerArgs << "-Werror" + // --release is the recommended way to select the target release, but it's only supported in Java 9 so we also + // set --source and --target via `sourceCompatibility` and `targetCompatibility`. If/when Gradle supports `--release` + // natively (https://github.com/gradle/gradle/issues/2510), we should switch to that. + if (JavaVersion.current().isJava9Compatible()) + options.compilerArgs << "--release" << minJavaVersion + } + + if (shouldPublish) { + + publishing { + repositories { + // To test locally, invoke gradlew with `-PmavenUrl=file:///some/local/path` + maven { + url = mavenUrl + credentials { + username = mavenUsername + password = mavenPassword + } + } + } + publications { + mavenJava(MavenPublication) { + from components.java + + afterEvaluate { + ["srcJar", "javadocJar", "scaladocJar", "testJar", "testSrcJar"].forEach { taskName -> + def task = tasks.findByName(taskName) + if (task != null) + artifact task + } + + artifactId = archivesBaseName + pom { + name = 'Apache Kafka' + url = 'https://kafka.apache.org' + licenses { + license { + name = 'The Apache License, Version 2.0' + url = 'http://www.apache.org/licenses/LICENSE-2.0.txt' + distribution = 'repo' + } + } + } + } + } + } + } + + if (shouldSign) { + signing { + sign publishing.publications.mavenJava + } + } + } + + // Remove the relevant project name once it's converted to JUnit 5 + def shouldUseJUnit5 = !(["runtime", "streams"].contains(it.project.name)) + + def testLoggingEvents = ["passed", "skipped", "failed"] + def testShowStandardStreams = false + def testExceptionFormat = 'full' + // Gradle built-in logging only supports sending test output to stdout, which generates a lot + // of noise, especially for passing tests. We really only want output for failed tests. This + // hooks into the output and logs it (so we don't have to buffer it all in memory) and only + // saves the output for failing tests. Directory and filenames are such that you can, e.g., + // create a Jenkins rule to collect failed test output. + def logTestStdout = { + def testId = { TestDescriptor descriptor -> + "${descriptor.className}.${descriptor.name}".toString() + } + + def logFiles = new HashMap() + def logStreams = new HashMap() + beforeTest { TestDescriptor td -> + def tid = testId(td) + // truncate the file name if it's too long + def logFile = new File( + "${projectDir}/build/reports/testOutput/${tid.substring(0, Math.min(tid.size(),240))}.test.stdout" + ) + logFile.parentFile.mkdirs() + logFiles.put(tid, logFile) + logStreams.put(tid, new FileOutputStream(logFile)) + } + onOutput { TestDescriptor td, TestOutputEvent toe -> + def tid = testId(td) + // Some output can happen outside the context of a specific test (e.g. at the class level) + // and beforeTest/afterTest seems to not be invoked for these cases (and similarly, there's + // a TestDescriptor hierarchy that includes the thread executing the test, Gradle tasks, + // etc). We see some of these in practice and it seems like something buggy in the Gradle + // test runner since we see it *before* any tests and it is frequently not related to any + // code in the test (best guess is that it is tail output from last test). We won't have + // an output file for these, so simply ignore them. If they become critical for debugging, + // they can be seen with showStandardStreams. + if (td.name == td.className || td.className == null) { + // silently ignore output unrelated to specific test methods + return + } else if (logStreams.get(tid) == null) { + println "WARNING: unexpectedly got output for a test [${tid}]" + + " that we didn't previously see in the beforeTest hook." + + " Message for debugging: [" + toe.message + "]." + return + } + try { + logStreams.get(tid).write(toe.message.getBytes(StandardCharsets.UTF_8)) + } catch (Exception e) { + println "ERROR: Failed to write output for test ${tid}" + e.printStackTrace() + } + } + afterTest { TestDescriptor td, TestResult tr -> + def tid = testId(td) + try { + logStreams.get(tid).close() + if (tr.resultType != TestResult.ResultType.FAILURE) { + logFiles.get(tid).delete() + } else { + def file = logFiles.get(tid) + println "${tid} failed, log available in ${file}" + } + } catch (Exception e) { + println "ERROR: Failed to close stdout file for ${tid}" + e.printStackTrace() + } finally { + logFiles.remove(tid) + logStreams.remove(tid) + } + } + } + + // The suites are for running sets of tests in IDEs. + // Gradle will run each test class, so we exclude the suites to avoid redundantly running the tests twice. + def testsToExclude = ['**/*Suite.class'] + // Exclude PowerMock tests when running with Java 16 or newer until a version of PowerMock that supports the relevant versions is released + // The relevant issues are https://github.com/powermock/powermock/issues/1094 and https://github.com/powermock/powermock/issues/1099 + if (JavaVersion.current().isCompatibleWith(JavaVersion.VERSION_16)) { + testsToExclude.addAll([ + // connect tests + "**/AbstractHerderTest.*", "**/ConnectClusterStateImplTest.*", "**/ConnectorPluginsResourceTest.*", + "**/ConnectorsResourceTest.*", "**/DistributedHerderTest.*", "**/FileOffsetBakingStoreTest.*", + "**/ErrorHandlingTaskTest.*", "**/KafkaConfigBackingStoreTest.*", "**/KafkaOffsetBackingStoreTest.*", + "**/KafkaBasedLogTest.*", "**/OffsetStorageWriterTest.*", "**/StandaloneHerderTest.*", + "**/SourceTaskOffsetCommitterTest.*", "**/WorkerConfigTransformerTest.*", "**/WorkerGroupMemberTest.*", + "**/WorkerSinkTaskTest.*", "**/WorkerSinkTaskThreadedTest.*", "**/WorkerSourceTaskTest.*", + "**/WorkerTaskTest.*", "**/WorkerTest.*", + // streams tests + "**/KafkaStreamsTest.*" + ]) + } + + test { + maxParallelForks = userMaxForks ?: Runtime.runtime.availableProcessors() + ignoreFailures = userIgnoreFailures + + maxHeapSize = defaultMaxHeapSize + jvmArgs = defaultJvmArgs + + testLogging { + events = userTestLoggingEvents ?: testLoggingEvents + showStandardStreams = userShowStandardStreams ?: testShowStandardStreams + exceptionFormat = testExceptionFormat + } + logTestStdout.rehydrate(delegate, owner, this)() + + exclude testsToExclude + + if (shouldUseJUnit5) + useJUnitPlatform() + + retry { + maxRetries = userMaxTestRetries + maxFailures = userMaxTestRetryFailures + } + } + + task integrationTest(type: Test, dependsOn: compileJava) { + maxParallelForks = userMaxForks ?: Runtime.runtime.availableProcessors() + ignoreFailures = userIgnoreFailures + + maxHeapSize = defaultMaxHeapSize + jvmArgs = defaultJvmArgs + + + testLogging { + events = userTestLoggingEvents ?: testLoggingEvents + showStandardStreams = userShowStandardStreams ?: testShowStandardStreams + exceptionFormat = testExceptionFormat + } + logTestStdout.rehydrate(delegate, owner, this)() + + exclude testsToExclude + + if (shouldUseJUnit5) { + useJUnitPlatform { + includeTags "integration" + } + } else { + useJUnit { + includeCategories 'org.apache.kafka.test.IntegrationTest' + } + } + + retry { + maxRetries = userMaxTestRetries + maxFailures = userMaxTestRetryFailures + } + } + + task unitTest(type: Test, dependsOn: compileJava) { + maxParallelForks = userMaxForks ?: Runtime.runtime.availableProcessors() + ignoreFailures = userIgnoreFailures + + maxHeapSize = defaultMaxHeapSize + jvmArgs = defaultJvmArgs + + testLogging { + events = userTestLoggingEvents ?: testLoggingEvents + showStandardStreams = userShowStandardStreams ?: testShowStandardStreams + exceptionFormat = testExceptionFormat + } + logTestStdout.rehydrate(delegate, owner, this)() + + exclude testsToExclude + + if (shouldUseJUnit5) { + useJUnitPlatform { + excludeTags "integration" + } + } else { + useJUnit { + excludeCategories 'org.apache.kafka.test.IntegrationTest' + } + } + + retry { + maxRetries = userMaxTestRetries + maxFailures = userMaxTestRetryFailures + } + } + + // remove test output from all test types + tasks.withType(Test).all { t -> + cleanTest { + delete t.reports.junitXml.destination + delete t.reports.html.destination + } + } + + jar { + from "$rootDir/LICENSE" + from "$rootDir/NOTICE" + } + + task srcJar(type: Jar) { + archiveClassifier = 'sources' + from "$rootDir/LICENSE" + from "$rootDir/NOTICE" + from sourceSets.main.allSource + } + + task javadocJar(type: Jar, dependsOn: javadoc) { + archiveClassifier = 'javadoc' + from "$rootDir/LICENSE" + from "$rootDir/NOTICE" + from javadoc.destinationDir + } + + task docsJar(dependsOn: javadocJar) + + javadoc { + options.charSet = 'UTF-8' + options.docEncoding = 'UTF-8' + options.encoding = 'UTF-8' + // Turn off doclint for now, see https://blog.joda.org/2014/02/turning-off-doclint-in-jdk-8-javadoc.html for rationale + options.addStringOption('Xdoclint:none', '-quiet') + + // The URL structure was changed to include the locale after Java 8 + if (JavaVersion.current().isJava11Compatible()) + options.links "https://docs.oracle.com/en/java/javase/${JavaVersion.current().majorVersion}/docs/api/" + else + options.links "https://docs.oracle.com/javase/8/docs/api/" + } + + task systemTestLibs(dependsOn: jar) + + if (!sourceSets.test.allSource.isEmpty()) { + task testJar(type: Jar) { + archiveClassifier = 'test' + from "$rootDir/LICENSE" + from "$rootDir/NOTICE" + from sourceSets.test.output + } + + task testSrcJar(type: Jar, dependsOn: testJar) { + archiveClassifier = 'test-sources' + from "$rootDir/LICENSE" + from "$rootDir/NOTICE" + from sourceSets.test.allSource + } + + } + + plugins.withType(ScalaPlugin) { + + scala { + zincVersion = versions.zinc + } + + task scaladocJar(type:Jar, dependsOn: scaladoc) { + archiveClassifier = 'scaladoc' + from "$rootDir/LICENSE" + from "$rootDir/NOTICE" + from scaladoc.destinationDir + } + + //documentation task should also trigger building scala doc jar + docsJar.dependsOn scaladocJar + + } + + tasks.withType(ScalaCompile) { + scalaCompileOptions.additionalParameters = [ + "-deprecation", + "-unchecked", + "-encoding", "utf8", + "-Xlog-reflective-calls", + "-feature", + "-language:postfixOps", + "-language:implicitConversions", + "-language:existentials", + "-Xlint:constant", + "-Xlint:delayedinit-select", + "-Xlint:doc-detached", + "-Xlint:missing-interpolator", + "-Xlint:nullary-unit", + "-Xlint:option-implicit", + "-Xlint:package-object-classes", + "-Xlint:poly-implicit-overload", + "-Xlint:private-shadow", + "-Xlint:stars-align", + "-Xlint:type-parameter-shadow", + "-Xlint:unused" + ] + + // See README.md for details on this option and the meaning of each value + if (userScalaOptimizerMode.equals("method")) + scalaCompileOptions.additionalParameters += ["-opt:l:method"] + else if (userScalaOptimizerMode.startsWith("inline-")) { + List inlineFrom = ["-opt-inline-from:org.apache.kafka.**"] + if (project.name.equals('core')) + inlineFrom.add("-opt-inline-from:kafka.**") + if (userScalaOptimizerMode.equals("inline-scala")) + inlineFrom.add("-opt-inline-from:scala.**") + + scalaCompileOptions.additionalParameters += ["-opt:l:inline"] + scalaCompileOptions.additionalParameters += inlineFrom + } + + if (versions.baseScala != '2.12') { + scalaCompileOptions.additionalParameters += ["-opt-warnings", "-Xlint:strict-unsealed-patmat"] + // Scala 2.13.2 introduces compiler warnings suppression, which is a pre-requisite for -Xfatal-warnings + scalaCompileOptions.additionalParameters += ["-Xfatal-warnings"] + } + + // these options are valid for Scala versions < 2.13 only + // Scala 2.13 removes them, see https://github.com/scala/scala/pull/6502 and https://github.com/scala/scala/pull/5969 + if (versions.baseScala == '2.12') { + scalaCompileOptions.additionalParameters += [ + "-Xlint:by-name-right-associative", + "-Xlint:nullary-override", + "-Xlint:unsound-match" + ] + } + + // Scalac's `-release` requires Java 9 or higher + if (JavaVersion.current().isJava9Compatible()) + scalaCompileOptions.additionalParameters += ["-release", minJavaVersion] + + configure(scalaCompileOptions.forkOptions) { + memoryMaximumSize = defaultMaxHeapSize + jvmArgs = defaultJvmArgs + } + } + + checkstyle { + configFile = new File(rootDir, "checkstyle/checkstyle.xml") + configProperties = checkstyleConfigProperties("import-control.xml") + toolVersion = versions.checkstyle + } + + configure(checkstyleMain) { + group = 'Verification' + description = 'Run checkstyle on all main Java sources' + } + + configure(checkstyleTest) { + group = 'Verification' + description = 'Run checkstyle on all test Java sources' + } + + test.dependsOn('checkstyleMain', 'checkstyleTest') + + spotbugs { + toolVersion = versions.spotbugs + excludeFilter = file("$rootDir/gradle/spotbugs-exclude.xml") + ignoreFailures = false + } + test.dependsOn('spotbugsMain') + + tasks.withType(com.github.spotbugs.snom.SpotBugsTask) { + reports { + // Continue supporting `xmlFindBugsReport` for compatibility + xml.enabled(project.hasProperty('xmlSpotBugsReport') || project.hasProperty('xmlFindBugsReport')) + html.enabled(!project.hasProperty('xmlSpotBugsReport') && !project.hasProperty('xmlFindBugsReport')) + } + maxHeapSize = defaultMaxHeapSize + jvmArgs = defaultJvmArgs + } + + // Ignore core since its a scala project + if (it.path != ':core') { + if (userEnableTestCoverage) { + apply plugin: "jacoco" + + jacoco { + toolVersion = versions.jacoco + } + + // NOTE: Jacoco Gradle plugin does not support "offline instrumentation" this means that classes mocked by PowerMock + // may report 0 coverage, since the source was modified after initial instrumentation. + // See https://github.com/jacoco/jacoco/issues/51 + jacocoTestReport { + dependsOn tasks.test + sourceSets sourceSets.main + reports { + html.enabled = true + xml.enabled = true + csv.enabled = false + } + } + + } + } + + if (userEnableTestCoverage) { + def coverageGen = it.path == ':core' ? 'reportScoverage' : 'jacocoTestReport' + task reportCoverage(dependsOn: [coverageGen]) + } + + task determineCommitId { + def takeFromHash = 16 + if (commitId) { + commitId = commitId.take(takeFromHash) + } else if (file("$rootDir/.git/HEAD").exists()) { + def headRef = file("$rootDir/.git/HEAD").text + if (headRef.contains('ref: ')) { + headRef = headRef.replaceAll('ref: ', '').trim() + if (file("$rootDir/.git/$headRef").exists()) { + commitId = file("$rootDir/.git/$headRef").text.trim().take(takeFromHash) + } + } else { + commitId = headRef.trim().take(takeFromHash) + } + } else { + commitId = "unknown" + } + } + +} + +gradle.taskGraph.whenReady { taskGraph -> + taskGraph.getAllTasks().findAll { it.name.contains('spotbugsScoverage') || it.name.contains('spotbugsTest') }.each { task -> + task.enabled = false + } +} + +def fineTuneEclipseClasspathFile(eclipse, project) { + eclipse.classpath.file { + beforeMerged { cp -> + cp.entries.clear() + // for the core project add the directories defined under test/scala as separate source directories + if (project.name.equals('core')) { + cp.entries.add(new org.gradle.plugins.ide.eclipse.model.SourceFolder("src/test/scala/integration", null)) + cp.entries.add(new org.gradle.plugins.ide.eclipse.model.SourceFolder("src/test/scala/other", null)) + cp.entries.add(new org.gradle.plugins.ide.eclipse.model.SourceFolder("src/test/scala/unit", null)) + } + } + whenMerged { cp -> + // for the core project exclude the separate sub-directories defined under test/scala. These are added as source dirs above + if (project.name.equals('core')) { + cp.entries.findAll { it.kind == "src" && it.path.equals("src/test/scala") }*.excludes = ["integration/", "other/", "unit/"] + } + /* + * Set all eclipse build output to go to 'build_eclipse' directory. This is to ensure that gradle and eclipse use different + * build output directories, and also avoid using the eclpise default of 'bin' which clashes with some of our script directories. + * https://discuss.gradle.org/t/eclipse-generated-files-should-be-put-in-the-same-place-as-the-gradle-generated-files/6986/2 + */ + cp.entries.findAll { it.kind == "output" }*.path = "build_eclipse" + /* + * Some projects have explicitly added test output dependencies. These are required for the gradle build but not required + * in Eclipse since the dependent projects are added as dependencies. So clean up these from the generated classpath. + */ + cp.entries.removeAll { it.kind == "lib" && it.path.matches(".*/build/(classes|resources)/test") } + } + } +} + +def checkstyleConfigProperties(configFileName) { + [importControlFile: "$rootDir/checkstyle/$configFileName", + suppressionsFile: "$rootDir/checkstyle/suppressions.xml", + headerFile: "$rootDir/checkstyle/java.header"] +} + +// Aggregates all jacoco results into the root project directory +if (userEnableTestCoverage) { + task jacocoRootReport(type: org.gradle.testing.jacoco.tasks.JacocoReport) { + def javaProjects = subprojects.findAll { it.path != ':core' } + + description = 'Generates an aggregate report from all subprojects' + dependsOn(javaProjects.test) + + additionalSourceDirs.from = javaProjects.sourceSets.main.allSource.srcDirs + sourceDirectories.from = javaProjects.sourceSets.main.allSource.srcDirs + classDirectories.from = javaProjects.sourceSets.main.output + executionData.from = javaProjects.jacocoTestReport.executionData + + reports { + html.enabled = true + xml.enabled = true + } + + // workaround to ignore projects that don't have any tests at all + onlyIf = { true } + doFirst { + executionData = files(executionData.findAll { it.exists() }) + } + } +} + +if (userEnableTestCoverage) { + task reportCoverage(dependsOn: ['jacocoRootReport', 'core:reportCoverage']) +} + +def connectPkgs = [ + 'connect:api', + 'connect:basic-auth-extension', + 'connect:file', + 'connect:json', + 'connect:runtime', + 'connect:transforms', + 'connect:mirror', + 'connect:mirror-client' +] + +tasks.create(name: "jarConnect", dependsOn: connectPkgs.collect { it + ":jar" }) {} + +tasks.create(name: "testConnect", dependsOn: connectPkgs.collect { it + ":test" }) {} + +project(':core') { + apply plugin: 'scala' + + // scaladoc generation is configured at the sub-module level with an artifacts + // block (cf. see streams-scala). If scaladoc generation is invoked explicitly + // for the `core` module, this ensures the generated jar doesn't include scaladoc + // files since the `core` module doesn't include public APIs. + scaladoc { + enabled = false + } + if (userEnableTestCoverage) + apply plugin: "org.scoverage" + archivesBaseName = "kafka_${versions.baseScala}" + + dependencies { + // `core` is often used in users' tests, define the following dependencies as `api` for backwards compatibility + // even though the `core` module doesn't expose any public API + api project(':clients') + api libs.scalaLibrary + + implementation project(':server-common') + implementation project(':metadata') + implementation project(':raft') + implementation project(':storage') + + implementation libs.argparse4j + implementation libs.jacksonDatabind + implementation libs.jacksonModuleScala + implementation libs.jacksonDataformatCsv + implementation libs.jacksonJDK8Datatypes + implementation libs.joptSimple + implementation libs.jose4j + implementation libs.metrics + implementation libs.scalaCollectionCompat + implementation libs.scalaJava8Compat + // only needed transitively, but set it explicitly to ensure it has the same version as scala-library + implementation libs.scalaReflect + implementation libs.scalaLogging + implementation libs.slf4jApi + implementation(libs.zookeeper) { + // Dropwizard Metrics are required by ZooKeeper as of v3.6.0, + // but the library should *not* be used in Kafka code + implementation libs.dropwizardMetrics + exclude module: 'slf4j-log4j12' + exclude module: 'log4j' + } + // ZooKeeperMain depends on commons-cli but declares the dependency as `provided` + implementation libs.commonsCli + + compileOnly libs.log4j + + testImplementation project(':clients').sourceSets.test.output + testImplementation project(':metadata').sourceSets.test.output + testImplementation project(':raft').sourceSets.test.output + testImplementation libs.bcpkix + testImplementation libs.mockitoCore + testImplementation libs.easymock + testImplementation(libs.apacheda) { + exclude group: 'xml-apis', module: 'xml-apis' + // `mina-core` is a transitive dependency for `apacheds` and `apacheda`. + // It is safer to use from `apacheds` since that is the implementation. + exclude module: 'mina-core' + } + testImplementation libs.apachedsCoreApi + testImplementation libs.apachedsInterceptorKerberos + testImplementation libs.apachedsProtocolShared + testImplementation libs.apachedsProtocolKerberos + testImplementation libs.apachedsProtocolLdap + testImplementation libs.apachedsLdifPartition + testImplementation libs.apachedsMavibotPartition + testImplementation libs.apachedsJdbmPartition + testImplementation libs.junitJupiter + testImplementation libs.slf4jlog4j + testImplementation(libs.jfreechart) { + exclude group: 'junit', module: 'junit' + } + } + + if (userEnableTestCoverage) { + scoverage { + scoverageVersion = versions.scoverage + reportDir = file("${rootProject.buildDir}/scoverage") + highlighting = false + minimumRate = 0.0 + } + } + + configurations { + // manually excludes some unnecessary dependencies + implementation.exclude module: 'javax' + implementation.exclude module: 'jline' + implementation.exclude module: 'jms' + implementation.exclude module: 'jmxri' + implementation.exclude module: 'jmxtools' + implementation.exclude module: 'mail' + // To prevent a UniqueResourceException due the same resource existing in both + // org.apache.directory.api/api-all and org.apache.directory.api/api-ldap-schema-data + testImplementation.exclude module: 'api-ldap-schema-data' + } + + tasks.create(name: "copyDependantLibs", type: Copy) { + from (configurations.testRuntimeClasspath) { + include('slf4j-log4j12*') + include('log4j*jar') + } + from (configurations.runtimeClasspath) { + exclude('kafka-clients*') + } + into "$buildDir/dependant-libs-${versions.scala}" + duplicatesStrategy 'exclude' + } + + task processMessages(type:JavaExec) { + main = "org.apache.kafka.message.MessageGenerator" + classpath = project(':generator').sourceSets.main.runtimeClasspath + args = [ "-p", "kafka.internals.generated", + "-o", "src/generated/java/kafka/internals/generated", + "-i", "src/main/resources/common/message", + "-m", "MessageDataGenerator" + ] + inputs.dir("src/main/resources/common/message") + outputs.dir("src/generated/java/kafka/internals/generated") + } + + compileJava.dependsOn 'processMessages' + + task genProtocolErrorDocs(type: JavaExec) { + classpath = sourceSets.main.runtimeClasspath + main = 'org.apache.kafka.common.protocol.Errors' + if( !generatedDocsDir.exists() ) { generatedDocsDir.mkdirs() } + standardOutput = new File(generatedDocsDir, "protocol_errors.html").newOutputStream() + } + + task genProtocolTypesDocs(type: JavaExec) { + classpath = sourceSets.main.runtimeClasspath + main = 'org.apache.kafka.common.protocol.types.Type' + if( !generatedDocsDir.exists() ) { generatedDocsDir.mkdirs() } + standardOutput = new File(generatedDocsDir, "protocol_types.html").newOutputStream() + } + + task genProtocolApiKeyDocs(type: JavaExec) { + classpath = sourceSets.main.runtimeClasspath + main = 'org.apache.kafka.common.protocol.ApiKeys' + if( !generatedDocsDir.exists() ) { generatedDocsDir.mkdirs() } + standardOutput = new File(generatedDocsDir, "protocol_api_keys.html").newOutputStream() + } + + task genProtocolMessageDocs(type: JavaExec) { + classpath = sourceSets.main.runtimeClasspath + main = 'org.apache.kafka.common.protocol.Protocol' + if( !generatedDocsDir.exists() ) { generatedDocsDir.mkdirs() } + standardOutput = new File(generatedDocsDir, "protocol_messages.html").newOutputStream() + } + + task genAdminClientConfigDocs(type: JavaExec) { + classpath = sourceSets.main.runtimeClasspath + main = 'org.apache.kafka.clients.admin.AdminClientConfig' + if( !generatedDocsDir.exists() ) { generatedDocsDir.mkdirs() } + standardOutput = new File(generatedDocsDir, "admin_client_config.html").newOutputStream() + } + + task genProducerConfigDocs(type: JavaExec) { + classpath = sourceSets.main.runtimeClasspath + main = 'org.apache.kafka.clients.producer.ProducerConfig' + if( !generatedDocsDir.exists() ) { generatedDocsDir.mkdirs() } + standardOutput = new File(generatedDocsDir, "producer_config.html").newOutputStream() + } + + task genConsumerConfigDocs(type: JavaExec) { + classpath = sourceSets.main.runtimeClasspath + main = 'org.apache.kafka.clients.consumer.ConsumerConfig' + if( !generatedDocsDir.exists() ) { generatedDocsDir.mkdirs() } + standardOutput = new File(generatedDocsDir, "consumer_config.html").newOutputStream() + } + + task genKafkaConfigDocs(type: JavaExec) { + classpath = sourceSets.main.runtimeClasspath + main = 'kafka.server.KafkaConfig' + if( !generatedDocsDir.exists() ) { generatedDocsDir.mkdirs() } + standardOutput = new File(generatedDocsDir, "kafka_config.html").newOutputStream() + } + + task genTopicConfigDocs(type: JavaExec) { + classpath = sourceSets.main.runtimeClasspath + main = 'kafka.log.LogConfig' + if( !generatedDocsDir.exists() ) { generatedDocsDir.mkdirs() } + standardOutput = new File(generatedDocsDir, "topic_config.html").newOutputStream() + } + + task genConsumerMetricsDocs(type: JavaExec) { + classpath = sourceSets.test.runtimeClasspath + main = 'org.apache.kafka.clients.consumer.internals.ConsumerMetrics' + if( !generatedDocsDir.exists() ) { generatedDocsDir.mkdirs() } + standardOutput = new File(generatedDocsDir, "consumer_metrics.html").newOutputStream() + } + + task genProducerMetricsDocs(type: JavaExec) { + classpath = sourceSets.test.runtimeClasspath + main = 'org.apache.kafka.clients.producer.internals.ProducerMetrics' + if( !generatedDocsDir.exists() ) { generatedDocsDir.mkdirs() } + standardOutput = new File(generatedDocsDir, "producer_metrics.html").newOutputStream() + } + + task siteDocsTar(dependsOn: ['genProtocolErrorDocs', 'genProtocolTypesDocs', 'genProtocolApiKeyDocs', 'genProtocolMessageDocs', + 'genAdminClientConfigDocs', 'genProducerConfigDocs', 'genConsumerConfigDocs', + 'genKafkaConfigDocs', 'genTopicConfigDocs', + ':connect:runtime:genConnectConfigDocs', ':connect:runtime:genConnectTransformationDocs', + ':connect:runtime:genConnectPredicateDocs', + ':connect:runtime:genSinkConnectorConfigDocs', ':connect:runtime:genSourceConnectorConfigDocs', + ':streams:genStreamsConfigDocs', 'genConsumerMetricsDocs', 'genProducerMetricsDocs', + ':connect:runtime:genConnectMetricsDocs'], type: Tar) { + archiveClassifier = 'site-docs' + compression = Compression.GZIP + from project.file("$rootDir/docs") + into 'site-docs' + duplicatesStrategy 'exclude' + } + + tasks.create(name: "releaseTarGz", dependsOn: configurations.archives.artifacts, type: Tar) { + into "kafka_${versions.baseScala}-${archiveVersion.get()}" + compression = Compression.GZIP + from(project.file("$rootDir/bin")) { into "bin/" } + from(project.file("$rootDir/config")) { into "config/" } + from(project.file("$rootDir/licenses")) { into "licenses/" } + from "$rootDir/LICENSE-binary" rename {String filename -> filename.replace("-binary", "")} + from "$rootDir/NOTICE-binary" rename {String filename -> filename.replace("-binary", "")} + from(configurations.runtimeClasspath) { into("libs/") } + from(configurations.archives.artifacts.files) { into("libs/") } + from(project.siteDocsTar) { into("site-docs/") } + from(project(':tools').jar) { into("libs/") } + from(project(':tools').configurations.runtimeClasspath) { into("libs/") } + from(project(':trogdor').jar) { into("libs/") } + from(project(':trogdor').configurations.runtimeClasspath) { into("libs/") } + from(project(':shell').jar) { into("libs/") } + from(project(':shell').configurations.runtimeClasspath) { into("libs/") } + from(project(':connect:api').jar) { into("libs/") } + from(project(':connect:api').configurations.runtimeClasspath) { into("libs/") } + from(project(':connect:runtime').jar) { into("libs/") } + from(project(':connect:runtime').configurations.runtimeClasspath) { into("libs/") } + from(project(':connect:transforms').jar) { into("libs/") } + from(project(':connect:transforms').configurations.runtimeClasspath) { into("libs/") } + from(project(':connect:json').jar) { into("libs/") } + from(project(':connect:json').configurations.runtimeClasspath) { into("libs/") } + from(project(':connect:file').jar) { into("libs/") } + from(project(':connect:file').configurations.runtimeClasspath) { into("libs/") } + from(project(':connect:basic-auth-extension').jar) { into("libs/") } + from(project(':connect:basic-auth-extension').configurations.runtimeClasspath) { into("libs/") } + from(project(':connect:mirror').jar) { into("libs/") } + from(project(':connect:mirror').configurations.runtimeClasspath) { into("libs/") } + from(project(':connect:mirror-client').jar) { into("libs/") } + from(project(':connect:mirror-client').configurations.runtimeClasspath) { into("libs/") } + from(project(':streams').jar) { into("libs/") } + from(project(':streams').configurations.runtimeClasspath) { into("libs/") } + from(project(':streams:streams-scala').jar) { into("libs/") } + from(project(':streams:streams-scala').configurations.runtimeClasspath) { into("libs/") } + from(project(':streams:test-utils').jar) { into("libs/") } + from(project(':streams:test-utils').configurations.runtimeClasspath) { into("libs/") } + from(project(':streams:examples').jar) { into("libs/") } + from(project(':streams:examples').configurations.runtimeClasspath) { into("libs/") } + duplicatesStrategy 'exclude' + } + + jar { + dependsOn('copyDependantLibs') + } + + jar.manifest { + attributes( + 'Version': "${version}" + ) + } + + tasks.create(name: "copyDependantTestLibs", type: Copy) { + from (configurations.testRuntimeClasspath) { + include('*.jar') + } + into "$buildDir/dependant-testlibs" + //By default gradle does not handle test dependencies between the sub-projects + //This line is to include clients project test jar to dependant-testlibs + from (project(':clients').testJar ) { "$buildDir/dependant-testlibs" } + duplicatesStrategy 'exclude' + } + + systemTestLibs.dependsOn('jar', 'testJar', 'copyDependantTestLibs') + + checkstyle { + configProperties = checkstyleConfigProperties("import-control-core.xml") + } + + sourceSets { + // Set java/scala source folders in the `scala` block to enable joint compilation + main { + java { + srcDirs = [] + } + scala { + srcDirs = ["src/generated/java", "src/main/java", "src/main/scala"] + } + } + test { + java { + srcDirs = [] + } + scala { + srcDirs = ["src/test/java", "src/test/scala"] + } + } + } +} + +project(':metadata') { + archivesBaseName = "kafka-metadata" + + dependencies { + implementation project(':server-common') + implementation project(':clients') + implementation project(':raft') + implementation libs.jacksonDatabind + implementation libs.jacksonJDK8Datatypes + implementation libs.metrics + compileOnly libs.log4j + testImplementation libs.junitJupiter + testImplementation libs.hamcrest + testImplementation libs.slf4jlog4j + testImplementation project(':clients').sourceSets.test.output + testImplementation project(':raft').sourceSets.test.output + } + + task processMessages(type:JavaExec) { + main = "org.apache.kafka.message.MessageGenerator" + classpath = project(':generator').sourceSets.main.runtimeClasspath + args = [ "-p", "org.apache.kafka.common.metadata", + "-o", "src/generated/java/org/apache/kafka/common/metadata", + "-i", "src/main/resources/common/metadata", + "-m", "MessageDataGenerator", "JsonConverterGenerator", + "-t", "MetadataRecordTypeGenerator", "MetadataJsonConvertersGenerator" + ] + inputs.dir("src/main/resources/common/metadata") + outputs.dir("src/generated/java/org/apache/kafka/common/metadata") + } + + compileJava.dependsOn 'processMessages' + + sourceSets { + main { + java { + srcDirs = ["src/generated/java", "src/main/java"] + } + } + test { + java { + srcDirs = ["src/generated/java", "src/test/java"] + } + } + } + + javadoc { + enabled = false + } +} + +project(':examples') { + archivesBaseName = "kafka-examples" + + dependencies { + implementation project(':clients') + implementation project(':core') + } + + javadoc { + enabled = false + } + + checkstyle { + configProperties = checkstyleConfigProperties("import-control-core.xml") + } +} + +project(':generator') { + dependencies { + implementation libs.argparse4j + implementation libs.jacksonDatabind + implementation libs.jacksonJDK8Datatypes + implementation libs.jacksonJaxrsJsonProvider + testImplementation libs.junitJupiter + } + + javadoc { + enabled = false + } +} + +project(':clients') { + archivesBaseName = "kafka-clients" + + dependencies { + implementation libs.zstd + implementation libs.lz4 + implementation libs.snappy + implementation libs.slf4jApi + + compileOnly libs.jacksonDatabind // for SASL/OAUTHBEARER bearer token parsing + compileOnly libs.jacksonJDK8Datatypes + compileOnly libs.jose4j // for SASL/OAUTHBEARER JWT validation; only used by broker + + testImplementation libs.bcpkix + testImplementation libs.junitJupiter + testImplementation libs.mockitoCore + + testRuntimeOnly libs.slf4jlog4j + testRuntimeOnly libs.jacksonDatabind + testRuntimeOnly libs.jacksonJDK8Datatypes + testImplementation libs.jose4j + testImplementation libs.jacksonJaxrsJsonProvider + } + + task createVersionFile(dependsOn: determineCommitId) { + ext.receiptFile = file("$buildDir/kafka/$buildVersionFileName") + outputs.file receiptFile + outputs.upToDateWhen { false } + doLast { + def data = [ + commitId: commitId, + version: version, + ] + + receiptFile.parentFile.mkdirs() + def content = data.entrySet().collect { "$it.key=$it.value" }.sort().join("\n") + receiptFile.setText(content, "ISO-8859-1") + } + } + + jar { + dependsOn createVersionFile + from("$buildDir") { + include "kafka/$buildVersionFileName" + } + } + + clean.doFirst { + delete "$buildDir/kafka/" + } + + task processMessages(type:JavaExec) { + main = "org.apache.kafka.message.MessageGenerator" + classpath = project(':generator').sourceSets.main.runtimeClasspath + args = [ "-p", "org.apache.kafka.common.message", + "-o", "src/generated/java/org/apache/kafka/common/message", + "-i", "src/main/resources/common/message", + "-t", "ApiMessageTypeGenerator", + "-m", "MessageDataGenerator", "JsonConverterGenerator" + ] + inputs.dir("src/main/resources/common/message") + outputs.dir("src/generated/java/org/apache/kafka/common/message") + } + + task processTestMessages(type:JavaExec) { + main = "org.apache.kafka.message.MessageGenerator" + classpath = project(':generator').sourceSets.main.runtimeClasspath + args = [ "-p", "org.apache.kafka.common.message", + "-o", "src/generated-test/java/org/apache/kafka/common/message", + "-i", "src/test/resources/common/message", + "-m", "MessageDataGenerator", "JsonConverterGenerator" + ] + inputs.dir("src/test/resources/common/message") + outputs.dir("src/generated-test/java/org/apache/kafka/common/message") + } + + sourceSets { + main { + java { + srcDirs = ["src/generated/java", "src/main/java"] + } + } + test { + java { + srcDirs = ["src/generated/java", "src/generated-test/java", "src/test/java"] + } + } + } + + compileJava.dependsOn 'processMessages' + + compileTestJava.dependsOn 'processTestMessages' + + javadoc { + include "**/org/apache/kafka/clients/admin/*" + include "**/org/apache/kafka/clients/consumer/*" + include "**/org/apache/kafka/clients/producer/*" + include "**/org/apache/kafka/common/*" + include "**/org/apache/kafka/common/acl/*" + include "**/org/apache/kafka/common/annotation/*" + include "**/org/apache/kafka/common/errors/*" + include "**/org/apache/kafka/common/header/*" + include "**/org/apache/kafka/common/metrics/*" + include "**/org/apache/kafka/common/metrics/stats/*" + include "**/org/apache/kafka/common/quota/*" + include "**/org/apache/kafka/common/resource/*" + include "**/org/apache/kafka/common/serialization/*" + include "**/org/apache/kafka/common/config/*" + include "**/org/apache/kafka/common/config/provider/*" + include "**/org/apache/kafka/common/security/auth/*" + include "**/org/apache/kafka/common/security/plain/*" + include "**/org/apache/kafka/common/security/scram/*" + include "**/org/apache/kafka/common/security/token/delegation/*" + include "**/org/apache/kafka/common/security/oauthbearer/*" + include "**/org/apache/kafka/server/authorizer/*" + include "**/org/apache/kafka/server/policy/*" + include "**/org/apache/kafka/server/quota/*" + } +} + +project(':raft') { + archivesBaseName = "kafka-raft" + + dependencies { + implementation project(':server-common') + implementation project(':clients') + implementation libs.slf4jApi + implementation libs.jacksonDatabind + + testImplementation project(':server-common') + testImplementation project(':clients') + testImplementation project(':clients').sourceSets.test.output + testImplementation libs.junitJupiter + testImplementation libs.mockitoCore + testImplementation libs.jqwik + + testRuntimeOnly libs.slf4jlog4j + } + + task createVersionFile(dependsOn: determineCommitId) { + ext.receiptFile = file("$buildDir/kafka/$buildVersionFileName") + outputs.file receiptFile + outputs.upToDateWhen { false } + doLast { + def data = [ + commitId: commitId, + version: version, + ] + + receiptFile.parentFile.mkdirs() + def content = data.entrySet().collect { "$it.key=$it.value" }.sort().join("\n") + receiptFile.setText(content, "ISO-8859-1") + } + } + + task processMessages(type:JavaExec) { + main = "org.apache.kafka.message.MessageGenerator" + classpath = project(':generator').sourceSets.main.runtimeClasspath + args = [ "-p", "org.apache.kafka.raft.generated", + "-o", "src/generated/java/org/apache/kafka/raft/generated", + "-i", "src/main/resources/common/message", + "-m", "MessageDataGenerator", "JsonConverterGenerator"] + inputs.dir("src/main/resources/common/message") + outputs.dir("src/generated/java/org/apache/kafka/raft/generated") + } + + sourceSets { + main { + java { + srcDirs = ["src/generated/java", "src/main/java"] + } + } + test { + java { + srcDirs = ["src/generated/java", "src/test/java"] + } + } + } + + compileJava.dependsOn 'processMessages' + + jar { + dependsOn createVersionFile + from("$buildDir") { + include "kafka/$buildVersionFileName" + } + } + + test { + useJUnitPlatform { + includeEngines 'jqwik', 'junit-jupiter' + } + } + + clean.doFirst { + delete "$buildDir/kafka/" + } + + javadoc { + enabled = false + } +} + +project(':server-common') { + archivesBaseName = "kafka-server-common" + + dependencies { + api project(':clients') + implementation libs.slf4jApi + + testImplementation project(':clients') + testImplementation project(':clients').sourceSets.test.output + testImplementation libs.junitJupiter + testImplementation libs.mockitoCore + + testRuntimeOnly libs.slf4jlog4j + } + + task createVersionFile(dependsOn: determineCommitId) { + ext.receiptFile = file("$buildDir/kafka/$buildVersionFileName") + outputs.file receiptFile + outputs.upToDateWhen { false } + doLast { + def data = [ + commitId: commitId, + version: version, + ] + + receiptFile.parentFile.mkdirs() + def content = data.entrySet().collect { "$it.key=$it.value" }.sort().join("\n") + receiptFile.setText(content, "ISO-8859-1") + } + } + + sourceSets { + main { + java { + srcDirs = ["src/main/java"] + } + } + test { + java { + srcDirs = ["src/test/java"] + } + } + } + + jar { + dependsOn createVersionFile + from("$buildDir") { + include "kafka/$buildVersionFileName" + } + } + + clean.doFirst { + delete "$buildDir/kafka/" + } +} + +project(':storage:api') { + archivesBaseName = "kafka-storage-api" + + dependencies { + implementation project(':clients') + implementation libs.slf4jApi + + testImplementation project(':clients') + testImplementation project(':clients').sourceSets.test.output + testImplementation libs.junitJupiter + testImplementation libs.mockitoCore + + testRuntimeOnly libs.slf4jlog4j + } + + task createVersionFile(dependsOn: determineCommitId) { + ext.receiptFile = file("$buildDir/kafka/$buildVersionFileName") + outputs.file receiptFile + outputs.upToDateWhen { false } + doLast { + def data = [ + commitId: commitId, + version: version, + ] + + receiptFile.parentFile.mkdirs() + def content = data.entrySet().collect { "$it.key=$it.value" }.sort().join("\n") + receiptFile.setText(content, "ISO-8859-1") + } + } + + sourceSets { + main { + java { + srcDirs = ["src/main/java"] + } + } + test { + java { + srcDirs = ["src/test/java"] + } + } + } + + jar { + dependsOn createVersionFile + from("$buildDir") { + include "kafka/$buildVersionFileName" + } + } + + clean.doFirst { + delete "$buildDir/kafka/" + } + + javadoc { + include "**/org/apache/kafka/server/log/remote/storage/*" + } +} + +project(':storage') { + archivesBaseName = "kafka-storage" + + dependencies { + implementation project(':storage:api') + implementation project(':server-common') + implementation project(':clients') + implementation libs.slf4jApi + implementation libs.jacksonDatabind + + testImplementation project(':clients') + testImplementation project(':clients').sourceSets.test.output + testImplementation project(':core') + testImplementation project(':core').sourceSets.test.output + testImplementation libs.junitJupiter + testImplementation libs.mockitoCore + testImplementation libs.bcpkix + + testRuntimeOnly libs.slf4jlog4j + } + + task createVersionFile(dependsOn: determineCommitId) { + ext.receiptFile = file("$buildDir/kafka/$buildVersionFileName") + outputs.file receiptFile + outputs.upToDateWhen { false } + doLast { + def data = [ + commitId: commitId, + version: version, + ] + + receiptFile.parentFile.mkdirs() + def content = data.entrySet().collect { "$it.key=$it.value" }.sort().join("\n") + receiptFile.setText(content, "ISO-8859-1") + } + } + + task processMessages(type:JavaExec) { + main = "org.apache.kafka.message.MessageGenerator" + classpath = project(':generator').sourceSets.main.runtimeClasspath + args = [ "-p", " org.apache.kafka.server.log.remote.metadata.storage.generated", + "-o", "src/generated/java/org/apache/kafka/server/log/remote/metadata/storage/generated", + "-i", "src/main/resources/message", + "-m", "MessageDataGenerator", "JsonConverterGenerator", + "-t", "MetadataRecordTypeGenerator", "MetadataJsonConvertersGenerator" ] + inputs.dir("src/main/resources/message") + outputs.dir("src/generated/java/org/apache/kafka/server/log/remote/metadata/storage/generated") + } + + sourceSets { + main { + java { + srcDirs = ["src/generated/java", "src/main/java"] + } + } + test { + java { + srcDirs = ["src/generated/java", "src/test/java"] + } + } + } + + compileJava.dependsOn 'processMessages' + + jar { + dependsOn createVersionFile + from("$buildDir") { + include "kafka/$buildVersionFileName" + } + } + + test { + useJUnitPlatform { + includeEngines 'junit-jupiter' + } + } + + clean.doFirst { + delete "$buildDir/kafka/" + } + + javadoc { + enabled = false + } +} + +project(':tools') { + archivesBaseName = "kafka-tools" + + dependencies { + implementation project(':clients') + implementation project(':log4j-appender') + implementation libs.argparse4j + implementation libs.jacksonDatabind + implementation libs.jacksonJDK8Datatypes + implementation libs.slf4jApi + implementation libs.log4j + + implementation libs.jose4j // for SASL/OAUTHBEARER JWT validation + implementation libs.jacksonJaxrsJsonProvider + + testImplementation project(':clients') + testImplementation libs.junitJupiter + testImplementation project(':clients').sourceSets.test.output + testImplementation libs.mockitoInline // supports mocking static methods, final classes, etc. + testImplementation libs.mockitoJunitJupiter // supports MockitoExtension + testRuntimeOnly libs.slf4jlog4j + } + + javadoc { + enabled = false + } + + tasks.create(name: "copyDependantLibs", type: Copy) { + from (configurations.testRuntimeClasspath) { + include('slf4j-log4j12*') + include('log4j*jar') + } + from (configurations.runtimeClasspath) { + exclude('kafka-clients*') + } + into "$buildDir/dependant-libs-${versions.scala}" + duplicatesStrategy 'exclude' + } + + jar { + dependsOn 'copyDependantLibs' + } +} + +project(':trogdor') { + archivesBaseName = "trogdor" + + dependencies { + implementation project(':clients') + implementation project(':log4j-appender') + implementation libs.argparse4j + implementation libs.jacksonDatabind + implementation libs.jacksonJDK8Datatypes + implementation libs.slf4jApi + implementation libs.log4j + + implementation libs.jacksonJaxrsJsonProvider + implementation libs.jerseyContainerServlet + implementation libs.jerseyHk2 + implementation libs.jaxbApi // Jersey dependency that was available in the JDK before Java 9 + implementation libs.activation // Jersey dependency that was available in the JDK before Java 9 + implementation libs.jettyServer + implementation libs.jettyServlet + implementation libs.jettyServlets + + testImplementation project(':clients') + testImplementation libs.junitJupiter + testImplementation project(':clients').sourceSets.test.output + testImplementation libs.mockitoInline // supports mocking static methods, final classes, etc. + + testRuntimeOnly libs.slf4jlog4j + } + + javadoc { + enabled = false + } + + tasks.create(name: "copyDependantLibs", type: Copy) { + from (configurations.testRuntimeClasspath) { + include('slf4j-log4j12*') + include('log4j*jar') + } + from (configurations.runtimeClasspath) { + exclude('kafka-clients*') + } + into "$buildDir/dependant-libs-${versions.scala}" + duplicatesStrategy 'exclude' + } + + jar { + dependsOn 'copyDependantLibs' + } +} + +project(':shell') { + archivesBaseName = "kafka-shell" + + dependencies { + implementation libs.argparse4j + implementation libs.jacksonDatabind + implementation libs.jacksonJDK8Datatypes + implementation libs.jline + implementation libs.slf4jApi + implementation project(':server-common') + implementation project(':clients') + implementation project(':core') + implementation project(':log4j-appender') + implementation project(':metadata') + implementation project(':raft') + + implementation libs.jose4j // for SASL/OAUTHBEARER JWT validation + implementation libs.jacksonJaxrsJsonProvider + + testImplementation project(':clients') + testImplementation libs.junitJupiter + + testRuntimeOnly libs.slf4jlog4j + } + + javadoc { + enabled = false + } + + tasks.create(name: "copyDependantLibs", type: Copy) { + from (configurations.testRuntimeClasspath) { + include('jline-*jar') + } + from (configurations.runtimeClasspath) { + include('jline-*jar') + } + into "$buildDir/dependant-libs-${versions.scala}" + duplicatesStrategy 'exclude' + } + + jar { + dependsOn 'copyDependantLibs' + } +} + +project(':streams') { + archivesBaseName = "kafka-streams" + ext.buildStreamsVersionFileName = "kafka-streams-version.properties" + + dependencies { + api project(':clients') + // `org.rocksdb.Options` is part of Kafka Streams public api via `RocksDBConfigSetter` + api libs.rocksDBJni + + implementation libs.slf4jApi + implementation libs.jacksonAnnotations + implementation libs.jacksonDatabind + + // testCompileOnly prevents streams from exporting a dependency on test-utils, which would cause a dependency cycle + testCompileOnly project(':streams:test-utils') + testImplementation project(':clients').sourceSets.test.output + testImplementation project(':core') + testImplementation project(':core').sourceSets.test.output + testImplementation libs.log4j + testImplementation libs.junitJupiterApi + testImplementation libs.junitVintageEngine + testImplementation libs.easymock + testImplementation libs.powermockJunit4 + testImplementation libs.powermockEasymock + testImplementation libs.bcpkix + testImplementation libs.hamcrest + testImplementation libs.mockitoInline // supports mocking static methods, final classes, etc. + + testRuntimeOnly project(':streams:test-utils') + testRuntimeOnly libs.slf4jlog4j + } + + task processMessages(type:JavaExec) { + main = "org.apache.kafka.message.MessageGenerator" + classpath = project(':generator').sourceSets.main.runtimeClasspath + args = [ "-p", "org.apache.kafka.streams.internals.generated", + "-o", "src/generated/java/org/apache/kafka/streams/internals/generated", + "-i", "src/main/resources/common/message", + "-m", "MessageDataGenerator" + ] + inputs.dir("src/main/resources/common/message") + outputs.dir("src/generated/java/org/apache/kafka/streams/internals/generated") + } + + sourceSets { + main { + java { + srcDirs = ["src/generated/java", "src/main/java"] + } + } + test { + java { + srcDirs = ["src/generated/java", "src/test/java"] + } + } + } + + compileJava.dependsOn 'processMessages' + + javadoc { + include "**/org/apache/kafka/streams/**" + exclude "**/org/apache/kafka/streams/internals/**", "**/org/apache/kafka/streams/**/internals/**" + } + + tasks.create(name: "copyDependantLibs", type: Copy) { + from (configurations.runtimeClasspath) { + exclude('kafka-clients*') + } + into "$buildDir/dependant-libs-${versions.scala}" + duplicatesStrategy 'exclude' + } + + task createStreamsVersionFile(dependsOn: determineCommitId) { + ext.receiptFile = file("$buildDir/kafka/$buildStreamsVersionFileName") + outputs.file receiptFile + outputs.upToDateWhen { false } + doLast { + def data = [ + commitId: commitId, + version: version, + ] + + receiptFile.parentFile.mkdirs() + def content = data.entrySet().collect { "$it.key=$it.value" }.sort().join("\n") + receiptFile.setText(content, "ISO-8859-1") + } + } + + jar { + dependsOn 'createStreamsVersionFile' + from("$buildDir") { + include "kafka/$buildStreamsVersionFileName" + } + dependsOn 'copyDependantLibs' + } + + systemTestLibs { + dependsOn testJar + } + + task genStreamsConfigDocs(type: JavaExec) { + classpath = sourceSets.main.runtimeClasspath + main = 'org.apache.kafka.streams.StreamsConfig' + if( !generatedDocsDir.exists() ) { generatedDocsDir.mkdirs() } + standardOutput = new File(generatedDocsDir, "streams_config.html").newOutputStream() + } + + task testAll( + dependsOn: [ + ':streams:test', + ':streams:test-utils:test', + ':streams:streams-scala:test', + ':streams:upgrade-system-tests-0100:test', + ':streams:upgrade-system-tests-0101:test', + ':streams:upgrade-system-tests-0102:test', + ':streams:upgrade-system-tests-0110:test', + ':streams:upgrade-system-tests-10:test', + ':streams:upgrade-system-tests-11:test', + ':streams:upgrade-system-tests-20:test', + ':streams:upgrade-system-tests-21:test', + ':streams:upgrade-system-tests-22:test', + ':streams:upgrade-system-tests-23:test', + ':streams:upgrade-system-tests-24:test', + ':streams:upgrade-system-tests-25:test', + ':streams:upgrade-system-tests-26:test', + ':streams:upgrade-system-tests-27:test', + ':streams:upgrade-system-tests-28:test', + ':streams:examples:test' + ] + ) +} + +project(':streams:streams-scala') { + apply plugin: 'scala' + archivesBaseName = "kafka-streams-scala_${versions.baseScala}" + dependencies { + api project(':streams') + + api libs.scalaLibrary + api libs.scalaCollectionCompat + + testImplementation project(':core') + testImplementation project(':core').sourceSets.test.output + testImplementation project(':streams').sourceSets.test.output + testImplementation project(':clients').sourceSets.test.output + testImplementation project(':streams:test-utils') + + testImplementation libs.junitJupiter + testImplementation libs.easymock + testImplementation libs.hamcrest + testRuntimeOnly libs.slf4jlog4j + } + + javadoc { + include "**/org/apache/kafka/streams/scala/**" + } + + tasks.create(name: "copyDependantLibs", type: Copy) { + from (configurations.runtimeClasspath) { + exclude('kafka-streams*') + } + into "$buildDir/dependant-libs-${versions.scala}" + duplicatesStrategy 'exclude' + } + + jar { + dependsOn 'copyDependantLibs' + } + + test.dependsOn(':spotlessScalaCheck') +} + +project(':streams:test-utils') { + archivesBaseName = "kafka-streams-test-utils" + + dependencies { + api project(':streams') + api project(':clients') + + implementation libs.slf4jApi + + testImplementation project(':clients').sourceSets.test.output + testImplementation libs.junitJupiter + testImplementation libs.mockitoCore + testImplementation libs.hamcrest + + testRuntimeOnly libs.slf4jlog4j + } + + javadoc { + include "**/org/apache/kafka/streams/test/**" + } + + tasks.create(name: "copyDependantLibs", type: Copy) { + from (configurations.runtimeClasspath) { + exclude('kafka-streams*') + } + into "$buildDir/dependant-libs-${versions.scala}" + duplicatesStrategy 'exclude' + } + + jar { + dependsOn 'copyDependantLibs' + } + +} + +project(':streams:examples') { + archivesBaseName = "kafka-streams-examples" + + dependencies { + // this dependency should be removed after we unify data API + implementation(project(':connect:json')) { + // this transitive dependency is not used in Streams, and it breaks SBT builds + exclude module: 'javax.ws.rs-api' + } + + implementation project(':streams') + + implementation libs.slf4jlog4j + + testImplementation project(':streams:test-utils') + testImplementation project(':clients').sourceSets.test.output // for org.apache.kafka.test.IntegrationTest + testImplementation libs.junitJupiter + testImplementation libs.hamcrest + } + + javadoc { + enabled = false + } + + tasks.create(name: "copyDependantLibs", type: Copy) { + from (configurations.runtimeClasspath) { + exclude('kafka-streams*') + } + into "$buildDir/dependant-libs-${versions.scala}" + duplicatesStrategy 'exclude' + } + + jar { + dependsOn 'copyDependantLibs' + } +} + +project(':streams:upgrade-system-tests-0100') { + archivesBaseName = "kafka-streams-upgrade-system-tests-0100" + + dependencies { + testImplementation libs.kafkaStreams_0100 + testRuntimeOnly libs.junitJupiter + } + + systemTestLibs { + dependsOn testJar + } +} + +project(':streams:upgrade-system-tests-0101') { + archivesBaseName = "kafka-streams-upgrade-system-tests-0101" + + dependencies { + testImplementation libs.kafkaStreams_0101 + testRuntimeOnly libs.junitJupiter + } + + systemTestLibs { + dependsOn testJar + } +} + +project(':streams:upgrade-system-tests-0102') { + archivesBaseName = "kafka-streams-upgrade-system-tests-0102" + + dependencies { + testImplementation libs.kafkaStreams_0102 + testRuntimeOnly libs.junitJupiter + } + + systemTestLibs { + dependsOn testJar + } +} + +project(':streams:upgrade-system-tests-0110') { + archivesBaseName = "kafka-streams-upgrade-system-tests-0110" + + dependencies { + testImplementation libs.kafkaStreams_0110 + testRuntimeOnly libs.junitJupiter + } + + systemTestLibs { + dependsOn testJar + } +} + +project(':streams:upgrade-system-tests-10') { + archivesBaseName = "kafka-streams-upgrade-system-tests-10" + + dependencies { + testImplementation libs.kafkaStreams_10 + testRuntimeOnly libs.junitJupiter + } + + systemTestLibs { + dependsOn testJar + } +} + +project(':streams:upgrade-system-tests-11') { + archivesBaseName = "kafka-streams-upgrade-system-tests-11" + + dependencies { + testImplementation libs.kafkaStreams_11 + testRuntimeOnly libs.junitJupiter + } + + systemTestLibs { + dependsOn testJar + } +} + +project(':streams:upgrade-system-tests-20') { + archivesBaseName = "kafka-streams-upgrade-system-tests-20" + + dependencies { + testImplementation libs.kafkaStreams_20 + testRuntimeOnly libs.junitJupiter + } + + systemTestLibs { + dependsOn testJar + } +} + +project(':streams:upgrade-system-tests-21') { + archivesBaseName = "kafka-streams-upgrade-system-tests-21" + + dependencies { + testImplementation libs.kafkaStreams_21 + testRuntimeOnly libs.junitJupiter + } + + systemTestLibs { + dependsOn testJar + } +} + +project(':streams:upgrade-system-tests-22') { + archivesBaseName = "kafka-streams-upgrade-system-tests-22" + + dependencies { + testImplementation libs.kafkaStreams_22 + testRuntimeOnly libs.junitJupiter + } + + systemTestLibs { + dependsOn testJar + } +} + +project(':streams:upgrade-system-tests-23') { + archivesBaseName = "kafka-streams-upgrade-system-tests-23" + + dependencies { + testImplementation libs.kafkaStreams_23 + testRuntimeOnly libs.junitJupiter + } + + systemTestLibs { + dependsOn testJar + } +} + +project(':streams:upgrade-system-tests-24') { + archivesBaseName = "kafka-streams-upgrade-system-tests-24" + + dependencies { + testImplementation libs.kafkaStreams_24 + testRuntimeOnly libs.junitJupiter + } + + systemTestLibs { + dependsOn testJar + } +} + +project(':streams:upgrade-system-tests-25') { + archivesBaseName = "kafka-streams-upgrade-system-tests-25" + + dependencies { + testImplementation libs.kafkaStreams_25 + testRuntimeOnly libs.junitJupiter + } + + systemTestLibs { + dependsOn testJar + } +} + +project(':streams:upgrade-system-tests-26') { + archivesBaseName = "kafka-streams-upgrade-system-tests-26" + + dependencies { + testImplementation libs.kafkaStreams_26 + testRuntimeOnly libs.junitJupiter + } + + systemTestLibs { + dependsOn testJar + } +} + +project(':streams:upgrade-system-tests-27') { + archivesBaseName = "kafka-streams-upgrade-system-tests-27" + + dependencies { + testImplementation libs.kafkaStreams_27 + testRuntimeOnly libs.junitJupiter + } + + systemTestLibs { + dependsOn testJar + } +} + +project(':streams:upgrade-system-tests-28') { + archivesBaseName = "kafka-streams-upgrade-system-tests-28" + + dependencies { + testImplementation libs.kafkaStreams_28 + testRuntimeOnly libs.junitJupiter + } + + systemTestLibs { + dependsOn testJar + } +} + +project(':jmh-benchmarks') { + + apply plugin: 'com.github.johnrengelman.shadow' + + shadowJar { + archiveBaseName = 'kafka-jmh-benchmarks' + } + + dependencies { + implementation(project(':core')) { + // jmh requires jopt 4.x while `core` depends on 5.0, they are not binary compatible + exclude group: 'net.sf.jopt-simple', module: 'jopt-simple' + } + implementation project(':clients') + implementation project(':metadata') + implementation project(':streams') + implementation project(':core') + implementation project(':clients').sourceSets.test.output + implementation project(':core').sourceSets.test.output + + implementation libs.jmhCore + annotationProcessor libs.jmhGeneratorAnnProcess + implementation libs.jmhCoreBenchmarks + implementation libs.jacksonDatabind + implementation libs.metrics + implementation libs.mockitoCore + implementation libs.slf4jlog4j + implementation libs.scalaLibrary + implementation libs.scalaJava8Compat + } + + tasks.withType(JavaCompile) { + // Suppress warning caused by code generated by jmh: `warning: [cast] redundant cast to long` + options.compilerArgs << "-Xlint:-cast" + } + + jar { + manifest { + attributes "Main-Class": "org.openjdk.jmh.Main" + } + } + + checkstyle { + configProperties = checkstyleConfigProperties("import-control-jmh-benchmarks.xml") + } + + task jmh(type: JavaExec, dependsOn: [':jmh-benchmarks:clean', ':jmh-benchmarks:shadowJar']) { + + main="-jar" + + doFirst { + if (System.getProperty("jmhArgs")) { + args System.getProperty("jmhArgs").split(' ') + } + args = [shadowJar.archivePath, *args] + } + } + + javadoc { + enabled = false + } +} + +project(':log4j-appender') { + archivesBaseName = "kafka-log4j-appender" + + dependencies { + implementation project(':clients') + implementation libs.slf4jlog4j + + testImplementation project(':clients').sourceSets.test.output + testImplementation libs.junitJupiter + testImplementation libs.hamcrest + testImplementation libs.mockitoCore + } + + javadoc { + enabled = false + } + +} + +project(':connect:api') { + archivesBaseName = "connect-api" + + dependencies { + api project(':clients') + implementation libs.slf4jApi + implementation libs.jaxrsApi + + testImplementation libs.junitJupiter + testRuntimeOnly libs.slf4jlog4j + testImplementation project(':clients').sourceSets.test.output + } + + javadoc { + include "**/org/apache/kafka/connect/**" // needed for the `aggregatedJavadoc` task + } + + tasks.create(name: "copyDependantLibs", type: Copy) { + from (configurations.testRuntimeClasspath) { + include('slf4j-log4j12*') + include('log4j*jar') + } + from (configurations.runtimeClasspath) { + exclude('kafka-clients*') + exclude('connect-*') + } + into "$buildDir/dependant-libs" + duplicatesStrategy 'exclude' + } + + jar { + dependsOn copyDependantLibs + } +} + +project(':connect:transforms') { + archivesBaseName = "connect-transforms" + + dependencies { + api project(':connect:api') + + implementation libs.slf4jApi + + testImplementation libs.easymock + testImplementation libs.junitJupiter + + testRuntimeOnly libs.slf4jlog4j + testImplementation project(':clients').sourceSets.test.output + } + + javadoc { + enabled = false + } + + tasks.create(name: "copyDependantLibs", type: Copy) { + from (configurations.testRuntimeClasspath) { + include('slf4j-log4j12*') + include('log4j*jar') + } + from (configurations.runtimeClasspath) { + exclude('kafka-clients*') + exclude('connect-*') + } + into "$buildDir/dependant-libs" + duplicatesStrategy 'exclude' + } + + jar { + dependsOn copyDependantLibs + } +} + +project(':connect:json') { + archivesBaseName = "connect-json" + + dependencies { + api project(':connect:api') + + api libs.jacksonDatabind + api libs.jacksonJDK8Datatypes + + implementation libs.slf4jApi + + testImplementation libs.easymock + testImplementation libs.junitJupiter + + testRuntimeOnly libs.slf4jlog4j + testImplementation project(':clients').sourceSets.test.output + } + + javadoc { + enabled = false + } + + tasks.create(name: "copyDependantLibs", type: Copy) { + from (configurations.testRuntimeClasspath) { + include('slf4j-log4j12*') + include('log4j*jar') + } + from (configurations.runtimeClasspath) { + exclude('kafka-clients*') + exclude('connect-*') + } + into "$buildDir/dependant-libs" + duplicatesStrategy 'exclude' + } + + jar { + dependsOn copyDependantLibs + } +} + +project(':connect:runtime') { + archivesBaseName = "connect-runtime" + + dependencies { + // connect-runtime is used in tests, use `api` for modules below for backwards compatibility even though + // applications should generally not depend on `connect-runtime` + api project(':connect:api') + api project(':clients') + api project(':connect:json') + api project(':connect:transforms') + + implementation project(':tools') + + implementation libs.slf4jApi + implementation libs.log4j + implementation libs.jose4j // for SASL/OAUTHBEARER JWT validation + implementation libs.jacksonAnnotations + implementation libs.jacksonJaxrsJsonProvider + implementation libs.jerseyContainerServlet + implementation libs.jerseyHk2 + implementation libs.jaxbApi // Jersey dependency that was available in the JDK before Java 9 + implementation libs.activation // Jersey dependency that was available in the JDK before Java 9 + implementation libs.jettyServer + implementation libs.jettyServlet + implementation libs.jettyServlets + implementation libs.jettyClient + implementation libs.reflections + implementation libs.mavenArtifact + + testImplementation project(':clients').sourceSets.test.output + testImplementation project(':core') + testImplementation project(':metadata') + testImplementation project(':core').sourceSets.test.output + + testImplementation libs.easymock + testImplementation libs.junitJupiterApi + testImplementation libs.junitVintageEngine + testImplementation libs.powermockJunit4 + testImplementation libs.powermockEasymock + testImplementation libs.mockitoCore + testImplementation libs.httpclient + + testRuntimeOnly libs.slf4jlog4j + } + + javadoc { + enabled = false + } + + tasks.create(name: "copyDependantLibs", type: Copy) { + from (configurations.testRuntimeClasspath) { + include('slf4j-log4j12*') + include('log4j*jar') + } + from (configurations.runtimeClasspath) { + exclude('kafka-clients*') + exclude('connect-*') + } + into "$buildDir/dependant-libs" + duplicatesStrategy 'exclude' + } + + jar { + dependsOn copyDependantLibs + } + + task genConnectConfigDocs(type: JavaExec) { + classpath = sourceSets.main.runtimeClasspath + main = 'org.apache.kafka.connect.runtime.distributed.DistributedConfig' + if( !generatedDocsDir.exists() ) { generatedDocsDir.mkdirs() } + standardOutput = new File(generatedDocsDir, "connect_config.html").newOutputStream() + } + + task genSinkConnectorConfigDocs(type: JavaExec) { + classpath = sourceSets.main.runtimeClasspath + main = 'org.apache.kafka.connect.runtime.SinkConnectorConfig' + if( !generatedDocsDir.exists() ) { generatedDocsDir.mkdirs() } + standardOutput = new File(generatedDocsDir, "sink_connector_config.html").newOutputStream() + } + + task genSourceConnectorConfigDocs(type: JavaExec) { + classpath = sourceSets.main.runtimeClasspath + main = 'org.apache.kafka.connect.runtime.SourceConnectorConfig' + if( !generatedDocsDir.exists() ) { generatedDocsDir.mkdirs() } + standardOutput = new File(generatedDocsDir, "source_connector_config.html").newOutputStream() + } + + task genConnectTransformationDocs(type: JavaExec) { + classpath = sourceSets.main.runtimeClasspath + main = 'org.apache.kafka.connect.tools.TransformationDoc' + if( !generatedDocsDir.exists() ) { generatedDocsDir.mkdirs() } + standardOutput = new File(generatedDocsDir, "connect_transforms.html").newOutputStream() + } + + task genConnectPredicateDocs(type: JavaExec) { + classpath = sourceSets.main.runtimeClasspath + main = 'org.apache.kafka.connect.tools.PredicateDoc' + if( !generatedDocsDir.exists() ) { generatedDocsDir.mkdirs() } + standardOutput = new File(generatedDocsDir, "connect_predicates.html").newOutputStream() + } + + task genConnectMetricsDocs(type: JavaExec) { + classpath = sourceSets.test.runtimeClasspath + main = 'org.apache.kafka.connect.runtime.ConnectMetrics' + if( !generatedDocsDir.exists() ) { generatedDocsDir.mkdirs() } + standardOutput = new File(generatedDocsDir, "connect_metrics.html").newOutputStream() + } + +} + +project(':connect:file') { + archivesBaseName = "connect-file" + + dependencies { + implementation project(':connect:api') + implementation libs.slf4jApi + + testImplementation libs.easymock + testImplementation libs.junitJupiter + testImplementation libs.mockitoInline // supports mocking static methods, final classes, etc. + + testRuntimeOnly libs.slf4jlog4j + testImplementation project(':clients').sourceSets.test.output + } + + javadoc { + enabled = false + } + + tasks.create(name: "copyDependantLibs", type: Copy) { + from (configurations.testRuntimeClasspath) { + include('slf4j-log4j12*') + include('log4j*jar') + } + from (configurations.runtimeClasspath) { + exclude('kafka-clients*') + exclude('connect-*') + } + into "$buildDir/dependant-libs" + duplicatesStrategy 'exclude' + } + + jar { + dependsOn copyDependantLibs + } +} + +project(':connect:basic-auth-extension') { + archivesBaseName = "connect-basic-auth-extension" + + dependencies { + implementation project(':connect:api') + implementation libs.slf4jApi + implementation libs.jaxrsApi + + testImplementation libs.bcpkix + testImplementation libs.mockitoCore + testImplementation libs.junitJupiter + testImplementation project(':clients').sourceSets.test.output + + testRuntimeOnly libs.slf4jlog4j + testRuntimeOnly libs.jerseyContainerServlet + } + + javadoc { + enabled = false + } + + tasks.create(name: "copyDependantLibs", type: Copy) { + from (configurations.testRuntimeClasspath) { + include('slf4j-log4j12*') + include('log4j*jar') + } + from (configurations.runtimeClasspath) { + exclude('kafka-clients*') + exclude('connect-*') + } + into "$buildDir/dependant-libs" + duplicatesStrategy 'exclude' + } + + jar { + dependsOn copyDependantLibs + } +} + +project(':connect:mirror') { + archivesBaseName = "connect-mirror" + + dependencies { + implementation project(':connect:api') + implementation project(':connect:runtime') + implementation project(':connect:mirror-client') + implementation project(':clients') + + implementation libs.argparse4j + implementation libs.jacksonAnnotations + implementation libs.slf4jApi + + testImplementation libs.junitJupiter + testImplementation libs.mockitoCore + testImplementation project(':clients').sourceSets.test.output + testImplementation project(':connect:runtime').sourceSets.test.output + testImplementation project(':core') + testImplementation project(':core').sourceSets.test.output + + testRuntimeOnly project(':connect:runtime') + testRuntimeOnly libs.slf4jlog4j + testRuntimeOnly libs.bcpkix + } + + javadoc { + enabled = false + } + + tasks.create(name: "copyDependantLibs", type: Copy) { + from (configurations.testRuntimeClasspath) { + include('slf4j-log4j12*') + include('log4j*jar') + } + from (configurations.runtimeClasspath) { + exclude('kafka-clients*') + exclude('connect-*') + } + into "$buildDir/dependant-libs" + duplicatesStrategy 'exclude' + } + + jar { + dependsOn copyDependantLibs + } +} + +project(':connect:mirror-client') { + archivesBaseName = "connect-mirror-client" + + dependencies { + implementation project(':clients') + implementation libs.slf4jApi + + testImplementation libs.junitJupiter + testImplementation project(':clients').sourceSets.test.output + + testRuntimeOnly libs.slf4jlog4j + } + + javadoc { + enabled = true + } + + tasks.create(name: "copyDependantLibs", type: Copy) { + from (configurations.testRuntimeClasspath) { + include('slf4j-log4j12*') + include('log4j*jar') + } + from (configurations.runtimeClasspath) { + exclude('kafka-clients*') + exclude('connect-*') + } + into "$buildDir/dependant-libs" + duplicatesStrategy 'exclude' + } + + jar { + dependsOn copyDependantLibs + } +} + +task aggregatedJavadoc(type: Javadoc, dependsOn: compileJava) { + def projectsWithJavadoc = subprojects.findAll { it.javadoc.enabled } + source = projectsWithJavadoc.collect { it.sourceSets.main.allJava } + classpath = files(projectsWithJavadoc.collect { it.sourceSets.main.compileClasspath }) + includes = projectsWithJavadoc.collectMany { it.javadoc.getIncludes() } + excludes = projectsWithJavadoc.collectMany { it.javadoc.getExcludes() } + + options.charSet = 'UTF-8' + options.docEncoding = 'UTF-8' + options.encoding = 'UTF-8' + // Turn off doclint for now, see https://blog.joda.org/2014/02/turning-off-doclint-in-jdk-8-javadoc.html for rationale + options.addStringOption('Xdoclint:none', '-quiet') + + // The URL structure was changed to include the locale after Java 8 + if (JavaVersion.current().isJava11Compatible()) + options.links "https://docs.oracle.com/en/java/javase/${JavaVersion.current().majorVersion}/docs/api/" + else + options.links "https://docs.oracle.com/javase/8/docs/api/" +} diff --git a/checkstyle/.scalafmt.conf b/checkstyle/.scalafmt.conf new file mode 100644 index 0000000..4c4fcf3 --- /dev/null +++ b/checkstyle/.scalafmt.conf @@ -0,0 +1,19 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +docstrings = JavaDoc +maxColumn = 120 +continuationIndent.defnSite = 2 +assumeStandardLibraryStripMargin = true +rewrite.rules = [SortImports, RedundantBraces, RedundantParens, SortModifiers] \ No newline at end of file diff --git a/checkstyle/checkstyle.xml b/checkstyle/checkstyle.xml new file mode 100644 index 0000000..7f912dc --- /dev/null +++ b/checkstyle/checkstyle.xml @@ -0,0 +1,151 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/checkstyle/import-control-core.xml b/checkstyle/import-control-core.xml new file mode 100644 index 0000000..6ec3ae9 --- /dev/null +++ b/checkstyle/import-control-core.xml @@ -0,0 +1,102 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/checkstyle/import-control-jmh-benchmarks.xml b/checkstyle/import-control-jmh-benchmarks.xml new file mode 100644 index 0000000..5b9b418 --- /dev/null +++ b/checkstyle/import-control-jmh-benchmarks.xml @@ -0,0 +1,57 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/checkstyle/import-control.xml b/checkstyle/import-control.xml new file mode 100644 index 0000000..8b7a8e9 --- /dev/null +++ b/checkstyle/import-control.xml @@ -0,0 +1,626 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/checkstyle/java.header b/checkstyle/java.header new file mode 100644 index 0000000..45fd2d5 --- /dev/null +++ b/checkstyle/java.header @@ -0,0 +1,16 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ diff --git a/checkstyle/suppressions.xml b/checkstyle/suppressions.xml new file mode 100644 index 0000000..c1f9b34 --- /dev/null +++ b/checkstyle/suppressions.xml @@ -0,0 +1,293 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/clients/.gitignore b/clients/.gitignore new file mode 100644 index 0000000..ae3c172 --- /dev/null +++ b/clients/.gitignore @@ -0,0 +1 @@ +/bin/ diff --git a/clients/src/main/java/org/apache/kafka/clients/ApiVersions.java b/clients/src/main/java/org/apache/kafka/clients/ApiVersions.java new file mode 100644 index 0000000..a09d581 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/ApiVersions.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients; + +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.requests.ProduceRequest; + +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +/** + * Maintains node api versions for access outside of NetworkClient (which is where the information is derived). + * The pattern is akin to the use of {@link Metadata} for topic metadata. + * + * NOTE: This class is intended for INTERNAL usage only within Kafka. + */ +public class ApiVersions { + + private final Map nodeApiVersions = new HashMap<>(); + private byte maxUsableProduceMagic = RecordBatch.CURRENT_MAGIC_VALUE; + + public synchronized void update(String nodeId, NodeApiVersions nodeApiVersions) { + this.nodeApiVersions.put(nodeId, nodeApiVersions); + this.maxUsableProduceMagic = computeMaxUsableProduceMagic(); + } + + public synchronized void remove(String nodeId) { + this.nodeApiVersions.remove(nodeId); + this.maxUsableProduceMagic = computeMaxUsableProduceMagic(); + } + + public synchronized NodeApiVersions get(String nodeId) { + return this.nodeApiVersions.get(nodeId); + } + + private byte computeMaxUsableProduceMagic() { + // use a magic version which is supported by all brokers to reduce the chance that + // we will need to convert the messages when they are ready to be sent. + Optional knownBrokerNodesMinRequiredMagicForProduce = this.nodeApiVersions.values().stream() + .filter(versions -> versions.apiVersion(ApiKeys.PRODUCE) != null) // filter out Raft controller nodes + .map(versions -> ProduceRequest.requiredMagicForVersion(versions.latestUsableVersion(ApiKeys.PRODUCE))) + .min(Byte::compare); + return (byte) Math.min(RecordBatch.CURRENT_MAGIC_VALUE, + knownBrokerNodesMinRequiredMagicForProduce.orElse(RecordBatch.CURRENT_MAGIC_VALUE)); + } + + public synchronized byte maxUsableProduceMagic() { + return maxUsableProduceMagic; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/ClientDnsLookup.java b/clients/src/main/java/org/apache/kafka/clients/ClientDnsLookup.java new file mode 100644 index 0000000..e097c7e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/ClientDnsLookup.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients; + +import java.util.Locale; + +public enum ClientDnsLookup { + USE_ALL_DNS_IPS("use_all_dns_ips"), + RESOLVE_CANONICAL_BOOTSTRAP_SERVERS_ONLY("resolve_canonical_bootstrap_servers_only"); + + private final String clientDnsLookup; + + ClientDnsLookup(String clientDnsLookup) { + this.clientDnsLookup = clientDnsLookup; + } + + @Override + public String toString() { + return clientDnsLookup; + } + + public static ClientDnsLookup forConfig(String config) { + return ClientDnsLookup.valueOf(config.toUpperCase(Locale.ROOT)); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/ClientRequest.java b/clients/src/main/java/org/apache/kafka/clients/ClientRequest.java new file mode 100644 index 0000000..abba795 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/ClientRequest.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients; + +import org.apache.kafka.common.message.RequestHeaderData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.requests.AbstractRequest; +import org.apache.kafka.common.requests.RequestHeader; + +/** + * A request being sent to the server. This holds both the network send as well as the client-level metadata. + */ +public final class ClientRequest { + + private final String destination; + private final AbstractRequest.Builder requestBuilder; + private final int correlationId; + private final String clientId; + private final long createdTimeMs; + private final boolean expectResponse; + private final int requestTimeoutMs; + private final RequestCompletionHandler callback; + + /** + * @param destination The brokerId to send the request to + * @param requestBuilder The builder for the request to make + * @param correlationId The correlation id for this client request + * @param clientId The client ID to use for the header + * @param createdTimeMs The unix timestamp in milliseconds for the time at which this request was created. + * @param expectResponse Should we expect a response message or is this request complete once it is sent? + * @param callback A callback to execute when the response has been received (or null if no callback is necessary) + */ + public ClientRequest(String destination, + AbstractRequest.Builder requestBuilder, + int correlationId, + String clientId, + long createdTimeMs, + boolean expectResponse, + int requestTimeoutMs, + RequestCompletionHandler callback) { + this.destination = destination; + this.requestBuilder = requestBuilder; + this.correlationId = correlationId; + this.clientId = clientId; + this.createdTimeMs = createdTimeMs; + this.expectResponse = expectResponse; + this.requestTimeoutMs = requestTimeoutMs; + this.callback = callback; + } + + @Override + public String toString() { + return "ClientRequest(expectResponse=" + expectResponse + + ", callback=" + callback + + ", destination=" + destination + + ", correlationId=" + correlationId + + ", clientId=" + clientId + + ", createdTimeMs=" + createdTimeMs + + ", requestBuilder=" + requestBuilder + + ")"; + } + + public boolean expectResponse() { + return expectResponse; + } + + public ApiKeys apiKey() { + return requestBuilder.apiKey(); + } + + public RequestHeader makeHeader(short version) { + ApiKeys requestApiKey = apiKey(); + return new RequestHeader( + new RequestHeaderData() + .setRequestApiKey(requestApiKey.id) + .setRequestApiVersion(version) + .setClientId(clientId) + .setCorrelationId(correlationId), + requestApiKey.requestHeaderVersion(version)); + } + + public AbstractRequest.Builder requestBuilder() { + return requestBuilder; + } + + public String destination() { + return destination; + } + + public RequestCompletionHandler callback() { + return callback; + } + + public long createdTimeMs() { + return createdTimeMs; + } + + public int correlationId() { + return correlationId; + } + + public int requestTimeoutMs() { + return requestTimeoutMs; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/ClientResponse.java b/clients/src/main/java/org/apache/kafka/clients/ClientResponse.java new file mode 100644 index 0000000..446bf44 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/ClientResponse.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients; + +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.requests.AbstractResponse; +import org.apache.kafka.common.requests.RequestHeader; + +/** + * A response from the server. Contains both the body of the response as well as the correlated request + * metadata that was originally sent. + */ +public class ClientResponse { + + private final RequestHeader requestHeader; + private final RequestCompletionHandler callback; + private final String destination; + private final long receivedTimeMs; + private final long latencyMs; + private final boolean disconnected; + private final UnsupportedVersionException versionMismatch; + private final AuthenticationException authenticationException; + private final AbstractResponse responseBody; + + /** + * @param requestHeader The header of the corresponding request + * @param callback The callback to be invoked + * @param createdTimeMs The unix timestamp when the corresponding request was created + * @param destination The node the corresponding request was sent to + * @param receivedTimeMs The unix timestamp when this response was received + * @param disconnected Whether the client disconnected before fully reading a response + * @param versionMismatch Whether there was a version mismatch that prevented sending the request. + * @param responseBody The response contents (or null) if we disconnected, no response was expected, + * or if there was a version mismatch. + */ + public ClientResponse(RequestHeader requestHeader, + RequestCompletionHandler callback, + String destination, + long createdTimeMs, + long receivedTimeMs, + boolean disconnected, + UnsupportedVersionException versionMismatch, + AuthenticationException authenticationException, + AbstractResponse responseBody) { + this.requestHeader = requestHeader; + this.callback = callback; + this.destination = destination; + this.receivedTimeMs = receivedTimeMs; + this.latencyMs = receivedTimeMs - createdTimeMs; + this.disconnected = disconnected; + this.versionMismatch = versionMismatch; + this.authenticationException = authenticationException; + this.responseBody = responseBody; + } + + public long receivedTimeMs() { + return receivedTimeMs; + } + + public boolean wasDisconnected() { + return disconnected; + } + + public UnsupportedVersionException versionMismatch() { + return versionMismatch; + } + + public AuthenticationException authenticationException() { + return authenticationException; + } + + public RequestHeader requestHeader() { + return requestHeader; + } + + public String destination() { + return destination; + } + + public AbstractResponse responseBody() { + return responseBody; + } + + public boolean hasResponse() { + return responseBody != null; + } + + public long requestLatencyMs() { + return latencyMs; + } + + public void onComplete() { + if (callback != null) + callback.onComplete(this); + } + + @Override + public String toString() { + return "ClientResponse(receivedTimeMs=" + receivedTimeMs + + ", latencyMs=" + + latencyMs + + ", disconnected=" + + disconnected + + ", requestHeader=" + + requestHeader + + ", responseBody=" + + responseBody + + ")"; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/ClientUtils.java b/clients/src/main/java/org/apache/kafka/clients/ClientUtils.java new file mode 100644 index 0000000..86d4678 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/ClientUtils.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients; + +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.config.SaslConfigs; +import org.apache.kafka.common.network.ChannelBuilder; +import org.apache.kafka.common.network.ChannelBuilders; +import org.apache.kafka.common.security.JaasContext; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.UnknownHostException; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import static org.apache.kafka.common.utils.Utils.getHost; +import static org.apache.kafka.common.utils.Utils.getPort; + +public final class ClientUtils { + private static final Logger log = LoggerFactory.getLogger(ClientUtils.class); + + private ClientUtils() { + } + + public static List parseAndValidateAddresses(List urls, String clientDnsLookupConfig) { + return parseAndValidateAddresses(urls, ClientDnsLookup.forConfig(clientDnsLookupConfig)); + } + + public static List parseAndValidateAddresses(List urls, ClientDnsLookup clientDnsLookup) { + List addresses = new ArrayList<>(); + for (String url : urls) { + if (url != null && !url.isEmpty()) { + try { + String host = getHost(url); + Integer port = getPort(url); + if (host == null || port == null) + throw new ConfigException("Invalid url in " + CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG + ": " + url); + + if (clientDnsLookup == ClientDnsLookup.RESOLVE_CANONICAL_BOOTSTRAP_SERVERS_ONLY) { + InetAddress[] inetAddresses = InetAddress.getAllByName(host); + for (InetAddress inetAddress : inetAddresses) { + String resolvedCanonicalName = inetAddress.getCanonicalHostName(); + InetSocketAddress address = new InetSocketAddress(resolvedCanonicalName, port); + if (address.isUnresolved()) { + log.warn("Couldn't resolve server {} from {} as DNS resolution of the canonical hostname {} failed for {}", url, CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, resolvedCanonicalName, host); + } else { + addresses.add(address); + } + } + } else { + InetSocketAddress address = new InetSocketAddress(host, port); + if (address.isUnresolved()) { + log.warn("Couldn't resolve server {} from {} as DNS resolution failed for {}", url, CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, host); + } else { + addresses.add(address); + } + } + + } catch (IllegalArgumentException e) { + throw new ConfigException("Invalid port in " + CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG + ": " + url); + } catch (UnknownHostException e) { + throw new ConfigException("Unknown host in " + CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG + ": " + url); + } + } + } + if (addresses.isEmpty()) + throw new ConfigException("No resolvable bootstrap urls given in " + CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG); + return addresses; + } + + /** + * Create a new channel builder from the provided configuration. + * + * @param config client configs + * @param time the time implementation + * @param logContext the logging context + * + * @return configured ChannelBuilder based on the configs. + */ + public static ChannelBuilder createChannelBuilder(AbstractConfig config, Time time, LogContext logContext) { + SecurityProtocol securityProtocol = SecurityProtocol.forName(config.getString(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG)); + String clientSaslMechanism = config.getString(SaslConfigs.SASL_MECHANISM); + return ChannelBuilders.clientChannelBuilder(securityProtocol, JaasContext.Type.CLIENT, config, null, + clientSaslMechanism, time, true, logContext); + } + + static List resolve(String host, HostResolver hostResolver) throws UnknownHostException { + InetAddress[] addresses = hostResolver.resolve(host); + List result = filterPreferredAddresses(addresses); + if (log.isDebugEnabled()) + log.debug("Resolved host {} as {}", host, result.stream().map(i -> i.getHostAddress()).collect(Collectors.joining(","))); + return result; + } + + /** + * Return a list containing the first address in `allAddresses` and subsequent addresses + * that are a subtype of the first address. + * + * The outcome is that all returned addresses are either IPv4 or IPv6 (InetAddress has two + * subclasses: Inet4Address and Inet6Address). + */ + static List filterPreferredAddresses(InetAddress[] allAddresses) { + List preferredAddresses = new ArrayList<>(); + Class clazz = null; + for (InetAddress address : allAddresses) { + if (clazz == null) { + clazz = address.getClass(); + } + if (clazz.isInstance(address)) { + preferredAddresses.add(address); + } + } + return preferredAddresses; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/ClusterConnectionStates.java b/clients/src/main/java/org/apache/kafka/clients/ClusterConnectionStates.java new file mode 100644 index 0000000..95efdbe --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/ClusterConnectionStates.java @@ -0,0 +1,542 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients; + +import java.util.HashSet; +import java.util.Set; + +import java.util.stream.Collectors; +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.utils.ExponentialBackoff; +import org.apache.kafka.common.utils.LogContext; +import org.slf4j.Logger; + +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * The state of our connection to each node in the cluster. + * + */ +final class ClusterConnectionStates { + final static int RECONNECT_BACKOFF_EXP_BASE = 2; + final static double RECONNECT_BACKOFF_JITTER = 0.2; + final static int CONNECTION_SETUP_TIMEOUT_EXP_BASE = 2; + final static double CONNECTION_SETUP_TIMEOUT_JITTER = 0.2; + private final Map nodeState; + private final Logger log; + private final HostResolver hostResolver; + private Set connectingNodes; + private ExponentialBackoff reconnectBackoff; + private ExponentialBackoff connectionSetupTimeout; + + public ClusterConnectionStates(long reconnectBackoffMs, long reconnectBackoffMaxMs, + long connectionSetupTimeoutMs, long connectionSetupTimeoutMaxMs, + LogContext logContext, HostResolver hostResolver) { + this.log = logContext.logger(ClusterConnectionStates.class); + this.reconnectBackoff = new ExponentialBackoff( + reconnectBackoffMs, + RECONNECT_BACKOFF_EXP_BASE, + reconnectBackoffMaxMs, + RECONNECT_BACKOFF_JITTER); + this.connectionSetupTimeout = new ExponentialBackoff( + connectionSetupTimeoutMs, + CONNECTION_SETUP_TIMEOUT_EXP_BASE, + connectionSetupTimeoutMaxMs, + CONNECTION_SETUP_TIMEOUT_JITTER); + this.nodeState = new HashMap<>(); + this.connectingNodes = new HashSet<>(); + this.hostResolver = hostResolver; + } + + /** + * Return true iff we can currently initiate a new connection. This will be the case if we are not + * connected and haven't been connected for at least the minimum reconnection backoff period. + * @param id the connection id to check + * @param now the current time in ms + * @return true if we can initiate a new connection + */ + public boolean canConnect(String id, long now) { + NodeConnectionState state = nodeState.get(id); + if (state == null) + return true; + else + return state.state.isDisconnected() && + now - state.lastConnectAttemptMs >= state.reconnectBackoffMs; + } + + /** + * Return true if we are disconnected from the given node and can't re-establish a connection yet. + * @param id the connection to check + * @param now the current time in ms + */ + public boolean isBlackedOut(String id, long now) { + NodeConnectionState state = nodeState.get(id); + return state != null + && state.state.isDisconnected() + && now - state.lastConnectAttemptMs < state.reconnectBackoffMs; + } + + /** + * Returns the number of milliseconds to wait, based on the connection state, before attempting to send data. When + * disconnected, this respects the reconnect backoff time. When connecting, return a delay based on the connection timeout. + * When connected, wait indefinitely (i.e. until a wakeup). + * @param id the connection to check + * @param now the current time in ms + */ + public long connectionDelay(String id, long now) { + NodeConnectionState state = nodeState.get(id); + if (state == null) return 0; + + if (state.state == ConnectionState.CONNECTING) { + return connectionSetupTimeoutMs(id); + } else if (state.state.isDisconnected()) { + long timeWaited = now - state.lastConnectAttemptMs; + return Math.max(state.reconnectBackoffMs - timeWaited, 0); + } else { + // When connected, we should be able to delay indefinitely since other events (connection or + // data acked) will cause a wakeup once data can be sent. + return Long.MAX_VALUE; + } + } + + /** + * Return true if a specific connection establishment is currently underway + * @param id The id of the node to check + */ + public boolean isConnecting(String id) { + NodeConnectionState state = nodeState.get(id); + return state != null && state.state == ConnectionState.CONNECTING; + } + + /** + * Check whether a connection is either being established or awaiting API version information. + * @param id The id of the node to check + * @return true if the node is either connecting or has connected and is awaiting API versions, false otherwise + */ + public boolean isPreparingConnection(String id) { + NodeConnectionState state = nodeState.get(id); + return state != null && + (state.state == ConnectionState.CONNECTING || state.state == ConnectionState.CHECKING_API_VERSIONS); + } + + /** + * Enter the connecting state for the given connection, moving to a new resolved address if necessary. + * @param id the id of the connection + * @param now the current time in ms + * @param host the host of the connection, to be resolved internally if needed + */ + public void connecting(String id, long now, String host) { + NodeConnectionState connectionState = nodeState.get(id); + if (connectionState != null && connectionState.host().equals(host)) { + connectionState.lastConnectAttemptMs = now; + connectionState.state = ConnectionState.CONNECTING; + // Move to next resolved address, or if addresses are exhausted, mark node to be re-resolved + connectionState.moveToNextAddress(); + connectingNodes.add(id); + return; + } else if (connectionState != null) { + log.info("Hostname for node {} changed from {} to {}.", id, connectionState.host(), host); + } + + // Create a new NodeConnectionState if nodeState does not already contain one + // for the specified id or if the hostname associated with the node id changed. + nodeState.put(id, new NodeConnectionState(ConnectionState.CONNECTING, now, + reconnectBackoff.backoff(0), connectionSetupTimeout.backoff(0), host, hostResolver)); + connectingNodes.add(id); + } + + /** + * Returns a resolved address for the given connection, resolving it if necessary. + * @param id the id of the connection + * @throws UnknownHostException if the address was not resolvable + */ + public InetAddress currentAddress(String id) throws UnknownHostException { + return nodeState(id).currentAddress(); + } + + /** + * Enter the disconnected state for the given node. + * @param id the connection we have disconnected + * @param now the current time in ms + */ + public void disconnected(String id, long now) { + NodeConnectionState nodeState = nodeState(id); + nodeState.lastConnectAttemptMs = now; + updateReconnectBackoff(nodeState); + if (nodeState.state == ConnectionState.CONNECTING) { + updateConnectionSetupTimeout(nodeState); + connectingNodes.remove(id); + } else { + resetConnectionSetupTimeout(nodeState); + if (nodeState.state.isConnected()) { + // If a connection had previously been established, clear the addresses to trigger a new DNS resolution + // because the node IPs may have changed + nodeState.clearAddresses(); + } + } + nodeState.state = ConnectionState.DISCONNECTED; + } + + /** + * Indicate that the connection is throttled until the specified deadline. + * @param id the connection to be throttled + * @param throttleUntilTimeMs the throttle deadline in milliseconds + */ + public void throttle(String id, long throttleUntilTimeMs) { + NodeConnectionState state = nodeState.get(id); + // The throttle deadline should never regress. + if (state != null && state.throttleUntilTimeMs < throttleUntilTimeMs) { + state.throttleUntilTimeMs = throttleUntilTimeMs; + } + } + + /** + * Return the remaining throttling delay in milliseconds if throttling is in progress. Return 0, otherwise. + * @param id the connection to check + * @param now the current time in ms + */ + public long throttleDelayMs(String id, long now) { + NodeConnectionState state = nodeState.get(id); + if (state != null && state.throttleUntilTimeMs > now) { + return state.throttleUntilTimeMs - now; + } else { + return 0; + } + } + + /** + * Return the number of milliseconds to wait, based on the connection state and the throttle time, before + * attempting to send data. If the connection has been established but being throttled, return throttle delay. + * Otherwise, return connection delay. + * @param id the connection to check + * @param now the current time in ms + */ + public long pollDelayMs(String id, long now) { + long throttleDelayMs = throttleDelayMs(id, now); + if (isConnected(id) && throttleDelayMs > 0) { + return throttleDelayMs; + } else { + return connectionDelay(id, now); + } + } + + /** + * Enter the checking_api_versions state for the given node. + * @param id the connection identifier + */ + public void checkingApiVersions(String id) { + NodeConnectionState nodeState = nodeState(id); + nodeState.state = ConnectionState.CHECKING_API_VERSIONS; + resetReconnectBackoff(nodeState); + resetConnectionSetupTimeout(nodeState); + connectingNodes.remove(id); + } + + /** + * Enter the ready state for the given node. + * @param id the connection identifier + */ + public void ready(String id) { + NodeConnectionState nodeState = nodeState(id); + nodeState.state = ConnectionState.READY; + nodeState.authenticationException = null; + resetReconnectBackoff(nodeState); + resetConnectionSetupTimeout(nodeState); + connectingNodes.remove(id); + } + + /** + * Enter the authentication failed state for the given node. + * @param id the connection identifier + * @param now the current time in ms + * @param exception the authentication exception + */ + public void authenticationFailed(String id, long now, AuthenticationException exception) { + NodeConnectionState nodeState = nodeState(id); + nodeState.authenticationException = exception; + nodeState.state = ConnectionState.AUTHENTICATION_FAILED; + nodeState.lastConnectAttemptMs = now; + updateReconnectBackoff(nodeState); + } + + /** + * Return true if the connection is in the READY state and currently not throttled. + * + * @param id the connection identifier + * @param now the current time in ms + */ + public boolean isReady(String id, long now) { + return isReady(nodeState.get(id), now); + } + + private boolean isReady(NodeConnectionState state, long now) { + return state != null && state.state == ConnectionState.READY && state.throttleUntilTimeMs <= now; + } + + /** + * Return true if there is at least one node with connection in the READY state and not throttled. Returns false + * otherwise. + * + * @param now the current time in ms + */ + public boolean hasReadyNodes(long now) { + for (Map.Entry entry : nodeState.entrySet()) { + if (isReady(entry.getValue(), now)) { + return true; + } + } + return false; + } + + /** + * Return true if the connection has been established + * @param id The id of the node to check + */ + public boolean isConnected(String id) { + NodeConnectionState state = nodeState.get(id); + return state != null && state.state.isConnected(); + } + + /** + * Return true if the connection has been disconnected + * @param id The id of the node to check + */ + public boolean isDisconnected(String id) { + NodeConnectionState state = nodeState.get(id); + return state != null && state.state.isDisconnected(); + } + + /** + * Return authentication exception if an authentication error occurred + * @param id The id of the node to check + */ + public AuthenticationException authenticationException(String id) { + NodeConnectionState state = nodeState.get(id); + return state != null ? state.authenticationException : null; + } + + /** + * Resets the failure count for a node and sets the reconnect backoff to the base + * value configured via reconnect.backoff.ms + * + * @param nodeState The node state object to update + */ + private void resetReconnectBackoff(NodeConnectionState nodeState) { + nodeState.failedAttempts = 0; + nodeState.reconnectBackoffMs = reconnectBackoff.backoff(0); + } + + /** + * Resets the failure count for a node and sets the connection setup timeout to the base + * value configured via socket.connection.setup.timeout.ms + * + * @param nodeState The node state object to update + */ + private void resetConnectionSetupTimeout(NodeConnectionState nodeState) { + nodeState.failedConnectAttempts = 0; + nodeState.connectionSetupTimeoutMs = connectionSetupTimeout.backoff(0); + } + + /** + * Increment the failure counter, update the node reconnect backoff exponentially, + * and record the current timestamp. + * The delay is reconnect.backoff.ms * 2**(failures - 1) * (+/- 20% random jitter) + * Up to a (pre-jitter) maximum of reconnect.backoff.max.ms + * + * @param nodeState The node state object to update + */ + private void updateReconnectBackoff(NodeConnectionState nodeState) { + nodeState.reconnectBackoffMs = reconnectBackoff.backoff(nodeState.failedAttempts); + nodeState.failedAttempts++; + } + + /** + * Increment the failure counter and update the node connection setup timeout exponentially. + * The delay is socket.connection.setup.timeout.ms * 2**(failures) * (+/- 20% random jitter) + * Up to a (pre-jitter) maximum of reconnect.backoff.max.ms + * + * @param nodeState The node state object to update + */ + private void updateConnectionSetupTimeout(NodeConnectionState nodeState) { + nodeState.failedConnectAttempts++; + nodeState.connectionSetupTimeoutMs = connectionSetupTimeout.backoff(nodeState.failedConnectAttempts); + } + + /** + * Remove the given node from the tracked connection states. The main difference between this and `disconnected` + * is the impact on `connectionDelay`: it will be 0 after this call whereas `reconnectBackoffMs` will be taken + * into account after `disconnected` is called. + * + * @param id the connection to remove + */ + public void remove(String id) { + nodeState.remove(id); + connectingNodes.remove(id); + } + + /** + * Get the state of a given connection. + * @param id the id of the connection + * @return the state of our connection + */ + public ConnectionState connectionState(String id) { + return nodeState(id).state; + } + + /** + * Get the state of a given node. + * @param id the connection to fetch the state for + */ + private NodeConnectionState nodeState(String id) { + NodeConnectionState state = this.nodeState.get(id); + if (state == null) + throw new IllegalStateException("No entry found for connection " + id); + return state; + } + + /** + * Get the id set of nodes which are in CONNECTING state + */ + // package private for testing only + Set connectingNodes() { + return this.connectingNodes; + } + + /** + * Get the timestamp of the latest connection attempt of a given node + * @param id the connection to fetch the state for + */ + public long lastConnectAttemptMs(String id) { + NodeConnectionState nodeState = this.nodeState.get(id); + return nodeState == null ? 0 : nodeState.lastConnectAttemptMs; + } + + /** + * Get the current socket connection setup timeout of the given node. + * The base value is defined via socket.connection.setup.timeout. + * @param id the connection to fetch the state for + */ + public long connectionSetupTimeoutMs(String id) { + NodeConnectionState nodeState = this.nodeState(id); + return nodeState.connectionSetupTimeoutMs; + } + + /** + * Test if the connection to the given node has reached its timeout + * @param id the connection to fetch the state for + * @param now the current time in ms + */ + public boolean isConnectionSetupTimeout(String id, long now) { + NodeConnectionState nodeState = this.nodeState(id); + if (nodeState.state != ConnectionState.CONNECTING) + throw new IllegalStateException("Node " + id + " is not in connecting state"); + return now - lastConnectAttemptMs(id) > connectionSetupTimeoutMs(id); + } + + /** + * Return the List of nodes whose connection setup has timed out. + * @param now the current time in ms + */ + public List nodesWithConnectionSetupTimeout(long now) { + return connectingNodes.stream() + .filter(id -> isConnectionSetupTimeout(id, now)) + .collect(Collectors.toList()); + } + + /** + * The state of our connection to a node. + */ + private static class NodeConnectionState { + + ConnectionState state; + AuthenticationException authenticationException; + long lastConnectAttemptMs; + long failedAttempts; + long failedConnectAttempts; + long reconnectBackoffMs; + long connectionSetupTimeoutMs; + // Connection is being throttled if current time < throttleUntilTimeMs. + long throttleUntilTimeMs; + private List addresses; + private int addressIndex; + private final String host; + private final HostResolver hostResolver; + + private NodeConnectionState(ConnectionState state, long lastConnectAttempt, long reconnectBackoffMs, + long connectionSetupTimeoutMs, String host, HostResolver hostResolver) { + this.state = state; + this.addresses = Collections.emptyList(); + this.addressIndex = -1; + this.authenticationException = null; + this.lastConnectAttemptMs = lastConnectAttempt; + this.failedAttempts = 0; + this.reconnectBackoffMs = reconnectBackoffMs; + this.connectionSetupTimeoutMs = connectionSetupTimeoutMs; + this.throttleUntilTimeMs = 0; + this.host = host; + this.hostResolver = hostResolver; + } + + public String host() { + return host; + } + + /** + * Fetches the current selected IP address for this node, resolving {@link #host()} if necessary. + * @return the selected address + * @throws UnknownHostException if resolving {@link #host()} fails + */ + private InetAddress currentAddress() throws UnknownHostException { + if (addresses.isEmpty()) { + // (Re-)initialize list + addresses = ClientUtils.resolve(host, hostResolver); + addressIndex = 0; + } + + return addresses.get(addressIndex); + } + + /** + * Jumps to the next available resolved address for this node. If no other addresses are available, marks the + * list to be refreshed on the next {@link #currentAddress()} call. + */ + private void moveToNextAddress() { + if (addresses.isEmpty()) + return; // Avoid div0. List will initialize on next currentAddress() call + + addressIndex = (addressIndex + 1) % addresses.size(); + if (addressIndex == 0) + addresses = Collections.emptyList(); // Exhausted list. Re-resolve on next currentAddress() call + } + + /** + * Clears the resolved addresses in order to trigger re-resolving on the next {@link #currentAddress()} call. + */ + private void clearAddresses() { + addresses = Collections.emptyList(); + } + + public String toString() { + return "NodeState(" + state + ", " + lastConnectAttemptMs + ", " + failedAttempts + ", " + throttleUntilTimeMs + ")"; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/CommonClientConfigs.java b/clients/src/main/java/org/apache/kafka/clients/CommonClientConfigs.java new file mode 100644 index 0000000..58075d6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/CommonClientConfigs.java @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients; + +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.Map; + +/** + * Configurations shared by Kafka client applications: producer, consumer, connect, etc. + */ +public class CommonClientConfigs { + private static final Logger log = LoggerFactory.getLogger(CommonClientConfigs.class); + + /* + * NOTE: DO NOT CHANGE EITHER CONFIG NAMES AS THESE ARE PART OF THE PUBLIC API AND CHANGE WILL BREAK USER CODE. + */ + + public static final String BOOTSTRAP_SERVERS_CONFIG = "bootstrap.servers"; + public static final String BOOTSTRAP_SERVERS_DOC = "A list of host/port pairs to use for establishing the initial connection to the Kafka cluster. The client will make use of all servers irrespective of which servers are specified here for bootstrapping—this list only impacts the initial hosts used to discover the full set of servers. This list should be in the form " + + "host1:port1,host2:port2,.... Since these servers are just used for the initial connection to " + + "discover the full cluster membership (which may change dynamically), this list need not contain the full set of " + + "servers (you may want more than one, though, in case a server is down)."; + + public static final String CLIENT_DNS_LOOKUP_CONFIG = "client.dns.lookup"; + public static final String CLIENT_DNS_LOOKUP_DOC = "Controls how the client uses DNS lookups. " + + "If set to use_all_dns_ips, connect to each returned IP " + + "address in sequence until a successful connection is established. " + + "After a disconnection, the next IP is used. Once all IPs have been " + + "used once, the client resolves the IP(s) from the hostname again " + + "(both the JVM and the OS cache DNS name lookups, however). " + + "If set to resolve_canonical_bootstrap_servers_only, " + + "resolve each bootstrap address into a list of canonical names. After " + + "the bootstrap phase, this behaves the same as use_all_dns_ips."; + + public static final String METADATA_MAX_AGE_CONFIG = "metadata.max.age.ms"; + public static final String METADATA_MAX_AGE_DOC = "The period of time in milliseconds after which we force a refresh of metadata even if we haven't seen any partition leadership changes to proactively discover any new brokers or partitions."; + + public static final String SEND_BUFFER_CONFIG = "send.buffer.bytes"; + public static final String SEND_BUFFER_DOC = "The size of the TCP send buffer (SO_SNDBUF) to use when sending data. If the value is -1, the OS default will be used."; + public static final int SEND_BUFFER_LOWER_BOUND = -1; + + public static final String RECEIVE_BUFFER_CONFIG = "receive.buffer.bytes"; + public static final String RECEIVE_BUFFER_DOC = "The size of the TCP receive buffer (SO_RCVBUF) to use when reading data. If the value is -1, the OS default will be used."; + public static final int RECEIVE_BUFFER_LOWER_BOUND = -1; + + public static final String CLIENT_ID_CONFIG = "client.id"; + public static final String CLIENT_ID_DOC = "An id string to pass to the server when making requests. The purpose of this is to be able to track the source of requests beyond just ip/port by allowing a logical application name to be included in server-side request logging."; + + public static final String CLIENT_RACK_CONFIG = "client.rack"; + public static final String CLIENT_RACK_DOC = "A rack identifier for this client. This can be any string value which indicates where this client is physically located. It corresponds with the broker config 'broker.rack'"; + + public static final String RECONNECT_BACKOFF_MS_CONFIG = "reconnect.backoff.ms"; + public static final String RECONNECT_BACKOFF_MS_DOC = "The base amount of time to wait before attempting to reconnect to a given host. This avoids repeatedly connecting to a host in a tight loop. This backoff applies to all connection attempts by the client to a broker."; + + public static final String RECONNECT_BACKOFF_MAX_MS_CONFIG = "reconnect.backoff.max.ms"; + public static final String RECONNECT_BACKOFF_MAX_MS_DOC = "The maximum amount of time in milliseconds to wait when reconnecting to a broker that has repeatedly failed to connect. If provided, the backoff per host will increase exponentially for each consecutive connection failure, up to this maximum. After calculating the backoff increase, 20% random jitter is added to avoid connection storms."; + + public static final String RETRIES_CONFIG = "retries"; + public static final String RETRIES_DOC = "Setting a value greater than zero will cause the client to resend any request that fails with a potentially transient error." + + " It is recommended to set the value to either zero or `MAX_VALUE` and use corresponding timeout parameters to control how long a client should retry a request."; + + public static final String RETRY_BACKOFF_MS_CONFIG = "retry.backoff.ms"; + public static final String RETRY_BACKOFF_MS_DOC = "The amount of time to wait before attempting to retry a failed request to a given topic partition. This avoids repeatedly sending requests in a tight loop under some failure scenarios."; + + public static final String METRICS_SAMPLE_WINDOW_MS_CONFIG = "metrics.sample.window.ms"; + public static final String METRICS_SAMPLE_WINDOW_MS_DOC = "The window of time a metrics sample is computed over."; + + public static final String METRICS_NUM_SAMPLES_CONFIG = "metrics.num.samples"; + public static final String METRICS_NUM_SAMPLES_DOC = "The number of samples maintained to compute metrics."; + + public static final String METRICS_RECORDING_LEVEL_CONFIG = "metrics.recording.level"; + public static final String METRICS_RECORDING_LEVEL_DOC = "The highest recording level for metrics."; + + public static final String METRIC_REPORTER_CLASSES_CONFIG = "metric.reporters"; + public static final String METRIC_REPORTER_CLASSES_DOC = "A list of classes to use as metrics reporters. Implementing the org.apache.kafka.common.metrics.MetricsReporter interface allows plugging in classes that will be notified of new metric creation. The JmxReporter is always included to register JMX statistics."; + + public static final String METRICS_CONTEXT_PREFIX = "metrics.context."; + + public static final String SECURITY_PROTOCOL_CONFIG = "security.protocol"; + public static final String SECURITY_PROTOCOL_DOC = "Protocol used to communicate with brokers. Valid values are: " + + Utils.join(SecurityProtocol.names(), ", ") + "."; + public static final String DEFAULT_SECURITY_PROTOCOL = "PLAINTEXT"; + + public static final String SOCKET_CONNECTION_SETUP_TIMEOUT_MS_CONFIG = "socket.connection.setup.timeout.ms"; + public static final String SOCKET_CONNECTION_SETUP_TIMEOUT_MS_DOC = "The amount of time the client will wait for the socket connection to be established. If the connection is not built before the timeout elapses, clients will close the socket channel."; + public static final Long DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT_MS = 10 * 1000L; + + public static final String SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS_CONFIG = "socket.connection.setup.timeout.max.ms"; + public static final String SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS_DOC = "The maximum amount of time the client will wait for the socket connection to be established. The connection setup timeout will increase exponentially for each consecutive connection failure up to this maximum. To avoid connection storms, a randomization factor of 0.2 will be applied to the timeout resulting in a random range between 20% below and 20% above the computed value."; + public static final Long DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS = 30 * 1000L; + + public static final String CONNECTIONS_MAX_IDLE_MS_CONFIG = "connections.max.idle.ms"; + public static final String CONNECTIONS_MAX_IDLE_MS_DOC = "Close idle connections after the number of milliseconds specified by this config."; + + public static final String REQUEST_TIMEOUT_MS_CONFIG = "request.timeout.ms"; + public static final String REQUEST_TIMEOUT_MS_DOC = "The configuration controls the maximum amount of time the client will wait " + + "for the response of a request. If the response is not received before the timeout " + + "elapses the client will resend the request if necessary or fail the request if " + + "retries are exhausted."; + + public static final String DEFAULT_LIST_KEY_SERDE_INNER_CLASS = "default.list.key.serde.inner"; + public static final String DEFAULT_LIST_KEY_SERDE_INNER_CLASS_DOC = "Default inner class of list serde for key that implements the org.apache.kafka.common.serialization.Serde interface. " + + "This configuration will be read if and only if default.key.serde configuration is set to org.apache.kafka.common.serialization.Serdes.ListSerde"; + + public static final String DEFAULT_LIST_VALUE_SERDE_INNER_CLASS = "default.list.value.serde.inner"; + public static final String DEFAULT_LIST_VALUE_SERDE_INNER_CLASS_DOC = "Default inner class of list serde for value that implements the org.apache.kafka.common.serialization.Serde interface. " + + "This configuration will be read if and only if default.value.serde configuration is set to org.apache.kafka.common.serialization.Serdes.ListSerde"; + + public static final String DEFAULT_LIST_KEY_SERDE_TYPE_CLASS = "default.list.key.serde.type"; + public static final String DEFAULT_LIST_KEY_SERDE_TYPE_CLASS_DOC = "Default class for key that implements the java.util.List interface. " + + "This configuration will be read if and only if default.key.serde configuration is set to org.apache.kafka.common.serialization.Serdes.ListSerde " + + "Note when list serde class is used, one needs to set the inner serde class that implements the org.apache.kafka.common.serialization.Serde interface via '" + + DEFAULT_LIST_KEY_SERDE_INNER_CLASS + "'"; + + public static final String DEFAULT_LIST_VALUE_SERDE_TYPE_CLASS = "default.list.value.serde.type"; + public static final String DEFAULT_LIST_VALUE_SERDE_TYPE_CLASS_DOC = "Default class for value that implements the java.util.List interface. " + + "This configuration will be read if and only if default.value.serde configuration is set to org.apache.kafka.common.serialization.Serdes.ListSerde " + + "Note when list serde class is used, one needs to set the inner serde class that implements the org.apache.kafka.common.serialization.Serde interface via '" + + DEFAULT_LIST_VALUE_SERDE_INNER_CLASS + "'"; + + public static final String GROUP_ID_CONFIG = "group.id"; + public static final String GROUP_ID_DOC = "A unique string that identifies the consumer group this consumer belongs to. This property is required if the consumer uses either the group management functionality by using subscribe(topic) or the Kafka-based offset management strategy."; + + public static final String GROUP_INSTANCE_ID_CONFIG = "group.instance.id"; + public static final String GROUP_INSTANCE_ID_DOC = "A unique identifier of the consumer instance provided by the end user. " + + "Only non-empty strings are permitted. If set, the consumer is treated as a static member, " + + "which means that only one instance with this ID is allowed in the consumer group at any time. " + + "This can be used in combination with a larger session timeout to avoid group rebalances caused by transient unavailability " + + "(e.g. process restarts). If not set, the consumer will join the group as a dynamic member, which is the traditional behavior."; + + public static final String MAX_POLL_INTERVAL_MS_CONFIG = "max.poll.interval.ms"; + public static final String MAX_POLL_INTERVAL_MS_DOC = "The maximum delay between invocations of poll() when using " + + "consumer group management. This places an upper bound on the amount of time that the consumer can be idle " + + "before fetching more records. If poll() is not called before expiration of this timeout, then the consumer " + + "is considered failed and the group will rebalance in order to reassign the partitions to another member. " + + "For consumers using a non-null group.instance.id which reach this timeout, partitions will not be immediately reassigned. " + + "Instead, the consumer will stop sending heartbeats and partitions will be reassigned " + + "after expiration of session.timeout.ms. This mirrors the behavior of a static consumer which has shutdown."; + + public static final String REBALANCE_TIMEOUT_MS_CONFIG = "rebalance.timeout.ms"; + public static final String REBALANCE_TIMEOUT_MS_DOC = "The maximum allowed time for each worker to join the group " + + "once a rebalance has begun. This is basically a limit on the amount of time needed for all tasks to " + + "flush any pending data and commit offsets. If the timeout is exceeded, then the worker will be removed " + + "from the group, which will cause offset commit failures."; + + public static final String SESSION_TIMEOUT_MS_CONFIG = "session.timeout.ms"; + public static final String SESSION_TIMEOUT_MS_DOC = "The timeout used to detect client failures when using " + + "Kafka's group management facility. The client sends periodic heartbeats to indicate its liveness " + + "to the broker. If no heartbeats are received by the broker before the expiration of this session timeout, " + + "then the broker will remove this client from the group and initiate a rebalance. Note that the value " + + "must be in the allowable range as configured in the broker configuration by group.min.session.timeout.ms " + + "and group.max.session.timeout.ms."; + + public static final String HEARTBEAT_INTERVAL_MS_CONFIG = "heartbeat.interval.ms"; + public static final String HEARTBEAT_INTERVAL_MS_DOC = "The expected time between heartbeats to the consumer " + + "coordinator when using Kafka's group management facilities. Heartbeats are used to ensure that the " + + "consumer's session stays active and to facilitate rebalancing when new consumers join or leave the group. " + + "The value must be set lower than session.timeout.ms, but typically should be set no higher " + + "than 1/3 of that value. It can be adjusted even lower to control the expected time for normal rebalances."; + + public static final String DEFAULT_API_TIMEOUT_MS_CONFIG = "default.api.timeout.ms"; + public static final String DEFAULT_API_TIMEOUT_MS_DOC = "Specifies the timeout (in milliseconds) for client APIs. " + + "This configuration is used as the default timeout for all client operations that do not specify a timeout parameter."; + + /** + * Postprocess the configuration so that exponential backoff is disabled when reconnect backoff + * is explicitly configured but the maximum reconnect backoff is not explicitly configured. + * + * @param config The config object. + * @param parsedValues The parsedValues as provided to postProcessParsedConfig. + * + * @return The new values which have been set as described in postProcessParsedConfig. + */ + public static Map postProcessReconnectBackoffConfigs(AbstractConfig config, + Map parsedValues) { + HashMap rval = new HashMap<>(); + if ((!config.originals().containsKey(RECONNECT_BACKOFF_MAX_MS_CONFIG)) && + config.originals().containsKey(RECONNECT_BACKOFF_MS_CONFIG)) { + log.debug("Disabling exponential reconnect backoff because {} is set, but {} is not.", + RECONNECT_BACKOFF_MS_CONFIG, RECONNECT_BACKOFF_MAX_MS_CONFIG); + rval.put(RECONNECT_BACKOFF_MAX_MS_CONFIG, parsedValues.get(RECONNECT_BACKOFF_MS_CONFIG)); + } + return rval; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/ConnectionState.java b/clients/src/main/java/org/apache/kafka/clients/ConnectionState.java new file mode 100644 index 0000000..f92c7fa --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/ConnectionState.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients; + +/** + * The states of a node connection + * + * DISCONNECTED: connection has not been successfully established yet + * CONNECTING: connection is under progress + * CHECKING_API_VERSIONS: connection has been established and api versions check is in progress. Failure of this check will cause connection to close + * READY: connection is ready to send requests + * AUTHENTICATION_FAILED: connection failed due to an authentication error + */ +public enum ConnectionState { + DISCONNECTED, CONNECTING, CHECKING_API_VERSIONS, READY, AUTHENTICATION_FAILED; + + public boolean isDisconnected() { + return this == AUTHENTICATION_FAILED || this == DISCONNECTED; + } + + public boolean isConnected() { + return this == CHECKING_API_VERSIONS || this == READY; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/DefaultHostResolver.java b/clients/src/main/java/org/apache/kafka/clients/DefaultHostResolver.java new file mode 100644 index 0000000..786173e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/DefaultHostResolver.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients; + +import java.net.InetAddress; +import java.net.UnknownHostException; + +public class DefaultHostResolver implements HostResolver { + + @Override + public InetAddress[] resolve(String host) throws UnknownHostException { + return InetAddress.getAllByName(host); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/FetchSessionHandler.java b/clients/src/main/java/org/apache/kafka/clients/FetchSessionHandler.java new file mode 100644 index 0000000..aca847c --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/FetchSessionHandler.java @@ -0,0 +1,605 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients; + +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.FetchMetadata; +import org.apache.kafka.common.requests.FetchRequest.PartitionData; +import org.apache.kafka.common.requests.FetchResponse; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; + +import static org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID; + +/** + * FetchSessionHandler maintains the fetch session state for connecting to a broker. + * + * Using the protocol outlined by KIP-227, clients can create incremental fetch sessions. + * These sessions allow the client to fetch information about a set of partition over + * and over, without explicitly enumerating all the partitions in the request and the + * response. + * + * FetchSessionHandler tracks the partitions which are in the session. It also + * determines which partitions need to be included in each fetch request, and what + * the attached fetch session metadata should be for each request. The corresponding + * class on the receiving broker side is FetchManager. + */ +public class FetchSessionHandler { + private final Logger log; + + private final int node; + + /** + * The metadata for the next fetch request. + */ + private FetchMetadata nextMetadata = FetchMetadata.INITIAL; + + public FetchSessionHandler(LogContext logContext, int node) { + this.log = logContext.logger(FetchSessionHandler.class); + this.node = node; + } + + /** + * All of the partitions which exist in the fetch request session. + */ + private LinkedHashMap sessionPartitions = + new LinkedHashMap<>(0); + + /** + * All of the topic names mapped to topic ids for topics which exist in the fetch request session. + */ + private Map sessionTopicNames = new HashMap<>(0); + + public Map sessionTopicNames() { + return sessionTopicNames; + } + + public static class FetchRequestData { + /** + * The partitions to send in the fetch request. + */ + private final Map toSend; + + /** + * The partitions to send in the request's "forget" list. + */ + private final List toForget; + + /** + * The partitions to send in the request's "forget" list if + * the version is >= 13. + */ + private final List toReplace; + + /** + * All of the partitions which exist in the fetch request session. + */ + private final Map sessionPartitions; + + /** + * The metadata to use in this fetch request. + */ + private final FetchMetadata metadata; + + /** + * A boolean indicating whether we have a topic ID for every topic in the request so that we can send a request that + * uses topic IDs + */ + private final boolean canUseTopicIds; + + FetchRequestData(Map toSend, + List toForget, + List toReplace, + Map sessionPartitions, + FetchMetadata metadata, + boolean canUseTopicIds) { + this.toSend = toSend; + this.toForget = toForget; + this.toReplace = toReplace; + this.sessionPartitions = sessionPartitions; + this.metadata = metadata; + this.canUseTopicIds = canUseTopicIds; + } + + /** + * Get the set of partitions to send in this fetch request. + */ + public Map toSend() { + return toSend; + } + + /** + * Get a list of partitions to forget in this fetch request. + */ + public List toForget() { + return toForget; + } + + /** + * Get a list of partitions to forget in this fetch request. + */ + public List toReplace() { + return toReplace; + } + + /** + * Get the full set of partitions involved in this fetch request. + */ + public Map sessionPartitions() { + return sessionPartitions; + } + + public FetchMetadata metadata() { + return metadata; + } + + public boolean canUseTopicIds() { + return canUseTopicIds; + } + + @Override + public String toString() { + StringBuilder bld; + if (metadata.isFull()) { + bld = new StringBuilder("FullFetchRequest(toSend=("); + String prefix = ""; + for (TopicPartition partition : toSend.keySet()) { + bld.append(prefix); + bld.append(partition); + prefix = ", "; + } + } else { + bld = new StringBuilder("IncrementalFetchRequest(toSend=("); + String prefix = ""; + for (TopicPartition partition : toSend.keySet()) { + bld.append(prefix); + bld.append(partition); + prefix = ", "; + } + bld.append("), toForget=("); + prefix = ""; + for (TopicIdPartition partition : toForget) { + bld.append(prefix); + bld.append(partition); + prefix = ", "; + } + bld.append("), toReplace=("); + prefix = ""; + for (TopicIdPartition partition : toReplace) { + bld.append(prefix); + bld.append(partition); + prefix = ", "; + } + bld.append("), implied=("); + prefix = ""; + for (TopicPartition partition : sessionPartitions.keySet()) { + if (!toSend.containsKey(partition)) { + bld.append(prefix); + bld.append(partition); + prefix = ", "; + } + } + } + if (canUseTopicIds) { + bld.append("), canUseTopicIds=True"); + } else { + bld.append("), canUseTopicIds=False"); + } + bld.append(")"); + return bld.toString(); + } + } + + public class Builder { + /** + * The next partitions which we want to fetch. + * + * It is important to maintain the insertion order of this list by using a LinkedHashMap rather + * than a regular Map. + * + * One reason is that when dealing with FULL fetch requests, if there is not enough response + * space to return data from all partitions, the server will only return data from partitions + * early in this list. + * + * Another reason is because we make use of the list ordering to optimize the preparation of + * incremental fetch requests (see below). + */ + private LinkedHashMap next; + private Map topicNames; + private final boolean copySessionPartitions; + private int partitionsWithoutTopicIds = 0; + + Builder() { + this.next = new LinkedHashMap<>(); + this.topicNames = new HashMap<>(); + this.copySessionPartitions = true; + } + + Builder(int initialSize, boolean copySessionPartitions) { + this.next = new LinkedHashMap<>(initialSize); + this.topicNames = new HashMap<>(); + this.copySessionPartitions = copySessionPartitions; + } + + /** + * Mark that we want data from this partition in the upcoming fetch. + */ + public void add(TopicPartition topicPartition, PartitionData data) { + next.put(topicPartition, data); + // topicIds should not change between adding partitions and building, so we can use putIfAbsent + if (data.topicId.equals(Uuid.ZERO_UUID)) { + partitionsWithoutTopicIds++; + } else { + topicNames.putIfAbsent(data.topicId, topicPartition.topic()); + } + } + + public FetchRequestData build() { + boolean canUseTopicIds = partitionsWithoutTopicIds == 0; + + if (nextMetadata.isFull()) { + if (log.isDebugEnabled()) { + log.debug("Built full fetch {} for node {} with {}.", + nextMetadata, node, topicPartitionsToLogString(next.keySet())); + } + sessionPartitions = next; + next = null; + // Only add topic IDs to the session if we are using topic IDs. + if (canUseTopicIds) { + sessionTopicNames = topicNames; + } else { + sessionTopicNames = Collections.emptyMap(); + } + Map toSend = + Collections.unmodifiableMap(new LinkedHashMap<>(sessionPartitions)); + return new FetchRequestData(toSend, Collections.emptyList(), Collections.emptyList(), toSend, nextMetadata, canUseTopicIds); + } + + List added = new ArrayList<>(); + List removed = new ArrayList<>(); + List altered = new ArrayList<>(); + List replaced = new ArrayList<>(); + for (Iterator> iter = + sessionPartitions.entrySet().iterator(); iter.hasNext(); ) { + Entry entry = iter.next(); + TopicPartition topicPartition = entry.getKey(); + PartitionData prevData = entry.getValue(); + PartitionData nextData = next.remove(topicPartition); + if (nextData != null) { + // We basically check if the new partition had the same topic ID. If not, + // we add it to the "replaced" set. If the request is version 13 or higher, the replaced + // partition will be forgotten. In any case, we will send the new partition in the request. + if (!prevData.topicId.equals(nextData.topicId) + && !prevData.topicId.equals(Uuid.ZERO_UUID) + && !nextData.topicId.equals(Uuid.ZERO_UUID)) { + // Re-add the replaced partition to the end of 'next' + next.put(topicPartition, nextData); + entry.setValue(nextData); + replaced.add(new TopicIdPartition(prevData.topicId, topicPartition)); + } else if (!prevData.equals(nextData)) { + // Re-add the altered partition to the end of 'next' + next.put(topicPartition, nextData); + entry.setValue(nextData); + altered.add(new TopicIdPartition(nextData.topicId, topicPartition)); + } + } else { + // Remove this partition from the session. + iter.remove(); + // Indicate that we no longer want to listen to this partition. + removed.add(new TopicIdPartition(prevData.topicId, topicPartition)); + // If we do not have this topic ID in the builder or the session, we can not use topic IDs. + if (canUseTopicIds && prevData.topicId.equals(Uuid.ZERO_UUID)) + canUseTopicIds = false; + } + } + // Add any new partitions to the session. + for (Entry entry : next.entrySet()) { + TopicPartition topicPartition = entry.getKey(); + PartitionData nextData = entry.getValue(); + if (sessionPartitions.containsKey(topicPartition)) { + // In the previous loop, all the partitions which existed in both sessionPartitions + // and next were moved to the end of next, or removed from next. Therefore, + // once we hit one of them, we know there are no more unseen entries to look + // at in next. + break; + } + sessionPartitions.put(topicPartition, nextData); + added.add(new TopicIdPartition(nextData.topicId, topicPartition)); + } + + // Add topic IDs to session if we can use them. If an ID is inconsistent, we will handle in the receiving broker. + // If we switched from using topic IDs to not using them (or vice versa), that error will also be handled in the receiving broker. + if (canUseTopicIds) { + sessionTopicNames = topicNames; + } else { + sessionTopicNames = Collections.emptyMap(); + } + + if (log.isDebugEnabled()) { + log.debug("Built incremental fetch {} for node {}. Added {}, altered {}, removed {}, " + + "replaced {} out of {}", nextMetadata, node, topicIdPartitionsToLogString(added), + topicIdPartitionsToLogString(altered), topicIdPartitionsToLogString(removed), + topicIdPartitionsToLogString(replaced), topicPartitionsToLogString(sessionPartitions.keySet())); + } + Map toSend = Collections.unmodifiableMap(next); + Map curSessionPartitions = copySessionPartitions + ? Collections.unmodifiableMap(new LinkedHashMap<>(sessionPartitions)) + : Collections.unmodifiableMap(sessionPartitions); + next = null; + return new FetchRequestData(toSend, + Collections.unmodifiableList(removed), + Collections.unmodifiableList(replaced), + curSessionPartitions, + nextMetadata, + canUseTopicIds); + } + } + + public Builder newBuilder() { + return new Builder(); + } + + + /** A builder that allows for presizing the PartitionData hashmap, and avoiding making a + * secondary copy of the sessionPartitions, in cases where this is not necessarily. + * This builder is primarily for use by the Replica Fetcher + * @param size the initial size of the PartitionData hashmap + * @param copySessionPartitions boolean denoting whether the builder should make a deep copy of + * session partitions + */ + public Builder newBuilder(int size, boolean copySessionPartitions) { + return new Builder(size, copySessionPartitions); + } + + private String topicPartitionsToLogString(Collection partitions) { + if (!log.isTraceEnabled()) { + return String.format("%d partition(s)", partitions.size()); + } + return "(" + Utils.join(partitions, ", ") + ")"; + } + + private String topicIdPartitionsToLogString(Collection partitions) { + if (!log.isTraceEnabled()) { + return String.format("%d partition(s)", partitions.size()); + } + return "(" + Utils.join(partitions, ", ") + ")"; + } + + /** + * Return missing items which are expected to be in a particular set, but which are not. + * + * @param toFind The items to look for. + * @param toSearch The set of items to search. + * @return null if all items were found; some of the missing ones in a set, if not. + */ + static Set findMissing(Set toFind, Set toSearch) { + Set ret = new LinkedHashSet<>(); + for (T toFindItem: toFind) { + if (!toSearch.contains(toFindItem)) { + ret.add(toFindItem); + } + } + return ret; + } + + /** + * Verify that a full fetch response contains all the partitions in the fetch session. + * + * @param topicPartitions The topicPartitions from the FetchResponse. + * @param ids The topic IDs from the FetchResponse. + * @param version The version of the FetchResponse. + * @return True if the full fetch response partitions are valid. + */ + String verifyFullFetchResponsePartitions(Set topicPartitions, Set ids, short version) { + StringBuilder bld = new StringBuilder(); + Set extra = + findMissing(topicPartitions, sessionPartitions.keySet()); + Set omitted = + findMissing(sessionPartitions.keySet(), topicPartitions); + Set extraIds = new HashSet<>(); + if (version >= 13) { + extraIds = findMissing(ids, sessionTopicNames.keySet()); + } + if (!omitted.isEmpty()) { + bld.append("omittedPartitions=(").append(Utils.join(omitted, ", ")).append(", "); + } + if (!extra.isEmpty()) { + bld.append("extraPartitions=(").append(Utils.join(extra, ", ")).append(", "); + } + if (!extraIds.isEmpty()) { + bld.append("extraIds=(").append(Utils.join(extraIds, ", ")).append(", "); + } + if ((!omitted.isEmpty()) || (!extra.isEmpty()) || (!extraIds.isEmpty())) { + bld.append("response=(").append(Utils.join(topicPartitions, ", ")).append(")"); + return bld.toString(); + } + return null; + } + + /** + * Verify that the partitions in an incremental fetch response are contained in the session. + * + * @param topicPartitions The topicPartitions from the FetchResponse. + * @param ids The topic IDs from the FetchResponse. + * @param version The version of the FetchResponse. + * @return True if the incremental fetch response partitions are valid. + */ + String verifyIncrementalFetchResponsePartitions(Set topicPartitions, Set ids, short version) { + Set extraIds = new HashSet<>(); + if (version >= 13) { + extraIds = findMissing(ids, sessionTopicNames.keySet()); + } + Set extra = + findMissing(topicPartitions, sessionPartitions.keySet()); + StringBuilder bld = new StringBuilder(); + if (extra.isEmpty()) + bld.append("extraPartitions=(").append(Utils.join(extra, ", ")).append("), "); + if (extraIds.isEmpty()) + bld.append("extraIds=(").append(Utils.join(extraIds, ", ")).append("), "); + if ((!extra.isEmpty()) || (!extraIds.isEmpty())) { + bld.append("response=(").append(Utils.join(topicPartitions, ", ")).append(")"); + return bld.toString(); + } + return null; + } + + /** + * Create a string describing the partitions in a FetchResponse. + * + * @param topicPartitions The topicPartitions from the FetchResponse. + * @return The string to log. + */ + private String responseDataToLogString(Set topicPartitions) { + if (!log.isTraceEnabled()) { + int implied = sessionPartitions.size() - topicPartitions.size(); + if (implied > 0) { + return String.format(" with %d response partition(s), %d implied partition(s)", + topicPartitions.size(), implied); + } else { + return String.format(" with %d response partition(s)", + topicPartitions.size()); + } + } + StringBuilder bld = new StringBuilder(); + bld.append(" with response=("). + append(Utils.join(topicPartitions, ", ")). + append(")"); + String prefix = ", implied=("; + String suffix = ""; + for (TopicPartition partition : sessionPartitions.keySet()) { + if (!topicPartitions.contains(partition)) { + bld.append(prefix); + bld.append(partition); + prefix = ", "; + suffix = ")"; + } + } + bld.append(suffix); + return bld.toString(); + } + + /** + * Handle the fetch response. + * + * @param response The response. + * @param version The version of the request. + * @return True if the response is well-formed; false if it can't be processed + * because of missing or unexpected partitions. + */ + public boolean handleResponse(FetchResponse response, short version) { + if (response.error() != Errors.NONE) { + log.info("Node {} was unable to process the fetch request with {}: {}.", + node, nextMetadata, response.error()); + if (response.error() == Errors.FETCH_SESSION_ID_NOT_FOUND) { + nextMetadata = FetchMetadata.INITIAL; + } else { + nextMetadata = nextMetadata.nextCloseExisting(); + } + return false; + } + Set topicPartitions = response.responseData(sessionTopicNames, version).keySet(); + if (nextMetadata.isFull()) { + if (topicPartitions.isEmpty() && response.throttleTimeMs() > 0) { + // Normally, an empty full fetch response would be invalid. However, KIP-219 + // specifies that if the broker wants to throttle the client, it will respond + // to a full fetch request with an empty response and a throttleTimeMs + // value set. We don't want to log this with a warning, since it's not an error. + // However, the empty full fetch response can't be processed, so it's still appropriate + // to return false here. + if (log.isDebugEnabled()) { + log.debug("Node {} sent a empty full fetch response to indicate that this " + + "client should be throttled for {} ms.", node, response.throttleTimeMs()); + } + nextMetadata = FetchMetadata.INITIAL; + return false; + } + String problem = verifyFullFetchResponsePartitions(topicPartitions, response.topicIds(), version); + if (problem != null) { + log.info("Node {} sent an invalid full fetch response with {}", node, problem); + nextMetadata = FetchMetadata.INITIAL; + return false; + } else if (response.sessionId() == INVALID_SESSION_ID) { + if (log.isDebugEnabled()) + log.debug("Node {} sent a full fetch response{}", node, responseDataToLogString(topicPartitions)); + nextMetadata = FetchMetadata.INITIAL; + return true; + } else { + // The server created a new incremental fetch session. + if (log.isDebugEnabled()) + log.debug("Node {} sent a full fetch response that created a new incremental " + + "fetch session {}{}", node, response.sessionId(), responseDataToLogString(topicPartitions)); + nextMetadata = FetchMetadata.newIncremental(response.sessionId()); + return true; + } + } else { + String problem = verifyIncrementalFetchResponsePartitions(topicPartitions, response.topicIds(), version); + if (problem != null) { + log.info("Node {} sent an invalid incremental fetch response with {}", node, problem); + nextMetadata = nextMetadata.nextCloseExisting(); + return false; + } else if (response.sessionId() == INVALID_SESSION_ID) { + // The incremental fetch session was closed by the server. + if (log.isDebugEnabled()) + log.debug("Node {} sent an incremental fetch response closing session {}{}", + node, nextMetadata.sessionId(), responseDataToLogString(topicPartitions)); + nextMetadata = FetchMetadata.INITIAL; + return true; + } else { + // The incremental fetch session was continued by the server. + // We don't have to do anything special here to support KIP-219, since an empty incremental + // fetch request is perfectly valid. + if (log.isDebugEnabled()) + log.debug("Node {} sent an incremental fetch response with throttleTimeMs = {} " + + "for session {}{}", node, response.throttleTimeMs(), response.sessionId(), + responseDataToLogString(topicPartitions)); + nextMetadata = nextMetadata.nextIncremental(); + return true; + } + } + } + + /** + * Handle an error sending the prepared request. + * + * When a network error occurs, we close any existing fetch session on our next request, + * and try to create a new session. + * + * @param t The exception. + */ + public void handleError(Throwable t) { + log.info("Error sending fetch request {} to node {}:", nextMetadata, node, t); + nextMetadata = nextMetadata.nextCloseExisting(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/GroupRebalanceConfig.java b/clients/src/main/java/org/apache/kafka/clients/GroupRebalanceConfig.java new file mode 100644 index 0000000..006800a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/GroupRebalanceConfig.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients; + +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.requests.JoinGroupRequest; + +import java.util.Locale; +import java.util.Optional; + +/** + * Class to extract group rebalance related configs. + */ +public class GroupRebalanceConfig { + + public enum ProtocolType { + CONSUMER, + CONNECT; + + @Override + public String toString() { + return super.toString().toLowerCase(Locale.ROOT); + } + } + + public final int sessionTimeoutMs; + public final int rebalanceTimeoutMs; + public final int heartbeatIntervalMs; + public final String groupId; + public final Optional groupInstanceId; + public final long retryBackoffMs; + public final boolean leaveGroupOnClose; + + public GroupRebalanceConfig(AbstractConfig config, ProtocolType protocolType) { + this.sessionTimeoutMs = config.getInt(CommonClientConfigs.SESSION_TIMEOUT_MS_CONFIG); + + // Consumer and Connect use different config names for defining rebalance timeout + if (protocolType == ProtocolType.CONSUMER) { + this.rebalanceTimeoutMs = config.getInt(CommonClientConfigs.MAX_POLL_INTERVAL_MS_CONFIG); + } else { + this.rebalanceTimeoutMs = config.getInt(CommonClientConfigs.REBALANCE_TIMEOUT_MS_CONFIG); + } + + this.heartbeatIntervalMs = config.getInt(CommonClientConfigs.HEARTBEAT_INTERVAL_MS_CONFIG); + this.groupId = config.getString(CommonClientConfigs.GROUP_ID_CONFIG); + + // Static membership is only introduced in consumer API. + if (protocolType == ProtocolType.CONSUMER) { + String groupInstanceId = config.getString(CommonClientConfigs.GROUP_INSTANCE_ID_CONFIG); + if (groupInstanceId != null) { + JoinGroupRequest.validateGroupInstanceId(groupInstanceId); + this.groupInstanceId = Optional.of(groupInstanceId); + } else { + this.groupInstanceId = Optional.empty(); + } + } else { + this.groupInstanceId = Optional.empty(); + } + + this.retryBackoffMs = config.getLong(CommonClientConfigs.RETRY_BACKOFF_MS_CONFIG); + + // Internal leave group config is only defined in Consumer. + if (protocolType == ProtocolType.CONSUMER) { + this.leaveGroupOnClose = config.getBoolean("internal.leave.group.on.close"); + } else { + this.leaveGroupOnClose = true; + } + } + + // For testing purpose. + public GroupRebalanceConfig(final int sessionTimeoutMs, + final int rebalanceTimeoutMs, + final int heartbeatIntervalMs, + String groupId, + Optional groupInstanceId, + long retryBackoffMs, + boolean leaveGroupOnClose) { + this.sessionTimeoutMs = sessionTimeoutMs; + this.rebalanceTimeoutMs = rebalanceTimeoutMs; + this.heartbeatIntervalMs = heartbeatIntervalMs; + this.groupId = groupId; + this.groupInstanceId = groupInstanceId; + this.retryBackoffMs = retryBackoffMs; + this.leaveGroupOnClose = leaveGroupOnClose; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/HostResolver.java b/clients/src/main/java/org/apache/kafka/clients/HostResolver.java new file mode 100644 index 0000000..80209ca --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/HostResolver.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients; + +import java.net.InetAddress; +import java.net.UnknownHostException; + +public interface HostResolver { + + InetAddress[] resolve(String host) throws UnknownHostException; +} diff --git a/clients/src/main/java/org/apache/kafka/clients/InFlightRequests.java b/clients/src/main/java/org/apache/kafka/clients/InFlightRequests.java new file mode 100644 index 0000000..6f5477e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/InFlightRequests.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Deque; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * The set of requests which have been sent or are being sent but haven't yet received a response + */ +final class InFlightRequests { + + private final int maxInFlightRequestsPerConnection; + private final Map> requests = new HashMap<>(); + /** Thread safe total number of in flight requests. */ + private final AtomicInteger inFlightRequestCount = new AtomicInteger(0); + + public InFlightRequests(int maxInFlightRequestsPerConnection) { + this.maxInFlightRequestsPerConnection = maxInFlightRequestsPerConnection; + } + + /** + * Add the given request to the queue for the connection it was directed to + */ + public void add(NetworkClient.InFlightRequest request) { + String destination = request.destination; + Deque reqs = this.requests.get(destination); + if (reqs == null) { + reqs = new ArrayDeque<>(); + this.requests.put(destination, reqs); + } + reqs.addFirst(request); + inFlightRequestCount.incrementAndGet(); + } + + /** + * Get the request queue for the given node + */ + private Deque requestQueue(String node) { + Deque reqs = requests.get(node); + if (reqs == null || reqs.isEmpty()) + throw new IllegalStateException("There are no in-flight requests for node " + node); + return reqs; + } + + /** + * Get the oldest request (the one that will be completed next) for the given node + */ + public NetworkClient.InFlightRequest completeNext(String node) { + NetworkClient.InFlightRequest inFlightRequest = requestQueue(node).pollLast(); + inFlightRequestCount.decrementAndGet(); + return inFlightRequest; + } + + /** + * Get the last request we sent to the given node (but don't remove it from the queue) + * @param node The node id + */ + public NetworkClient.InFlightRequest lastSent(String node) { + return requestQueue(node).peekFirst(); + } + + /** + * Complete the last request that was sent to a particular node. + * @param node The node the request was sent to + * @return The request + */ + public NetworkClient.InFlightRequest completeLastSent(String node) { + NetworkClient.InFlightRequest inFlightRequest = requestQueue(node).pollFirst(); + inFlightRequestCount.decrementAndGet(); + return inFlightRequest; + } + + /** + * Can we send more requests to this node? + * + * @param node Node in question + * @return true iff we have no requests still being sent to the given node + */ + public boolean canSendMore(String node) { + Deque queue = requests.get(node); + return queue == null || queue.isEmpty() || + (queue.peekFirst().send.completed() && queue.size() < this.maxInFlightRequestsPerConnection); + } + + /** + * Return the number of in-flight requests directed at the given node + * @param node The node + * @return The request count. + */ + public int count(String node) { + Deque queue = requests.get(node); + return queue == null ? 0 : queue.size(); + } + + /** + * Return true if there is no in-flight request directed at the given node and false otherwise + */ + public boolean isEmpty(String node) { + Deque queue = requests.get(node); + return queue == null || queue.isEmpty(); + } + + /** + * Count all in-flight requests for all nodes. This method is thread safe, but may lag the actual count. + */ + public int count() { + return inFlightRequestCount.get(); + } + + /** + * Return true if there is no in-flight request and false otherwise + */ + public boolean isEmpty() { + for (Deque deque : this.requests.values()) { + if (!deque.isEmpty()) + return false; + } + return true; + } + + /** + * Clear out all the in-flight requests for the given node and return them + * + * @param node The node + * @return All the in-flight requests for that node that have been removed + */ + public Iterable clearAll(String node) { + Deque reqs = requests.get(node); + if (reqs == null) { + return Collections.emptyList(); + } else { + final Deque clearedRequests = requests.remove(node); + inFlightRequestCount.getAndAdd(-clearedRequests.size()); + return () -> clearedRequests.descendingIterator(); + } + } + + private Boolean hasExpiredRequest(long now, Deque deque) { + for (NetworkClient.InFlightRequest request : deque) { + if (request.timeElapsedSinceSendMs(now) > request.requestTimeoutMs) + return true; + } + return false; + } + + /** + * Returns a list of nodes with pending in-flight request, that need to be timed out + * + * @param now current time in milliseconds + * @return list of nodes + */ + public List nodesWithTimedOutRequests(long now) { + List nodeIds = new ArrayList<>(); + for (Map.Entry> requestEntry : requests.entrySet()) { + String nodeId = requestEntry.getKey(); + Deque deque = requestEntry.getValue(); + if (hasExpiredRequest(now, deque)) + nodeIds.add(nodeId); + } + return nodeIds; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/KafkaClient.java b/clients/src/main/java/org/apache/kafka/clients/KafkaClient.java new file mode 100644 index 0000000..18a7eef --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/KafkaClient.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients; + +import org.apache.kafka.common.Node; +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.requests.AbstractRequest; + +import java.io.Closeable; +import java.util.List; + +/** + * The interface for {@link NetworkClient} + */ +public interface KafkaClient extends Closeable { + + /** + * Check if we are currently ready to send another request to the given node but don't attempt to connect if we + * aren't. + * + * @param node The node to check + * @param now The current timestamp + */ + boolean isReady(Node node, long now); + + /** + * Initiate a connection to the given node (if necessary), and return true if already connected. The readiness of a + * node will change only when poll is invoked. + * + * @param node The node to connect to. + * @param now The current time + * @return true iff we are ready to immediately initiate the sending of another request to the given node. + */ + boolean ready(Node node, long now); + + /** + * Return the number of milliseconds to wait, based on the connection state, before attempting to send data. When + * disconnected, this respects the reconnect backoff time. When connecting or connected, this handles slow/stalled + * connections. + * + * @param node The node to check + * @param now The current timestamp + * @return The number of milliseconds to wait. + */ + long connectionDelay(Node node, long now); + + /** + * Return the number of milliseconds to wait, based on the connection state and the throttle time, before + * attempting to send data. If the connection has been established but being throttled, return throttle delay. + * Otherwise, return connection delay. + * + * @param node the connection to check + * @param now the current time in ms + */ + long pollDelayMs(Node node, long now); + + /** + * Check if the connection of the node has failed, based on the connection state. Such connection failure are + * usually transient and can be resumed in the next {@link #ready(org.apache.kafka.common.Node, long)} } + * call, but there are cases where transient failures needs to be caught and re-acted upon. + * + * @param node the node to check + * @return true iff the connection has failed and the node is disconnected + */ + boolean connectionFailed(Node node); + + /** + * Check if authentication to this node has failed, based on the connection state. Authentication failures are + * propagated without any retries. + * + * @param node the node to check + * @return an AuthenticationException iff authentication has failed, null otherwise + */ + AuthenticationException authenticationException(Node node); + + /** + * Queue up the given request for sending. Requests can only be sent on ready connections. + * @param request The request + * @param now The current timestamp + */ + void send(ClientRequest request, long now); + + /** + * Do actual reads and writes from sockets. + * + * @param timeout The maximum amount of time to wait for responses in ms, must be non-negative. The implementation + * is free to use a lower value if appropriate (common reasons for this are a lower request or + * metadata update timeout) + * @param now The current time in ms + * @throws IllegalStateException If a request is sent to an unready node + */ + List poll(long timeout, long now); + + /** + * Disconnects the connection to a particular node, if there is one. + * Any pending ClientRequests for this connection will receive disconnections. + * + * @param nodeId The id of the node + */ + void disconnect(String nodeId); + + /** + * Closes the connection to a particular node (if there is one). + * All requests on the connection will be cleared. ClientRequest callbacks will not be invoked + * for the cleared requests, nor will they be returned from poll(). + * + * @param nodeId The id of the node + */ + void close(String nodeId); + + /** + * Choose the node with the fewest outstanding requests. This method will prefer a node with an existing connection, + * but will potentially choose a node for which we don't yet have a connection if all existing connections are in + * use. + * + * @param now The current time in ms + * @return The node with the fewest in-flight requests. + */ + Node leastLoadedNode(long now); + + /** + * The number of currently in-flight requests for which we have not yet returned a response + */ + int inFlightRequestCount(); + + /** + * Return true if there is at least one in-flight request and false otherwise. + */ + boolean hasInFlightRequests(); + + /** + * Get the total in-flight requests for a particular node + * + * @param nodeId The id of the node + */ + int inFlightRequestCount(String nodeId); + + /** + * Return true if there is at least one in-flight request for a particular node and false otherwise. + */ + boolean hasInFlightRequests(String nodeId); + + /** + * Return true if there is at least one node with connection in the READY state and not throttled. Returns false + * otherwise. + * + * @param now the current time + */ + boolean hasReadyNodes(long now); + + /** + * Wake up the client if it is currently blocked waiting for I/O + */ + void wakeup(); + + /** + * Create a new ClientRequest. + * + * @param nodeId the node to send to + * @param requestBuilder the request builder to use + * @param createdTimeMs the time in milliseconds to use as the creation time of the request + * @param expectResponse true iff we expect a response + */ + ClientRequest newClientRequest(String nodeId, AbstractRequest.Builder requestBuilder, + long createdTimeMs, boolean expectResponse); + + /** + * Create a new ClientRequest. + * + * @param nodeId the node to send to + * @param requestBuilder the request builder to use + * @param createdTimeMs the time in milliseconds to use as the creation time of the request + * @param expectResponse true iff we expect a response + * @param requestTimeoutMs Upper bound time in milliseconds to await a response before disconnecting the socket and + * cancelling the request. The request may get cancelled sooner if the socket disconnects + * for any reason including if another pending request to the same node timed out first. + * @param callback the callback to invoke when we get a response + */ + ClientRequest newClientRequest(String nodeId, + AbstractRequest.Builder requestBuilder, + long createdTimeMs, + boolean expectResponse, + int requestTimeoutMs, + RequestCompletionHandler callback); + + + + /** + * Initiates shutdown of this client. This method may be invoked from another thread while this + * client is being polled. No further requests may be sent using the client. The current poll() + * will be terminated using wakeup(). The client should be explicitly shutdown using {@link #close()} + * after poll returns. Note that {@link #close()} should not be invoked concurrently while polling. + */ + void initiateClose(); + + /** + * Returns true if the client is still active. Returns false if {@link #initiateClose()} or {@link #close()} + * was invoked for this client. + */ + boolean active(); + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/ManualMetadataUpdater.java b/clients/src/main/java/org/apache/kafka/clients/ManualMetadataUpdater.java new file mode 100644 index 0000000..3d51549 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/ManualMetadataUpdater.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.requests.RequestHeader; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +/** + * A simple implementation of `MetadataUpdater` that returns the cluster nodes set via the constructor or via + * `setNodes`. + * + * This is useful in cases where automatic metadata updates are not required. An example is controller/broker + * communication. + * + * This class is not thread-safe! + */ +public class ManualMetadataUpdater implements MetadataUpdater { + private List nodes; + + public ManualMetadataUpdater() { + this(new ArrayList<>(0)); + } + + public ManualMetadataUpdater(List nodes) { + this.nodes = nodes; + } + + public void setNodes(List nodes) { + this.nodes = nodes; + } + + @Override + public List fetchNodes() { + return new ArrayList<>(nodes); + } + + @Override + public boolean isUpdateDue(long now) { + return false; + } + + @Override + public long maybeUpdate(long now) { + return Long.MAX_VALUE; + } + + @Override + public void handleServerDisconnect(long now, String nodeId, Optional maybeAuthException) { + // We don't fail the broker on failures. There should be sufficient information from + // the NetworkClient logs to indicate the reason for the failure. + } + + @Override + public void handleFailedRequest(long now, Optional maybeFatalException) { + // Do nothing + } + + @Override + public void handleSuccessfulResponse(RequestHeader requestHeader, long now, MetadataResponse response) { + // Do nothing + } + + @Override + public void close() { + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/Metadata.java b/clients/src/main/java/org/apache/kafka/clients/Metadata.java new file mode 100644 index 0000000..60d2c05 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/Metadata.java @@ -0,0 +1,636 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients; + +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.errors.InvalidMetadataException; +import org.apache.kafka.common.errors.InvalidTopicException; +import org.apache.kafka.common.errors.TopicAuthorizationException; +import org.apache.kafka.common.internals.ClusterResourceListeners; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.MetadataRequest; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.utils.LogContext; +import org.slf4j.Logger; + +import java.io.Closeable; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.function.Supplier; + +import static org.apache.kafka.common.record.RecordBatch.NO_PARTITION_LEADER_EPOCH; + +/** + * A class encapsulating some of the logic around metadata. + *

+ * This class is shared by the client thread (for partitioning) and the background sender thread. + * + * Metadata is maintained for only a subset of topics, which can be added to over time. When we request metadata for a + * topic we don't have any metadata for it will trigger a metadata update. + *

+ * If topic expiry is enabled for the metadata, any topic that has not been used within the expiry interval + * is removed from the metadata refresh set after an update. Consumers disable topic expiry since they explicitly + * manage topics while producers rely on topic expiry to limit the refresh set. + */ +public class Metadata implements Closeable { + private final Logger log; + private final long refreshBackoffMs; + private final long metadataExpireMs; + private int updateVersion; // bumped on every metadata response + private int requestVersion; // bumped on every new topic addition + private long lastRefreshMs; + private long lastSuccessfulRefreshMs; + private KafkaException fatalException; + private Set invalidTopics; + private Set unauthorizedTopics; + private MetadataCache cache = MetadataCache.empty(); + private boolean needFullUpdate; + private boolean needPartialUpdate; + private final ClusterResourceListeners clusterResourceListeners; + private boolean isClosed; + private final Map lastSeenLeaderEpochs; + + /** + * Create a new Metadata instance + * + * @param refreshBackoffMs The minimum amount of time that must expire between metadata refreshes to avoid busy + * polling + * @param metadataExpireMs The maximum amount of time that metadata can be retained without refresh + * @param logContext Log context corresponding to the containing client + * @param clusterResourceListeners List of ClusterResourceListeners which will receive metadata updates. + */ + public Metadata(long refreshBackoffMs, + long metadataExpireMs, + LogContext logContext, + ClusterResourceListeners clusterResourceListeners) { + this.log = logContext.logger(Metadata.class); + this.refreshBackoffMs = refreshBackoffMs; + this.metadataExpireMs = metadataExpireMs; + this.lastRefreshMs = 0L; + this.lastSuccessfulRefreshMs = 0L; + this.requestVersion = 0; + this.updateVersion = 0; + this.needFullUpdate = false; + this.needPartialUpdate = false; + this.clusterResourceListeners = clusterResourceListeners; + this.isClosed = false; + this.lastSeenLeaderEpochs = new HashMap<>(); + this.invalidTopics = Collections.emptySet(); + this.unauthorizedTopics = Collections.emptySet(); + } + + /** + * Get the current cluster info without blocking + */ + public synchronized Cluster fetch() { + return cache.cluster(); + } + + /** + * Return the next time when the current cluster info can be updated (i.e., backoff time has elapsed). + * + * @param nowMs current time in ms + * @return remaining time in ms till the cluster info can be updated again + */ + public synchronized long timeToAllowUpdate(long nowMs) { + return Math.max(this.lastRefreshMs + this.refreshBackoffMs - nowMs, 0); + } + + /** + * The next time to update the cluster info is the maximum of the time the current info will expire and the time the + * current info can be updated (i.e. backoff time has elapsed); If an update has been request then the expiry time + * is now + * + * @param nowMs current time in ms + * @return remaining time in ms till updating the cluster info + */ + public synchronized long timeToNextUpdate(long nowMs) { + long timeToExpire = updateRequested() ? 0 : Math.max(this.lastSuccessfulRefreshMs + this.metadataExpireMs - nowMs, 0); + return Math.max(timeToExpire, timeToAllowUpdate(nowMs)); + } + + public long metadataExpireMs() { + return this.metadataExpireMs; + } + + /** + * Request an update of the current cluster metadata info, return the current updateVersion before the update + */ + public synchronized int requestUpdate() { + this.needFullUpdate = true; + return this.updateVersion; + } + + public synchronized int requestUpdateForNewTopics() { + // Override the timestamp of last refresh to let immediate update. + this.lastRefreshMs = 0; + this.needPartialUpdate = true; + this.requestVersion++; + return this.updateVersion; + } + + /** + * Request an update for the partition metadata iff we have seen a newer leader epoch. This is called by the client + * any time it handles a response from the broker that includes leader epoch, except for UpdateMetadata which + * follows a different code path ({@link #update}). + * + * @param topicPartition + * @param leaderEpoch + * @return true if we updated the last seen epoch, false otherwise + */ + public synchronized boolean updateLastSeenEpochIfNewer(TopicPartition topicPartition, int leaderEpoch) { + Objects.requireNonNull(topicPartition, "TopicPartition cannot be null"); + if (leaderEpoch < 0) + throw new IllegalArgumentException("Invalid leader epoch " + leaderEpoch + " (must be non-negative)"); + + Integer oldEpoch = lastSeenLeaderEpochs.get(topicPartition); + log.trace("Determining if we should replace existing epoch {} with new epoch {} for partition {}", oldEpoch, leaderEpoch, topicPartition); + + final boolean updated; + if (oldEpoch == null) { + log.debug("Not replacing null epoch with new epoch {} for partition {}", leaderEpoch, topicPartition); + updated = false; + } else if (leaderEpoch > oldEpoch) { + log.debug("Updating last seen epoch from {} to {} for partition {}", oldEpoch, leaderEpoch, topicPartition); + lastSeenLeaderEpochs.put(topicPartition, leaderEpoch); + updated = true; + } else { + log.debug("Not replacing existing epoch {} with new epoch {} for partition {}", oldEpoch, leaderEpoch, topicPartition); + updated = false; + } + + this.needFullUpdate = this.needFullUpdate || updated; + return updated; + } + + public Optional lastSeenLeaderEpoch(TopicPartition topicPartition) { + return Optional.ofNullable(lastSeenLeaderEpochs.get(topicPartition)); + } + + /** + * Check whether an update has been explicitly requested. + * + * @return true if an update was requested, false otherwise + */ + public synchronized boolean updateRequested() { + return this.needFullUpdate || this.needPartialUpdate; + } + + /** + * Return the cached partition info if it exists and a newer leader epoch isn't known about. + */ + synchronized Optional partitionMetadataIfCurrent(TopicPartition topicPartition) { + Integer epoch = lastSeenLeaderEpochs.get(topicPartition); + Optional partitionMetadata = cache.partitionMetadata(topicPartition); + if (epoch == null) { + // old cluster format (no epochs) + return partitionMetadata; + } else { + return partitionMetadata.filter(metadata -> + metadata.leaderEpoch.orElse(NO_PARTITION_LEADER_EPOCH).equals(epoch)); + } + } + + /** + * @return a mapping from topic names to topic IDs for all topics with valid IDs in the cache + */ + public synchronized Map topicIds() { + return cache.topicIds(); + } + + public synchronized LeaderAndEpoch currentLeader(TopicPartition topicPartition) { + Optional maybeMetadata = partitionMetadataIfCurrent(topicPartition); + if (!maybeMetadata.isPresent()) + return new LeaderAndEpoch(Optional.empty(), Optional.ofNullable(lastSeenLeaderEpochs.get(topicPartition))); + + MetadataResponse.PartitionMetadata partitionMetadata = maybeMetadata.get(); + Optional leaderEpochOpt = partitionMetadata.leaderEpoch; + Optional leaderNodeOpt = partitionMetadata.leaderId.flatMap(cache::nodeById); + return new LeaderAndEpoch(leaderNodeOpt, leaderEpochOpt); + } + + public synchronized void bootstrap(List addresses) { + this.needFullUpdate = true; + this.updateVersion += 1; + this.cache = MetadataCache.bootstrap(addresses); + } + + /** + * Update metadata assuming the current request version. + * + * For testing only. + */ + public synchronized void updateWithCurrentRequestVersion(MetadataResponse response, boolean isPartialUpdate, long nowMs) { + this.update(this.requestVersion, response, isPartialUpdate, nowMs); + } + + /** + * Updates the cluster metadata. If topic expiry is enabled, expiry time + * is set for topics if required and expired topics are removed from the metadata. + * + * @param requestVersion The request version corresponding to the update response, as provided by + * {@link #newMetadataRequestAndVersion(long)}. + * @param response metadata response received from the broker + * @param isPartialUpdate whether the metadata request was for a subset of the active topics + * @param nowMs current time in milliseconds + */ + public synchronized void update(int requestVersion, MetadataResponse response, boolean isPartialUpdate, long nowMs) { + Objects.requireNonNull(response, "Metadata response cannot be null"); + if (isClosed()) + throw new IllegalStateException("Update requested after metadata close"); + + this.needPartialUpdate = requestVersion < this.requestVersion; + this.lastRefreshMs = nowMs; + this.updateVersion += 1; + if (!isPartialUpdate) { + this.needFullUpdate = false; + this.lastSuccessfulRefreshMs = nowMs; + } + + String previousClusterId = cache.clusterResource().clusterId(); + + this.cache = handleMetadataResponse(response, isPartialUpdate, nowMs); + + Cluster cluster = cache.cluster(); + maybeSetMetadataError(cluster); + + this.lastSeenLeaderEpochs.keySet().removeIf(tp -> !retainTopic(tp.topic(), false, nowMs)); + + String newClusterId = cache.clusterResource().clusterId(); + if (!Objects.equals(previousClusterId, newClusterId)) { + log.info("Cluster ID: {}", newClusterId); + } + clusterResourceListeners.onUpdate(cache.clusterResource()); + + log.debug("Updated cluster metadata updateVersion {} to {}", this.updateVersion, this.cache); + } + + private void maybeSetMetadataError(Cluster cluster) { + clearRecoverableErrors(); + checkInvalidTopics(cluster); + checkUnauthorizedTopics(cluster); + } + + private void checkInvalidTopics(Cluster cluster) { + if (!cluster.invalidTopics().isEmpty()) { + log.error("Metadata response reported invalid topics {}", cluster.invalidTopics()); + invalidTopics = new HashSet<>(cluster.invalidTopics()); + } + } + + private void checkUnauthorizedTopics(Cluster cluster) { + if (!cluster.unauthorizedTopics().isEmpty()) { + log.error("Topic authorization failed for topics {}", cluster.unauthorizedTopics()); + unauthorizedTopics = new HashSet<>(cluster.unauthorizedTopics()); + } + } + + /** + * Transform a MetadataResponse into a new MetadataCache instance. + */ + private MetadataCache handleMetadataResponse(MetadataResponse metadataResponse, boolean isPartialUpdate, long nowMs) { + // All encountered topics. + Set topics = new HashSet<>(); + + // Retained topics to be passed to the metadata cache. + Set internalTopics = new HashSet<>(); + Set unauthorizedTopics = new HashSet<>(); + Set invalidTopics = new HashSet<>(); + + List partitions = new ArrayList<>(); + Map topicIds = new HashMap<>(); + Map oldTopicIds = cache.topicIds(); + for (MetadataResponse.TopicMetadata metadata : metadataResponse.topicMetadata()) { + String topicName = metadata.topic(); + Uuid topicId = metadata.topicId(); + topics.add(topicName); + // We can only reason about topic ID changes when both IDs are valid, so keep oldId null unless the new metadata contains a topic ID + Uuid oldTopicId = null; + if (!Uuid.ZERO_UUID.equals(topicId)) { + topicIds.put(topicName, topicId); + oldTopicId = oldTopicIds.get(topicName); + } else { + topicId = null; + } + + if (!retainTopic(topicName, metadata.isInternal(), nowMs)) + continue; + + if (metadata.isInternal()) + internalTopics.add(topicName); + + if (metadata.error() == Errors.NONE) { + for (MetadataResponse.PartitionMetadata partitionMetadata : metadata.partitionMetadata()) { + // Even if the partition's metadata includes an error, we need to handle + // the update to catch new epochs + updateLatestMetadata(partitionMetadata, metadataResponse.hasReliableLeaderEpochs(), topicId, oldTopicId) + .ifPresent(partitions::add); + + if (partitionMetadata.error.exception() instanceof InvalidMetadataException) { + log.debug("Requesting metadata update for partition {} due to error {}", + partitionMetadata.topicPartition, partitionMetadata.error); + requestUpdate(); + } + } + } else { + if (metadata.error().exception() instanceof InvalidMetadataException) { + log.debug("Requesting metadata update for topic {} due to error {}", topicName, metadata.error()); + requestUpdate(); + } + + if (metadata.error() == Errors.INVALID_TOPIC_EXCEPTION) + invalidTopics.add(topicName); + else if (metadata.error() == Errors.TOPIC_AUTHORIZATION_FAILED) + unauthorizedTopics.add(topicName); + } + } + + Map nodes = metadataResponse.brokersById(); + if (isPartialUpdate) + return this.cache.mergeWith(metadataResponse.clusterId(), nodes, partitions, + unauthorizedTopics, invalidTopics, internalTopics, metadataResponse.controller(), topicIds, + (topic, isInternal) -> !topics.contains(topic) && retainTopic(topic, isInternal, nowMs)); + else + return new MetadataCache(metadataResponse.clusterId(), nodes, partitions, + unauthorizedTopics, invalidTopics, internalTopics, metadataResponse.controller(), topicIds); + } + + /** + * Compute the latest partition metadata to cache given ordering by leader epochs (if both + * available and reliable) and whether the topic ID changed. + */ + private Optional updateLatestMetadata( + MetadataResponse.PartitionMetadata partitionMetadata, + boolean hasReliableLeaderEpoch, + Uuid topicId, + Uuid oldTopicId) { + TopicPartition tp = partitionMetadata.topicPartition; + if (hasReliableLeaderEpoch && partitionMetadata.leaderEpoch.isPresent()) { + int newEpoch = partitionMetadata.leaderEpoch.get(); + Integer currentEpoch = lastSeenLeaderEpochs.get(tp); + if (topicId != null && !topicId.equals(oldTopicId)) { + // If the new topic ID is valid and different from the last seen topic ID, update the metadata. + // Between the time that a topic is deleted and re-created, the client may lose track of the + // corresponding topicId (i.e. `oldTopicId` will be null). In this case, when we discover the new + // topicId, we allow the corresponding leader epoch to override the last seen value. + log.info("Resetting the last seen epoch of partition {} to {} since the associated topicId changed from {} to {}", + tp, newEpoch, oldTopicId, topicId); + lastSeenLeaderEpochs.put(tp, newEpoch); + return Optional.of(partitionMetadata); + } else if (currentEpoch == null || newEpoch >= currentEpoch) { + // If the received leader epoch is at least the same as the previous one, update the metadata + log.debug("Updating last seen epoch for partition {} from {} to epoch {} from new metadata", tp, currentEpoch, newEpoch); + lastSeenLeaderEpochs.put(tp, newEpoch); + return Optional.of(partitionMetadata); + } else { + // Otherwise ignore the new metadata and use the previously cached info + log.debug("Got metadata for an older epoch {} (current is {}) for partition {}, not updating", newEpoch, currentEpoch, tp); + return cache.partitionMetadata(tp); + } + } else { + // Handle old cluster formats as well as error responses where leader and epoch are missing + lastSeenLeaderEpochs.remove(tp); + return Optional.of(partitionMetadata.withoutLeaderEpoch()); + } + } + + /** + * If any non-retriable exceptions were encountered during metadata update, clear and throw the exception. + * This is used by the consumer to propagate any fatal exceptions or topic exceptions for any of the topics + * in the consumer's Metadata. + */ + public synchronized void maybeThrowAnyException() { + clearErrorsAndMaybeThrowException(this::recoverableException); + } + + /** + * If any fatal exceptions were encountered during metadata update, throw the exception. This is used by + * the producer to abort waiting for metadata if there were fatal exceptions (e.g. authentication failures) + * in the last metadata update. + */ + protected synchronized void maybeThrowFatalException() { + KafkaException metadataException = this.fatalException; + if (metadataException != null) { + fatalException = null; + throw metadataException; + } + } + + /** + * If any non-retriable exceptions were encountered during metadata update, throw exception if the exception + * is fatal or related to the specified topic. All exceptions from the last metadata update are cleared. + * This is used by the producer to propagate topic metadata errors for send requests. + */ + public synchronized void maybeThrowExceptionForTopic(String topic) { + clearErrorsAndMaybeThrowException(() -> recoverableExceptionForTopic(topic)); + } + + private void clearErrorsAndMaybeThrowException(Supplier recoverableExceptionSupplier) { + KafkaException metadataException = Optional.ofNullable(fatalException).orElseGet(recoverableExceptionSupplier); + fatalException = null; + clearRecoverableErrors(); + if (metadataException != null) + throw metadataException; + } + + // We may be able to recover from this exception if metadata for this topic is no longer needed + private KafkaException recoverableException() { + if (!unauthorizedTopics.isEmpty()) + return new TopicAuthorizationException(unauthorizedTopics); + else if (!invalidTopics.isEmpty()) + return new InvalidTopicException(invalidTopics); + else + return null; + } + + private KafkaException recoverableExceptionForTopic(String topic) { + if (unauthorizedTopics.contains(topic)) + return new TopicAuthorizationException(Collections.singleton(topic)); + else if (invalidTopics.contains(topic)) + return new InvalidTopicException(Collections.singleton(topic)); + else + return null; + } + + private void clearRecoverableErrors() { + invalidTopics = Collections.emptySet(); + unauthorizedTopics = Collections.emptySet(); + } + + /** + * Record an attempt to update the metadata that failed. We need to keep track of this + * to avoid retrying immediately. + */ + public synchronized void failedUpdate(long now) { + this.lastRefreshMs = now; + } + + /** + * Propagate a fatal error which affects the ability to fetch metadata for the cluster. + * Two examples are authentication and unsupported version exceptions. + * + * @param exception The fatal exception + */ + public synchronized void fatalError(KafkaException exception) { + this.fatalException = exception; + } + + /** + * @return The current metadata updateVersion + */ + public synchronized int updateVersion() { + return this.updateVersion; + } + + /** + * The last time metadata was successfully updated. + */ + public synchronized long lastSuccessfulUpdate() { + return this.lastSuccessfulRefreshMs; + } + + /** + * Close this metadata instance to indicate that metadata updates are no longer possible. + */ + @Override + public synchronized void close() { + this.isClosed = true; + } + + /** + * Check if this metadata instance has been closed. See {@link #close()} for more information. + * + * @return True if this instance has been closed; false otherwise + */ + public synchronized boolean isClosed() { + return this.isClosed; + } + + public synchronized MetadataRequestAndVersion newMetadataRequestAndVersion(long nowMs) { + MetadataRequest.Builder request = null; + boolean isPartialUpdate = false; + + // Perform a partial update only if a full update hasn't been requested, and the last successful + // hasn't exceeded the metadata refresh time. + if (!this.needFullUpdate && this.lastSuccessfulRefreshMs + this.metadataExpireMs > nowMs) { + request = newMetadataRequestBuilderForNewTopics(); + isPartialUpdate = true; + } + if (request == null) { + request = newMetadataRequestBuilder(); + isPartialUpdate = false; + } + return new MetadataRequestAndVersion(request, requestVersion, isPartialUpdate); + } + + /** + * Constructs and returns a metadata request builder for fetching cluster data and all active topics. + * + * @return the constructed non-null metadata builder + */ + protected MetadataRequest.Builder newMetadataRequestBuilder() { + return MetadataRequest.Builder.allTopics(); + } + + /** + * Constructs and returns a metadata request builder for fetching cluster data and any uncached topics, + * otherwise null if the functionality is not supported. + * + * @return the constructed metadata builder, or null if not supported + */ + protected MetadataRequest.Builder newMetadataRequestBuilderForNewTopics() { + return null; + } + + protected boolean retainTopic(String topic, boolean isInternal, long nowMs) { + return true; + } + + public static class MetadataRequestAndVersion { + public final MetadataRequest.Builder requestBuilder; + public final int requestVersion; + public final boolean isPartialUpdate; + + private MetadataRequestAndVersion(MetadataRequest.Builder requestBuilder, + int requestVersion, + boolean isPartialUpdate) { + this.requestBuilder = requestBuilder; + this.requestVersion = requestVersion; + this.isPartialUpdate = isPartialUpdate; + } + } + + /** + * Represents current leader state known in metadata. It is possible that we know the leader, but not the + * epoch if the metadata is received from a broker which does not support a sufficient Metadata API version. + * It is also possible that we know of the leader epoch, but not the leader when it is derived + * from an external source (e.g. a committed offset). + */ + public static class LeaderAndEpoch { + private static final LeaderAndEpoch NO_LEADER_OR_EPOCH = new LeaderAndEpoch(Optional.empty(), Optional.empty()); + + public final Optional leader; + public final Optional epoch; + + public LeaderAndEpoch(Optional leader, Optional epoch) { + this.leader = Objects.requireNonNull(leader); + this.epoch = Objects.requireNonNull(epoch); + } + + public static LeaderAndEpoch noLeaderOrEpoch() { + return NO_LEADER_OR_EPOCH; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + LeaderAndEpoch that = (LeaderAndEpoch) o; + + if (!leader.equals(that.leader)) return false; + return epoch.equals(that.epoch); + } + + @Override + public int hashCode() { + int result = leader.hashCode(); + result = 31 * result + epoch.hashCode(); + return result; + } + + @Override + public String toString() { + return "LeaderAndEpoch{" + + "leader=" + leader + + ", epoch=" + epoch.map(Number::toString).orElse("absent") + + '}'; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/MetadataCache.java b/clients/src/main/java/org/apache/kafka/clients/MetadataCache.java new file mode 100644 index 0000000..d7b6bfd --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/MetadataCache.java @@ -0,0 +1,235 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients; + +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.ClusterResource; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.requests.MetadataResponse.PartitionMetadata; + +import java.net.InetSocketAddress; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.BiPredicate; +import java.util.function.Predicate; +import java.util.stream.Collectors; + +/** + * An internal mutable cache of nodes, topics, and partitions in the Kafka cluster. This keeps an up-to-date Cluster + * instance which is optimized for read access. + */ +public class MetadataCache { + private final String clusterId; + private final Map nodes; + private final Set unauthorizedTopics; + private final Set invalidTopics; + private final Set internalTopics; + private final Node controller; + private final Map metadataByPartition; + private final Map topicIds; + + private Cluster clusterInstance; + + MetadataCache(String clusterId, + Map nodes, + Collection partitions, + Set unauthorizedTopics, + Set invalidTopics, + Set internalTopics, + Node controller, + Map topicIds) { + this(clusterId, nodes, partitions, unauthorizedTopics, invalidTopics, internalTopics, controller, topicIds, null); + } + + private MetadataCache(String clusterId, + Map nodes, + Collection partitions, + Set unauthorizedTopics, + Set invalidTopics, + Set internalTopics, + Node controller, + Map topicIds, + Cluster clusterInstance) { + this.clusterId = clusterId; + this.nodes = nodes; + this.unauthorizedTopics = unauthorizedTopics; + this.invalidTopics = invalidTopics; + this.internalTopics = internalTopics; + this.controller = controller; + this.topicIds = topicIds; + + this.metadataByPartition = new HashMap<>(partitions.size()); + for (PartitionMetadata p : partitions) { + this.metadataByPartition.put(p.topicPartition, p); + } + + if (clusterInstance == null) { + computeClusterView(); + } else { + this.clusterInstance = clusterInstance; + } + } + + Optional partitionMetadata(TopicPartition topicPartition) { + return Optional.ofNullable(metadataByPartition.get(topicPartition)); + } + + Map topicIds() { + return topicIds; + } + + Optional nodeById(int id) { + return Optional.ofNullable(nodes.get(id)); + } + + Cluster cluster() { + if (clusterInstance == null) { + throw new IllegalStateException("Cached Cluster instance should not be null, but was."); + } else { + return clusterInstance; + } + } + + ClusterResource clusterResource() { + return new ClusterResource(clusterId); + } + + /** + * Merges the metadata cache's contents with the provided metadata, returning a new metadata cache. The provided + * metadata is presumed to be more recent than the cache's metadata, and therefore all overlapping metadata will + * be overridden. + * + * @param newClusterId the new cluster Id + * @param newNodes the new set of nodes + * @param addPartitions partitions to add + * @param addUnauthorizedTopics unauthorized topics to add + * @param addInternalTopics internal topics to add + * @param newController the new controller node + * @param topicIds the mapping from topic name to topic ID from the MetadataResponse + * @param retainTopic returns whether a topic's metadata should be retained + * @return the merged metadata cache + */ + MetadataCache mergeWith(String newClusterId, + Map newNodes, + Collection addPartitions, + Set addUnauthorizedTopics, + Set addInvalidTopics, + Set addInternalTopics, + Node newController, + Map topicIds, + BiPredicate retainTopic) { + + Predicate shouldRetainTopic = topic -> retainTopic.test(topic, internalTopics.contains(topic)); + + Map newMetadataByPartition = new HashMap<>(addPartitions.size()); + + // We want the most recent topic ID. We start with the previous ID stored for retained topics and then + // update with newest information from the MetadataResponse. We always take the latest state, removing existing + // topic IDs if the latest state contains the topic name but not a topic ID. + Map newTopicIds = topicIds.entrySet().stream() + .filter(entry -> shouldRetainTopic.test(entry.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + for (PartitionMetadata partition : addPartitions) { + newMetadataByPartition.put(partition.topicPartition, partition); + Uuid id = topicIds.get(partition.topic()); + if (id != null) + newTopicIds.put(partition.topic(), id); + else + // Remove if the latest metadata does not have a topic ID + newTopicIds.remove(partition.topic()); + } + for (Map.Entry entry : metadataByPartition.entrySet()) { + if (shouldRetainTopic.test(entry.getKey().topic())) { + newMetadataByPartition.putIfAbsent(entry.getKey(), entry.getValue()); + } + } + + Set newUnauthorizedTopics = fillSet(addUnauthorizedTopics, unauthorizedTopics, shouldRetainTopic); + Set newInvalidTopics = fillSet(addInvalidTopics, invalidTopics, shouldRetainTopic); + Set newInternalTopics = fillSet(addInternalTopics, internalTopics, shouldRetainTopic); + + return new MetadataCache(newClusterId, newNodes, newMetadataByPartition.values(), newUnauthorizedTopics, + newInvalidTopics, newInternalTopics, newController, newTopicIds); + } + + /** + * Copies {@code baseSet} and adds all non-existent elements in {@code fillSet} such that {@code predicate} is true. + * In other words, all elements of {@code baseSet} will be contained in the result, with additional non-overlapping + * elements in {@code fillSet} where the predicate is true. + * + * @param baseSet the base elements for the resulting set + * @param fillSet elements to be filled into the resulting set + * @param predicate tested against the fill set to determine whether elements should be added to the base set + */ + private static Set fillSet(Set baseSet, Set fillSet, Predicate predicate) { + Set result = new HashSet<>(baseSet); + for (T element : fillSet) { + if (predicate.test(element)) { + result.add(element); + } + } + return result; + } + + private void computeClusterView() { + List partitionInfos = metadataByPartition.values() + .stream() + .map(metadata -> MetadataResponse.toPartitionInfo(metadata, nodes)) + .collect(Collectors.toList()); + this.clusterInstance = new Cluster(clusterId, nodes.values(), partitionInfos, unauthorizedTopics, + invalidTopics, internalTopics, controller, topicIds); + } + + static MetadataCache bootstrap(List addresses) { + Map nodes = new HashMap<>(); + int nodeId = -1; + for (InetSocketAddress address : addresses) { + nodes.put(nodeId, new Node(nodeId, address.getHostString(), address.getPort())); + nodeId--; + } + return new MetadataCache(null, nodes, Collections.emptyList(), + Collections.emptySet(), Collections.emptySet(), Collections.emptySet(), + null, Collections.emptyMap(), Cluster.bootstrap(addresses)); + } + + static MetadataCache empty() { + return new MetadataCache(null, Collections.emptyMap(), Collections.emptyList(), + Collections.emptySet(), Collections.emptySet(), Collections.emptySet(), null, Collections.emptyMap(), Cluster.empty()); + } + + @Override + public String toString() { + return "MetadataCache{" + + "clusterId='" + clusterId + '\'' + + ", nodes=" + nodes + + ", partitions=" + metadataByPartition.values() + + ", controller=" + controller + + '}'; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/MetadataUpdater.java b/clients/src/main/java/org/apache/kafka/clients/MetadataUpdater.java new file mode 100644 index 0000000..77f3efa --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/MetadataUpdater.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.requests.RequestHeader; + +import java.io.Closeable; +import java.util.List; +import java.util.Optional; + +/** + * The interface used by `NetworkClient` to request cluster metadata info to be updated and to retrieve the cluster nodes + * from such metadata. This is an internal class. + *

+ * This class is not thread-safe! + */ +public interface MetadataUpdater extends Closeable { + + /** + * Gets the current cluster info without blocking. + */ + List fetchNodes(); + + /** + * Returns true if an update to the cluster metadata info is due. + */ + boolean isUpdateDue(long now); + + /** + * Starts a cluster metadata update if needed and possible. Returns the time until the metadata update (which would + * be 0 if an update has been started as a result of this call). + * + * If the implementation relies on `NetworkClient` to send requests, `handleSuccessfulResponse` will be + * invoked after the metadata response is received. + * + * The semantics of `needed` and `possible` are implementation-dependent and may take into account a number of + * factors like node availability, how long since the last metadata update, etc. + */ + long maybeUpdate(long now); + + /** + * Handle a server disconnect. + * + * This provides a mechanism for the `MetadataUpdater` implementation to use the NetworkClient instance for its own + * requests with special handling for disconnections of such requests. + * + * @param now Current time in milliseconds + * @param nodeId The id of the node that disconnected + * @param maybeAuthException Optional authentication error + */ + void handleServerDisconnect(long now, String nodeId, Optional maybeAuthException); + + /** + * Handle a metadata request failure. + * + * @param now Current time in milliseconds + * @param maybeFatalException Optional fatal error (e.g. {@link UnsupportedVersionException}) + */ + void handleFailedRequest(long now, Optional maybeFatalException); + + /** + * Handle responses for metadata requests. + * + * This provides a mechanism for the `MetadataUpdater` implementation to use the NetworkClient instance for its own + * requests with special handling for completed receives of such requests. + */ + void handleSuccessfulResponse(RequestHeader requestHeader, long now, MetadataResponse metadataResponse); + + /** + * Close this updater. + */ + @Override + void close(); +} diff --git a/clients/src/main/java/org/apache/kafka/clients/NetworkClient.java b/clients/src/main/java/org/apache/kafka/clients/NetworkClient.java new file mode 100644 index 0000000..e7a0ac7 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/NetworkClient.java @@ -0,0 +1,1295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients; + +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.errors.DisconnectException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.ApiVersionsResponseData.ApiVersion; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.network.ChannelState; +import org.apache.kafka.common.network.NetworkSend; +import org.apache.kafka.common.network.NetworkReceive; +import org.apache.kafka.common.network.Selectable; +import org.apache.kafka.common.network.Send; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.protocol.types.SchemaException; +import org.apache.kafka.common.requests.AbstractRequest; +import org.apache.kafka.common.requests.AbstractResponse; +import org.apache.kafka.common.requests.ApiVersionsRequest; +import org.apache.kafka.common.requests.ApiVersionsResponse; +import org.apache.kafka.common.requests.CorrelationIdMismatchException; +import org.apache.kafka.common.requests.MetadataRequest; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.requests.RequestHeader; +import org.apache.kafka.common.security.authenticator.SaslClientAuthenticator; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.nio.BufferUnderflowException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Random; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; + +/** + * A network client for asynchronous request/response network i/o. This is an internal class used to implement the + * user-facing producer and consumer clients. + *

+ * This class is not thread-safe! + */ +public class NetworkClient implements KafkaClient { + + private enum State { + ACTIVE, + CLOSING, + CLOSED + } + + private final Logger log; + + /* the selector used to perform network i/o */ + private final Selectable selector; + + private final MetadataUpdater metadataUpdater; + + private final Random randOffset; + + /* the state of each node's connection */ + private final ClusterConnectionStates connectionStates; + + /* the set of requests currently being sent or awaiting a response */ + private final InFlightRequests inFlightRequests; + + /* the socket send buffer size in bytes */ + private final int socketSendBuffer; + + /* the socket receive size buffer in bytes */ + private final int socketReceiveBuffer; + + /* the client id used to identify this client in requests to the server */ + private final String clientId; + + /* the current correlation id to use when sending requests to servers */ + private int correlation; + + /* default timeout for individual requests to await acknowledgement from servers */ + private final int defaultRequestTimeoutMs; + + /* time in ms to wait before retrying to create connection to a server */ + private final long reconnectBackoffMs; + + private final Time time; + + /** + * True if we should send an ApiVersionRequest when first connecting to a broker. + */ + private final boolean discoverBrokerVersions; + + private final ApiVersions apiVersions; + + private final Map nodesNeedingApiVersionsFetch = new HashMap<>(); + + private final List abortedSends = new LinkedList<>(); + + private final Sensor throttleTimeSensor; + + private final AtomicReference state; + + public NetworkClient(Selectable selector, + Metadata metadata, + String clientId, + int maxInFlightRequestsPerConnection, + long reconnectBackoffMs, + long reconnectBackoffMax, + int socketSendBuffer, + int socketReceiveBuffer, + int defaultRequestTimeoutMs, + long connectionSetupTimeoutMs, + long connectionSetupTimeoutMaxMs, + Time time, + boolean discoverBrokerVersions, + ApiVersions apiVersions, + LogContext logContext) { + this(selector, + metadata, + clientId, + maxInFlightRequestsPerConnection, + reconnectBackoffMs, + reconnectBackoffMax, + socketSendBuffer, + socketReceiveBuffer, + defaultRequestTimeoutMs, + connectionSetupTimeoutMs, + connectionSetupTimeoutMaxMs, + time, + discoverBrokerVersions, + apiVersions, + null, + logContext); + } + + public NetworkClient(Selectable selector, + Metadata metadata, + String clientId, + int maxInFlightRequestsPerConnection, + long reconnectBackoffMs, + long reconnectBackoffMax, + int socketSendBuffer, + int socketReceiveBuffer, + int defaultRequestTimeoutMs, + long connectionSetupTimeoutMs, + long connectionSetupTimeoutMaxMs, + Time time, + boolean discoverBrokerVersions, + ApiVersions apiVersions, + Sensor throttleTimeSensor, + LogContext logContext) { + this(null, + metadata, + selector, + clientId, + maxInFlightRequestsPerConnection, + reconnectBackoffMs, + reconnectBackoffMax, + socketSendBuffer, + socketReceiveBuffer, + defaultRequestTimeoutMs, + connectionSetupTimeoutMs, + connectionSetupTimeoutMaxMs, + time, + discoverBrokerVersions, + apiVersions, + throttleTimeSensor, + logContext, + new DefaultHostResolver()); + } + + public NetworkClient(Selectable selector, + MetadataUpdater metadataUpdater, + String clientId, + int maxInFlightRequestsPerConnection, + long reconnectBackoffMs, + long reconnectBackoffMax, + int socketSendBuffer, + int socketReceiveBuffer, + int defaultRequestTimeoutMs, + long connectionSetupTimeoutMs, + long connectionSetupTimeoutMaxMs, + Time time, + boolean discoverBrokerVersions, + ApiVersions apiVersions, + LogContext logContext) { + this(metadataUpdater, + null, + selector, + clientId, + maxInFlightRequestsPerConnection, + reconnectBackoffMs, + reconnectBackoffMax, + socketSendBuffer, + socketReceiveBuffer, + defaultRequestTimeoutMs, + connectionSetupTimeoutMs, + connectionSetupTimeoutMaxMs, + time, + discoverBrokerVersions, + apiVersions, + null, + logContext, + new DefaultHostResolver()); + } + + public NetworkClient(MetadataUpdater metadataUpdater, + Metadata metadata, + Selectable selector, + String clientId, + int maxInFlightRequestsPerConnection, + long reconnectBackoffMs, + long reconnectBackoffMax, + int socketSendBuffer, + int socketReceiveBuffer, + int defaultRequestTimeoutMs, + long connectionSetupTimeoutMs, + long connectionSetupTimeoutMaxMs, + Time time, + boolean discoverBrokerVersions, + ApiVersions apiVersions, + Sensor throttleTimeSensor, + LogContext logContext, + HostResolver hostResolver) { + /* It would be better if we could pass `DefaultMetadataUpdater` from the public constructor, but it's not + * possible because `DefaultMetadataUpdater` is an inner class and it can only be instantiated after the + * super constructor is invoked. + */ + if (metadataUpdater == null) { + if (metadata == null) + throw new IllegalArgumentException("`metadata` must not be null"); + this.metadataUpdater = new DefaultMetadataUpdater(metadata); + } else { + this.metadataUpdater = metadataUpdater; + } + this.selector = selector; + this.clientId = clientId; + this.inFlightRequests = new InFlightRequests(maxInFlightRequestsPerConnection); + this.connectionStates = new ClusterConnectionStates( + reconnectBackoffMs, reconnectBackoffMax, + connectionSetupTimeoutMs, connectionSetupTimeoutMaxMs, logContext, hostResolver); + this.socketSendBuffer = socketSendBuffer; + this.socketReceiveBuffer = socketReceiveBuffer; + this.correlation = 0; + this.randOffset = new Random(); + this.defaultRequestTimeoutMs = defaultRequestTimeoutMs; + this.reconnectBackoffMs = reconnectBackoffMs; + this.time = time; + this.discoverBrokerVersions = discoverBrokerVersions; + this.apiVersions = apiVersions; + this.throttleTimeSensor = throttleTimeSensor; + this.log = logContext.logger(NetworkClient.class); + this.state = new AtomicReference<>(State.ACTIVE); + } + + /** + * Begin connecting to the given node, return true if we are already connected and ready to send to that node. + * + * @param node The node to check + * @param now The current timestamp + * @return True if we are ready to send to the given node + */ + @Override + public boolean ready(Node node, long now) { + if (node.isEmpty()) + throw new IllegalArgumentException("Cannot connect to empty node " + node); + + if (isReady(node, now)) + return true; + + if (connectionStates.canConnect(node.idString(), now)) + // if we are interested in sending to a node and we don't have a connection to it, initiate one + initiateConnect(node, now); + + return false; + } + + // Visible for testing + boolean canConnect(Node node, long now) { + return connectionStates.canConnect(node.idString(), now); + } + + /** + * Disconnects the connection to a particular node, if there is one. + * Any pending ClientRequests for this connection will receive disconnections. + * + * @param nodeId The id of the node + */ + @Override + public void disconnect(String nodeId) { + if (connectionStates.isDisconnected(nodeId)) { + log.debug("Client requested disconnect from node {}, which is already disconnected", nodeId); + return; + } + + log.info("Client requested disconnect from node {}", nodeId); + selector.close(nodeId); + long now = time.milliseconds(); + cancelInFlightRequests(nodeId, now, abortedSends); + connectionStates.disconnected(nodeId, now); + } + + private void cancelInFlightRequests(String nodeId, long now, Collection responses) { + Iterable inFlightRequests = this.inFlightRequests.clearAll(nodeId); + for (InFlightRequest request : inFlightRequests) { + if (log.isDebugEnabled()) { + log.debug("Cancelled in-flight {} request with correlation id {} due to node {} being disconnected " + + "(elapsed time since creation: {}ms, elapsed time since send: {}ms, request timeout: {}ms): {}", + request.header.apiKey(), request.header.correlationId(), nodeId, + request.timeElapsedSinceCreateMs(now), request.timeElapsedSinceSendMs(now), + request.requestTimeoutMs, request.request); + } else { + log.info("Cancelled in-flight {} request with correlation id {} due to node {} being disconnected " + + "(elapsed time since creation: {}ms, elapsed time since send: {}ms, request timeout: {}ms)", + request.header.apiKey(), request.header.correlationId(), nodeId, + request.timeElapsedSinceCreateMs(now), request.timeElapsedSinceSendMs(now), + request.requestTimeoutMs); + } + + if (!request.isInternalRequest) { + if (responses != null) + responses.add(request.disconnected(now, null)); + } else if (request.header.apiKey() == ApiKeys.METADATA) { + metadataUpdater.handleFailedRequest(now, Optional.empty()); + } + } + } + + /** + * Closes the connection to a particular node (if there is one). + * All requests on the connection will be cleared. ClientRequest callbacks will not be invoked + * for the cleared requests, nor will they be returned from poll(). + * + * @param nodeId The id of the node + */ + @Override + public void close(String nodeId) { + log.info("Client requested connection close from node {}", nodeId); + selector.close(nodeId); + long now = time.milliseconds(); + cancelInFlightRequests(nodeId, now, null); + connectionStates.remove(nodeId); + } + + /** + * Returns the number of milliseconds to wait, based on the connection state, before attempting to send data. When + * disconnected, this respects the reconnect backoff time. When connecting or connected, this handles slow/stalled + * connections. + * + * @param node The node to check + * @param now The current timestamp + * @return The number of milliseconds to wait. + */ + @Override + public long connectionDelay(Node node, long now) { + return connectionStates.connectionDelay(node.idString(), now); + } + + // Return the remaining throttling delay in milliseconds if throttling is in progress. Return 0, otherwise. + // This is for testing. + public long throttleDelayMs(Node node, long now) { + return connectionStates.throttleDelayMs(node.idString(), now); + } + + /** + * Return the poll delay in milliseconds based on both connection and throttle delay. + * @param node the connection to check + * @param now the current time in ms + */ + @Override + public long pollDelayMs(Node node, long now) { + return connectionStates.pollDelayMs(node.idString(), now); + } + + /** + * Check if the connection of the node has failed, based on the connection state. Such connection failure are + * usually transient and can be resumed in the next {@link #ready(org.apache.kafka.common.Node, long)} } + * call, but there are cases where transient failures needs to be caught and re-acted upon. + * + * @param node the node to check + * @return true iff the connection has failed and the node is disconnected + */ + @Override + public boolean connectionFailed(Node node) { + return connectionStates.isDisconnected(node.idString()); + } + + /** + * Check if authentication to this node has failed, based on the connection state. Authentication failures are + * propagated without any retries. + * + * @param node the node to check + * @return an AuthenticationException iff authentication has failed, null otherwise + */ + @Override + public AuthenticationException authenticationException(Node node) { + return connectionStates.authenticationException(node.idString()); + } + + /** + * Check if the node with the given id is ready to send more requests. + * + * @param node The node + * @param now The current time in ms + * @return true if the node is ready + */ + @Override + public boolean isReady(Node node, long now) { + // if we need to update our metadata now declare all requests unready to make metadata requests first + // priority + return !metadataUpdater.isUpdateDue(now) && canSendRequest(node.idString(), now); + } + + /** + * Are we connected and ready and able to send more requests to the given connection? + * + * @param node The node + * @param now the current timestamp + */ + private boolean canSendRequest(String node, long now) { + return connectionStates.isReady(node, now) && selector.isChannelReady(node) && + inFlightRequests.canSendMore(node); + } + + /** + * Queue up the given request for sending. Requests can only be sent out to ready nodes. + * @param request The request + * @param now The current timestamp + */ + @Override + public void send(ClientRequest request, long now) { + doSend(request, false, now); + } + + // package-private for testing + void sendInternalMetadataRequest(MetadataRequest.Builder builder, String nodeConnectionId, long now) { + ClientRequest clientRequest = newClientRequest(nodeConnectionId, builder, now, true); + doSend(clientRequest, true, now); + } + + private void doSend(ClientRequest clientRequest, boolean isInternalRequest, long now) { + ensureActive(); + String nodeId = clientRequest.destination(); + if (!isInternalRequest) { + // If this request came from outside the NetworkClient, validate + // that we can send data. If the request is internal, we trust + // that internal code has done this validation. Validation + // will be slightly different for some internal requests (for + // example, ApiVersionsRequests can be sent prior to being in + // READY state.) + if (!canSendRequest(nodeId, now)) + throw new IllegalStateException("Attempt to send a request to node " + nodeId + " which is not ready."); + } + AbstractRequest.Builder builder = clientRequest.requestBuilder(); + try { + NodeApiVersions versionInfo = apiVersions.get(nodeId); + short version; + // Note: if versionInfo is null, we have no server version information. This would be + // the case when sending the initial ApiVersionRequest which fetches the version + // information itself. It is also the case when discoverBrokerVersions is set to false. + if (versionInfo == null) { + version = builder.latestAllowedVersion(); + if (discoverBrokerVersions && log.isTraceEnabled()) + log.trace("No version information found when sending {} with correlation id {} to node {}. " + + "Assuming version {}.", clientRequest.apiKey(), clientRequest.correlationId(), nodeId, version); + } else { + version = versionInfo.latestUsableVersion(clientRequest.apiKey(), builder.oldestAllowedVersion(), + builder.latestAllowedVersion()); + } + // The call to build may also throw UnsupportedVersionException, if there are essential + // fields that cannot be represented in the chosen version. + doSend(clientRequest, isInternalRequest, now, builder.build(version)); + } catch (UnsupportedVersionException unsupportedVersionException) { + // If the version is not supported, skip sending the request over the wire. + // Instead, simply add it to the local queue of aborted requests. + log.debug("Version mismatch when attempting to send {} with correlation id {} to {}", builder, + clientRequest.correlationId(), clientRequest.destination(), unsupportedVersionException); + ClientResponse clientResponse = new ClientResponse(clientRequest.makeHeader(builder.latestAllowedVersion()), + clientRequest.callback(), clientRequest.destination(), now, now, + false, unsupportedVersionException, null, null); + + if (!isInternalRequest) + abortedSends.add(clientResponse); + else if (clientRequest.apiKey() == ApiKeys.METADATA) + metadataUpdater.handleFailedRequest(now, Optional.of(unsupportedVersionException)); + } + } + + private void doSend(ClientRequest clientRequest, boolean isInternalRequest, long now, AbstractRequest request) { + String destination = clientRequest.destination(); + RequestHeader header = clientRequest.makeHeader(request.version()); + if (log.isDebugEnabled()) { + log.debug("Sending {} request with header {} and timeout {} to node {}: {}", + clientRequest.apiKey(), header, clientRequest.requestTimeoutMs(), destination, request); + } + Send send = request.toSend(header); + InFlightRequest inFlightRequest = new InFlightRequest( + clientRequest, + header, + isInternalRequest, + request, + send, + now); + this.inFlightRequests.add(inFlightRequest); + selector.send(new NetworkSend(clientRequest.destination(), send)); + } + + /** + * Do actual reads and writes to sockets. + * + * @param timeout The maximum amount of time to wait (in ms) for responses if there are none immediately, + * must be non-negative. The actual timeout will be the minimum of timeout, request timeout and + * metadata timeout + * @param now The current time in milliseconds + * @return The list of responses received + */ + @Override + public List poll(long timeout, long now) { + ensureActive(); + + if (!abortedSends.isEmpty()) { + // If there are aborted sends because of unsupported version exceptions or disconnects, + // handle them immediately without waiting for Selector#poll. + List responses = new ArrayList<>(); + handleAbortedSends(responses); + completeResponses(responses); + return responses; + } + + long metadataTimeout = metadataUpdater.maybeUpdate(now); + try { + this.selector.poll(Utils.min(timeout, metadataTimeout, defaultRequestTimeoutMs)); + } catch (IOException e) { + log.error("Unexpected error during I/O", e); + } + + // process completed actions + long updatedNow = this.time.milliseconds(); + List responses = new ArrayList<>(); + handleCompletedSends(responses, updatedNow); + handleCompletedReceives(responses, updatedNow); + handleDisconnections(responses, updatedNow); + handleConnections(); + handleInitiateApiVersionRequests(updatedNow); + handleTimedOutConnections(responses, updatedNow); + handleTimedOutRequests(responses, updatedNow); + completeResponses(responses); + + return responses; + } + + private void completeResponses(List responses) { + for (ClientResponse response : responses) { + try { + response.onComplete(); + } catch (Exception e) { + log.error("Uncaught error in request completion:", e); + } + } + } + + /** + * Get the number of in-flight requests + */ + @Override + public int inFlightRequestCount() { + return this.inFlightRequests.count(); + } + + @Override + public boolean hasInFlightRequests() { + return !this.inFlightRequests.isEmpty(); + } + + /** + * Get the number of in-flight requests for a given node + */ + @Override + public int inFlightRequestCount(String node) { + return this.inFlightRequests.count(node); + } + + @Override + public boolean hasInFlightRequests(String node) { + return !this.inFlightRequests.isEmpty(node); + } + + @Override + public boolean hasReadyNodes(long now) { + return connectionStates.hasReadyNodes(now); + } + + /** + * Interrupt the client if it is blocked waiting on I/O. + */ + @Override + public void wakeup() { + this.selector.wakeup(); + } + + @Override + public void initiateClose() { + if (state.compareAndSet(State.ACTIVE, State.CLOSING)) { + wakeup(); + } + } + + @Override + public boolean active() { + return state.get() == State.ACTIVE; + } + + private void ensureActive() { + if (!active()) + throw new DisconnectException("NetworkClient is no longer active, state is " + state); + } + + /** + * Close the network client + */ + @Override + public void close() { + state.compareAndSet(State.ACTIVE, State.CLOSING); + if (state.compareAndSet(State.CLOSING, State.CLOSED)) { + this.selector.close(); + this.metadataUpdater.close(); + } else { + log.warn("Attempting to close NetworkClient that has already been closed."); + } + } + + /** + * Choose the node with the fewest outstanding requests which is at least eligible for connection. This method will + * prefer a node with an existing connection, but will potentially choose a node for which we don't yet have a + * connection if all existing connections are in use. If no connection exists, this method will prefer a node + * with least recent connection attempts. This method will never choose a node for which there is no + * existing connection and from which we have disconnected within the reconnect backoff period, or an active + * connection which is being throttled. + * + * @return The node with the fewest in-flight requests. + */ + @Override + public Node leastLoadedNode(long now) { + List nodes = this.metadataUpdater.fetchNodes(); + if (nodes.isEmpty()) + throw new IllegalStateException("There are no nodes in the Kafka cluster"); + int inflight = Integer.MAX_VALUE; + + Node foundConnecting = null; + Node foundCanConnect = null; + Node foundReady = null; + + int offset = this.randOffset.nextInt(nodes.size()); + for (int i = 0; i < nodes.size(); i++) { + int idx = (offset + i) % nodes.size(); + Node node = nodes.get(idx); + if (canSendRequest(node.idString(), now)) { + int currInflight = this.inFlightRequests.count(node.idString()); + if (currInflight == 0) { + // if we find an established connection with no in-flight requests we can stop right away + log.trace("Found least loaded node {} connected with no in-flight requests", node); + return node; + } else if (currInflight < inflight) { + // otherwise if this is the best we have found so far, record that + inflight = currInflight; + foundReady = node; + } + } else if (connectionStates.isPreparingConnection(node.idString())) { + foundConnecting = node; + } else if (canConnect(node, now)) { + if (foundCanConnect == null || + this.connectionStates.lastConnectAttemptMs(foundCanConnect.idString()) > + this.connectionStates.lastConnectAttemptMs(node.idString())) { + foundCanConnect = node; + } + } else { + log.trace("Removing node {} from least loaded node selection since it is neither ready " + + "for sending or connecting", node); + } + } + + // We prefer established connections if possible. Otherwise, we will wait for connections + // which are being established before connecting to new nodes. + if (foundReady != null) { + log.trace("Found least loaded node {} with {} inflight requests", foundReady, inflight); + return foundReady; + } else if (foundConnecting != null) { + log.trace("Found least loaded connecting node {}", foundConnecting); + return foundConnecting; + } else if (foundCanConnect != null) { + log.trace("Found least loaded node {} with no active connection", foundCanConnect); + return foundCanConnect; + } else { + log.trace("Least loaded node selection failed to find an available node"); + return null; + } + } + + public static AbstractResponse parseResponse(ByteBuffer responseBuffer, RequestHeader requestHeader) { + try { + return AbstractResponse.parseResponse(responseBuffer, requestHeader); + } catch (BufferUnderflowException e) { + throw new SchemaException("Buffer underflow while parsing response for request with header " + requestHeader, e); + } catch (CorrelationIdMismatchException e) { + if (SaslClientAuthenticator.isReserved(requestHeader.correlationId()) + && !SaslClientAuthenticator.isReserved(e.responseCorrelationId())) + throw new SchemaException("The response is unrelated to Sasl request since its correlation id is " + + e.responseCorrelationId() + " and the reserved range for Sasl request is [ " + + SaslClientAuthenticator.MIN_RESERVED_CORRELATION_ID + "," + + SaslClientAuthenticator.MAX_RESERVED_CORRELATION_ID + "]"); + else { + throw e; + } + } + } + + /** + * Post process disconnection of a node + * + * @param responses The list of responses to update + * @param nodeId Id of the node to be disconnected + * @param now The current time + * @param disconnectState The state of the disconnected channel + */ + private void processDisconnection(List responses, + String nodeId, + long now, + ChannelState disconnectState) { + connectionStates.disconnected(nodeId, now); + apiVersions.remove(nodeId); + nodesNeedingApiVersionsFetch.remove(nodeId); + switch (disconnectState.state()) { + case AUTHENTICATION_FAILED: + AuthenticationException exception = disconnectState.exception(); + connectionStates.authenticationFailed(nodeId, now, exception); + log.error("Connection to node {} ({}) failed authentication due to: {}", nodeId, + disconnectState.remoteAddress(), exception.getMessage()); + break; + case AUTHENTICATE: + log.warn("Connection to node {} ({}) terminated during authentication. This may happen " + + "due to any of the following reasons: (1) Authentication failed due to invalid " + + "credentials with brokers older than 1.0.0, (2) Firewall blocking Kafka TLS " + + "traffic (eg it may only allow HTTPS traffic), (3) Transient network issue.", + nodeId, disconnectState.remoteAddress()); + break; + case NOT_CONNECTED: + log.warn("Connection to node {} ({}) could not be established. Broker may not be available.", nodeId, disconnectState.remoteAddress()); + break; + default: + break; // Disconnections in other states are logged at debug level in Selector + } + + cancelInFlightRequests(nodeId, now, responses); + metadataUpdater.handleServerDisconnect(now, nodeId, Optional.ofNullable(disconnectState.exception())); + } + + /** + * Iterate over all the inflight requests and expire any requests that have exceeded the configured requestTimeout. + * The connection to the node associated with the request will be terminated and will be treated as a disconnection. + * + * @param responses The list of responses to update + * @param now The current time + */ + private void handleTimedOutRequests(List responses, long now) { + List nodeIds = this.inFlightRequests.nodesWithTimedOutRequests(now); + for (String nodeId : nodeIds) { + // close connection to the node + this.selector.close(nodeId); + log.info("Disconnecting from node {} due to request timeout.", nodeId); + processDisconnection(responses, nodeId, now, ChannelState.LOCAL_CLOSE); + } + } + + private void handleAbortedSends(List responses) { + responses.addAll(abortedSends); + abortedSends.clear(); + } + + /** + * Handle socket channel connection timeout. The timeout will hit iff a connection + * stays at the ConnectionState.CONNECTING state longer than the timeout value, + * as indicated by ClusterConnectionStates.NodeConnectionState. + * + * @param responses The list of responses to update + * @param now The current time + */ + private void handleTimedOutConnections(List responses, long now) { + List nodes = connectionStates.nodesWithConnectionSetupTimeout(now); + for (String nodeId : nodes) { + this.selector.close(nodeId); + log.info( + "Disconnecting from node {} due to socket connection setup timeout. " + + "The timeout value is {} ms.", + nodeId, + connectionStates.connectionSetupTimeoutMs(nodeId)); + processDisconnection(responses, nodeId, now, ChannelState.LOCAL_CLOSE); + } + } + + /** + * Handle any completed request send. In particular if no response is expected consider the request complete. + * + * @param responses The list of responses to update + * @param now The current time + */ + private void handleCompletedSends(List responses, long now) { + // if no response is expected then when the send is completed, return it + for (NetworkSend send : this.selector.completedSends()) { + InFlightRequest request = this.inFlightRequests.lastSent(send.destinationId()); + if (!request.expectResponse) { + this.inFlightRequests.completeLastSent(send.destinationId()); + responses.add(request.completed(null, now)); + } + } + } + + /** + * If a response from a node includes a non-zero throttle delay and client-side throttling has been enabled for + * the connection to the node, throttle the connection for the specified delay. + * + * @param response the response + * @param apiVersion the API version of the response + * @param nodeId the id of the node + * @param now The current time + */ + private void maybeThrottle(AbstractResponse response, short apiVersion, String nodeId, long now) { + int throttleTimeMs = response.throttleTimeMs(); + if (throttleTimeMs > 0 && response.shouldClientThrottle(apiVersion)) { + connectionStates.throttle(nodeId, now + throttleTimeMs); + log.trace("Connection to node {} is throttled for {} ms until timestamp {}", nodeId, throttleTimeMs, + now + throttleTimeMs); + } + } + + /** + * Handle any completed receives and update the response list with the responses received. + * + * @param responses The list of responses to update + * @param now The current time + */ + private void handleCompletedReceives(List responses, long now) { + for (NetworkReceive receive : this.selector.completedReceives()) { + String source = receive.source(); + InFlightRequest req = inFlightRequests.completeNext(source); + + AbstractResponse response = parseResponse(receive.payload(), req.header); + if (throttleTimeSensor != null) + throttleTimeSensor.record(response.throttleTimeMs(), now); + + if (log.isDebugEnabled()) { + log.debug("Received {} response from node {} for request with header {}: {}", + req.header.apiKey(), req.destination, req.header, response); + } + + // If the received response includes a throttle delay, throttle the connection. + maybeThrottle(response, req.header.apiVersion(), req.destination, now); + if (req.isInternalRequest && response instanceof MetadataResponse) + metadataUpdater.handleSuccessfulResponse(req.header, now, (MetadataResponse) response); + else if (req.isInternalRequest && response instanceof ApiVersionsResponse) + handleApiVersionsResponse(responses, req, now, (ApiVersionsResponse) response); + else + responses.add(req.completed(response, now)); + } + } + + private void handleApiVersionsResponse(List responses, + InFlightRequest req, long now, ApiVersionsResponse apiVersionsResponse) { + final String node = req.destination; + if (apiVersionsResponse.data().errorCode() != Errors.NONE.code()) { + if (req.request.version() == 0 || apiVersionsResponse.data().errorCode() != Errors.UNSUPPORTED_VERSION.code()) { + log.warn("Received error {} from node {} when making an ApiVersionsRequest with correlation id {}. Disconnecting.", + Errors.forCode(apiVersionsResponse.data().errorCode()), node, req.header.correlationId()); + this.selector.close(node); + processDisconnection(responses, node, now, ChannelState.LOCAL_CLOSE); + } else { + // Starting from Apache Kafka 2.4, ApiKeys field is populated with the supported versions of + // the ApiVersionsRequest when an UNSUPPORTED_VERSION error is returned. + // If not provided, the client falls back to version 0. + short maxApiVersion = 0; + if (apiVersionsResponse.data().apiKeys().size() > 0) { + ApiVersion apiVersion = apiVersionsResponse.data().apiKeys().find(ApiKeys.API_VERSIONS.id); + if (apiVersion != null) { + maxApiVersion = apiVersion.maxVersion(); + } + } + nodesNeedingApiVersionsFetch.put(node, new ApiVersionsRequest.Builder(maxApiVersion)); + } + return; + } + NodeApiVersions nodeVersionInfo = new NodeApiVersions(apiVersionsResponse.data().apiKeys()); + apiVersions.update(node, nodeVersionInfo); + this.connectionStates.ready(node); + log.debug("Node {} has finalized features epoch: {}, finalized features: {}, supported features: {}, API versions: {}.", + node, apiVersionsResponse.data().finalizedFeaturesEpoch(), apiVersionsResponse.data().finalizedFeatures(), + apiVersionsResponse.data().supportedFeatures(), nodeVersionInfo); + } + + /** + * Handle any disconnected connections + * + * @param responses The list of responses that completed with the disconnection + * @param now The current time + */ + private void handleDisconnections(List responses, long now) { + for (Map.Entry entry : this.selector.disconnected().entrySet()) { + String node = entry.getKey(); + log.info("Node {} disconnected.", node); + processDisconnection(responses, node, now, entry.getValue()); + } + } + + /** + * Record any newly completed connections + */ + private void handleConnections() { + for (String node : this.selector.connected()) { + // We are now connected. Note that we might not still be able to send requests. For instance, + // if SSL is enabled, the SSL handshake happens after the connection is established. + // Therefore, it is still necessary to check isChannelReady before attempting to send on this + // connection. + if (discoverBrokerVersions) { + this.connectionStates.checkingApiVersions(node); + nodesNeedingApiVersionsFetch.put(node, new ApiVersionsRequest.Builder()); + log.debug("Completed connection to node {}. Fetching API versions.", node); + } else { + this.connectionStates.ready(node); + log.debug("Completed connection to node {}. Ready.", node); + } + } + } + + private void handleInitiateApiVersionRequests(long now) { + Iterator> iter = nodesNeedingApiVersionsFetch.entrySet().iterator(); + while (iter.hasNext()) { + Map.Entry entry = iter.next(); + String node = entry.getKey(); + if (selector.isChannelReady(node) && inFlightRequests.canSendMore(node)) { + log.debug("Initiating API versions fetch from node {}.", node); + ApiVersionsRequest.Builder apiVersionRequestBuilder = entry.getValue(); + ClientRequest clientRequest = newClientRequest(node, apiVersionRequestBuilder, now, true); + doSend(clientRequest, true, now); + iter.remove(); + } + } + } + + /** + * Initiate a connection to the given node + * @param node the node to connect to + * @param now current time in epoch milliseconds + */ + private void initiateConnect(Node node, long now) { + String nodeConnectionId = node.idString(); + try { + connectionStates.connecting(nodeConnectionId, now, node.host()); + InetAddress address = connectionStates.currentAddress(nodeConnectionId); + log.debug("Initiating connection to node {} using address {}", node, address); + selector.connect(nodeConnectionId, + new InetSocketAddress(address, node.port()), + this.socketSendBuffer, + this.socketReceiveBuffer); + } catch (IOException e) { + log.warn("Error connecting to node {}", node, e); + // Attempt failed, we'll try again after the backoff + connectionStates.disconnected(nodeConnectionId, now); + // Notify metadata updater of the connection failure + metadataUpdater.handleServerDisconnect(now, nodeConnectionId, Optional.empty()); + } + } + + class DefaultMetadataUpdater implements MetadataUpdater { + + /* the current cluster metadata */ + private final Metadata metadata; + + // Defined if there is a request in progress, null otherwise + private InProgressData inProgress; + + DefaultMetadataUpdater(Metadata metadata) { + this.metadata = metadata; + this.inProgress = null; + } + + @Override + public List fetchNodes() { + return metadata.fetch().nodes(); + } + + @Override + public boolean isUpdateDue(long now) { + return !hasFetchInProgress() && this.metadata.timeToNextUpdate(now) == 0; + } + + private boolean hasFetchInProgress() { + return inProgress != null; + } + + @Override + public long maybeUpdate(long now) { + // should we update our metadata? + long timeToNextMetadataUpdate = metadata.timeToNextUpdate(now); + long waitForMetadataFetch = hasFetchInProgress() ? defaultRequestTimeoutMs : 0; + + long metadataTimeout = Math.max(timeToNextMetadataUpdate, waitForMetadataFetch); + if (metadataTimeout > 0) { + return metadataTimeout; + } + + // Beware that the behavior of this method and the computation of timeouts for poll() are + // highly dependent on the behavior of leastLoadedNode. + Node node = leastLoadedNode(now); + if (node == null) { + log.debug("Give up sending metadata request since no node is available"); + return reconnectBackoffMs; + } + + return maybeUpdate(now, node); + } + + @Override + public void handleServerDisconnect(long now, String destinationId, Optional maybeFatalException) { + Cluster cluster = metadata.fetch(); + // 'processDisconnection' generates warnings for misconfigured bootstrap server configuration + // resulting in 'Connection Refused' and misconfigured security resulting in authentication failures. + // The warning below handles the case where a connection to a broker was established, but was disconnected + // before metadata could be obtained. + if (cluster.isBootstrapConfigured()) { + int nodeId = Integer.parseInt(destinationId); + Node node = cluster.nodeById(nodeId); + if (node != null) + log.warn("Bootstrap broker {} disconnected", node); + } + + // If we have a disconnect while an update is due, we treat it as a failed update + // so that we can backoff properly + if (isUpdateDue(now)) + handleFailedRequest(now, Optional.empty()); + + maybeFatalException.ifPresent(metadata::fatalError); + + // The disconnect may be the result of stale metadata, so request an update + metadata.requestUpdate(); + } + + @Override + public void handleFailedRequest(long now, Optional maybeFatalException) { + maybeFatalException.ifPresent(metadata::fatalError); + metadata.failedUpdate(now); + inProgress = null; + } + + @Override + public void handleSuccessfulResponse(RequestHeader requestHeader, long now, MetadataResponse response) { + // If any partition has leader with missing listeners, log up to ten of these partitions + // for diagnosing broker configuration issues. + // This could be a transient issue if listeners were added dynamically to brokers. + List missingListenerPartitions = response.topicMetadata().stream().flatMap(topicMetadata -> + topicMetadata.partitionMetadata().stream() + .filter(partitionMetadata -> partitionMetadata.error == Errors.LISTENER_NOT_FOUND) + .map(partitionMetadata -> new TopicPartition(topicMetadata.topic(), partitionMetadata.partition()))) + .collect(Collectors.toList()); + if (!missingListenerPartitions.isEmpty()) { + int count = missingListenerPartitions.size(); + log.warn("{} partitions have leader brokers without a matching listener, including {}", + count, missingListenerPartitions.subList(0, Math.min(10, count))); + } + + // Check if any topic's metadata failed to get updated + Map errors = response.errors(); + if (!errors.isEmpty()) + log.warn("Error while fetching metadata with correlation id {} : {}", requestHeader.correlationId(), errors); + + // When talking to the startup phase of a broker, it is possible to receive an empty metadata set, which + // we should retry later. + if (response.brokers().isEmpty()) { + log.trace("Ignoring empty metadata response with correlation id {}.", requestHeader.correlationId()); + this.metadata.failedUpdate(now); + } else { + this.metadata.update(inProgress.requestVersion, response, inProgress.isPartialUpdate, now); + } + + inProgress = null; + } + + @Override + public void close() { + this.metadata.close(); + } + + /** + * Return true if there's at least one connection establishment is currently underway + */ + private boolean isAnyNodeConnecting() { + for (Node node : fetchNodes()) { + if (connectionStates.isConnecting(node.idString())) { + return true; + } + } + return false; + } + + /** + * Add a metadata request to the list of sends if we can make one + */ + private long maybeUpdate(long now, Node node) { + String nodeConnectionId = node.idString(); + + if (canSendRequest(nodeConnectionId, now)) { + Metadata.MetadataRequestAndVersion requestAndVersion = metadata.newMetadataRequestAndVersion(now); + MetadataRequest.Builder metadataRequest = requestAndVersion.requestBuilder; + log.debug("Sending metadata request {} to node {}", metadataRequest, node); + sendInternalMetadataRequest(metadataRequest, nodeConnectionId, now); + inProgress = new InProgressData(requestAndVersion.requestVersion, requestAndVersion.isPartialUpdate); + return defaultRequestTimeoutMs; + } + + // If there's any connection establishment underway, wait until it completes. This prevents + // the client from unnecessarily connecting to additional nodes while a previous connection + // attempt has not been completed. + if (isAnyNodeConnecting()) { + // Strictly the timeout we should return here is "connect timeout", but as we don't + // have such application level configuration, using reconnect backoff instead. + return reconnectBackoffMs; + } + + if (connectionStates.canConnect(nodeConnectionId, now)) { + // We don't have a connection to this node right now, make one + log.debug("Initialize connection to node {} for sending metadata request", node); + initiateConnect(node, now); + return reconnectBackoffMs; + } + + // connected, but can't send more OR connecting + // In either case, we just need to wait for a network event to let us know the selected + // connection might be usable again. + return Long.MAX_VALUE; + } + + public class InProgressData { + public final int requestVersion; + public final boolean isPartialUpdate; + + private InProgressData(int requestVersion, boolean isPartialUpdate) { + this.requestVersion = requestVersion; + this.isPartialUpdate = isPartialUpdate; + } + } + + } + + @Override + public ClientRequest newClientRequest(String nodeId, + AbstractRequest.Builder requestBuilder, + long createdTimeMs, + boolean expectResponse) { + return newClientRequest(nodeId, requestBuilder, createdTimeMs, expectResponse, defaultRequestTimeoutMs, null); + } + + // visible for testing + int nextCorrelationId() { + if (SaslClientAuthenticator.isReserved(correlation)) { + // the numeric overflow is fine as negative values is acceptable + correlation = SaslClientAuthenticator.MAX_RESERVED_CORRELATION_ID + 1; + } + return correlation++; + } + + @Override + public ClientRequest newClientRequest(String nodeId, + AbstractRequest.Builder requestBuilder, + long createdTimeMs, + boolean expectResponse, + int requestTimeoutMs, + RequestCompletionHandler callback) { + return new ClientRequest(nodeId, requestBuilder, nextCorrelationId(), clientId, createdTimeMs, expectResponse, + requestTimeoutMs, callback); + } + + public boolean discoverBrokerVersions() { + return discoverBrokerVersions; + } + + static class InFlightRequest { + final RequestHeader header; + final String destination; + final RequestCompletionHandler callback; + final boolean expectResponse; + final AbstractRequest request; + final boolean isInternalRequest; // used to flag requests which are initiated internally by NetworkClient + final Send send; + final long sendTimeMs; + final long createdTimeMs; + final long requestTimeoutMs; + + public InFlightRequest(ClientRequest clientRequest, + RequestHeader header, + boolean isInternalRequest, + AbstractRequest request, + Send send, + long sendTimeMs) { + this(header, + clientRequest.requestTimeoutMs(), + clientRequest.createdTimeMs(), + clientRequest.destination(), + clientRequest.callback(), + clientRequest.expectResponse(), + isInternalRequest, + request, + send, + sendTimeMs); + } + + public InFlightRequest(RequestHeader header, + int requestTimeoutMs, + long createdTimeMs, + String destination, + RequestCompletionHandler callback, + boolean expectResponse, + boolean isInternalRequest, + AbstractRequest request, + Send send, + long sendTimeMs) { + this.header = header; + this.requestTimeoutMs = requestTimeoutMs; + this.createdTimeMs = createdTimeMs; + this.destination = destination; + this.callback = callback; + this.expectResponse = expectResponse; + this.isInternalRequest = isInternalRequest; + this.request = request; + this.send = send; + this.sendTimeMs = sendTimeMs; + } + + public long timeElapsedSinceSendMs(long currentTimeMs) { + return Math.max(0, currentTimeMs - sendTimeMs); + } + + public long timeElapsedSinceCreateMs(long currentTimeMs) { + return Math.max(0, currentTimeMs - createdTimeMs); + } + + public ClientResponse completed(AbstractResponse response, long timeMs) { + return new ClientResponse(header, callback, destination, createdTimeMs, timeMs, + false, null, null, response); + } + + public ClientResponse disconnected(long timeMs, AuthenticationException authenticationException) { + return new ClientResponse(header, callback, destination, createdTimeMs, timeMs, + true, null, authenticationException, null); + } + + @Override + public String toString() { + return "InFlightRequest(header=" + header + + ", destination=" + destination + + ", expectResponse=" + expectResponse + + ", createdTimeMs=" + createdTimeMs + + ", sendTimeMs=" + sendTimeMs + + ", isInternalRequest=" + isInternalRequest + + ", request=" + request + + ", callback=" + callback + + ", send=" + send + ")"; + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/NetworkClientUtils.java b/clients/src/main/java/org/apache/kafka/clients/NetworkClientUtils.java new file mode 100644 index 0000000..4c4d635 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/NetworkClientUtils.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients; + +import org.apache.kafka.common.Node; +import org.apache.kafka.common.errors.DisconnectException; +import org.apache.kafka.common.utils.Time; + +import java.io.IOException; +import java.util.List; + +/** + * Provides additional utilities for {@link NetworkClient} (e.g. to implement blocking behaviour). + */ +public final class NetworkClientUtils { + + private NetworkClientUtils() {} + + /** + * Checks whether the node is currently connected, first calling `client.poll` to ensure that any pending + * disconnects have been processed. + * + * This method can be used to check the status of a connection prior to calling the blocking version to be able + * to tell whether the latter completed a new connection. + */ + public static boolean isReady(KafkaClient client, Node node, long currentTime) { + client.poll(0, currentTime); + return client.isReady(node, currentTime); + } + + /** + * Invokes `client.poll` to discard pending disconnects, followed by `client.ready` and 0 or more `client.poll` + * invocations until the connection to `node` is ready, the timeoutMs expires or the connection fails. + * + * It returns `true` if the call completes normally or `false` if the timeoutMs expires. If the connection fails, + * an `IOException` is thrown instead. Note that if the `NetworkClient` has been configured with a positive + * connection timeoutMs, it is possible for this method to raise an `IOException` for a previous connection which + * has recently disconnected. If authentication to the node fails, an `AuthenticationException` is thrown. + * + * This method is useful for implementing blocking behaviour on top of the non-blocking `NetworkClient`, use it with + * care. + */ + public static boolean awaitReady(KafkaClient client, Node node, Time time, long timeoutMs) throws IOException { + if (timeoutMs < 0) { + throw new IllegalArgumentException("Timeout needs to be greater than 0"); + } + long startTime = time.milliseconds(); + + if (isReady(client, node, startTime) || client.ready(node, startTime)) + return true; + + long attemptStartTime = time.milliseconds(); + while (!client.isReady(node, attemptStartTime) && attemptStartTime - startTime < timeoutMs) { + if (client.connectionFailed(node)) { + throw new IOException("Connection to " + node + " failed."); + } + long pollTimeout = timeoutMs - (attemptStartTime - startTime); // initialize in this order to avoid overflow + client.poll(pollTimeout, attemptStartTime); + if (client.authenticationException(node) != null) + throw client.authenticationException(node); + attemptStartTime = time.milliseconds(); + } + return client.isReady(node, attemptStartTime); + } + + /** + * Invokes `client.send` followed by 1 or more `client.poll` invocations until a response is received or a + * disconnection happens (which can happen for a number of reasons including a request timeout). + * + * In case of a disconnection, an `IOException` is thrown. + * If shutdown is initiated on the client during this method, an IOException is thrown. + * + * This method is useful for implementing blocking behaviour on top of the non-blocking `NetworkClient`, use it with + * care. + */ + public static ClientResponse sendAndReceive(KafkaClient client, ClientRequest request, Time time) throws IOException { + try { + client.send(request, time.milliseconds()); + while (client.active()) { + List responses = client.poll(Long.MAX_VALUE, time.milliseconds()); + for (ClientResponse response : responses) { + if (response.requestHeader().correlationId() == request.correlationId()) { + if (response.wasDisconnected()) { + throw new IOException("Connection to " + response.destination() + " was disconnected before the response was read"); + } + if (response.versionMismatch() != null) { + throw response.versionMismatch(); + } + return response; + } + } + } + throw new IOException("Client was shutdown before response was read"); + } catch (DisconnectException e) { + if (client.active()) + throw e; + else + throw new IOException("Client was shutdown before response was read"); + + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/NodeApiVersions.java b/clients/src/main/java/org/apache/kafka/clients/NodeApiVersions.java new file mode 100644 index 0000000..3c09f0e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/NodeApiVersions.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients; + +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.ApiVersionsResponseData.ApiVersion; +import org.apache.kafka.common.message.ApiVersionsResponseData.ApiVersionCollection; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.requests.ApiVersionsResponse; +import org.apache.kafka.common.utils.Utils; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.EnumMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.TreeMap; + +/** + * An internal class which represents the API versions supported by a particular node. + */ +public class NodeApiVersions { + + // A map of the usable versions of each API, keyed by the ApiKeys instance + private final Map supportedVersions = new EnumMap<>(ApiKeys.class); + + // List of APIs which the broker supports, but which are unknown to the client + private final List unknownApis = new ArrayList<>(); + + /** + * Create a NodeApiVersions object with the current ApiVersions. + * + * @return A new NodeApiVersions object. + */ + public static NodeApiVersions create() { + return create(Collections.emptyList()); + } + + /** + * Create a NodeApiVersions object. + * + * @param overrides API versions to override. Any ApiVersion not specified here will be set to the current client + * value. + * @return A new NodeApiVersions object. + */ + public static NodeApiVersions create(Collection overrides) { + List apiVersions = new LinkedList<>(overrides); + for (ApiKeys apiKey : ApiKeys.zkBrokerApis()) { + boolean exists = false; + for (ApiVersion apiVersion : apiVersions) { + if (apiVersion.apiKey() == apiKey.id) { + exists = true; + break; + } + } + if (!exists) apiVersions.add(ApiVersionsResponse.toApiVersion(apiKey)); + } + return new NodeApiVersions(apiVersions); + } + + + /** + * Create a NodeApiVersions object with a single ApiKey. It is mainly used in tests. + * + * @param apiKey ApiKey's id. + * @param minVersion ApiKey's minimum version. + * @param maxVersion ApiKey's maximum version. + * @return A new NodeApiVersions object. + */ + public static NodeApiVersions create(short apiKey, short minVersion, short maxVersion) { + return create(Collections.singleton(new ApiVersion() + .setApiKey(apiKey) + .setMinVersion(minVersion) + .setMaxVersion(maxVersion))); + } + + public NodeApiVersions(ApiVersionCollection nodeApiVersions) { + for (ApiVersion nodeApiVersion : nodeApiVersions) { + if (ApiKeys.hasId(nodeApiVersion.apiKey())) { + ApiKeys nodeApiKey = ApiKeys.forId(nodeApiVersion.apiKey()); + supportedVersions.put(nodeApiKey, nodeApiVersion); + } else { + // Newer brokers may support ApiKeys we don't know about + unknownApis.add(nodeApiVersion); + } + } + } + + public NodeApiVersions(Collection nodeApiVersions) { + for (ApiVersion nodeApiVersion : nodeApiVersions) { + if (ApiKeys.hasId(nodeApiVersion.apiKey())) { + ApiKeys nodeApiKey = ApiKeys.forId(nodeApiVersion.apiKey()); + supportedVersions.put(nodeApiKey, nodeApiVersion); + } else { + // Newer brokers may support ApiKeys we don't know about + unknownApis.add(nodeApiVersion); + } + } + } + + /** + * Return the most recent version supported by both the node and the local software. + */ + public short latestUsableVersion(ApiKeys apiKey) { + return latestUsableVersion(apiKey, apiKey.oldestVersion(), apiKey.latestVersion()); + } + + /** + * Get the latest version supported by the broker within an allowed range of versions + */ + public short latestUsableVersion(ApiKeys apiKey, short oldestAllowedVersion, short latestAllowedVersion) { + if (!supportedVersions.containsKey(apiKey)) + throw new UnsupportedVersionException("The broker does not support " + apiKey); + ApiVersion supportedVersion = supportedVersions.get(apiKey); + Optional intersectVersion = ApiVersionsResponse.intersect(supportedVersion, + new ApiVersion() + .setApiKey(apiKey.id) + .setMinVersion(oldestAllowedVersion) + .setMaxVersion(latestAllowedVersion)); + + if (intersectVersion.isPresent()) + return intersectVersion.get().maxVersion(); + else + throw new UnsupportedVersionException("The broker does not support " + apiKey + + " with version in range [" + oldestAllowedVersion + "," + latestAllowedVersion + "]. The supported" + + " range is [" + supportedVersion.minVersion() + "," + supportedVersion.maxVersion() + "]."); + } + + /** + * Convert the object to a string with no linebreaks.

+ *

+ * This toString method is relatively expensive, so avoid calling it unless debug logging is turned on. + */ + @Override + public String toString() { + return toString(false); + } + + /** + * Convert the object to a string. + * + * @param lineBreaks True if we should add a linebreak after each api. + */ + public String toString(boolean lineBreaks) { + // The apiVersion collection may not be in sorted order. We put it into + // a TreeMap before printing it out to ensure that we always print in + // ascending order. + TreeMap apiKeysText = new TreeMap<>(); + for (ApiVersion supportedVersion : this.supportedVersions.values()) + apiKeysText.put(supportedVersion.apiKey(), apiVersionToText(supportedVersion)); + for (ApiVersion apiVersion : unknownApis) + apiKeysText.put(apiVersion.apiKey(), apiVersionToText(apiVersion)); + + // Also handle the case where some apiKey types are not specified at all in the given ApiVersions, + // which may happen when the remote is too old. + for (ApiKeys apiKey : ApiKeys.zkBrokerApis()) { + if (!apiKeysText.containsKey(apiKey.id)) { + StringBuilder bld = new StringBuilder(); + bld.append(apiKey.name).append("("). + append(apiKey.id).append("): ").append("UNSUPPORTED"); + apiKeysText.put(apiKey.id, bld.toString()); + } + } + String separator = lineBreaks ? ",\n\t" : ", "; + StringBuilder bld = new StringBuilder(); + bld.append("("); + if (lineBreaks) + bld.append("\n\t"); + bld.append(Utils.join(apiKeysText.values(), separator)); + if (lineBreaks) + bld.append("\n"); + bld.append(")"); + return bld.toString(); + } + + private String apiVersionToText(ApiVersion apiVersion) { + StringBuilder bld = new StringBuilder(); + ApiKeys apiKey = null; + if (ApiKeys.hasId(apiVersion.apiKey())) { + apiKey = ApiKeys.forId(apiVersion.apiKey()); + bld.append(apiKey.name).append("(").append(apiKey.id).append("): "); + } else { + bld.append("UNKNOWN(").append(apiVersion.apiKey()).append("): "); + } + + if (apiVersion.minVersion() == apiVersion.maxVersion()) { + bld.append(apiVersion.minVersion()); + } else { + bld.append(apiVersion.minVersion()).append(" to ").append(apiVersion.maxVersion()); + } + + if (apiKey != null) { + ApiVersion supportedVersion = supportedVersions.get(apiKey); + if (apiKey.latestVersion() < supportedVersion.minVersion()) { + bld.append(" [unusable: node too new]"); + } else if (supportedVersion.maxVersion() < apiKey.oldestVersion()) { + bld.append(" [unusable: node too old]"); + } else { + short latestUsableVersion = Utils.min(apiKey.latestVersion(), supportedVersion.maxVersion()); + bld.append(" [usable: ").append(latestUsableVersion).append("]"); + } + } + return bld.toString(); + } + + /** + * Get the version information for a given API. + * + * @param apiKey The api key to lookup + * @return The api version information from the broker or null if it is unsupported + */ + public ApiVersion apiVersion(ApiKeys apiKey) { + return supportedVersions.get(apiKey); + } + + public Map allSupportedApiVersions() { + return supportedVersions; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/RequestCompletionHandler.java b/clients/src/main/java/org/apache/kafka/clients/RequestCompletionHandler.java new file mode 100644 index 0000000..add623f --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/RequestCompletionHandler.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients; + +/** + * A callback interface for attaching an action to be executed when a request is complete and the corresponding response + * has been received. This handler will also be invoked if there is a disconnection while handling the request. + */ +public interface RequestCompletionHandler { + + void onComplete(ClientResponse response); +} diff --git a/clients/src/main/java/org/apache/kafka/clients/StaleMetadataException.java b/clients/src/main/java/org/apache/kafka/clients/StaleMetadataException.java new file mode 100644 index 0000000..dafc2d5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/StaleMetadataException.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients; + +import org.apache.kafka.common.errors.InvalidMetadataException; + +/** + * Thrown when current metadata cannot be used. This is often used as a way to trigger a metadata + * update before retrying another operation. + * + * Note: this is not a public API. + */ +public class StaleMetadataException extends InvalidMetadataException { + private static final long serialVersionUID = 1L; + + public StaleMetadataException() {} + + public StaleMetadataException(String message) { + super(message); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/AbortTransactionOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/AbortTransactionOptions.java new file mode 100644 index 0000000..52dc6b1 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/AbortTransactionOptions.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +@InterfaceStability.Evolving +public class AbortTransactionOptions extends AbstractOptions { + + @Override + public String toString() { + return "AbortTransactionOptions(" + + "timeoutMs=" + timeoutMs + + ')'; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/AbortTransactionResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/AbortTransactionResult.java new file mode 100644 index 0000000..30cbcee --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/AbortTransactionResult.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Map; + +/** + * The result of {@link Admin#abortTransaction(AbortTransactionSpec, AbortTransactionOptions)}. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class AbortTransactionResult { + private final Map> futures; + + AbortTransactionResult(Map> futures) { + this.futures = futures; + } + + /** + * Get a future which completes when the transaction specified by {@link AbortTransactionSpec} + * in the respective call to {@link Admin#abortTransaction(AbortTransactionSpec, AbortTransactionOptions)} + * returns successfully or fails due to an error or timeout. + * + * @return the future + */ + public KafkaFuture all() { + return KafkaFuture.allOf(futures.values().toArray(new KafkaFuture[0])); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/AbortTransactionSpec.java b/clients/src/main/java/org/apache/kafka/clients/admin/AbortTransactionSpec.java new file mode 100644 index 0000000..9eb7057 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/AbortTransactionSpec.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Objects; + +@InterfaceStability.Evolving +public class AbortTransactionSpec { + private final TopicPartition topicPartition; + private final long producerId; + private final short producerEpoch; + private final int coordinatorEpoch; + + public AbortTransactionSpec( + TopicPartition topicPartition, + long producerId, + short producerEpoch, + int coordinatorEpoch + ) { + this.topicPartition = topicPartition; + this.producerId = producerId; + this.producerEpoch = producerEpoch; + this.coordinatorEpoch = coordinatorEpoch; + } + + public TopicPartition topicPartition() { + return topicPartition; + } + + public long producerId() { + return producerId; + } + + public short producerEpoch() { + return producerEpoch; + } + + public int coordinatorEpoch() { + return coordinatorEpoch; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AbortTransactionSpec that = (AbortTransactionSpec) o; + return producerId == that.producerId && + producerEpoch == that.producerEpoch && + coordinatorEpoch == that.coordinatorEpoch && + Objects.equals(topicPartition, that.topicPartition); + } + + @Override + public int hashCode() { + return Objects.hash(topicPartition, producerId, producerEpoch, coordinatorEpoch); + } + + @Override + public String toString() { + return "AbortTransactionSpec(" + + "topicPartition=" + topicPartition + + ", producerId=" + producerId + + ", producerEpoch=" + producerEpoch + + ", coordinatorEpoch=" + coordinatorEpoch + + ')'; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/AbstractOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/AbstractOptions.java new file mode 100644 index 0000000..2312fe4 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/AbstractOptions.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + + +/* + * This class implements the common APIs that are shared by Options classes for various AdminClient commands + */ +public abstract class AbstractOptions { + + protected Integer timeoutMs = null; + + /** + * Set the timeout in milliseconds for this operation or {@code null} if the default api timeout for the + * AdminClient should be used. + */ + @SuppressWarnings("unchecked") + public T timeoutMs(Integer timeoutMs) { + this.timeoutMs = timeoutMs; + return (T) this; + } + + /** + * The timeout in milliseconds for this operation or {@code null} if the default api timeout for the + * AdminClient should be used. + */ + public Integer timeoutMs() { + return timeoutMs; + } + + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/Admin.java b/clients/src/main/java/org/apache/kafka/clients/admin/Admin.java new file mode 100644 index 0000000..377f009 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/Admin.java @@ -0,0 +1,1582 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.ElectionType; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.TopicCollection; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.TopicPartitionReplica; +import org.apache.kafka.common.acl.AclBinding; +import org.apache.kafka.common.acl.AclBindingFilter; +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.errors.FeatureUpdateFailedException; +import org.apache.kafka.common.quota.ClientQuotaAlteration; +import org.apache.kafka.common.quota.ClientQuotaFilter; +import org.apache.kafka.common.requests.LeaveGroupResponse; + +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Properties; +import java.util.Set; + +/** + * The administrative client for Kafka, which supports managing and inspecting topics, brokers, configurations and ACLs. + *

+ * Instances returned from the {@code create} methods of this interface are guaranteed to be thread safe. + * However, the {@link KafkaFuture KafkaFutures} returned from request methods are executed + * by a single thread so it is important that any code which executes on that thread when they complete + * (using {@link KafkaFuture#thenApply(KafkaFuture.Function)}, for example) doesn't block + * for too long. If necessary, processing of results should be passed to another thread. + *

+ * The operations exposed by Admin follow a consistent pattern: + *

    + *
  • Admin instances should be created using {@link Admin#create(Properties)} or {@link Admin#create(Map)}
  • + *
  • Each operation typically has two overloaded methods, one which uses a default set of options and an + * overloaded method where the last parameter is an explicit options object. + *
  • The operation method's first parameter is a {@code Collection} of items to perform + * the operation on. Batching multiple requests into a single call is more efficient and should be + * preferred over multiple calls to the same method. + *
  • The operation methods execute asynchronously. + *
  • Each {@code xxx} operation method returns an {@code XxxResult} class with methods which expose + * {@link KafkaFuture} for accessing the result(s) of the operation. + *
  • Typically an {@code all()} method is provided for getting the overall success/failure of the batch and a + * {@code values()} method provided access to each item in a request batch. + * Other methods may also be provided. + *
  • For synchronous behaviour use {@link KafkaFuture#get()} + *
+ *

+ * Here is a simple example of using an Admin client instance to create a new topic: + *

+ * {@code
+ * Properties props = new Properties();
+ * props.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092");
+ *
+ * try (Admin admin = Admin.create(props)) {
+ *   String topicName = "my-topic";
+ *   int partitions = 12;
+ *   short replicationFactor = 3;
+ *   // Create a compacted topic
+ *   CreateTopicsResult result = admin.createTopics(Collections.singleton(
+ *     new NewTopic(topicName, partitions, replicationFactor)
+ *       .configs(Collections.singletonMap(TopicConfig.CLEANUP_POLICY_CONFIG, TopicConfig.CLEANUP_POLICY_COMPACT))));
+ *
+ *   // Call values() to get the result for a specific topic
+ *   KafkaFuture future = result.values().get(topicName);
+ *
+ *   // Call get() to block until the topic creation is complete or has failed
+ *   // if creation failed the ExecutionException wraps the underlying cause.
+ *   future.get();
+ * }
+ * }
+ * 
+ * + *

Bootstrap and balancing

+ *

+ * The {@code bootstrap.servers} config in the {@code Map} or {@code Properties} passed + * to {@link Admin#create(Properties)} is only used for discovering the brokers in the cluster, + * which the client will then connect to as needed. + * As such, it is sufficient to include only two or three broker addresses to cope with the possibility of brokers + * being unavailable. + *

+ * Different operations necessitate requests being sent to different nodes in the cluster. For example + * {@link #createTopics(Collection)} communicates with the controller, but {@link #describeTopics(Collection)} + * can talk to any broker. When the recipient does not matter the instance will try to use the broker with the + * fewest outstanding requests. + *

+ * The client will transparently retry certain errors which are usually transient. + * For example if the request for {@code createTopics()} get sent to a node which was not the controller + * the metadata would be refreshed and the request re-sent to the controller. + * + *

Broker Compatibility

+ *

+ * The minimum broker version required is 0.10.0.0. Methods with stricter requirements will specify the minimum broker + * version required. + *

+ * This client was introduced in 0.11.0.0 and the API is still evolving. We will try to evolve the API in a compatible + * manner, but we reserve the right to make breaking changes in minor releases, if necessary. We will update the + * {@code InterfaceStability} annotation and this notice once the API is considered stable. + *

+ */ +@InterfaceStability.Evolving +public interface Admin extends AutoCloseable { + + /** + * Create a new Admin with the given configuration. + * + * @param props The configuration. + * @return The new KafkaAdminClient. + */ + static Admin create(Properties props) { + return KafkaAdminClient.createInternal(new AdminClientConfig(props, true), null); + } + + /** + * Create a new Admin with the given configuration. + * + * @param conf The configuration. + * @return The new KafkaAdminClient. + */ + static Admin create(Map conf) { + return KafkaAdminClient.createInternal(new AdminClientConfig(conf, true), null, null); + } + + /** + * Close the Admin and release all associated resources. + *

+ * See {@link #close(Duration)} + */ + @Override + default void close() { + close(Duration.ofMillis(Long.MAX_VALUE)); + } + + /** + * Close the Admin client and release all associated resources. + *

+ * The close operation has a grace period during which current operations will be allowed to + * complete, specified by the given duration. + * New operations will not be accepted during the grace period. Once the grace period is over, + * all operations that have not yet been completed will be aborted with a {@link org.apache.kafka.common.errors.TimeoutException}. + * + * @param timeout The time to use for the wait time. + */ + void close(Duration timeout); + + /** + * Create a batch of new topics with the default options. + *

+ * This is a convenience method for {@link #createTopics(Collection, CreateTopicsOptions)} with default options. + * See the overload for more details. + *

+ * This operation is supported by brokers with version 0.10.1.0 or higher. + * + * @param newTopics The new topics to create. + * @return The CreateTopicsResult. + */ + default CreateTopicsResult createTopics(Collection newTopics) { + return createTopics(newTopics, new CreateTopicsOptions()); + } + + /** + * Create a batch of new topics. + *

+ * This operation is not transactional so it may succeed for some topics while fail for others. + *

+ * It may take several seconds after {@link CreateTopicsResult} returns + * success for all the brokers to become aware that the topics have been created. + * During this time, {@link #listTopics()} and {@link #describeTopics(Collection)} + * may not return information about the new topics. + *

+ * This operation is supported by brokers with version 0.10.1.0 or higher. The validateOnly option is supported + * from version 0.10.2.0. + * + * @param newTopics The new topics to create. + * @param options The options to use when creating the new topics. + * @return The CreateTopicsResult. + */ + CreateTopicsResult createTopics(Collection newTopics, CreateTopicsOptions options); + + /** + * This is a convenience method for {@link #deleteTopics(TopicCollection, DeleteTopicsOptions)} + * with default options. See the overload for more details. + *

+ * This operation is supported by brokers with version 0.10.1.0 or higher. + * + * @param topics The topic names to delete. + * @return The DeleteTopicsResult. + */ + default DeleteTopicsResult deleteTopics(Collection topics) { + return deleteTopics(TopicCollection.ofTopicNames(topics), new DeleteTopicsOptions()); + } + + /** + * This is a convenience method for {@link #deleteTopics(TopicCollection, DeleteTopicsOptions)} + * with default options. See the overload for more details. + *

+ * This operation is supported by brokers with version 0.10.1.0 or higher. + * + * @param topics The topic names to delete. + * @param options The options to use when deleting the topics. + * @return The DeleteTopicsResult. + */ + default DeleteTopicsResult deleteTopics(Collection topics, DeleteTopicsOptions options) { + return deleteTopics(TopicCollection.ofTopicNames(topics), options); + } + + /** + * This is a convenience method for {@link #deleteTopics(TopicCollection, DeleteTopicsOptions)} + * with default options. See the overload for more details. + *

+ * When using topic IDs, this operation is supported by brokers with inter-broker protocol 2.8 or higher. + * When using topic names, this operation is supported by brokers with version 0.10.1.0 or higher. + * + * @param topics The topics to delete. + * @return The DeleteTopicsResult. + */ + default DeleteTopicsResult deleteTopics(TopicCollection topics) { + return deleteTopics(topics, new DeleteTopicsOptions()); + } + + /** + * Delete a batch of topics. + *

+ * This operation is not transactional so it may succeed for some topics while fail for others. + *

+ * It may take several seconds after the {@link DeleteTopicsResult} returns + * success for all the brokers to become aware that the topics are gone. + * During this time, {@link #listTopics()} and {@link #describeTopics(Collection)} + * may continue to return information about the deleted topics. + *

+ * If delete.topic.enable is false on the brokers, deleteTopics will mark + * the topics for deletion, but not actually delete them. The futures will + * return successfully in this case. + *

+ * When using topic IDs, this operation is supported by brokers with inter-broker protocol 2.8 or higher. + * When using topic names, this operation is supported by brokers with version 0.10.1.0 or higher. + * + * @param topics The topics to delete. + * @param options The options to use when deleting the topics. + * @return The DeleteTopicsResult. + */ + DeleteTopicsResult deleteTopics(TopicCollection topics, DeleteTopicsOptions options); + + /** + * List the topics available in the cluster with the default options. + *

+ * This is a convenience method for {@link #listTopics(ListTopicsOptions)} with default options. + * See the overload for more details. + * + * @return The ListTopicsResult. + */ + default ListTopicsResult listTopics() { + return listTopics(new ListTopicsOptions()); + } + + /** + * List the topics available in the cluster. + * + * @param options The options to use when listing the topics. + * @return The ListTopicsResult. + */ + ListTopicsResult listTopics(ListTopicsOptions options); + + /** + * Describe some topics in the cluster, with the default options. + *

+ * This is a convenience method for {@link #describeTopics(Collection, DescribeTopicsOptions)} with + * default options. See the overload for more details. + * + * @param topicNames The names of the topics to describe. + * @return The DescribeTopicsResult. + */ + default DescribeTopicsResult describeTopics(Collection topicNames) { + return describeTopics(topicNames, new DescribeTopicsOptions()); + } + + /** + * Describe some topics in the cluster. + * + * @param topicNames The names of the topics to describe. + * @param options The options to use when describing the topic. + * @return The DescribeTopicsResult. + */ + default DescribeTopicsResult describeTopics(Collection topicNames, DescribeTopicsOptions options) { + return describeTopics(TopicCollection.ofTopicNames(topicNames), options); + } + + /** + * This is a convenience method for {@link #describeTopics(TopicCollection, DescribeTopicsOptions)} + * with default options. See the overload for more details. + *

+ * When using topic IDs, this operation is supported by brokers with version 3.1.0 or higher. + * + * @param topics The topics to describe. + * @return The DescribeTopicsResult. + */ + default DescribeTopicsResult describeTopics(TopicCollection topics) { + return describeTopics(topics, new DescribeTopicsOptions()); + } + + /** + * Describe some topics in the cluster. + * + * When using topic IDs, this operation is supported by brokers with version 3.1.0 or higher. + * + * @param topics The topics to describe. + * @param options The options to use when describing the topics. + * @return The DescribeTopicsResult. + */ + DescribeTopicsResult describeTopics(TopicCollection topics, DescribeTopicsOptions options); + + /** + * Get information about the nodes in the cluster, using the default options. + *

+ * This is a convenience method for {@link #describeCluster(DescribeClusterOptions)} with default options. + * See the overload for more details. + * + * @return The DescribeClusterResult. + */ + default DescribeClusterResult describeCluster() { + return describeCluster(new DescribeClusterOptions()); + } + + /** + * Get information about the nodes in the cluster. + * + * @param options The options to use when getting information about the cluster. + * @return The DescribeClusterResult. + */ + DescribeClusterResult describeCluster(DescribeClusterOptions options); + + /** + * This is a convenience method for {@link #describeAcls(AclBindingFilter, DescribeAclsOptions)} with + * default options. See the overload for more details. + *

+ * This operation is supported by brokers with version 0.11.0.0 or higher. + * + * @param filter The filter to use. + * @return The DescribeAclsResult. + */ + default DescribeAclsResult describeAcls(AclBindingFilter filter) { + return describeAcls(filter, new DescribeAclsOptions()); + } + + /** + * Lists access control lists (ACLs) according to the supplied filter. + *

+ * Note: it may take some time for changes made by {@code createAcls} or {@code deleteAcls} to be reflected + * in the output of {@code describeAcls}. + *

+ * This operation is supported by brokers with version 0.11.0.0 or higher. + * + * @param filter The filter to use. + * @param options The options to use when listing the ACLs. + * @return The DescribeAclsResult. + */ + DescribeAclsResult describeAcls(AclBindingFilter filter, DescribeAclsOptions options); + + /** + * This is a convenience method for {@link #createAcls(Collection, CreateAclsOptions)} with + * default options. See the overload for more details. + *

+ * This operation is supported by brokers with version 0.11.0.0 or higher. + * + * @param acls The ACLs to create + * @return The CreateAclsResult. + */ + default CreateAclsResult createAcls(Collection acls) { + return createAcls(acls, new CreateAclsOptions()); + } + + /** + * Creates access control lists (ACLs) which are bound to specific resources. + *

+ * This operation is not transactional so it may succeed for some ACLs while fail for others. + *

+ * If you attempt to add an ACL that duplicates an existing ACL, no error will be raised, but + * no changes will be made. + *

+ * This operation is supported by brokers with version 0.11.0.0 or higher. + * + * @param acls The ACLs to create + * @param options The options to use when creating the ACLs. + * @return The CreateAclsResult. + */ + CreateAclsResult createAcls(Collection acls, CreateAclsOptions options); + + /** + * This is a convenience method for {@link #deleteAcls(Collection, DeleteAclsOptions)} with default options. + * See the overload for more details. + *

+ * This operation is supported by brokers with version 0.11.0.0 or higher. + * + * @param filters The filters to use. + * @return The DeleteAclsResult. + */ + default DeleteAclsResult deleteAcls(Collection filters) { + return deleteAcls(filters, new DeleteAclsOptions()); + } + + /** + * Deletes access control lists (ACLs) according to the supplied filters. + *

+ * This operation is not transactional so it may succeed for some ACLs while fail for others. + *

+ * This operation is supported by brokers with version 0.11.0.0 or higher. + * + * @param filters The filters to use. + * @param options The options to use when deleting the ACLs. + * @return The DeleteAclsResult. + */ + DeleteAclsResult deleteAcls(Collection filters, DeleteAclsOptions options); + + + /** + * Get the configuration for the specified resources with the default options. + *

+ * This is a convenience method for {@link #describeConfigs(Collection, DescribeConfigsOptions)} with default options. + * See the overload for more details. + *

+ * This operation is supported by brokers with version 0.11.0.0 or higher. + * + * @param resources The resources (topic and broker resource types are currently supported) + * @return The DescribeConfigsResult + */ + default DescribeConfigsResult describeConfigs(Collection resources) { + return describeConfigs(resources, new DescribeConfigsOptions()); + } + + /** + * Get the configuration for the specified resources. + *

+ * The returned configuration includes default values and the isDefault() method can be used to distinguish them + * from user supplied values. + *

+ * The value of config entries where isSensitive() is true is always {@code null} so that sensitive information + * is not disclosed. + *

+ * Config entries where isReadOnly() is true cannot be updated. + *

+ * This operation is supported by brokers with version 0.11.0.0 or higher. + * + * @param resources The resources (topic and broker resource types are currently supported) + * @param options The options to use when describing configs + * @return The DescribeConfigsResult + */ + DescribeConfigsResult describeConfigs(Collection resources, DescribeConfigsOptions options); + + /** + * Update the configuration for the specified resources with the default options. + *

+ * This is a convenience method for {@link #alterConfigs(Map, AlterConfigsOptions)} with default options. + * See the overload for more details. + *

+ * This operation is supported by brokers with version 0.11.0.0 or higher. + * + * @param configs The resources with their configs (topic is the only resource type with configs that can + * be updated currently) + * @return The AlterConfigsResult + * @deprecated Since 2.3. Use {@link #incrementalAlterConfigs(Map)}. + */ + @Deprecated + default AlterConfigsResult alterConfigs(Map configs) { + return alterConfigs(configs, new AlterConfigsOptions()); + } + + /** + * Update the configuration for the specified resources with the default options. + *

+ * Updates are not transactional so they may succeed for some resources while fail for others. The configs for + * a particular resource are updated atomically. + *

+ * This operation is supported by brokers with version 0.11.0.0 or higher. + * + * @param configs The resources with their configs (topic is the only resource type with configs that can + * be updated currently) + * @param options The options to use when describing configs + * @return The AlterConfigsResult + * @deprecated Since 2.3. Use {@link #incrementalAlterConfigs(Map, AlterConfigsOptions)}. + */ + @Deprecated + AlterConfigsResult alterConfigs(Map configs, AlterConfigsOptions options); + + /** + * Incrementally updates the configuration for the specified resources with default options. + *

+ * This is a convenience method for {@link #incrementalAlterConfigs(Map, AlterConfigsOptions)} with default options. + * See the overload for more details. + *

+ * This operation is supported by brokers with version 2.3.0 or higher. + * + * @param configs The resources with their configs + * @return The AlterConfigsResult + */ + default AlterConfigsResult incrementalAlterConfigs(Map> configs) { + return incrementalAlterConfigs(configs, new AlterConfigsOptions()); + } + + /** + * Incrementally update the configuration for the specified resources. + *

+ * Updates are not transactional so they may succeed for some resources while fail for others. The configs for + * a particular resource are updated atomically. + *

+ * The following exceptions can be anticipated when calling {@code get()} on the futures obtained from + * the returned {@link AlterConfigsResult}: + *

    + *
  • {@link org.apache.kafka.common.errors.ClusterAuthorizationException} + * if the authenticated user didn't have alter access to the cluster.
  • + *
  • {@link org.apache.kafka.common.errors.TopicAuthorizationException} + * if the authenticated user didn't have alter access to the Topic.
  • + *
  • {@link org.apache.kafka.common.errors.UnknownTopicOrPartitionException} + * if the Topic doesn't exist.
  • + *
  • {@link org.apache.kafka.common.errors.InvalidRequestException} + * if the request details are invalid. e.g., a configuration key was specified more than once for a resource
  • + *
+ *

+ * This operation is supported by brokers with version 2.3.0 or higher. + * + * @param configs The resources with their configs + * @param options The options to use when altering configs + * @return The AlterConfigsResult + */ + AlterConfigsResult incrementalAlterConfigs(Map> configs, AlterConfigsOptions options); + + /** + * Change the log directory for the specified replicas. If the replica does not exist on the broker, the result + * shows REPLICA_NOT_AVAILABLE for the given replica and the replica will be created in the given log directory on the + * broker when it is created later. If the replica already exists on the broker, the replica will be moved to the given + * log directory if it is not already there. For detailed result, inspect the returned {@link AlterReplicaLogDirsResult} instance. + *

+ * This operation is not transactional so it may succeed for some replicas while fail for others. + *

+ * This is a convenience method for {@link #alterReplicaLogDirs(Map, AlterReplicaLogDirsOptions)} with default options. + * See the overload for more details. + *

+ * This operation is supported by brokers with version 1.1.0 or higher. + * + * @param replicaAssignment The replicas with their log directory absolute path + * @return The AlterReplicaLogDirsResult + */ + default AlterReplicaLogDirsResult alterReplicaLogDirs(Map replicaAssignment) { + return alterReplicaLogDirs(replicaAssignment, new AlterReplicaLogDirsOptions()); + } + + /** + * Change the log directory for the specified replicas. If the replica does not exist on the broker, the result + * shows REPLICA_NOT_AVAILABLE for the given replica and the replica will be created in the given log directory on the + * broker when it is created later. If the replica already exists on the broker, the replica will be moved to the given + * log directory if it is not already there. For detailed result, inspect the returned {@link AlterReplicaLogDirsResult} instance. + *

+ * This operation is not transactional so it may succeed for some replicas while fail for others. + *

+ * This operation is supported by brokers with version 1.1.0 or higher. + * + * @param replicaAssignment The replicas with their log directory absolute path + * @param options The options to use when changing replica dir + * @return The AlterReplicaLogDirsResult + */ + AlterReplicaLogDirsResult alterReplicaLogDirs(Map replicaAssignment, + AlterReplicaLogDirsOptions options); + + /** + * Query the information of all log directories on the given set of brokers + *

+ * This is a convenience method for {@link #describeLogDirs(Collection, DescribeLogDirsOptions)} with default options. + * See the overload for more details. + *

+ * This operation is supported by brokers with version 1.0.0 or higher. + * + * @param brokers A list of brokers + * @return The DescribeLogDirsResult + */ + default DescribeLogDirsResult describeLogDirs(Collection brokers) { + return describeLogDirs(brokers, new DescribeLogDirsOptions()); + } + + /** + * Query the information of all log directories on the given set of brokers + *

+ * This operation is supported by brokers with version 1.0.0 or higher. + * + * @param brokers A list of brokers + * @param options The options to use when querying log dir info + * @return The DescribeLogDirsResult + */ + DescribeLogDirsResult describeLogDirs(Collection brokers, DescribeLogDirsOptions options); + + /** + * Query the replica log directory information for the specified replicas. + *

+ * This is a convenience method for {@link #describeReplicaLogDirs(Collection, DescribeReplicaLogDirsOptions)} + * with default options. See the overload for more details. + *

+ * This operation is supported by brokers with version 1.0.0 or higher. + * + * @param replicas The replicas to query + * @return The DescribeReplicaLogDirsResult + */ + default DescribeReplicaLogDirsResult describeReplicaLogDirs(Collection replicas) { + return describeReplicaLogDirs(replicas, new DescribeReplicaLogDirsOptions()); + } + + /** + * Query the replica log directory information for the specified replicas. + *

+ * This operation is supported by brokers with version 1.0.0 or higher. + * + * @param replicas The replicas to query + * @param options The options to use when querying replica log dir info + * @return The DescribeReplicaLogDirsResult + */ + DescribeReplicaLogDirsResult describeReplicaLogDirs(Collection replicas, DescribeReplicaLogDirsOptions options); + + /** + * Increase the number of partitions of the topics given as the keys of {@code newPartitions} + * according to the corresponding values. If partitions are increased for a topic that has a key, + * the partition logic or ordering of the messages will be affected. + *

+ * This is a convenience method for {@link #createPartitions(Map, CreatePartitionsOptions)} with default options. + * See the overload for more details. + * + * @param newPartitions The topics which should have new partitions created, and corresponding parameters + * for the created partitions. + * @return The CreatePartitionsResult. + */ + default CreatePartitionsResult createPartitions(Map newPartitions) { + return createPartitions(newPartitions, new CreatePartitionsOptions()); + } + + /** + * Increase the number of partitions of the topics given as the keys of {@code newPartitions} + * according to the corresponding values. If partitions are increased for a topic that has a key, + * the partition logic or ordering of the messages will be affected. + *

+ * This operation is not transactional so it may succeed for some topics while fail for others. + *

+ * It may take several seconds after this method returns + * success for all the brokers to become aware that the partitions have been created. + * During this time, {@link #describeTopics(Collection)} + * may not return information about the new partitions. + *

+ * This operation is supported by brokers with version 1.0.0 or higher. + *

+ * The following exceptions can be anticipated when calling {@code get()} on the futures obtained from the + * {@link CreatePartitionsResult#values() values()} method of the returned {@link CreatePartitionsResult} + *

    + *
  • {@link org.apache.kafka.common.errors.AuthorizationException} + * if the authenticated user is not authorized to alter the topic
  • + *
  • {@link org.apache.kafka.common.errors.TimeoutException} + * if the request was not completed in within the given {@link CreatePartitionsOptions#timeoutMs()}.
  • + *
  • {@link org.apache.kafka.common.errors.ReassignmentInProgressException} + * if a partition reassignment is currently in progress
  • + *
  • {@link org.apache.kafka.common.errors.BrokerNotAvailableException} + * if the requested {@link NewPartitions#assignments()} contain a broker that is currently unavailable.
  • + *
  • {@link org.apache.kafka.common.errors.InvalidReplicationFactorException} + * if no {@link NewPartitions#assignments()} are given and it is impossible for the broker to assign + * replicas with the topics replication factor.
  • + *
  • Subclasses of {@link org.apache.kafka.common.KafkaException} + * if the request is invalid in some way.
  • + *
+ * + * @param newPartitions The topics which should have new partitions created, and corresponding parameters + * for the created partitions. + * @param options The options to use when creating the new partitions. + * @return The CreatePartitionsResult. + */ + CreatePartitionsResult createPartitions(Map newPartitions, + CreatePartitionsOptions options); + + /** + * Delete records whose offset is smaller than the given offset of the corresponding partition. + *

+ * This is a convenience method for {@link #deleteRecords(Map, DeleteRecordsOptions)} with default options. + * See the overload for more details. + *

+ * This operation is supported by brokers with version 0.11.0.0 or higher. + * + * @param recordsToDelete The topic partitions and related offsets from which records deletion starts. + * @return The DeleteRecordsResult. + */ + default DeleteRecordsResult deleteRecords(Map recordsToDelete) { + return deleteRecords(recordsToDelete, new DeleteRecordsOptions()); + } + + /** + * Delete records whose offset is smaller than the given offset of the corresponding partition. + *

+ * This operation is supported by brokers with version 0.11.0.0 or higher. + * + * @param recordsToDelete The topic partitions and related offsets from which records deletion starts. + * @param options The options to use when deleting records. + * @return The DeleteRecordsResult. + */ + DeleteRecordsResult deleteRecords(Map recordsToDelete, + DeleteRecordsOptions options); + + /** + * Create a Delegation Token. + *

+ * This is a convenience method for {@link #createDelegationToken(CreateDelegationTokenOptions)} with default options. + * See the overload for more details. + * + * @return The CreateDelegationTokenResult. + */ + default CreateDelegationTokenResult createDelegationToken() { + return createDelegationToken(new CreateDelegationTokenOptions()); + } + + + /** + * Create a Delegation Token. + *

+ * This operation is supported by brokers with version 1.1.0 or higher. + *

+ * The following exceptions can be anticipated when calling {@code get()} on the futures obtained from the + * {@link CreateDelegationTokenResult#delegationToken() delegationToken()} method of the returned {@link CreateDelegationTokenResult} + *

    + *
  • {@link org.apache.kafka.common.errors.UnsupportedByAuthenticationException} + * If the request sent on PLAINTEXT/1-way SSL channels or delegation token authenticated channels.
  • + *
  • {@link org.apache.kafka.common.errors.InvalidPrincipalTypeException} + * if the renewers principal type is not supported.
  • + *
  • {@link org.apache.kafka.common.errors.DelegationTokenDisabledException} + * if the delegation token feature is disabled.
  • + *
  • {@link org.apache.kafka.common.errors.TimeoutException} + * if the request was not completed in within the given {@link CreateDelegationTokenOptions#timeoutMs()}.
  • + *
+ * + * @param options The options to use when creating delegation token. + * @return The DeleteRecordsResult. + */ + CreateDelegationTokenResult createDelegationToken(CreateDelegationTokenOptions options); + + + /** + * Renew a Delegation Token. + *

+ * This is a convenience method for {@link #renewDelegationToken(byte[], RenewDelegationTokenOptions)} with default options. + * See the overload for more details. + * + * @param hmac HMAC of the Delegation token + * @return The RenewDelegationTokenResult. + */ + default RenewDelegationTokenResult renewDelegationToken(byte[] hmac) { + return renewDelegationToken(hmac, new RenewDelegationTokenOptions()); + } + + /** + * Renew a Delegation Token. + *

+ * This operation is supported by brokers with version 1.1.0 or higher. + *

+ * The following exceptions can be anticipated when calling {@code get()} on the futures obtained from the + * {@link RenewDelegationTokenResult#expiryTimestamp() expiryTimestamp()} method of the returned {@link RenewDelegationTokenResult} + *

    + *
  • {@link org.apache.kafka.common.errors.UnsupportedByAuthenticationException} + * If the request sent on PLAINTEXT/1-way SSL channels or delegation token authenticated channels.
  • + *
  • {@link org.apache.kafka.common.errors.DelegationTokenDisabledException} + * if the delegation token feature is disabled.
  • + *
  • {@link org.apache.kafka.common.errors.DelegationTokenNotFoundException} + * if the delegation token is not found on server.
  • + *
  • {@link org.apache.kafka.common.errors.DelegationTokenOwnerMismatchException} + * if the authenticated user is not owner/renewer of the token.
  • + *
  • {@link org.apache.kafka.common.errors.DelegationTokenExpiredException} + * if the delegation token is expired.
  • + *
  • {@link org.apache.kafka.common.errors.TimeoutException} + * if the request was not completed in within the given {@link RenewDelegationTokenOptions#timeoutMs()}.
  • + *
+ * + * @param hmac HMAC of the Delegation token + * @param options The options to use when renewing delegation token. + * @return The RenewDelegationTokenResult. + */ + RenewDelegationTokenResult renewDelegationToken(byte[] hmac, RenewDelegationTokenOptions options); + + /** + * Expire a Delegation Token. + *

+ * This is a convenience method for {@link #expireDelegationToken(byte[], ExpireDelegationTokenOptions)} with default options. + * This will expire the token immediately. See the overload for more details. + * + * @param hmac HMAC of the Delegation token + * @return The ExpireDelegationTokenResult. + */ + default ExpireDelegationTokenResult expireDelegationToken(byte[] hmac) { + return expireDelegationToken(hmac, new ExpireDelegationTokenOptions()); + } + + /** + * Expire a Delegation Token. + *

+ * This operation is supported by brokers with version 1.1.0 or higher. + *

+ * The following exceptions can be anticipated when calling {@code get()} on the futures obtained from the + * {@link ExpireDelegationTokenResult#expiryTimestamp() expiryTimestamp()} method of the returned {@link ExpireDelegationTokenResult} + *

    + *
  • {@link org.apache.kafka.common.errors.UnsupportedByAuthenticationException} + * If the request sent on PLAINTEXT/1-way SSL channels or delegation token authenticated channels.
  • + *
  • {@link org.apache.kafka.common.errors.DelegationTokenDisabledException} + * if the delegation token feature is disabled.
  • + *
  • {@link org.apache.kafka.common.errors.DelegationTokenNotFoundException} + * if the delegation token is not found on server.
  • + *
  • {@link org.apache.kafka.common.errors.DelegationTokenOwnerMismatchException} + * if the authenticated user is not owner/renewer of the requested token.
  • + *
  • {@link org.apache.kafka.common.errors.DelegationTokenExpiredException} + * if the delegation token is expired.
  • + *
  • {@link org.apache.kafka.common.errors.TimeoutException} + * if the request was not completed in within the given {@link ExpireDelegationTokenOptions#timeoutMs()}.
  • + *
+ * + * @param hmac HMAC of the Delegation token + * @param options The options to use when expiring delegation token. + * @return The ExpireDelegationTokenResult. + */ + ExpireDelegationTokenResult expireDelegationToken(byte[] hmac, ExpireDelegationTokenOptions options); + + /** + * Describe the Delegation Tokens. + *

+ * This is a convenience method for {@link #describeDelegationToken(DescribeDelegationTokenOptions)} with default options. + * This will return all the user owned tokens and tokens where user have Describe permission. See the overload for more details. + * + * @return The DescribeDelegationTokenResult. + */ + default DescribeDelegationTokenResult describeDelegationToken() { + return describeDelegationToken(new DescribeDelegationTokenOptions()); + } + + /** + * Describe the Delegation Tokens. + *

+ * This operation is supported by brokers with version 1.1.0 or higher. + *

+ * The following exceptions can be anticipated when calling {@code get()} on the futures obtained from the + * {@link DescribeDelegationTokenResult#delegationTokens() delegationTokens()} method of the returned {@link DescribeDelegationTokenResult} + *

    + *
  • {@link org.apache.kafka.common.errors.UnsupportedByAuthenticationException} + * If the request sent on PLAINTEXT/1-way SSL channels or delegation token authenticated channels.
  • + *
  • {@link org.apache.kafka.common.errors.DelegationTokenDisabledException} + * if the delegation token feature is disabled.
  • + *
  • {@link org.apache.kafka.common.errors.TimeoutException} + * if the request was not completed in within the given {@link DescribeDelegationTokenOptions#timeoutMs()}.
  • + *
+ * + * @param options The options to use when describing delegation tokens. + * @return The DescribeDelegationTokenResult. + */ + DescribeDelegationTokenResult describeDelegationToken(DescribeDelegationTokenOptions options); + + /** + * Describe some group IDs in the cluster. + * + * @param groupIds The IDs of the groups to describe. + * @param options The options to use when describing the groups. + * @return The DescribeConsumerGroupResult. + */ + DescribeConsumerGroupsResult describeConsumerGroups(Collection groupIds, + DescribeConsumerGroupsOptions options); + + /** + * Describe some group IDs in the cluster, with the default options. + *

+ * This is a convenience method for {@link #describeConsumerGroups(Collection, DescribeConsumerGroupsOptions)} + * with default options. See the overload for more details. + * + * @param groupIds The IDs of the groups to describe. + * @return The DescribeConsumerGroupResult. + */ + default DescribeConsumerGroupsResult describeConsumerGroups(Collection groupIds) { + return describeConsumerGroups(groupIds, new DescribeConsumerGroupsOptions()); + } + + /** + * List the consumer groups available in the cluster. + * + * @param options The options to use when listing the consumer groups. + * @return The ListGroupsResult. + */ + ListConsumerGroupsResult listConsumerGroups(ListConsumerGroupsOptions options); + + /** + * List the consumer groups available in the cluster with the default options. + *

+ * This is a convenience method for {@link #listConsumerGroups(ListConsumerGroupsOptions)} with default options. + * See the overload for more details. + * + * @return The ListGroupsResult. + */ + default ListConsumerGroupsResult listConsumerGroups() { + return listConsumerGroups(new ListConsumerGroupsOptions()); + } + + /** + * List the consumer group offsets available in the cluster. + * + * @param options The options to use when listing the consumer group offsets. + * @return The ListGroupOffsetsResult + */ + ListConsumerGroupOffsetsResult listConsumerGroupOffsets(String groupId, ListConsumerGroupOffsetsOptions options); + + /** + * List the consumer group offsets available in the cluster with the default options. + *

+ * This is a convenience method for {@link #listConsumerGroupOffsets(String, ListConsumerGroupOffsetsOptions)} with default options. + * + * @return The ListGroupOffsetsResult. + */ + default ListConsumerGroupOffsetsResult listConsumerGroupOffsets(String groupId) { + return listConsumerGroupOffsets(groupId, new ListConsumerGroupOffsetsOptions()); + } + + /** + * Delete consumer groups from the cluster. + * + * @param options The options to use when deleting a consumer group. + * @return The DeletConsumerGroupResult. + */ + DeleteConsumerGroupsResult deleteConsumerGroups(Collection groupIds, DeleteConsumerGroupsOptions options); + + /** + * Delete consumer groups from the cluster with the default options. + * + * @return The DeleteConsumerGroupResult. + */ + default DeleteConsumerGroupsResult deleteConsumerGroups(Collection groupIds) { + return deleteConsumerGroups(groupIds, new DeleteConsumerGroupsOptions()); + } + + /** + * Delete committed offsets for a set of partitions in a consumer group. This will + * succeed at the partition level only if the group is not actively subscribed + * to the corresponding topic. + * + * @param options The options to use when deleting offsets in a consumer group. + * @return The DeleteConsumerGroupOffsetsResult. + */ + DeleteConsumerGroupOffsetsResult deleteConsumerGroupOffsets(String groupId, + Set partitions, + DeleteConsumerGroupOffsetsOptions options); + + /** + * Delete committed offsets for a set of partitions in a consumer group with the default + * options. This will succeed at the partition level only if the group is not actively + * subscribed to the corresponding topic. + * + * @return The DeleteConsumerGroupOffsetsResult. + */ + default DeleteConsumerGroupOffsetsResult deleteConsumerGroupOffsets(String groupId, Set partitions) { + return deleteConsumerGroupOffsets(groupId, partitions, new DeleteConsumerGroupOffsetsOptions()); + } + + /** + * Elect a replica as leader for topic partitions. + *

+ * This is a convenience method for {@link #electLeaders(ElectionType, Set, ElectLeadersOptions)} + * with default options. + * + * @param electionType The type of election to conduct. + * @param partitions The topics and partitions for which to conduct elections. + * @return The ElectLeadersResult. + */ + default ElectLeadersResult electLeaders(ElectionType electionType, Set partitions) { + return electLeaders(electionType, partitions, new ElectLeadersOptions()); + } + + /** + * Elect a replica as leader for the given {@code partitions}, or for all partitions if the argument + * to {@code partitions} is null. + *

+ * This operation is not transactional so it may succeed for some partitions while fail for others. + *

+ * It may take several seconds after this method returns success for all the brokers in the cluster + * to become aware that the partitions have new leaders. During this time, + * {@link #describeTopics(Collection)} may not return information about the partitions' + * new leaders. + *

+ * This operation is supported by brokers with version 2.2.0 or later if preferred election is use; + * otherwise the brokers most be 2.4.0 or higher. + *

+ * The following exceptions can be anticipated when calling {@code get()} on the future obtained + * from the returned {@link ElectLeadersResult}: + *

    + *
  • {@link org.apache.kafka.common.errors.ClusterAuthorizationException} + * if the authenticated user didn't have alter access to the cluster.
  • + *
  • {@link org.apache.kafka.common.errors.UnknownTopicOrPartitionException} + * if the topic or partition did not exist within the cluster.
  • + *
  • {@link org.apache.kafka.common.errors.InvalidTopicException} + * if the topic was already queued for deletion.
  • + *
  • {@link org.apache.kafka.common.errors.NotControllerException} + * if the request was sent to a broker that was not the controller for the cluster.
  • + *
  • {@link org.apache.kafka.common.errors.TimeoutException} + * if the request timed out before the election was complete.
  • + *
  • {@link org.apache.kafka.common.errors.LeaderNotAvailableException} + * if the preferred leader was not alive or not in the ISR.
  • + *
+ * + * @param electionType The type of election to conduct. + * @param partitions The topics and partitions for which to conduct elections. + * @param options The options to use when electing the leaders. + * @return The ElectLeadersResult. + */ + ElectLeadersResult electLeaders( + ElectionType electionType, + Set partitions, + ElectLeadersOptions options); + + + /** + * Change the reassignments for one or more partitions. + * Providing an empty Optional (e.g via {@link Optional#empty()}) will revert the reassignment for the associated partition. + * + * This is a convenience method for {@link #alterPartitionReassignments(Map, AlterPartitionReassignmentsOptions)} + * with default options. See the overload for more details. + */ + default AlterPartitionReassignmentsResult alterPartitionReassignments( + Map> reassignments) { + return alterPartitionReassignments(reassignments, new AlterPartitionReassignmentsOptions()); + } + + /** + * Change the reassignments for one or more partitions. + * Providing an empty Optional (e.g via {@link Optional#empty()}) will revert the reassignment for the associated partition. + * + *

The following exceptions can be anticipated when calling {@code get()} on the futures obtained from + * the returned {@code AlterPartitionReassignmentsResult}:

+ *
    + *
  • {@link org.apache.kafka.common.errors.ClusterAuthorizationException} + * If the authenticated user didn't have alter access to the cluster.
  • + *
  • {@link org.apache.kafka.common.errors.UnknownTopicOrPartitionException} + * If the topic or partition does not exist within the cluster.
  • + *
  • {@link org.apache.kafka.common.errors.TimeoutException} + * if the request timed out before the controller could record the new assignments.
  • + *
  • {@link org.apache.kafka.common.errors.InvalidReplicaAssignmentException} + * If the specified assignment was not valid.
  • + *
  • {@link org.apache.kafka.common.errors.NoReassignmentInProgressException} + * If there was an attempt to cancel a reassignment for a partition which was not being reassigned.
  • + *
+ * + * @param reassignments The reassignments to add, modify, or remove. See {@link NewPartitionReassignment}. + * @param options The options to use. + * @return The result. + */ + AlterPartitionReassignmentsResult alterPartitionReassignments( + Map> reassignments, + AlterPartitionReassignmentsOptions options); + + + /** + * List all of the current partition reassignments + * + * This is a convenience method for {@link #listPartitionReassignments(ListPartitionReassignmentsOptions)} + * with default options. See the overload for more details. + */ + default ListPartitionReassignmentsResult listPartitionReassignments() { + return listPartitionReassignments(new ListPartitionReassignmentsOptions()); + } + + /** + * List the current reassignments for the given partitions + * + * This is a convenience method for {@link #listPartitionReassignments(Set, ListPartitionReassignmentsOptions)} + * with default options. See the overload for more details. + */ + default ListPartitionReassignmentsResult listPartitionReassignments(Set partitions) { + return listPartitionReassignments(partitions, new ListPartitionReassignmentsOptions()); + } + + /** + * List the current reassignments for the given partitions + * + *

The following exceptions can be anticipated when calling {@code get()} on the futures obtained from + * the returned {@code ListPartitionReassignmentsResult}:

+ *
    + *
  • {@link org.apache.kafka.common.errors.ClusterAuthorizationException} + * If the authenticated user doesn't have alter access to the cluster.
  • + *
  • {@link org.apache.kafka.common.errors.UnknownTopicOrPartitionException} + * If a given topic or partition does not exist.
  • + *
  • {@link org.apache.kafka.common.errors.TimeoutException} + * If the request timed out before the controller could list the current reassignments.
  • + *
+ * + * @param partitions The topic partitions to list reassignments for. + * @param options The options to use. + * @return The result. + */ + default ListPartitionReassignmentsResult listPartitionReassignments( + Set partitions, + ListPartitionReassignmentsOptions options) { + return listPartitionReassignments(Optional.of(partitions), options); + } + + /** + * List all of the current partition reassignments + * + *

The following exceptions can be anticipated when calling {@code get()} on the futures obtained from + * the returned {@code ListPartitionReassignmentsResult}:

+ *
    + *
  • {@link org.apache.kafka.common.errors.ClusterAuthorizationException} + * If the authenticated user doesn't have alter access to the cluster.
  • + *
  • {@link org.apache.kafka.common.errors.UnknownTopicOrPartitionException} + * If a given topic or partition does not exist.
  • + *
  • {@link org.apache.kafka.common.errors.TimeoutException} + * If the request timed out before the controller could list the current reassignments.
  • + *
+ * + * @param options The options to use. + * @return The result. + */ + default ListPartitionReassignmentsResult listPartitionReassignments(ListPartitionReassignmentsOptions options) { + return listPartitionReassignments(Optional.empty(), options); + } + + /** + * @param partitions the partitions we want to get reassignment for, or an empty optional if we want to get the reassignments for all partitions in the cluster + * @param options The options to use. + * @return The result. + */ + ListPartitionReassignmentsResult listPartitionReassignments(Optional> partitions, + ListPartitionReassignmentsOptions options); + + /** + * Remove members from the consumer group by given member identities. + *

+ * For possible error codes, refer to {@link LeaveGroupResponse}. + * + * @param groupId The ID of the group to remove member from. + * @param options The options to carry removing members' information. + * @return The MembershipChangeResult. + */ + RemoveMembersFromConsumerGroupResult removeMembersFromConsumerGroup(String groupId, RemoveMembersFromConsumerGroupOptions options); + + /** + *

Alters offsets for the specified group. In order to succeed, the group must be empty. + * + *

This is a convenience method for {@link #alterConsumerGroupOffsets(String, Map, AlterConsumerGroupOffsetsOptions)} with default options. + * See the overload for more details. + * + * @param groupId The group for which to alter offsets. + * @param offsets A map of offsets by partition with associated metadata. + * @return The AlterOffsetsResult. + */ + default AlterConsumerGroupOffsetsResult alterConsumerGroupOffsets(String groupId, Map offsets) { + return alterConsumerGroupOffsets(groupId, offsets, new AlterConsumerGroupOffsetsOptions()); + } + + /** + *

Alters offsets for the specified group. In order to succeed, the group must be empty. + * + *

This operation is not transactional so it may succeed for some partitions while fail for others. + * + * @param groupId The group for which to alter offsets. + * @param offsets A map of offsets by partition with associated metadata. Partitions not specified in the map are ignored. + * @param options The options to use when altering the offsets. + * @return The AlterOffsetsResult. + */ + AlterConsumerGroupOffsetsResult alterConsumerGroupOffsets(String groupId, Map offsets, AlterConsumerGroupOffsetsOptions options); + + /** + *

List offset for the specified partitions and OffsetSpec. This operation enables to find + * the beginning offset, end offset as well as the offset matching a timestamp in partitions. + * + *

This is a convenience method for {@link #listOffsets(Map, ListOffsetsOptions)} + * + * @param topicPartitionOffsets The mapping from partition to the OffsetSpec to look up. + * @return The ListOffsetsResult. + */ + default ListOffsetsResult listOffsets(Map topicPartitionOffsets) { + return listOffsets(topicPartitionOffsets, new ListOffsetsOptions()); + } + + /** + *

List offset for the specified partitions. This operation enables to find + * the beginning offset, end offset as well as the offset matching a timestamp in partitions. + * + * @param topicPartitionOffsets The mapping from partition to the OffsetSpec to look up. + * @param options The options to use when retrieving the offsets + * @return The ListOffsetsResult. + */ + ListOffsetsResult listOffsets(Map topicPartitionOffsets, ListOffsetsOptions options); + + /** + * Describes all entities matching the provided filter that have at least one client quota configuration + * value defined. + *

+ * This is a convenience method for {@link #describeClientQuotas(ClientQuotaFilter, DescribeClientQuotasOptions)} + * with default options. See the overload for more details. + *

+ * This operation is supported by brokers with version 2.6.0 or higher. + * + * @param filter the filter to apply to match entities + * @return the DescribeClientQuotasResult containing the result + */ + default DescribeClientQuotasResult describeClientQuotas(ClientQuotaFilter filter) { + return describeClientQuotas(filter, new DescribeClientQuotasOptions()); + } + + /** + * Describes all entities matching the provided filter that have at least one client quota configuration + * value defined. + *

+ * The following exceptions can be anticipated when calling {@code get()} on the future from the + * returned {@link DescribeClientQuotasResult}: + *

    + *
  • {@link org.apache.kafka.common.errors.ClusterAuthorizationException} + * If the authenticated user didn't have describe access to the cluster.
  • + *
  • {@link org.apache.kafka.common.errors.InvalidRequestException} + * If the request details are invalid. e.g., an invalid entity type was specified.
  • + *
  • {@link org.apache.kafka.common.errors.TimeoutException} + * If the request timed out before the describe could finish.
  • + *
+ *

+ * This operation is supported by brokers with version 2.6.0 or higher. + * + * @param filter the filter to apply to match entities + * @param options the options to use + * @return the DescribeClientQuotasResult containing the result + */ + DescribeClientQuotasResult describeClientQuotas(ClientQuotaFilter filter, DescribeClientQuotasOptions options); + + /** + * Alters client quota configurations with the specified alterations. + *

+ * This is a convenience method for {@link #alterClientQuotas(Collection, AlterClientQuotasOptions)} + * with default options. See the overload for more details. + *

+ * This operation is supported by brokers with version 2.6.0 or higher. + * + * @param entries the alterations to perform + * @return the AlterClientQuotasResult containing the result + */ + default AlterClientQuotasResult alterClientQuotas(Collection entries) { + return alterClientQuotas(entries, new AlterClientQuotasOptions()); + } + + /** + * Alters client quota configurations with the specified alterations. + *

+ * Alterations for a single entity are atomic, but across entities is not guaranteed. The resulting + * per-entity error code should be evaluated to resolve the success or failure of all updates. + *

+ * The following exceptions can be anticipated when calling {@code get()} on the futures obtained from + * the returned {@link AlterClientQuotasResult}: + *

    + *
  • {@link org.apache.kafka.common.errors.ClusterAuthorizationException} + * If the authenticated user didn't have alter access to the cluster.
  • + *
  • {@link org.apache.kafka.common.errors.InvalidRequestException} + * If the request details are invalid. e.g., a configuration key was specified more than once for an entity.
  • + *
  • {@link org.apache.kafka.common.errors.TimeoutException} + * If the request timed out before the alterations could finish. It cannot be guaranteed whether the update + * succeed or not.
  • + *
+ *

+ * This operation is supported by brokers with version 2.6.0 or higher. + * + * @param entries the alterations to perform + * @return the AlterClientQuotasResult containing the result + */ + AlterClientQuotasResult alterClientQuotas(Collection entries, AlterClientQuotasOptions options); + + /** + * Describe all SASL/SCRAM credentials. + * + *

This is a convenience method for {@link #describeUserScramCredentials(List, DescribeUserScramCredentialsOptions)} + * + * @return The DescribeUserScramCredentialsResult. + */ + default DescribeUserScramCredentialsResult describeUserScramCredentials() { + return describeUserScramCredentials(null, new DescribeUserScramCredentialsOptions()); + } + + /** + * Describe SASL/SCRAM credentials for the given users. + * + *

This is a convenience method for {@link #describeUserScramCredentials(List, DescribeUserScramCredentialsOptions)} + * + * @param users the users for which credentials are to be described; all users' credentials are described if null + * or empty. + * @return The DescribeUserScramCredentialsResult. + */ + default DescribeUserScramCredentialsResult describeUserScramCredentials(List users) { + return describeUserScramCredentials(users, new DescribeUserScramCredentialsOptions()); + } + + /** + * Describe SASL/SCRAM credentials. + *

+ * The following exceptions can be anticipated when calling {@code get()} on the futures from the + * returned {@link DescribeUserScramCredentialsResult}: + *

    + *
  • {@link org.apache.kafka.common.errors.ClusterAuthorizationException} + * If the authenticated user didn't have describe access to the cluster.
  • + *
  • {@link org.apache.kafka.common.errors.ResourceNotFoundException} + * If the user did not exist/had no SCRAM credentials.
  • + *
  • {@link org.apache.kafka.common.errors.DuplicateResourceException} + * If the user was requested to be described more than once in the original request.
  • + *
  • {@link org.apache.kafka.common.errors.TimeoutException} + * If the request timed out before the describe operation could finish.
  • + *
+ *

+ * This operation is supported by brokers with version 2.7.0 or higher. + * + * @param users the users for which credentials are to be described; all users' credentials are described if null + * or empty. + * @param options The options to use when describing the credentials + * @return The DescribeUserScramCredentialsResult. + */ + DescribeUserScramCredentialsResult describeUserScramCredentials(List users, DescribeUserScramCredentialsOptions options); + + /** + * Alter SASL/SCRAM credentials for the given users. + * + *

This is a convenience method for {@link #alterUserScramCredentials(List, AlterUserScramCredentialsOptions)} + * + * @param alterations the alterations to be applied + * @return The AlterUserScramCredentialsResult. + */ + default AlterUserScramCredentialsResult alterUserScramCredentials(List alterations) { + return alterUserScramCredentials(alterations, new AlterUserScramCredentialsOptions()); + } + + /** + * Alter SASL/SCRAM credentials. + * + *

+ * The following exceptions can be anticipated when calling {@code get()} any of the futures from the + * returned {@link AlterUserScramCredentialsResult}: + *

    + *
  • {@link org.apache.kafka.common.errors.NotControllerException} + * If the request is not sent to the Controller broker.
  • + *
  • {@link org.apache.kafka.common.errors.ClusterAuthorizationException} + * If the authenticated user didn't have alter access to the cluster.
  • + *
  • {@link org.apache.kafka.common.errors.UnsupportedByAuthenticationException} + * If the user authenticated with a delegation token.
  • + *
  • {@link org.apache.kafka.common.errors.UnsupportedSaslMechanismException} + * If the requested SCRAM mechanism is unrecognized or otherwise unsupported.
  • + *
  • {@link org.apache.kafka.common.errors.UnacceptableCredentialException} + * If the username is empty or the requested number of iterations is too small or too large.
  • + *
  • {@link org.apache.kafka.common.errors.TimeoutException} + * If the request timed out before the describe could finish.
  • + *
+ *

+ * This operation is supported by brokers with version 2.7.0 or higher. + * + * @param alterations the alterations to be applied + * @param options The options to use when altering the credentials + * @return The AlterUserScramCredentialsResult. + */ + AlterUserScramCredentialsResult alterUserScramCredentials(List alterations, + AlterUserScramCredentialsOptions options); + /** + * Describes finalized as well as supported features. + *

+ * This is a convenience method for {@link #describeFeatures(DescribeFeaturesOptions)} with default options. + * See the overload for more details. + * + * @return the {@link DescribeFeaturesResult} containing the result + */ + default DescribeFeaturesResult describeFeatures() { + return describeFeatures(new DescribeFeaturesOptions()); + } + + /** + * Describes finalized as well as supported features. The request is issued to any random + * broker. + *

+ * The following exceptions can be anticipated when calling {@code get()} on the future from the + * returned {@link DescribeFeaturesResult}: + *

    + *
  • {@link org.apache.kafka.common.errors.TimeoutException} + * If the request timed out before the describe operation could finish.
  • + *
+ *

+ * + * @param options the options to use + * @return the {@link DescribeFeaturesResult} containing the result + */ + DescribeFeaturesResult describeFeatures(DescribeFeaturesOptions options); + + /** + * Applies specified updates to finalized features. This operation is not transactional so some + * updates may succeed while the rest may fail. + *

+ * The API takes in a map of finalized feature names to {@link FeatureUpdate} that needs to be + * applied. Each entry in the map specifies the finalized feature to be added or updated or + * deleted, along with the new max feature version level value. This request is issued only to + * the controller since the API is only served by the controller. The return value contains an + * error code for each supplied {@link FeatureUpdate}, and the code indicates if the update + * succeeded or failed in the controller. + *

    + *
  • Downgrade of feature version level is not a regular operation/intent. It is only allowed + * in the controller if the {@link FeatureUpdate} has the allowDowngrade flag set. Setting this + * flag conveys user intent to attempt downgrade of a feature max version level. Note that + * despite the allowDowngrade flag being set, certain downgrades may be rejected by the + * controller if it is deemed impossible.
  • + *
  • Deletion of a finalized feature version is not a regular operation/intent. It could be + * done by setting the allowDowngrade flag to true in the {@link FeatureUpdate}, and, setting + * the max version level to a value less than 1.
  • + *
+ *

+ * The following exceptions can be anticipated when calling {@code get()} on the futures + * obtained from the returned {@link UpdateFeaturesResult}: + *

    + *
  • {@link org.apache.kafka.common.errors.ClusterAuthorizationException} + * If the authenticated user didn't have alter access to the cluster.
  • + *
  • {@link org.apache.kafka.common.errors.InvalidRequestException} + * If the request details are invalid. e.g., a non-existing finalized feature is attempted + * to be deleted or downgraded.
  • + *
  • {@link org.apache.kafka.common.errors.TimeoutException} + * If the request timed out before the updates could finish. It cannot be guaranteed whether + * the updates succeeded or not.
  • + *
  • {@link FeatureUpdateFailedException} + * This means there was an unexpected error encountered when the update was applied on + * the controller. There is no guarantee on whether the update succeeded or failed. The best + * way to find out is to issue a {@link Admin#describeFeatures(DescribeFeaturesOptions)} + * request.
  • + *
+ *

+ * This operation is supported by brokers with version 2.7.0 or higher. + + * @param featureUpdates the map of finalized feature name to {@link FeatureUpdate} + * @param options the options to use + * @return the {@link UpdateFeaturesResult} containing the result + */ + UpdateFeaturesResult updateFeatures(Map featureUpdates, UpdateFeaturesOptions options); + + /** + * Unregister a broker. + *

+ * This operation does not have any effect on partition assignments. It is supported + * only on Kafka clusters which use Raft to store metadata, rather than ZooKeeper. + * + * This is a convenience method for {@link #unregisterBroker(int, UnregisterBrokerOptions)} + * + * @param brokerId the broker id to unregister. + * + * @return the {@link UnregisterBrokerResult} containing the result + */ + @InterfaceStability.Unstable + default UnregisterBrokerResult unregisterBroker(int brokerId) { + return unregisterBroker(brokerId, new UnregisterBrokerOptions()); + } + + /** + * Unregister a broker. + *

+ * This operation does not have any effect on partition assignments. It is supported + * only on Kafka clusters which use Raft to store metadata, rather than ZooKeeper. + * + * The following exceptions can be anticipated when calling {@code get()} on the future from the + * returned {@link UnregisterBrokerResult}: + *

    + *
  • {@link org.apache.kafka.common.errors.TimeoutException} + * If the request timed out before the describe operation could finish.
  • + *
  • {@link org.apache.kafka.common.errors.UnsupportedVersionException} + * If the software is too old to support the unregistration API, or if the + * cluster is not using Raft to store metadata. + *
+ *

+ * + * @param brokerId the broker id to unregister. + * @param options the options to use. + * + * @return the {@link UnregisterBrokerResult} containing the result + */ + @InterfaceStability.Unstable + UnregisterBrokerResult unregisterBroker(int brokerId, UnregisterBrokerOptions options); + + /** + * Describe producer state on a set of topic partitions. See + * {@link #describeProducers(Collection, DescribeProducersOptions)} for more details. + * + * @param partitions The set of partitions to query + * @return The result + */ + default DescribeProducersResult describeProducers(Collection partitions) { + return describeProducers(partitions, new DescribeProducersOptions()); + } + + /** + * Describe active producer state on a set of topic partitions. Unless a specific broker + * is requested through {@link DescribeProducersOptions#brokerId(int)}, this will + * query the partition leader to find the producer state. + * + * @param partitions The set of partitions to query + * @param options Options to control the method behavior + * @return The result + */ + DescribeProducersResult describeProducers(Collection partitions, DescribeProducersOptions options); + + /** + * Describe the state of a set of transactional IDs. See + * {@link #describeTransactions(Collection, DescribeTransactionsOptions)} for more details. + * + * @param transactionalIds The set of transactional IDs to query + * @return The result + */ + default DescribeTransactionsResult describeTransactions(Collection transactionalIds) { + return describeTransactions(transactionalIds, new DescribeTransactionsOptions()); + } + + /** + * Describe the state of a set of transactional IDs from the respective transaction coordinators, + * which are dynamically discovered. + * + * @param transactionalIds The set of transactional IDs to query + * @param options Options to control the method behavior + * @return The result + */ + DescribeTransactionsResult describeTransactions(Collection transactionalIds, DescribeTransactionsOptions options); + + /** + * Forcefully abort a transaction which is open on a topic partition. See + * {@link #abortTransaction(AbortTransactionSpec, AbortTransactionOptions)} for more details. + * + * @param spec The transaction specification including topic partition and producer details + * @return The result + */ + default AbortTransactionResult abortTransaction(AbortTransactionSpec spec) { + return abortTransaction(spec, new AbortTransactionOptions()); + } + + /** + * Forcefully abort a transaction which is open on a topic partition. This will + * send a `WriteTxnMarkers` request to the partition leader in order to abort the + * transaction. This requires administrative privileges. + * + * @param spec The transaction specification including topic partition and producer details + * @param options Options to control the method behavior (including filters) + * @return The result + */ + AbortTransactionResult abortTransaction(AbortTransactionSpec spec, AbortTransactionOptions options); + + /** + * List active transactions in the cluster. See + * {@link #listTransactions(ListTransactionsOptions)} for more details. + * + * @return The result + */ + default ListTransactionsResult listTransactions() { + return listTransactions(new ListTransactionsOptions()); + } + + /** + * List active transactions in the cluster. This will query all potential transaction + * coordinators in the cluster and collect the state of all transactions. Users + * should typically attempt to reduce the size of the result set using + * {@link ListTransactionsOptions#filterProducerIds(Collection)} or + * {@link ListTransactionsOptions#filterStates(Collection)} + * + * @param options Options to control the method behavior (including filters) + * @return The result + */ + ListTransactionsResult listTransactions(ListTransactionsOptions options); + + /** + * Get the metrics kept by the adminClient + */ + Map metrics(); +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/AdminClient.java b/clients/src/main/java/org/apache/kafka/clients/admin/AdminClient.java new file mode 100644 index 0000000..75f1c5f --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/AdminClient.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import java.util.Map; +import java.util.Properties; + +/** + * The base class for in-built admin clients. + * + * Client code should use the newer {@link Admin} interface in preference to this class. + * + * This class may be removed in a later release, but has not be marked as deprecated to avoid unnecessary noise. + */ +public abstract class AdminClient implements Admin { + + /** + * Create a new Admin with the given configuration. + * + * @param props The configuration. + * @return The new KafkaAdminClient. + */ + public static AdminClient create(Properties props) { + return (AdminClient) Admin.create(props); + } + + /** + * Create a new Admin with the given configuration. + * + * @param conf The configuration. + * @return The new KafkaAdminClient. + */ + public static AdminClient create(Map conf) { + return (AdminClient) Admin.create(conf); + } +} + diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/AdminClientConfig.java b/clients/src/main/java/org/apache/kafka/clients/admin/AdminClientConfig.java new file mode 100644 index 0000000..16feef6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/AdminClientConfig.java @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.clients.ClientDnsLookup; +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigDef.Importance; +import org.apache.kafka.common.config.ConfigDef.Type; +import org.apache.kafka.common.config.SecurityConfig; +import org.apache.kafka.common.metrics.Sensor; + +import java.util.Map; +import java.util.Set; + +import static org.apache.kafka.common.config.ConfigDef.Range.atLeast; +import static org.apache.kafka.common.config.ConfigDef.Range.between; +import static org.apache.kafka.common.config.ConfigDef.ValidString.in; + +/** + * The AdminClient configuration class, which also contains constants for configuration entry names. + */ +public class AdminClientConfig extends AbstractConfig { + private static final ConfigDef CONFIG; + + /** + * bootstrap.servers + */ + public static final String BOOTSTRAP_SERVERS_CONFIG = CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG; + private static final String BOOTSTRAP_SERVERS_DOC = CommonClientConfigs.BOOTSTRAP_SERVERS_DOC; + + /** + * client.dns.lookup + */ + public static final String CLIENT_DNS_LOOKUP_CONFIG = CommonClientConfigs.CLIENT_DNS_LOOKUP_CONFIG; + private static final String CLIENT_DNS_LOOKUP_DOC = CommonClientConfigs.CLIENT_DNS_LOOKUP_DOC; + + /** + * reconnect.backoff.ms + */ + public static final String RECONNECT_BACKOFF_MS_CONFIG = CommonClientConfigs.RECONNECT_BACKOFF_MS_CONFIG; + private static final String RECONNECT_BACKOFF_MS_DOC = CommonClientConfigs.RECONNECT_BACKOFF_MS_DOC; + + /** + * reconnect.backoff.max.ms + */ + public static final String RECONNECT_BACKOFF_MAX_MS_CONFIG = CommonClientConfigs.RECONNECT_BACKOFF_MAX_MS_CONFIG; + private static final String RECONNECT_BACKOFF_MAX_MS_DOC = CommonClientConfigs.RECONNECT_BACKOFF_MAX_MS_DOC; + + /** + * retry.backoff.ms + */ + public static final String RETRY_BACKOFF_MS_CONFIG = CommonClientConfigs.RETRY_BACKOFF_MS_CONFIG; + private static final String RETRY_BACKOFF_MS_DOC = "The amount of time to wait before attempting to " + + "retry a failed request. This avoids repeatedly sending requests in a tight loop under " + + "some failure scenarios."; + + /** socket.connection.setup.timeout.ms */ + public static final String SOCKET_CONNECTION_SETUP_TIMEOUT_MS_CONFIG = CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MS_CONFIG; + + /** socket.connection.setup.timeout.max.ms */ + public static final String SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS_CONFIG = CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS_CONFIG; + + /** connections.max.idle.ms */ + public static final String CONNECTIONS_MAX_IDLE_MS_CONFIG = CommonClientConfigs.CONNECTIONS_MAX_IDLE_MS_CONFIG; + private static final String CONNECTIONS_MAX_IDLE_MS_DOC = CommonClientConfigs.CONNECTIONS_MAX_IDLE_MS_DOC; + + /** request.timeout.ms */ + public static final String REQUEST_TIMEOUT_MS_CONFIG = CommonClientConfigs.REQUEST_TIMEOUT_MS_CONFIG; + private static final String REQUEST_TIMEOUT_MS_DOC = CommonClientConfigs.REQUEST_TIMEOUT_MS_DOC; + + public static final String CLIENT_ID_CONFIG = CommonClientConfigs.CLIENT_ID_CONFIG; + private static final String CLIENT_ID_DOC = CommonClientConfigs.CLIENT_ID_DOC; + + public static final String METADATA_MAX_AGE_CONFIG = CommonClientConfigs.METADATA_MAX_AGE_CONFIG; + private static final String METADATA_MAX_AGE_DOC = CommonClientConfigs.METADATA_MAX_AGE_DOC; + + public static final String SEND_BUFFER_CONFIG = CommonClientConfigs.SEND_BUFFER_CONFIG; + private static final String SEND_BUFFER_DOC = CommonClientConfigs.SEND_BUFFER_DOC; + + public static final String RECEIVE_BUFFER_CONFIG = CommonClientConfigs.RECEIVE_BUFFER_CONFIG; + private static final String RECEIVE_BUFFER_DOC = CommonClientConfigs.RECEIVE_BUFFER_DOC; + + public static final String METRIC_REPORTER_CLASSES_CONFIG = CommonClientConfigs.METRIC_REPORTER_CLASSES_CONFIG; + private static final String METRIC_REPORTER_CLASSES_DOC = CommonClientConfigs.METRIC_REPORTER_CLASSES_DOC; + + public static final String METRICS_NUM_SAMPLES_CONFIG = CommonClientConfigs.METRICS_NUM_SAMPLES_CONFIG; + private static final String METRICS_NUM_SAMPLES_DOC = CommonClientConfigs.METRICS_NUM_SAMPLES_DOC; + + public static final String METRICS_SAMPLE_WINDOW_MS_CONFIG = CommonClientConfigs.METRICS_SAMPLE_WINDOW_MS_CONFIG; + private static final String METRICS_SAMPLE_WINDOW_MS_DOC = CommonClientConfigs.METRICS_SAMPLE_WINDOW_MS_DOC; + + public static final String METRICS_RECORDING_LEVEL_CONFIG = CommonClientConfigs.METRICS_RECORDING_LEVEL_CONFIG; + + public static final String SECURITY_PROTOCOL_CONFIG = CommonClientConfigs.SECURITY_PROTOCOL_CONFIG; + public static final String DEFAULT_SECURITY_PROTOCOL = CommonClientConfigs.DEFAULT_SECURITY_PROTOCOL; + private static final String SECURITY_PROTOCOL_DOC = CommonClientConfigs.SECURITY_PROTOCOL_DOC; + private static final String METRICS_RECORDING_LEVEL_DOC = CommonClientConfigs.METRICS_RECORDING_LEVEL_DOC; + + public static final String RETRIES_CONFIG = CommonClientConfigs.RETRIES_CONFIG; + public static final String DEFAULT_API_TIMEOUT_MS_CONFIG = CommonClientConfigs.DEFAULT_API_TIMEOUT_MS_CONFIG; + + /** + * security.providers + */ + public static final String SECURITY_PROVIDERS_CONFIG = SecurityConfig.SECURITY_PROVIDERS_CONFIG; + private static final String SECURITY_PROVIDERS_DOC = SecurityConfig.SECURITY_PROVIDERS_DOC; + + static { + CONFIG = new ConfigDef().define(BOOTSTRAP_SERVERS_CONFIG, + Type.LIST, + Importance.HIGH, + BOOTSTRAP_SERVERS_DOC) + .define(CLIENT_ID_CONFIG, Type.STRING, "", Importance.MEDIUM, CLIENT_ID_DOC) + .define(METADATA_MAX_AGE_CONFIG, Type.LONG, 5 * 60 * 1000, atLeast(0), Importance.LOW, METADATA_MAX_AGE_DOC) + .define(SEND_BUFFER_CONFIG, Type.INT, 128 * 1024, atLeast(-1), Importance.MEDIUM, SEND_BUFFER_DOC) + .define(RECEIVE_BUFFER_CONFIG, Type.INT, 64 * 1024, atLeast(-1), Importance.MEDIUM, RECEIVE_BUFFER_DOC) + .define(RECONNECT_BACKOFF_MS_CONFIG, + Type.LONG, + 50L, + atLeast(0L), + Importance.LOW, + RECONNECT_BACKOFF_MS_DOC) + .define(RECONNECT_BACKOFF_MAX_MS_CONFIG, + Type.LONG, + 1000L, + atLeast(0L), + Importance.LOW, + RECONNECT_BACKOFF_MAX_MS_DOC) + .define(RETRY_BACKOFF_MS_CONFIG, + Type.LONG, + 100L, + atLeast(0L), + Importance.LOW, + RETRY_BACKOFF_MS_DOC) + .define(REQUEST_TIMEOUT_MS_CONFIG, + Type.INT, + 30000, + atLeast(0), + Importance.MEDIUM, + REQUEST_TIMEOUT_MS_DOC) + .define(SOCKET_CONNECTION_SETUP_TIMEOUT_MS_CONFIG, + Type.LONG, + CommonClientConfigs.DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT_MS, + Importance.MEDIUM, + CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MS_DOC) + .define(SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS_CONFIG, + Type.LONG, + CommonClientConfigs.DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS, + Importance.MEDIUM, + CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS_DOC) + .define(CONNECTIONS_MAX_IDLE_MS_CONFIG, + Type.LONG, + 5 * 60 * 1000, + Importance.MEDIUM, + CONNECTIONS_MAX_IDLE_MS_DOC) + .define(RETRIES_CONFIG, + Type.INT, + Integer.MAX_VALUE, + between(0, Integer.MAX_VALUE), + Importance.LOW, + CommonClientConfigs.RETRIES_DOC) + .define(DEFAULT_API_TIMEOUT_MS_CONFIG, + Type.INT, + 60000, + atLeast(0), + Importance.MEDIUM, + CommonClientConfigs.DEFAULT_API_TIMEOUT_MS_DOC) + .define(METRICS_SAMPLE_WINDOW_MS_CONFIG, + Type.LONG, + 30000, + atLeast(0), + Importance.LOW, + METRICS_SAMPLE_WINDOW_MS_DOC) + .define(METRICS_NUM_SAMPLES_CONFIG, Type.INT, 2, atLeast(1), Importance.LOW, METRICS_NUM_SAMPLES_DOC) + .define(METRIC_REPORTER_CLASSES_CONFIG, Type.LIST, "", Importance.LOW, METRIC_REPORTER_CLASSES_DOC) + .define(METRICS_RECORDING_LEVEL_CONFIG, + Type.STRING, + Sensor.RecordingLevel.INFO.toString(), + in(Sensor.RecordingLevel.INFO.toString(), Sensor.RecordingLevel.DEBUG.toString(), Sensor.RecordingLevel.TRACE.toString()), + Importance.LOW, + METRICS_RECORDING_LEVEL_DOC) + .define(CLIENT_DNS_LOOKUP_CONFIG, + Type.STRING, + ClientDnsLookup.USE_ALL_DNS_IPS.toString(), + in(ClientDnsLookup.USE_ALL_DNS_IPS.toString(), + ClientDnsLookup.RESOLVE_CANONICAL_BOOTSTRAP_SERVERS_ONLY.toString()), + Importance.MEDIUM, + CLIENT_DNS_LOOKUP_DOC) + // security support + .define(SECURITY_PROVIDERS_CONFIG, + Type.STRING, + null, + Importance.LOW, + SECURITY_PROVIDERS_DOC) + .define(SECURITY_PROTOCOL_CONFIG, + Type.STRING, + DEFAULT_SECURITY_PROTOCOL, + Importance.MEDIUM, + SECURITY_PROTOCOL_DOC) + .withClientSslSupport() + .withClientSaslSupport(); + } + + @Override + protected Map postProcessParsedConfig(final Map parsedValues) { + return CommonClientConfigs.postProcessReconnectBackoffConfigs(this, parsedValues); + } + + public AdminClientConfig(Map props) { + this(props, false); + } + + protected AdminClientConfig(Map props, boolean doLog) { + super(CONFIG, props, doLog); + } + + public static Set configNames() { + return CONFIG.names(); + } + + public static ConfigDef configDef() { + return new ConfigDef(CONFIG); + } + + public static void main(String[] args) { + System.out.println(CONFIG.toHtml(4, config -> "adminclientconfigs_" + config)); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/AlterClientQuotasOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/AlterClientQuotasOptions.java new file mode 100644 index 0000000..3cdaa97 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/AlterClientQuotasOptions.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +/** + * Options for {@link Admin#alterClientQuotas(Collection, AlterClientQuotasOptions)}. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class AlterClientQuotasOptions extends AbstractOptions { + + private boolean validateOnly = false; + + /** + * Returns whether the request should be validated without altering the configs. + */ + public boolean validateOnly() { + return this.validateOnly; + } + + /** + * Sets whether the request should be validated without altering the configs. + */ + public AlterClientQuotasOptions validateOnly(boolean validateOnly) { + this.validateOnly = validateOnly; + return this; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/AlterClientQuotasResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/AlterClientQuotasResult.java new file mode 100644 index 0000000..63c6b3e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/AlterClientQuotasResult.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.quota.ClientQuotaEntity; + +import java.util.Map; + +/** + * The result of the {@link Admin#alterClientQuotas(Collection, AlterClientQuotasOptions)} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class AlterClientQuotasResult { + + private final Map> futures; + + /** + * Maps an entity to its alteration result. + * + * @param futures maps entity to its alteration result + */ + public AlterClientQuotasResult(Map> futures) { + this.futures = futures; + } + + /** + * Returns a map from quota entity to a future which can be used to check the status of the operation. + */ + public Map> values() { + return futures; + } + + /** + * Returns a future which succeeds only if all quota alterations succeed. + */ + public KafkaFuture all() { + return KafkaFuture.allOf(futures.values().toArray(new KafkaFuture[0])); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/AlterConfigOp.java b/clients/src/main/java/org/apache/kafka/clients/admin/AlterConfigOp.java new file mode 100644 index 0000000..1131e12 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/AlterConfigOp.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; +import java.util.Objects; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * A class representing a alter configuration entry containing name, value and operation type. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class AlterConfigOp { + + public enum OpType { + /** + * Set the value of the configuration entry. + */ + SET((byte) 0), + /** + * Revert the configuration entry to the default value (possibly null). + */ + DELETE((byte) 1), + /** + * (For list-type configuration entries only.) Add the specified values to the + * current value of the configuration entry. If the configuration value has not been set, + * adds to the default value. + */ + APPEND((byte) 2), + /** + * (For list-type configuration entries only.) Removes the specified values from the current + * value of the configuration entry. It is legal to remove values that are not currently in the + * configuration entry. Removing all entries from the current configuration value leaves an empty + * list and does NOT revert to the default value of the entry. + */ + SUBTRACT((byte) 3); + + private static final Map OP_TYPES = Collections.unmodifiableMap( + Arrays.stream(values()).collect(Collectors.toMap(OpType::id, Function.identity())) + ); + + private final byte id; + + OpType(final byte id) { + this.id = id; + } + + public byte id() { + return id; + } + + public static OpType forId(final byte id) { + return OP_TYPES.get(id); + } + } + + private final ConfigEntry configEntry; + private final OpType opType; + + public AlterConfigOp(ConfigEntry configEntry, OpType operationType) { + this.configEntry = configEntry; + this.opType = operationType; + } + + public ConfigEntry configEntry() { + return configEntry; + }; + + public OpType opType() { + return opType; + }; + + @Override + public boolean equals(final Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + final AlterConfigOp that = (AlterConfigOp) o; + return opType == that.opType && + Objects.equals(configEntry, that.configEntry); + } + + @Override + public int hashCode() { + return Objects.hash(opType, configEntry); + } + + @Override + public String toString() { + return "AlterConfigOp{" + + "opType=" + opType + + ", configEntry=" + configEntry + + '}'; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/AlterConfigsOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/AlterConfigsOptions.java new file mode 100644 index 0000000..198a4ea --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/AlterConfigsOptions.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Map; + +/** + * Options for {@link Admin#incrementalAlterConfigs(Map)} and {@link Admin#alterConfigs(Map)}. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class AlterConfigsOptions extends AbstractOptions { + + private boolean validateOnly = false; + + /** + * Set the timeout in milliseconds for this operation or {@code null} if the default api timeout for the + * AdminClient should be used. + * + */ + // This method is retained to keep binary compatibility with 0.11 + public AlterConfigsOptions timeoutMs(Integer timeoutMs) { + this.timeoutMs = timeoutMs; + return this; + } + + /** + * Return true if the request should be validated without altering the configs. + */ + public boolean shouldValidateOnly() { + return validateOnly; + } + + /** + * Set to true if the request should be validated without altering the configs. + */ + public AlterConfigsOptions validateOnly(boolean validateOnly) { + this.validateOnly = validateOnly; + return this; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/AlterConfigsResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/AlterConfigsResult.java new file mode 100644 index 0000000..29056ce --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/AlterConfigsResult.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.config.ConfigResource; + +import java.util.Map; + +/** + * The result of the {@link Admin#alterConfigs(Map)} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class AlterConfigsResult { + + private final Map> futures; + + AlterConfigsResult(Map> futures) { + this.futures = futures; + } + + /** + * Return a map from resources to futures which can be used to check the status of the operation on each resource. + */ + public Map> values() { + return futures; + } + + /** + * Return a future which succeeds only if all the alter configs operations succeed. + */ + public KafkaFuture all() { + return KafkaFuture.allOf(futures.values().toArray(new KafkaFuture[0])); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/AlterConsumerGroupOffsetsOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/AlterConsumerGroupOffsetsOptions.java new file mode 100644 index 0000000..f630b48 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/AlterConsumerGroupOffsetsOptions.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Map; + +/** + * Options for the {@link AdminClient#alterConsumerGroupOffsets(String, Map, AlterConsumerGroupOffsetsOptions)} call. + * + * The API of this class is evolving, see {@link AdminClient} for details. + */ +@InterfaceStability.Evolving +public class AlterConsumerGroupOffsetsOptions extends AbstractOptions { +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/AlterConsumerGroupOffsetsResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/AlterConsumerGroupOffsetsResult.java new file mode 100644 index 0000000..c4cb2e9 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/AlterConsumerGroupOffsetsResult.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.protocol.Errors; + +/** + * The result of the {@link AdminClient#alterConsumerGroupOffsets(String, Map)} call. + * + * The API of this class is evolving, see {@link AdminClient} for details. + */ +@InterfaceStability.Evolving +public class AlterConsumerGroupOffsetsResult { + + private final KafkaFuture> future; + + AlterConsumerGroupOffsetsResult(KafkaFuture> future) { + this.future = future; + } + + /** + * Return a future which can be used to check the result for a given partition. + */ + public KafkaFuture partitionResult(final TopicPartition partition) { + final KafkaFutureImpl result = new KafkaFutureImpl<>(); + + this.future.whenComplete((topicPartitions, throwable) -> { + if (throwable != null) { + result.completeExceptionally(throwable); + } else if (!topicPartitions.containsKey(partition)) { + result.completeExceptionally(new IllegalArgumentException( + "Alter offset for partition \"" + partition + "\" was not attempted")); + } else { + final Errors error = topicPartitions.get(partition); + if (error == Errors.NONE) { + result.complete(null); + } else { + result.completeExceptionally(error.exception()); + } + } + }); + + return result; + } + + /** + * Return a future which succeeds if all the alter offsets succeed. + */ + public KafkaFuture all() { + return this.future.thenApply(topicPartitionErrorsMap -> { + List partitionsFailed = topicPartitionErrorsMap.entrySet() + .stream() + .filter(e -> e.getValue() != Errors.NONE) + .map(Map.Entry::getKey) + .collect(Collectors.toList()); + for (Errors error : topicPartitionErrorsMap.values()) { + if (error != Errors.NONE) { + throw error.exception( + "Failed altering consumer group offsets for the following partitions: " + partitionsFailed); + } + } + return null; + }); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/AlterPartitionReassignmentsOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/AlterPartitionReassignmentsOptions.java new file mode 100644 index 0000000..bee9c70 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/AlterPartitionReassignmentsOptions.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Map; + +/** + * Options for {@link AdminClient#alterPartitionReassignments(Map, AlterPartitionReassignmentsOptions)} + * + * The API of this class is evolving. See {@link AdminClient} for details. + */ +@InterfaceStability.Evolving +public class AlterPartitionReassignmentsOptions extends AbstractOptions { +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/AlterPartitionReassignmentsResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/AlterPartitionReassignmentsResult.java new file mode 100644 index 0000000..2009ab5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/AlterPartitionReassignmentsResult.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Map; + +/** + * The result of {@link AdminClient#alterPartitionReassignments(Map, AlterPartitionReassignmentsOptions)}. + * + * The API of this class is evolving. See {@link AdminClient} for details. + */ +@InterfaceStability.Evolving +public class AlterPartitionReassignmentsResult { + private final Map> futures; + + AlterPartitionReassignmentsResult(Map> futures) { + this.futures = futures; + } + + /** + * Return a map from partitions to futures which can be used to check the status of the reassignment. + * + * Possible error codes: + * + * INVALID_REPLICA_ASSIGNMENT (39) - if the specified replica assignment was not valid -- for example, if it included negative numbers, repeated numbers, or specified a broker ID that the controller was not aware of. + * NO_REASSIGNMENT_IN_PROGRESS (85) - if the request wants to cancel reassignments but none exist + * UNKNOWN (-1) + * + */ + public Map> values() { + return futures; + } + + /** + * Return a future which succeeds only if all the reassignments were successfully initiated. + */ + public KafkaFuture all() { + return KafkaFuture.allOf(futures.values().toArray(new KafkaFuture[0])); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/AlterReplicaLogDirsOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/AlterReplicaLogDirsOptions.java new file mode 100644 index 0000000..76037fb --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/AlterReplicaLogDirsOptions.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; +import java.util.Map; + +/** + * Options for {@link Admin#alterReplicaLogDirs(Map, AlterReplicaLogDirsOptions)}. + */ +@InterfaceStability.Evolving +public class AlterReplicaLogDirsOptions extends AbstractOptions { + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/AlterReplicaLogDirsResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/AlterReplicaLogDirsResult.java new file mode 100644 index 0000000..81eb0ea --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/AlterReplicaLogDirsResult.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import java.util.Map; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ExecutionException; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.TopicPartitionReplica; +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.errors.ClusterAuthorizationException; +import org.apache.kafka.common.errors.InvalidTopicException; +import org.apache.kafka.common.errors.KafkaStorageException; +import org.apache.kafka.common.errors.LogDirNotFoundException; +import org.apache.kafka.common.errors.ReplicaNotAvailableException; +import org.apache.kafka.common.errors.UnknownServerException; + +/** + * The result of {@link Admin#alterReplicaLogDirs(Map, AlterReplicaLogDirsOptions)}. + * + * To retrieve the detailed result per specified {@link TopicPartitionReplica}, use {@link #values()}. To retrieve the + * overall result only, use {@link #all()}. + */ +@InterfaceStability.Evolving +public class AlterReplicaLogDirsResult { + private final Map> futures; + + AlterReplicaLogDirsResult(Map> futures) { + this.futures = futures; + } + + /** + * Return a map from {@link TopicPartitionReplica} to {@link KafkaFuture} which holds the status of individual + * replica movement. + * + * To check the result of individual replica movement, call {@link KafkaFuture#get()} from the value contained + * in the returned map. If there is no error, it will return silently; if not, an {@link Exception} will be thrown + * like the following: + * + *

    + *
  • {@link CancellationException}: The task was canceled.
  • + *
  • {@link InterruptedException}: Interrupted while joining I/O thread.
  • + *
  • {@link ExecutionException}: Execution failed with the following causes:
  • + *
      + *
    • {@link ClusterAuthorizationException}: Authorization failed. (CLUSTER_AUTHORIZATION_FAILED, 31)
    • + *
    • {@link InvalidTopicException}: The specified topic name is too long. (INVALID_TOPIC_EXCEPTION, 17)
    • + *
    • {@link LogDirNotFoundException}: The specified log directory is not found in the broker. (LOG_DIR_NOT_FOUND, 57)
    • + *
    • {@link ReplicaNotAvailableException}: The replica does not exist on the broker. (REPLICA_NOT_AVAILABLE, 9)
    • + *
    • {@link KafkaStorageException}: Disk error occurred. (KAFKA_STORAGE_ERROR, 56)
    • + *
    • {@link UnknownServerException}: Unknown. (UNKNOWN_SERVER_ERROR, -1)
    • + *
    + *
+ */ + public Map> values() { + return futures; + } + + /** + * Return a {@link KafkaFuture} which succeeds on {@link KafkaFuture#get()} if all the replica movement have succeeded. + * if not, it throws an {@link Exception} described in {@link #values()} method. + */ + public KafkaFuture all() { + return KafkaFuture.allOf(futures.values().toArray(new KafkaFuture[0])); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/AlterUserScramCredentialsOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/AlterUserScramCredentialsOptions.java new file mode 100644 index 0000000..23a0b0a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/AlterUserScramCredentialsOptions.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.List; + +/** + * Options for {@link AdminClient#alterUserScramCredentials(List, AlterUserScramCredentialsOptions)} + * + * The API of this class is evolving. See {@link AdminClient} for details. + */ +@InterfaceStability.Evolving +public class AlterUserScramCredentialsOptions extends AbstractOptions { +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/AlterUserScramCredentialsResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/AlterUserScramCredentialsResult.java new file mode 100644 index 0000000..a0ce013 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/AlterUserScramCredentialsResult.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +/** + * The result of the {@link Admin#alterUserScramCredentials(List)} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class AlterUserScramCredentialsResult { + private final Map> futures; + + /** + * + * @param futures the required map from user names to futures representing the results of the alteration(s) + * for each user + */ + public AlterUserScramCredentialsResult(Map> futures) { + this.futures = Collections.unmodifiableMap(Objects.requireNonNull(futures)); + } + + /** + * Return a map from user names to futures, which can be used to check the status of the alteration(s) + * for each user. + */ + public Map> values() { + return this.futures; + } + + /** + * Return a future which succeeds only if all the user SCRAM credential alterations succeed. + */ + public KafkaFuture all() { + return KafkaFuture.allOf(futures.values().toArray(new KafkaFuture[0])); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/Config.java b/clients/src/main/java/org/apache/kafka/clients/admin/Config.java new file mode 100644 index 0000000..ae7c03a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/Config.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * A configuration object containing the configuration entries for a resource. + *

+ * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class Config { + + private final Map entries = new HashMap<>(); + + /** + * Create a configuration instance with the provided entries. + */ + public Config(Collection entries) { + for (ConfigEntry entry : entries) { + this.entries.put(entry.name(), entry); + } + } + + /** + * Configuration entries for a resource. + */ + public Collection entries() { + return Collections.unmodifiableCollection(entries.values()); + } + + /** + * Get the configuration entry with the provided name or null if there isn't one. + */ + public ConfigEntry get(String name) { + return entries.get(name); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + Config config = (Config) o; + + return entries.equals(config.entries); + } + + @Override + public int hashCode() { + return entries.hashCode(); + } + + @Override + public String toString() { + return "Config(entries=" + entries.values() + ")"; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/ConfigEntry.java b/clients/src/main/java/org/apache/kafka/clients/admin/ConfigEntry.java new file mode 100644 index 0000000..30686c9 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/ConfigEntry.java @@ -0,0 +1,290 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +/** + * A class representing a configuration entry containing name, value and additional metadata. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class ConfigEntry { + + private final String name; + private final String value; + private final ConfigSource source; + private final boolean isSensitive; + private final boolean isReadOnly; + private final List synonyms; + private final ConfigType type; + private final String documentation; + + /** + * Create a configuration entry with the provided values. + * + * @param name the non-null config name + * @param value the config value or null + */ + public ConfigEntry(String name, String value) { + this(name, value, ConfigSource.UNKNOWN, false, false, + Collections.emptyList(), ConfigType.UNKNOWN, null); + } + + /** + * Create a configuration with the provided values. + * + * @param name the non-null config name + * @param value the config value or null + * @param source the source of this config entry + * @param isSensitive whether the config value is sensitive, the broker never returns the value if it is sensitive + * @param isReadOnly whether the config is read-only and cannot be updated + * @param synonyms Synonym configs in order of precedence + */ + ConfigEntry(String name, String value, ConfigSource source, boolean isSensitive, boolean isReadOnly, + List synonyms, ConfigType type, String documentation) { + Objects.requireNonNull(name, "name should not be null"); + this.name = name; + this.value = value; + this.source = source; + this.isSensitive = isSensitive; + this.isReadOnly = isReadOnly; + this.synonyms = synonyms; + this.type = type; + this.documentation = documentation; + } + + /** + * Return the config name. + */ + public String name() { + return name; + } + + /** + * Return the value or null. Null is returned if the config is unset or if isSensitive is true. + */ + public String value() { + return value; + } + + /** + * Return the source of this configuration entry. + */ + public ConfigSource source() { + return source; + } + + /** + * Return whether the config value is the default or if it's been explicitly set. + */ + public boolean isDefault() { + return source == ConfigSource.DEFAULT_CONFIG; + } + + /** + * Return whether the config value is sensitive. The value is always set to null by the broker if the config value + * is sensitive. + */ + public boolean isSensitive() { + return isSensitive; + } + + /** + * Return whether the config is read-only and cannot be updated. + */ + public boolean isReadOnly() { + return isReadOnly; + } + + /** + * Returns all config values that may be used as the value of this config along with their source, + * in the order of precedence. The list starts with the value returned in this ConfigEntry. + * The list is empty if synonyms were not requested using {@link DescribeConfigsOptions#includeSynonyms(boolean)} + */ + public List synonyms() { + return synonyms; + } + + /** + * Return the config data type. + */ + public ConfigType type() { + return type; + } + + /** + * Return the config documentation. + */ + public String documentation() { + return documentation; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + ConfigEntry that = (ConfigEntry) o; + + return this.name.equals(that.name) && + Objects.equals(this.value, that.value) && + this.isSensitive == that.isSensitive && + this.isReadOnly == that.isReadOnly && + Objects.equals(this.source, that.source) && + Objects.equals(this.synonyms, that.synonyms) && + Objects.equals(this.type, that.type) && + Objects.equals(this.documentation, that.documentation); + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + name.hashCode(); + result = prime * result + Objects.hashCode(value); + result = prime * result + (isSensitive ? 1 : 0); + result = prime * result + (isReadOnly ? 1 : 0); + result = prime * result + Objects.hashCode(source); + result = prime * result + Objects.hashCode(synonyms); + result = prime * result + Objects.hashCode(type); + result = prime * result + Objects.hashCode(documentation); + return result; + } + + /** + * Override toString to redact sensitive value. + * WARNING, user should be responsible to set the correct "isSensitive" field for each config entry. + */ + @Override + public String toString() { + return "ConfigEntry(" + + "name=" + name + + ", value=" + (isSensitive ? "Redacted" : value) + + ", source=" + source + + ", isSensitive=" + isSensitive + + ", isReadOnly=" + isReadOnly + + ", synonyms=" + synonyms + + ", type=" + type + + ", documentation=" + documentation + + ")"; + } + + /** + * Data type of configuration entry. + */ + public enum ConfigType { + UNKNOWN, + BOOLEAN, + STRING, + INT, + SHORT, + LONG, + DOUBLE, + LIST, + CLASS, + PASSWORD + } + + /** + * Source of configuration entries. + */ + public enum ConfigSource { + DYNAMIC_TOPIC_CONFIG, // dynamic topic config that is configured for a specific topic + DYNAMIC_BROKER_LOGGER_CONFIG, // dynamic broker logger config that is configured for a specific broker + DYNAMIC_BROKER_CONFIG, // dynamic broker config that is configured for a specific broker + DYNAMIC_DEFAULT_BROKER_CONFIG, // dynamic broker config that is configured as default for all brokers in the cluster + STATIC_BROKER_CONFIG, // static broker config provided as broker properties at start up (e.g. server.properties file) + DEFAULT_CONFIG, // built-in default configuration for configs that have a default value + UNKNOWN // source unknown e.g. in the ConfigEntry used for alter requests where source is not set + } + + /** + * Class representing a configuration synonym of a {@link ConfigEntry}. + */ + public static class ConfigSynonym { + + private final String name; + private final String value; + private final ConfigSource source; + + /** + * Create a configuration synonym with the provided values. + * + * @param name Configuration name (this may be different from the name of the associated {@link ConfigEntry} + * @param value Configuration value + * @param source {@link ConfigSource} of this configuraton + */ + ConfigSynonym(String name, String value, ConfigSource source) { + this.name = name; + this.value = value; + this.source = source; + } + + /** + * Returns the name of this configuration. + */ + public String name() { + return name; + } + + /** + * Returns the value of this configuration, which may be null if the configuration is sensitive. + */ + public String value() { + return value; + } + + /** + * Returns the source of this configuration. + */ + public ConfigSource source() { + return source; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + ConfigSynonym that = (ConfigSynonym) o; + return Objects.equals(name, that.name) && Objects.equals(value, that.value) && source == that.source; + } + + @Override + public int hashCode() { + return Objects.hash(name, value, source); + } + + @Override + public String toString() { + return "ConfigSynonym(" + + "name=" + name + + ", value=" + value + + ", source=" + source + + ")"; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/ConsumerGroupDescription.java b/clients/src/main/java/org/apache/kafka/clients/admin/ConsumerGroupDescription.java new file mode 100644 index 0000000..3ae9754 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/ConsumerGroupDescription.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.ConsumerGroupState; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.acl.AclOperation; +import org.apache.kafka.common.utils.Utils; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Objects; +import java.util.Set; + +/** + * A detailed description of a single consumer group in the cluster. + */ +public class ConsumerGroupDescription { + private final String groupId; + private final boolean isSimpleConsumerGroup; + private final Collection members; + private final String partitionAssignor; + private final ConsumerGroupState state; + private final Node coordinator; + private final Set authorizedOperations; + + public ConsumerGroupDescription(String groupId, + boolean isSimpleConsumerGroup, + Collection members, + String partitionAssignor, + ConsumerGroupState state, + Node coordinator) { + this(groupId, isSimpleConsumerGroup, members, partitionAssignor, state, coordinator, Collections.emptySet()); + } + + public ConsumerGroupDescription(String groupId, + boolean isSimpleConsumerGroup, + Collection members, + String partitionAssignor, + ConsumerGroupState state, + Node coordinator, + Set authorizedOperations) { + this.groupId = groupId == null ? "" : groupId; + this.isSimpleConsumerGroup = isSimpleConsumerGroup; + this.members = members == null ? Collections.emptyList() : + Collections.unmodifiableList(new ArrayList<>(members)); + this.partitionAssignor = partitionAssignor == null ? "" : partitionAssignor; + this.state = state; + this.coordinator = coordinator; + this.authorizedOperations = authorizedOperations; + } + + @Override + public boolean equals(final Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + final ConsumerGroupDescription that = (ConsumerGroupDescription) o; + return isSimpleConsumerGroup == that.isSimpleConsumerGroup && + Objects.equals(groupId, that.groupId) && + Objects.equals(members, that.members) && + Objects.equals(partitionAssignor, that.partitionAssignor) && + state == that.state && + Objects.equals(coordinator, that.coordinator) && + Objects.equals(authorizedOperations, that.authorizedOperations); + } + + @Override + public int hashCode() { + return Objects.hash(groupId, isSimpleConsumerGroup, members, partitionAssignor, state, coordinator, authorizedOperations); + } + + /** + * The id of the consumer group. + */ + public String groupId() { + return groupId; + } + + /** + * If consumer group is simple or not. + */ + public boolean isSimpleConsumerGroup() { + return isSimpleConsumerGroup; + } + + /** + * A list of the members of the consumer group. + */ + public Collection members() { + return members; + } + + /** + * The consumer group partition assignor. + */ + public String partitionAssignor() { + return partitionAssignor; + } + + /** + * The consumer group state, or UNKNOWN if the state is too new for us to parse. + */ + public ConsumerGroupState state() { + return state; + } + + /** + * The consumer group coordinator, or null if the coordinator is not known. + */ + public Node coordinator() { + return coordinator; + } + + /** + * authorizedOperations for this group, or null if that information is not known. + */ + public Set authorizedOperations() { + return authorizedOperations; + } + + @Override + public String toString() { + return "(groupId=" + groupId + + ", isSimpleConsumerGroup=" + isSimpleConsumerGroup + + ", members=" + Utils.join(members, ",") + + ", partitionAssignor=" + partitionAssignor + + ", state=" + state + + ", coordinator=" + coordinator + + ", authorizedOperations=" + authorizedOperations + + ")"; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/ConsumerGroupListing.java b/clients/src/main/java/org/apache/kafka/clients/admin/ConsumerGroupListing.java new file mode 100644 index 0000000..0abc3e0 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/ConsumerGroupListing.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import java.util.Objects; +import java.util.Optional; + +import org.apache.kafka.common.ConsumerGroupState; + +/** + * A listing of a consumer group in the cluster. + */ +public class ConsumerGroupListing { + private final String groupId; + private final boolean isSimpleConsumerGroup; + private final Optional state; + + /** + * Create an instance with the specified parameters. + * + * @param groupId Group Id + * @param isSimpleConsumerGroup If consumer group is simple or not. + */ + public ConsumerGroupListing(String groupId, boolean isSimpleConsumerGroup) { + this(groupId, isSimpleConsumerGroup, Optional.empty()); + } + + /** + * Create an instance with the specified parameters. + * + * @param groupId Group Id + * @param isSimpleConsumerGroup If consumer group is simple or not. + * @param state The state of the consumer group + */ + public ConsumerGroupListing(String groupId, boolean isSimpleConsumerGroup, Optional state) { + this.groupId = groupId; + this.isSimpleConsumerGroup = isSimpleConsumerGroup; + this.state = Objects.requireNonNull(state); + } + + /** + * Consumer Group Id + */ + public String groupId() { + return groupId; + } + + /** + * If Consumer Group is simple or not. + */ + public boolean isSimpleConsumerGroup() { + return isSimpleConsumerGroup; + } + + /** + * Consumer Group state + */ + public Optional state() { + return state; + } + + @Override + public String toString() { + return "(" + + "groupId='" + groupId + '\'' + + ", isSimpleConsumerGroup=" + isSimpleConsumerGroup + + ", state=" + state + + ')'; + } + + @Override + public int hashCode() { + return Objects.hash(groupId, isSimpleConsumerGroup, state); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + ConsumerGroupListing other = (ConsumerGroupListing) obj; + if (groupId == null) { + if (other.groupId != null) + return false; + } else if (!groupId.equals(other.groupId)) + return false; + if (isSimpleConsumerGroup != other.isSimpleConsumerGroup) + return false; + if (state == null) { + if (other.state != null) + return false; + } else if (!state.equals(other.state)) + return false; + return true; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/CreateAclsOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/CreateAclsOptions.java new file mode 100644 index 0000000..ad4ae74 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/CreateAclsOptions.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Collection; + +/** + * Options for {@link Admin#createAcls(Collection)}. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class CreateAclsOptions extends AbstractOptions { + + /** + * Set the timeout in milliseconds for this operation or {@code null} if the default api timeout for the + * AdminClient should be used. + * + */ + // This method is retained to keep binary compatibility with 0.11 + public CreateAclsOptions timeoutMs(Integer timeoutMs) { + this.timeoutMs = timeoutMs; + return this; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/CreateAclsResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/CreateAclsResult.java new file mode 100644 index 0000000..6e69554 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/CreateAclsResult.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.acl.AclBinding; +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Collection; +import java.util.Map; + +/** + * The result of the {@link Admin#createAcls(Collection)} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class CreateAclsResult { + private final Map> futures; + + CreateAclsResult(Map> futures) { + this.futures = futures; + } + + /** + * Return a map from ACL bindings to futures which can be used to check the status of the creation of each ACL + * binding. + */ + public Map> values() { + return futures; + } + + /** + * Return a future which succeeds only if all the ACL creations succeed. + */ + public KafkaFuture all() { + return KafkaFuture.allOf(futures.values().toArray(new KafkaFuture[0])); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/CreateDelegationTokenOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/CreateDelegationTokenOptions.java new file mode 100644 index 0000000..6a082d4 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/CreateDelegationTokenOptions.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import java.util.LinkedList; +import java.util.List; + +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.security.auth.KafkaPrincipal; + +/** + * Options for {@link Admin#createDelegationToken(CreateDelegationTokenOptions)}. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class CreateDelegationTokenOptions extends AbstractOptions { + private long maxLifeTimeMs = -1; + private List renewers = new LinkedList<>(); + + public CreateDelegationTokenOptions renewers(List renewers) { + this.renewers = renewers; + return this; + } + + public List renewers() { + return renewers; + } + + public CreateDelegationTokenOptions maxlifeTimeMs(long maxLifeTimeMs) { + this.maxLifeTimeMs = maxLifeTimeMs; + return this; + } + + public long maxlifeTimeMs() { + return maxLifeTimeMs; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/CreateDelegationTokenResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/CreateDelegationTokenResult.java new file mode 100644 index 0000000..7aa4804 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/CreateDelegationTokenResult.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.security.token.delegation.DelegationToken; + +/** + * The result of the {@link KafkaAdminClient#createDelegationToken(CreateDelegationTokenOptions)} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class CreateDelegationTokenResult { + private final KafkaFuture delegationToken; + + CreateDelegationTokenResult(KafkaFuture delegationToken) { + this.delegationToken = delegationToken; + } + + /** + * Returns a future which yields a delegation token + */ + public KafkaFuture delegationToken() { + return delegationToken; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/CreatePartitionsOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/CreatePartitionsOptions.java new file mode 100644 index 0000000..5a183bd --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/CreatePartitionsOptions.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Map; + +/** + * Options for {@link Admin#createPartitions(Map)}. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class CreatePartitionsOptions extends AbstractOptions { + + private boolean validateOnly = false; + private boolean retryOnQuotaViolation = true; + + public CreatePartitionsOptions() { + } + + /** + * Return true if the request should be validated without creating new partitions. + */ + public boolean validateOnly() { + return validateOnly; + } + + /** + * Set to true if the request should be validated without creating new partitions. + */ + public CreatePartitionsOptions validateOnly(boolean validateOnly) { + this.validateOnly = validateOnly; + return this; + } + + /** + * Set to true if quota violation should be automatically retried. + */ + public CreatePartitionsOptions retryOnQuotaViolation(boolean retryOnQuotaViolation) { + this.retryOnQuotaViolation = retryOnQuotaViolation; + return this; + } + + /** + * Returns true if quota violation should be automatically retried. + */ + public boolean shouldRetryOnQuotaViolation() { + return retryOnQuotaViolation; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/CreatePartitionsResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/CreatePartitionsResult.java new file mode 100644 index 0000000..8b864b6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/CreatePartitionsResult.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Map; + +/** + * The result of the {@link Admin#createPartitions(Map)} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class CreatePartitionsResult { + + private final Map> values; + + CreatePartitionsResult(Map> values) { + this.values = values; + } + + /** + * Return a map from topic names to futures, which can be used to check the status of individual + * partition creations. + */ + public Map> values() { + return values; + } + + /** + * Return a future which succeeds if all the partition creations succeed. + */ + public KafkaFuture all() { + return KafkaFuture.allOf(values.values().toArray(new KafkaFuture[0])); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/CreateTopicsOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/CreateTopicsOptions.java new file mode 100644 index 0000000..c897f03 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/CreateTopicsOptions.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Collection; + +/** + * Options for {@link Admin#createTopics(Collection)}. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class CreateTopicsOptions extends AbstractOptions { + + private boolean validateOnly = false; + private boolean retryOnQuotaViolation = true; + + /** + * Set the timeout in milliseconds for this operation or {@code null} if the default api timeout for the + * AdminClient should be used. + * + */ + // This method is retained to keep binary compatibility with 0.11 + public CreateTopicsOptions timeoutMs(Integer timeoutMs) { + this.timeoutMs = timeoutMs; + return this; + } + + /** + * Set to true if the request should be validated without creating the topic. + */ + public CreateTopicsOptions validateOnly(boolean validateOnly) { + this.validateOnly = validateOnly; + return this; + } + + /** + * Return true if the request should be validated without creating the topic. + */ + public boolean shouldValidateOnly() { + return validateOnly; + } + + + /** + * Set to true if quota violation should be automatically retried. + */ + public CreateTopicsOptions retryOnQuotaViolation(boolean retryOnQuotaViolation) { + this.retryOnQuotaViolation = retryOnQuotaViolation; + return this; + } + + /** + * Returns true if quota violation should be automatically retried. + */ + public boolean shouldRetryOnQuotaViolation() { + return retryOnQuotaViolation; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/CreateTopicsResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/CreateTopicsResult.java new file mode 100644 index 0000000..100e996 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/CreateTopicsResult.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.errors.ApiException; + +import java.util.Collection; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * The result of {@link Admin#createTopics(Collection)}. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class CreateTopicsResult { + final static int UNKNOWN = -1; + + private final Map> futures; + + protected CreateTopicsResult(Map> futures) { + this.futures = futures; + } + + /** + * Return a map from topic names to futures, which can be used to check the status of individual + * topic creations. + */ + public Map> values() { + return futures.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().thenApply(v -> (Void) null))); + } + + /** + * Return a future which succeeds if all the topic creations succeed. + */ + public KafkaFuture all() { + return KafkaFuture.allOf(futures.values().toArray(new KafkaFuture[0])); + } + + /** + * Returns a future that provides topic configs for the topic when the request completes. + *

+ * If broker version doesn't support replication factor in the response, throw + * {@link org.apache.kafka.common.errors.UnsupportedVersionException}. + * If broker returned an error for topic configs, throw appropriate exception. For example, + * {@link org.apache.kafka.common.errors.TopicAuthorizationException} is thrown if user does not + * have permission to describe topic configs. + */ + public KafkaFuture config(String topic) { + return futures.get(topic).thenApply(TopicMetadataAndConfig::config); + } + + /** + * Returns a future that provides topic ID for the topic when the request completes. + *

+ * If broker version doesn't support replication factor in the response, throw + * {@link org.apache.kafka.common.errors.UnsupportedVersionException}. + * If broker returned an error for topic configs, throw appropriate exception. For example, + * {@link org.apache.kafka.common.errors.TopicAuthorizationException} is thrown if user does not + * have permission to describe topic configs. + */ + public KafkaFuture topicId(String topic) { + return futures.get(topic).thenApply(TopicMetadataAndConfig::topicId); + } + + /** + * Returns a future that provides number of partitions in the topic when the request completes. + *

+ * If broker version doesn't support replication factor in the response, throw + * {@link org.apache.kafka.common.errors.UnsupportedVersionException}. + * If broker returned an error for topic configs, throw appropriate exception. For example, + * {@link org.apache.kafka.common.errors.TopicAuthorizationException} is thrown if user does not + * have permission to describe topic configs. + */ + public KafkaFuture numPartitions(String topic) { + return futures.get(topic).thenApply(TopicMetadataAndConfig::numPartitions); + } + + /** + * Returns a future that provides replication factor for the topic when the request completes. + *

+ * If broker version doesn't support replication factor in the response, throw + * {@link org.apache.kafka.common.errors.UnsupportedVersionException}. + * If broker returned an error for topic configs, throw appropriate exception. For example, + * {@link org.apache.kafka.common.errors.TopicAuthorizationException} is thrown if user does not + * have permission to describe topic configs. + */ + public KafkaFuture replicationFactor(String topic) { + return futures.get(topic).thenApply(TopicMetadataAndConfig::replicationFactor); + } + + public static class TopicMetadataAndConfig { + private final ApiException exception; + private final Uuid topicId; + private final int numPartitions; + private final int replicationFactor; + private final Config config; + + public TopicMetadataAndConfig(Uuid topicId, int numPartitions, int replicationFactor, Config config) { + this.exception = null; + this.topicId = topicId; + this.numPartitions = numPartitions; + this.replicationFactor = replicationFactor; + this.config = config; + } + + public TopicMetadataAndConfig(ApiException exception) { + this.exception = exception; + this.topicId = Uuid.ZERO_UUID; + this.numPartitions = UNKNOWN; + this.replicationFactor = UNKNOWN; + this.config = null; + } + + public Uuid topicId() { + ensureSuccess(); + return topicId; + } + + public int numPartitions() { + ensureSuccess(); + return numPartitions; + } + + public int replicationFactor() { + ensureSuccess(); + return replicationFactor; + } + + public Config config() { + ensureSuccess(); + return config; + } + + private void ensureSuccess() { + if (exception != null) + throw exception; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DeleteAclsOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/DeleteAclsOptions.java new file mode 100644 index 0000000..7c250e1 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DeleteAclsOptions.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Collection; + +/** + * Options for the {@link Admin#deleteAcls(Collection)} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DeleteAclsOptions extends AbstractOptions { + + /** + * Set the timeout in milliseconds for this operation or {@code null} if the default api timeout for the + * AdminClient should be used. + * + */ + // This method is retained to keep binary compatibility with 0.11 + public DeleteAclsOptions timeoutMs(Integer timeoutMs) { + this.timeoutMs = timeoutMs; + return this; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DeleteAclsResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/DeleteAclsResult.java new file mode 100644 index 0000000..391b9d1 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DeleteAclsResult.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.acl.AclBinding; +import org.apache.kafka.common.acl.AclBindingFilter; +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.errors.ApiException; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; + +/** + * The result of the {@link Admin#deleteAcls(Collection)} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DeleteAclsResult { + + /** + * A class containing either the deleted ACL binding or an exception if the delete failed. + */ + public static class FilterResult { + private final AclBinding binding; + private final ApiException exception; + + FilterResult(AclBinding binding, ApiException exception) { + this.binding = binding; + this.exception = exception; + } + + /** + * Return the deleted ACL binding or null if there was an error. + */ + public AclBinding binding() { + return binding; + } + + /** + * Return an exception if the ACL delete was not successful or null if it was. + */ + public ApiException exception() { + return exception; + } + } + + /** + * A class containing the results of the delete ACLs operation. + */ + public static class FilterResults { + private final List values; + + FilterResults(List values) { + this.values = values; + } + + /** + * Return a list of delete ACLs results for a given filter. + */ + public List values() { + return values; + } + } + + private final Map> futures; + + DeleteAclsResult(Map> futures) { + this.futures = futures; + } + + /** + * Return a map from acl filters to futures which can be used to check the status of the deletions by each + * filter. + */ + public Map> values() { + return futures; + } + + /** + * Return a future which succeeds only if all the ACLs deletions succeed, and which contains all the deleted ACLs. + * Note that it if the filters don't match any ACLs, this is not considered an error. + */ + public KafkaFuture> all() { + return KafkaFuture.allOf(futures.values().toArray(new KafkaFuture[0])).thenApply(v -> getAclBindings(futures)); + } + + private List getAclBindings(Map> futures) { + List acls = new ArrayList<>(); + for (KafkaFuture value: futures.values()) { + FilterResults results; + try { + results = value.get(); + } catch (Throwable e) { + // This should be unreachable, since the future returned by KafkaFuture#allOf should + // have failed if any Future failed. + throw new KafkaException("DeleteAclsResult#all: internal error", e); + } + for (FilterResult result : results.values()) { + if (result.exception() != null) + throw result.exception(); + acls.add(result.binding()); + } + } + return acls; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DeleteConsumerGroupOffsetsOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/DeleteConsumerGroupOffsetsOptions.java new file mode 100644 index 0000000..63e6b4b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DeleteConsumerGroupOffsetsOptions.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import java.util.Set; +import org.apache.kafka.common.annotation.InterfaceStability; + +/** + * Options for the {@link Admin#deleteConsumerGroupOffsets(String, Set)} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DeleteConsumerGroupOffsetsOptions extends AbstractOptions { + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DeleteConsumerGroupOffsetsResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/DeleteConsumerGroupOffsetsResult.java new file mode 100644 index 0000000..336e9c0 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DeleteConsumerGroupOffsetsResult.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import java.util.Set; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Map; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.protocol.Errors; + +/** + * The result of the {@link Admin#deleteConsumerGroupOffsets(String, Set)} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DeleteConsumerGroupOffsetsResult { + private final KafkaFuture> future; + private final Set partitions; + + + DeleteConsumerGroupOffsetsResult(KafkaFuture> future, Set partitions) { + this.future = future; + this.partitions = partitions; + } + + /** + * Return a future which can be used to check the result for a given partition. + */ + public KafkaFuture partitionResult(final TopicPartition partition) { + if (!partitions.contains(partition)) { + throw new IllegalArgumentException("Partition " + partition + " was not included in the original request"); + } + final KafkaFutureImpl result = new KafkaFutureImpl<>(); + + this.future.whenComplete((topicPartitions, throwable) -> { + if (throwable != null) { + result.completeExceptionally(throwable); + } else if (!maybeCompleteExceptionally(topicPartitions, partition, result)) { + result.complete(null); + } + }); + return result; + } + + /** + * Return a future which succeeds only if all the deletions succeed. + * If not, the first partition error shall be returned. + */ + public KafkaFuture all() { + final KafkaFutureImpl result = new KafkaFutureImpl<>(); + + this.future.whenComplete((topicPartitions, throwable) -> { + if (throwable != null) { + result.completeExceptionally(throwable); + } else { + for (TopicPartition partition : partitions) { + if (maybeCompleteExceptionally(topicPartitions, partition, result)) { + return; + } + } + result.complete(null); + } + }); + return result; + } + + private boolean maybeCompleteExceptionally(Map partitionLevelErrors, + TopicPartition partition, + KafkaFutureImpl result) { + Throwable exception = KafkaAdminClient.getSubLevelError(partitionLevelErrors, partition, + "Offset deletion result for partition \"" + partition + "\" was not included in the response"); + if (exception != null) { + result.completeExceptionally(exception); + return true; + } else { + return false; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DeleteConsumerGroupsOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/DeleteConsumerGroupsOptions.java new file mode 100644 index 0000000..081aeab --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DeleteConsumerGroupsOptions.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Collection; + +/** + * Options for the {@link Admin#deleteConsumerGroups(Collection)} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DeleteConsumerGroupsOptions extends AbstractOptions { + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DeleteConsumerGroupsResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/DeleteConsumerGroupsResult.java new file mode 100644 index 0000000..90ddbd0 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DeleteConsumerGroupsResult.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; + +/** + * The result of the {@link Admin#deleteConsumerGroups(Collection)} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DeleteConsumerGroupsResult { + private final Map> futures; + + DeleteConsumerGroupsResult(final Map> futures) { + this.futures = futures; + } + + /** + * Return a map from group id to futures which can be used to check the status of + * individual deletions. + */ + public Map> deletedGroups() { + Map> deletedGroups = new HashMap<>(futures.size()); + futures.forEach((key, future) -> deletedGroups.put(key, future)); + return deletedGroups; + } + + /** + * Return a future which succeeds only if all the consumer group deletions succeed. + */ + public KafkaFuture all() { + return KafkaFuture.allOf(futures.values().toArray(new KafkaFuture[0])); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DeleteRecordsOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/DeleteRecordsOptions.java new file mode 100644 index 0000000..34af759 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DeleteRecordsOptions.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Map; + +/** + * Options for {@link Admin#deleteRecords(Map, DeleteRecordsOptions)}. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DeleteRecordsOptions extends AbstractOptions { + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DeleteRecordsResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/DeleteRecordsResult.java new file mode 100644 index 0000000..0196632 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DeleteRecordsResult.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Map; + +/** + * The result of the {@link Admin#deleteRecords(Map)} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DeleteRecordsResult { + + private final Map> futures; + + public DeleteRecordsResult(Map> futures) { + this.futures = futures; + } + + /** + * Return a map from topic partition to futures which can be used to check the status of + * individual deletions. + */ + public Map> lowWatermarks() { + return futures; + } + + /** + * Return a future which succeeds only if all the records deletions succeed. + */ + public KafkaFuture all() { + return KafkaFuture.allOf(futures.values().toArray(new KafkaFuture[0])); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DeleteTopicsOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/DeleteTopicsOptions.java new file mode 100644 index 0000000..2711aff --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DeleteTopicsOptions.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Collection; + +/** + * Options for {@link Admin#deleteTopics(Collection)}. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DeleteTopicsOptions extends AbstractOptions { + + private boolean retryOnQuotaViolation = true; + + /** + * Set the timeout in milliseconds for this operation or {@code null} if the default api timeout for the + * AdminClient should be used. + * + */ + // This method is retained to keep binary compatibility with 0.11 + public DeleteTopicsOptions timeoutMs(Integer timeoutMs) { + this.timeoutMs = timeoutMs; + return this; + } + + /** + * Set to true if quota violation should be automatically retried. + */ + public DeleteTopicsOptions retryOnQuotaViolation(boolean retryOnQuotaViolation) { + this.retryOnQuotaViolation = retryOnQuotaViolation; + return this; + } + + /** + * Returns true if quota violation should be automatically retried. + */ + public boolean shouldRetryOnQuotaViolation() { + return retryOnQuotaViolation; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DeleteTopicsResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/DeleteTopicsResult.java new file mode 100644 index 0000000..725b82a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DeleteTopicsResult.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.TopicCollection; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Collection; +import java.util.Map; + +/** + * The result of the {@link Admin#deleteTopics(Collection)} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DeleteTopicsResult { + private final Map> topicIdFutures; + private final Map> nameFutures; + + protected DeleteTopicsResult(Map> topicIdFutures, Map> nameFutures) { + if (topicIdFutures != null && nameFutures != null) + throw new IllegalArgumentException("topicIdFutures and nameFutures cannot both be specified."); + if (topicIdFutures == null && nameFutures == null) + throw new IllegalArgumentException("topicIdFutures and nameFutures cannot both be null."); + this.topicIdFutures = topicIdFutures; + this.nameFutures = nameFutures; + } + + static DeleteTopicsResult ofTopicIds(Map> topicIdFutures) { + return new DeleteTopicsResult(topicIdFutures, null); + } + + static DeleteTopicsResult ofTopicNames(Map> nameFutures) { + return new DeleteTopicsResult(null, nameFutures); + } + + /** + * Use when {@link Admin#deleteTopics(TopicCollection, DeleteTopicsOptions)} used a TopicIdCollection + * @return a map from topic IDs to futures which can be used to check the status of + * individual deletions if the deleteTopics request used topic IDs. Otherwise return null. + */ + public Map> topicIdValues() { + return topicIdFutures; + } + + /** + * Use when {@link Admin#deleteTopics(TopicCollection, DeleteTopicsOptions)} used a TopicNameCollection + * @return a map from topic names to futures which can be used to check the status of + * individual deletions if the deleteTopics request used topic names. Otherwise return null. + */ + public Map> topicNameValues() { + return nameFutures; + } + + /** + * @return a map from topic names to futures which can be used to check the status of + * individual deletions if the deleteTopics request used topic names. Otherwise return null. + * @deprecated Since 3.0 use {@link #topicNameValues} instead + */ + @Deprecated + public Map> values() { + return nameFutures; + } + + /** + * @return a future which succeeds only if all the topic deletions succeed. + */ + public KafkaFuture all() { + return (topicIdFutures == null) ? KafkaFuture.allOf(nameFutures.values().toArray(new KafkaFuture[0])) : + KafkaFuture.allOf(topicIdFutures.values().toArray(new KafkaFuture[0])); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DeletedRecords.java b/clients/src/main/java/org/apache/kafka/clients/admin/DeletedRecords.java new file mode 100644 index 0000000..983ae28 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DeletedRecords.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +/** + * Represents information about deleted records + * + * The API for this class is still evolving and we may break compatibility in minor releases, if necessary. + */ +@InterfaceStability.Evolving +public class DeletedRecords { + + private final long lowWatermark; + + /** + * Create an instance of this class with the provided parameters. + * + * @param lowWatermark "low watermark" for the topic partition on which the deletion was executed + */ + public DeletedRecords(long lowWatermark) { + this.lowWatermark = lowWatermark; + } + + /** + * Return the "low watermark" for the topic partition on which the deletion was executed + */ + public long lowWatermark() { + return lowWatermark; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DescribeAclsOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeAclsOptions.java new file mode 100644 index 0000000..e44d584 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeAclsOptions.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.acl.AclBindingFilter; +import org.apache.kafka.common.annotation.InterfaceStability; + +/** + * Options for {@link Admin#describeAcls(AclBindingFilter)}. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DescribeAclsOptions extends AbstractOptions { + + /** + * Set the timeout in milliseconds for this operation or {@code null} if the default api timeout for the + * AdminClient should be used. + * + */ + // This method is retained to keep binary compatibility with 0.11 + public DescribeAclsOptions timeoutMs(Integer timeoutMs) { + this.timeoutMs = timeoutMs; + return this; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DescribeAclsResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeAclsResult.java new file mode 100644 index 0000000..fb16222 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeAclsResult.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.acl.AclBinding; +import org.apache.kafka.common.acl.AclBindingFilter; +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Collection; + +/** + * The result of the {@link KafkaAdminClient#describeAcls(AclBindingFilter)} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DescribeAclsResult { + private final KafkaFuture> future; + + DescribeAclsResult(KafkaFuture> future) { + this.future = future; + } + + /** + * Return a future containing the ACLs requested. + */ + public KafkaFuture> values() { + return future; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DescribeClientQuotasOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeClientQuotasOptions.java new file mode 100644 index 0000000..c3bdc7b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeClientQuotasOptions.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.quota.ClientQuotaFilter; + +/** + * Options for {@link Admin#describeClientQuotas(ClientQuotaFilter, DescribeClientQuotasOptions)}. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DescribeClientQuotasOptions extends AbstractOptions { +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DescribeClientQuotasResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeClientQuotasResult.java new file mode 100644 index 0000000..0e41bc7 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeClientQuotasResult.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.quota.ClientQuotaEntity; +import org.apache.kafka.common.quota.ClientQuotaFilter; + +import java.util.Map; + +/** + * The result of the {@link Admin#describeClientQuotas(ClientQuotaFilter, DescribeClientQuotasOptions)} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DescribeClientQuotasResult { + + private final KafkaFuture>> entities; + + /** + * Maps an entity to its configured quota value(s). Note if no value is defined for a quota + * type for that entity's config, then it is not included in the resulting value map. + * + * @param entities future for the collection of entities that matched the filter + */ + public DescribeClientQuotasResult(KafkaFuture>> entities) { + this.entities = entities; + } + + /** + * Returns a map from quota entity to a future which can be used to check the status of the operation. + */ + public KafkaFuture>> entities() { + return entities; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DescribeClusterOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeClusterOptions.java new file mode 100644 index 0000000..2eac1f0 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeClusterOptions.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +/** + * Options for {@link Admin#describeCluster()}. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DescribeClusterOptions extends AbstractOptions { + + private boolean includeAuthorizedOperations; + + /** + * Set the timeout in milliseconds for this operation or {@code null} if the default api timeout for the + * AdminClient should be used. + * + */ + // This method is retained to keep binary compatibility with 0.11 + public DescribeClusterOptions timeoutMs(Integer timeoutMs) { + this.timeoutMs = timeoutMs; + return this; + } + + public DescribeClusterOptions includeAuthorizedOperations(boolean includeAuthorizedOperations) { + this.includeAuthorizedOperations = includeAuthorizedOperations; + return this; + } + + /** + * Specify if authorized operations should be included in the response. Note that some + * older brokers cannot not supply this information even if it is requested. + */ + public boolean includeAuthorizedOperations() { + return includeAuthorizedOperations; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DescribeClusterResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeClusterResult.java new file mode 100644 index 0000000..d307d25 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeClusterResult.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.acl.AclOperation; +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Collection; +import java.util.Set; + +/** + * The result of the {@link KafkaAdminClient#describeCluster()} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DescribeClusterResult { + private final KafkaFuture> nodes; + private final KafkaFuture controller; + private final KafkaFuture clusterId; + private final KafkaFuture> authorizedOperations; + + DescribeClusterResult(KafkaFuture> nodes, + KafkaFuture controller, + KafkaFuture clusterId, + KafkaFuture> authorizedOperations) { + this.nodes = nodes; + this.controller = controller; + this.clusterId = clusterId; + this.authorizedOperations = authorizedOperations; + } + + /** + * Returns a future which yields a collection of nodes. + */ + public KafkaFuture> nodes() { + return nodes; + } + + /** + * Returns a future which yields the current controller id. + * Note that this may yield null, if the controller ID is not yet known. + */ + public KafkaFuture controller() { + return controller; + } + + /** + * Returns a future which yields the current cluster id. The future value will be non-null if the + * broker version is 0.10.1.0 or higher and null otherwise. + */ + public KafkaFuture clusterId() { + return clusterId; + } + + /** + * Returns a future which yields authorized operations. The future value will be non-null if the + * broker supplied this information, and null otherwise. + */ + public KafkaFuture> authorizedOperations() { + return authorizedOperations; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DescribeConfigsOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeConfigsOptions.java new file mode 100644 index 0000000..bfb9c18 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeConfigsOptions.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Collection; + +/** + * Options for {@link Admin#describeConfigs(Collection)}. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DescribeConfigsOptions extends AbstractOptions { + + private boolean includeSynonyms = false; + private boolean includeDocumentation = false; + + /** + * Set the timeout in milliseconds for this operation or {@code null} if the default api timeout for the + * AdminClient should be used. + * + */ + // This method is retained to keep binary compatibility with 0.11 + public DescribeConfigsOptions timeoutMs(Integer timeoutMs) { + this.timeoutMs = timeoutMs; + return this; + } + + /** + * Return true if synonym configs should be returned in the response. + */ + public boolean includeSynonyms() { + return includeSynonyms; + } + + /** + * Return true if config documentation should be returned in the response. + */ + public boolean includeDocumentation() { + return includeDocumentation; + } + + /** + * Set to true if synonym configs should be returned in the response. + */ + public DescribeConfigsOptions includeSynonyms(boolean includeSynonyms) { + this.includeSynonyms = includeSynonyms; + return this; + } + + /** + * Set to true if config documentation should be returned in the response. + */ + public DescribeConfigsOptions includeDocumentation(boolean includeDocumentation) { + this.includeDocumentation = includeDocumentation; + return this; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DescribeConfigsResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeConfigsResult.java new file mode 100644 index 0000000..653c97d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeConfigsResult.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.config.ConfigResource; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ExecutionException; + +/** + * The result of the {@link KafkaAdminClient#describeConfigs(Collection)} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DescribeConfigsResult { + + private final Map> futures; + + protected DescribeConfigsResult(Map> futures) { + this.futures = futures; + } + + /** + * Return a map from resources to futures which can be used to check the status of the configuration for each + * resource. + */ + public Map> values() { + return futures; + } + + /** + * Return a future which succeeds only if all the config descriptions succeed. + */ + public KafkaFuture> all() { + return KafkaFuture.allOf(futures.values().toArray(new KafkaFuture[0])). + thenApply(new KafkaFuture.BaseFunction>() { + @Override + public Map apply(Void v) { + Map configs = new HashMap<>(futures.size()); + for (Map.Entry> entry : futures.entrySet()) { + try { + configs.put(entry.getKey(), entry.getValue().get()); + } catch (InterruptedException | ExecutionException e) { + // This should be unreachable, because allOf ensured that all the futures + // completed successfully. + throw new RuntimeException(e); + } + } + return configs; + } + }); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DescribeConsumerGroupsOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeConsumerGroupsOptions.java new file mode 100644 index 0000000..70238a8 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeConsumerGroupsOptions.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Collection; + +/** + * Options for {@link Admin#describeConsumerGroups(Collection, DescribeConsumerGroupsOptions)}. + *

+ * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DescribeConsumerGroupsOptions extends AbstractOptions { + private boolean includeAuthorizedOperations; + + public DescribeConsumerGroupsOptions includeAuthorizedOperations(boolean includeAuthorizedOperations) { + this.includeAuthorizedOperations = includeAuthorizedOperations; + return this; + } + + public boolean includeAuthorizedOperations() { + return includeAuthorizedOperations; + } +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DescribeConsumerGroupsResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeConsumerGroupsResult.java new file mode 100644 index 0000000..8940060 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeConsumerGroupsResult.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ExecutionException; + + +/** + * The result of the {@link KafkaAdminClient#describeConsumerGroups(Collection, DescribeConsumerGroupsOptions)}} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DescribeConsumerGroupsResult { + + private final Map> futures; + + public DescribeConsumerGroupsResult(final Map> futures) { + this.futures = futures; + } + + /** + * Return a map from group id to futures which yield group descriptions. + */ + public Map> describedGroups() { + Map> describedGroups = new HashMap<>(); + futures.forEach((key, future) -> describedGroups.put(key, future)); + return describedGroups; + } + + /** + * Return a future which yields all ConsumerGroupDescription objects, if all the describes succeed. + */ + public KafkaFuture> all() { + return KafkaFuture.allOf(futures.values().toArray(new KafkaFuture[0])).thenApply( + nil -> { + Map descriptions = new HashMap<>(futures.size()); + futures.forEach((key, future) -> { + try { + descriptions.put(key, future.get()); + } catch (InterruptedException | ExecutionException e) { + // This should be unreachable, since the KafkaFuture#allOf already ensured + // that all of the futures completed successfully. + throw new RuntimeException(e); + } + }); + return descriptions; + }); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DescribeDelegationTokenOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeDelegationTokenOptions.java new file mode 100644 index 0000000..ef9f105 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeDelegationTokenOptions.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import java.util.List; + +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.security.auth.KafkaPrincipal; + +/** + * Options for {@link Admin#describeDelegationToken(DescribeDelegationTokenOptions)}. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DescribeDelegationTokenOptions extends AbstractOptions { + private List owners; + + /** + * if owners is null, all the user owned tokens and tokens where user have Describe permission + * will be returned. + * @param owners + * @return this instance + */ + public DescribeDelegationTokenOptions owners(List owners) { + this.owners = owners; + return this; + } + + public List owners() { + return owners; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DescribeDelegationTokenResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeDelegationTokenResult.java new file mode 100644 index 0000000..47b2530 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeDelegationTokenResult.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import java.util.List; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.security.token.delegation.DelegationToken; + +/** + * The result of the {@link KafkaAdminClient#describeDelegationToken(DescribeDelegationTokenOptions)} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DescribeDelegationTokenResult { + private final KafkaFuture> delegationTokens; + + DescribeDelegationTokenResult(KafkaFuture> delegationTokens) { + this.delegationTokens = delegationTokens; + } + + /** + * Returns a future which yields list of delegation tokens + */ + public KafkaFuture> delegationTokens() { + return delegationTokens; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DescribeFeaturesOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeFeaturesOptions.java new file mode 100644 index 0000000..a51ca74 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeFeaturesOptions.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +/** + * Options for {@link AdminClient#describeFeatures(DescribeFeaturesOptions)}. + * + * The API of this class is evolving. See {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DescribeFeaturesOptions extends AbstractOptions { +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DescribeFeaturesResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeFeaturesResult.java new file mode 100644 index 0000000..c48dc19 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeFeaturesResult.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; + +/** + * The result of the {@link Admin#describeFeatures(DescribeFeaturesOptions)} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +public class DescribeFeaturesResult { + + private final KafkaFuture future; + + DescribeFeaturesResult(KafkaFuture future) { + this.future = future; + } + + public KafkaFuture featureMetadata() { + return future; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DescribeLogDirsOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeLogDirsOptions.java new file mode 100644 index 0000000..17890ca --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeLogDirsOptions.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Collection; + + +/** + * Options for {@link Admin#describeLogDirs(Collection)} + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DescribeLogDirsOptions extends AbstractOptions { + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DescribeLogDirsResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeLogDirsResult.java new file mode 100644 index 0000000..96a81f0 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeLogDirsResult.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.annotation.InterfaceStability; +import java.util.HashMap; +import java.util.Collection; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.DescribeLogDirsResponse; + + +/** + * The result of the {@link Admin#describeLogDirs(Collection)} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DescribeLogDirsResult { + private final Map>> futures; + + DescribeLogDirsResult(Map>> futures) { + this.futures = futures; + } + + /** + * Return a map from brokerId to future which can be used to check the information of partitions on each individual broker. + * @deprecated Deprecated Since Kafka 2.7. Use {@link #descriptions()}. + */ + @Deprecated + @SuppressWarnings("deprecation") + public Map>> values() { + return descriptions().entrySet().stream() + .collect(Collectors.toMap( + Map.Entry::getKey, + entry -> entry.getValue().thenApply(map -> convertMapValues(map)))); + } + + @SuppressWarnings("deprecation") + private Map convertMapValues(Map map) { + Stream> stream = map.entrySet().stream(); + return stream.collect(Collectors.toMap( + Map.Entry::getKey, + infoEntry -> { + LogDirDescription logDir = infoEntry.getValue(); + return new DescribeLogDirsResponse.LogDirInfo(logDir.error() == null ? Errors.NONE : Errors.forException(logDir.error()), + logDir.replicaInfos().entrySet().stream().collect(Collectors.toMap( + Map.Entry::getKey, + replicaEntry -> new DescribeLogDirsResponse.ReplicaInfo( + replicaEntry.getValue().size(), + replicaEntry.getValue().offsetLag(), + replicaEntry.getValue().isFuture()) + ))); + })); + } + + /** + * Return a map from brokerId to future which can be used to check the information of partitions on each individual broker. + * The result of the future is a map from broker log directory path to a description of that log directory. + */ + public Map>> descriptions() { + return futures; + } + + /** + * Return a future which succeeds only if all the brokers have responded without error + * @deprecated Deprecated Since Kafka 2.7. Use {@link #allDescriptions()}. + */ + @Deprecated + @SuppressWarnings("deprecation") + public KafkaFuture>> all() { + return allDescriptions().thenApply(map -> map.entrySet().stream().collect(Collectors.toMap( + entry -> entry.getKey(), + entry -> convertMapValues(entry.getValue()) + ))); + } + + /** + * Return a future which succeeds only if all the brokers have responded without error. + * The result of the future is a map from brokerId to a map from broker log directory path + * to a description of that log directory. + */ + public KafkaFuture>> allDescriptions() { + return KafkaFuture.allOf(futures.values().toArray(new KafkaFuture[0])). + thenApply(v -> { + Map> descriptions = new HashMap<>(futures.size()); + for (Map.Entry>> entry : futures.entrySet()) { + try { + descriptions.put(entry.getKey(), entry.getValue().get()); + } catch (InterruptedException | ExecutionException e) { + // This should be unreachable, because allOf ensured that all the futures completed successfully. + throw new RuntimeException(e); + } + } + return descriptions; + }); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DescribeProducersOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeProducersOptions.java new file mode 100644 index 0000000..0776e40 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeProducersOptions.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Collection; +import java.util.Objects; +import java.util.OptionalInt; + +/** + * Options for {@link Admin#describeProducers(Collection)}. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DescribeProducersOptions extends AbstractOptions { + private OptionalInt brokerId = OptionalInt.empty(); + + public DescribeProducersOptions brokerId(int brokerId) { + this.brokerId = OptionalInt.of(brokerId); + return this; + } + + public OptionalInt brokerId() { + return brokerId; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DescribeProducersOptions that = (DescribeProducersOptions) o; + return Objects.equals(brokerId, that.brokerId) && + Objects.equals(timeoutMs, that.timeoutMs); + } + + @Override + public int hashCode() { + return Objects.hash(brokerId, timeoutMs); + } + + @Override + public String toString() { + return "DescribeProducersOptions(" + + "brokerId=" + brokerId + + ", timeoutMs=" + timeoutMs + + ')'; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DescribeProducersResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeProducersResult.java new file mode 100644 index 0000000..13977c6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeProducersResult.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; + +@InterfaceStability.Evolving +public class DescribeProducersResult { + + private final Map> futures; + + DescribeProducersResult(Map> futures) { + this.futures = futures; + } + + public KafkaFuture partitionResult(final TopicPartition partition) { + KafkaFuture future = futures.get(partition); + if (future == null) { + throw new IllegalArgumentException("Topic partition " + partition + + " was not included in the request"); + } + return future; + } + + public KafkaFuture> all() { + return KafkaFuture.allOf(futures.values().toArray(new KafkaFuture[0])) + .thenApply(nil -> { + Map results = new HashMap<>(futures.size()); + for (Map.Entry> entry : futures.entrySet()) { + try { + results.put(entry.getKey(), entry.getValue().get()); + } catch (InterruptedException | ExecutionException e) { + // This should be unreachable, because allOf ensured that all the futures completed successfully. + throw new KafkaException(e); + } + } + return results; + }); + } + + public static class PartitionProducerState { + private final List activeProducers; + + public PartitionProducerState(List activeProducers) { + this.activeProducers = activeProducers; + } + + public List activeProducers() { + return activeProducers; + } + + @Override + public String toString() { + return "PartitionProducerState(" + + "activeProducers=" + activeProducers + + ')'; + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DescribeReplicaLogDirsOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeReplicaLogDirsOptions.java new file mode 100644 index 0000000..589de50 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeReplicaLogDirsOptions.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; +import java.util.Collection; + +/** + * Options for {@link Admin#describeReplicaLogDirs(Collection)}. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DescribeReplicaLogDirsOptions extends AbstractOptions { + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DescribeReplicaLogDirsResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeReplicaLogDirsResult.java new file mode 100644 index 0000000..54bd9c1 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeReplicaLogDirsResult.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.TopicPartitionReplica; +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.requests.DescribeLogDirsResponse; + +import java.util.HashMap; +import java.util.Map; +import java.util.Collection; +import java.util.concurrent.ExecutionException; + + +/** + * The result of {@link Admin#describeReplicaLogDirs(Collection)}. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DescribeReplicaLogDirsResult { + private final Map> futures; + + DescribeReplicaLogDirsResult(Map> futures) { + this.futures = futures; + } + + /** + * Return a map from replica to future which can be used to check the log directory information of individual replicas + */ + public Map> values() { + return futures; + } + + /** + * Return a future which succeeds if log directory information of all replicas are available + */ + public KafkaFuture> all() { + return KafkaFuture.allOf(futures.values().toArray(new KafkaFuture[0])). + thenApply(new KafkaFuture.BaseFunction>() { + @Override + public Map apply(Void v) { + Map replicaLogDirInfos = new HashMap<>(); + for (Map.Entry> entry : futures.entrySet()) { + try { + replicaLogDirInfos.put(entry.getKey(), entry.getValue().get()); + } catch (InterruptedException | ExecutionException e) { + // This should be unreachable, because allOf ensured that all the futures completed successfully. + throw new RuntimeException(e); + } + } + return replicaLogDirInfos; + } + }); + } + + static public class ReplicaLogDirInfo { + // The current log directory of the replica of this partition on the given broker. + // Null if no replica is not found for this partition on the given broker. + private final String currentReplicaLogDir; + // Defined as max(HW of partition - LEO of the replica, 0). + private final long currentReplicaOffsetLag; + // The future log directory of the replica of this partition on the given broker. + // Null if the replica of this partition is not being moved to another log directory on the given broker. + private final String futureReplicaLogDir; + // The LEO of the replica - LEO of the future log of this replica in the destination log directory. + // -1 if either there is not replica for this partition or the replica of this partition is not being moved to another log directory on the given broker. + private final long futureReplicaOffsetLag; + + ReplicaLogDirInfo() { + this(null, DescribeLogDirsResponse.INVALID_OFFSET_LAG, null, DescribeLogDirsResponse.INVALID_OFFSET_LAG); + } + + ReplicaLogDirInfo(String currentReplicaLogDir, + long currentReplicaOffsetLag, + String futureReplicaLogDir, + long futureReplicaOffsetLag) { + this.currentReplicaLogDir = currentReplicaLogDir; + this.currentReplicaOffsetLag = currentReplicaOffsetLag; + this.futureReplicaLogDir = futureReplicaLogDir; + this.futureReplicaOffsetLag = futureReplicaOffsetLag; + } + + public String getCurrentReplicaLogDir() { + return currentReplicaLogDir; + } + + public long getCurrentReplicaOffsetLag() { + return currentReplicaOffsetLag; + } + + public String getFutureReplicaLogDir() { + return futureReplicaLogDir; + } + + public long getFutureReplicaOffsetLag() { + return futureReplicaOffsetLag; + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(); + if (futureReplicaLogDir != null) { + builder.append("(currentReplicaLogDir=") + .append(currentReplicaLogDir) + .append(", futureReplicaLogDir=") + .append(futureReplicaLogDir) + .append(", futureReplicaOffsetLag=") + .append(futureReplicaOffsetLag) + .append(")"); + } else { + builder.append("ReplicaLogDirInfo(currentReplicaLogDir=").append(currentReplicaLogDir).append(")"); + } + return builder.toString(); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DescribeTopicsOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeTopicsOptions.java new file mode 100644 index 0000000..299aaea --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeTopicsOptions.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Collection; + +/** + * Options for {@link Admin#describeTopics(Collection)}. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DescribeTopicsOptions extends AbstractOptions { + + private boolean includeAuthorizedOperations; + + /** + * Set the timeout in milliseconds for this operation or {@code null} if the default api timeout for the + * AdminClient should be used. + * + */ + // This method is retained to keep binary compatibility with 0.11 + public DescribeTopicsOptions timeoutMs(Integer timeoutMs) { + this.timeoutMs = timeoutMs; + return this; + } + + public DescribeTopicsOptions includeAuthorizedOperations(boolean includeAuthorizedOperations) { + this.includeAuthorizedOperations = includeAuthorizedOperations; + return this; + } + + public boolean includeAuthorizedOperations() { + return includeAuthorizedOperations; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DescribeTopicsResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeTopicsResult.java new file mode 100644 index 0000000..41593c5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeTopicsResult.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.TopicCollection; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ExecutionException; + +/** + * The result of the {@link KafkaAdminClient#describeTopics(Collection)} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DescribeTopicsResult { + private final Map> topicIdFutures; + private final Map> nameFutures; + + @Deprecated + protected DescribeTopicsResult(Map> futures) { + this(null, futures); + } + + // VisibleForTesting + protected DescribeTopicsResult(Map> topicIdFutures, Map> nameFutures) { + if (topicIdFutures != null && nameFutures != null) + throw new IllegalArgumentException("topicIdFutures and nameFutures cannot both be specified."); + if (topicIdFutures == null && nameFutures == null) + throw new IllegalArgumentException("topicIdFutures and nameFutures cannot both be null."); + this.topicIdFutures = topicIdFutures; + this.nameFutures = nameFutures; + } + + static DescribeTopicsResult ofTopicIds(Map> topicIdFutures) { + return new DescribeTopicsResult(topicIdFutures, null); + } + + static DescribeTopicsResult ofTopicNames(Map> nameFutures) { + return new DescribeTopicsResult(null, nameFutures); + } + + /** + * Use when {@link Admin#describeTopics(TopicCollection, DescribeTopicsOptions)} used a TopicIdCollection + * + * @return a map from topic IDs to futures which can be used to check the status of + * individual topics if the request used topic IDs, otherwise return null. + */ + public Map> topicIdValues() { + return topicIdFutures; + } + + /** + * Use when {@link Admin#describeTopics(TopicCollection, DescribeTopicsOptions)} used a TopicNameCollection + * + * @return a map from topic names to futures which can be used to check the status of + * individual topics if the request used topic names, otherwise return null. + */ + public Map> topicNameValues() { + return nameFutures; + } + + /** + * @return a map from topic names to futures which can be used to check the status of + * individual topics if the request used topic names, otherwise return null. + * + * @deprecated Since 3.1.0 use {@link #topicNameValues} instead + */ + @Deprecated + public Map> values() { + return nameFutures; + } + + /** + * @return A future map from topic names to descriptions which can be used to check + * the status of individual description if the describe topic request used + * topic names, otherwise return null, this request succeeds only if all the + * topic descriptions succeed + * + * @deprecated Since 3.1.0 use {@link #allTopicNames()} instead + */ + @Deprecated + public KafkaFuture> all() { + return all(nameFutures); + } + + /** + * @return A future map from topic names to descriptions which can be used to check + * the status of individual description if the describe topic request used + * topic names, otherwise return null, this request succeeds only if all the + * topic descriptions succeed + */ + public KafkaFuture> allTopicNames() { + return all(nameFutures); + } + + /** + * @return A future map from topic ids to descriptions which can be used to check the + * status of individual description if the describe topic request used topic + * ids, otherwise return null, this request succeeds only if all the topic + * descriptions succeed + */ + public KafkaFuture> allTopicIds() { + return all(topicIdFutures); + } + + /** + * Return a future which succeeds only if all the topic descriptions succeed. + */ + private static KafkaFuture> all(Map> futures) { + KafkaFuture future = KafkaFuture.allOf(futures.values().toArray(new KafkaFuture[0])); + return future. + thenApply(v -> { + Map descriptions = new HashMap<>(futures.size()); + for (Map.Entry> entry : futures.entrySet()) { + try { + descriptions.put(entry.getKey(), entry.getValue().get()); + } catch (InterruptedException | ExecutionException e) { + // This should be unreachable, because allOf ensured that all the futures + // completed successfully. + throw new RuntimeException(e); + } + } + return descriptions; + }); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DescribeTransactionsOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeTransactionsOptions.java new file mode 100644 index 0000000..47beb7f --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeTransactionsOptions.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Collection; + +/** + * Options for {@link Admin#describeTransactions(Collection)}. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DescribeTransactionsOptions extends AbstractOptions { + + @Override + public String toString() { + return "DescribeTransactionsOptions(" + + "timeoutMs=" + timeoutMs + + ')'; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DescribeTransactionsResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeTransactionsResult.java new file mode 100644 index 0000000..278a254 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeTransactionsResult.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.clients.admin.internals.CoordinatorKey; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ExecutionException; + +@InterfaceStability.Evolving +public class DescribeTransactionsResult { + private final Map> futures; + + DescribeTransactionsResult(Map> futures) { + this.futures = futures; + } + + /** + * Get the description of a specific transactional ID. + * + * @param transactionalId the transactional ID to describe + * @return a future which completes when the transaction description of a particular + * transactional ID is available. + * @throws IllegalArgumentException if the `transactionalId` was not included in the + * respective call to {@link Admin#describeTransactions(Collection, DescribeTransactionsOptions)}. + */ + public KafkaFuture description(String transactionalId) { + CoordinatorKey key = CoordinatorKey.byTransactionalId(transactionalId); + KafkaFuture future = futures.get(key); + if (future == null) { + throw new IllegalArgumentException("TransactionalId " + + "`" + transactionalId + "` was not included in the request"); + } + return future; + } + /** + * Get a future which returns a map of the transaction descriptions requested in the respective + * call to {@link Admin#describeTransactions(Collection, DescribeTransactionsOptions)}. + * + * If the description fails on any of the transactional IDs in the request, then this future + * will also fail. + * + * @return a future which either completes when all transaction descriptions complete or fails + * if any of the descriptions cannot be obtained + */ + public KafkaFuture> all() { + return KafkaFuture.allOf(futures.values().toArray(new KafkaFuture[0])) + .thenApply(nil -> { + Map results = new HashMap<>(futures.size()); + for (Map.Entry> entry : futures.entrySet()) { + try { + results.put(entry.getKey().idValue, entry.getValue().get()); + } catch (InterruptedException | ExecutionException e) { + // This should be unreachable, because allOf ensured that all the futures completed successfully. + throw new RuntimeException(e); + } + } + return results; + }); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DescribeUserScramCredentialsOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeUserScramCredentialsOptions.java new file mode 100644 index 0000000..1d1af47 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeUserScramCredentialsOptions.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.List; + +/** + * Options for {@link AdminClient#describeUserScramCredentials(List, DescribeUserScramCredentialsOptions)} + * + * The API of this class is evolving. See {@link AdminClient} for details. + */ +@InterfaceStability.Evolving +public class DescribeUserScramCredentialsOptions extends AbstractOptions { +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/DescribeUserScramCredentialsResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeUserScramCredentialsResult.java new file mode 100644 index 0000000..2eddd7e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/DescribeUserScramCredentialsResult.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.errors.ResourceNotFoundException; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.message.DescribeUserScramCredentialsResponseData; +import org.apache.kafka.common.protocol.Errors; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * The result of the {@link Admin#describeUserScramCredentials()} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class DescribeUserScramCredentialsResult { + private final KafkaFuture dataFuture; + + /** + * Package-private constructor + * + * @param dataFuture the future indicating response data from the call + */ + DescribeUserScramCredentialsResult(KafkaFuture dataFuture) { + this.dataFuture = Objects.requireNonNull(dataFuture); + } + + /** + * + * @return a future for the results of all described users with map keys (one per user) being consistent with the + * contents of the list returned by {@link #users()}. The future will complete successfully only if all such user + * descriptions complete successfully. + */ + public KafkaFuture> all() { + final KafkaFutureImpl> retval = new KafkaFutureImpl<>(); + dataFuture.whenComplete((data, throwable) -> { + if (throwable != null) { + retval.completeExceptionally(throwable); + } else { + /* Check to make sure every individual described user succeeded. Note that a successfully described user + * is one that appears with *either* a NONE error code or a RESOURCE_NOT_FOUND error code. The + * RESOURCE_NOT_FOUND means the client explicitly requested a describe of that particular user but it could + * not be described because it does not exist; such a user will not appear as a key in the returned map. + */ + Optional optionalFirstFailedDescribe = + data.results().stream().filter(result -> + result.errorCode() != Errors.NONE.code() && result.errorCode() != Errors.RESOURCE_NOT_FOUND.code()).findFirst(); + if (optionalFirstFailedDescribe.isPresent()) { + retval.completeExceptionally(Errors.forCode(optionalFirstFailedDescribe.get().errorCode()).exception(optionalFirstFailedDescribe.get().errorMessage())); + } else { + Map retvalMap = new HashMap<>(); + data.results().stream().forEach(userResult -> + retvalMap.put(userResult.user(), new UserScramCredentialsDescription(userResult.user(), + getScramCredentialInfosFor(userResult)))); + retval.complete(retvalMap); + } + } + }); + return retval; + } + + /** + * + * @return a future indicating the distinct users that meet the request criteria and that have at least one + * credential. The future will not complete successfully if the user is not authorized to perform the describe + * operation; otherwise, it will complete successfully as long as the list of users with credentials can be + * successfully determined within some hard-coded timeout period. Note that the returned list will not include users + * that do not exist/have no credentials: a request to describe an explicit list of users, none of which existed/had + * a credential, will result in a future that returns an empty list being returned here. A returned list will + * include users that have a credential but that could not be described. + */ + public KafkaFuture> users() { + final KafkaFutureImpl> retval = new KafkaFutureImpl<>(); + dataFuture.whenComplete((data, throwable) -> { + if (throwable != null) { + retval.completeExceptionally(throwable); + } else { + retval.complete(data.results().stream() + .filter(result -> result.errorCode() != Errors.RESOURCE_NOT_FOUND.code()) + .map(result -> result.user()).collect(Collectors.toList())); + } + }); + return retval; + } + + /** + * + * @param userName the name of the user description being requested + * @return a future indicating the description results for the given user. The future will complete exceptionally if + * the future returned by {@link #users()} completes exceptionally. Note that if the given user does not exist in + * the list of described users then the returned future will complete exceptionally with + * {@link org.apache.kafka.common.errors.ResourceNotFoundException}. + */ + public KafkaFuture description(String userName) { + final KafkaFutureImpl retval = new KafkaFutureImpl<>(); + dataFuture.whenComplete((data, throwable) -> { + if (throwable != null) { + retval.completeExceptionally(throwable); + } else { + // it is possible that there is no future for this user (for example, the original describe request was + // for users 1, 2, and 3 but this is looking for user 4), so explicitly take care of that case + Optional optionalUserResult = + data.results().stream().filter(result -> result.user().equals(userName)).findFirst(); + if (!optionalUserResult.isPresent()) { + retval.completeExceptionally(new ResourceNotFoundException("No such user: " + userName)); + } else { + DescribeUserScramCredentialsResponseData.DescribeUserScramCredentialsResult userResult = optionalUserResult.get(); + if (userResult.errorCode() != Errors.NONE.code()) { + // RESOURCE_NOT_FOUND is included here + retval.completeExceptionally(Errors.forCode(userResult.errorCode()).exception(userResult.errorMessage())); + } else { + retval.complete(new UserScramCredentialsDescription(userResult.user(), getScramCredentialInfosFor(userResult))); + } + } + } + }); + return retval; + } + + private static List getScramCredentialInfosFor( + DescribeUserScramCredentialsResponseData.DescribeUserScramCredentialsResult userResult) { + return userResult.credentialInfos().stream().map(c -> + new ScramCredentialInfo(ScramMechanism.fromType(c.mechanism()), c.iterations())) + .collect(Collectors.toList()); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/ElectLeadersOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/ElectLeadersOptions.java new file mode 100644 index 0000000..ae03ebe --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/ElectLeadersOptions.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.ElectionType; +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Set; + +/** + * Options for {@link Admin#electLeaders(ElectionType, Set, ElectLeadersOptions)}. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +final public class ElectLeadersOptions extends AbstractOptions { +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/ElectLeadersResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/ElectLeadersResult.java new file mode 100644 index 0000000..548c94c --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/ElectLeadersResult.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + + +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import org.apache.kafka.common.ElectionType; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.internals.KafkaFutureImpl; + +/** + * The result of {@link Admin#electLeaders(ElectionType, Set, ElectLeadersOptions)} + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +final public class ElectLeadersResult { + private final KafkaFuture>> electionFuture; + + ElectLeadersResult(KafkaFuture>> electionFuture) { + this.electionFuture = electionFuture; + } + + /** + *

Get a future for the topic partitions for which a leader election was attempted. + * If the election succeeded then the value for a topic partition will be the empty Optional. + * Otherwise the election failed and the Optional will be set with the error.

+ */ + public KafkaFuture>> partitions() { + return electionFuture; + } + + /** + * Return a future which succeeds if all the topic elections succeed. + */ + public KafkaFuture all() { + final KafkaFutureImpl result = new KafkaFutureImpl<>(); + + partitions().whenComplete( + new KafkaFuture.BiConsumer>, Throwable>() { + @Override + public void accept(Map> topicPartitions, Throwable throwable) { + if (throwable != null) { + result.completeExceptionally(throwable); + } else { + for (Optional exception : topicPartitions.values()) { + if (exception.isPresent()) { + result.completeExceptionally(exception.get()); + return; + } + } + result.complete(null); + } + } + }); + + return result; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/ExpireDelegationTokenOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/ExpireDelegationTokenOptions.java new file mode 100644 index 0000000..3bf9489 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/ExpireDelegationTokenOptions.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +/** + * Options for {@link Admin#expireDelegationToken(byte[], ExpireDelegationTokenOptions)}. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class ExpireDelegationTokenOptions extends AbstractOptions { + private long expiryTimePeriodMs = -1L; + + public ExpireDelegationTokenOptions expiryTimePeriodMs(long expiryTimePeriodMs) { + this.expiryTimePeriodMs = expiryTimePeriodMs; + return this; + } + + public long expiryTimePeriodMs() { + return expiryTimePeriodMs; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/ExpireDelegationTokenResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/ExpireDelegationTokenResult.java new file mode 100644 index 0000000..59b1714 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/ExpireDelegationTokenResult.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.annotation.InterfaceStability; + +/** + * The result of the {@link KafkaAdminClient#expireDelegationToken(byte[], ExpireDelegationTokenOptions)} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class ExpireDelegationTokenResult { + private final KafkaFuture expiryTimestamp; + + ExpireDelegationTokenResult(KafkaFuture expiryTimestamp) { + this.expiryTimestamp = expiryTimestamp; + } + + /** + * Returns a future which yields expiry timestamp + */ + public KafkaFuture expiryTimestamp() { + return expiryTimestamp; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/FeatureMetadata.java b/clients/src/main/java/org/apache/kafka/clients/admin/FeatureMetadata.java new file mode 100644 index 0000000..815f9e3 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/FeatureMetadata.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import static java.util.stream.Collectors.joining; + +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +/** + * Encapsulates details about finalized as well as supported features. This is particularly useful + * to hold the result returned by the {@link Admin#describeFeatures(DescribeFeaturesOptions)} API. + */ +public class FeatureMetadata { + + private final Map finalizedFeatures; + + private final Optional finalizedFeaturesEpoch; + + private final Map supportedFeatures; + + FeatureMetadata(final Map finalizedFeatures, + final Optional finalizedFeaturesEpoch, + final Map supportedFeatures) { + this.finalizedFeatures = new HashMap<>(finalizedFeatures); + this.finalizedFeaturesEpoch = finalizedFeaturesEpoch; + this.supportedFeatures = new HashMap<>(supportedFeatures); + } + + /** + * Returns a map of finalized feature versions. Each entry in the map contains a key being a + * feature name and the value being a range of version levels supported by every broker in the + * cluster. + */ + public Map finalizedFeatures() { + return new HashMap<>(finalizedFeatures); + } + + /** + * The epoch for the finalized features. + * If the returned value is empty, it means the finalized features are absent/unavailable. + */ + public Optional finalizedFeaturesEpoch() { + return finalizedFeaturesEpoch; + } + + /** + * Returns a map of supported feature versions. Each entry in the map contains a key being a + * feature name and the value being a range of versions supported by a particular broker in the + * cluster. + */ + public Map supportedFeatures() { + return new HashMap<>(supportedFeatures); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof FeatureMetadata)) { + return false; + } + + final FeatureMetadata that = (FeatureMetadata) other; + return Objects.equals(this.finalizedFeatures, that.finalizedFeatures) && + Objects.equals(this.finalizedFeaturesEpoch, that.finalizedFeaturesEpoch) && + Objects.equals(this.supportedFeatures, that.supportedFeatures); + } + + @Override + public int hashCode() { + return Objects.hash(finalizedFeatures, finalizedFeaturesEpoch, supportedFeatures); + } + + private static String mapToString(final Map featureVersionsMap) { + return String.format( + "{%s}", + featureVersionsMap + .entrySet() + .stream() + .map(entry -> String.format("(%s -> %s)", entry.getKey(), entry.getValue())) + .collect(joining(", ")) + ); + } + + @Override + public String toString() { + return String.format( + "FeatureMetadata{finalizedFeatures:%s, finalizedFeaturesEpoch:%s, supportedFeatures:%s}", + mapToString(finalizedFeatures), + finalizedFeaturesEpoch.map(Object::toString).orElse(""), + mapToString(supportedFeatures)); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/FeatureUpdate.java b/clients/src/main/java/org/apache/kafka/clients/admin/FeatureUpdate.java new file mode 100644 index 0000000..38753af --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/FeatureUpdate.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import java.util.Objects; + +/** + * Encapsulates details about an update to a finalized feature. + */ +public class FeatureUpdate { + private final short maxVersionLevel; + private final boolean allowDowngrade; + + /** + * @param maxVersionLevel the new maximum version level for the finalized feature. + * a value < 1 is special and indicates that the update is intended to + * delete the finalized feature, and should be accompanied by setting + * the allowDowngrade flag to true. + * @param allowDowngrade - true, if this feature update was meant to downgrade the existing + * maximum version level of the finalized feature. + * - false, otherwise. + */ + public FeatureUpdate(final short maxVersionLevel, final boolean allowDowngrade) { + if (maxVersionLevel < 1 && !allowDowngrade) { + throw new IllegalArgumentException(String.format( + "The allowDowngrade flag should be set when the provided maxVersionLevel:%d is < 1.", + maxVersionLevel)); + } + this.maxVersionLevel = maxVersionLevel; + this.allowDowngrade = allowDowngrade; + } + + public short maxVersionLevel() { + return maxVersionLevel; + } + + public boolean allowDowngrade() { + return allowDowngrade; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + + if (!(other instanceof FeatureUpdate)) { + return false; + } + + final FeatureUpdate that = (FeatureUpdate) other; + return this.maxVersionLevel == that.maxVersionLevel && this.allowDowngrade == that.allowDowngrade; + } + + @Override + public int hashCode() { + return Objects.hash(maxVersionLevel, allowDowngrade); + } + + @Override + public String toString() { + return String.format("FeatureUpdate{maxVersionLevel:%d, allowDowngrade:%s}", maxVersionLevel, allowDowngrade); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/FinalizedVersionRange.java b/clients/src/main/java/org/apache/kafka/clients/admin/FinalizedVersionRange.java new file mode 100644 index 0000000..aa0401a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/FinalizedVersionRange.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import java.util.Objects; + +/** + * Represents a range of version levels supported by every broker in a cluster for some feature. + */ +public class FinalizedVersionRange { + private final short minVersionLevel; + + private final short maxVersionLevel; + + /** + * Raises an exception unless the following condition is met: + * minVersionLevel >= 1 and maxVersionLevel >= 1 and maxVersionLevel >= minVersionLevel. + * + * @param minVersionLevel The minimum version level value. + * @param maxVersionLevel The maximum version level value. + * + * @throws IllegalArgumentException Raised when the condition described above is not met. + */ + FinalizedVersionRange(final short minVersionLevel, final short maxVersionLevel) { + if (minVersionLevel < 1 || maxVersionLevel < 1 || maxVersionLevel < minVersionLevel) { + throw new IllegalArgumentException( + String.format( + "Expected minVersionLevel >= 1, maxVersionLevel >= 1 and" + + " maxVersionLevel >= minVersionLevel, but received" + + " minVersionLevel: %d, maxVersionLevel: %d", minVersionLevel, maxVersionLevel)); + } + this.minVersionLevel = minVersionLevel; + this.maxVersionLevel = maxVersionLevel; + } + + public short minVersionLevel() { + return minVersionLevel; + } + + public short maxVersionLevel() { + return maxVersionLevel; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof FinalizedVersionRange)) { + return false; + } + + final FinalizedVersionRange that = (FinalizedVersionRange) other; + return this.minVersionLevel == that.minVersionLevel && + this.maxVersionLevel == that.maxVersionLevel; + } + + @Override + public int hashCode() { + return Objects.hash(minVersionLevel, maxVersionLevel); + } + + @Override + public String toString() { + return String.format( + "FinalizedVersionRange[min_version_level:%d, max_version_level:%d]", + minVersionLevel, + maxVersionLevel); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/KafkaAdminClient.java b/clients/src/main/java/org/apache/kafka/clients/admin/KafkaAdminClient.java new file mode 100644 index 0000000..c34c748 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/KafkaAdminClient.java @@ -0,0 +1,4469 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.clients.ApiVersions; +import org.apache.kafka.clients.ClientRequest; +import org.apache.kafka.clients.ClientResponse; +import org.apache.kafka.clients.ClientUtils; +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.clients.DefaultHostResolver; +import org.apache.kafka.clients.HostResolver; +import org.apache.kafka.clients.KafkaClient; +import org.apache.kafka.clients.NetworkClient; +import org.apache.kafka.clients.StaleMetadataException; +import org.apache.kafka.clients.admin.CreateTopicsResult.TopicMetadataAndConfig; +import org.apache.kafka.clients.admin.DeleteAclsResult.FilterResult; +import org.apache.kafka.clients.admin.DeleteAclsResult.FilterResults; +import org.apache.kafka.clients.admin.DescribeReplicaLogDirsResult.ReplicaLogDirInfo; +import org.apache.kafka.clients.admin.ListOffsetsResult.ListOffsetsResultInfo; +import org.apache.kafka.clients.admin.OffsetSpec.TimestampSpec; +import org.apache.kafka.clients.admin.internals.AbortTransactionHandler; +import org.apache.kafka.clients.admin.internals.AdminApiDriver; +import org.apache.kafka.clients.admin.internals.AdminApiHandler; +import org.apache.kafka.clients.admin.internals.AdminApiFuture; +import org.apache.kafka.clients.admin.internals.AdminApiFuture.SimpleAdminApiFuture; +import org.apache.kafka.clients.admin.internals.AdminMetadataManager; +import org.apache.kafka.clients.admin.internals.AllBrokersStrategy; +import org.apache.kafka.clients.admin.internals.AlterConsumerGroupOffsetsHandler; +import org.apache.kafka.clients.admin.internals.CoordinatorKey; +import org.apache.kafka.clients.admin.internals.DeleteConsumerGroupOffsetsHandler; +import org.apache.kafka.clients.admin.internals.DeleteConsumerGroupsHandler; +import org.apache.kafka.clients.admin.internals.DescribeConsumerGroupsHandler; +import org.apache.kafka.clients.admin.internals.DescribeProducersHandler; +import org.apache.kafka.clients.admin.internals.DescribeTransactionsHandler; +import org.apache.kafka.clients.admin.internals.ListConsumerGroupOffsetsHandler; +import org.apache.kafka.clients.admin.internals.ListTransactionsHandler; +import org.apache.kafka.clients.admin.internals.MetadataOperationContext; +import org.apache.kafka.clients.admin.internals.RemoveMembersFromConsumerGroupHandler; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.consumer.internals.ConsumerProtocol; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.ConsumerGroupState; +import org.apache.kafka.common.ElectionType; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicCollection; +import org.apache.kafka.common.TopicCollection.TopicIdCollection; +import org.apache.kafka.common.TopicCollection.TopicNameCollection; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.TopicPartitionInfo; +import org.apache.kafka.common.TopicPartitionReplica; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.acl.AclBinding; +import org.apache.kafka.common.acl.AclBindingFilter; +import org.apache.kafka.common.acl.AclOperation; +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.errors.ApiException; +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.errors.DisconnectException; +import org.apache.kafka.common.errors.InvalidRequestException; +import org.apache.kafka.common.errors.InvalidTopicException; +import org.apache.kafka.common.errors.KafkaStorageException; +import org.apache.kafka.common.errors.RetriableException; +import org.apache.kafka.common.errors.ThrottlingQuotaExceededException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.UnacceptableCredentialException; +import org.apache.kafka.common.errors.UnknownServerException; +import org.apache.kafka.common.errors.UnknownTopicOrPartitionException; +import org.apache.kafka.common.errors.UnsupportedSaslMechanismException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.message.AlterPartitionReassignmentsRequestData; +import org.apache.kafka.common.message.AlterPartitionReassignmentsRequestData.ReassignableTopic; +import org.apache.kafka.common.message.AlterReplicaLogDirsRequestData; +import org.apache.kafka.common.message.AlterReplicaLogDirsRequestData.AlterReplicaLogDir; +import org.apache.kafka.common.message.AlterReplicaLogDirsRequestData.AlterReplicaLogDirTopic; +import org.apache.kafka.common.message.AlterReplicaLogDirsResponseData.AlterReplicaLogDirPartitionResult; +import org.apache.kafka.common.message.AlterReplicaLogDirsResponseData.AlterReplicaLogDirTopicResult; +import org.apache.kafka.common.message.AlterUserScramCredentialsRequestData; +import org.apache.kafka.common.message.ApiVersionsResponseData.FinalizedFeatureKey; +import org.apache.kafka.common.message.ApiVersionsResponseData.SupportedFeatureKey; +import org.apache.kafka.common.message.CreateAclsRequestData; +import org.apache.kafka.common.message.CreateAclsRequestData.AclCreation; +import org.apache.kafka.common.message.CreateAclsResponseData.AclCreationResult; +import org.apache.kafka.common.message.CreateDelegationTokenRequestData; +import org.apache.kafka.common.message.CreateDelegationTokenRequestData.CreatableRenewers; +import org.apache.kafka.common.message.CreateDelegationTokenResponseData; +import org.apache.kafka.common.message.CreatePartitionsRequestData; +import org.apache.kafka.common.message.CreatePartitionsRequestData.CreatePartitionsAssignment; +import org.apache.kafka.common.message.CreatePartitionsRequestData.CreatePartitionsTopic; +import org.apache.kafka.common.message.CreatePartitionsRequestData.CreatePartitionsTopicCollection; +import org.apache.kafka.common.message.CreatePartitionsResponseData.CreatePartitionsTopicResult; +import org.apache.kafka.common.message.CreateTopicsRequestData; +import org.apache.kafka.common.message.CreateTopicsRequestData.CreatableTopicCollection; +import org.apache.kafka.common.message.CreateTopicsResponseData.CreatableTopicConfigs; +import org.apache.kafka.common.message.CreateTopicsResponseData.CreatableTopicResult; +import org.apache.kafka.common.message.DeleteAclsRequestData; +import org.apache.kafka.common.message.DeleteAclsRequestData.DeleteAclsFilter; +import org.apache.kafka.common.message.DeleteAclsResponseData; +import org.apache.kafka.common.message.DeleteAclsResponseData.DeleteAclsFilterResult; +import org.apache.kafka.common.message.DeleteAclsResponseData.DeleteAclsMatchingAcl; +import org.apache.kafka.common.message.DeleteRecordsRequestData; +import org.apache.kafka.common.message.DeleteRecordsRequestData.DeleteRecordsPartition; +import org.apache.kafka.common.message.DeleteRecordsRequestData.DeleteRecordsTopic; +import org.apache.kafka.common.message.DeleteRecordsResponseData; +import org.apache.kafka.common.message.DeleteRecordsResponseData.DeleteRecordsTopicResult; +import org.apache.kafka.common.message.DeleteTopicsRequestData; +import org.apache.kafka.common.message.DeleteTopicsRequestData.DeleteTopicState; +import org.apache.kafka.common.message.DeleteTopicsResponseData.DeletableTopicResult; +import org.apache.kafka.common.message.DescribeClusterRequestData; +import org.apache.kafka.common.message.DescribeConfigsRequestData; +import org.apache.kafka.common.message.DescribeConfigsResponseData; +import org.apache.kafka.common.message.DescribeLogDirsRequestData; +import org.apache.kafka.common.message.DescribeLogDirsRequestData.DescribableLogDirTopic; +import org.apache.kafka.common.message.DescribeLogDirsResponseData; +import org.apache.kafka.common.message.DescribeUserScramCredentialsRequestData; +import org.apache.kafka.common.message.DescribeUserScramCredentialsRequestData.UserName; +import org.apache.kafka.common.message.DescribeUserScramCredentialsResponseData; +import org.apache.kafka.common.message.ExpireDelegationTokenRequestData; +import org.apache.kafka.common.message.LeaveGroupRequestData.MemberIdentity; +import org.apache.kafka.common.message.ListGroupsRequestData; +import org.apache.kafka.common.message.ListGroupsResponseData; +import org.apache.kafka.common.message.ListOffsetsRequestData.ListOffsetsPartition; +import org.apache.kafka.common.message.ListOffsetsRequestData.ListOffsetsTopic; +import org.apache.kafka.common.message.ListOffsetsResponseData.ListOffsetsPartitionResponse; +import org.apache.kafka.common.message.ListOffsetsResponseData.ListOffsetsTopicResponse; +import org.apache.kafka.common.message.ListPartitionReassignmentsRequestData; +import org.apache.kafka.common.message.MetadataRequestData; +import org.apache.kafka.common.message.RenewDelegationTokenRequestData; +import org.apache.kafka.common.message.UnregisterBrokerRequestData; +import org.apache.kafka.common.message.UpdateFeaturesRequestData; +import org.apache.kafka.common.message.UpdateFeaturesResponseData.UpdatableFeatureResult; +import org.apache.kafka.common.metrics.JmxReporter; +import org.apache.kafka.common.metrics.KafkaMetricsContext; +import org.apache.kafka.common.metrics.MetricConfig; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.MetricsContext; +import org.apache.kafka.common.metrics.MetricsReporter; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.network.ChannelBuilder; +import org.apache.kafka.common.network.Selector; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.quota.ClientQuotaAlteration; +import org.apache.kafka.common.quota.ClientQuotaEntity; +import org.apache.kafka.common.quota.ClientQuotaFilter; +import org.apache.kafka.common.requests.AbstractRequest; +import org.apache.kafka.common.requests.AbstractResponse; +import org.apache.kafka.common.requests.AlterClientQuotasRequest; +import org.apache.kafka.common.requests.AlterClientQuotasResponse; +import org.apache.kafka.common.requests.AlterConfigsRequest; +import org.apache.kafka.common.requests.AlterConfigsResponse; +import org.apache.kafka.common.requests.AlterPartitionReassignmentsRequest; +import org.apache.kafka.common.requests.AlterPartitionReassignmentsResponse; +import org.apache.kafka.common.requests.AlterReplicaLogDirsRequest; +import org.apache.kafka.common.requests.AlterReplicaLogDirsResponse; +import org.apache.kafka.common.requests.AlterUserScramCredentialsRequest; +import org.apache.kafka.common.requests.AlterUserScramCredentialsResponse; +import org.apache.kafka.common.requests.ApiError; +import org.apache.kafka.common.requests.ApiVersionsRequest; +import org.apache.kafka.common.requests.ApiVersionsResponse; +import org.apache.kafka.common.requests.CreateAclsRequest; +import org.apache.kafka.common.requests.CreateAclsResponse; +import org.apache.kafka.common.requests.CreateDelegationTokenRequest; +import org.apache.kafka.common.requests.CreateDelegationTokenResponse; +import org.apache.kafka.common.requests.CreatePartitionsRequest; +import org.apache.kafka.common.requests.CreatePartitionsResponse; +import org.apache.kafka.common.requests.CreateTopicsRequest; +import org.apache.kafka.common.requests.CreateTopicsResponse; +import org.apache.kafka.common.requests.DeleteAclsRequest; +import org.apache.kafka.common.requests.DeleteAclsResponse; +import org.apache.kafka.common.requests.DeleteRecordsRequest; +import org.apache.kafka.common.requests.DeleteRecordsResponse; +import org.apache.kafka.common.requests.DeleteTopicsRequest; +import org.apache.kafka.common.requests.DeleteTopicsResponse; +import org.apache.kafka.common.requests.DescribeAclsRequest; +import org.apache.kafka.common.requests.DescribeAclsResponse; +import org.apache.kafka.common.requests.DescribeClientQuotasRequest; +import org.apache.kafka.common.requests.DescribeClientQuotasResponse; +import org.apache.kafka.common.requests.DescribeClusterRequest; +import org.apache.kafka.common.requests.DescribeClusterResponse; +import org.apache.kafka.common.requests.DescribeConfigsRequest; +import org.apache.kafka.common.requests.DescribeConfigsResponse; +import org.apache.kafka.common.requests.DescribeDelegationTokenRequest; +import org.apache.kafka.common.requests.DescribeDelegationTokenResponse; +import org.apache.kafka.common.requests.DescribeLogDirsRequest; +import org.apache.kafka.common.requests.DescribeLogDirsResponse; +import org.apache.kafka.common.requests.DescribeUserScramCredentialsRequest; +import org.apache.kafka.common.requests.DescribeUserScramCredentialsResponse; +import org.apache.kafka.common.requests.ElectLeadersRequest; +import org.apache.kafka.common.requests.ElectLeadersResponse; +import org.apache.kafka.common.requests.ExpireDelegationTokenRequest; +import org.apache.kafka.common.requests.ExpireDelegationTokenResponse; +import org.apache.kafka.common.requests.IncrementalAlterConfigsRequest; +import org.apache.kafka.common.requests.IncrementalAlterConfigsResponse; +import org.apache.kafka.common.requests.ListGroupsRequest; +import org.apache.kafka.common.requests.ListGroupsResponse; +import org.apache.kafka.common.requests.ListOffsetsRequest; +import org.apache.kafka.common.requests.ListOffsetsResponse; +import org.apache.kafka.common.requests.ListPartitionReassignmentsRequest; +import org.apache.kafka.common.requests.ListPartitionReassignmentsResponse; +import org.apache.kafka.common.requests.MetadataRequest; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.requests.RenewDelegationTokenRequest; +import org.apache.kafka.common.requests.RenewDelegationTokenResponse; +import org.apache.kafka.common.requests.UnregisterBrokerRequest; +import org.apache.kafka.common.requests.UnregisterBrokerResponse; +import org.apache.kafka.common.requests.UpdateFeaturesRequest; +import org.apache.kafka.common.requests.UpdateFeaturesResponse; +import org.apache.kafka.common.security.auth.KafkaPrincipal; +import org.apache.kafka.common.security.scram.internals.ScramFormatter; +import org.apache.kafka.common.security.token.delegation.DelegationToken; +import org.apache.kafka.common.security.token.delegation.TokenInformation; +import org.apache.kafka.common.utils.AppInfoParser; +import org.apache.kafka.common.utils.KafkaThread; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; + +import java.net.InetSocketAddress; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.TreeMap; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.apache.kafka.common.message.AlterPartitionReassignmentsRequestData.ReassignablePartition; +import static org.apache.kafka.common.message.AlterPartitionReassignmentsResponseData.ReassignablePartitionResponse; +import static org.apache.kafka.common.message.AlterPartitionReassignmentsResponseData.ReassignableTopicResponse; +import static org.apache.kafka.common.message.ListPartitionReassignmentsRequestData.ListPartitionReassignmentsTopics; +import static org.apache.kafka.common.message.ListPartitionReassignmentsResponseData.OngoingPartitionReassignment; +import static org.apache.kafka.common.message.ListPartitionReassignmentsResponseData.OngoingTopicReassignment; +import static org.apache.kafka.common.requests.MetadataRequest.convertToMetadataRequestTopic; +import static org.apache.kafka.common.requests.MetadataRequest.convertTopicIdsToMetadataRequestTopic; +import static org.apache.kafka.common.utils.Utils.closeQuietly; + +/** + * The default implementation of {@link Admin}. An instance of this class is created by invoking one of the + * {@code create()} methods in {@code AdminClient}. Users should not refer to this class directly. + * + *

+ * This class is thread-safe. + *

+ * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class KafkaAdminClient extends AdminClient { + + /** + * The next integer to use to name a KafkaAdminClient which the user hasn't specified an explicit name for. + */ + private static final AtomicInteger ADMIN_CLIENT_ID_SEQUENCE = new AtomicInteger(1); + + /** + * The prefix to use for the JMX metrics for this class + */ + private static final String JMX_PREFIX = "kafka.admin.client"; + + /** + * An invalid shutdown time which indicates that a shutdown has not yet been performed. + */ + private static final long INVALID_SHUTDOWN_TIME = -1; + + /** + * Thread name prefix for admin client network thread + */ + static final String NETWORK_THREAD_PREFIX = "kafka-admin-client-thread"; + + private final Logger log; + private final LogContext logContext; + + /** + * The default timeout to use for an operation. + */ + private final int defaultApiTimeoutMs; + + /** + * The timeout to use for a single request. + */ + private final int requestTimeoutMs; + + /** + * The name of this AdminClient instance. + */ + private final String clientId; + + /** + * Provides the time. + */ + private final Time time; + + /** + * The cluster metadata manager used by the KafkaClient. + */ + private final AdminMetadataManager metadataManager; + + /** + * The metrics for this KafkaAdminClient. + */ + private final Metrics metrics; + + /** + * The network client to use. + */ + private final KafkaClient client; + + /** + * The runnable used in the service thread for this admin client. + */ + private final AdminClientRunnable runnable; + + /** + * The network service thread for this admin client. + */ + private final Thread thread; + + /** + * During a close operation, this is the time at which we will time out all pending operations + * and force the RPC thread to exit. If the admin client is not closing, this will be 0. + */ + private final AtomicLong hardShutdownTimeMs = new AtomicLong(INVALID_SHUTDOWN_TIME); + + /** + * A factory which creates TimeoutProcessors for the RPC thread. + */ + private final TimeoutProcessorFactory timeoutProcessorFactory; + + private final int maxRetries; + + private final long retryBackoffMs; + + /** + * Get or create a list value from a map. + * + * @param map The map to get or create the element from. + * @param key The key. + * @param The key type. + * @param The value type. + * @return The list value. + */ + static List getOrCreateListValue(Map> map, K key) { + return map.computeIfAbsent(key, k -> new LinkedList<>()); + } + + /** + * Send an exception to every element in a collection of KafkaFutureImpls. + * + * @param futures The collection of KafkaFutureImpl objects. + * @param exc The exception + * @param The KafkaFutureImpl result type. + */ + private static void completeAllExceptionally(Collection> futures, Throwable exc) { + completeAllExceptionally(futures.stream(), exc); + } + + /** + * Send an exception to all futures in the provided stream + * + * @param futures The stream of KafkaFutureImpl objects. + * @param exc The exception + * @param The KafkaFutureImpl result type. + */ + private static void completeAllExceptionally(Stream> futures, Throwable exc) { + futures.forEach(future -> future.completeExceptionally(exc)); + } + + /** + * Get the current time remaining before a deadline as an integer. + * + * @param now The current time in milliseconds. + * @param deadlineMs The deadline time in milliseconds. + * @return The time delta in milliseconds. + */ + static int calcTimeoutMsRemainingAsInt(long now, long deadlineMs) { + long deltaMs = deadlineMs - now; + if (deltaMs > Integer.MAX_VALUE) + deltaMs = Integer.MAX_VALUE; + else if (deltaMs < Integer.MIN_VALUE) + deltaMs = Integer.MIN_VALUE; + return (int) deltaMs; + } + + /** + * Generate the client id based on the configuration. + * + * @param config The configuration + * + * @return The client id + */ + static String generateClientId(AdminClientConfig config) { + String clientId = config.getString(AdminClientConfig.CLIENT_ID_CONFIG); + if (!clientId.isEmpty()) + return clientId; + return "adminclient-" + ADMIN_CLIENT_ID_SEQUENCE.getAndIncrement(); + } + + /** + * Get the deadline for a particular call. + * + * @param now The current time in milliseconds. + * @param optionTimeoutMs The timeout option given by the user. + * + * @return The deadline in milliseconds. + */ + private long calcDeadlineMs(long now, Integer optionTimeoutMs) { + if (optionTimeoutMs != null) + return now + Math.max(0, optionTimeoutMs); + return now + defaultApiTimeoutMs; + } + + /** + * Pretty-print an exception. + * + * @param throwable The exception. + * + * @return A compact human-readable string. + */ + static String prettyPrintException(Throwable throwable) { + if (throwable == null) + return "Null exception."; + if (throwable.getMessage() != null) { + return throwable.getClass().getSimpleName() + ": " + throwable.getMessage(); + } + return throwable.getClass().getSimpleName(); + } + + static KafkaAdminClient createInternal(AdminClientConfig config, TimeoutProcessorFactory timeoutProcessorFactory) { + return createInternal(config, timeoutProcessorFactory, null); + } + + static KafkaAdminClient createInternal(AdminClientConfig config, TimeoutProcessorFactory timeoutProcessorFactory, + HostResolver hostResolver) { + Metrics metrics = null; + NetworkClient networkClient = null; + Time time = Time.SYSTEM; + String clientId = generateClientId(config); + ChannelBuilder channelBuilder = null; + Selector selector = null; + ApiVersions apiVersions = new ApiVersions(); + LogContext logContext = createLogContext(clientId); + + try { + // Since we only request node information, it's safe to pass true for allowAutoTopicCreation (and it + // simplifies communication with older brokers) + AdminMetadataManager metadataManager = new AdminMetadataManager(logContext, + config.getLong(AdminClientConfig.RETRY_BACKOFF_MS_CONFIG), + config.getLong(AdminClientConfig.METADATA_MAX_AGE_CONFIG)); + List addresses = ClientUtils.parseAndValidateAddresses( + config.getList(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG), + config.getString(AdminClientConfig.CLIENT_DNS_LOOKUP_CONFIG)); + metadataManager.update(Cluster.bootstrap(addresses), time.milliseconds()); + List reporters = config.getConfiguredInstances(AdminClientConfig.METRIC_REPORTER_CLASSES_CONFIG, + MetricsReporter.class, + Collections.singletonMap(AdminClientConfig.CLIENT_ID_CONFIG, clientId)); + Map metricTags = Collections.singletonMap("client-id", clientId); + MetricConfig metricConfig = new MetricConfig().samples(config.getInt(AdminClientConfig.METRICS_NUM_SAMPLES_CONFIG)) + .timeWindow(config.getLong(AdminClientConfig.METRICS_SAMPLE_WINDOW_MS_CONFIG), TimeUnit.MILLISECONDS) + .recordLevel(Sensor.RecordingLevel.forName(config.getString(AdminClientConfig.METRICS_RECORDING_LEVEL_CONFIG))) + .tags(metricTags); + JmxReporter jmxReporter = new JmxReporter(); + jmxReporter.configure(config.originals()); + reporters.add(jmxReporter); + MetricsContext metricsContext = new KafkaMetricsContext(JMX_PREFIX, + config.originalsWithPrefix(CommonClientConfigs.METRICS_CONTEXT_PREFIX)); + metrics = new Metrics(metricConfig, reporters, time, metricsContext); + String metricGrpPrefix = "admin-client"; + channelBuilder = ClientUtils.createChannelBuilder(config, time, logContext); + selector = new Selector(config.getLong(AdminClientConfig.CONNECTIONS_MAX_IDLE_MS_CONFIG), + metrics, time, metricGrpPrefix, channelBuilder, logContext); + networkClient = new NetworkClient( + metadataManager.updater(), + null, + selector, + clientId, + 1, + config.getLong(AdminClientConfig.RECONNECT_BACKOFF_MS_CONFIG), + config.getLong(AdminClientConfig.RECONNECT_BACKOFF_MAX_MS_CONFIG), + config.getInt(AdminClientConfig.SEND_BUFFER_CONFIG), + config.getInt(AdminClientConfig.RECEIVE_BUFFER_CONFIG), + (int) TimeUnit.HOURS.toMillis(1), + config.getLong(AdminClientConfig.SOCKET_CONNECTION_SETUP_TIMEOUT_MS_CONFIG), + config.getLong(AdminClientConfig.SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS_CONFIG), + time, + true, + apiVersions, + null, + logContext, + (hostResolver == null) ? new DefaultHostResolver() : hostResolver); + return new KafkaAdminClient(config, clientId, time, metadataManager, metrics, networkClient, + timeoutProcessorFactory, logContext); + } catch (Throwable exc) { + closeQuietly(metrics, "Metrics"); + closeQuietly(networkClient, "NetworkClient"); + closeQuietly(selector, "Selector"); + closeQuietly(channelBuilder, "ChannelBuilder"); + throw new KafkaException("Failed to create new KafkaAdminClient", exc); + } + } + + static KafkaAdminClient createInternal(AdminClientConfig config, + AdminMetadataManager metadataManager, + KafkaClient client, + Time time) { + Metrics metrics = null; + String clientId = generateClientId(config); + + try { + metrics = new Metrics(new MetricConfig(), new LinkedList<>(), time); + LogContext logContext = createLogContext(clientId); + return new KafkaAdminClient(config, clientId, time, metadataManager, metrics, + client, null, logContext); + } catch (Throwable exc) { + closeQuietly(metrics, "Metrics"); + throw new KafkaException("Failed to create new KafkaAdminClient", exc); + } + } + + static LogContext createLogContext(String clientId) { + return new LogContext("[AdminClient clientId=" + clientId + "] "); + } + + private KafkaAdminClient(AdminClientConfig config, + String clientId, + Time time, + AdminMetadataManager metadataManager, + Metrics metrics, + KafkaClient client, + TimeoutProcessorFactory timeoutProcessorFactory, + LogContext logContext) { + this.clientId = clientId; + this.log = logContext.logger(KafkaAdminClient.class); + this.logContext = logContext; + this.requestTimeoutMs = config.getInt(AdminClientConfig.REQUEST_TIMEOUT_MS_CONFIG); + this.defaultApiTimeoutMs = configureDefaultApiTimeoutMs(config); + this.time = time; + this.metadataManager = metadataManager; + this.metrics = metrics; + this.client = client; + this.runnable = new AdminClientRunnable(); + String threadName = NETWORK_THREAD_PREFIX + " | " + clientId; + this.thread = new KafkaThread(threadName, runnable, true); + this.timeoutProcessorFactory = (timeoutProcessorFactory == null) ? + new TimeoutProcessorFactory() : timeoutProcessorFactory; + this.maxRetries = config.getInt(AdminClientConfig.RETRIES_CONFIG); + this.retryBackoffMs = config.getLong(AdminClientConfig.RETRY_BACKOFF_MS_CONFIG); + config.logUnused(); + AppInfoParser.registerAppInfo(JMX_PREFIX, clientId, metrics, time.milliseconds()); + log.debug("Kafka admin client initialized"); + thread.start(); + } + + /** + * If a default.api.timeout.ms has been explicitly specified, raise an error if it conflicts with request.timeout.ms. + * If no default.api.timeout.ms has been configured, then set its value as the max of the default and request.timeout.ms. Also we should probably log a warning. + * Otherwise, use the provided values for both configurations. + * + * @param config The configuration + */ + private int configureDefaultApiTimeoutMs(AdminClientConfig config) { + int requestTimeoutMs = config.getInt(AdminClientConfig.REQUEST_TIMEOUT_MS_CONFIG); + int defaultApiTimeoutMs = config.getInt(AdminClientConfig.DEFAULT_API_TIMEOUT_MS_CONFIG); + + if (defaultApiTimeoutMs < requestTimeoutMs) { + if (config.originals().containsKey(AdminClientConfig.DEFAULT_API_TIMEOUT_MS_CONFIG)) { + throw new ConfigException("The specified value of " + AdminClientConfig.DEFAULT_API_TIMEOUT_MS_CONFIG + + " must be no smaller than the value of " + AdminClientConfig.REQUEST_TIMEOUT_MS_CONFIG + "."); + } else { + log.warn("Overriding the default value for {} ({}) with the explicitly configured request timeout {}", + AdminClientConfig.DEFAULT_API_TIMEOUT_MS_CONFIG, this.defaultApiTimeoutMs, + requestTimeoutMs); + return requestTimeoutMs; + } + } + return defaultApiTimeoutMs; + } + + @Override + public void close(Duration timeout) { + long waitTimeMs = timeout.toMillis(); + if (waitTimeMs < 0) + throw new IllegalArgumentException("The timeout cannot be negative."); + waitTimeMs = Math.min(TimeUnit.DAYS.toMillis(365), waitTimeMs); // Limit the timeout to a year. + long now = time.milliseconds(); + long newHardShutdownTimeMs = now + waitTimeMs; + long prev = INVALID_SHUTDOWN_TIME; + while (true) { + if (hardShutdownTimeMs.compareAndSet(prev, newHardShutdownTimeMs)) { + if (prev == INVALID_SHUTDOWN_TIME) { + log.debug("Initiating close operation."); + } else { + log.debug("Moving hard shutdown time forward."); + } + client.wakeup(); // Wake the thread, if it is blocked inside poll(). + break; + } + prev = hardShutdownTimeMs.get(); + if (prev < newHardShutdownTimeMs) { + log.debug("Hard shutdown time is already earlier than requested."); + newHardShutdownTimeMs = prev; + break; + } + } + if (log.isDebugEnabled()) { + long deltaMs = Math.max(0, newHardShutdownTimeMs - time.milliseconds()); + log.debug("Waiting for the I/O thread to exit. Hard shutdown in {} ms.", deltaMs); + } + try { + // close() can be called by AdminClient thread when it invokes callback. That will + // cause deadlock, so check for that condition. + if (Thread.currentThread() != thread) { + // Wait for the thread to be joined. + thread.join(waitTimeMs); + } + log.debug("Kafka admin client closed."); + } catch (InterruptedException e) { + log.debug("Interrupted while joining I/O thread", e); + Thread.currentThread().interrupt(); + } + } + + /** + * An interface for providing a node for a call. + */ + private interface NodeProvider { + Node provide(); + } + + private class MetadataUpdateNodeIdProvider implements NodeProvider { + @Override + public Node provide() { + return client.leastLoadedNode(time.milliseconds()); + } + } + + private class ConstantNodeIdProvider implements NodeProvider { + private final int nodeId; + + ConstantNodeIdProvider(int nodeId) { + this.nodeId = nodeId; + } + + @Override + public Node provide() { + if (metadataManager.isReady() && + (metadataManager.nodeById(nodeId) != null)) { + return metadataManager.nodeById(nodeId); + } + // If we can't find the node with the given constant ID, we schedule a + // metadata update and hope it appears. This behavior is useful for avoiding + // flaky behavior in tests when the cluster is starting up and not all nodes + // have appeared. + metadataManager.requestUpdate(); + return null; + } + } + + /** + * Provides the controller node. + */ + private class ControllerNodeProvider implements NodeProvider { + @Override + public Node provide() { + if (metadataManager.isReady() && + (metadataManager.controller() != null)) { + return metadataManager.controller(); + } + metadataManager.requestUpdate(); + return null; + } + } + + /** + * Provides the least loaded node. + */ + private class LeastLoadedNodeProvider implements NodeProvider { + @Override + public Node provide() { + if (metadataManager.isReady()) { + // This may return null if all nodes are busy. + // In that case, we will postpone node assignment. + return client.leastLoadedNode(time.milliseconds()); + } + metadataManager.requestUpdate(); + return null; + } + } + + abstract class Call { + private final boolean internal; + private final String callName; + private final long deadlineMs; + private final NodeProvider nodeProvider; + protected int tries; + private Node curNode = null; + private long nextAllowedTryMs; + + Call(boolean internal, + String callName, + long nextAllowedTryMs, + int tries, + long deadlineMs, + NodeProvider nodeProvider + ) { + this.internal = internal; + this.callName = callName; + this.nextAllowedTryMs = nextAllowedTryMs; + this.tries = tries; + this.deadlineMs = deadlineMs; + this.nodeProvider = nodeProvider; + } + + Call(boolean internal, String callName, long deadlineMs, NodeProvider nodeProvider) { + this(internal, callName, 0, 0, deadlineMs, nodeProvider); + } + + Call(String callName, long deadlineMs, NodeProvider nodeProvider) { + this(false, callName, 0, 0, deadlineMs, nodeProvider); + } + + Call(String callName, long nextAllowedTryMs, int tries, long deadlineMs, NodeProvider nodeProvider) { + this(false, callName, nextAllowedTryMs, tries, deadlineMs, nodeProvider); + } + + protected Node curNode() { + return curNode; + } + + /** + * Handle a failure. + * + * Depending on what the exception is and how many times we have already tried, we may choose to + * fail the Call, or retry it. It is important to print the stack traces here in some cases, + * since they are not necessarily preserved in ApiVersionException objects. + * + * @param now The current time in milliseconds. + * @param throwable The failure exception. + */ + final void fail(long now, Throwable throwable) { + if (curNode != null) { + runnable.nodeReadyDeadlines.remove(curNode); + curNode = null; + } + // If the admin client is closing, we can't retry. + if (runnable.closing) { + handleFailure(throwable); + return; + } + // If this is an UnsupportedVersionException that we can retry, do so. Note that a + // protocol downgrade will not count against the total number of retries we get for + // this RPC. That is why 'tries' is not incremented. + if ((throwable instanceof UnsupportedVersionException) && + handleUnsupportedVersionException((UnsupportedVersionException) throwable)) { + log.debug("{} attempting protocol downgrade and then retry.", this); + runnable.pendingCalls.add(this); + return; + } + tries++; + nextAllowedTryMs = now + retryBackoffMs; + + // If the call has timed out, fail. + if (calcTimeoutMsRemainingAsInt(now, deadlineMs) <= 0) { + handleTimeoutFailure(now, throwable); + return; + } + // If the exception is not retriable, fail. + if (!(throwable instanceof RetriableException)) { + if (log.isDebugEnabled()) { + log.debug("{} failed with non-retriable exception after {} attempt(s)", this, tries, + new Exception(prettyPrintException(throwable))); + } + handleFailure(throwable); + return; + } + // If we are out of retries, fail. + if (tries > maxRetries) { + handleTimeoutFailure(now, throwable); + return; + } + if (log.isDebugEnabled()) { + log.debug("{} failed: {}. Beginning retry #{}", + this, prettyPrintException(throwable), tries); + } + maybeRetry(now, throwable); + } + + void maybeRetry(long now, Throwable throwable) { + runnable.pendingCalls.add(this); + } + + private void handleTimeoutFailure(long now, Throwable cause) { + if (log.isDebugEnabled()) { + log.debug("{} timed out at {} after {} attempt(s)", this, now, tries, + new Exception(prettyPrintException(cause))); + } + if (cause instanceof TimeoutException) { + handleFailure(cause); + } else { + handleFailure(new TimeoutException(this + " timed out at " + now + + " after " + tries + " attempt(s)", cause)); + } + } + + /** + * Create an AbstractRequest.Builder for this Call. + * + * @param timeoutMs The timeout in milliseconds. + * + * @return The AbstractRequest builder. + */ + abstract AbstractRequest.Builder createRequest(int timeoutMs); + + /** + * Process the call response. + * + * @param abstractResponse The AbstractResponse. + * + */ + abstract void handleResponse(AbstractResponse abstractResponse); + + /** + * Handle a failure. This will only be called if the failure exception was not + * retriable, or if we hit a timeout. + * + * @param throwable The exception. + */ + abstract void handleFailure(Throwable throwable); + + /** + * Handle an UnsupportedVersionException. + * + * @param exception The exception. + * + * @return True if the exception can be handled; false otherwise. + */ + boolean handleUnsupportedVersionException(UnsupportedVersionException exception) { + return false; + } + + @Override + public String toString() { + return "Call(callName=" + callName + ", deadlineMs=" + deadlineMs + + ", tries=" + tries + ", nextAllowedTryMs=" + nextAllowedTryMs + ")"; + } + + public boolean isInternal() { + return internal; + } + } + + static class TimeoutProcessorFactory { + TimeoutProcessor create(long now) { + return new TimeoutProcessor(now); + } + } + + static class TimeoutProcessor { + /** + * The current time in milliseconds. + */ + private final long now; + + /** + * The number of milliseconds until the next timeout. + */ + private int nextTimeoutMs; + + /** + * Create a new timeout processor. + * + * @param now The current time in milliseconds since the epoch. + */ + TimeoutProcessor(long now) { + this.now = now; + this.nextTimeoutMs = Integer.MAX_VALUE; + } + + /** + * Check for calls which have timed out. + * Timed out calls will be removed and failed. + * The remaining milliseconds until the next timeout will be updated. + * + * @param calls The collection of calls. + * + * @return The number of calls which were timed out. + */ + int handleTimeouts(Collection calls, String msg) { + int numTimedOut = 0; + for (Iterator iter = calls.iterator(); iter.hasNext(); ) { + Call call = iter.next(); + int remainingMs = calcTimeoutMsRemainingAsInt(now, call.deadlineMs); + if (remainingMs < 0) { + call.fail(now, new TimeoutException(msg + " Call: " + call.callName)); + iter.remove(); + numTimedOut++; + } else { + nextTimeoutMs = Math.min(nextTimeoutMs, remainingMs); + } + } + return numTimedOut; + } + + /** + * Check whether a call should be timed out. + * The remaining milliseconds until the next timeout will be updated. + * + * @param call The call. + * + * @return True if the call should be timed out. + */ + boolean callHasExpired(Call call) { + int remainingMs = calcTimeoutMsRemainingAsInt(now, call.deadlineMs); + if (remainingMs < 0) + return true; + nextTimeoutMs = Math.min(nextTimeoutMs, remainingMs); + return false; + } + + int nextTimeoutMs() { + return nextTimeoutMs; + } + } + + private final class AdminClientRunnable implements Runnable { + /** + * Calls which have not yet been assigned to a node. + * Only accessed from this thread. + */ + private final ArrayList pendingCalls = new ArrayList<>(); + + /** + * Maps nodes to calls that we want to send. + * Only accessed from this thread. + */ + private final Map> callsToSend = new HashMap<>(); + + /** + * Maps node ID strings to calls that have been sent. + * Only accessed from this thread. + */ + private final Map callsInFlight = new HashMap<>(); + + /** + * Maps correlation IDs to calls that have been sent. + * Only accessed from this thread. + */ + private final Map correlationIdToCalls = new HashMap<>(); + + /** + * Pending calls. Protected by the object monitor. + */ + private final List newCalls = new LinkedList<>(); + + /** + * Maps node ID strings to their readiness deadlines. A node will appear in this + * map if there are callsToSend which are waiting for it to be ready, and there + * are no calls in flight using the node. + */ + private final Map nodeReadyDeadlines = new HashMap<>(); + + /** + * Whether the admin client is closing. + */ + private volatile boolean closing = false; + + /** + * Time out the elements in the pendingCalls list which are expired. + * + * @param processor The timeout processor. + */ + private void timeoutPendingCalls(TimeoutProcessor processor) { + int numTimedOut = processor.handleTimeouts(pendingCalls, "Timed out waiting for a node assignment."); + if (numTimedOut > 0) + log.debug("Timed out {} pending calls.", numTimedOut); + } + + /** + * Time out calls which have been assigned to nodes. + * + * @param processor The timeout processor. + */ + private int timeoutCallsToSend(TimeoutProcessor processor) { + int numTimedOut = 0; + for (List callList : callsToSend.values()) { + numTimedOut += processor.handleTimeouts(callList, + "Timed out waiting to send the call."); + } + if (numTimedOut > 0) + log.debug("Timed out {} call(s) with assigned nodes.", numTimedOut); + return numTimedOut; + } + + /** + * Drain all the calls from newCalls into pendingCalls. + * + * This function holds the lock for the minimum amount of time, to avoid blocking + * users of AdminClient who will also take the lock to add new calls. + */ + private synchronized void drainNewCalls() { + transitionToPendingAndClearList(newCalls); + } + + /** + * Add some calls to pendingCalls, and then clear the input list. + * Also clears Call#curNode. + * + * @param calls The calls to add. + */ + private void transitionToPendingAndClearList(List calls) { + for (Call call : calls) { + call.curNode = null; + pendingCalls.add(call); + } + calls.clear(); + } + + /** + * Choose nodes for the calls in the pendingCalls list. + * + * @param now The current time in milliseconds. + * @return The minimum time until a call is ready to be retried if any of the pending + * calls are backing off after a failure + */ + private long maybeDrainPendingCalls(long now) { + long pollTimeout = Long.MAX_VALUE; + log.trace("Trying to choose nodes for {} at {}", pendingCalls, now); + + Iterator pendingIter = pendingCalls.iterator(); + while (pendingIter.hasNext()) { + Call call = pendingIter.next(); + // If the call is being retried, await the proper backoff before finding the node + if (now < call.nextAllowedTryMs) { + pollTimeout = Math.min(pollTimeout, call.nextAllowedTryMs - now); + } else if (maybeDrainPendingCall(call, now)) { + pendingIter.remove(); + } + } + return pollTimeout; + } + + /** + * Check whether a pending call can be assigned a node. Return true if the pending call was either + * transferred to the callsToSend collection or if the call was failed. Return false if it + * should remain pending. + */ + private boolean maybeDrainPendingCall(Call call, long now) { + try { + Node node = call.nodeProvider.provide(); + if (node != null) { + log.trace("Assigned {} to node {}", call, node); + call.curNode = node; + getOrCreateListValue(callsToSend, node).add(call); + return true; + } else { + log.trace("Unable to assign {} to a node.", call); + return false; + } + } catch (Throwable t) { + // Handle authentication errors while choosing nodes. + log.debug("Unable to choose node for {}", call, t); + call.fail(now, t); + return true; + } + } + + /** + * Send the calls which are ready. + * + * @param now The current time in milliseconds. + * @return The minimum timeout we need for poll(). + */ + private long sendEligibleCalls(long now) { + long pollTimeout = Long.MAX_VALUE; + for (Iterator>> iter = callsToSend.entrySet().iterator(); iter.hasNext(); ) { + Map.Entry> entry = iter.next(); + List calls = entry.getValue(); + if (calls.isEmpty()) { + iter.remove(); + continue; + } + Node node = entry.getKey(); + if (callsInFlight.containsKey(node.idString())) { + log.trace("Still waiting for other calls to finish on node {}.", node); + nodeReadyDeadlines.remove(node); + continue; + } + if (!client.ready(node, now)) { + Long deadline = nodeReadyDeadlines.get(node); + if (deadline != null) { + if (now >= deadline) { + log.info("Disconnecting from {} and revoking {} node assignment(s) " + + "because the node is taking too long to become ready.", + node.idString(), calls.size()); + transitionToPendingAndClearList(calls); + client.disconnect(node.idString()); + nodeReadyDeadlines.remove(node); + iter.remove(); + continue; + } + pollTimeout = Math.min(pollTimeout, deadline - now); + } else { + nodeReadyDeadlines.put(node, now + requestTimeoutMs); + } + long nodeTimeout = client.pollDelayMs(node, now); + pollTimeout = Math.min(pollTimeout, nodeTimeout); + log.trace("Client is not ready to send to {}. Must delay {} ms", node, nodeTimeout); + continue; + } + // Subtract the time we spent waiting for the node to become ready from + // the total request time. + int remainingRequestTime; + Long deadlineMs = nodeReadyDeadlines.remove(node); + if (deadlineMs == null) { + remainingRequestTime = requestTimeoutMs; + } else { + remainingRequestTime = calcTimeoutMsRemainingAsInt(now, deadlineMs); + } + while (!calls.isEmpty()) { + Call call = calls.remove(0); + int timeoutMs = Math.min(remainingRequestTime, + calcTimeoutMsRemainingAsInt(now, call.deadlineMs)); + AbstractRequest.Builder requestBuilder; + try { + requestBuilder = call.createRequest(timeoutMs); + } catch (Throwable t) { + call.fail(now, new KafkaException(String.format( + "Internal error sending %s to %s.", call.callName, node), t)); + continue; + } + ClientRequest clientRequest = client.newClientRequest(node.idString(), + requestBuilder, now, true, timeoutMs, null); + log.debug("Sending {} to {}. correlationId={}, timeoutMs={}", + requestBuilder, node, clientRequest.correlationId(), timeoutMs); + client.send(clientRequest, now); + callsInFlight.put(node.idString(), call); + correlationIdToCalls.put(clientRequest.correlationId(), call); + break; + } + } + return pollTimeout; + } + + /** + * Time out expired calls that are in flight. + * + * Calls that are in flight may have been partially or completely sent over the wire. They may + * even be in the process of being processed by the remote server. At the moment, our only option + * to time them out is to close the entire connection. + * + * @param processor The timeout processor. + */ + private void timeoutCallsInFlight(TimeoutProcessor processor) { + int numTimedOut = 0; + for (Map.Entry entry : callsInFlight.entrySet()) { + Call call = entry.getValue(); + String nodeId = entry.getKey(); + if (processor.callHasExpired(call)) { + log.info("Disconnecting from {} due to timeout while awaiting {}", nodeId, call); + client.disconnect(nodeId); + numTimedOut++; + // We don't remove anything from the callsInFlight data structure. Because the connection + // has been closed, the calls should be returned by the next client#poll(), + // and handled at that point. + } + } + if (numTimedOut > 0) + log.debug("Timed out {} call(s) in flight.", numTimedOut); + } + + /** + * Handle responses from the server. + * + * @param now The current time in milliseconds. + * @param responses The latest responses from KafkaClient. + **/ + private void handleResponses(long now, List responses) { + for (ClientResponse response : responses) { + int correlationId = response.requestHeader().correlationId(); + + Call call = correlationIdToCalls.get(correlationId); + if (call == null) { + // If the server returns information about a correlation ID we didn't use yet, + // an internal server error has occurred. Close the connection and log an error message. + log.error("Internal server error on {}: server returned information about unknown " + + "correlation ID {}, requestHeader = {}", response.destination(), correlationId, + response.requestHeader()); + client.disconnect(response.destination()); + continue; + } + + // Stop tracking this call. + correlationIdToCalls.remove(correlationId); + if (!callsInFlight.remove(response.destination(), call)) { + log.error("Internal server error on {}: ignoring call {} in correlationIdToCall " + + "that did not exist in callsInFlight", response.destination(), call); + continue; + } + + // Handle the result of the call. This may involve retrying the call, if we got a + // retriable exception. + if (response.versionMismatch() != null) { + call.fail(now, response.versionMismatch()); + } else if (response.wasDisconnected()) { + AuthenticationException authException = client.authenticationException(call.curNode()); + if (authException != null) { + call.fail(now, authException); + } else { + call.fail(now, new DisconnectException(String.format( + "Cancelled %s request with correlation id %s due to node %s being disconnected", + call.callName, correlationId, response.destination()))); + } + } else { + try { + call.handleResponse(response.responseBody()); + if (log.isTraceEnabled()) + log.trace("{} got response {}", call, response.responseBody()); + } catch (Throwable t) { + if (log.isTraceEnabled()) + log.trace("{} handleResponse failed with {}", call, prettyPrintException(t)); + call.fail(now, t); + } + } + } + } + + /** + * Unassign calls that have not yet been sent based on some predicate. For example, this + * is used to reassign the calls that have been assigned to a disconnected node. + * + * @param shouldUnassign Condition for reassignment. If the predicate is true, then the calls will + * be put back in the pendingCalls collection and they will be reassigned + */ + private void unassignUnsentCalls(Predicate shouldUnassign) { + for (Iterator>> iter = callsToSend.entrySet().iterator(); iter.hasNext(); ) { + Map.Entry> entry = iter.next(); + Node node = entry.getKey(); + List awaitingCalls = entry.getValue(); + + if (awaitingCalls.isEmpty()) { + iter.remove(); + } else if (shouldUnassign.test(node)) { + nodeReadyDeadlines.remove(node); + transitionToPendingAndClearList(awaitingCalls); + iter.remove(); + } + } + } + + private boolean hasActiveExternalCalls(Collection calls) { + for (Call call : calls) { + if (!call.isInternal()) { + return true; + } + } + return false; + } + + /** + * Return true if there are currently active external calls. + */ + private boolean hasActiveExternalCalls() { + if (hasActiveExternalCalls(pendingCalls)) { + return true; + } + for (List callList : callsToSend.values()) { + if (hasActiveExternalCalls(callList)) { + return true; + } + } + return hasActiveExternalCalls(correlationIdToCalls.values()); + } + + private boolean threadShouldExit(long now, long curHardShutdownTimeMs) { + if (!hasActiveExternalCalls()) { + log.trace("All work has been completed, and the I/O thread is now exiting."); + return true; + } + if (now >= curHardShutdownTimeMs) { + log.info("Forcing a hard I/O thread shutdown. Requests in progress will be aborted."); + return true; + } + log.debug("Hard shutdown in {} ms.", curHardShutdownTimeMs - now); + return false; + } + + @Override + public void run() { + log.debug("Thread starting"); + try { + processRequests(); + } finally { + closing = true; + AppInfoParser.unregisterAppInfo(JMX_PREFIX, clientId, metrics); + + int numTimedOut = 0; + TimeoutProcessor timeoutProcessor = new TimeoutProcessor(Long.MAX_VALUE); + synchronized (this) { + numTimedOut += timeoutProcessor.handleTimeouts(newCalls, "The AdminClient thread has exited."); + } + numTimedOut += timeoutProcessor.handleTimeouts(pendingCalls, "The AdminClient thread has exited."); + numTimedOut += timeoutCallsToSend(timeoutProcessor); + numTimedOut += timeoutProcessor.handleTimeouts(correlationIdToCalls.values(), + "The AdminClient thread has exited."); + if (numTimedOut > 0) { + log.info("Timed out {} remaining operation(s) during close.", numTimedOut); + } + closeQuietly(client, "KafkaClient"); + closeQuietly(metrics, "Metrics"); + log.debug("Exiting AdminClientRunnable thread."); + } + } + + private void processRequests() { + long now = time.milliseconds(); + while (true) { + // Copy newCalls into pendingCalls. + drainNewCalls(); + + // Check if the AdminClient thread should shut down. + long curHardShutdownTimeMs = hardShutdownTimeMs.get(); + if ((curHardShutdownTimeMs != INVALID_SHUTDOWN_TIME) && threadShouldExit(now, curHardShutdownTimeMs)) + break; + + // Handle timeouts. + TimeoutProcessor timeoutProcessor = timeoutProcessorFactory.create(now); + timeoutPendingCalls(timeoutProcessor); + timeoutCallsToSend(timeoutProcessor); + timeoutCallsInFlight(timeoutProcessor); + + long pollTimeout = Math.min(1200000, timeoutProcessor.nextTimeoutMs()); + if (curHardShutdownTimeMs != INVALID_SHUTDOWN_TIME) { + pollTimeout = Math.min(pollTimeout, curHardShutdownTimeMs - now); + } + + // Choose nodes for our pending calls. + pollTimeout = Math.min(pollTimeout, maybeDrainPendingCalls(now)); + long metadataFetchDelayMs = metadataManager.metadataFetchDelayMs(now); + if (metadataFetchDelayMs == 0) { + metadataManager.transitionToUpdatePending(now); + Call metadataCall = makeMetadataCall(now); + // Create a new metadata fetch call and add it to the end of pendingCalls. + // Assign a node for just the new call (we handled the other pending nodes above). + + if (!maybeDrainPendingCall(metadataCall, now)) + pendingCalls.add(metadataCall); + } + pollTimeout = Math.min(pollTimeout, sendEligibleCalls(now)); + + if (metadataFetchDelayMs > 0) { + pollTimeout = Math.min(pollTimeout, metadataFetchDelayMs); + } + + // Ensure that we use a small poll timeout if there are pending calls which need to be sent + if (!pendingCalls.isEmpty()) + pollTimeout = Math.min(pollTimeout, retryBackoffMs); + + // Wait for network responses. + log.trace("Entering KafkaClient#poll(timeout={})", pollTimeout); + List responses = client.poll(Math.max(0L, pollTimeout), now); + log.trace("KafkaClient#poll retrieved {} response(s)", responses.size()); + + // unassign calls to disconnected nodes + unassignUnsentCalls(client::connectionFailed); + + // Update the current time and handle the latest responses. + now = time.milliseconds(); + handleResponses(now, responses); + } + } + + /** + * Queue a call for sending. + * + * If the AdminClient thread has exited, this will fail. Otherwise, it will succeed (even + * if the AdminClient is shutting down). This function should called when retrying an + * existing call. + * + * @param call The new call object. + * @param now The current time in milliseconds. + */ + void enqueue(Call call, long now) { + if (call.tries > maxRetries) { + log.debug("Max retries {} for {} reached", maxRetries, call); + call.handleTimeoutFailure(time.milliseconds(), new TimeoutException( + "Exceeded maxRetries after " + call.tries + " tries.")); + return; + } + if (log.isDebugEnabled()) { + log.debug("Queueing {} with a timeout {} ms from now.", call, + Math.min(requestTimeoutMs, call.deadlineMs - now)); + } + boolean accepted = false; + synchronized (this) { + if (!closing) { + newCalls.add(call); + accepted = true; + } + } + if (accepted) { + client.wakeup(); // wake the thread if it is in poll() + } else { + log.debug("The AdminClient thread has exited. Timing out {}.", call); + call.handleTimeoutFailure(time.milliseconds(), + new TimeoutException("The AdminClient thread has exited.")); + } + } + + /** + * Initiate a new call. + * + * This will fail if the AdminClient is scheduled to shut down. + * + * @param call The new call object. + * @param now The current time in milliseconds. + */ + void call(Call call, long now) { + if (hardShutdownTimeMs.get() != INVALID_SHUTDOWN_TIME) { + log.debug("The AdminClient is not accepting new calls. Timing out {}.", call); + call.handleTimeoutFailure(time.milliseconds(), + new TimeoutException("The AdminClient thread is not accepting new calls.")); + } else { + enqueue(call, now); + } + } + + /** + * Create a new metadata call. + */ + private Call makeMetadataCall(long now) { + return new Call(true, "fetchMetadata", calcDeadlineMs(now, requestTimeoutMs), + new MetadataUpdateNodeIdProvider()) { + @Override + public MetadataRequest.Builder createRequest(int timeoutMs) { + // Since this only requests node information, it's safe to pass true + // for allowAutoTopicCreation (and it simplifies communication with + // older brokers) + return new MetadataRequest.Builder(new MetadataRequestData() + .setTopics(Collections.emptyList()) + .setAllowAutoTopicCreation(true)); + } + + @Override + public void handleResponse(AbstractResponse abstractResponse) { + MetadataResponse response = (MetadataResponse) abstractResponse; + long now = time.milliseconds(); + metadataManager.update(response.buildCluster(), now); + + // Unassign all unsent requests after a metadata refresh to allow for a new + // destination to be selected from the new metadata + unassignUnsentCalls(node -> true); + } + + @Override + public void handleFailure(Throwable e) { + metadataManager.updateFailed(e); + } + }; + } + } + + /** + * Returns true if a topic name cannot be represented in an RPC. This function does NOT check + * whether the name is too long, contains invalid characters, etc. It is better to enforce + * those policies on the server, so that they can be changed in the future if needed. + */ + private static boolean topicNameIsUnrepresentable(String topicName) { + return topicName == null || topicName.isEmpty(); + } + + private static boolean topicIdIsUnrepresentable(Uuid topicId) { + return topicId == null || topicId == Uuid.ZERO_UUID; + } + + // for testing + int numPendingCalls() { + return runnable.pendingCalls.size(); + } + + /** + * Fail futures in the given stream which are not done. + * Used when a response handler expected a result for some entity but no result was present. + */ + private static void completeUnrealizedFutures( + Stream>> futures, + Function messageFormatter) { + futures.filter(entry -> !entry.getValue().isDone()).forEach(entry -> + entry.getValue().completeExceptionally(new ApiException(messageFormatter.apply(entry.getKey())))); + } + + /** + * Fail futures in the given Map which were retried due to exceeding quota. We propagate + * the initial error back to the caller if the request timed out. + */ + private static void maybeCompleteQuotaExceededException( + boolean shouldRetryOnQuotaViolation, + Throwable throwable, + Map> futures, + Map quotaExceededExceptions, + int throttleTimeDelta) { + if (shouldRetryOnQuotaViolation && throwable instanceof TimeoutException) { + quotaExceededExceptions.forEach((key, value) -> futures.get(key).completeExceptionally( + new ThrottlingQuotaExceededException( + Math.max(0, value.throttleTimeMs() - throttleTimeDelta), + value.getMessage()))); + } + } + + @Override + public CreateTopicsResult createTopics(final Collection newTopics, + final CreateTopicsOptions options) { + final Map> topicFutures = new HashMap<>(newTopics.size()); + final CreatableTopicCollection topics = new CreatableTopicCollection(); + for (NewTopic newTopic : newTopics) { + if (topicNameIsUnrepresentable(newTopic.name())) { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + future.completeExceptionally(new InvalidTopicException("The given topic name '" + + newTopic.name() + "' cannot be represented in a request.")); + topicFutures.put(newTopic.name(), future); + } else if (!topicFutures.containsKey(newTopic.name())) { + topicFutures.put(newTopic.name(), new KafkaFutureImpl<>()); + topics.add(newTopic.convertToCreatableTopic()); + } + } + if (!topics.isEmpty()) { + final long now = time.milliseconds(); + final long deadline = calcDeadlineMs(now, options.timeoutMs()); + final Call call = getCreateTopicsCall(options, topicFutures, topics, + Collections.emptyMap(), now, deadline); + runnable.call(call, now); + } + return new CreateTopicsResult(new HashMap<>(topicFutures)); + } + + private Call getCreateTopicsCall(final CreateTopicsOptions options, + final Map> futures, + final CreatableTopicCollection topics, + final Map quotaExceededExceptions, + final long now, + final long deadline) { + return new Call("createTopics", deadline, new ControllerNodeProvider()) { + @Override + public CreateTopicsRequest.Builder createRequest(int timeoutMs) { + return new CreateTopicsRequest.Builder( + new CreateTopicsRequestData() + .setTopics(topics) + .setTimeoutMs(timeoutMs) + .setValidateOnly(options.shouldValidateOnly())); + } + + @Override + public void handleResponse(AbstractResponse abstractResponse) { + // Check for controller change + handleNotControllerError(abstractResponse); + // Handle server responses for particular topics. + final CreateTopicsResponse response = (CreateTopicsResponse) abstractResponse; + final CreatableTopicCollection retryTopics = new CreatableTopicCollection(); + final Map retryTopicQuotaExceededExceptions = new HashMap<>(); + for (CreatableTopicResult result : response.data().topics()) { + KafkaFutureImpl future = futures.get(result.name()); + if (future == null) { + log.warn("Server response mentioned unknown topic {}", result.name()); + } else { + ApiError error = new ApiError(result.errorCode(), result.errorMessage()); + if (error.isFailure()) { + if (error.is(Errors.THROTTLING_QUOTA_EXCEEDED)) { + ThrottlingQuotaExceededException quotaExceededException = new ThrottlingQuotaExceededException( + response.throttleTimeMs(), error.messageWithFallback()); + if (options.shouldRetryOnQuotaViolation()) { + retryTopics.add(topics.find(result.name()).duplicate()); + retryTopicQuotaExceededExceptions.put(result.name(), quotaExceededException); + } else { + future.completeExceptionally(quotaExceededException); + } + } else { + future.completeExceptionally(error.exception()); + } + } else { + TopicMetadataAndConfig topicMetadataAndConfig; + if (result.topicConfigErrorCode() != Errors.NONE.code()) { + topicMetadataAndConfig = new TopicMetadataAndConfig( + Errors.forCode(result.topicConfigErrorCode()).exception()); + } else if (result.numPartitions() == CreateTopicsResult.UNKNOWN) { + topicMetadataAndConfig = new TopicMetadataAndConfig(new UnsupportedVersionException( + "Topic metadata and configs in CreateTopics response not supported")); + } else { + List configs = result.configs(); + Config topicConfig = new Config(configs.stream() + .map(this::configEntry) + .collect(Collectors.toSet())); + topicMetadataAndConfig = new TopicMetadataAndConfig(result.topicId(), result.numPartitions(), + result.replicationFactor(), + topicConfig); + } + future.complete(topicMetadataAndConfig); + } + } + } + // If there are topics to retry, retry them; complete unrealized futures otherwise. + if (retryTopics.isEmpty()) { + // The server should send back a response for every topic. But do a sanity check anyway. + completeUnrealizedFutures(futures.entrySet().stream(), + topic -> "The controller response did not contain a result for topic " + topic); + } else { + final long now = time.milliseconds(); + final Call call = getCreateTopicsCall(options, futures, retryTopics, + retryTopicQuotaExceededExceptions, now, deadline); + runnable.call(call, now); + } + } + + private ConfigEntry configEntry(CreatableTopicConfigs config) { + return new ConfigEntry( + config.name(), + config.value(), + configSource(DescribeConfigsResponse.ConfigSource.forId(config.configSource())), + config.isSensitive(), + config.readOnly(), + Collections.emptyList(), + null, + null); + } + + @Override + void handleFailure(Throwable throwable) { + // If there were any topics retries due to a quota exceeded exception, we propagate + // the initial error back to the caller if the request timed out. + maybeCompleteQuotaExceededException(options.shouldRetryOnQuotaViolation(), + throwable, futures, quotaExceededExceptions, (int) (time.milliseconds() - now)); + // Fail all the other remaining futures + completeAllExceptionally(futures.values(), throwable); + } + }; + } + + @Override + public DeleteTopicsResult deleteTopics(final TopicCollection topics, + final DeleteTopicsOptions options) { + if (topics instanceof TopicIdCollection) + return DeleteTopicsResult.ofTopicIds(handleDeleteTopicsUsingIds(((TopicIdCollection) topics).topicIds(), options)); + else if (topics instanceof TopicNameCollection) + return DeleteTopicsResult.ofTopicNames(handleDeleteTopicsUsingNames(((TopicNameCollection) topics).topicNames(), options)); + else + throw new IllegalArgumentException("The TopicCollection: " + topics + " provided did not match any supported classes for deleteTopics."); + } + + private Map> handleDeleteTopicsUsingNames(final Collection topicNames, + final DeleteTopicsOptions options) { + final Map> topicFutures = new HashMap<>(topicNames.size()); + final List validTopicNames = new ArrayList<>(topicNames.size()); + for (String topicName : topicNames) { + if (topicNameIsUnrepresentable(topicName)) { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + future.completeExceptionally(new InvalidTopicException("The given topic name '" + + topicName + "' cannot be represented in a request.")); + topicFutures.put(topicName, future); + } else if (!topicFutures.containsKey(topicName)) { + topicFutures.put(topicName, new KafkaFutureImpl<>()); + validTopicNames.add(topicName); + } + } + if (!validTopicNames.isEmpty()) { + final long now = time.milliseconds(); + final long deadline = calcDeadlineMs(now, options.timeoutMs()); + final Call call = getDeleteTopicsCall(options, topicFutures, validTopicNames, + Collections.emptyMap(), now, deadline); + runnable.call(call, now); + } + return new HashMap<>(topicFutures); + } + + private Map> handleDeleteTopicsUsingIds(final Collection topicIds, + final DeleteTopicsOptions options) { + final Map> topicFutures = new HashMap<>(topicIds.size()); + final List validTopicIds = new ArrayList<>(topicIds.size()); + for (Uuid topicId : topicIds) { + if (topicId.equals(Uuid.ZERO_UUID)) { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + future.completeExceptionally(new InvalidTopicException("The given topic ID '" + + topicId + "' cannot be represented in a request.")); + topicFutures.put(topicId, future); + } else if (!topicFutures.containsKey(topicId)) { + topicFutures.put(topicId, new KafkaFutureImpl<>()); + validTopicIds.add(topicId); + } + } + if (!validTopicIds.isEmpty()) { + final long now = time.milliseconds(); + final long deadline = calcDeadlineMs(now, options.timeoutMs()); + final Call call = getDeleteTopicsWithIdsCall(options, topicFutures, validTopicIds, + Collections.emptyMap(), now, deadline); + runnable.call(call, now); + } + return new HashMap<>(topicFutures); + } + + private Call getDeleteTopicsCall(final DeleteTopicsOptions options, + final Map> futures, + final List topics, + final Map quotaExceededExceptions, + final long now, + final long deadline) { + return new Call("deleteTopics", deadline, new ControllerNodeProvider()) { + @Override + DeleteTopicsRequest.Builder createRequest(int timeoutMs) { + return new DeleteTopicsRequest.Builder( + new DeleteTopicsRequestData() + .setTopicNames(topics) + .setTimeoutMs(timeoutMs)); + } + + @Override + void handleResponse(AbstractResponse abstractResponse) { + // Check for controller change + handleNotControllerError(abstractResponse); + // Handle server responses for particular topics. + final DeleteTopicsResponse response = (DeleteTopicsResponse) abstractResponse; + final List retryTopics = new ArrayList<>(); + final Map retryTopicQuotaExceededExceptions = new HashMap<>(); + for (DeletableTopicResult result : response.data().responses()) { + KafkaFutureImpl future = futures.get(result.name()); + if (future == null) { + log.warn("Server response mentioned unknown topic {}", result.name()); + } else { + ApiError error = new ApiError(result.errorCode(), result.errorMessage()); + if (error.isFailure()) { + if (error.is(Errors.THROTTLING_QUOTA_EXCEEDED)) { + ThrottlingQuotaExceededException quotaExceededException = new ThrottlingQuotaExceededException( + response.throttleTimeMs(), error.messageWithFallback()); + if (options.shouldRetryOnQuotaViolation()) { + retryTopics.add(result.name()); + retryTopicQuotaExceededExceptions.put(result.name(), quotaExceededException); + } else { + future.completeExceptionally(quotaExceededException); + } + } else { + future.completeExceptionally(error.exception()); + } + } else { + future.complete(null); + } + } + } + // If there are topics to retry, retry them; complete unrealized futures otherwise. + if (retryTopics.isEmpty()) { + // The server should send back a response for every topic. But do a sanity check anyway. + completeUnrealizedFutures(futures.entrySet().stream(), + topic -> "The controller response did not contain a result for topic " + topic); + } else { + final long now = time.milliseconds(); + final Call call = getDeleteTopicsCall(options, futures, retryTopics, + retryTopicQuotaExceededExceptions, now, deadline); + runnable.call(call, now); + } + } + + @Override + void handleFailure(Throwable throwable) { + // If there were any topics retries due to a quota exceeded exception, we propagate + // the initial error back to the caller if the request timed out. + maybeCompleteQuotaExceededException(options.shouldRetryOnQuotaViolation(), + throwable, futures, quotaExceededExceptions, (int) (time.milliseconds() - now)); + // Fail all the other remaining futures + completeAllExceptionally(futures.values(), throwable); + } + }; + } + + private Call getDeleteTopicsWithIdsCall(final DeleteTopicsOptions options, + final Map> futures, + final List topicIds, + final Map quotaExceededExceptions, + final long now, + final long deadline) { + return new Call("deleteTopics", deadline, new ControllerNodeProvider()) { + @Override + DeleteTopicsRequest.Builder createRequest(int timeoutMs) { + return new DeleteTopicsRequest.Builder( + new DeleteTopicsRequestData() + .setTopics(topicIds.stream().map( + topic -> new DeleteTopicState().setTopicId(topic)).collect(Collectors.toList())) + .setTimeoutMs(timeoutMs)); + } + + @Override + void handleResponse(AbstractResponse abstractResponse) { + // Check for controller change + handleNotControllerError(abstractResponse); + // Handle server responses for particular topics. + final DeleteTopicsResponse response = (DeleteTopicsResponse) abstractResponse; + final List retryTopics = new ArrayList<>(); + final Map retryTopicQuotaExceededExceptions = new HashMap<>(); + for (DeletableTopicResult result : response.data().responses()) { + KafkaFutureImpl future = futures.get(result.topicId()); + if (future == null) { + log.warn("Server response mentioned unknown topic ID {}", result.topicId()); + } else { + ApiError error = new ApiError(result.errorCode(), result.errorMessage()); + if (error.isFailure()) { + if (error.is(Errors.THROTTLING_QUOTA_EXCEEDED)) { + ThrottlingQuotaExceededException quotaExceededException = new ThrottlingQuotaExceededException( + response.throttleTimeMs(), error.messageWithFallback()); + if (options.shouldRetryOnQuotaViolation()) { + retryTopics.add(result.topicId()); + retryTopicQuotaExceededExceptions.put(result.topicId(), quotaExceededException); + } else { + future.completeExceptionally(quotaExceededException); + } + } else { + future.completeExceptionally(error.exception()); + } + } else { + future.complete(null); + } + } + } + // If there are topics to retry, retry them; complete unrealized futures otherwise. + if (retryTopics.isEmpty()) { + // The server should send back a response for every topic. But do a sanity check anyway. + completeUnrealizedFutures(futures.entrySet().stream(), + topic -> "The controller response did not contain a result for topic " + topic); + } else { + final long now = time.milliseconds(); + final Call call = getDeleteTopicsWithIdsCall(options, futures, retryTopics, + retryTopicQuotaExceededExceptions, now, deadline); + runnable.call(call, now); + } + } + + @Override + void handleFailure(Throwable throwable) { + // If there were any topics retries due to a quota exceeded exception, we propagate + // the initial error back to the caller if the request timed out. + maybeCompleteQuotaExceededException(options.shouldRetryOnQuotaViolation(), + throwable, futures, quotaExceededExceptions, (int) (time.milliseconds() - now)); + // Fail all the other remaining futures + completeAllExceptionally(futures.values(), throwable); + } + }; + } + + @Override + public ListTopicsResult listTopics(final ListTopicsOptions options) { + final KafkaFutureImpl> topicListingFuture = new KafkaFutureImpl<>(); + final long now = time.milliseconds(); + runnable.call(new Call("listTopics", calcDeadlineMs(now, options.timeoutMs()), + new LeastLoadedNodeProvider()) { + + @Override + MetadataRequest.Builder createRequest(int timeoutMs) { + return MetadataRequest.Builder.allTopics(); + } + + @Override + void handleResponse(AbstractResponse abstractResponse) { + MetadataResponse response = (MetadataResponse) abstractResponse; + Map topicListing = new HashMap<>(); + for (MetadataResponse.TopicMetadata topicMetadata : response.topicMetadata()) { + String topicName = topicMetadata.topic(); + boolean isInternal = topicMetadata.isInternal(); + if (!topicMetadata.isInternal() || options.shouldListInternal()) + topicListing.put(topicName, new TopicListing(topicName, topicMetadata.topicId(), isInternal)); + } + topicListingFuture.complete(topicListing); + } + + @Override + void handleFailure(Throwable throwable) { + topicListingFuture.completeExceptionally(throwable); + } + }, now); + return new ListTopicsResult(topicListingFuture); + } + + @Override + public DescribeTopicsResult describeTopics(final TopicCollection topics, DescribeTopicsOptions options) { + if (topics instanceof TopicIdCollection) + return DescribeTopicsResult.ofTopicIds(handleDescribeTopicsByIds(((TopicIdCollection) topics).topicIds(), options)); + else if (topics instanceof TopicNameCollection) + return DescribeTopicsResult.ofTopicNames(handleDescribeTopicsByNames(((TopicNameCollection) topics).topicNames(), options)); + else + throw new IllegalArgumentException("The TopicCollection: " + topics + " provided did not match any supported classes for describeTopics."); + } + + private Map> handleDescribeTopicsByNames(final Collection topicNames, DescribeTopicsOptions options) { + final Map> topicFutures = new HashMap<>(topicNames.size()); + final ArrayList topicNamesList = new ArrayList<>(); + for (String topicName : topicNames) { + if (topicNameIsUnrepresentable(topicName)) { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + future.completeExceptionally(new InvalidTopicException("The given topic name '" + + topicName + "' cannot be represented in a request.")); + topicFutures.put(topicName, future); + } else if (!topicFutures.containsKey(topicName)) { + topicFutures.put(topicName, new KafkaFutureImpl<>()); + topicNamesList.add(topicName); + } + } + final long now = time.milliseconds(); + Call call = new Call("describeTopics", calcDeadlineMs(now, options.timeoutMs()), + new LeastLoadedNodeProvider()) { + + private boolean supportsDisablingTopicCreation = true; + + @Override + MetadataRequest.Builder createRequest(int timeoutMs) { + if (supportsDisablingTopicCreation) + return new MetadataRequest.Builder(new MetadataRequestData() + .setTopics(convertToMetadataRequestTopic(topicNamesList)) + .setAllowAutoTopicCreation(false) + .setIncludeTopicAuthorizedOperations(options.includeAuthorizedOperations())); + else + return MetadataRequest.Builder.allTopics(); + } + + @Override + void handleResponse(AbstractResponse abstractResponse) { + MetadataResponse response = (MetadataResponse) abstractResponse; + // Handle server responses for particular topics. + Cluster cluster = response.buildCluster(); + Map errors = response.errors(); + for (Map.Entry> entry : topicFutures.entrySet()) { + String topicName = entry.getKey(); + KafkaFutureImpl future = entry.getValue(); + Errors topicError = errors.get(topicName); + if (topicError != null) { + future.completeExceptionally(topicError.exception()); + continue; + } + if (!cluster.topics().contains(topicName)) { + future.completeExceptionally(new UnknownTopicOrPartitionException("Topic " + topicName + " not found.")); + continue; + } + Uuid topicId = cluster.topicId(topicName); + Integer authorizedOperations = response.topicAuthorizedOperations(topicName).get(); + TopicDescription topicDescription = getTopicDescriptionFromCluster(cluster, topicName, topicId, authorizedOperations); + future.complete(topicDescription); + } + } + + @Override + boolean handleUnsupportedVersionException(UnsupportedVersionException exception) { + if (supportsDisablingTopicCreation) { + supportsDisablingTopicCreation = false; + return true; + } + return false; + } + + @Override + void handleFailure(Throwable throwable) { + completeAllExceptionally(topicFutures.values(), throwable); + } + }; + if (!topicNamesList.isEmpty()) { + runnable.call(call, now); + } + return new HashMap<>(topicFutures); + } + + private Map> handleDescribeTopicsByIds(Collection topicIds, DescribeTopicsOptions options) { + + final Map> topicFutures = new HashMap<>(topicIds.size()); + final List topicIdsList = new ArrayList<>(); + for (Uuid topicId : topicIds) { + if (topicIdIsUnrepresentable(topicId)) { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + future.completeExceptionally(new InvalidTopicException("The given topic id '" + + topicId + "' cannot be represented in a request.")); + topicFutures.put(topicId, future); + } else if (!topicFutures.containsKey(topicId)) { + topicFutures.put(topicId, new KafkaFutureImpl<>()); + topicIdsList.add(topicId); + } + } + final long now = time.milliseconds(); + Call call = new Call("describeTopicsWithIds", calcDeadlineMs(now, options.timeoutMs()), + new LeastLoadedNodeProvider()) { + + @Override + MetadataRequest.Builder createRequest(int timeoutMs) { + return new MetadataRequest.Builder(new MetadataRequestData() + .setTopics(convertTopicIdsToMetadataRequestTopic(topicIdsList)) + .setAllowAutoTopicCreation(false) + .setIncludeTopicAuthorizedOperations(options.includeAuthorizedOperations())); + } + + @Override + void handleResponse(AbstractResponse abstractResponse) { + MetadataResponse response = (MetadataResponse) abstractResponse; + // Handle server responses for particular topics. + Cluster cluster = response.buildCluster(); + Map errors = response.errorsByTopicId(); + for (Map.Entry> entry : topicFutures.entrySet()) { + Uuid topicId = entry.getKey(); + KafkaFutureImpl future = entry.getValue(); + + String topicName = cluster.topicName(topicId); + if (topicName == null) { + future.completeExceptionally(new InvalidTopicException("TopicId " + topicId + " not found.")); + continue; + } + Errors topicError = errors.get(topicId); + if (topicError != null) { + future.completeExceptionally(topicError.exception()); + continue; + } + + Integer authorizedOperations = response.topicAuthorizedOperations(topicName).get(); + TopicDescription topicDescription = getTopicDescriptionFromCluster(cluster, topicName, topicId, authorizedOperations); + future.complete(topicDescription); + } + } + + @Override + void handleFailure(Throwable throwable) { + completeAllExceptionally(topicFutures.values(), throwable); + } + }; + if (!topicIdsList.isEmpty()) { + runnable.call(call, now); + } + return new HashMap<>(topicFutures); + } + + private TopicDescription getTopicDescriptionFromCluster(Cluster cluster, String topicName, Uuid topicId, + Integer authorizedOperations) { + boolean isInternal = cluster.internalTopics().contains(topicName); + List partitionInfos = cluster.partitionsForTopic(topicName); + List partitions = new ArrayList<>(partitionInfos.size()); + for (PartitionInfo partitionInfo : partitionInfos) { + TopicPartitionInfo topicPartitionInfo = new TopicPartitionInfo( + partitionInfo.partition(), leader(partitionInfo), Arrays.asList(partitionInfo.replicas()), + Arrays.asList(partitionInfo.inSyncReplicas())); + partitions.add(topicPartitionInfo); + } + partitions.sort(Comparator.comparingInt(TopicPartitionInfo::partition)); + return new TopicDescription(topicName, isInternal, partitions, validAclOperations(authorizedOperations), topicId); + } + + private Node leader(PartitionInfo partitionInfo) { + if (partitionInfo.leader() == null || partitionInfo.leader().id() == Node.noNode().id()) + return null; + return partitionInfo.leader(); + } + + @Override + public DescribeClusterResult describeCluster(DescribeClusterOptions options) { + final KafkaFutureImpl> describeClusterFuture = new KafkaFutureImpl<>(); + final KafkaFutureImpl controllerFuture = new KafkaFutureImpl<>(); + final KafkaFutureImpl clusterIdFuture = new KafkaFutureImpl<>(); + final KafkaFutureImpl> authorizedOperationsFuture = new KafkaFutureImpl<>(); + + final long now = time.milliseconds(); + runnable.call(new Call("listNodes", calcDeadlineMs(now, options.timeoutMs()), + new LeastLoadedNodeProvider()) { + + private boolean useMetadataRequest = false; + + @Override + AbstractRequest.Builder createRequest(int timeoutMs) { + if (!useMetadataRequest) { + return new DescribeClusterRequest.Builder(new DescribeClusterRequestData() + .setIncludeClusterAuthorizedOperations( + options.includeAuthorizedOperations())); + } else { + // Since this only requests node information, it's safe to pass true for allowAutoTopicCreation (and it + // simplifies communication with older brokers) + return new MetadataRequest.Builder(new MetadataRequestData() + .setTopics(Collections.emptyList()) + .setAllowAutoTopicCreation(true) + .setIncludeClusterAuthorizedOperations( + options.includeAuthorizedOperations())); + } + } + + @Override + void handleResponse(AbstractResponse abstractResponse) { + if (!useMetadataRequest) { + DescribeClusterResponse response = (DescribeClusterResponse) abstractResponse; + + Errors error = Errors.forCode(response.data().errorCode()); + if (error != Errors.NONE) { + ApiError apiError = new ApiError(error, response.data().errorMessage()); + handleFailure(apiError.exception()); + return; + } + + Map nodes = response.nodes(); + describeClusterFuture.complete(nodes.values()); + // Controller is null if controller id is equal to NO_CONTROLLER_ID + controllerFuture.complete(nodes.get(response.data().controllerId())); + clusterIdFuture.complete(response.data().clusterId()); + authorizedOperationsFuture.complete( + validAclOperations(response.data().clusterAuthorizedOperations())); + } else { + MetadataResponse response = (MetadataResponse) abstractResponse; + describeClusterFuture.complete(response.brokers()); + controllerFuture.complete(controller(response)); + clusterIdFuture.complete(response.clusterId()); + authorizedOperationsFuture.complete( + validAclOperations(response.clusterAuthorizedOperations())); + } + } + + private Node controller(MetadataResponse response) { + if (response.controller() == null || response.controller().id() == MetadataResponse.NO_CONTROLLER_ID) + return null; + return response.controller(); + } + + @Override + void handleFailure(Throwable throwable) { + describeClusterFuture.completeExceptionally(throwable); + controllerFuture.completeExceptionally(throwable); + clusterIdFuture.completeExceptionally(throwable); + authorizedOperationsFuture.completeExceptionally(throwable); + } + + @Override + boolean handleUnsupportedVersionException(final UnsupportedVersionException exception) { + if (useMetadataRequest) { + return false; + } + + useMetadataRequest = true; + return true; + } + }, now); + + return new DescribeClusterResult(describeClusterFuture, controllerFuture, clusterIdFuture, + authorizedOperationsFuture); + } + + @Override + public DescribeAclsResult describeAcls(final AclBindingFilter filter, DescribeAclsOptions options) { + if (filter.isUnknown()) { + KafkaFutureImpl> future = new KafkaFutureImpl<>(); + future.completeExceptionally(new InvalidRequestException("The AclBindingFilter " + + "must not contain UNKNOWN elements.")); + return new DescribeAclsResult(future); + } + final long now = time.milliseconds(); + final KafkaFutureImpl> future = new KafkaFutureImpl<>(); + runnable.call(new Call("describeAcls", calcDeadlineMs(now, options.timeoutMs()), + new LeastLoadedNodeProvider()) { + + @Override + DescribeAclsRequest.Builder createRequest(int timeoutMs) { + return new DescribeAclsRequest.Builder(filter); + } + + @Override + void handleResponse(AbstractResponse abstractResponse) { + DescribeAclsResponse response = (DescribeAclsResponse) abstractResponse; + if (response.error().isFailure()) { + future.completeExceptionally(response.error().exception()); + } else { + future.complete(DescribeAclsResponse.aclBindings(response.acls())); + } + } + + @Override + void handleFailure(Throwable throwable) { + future.completeExceptionally(throwable); + } + }, now); + return new DescribeAclsResult(future); + } + + @Override + public CreateAclsResult createAcls(Collection acls, CreateAclsOptions options) { + final long now = time.milliseconds(); + final Map> futures = new HashMap<>(); + final List aclCreations = new ArrayList<>(); + final List aclBindingsSent = new ArrayList<>(); + for (AclBinding acl : acls) { + if (futures.get(acl) == null) { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + futures.put(acl, future); + String indefinite = acl.toFilter().findIndefiniteField(); + if (indefinite == null) { + aclCreations.add(CreateAclsRequest.aclCreation(acl)); + aclBindingsSent.add(acl); + } else { + future.completeExceptionally(new InvalidRequestException("Invalid ACL creation: " + + indefinite)); + } + } + } + final CreateAclsRequestData data = new CreateAclsRequestData().setCreations(aclCreations); + runnable.call(new Call("createAcls", calcDeadlineMs(now, options.timeoutMs()), + new LeastLoadedNodeProvider()) { + + @Override + CreateAclsRequest.Builder createRequest(int timeoutMs) { + return new CreateAclsRequest.Builder(data); + } + + @Override + void handleResponse(AbstractResponse abstractResponse) { + CreateAclsResponse response = (CreateAclsResponse) abstractResponse; + List responses = response.results(); + Iterator iter = responses.iterator(); + for (AclBinding aclBinding : aclBindingsSent) { + KafkaFutureImpl future = futures.get(aclBinding); + if (!iter.hasNext()) { + future.completeExceptionally(new UnknownServerException( + "The broker reported no creation result for the given ACL: " + aclBinding)); + } else { + AclCreationResult creation = iter.next(); + Errors error = Errors.forCode(creation.errorCode()); + ApiError apiError = new ApiError(error, creation.errorMessage()); + if (apiError.isFailure()) + future.completeExceptionally(apiError.exception()); + else + future.complete(null); + } + } + } + + @Override + void handleFailure(Throwable throwable) { + completeAllExceptionally(futures.values(), throwable); + } + }, now); + return new CreateAclsResult(new HashMap<>(futures)); + } + + @Override + public DeleteAclsResult deleteAcls(Collection filters, DeleteAclsOptions options) { + final long now = time.milliseconds(); + final Map> futures = new HashMap<>(); + final List aclBindingFiltersSent = new ArrayList<>(); + final List deleteAclsFilters = new ArrayList<>(); + for (AclBindingFilter filter : filters) { + if (futures.get(filter) == null) { + aclBindingFiltersSent.add(filter); + deleteAclsFilters.add(DeleteAclsRequest.deleteAclsFilter(filter)); + futures.put(filter, new KafkaFutureImpl<>()); + } + } + final DeleteAclsRequestData data = new DeleteAclsRequestData().setFilters(deleteAclsFilters); + runnable.call(new Call("deleteAcls", calcDeadlineMs(now, options.timeoutMs()), + new LeastLoadedNodeProvider()) { + + @Override + DeleteAclsRequest.Builder createRequest(int timeoutMs) { + return new DeleteAclsRequest.Builder(data); + } + + @Override + void handleResponse(AbstractResponse abstractResponse) { + DeleteAclsResponse response = (DeleteAclsResponse) abstractResponse; + List results = response.filterResults(); + Iterator iter = results.iterator(); + for (AclBindingFilter bindingFilter : aclBindingFiltersSent) { + KafkaFutureImpl future = futures.get(bindingFilter); + if (!iter.hasNext()) { + future.completeExceptionally(new UnknownServerException( + "The broker reported no deletion result for the given filter.")); + } else { + DeleteAclsFilterResult filterResult = iter.next(); + ApiError error = new ApiError(Errors.forCode(filterResult.errorCode()), filterResult.errorMessage()); + if (error.isFailure()) { + future.completeExceptionally(error.exception()); + } else { + List filterResults = new ArrayList<>(); + for (DeleteAclsMatchingAcl matchingAcl : filterResult.matchingAcls()) { + ApiError aclError = new ApiError(Errors.forCode(matchingAcl.errorCode()), + matchingAcl.errorMessage()); + AclBinding aclBinding = DeleteAclsResponse.aclBinding(matchingAcl); + filterResults.add(new FilterResult(aclBinding, aclError.exception())); + } + future.complete(new FilterResults(filterResults)); + } + } + } + } + + @Override + void handleFailure(Throwable throwable) { + completeAllExceptionally(futures.values(), throwable); + } + }, now); + return new DeleteAclsResult(new HashMap<>(futures)); + } + + @Override + public DescribeConfigsResult describeConfigs(Collection configResources, final DescribeConfigsOptions options) { + // Partition the requested config resources based on which broker they must be sent to with the + // null broker being used for config resources which can be obtained from any broker + final Map>> brokerFutures = new HashMap<>(configResources.size()); + + for (ConfigResource resource : configResources) { + Integer broker = nodeFor(resource); + brokerFutures.compute(broker, (key, value) -> { + if (value == null) { + value = new HashMap<>(); + } + value.put(resource, new KafkaFutureImpl<>()); + return value; + }); + } + + final long now = time.milliseconds(); + for (Map.Entry>> entry : brokerFutures.entrySet()) { + Integer broker = entry.getKey(); + Map> unified = entry.getValue(); + + runnable.call(new Call("describeConfigs", calcDeadlineMs(now, options.timeoutMs()), + broker != null ? new ConstantNodeIdProvider(broker) : new LeastLoadedNodeProvider()) { + + @Override + DescribeConfigsRequest.Builder createRequest(int timeoutMs) { + return new DescribeConfigsRequest.Builder(new DescribeConfigsRequestData() + .setResources(unified.keySet().stream() + .map(config -> + new DescribeConfigsRequestData.DescribeConfigsResource() + .setResourceName(config.name()) + .setResourceType(config.type().id()) + .setConfigurationKeys(null)) + .collect(Collectors.toList())) + .setIncludeSynonyms(options.includeSynonyms()) + .setIncludeDocumentation(options.includeDocumentation())); + } + + @Override + void handleResponse(AbstractResponse abstractResponse) { + DescribeConfigsResponse response = (DescribeConfigsResponse) abstractResponse; + for (Map.Entry entry : response.resultMap().entrySet()) { + ConfigResource configResource = entry.getKey(); + DescribeConfigsResponseData.DescribeConfigsResult describeConfigsResult = entry.getValue(); + KafkaFutureImpl future = unified.get(configResource); + if (future == null) { + if (broker != null) { + log.warn("The config {} in the response from broker {} is not in the request", + configResource, broker); + } else { + log.warn("The config {} in the response from the least loaded broker is not in the request", + configResource); + } + } else { + if (describeConfigsResult.errorCode() != Errors.NONE.code()) { + future.completeExceptionally(Errors.forCode(describeConfigsResult.errorCode()) + .exception(describeConfigsResult.errorMessage())); + } else { + future.complete(describeConfigResult(describeConfigsResult)); + } + } + } + completeUnrealizedFutures( + unified.entrySet().stream(), + configResource -> "The broker response did not contain a result for config resource " + configResource); + } + + @Override + void handleFailure(Throwable throwable) { + completeAllExceptionally(unified.values(), throwable); + } + }, now); + } + + return new DescribeConfigsResult(new HashMap<>(brokerFutures.entrySet().stream() + .flatMap(x -> x.getValue().entrySet().stream()) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)))); + } + + private Config describeConfigResult(DescribeConfigsResponseData.DescribeConfigsResult describeConfigsResult) { + return new Config(describeConfigsResult.configs().stream().map(config -> new ConfigEntry( + config.name(), + config.value(), + DescribeConfigsResponse.ConfigSource.forId(config.configSource()).source(), + config.isSensitive(), + config.readOnly(), + (config.synonyms().stream().map(synonym -> new ConfigEntry.ConfigSynonym(synonym.name(), synonym.value(), + DescribeConfigsResponse.ConfigSource.forId(synonym.source()).source()))).collect(Collectors.toList()), + DescribeConfigsResponse.ConfigType.forId(config.configType()).type(), + config.documentation() + )).collect(Collectors.toList())); + } + + private ConfigEntry.ConfigSource configSource(DescribeConfigsResponse.ConfigSource source) { + ConfigEntry.ConfigSource configSource; + switch (source) { + case TOPIC_CONFIG: + configSource = ConfigEntry.ConfigSource.DYNAMIC_TOPIC_CONFIG; + break; + case DYNAMIC_BROKER_CONFIG: + configSource = ConfigEntry.ConfigSource.DYNAMIC_BROKER_CONFIG; + break; + case DYNAMIC_DEFAULT_BROKER_CONFIG: + configSource = ConfigEntry.ConfigSource.DYNAMIC_DEFAULT_BROKER_CONFIG; + break; + case STATIC_BROKER_CONFIG: + configSource = ConfigEntry.ConfigSource.STATIC_BROKER_CONFIG; + break; + case DYNAMIC_BROKER_LOGGER_CONFIG: + configSource = ConfigEntry.ConfigSource.DYNAMIC_BROKER_LOGGER_CONFIG; + break; + case DEFAULT_CONFIG: + configSource = ConfigEntry.ConfigSource.DEFAULT_CONFIG; + break; + default: + throw new IllegalArgumentException("Unexpected config source " + source); + } + return configSource; + } + + @Override + @Deprecated + public AlterConfigsResult alterConfigs(Map configs, final AlterConfigsOptions options) { + final Map> allFutures = new HashMap<>(); + // We must make a separate AlterConfigs request for every BROKER resource we want to alter + // and send the request to that specific broker. Other resources are grouped together into + // a single request that may be sent to any broker. + final Collection unifiedRequestResources = new ArrayList<>(); + + for (ConfigResource resource : configs.keySet()) { + Integer node = nodeFor(resource); + if (node != null) { + NodeProvider nodeProvider = new ConstantNodeIdProvider(node); + allFutures.putAll(alterConfigs(configs, options, Collections.singleton(resource), nodeProvider)); + } else + unifiedRequestResources.add(resource); + } + if (!unifiedRequestResources.isEmpty()) + allFutures.putAll(alterConfigs(configs, options, unifiedRequestResources, new LeastLoadedNodeProvider())); + return new AlterConfigsResult(new HashMap<>(allFutures)); + } + + private Map> alterConfigs(Map configs, + final AlterConfigsOptions options, + Collection resources, + NodeProvider nodeProvider) { + final Map> futures = new HashMap<>(); + final Map requestMap = new HashMap<>(resources.size()); + for (ConfigResource resource : resources) { + List configEntries = new ArrayList<>(); + for (ConfigEntry configEntry: configs.get(resource).entries()) + configEntries.add(new AlterConfigsRequest.ConfigEntry(configEntry.name(), configEntry.value())); + requestMap.put(resource, new AlterConfigsRequest.Config(configEntries)); + futures.put(resource, new KafkaFutureImpl<>()); + } + + final long now = time.milliseconds(); + runnable.call(new Call("alterConfigs", calcDeadlineMs(now, options.timeoutMs()), nodeProvider) { + + @Override + public AlterConfigsRequest.Builder createRequest(int timeoutMs) { + return new AlterConfigsRequest.Builder(requestMap, options.shouldValidateOnly()); + } + + @Override + public void handleResponse(AbstractResponse abstractResponse) { + AlterConfigsResponse response = (AlterConfigsResponse) abstractResponse; + for (Map.Entry> entry : futures.entrySet()) { + KafkaFutureImpl future = entry.getValue(); + ApiException exception = response.errors().get(entry.getKey()).exception(); + if (exception != null) { + future.completeExceptionally(exception); + } else { + future.complete(null); + } + } + } + + @Override + void handleFailure(Throwable throwable) { + completeAllExceptionally(futures.values(), throwable); + } + }, now); + return futures; + } + + @Override + public AlterConfigsResult incrementalAlterConfigs(Map> configs, + final AlterConfigsOptions options) { + final Map> allFutures = new HashMap<>(); + // We must make a separate AlterConfigs request for every BROKER resource we want to alter + // and send the request to that specific broker. Other resources are grouped together into + // a single request that may be sent to any broker. + final Collection unifiedRequestResources = new ArrayList<>(); + + for (ConfigResource resource : configs.keySet()) { + Integer node = nodeFor(resource); + if (node != null) { + NodeProvider nodeProvider = new ConstantNodeIdProvider(node); + allFutures.putAll(incrementalAlterConfigs(configs, options, Collections.singleton(resource), nodeProvider)); + } else + unifiedRequestResources.add(resource); + } + if (!unifiedRequestResources.isEmpty()) + allFutures.putAll(incrementalAlterConfigs(configs, options, unifiedRequestResources, new LeastLoadedNodeProvider())); + + return new AlterConfigsResult(new HashMap<>(allFutures)); + } + + private Map> incrementalAlterConfigs(Map> configs, + final AlterConfigsOptions options, + Collection resources, + NodeProvider nodeProvider) { + final Map> futures = new HashMap<>(); + for (ConfigResource resource : resources) + futures.put(resource, new KafkaFutureImpl<>()); + + final long now = time.milliseconds(); + runnable.call(new Call("incrementalAlterConfigs", calcDeadlineMs(now, options.timeoutMs()), nodeProvider) { + + @Override + public IncrementalAlterConfigsRequest.Builder createRequest(int timeoutMs) { + return new IncrementalAlterConfigsRequest.Builder(resources, configs, options.shouldValidateOnly()); + } + + @Override + public void handleResponse(AbstractResponse abstractResponse) { + IncrementalAlterConfigsResponse response = (IncrementalAlterConfigsResponse) abstractResponse; + Map errors = IncrementalAlterConfigsResponse.fromResponseData(response.data()); + for (Map.Entry> entry : futures.entrySet()) { + KafkaFutureImpl future = entry.getValue(); + ApiException exception = errors.get(entry.getKey()).exception(); + if (exception != null) { + future.completeExceptionally(exception); + } else { + future.complete(null); + } + } + } + + @Override + void handleFailure(Throwable throwable) { + completeAllExceptionally(futures.values(), throwable); + } + }, now); + return futures; + } + + @Override + public AlterReplicaLogDirsResult alterReplicaLogDirs(Map replicaAssignment, final AlterReplicaLogDirsOptions options) { + final Map> futures = new HashMap<>(replicaAssignment.size()); + + for (TopicPartitionReplica replica : replicaAssignment.keySet()) + futures.put(replica, new KafkaFutureImpl<>()); + + Map replicaAssignmentByBroker = new HashMap<>(); + for (Map.Entry entry: replicaAssignment.entrySet()) { + TopicPartitionReplica replica = entry.getKey(); + String logDir = entry.getValue(); + int brokerId = replica.brokerId(); + AlterReplicaLogDirsRequestData value = replicaAssignmentByBroker.computeIfAbsent(brokerId, + key -> new AlterReplicaLogDirsRequestData()); + AlterReplicaLogDir alterReplicaLogDir = value.dirs().find(logDir); + if (alterReplicaLogDir == null) { + alterReplicaLogDir = new AlterReplicaLogDir(); + alterReplicaLogDir.setPath(logDir); + value.dirs().add(alterReplicaLogDir); + } + AlterReplicaLogDirTopic alterReplicaLogDirTopic = alterReplicaLogDir.topics().find(replica.topic()); + if (alterReplicaLogDirTopic == null) { + alterReplicaLogDirTopic = new AlterReplicaLogDirTopic().setName(replica.topic()); + alterReplicaLogDir.topics().add(alterReplicaLogDirTopic); + } + alterReplicaLogDirTopic.partitions().add(replica.partition()); + } + + final long now = time.milliseconds(); + for (Map.Entry entry: replicaAssignmentByBroker.entrySet()) { + final int brokerId = entry.getKey(); + final AlterReplicaLogDirsRequestData assignment = entry.getValue(); + + runnable.call(new Call("alterReplicaLogDirs", calcDeadlineMs(now, options.timeoutMs()), + new ConstantNodeIdProvider(brokerId)) { + + @Override + public AlterReplicaLogDirsRequest.Builder createRequest(int timeoutMs) { + return new AlterReplicaLogDirsRequest.Builder(assignment); + } + + @Override + public void handleResponse(AbstractResponse abstractResponse) { + AlterReplicaLogDirsResponse response = (AlterReplicaLogDirsResponse) abstractResponse; + for (AlterReplicaLogDirTopicResult topicResult: response.data().results()) { + for (AlterReplicaLogDirPartitionResult partitionResult: topicResult.partitions()) { + TopicPartitionReplica replica = new TopicPartitionReplica( + topicResult.topicName(), partitionResult.partitionIndex(), brokerId); + KafkaFutureImpl future = futures.get(replica); + if (future == null) { + log.warn("The partition {} in the response from broker {} is not in the request", + new TopicPartition(topicResult.topicName(), partitionResult.partitionIndex()), + brokerId); + } else if (partitionResult.errorCode() == Errors.NONE.code()) { + future.complete(null); + } else { + future.completeExceptionally(Errors.forCode(partitionResult.errorCode()).exception()); + } + } + } + // The server should send back a response for every replica. But do a sanity check anyway. + completeUnrealizedFutures( + futures.entrySet().stream().filter(entry -> entry.getKey().brokerId() == brokerId), + replica -> "The response from broker " + brokerId + + " did not contain a result for replica " + replica); + } + @Override + void handleFailure(Throwable throwable) { + // Only completes the futures of brokerId + completeAllExceptionally( + futures.entrySet().stream() + .filter(entry -> entry.getKey().brokerId() == brokerId) + .map(Map.Entry::getValue), + throwable); + } + }, now); + } + + return new AlterReplicaLogDirsResult(new HashMap<>(futures)); + } + + @Override + public DescribeLogDirsResult describeLogDirs(Collection brokers, DescribeLogDirsOptions options) { + final Map>> futures = new HashMap<>(brokers.size()); + + final long now = time.milliseconds(); + for (final Integer brokerId : brokers) { + KafkaFutureImpl> future = new KafkaFutureImpl<>(); + futures.put(brokerId, future); + + runnable.call(new Call("describeLogDirs", calcDeadlineMs(now, options.timeoutMs()), + new ConstantNodeIdProvider(brokerId)) { + + @Override + public DescribeLogDirsRequest.Builder createRequest(int timeoutMs) { + // Query selected partitions in all log directories + return new DescribeLogDirsRequest.Builder(new DescribeLogDirsRequestData().setTopics(null)); + } + + @Override + public void handleResponse(AbstractResponse abstractResponse) { + DescribeLogDirsResponse response = (DescribeLogDirsResponse) abstractResponse; + Map descriptions = logDirDescriptions(response); + if (descriptions.size() > 0) { + future.complete(descriptions); + } else { + // descriptions will be empty if and only if the user is not authorized to describe cluster resource. + future.completeExceptionally(Errors.CLUSTER_AUTHORIZATION_FAILED.exception()); + } + } + @Override + void handleFailure(Throwable throwable) { + future.completeExceptionally(throwable); + } + }, now); + } + + return new DescribeLogDirsResult(new HashMap<>(futures)); + } + + private static Map logDirDescriptions(DescribeLogDirsResponse response) { + Map result = new HashMap<>(response.data().results().size()); + for (DescribeLogDirsResponseData.DescribeLogDirsResult logDirResult : response.data().results()) { + Map replicaInfoMap = new HashMap<>(); + for (DescribeLogDirsResponseData.DescribeLogDirsTopic t : logDirResult.topics()) { + for (DescribeLogDirsResponseData.DescribeLogDirsPartition p : t.partitions()) { + replicaInfoMap.put( + new TopicPartition(t.name(), p.partitionIndex()), + new ReplicaInfo(p.partitionSize(), p.offsetLag(), p.isFutureKey())); + } + } + result.put(logDirResult.logDir(), new LogDirDescription(Errors.forCode(logDirResult.errorCode()).exception(), replicaInfoMap)); + } + return result; + } + + @Override + public DescribeReplicaLogDirsResult describeReplicaLogDirs(Collection replicas, DescribeReplicaLogDirsOptions options) { + final Map> futures = new HashMap<>(replicas.size()); + + for (TopicPartitionReplica replica : replicas) { + futures.put(replica, new KafkaFutureImpl<>()); + } + + Map partitionsByBroker = new HashMap<>(); + + for (TopicPartitionReplica replica: replicas) { + DescribeLogDirsRequestData requestData = partitionsByBroker.computeIfAbsent(replica.brokerId(), + brokerId -> new DescribeLogDirsRequestData()); + DescribableLogDirTopic describableLogDirTopic = requestData.topics().find(replica.topic()); + if (describableLogDirTopic == null) { + List partitions = new ArrayList<>(); + partitions.add(replica.partition()); + describableLogDirTopic = new DescribableLogDirTopic().setTopic(replica.topic()) + .setPartitions(partitions); + requestData.topics().add(describableLogDirTopic); + } else { + describableLogDirTopic.partitions().add(replica.partition()); + } + } + + final long now = time.milliseconds(); + for (Map.Entry entry: partitionsByBroker.entrySet()) { + final int brokerId = entry.getKey(); + final DescribeLogDirsRequestData topicPartitions = entry.getValue(); + final Map replicaDirInfoByPartition = new HashMap<>(); + for (DescribableLogDirTopic topicPartition: topicPartitions.topics()) { + for (Integer partitionId : topicPartition.partitions()) { + replicaDirInfoByPartition.put(new TopicPartition(topicPartition.topic(), partitionId), new ReplicaLogDirInfo()); + } + } + + runnable.call(new Call("describeReplicaLogDirs", calcDeadlineMs(now, options.timeoutMs()), + new ConstantNodeIdProvider(brokerId)) { + + @Override + public DescribeLogDirsRequest.Builder createRequest(int timeoutMs) { + // Query selected partitions in all log directories + return new DescribeLogDirsRequest.Builder(topicPartitions); + } + + @Override + public void handleResponse(AbstractResponse abstractResponse) { + DescribeLogDirsResponse response = (DescribeLogDirsResponse) abstractResponse; + for (Map.Entry responseEntry: logDirDescriptions(response).entrySet()) { + String logDir = responseEntry.getKey(); + LogDirDescription logDirInfo = responseEntry.getValue(); + + // No replica info will be provided if the log directory is offline + if (logDirInfo.error() instanceof KafkaStorageException) + continue; + if (logDirInfo.error() != null) + handleFailure(new IllegalStateException( + "The error " + logDirInfo.error().getClass().getName() + " for log directory " + logDir + " in the response from broker " + brokerId + " is illegal")); + + for (Map.Entry replicaInfoEntry: logDirInfo.replicaInfos().entrySet()) { + TopicPartition tp = replicaInfoEntry.getKey(); + ReplicaInfo replicaInfo = replicaInfoEntry.getValue(); + ReplicaLogDirInfo replicaLogDirInfo = replicaDirInfoByPartition.get(tp); + if (replicaLogDirInfo == null) { + log.warn("Server response from broker {} mentioned unknown partition {}", brokerId, tp); + } else if (replicaInfo.isFuture()) { + replicaDirInfoByPartition.put(tp, new ReplicaLogDirInfo(replicaLogDirInfo.getCurrentReplicaLogDir(), + replicaLogDirInfo.getCurrentReplicaOffsetLag(), + logDir, + replicaInfo.offsetLag())); + } else { + replicaDirInfoByPartition.put(tp, new ReplicaLogDirInfo(logDir, + replicaInfo.offsetLag(), + replicaLogDirInfo.getFutureReplicaLogDir(), + replicaLogDirInfo.getFutureReplicaOffsetLag())); + } + } + } + + for (Map.Entry entry: replicaDirInfoByPartition.entrySet()) { + TopicPartition tp = entry.getKey(); + KafkaFutureImpl future = futures.get(new TopicPartitionReplica(tp.topic(), tp.partition(), brokerId)); + future.complete(entry.getValue()); + } + } + @Override + void handleFailure(Throwable throwable) { + completeAllExceptionally(futures.values(), throwable); + } + }, now); + } + + return new DescribeReplicaLogDirsResult(new HashMap<>(futures)); + } + + @Override + public CreatePartitionsResult createPartitions(final Map newPartitions, + final CreatePartitionsOptions options) { + final Map> futures = new HashMap<>(newPartitions.size()); + final CreatePartitionsTopicCollection topics = new CreatePartitionsTopicCollection(newPartitions.size()); + for (Map.Entry entry : newPartitions.entrySet()) { + final String topic = entry.getKey(); + final NewPartitions newPartition = entry.getValue(); + List> newAssignments = newPartition.assignments(); + List assignments = newAssignments == null ? null : + newAssignments.stream() + .map(brokerIds -> new CreatePartitionsAssignment().setBrokerIds(brokerIds)) + .collect(Collectors.toList()); + topics.add(new CreatePartitionsTopic() + .setName(topic) + .setCount(newPartition.totalCount()) + .setAssignments(assignments)); + futures.put(topic, new KafkaFutureImpl<>()); + } + if (!topics.isEmpty()) { + final long now = time.milliseconds(); + final long deadline = calcDeadlineMs(now, options.timeoutMs()); + final Call call = getCreatePartitionsCall(options, futures, topics, + Collections.emptyMap(), now, deadline); + runnable.call(call, now); + } + return new CreatePartitionsResult(new HashMap<>(futures)); + } + + private Call getCreatePartitionsCall(final CreatePartitionsOptions options, + final Map> futures, + final CreatePartitionsTopicCollection topics, + final Map quotaExceededExceptions, + final long now, + final long deadline) { + return new Call("createPartitions", deadline, new ControllerNodeProvider()) { + @Override + public CreatePartitionsRequest.Builder createRequest(int timeoutMs) { + return new CreatePartitionsRequest.Builder( + new CreatePartitionsRequestData() + .setTopics(topics) + .setValidateOnly(options.validateOnly()) + .setTimeoutMs(timeoutMs)); + } + + @Override + public void handleResponse(AbstractResponse abstractResponse) { + // Check for controller change + handleNotControllerError(abstractResponse); + // Handle server responses for particular topics. + final CreatePartitionsResponse response = (CreatePartitionsResponse) abstractResponse; + final CreatePartitionsTopicCollection retryTopics = new CreatePartitionsTopicCollection(); + final Map retryTopicQuotaExceededExceptions = new HashMap<>(); + for (CreatePartitionsTopicResult result : response.data().results()) { + KafkaFutureImpl future = futures.get(result.name()); + if (future == null) { + log.warn("Server response mentioned unknown topic {}", result.name()); + } else { + ApiError error = new ApiError(result.errorCode(), result.errorMessage()); + if (error.isFailure()) { + if (error.is(Errors.THROTTLING_QUOTA_EXCEEDED)) { + ThrottlingQuotaExceededException quotaExceededException = new ThrottlingQuotaExceededException( + response.throttleTimeMs(), error.messageWithFallback()); + if (options.shouldRetryOnQuotaViolation()) { + retryTopics.add(topics.find(result.name()).duplicate()); + retryTopicQuotaExceededExceptions.put(result.name(), quotaExceededException); + } else { + future.completeExceptionally(quotaExceededException); + } + } else { + future.completeExceptionally(error.exception()); + } + } else { + future.complete(null); + } + } + } + // If there are topics to retry, retry them; complete unrealized futures otherwise. + if (retryTopics.isEmpty()) { + // The server should send back a response for every topic. But do a sanity check anyway. + completeUnrealizedFutures(futures.entrySet().stream(), + topic -> "The controller response did not contain a result for topic " + topic); + } else { + final long now = time.milliseconds(); + final Call call = getCreatePartitionsCall(options, futures, retryTopics, + retryTopicQuotaExceededExceptions, now, deadline); + runnable.call(call, now); + } + } + + @Override + void handleFailure(Throwable throwable) { + // If there were any topics retries due to a quota exceeded exception, we propagate + // the initial error back to the caller if the request timed out. + maybeCompleteQuotaExceededException(options.shouldRetryOnQuotaViolation(), + throwable, futures, quotaExceededExceptions, (int) (time.milliseconds() - now)); + // Fail all the other remaining futures + completeAllExceptionally(futures.values(), throwable); + } + }; + } + + @Override + public DeleteRecordsResult deleteRecords(final Map recordsToDelete, + final DeleteRecordsOptions options) { + + // requests need to be sent to partitions leader nodes so ... + // ... from the provided map it's needed to create more maps grouping topic/partition per leader + + final Map> futures = new HashMap<>(recordsToDelete.size()); + for (TopicPartition topicPartition: recordsToDelete.keySet()) { + futures.put(topicPartition, new KafkaFutureImpl<>()); + } + + // preparing topics list for asking metadata about them + final Set topics = new HashSet<>(); + for (TopicPartition topicPartition: recordsToDelete.keySet()) { + topics.add(topicPartition.topic()); + } + + final long nowMetadata = time.milliseconds(); + final long deadline = calcDeadlineMs(nowMetadata, options.timeoutMs()); + // asking for topics metadata for getting partitions leaders + runnable.call(new Call("topicsMetadata", deadline, + new LeastLoadedNodeProvider()) { + + @Override + MetadataRequest.Builder createRequest(int timeoutMs) { + return new MetadataRequest.Builder(new MetadataRequestData() + .setTopics(convertToMetadataRequestTopic(topics)) + .setAllowAutoTopicCreation(false)); + } + + @Override + void handleResponse(AbstractResponse abstractResponse) { + MetadataResponse response = (MetadataResponse) abstractResponse; + + Map errors = response.errors(); + Cluster cluster = response.buildCluster(); + + // Group topic partitions by leader + Map> leaders = new HashMap<>(); + for (Map.Entry entry: recordsToDelete.entrySet()) { + TopicPartition topicPartition = entry.getKey(); + KafkaFutureImpl future = futures.get(topicPartition); + + // Fail partitions with topic errors + Errors topicError = errors.get(topicPartition.topic()); + if (errors.containsKey(topicPartition.topic())) { + future.completeExceptionally(topicError.exception()); + } else { + Node node = cluster.leaderFor(topicPartition); + if (node != null) { + Map deletionsForLeader = leaders.computeIfAbsent( + node, key -> new HashMap<>()); + DeleteRecordsTopic deleteRecords = deletionsForLeader.get(topicPartition.topic()); + if (deleteRecords == null) { + deleteRecords = new DeleteRecordsTopic() + .setName(topicPartition.topic()); + deletionsForLeader.put(topicPartition.topic(), deleteRecords); + } + deleteRecords.partitions().add(new DeleteRecordsPartition() + .setPartitionIndex(topicPartition.partition()) + .setOffset(entry.getValue().beforeOffset())); + } else { + future.completeExceptionally(Errors.LEADER_NOT_AVAILABLE.exception()); + } + } + } + + final long deleteRecordsCallTimeMs = time.milliseconds(); + + for (final Map.Entry> entry : leaders.entrySet()) { + final Map partitionDeleteOffsets = entry.getValue(); + final int brokerId = entry.getKey().id(); + + runnable.call(new Call("deleteRecords", deadline, + new ConstantNodeIdProvider(brokerId)) { + + @Override + DeleteRecordsRequest.Builder createRequest(int timeoutMs) { + return new DeleteRecordsRequest.Builder(new DeleteRecordsRequestData() + .setTimeoutMs(timeoutMs) + .setTopics(new ArrayList<>(partitionDeleteOffsets.values()))); + } + + @Override + void handleResponse(AbstractResponse abstractResponse) { + DeleteRecordsResponse response = (DeleteRecordsResponse) abstractResponse; + for (DeleteRecordsTopicResult topicResult: response.data().topics()) { + for (DeleteRecordsResponseData.DeleteRecordsPartitionResult partitionResult : topicResult.partitions()) { + KafkaFutureImpl future = futures.get(new TopicPartition(topicResult.name(), partitionResult.partitionIndex())); + if (partitionResult.errorCode() == Errors.NONE.code()) { + future.complete(new DeletedRecords(partitionResult.lowWatermark())); + } else { + future.completeExceptionally(Errors.forCode(partitionResult.errorCode()).exception()); + } + } + } + } + + @Override + void handleFailure(Throwable throwable) { + Stream> callFutures = + partitionDeleteOffsets.values().stream().flatMap( + recordsToDelete -> + recordsToDelete.partitions().stream().map(partitionsToDelete -> + new TopicPartition(recordsToDelete.name(), partitionsToDelete.partitionIndex())) + ).map(futures::get); + completeAllExceptionally(callFutures, throwable); + } + }, deleteRecordsCallTimeMs); + } + } + + @Override + void handleFailure(Throwable throwable) { + completeAllExceptionally(futures.values(), throwable); + } + }, nowMetadata); + + return new DeleteRecordsResult(new HashMap<>(futures)); + } + + @Override + public CreateDelegationTokenResult createDelegationToken(final CreateDelegationTokenOptions options) { + final KafkaFutureImpl delegationTokenFuture = new KafkaFutureImpl<>(); + final long now = time.milliseconds(); + List renewers = new ArrayList<>(); + for (KafkaPrincipal principal : options.renewers()) { + renewers.add(new CreatableRenewers() + .setPrincipalName(principal.getName()) + .setPrincipalType(principal.getPrincipalType())); + } + runnable.call(new Call("createDelegationToken", calcDeadlineMs(now, options.timeoutMs()), + new LeastLoadedNodeProvider()) { + + @Override + CreateDelegationTokenRequest.Builder createRequest(int timeoutMs) { + return new CreateDelegationTokenRequest.Builder( + new CreateDelegationTokenRequestData() + .setRenewers(renewers) + .setMaxLifetimeMs(options.maxlifeTimeMs())); + } + + @Override + void handleResponse(AbstractResponse abstractResponse) { + CreateDelegationTokenResponse response = (CreateDelegationTokenResponse) abstractResponse; + if (response.hasError()) { + delegationTokenFuture.completeExceptionally(response.error().exception()); + } else { + CreateDelegationTokenResponseData data = response.data(); + TokenInformation tokenInfo = new TokenInformation(data.tokenId(), new KafkaPrincipal(data.principalType(), data.principalName()), + options.renewers(), data.issueTimestampMs(), data.maxTimestampMs(), data.expiryTimestampMs()); + DelegationToken token = new DelegationToken(tokenInfo, data.hmac()); + delegationTokenFuture.complete(token); + } + } + + @Override + void handleFailure(Throwable throwable) { + delegationTokenFuture.completeExceptionally(throwable); + } + }, now); + + return new CreateDelegationTokenResult(delegationTokenFuture); + } + + @Override + public RenewDelegationTokenResult renewDelegationToken(final byte[] hmac, final RenewDelegationTokenOptions options) { + final KafkaFutureImpl expiryTimeFuture = new KafkaFutureImpl<>(); + final long now = time.milliseconds(); + runnable.call(new Call("renewDelegationToken", calcDeadlineMs(now, options.timeoutMs()), + new LeastLoadedNodeProvider()) { + + @Override + RenewDelegationTokenRequest.Builder createRequest(int timeoutMs) { + return new RenewDelegationTokenRequest.Builder( + new RenewDelegationTokenRequestData() + .setHmac(hmac) + .setRenewPeriodMs(options.renewTimePeriodMs())); + } + + @Override + void handleResponse(AbstractResponse abstractResponse) { + RenewDelegationTokenResponse response = (RenewDelegationTokenResponse) abstractResponse; + if (response.hasError()) { + expiryTimeFuture.completeExceptionally(response.error().exception()); + } else { + expiryTimeFuture.complete(response.expiryTimestamp()); + } + } + + @Override + void handleFailure(Throwable throwable) { + expiryTimeFuture.completeExceptionally(throwable); + } + }, now); + + return new RenewDelegationTokenResult(expiryTimeFuture); + } + + @Override + public ExpireDelegationTokenResult expireDelegationToken(final byte[] hmac, final ExpireDelegationTokenOptions options) { + final KafkaFutureImpl expiryTimeFuture = new KafkaFutureImpl<>(); + final long now = time.milliseconds(); + runnable.call(new Call("expireDelegationToken", calcDeadlineMs(now, options.timeoutMs()), + new LeastLoadedNodeProvider()) { + + @Override + ExpireDelegationTokenRequest.Builder createRequest(int timeoutMs) { + return new ExpireDelegationTokenRequest.Builder( + new ExpireDelegationTokenRequestData() + .setHmac(hmac) + .setExpiryTimePeriodMs(options.expiryTimePeriodMs())); + } + + @Override + void handleResponse(AbstractResponse abstractResponse) { + ExpireDelegationTokenResponse response = (ExpireDelegationTokenResponse) abstractResponse; + if (response.hasError()) { + expiryTimeFuture.completeExceptionally(response.error().exception()); + } else { + expiryTimeFuture.complete(response.expiryTimestamp()); + } + } + + @Override + void handleFailure(Throwable throwable) { + expiryTimeFuture.completeExceptionally(throwable); + } + }, now); + + return new ExpireDelegationTokenResult(expiryTimeFuture); + } + + @Override + public DescribeDelegationTokenResult describeDelegationToken(final DescribeDelegationTokenOptions options) { + final KafkaFutureImpl> tokensFuture = new KafkaFutureImpl<>(); + final long now = time.milliseconds(); + runnable.call(new Call("describeDelegationToken", calcDeadlineMs(now, options.timeoutMs()), + new LeastLoadedNodeProvider()) { + + @Override + DescribeDelegationTokenRequest.Builder createRequest(int timeoutMs) { + return new DescribeDelegationTokenRequest.Builder(options.owners()); + } + + @Override + void handleResponse(AbstractResponse abstractResponse) { + DescribeDelegationTokenResponse response = (DescribeDelegationTokenResponse) abstractResponse; + if (response.hasError()) { + tokensFuture.completeExceptionally(response.error().exception()); + } else { + tokensFuture.complete(response.tokens()); + } + } + + @Override + void handleFailure(Throwable throwable) { + tokensFuture.completeExceptionally(throwable); + } + }, now); + + return new DescribeDelegationTokenResult(tokensFuture); + } + + private void rescheduleMetadataTask(MetadataOperationContext context, Supplier> nextCalls) { + log.info("Retrying to fetch metadata."); + // Requeue the task so that we can re-attempt fetching metadata + context.setResponse(Optional.empty()); + Call metadataCall = getMetadataCall(context, nextCalls); + runnable.call(metadataCall, time.milliseconds()); + } + + @Override + public DescribeConsumerGroupsResult describeConsumerGroups(final Collection groupIds, + final DescribeConsumerGroupsOptions options) { + SimpleAdminApiFuture future = + DescribeConsumerGroupsHandler.newFuture(groupIds); + DescribeConsumerGroupsHandler handler = new DescribeConsumerGroupsHandler(options.includeAuthorizedOperations(), logContext); + invokeDriver(handler, future, options.timeoutMs); + return new DescribeConsumerGroupsResult(future.all().entrySet().stream() + .collect(Collectors.toMap(entry -> entry.getKey().idValue, Map.Entry::getValue))); + } + + /** + * Returns a {@code Call} object to fetch the cluster metadata. Takes a List of Calls + * parameter to schedule actions that need to be taken using the metadata. The param is a Supplier + * so that it can be lazily created, so that it can use the results of the metadata call in its + * construction. + * + * @param The type of return value of the KafkaFuture, like ListOffsetsResultInfo, etc. + * @param The type of configuration option, like ListOffsetsOptions, etc + */ + private > Call getMetadataCall(MetadataOperationContext context, + Supplier> nextCalls) { + return new Call("metadata", context.deadline(), new LeastLoadedNodeProvider()) { + @Override + MetadataRequest.Builder createRequest(int timeoutMs) { + return new MetadataRequest.Builder(new MetadataRequestData() + .setTopics(convertToMetadataRequestTopic(context.topics())) + .setAllowAutoTopicCreation(false)); + } + + @Override + void handleResponse(AbstractResponse abstractResponse) { + MetadataResponse response = (MetadataResponse) abstractResponse; + MetadataOperationContext.handleMetadataErrors(response); + + context.setResponse(Optional.of(response)); + + for (Call call : nextCalls.get()) { + runnable.call(call, time.milliseconds()); + } + } + + @Override + void handleFailure(Throwable throwable) { + for (KafkaFutureImpl future : context.futures().values()) { + future.completeExceptionally(throwable); + } + } + }; + } + + private Set validAclOperations(final int authorizedOperations) { + if (authorizedOperations == MetadataResponse.AUTHORIZED_OPERATIONS_OMITTED) { + return null; + } + return Utils.from32BitField(authorizedOperations) + .stream() + .map(AclOperation::fromCode) + .filter(operation -> operation != AclOperation.UNKNOWN + && operation != AclOperation.ALL + && operation != AclOperation.ANY) + .collect(Collectors.toSet()); + } + + private final static class ListConsumerGroupsResults { + private final List errors; + private final HashMap listings; + private final HashSet remaining; + private final KafkaFutureImpl> future; + + ListConsumerGroupsResults(Collection leaders, + KafkaFutureImpl> future) { + this.errors = new ArrayList<>(); + this.listings = new HashMap<>(); + this.remaining = new HashSet<>(leaders); + this.future = future; + tryComplete(); + } + + synchronized void addError(Throwable throwable, Node node) { + ApiError error = ApiError.fromThrowable(throwable); + if (error.message() == null || error.message().isEmpty()) { + errors.add(error.error().exception("Error listing groups on " + node)); + } else { + errors.add(error.error().exception("Error listing groups on " + node + ": " + error.message())); + } + } + + synchronized void addListing(ConsumerGroupListing listing) { + listings.put(listing.groupId(), listing); + } + + synchronized void tryComplete(Node leader) { + remaining.remove(leader); + tryComplete(); + } + + private synchronized void tryComplete() { + if (remaining.isEmpty()) { + ArrayList results = new ArrayList<>(listings.values()); + results.addAll(errors); + future.complete(results); + } + } + } + + @Override + public ListConsumerGroupsResult listConsumerGroups(ListConsumerGroupsOptions options) { + final KafkaFutureImpl> all = new KafkaFutureImpl<>(); + final long nowMetadata = time.milliseconds(); + final long deadline = calcDeadlineMs(nowMetadata, options.timeoutMs()); + runnable.call(new Call("findAllBrokers", deadline, new LeastLoadedNodeProvider()) { + @Override + MetadataRequest.Builder createRequest(int timeoutMs) { + return new MetadataRequest.Builder(new MetadataRequestData() + .setTopics(Collections.emptyList()) + .setAllowAutoTopicCreation(true)); + } + + @Override + void handleResponse(AbstractResponse abstractResponse) { + MetadataResponse metadataResponse = (MetadataResponse) abstractResponse; + Collection nodes = metadataResponse.brokers(); + if (nodes.isEmpty()) + throw new StaleMetadataException("Metadata fetch failed due to missing broker list"); + + HashSet allNodes = new HashSet<>(nodes); + final ListConsumerGroupsResults results = new ListConsumerGroupsResults(allNodes, all); + + for (final Node node : allNodes) { + final long nowList = time.milliseconds(); + runnable.call(new Call("listConsumerGroups", deadline, new ConstantNodeIdProvider(node.id())) { + @Override + ListGroupsRequest.Builder createRequest(int timeoutMs) { + List states = options.states() + .stream() + .map(s -> s.toString()) + .collect(Collectors.toList()); + return new ListGroupsRequest.Builder(new ListGroupsRequestData().setStatesFilter(states)); + } + + private void maybeAddConsumerGroup(ListGroupsResponseData.ListedGroup group) { + String protocolType = group.protocolType(); + if (protocolType.equals(ConsumerProtocol.PROTOCOL_TYPE) || protocolType.isEmpty()) { + final String groupId = group.groupId(); + final Optional state = group.groupState().equals("") + ? Optional.empty() + : Optional.of(ConsumerGroupState.parse(group.groupState())); + final ConsumerGroupListing groupListing = new ConsumerGroupListing(groupId, protocolType.isEmpty(), state); + results.addListing(groupListing); + } + } + + @Override + void handleResponse(AbstractResponse abstractResponse) { + final ListGroupsResponse response = (ListGroupsResponse) abstractResponse; + synchronized (results) { + Errors error = Errors.forCode(response.data().errorCode()); + if (error == Errors.COORDINATOR_LOAD_IN_PROGRESS || error == Errors.COORDINATOR_NOT_AVAILABLE) { + throw error.exception(); + } else if (error != Errors.NONE) { + results.addError(error.exception(), node); + } else { + for (ListGroupsResponseData.ListedGroup group : response.data().groups()) { + maybeAddConsumerGroup(group); + } + } + results.tryComplete(node); + } + } + + @Override + void handleFailure(Throwable throwable) { + synchronized (results) { + results.addError(throwable, node); + results.tryComplete(node); + } + } + }, nowList); + } + } + + @Override + void handleFailure(Throwable throwable) { + KafkaException exception = new KafkaException("Failed to find brokers to send ListGroups", throwable); + all.complete(Collections.singletonList(exception)); + } + }, nowMetadata); + + return new ListConsumerGroupsResult(all); + } + + @Override + public ListConsumerGroupOffsetsResult listConsumerGroupOffsets(final String groupId, + final ListConsumerGroupOffsetsOptions options) { + SimpleAdminApiFuture> future = + ListConsumerGroupOffsetsHandler.newFuture(groupId); + ListConsumerGroupOffsetsHandler handler = new ListConsumerGroupOffsetsHandler(groupId, options.topicPartitions(), logContext); + invokeDriver(handler, future, options.timeoutMs); + return new ListConsumerGroupOffsetsResult(future.get(CoordinatorKey.byGroupId(groupId))); + } + + @Override + public DeleteConsumerGroupsResult deleteConsumerGroups(Collection groupIds, DeleteConsumerGroupsOptions options) { + SimpleAdminApiFuture future = + DeleteConsumerGroupsHandler.newFuture(groupIds); + DeleteConsumerGroupsHandler handler = new DeleteConsumerGroupsHandler(logContext); + invokeDriver(handler, future, options.timeoutMs); + return new DeleteConsumerGroupsResult(future.all().entrySet().stream() + .collect(Collectors.toMap(entry -> entry.getKey().idValue, Map.Entry::getValue))); + } + + @Override + public DeleteConsumerGroupOffsetsResult deleteConsumerGroupOffsets( + String groupId, + Set partitions, + DeleteConsumerGroupOffsetsOptions options) { + SimpleAdminApiFuture> future = + DeleteConsumerGroupOffsetsHandler.newFuture(groupId); + DeleteConsumerGroupOffsetsHandler handler = new DeleteConsumerGroupOffsetsHandler(groupId, partitions, logContext); + invokeDriver(handler, future, options.timeoutMs); + return new DeleteConsumerGroupOffsetsResult(future.get(CoordinatorKey.byGroupId(groupId)), partitions); + } + + @Override + public Map metrics() { + return Collections.unmodifiableMap(this.metrics.metrics()); + } + + @Override + public ElectLeadersResult electLeaders( + final ElectionType electionType, + final Set topicPartitions, + ElectLeadersOptions options) { + final KafkaFutureImpl>> electionFuture = new KafkaFutureImpl<>(); + final long now = time.milliseconds(); + runnable.call(new Call("electLeaders", calcDeadlineMs(now, options.timeoutMs()), + new ControllerNodeProvider()) { + + @Override + public ElectLeadersRequest.Builder createRequest(int timeoutMs) { + return new ElectLeadersRequest.Builder(electionType, topicPartitions, timeoutMs); + } + + @Override + public void handleResponse(AbstractResponse abstractResponse) { + ElectLeadersResponse response = (ElectLeadersResponse) abstractResponse; + Map> result = ElectLeadersResponse.electLeadersResult(response.data()); + + // For version == 0 then errorCode would be 0 which maps to Errors.NONE + Errors error = Errors.forCode(response.data().errorCode()); + if (error != Errors.NONE) { + electionFuture.completeExceptionally(error.exception()); + return; + } + + electionFuture.complete(result); + } + + @Override + void handleFailure(Throwable throwable) { + electionFuture.completeExceptionally(throwable); + } + }, now); + + return new ElectLeadersResult(electionFuture); + } + + @Override + public AlterPartitionReassignmentsResult alterPartitionReassignments( + Map> reassignments, + AlterPartitionReassignmentsOptions options) { + final Map> futures = new HashMap<>(); + final Map>> topicsToReassignments = new TreeMap<>(); + for (Map.Entry> entry : reassignments.entrySet()) { + String topic = entry.getKey().topic(); + int partition = entry.getKey().partition(); + TopicPartition topicPartition = new TopicPartition(topic, partition); + Optional reassignment = entry.getValue(); + KafkaFutureImpl future = new KafkaFutureImpl<>(); + futures.put(topicPartition, future); + + if (topicNameIsUnrepresentable(topic)) { + future.completeExceptionally(new InvalidTopicException("The given topic name '" + + topic + "' cannot be represented in a request.")); + } else if (topicPartition.partition() < 0) { + future.completeExceptionally(new InvalidTopicException("The given partition index " + + topicPartition.partition() + " is not valid.")); + } else { + Map> partitionReassignments = + topicsToReassignments.get(topicPartition.topic()); + if (partitionReassignments == null) { + partitionReassignments = new TreeMap<>(); + topicsToReassignments.put(topic, partitionReassignments); + } + + partitionReassignments.put(partition, reassignment); + } + } + + final long now = time.milliseconds(); + Call call = new Call("alterPartitionReassignments", calcDeadlineMs(now, options.timeoutMs()), + new ControllerNodeProvider()) { + + @Override + public AlterPartitionReassignmentsRequest.Builder createRequest(int timeoutMs) { + AlterPartitionReassignmentsRequestData data = + new AlterPartitionReassignmentsRequestData(); + for (Map.Entry>> entry : + topicsToReassignments.entrySet()) { + String topicName = entry.getKey(); + Map> partitionsToReassignments = entry.getValue(); + + List reassignablePartitions = new ArrayList<>(); + for (Map.Entry> partitionEntry : + partitionsToReassignments.entrySet()) { + int partitionIndex = partitionEntry.getKey(); + Optional reassignment = partitionEntry.getValue(); + + ReassignablePartition reassignablePartition = new ReassignablePartition() + .setPartitionIndex(partitionIndex) + .setReplicas(reassignment.map(NewPartitionReassignment::targetReplicas).orElse(null)); + reassignablePartitions.add(reassignablePartition); + } + + ReassignableTopic reassignableTopic = new ReassignableTopic() + .setName(topicName) + .setPartitions(reassignablePartitions); + data.topics().add(reassignableTopic); + } + data.setTimeoutMs(timeoutMs); + return new AlterPartitionReassignmentsRequest.Builder(data); + } + + @Override + public void handleResponse(AbstractResponse abstractResponse) { + AlterPartitionReassignmentsResponse response = (AlterPartitionReassignmentsResponse) abstractResponse; + Map errors = new HashMap<>(); + int receivedResponsesCount = 0; + + Errors topLevelError = Errors.forCode(response.data().errorCode()); + switch (topLevelError) { + case NONE: + receivedResponsesCount += validateTopicResponses(response.data().responses(), errors); + break; + case NOT_CONTROLLER: + handleNotControllerError(topLevelError); + break; + default: + for (ReassignableTopicResponse topicResponse : response.data().responses()) { + String topicName = topicResponse.name(); + for (ReassignablePartitionResponse partition : topicResponse.partitions()) { + errors.put( + new TopicPartition(topicName, partition.partitionIndex()), + new ApiError(topLevelError, response.data().errorMessage()).exception() + ); + receivedResponsesCount += 1; + } + } + break; + } + + assertResponseCountMatch(errors, receivedResponsesCount); + for (Map.Entry entry : errors.entrySet()) { + ApiException exception = entry.getValue(); + if (exception == null) + futures.get(entry.getKey()).complete(null); + else + futures.get(entry.getKey()).completeExceptionally(exception); + } + } + + private void assertResponseCountMatch(Map errors, int receivedResponsesCount) { + int expectedResponsesCount = topicsToReassignments.values().stream().mapToInt(Map::size).sum(); + if (errors.values().stream().noneMatch(Objects::nonNull) && receivedResponsesCount != expectedResponsesCount) { + String quantifier = receivedResponsesCount > expectedResponsesCount ? "many" : "less"; + throw new UnknownServerException("The server returned too " + quantifier + " results." + + "Expected " + expectedResponsesCount + " but received " + receivedResponsesCount); + } + } + + private int validateTopicResponses(List topicResponses, + Map errors) { + int receivedResponsesCount = 0; + + for (ReassignableTopicResponse topicResponse : topicResponses) { + String topicName = topicResponse.name(); + for (ReassignablePartitionResponse partResponse : topicResponse.partitions()) { + Errors partitionError = Errors.forCode(partResponse.errorCode()); + + TopicPartition tp = new TopicPartition(topicName, partResponse.partitionIndex()); + if (partitionError == Errors.NONE) { + errors.put(tp, null); + } else { + errors.put(tp, new ApiError(partitionError, partResponse.errorMessage()).exception()); + } + receivedResponsesCount += 1; + } + } + + return receivedResponsesCount; + } + + @Override + void handleFailure(Throwable throwable) { + for (KafkaFutureImpl future : futures.values()) { + future.completeExceptionally(throwable); + } + } + }; + if (!topicsToReassignments.isEmpty()) { + runnable.call(call, now); + } + return new AlterPartitionReassignmentsResult(new HashMap<>(futures)); + } + + @Override + public ListPartitionReassignmentsResult listPartitionReassignments(Optional> partitions, + ListPartitionReassignmentsOptions options) { + final KafkaFutureImpl> partitionReassignmentsFuture = new KafkaFutureImpl<>(); + if (partitions.isPresent()) { + for (TopicPartition tp : partitions.get()) { + String topic = tp.topic(); + int partition = tp.partition(); + if (topicNameIsUnrepresentable(topic)) { + partitionReassignmentsFuture.completeExceptionally(new InvalidTopicException("The given topic name '" + + topic + "' cannot be represented in a request.")); + } else if (partition < 0) { + partitionReassignmentsFuture.completeExceptionally(new InvalidTopicException("The given partition index " + + partition + " is not valid.")); + } + if (partitionReassignmentsFuture.isCompletedExceptionally()) + return new ListPartitionReassignmentsResult(partitionReassignmentsFuture); + } + } + final long now = time.milliseconds(); + runnable.call(new Call("listPartitionReassignments", calcDeadlineMs(now, options.timeoutMs()), + new ControllerNodeProvider()) { + + @Override + ListPartitionReassignmentsRequest.Builder createRequest(int timeoutMs) { + ListPartitionReassignmentsRequestData listData = new ListPartitionReassignmentsRequestData(); + listData.setTimeoutMs(timeoutMs); + + if (partitions.isPresent()) { + Map reassignmentTopicByTopicName = new HashMap<>(); + + for (TopicPartition tp : partitions.get()) { + if (!reassignmentTopicByTopicName.containsKey(tp.topic())) + reassignmentTopicByTopicName.put(tp.topic(), new ListPartitionReassignmentsTopics().setName(tp.topic())); + + reassignmentTopicByTopicName.get(tp.topic()).partitionIndexes().add(tp.partition()); + } + + listData.setTopics(new ArrayList<>(reassignmentTopicByTopicName.values())); + } + return new ListPartitionReassignmentsRequest.Builder(listData); + } + + @Override + void handleResponse(AbstractResponse abstractResponse) { + ListPartitionReassignmentsResponse response = (ListPartitionReassignmentsResponse) abstractResponse; + Errors error = Errors.forCode(response.data().errorCode()); + switch (error) { + case NONE: + break; + case NOT_CONTROLLER: + handleNotControllerError(error); + break; + default: + partitionReassignmentsFuture.completeExceptionally(new ApiError(error, response.data().errorMessage()).exception()); + break; + } + Map reassignmentMap = new HashMap<>(); + + for (OngoingTopicReassignment topicReassignment : response.data().topics()) { + String topicName = topicReassignment.name(); + for (OngoingPartitionReassignment partitionReassignment : topicReassignment.partitions()) { + reassignmentMap.put( + new TopicPartition(topicName, partitionReassignment.partitionIndex()), + new PartitionReassignment(partitionReassignment.replicas(), partitionReassignment.addingReplicas(), partitionReassignment.removingReplicas()) + ); + } + } + + partitionReassignmentsFuture.complete(reassignmentMap); + } + + @Override + void handleFailure(Throwable throwable) { + partitionReassignmentsFuture.completeExceptionally(throwable); + } + }, now); + + return new ListPartitionReassignmentsResult(partitionReassignmentsFuture); + } + + private void handleNotControllerError(AbstractResponse response) throws ApiException { + if (response.errorCounts().containsKey(Errors.NOT_CONTROLLER)) { + handleNotControllerError(Errors.NOT_CONTROLLER); + } + } + + private void handleNotControllerError(Errors error) throws ApiException { + metadataManager.clearController(); + metadataManager.requestUpdate(); + throw error.exception(); + } + + /** + * Returns the broker id pertaining to the given resource, or null if the resource is not associated + * with a particular broker. + */ + private Integer nodeFor(ConfigResource resource) { + if ((resource.type() == ConfigResource.Type.BROKER && !resource.isDefault()) + || resource.type() == ConfigResource.Type.BROKER_LOGGER) { + return Integer.valueOf(resource.name()); + } else { + return null; + } + } + + private List getMembersFromGroup(String groupId) { + Collection members; + try { + members = describeConsumerGroups(Collections.singleton(groupId)).describedGroups().get(groupId).get().members(); + } catch (Exception ex) { + throw new KafkaException("Encounter exception when trying to get members from group: " + groupId, ex); + } + + List membersToRemove = new ArrayList<>(); + for (final MemberDescription member : members) { + if (member.groupInstanceId().isPresent()) { + membersToRemove.add(new MemberIdentity().setGroupInstanceId(member.groupInstanceId().get())); + } else { + membersToRemove.add(new MemberIdentity().setMemberId(member.consumerId())); + } + } + return membersToRemove; + } + + @Override + public RemoveMembersFromConsumerGroupResult removeMembersFromConsumerGroup(String groupId, + RemoveMembersFromConsumerGroupOptions options) { + List members; + if (options.removeAll()) { + members = getMembersFromGroup(groupId); + } else { + members = options.members().stream().map(MemberToRemove::toMemberIdentity).collect(Collectors.toList()); + } + SimpleAdminApiFuture> future = + RemoveMembersFromConsumerGroupHandler.newFuture(groupId); + RemoveMembersFromConsumerGroupHandler handler = new RemoveMembersFromConsumerGroupHandler(groupId, members, logContext); + invokeDriver(handler, future, options.timeoutMs); + return new RemoveMembersFromConsumerGroupResult(future.get(CoordinatorKey.byGroupId(groupId)), options.members()); + } + + @Override + public AlterConsumerGroupOffsetsResult alterConsumerGroupOffsets( + String groupId, + Map offsets, + AlterConsumerGroupOffsetsOptions options + ) { + SimpleAdminApiFuture> future = + AlterConsumerGroupOffsetsHandler.newFuture(groupId); + AlterConsumerGroupOffsetsHandler handler = new AlterConsumerGroupOffsetsHandler(groupId, offsets, logContext); + invokeDriver(handler, future, options.timeoutMs); + return new AlterConsumerGroupOffsetsResult(future.get(CoordinatorKey.byGroupId(groupId))); + } + + @Override + public ListOffsetsResult listOffsets(Map topicPartitionOffsets, + ListOffsetsOptions options) { + + // preparing topics list for asking metadata about them + final Map> futures = new HashMap<>(topicPartitionOffsets.size()); + final Set topics = new HashSet<>(); + for (TopicPartition topicPartition : topicPartitionOffsets.keySet()) { + topics.add(topicPartition.topic()); + futures.put(topicPartition, new KafkaFutureImpl<>()); + } + + final long nowMetadata = time.milliseconds(); + final long deadline = calcDeadlineMs(nowMetadata, options.timeoutMs()); + + MetadataOperationContext context = + new MetadataOperationContext<>(topics, options, deadline, futures); + + Call metadataCall = getMetadataCall(context, + () -> KafkaAdminClient.this.getListOffsetsCalls(context, topicPartitionOffsets, futures)); + runnable.call(metadataCall, nowMetadata); + + return new ListOffsetsResult(new HashMap<>(futures)); + } + + // visible for benchmark + List getListOffsetsCalls(MetadataOperationContext context, + Map topicPartitionOffsets, + Map> futures) { + + MetadataResponse mr = context.response().orElseThrow(() -> new IllegalStateException("No Metadata response")); + Cluster clusterSnapshot = mr.buildCluster(); + List calls = new ArrayList<>(); + // grouping topic partitions per leader + Map> leaders = new HashMap<>(); + + for (Map.Entry entry: topicPartitionOffsets.entrySet()) { + + OffsetSpec offsetSpec = entry.getValue(); + TopicPartition tp = entry.getKey(); + KafkaFutureImpl future = futures.get(tp); + long offsetQuery = getOffsetFromOffsetSpec(offsetSpec); + // avoid sending listOffsets request for topics with errors + if (!mr.errors().containsKey(tp.topic())) { + Node node = clusterSnapshot.leaderFor(tp); + if (node != null) { + Map leadersOnNode = leaders.computeIfAbsent(node, k -> new HashMap<>()); + ListOffsetsTopic topic = leadersOnNode.computeIfAbsent(tp.topic(), k -> new ListOffsetsTopic().setName(tp.topic())); + topic.partitions().add(new ListOffsetsPartition().setPartitionIndex(tp.partition()).setTimestamp(offsetQuery)); + } else { + future.completeExceptionally(Errors.LEADER_NOT_AVAILABLE.exception()); + } + } else { + future.completeExceptionally(mr.errors().get(tp.topic()).exception()); + } + } + + for (final Map.Entry> entry : leaders.entrySet()) { + final int brokerId = entry.getKey().id(); + + calls.add(new Call("listOffsets on broker " + brokerId, context.deadline(), new ConstantNodeIdProvider(brokerId)) { + + final List partitionsToQuery = new ArrayList<>(entry.getValue().values()); + + private boolean supportsMaxTimestamp = partitionsToQuery.stream() + .flatMap(t -> t.partitions().stream()) + .anyMatch(p -> p.timestamp() == ListOffsetsRequest.MAX_TIMESTAMP); + + @Override + ListOffsetsRequest.Builder createRequest(int timeoutMs) { + return ListOffsetsRequest.Builder + .forConsumer(true, context.options().isolationLevel(), supportsMaxTimestamp) + .setTargetTimes(partitionsToQuery); + } + + @Override + void handleResponse(AbstractResponse abstractResponse) { + ListOffsetsResponse response = (ListOffsetsResponse) abstractResponse; + Map retryTopicPartitionOffsets = new HashMap<>(); + + for (ListOffsetsTopicResponse topic : response.topics()) { + for (ListOffsetsPartitionResponse partition : topic.partitions()) { + TopicPartition tp = new TopicPartition(topic.name(), partition.partitionIndex()); + KafkaFutureImpl future = futures.get(tp); + Errors error = Errors.forCode(partition.errorCode()); + OffsetSpec offsetRequestSpec = topicPartitionOffsets.get(tp); + if (offsetRequestSpec == null) { + log.warn("Server response mentioned unknown topic partition {}", tp); + } else if (MetadataOperationContext.shouldRefreshMetadata(error)) { + retryTopicPartitionOffsets.put(tp, offsetRequestSpec); + } else if (error == Errors.NONE) { + Optional leaderEpoch = (partition.leaderEpoch() == ListOffsetsResponse.UNKNOWN_EPOCH) + ? Optional.empty() + : Optional.of(partition.leaderEpoch()); + future.complete(new ListOffsetsResultInfo(partition.offset(), partition.timestamp(), leaderEpoch)); + } else { + future.completeExceptionally(error.exception()); + } + } + } + + if (retryTopicPartitionOffsets.isEmpty()) { + // The server should send back a response for every topic partition. But do a sanity check anyway. + for (ListOffsetsTopic topic : partitionsToQuery) { + for (ListOffsetsPartition partition : topic.partitions()) { + TopicPartition tp = new TopicPartition(topic.name(), partition.partitionIndex()); + ApiException error = new ApiException("The response from broker " + brokerId + + " did not contain a result for topic partition " + tp); + futures.get(tp).completeExceptionally(error); + } + } + } else { + Set retryTopics = retryTopicPartitionOffsets.keySet().stream().map( + TopicPartition::topic).collect(Collectors.toSet()); + MetadataOperationContext retryContext = + new MetadataOperationContext<>(retryTopics, context.options(), context.deadline(), futures); + rescheduleMetadataTask(retryContext, () -> getListOffsetsCalls(retryContext, retryTopicPartitionOffsets, futures)); + } + } + + @Override + void handleFailure(Throwable throwable) { + for (ListOffsetsTopic topic : entry.getValue().values()) { + for (ListOffsetsPartition partition : topic.partitions()) { + TopicPartition tp = new TopicPartition(topic.name(), partition.partitionIndex()); + KafkaFutureImpl future = futures.get(tp); + future.completeExceptionally(throwable); + } + } + } + + @Override + boolean handleUnsupportedVersionException(UnsupportedVersionException exception) { + if (supportsMaxTimestamp) { + supportsMaxTimestamp = false; + + // fail any unsupported futures and remove partitions from the downgraded retry + Iterator topicIterator = partitionsToQuery.iterator(); + while (topicIterator.hasNext()) { + ListOffsetsTopic topic = topicIterator.next(); + Iterator partitionIterator = topic.partitions().iterator(); + while (partitionIterator.hasNext()) { + ListOffsetsPartition partition = partitionIterator.next(); + if (partition.timestamp() == ListOffsetsRequest.MAX_TIMESTAMP) { + futures.get(new TopicPartition(topic.name(), partition.partitionIndex())) + .completeExceptionally(new UnsupportedVersionException( + "Broker " + brokerId + " does not support MAX_TIMESTAMP offset spec")); + partitionIterator.remove(); + } + } + if (topic.partitions().isEmpty()) { + topicIterator.remove(); + } + } + return !partitionsToQuery.isEmpty(); + } + return false; + } + }); + } + return calls; + } + + @Override + public DescribeClientQuotasResult describeClientQuotas(ClientQuotaFilter filter, DescribeClientQuotasOptions options) { + KafkaFutureImpl>> future = new KafkaFutureImpl<>(); + + final long now = time.milliseconds(); + runnable.call(new Call("describeClientQuotas", calcDeadlineMs(now, options.timeoutMs()), + new LeastLoadedNodeProvider()) { + + @Override + DescribeClientQuotasRequest.Builder createRequest(int timeoutMs) { + return new DescribeClientQuotasRequest.Builder(filter); + } + + @Override + void handleResponse(AbstractResponse abstractResponse) { + DescribeClientQuotasResponse response = (DescribeClientQuotasResponse) abstractResponse; + response.complete(future); + } + + @Override + void handleFailure(Throwable throwable) { + future.completeExceptionally(throwable); + } + }, now); + + return new DescribeClientQuotasResult(future); + } + + @Override + public AlterClientQuotasResult alterClientQuotas(Collection entries, AlterClientQuotasOptions options) { + Map> futures = new HashMap<>(entries.size()); + for (ClientQuotaAlteration entry : entries) { + futures.put(entry.entity(), new KafkaFutureImpl<>()); + } + + final long now = time.milliseconds(); + runnable.call(new Call("alterClientQuotas", calcDeadlineMs(now, options.timeoutMs()), + new LeastLoadedNodeProvider()) { + + @Override + AlterClientQuotasRequest.Builder createRequest(int timeoutMs) { + return new AlterClientQuotasRequest.Builder(entries, options.validateOnly()); + } + + @Override + void handleResponse(AbstractResponse abstractResponse) { + AlterClientQuotasResponse response = (AlterClientQuotasResponse) abstractResponse; + response.complete(futures); + } + + @Override + void handleFailure(Throwable throwable) { + completeAllExceptionally(futures.values(), throwable); + } + }, now); + + return new AlterClientQuotasResult(Collections.unmodifiableMap(futures)); + } + + @Override + public DescribeUserScramCredentialsResult describeUserScramCredentials(List users, DescribeUserScramCredentialsOptions options) { + final KafkaFutureImpl dataFuture = new KafkaFutureImpl<>(); + final long now = time.milliseconds(); + Call call = new Call("describeUserScramCredentials", calcDeadlineMs(now, options.timeoutMs()), + new LeastLoadedNodeProvider()) { + @Override + public DescribeUserScramCredentialsRequest.Builder createRequest(final int timeoutMs) { + final DescribeUserScramCredentialsRequestData requestData = new DescribeUserScramCredentialsRequestData(); + + if (users != null && !users.isEmpty()) { + final List userNames = new ArrayList<>(users.size()); + + for (final String user : users) { + if (user != null) { + userNames.add(new UserName().setName(user)); + } + } + + requestData.setUsers(userNames); + } + + return new DescribeUserScramCredentialsRequest.Builder(requestData); + } + + @Override + public void handleResponse(AbstractResponse abstractResponse) { + DescribeUserScramCredentialsResponse response = (DescribeUserScramCredentialsResponse) abstractResponse; + DescribeUserScramCredentialsResponseData data = response.data(); + short messageLevelErrorCode = data.errorCode(); + if (messageLevelErrorCode != Errors.NONE.code()) { + dataFuture.completeExceptionally(Errors.forCode(messageLevelErrorCode).exception(data.errorMessage())); + } else { + dataFuture.complete(data); + } + } + + @Override + void handleFailure(Throwable throwable) { + dataFuture.completeExceptionally(throwable); + } + }; + runnable.call(call, now); + return new DescribeUserScramCredentialsResult(dataFuture); + } + + @Override + public AlterUserScramCredentialsResult alterUserScramCredentials(List alterations, + AlterUserScramCredentialsOptions options) { + final long now = time.milliseconds(); + final Map> futures = new HashMap<>(); + for (UserScramCredentialAlteration alteration: alterations) { + futures.put(alteration.user(), new KafkaFutureImpl<>()); + } + final Map userIllegalAlterationExceptions = new HashMap<>(); + // We need to keep track of users with deletions of an unknown SCRAM mechanism + final String usernameMustNotBeEmptyMsg = "Username must not be empty"; + String passwordMustNotBeEmptyMsg = "Password must not be empty"; + final String unknownScramMechanismMsg = "Unknown SCRAM mechanism"; + alterations.stream().filter(a -> a instanceof UserScramCredentialDeletion).forEach(alteration -> { + final String user = alteration.user(); + if (user == null || user.isEmpty()) { + userIllegalAlterationExceptions.put(alteration.user(), new UnacceptableCredentialException(usernameMustNotBeEmptyMsg)); + } else { + UserScramCredentialDeletion deletion = (UserScramCredentialDeletion) alteration; + ScramMechanism mechanism = deletion.mechanism(); + if (mechanism == null || mechanism == ScramMechanism.UNKNOWN) { + userIllegalAlterationExceptions.put(user, new UnsupportedSaslMechanismException(unknownScramMechanismMsg)); + } + } + }); + // Creating an upsertion may throw InvalidKeyException or NoSuchAlgorithmException, + // so keep track of which users are affected by such a failure so we can fail all their alterations later + final Map> userInsertions = new HashMap<>(); + alterations.stream().filter(a -> a instanceof UserScramCredentialUpsertion) + .filter(alteration -> !userIllegalAlterationExceptions.containsKey(alteration.user())) + .forEach(alteration -> { + final String user = alteration.user(); + if (user == null || user.isEmpty()) { + userIllegalAlterationExceptions.put(alteration.user(), new UnacceptableCredentialException(usernameMustNotBeEmptyMsg)); + } else { + UserScramCredentialUpsertion upsertion = (UserScramCredentialUpsertion) alteration; + try { + byte[] password = upsertion.password(); + if (password == null || password.length == 0) { + userIllegalAlterationExceptions.put(user, new UnacceptableCredentialException(passwordMustNotBeEmptyMsg)); + } else { + ScramMechanism mechanism = upsertion.credentialInfo().mechanism(); + if (mechanism == null || mechanism == ScramMechanism.UNKNOWN) { + userIllegalAlterationExceptions.put(user, new UnsupportedSaslMechanismException(unknownScramMechanismMsg)); + } else { + userInsertions.putIfAbsent(user, new HashMap<>()); + userInsertions.get(user).put(mechanism, getScramCredentialUpsertion(upsertion)); + } + } + } catch (NoSuchAlgorithmException e) { + // we might overwrite an exception from a previous alteration, but we don't really care + // since we just need to mark this user as having at least one illegal alteration + // and make an exception instance available for completing the corresponding future exceptionally + userIllegalAlterationExceptions.put(user, new UnsupportedSaslMechanismException(unknownScramMechanismMsg)); + } catch (InvalidKeyException e) { + // generally shouldn't happen since we deal with the empty password case above, + // but we still need to catch/handle it + userIllegalAlterationExceptions.put(user, new UnacceptableCredentialException(e.getMessage(), e)); + } + } + }); + + // submit alterations only for users that do not have an illegal alteration as identified above + Call call = new Call("alterUserScramCredentials", calcDeadlineMs(now, options.timeoutMs()), + new ControllerNodeProvider()) { + @Override + public AlterUserScramCredentialsRequest.Builder createRequest(int timeoutMs) { + return new AlterUserScramCredentialsRequest.Builder( + new AlterUserScramCredentialsRequestData().setUpsertions(alterations.stream() + .filter(a -> a instanceof UserScramCredentialUpsertion) + .filter(a -> !userIllegalAlterationExceptions.containsKey(a.user())) + .map(a -> userInsertions.get(a.user()).get(((UserScramCredentialUpsertion) a).credentialInfo().mechanism())) + .collect(Collectors.toList())) + .setDeletions(alterations.stream() + .filter(a -> a instanceof UserScramCredentialDeletion) + .filter(a -> !userIllegalAlterationExceptions.containsKey(a.user())) + .map(d -> getScramCredentialDeletion((UserScramCredentialDeletion) d)) + .collect(Collectors.toList()))); + } + + @Override + public void handleResponse(AbstractResponse abstractResponse) { + AlterUserScramCredentialsResponse response = (AlterUserScramCredentialsResponse) abstractResponse; + // Check for controller change + for (Errors error : response.errorCounts().keySet()) { + if (error == Errors.NOT_CONTROLLER) { + handleNotControllerError(error); + } + } + /* Now that we have the results for the ones we sent, + * fail any users that have an illegal alteration as identified above. + * Be sure to do this after the NOT_CONTROLLER error check above + * so that all errors are consistent in that case. + */ + userIllegalAlterationExceptions.entrySet().stream().forEach(entry -> { + futures.get(entry.getKey()).completeExceptionally(entry.getValue()); + }); + response.data().results().forEach(result -> { + KafkaFutureImpl future = futures.get(result.user()); + if (future == null) { + log.warn("Server response mentioned unknown user {}", result.user()); + } else { + Errors error = Errors.forCode(result.errorCode()); + if (error != Errors.NONE) { + future.completeExceptionally(error.exception(result.errorMessage())); + } else { + future.complete(null); + } + } + }); + completeUnrealizedFutures( + futures.entrySet().stream(), + user -> "The broker response did not contain a result for user " + user); + } + + @Override + void handleFailure(Throwable throwable) { + completeAllExceptionally(futures.values(), throwable); + } + }; + runnable.call(call, now); + return new AlterUserScramCredentialsResult(new HashMap<>(futures)); + } + + private static AlterUserScramCredentialsRequestData.ScramCredentialUpsertion getScramCredentialUpsertion(UserScramCredentialUpsertion u) throws InvalidKeyException, NoSuchAlgorithmException { + AlterUserScramCredentialsRequestData.ScramCredentialUpsertion retval = new AlterUserScramCredentialsRequestData.ScramCredentialUpsertion(); + return retval.setName(u.user()) + .setMechanism(u.credentialInfo().mechanism().type()) + .setIterations(u.credentialInfo().iterations()) + .setSalt(u.salt()) + .setSaltedPassword(getSaltedPasword(u.credentialInfo().mechanism(), u.password(), u.salt(), u.credentialInfo().iterations())); + } + + private static AlterUserScramCredentialsRequestData.ScramCredentialDeletion getScramCredentialDeletion(UserScramCredentialDeletion d) { + return new AlterUserScramCredentialsRequestData.ScramCredentialDeletion().setName(d.user()).setMechanism(d.mechanism().type()); + } + + private static byte[] getSaltedPasword(ScramMechanism publicScramMechanism, byte[] password, byte[] salt, int iterations) throws NoSuchAlgorithmException, InvalidKeyException { + return new ScramFormatter(org.apache.kafka.common.security.scram.internals.ScramMechanism.forMechanismName(publicScramMechanism.mechanismName())) + .hi(password, salt, iterations); + } + + @Override + public DescribeFeaturesResult describeFeatures(final DescribeFeaturesOptions options) { + final KafkaFutureImpl future = new KafkaFutureImpl<>(); + final long now = time.milliseconds(); + final Call call = new Call( + "describeFeatures", calcDeadlineMs(now, options.timeoutMs()), new LeastLoadedNodeProvider()) { + + private FeatureMetadata createFeatureMetadata(final ApiVersionsResponse response) { + final Map finalizedFeatures = new HashMap<>(); + for (final FinalizedFeatureKey key : response.data().finalizedFeatures().valuesSet()) { + finalizedFeatures.put(key.name(), new FinalizedVersionRange(key.minVersionLevel(), key.maxVersionLevel())); + } + + Optional finalizedFeaturesEpoch; + if (response.data().finalizedFeaturesEpoch() >= 0L) { + finalizedFeaturesEpoch = Optional.of(response.data().finalizedFeaturesEpoch()); + } else { + finalizedFeaturesEpoch = Optional.empty(); + } + + final Map supportedFeatures = new HashMap<>(); + for (final SupportedFeatureKey key : response.data().supportedFeatures().valuesSet()) { + supportedFeatures.put(key.name(), new SupportedVersionRange(key.minVersion(), key.maxVersion())); + } + + return new FeatureMetadata(finalizedFeatures, finalizedFeaturesEpoch, supportedFeatures); + } + + @Override + ApiVersionsRequest.Builder createRequest(int timeoutMs) { + return new ApiVersionsRequest.Builder(); + } + + @Override + void handleResponse(AbstractResponse response) { + final ApiVersionsResponse apiVersionsResponse = (ApiVersionsResponse) response; + if (apiVersionsResponse.data().errorCode() == Errors.NONE.code()) { + future.complete(createFeatureMetadata(apiVersionsResponse)); + } else { + future.completeExceptionally(Errors.forCode(apiVersionsResponse.data().errorCode()).exception()); + } + } + + @Override + void handleFailure(Throwable throwable) { + completeAllExceptionally(Collections.singletonList(future), throwable); + } + }; + + runnable.call(call, now); + return new DescribeFeaturesResult(future); + } + + @Override + public UpdateFeaturesResult updateFeatures(final Map featureUpdates, + final UpdateFeaturesOptions options) { + if (featureUpdates.isEmpty()) { + throw new IllegalArgumentException("Feature updates can not be null or empty."); + } + + final Map> updateFutures = new HashMap<>(); + for (final Map.Entry entry : featureUpdates.entrySet()) { + final String feature = entry.getKey(); + if (Utils.isBlank(feature)) { + throw new IllegalArgumentException("Provided feature can not be empty."); + } + updateFutures.put(entry.getKey(), new KafkaFutureImpl<>()); + } + + final long now = time.milliseconds(); + final Call call = new Call("updateFeatures", calcDeadlineMs(now, options.timeoutMs()), + new ControllerNodeProvider()) { + + @Override + UpdateFeaturesRequest.Builder createRequest(int timeoutMs) { + final UpdateFeaturesRequestData.FeatureUpdateKeyCollection featureUpdatesRequestData + = new UpdateFeaturesRequestData.FeatureUpdateKeyCollection(); + for (Map.Entry entry : featureUpdates.entrySet()) { + final String feature = entry.getKey(); + final FeatureUpdate update = entry.getValue(); + final UpdateFeaturesRequestData.FeatureUpdateKey requestItem = + new UpdateFeaturesRequestData.FeatureUpdateKey(); + requestItem.setFeature(feature); + requestItem.setMaxVersionLevel(update.maxVersionLevel()); + requestItem.setAllowDowngrade(update.allowDowngrade()); + featureUpdatesRequestData.add(requestItem); + } + return new UpdateFeaturesRequest.Builder( + new UpdateFeaturesRequestData() + .setTimeoutMs(timeoutMs) + .setFeatureUpdates(featureUpdatesRequestData)); + } + + @Override + void handleResponse(AbstractResponse abstractResponse) { + final UpdateFeaturesResponse response = + (UpdateFeaturesResponse) abstractResponse; + + ApiError topLevelError = response.topLevelError(); + switch (topLevelError.error()) { + case NONE: + for (final UpdatableFeatureResult result : response.data().results()) { + final KafkaFutureImpl future = updateFutures.get(result.feature()); + if (future == null) { + log.warn("Server response mentioned unknown feature {}", result.feature()); + } else { + final Errors error = Errors.forCode(result.errorCode()); + if (error == Errors.NONE) { + future.complete(null); + } else { + future.completeExceptionally(error.exception(result.errorMessage())); + } + } + } + // The server should send back a response for every feature, but we do a sanity check anyway. + completeUnrealizedFutures(updateFutures.entrySet().stream(), + feature -> "The controller response did not contain a result for feature " + feature); + break; + case NOT_CONTROLLER: + handleNotControllerError(topLevelError.error()); + break; + default: + for (final Map.Entry> entry : updateFutures.entrySet()) { + entry.getValue().completeExceptionally(topLevelError.exception()); + } + break; + } + } + + @Override + void handleFailure(Throwable throwable) { + completeAllExceptionally(updateFutures.values(), throwable); + } + }; + + runnable.call(call, now); + return new UpdateFeaturesResult(new HashMap<>(updateFutures)); + } + + @Override + public UnregisterBrokerResult unregisterBroker(int brokerId, UnregisterBrokerOptions options) { + final KafkaFutureImpl future = new KafkaFutureImpl<>(); + final long now = time.milliseconds(); + final Call call = new Call("unregisterBroker", calcDeadlineMs(now, options.timeoutMs()), + new LeastLoadedNodeProvider()) { + + @Override + UnregisterBrokerRequest.Builder createRequest(int timeoutMs) { + UnregisterBrokerRequestData data = + new UnregisterBrokerRequestData().setBrokerId(brokerId); + return new UnregisterBrokerRequest.Builder(data); + } + + @Override + void handleResponse(AbstractResponse abstractResponse) { + final UnregisterBrokerResponse response = + (UnregisterBrokerResponse) abstractResponse; + Errors error = Errors.forCode(response.data().errorCode()); + switch (error) { + case NONE: + future.complete(null); + break; + case REQUEST_TIMED_OUT: + throw error.exception(); + default: + log.error("Unregister broker request for broker ID {} failed: {}", + brokerId, error.message()); + future.completeExceptionally(error.exception()); + break; + } + } + + @Override + void handleFailure(Throwable throwable) { + future.completeExceptionally(throwable); + } + }; + runnable.call(call, now); + return new UnregisterBrokerResult(future); + } + + @Override + public DescribeProducersResult describeProducers(Collection topicPartitions, DescribeProducersOptions options) { + AdminApiFuture.SimpleAdminApiFuture future = + DescribeProducersHandler.newFuture(topicPartitions); + DescribeProducersHandler handler = new DescribeProducersHandler(options, logContext); + invokeDriver(handler, future, options.timeoutMs); + return new DescribeProducersResult(future.all()); + } + + @Override + public DescribeTransactionsResult describeTransactions(Collection transactionalIds, DescribeTransactionsOptions options) { + AdminApiFuture.SimpleAdminApiFuture future = + DescribeTransactionsHandler.newFuture(transactionalIds); + DescribeTransactionsHandler handler = new DescribeTransactionsHandler(logContext); + invokeDriver(handler, future, options.timeoutMs); + return new DescribeTransactionsResult(future.all()); + } + + @Override + public AbortTransactionResult abortTransaction(AbortTransactionSpec spec, AbortTransactionOptions options) { + AdminApiFuture.SimpleAdminApiFuture future = + AbortTransactionHandler.newFuture(Collections.singleton(spec.topicPartition())); + AbortTransactionHandler handler = new AbortTransactionHandler(spec, logContext); + invokeDriver(handler, future, options.timeoutMs); + return new AbortTransactionResult(future.all()); + } + + @Override + public ListTransactionsResult listTransactions(ListTransactionsOptions options) { + AllBrokersStrategy.AllBrokersFuture> future = + ListTransactionsHandler.newFuture(); + ListTransactionsHandler handler = new ListTransactionsHandler(options, logContext); + invokeDriver(handler, future, options.timeoutMs); + return new ListTransactionsResult(future.all()); + } + + private void invokeDriver( + AdminApiHandler handler, + AdminApiFuture future, + Integer timeoutMs + ) { + long currentTimeMs = time.milliseconds(); + long deadlineMs = calcDeadlineMs(currentTimeMs, timeoutMs); + + AdminApiDriver driver = new AdminApiDriver<>( + handler, + future, + deadlineMs, + retryBackoffMs, + logContext + ); + + maybeSendRequests(driver, currentTimeMs); + } + + private void maybeSendRequests(AdminApiDriver driver, long currentTimeMs) { + for (AdminApiDriver.RequestSpec spec : driver.poll()) { + runnable.call(newCall(driver, spec), currentTimeMs); + } + } + + private Call newCall(AdminApiDriver driver, AdminApiDriver.RequestSpec spec) { + NodeProvider nodeProvider = spec.scope.destinationBrokerId().isPresent() ? + new ConstantNodeIdProvider(spec.scope.destinationBrokerId().getAsInt()) : + new LeastLoadedNodeProvider(); + return new Call(spec.name, spec.nextAllowedTryMs, spec.tries, spec.deadlineMs, nodeProvider) { + @Override + AbstractRequest.Builder createRequest(int timeoutMs) { + return spec.request; + } + + @Override + void handleResponse(AbstractResponse response) { + long currentTimeMs = time.milliseconds(); + driver.onResponse(currentTimeMs, spec, response, this.curNode()); + maybeSendRequests(driver, currentTimeMs); + } + + @Override + void handleFailure(Throwable throwable) { + long currentTimeMs = time.milliseconds(); + driver.onFailure(currentTimeMs, spec, throwable); + maybeSendRequests(driver, currentTimeMs); + } + + @Override + void maybeRetry(long currentTimeMs, Throwable throwable) { + if (throwable instanceof DisconnectException) { + // Disconnects are a special case. We want to give the driver a chance + // to retry lookup rather than getting stuck on a node which is down. + // For example, if a partition leader shuts down after our metadata query, + // then we might get a disconnect. We want to try to find the new partition + // leader rather than retrying on the same node. + driver.onFailure(currentTimeMs, spec, throwable); + maybeSendRequests(driver, currentTimeMs); + } else { + super.maybeRetry(currentTimeMs, throwable); + } + } + }; + } + + private long getOffsetFromOffsetSpec(OffsetSpec offsetSpec) { + if (offsetSpec instanceof TimestampSpec) { + return ((TimestampSpec) offsetSpec).timestamp(); + } else if (offsetSpec instanceof OffsetSpec.EarliestSpec) { + return ListOffsetsRequest.EARLIEST_TIMESTAMP; + } else if (offsetSpec instanceof OffsetSpec.MaxTimestampSpec) { + return ListOffsetsRequest.MAX_TIMESTAMP; + } + return ListOffsetsRequest.LATEST_TIMESTAMP; + } + + /** + * Get a sub level error when the request is in batch. If given key was not found, + * return an {@link IllegalArgumentException}. + */ + static Throwable getSubLevelError(Map subLevelErrors, K subKey, String keyNotFoundMsg) { + if (!subLevelErrors.containsKey(subKey)) { + return new IllegalArgumentException(keyNotFoundMsg); + } else { + return subLevelErrors.get(subKey).exception(); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/ListConsumerGroupOffsetsOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/ListConsumerGroupOffsetsOptions.java new file mode 100644 index 0000000..af738ca --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/ListConsumerGroupOffsetsOptions.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.List; + +/** + * Options for {@link Admin#listConsumerGroupOffsets(String)}. + *

+ * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class ListConsumerGroupOffsetsOptions extends AbstractOptions { + + private List topicPartitions = null; + + /** + * Set the topic partitions to list as part of the result. + * {@code null} includes all topic partitions. + * + * @param topicPartitions List of topic partitions to include + * @return This ListGroupOffsetsOptions + */ + public ListConsumerGroupOffsetsOptions topicPartitions(List topicPartitions) { + this.topicPartitions = topicPartitions; + return this; + } + + /** + * Returns a list of topic partitions to add as part of the result. + */ + public List topicPartitions() { + return topicPartitions; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/ListConsumerGroupOffsetsResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/ListConsumerGroupOffsetsResult.java new file mode 100644 index 0000000..48f4531 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/ListConsumerGroupOffsetsResult.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Map; + +/** + * The result of the {@link Admin#listConsumerGroupOffsets(String)} call. + *

+ * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class ListConsumerGroupOffsetsResult { + + final KafkaFuture> future; + + ListConsumerGroupOffsetsResult(KafkaFuture> future) { + this.future = future; + } + + /** + * Return a future which yields a map of topic partitions to OffsetAndMetadata objects. + * If the group does not have a committed offset for this partition, the corresponding value in the returned map will be null. + */ + public KafkaFuture> partitionsToOffsetAndMetadata() { + return future; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/ListConsumerGroupsOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/ListConsumerGroupsOptions.java new file mode 100644 index 0000000..9f1f38d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/ListConsumerGroupsOptions.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +import org.apache.kafka.common.ConsumerGroupState; +import org.apache.kafka.common.annotation.InterfaceStability; + +/** + * Options for {@link Admin#listConsumerGroups()}. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class ListConsumerGroupsOptions extends AbstractOptions { + + private Set states = Collections.emptySet(); + + /** + * If states is set, only groups in these states will be returned by listConsumerGroups() + * Otherwise, all groups are returned. + * This operation is supported by brokers with version 2.6.0 or later. + */ + public ListConsumerGroupsOptions inStates(Set states) { + this.states = (states == null) ? Collections.emptySet() : new HashSet<>(states); + return this; + } + + /** + * Returns the list of States that are requested or empty if no states have been specified + */ + public Set states() { + return states; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/ListConsumerGroupsResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/ListConsumerGroupsResult.java new file mode 100644 index 0000000..2d1c612 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/ListConsumerGroupsResult.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.internals.KafkaFutureImpl; + +import java.util.ArrayList; +import java.util.Collection; + +/** + * The result of the {@link Admin#listConsumerGroups()} call. + *

+ * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class ListConsumerGroupsResult { + private final KafkaFutureImpl> all; + private final KafkaFutureImpl> valid; + private final KafkaFutureImpl> errors; + + ListConsumerGroupsResult(KafkaFuture> future) { + this.all = new KafkaFutureImpl<>(); + this.valid = new KafkaFutureImpl<>(); + this.errors = new KafkaFutureImpl<>(); + future.thenApply(new KafkaFuture.BaseFunction, Void>() { + @Override + public Void apply(Collection results) { + ArrayList curErrors = new ArrayList<>(); + ArrayList curValid = new ArrayList<>(); + for (Object resultObject : results) { + if (resultObject instanceof Throwable) { + curErrors.add((Throwable) resultObject); + } else { + curValid.add((ConsumerGroupListing) resultObject); + } + } + if (!curErrors.isEmpty()) { + all.completeExceptionally(curErrors.get(0)); + } else { + all.complete(curValid); + } + valid.complete(curValid); + errors.complete(curErrors); + return null; + } + }); + } + + /** + * Returns a future that yields either an exception, or the full set of consumer group + * listings. + * + * In the event of a failure, the future yields nothing but the first exception which + * occurred. + */ + public KafkaFuture> all() { + return all; + } + + /** + * Returns a future which yields just the valid listings. + * + * This future never fails with an error, no matter what happens. Errors are completely + * ignored. If nothing can be fetched, an empty collection is yielded. + * If there is an error, but some results can be returned, this future will yield + * those partial results. When using this future, it is a good idea to also check + * the errors future so that errors can be displayed and handled. + */ + public KafkaFuture> valid() { + return valid; + } + + /** + * Returns a future which yields just the errors which occurred. + * + * If this future yields a non-empty collection, it is very likely that elements are + * missing from the valid() set. + * + * This future itself never fails with an error. In the event of an error, this future + * will successfully yield a collection containing at least one exception. + */ + public KafkaFuture> errors() { + return errors; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/ListOffsetsOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/ListOffsetsOptions.java new file mode 100644 index 0000000..684ad26 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/ListOffsetsOptions.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.IsolationLevel; +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Map; + +/** + * Options for {@link AdminClient#listOffsets(Map)}. + * + * The API of this class is evolving, see {@link AdminClient} for details. + */ +@InterfaceStability.Evolving +public class ListOffsetsOptions extends AbstractOptions { + + private final IsolationLevel isolationLevel; + + public ListOffsetsOptions() { + this(IsolationLevel.READ_UNCOMMITTED); + } + + public ListOffsetsOptions(IsolationLevel isolationLevel) { + this.isolationLevel = isolationLevel; + } + + public IsolationLevel isolationLevel() { + return isolationLevel; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/ListOffsetsResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/ListOffsetsResult.java new file mode 100644 index 0000000..5eb00de --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/ListOffsetsResult.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ExecutionException; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.annotation.InterfaceStability; + +/** + * The result of the {@link AdminClient#listOffsets(Map)} call. + * + * The API of this class is evolving, see {@link AdminClient} for details. + */ +@InterfaceStability.Evolving +public class ListOffsetsResult { + + private final Map> futures; + + public ListOffsetsResult(Map> futures) { + this.futures = futures; + } + + /** + * Return a future which can be used to check the result for a given partition. + */ + public KafkaFuture partitionResult(final TopicPartition partition) { + KafkaFuture future = futures.get(partition); + if (future == null) { + throw new IllegalArgumentException( + "List Offsets for partition \"" + partition + "\" was not attempted"); + } + return future; + } + + /** + * Return a future which succeeds only if offsets for all specified partitions have been successfully + * retrieved. + */ + public KafkaFuture> all() { + return KafkaFuture.allOf(futures.values().toArray(new KafkaFuture[0])) + .thenApply(new KafkaFuture.BaseFunction>() { + @Override + public Map apply(Void v) { + Map offsets = new HashMap<>(futures.size()); + for (Map.Entry> entry : futures.entrySet()) { + try { + offsets.put(entry.getKey(), entry.getValue().get()); + } catch (InterruptedException | ExecutionException e) { + // This should be unreachable, because allOf ensured that all the futures completed successfully. + throw new RuntimeException(e); + } + } + return offsets; + } + }); + } + + public static class ListOffsetsResultInfo { + + private final long offset; + private final long timestamp; + private final Optional leaderEpoch; + + public ListOffsetsResultInfo(long offset, long timestamp, Optional leaderEpoch) { + this.offset = offset; + this.timestamp = timestamp; + this.leaderEpoch = leaderEpoch; + } + + public long offset() { + return offset; + } + + public long timestamp() { + return timestamp; + } + + public Optional leaderEpoch() { + return leaderEpoch; + } + + @Override + public String toString() { + return "ListOffsetsResultInfo(offset=" + offset + ", timestamp=" + timestamp + ", leaderEpoch=" + + leaderEpoch + ")"; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/ListPartitionReassignmentsOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/ListPartitionReassignmentsOptions.java new file mode 100644 index 0000000..7dcc7a6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/ListPartitionReassignmentsOptions.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +/** + * Options for {@link AdminClient#listPartitionReassignments(ListPartitionReassignmentsOptions)} + * + * The API of this class is evolving. See {@link AdminClient} for details. + */ +@InterfaceStability.Evolving +public class ListPartitionReassignmentsOptions extends AbstractOptions { +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/ListPartitionReassignmentsResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/ListPartitionReassignmentsResult.java new file mode 100644 index 0000000..bc72c06 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/ListPartitionReassignmentsResult.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.TopicPartition; + +import java.util.Map; + +/** + * The result of {@link AdminClient#listPartitionReassignments(ListPartitionReassignmentsOptions)}. + * + * The API of this class is evolving. See {@link AdminClient} for details. + */ +public class ListPartitionReassignmentsResult { + private final KafkaFuture> future; + + ListPartitionReassignmentsResult(KafkaFuture> reassignments) { + this.future = reassignments; + } + + /** + * Return a future which yields a map containing each partition's reassignments + */ + public KafkaFuture> reassignments() { + return future; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/ListTopicsOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/ListTopicsOptions.java new file mode 100644 index 0000000..4ffa66d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/ListTopicsOptions.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Objects; + +/** + * Options for {@link Admin#listTopics()}. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class ListTopicsOptions extends AbstractOptions { + + private boolean listInternal = false; + + /** + * Set the timeout in milliseconds for this operation or {@code null} if the default api timeout for the + * AdminClient should be used. + * + */ + // This method is retained to keep binary compatibility with 0.11 + public ListTopicsOptions timeoutMs(Integer timeoutMs) { + this.timeoutMs = timeoutMs; + return this; + } + + /** + * Set whether we should list internal topics. + * + * @param listInternal Whether we should list internal topics. null means to use + * the default. + * @return This ListTopicsOptions object. + */ + public ListTopicsOptions listInternal(boolean listInternal) { + this.listInternal = listInternal; + return this; + } + + /** + * Return true if we should list internal topics. + */ + public boolean shouldListInternal() { + return listInternal; + } + + @Override + public String toString() { + return "ListTopicsOptions(" + + "listInternal=" + listInternal + + ')'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ListTopicsOptions that = (ListTopicsOptions) o; + return listInternal == that.listInternal; + } + + @Override + public int hashCode() { + return Objects.hash(listInternal); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/ListTopicsResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/ListTopicsResult.java new file mode 100644 index 0000000..2154073 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/ListTopicsResult.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Collection; +import java.util.Map; +import java.util.Set; + +/** + * The result of the {@link Admin#listTopics()} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class ListTopicsResult { + final KafkaFuture> future; + + ListTopicsResult(KafkaFuture> future) { + this.future = future; + } + + /** + * Return a future which yields a map of topic names to TopicListing objects. + */ + public KafkaFuture> namesToListings() { + return future; + } + + /** + * Return a future which yields a collection of TopicListing objects. + */ + public KafkaFuture> listings() { + return future.thenApply(namesToDescriptions -> namesToDescriptions.values()); + } + + /** + * Return a future which yields a collection of topic names. + */ + public KafkaFuture> names() { + return future.thenApply(namesToListings -> namesToListings.keySet()); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/ListTransactionsOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/ListTransactionsOptions.java new file mode 100644 index 0000000..c23d444 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/ListTransactionsOptions.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.Objects; +import java.util.Set; + +/** + * Options for {@link Admin#listTransactions()}. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class ListTransactionsOptions extends AbstractOptions { + private Set filteredStates = Collections.emptySet(); + private Set filteredProducerIds = Collections.emptySet(); + + /** + * Filter only the transactions that are in a specific set of states. If no filter + * is specified or if the passed set of states is empty, then transactions in all + * states will be returned. + * + * @param states the set of states to filter by + * @return this object + */ + public ListTransactionsOptions filterStates(Collection states) { + this.filteredStates = new HashSet<>(states); + return this; + } + + /** + * Filter only the transactions from producers in a specific set of producerIds. + * If no filter is specified or if the passed collection of producerIds is empty, + * then the transactions of all producerIds will be returned. + * + * @param producerIdFilters the set of producerIds to filter by + * @return this object + */ + public ListTransactionsOptions filterProducerIds(Collection producerIdFilters) { + this.filteredProducerIds = new HashSet<>(producerIdFilters); + return this; + } + + /** + * Returns the set of states to be filtered or empty if no states have been specified. + * + * @return the current set of filtered states (empty means that no states are filtered and all + * all transactions will be returned) + */ + public Set filteredStates() { + return filteredStates; + } + + /** + * Returns the set of producerIds that are being filtered or empty if none have been specified. + * + * @return the current set of filtered states (empty means that no producerIds are filtered and + * all transactions will be returned) + */ + public Set filteredProducerIds() { + return filteredProducerIds; + } + + @Override + public String toString() { + return "ListTransactionsOptions(" + + "filteredStates=" + filteredStates + + ", filteredProducerIds=" + filteredProducerIds + + ", timeoutMs=" + timeoutMs + + ')'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ListTransactionsOptions that = (ListTransactionsOptions) o; + return Objects.equals(filteredStates, that.filteredStates) && + Objects.equals(filteredProducerIds, that.filteredProducerIds); + } + + @Override + public int hashCode() { + return Objects.hash(filteredStates, filteredProducerIds); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/ListTransactionsResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/ListTransactionsResult.java new file mode 100644 index 0000000..c9670db --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/ListTransactionsResult.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.internals.KafkaFutureImpl; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * The result of the {@link Admin#listTransactions()} call. + *

+ * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class ListTransactionsResult { + private final KafkaFuture>>> future; + + ListTransactionsResult(KafkaFuture>>> future) { + this.future = future; + } + + /** + * Get all transaction listings. If any of the underlying requests fail, then the future + * returned from this method will also fail with the first encountered error. + * + * @return A future containing the collection of transaction listings. The future completes + * when all transaction listings are available and fails after any non-retriable error. + */ + public KafkaFuture> all() { + return allByBrokerId().thenApply(map -> { + List allListings = new ArrayList<>(); + for (Collection listings : map.values()) { + allListings.addAll(listings); + } + return allListings; + }); + } + + /** + * Get a future which returns a map containing the underlying listing future for each broker + * in the cluster. This is useful, for example, if a partial listing of transactions is + * sufficient, or if you want more granular error details. + * + * @return A future containing a map of futures by broker which complete individually when + * their respective transaction listings are available. The top-level future returned + * from this method may fail if the admin client is unable to lookup the available + * brokers in the cluster. + */ + public KafkaFuture>>> byBrokerId() { + KafkaFutureImpl>>> result = new KafkaFutureImpl<>(); + future.whenComplete((brokerFutures, exception) -> { + if (brokerFutures != null) { + Map>> brokerFuturesCopy = + new HashMap<>(brokerFutures.size()); + brokerFuturesCopy.putAll(brokerFutures); + result.complete(brokerFuturesCopy); + } else { + result.completeExceptionally(exception); + } + }); + return result; + } + + /** + * Get all transaction listings in a map which is keyed by the ID of respective broker + * that is currently managing them. If any of the underlying requests fail, then the future + * returned from this method will also fail with the first encountered error. + * + * @return A future containing a map from the broker ID to the transactions hosted by that + * broker respectively. This future completes when all transaction listings are + * available and fails after any non-retriable error. + */ + public KafkaFuture>> allByBrokerId() { + KafkaFutureImpl>> allFuture = new KafkaFutureImpl<>(); + Map> allListingsMap = new HashMap<>(); + + future.whenComplete((map, topLevelException) -> { + if (topLevelException != null) { + allFuture.completeExceptionally(topLevelException); + return; + } + + Set remainingResponses = new HashSet<>(map.keySet()); + map.forEach((brokerId, future) -> { + future.whenComplete((listings, brokerException) -> { + if (brokerException != null) { + allFuture.completeExceptionally(brokerException); + } else if (!allFuture.isDone()) { + allListingsMap.put(brokerId, listings); + remainingResponses.remove(brokerId); + + if (remainingResponses.isEmpty()) { + allFuture.complete(allListingsMap); + } + } + }); + }); + }); + + return allFuture; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/LogDirDescription.java b/clients/src/main/java/org/apache/kafka/clients/admin/LogDirDescription.java new file mode 100644 index 0000000..1c326ec --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/LogDirDescription.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.ApiException; + +import java.util.Map; + +import static java.util.Collections.unmodifiableMap; + +/** + * A description of a log directory on a particular broker. + */ +public class LogDirDescription { + private final Map replicaInfos; + private final ApiException error; + + public LogDirDescription(ApiException error, Map replicaInfos) { + this.error = error; + this.replicaInfos = replicaInfos; + } + + /** + * Returns `ApiException` if the log directory is offline or an error occurred, otherwise returns null. + *

    + *
  • KafkaStorageException - The log directory is offline. + *
  • UnknownServerException - The server experienced an unexpected error when processing the request. + *
+ */ + public ApiException error() { + return error; + } + + /** + * A map from topic partition to replica information for that partition + * in this log directory. + */ + public Map replicaInfos() { + return unmodifiableMap(replicaInfos); + } + + @Override + public String toString() { + return "LogDirDescription(" + + "replicaInfos=" + replicaInfos + + ", error=" + error + + ')'; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/MemberAssignment.java b/clients/src/main/java/org/apache/kafka/clients/admin/MemberAssignment.java new file mode 100644 index 0000000..3305de0 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/MemberAssignment.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.Utils; + +import java.util.Collections; +import java.util.HashSet; +import java.util.Objects; +import java.util.Set; + +/** + * A description of the assignments of a specific group member. + */ +public class MemberAssignment { + private final Set topicPartitions; + + /** + * Creates an instance with the specified parameters. + * + * @param topicPartitions List of topic partitions + */ + public MemberAssignment(Set topicPartitions) { + this.topicPartitions = topicPartitions == null ? Collections.emptySet() : + Collections.unmodifiableSet(new HashSet<>(topicPartitions)); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + MemberAssignment that = (MemberAssignment) o; + + return Objects.equals(topicPartitions, that.topicPartitions); + } + + @Override + public int hashCode() { + return topicPartitions != null ? topicPartitions.hashCode() : 0; + } + + /** + * The topic partitions assigned to a group member. + */ + public Set topicPartitions() { + return topicPartitions; + } + + @Override + public String toString() { + return "(topicPartitions=" + Utils.join(topicPartitions, ",") + ")"; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/MemberDescription.java b/clients/src/main/java/org/apache/kafka/clients/admin/MemberDescription.java new file mode 100644 index 0000000..7bc6b14 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/MemberDescription.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import java.util.Collections; +import java.util.Objects; +import java.util.Optional; + +/** + * A detailed description of a single group instance in the cluster. + */ +public class MemberDescription { + private final String memberId; + private final Optional groupInstanceId; + private final String clientId; + private final String host; + private final MemberAssignment assignment; + + public MemberDescription(String memberId, + Optional groupInstanceId, + String clientId, + String host, + MemberAssignment assignment) { + this.memberId = memberId == null ? "" : memberId; + this.groupInstanceId = groupInstanceId; + this.clientId = clientId == null ? "" : clientId; + this.host = host == null ? "" : host; + this.assignment = assignment == null ? + new MemberAssignment(Collections.emptySet()) : assignment; + } + + public MemberDescription(String memberId, + String clientId, + String host, + MemberAssignment assignment) { + this(memberId, Optional.empty(), clientId, host, assignment); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MemberDescription that = (MemberDescription) o; + return memberId.equals(that.memberId) && + groupInstanceId.equals(that.groupInstanceId) && + clientId.equals(that.clientId) && + host.equals(that.host) && + assignment.equals(that.assignment); + } + + @Override + public int hashCode() { + return Objects.hash(memberId, groupInstanceId, clientId, host, assignment); + } + + /** + * The consumer id of the group member. + */ + public String consumerId() { + return memberId; + } + + /** + * The instance id of the group member. + */ + public Optional groupInstanceId() { + return groupInstanceId; + } + + /** + * The client id of the group member. + */ + public String clientId() { + return clientId; + } + + /** + * The host where the group member is running. + */ + public String host() { + return host; + } + + /** + * The assignment of the group member. + */ + public MemberAssignment assignment() { + return assignment; + } + + @Override + public String toString() { + return "(memberId=" + memberId + + ", groupInstanceId=" + groupInstanceId.orElse("null") + + ", clientId=" + clientId + + ", host=" + host + + ", assignment=" + assignment + ")"; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/MemberToRemove.java b/clients/src/main/java/org/apache/kafka/clients/admin/MemberToRemove.java new file mode 100644 index 0000000..4c7b16b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/MemberToRemove.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.message.LeaveGroupRequestData.MemberIdentity; +import org.apache.kafka.common.requests.JoinGroupRequest; + +import java.util.Objects; + +/** + * A struct containing information about the member to be removed. + */ +public class MemberToRemove { + private final String groupInstanceId; + + public MemberToRemove(String groupInstanceId) { + this.groupInstanceId = groupInstanceId; + } + + @Override + public boolean equals(Object o) { + if (o instanceof MemberToRemove) { + MemberToRemove otherMember = (MemberToRemove) o; + return this.groupInstanceId.equals(otherMember.groupInstanceId); + } else { + return false; + } + } + + @Override + public int hashCode() { + return Objects.hash(groupInstanceId); + } + + MemberIdentity toMemberIdentity() { + return new MemberIdentity() + .setGroupInstanceId(groupInstanceId) + .setMemberId(JoinGroupRequest.UNKNOWN_MEMBER_ID); + } + + public String groupInstanceId() { + return groupInstanceId; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/NewPartitionReassignment.java b/clients/src/main/java/org/apache/kafka/clients/admin/NewPartitionReassignment.java new file mode 100644 index 0000000..f9a7008 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/NewPartitionReassignment.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +/** + * A new partition reassignment, which can be applied via {@link AdminClient#alterPartitionReassignments(Map, AlterPartitionReassignmentsOptions)}. + */ +public class NewPartitionReassignment { + private final List targetReplicas; + + /** + * @throws IllegalArgumentException if no replicas are supplied + */ + public NewPartitionReassignment(List targetReplicas) { + if (targetReplicas == null || targetReplicas.size() == 0) + throw new IllegalArgumentException("Cannot create a new partition reassignment without any replicas"); + this.targetReplicas = Collections.unmodifiableList(new ArrayList<>(targetReplicas)); + } + + public List targetReplicas() { + return targetReplicas; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/NewPartitions.java b/clients/src/main/java/org/apache/kafka/clients/admin/NewPartitions.java new file mode 100644 index 0000000..06da256 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/NewPartitions.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.List; +import java.util.Map; + +/** + * Describes new partitions for a particular topic in a call to {@link Admin#createPartitions(Map)}. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class NewPartitions { + + private int totalCount; + + private List> newAssignments; + + private NewPartitions(int totalCount, List> newAssignments) { + this.totalCount = totalCount; + this.newAssignments = newAssignments; + } + + /** + * Increase the partition count for a topic to the given {@code totalCount}. + * The assignment of new replicas to brokers will be decided by the broker. + * + * @param totalCount The total number of partitions after the operation succeeds. + */ + public static NewPartitions increaseTo(int totalCount) { + return new NewPartitions(totalCount, null); + } + + /** + *

Increase the partition count for a topic to the given {@code totalCount} + * assigning the new partitions according to the given {@code newAssignments}. + * The length of the given {@code newAssignments} should equal {@code totalCount - oldCount}, since + * the assignment of existing partitions are not changed. + * Each inner list of {@code newAssignments} should have a length equal to + * the topic's replication factor. + * The first broker id in each inner list is the "preferred replica".

+ * + *

For example, suppose a topic currently has a replication factor of 2, and + * has 3 partitions. The number of partitions can be increased to 6 using a + * {@code NewPartition} constructed like this:

+ * + *

+     * NewPartitions.increaseTo(6, asList(asList(1, 2),
+     *                                    asList(2, 3),
+     *                                    asList(3, 1)))
+     * 
+ *

In this example partition 3's preferred leader will be broker 1, partition 4's preferred leader will be + * broker 2 and partition 5's preferred leader will be broker 3.

+ * + * @param totalCount The total number of partitions after the operation succeeds. + * @param newAssignments The replica assignments for the new partitions. + */ + public static NewPartitions increaseTo(int totalCount, List> newAssignments) { + return new NewPartitions(totalCount, newAssignments); + } + + /** + * The total number of partitions after the operation succeeds. + */ + public int totalCount() { + return totalCount; + } + + /** + * The replica assignments for the new partitions, or null if the assignment will be done by the controller. + */ + public List> assignments() { + return newAssignments; + } + + @Override + public String toString() { + return "(totalCount=" + totalCount() + ", newAssignments=" + assignments() + ")"; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/NewTopic.java b/clients/src/main/java/org/apache/kafka/clients/admin/NewTopic.java new file mode 100644 index 0000000..2f335d0 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/NewTopic.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import java.util.Optional; +import org.apache.kafka.common.message.CreateTopicsRequestData.CreatableReplicaAssignment; +import org.apache.kafka.common.message.CreateTopicsRequestData.CreatableTopic; +import org.apache.kafka.common.message.CreateTopicsRequestData.CreateableTopicConfig; +import org.apache.kafka.common.requests.CreateTopicsRequest; + +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Map.Entry; + +/** + * A new topic to be created via {@link Admin#createTopics(Collection)}. + */ +public class NewTopic { + + private final String name; + private final Optional numPartitions; + private final Optional replicationFactor; + private final Map> replicasAssignments; + private Map configs = null; + + /** + * A new topic with the specified replication factor and number of partitions. + */ + public NewTopic(String name, int numPartitions, short replicationFactor) { + this(name, Optional.of(numPartitions), Optional.of(replicationFactor)); + } + + /** + * A new topic that optionally defaults {@code numPartitions} and {@code replicationFactor} to + * the broker configurations for {@code num.partitions} and {@code default.replication.factor} + * respectively. + */ + public NewTopic(String name, Optional numPartitions, Optional replicationFactor) { + this.name = name; + this.numPartitions = numPartitions; + this.replicationFactor = replicationFactor; + this.replicasAssignments = null; + } + + /** + * A new topic with the specified replica assignment configuration. + * + * @param name the topic name. + * @param replicasAssignments a map from partition id to replica ids (i.e. broker ids). Although not enforced, it is + * generally a good idea for all partitions to have the same number of replicas. + */ + public NewTopic(String name, Map> replicasAssignments) { + this.name = name; + this.numPartitions = Optional.empty(); + this.replicationFactor = Optional.empty(); + this.replicasAssignments = Collections.unmodifiableMap(replicasAssignments); + } + + /** + * The name of the topic to be created. + */ + public String name() { + return name; + } + + /** + * The number of partitions for the new topic or -1 if a replica assignment has been specified. + */ + public int numPartitions() { + return numPartitions.orElse(CreateTopicsRequest.NO_NUM_PARTITIONS); + } + + /** + * The replication factor for the new topic or -1 if a replica assignment has been specified. + */ + public short replicationFactor() { + return replicationFactor.orElse(CreateTopicsRequest.NO_REPLICATION_FACTOR); + } + + /** + * A map from partition id to replica ids (i.e. broker ids) or null if the number of partitions and replication + * factor have been specified instead. + */ + public Map> replicasAssignments() { + return replicasAssignments; + } + + /** + * Set the configuration to use on the new topic. + * + * @param configs The configuration map. + * @return This NewTopic object. + */ + public NewTopic configs(Map configs) { + this.configs = configs; + return this; + } + + /** + * The configuration for the new topic or null if no configs ever specified. + */ + public Map configs() { + return configs; + } + + CreatableTopic convertToCreatableTopic() { + CreatableTopic creatableTopic = new CreatableTopic(). + setName(name). + setNumPartitions(numPartitions.orElse(CreateTopicsRequest.NO_NUM_PARTITIONS)). + setReplicationFactor(replicationFactor.orElse(CreateTopicsRequest.NO_REPLICATION_FACTOR)); + if (replicasAssignments != null) { + for (Entry> entry : replicasAssignments.entrySet()) { + creatableTopic.assignments().add( + new CreatableReplicaAssignment(). + setPartitionIndex(entry.getKey()). + setBrokerIds(entry.getValue())); + } + } + if (configs != null) { + for (Entry entry : configs.entrySet()) { + creatableTopic.configs().add( + new CreateableTopicConfig(). + setName(entry.getKey()). + setValue(entry.getValue())); + } + } + return creatableTopic; + } + + @Override + public String toString() { + StringBuilder bld = new StringBuilder(); + bld.append("(name=").append(name). + append(", numPartitions=").append(numPartitions.map(String::valueOf).orElse("default")). + append(", replicationFactor=").append(replicationFactor.map(String::valueOf).orElse("default")). + append(", replicasAssignments=").append(replicasAssignments). + append(", configs=").append(configs). + append(")"); + return bld.toString(); + } + + @Override + public boolean equals(final Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + final NewTopic that = (NewTopic) o; + return Objects.equals(name, that.name) && + Objects.equals(numPartitions, that.numPartitions) && + Objects.equals(replicationFactor, that.replicationFactor) && + Objects.equals(replicasAssignments, that.replicasAssignments) && + Objects.equals(configs, that.configs); + } + + @Override + public int hashCode() { + return Objects.hash(name, numPartitions, replicationFactor, replicasAssignments, configs); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/OffsetSpec.java b/clients/src/main/java/org/apache/kafka/clients/admin/OffsetSpec.java new file mode 100644 index 0000000..dcf9045 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/OffsetSpec.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import java.util.Map; + +/** + * This class allows to specify the desired offsets when using {@link KafkaAdminClient#listOffsets(Map, ListOffsetsOptions)} + */ +public class OffsetSpec { + + public static class EarliestSpec extends OffsetSpec { } + public static class LatestSpec extends OffsetSpec { } + public static class MaxTimestampSpec extends OffsetSpec { } + public static class TimestampSpec extends OffsetSpec { + private final long timestamp; + + TimestampSpec(long timestamp) { + this.timestamp = timestamp; + } + + long timestamp() { + return timestamp; + } + } + + /** + * Used to retrieve the latest offset of a partition + */ + public static OffsetSpec latest() { + return new LatestSpec(); + } + + /** + * Used to retrieve the earliest offset of a partition + */ + public static OffsetSpec earliest() { + return new EarliestSpec(); + } + + /** + * Used to retrieve the earliest offset whose timestamp is greater than + * or equal to the given timestamp in the corresponding partition + * @param timestamp in milliseconds + */ + public static OffsetSpec forTimestamp(long timestamp) { + return new TimestampSpec(timestamp); + } + + /** + * Used to retrieve the offset with the largest timestamp of a partition + * as message timestamps can be specified client side this may not match + * the log end offset returned by LatestSpec + */ + public static OffsetSpec maxTimestamp() { + return new MaxTimestampSpec(); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/PartitionReassignment.java b/clients/src/main/java/org/apache/kafka/clients/admin/PartitionReassignment.java new file mode 100644 index 0000000..4a9d151 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/PartitionReassignment.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import java.util.Collections; +import java.util.List; + +/** + * A partition reassignment, which has been listed via {@link AdminClient#listPartitionReassignments()}. + */ +public class PartitionReassignment { + + private final List replicas; + private final List addingReplicas; + private final List removingReplicas; + + public PartitionReassignment(List replicas, List addingReplicas, List removingReplicas) { + this.replicas = Collections.unmodifiableList(replicas); + this.addingReplicas = Collections.unmodifiableList(addingReplicas); + this.removingReplicas = Collections.unmodifiableList(removingReplicas); + } + + /** + * The brokers which this partition currently resides on. + */ + public List replicas() { + return replicas; + } + + /** + * The brokers that we are adding this partition to as part of a reassignment. + * A subset of replicas. + */ + public List addingReplicas() { + return addingReplicas; + } + + /** + * The brokers that we are removing this partition from as part of a reassignment. + * A subset of replicas. + */ + public List removingReplicas() { + return removingReplicas; + } + + @Override + public String toString() { + return "PartitionReassignment(" + + "replicas=" + replicas + + ", addingReplicas=" + addingReplicas + + ", removingReplicas=" + removingReplicas + + ')'; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/ProducerState.java b/clients/src/main/java/org/apache/kafka/clients/admin/ProducerState.java new file mode 100644 index 0000000..243edde --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/ProducerState.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import java.util.Objects; +import java.util.OptionalInt; +import java.util.OptionalLong; + +public class ProducerState { + private final long producerId; + private final int producerEpoch; + private final int lastSequence; + private final long lastTimestamp; + private final OptionalInt coordinatorEpoch; + private final OptionalLong currentTransactionStartOffset; + + public ProducerState( + long producerId, + int producerEpoch, + int lastSequence, + long lastTimestamp, + OptionalInt coordinatorEpoch, + OptionalLong currentTransactionStartOffset + ) { + this.producerId = producerId; + this.producerEpoch = producerEpoch; + this.lastSequence = lastSequence; + this.lastTimestamp = lastTimestamp; + this.coordinatorEpoch = coordinatorEpoch; + this.currentTransactionStartOffset = currentTransactionStartOffset; + } + + public long producerId() { + return producerId; + } + + public int producerEpoch() { + return producerEpoch; + } + + public int lastSequence() { + return lastSequence; + } + + public long lastTimestamp() { + return lastTimestamp; + } + + public OptionalLong currentTransactionStartOffset() { + return currentTransactionStartOffset; + } + + public OptionalInt coordinatorEpoch() { + return coordinatorEpoch; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ProducerState that = (ProducerState) o; + return producerId == that.producerId && + producerEpoch == that.producerEpoch && + lastSequence == that.lastSequence && + lastTimestamp == that.lastTimestamp && + Objects.equals(coordinatorEpoch, that.coordinatorEpoch) && + Objects.equals(currentTransactionStartOffset, that.currentTransactionStartOffset); + } + + @Override + public int hashCode() { + return Objects.hash(producerId, producerEpoch, lastSequence, lastTimestamp, + coordinatorEpoch, currentTransactionStartOffset); + } + + @Override + public String toString() { + return "ProducerState(" + + "producerId=" + producerId + + ", producerEpoch=" + producerEpoch + + ", lastSequence=" + lastSequence + + ", lastTimestamp=" + lastTimestamp + + ", coordinatorEpoch=" + coordinatorEpoch + + ", currentTransactionStartOffset=" + currentTransactionStartOffset + + ')'; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/RecordsToDelete.java b/clients/src/main/java/org/apache/kafka/clients/admin/RecordsToDelete.java new file mode 100644 index 0000000..af835c8 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/RecordsToDelete.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Map; + +/** + * Describe records to delete in a call to {@link Admin#deleteRecords(Map)} + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class RecordsToDelete { + + private long offset; + + private RecordsToDelete(long offset) { + this.offset = offset; + } + + /** + * Delete all the records before the given {@code offset} + * + * @param offset the offset before which all records will be deleted + */ + public static RecordsToDelete beforeOffset(long offset) { + return new RecordsToDelete(offset); + } + + /** + * The offset before which all records will be deleted + */ + public long beforeOffset() { + return offset; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + RecordsToDelete that = (RecordsToDelete) o; + + return this.offset == that.offset; + } + + @Override + public int hashCode() { + return (int) offset; + } + + @Override + public String toString() { + return "(beforeOffset = " + offset + ")"; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/RemoveMembersFromConsumerGroupOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/RemoveMembersFromConsumerGroupOptions.java new file mode 100644 index 0000000..322beec --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/RemoveMembersFromConsumerGroupOptions.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +/** + * Options for {@link AdminClient#removeMembersFromConsumerGroup(String, RemoveMembersFromConsumerGroupOptions)}. + * It carries the members to be removed from the consumer group. + * + * The API of this class is evolving, see {@link AdminClient} for details. + */ +@InterfaceStability.Evolving +public class RemoveMembersFromConsumerGroupOptions extends AbstractOptions { + + private Set members; + + public RemoveMembersFromConsumerGroupOptions(Collection members) { + if (members.isEmpty()) { + throw new IllegalArgumentException("Invalid empty members has been provided"); + } + this.members = new HashSet<>(members); + } + + public RemoveMembersFromConsumerGroupOptions() { + this.members = Collections.emptySet(); + } + + public Set members() { + return members; + } + + public boolean removeAll() { + return members.isEmpty(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/RemoveMembersFromConsumerGroupResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/RemoveMembersFromConsumerGroupResult.java new file mode 100644 index 0000000..3845e2f --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/RemoveMembersFromConsumerGroupResult.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.message.LeaveGroupRequestData.MemberIdentity; +import org.apache.kafka.common.protocol.Errors; + +import java.util.Map; +import java.util.Set; + +/** + * The result of the {@link Admin#removeMembersFromConsumerGroup(String, RemoveMembersFromConsumerGroupOptions)} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +public class RemoveMembersFromConsumerGroupResult { + + private final KafkaFuture> future; + private final Set memberInfos; + + RemoveMembersFromConsumerGroupResult(KafkaFuture> future, + Set memberInfos) { + this.future = future; + this.memberInfos = memberInfos; + } + + /** + * Returns a future which indicates whether the request was 100% success, i.e. no + * either top level or member level error. + * If not, the first member error shall be returned. + */ + public KafkaFuture all() { + final KafkaFutureImpl result = new KafkaFutureImpl<>(); + this.future.whenComplete((memberErrors, throwable) -> { + if (throwable != null) { + result.completeExceptionally(throwable); + } else { + if (removeAll()) { + for (Map.Entry entry: memberErrors.entrySet()) { + Exception exception = entry.getValue().exception(); + if (exception != null) { + Throwable ex = new KafkaException("Encounter exception when trying to remove: " + + entry.getKey(), exception); + result.completeExceptionally(ex); + return; + } + } + } else { + for (MemberToRemove memberToRemove : memberInfos) { + if (maybeCompleteExceptionally(memberErrors, memberToRemove.toMemberIdentity(), result)) { + return; + } + } + } + result.complete(null); + } + }); + return result; + } + + /** + * Returns the selected member future. + */ + public KafkaFuture memberResult(MemberToRemove member) { + if (removeAll()) { + throw new IllegalArgumentException("The method: memberResult is not applicable in 'removeAll' mode"); + } + if (!memberInfos.contains(member)) { + throw new IllegalArgumentException("Member " + member + " was not included in the original request"); + } + + final KafkaFutureImpl result = new KafkaFutureImpl<>(); + this.future.whenComplete((memberErrors, throwable) -> { + if (throwable != null) { + result.completeExceptionally(throwable); + } else if (!maybeCompleteExceptionally(memberErrors, member.toMemberIdentity(), result)) { + result.complete(null); + } + }); + return result; + } + + private boolean maybeCompleteExceptionally(Map memberErrors, + MemberIdentity member, + KafkaFutureImpl result) { + Throwable exception = KafkaAdminClient.getSubLevelError(memberErrors, member, + "Member \"" + member + "\" was not included in the removal response"); + if (exception != null) { + result.completeExceptionally(exception); + return true; + } else { + return false; + } + } + + private boolean removeAll() { + return memberInfos.isEmpty(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/RenewDelegationTokenOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/RenewDelegationTokenOptions.java new file mode 100644 index 0000000..5c2b0d1 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/RenewDelegationTokenOptions.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +/** + * Options for {@link Admin#renewDelegationToken(byte[], RenewDelegationTokenOptions)}. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class RenewDelegationTokenOptions extends AbstractOptions { + private long renewTimePeriodMs = -1; + + public RenewDelegationTokenOptions renewTimePeriodMs(long renewTimePeriodMs) { + this.renewTimePeriodMs = renewTimePeriodMs; + return this; + } + + public long renewTimePeriodMs() { + return renewTimePeriodMs; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/RenewDelegationTokenResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/RenewDelegationTokenResult.java new file mode 100644 index 0000000..74725d4 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/RenewDelegationTokenResult.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.annotation.InterfaceStability; + +/** + * The result of the {@link KafkaAdminClient#expireDelegationToken(byte[], ExpireDelegationTokenOptions)} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class RenewDelegationTokenResult { + private final KafkaFuture expiryTimestamp; + + RenewDelegationTokenResult(KafkaFuture expiryTimestamp) { + this.expiryTimestamp = expiryTimestamp; + } + + /** + * Returns a future which yields expiry timestamp + */ + public KafkaFuture expiryTimestamp() { + return expiryTimestamp; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/ReplicaInfo.java b/clients/src/main/java/org/apache/kafka/clients/admin/ReplicaInfo.java new file mode 100644 index 0000000..b77375d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/ReplicaInfo.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +/** + * A description of a replica on a particular broker. + */ +public class ReplicaInfo { + + private final long size; + private final long offsetLag; + private final boolean isFuture; + + public ReplicaInfo(long size, long offsetLag, boolean isFuture) { + this.size = size; + this.offsetLag = offsetLag; + this.isFuture = isFuture; + } + + /** + * The total size of the log segments in this replica in bytes. + */ + public long size() { + return size; + } + + /** + * The lag of the log's LEO with respect to the partition's + * high watermark (if it is the current log for the partition) + * or the current replica's LEO (if it is the {@linkplain #isFuture() future log} + * for the partition). + */ + public long offsetLag() { + return offsetLag; + } + + /** + * Whether this replica has been created by a AlterReplicaLogDirsRequest + * but not yet replaced the current replica on the broker. + * + * @return true if this log is created by AlterReplicaLogDirsRequest and will replace the current log + * of the replica at some time in the future. + */ + public boolean isFuture() { + return isFuture; + } + + @Override + public String toString() { + return "ReplicaInfo(" + + "size=" + size + + ", offsetLag=" + offsetLag + + ", isFuture=" + isFuture + + ')'; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/ScramCredentialInfo.java b/clients/src/main/java/org/apache/kafka/clients/admin/ScramCredentialInfo.java new file mode 100644 index 0000000..e8403b6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/ScramCredentialInfo.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import java.util.Objects; + +/** + * Mechanism and iterations for a SASL/SCRAM credential associated with a user. + * + * @see KIP-554: Add Broker-side SCRAM Config API + */ +public class ScramCredentialInfo { + private final ScramMechanism mechanism; + private final int iterations; + + /** + * + * @param mechanism the required mechanism + * @param iterations the number of iterations used when creating the credential + */ + public ScramCredentialInfo(ScramMechanism mechanism, int iterations) { + this.mechanism = Objects.requireNonNull(mechanism); + this.iterations = iterations; + } + + /** + * + * @return the mechanism + */ + public ScramMechanism mechanism() { + return mechanism; + } + + /** + * + * @return the number of iterations used when creating the credential + */ + public int iterations() { + return iterations; + } + + @Override + public String toString() { + return "ScramCredentialInfo{" + + "mechanism=" + mechanism + + ", iterations=" + iterations + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ScramCredentialInfo that = (ScramCredentialInfo) o; + return iterations == that.iterations && + mechanism == that.mechanism; + } + + @Override + public int hashCode() { + return Objects.hash(mechanism, iterations); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/ScramMechanism.java b/clients/src/main/java/org/apache/kafka/clients/admin/ScramMechanism.java new file mode 100644 index 0000000..95ad18c --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/ScramMechanism.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import java.util.Arrays; + +/** + * Representation of a SASL/SCRAM Mechanism. + * + * @see KIP-554: Add Broker-side SCRAM Config API + */ +public enum ScramMechanism { + UNKNOWN((byte) 0), + SCRAM_SHA_256((byte) 1), + SCRAM_SHA_512((byte) 2); + + private static final ScramMechanism[] VALUES = values(); + + /** + * + * @param type the type indicator + * @return the instance corresponding to the given type indicator, otherwise {@link #UNKNOWN} + */ + public static ScramMechanism fromType(byte type) { + for (ScramMechanism scramMechanism : VALUES) { + if (scramMechanism.type == type) { + return scramMechanism; + } + } + return UNKNOWN; + } + + /** + * + * @param mechanismName the SASL SCRAM mechanism name + * @return the corresponding SASL SCRAM mechanism enum, otherwise {@link #UNKNOWN} + * @see + * Salted Challenge Response Authentication Mechanism (SCRAM) SASL and GSS-API Mechanisms, Section 4 + */ + public static ScramMechanism fromMechanismName(String mechanismName) { + return Arrays.stream(VALUES) + .filter(mechanism -> mechanism.mechanismName.equals(mechanismName)) + .findFirst() + .orElse(UNKNOWN); + } + + /** + * + * @return the corresponding SASL SCRAM mechanism name + * @see + * Salted Challenge Response Authentication Mechanism (SCRAM) SASL and GSS-API Mechanisms, Section 4 + */ + public String mechanismName() { + return this.mechanismName; + } + + /** + * + * @return the type indicator for this SASL SCRAM mechanism + */ + public byte type() { + return this.type; + } + + private final byte type; + private final String mechanismName; + + private ScramMechanism(byte type) { + this.type = type; + this.mechanismName = toString().replace('_', '-'); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/SupportedVersionRange.java b/clients/src/main/java/org/apache/kafka/clients/admin/SupportedVersionRange.java new file mode 100644 index 0000000..d71da31 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/SupportedVersionRange.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import java.util.Objects; + +/** + * Represents a range of versions that a particular broker supports for some feature. + */ +public class SupportedVersionRange { + private final short minVersion; + + private final short maxVersion; + + /** + * Raises an exception unless the following conditions are met: + * 1 <= minVersion <= maxVersion. + * + * @param minVersion The minimum version value. + * @param maxVersion The maximum version value. + * + * @throws IllegalArgumentException Raised when the condition described above is not met. + */ + SupportedVersionRange(final short minVersion, final short maxVersion) { + if (minVersion < 1 || maxVersion < 1 || maxVersion < minVersion) { + throw new IllegalArgumentException( + String.format( + "Expected 1 <= minVersion <= maxVersion but received minVersion:%d, maxVersion:%d.", + minVersion, + maxVersion)); + } + this.minVersion = minVersion; + this.maxVersion = maxVersion; + } + + public short minVersion() { + return minVersion; + } + + public short maxVersion() { + return maxVersion; + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + + if (other == null || getClass() != other.getClass()) { + return false; + } + + final SupportedVersionRange that = (SupportedVersionRange) other; + return this.minVersion == that.minVersion && this.maxVersion == that.maxVersion; + } + + @Override + public int hashCode() { + return Objects.hash(minVersion, maxVersion); + } + + @Override + public String toString() { + return String.format("SupportedVersionRange[min_version:%d, max_version:%d]", minVersion, maxVersion); + } +} + diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/TopicDescription.java b/clients/src/main/java/org/apache/kafka/clients/admin/TopicDescription.java new file mode 100644 index 0000000..e8700d4 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/TopicDescription.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.TopicPartitionInfo; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.acl.AclOperation; +import org.apache.kafka.common.utils.Utils; + +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.Set; + +/** + * A detailed description of a single topic in the cluster. + */ +public class TopicDescription { + private final String name; + private final boolean internal; + private final List partitions; + private final Set authorizedOperations; + private final Uuid topicId; + + @Override + public boolean equals(final Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + final TopicDescription that = (TopicDescription) o; + return internal == that.internal && + Objects.equals(name, that.name) && + Objects.equals(partitions, that.partitions) && + Objects.equals(authorizedOperations, that.authorizedOperations); + } + + @Override + public int hashCode() { + return Objects.hash(name, internal, partitions, authorizedOperations); + } + + /** + * Create an instance with the specified parameters. + * + * @param name The topic name + * @param internal Whether the topic is internal to Kafka + * @param partitions A list of partitions where the index represents the partition id and the element contains + * leadership and replica information for that partition. + */ + public TopicDescription(String name, boolean internal, List partitions) { + this(name, internal, partitions, Collections.emptySet()); + } + + /** + * Create an instance with the specified parameters. + * + * @param name The topic name + * @param internal Whether the topic is internal to Kafka + * @param partitions A list of partitions where the index represents the partition id and the element contains + * leadership and replica information for that partition. + * @param authorizedOperations authorized operations for this topic, or empty set if this is not known. + */ + public TopicDescription(String name, boolean internal, List partitions, + Set authorizedOperations) { + this(name, internal, partitions, authorizedOperations, Uuid.ZERO_UUID); + } + + /** + * Create an instance with the specified parameters. + * + * @param name The topic name + * @param internal Whether the topic is internal to Kafka + * @param partitions A list of partitions where the index represents the partition id and the element contains + * leadership and replica information for that partition. + * @param authorizedOperations authorized operations for this topic, or empty set if this is not known. + * @param topicId the topic id + */ + public TopicDescription(String name, boolean internal, List partitions, + Set authorizedOperations, Uuid topicId) { + this.name = name; + this.internal = internal; + this.partitions = partitions; + this.authorizedOperations = authorizedOperations; + this.topicId = topicId; + } + + /** + * The name of the topic. + */ + public String name() { + return name; + } + + /** + * Whether the topic is internal to Kafka. An example of an internal topic is the offsets and group management topic: + * __consumer_offsets. + */ + public boolean isInternal() { + return internal; + } + + public Uuid topicId() { + return topicId; + } + + /** + * A list of partitions where the index represents the partition id and the element contains leadership and replica + * information for that partition. + */ + public List partitions() { + return partitions; + } + + /** + * authorized operations for this topic, or null if this is not known. + */ + public Set authorizedOperations() { + return authorizedOperations; + } + + @Override + public String toString() { + return "(name=" + name + ", internal=" + internal + ", partitions=" + + Utils.join(partitions, ",") + ", authorizedOperations=" + authorizedOperations + ")"; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/TopicListing.java b/clients/src/main/java/org/apache/kafka/clients/admin/TopicListing.java new file mode 100644 index 0000000..42ceeff --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/TopicListing.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.Uuid; + +/** + * A listing of a topic in the cluster. + */ +public class TopicListing { + private final String name; + private final Uuid topicId; + private final boolean internal; + + /** + * Create an instance with the specified parameters. + * + * @param name The topic name + * @param internal Whether the topic is internal to Kafka + * @deprecated Since 3.0 use {@link #TopicListing(String, Uuid, boolean)} instead + */ + @Deprecated + public TopicListing(String name, boolean internal) { + this.name = name; + this.internal = internal; + this.topicId = Uuid.ZERO_UUID; + } + + /** + * Create an instance with the specified parameters. + * + * @param name The topic name + * @param topicId The topic id. + * @param internal Whether the topic is internal to Kafka + */ + public TopicListing(String name, Uuid topicId, boolean internal) { + this.topicId = topicId; + this.name = name; + this.internal = internal; + } + + /** + * The id of the topic. + */ + public Uuid topicId() { + return topicId; + } + + /** + * The name of the topic. + */ + public String name() { + return name; + } + + /** + * Whether the topic is internal to Kafka. An example of an internal topic is the offsets and group management topic: + * __consumer_offsets. + */ + public boolean isInternal() { + return internal; + } + + @Override + public String toString() { + return "(name=" + name + ", topicId=" + topicId + ", internal=" + internal + ")"; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/TransactionDescription.java b/clients/src/main/java/org/apache/kafka/clients/admin/TransactionDescription.java new file mode 100644 index 0000000..5a16919 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/TransactionDescription.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Objects; +import java.util.OptionalLong; +import java.util.Set; + +@InterfaceStability.Evolving +public class TransactionDescription { + private final int coordinatorId; + private final TransactionState state; + private final long producerId; + private final int producerEpoch; + private final long transactionTimeoutMs; + private final OptionalLong transactionStartTimeMs; + private final Set topicPartitions; + + public TransactionDescription( + int coordinatorId, + TransactionState state, + long producerId, + int producerEpoch, + long transactionTimeoutMs, + OptionalLong transactionStartTimeMs, + Set topicPartitions + ) { + this.coordinatorId = coordinatorId; + this.state = state; + this.producerId = producerId; + this.producerEpoch = producerEpoch; + this.transactionTimeoutMs = transactionTimeoutMs; + this.transactionStartTimeMs = transactionStartTimeMs; + this.topicPartitions = topicPartitions; + } + + public int coordinatorId() { + return coordinatorId; + } + + public TransactionState state() { + return state; + } + + public long producerId() { + return producerId; + } + + public int producerEpoch() { + return producerEpoch; + } + + public long transactionTimeoutMs() { + return transactionTimeoutMs; + } + + public OptionalLong transactionStartTimeMs() { + return transactionStartTimeMs; + } + + public Set topicPartitions() { + return topicPartitions; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TransactionDescription that = (TransactionDescription) o; + return coordinatorId == that.coordinatorId && + producerId == that.producerId && + producerEpoch == that.producerEpoch && + transactionTimeoutMs == that.transactionTimeoutMs && + state == that.state && + Objects.equals(transactionStartTimeMs, that.transactionStartTimeMs) && + Objects.equals(topicPartitions, that.topicPartitions); + } + + @Override + public int hashCode() { + return Objects.hash(coordinatorId, state, producerId, producerEpoch, transactionTimeoutMs, transactionStartTimeMs, topicPartitions); + } + + @Override + public String toString() { + return "TransactionDescription(" + + "coordinatorId=" + coordinatorId + + ", state=" + state + + ", producerId=" + producerId + + ", producerEpoch=" + producerEpoch + + ", transactionTimeoutMs=" + transactionTimeoutMs + + ", transactionStartTimeMs=" + transactionStartTimeMs + + ", topicPartitions=" + topicPartitions + + ')'; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/TransactionListing.java b/clients/src/main/java/org/apache/kafka/clients/admin/TransactionListing.java new file mode 100644 index 0000000..e8e7eb6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/TransactionListing.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Objects; + +@InterfaceStability.Evolving +public class TransactionListing { + private final String transactionalId; + private final long producerId; + private final TransactionState transactionState; + + public TransactionListing( + String transactionalId, + long producerId, + TransactionState transactionState + ) { + this.transactionalId = transactionalId; + this.producerId = producerId; + this.transactionState = transactionState; + } + + public String transactionalId() { + return transactionalId; + } + + public long producerId() { + return producerId; + } + + public TransactionState state() { + return transactionState; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TransactionListing that = (TransactionListing) o; + return producerId == that.producerId && + Objects.equals(transactionalId, that.transactionalId) && + transactionState == that.transactionState; + } + + @Override + public int hashCode() { + return Objects.hash(transactionalId, producerId, transactionState); + } + + @Override + public String toString() { + return "TransactionListing(" + + "transactionalId='" + transactionalId + '\'' + + ", producerId=" + producerId + + ", transactionState=" + transactionState + + ')'; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/TransactionState.java b/clients/src/main/java/org/apache/kafka/clients/admin/TransactionState.java new file mode 100644 index 0000000..9ff0966 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/TransactionState.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Arrays; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; + +@InterfaceStability.Evolving +public enum TransactionState { + ONGOING("Ongoing"), + PREPARE_ABORT("PrepareAbort"), + PREPARE_COMMIT("PrepareCommit"), + COMPLETE_ABORT("CompleteAbort"), + COMPLETE_COMMIT("CompleteCommit"), + EMPTY("Empty"), + PREPARE_EPOCH_FENCE("PrepareEpochFence"), + UNKNOWN("Unknown"); + + private final static Map NAME_TO_ENUM = Arrays.stream(values()) + .collect(Collectors.toMap(state -> state.name, Function.identity())); + + private final String name; + + TransactionState(String name) { + this.name = name; + } + + @Override + public String toString() { + return name; + } + + public static TransactionState parse(String name) { + TransactionState state = NAME_TO_ENUM.get(name); + return state == null ? UNKNOWN : state; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/UnregisterBrokerOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/UnregisterBrokerOptions.java new file mode 100644 index 0000000..1935b79 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/UnregisterBrokerOptions.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.annotation.InterfaceStability; + +/** + * Options for {@link Admin#unregisterBroker(int, UnregisterBrokerOptions)}. + * + * The API of this class is evolving. See {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class UnregisterBrokerOptions extends AbstractOptions { +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/UnregisterBrokerResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/UnregisterBrokerResult.java new file mode 100644 index 0000000..b44c7e0 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/UnregisterBrokerResult.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; + +/** + * The result of the {@link Admin#unregisterBroker(int, UnregisterBrokerOptions)} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +public class UnregisterBrokerResult { + private final KafkaFuture future; + + UnregisterBrokerResult(final KafkaFuture future) { + this.future = future; + } + + /** + * Return a future which succeeds if the operation is successful. + */ + public KafkaFuture all() { + return future; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/UpdateFeaturesOptions.java b/clients/src/main/java/org/apache/kafka/clients/admin/UpdateFeaturesOptions.java new file mode 100644 index 0000000..7a9f214 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/UpdateFeaturesOptions.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import java.util.Map; +import org.apache.kafka.common.annotation.InterfaceStability; + +/** + * Options for {@link AdminClient#updateFeatures(Map, UpdateFeaturesOptions)}. + * + * The API of this class is evolving. See {@link Admin} for details. + */ +@InterfaceStability.Evolving +public class UpdateFeaturesOptions extends AbstractOptions { +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/UpdateFeaturesResult.java b/clients/src/main/java/org/apache/kafka/clients/admin/UpdateFeaturesResult.java new file mode 100644 index 0000000..6c484dc --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/UpdateFeaturesResult.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import java.util.Map; +import org.apache.kafka.common.KafkaFuture; + +/** + * The result of the {@link Admin#updateFeatures(Map, UpdateFeaturesOptions)} call. + * + * The API of this class is evolving, see {@link Admin} for details. + */ +public class UpdateFeaturesResult { + private final Map> futures; + + /** + * @param futures a map from feature name to future, which can be used to check the status of + * individual feature updates. + */ + UpdateFeaturesResult(final Map> futures) { + this.futures = futures; + } + + public Map> values() { + return futures; + } + + /** + * Return a future which succeeds if all the feature updates succeed. + */ + public KafkaFuture all() { + return KafkaFuture.allOf(futures.values().toArray(new KafkaFuture[0])); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/UserScramCredentialAlteration.java b/clients/src/main/java/org/apache/kafka/clients/admin/UserScramCredentialAlteration.java new file mode 100644 index 0000000..8293fe5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/UserScramCredentialAlteration.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import java.util.Objects; + +/** + * A request to alter a user's SASL/SCRAM credentials. + * + * @see KIP-554: Add Broker-side SCRAM Config API + */ +public abstract class UserScramCredentialAlteration { + protected final String user; + + /** + * + * @param user the mandatory user + */ + protected UserScramCredentialAlteration(String user) { + this.user = Objects.requireNonNull(user); + } + + /** + * + * @return the always non-null user + */ + public String user() { + return this.user; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/UserScramCredentialDeletion.java b/clients/src/main/java/org/apache/kafka/clients/admin/UserScramCredentialDeletion.java new file mode 100644 index 0000000..633075a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/UserScramCredentialDeletion.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import java.util.Objects; + +/** + * A request to delete a SASL/SCRAM credential for a user. + * + * @see KIP-554: Add Broker-side SCRAM Config API + */ +public class UserScramCredentialDeletion extends UserScramCredentialAlteration { + private final ScramMechanism mechanism; + + /** + * @param user the mandatory user + * @param mechanism the mandatory mechanism + */ + public UserScramCredentialDeletion(String user, ScramMechanism mechanism) { + super(user); + this.mechanism = Objects.requireNonNull(mechanism); + } + + /** + * + * @return the always non-null mechanism + */ + public ScramMechanism mechanism() { + return mechanism; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/UserScramCredentialUpsertion.java b/clients/src/main/java/org/apache/kafka/clients/admin/UserScramCredentialUpsertion.java new file mode 100644 index 0000000..5d5cf9c --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/UserScramCredentialUpsertion.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.security.scram.internals.ScramFormatter; + +import java.nio.charset.StandardCharsets; +import java.security.SecureRandom; +import java.util.Objects; + +/** + * A request to update/insert a SASL/SCRAM credential for a user. + * + * @see KIP-554: Add Broker-side SCRAM Config API + */ +public class UserScramCredentialUpsertion extends UserScramCredentialAlteration { + private final ScramCredentialInfo info; + private final byte[] salt; + private final byte[] password; + + /** + * Constructor that generates a random salt + * + * @param user the user for which the credential is to be updated/inserted + * @param credentialInfo the mechanism and iterations to be used + * @param password the password + */ + public UserScramCredentialUpsertion(String user, ScramCredentialInfo credentialInfo, String password) { + this(user, credentialInfo, password.getBytes(StandardCharsets.UTF_8)); + } + + /** + * Constructor that generates a random salt + * + * @param user the user for which the credential is to be updated/inserted + * @param credentialInfo the mechanism and iterations to be used + * @param password the password + */ + public UserScramCredentialUpsertion(String user, ScramCredentialInfo credentialInfo, byte[] password) { + this(user, credentialInfo, password, generateRandomSalt()); + } + + /** + * Constructor that accepts an explicit salt + * + * @param user the user for which the credential is to be updated/inserted + * @param credentialInfo the mechanism and iterations to be used + * @param password the password + * @param salt the salt to be used + */ + public UserScramCredentialUpsertion(String user, ScramCredentialInfo credentialInfo, byte[] password, byte[] salt) { + super(Objects.requireNonNull(user)); + this.info = Objects.requireNonNull(credentialInfo); + this.password = Objects.requireNonNull(password); + this.salt = Objects.requireNonNull(salt); + } + + /** + * + * @return the mechanism and iterations + */ + public ScramCredentialInfo credentialInfo() { + return info; + } + + /** + * + * @return the salt + */ + public byte[] salt() { + return salt; + } + + /** + * + * @return the password + */ + public byte[] password() { + return password; + } + + private static byte[] generateRandomSalt() { + return ScramFormatter.secureRandomBytes(new SecureRandom()); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/UserScramCredentialsDescription.java b/clients/src/main/java/org/apache/kafka/clients/admin/UserScramCredentialsDescription.java new file mode 100644 index 0000000..97bc358 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/UserScramCredentialsDescription.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +/** + * Representation of all SASL/SCRAM credentials associated with a user that can be retrieved, or an exception indicating + * why credentials could not be retrieved. + * + * @see KIP-554: Add Broker-side SCRAM Config API + */ +public class UserScramCredentialsDescription { + private final String name; + private final List credentialInfos; + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + UserScramCredentialsDescription that = (UserScramCredentialsDescription) o; + return name.equals(that.name) && + credentialInfos.equals(that.credentialInfos); + } + + @Override + public int hashCode() { + return Objects.hash(name, credentialInfos); + } + + @Override + public String toString() { + return "UserScramCredentialsDescription{" + + "name='" + name + '\'' + + ", credentialInfos=" + credentialInfos + + '}'; + } + + /** + * + * @param name the required user name + * @param credentialInfos the required SASL/SCRAM credential representations for the user + */ + public UserScramCredentialsDescription(String name, List credentialInfos) { + this.name = Objects.requireNonNull(name); + this.credentialInfos = Collections.unmodifiableList(new ArrayList<>(credentialInfos)); + } + + /** + * + * @return the user name + */ + public String name() { + return name; + } + + /** + * + * @return the always non-null/unmodifiable list of SASL/SCRAM credential representations for the user + */ + public List credentialInfos() { + return credentialInfos; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/internals/AbortTransactionHandler.java b/clients/src/main/java/org/apache/kafka/clients/admin/internals/AbortTransactionHandler.java new file mode 100644 index 0000000..c25e4d8 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/internals/AbortTransactionHandler.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import org.apache.kafka.clients.admin.AbortTransactionSpec; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.ClusterAuthorizationException; +import org.apache.kafka.common.errors.InvalidProducerEpochException; +import org.apache.kafka.common.errors.TransactionCoordinatorFencedException; +import org.apache.kafka.common.message.WriteTxnMarkersRequestData; +import org.apache.kafka.common.message.WriteTxnMarkersResponseData; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.AbstractResponse; +import org.apache.kafka.common.requests.WriteTxnMarkersRequest; +import org.apache.kafka.common.requests.WriteTxnMarkersResponse; +import org.apache.kafka.common.utils.LogContext; +import org.slf4j.Logger; + +import java.util.List; +import java.util.Set; + +import static java.util.Collections.singleton; +import static java.util.Collections.singletonList; + +public class AbortTransactionHandler implements AdminApiHandler { + private final Logger log; + private final AbortTransactionSpec abortSpec; + private final PartitionLeaderStrategy lookupStrategy; + + public AbortTransactionHandler( + AbortTransactionSpec abortSpec, + LogContext logContext + ) { + this.abortSpec = abortSpec; + this.log = logContext.logger(AbortTransactionHandler.class); + this.lookupStrategy = new PartitionLeaderStrategy(logContext); + } + + public static AdminApiFuture.SimpleAdminApiFuture newFuture( + Set topicPartitions + ) { + return AdminApiFuture.forKeys(topicPartitions); + } + + @Override + public String apiName() { + return "abortTransaction"; + } + + @Override + public AdminApiLookupStrategy lookupStrategy() { + return lookupStrategy; + } + + @Override + public WriteTxnMarkersRequest.Builder buildRequest( + int brokerId, + Set topicPartitions + ) { + validateTopicPartitions(topicPartitions); + + WriteTxnMarkersRequestData.WritableTxnMarker marker = new WriteTxnMarkersRequestData.WritableTxnMarker() + .setCoordinatorEpoch(abortSpec.coordinatorEpoch()) + .setProducerEpoch(abortSpec.producerEpoch()) + .setProducerId(abortSpec.producerId()) + .setTransactionResult(false); + + marker.topics().add(new WriteTxnMarkersRequestData.WritableTxnMarkerTopic() + .setName(abortSpec.topicPartition().topic()) + .setPartitionIndexes(singletonList(abortSpec.topicPartition().partition())) + ); + + WriteTxnMarkersRequestData request = new WriteTxnMarkersRequestData(); + request.markers().add(marker); + + return new WriteTxnMarkersRequest.Builder(request); + } + + @Override + public ApiResult handleResponse( + Node broker, + Set topicPartitions, + AbstractResponse abstractResponse + ) { + validateTopicPartitions(topicPartitions); + + WriteTxnMarkersResponse response = (WriteTxnMarkersResponse) abstractResponse; + List markerResponses = response.data().markers(); + + if (markerResponses.size() != 1 || markerResponses.get(0).producerId() != abortSpec.producerId()) { + return ApiResult.failed(abortSpec.topicPartition(), new KafkaException("WriteTxnMarkers response " + + "included unexpected marker entries: " + markerResponses + "(expected to find exactly one " + + "entry with producerId " + abortSpec.producerId() + ")")); + } + + WriteTxnMarkersResponseData.WritableTxnMarkerResult markerResponse = markerResponses.get(0); + List topicResponses = markerResponse.topics(); + + if (topicResponses.size() != 1 || !topicResponses.get(0).name().equals(abortSpec.topicPartition().topic())) { + return ApiResult.failed(abortSpec.topicPartition(), new KafkaException("WriteTxnMarkers response " + + "included unexpected topic entries: " + markerResponses + "(expected to find exactly one " + + "entry with topic partition " + abortSpec.topicPartition() + ")")); + } + + WriteTxnMarkersResponseData.WritableTxnMarkerTopicResult topicResponse = topicResponses.get(0); + List partitionResponses = + topicResponse.partitions(); + + if (partitionResponses.size() != 1 || partitionResponses.get(0).partitionIndex() != abortSpec.topicPartition().partition()) { + return ApiResult.failed(abortSpec.topicPartition(), new KafkaException("WriteTxnMarkers response " + + "included unexpected partition entries for topic " + abortSpec.topicPartition().topic() + + ": " + markerResponses + "(expected to find exactly one entry with partition " + + abortSpec.topicPartition().partition() + ")")); + } + + WriteTxnMarkersResponseData.WritableTxnMarkerPartitionResult partitionResponse = partitionResponses.get(0); + Errors error = Errors.forCode(partitionResponse.errorCode()); + + if (error != Errors.NONE) { + return handleError(error); + } else { + return ApiResult.completed(abortSpec.topicPartition(), null); + } + } + + private ApiResult handleError(Errors error) { + switch (error) { + case CLUSTER_AUTHORIZATION_FAILED: + log.error("WriteTxnMarkers request for abort spec {} failed cluster authorization", abortSpec); + return ApiResult.failed(abortSpec.topicPartition(), new ClusterAuthorizationException( + "WriteTxnMarkers request with " + abortSpec + " failed due to cluster " + + "authorization error")); + + case INVALID_PRODUCER_EPOCH: + log.error("WriteTxnMarkers request for abort spec {} failed due to an invalid producer epoch", + abortSpec); + return ApiResult.failed(abortSpec.topicPartition(), new InvalidProducerEpochException( + "WriteTxnMarkers request with " + abortSpec + " failed due an invalid producer epoch")); + + case TRANSACTION_COORDINATOR_FENCED: + log.error("WriteTxnMarkers request for abort spec {} failed because the coordinator epoch is fenced", + abortSpec); + return ApiResult.failed(abortSpec.topicPartition(), new TransactionCoordinatorFencedException( + "WriteTxnMarkers request with " + abortSpec + " failed since the provided " + + "coordinator epoch " + abortSpec.coordinatorEpoch() + " has been fenced " + + "by the active coordinator")); + + case NOT_LEADER_OR_FOLLOWER: + case REPLICA_NOT_AVAILABLE: + case BROKER_NOT_AVAILABLE: + case UNKNOWN_TOPIC_OR_PARTITION: + log.debug("WriteTxnMarkers request for abort spec {} failed due to {}. Will retry after attempting to " + + "find the leader again", abortSpec, error); + return ApiResult.unmapped(singletonList(abortSpec.topicPartition())); + + default: + log.error("WriteTxnMarkers request for abort spec {} failed due to an unexpected error {}", + abortSpec, error); + return ApiResult.failed(abortSpec.topicPartition(), error.exception( + "WriteTxnMarkers request with " + abortSpec + " failed due to unexpected error: " + error.message())); + } + } + + private void validateTopicPartitions(Set topicPartitions) { + if (!topicPartitions.equals(singleton(abortSpec.topicPartition()))) { + throw new IllegalArgumentException("Received unexpected topic partitions " + topicPartitions + + " (expected only " + singleton(abortSpec.topicPartition()) + ")"); + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/internals/AdminApiDriver.java b/clients/src/main/java/org/apache/kafka/clients/admin/internals/AdminApiDriver.java new file mode 100644 index 0000000..b5c9ff3 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/internals/AdminApiDriver.java @@ -0,0 +1,483 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import org.apache.kafka.common.Node; +import org.apache.kafka.common.errors.DisconnectException; +import org.apache.kafka.common.requests.AbstractRequest; +import org.apache.kafka.common.requests.AbstractResponse; +import org.apache.kafka.common.requests.FindCoordinatorRequest.NoBatchedFindCoordinatorsException; +import org.apache.kafka.common.utils.LogContext; +import org.slf4j.Logger; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.Set; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * The `KafkaAdminClient`'s internal `Call` primitive is not a good fit for multi-stage + * request workflows such as we see with the group coordinator APIs or any request which + * needs to be sent to a partition leader. Typically these APIs have two concrete stages: + * + * 1. Lookup: Find the broker that can fulfill the request (e.g. partition leader or group + * coordinator) + * 2. Fulfillment: Send the request to the broker found in the first step + * + * This is complicated by the fact that `Admin` APIs are typically batched, which + * means the Lookup stage may result in a set of brokers. For example, take a `ListOffsets` + * request for a set of topic partitions. In the Lookup stage, we will find the partition + * leaders for this set of partitions; in the Fulfillment stage, we will group together + * partition according to the IDs of the discovered leaders. + * + * Additionally, the flow between these two stages is bi-directional. We may find after + * sending a `ListOffsets` request to an expected leader that there was a leader change. + * This would result in a topic partition being sent back to the Lookup stage. + * + * Managing this complexity by chaining together `Call` implementations is challenging + * and messy, so instead we use this class to do the bookkeeping. It handles both the + * batching aspect as well as the transitions between the Lookup and Fulfillment stages. + * + * Note that the interpretation of the `retries` configuration becomes ambiguous + * for this kind of pipeline. We could treat it as an overall limit on the number + * of requests that can be sent, but that is not very useful because each pipeline + * has a minimum number of requests that need to be sent in order to satisfy the request. + * Instead, we treat this number of retries independently at each stage so that each + * stage has at least one opportunity to complete. So if a user sets `retries=1`, then + * the full pipeline can still complete as long as there are no request failures. + * + * @param The key type, which is also the granularity of the request routing (e.g. + * this could be `TopicPartition` in the case of requests intended for a partition + * leader or the `GroupId` in the case of consumer group requests intended for + * the group coordinator) + * @param The fulfillment type for each key (e.g. this could be consumer group state + * when the key type is a consumer `GroupId`) + */ +public class AdminApiDriver { + private final Logger log; + private final long retryBackoffMs; + private final long deadlineMs; + private final AdminApiHandler handler; + private final AdminApiFuture future; + + private final BiMultimap lookupMap = new BiMultimap<>(); + private final BiMultimap fulfillmentMap = new BiMultimap<>(); + private final Map requestStates = new HashMap<>(); + + public AdminApiDriver( + AdminApiHandler handler, + AdminApiFuture future, + long deadlineMs, + long retryBackoffMs, + LogContext logContext + ) { + this.handler = handler; + this.future = future; + this.deadlineMs = deadlineMs; + this.retryBackoffMs = retryBackoffMs; + this.log = logContext.logger(AdminApiDriver.class); + retryLookup(future.lookupKeys()); + } + + /** + * Associate a key with a brokerId. This is called after a response in the Lookup + * stage reveals the mapping (e.g. when the `FindCoordinator` tells us the group + * coordinator for a specific consumer group). + */ + private void map(K key, Integer brokerId) { + lookupMap.remove(key); + fulfillmentMap.put(new FulfillmentScope(brokerId), key); + } + + /** + * Disassociate a key from the currently mapped brokerId. This will send the key + * back to the Lookup stage, which will allow us to attempt lookup again. + */ + private void unmap(K key) { + fulfillmentMap.remove(key); + + ApiRequestScope lookupScope = handler.lookupStrategy().lookupScope(key); + OptionalInt destinationBrokerId = lookupScope.destinationBrokerId(); + + if (destinationBrokerId.isPresent()) { + fulfillmentMap.put(new FulfillmentScope(destinationBrokerId.getAsInt()), key); + } else { + lookupMap.put(handler.lookupStrategy().lookupScope(key), key); + } + } + + private void clear(Collection keys) { + keys.forEach(key -> { + lookupMap.remove(key); + fulfillmentMap.remove(key); + }); + } + + OptionalInt keyToBrokerId(K key) { + Optional scope = fulfillmentMap.getKey(key); + return scope + .map(fulfillmentScope -> OptionalInt.of(fulfillmentScope.destinationBrokerId)) + .orElseGet(OptionalInt::empty); + } + + /** + * Complete the future associated with the given key exceptionally. After is called, + * the key will be taken out of both the Lookup and Fulfillment stages so that request + * are not retried. + */ + private void completeExceptionally(Map errors) { + if (!errors.isEmpty()) { + future.completeExceptionally(errors); + clear(errors.keySet()); + } + } + + private void completeLookupExceptionally(Map errors) { + if (!errors.isEmpty()) { + future.completeLookupExceptionally(errors); + clear(errors.keySet()); + } + } + + private void retryLookup(Collection keys) { + keys.forEach(this::unmap); + } + + /** + * Complete the future associated with the given key. After this is called, all keys will + * be taken out of both the Lookup and Fulfillment stages so that request are not retried. + */ + private void complete(Map values) { + if (!values.isEmpty()) { + future.complete(values); + clear(values.keySet()); + } + } + + private void completeLookup(Map brokerIdMapping) { + if (!brokerIdMapping.isEmpty()) { + future.completeLookup(brokerIdMapping); + brokerIdMapping.forEach(this::map); + } + } + + /** + * Check whether any requests need to be sent. This should be called immediately + * after the driver is constructed and then again after each request returns + * (i.e. after {@link #onFailure(long, RequestSpec, Throwable)} or + * {@link #onResponse(long, RequestSpec, AbstractResponse, Node)}). + * + * @return A list of requests that need to be sent + */ + public List> poll() { + List> requests = new ArrayList<>(); + collectLookupRequests(requests); + collectFulfillmentRequests(requests); + return requests; + } + + /** + * Callback that is invoked when a `Call` returns a response successfully. + */ + public void onResponse( + long currentTimeMs, + RequestSpec spec, + AbstractResponse response, + Node node + ) { + clearInflightRequest(currentTimeMs, spec); + + if (spec.scope instanceof FulfillmentScope) { + AdminApiHandler.ApiResult result = handler.handleResponse( + node, + spec.keys, + response + ); + complete(result.completedKeys); + completeExceptionally(result.failedKeys); + retryLookup(result.unmappedKeys); + } else { + AdminApiLookupStrategy.LookupResult result = handler.lookupStrategy().handleResponse( + spec.keys, + response + ); + + result.completedKeys.forEach(lookupMap::remove); + completeLookup(result.mappedKeys); + completeLookupExceptionally(result.failedKeys); + } + } + + /** + * Callback that is invoked when a `Call` is failed. + */ + public void onFailure( + long currentTimeMs, + RequestSpec spec, + Throwable t + ) { + clearInflightRequest(currentTimeMs, spec); + if (t instanceof DisconnectException) { + log.debug("Node disconnected before response could be received for request {}. " + + "Will attempt retry", spec.request); + + // After a disconnect, we want the driver to attempt to lookup the key + // again. This gives us a chance to find a new coordinator or partition + // leader for example. + Set keysToUnmap = spec.keys.stream() + .filter(future.lookupKeys()::contains) + .collect(Collectors.toSet()); + retryLookup(keysToUnmap); + + } else if (t instanceof NoBatchedFindCoordinatorsException) { + ((CoordinatorStrategy) handler.lookupStrategy()).disableBatch(); + Set keysToUnmap = spec.keys.stream() + .filter(future.lookupKeys()::contains) + .collect(Collectors.toSet()); + retryLookup(keysToUnmap); + } else { + Map errors = spec.keys.stream().collect(Collectors.toMap( + Function.identity(), + key -> t + )); + + if (spec.scope instanceof FulfillmentScope) { + completeExceptionally(errors); + } else { + completeLookupExceptionally(errors); + } + } + } + + private void clearInflightRequest(long currentTimeMs, RequestSpec spec) { + RequestState requestState = requestStates.get(spec.scope); + if (requestState != null) { + // Only apply backoff if it's not a retry of a lookup request + if (spec.scope instanceof FulfillmentScope) { + requestState.clearInflight(currentTimeMs + retryBackoffMs); + } else { + requestState.clearInflight(currentTimeMs); + } + } + } + + private void collectRequests( + List> requests, + BiMultimap multimap, + BiFunction, T, AbstractRequest.Builder> buildRequest + ) { + for (Map.Entry> entry : multimap.entrySet()) { + T scope = entry.getKey(); + + Set keys = entry.getValue(); + if (keys.isEmpty()) { + continue; + } + + RequestState requestState = requestStates.computeIfAbsent(scope, c -> new RequestState()); + if (requestState.hasInflight()) { + continue; + } + + // Copy the keys to avoid exposing the underlying mutable set + Set copyKeys = Collections.unmodifiableSet(new HashSet<>(keys)); + + AbstractRequest.Builder request = buildRequest.apply(copyKeys, scope); + RequestSpec spec = new RequestSpec<>( + handler.apiName() + "(api=" + request.apiKey() + ")", + scope, + copyKeys, + request, + requestState.nextAllowedRetryMs, + deadlineMs, + requestState.tries + ); + + requestState.setInflight(spec); + requests.add(spec); + } + } + + private void collectLookupRequests(List> requests) { + collectRequests( + requests, + lookupMap, + (keys, scope) -> handler.lookupStrategy().buildRequest(keys) + ); + } + + private void collectFulfillmentRequests(List> requests) { + collectRequests( + requests, + fulfillmentMap, + (keys, scope) -> handler.buildRequest(scope.destinationBrokerId, keys) + ); + } + + /** + * This is a helper class which helps us to map requests that need to be sent + * to the internal `Call` implementation that is used internally in + * {@link org.apache.kafka.clients.admin.KafkaAdminClient}. + */ + public static class RequestSpec { + public final String name; + public final ApiRequestScope scope; + public final Set keys; + public final AbstractRequest.Builder request; + public final long nextAllowedTryMs; + public final long deadlineMs; + public final int tries; + + public RequestSpec( + String name, + ApiRequestScope scope, + Set keys, + AbstractRequest.Builder request, + long nextAllowedTryMs, + long deadlineMs, + int tries + ) { + this.name = name; + this.scope = scope; + this.keys = keys; + this.request = request; + this.nextAllowedTryMs = nextAllowedTryMs; + this.deadlineMs = deadlineMs; + this.tries = tries; + } + + @Override + public String toString() { + return "RequestSpec(" + + "name=" + name + + ", scope=" + scope + + ", keys=" + keys + + ", request=" + request + + ", nextAllowedTryMs=" + nextAllowedTryMs + + ", deadlineMs=" + deadlineMs + + ", tries=" + tries + + ')'; + } + } + + /** + * Helper class used to track the request state within each request scope. + * This class enforces a maximum number of inflight request and keeps track + * of backoff/retry state. + */ + private class RequestState { + private Optional> inflightRequest = Optional.empty(); + private int tries = 0; + private long nextAllowedRetryMs = 0; + + boolean hasInflight() { + return inflightRequest.isPresent(); + } + + public void clearInflight(long nextAllowedRetryMs) { + this.inflightRequest = Optional.empty(); + this.nextAllowedRetryMs = nextAllowedRetryMs; + } + + public void setInflight(RequestSpec spec) { + this.inflightRequest = Optional.of(spec); + this.tries++; + } + } + + /** + * Completion of the Lookup stage results in a destination broker to send the + * fulfillment request to. Each destination broker in the Fulfillment stage + * gets its own request scope. + */ + private static class FulfillmentScope implements ApiRequestScope { + public final int destinationBrokerId; + + private FulfillmentScope(int destinationBrokerId) { + this.destinationBrokerId = destinationBrokerId; + } + + @Override + public OptionalInt destinationBrokerId() { + return OptionalInt.of(destinationBrokerId); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + FulfillmentScope that = (FulfillmentScope) o; + return destinationBrokerId == that.destinationBrokerId; + } + + @Override + public int hashCode() { + return Objects.hash(destinationBrokerId); + } + } + + /** + * Helper class which maintains a bi-directional mapping from a key to a set of values. + * Each value can map to one and only one key, but many values can be associated with + * a single key. + * + * @param The key type + * @param The value type + */ + private static class BiMultimap { + private final Map reverseMap = new HashMap<>(); + private final Map> map = new HashMap<>(); + + void put(K key, V value) { + remove(value); + reverseMap.put(value, key); + map.computeIfAbsent(key, k -> new HashSet<>()).add(value); + } + + void remove(V value) { + K key = reverseMap.remove(value); + if (key != null) { + Set set = map.get(key); + if (set != null) { + set.remove(value); + if (set.isEmpty()) { + map.remove(key); + } + } + } + } + + Optional getKey(V value) { + return Optional.ofNullable(reverseMap.get(value)); + } + + Set>> entrySet() { + return map.entrySet(); + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/internals/AdminApiFuture.java b/clients/src/main/java/org/apache/kafka/clients/admin/internals/AdminApiFuture.java new file mode 100644 index 0000000..b0294d8 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/internals/AdminApiFuture.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.internals.KafkaFutureImpl; + +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +public interface AdminApiFuture { + + /** + * The initial set of lookup keys. Although this will usually match the fulfillment + * keys, it does not necessarily have to. For example, in the case of + * {@link AllBrokersStrategy.AllBrokersFuture}, + * we use the lookup phase in order to discover the set of keys that will be searched + * during the fulfillment phase. + * + * @return non-empty set of initial lookup keys + */ + Set lookupKeys(); + + /** + * Complete the futures associated with the given keys. + * + * @param values the completed keys with their respective values + */ + void complete(Map values); + + /** + * Invoked when lookup of a set of keys succeeds. + * + * @param brokerIdMapping the discovered mapping from key to the respective brokerId that will + * handle the fulfillment request + */ + default void completeLookup(Map brokerIdMapping) { + } + + /** + * Invoked when lookup fails with a fatal error on a set of keys. + * + * @param lookupErrors the set of keys that failed lookup with their respective errors + */ + default void completeLookupExceptionally(Map lookupErrors) { + completeExceptionally(lookupErrors); + } + + /** + * Complete the futures associated with the given keys exceptionally. + * + * @param errors the failed keys with their respective errors + */ + void completeExceptionally(Map errors); + + static SimpleAdminApiFuture forKeys(Set keys) { + return new SimpleAdminApiFuture<>(keys); + } + + /** + * This class can be used when the set of keys is known ahead of time. + */ + class SimpleAdminApiFuture implements AdminApiFuture { + private final Map> futures; + + public SimpleAdminApiFuture(Set keys) { + this.futures = keys.stream().collect(Collectors.toMap( + Function.identity(), + k -> new KafkaFutureImpl<>() + )); + } + + @Override + public Set lookupKeys() { + return futures.keySet(); + } + + @Override + public void complete(Map values) { + values.forEach(this::complete); + } + + private void complete(K key, V value) { + futureOrThrow(key).complete(value); + } + + @Override + public void completeExceptionally(Map errors) { + errors.forEach(this::completeExceptionally); + } + + private void completeExceptionally(K key, Throwable t) { + futureOrThrow(key).completeExceptionally(t); + } + + private KafkaFutureImpl futureOrThrow(K key) { + // The below typecast is safe because we initialise futures using only KafkaFutureImpl. + KafkaFutureImpl future = (KafkaFutureImpl) futures.get(key); + if (future == null) { + throw new IllegalArgumentException("Attempt to complete future for " + key + + ", which was not requested"); + } else { + return future; + } + } + + public Map> all() { + return futures; + } + + public KafkaFuture get(K key) { + return futures.get(key); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/internals/AdminApiHandler.java b/clients/src/main/java/org/apache/kafka/clients/admin/internals/AdminApiHandler.java new file mode 100644 index 0000000..9f8d0ac --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/internals/AdminApiHandler.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import org.apache.kafka.common.Node; +import org.apache.kafka.common.requests.AbstractRequest; +import org.apache.kafka.common.requests.AbstractResponse; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; + +public interface AdminApiHandler { + + /** + * Get a user-friendly name for the API this handler is implementing. + */ + String apiName(); + + /** + * Build the request. The set of keys are derived by {@link AdminApiDriver} + * during the lookup stage as the set of keys which all map to the same + * destination broker. + * + * @param brokerId the target brokerId for the request + * @param keys the set of keys that should be handled by this request + * + * @return a builder for the request containing the given keys + */ + AbstractRequest.Builder buildRequest(int brokerId, Set keys); + + /** + * Callback that is invoked when a request returns successfully. + * The handler should parse the response, check for errors, and return a + * result which indicates which keys (if any) have either been completed or + * failed with an unrecoverable error. + * + * It is also possible that the response indicates an incorrect target brokerId + * (e.g. in the case of a NotLeader error when the request is bound for a partition + * leader). In this case the key will be "unmapped" from the target brokerId + * and lookup will be retried. + * + * Note that keys which received a retriable error should be left out of the + * result. They will be retried automatically. + * + * @param broker the broker that the associated request was sent to + * @param keys the set of keys from the associated request + * @param response the response received from the broker + * + * @return result indicating key completion, failure, and unmapping + */ + ApiResult handleResponse(Node broker, Set keys, AbstractResponse response); + + /** + * Get the lookup strategy that is responsible for finding the brokerId + * which will handle each respective key. + * + * @return non-null lookup strategy + */ + AdminApiLookupStrategy lookupStrategy(); + + class ApiResult { + public final Map completedKeys; + public final Map failedKeys; + public final List unmappedKeys; + + public ApiResult( + Map completedKeys, + Map failedKeys, + List unmappedKeys + ) { + this.completedKeys = Collections.unmodifiableMap(completedKeys); + this.failedKeys = Collections.unmodifiableMap(failedKeys); + this.unmappedKeys = Collections.unmodifiableList(unmappedKeys); + } + + public static ApiResult completed(K key, V value) { + return new ApiResult<>( + Collections.singletonMap(key, value), + Collections.emptyMap(), + Collections.emptyList() + ); + } + + public static ApiResult failed(K key, Throwable t) { + return new ApiResult<>( + Collections.emptyMap(), + Collections.singletonMap(key, t), + Collections.emptyList() + ); + } + + public static ApiResult unmapped(List keys) { + return new ApiResult<>( + Collections.emptyMap(), + Collections.emptyMap(), + keys + ); + } + + public static ApiResult empty() { + return new ApiResult<>( + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyList() + ); + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/internals/AdminApiLookupStrategy.java b/clients/src/main/java/org/apache/kafka/clients/admin/internals/AdminApiLookupStrategy.java new file mode 100644 index 0000000..56c0837 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/internals/AdminApiLookupStrategy.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import org.apache.kafka.common.requests.AbstractRequest; +import org.apache.kafka.common.requests.AbstractResponse; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.singletonMap; + +public interface AdminApiLookupStrategy { + + /** + * Define the scope of a given key for lookup. Key lookups are complicated + * by the need to accommodate different batching mechanics. For example, + * a `Metadata` request supports arbitrary batching of topic partitions in + * order to discover partitions leaders. This can be supported by returning + * a single scope object for all keys. + * + * On the other hand, `FindCoordinator` requests only support lookup of a + * single key. This can be supported by returning a different scope object + * for each lookup key. + * + * Note that if the {@link ApiRequestScope#destinationBrokerId()} maps to + * a specific brokerId, then lookup will be skipped. See the use of + * {@link StaticBrokerStrategy} in {@link DescribeProducersHandler} for + * an example of this usage. + * + * @param key the lookup key + * + * @return request scope indicating how lookup requests can be batched together + */ + ApiRequestScope lookupScope(T key); + + /** + * Build the lookup request for a set of keys. The grouping of the keys is controlled + * through {@link #lookupScope(Object)}. In other words, each set of keys that map + * to the same request scope object will be sent to this method. + * + * @param keys the set of keys that require lookup + * + * @return a builder for the lookup request + */ + AbstractRequest.Builder buildRequest(Set keys); + + /** + * Callback that is invoked when a lookup request returns successfully. The handler + * should parse the response, check for errors, and return a result indicating + * which keys were mapped to a brokerId successfully and which keys received + * a fatal error (e.g. a topic authorization failure). + * + * Note that keys which receive a retriable error should be left out of the + * result. They will be retried automatically. For example, if the response of + * `FindCoordinator` request indicates an unavailable coordinator, then the key + * should be left out of the result so that the request will be retried. + * + * @param keys the set of keys from the associated request + * @param response the response received from the broker + * + * @return a result indicating which keys mapped successfully to a brokerId and + * which encountered a fatal error + */ + LookupResult handleResponse(Set keys, AbstractResponse response); + + class LookupResult { + // This is the set of keys that have been completed by the lookup phase itself. + // The driver will not attempt lookup or fulfillment for completed keys. + public final List completedKeys; + + // This is the set of keys that have been mapped to a specific broker for + // fulfillment of the API request. + public final Map mappedKeys; + + // This is the set of keys that have encountered a fatal error during the lookup + // phase. The driver will not attempt lookup or fulfillment for failed keys. + public final Map failedKeys; + + public LookupResult( + Map failedKeys, + Map mappedKeys + ) { + this(Collections.emptyList(), failedKeys, mappedKeys); + } + + public LookupResult( + List completedKeys, + Map failedKeys, + Map mappedKeys + ) { + this.completedKeys = Collections.unmodifiableList(completedKeys); + this.failedKeys = Collections.unmodifiableMap(failedKeys); + this.mappedKeys = Collections.unmodifiableMap(mappedKeys); + } + + static LookupResult empty() { + return new LookupResult<>(emptyMap(), emptyMap()); + } + + static LookupResult failed(K key, Throwable exception) { + return new LookupResult<>(singletonMap(key, exception), emptyMap()); + } + + static LookupResult mapped(K key, Integer brokerId) { + return new LookupResult<>(emptyMap(), singletonMap(key, brokerId)); + } + + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/internals/AdminMetadataManager.java b/clients/src/main/java/org/apache/kafka/clients/admin/internals/AdminMetadataManager.java new file mode 100644 index 0000000..f1ccd9a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/internals/AdminMetadataManager.java @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin.internals; + +import org.apache.kafka.clients.MetadataUpdater; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.requests.RequestHeader; +import org.apache.kafka.common.utils.LogContext; +import org.slf4j.Logger; + +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** + * Manages the metadata for KafkaAdminClient. + * + * This class is not thread-safe. It is only accessed from the AdminClient + * service thread (which also uses the NetworkClient). + */ +public class AdminMetadataManager { + private final Logger log; + + /** + * The minimum amount of time that we should wait between subsequent + * retries, when fetching metadata. + */ + private final long refreshBackoffMs; + + /** + * The minimum amount of time that we should wait before triggering an + * automatic metadata refresh. + */ + private final long metadataExpireMs; + + /** + * Used to update the NetworkClient metadata. + */ + private final AdminMetadataUpdater updater; + + /** + * The current metadata state. + */ + private State state = State.QUIESCENT; + + /** + * The time in wall-clock milliseconds when we last updated the metadata. + */ + private long lastMetadataUpdateMs = 0; + + /** + * The time in wall-clock milliseconds when we last attempted to fetch new + * metadata. + */ + private long lastMetadataFetchAttemptMs = 0; + + /** + * The current cluster information. + */ + private Cluster cluster = Cluster.empty(); + + /** + * If we got an authorization exception when we last attempted to fetch + * metadata, this is it; null, otherwise. + */ + private AuthenticationException authException = null; + + public class AdminMetadataUpdater implements MetadataUpdater { + @Override + public List fetchNodes() { + return cluster.nodes(); + } + + @Override + public boolean isUpdateDue(long now) { + return false; + } + + @Override + public long maybeUpdate(long now) { + return Long.MAX_VALUE; + } + + @Override + public void handleServerDisconnect(long now, String destinationId, Optional maybeFatalException) { + maybeFatalException.ifPresent(AdminMetadataManager.this::updateFailed); + AdminMetadataManager.this.requestUpdate(); + } + + @Override + public void handleFailedRequest(long now, Optional maybeFatalException) { + // Do nothing + } + + @Override + public void handleSuccessfulResponse(RequestHeader requestHeader, long now, MetadataResponse metadataResponse) { + // Do nothing + } + + @Override + public void close() { + } + } + + /** + * The current AdminMetadataManager state. + */ + enum State { + QUIESCENT, + UPDATE_REQUESTED, + UPDATE_PENDING + } + + public AdminMetadataManager(LogContext logContext, long refreshBackoffMs, long metadataExpireMs) { + this.log = logContext.logger(AdminMetadataManager.class); + this.refreshBackoffMs = refreshBackoffMs; + this.metadataExpireMs = metadataExpireMs; + this.updater = new AdminMetadataUpdater(); + } + + public AdminMetadataUpdater updater() { + return updater; + } + + public boolean isReady() { + if (authException != null) { + log.debug("Metadata is not usable: failed to get metadata.", authException); + throw authException; + } + if (cluster.nodes().isEmpty()) { + log.trace("Metadata is not ready: bootstrap nodes have not been " + + "initialized yet."); + return false; + } + if (cluster.isBootstrapConfigured()) { + log.trace("Metadata is not ready: we have not fetched metadata from " + + "the bootstrap nodes yet."); + return false; + } + log.trace("Metadata is ready to use."); + return true; + } + + public Node controller() { + return cluster.controller(); + } + + public Node nodeById(int nodeId) { + return cluster.nodeById(nodeId); + } + + public void requestUpdate() { + if (state == State.QUIESCENT) { + state = State.UPDATE_REQUESTED; + log.debug("Requesting metadata update."); + } + } + + public void clearController() { + if (cluster.controller() != null) { + log.trace("Clearing cached controller node {}.", cluster.controller()); + this.cluster = new Cluster(cluster.clusterResource().clusterId(), + cluster.nodes(), + Collections.emptySet(), + Collections.emptySet(), + Collections.emptySet(), + null); + } + } + + /** + * Determine if the AdminClient should fetch new metadata. + */ + public long metadataFetchDelayMs(long now) { + switch (state) { + case QUIESCENT: + // Calculate the time remaining until the next periodic update. + // We want to avoid making many metadata requests in a short amount of time, + // so there is a metadata refresh backoff period. + return Math.max(delayBeforeNextAttemptMs(now), delayBeforeNextExpireMs(now)); + case UPDATE_REQUESTED: + // Respect the backoff, even if an update has been requested + return delayBeforeNextAttemptMs(now); + default: + // An update is already pending, so we don't need to initiate another one. + return Long.MAX_VALUE; + } + } + + private long delayBeforeNextExpireMs(long now) { + long timeSinceUpdate = now - lastMetadataUpdateMs; + return Math.max(0, metadataExpireMs - timeSinceUpdate); + } + + private long delayBeforeNextAttemptMs(long now) { + long timeSinceAttempt = now - lastMetadataFetchAttemptMs; + return Math.max(0, refreshBackoffMs - timeSinceAttempt); + } + + /** + * Transition into the UPDATE_PENDING state. Updates lastMetadataFetchAttemptMs. + */ + public void transitionToUpdatePending(long now) { + this.state = State.UPDATE_PENDING; + this.lastMetadataFetchAttemptMs = now; + } + + public void updateFailed(Throwable exception) { + // We depend on pending calls to request another metadata update + this.state = State.QUIESCENT; + + if (exception instanceof AuthenticationException) { + log.warn("Metadata update failed due to authentication error", exception); + this.authException = (AuthenticationException) exception; + } else { + log.info("Metadata update failed", exception); + } + } + + /** + * Receive new metadata, and transition into the QUIESCENT state. + * Updates lastMetadataUpdateMs, cluster, and authException. + */ + public void update(Cluster cluster, long now) { + if (cluster.isBootstrapConfigured()) { + log.debug("Setting bootstrap cluster metadata {}.", cluster); + } else { + log.debug("Updating cluster metadata to {}", cluster); + this.lastMetadataUpdateMs = now; + } + + this.state = State.QUIESCENT; + this.authException = null; + + if (!cluster.nodes().isEmpty()) { + this.cluster = cluster; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/internals/AllBrokersStrategy.java b/clients/src/main/java/org/apache/kafka/clients/admin/internals/AllBrokersStrategy.java new file mode 100644 index 0000000..56305db --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/internals/AllBrokersStrategy.java @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.message.MetadataRequestData; +import org.apache.kafka.common.message.MetadataResponseData; +import org.apache.kafka.common.requests.AbstractResponse; +import org.apache.kafka.common.requests.MetadataRequest; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.utils.LogContext; +import org.slf4j.Logger; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.OptionalInt; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * This class is used for use cases which require requests to be sent to all + * brokers in the cluster. + * + * This is a slightly degenerate case of a lookup strategy in the sense that + * the broker IDs are used as both the keys and values. Also, unlike + * {@link CoordinatorStrategy} and {@link PartitionLeaderStrategy}, we do not + * know the set of keys ahead of time: we require the initial lookup in order + * to discover what the broker IDs are. This is represented with a more complex + * type {@code Future>} in the admin API result type. + * For example, see {@link org.apache.kafka.clients.admin.ListTransactionsResult}. + */ +public class AllBrokersStrategy implements AdminApiLookupStrategy { + public static final BrokerKey ANY_BROKER = new BrokerKey(OptionalInt.empty()); + public static final Set LOOKUP_KEYS = Collections.singleton(ANY_BROKER); + private static final ApiRequestScope SINGLE_REQUEST_SCOPE = new ApiRequestScope() { + }; + + private final Logger log; + + public AllBrokersStrategy( + LogContext logContext + ) { + this.log = logContext.logger(AllBrokersStrategy.class); + } + + @Override + public ApiRequestScope lookupScope(BrokerKey key) { + return SINGLE_REQUEST_SCOPE; + } + + @Override + public MetadataRequest.Builder buildRequest(Set keys) { + validateLookupKeys(keys); + // Send empty `Metadata` request. We are only interested in the brokers from the response + return new MetadataRequest.Builder(new MetadataRequestData()); + } + + @Override + public LookupResult handleResponse(Set keys, AbstractResponse abstractResponse) { + validateLookupKeys(keys); + + MetadataResponse response = (MetadataResponse) abstractResponse; + MetadataResponseData.MetadataResponseBrokerCollection brokers = response.data().brokers(); + + if (brokers.isEmpty()) { + log.debug("Metadata response contained no brokers. Will backoff and retry"); + return LookupResult.empty(); + } else { + log.debug("Discovered all brokers {} to send requests to", brokers); + } + + Map brokerKeys = brokers.stream().collect(Collectors.toMap( + broker -> new BrokerKey(OptionalInt.of(broker.nodeId())), + MetadataResponseData.MetadataResponseBroker::nodeId + )); + + return new LookupResult<>( + Collections.singletonList(ANY_BROKER), + Collections.emptyMap(), + brokerKeys + ); + } + + private void validateLookupKeys(Set keys) { + if (keys.size() != 1) { + throw new IllegalArgumentException("Unexpected key set: " + keys); + } + BrokerKey key = keys.iterator().next(); + if (key != ANY_BROKER) { + throw new IllegalArgumentException("Unexpected key set: " + keys); + } + } + + public static class BrokerKey { + public final OptionalInt brokerId; + + public BrokerKey(OptionalInt brokerId) { + this.brokerId = brokerId; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + BrokerKey that = (BrokerKey) o; + return Objects.equals(brokerId, that.brokerId); + } + + @Override + public int hashCode() { + return Objects.hash(brokerId); + } + + @Override + public String toString() { + return "BrokerKey(" + + "brokerId=" + brokerId + + ')'; + } + } + + public static class AllBrokersFuture implements AdminApiFuture { + private final KafkaFutureImpl>> future = new KafkaFutureImpl<>(); + private final Map> brokerFutures = new HashMap<>(); + + @Override + public Set lookupKeys() { + return LOOKUP_KEYS; + } + + @Override + public void completeLookup(Map brokerMapping) { + brokerMapping.forEach((brokerKey, brokerId) -> { + if (brokerId != brokerKey.brokerId.orElse(-1)) { + throw new IllegalArgumentException("Invalid lookup mapping " + brokerKey + " -> " + brokerId); + } + brokerFutures.put(brokerId, new KafkaFutureImpl<>()); + }); + future.complete(brokerFutures); + } + + @Override + public void completeLookupExceptionally(Map lookupErrors) { + if (!LOOKUP_KEYS.equals(lookupErrors.keySet())) { + throw new IllegalArgumentException("Unexpected keys among lookup errors: " + lookupErrors); + } + future.completeExceptionally(lookupErrors.get(ANY_BROKER)); + } + + @Override + public void complete(Map values) { + values.forEach(this::complete); + } + + private void complete(AllBrokersStrategy.BrokerKey key, V value) { + if (key == ANY_BROKER) { + throw new IllegalArgumentException("Invalid attempt to complete with lookup key sentinel"); + } else { + futureOrThrow(key).complete(value); + } + } + + @Override + public void completeExceptionally(Map errors) { + errors.forEach(this::completeExceptionally); + } + + private void completeExceptionally(AllBrokersStrategy.BrokerKey key, Throwable t) { + if (key == ANY_BROKER) { + future.completeExceptionally(t); + } else { + futureOrThrow(key).completeExceptionally(t); + } + } + + public KafkaFutureImpl>> all() { + return future; + } + + private KafkaFutureImpl futureOrThrow(BrokerKey key) { + if (!key.brokerId.isPresent()) { + throw new IllegalArgumentException("Attempt to complete with invalid key: " + key); + } else { + int brokerId = key.brokerId.getAsInt(); + KafkaFutureImpl future = brokerFutures.get(brokerId); + if (future == null) { + throw new IllegalArgumentException("Attempt to complete with unknown broker id: " + brokerId); + } else { + return future; + } + } + } + + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/internals/AlterConsumerGroupOffsetsHandler.java b/clients/src/main/java/org/apache/kafka/clients/admin/internals/AlterConsumerGroupOffsetsHandler.java new file mode 100644 index 0000000..cb7551e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/internals/AlterConsumerGroupOffsetsHandler.java @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import static java.util.Collections.singleton; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.OffsetCommitRequestData; +import org.apache.kafka.common.message.OffsetCommitRequestData.OffsetCommitRequestPartition; +import org.apache.kafka.common.message.OffsetCommitRequestData.OffsetCommitRequestTopic; +import org.apache.kafka.common.message.OffsetCommitResponseData.OffsetCommitResponsePartition; +import org.apache.kafka.common.message.OffsetCommitResponseData.OffsetCommitResponseTopic; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.AbstractResponse; +import org.apache.kafka.common.requests.OffsetCommitRequest; +import org.apache.kafka.common.requests.OffsetCommitResponse; +import org.apache.kafka.common.requests.FindCoordinatorRequest.CoordinatorType; +import org.apache.kafka.common.utils.LogContext; +import org.slf4j.Logger; + +public class AlterConsumerGroupOffsetsHandler implements AdminApiHandler> { + + private final CoordinatorKey groupId; + private final Map offsets; + private final Logger log; + private final AdminApiLookupStrategy lookupStrategy; + + public AlterConsumerGroupOffsetsHandler( + String groupId, + Map offsets, + LogContext logContext + ) { + this.groupId = CoordinatorKey.byGroupId(groupId); + this.offsets = offsets; + this.log = logContext.logger(AlterConsumerGroupOffsetsHandler.class); + this.lookupStrategy = new CoordinatorStrategy(CoordinatorType.GROUP, logContext); + } + + @Override + public String apiName() { + return "offsetCommit"; + } + + @Override + public AdminApiLookupStrategy lookupStrategy() { + return lookupStrategy; + } + + public static AdminApiFuture.SimpleAdminApiFuture> newFuture( + String groupId + ) { + return AdminApiFuture.forKeys(Collections.singleton(CoordinatorKey.byGroupId(groupId))); + } + + private void validateKeys(Set groupIds) { + if (!groupIds.equals(singleton(groupId))) { + throw new IllegalArgumentException("Received unexpected group ids " + groupIds + + " (expected only " + singleton(groupId) + ")"); + } + } + + @Override + public OffsetCommitRequest.Builder buildRequest( + int coordinatorId, + Set groupIds + ) { + validateKeys(groupIds); + + Map offsetData = new HashMap<>(); + offsets.forEach((topicPartition, offsetAndMetadata) -> { + OffsetCommitRequestTopic topic = offsetData.computeIfAbsent( + topicPartition.topic(), + key -> new OffsetCommitRequestTopic().setName(topicPartition.topic()) + ); + + topic.partitions().add(new OffsetCommitRequestPartition() + .setCommittedOffset(offsetAndMetadata.offset()) + .setCommittedLeaderEpoch(offsetAndMetadata.leaderEpoch().orElse(-1)) + .setCommittedMetadata(offsetAndMetadata.metadata()) + .setPartitionIndex(topicPartition.partition())); + }); + + OffsetCommitRequestData data = new OffsetCommitRequestData() + .setGroupId(groupId.idValue) + .setTopics(new ArrayList<>(offsetData.values())); + + return new OffsetCommitRequest.Builder(data); + } + + @Override + public ApiResult> handleResponse( + Node coordinator, + Set groupIds, + AbstractResponse abstractResponse + ) { + validateKeys(groupIds); + + final OffsetCommitResponse response = (OffsetCommitResponse) abstractResponse; + final Set groupsToUnmap = new HashSet<>(); + final Set groupsToRetry = new HashSet<>(); + final Map partitionResults = new HashMap<>(); + + for (OffsetCommitResponseTopic topic : response.data().topics()) { + for (OffsetCommitResponsePartition partition : topic.partitions()) { + TopicPartition topicPartition = new TopicPartition(topic.name(), partition.partitionIndex()); + Errors error = Errors.forCode(partition.errorCode()); + + if (error != Errors.NONE) { + handleError( + groupId, + topicPartition, + error, + partitionResults, + groupsToUnmap, + groupsToRetry + ); + } else { + partitionResults.put(topicPartition, error); + } + } + } + + if (groupsToUnmap.isEmpty() && groupsToRetry.isEmpty()) { + return ApiResult.completed(groupId, partitionResults); + } else { + return ApiResult.unmapped(new ArrayList<>(groupsToUnmap)); + } + } + + private void handleError( + CoordinatorKey groupId, + TopicPartition topicPartition, + Errors error, + Map partitionResults, + Set groupsToUnmap, + Set groupsToRetry + ) { + switch (error) { + // If the coordinator is in the middle of loading, or rebalance is in progress, then we just need to retry. + case COORDINATOR_LOAD_IN_PROGRESS: + case REBALANCE_IN_PROGRESS: + log.debug("OffsetCommit request for group id {} returned error {}. Will retry.", + groupId.idValue, error); + groupsToRetry.add(groupId); + break; + + // If the coordinator is not available, then we unmap and retry. + case COORDINATOR_NOT_AVAILABLE: + case NOT_COORDINATOR: + log.debug("OffsetCommit request for group id {} returned error {}. Will rediscover the coordinator and retry.", + groupId.idValue, error); + groupsToUnmap.add(groupId); + break; + + // Group level errors. + case INVALID_GROUP_ID: + case INVALID_COMMIT_OFFSET_SIZE: + case GROUP_AUTHORIZATION_FAILED: + log.debug("OffsetCommit request for group id {} failed due to error {}.", + groupId.idValue, error); + partitionResults.put(topicPartition, error); + break; + + // TopicPartition level errors. + case UNKNOWN_TOPIC_OR_PARTITION: + case OFFSET_METADATA_TOO_LARGE: + case TOPIC_AUTHORIZATION_FAILED: + log.debug("OffsetCommit request for group id {} and partition {} failed due" + + " to error {}.", groupId.idValue, topicPartition, error); + partitionResults.put(topicPartition, error); + break; + + // Unexpected errors. + default: + log.error("OffsetCommit request for group id {} and partition {} failed due" + + " to unexpected error {}.", groupId.idValue, topicPartition, error); + partitionResults.put(topicPartition, error); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/internals/ApiRequestScope.java b/clients/src/main/java/org/apache/kafka/clients/admin/internals/ApiRequestScope.java new file mode 100644 index 0000000..0d27f38 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/internals/ApiRequestScope.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import java.util.OptionalInt; + +/** + * This interface is used by {@link AdminApiDriver} to bridge the gap + * to the internal `NodeProvider` defined in + * {@link org.apache.kafka.clients.admin.KafkaAdminClient}. However, a + * request scope is more than just a target broker specification. It also + * provides a way to group key lookups according to different batching + * mechanics. See {@link AdminApiLookupStrategy#lookupScope(Object)} for + * more detail. + */ +public interface ApiRequestScope { + + /** + * Get the target broker ID that a request is intended for or + * empty if the request can be sent to any broker. + * + * Note that if the destination broker ID is present in the + * {@link ApiRequestScope} returned by {@link AdminApiLookupStrategy#lookupScope(Object)}, + * then no lookup will be attempted. + * + * @return optional broker ID + */ + default OptionalInt destinationBrokerId() { + return OptionalInt.empty(); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/internals/CoordinatorKey.java b/clients/src/main/java/org/apache/kafka/clients/admin/internals/CoordinatorKey.java new file mode 100644 index 0000000..f61221b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/internals/CoordinatorKey.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import org.apache.kafka.common.requests.FindCoordinatorRequest; + +import java.util.Objects; + +public class CoordinatorKey { + public final String idValue; + public final FindCoordinatorRequest.CoordinatorType type; + + private CoordinatorKey(FindCoordinatorRequest.CoordinatorType type, String idValue) { + this.idValue = idValue; + this.type = type; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + CoordinatorKey that = (CoordinatorKey) o; + return Objects.equals(idValue, that.idValue) && + type == that.type; + } + + @Override + public int hashCode() { + return Objects.hash(idValue, type); + } + + @Override + public String toString() { + return "CoordinatorKey(" + + "idValue='" + idValue + '\'' + + ", type=" + type + + ')'; + } + + public static CoordinatorKey byGroupId(String groupId) { + return new CoordinatorKey(FindCoordinatorRequest.CoordinatorType.GROUP, groupId); + } + + public static CoordinatorKey byTransactionalId(String transactionalId) { + return new CoordinatorKey(FindCoordinatorRequest.CoordinatorType.TRANSACTION, transactionalId); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/internals/CoordinatorStrategy.java b/clients/src/main/java/org/apache/kafka/clients/admin/internals/CoordinatorStrategy.java new file mode 100644 index 0000000..e6fc0d6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/internals/CoordinatorStrategy.java @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import org.apache.kafka.common.errors.GroupAuthorizationException; +import org.apache.kafka.common.errors.InvalidGroupIdException; +import org.apache.kafka.common.errors.TransactionalIdAuthorizationException; +import org.apache.kafka.common.message.FindCoordinatorRequestData; +import org.apache.kafka.common.message.FindCoordinatorResponseData.Coordinator; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.AbstractResponse; +import org.apache.kafka.common.requests.FindCoordinatorRequest; +import org.apache.kafka.common.requests.FindCoordinatorRequest.CoordinatorType; +import org.apache.kafka.common.requests.FindCoordinatorResponse; +import org.apache.kafka.common.utils.LogContext; +import org.slf4j.Logger; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; + +public class CoordinatorStrategy implements AdminApiLookupStrategy { + + private static final ApiRequestScope BATCH_REQUEST_SCOPE = new ApiRequestScope() { }; + + private final Logger log; + private final FindCoordinatorRequest.CoordinatorType type; + private Set unrepresentableKeys = Collections.emptySet(); + + boolean batch = true; + + public CoordinatorStrategy( + FindCoordinatorRequest.CoordinatorType type, + LogContext logContext + ) { + this.type = type; + this.log = logContext.logger(CoordinatorStrategy.class); + } + + @Override + public ApiRequestScope lookupScope(CoordinatorKey key) { + if (batch) { + return BATCH_REQUEST_SCOPE; + } else { + // If the `FindCoordinator` API does not support batched lookups, we use a + // separate lookup context for each coordinator key we need to lookup + return new LookupRequestScope(key); + } + } + + @Override + public FindCoordinatorRequest.Builder buildRequest(Set keys) { + unrepresentableKeys = keys.stream().filter(k -> k == null || !isRepresentableKey(k.idValue)).collect(Collectors.toSet()); + Set representableKeys = keys.stream().filter(k -> k != null && isRepresentableKey(k.idValue)).collect(Collectors.toSet()); + if (batch) { + ensureSameType(representableKeys); + FindCoordinatorRequestData data = new FindCoordinatorRequestData() + .setKeyType(type.id()) + .setCoordinatorKeys(representableKeys.stream().map(k -> k.idValue).collect(Collectors.toList())); + return new FindCoordinatorRequest.Builder(data); + } else { + CoordinatorKey key = requireSingletonAndType(representableKeys); + return new FindCoordinatorRequest.Builder( + new FindCoordinatorRequestData() + .setKey(key.idValue) + .setKeyType(key.type.id()) + ); + } + } + + @Override + public LookupResult handleResponse( + Set keys, + AbstractResponse abstractResponse + ) { + Map mappedKeys = new HashMap<>(); + Map failedKeys = new HashMap<>(); + + for (CoordinatorKey key : unrepresentableKeys) { + failedKeys.put(key, new InvalidGroupIdException("The given group id '" + + key.idValue + "' cannot be represented in a request.")); + } + + for (Coordinator coordinator : ((FindCoordinatorResponse) abstractResponse).coordinators()) { + CoordinatorKey key; + if (coordinator.key() == null) // old version without batching + key = requireSingletonAndType(keys); + else { + key = (type == CoordinatorType.GROUP) + ? CoordinatorKey.byGroupId(coordinator.key()) + : CoordinatorKey.byTransactionalId(coordinator.key()); + } + handleError(Errors.forCode(coordinator.errorCode()), + key, + coordinator.nodeId(), + mappedKeys, + failedKeys); + } + return new LookupResult<>(failedKeys, mappedKeys); + } + + public void disableBatch() { + batch = false; + } + + private CoordinatorKey requireSingletonAndType(Set keys) { + if (keys.size() != 1) { + throw new IllegalArgumentException("Unexpected size of key set: expected 1, but got " + keys.size()); + } + CoordinatorKey key = keys.iterator().next(); + if (key.type != type) { + throw new IllegalArgumentException("Unexpected key type: expected key to be of type " + type + ", but got " + key.type); + } + return key; + } + + private void ensureSameType(Set keys) { + if (keys.size() < 1) { + throw new IllegalArgumentException("Unexpected size of key set: expected >= 1, but got " + keys.size()); + } + if (keys.stream().filter(k -> k.type == type).collect(Collectors.toSet()).size() != keys.size()) { + throw new IllegalArgumentException("Unexpected key set: expected all key to be of type " + type + ", but some key were not"); + } + } + + private static boolean isRepresentableKey(String groupId) { + return groupId != null; + } + + private void handleError(Errors error, CoordinatorKey key, int nodeId, Map mappedKeys, Map failedKeys) { + switch (error) { + case NONE: + mappedKeys.put(key, nodeId); + break; + case COORDINATOR_NOT_AVAILABLE: + case COORDINATOR_LOAD_IN_PROGRESS: + log.debug("FindCoordinator request for key {} returned topic-level error {}. Will retry", + key, error); + break; + case GROUP_AUTHORIZATION_FAILED: + failedKeys.put(key, new GroupAuthorizationException("FindCoordinator request for groupId " + + "`" + key + "` failed due to authorization failure", key.idValue)); + break; + case TRANSACTIONAL_ID_AUTHORIZATION_FAILED: + failedKeys.put(key, new TransactionalIdAuthorizationException("FindCoordinator request for " + + "transactionalId `" + key + "` failed due to authorization failure")); + break; + default: + failedKeys.put(key, error.exception("FindCoordinator request for key " + + "`" + key + "` failed due to an unexpected error")); + } + } + + private static class LookupRequestScope implements ApiRequestScope { + final CoordinatorKey key; + + private LookupRequestScope(CoordinatorKey key) { + this.key = key; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + LookupRequestScope that = (LookupRequestScope) o; + return Objects.equals(key, that.key); + } + + @Override + public int hashCode() { + return Objects.hash(key); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/internals/DeleteConsumerGroupOffsetsHandler.java b/clients/src/main/java/org/apache/kafka/clients/admin/internals/DeleteConsumerGroupOffsetsHandler.java new file mode 100644 index 0000000..a853edd --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/internals/DeleteConsumerGroupOffsetsHandler.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.OffsetDeleteRequestData; +import org.apache.kafka.common.message.OffsetDeleteRequestData.OffsetDeleteRequestPartition; +import org.apache.kafka.common.message.OffsetDeleteRequestData.OffsetDeleteRequestTopic; +import org.apache.kafka.common.message.OffsetDeleteRequestData.OffsetDeleteRequestTopicCollection; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.AbstractResponse; +import org.apache.kafka.common.requests.OffsetDeleteRequest; +import org.apache.kafka.common.requests.OffsetDeleteResponse; +import org.apache.kafka.common.requests.FindCoordinatorRequest.CoordinatorType; +import org.apache.kafka.common.utils.LogContext; +import org.slf4j.Logger; + +public class DeleteConsumerGroupOffsetsHandler implements AdminApiHandler> { + + private final CoordinatorKey groupId; + private final Set partitions; + private final Logger log; + private final AdminApiLookupStrategy lookupStrategy; + + public DeleteConsumerGroupOffsetsHandler( + String groupId, + Set partitions, + LogContext logContext + ) { + this.groupId = CoordinatorKey.byGroupId(groupId); + this.partitions = partitions; + this.log = logContext.logger(DeleteConsumerGroupOffsetsHandler.class); + this.lookupStrategy = new CoordinatorStrategy(CoordinatorType.GROUP, logContext); + } + + @Override + public String apiName() { + return "offsetDelete"; + } + + @Override + public AdminApiLookupStrategy lookupStrategy() { + return lookupStrategy; + } + + public static AdminApiFuture.SimpleAdminApiFuture> newFuture( + String groupId + ) { + return AdminApiFuture.forKeys(Collections.singleton(CoordinatorKey.byGroupId(groupId))); + } + + private void validateKeys(Set groupIds) { + if (!groupIds.equals(Collections.singleton(groupId))) { + throw new IllegalArgumentException("Received unexpected group ids " + groupIds + + " (expected only " + Collections.singleton(groupId) + ")"); + } + } + + @Override + public OffsetDeleteRequest.Builder buildRequest(int coordinatorId, Set groupIds) { + validateKeys(groupIds); + + final OffsetDeleteRequestTopicCollection topics = new OffsetDeleteRequestTopicCollection(); + partitions.stream().collect(Collectors.groupingBy(TopicPartition::topic)).forEach((topic, topicPartitions) -> topics.add( + new OffsetDeleteRequestTopic() + .setName(topic) + .setPartitions(topicPartitions.stream() + .map(tp -> new OffsetDeleteRequestPartition().setPartitionIndex(tp.partition())) + .collect(Collectors.toList()) + ) + )); + + return new OffsetDeleteRequest.Builder( + new OffsetDeleteRequestData() + .setGroupId(groupId.idValue) + .setTopics(topics) + ); + } + + @Override + public ApiResult> handleResponse( + Node coordinator, + Set groupIds, + AbstractResponse abstractResponse + ) { + validateKeys(groupIds); + + final OffsetDeleteResponse response = (OffsetDeleteResponse) abstractResponse; + final Errors error = Errors.forCode(response.data().errorCode()); + + if (error != Errors.NONE) { + final Map failed = new HashMap<>(); + final Set groupsToUnmap = new HashSet<>(); + + handleGroupError(groupId, error, failed, groupsToUnmap); + + return new ApiResult<>(Collections.emptyMap(), failed, new ArrayList<>(groupsToUnmap)); + } else { + final Map partitionResults = new HashMap<>(); + response.data().topics().forEach(topic -> + topic.partitions().forEach(partition -> + partitionResults.put( + new TopicPartition(topic.name(), partition.partitionIndex()), + Errors.forCode(partition.errorCode()) + ) + ) + ); + + return ApiResult.completed(groupId, partitionResults); + } + } + + private void handleGroupError( + CoordinatorKey groupId, + Errors error, + Map failed, + Set groupsToUnmap + ) { + switch (error) { + case GROUP_AUTHORIZATION_FAILED: + case GROUP_ID_NOT_FOUND: + case INVALID_GROUP_ID: + case NON_EMPTY_GROUP: + log.debug("`OffsetDelete` request for group id {} failed due to error {}.", groupId.idValue, error); + failed.put(groupId, error.exception()); + break; + + case COORDINATOR_LOAD_IN_PROGRESS: + // If the coordinator is in the middle of loading, then we just need to retry + log.debug("`OffsetDelete` request for group id {} failed because the coordinator" + + " is still in the process of loading state. Will retry.", groupId.idValue); + break; + + case COORDINATOR_NOT_AVAILABLE: + case NOT_COORDINATOR: + // If the coordinator is unavailable or there was a coordinator change, then we unmap + // the key so that we retry the `FindCoordinator` request + log.debug("`OffsetDelete` request for group id {} returned error {}. " + + "Will attempt to find the coordinator again and retry.", groupId.idValue, error); + groupsToUnmap.add(groupId); + break; + + default: + log.error("`OffsetDelete` request for group id {} failed due to unexpected error {}.", groupId.idValue, error); + failed.put(groupId, error.exception()); + break; + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/internals/DeleteConsumerGroupsHandler.java b/clients/src/main/java/org/apache/kafka/clients/admin/internals/DeleteConsumerGroupsHandler.java new file mode 100644 index 0000000..693d236 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/internals/DeleteConsumerGroupsHandler.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.kafka.common.Node; +import org.apache.kafka.common.message.DeleteGroupsRequestData; +import org.apache.kafka.common.message.DeleteGroupsResponseData.DeletableGroupResult; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.AbstractResponse; +import org.apache.kafka.common.requests.DeleteGroupsRequest; +import org.apache.kafka.common.requests.DeleteGroupsResponse; +import org.apache.kafka.common.requests.FindCoordinatorRequest.CoordinatorType; +import org.apache.kafka.common.utils.LogContext; +import org.slf4j.Logger; + +public class DeleteConsumerGroupsHandler implements AdminApiHandler { + + private final Logger log; + private final AdminApiLookupStrategy lookupStrategy; + + public DeleteConsumerGroupsHandler( + LogContext logContext + ) { + this.log = logContext.logger(DeleteConsumerGroupsHandler.class); + this.lookupStrategy = new CoordinatorStrategy(CoordinatorType.GROUP, logContext); + } + + @Override + public String apiName() { + return "deleteConsumerGroups"; + } + + @Override + public AdminApiLookupStrategy lookupStrategy() { + return lookupStrategy; + } + + public static AdminApiFuture.SimpleAdminApiFuture newFuture( + Collection groupIds + ) { + return AdminApiFuture.forKeys(buildKeySet(groupIds)); + } + + private static Set buildKeySet(Collection groupIds) { + return groupIds.stream() + .map(CoordinatorKey::byGroupId) + .collect(Collectors.toSet()); + } + + @Override + public DeleteGroupsRequest.Builder buildRequest( + int coordinatorId, + Set keys + ) { + List groupIds = keys.stream().map(key -> key.idValue).collect(Collectors.toList()); + DeleteGroupsRequestData data = new DeleteGroupsRequestData() + .setGroupsNames(groupIds); + return new DeleteGroupsRequest.Builder(data); + } + + @Override + public ApiResult handleResponse( + Node coordinator, + Set groupIds, + AbstractResponse abstractResponse + ) { + final DeleteGroupsResponse response = (DeleteGroupsResponse) abstractResponse; + final Map completed = new HashMap<>(); + final Map failed = new HashMap<>(); + final Set groupsToUnmap = new HashSet<>(); + + for (DeletableGroupResult deletedGroup : response.data().results()) { + CoordinatorKey groupIdKey = CoordinatorKey.byGroupId(deletedGroup.groupId()); + Errors error = Errors.forCode(deletedGroup.errorCode()); + if (error != Errors.NONE) { + handleError(groupIdKey, error, failed, groupsToUnmap); + continue; + } + + completed.put(groupIdKey, null); + } + + return new ApiResult<>(completed, failed, new ArrayList<>(groupsToUnmap)); + } + + private void handleError( + CoordinatorKey groupId, + Errors error, + Map failed, + Set groupsToUnmap + ) { + switch (error) { + case GROUP_AUTHORIZATION_FAILED: + case INVALID_GROUP_ID: + case NON_EMPTY_GROUP: + case GROUP_ID_NOT_FOUND: + log.debug("`DeleteConsumerGroups` request for group id {} failed due to error {}", groupId.idValue, error); + failed.put(groupId, error.exception()); + break; + + case COORDINATOR_LOAD_IN_PROGRESS: + // If the coordinator is in the middle of loading, then we just need to retry + log.debug("`DeleteConsumerGroups` request for group id {} failed because the coordinator " + + "is still in the process of loading state. Will retry", groupId.idValue); + break; + + case COORDINATOR_NOT_AVAILABLE: + case NOT_COORDINATOR: + // If the coordinator is unavailable or there was a coordinator change, then we unmap + // the key so that we retry the `FindCoordinator` request + log.debug("`DeleteConsumerGroups` request for group id {} returned error {}. " + + "Will attempt to find the coordinator again and retry", groupId.idValue, error); + groupsToUnmap.add(groupId); + break; + + default: + log.error("`DeleteConsumerGroups` request for group id {} failed due to unexpected error {}", groupId.idValue, error); + failed.put(groupId, error.exception()); + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/internals/DescribeConsumerGroupsHandler.java b/clients/src/main/java/org/apache/kafka/clients/admin/internals/DescribeConsumerGroupsHandler.java new file mode 100644 index 0000000..5c5022a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/internals/DescribeConsumerGroupsHandler.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.kafka.clients.admin.ConsumerGroupDescription; +import org.apache.kafka.clients.admin.MemberAssignment; +import org.apache.kafka.clients.admin.MemberDescription; +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Assignment; +import org.apache.kafka.clients.consumer.internals.ConsumerProtocol; +import org.apache.kafka.common.ConsumerGroupState; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.acl.AclOperation; +import org.apache.kafka.common.message.DescribeGroupsRequestData; +import org.apache.kafka.common.message.DescribeGroupsResponseData.DescribedGroup; +import org.apache.kafka.common.message.DescribeGroupsResponseData.DescribedGroupMember; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.AbstractResponse; +import org.apache.kafka.common.requests.DescribeGroupsRequest; +import org.apache.kafka.common.requests.DescribeGroupsResponse; +import org.apache.kafka.common.requests.FindCoordinatorRequest; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.requests.FindCoordinatorRequest.CoordinatorType; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; + +public class DescribeConsumerGroupsHandler implements AdminApiHandler { + + private final boolean includeAuthorizedOperations; + private final Logger log; + private final AdminApiLookupStrategy lookupStrategy; + + public DescribeConsumerGroupsHandler( + boolean includeAuthorizedOperations, + LogContext logContext + ) { + this.includeAuthorizedOperations = includeAuthorizedOperations; + this.log = logContext.logger(DescribeConsumerGroupsHandler.class); + this.lookupStrategy = new CoordinatorStrategy(CoordinatorType.GROUP, logContext); + } + + private static Set buildKeySet(Collection groupIds) { + return groupIds.stream() + .map(CoordinatorKey::byGroupId) + .collect(Collectors.toSet()); + } + + public static AdminApiFuture.SimpleAdminApiFuture newFuture( + Collection groupIds + ) { + return AdminApiFuture.forKeys(buildKeySet(groupIds)); + } + + @Override + public String apiName() { + return "describeGroups"; + } + + @Override + public AdminApiLookupStrategy lookupStrategy() { + return lookupStrategy; + } + + @Override + public DescribeGroupsRequest.Builder buildRequest(int coordinatorId, Set keys) { + List groupIds = keys.stream().map(key -> { + if (key.type != FindCoordinatorRequest.CoordinatorType.GROUP) { + throw new IllegalArgumentException("Invalid transaction coordinator key " + key + + " when building `DescribeGroups` request"); + } + return key.idValue; + }).collect(Collectors.toList()); + DescribeGroupsRequestData data = new DescribeGroupsRequestData() + .setGroups(groupIds) + .setIncludeAuthorizedOperations(includeAuthorizedOperations); + return new DescribeGroupsRequest.Builder(data); + } + + @Override + public ApiResult handleResponse( + Node coordinator, + Set groupIds, + AbstractResponse abstractResponse + ) { + final DescribeGroupsResponse response = (DescribeGroupsResponse) abstractResponse; + final Map completed = new HashMap<>(); + final Map failed = new HashMap<>(); + final Set groupsToUnmap = new HashSet<>(); + + for (DescribedGroup describedGroup : response.data().groups()) { + CoordinatorKey groupIdKey = CoordinatorKey.byGroupId(describedGroup.groupId()); + Errors error = Errors.forCode(describedGroup.errorCode()); + if (error != Errors.NONE) { + handleError(groupIdKey, error, failed, groupsToUnmap); + continue; + } + final String protocolType = describedGroup.protocolType(); + if (protocolType.equals(ConsumerProtocol.PROTOCOL_TYPE) || protocolType.isEmpty()) { + final List members = describedGroup.members(); + final List memberDescriptions = new ArrayList<>(members.size()); + final Set authorizedOperations = validAclOperations(describedGroup.authorizedOperations()); + for (DescribedGroupMember groupMember : members) { + Set partitions = Collections.emptySet(); + if (groupMember.memberAssignment().length > 0) { + final Assignment assignment = ConsumerProtocol. + deserializeAssignment(ByteBuffer.wrap(groupMember.memberAssignment())); + partitions = new HashSet<>(assignment.partitions()); + } + memberDescriptions.add(new MemberDescription( + groupMember.memberId(), + Optional.ofNullable(groupMember.groupInstanceId()), + groupMember.clientId(), + groupMember.clientHost(), + new MemberAssignment(partitions))); + } + final ConsumerGroupDescription consumerGroupDescription = + new ConsumerGroupDescription(groupIdKey.idValue, protocolType.isEmpty(), + memberDescriptions, + describedGroup.protocolData(), + ConsumerGroupState.parse(describedGroup.groupState()), + coordinator, + authorizedOperations); + completed.put(groupIdKey, consumerGroupDescription); + } else { + failed.put(groupIdKey, new IllegalArgumentException( + String.format("GroupId %s is not a consumer group (%s).", + groupIdKey.idValue, protocolType))); + } + } + + return new ApiResult<>(completed, failed, new ArrayList<>(groupsToUnmap)); + } + + private void handleError( + CoordinatorKey groupId, + Errors error, + Map failed, + Set groupsToUnmap + ) { + switch (error) { + case GROUP_AUTHORIZATION_FAILED: + log.debug("`DescribeGroups` request for group id {} failed due to error {}", groupId.idValue, error); + failed.put(groupId, error.exception()); + break; + + case COORDINATOR_LOAD_IN_PROGRESS: + // If the coordinator is in the middle of loading, then we just need to retry + log.debug("`DescribeGroups` request for group id {} failed because the coordinator " + + "is still in the process of loading state. Will retry", groupId.idValue); + break; + + case COORDINATOR_NOT_AVAILABLE: + case NOT_COORDINATOR: + // If the coordinator is unavailable or there was a coordinator change, then we unmap + // the key so that we retry the `FindCoordinator` request + log.debug("`DescribeGroups` request for group id {} returned error {}. " + + "Will attempt to find the coordinator again and retry", groupId.idValue, error); + groupsToUnmap.add(groupId); + break; + + default: + log.error("`DescribeGroups` request for group id {} failed due to unexpected error {}", groupId.idValue, error); + failed.put(groupId, error.exception()); + } + } + + private Set validAclOperations(final int authorizedOperations) { + if (authorizedOperations == MetadataResponse.AUTHORIZED_OPERATIONS_OMITTED) { + return null; + } + return Utils.from32BitField(authorizedOperations) + .stream() + .map(AclOperation::fromCode) + .filter(operation -> operation != AclOperation.UNKNOWN + && operation != AclOperation.ALL + && operation != AclOperation.ANY) + .collect(Collectors.toSet()); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/internals/DescribeProducersHandler.java b/clients/src/main/java/org/apache/kafka/clients/admin/internals/DescribeProducersHandler.java new file mode 100644 index 0000000..4b279d5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/internals/DescribeProducersHandler.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import org.apache.kafka.clients.admin.DescribeProducersOptions; +import org.apache.kafka.clients.admin.DescribeProducersResult.PartitionProducerState; +import org.apache.kafka.clients.admin.ProducerState; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.InvalidTopicException; +import org.apache.kafka.common.errors.TopicAuthorizationException; +import org.apache.kafka.common.message.DescribeProducersRequestData; +import org.apache.kafka.common.message.DescribeProducersResponseData; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.AbstractResponse; +import org.apache.kafka.common.requests.ApiError; +import org.apache.kafka.common.requests.DescribeProducersRequest; +import org.apache.kafka.common.requests.DescribeProducersResponse; +import org.apache.kafka.common.utils.CollectionUtils; +import org.apache.kafka.common.utils.LogContext; +import org.slf4j.Logger; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.OptionalInt; +import java.util.OptionalLong; +import java.util.Set; +import java.util.stream.Collectors; + +public class DescribeProducersHandler implements AdminApiHandler { + private final Logger log; + private final DescribeProducersOptions options; + private final AdminApiLookupStrategy lookupStrategy; + + public DescribeProducersHandler( + DescribeProducersOptions options, + LogContext logContext + ) { + this.options = options; + this.log = logContext.logger(DescribeProducersHandler.class); + + if (options.brokerId().isPresent()) { + this.lookupStrategy = new StaticBrokerStrategy<>(options.brokerId().getAsInt()); + } else { + this.lookupStrategy = new PartitionLeaderStrategy(logContext); + } + } + + public static AdminApiFuture.SimpleAdminApiFuture newFuture( + Collection topicPartitions + ) { + return AdminApiFuture.forKeys(new HashSet<>(topicPartitions)); + } + + @Override + public String apiName() { + return "describeProducers"; + } + + @Override + public AdminApiLookupStrategy lookupStrategy() { + return lookupStrategy; + } + + @Override + public DescribeProducersRequest.Builder buildRequest( + int brokerId, + Set topicPartitions + ) { + DescribeProducersRequestData request = new DescribeProducersRequestData(); + DescribeProducersRequest.Builder builder = new DescribeProducersRequest.Builder(request); + + CollectionUtils.groupPartitionsByTopic( + topicPartitions, + builder::addTopic, + (topicRequest, partitionId) -> topicRequest.partitionIndexes().add(partitionId) + ); + + return builder; + } + + private void handlePartitionError( + TopicPartition topicPartition, + ApiError apiError, + Map failed, + List unmapped + ) { + switch (apiError.error()) { + case NOT_LEADER_OR_FOLLOWER: + if (options.brokerId().isPresent()) { + // Typically these errors are retriable, but if the user specified the brokerId + // explicitly, then they are fatal. + int brokerId = options.brokerId().getAsInt(); + log.error("Not leader error in `DescribeProducers` response for partition {} " + + "for brokerId {} set in options", topicPartition, brokerId, apiError.exception()); + failed.put(topicPartition, apiError.error().exception("Failed to describe active producers " + + "for partition " + topicPartition + " on brokerId " + brokerId)); + } else { + // Otherwise, we unmap the partition so that we can find the new leader + log.debug("Not leader error in `DescribeProducers` response for partition {}. " + + "Will retry later.", topicPartition); + unmapped.add(topicPartition); + } + break; + + case UNKNOWN_TOPIC_OR_PARTITION: + log.debug("Unknown topic/partition error in `DescribeProducers` response for partition {}. " + + "Will retry later.", topicPartition); + break; + + case INVALID_TOPIC_EXCEPTION: + log.error("Invalid topic in `DescribeProducers` response for partition {}", + topicPartition, apiError.exception()); + failed.put(topicPartition, new InvalidTopicException( + "Failed to fetch metadata for partition " + topicPartition + + " due to invalid topic error: " + apiError.messageWithFallback(), + Collections.singleton(topicPartition.topic()))); + break; + + case TOPIC_AUTHORIZATION_FAILED: + log.error("Authorization failed in `DescribeProducers` response for partition {}", + topicPartition, apiError.exception()); + failed.put(topicPartition, new TopicAuthorizationException("Failed to describe " + + "active producers for partition " + topicPartition + " due to authorization failure on topic" + + " `" + topicPartition.topic() + "`", Collections.singleton(topicPartition.topic()))); + break; + + default: + log.error("Unexpected error in `DescribeProducers` response for partition {}", + topicPartition, apiError.exception()); + failed.put(topicPartition, apiError.error().exception("Failed to describe active " + + "producers for partition " + topicPartition + " due to unexpected error")); + break; + } + } + + @Override + public ApiResult handleResponse( + Node broker, + Set keys, + AbstractResponse abstractResponse + ) { + DescribeProducersResponse response = (DescribeProducersResponse) abstractResponse; + Map completed = new HashMap<>(); + Map failed = new HashMap<>(); + List unmapped = new ArrayList<>(); + + for (DescribeProducersResponseData.TopicResponse topicResponse : response.data().topics()) { + for (DescribeProducersResponseData.PartitionResponse partitionResponse : topicResponse.partitions()) { + TopicPartition topicPartition = new TopicPartition( + topicResponse.name(), partitionResponse.partitionIndex()); + + Errors error = Errors.forCode(partitionResponse.errorCode()); + if (error != Errors.NONE) { + ApiError apiError = new ApiError(error, partitionResponse.errorMessage()); + handlePartitionError(topicPartition, apiError, failed, unmapped); + continue; + } + + List activeProducers = partitionResponse.activeProducers().stream() + .map(activeProducer -> { + OptionalLong currentTransactionFirstOffset = + activeProducer.currentTxnStartOffset() < 0 ? + OptionalLong.empty() : + OptionalLong.of(activeProducer.currentTxnStartOffset()); + OptionalInt coordinatorEpoch = + activeProducer.coordinatorEpoch() < 0 ? + OptionalInt.empty() : + OptionalInt.of(activeProducer.coordinatorEpoch()); + + return new ProducerState( + activeProducer.producerId(), + activeProducer.producerEpoch(), + activeProducer.lastSequence(), + activeProducer.lastTimestamp(), + coordinatorEpoch, + currentTransactionFirstOffset + ); + }).collect(Collectors.toList()); + + completed.put(topicPartition, new PartitionProducerState(activeProducers)); + } + } + return new ApiResult<>(completed, failed, unmapped); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/internals/DescribeTransactionsHandler.java b/clients/src/main/java/org/apache/kafka/clients/admin/internals/DescribeTransactionsHandler.java new file mode 100644 index 0000000..d270145 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/internals/DescribeTransactionsHandler.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import org.apache.kafka.clients.admin.TransactionDescription; +import org.apache.kafka.clients.admin.TransactionState; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.TransactionalIdAuthorizationException; +import org.apache.kafka.common.errors.TransactionalIdNotFoundException; +import org.apache.kafka.common.message.DescribeTransactionsRequestData; +import org.apache.kafka.common.message.DescribeTransactionsResponseData; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.AbstractResponse; +import org.apache.kafka.common.requests.DescribeTransactionsRequest; +import org.apache.kafka.common.requests.DescribeTransactionsResponse; +import org.apache.kafka.common.requests.FindCoordinatorRequest; +import org.apache.kafka.common.requests.FindCoordinatorRequest.CoordinatorType; +import org.apache.kafka.common.utils.LogContext; +import org.slf4j.Logger; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.OptionalLong; +import java.util.Set; +import java.util.stream.Collectors; + +public class DescribeTransactionsHandler implements AdminApiHandler { + private final Logger log; + private final AdminApiLookupStrategy lookupStrategy; + + public DescribeTransactionsHandler( + LogContext logContext + ) { + this.log = logContext.logger(DescribeTransactionsHandler.class); + this.lookupStrategy = new CoordinatorStrategy(CoordinatorType.TRANSACTION, logContext); + } + + public static AdminApiFuture.SimpleAdminApiFuture newFuture( + Collection transactionalIds + ) { + return AdminApiFuture.forKeys(buildKeySet(transactionalIds)); + } + + private static Set buildKeySet(Collection transactionalIds) { + return transactionalIds.stream() + .map(CoordinatorKey::byTransactionalId) + .collect(Collectors.toSet()); + } + + @Override + public String apiName() { + return "describeTransactions"; + } + + @Override + public AdminApiLookupStrategy lookupStrategy() { + return lookupStrategy; + } + + @Override + public DescribeTransactionsRequest.Builder buildRequest( + int brokerId, + Set keys + ) { + DescribeTransactionsRequestData request = new DescribeTransactionsRequestData(); + List transactionalIds = keys.stream().map(key -> { + if (key.type != FindCoordinatorRequest.CoordinatorType.TRANSACTION) { + throw new IllegalArgumentException("Invalid group coordinator key " + key + + " when building `DescribeTransaction` request"); + } + return key.idValue; + }).collect(Collectors.toList()); + request.setTransactionalIds(transactionalIds); + return new DescribeTransactionsRequest.Builder(request); + } + + @Override + public ApiResult handleResponse( + Node broker, + Set keys, + AbstractResponse abstractResponse + ) { + DescribeTransactionsResponse response = (DescribeTransactionsResponse) abstractResponse; + Map completed = new HashMap<>(); + Map failed = new HashMap<>(); + List unmapped = new ArrayList<>(); + + for (DescribeTransactionsResponseData.TransactionState transactionState : response.data().transactionStates()) { + CoordinatorKey transactionalIdKey = CoordinatorKey.byTransactionalId( + transactionState.transactionalId()); + if (!keys.contains(transactionalIdKey)) { + log.warn("Response included transactionalId `{}`, which was not requested", + transactionState.transactionalId()); + continue; + } + + Errors error = Errors.forCode(transactionState.errorCode()); + if (error != Errors.NONE) { + handleError(transactionalIdKey, error, failed, unmapped); + continue; + } + + OptionalLong transactionStartTimeMs = transactionState.transactionStartTimeMs() < 0 ? + OptionalLong.empty() : + OptionalLong.of(transactionState.transactionStartTimeMs()); + + completed.put(transactionalIdKey, new TransactionDescription( + broker.id(), + TransactionState.parse(transactionState.transactionState()), + transactionState.producerId(), + transactionState.producerEpoch(), + transactionState.transactionTimeoutMs(), + transactionStartTimeMs, + collectTopicPartitions(transactionState) + )); + } + + return new ApiResult<>(completed, failed, unmapped); + } + + private Set collectTopicPartitions( + DescribeTransactionsResponseData.TransactionState transactionState + ) { + Set res = new HashSet<>(); + for (DescribeTransactionsResponseData.TopicData topicData : transactionState.topics()) { + String topic = topicData.topic(); + for (Integer partitionId : topicData.partitions()) { + res.add(new TopicPartition(topic, partitionId)); + } + } + return res; + } + + private void handleError( + CoordinatorKey transactionalIdKey, + Errors error, + Map failed, + List unmapped + ) { + switch (error) { + case TRANSACTIONAL_ID_AUTHORIZATION_FAILED: + failed.put(transactionalIdKey, new TransactionalIdAuthorizationException( + "DescribeTransactions request for transactionalId `" + transactionalIdKey.idValue + "` " + + "failed due to authorization failure")); + break; + + case TRANSACTIONAL_ID_NOT_FOUND: + failed.put(transactionalIdKey, new TransactionalIdNotFoundException( + "DescribeTransactions request for transactionalId `" + transactionalIdKey.idValue + "` " + + "failed because the ID could not be found")); + break; + + case COORDINATOR_LOAD_IN_PROGRESS: + // If the coordinator is in the middle of loading, then we just need to retry + log.debug("DescribeTransactions request for transactionalId `{}` failed because the " + + "coordinator is still in the process of loading state. Will retry", + transactionalIdKey.idValue); + break; + + case NOT_COORDINATOR: + case COORDINATOR_NOT_AVAILABLE: + // If the coordinator is unavailable or there was a coordinator change, then we unmap + // the key so that we retry the `FindCoordinator` request + unmapped.add(transactionalIdKey); + log.debug("DescribeTransactions request for transactionalId `{}` returned error {}. Will attempt " + + "to find the coordinator again and retry", transactionalIdKey.idValue, error); + break; + + default: + failed.put(transactionalIdKey, error.exception("DescribeTransactions request for " + + "transactionalId `" + transactionalIdKey.idValue + "` failed due to unexpected error")); + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/internals/ListConsumerGroupOffsetsHandler.java b/clients/src/main/java/org/apache/kafka/clients/admin/internals/ListConsumerGroupOffsetsHandler.java new file mode 100644 index 0000000..b1d2e9d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/internals/ListConsumerGroupOffsetsHandler.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.AbstractResponse; +import org.apache.kafka.common.requests.OffsetFetchRequest; +import org.apache.kafka.common.requests.OffsetFetchResponse; +import org.apache.kafka.common.requests.FindCoordinatorRequest.CoordinatorType; +import org.apache.kafka.common.utils.LogContext; +import org.slf4j.Logger; + +public class ListConsumerGroupOffsetsHandler implements AdminApiHandler> { + + private final CoordinatorKey groupId; + private final List partitions; + private final Logger log; + private final AdminApiLookupStrategy lookupStrategy; + + public ListConsumerGroupOffsetsHandler( + String groupId, + List partitions, + LogContext logContext + ) { + this.groupId = CoordinatorKey.byGroupId(groupId); + this.partitions = partitions; + this.log = logContext.logger(ListConsumerGroupOffsetsHandler.class); + this.lookupStrategy = new CoordinatorStrategy(CoordinatorType.GROUP, logContext); + } + + public static AdminApiFuture.SimpleAdminApiFuture> newFuture( + String groupId + ) { + return AdminApiFuture.forKeys(Collections.singleton(CoordinatorKey.byGroupId(groupId))); + } + + @Override + public String apiName() { + return "offsetFetch"; + } + + @Override + public AdminApiLookupStrategy lookupStrategy() { + return lookupStrategy; + } + + private void validateKeys(Set groupIds) { + if (!groupIds.equals(Collections.singleton(groupId))) { + throw new IllegalArgumentException("Received unexpected group ids " + groupIds + + " (expected only " + Collections.singleton(groupId) + ")"); + } + } + + @Override + public OffsetFetchRequest.Builder buildRequest(int coordinatorId, Set groupIds) { + validateKeys(groupIds); + // Set the flag to false as for admin client request, + // we don't need to wait for any pending offset state to clear. + return new OffsetFetchRequest.Builder(groupId.idValue, false, partitions, false); + } + + @Override + public ApiResult> handleResponse( + Node coordinator, + Set groupIds, + AbstractResponse abstractResponse + ) { + validateKeys(groupIds); + + final OffsetFetchResponse response = (OffsetFetchResponse) abstractResponse; + + // the groupError will contain the group level error for v0-v8 OffsetFetchResponse + Errors groupError = response.groupLevelError(groupId.idValue); + if (groupError != Errors.NONE) { + final Map failed = new HashMap<>(); + final Set groupsToUnmap = new HashSet<>(); + + handleGroupError(groupId, groupError, failed, groupsToUnmap); + + return new ApiResult<>(Collections.emptyMap(), failed, new ArrayList<>(groupsToUnmap)); + } else { + final Map groupOffsetsListing = new HashMap<>(); + + response.partitionDataMap(groupId.idValue).forEach((topicPartition, partitionData) -> { + final Errors error = partitionData.error; + if (error == Errors.NONE) { + final long offset = partitionData.offset; + final String metadata = partitionData.metadata; + final Optional leaderEpoch = partitionData.leaderEpoch; + // Negative offset indicates that the group has no committed offset for this partition + if (offset < 0) { + groupOffsetsListing.put(topicPartition, null); + } else { + groupOffsetsListing.put(topicPartition, new OffsetAndMetadata(offset, leaderEpoch, metadata)); + } + } else { + log.warn("Skipping return offset for {} due to error {}.", topicPartition, error); + } + }); + + return ApiResult.completed(groupId, groupOffsetsListing); + } + } + + private void handleGroupError( + CoordinatorKey groupId, + Errors error, + Map failed, + Set groupsToUnmap + ) { + switch (error) { + case GROUP_AUTHORIZATION_FAILED: + log.debug("`OffsetFetch` request for group id {} failed due to error {}", groupId.idValue, error); + failed.put(groupId, error.exception()); + break; + case COORDINATOR_LOAD_IN_PROGRESS: + // If the coordinator is in the middle of loading, then we just need to retry + log.debug("`OffsetFetch` request for group id {} failed because the coordinator " + + "is still in the process of loading state. Will retry", groupId.idValue); + break; + + case COORDINATOR_NOT_AVAILABLE: + case NOT_COORDINATOR: + // If the coordinator is unavailable or there was a coordinator change, then we unmap + // the key so that we retry the `FindCoordinator` request + log.debug("`OffsetFetch` request for group id {} returned error {}. " + + "Will attempt to find the coordinator again and retry", groupId.idValue, error); + groupsToUnmap.add(groupId); + break; + + default: + log.error("`OffsetFetch` request for group id {} failed due to unexpected error {}", groupId.idValue, error); + failed.put(groupId, error.exception()); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/internals/ListTransactionsHandler.java b/clients/src/main/java/org/apache/kafka/clients/admin/internals/ListTransactionsHandler.java new file mode 100644 index 0000000..d60580c --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/internals/ListTransactionsHandler.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import org.apache.kafka.clients.admin.ListTransactionsOptions; +import org.apache.kafka.clients.admin.TransactionListing; +import org.apache.kafka.clients.admin.TransactionState; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.errors.CoordinatorNotAvailableException; +import org.apache.kafka.common.message.ListTransactionsRequestData; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.AbstractResponse; +import org.apache.kafka.common.requests.ListTransactionsRequest; +import org.apache.kafka.common.requests.ListTransactionsResponse; +import org.apache.kafka.common.utils.LogContext; +import org.slf4j.Logger; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +public class ListTransactionsHandler implements AdminApiHandler> { + private final Logger log; + private final ListTransactionsOptions options; + private final AllBrokersStrategy lookupStrategy; + + public ListTransactionsHandler( + ListTransactionsOptions options, + LogContext logContext + ) { + this.options = options; + this.log = logContext.logger(ListTransactionsHandler.class); + this.lookupStrategy = new AllBrokersStrategy(logContext); + } + + public static AllBrokersStrategy.AllBrokersFuture> newFuture() { + return new AllBrokersStrategy.AllBrokersFuture<>(); + } + + @Override + public String apiName() { + return "listTransactions"; + } + + @Override + public AdminApiLookupStrategy lookupStrategy() { + return lookupStrategy; + } + + @Override + public ListTransactionsRequest.Builder buildRequest( + int brokerId, + Set keys + ) { + ListTransactionsRequestData request = new ListTransactionsRequestData(); + request.setProducerIdFilters(new ArrayList<>(options.filteredProducerIds())); + request.setStateFilters(options.filteredStates().stream() + .map(TransactionState::toString) + .collect(Collectors.toList())); + return new ListTransactionsRequest.Builder(request); + } + + @Override + public ApiResult> handleResponse( + Node broker, + Set keys, + AbstractResponse abstractResponse + ) { + int brokerId = broker.id(); + AllBrokersStrategy.BrokerKey key = requireSingleton(keys, brokerId); + + ListTransactionsResponse response = (ListTransactionsResponse) abstractResponse; + Errors error = Errors.forCode(response.data().errorCode()); + + if (error == Errors.COORDINATOR_LOAD_IN_PROGRESS) { + log.debug("The `ListTransactions` request sent to broker {} failed because the " + + "coordinator is still loading state. Will try again after backing off", brokerId); + return ApiResult.empty(); + } else if (error == Errors.COORDINATOR_NOT_AVAILABLE) { + log.debug("The `ListTransactions` request sent to broker {} failed because the " + + "coordinator is shutting down", brokerId); + return ApiResult.failed(key, new CoordinatorNotAvailableException("ListTransactions " + + "request sent to broker " + brokerId + " failed because the coordinator is shutting down")); + } else if (error != Errors.NONE) { + log.error("The `ListTransactions` request sent to broker {} failed because of an " + + "unexpected error {}", brokerId, error); + return ApiResult.failed(key, error.exception("ListTransactions request " + + "sent to broker " + brokerId + " failed with an unexpected exception")); + } else { + List listings = response.data().transactionStates().stream() + .map(transactionState -> new TransactionListing( + transactionState.transactionalId(), + transactionState.producerId(), + TransactionState.parse(transactionState.transactionState()))) + .collect(Collectors.toList()); + return ApiResult.completed(key, listings); + } + } + + private AllBrokersStrategy.BrokerKey requireSingleton( + Set keys, + int brokerId + ) { + if (keys.size() != 1) { + throw new IllegalArgumentException("Unexpected key set: " + keys); + } + + AllBrokersStrategy.BrokerKey key = keys.iterator().next(); + if (!key.brokerId.isPresent() || key.brokerId.getAsInt() != brokerId) { + throw new IllegalArgumentException("Unexpected broker key: " + key); + } + + return key; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/internals/MetadataOperationContext.java b/clients/src/main/java/org/apache/kafka/clients/admin/internals/MetadataOperationContext.java new file mode 100644 index 0000000..e7f2c07 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/internals/MetadataOperationContext.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin.internals; + +import java.util.Collection; +import java.util.Map; +import java.util.Optional; + +import org.apache.kafka.clients.admin.AbstractOptions; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.InvalidMetadataException; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.requests.MetadataResponse.PartitionMetadata; +import org.apache.kafka.common.requests.MetadataResponse.TopicMetadata; + +/** + * Context class to encapsulate parameters of a call to fetch and use cluster metadata. + * Some of the parameters are provided at construction and are immutable whereas others are provided + * as "Call" are completed and values are available. + * + * @param The type of return value of the KafkaFuture + * @param The type of configuration option. + */ +public final class MetadataOperationContext> { + final private Collection topics; + final private O options; + final private long deadline; + final private Map> futures; + private Optional response; + + public MetadataOperationContext(Collection topics, + O options, + long deadline, + Map> futures) { + this.topics = topics; + this.options = options; + this.deadline = deadline; + this.futures = futures; + this.response = Optional.empty(); + } + + public void setResponse(Optional response) { + this.response = response; + } + + public Optional response() { + return response; + } + + public O options() { + return options; + } + + public long deadline() { + return deadline; + } + + public Map> futures() { + return futures; + } + + public Collection topics() { + return topics; + } + + public static void handleMetadataErrors(MetadataResponse response) { + for (TopicMetadata tm : response.topicMetadata()) { + if (shouldRefreshMetadata(tm.error())) throw tm.error().exception(); + for (PartitionMetadata pm : tm.partitionMetadata()) { + if (shouldRefreshMetadata(pm.error)) { + throw pm.error.exception(); + } + } + } + } + + public static boolean shouldRefreshMetadata(Errors error) { + return error.exception() instanceof InvalidMetadataException; + } +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/internals/PartitionLeaderStrategy.java b/clients/src/main/java/org/apache/kafka/clients/admin/internals/PartitionLeaderStrategy.java new file mode 100644 index 0000000..18ae79a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/internals/PartitionLeaderStrategy.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.InvalidTopicException; +import org.apache.kafka.common.errors.TopicAuthorizationException; +import org.apache.kafka.common.message.MetadataRequestData; +import org.apache.kafka.common.message.MetadataResponseData; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.AbstractResponse; +import org.apache.kafka.common.requests.MetadataRequest; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.utils.LogContext; +import org.slf4j.Logger; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; + +/** + * Base driver implementation for APIs which target partition leaders. + */ +public class PartitionLeaderStrategy implements AdminApiLookupStrategy { + private static final ApiRequestScope SINGLE_REQUEST_SCOPE = new ApiRequestScope() { + }; + + private final Logger log; + + public PartitionLeaderStrategy(LogContext logContext) { + this.log = logContext.logger(PartitionLeaderStrategy.class); + } + + @Override + public ApiRequestScope lookupScope(TopicPartition key) { + // Metadata requests can group topic partitions arbitrarily, so they can all share + // the same request context + return SINGLE_REQUEST_SCOPE; + } + + @Override + public MetadataRequest.Builder buildRequest(Set partitions) { + MetadataRequestData request = new MetadataRequestData(); + request.setAllowAutoTopicCreation(false); + partitions.stream().map(TopicPartition::topic).distinct().forEach(topic -> + request.topics().add(new MetadataRequestData.MetadataRequestTopic().setName(topic)) + ); + return new MetadataRequest.Builder(request); + } + + private void handleTopicError( + String topic, + Errors topicError, + Set requestPartitions, + Map failed + ) { + switch (topicError) { + case UNKNOWN_TOPIC_OR_PARTITION: + case LEADER_NOT_AVAILABLE: + case BROKER_NOT_AVAILABLE: + log.debug("Metadata request for topic {} returned topic-level error {}. Will retry", + topic, topicError); + break; + + case TOPIC_AUTHORIZATION_FAILED: + log.error("Received authorization failure for topic {} in `Metadata` response", topic, + topicError.exception()); + failAllPartitionsForTopic(topic, requestPartitions, failed, tp -> new TopicAuthorizationException( + "Failed to fetch metadata for partition " + tp + " due to topic authorization failure", + Collections.singleton(topic))); + break; + + case INVALID_TOPIC_EXCEPTION: + log.error("Received invalid topic error for topic {} in `Metadata` response", topic, + topicError.exception()); + failAllPartitionsForTopic(topic, requestPartitions, failed, tp -> new InvalidTopicException( + "Failed to fetch metadata for partition " + tp + " due to invalid topic `" + topic + "`", + Collections.singleton(topic))); + break; + + default: + log.error("Received unexpected error for topic {} in `Metadata` response", topic, + topicError.exception()); + failAllPartitionsForTopic(topic, requestPartitions, failed, tp -> topicError.exception( + "Failed to fetch metadata for partition " + tp + " due to unexpected error for topic `" + topic + "`")); + } + } + + private void failAllPartitionsForTopic( + String topic, + Set partitions, + Map failed, + Function exceptionGenerator + ) { + partitions.stream().filter(tp -> tp.topic().equals(topic)).forEach(tp -> { + failed.put(tp, exceptionGenerator.apply(tp)); + }); + } + + private void handlePartitionError( + TopicPartition topicPartition, + Errors partitionError, + Map failed + ) { + switch (partitionError) { + case NOT_LEADER_OR_FOLLOWER: + case REPLICA_NOT_AVAILABLE: + case LEADER_NOT_AVAILABLE: + case BROKER_NOT_AVAILABLE: + case KAFKA_STORAGE_ERROR: + log.debug("Metadata request for partition {} returned partition-level error {}. Will retry", + topicPartition, partitionError); + break; + + default: + log.error("Received unexpected error for partition {} in `Metadata` response", + topicPartition, partitionError.exception()); + failed.put(topicPartition, partitionError.exception( + "Unexpected error during metadata lookup for " + topicPartition)); + } + } + + @Override + public LookupResult handleResponse( + Set requestPartitions, + AbstractResponse abstractResponse + ) { + MetadataResponse response = (MetadataResponse) abstractResponse; + Map failed = new HashMap<>(); + Map mapped = new HashMap<>(); + + for (MetadataResponseData.MetadataResponseTopic topicMetadata : response.data().topics()) { + String topic = topicMetadata.name(); + Errors topicError = Errors.forCode(topicMetadata.errorCode()); + if (topicError != Errors.NONE) { + handleTopicError(topic, topicError, requestPartitions, failed); + continue; + } + + for (MetadataResponseData.MetadataResponsePartition partitionMetadata : topicMetadata.partitions()) { + TopicPartition topicPartition = new TopicPartition(topic, partitionMetadata.partitionIndex()); + Errors partitionError = Errors.forCode(partitionMetadata.errorCode()); + + if (!requestPartitions.contains(topicPartition)) { + // The `Metadata` response always returns all partitions for requested + // topics, so we have to filter any that we are not interested in. + continue; + } + + if (partitionError != Errors.NONE) { + handlePartitionError(topicPartition, partitionError, failed); + continue; + } + + int leaderId = partitionMetadata.leaderId(); + if (leaderId >= 0) { + mapped.put(topicPartition, leaderId); + } else { + log.debug("Metadata request for {} returned no error, but the leader is unknown. Will retry", + topicPartition); + } + } + } + return new LookupResult<>(failed, mapped); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/internals/RemoveMembersFromConsumerGroupHandler.java b/clients/src/main/java/org/apache/kafka/clients/admin/internals/RemoveMembersFromConsumerGroupHandler.java new file mode 100644 index 0000000..90b3865 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/internals/RemoveMembersFromConsumerGroupHandler.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.apache.kafka.common.Node; +import org.apache.kafka.common.message.LeaveGroupRequestData.MemberIdentity; +import org.apache.kafka.common.message.LeaveGroupResponseData.MemberResponse; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.AbstractResponse; +import org.apache.kafka.common.requests.LeaveGroupRequest; +import org.apache.kafka.common.requests.LeaveGroupResponse; +import org.apache.kafka.common.requests.FindCoordinatorRequest.CoordinatorType; +import org.apache.kafka.common.utils.LogContext; +import org.slf4j.Logger; + +public class RemoveMembersFromConsumerGroupHandler implements AdminApiHandler> { + + private final CoordinatorKey groupId; + private final List members; + private final Logger log; + private final AdminApiLookupStrategy lookupStrategy; + + public RemoveMembersFromConsumerGroupHandler( + String groupId, + List members, + LogContext logContext + ) { + this.groupId = CoordinatorKey.byGroupId(groupId); + this.members = members; + this.log = logContext.logger(RemoveMembersFromConsumerGroupHandler.class); + this.lookupStrategy = new CoordinatorStrategy(CoordinatorType.GROUP, logContext); + } + + @Override + public String apiName() { + return "leaveGroup"; + } + + @Override + public AdminApiLookupStrategy lookupStrategy() { + return lookupStrategy; + } + + public static AdminApiFuture.SimpleAdminApiFuture> newFuture( + String groupId + ) { + return AdminApiFuture.forKeys(Collections.singleton(CoordinatorKey.byGroupId(groupId))); + } + + private void validateKeys( + Set groupIds + ) { + if (!groupIds.equals(Collections.singleton(groupId))) { + throw new IllegalArgumentException("Received unexpected group ids " + groupIds + + " (expected only " + Collections.singleton(groupId) + ")"); + } + } + + @Override + public LeaveGroupRequest.Builder buildRequest(int coordinatorId, Set groupIds) { + validateKeys(groupIds); + return new LeaveGroupRequest.Builder(groupId.idValue, members); + } + + @Override + public ApiResult> handleResponse( + Node coordinator, + Set groupIds, + AbstractResponse abstractResponse + ) { + validateKeys(groupIds); + final LeaveGroupResponse response = (LeaveGroupResponse) abstractResponse; + + final Errors error = response.topLevelError(); + if (error != Errors.NONE) { + final Map failed = new HashMap<>(); + final Set groupsToUnmap = new HashSet<>(); + + handleGroupError(groupId, error, failed, groupsToUnmap); + + return new ApiResult<>(Collections.emptyMap(), failed, new ArrayList<>(groupsToUnmap)); + } else { + final Map memberErrors = new HashMap<>(); + for (MemberResponse memberResponse : response.memberResponses()) { + memberErrors.put(new MemberIdentity() + .setMemberId(memberResponse.memberId()) + .setGroupInstanceId(memberResponse.groupInstanceId()), + Errors.forCode(memberResponse.errorCode())); + } + + return ApiResult.completed(groupId, memberErrors); + } + } + + private void handleGroupError( + CoordinatorKey groupId, + Errors error, + Map failed, + Set groupsToUnmap + ) { + switch (error) { + case GROUP_AUTHORIZATION_FAILED: + log.debug("`LeaveGroup` request for group id {} failed due to error {}", groupId.idValue, error); + failed.put(groupId, error.exception()); + break; + case COORDINATOR_LOAD_IN_PROGRESS: + // If the coordinator is in the middle of loading, then we just need to retry + log.debug("`LeaveGroup` request for group id {} failed because the coordinator " + + "is still in the process of loading state. Will retry", groupId.idValue); + break; + case COORDINATOR_NOT_AVAILABLE: + case NOT_COORDINATOR: + // If the coordinator is unavailable or there was a coordinator change, then we unmap + // the key so that we retry the `FindCoordinator` request + log.debug("`LeaveGroup` request for group id {} returned error {}. " + + "Will attempt to find the coordinator again and retry", groupId.idValue, error); + groupsToUnmap.add(groupId); + break; + + default: + log.error("`LeaveGroup` request for group id {} failed due to unexpected error {}", groupId.idValue, error); + failed.put(groupId, error.exception()); + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/admin/internals/StaticBrokerStrategy.java b/clients/src/main/java/org/apache/kafka/clients/admin/internals/StaticBrokerStrategy.java new file mode 100644 index 0000000..7b66537 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/admin/internals/StaticBrokerStrategy.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import org.apache.kafka.common.requests.AbstractRequest; +import org.apache.kafka.common.requests.AbstractResponse; + +import java.util.OptionalInt; +import java.util.Set; + +/** + * This lookup strategy is used when we already know the destination broker ID + * and we have no need for an explicit lookup. By setting {@link ApiRequestScope#destinationBrokerId()} + * in the returned value for {@link #lookupScope(Object)}, the driver will + * skip the lookup. + */ +public class StaticBrokerStrategy implements AdminApiLookupStrategy { + private final SingleBrokerScope scope; + + public StaticBrokerStrategy(int brokerId) { + this.scope = new SingleBrokerScope(brokerId); + } + + @Override + public ApiRequestScope lookupScope(K key) { + return scope; + } + + @Override + public AbstractRequest.Builder buildRequest(Set keys) { + throw new UnsupportedOperationException(); + } + + @Override + public LookupResult handleResponse(Set keys, AbstractResponse response) { + throw new UnsupportedOperationException(); + } + + private static class SingleBrokerScope implements ApiRequestScope { + private final int brokerId; + + private SingleBrokerScope(int brokerId) { + this.brokerId = brokerId; + } + + @Override + public OptionalInt destinationBrokerId() { + return OptionalInt.of(brokerId); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/CommitFailedException.java b/clients/src/main/java/org/apache/kafka/clients/consumer/CommitFailedException.java new file mode 100644 index 0000000..2040216 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/CommitFailedException.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import org.apache.kafka.common.KafkaException; + +/** + * This exception is raised when an offset commit with {@link KafkaConsumer#commitSync()} fails + * with an unrecoverable error. This can happen when a group rebalance completes before the commit + * could be successfully applied. In this case, the commit cannot generally be retried because some + * of the partitions may have already been assigned to another member in the group. + */ +public class CommitFailedException extends KafkaException { + + private static final long serialVersionUID = 1L; + + public CommitFailedException(final String message) { + super(message); + } + + public CommitFailedException() { + super("Commit cannot be completed since the group has already " + + "rebalanced and assigned the partitions to another member. This means that the time " + + "between subsequent calls to poll() was longer than the configured max.poll.interval.ms, " + + "which typically implies that the poll loop is spending too much time message processing. " + + "You can address this either by increasing max.poll.interval.ms or by reducing the maximum " + + "size of batches returned in poll() with max.poll.records."); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/Consumer.java b/clients/src/main/java/org/apache/kafka/clients/consumer/Consumer.java new file mode 100644 index 0000000..0dd7dd8 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/Consumer.java @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; + +import java.io.Closeable; +import java.time.Duration; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.OptionalLong; +import java.util.Set; +import java.util.regex.Pattern; + +/** + * @see KafkaConsumer + * @see MockConsumer + */ +public interface Consumer extends Closeable { + + /** + * @see KafkaConsumer#assignment() + */ + Set assignment(); + + /** + * @see KafkaConsumer#subscription() + */ + Set subscription(); + + /** + * @see KafkaConsumer#subscribe(Collection) + */ + void subscribe(Collection topics); + + /** + * @see KafkaConsumer#subscribe(Collection, ConsumerRebalanceListener) + */ + void subscribe(Collection topics, ConsumerRebalanceListener callback); + + /** + * @see KafkaConsumer#assign(Collection) + */ + void assign(Collection partitions); + + /** + * @see KafkaConsumer#subscribe(Pattern, ConsumerRebalanceListener) + */ + void subscribe(Pattern pattern, ConsumerRebalanceListener callback); + + /** + * @see KafkaConsumer#subscribe(Pattern) + */ + void subscribe(Pattern pattern); + + /** + * @see KafkaConsumer#unsubscribe() + */ + void unsubscribe(); + + /** + * @see KafkaConsumer#poll(long) + */ + @Deprecated + ConsumerRecords poll(long timeout); + + /** + * @see KafkaConsumer#poll(Duration) + */ + ConsumerRecords poll(Duration timeout); + + /** + * @see KafkaConsumer#commitSync() + */ + void commitSync(); + + /** + * @see KafkaConsumer#commitSync(Duration) + */ + void commitSync(Duration timeout); + + /** + * @see KafkaConsumer#commitSync(Map) + */ + void commitSync(Map offsets); + + /** + * @see KafkaConsumer#commitSync(Map, Duration) + */ + void commitSync(final Map offsets, final Duration timeout); + /** + * @see KafkaConsumer#commitAsync() + */ + void commitAsync(); + + /** + * @see KafkaConsumer#commitAsync(OffsetCommitCallback) + */ + void commitAsync(OffsetCommitCallback callback); + + /** + * @see KafkaConsumer#commitAsync(Map, OffsetCommitCallback) + */ + void commitAsync(Map offsets, OffsetCommitCallback callback); + + /** + * @see KafkaConsumer#seek(TopicPartition, long) + */ + void seek(TopicPartition partition, long offset); + + /** + * @see KafkaConsumer#seek(TopicPartition, OffsetAndMetadata) + */ + void seek(TopicPartition partition, OffsetAndMetadata offsetAndMetadata); + + /** + * @see KafkaConsumer#seekToBeginning(Collection) + */ + void seekToBeginning(Collection partitions); + + /** + * @see KafkaConsumer#seekToEnd(Collection) + */ + void seekToEnd(Collection partitions); + + /** + * @see KafkaConsumer#position(TopicPartition) + */ + long position(TopicPartition partition); + + /** + * @see KafkaConsumer#position(TopicPartition, Duration) + */ + long position(TopicPartition partition, final Duration timeout); + + /** + * @see KafkaConsumer#committed(TopicPartition) + */ + @Deprecated + OffsetAndMetadata committed(TopicPartition partition); + + /** + * @see KafkaConsumer#committed(TopicPartition, Duration) + */ + @Deprecated + OffsetAndMetadata committed(TopicPartition partition, final Duration timeout); + + /** + * @see KafkaConsumer#committed(Set) + */ + Map committed(Set partitions); + + /** + * @see KafkaConsumer#committed(Set, Duration) + */ + Map committed(Set partitions, final Duration timeout); + + /** + * @see KafkaConsumer#metrics() + */ + Map metrics(); + + /** + * @see KafkaConsumer#partitionsFor(String) + */ + List partitionsFor(String topic); + + /** + * @see KafkaConsumer#partitionsFor(String, Duration) + */ + List partitionsFor(String topic, Duration timeout); + + /** + * @see KafkaConsumer#listTopics() + */ + Map> listTopics(); + + /** + * @see KafkaConsumer#listTopics(Duration) + */ + Map> listTopics(Duration timeout); + + /** + * @see KafkaConsumer#paused() + */ + Set paused(); + + /** + * @see KafkaConsumer#pause(Collection) + */ + void pause(Collection partitions); + + /** + * @see KafkaConsumer#resume(Collection) + */ + void resume(Collection partitions); + + /** + * @see KafkaConsumer#offsetsForTimes(Map) + */ + Map offsetsForTimes(Map timestampsToSearch); + + /** + * @see KafkaConsumer#offsetsForTimes(Map, Duration) + */ + Map offsetsForTimes(Map timestampsToSearch, Duration timeout); + + /** + * @see KafkaConsumer#beginningOffsets(Collection) + */ + Map beginningOffsets(Collection partitions); + + /** + * @see KafkaConsumer#beginningOffsets(Collection, Duration) + */ + Map beginningOffsets(Collection partitions, Duration timeout); + + /** + * @see KafkaConsumer#endOffsets(Collection) + */ + Map endOffsets(Collection partitions); + + /** + * @see KafkaConsumer#endOffsets(Collection, Duration) + */ + Map endOffsets(Collection partitions, Duration timeout); + + /** + * @see KafkaConsumer#currentLag(TopicPartition) + */ + OptionalLong currentLag(TopicPartition topicPartition); + + /** + * @see KafkaConsumer#groupMetadata() + */ + ConsumerGroupMetadata groupMetadata(); + + /** + * @see KafkaConsumer#enforceRebalance() + */ + void enforceRebalance(); + + /** + * @see KafkaConsumer#close() + */ + void close(); + + /** + * @see KafkaConsumer#close(Duration) + */ + void close(Duration timeout); + + /** + * @see KafkaConsumer#wakeup() + */ + void wakeup(); + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerConfig.java b/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerConfig.java new file mode 100644 index 0000000..ca24c28 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerConfig.java @@ -0,0 +1,649 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import org.apache.kafka.clients.ClientDnsLookup; +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.common.IsolationLevel; +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigDef.Importance; +import org.apache.kafka.common.config.ConfigDef.Type; +import org.apache.kafka.common.config.SecurityConfig; +import org.apache.kafka.common.errors.InvalidConfigurationException; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.requests.JoinGroupRequest; +import org.apache.kafka.common.serialization.Deserializer; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.apache.kafka.clients.consumer.CooperativeStickyAssignor.COOPERATIVE_STICKY_ASSIGNOR_NAME; +import static org.apache.kafka.clients.consumer.RangeAssignor.RANGE_ASSIGNOR_NAME; +import static org.apache.kafka.clients.consumer.RoundRobinAssignor.ROUNDROBIN_ASSIGNOR_NAME; +import static org.apache.kafka.clients.consumer.StickyAssignor.STICKY_ASSIGNOR_NAME; +import static org.apache.kafka.common.config.ConfigDef.Range.atLeast; +import static org.apache.kafka.common.config.ConfigDef.ValidString.in; + +/** + * The consumer configuration keys + */ +public class ConsumerConfig extends AbstractConfig { + private static final ConfigDef CONFIG; + + // a list contains all the assignor names that only assign subscribed topics to consumer. Should be updated when new assignor added. + // This is to help optimize ConsumerCoordinator#performAssignment method + public static final List ASSIGN_FROM_SUBSCRIBED_ASSIGNORS = + Collections.unmodifiableList(Arrays.asList( + RANGE_ASSIGNOR_NAME, + ROUNDROBIN_ASSIGNOR_NAME, + STICKY_ASSIGNOR_NAME, + COOPERATIVE_STICKY_ASSIGNOR_NAME + )); + + /* + * NOTE: DO NOT CHANGE EITHER CONFIG STRINGS OR THEIR JAVA VARIABLE NAMES AS + * THESE ARE PART OF THE PUBLIC API AND CHANGE WILL BREAK USER CODE. + */ + + /** + * group.id + */ + public static final String GROUP_ID_CONFIG = CommonClientConfigs.GROUP_ID_CONFIG; + private static final String GROUP_ID_DOC = CommonClientConfigs.GROUP_ID_DOC; + + /** + * group.instance.id + */ + public static final String GROUP_INSTANCE_ID_CONFIG = CommonClientConfigs.GROUP_INSTANCE_ID_CONFIG; + private static final String GROUP_INSTANCE_ID_DOC = CommonClientConfigs.GROUP_INSTANCE_ID_DOC; + + /** max.poll.records */ + public static final String MAX_POLL_RECORDS_CONFIG = "max.poll.records"; + private static final String MAX_POLL_RECORDS_DOC = "The maximum number of records returned in a single call to poll()." + + " Note, that " + MAX_POLL_RECORDS_CONFIG + " does not impact the underlying fetching behavior." + + " The consumer will cache the records from each fetch request and returns them incrementally from each poll."; + + /** max.poll.interval.ms */ + public static final String MAX_POLL_INTERVAL_MS_CONFIG = CommonClientConfigs.MAX_POLL_INTERVAL_MS_CONFIG; + private static final String MAX_POLL_INTERVAL_MS_DOC = CommonClientConfigs.MAX_POLL_INTERVAL_MS_DOC; + /** + * session.timeout.ms + */ + public static final String SESSION_TIMEOUT_MS_CONFIG = CommonClientConfigs.SESSION_TIMEOUT_MS_CONFIG; + private static final String SESSION_TIMEOUT_MS_DOC = CommonClientConfigs.SESSION_TIMEOUT_MS_DOC; + + /** + * heartbeat.interval.ms + */ + public static final String HEARTBEAT_INTERVAL_MS_CONFIG = CommonClientConfigs.HEARTBEAT_INTERVAL_MS_CONFIG; + private static final String HEARTBEAT_INTERVAL_MS_DOC = CommonClientConfigs.HEARTBEAT_INTERVAL_MS_DOC; + + /** + * bootstrap.servers + */ + public static final String BOOTSTRAP_SERVERS_CONFIG = CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG; + + /** client.dns.lookup */ + public static final String CLIENT_DNS_LOOKUP_CONFIG = CommonClientConfigs.CLIENT_DNS_LOOKUP_CONFIG; + + /** + * enable.auto.commit + */ + public static final String ENABLE_AUTO_COMMIT_CONFIG = "enable.auto.commit"; + private static final String ENABLE_AUTO_COMMIT_DOC = "If true the consumer's offset will be periodically committed in the background."; + + /** + * auto.commit.interval.ms + */ + public static final String AUTO_COMMIT_INTERVAL_MS_CONFIG = "auto.commit.interval.ms"; + private static final String AUTO_COMMIT_INTERVAL_MS_DOC = "The frequency in milliseconds that the consumer offsets are auto-committed to Kafka if enable.auto.commit is set to true."; + + /** + * partition.assignment.strategy + */ + public static final String PARTITION_ASSIGNMENT_STRATEGY_CONFIG = "partition.assignment.strategy"; + private static final String PARTITION_ASSIGNMENT_STRATEGY_DOC = "A list of class names or class types, " + + "ordered by preference, of supported partition assignment strategies that the client will use to distribute " + + "partition ownership amongst consumer instances when group management is used. Available options are:" + + "
    " + + "
  • org.apache.kafka.clients.consumer.RangeAssignor: Assigns partitions on a per-topic basis.
  • " + + "
  • org.apache.kafka.clients.consumer.RoundRobinAssignor: Assigns partitions to consumers in a round-robin fashion.
  • " + + "
  • org.apache.kafka.clients.consumer.StickyAssignor: Guarantees an assignment that is " + + "maximally balanced while preserving as many existing partition assignments as possible.
  • " + + "
  • org.apache.kafka.clients.consumer.CooperativeStickyAssignor: Follows the same StickyAssignor " + + "logic, but allows for cooperative rebalancing.
  • " + + "
" + + "

The default assignor is [RangeAssignor, CooperativeStickyAssignor], which will use the RangeAssignor by default, " + + "but allows upgrading to the CooperativeStickyAssignor with just a single rolling bounce that removes the RangeAssignor from the list.

" + + "

Implementing the org.apache.kafka.clients.consumer.ConsumerPartitionAssignor " + + "interface allows you to plug in a custom assignment strategy.

"; + + /** + * auto.offset.reset + */ + public static final String AUTO_OFFSET_RESET_CONFIG = "auto.offset.reset"; + public static final String AUTO_OFFSET_RESET_DOC = "What to do when there is no initial offset in Kafka or if the current offset does not exist any more on the server (e.g. because that data has been deleted):
  • earliest: automatically reset the offset to the earliest offset
  • latest: automatically reset the offset to the latest offset
  • none: throw exception to the consumer if no previous offset is found for the consumer's group
  • anything else: throw exception to the consumer.
"; + + /** + * fetch.min.bytes + */ + public static final String FETCH_MIN_BYTES_CONFIG = "fetch.min.bytes"; + private static final String FETCH_MIN_BYTES_DOC = "The minimum amount of data the server should return for a fetch request. If insufficient data is available the request will wait for that much data to accumulate before answering the request. The default setting of 1 byte means that fetch requests are answered as soon as a single byte of data is available or the fetch request times out waiting for data to arrive. Setting this to something greater than 1 will cause the server to wait for larger amounts of data to accumulate which can improve server throughput a bit at the cost of some additional latency."; + + /** + * fetch.max.bytes + */ + public static final String FETCH_MAX_BYTES_CONFIG = "fetch.max.bytes"; + private static final String FETCH_MAX_BYTES_DOC = "The maximum amount of data the server should return for a fetch request. " + + "Records are fetched in batches by the consumer, and if the first record batch in the first non-empty partition of the fetch is larger than " + + "this value, the record batch will still be returned to ensure that the consumer can make progress. As such, this is not a absolute maximum. " + + "The maximum record batch size accepted by the broker is defined via message.max.bytes (broker config) or " + + "max.message.bytes (topic config). Note that the consumer performs multiple fetches in parallel."; + public static final int DEFAULT_FETCH_MAX_BYTES = 50 * 1024 * 1024; + + /** + * fetch.max.wait.ms + */ + public static final String FETCH_MAX_WAIT_MS_CONFIG = "fetch.max.wait.ms"; + private static final String FETCH_MAX_WAIT_MS_DOC = "The maximum amount of time the server will block before answering the fetch request if there isn't sufficient data to immediately satisfy the requirement given by fetch.min.bytes."; + + /** metadata.max.age.ms */ + public static final String METADATA_MAX_AGE_CONFIG = CommonClientConfigs.METADATA_MAX_AGE_CONFIG; + + /** + * max.partition.fetch.bytes + */ + public static final String MAX_PARTITION_FETCH_BYTES_CONFIG = "max.partition.fetch.bytes"; + private static final String MAX_PARTITION_FETCH_BYTES_DOC = "The maximum amount of data per-partition the server " + + "will return. Records are fetched in batches by the consumer. If the first record batch in the first non-empty " + + "partition of the fetch is larger than this limit, the " + + "batch will still be returned to ensure that the consumer can make progress. The maximum record batch size " + + "accepted by the broker is defined via message.max.bytes (broker config) or " + + "max.message.bytes (topic config). See " + FETCH_MAX_BYTES_CONFIG + " for limiting the consumer request size."; + public static final int DEFAULT_MAX_PARTITION_FETCH_BYTES = 1 * 1024 * 1024; + + /** send.buffer.bytes */ + public static final String SEND_BUFFER_CONFIG = CommonClientConfigs.SEND_BUFFER_CONFIG; + + /** receive.buffer.bytes */ + public static final String RECEIVE_BUFFER_CONFIG = CommonClientConfigs.RECEIVE_BUFFER_CONFIG; + + /** + * client.id + */ + public static final String CLIENT_ID_CONFIG = CommonClientConfigs.CLIENT_ID_CONFIG; + + /** + * client.rack + */ + public static final String CLIENT_RACK_CONFIG = CommonClientConfigs.CLIENT_RACK_CONFIG; + + /** + * reconnect.backoff.ms + */ + public static final String RECONNECT_BACKOFF_MS_CONFIG = CommonClientConfigs.RECONNECT_BACKOFF_MS_CONFIG; + + /** + * reconnect.backoff.max.ms + */ + public static final String RECONNECT_BACKOFF_MAX_MS_CONFIG = CommonClientConfigs.RECONNECT_BACKOFF_MAX_MS_CONFIG; + + /** + * retry.backoff.ms + */ + public static final String RETRY_BACKOFF_MS_CONFIG = CommonClientConfigs.RETRY_BACKOFF_MS_CONFIG; + + /** + * metrics.sample.window.ms + */ + public static final String METRICS_SAMPLE_WINDOW_MS_CONFIG = CommonClientConfigs.METRICS_SAMPLE_WINDOW_MS_CONFIG; + + /** + * metrics.num.samples + */ + public static final String METRICS_NUM_SAMPLES_CONFIG = CommonClientConfigs.METRICS_NUM_SAMPLES_CONFIG; + + /** + * metrics.log.level + */ + public static final String METRICS_RECORDING_LEVEL_CONFIG = CommonClientConfigs.METRICS_RECORDING_LEVEL_CONFIG; + + /** + * metric.reporters + */ + public static final String METRIC_REPORTER_CLASSES_CONFIG = CommonClientConfigs.METRIC_REPORTER_CLASSES_CONFIG; + + /** + * check.crcs + */ + public static final String CHECK_CRCS_CONFIG = "check.crcs"; + private static final String CHECK_CRCS_DOC = "Automatically check the CRC32 of the records consumed. This ensures no on-the-wire or on-disk corruption to the messages occurred. This check adds some overhead, so it may be disabled in cases seeking extreme performance."; + + /** key.deserializer */ + public static final String KEY_DESERIALIZER_CLASS_CONFIG = "key.deserializer"; + public static final String KEY_DESERIALIZER_CLASS_DOC = "Deserializer class for key that implements the org.apache.kafka.common.serialization.Deserializer interface."; + + /** value.deserializer */ + public static final String VALUE_DESERIALIZER_CLASS_CONFIG = "value.deserializer"; + public static final String VALUE_DESERIALIZER_CLASS_DOC = "Deserializer class for value that implements the org.apache.kafka.common.serialization.Deserializer interface."; + + /** socket.connection.setup.timeout.ms */ + public static final String SOCKET_CONNECTION_SETUP_TIMEOUT_MS_CONFIG = CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MS_CONFIG; + + /** socket.connection.setup.timeout.max.ms */ + public static final String SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS_CONFIG = CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS_CONFIG; + + /** connections.max.idle.ms */ + public static final String CONNECTIONS_MAX_IDLE_MS_CONFIG = CommonClientConfigs.CONNECTIONS_MAX_IDLE_MS_CONFIG; + + /** request.timeout.ms */ + public static final String REQUEST_TIMEOUT_MS_CONFIG = CommonClientConfigs.REQUEST_TIMEOUT_MS_CONFIG; + private static final String REQUEST_TIMEOUT_MS_DOC = CommonClientConfigs.REQUEST_TIMEOUT_MS_DOC; + + /** default.api.timeout.ms */ + public static final String DEFAULT_API_TIMEOUT_MS_CONFIG = CommonClientConfigs.DEFAULT_API_TIMEOUT_MS_CONFIG; + + /** interceptor.classes */ + public static final String INTERCEPTOR_CLASSES_CONFIG = "interceptor.classes"; + public static final String INTERCEPTOR_CLASSES_DOC = "A list of classes to use as interceptors. " + + "Implementing the org.apache.kafka.clients.consumer.ConsumerInterceptor interface allows you to intercept (and possibly mutate) records " + + "received by the consumer. By default, there are no interceptors."; + + + /** exclude.internal.topics */ + public static final String EXCLUDE_INTERNAL_TOPICS_CONFIG = "exclude.internal.topics"; + private static final String EXCLUDE_INTERNAL_TOPICS_DOC = "Whether internal topics matching a subscribed pattern should " + + "be excluded from the subscription. It is always possible to explicitly subscribe to an internal topic."; + public static final boolean DEFAULT_EXCLUDE_INTERNAL_TOPICS = true; + + /** + * internal.leave.group.on.close + * Whether or not the consumer should leave the group on close. If set to false then a rebalance + * won't occur until session.timeout.ms expires. + * + *

+ * Note: this is an internal configuration and could be changed in the future in a backward incompatible way + * + */ + static final String LEAVE_GROUP_ON_CLOSE_CONFIG = "internal.leave.group.on.close"; + + /** + * internal.throw.on.fetch.stable.offset.unsupported + * Whether or not the consumer should throw when the new stable offset feature is supported. + * If set to true then the client shall crash upon hitting it. + * The purpose of this flag is to prevent unexpected broker downgrade which makes + * the offset fetch protection against pending commit invalid. The safest approach + * is to fail fast to avoid introducing correctness issue. + * + *

+ * Note: this is an internal configuration and could be changed in the future in a backward incompatible way + * + */ + static final String THROW_ON_FETCH_STABLE_OFFSET_UNSUPPORTED = "internal.throw.on.fetch.stable.offset.unsupported"; + + /** isolation.level */ + public static final String ISOLATION_LEVEL_CONFIG = "isolation.level"; + public static final String ISOLATION_LEVEL_DOC = "Controls how to read messages written transactionally. If set to read_committed, consumer.poll() will only return" + + " transactional messages which have been committed. If set to read_uncommitted (the default), consumer.poll() will return all messages, even transactional messages" + + " which have been aborted. Non-transactional messages will be returned unconditionally in either mode.

Messages will always be returned in offset order. Hence, in " + + " read_committed mode, consumer.poll() will only return messages up to the last stable offset (LSO), which is the one less than the offset of the first open transaction." + + " In particular any messages appearing after messages belonging to ongoing transactions will be withheld until the relevant transaction has been completed. As a result, read_committed" + + " consumers will not be able to read up to the high watermark when there are in flight transactions.

Further, when in read_committed the seekToEnd method will" + + " return the LSO"; + + public static final String DEFAULT_ISOLATION_LEVEL = IsolationLevel.READ_UNCOMMITTED.toString().toLowerCase(Locale.ROOT); + + /** allow.auto.create.topics */ + public static final String ALLOW_AUTO_CREATE_TOPICS_CONFIG = "allow.auto.create.topics"; + private static final String ALLOW_AUTO_CREATE_TOPICS_DOC = "Allow automatic topic creation on the broker when" + + " subscribing to or assigning a topic. A topic being subscribed to will be automatically created only if the" + + " broker allows for it using `auto.create.topics.enable` broker configuration. This configuration must" + + " be set to `false` when using brokers older than 0.11.0"; + public static final boolean DEFAULT_ALLOW_AUTO_CREATE_TOPICS = true; + + /** + * security.providers + */ + public static final String SECURITY_PROVIDERS_CONFIG = SecurityConfig.SECURITY_PROVIDERS_CONFIG; + private static final String SECURITY_PROVIDERS_DOC = SecurityConfig.SECURITY_PROVIDERS_DOC; + + private static final AtomicInteger CONSUMER_CLIENT_ID_SEQUENCE = new AtomicInteger(1); + + static { + CONFIG = new ConfigDef().define(BOOTSTRAP_SERVERS_CONFIG, + Type.LIST, + Collections.emptyList(), + new ConfigDef.NonNullValidator(), + Importance.HIGH, + CommonClientConfigs.BOOTSTRAP_SERVERS_DOC) + .define(CLIENT_DNS_LOOKUP_CONFIG, + Type.STRING, + ClientDnsLookup.USE_ALL_DNS_IPS.toString(), + in(ClientDnsLookup.USE_ALL_DNS_IPS.toString(), + ClientDnsLookup.RESOLVE_CANONICAL_BOOTSTRAP_SERVERS_ONLY.toString()), + Importance.MEDIUM, + CommonClientConfigs.CLIENT_DNS_LOOKUP_DOC) + .define(GROUP_ID_CONFIG, Type.STRING, null, Importance.HIGH, GROUP_ID_DOC) + .define(GROUP_INSTANCE_ID_CONFIG, + Type.STRING, + null, + Importance.MEDIUM, + GROUP_INSTANCE_ID_DOC) + .define(SESSION_TIMEOUT_MS_CONFIG, + Type.INT, + 45000, + Importance.HIGH, + SESSION_TIMEOUT_MS_DOC) + .define(HEARTBEAT_INTERVAL_MS_CONFIG, + Type.INT, + 3000, + Importance.HIGH, + HEARTBEAT_INTERVAL_MS_DOC) + .define(PARTITION_ASSIGNMENT_STRATEGY_CONFIG, + Type.LIST, + Arrays.asList(RangeAssignor.class, CooperativeStickyAssignor.class), + new ConfigDef.NonNullValidator(), + Importance.MEDIUM, + PARTITION_ASSIGNMENT_STRATEGY_DOC) + .define(METADATA_MAX_AGE_CONFIG, + Type.LONG, + 5 * 60 * 1000, + atLeast(0), + Importance.LOW, + CommonClientConfigs.METADATA_MAX_AGE_DOC) + .define(ENABLE_AUTO_COMMIT_CONFIG, + Type.BOOLEAN, + true, + Importance.MEDIUM, + ENABLE_AUTO_COMMIT_DOC) + .define(AUTO_COMMIT_INTERVAL_MS_CONFIG, + Type.INT, + 5000, + atLeast(0), + Importance.LOW, + AUTO_COMMIT_INTERVAL_MS_DOC) + .define(CLIENT_ID_CONFIG, + Type.STRING, + "", + Importance.LOW, + CommonClientConfigs.CLIENT_ID_DOC) + .define(CLIENT_RACK_CONFIG, + Type.STRING, + "", + Importance.LOW, + CommonClientConfigs.CLIENT_RACK_DOC) + .define(MAX_PARTITION_FETCH_BYTES_CONFIG, + Type.INT, + DEFAULT_MAX_PARTITION_FETCH_BYTES, + atLeast(0), + Importance.HIGH, + MAX_PARTITION_FETCH_BYTES_DOC) + .define(SEND_BUFFER_CONFIG, + Type.INT, + 128 * 1024, + atLeast(CommonClientConfigs.SEND_BUFFER_LOWER_BOUND), + Importance.MEDIUM, + CommonClientConfigs.SEND_BUFFER_DOC) + .define(RECEIVE_BUFFER_CONFIG, + Type.INT, + 64 * 1024, + atLeast(CommonClientConfigs.RECEIVE_BUFFER_LOWER_BOUND), + Importance.MEDIUM, + CommonClientConfigs.RECEIVE_BUFFER_DOC) + .define(FETCH_MIN_BYTES_CONFIG, + Type.INT, + 1, + atLeast(0), + Importance.HIGH, + FETCH_MIN_BYTES_DOC) + .define(FETCH_MAX_BYTES_CONFIG, + Type.INT, + DEFAULT_FETCH_MAX_BYTES, + atLeast(0), + Importance.MEDIUM, + FETCH_MAX_BYTES_DOC) + .define(FETCH_MAX_WAIT_MS_CONFIG, + Type.INT, + 500, + atLeast(0), + Importance.LOW, + FETCH_MAX_WAIT_MS_DOC) + .define(RECONNECT_BACKOFF_MS_CONFIG, + Type.LONG, + 50L, + atLeast(0L), + Importance.LOW, + CommonClientConfigs.RECONNECT_BACKOFF_MS_DOC) + .define(RECONNECT_BACKOFF_MAX_MS_CONFIG, + Type.LONG, + 1000L, + atLeast(0L), + Importance.LOW, + CommonClientConfigs.RECONNECT_BACKOFF_MAX_MS_DOC) + .define(RETRY_BACKOFF_MS_CONFIG, + Type.LONG, + 100L, + atLeast(0L), + Importance.LOW, + CommonClientConfigs.RETRY_BACKOFF_MS_DOC) + .define(AUTO_OFFSET_RESET_CONFIG, + Type.STRING, + "latest", + in("latest", "earliest", "none"), + Importance.MEDIUM, + AUTO_OFFSET_RESET_DOC) + .define(CHECK_CRCS_CONFIG, + Type.BOOLEAN, + true, + Importance.LOW, + CHECK_CRCS_DOC) + .define(METRICS_SAMPLE_WINDOW_MS_CONFIG, + Type.LONG, + 30000, + atLeast(0), + Importance.LOW, + CommonClientConfigs.METRICS_SAMPLE_WINDOW_MS_DOC) + .define(METRICS_NUM_SAMPLES_CONFIG, + Type.INT, + 2, + atLeast(1), + Importance.LOW, + CommonClientConfigs.METRICS_NUM_SAMPLES_DOC) + .define(METRICS_RECORDING_LEVEL_CONFIG, + Type.STRING, + Sensor.RecordingLevel.INFO.toString(), + in(Sensor.RecordingLevel.INFO.toString(), Sensor.RecordingLevel.DEBUG.toString(), Sensor.RecordingLevel.TRACE.toString()), + Importance.LOW, + CommonClientConfigs.METRICS_RECORDING_LEVEL_DOC) + .define(METRIC_REPORTER_CLASSES_CONFIG, + Type.LIST, + Collections.emptyList(), + new ConfigDef.NonNullValidator(), + Importance.LOW, + CommonClientConfigs.METRIC_REPORTER_CLASSES_DOC) + .define(KEY_DESERIALIZER_CLASS_CONFIG, + Type.CLASS, + Importance.HIGH, + KEY_DESERIALIZER_CLASS_DOC) + .define(VALUE_DESERIALIZER_CLASS_CONFIG, + Type.CLASS, + Importance.HIGH, + VALUE_DESERIALIZER_CLASS_DOC) + .define(REQUEST_TIMEOUT_MS_CONFIG, + Type.INT, + 30000, + atLeast(0), + Importance.MEDIUM, + REQUEST_TIMEOUT_MS_DOC) + .define(DEFAULT_API_TIMEOUT_MS_CONFIG, + Type.INT, + 60 * 1000, + atLeast(0), + Importance.MEDIUM, + CommonClientConfigs.DEFAULT_API_TIMEOUT_MS_DOC) + .define(SOCKET_CONNECTION_SETUP_TIMEOUT_MS_CONFIG, + Type.LONG, + CommonClientConfigs.DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT_MS, + Importance.MEDIUM, + CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MS_DOC) + .define(SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS_CONFIG, + Type.LONG, + CommonClientConfigs.DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS, + Importance.MEDIUM, + CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS_DOC) + /* default is set to be a bit lower than the server default (10 min), to avoid both client and server closing connection at same time */ + .define(CONNECTIONS_MAX_IDLE_MS_CONFIG, + Type.LONG, + 9 * 60 * 1000, + Importance.MEDIUM, + CommonClientConfigs.CONNECTIONS_MAX_IDLE_MS_DOC) + .define(INTERCEPTOR_CLASSES_CONFIG, + Type.LIST, + Collections.emptyList(), + new ConfigDef.NonNullValidator(), + Importance.LOW, + INTERCEPTOR_CLASSES_DOC) + .define(MAX_POLL_RECORDS_CONFIG, + Type.INT, + 500, + atLeast(1), + Importance.MEDIUM, + MAX_POLL_RECORDS_DOC) + .define(MAX_POLL_INTERVAL_MS_CONFIG, + Type.INT, + 300000, + atLeast(1), + Importance.MEDIUM, + MAX_POLL_INTERVAL_MS_DOC) + .define(EXCLUDE_INTERNAL_TOPICS_CONFIG, + Type.BOOLEAN, + DEFAULT_EXCLUDE_INTERNAL_TOPICS, + Importance.MEDIUM, + EXCLUDE_INTERNAL_TOPICS_DOC) + .defineInternal(LEAVE_GROUP_ON_CLOSE_CONFIG, + Type.BOOLEAN, + true, + Importance.LOW) + .defineInternal(THROW_ON_FETCH_STABLE_OFFSET_UNSUPPORTED, + Type.BOOLEAN, + false, + Importance.LOW) + .define(ISOLATION_LEVEL_CONFIG, + Type.STRING, + DEFAULT_ISOLATION_LEVEL, + in(IsolationLevel.READ_COMMITTED.toString().toLowerCase(Locale.ROOT), IsolationLevel.READ_UNCOMMITTED.toString().toLowerCase(Locale.ROOT)), + Importance.MEDIUM, + ISOLATION_LEVEL_DOC) + .define(ALLOW_AUTO_CREATE_TOPICS_CONFIG, + Type.BOOLEAN, + DEFAULT_ALLOW_AUTO_CREATE_TOPICS, + Importance.MEDIUM, + ALLOW_AUTO_CREATE_TOPICS_DOC) + // security support + .define(SECURITY_PROVIDERS_CONFIG, + Type.STRING, + null, + Importance.LOW, + SECURITY_PROVIDERS_DOC) + .define(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, + Type.STRING, + CommonClientConfigs.DEFAULT_SECURITY_PROTOCOL, + Importance.MEDIUM, + CommonClientConfigs.SECURITY_PROTOCOL_DOC) + .withClientSslSupport() + .withClientSaslSupport(); + } + + @Override + protected Map postProcessParsedConfig(final Map parsedValues) { + Map refinedConfigs = CommonClientConfigs.postProcessReconnectBackoffConfigs(this, parsedValues); + maybeOverrideClientId(refinedConfigs); + return refinedConfigs; + } + + private void maybeOverrideClientId(Map configs) { + final String clientId = this.getString(CLIENT_ID_CONFIG); + if (clientId == null || clientId.isEmpty()) { + final String groupId = this.getString(GROUP_ID_CONFIG); + String groupInstanceId = this.getString(GROUP_INSTANCE_ID_CONFIG); + if (groupInstanceId != null) + JoinGroupRequest.validateGroupInstanceId(groupInstanceId); + + String groupInstanceIdPart = groupInstanceId != null ? groupInstanceId : CONSUMER_CLIENT_ID_SEQUENCE.getAndIncrement() + ""; + String generatedClientId = String.format("consumer-%s-%s", groupId, groupInstanceIdPart); + configs.put(CLIENT_ID_CONFIG, generatedClientId); + } + } + + protected static Map appendDeserializerToConfig(Map configs, + Deserializer keyDeserializer, + Deserializer valueDeserializer) { + Map newConfigs = new HashMap<>(configs); + if (keyDeserializer != null) + newConfigs.put(KEY_DESERIALIZER_CLASS_CONFIG, keyDeserializer.getClass()); + if (valueDeserializer != null) + newConfigs.put(VALUE_DESERIALIZER_CLASS_CONFIG, valueDeserializer.getClass()); + return newConfigs; + } + + boolean maybeOverrideEnableAutoCommit() { + Optional groupId = Optional.ofNullable(getString(CommonClientConfigs.GROUP_ID_CONFIG)); + boolean enableAutoCommit = getBoolean(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG); + if (!groupId.isPresent()) { // overwrite in case of default group id where the config is not explicitly provided + if (!originals().containsKey(ENABLE_AUTO_COMMIT_CONFIG)) { + enableAutoCommit = false; + } else if (enableAutoCommit) { + throw new InvalidConfigurationException(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG + " cannot be set to true when default group id (null) is used."); + } + } + return enableAutoCommit; + } + + public ConsumerConfig(Properties props) { + super(CONFIG, props); + } + + public ConsumerConfig(Map props) { + super(CONFIG, props); + } + + protected ConsumerConfig(Map props, boolean doLog) { + super(CONFIG, props, doLog); + } + + public static Set configNames() { + return CONFIG.names(); + } + + public static ConfigDef configDef() { + return new ConfigDef(CONFIG); + } + + public static void main(String[] args) { + System.out.println(CONFIG.toHtml(4, config -> "consumerconfigs_" + config)); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerGroupMetadata.java b/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerGroupMetadata.java new file mode 100644 index 0000000..f9b7b28 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerGroupMetadata.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import org.apache.kafka.common.requests.JoinGroupRequest; + +import java.util.Objects; +import java.util.Optional; + +/** + * A metadata struct containing the consumer group information. + * Note: Any change to this class is considered public and requires a KIP. + */ +public class ConsumerGroupMetadata { + final private String groupId; + final private int generationId; + final private String memberId; + final private Optional groupInstanceId; + + public ConsumerGroupMetadata(String groupId, + int generationId, + String memberId, + Optional groupInstanceId) { + this.groupId = Objects.requireNonNull(groupId, "group.id can't be null"); + this.generationId = generationId; + this.memberId = Objects.requireNonNull(memberId, "member.id can't be null"); + this.groupInstanceId = Objects.requireNonNull(groupInstanceId, "group.instance.id can't be null"); + } + + public ConsumerGroupMetadata(String groupId) { + this(groupId, + JoinGroupRequest.UNKNOWN_GENERATION_ID, + JoinGroupRequest.UNKNOWN_MEMBER_ID, + Optional.empty()); + } + + public String groupId() { + return groupId; + } + + public int generationId() { + return generationId; + } + + public String memberId() { + return memberId; + } + + public Optional groupInstanceId() { + return groupInstanceId; + } + + @Override + public String toString() { + return String.format("GroupMetadata(groupId = %s, generationId = %d, memberId = %s, groupInstanceId = %s)", + groupId, + generationId, + memberId, + groupInstanceId.orElse("")); + } + + @Override + public boolean equals(final Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + final ConsumerGroupMetadata that = (ConsumerGroupMetadata) o; + return generationId == that.generationId && + Objects.equals(groupId, that.groupId) && + Objects.equals(memberId, that.memberId) && + Objects.equals(groupInstanceId, that.groupInstanceId); + } + + @Override + public int hashCode() { + return Objects.hash(groupId, generationId, memberId, groupInstanceId); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerInterceptor.java b/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerInterceptor.java new file mode 100644 index 0000000..c04afcc --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerInterceptor.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + + +import org.apache.kafka.common.Configurable; +import org.apache.kafka.common.TopicPartition; + +import java.util.Map; + +/** + * A plugin interface that allows you to intercept (and possibly mutate) records received by the consumer. A primary use-case + * is for third-party components to hook into the consumer applications for custom monitoring, logging, etc. + * + *

+ * This class will get consumer config properties via configure() method, including clientId assigned + * by KafkaConsumer if not specified in the consumer config. The interceptor implementation needs to be aware that it will be + * sharing consumer config namespace with other interceptors and serializers, and ensure that there are no conflicts. + *

+ * Exceptions thrown by ConsumerInterceptor methods will be caught, logged, but not propagated further. As a result, if + * the user configures the interceptor with the wrong key and value type parameters, the consumer will not throw an exception, + * just log the errors. + *

+ * ConsumerInterceptor callbacks are called from the same thread that invokes + * {@link org.apache.kafka.clients.consumer.KafkaConsumer#poll(java.time.Duration)}. + *

+ * Implement {@link org.apache.kafka.common.ClusterResourceListener} to receive cluster metadata once it's available. Please see the class documentation for ClusterResourceListener for more information. + */ +public interface ConsumerInterceptor extends Configurable, AutoCloseable { + + /** + * This is called just before the records are returned by + * {@link org.apache.kafka.clients.consumer.KafkaConsumer#poll(java.time.Duration)} + *

+ * This method is allowed to modify consumer records, in which case the new records will be + * returned. There is no limitation on number of records that could be returned from this + * method. I.e., the interceptor can filter the records or generate new records. + *

+ * Any exception thrown by this method will be caught by the caller, logged, but not propagated to the client. + *

+ * Since the consumer may run multiple interceptors, a particular interceptor's onConsume() callback will be called + * in the order specified by {@link org.apache.kafka.clients.consumer.ConsumerConfig#INTERCEPTOR_CLASSES_CONFIG}. + * The first interceptor in the list gets the consumed records, the following interceptor will be passed the records returned + * by the previous interceptor, and so on. Since interceptors are allowed to modify records, interceptors may potentially get + * the records already modified by other interceptors. However, building a pipeline of mutable interceptors that depend on the output + * of the previous interceptor is discouraged, because of potential side-effects caused by interceptors potentially failing + * to modify the record and throwing an exception. If one of the interceptors in the list throws an exception from onConsume(), + * the exception is caught, logged, and the next interceptor is called with the records returned by the last successful interceptor + * in the list, or otherwise the original consumed records. + * + * @param records records to be consumed by the client or records returned by the previous interceptors in the list. + * @return records that are either modified by the interceptor or same as records passed to this method. + */ + ConsumerRecords onConsume(ConsumerRecords records); + + /** + * This is called when offsets get committed. + *

+ * Any exception thrown by this method will be ignored by the caller. + * + * @param offsets A map of offsets by partition with associated metadata + */ + void onCommit(Map offsets); + + /** + * This is called when interceptor is closed + */ + void close(); +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerPartitionAssignor.java b/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerPartitionAssignor.java new file mode 100644 index 0000000..a541b8a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerPartitionAssignor.java @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Optional; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.Configurable; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.Utils; + +/** + * This interface is used to define custom partition assignment for use in + * {@link org.apache.kafka.clients.consumer.KafkaConsumer}. Members of the consumer group subscribe + * to the topics they are interested in and forward their subscriptions to a Kafka broker serving + * as the group coordinator. The coordinator selects one member to perform the group assignment and + * propagates the subscriptions of all members to it. Then {@link #assign(Cluster, GroupSubscription)} is called + * to perform the assignment and the results are forwarded back to each respective members + * + * In some cases, it is useful to forward additional metadata to the assignor in order to make + * assignment decisions. For this, you can override {@link #subscriptionUserData(Set)} and provide custom + * userData in the returned Subscription. For example, to have a rack-aware assignor, an implementation + * can use this user data to forward the rackId belonging to each member. + */ +public interface ConsumerPartitionAssignor { + + /** + * Return serialized data that will be included in the {@link Subscription} sent to the leader + * and can be leveraged in {@link #assign(Cluster, GroupSubscription)} ((e.g. local host/rack information) + * + * @param topics Topics subscribed to through {@link org.apache.kafka.clients.consumer.KafkaConsumer#subscribe(java.util.Collection)} + * and variants + * @return nullable subscription user data + */ + default ByteBuffer subscriptionUserData(Set topics) { + return null; + } + + /** + * Perform the group assignment given the member subscriptions and current cluster metadata. + * @param metadata Current topic/broker metadata known by consumer + * @param groupSubscription Subscriptions from all members including metadata provided through {@link #subscriptionUserData(Set)} + * @return A map from the members to their respective assignments. This should have one entry + * for each member in the input subscription map. + */ + GroupAssignment assign(Cluster metadata, GroupSubscription groupSubscription); + + /** + * Callback which is invoked when a group member receives its assignment from the leader. + * @param assignment The local member's assignment as provided by the leader in {@link #assign(Cluster, GroupSubscription)} + * @param metadata Additional metadata on the consumer (optional) + */ + default void onAssignment(Assignment assignment, ConsumerGroupMetadata metadata) { + } + + /** + * Indicate which rebalance protocol this assignor works with; + * By default it should always work with {@link RebalanceProtocol#EAGER}. + */ + default List supportedProtocols() { + return Collections.singletonList(RebalanceProtocol.EAGER); + } + + /** + * Return the version of the assignor which indicates how the user metadata encodings + * and the assignment algorithm gets evolved. + */ + default short version() { + return (short) 0; + } + + /** + * Unique name for this assignor (e.g. "range" or "roundrobin" or "sticky"). Note, this is not required + * to be the same as the class name specified in {@link ConsumerConfig#PARTITION_ASSIGNMENT_STRATEGY_CONFIG} + * @return non-null unique name + */ + String name(); + + final class Subscription { + private final List topics; + private final ByteBuffer userData; + private final List ownedPartitions; + private Optional groupInstanceId; + + public Subscription(List topics, ByteBuffer userData, List ownedPartitions) { + this.topics = topics; + this.userData = userData; + this.ownedPartitions = ownedPartitions; + this.groupInstanceId = Optional.empty(); + } + + public Subscription(List topics, ByteBuffer userData) { + this(topics, userData, Collections.emptyList()); + } + + public Subscription(List topics) { + this(topics, null, Collections.emptyList()); + } + + public List topics() { + return topics; + } + + public ByteBuffer userData() { + return userData; + } + + public List ownedPartitions() { + return ownedPartitions; + } + + public void setGroupInstanceId(Optional groupInstanceId) { + this.groupInstanceId = groupInstanceId; + } + + public Optional groupInstanceId() { + return groupInstanceId; + } + + @Override + public String toString() { + return "Subscription(" + + "topics=" + topics + + (userData == null ? "" : ", userDataSize=" + userData.remaining()) + + ", ownedPartitions=" + ownedPartitions + + ", groupInstanceId=" + (groupInstanceId.map(String::toString).orElse("null")) + + ")"; + } + } + + final class Assignment { + private List partitions; + private ByteBuffer userData; + + public Assignment(List partitions, ByteBuffer userData) { + this.partitions = partitions; + this.userData = userData; + } + + public Assignment(List partitions) { + this(partitions, null); + } + + public List partitions() { + return partitions; + } + + public ByteBuffer userData() { + return userData; + } + + @Override + public String toString() { + return "Assignment(" + + "partitions=" + partitions + + (userData == null ? "" : ", userDataSize=" + userData.remaining()) + + ')'; + } + } + + final class GroupSubscription { + private final Map subscriptions; + + public GroupSubscription(Map subscriptions) { + this.subscriptions = subscriptions; + } + + public Map groupSubscription() { + return subscriptions; + } + + @Override + public String toString() { + return "GroupSubscription(" + + "subscriptions=" + subscriptions + + ")"; + } + } + + final class GroupAssignment { + private final Map assignments; + + public GroupAssignment(Map assignments) { + this.assignments = assignments; + } + + public Map groupAssignment() { + return assignments; + } + + @Override + public String toString() { + return "GroupAssignment(" + + "assignments=" + assignments + + ")"; + } + } + + /** + * The rebalance protocol defines partition assignment and revocation semantics. The purpose is to establish a + * consistent set of rules that all consumers in a group follow in order to transfer ownership of a partition. + * {@link ConsumerPartitionAssignor} implementors can claim supporting one or more rebalance protocols via the + * {@link ConsumerPartitionAssignor#supportedProtocols()}, and it is their responsibility to respect the rules + * of those protocols in their {@link ConsumerPartitionAssignor#assign(Cluster, GroupSubscription)} implementations. + * Failures to follow the rules of the supported protocols would lead to runtime error or undefined behavior. + * + * The {@link RebalanceProtocol#EAGER} rebalance protocol requires a consumer to always revoke all its owned + * partitions before participating in a rebalance event. It therefore allows a complete reshuffling of the assignment. + * + * {@link RebalanceProtocol#COOPERATIVE} rebalance protocol allows a consumer to retain its currently owned + * partitions before participating in a rebalance event. The assignor should not reassign any owned partitions + * immediately, but instead may indicate consumers the need for partition revocation so that the revoked + * partitions can be reassigned to other consumers in the next rebalance event. This is designed for sticky assignment + * logic which attempts to minimize partition reassignment with cooperative adjustments. + */ + enum RebalanceProtocol { + EAGER((byte) 0), COOPERATIVE((byte) 1); + + private final byte id; + + RebalanceProtocol(byte id) { + this.id = id; + } + + public byte id() { + return id; + } + + public static RebalanceProtocol forId(byte id) { + switch (id) { + case 0: + return EAGER; + case 1: + return COOPERATIVE; + default: + throw new IllegalArgumentException("Unknown rebalance protocol id: " + id); + } + } + } + + /** + * Get a list of configured instances of {@link org.apache.kafka.clients.consumer.ConsumerPartitionAssignor} + * based on the class names/types specified by {@link org.apache.kafka.clients.consumer.ConsumerConfig#PARTITION_ASSIGNMENT_STRATEGY_CONFIG} + */ + static List getAssignorInstances(List assignorClasses, Map configs) { + List assignors = new ArrayList<>(); + // a map to store assignor name -> assignor class name + Map assignorNameMap = new HashMap<>(); + + if (assignorClasses == null) + return assignors; + + for (Object klass : assignorClasses) { + // first try to get the class if passed in as a string + if (klass instanceof String) { + try { + klass = Class.forName((String) klass, true, Utils.getContextOrKafkaClassLoader()); + } catch (ClassNotFoundException classNotFound) { + throw new KafkaException(klass + " ClassNotFoundException exception occurred", classNotFound); + } + } + + if (klass instanceof Class) { + Object assignor = Utils.newInstance((Class) klass); + if (assignor instanceof Configurable) + ((Configurable) assignor).configure(configs); + + if (assignor instanceof ConsumerPartitionAssignor) { + String assignorName = ((ConsumerPartitionAssignor) assignor).name(); + if (assignorNameMap.containsKey(assignorName)) { + throw new KafkaException("The assignor name: '" + assignorName + "' is used in more than one assignor: " + + assignorNameMap.get(assignorName) + ", " + assignor.getClass().getName()); + } + assignorNameMap.put(assignorName, assignor.getClass().getName()); + assignors.add((ConsumerPartitionAssignor) assignor); + } else { + throw new KafkaException(klass + " is not an instance of " + ConsumerPartitionAssignor.class.getName()); + } + } else { + throw new KafkaException("List contains element of type " + klass.getClass().getName() + ", expected String or Class"); + } + } + return assignors; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerRebalanceListener.java b/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerRebalanceListener.java new file mode 100644 index 0000000..2f43b60 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerRebalanceListener.java @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import java.time.Duration; +import java.util.Collection; + +import org.apache.kafka.common.TopicPartition; + +/** + * A callback interface that the user can implement to trigger custom actions when the set of partitions assigned to the + * consumer changes. + *

+ * This is applicable when the consumer is having Kafka auto-manage group membership. If the consumer directly assigns partitions, + * those partitions will never be reassigned and this callback is not applicable. + *

+ * When Kafka is managing the group membership, a partition re-assignment will be triggered any time the members of the group change or the subscription + * of the members changes. This can occur when processes die, new process instances are added or old instances come back to life after failure. + * Partition re-assignments can also be triggered by changes affecting the subscribed topics (e.g. when the number of partitions is + * administratively adjusted). + *

+ * There are many uses for this functionality. One common use is saving offsets in a custom store. By saving offsets in + * the {@link #onPartitionsRevoked(Collection)} call we can ensure that any time partition assignment changes + * the offset gets saved. + *

+ * Another use is flushing out any kind of cache of intermediate results the consumer may be keeping. For example, + * consider a case where the consumer is subscribed to a topic containing user page views, and the goal is to count the + * number of page views per user for each five minute window. Let's say the topic is partitioned by the user id so that + * all events for a particular user go to a single consumer instance. The consumer can keep in memory a running + * tally of actions per user and only flush these out to a remote data store when its cache gets too big. However if a + * partition is reassigned it may want to automatically trigger a flush of this cache, before the new owner takes over + * consumption. + *

+ * This callback will only execute in the user thread as part of the {@link Consumer#poll(java.time.Duration) poll(long)} call + * whenever partition assignment changes. + *

+ * Under normal conditions, if a partition is reassigned from one consumer to another, then the old consumer will + * always invoke {@link #onPartitionsRevoked(Collection) onPartitionsRevoked} for that partition prior to the new consumer + * invoking {@link #onPartitionsAssigned(Collection) onPartitionsAssigned} for the same partition. So if offsets or other state is saved in the + * {@link #onPartitionsRevoked(Collection) onPartitionsRevoked} call by one consumer member, it will be always accessible by the time the + * other consumer member taking over that partition and triggering its {@link #onPartitionsAssigned(Collection) onPartitionsAssigned} callback to load the state. + *

+ * You can think of revocation as a graceful way to give up ownership of a partition. In some cases, the consumer may not have an opportunity to do so. + * For example, if the session times out, then the partitions may be reassigned before we have a chance to revoke them gracefully. + * For this case, we have a third callback {@link #onPartitionsLost(Collection)}. The difference between this function and + * {@link #onPartitionsRevoked(Collection)} is that upon invocation of {@link #onPartitionsLost(Collection)}, the partitions + * may already be owned by some other members in the group and therefore users would not be able to commit its consumed offsets for example. + * Users could implement these two functions differently (by default, + * {@link #onPartitionsLost(Collection)} will be calling {@link #onPartitionsRevoked(Collection)} directly); for example, in the + * {@link #onPartitionsLost(Collection)} we should not need to store the offsets since we know these partitions are no longer owned by the consumer + * at that time. + *

+ * During a rebalance event, the {@link #onPartitionsAssigned(Collection) onPartitionsAssigned} function will always be triggered exactly once when + * the rebalance completes. That is, even if there is no newly assigned partitions for a consumer member, its {@link #onPartitionsAssigned(Collection) onPartitionsAssigned} + * will still be triggered with an empty collection of partitions. As a result this function can be used also to notify when a rebalance event has happened. + * With eager rebalancing, {@link #onPartitionsRevoked(Collection)} will always be called at the start of a rebalance. On the other hand, {@link #onPartitionsLost(Collection)} + * will only be called when there were non-empty partitions that were lost. + * With cooperative rebalancing, {@link #onPartitionsRevoked(Collection)} and {@link #onPartitionsLost(Collection)} + * will only be triggered when there are non-empty partitions revoked or lost from this consumer member during a rebalance event. + *

+ * It is possible + * for a {@link org.apache.kafka.common.errors.WakeupException} or {@link org.apache.kafka.common.errors.InterruptException} + * to be raised from one of these nested invocations. In this case, the exception will be propagated to the current + * invocation of {@link KafkaConsumer#poll(java.time.Duration)} in which this callback is being executed. This means it is not + * necessary to catch these exceptions and re-attempt to wakeup or interrupt the consumer thread. + * Also if the callback function implementation itself throws an exception, this exception will be propagated to the current + * invocation of {@link KafkaConsumer#poll(java.time.Duration)} as well. + *

+ * Note that callbacks only serve as notification of an assignment change. + * They cannot be used to express acceptance of the change. + * Hence throwing an exception from a callback does not affect the assignment in any way, + * as it will be propagated all the way up to the {@link KafkaConsumer#poll(java.time.Duration)} call. + * If user captures the exception in the caller, the callback is still assumed successful and no further retries will be attempted. + *

+ * + * Here is pseudo-code for a callback implementation for saving offsets: + *

+ * {@code
+ *   public class SaveOffsetsOnRebalance implements ConsumerRebalanceListener {
+ *       private Consumer consumer;
+ *
+ *       public SaveOffsetsOnRebalance(Consumer consumer) {
+ *           this.consumer = consumer;
+ *       }
+ *
+ *       public void onPartitionsRevoked(Collection partitions) {
+ *           // save the offsets in an external store using some custom code not described here
+ *           for(TopicPartition partition: partitions)
+ *              saveOffsetInExternalStore(consumer.position(partition));
+ *       }
+ *
+ *       public void onPartitionsLost(Collection partitions) {
+ *           // do not need to save the offsets since these partitions are probably owned by other consumers already
+ *       }
+ *
+ *       public void onPartitionsAssigned(Collection partitions) {
+ *           // read the offsets from an external store using some custom code not described here
+ *           for(TopicPartition partition: partitions)
+ *              consumer.seek(partition, readOffsetFromExternalStore(partition));
+ *       }
+ *   }
+ * }
+ * 
+ */ +public interface ConsumerRebalanceListener { + + /** + * A callback method the user can implement to provide handling of offset commits to a customized store. + * This method will be called during a rebalance operation when the consumer has to give up some partitions. + * It can also be called when consumer is being closed ({@link KafkaConsumer#close(Duration)}) + * or is unsubscribing ({@link KafkaConsumer#unsubscribe()}). + * It is recommended that offsets should be committed in this callback to either Kafka or a + * custom offset store to prevent duplicate data. + *

+ * In eager rebalancing, it will always be called at the start of a rebalance and after the consumer stops fetching data. + * In cooperative rebalancing, it will be called at the end of a rebalance on the set of partitions being revoked iff the set is non-empty. + * For examples on usage of this API, see Usage Examples section of {@link KafkaConsumer KafkaConsumer}. + *

+ * It is common for the revocation callback to use the consumer instance in order to commit offsets. It is possible + * for a {@link org.apache.kafka.common.errors.WakeupException} or {@link org.apache.kafka.common.errors.InterruptException} + * to be raised from one of these nested invocations. In this case, the exception will be propagated to the current + * invocation of {@link KafkaConsumer#poll(java.time.Duration)} in which this callback is being executed. This means it is not + * necessary to catch these exceptions and re-attempt to wakeup or interrupt the consumer thread. + * + * @param partitions The list of partitions that were assigned to the consumer and now need to be revoked (may not + * include all currently assigned partitions, i.e. there may still be some partitions left) + * @throws org.apache.kafka.common.errors.WakeupException If raised from a nested call to {@link KafkaConsumer} + * @throws org.apache.kafka.common.errors.InterruptException If raised from a nested call to {@link KafkaConsumer} + */ + void onPartitionsRevoked(Collection partitions); + + /** + * A callback method the user can implement to provide handling of customized offsets on completion of a successful + * partition re-assignment. This method will be called after the partition re-assignment completes and before the + * consumer starts fetching data, and only as the result of a {@link Consumer#poll(java.time.Duration) poll(long)} call. + *

+ * It is guaranteed that under normal conditions all the processes in a consumer group will execute their + * {@link #onPartitionsRevoked(Collection)} callback before any instance executes its + * {@link #onPartitionsAssigned(Collection)} callback. During exceptional scenarios, partitions may be migrated + * without the old owner being notified (i.e. their {@link #onPartitionsRevoked(Collection)} callback not triggered), + * and later when the old owner consumer realized this event, the {@link #onPartitionsLost(Collection)} (Collection)} callback + * will be triggered by the consumer then. + *

+ * It is common for the assignment callback to use the consumer instance in order to query offsets. It is possible + * for a {@link org.apache.kafka.common.errors.WakeupException} or {@link org.apache.kafka.common.errors.InterruptException} + * to be raised from one of these nested invocations. In this case, the exception will be propagated to the current + * invocation of {@link KafkaConsumer#poll(java.time.Duration)} in which this callback is being executed. This means it is not + * necessary to catch these exceptions and re-attempt to wakeup or interrupt the consumer thread. + * + * @param partitions The list of partitions that are now assigned to the consumer (previously owned partitions will + * NOT be included, i.e. this list will only include newly added partitions) + * @throws org.apache.kafka.common.errors.WakeupException If raised from a nested call to {@link KafkaConsumer} + * @throws org.apache.kafka.common.errors.InterruptException If raised from a nested call to {@link KafkaConsumer} + */ + void onPartitionsAssigned(Collection partitions); + + /** + * A callback method you can implement to provide handling of cleaning up resources for partitions that have already + * been reassigned to other consumers. This method will not be called during normal execution as the owned partitions would + * first be revoked by calling the {@link ConsumerRebalanceListener#onPartitionsRevoked}, before being reassigned + * to other consumers during a rebalance event. However, during exceptional scenarios when the consumer realized that it + * does not own this partition any longer, i.e. not revoked via a normal rebalance event, then this method would be invoked. + *

+ * For example, this function is called if a consumer's session timeout has expired, or if a fatal error has been + * received indicating the consumer is no longer part of the group. + *

+ * By default it will just trigger {@link ConsumerRebalanceListener#onPartitionsRevoked}; for users who want to distinguish + * the handling logic of revoked partitions v.s. lost partitions, they can override the default implementation. + *

+ * It is possible + * for a {@link org.apache.kafka.common.errors.WakeupException} or {@link org.apache.kafka.common.errors.InterruptException} + * to be raised from one of these nested invocations. In this case, the exception will be propagated to the current + * invocation of {@link KafkaConsumer#poll(java.time.Duration)} in which this callback is being executed. This means it is not + * necessary to catch these exceptions and re-attempt to wakeup or interrupt the consumer thread. + * + * @param partitions The list of partitions that were assigned to the consumer and now have been reassigned + * to other consumers. With the current protocol this will always include all of the consumer's + * previously assigned partitions, but this may change in future protocols (ie there would still + * be some partitions left) + * @throws org.apache.kafka.common.errors.WakeupException If raised from a nested call to {@link KafkaConsumer} + * @throws org.apache.kafka.common.errors.InterruptException If raised from a nested call to {@link KafkaConsumer} + */ + default void onPartitionsLost(Collection partitions) { + onPartitionsRevoked(partitions); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerRecord.java b/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerRecord.java new file mode 100644 index 0000000..2ae93e8 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerRecord.java @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.record.TimestampType; + +import java.util.Optional; + +/** + * A key/value pair to be received from Kafka. This also consists of a topic name and + * a partition number from which the record is being received, an offset that points + * to the record in a Kafka partition, and a timestamp as marked by the corresponding ProducerRecord. + */ +public class ConsumerRecord { + public static final long NO_TIMESTAMP = RecordBatch.NO_TIMESTAMP; + public static final int NULL_SIZE = -1; + + /** + * @deprecated checksums are no longer exposed by this class, this constant will be removed in Apache Kafka 4.0 + * (deprecated since 3.0). + */ + @Deprecated + public static final int NULL_CHECKSUM = -1; + + private final String topic; + private final int partition; + private final long offset; + private final long timestamp; + private final TimestampType timestampType; + private final int serializedKeySize; + private final int serializedValueSize; + private final Headers headers; + private final K key; + private final V value; + private final Optional leaderEpoch; + + /** + * Creates a record to be received from a specified topic and partition (provided for + * compatibility with Kafka 0.9 before the message format supported timestamps and before + * serialized metadata were exposed). + * + * @param topic The topic this record is received from + * @param partition The partition of the topic this record is received from + * @param offset The offset of this record in the corresponding Kafka partition + * @param key The key of the record, if one exists (null is allowed) + * @param value The record contents + */ + public ConsumerRecord(String topic, + int partition, + long offset, + K key, + V value) { + this(topic, partition, offset, NO_TIMESTAMP, TimestampType.NO_TIMESTAMP_TYPE, NULL_SIZE, NULL_SIZE, key, value, + new RecordHeaders(), Optional.empty()); + } + + /** + * Creates a record to be received from a specified topic and partition + * + * @param topic The topic this record is received from + * @param partition The partition of the topic this record is received from + * @param offset The offset of this record in the corresponding Kafka partition + * @param timestamp The timestamp of the record. + * @param timestampType The timestamp type + * @param serializedKeySize The length of the serialized key + * @param serializedValueSize The length of the serialized value + * @param key The key of the record, if one exists (null is allowed) + * @param value The record contents + * @param headers The headers of the record + * @param leaderEpoch Optional leader epoch of the record (may be empty for legacy record formats) + */ + public ConsumerRecord(String topic, + int partition, + long offset, + long timestamp, + TimestampType timestampType, + int serializedKeySize, + int serializedValueSize, + K key, + V value, + Headers headers, + Optional leaderEpoch) { + if (topic == null) + throw new IllegalArgumentException("Topic cannot be null"); + if (headers == null) + throw new IllegalArgumentException("Headers cannot be null"); + + this.topic = topic; + this.partition = partition; + this.offset = offset; + this.timestamp = timestamp; + this.timestampType = timestampType; + this.serializedKeySize = serializedKeySize; + this.serializedValueSize = serializedValueSize; + this.key = key; + this.value = value; + this.headers = headers; + this.leaderEpoch = leaderEpoch; + } + + /** + * Creates a record to be received from a specified topic and partition (provided for + * compatibility with Kafka 0.10 before the message format supported headers). + * + * @param topic The topic this record is received from + * @param partition The partition of the topic this record is received from + * @param offset The offset of this record in the corresponding Kafka partition + * @param timestamp The timestamp of the record. + * @param timestampType The timestamp type + * @param serializedKeySize The length of the serialized key + * @param serializedValueSize The length of the serialized value + * @param key The key of the record, if one exists (null is allowed) + * @param value The record contents + * + * @deprecated use one of the constructors without a `checksum` parameter. This constructor will be removed in + * Apache Kafka 4.0 (deprecated since 3.0). + */ + @Deprecated + public ConsumerRecord(String topic, + int partition, + long offset, + long timestamp, + TimestampType timestampType, + long checksum, + int serializedKeySize, + int serializedValueSize, + K key, + V value) { + this(topic, partition, offset, timestamp, timestampType, serializedKeySize, serializedValueSize, + key, value, new RecordHeaders(), Optional.empty()); + } + + /** + * Creates a record to be received from a specified topic and partition + * + * @param topic The topic this record is received from + * @param partition The partition of the topic this record is received from + * @param offset The offset of this record in the corresponding Kafka partition + * @param timestamp The timestamp of the record. + * @param timestampType The timestamp type + * @param serializedKeySize The length of the serialized key + * @param serializedValueSize The length of the serialized value + * @param key The key of the record, if one exists (null is allowed) + * @param value The record contents + * @param headers The headers of the record. + * + * @deprecated use one of the constructors without a `checksum` parameter. This constructor will be removed in + * Apache Kafka 4.0 (deprecated since 3.0). + */ + @Deprecated + public ConsumerRecord(String topic, + int partition, + long offset, + long timestamp, + TimestampType timestampType, + Long checksum, + int serializedKeySize, + int serializedValueSize, + K key, + V value, + Headers headers) { + this(topic, partition, offset, timestamp, timestampType, serializedKeySize, serializedValueSize, + key, value, headers, Optional.empty()); + } + + /** + * Creates a record to be received from a specified topic and partition + * + * @param topic The topic this record is received from + * @param partition The partition of the topic this record is received from + * @param offset The offset of this record in the corresponding Kafka partition + * @param timestamp The timestamp of the record. + * @param timestampType The timestamp type + * @param serializedKeySize The length of the serialized key + * @param serializedValueSize The length of the serialized value + * @param key The key of the record, if one exists (null is allowed) + * @param value The record contents + * @param headers The headers of the record + * @param leaderEpoch Optional leader epoch of the record (may be empty for legacy record formats) + * + * @deprecated use one of the constructors without a `checksum` parameter. This constructor will be removed in + * Apache Kafka 4.0 (deprecated since 3.0). + */ + @Deprecated + public ConsumerRecord(String topic, + int partition, + long offset, + long timestamp, + TimestampType timestampType, + Long checksum, + int serializedKeySize, + int serializedValueSize, + K key, + V value, + Headers headers, + Optional leaderEpoch) { + this(topic, partition, offset, timestamp, timestampType, serializedKeySize, serializedValueSize, key, value, headers, + leaderEpoch); + } + + /** + * The topic this record is received from (never null) + */ + public String topic() { + return this.topic; + } + + /** + * The partition from which this record is received + */ + public int partition() { + return this.partition; + } + + /** + * The headers (never null) + */ + public Headers headers() { + return headers; + } + + /** + * The key (or null if no key is specified) + */ + public K key() { + return key; + } + + /** + * The value + */ + public V value() { + return value; + } + + /** + * The position of this record in the corresponding Kafka partition. + */ + public long offset() { + return offset; + } + + /** + * The timestamp of this record + */ + public long timestamp() { + return timestamp; + } + + /** + * The timestamp type of this record + */ + public TimestampType timestampType() { + return timestampType; + } + + /** + * The size of the serialized, uncompressed key in bytes. If key is null, the returned size + * is -1. + */ + public int serializedKeySize() { + return this.serializedKeySize; + } + + /** + * The size of the serialized, uncompressed value in bytes. If value is null, the + * returned size is -1. + */ + public int serializedValueSize() { + return this.serializedValueSize; + } + + /** + * Get the leader epoch for the record if available + * + * @return the leader epoch or empty for legacy record formats + */ + public Optional leaderEpoch() { + return leaderEpoch; + } + + @Override + public String toString() { + return "ConsumerRecord(topic = " + topic + + ", partition = " + partition + + ", leaderEpoch = " + leaderEpoch.orElse(null) + + ", offset = " + offset + + ", " + timestampType + " = " + timestamp + + ", serialized key size = " + serializedKeySize + + ", serialized value size = " + serializedValueSize + + ", headers = " + headers + + ", key = " + key + + ", value = " + value + ")"; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerRecords.java b/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerRecords.java new file mode 100644 index 0000000..92390e9 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/ConsumerRecords.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.AbstractIterator; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * A container that holds the list {@link ConsumerRecord} per partition for a + * particular topic. There is one {@link ConsumerRecord} list for every topic + * partition returned by a {@link Consumer#poll(java.time.Duration)} operation. + */ +public class ConsumerRecords implements Iterable> { + public static final ConsumerRecords EMPTY = new ConsumerRecords<>(Collections.emptyMap()); + + private final Map>> records; + + public ConsumerRecords(Map>> records) { + this.records = records; + } + + /** + * Get just the records for the given partition + * + * @param partition The partition to get records for + */ + public List> records(TopicPartition partition) { + List> recs = this.records.get(partition); + if (recs == null) + return Collections.emptyList(); + else + return Collections.unmodifiableList(recs); + } + + /** + * Get just the records for the given topic + */ + public Iterable> records(String topic) { + if (topic == null) + throw new IllegalArgumentException("Topic must be non-null."); + List>> recs = new ArrayList<>(); + for (Map.Entry>> entry : records.entrySet()) { + if (entry.getKey().topic().equals(topic)) + recs.add(entry.getValue()); + } + return new ConcatenatedIterable<>(recs); + } + + /** + * Get the partitions which have records contained in this record set. + * @return the set of partitions with data in this record set (may be empty if no data was returned) + */ + public Set partitions() { + return Collections.unmodifiableSet(records.keySet()); + } + + @Override + public Iterator> iterator() { + return new ConcatenatedIterable<>(records.values()).iterator(); + } + + /** + * The number of records for all topics + */ + public int count() { + int count = 0; + for (List> recs: this.records.values()) + count += recs.size(); + return count; + } + + private static class ConcatenatedIterable implements Iterable> { + + private final Iterable>> iterables; + + public ConcatenatedIterable(Iterable>> iterables) { + this.iterables = iterables; + } + + @Override + public Iterator> iterator() { + return new AbstractIterator>() { + Iterator>> iters = iterables.iterator(); + Iterator> current; + + public ConsumerRecord makeNext() { + while (current == null || !current.hasNext()) { + if (iters.hasNext()) + current = iters.next().iterator(); + else + return allDone(); + } + return current.next(); + } + }; + } + } + + public boolean isEmpty() { + return records.isEmpty(); + } + + @SuppressWarnings("unchecked") + public static ConsumerRecords empty() { + return (ConsumerRecords) EMPTY; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/CooperativeStickyAssignor.java b/clients/src/main/java/org/apache/kafka/clients/consumer/CooperativeStickyAssignor.java new file mode 100644 index 0000000..b2cca87 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/CooperativeStickyAssignor.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import org.apache.kafka.clients.consumer.internals.AbstractStickyAssignor; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.protocol.types.Field; +import org.apache.kafka.common.protocol.types.Schema; +import org.apache.kafka.common.protocol.types.Struct; +import org.apache.kafka.common.protocol.types.Type; + +/** + * A cooperative version of the {@link AbstractStickyAssignor AbstractStickyAssignor}. This follows the same (sticky) + * assignment logic as {@link StickyAssignor StickyAssignor} but allows for cooperative rebalancing while the + * {@link StickyAssignor StickyAssignor} follows the eager rebalancing protocol. See + * {@link ConsumerPartitionAssignor.RebalanceProtocol} for an explanation of the rebalancing protocols. + *

+ * Users should prefer this assignor for newer clusters. + *

+ * To turn on cooperative rebalancing you must set all your consumers to use this {@code PartitionAssignor}, + * or implement a custom one that returns {@code RebalanceProtocol.COOPERATIVE} in + * {@link CooperativeStickyAssignor#supportedProtocols supportedProtocols()}. + *

+ * IMPORTANT: if upgrading from 2.3 or earlier, you must follow a specific upgrade path in order to safely turn on + * cooperative rebalancing. See the upgrade guide for details. + */ +public class CooperativeStickyAssignor extends AbstractStickyAssignor { + public static final String COOPERATIVE_STICKY_ASSIGNOR_NAME = "cooperative-sticky"; + + // these schemas are used for preserving useful metadata for the assignment, such as the last stable generation + private static final String GENERATION_KEY_NAME = "generation"; + private static final Schema COOPERATIVE_STICKY_ASSIGNOR_USER_DATA_V0 = new Schema( + new Field(GENERATION_KEY_NAME, Type.INT32)); + + private int generation = DEFAULT_GENERATION; // consumer group generation + + @Override + public String name() { + return COOPERATIVE_STICKY_ASSIGNOR_NAME; + } + + @Override + public List supportedProtocols() { + return Arrays.asList(RebalanceProtocol.COOPERATIVE, RebalanceProtocol.EAGER); + } + + @Override + public void onAssignment(Assignment assignment, ConsumerGroupMetadata metadata) { + this.generation = metadata.generationId(); + } + + @Override + public ByteBuffer subscriptionUserData(Set topics) { + Struct struct = new Struct(COOPERATIVE_STICKY_ASSIGNOR_USER_DATA_V0); + + struct.set(GENERATION_KEY_NAME, generation); + ByteBuffer buffer = ByteBuffer.allocate(COOPERATIVE_STICKY_ASSIGNOR_USER_DATA_V0.sizeOf(struct)); + COOPERATIVE_STICKY_ASSIGNOR_USER_DATA_V0.write(buffer, struct); + buffer.flip(); + return buffer; + } + + @Override + protected MemberData memberData(Subscription subscription) { + ByteBuffer buffer = subscription.userData(); + Optional encodedGeneration; + if (buffer == null) { + encodedGeneration = Optional.empty(); + } else { + try { + Struct struct = COOPERATIVE_STICKY_ASSIGNOR_USER_DATA_V0.read(buffer); + encodedGeneration = Optional.of(struct.getInt(GENERATION_KEY_NAME)); + } catch (Exception e) { + encodedGeneration = Optional.of(DEFAULT_GENERATION); + } + } + return new MemberData(subscription.ownedPartitions(), encodedGeneration); + } + + @Override + public Map> assign(Map partitionsPerTopic, + Map subscriptions) { + Map> assignments = super.assign(partitionsPerTopic, subscriptions); + + Map partitionsTransferringOwnership = super.partitionsTransferringOwnership == null ? + computePartitionsTransferringOwnership(subscriptions, assignments) : + super.partitionsTransferringOwnership; + + adjustAssignment(assignments, partitionsTransferringOwnership); + return assignments; + } + + // Following the cooperative rebalancing protocol requires removing partitions that must first be revoked from the assignment + private void adjustAssignment(Map> assignments, + Map partitionsTransferringOwnership) { + for (Map.Entry partitionEntry : partitionsTransferringOwnership.entrySet()) { + assignments.get(partitionEntry.getValue()).remove(partitionEntry.getKey()); + } + } + + private Map computePartitionsTransferringOwnership(Map subscriptions, + Map> assignments) { + Map allAddedPartitions = new HashMap<>(); + Set allRevokedPartitions = new HashSet<>(); + + for (final Map.Entry> entry : assignments.entrySet()) { + String consumer = entry.getKey(); + + List ownedPartitions = subscriptions.get(consumer).ownedPartitions(); + List assignedPartitions = entry.getValue(); + + Set ownedPartitionsSet = new HashSet<>(ownedPartitions); + for (TopicPartition tp : assignedPartitions) { + if (!ownedPartitionsSet.contains(tp)) + allAddedPartitions.put(tp, consumer); + } + + Set assignedPartitionsSet = new HashSet<>(assignedPartitions); + for (TopicPartition tp : ownedPartitions) { + if (!assignedPartitionsSet.contains(tp)) + allRevokedPartitions.add(tp); + } + } + + allAddedPartitions.keySet().retainAll(allRevokedPartitions); + return allAddedPartitions; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/InvalidOffsetException.java b/clients/src/main/java/org/apache/kafka/clients/consumer/InvalidOffsetException.java new file mode 100644 index 0000000..b23ca86 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/InvalidOffsetException.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.TopicPartition; + +import java.util.Set; + +/** + * Thrown when the offset for a set of partitions is invalid (either undefined or out of range), + * and no reset policy has been configured. + * @see NoOffsetForPartitionException + * @see OffsetOutOfRangeException + */ +public abstract class InvalidOffsetException extends KafkaException { + + public InvalidOffsetException(String message) { + super(message); + } + + public abstract Set partitions(); + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java b/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java new file mode 100644 index 0000000..286f84b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/KafkaConsumer.java @@ -0,0 +1,2506 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import org.apache.kafka.clients.ApiVersions; +import org.apache.kafka.clients.ClientUtils; +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.clients.GroupRebalanceConfig; +import org.apache.kafka.clients.Metadata; +import org.apache.kafka.clients.NetworkClient; +import org.apache.kafka.clients.consumer.internals.ConsumerCoordinator; +import org.apache.kafka.clients.consumer.internals.ConsumerInterceptors; +import org.apache.kafka.clients.consumer.internals.ConsumerMetadata; +import org.apache.kafka.clients.consumer.internals.ConsumerNetworkClient; +import org.apache.kafka.clients.consumer.internals.Fetcher; +import org.apache.kafka.clients.consumer.internals.FetcherMetricsRegistry; +import org.apache.kafka.clients.consumer.internals.KafkaConsumerMetrics; +import org.apache.kafka.clients.consumer.internals.NoOpConsumerRebalanceListener; +import org.apache.kafka.clients.consumer.internals.SubscriptionState; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.IsolationLevel; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.InterruptException; +import org.apache.kafka.common.errors.InvalidGroupIdException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.internals.ClusterResourceListeners; +import org.apache.kafka.common.metrics.JmxReporter; +import org.apache.kafka.common.metrics.KafkaMetricsContext; +import org.apache.kafka.common.metrics.MetricConfig; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.MetricsContext; +import org.apache.kafka.common.metrics.MetricsReporter; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.network.ChannelBuilder; +import org.apache.kafka.common.network.Selector; +import org.apache.kafka.common.requests.MetadataRequest; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.utils.AppInfoParser; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Timer; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; + +import java.net.InetSocketAddress; +import java.time.Duration; +import java.util.Collection; +import java.util.Collections; +import java.util.ConcurrentModificationException; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.OptionalLong; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.regex.Pattern; + +/** + * A client that consumes records from a Kafka cluster. + *

+ * This client transparently handles the failure of Kafka brokers, and transparently adapts as topic partitions + * it fetches migrate within the cluster. This client also interacts with the broker to allow groups of + * consumers to load balance consumption using consumer groups. + *

+ * The consumer maintains TCP connections to the necessary brokers to fetch data. + * Failure to close the consumer after use will leak these connections. + * The consumer is not thread-safe. See Multi-threaded Processing for more details. + * + *

Cross-Version Compatibility

+ * This client can communicate with brokers that are version 0.10.0 or newer. Older or newer brokers may not support + * certain features. For example, 0.10.0 brokers do not support offsetsForTimes, because this feature was added + * in version 0.10.1. You will receive an {@link org.apache.kafka.common.errors.UnsupportedVersionException} + * when invoking an API that is not available on the running broker version. + *

+ * + *

Offsets and Consumer Position

+ * Kafka maintains a numerical offset for each record in a partition. This offset acts as a unique identifier of + * a record within that partition, and also denotes the position of the consumer in the partition. For example, a consumer + * which is at position 5 has consumed records with offsets 0 through 4 and will next receive the record with offset 5. There + * are actually two notions of position relevant to the user of the consumer: + *

+ * The {@link #position(TopicPartition) position} of the consumer gives the offset of the next record that will be given + * out. It will be one larger than the highest offset the consumer has seen in that partition. It automatically advances + * every time the consumer receives messages in a call to {@link #poll(Duration)}. + *

+ * The {@link #commitSync() committed position} is the last offset that has been stored securely. Should the + * process fail and restart, this is the offset that the consumer will recover to. The consumer can either automatically commit + * offsets periodically; or it can choose to control this committed position manually by calling one of the commit APIs + * (e.g. {@link #commitSync() commitSync} and {@link #commitAsync(OffsetCommitCallback) commitAsync}). + *

+ * This distinction gives the consumer control over when a record is considered consumed. It is discussed in further + * detail below. + * + *

Consumer Groups and Topic Subscriptions

+ * + * Kafka uses the concept of consumer groups to allow a pool of processes to divide the work of consuming and + * processing records. These processes can either be running on the same machine or they can be + * distributed over many machines to provide scalability and fault tolerance for processing. All consumer instances + * sharing the same {@code group.id} will be part of the same consumer group. + *

+ * Each consumer in a group can dynamically set the list of topics it wants to subscribe to through one of the + * {@link #subscribe(Collection, ConsumerRebalanceListener) subscribe} APIs. Kafka will deliver each message in the + * subscribed topics to one process in each consumer group. This is achieved by balancing the partitions between all + * members in the consumer group so that each partition is assigned to exactly one consumer in the group. So if there + * is a topic with four partitions, and a consumer group with two processes, each process would consume from two partitions. + *

+ * Membership in a consumer group is maintained dynamically: if a process fails, the partitions assigned to it will + * be reassigned to other consumers in the same group. Similarly, if a new consumer joins the group, partitions will be moved + * from existing consumers to the new one. This is known as rebalancing the group and is discussed in more + * detail below. Group rebalancing is also used when new partitions are added + * to one of the subscribed topics or when a new topic matching a {@link #subscribe(Pattern, ConsumerRebalanceListener) subscribed regex} + * is created. The group will automatically detect the new partitions through periodic metadata refreshes and + * assign them to members of the group. + *

+ * Conceptually you can think of a consumer group as being a single logical subscriber that happens to be made up of + * multiple processes. As a multi-subscriber system, Kafka naturally supports having any number of consumer groups for a + * given topic without duplicating data (additional consumers are actually quite cheap). + *

+ * This is a slight generalization of the functionality that is common in messaging systems. To get semantics similar to + * a queue in a traditional messaging system all processes would be part of a single consumer group and hence record + * delivery would be balanced over the group like with a queue. Unlike a traditional messaging system, though, you can + * have multiple such groups. To get semantics similar to pub-sub in a traditional messaging system each process would + * have its own consumer group, so each process would subscribe to all the records published to the topic. + *

+ * In addition, when group reassignment happens automatically, consumers can be notified through a {@link ConsumerRebalanceListener}, + * which allows them to finish necessary application-level logic such as state cleanup, manual offset + * commits, etc. See Storing Offsets Outside Kafka for more details. + *

+ * It is also possible for the consumer to manually assign specific partitions + * (similar to the older "simple" consumer) using {@link #assign(Collection)}. In this case, dynamic partition + * assignment and consumer group coordination will be disabled. + * + *

Detecting Consumer Failures

+ * + * After subscribing to a set of topics, the consumer will automatically join the group when {@link #poll(Duration)} is + * invoked. The poll API is designed to ensure consumer liveness. As long as you continue to call poll, the consumer + * will stay in the group and continue to receive messages from the partitions it was assigned. Underneath the covers, + * the consumer sends periodic heartbeats to the server. If the consumer crashes or is unable to send heartbeats for + * a duration of {@code session.timeout.ms}, then the consumer will be considered dead and its partitions will + * be reassigned. + *

+ * It is also possible that the consumer could encounter a "livelock" situation where it is continuing + * to send heartbeats, but no progress is being made. To prevent the consumer from holding onto its partitions + * indefinitely in this case, we provide a liveness detection mechanism using the {@code max.poll.interval.ms} + * setting. Basically if you don't call poll at least as frequently as the configured max interval, + * then the client will proactively leave the group so that another consumer can take over its partitions. When this happens, + * you may see an offset commit failure (as indicated by a {@link CommitFailedException} thrown from a call to {@link #commitSync()}). + * This is a safety mechanism which guarantees that only active members of the group are able to commit offsets. + * So to stay in the group, you must continue to call poll. + *

+ * The consumer provides two configuration settings to control the behavior of the poll loop: + *

    + *
  1. max.poll.interval.ms: By increasing the interval between expected polls, you can give + * the consumer more time to handle a batch of records returned from {@link #poll(Duration)}. The drawback + * is that increasing this value may delay a group rebalance since the consumer will only join the rebalance + * inside the call to poll. You can use this setting to bound the time to finish a rebalance, but + * you risk slower progress if the consumer cannot actually call {@link #poll(Duration) poll} often enough.
  2. + *
  3. max.poll.records: Use this setting to limit the total records returned from a single + * call to poll. This can make it easier to predict the maximum that must be handled within each poll + * interval. By tuning this value, you may be able to reduce the poll interval, which will reduce the + * impact of group rebalancing.
  4. + *
+ *

+ * For use cases where message processing time varies unpredictably, neither of these options may be sufficient. + * The recommended way to handle these cases is to move message processing to another thread, which allows + * the consumer to continue calling {@link #poll(Duration) poll} while the processor is still working. + * Some care must be taken to ensure that committed offsets do not get ahead of the actual position. + * Typically, you must disable automatic commits and manually commit processed offsets for records only after the + * thread has finished handling them (depending on the delivery semantics you need). + * Note also that you will need to {@link #pause(Collection) pause} the partition so that no new records are received + * from poll until after thread has finished handling those previously returned. + * + *

Usage Examples

+ * The consumer APIs offer flexibility to cover a variety of consumption use cases. Here are some examples to + * demonstrate how to use them. + * + *

Automatic Offset Committing

+ * This example demonstrates a simple usage of Kafka's consumer api that relies on automatic offset committing. + *

+ *

+ *     Properties props = new Properties();
+ *     props.setProperty("bootstrap.servers", "localhost:9092");
+ *     props.setProperty("group.id", "test");
+ *     props.setProperty("enable.auto.commit", "true");
+ *     props.setProperty("auto.commit.interval.ms", "1000");
+ *     props.setProperty("key.deserializer", "org.apache.kafka.common.serialization.StringDeserializer");
+ *     props.setProperty("value.deserializer", "org.apache.kafka.common.serialization.StringDeserializer");
+ *     KafkaConsumer<String, String> consumer = new KafkaConsumer<>(props);
+ *     consumer.subscribe(Arrays.asList("foo", "bar"));
+ *     while (true) {
+ *         ConsumerRecords<String, String> records = consumer.poll(Duration.ofMillis(100));
+ *         for (ConsumerRecord<String, String> record : records)
+ *             System.out.printf("offset = %d, key = %s, value = %s%n", record.offset(), record.key(), record.value());
+ *     }
+ * 
+ * + * The connection to the cluster is bootstrapped by specifying a list of one or more brokers to contact using the + * configuration {@code >bootstrap.servers}. This list is just used to discover the rest of the brokers in the + * cluster and need not be an exhaustive list of servers in the cluster (though you may want to specify more than one in + * case there are servers down when the client is connecting). + *

+ * Setting {@code enable.auto.commit} means that offsets are committed automatically with a frequency controlled by + * the config {@code auto.commit.interval.ms}. + *

+ * In this example the consumer is subscribing to the topics foo and bar as part of a group of consumers + * called test as configured with {@code group.id}. + *

+ * The deserializer settings specify how to turn bytes into objects. For example, by specifying string deserializers, we + * are saying that our record's key and value will just be simple strings. + * + *

Manual Offset Control

+ * + * Instead of relying on the consumer to periodically commit consumed offsets, users can also control when records + * should be considered as consumed and hence commit their offsets. This is useful when the consumption of the messages + * is coupled with some processing logic and hence a message should not be considered as consumed until it is completed processing. + + *

+ *

+ *     Properties props = new Properties();
+ *     props.setProperty("bootstrap.servers", "localhost:9092");
+ *     props.setProperty("group.id", "test");
+ *     props.setProperty("enable.auto.commit", "false");
+ *     props.setProperty("key.deserializer", "org.apache.kafka.common.serialization.StringDeserializer");
+ *     props.setProperty("value.deserializer", "org.apache.kafka.common.serialization.StringDeserializer");
+ *     KafkaConsumer<String, String> consumer = new KafkaConsumer<>(props);
+ *     consumer.subscribe(Arrays.asList("foo", "bar"));
+ *     final int minBatchSize = 200;
+ *     List<ConsumerRecord<String, String>> buffer = new ArrayList<>();
+ *     while (true) {
+ *         ConsumerRecords<String, String> records = consumer.poll(Duration.ofMillis(100));
+ *         for (ConsumerRecord<String, String> record : records) {
+ *             buffer.add(record);
+ *         }
+ *         if (buffer.size() >= minBatchSize) {
+ *             insertIntoDb(buffer);
+ *             consumer.commitSync();
+ *             buffer.clear();
+ *         }
+ *     }
+ * 
+ * + * In this example we will consume a batch of records and batch them up in memory. When we have enough records + * batched, we will insert them into a database. If we allowed offsets to auto commit as in the previous example, records + * would be considered consumed after they were returned to the user in {@link #poll(Duration) poll}. It would then be + * possible + * for our process to fail after batching the records, but before they had been inserted into the database. + *

+ * To avoid this, we will manually commit the offsets only after the corresponding records have been inserted into the + * database. This gives us exact control of when a record is considered consumed. This raises the opposite possibility: + * the process could fail in the interval after the insert into the database but before the commit (even though this + * would likely just be a few milliseconds, it is a possibility). In this case the process that took over consumption + * would consume from last committed offset and would repeat the insert of the last batch of data. Used in this way + * Kafka provides what is often called "at-least-once" delivery guarantees, as each record will likely be delivered one + * time but in failure cases could be duplicated. + *

+ * Note: Using automatic offset commits can also give you "at-least-once" delivery, but the requirement is that + * you must consume all data returned from each call to {@link #poll(Duration)} before any subsequent calls, or before + * {@link #close() closing} the consumer. If you fail to do either of these, it is possible for the committed offset + * to get ahead of the consumed position, which results in missing records. The advantage of using manual offset + * control is that you have direct control over when a record is considered "consumed." + *

+ * The above example uses {@link #commitSync() commitSync} to mark all received records as committed. In some cases + * you may wish to have even finer control over which records have been committed by specifying an offset explicitly. + * In the example below we commit offset after we finish handling the records in each partition. + *

+ *

+ *     try {
+ *         while(running) {
+ *             ConsumerRecords<String, String> records = consumer.poll(Duration.ofMillis(Long.MAX_VALUE));
+ *             for (TopicPartition partition : records.partitions()) {
+ *                 List<ConsumerRecord<String, String>> partitionRecords = records.records(partition);
+ *                 for (ConsumerRecord<String, String> record : partitionRecords) {
+ *                     System.out.println(record.offset() + ": " + record.value());
+ *                 }
+ *                 long lastOffset = partitionRecords.get(partitionRecords.size() - 1).offset();
+ *                 consumer.commitSync(Collections.singletonMap(partition, new OffsetAndMetadata(lastOffset + 1)));
+ *             }
+ *         }
+ *     } finally {
+ *       consumer.close();
+ *     }
+ * 
+ * + * Note: The committed offset should always be the offset of the next message that your application will read. + * Thus, when calling {@link #commitSync(Map) commitSync(offsets)} you should add one to the offset of the last message processed. + * + *

Manual Partition Assignment

+ * + * In the previous examples, we subscribed to the topics we were interested in and let Kafka dynamically assign a + * fair share of the partitions for those topics based on the active consumers in the group. However, in + * some cases you may need finer control over the specific partitions that are assigned. For example: + *

+ *

    + *
  • If the process is maintaining some kind of local state associated with that partition (like a + * local on-disk key-value store), then it should only get records for the partition it is maintaining on disk. + *
  • If the process itself is highly available and will be restarted if it fails (perhaps using a + * cluster management framework like YARN, Mesos, or AWS facilities, or as part of a stream processing framework). In + * this case there is no need for Kafka to detect the failure and reassign the partition since the consuming process + * will be restarted on another machine. + *
+ *

+ * To use this mode, instead of subscribing to the topic using {@link #subscribe(Collection) subscribe}, you just call + * {@link #assign(Collection)} with the full list of partitions that you want to consume. + * + *

+ *     String topic = "foo";
+ *     TopicPartition partition0 = new TopicPartition(topic, 0);
+ *     TopicPartition partition1 = new TopicPartition(topic, 1);
+ *     consumer.assign(Arrays.asList(partition0, partition1));
+ * 
+ * + * Once assigned, you can call {@link #poll(Duration) poll} in a loop, just as in the preceding examples to consume + * records. The group that the consumer specifies is still used for committing offsets, but now the set of partitions + * will only change with another call to {@link #assign(Collection) assign}. Manual partition assignment does + * not use group coordination, so consumer failures will not cause assigned partitions to be rebalanced. Each consumer + * acts independently even if it shares a groupId with another consumer. To avoid offset commit conflicts, you should + * usually ensure that the groupId is unique for each consumer instance. + *

+ * Note that it isn't possible to mix manual partition assignment (i.e. using {@link #assign(Collection) assign}) + * with dynamic partition assignment through topic subscription (i.e. using {@link #subscribe(Collection) subscribe}). + * + *

Storing Offsets Outside Kafka

+ * + * The consumer application need not use Kafka's built-in offset storage, it can store offsets in a store of its own + * choosing. The primary use case for this is allowing the application to store both the offset and the results of the + * consumption in the same system in a way that both the results and offsets are stored atomically. This is not always + * possible, but when it is it will make the consumption fully atomic and give "exactly once" semantics that are + * stronger than the default "at-least once" semantics you get with Kafka's offset commit functionality. + *

+ * Here are a couple of examples of this type of usage: + *

    + *
  • If the results of the consumption are being stored in a relational database, storing the offset in the database + * as well can allow committing both the results and offset in a single transaction. Thus either the transaction will + * succeed and the offset will be updated based on what was consumed or the result will not be stored and the offset + * won't be updated. + *
  • If the results are being stored in a local store it may be possible to store the offset there as well. For + * example a search index could be built by subscribing to a particular partition and storing both the offset and the + * indexed data together. If this is done in a way that is atomic, it is often possible to have it be the case that even + * if a crash occurs that causes unsync'd data to be lost, whatever is left has the corresponding offset stored as well. + * This means that in this case the indexing process that comes back having lost recent updates just resumes indexing + * from what it has ensuring that no updates are lost. + *
+ *

+ * Each record comes with its own offset, so to manage your own offset you just need to do the following: + * + *

    + *
  • Configure enable.auto.commit=false + *
  • Use the offset provided with each {@link ConsumerRecord} to save your position. + *
  • On restart restore the position of the consumer using {@link #seek(TopicPartition, long)}. + *
+ * + *

+ * This type of usage is simplest when the partition assignment is also done manually (this would be likely in the + * search index use case described above). If the partition assignment is done automatically special care is + * needed to handle the case where partition assignments change. This can be done by providing a + * {@link ConsumerRebalanceListener} instance in the call to {@link #subscribe(Collection, ConsumerRebalanceListener)} + * and {@link #subscribe(Pattern, ConsumerRebalanceListener)}. + * For example, when partitions are taken from a consumer the consumer will want to commit its offset for those partitions by + * implementing {@link ConsumerRebalanceListener#onPartitionsRevoked(Collection)}. When partitions are assigned to a + * consumer, the consumer will want to look up the offset for those new partitions and correctly initialize the consumer + * to that position by implementing {@link ConsumerRebalanceListener#onPartitionsAssigned(Collection)}. + *

+ * Another common use for {@link ConsumerRebalanceListener} is to flush any caches the application maintains for + * partitions that are moved elsewhere. + * + *

Controlling The Consumer's Position

+ * + * In most use cases the consumer will simply consume records from beginning to end, periodically committing its + * position (either automatically or manually). However Kafka allows the consumer to manually control its position, + * moving forward or backwards in a partition at will. This means a consumer can re-consume older records, or skip to + * the most recent records without actually consuming the intermediate records. + *

+ * There are several instances where manually controlling the consumer's position can be useful. + *

+ * One case is for time-sensitive record processing it may make sense for a consumer that falls far enough behind to not + * attempt to catch up processing all records, but rather just skip to the most recent records. + *

+ * Another use case is for a system that maintains local state as described in the previous section. In such a system + * the consumer will want to initialize its position on start-up to whatever is contained in the local store. Likewise + * if the local state is destroyed (say because the disk is lost) the state may be recreated on a new machine by + * re-consuming all the data and recreating the state (assuming that Kafka is retaining sufficient history). + *

+ * Kafka allows specifying the position using {@link #seek(TopicPartition, long)} to specify the new position. Special + * methods for seeking to the earliest and latest offset the server maintains are also available ( + * {@link #seekToBeginning(Collection)} and {@link #seekToEnd(Collection)} respectively). + * + *

Consumption Flow Control

+ * + * If a consumer is assigned multiple partitions to fetch data from, it will try to consume from all of them at the same time, + * effectively giving these partitions the same priority for consumption. However in some cases consumers may want to + * first focus on fetching from some subset of the assigned partitions at full speed, and only start fetching other partitions + * when these partitions have few or no data to consume. + * + *

+ * One of such cases is stream processing, where processor fetches from two topics and performs the join on these two streams. + * When one of the topics is long lagging behind the other, the processor would like to pause fetching from the ahead topic + * in order to get the lagging stream to catch up. Another example is bootstraping upon consumer starting up where there are + * a lot of history data to catch up, the applications usually want to get the latest data on some of the topics before consider + * fetching other topics. + * + *

+ * Kafka supports dynamic controlling of consumption flows by using {@link #pause(Collection)} and {@link #resume(Collection)} + * to pause the consumption on the specified assigned partitions and resume the consumption + * on the specified paused partitions respectively in the future {@link #poll(Duration)} calls. + * + *

Reading Transactional Messages

+ * + *

+ * Transactions were introduced in Kafka 0.11.0 wherein applications can write to multiple topics and partitions atomically. + * In order for this to work, consumers reading from these partitions should be configured to only read committed data. + * This can be achieved by setting the {@code isolation.level=read_committed} in the consumer's configuration. + * + *

+ * In read_committed mode, the consumer will read only those transactional messages which have been + * successfully committed. It will continue to read non-transactional messages as before. There is no client-side + * buffering in read_committed mode. Instead, the end offset of a partition for a read_committed + * consumer would be the offset of the first message in the partition belonging to an open transaction. This offset + * is known as the 'Last Stable Offset'(LSO).

+ * + *

+ * A {@code read_committed} consumer will only read up to the LSO and filter out any transactional + * messages which have been aborted. The LSO also affects the behavior of {@link #seekToEnd(Collection)} and + * {@link #endOffsets(Collection)} for {@code read_committed} consumers, details of which are in each method's documentation. + * Finally, the fetch lag metrics are also adjusted to be relative to the LSO for {@code read_committed} consumers. + * + *

+ * Partitions with transactional messages will include commit or abort markers which indicate the result of a transaction. + * There markers are not returned to applications, yet have an offset in the log. As a result, applications reading from + * topics with transactional messages will see gaps in the consumed offsets. These missing messages would be the transaction + * markers, and they are filtered out for consumers in both isolation levels. Additionally, applications using + * {@code read_committed} consumers may also see gaps due to aborted transactions, since those messages would not + * be returned by the consumer and yet would have valid offsets. + * + *

Multi-threaded Processing

+ * + * The Kafka consumer is NOT thread-safe. All network I/O happens in the thread of the application + * making the call. It is the responsibility of the user to ensure that multi-threaded access + * is properly synchronized. Un-synchronized access will result in {@link ConcurrentModificationException}. + * + *

+ * The only exception to this rule is {@link #wakeup()}, which can safely be used from an external thread to + * interrupt an active operation. In this case, a {@link org.apache.kafka.common.errors.WakeupException} will be + * thrown from the thread blocking on the operation. This can be used to shutdown the consumer from another thread. + * The following snippet shows the typical pattern: + * + *

+ * public class KafkaConsumerRunner implements Runnable {
+ *     private final AtomicBoolean closed = new AtomicBoolean(false);
+ *     private final KafkaConsumer consumer;
+ *
+ *     public KafkaConsumerRunner(KafkaConsumer consumer) {
+ *       this.consumer = consumer;
+ *     }
+ *
+ *     {@literal}@Override
+ *     public void run() {
+ *         try {
+ *             consumer.subscribe(Arrays.asList("topic"));
+ *             while (!closed.get()) {
+ *                 ConsumerRecords records = consumer.poll(Duration.ofMillis(10000));
+ *                 // Handle new records
+ *             }
+ *         } catch (WakeupException e) {
+ *             // Ignore exception if closing
+ *             if (!closed.get()) throw e;
+ *         } finally {
+ *             consumer.close();
+ *         }
+ *     }
+ *
+ *     // Shutdown hook which can be called from a separate thread
+ *     public void shutdown() {
+ *         closed.set(true);
+ *         consumer.wakeup();
+ *     }
+ * }
+ * 
+ * + * Then in a separate thread, the consumer can be shutdown by setting the closed flag and waking up the consumer. + * + *

+ *

+ *     closed.set(true);
+ *     consumer.wakeup();
+ * 
+ * + *

+ * Note that while it is possible to use thread interrupts instead of {@link #wakeup()} to abort a blocking operation + * (in which case, {@link InterruptException} will be raised), we discourage their use since they may cause a clean + * shutdown of the consumer to be aborted. Interrupts are mainly supported for those cases where using {@link #wakeup()} + * is impossible, e.g. when a consumer thread is managed by code that is unaware of the Kafka client. + * + *

+ * We have intentionally avoided implementing a particular threading model for processing. This leaves several + * options for implementing multi-threaded processing of records. + * + *

1. One Consumer Per Thread

+ * + * A simple option is to give each thread its own consumer instance. Here are the pros and cons of this approach: + *
    + *
  • PRO: It is the easiest to implement + *
  • PRO: It is often the fastest as no inter-thread co-ordination is needed + *
  • PRO: It makes in-order processing on a per-partition basis very easy to implement (each thread just + * processes messages in the order it receives them). + *
  • CON: More consumers means more TCP connections to the cluster (one per thread). In general Kafka handles + * connections very efficiently so this is generally a small cost. + *
  • CON: Multiple consumers means more requests being sent to the server and slightly less batching of data + * which can cause some drop in I/O throughput. + *
  • CON: The number of total threads across all processes will be limited by the total number of partitions. + *
+ * + *

2. Decouple Consumption and Processing

+ * + * Another alternative is to have one or more consumer threads that do all data consumption and hands off + * {@link ConsumerRecords} instances to a blocking queue consumed by a pool of processor threads that actually handle + * the record processing. + * + * This option likewise has pros and cons: + *
    + *
  • PRO: This option allows independently scaling the number of consumers and processors. This makes it + * possible to have a single consumer that feeds many processor threads, avoiding any limitation on partitions. + *
  • CON: Guaranteeing order across the processors requires particular care as the threads will execute + * independently an earlier chunk of data may actually be processed after a later chunk of data just due to the luck of + * thread execution timing. For processing that has no ordering requirements this is not a problem. + *
  • CON: Manually committing the position becomes harder as it requires that all threads co-ordinate to ensure + * that processing is complete for that partition. + *
+ * + * There are many possible variations on this approach. For example each processor thread can have its own queue, and + * the consumer threads can hash into these queues using the TopicPartition to ensure in-order consumption and simplify + * commit. + */ +public class KafkaConsumer implements Consumer { + + private static final String CLIENT_ID_METRIC_TAG = "client-id"; + private static final long NO_CURRENT_THREAD = -1L; + private static final String JMX_PREFIX = "kafka.consumer"; + static final long DEFAULT_CLOSE_TIMEOUT_MS = 30 * 1000; + + // Visible for testing + final Metrics metrics; + final KafkaConsumerMetrics kafkaConsumerMetrics; + + private Logger log; + private final String clientId; + private final Optional groupId; + private final ConsumerCoordinator coordinator; + private final Deserializer keyDeserializer; + private final Deserializer valueDeserializer; + private final Fetcher fetcher; + private final ConsumerInterceptors interceptors; + private final IsolationLevel isolationLevel; + + private final Time time; + private final ConsumerNetworkClient client; + private final SubscriptionState subscriptions; + private final ConsumerMetadata metadata; + private final long retryBackoffMs; + private final long requestTimeoutMs; + private final int defaultApiTimeoutMs; + private volatile boolean closed = false; + private List assignors; + + // currentThread holds the threadId of the current thread accessing KafkaConsumer + // and is used to prevent multi-threaded access + private final AtomicLong currentThread = new AtomicLong(NO_CURRENT_THREAD); + // refcount is used to allow reentrant access by the thread who has acquired currentThread + private final AtomicInteger refcount = new AtomicInteger(0); + + // to keep from repeatedly scanning subscriptions in poll(), cache the result during metadata updates + private boolean cachedSubscriptionHashAllFetchPositions; + + /** + * A consumer is instantiated by providing a set of key-value pairs as configuration. Valid configuration strings + * are documented here. Values can be + * either strings or objects of the appropriate type (for example a numeric configuration would accept either the + * string "42" or the integer 42). + *

+ * Valid configuration strings are documented at {@link ConsumerConfig}. + *

+ * Note: after creating a {@code KafkaConsumer} you must always {@link #close()} it to avoid resource leaks. + * + * @param configs The consumer configs + */ + public KafkaConsumer(Map configs) { + this(configs, null, null); + } + + /** + * A consumer is instantiated by providing a {@link java.util.Properties} object as configuration. + *

+ * Valid configuration strings are documented at {@link ConsumerConfig}. + *

+ * Note: after creating a {@code KafkaConsumer} you must always {@link #close()} it to avoid resource leaks. + * + * @param properties The consumer configuration properties + */ + public KafkaConsumer(Properties properties) { + this(properties, null, null); + } + + /** + * A consumer is instantiated by providing a {@link java.util.Properties} object as configuration, and a + * key and a value {@link Deserializer}. + *

+ * Valid configuration strings are documented at {@link ConsumerConfig}. + *

+ * Note: after creating a {@code KafkaConsumer} you must always {@link #close()} it to avoid resource leaks. + * + * @param properties The consumer configuration properties + * @param keyDeserializer The deserializer for key that implements {@link Deserializer}. The configure() method + * won't be called in the consumer when the deserializer is passed in directly. + * @param valueDeserializer The deserializer for value that implements {@link Deserializer}. The configure() method + * won't be called in the consumer when the deserializer is passed in directly. + */ + public KafkaConsumer(Properties properties, + Deserializer keyDeserializer, + Deserializer valueDeserializer) { + this(Utils.propsToMap(properties), keyDeserializer, valueDeserializer); + } + + /** + * A consumer is instantiated by providing a set of key-value pairs as configuration, and a key and a value {@link Deserializer}. + *

+ * Valid configuration strings are documented at {@link ConsumerConfig}. + *

+ * Note: after creating a {@code KafkaConsumer} you must always {@link #close()} it to avoid resource leaks. + * + * @param configs The consumer configs + * @param keyDeserializer The deserializer for key that implements {@link Deserializer}. The configure() method + * won't be called in the consumer when the deserializer is passed in directly. + * @param valueDeserializer The deserializer for value that implements {@link Deserializer}. The configure() method + * won't be called in the consumer when the deserializer is passed in directly. + */ + public KafkaConsumer(Map configs, + Deserializer keyDeserializer, + Deserializer valueDeserializer) { + this(new ConsumerConfig(ConsumerConfig.appendDeserializerToConfig(configs, keyDeserializer, valueDeserializer)), + keyDeserializer, valueDeserializer); + } + + @SuppressWarnings("unchecked") + KafkaConsumer(ConsumerConfig config, Deserializer keyDeserializer, Deserializer valueDeserializer) { + try { + GroupRebalanceConfig groupRebalanceConfig = new GroupRebalanceConfig(config, + GroupRebalanceConfig.ProtocolType.CONSUMER); + + this.groupId = Optional.ofNullable(groupRebalanceConfig.groupId); + this.clientId = config.getString(CommonClientConfigs.CLIENT_ID_CONFIG); + + LogContext logContext; + + // If group.instance.id is set, we will append it to the log context. + if (groupRebalanceConfig.groupInstanceId.isPresent()) { + logContext = new LogContext("[Consumer instanceId=" + groupRebalanceConfig.groupInstanceId.get() + + ", clientId=" + clientId + ", groupId=" + groupId.orElse("null") + "] "); + } else { + logContext = new LogContext("[Consumer clientId=" + clientId + ", groupId=" + groupId.orElse("null") + "] "); + } + + this.log = logContext.logger(getClass()); + boolean enableAutoCommit = config.maybeOverrideEnableAutoCommit(); + groupId.ifPresent(groupIdStr -> { + if (groupIdStr.isEmpty()) { + log.warn("Support for using the empty group id by consumers is deprecated and will be removed in the next major release."); + } + }); + + log.debug("Initializing the Kafka consumer"); + this.requestTimeoutMs = config.getInt(ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG); + this.defaultApiTimeoutMs = config.getInt(ConsumerConfig.DEFAULT_API_TIMEOUT_MS_CONFIG); + this.time = Time.SYSTEM; + this.metrics = buildMetrics(config, time, clientId); + this.retryBackoffMs = config.getLong(ConsumerConfig.RETRY_BACKOFF_MS_CONFIG); + + List> interceptorList = (List) config.getConfiguredInstances( + ConsumerConfig.INTERCEPTOR_CLASSES_CONFIG, + ConsumerInterceptor.class, + Collections.singletonMap(ConsumerConfig.CLIENT_ID_CONFIG, clientId)); + this.interceptors = new ConsumerInterceptors<>(interceptorList); + if (keyDeserializer == null) { + this.keyDeserializer = config.getConfiguredInstance(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, Deserializer.class); + this.keyDeserializer.configure(config.originals(Collections.singletonMap(ConsumerConfig.CLIENT_ID_CONFIG, clientId)), true); + } else { + config.ignore(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG); + this.keyDeserializer = keyDeserializer; + } + if (valueDeserializer == null) { + this.valueDeserializer = config.getConfiguredInstance(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, Deserializer.class); + this.valueDeserializer.configure(config.originals(Collections.singletonMap(ConsumerConfig.CLIENT_ID_CONFIG, clientId)), false); + } else { + config.ignore(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG); + this.valueDeserializer = valueDeserializer; + } + OffsetResetStrategy offsetResetStrategy = OffsetResetStrategy.valueOf(config.getString(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG).toUpperCase(Locale.ROOT)); + this.subscriptions = new SubscriptionState(logContext, offsetResetStrategy); + ClusterResourceListeners clusterResourceListeners = configureClusterResourceListeners(keyDeserializer, + valueDeserializer, metrics.reporters(), interceptorList); + this.metadata = new ConsumerMetadata(retryBackoffMs, + config.getLong(ConsumerConfig.METADATA_MAX_AGE_CONFIG), + !config.getBoolean(ConsumerConfig.EXCLUDE_INTERNAL_TOPICS_CONFIG), + config.getBoolean(ConsumerConfig.ALLOW_AUTO_CREATE_TOPICS_CONFIG), + subscriptions, logContext, clusterResourceListeners); + List addresses = ClientUtils.parseAndValidateAddresses( + config.getList(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG), config.getString(ConsumerConfig.CLIENT_DNS_LOOKUP_CONFIG)); + this.metadata.bootstrap(addresses); + String metricGrpPrefix = "consumer"; + + FetcherMetricsRegistry metricsRegistry = new FetcherMetricsRegistry(Collections.singleton(CLIENT_ID_METRIC_TAG), metricGrpPrefix); + ChannelBuilder channelBuilder = ClientUtils.createChannelBuilder(config, time, logContext); + this.isolationLevel = IsolationLevel.valueOf( + config.getString(ConsumerConfig.ISOLATION_LEVEL_CONFIG).toUpperCase(Locale.ROOT)); + Sensor throttleTimeSensor = Fetcher.throttleTimeSensor(metrics, metricsRegistry); + int heartbeatIntervalMs = config.getInt(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG); + + ApiVersions apiVersions = new ApiVersions(); + NetworkClient netClient = new NetworkClient( + new Selector(config.getLong(ConsumerConfig.CONNECTIONS_MAX_IDLE_MS_CONFIG), metrics, time, metricGrpPrefix, channelBuilder, logContext), + this.metadata, + clientId, + 100, // a fixed large enough value will suffice for max in-flight requests + config.getLong(ConsumerConfig.RECONNECT_BACKOFF_MS_CONFIG), + config.getLong(ConsumerConfig.RECONNECT_BACKOFF_MAX_MS_CONFIG), + config.getInt(ConsumerConfig.SEND_BUFFER_CONFIG), + config.getInt(ConsumerConfig.RECEIVE_BUFFER_CONFIG), + config.getInt(ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG), + config.getLong(ConsumerConfig.SOCKET_CONNECTION_SETUP_TIMEOUT_MS_CONFIG), + config.getLong(ConsumerConfig.SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS_CONFIG), + time, + true, + apiVersions, + throttleTimeSensor, + logContext); + this.client = new ConsumerNetworkClient( + logContext, + netClient, + metadata, + time, + retryBackoffMs, + config.getInt(ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG), + heartbeatIntervalMs); //Will avoid blocking an extended period of time to prevent heartbeat thread starvation + + this.assignors = ConsumerPartitionAssignor.getAssignorInstances( + config.getList(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG), + config.originals(Collections.singletonMap(ConsumerConfig.CLIENT_ID_CONFIG, clientId)) + ); + + // no coordinator will be constructed for the default (null) group id + this.coordinator = !groupId.isPresent() ? null : + new ConsumerCoordinator(groupRebalanceConfig, + logContext, + this.client, + assignors, + this.metadata, + this.subscriptions, + metrics, + metricGrpPrefix, + this.time, + enableAutoCommit, + config.getInt(ConsumerConfig.AUTO_COMMIT_INTERVAL_MS_CONFIG), + this.interceptors, + config.getBoolean(ConsumerConfig.THROW_ON_FETCH_STABLE_OFFSET_UNSUPPORTED)); + this.fetcher = new Fetcher<>( + logContext, + this.client, + config.getInt(ConsumerConfig.FETCH_MIN_BYTES_CONFIG), + config.getInt(ConsumerConfig.FETCH_MAX_BYTES_CONFIG), + config.getInt(ConsumerConfig.FETCH_MAX_WAIT_MS_CONFIG), + config.getInt(ConsumerConfig.MAX_PARTITION_FETCH_BYTES_CONFIG), + config.getInt(ConsumerConfig.MAX_POLL_RECORDS_CONFIG), + config.getBoolean(ConsumerConfig.CHECK_CRCS_CONFIG), + config.getString(ConsumerConfig.CLIENT_RACK_CONFIG), + this.keyDeserializer, + this.valueDeserializer, + this.metadata, + this.subscriptions, + metrics, + metricsRegistry, + this.time, + this.retryBackoffMs, + this.requestTimeoutMs, + isolationLevel, + apiVersions); + + this.kafkaConsumerMetrics = new KafkaConsumerMetrics(metrics, metricGrpPrefix); + + config.logUnused(); + AppInfoParser.registerAppInfo(JMX_PREFIX, clientId, metrics, time.milliseconds()); + log.debug("Kafka consumer initialized"); + } catch (Throwable t) { + // call close methods if internal objects are already constructed; this is to prevent resource leak. see KAFKA-2121 + // we do not need to call `close` at all when `log` is null, which means no internal objects were initialized. + if (this.log != null) { + close(0, true); + } + // now propagate the exception + throw new KafkaException("Failed to construct kafka consumer", t); + } + } + + // visible for testing + KafkaConsumer(LogContext logContext, + String clientId, + ConsumerCoordinator coordinator, + Deserializer keyDeserializer, + Deserializer valueDeserializer, + Fetcher fetcher, + ConsumerInterceptors interceptors, + Time time, + ConsumerNetworkClient client, + Metrics metrics, + SubscriptionState subscriptions, + ConsumerMetadata metadata, + long retryBackoffMs, + long requestTimeoutMs, + int defaultApiTimeoutMs, + List assignors, + String groupId) { + this.log = logContext.logger(getClass()); + this.clientId = clientId; + this.coordinator = coordinator; + this.keyDeserializer = keyDeserializer; + this.valueDeserializer = valueDeserializer; + this.fetcher = fetcher; + this.isolationLevel = IsolationLevel.READ_UNCOMMITTED; + this.interceptors = Objects.requireNonNull(interceptors); + this.time = time; + this.client = client; + this.metrics = metrics; + this.subscriptions = subscriptions; + this.metadata = metadata; + this.retryBackoffMs = retryBackoffMs; + this.requestTimeoutMs = requestTimeoutMs; + this.defaultApiTimeoutMs = defaultApiTimeoutMs; + this.assignors = assignors; + this.groupId = Optional.ofNullable(groupId); + this.kafkaConsumerMetrics = new KafkaConsumerMetrics(metrics, "consumer"); + } + + private static Metrics buildMetrics(ConsumerConfig config, Time time, String clientId) { + Map metricsTags = Collections.singletonMap(CLIENT_ID_METRIC_TAG, clientId); + MetricConfig metricConfig = new MetricConfig().samples(config.getInt(ConsumerConfig.METRICS_NUM_SAMPLES_CONFIG)) + .timeWindow(config.getLong(ConsumerConfig.METRICS_SAMPLE_WINDOW_MS_CONFIG), TimeUnit.MILLISECONDS) + .recordLevel(Sensor.RecordingLevel.forName(config.getString(ConsumerConfig.METRICS_RECORDING_LEVEL_CONFIG))) + .tags(metricsTags); + List reporters = config.getConfiguredInstances(ConsumerConfig.METRIC_REPORTER_CLASSES_CONFIG, + MetricsReporter.class, Collections.singletonMap(ConsumerConfig.CLIENT_ID_CONFIG, clientId)); + JmxReporter jmxReporter = new JmxReporter(); + jmxReporter.configure(config.originals(Collections.singletonMap(ConsumerConfig.CLIENT_ID_CONFIG, clientId))); + reporters.add(jmxReporter); + MetricsContext metricsContext = new KafkaMetricsContext(JMX_PREFIX, + config.originalsWithPrefix(CommonClientConfigs.METRICS_CONTEXT_PREFIX)); + return new Metrics(metricConfig, reporters, time, metricsContext); + } + + /** + * Get the set of partitions currently assigned to this consumer. If subscription happened by directly assigning + * partitions using {@link #assign(Collection)} then this will simply return the same partitions that + * were assigned. If topic subscription was used, then this will give the set of topic partitions currently assigned + * to the consumer (which may be none if the assignment hasn't happened yet, or the partitions are in the + * process of getting reassigned). + * @return The set of partitions currently assigned to this consumer + */ + public Set assignment() { + acquireAndEnsureOpen(); + try { + return Collections.unmodifiableSet(this.subscriptions.assignedPartitions()); + } finally { + release(); + } + } + + /** + * Get the current subscription. Will return the same topics used in the most recent call to + * {@link #subscribe(Collection, ConsumerRebalanceListener)}, or an empty set if no such call has been made. + * @return The set of topics currently subscribed to + */ + public Set subscription() { + acquireAndEnsureOpen(); + try { + return Collections.unmodifiableSet(new HashSet<>(this.subscriptions.subscription())); + } finally { + release(); + } + } + + /** + * Subscribe to the given list of topics to get dynamically + * assigned partitions. Topic subscriptions are not incremental. This list will replace the current + * assignment (if there is one). Note that it is not possible to combine topic subscription with group management + * with manual partition assignment through {@link #assign(Collection)}. + * + * If the given list of topics is empty, it is treated the same as {@link #unsubscribe()}. + * + *

+ * As part of group management, the consumer will keep track of the list of consumers that belong to a particular + * group and will trigger a rebalance operation if any one of the following events are triggered: + *

    + *
  • Number of partitions change for any of the subscribed topics + *
  • A subscribed topic is created or deleted + *
  • An existing member of the consumer group is shutdown or fails + *
  • A new member is added to the consumer group + *
+ *

+ * When any of these events are triggered, the provided listener will be invoked first to indicate that + * the consumer's assignment has been revoked, and then again when the new assignment has been received. + * Note that rebalances will only occur during an active call to {@link #poll(Duration)}, so callbacks will + * also only be invoked during that time. + * + * The provided listener will immediately override any listener set in a previous call to subscribe. + * It is guaranteed, however, that the partitions revoked/assigned through this interface are from topics + * subscribed in this call. See {@link ConsumerRebalanceListener} for more details. + * + * @param topics The list of topics to subscribe to + * @param listener Non-null listener instance to get notifications on partition assignment/revocation for the + * subscribed topics + * @throws IllegalArgumentException If topics is null or contains null or empty elements, or if listener is null + * @throws IllegalStateException If {@code subscribe()} is called previously with pattern, or assign is called + * previously (without a subsequent call to {@link #unsubscribe()}), or if not + * configured at-least one partition assignment strategy + */ + @Override + public void subscribe(Collection topics, ConsumerRebalanceListener listener) { + acquireAndEnsureOpen(); + try { + maybeThrowInvalidGroupIdException(); + if (topics == null) + throw new IllegalArgumentException("Topic collection to subscribe to cannot be null"); + if (topics.isEmpty()) { + // treat subscribing to empty topic list as the same as unsubscribing + this.unsubscribe(); + } else { + for (String topic : topics) { + if (Utils.isBlank(topic)) + throw new IllegalArgumentException("Topic collection to subscribe to cannot contain null or empty topic"); + } + + throwIfNoAssignorsConfigured(); + fetcher.clearBufferedDataForUnassignedTopics(topics); + log.info("Subscribed to topic(s): {}", Utils.join(topics, ", ")); + if (this.subscriptions.subscribe(new HashSet<>(topics), listener)) + metadata.requestUpdateForNewTopics(); + } + } finally { + release(); + } + } + + /** + * Subscribe to the given list of topics to get dynamically assigned partitions. + * Topic subscriptions are not incremental. This list will replace the current + * assignment (if there is one). It is not possible to combine topic subscription with group management + * with manual partition assignment through {@link #assign(Collection)}. + * + * If the given list of topics is empty, it is treated the same as {@link #unsubscribe()}. + * + *

+ * This is a short-hand for {@link #subscribe(Collection, ConsumerRebalanceListener)}, which + * uses a no-op listener. If you need the ability to seek to particular offsets, you should prefer + * {@link #subscribe(Collection, ConsumerRebalanceListener)}, since group rebalances will cause partition offsets + * to be reset. You should also provide your own listener if you are doing your own offset + * management since the listener gives you an opportunity to commit offsets before a rebalance finishes. + * + * @param topics The list of topics to subscribe to + * @throws IllegalArgumentException If topics is null or contains null or empty elements + * @throws IllegalStateException If {@code subscribe()} is called previously with pattern, or assign is called + * previously (without a subsequent call to {@link #unsubscribe()}), or if not + * configured at-least one partition assignment strategy + */ + @Override + public void subscribe(Collection topics) { + subscribe(topics, new NoOpConsumerRebalanceListener()); + } + + /** + * Subscribe to all topics matching specified pattern to get dynamically assigned partitions. + * The pattern matching will be done periodically against all topics existing at the time of check. + * This can be controlled through the {@code metadata.max.age.ms} configuration: by lowering + * the max metadata age, the consumer will refresh metadata more often and check for matching topics. + *

+ * See {@link #subscribe(Collection, ConsumerRebalanceListener)} for details on the + * use of the {@link ConsumerRebalanceListener}. Generally rebalances are triggered when there + * is a change to the topics matching the provided pattern and when consumer group membership changes. + * Group rebalances only take place during an active call to {@link #poll(Duration)}. + * + * @param pattern Pattern to subscribe to + * @param listener Non-null listener instance to get notifications on partition assignment/revocation for the + * subscribed topics + * @throws IllegalArgumentException If pattern or listener is null + * @throws IllegalStateException If {@code subscribe()} is called previously with topics, or assign is called + * previously (without a subsequent call to {@link #unsubscribe()}), or if not + * configured at-least one partition assignment strategy + */ + @Override + public void subscribe(Pattern pattern, ConsumerRebalanceListener listener) { + maybeThrowInvalidGroupIdException(); + if (pattern == null || pattern.toString().equals("")) + throw new IllegalArgumentException("Topic pattern to subscribe to cannot be " + (pattern == null ? + "null" : "empty")); + + acquireAndEnsureOpen(); + try { + throwIfNoAssignorsConfigured(); + log.info("Subscribed to pattern: '{}'", pattern); + this.subscriptions.subscribe(pattern, listener); + this.coordinator.updatePatternSubscription(metadata.fetch()); + this.metadata.requestUpdateForNewTopics(); + } finally { + release(); + } + } + + /** + * Subscribe to all topics matching specified pattern to get dynamically assigned partitions. + * The pattern matching will be done periodically against topics existing at the time of check. + *

+ * This is a short-hand for {@link #subscribe(Pattern, ConsumerRebalanceListener)}, which + * uses a no-op listener. If you need the ability to seek to particular offsets, you should prefer + * {@link #subscribe(Pattern, ConsumerRebalanceListener)}, since group rebalances will cause partition offsets + * to be reset. You should also provide your own listener if you are doing your own offset + * management since the listener gives you an opportunity to commit offsets before a rebalance finishes. + * + * @param pattern Pattern to subscribe to + * @throws IllegalArgumentException If pattern is null + * @throws IllegalStateException If {@code subscribe()} is called previously with topics, or assign is called + * previously (without a subsequent call to {@link #unsubscribe()}), or if not + * configured at-least one partition assignment strategy + */ + @Override + public void subscribe(Pattern pattern) { + subscribe(pattern, new NoOpConsumerRebalanceListener()); + } + + /** + * Unsubscribe from topics currently subscribed with {@link #subscribe(Collection)} or {@link #subscribe(Pattern)}. + * This also clears any partitions directly assigned through {@link #assign(Collection)}. + * + * @throws org.apache.kafka.common.KafkaException for any other unrecoverable errors (e.g. rebalance callback errors) + */ + public void unsubscribe() { + acquireAndEnsureOpen(); + try { + fetcher.clearBufferedDataForUnassignedPartitions(Collections.emptySet()); + if (this.coordinator != null) { + this.coordinator.onLeavePrepare(); + this.coordinator.maybeLeaveGroup("the consumer unsubscribed from all topics"); + } + this.subscriptions.unsubscribe(); + log.info("Unsubscribed all topics or patterns and assigned partitions"); + } finally { + release(); + } + } + + /** + * Manually assign a list of partitions to this consumer. This interface does not allow for incremental assignment + * and will replace the previous assignment (if there is one). + *

+ * If the given list of topic partitions is empty, it is treated the same as {@link #unsubscribe()}. + *

+ * Manual topic assignment through this method does not use the consumer's group management + * functionality. As such, there will be no rebalance operation triggered when group membership or cluster and topic + * metadata change. Note that it is not possible to use both manual partition assignment with {@link #assign(Collection)} + * and group assignment with {@link #subscribe(Collection, ConsumerRebalanceListener)}. + *

+ * If auto-commit is enabled, an async commit (based on the old assignment) will be triggered before the new + * assignment replaces the old one. + * + * @param partitions The list of partitions to assign this consumer + * @throws IllegalArgumentException If partitions is null or contains null or empty topics + * @throws IllegalStateException If {@code subscribe()} is called previously with topics or pattern + * (without a subsequent call to {@link #unsubscribe()}) + */ + @Override + public void assign(Collection partitions) { + acquireAndEnsureOpen(); + try { + if (partitions == null) { + throw new IllegalArgumentException("Topic partition collection to assign to cannot be null"); + } else if (partitions.isEmpty()) { + this.unsubscribe(); + } else { + for (TopicPartition tp : partitions) { + String topic = (tp != null) ? tp.topic() : null; + if (Utils.isBlank(topic)) + throw new IllegalArgumentException("Topic partitions to assign to cannot have null or empty topic"); + } + fetcher.clearBufferedDataForUnassignedPartitions(partitions); + + // make sure the offsets of topic partitions the consumer is unsubscribing from + // are committed since there will be no following rebalance + if (coordinator != null) + this.coordinator.maybeAutoCommitOffsetsAsync(time.milliseconds()); + + log.info("Subscribed to partition(s): {}", Utils.join(partitions, ", ")); + if (this.subscriptions.assignFromUser(new HashSet<>(partitions))) + metadata.requestUpdateForNewTopics(); + } + } finally { + release(); + } + } + + /** + * Fetch data for the topics or partitions specified using one of the subscribe/assign APIs. It is an error to not have + * subscribed to any topics or partitions before polling for data. + *

+ * On each poll, consumer will try to use the last consumed offset as the starting offset and fetch sequentially. The last + * consumed offset can be manually set through {@link #seek(TopicPartition, long)} or automatically set as the last committed + * offset for the subscribed list of partitions + * + * + * @param timeoutMs The time, in milliseconds, spent waiting in poll if data is not available in the buffer. + * If 0, returns immediately with any records that are available currently in the buffer, else returns empty. + * Must not be negative. + * @return map of topic to records since the last fetch for the subscribed list of topics and partitions + * + * @throws org.apache.kafka.clients.consumer.InvalidOffsetException if the offset for a partition or set of + * partitions is undefined or out of range and no offset reset policy has been configured + * @throws org.apache.kafka.common.errors.WakeupException if {@link #wakeup()} is called before or while this + * function is called + * @throws org.apache.kafka.common.errors.InterruptException if the calling thread is interrupted before or while + * this function is called + * @throws org.apache.kafka.common.errors.AuthenticationException if authentication fails. See the exception for more details + * @throws org.apache.kafka.common.errors.AuthorizationException if caller lacks Read access to any of the subscribed + * topics or to the configured groupId. See the exception for more details + * @throws org.apache.kafka.common.KafkaException for any other unrecoverable errors (e.g. invalid groupId or + * session timeout, errors deserializing key/value pairs, or any new error cases in future versions) + * @throws java.lang.IllegalArgumentException if the timeout value is negative + * @throws java.lang.IllegalStateException if the consumer is not subscribed to any topics or manually assigned any + * partitions to consume from + * @throws org.apache.kafka.common.errors.FencedInstanceIdException if this consumer instance gets fenced by broker. + * + * @deprecated Since 2.0. Use {@link #poll(Duration)}, which does not block beyond the timeout awaiting partition + * assignment. See KIP-266 for more information. + */ + @Deprecated + @Override + public ConsumerRecords poll(final long timeoutMs) { + return poll(time.timer(timeoutMs), false); + } + + /** + * Fetch data for the topics or partitions specified using one of the subscribe/assign APIs. It is an error to not have + * subscribed to any topics or partitions before polling for data. + *

+ * On each poll, consumer will try to use the last consumed offset as the starting offset and fetch sequentially. The last + * consumed offset can be manually set through {@link #seek(TopicPartition, long)} or automatically set as the last committed + * offset for the subscribed list of partitions + * + *

+ * This method returns immediately if there are records available. Otherwise, it will await the passed timeout. + * If the timeout expires, an empty record set will be returned. Note that this method may block beyond the + * timeout in order to execute custom {@link ConsumerRebalanceListener} callbacks. + * + * + * @param timeout The maximum time to block (must not be greater than {@link Long#MAX_VALUE} milliseconds) + * + * @return map of topic to records since the last fetch for the subscribed list of topics and partitions + * + * @throws org.apache.kafka.clients.consumer.InvalidOffsetException if the offset for a partition or set of + * partitions is undefined or out of range and no offset reset policy has been configured + * @throws org.apache.kafka.common.errors.WakeupException if {@link #wakeup()} is called before or while this + * function is called + * @throws org.apache.kafka.common.errors.InterruptException if the calling thread is interrupted before or while + * this function is called + * @throws org.apache.kafka.common.errors.AuthenticationException if authentication fails. See the exception for more details + * @throws org.apache.kafka.common.errors.AuthorizationException if caller lacks Read access to any of the subscribed + * topics or to the configured groupId. See the exception for more details + * @throws org.apache.kafka.common.KafkaException for any other unrecoverable errors (e.g. invalid groupId or + * session timeout, errors deserializing key/value pairs, your rebalance callback thrown exceptions, + * or any new error cases in future versions) + * @throws java.lang.IllegalArgumentException if the timeout value is negative + * @throws java.lang.IllegalStateException if the consumer is not subscribed to any topics or manually assigned any + * partitions to consume from + * @throws java.lang.ArithmeticException if the timeout is greater than {@link Long#MAX_VALUE} milliseconds. + * @throws org.apache.kafka.common.errors.InvalidTopicException if the current subscription contains any invalid + * topic (per {@link org.apache.kafka.common.internals.Topic#validate(String)}) + * @throws org.apache.kafka.common.errors.UnsupportedVersionException if the consumer attempts to fetch stable offsets + * when the broker doesn't support this feature + * @throws org.apache.kafka.common.errors.FencedInstanceIdException if this consumer instance gets fenced by broker. + */ + @Override + public ConsumerRecords poll(final Duration timeout) { + return poll(time.timer(timeout), true); + } + + /** + * @throws KafkaException if the rebalance callback throws exception + */ + private ConsumerRecords poll(final Timer timer, final boolean includeMetadataInTimeout) { + acquireAndEnsureOpen(); + try { + this.kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs()); + + if (this.subscriptions.hasNoSubscriptionOrUserAssignment()) { + throw new IllegalStateException("Consumer is not subscribed to any topics or assigned any partitions"); + } + + do { + client.maybeTriggerWakeup(); + + if (includeMetadataInTimeout) { + // try to update assignment metadata BUT do not need to block on the timer for join group + updateAssignmentMetadataIfNeeded(timer, false); + } else { + while (!updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE), true)) { + log.warn("Still waiting for metadata"); + } + } + + final Map>> records = pollForFetches(timer); + if (!records.isEmpty()) { + // before returning the fetched records, we can send off the next round of fetches + // and avoid block waiting for their responses to enable pipelining while the user + // is handling the fetched records. + // + // NOTE: since the consumed position has already been updated, we must not allow + // wakeups or any other errors to be triggered prior to returning the fetched records. + if (fetcher.sendFetches() > 0 || client.hasPendingRequests()) { + client.transmitSends(); + } + + return this.interceptors.onConsume(new ConsumerRecords<>(records)); + } + } while (timer.notExpired()); + + return ConsumerRecords.empty(); + } finally { + release(); + this.kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs()); + } + } + + boolean updateAssignmentMetadataIfNeeded(final Timer timer, final boolean waitForJoinGroup) { + if (coordinator != null && !coordinator.poll(timer, waitForJoinGroup)) { + return false; + } + + return updateFetchPositions(timer); + } + + /** + * @throws KafkaException if the rebalance callback throws exception + */ + private Map>> pollForFetches(Timer timer) { + long pollTimeout = coordinator == null ? timer.remainingMs() : + Math.min(coordinator.timeToNextPoll(timer.currentTimeMs()), timer.remainingMs()); + + // if data is available already, return it immediately + final Map>> records = fetcher.fetchedRecords(); + if (!records.isEmpty()) { + return records; + } + + // send any new fetches (won't resend pending fetches) + fetcher.sendFetches(); + + // We do not want to be stuck blocking in poll if we are missing some positions + // since the offset lookup may be backing off after a failure + + // NOTE: the use of cachedSubscriptionHashAllFetchPositions means we MUST call + // updateAssignmentMetadataIfNeeded before this method. + if (!cachedSubscriptionHashAllFetchPositions && pollTimeout > retryBackoffMs) { + pollTimeout = retryBackoffMs; + } + + log.trace("Polling for fetches with timeout {}", pollTimeout); + + Timer pollTimer = time.timer(pollTimeout); + client.poll(pollTimer, () -> { + // since a fetch might be completed by the background thread, we need this poll condition + // to ensure that we do not block unnecessarily in poll() + return !fetcher.hasAvailableFetches(); + }); + timer.update(pollTimer.currentTimeMs()); + + return fetcher.fetchedRecords(); + } + + /** + * Commit offsets returned on the last {@link #poll(Duration) poll()} for all the subscribed list of topics and + * partitions. + *

+ * This commits offsets only to Kafka. The offsets committed using this API will be used on the first fetch after + * every rebalance and also on startup. As such, if you need to store offsets in anything other than Kafka, this API + * should not be used. + *

+ * This is a synchronous commit and will block until either the commit succeeds, an unrecoverable error is + * encountered (in which case it is thrown to the caller), or the timeout specified by {@code default.api.timeout.ms} expires + * (in which case a {@link org.apache.kafka.common.errors.TimeoutException} is thrown to the caller). + *

+ * Note that asynchronous offset commits sent previously with the {@link #commitAsync(OffsetCommitCallback)} + * (or similar) are guaranteed to have their callbacks invoked prior to completion of this method. + * + * @throws org.apache.kafka.clients.consumer.CommitFailedException if the commit failed and cannot be retried. + * This fatal error can only occur if you are using automatic group management with {@link #subscribe(Collection)}, + * or if there is an active group with the same group.id which is using group management. In such cases, + * when you are trying to commit to partitions that are no longer assigned to this consumer because the + * consumer is for example no longer part of the group this exception would be thrown. + * @throws org.apache.kafka.common.errors.RebalanceInProgressException if the consumer instance is in the middle of a rebalance + * so it is not yet determined which partitions would be assigned to the consumer. In such cases you can first + * complete the rebalance by calling {@link #poll(Duration)} and commit can be reconsidered afterwards. + * NOTE when you reconsider committing after the rebalance, the assigned partitions may have changed, + * and also for those partitions that are still assigned their fetch positions may have changed too + * if more records are returned from the {@link #poll(Duration)} call. + * @throws org.apache.kafka.common.errors.WakeupException if {@link #wakeup()} is called before or while this + * function is called + * @throws org.apache.kafka.common.errors.InterruptException if the calling thread is interrupted before or while + * this function is called + * @throws org.apache.kafka.common.errors.AuthenticationException if authentication fails. See the exception for more details + * @throws org.apache.kafka.common.errors.AuthorizationException if not authorized to the topic or to the + * configured groupId. See the exception for more details + * @throws org.apache.kafka.common.KafkaException for any other unrecoverable errors (e.g. if offset metadata + * is too large or if the topic does not exist). + * @throws org.apache.kafka.common.errors.TimeoutException if the timeout specified by {@code default.api.timeout.ms} expires + * before successful completion of the offset commit + * @throws org.apache.kafka.common.errors.FencedInstanceIdException if this consumer instance gets fenced by broker. + */ + @Override + public void commitSync() { + commitSync(Duration.ofMillis(defaultApiTimeoutMs)); + } + + /** + * Commit offsets returned on the last {@link #poll(Duration) poll()} for all the subscribed list of topics and + * partitions. + *

+ * This commits offsets only to Kafka. The offsets committed using this API will be used on the first fetch after + * every rebalance and also on startup. As such, if you need to store offsets in anything other than Kafka, this API + * should not be used. + *

+ * This is a synchronous commit and will block until either the commit succeeds, an unrecoverable error is + * encountered (in which case it is thrown to the caller), or the passed timeout expires. + *

+ * Note that asynchronous offset commits sent previously with the {@link #commitAsync(OffsetCommitCallback)} + * (or similar) are guaranteed to have their callbacks invoked prior to completion of this method. + * + * @throws org.apache.kafka.clients.consumer.CommitFailedException if the commit failed and cannot be retried. + * This can only occur if you are using automatic group management with {@link #subscribe(Collection)}, + * or if there is an active group with the same group.id which is using group management. In such cases, + * when you are trying to commit to partitions that are no longer assigned to this consumer because the + * consumer is for example no longer part of the group this exception would be thrown. + * @throws org.apache.kafka.common.errors.RebalanceInProgressException if the consumer instance is in the middle of a rebalance + * so it is not yet determined which partitions would be assigned to the consumer. In such cases you can first + * complete the rebalance by calling {@link #poll(Duration)} and commit can be reconsidered afterwards. + * NOTE when you reconsider committing after the rebalance, the assigned partitions may have changed, + * and also for those partitions that are still assigned their fetch positions may have changed too + * if more records are returned from the {@link #poll(Duration)} call. + * @throws org.apache.kafka.common.errors.WakeupException if {@link #wakeup()} is called before or while this + * function is called + * @throws org.apache.kafka.common.errors.InterruptException if the calling thread is interrupted before or while + * this function is called + * @throws org.apache.kafka.common.errors.AuthenticationException if authentication fails. See the exception for more details + * @throws org.apache.kafka.common.errors.AuthorizationException if not authorized to the topic or to the + * configured groupId. See the exception for more details + * @throws org.apache.kafka.common.KafkaException for any other unrecoverable errors (e.g. if offset metadata + * is too large or if the topic does not exist). + * @throws org.apache.kafka.common.errors.TimeoutException if the timeout expires before successful completion + * of the offset commit + * @throws org.apache.kafka.common.errors.FencedInstanceIdException if this consumer instance gets fenced by broker. + */ + @Override + public void commitSync(Duration timeout) { + commitSync(subscriptions.allConsumed(), timeout); + } + + /** + * Commit the specified offsets for the specified list of topics and partitions. + *

+ * This commits offsets to Kafka. The offsets committed using this API will be used on the first fetch after every + * rebalance and also on startup. As such, if you need to store offsets in anything other than Kafka, this API + * should not be used. The committed offset should be the next message your application will consume, + * i.e. lastProcessedMessageOffset + 1. If automatic group management with {@link #subscribe(Collection)} is used, + * then the committed offsets must belong to the currently auto-assigned partitions. + *

+ * This is a synchronous commit and will block until either the commit succeeds or an unrecoverable error is + * encountered (in which case it is thrown to the caller), or the timeout specified by {@code default.api.timeout.ms} expires + * (in which case a {@link org.apache.kafka.common.errors.TimeoutException} is thrown to the caller). + *

+ * Note that asynchronous offset commits sent previously with the {@link #commitAsync(OffsetCommitCallback)} + * (or similar) are guaranteed to have their callbacks invoked prior to completion of this method. + * + * @param offsets A map of offsets by partition with associated metadata + * @throws org.apache.kafka.clients.consumer.CommitFailedException if the commit failed and cannot be retried. + * This can only occur if you are using automatic group management with {@link #subscribe(Collection)}, + * or if there is an active group with the same group.id which is using group management. In such cases, + * when you are trying to commit to partitions that are no longer assigned to this consumer because the + * consumer is for example no longer part of the group this exception would be thrown. + * @throws org.apache.kafka.common.errors.RebalanceInProgressException if the consumer instance is in the middle of a rebalance + * so it is not yet determined which partitions would be assigned to the consumer. In such cases you can first + * complete the rebalance by calling {@link #poll(Duration)} and commit can be reconsidered afterwards. + * NOTE when you reconsider committing after the rebalance, the assigned partitions may have changed, + * and also for those partitions that are still assigned their fetch positions may have changed too + * if more records are returned from the {@link #poll(Duration)} call, so when you retry committing + * you should consider updating the passed in {@code offset} parameter. + * @throws org.apache.kafka.common.errors.WakeupException if {@link #wakeup()} is called before or while this + * function is called + * @throws org.apache.kafka.common.errors.InterruptException if the calling thread is interrupted before or while + * this function is called + * @throws org.apache.kafka.common.errors.AuthenticationException if authentication fails. See the exception for more details + * @throws org.apache.kafka.common.errors.AuthorizationException if not authorized to the topic or to the + * configured groupId. See the exception for more details + * @throws java.lang.IllegalArgumentException if the committed offset is negative + * @throws org.apache.kafka.common.KafkaException for any other unrecoverable errors (e.g. if offset metadata + * is too large or if the topic does not exist). + * @throws org.apache.kafka.common.errors.TimeoutException if the timeout expires before successful completion + * of the offset commit + * @throws org.apache.kafka.common.errors.FencedInstanceIdException if this consumer instance gets fenced by broker. + */ + @Override + public void commitSync(final Map offsets) { + commitSync(offsets, Duration.ofMillis(defaultApiTimeoutMs)); + } + + /** + * Commit the specified offsets for the specified list of topics and partitions. + *

+ * This commits offsets to Kafka. The offsets committed using this API will be used on the first fetch after every + * rebalance and also on startup. As such, if you need to store offsets in anything other than Kafka, this API + * should not be used. The committed offset should be the next message your application will consume, + * i.e. lastProcessedMessageOffset + 1. If automatic group management with {@link #subscribe(Collection)} is used, + * then the committed offsets must belong to the currently auto-assigned partitions. + *

+ * This is a synchronous commit and will block until either the commit succeeds, an unrecoverable error is + * encountered (in which case it is thrown to the caller), or the timeout expires. + *

+ * Note that asynchronous offset commits sent previously with the {@link #commitAsync(OffsetCommitCallback)} + * (or similar) are guaranteed to have their callbacks invoked prior to completion of this method. + * + * @param offsets A map of offsets by partition with associated metadata + * @param timeout The maximum amount of time to await completion of the offset commit + * @throws org.apache.kafka.clients.consumer.CommitFailedException if the commit failed and cannot be retried. + * This can only occur if you are using automatic group management with {@link #subscribe(Collection)}, + * or if there is an active group with the same group.id which is using group management. In such cases, + * when you are trying to commit to partitions that are no longer assigned to this consumer because the + * consumer is for example no longer part of the group this exception would be thrown. + * @throws org.apache.kafka.common.errors.RebalanceInProgressException if the consumer instance is in the middle of a rebalance + * so it is not yet determined which partitions would be assigned to the consumer. In such cases you can first + * complete the rebalance by calling {@link #poll(Duration)} and commit can be reconsidered afterwards. + * NOTE when you reconsider committing after the rebalance, the assigned partitions may have changed, + * and also for those partitions that are still assigned their fetch positions may have changed too + * if more records are returned from the {@link #poll(Duration)} call, so when you retry committing + * you should consider updating the passed in {@code offset} parameter. + * @throws org.apache.kafka.common.errors.WakeupException if {@link #wakeup()} is called before or while this + * function is called + * @throws org.apache.kafka.common.errors.InterruptException if the calling thread is interrupted before or while + * this function is called + * @throws org.apache.kafka.common.errors.AuthenticationException if authentication fails. See the exception for more details + * @throws org.apache.kafka.common.errors.AuthorizationException if not authorized to the topic or to the + * configured groupId. See the exception for more details + * @throws java.lang.IllegalArgumentException if the committed offset is negative + * @throws org.apache.kafka.common.KafkaException for any other unrecoverable errors (e.g. if offset metadata + * is too large or if the topic does not exist). + * @throws org.apache.kafka.common.errors.TimeoutException if the timeout expires before successful completion + * of the offset commit + * @throws org.apache.kafka.common.errors.FencedInstanceIdException if this consumer instance gets fenced by broker. + */ + @Override + public void commitSync(final Map offsets, final Duration timeout) { + acquireAndEnsureOpen(); + long commitStart = time.nanoseconds(); + try { + maybeThrowInvalidGroupIdException(); + offsets.forEach(this::updateLastSeenEpochIfNewer); + if (!coordinator.commitOffsetsSync(new HashMap<>(offsets), time.timer(timeout))) { + throw new TimeoutException("Timeout of " + timeout.toMillis() + "ms expired before successfully " + + "committing offsets " + offsets); + } + } finally { + kafkaConsumerMetrics.recordCommitSync(time.nanoseconds() - commitStart); + release(); + } + } + + /** + * Commit offsets returned on the last {@link #poll(Duration)} for all the subscribed list of topics and partition. + * Same as {@link #commitAsync(OffsetCommitCallback) commitAsync(null)} + * @throws org.apache.kafka.common.errors.FencedInstanceIdException if this consumer instance gets fenced by broker. + */ + @Override + public void commitAsync() { + commitAsync(null); + } + + /** + * Commit offsets returned on the last {@link #poll(Duration) poll()} for the subscribed list of topics and partitions. + *

+ * This commits offsets only to Kafka. The offsets committed using this API will be used on the first fetch after + * every rebalance and also on startup. As such, if you need to store offsets in anything other than Kafka, this API + * should not be used. + *

+ * This is an asynchronous call and will not block. Any errors encountered are either passed to the callback + * (if provided) or discarded. + *

+ * Offsets committed through multiple calls to this API are guaranteed to be sent in the same order as + * the invocations. Corresponding commit callbacks are also invoked in the same order. Additionally note that + * offsets committed through this API are guaranteed to complete before a subsequent call to {@link #commitSync()} + * (and variants) returns. + * + * @param callback Callback to invoke when the commit completes + * @throws org.apache.kafka.common.errors.FencedInstanceIdException if this consumer instance gets fenced by broker. + */ + @Override + public void commitAsync(OffsetCommitCallback callback) { + commitAsync(subscriptions.allConsumed(), callback); + } + + /** + * Commit the specified offsets for the specified list of topics and partitions to Kafka. + *

+ * This commits offsets to Kafka. The offsets committed using this API will be used on the first fetch after every + * rebalance and also on startup. As such, if you need to store offsets in anything other than Kafka, this API + * should not be used. The committed offset should be the next message your application will consume, + * i.e. lastProcessedMessageOffset + 1. If automatic group management with {@link #subscribe(Collection)} is used, + * then the committed offsets must belong to the currently auto-assigned partitions. + *

+ * This is an asynchronous call and will not block. Any errors encountered are either passed to the callback + * (if provided) or discarded. + *

+ * Offsets committed through multiple calls to this API are guaranteed to be sent in the same order as + * the invocations. Corresponding commit callbacks are also invoked in the same order. Additionally note that + * offsets committed through this API are guaranteed to complete before a subsequent call to {@link #commitSync()} + * (and variants) returns. + * + * @param offsets A map of offsets by partition with associate metadata. This map will be copied internally, so it + * is safe to mutate the map after returning. + * @param callback Callback to invoke when the commit completes + * @throws org.apache.kafka.common.errors.FencedInstanceIdException if this consumer instance gets fenced by broker. + */ + @Override + public void commitAsync(final Map offsets, OffsetCommitCallback callback) { + acquireAndEnsureOpen(); + try { + maybeThrowInvalidGroupIdException(); + log.debug("Committing offsets: {}", offsets); + offsets.forEach(this::updateLastSeenEpochIfNewer); + coordinator.commitOffsetsAsync(new HashMap<>(offsets), callback); + } finally { + release(); + } + } + + /** + * Overrides the fetch offsets that the consumer will use on the next {@link #poll(Duration) poll(timeout)}. If this API + * is invoked for the same partition more than once, the latest offset will be used on the next poll(). Note that + * you may lose data if this API is arbitrarily used in the middle of consumption, to reset the fetch offsets + * + * @throws IllegalArgumentException if the provided offset is negative + * @throws IllegalStateException if the provided TopicPartition is not assigned to this consumer + */ + @Override + public void seek(TopicPartition partition, long offset) { + if (offset < 0) + throw new IllegalArgumentException("seek offset must not be a negative number"); + + acquireAndEnsureOpen(); + try { + log.info("Seeking to offset {} for partition {}", offset, partition); + SubscriptionState.FetchPosition newPosition = new SubscriptionState.FetchPosition( + offset, + Optional.empty(), // This will ensure we skip validation + this.metadata.currentLeader(partition)); + this.subscriptions.seekUnvalidated(partition, newPosition); + } finally { + release(); + } + } + + /** + * Overrides the fetch offsets that the consumer will use on the next {@link #poll(Duration) poll(timeout)}. If this API + * is invoked for the same partition more than once, the latest offset will be used on the next poll(). Note that + * you may lose data if this API is arbitrarily used in the middle of consumption, to reset the fetch offsets. This + * method allows for setting the leaderEpoch along with the desired offset. + * + * @throws IllegalArgumentException if the provided offset is negative + * @throws IllegalStateException if the provided TopicPartition is not assigned to this consumer + */ + @Override + public void seek(TopicPartition partition, OffsetAndMetadata offsetAndMetadata) { + long offset = offsetAndMetadata.offset(); + if (offset < 0) { + throw new IllegalArgumentException("seek offset must not be a negative number"); + } + + acquireAndEnsureOpen(); + try { + if (offsetAndMetadata.leaderEpoch().isPresent()) { + log.info("Seeking to offset {} for partition {} with epoch {}", + offset, partition, offsetAndMetadata.leaderEpoch().get()); + } else { + log.info("Seeking to offset {} for partition {}", offset, partition); + } + Metadata.LeaderAndEpoch currentLeaderAndEpoch = this.metadata.currentLeader(partition); + SubscriptionState.FetchPosition newPosition = new SubscriptionState.FetchPosition( + offsetAndMetadata.offset(), + offsetAndMetadata.leaderEpoch(), + currentLeaderAndEpoch); + this.updateLastSeenEpochIfNewer(partition, offsetAndMetadata); + this.subscriptions.seekUnvalidated(partition, newPosition); + } finally { + release(); + } + } + + /** + * Seek to the first offset for each of the given partitions. This function evaluates lazily, seeking to the + * first offset in all partitions only when {@link #poll(Duration)} or {@link #position(TopicPartition)} are called. + * If no partitions are provided, seek to the first offset for all of the currently assigned partitions. + * + * @throws IllegalArgumentException if {@code partitions} is {@code null} + * @throws IllegalStateException if any of the provided partitions are not currently assigned to this consumer + */ + @Override + public void seekToBeginning(Collection partitions) { + if (partitions == null) + throw new IllegalArgumentException("Partitions collection cannot be null"); + + acquireAndEnsureOpen(); + try { + Collection parts = partitions.size() == 0 ? this.subscriptions.assignedPartitions() : partitions; + subscriptions.requestOffsetReset(parts, OffsetResetStrategy.EARLIEST); + } finally { + release(); + } + } + + /** + * Seek to the last offset for each of the given partitions. This function evaluates lazily, seeking to the + * final offset in all partitions only when {@link #poll(Duration)} or {@link #position(TopicPartition)} are called. + * If no partitions are provided, seek to the final offset for all of the currently assigned partitions. + *

+ * If {@code isolation.level=read_committed}, the end offset will be the Last Stable Offset, i.e., the offset + * of the first message with an open transaction. + * + * @throws IllegalArgumentException if {@code partitions} is {@code null} + * @throws IllegalStateException if any of the provided partitions are not currently assigned to this consumer + */ + @Override + public void seekToEnd(Collection partitions) { + if (partitions == null) + throw new IllegalArgumentException("Partitions collection cannot be null"); + + acquireAndEnsureOpen(); + try { + Collection parts = partitions.size() == 0 ? this.subscriptions.assignedPartitions() : partitions; + subscriptions.requestOffsetReset(parts, OffsetResetStrategy.LATEST); + } finally { + release(); + } + } + + /** + * Get the offset of the next record that will be fetched (if a record with that offset exists). + * This method may issue a remote call to the server if there is no current position for the given partition. + *

+ * This call will block until either the position could be determined or an unrecoverable error is + * encountered (in which case it is thrown to the caller), or the timeout specified by {@code default.api.timeout.ms} expires + * (in which case a {@link org.apache.kafka.common.errors.TimeoutException} is thrown to the caller). + * + * @param partition The partition to get the position for + * @return The current position of the consumer (that is, the offset of the next record to be fetched) + * @throws IllegalStateException if the provided TopicPartition is not assigned to this consumer + * @throws org.apache.kafka.clients.consumer.InvalidOffsetException if no offset is currently defined for + * the partition + * @throws org.apache.kafka.common.errors.WakeupException if {@link #wakeup()} is called before or while this + * function is called + * @throws org.apache.kafka.common.errors.InterruptException if the calling thread is interrupted before or while + * this function is called + * @throws org.apache.kafka.common.errors.AuthenticationException if authentication fails. See the exception for more details + * @throws org.apache.kafka.common.errors.AuthorizationException if not authorized to the topic or to the + * configured groupId. See the exception for more details + * @throws org.apache.kafka.common.errors.UnsupportedVersionException if the consumer attempts to fetch stable offsets + * when the broker doesn't support this feature + * @throws org.apache.kafka.common.KafkaException for any other unrecoverable errors + * @throws org.apache.kafka.common.errors.TimeoutException if the position cannot be determined before the + * timeout specified by {@code default.api.timeout.ms} expires + */ + @Override + public long position(TopicPartition partition) { + return position(partition, Duration.ofMillis(defaultApiTimeoutMs)); + } + + /** + * Get the offset of the next record that will be fetched (if a record with that offset exists). + * This method may issue a remote call to the server if there is no current position + * for the given partition. + *

+ * This call will block until the position can be determined, an unrecoverable error is + * encountered (in which case it is thrown to the caller), or the timeout expires. + * + * @param partition The partition to get the position for + * @param timeout The maximum amount of time to await determination of the current position + * @return The current position of the consumer (that is, the offset of the next record to be fetched) + * @throws IllegalStateException if the provided TopicPartition is not assigned to this consumer + * @throws org.apache.kafka.clients.consumer.InvalidOffsetException if no offset is currently defined for + * the partition + * @throws org.apache.kafka.common.errors.WakeupException if {@link #wakeup()} is called before or while this + * function is called + * @throws org.apache.kafka.common.errors.InterruptException if the calling thread is interrupted before or while + * this function is called + * @throws org.apache.kafka.common.errors.TimeoutException if the position cannot be determined before the + * passed timeout expires + * @throws org.apache.kafka.common.errors.AuthenticationException if authentication fails. See the exception for more details + * @throws org.apache.kafka.common.errors.AuthorizationException if not authorized to the topic or to the + * configured groupId. See the exception for more details + * @throws org.apache.kafka.common.KafkaException for any other unrecoverable errors + */ + @Override + public long position(TopicPartition partition, final Duration timeout) { + acquireAndEnsureOpen(); + try { + if (!this.subscriptions.isAssigned(partition)) + throw new IllegalStateException("You can only check the position for partitions assigned to this consumer."); + + Timer timer = time.timer(timeout); + do { + SubscriptionState.FetchPosition position = this.subscriptions.validPosition(partition); + if (position != null) + return position.offset; + + updateFetchPositions(timer); + client.poll(timer); + } while (timer.notExpired()); + + throw new TimeoutException("Timeout of " + timeout.toMillis() + "ms expired before the position " + + "for partition " + partition + " could be determined"); + } finally { + release(); + } + } + + /** + * Get the last committed offset for the given partition (whether the commit happened by this process or + * another). This offset will be used as the position for the consumer in the event of a failure. + *

+ * This call will do a remote call to get the latest committed offset from the server, and will block until the + * committed offset is gotten successfully, an unrecoverable error is encountered (in which case it is thrown to + * the caller), or the timeout specified by {@code default.api.timeout.ms} expires (in which case a + * {@link org.apache.kafka.common.errors.TimeoutException} is thrown to the caller). + * + * @param partition The partition to check + * @return The last committed offset and metadata or null if there was no prior commit + * @throws org.apache.kafka.common.errors.WakeupException if {@link #wakeup()} is called before or while this + * function is called + * @throws org.apache.kafka.common.errors.InterruptException if the calling thread is interrupted before or while + * this function is called + * @throws org.apache.kafka.common.errors.AuthenticationException if authentication fails. See the exception for more details + * @throws org.apache.kafka.common.errors.AuthorizationException if not authorized to the topic or to the + * configured groupId. See the exception for more details + * @throws org.apache.kafka.common.KafkaException for any other unrecoverable errors + * @throws org.apache.kafka.common.errors.TimeoutException if the committed offset cannot be found before + * the timeout specified by {@code default.api.timeout.ms} expires. + * + * @deprecated since 2.4 Use {@link #committed(Set)} instead + */ + @Deprecated + @Override + public OffsetAndMetadata committed(TopicPartition partition) { + return committed(partition, Duration.ofMillis(defaultApiTimeoutMs)); + } + + /** + * Get the last committed offset for the given partition (whether the commit happened by this process or + * another). This offset will be used as the position for the consumer in the event of a failure. + *

+ * This call will block until the position can be determined, an unrecoverable error is + * encountered (in which case it is thrown to the caller), or the timeout expires. + * + * @param partition The partition to check + * @param timeout The maximum amount of time to await the current committed offset + * @return The last committed offset and metadata or null if there was no prior commit + * @throws org.apache.kafka.common.errors.WakeupException if {@link #wakeup()} is called before or while this + * function is called + * @throws org.apache.kafka.common.errors.InterruptException if the calling thread is interrupted before or while + * this function is called + * @throws org.apache.kafka.common.errors.AuthenticationException if authentication fails. See the exception for more details + * @throws org.apache.kafka.common.errors.AuthorizationException if not authorized to the topic or to the + * configured groupId. See the exception for more details + * @throws org.apache.kafka.common.KafkaException for any other unrecoverable errors + * @throws org.apache.kafka.common.errors.TimeoutException if the committed offset cannot be found before + * expiration of the timeout + * + * @deprecated since 2.4 Use {@link #committed(Set, Duration)} instead + */ + @Deprecated + @Override + public OffsetAndMetadata committed(TopicPartition partition, final Duration timeout) { + return committed(Collections.singleton(partition), timeout).get(partition); + } + + /** + * Get the last committed offsets for the given partitions (whether the commit happened by this process or + * another). The returned offsets will be used as the position for the consumer in the event of a failure. + *

+ * If any of the partitions requested do not exist, an exception would be thrown. + *

+ * This call will do a remote call to get the latest committed offsets from the server, and will block until the + * committed offsets are gotten successfully, an unrecoverable error is encountered (in which case it is thrown to + * the caller), or the timeout specified by {@code default.api.timeout.ms} expires (in which case a + * {@link org.apache.kafka.common.errors.TimeoutException} is thrown to the caller). + * + * @param partitions The partitions to check + * @return The latest committed offsets for the given partitions; {@code null} will be returned for the + * partition if there is no such message. + * @throws org.apache.kafka.common.errors.WakeupException if {@link #wakeup()} is called before or while this + * function is called + * @throws org.apache.kafka.common.errors.InterruptException if the calling thread is interrupted before or while + * this function is called + * @throws org.apache.kafka.common.errors.AuthenticationException if authentication fails. See the exception for more details + * @throws org.apache.kafka.common.errors.AuthorizationException if not authorized to the topic or to the + * configured groupId. See the exception for more details + * @throws org.apache.kafka.common.errors.UnsupportedVersionException if the consumer attempts to fetch stable offsets + * when the broker doesn't support this feature + * @throws org.apache.kafka.common.KafkaException for any other unrecoverable errors + * @throws org.apache.kafka.common.errors.TimeoutException if the committed offset cannot be found before + * the timeout specified by {@code default.api.timeout.ms} expires. + */ + @Override + public Map committed(final Set partitions) { + return committed(partitions, Duration.ofMillis(defaultApiTimeoutMs)); + } + + /** + * Get the last committed offsets for the given partitions (whether the commit happened by this process or + * another). The returned offsets will be used as the position for the consumer in the event of a failure. + *

+ * If any of the partitions requested do not exist, an exception would be thrown. + *

+ * This call will block to do a remote call to get the latest committed offsets from the server. + * + * @param partitions The partitions to check + * @param timeout The maximum amount of time to await the latest committed offsets + * @return The latest committed offsets for the given partitions; {@code null} will be returned for the + * partition if there is no such message. + * @throws org.apache.kafka.common.errors.WakeupException if {@link #wakeup()} is called before or while this + * function is called + * @throws org.apache.kafka.common.errors.InterruptException if the calling thread is interrupted before or while + * this function is called + * @throws org.apache.kafka.common.errors.AuthenticationException if authentication fails. See the exception for more details + * @throws org.apache.kafka.common.errors.AuthorizationException if not authorized to the topic or to the + * configured groupId. See the exception for more details + * @throws org.apache.kafka.common.KafkaException for any other unrecoverable errors + * @throws org.apache.kafka.common.errors.TimeoutException if the committed offset cannot be found before + * expiration of the timeout + */ + @Override + public Map committed(final Set partitions, final Duration timeout) { + acquireAndEnsureOpen(); + long start = time.nanoseconds(); + try { + maybeThrowInvalidGroupIdException(); + final Map offsets; + offsets = coordinator.fetchCommittedOffsets(partitions, time.timer(timeout)); + if (offsets == null) { + throw new TimeoutException("Timeout of " + timeout.toMillis() + "ms expired before the last " + + "committed offset for partitions " + partitions + " could be determined. Try tuning default.api.timeout.ms " + + "larger to relax the threshold."); + } else { + offsets.forEach(this::updateLastSeenEpochIfNewer); + return offsets; + } + } finally { + kafkaConsumerMetrics.recordCommitted(time.nanoseconds() - start); + release(); + } + } + + /** + * Get the metrics kept by the consumer + */ + @Override + public Map metrics() { + return Collections.unmodifiableMap(this.metrics.metrics()); + } + + /** + * Get metadata about the partitions for a given topic. This method will issue a remote call to the server if it + * does not already have any metadata about the given topic. + * + * @param topic The topic to get partition metadata for + * + * @return The list of partitions, which will be empty when the given topic is not found + * @throws org.apache.kafka.common.errors.WakeupException if {@link #wakeup()} is called before or while this + * function is called + * @throws org.apache.kafka.common.errors.InterruptException if the calling thread is interrupted before or while + * this function is called + * @throws org.apache.kafka.common.errors.AuthenticationException if authentication fails. See the exception for more details + * @throws org.apache.kafka.common.errors.AuthorizationException if not authorized to the specified topic. See the exception for more details + * @throws org.apache.kafka.common.KafkaException for any other unrecoverable errors + * @throws org.apache.kafka.common.errors.TimeoutException if the offset metadata could not be fetched before + * the amount of time allocated by {@code default.api.timeout.ms} expires. + */ + @Override + public List partitionsFor(String topic) { + return partitionsFor(topic, Duration.ofMillis(defaultApiTimeoutMs)); + } + + /** + * Get metadata about the partitions for a given topic. This method will issue a remote call to the server if it + * does not already have any metadata about the given topic. + * + * @param topic The topic to get partition metadata for + * @param timeout The maximum of time to await topic metadata + * + * @return The list of partitions, which will be empty when the given topic is not found + * @throws org.apache.kafka.common.errors.WakeupException if {@link #wakeup()} is called before or while this + * function is called + * @throws org.apache.kafka.common.errors.InterruptException if the calling thread is interrupted before or while + * this function is called + * @throws org.apache.kafka.common.errors.AuthenticationException if authentication fails. See the exception for more details + * @throws org.apache.kafka.common.errors.AuthorizationException if not authorized to the specified topic. See + * the exception for more details + * @throws org.apache.kafka.common.errors.TimeoutException if topic metadata cannot be fetched before expiration + * of the passed timeout + * @throws org.apache.kafka.common.KafkaException for any other unrecoverable errors + */ + @Override + public List partitionsFor(String topic, Duration timeout) { + acquireAndEnsureOpen(); + try { + Cluster cluster = this.metadata.fetch(); + List parts = cluster.partitionsForTopic(topic); + if (!parts.isEmpty()) + return parts; + + Timer timer = time.timer(timeout); + Map> topicMetadata = fetcher.getTopicMetadata( + new MetadataRequest.Builder(Collections.singletonList(topic), metadata.allowAutoTopicCreation()), timer); + return topicMetadata.getOrDefault(topic, Collections.emptyList()); + } finally { + release(); + } + } + + /** + * Get metadata about partitions for all topics that the user is authorized to view. This method will issue a + * remote call to the server. + + * @return The map of topics and its partitions + * + * @throws org.apache.kafka.common.errors.WakeupException if {@link #wakeup()} is called before or while this + * function is called + * @throws org.apache.kafka.common.errors.InterruptException if the calling thread is interrupted before or while + * this function is called + * @throws org.apache.kafka.common.KafkaException for any other unrecoverable errors + * @throws org.apache.kafka.common.errors.TimeoutException if the offset metadata could not be fetched before + * the amount of time allocated by {@code default.api.timeout.ms} expires. + */ + @Override + public Map> listTopics() { + return listTopics(Duration.ofMillis(defaultApiTimeoutMs)); + } + + /** + * Get metadata about partitions for all topics that the user is authorized to view. This method will issue a + * remote call to the server. + * + * @param timeout The maximum time this operation will block to fetch topic metadata + * + * @return The map of topics and its partitions + * @throws org.apache.kafka.common.errors.WakeupException if {@link #wakeup()} is called before or while this + * function is called + * @throws org.apache.kafka.common.errors.InterruptException if the calling thread is interrupted before or while + * this function is called + * @throws org.apache.kafka.common.errors.TimeoutException if the topic metadata could not be fetched before + * expiration of the passed timeout + * @throws org.apache.kafka.common.KafkaException for any other unrecoverable errors + */ + @Override + public Map> listTopics(Duration timeout) { + acquireAndEnsureOpen(); + try { + return fetcher.getAllTopicMetadata(time.timer(timeout)); + } finally { + release(); + } + } + + /** + * Suspend fetching from the requested partitions. Future calls to {@link #poll(Duration)} will not return + * any records from these partitions until they have been resumed using {@link #resume(Collection)}. + * Note that this method does not affect partition subscription. In particular, it does not cause a group + * rebalance when automatic assignment is used. + * @param partitions The partitions which should be paused + * @throws IllegalStateException if any of the provided partitions are not currently assigned to this consumer + */ + @Override + public void pause(Collection partitions) { + acquireAndEnsureOpen(); + try { + log.debug("Pausing partitions {}", partitions); + for (TopicPartition partition: partitions) { + subscriptions.pause(partition); + } + } finally { + release(); + } + } + + /** + * Resume specified partitions which have been paused with {@link #pause(Collection)}. New calls to + * {@link #poll(Duration)} will return records from these partitions if there are any to be fetched. + * If the partitions were not previously paused, this method is a no-op. + * @param partitions The partitions which should be resumed + * @throws IllegalStateException if any of the provided partitions are not currently assigned to this consumer + */ + @Override + public void resume(Collection partitions) { + acquireAndEnsureOpen(); + try { + log.debug("Resuming partitions {}", partitions); + for (TopicPartition partition: partitions) { + subscriptions.resume(partition); + } + } finally { + release(); + } + } + + /** + * Get the set of partitions that were previously paused by a call to {@link #pause(Collection)}. + * + * @return The set of paused partitions + */ + @Override + public Set paused() { + acquireAndEnsureOpen(); + try { + return Collections.unmodifiableSet(subscriptions.pausedPartitions()); + } finally { + release(); + } + } + + /** + * Look up the offsets for the given partitions by timestamp. The returned offset for each partition is the + * earliest offset whose timestamp is greater than or equal to the given timestamp in the corresponding partition. + * + * This is a blocking call. The consumer does not have to be assigned the partitions. + * If the message format version in a partition is before 0.10.0, i.e. the messages do not have timestamps, null + * will be returned for that partition. + * + * @param timestampsToSearch the mapping from partition to the timestamp to look up. + * + * @return a mapping from partition to the timestamp and offset of the first message with timestamp greater + * than or equal to the target timestamp. {@code null} will be returned for the partition if there is no + * such message. + * @throws org.apache.kafka.common.errors.AuthenticationException if authentication fails. See the exception for more details + * @throws org.apache.kafka.common.errors.AuthorizationException if not authorized to the topic(s). See the exception for more details + * @throws IllegalArgumentException if the target timestamp is negative + * @throws org.apache.kafka.common.errors.TimeoutException if the offset metadata could not be fetched before + * the amount of time allocated by {@code default.api.timeout.ms} expires. + * @throws org.apache.kafka.common.errors.UnsupportedVersionException if the broker does not support looking up + * the offsets by timestamp + */ + @Override + public Map offsetsForTimes(Map timestampsToSearch) { + return offsetsForTimes(timestampsToSearch, Duration.ofMillis(defaultApiTimeoutMs)); + } + + /** + * Look up the offsets for the given partitions by timestamp. The returned offset for each partition is the + * earliest offset whose timestamp is greater than or equal to the given timestamp in the corresponding partition. + * + * This is a blocking call. The consumer does not have to be assigned the partitions. + * If the message format version in a partition is before 0.10.0, i.e. the messages do not have timestamps, null + * will be returned for that partition. + * + * @param timestampsToSearch the mapping from partition to the timestamp to look up. + * @param timeout The maximum amount of time to await retrieval of the offsets + * + * @return a mapping from partition to the timestamp and offset of the first message with timestamp greater + * than or equal to the target timestamp. {@code null} will be returned for the partition if there is no + * such message. + * @throws org.apache.kafka.common.errors.AuthenticationException if authentication fails. See the exception for more details + * @throws org.apache.kafka.common.errors.AuthorizationException if not authorized to the topic(s). See the exception for more details + * @throws IllegalArgumentException if the target timestamp is negative + * @throws org.apache.kafka.common.errors.TimeoutException if the offset metadata could not be fetched before + * expiration of the passed timeout + * @throws org.apache.kafka.common.errors.UnsupportedVersionException if the broker does not support looking up + * the offsets by timestamp + */ + @Override + public Map offsetsForTimes(Map timestampsToSearch, Duration timeout) { + acquireAndEnsureOpen(); + try { + for (Map.Entry entry : timestampsToSearch.entrySet()) { + // we explicitly exclude the earliest and latest offset here so the timestamp in the returned + // OffsetAndTimestamp is always positive. + if (entry.getValue() < 0) + throw new IllegalArgumentException("The target time for partition " + entry.getKey() + " is " + + entry.getValue() + ". The target time cannot be negative."); + } + return fetcher.offsetsForTimes(timestampsToSearch, time.timer(timeout)); + } finally { + release(); + } + } + + /** + * Get the first offset for the given partitions. + *

+ * This method does not change the current consumer position of the partitions. + * + * @see #seekToBeginning(Collection) + * + * @param partitions the partitions to get the earliest offsets. + * @return The earliest available offsets for the given partitions + * @throws org.apache.kafka.common.errors.AuthenticationException if authentication fails. See the exception for more details + * @throws org.apache.kafka.common.errors.AuthorizationException if not authorized to the topic(s). See the exception for more details + * @throws org.apache.kafka.common.errors.TimeoutException if the offset metadata could not be fetched before + * expiration of the configured {@code default.api.timeout.ms} + */ + @Override + public Map beginningOffsets(Collection partitions) { + return beginningOffsets(partitions, Duration.ofMillis(defaultApiTimeoutMs)); + } + + /** + * Get the first offset for the given partitions. + *

+ * This method does not change the current consumer position of the partitions. + * + * @see #seekToBeginning(Collection) + * + * @param partitions the partitions to get the earliest offsets + * @param timeout The maximum amount of time to await retrieval of the beginning offsets + * + * @return The earliest available offsets for the given partitions + * @throws org.apache.kafka.common.errors.AuthenticationException if authentication fails. See the exception for more details + * @throws org.apache.kafka.common.errors.AuthorizationException if not authorized to the topic(s). See the exception for more details + * @throws org.apache.kafka.common.errors.TimeoutException if the offset metadata could not be fetched before + * expiration of the passed timeout + */ + @Override + public Map beginningOffsets(Collection partitions, Duration timeout) { + acquireAndEnsureOpen(); + try { + return fetcher.beginningOffsets(partitions, time.timer(timeout)); + } finally { + release(); + } + } + + /** + * Get the end offsets for the given partitions. In the default {@code read_uncommitted} isolation level, the end + * offset is the high watermark (that is, the offset of the last successfully replicated message plus one). For + * {@code read_committed} consumers, the end offset is the last stable offset (LSO), which is the minimum of + * the high watermark and the smallest offset of any open transaction. Finally, if the partition has never been + * written to, the end offset is 0. + * + *

+ * This method does not change the current consumer position of the partitions. + * + * @see #seekToEnd(Collection) + * + * @param partitions the partitions to get the end offsets. + * @return The end offsets for the given partitions. + * @throws org.apache.kafka.common.errors.AuthenticationException if authentication fails. See the exception for more details + * @throws org.apache.kafka.common.errors.AuthorizationException if not authorized to the topic(s). See the exception for more details + * @throws org.apache.kafka.common.errors.TimeoutException if the offset metadata could not be fetched before + * the amount of time allocated by {@code request.timeout.ms} expires + */ + @Override + public Map endOffsets(Collection partitions) { + return endOffsets(partitions, Duration.ofMillis(requestTimeoutMs)); + } + + /** + * Get the end offsets for the given partitions. In the default {@code read_uncommitted} isolation level, the end + * offset is the high watermark (that is, the offset of the last successfully replicated message plus one). For + * {@code read_committed} consumers, the end offset is the last stable offset (LSO), which is the minimum of + * the high watermark and the smallest offset of any open transaction. Finally, if the partition has never been + * written to, the end offset is 0. + * + *

+ * This method does not change the current consumer position of the partitions. + * + * @see #seekToEnd(Collection) + * + * @param partitions the partitions to get the end offsets. + * @param timeout The maximum amount of time to await retrieval of the end offsets + * + * @return The end offsets for the given partitions. + * @throws org.apache.kafka.common.errors.AuthenticationException if authentication fails. See the exception for more details + * @throws org.apache.kafka.common.errors.AuthorizationException if not authorized to the topic(s). See the exception for more details + * @throws org.apache.kafka.common.errors.TimeoutException if the offsets could not be fetched before + * expiration of the passed timeout + */ + @Override + public Map endOffsets(Collection partitions, Duration timeout) { + acquireAndEnsureOpen(); + try { + return fetcher.endOffsets(partitions, time.timer(timeout)); + } finally { + release(); + } + } + + /** + * Get the consumer's current lag on the partition. Returns an "empty" {@link OptionalLong} if the lag is not known, + * for example if there is no position yet, or if the end offset is not known yet. + * + *

+ * This method uses locally cached metadata and never makes a remote call. + * + * @param topicPartition The partition to get the lag for. + * + * @return This {@code Consumer} instance's current lag for the given partition. + * + * @throws IllegalStateException if the {@code topicPartition} is not assigned + **/ + @Override + public OptionalLong currentLag(TopicPartition topicPartition) { + acquireAndEnsureOpen(); + try { + final Long lag = subscriptions.partitionLag(topicPartition, isolationLevel); + + // if the log end offset is not known and hence cannot return lag and there is + // no in-flight list offset requested yet, + // issue a list offset request for that partition so that next time + // we may get the answer; we do not need to wait for the return value + // since we would not try to poll the network client synchronously + if (lag == null) { + if (subscriptions.partitionEndOffset(topicPartition, isolationLevel) == null && + !subscriptions.partitionEndOffsetRequested(topicPartition)) { + log.info("Requesting the log end offset for {} in order to compute lag", topicPartition); + subscriptions.requestPartitionEndOffset(topicPartition); + fetcher.endOffsets(Collections.singleton(topicPartition), time.timer(0L)); + } + + return OptionalLong.empty(); + } + + return OptionalLong.of(lag); + } finally { + release(); + } + } + + /** + * Return the current group metadata associated with this consumer. + * + * @return consumer group metadata + * @throws org.apache.kafka.common.errors.InvalidGroupIdException if consumer does not have a group + */ + @Override + public ConsumerGroupMetadata groupMetadata() { + acquireAndEnsureOpen(); + try { + maybeThrowInvalidGroupIdException(); + return coordinator.groupMetadata(); + } finally { + release(); + } + } + + /** + * Alert the consumer to trigger a new rebalance by rejoining the group. This is a nonblocking call that forces + * the consumer to trigger a new rebalance on the next {@link #poll(Duration)} call. Note that this API does not + * itself initiate the rebalance, so you must still call {@link #poll(Duration)}. If a rebalance is already in + * progress this call will be a no-op. If you wish to force an additional rebalance you must complete the current + * one by calling poll before retrying this API. + *

+ * You do not need to call this during normal processing, as the consumer group will manage itself + * automatically and rebalance when necessary. However there may be situations where the application wishes to + * trigger a rebalance that would otherwise not occur. For example, if some condition external and invisible to + * the Consumer and its group changes in a way that would affect the userdata encoded in the + * {@link org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Subscription Subscription}, the Consumer + * will not be notified and no rebalance will occur. This API can be used to force the group to rebalance so that + * the assignor can perform a partition reassignment based on the latest userdata. If your assignor does not use + * this userdata, or you do not use a custom + * {@link org.apache.kafka.clients.consumer.ConsumerPartitionAssignor ConsumerPartitionAssignor}, you should not + * use this API. + * + * @throws java.lang.IllegalStateException if the consumer does not use group subscription + */ + @Override + public void enforceRebalance() { + acquireAndEnsureOpen(); + try { + if (coordinator == null) { + throw new IllegalStateException("Tried to force a rebalance but consumer does not have a group."); + } + coordinator.requestRejoin("rebalance enforced by user"); + } finally { + release(); + } + } + + /** + * Close the consumer, waiting for up to the default timeout of 30 seconds for any needed cleanup. + * If auto-commit is enabled, this will commit the current offsets if possible within the default + * timeout. See {@link #close(Duration)} for details. Note that {@link #wakeup()} + * cannot be used to interrupt close. + * + * @throws org.apache.kafka.common.errors.InterruptException if the calling thread is interrupted + * before or while this function is called + * @throws org.apache.kafka.common.KafkaException for any other error during close + */ + @Override + public void close() { + close(Duration.ofMillis(DEFAULT_CLOSE_TIMEOUT_MS)); + } + + /** + * Tries to close the consumer cleanly within the specified timeout. This method waits up to + * {@code timeout} for the consumer to complete pending commits and leave the group. + * If auto-commit is enabled, this will commit the current offsets if possible within the + * timeout. If the consumer is unable to complete offset commits and gracefully leave the group + * before the timeout expires, the consumer is force closed. Note that {@link #wakeup()} cannot be + * used to interrupt close. + * + * @param timeout The maximum time to wait for consumer to close gracefully. The value must be + * non-negative. Specifying a timeout of zero means do not wait for pending requests to complete. + * + * @throws IllegalArgumentException If the {@code timeout} is negative. + * @throws InterruptException If the thread is interrupted before or while this function is called + * @throws org.apache.kafka.common.KafkaException for any other error during close + */ + @Override + public void close(Duration timeout) { + if (timeout.toMillis() < 0) + throw new IllegalArgumentException("The timeout cannot be negative."); + acquire(); + try { + if (!closed) { + // need to close before setting the flag since the close function + // itself may trigger rebalance callback that needs the consumer to be open still + close(timeout.toMillis(), false); + } + } finally { + closed = true; + release(); + } + } + + /** + * Wakeup the consumer. This method is thread-safe and is useful in particular to abort a long poll. + * The thread which is blocking in an operation will throw {@link org.apache.kafka.common.errors.WakeupException}. + * If no thread is blocking in a method which can throw {@link org.apache.kafka.common.errors.WakeupException}, the next call to such a method will raise it instead. + */ + @Override + public void wakeup() { + this.client.wakeup(); + } + + private ClusterResourceListeners configureClusterResourceListeners(Deserializer keyDeserializer, Deserializer valueDeserializer, List... candidateLists) { + ClusterResourceListeners clusterResourceListeners = new ClusterResourceListeners(); + for (List candidateList: candidateLists) + clusterResourceListeners.maybeAddAll(candidateList); + + clusterResourceListeners.maybeAdd(keyDeserializer); + clusterResourceListeners.maybeAdd(valueDeserializer); + return clusterResourceListeners; + } + + private void close(long timeoutMs, boolean swallowException) { + log.trace("Closing the Kafka consumer"); + AtomicReference firstException = new AtomicReference<>(); + try { + if (coordinator != null) + coordinator.close(time.timer(Math.min(timeoutMs, requestTimeoutMs))); + } catch (Throwable t) { + firstException.compareAndSet(null, t); + log.error("Failed to close coordinator", t); + } + Utils.closeQuietly(fetcher, "fetcher", firstException); + Utils.closeQuietly(interceptors, "consumer interceptors", firstException); + Utils.closeQuietly(kafkaConsumerMetrics, "kafka consumer metrics", firstException); + Utils.closeQuietly(metrics, "consumer metrics", firstException); + Utils.closeQuietly(client, "consumer network client", firstException); + Utils.closeQuietly(keyDeserializer, "consumer key deserializer", firstException); + Utils.closeQuietly(valueDeserializer, "consumer value deserializer", firstException); + AppInfoParser.unregisterAppInfo(JMX_PREFIX, clientId, metrics); + log.debug("Kafka consumer has been closed"); + Throwable exception = firstException.get(); + if (exception != null && !swallowException) { + if (exception instanceof InterruptException) { + throw (InterruptException) exception; + } + throw new KafkaException("Failed to close kafka consumer", exception); + } + } + + /** + * Set the fetch position to the committed position (if there is one) + * or reset it using the offset reset policy the user has configured. + * + * @throws org.apache.kafka.common.errors.AuthenticationException if authentication fails. See the exception for more details + * @throws NoOffsetForPartitionException If no offset is stored for a given partition and no offset reset policy is + * defined + * @return true iff the operation completed without timing out + */ + private boolean updateFetchPositions(final Timer timer) { + // If any partitions have been truncated due to a leader change, we need to validate the offsets + fetcher.validateOffsetsIfNeeded(); + + cachedSubscriptionHashAllFetchPositions = subscriptions.hasAllFetchPositions(); + if (cachedSubscriptionHashAllFetchPositions) return true; + + // If there are any partitions which do not have a valid position and are not + // awaiting reset, then we need to fetch committed offsets. We will only do a + // coordinator lookup if there are partitions which have missing positions, so + // a consumer with manually assigned partitions can avoid a coordinator dependence + // by always ensuring that assigned partitions have an initial position. + if (coordinator != null && !coordinator.refreshCommittedOffsetsIfNeeded(timer)) return false; + + // If there are partitions still needing a position and a reset policy is defined, + // request reset using the default policy. If no reset strategy is defined and there + // are partitions with a missing position, then we will raise an exception. + subscriptions.resetInitializingPositions(); + + // Finally send an asynchronous request to lookup and update the positions of any + // partitions which are awaiting reset. + fetcher.resetOffsetsIfNeeded(); + + return true; + } + + /** + * Acquire the light lock and ensure that the consumer hasn't been closed. + * @throws IllegalStateException If the consumer has been closed + */ + private void acquireAndEnsureOpen() { + acquire(); + if (this.closed) { + release(); + throw new IllegalStateException("This consumer has already been closed."); + } + } + + /** + * Acquire the light lock protecting this consumer from multi-threaded access. Instead of blocking + * when the lock is not available, however, we just throw an exception (since multi-threaded usage is not + * supported). + * @throws ConcurrentModificationException if another thread already has the lock + */ + private void acquire() { + long threadId = Thread.currentThread().getId(); + if (threadId != currentThread.get() && !currentThread.compareAndSet(NO_CURRENT_THREAD, threadId)) + throw new ConcurrentModificationException("KafkaConsumer is not safe for multi-threaded access"); + refcount.incrementAndGet(); + } + + /** + * Release the light lock protecting the consumer from multi-threaded access. + */ + private void release() { + if (refcount.decrementAndGet() == 0) + currentThread.set(NO_CURRENT_THREAD); + } + + private void throwIfNoAssignorsConfigured() { + if (assignors.isEmpty()) + throw new IllegalStateException("Must configure at least one partition assigner class name to " + + ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG + " configuration property"); + } + + private void maybeThrowInvalidGroupIdException() { + if (!groupId.isPresent()) + throw new InvalidGroupIdException("To use the group management or offset commit APIs, you must " + + "provide a valid " + ConsumerConfig.GROUP_ID_CONFIG + " in the consumer configuration."); + } + + private void updateLastSeenEpochIfNewer(TopicPartition topicPartition, OffsetAndMetadata offsetAndMetadata) { + if (offsetAndMetadata != null) + offsetAndMetadata.leaderEpoch().ifPresent(epoch -> metadata.updateLastSeenEpochIfNewer(topicPartition, epoch)); + } + + // Functions below are for testing only + String getClientId() { + return clientId; + } + + boolean updateAssignmentMetadataIfNeeded(final Timer timer) { + return updateAssignmentMetadataIfNeeded(timer, true); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/LogTruncationException.java b/clients/src/main/java/org/apache/kafka/clients/consumer/LogTruncationException.java new file mode 100644 index 0000000..336eed4 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/LogTruncationException.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import org.apache.kafka.common.TopicPartition; + +import java.util.Collections; +import java.util.Map; + +/** + * In the event of an unclean leader election, the log will be truncated, + * previously committed data will be lost, and new data will be written + * over these offsets. When this happens, the consumer will detect the + * truncation and raise this exception (if no automatic reset policy + * has been defined) with the first offset known to diverge from what the + * consumer previously read. + */ +public class LogTruncationException extends OffsetOutOfRangeException { + + private final Map divergentOffsets; + + public LogTruncationException(String message, + Map fetchOffsets, + Map divergentOffsets) { + super(message, fetchOffsets); + this.divergentOffsets = Collections.unmodifiableMap(divergentOffsets); + } + + /** + * Get the divergent offsets for the partitions which were truncated. For each + * partition, this is the first offset which is known to diverge from what the + * consumer read. + * + * Note that there is no guarantee that this offset will be known. It is necessary + * to use {@link #partitions()} to see the set of partitions that were truncated + * and then check for the presence of a divergent offset in this map. + */ + public Map divergentOffsets() { + return divergentOffsets; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java b/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java new file mode 100644 index 0000000..b1fc7ec --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/MockConsumer.java @@ -0,0 +1,577 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import org.apache.kafka.clients.Metadata; +import org.apache.kafka.clients.consumer.internals.NoOpConsumerRebalanceListener; +import org.apache.kafka.clients.consumer.internals.SubscriptionState; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.WakeupException; +import org.apache.kafka.common.utils.LogContext; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalLong; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import static java.util.Collections.singleton; +import static org.apache.kafka.clients.consumer.KafkaConsumer.DEFAULT_CLOSE_TIMEOUT_MS; + + +/** + * A mock of the {@link Consumer} interface you can use for testing code that uses Kafka. This class is not + * threadsafe . However, you can use the {@link #schedulePollTask(Runnable)} method to write multithreaded tests + * where a driver thread waits for {@link #poll(Duration)} to be called by a background thread and then can safely perform + * operations during a callback. + */ +public class MockConsumer implements Consumer { + + private final Map> partitions; + private final SubscriptionState subscriptions; + private final Map beginningOffsets; + private final Map endOffsets; + private final Map committed; + private final Queue pollTasks; + private final Set paused; + + private Map>> records; + private KafkaException pollException; + private KafkaException offsetsException; + private AtomicBoolean wakeup; + private Duration lastPollTimeout; + private boolean closed; + private boolean shouldRebalance; + + public MockConsumer(OffsetResetStrategy offsetResetStrategy) { + this.subscriptions = new SubscriptionState(new LogContext(), offsetResetStrategy); + this.partitions = new HashMap<>(); + this.records = new HashMap<>(); + this.paused = new HashSet<>(); + this.closed = false; + this.beginningOffsets = new HashMap<>(); + this.endOffsets = new HashMap<>(); + this.pollTasks = new LinkedList<>(); + this.pollException = null; + this.wakeup = new AtomicBoolean(false); + this.committed = new HashMap<>(); + this.shouldRebalance = false; + } + + @Override + public synchronized Set assignment() { + return this.subscriptions.assignedPartitions(); + } + + /** Simulate a rebalance event. */ + public synchronized void rebalance(Collection newAssignment) { + // TODO: Rebalance callbacks + this.records.clear(); + this.subscriptions.assignFromSubscribed(newAssignment); + } + + @Override + public synchronized Set subscription() { + return this.subscriptions.subscription(); + } + + @Override + public synchronized void subscribe(Collection topics) { + subscribe(topics, new NoOpConsumerRebalanceListener()); + } + + @Override + public synchronized void subscribe(Pattern pattern, final ConsumerRebalanceListener listener) { + ensureNotClosed(); + committed.clear(); + this.subscriptions.subscribe(pattern, listener); + Set topicsToSubscribe = new HashSet<>(); + for (String topic: partitions.keySet()) { + if (pattern.matcher(topic).matches() && + !subscriptions.subscription().contains(topic)) + topicsToSubscribe.add(topic); + } + ensureNotClosed(); + this.subscriptions.subscribeFromPattern(topicsToSubscribe); + final Set assignedPartitions = new HashSet<>(); + for (final String topic : topicsToSubscribe) { + for (final PartitionInfo info : this.partitions.get(topic)) { + assignedPartitions.add(new TopicPartition(topic, info.partition())); + } + + } + subscriptions.assignFromSubscribed(assignedPartitions); + } + + @Override + public synchronized void subscribe(Pattern pattern) { + subscribe(pattern, new NoOpConsumerRebalanceListener()); + } + + @Override + public synchronized void subscribe(Collection topics, final ConsumerRebalanceListener listener) { + ensureNotClosed(); + committed.clear(); + this.subscriptions.subscribe(new HashSet<>(topics), listener); + } + + @Override + public synchronized void assign(Collection partitions) { + ensureNotClosed(); + committed.clear(); + this.subscriptions.assignFromUser(new HashSet<>(partitions)); + } + + @Override + public synchronized void unsubscribe() { + ensureNotClosed(); + committed.clear(); + subscriptions.unsubscribe(); + } + + @Deprecated + @Override + public synchronized ConsumerRecords poll(long timeout) { + return poll(Duration.ofMillis(timeout)); + } + + @Override + public synchronized ConsumerRecords poll(final Duration timeout) { + ensureNotClosed(); + + lastPollTimeout = timeout; + + // Synchronize around the entire execution so new tasks to be triggered on subsequent poll calls can be added in + // the callback + synchronized (pollTasks) { + Runnable task = pollTasks.poll(); + if (task != null) + task.run(); + } + + if (wakeup.get()) { + wakeup.set(false); + throw new WakeupException(); + } + + if (pollException != null) { + RuntimeException exception = this.pollException; + this.pollException = null; + throw exception; + } + + // Handle seeks that need to wait for a poll() call to be processed + for (TopicPartition tp : subscriptions.assignedPartitions()) + if (!subscriptions.hasValidPosition(tp)) + updateFetchPosition(tp); + + // update the consumed offset + final Map>> results = new HashMap<>(); + final List toClear = new ArrayList<>(); + + for (Map.Entry>> entry : this.records.entrySet()) { + if (!subscriptions.isPaused(entry.getKey())) { + final List> recs = entry.getValue(); + for (final ConsumerRecord rec : recs) { + long position = subscriptions.position(entry.getKey()).offset; + + if (beginningOffsets.get(entry.getKey()) != null && beginningOffsets.get(entry.getKey()) > position) { + throw new OffsetOutOfRangeException(Collections.singletonMap(entry.getKey(), position)); + } + + if (assignment().contains(entry.getKey()) && rec.offset() >= position) { + results.computeIfAbsent(entry.getKey(), partition -> new ArrayList<>()).add(rec); + Metadata.LeaderAndEpoch leaderAndEpoch = new Metadata.LeaderAndEpoch(Optional.empty(), rec.leaderEpoch()); + SubscriptionState.FetchPosition newPosition = new SubscriptionState.FetchPosition( + rec.offset() + 1, rec.leaderEpoch(), leaderAndEpoch); + subscriptions.position(entry.getKey(), newPosition); + } + } + toClear.add(entry.getKey()); + } + } + + toClear.forEach(p -> this.records.remove(p)); + return new ConsumerRecords<>(results); + } + + public synchronized void addRecord(ConsumerRecord record) { + ensureNotClosed(); + TopicPartition tp = new TopicPartition(record.topic(), record.partition()); + Set currentAssigned = this.subscriptions.assignedPartitions(); + if (!currentAssigned.contains(tp)) + throw new IllegalStateException("Cannot add records for a partition that is not assigned to the consumer"); + List> recs = this.records.computeIfAbsent(tp, k -> new ArrayList<>()); + recs.add(record); + } + + /** + * @deprecated Use {@link #setPollException(KafkaException)} instead + */ + @Deprecated + public synchronized void setException(KafkaException exception) { + setPollException(exception); + } + + public synchronized void setPollException(KafkaException exception) { + this.pollException = exception; + } + + public synchronized void setOffsetsException(KafkaException exception) { + this.offsetsException = exception; + } + + @Override + public synchronized void commitAsync(Map offsets, OffsetCommitCallback callback) { + ensureNotClosed(); + for (Map.Entry entry : offsets.entrySet()) + committed.put(entry.getKey(), entry.getValue()); + if (callback != null) { + callback.onComplete(offsets, null); + } + } + + @Override + public synchronized void commitSync(Map offsets) { + commitAsync(offsets, null); + } + + @Override + public synchronized void commitAsync() { + commitAsync(null); + } + + @Override + public synchronized void commitAsync(OffsetCommitCallback callback) { + ensureNotClosed(); + commitAsync(this.subscriptions.allConsumed(), callback); + } + + @Override + public synchronized void commitSync() { + commitSync(this.subscriptions.allConsumed()); + } + + @Override + public synchronized void commitSync(Duration timeout) { + commitSync(this.subscriptions.allConsumed()); + } + + @Override + public void commitSync(Map offsets, final Duration timeout) { + commitSync(offsets); + } + + @Override + public synchronized void seek(TopicPartition partition, long offset) { + ensureNotClosed(); + subscriptions.seek(partition, offset); + } + + @Override + public void seek(TopicPartition partition, OffsetAndMetadata offsetAndMetadata) { + ensureNotClosed(); + subscriptions.seek(partition, offsetAndMetadata.offset()); + } + + @Deprecated + @Override + public synchronized OffsetAndMetadata committed(final TopicPartition partition) { + return committed(singleton(partition)).get(partition); + } + + @Deprecated + @Override + public OffsetAndMetadata committed(final TopicPartition partition, final Duration timeout) { + return committed(partition); + } + + @Override + public synchronized Map committed(final Set partitions) { + ensureNotClosed(); + + return partitions.stream() + .filter(committed::containsKey) + .collect(Collectors.toMap(tp -> tp, tp -> subscriptions.isAssigned(tp) ? + committed.get(tp) : new OffsetAndMetadata(0))); + } + + @Override + public synchronized Map committed(final Set partitions, final Duration timeout) { + return committed(partitions); + } + + @Override + public synchronized long position(TopicPartition partition) { + ensureNotClosed(); + if (!this.subscriptions.isAssigned(partition)) + throw new IllegalArgumentException("You can only check the position for partitions assigned to this consumer."); + SubscriptionState.FetchPosition position = this.subscriptions.position(partition); + if (position == null) { + updateFetchPosition(partition); + position = this.subscriptions.position(partition); + } + return position.offset; + } + + @Override + public synchronized long position(TopicPartition partition, final Duration timeout) { + return position(partition); + } + + @Override + public synchronized void seekToBeginning(Collection partitions) { + ensureNotClosed(); + subscriptions.requestOffsetReset(partitions, OffsetResetStrategy.EARLIEST); + } + + public synchronized void updateBeginningOffsets(Map newOffsets) { + beginningOffsets.putAll(newOffsets); + } + + @Override + public synchronized void seekToEnd(Collection partitions) { + ensureNotClosed(); + subscriptions.requestOffsetReset(partitions, OffsetResetStrategy.LATEST); + } + + public synchronized void updateEndOffsets(final Map newOffsets) { + endOffsets.putAll(newOffsets); + } + + @Override + public synchronized Map metrics() { + ensureNotClosed(); + return Collections.emptyMap(); + } + + @Override + public synchronized List partitionsFor(String topic) { + ensureNotClosed(); + return this.partitions.getOrDefault(topic, Collections.emptyList()); + } + + @Override + public synchronized Map> listTopics() { + ensureNotClosed(); + return partitions; + } + + public synchronized void updatePartitions(String topic, List partitions) { + ensureNotClosed(); + this.partitions.put(topic, partitions); + } + + @Override + public synchronized void pause(Collection partitions) { + for (TopicPartition partition : partitions) { + subscriptions.pause(partition); + paused.add(partition); + } + } + + @Override + public synchronized void resume(Collection partitions) { + for (TopicPartition partition : partitions) { + subscriptions.resume(partition); + paused.remove(partition); + } + } + + @Override + public synchronized Map offsetsForTimes(Map timestampsToSearch) { + throw new UnsupportedOperationException("Not implemented yet."); + } + + @Override + public synchronized Map beginningOffsets(Collection partitions) { + if (offsetsException != null) { + RuntimeException exception = this.offsetsException; + this.offsetsException = null; + throw exception; + } + Map result = new HashMap<>(); + for (TopicPartition tp : partitions) { + Long beginningOffset = beginningOffsets.get(tp); + if (beginningOffset == null) + throw new IllegalStateException("The partition " + tp + " does not have a beginning offset."); + result.put(tp, beginningOffset); + } + return result; + } + + @Override + public synchronized Map endOffsets(Collection partitions) { + if (offsetsException != null) { + RuntimeException exception = this.offsetsException; + this.offsetsException = null; + throw exception; + } + Map result = new HashMap<>(); + for (TopicPartition tp : partitions) { + Long endOffset = endOffsets.get(tp); + if (endOffset == null) + throw new IllegalStateException("The partition " + tp + " does not have an end offset."); + result.put(tp, endOffset); + } + return result; + } + + @Override + public void close() { + close(Duration.ofMillis(DEFAULT_CLOSE_TIMEOUT_MS)); + } + + @Override + public synchronized void close(Duration timeout) { + this.closed = true; + } + + public synchronized boolean closed() { + return this.closed; + } + + @Override + public synchronized void wakeup() { + wakeup.set(true); + } + + /** + * Schedule a task to be executed during a poll(). One enqueued task will be executed per {@link #poll(Duration)} + * invocation. You can use this repeatedly to mock out multiple responses to poll invocations. + * @param task the task to be executed + */ + public synchronized void schedulePollTask(Runnable task) { + synchronized (pollTasks) { + pollTasks.add(task); + } + } + + public synchronized void scheduleNopPollTask() { + schedulePollTask(() -> { }); + } + + public synchronized Set paused() { + return Collections.unmodifiableSet(new HashSet<>(paused)); + } + + private void ensureNotClosed() { + if (this.closed) + throw new IllegalStateException("This consumer has already been closed."); + } + + private void updateFetchPosition(TopicPartition tp) { + if (subscriptions.isOffsetResetNeeded(tp)) { + resetOffsetPosition(tp); + } else if (!committed.containsKey(tp)) { + subscriptions.requestOffsetReset(tp); + resetOffsetPosition(tp); + } else { + subscriptions.seek(tp, committed.get(tp).offset()); + } + } + + private void resetOffsetPosition(TopicPartition tp) { + OffsetResetStrategy strategy = subscriptions.resetStrategy(tp); + Long offset; + if (strategy == OffsetResetStrategy.EARLIEST) { + offset = beginningOffsets.get(tp); + if (offset == null) + throw new IllegalStateException("MockConsumer didn't have beginning offset specified, but tried to seek to beginning"); + } else if (strategy == OffsetResetStrategy.LATEST) { + offset = endOffsets.get(tp); + if (offset == null) + throw new IllegalStateException("MockConsumer didn't have end offset specified, but tried to seek to end"); + } else { + throw new NoOffsetForPartitionException(tp); + } + seek(tp, offset); + } + + @Override + public List partitionsFor(String topic, Duration timeout) { + return partitionsFor(topic); + } + + @Override + public Map> listTopics(Duration timeout) { + return listTopics(); + } + + @Override + public Map offsetsForTimes(Map timestampsToSearch, + Duration timeout) { + return offsetsForTimes(timestampsToSearch); + } + + @Override + public Map beginningOffsets(Collection partitions, Duration timeout) { + return beginningOffsets(partitions); + } + + @Override + public Map endOffsets(Collection partitions, Duration timeout) { + return endOffsets(partitions); + } + + @Override + public OptionalLong currentLag(TopicPartition topicPartition) { + if (endOffsets.containsKey(topicPartition)) { + return OptionalLong.of(endOffsets.get(topicPartition) - position(topicPartition)); + } else { + // if the test doesn't bother to set an end offset, we assume it wants to model being caught up. + return OptionalLong.of(0L); + } + } + + @Override + public ConsumerGroupMetadata groupMetadata() { + return new ConsumerGroupMetadata("dummy.group.id", 1, "1", Optional.empty()); + } + + @Override + public void enforceRebalance() { + shouldRebalance = true; + } + + public boolean shouldRebalance() { + return shouldRebalance; + } + + public void resetShouldRebalance() { + shouldRebalance = false; + } + + public Duration lastPollTimeout() { + return lastPollTimeout; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/NoOffsetForPartitionException.java b/clients/src/main/java/org/apache/kafka/clients/consumer/NoOffsetForPartitionException.java new file mode 100644 index 0000000..1770d60 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/NoOffsetForPartitionException.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import org.apache.kafka.common.TopicPartition; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +/** + * Indicates that there is no stored offset for a partition and no defined offset + * reset policy. + */ +public class NoOffsetForPartitionException extends InvalidOffsetException { + + private static final long serialVersionUID = 1L; + + private final Set partitions; + + public NoOffsetForPartitionException(TopicPartition partition) { + super("Undefined offset with no reset policy for partition: " + partition); + this.partitions = Collections.singleton(partition); + } + + public NoOffsetForPartitionException(Collection partitions) { + super("Undefined offset with no reset policy for partitions: " + partitions); + this.partitions = Collections.unmodifiableSet(new HashSet<>(partitions)); + } + + /** + * returns all partitions for which no offests are defined. + * @return all partitions without offsets + */ + public Set partitions() { + return partitions; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/OffsetAndMetadata.java b/clients/src/main/java/org/apache/kafka/clients/consumer/OffsetAndMetadata.java new file mode 100644 index 0000000..d6b3b94 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/OffsetAndMetadata.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import org.apache.kafka.common.requests.OffsetFetchResponse; + +import java.io.Serializable; +import java.util.Objects; +import java.util.Optional; + +/** + * The Kafka offset commit API allows users to provide additional metadata (in the form of a string) + * when an offset is committed. This can be useful (for example) to store information about which + * node made the commit, what time the commit was made, etc. + */ +public class OffsetAndMetadata implements Serializable { + private static final long serialVersionUID = 2019555404968089681L; + + private final long offset; + private final String metadata; + + // We use null to represent the absence of a leader epoch to simplify serialization. + // I.e., older serializations of this class which do not have this field will automatically + // initialize its value to null. + private final Integer leaderEpoch; + + /** + * Construct a new OffsetAndMetadata object for committing through {@link KafkaConsumer}. + * + * @param offset The offset to be committed + * @param leaderEpoch Optional leader epoch of the last consumed record + * @param metadata Non-null metadata + */ + public OffsetAndMetadata(long offset, Optional leaderEpoch, String metadata) { + if (offset < 0) + throw new IllegalArgumentException("Invalid negative offset"); + + this.offset = offset; + this.leaderEpoch = leaderEpoch.orElse(null); + + // The server converts null metadata to an empty string. So we store it as an empty string as well on the client + // to be consistent. + if (metadata == null) + this.metadata = OffsetFetchResponse.NO_METADATA; + else + this.metadata = metadata; + } + + /** + * Construct a new OffsetAndMetadata object for committing through {@link KafkaConsumer}. + * @param offset The offset to be committed + * @param metadata Non-null metadata + */ + public OffsetAndMetadata(long offset, String metadata) { + this(offset, Optional.empty(), metadata); + } + + /** + * Construct a new OffsetAndMetadata object for committing through {@link KafkaConsumer}. The metadata + * associated with the commit will be empty. + * @param offset The offset to be committed + */ + public OffsetAndMetadata(long offset) { + this(offset, ""); + } + + public long offset() { + return offset; + } + + public String metadata() { + return metadata; + } + + /** + * Get the leader epoch of the previously consumed record (if one is known). Log truncation is detected + * if there exists a leader epoch which is larger than this epoch and begins at an offset earlier than + * the committed offset. + * + * @return the leader epoch or empty if not known + */ + public Optional leaderEpoch() { + if (leaderEpoch == null || leaderEpoch < 0) + return Optional.empty(); + return Optional.of(leaderEpoch); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + OffsetAndMetadata that = (OffsetAndMetadata) o; + return offset == that.offset && + Objects.equals(metadata, that.metadata) && + Objects.equals(leaderEpoch, that.leaderEpoch); + } + + @Override + public int hashCode() { + return Objects.hash(offset, metadata, leaderEpoch); + } + + @Override + public String toString() { + return "OffsetAndMetadata{" + + "offset=" + offset + + ", leaderEpoch=" + leaderEpoch + + ", metadata='" + metadata + '\'' + + '}'; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/OffsetAndTimestamp.java b/clients/src/main/java/org/apache/kafka/clients/consumer/OffsetAndTimestamp.java new file mode 100644 index 0000000..40d9930 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/OffsetAndTimestamp.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import java.util.Objects; +import java.util.Optional; + +/** + * A container class for offset and timestamp. + */ +public final class OffsetAndTimestamp { + private final long timestamp; + private final long offset; + private final Optional leaderEpoch; + + public OffsetAndTimestamp(long offset, long timestamp) { + this(offset, timestamp, Optional.empty()); + } + + public OffsetAndTimestamp(long offset, long timestamp, Optional leaderEpoch) { + if (offset < 0) + throw new IllegalArgumentException("Invalid negative offset"); + + if (timestamp < 0) + throw new IllegalArgumentException("Invalid negative timestamp"); + + this.offset = offset; + this.timestamp = timestamp; + this.leaderEpoch = leaderEpoch; + } + + public long timestamp() { + return timestamp; + } + + public long offset() { + return offset; + } + + /** + * Get the leader epoch corresponding to the offset that was found (if one exists). + * This can be provided to seek() to ensure that the log hasn't been truncated prior to fetching. + * + * @return The leader epoch or empty if it is not known + */ + public Optional leaderEpoch() { + return leaderEpoch; + } + + @Override + public String toString() { + return "(timestamp=" + timestamp + + ", leaderEpoch=" + leaderEpoch.orElse(null) + + ", offset=" + offset + ")"; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + OffsetAndTimestamp that = (OffsetAndTimestamp) o; + return timestamp == that.timestamp && + offset == that.offset && + Objects.equals(leaderEpoch, that.leaderEpoch); + } + + @Override + public int hashCode() { + return Objects.hash(timestamp, offset, leaderEpoch); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/OffsetCommitCallback.java b/clients/src/main/java/org/apache/kafka/clients/consumer/OffsetCommitCallback.java new file mode 100644 index 0000000..53e8ae7 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/OffsetCommitCallback.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import org.apache.kafka.common.TopicPartition; + +import java.time.Duration; +import java.util.Collection; +import java.util.Map; + +/** + * A callback interface that the user can implement to trigger custom actions when a commit request completes. The callback + * may be executed in any thread calling {@link Consumer#poll(java.time.Duration) poll()}. + */ +public interface OffsetCommitCallback { + + /** + * A callback method the user can implement to provide asynchronous handling of commit request completion. + * This method will be called when the commit request sent to the server has been acknowledged. + * + * @param offsets A map of the offsets and associated metadata that this callback applies to + * @param exception The exception thrown during processing of the request, or null if the commit completed successfully + * + * @throws org.apache.kafka.clients.consumer.CommitFailedException if the commit failed and cannot be retried. + * This can only occur if you are using automatic group management with {@link KafkaConsumer#subscribe(Collection)}, + * or if there is an active group with the same groupId which is using group management. + * @throws org.apache.kafka.common.errors.RebalanceInProgressException if the commit failed because + * it is in the middle of a rebalance. In such cases + * commit could be retried after the rebalance is completed with the {@link KafkaConsumer#poll(Duration)} call. + * @throws org.apache.kafka.common.errors.WakeupException if {@link KafkaConsumer#wakeup()} is called before or while this + * function is called + * @throws org.apache.kafka.common.errors.InterruptException if the calling thread is interrupted before or while + * this function is called + * @throws org.apache.kafka.common.errors.AuthorizationException if not authorized to the topic or to the + * configured groupId. See the exception for more details + * @throws org.apache.kafka.common.KafkaException for any other unrecoverable errors (e.g. if offset metadata + * is too large or if the committed offset is invalid). + */ + void onComplete(Map offsets, Exception exception); +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/OffsetOutOfRangeException.java b/clients/src/main/java/org/apache/kafka/clients/consumer/OffsetOutOfRangeException.java new file mode 100644 index 0000000..c98e22f --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/OffsetOutOfRangeException.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import org.apache.kafka.common.TopicPartition; + +import java.util.Map; +import java.util.Set; + +/** + * No reset policy has been defined, and the offsets for these partitions are either larger or smaller + * than the range of offsets the server has for the given partition. + */ +public class OffsetOutOfRangeException extends InvalidOffsetException { + + private static final long serialVersionUID = 1L; + private final Map offsetOutOfRangePartitions; + + public OffsetOutOfRangeException(Map offsetOutOfRangePartitions) { + this("Offsets out of range with no configured reset policy for partitions: " + + offsetOutOfRangePartitions, offsetOutOfRangePartitions); + } + + public OffsetOutOfRangeException(String message, Map offsetOutOfRangePartitions) { + super(message); + this.offsetOutOfRangePartitions = offsetOutOfRangePartitions; + } + + /** + * Get a map of the topic partitions and the respective out-of-range fetch offsets. + */ + public Map offsetOutOfRangePartitions() { + return offsetOutOfRangePartitions; + } + + @Override + public Set partitions() { + return offsetOutOfRangePartitions.keySet(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/OffsetResetStrategy.java b/clients/src/main/java/org/apache/kafka/clients/consumer/OffsetResetStrategy.java new file mode 100644 index 0000000..6d742b8 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/OffsetResetStrategy.java @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +public enum OffsetResetStrategy { + LATEST, EARLIEST, NONE +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/RangeAssignor.java b/clients/src/main/java/org/apache/kafka/clients/consumer/RangeAssignor.java new file mode 100644 index 0000000..aec0d39 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/RangeAssignor.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import org.apache.kafka.clients.consumer.internals.AbstractPartitionAssignor; +import org.apache.kafka.common.TopicPartition; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + *

The range assignor works on a per-topic basis. For each topic, we lay out the available partitions in numeric order + * and the consumers in lexicographic order. We then divide the number of partitions by the total number of + * consumers to determine the number of partitions to assign to each consumer. If it does not evenly + * divide, then the first few consumers will have one extra partition. + * + *

For example, suppose there are two consumers C0 and C1, two topics t0 and + * t1, and each topic has 3 partitions, resulting in partitions t0p0, t0p1, + * t0p2, t1p0, t1p1, and t1p2. + * + *

The assignment will be: + *

    + *
  • C0: [t0p0, t0p1, t1p0, t1p1]
  • + *
  • C1: [t0p2, t1p2]
  • + *
+ * + * Since the introduction of static membership, we could leverage group.instance.id to make the assignment behavior more sticky. + * For the above example, after one rolling bounce, group coordinator will attempt to assign new member.id towards consumers, + * for example C0 -> C3 C1 -> C2. + * + *

The assignment could be completely shuffled to: + *

    + *
  • C3 (was C0): [t0p2, t1p2] (before was [t0p0, t0p1, t1p0, t1p1]) + *
  • C2 (was C1): [t0p0, t0p1, t1p0, t1p1] (before was [t0p2, t1p2]) + *
+ * + * The assignment change was caused by the change of member.id relative order, and + * can be avoided by setting the group.instance.id. + * Consumers will have individual instance ids I1, I2. As long as + * 1. Number of members remain the same across generation + * 2. Static members' identities persist across generation + * 3. Subscription pattern doesn't change for any member + * + *

The assignment will always be: + *

    + *
  • I0: [t0p0, t0p1, t1p0, t1p1] + *
  • I1: [t0p2, t1p2] + *
+ */ +public class RangeAssignor extends AbstractPartitionAssignor { + public static final String RANGE_ASSIGNOR_NAME = "range"; + + @Override + public String name() { + return RANGE_ASSIGNOR_NAME; + } + + private Map> consumersPerTopic(Map consumerMetadata) { + Map> topicToConsumers = new HashMap<>(); + for (Map.Entry subscriptionEntry : consumerMetadata.entrySet()) { + String consumerId = subscriptionEntry.getKey(); + MemberInfo memberInfo = new MemberInfo(consumerId, subscriptionEntry.getValue().groupInstanceId()); + for (String topic : subscriptionEntry.getValue().topics()) { + put(topicToConsumers, topic, memberInfo); + } + } + return topicToConsumers; + } + + @Override + public Map> assign(Map partitionsPerTopic, + Map subscriptions) { + Map> consumersPerTopic = consumersPerTopic(subscriptions); + + Map> assignment = new HashMap<>(); + for (String memberId : subscriptions.keySet()) + assignment.put(memberId, new ArrayList<>()); + + for (Map.Entry> topicEntry : consumersPerTopic.entrySet()) { + String topic = topicEntry.getKey(); + List consumersForTopic = topicEntry.getValue(); + + Integer numPartitionsForTopic = partitionsPerTopic.get(topic); + if (numPartitionsForTopic == null) + continue; + + Collections.sort(consumersForTopic); + + int numPartitionsPerConsumer = numPartitionsForTopic / consumersForTopic.size(); + int consumersWithExtraPartition = numPartitionsForTopic % consumersForTopic.size(); + + List partitions = AbstractPartitionAssignor.partitions(topic, numPartitionsForTopic); + for (int i = 0, n = consumersForTopic.size(); i < n; i++) { + int start = numPartitionsPerConsumer * i + Math.min(i, consumersWithExtraPartition); + int length = numPartitionsPerConsumer + (i + 1 > consumersWithExtraPartition ? 0 : 1); + assignment.get(consumersForTopic.get(i).memberId).addAll(partitions.subList(start, start + length)); + } + } + return assignment; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/RetriableCommitFailedException.java b/clients/src/main/java/org/apache/kafka/clients/consumer/RetriableCommitFailedException.java new file mode 100644 index 0000000..f44dce6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/RetriableCommitFailedException.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import org.apache.kafka.common.errors.RetriableException; + +public class RetriableCommitFailedException extends RetriableException { + + private static final long serialVersionUID = 1L; + + public RetriableCommitFailedException(Throwable t) { + super("Offset commit failed with a retriable exception. You should retry committing " + + "the latest consumed offsets.", t); + } + + public RetriableCommitFailedException(String message) { + super(message); + } + + public RetriableCommitFailedException(String message, Throwable t) { + super(message, t); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/RoundRobinAssignor.java b/clients/src/main/java/org/apache/kafka/clients/consumer/RoundRobinAssignor.java new file mode 100644 index 0000000..2d6edea --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/RoundRobinAssignor.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import org.apache.kafka.clients.consumer.internals.AbstractPartitionAssignor; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.CircularIterator; +import org.apache.kafka.common.utils.Utils; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.SortedSet; +import java.util.TreeSet; + +/** + *

The round robin assignor lays out all the available partitions and all the available consumers. It + * then proceeds to do a round robin assignment from partition to consumer. If the subscriptions of all consumer + * instances are identical, then the partitions will be uniformly distributed. (i.e., the partition ownership counts + * will be within a delta of exactly one across all consumers.) + * + *

For example, suppose there are two consumers C0 and C1, two topics t0 and t1, + * and each topic has 3 partitions, resulting in partitions t0p0, t0p1, t0p2, + * t1p0, t1p1, and t1p2. + * + *

The assignment will be: + *

    + *
  • C0: [t0p0, t0p2, t1p1] + *
  • C1: [t0p1, t1p0, t1p2] + *
+ * + *

When subscriptions differ across consumer instances, the assignment process still considers each + * consumer instance in round robin fashion but skips over an instance if it is not subscribed to + * the topic. Unlike the case when subscriptions are identical, this can result in imbalanced + * assignments. For example, we have three consumers C0, C1, C2, + * and three topics t0, t1, t2, with 1, 2, and 3 partitions, respectively. + * Therefore, the partitions are t0p0, t1p0, t1p1, t2p0, t2p1, t2p2. + * C0 is subscribed to t0; + * C1 is subscribed to t0, t1; + * and C2 is subscribed to t0, t1, t2. + * + *

That assignment will be: + *

    + *
  • C0: [t0p0] + *
  • C1: [t1p0] + *
  • C2: [t1p1, t2p0, t2p1, t2p2] + *
+ * + * Since the introduction of static membership, we could leverage group.instance.id to make the assignment behavior more sticky. + * For example, we have three consumers with assigned member.id C0, C1, C2, + * two topics t0 and t1, and each topic has 3 partitions, resulting in partitions t0p0, + * t0p1, t0p2, t1p0, t1p1, and t1p2. We choose to honor + * the sorted order based on ephemeral member.id. + * + *

The assignment will be: + *

    + *
  • C0: [t0p0, t1p0] + *
  • C1: [t0p1, t1p1] + *
  • C2: [t0p2, t1p2] + *
+ * + * After one rolling bounce, group coordinator will attempt to assign new member.id towards consumers, + * for example C0 -> C5 C1 -> C3, C2 -> C4. + * + *

The assignment could be completely shuffled to: + *

    + *
  • C3 (was C1): [t0p0, t1p0] (before was [t0p1, t1p1]) + *
  • C4 (was C2): [t0p1, t1p1] (before was [t0p2, t1p2]) + *
  • C5 (was C0): [t0p2, t1p2] (before was [t0p0, t1p0]) + *
+ * + * This issue could be mitigated by the introduction of static membership. Consumers will have individual instance ids + * I1, I2, I3. As long as + * 1. Number of members remain the same across generation + * 2. Static members' identities persist across generation + * 3. Subscription pattern doesn't change for any member + * + *

The assignment will always be: + *

    + *
  • I0: [t0p0, t1p0] + *
  • I1: [t0p1, t1p1] + *
  • I2: [t0p2, t1p2] + *
+ */ +public class RoundRobinAssignor extends AbstractPartitionAssignor { + public static final String ROUNDROBIN_ASSIGNOR_NAME = "roundrobin"; + + @Override + public Map> assign(Map partitionsPerTopic, + Map subscriptions) { + Map> assignment = new HashMap<>(); + List memberInfoList = new ArrayList<>(); + for (Map.Entry memberSubscription : subscriptions.entrySet()) { + assignment.put(memberSubscription.getKey(), new ArrayList<>()); + memberInfoList.add(new MemberInfo(memberSubscription.getKey(), + memberSubscription.getValue().groupInstanceId())); + } + + CircularIterator assigner = new CircularIterator<>(Utils.sorted(memberInfoList)); + + for (TopicPartition partition : allPartitionsSorted(partitionsPerTopic, subscriptions)) { + final String topic = partition.topic(); + while (!subscriptions.get(assigner.peek().memberId).topics().contains(topic)) + assigner.next(); + assignment.get(assigner.next().memberId).add(partition); + } + return assignment; + } + + private List allPartitionsSorted(Map partitionsPerTopic, + Map subscriptions) { + SortedSet topics = new TreeSet<>(); + for (Subscription subscription : subscriptions.values()) + topics.addAll(subscription.topics()); + + List allPartitions = new ArrayList<>(); + for (String topic : topics) { + Integer numPartitionsForTopic = partitionsPerTopic.get(topic); + if (numPartitionsForTopic != null) + allPartitions.addAll(AbstractPartitionAssignor.partitions(topic, numPartitionsForTopic)); + } + return allPartitions; + } + + @Override + public String name() { + return ROUNDROBIN_ASSIGNOR_NAME; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/StickyAssignor.java b/clients/src/main/java/org/apache/kafka/clients/consumer/StickyAssignor.java new file mode 100644 index 0000000..787c432 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/StickyAssignor.java @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import org.apache.kafka.clients.consumer.internals.AbstractStickyAssignor; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.protocol.types.ArrayOf; +import org.apache.kafka.common.protocol.types.Field; +import org.apache.kafka.common.protocol.types.Schema; +import org.apache.kafka.common.protocol.types.Struct; +import org.apache.kafka.common.protocol.types.Type; +import org.apache.kafka.common.utils.CollectionUtils; + +/** + *

The sticky assignor serves two purposes. First, it guarantees an assignment that is as balanced as possible, meaning either: + *

    + *
  • the numbers of topic partitions assigned to consumers differ by at most one; or
  • + *
  • each consumer that has 2+ fewer topic partitions than some other consumer cannot get any of those topic partitions transferred to it.
  • + *
+ * Second, it preserved as many existing assignment as possible when a reassignment occurs. This helps in saving some of the + * overhead processing when topic partitions move from one consumer to another.

+ * + *

Starting fresh it would work by distributing the partitions over consumers as evenly as possible. Even though this may sound similar to + * how round robin assignor works, the second example below shows that it is not. + * During a reassignment it would perform the reassignment in such a way that in the new assignment + *

    + *
  1. topic partitions are still distributed as evenly as possible, and
  2. + *
  3. topic partitions stay with their previously assigned consumers as much as possible.
  4. + *
+ * Of course, the first goal above takes precedence over the second one.

+ * + *

Example 1. Suppose there are three consumers C0, C1, C2, + * four topics t0, t1, t2, t3, and each topic has 2 partitions, + * resulting in partitions t0p0, t0p1, t1p0, t1p1, t2p0, + * t2p1, t3p0, t3p1. Each consumer is subscribed to all three topics. + * + * The assignment with both sticky and round robin assignors will be: + *

    + *
  • C0: [t0p0, t1p1, t3p0]
  • + *
  • C1: [t0p1, t2p0, t3p1]
  • + *
  • C2: [t1p0, t2p1]
  • + *
+ * + * Now, let's assume C1 is removed and a reassignment is about to happen. The round robin assignor would produce: + *
    + *
  • C0: [t0p0, t1p0, t2p0, t3p0]
  • + *
  • C2: [t0p1, t1p1, t2p1, t3p1]
  • + *
+ * + * while the sticky assignor would result in: + *
    + *
  • C0 [t0p0, t1p1, t3p0, t2p0]
  • + *
  • C2 [t1p0, t2p1, t0p1, t3p1]
  • + *
+ * preserving all the previous assignments (unlike the round robin assignor). + *

+ *

Example 2. There are three consumers C0, C1, C2, + * and three topics t0, t1, t2, with 1, 2, and 3 partitions respectively. + * Therefore, the partitions are t0p0, t1p0, t1p1, t2p0, + * t2p1, t2p2. C0 is subscribed to t0; C1 is subscribed to + * t0, t1; and C2 is subscribed to t0, t1, t2. + * + * The round robin assignor would come up with the following assignment: + *

    + *
  • C0 [t0p0]
  • + *
  • C1 [t1p0]
  • + *
  • C2 [t1p1, t2p0, t2p1, t2p2]
  • + *
+ * + * which is not as balanced as the assignment suggested by sticky assignor: + *
    + *
  • C0 [t0p0]
  • + *
  • C1 [t1p0, t1p1]
  • + *
  • C2 [t2p0, t2p1, t2p2]
  • + *
+ * + * Now, if consumer C0 is removed, these two assignors would produce the following assignments. + * Round Robin (preserves 3 partition assignments): + *
    + *
  • C1 [t0p0, t1p1]
  • + *
  • C2 [t1p0, t2p0, t2p1, t2p2]
  • + *
+ * + * Sticky (preserves 5 partition assignments): + *
    + *
  • C1 [t1p0, t1p1, t0p0]
  • + *
  • C2 [t2p0, t2p1, t2p2]
  • + *
+ *

+ *

Impact on ConsumerRebalanceListener

+ * The sticky assignment strategy can provide some optimization to those consumers that have some partition cleanup code + * in their onPartitionsRevoked() callback listeners. The cleanup code is placed in that callback listener + * because the consumer has no assumption or hope of preserving any of its assigned partitions after a rebalance when it + * is using range or round robin assignor. The listener code would look like this: + *
+ * {@code
+ * class TheOldRebalanceListener implements ConsumerRebalanceListener {
+ *
+ *   void onPartitionsRevoked(Collection partitions) {
+ *     for (TopicPartition partition: partitions) {
+ *       commitOffsets(partition);
+ *       cleanupState(partition);
+ *     }
+ *   }
+ *
+ *   void onPartitionsAssigned(Collection partitions) {
+ *     for (TopicPartition partition: partitions) {
+ *       initializeState(partition);
+ *       initializeOffset(partition);
+ *     }
+ *   }
+ * }
+ * }
+ * 
+ * + * As mentioned above, one advantage of the sticky assignor is that, in general, it reduces the number of partitions that + * actually move from one consumer to another during a reassignment. Therefore, it allows consumers to do their cleanup + * more efficiently. Of course, they still can perform the partition cleanup in the onPartitionsRevoked() + * listener, but they can be more efficient and make a note of their partitions before and after the rebalance, and do the + * cleanup after the rebalance only on the partitions they have lost (which is normally not a lot). The code snippet below + * clarifies this point: + *
+ * {@code
+ * class TheNewRebalanceListener implements ConsumerRebalanceListener {
+ *   Collection lastAssignment = Collections.emptyList();
+ *
+ *   void onPartitionsRevoked(Collection partitions) {
+ *     for (TopicPartition partition: partitions)
+ *       commitOffsets(partition);
+ *   }
+ *
+ *   void onPartitionsAssigned(Collection assignment) {
+ *     for (TopicPartition partition: difference(lastAssignment, assignment))
+ *       cleanupState(partition);
+ *
+ *     for (TopicPartition partition: difference(assignment, lastAssignment))
+ *       initializeState(partition);
+ *
+ *     for (TopicPartition partition: assignment)
+ *       initializeOffset(partition);
+ *
+ *     this.lastAssignment = assignment;
+ *   }
+ * }
+ * }
+ * 
+ * + * Any consumer that uses sticky assignment can leverage this listener like this: + * consumer.subscribe(topics, new TheNewRebalanceListener()); + * + * Note that you can leverage the {@link CooperativeStickyAssignor} so that only partitions which are being + * reassigned to another consumer will be revoked. That is the preferred assignor for newer cluster. See + * {@link ConsumerPartitionAssignor.RebalanceProtocol} for a detailed explanation of cooperative rebalancing. + */ +public class StickyAssignor extends AbstractStickyAssignor { + public static final String STICKY_ASSIGNOR_NAME = "sticky"; + + // these schemas are used for preserving consumer's previously assigned partitions + // list and sending it as user data to the leader during a rebalance + static final String TOPIC_PARTITIONS_KEY_NAME = "previous_assignment"; + static final String TOPIC_KEY_NAME = "topic"; + static final String PARTITIONS_KEY_NAME = "partitions"; + private static final String GENERATION_KEY_NAME = "generation"; + + static final Schema TOPIC_ASSIGNMENT = new Schema( + new Field(TOPIC_KEY_NAME, Type.STRING), + new Field(PARTITIONS_KEY_NAME, new ArrayOf(Type.INT32))); + static final Schema STICKY_ASSIGNOR_USER_DATA_V0 = new Schema( + new Field(TOPIC_PARTITIONS_KEY_NAME, new ArrayOf(TOPIC_ASSIGNMENT))); + private static final Schema STICKY_ASSIGNOR_USER_DATA_V1 = new Schema( + new Field(TOPIC_PARTITIONS_KEY_NAME, new ArrayOf(TOPIC_ASSIGNMENT)), + new Field(GENERATION_KEY_NAME, Type.INT32)); + + private List memberAssignment = null; + private int generation = DEFAULT_GENERATION; // consumer group generation + + @Override + public String name() { + return STICKY_ASSIGNOR_NAME; + } + + @Override + public void onAssignment(Assignment assignment, ConsumerGroupMetadata metadata) { + memberAssignment = assignment.partitions(); + this.generation = metadata.generationId(); + } + + @Override + public ByteBuffer subscriptionUserData(Set topics) { + if (memberAssignment == null) + return null; + + return serializeTopicPartitionAssignment(new MemberData(memberAssignment, Optional.of(generation))); + } + + @Override + protected MemberData memberData(Subscription subscription) { + ByteBuffer userData = subscription.userData(); + if (userData == null || !userData.hasRemaining()) { + return new MemberData(Collections.emptyList(), Optional.empty()); + } + return deserializeTopicPartitionAssignment(userData); + } + + // visible for testing + static ByteBuffer serializeTopicPartitionAssignment(MemberData memberData) { + Struct struct = new Struct(STICKY_ASSIGNOR_USER_DATA_V1); + List topicAssignments = new ArrayList<>(); + for (Map.Entry> topicEntry : CollectionUtils.groupPartitionsByTopic(memberData.partitions).entrySet()) { + Struct topicAssignment = new Struct(TOPIC_ASSIGNMENT); + topicAssignment.set(TOPIC_KEY_NAME, topicEntry.getKey()); + topicAssignment.set(PARTITIONS_KEY_NAME, topicEntry.getValue().toArray()); + topicAssignments.add(topicAssignment); + } + struct.set(TOPIC_PARTITIONS_KEY_NAME, topicAssignments.toArray()); + if (memberData.generation.isPresent()) + struct.set(GENERATION_KEY_NAME, memberData.generation.get()); + ByteBuffer buffer = ByteBuffer.allocate(STICKY_ASSIGNOR_USER_DATA_V1.sizeOf(struct)); + STICKY_ASSIGNOR_USER_DATA_V1.write(buffer, struct); + buffer.flip(); + return buffer; + } + + private static MemberData deserializeTopicPartitionAssignment(ByteBuffer buffer) { + Struct struct; + ByteBuffer copy = buffer.duplicate(); + try { + struct = STICKY_ASSIGNOR_USER_DATA_V1.read(buffer); + } catch (Exception e1) { + try { + // fall back to older schema + struct = STICKY_ASSIGNOR_USER_DATA_V0.read(copy); + } catch (Exception e2) { + // ignore the consumer's previous assignment if it cannot be parsed + return new MemberData(Collections.emptyList(), Optional.of(DEFAULT_GENERATION)); + } + } + + List partitions = new ArrayList<>(); + for (Object structObj : struct.getArray(TOPIC_PARTITIONS_KEY_NAME)) { + Struct assignment = (Struct) structObj; + String topic = assignment.getString(TOPIC_KEY_NAME); + for (Object partitionObj : assignment.getArray(PARTITIONS_KEY_NAME)) { + Integer partition = (Integer) partitionObj; + partitions.add(new TopicPartition(topic, partition)); + } + } + // make sure this is backward compatible + Optional generation = struct.hasField(GENERATION_KEY_NAME) ? Optional.of(struct.getInt(GENERATION_KEY_NAME)) : Optional.empty(); + return new MemberData(partitions, generation); + } +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinator.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinator.java new file mode 100644 index 0000000..6f16b34 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinator.java @@ -0,0 +1,1567 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.clients.ClientResponse; +import org.apache.kafka.clients.GroupRebalanceConfig; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.errors.DisconnectException; +import org.apache.kafka.common.errors.FencedInstanceIdException; +import org.apache.kafka.common.errors.GroupAuthorizationException; +import org.apache.kafka.common.errors.GroupMaxSizeReachedException; +import org.apache.kafka.common.errors.IllegalGenerationException; +import org.apache.kafka.common.errors.InterruptException; +import org.apache.kafka.common.errors.MemberIdRequiredException; +import org.apache.kafka.common.errors.RebalanceInProgressException; +import org.apache.kafka.common.errors.RetriableException; +import org.apache.kafka.common.errors.UnknownMemberIdException; +import org.apache.kafka.common.message.FindCoordinatorRequestData; +import org.apache.kafka.common.message.FindCoordinatorResponseData.Coordinator; +import org.apache.kafka.common.message.HeartbeatRequestData; +import org.apache.kafka.common.message.JoinGroupRequestData; +import org.apache.kafka.common.message.JoinGroupResponseData; +import org.apache.kafka.common.message.LeaveGroupRequestData.MemberIdentity; +import org.apache.kafka.common.message.LeaveGroupResponseData.MemberResponse; +import org.apache.kafka.common.message.SyncGroupRequestData; +import org.apache.kafka.common.metrics.Measurable; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.stats.Avg; +import org.apache.kafka.common.metrics.stats.CumulativeCount; +import org.apache.kafka.common.metrics.stats.CumulativeSum; +import org.apache.kafka.common.metrics.stats.Max; +import org.apache.kafka.common.metrics.stats.Meter; +import org.apache.kafka.common.metrics.stats.Rate; +import org.apache.kafka.common.metrics.stats.WindowedCount; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.FindCoordinatorRequest; +import org.apache.kafka.common.requests.FindCoordinatorRequest.CoordinatorType; +import org.apache.kafka.common.requests.FindCoordinatorResponse; +import org.apache.kafka.common.requests.HeartbeatRequest; +import org.apache.kafka.common.requests.HeartbeatResponse; +import org.apache.kafka.common.requests.JoinGroupRequest; +import org.apache.kafka.common.requests.JoinGroupResponse; +import org.apache.kafka.common.requests.LeaveGroupRequest; +import org.apache.kafka.common.requests.LeaveGroupResponse; +import org.apache.kafka.common.requests.OffsetCommitRequest; +import org.apache.kafka.common.requests.SyncGroupRequest; +import org.apache.kafka.common.requests.SyncGroupResponse; +import org.apache.kafka.common.utils.KafkaThread; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Timer; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; + +import java.io.Closeable; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +/** + * AbstractCoordinator implements group management for a single group member by interacting with + * a designated Kafka broker (the coordinator). Group semantics are provided by extending this class. + * See {@link ConsumerCoordinator} for example usage. + * + * From a high level, Kafka's group management protocol consists of the following sequence of actions: + * + *
    + *
  1. Group Registration: Group members register with the coordinator providing their own metadata + * (such as the set of topics they are interested in).
  2. + *
  3. Group/Leader Selection: The coordinator select the members of the group and chooses one member + * as the leader.
  4. + *
  5. State Assignment: The leader collects the metadata from all the members of the group and + * assigns state.
  6. + *
  7. Group Stabilization: Each member receives the state assigned by the leader and begins + * processing.
  8. + *
+ * + * To leverage this protocol, an implementation must define the format of metadata provided by each + * member for group registration in {@link #metadata()} and the format of the state assignment provided + * by the leader in {@link #performAssignment(String, String, List)} and becomes available to members in + * {@link #onJoinComplete(int, String, String, ByteBuffer)}. + * + * Note on locking: this class shares state between the caller and a background thread which is + * used for sending heartbeats after the client has joined the group. All mutable state as well as + * state transitions are protected with the class's monitor. Generally this means acquiring the lock + * before reading or writing the state of the group (e.g. generation, memberId) and holding the lock + * when sending a request that affects the state of the group (e.g. JoinGroup, LeaveGroup). + */ +public abstract class AbstractCoordinator implements Closeable { + public static final String HEARTBEAT_THREAD_PREFIX = "kafka-coordinator-heartbeat-thread"; + public static final int JOIN_GROUP_TIMEOUT_LAPSE = 5000; + + protected enum MemberState { + UNJOINED, // the client is not part of a group + PREPARING_REBALANCE, // the client has sent the join group request, but have not received response + COMPLETING_REBALANCE, // the client has received join group response, but have not received assignment + STABLE; // the client has joined and is sending heartbeats + + public boolean hasNotJoinedGroup() { + return equals(UNJOINED) || equals(PREPARING_REBALANCE); + } + } + + private final Logger log; + private final Heartbeat heartbeat; + private final GroupCoordinatorMetrics sensors; + private final GroupRebalanceConfig rebalanceConfig; + + protected final Time time; + protected final ConsumerNetworkClient client; + + private Node coordinator = null; + private boolean rejoinNeeded = true; + private boolean needsJoinPrepare = true; + private HeartbeatThread heartbeatThread = null; + private RequestFuture joinFuture = null; + private RequestFuture findCoordinatorFuture = null; + private volatile RuntimeException fatalFindCoordinatorException = null; + private Generation generation = Generation.NO_GENERATION; + private long lastRebalanceStartMs = -1L; + private long lastRebalanceEndMs = -1L; + private long lastTimeOfConnectionMs = -1L; // starting logging a warning only after unable to connect for a while + + protected MemberState state = MemberState.UNJOINED; + + + /** + * Initialize the coordination manager. + */ + public AbstractCoordinator(GroupRebalanceConfig rebalanceConfig, + LogContext logContext, + ConsumerNetworkClient client, + Metrics metrics, + String metricGrpPrefix, + Time time) { + Objects.requireNonNull(rebalanceConfig.groupId, + "Expected a non-null group id for coordinator construction"); + this.rebalanceConfig = rebalanceConfig; + this.log = logContext.logger(this.getClass()); + this.client = client; + this.time = time; + this.heartbeat = new Heartbeat(rebalanceConfig, time); + this.sensors = new GroupCoordinatorMetrics(metrics, metricGrpPrefix); + } + + /** + * Unique identifier for the class of supported protocols (e.g. "consumer" or "connect"). + * @return Non-null protocol type name + */ + protected abstract String protocolType(); + + /** + * Get the current list of protocols and their associated metadata supported + * by the local member. The order of the protocols in the list indicates the preference + * of the protocol (the first entry is the most preferred). The coordinator takes this + * preference into account when selecting the generation protocol (generally more preferred + * protocols will be selected as long as all members support them and there is no disagreement + * on the preference). + * @return Non-empty map of supported protocols and metadata + */ + protected abstract JoinGroupRequestData.JoinGroupRequestProtocolCollection metadata(); + + /** + * Invoked prior to each group join or rejoin. This is typically used to perform any + * cleanup from the previous generation (such as committing offsets for the consumer) + * @param generation The previous generation or -1 if there was none + * @param memberId The identifier of this member in the previous group or "" if there was none + */ + protected abstract void onJoinPrepare(int generation, String memberId); + + /** + * Perform assignment for the group. This is used by the leader to push state to all the members + * of the group (e.g. to push partition assignments in the case of the new consumer) + * @param leaderId The id of the leader (which is this member) + * @param protocol The protocol selected by the coordinator + * @param allMemberMetadata Metadata from all members of the group + * @return A map from each member to their state assignment + */ + protected abstract Map performAssignment(String leaderId, + String protocol, + List allMemberMetadata); + + /** + * Invoked when a group member has successfully joined a group. If this call fails with an exception, + * then it will be retried using the same assignment state on the next call to {@link #ensureActiveGroup()}. + * + * @param generation The generation that was joined + * @param memberId The identifier for the local member in the group + * @param protocol The protocol selected by the coordinator + * @param memberAssignment The assignment propagated from the group leader + */ + protected abstract void onJoinComplete(int generation, + String memberId, + String protocol, + ByteBuffer memberAssignment); + + /** + * Invoked prior to each leave group event. This is typically used to cleanup assigned partitions; + * note it is triggered by the consumer's API caller thread (i.e. background heartbeat thread would + * not trigger it even if it tries to force leaving group upon heartbeat session expiration) + */ + protected void onLeavePrepare() {} + + /** + * Visible for testing. + * + * Ensure that the coordinator is ready to receive requests. + * + * @param timer Timer bounding how long this method can block + * @return true If coordinator discovery and initial connection succeeded, false otherwise + */ + protected synchronized boolean ensureCoordinatorReady(final Timer timer) { + if (!coordinatorUnknown()) + return true; + + do { + if (fatalFindCoordinatorException != null) { + final RuntimeException fatalException = fatalFindCoordinatorException; + fatalFindCoordinatorException = null; + throw fatalException; + } + final RequestFuture future = lookupCoordinator(); + client.poll(future, timer); + + if (!future.isDone()) { + // ran out of time + break; + } + + RuntimeException fatalException = null; + + if (future.failed()) { + if (future.isRetriable()) { + log.debug("Coordinator discovery failed, refreshing metadata", future.exception()); + client.awaitMetadataUpdate(timer); + } else { + fatalException = future.exception(); + log.info("FindCoordinator request hit fatal exception", fatalException); + } + } else if (coordinator != null && client.isUnavailable(coordinator)) { + // we found the coordinator, but the connection has failed, so mark + // it dead and backoff before retrying discovery + markCoordinatorUnknown("coordinator unavailable"); + timer.sleep(rebalanceConfig.retryBackoffMs); + } + + clearFindCoordinatorFuture(); + if (fatalException != null) + throw fatalException; + } while (coordinatorUnknown() && timer.notExpired()); + + return !coordinatorUnknown(); + } + + protected synchronized RequestFuture lookupCoordinator() { + if (findCoordinatorFuture == null) { + // find a node to ask about the coordinator + Node node = this.client.leastLoadedNode(); + if (node == null) { + log.debug("No broker available to send FindCoordinator request"); + return RequestFuture.noBrokersAvailable(); + } else { + findCoordinatorFuture = sendFindCoordinatorRequest(node); + } + } + return findCoordinatorFuture; + } + + private synchronized void clearFindCoordinatorFuture() { + findCoordinatorFuture = null; + } + + /** + * Check whether the group should be rejoined (e.g. if metadata changes) or whether a + * rejoin request is already in flight and needs to be completed. + * + * @return true if it should, false otherwise + */ + protected synchronized boolean rejoinNeededOrPending() { + // if there's a pending joinFuture, we should try to complete handling it. + return rejoinNeeded || joinFuture != null; + } + + /** + * Check the status of the heartbeat thread (if it is active) and indicate the liveness + * of the client. This must be called periodically after joining with {@link #ensureActiveGroup()} + * to ensure that the member stays in the group. If an interval of time longer than the + * provided rebalance timeout expires without calling this method, then the client will proactively + * leave the group. + * + * @param now current time in milliseconds + * @throws RuntimeException for unexpected errors raised from the heartbeat thread + */ + protected synchronized void pollHeartbeat(long now) { + if (heartbeatThread != null) { + if (heartbeatThread.hasFailed()) { + // set the heartbeat thread to null and raise an exception. If the user catches it, + // the next call to ensureActiveGroup() will spawn a new heartbeat thread. + RuntimeException cause = heartbeatThread.failureCause(); + heartbeatThread = null; + throw cause; + } + // Awake the heartbeat thread if needed + if (heartbeat.shouldHeartbeat(now)) { + notify(); + } + heartbeat.poll(now); + } + } + + protected synchronized long timeToNextHeartbeat(long now) { + // if we have not joined the group or we are preparing rebalance, + // we don't need to send heartbeats + if (state.hasNotJoinedGroup()) + return Long.MAX_VALUE; + return heartbeat.timeToNextHeartbeat(now); + } + + /** + * Ensure that the group is active (i.e. joined and synced) + */ + public void ensureActiveGroup() { + while (!ensureActiveGroup(time.timer(Long.MAX_VALUE))) { + log.warn("still waiting to ensure active group"); + } + } + + /** + * Ensure the group is active (i.e., joined and synced) + * + * @param timer Timer bounding how long this method can block + * @throws KafkaException if the callback throws exception + * @return true iff the group is active + */ + boolean ensureActiveGroup(final Timer timer) { + // always ensure that the coordinator is ready because we may have been disconnected + // when sending heartbeats and does not necessarily require us to rejoin the group. + if (!ensureCoordinatorReady(timer)) { + return false; + } + + startHeartbeatThreadIfNeeded(); + return joinGroupIfNeeded(timer); + } + + private synchronized void startHeartbeatThreadIfNeeded() { + if (heartbeatThread == null) { + heartbeatThread = new HeartbeatThread(); + heartbeatThread.start(); + } + } + + private void closeHeartbeatThread() { + HeartbeatThread thread; + synchronized (this) { + if (heartbeatThread == null) + return; + heartbeatThread.close(); + thread = heartbeatThread; + heartbeatThread = null; + } + try { + thread.join(); + } catch (InterruptedException e) { + log.warn("Interrupted while waiting for consumer heartbeat thread to close"); + throw new InterruptException(e); + } + } + + /** + * Joins the group without starting the heartbeat thread. + * + * If this function returns true, the state must always be in STABLE and heartbeat enabled. + * If this function returns false, the state can be in one of the following: + * * UNJOINED: got error response but times out before being able to re-join, heartbeat disabled + * * PREPARING_REBALANCE: not yet received join-group response before timeout, heartbeat disabled + * * COMPLETING_REBALANCE: not yet received sync-group response before timeout, heartbeat enabled + * + * Visible for testing. + * + * @param timer Timer bounding how long this method can block + * @throws KafkaException if the callback throws exception + * @return true iff the operation succeeded + */ + boolean joinGroupIfNeeded(final Timer timer) { + while (rejoinNeededOrPending()) { + if (!ensureCoordinatorReady(timer)) { + return false; + } + + // call onJoinPrepare if needed. We set a flag to make sure that we do not call it a second + // time if the client is woken up before a pending rebalance completes. This must be called + // on each iteration of the loop because an event requiring a rebalance (such as a metadata + // refresh which changes the matched subscription set) can occur while another rebalance is + // still in progress. + if (needsJoinPrepare) { + // need to set the flag before calling onJoinPrepare since the user callback may throw + // exception, in which case upon retry we should not retry onJoinPrepare either. + needsJoinPrepare = false; + onJoinPrepare(generation.generationId, generation.memberId); + } + + final RequestFuture future = initiateJoinGroup(); + client.poll(future, timer); + if (!future.isDone()) { + // we ran out of time + return false; + } + + if (future.succeeded()) { + Generation generationSnapshot; + MemberState stateSnapshot; + + // Generation data maybe concurrently cleared by Heartbeat thread. + // Can't use synchronized for {@code onJoinComplete}, because it can be long enough + // and shouldn't block heartbeat thread. + // See {@link PlaintextConsumerTest#testMaxPollIntervalMsDelayInAssignment} + synchronized (AbstractCoordinator.this) { + generationSnapshot = this.generation; + stateSnapshot = this.state; + } + + if (!generationSnapshot.equals(Generation.NO_GENERATION) && stateSnapshot == MemberState.STABLE) { + // Duplicate the buffer in case `onJoinComplete` does not complete and needs to be retried. + ByteBuffer memberAssignment = future.value().duplicate(); + + onJoinComplete(generationSnapshot.generationId, generationSnapshot.memberId, generationSnapshot.protocolName, memberAssignment); + + // Generally speaking we should always resetJoinGroupFuture once the future is done, but here + // we can only reset the join group future after the completion callback returns. This ensures + // that if the callback is woken up, we will retry it on the next joinGroupIfNeeded. + // And because of that we should explicitly trigger resetJoinGroupFuture in other conditions below. + resetJoinGroupFuture(); + needsJoinPrepare = true; + } else { + final String reason = String.format("rebalance failed since the generation/state was " + + "modified by heartbeat thread to %s/%s before the rebalance callback triggered", + generationSnapshot, stateSnapshot); + + resetStateAndRejoin(reason); + resetJoinGroupFuture(); + } + } else { + final RuntimeException exception = future.exception(); + + resetJoinGroupFuture(); + rejoinNeeded = true; + + if (exception instanceof UnknownMemberIdException || + exception instanceof IllegalGenerationException || + exception instanceof RebalanceInProgressException || + exception instanceof MemberIdRequiredException) + continue; + else if (!future.isRetriable()) + throw exception; + + timer.sleep(rebalanceConfig.retryBackoffMs); + } + } + return true; + } + + private synchronized void resetJoinGroupFuture() { + this.joinFuture = null; + } + + private synchronized RequestFuture initiateJoinGroup() { + // we store the join future in case we are woken up by the user after beginning the + // rebalance in the call to poll below. This ensures that we do not mistakenly attempt + // to rejoin before the pending rebalance has completed. + if (joinFuture == null) { + state = MemberState.PREPARING_REBALANCE; + // a rebalance can be triggered consecutively if the previous one failed, + // in this case we would not update the start time. + if (lastRebalanceStartMs == -1L) + lastRebalanceStartMs = time.milliseconds(); + joinFuture = sendJoinGroupRequest(); + joinFuture.addListener(new RequestFutureListener() { + @Override + public void onSuccess(ByteBuffer value) { + // do nothing since all the handler logic are in SyncGroupResponseHandler already + } + + @Override + public void onFailure(RuntimeException e) { + // we handle failures below after the request finishes. if the join completes + // after having been woken up, the exception is ignored and we will rejoin; + // this can be triggered when either join or sync request failed + synchronized (AbstractCoordinator.this) { + sensors.failedRebalanceSensor.record(); + } + } + }); + } + return joinFuture; + } + + /** + * Join the group and return the assignment for the next generation. This function handles both + * JoinGroup and SyncGroup, delegating to {@link #performAssignment(String, String, List)} if + * elected leader by the coordinator. + * + * NOTE: This is visible only for testing + * + * @return A request future which wraps the assignment returned from the group leader + */ + RequestFuture sendJoinGroupRequest() { + if (coordinatorUnknown()) + return RequestFuture.coordinatorNotAvailable(); + + // send a join group request to the coordinator + log.info("(Re-)joining group"); + JoinGroupRequest.Builder requestBuilder = new JoinGroupRequest.Builder( + new JoinGroupRequestData() + .setGroupId(rebalanceConfig.groupId) + .setSessionTimeoutMs(this.rebalanceConfig.sessionTimeoutMs) + .setMemberId(this.generation.memberId) + .setGroupInstanceId(this.rebalanceConfig.groupInstanceId.orElse(null)) + .setProtocolType(protocolType()) + .setProtocols(metadata()) + .setRebalanceTimeoutMs(this.rebalanceConfig.rebalanceTimeoutMs) + ); + + log.debug("Sending JoinGroup ({}) to coordinator {}", requestBuilder, this.coordinator); + + // Note that we override the request timeout using the rebalance timeout since that is the + // maximum time that it may block on the coordinator. We add an extra 5 seconds for small delays. + int joinGroupTimeoutMs = Math.max( + client.defaultRequestTimeoutMs(), + Math.max( + rebalanceConfig.rebalanceTimeoutMs + JOIN_GROUP_TIMEOUT_LAPSE, + rebalanceConfig.rebalanceTimeoutMs) // guard against overflow since rebalance timeout can be MAX_VALUE + ); + return client.send(coordinator, requestBuilder, joinGroupTimeoutMs) + .compose(new JoinGroupResponseHandler(generation)); + } + + private class JoinGroupResponseHandler extends CoordinatorResponseHandler { + private JoinGroupResponseHandler(final Generation generation) { + super(generation); + } + + @Override + public void handle(JoinGroupResponse joinResponse, RequestFuture future) { + Errors error = joinResponse.error(); + if (error == Errors.NONE) { + if (isProtocolTypeInconsistent(joinResponse.data().protocolType())) { + log.error("JoinGroup failed: Inconsistent Protocol Type, received {} but expected {}", + joinResponse.data().protocolType(), protocolType()); + future.raise(Errors.INCONSISTENT_GROUP_PROTOCOL); + } else { + log.debug("Received successful JoinGroup response: {}", joinResponse); + sensors.joinSensor.record(response.requestLatencyMs()); + + synchronized (AbstractCoordinator.this) { + if (state != MemberState.PREPARING_REBALANCE) { + // if the consumer was woken up before a rebalance completes, we may have already left + // the group. In this case, we do not want to continue with the sync group. + future.raise(new UnjoinedGroupException()); + } else { + state = MemberState.COMPLETING_REBALANCE; + + // we only need to enable heartbeat thread whenever we transit to + // COMPLETING_REBALANCE state since we always transit from this state to STABLE + if (heartbeatThread != null) + heartbeatThread.enable(); + + AbstractCoordinator.this.generation = new Generation( + joinResponse.data().generationId(), + joinResponse.data().memberId(), joinResponse.data().protocolName()); + + log.info("Successfully joined group with generation {}", AbstractCoordinator.this.generation); + + if (joinResponse.isLeader()) { + onJoinLeader(joinResponse).chain(future); + } else { + onJoinFollower().chain(future); + } + } + } + } + } else if (error == Errors.COORDINATOR_LOAD_IN_PROGRESS) { + log.info("JoinGroup failed: Coordinator {} is loading the group.", coordinator()); + // backoff and retry + future.raise(error); + } else if (error == Errors.UNKNOWN_MEMBER_ID) { + log.info("JoinGroup failed: {} Need to re-join the group. Sent generation was {}", + error.message(), sentGeneration); + // only need to reset the member id if generation has not been changed, + // then retry immediately + if (generationUnchanged()) + resetGenerationOnResponseError(ApiKeys.JOIN_GROUP, error); + + future.raise(error); + } else if (error == Errors.COORDINATOR_NOT_AVAILABLE + || error == Errors.NOT_COORDINATOR) { + // re-discover the coordinator and retry with backoff + markCoordinatorUnknown(error); + log.info("JoinGroup failed: {} Marking coordinator unknown. Sent generation was {}", + error.message(), sentGeneration); + future.raise(error); + } else if (error == Errors.FENCED_INSTANCE_ID) { + // for join-group request, even if the generation has changed we would not expect the instance id + // gets fenced, and hence we always treat this as a fatal error + log.error("JoinGroup failed: The group instance id {} has been fenced by another instance. " + + "Sent generation was {}", rebalanceConfig.groupInstanceId, sentGeneration); + future.raise(error); + } else if (error == Errors.INCONSISTENT_GROUP_PROTOCOL + || error == Errors.INVALID_SESSION_TIMEOUT + || error == Errors.INVALID_GROUP_ID + || error == Errors.GROUP_AUTHORIZATION_FAILED + || error == Errors.GROUP_MAX_SIZE_REACHED) { + // log the error and re-throw the exception + log.error("JoinGroup failed due to fatal error: {}", error.message()); + if (error == Errors.GROUP_MAX_SIZE_REACHED) { + future.raise(new GroupMaxSizeReachedException("Consumer group " + rebalanceConfig.groupId + + " already has the configured maximum number of members.")); + } else if (error == Errors.GROUP_AUTHORIZATION_FAILED) { + future.raise(GroupAuthorizationException.forGroupId(rebalanceConfig.groupId)); + } else { + future.raise(error); + } + } else if (error == Errors.UNSUPPORTED_VERSION) { + log.error("JoinGroup failed due to unsupported version error. Please unset field group.instance.id " + + "and retry to see if the problem resolves"); + future.raise(error); + } else if (error == Errors.MEMBER_ID_REQUIRED) { + // Broker requires a concrete member id to be allowed to join the group. Update member id + // and send another join group request in next cycle. + String memberId = joinResponse.data().memberId(); + log.debug("JoinGroup failed due to non-fatal error: {} Will set the member id as {} and then rejoin. " + + "Sent generation was {}", error, memberId, sentGeneration); + synchronized (AbstractCoordinator.this) { + AbstractCoordinator.this.generation = new Generation(OffsetCommitRequest.DEFAULT_GENERATION_ID, memberId, null); + } + requestRejoin("need to re-join with the given member-id"); + + future.raise(error); + } else if (error == Errors.REBALANCE_IN_PROGRESS) { + log.info("JoinGroup failed due to non-fatal error: REBALANCE_IN_PROGRESS, " + + "which could indicate a replication timeout on the broker. Will retry."); + future.raise(error); + } else { + // unexpected error, throw the exception + log.error("JoinGroup failed due to unexpected error: {}", error.message()); + future.raise(new KafkaException("Unexpected error in join group response: " + error.message())); + } + } + } + + private RequestFuture onJoinFollower() { + // send follower's sync group with an empty assignment + SyncGroupRequest.Builder requestBuilder = + new SyncGroupRequest.Builder( + new SyncGroupRequestData() + .setGroupId(rebalanceConfig.groupId) + .setMemberId(generation.memberId) + .setProtocolType(protocolType()) + .setProtocolName(generation.protocolName) + .setGroupInstanceId(this.rebalanceConfig.groupInstanceId.orElse(null)) + .setGenerationId(generation.generationId) + .setAssignments(Collections.emptyList()) + ); + log.debug("Sending follower SyncGroup to coordinator {} at generation {}: {}", this.coordinator, this.generation, requestBuilder); + return sendSyncGroupRequest(requestBuilder); + } + + private RequestFuture onJoinLeader(JoinGroupResponse joinResponse) { + try { + // perform the leader synchronization and send back the assignment for the group + Map groupAssignment = performAssignment(joinResponse.data().leader(), joinResponse.data().protocolName(), + joinResponse.data().members()); + + List groupAssignmentList = new ArrayList<>(); + for (Map.Entry assignment : groupAssignment.entrySet()) { + groupAssignmentList.add(new SyncGroupRequestData.SyncGroupRequestAssignment() + .setMemberId(assignment.getKey()) + .setAssignment(Utils.toArray(assignment.getValue())) + ); + } + + SyncGroupRequest.Builder requestBuilder = + new SyncGroupRequest.Builder( + new SyncGroupRequestData() + .setGroupId(rebalanceConfig.groupId) + .setMemberId(generation.memberId) + .setProtocolType(protocolType()) + .setProtocolName(generation.protocolName) + .setGroupInstanceId(this.rebalanceConfig.groupInstanceId.orElse(null)) + .setGenerationId(generation.generationId) + .setAssignments(groupAssignmentList) + ); + log.debug("Sending leader SyncGroup to coordinator {} at generation {}: {}", this.coordinator, this.generation, requestBuilder); + return sendSyncGroupRequest(requestBuilder); + } catch (RuntimeException e) { + return RequestFuture.failure(e); + } + } + + private RequestFuture sendSyncGroupRequest(SyncGroupRequest.Builder requestBuilder) { + if (coordinatorUnknown()) + return RequestFuture.coordinatorNotAvailable(); + return client.send(coordinator, requestBuilder) + .compose(new SyncGroupResponseHandler(generation)); + } + + private class SyncGroupResponseHandler extends CoordinatorResponseHandler { + private SyncGroupResponseHandler(final Generation generation) { + super(generation); + } + + @Override + public void handle(SyncGroupResponse syncResponse, + RequestFuture future) { + Errors error = syncResponse.error(); + if (error == Errors.NONE) { + if (isProtocolTypeInconsistent(syncResponse.data().protocolType())) { + log.error("SyncGroup failed due to inconsistent Protocol Type, received {} but expected {}", + syncResponse.data().protocolType(), protocolType()); + future.raise(Errors.INCONSISTENT_GROUP_PROTOCOL); + } else { + log.debug("Received successful SyncGroup response: {}", syncResponse); + sensors.syncSensor.record(response.requestLatencyMs()); + + synchronized (AbstractCoordinator.this) { + if (!generation.equals(Generation.NO_GENERATION) && state == MemberState.COMPLETING_REBALANCE) { + // check protocol name only if the generation is not reset + final String protocolName = syncResponse.data().protocolName(); + final boolean protocolNameInconsistent = protocolName != null && + !protocolName.equals(generation.protocolName); + + if (protocolNameInconsistent) { + log.error("SyncGroup failed due to inconsistent Protocol Name, received {} but expected {}", + protocolName, generation.protocolName); + + future.raise(Errors.INCONSISTENT_GROUP_PROTOCOL); + } else { + log.info("Successfully synced group in generation {}", generation); + state = MemberState.STABLE; + rejoinNeeded = false; + // record rebalance latency + lastRebalanceEndMs = time.milliseconds(); + sensors.successfulRebalanceSensor.record(lastRebalanceEndMs - lastRebalanceStartMs); + lastRebalanceStartMs = -1L; + + future.complete(ByteBuffer.wrap(syncResponse.data().assignment())); + } + } else { + log.info("Generation data was cleared by heartbeat thread to {} and state is now {} before " + + "receiving SyncGroup response, marking this rebalance as failed and retry", + generation, state); + // use ILLEGAL_GENERATION error code to let it retry immediately + future.raise(Errors.ILLEGAL_GENERATION); + } + } + } + } else { + if (error == Errors.GROUP_AUTHORIZATION_FAILED) { + future.raise(GroupAuthorizationException.forGroupId(rebalanceConfig.groupId)); + } else if (error == Errors.REBALANCE_IN_PROGRESS) { + log.info("SyncGroup failed: The group began another rebalance. Need to re-join the group. " + + "Sent generation was {}", sentGeneration); + future.raise(error); + } else if (error == Errors.FENCED_INSTANCE_ID) { + // for sync-group request, even if the generation has changed we would not expect the instance id + // gets fenced, and hence we always treat this as a fatal error + log.error("SyncGroup failed: The group instance id {} has been fenced by another instance. " + + "Sent generation was {}", rebalanceConfig.groupInstanceId, sentGeneration); + future.raise(error); + } else if (error == Errors.UNKNOWN_MEMBER_ID + || error == Errors.ILLEGAL_GENERATION) { + log.info("SyncGroup failed: {} Need to re-join the group. Sent generation was {}", + error.message(), sentGeneration); + if (generationUnchanged()) + resetGenerationOnResponseError(ApiKeys.SYNC_GROUP, error); + + future.raise(error); + } else if (error == Errors.COORDINATOR_NOT_AVAILABLE + || error == Errors.NOT_COORDINATOR) { + log.info("SyncGroup failed: {} Marking coordinator unknown. Sent generation was {}", + error.message(), sentGeneration); + markCoordinatorUnknown(error); + future.raise(error); + } else { + future.raise(new KafkaException("Unexpected error from SyncGroup: " + error.message())); + } + } + } + } + + /** + * Discover the current coordinator for the group. Sends a GroupMetadata request to + * one of the brokers. The returned future should be polled to get the result of the request. + * @return A request future which indicates the completion of the metadata request + */ + private RequestFuture sendFindCoordinatorRequest(Node node) { + // initiate the group metadata request + log.debug("Sending FindCoordinator request to broker {}", node); + FindCoordinatorRequestData data = new FindCoordinatorRequestData() + .setKeyType(CoordinatorType.GROUP.id()) + .setKey(this.rebalanceConfig.groupId); + FindCoordinatorRequest.Builder requestBuilder = new FindCoordinatorRequest.Builder(data); + return client.send(node, requestBuilder) + .compose(new FindCoordinatorResponseHandler()); + } + + private class FindCoordinatorResponseHandler extends RequestFutureAdapter { + + @Override + public void onSuccess(ClientResponse resp, RequestFuture future) { + log.debug("Received FindCoordinator response {}", resp); + + List coordinators = ((FindCoordinatorResponse) resp.responseBody()).coordinators(); + if (coordinators.size() != 1) { + log.error("Group coordinator lookup failed: Invalid response containing more than a single coordinator"); + future.raise(new IllegalStateException("Group coordinator lookup failed: Invalid response containing more than a single coordinator")); + } + Coordinator coordinatorData = coordinators.get(0); + Errors error = Errors.forCode(coordinatorData.errorCode()); + if (error == Errors.NONE) { + synchronized (AbstractCoordinator.this) { + // use MAX_VALUE - node.id as the coordinator id to allow separate connections + // for the coordinator in the underlying network client layer + int coordinatorConnectionId = Integer.MAX_VALUE - coordinatorData.nodeId(); + + AbstractCoordinator.this.coordinator = new Node( + coordinatorConnectionId, + coordinatorData.host(), + coordinatorData.port()); + log.info("Discovered group coordinator {}", coordinator); + client.tryConnect(coordinator); + heartbeat.resetSessionTimeout(); + } + future.complete(null); + } else if (error == Errors.GROUP_AUTHORIZATION_FAILED) { + future.raise(GroupAuthorizationException.forGroupId(rebalanceConfig.groupId)); + } else { + log.debug("Group coordinator lookup failed: {}", coordinatorData.errorMessage()); + future.raise(error); + } + } + + @Override + public void onFailure(RuntimeException e, RequestFuture future) { + log.debug("FindCoordinator request failed due to {}", e.toString()); + + if (!(e instanceof RetriableException)) { + // Remember the exception if fatal so we can ensure it gets thrown by the main thread + fatalFindCoordinatorException = e; + } + + super.onFailure(e, future); + } + } + + /** + * Check if we know who the coordinator is and we have an active connection + * @return true if the coordinator is unknown + */ + public boolean coordinatorUnknown() { + return checkAndGetCoordinator() == null; + } + + /** + * Get the coordinator if its connection is still active. Otherwise mark it unknown and + * return null. + * + * @return the current coordinator or null if it is unknown + */ + protected synchronized Node checkAndGetCoordinator() { + if (coordinator != null && client.isUnavailable(coordinator)) { + markCoordinatorUnknown(true, "coordinator unavailable"); + return null; + } + return this.coordinator; + } + + private synchronized Node coordinator() { + return this.coordinator; + } + + + protected synchronized void markCoordinatorUnknown(Errors error) { + markCoordinatorUnknown(false, "error response " + error.name()); + } + + protected synchronized void markCoordinatorUnknown(String cause) { + markCoordinatorUnknown(false, cause); + } + + protected synchronized void markCoordinatorUnknown(boolean isDisconnected, String cause) { + if (this.coordinator != null) { + log.info("Group coordinator {} is unavailable or invalid due to cause: {}." + + "isDisconnected: {}. Rediscovery will be attempted.", this.coordinator, + cause, isDisconnected); + Node oldCoordinator = this.coordinator; + + // Mark the coordinator dead before disconnecting requests since the callbacks for any pending + // requests may attempt to do likewise. This also prevents new requests from being sent to the + // coordinator while the disconnect is in progress. + this.coordinator = null; + + // Disconnect from the coordinator to ensure that there are no in-flight requests remaining. + // Pending callbacks will be invoked with a DisconnectException on the next call to poll. + if (!isDisconnected) { + log.info("Requesting disconnect from last known coordinator {}", oldCoordinator); + client.disconnectAsync(oldCoordinator); + } + + lastTimeOfConnectionMs = time.milliseconds(); + } else { + long durationOfOngoingDisconnect = time.milliseconds() - lastTimeOfConnectionMs; + if (durationOfOngoingDisconnect > rebalanceConfig.rebalanceTimeoutMs) + log.warn("Consumer has been disconnected from the group coordinator for {}ms", durationOfOngoingDisconnect); + } + } + + /** + * Get the current generation state, regardless of whether it is currently stable. + * Note that the generation information can be updated while we are still in the middle + * of a rebalance, after the join-group response is received. + * + * @return the current generation + */ + protected synchronized Generation generation() { + return generation; + } + + /** + * Get the current generation state if the group is stable, otherwise return null + * + * @return the current generation or null + */ + protected synchronized Generation generationIfStable() { + if (this.state != MemberState.STABLE) + return null; + return generation; + } + + protected synchronized boolean rebalanceInProgress() { + return this.state == MemberState.PREPARING_REBALANCE || this.state == MemberState.COMPLETING_REBALANCE; + } + + protected synchronized String memberId() { + return generation.memberId; + } + + private synchronized void resetStateAndGeneration(final String reason) { + log.info("Resetting generation due to: {}", reason); + + state = MemberState.UNJOINED; + generation = Generation.NO_GENERATION; + } + + private synchronized void resetStateAndRejoin(final String reason) { + resetStateAndGeneration(reason); + requestRejoin(reason); + needsJoinPrepare = true; + } + + synchronized void resetGenerationOnResponseError(ApiKeys api, Errors error) { + final String reason = String.format("encountered %s from %s response", error, api); + resetStateAndRejoin(reason); + } + + synchronized void resetGenerationOnLeaveGroup() { + resetStateAndRejoin("consumer pro-actively leaving the group"); + } + + public synchronized void requestRejoinIfNecessary(final String reason) { + if (!this.rejoinNeeded) { + requestRejoin(reason); + } + } + + public synchronized void requestRejoin(final String reason) { + log.info("Request joining group due to: {}", reason); + this.rejoinNeeded = true; + } + + private boolean isProtocolTypeInconsistent(String protocolType) { + return protocolType != null && !protocolType.equals(protocolType()); + } + + /** + * Close the coordinator, waiting if needed to send LeaveGroup. + */ + @Override + public final void close() { + close(time.timer(0)); + } + + /** + * @throws KafkaException if the rebalance callback throws exception + */ + protected void close(Timer timer) { + try { + closeHeartbeatThread(); + } finally { + // Synchronize after closing the heartbeat thread since heartbeat thread + // needs this lock to complete and terminate after close flag is set. + synchronized (this) { + if (rebalanceConfig.leaveGroupOnClose) { + onLeavePrepare(); + maybeLeaveGroup("the consumer is being closed"); + } + + // At this point, there may be pending commits (async commits or sync commits that were + // interrupted using wakeup) and the leave group request which have been queued, but not + // yet sent to the broker. Wait up to close timeout for these pending requests to be processed. + // If coordinator is not known, requests are aborted. + Node coordinator = checkAndGetCoordinator(); + if (coordinator != null && !client.awaitPendingRequests(coordinator, timer)) + log.warn("Close timed out with {} pending requests to coordinator, terminating client connections", + client.pendingRequestCount(coordinator)); + } + } + } + + /** + * Sends LeaveGroupRequest and logs the {@code leaveReason}, unless this member is using static membership or is already + * not part of the group (ie does not have a valid member id, is in the UNJOINED state, or the coordinator is unknown). + * + * @param leaveReason the reason to leave the group for logging + * @throws KafkaException if the rebalance callback throws exception + */ + public synchronized RequestFuture maybeLeaveGroup(String leaveReason) { + RequestFuture future = null; + + // Starting from 2.3, only dynamic members will send LeaveGroupRequest to the broker, + // consumer with valid group.instance.id is viewed as static member that never sends LeaveGroup, + // and the membership expiration is only controlled by session timeout. + if (isDynamicMember() && !coordinatorUnknown() && + state != MemberState.UNJOINED && generation.hasMemberId()) { + // this is a minimal effort attempt to leave the group. we do not + // attempt any resending if the request fails or times out. + log.info("Member {} sending LeaveGroup request to coordinator {} due to {}", + generation.memberId, coordinator, leaveReason); + LeaveGroupRequest.Builder request = new LeaveGroupRequest.Builder( + rebalanceConfig.groupId, + Collections.singletonList(new MemberIdentity().setMemberId(generation.memberId)) + ); + + future = client.send(coordinator, request).compose(new LeaveGroupResponseHandler(generation)); + client.pollNoWakeup(); + } + + resetGenerationOnLeaveGroup(); + + return future; + } + + protected boolean isDynamicMember() { + return !rebalanceConfig.groupInstanceId.isPresent(); + } + + private class LeaveGroupResponseHandler extends CoordinatorResponseHandler { + private LeaveGroupResponseHandler(final Generation generation) { + super(generation); + } + + @Override + public void handle(LeaveGroupResponse leaveResponse, RequestFuture future) { + final List members = leaveResponse.memberResponses(); + if (members.size() > 1) { + future.raise(new IllegalStateException("The expected leave group response " + + "should only contain no more than one member info, however get " + members)); + } + + final Errors error = leaveResponse.error(); + if (error == Errors.NONE) { + log.debug("LeaveGroup response with {} returned successfully: {}", sentGeneration, response); + future.complete(null); + } else { + log.error("LeaveGroup request with {} failed with error: {}", sentGeneration, error.message()); + future.raise(error); + } + } + } + + // visible for testing + synchronized RequestFuture sendHeartbeatRequest() { + log.debug("Sending Heartbeat request with generation {} and member id {} to coordinator {}", + generation.generationId, generation.memberId, coordinator); + HeartbeatRequest.Builder requestBuilder = + new HeartbeatRequest.Builder(new HeartbeatRequestData() + .setGroupId(rebalanceConfig.groupId) + .setMemberId(this.generation.memberId) + .setGroupInstanceId(this.rebalanceConfig.groupInstanceId.orElse(null)) + .setGenerationId(this.generation.generationId)); + return client.send(coordinator, requestBuilder) + .compose(new HeartbeatResponseHandler(generation)); + } + + private class HeartbeatResponseHandler extends CoordinatorResponseHandler { + private HeartbeatResponseHandler(final Generation generation) { + super(generation); + } + + @Override + public void handle(HeartbeatResponse heartbeatResponse, RequestFuture future) { + sensors.heartbeatSensor.record(response.requestLatencyMs()); + Errors error = heartbeatResponse.error(); + + if (error == Errors.NONE) { + log.debug("Received successful Heartbeat response"); + future.complete(null); + } else if (error == Errors.COORDINATOR_NOT_AVAILABLE + || error == Errors.NOT_COORDINATOR) { + log.info("Attempt to heartbeat failed since coordinator {} is either not started or not valid", + coordinator()); + markCoordinatorUnknown(error); + future.raise(error); + } else if (error == Errors.REBALANCE_IN_PROGRESS) { + // since we may be sending the request during rebalance, we should check + // this case and ignore the REBALANCE_IN_PROGRESS error + synchronized (AbstractCoordinator.this) { + if (state == MemberState.STABLE) { + requestRejoin("group is already rebalancing"); + future.raise(error); + } else { + log.debug("Ignoring heartbeat response with error {} during {} state", error, state); + future.complete(null); + } + } + } else if (error == Errors.ILLEGAL_GENERATION || + error == Errors.UNKNOWN_MEMBER_ID || + error == Errors.FENCED_INSTANCE_ID) { + if (generationUnchanged()) { + log.info("Attempt to heartbeat with {} and group instance id {} failed due to {}, resetting generation", + sentGeneration, rebalanceConfig.groupInstanceId, error); + resetGenerationOnResponseError(ApiKeys.HEARTBEAT, error); + future.raise(error); + } else { + // if the generation has changed, then ignore this error + log.info("Attempt to heartbeat with stale {} and group instance id {} failed due to {}, ignoring the error", + sentGeneration, rebalanceConfig.groupInstanceId, error); + future.complete(null); + } + } else if (error == Errors.GROUP_AUTHORIZATION_FAILED) { + future.raise(GroupAuthorizationException.forGroupId(rebalanceConfig.groupId)); + } else { + future.raise(new KafkaException("Unexpected error in heartbeat response: " + error.message())); + } + } + } + + protected abstract class CoordinatorResponseHandler extends RequestFutureAdapter { + CoordinatorResponseHandler(final Generation generation) { + this.sentGeneration = generation; + } + + final Generation sentGeneration; + ClientResponse response; + + public abstract void handle(R response, RequestFuture future); + + @Override + public void onFailure(RuntimeException e, RequestFuture future) { + // mark the coordinator as dead + if (e instanceof DisconnectException) { + markCoordinatorUnknown(true, e.getMessage()); + } + future.raise(e); + } + + @Override + @SuppressWarnings("unchecked") + public void onSuccess(ClientResponse clientResponse, RequestFuture future) { + try { + this.response = clientResponse; + R responseObj = (R) clientResponse.responseBody(); + handle(responseObj, future); + } catch (RuntimeException e) { + if (!future.isDone()) + future.raise(e); + } + } + + boolean generationUnchanged() { + synchronized (AbstractCoordinator.this) { + return generation.equals(sentGeneration); + } + } + } + + protected Meter createMeter(Metrics metrics, String groupName, String baseName, String descriptiveName) { + return new Meter(new WindowedCount(), + metrics.metricName(baseName + "-rate", groupName, + String.format("The number of %s per second", descriptiveName)), + metrics.metricName(baseName + "-total", groupName, + String.format("The total number of %s", descriptiveName))); + } + + private class GroupCoordinatorMetrics { + public final String metricGrpName; + + public final Sensor heartbeatSensor; + public final Sensor joinSensor; + public final Sensor syncSensor; + public final Sensor successfulRebalanceSensor; + public final Sensor failedRebalanceSensor; + + public GroupCoordinatorMetrics(Metrics metrics, String metricGrpPrefix) { + this.metricGrpName = metricGrpPrefix + "-coordinator-metrics"; + + this.heartbeatSensor = metrics.sensor("heartbeat-latency"); + this.heartbeatSensor.add(metrics.metricName("heartbeat-response-time-max", + this.metricGrpName, + "The max time taken to receive a response to a heartbeat request"), new Max()); + this.heartbeatSensor.add(createMeter(metrics, metricGrpName, "heartbeat", "heartbeats")); + + this.joinSensor = metrics.sensor("join-latency"); + this.joinSensor.add(metrics.metricName("join-time-avg", + this.metricGrpName, + "The average time taken for a group rejoin"), new Avg()); + this.joinSensor.add(metrics.metricName("join-time-max", + this.metricGrpName, + "The max time taken for a group rejoin"), new Max()); + this.joinSensor.add(createMeter(metrics, metricGrpName, "join", "group joins")); + + this.syncSensor = metrics.sensor("sync-latency"); + this.syncSensor.add(metrics.metricName("sync-time-avg", + this.metricGrpName, + "The average time taken for a group sync"), new Avg()); + this.syncSensor.add(metrics.metricName("sync-time-max", + this.metricGrpName, + "The max time taken for a group sync"), new Max()); + this.syncSensor.add(createMeter(metrics, metricGrpName, "sync", "group syncs")); + + this.successfulRebalanceSensor = metrics.sensor("rebalance-latency"); + this.successfulRebalanceSensor.add(metrics.metricName("rebalance-latency-avg", + this.metricGrpName, + "The average time taken for a group to complete a successful rebalance, which may be composed of " + + "several failed re-trials until it succeeded"), new Avg()); + this.successfulRebalanceSensor.add(metrics.metricName("rebalance-latency-max", + this.metricGrpName, + "The max time taken for a group to complete a successful rebalance, which may be composed of " + + "several failed re-trials until it succeeded"), new Max()); + this.successfulRebalanceSensor.add(metrics.metricName("rebalance-latency-total", + this.metricGrpName, + "The total number of milliseconds this consumer has spent in successful rebalances since creation"), + new CumulativeSum()); + this.successfulRebalanceSensor.add( + metrics.metricName("rebalance-total", + this.metricGrpName, + "The total number of successful rebalance events, each event is composed of " + + "several failed re-trials until it succeeded"), + new CumulativeCount() + ); + this.successfulRebalanceSensor.add( + metrics.metricName( + "rebalance-rate-per-hour", + this.metricGrpName, + "The number of successful rebalance events per hour, each event is composed of " + + "several failed re-trials until it succeeded"), + new Rate(TimeUnit.HOURS, new WindowedCount()) + ); + + this.failedRebalanceSensor = metrics.sensor("failed-rebalance"); + this.failedRebalanceSensor.add( + metrics.metricName("failed-rebalance-total", + this.metricGrpName, + "The total number of failed rebalance events"), + new CumulativeCount() + ); + this.failedRebalanceSensor.add( + metrics.metricName( + "failed-rebalance-rate-per-hour", + this.metricGrpName, + "The number of failed rebalance events per hour"), + new Rate(TimeUnit.HOURS, new WindowedCount()) + ); + + Measurable lastRebalance = (config, now) -> { + if (lastRebalanceEndMs == -1L) + // if no rebalance is ever triggered, we just return -1. + return -1d; + else + return TimeUnit.SECONDS.convert(now - lastRebalanceEndMs, TimeUnit.MILLISECONDS); + }; + metrics.addMetric(metrics.metricName("last-rebalance-seconds-ago", + this.metricGrpName, + "The number of seconds since the last successful rebalance event"), + lastRebalance); + + Measurable lastHeartbeat = (config, now) -> { + if (heartbeat.lastHeartbeatSend() == 0L) + // if no heartbeat is ever triggered, just return -1. + return -1d; + else + return TimeUnit.SECONDS.convert(now - heartbeat.lastHeartbeatSend(), TimeUnit.MILLISECONDS); + }; + metrics.addMetric(metrics.metricName("last-heartbeat-seconds-ago", + this.metricGrpName, + "The number of seconds since the last coordinator heartbeat was sent"), + lastHeartbeat); + } + } + + private class HeartbeatThread extends KafkaThread implements AutoCloseable { + private boolean enabled = false; + private boolean closed = false; + private final AtomicReference failed = new AtomicReference<>(null); + + private HeartbeatThread() { + super(HEARTBEAT_THREAD_PREFIX + (rebalanceConfig.groupId.isEmpty() ? "" : " | " + rebalanceConfig.groupId), true); + } + + public void enable() { + synchronized (AbstractCoordinator.this) { + log.debug("Enabling heartbeat thread"); + this.enabled = true; + heartbeat.resetTimeouts(); + AbstractCoordinator.this.notify(); + } + } + + public void disable() { + synchronized (AbstractCoordinator.this) { + log.debug("Disabling heartbeat thread"); + this.enabled = false; + } + } + + public void close() { + synchronized (AbstractCoordinator.this) { + this.closed = true; + AbstractCoordinator.this.notify(); + } + } + + private boolean hasFailed() { + return failed.get() != null; + } + + private RuntimeException failureCause() { + return failed.get(); + } + + @Override + public void run() { + try { + log.debug("Heartbeat thread started"); + while (true) { + synchronized (AbstractCoordinator.this) { + if (closed) + return; + + if (!enabled) { + AbstractCoordinator.this.wait(); + continue; + } + + // we do not need to heartbeat we are not part of a group yet; + // also if we already have fatal error, the client will be + // crashed soon, hence we do not need to continue heartbeating either + if (state.hasNotJoinedGroup() || hasFailed()) { + disable(); + continue; + } + + client.pollNoWakeup(); + long now = time.milliseconds(); + + if (coordinatorUnknown()) { + if (findCoordinatorFuture != null) { + // clear the future so that after the backoff, if the hb still sees coordinator unknown in + // the next iteration it will try to re-discover the coordinator in case the main thread cannot + clearFindCoordinatorFuture(); + + // backoff properly + AbstractCoordinator.this.wait(rebalanceConfig.retryBackoffMs); + } else { + lookupCoordinator(); + } + } else if (heartbeat.sessionTimeoutExpired(now)) { + // the session timeout has expired without seeing a successful heartbeat, so we should + // probably make sure the coordinator is still healthy. + markCoordinatorUnknown("session timed out without receiving a " + + "heartbeat response"); + } else if (heartbeat.pollTimeoutExpired(now)) { + // the poll timeout has expired, which means that the foreground thread has stalled + // in between calls to poll(). + log.warn("consumer poll timeout has expired. This means the time between subsequent calls to poll() " + + "was longer than the configured max.poll.interval.ms, which typically implies that " + + "the poll loop is spending too much time processing messages. You can address this " + + "either by increasing max.poll.interval.ms or by reducing the maximum size of batches " + + "returned in poll() with max.poll.records."); + + maybeLeaveGroup("consumer poll timeout has expired."); + } else if (!heartbeat.shouldHeartbeat(now)) { + // poll again after waiting for the retry backoff in case the heartbeat failed or the + // coordinator disconnected + AbstractCoordinator.this.wait(rebalanceConfig.retryBackoffMs); + } else { + heartbeat.sentHeartbeat(now); + final RequestFuture heartbeatFuture = sendHeartbeatRequest(); + heartbeatFuture.addListener(new RequestFutureListener() { + @Override + public void onSuccess(Void value) { + synchronized (AbstractCoordinator.this) { + heartbeat.receiveHeartbeat(); + } + } + + @Override + public void onFailure(RuntimeException e) { + synchronized (AbstractCoordinator.this) { + if (e instanceof RebalanceInProgressException) { + // it is valid to continue heartbeating while the group is rebalancing. This + // ensures that the coordinator keeps the member in the group for as long + // as the duration of the rebalance timeout. If we stop sending heartbeats, + // however, then the session timeout may expire before we can rejoin. + heartbeat.receiveHeartbeat(); + } else if (e instanceof FencedInstanceIdException) { + log.error("Caught fenced group.instance.id {} error in heartbeat thread", rebalanceConfig.groupInstanceId); + heartbeatThread.failed.set(e); + } else { + heartbeat.failHeartbeat(); + // wake up the thread if it's sleeping to reschedule the heartbeat + AbstractCoordinator.this.notify(); + } + } + } + }); + } + } + } + } catch (AuthenticationException e) { + log.error("An authentication error occurred in the heartbeat thread", e); + this.failed.set(e); + } catch (GroupAuthorizationException e) { + log.error("A group authorization error occurred in the heartbeat thread", e); + this.failed.set(e); + } catch (InterruptedException | InterruptException e) { + Thread.interrupted(); + log.error("Unexpected interrupt received in heartbeat thread", e); + this.failed.set(new RuntimeException(e)); + } catch (Throwable e) { + log.error("Heartbeat thread failed due to unexpected error", e); + if (e instanceof RuntimeException) + this.failed.set((RuntimeException) e); + else + this.failed.set(new RuntimeException(e)); + } finally { + log.debug("Heartbeat thread has closed"); + } + } + + } + + protected static class Generation { + public static final Generation NO_GENERATION = new Generation( + OffsetCommitRequest.DEFAULT_GENERATION_ID, + JoinGroupRequest.UNKNOWN_MEMBER_ID, + null); + + public final int generationId; + public final String memberId; + public final String protocolName; + + public Generation(int generationId, String memberId, String protocolName) { + this.generationId = generationId; + this.memberId = memberId; + this.protocolName = protocolName; + } + + /** + * @return true if this generation has a valid member id, false otherwise. A member might have an id before + * it becomes part of a group generation. + */ + public boolean hasMemberId() { + return !memberId.isEmpty(); + } + + @Override + public boolean equals(final Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + final Generation that = (Generation) o; + return generationId == that.generationId && + Objects.equals(memberId, that.memberId) && + Objects.equals(protocolName, that.protocolName); + } + + @Override + public int hashCode() { + return Objects.hash(generationId, memberId, protocolName); + } + + @Override + public String toString() { + return "Generation{" + + "generationId=" + generationId + + ", memberId='" + memberId + '\'' + + ", protocol='" + protocolName + '\'' + + '}'; + } + } + + @SuppressWarnings("serial") + private static class UnjoinedGroupException extends RetriableException { + + } + + // For testing only below + final Heartbeat heartbeat() { + return heartbeat; + } + + final synchronized void setLastRebalanceTime(final long timestamp) { + lastRebalanceEndMs = timestamp; + } + + /** + * Check whether given generation id is matching the record within current generation. + * + * @param generationId generation id + * @return true if the two ids are matching. + */ + final boolean hasMatchingGenerationId(int generationId) { + return !generation.equals(Generation.NO_GENERATION) && generation.generationId == generationId; + } + + final boolean hasUnknownGeneration() { + return generation.equals(Generation.NO_GENERATION); + } + + /** + * @return true if the current generation's member ID is valid, false otherwise + */ + final boolean hasValidMemberId() { + return !hasUnknownGeneration() && generation.hasMemberId(); + } + + final synchronized void setNewGeneration(final Generation generation) { + this.generation = generation; + } + + final synchronized void setNewState(final MemberState state) { + this.state = state; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractPartitionAssignor.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractPartitionAssignor.java new file mode 100644 index 0000000..ed0282b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractPartitionAssignor.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.TopicPartition; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +/** + * Abstract assignor implementation which does some common grunt work (in particular collecting + * partition counts which are always needed in assignors). + */ +public abstract class AbstractPartitionAssignor implements ConsumerPartitionAssignor { + private static final Logger log = LoggerFactory.getLogger(AbstractPartitionAssignor.class); + + /** + * Perform the group assignment given the partition counts and member subscriptions + * @param partitionsPerTopic The number of partitions for each subscribed topic. Topics not in metadata will be excluded + * from this map. + * @param subscriptions Map from the member id to their respective topic subscription + * @return Map from each member to the list of partitions assigned to them. + */ + public abstract Map> assign(Map partitionsPerTopic, + Map subscriptions); + + @Override + public GroupAssignment assign(Cluster metadata, GroupSubscription groupSubscription) { + Map subscriptions = groupSubscription.groupSubscription(); + Set allSubscribedTopics = new HashSet<>(); + for (Map.Entry subscriptionEntry : subscriptions.entrySet()) + allSubscribedTopics.addAll(subscriptionEntry.getValue().topics()); + + Map partitionsPerTopic = new HashMap<>(); + for (String topic : allSubscribedTopics) { + Integer numPartitions = metadata.partitionCountForTopic(topic); + if (numPartitions != null && numPartitions > 0) + partitionsPerTopic.put(topic, numPartitions); + else + log.debug("Skipping assignment for topic {} since no metadata is available", topic); + } + + Map> rawAssignments = assign(partitionsPerTopic, subscriptions); + + // this class maintains no user data, so just wrap the results + Map assignments = new HashMap<>(); + for (Map.Entry> assignmentEntry : rawAssignments.entrySet()) + assignments.put(assignmentEntry.getKey(), new Assignment(assignmentEntry.getValue())); + return new GroupAssignment(assignments); + } + + protected static void put(Map> map, K key, V value) { + List list = map.computeIfAbsent(key, k -> new ArrayList<>()); + list.add(value); + } + + protected static List partitions(String topic, int numPartitions) { + List partitions = new ArrayList<>(numPartitions); + for (int i = 0; i < numPartitions; i++) + partitions.add(new TopicPartition(topic, i)); + return partitions; + } + + public static class MemberInfo implements Comparable { + public final String memberId; + public final Optional groupInstanceId; + + public MemberInfo(String memberId, Optional groupInstanceId) { + this.memberId = memberId; + this.groupInstanceId = groupInstanceId; + } + + @Override + public int compareTo(MemberInfo otherMemberInfo) { + if (this.groupInstanceId.isPresent() && + otherMemberInfo.groupInstanceId.isPresent()) { + return this.groupInstanceId.get() + .compareTo(otherMemberInfo.groupInstanceId.get()); + } else if (this.groupInstanceId.isPresent()) { + return -1; + } else if (otherMemberInfo.groupInstanceId.isPresent()) { + return 1; + } else { + return this.memberId.compareTo(otherMemberInfo.memberId); + } + } + + @Override + public boolean equals(Object o) { + return o instanceof MemberInfo && this.memberId.equals(((MemberInfo) o).memberId); + } + + /** + * We could just use member.id to be the hashcode, since it's unique + * across the group. + */ + @Override + public int hashCode() { + return memberId.hashCode(); + } + + @Override + public String toString() { + return "MemberInfo [member.id: " + memberId + + ", group.instance.id: " + groupInstanceId.orElse("{}") + + "]"; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractStickyAssignor.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractStickyAssignor.java new file mode 100644 index 0000000..145c6ee --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AbstractStickyAssignor.java @@ -0,0 +1,1251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.Set; +import java.util.TreeSet; +import java.util.stream.Collectors; +import org.apache.kafka.common.TopicPartition; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public abstract class AbstractStickyAssignor extends AbstractPartitionAssignor { + private static final Logger log = LoggerFactory.getLogger(AbstractStickyAssignor.class); + + public static final int DEFAULT_GENERATION = -1; + public int maxGeneration = DEFAULT_GENERATION; + + private PartitionMovements partitionMovements; + + // Keep track of the partitions being migrated from one consumer to another during assignment + // so the cooperative assignor can adjust the assignment + protected Map partitionsTransferringOwnership = new HashMap<>(); + + static final class ConsumerGenerationPair { + final String consumer; + final int generation; + ConsumerGenerationPair(String consumer, int generation) { + this.consumer = consumer; + this.generation = generation; + } + } + + public static final class MemberData { + public final List partitions; + public final Optional generation; + public MemberData(List partitions, Optional generation) { + this.partitions = partitions; + this.generation = generation; + } + } + + abstract protected MemberData memberData(Subscription subscription); + + @Override + public Map> assign(Map partitionsPerTopic, + Map subscriptions) { + Map> consumerToOwnedPartitions = new HashMap<>(); + Set partitionsWithMultiplePreviousOwners = new HashSet<>(); + if (allSubscriptionsEqual(partitionsPerTopic.keySet(), subscriptions, consumerToOwnedPartitions, partitionsWithMultiplePreviousOwners)) { + log.debug("Detected that all consumers were subscribed to same set of topics, invoking the " + + "optimized assignment algorithm"); + partitionsTransferringOwnership = new HashMap<>(); + return constrainedAssign(partitionsPerTopic, consumerToOwnedPartitions, partitionsWithMultiplePreviousOwners); + } else { + log.debug("Detected that not all consumers were subscribed to same set of topics, falling back to the " + + "general case assignment algorithm"); + // we must set this to null for the general case so the cooperative assignor knows to compute it from scratch + partitionsTransferringOwnership = null; + return generalAssign(partitionsPerTopic, subscriptions, consumerToOwnedPartitions); + } + } + + /** + * Returns true iff all consumers have an identical subscription. Also fills out the passed in + * {@code consumerToOwnedPartitions} with each consumer's previously owned and still-subscribed partitions, + * and the {@code partitionsWithMultiplePreviousOwners} with any partitions claimed by multiple previous owners + */ + private boolean allSubscriptionsEqual(Set allTopics, + Map subscriptions, + Map> consumerToOwnedPartitions, + Set partitionsWithMultiplePreviousOwners) { + Set membersOfCurrentHighestGeneration = new HashSet<>(); + boolean isAllSubscriptionsEqual = true; + + Set subscribedTopics = new HashSet<>(); + + // keep track of all previously owned partitions so we can invalidate them if invalid input is + // detected, eg two consumers somehow claiming the same partition in the same/current generation + Map allPreviousPartitionsToOwner = new HashMap<>(); + + for (Map.Entry subscriptionEntry : subscriptions.entrySet()) { + String consumer = subscriptionEntry.getKey(); + Subscription subscription = subscriptionEntry.getValue(); + + // initialize the subscribed topics set if this is the first subscription + if (subscribedTopics.isEmpty()) { + subscribedTopics.addAll(subscription.topics()); + } else if (isAllSubscriptionsEqual && !(subscription.topics().size() == subscribedTopics.size() + && subscribedTopics.containsAll(subscription.topics()))) { + isAllSubscriptionsEqual = false; + } + + MemberData memberData = memberData(subscription); + + List ownedPartitions = new ArrayList<>(); + consumerToOwnedPartitions.put(consumer, ownedPartitions); + + // Only consider this consumer's owned partitions as valid if it is a member of the current highest + // generation, or it's generation is not present but we have not seen any known generation so far + if (memberData.generation.isPresent() && memberData.generation.get() >= maxGeneration + || !memberData.generation.isPresent() && maxGeneration == DEFAULT_GENERATION) { + + // If the current member's generation is higher, all the previously owned partitions are invalid + if (memberData.generation.isPresent() && memberData.generation.get() > maxGeneration) { + allPreviousPartitionsToOwner.clear(); + partitionsWithMultiplePreviousOwners.clear(); + for (String droppedOutConsumer : membersOfCurrentHighestGeneration) { + consumerToOwnedPartitions.get(droppedOutConsumer).clear(); + } + + membersOfCurrentHighestGeneration.clear(); + maxGeneration = memberData.generation.get(); + } + + membersOfCurrentHighestGeneration.add(consumer); + for (final TopicPartition tp : memberData.partitions) { + // filter out any topics that no longer exist or aren't part of the current subscription + if (allTopics.contains(tp.topic())) { + String otherConsumer = allPreviousPartitionsToOwner.put(tp, consumer); + if (otherConsumer == null) { + // this partition is not owned by other consumer in the same generation + ownedPartitions.add(tp); + } else { + log.error("Found multiple consumers {} and {} claiming the same TopicPartition {} in the " + + "same generation {}, this will be invalidated and removed from their previous assignment.", + consumer, otherConsumer, tp, maxGeneration); + consumerToOwnedPartitions.get(otherConsumer).remove(tp); + partitionsWithMultiplePreviousOwners.add(tp); + } + } + } + } + } + + return isAllSubscriptionsEqual; + } + + /** + * This constrainedAssign optimizes the assignment algorithm when all consumers were subscribed to same set of topics. + * The method includes the following steps: + * + * 1. Reassign previously owned partitions: + * a. if owned less than minQuota partitions, just assign all owned partitions, and put the member into unfilled member list + * b. if owned maxQuota or more, and we're still under the number of expected max capacity members, assign maxQuota partitions + * c. if owned at least "minQuota" of partitions, assign minQuota partitions, and put the member into unfilled member list if + * we're still under the number of expected max capacity members + * 2. Fill remaining members up to the expected numbers of maxQuota partitions, otherwise, to minQuota partitions + * + * @param partitionsPerTopic The number of partitions for each subscribed topic + * @param consumerToOwnedPartitions Each consumer's previously owned and still-subscribed partitions + * @param partitionsWithMultiplePreviousOwners The partitions being claimed in the previous assignment of multiple consumers + * + * @return Map from each member to the list of partitions assigned to them. + */ + private Map> constrainedAssign(Map partitionsPerTopic, + Map> consumerToOwnedPartitions, + Set partitionsWithMultiplePreviousOwners) { + if (log.isDebugEnabled()) { + log.debug("Performing constrained assign with partitionsPerTopic: {}, consumerToOwnedPartitions: {}.", + partitionsPerTopic, consumerToOwnedPartitions); + } + + Set allRevokedPartitions = new HashSet<>(); + + // the consumers which may still be assigned one or more partitions to reach expected capacity + List unfilledMembersWithUnderMinQuotaPartitions = new LinkedList<>(); + LinkedList unfilledMembersWithExactlyMinQuotaPartitions = new LinkedList<>(); + + int numberOfConsumers = consumerToOwnedPartitions.size(); + int totalPartitionsCount = partitionsPerTopic.values().stream().reduce(0, Integer::sum); + + int minQuota = (int) Math.floor(((double) totalPartitionsCount) / numberOfConsumers); + int maxQuota = (int) Math.ceil(((double) totalPartitionsCount) / numberOfConsumers); + // the expected number of members receiving more than minQuota partitions (zero when minQuota == maxQuota) + int expectedNumMembersWithOverMinQuotaPartitions = totalPartitionsCount % numberOfConsumers; + // the current number of members receiving more than minQuota partitions (zero when minQuota == maxQuota) + int currentNumMembersWithOverMinQuotaPartitions = 0; + + // initialize the assignment map with an empty array of size maxQuota for all members + Map> assignment = new HashMap<>( + consumerToOwnedPartitions.keySet().stream().collect(Collectors.toMap(c -> c, c -> new ArrayList<>(maxQuota)))); + + List assignedPartitions = new ArrayList<>(); + // Reassign previously owned partitions, up to the expected number of partitions per consumer + for (Map.Entry> consumerEntry : consumerToOwnedPartitions.entrySet()) { + String consumer = consumerEntry.getKey(); + List ownedPartitions = consumerEntry.getValue(); + + List consumerAssignment = assignment.get(consumer); + + for (TopicPartition doublyClaimedPartition : partitionsWithMultiplePreviousOwners) { + if (ownedPartitions.contains(doublyClaimedPartition)) { + log.error("Found partition {} still claimed as owned by consumer {}, despite being claimed by multiple " + + "consumers already in the same generation. Removing it from the ownedPartitions", + doublyClaimedPartition, consumer); + ownedPartitions.remove(doublyClaimedPartition); + } + } + + if (ownedPartitions.size() < minQuota) { + // the expected assignment size is more than this consumer has now, so keep all the owned partitions + // and put this member into the unfilled member list + if (ownedPartitions.size() > 0) { + consumerAssignment.addAll(ownedPartitions); + assignedPartitions.addAll(ownedPartitions); + } + unfilledMembersWithUnderMinQuotaPartitions.add(consumer); + } else if (ownedPartitions.size() >= maxQuota && currentNumMembersWithOverMinQuotaPartitions < expectedNumMembersWithOverMinQuotaPartitions) { + // consumer owned the "maxQuota" of partitions or more, and we're still under the number of expected members + // with more than the minQuota partitions, so keep "maxQuota" of the owned partitions, and revoke the rest of the partitions + currentNumMembersWithOverMinQuotaPartitions++; + if (currentNumMembersWithOverMinQuotaPartitions == expectedNumMembersWithOverMinQuotaPartitions) { + unfilledMembersWithExactlyMinQuotaPartitions.clear(); + } + List maxQuotaPartitions = ownedPartitions.subList(0, maxQuota); + consumerAssignment.addAll(maxQuotaPartitions); + assignedPartitions.addAll(maxQuotaPartitions); + allRevokedPartitions.addAll(ownedPartitions.subList(maxQuota, ownedPartitions.size())); + } else { + // consumer owned at least "minQuota" of partitions + // so keep "minQuota" of the owned partitions, and revoke the rest of the partitions + List minQuotaPartitions = ownedPartitions.subList(0, minQuota); + consumerAssignment.addAll(minQuotaPartitions); + assignedPartitions.addAll(minQuotaPartitions); + allRevokedPartitions.addAll(ownedPartitions.subList(minQuota, ownedPartitions.size())); + // this consumer is potential maxQuota candidate since we're still under the number of expected members + // with more than the minQuota partitions. Note, if the number of expected members with more than + // the minQuota partitions is 0, it means minQuota == maxQuota, and there are no potentially unfilled + if (currentNumMembersWithOverMinQuotaPartitions < expectedNumMembersWithOverMinQuotaPartitions) { + unfilledMembersWithExactlyMinQuotaPartitions.add(consumer); + } + } + } + + List unassignedPartitions = getUnassignedPartitions(totalPartitionsCount, partitionsPerTopic, assignedPartitions); + + if (log.isDebugEnabled()) { + log.debug("After reassigning previously owned partitions, unfilled members: {}, unassigned partitions: {}, " + + "current assignment: {}", unfilledMembersWithUnderMinQuotaPartitions, unassignedPartitions, assignment); + } + + Collections.sort(unfilledMembersWithUnderMinQuotaPartitions); + Collections.sort(unfilledMembersWithExactlyMinQuotaPartitions); + + Iterator unfilledConsumerIter = unfilledMembersWithUnderMinQuotaPartitions.iterator(); + // Round-Robin filling remaining members up to the expected numbers of maxQuota, otherwise, to minQuota + for (TopicPartition unassignedPartition : unassignedPartitions) { + String consumer; + if (unfilledConsumerIter.hasNext()) { + consumer = unfilledConsumerIter.next(); + } else { + if (unfilledMembersWithUnderMinQuotaPartitions.isEmpty() && unfilledMembersWithExactlyMinQuotaPartitions.isEmpty()) { + // Should not enter here since we have calculated the exact number to assign to each consumer. + // This indicates issues in the assignment algorithm + int currentPartitionIndex = unassignedPartitions.indexOf(unassignedPartition); + log.error("No more unfilled consumers to be assigned. The remaining unassigned partitions are: {}", + unassignedPartitions.subList(currentPartitionIndex, unassignedPartitions.size())); + throw new IllegalStateException("No more unfilled consumers to be assigned."); + } else if (unfilledMembersWithUnderMinQuotaPartitions.isEmpty()) { + consumer = unfilledMembersWithExactlyMinQuotaPartitions.poll(); + } else { + unfilledConsumerIter = unfilledMembersWithUnderMinQuotaPartitions.iterator(); + consumer = unfilledConsumerIter.next(); + } + } + + List consumerAssignment = assignment.get(consumer); + consumerAssignment.add(unassignedPartition); + + // We already assigned all possible ownedPartitions, so we know this must be newly assigned to this consumer + // or else the partition was actually claimed by multiple previous owners and had to be invalidated from all + // members claimed ownedPartitions + if (allRevokedPartitions.contains(unassignedPartition) || partitionsWithMultiplePreviousOwners.contains(unassignedPartition)) + partitionsTransferringOwnership.put(unassignedPartition, consumer); + + int currentAssignedCount = consumerAssignment.size(); + if (currentAssignedCount == minQuota) { + unfilledConsumerIter.remove(); + unfilledMembersWithExactlyMinQuotaPartitions.add(consumer); + } else if (currentAssignedCount == maxQuota) { + currentNumMembersWithOverMinQuotaPartitions++; + if (currentNumMembersWithOverMinQuotaPartitions == expectedNumMembersWithOverMinQuotaPartitions) { + // We only start to iterate over the "potentially unfilled" members at minQuota after we've filled + // all members up to at least minQuota, so once the last minQuota member reaches maxQuota, we + // should be done. But in case of some algorithmic error, just log a warning and continue to + // assign any remaining partitions within the assignment constraints + if (unassignedPartitions.indexOf(unassignedPartition) != unassignedPartitions.size() - 1) { + log.error("Filled the last member up to maxQuota but still had partitions remaining to assign, " + + "will continue but this indicates a bug in the assignment."); + } + } + } + } + + if (!unfilledMembersWithUnderMinQuotaPartitions.isEmpty()) { + // we expected all the remaining unfilled members have minQuota partitions and we're already at the expected number + // of members with more than the minQuota partitions. Otherwise, there must be error here. + if (currentNumMembersWithOverMinQuotaPartitions != expectedNumMembersWithOverMinQuotaPartitions) { + log.error("Current number of members with more than the minQuota partitions: {}, is less than the expected number " + + "of members with more than the minQuota partitions: {}, and no more partitions to be assigned to the remaining unfilled consumers: {}", + currentNumMembersWithOverMinQuotaPartitions, expectedNumMembersWithOverMinQuotaPartitions, unfilledMembersWithUnderMinQuotaPartitions); + throw new IllegalStateException("We haven't reached the expected number of members with " + + "more than the minQuota partitions, but no more partitions to be assigned"); + } else { + for (String unfilledMember : unfilledMembersWithUnderMinQuotaPartitions) { + int assignedPartitionsCount = assignment.get(unfilledMember).size(); + if (assignedPartitionsCount != minQuota) { + log.error("Consumer: [{}] should have {} partitions, but got {} partitions, and no more partitions " + + "to be assigned. The remaining unfilled consumers are: {}", unfilledMember, minQuota, assignedPartitionsCount, unfilledMembersWithUnderMinQuotaPartitions); + throw new IllegalStateException(String.format("Consumer: [%s] doesn't reach minQuota partitions, " + + "and no more partitions to be assigned", unfilledMember)); + } else { + log.trace("skip over this unfilled member: [{}] because we've reached the expected number of " + + "members with more than the minQuota partitions, and this member already have minQuota partitions", unfilledMember); + } + } + } + } + + log.info("Final assignment of partitions to consumers: \n{}", assignment); + + return assignment; + } + + + private List getAllTopicPartitions(Map partitionsPerTopic, + List sortedAllTopics, + int totalPartitionsCount) { + List allPartitions = new ArrayList<>(totalPartitionsCount); + + for (String topic : sortedAllTopics) { + int partitionCount = partitionsPerTopic.get(topic); + for (int i = 0; i < partitionCount; ++i) { + allPartitions.add(new TopicPartition(topic, i)); + } + } + return allPartitions; + } + + /** + * This generalAssign algorithm guarantees the assignment that is as balanced as possible. + * This method includes the following steps: + * + * 1. Preserving all the existing partition assignments + * 2. Removing all the partition assignments that have become invalid due to the change that triggers the reassignment + * 3. Assigning the unassigned partitions in a way that balances out the overall assignments of partitions to consumers + * 4. Further balancing out the resulting assignment by finding the partitions that can be reassigned + * to another consumer towards an overall more balanced assignment. + * + * @param partitionsPerTopic The number of partitions for each subscribed topic. + * @param subscriptions Map from the member id to their respective topic subscription + * @param currentAssignment Each consumer's previously owned and still-subscribed partitions + * + * @return Map from each member to the list of partitions assigned to them. + */ + private Map> generalAssign(Map partitionsPerTopic, + Map subscriptions, + Map> currentAssignment) { + if (log.isDebugEnabled()) { + log.debug("performing general assign. partitionsPerTopic: {}, subscriptions: {}, currentAssignment: {}", + partitionsPerTopic, subscriptions, currentAssignment); + } + + Map prevAssignment = new HashMap<>(); + partitionMovements = new PartitionMovements(); + + prepopulateCurrentAssignments(subscriptions, prevAssignment); + + // a mapping of all topics to all consumers that can be assigned to them + final Map> topic2AllPotentialConsumers = new HashMap<>(partitionsPerTopic.keySet().size()); + // a mapping of all consumers to all potential topics that can be assigned to them + final Map> consumer2AllPotentialTopics = new HashMap<>(subscriptions.keySet().size()); + + // initialize topic2AllPotentialConsumers and consumer2AllPotentialTopics + partitionsPerTopic.keySet().stream().forEach( + topicName -> topic2AllPotentialConsumers.put(topicName, new ArrayList<>())); + + for (Entry entry: subscriptions.entrySet()) { + String consumerId = entry.getKey(); + List subscribedTopics = new ArrayList<>(entry.getValue().topics().size()); + consumer2AllPotentialTopics.put(consumerId, subscribedTopics); + entry.getValue().topics().stream().filter(topic -> partitionsPerTopic.get(topic) != null).forEach(topic -> { + subscribedTopics.add(topic); + topic2AllPotentialConsumers.get(topic).add(consumerId); + }); + + // add this consumer to currentAssignment (with an empty topic partition assignment) if it does not already exist + if (!currentAssignment.containsKey(consumerId)) + currentAssignment.put(consumerId, new ArrayList<>()); + } + + // a mapping of partition to current consumer + Map currentPartitionConsumer = new HashMap<>(); + for (Map.Entry> entry: currentAssignment.entrySet()) + for (TopicPartition topicPartition: entry.getValue()) + currentPartitionConsumer.put(topicPartition, entry.getKey()); + + int totalPartitionsCount = partitionsPerTopic.values().stream().reduce(0, Integer::sum); + List sortedAllTopics = new ArrayList<>(topic2AllPotentialConsumers.keySet()); + Collections.sort(sortedAllTopics, new TopicComparator(topic2AllPotentialConsumers)); + List sortedAllPartitions = getAllTopicPartitions(partitionsPerTopic, sortedAllTopics, totalPartitionsCount); + + // the partitions already assigned in current assignment + List assignedPartitions = new ArrayList<>(); + boolean revocationRequired = false; + for (Iterator>> it = currentAssignment.entrySet().iterator(); it.hasNext();) { + Map.Entry> entry = it.next(); + Subscription consumerSubscription = subscriptions.get(entry.getKey()); + if (consumerSubscription == null) { + // if a consumer that existed before (and had some partition assignments) is now removed, remove it from currentAssignment + for (TopicPartition topicPartition: entry.getValue()) + currentPartitionConsumer.remove(topicPartition); + it.remove(); + } else { + // otherwise (the consumer still exists) + for (Iterator partitionIter = entry.getValue().iterator(); partitionIter.hasNext();) { + TopicPartition partition = partitionIter.next(); + if (!topic2AllPotentialConsumers.containsKey(partition.topic())) { + // if this topic partition of this consumer no longer exists, remove it from currentAssignment of the consumer + partitionIter.remove(); + currentPartitionConsumer.remove(partition); + } else if (!consumerSubscription.topics().contains(partition.topic())) { + // because the consumer is no longer subscribed to its topic, remove it from currentAssignment of the consumer + partitionIter.remove(); + revocationRequired = true; + } else { + // otherwise, remove the topic partition from those that need to be assigned only if + // its current consumer is still subscribed to its topic (because it is already assigned + // and we would want to preserve that assignment as much as possible) + assignedPartitions.add(partition); + } + } + } + } + + // all partitions that needed to be assigned + List unassignedPartitions = getUnassignedPartitions(sortedAllPartitions, assignedPartitions, topic2AllPotentialConsumers); + + if (log.isDebugEnabled()) { + log.debug("unassigned Partitions: {}", unassignedPartitions); + } + + // at this point we have preserved all valid topic partition to consumer assignments and removed + // all invalid topic partitions and invalid consumers. Now we need to assign unassignedPartitions + // to consumers so that the topic partition assignments are as balanced as possible. + + // an ascending sorted set of consumers based on how many topic partitions are already assigned to them + TreeSet sortedCurrentSubscriptions = new TreeSet<>(new SubscriptionComparator(currentAssignment)); + sortedCurrentSubscriptions.addAll(currentAssignment.keySet()); + + balance(currentAssignment, prevAssignment, sortedAllPartitions, unassignedPartitions, sortedCurrentSubscriptions, + consumer2AllPotentialTopics, topic2AllPotentialConsumers, currentPartitionConsumer, revocationRequired, + partitionsPerTopic, totalPartitionsCount); + + log.info("Final assignment of partitions to consumers: \n{}", currentAssignment); + + return currentAssignment; + } + + /** + * get the unassigned partition list by computing the difference set of the sortedPartitions(all partitions) + * and sortedAssignedPartitions. If no assigned partitions, we'll just return all sorted topic partitions. + * This is used in generalAssign method + * + * We loop the sortedPartition, and compare the ith element in sortedAssignedPartitions(i start from 0): + * - if not equal to the ith element, add to unassignedPartitions + * - if equal to the the ith element, get next element from sortedAssignedPartitions + * + * @param sortedAllPartitions: sorted all partitions + * @param sortedAssignedPartitions: sorted partitions, all are included in the sortedPartitions + * @param topic2AllPotentialConsumers: topics mapped to all consumers that subscribed to it + * @return partitions that aren't assigned to any current consumer + */ + private List getUnassignedPartitions(List sortedAllPartitions, + List sortedAssignedPartitions, + Map> topic2AllPotentialConsumers) { + if (sortedAssignedPartitions.isEmpty()) { + return sortedAllPartitions; + } + + List unassignedPartitions = new ArrayList<>(); + + Collections.sort(sortedAssignedPartitions, new PartitionComparator(topic2AllPotentialConsumers)); + + boolean shouldAddDirectly = false; + Iterator sortedAssignedPartitionsIter = sortedAssignedPartitions.iterator(); + TopicPartition nextAssignedPartition = sortedAssignedPartitionsIter.next(); + + for (TopicPartition topicPartition : sortedAllPartitions) { + if (shouldAddDirectly || !nextAssignedPartition.equals(topicPartition)) { + unassignedPartitions.add(topicPartition); + } else { + // this partition is in assignedPartitions, don't add to unassignedPartitions, just get next assigned partition + if (sortedAssignedPartitionsIter.hasNext()) { + nextAssignedPartition = sortedAssignedPartitionsIter.next(); + } else { + // add the remaining directly since there is no more sortedAssignedPartitions + shouldAddDirectly = true; + } + } + } + return unassignedPartitions; + } + + /** + * get the unassigned partition list by computing the difference set of all sorted partitions + * and sortedAssignedPartitions. If no assigned partitions, we'll just return all sorted topic partitions. + * This is used in constrainedAssign method + * + * To compute the difference set, we use two pointers technique here: + * + * We loop through the all sorted topics, and then iterate all partitions the topic has, + * compared with the ith element in sortedAssignedPartitions(i starts from 0): + * - if not equal to the ith element, add to unassignedPartitions + * - if equal to the the ith element, get next element from sortedAssignedPartitions + * + * @param totalPartitionsCount all partitions counts in this assignment + * @param partitionsPerTopic the number of partitions for each subscribed topic. + * @param sortedAssignedPartitions sorted partitions, all are included in the sortedPartitions + * @return the partitions not yet assigned to any consumers + */ + private List getUnassignedPartitions(int totalPartitionsCount, + Map partitionsPerTopic, + List sortedAssignedPartitions) { + List sortedAllTopics = new ArrayList<>(partitionsPerTopic.keySet()); + // sort all topics first, then we can have sorted all topic partitions by adding partitions starting from 0 + Collections.sort(sortedAllTopics); + + if (sortedAssignedPartitions.isEmpty()) { + // no assigned partitions means all partitions are unassigned partitions + return getAllTopicPartitions(partitionsPerTopic, sortedAllTopics, totalPartitionsCount); + } + + List unassignedPartitions = new ArrayList<>(totalPartitionsCount - sortedAssignedPartitions.size()); + + Collections.sort(sortedAssignedPartitions, Comparator.comparing(TopicPartition::topic).thenComparing(TopicPartition::partition)); + + boolean shouldAddDirectly = false; + Iterator sortedAssignedPartitionsIter = sortedAssignedPartitions.iterator(); + TopicPartition nextAssignedPartition = sortedAssignedPartitionsIter.next(); + + for (String topic : sortedAllTopics) { + int partitionCount = partitionsPerTopic.get(topic); + for (int i = 0; i < partitionCount; i++) { + if (shouldAddDirectly || !(nextAssignedPartition.topic().equals(topic) && nextAssignedPartition.partition() == i)) { + unassignedPartitions.add(new TopicPartition(topic, i)); + } else { + // this partition is in assignedPartitions, don't add to unassignedPartitions, just get next assigned partition + if (sortedAssignedPartitionsIter.hasNext()) { + nextAssignedPartition = sortedAssignedPartitionsIter.next(); + } else { + // add the remaining directly since there is no more sortedAssignedPartitions + shouldAddDirectly = true; + } + } + } + } + + return unassignedPartitions; + } + + /** + * update the prevAssignment with the partitions, consumer and generation in parameters + * + * @param partitions: The partitions to be updated the prevAssignement + * @param consumer: The consumer Id + * @param prevAssignment: The assignment contains the assignment with the 2nd largest generation + * @param generation: The generation of this assignment (partitions) + */ + private void updatePrevAssignment(Map prevAssignment, + List partitions, + String consumer, + int generation) { + for (TopicPartition partition: partitions) { + if (prevAssignment.containsKey(partition)) { + // only keep the latest previous assignment + if (generation > prevAssignment.get(partition).generation) { + prevAssignment.put(partition, new ConsumerGenerationPair(consumer, generation)); + } + } else { + prevAssignment.put(partition, new ConsumerGenerationPair(consumer, generation)); + } + } + } + + /** + * filling in the prevAssignment from the subscriptions. + * + * @param subscriptions: Map from the member id to their respective topic subscription + * @param prevAssignment: The assignment contains the assignment with the 2nd largest generation + */ + private void prepopulateCurrentAssignments(Map subscriptions, + Map prevAssignment) { + // we need to process subscriptions' user data with each consumer's reported generation in mind + // higher generations overwrite lower generations in case of a conflict + // note that a conflict could exists only if user data is for different generations + + for (Map.Entry subscriptionEntry: subscriptions.entrySet()) { + String consumer = subscriptionEntry.getKey(); + Subscription subscription = subscriptionEntry.getValue(); + if (subscription.userData() != null) { + // since this is our 2nd time to deserialize memberData, rewind userData is necessary + subscription.userData().rewind(); + } + MemberData memberData = memberData(subscriptionEntry.getValue()); + + // we already have the maxGeneration info, so just compare the current generation of memberData, and put into prevAssignment + if (memberData.generation.isPresent() && memberData.generation.get() < maxGeneration) { + // if the current member's generation is lower than maxGeneration, put into prevAssignment if needed + updatePrevAssignment(prevAssignment, memberData.partitions, consumer, memberData.generation.get()); + } else if (!memberData.generation.isPresent() && maxGeneration > DEFAULT_GENERATION) { + // if maxGeneration is larger then DEFAULT_GENERATION + // put all (no generation) partitions as DEFAULT_GENERATION into prevAssignment if needed + updatePrevAssignment(prevAssignment, memberData.partitions, consumer, DEFAULT_GENERATION); + } + } + } + + /** + * determine if the current assignment is a balanced one + * + * @param currentAssignment: the assignment whose balance needs to be checked + * @param sortedCurrentSubscriptions: an ascending sorted set of consumers based on how many topic partitions are already assigned to them + * @param allSubscriptions: a mapping of all consumers to all potential topics that can be assigned to them + * @param partitionsPerTopic: The number of partitions for each subscribed topic + * @param totalPartitionCount total partition count to be assigned + * @return true if the given assignment is balanced; false otherwise + */ + private boolean isBalanced(Map> currentAssignment, + TreeSet sortedCurrentSubscriptions, + Map> allSubscriptions, + Map partitionsPerTopic, + int totalPartitionCount) { + int min = currentAssignment.get(sortedCurrentSubscriptions.first()).size(); + int max = currentAssignment.get(sortedCurrentSubscriptions.last()).size(); + if (min >= max - 1) + // if minimum and maximum numbers of partitions assigned to consumers differ by at most one return true + return true; + + // create a mapping from partitions to the consumer assigned to them + final Map allPartitions = new HashMap<>(); + Set>> assignments = currentAssignment.entrySet(); + for (Map.Entry> entry: assignments) { + List topicPartitions = entry.getValue(); + for (TopicPartition topicPartition: topicPartitions) { + if (allPartitions.containsKey(topicPartition)) + log.error("{} is assigned to more than one consumer.", topicPartition); + allPartitions.put(topicPartition, entry.getKey()); + } + } + + // for each consumer that does not have all the topic partitions it can get make sure none of the topic partitions it + // could but did not get cannot be moved to it (because that would break the balance) + for (String consumer: sortedCurrentSubscriptions) { + List consumerPartitions = currentAssignment.get(consumer); + int consumerPartitionCount = consumerPartitions.size(); + + // skip if this consumer already has all the topic partitions it can get + List allSubscribedTopics = allSubscriptions.get(consumer); + int maxAssignmentSize = getMaxAssignmentSize(totalPartitionCount, allSubscribedTopics, partitionsPerTopic); + + if (consumerPartitionCount == maxAssignmentSize) + continue; + + // otherwise make sure it cannot get any more + for (String topic: allSubscribedTopics) { + int partitionCount = partitionsPerTopic.get(topic); + for (int i = 0; i < partitionCount; i++) { + TopicPartition topicPartition = new TopicPartition(topic, i); + if (!currentAssignment.get(consumer).contains(topicPartition)) { + String otherConsumer = allPartitions.get(topicPartition); + int otherConsumerPartitionCount = currentAssignment.get(otherConsumer).size(); + if (consumerPartitionCount < otherConsumerPartitionCount) { + log.debug("{} can be moved from consumer {} to consumer {} for a more balanced assignment.", + topicPartition, otherConsumer, consumer); + return false; + } + } + } + } + } + return true; + } + + /** + * get the maximum assigned partition size of the {@code allSubscribedTopics} + * + * @param totalPartitionCount total partition count to be assigned + * @param allSubscribedTopics the subscribed topics of a consumer + * @param partitionsPerTopic The number of partitions for each subscribed topic + * @return maximum assigned partition size + */ + private int getMaxAssignmentSize(int totalPartitionCount, + List allSubscribedTopics, + Map partitionsPerTopic) { + int maxAssignmentSize; + if (allSubscribedTopics.size() == partitionsPerTopic.size()) { + maxAssignmentSize = totalPartitionCount; + } else { + maxAssignmentSize = allSubscribedTopics.stream().map(topic -> partitionsPerTopic.get(topic)).reduce(0, Integer::sum); + } + return maxAssignmentSize; + } + + /** + * @return the balance score of the given assignment, as the sum of assigned partitions size difference of all consumer pairs. + * A perfectly balanced assignment (with all consumers getting the same number of partitions) has a balance score of 0. + * Lower balance score indicates a more balanced assignment. + */ + private int getBalanceScore(Map> assignment) { + int score = 0; + + Map consumer2AssignmentSize = new HashMap<>(); + for (Entry> entry: assignment.entrySet()) + consumer2AssignmentSize.put(entry.getKey(), entry.getValue().size()); + + Iterator> it = consumer2AssignmentSize.entrySet().iterator(); + while (it.hasNext()) { + Entry entry = it.next(); + int consumerAssignmentSize = entry.getValue(); + it.remove(); + for (Entry otherEntry: consumer2AssignmentSize.entrySet()) + score += Math.abs(consumerAssignmentSize - otherEntry.getValue()); + } + + return score; + } + + /** + * The assignment should improve the overall balance of the partition assignments to consumers. + */ + private void assignPartition(TopicPartition partition, + TreeSet sortedCurrentSubscriptions, + Map> currentAssignment, + Map> consumer2AllPotentialTopics, + Map currentPartitionConsumer) { + for (String consumer: sortedCurrentSubscriptions) { + if (consumer2AllPotentialTopics.get(consumer).contains(partition.topic())) { + sortedCurrentSubscriptions.remove(consumer); + currentAssignment.get(consumer).add(partition); + currentPartitionConsumer.put(partition, consumer); + sortedCurrentSubscriptions.add(consumer); + break; + } + } + } + + private boolean canParticipateInReassignment(String topic, + Map> topic2AllPotentialConsumers) { + // if a topic has two or more potential consumers it is subject to reassignment. + return topic2AllPotentialConsumers.get(topic).size() >= 2; + } + + private boolean canParticipateInReassignment(String consumer, + Map> currentAssignment, + Map> consumer2AllPotentialTopics, + Map> topic2AllPotentialConsumers, + Map partitionsPerTopic, + int totalPartitionCount) { + List currentPartitions = currentAssignment.get(consumer); + int currentAssignmentSize = currentPartitions.size(); + List allSubscribedTopics = consumer2AllPotentialTopics.get(consumer); + int maxAssignmentSize = getMaxAssignmentSize(totalPartitionCount, allSubscribedTopics, partitionsPerTopic); + + if (currentAssignmentSize > maxAssignmentSize) + log.error("The consumer {} is assigned more partitions than the maximum possible.", consumer); + + if (currentAssignmentSize < maxAssignmentSize) + // if a consumer is not assigned all its potential partitions it is subject to reassignment + return true; + + for (TopicPartition partition: currentPartitions) + // if any of the partitions assigned to a consumer is subject to reassignment the consumer itself + // is subject to reassignment + if (canParticipateInReassignment(partition.topic(), topic2AllPotentialConsumers)) + return true; + + return false; + } + + /** + * Balance the current assignment using the data structures created in the assign(...) method above. + */ + private void balance(Map> currentAssignment, + Map prevAssignment, + List sortedPartitions, + List unassignedPartitions, + TreeSet sortedCurrentSubscriptions, + Map> consumer2AllPotentialTopics, + Map> topic2AllPotentialConsumers, + Map currentPartitionConsumer, + boolean revocationRequired, + Map partitionsPerTopic, + int totalPartitionCount) { + boolean initializing = currentAssignment.get(sortedCurrentSubscriptions.last()).isEmpty(); + + // assign all unassigned partitions + for (TopicPartition partition: unassignedPartitions) { + // skip if there is no potential consumer for the topic + if (topic2AllPotentialConsumers.get(partition.topic()).isEmpty()) + continue; + + assignPartition(partition, sortedCurrentSubscriptions, currentAssignment, + consumer2AllPotentialTopics, currentPartitionConsumer); + } + + // narrow down the reassignment scope to only those partitions that can actually be reassigned + Set fixedPartitions = new HashSet<>(); + for (String topic: topic2AllPotentialConsumers.keySet()) + if (!canParticipateInReassignment(topic, topic2AllPotentialConsumers)) { + for (int i = 0; i < partitionsPerTopic.get(topic); i++) { + fixedPartitions.add(new TopicPartition(topic, i)); + } + } + sortedPartitions.removeAll(fixedPartitions); + unassignedPartitions.removeAll(fixedPartitions); + + // narrow down the reassignment scope to only those consumers that are subject to reassignment + Map> fixedAssignments = new HashMap<>(); + for (String consumer: consumer2AllPotentialTopics.keySet()) + if (!canParticipateInReassignment(consumer, currentAssignment, + consumer2AllPotentialTopics, topic2AllPotentialConsumers, partitionsPerTopic, totalPartitionCount)) { + sortedCurrentSubscriptions.remove(consumer); + fixedAssignments.put(consumer, currentAssignment.remove(consumer)); + } + + // create a deep copy of the current assignment so we can revert to it if we do not get a more balanced assignment later + Map> preBalanceAssignment = deepCopy(currentAssignment); + Map preBalancePartitionConsumers = new HashMap<>(currentPartitionConsumer); + + // if we don't already need to revoke something due to subscription changes, first try to balance by only moving newly added partitions + if (!revocationRequired) { + performReassignments(unassignedPartitions, currentAssignment, prevAssignment, sortedCurrentSubscriptions, + consumer2AllPotentialTopics, topic2AllPotentialConsumers, currentPartitionConsumer, partitionsPerTopic, totalPartitionCount); + } + + boolean reassignmentPerformed = performReassignments(sortedPartitions, currentAssignment, prevAssignment, sortedCurrentSubscriptions, + consumer2AllPotentialTopics, topic2AllPotentialConsumers, currentPartitionConsumer, partitionsPerTopic, totalPartitionCount); + + // if we are not preserving existing assignments and we have made changes to the current assignment + // make sure we are getting a more balanced assignment; otherwise, revert to previous assignment + if (!initializing && reassignmentPerformed && getBalanceScore(currentAssignment) >= getBalanceScore(preBalanceAssignment)) { + deepCopy(preBalanceAssignment, currentAssignment); + currentPartitionConsumer.clear(); + currentPartitionConsumer.putAll(preBalancePartitionConsumers); + } + + // add the fixed assignments (those that could not change) back + for (Entry> entry: fixedAssignments.entrySet()) { + String consumer = entry.getKey(); + currentAssignment.put(consumer, entry.getValue()); + sortedCurrentSubscriptions.add(consumer); + } + + fixedAssignments.clear(); + } + + private boolean performReassignments(List reassignablePartitions, + Map> currentAssignment, + Map prevAssignment, + TreeSet sortedCurrentSubscriptions, + Map> consumer2AllPotentialTopics, + Map> topic2AllPotentialConsumers, + Map currentPartitionConsumer, + Map partitionsPerTopic, + int totalPartitionCount) { + boolean reassignmentPerformed = false; + boolean modified; + + // repeat reassignment until no partition can be moved to improve the balance + do { + modified = false; + // reassign all reassignable partitions (starting from the partition with least potential consumers and if needed) + // until the full list is processed or a balance is achieved + Iterator partitionIterator = reassignablePartitions.iterator(); + while (partitionIterator.hasNext() && !isBalanced(currentAssignment, sortedCurrentSubscriptions, + consumer2AllPotentialTopics, partitionsPerTopic, totalPartitionCount)) { + TopicPartition partition = partitionIterator.next(); + + // the partition must have at least two consumers + if (topic2AllPotentialConsumers.get(partition.topic()).size() <= 1) + log.error("Expected more than one potential consumer for partition '{}'", partition); + + // the partition must have a current consumer + String consumer = currentPartitionConsumer.get(partition); + if (consumer == null) + log.error("Expected partition '{}' to be assigned to a consumer", partition); + + if (prevAssignment.containsKey(partition) && + currentAssignment.get(consumer).size() > currentAssignment.get(prevAssignment.get(partition).consumer).size() + 1) { + reassignPartition(partition, currentAssignment, sortedCurrentSubscriptions, currentPartitionConsumer, prevAssignment.get(partition).consumer); + reassignmentPerformed = true; + modified = true; + continue; + } + + // check if a better-suited consumer exist for the partition; if so, reassign it + for (String otherConsumer: topic2AllPotentialConsumers.get(partition.topic())) { + if (currentAssignment.get(consumer).size() > currentAssignment.get(otherConsumer).size() + 1) { + reassignPartition(partition, currentAssignment, sortedCurrentSubscriptions, currentPartitionConsumer, consumer2AllPotentialTopics); + reassignmentPerformed = true; + modified = true; + break; + } + } + } + } while (modified); + + return reassignmentPerformed; + } + + private void reassignPartition(TopicPartition partition, + Map> currentAssignment, + TreeSet sortedCurrentSubscriptions, + Map currentPartitionConsumer, + Map> consumer2AllPotentialTopics) { + // find the new consumer + String newConsumer = null; + for (String anotherConsumer: sortedCurrentSubscriptions) { + if (consumer2AllPotentialTopics.get(anotherConsumer).contains(partition.topic())) { + newConsumer = anotherConsumer; + break; + } + } + + assert newConsumer != null; + + reassignPartition(partition, currentAssignment, sortedCurrentSubscriptions, currentPartitionConsumer, newConsumer); + } + + private void reassignPartition(TopicPartition partition, + Map> currentAssignment, + TreeSet sortedCurrentSubscriptions, + Map currentPartitionConsumer, + String newConsumer) { + String consumer = currentPartitionConsumer.get(partition); + // find the correct partition movement considering the stickiness requirement + TopicPartition partitionToBeMoved = partitionMovements.getTheActualPartitionToBeMoved(partition, consumer, newConsumer); + processPartitionMovement(partitionToBeMoved, newConsumer, currentAssignment, sortedCurrentSubscriptions, currentPartitionConsumer); + } + + private void processPartitionMovement(TopicPartition partition, + String newConsumer, + Map> currentAssignment, + TreeSet sortedCurrentSubscriptions, + Map currentPartitionConsumer) { + String oldConsumer = currentPartitionConsumer.get(partition); + + sortedCurrentSubscriptions.remove(oldConsumer); + sortedCurrentSubscriptions.remove(newConsumer); + + partitionMovements.movePartition(partition, oldConsumer, newConsumer); + + currentAssignment.get(oldConsumer).remove(partition); + currentAssignment.get(newConsumer).add(partition); + currentPartitionConsumer.put(partition, newConsumer); + sortedCurrentSubscriptions.add(newConsumer); + sortedCurrentSubscriptions.add(oldConsumer); + } + + public boolean isSticky() { + return partitionMovements.isSticky(); + } + + private void deepCopy(Map> source, Map> dest) { + dest.clear(); + for (Entry> entry: source.entrySet()) + dest.put(entry.getKey(), new ArrayList<>(entry.getValue())); + } + + private Map> deepCopy(Map> assignment) { + Map> copy = new HashMap<>(); + deepCopy(assignment, copy); + return copy; + } + + private static class TopicComparator implements Comparator, Serializable { + private static final long serialVersionUID = 1L; + private Map> map; + + TopicComparator(Map> map) { + this.map = map; + } + + @Override + public int compare(String o1, String o2) { + int ret = map.get(o1).size() - map.get(o2).size(); + if (ret == 0) { + ret = o1.compareTo(o2); + } + return ret; + } + } + + private static class PartitionComparator implements Comparator, Serializable { + private static final long serialVersionUID = 1L; + private Map> map; + + PartitionComparator(Map> map) { + this.map = map; + } + + @Override + public int compare(TopicPartition o1, TopicPartition o2) { + int ret = map.get(o1.topic()).size() - map.get(o2.topic()).size(); + if (ret == 0) { + ret = o1.topic().compareTo(o2.topic()); + if (ret == 0) + ret = o1.partition() - o2.partition(); + } + return ret; + } + } + + private static class SubscriptionComparator implements Comparator, Serializable { + private static final long serialVersionUID = 1L; + private Map> map; + + SubscriptionComparator(Map> map) { + this.map = map; + } + + @Override + public int compare(String o1, String o2) { + int ret = map.get(o1).size() - map.get(o2).size(); + if (ret == 0) + ret = o1.compareTo(o2); + return ret; + } + } + + /** + * This class maintains some data structures to simplify lookup of partition movements among consumers. At each point of + * time during a partition rebalance it keeps track of partition movements corresponding to each topic, and also possible + * movement (in form a ConsumerPair object) for each partition. + */ + private static class PartitionMovements { + private Map>> partitionMovementsByTopic = new HashMap<>(); + private Map partitionMovements = new HashMap<>(); + + private ConsumerPair removeMovementRecordOfPartition(TopicPartition partition) { + ConsumerPair pair = partitionMovements.remove(partition); + + String topic = partition.topic(); + Map> partitionMovementsForThisTopic = partitionMovementsByTopic.get(topic); + partitionMovementsForThisTopic.get(pair).remove(partition); + if (partitionMovementsForThisTopic.get(pair).isEmpty()) + partitionMovementsForThisTopic.remove(pair); + if (partitionMovementsByTopic.get(topic).isEmpty()) + partitionMovementsByTopic.remove(topic); + + return pair; + } + + private void addPartitionMovementRecord(TopicPartition partition, ConsumerPair pair) { + partitionMovements.put(partition, pair); + + String topic = partition.topic(); + if (!partitionMovementsByTopic.containsKey(topic)) + partitionMovementsByTopic.put(topic, new HashMap<>()); + + Map> partitionMovementsForThisTopic = partitionMovementsByTopic.get(topic); + if (!partitionMovementsForThisTopic.containsKey(pair)) + partitionMovementsForThisTopic.put(pair, new HashSet<>()); + + partitionMovementsForThisTopic.get(pair).add(partition); + } + + private void movePartition(TopicPartition partition, String oldConsumer, String newConsumer) { + ConsumerPair pair = new ConsumerPair(oldConsumer, newConsumer); + + if (partitionMovements.containsKey(partition)) { + // this partition has previously moved + ConsumerPair existingPair = removeMovementRecordOfPartition(partition); + assert existingPair.dstMemberId.equals(oldConsumer); + if (!existingPair.srcMemberId.equals(newConsumer)) { + // the partition is not moving back to its previous consumer + // return new ConsumerPair2(existingPair.src, newConsumer); + addPartitionMovementRecord(partition, new ConsumerPair(existingPair.srcMemberId, newConsumer)); + } + } else + addPartitionMovementRecord(partition, pair); + } + + private TopicPartition getTheActualPartitionToBeMoved(TopicPartition partition, String oldConsumer, String newConsumer) { + String topic = partition.topic(); + + if (!partitionMovementsByTopic.containsKey(topic)) + return partition; + + if (partitionMovements.containsKey(partition)) { + // this partition has previously moved + assert oldConsumer.equals(partitionMovements.get(partition).dstMemberId); + oldConsumer = partitionMovements.get(partition).srcMemberId; + } + + Map> partitionMovementsForThisTopic = partitionMovementsByTopic.get(topic); + ConsumerPair reversePair = new ConsumerPair(newConsumer, oldConsumer); + if (!partitionMovementsForThisTopic.containsKey(reversePair)) + return partition; + + return partitionMovementsForThisTopic.get(reversePair).iterator().next(); + } + + private boolean isLinked(String src, String dst, Set pairs, List currentPath) { + if (src.equals(dst)) + return false; + + if (pairs.isEmpty()) + return false; + + if (new ConsumerPair(src, dst).in(pairs)) { + currentPath.add(src); + currentPath.add(dst); + return true; + } + + for (ConsumerPair pair: pairs) + if (pair.srcMemberId.equals(src)) { + Set reducedSet = new HashSet<>(pairs); + reducedSet.remove(pair); + currentPath.add(pair.srcMemberId); + return isLinked(pair.dstMemberId, dst, reducedSet, currentPath); + } + + return false; + } + + private boolean in(List cycle, Set> cycles) { + List superCycle = new ArrayList<>(cycle); + superCycle.remove(superCycle.size() - 1); + superCycle.addAll(cycle); + for (List foundCycle: cycles) { + if (foundCycle.size() == cycle.size() && Collections.indexOfSubList(superCycle, foundCycle) != -1) + return true; + } + return false; + } + + private boolean hasCycles(Set pairs) { + Set> cycles = new HashSet<>(); + for (ConsumerPair pair: pairs) { + Set reducedPairs = new HashSet<>(pairs); + reducedPairs.remove(pair); + List path = new ArrayList<>(Collections.singleton(pair.srcMemberId)); + if (isLinked(pair.dstMemberId, pair.srcMemberId, reducedPairs, path) && !in(path, cycles)) { + cycles.add(new ArrayList<>(path)); + log.error("A cycle of length {} was found: {}", path.size() - 1, path.toString()); + } + } + + // for now we want to make sure there is no partition movements of the same topic between a pair of consumers. + // the odds of finding a cycle among more than two consumers seem to be very low (according to various randomized + // tests with the given sticky algorithm) that it should not worth the added complexity of handling those cases. + for (List cycle: cycles) + if (cycle.size() == 3) // indicates a cycle of length 2 + return true; + return false; + } + + private boolean isSticky() { + for (Map.Entry>> topicMovements: this.partitionMovementsByTopic.entrySet()) { + Set topicMovementPairs = topicMovements.getValue().keySet(); + if (hasCycles(topicMovementPairs)) { + log.error("Stickiness is violated for topic {}" + + "\nPartition movements for this topic occurred among the following consumer pairs:" + + "\n{}", topicMovements.getKey(), topicMovements.getValue().toString()); + return false; + } + } + + return true; + } + } + + /** + * ConsumerPair represents a pair of Kafka consumer ids involved in a partition reassignment. Each + * ConsumerPair object, which contains a source (src) and a destination (dst) + * element, normally corresponds to a particular partition or topic, and indicates that the particular partition or some + * partition of the particular topic was moved from the source consumer to the destination consumer during the rebalance. + * This class is used, through the PartitionMovements class, by the sticky assignor and helps in determining + * whether a partition reassignment results in cycles among the generated graph of consumer pairs. + */ + private static class ConsumerPair { + private final String srcMemberId; + private final String dstMemberId; + + ConsumerPair(String srcMemberId, String dstMemberId) { + this.srcMemberId = srcMemberId; + this.dstMemberId = dstMemberId; + } + + public String toString() { + return this.srcMemberId + "->" + this.dstMemberId; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((this.srcMemberId == null) ? 0 : this.srcMemberId.hashCode()); + result = prime * result + ((this.dstMemberId == null) ? 0 : this.dstMemberId.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (obj == null) + return false; + + if (!getClass().isInstance(obj)) + return false; + + ConsumerPair otherPair = (ConsumerPair) obj; + return this.srcMemberId.equals(otherPair.srcMemberId) && this.dstMemberId.equals(otherPair.dstMemberId); + } + + private boolean in(Set pairs) { + for (ConsumerPair pair: pairs) + if (this.equals(pair)) + return true; + return false; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncClient.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncClient.java new file mode 100644 index 0000000..8b35499 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/AsyncClient.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.clients.ClientResponse; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.requests.AbstractRequest; +import org.apache.kafka.common.requests.AbstractResponse; +import org.apache.kafka.common.utils.LogContext; +import org.slf4j.Logger; + +public abstract class AsyncClient { + + private final Logger log; + private final ConsumerNetworkClient client; + + AsyncClient(ConsumerNetworkClient client, LogContext logContext) { + this.client = client; + this.log = logContext.logger(getClass()); + } + + public RequestFuture sendAsyncRequest(Node node, T1 requestData) { + AbstractRequest.Builder requestBuilder = prepareRequest(node, requestData); + + return client.send(node, requestBuilder).compose(new RequestFutureAdapter() { + @Override + @SuppressWarnings("unchecked") + public void onSuccess(ClientResponse value, RequestFuture future) { + Resp resp; + try { + resp = (Resp) value.responseBody(); + } catch (ClassCastException cce) { + log.error("Could not cast response body", cce); + future.raise(cce); + return; + } + log.trace("Received {} {} from broker {}", resp.getClass().getSimpleName(), resp, node); + try { + future.complete(handleResponse(node, requestData, resp)); + } catch (RuntimeException e) { + if (!future.isDone()) { + future.raise(e); + } + } + } + + @Override + public void onFailure(RuntimeException e, RequestFuture future1) { + future1.raise(e); + } + }); + } + + protected Logger logger() { + return log; + } + + protected abstract AbstractRequest.Builder prepareRequest(Node node, T1 requestData); + + protected abstract T2 handleResponse(Node node, T1 requestData, Resp response); +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java new file mode 100644 index 0000000..fad7f92 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinator.java @@ -0,0 +1,1509 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.clients.GroupRebalanceConfig; +import org.apache.kafka.clients.consumer.CommitFailedException; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerGroupMetadata; +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor; +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Assignment; +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.GroupSubscription; +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.RebalanceProtocol; +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Subscription; +import org.apache.kafka.clients.consumer.ConsumerRebalanceListener; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.consumer.OffsetCommitCallback; +import org.apache.kafka.clients.consumer.RetriableCommitFailedException; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.FencedInstanceIdException; +import org.apache.kafka.common.errors.GroupAuthorizationException; +import org.apache.kafka.common.errors.InterruptException; +import org.apache.kafka.common.errors.UnstableOffsetCommitException; +import org.apache.kafka.common.errors.RebalanceInProgressException; +import org.apache.kafka.common.errors.RetriableException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.TopicAuthorizationException; +import org.apache.kafka.common.errors.WakeupException; +import org.apache.kafka.common.message.JoinGroupRequestData; +import org.apache.kafka.common.message.JoinGroupResponseData; +import org.apache.kafka.common.message.OffsetCommitRequestData; +import org.apache.kafka.common.message.OffsetCommitResponseData; +import org.apache.kafka.common.metrics.Measurable; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.stats.Avg; +import org.apache.kafka.common.metrics.stats.Max; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.requests.JoinGroupRequest; +import org.apache.kafka.common.requests.OffsetCommitRequest; +import org.apache.kafka.common.requests.OffsetCommitResponse; +import org.apache.kafka.common.requests.OffsetFetchRequest; +import org.apache.kafka.common.requests.OffsetFetchResponse; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Timer; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; + +import static org.apache.kafka.clients.consumer.ConsumerConfig.ASSIGN_FROM_SUBSCRIBED_ASSIGNORS; +import static org.apache.kafka.clients.consumer.CooperativeStickyAssignor.COOPERATIVE_STICKY_ASSIGNOR_NAME; + +/** + * This class manages the coordination process with the consumer coordinator. + */ +public final class ConsumerCoordinator extends AbstractCoordinator { + private final GroupRebalanceConfig rebalanceConfig; + private final Logger log; + private final List assignors; + private final ConsumerMetadata metadata; + private final ConsumerCoordinatorMetrics sensors; + private final SubscriptionState subscriptions; + private final OffsetCommitCallback defaultOffsetCommitCallback; + private final boolean autoCommitEnabled; + private final int autoCommitIntervalMs; + private final ConsumerInterceptors interceptors; + private final AtomicInteger pendingAsyncCommits; + + // this collection must be thread-safe because it is modified from the response handler + // of offset commit requests, which may be invoked from the heartbeat thread + private final ConcurrentLinkedQueue completedOffsetCommits; + + private boolean isLeader = false; + private Set joinedSubscription; + private MetadataSnapshot metadataSnapshot; + private MetadataSnapshot assignmentSnapshot; + private Timer nextAutoCommitTimer; + private AtomicBoolean asyncCommitFenced; + private ConsumerGroupMetadata groupMetadata; + private final boolean throwOnFetchStableOffsetsUnsupported; + + // hold onto request&future for committed offset requests to enable async calls. + private PendingCommittedOffsetRequest pendingCommittedOffsetRequest = null; + + private static class PendingCommittedOffsetRequest { + private final Set requestedPartitions; + private final Generation requestedGeneration; + private final RequestFuture> response; + + private PendingCommittedOffsetRequest(final Set requestedPartitions, + final Generation generationAtRequestTime, + final RequestFuture> response) { + this.requestedPartitions = Objects.requireNonNull(requestedPartitions); + this.response = Objects.requireNonNull(response); + this.requestedGeneration = generationAtRequestTime; + } + + private boolean sameRequest(final Set currentRequest, final Generation currentGeneration) { + return Objects.equals(requestedGeneration, currentGeneration) && requestedPartitions.equals(currentRequest); + } + } + + private final RebalanceProtocol protocol; + + /** + * Initialize the coordination manager. + */ + public ConsumerCoordinator(GroupRebalanceConfig rebalanceConfig, + LogContext logContext, + ConsumerNetworkClient client, + List assignors, + ConsumerMetadata metadata, + SubscriptionState subscriptions, + Metrics metrics, + String metricGrpPrefix, + Time time, + boolean autoCommitEnabled, + int autoCommitIntervalMs, + ConsumerInterceptors interceptors, + boolean throwOnFetchStableOffsetsUnsupported) { + super(rebalanceConfig, + logContext, + client, + metrics, + metricGrpPrefix, + time); + this.rebalanceConfig = rebalanceConfig; + this.log = logContext.logger(ConsumerCoordinator.class); + this.metadata = metadata; + this.metadataSnapshot = new MetadataSnapshot(subscriptions, metadata.fetch(), metadata.updateVersion()); + this.subscriptions = subscriptions; + this.defaultOffsetCommitCallback = new DefaultOffsetCommitCallback(); + this.autoCommitEnabled = autoCommitEnabled; + this.autoCommitIntervalMs = autoCommitIntervalMs; + this.assignors = assignors; + this.completedOffsetCommits = new ConcurrentLinkedQueue<>(); + this.sensors = new ConsumerCoordinatorMetrics(metrics, metricGrpPrefix); + this.interceptors = interceptors; + this.pendingAsyncCommits = new AtomicInteger(); + this.asyncCommitFenced = new AtomicBoolean(false); + this.groupMetadata = new ConsumerGroupMetadata(rebalanceConfig.groupId, + JoinGroupRequest.UNKNOWN_GENERATION_ID, JoinGroupRequest.UNKNOWN_MEMBER_ID, rebalanceConfig.groupInstanceId); + this.throwOnFetchStableOffsetsUnsupported = throwOnFetchStableOffsetsUnsupported; + + if (autoCommitEnabled) + this.nextAutoCommitTimer = time.timer(autoCommitIntervalMs); + + // select the rebalance protocol such that: + // 1. only consider protocols that are supported by all the assignors. If there is no common protocols supported + // across all the assignors, throw an exception. + // 2. if there are multiple protocols that are commonly supported, select the one with the highest id (i.e. the + // id number indicates how advanced the protocol is). + // we know there are at least one assignor in the list, no need to double check for NPE + if (!assignors.isEmpty()) { + List supportedProtocols = new ArrayList<>(assignors.get(0).supportedProtocols()); + + for (ConsumerPartitionAssignor assignor : assignors) { + supportedProtocols.retainAll(assignor.supportedProtocols()); + } + + if (supportedProtocols.isEmpty()) { + throw new IllegalArgumentException("Specified assignors " + + assignors.stream().map(ConsumerPartitionAssignor::name).collect(Collectors.toSet()) + + " do not have commonly supported rebalance protocol"); + } + + Collections.sort(supportedProtocols); + + protocol = supportedProtocols.get(supportedProtocols.size() - 1); + } else { + protocol = null; + } + + this.metadata.requestUpdate(); + } + + @Override + public String protocolType() { + return ConsumerProtocol.PROTOCOL_TYPE; + } + + @Override + protected JoinGroupRequestData.JoinGroupRequestProtocolCollection metadata() { + log.debug("Joining group with current subscription: {}", subscriptions.subscription()); + this.joinedSubscription = subscriptions.subscription(); + JoinGroupRequestData.JoinGroupRequestProtocolCollection protocolSet = new JoinGroupRequestData.JoinGroupRequestProtocolCollection(); + + List topics = new ArrayList<>(joinedSubscription); + for (ConsumerPartitionAssignor assignor : assignors) { + Subscription subscription = new Subscription(topics, + assignor.subscriptionUserData(joinedSubscription), + subscriptions.assignedPartitionsList()); + ByteBuffer metadata = ConsumerProtocol.serializeSubscription(subscription); + + protocolSet.add(new JoinGroupRequestData.JoinGroupRequestProtocol() + .setName(assignor.name()) + .setMetadata(Utils.toArray(metadata))); + } + return protocolSet; + } + + public void updatePatternSubscription(Cluster cluster) { + final Set topicsToSubscribe = cluster.topics().stream() + .filter(subscriptions::matchesSubscribedPattern) + .collect(Collectors.toSet()); + if (subscriptions.subscribeFromPattern(topicsToSubscribe)) + metadata.requestUpdateForNewTopics(); + } + + private ConsumerPartitionAssignor lookupAssignor(String name) { + for (ConsumerPartitionAssignor assignor : this.assignors) { + if (assignor.name().equals(name)) + return assignor; + } + return null; + } + + private void maybeUpdateJoinedSubscription(Set assignedPartitions) { + if (subscriptions.hasPatternSubscription()) { + // Check if the assignment contains some topics that were not in the original + // subscription, if yes we will obey what leader has decided and add these topics + // into the subscriptions as long as they still match the subscribed pattern + + Set addedTopics = new HashSet<>(); + // this is a copy because its handed to listener below + for (TopicPartition tp : assignedPartitions) { + if (!joinedSubscription.contains(tp.topic())) + addedTopics.add(tp.topic()); + } + + if (!addedTopics.isEmpty()) { + Set newSubscription = new HashSet<>(subscriptions.subscription()); + Set newJoinedSubscription = new HashSet<>(joinedSubscription); + newSubscription.addAll(addedTopics); + newJoinedSubscription.addAll(addedTopics); + + if (this.subscriptions.subscribeFromPattern(newSubscription)) + metadata.requestUpdateForNewTopics(); + this.joinedSubscription = newJoinedSubscription; + } + } + } + + private Exception invokeOnAssignment(final ConsumerPartitionAssignor assignor, final Assignment assignment) { + log.info("Notifying assignor about the new {}", assignment); + + try { + assignor.onAssignment(assignment, groupMetadata); + } catch (Exception e) { + return e; + } + + return null; + } + + private Exception invokePartitionsAssigned(final Set assignedPartitions) { + log.info("Adding newly assigned partitions: {}", Utils.join(assignedPartitions, ", ")); + + ConsumerRebalanceListener listener = subscriptions.rebalanceListener(); + try { + final long startMs = time.milliseconds(); + listener.onPartitionsAssigned(assignedPartitions); + sensors.assignCallbackSensor.record(time.milliseconds() - startMs); + } catch (WakeupException | InterruptException e) { + throw e; + } catch (Exception e) { + log.error("User provided listener {} failed on invocation of onPartitionsAssigned for partitions {}", + listener.getClass().getName(), assignedPartitions, e); + return e; + } + + return null; + } + + private Exception invokePartitionsRevoked(final Set revokedPartitions) { + log.info("Revoke previously assigned partitions {}", Utils.join(revokedPartitions, ", ")); + + ConsumerRebalanceListener listener = subscriptions.rebalanceListener(); + try { + final long startMs = time.milliseconds(); + listener.onPartitionsRevoked(revokedPartitions); + sensors.revokeCallbackSensor.record(time.milliseconds() - startMs); + } catch (WakeupException | InterruptException e) { + throw e; + } catch (Exception e) { + log.error("User provided listener {} failed on invocation of onPartitionsRevoked for partitions {}", + listener.getClass().getName(), revokedPartitions, e); + return e; + } + + return null; + } + + private Exception invokePartitionsLost(final Set lostPartitions) { + log.info("Lost previously assigned partitions {}", Utils.join(lostPartitions, ", ")); + + ConsumerRebalanceListener listener = subscriptions.rebalanceListener(); + try { + final long startMs = time.milliseconds(); + listener.onPartitionsLost(lostPartitions); + sensors.loseCallbackSensor.record(time.milliseconds() - startMs); + } catch (WakeupException | InterruptException e) { + throw e; + } catch (Exception e) { + log.error("User provided listener {} failed on invocation of onPartitionsLost for partitions {}", + listener.getClass().getName(), lostPartitions, e); + return e; + } + + return null; + } + + @Override + protected void onJoinComplete(int generation, + String memberId, + String assignmentStrategy, + ByteBuffer assignmentBuffer) { + log.debug("Executing onJoinComplete with generation {} and memberId {}", generation, memberId); + + // Only the leader is responsible for monitoring for metadata changes (i.e. partition changes) + if (!isLeader) + assignmentSnapshot = null; + + ConsumerPartitionAssignor assignor = lookupAssignor(assignmentStrategy); + if (assignor == null) + throw new IllegalStateException("Coordinator selected invalid assignment protocol: " + assignmentStrategy); + + // Give the assignor a chance to update internal state based on the received assignment + groupMetadata = new ConsumerGroupMetadata(rebalanceConfig.groupId, generation, memberId, rebalanceConfig.groupInstanceId); + + Set ownedPartitions = new HashSet<>(subscriptions.assignedPartitions()); + + // should at least encode the short version + if (assignmentBuffer.remaining() < 2) + throw new IllegalStateException("There are insufficient bytes available to read assignment from the sync-group response (" + + "actual byte size " + assignmentBuffer.remaining() + ") , this is not expected; " + + "it is possible that the leader's assign function is buggy and did not return any assignment for this member, " + + "or because static member is configured and the protocol is buggy hence did not get the assignment for this member"); + + Assignment assignment = ConsumerProtocol.deserializeAssignment(assignmentBuffer); + + Set assignedPartitions = new HashSet<>(assignment.partitions()); + + if (!subscriptions.checkAssignmentMatchedSubscription(assignedPartitions)) { + final String reason = String.format("received assignment %s does not match the current subscription %s; " + + "it is likely that the subscription has changed since we joined the group, will re-join with current subscription", + assignment.partitions(), subscriptions.prettyString()); + requestRejoin(reason); + + return; + } + + final AtomicReference firstException = new AtomicReference<>(null); + Set addedPartitions = new HashSet<>(assignedPartitions); + addedPartitions.removeAll(ownedPartitions); + + if (protocol == RebalanceProtocol.COOPERATIVE) { + Set revokedPartitions = new HashSet<>(ownedPartitions); + revokedPartitions.removeAll(assignedPartitions); + + log.info("Updating assignment with\n" + + "\tAssigned partitions: {}\n" + + "\tCurrent owned partitions: {}\n" + + "\tAdded partitions (assigned - owned): {}\n" + + "\tRevoked partitions (owned - assigned): {}\n", + assignedPartitions, + ownedPartitions, + addedPartitions, + revokedPartitions + ); + + if (!revokedPartitions.isEmpty()) { + // Revoke partitions that were previously owned but no longer assigned; + // note that we should only change the assignment (or update the assignor's state) + // AFTER we've triggered the revoke callback + firstException.compareAndSet(null, invokePartitionsRevoked(revokedPartitions)); + + // If revoked any partitions, need to re-join the group afterwards + final String reason = String.format("need to revoke partitions %s as indicated " + + "by the current assignment and re-join", revokedPartitions); + requestRejoin(reason); + } + } + + // The leader may have assigned partitions which match our subscription pattern, but which + // were not explicitly requested, so we update the joined subscription here. + maybeUpdateJoinedSubscription(assignedPartitions); + + // Catch any exception here to make sure we could complete the user callback. + firstException.compareAndSet(null, invokeOnAssignment(assignor, assignment)); + + // Reschedule the auto commit starting from now + if (autoCommitEnabled) + this.nextAutoCommitTimer.updateAndReset(autoCommitIntervalMs); + + subscriptions.assignFromSubscribed(assignedPartitions); + + // Add partitions that were not previously owned but are now assigned + firstException.compareAndSet(null, invokePartitionsAssigned(addedPartitions)); + + if (firstException.get() != null) { + if (firstException.get() instanceof KafkaException) { + throw (KafkaException) firstException.get(); + } else { + throw new KafkaException("User rebalance callback throws an error", firstException.get()); + } + } + } + + void maybeUpdateSubscriptionMetadata() { + int version = metadata.updateVersion(); + if (version > metadataSnapshot.version) { + Cluster cluster = metadata.fetch(); + + if (subscriptions.hasPatternSubscription()) + updatePatternSubscription(cluster); + + // Update the current snapshot, which will be used to check for subscription + // changes that would require a rebalance (e.g. new partitions). + metadataSnapshot = new MetadataSnapshot(subscriptions, cluster, version); + } + } + + /** + * Poll for coordinator events. This ensures that the coordinator is known and that the consumer + * has joined the group (if it is using group management). This also handles periodic offset commits + * if they are enabled. + *

+ * Returns early if the timeout expires or if waiting on rejoin is not required + * + * @param timer Timer bounding how long this method can block + * @param waitForJoinGroup Boolean flag indicating if we should wait until re-join group completes + * @throws KafkaException if the rebalance callback throws an exception + * @return true iff the operation succeeded + */ + public boolean poll(Timer timer, boolean waitForJoinGroup) { + maybeUpdateSubscriptionMetadata(); + + invokeCompletedOffsetCommitCallbacks(); + + if (subscriptions.hasAutoAssignedPartitions()) { + if (protocol == null) { + throw new IllegalStateException("User configured " + ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG + + " to empty while trying to subscribe for group protocol to auto assign partitions"); + } + // Always update the heartbeat last poll time so that the heartbeat thread does not leave the + // group proactively due to application inactivity even if (say) the coordinator cannot be found. + pollHeartbeat(timer.currentTimeMs()); + if (coordinatorUnknown() && !ensureCoordinatorReady(timer)) { + return false; + } + + if (rejoinNeededOrPending()) { + // due to a race condition between the initial metadata fetch and the initial rebalance, + // we need to ensure that the metadata is fresh before joining initially. This ensures + // that we have matched the pattern against the cluster's topics at least once before joining. + if (subscriptions.hasPatternSubscription()) { + // For consumer group that uses pattern-based subscription, after a topic is created, + // any consumer that discovers the topic after metadata refresh can trigger rebalance + // across the entire consumer group. Multiple rebalances can be triggered after one topic + // creation if consumers refresh metadata at vastly different times. We can significantly + // reduce the number of rebalances caused by single topic creation by asking consumer to + // refresh metadata before re-joining the group as long as the refresh backoff time has + // passed. + if (this.metadata.timeToAllowUpdate(timer.currentTimeMs()) == 0) { + this.metadata.requestUpdate(); + } + + if (!client.ensureFreshMetadata(timer)) { + return false; + } + + maybeUpdateSubscriptionMetadata(); + } + + // if not wait for join group, we would just use a timer of 0 + if (!ensureActiveGroup(waitForJoinGroup ? timer : time.timer(0L))) { + // since we may use a different timer in the callee, we'd still need + // to update the original timer's current time after the call + timer.update(time.milliseconds()); + + return false; + } + } + } else { + // For manually assigned partitions, if there are no ready nodes, await metadata. + // If connections to all nodes fail, wakeups triggered while attempting to send fetch + // requests result in polls returning immediately, causing a tight loop of polls. Without + // the wakeup, poll() with no channels would block for the timeout, delaying re-connection. + // awaitMetadataUpdate() initiates new connections with configured backoff and avoids the busy loop. + // When group management is used, metadata wait is already performed for this scenario as + // coordinator is unknown, hence this check is not required. + if (metadata.updateRequested() && !client.hasReadyNodes(timer.currentTimeMs())) { + client.awaitMetadataUpdate(timer); + } + } + + maybeAutoCommitOffsetsAsync(timer.currentTimeMs()); + return true; + } + + /** + * Return the time to the next needed invocation of {@link ConsumerNetworkClient#poll(Timer)}. + * @param now current time in milliseconds + * @return the maximum time in milliseconds the caller should wait before the next invocation of poll() + */ + public long timeToNextPoll(long now) { + if (!autoCommitEnabled) + return timeToNextHeartbeat(now); + + return Math.min(nextAutoCommitTimer.remainingMs(), timeToNextHeartbeat(now)); + } + + private void updateGroupSubscription(Set topics) { + // the leader will begin watching for changes to any of the topics the group is interested in, + // which ensures that all metadata changes will eventually be seen + if (this.subscriptions.groupSubscribe(topics)) + metadata.requestUpdateForNewTopics(); + + // update metadata (if needed) and keep track of the metadata used for assignment so that + // we can check after rebalance completion whether anything has changed + if (!client.ensureFreshMetadata(time.timer(Long.MAX_VALUE))) + throw new TimeoutException(); + + maybeUpdateSubscriptionMetadata(); + } + + private boolean isAssignFromSubscribedTopicsAssignor(String name) { + return ASSIGN_FROM_SUBSCRIBED_ASSIGNORS.contains(name); + } + + /** + * user-customized assignor may have created some topics that are not in the subscription list + * and assign their partitions to the members; in this case we would like to update the leader's + * own metadata with the newly added topics so that it will not trigger a subsequent rebalance + * when these topics gets updated from metadata refresh. + * + * We skip the check for in-product assignors since this will not happen in in-product assignors. + * + * TODO: this is a hack and not something we want to support long-term unless we push regex into the protocol + * we may need to modify the ConsumerPartitionAssignor API to better support this case. + * + * @param assignorName the selected assignor name + * @param assignments the assignments after assignor assigned + * @param allSubscribedTopics all consumers' subscribed topics + */ + private void maybeUpdateGroupSubscription(String assignorName, + Map assignments, + Set allSubscribedTopics) { + if (!isAssignFromSubscribedTopicsAssignor(assignorName)) { + Set assignedTopics = new HashSet<>(); + for (Assignment assigned : assignments.values()) { + for (TopicPartition tp : assigned.partitions()) + assignedTopics.add(tp.topic()); + } + + if (!assignedTopics.containsAll(allSubscribedTopics)) { + Set notAssignedTopics = new HashSet<>(allSubscribedTopics); + notAssignedTopics.removeAll(assignedTopics); + log.warn("The following subscribed topics are not assigned to any members: {} ", notAssignedTopics); + } + + if (!allSubscribedTopics.containsAll(assignedTopics)) { + Set newlyAddedTopics = new HashSet<>(assignedTopics); + newlyAddedTopics.removeAll(allSubscribedTopics); + log.info("The following not-subscribed topics are assigned, and their metadata will be " + + "fetched from the brokers: {}", newlyAddedTopics); + + allSubscribedTopics.addAll(newlyAddedTopics); + updateGroupSubscription(allSubscribedTopics); + } + } + } + + @Override + protected Map performAssignment(String leaderId, + String assignmentStrategy, + List allSubscriptions) { + ConsumerPartitionAssignor assignor = lookupAssignor(assignmentStrategy); + if (assignor == null) + throw new IllegalStateException("Coordinator selected invalid assignment protocol: " + assignmentStrategy); + String assignorName = assignor.name(); + + Set allSubscribedTopics = new HashSet<>(); + Map subscriptions = new HashMap<>(); + + // collect all the owned partitions + Map> ownedPartitions = new HashMap<>(); + + for (JoinGroupResponseData.JoinGroupResponseMember memberSubscription : allSubscriptions) { + Subscription subscription = ConsumerProtocol.deserializeSubscription(ByteBuffer.wrap(memberSubscription.metadata())); + subscription.setGroupInstanceId(Optional.ofNullable(memberSubscription.groupInstanceId())); + subscriptions.put(memberSubscription.memberId(), subscription); + allSubscribedTopics.addAll(subscription.topics()); + ownedPartitions.put(memberSubscription.memberId(), subscription.ownedPartitions()); + } + + // the leader will begin watching for changes to any of the topics the group is interested in, + // which ensures that all metadata changes will eventually be seen + updateGroupSubscription(allSubscribedTopics); + + isLeader = true; + + log.debug("Performing assignment using strategy {} with subscriptions {}", assignorName, subscriptions); + + Map assignments = assignor.assign(metadata.fetch(), new GroupSubscription(subscriptions)).groupAssignment(); + + // skip the validation for built-in cooperative sticky assignor since we've considered + // the "generation" of ownedPartition inside the assignor + if (protocol == RebalanceProtocol.COOPERATIVE && !assignorName.equals(COOPERATIVE_STICKY_ASSIGNOR_NAME)) { + validateCooperativeAssignment(ownedPartitions, assignments); + } + + maybeUpdateGroupSubscription(assignorName, assignments, allSubscribedTopics); + + assignmentSnapshot = metadataSnapshot; + + log.info("Finished assignment for group at generation {}: {}", generation().generationId, assignments); + + Map groupAssignment = new HashMap<>(); + for (Map.Entry assignmentEntry : assignments.entrySet()) { + ByteBuffer buffer = ConsumerProtocol.serializeAssignment(assignmentEntry.getValue()); + groupAssignment.put(assignmentEntry.getKey(), buffer); + } + + return groupAssignment; + } + + /** + * Used by COOPERATIVE rebalance protocol only. + * + * Validate the assignments returned by the assignor such that no owned partitions are going to + * be reassigned to a different consumer directly: if the assignor wants to reassign an owned partition, + * it must first remove it from the new assignment of the current owner so that it is not assigned to any + * member, and then in the next rebalance it can finally reassign those partitions not owned by anyone to consumers. + */ + private void validateCooperativeAssignment(final Map> ownedPartitions, + final Map assignments) { + Set totalRevokedPartitions = new HashSet<>(); + Set totalAddedPartitions = new HashSet<>(); + for (final Map.Entry entry : assignments.entrySet()) { + final Assignment assignment = entry.getValue(); + final Set addedPartitions = new HashSet<>(assignment.partitions()); + addedPartitions.removeAll(ownedPartitions.get(entry.getKey())); + final Set revokedPartitions = new HashSet<>(ownedPartitions.get(entry.getKey())); + revokedPartitions.removeAll(assignment.partitions()); + + totalAddedPartitions.addAll(addedPartitions); + totalRevokedPartitions.addAll(revokedPartitions); + } + + // if there are overlap between revoked partitions and added partitions, it means some partitions + // immediately gets re-assigned to another member while it is still claimed by some member + totalAddedPartitions.retainAll(totalRevokedPartitions); + if (!totalAddedPartitions.isEmpty()) { + log.error("With the COOPERATIVE protocol, owned partitions cannot be " + + "reassigned to other members; however the assignor has reassigned partitions {} which are still owned " + + "by some members", totalAddedPartitions); + + throw new IllegalStateException("Assignor supporting the COOPERATIVE protocol violates its requirements"); + } + } + + @Override + protected void onJoinPrepare(int generation, String memberId) { + log.debug("Executing onJoinPrepare with generation {} and memberId {}", generation, memberId); + // commit offsets prior to rebalance if auto-commit enabled + maybeAutoCommitOffsetsSync(time.timer(rebalanceConfig.rebalanceTimeoutMs)); + + // the generation / member-id can possibly be reset by the heartbeat thread + // upon getting errors or heartbeat timeouts; in this case whatever is previously + // owned partitions would be lost, we should trigger the callback and cleanup the assignment; + // otherwise we can proceed normally and revoke the partitions depending on the protocol, + // and in that case we should only change the assignment AFTER the revoke callback is triggered + // so that users can still access the previously owned partitions to commit offsets etc. + Exception exception = null; + final Set revokedPartitions; + if (generation == Generation.NO_GENERATION.generationId && + memberId.equals(Generation.NO_GENERATION.memberId)) { + revokedPartitions = new HashSet<>(subscriptions.assignedPartitions()); + + if (!revokedPartitions.isEmpty()) { + log.info("Giving away all assigned partitions as lost since generation has been reset," + + "indicating that consumer is no longer part of the group"); + exception = invokePartitionsLost(revokedPartitions); + + subscriptions.assignFromSubscribed(Collections.emptySet()); + } + } else { + switch (protocol) { + case EAGER: + // revoke all partitions + revokedPartitions = new HashSet<>(subscriptions.assignedPartitions()); + exception = invokePartitionsRevoked(revokedPartitions); + + subscriptions.assignFromSubscribed(Collections.emptySet()); + + break; + + case COOPERATIVE: + // only revoke those partitions that are not in the subscription any more. + Set ownedPartitions = new HashSet<>(subscriptions.assignedPartitions()); + revokedPartitions = ownedPartitions.stream() + .filter(tp -> !subscriptions.subscription().contains(tp.topic())) + .collect(Collectors.toSet()); + + if (!revokedPartitions.isEmpty()) { + exception = invokePartitionsRevoked(revokedPartitions); + + ownedPartitions.removeAll(revokedPartitions); + subscriptions.assignFromSubscribed(ownedPartitions); + } + + break; + } + } + + isLeader = false; + subscriptions.resetGroupSubscription(); + + if (exception != null) { + throw new KafkaException("User rebalance callback throws an error", exception); + } + } + + @Override + public void onLeavePrepare() { + // Save the current Generation and use that to get the memberId, as the hb thread can change it at any time + final Generation currentGeneration = generation(); + final String memberId = currentGeneration.memberId; + + log.debug("Executing onLeavePrepare with generation {} and memberId {}", currentGeneration, memberId); + + // we should reset assignment and trigger the callback before leaving group + Set droppedPartitions = new HashSet<>(subscriptions.assignedPartitions()); + + if (subscriptions.hasAutoAssignedPartitions() && !droppedPartitions.isEmpty()) { + final Exception e; + if (generation() == Generation.NO_GENERATION || rebalanceInProgress()) { + e = invokePartitionsLost(droppedPartitions); + } else { + e = invokePartitionsRevoked(droppedPartitions); + } + + subscriptions.assignFromSubscribed(Collections.emptySet()); + + if (e != null) { + throw new KafkaException("User rebalance callback throws an error", e); + } + } + } + + /** + * @throws KafkaException if the callback throws exception + */ + @Override + public boolean rejoinNeededOrPending() { + if (!subscriptions.hasAutoAssignedPartitions()) + return false; + + // we need to rejoin if we performed the assignment and metadata has changed; + // also for those owned-but-no-longer-existed partitions we should drop them as lost + if (assignmentSnapshot != null && !assignmentSnapshot.matches(metadataSnapshot)) { + final String reason = String.format("cached metadata has changed from %s at the beginning of the rebalance to %s", + assignmentSnapshot, metadataSnapshot); + requestRejoinIfNecessary(reason); + return true; + } + + // we need to join if our subscription has changed since the last join + if (joinedSubscription != null && !joinedSubscription.equals(subscriptions.subscription())) { + final String reason = String.format("subscription has changed from %s at the beginning of the rebalance to %s", + joinedSubscription, subscriptions.subscription()); + requestRejoinIfNecessary(reason); + return true; + } + + return super.rejoinNeededOrPending(); + } + + /** + * Refresh the committed offsets for provided partitions. + * + * @param timer Timer bounding how long this method can block + * @return true iff the operation completed within the timeout + */ + public boolean refreshCommittedOffsetsIfNeeded(Timer timer) { + final Set initializingPartitions = subscriptions.initializingPartitions(); + + final Map offsets = fetchCommittedOffsets(initializingPartitions, timer); + if (offsets == null) return false; + + for (final Map.Entry entry : offsets.entrySet()) { + final TopicPartition tp = entry.getKey(); + final OffsetAndMetadata offsetAndMetadata = entry.getValue(); + if (offsetAndMetadata != null) { + // first update the epoch if necessary + entry.getValue().leaderEpoch().ifPresent(epoch -> this.metadata.updateLastSeenEpochIfNewer(entry.getKey(), epoch)); + + // it's possible that the partition is no longer assigned when the response is received, + // so we need to ignore seeking if that's the case + if (this.subscriptions.isAssigned(tp)) { + final ConsumerMetadata.LeaderAndEpoch leaderAndEpoch = metadata.currentLeader(tp); + final SubscriptionState.FetchPosition position = new SubscriptionState.FetchPosition( + offsetAndMetadata.offset(), offsetAndMetadata.leaderEpoch(), + leaderAndEpoch); + + this.subscriptions.seekUnvalidated(tp, position); + + log.info("Setting offset for partition {} to the committed offset {}", tp, position); + } else { + log.info("Ignoring the returned {} since its partition {} is no longer assigned", + offsetAndMetadata, tp); + } + } + } + return true; + } + + /** + * Fetch the current committed offsets from the coordinator for a set of partitions. + * + * @param partitions The partitions to fetch offsets for + * @return A map from partition to the committed offset or null if the operation timed out + */ + public Map fetchCommittedOffsets(final Set partitions, + final Timer timer) { + if (partitions.isEmpty()) return Collections.emptyMap(); + + final Generation generationForOffsetRequest = generationIfStable(); + if (pendingCommittedOffsetRequest != null && + !pendingCommittedOffsetRequest.sameRequest(partitions, generationForOffsetRequest)) { + // if we were waiting for a different request, then just clear it. + pendingCommittedOffsetRequest = null; + } + + do { + if (!ensureCoordinatorReady(timer)) return null; + + // contact coordinator to fetch committed offsets + final RequestFuture> future; + if (pendingCommittedOffsetRequest != null) { + future = pendingCommittedOffsetRequest.response; + } else { + future = sendOffsetFetchRequest(partitions); + pendingCommittedOffsetRequest = new PendingCommittedOffsetRequest(partitions, generationForOffsetRequest, future); + } + client.poll(future, timer); + + if (future.isDone()) { + pendingCommittedOffsetRequest = null; + + if (future.succeeded()) { + return future.value(); + } else if (!future.isRetriable()) { + throw future.exception(); + } else { + timer.sleep(rebalanceConfig.retryBackoffMs); + } + } else { + return null; + } + } while (timer.notExpired()); + return null; + } + + /** + * Return the consumer group metadata. + * + * @return the current consumer group metadata + */ + public ConsumerGroupMetadata groupMetadata() { + return groupMetadata; + } + + /** + * @throws KafkaException if the rebalance callback throws exception + */ + public void close(final Timer timer) { + // we do not need to re-enable wakeups since we are closing already + client.disableWakeups(); + try { + maybeAutoCommitOffsetsSync(timer); + while (pendingAsyncCommits.get() > 0 && timer.notExpired()) { + ensureCoordinatorReady(timer); + client.poll(timer); + invokeCompletedOffsetCommitCallbacks(); + } + } finally { + super.close(timer); + } + } + + // visible for testing + void invokeCompletedOffsetCommitCallbacks() { + if (asyncCommitFenced.get()) { + throw new FencedInstanceIdException("Get fenced exception for group.instance.id " + + rebalanceConfig.groupInstanceId.orElse("unset_instance_id") + + ", current member.id is " + memberId()); + } + while (true) { + OffsetCommitCompletion completion = completedOffsetCommits.poll(); + if (completion == null) { + break; + } + completion.invoke(); + } + } + + public void commitOffsetsAsync(final Map offsets, final OffsetCommitCallback callback) { + invokeCompletedOffsetCommitCallbacks(); + + if (!coordinatorUnknown()) { + doCommitOffsetsAsync(offsets, callback); + } else { + // we don't know the current coordinator, so try to find it and then send the commit + // or fail (we don't want recursive retries which can cause offset commits to arrive + // out of order). Note that there may be multiple offset commits chained to the same + // coordinator lookup request. This is fine because the listeners will be invoked in + // the same order that they were added. Note also that AbstractCoordinator prevents + // multiple concurrent coordinator lookup requests. + pendingAsyncCommits.incrementAndGet(); + lookupCoordinator().addListener(new RequestFutureListener() { + @Override + public void onSuccess(Void value) { + pendingAsyncCommits.decrementAndGet(); + doCommitOffsetsAsync(offsets, callback); + client.pollNoWakeup(); + } + + @Override + public void onFailure(RuntimeException e) { + pendingAsyncCommits.decrementAndGet(); + completedOffsetCommits.add(new OffsetCommitCompletion(callback, offsets, + new RetriableCommitFailedException(e))); + } + }); + } + + // ensure the commit has a chance to be transmitted (without blocking on its completion). + // Note that commits are treated as heartbeats by the coordinator, so there is no need to + // explicitly allow heartbeats through delayed task execution. + client.pollNoWakeup(); + } + + private void doCommitOffsetsAsync(final Map offsets, final OffsetCommitCallback callback) { + RequestFuture future = sendOffsetCommitRequest(offsets); + final OffsetCommitCallback cb = callback == null ? defaultOffsetCommitCallback : callback; + future.addListener(new RequestFutureListener() { + @Override + public void onSuccess(Void value) { + if (interceptors != null) + interceptors.onCommit(offsets); + completedOffsetCommits.add(new OffsetCommitCompletion(cb, offsets, null)); + } + + @Override + public void onFailure(RuntimeException e) { + Exception commitException = e; + + if (e instanceof RetriableException) { + commitException = new RetriableCommitFailedException(e); + } + completedOffsetCommits.add(new OffsetCommitCompletion(cb, offsets, commitException)); + if (commitException instanceof FencedInstanceIdException) { + asyncCommitFenced.set(true); + } + } + }); + } + + /** + * Commit offsets synchronously. This method will retry until the commit completes successfully + * or an unrecoverable error is encountered. + * @param offsets The offsets to be committed + * @throws org.apache.kafka.common.errors.AuthorizationException if the consumer is not authorized to the group + * or to any of the specified partitions. See the exception for more details + * @throws CommitFailedException if an unrecoverable error occurs before the commit can be completed + * @throws FencedInstanceIdException if a static member gets fenced + * @return If the offset commit was successfully sent and a successful response was received from + * the coordinator + */ + public boolean commitOffsetsSync(Map offsets, Timer timer) { + invokeCompletedOffsetCommitCallbacks(); + + if (offsets.isEmpty()) + return true; + + do { + if (coordinatorUnknown() && !ensureCoordinatorReady(timer)) { + return false; + } + + RequestFuture future = sendOffsetCommitRequest(offsets); + client.poll(future, timer); + + // We may have had in-flight offset commits when the synchronous commit began. If so, ensure that + // the corresponding callbacks are invoked prior to returning in order to preserve the order that + // the offset commits were applied. + invokeCompletedOffsetCommitCallbacks(); + + if (future.succeeded()) { + if (interceptors != null) + interceptors.onCommit(offsets); + return true; + } + + if (future.failed() && !future.isRetriable()) + throw future.exception(); + + timer.sleep(rebalanceConfig.retryBackoffMs); + } while (timer.notExpired()); + + return false; + } + + public void maybeAutoCommitOffsetsAsync(long now) { + if (autoCommitEnabled) { + nextAutoCommitTimer.update(now); + if (nextAutoCommitTimer.isExpired()) { + nextAutoCommitTimer.reset(autoCommitIntervalMs); + doAutoCommitOffsetsAsync(); + } + } + } + + private void doAutoCommitOffsetsAsync() { + Map allConsumedOffsets = subscriptions.allConsumed(); + log.debug("Sending asynchronous auto-commit of offsets {}", allConsumedOffsets); + + commitOffsetsAsync(allConsumedOffsets, (offsets, exception) -> { + if (exception != null) { + if (exception instanceof RetriableCommitFailedException) { + log.debug("Asynchronous auto-commit of offsets {} failed due to retriable error: {}", offsets, + exception); + nextAutoCommitTimer.updateAndReset(rebalanceConfig.retryBackoffMs); + } else { + log.warn("Asynchronous auto-commit of offsets {} failed: {}", offsets, exception.getMessage()); + } + } else { + log.debug("Completed asynchronous auto-commit of offsets {}", offsets); + } + }); + } + + private void maybeAutoCommitOffsetsSync(Timer timer) { + if (autoCommitEnabled) { + Map allConsumedOffsets = subscriptions.allConsumed(); + try { + log.debug("Sending synchronous auto-commit of offsets {}", allConsumedOffsets); + if (!commitOffsetsSync(allConsumedOffsets, timer)) + log.debug("Auto-commit of offsets {} timed out before completion", allConsumedOffsets); + } catch (WakeupException | InterruptException e) { + log.debug("Auto-commit of offsets {} was interrupted before completion", allConsumedOffsets); + // rethrow wakeups since they are triggered by the user + throw e; + } catch (Exception e) { + // consistent with async auto-commit failures, we do not propagate the exception + log.warn("Synchronous auto-commit of offsets {} failed: {}", allConsumedOffsets, e.getMessage()); + } + } + } + + private class DefaultOffsetCommitCallback implements OffsetCommitCallback { + @Override + public void onComplete(Map offsets, Exception exception) { + if (exception != null) + log.error("Offset commit with offsets {} failed", offsets, exception); + } + } + + /** + * Commit offsets for the specified list of topics and partitions. This is a non-blocking call + * which returns a request future that can be polled in the case of a synchronous commit or ignored in the + * asynchronous case. + * + * NOTE: This is visible only for testing + * + * @param offsets The list of offsets per partition that should be committed. + * @return A request future whose value indicates whether the commit was successful or not + */ + RequestFuture sendOffsetCommitRequest(final Map offsets) { + if (offsets.isEmpty()) + return RequestFuture.voidSuccess(); + + Node coordinator = checkAndGetCoordinator(); + if (coordinator == null) + return RequestFuture.coordinatorNotAvailable(); + + // create the offset commit request + Map requestTopicDataMap = new HashMap<>(); + for (Map.Entry entry : offsets.entrySet()) { + TopicPartition topicPartition = entry.getKey(); + OffsetAndMetadata offsetAndMetadata = entry.getValue(); + if (offsetAndMetadata.offset() < 0) { + return RequestFuture.failure(new IllegalArgumentException("Invalid offset: " + offsetAndMetadata.offset())); + } + + OffsetCommitRequestData.OffsetCommitRequestTopic topic = requestTopicDataMap + .getOrDefault(topicPartition.topic(), + new OffsetCommitRequestData.OffsetCommitRequestTopic() + .setName(topicPartition.topic()) + ); + + topic.partitions().add(new OffsetCommitRequestData.OffsetCommitRequestPartition() + .setPartitionIndex(topicPartition.partition()) + .setCommittedOffset(offsetAndMetadata.offset()) + .setCommittedLeaderEpoch(offsetAndMetadata.leaderEpoch().orElse(RecordBatch.NO_PARTITION_LEADER_EPOCH)) + .setCommittedMetadata(offsetAndMetadata.metadata()) + ); + requestTopicDataMap.put(topicPartition.topic(), topic); + } + + final Generation generation; + if (subscriptions.hasAutoAssignedPartitions()) { + generation = generationIfStable(); + // if the generation is null, we are not part of an active group (and we expect to be). + // the only thing we can do is fail the commit and let the user rejoin the group in poll(). + if (generation == null) { + log.info("Failing OffsetCommit request since the consumer is not part of an active group"); + + if (rebalanceInProgress()) { + // if the client knows it is already rebalancing, we can use RebalanceInProgressException instead of + // CommitFailedException to indicate this is not a fatal error + return RequestFuture.failure(new RebalanceInProgressException("Offset commit cannot be completed since the " + + "consumer is undergoing a rebalance for auto partition assignment. You can try completing the rebalance " + + "by calling poll() and then retry the operation.")); + } else { + return RequestFuture.failure(new CommitFailedException("Offset commit cannot be completed since the " + + "consumer is not part of an active group for auto partition assignment; it is likely that the consumer " + + "was kicked out of the group.")); + } + } + } else { + generation = Generation.NO_GENERATION; + } + + OffsetCommitRequest.Builder builder = new OffsetCommitRequest.Builder( + new OffsetCommitRequestData() + .setGroupId(this.rebalanceConfig.groupId) + .setGenerationId(generation.generationId) + .setMemberId(generation.memberId) + .setGroupInstanceId(rebalanceConfig.groupInstanceId.orElse(null)) + .setTopics(new ArrayList<>(requestTopicDataMap.values())) + ); + + log.trace("Sending OffsetCommit request with {} to coordinator {}", offsets, coordinator); + + return client.send(coordinator, builder) + .compose(new OffsetCommitResponseHandler(offsets, generation)); + } + + private class OffsetCommitResponseHandler extends CoordinatorResponseHandler { + private final Map offsets; + + private OffsetCommitResponseHandler(Map offsets, Generation generation) { + super(generation); + this.offsets = offsets; + } + + @Override + public void handle(OffsetCommitResponse commitResponse, RequestFuture future) { + sensors.commitSensor.record(response.requestLatencyMs()); + Set unauthorizedTopics = new HashSet<>(); + + for (OffsetCommitResponseData.OffsetCommitResponseTopic topic : commitResponse.data().topics()) { + for (OffsetCommitResponseData.OffsetCommitResponsePartition partition : topic.partitions()) { + TopicPartition tp = new TopicPartition(topic.name(), partition.partitionIndex()); + OffsetAndMetadata offsetAndMetadata = this.offsets.get(tp); + + long offset = offsetAndMetadata.offset(); + + Errors error = Errors.forCode(partition.errorCode()); + if (error == Errors.NONE) { + log.debug("Committed offset {} for partition {}", offset, tp); + } else { + if (error.exception() instanceof RetriableException) { + log.warn("Offset commit failed on partition {} at offset {}: {}", tp, offset, error.message()); + } else { + log.error("Offset commit failed on partition {} at offset {}: {}", tp, offset, error.message()); + } + + if (error == Errors.GROUP_AUTHORIZATION_FAILED) { + future.raise(GroupAuthorizationException.forGroupId(rebalanceConfig.groupId)); + return; + } else if (error == Errors.TOPIC_AUTHORIZATION_FAILED) { + unauthorizedTopics.add(tp.topic()); + } else if (error == Errors.OFFSET_METADATA_TOO_LARGE + || error == Errors.INVALID_COMMIT_OFFSET_SIZE) { + // raise the error to the user + future.raise(error); + return; + } else if (error == Errors.COORDINATOR_LOAD_IN_PROGRESS + || error == Errors.UNKNOWN_TOPIC_OR_PARTITION) { + // just retry + future.raise(error); + return; + } else if (error == Errors.COORDINATOR_NOT_AVAILABLE + || error == Errors.NOT_COORDINATOR + || error == Errors.REQUEST_TIMED_OUT) { + markCoordinatorUnknown(error); + future.raise(error); + return; + } else if (error == Errors.FENCED_INSTANCE_ID) { + log.info("OffsetCommit failed with {} due to group instance id {} fenced", sentGeneration, rebalanceConfig.groupInstanceId); + + // if the generation has changed or we are not in rebalancing, do not raise the fatal error but rebalance-in-progress + if (generationUnchanged()) { + future.raise(error); + } else { + KafkaException exception; + synchronized (ConsumerCoordinator.this) { + if (ConsumerCoordinator.this.state == MemberState.PREPARING_REBALANCE) { + exception = new RebalanceInProgressException("Offset commit cannot be completed since the " + + "consumer member's old generation is fenced by its group instance id, it is possible that " + + "this consumer has already participated another rebalance and got a new generation"); + } else { + exception = new CommitFailedException(); + } + } + future.raise(exception); + } + return; + } else if (error == Errors.REBALANCE_IN_PROGRESS) { + /* Consumer should not try to commit offset in between join-group and sync-group, + * and hence on broker-side it is not expected to see a commit offset request + * during CompletingRebalance phase; if it ever happens then broker would return + * this error to indicate that we are still in the middle of a rebalance. + * In this case we would throw a RebalanceInProgressException, + * request re-join but do not reset generations. If the callers decide to retry they + * can go ahead and call poll to finish up the rebalance first, and then try commit again. + */ + requestRejoin("offset commit failed since group is already rebalancing"); + future.raise(new RebalanceInProgressException("Offset commit cannot be completed since the " + + "consumer group is executing a rebalance at the moment. You can try completing the rebalance " + + "by calling poll() and then retry commit again")); + return; + } else if (error == Errors.UNKNOWN_MEMBER_ID + || error == Errors.ILLEGAL_GENERATION) { + log.info("OffsetCommit failed with {}: {}", sentGeneration, error.message()); + + // only need to reset generation and re-join group if generation has not changed or we are not in rebalancing; + // otherwise only raise rebalance-in-progress error + KafkaException exception; + synchronized (ConsumerCoordinator.this) { + if (!generationUnchanged() && ConsumerCoordinator.this.state == MemberState.PREPARING_REBALANCE) { + exception = new RebalanceInProgressException("Offset commit cannot be completed since the " + + "consumer member's generation is already stale, meaning it has already participated another rebalance and " + + "got a new generation. You can try completing the rebalance by calling poll() and then retry commit again"); + } else { + resetGenerationOnResponseError(ApiKeys.OFFSET_COMMIT, error); + exception = new CommitFailedException(); + } + } + future.raise(exception); + return; + } else { + future.raise(new KafkaException("Unexpected error in commit: " + error.message())); + return; + } + } + } + } + + if (!unauthorizedTopics.isEmpty()) { + log.error("Not authorized to commit to topics {}", unauthorizedTopics); + future.raise(new TopicAuthorizationException(unauthorizedTopics)); + } else { + future.complete(null); + } + } + } + + /** + * Fetch the committed offsets for a set of partitions. This is a non-blocking call. The + * returned future can be polled to get the actual offsets returned from the broker. + * + * @param partitions The set of partitions to get offsets for. + * @return A request future containing the committed offsets. + */ + private RequestFuture> sendOffsetFetchRequest(Set partitions) { + Node coordinator = checkAndGetCoordinator(); + if (coordinator == null) + return RequestFuture.coordinatorNotAvailable(); + + log.debug("Fetching committed offsets for partitions: {}", partitions); + // construct the request + OffsetFetchRequest.Builder requestBuilder = + new OffsetFetchRequest.Builder(this.rebalanceConfig.groupId, true, new ArrayList<>(partitions), throwOnFetchStableOffsetsUnsupported); + + // send the request with a callback + return client.send(coordinator, requestBuilder) + .compose(new OffsetFetchResponseHandler()); + } + + private class OffsetFetchResponseHandler extends CoordinatorResponseHandler> { + private OffsetFetchResponseHandler() { + super(Generation.NO_GENERATION); + } + + @Override + public void handle(OffsetFetchResponse response, RequestFuture> future) { + Errors responseError = response.groupLevelError(rebalanceConfig.groupId); + if (responseError != Errors.NONE) { + log.debug("Offset fetch failed: {}", responseError.message()); + + if (responseError == Errors.COORDINATOR_LOAD_IN_PROGRESS) { + // just retry + future.raise(responseError); + } else if (responseError == Errors.NOT_COORDINATOR) { + // re-discover the coordinator and retry + markCoordinatorUnknown(responseError); + future.raise(responseError); + } else if (responseError == Errors.GROUP_AUTHORIZATION_FAILED) { + future.raise(GroupAuthorizationException.forGroupId(rebalanceConfig.groupId)); + } else { + future.raise(new KafkaException("Unexpected error in fetch offset response: " + responseError.message())); + } + return; + } + + Set unauthorizedTopics = null; + Map responseData = + response.partitionDataMap(rebalanceConfig.groupId); + Map offsets = new HashMap<>(responseData.size()); + Set unstableTxnOffsetTopicPartitions = new HashSet<>(); + for (Map.Entry entry : responseData.entrySet()) { + TopicPartition tp = entry.getKey(); + OffsetFetchResponse.PartitionData partitionData = entry.getValue(); + if (partitionData.hasError()) { + Errors error = partitionData.error; + log.debug("Failed to fetch offset for partition {}: {}", tp, error.message()); + + if (error == Errors.UNKNOWN_TOPIC_OR_PARTITION) { + future.raise(new KafkaException("Topic or Partition " + tp + " does not exist")); + return; + } else if (error == Errors.TOPIC_AUTHORIZATION_FAILED) { + if (unauthorizedTopics == null) { + unauthorizedTopics = new HashSet<>(); + } + unauthorizedTopics.add(tp.topic()); + } else if (error == Errors.UNSTABLE_OFFSET_COMMIT) { + unstableTxnOffsetTopicPartitions.add(tp); + } else { + future.raise(new KafkaException("Unexpected error in fetch offset response for partition " + + tp + ": " + error.message())); + return; + } + } else if (partitionData.offset >= 0) { + // record the position with the offset (-1 indicates no committed offset to fetch); + // if there's no committed offset, record as null + offsets.put(tp, new OffsetAndMetadata(partitionData.offset, partitionData.leaderEpoch, partitionData.metadata)); + } else { + log.info("Found no committed offset for partition {}", tp); + offsets.put(tp, null); + } + } + + if (unauthorizedTopics != null) { + future.raise(new TopicAuthorizationException(unauthorizedTopics)); + } else if (!unstableTxnOffsetTopicPartitions.isEmpty()) { + // just retry + log.info("The following partitions still have unstable offsets " + + "which are not cleared on the broker side: {}" + + ", this could be either " + + "transactional offsets waiting for completion, or " + + "normal offsets waiting for replication after appending to local log", unstableTxnOffsetTopicPartitions); + future.raise(new UnstableOffsetCommitException("There are unstable offsets for the requested topic partitions")); + } else { + future.complete(offsets); + } + } + } + + private class ConsumerCoordinatorMetrics { + private final String metricGrpName; + private final Sensor commitSensor; + private final Sensor revokeCallbackSensor; + private final Sensor assignCallbackSensor; + private final Sensor loseCallbackSensor; + + private ConsumerCoordinatorMetrics(Metrics metrics, String metricGrpPrefix) { + this.metricGrpName = metricGrpPrefix + "-coordinator-metrics"; + + this.commitSensor = metrics.sensor("commit-latency"); + this.commitSensor.add(metrics.metricName("commit-latency-avg", + this.metricGrpName, + "The average time taken for a commit request"), new Avg()); + this.commitSensor.add(metrics.metricName("commit-latency-max", + this.metricGrpName, + "The max time taken for a commit request"), new Max()); + this.commitSensor.add(createMeter(metrics, metricGrpName, "commit", "commit calls")); + + this.revokeCallbackSensor = metrics.sensor("partition-revoked-latency"); + this.revokeCallbackSensor.add(metrics.metricName("partition-revoked-latency-avg", + this.metricGrpName, + "The average time taken for a partition-revoked rebalance listener callback"), new Avg()); + this.revokeCallbackSensor.add(metrics.metricName("partition-revoked-latency-max", + this.metricGrpName, + "The max time taken for a partition-revoked rebalance listener callback"), new Max()); + + this.assignCallbackSensor = metrics.sensor("partition-assigned-latency"); + this.assignCallbackSensor.add(metrics.metricName("partition-assigned-latency-avg", + this.metricGrpName, + "The average time taken for a partition-assigned rebalance listener callback"), new Avg()); + this.assignCallbackSensor.add(metrics.metricName("partition-assigned-latency-max", + this.metricGrpName, + "The max time taken for a partition-assigned rebalance listener callback"), new Max()); + + this.loseCallbackSensor = metrics.sensor("partition-lost-latency"); + this.loseCallbackSensor.add(metrics.metricName("partition-lost-latency-avg", + this.metricGrpName, + "The average time taken for a partition-lost rebalance listener callback"), new Avg()); + this.loseCallbackSensor.add(metrics.metricName("partition-lost-latency-max", + this.metricGrpName, + "The max time taken for a partition-lost rebalance listener callback"), new Max()); + + Measurable numParts = (config, now) -> subscriptions.numAssignedPartitions(); + metrics.addMetric(metrics.metricName("assigned-partitions", + this.metricGrpName, + "The number of partitions currently assigned to this consumer"), numParts); + } + } + + private static class MetadataSnapshot { + private final int version; + private final Map partitionsPerTopic; + + private MetadataSnapshot(SubscriptionState subscription, Cluster cluster, int version) { + Map partitionsPerTopic = new HashMap<>(); + for (String topic : subscription.metadataTopics()) { + Integer numPartitions = cluster.partitionCountForTopic(topic); + if (numPartitions != null) + partitionsPerTopic.put(topic, numPartitions); + } + this.partitionsPerTopic = partitionsPerTopic; + this.version = version; + } + + boolean matches(MetadataSnapshot other) { + return version == other.version || partitionsPerTopic.equals(other.partitionsPerTopic); + } + + @Override + public String toString() { + return "(version" + version + ": " + partitionsPerTopic + ")"; + } + } + + private static class OffsetCommitCompletion { + private final OffsetCommitCallback callback; + private final Map offsets; + private final Exception exception; + + private OffsetCommitCompletion(OffsetCommitCallback callback, Map offsets, Exception exception) { + this.callback = callback; + this.offsets = offsets; + this.exception = exception; + } + + public void invoke() { + if (callback != null) + callback.onComplete(offsets, exception); + } + } + + /* test-only classes below */ + RebalanceProtocol getProtocol() { + return protocol; + } + + boolean poll(Timer timer) { + return poll(timer, true); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerInterceptors.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerInterceptors.java new file mode 100644 index 0000000..d96d8ce --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerInterceptors.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + + +import org.apache.kafka.clients.consumer.ConsumerInterceptor; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.TopicPartition; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Closeable; +import java.util.List; +import java.util.Map; + +/** + * A container that holds the list {@link org.apache.kafka.clients.consumer.ConsumerInterceptor} + * and wraps calls to the chain of custom interceptors. + */ +public class ConsumerInterceptors implements Closeable { + private static final Logger log = LoggerFactory.getLogger(ConsumerInterceptors.class); + private final List> interceptors; + + public ConsumerInterceptors(List> interceptors) { + this.interceptors = interceptors; + } + + /** + * This is called when the records are about to be returned to the user. + *

+ * This method calls {@link ConsumerInterceptor#onConsume(ConsumerRecords)} for each + * interceptor. Records returned from each interceptor get passed to onConsume() of the next interceptor + * in the chain of interceptors. + *

+ * This method does not throw exceptions. If any of the interceptors in the chain throws an exception, + * it gets caught and logged, and next interceptor in the chain is called with 'records' returned by the + * previous successful interceptor onConsume call. + * + * @param records records to be consumed by the client. + * @return records that are either modified by interceptors or same as records passed to this method. + */ + public ConsumerRecords onConsume(ConsumerRecords records) { + ConsumerRecords interceptRecords = records; + for (ConsumerInterceptor interceptor : this.interceptors) { + try { + interceptRecords = interceptor.onConsume(interceptRecords); + } catch (Exception e) { + // do not propagate interceptor exception, log and continue calling other interceptors + log.warn("Error executing interceptor onConsume callback", e); + } + } + return interceptRecords; + } + + /** + * This is called when commit request returns successfully from the broker. + *

+ * This method calls {@link ConsumerInterceptor#onCommit(Map)} method for each interceptor. + *

+ * This method does not throw exceptions. Exceptions thrown by any of the interceptors in the chain are logged, but not propagated. + * + * @param offsets A map of offsets by partition with associated metadata + */ + public void onCommit(Map offsets) { + for (ConsumerInterceptor interceptor : this.interceptors) { + try { + interceptor.onCommit(offsets); + } catch (Exception e) { + // do not propagate interceptor exception, just log + log.warn("Error executing interceptor onCommit callback", e); + } + } + } + + /** + * Closes every interceptor in a container. + */ + @Override + public void close() { + for (ConsumerInterceptor interceptor : this.interceptors) { + try { + interceptor.close(); + } catch (Exception e) { + log.error("Failed to close consumer interceptor ", e); + } + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerMetadata.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerMetadata.java new file mode 100644 index 0000000..ef7d924 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerMetadata.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.clients.Metadata; +import org.apache.kafka.common.internals.ClusterResourceListeners; +import org.apache.kafka.common.requests.MetadataRequest; +import org.apache.kafka.common.utils.LogContext; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +public class ConsumerMetadata extends Metadata { + private final boolean includeInternalTopics; + private final boolean allowAutoTopicCreation; + private final SubscriptionState subscription; + private final Set transientTopics; + + public ConsumerMetadata(long refreshBackoffMs, + long metadataExpireMs, + boolean includeInternalTopics, + boolean allowAutoTopicCreation, + SubscriptionState subscription, + LogContext logContext, + ClusterResourceListeners clusterResourceListeners) { + super(refreshBackoffMs, metadataExpireMs, logContext, clusterResourceListeners); + this.includeInternalTopics = includeInternalTopics; + this.allowAutoTopicCreation = allowAutoTopicCreation; + this.subscription = subscription; + this.transientTopics = new HashSet<>(); + } + + public boolean allowAutoTopicCreation() { + return allowAutoTopicCreation; + } + + @Override + public synchronized MetadataRequest.Builder newMetadataRequestBuilder() { + if (subscription.hasPatternSubscription()) + return MetadataRequest.Builder.allTopics(); + List topics = new ArrayList<>(); + topics.addAll(subscription.metadataTopics()); + topics.addAll(transientTopics); + return new MetadataRequest.Builder(topics, allowAutoTopicCreation); + } + + synchronized void addTransientTopics(Set topics) { + this.transientTopics.addAll(topics); + if (!fetch().topics().containsAll(topics)) + requestUpdateForNewTopics(); + } + + synchronized void clearTransientTopics() { + this.transientTopics.clear(); + } + + @Override + protected synchronized boolean retainTopic(String topic, boolean isInternal, long nowMs) { + if (transientTopics.contains(topic) || subscription.needsMetadata(topic)) + return true; + + if (isInternal && !includeInternalTopics) + return false; + + return subscription.matchesSubscribedPattern(topic); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerMetrics.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerMetrics.java new file mode 100644 index 0000000..e58db82 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerMetrics.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import org.apache.kafka.common.MetricNameTemplate; +import org.apache.kafka.common.metrics.Metrics; + +public class ConsumerMetrics { + + public FetcherMetricsRegistry fetcherMetrics; + + public ConsumerMetrics(Set metricsTags, String metricGrpPrefix) { + this.fetcherMetrics = new FetcherMetricsRegistry(metricsTags, metricGrpPrefix); + } + + public ConsumerMetrics(String metricGroupPrefix) { + this(new HashSet(), metricGroupPrefix); + } + + private List getAllTemplates() { + List l = new ArrayList<>(this.fetcherMetrics.getAllTemplates()); + return l; + } + + public static void main(String[] args) { + Set tags = new HashSet<>(); + tags.add("client-id"); + ConsumerMetrics metrics = new ConsumerMetrics(tags, "consumer"); + System.out.println(Metrics.toHtmlTable("kafka.consumer", metrics.getAllTemplates())); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClient.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClient.java new file mode 100644 index 0000000..4b91120 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClient.java @@ -0,0 +1,725 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.clients.ClientRequest; +import org.apache.kafka.clients.ClientResponse; +import org.apache.kafka.clients.KafkaClient; +import org.apache.kafka.clients.Metadata; +import org.apache.kafka.clients.RequestCompletionHandler; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.errors.DisconnectException; +import org.apache.kafka.common.errors.InterruptException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.WakeupException; +import org.apache.kafka.common.requests.AbstractRequest; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Timer; +import org.slf4j.Logger; + +import java.io.Closeable; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.locks.ReentrantLock; + +/** + * Higher level consumer access to the network layer with basic support for request futures. This class + * is thread-safe, but provides no synchronization for response callbacks. This guarantees that no locks + * are held when they are invoked. + */ +public class ConsumerNetworkClient implements Closeable { + private static final int MAX_POLL_TIMEOUT_MS = 5000; + + // the mutable state of this class is protected by the object's monitor (excluding the wakeup + // flag and the request completion queue below). + private final Logger log; + private final KafkaClient client; + private final UnsentRequests unsent = new UnsentRequests(); + private final Metadata metadata; + private final Time time; + private final long retryBackoffMs; + private final int maxPollTimeoutMs; + private final int requestTimeoutMs; + private final AtomicBoolean wakeupDisabled = new AtomicBoolean(); + + // We do not need high throughput, so use a fair lock to try to avoid starvation + private final ReentrantLock lock = new ReentrantLock(true); + + // when requests complete, they are transferred to this queue prior to invocation. The purpose + // is to avoid invoking them while holding this object's monitor which can open the door for deadlocks. + private final ConcurrentLinkedQueue pendingCompletion = new ConcurrentLinkedQueue<>(); + + private final ConcurrentLinkedQueue pendingDisconnects = new ConcurrentLinkedQueue<>(); + + // this flag allows the client to be safely woken up without waiting on the lock above. It is + // atomic to avoid the need to acquire the lock above in order to enable it concurrently. + private final AtomicBoolean wakeup = new AtomicBoolean(false); + + public ConsumerNetworkClient(LogContext logContext, + KafkaClient client, + Metadata metadata, + Time time, + long retryBackoffMs, + int requestTimeoutMs, + int maxPollTimeoutMs) { + this.log = logContext.logger(ConsumerNetworkClient.class); + this.client = client; + this.metadata = metadata; + this.time = time; + this.retryBackoffMs = retryBackoffMs; + this.maxPollTimeoutMs = Math.min(maxPollTimeoutMs, MAX_POLL_TIMEOUT_MS); + this.requestTimeoutMs = requestTimeoutMs; + } + + public int defaultRequestTimeoutMs() { + return requestTimeoutMs; + } + + /** + * Send a request with the default timeout. See {@link #send(Node, AbstractRequest.Builder, int)}. + */ + public RequestFuture send(Node node, AbstractRequest.Builder requestBuilder) { + return send(node, requestBuilder, requestTimeoutMs); + } + + /** + * Send a new request. Note that the request is not actually transmitted on the + * network until one of the {@link #poll(Timer)} variants is invoked. At this + * point the request will either be transmitted successfully or will fail. + * Use the returned future to obtain the result of the send. Note that there is no + * need to check for disconnects explicitly on the {@link ClientResponse} object; + * instead, the future will be failed with a {@link DisconnectException}. + * + * @param node The destination of the request + * @param requestBuilder A builder for the request payload + * @param requestTimeoutMs Maximum time in milliseconds to await a response before disconnecting the socket and + * cancelling the request. The request may be cancelled sooner if the socket disconnects + * for any reason. + * @return A future which indicates the result of the send. + */ + public RequestFuture send(Node node, + AbstractRequest.Builder requestBuilder, + int requestTimeoutMs) { + long now = time.milliseconds(); + RequestFutureCompletionHandler completionHandler = new RequestFutureCompletionHandler(); + ClientRequest clientRequest = client.newClientRequest(node.idString(), requestBuilder, now, true, + requestTimeoutMs, completionHandler); + unsent.put(node, clientRequest); + + // wakeup the client in case it is blocking in poll so that we can send the queued request + client.wakeup(); + return completionHandler.future; + } + + public Node leastLoadedNode() { + lock.lock(); + try { + return client.leastLoadedNode(time.milliseconds()); + } finally { + lock.unlock(); + } + } + + public boolean hasReadyNodes(long now) { + lock.lock(); + try { + return client.hasReadyNodes(now); + } finally { + lock.unlock(); + } + } + + /** + * Block waiting on the metadata refresh with a timeout. + * + * @return true if update succeeded, false otherwise. + */ + public boolean awaitMetadataUpdate(Timer timer) { + int version = this.metadata.requestUpdate(); + do { + poll(timer); + } while (this.metadata.updateVersion() == version && timer.notExpired()); + return this.metadata.updateVersion() > version; + } + + /** + * Ensure our metadata is fresh (if an update is expected, this will block + * until it has completed). + */ + boolean ensureFreshMetadata(Timer timer) { + if (this.metadata.updateRequested() || this.metadata.timeToNextUpdate(timer.currentTimeMs()) == 0) { + return awaitMetadataUpdate(timer); + } else { + // the metadata is already fresh + return true; + } + } + + /** + * Wakeup an active poll. This will cause the polling thread to throw an exception either + * on the current poll if one is active, or the next poll. + */ + public void wakeup() { + // wakeup should be safe without holding the client lock since it simply delegates to + // Selector's wakeup, which is thread-safe + log.debug("Received user wakeup"); + this.wakeup.set(true); + this.client.wakeup(); + } + + /** + * Block indefinitely until the given request future has finished. + * @param future The request future to await. + * @throws WakeupException if {@link #wakeup()} is called from another thread + * @throws InterruptException if the calling thread is interrupted + */ + public void poll(RequestFuture future) { + while (!future.isDone()) + poll(time.timer(Long.MAX_VALUE), future); + } + + /** + * Block until the provided request future request has finished or the timeout has expired. + * @param future The request future to wait for + * @param timer Timer bounding how long this method can block + * @return true if the future is done, false otherwise + * @throws WakeupException if {@link #wakeup()} is called from another thread + * @throws InterruptException if the calling thread is interrupted + */ + public boolean poll(RequestFuture future, Timer timer) { + do { + poll(timer, future); + } while (!future.isDone() && timer.notExpired()); + return future.isDone(); + } + + /** + * Poll for any network IO. + * @param timer Timer bounding how long this method can block + * @throws WakeupException if {@link #wakeup()} is called from another thread + * @throws InterruptException if the calling thread is interrupted + */ + public void poll(Timer timer) { + poll(timer, null); + } + + /** + * Poll for any network IO. + * @param timer Timer bounding how long this method can block + * @param pollCondition Nullable blocking condition + */ + public void poll(Timer timer, PollCondition pollCondition) { + poll(timer, pollCondition, false); + } + + /** + * Poll for any network IO. + * @param timer Timer bounding how long this method can block + * @param pollCondition Nullable blocking condition + * @param disableWakeup If TRUE disable triggering wake-ups + */ + public void poll(Timer timer, PollCondition pollCondition, boolean disableWakeup) { + // there may be handlers which need to be invoked if we woke up the previous call to poll + firePendingCompletedRequests(); + + lock.lock(); + try { + // Handle async disconnects prior to attempting any sends + handlePendingDisconnects(); + + // send all the requests we can send now + long pollDelayMs = trySend(timer.currentTimeMs()); + + // check whether the poll is still needed by the caller. Note that if the expected completion + // condition becomes satisfied after the call to shouldBlock() (because of a fired completion + // handler), the client will be woken up. + if (pendingCompletion.isEmpty() && (pollCondition == null || pollCondition.shouldBlock())) { + // if there are no requests in flight, do not block longer than the retry backoff + long pollTimeout = Math.min(timer.remainingMs(), pollDelayMs); + if (client.inFlightRequestCount() == 0) + pollTimeout = Math.min(pollTimeout, retryBackoffMs); + client.poll(pollTimeout, timer.currentTimeMs()); + } else { + client.poll(0, timer.currentTimeMs()); + } + timer.update(); + + // handle any disconnects by failing the active requests. note that disconnects must + // be checked immediately following poll since any subsequent call to client.ready() + // will reset the disconnect status + checkDisconnects(timer.currentTimeMs()); + if (!disableWakeup) { + // trigger wakeups after checking for disconnects so that the callbacks will be ready + // to be fired on the next call to poll() + maybeTriggerWakeup(); + } + // throw InterruptException if this thread is interrupted + maybeThrowInterruptException(); + + // try again to send requests since buffer space may have been + // cleared or a connect finished in the poll + trySend(timer.currentTimeMs()); + + // fail requests that couldn't be sent if they have expired + failExpiredRequests(timer.currentTimeMs()); + + // clean unsent requests collection to keep the map from growing indefinitely + unsent.clean(); + } finally { + lock.unlock(); + } + + // called without the lock to avoid deadlock potential if handlers need to acquire locks + firePendingCompletedRequests(); + + metadata.maybeThrowAnyException(); + } + + /** + * Poll for network IO and return immediately. This will not trigger wakeups. + */ + public void pollNoWakeup() { + poll(time.timer(0), null, true); + } + + /** + * Poll for network IO in best-effort only trying to transmit the ready-to-send request + * Do not check any pending requests or metadata errors so that no exception should ever + * be thrown, also no wakeups be triggered and no interrupted exception either. + */ + public void transmitSends() { + Timer timer = time.timer(0); + + // do not try to handle any disconnects, prev request failures, metadata exception etc; + // just try once and return immediately + lock.lock(); + try { + // send all the requests we can send now + trySend(timer.currentTimeMs()); + + client.poll(0, timer.currentTimeMs()); + } finally { + lock.unlock(); + } + } + + /** + * Block until all pending requests from the given node have finished. + * @param node The node to await requests from + * @param timer Timer bounding how long this method can block + * @return true If all requests finished, false if the timeout expired first + */ + public boolean awaitPendingRequests(Node node, Timer timer) { + while (hasPendingRequests(node) && timer.notExpired()) { + poll(timer); + } + return !hasPendingRequests(node); + } + + /** + * Get the count of pending requests to the given node. This includes both request that + * have been transmitted (i.e. in-flight requests) and those which are awaiting transmission. + * @param node The node in question + * @return The number of pending requests + */ + public int pendingRequestCount(Node node) { + lock.lock(); + try { + return unsent.requestCount(node) + client.inFlightRequestCount(node.idString()); + } finally { + lock.unlock(); + } + } + + /** + * Check whether there is pending request to the given node. This includes both request that + * have been transmitted (i.e. in-flight requests) and those which are awaiting transmission. + * @param node The node in question + * @return A boolean indicating whether there is pending request + */ + public boolean hasPendingRequests(Node node) { + if (unsent.hasRequests(node)) + return true; + lock.lock(); + try { + return client.hasInFlightRequests(node.idString()); + } finally { + lock.unlock(); + } + } + + /** + * Get the total count of pending requests from all nodes. This includes both requests that + * have been transmitted (i.e. in-flight requests) and those which are awaiting transmission. + * @return The total count of pending requests + */ + public int pendingRequestCount() { + lock.lock(); + try { + return unsent.requestCount() + client.inFlightRequestCount(); + } finally { + lock.unlock(); + } + } + + /** + * Check whether there is pending request. This includes both requests that + * have been transmitted (i.e. in-flight requests) and those which are awaiting transmission. + * @return A boolean indicating whether there is pending request + */ + public boolean hasPendingRequests() { + if (unsent.hasRequests()) + return true; + lock.lock(); + try { + return client.hasInFlightRequests(); + } finally { + lock.unlock(); + } + } + + private void firePendingCompletedRequests() { + boolean completedRequestsFired = false; + for (;;) { + RequestFutureCompletionHandler completionHandler = pendingCompletion.poll(); + if (completionHandler == null) + break; + + completionHandler.fireCompletion(); + completedRequestsFired = true; + } + + // wakeup the client in case it is blocking in poll for this future's completion + if (completedRequestsFired) + client.wakeup(); + } + + private void checkDisconnects(long now) { + // any disconnects affecting requests that have already been transmitted will be handled + // by NetworkClient, so we just need to check whether connections for any of the unsent + // requests have been disconnected; if they have, then we complete the corresponding future + // and set the disconnect flag in the ClientResponse + for (Node node : unsent.nodes()) { + if (client.connectionFailed(node)) { + // Remove entry before invoking request callback to avoid callbacks handling + // coordinator failures traversing the unsent list again. + Collection requests = unsent.remove(node); + for (ClientRequest request : requests) { + RequestFutureCompletionHandler handler = (RequestFutureCompletionHandler) request.callback(); + AuthenticationException authenticationException = client.authenticationException(node); + handler.onComplete(new ClientResponse(request.makeHeader(request.requestBuilder().latestAllowedVersion()), + request.callback(), request.destination(), request.createdTimeMs(), now, true, + null, authenticationException, null)); + } + } + } + } + + private void handlePendingDisconnects() { + lock.lock(); + try { + while (true) { + Node node = pendingDisconnects.poll(); + if (node == null) + break; + + failUnsentRequests(node, DisconnectException.INSTANCE); + client.disconnect(node.idString()); + } + } finally { + lock.unlock(); + } + } + + public void disconnectAsync(Node node) { + pendingDisconnects.offer(node); + client.wakeup(); + } + + private void failExpiredRequests(long now) { + // clear all expired unsent requests and fail their corresponding futures + Collection expiredRequests = unsent.removeExpiredRequests(now); + for (ClientRequest request : expiredRequests) { + RequestFutureCompletionHandler handler = (RequestFutureCompletionHandler) request.callback(); + handler.onFailure(new TimeoutException("Failed to send request after " + request.requestTimeoutMs() + " ms.")); + } + } + + private void failUnsentRequests(Node node, RuntimeException e) { + // clear unsent requests to node and fail their corresponding futures + lock.lock(); + try { + Collection unsentRequests = unsent.remove(node); + for (ClientRequest unsentRequest : unsentRequests) { + RequestFutureCompletionHandler handler = (RequestFutureCompletionHandler) unsentRequest.callback(); + handler.onFailure(e); + } + } finally { + lock.unlock(); + } + } + + // Visible for testing + long trySend(long now) { + long pollDelayMs = maxPollTimeoutMs; + + // send any requests that can be sent now + for (Node node : unsent.nodes()) { + Iterator iterator = unsent.requestIterator(node); + if (iterator.hasNext()) + pollDelayMs = Math.min(pollDelayMs, client.pollDelayMs(node, now)); + + while (iterator.hasNext()) { + ClientRequest request = iterator.next(); + if (client.ready(node, now)) { + client.send(request, now); + iterator.remove(); + } else { + // try next node when current node is not ready + break; + } + } + } + return pollDelayMs; + } + + public void maybeTriggerWakeup() { + if (!wakeupDisabled.get() && wakeup.get()) { + log.debug("Raising WakeupException in response to user wakeup"); + wakeup.set(false); + throw new WakeupException(); + } + } + + private void maybeThrowInterruptException() { + if (Thread.interrupted()) { + throw new InterruptException(new InterruptedException()); + } + } + + public void disableWakeups() { + wakeupDisabled.set(true); + } + + @Override + public void close() throws IOException { + lock.lock(); + try { + client.close(); + } finally { + lock.unlock(); + } + } + + + /** + * Check if the code is disconnected and unavailable for immediate reconnection (i.e. if it is in + * reconnect backoff window following the disconnect). + */ + public boolean isUnavailable(Node node) { + lock.lock(); + try { + return client.connectionFailed(node) && client.connectionDelay(node, time.milliseconds()) > 0; + } finally { + lock.unlock(); + } + } + + /** + * Check for an authentication error on a given node and raise the exception if there is one. + */ + public void maybeThrowAuthFailure(Node node) { + lock.lock(); + try { + AuthenticationException exception = client.authenticationException(node); + if (exception != null) + throw exception; + } finally { + lock.unlock(); + } + } + + /** + * Initiate a connection if currently possible. This is only really useful for resetting the failed + * status of a socket. If there is an actual request to send, then {@link #send(Node, AbstractRequest.Builder)} + * should be used. + * @param node The node to connect to + */ + public void tryConnect(Node node) { + lock.lock(); + try { + client.ready(node, time.milliseconds()); + } finally { + lock.unlock(); + } + } + + private class RequestFutureCompletionHandler implements RequestCompletionHandler { + private final RequestFuture future; + private ClientResponse response; + private RuntimeException e; + + private RequestFutureCompletionHandler() { + this.future = new RequestFuture<>(); + } + + public void fireCompletion() { + if (e != null) { + future.raise(e); + } else if (response.authenticationException() != null) { + future.raise(response.authenticationException()); + } else if (response.wasDisconnected()) { + log.debug("Cancelled request with header {} due to node {} being disconnected", + response.requestHeader(), response.destination()); + future.raise(DisconnectException.INSTANCE); + } else if (response.versionMismatch() != null) { + future.raise(response.versionMismatch()); + } else { + future.complete(response); + } + } + + public void onFailure(RuntimeException e) { + this.e = e; + pendingCompletion.add(this); + } + + @Override + public void onComplete(ClientResponse response) { + this.response = response; + pendingCompletion.add(this); + } + } + + /** + * When invoking poll from a multi-threaded environment, it is possible that the condition that + * the caller is awaiting has already been satisfied prior to the invocation of poll. We therefore + * introduce this interface to push the condition checking as close as possible to the invocation + * of poll. In particular, the check will be done while holding the lock used to protect concurrent + * access to {@link org.apache.kafka.clients.NetworkClient}, which means implementations must be + * very careful about locking order if the callback must acquire additional locks. + */ + public interface PollCondition { + /** + * Return whether the caller is still awaiting an IO event. + * @return true if so, false otherwise. + */ + boolean shouldBlock(); + } + + /* + * A thread-safe helper class to hold requests per node that have not been sent yet + */ + private static final class UnsentRequests { + private final ConcurrentMap> unsent; + + private UnsentRequests() { + unsent = new ConcurrentHashMap<>(); + } + + public void put(Node node, ClientRequest request) { + // the lock protects the put from a concurrent removal of the queue for the node + synchronized (unsent) { + ConcurrentLinkedQueue requests = unsent.computeIfAbsent(node, key -> new ConcurrentLinkedQueue<>()); + requests.add(request); + } + } + + public int requestCount(Node node) { + ConcurrentLinkedQueue requests = unsent.get(node); + return requests == null ? 0 : requests.size(); + } + + public int requestCount() { + int total = 0; + for (ConcurrentLinkedQueue requests : unsent.values()) + total += requests.size(); + return total; + } + + public boolean hasRequests(Node node) { + ConcurrentLinkedQueue requests = unsent.get(node); + return requests != null && !requests.isEmpty(); + } + + public boolean hasRequests() { + for (ConcurrentLinkedQueue requests : unsent.values()) + if (!requests.isEmpty()) + return true; + return false; + } + + private Collection removeExpiredRequests(long now) { + List expiredRequests = new ArrayList<>(); + for (ConcurrentLinkedQueue requests : unsent.values()) { + Iterator requestIterator = requests.iterator(); + while (requestIterator.hasNext()) { + ClientRequest request = requestIterator.next(); + long elapsedMs = Math.max(0, now - request.createdTimeMs()); + if (elapsedMs > request.requestTimeoutMs()) { + expiredRequests.add(request); + requestIterator.remove(); + } else + break; + } + } + return expiredRequests; + } + + public void clean() { + // the lock protects removal from a concurrent put which could otherwise mutate the + // queue after it has been removed from the map + synchronized (unsent) { + Iterator> iterator = unsent.values().iterator(); + while (iterator.hasNext()) { + ConcurrentLinkedQueue requests = iterator.next(); + if (requests.isEmpty()) + iterator.remove(); + } + } + } + + public Collection remove(Node node) { + // the lock protects removal from a concurrent put which could otherwise mutate the + // queue after it has been removed from the map + synchronized (unsent) { + ConcurrentLinkedQueue requests = unsent.remove(node); + return requests == null ? Collections.emptyList() : requests; + } + } + + public Iterator requestIterator(Node node) { + ConcurrentLinkedQueue requests = unsent.get(node); + return requests == null ? Collections.emptyIterator() : requests.iterator(); + } + + public Collection nodes() { + return unsent.keySet(); + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerProtocol.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerProtocol.java new file mode 100644 index 0000000..7df0fe8 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/ConsumerProtocol.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Assignment; +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Subscription; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.ConsumerProtocolAssignment; +import org.apache.kafka.common.message.ConsumerProtocolSubscription; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.MessageUtil; +import org.apache.kafka.common.protocol.types.SchemaException; + +import java.nio.ByteBuffer; +import java.nio.BufferUnderflowException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; + +/** + * ConsumerProtocol contains the schemas for consumer subscriptions and assignments for use with + * Kafka's generalized group management protocol. + * + * The current implementation assumes that future versions will not break compatibility. When + * it encounters a newer version, it parses it using the current format. This basically means + * that new versions cannot remove or reorder any of the existing fields. + */ +public class ConsumerProtocol { + public static final String PROTOCOL_TYPE = "consumer"; + + static { + // Safety check to ensure that both parts of the consumer protocol remain in sync. + if (ConsumerProtocolSubscription.LOWEST_SUPPORTED_VERSION + != ConsumerProtocolAssignment.LOWEST_SUPPORTED_VERSION) + throw new IllegalStateException("Subscription and Assignment schemas must have the " + + "same lowest version"); + + if (ConsumerProtocolSubscription.HIGHEST_SUPPORTED_VERSION + != ConsumerProtocolAssignment.HIGHEST_SUPPORTED_VERSION) + throw new IllegalStateException("Subscription and Assignment schemas must have the " + + "same highest version"); + } + + public static short deserializeVersion(final ByteBuffer buffer) { + try { + return buffer.getShort(); + } catch (BufferUnderflowException e) { + throw new SchemaException("Buffer underflow while parsing consumer protocol's header", e); + } + } + + public static ByteBuffer serializeSubscription(final Subscription subscription) { + return serializeSubscription(subscription, ConsumerProtocolSubscription.HIGHEST_SUPPORTED_VERSION); + } + + public static ByteBuffer serializeSubscription(final Subscription subscription, short version) { + version = checkSubscriptionVersion(version); + + ConsumerProtocolSubscription data = new ConsumerProtocolSubscription(); + + List topics = new ArrayList<>(subscription.topics()); + Collections.sort(topics); + data.setTopics(topics); + + data.setUserData(subscription.userData() != null ? subscription.userData().duplicate() : null); + + List ownedPartitions = new ArrayList<>(subscription.ownedPartitions()); + ownedPartitions.sort(Comparator.comparing(TopicPartition::topic).thenComparing(TopicPartition::partition)); + ConsumerProtocolSubscription.TopicPartition partition = null; + for (TopicPartition tp : ownedPartitions) { + if (partition == null || !partition.topic().equals(tp.topic())) { + partition = new ConsumerProtocolSubscription.TopicPartition().setTopic(tp.topic()); + data.ownedPartitions().add(partition); + } + partition.partitions().add(tp.partition()); + } + + return MessageUtil.toVersionPrefixedByteBuffer(version, data); + } + + public static Subscription deserializeSubscription(final ByteBuffer buffer, short version) { + version = checkSubscriptionVersion(version); + + try { + ConsumerProtocolSubscription data = + new ConsumerProtocolSubscription(new ByteBufferAccessor(buffer), version); + + List ownedPartitions = new ArrayList<>(); + for (ConsumerProtocolSubscription.TopicPartition tp : data.ownedPartitions()) { + for (Integer partition : tp.partitions()) { + ownedPartitions.add(new TopicPartition(tp.topic(), partition)); + } + } + + return new Subscription( + data.topics(), + data.userData() != null ? data.userData().duplicate() : null, + ownedPartitions); + } catch (BufferUnderflowException e) { + throw new SchemaException("Buffer underflow while parsing consumer protocol's subscription", e); + } + } + + public static Subscription deserializeSubscription(final ByteBuffer buffer) { + return deserializeSubscription(buffer, deserializeVersion(buffer)); + } + + public static ByteBuffer serializeAssignment(final Assignment assignment) { + return serializeAssignment(assignment, ConsumerProtocolAssignment.HIGHEST_SUPPORTED_VERSION); + } + + public static ByteBuffer serializeAssignment(final Assignment assignment, short version) { + version = checkAssignmentVersion(version); + + ConsumerProtocolAssignment data = new ConsumerProtocolAssignment(); + data.setUserData(assignment.userData() != null ? assignment.userData().duplicate() : null); + assignment.partitions().forEach(tp -> { + ConsumerProtocolAssignment.TopicPartition partition = data.assignedPartitions().find(tp.topic()); + if (partition == null) { + partition = new ConsumerProtocolAssignment.TopicPartition().setTopic(tp.topic()); + data.assignedPartitions().add(partition); + } + partition.partitions().add(tp.partition()); + }); + return MessageUtil.toVersionPrefixedByteBuffer(version, data); + } + + public static Assignment deserializeAssignment(final ByteBuffer buffer, short version) { + version = checkAssignmentVersion(version); + + try { + ConsumerProtocolAssignment data = + new ConsumerProtocolAssignment(new ByteBufferAccessor(buffer), version); + + List assignedPartitions = new ArrayList<>(); + for (ConsumerProtocolAssignment.TopicPartition tp : data.assignedPartitions()) { + for (Integer partition : tp.partitions()) { + assignedPartitions.add(new TopicPartition(tp.topic(), partition)); + } + } + + return new Assignment( + assignedPartitions, + data.userData() != null ? data.userData().duplicate() : null); + } catch (BufferUnderflowException e) { + throw new SchemaException("Buffer underflow while parsing consumer protocol's assignment", e); + } + } + + public static Assignment deserializeAssignment(final ByteBuffer buffer) { + return deserializeAssignment(buffer, deserializeVersion(buffer)); + } + + private static short checkSubscriptionVersion(final short version) { + if (version < ConsumerProtocolSubscription.LOWEST_SUPPORTED_VERSION) + throw new SchemaException("Unsupported subscription version: " + version); + else if (version > ConsumerProtocolSubscription.HIGHEST_SUPPORTED_VERSION) + return ConsumerProtocolSubscription.HIGHEST_SUPPORTED_VERSION; + else + return version; + } + + private static short checkAssignmentVersion(final short version) { + if (version < ConsumerProtocolAssignment.LOWEST_SUPPORTED_VERSION) + throw new SchemaException("Unsupported assignment version: " + version); + else if (version > ConsumerProtocolAssignment.HIGHEST_SUPPORTED_VERSION) + return ConsumerProtocolAssignment.HIGHEST_SUPPORTED_VERSION; + else + return version; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java new file mode 100644 index 0000000..d567f5b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Fetcher.java @@ -0,0 +1,1961 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.clients.ApiVersions; +import org.apache.kafka.clients.ClientResponse; +import org.apache.kafka.clients.FetchSessionHandler; +import org.apache.kafka.clients.Metadata; +import org.apache.kafka.clients.NodeApiVersions; +import org.apache.kafka.clients.StaleMetadataException; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.LogTruncationException; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.consumer.OffsetAndTimestamp; +import org.apache.kafka.clients.consumer.OffsetOutOfRangeException; +import org.apache.kafka.clients.consumer.OffsetResetStrategy; +import org.apache.kafka.clients.consumer.internals.OffsetsForLeaderEpochClient.OffsetForEpochResult; +import org.apache.kafka.clients.consumer.internals.SubscriptionState.FetchPosition; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.IsolationLevel; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.errors.CorruptRecordException; +import org.apache.kafka.common.errors.InvalidTopicException; +import org.apache.kafka.common.errors.RecordDeserializationException; +import org.apache.kafka.common.errors.RecordTooLargeException; +import org.apache.kafka.common.errors.RetriableException; +import org.apache.kafka.common.errors.SerializationException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.TopicAuthorizationException; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.message.FetchResponseData; +import org.apache.kafka.common.message.ApiVersionsResponseData.ApiVersion; +import org.apache.kafka.common.message.ListOffsetsRequestData.ListOffsetsPartition; +import org.apache.kafka.common.message.ListOffsetsResponseData.ListOffsetsPartitionResponse; +import org.apache.kafka.common.message.ListOffsetsResponseData.ListOffsetsTopicResponse; +import org.apache.kafka.common.metrics.Gauge; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.stats.Avg; +import org.apache.kafka.common.metrics.stats.Max; +import org.apache.kafka.common.metrics.stats.Meter; +import org.apache.kafka.common.metrics.stats.Min; +import org.apache.kafka.common.metrics.stats.Value; +import org.apache.kafka.common.metrics.stats.WindowedCount; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.utils.BufferSupplier; +import org.apache.kafka.common.record.ControlRecordType; +import org.apache.kafka.common.record.Record; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.common.requests.FetchRequest; +import org.apache.kafka.common.requests.FetchResponse; +import org.apache.kafka.common.requests.ListOffsetsRequest; +import org.apache.kafka.common.requests.ListOffsetsResponse; +import org.apache.kafka.common.requests.MetadataRequest; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.requests.OffsetsForLeaderEpochRequest; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.utils.CloseableIterator; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Timer; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.helpers.MessageFormatter; + +import java.io.Closeable; +import java.nio.ByteBuffer; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.PriorityQueue; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static java.util.Collections.emptyList; + +/** + * This class manages the fetching process with the brokers. + *

+ * Thread-safety: + * Requests and responses of Fetcher may be processed by different threads since heartbeat + * thread may process responses. Other operations are single-threaded and invoked only from + * the thread polling the consumer. + *

    + *
  • If a response handler accesses any shared state of the Fetcher (e.g. FetchSessionHandler), + * all access to that state must be synchronized on the Fetcher instance.
  • + *
  • If a response handler accesses any shared state of the coordinator (e.g. SubscriptionState), + * it is assumed that all access to that state is synchronized on the coordinator instance by + * the caller.
  • + *
  • Responses that collate partial responses from multiple brokers (e.g. to list offsets) are + * synchronized on the response future.
  • + *
  • At most one request is pending for each node at any time. Nodes with pending requests are + * tracked and updated after processing the response. This ensures that any state (e.g. epoch) + * updated while processing responses on one thread are visible while creating the subsequent request + * on a different thread.
  • + *
+ */ +public class Fetcher implements Closeable { + private final Logger log; + private final LogContext logContext; + private final ConsumerNetworkClient client; + private final Time time; + private final int minBytes; + private final int maxBytes; + private final int maxWaitMs; + private final int fetchSize; + private final long retryBackoffMs; + private final long requestTimeoutMs; + private final int maxPollRecords; + private final boolean checkCrcs; + private final String clientRackId; + private final ConsumerMetadata metadata; + private final FetchManagerMetrics sensors; + private final SubscriptionState subscriptions; + private final ConcurrentLinkedQueue completedFetches; + private final BufferSupplier decompressionBufferSupplier = BufferSupplier.create(); + private final Deserializer keyDeserializer; + private final Deserializer valueDeserializer; + private final IsolationLevel isolationLevel; + private final Map sessionHandlers; + private final AtomicReference cachedListOffsetsException = new AtomicReference<>(); + private final AtomicReference cachedOffsetForLeaderException = new AtomicReference<>(); + private final OffsetsForLeaderEpochClient offsetsForLeaderEpochClient; + private final Set nodesWithPendingFetchRequests; + private final ApiVersions apiVersions; + private final AtomicInteger metadataUpdateVersion = new AtomicInteger(-1); + + private CompletedFetch nextInLineFetch = null; + + public Fetcher(LogContext logContext, + ConsumerNetworkClient client, + int minBytes, + int maxBytes, + int maxWaitMs, + int fetchSize, + int maxPollRecords, + boolean checkCrcs, + String clientRackId, + Deserializer keyDeserializer, + Deserializer valueDeserializer, + ConsumerMetadata metadata, + SubscriptionState subscriptions, + Metrics metrics, + FetcherMetricsRegistry metricsRegistry, + Time time, + long retryBackoffMs, + long requestTimeoutMs, + IsolationLevel isolationLevel, + ApiVersions apiVersions) { + this.log = logContext.logger(Fetcher.class); + this.logContext = logContext; + this.time = time; + this.client = client; + this.metadata = metadata; + this.subscriptions = subscriptions; + this.minBytes = minBytes; + this.maxBytes = maxBytes; + this.maxWaitMs = maxWaitMs; + this.fetchSize = fetchSize; + this.maxPollRecords = maxPollRecords; + this.checkCrcs = checkCrcs; + this.clientRackId = clientRackId; + this.keyDeserializer = keyDeserializer; + this.valueDeserializer = valueDeserializer; + this.completedFetches = new ConcurrentLinkedQueue<>(); + this.sensors = new FetchManagerMetrics(metrics, metricsRegistry); + this.retryBackoffMs = retryBackoffMs; + this.requestTimeoutMs = requestTimeoutMs; + this.isolationLevel = isolationLevel; + this.apiVersions = apiVersions; + this.sessionHandlers = new HashMap<>(); + this.offsetsForLeaderEpochClient = new OffsetsForLeaderEpochClient(client, logContext); + this.nodesWithPendingFetchRequests = new HashSet<>(); + } + + /** + * Represents data about an offset returned by a broker. + */ + static class ListOffsetData { + final long offset; + final Long timestamp; // null if the broker does not support returning timestamps + final Optional leaderEpoch; // empty if the leader epoch is not known + + ListOffsetData(long offset, Long timestamp, Optional leaderEpoch) { + this.offset = offset; + this.timestamp = timestamp; + this.leaderEpoch = leaderEpoch; + } + } + + /** + * Return whether we have any completed fetches pending return to the user. This method is thread-safe. Has + * visibility for testing. + * @return true if there are completed fetches, false otherwise + */ + protected boolean hasCompletedFetches() { + return !completedFetches.isEmpty(); + } + + /** + * Return whether we have any completed fetches that are fetchable. This method is thread-safe. + * @return true if there are completed fetches that can be returned, false otherwise + */ + public boolean hasAvailableFetches() { + return completedFetches.stream().anyMatch(fetch -> subscriptions.isFetchable(fetch.partition)); + } + + /** + * Set-up a fetch request for any node that we have assigned partitions for which doesn't already have + * an in-flight fetch or pending fetch data. + * @return number of fetches sent + */ + public synchronized int sendFetches() { + // Update metrics in case there was an assignment change + sensors.maybeUpdateAssignment(subscriptions); + + Map fetchRequestMap = prepareFetchRequests(); + for (Map.Entry entry : fetchRequestMap.entrySet()) { + final Node fetchTarget = entry.getKey(); + final FetchSessionHandler.FetchRequestData data = entry.getValue(); + final short maxVersion; + if (!data.canUseTopicIds()) { + maxVersion = (short) 12; + } else { + maxVersion = ApiKeys.FETCH.latestVersion(); + } + final FetchRequest.Builder request = FetchRequest.Builder + .forConsumer(maxVersion, this.maxWaitMs, this.minBytes, data.toSend()) + .isolationLevel(isolationLevel) + .setMaxBytes(this.maxBytes) + .metadata(data.metadata()) + .removed(data.toForget()) + .replaced(data.toReplace()) + .rackId(clientRackId); + + if (log.isDebugEnabled()) { + log.debug("Sending {} {} to broker {}", isolationLevel, data.toString(), fetchTarget); + } + RequestFuture future = client.send(fetchTarget, request); + // We add the node to the set of nodes with pending fetch requests before adding the + // listener because the future may have been fulfilled on another thread (e.g. during a + // disconnection being handled by the heartbeat thread) which will mean the listener + // will be invoked synchronously. + this.nodesWithPendingFetchRequests.add(entry.getKey().id()); + future.addListener(new RequestFutureListener() { + @Override + public void onSuccess(ClientResponse resp) { + synchronized (Fetcher.this) { + try { + FetchResponse response = (FetchResponse) resp.responseBody(); + FetchSessionHandler handler = sessionHandler(fetchTarget.id()); + if (handler == null) { + log.error("Unable to find FetchSessionHandler for node {}. Ignoring fetch response.", + fetchTarget.id()); + return; + } + if (!handler.handleResponse(response, resp.requestHeader().apiVersion())) { + if (response.error() == Errors.FETCH_SESSION_TOPIC_ID_ERROR) { + metadata.requestUpdate(); + } + return; + } + + Map responseData = response.responseData(handler.sessionTopicNames(), resp.requestHeader().apiVersion()); + Set partitions = new HashSet<>(responseData.keySet()); + FetchResponseMetricAggregator metricAggregator = new FetchResponseMetricAggregator(sensors, partitions); + + for (Map.Entry entry : responseData.entrySet()) { + TopicPartition partition = entry.getKey(); + FetchRequest.PartitionData requestData = data.sessionPartitions().get(partition); + if (requestData == null) { + String message; + if (data.metadata().isFull()) { + message = MessageFormatter.arrayFormat( + "Response for missing full request partition: partition={}; metadata={}", + new Object[]{partition, data.metadata()}).getMessage(); + } else { + message = MessageFormatter.arrayFormat( + "Response for missing session request partition: partition={}; metadata={}; toSend={}; toForget={}; toReplace={}", + new Object[]{partition, data.metadata(), data.toSend(), data.toForget(), data.toReplace()}).getMessage(); + } + + // Received fetch response for missing session partition + throw new IllegalStateException(message); + } else { + long fetchOffset = requestData.fetchOffset; + FetchResponseData.PartitionData partitionData = entry.getValue(); + + log.debug("Fetch {} at offset {} for partition {} returned fetch data {}", + isolationLevel, fetchOffset, partition, partitionData); + + Iterator batches = FetchResponse.recordsOrFail(partitionData).batches().iterator(); + short responseVersion = resp.requestHeader().apiVersion(); + + completedFetches.add(new CompletedFetch(partition, partitionData, + metricAggregator, batches, fetchOffset, responseVersion)); + } + } + + sensors.fetchLatency.record(resp.requestLatencyMs()); + } finally { + nodesWithPendingFetchRequests.remove(fetchTarget.id()); + } + } + } + + @Override + public void onFailure(RuntimeException e) { + synchronized (Fetcher.this) { + try { + FetchSessionHandler handler = sessionHandler(fetchTarget.id()); + if (handler != null) { + handler.handleError(e); + } + } finally { + nodesWithPendingFetchRequests.remove(fetchTarget.id()); + } + } + } + }); + + } + return fetchRequestMap.size(); + } + + /** + * Get topic metadata for all topics in the cluster + * @param timer Timer bounding how long this method can block + * @return The map of topics with their partition information + */ + public Map> getAllTopicMetadata(Timer timer) { + return getTopicMetadata(MetadataRequest.Builder.allTopics(), timer); + } + + /** + * Get metadata for all topics present in Kafka cluster + * + * @param request The MetadataRequest to send + * @param timer Timer bounding how long this method can block + * @return The map of topics with their partition information + */ + public Map> getTopicMetadata(MetadataRequest.Builder request, Timer timer) { + // Save the round trip if no topics are requested. + if (!request.isAllTopics() && request.emptyTopicList()) + return Collections.emptyMap(); + + do { + RequestFuture future = sendMetadataRequest(request); + client.poll(future, timer); + + if (future.failed() && !future.isRetriable()) + throw future.exception(); + + if (future.succeeded()) { + MetadataResponse response = (MetadataResponse) future.value().responseBody(); + Cluster cluster = response.buildCluster(); + + Set unauthorizedTopics = cluster.unauthorizedTopics(); + if (!unauthorizedTopics.isEmpty()) + throw new TopicAuthorizationException(unauthorizedTopics); + + boolean shouldRetry = false; + Map errors = response.errors(); + if (!errors.isEmpty()) { + // if there were errors, we need to check whether they were fatal or whether + // we should just retry + + log.debug("Topic metadata fetch included errors: {}", errors); + + for (Map.Entry errorEntry : errors.entrySet()) { + String topic = errorEntry.getKey(); + Errors error = errorEntry.getValue(); + + if (error == Errors.INVALID_TOPIC_EXCEPTION) + throw new InvalidTopicException("Topic '" + topic + "' is invalid"); + else if (error == Errors.UNKNOWN_TOPIC_OR_PARTITION) + // if a requested topic is unknown, we just continue and let it be absent + // in the returned map + continue; + else if (error.exception() instanceof RetriableException) + shouldRetry = true; + else + throw new KafkaException("Unexpected error fetching metadata for topic " + topic, + error.exception()); + } + } + + if (!shouldRetry) { + HashMap> topicsPartitionInfos = new HashMap<>(); + for (String topic : cluster.topics()) + topicsPartitionInfos.put(topic, cluster.partitionsForTopic(topic)); + return topicsPartitionInfos; + } + } + + timer.sleep(retryBackoffMs); + } while (timer.notExpired()); + + throw new TimeoutException("Timeout expired while fetching topic metadata"); + } + + /** + * Send Metadata Request to least loaded node in Kafka cluster asynchronously + * @return A future that indicates result of sent metadata request + */ + private RequestFuture sendMetadataRequest(MetadataRequest.Builder request) { + final Node node = client.leastLoadedNode(); + if (node == null) + return RequestFuture.noBrokersAvailable(); + else + return client.send(node, request); + } + + private Long offsetResetStrategyTimestamp(final TopicPartition partition) { + OffsetResetStrategy strategy = subscriptions.resetStrategy(partition); + if (strategy == OffsetResetStrategy.EARLIEST) + return ListOffsetsRequest.EARLIEST_TIMESTAMP; + else if (strategy == OffsetResetStrategy.LATEST) + return ListOffsetsRequest.LATEST_TIMESTAMP; + else + return null; + } + + private OffsetResetStrategy timestampToOffsetResetStrategy(long timestamp) { + if (timestamp == ListOffsetsRequest.EARLIEST_TIMESTAMP) + return OffsetResetStrategy.EARLIEST; + else if (timestamp == ListOffsetsRequest.LATEST_TIMESTAMP) + return OffsetResetStrategy.LATEST; + else + return null; + } + + /** + * Reset offsets for all assigned partitions that require it. + * + * @throws org.apache.kafka.clients.consumer.NoOffsetForPartitionException If no offset reset strategy is defined + * and one or more partitions aren't awaiting a seekToBeginning() or seekToEnd(). + */ + public void resetOffsetsIfNeeded() { + // Raise exception from previous offset fetch if there is one + RuntimeException exception = cachedListOffsetsException.getAndSet(null); + if (exception != null) + throw exception; + + Set partitions = subscriptions.partitionsNeedingReset(time.milliseconds()); + if (partitions.isEmpty()) + return; + + final Map offsetResetTimestamps = new HashMap<>(); + for (final TopicPartition partition : partitions) { + Long timestamp = offsetResetStrategyTimestamp(partition); + if (timestamp != null) + offsetResetTimestamps.put(partition, timestamp); + } + + resetOffsetsAsync(offsetResetTimestamps); + } + + /** + * Validate offsets for all assigned partitions for which a leader change has been detected. + */ + public void validateOffsetsIfNeeded() { + RuntimeException exception = cachedOffsetForLeaderException.getAndSet(null); + if (exception != null) + throw exception; + + // Validate each partition against the current leader and epoch + // If we see a new metadata version, check all partitions + validatePositionsOnMetadataChange(); + + // Collect positions needing validation, with backoff + Map partitionsToValidate = subscriptions + .partitionsNeedingValidation(time.milliseconds()) + .stream() + .filter(tp -> subscriptions.position(tp) != null) + .collect(Collectors.toMap(Function.identity(), subscriptions::position)); + + validateOffsetsAsync(partitionsToValidate); + } + + public Map offsetsForTimes(Map timestampsToSearch, + Timer timer) { + metadata.addTransientTopics(topicsForPartitions(timestampsToSearch.keySet())); + + try { + Map fetchedOffsets = fetchOffsetsByTimes(timestampsToSearch, + timer, true).fetchedOffsets; + + HashMap offsetsByTimes = new HashMap<>(timestampsToSearch.size()); + for (Map.Entry entry : timestampsToSearch.entrySet()) + offsetsByTimes.put(entry.getKey(), null); + + for (Map.Entry entry : fetchedOffsets.entrySet()) { + // 'entry.getValue().timestamp' will not be null since we are guaranteed + // to work with a v1 (or later) ListOffset request + ListOffsetData offsetData = entry.getValue(); + offsetsByTimes.put(entry.getKey(), new OffsetAndTimestamp(offsetData.offset, offsetData.timestamp, + offsetData.leaderEpoch)); + } + + return offsetsByTimes; + } finally { + metadata.clearTransientTopics(); + } + } + + private ListOffsetResult fetchOffsetsByTimes(Map timestampsToSearch, + Timer timer, + boolean requireTimestamps) { + ListOffsetResult result = new ListOffsetResult(); + if (timestampsToSearch.isEmpty()) + return result; + + Map remainingToSearch = new HashMap<>(timestampsToSearch); + do { + RequestFuture future = sendListOffsetsRequests(remainingToSearch, requireTimestamps); + + future.addListener(new RequestFutureListener() { + @Override + public void onSuccess(ListOffsetResult value) { + synchronized (future) { + result.fetchedOffsets.putAll(value.fetchedOffsets); + remainingToSearch.keySet().retainAll(value.partitionsToRetry); + + for (final Map.Entry entry: value.fetchedOffsets.entrySet()) { + final TopicPartition partition = entry.getKey(); + + // if the interested partitions are part of the subscriptions, use the returned offset to update + // the subscription state as well: + // * with read-committed, the returned offset would be LSO; + // * with read-uncommitted, the returned offset would be HW; + if (subscriptions.isAssigned(partition)) { + final long offset = entry.getValue().offset; + if (isolationLevel == IsolationLevel.READ_COMMITTED) { + log.trace("Updating last stable offset for partition {} to {}", partition, offset); + subscriptions.updateLastStableOffset(partition, offset); + } else { + log.trace("Updating high watermark for partition {} to {}", partition, offset); + subscriptions.updateHighWatermark(partition, offset); + } + } + } + } + } + + @Override + public void onFailure(RuntimeException e) { + if (!(e instanceof RetriableException)) { + throw future.exception(); + } + } + }); + + // if timeout is set to zero, do not try to poll the network client at all + // and return empty immediately; otherwise try to get the results synchronously + // and throw timeout exception if cannot complete in time + if (timer.timeoutMs() == 0L) + return result; + + client.poll(future, timer); + + if (!future.isDone()) { + break; + } else if (remainingToSearch.isEmpty()) { + return result; + } else { + client.awaitMetadataUpdate(timer); + } + } while (timer.notExpired()); + + throw new TimeoutException("Failed to get offsets by times in " + timer.elapsedMs() + "ms"); + } + + public Map beginningOffsets(Collection partitions, Timer timer) { + return beginningOrEndOffset(partitions, ListOffsetsRequest.EARLIEST_TIMESTAMP, timer); + } + + public Map endOffsets(Collection partitions, Timer timer) { + return beginningOrEndOffset(partitions, ListOffsetsRequest.LATEST_TIMESTAMP, timer); + } + + private Map beginningOrEndOffset(Collection partitions, + long timestamp, + Timer timer) { + metadata.addTransientTopics(topicsForPartitions(partitions)); + try { + Map timestampsToSearch = partitions.stream() + .distinct() + .collect(Collectors.toMap(Function.identity(), tp -> timestamp)); + + ListOffsetResult result = fetchOffsetsByTimes(timestampsToSearch, timer, false); + + return result.fetchedOffsets.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().offset)); + } finally { + metadata.clearTransientTopics(); + } + } + + /** + * Return the fetched records, empty the record buffer and update the consumed position. + * + * NOTE: returning empty records guarantees the consumed position are NOT updated. + * + * @return The fetched records per partition + * @throws OffsetOutOfRangeException If there is OffsetOutOfRange error in fetchResponse and + * the defaultResetPolicy is NONE + * @throws TopicAuthorizationException If there is TopicAuthorization error in fetchResponse. + */ + public Map>> fetchedRecords() { + Map>> fetched = new HashMap<>(); + Queue pausedCompletedFetches = new ArrayDeque<>(); + int recordsRemaining = maxPollRecords; + + try { + while (recordsRemaining > 0) { + if (nextInLineFetch == null || nextInLineFetch.isConsumed) { + CompletedFetch records = completedFetches.peek(); + if (records == null) break; + + if (records.notInitialized()) { + try { + nextInLineFetch = initializeCompletedFetch(records); + } catch (Exception e) { + // Remove a completedFetch upon a parse with exception if (1) it contains no records, and + // (2) there are no fetched records with actual content preceding this exception. + // The first condition ensures that the completedFetches is not stuck with the same completedFetch + // in cases such as the TopicAuthorizationException, and the second condition ensures that no + // potential data loss due to an exception in a following record. + FetchResponseData.PartitionData partition = records.partitionData; + if (fetched.isEmpty() && FetchResponse.recordsOrFail(partition).sizeInBytes() == 0) { + completedFetches.poll(); + } + throw e; + } + } else { + nextInLineFetch = records; + } + completedFetches.poll(); + } else if (subscriptions.isPaused(nextInLineFetch.partition)) { + // when the partition is paused we add the records back to the completedFetches queue instead of draining + // them so that they can be returned on a subsequent poll if the partition is resumed at that time + log.debug("Skipping fetching records for assigned partition {} because it is paused", nextInLineFetch.partition); + pausedCompletedFetches.add(nextInLineFetch); + nextInLineFetch = null; + } else { + List> records = fetchRecords(nextInLineFetch, recordsRemaining); + + if (!records.isEmpty()) { + TopicPartition partition = nextInLineFetch.partition; + List> currentRecords = fetched.get(partition); + if (currentRecords == null) { + fetched.put(partition, records); + } else { + // this case shouldn't usually happen because we only send one fetch at a time per partition, + // but it might conceivably happen in some rare cases (such as partition leader changes). + // we have to copy to a new list because the old one may be immutable + List> newRecords = new ArrayList<>(records.size() + currentRecords.size()); + newRecords.addAll(currentRecords); + newRecords.addAll(records); + fetched.put(partition, newRecords); + } + recordsRemaining -= records.size(); + } + } + } + } catch (KafkaException e) { + if (fetched.isEmpty()) + throw e; + } finally { + // add any polled completed fetches for paused partitions back to the completed fetches queue to be + // re-evaluated in the next poll + completedFetches.addAll(pausedCompletedFetches); + } + + return fetched; + } + + private List> fetchRecords(CompletedFetch completedFetch, int maxRecords) { + if (!subscriptions.isAssigned(completedFetch.partition)) { + // this can happen when a rebalance happened before fetched records are returned to the consumer's poll call + log.debug("Not returning fetched records for partition {} since it is no longer assigned", + completedFetch.partition); + } else if (!subscriptions.isFetchable(completedFetch.partition)) { + // this can happen when a partition is paused before fetched records are returned to the consumer's + // poll call or if the offset is being reset + log.debug("Not returning fetched records for assigned partition {} since it is no longer fetchable", + completedFetch.partition); + } else { + FetchPosition position = subscriptions.position(completedFetch.partition); + if (position == null) { + throw new IllegalStateException("Missing position for fetchable partition " + completedFetch.partition); + } + + if (completedFetch.nextFetchOffset == position.offset) { + List> partRecords = completedFetch.fetchRecords(maxRecords); + + log.trace("Returning {} fetched records at offset {} for assigned partition {}", + partRecords.size(), position, completedFetch.partition); + + if (completedFetch.nextFetchOffset > position.offset) { + FetchPosition nextPosition = new FetchPosition( + completedFetch.nextFetchOffset, + completedFetch.lastEpoch, + position.currentLeader); + log.trace("Update fetching position to {} for partition {}", nextPosition, completedFetch.partition); + subscriptions.position(completedFetch.partition, nextPosition); + } + + Long partitionLag = subscriptions.partitionLag(completedFetch.partition, isolationLevel); + if (partitionLag != null) + this.sensors.recordPartitionLag(completedFetch.partition, partitionLag); + + Long lead = subscriptions.partitionLead(completedFetch.partition); + if (lead != null) { + this.sensors.recordPartitionLead(completedFetch.partition, lead); + } + + return partRecords; + } else { + // these records aren't next in line based on the last consumed position, ignore them + // they must be from an obsolete request + log.debug("Ignoring fetched records for {} at offset {} since the current position is {}", + completedFetch.partition, completedFetch.nextFetchOffset, position); + } + } + + log.trace("Draining fetched records for partition {}", completedFetch.partition); + completedFetch.drain(); + + return emptyList(); + } + + // Visible for testing + void resetOffsetIfNeeded(TopicPartition partition, OffsetResetStrategy requestedResetStrategy, ListOffsetData offsetData) { + FetchPosition position = new FetchPosition( + offsetData.offset, + Optional.empty(), // This will ensure we skip validation + metadata.currentLeader(partition)); + offsetData.leaderEpoch.ifPresent(epoch -> metadata.updateLastSeenEpochIfNewer(partition, epoch)); + subscriptions.maybeSeekUnvalidated(partition, position, requestedResetStrategy); + } + + private void resetOffsetsAsync(Map partitionResetTimestamps) { + Map> timestampsToSearchByNode = + groupListOffsetRequests(partitionResetTimestamps, new HashSet<>()); + for (Map.Entry> entry : timestampsToSearchByNode.entrySet()) { + Node node = entry.getKey(); + final Map resetTimestamps = entry.getValue(); + subscriptions.setNextAllowedRetry(resetTimestamps.keySet(), time.milliseconds() + requestTimeoutMs); + + RequestFuture future = sendListOffsetRequest(node, resetTimestamps, false); + future.addListener(new RequestFutureListener() { + @Override + public void onSuccess(ListOffsetResult result) { + if (!result.partitionsToRetry.isEmpty()) { + subscriptions.requestFailed(result.partitionsToRetry, time.milliseconds() + retryBackoffMs); + metadata.requestUpdate(); + } + + for (Map.Entry fetchedOffset : result.fetchedOffsets.entrySet()) { + TopicPartition partition = fetchedOffset.getKey(); + ListOffsetData offsetData = fetchedOffset.getValue(); + ListOffsetsPartition requestedReset = resetTimestamps.get(partition); + resetOffsetIfNeeded(partition, timestampToOffsetResetStrategy(requestedReset.timestamp()), offsetData); + } + } + + @Override + public void onFailure(RuntimeException e) { + subscriptions.requestFailed(resetTimestamps.keySet(), time.milliseconds() + retryBackoffMs); + metadata.requestUpdate(); + + if (!(e instanceof RetriableException) && !cachedListOffsetsException.compareAndSet(null, e)) + log.error("Discarding error in ListOffsetResponse because another error is pending", e); + } + }); + } + } + + static boolean hasUsableOffsetForLeaderEpochVersion(NodeApiVersions nodeApiVersions) { + ApiVersion apiVersion = nodeApiVersions.apiVersion(ApiKeys.OFFSET_FOR_LEADER_EPOCH); + if (apiVersion == null) + return false; + + return OffsetsForLeaderEpochRequest.supportsTopicPermission(apiVersion.maxVersion()); + } + + /** + * For each partition which needs validation, make an asynchronous request to get the end-offsets for the partition + * with the epoch less than or equal to the epoch the partition last saw. + * + * Requests are grouped by Node for efficiency. + */ + private void validateOffsetsAsync(Map partitionsToValidate) { + final Map> regrouped = + regroupFetchPositionsByLeader(partitionsToValidate); + + long nextResetTimeMs = time.milliseconds() + requestTimeoutMs; + regrouped.forEach((node, fetchPositions) -> { + if (node.isEmpty()) { + metadata.requestUpdate(); + return; + } + + NodeApiVersions nodeApiVersions = apiVersions.get(node.idString()); + if (nodeApiVersions == null) { + client.tryConnect(node); + return; + } + + if (!hasUsableOffsetForLeaderEpochVersion(nodeApiVersions)) { + log.debug("Skipping validation of fetch offsets for partitions {} since the broker does not " + + "support the required protocol version (introduced in Kafka 2.3)", + fetchPositions.keySet()); + for (TopicPartition partition : fetchPositions.keySet()) { + subscriptions.completeValidation(partition); + } + return; + } + + subscriptions.setNextAllowedRetry(fetchPositions.keySet(), nextResetTimeMs); + + RequestFuture future = + offsetsForLeaderEpochClient.sendAsyncRequest(node, fetchPositions); + + future.addListener(new RequestFutureListener() { + @Override + public void onSuccess(OffsetForEpochResult offsetsResult) { + List truncations = new ArrayList<>(); + if (!offsetsResult.partitionsToRetry().isEmpty()) { + subscriptions.setNextAllowedRetry(offsetsResult.partitionsToRetry(), time.milliseconds() + retryBackoffMs); + metadata.requestUpdate(); + } + + // For each OffsetsForLeader response, check if the end-offset is lower than our current offset + // for the partition. If so, it means we have experienced log truncation and need to reposition + // that partition's offset. + // + // In addition, check whether the returned offset and epoch are valid. If not, then we should reset + // its offset if reset policy is configured, or throw out of range exception. + offsetsResult.endOffsets().forEach((topicPartition, respEndOffset) -> { + FetchPosition requestPosition = fetchPositions.get(topicPartition); + Optional truncationOpt = + subscriptions.maybeCompleteValidation(topicPartition, requestPosition, respEndOffset); + truncationOpt.ifPresent(truncations::add); + }); + + if (!truncations.isEmpty()) { + maybeSetOffsetForLeaderException(buildLogTruncationException(truncations)); + } + } + + @Override + public void onFailure(RuntimeException e) { + subscriptions.requestFailed(fetchPositions.keySet(), time.milliseconds() + retryBackoffMs); + metadata.requestUpdate(); + + if (!(e instanceof RetriableException)) { + maybeSetOffsetForLeaderException(e); + } + } + }); + }); + } + + private LogTruncationException buildLogTruncationException(List truncations) { + Map divergentOffsets = new HashMap<>(); + Map truncatedFetchOffsets = new HashMap<>(); + for (SubscriptionState.LogTruncation truncation : truncations) { + truncation.divergentOffsetOpt.ifPresent(divergentOffset -> + divergentOffsets.put(truncation.topicPartition, divergentOffset)); + truncatedFetchOffsets.put(truncation.topicPartition, truncation.fetchPosition.offset); + } + return new LogTruncationException("Detected truncated partitions: " + truncations, + truncatedFetchOffsets, divergentOffsets); + } + + private void maybeSetOffsetForLeaderException(RuntimeException e) { + if (!cachedOffsetForLeaderException.compareAndSet(null, e)) { + log.error("Discarding error in OffsetsForLeaderEpoch because another error is pending", e); + } + } + + /** + * Search the offsets by target times for the specified partitions. + * + * @param timestampsToSearch the mapping between partitions and target time + * @param requireTimestamps true if we should fail with an UnsupportedVersionException if the broker does + * not support fetching precise timestamps for offsets + * @return A response which can be polled to obtain the corresponding timestamps and offsets. + */ + private RequestFuture sendListOffsetsRequests(final Map timestampsToSearch, + final boolean requireTimestamps) { + final Set partitionsToRetry = new HashSet<>(); + Map> timestampsToSearchByNode = + groupListOffsetRequests(timestampsToSearch, partitionsToRetry); + if (timestampsToSearchByNode.isEmpty()) + return RequestFuture.failure(new StaleMetadataException()); + + final RequestFuture listOffsetRequestsFuture = new RequestFuture<>(); + final Map fetchedTimestampOffsets = new HashMap<>(); + final AtomicInteger remainingResponses = new AtomicInteger(timestampsToSearchByNode.size()); + + for (Map.Entry> entry : timestampsToSearchByNode.entrySet()) { + RequestFuture future = sendListOffsetRequest(entry.getKey(), entry.getValue(), requireTimestamps); + future.addListener(new RequestFutureListener() { + @Override + public void onSuccess(ListOffsetResult partialResult) { + synchronized (listOffsetRequestsFuture) { + fetchedTimestampOffsets.putAll(partialResult.fetchedOffsets); + partitionsToRetry.addAll(partialResult.partitionsToRetry); + + if (remainingResponses.decrementAndGet() == 0 && !listOffsetRequestsFuture.isDone()) { + ListOffsetResult result = new ListOffsetResult(fetchedTimestampOffsets, partitionsToRetry); + listOffsetRequestsFuture.complete(result); + } + } + } + + @Override + public void onFailure(RuntimeException e) { + synchronized (listOffsetRequestsFuture) { + if (!listOffsetRequestsFuture.isDone()) + listOffsetRequestsFuture.raise(e); + } + } + }); + } + return listOffsetRequestsFuture; + } + + /** + * Groups timestamps to search by node for topic partitions in `timestampsToSearch` that have + * leaders available. Topic partitions from `timestampsToSearch` that do not have their leader + * available are added to `partitionsToRetry` + * @param timestampsToSearch The mapping from partitions ot the target timestamps + * @param partitionsToRetry A set of topic partitions that will be extended with partitions + * that need metadata update or re-connect to the leader. + */ + private Map> groupListOffsetRequests( + Map timestampsToSearch, + Set partitionsToRetry) { + final Map partitionDataMap = new HashMap<>(); + for (Map.Entry entry: timestampsToSearch.entrySet()) { + TopicPartition tp = entry.getKey(); + Long offset = entry.getValue(); + Metadata.LeaderAndEpoch leaderAndEpoch = metadata.currentLeader(tp); + + if (!leaderAndEpoch.leader.isPresent()) { + log.debug("Leader for partition {} is unknown for fetching offset {}", tp, offset); + metadata.requestUpdate(); + partitionsToRetry.add(tp); + } else { + Node leader = leaderAndEpoch.leader.get(); + if (client.isUnavailable(leader)) { + client.maybeThrowAuthFailure(leader); + + // The connection has failed and we need to await the backoff period before we can + // try again. No need to request a metadata update since the disconnect will have + // done so already. + log.debug("Leader {} for partition {} is unavailable for fetching offset until reconnect backoff expires", + leader, tp); + partitionsToRetry.add(tp); + } else { + int currentLeaderEpoch = leaderAndEpoch.epoch.orElse(ListOffsetsResponse.UNKNOWN_EPOCH); + partitionDataMap.put(tp, new ListOffsetsPartition() + .setPartitionIndex(tp.partition()) + .setTimestamp(offset) + .setCurrentLeaderEpoch(currentLeaderEpoch)); + } + } + } + return regroupPartitionMapByNode(partitionDataMap); + } + + /** + * Send the ListOffsetRequest to a specific broker for the partitions and target timestamps. + * + * @param node The node to send the ListOffsetRequest to. + * @param timestampsToSearch The mapping from partitions to the target timestamps. + * @param requireTimestamp True if we require a timestamp in the response. + * @return A response which can be polled to obtain the corresponding timestamps and offsets. + */ + private RequestFuture sendListOffsetRequest(final Node node, + final Map timestampsToSearch, + boolean requireTimestamp) { + ListOffsetsRequest.Builder builder = ListOffsetsRequest.Builder + .forConsumer(requireTimestamp, isolationLevel, false) + .setTargetTimes(ListOffsetsRequest.toListOffsetsTopics(timestampsToSearch)); + + log.debug("Sending ListOffsetRequest {} to broker {}", builder, node); + return client.send(node, builder) + .compose(new RequestFutureAdapter() { + @Override + public void onSuccess(ClientResponse response, RequestFuture future) { + ListOffsetsResponse lor = (ListOffsetsResponse) response.responseBody(); + log.trace("Received ListOffsetResponse {} from broker {}", lor, node); + handleListOffsetResponse(lor, future); + } + }); + } + + /** + * Callback for the response of the list offset call above. + * @param listOffsetsResponse The response from the server. + * @param future The future to be completed when the response returns. Note that any partition-level errors will + * generally fail the entire future result. The one exception is UNSUPPORTED_FOR_MESSAGE_FORMAT, + * which indicates that the broker does not support the v1 message format. Partitions with this + * particular error are simply left out of the future map. Note that the corresponding timestamp + * value of each partition may be null only for v0. In v1 and later the ListOffset API would not + * return a null timestamp (-1 is returned instead when necessary). + */ + private void handleListOffsetResponse(ListOffsetsResponse listOffsetsResponse, + RequestFuture future) { + Map fetchedOffsets = new HashMap<>(); + Set partitionsToRetry = new HashSet<>(); + Set unauthorizedTopics = new HashSet<>(); + + for (ListOffsetsTopicResponse topic : listOffsetsResponse.topics()) { + for (ListOffsetsPartitionResponse partition : topic.partitions()) { + TopicPartition topicPartition = new TopicPartition(topic.name(), partition.partitionIndex()); + Errors error = Errors.forCode(partition.errorCode()); + switch (error) { + case NONE: + if (!partition.oldStyleOffsets().isEmpty()) { + // Handle v0 response with offsets + long offset; + if (partition.oldStyleOffsets().size() > 1) { + future.raise(new IllegalStateException("Unexpected partitionData response of length " + + partition.oldStyleOffsets().size())); + return; + } else { + offset = partition.oldStyleOffsets().get(0); + } + log.debug("Handling v0 ListOffsetResponse response for {}. Fetched offset {}", + topicPartition, offset); + if (offset != ListOffsetsResponse.UNKNOWN_OFFSET) { + ListOffsetData offsetData = new ListOffsetData(offset, null, Optional.empty()); + fetchedOffsets.put(topicPartition, offsetData); + } + } else { + // Handle v1 and later response or v0 without offsets + log.debug("Handling ListOffsetResponse response for {}. Fetched offset {}, timestamp {}", + topicPartition, partition.offset(), partition.timestamp()); + if (partition.offset() != ListOffsetsResponse.UNKNOWN_OFFSET) { + Optional leaderEpoch = (partition.leaderEpoch() == ListOffsetsResponse.UNKNOWN_EPOCH) + ? Optional.empty() + : Optional.of(partition.leaderEpoch()); + ListOffsetData offsetData = new ListOffsetData(partition.offset(), partition.timestamp(), + leaderEpoch); + fetchedOffsets.put(topicPartition, offsetData); + } + } + break; + case UNSUPPORTED_FOR_MESSAGE_FORMAT: + // The message format on the broker side is before 0.10.0, which means it does not + // support timestamps. We treat this case the same as if we weren't able to find an + // offset corresponding to the requested timestamp and leave it out of the result. + log.debug("Cannot search by timestamp for partition {} because the message format version " + + "is before 0.10.0", topicPartition); + break; + case NOT_LEADER_OR_FOLLOWER: + case REPLICA_NOT_AVAILABLE: + case KAFKA_STORAGE_ERROR: + case OFFSET_NOT_AVAILABLE: + case LEADER_NOT_AVAILABLE: + case FENCED_LEADER_EPOCH: + case UNKNOWN_LEADER_EPOCH: + log.debug("Attempt to fetch offsets for partition {} failed due to {}, retrying.", + topicPartition, error); + partitionsToRetry.add(topicPartition); + break; + case UNKNOWN_TOPIC_OR_PARTITION: + log.warn("Received unknown topic or partition error in ListOffset request for partition {}", topicPartition); + partitionsToRetry.add(topicPartition); + break; + case TOPIC_AUTHORIZATION_FAILED: + unauthorizedTopics.add(topicPartition.topic()); + break; + default: + log.warn("Attempt to fetch offsets for partition {} failed due to unexpected exception: {}, retrying.", + topicPartition, error.message()); + partitionsToRetry.add(topicPartition); + } + } + } + + if (!unauthorizedTopics.isEmpty()) + future.raise(new TopicAuthorizationException(unauthorizedTopics)); + else + future.complete(new ListOffsetResult(fetchedOffsets, partitionsToRetry)); + } + + static class ListOffsetResult { + private final Map fetchedOffsets; + private final Set partitionsToRetry; + + ListOffsetResult(Map fetchedOffsets, Set partitionsNeedingRetry) { + this.fetchedOffsets = fetchedOffsets; + this.partitionsToRetry = partitionsNeedingRetry; + } + + ListOffsetResult() { + this.fetchedOffsets = new HashMap<>(); + this.partitionsToRetry = new HashSet<>(); + } + } + + private List fetchablePartitions() { + Set exclude = new HashSet<>(); + if (nextInLineFetch != null && !nextInLineFetch.isConsumed) { + exclude.add(nextInLineFetch.partition); + } + for (CompletedFetch completedFetch : completedFetches) { + exclude.add(completedFetch.partition); + } + return subscriptions.fetchablePartitions(tp -> !exclude.contains(tp)); + } + + /** + * Determine which replica to read from. + */ + Node selectReadReplica(TopicPartition partition, Node leaderReplica, long currentTimeMs) { + Optional nodeId = subscriptions.preferredReadReplica(partition, currentTimeMs); + if (nodeId.isPresent()) { + Optional node = nodeId.flatMap(id -> metadata.fetch().nodeIfOnline(partition, id)); + if (node.isPresent()) { + return node.get(); + } else { + log.trace("Not fetching from {} for partition {} since it is marked offline or is missing from our metadata," + + " using the leader instead.", nodeId, partition); + subscriptions.clearPreferredReadReplica(partition); + return leaderReplica; + } + } else { + return leaderReplica; + } + } + + /** + * If we have seen new metadata (as tracked by {@link org.apache.kafka.clients.Metadata#updateVersion()}), then + * we should check that all of the assignments have a valid position. + */ + private void validatePositionsOnMetadataChange() { + int newMetadataUpdateVersion = metadata.updateVersion(); + if (metadataUpdateVersion.getAndSet(newMetadataUpdateVersion) != newMetadataUpdateVersion) { + subscriptions.assignedPartitions().forEach(topicPartition -> { + ConsumerMetadata.LeaderAndEpoch leaderAndEpoch = metadata.currentLeader(topicPartition); + subscriptions.maybeValidatePositionForCurrentLeader(apiVersions, topicPartition, leaderAndEpoch); + }); + } + } + + /** + * Create fetch requests for all nodes for which we have assigned partitions + * that have no existing requests in flight. + */ + private Map prepareFetchRequests() { + Map fetchable = new LinkedHashMap<>(); + + validatePositionsOnMetadataChange(); + + long currentTimeMs = time.milliseconds(); + Map topicIds = metadata.topicIds(); + + for (TopicPartition partition : fetchablePartitions()) { + FetchPosition position = this.subscriptions.position(partition); + if (position == null) { + throw new IllegalStateException("Missing position for fetchable partition " + partition); + } + + Optional leaderOpt = position.currentLeader.leader; + if (!leaderOpt.isPresent()) { + log.debug("Requesting metadata update for partition {} since the position {} is missing the current leader node", partition, position); + metadata.requestUpdate(); + continue; + } + + // Use the preferred read replica if set, otherwise the position's leader + Node node = selectReadReplica(partition, leaderOpt.get(), currentTimeMs); + if (client.isUnavailable(node)) { + client.maybeThrowAuthFailure(node); + + // If we try to send during the reconnect backoff window, then the request is just + // going to be failed anyway before being sent, so skip the send for now + log.trace("Skipping fetch for partition {} because node {} is awaiting reconnect backoff", partition, node); + } else if (this.nodesWithPendingFetchRequests.contains(node.id())) { + log.trace("Skipping fetch for partition {} because previous request to {} has not been processed", partition, node); + } else { + // if there is a leader and no in-flight requests, issue a new fetch + FetchSessionHandler.Builder builder = fetchable.get(node); + if (builder == null) { + int id = node.id(); + FetchSessionHandler handler = sessionHandler(id); + if (handler == null) { + handler = new FetchSessionHandler(logContext, id); + sessionHandlers.put(id, handler); + } + builder = handler.newBuilder(); + fetchable.put(node, builder); + } + builder.add(partition, new FetchRequest.PartitionData( + topicIds.getOrDefault(partition.topic(), Uuid.ZERO_UUID), + position.offset, FetchRequest.INVALID_LOG_START_OFFSET, this.fetchSize, + position.currentLeader.epoch, Optional.empty())); + + log.debug("Added {} fetch request for partition {} at position {} to node {}", isolationLevel, + partition, position, node); + } + } + + Map reqs = new LinkedHashMap<>(); + for (Map.Entry entry : fetchable.entrySet()) { + reqs.put(entry.getKey(), entry.getValue().build()); + } + return reqs; + } + + private Map> regroupFetchPositionsByLeader( + Map partitionMap) { + return partitionMap.entrySet() + .stream() + .filter(entry -> entry.getValue().currentLeader.leader.isPresent()) + .collect(Collectors.groupingBy(entry -> entry.getValue().currentLeader.leader.get(), + Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))); + } + + private Map> regroupPartitionMapByNode(Map partitionMap) { + return partitionMap.entrySet() + .stream() + .collect(Collectors.groupingBy(entry -> metadata.fetch().leaderFor(entry.getKey()), + Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))); + } + + /** + * Initialize a CompletedFetch object. + */ + private CompletedFetch initializeCompletedFetch(CompletedFetch nextCompletedFetch) { + TopicPartition tp = nextCompletedFetch.partition; + FetchResponseData.PartitionData partition = nextCompletedFetch.partitionData; + long fetchOffset = nextCompletedFetch.nextFetchOffset; + CompletedFetch completedFetch = null; + Errors error = Errors.forCode(partition.errorCode()); + + try { + if (!subscriptions.hasValidPosition(tp)) { + // this can happen when a rebalance happened while fetch is still in-flight + log.debug("Ignoring fetched records for partition {} since it no longer has valid position", tp); + } else if (error == Errors.NONE) { + // we are interested in this fetch only if the beginning offset matches the + // current consumed position + FetchPosition position = subscriptions.position(tp); + if (position == null || position.offset != fetchOffset) { + log.debug("Discarding stale fetch response for partition {} since its offset {} does not match " + + "the expected offset {}", tp, fetchOffset, position); + return null; + } + + log.trace("Preparing to read {} bytes of data for partition {} with offset {}", + FetchResponse.recordsSize(partition), tp, position); + Iterator batches = FetchResponse.recordsOrFail(partition).batches().iterator(); + completedFetch = nextCompletedFetch; + + if (!batches.hasNext() && FetchResponse.recordsSize(partition) > 0) { + if (completedFetch.responseVersion < 3) { + // Implement the pre KIP-74 behavior of throwing a RecordTooLargeException. + Map recordTooLargePartitions = Collections.singletonMap(tp, fetchOffset); + throw new RecordTooLargeException("There are some messages at [Partition=Offset]: " + + recordTooLargePartitions + " whose size is larger than the fetch size " + this.fetchSize + + " and hence cannot be returned. Please considering upgrading your broker to 0.10.1.0 or " + + "newer to avoid this issue. Alternately, increase the fetch size on the client (using " + + ConsumerConfig.MAX_PARTITION_FETCH_BYTES_CONFIG + ")", + recordTooLargePartitions); + } else { + // This should not happen with brokers that support FetchRequest/Response V3 or higher (i.e. KIP-74) + throw new KafkaException("Failed to make progress reading messages at " + tp + "=" + + fetchOffset + ". Received a non-empty fetch response from the server, but no " + + "complete records were found."); + } + } + + if (partition.highWatermark() >= 0) { + log.trace("Updating high watermark for partition {} to {}", tp, partition.highWatermark()); + subscriptions.updateHighWatermark(tp, partition.highWatermark()); + } + + if (partition.logStartOffset() >= 0) { + log.trace("Updating log start offset for partition {} to {}", tp, partition.logStartOffset()); + subscriptions.updateLogStartOffset(tp, partition.logStartOffset()); + } + + if (partition.lastStableOffset() >= 0) { + log.trace("Updating last stable offset for partition {} to {}", tp, partition.lastStableOffset()); + subscriptions.updateLastStableOffset(tp, partition.lastStableOffset()); + } + + if (FetchResponse.isPreferredReplica(partition)) { + subscriptions.updatePreferredReadReplica(completedFetch.partition, partition.preferredReadReplica(), () -> { + long expireTimeMs = time.milliseconds() + metadata.metadataExpireMs(); + log.debug("Updating preferred read replica for partition {} to {}, set to expire at {}", + tp, partition.preferredReadReplica(), expireTimeMs); + return expireTimeMs; + }); + } + + nextCompletedFetch.initialized = true; + } else if (error == Errors.NOT_LEADER_OR_FOLLOWER || + error == Errors.REPLICA_NOT_AVAILABLE || + error == Errors.KAFKA_STORAGE_ERROR || + error == Errors.FENCED_LEADER_EPOCH || + error == Errors.OFFSET_NOT_AVAILABLE) { + log.debug("Error in fetch for partition {}: {}", tp, error.exceptionName()); + this.metadata.requestUpdate(); + } else if (error == Errors.UNKNOWN_TOPIC_OR_PARTITION) { + log.warn("Received unknown topic or partition error in fetch for partition {}", tp); + this.metadata.requestUpdate(); + } else if (error == Errors.UNKNOWN_TOPIC_ID) { + log.warn("Received unknown topic ID error in fetch for partition {}", tp); + this.metadata.requestUpdate(); + } else if (error == Errors.INCONSISTENT_TOPIC_ID) { + log.warn("Received inconsistent topic ID error in fetch for partition {}", tp); + this.metadata.requestUpdate(); + } else if (error == Errors.OFFSET_OUT_OF_RANGE) { + Optional clearedReplicaId = subscriptions.clearPreferredReadReplica(tp); + if (!clearedReplicaId.isPresent()) { + // If there's no preferred replica to clear, we're fetching from the leader so handle this error normally + FetchPosition position = subscriptions.position(tp); + if (position == null || fetchOffset != position.offset) { + log.debug("Discarding stale fetch response for partition {} since the fetched offset {} " + + "does not match the current offset {}", tp, fetchOffset, position); + } else { + handleOffsetOutOfRange(position, tp); + } + } else { + log.debug("Unset the preferred read replica {} for partition {} since we got {} when fetching {}", + clearedReplicaId.get(), tp, error, fetchOffset); + } + } else if (error == Errors.TOPIC_AUTHORIZATION_FAILED) { + //we log the actual partition and not just the topic to help with ACL propagation issues in large clusters + log.warn("Not authorized to read from partition {}.", tp); + throw new TopicAuthorizationException(Collections.singleton(tp.topic())); + } else if (error == Errors.UNKNOWN_LEADER_EPOCH) { + log.debug("Received unknown leader epoch error in fetch for partition {}", tp); + } else if (error == Errors.UNKNOWN_SERVER_ERROR) { + log.warn("Unknown server error while fetching offset {} for topic-partition {}", + fetchOffset, tp); + } else if (error == Errors.CORRUPT_MESSAGE) { + throw new KafkaException("Encountered corrupt message when fetching offset " + + fetchOffset + + " for topic-partition " + + tp); + } else { + throw new IllegalStateException("Unexpected error code " + + error.code() + + " while fetching at offset " + + fetchOffset + + " from topic-partition " + tp); + } + } finally { + if (completedFetch == null) + nextCompletedFetch.metricAggregator.record(tp, 0, 0); + + if (error != Errors.NONE) + // we move the partition to the end if there was an error. This way, it's more likely that partitions for + // the same topic can remain together (allowing for more efficient serialization). + subscriptions.movePartitionToEnd(tp); + } + + return completedFetch; + } + + private void handleOffsetOutOfRange(FetchPosition fetchPosition, TopicPartition topicPartition) { + String errorMessage = "Fetch position " + fetchPosition + " is out of range for partition " + topicPartition; + if (subscriptions.hasDefaultOffsetResetPolicy()) { + log.info("{}, resetting offset", errorMessage); + subscriptions.requestOffsetReset(topicPartition); + } else { + log.info("{}, raising error to the application since no reset policy is configured", errorMessage); + throw new OffsetOutOfRangeException(errorMessage, + Collections.singletonMap(topicPartition, fetchPosition.offset)); + } + } + + /** + * Parse the record entry, deserializing the key / value fields if necessary + */ + private ConsumerRecord parseRecord(TopicPartition partition, + RecordBatch batch, + Record record) { + try { + long offset = record.offset(); + long timestamp = record.timestamp(); + Optional leaderEpoch = maybeLeaderEpoch(batch.partitionLeaderEpoch()); + TimestampType timestampType = batch.timestampType(); + Headers headers = new RecordHeaders(record.headers()); + ByteBuffer keyBytes = record.key(); + byte[] keyByteArray = keyBytes == null ? null : Utils.toArray(keyBytes); + K key = keyBytes == null ? null : this.keyDeserializer.deserialize(partition.topic(), headers, keyByteArray); + ByteBuffer valueBytes = record.value(); + byte[] valueByteArray = valueBytes == null ? null : Utils.toArray(valueBytes); + V value = valueBytes == null ? null : this.valueDeserializer.deserialize(partition.topic(), headers, valueByteArray); + return new ConsumerRecord<>(partition.topic(), partition.partition(), offset, + timestamp, timestampType, + keyByteArray == null ? ConsumerRecord.NULL_SIZE : keyByteArray.length, + valueByteArray == null ? ConsumerRecord.NULL_SIZE : valueByteArray.length, + key, value, headers, leaderEpoch); + } catch (RuntimeException e) { + throw new RecordDeserializationException(partition, record.offset(), + "Error deserializing key/value for partition " + partition + + " at offset " + record.offset() + ". If needed, please seek past the record to continue consumption.", e); + } + } + + private Optional maybeLeaderEpoch(int leaderEpoch) { + return leaderEpoch == RecordBatch.NO_PARTITION_LEADER_EPOCH ? Optional.empty() : Optional.of(leaderEpoch); + } + + /** + * Clear the buffered data which are not a part of newly assigned partitions + * + * @param assignedPartitions newly assigned {@link TopicPartition} + */ + public void clearBufferedDataForUnassignedPartitions(Collection assignedPartitions) { + Iterator completedFetchesItr = completedFetches.iterator(); + while (completedFetchesItr.hasNext()) { + CompletedFetch records = completedFetchesItr.next(); + TopicPartition tp = records.partition; + if (!assignedPartitions.contains(tp)) { + records.drain(); + completedFetchesItr.remove(); + } + } + + if (nextInLineFetch != null && !assignedPartitions.contains(nextInLineFetch.partition)) { + nextInLineFetch.drain(); + nextInLineFetch = null; + } + } + + /** + * Clear the buffered data which are not a part of newly assigned topics + * + * @param assignedTopics newly assigned topics + */ + public void clearBufferedDataForUnassignedTopics(Collection assignedTopics) { + Set currentTopicPartitions = new HashSet<>(); + for (TopicPartition tp : subscriptions.assignedPartitions()) { + if (assignedTopics.contains(tp.topic())) { + currentTopicPartitions.add(tp); + } + } + clearBufferedDataForUnassignedPartitions(currentTopicPartitions); + } + + // Visible for testing + protected FetchSessionHandler sessionHandler(int node) { + return sessionHandlers.get(node); + } + + public static Sensor throttleTimeSensor(Metrics metrics, FetcherMetricsRegistry metricsRegistry) { + Sensor fetchThrottleTimeSensor = metrics.sensor("fetch-throttle-time"); + fetchThrottleTimeSensor.add(metrics.metricInstance(metricsRegistry.fetchThrottleTimeAvg), new Avg()); + + fetchThrottleTimeSensor.add(metrics.metricInstance(metricsRegistry.fetchThrottleTimeMax), new Max()); + + return fetchThrottleTimeSensor; + } + + private class CompletedFetch { + private final TopicPartition partition; + private final Iterator batches; + private final Set abortedProducerIds; + private final PriorityQueue abortedTransactions; + private final FetchResponseData.PartitionData partitionData; + private final FetchResponseMetricAggregator metricAggregator; + private final short responseVersion; + + private int recordsRead; + private int bytesRead; + private RecordBatch currentBatch; + private Record lastRecord; + private CloseableIterator records; + private long nextFetchOffset; + private Optional lastEpoch; + private boolean isConsumed = false; + private Exception cachedRecordException = null; + private boolean corruptLastRecord = false; + private boolean initialized = false; + + private CompletedFetch(TopicPartition partition, + FetchResponseData.PartitionData partitionData, + FetchResponseMetricAggregator metricAggregator, + Iterator batches, + Long fetchOffset, + short responseVersion) { + this.partition = partition; + this.partitionData = partitionData; + this.metricAggregator = metricAggregator; + this.batches = batches; + this.nextFetchOffset = fetchOffset; + this.responseVersion = responseVersion; + this.lastEpoch = Optional.empty(); + this.abortedProducerIds = new HashSet<>(); + this.abortedTransactions = abortedTransactions(partitionData); + } + + private void drain() { + if (!isConsumed) { + maybeCloseRecordStream(); + cachedRecordException = null; + this.isConsumed = true; + this.metricAggregator.record(partition, bytesRead, recordsRead); + + // we move the partition to the end if we received some bytes. This way, it's more likely that partitions + // for the same topic can remain together (allowing for more efficient serialization). + if (bytesRead > 0) + subscriptions.movePartitionToEnd(partition); + } + } + + private void maybeEnsureValid(RecordBatch batch) { + if (checkCrcs && currentBatch.magic() >= RecordBatch.MAGIC_VALUE_V2) { + try { + batch.ensureValid(); + } catch (CorruptRecordException e) { + throw new KafkaException("Record batch for partition " + partition + " at offset " + + batch.baseOffset() + " is invalid, cause: " + e.getMessage()); + } + } + } + + private void maybeEnsureValid(Record record) { + if (checkCrcs) { + try { + record.ensureValid(); + } catch (CorruptRecordException e) { + throw new KafkaException("Record for partition " + partition + " at offset " + record.offset() + + " is invalid, cause: " + e.getMessage()); + } + } + } + + private void maybeCloseRecordStream() { + if (records != null) { + records.close(); + records = null; + } + } + + private Record nextFetchedRecord() { + while (true) { + if (records == null || !records.hasNext()) { + maybeCloseRecordStream(); + + if (!batches.hasNext()) { + // Message format v2 preserves the last offset in a batch even if the last record is removed + // through compaction. By using the next offset computed from the last offset in the batch, + // we ensure that the offset of the next fetch will point to the next batch, which avoids + // unnecessary re-fetching of the same batch (in the worst case, the consumer could get stuck + // fetching the same batch repeatedly). + if (currentBatch != null) + nextFetchOffset = currentBatch.nextOffset(); + drain(); + return null; + } + + currentBatch = batches.next(); + lastEpoch = currentBatch.partitionLeaderEpoch() == RecordBatch.NO_PARTITION_LEADER_EPOCH ? + Optional.empty() : Optional.of(currentBatch.partitionLeaderEpoch()); + + maybeEnsureValid(currentBatch); + + if (isolationLevel == IsolationLevel.READ_COMMITTED && currentBatch.hasProducerId()) { + // remove from the aborted transaction queue all aborted transactions which have begun + // before the current batch's last offset and add the associated producerIds to the + // aborted producer set + consumeAbortedTransactionsUpTo(currentBatch.lastOffset()); + + long producerId = currentBatch.producerId(); + if (containsAbortMarker(currentBatch)) { + abortedProducerIds.remove(producerId); + } else if (isBatchAborted(currentBatch)) { + log.debug("Skipping aborted record batch from partition {} with producerId {} and " + + "offsets {} to {}", + partition, producerId, currentBatch.baseOffset(), currentBatch.lastOffset()); + nextFetchOffset = currentBatch.nextOffset(); + continue; + } + } + + records = currentBatch.streamingIterator(decompressionBufferSupplier); + } else { + Record record = records.next(); + // skip any records out of range + if (record.offset() >= nextFetchOffset) { + // we only do validation when the message should not be skipped. + maybeEnsureValid(record); + + // control records are not returned to the user + if (!currentBatch.isControlBatch()) { + return record; + } else { + // Increment the next fetch offset when we skip a control batch. + nextFetchOffset = record.offset() + 1; + } + } + } + } + } + + private List> fetchRecords(int maxRecords) { + // Error when fetching the next record before deserialization. + if (corruptLastRecord) + throw new KafkaException("Received exception when fetching the next record from " + partition + + ". If needed, please seek past the record to " + + "continue consumption.", cachedRecordException); + + if (isConsumed) + return Collections.emptyList(); + + List> records = new ArrayList<>(); + try { + for (int i = 0; i < maxRecords; i++) { + // Only move to next record if there was no exception in the last fetch. Otherwise we should + // use the last record to do deserialization again. + if (cachedRecordException == null) { + corruptLastRecord = true; + lastRecord = nextFetchedRecord(); + corruptLastRecord = false; + } + if (lastRecord == null) + break; + records.add(parseRecord(partition, currentBatch, lastRecord)); + recordsRead++; + bytesRead += lastRecord.sizeInBytes(); + nextFetchOffset = lastRecord.offset() + 1; + // In some cases, the deserialization may have thrown an exception and the retry may succeed, + // we allow user to move forward in this case. + cachedRecordException = null; + } + } catch (SerializationException se) { + cachedRecordException = se; + if (records.isEmpty()) + throw se; + } catch (KafkaException e) { + cachedRecordException = e; + if (records.isEmpty()) + throw new KafkaException("Received exception when fetching the next record from " + partition + + ". If needed, please seek past the record to " + + "continue consumption.", e); + } + return records; + } + + private void consumeAbortedTransactionsUpTo(long offset) { + if (abortedTransactions == null) + return; + + while (!abortedTransactions.isEmpty() && abortedTransactions.peek().firstOffset() <= offset) { + FetchResponseData.AbortedTransaction abortedTransaction = abortedTransactions.poll(); + abortedProducerIds.add(abortedTransaction.producerId()); + } + } + + private boolean isBatchAborted(RecordBatch batch) { + return batch.isTransactional() && abortedProducerIds.contains(batch.producerId()); + } + + private PriorityQueue abortedTransactions(FetchResponseData.PartitionData partition) { + if (partition.abortedTransactions() == null || partition.abortedTransactions().isEmpty()) + return null; + + PriorityQueue abortedTransactions = new PriorityQueue<>( + partition.abortedTransactions().size(), Comparator.comparingLong(FetchResponseData.AbortedTransaction::firstOffset) + ); + abortedTransactions.addAll(partition.abortedTransactions()); + return abortedTransactions; + } + + private boolean containsAbortMarker(RecordBatch batch) { + if (!batch.isControlBatch()) + return false; + + Iterator batchIterator = batch.iterator(); + if (!batchIterator.hasNext()) + return false; + + Record firstRecord = batchIterator.next(); + return ControlRecordType.ABORT == ControlRecordType.parse(firstRecord.key()); + } + + private boolean notInitialized() { + return !this.initialized; + } + } + + /** + * Since we parse the message data for each partition from each fetch response lazily, fetch-level + * metrics need to be aggregated as the messages from each partition are parsed. This class is used + * to facilitate this incremental aggregation. + */ + private static class FetchResponseMetricAggregator { + private final FetchManagerMetrics sensors; + private final Set unrecordedPartitions; + + private final FetchMetrics fetchMetrics = new FetchMetrics(); + private final Map topicFetchMetrics = new HashMap<>(); + + private FetchResponseMetricAggregator(FetchManagerMetrics sensors, + Set partitions) { + this.sensors = sensors; + this.unrecordedPartitions = partitions; + } + + /** + * After each partition is parsed, we update the current metric totals with the total bytes + * and number of records parsed. After all partitions have reported, we write the metric. + */ + public void record(TopicPartition partition, int bytes, int records) { + this.unrecordedPartitions.remove(partition); + this.fetchMetrics.increment(bytes, records); + + // collect and aggregate per-topic metrics + String topic = partition.topic(); + FetchMetrics topicFetchMetric = this.topicFetchMetrics.get(topic); + if (topicFetchMetric == null) { + topicFetchMetric = new FetchMetrics(); + this.topicFetchMetrics.put(topic, topicFetchMetric); + } + topicFetchMetric.increment(bytes, records); + + if (this.unrecordedPartitions.isEmpty()) { + // once all expected partitions from the fetch have reported in, record the metrics + this.sensors.bytesFetched.record(this.fetchMetrics.fetchBytes); + this.sensors.recordsFetched.record(this.fetchMetrics.fetchRecords); + + // also record per-topic metrics + for (Map.Entry entry: this.topicFetchMetrics.entrySet()) { + FetchMetrics metric = entry.getValue(); + this.sensors.recordTopicFetchMetrics(entry.getKey(), metric.fetchBytes, metric.fetchRecords); + } + } + } + + private static class FetchMetrics { + private int fetchBytes; + private int fetchRecords; + + protected void increment(int bytes, int records) { + this.fetchBytes += bytes; + this.fetchRecords += records; + } + } + } + + private static class FetchManagerMetrics { + private final Metrics metrics; + private FetcherMetricsRegistry metricsRegistry; + private final Sensor bytesFetched; + private final Sensor recordsFetched; + private final Sensor fetchLatency; + private final Sensor recordsFetchLag; + private final Sensor recordsFetchLead; + + private int assignmentId = 0; + private Set assignedPartitions = Collections.emptySet(); + + private FetchManagerMetrics(Metrics metrics, FetcherMetricsRegistry metricsRegistry) { + this.metrics = metrics; + this.metricsRegistry = metricsRegistry; + + this.bytesFetched = metrics.sensor("bytes-fetched"); + this.bytesFetched.add(metrics.metricInstance(metricsRegistry.fetchSizeAvg), new Avg()); + this.bytesFetched.add(metrics.metricInstance(metricsRegistry.fetchSizeMax), new Max()); + this.bytesFetched.add(new Meter(metrics.metricInstance(metricsRegistry.bytesConsumedRate), + metrics.metricInstance(metricsRegistry.bytesConsumedTotal))); + + this.recordsFetched = metrics.sensor("records-fetched"); + this.recordsFetched.add(metrics.metricInstance(metricsRegistry.recordsPerRequestAvg), new Avg()); + this.recordsFetched.add(new Meter(metrics.metricInstance(metricsRegistry.recordsConsumedRate), + metrics.metricInstance(metricsRegistry.recordsConsumedTotal))); + + this.fetchLatency = metrics.sensor("fetch-latency"); + this.fetchLatency.add(metrics.metricInstance(metricsRegistry.fetchLatencyAvg), new Avg()); + this.fetchLatency.add(metrics.metricInstance(metricsRegistry.fetchLatencyMax), new Max()); + this.fetchLatency.add(new Meter(new WindowedCount(), metrics.metricInstance(metricsRegistry.fetchRequestRate), + metrics.metricInstance(metricsRegistry.fetchRequestTotal))); + + this.recordsFetchLag = metrics.sensor("records-lag"); + this.recordsFetchLag.add(metrics.metricInstance(metricsRegistry.recordsLagMax), new Max()); + + this.recordsFetchLead = metrics.sensor("records-lead"); + this.recordsFetchLead.add(metrics.metricInstance(metricsRegistry.recordsLeadMin), new Min()); + } + + private void recordTopicFetchMetrics(String topic, int bytes, int records) { + // record bytes fetched + String name = "topic." + topic + ".bytes-fetched"; + Sensor bytesFetched = this.metrics.getSensor(name); + if (bytesFetched == null) { + Map metricTags = Collections.singletonMap("topic", topic.replace('.', '_')); + + bytesFetched = this.metrics.sensor(name); + bytesFetched.add(this.metrics.metricInstance(metricsRegistry.topicFetchSizeAvg, + metricTags), new Avg()); + bytesFetched.add(this.metrics.metricInstance(metricsRegistry.topicFetchSizeMax, + metricTags), new Max()); + bytesFetched.add(new Meter(this.metrics.metricInstance(metricsRegistry.topicBytesConsumedRate, metricTags), + this.metrics.metricInstance(metricsRegistry.topicBytesConsumedTotal, metricTags))); + } + bytesFetched.record(bytes); + + // record records fetched + name = "topic." + topic + ".records-fetched"; + Sensor recordsFetched = this.metrics.getSensor(name); + if (recordsFetched == null) { + Map metricTags = new HashMap<>(1); + metricTags.put("topic", topic.replace('.', '_')); + + recordsFetched = this.metrics.sensor(name); + recordsFetched.add(this.metrics.metricInstance(metricsRegistry.topicRecordsPerRequestAvg, + metricTags), new Avg()); + recordsFetched.add(new Meter(this.metrics.metricInstance(metricsRegistry.topicRecordsConsumedRate, metricTags), + this.metrics.metricInstance(metricsRegistry.topicRecordsConsumedTotal, metricTags))); + } + recordsFetched.record(records); + } + + private void maybeUpdateAssignment(SubscriptionState subscription) { + int newAssignmentId = subscription.assignmentId(); + if (this.assignmentId != newAssignmentId) { + Set newAssignedPartitions = subscription.assignedPartitions(); + for (TopicPartition tp : this.assignedPartitions) { + if (!newAssignedPartitions.contains(tp)) { + metrics.removeSensor(partitionLagMetricName(tp)); + metrics.removeSensor(partitionLeadMetricName(tp)); + metrics.removeMetric(partitionPreferredReadReplicaMetricName(tp)); + } + } + + for (TopicPartition tp : newAssignedPartitions) { + if (!this.assignedPartitions.contains(tp)) { + MetricName metricName = partitionPreferredReadReplicaMetricName(tp); + if (metrics.metric(metricName) == null) { + metrics.addMetric( + metricName, + (Gauge) (config, now) -> subscription.preferredReadReplica(tp, 0L).orElse(-1) + ); + } + } + } + + this.assignedPartitions = newAssignedPartitions; + this.assignmentId = newAssignmentId; + } + } + + private void recordPartitionLead(TopicPartition tp, long lead) { + this.recordsFetchLead.record(lead); + + String name = partitionLeadMetricName(tp); + Sensor recordsLead = this.metrics.getSensor(name); + if (recordsLead == null) { + Map metricTags = topicPartitionTags(tp); + + recordsLead = this.metrics.sensor(name); + + recordsLead.add(this.metrics.metricInstance(metricsRegistry.partitionRecordsLead, metricTags), new Value()); + recordsLead.add(this.metrics.metricInstance(metricsRegistry.partitionRecordsLeadMin, metricTags), new Min()); + recordsLead.add(this.metrics.metricInstance(metricsRegistry.partitionRecordsLeadAvg, metricTags), new Avg()); + } + recordsLead.record(lead); + } + + private void recordPartitionLag(TopicPartition tp, long lag) { + this.recordsFetchLag.record(lag); + + String name = partitionLagMetricName(tp); + Sensor recordsLag = this.metrics.getSensor(name); + if (recordsLag == null) { + Map metricTags = topicPartitionTags(tp); + recordsLag = this.metrics.sensor(name); + + recordsLag.add(this.metrics.metricInstance(metricsRegistry.partitionRecordsLag, metricTags), new Value()); + recordsLag.add(this.metrics.metricInstance(metricsRegistry.partitionRecordsLagMax, metricTags), new Max()); + recordsLag.add(this.metrics.metricInstance(metricsRegistry.partitionRecordsLagAvg, metricTags), new Avg()); + } + recordsLag.record(lag); + } + + private static String partitionLagMetricName(TopicPartition tp) { + return tp + ".records-lag"; + } + + private static String partitionLeadMetricName(TopicPartition tp) { + return tp + ".records-lead"; + } + + private MetricName partitionPreferredReadReplicaMetricName(TopicPartition tp) { + Map metricTags = topicPartitionTags(tp); + return this.metrics.metricInstance(metricsRegistry.partitionPreferredReadReplica, metricTags); + } + + private Map topicPartitionTags(TopicPartition tp) { + Map metricTags = new HashMap<>(2); + metricTags.put("topic", tp.topic().replace('.', '_')); + metricTags.put("partition", String.valueOf(tp.partition())); + return metricTags; + } + } + + @Override + public void close() { + if (nextInLineFetch != null) + nextInLineFetch.drain(); + decompressionBufferSupplier.close(); + } + + private Set topicsForPartitions(Collection partitions) { + return partitions.stream().map(TopicPartition::topic).collect(Collectors.toSet()); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/FetcherMetricsRegistry.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/FetcherMetricsRegistry.java new file mode 100644 index 0000000..501ffe9 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/FetcherMetricsRegistry.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Set; + +import org.apache.kafka.common.MetricNameTemplate; + +public class FetcherMetricsRegistry { + + public MetricNameTemplate fetchSizeAvg; + public MetricNameTemplate fetchSizeMax; + public MetricNameTemplate bytesConsumedRate; + public MetricNameTemplate bytesConsumedTotal; + public MetricNameTemplate recordsPerRequestAvg; + public MetricNameTemplate recordsConsumedRate; + public MetricNameTemplate recordsConsumedTotal; + public MetricNameTemplate fetchLatencyAvg; + public MetricNameTemplate fetchLatencyMax; + public MetricNameTemplate fetchRequestRate; + public MetricNameTemplate fetchRequestTotal; + public MetricNameTemplate recordsLagMax; + public MetricNameTemplate recordsLeadMin; + public MetricNameTemplate fetchThrottleTimeAvg; + public MetricNameTemplate fetchThrottleTimeMax; + public MetricNameTemplate topicFetchSizeAvg; + public MetricNameTemplate topicFetchSizeMax; + public MetricNameTemplate topicBytesConsumedRate; + public MetricNameTemplate topicBytesConsumedTotal; + public MetricNameTemplate topicRecordsPerRequestAvg; + public MetricNameTemplate topicRecordsConsumedRate; + public MetricNameTemplate topicRecordsConsumedTotal; + public MetricNameTemplate partitionRecordsLag; + public MetricNameTemplate partitionRecordsLagMax; + public MetricNameTemplate partitionRecordsLagAvg; + public MetricNameTemplate partitionRecordsLead; + public MetricNameTemplate partitionRecordsLeadMin; + public MetricNameTemplate partitionRecordsLeadAvg; + public MetricNameTemplate partitionPreferredReadReplica; + + public FetcherMetricsRegistry() { + this(new HashSet(), ""); + } + + public FetcherMetricsRegistry(String metricGrpPrefix) { + this(new HashSet(), metricGrpPrefix); + } + + public FetcherMetricsRegistry(Set tags, String metricGrpPrefix) { + + /***** Client level *****/ + String groupName = metricGrpPrefix + "-fetch-manager-metrics"; + + this.fetchSizeAvg = new MetricNameTemplate("fetch-size-avg", groupName, + "The average number of bytes fetched per request", tags); + + this.fetchSizeMax = new MetricNameTemplate("fetch-size-max", groupName, + "The maximum number of bytes fetched per request", tags); + this.bytesConsumedRate = new MetricNameTemplate("bytes-consumed-rate", groupName, + "The average number of bytes consumed per second", tags); + this.bytesConsumedTotal = new MetricNameTemplate("bytes-consumed-total", groupName, + "The total number of bytes consumed", tags); + + this.recordsPerRequestAvg = new MetricNameTemplate("records-per-request-avg", groupName, + "The average number of records in each request", tags); + this.recordsConsumedRate = new MetricNameTemplate("records-consumed-rate", groupName, + "The average number of records consumed per second", tags); + this.recordsConsumedTotal = new MetricNameTemplate("records-consumed-total", groupName, + "The total number of records consumed", tags); + + this.fetchLatencyAvg = new MetricNameTemplate("fetch-latency-avg", groupName, + "The average time taken for a fetch request.", tags); + this.fetchLatencyMax = new MetricNameTemplate("fetch-latency-max", groupName, + "The max time taken for any fetch request.", tags); + this.fetchRequestRate = new MetricNameTemplate("fetch-rate", groupName, + "The number of fetch requests per second.", tags); + this.fetchRequestTotal = new MetricNameTemplate("fetch-total", groupName, + "The total number of fetch requests.", tags); + + this.recordsLagMax = new MetricNameTemplate("records-lag-max", groupName, + "The maximum lag in terms of number of records for any partition in this window", tags); + this.recordsLeadMin = new MetricNameTemplate("records-lead-min", groupName, + "The minimum lead in terms of number of records for any partition in this window", tags); + + this.fetchThrottleTimeAvg = new MetricNameTemplate("fetch-throttle-time-avg", groupName, + "The average throttle time in ms", tags); + this.fetchThrottleTimeMax = new MetricNameTemplate("fetch-throttle-time-max", groupName, + "The maximum throttle time in ms", tags); + + /***** Topic level *****/ + Set topicTags = new LinkedHashSet<>(tags); + topicTags.add("topic"); + + this.topicFetchSizeAvg = new MetricNameTemplate("fetch-size-avg", groupName, + "The average number of bytes fetched per request for a topic", topicTags); + this.topicFetchSizeMax = new MetricNameTemplate("fetch-size-max", groupName, + "The maximum number of bytes fetched per request for a topic", topicTags); + this.topicBytesConsumedRate = new MetricNameTemplate("bytes-consumed-rate", groupName, + "The average number of bytes consumed per second for a topic", topicTags); + this.topicBytesConsumedTotal = new MetricNameTemplate("bytes-consumed-total", groupName, + "The total number of bytes consumed for a topic", topicTags); + + this.topicRecordsPerRequestAvg = new MetricNameTemplate("records-per-request-avg", groupName, + "The average number of records in each request for a topic", topicTags); + this.topicRecordsConsumedRate = new MetricNameTemplate("records-consumed-rate", groupName, + "The average number of records consumed per second for a topic", topicTags); + this.topicRecordsConsumedTotal = new MetricNameTemplate("records-consumed-total", groupName, + "The total number of records consumed for a topic", topicTags); + + /***** Partition level *****/ + Set partitionTags = new HashSet<>(topicTags); + partitionTags.add("partition"); + this.partitionRecordsLag = new MetricNameTemplate("records-lag", groupName, + "The latest lag of the partition", partitionTags); + this.partitionRecordsLagMax = new MetricNameTemplate("records-lag-max", groupName, + "The max lag of the partition", partitionTags); + this.partitionRecordsLagAvg = new MetricNameTemplate("records-lag-avg", groupName, + "The average lag of the partition", partitionTags); + this.partitionRecordsLead = new MetricNameTemplate("records-lead", groupName, + "The latest lead of the partition", partitionTags); + this.partitionRecordsLeadMin = new MetricNameTemplate("records-lead-min", groupName, + "The min lead of the partition", partitionTags); + this.partitionRecordsLeadAvg = new MetricNameTemplate("records-lead-avg", groupName, + "The average lead of the partition", partitionTags); + this.partitionPreferredReadReplica = new MetricNameTemplate( + "preferred-read-replica", "consumer-fetch-manager-metrics", + "The current read replica for the partition, or -1 if reading from leader", partitionTags); + } + + public List getAllTemplates() { + return Arrays.asList( + fetchSizeAvg, + fetchSizeMax, + bytesConsumedRate, + bytesConsumedTotal, + recordsPerRequestAvg, + recordsConsumedRate, + recordsConsumedTotal, + fetchLatencyAvg, + fetchLatencyMax, + fetchRequestRate, + fetchRequestTotal, + recordsLagMax, + recordsLeadMin, + fetchThrottleTimeAvg, + fetchThrottleTimeMax, + topicFetchSizeAvg, + topicFetchSizeMax, + topicBytesConsumedRate, + topicBytesConsumedTotal, + topicRecordsPerRequestAvg, + topicRecordsConsumedRate, + topicRecordsConsumedTotal, + partitionRecordsLag, + partitionRecordsLagAvg, + partitionRecordsLagMax, + partitionRecordsLead, + partitionRecordsLeadMin, + partitionRecordsLeadAvg, + partitionPreferredReadReplica + ); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Heartbeat.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Heartbeat.java new file mode 100644 index 0000000..dfb9f85 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/Heartbeat.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.clients.GroupRebalanceConfig; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Timer; + +import org.slf4j.Logger; + +/** + * A helper class for managing the heartbeat to the coordinator + */ +public final class Heartbeat { + private final int maxPollIntervalMs; + private final GroupRebalanceConfig rebalanceConfig; + private final Time time; + private final Timer heartbeatTimer; + private final Timer sessionTimer; + private final Timer pollTimer; + private final Logger log; + + private volatile long lastHeartbeatSend = 0L; + private volatile boolean heartbeatInFlight = false; + + public Heartbeat(GroupRebalanceConfig config, + Time time) { + if (config.heartbeatIntervalMs >= config.sessionTimeoutMs) + throw new IllegalArgumentException("Heartbeat must be set lower than the session timeout"); + this.rebalanceConfig = config; + this.time = time; + this.heartbeatTimer = time.timer(config.heartbeatIntervalMs); + this.sessionTimer = time.timer(config.sessionTimeoutMs); + this.maxPollIntervalMs = config.rebalanceTimeoutMs; + this.pollTimer = time.timer(maxPollIntervalMs); + + final LogContext logContext = new LogContext("[Heartbeat groupID=" + config.groupId + "] "); + this.log = logContext.logger(getClass()); + } + + private void update(long now) { + heartbeatTimer.update(now); + sessionTimer.update(now); + pollTimer.update(now); + } + + public void poll(long now) { + update(now); + pollTimer.reset(maxPollIntervalMs); + } + + boolean hasInflight() { + return heartbeatInFlight; + } + + void sentHeartbeat(long now) { + lastHeartbeatSend = now; + heartbeatInFlight = true; + update(now); + heartbeatTimer.reset(rebalanceConfig.heartbeatIntervalMs); + + if (log.isTraceEnabled()) { + log.trace("Sending heartbeat request with {}ms remaining on timer", heartbeatTimer.remainingMs()); + } + } + + void failHeartbeat() { + update(time.milliseconds()); + heartbeatInFlight = false; + heartbeatTimer.reset(rebalanceConfig.retryBackoffMs); + + log.trace("Heartbeat failed, reset the timer to {}ms remaining", heartbeatTimer.remainingMs()); + } + + void receiveHeartbeat() { + update(time.milliseconds()); + heartbeatInFlight = false; + sessionTimer.reset(rebalanceConfig.sessionTimeoutMs); + } + + boolean shouldHeartbeat(long now) { + update(now); + return heartbeatTimer.isExpired(); + } + + long lastHeartbeatSend() { + return this.lastHeartbeatSend; + } + + long timeToNextHeartbeat(long now) { + update(now); + return heartbeatTimer.remainingMs(); + } + + boolean sessionTimeoutExpired(long now) { + update(now); + return sessionTimer.isExpired(); + } + + void resetTimeouts() { + update(time.milliseconds()); + sessionTimer.reset(rebalanceConfig.sessionTimeoutMs); + pollTimer.reset(maxPollIntervalMs); + heartbeatTimer.reset(rebalanceConfig.heartbeatIntervalMs); + } + + void resetSessionTimeout() { + update(time.milliseconds()); + sessionTimer.reset(rebalanceConfig.sessionTimeoutMs); + } + + boolean pollTimeoutExpired(long now) { + update(now); + return pollTimer.isExpired(); + } + + long lastPollTime() { + return pollTimer.currentTimeMs(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/KafkaConsumerMetrics.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/KafkaConsumerMetrics.java new file mode 100644 index 0000000..0dc8a33 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/KafkaConsumerMetrics.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.Measurable; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.stats.Avg; +import org.apache.kafka.common.metrics.stats.CumulativeSum; +import org.apache.kafka.common.metrics.stats.Max; + +import java.util.concurrent.TimeUnit; + +public class KafkaConsumerMetrics implements AutoCloseable { + private final MetricName lastPollMetricName; + private final Sensor timeBetweenPollSensor; + private final Sensor pollIdleSensor; + private final Sensor committedSensor; + private final Sensor commitSyncSensor; + private final Metrics metrics; + private long lastPollMs; + private long pollStartMs; + private long timeSinceLastPollMs; + + public KafkaConsumerMetrics(Metrics metrics, String metricGrpPrefix) { + this.metrics = metrics; + String metricGroupName = metricGrpPrefix + "-metrics"; + Measurable lastPoll = (mConfig, now) -> { + if (lastPollMs == 0L) + // if no poll is ever triggered, just return -1. + return -1d; + else + return TimeUnit.SECONDS.convert(now - lastPollMs, TimeUnit.MILLISECONDS); + }; + this.lastPollMetricName = metrics.metricName("last-poll-seconds-ago", + metricGroupName, "The number of seconds since the last poll() invocation."); + metrics.addMetric(lastPollMetricName, lastPoll); + + this.timeBetweenPollSensor = metrics.sensor("time-between-poll"); + this.timeBetweenPollSensor.add(metrics.metricName("time-between-poll-avg", + metricGroupName, + "The average delay between invocations of poll() in milliseconds."), + new Avg()); + this.timeBetweenPollSensor.add(metrics.metricName("time-between-poll-max", + metricGroupName, + "The max delay between invocations of poll() in milliseconds."), + new Max()); + + this.pollIdleSensor = metrics.sensor("poll-idle-ratio-avg"); + this.pollIdleSensor.add(metrics.metricName("poll-idle-ratio-avg", + metricGroupName, + "The average fraction of time the consumer's poll() is idle as opposed to waiting for the user code to process records."), + new Avg()); + + this.commitSyncSensor = metrics.sensor("commit-sync-time-ns-total"); + this.commitSyncSensor.add( + metrics.metricName( + "commit-sync-time-ns-total", + metricGroupName, + "The total time the consumer has spent in commitSync in nanoseconds" + ), + new CumulativeSum() + ); + + this.committedSensor = metrics.sensor("committed-time-ns-total"); + this.committedSensor.add( + metrics.metricName( + "committed-time-ns-total", + metricGroupName, + "The total time the consumer has spent in committed in nanoseconds" + ), + new CumulativeSum() + ); + } + + public void recordPollStart(long pollStartMs) { + this.pollStartMs = pollStartMs; + this.timeSinceLastPollMs = lastPollMs != 0L ? pollStartMs - lastPollMs : 0; + this.timeBetweenPollSensor.record(timeSinceLastPollMs); + this.lastPollMs = pollStartMs; + } + + public void recordPollEnd(long pollEndMs) { + long pollTimeMs = pollEndMs - pollStartMs; + double pollIdleRatio = pollTimeMs * 1.0 / (pollTimeMs + timeSinceLastPollMs); + this.pollIdleSensor.record(pollIdleRatio); + } + + public void recordCommitSync(long duration) { + this.commitSyncSensor.record(duration); + } + + public void recordCommitted(long duration) { + this.committedSensor.record(duration); + } + + @Override + public void close() { + metrics.removeMetric(lastPollMetricName); + metrics.removeSensor(timeBetweenPollSensor.name()); + metrics.removeSensor(pollIdleSensor.name()); + metrics.removeSensor(commitSyncSensor.name()); + metrics.removeSensor(committedSensor.name()); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/NoAvailableBrokersException.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/NoAvailableBrokersException.java new file mode 100644 index 0000000..d1ad6a4 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/NoAvailableBrokersException.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.common.errors.InvalidMetadataException; + +/** + * No brokers were available to complete a request. + */ +public class NoAvailableBrokersException extends InvalidMetadataException { + private static final long serialVersionUID = 1L; + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/NoOpConsumerRebalanceListener.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/NoOpConsumerRebalanceListener.java new file mode 100644 index 0000000..a3acc83 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/NoOpConsumerRebalanceListener.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.clients.consumer.ConsumerRebalanceListener; +import org.apache.kafka.common.TopicPartition; + +import java.util.Collection; + +public class NoOpConsumerRebalanceListener implements ConsumerRebalanceListener { + + @Override + public void onPartitionsAssigned(Collection partitions) {} + + @Override + public void onPartitionsRevoked(Collection partitions) {} + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/OffsetsForLeaderEpochClient.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/OffsetsForLeaderEpochClient.java new file mode 100644 index 0000000..57650f5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/OffsetsForLeaderEpochClient.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.TopicAuthorizationException; +import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderPartition; +import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderTopic; +import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderTopicCollection; +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset; +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.OffsetForLeaderTopicResult; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.requests.AbstractRequest; +import org.apache.kafka.common.requests.OffsetsForLeaderEpochRequest; +import org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse; +import org.apache.kafka.common.utils.LogContext; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +/** + * Convenience class for making asynchronous requests to the OffsetsForLeaderEpoch API + */ +public class OffsetsForLeaderEpochClient extends AsyncClient< + Map, + OffsetsForLeaderEpochRequest, + OffsetsForLeaderEpochResponse, + OffsetsForLeaderEpochClient.OffsetForEpochResult> { + + OffsetsForLeaderEpochClient(ConsumerNetworkClient client, LogContext logContext) { + super(client, logContext); + } + + @Override + protected AbstractRequest.Builder prepareRequest( + Node node, Map requestData) { + OffsetForLeaderTopicCollection topics = new OffsetForLeaderTopicCollection(requestData.size()); + requestData.forEach((topicPartition, fetchPosition) -> + fetchPosition.offsetEpoch.ifPresent(fetchEpoch -> { + OffsetForLeaderTopic topic = topics.find(topicPartition.topic()); + if (topic == null) { + topic = new OffsetForLeaderTopic().setTopic(topicPartition.topic()); + topics.add(topic); + } + topic.partitions().add(new OffsetForLeaderPartition() + .setPartition(topicPartition.partition()) + .setLeaderEpoch(fetchEpoch) + .setCurrentLeaderEpoch(fetchPosition.currentLeader.epoch + .orElse(RecordBatch.NO_PARTITION_LEADER_EPOCH)) + ); + }) + ); + return OffsetsForLeaderEpochRequest.Builder.forConsumer(topics); + } + + @Override + protected OffsetForEpochResult handleResponse( + Node node, + Map requestData, + OffsetsForLeaderEpochResponse response) { + + Set partitionsToRetry = new HashSet<>(requestData.keySet()); + Set unauthorizedTopics = new HashSet<>(); + Map endOffsets = new HashMap<>(); + + for (OffsetForLeaderTopicResult topic : response.data().topics()) { + for (EpochEndOffset partition : topic.partitions()) { + TopicPartition topicPartition = new TopicPartition(topic.topic(), partition.partition()); + + if (!requestData.containsKey(topicPartition)) { + logger().warn("Received unrequested topic or partition {} from response, ignoring.", topicPartition); + continue; + } + + Errors error = Errors.forCode(partition.errorCode()); + switch (error) { + case NONE: + logger().debug("Handling OffsetsForLeaderEpoch response for {}. Got offset {} for epoch {}.", + topicPartition, partition.endOffset(), partition.leaderEpoch()); + endOffsets.put(topicPartition, partition); + partitionsToRetry.remove(topicPartition); + break; + case NOT_LEADER_OR_FOLLOWER: + case REPLICA_NOT_AVAILABLE: + case KAFKA_STORAGE_ERROR: + case OFFSET_NOT_AVAILABLE: + case LEADER_NOT_AVAILABLE: + case FENCED_LEADER_EPOCH: + case UNKNOWN_LEADER_EPOCH: + logger().debug("Attempt to fetch offsets for partition {} failed due to {}, retrying.", + topicPartition, error); + break; + case UNKNOWN_TOPIC_OR_PARTITION: + logger().warn("Received unknown topic or partition error in OffsetsForLeaderEpoch request for partition {}.", + topicPartition); + break; + case TOPIC_AUTHORIZATION_FAILED: + unauthorizedTopics.add(topicPartition.topic()); + partitionsToRetry.remove(topicPartition); + break; + default: + logger().warn("Attempt to fetch offsets for partition {} failed due to: {}, retrying.", + topicPartition, error.message()); + } + } + } + + if (!unauthorizedTopics.isEmpty()) + throw new TopicAuthorizationException(unauthorizedTopics); + else + return new OffsetForEpochResult(endOffsets, partitionsToRetry); + } + + public static class OffsetForEpochResult { + private final Map endOffsets; + private final Set partitionsToRetry; + + OffsetForEpochResult(Map endOffsets, Set partitionsNeedingRetry) { + this.endOffsets = endOffsets; + this.partitionsToRetry = partitionsNeedingRetry; + } + + public Map endOffsets() { + return endOffsets; + } + + public Set partitionsToRetry() { + return partitionsToRetry; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/RequestFuture.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/RequestFuture.java new file mode 100644 index 0000000..8a9d970 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/RequestFuture.java @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.common.errors.RetriableException; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.utils.Timer; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Result of an asynchronous request from {@link ConsumerNetworkClient}. Use {@link ConsumerNetworkClient#poll(Timer)} + * (and variants) to finish a request future. Use {@link #isDone()} to check if the future is complete, and + * {@link #succeeded()} to check if the request completed successfully. Typical usage might look like this: + * + *
+ *     RequestFuture future = client.send(api, request);
+ *     client.poll(future);
+ *
+ *     if (future.succeeded()) {
+ *         ClientResponse response = future.value();
+ *         // Handle response
+ *     } else {
+ *         throw future.exception();
+ *     }
+ * 
+ * + * @param Return type of the result (Can be Void if there is no response) + */ +public class RequestFuture implements ConsumerNetworkClient.PollCondition { + + private static final Object INCOMPLETE_SENTINEL = new Object(); + private final AtomicReference result = new AtomicReference<>(INCOMPLETE_SENTINEL); + private final ConcurrentLinkedQueue> listeners = new ConcurrentLinkedQueue<>(); + private final CountDownLatch completedLatch = new CountDownLatch(1); + + /** + * Check whether the response is ready to be handled + * @return true if the response is ready, false otherwise + */ + public boolean isDone() { + return result.get() != INCOMPLETE_SENTINEL; + } + + public boolean awaitDone(long timeout, TimeUnit unit) throws InterruptedException { + return completedLatch.await(timeout, unit); + } + + /** + * Get the value corresponding to this request (only available if the request succeeded) + * @return the value set in {@link #complete(Object)} + * @throws IllegalStateException if the future is not complete or failed + */ + @SuppressWarnings("unchecked") + public T value() { + if (!succeeded()) + throw new IllegalStateException("Attempt to retrieve value from future which hasn't successfully completed"); + return (T) result.get(); + } + + /** + * Check if the request succeeded; + * @return true if the request completed and was successful + */ + public boolean succeeded() { + return isDone() && !failed(); + } + + /** + * Check if the request failed. + * @return true if the request completed with a failure + */ + public boolean failed() { + return result.get() instanceof RuntimeException; + } + + /** + * Check if the request is retriable (convenience method for checking if + * the exception is an instance of {@link RetriableException}. + * @return true if it is retriable, false otherwise + * @throws IllegalStateException if the future is not complete or completed successfully + */ + public boolean isRetriable() { + return exception() instanceof RetriableException; + } + + /** + * Get the exception from a failed result (only available if the request failed) + * @return the exception set in {@link #raise(RuntimeException)} + * @throws IllegalStateException if the future is not complete or completed successfully + */ + public RuntimeException exception() { + if (!failed()) + throw new IllegalStateException("Attempt to retrieve exception from future which hasn't failed"); + return (RuntimeException) result.get(); + } + + /** + * Complete the request successfully. After this call, {@link #succeeded()} will return true + * and the value can be obtained through {@link #value()}. + * @param value corresponding value (or null if there is none) + * @throws IllegalStateException if the future has already been completed + * @throws IllegalArgumentException if the argument is an instance of {@link RuntimeException} + */ + public void complete(T value) { + try { + if (value instanceof RuntimeException) + throw new IllegalArgumentException("The argument to complete can not be an instance of RuntimeException"); + + if (!result.compareAndSet(INCOMPLETE_SENTINEL, value)) + throw new IllegalStateException("Invalid attempt to complete a request future which is already complete"); + fireSuccess(); + } finally { + completedLatch.countDown(); + } + } + + /** + * Raise an exception. The request will be marked as failed, and the caller can either + * handle the exception or throw it. + * @param e corresponding exception to be passed to caller + * @throws IllegalStateException if the future has already been completed + */ + public void raise(RuntimeException e) { + try { + if (e == null) + throw new IllegalArgumentException("The exception passed to raise must not be null"); + + if (!result.compareAndSet(INCOMPLETE_SENTINEL, e)) + throw new IllegalStateException("Invalid attempt to complete a request future which is already complete"); + + fireFailure(); + } finally { + completedLatch.countDown(); + } + } + + /** + * Raise an error. The request will be marked as failed. + * @param error corresponding error to be passed to caller + */ + public void raise(Errors error) { + raise(error.exception()); + } + + private void fireSuccess() { + T value = value(); + while (true) { + RequestFutureListener listener = listeners.poll(); + if (listener == null) + break; + listener.onSuccess(value); + } + } + + private void fireFailure() { + RuntimeException exception = exception(); + while (true) { + RequestFutureListener listener = listeners.poll(); + if (listener == null) + break; + listener.onFailure(exception); + } + } + + /** + * Add a listener which will be notified when the future completes + * @param listener non-null listener to add + */ + public void addListener(RequestFutureListener listener) { + this.listeners.add(listener); + if (failed()) + fireFailure(); + else if (succeeded()) + fireSuccess(); + } + + /** + * Convert from a request future of one type to another type + * @param adapter The adapter which does the conversion + * @param The type of the future adapted to + * @return The new future + */ + public RequestFuture compose(final RequestFutureAdapter adapter) { + final RequestFuture adapted = new RequestFuture<>(); + addListener(new RequestFutureListener() { + @Override + public void onSuccess(T value) { + adapter.onSuccess(value, adapted); + } + + @Override + public void onFailure(RuntimeException e) { + adapter.onFailure(e, adapted); + } + }); + return adapted; + } + + public void chain(final RequestFuture future) { + addListener(new RequestFutureListener() { + @Override + public void onSuccess(T value) { + future.complete(value); + } + + @Override + public void onFailure(RuntimeException e) { + future.raise(e); + } + }); + } + + public static RequestFuture failure(RuntimeException e) { + RequestFuture future = new RequestFuture<>(); + future.raise(e); + return future; + } + + public static RequestFuture voidSuccess() { + RequestFuture future = new RequestFuture<>(); + future.complete(null); + return future; + } + + public static RequestFuture coordinatorNotAvailable() { + return failure(Errors.COORDINATOR_NOT_AVAILABLE.exception()); + } + + public static RequestFuture noBrokersAvailable() { + return failure(new NoAvailableBrokersException()); + } + + @Override + public boolean shouldBlock() { + return !isDone(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/RequestFutureAdapter.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/RequestFutureAdapter.java new file mode 100644 index 0000000..7261c0f --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/RequestFutureAdapter.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +/** + * Adapt from a request future of one type to another. + * + * @param Type to adapt from + * @param Type to adapt to + */ +public abstract class RequestFutureAdapter { + + public abstract void onSuccess(F value, RequestFuture future); + + public void onFailure(RuntimeException e, RequestFuture future) { + future.raise(e); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/RequestFutureListener.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/RequestFutureListener.java new file mode 100644 index 0000000..3a624eb --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/RequestFutureListener.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +/** + * Listener interface to hook into RequestFuture completion. + */ +public interface RequestFutureListener { + + void onSuccess(T value); + + void onFailure(RuntimeException e); +} diff --git a/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java new file mode 100644 index 0000000..2dd587f --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/consumer/internals/SubscriptionState.java @@ -0,0 +1,1182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.clients.ApiVersions; +import org.apache.kafka.clients.Metadata; +import org.apache.kafka.clients.NodeApiVersions; +import org.apache.kafka.clients.consumer.ConsumerRebalanceListener; +import org.apache.kafka.clients.consumer.NoOffsetForPartitionException; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.consumer.OffsetResetStrategy; +import org.apache.kafka.common.IsolationLevel; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.internals.PartitionStates; +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset; +import org.apache.kafka.common.utils.LogContext; +import org.slf4j.Logger; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.function.LongSupplier; +import java.util.function.Predicate; +import java.util.regex.Pattern; + +import static org.apache.kafka.clients.consumer.internals.Fetcher.hasUsableOffsetForLeaderEpochVersion; +import static org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.UNDEFINED_EPOCH; +import static org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.UNDEFINED_EPOCH_OFFSET; + +/** + * A class for tracking the topics, partitions, and offsets for the consumer. A partition + * is "assigned" either directly with {@link #assignFromUser(Set)} (manual assignment) + * or with {@link #assignFromSubscribed(Collection)} (automatic assignment from subscription). + * + * Once assigned, the partition is not considered "fetchable" until its initial position has + * been set with {@link #seekValidated(TopicPartition, FetchPosition)}. Fetchable partitions track a fetch + * position which is used to set the offset of the next fetch, and a consumed position + * which is the last offset that has been returned to the user. You can suspend fetching + * from a partition through {@link #pause(TopicPartition)} without affecting the fetched/consumed + * offsets. The partition will remain unfetchable until the {@link #resume(TopicPartition)} is + * used. You can also query the pause state independently with {@link #isPaused(TopicPartition)}. + * + * Note that pause state as well as fetch/consumed positions are not preserved when partition + * assignment is changed whether directly by the user or through a group rebalance. + * + * Thread Safety: this class is thread-safe. + */ +public class SubscriptionState { + private static final String SUBSCRIPTION_EXCEPTION_MESSAGE = + "Subscription to topics, partitions and pattern are mutually exclusive"; + + private final Logger log; + + private enum SubscriptionType { + NONE, AUTO_TOPICS, AUTO_PATTERN, USER_ASSIGNED + } + + /* the type of subscription */ + private SubscriptionType subscriptionType; + + /* the pattern user has requested */ + private Pattern subscribedPattern; + + /* the list of topics the user has requested */ + private Set subscription; + + /* The list of topics the group has subscribed to. This may include some topics which are not part + * of `subscription` for the leader of a group since it is responsible for detecting metadata changes + * which require a group rebalance. */ + private Set groupSubscription; + + /* the partitions that are currently assigned, note that the order of partition matters (see FetchBuilder for more details) */ + private final PartitionStates assignment; + + /* Default offset reset strategy */ + private final OffsetResetStrategy defaultResetStrategy; + + /* User-provided listener to be invoked when assignment changes */ + private ConsumerRebalanceListener rebalanceListener; + + private int assignmentId = 0; + + @Override + public synchronized String toString() { + return "SubscriptionState{" + + "type=" + subscriptionType + + ", subscribedPattern=" + subscribedPattern + + ", subscription=" + String.join(",", subscription) + + ", groupSubscription=" + String.join(",", groupSubscription) + + ", defaultResetStrategy=" + defaultResetStrategy + + ", assignment=" + assignment.partitionStateValues() + " (id=" + assignmentId + ")}"; + } + + public synchronized String prettyString() { + switch (subscriptionType) { + case NONE: + return "None"; + case AUTO_TOPICS: + return "Subscribe(" + String.join(",", subscription) + ")"; + case AUTO_PATTERN: + return "Subscribe(" + subscribedPattern + ")"; + case USER_ASSIGNED: + return "Assign(" + assignedPartitions() + " , id=" + assignmentId + ")"; + default: + throw new IllegalStateException("Unrecognized subscription type: " + subscriptionType); + } + } + + public SubscriptionState(LogContext logContext, OffsetResetStrategy defaultResetStrategy) { + this.log = logContext.logger(this.getClass()); + this.defaultResetStrategy = defaultResetStrategy; + this.subscription = new HashSet<>(); + this.assignment = new PartitionStates<>(); + this.groupSubscription = new HashSet<>(); + this.subscribedPattern = null; + this.subscriptionType = SubscriptionType.NONE; + } + + /** + * Monotonically increasing id which is incremented after every assignment change. This can + * be used to check when an assignment has changed. + * + * @return The current assignment Id + */ + synchronized int assignmentId() { + return assignmentId; + } + + /** + * This method sets the subscription type if it is not already set (i.e. when it is NONE), + * or verifies that the subscription type is equal to the give type when it is set (i.e. + * when it is not NONE) + * @param type The given subscription type + */ + private void setSubscriptionType(SubscriptionType type) { + if (this.subscriptionType == SubscriptionType.NONE) + this.subscriptionType = type; + else if (this.subscriptionType != type) + throw new IllegalStateException(SUBSCRIPTION_EXCEPTION_MESSAGE); + } + + public synchronized boolean subscribe(Set topics, ConsumerRebalanceListener listener) { + registerRebalanceListener(listener); + setSubscriptionType(SubscriptionType.AUTO_TOPICS); + return changeSubscription(topics); + } + + public synchronized void subscribe(Pattern pattern, ConsumerRebalanceListener listener) { + registerRebalanceListener(listener); + setSubscriptionType(SubscriptionType.AUTO_PATTERN); + this.subscribedPattern = pattern; + } + + public synchronized boolean subscribeFromPattern(Set topics) { + if (subscriptionType != SubscriptionType.AUTO_PATTERN) + throw new IllegalArgumentException("Attempt to subscribe from pattern while subscription type set to " + + subscriptionType); + + return changeSubscription(topics); + } + + private boolean changeSubscription(Set topicsToSubscribe) { + if (subscription.equals(topicsToSubscribe)) + return false; + + subscription = topicsToSubscribe; + return true; + } + + /** + * Set the current group subscription. This is used by the group leader to ensure + * that it receives metadata updates for all topics that the group is interested in. + * + * @param topics All topics from the group subscription + * @return true if the group subscription contains topics which are not part of the local subscription + */ + synchronized boolean groupSubscribe(Collection topics) { + if (!hasAutoAssignedPartitions()) + throw new IllegalStateException(SUBSCRIPTION_EXCEPTION_MESSAGE); + groupSubscription = new HashSet<>(topics); + return !subscription.containsAll(groupSubscription); + } + + /** + * Reset the group's subscription to only contain topics subscribed by this consumer. + */ + synchronized void resetGroupSubscription() { + groupSubscription = Collections.emptySet(); + } + + /** + * Change the assignment to the specified partitions provided by the user, + * note this is different from {@link #assignFromSubscribed(Collection)} + * whose input partitions are provided from the subscribed topics. + */ + public synchronized boolean assignFromUser(Set partitions) { + setSubscriptionType(SubscriptionType.USER_ASSIGNED); + + if (this.assignment.partitionSet().equals(partitions)) + return false; + + assignmentId++; + + // update the subscribed topics + Set manualSubscribedTopics = new HashSet<>(); + Map partitionToState = new HashMap<>(); + for (TopicPartition partition : partitions) { + TopicPartitionState state = assignment.stateValue(partition); + if (state == null) + state = new TopicPartitionState(); + partitionToState.put(partition, state); + + manualSubscribedTopics.add(partition.topic()); + } + + this.assignment.set(partitionToState); + return changeSubscription(manualSubscribedTopics); + } + + /** + * @return true if assignments matches subscription, otherwise false + */ + public synchronized boolean checkAssignmentMatchedSubscription(Collection assignments) { + for (TopicPartition topicPartition : assignments) { + if (this.subscribedPattern != null) { + if (!this.subscribedPattern.matcher(topicPartition.topic()).matches()) { + log.info("Assigned partition {} for non-subscribed topic regex pattern; subscription pattern is {}", + topicPartition, + this.subscribedPattern); + + return false; + } + } else { + if (!this.subscription.contains(topicPartition.topic())) { + log.info("Assigned partition {} for non-subscribed topic; subscription is {}", topicPartition, this.subscription); + + return false; + } + } + } + + return true; + } + + /** + * Change the assignment to the specified partitions returned from the coordinator, note this is + * different from {@link #assignFromUser(Set)} which directly set the assignment from user inputs. + */ + public synchronized void assignFromSubscribed(Collection assignments) { + if (!this.hasAutoAssignedPartitions()) + throw new IllegalArgumentException("Attempt to dynamically assign partitions while manual assignment in use"); + + Map assignedPartitionStates = new HashMap<>(assignments.size()); + for (TopicPartition tp : assignments) { + TopicPartitionState state = this.assignment.stateValue(tp); + if (state == null) + state = new TopicPartitionState(); + assignedPartitionStates.put(tp, state); + } + + assignmentId++; + this.assignment.set(assignedPartitionStates); + } + + private void registerRebalanceListener(ConsumerRebalanceListener listener) { + if (listener == null) + throw new IllegalArgumentException("RebalanceListener cannot be null"); + this.rebalanceListener = listener; + } + + /** + * Check whether pattern subscription is in use. + * + */ + synchronized boolean hasPatternSubscription() { + return this.subscriptionType == SubscriptionType.AUTO_PATTERN; + } + + public synchronized boolean hasNoSubscriptionOrUserAssignment() { + return this.subscriptionType == SubscriptionType.NONE; + } + + public synchronized void unsubscribe() { + this.subscription = Collections.emptySet(); + this.groupSubscription = Collections.emptySet(); + this.assignment.clear(); + this.subscribedPattern = null; + this.subscriptionType = SubscriptionType.NONE; + this.assignmentId++; + } + + /** + * Check whether a topic matches a subscribed pattern. + * + * @return true if pattern subscription is in use and the topic matches the subscribed pattern, false otherwise + */ + synchronized boolean matchesSubscribedPattern(String topic) { + Pattern pattern = this.subscribedPattern; + if (hasPatternSubscription() && pattern != null) + return pattern.matcher(topic).matches(); + return false; + } + + public synchronized Set subscription() { + if (hasAutoAssignedPartitions()) + return this.subscription; + return Collections.emptySet(); + } + + public synchronized Set pausedPartitions() { + return collectPartitions(TopicPartitionState::isPaused); + } + + /** + * Get the subscription topics for which metadata is required. For the leader, this will include + * the union of the subscriptions of all group members. For followers, it is just that member's + * subscription. This is used when querying topic metadata to detect the metadata changes which would + * require rebalancing. The leader fetches metadata for all topics in the group so that it + * can do the partition assignment (which requires at least partition counts for all topics + * to be assigned). + * + * @return The union of all subscribed topics in the group if this member is the leader + * of the current generation; otherwise it returns the same set as {@link #subscription()} + */ + synchronized Set metadataTopics() { + if (groupSubscription.isEmpty()) + return subscription; + else if (groupSubscription.containsAll(subscription)) + return groupSubscription; + else { + // When subscription changes `groupSubscription` may be outdated, ensure that + // new subscription topics are returned. + Set topics = new HashSet<>(groupSubscription); + topics.addAll(subscription); + return topics; + } + } + + synchronized boolean needsMetadata(String topic) { + return subscription.contains(topic) || groupSubscription.contains(topic); + } + + private TopicPartitionState assignedState(TopicPartition tp) { + TopicPartitionState state = this.assignment.stateValue(tp); + if (state == null) + throw new IllegalStateException("No current assignment for partition " + tp); + return state; + } + + private TopicPartitionState assignedStateOrNull(TopicPartition tp) { + return this.assignment.stateValue(tp); + } + + public synchronized void seekValidated(TopicPartition tp, FetchPosition position) { + assignedState(tp).seekValidated(position); + } + + public void seek(TopicPartition tp, long offset) { + seekValidated(tp, new FetchPosition(offset)); + } + + public void seekUnvalidated(TopicPartition tp, FetchPosition position) { + assignedState(tp).seekUnvalidated(position); + } + + synchronized void maybeSeekUnvalidated(TopicPartition tp, FetchPosition position, OffsetResetStrategy requestedResetStrategy) { + TopicPartitionState state = assignedStateOrNull(tp); + if (state == null) { + log.debug("Skipping reset of partition {} since it is no longer assigned", tp); + } else if (!state.awaitingReset()) { + log.debug("Skipping reset of partition {} since reset is no longer needed", tp); + } else if (requestedResetStrategy != state.resetStrategy) { + log.debug("Skipping reset of partition {} since an alternative reset has been requested", tp); + } else { + log.info("Resetting offset for partition {} to position {}.", tp, position); + state.seekUnvalidated(position); + } + } + + /** + * @return a modifiable copy of the currently assigned partitions + */ + public synchronized Set assignedPartitions() { + return new HashSet<>(this.assignment.partitionSet()); + } + + /** + * @return a modifiable copy of the currently assigned partitions as a list + */ + public synchronized List assignedPartitionsList() { + return new ArrayList<>(this.assignment.partitionSet()); + } + + /** + * Provides the number of assigned partitions in a thread safe manner. + * @return the number of assigned partitions. + */ + synchronized int numAssignedPartitions() { + return this.assignment.size(); + } + + // Visible for testing + public synchronized List fetchablePartitions(Predicate isAvailable) { + // Since this is in the hot-path for fetching, we do this instead of using java.util.stream API + List result = new ArrayList<>(); + assignment.forEach((topicPartition, topicPartitionState) -> { + // Cheap check is first to avoid evaluating the predicate if possible + if (topicPartitionState.isFetchable() && isAvailable.test(topicPartition)) { + result.add(topicPartition); + } + }); + return result; + } + + public synchronized boolean hasAutoAssignedPartitions() { + return this.subscriptionType == SubscriptionType.AUTO_TOPICS || this.subscriptionType == SubscriptionType.AUTO_PATTERN; + } + + public synchronized void position(TopicPartition tp, FetchPosition position) { + assignedState(tp).position(position); + } + + /** + * Enter the offset validation state if the leader for this partition is known to support a usable version of the + * OffsetsForLeaderEpoch API. If the leader node does not support the API, simply complete the offset validation. + * + * @param apiVersions supported API versions + * @param tp topic partition to validate + * @param leaderAndEpoch leader epoch of the topic partition + * @return true if we enter the offset validation state + */ + public synchronized boolean maybeValidatePositionForCurrentLeader(ApiVersions apiVersions, + TopicPartition tp, + Metadata.LeaderAndEpoch leaderAndEpoch) { + if (leaderAndEpoch.leader.isPresent()) { + NodeApiVersions nodeApiVersions = apiVersions.get(leaderAndEpoch.leader.get().idString()); + if (nodeApiVersions == null || hasUsableOffsetForLeaderEpochVersion(nodeApiVersions)) { + return assignedState(tp).maybeValidatePosition(leaderAndEpoch); + } else { + // If the broker does not support a newer version of OffsetsForLeaderEpoch, we skip validation + assignedState(tp).updatePositionLeaderNoValidation(leaderAndEpoch); + return false; + } + } else { + return assignedState(tp).maybeValidatePosition(leaderAndEpoch); + } + } + + /** + * Attempt to complete validation with the end offset returned from the OffsetForLeaderEpoch request. + * @return Log truncation details if detected and no reset policy is defined. + */ + public synchronized Optional maybeCompleteValidation(TopicPartition tp, + FetchPosition requestPosition, + EpochEndOffset epochEndOffset) { + TopicPartitionState state = assignedStateOrNull(tp); + if (state == null) { + log.debug("Skipping completed validation for partition {} which is not currently assigned.", tp); + } else if (!state.awaitingValidation()) { + log.debug("Skipping completed validation for partition {} which is no longer expecting validation.", tp); + } else { + SubscriptionState.FetchPosition currentPosition = state.position; + if (!currentPosition.equals(requestPosition)) { + log.debug("Skipping completed validation for partition {} since the current position {} " + + "no longer matches the position {} when the request was sent", + tp, currentPosition, requestPosition); + } else if (epochEndOffset.endOffset() == UNDEFINED_EPOCH_OFFSET || + epochEndOffset.leaderEpoch() == UNDEFINED_EPOCH) { + if (hasDefaultOffsetResetPolicy()) { + log.info("Truncation detected for partition {} at offset {}, resetting offset", + tp, currentPosition); + requestOffsetReset(tp); + } else { + log.warn("Truncation detected for partition {} at offset {}, but no reset policy is set", + tp, currentPosition); + return Optional.of(new LogTruncation(tp, requestPosition, Optional.empty())); + } + } else if (epochEndOffset.endOffset() < currentPosition.offset) { + if (hasDefaultOffsetResetPolicy()) { + SubscriptionState.FetchPosition newPosition = new SubscriptionState.FetchPosition( + epochEndOffset.endOffset(), Optional.of(epochEndOffset.leaderEpoch()), + currentPosition.currentLeader); + log.info("Truncation detected for partition {} at offset {}, resetting offset to " + + "the first offset known to diverge {}", tp, currentPosition, newPosition); + state.seekValidated(newPosition); + } else { + OffsetAndMetadata divergentOffset = new OffsetAndMetadata(epochEndOffset.endOffset(), + Optional.of(epochEndOffset.leaderEpoch()), null); + log.warn("Truncation detected for partition {} at offset {} (the end offset from the " + + "broker is {}), but no reset policy is set", tp, currentPosition, divergentOffset); + return Optional.of(new LogTruncation(tp, requestPosition, Optional.of(divergentOffset))); + } + } else { + state.completeValidation(); + } + } + + return Optional.empty(); + } + + public synchronized boolean awaitingValidation(TopicPartition tp) { + return assignedState(tp).awaitingValidation(); + } + + public synchronized void completeValidation(TopicPartition tp) { + assignedState(tp).completeValidation(); + } + + public synchronized FetchPosition validPosition(TopicPartition tp) { + return assignedState(tp).validPosition(); + } + + public synchronized FetchPosition position(TopicPartition tp) { + return assignedState(tp).position; + } + + public synchronized Long partitionLag(TopicPartition tp, IsolationLevel isolationLevel) { + TopicPartitionState topicPartitionState = assignedState(tp); + if (topicPartitionState.position == null) { + return null; + } else if (isolationLevel == IsolationLevel.READ_COMMITTED) { + return topicPartitionState.lastStableOffset == null ? null : topicPartitionState.lastStableOffset - topicPartitionState.position.offset; + } else { + return topicPartitionState.highWatermark == null ? null : topicPartitionState.highWatermark - topicPartitionState.position.offset; + } + } + + public synchronized Long partitionEndOffset(TopicPartition tp, IsolationLevel isolationLevel) { + TopicPartitionState topicPartitionState = assignedState(tp); + if (isolationLevel == IsolationLevel.READ_COMMITTED) { + return topicPartitionState.lastStableOffset; + } else { + return topicPartitionState.highWatermark; + } + } + + public synchronized void requestPartitionEndOffset(TopicPartition tp) { + TopicPartitionState topicPartitionState = assignedState(tp); + topicPartitionState.requestEndOffset(); + } + + public synchronized boolean partitionEndOffsetRequested(TopicPartition tp) { + TopicPartitionState topicPartitionState = assignedState(tp); + return topicPartitionState.endOffsetRequested(); + } + + synchronized Long partitionLead(TopicPartition tp) { + TopicPartitionState topicPartitionState = assignedState(tp); + return topicPartitionState.logStartOffset == null ? null : topicPartitionState.position.offset - topicPartitionState.logStartOffset; + } + + synchronized void updateHighWatermark(TopicPartition tp, long highWatermark) { + assignedState(tp).highWatermark(highWatermark); + } + + synchronized void updateLogStartOffset(TopicPartition tp, long logStartOffset) { + assignedState(tp).logStartOffset(logStartOffset); + } + + synchronized void updateLastStableOffset(TopicPartition tp, long lastStableOffset) { + assignedState(tp).lastStableOffset(lastStableOffset); + } + + /** + * Set the preferred read replica with a lease timeout. After this time, the replica will no longer be valid and + * {@link #preferredReadReplica(TopicPartition, long)} will return an empty result. + * + * @param tp The topic partition + * @param preferredReadReplicaId The preferred read replica + * @param timeMs The time at which this preferred replica is no longer valid + */ + public synchronized void updatePreferredReadReplica(TopicPartition tp, int preferredReadReplicaId, LongSupplier timeMs) { + assignedState(tp).updatePreferredReadReplica(preferredReadReplicaId, timeMs); + } + + /** + * Get the preferred read replica + * + * @param tp The topic partition + * @param timeMs The current time + * @return Returns the current preferred read replica, if it has been set and if it has not expired. + */ + public synchronized Optional preferredReadReplica(TopicPartition tp, long timeMs) { + final TopicPartitionState topicPartitionState = assignedStateOrNull(tp); + if (topicPartitionState == null) { + return Optional.empty(); + } else { + return topicPartitionState.preferredReadReplica(timeMs); + } + } + + /** + * Unset the preferred read replica. This causes the fetcher to go back to the leader for fetches. + * + * @param tp The topic partition + * @return true if the preferred read replica was set, false otherwise. + */ + public synchronized Optional clearPreferredReadReplica(TopicPartition tp) { + return assignedState(tp).clearPreferredReadReplica(); + } + + public synchronized Map allConsumed() { + Map allConsumed = new HashMap<>(); + assignment.forEach((topicPartition, partitionState) -> { + if (partitionState.hasValidPosition()) + allConsumed.put(topicPartition, new OffsetAndMetadata(partitionState.position.offset, + partitionState.position.offsetEpoch, "")); + }); + return allConsumed; + } + + public synchronized void requestOffsetReset(TopicPartition partition, OffsetResetStrategy offsetResetStrategy) { + assignedState(partition).reset(offsetResetStrategy); + } + + public synchronized void requestOffsetReset(Collection partitions, OffsetResetStrategy offsetResetStrategy) { + partitions.forEach(tp -> { + log.info("Seeking to {} offset of partition {}", offsetResetStrategy, tp); + assignedState(tp).reset(offsetResetStrategy); + }); + } + + public void requestOffsetReset(TopicPartition partition) { + requestOffsetReset(partition, defaultResetStrategy); + } + + synchronized void setNextAllowedRetry(Set partitions, long nextAllowResetTimeMs) { + for (TopicPartition partition : partitions) { + assignedState(partition).setNextAllowedRetry(nextAllowResetTimeMs); + } + } + + boolean hasDefaultOffsetResetPolicy() { + return defaultResetStrategy != OffsetResetStrategy.NONE; + } + + public synchronized boolean isOffsetResetNeeded(TopicPartition partition) { + return assignedState(partition).awaitingReset(); + } + + public synchronized OffsetResetStrategy resetStrategy(TopicPartition partition) { + return assignedState(partition).resetStrategy(); + } + + public synchronized boolean hasAllFetchPositions() { + // Since this is in the hot-path for fetching, we do this instead of using java.util.stream API + Iterator it = assignment.stateIterator(); + while (it.hasNext()) { + if (!it.next().hasValidPosition()) { + return false; + } + } + return true; + } + + public synchronized Set initializingPartitions() { + return collectPartitions(state -> state.fetchState.equals(FetchStates.INITIALIZING)); + } + + private Set collectPartitions(Predicate filter) { + Set result = new HashSet<>(); + assignment.forEach((topicPartition, topicPartitionState) -> { + if (filter.test(topicPartitionState)) { + result.add(topicPartition); + } + }); + return result; + } + + + public synchronized void resetInitializingPositions() { + final Set partitionsWithNoOffsets = new HashSet<>(); + assignment.forEach((tp, partitionState) -> { + if (partitionState.fetchState.equals(FetchStates.INITIALIZING)) { + if (defaultResetStrategy == OffsetResetStrategy.NONE) + partitionsWithNoOffsets.add(tp); + else + requestOffsetReset(tp); + } + }); + + if (!partitionsWithNoOffsets.isEmpty()) + throw new NoOffsetForPartitionException(partitionsWithNoOffsets); + } + + public synchronized Set partitionsNeedingReset(long nowMs) { + return collectPartitions(state -> state.awaitingReset() && !state.awaitingRetryBackoff(nowMs)); + } + + public synchronized Set partitionsNeedingValidation(long nowMs) { + return collectPartitions(state -> state.awaitingValidation() && !state.awaitingRetryBackoff(nowMs)); + } + + public synchronized boolean isAssigned(TopicPartition tp) { + return assignment.contains(tp); + } + + public synchronized boolean isPaused(TopicPartition tp) { + TopicPartitionState assignedOrNull = assignedStateOrNull(tp); + return assignedOrNull != null && assignedOrNull.isPaused(); + } + + synchronized boolean isFetchable(TopicPartition tp) { + TopicPartitionState assignedOrNull = assignedStateOrNull(tp); + return assignedOrNull != null && assignedOrNull.isFetchable(); + } + + public synchronized boolean hasValidPosition(TopicPartition tp) { + TopicPartitionState assignedOrNull = assignedStateOrNull(tp); + return assignedOrNull != null && assignedOrNull.hasValidPosition(); + } + + public synchronized void pause(TopicPartition tp) { + assignedState(tp).pause(); + } + + public synchronized void resume(TopicPartition tp) { + assignedState(tp).resume(); + } + + synchronized void requestFailed(Set partitions, long nextRetryTimeMs) { + for (TopicPartition partition : partitions) { + // by the time the request failed, the assignment may no longer + // contain this partition any more, in which case we would just ignore. + final TopicPartitionState state = assignedStateOrNull(partition); + if (state != null) + state.requestFailed(nextRetryTimeMs); + } + } + + synchronized void movePartitionToEnd(TopicPartition tp) { + assignment.moveToEnd(tp); + } + + public synchronized ConsumerRebalanceListener rebalanceListener() { + return rebalanceListener; + } + + private static class TopicPartitionState { + + private FetchState fetchState; + private FetchPosition position; // last consumed position + + private Long highWatermark; // the high watermark from last fetch + private Long logStartOffset; // the log start offset + private Long lastStableOffset; + private boolean paused; // whether this partition has been paused by the user + private OffsetResetStrategy resetStrategy; // the strategy to use if the offset needs resetting + private Long nextRetryTimeMs; + private Integer preferredReadReplica; + private Long preferredReadReplicaExpireTimeMs; + private boolean endOffsetRequested; + + TopicPartitionState() { + this.paused = false; + this.endOffsetRequested = false; + this.fetchState = FetchStates.INITIALIZING; + this.position = null; + this.highWatermark = null; + this.logStartOffset = null; + this.lastStableOffset = null; + this.resetStrategy = null; + this.nextRetryTimeMs = null; + this.preferredReadReplica = null; + } + + public boolean endOffsetRequested() { + return endOffsetRequested; + } + + public void requestEndOffset() { + endOffsetRequested = true; + } + + private void transitionState(FetchState newState, Runnable runIfTransitioned) { + FetchState nextState = this.fetchState.transitionTo(newState); + if (nextState.equals(newState)) { + this.fetchState = nextState; + runIfTransitioned.run(); + if (this.position == null && nextState.requiresPosition()) { + throw new IllegalStateException("Transitioned subscription state to " + nextState + ", but position is null"); + } else if (!nextState.requiresPosition()) { + this.position = null; + } + } + } + + private Optional preferredReadReplica(long timeMs) { + if (preferredReadReplicaExpireTimeMs != null && timeMs > preferredReadReplicaExpireTimeMs) { + preferredReadReplica = null; + return Optional.empty(); + } else { + return Optional.ofNullable(preferredReadReplica); + } + } + + private void updatePreferredReadReplica(int preferredReadReplica, LongSupplier timeMs) { + if (this.preferredReadReplica == null || preferredReadReplica != this.preferredReadReplica) { + this.preferredReadReplica = preferredReadReplica; + this.preferredReadReplicaExpireTimeMs = timeMs.getAsLong(); + } + } + + private Optional clearPreferredReadReplica() { + if (preferredReadReplica != null) { + int removedReplicaId = this.preferredReadReplica; + this.preferredReadReplica = null; + this.preferredReadReplicaExpireTimeMs = null; + return Optional.of(removedReplicaId); + } else { + return Optional.empty(); + } + } + + private void reset(OffsetResetStrategy strategy) { + transitionState(FetchStates.AWAIT_RESET, () -> { + this.resetStrategy = strategy; + this.nextRetryTimeMs = null; + }); + } + + /** + * Check if the position exists and needs to be validated. If so, enter the AWAIT_VALIDATION state. This method + * also will update the position with the current leader and epoch. + * + * @param currentLeaderAndEpoch leader and epoch to compare the offset with + * @return true if the position is now awaiting validation + */ + private boolean maybeValidatePosition(Metadata.LeaderAndEpoch currentLeaderAndEpoch) { + if (this.fetchState.equals(FetchStates.AWAIT_RESET)) { + return false; + } + + if (!currentLeaderAndEpoch.leader.isPresent()) { + return false; + } + + if (position != null && !position.currentLeader.equals(currentLeaderAndEpoch)) { + FetchPosition newPosition = new FetchPosition(position.offset, position.offsetEpoch, currentLeaderAndEpoch); + validatePosition(newPosition); + preferredReadReplica = null; + } + return this.fetchState.equals(FetchStates.AWAIT_VALIDATION); + } + + /** + * For older versions of the API, we cannot perform offset validation so we simply transition directly to FETCHING + */ + private void updatePositionLeaderNoValidation(Metadata.LeaderAndEpoch currentLeaderAndEpoch) { + if (position != null) { + transitionState(FetchStates.FETCHING, () -> { + this.position = new FetchPosition(position.offset, position.offsetEpoch, currentLeaderAndEpoch); + this.nextRetryTimeMs = null; + }); + } + } + + private void validatePosition(FetchPosition position) { + if (position.offsetEpoch.isPresent() && position.currentLeader.epoch.isPresent()) { + transitionState(FetchStates.AWAIT_VALIDATION, () -> { + this.position = position; + this.nextRetryTimeMs = null; + }); + } else { + // If we have no epoch information for the current position, then we can skip validation + transitionState(FetchStates.FETCHING, () -> { + this.position = position; + this.nextRetryTimeMs = null; + }); + } + } + + /** + * Clear the awaiting validation state and enter fetching. + */ + private void completeValidation() { + if (hasPosition()) { + transitionState(FetchStates.FETCHING, () -> this.nextRetryTimeMs = null); + } + } + + private boolean awaitingValidation() { + return fetchState.equals(FetchStates.AWAIT_VALIDATION); + } + + private boolean awaitingRetryBackoff(long nowMs) { + return nextRetryTimeMs != null && nowMs < nextRetryTimeMs; + } + + private boolean awaitingReset() { + return fetchState.equals(FetchStates.AWAIT_RESET); + } + + private void setNextAllowedRetry(long nextAllowedRetryTimeMs) { + this.nextRetryTimeMs = nextAllowedRetryTimeMs; + } + + private void requestFailed(long nextAllowedRetryTimeMs) { + this.nextRetryTimeMs = nextAllowedRetryTimeMs; + } + + private boolean hasValidPosition() { + return fetchState.hasValidPosition(); + } + + private boolean hasPosition() { + return position != null; + } + + private boolean isPaused() { + return paused; + } + + private void seekValidated(FetchPosition position) { + transitionState(FetchStates.FETCHING, () -> { + this.position = position; + this.resetStrategy = null; + this.nextRetryTimeMs = null; + }); + } + + private void seekUnvalidated(FetchPosition fetchPosition) { + seekValidated(fetchPosition); + validatePosition(fetchPosition); + } + + private void position(FetchPosition position) { + if (!hasValidPosition()) + throw new IllegalStateException("Cannot set a new position without a valid current position"); + this.position = position; + } + + private FetchPosition validPosition() { + if (hasValidPosition()) { + return position; + } else { + return null; + } + } + + private void pause() { + this.paused = true; + } + + private void resume() { + this.paused = false; + } + + private boolean isFetchable() { + return !paused && hasValidPosition(); + } + + private void highWatermark(Long highWatermark) { + this.highWatermark = highWatermark; + this.endOffsetRequested = false; + } + + private void logStartOffset(Long logStartOffset) { + this.logStartOffset = logStartOffset; + } + + private void lastStableOffset(Long lastStableOffset) { + this.lastStableOffset = lastStableOffset; + this.endOffsetRequested = false; + } + + private OffsetResetStrategy resetStrategy() { + return resetStrategy; + } + } + + /** + * The fetch state of a partition. This class is used to determine valid state transitions and expose the some of + * the behavior of the current fetch state. Actual state variables are stored in the {@link TopicPartitionState}. + */ + interface FetchState { + default FetchState transitionTo(FetchState newState) { + if (validTransitions().contains(newState)) { + return newState; + } else { + return this; + } + } + + /** + * Return the valid states which this state can transition to + */ + Collection validTransitions(); + + /** + * Test if this state requires a position to be set + */ + boolean requiresPosition(); + + /** + * Test if this state is considered to have a valid position which can be used for fetching + */ + boolean hasValidPosition(); + } + + /** + * An enumeration of all the possible fetch states. The state transitions are encoded in the values returned by + * {@link FetchState#validTransitions}. + */ + enum FetchStates implements FetchState { + INITIALIZING() { + @Override + public Collection validTransitions() { + return Arrays.asList(FetchStates.FETCHING, FetchStates.AWAIT_RESET, FetchStates.AWAIT_VALIDATION); + } + + @Override + public boolean requiresPosition() { + return false; + } + + @Override + public boolean hasValidPosition() { + return false; + } + }, + + FETCHING() { + @Override + public Collection validTransitions() { + return Arrays.asList(FetchStates.FETCHING, FetchStates.AWAIT_RESET, FetchStates.AWAIT_VALIDATION); + } + + @Override + public boolean requiresPosition() { + return true; + } + + @Override + public boolean hasValidPosition() { + return true; + } + }, + + AWAIT_RESET() { + @Override + public Collection validTransitions() { + return Arrays.asList(FetchStates.FETCHING, FetchStates.AWAIT_RESET); + } + + @Override + public boolean requiresPosition() { + return false; + } + + @Override + public boolean hasValidPosition() { + return false; + } + }, + + AWAIT_VALIDATION() { + @Override + public Collection validTransitions() { + return Arrays.asList(FetchStates.FETCHING, FetchStates.AWAIT_RESET, FetchStates.AWAIT_VALIDATION); + } + + @Override + public boolean requiresPosition() { + return true; + } + + @Override + public boolean hasValidPosition() { + return false; + } + } + } + + /** + * Represents the position of a partition subscription. + * + * This includes the offset and epoch from the last record in + * the batch from a FetchResponse. It also includes the leader epoch at the time the batch was consumed. + */ + public static class FetchPosition { + public final long offset; + final Optional offsetEpoch; + final Metadata.LeaderAndEpoch currentLeader; + + FetchPosition(long offset) { + this(offset, Optional.empty(), Metadata.LeaderAndEpoch.noLeaderOrEpoch()); + } + + public FetchPosition(long offset, Optional offsetEpoch, Metadata.LeaderAndEpoch currentLeader) { + this.offset = offset; + this.offsetEpoch = Objects.requireNonNull(offsetEpoch); + this.currentLeader = Objects.requireNonNull(currentLeader); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + FetchPosition that = (FetchPosition) o; + return offset == that.offset && + offsetEpoch.equals(that.offsetEpoch) && + currentLeader.equals(that.currentLeader); + } + + @Override + public int hashCode() { + return Objects.hash(offset, offsetEpoch, currentLeader); + } + + @Override + public String toString() { + return "FetchPosition{" + + "offset=" + offset + + ", offsetEpoch=" + offsetEpoch + + ", currentLeader=" + currentLeader + + '}'; + } + } + + public static class LogTruncation { + public final TopicPartition topicPartition; + public final FetchPosition fetchPosition; + public final Optional divergentOffsetOpt; + + public LogTruncation(TopicPartition topicPartition, + FetchPosition fetchPosition, + Optional divergentOffsetOpt) { + this.topicPartition = topicPartition; + this.fetchPosition = fetchPosition; + this.divergentOffsetOpt = divergentOffsetOpt; + } + + @Override + public String toString() { + StringBuilder bldr = new StringBuilder() + .append("(partition=") + .append(topicPartition) + .append(", fetchOffset=") + .append(fetchPosition.offset) + .append(", fetchEpoch=") + .append(fetchPosition.offsetEpoch); + + if (divergentOffsetOpt.isPresent()) { + OffsetAndMetadata divergentOffset = divergentOffsetOpt.get(); + bldr.append(", divergentOffset=") + .append(divergentOffset.offset()) + .append(", divergentEpoch=") + .append(divergentOffset.leaderEpoch()); + } else { + bldr.append(", divergentOffset=unknown") + .append(", divergentEpoch=unknown"); + } + + return bldr.append(")").toString(); + + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/BufferExhaustedException.java b/clients/src/main/java/org/apache/kafka/clients/producer/BufferExhaustedException.java new file mode 100644 index 0000000..292bb4e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/BufferExhaustedException.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer; + +import org.apache.kafka.common.errors.TimeoutException; + +/** + * This exception is thrown if the producer cannot allocate memory for a record within max.block.ms due to the buffer + * being too full. + * + * In earlier versions a TimeoutException was thrown instead of this. To keep existing catch-clauses working + * this class extends TimeoutException. + * + */ +public class BufferExhaustedException extends TimeoutException { + + private static final long serialVersionUID = 1L; + + public BufferExhaustedException(String message) { + super(message); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/Callback.java b/clients/src/main/java/org/apache/kafka/clients/producer/Callback.java new file mode 100644 index 0000000..ee0610e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/Callback.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer; + +/** + * A callback interface that the user can implement to allow code to execute when the request is complete. This callback + * will generally execute in the background I/O thread so it should be fast. + */ +public interface Callback { + + /** + * A callback method the user can implement to provide asynchronous handling of request completion. This method will + * be called when the record sent to the server has been acknowledged. When exception is not null in the callback, + * metadata will contain the special -1 value for all fields except for topicPartition, which will be valid. + * + * @param metadata The metadata for the record that was sent (i.e. the partition and offset). An empty metadata + * with -1 value for all fields except for topicPartition will be returned if an error occurred. + * @param exception The exception thrown during processing of this record. Null if no error occurred. + * Possible thrown exceptions include: + * + * Non-Retriable exceptions (fatal, the message will never be sent): + * + * InvalidTopicException + * OffsetMetadataTooLargeException + * RecordBatchTooLargeException + * RecordTooLargeException + * UnknownServerException + * UnknownProducerIdException + * InvalidProducerEpochException + * + * Retriable exceptions (transient, may be covered by increasing #.retries): + * + * CorruptRecordException + * InvalidMetadataException + * NotEnoughReplicasAfterAppendException + * NotEnoughReplicasException + * OffsetOutOfRangeException + * TimeoutException + * UnknownTopicOrPartitionException + */ + void onCompletion(RecordMetadata metadata, Exception exception); +} diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java b/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java new file mode 100644 index 0000000..dbb908d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/KafkaProducer.java @@ -0,0 +1,1393 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer; + +import org.apache.kafka.clients.ApiVersions; +import org.apache.kafka.clients.ClientUtils; +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.clients.KafkaClient; +import org.apache.kafka.clients.NetworkClient; +import org.apache.kafka.clients.consumer.ConsumerGroupMetadata; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.consumer.OffsetCommitCallback; +import org.apache.kafka.clients.producer.internals.BufferPool; +import org.apache.kafka.clients.producer.internals.KafkaProducerMetrics; +import org.apache.kafka.clients.producer.internals.ProducerInterceptors; +import org.apache.kafka.clients.producer.internals.ProducerMetadata; +import org.apache.kafka.clients.producer.internals.ProducerMetrics; +import org.apache.kafka.clients.producer.internals.RecordAccumulator; +import org.apache.kafka.clients.producer.internals.Sender; +import org.apache.kafka.clients.producer.internals.TransactionManager; +import org.apache.kafka.clients.producer.internals.TransactionalRequestResult; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.errors.ApiException; +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.errors.AuthorizationException; +import org.apache.kafka.common.errors.InterruptException; +import org.apache.kafka.common.errors.InvalidTopicException; +import org.apache.kafka.common.errors.ProducerFencedException; +import org.apache.kafka.common.errors.RecordTooLargeException; +import org.apache.kafka.common.errors.SerializationException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.internals.ClusterResourceListeners; +import org.apache.kafka.common.metrics.JmxReporter; +import org.apache.kafka.common.metrics.KafkaMetricsContext; +import org.apache.kafka.common.metrics.MetricConfig; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.MetricsContext; +import org.apache.kafka.common.metrics.MetricsReporter; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.network.ChannelBuilder; +import org.apache.kafka.common.network.Selector; +import org.apache.kafka.common.record.AbstractRecords; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.requests.JoinGroupRequest; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.AppInfoParser; +import org.apache.kafka.common.utils.KafkaThread; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; + +import java.net.InetSocketAddress; +import java.time.Duration; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Properties; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + + +/** + * A Kafka client that publishes records to the Kafka cluster. + *

+ * The producer is thread safe and sharing a single producer instance across threads will generally be faster than + * having multiple instances. + *

+ * Here is a simple example of using the producer to send records with strings containing sequential numbers as the key/value + * pairs. + *

+ * {@code
+ * Properties props = new Properties();
+ * props.put("bootstrap.servers", "localhost:9092");
+ * props.put("acks", "all");
+ * props.put("retries", 0);
+ * props.put("linger.ms", 1);
+ * props.put("key.serializer", "org.apache.kafka.common.serialization.StringSerializer");
+ * props.put("value.serializer", "org.apache.kafka.common.serialization.StringSerializer");
+ *
+ * Producer producer = new KafkaProducer<>(props);
+ * for (int i = 0; i < 100; i++)
+ *     producer.send(new ProducerRecord("my-topic", Integer.toString(i), Integer.toString(i)));
+ *
+ * producer.close();
+ * }
+ *

+ * The producer consists of a pool of buffer space that holds records that haven't yet been transmitted to the server + * as well as a background I/O thread that is responsible for turning these records into requests and transmitting them + * to the cluster. Failure to close the producer after use will leak these resources. + *

+ * The {@link #send(ProducerRecord) send()} method is asynchronous. When called it adds the record to a buffer of pending record sends + * and immediately returns. This allows the producer to batch together individual records for efficiency. + *

+ * The acks config controls the criteria under which requests are considered complete. The "all" setting + * we have specified will result in blocking on the full commit of the record, the slowest but most durable setting. + *

+ * If the request fails, the producer can automatically retry, though since we have specified retries + * as 0 it won't. Enabling retries also opens up the possibility of duplicates (see the documentation on + * message delivery semantics for details). + *

+ * The producer maintains buffers of unsent records for each partition. These buffers are of a size specified by + * the batch.size config. Making this larger can result in more batching, but requires more memory (since we will + * generally have one of these buffers for each active partition). + *

+ * By default a buffer is available to send immediately even if there is additional unused space in the buffer. However if you + * want to reduce the number of requests you can set linger.ms to something greater than 0. This will + * instruct the producer to wait up to that number of milliseconds before sending a request in hope that more records will + * arrive to fill up the same batch. This is analogous to Nagle's algorithm in TCP. For example, in the code snippet above, + * likely all 100 records would be sent in a single request since we set our linger time to 1 millisecond. However this setting + * would add 1 millisecond of latency to our request waiting for more records to arrive if we didn't fill up the buffer. Note that + * records that arrive close together in time will generally batch together even with linger.ms=0 so under heavy load + * batching will occur regardless of the linger configuration; however setting this to something larger than 0 can lead to fewer, more + * efficient requests when not under maximal load at the cost of a small amount of latency. + *

+ * The buffer.memory controls the total amount of memory available to the producer for buffering. If records + * are sent faster than they can be transmitted to the server then this buffer space will be exhausted. When the buffer space is + * exhausted additional send calls will block. The threshold for time to block is determined by max.block.ms after which it throws + * a TimeoutException. + *

+ * The key.serializer and value.serializer instruct how to turn the key and value objects the user provides with + * their ProducerRecord into bytes. You can use the included {@link org.apache.kafka.common.serialization.ByteArraySerializer} or + * {@link org.apache.kafka.common.serialization.StringSerializer} for simple string or byte types. + *

+ * From Kafka 0.11, the KafkaProducer supports two additional modes: the idempotent producer and the transactional producer. + * The idempotent producer strengthens Kafka's delivery semantics from at least once to exactly once delivery. In particular + * producer retries will no longer introduce duplicates. The transactional producer allows an application to send messages + * to multiple partitions (and topics!) atomically. + *

+ *

+ * To enable idempotence, the enable.idempotence configuration must be set to true. If set, the + * retries config will default to Integer.MAX_VALUE and the acks config will + * default to all. There are no API changes for the idempotent producer, so existing applications will + * not need to be modified to take advantage of this feature. + *

+ *

+ * To take advantage of the idempotent producer, it is imperative to avoid application level re-sends since these cannot + * be de-duplicated. As such, if an application enables idempotence, it is recommended to leave the retries + * config unset, as it will be defaulted to Integer.MAX_VALUE. Additionally, if a {@link #send(ProducerRecord)} + * returns an error even with infinite retries (for instance if the message expires in the buffer before being sent), + * then it is recommended to shut down the producer and check the contents of the last produced message to ensure that + * it is not duplicated. Finally, the producer can only guarantee idempotence for messages sent within a single session. + *

+ *

To use the transactional producer and the attendant APIs, you must set the transactional.id + * configuration property. If the transactional.id is set, idempotence is automatically enabled along with + * the producer configs which idempotence depends on. Further, topics which are included in transactions should be configured + * for durability. In particular, the replication.factor should be at least 3, and the + * min.insync.replicas for these topics should be set to 2. Finally, in order for transactional guarantees + * to be realized from end-to-end, the consumers must be configured to read only committed messages as well. + *

+ *

+ * The purpose of the transactional.id is to enable transaction recovery across multiple sessions of a + * single producer instance. It would typically be derived from the shard identifier in a partitioned, stateful, application. + * As such, it should be unique to each producer instance running within a partitioned application. + *

+ *

All the new transactional APIs are blocking and will throw exceptions on failure. The example + * below illustrates how the new APIs are meant to be used. It is similar to the example above, except that all + * 100 messages are part of a single transaction. + *

+ *

+ *

+ * {@code
+ * Properties props = new Properties();
+ * props.put("bootstrap.servers", "localhost:9092");
+ * props.put("transactional.id", "my-transactional-id");
+ * Producer producer = new KafkaProducer<>(props, new StringSerializer(), new StringSerializer());
+ *
+ * producer.initTransactions();
+ *
+ * try {
+ *     producer.beginTransaction();
+ *     for (int i = 0; i < 100; i++)
+ *         producer.send(new ProducerRecord<>("my-topic", Integer.toString(i), Integer.toString(i)));
+ *     producer.commitTransaction();
+ * } catch (ProducerFencedException | OutOfOrderSequenceException | AuthorizationException e) {
+ *     // We can't recover from these exceptions, so our only option is to close the producer and exit.
+ *     producer.close();
+ * } catch (KafkaException e) {
+ *     // For all other exceptions, just abort the transaction and try again.
+ *     producer.abortTransaction();
+ * }
+ * producer.close();
+ * } 
+ *

+ *

+ * As is hinted at in the example, there can be only one open transaction per producer. All messages sent between the + * {@link #beginTransaction()} and {@link #commitTransaction()} calls will be part of a single transaction. When the + * transactional.id is specified, all messages sent by the producer must be part of a transaction. + *

+ *

+ * The transactional producer uses exceptions to communicate error states. In particular, it is not required + * to specify callbacks for producer.send() or to call .get() on the returned Future: a + * KafkaException would be thrown if any of the + * producer.send() or transactional calls hit an irrecoverable error during a transaction. See the {@link #send(ProducerRecord)} + * documentation for more details about detecting errors from a transactional send. + *

+ *

By calling + * producer.abortTransaction() upon receiving a KafkaException we can ensure that any + * successful writes are marked as aborted, hence keeping the transactional guarantees. + *

+ *

+ * This client can communicate with brokers that are version 0.10.0 or newer. Older or newer brokers may not support + * certain client features. For instance, the transactional APIs need broker versions 0.11.0 or later. You will receive an + * UnsupportedVersionException when invoking an API that is not available in the running broker version. + *

+ */ +public class KafkaProducer implements Producer { + + private final Logger log; + private static final String JMX_PREFIX = "kafka.producer"; + public static final String NETWORK_THREAD_PREFIX = "kafka-producer-network-thread"; + public static final String PRODUCER_METRIC_GROUP_NAME = "producer-metrics"; + + private final String clientId; + // Visible for testing + final Metrics metrics; + private final KafkaProducerMetrics producerMetrics; + private final Partitioner partitioner; + private final int maxRequestSize; + private final long totalMemorySize; + private final ProducerMetadata metadata; + private final RecordAccumulator accumulator; + private final Sender sender; + private final Thread ioThread; + private final CompressionType compressionType; + private final Sensor errors; + private final Time time; + private final Serializer keySerializer; + private final Serializer valueSerializer; + private final ProducerConfig producerConfig; + private final long maxBlockTimeMs; + private final ProducerInterceptors interceptors; + private final ApiVersions apiVersions; + private final TransactionManager transactionManager; + + /** + * A producer is instantiated by providing a set of key-value pairs as configuration. Valid configuration strings + * are documented here. Values can be + * either strings or Objects of the appropriate type (for example a numeric configuration would accept either the + * string "42" or the integer 42). + *

+ * Note: after creating a {@code KafkaProducer} you must always {@link #close()} it to avoid resource leaks. + * @param configs The producer configs + * + */ + public KafkaProducer(final Map configs) { + this(configs, null, null); + } + + /** + * A producer is instantiated by providing a set of key-value pairs as configuration, a key and a value {@link Serializer}. + * Valid configuration strings are documented here. + * Values can be either strings or Objects of the appropriate type (for example a numeric configuration would accept + * either the string "42" or the integer 42). + *

+ * Note: after creating a {@code KafkaProducer} you must always {@link #close()} it to avoid resource leaks. + * @param configs The producer configs + * @param keySerializer The serializer for key that implements {@link Serializer}. The configure() method won't be + * called in the producer when the serializer is passed in directly. + * @param valueSerializer The serializer for value that implements {@link Serializer}. The configure() method won't + * be called in the producer when the serializer is passed in directly. + */ + public KafkaProducer(Map configs, Serializer keySerializer, Serializer valueSerializer) { + this(new ProducerConfig(ProducerConfig.appendSerializerToConfig(configs, keySerializer, valueSerializer)), + keySerializer, valueSerializer, null, null, null, Time.SYSTEM); + } + + /** + * A producer is instantiated by providing a set of key-value pairs as configuration. Valid configuration strings + * are documented here. + *

+ * Note: after creating a {@code KafkaProducer} you must always {@link #close()} it to avoid resource leaks. + * @param properties The producer configs + */ + public KafkaProducer(Properties properties) { + this(properties, null, null); + } + + /** + * A producer is instantiated by providing a set of key-value pairs as configuration, a key and a value {@link Serializer}. + * Valid configuration strings are documented here. + *

+ * Note: after creating a {@code KafkaProducer} you must always {@link #close()} it to avoid resource leaks. + * @param properties The producer configs + * @param keySerializer The serializer for key that implements {@link Serializer}. The configure() method won't be + * called in the producer when the serializer is passed in directly. + * @param valueSerializer The serializer for value that implements {@link Serializer}. The configure() method won't + * be called in the producer when the serializer is passed in directly. + */ + public KafkaProducer(Properties properties, Serializer keySerializer, Serializer valueSerializer) { + this(Utils.propsToMap(properties), keySerializer, valueSerializer); + } + + // visible for testing + @SuppressWarnings("unchecked") + KafkaProducer(ProducerConfig config, + Serializer keySerializer, + Serializer valueSerializer, + ProducerMetadata metadata, + KafkaClient kafkaClient, + ProducerInterceptors interceptors, + Time time) { + try { + this.producerConfig = config; + this.time = time; + + String transactionalId = config.getString(ProducerConfig.TRANSACTIONAL_ID_CONFIG); + + this.clientId = config.getString(ProducerConfig.CLIENT_ID_CONFIG); + + LogContext logContext; + if (transactionalId == null) + logContext = new LogContext(String.format("[Producer clientId=%s] ", clientId)); + else + logContext = new LogContext(String.format("[Producer clientId=%s, transactionalId=%s] ", clientId, transactionalId)); + log = logContext.logger(KafkaProducer.class); + log.trace("Starting the Kafka producer"); + + Map metricTags = Collections.singletonMap("client-id", clientId); + MetricConfig metricConfig = new MetricConfig().samples(config.getInt(ProducerConfig.METRICS_NUM_SAMPLES_CONFIG)) + .timeWindow(config.getLong(ProducerConfig.METRICS_SAMPLE_WINDOW_MS_CONFIG), TimeUnit.MILLISECONDS) + .recordLevel(Sensor.RecordingLevel.forName(config.getString(ProducerConfig.METRICS_RECORDING_LEVEL_CONFIG))) + .tags(metricTags); + List reporters = config.getConfiguredInstances(ProducerConfig.METRIC_REPORTER_CLASSES_CONFIG, + MetricsReporter.class, + Collections.singletonMap(ProducerConfig.CLIENT_ID_CONFIG, clientId)); + JmxReporter jmxReporter = new JmxReporter(); + jmxReporter.configure(config.originals(Collections.singletonMap(ProducerConfig.CLIENT_ID_CONFIG, clientId))); + reporters.add(jmxReporter); + MetricsContext metricsContext = new KafkaMetricsContext(JMX_PREFIX, + config.originalsWithPrefix(CommonClientConfigs.METRICS_CONTEXT_PREFIX)); + this.metrics = new Metrics(metricConfig, reporters, time, metricsContext); + this.producerMetrics = new KafkaProducerMetrics(metrics); + this.partitioner = config.getConfiguredInstance( + ProducerConfig.PARTITIONER_CLASS_CONFIG, + Partitioner.class, + Collections.singletonMap(ProducerConfig.CLIENT_ID_CONFIG, clientId)); + long retryBackoffMs = config.getLong(ProducerConfig.RETRY_BACKOFF_MS_CONFIG); + if (keySerializer == null) { + this.keySerializer = config.getConfiguredInstance(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, + Serializer.class); + this.keySerializer.configure(config.originals(Collections.singletonMap(ProducerConfig.CLIENT_ID_CONFIG, clientId)), true); + } else { + config.ignore(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG); + this.keySerializer = keySerializer; + } + if (valueSerializer == null) { + this.valueSerializer = config.getConfiguredInstance(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, + Serializer.class); + this.valueSerializer.configure(config.originals(Collections.singletonMap(ProducerConfig.CLIENT_ID_CONFIG, clientId)), false); + } else { + config.ignore(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG); + this.valueSerializer = valueSerializer; + } + + List> interceptorList = (List) config.getConfiguredInstances( + ProducerConfig.INTERCEPTOR_CLASSES_CONFIG, + ProducerInterceptor.class, + Collections.singletonMap(ProducerConfig.CLIENT_ID_CONFIG, clientId)); + if (interceptors != null) + this.interceptors = interceptors; + else + this.interceptors = new ProducerInterceptors<>(interceptorList); + ClusterResourceListeners clusterResourceListeners = configureClusterResourceListeners(keySerializer, + valueSerializer, interceptorList, reporters); + this.maxRequestSize = config.getInt(ProducerConfig.MAX_REQUEST_SIZE_CONFIG); + this.totalMemorySize = config.getLong(ProducerConfig.BUFFER_MEMORY_CONFIG); + this.compressionType = CompressionType.forName(config.getString(ProducerConfig.COMPRESSION_TYPE_CONFIG)); + + this.maxBlockTimeMs = config.getLong(ProducerConfig.MAX_BLOCK_MS_CONFIG); + int deliveryTimeoutMs = configureDeliveryTimeout(config, log); + + this.apiVersions = new ApiVersions(); + this.transactionManager = configureTransactionState(config, logContext); + this.accumulator = new RecordAccumulator(logContext, + config.getInt(ProducerConfig.BATCH_SIZE_CONFIG), + this.compressionType, + lingerMs(config), + retryBackoffMs, + deliveryTimeoutMs, + metrics, + PRODUCER_METRIC_GROUP_NAME, + time, + apiVersions, + transactionManager, + new BufferPool(this.totalMemorySize, config.getInt(ProducerConfig.BATCH_SIZE_CONFIG), metrics, time, PRODUCER_METRIC_GROUP_NAME)); + + List addresses = ClientUtils.parseAndValidateAddresses( + config.getList(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG), + config.getString(ProducerConfig.CLIENT_DNS_LOOKUP_CONFIG)); + if (metadata != null) { + this.metadata = metadata; + } else { + this.metadata = new ProducerMetadata(retryBackoffMs, + config.getLong(ProducerConfig.METADATA_MAX_AGE_CONFIG), + config.getLong(ProducerConfig.METADATA_MAX_IDLE_CONFIG), + logContext, + clusterResourceListeners, + Time.SYSTEM); + this.metadata.bootstrap(addresses); + } + this.errors = this.metrics.sensor("errors"); + this.sender = newSender(logContext, kafkaClient, this.metadata); + String ioThreadName = NETWORK_THREAD_PREFIX + " | " + clientId; + this.ioThread = new KafkaThread(ioThreadName, this.sender, true); + this.ioThread.start(); + config.logUnused(); + AppInfoParser.registerAppInfo(JMX_PREFIX, clientId, metrics, time.milliseconds()); + log.debug("Kafka producer started"); + } catch (Throwable t) { + // call close methods if internal objects are already constructed this is to prevent resource leak. see KAFKA-2121 + close(Duration.ofMillis(0), true); + // now propagate the exception + throw new KafkaException("Failed to construct kafka producer", t); + } + } + + // visible for testing + Sender newSender(LogContext logContext, KafkaClient kafkaClient, ProducerMetadata metadata) { + int maxInflightRequests = configureInflightRequests(producerConfig); + int requestTimeoutMs = producerConfig.getInt(ProducerConfig.REQUEST_TIMEOUT_MS_CONFIG); + ChannelBuilder channelBuilder = ClientUtils.createChannelBuilder(producerConfig, time, logContext); + ProducerMetrics metricsRegistry = new ProducerMetrics(this.metrics); + Sensor throttleTimeSensor = Sender.throttleTimeSensor(metricsRegistry.senderMetrics); + KafkaClient client = kafkaClient != null ? kafkaClient : new NetworkClient( + new Selector(producerConfig.getLong(ProducerConfig.CONNECTIONS_MAX_IDLE_MS_CONFIG), + this.metrics, time, "producer", channelBuilder, logContext), + metadata, + clientId, + maxInflightRequests, + producerConfig.getLong(ProducerConfig.RECONNECT_BACKOFF_MS_CONFIG), + producerConfig.getLong(ProducerConfig.RECONNECT_BACKOFF_MAX_MS_CONFIG), + producerConfig.getInt(ProducerConfig.SEND_BUFFER_CONFIG), + producerConfig.getInt(ProducerConfig.RECEIVE_BUFFER_CONFIG), + requestTimeoutMs, + producerConfig.getLong(ProducerConfig.SOCKET_CONNECTION_SETUP_TIMEOUT_MS_CONFIG), + producerConfig.getLong(ProducerConfig.SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS_CONFIG), + time, + true, + apiVersions, + throttleTimeSensor, + logContext); + short acks = configureAcks(producerConfig, log); + return new Sender(logContext, + client, + metadata, + this.accumulator, + maxInflightRequests == 1, + producerConfig.getInt(ProducerConfig.MAX_REQUEST_SIZE_CONFIG), + acks, + producerConfig.getInt(ProducerConfig.RETRIES_CONFIG), + metricsRegistry.senderMetrics, + time, + requestTimeoutMs, + producerConfig.getLong(ProducerConfig.RETRY_BACKOFF_MS_CONFIG), + this.transactionManager, + apiVersions); + } + + private static int lingerMs(ProducerConfig config) { + return (int) Math.min(config.getLong(ProducerConfig.LINGER_MS_CONFIG), Integer.MAX_VALUE); + } + + private static int configureDeliveryTimeout(ProducerConfig config, Logger log) { + int deliveryTimeoutMs = config.getInt(ProducerConfig.DELIVERY_TIMEOUT_MS_CONFIG); + int lingerMs = lingerMs(config); + int requestTimeoutMs = config.getInt(ProducerConfig.REQUEST_TIMEOUT_MS_CONFIG); + int lingerAndRequestTimeoutMs = (int) Math.min((long) lingerMs + requestTimeoutMs, Integer.MAX_VALUE); + + if (deliveryTimeoutMs < lingerAndRequestTimeoutMs) { + if (config.originals().containsKey(ProducerConfig.DELIVERY_TIMEOUT_MS_CONFIG)) { + // throw an exception if the user explicitly set an inconsistent value + throw new ConfigException(ProducerConfig.DELIVERY_TIMEOUT_MS_CONFIG + + " should be equal to or larger than " + ProducerConfig.LINGER_MS_CONFIG + + " + " + ProducerConfig.REQUEST_TIMEOUT_MS_CONFIG); + } else { + // override deliveryTimeoutMs default value to lingerMs + requestTimeoutMs for backward compatibility + deliveryTimeoutMs = lingerAndRequestTimeoutMs; + log.warn("{} should be equal to or larger than {} + {}. Setting it to {}.", + ProducerConfig.DELIVERY_TIMEOUT_MS_CONFIG, ProducerConfig.LINGER_MS_CONFIG, + ProducerConfig.REQUEST_TIMEOUT_MS_CONFIG, deliveryTimeoutMs); + } + } + return deliveryTimeoutMs; + } + + private TransactionManager configureTransactionState(ProducerConfig config, + LogContext logContext) { + + TransactionManager transactionManager = null; + + final boolean userConfiguredIdempotence = config.originals().containsKey(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG); + final boolean userConfiguredTransactions = config.originals().containsKey(ProducerConfig.TRANSACTIONAL_ID_CONFIG); + if (userConfiguredTransactions && !userConfiguredIdempotence) + log.info("Overriding the default {} to true since {} is specified.", ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG, + ProducerConfig.TRANSACTIONAL_ID_CONFIG); + + if (config.idempotenceEnabled()) { + final String transactionalId = config.getString(ProducerConfig.TRANSACTIONAL_ID_CONFIG); + final int transactionTimeoutMs = config.getInt(ProducerConfig.TRANSACTION_TIMEOUT_CONFIG); + final long retryBackoffMs = config.getLong(ProducerConfig.RETRY_BACKOFF_MS_CONFIG); + transactionManager = new TransactionManager( + logContext, + transactionalId, + transactionTimeoutMs, + retryBackoffMs, + apiVersions + ); + + if (transactionManager.isTransactional()) + log.info("Instantiated a transactional producer."); + else + log.info("Instantiated an idempotent producer."); + } + return transactionManager; + } + + private static int configureInflightRequests(ProducerConfig config) { + if (config.idempotenceEnabled() && 5 < config.getInt(ProducerConfig.MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION)) { + throw new ConfigException("Must set " + ProducerConfig.MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION + " to at most 5" + + " to use the idempotent producer."); + } + return config.getInt(ProducerConfig.MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION); + } + + private static short configureAcks(ProducerConfig config, Logger log) { + boolean userConfiguredAcks = config.originals().containsKey(ProducerConfig.ACKS_CONFIG); + short acks = Short.parseShort(config.getString(ProducerConfig.ACKS_CONFIG)); + + if (config.idempotenceEnabled()) { + if (!userConfiguredAcks) + log.info("Overriding the default {} to all since idempotence is enabled.", ProducerConfig.ACKS_CONFIG); + else if (acks != -1) + throw new ConfigException("Must set " + ProducerConfig.ACKS_CONFIG + " to all in order to use the idempotent " + + "producer. Otherwise we cannot guarantee idempotence."); + } + return acks; + } + + /** + * Needs to be called before any other methods when the transactional.id is set in the configuration. + * + * This method does the following: + * 1. Ensures any transactions initiated by previous instances of the producer with the same + * transactional.id are completed. If the previous instance had failed with a transaction in + * progress, it will be aborted. If the last transaction had begun completion, + * but not yet finished, this method awaits its completion. + * 2. Gets the internal producer id and epoch, used in all future transactional + * messages issued by the producer. + * + * Note that this method will raise {@link TimeoutException} if the transactional state cannot + * be initialized before expiration of {@code max.block.ms}. Additionally, it will raise {@link InterruptException} + * if interrupted. It is safe to retry in either case, but once the transactional state has been successfully + * initialized, this method should no longer be used. + * + * @throws IllegalStateException if no transactional.id has been configured + * @throws org.apache.kafka.common.errors.UnsupportedVersionException fatal error indicating the broker + * does not support transactions (i.e. if its version is lower than 0.11.0.0) + * @throws org.apache.kafka.common.errors.AuthorizationException fatal error indicating that the configured + * transactional.id is not authorized. See the exception for more details + * @throws KafkaException if the producer has encountered a previous fatal error or for any other unexpected error + * @throws TimeoutException if the time taken for initialize the transaction has surpassed max.block.ms. + * @throws InterruptException if the thread is interrupted while blocked + */ + public void initTransactions() { + throwIfNoTransactionManager(); + throwIfProducerClosed(); + long now = time.nanoseconds(); + TransactionalRequestResult result = transactionManager.initializeTransactions(); + sender.wakeup(); + result.await(maxBlockTimeMs, TimeUnit.MILLISECONDS); + producerMetrics.recordInit(time.nanoseconds() - now); + } + + /** + * Should be called before the start of each new transaction. Note that prior to the first invocation + * of this method, you must invoke {@link #initTransactions()} exactly one time. + * + * @throws IllegalStateException if no transactional.id has been configured or if {@link #initTransactions()} + * has not yet been invoked + * @throws ProducerFencedException if another producer with the same transactional.id is active + * @throws org.apache.kafka.common.errors.InvalidProducerEpochException if the producer has attempted to produce with an old epoch + * to the partition leader. See the exception for more details + * @throws org.apache.kafka.common.errors.UnsupportedVersionException fatal error indicating the broker + * does not support transactions (i.e. if its version is lower than 0.11.0.0) + * @throws org.apache.kafka.common.errors.AuthorizationException fatal error indicating that the configured + * transactional.id is not authorized. See the exception for more details + * @throws KafkaException if the producer has encountered a previous fatal error or for any other unexpected error + */ + public void beginTransaction() throws ProducerFencedException { + throwIfNoTransactionManager(); + throwIfProducerClosed(); + long now = time.nanoseconds(); + transactionManager.beginTransaction(); + producerMetrics.recordBeginTxn(time.nanoseconds() - now); + } + + /** + * Sends a list of specified offsets to the consumer group coordinator, and also marks + * those offsets as part of the current transaction. These offsets will be considered + * committed only if the transaction is committed successfully. The committed offset should + * be the next message your application will consume, i.e. lastProcessedMessageOffset + 1. + *

+ * This method should be used when you need to batch consumed and produced messages + * together, typically in a consume-transform-produce pattern. Thus, the specified + * {@code consumerGroupId} should be the same as config parameter {@code group.id} of the used + * {@link KafkaConsumer consumer}. Note, that the consumer should have {@code enable.auto.commit=false} + * and should also not commit offsets manually (via {@link KafkaConsumer#commitSync(Map) sync} or + * {@link KafkaConsumer#commitAsync(Map, OffsetCommitCallback) async} commits). + * + * @throws IllegalStateException if no transactional.id has been configured, no transaction has been started + * @throws ProducerFencedException fatal error indicating another producer with the same transactional.id is active + * @throws org.apache.kafka.common.errors.UnsupportedVersionException fatal error indicating the broker + * does not support transactions (i.e. if its version is lower than 0.11.0.0) + * @throws org.apache.kafka.common.errors.UnsupportedForMessageFormatException fatal error indicating the message + * format used for the offsets topic on the broker does not support transactions + * @throws org.apache.kafka.common.errors.AuthorizationException fatal error indicating that the configured + * transactional.id is not authorized, or the consumer group id is not authorized. + * @throws org.apache.kafka.common.errors.InvalidProducerEpochException if the producer has attempted to produce with an old epoch + * to the partition leader. See the exception for more details + * @throws KafkaException if the producer has encountered a previous fatal or abortable error, or for any + * other unexpected error + * + * @deprecated Since 3.0.0, please use {@link #sendOffsetsToTransaction(Map, ConsumerGroupMetadata)} instead. + */ + @Deprecated + public void sendOffsetsToTransaction(Map offsets, + String consumerGroupId) throws ProducerFencedException { + sendOffsetsToTransaction(offsets, new ConsumerGroupMetadata(consumerGroupId)); + } + + /** + * Sends a list of specified offsets to the consumer group coordinator, and also marks + * those offsets as part of the current transaction. These offsets will be considered + * committed only if the transaction is committed successfully. The committed offset should + * be the next message your application will consume, i.e. lastProcessedMessageOffset + 1. + *

+ * This method should be used when you need to batch consumed and produced messages + * together, typically in a consume-transform-produce pattern. Thus, the specified + * {@code groupMetadata} should be extracted from the used {@link KafkaConsumer consumer} via + * {@link KafkaConsumer#groupMetadata()} to leverage consumer group metadata. This will provide + * stronger fencing than just supplying the {@code consumerGroupId} and passing in {@code new ConsumerGroupMetadata(consumerGroupId)}, + * however note that the full set of consumer group metadata returned by {@link KafkaConsumer#groupMetadata()} + * requires the brokers to be on version 2.5 or newer to understand. + * + *

+ * Note, that the consumer should have {@code enable.auto.commit=false} and should + * also not commit offsets manually (via {@link KafkaConsumer#commitSync(Map) sync} or + * {@link KafkaConsumer#commitAsync(Map, OffsetCommitCallback) async} commits). + * This method will raise {@link TimeoutException} if the producer cannot send offsets before expiration of {@code max.block.ms}. + * Additionally, it will raise {@link InterruptException} if interrupted. + * + * @throws IllegalStateException if no transactional.id has been configured or no transaction has been started. + * @throws ProducerFencedException fatal error indicating another producer with the same transactional.id is active + * @throws org.apache.kafka.common.errors.UnsupportedVersionException fatal error indicating the broker + * does not support transactions (i.e. if its version is lower than 0.11.0.0) or + * the broker doesn't support latest version of transactional API with all consumer group metadata + * (i.e. if its version is lower than 2.5.0). + * @throws org.apache.kafka.common.errors.UnsupportedForMessageFormatException fatal error indicating the message + * format used for the offsets topic on the broker does not support transactions + * @throws org.apache.kafka.common.errors.AuthorizationException fatal error indicating that the configured + * transactional.id is not authorized, or the consumer group id is not authorized. + * @throws org.apache.kafka.clients.consumer.CommitFailedException if the commit failed and cannot be retried + * (e.g. if the consumer has been kicked out of the group). Users should handle this by aborting the transaction. + * @throws org.apache.kafka.common.errors.FencedInstanceIdException if this producer instance gets fenced by broker due to a + * mis-configured consumer instance id within group metadata. + * @throws org.apache.kafka.common.errors.InvalidProducerEpochException if the producer has attempted to produce with an old epoch + * to the partition leader. See the exception for more details + * @throws KafkaException if the producer has encountered a previous fatal or abortable error, or for any + * other unexpected error + * @throws TimeoutException if the time taken for sending offsets has surpassed max.block.ms. + * @throws InterruptException if the thread is interrupted while blocked + */ + public void sendOffsetsToTransaction(Map offsets, + ConsumerGroupMetadata groupMetadata) throws ProducerFencedException { + throwIfInvalidGroupMetadata(groupMetadata); + throwIfNoTransactionManager(); + throwIfProducerClosed(); + + if (!offsets.isEmpty()) { + long start = time.nanoseconds(); + TransactionalRequestResult result = transactionManager.sendOffsetsToTransaction(offsets, groupMetadata); + sender.wakeup(); + result.await(maxBlockTimeMs, TimeUnit.MILLISECONDS); + producerMetrics.recordSendOffsets(time.nanoseconds() - start); + } + } + + /** + * Commits the ongoing transaction. This method will flush any unsent records before actually committing the transaction. + * + * Further, if any of the {@link #send(ProducerRecord)} calls which were part of the transaction hit irrecoverable + * errors, this method will throw the last received exception immediately and the transaction will not be committed. + * So all {@link #send(ProducerRecord)} calls in a transaction must succeed in order for this method to succeed. + * + * Note that this method will raise {@link TimeoutException} if the transaction cannot be committed before expiration + * of {@code max.block.ms}. Additionally, it will raise {@link InterruptException} if interrupted. + * It is safe to retry in either case, but it is not possible to attempt a different operation (such as abortTransaction) + * since the commit may already be in the progress of completing. If not retrying, the only option is to close the producer. + * + * @throws IllegalStateException if no transactional.id has been configured or no transaction has been started + * @throws ProducerFencedException fatal error indicating another producer with the same transactional.id is active + * @throws org.apache.kafka.common.errors.UnsupportedVersionException fatal error indicating the broker + * does not support transactions (i.e. if its version is lower than 0.11.0.0) + * @throws org.apache.kafka.common.errors.AuthorizationException fatal error indicating that the configured + * transactional.id is not authorized. See the exception for more details + * @throws org.apache.kafka.common.errors.InvalidProducerEpochException if the producer has attempted to produce with an old epoch + * to the partition leader. See the exception for more details + * @throws KafkaException if the producer has encountered a previous fatal or abortable error, or for any + * other unexpected error + * @throws TimeoutException if the time taken for committing the transaction has surpassed max.block.ms. + * @throws InterruptException if the thread is interrupted while blocked + */ + public void commitTransaction() throws ProducerFencedException { + throwIfNoTransactionManager(); + throwIfProducerClosed(); + long commitStart = time.nanoseconds(); + TransactionalRequestResult result = transactionManager.beginCommit(); + sender.wakeup(); + result.await(maxBlockTimeMs, TimeUnit.MILLISECONDS); + producerMetrics.recordCommitTxn(time.nanoseconds() - commitStart); + } + + /** + * Aborts the ongoing transaction. Any unflushed produce messages will be aborted when this call is made. + * This call will throw an exception immediately if any prior {@link #send(ProducerRecord)} calls failed with a + * {@link ProducerFencedException} or an instance of {@link org.apache.kafka.common.errors.AuthorizationException}. + * + * Note that this method will raise {@link TimeoutException} if the transaction cannot be aborted before expiration + * of {@code max.block.ms}. Additionally, it will raise {@link InterruptException} if interrupted. + * It is safe to retry in either case, but it is not possible to attempt a different operation (such as commitTransaction) + * since the abort may already be in the progress of completing. If not retrying, the only option is to close the producer. + * + * @throws IllegalStateException if no transactional.id has been configured or no transaction has been started + * @throws ProducerFencedException fatal error indicating another producer with the same transactional.id is active + * @throws org.apache.kafka.common.errors.InvalidProducerEpochException if the producer has attempted to produce with an old epoch + * to the partition leader. See the exception for more details + * @throws org.apache.kafka.common.errors.UnsupportedVersionException fatal error indicating the broker + * does not support transactions (i.e. if its version is lower than 0.11.0.0) + * @throws org.apache.kafka.common.errors.AuthorizationException fatal error indicating that the configured + * transactional.id is not authorized. See the exception for more details + * @throws KafkaException if the producer has encountered a previous fatal error or for any other unexpected error + * @throws TimeoutException if the time taken for aborting the transaction has surpassed max.block.ms. + * @throws InterruptException if the thread is interrupted while blocked + */ + public void abortTransaction() throws ProducerFencedException { + throwIfNoTransactionManager(); + throwIfProducerClosed(); + log.info("Aborting incomplete transaction"); + long abortStart = time.nanoseconds(); + TransactionalRequestResult result = transactionManager.beginAbort(); + sender.wakeup(); + result.await(maxBlockTimeMs, TimeUnit.MILLISECONDS); + producerMetrics.recordAbortTxn(time.nanoseconds() - abortStart); + } + + /** + * Asynchronously send a record to a topic. Equivalent to send(record, null). + * See {@link #send(ProducerRecord, Callback)} for details. + */ + @Override + public Future send(ProducerRecord record) { + return send(record, null); + } + + /** + * Asynchronously send a record to a topic and invoke the provided callback when the send has been acknowledged. + *

+ * The send is asynchronous and this method will return immediately once the record has been stored in the buffer of + * records waiting to be sent. This allows sending many records in parallel without blocking to wait for the + * response after each one. + *

+ * The result of the send is a {@link RecordMetadata} specifying the partition the record was sent to, the offset + * it was assigned and the timestamp of the record. If + * {@link org.apache.kafka.common.record.TimestampType#CREATE_TIME CreateTime} is used by the topic, the timestamp + * will be the user provided timestamp or the record send time if the user did not specify a timestamp for the + * record. If {@link org.apache.kafka.common.record.TimestampType#LOG_APPEND_TIME LogAppendTime} is used for the + * topic, the timestamp will be the Kafka broker local time when the message is appended. + *

+ * Since the send call is asynchronous it returns a {@link java.util.concurrent.Future Future} for the + * {@link RecordMetadata} that will be assigned to this record. Invoking {@link java.util.concurrent.Future#get() + * get()} on this future will block until the associated request completes and then return the metadata for the record + * or throw any exception that occurred while sending the record. + *

+ * If you want to simulate a simple blocking call you can call the get() method immediately: + * + *

+     * {@code
+     * byte[] key = "key".getBytes();
+     * byte[] value = "value".getBytes();
+     * ProducerRecord record = new ProducerRecord("my-topic", key, value)
+     * producer.send(record).get();
+     * }
+ *

+ * Fully non-blocking usage can make use of the {@link Callback} parameter to provide a callback that + * will be invoked when the request is complete. + * + *

+     * {@code
+     * ProducerRecord record = new ProducerRecord("the-topic", key, value);
+     * producer.send(myRecord,
+     *               new Callback() {
+     *                   public void onCompletion(RecordMetadata metadata, Exception e) {
+     *                       if(e != null) {
+     *                          e.printStackTrace();
+     *                       } else {
+     *                          System.out.println("The offset of the record we just sent is: " + metadata.offset());
+     *                       }
+     *                   }
+     *               });
+     * }
+     * 
+ * + * Callbacks for records being sent to the same partition are guaranteed to execute in order. That is, in the + * following example callback1 is guaranteed to execute before callback2: + * + *
+     * {@code
+     * producer.send(new ProducerRecord(topic, partition, key1, value1), callback1);
+     * producer.send(new ProducerRecord(topic, partition, key2, value2), callback2);
+     * }
+     * 
+ *

+ * When used as part of a transaction, it is not necessary to define a callback or check the result of the future + * in order to detect errors from send. If any of the send calls failed with an irrecoverable error, + * the final {@link #commitTransaction()} call will fail and throw the exception from the last failed send. When + * this happens, your application should call {@link #abortTransaction()} to reset the state and continue to send + * data. + *

+ *

+ * Some transactional send errors cannot be resolved with a call to {@link #abortTransaction()}. In particular, + * if a transactional send finishes with a {@link ProducerFencedException}, a {@link org.apache.kafka.common.errors.OutOfOrderSequenceException}, + * a {@link org.apache.kafka.common.errors.UnsupportedVersionException}, or an + * {@link org.apache.kafka.common.errors.AuthorizationException}, then the only option left is to call {@link #close()}. + * Fatal errors cause the producer to enter a defunct state in which future API calls will continue to raise + * the same underyling error wrapped in a new {@link KafkaException}. + *

+ *

+ * It is a similar picture when idempotence is enabled, but no transactional.id has been configured. + * In this case, {@link org.apache.kafka.common.errors.UnsupportedVersionException} and + * {@link org.apache.kafka.common.errors.AuthorizationException} are considered fatal errors. However, + * {@link ProducerFencedException} does not need to be handled. Additionally, it is possible to continue + * sending after receiving an {@link org.apache.kafka.common.errors.OutOfOrderSequenceException}, but doing so + * can result in out of order delivery of pending messages. To ensure proper ordering, you should close the + * producer and create a new instance. + *

+ *

+ * If the message format of the destination topic is not upgraded to 0.11.0.0, idempotent and transactional + * produce requests will fail with an {@link org.apache.kafka.common.errors.UnsupportedForMessageFormatException} + * error. If this is encountered during a transaction, it is possible to abort and continue. But note that future + * sends to the same topic will continue receiving the same exception until the topic is upgraded. + *

+ *

+ * Note that callbacks will generally execute in the I/O thread of the producer and so should be reasonably fast or + * they will delay the sending of messages from other threads. If you want to execute blocking or computationally + * expensive callbacks it is recommended to use your own {@link java.util.concurrent.Executor} in the callback body + * to parallelize processing. + * + * @param record The record to send + * @param callback A user-supplied callback to execute when the record has been acknowledged by the server (null + * indicates no callback) + * + * @throws AuthenticationException if authentication fails. See the exception for more details + * @throws AuthorizationException fatal error indicating that the producer is not allowed to write + * @throws IllegalStateException if a transactional.id has been configured and no transaction has been started, or + * when send is invoked after producer has been closed. + * @throws InterruptException If the thread is interrupted while blocked + * @throws SerializationException If the key or value are not valid objects given the configured serializers + * @throws TimeoutException If the record could not be appended to the send buffer due to memory unavailable + * or missing metadata within {@code max.block.ms}. + * @throws KafkaException If a Kafka related error occurs that does not belong to the public API exceptions. + */ + @Override + public Future send(ProducerRecord record, Callback callback) { + // intercept the record, which can be potentially modified; this method does not throw exceptions + ProducerRecord interceptedRecord = this.interceptors.onSend(record); + return doSend(interceptedRecord, callback); + } + + // Verify that this producer instance has not been closed. This method throws IllegalStateException if the producer + // has already been closed. + private void throwIfProducerClosed() { + if (sender == null || !sender.isRunning()) + throw new IllegalStateException("Cannot perform operation after producer has been closed"); + } + + /** + * Implementation of asynchronously send a record to a topic. + */ + private Future doSend(ProducerRecord record, Callback callback) { + TopicPartition tp = null; + try { + throwIfProducerClosed(); + // first make sure the metadata for the topic is available + long nowMs = time.milliseconds(); + ClusterAndWaitTime clusterAndWaitTime; + try { + clusterAndWaitTime = waitOnMetadata(record.topic(), record.partition(), nowMs, maxBlockTimeMs); + } catch (KafkaException e) { + if (metadata.isClosed()) + throw new KafkaException("Producer closed while send in progress", e); + throw e; + } + nowMs += clusterAndWaitTime.waitedOnMetadataMs; + long remainingWaitMs = Math.max(0, maxBlockTimeMs - clusterAndWaitTime.waitedOnMetadataMs); + Cluster cluster = clusterAndWaitTime.cluster; + byte[] serializedKey; + try { + serializedKey = keySerializer.serialize(record.topic(), record.headers(), record.key()); + } catch (ClassCastException cce) { + throw new SerializationException("Can't convert key of class " + record.key().getClass().getName() + + " to class " + producerConfig.getClass(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG).getName() + + " specified in key.serializer", cce); + } + byte[] serializedValue; + try { + serializedValue = valueSerializer.serialize(record.topic(), record.headers(), record.value()); + } catch (ClassCastException cce) { + throw new SerializationException("Can't convert value of class " + record.value().getClass().getName() + + " to class " + producerConfig.getClass(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG).getName() + + " specified in value.serializer", cce); + } + int partition = partition(record, serializedKey, serializedValue, cluster); + tp = new TopicPartition(record.topic(), partition); + + setReadOnly(record.headers()); + Header[] headers = record.headers().toArray(); + + int serializedSize = AbstractRecords.estimateSizeInBytesUpperBound(apiVersions.maxUsableProduceMagic(), + compressionType, serializedKey, serializedValue, headers); + ensureValidRecordSize(serializedSize); + long timestamp = record.timestamp() == null ? nowMs : record.timestamp(); + if (log.isTraceEnabled()) { + log.trace("Attempting to append record {} with callback {} to topic {} partition {}", record, callback, record.topic(), partition); + } + // producer callback will make sure to call both 'callback' and interceptor callback + Callback interceptCallback = new InterceptorCallback<>(callback, this.interceptors, tp); + + if (transactionManager != null && transactionManager.isTransactional()) { + transactionManager.failIfNotReadyForSend(); + } + RecordAccumulator.RecordAppendResult result = accumulator.append(tp, timestamp, serializedKey, + serializedValue, headers, interceptCallback, remainingWaitMs, true, nowMs); + + if (result.abortForNewBatch) { + int prevPartition = partition; + partitioner.onNewBatch(record.topic(), cluster, prevPartition); + partition = partition(record, serializedKey, serializedValue, cluster); + tp = new TopicPartition(record.topic(), partition); + if (log.isTraceEnabled()) { + log.trace("Retrying append due to new batch creation for topic {} partition {}. The old partition was {}", record.topic(), partition, prevPartition); + } + // producer callback will make sure to call both 'callback' and interceptor callback + interceptCallback = new InterceptorCallback<>(callback, this.interceptors, tp); + + result = accumulator.append(tp, timestamp, serializedKey, + serializedValue, headers, interceptCallback, remainingWaitMs, false, nowMs); + } + + if (transactionManager != null && transactionManager.isTransactional()) + transactionManager.maybeAddPartitionToTransaction(tp); + + if (result.batchIsFull || result.newBatchCreated) { + log.trace("Waking up the sender since topic {} partition {} is either full or getting a new batch", record.topic(), partition); + this.sender.wakeup(); + } + return result.future; + // handling exceptions and record the errors; + // for API exceptions return them in the future, + // for other exceptions throw directly + } catch (ApiException e) { + log.debug("Exception occurred during message send:", e); + if (callback != null) + callback.onCompletion(null, e); + this.errors.record(); + this.interceptors.onSendError(record, tp, e); + return new FutureFailure(e); + } catch (InterruptedException e) { + this.errors.record(); + this.interceptors.onSendError(record, tp, e); + throw new InterruptException(e); + } catch (KafkaException e) { + this.errors.record(); + this.interceptors.onSendError(record, tp, e); + throw e; + } catch (Exception e) { + // we notify interceptor about all exceptions, since onSend is called before anything else in this method + this.interceptors.onSendError(record, tp, e); + throw e; + } + } + + private void setReadOnly(Headers headers) { + if (headers instanceof RecordHeaders) { + ((RecordHeaders) headers).setReadOnly(); + } + } + + /** + * Wait for cluster metadata including partitions for the given topic to be available. + * @param topic The topic we want metadata for + * @param partition A specific partition expected to exist in metadata, or null if there's no preference + * @param nowMs The current time in ms + * @param maxWaitMs The maximum time in ms for waiting on the metadata + * @return The cluster containing topic metadata and the amount of time we waited in ms + * @throws TimeoutException if metadata could not be refreshed within {@code max.block.ms} + * @throws KafkaException for all Kafka-related exceptions, including the case where this method is called after producer close + */ + private ClusterAndWaitTime waitOnMetadata(String topic, Integer partition, long nowMs, long maxWaitMs) throws InterruptedException { + // add topic to metadata topic list if it is not there already and reset expiry + Cluster cluster = metadata.fetch(); + + if (cluster.invalidTopics().contains(topic)) + throw new InvalidTopicException(topic); + + metadata.add(topic, nowMs); + + Integer partitionsCount = cluster.partitionCountForTopic(topic); + // Return cached metadata if we have it, and if the record's partition is either undefined + // or within the known partition range + if (partitionsCount != null && (partition == null || partition < partitionsCount)) + return new ClusterAndWaitTime(cluster, 0); + + long remainingWaitMs = maxWaitMs; + long elapsed = 0; + // Issue metadata requests until we have metadata for the topic and the requested partition, + // or until maxWaitTimeMs is exceeded. This is necessary in case the metadata + // is stale and the number of partitions for this topic has increased in the meantime. + do { + if (partition != null) { + log.trace("Requesting metadata update for partition {} of topic {}.", partition, topic); + } else { + log.trace("Requesting metadata update for topic {}.", topic); + } + metadata.add(topic, nowMs + elapsed); + int version = metadata.requestUpdateForTopic(topic); + sender.wakeup(); + try { + metadata.awaitUpdate(version, remainingWaitMs); + } catch (TimeoutException ex) { + // Rethrow with original maxWaitMs to prevent logging exception with remainingWaitMs + throw new TimeoutException( + String.format("Topic %s not present in metadata after %d ms.", + topic, maxWaitMs)); + } + cluster = metadata.fetch(); + elapsed = time.milliseconds() - nowMs; + if (elapsed >= maxWaitMs) { + throw new TimeoutException(partitionsCount == null ? + String.format("Topic %s not present in metadata after %d ms.", + topic, maxWaitMs) : + String.format("Partition %d of topic %s with partition count %d is not present in metadata after %d ms.", + partition, topic, partitionsCount, maxWaitMs)); + } + metadata.maybeThrowExceptionForTopic(topic); + remainingWaitMs = maxWaitMs - elapsed; + partitionsCount = cluster.partitionCountForTopic(topic); + } while (partitionsCount == null || (partition != null && partition >= partitionsCount)); + + return new ClusterAndWaitTime(cluster, elapsed); + } + + /** + * Validate that the record size isn't too large + */ + private void ensureValidRecordSize(int size) { + if (size > maxRequestSize) + throw new RecordTooLargeException("The message is " + size + + " bytes when serialized which is larger than " + maxRequestSize + ", which is the value of the " + + ProducerConfig.MAX_REQUEST_SIZE_CONFIG + " configuration."); + if (size > totalMemorySize) + throw new RecordTooLargeException("The message is " + size + + " bytes when serialized which is larger than the total memory buffer you have configured with the " + + ProducerConfig.BUFFER_MEMORY_CONFIG + + " configuration."); + } + + /** + * Invoking this method makes all buffered records immediately available to send (even if linger.ms is + * greater than 0) and blocks on the completion of the requests associated with these records. The post-condition + * of flush() is that any previously sent record will have completed (e.g. Future.isDone() == true). + * A request is considered completed when it is successfully acknowledged + * according to the acks configuration you have specified or else it results in an error. + *

+ * Other threads can continue sending records while one thread is blocked waiting for a flush call to complete, + * however no guarantee is made about the completion of records sent after the flush call begins. + *

+ * This method can be useful when consuming from some input system and producing into Kafka. The flush() call + * gives a convenient way to ensure all previously sent messages have actually completed. + *

+ * This example shows how to consume from one Kafka topic and produce to another Kafka topic: + *

+     * {@code
+     * for(ConsumerRecord record: consumer.poll(100))
+     *     producer.send(new ProducerRecord("my-topic", record.key(), record.value());
+     * producer.flush();
+     * consumer.commitSync();
+     * }
+     * 
+ * + * Note that the above example may drop records if the produce request fails. If we want to ensure that this does not occur + * we need to set retries=<large_number> in our config. + *

+ *

+ * Applications don't need to call this method for transactional producers, since the {@link #commitTransaction()} will + * flush all buffered records before performing the commit. This ensures that all the {@link #send(ProducerRecord)} + * calls made since the previous {@link #beginTransaction()} are completed before the commit. + *

+ * + * @throws InterruptException If the thread is interrupted while blocked + */ + @Override + public void flush() { + log.trace("Flushing accumulated records in producer."); + + long start = time.nanoseconds(); + this.accumulator.beginFlush(); + this.sender.wakeup(); + try { + this.accumulator.awaitFlushCompletion(); + } catch (InterruptedException e) { + throw new InterruptException("Flush interrupted.", e); + } finally { + producerMetrics.recordFlush(time.nanoseconds() - start); + } + } + + /** + * Get the partition metadata for the given topic. This can be used for custom partitioning. + * @throws AuthenticationException if authentication fails. See the exception for more details + * @throws AuthorizationException if not authorized to the specified topic. See the exception for more details + * @throws InterruptException if the thread is interrupted while blocked + * @throws TimeoutException if metadata could not be refreshed within {@code max.block.ms} + * @throws KafkaException for all Kafka-related exceptions, including the case where this method is called after producer close + */ + @Override + public List partitionsFor(String topic) { + Objects.requireNonNull(topic, "topic cannot be null"); + try { + return waitOnMetadata(topic, null, time.milliseconds(), maxBlockTimeMs).cluster.partitionsForTopic(topic); + } catch (InterruptedException e) { + throw new InterruptException(e); + } + } + + /** + * Get the full set of internal metrics maintained by the producer. + */ + @Override + public Map metrics() { + return Collections.unmodifiableMap(this.metrics.metrics()); + } + + /** + * Close this producer. This method blocks until all previously sent requests complete. + * This method is equivalent to close(Long.MAX_VALUE, TimeUnit.MILLISECONDS). + *

+ * If close() is called from {@link Callback}, a warning message will be logged and close(0, TimeUnit.MILLISECONDS) + * will be called instead. We do this because the sender thread would otherwise try to join itself and + * block forever. + *

+ * + * @throws InterruptException If the thread is interrupted while blocked. + * @throws KafkaException If a unexpected error occurs while trying to close the client, this error should be treated + * as fatal and indicate the client is no longer functionable. + */ + @Override + public void close() { + close(Duration.ofMillis(Long.MAX_VALUE)); + } + + /** + * This method waits up to timeout for the producer to complete the sending of all incomplete requests. + *

+ * If the producer is unable to complete all requests before the timeout expires, this method will fail + * any unsent and unacknowledged records immediately. It will also abort the ongoing transaction if it's not + * already completing. + *

+ * If invoked from within a {@link Callback} this method will not block and will be equivalent to + * close(Duration.ofMillis(0)). This is done since no further sending will happen while + * blocking the I/O thread of the producer. + * + * @param timeout The maximum time to wait for producer to complete any pending requests. The value should be + * non-negative. Specifying a timeout of zero means do not wait for pending send requests to complete. + * @throws InterruptException If the thread is interrupted while blocked. + * @throws KafkaException If a unexpected error occurs while trying to close the client, this error should be treated + * as fatal and indicate the client is no longer functionable. + * @throws IllegalArgumentException If the timeout is negative. + * + */ + @Override + public void close(Duration timeout) { + close(timeout, false); + } + + private void close(Duration timeout, boolean swallowException) { + long timeoutMs = timeout.toMillis(); + if (timeoutMs < 0) + throw new IllegalArgumentException("The timeout cannot be negative."); + log.info("Closing the Kafka producer with timeoutMillis = {} ms.", timeoutMs); + + // this will keep track of the first encountered exception + AtomicReference firstException = new AtomicReference<>(); + boolean invokedFromCallback = Thread.currentThread() == this.ioThread; + if (timeoutMs > 0) { + if (invokedFromCallback) { + log.warn("Overriding close timeout {} ms to 0 ms in order to prevent useless blocking due to self-join. " + + "This means you have incorrectly invoked close with a non-zero timeout from the producer call-back.", + timeoutMs); + } else { + // Try to close gracefully. + if (this.sender != null) + this.sender.initiateClose(); + if (this.ioThread != null) { + try { + this.ioThread.join(timeoutMs); + } catch (InterruptedException t) { + firstException.compareAndSet(null, new InterruptException(t)); + log.error("Interrupted while joining ioThread", t); + } + } + } + } + + if (this.sender != null && this.ioThread != null && this.ioThread.isAlive()) { + log.info("Proceeding to force close the producer since pending requests could not be completed " + + "within timeout {} ms.", timeoutMs); + this.sender.forceClose(); + // Only join the sender thread when not calling from callback. + if (!invokedFromCallback) { + try { + this.ioThread.join(); + } catch (InterruptedException e) { + firstException.compareAndSet(null, new InterruptException(e)); + } + } + } + + Utils.closeQuietly(interceptors, "producer interceptors", firstException); + Utils.closeQuietly(producerMetrics, "producer metrics wrapper", firstException); + Utils.closeQuietly(metrics, "producer metrics", firstException); + Utils.closeQuietly(keySerializer, "producer keySerializer", firstException); + Utils.closeQuietly(valueSerializer, "producer valueSerializer", firstException); + Utils.closeQuietly(partitioner, "producer partitioner", firstException); + AppInfoParser.unregisterAppInfo(JMX_PREFIX, clientId, metrics); + Throwable exception = firstException.get(); + if (exception != null && !swallowException) { + if (exception instanceof InterruptException) { + throw (InterruptException) exception; + } + throw new KafkaException("Failed to close kafka producer", exception); + } + log.debug("Kafka producer has been closed"); + } + + private ClusterResourceListeners configureClusterResourceListeners(Serializer keySerializer, Serializer valueSerializer, List... candidateLists) { + ClusterResourceListeners clusterResourceListeners = new ClusterResourceListeners(); + for (List candidateList: candidateLists) + clusterResourceListeners.maybeAddAll(candidateList); + + clusterResourceListeners.maybeAdd(keySerializer); + clusterResourceListeners.maybeAdd(valueSerializer); + return clusterResourceListeners; + } + + /** + * computes partition for given record. + * if the record has partition returns the value otherwise + * calls configured partitioner class to compute the partition. + */ + private int partition(ProducerRecord record, byte[] serializedKey, byte[] serializedValue, Cluster cluster) { + Integer partition = record.partition(); + return partition != null ? + partition : + partitioner.partition( + record.topic(), record.key(), serializedKey, record.value(), serializedValue, cluster); + } + + private void throwIfInvalidGroupMetadata(ConsumerGroupMetadata groupMetadata) { + if (groupMetadata == null) { + throw new IllegalArgumentException("Consumer group metadata could not be null"); + } else if (groupMetadata.generationId() > 0 + && JoinGroupRequest.UNKNOWN_MEMBER_ID.equals(groupMetadata.memberId())) { + throw new IllegalArgumentException("Passed in group metadata " + groupMetadata + " has generationId > 0 but member.id "); + } + } + + private void throwIfNoTransactionManager() { + if (transactionManager == null) + throw new IllegalStateException("Cannot use transactional methods without enabling transactions " + + "by setting the " + ProducerConfig.TRANSACTIONAL_ID_CONFIG + " configuration property"); + } + + // Visible for testing + String getClientId() { + return clientId; + } + + private static class ClusterAndWaitTime { + final Cluster cluster; + final long waitedOnMetadataMs; + ClusterAndWaitTime(Cluster cluster, long waitedOnMetadataMs) { + this.cluster = cluster; + this.waitedOnMetadataMs = waitedOnMetadataMs; + } + } + + private static class FutureFailure implements Future { + + private final ExecutionException exception; + + public FutureFailure(Exception exception) { + this.exception = new ExecutionException(exception); + } + + @Override + public boolean cancel(boolean interrupt) { + return false; + } + + @Override + public RecordMetadata get() throws ExecutionException { + throw this.exception; + } + + @Override + public RecordMetadata get(long timeout, TimeUnit unit) throws ExecutionException { + throw this.exception; + } + + @Override + public boolean isCancelled() { + return false; + } + + @Override + public boolean isDone() { + return true; + } + + } + + /** + * A callback called when producer request is complete. It in turn calls user-supplied callback (if given) and + * notifies producer interceptors about the request completion. + */ + private static class InterceptorCallback implements Callback { + private final Callback userCallback; + private final ProducerInterceptors interceptors; + private final TopicPartition tp; + + private InterceptorCallback(Callback userCallback, ProducerInterceptors interceptors, TopicPartition tp) { + this.userCallback = userCallback; + this.interceptors = interceptors; + this.tp = tp; + } + + public void onCompletion(RecordMetadata metadata, Exception exception) { + metadata = metadata != null ? metadata : new RecordMetadata(tp, -1, -1, RecordBatch.NO_TIMESTAMP, -1, -1); + this.interceptors.onAcknowledgement(metadata, exception); + if (this.userCallback != null) + this.userCallback.onCompletion(metadata, exception); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/MockProducer.java b/clients/src/main/java/org/apache/kafka/clients/producer/MockProducer.java new file mode 100644 index 0000000..4fd540d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/MockProducer.java @@ -0,0 +1,549 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer; + +import org.apache.kafka.clients.consumer.ConsumerGroupMetadata; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.producer.internals.DefaultPartitioner; +import org.apache.kafka.clients.producer.internals.FutureRecordMetadata; +import org.apache.kafka.clients.producer.internals.ProduceRequestResult; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.ProducerFencedException; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.Time; + +import java.time.Duration; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Deque; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.Future; + +/** + * A mock of the producer interface you can use for testing code that uses Kafka. + *

+ * By default this mock will synchronously complete each send call successfully. However it can be configured to allow + * the user to control the completion of the call and supply an optional error for the producer to throw. + */ +public class MockProducer implements Producer { + + private final Cluster cluster; + private final Partitioner partitioner; + private final List> sent; + private final List> uncommittedSends; + private final Deque completions; + private final Map offsets; + private final List>> consumerGroupOffsets; + private Map> uncommittedConsumerGroupOffsets; + private final Serializer keySerializer; + private final Serializer valueSerializer; + private boolean autoComplete; + private boolean closed; + private boolean transactionInitialized; + private boolean transactionInFlight; + private boolean transactionCommitted; + private boolean transactionAborted; + private boolean producerFenced; + private boolean sentOffsets; + private long commitCount = 0L; + private final Map mockMetrics; + + public RuntimeException initTransactionException = null; + public RuntimeException beginTransactionException = null; + public RuntimeException sendOffsetsToTransactionException = null; + public RuntimeException commitTransactionException = null; + public RuntimeException abortTransactionException = null; + public RuntimeException sendException = null; + public RuntimeException flushException = null; + public RuntimeException partitionsForException = null; + public RuntimeException closeException = null; + + /** + * Create a mock producer + * + * @param cluster The cluster holding metadata for this producer + * @param autoComplete If true automatically complete all requests successfully and execute the callback. Otherwise + * the user must call {@link #completeNext()} or {@link #errorNext(RuntimeException)} after + * {@link #send(ProducerRecord) send()} to complete the call and unblock the {@link + * java.util.concurrent.Future Future<RecordMetadata>} that is returned. + * @param partitioner The partition strategy + * @param keySerializer The serializer for key that implements {@link Serializer}. + * @param valueSerializer The serializer for value that implements {@link Serializer}. + */ + public MockProducer(final Cluster cluster, + final boolean autoComplete, + final Partitioner partitioner, + final Serializer keySerializer, + final Serializer valueSerializer) { + this.cluster = cluster; + this.autoComplete = autoComplete; + this.partitioner = partitioner; + this.keySerializer = keySerializer; + this.valueSerializer = valueSerializer; + this.offsets = new HashMap<>(); + this.sent = new ArrayList<>(); + this.uncommittedSends = new ArrayList<>(); + this.consumerGroupOffsets = new ArrayList<>(); + this.uncommittedConsumerGroupOffsets = new HashMap<>(); + this.completions = new ArrayDeque<>(); + this.mockMetrics = new HashMap<>(); + } + + /** + * Create a new mock producer with invented metadata the given autoComplete setting and key\value serializers. + * + * Equivalent to {@link #MockProducer(Cluster, boolean, Partitioner, Serializer, Serializer)} new MockProducer(Cluster.empty(), autoComplete, new DefaultPartitioner(), keySerializer, valueSerializer)} + */ + public MockProducer(final boolean autoComplete, + final Serializer keySerializer, + final Serializer valueSerializer) { + this(Cluster.empty(), autoComplete, new DefaultPartitioner(), keySerializer, valueSerializer); + } + + /** + * Create a new mock producer with invented metadata the given autoComplete setting, partitioner and key\value serializers. + * + * Equivalent to {@link #MockProducer(Cluster, boolean, Partitioner, Serializer, Serializer)} new MockProducer(Cluster.empty(), autoComplete, partitioner, keySerializer, valueSerializer)} + */ + public MockProducer(final boolean autoComplete, + final Partitioner partitioner, + final Serializer keySerializer, + final Serializer valueSerializer) { + this(Cluster.empty(), autoComplete, partitioner, keySerializer, valueSerializer); + } + + /** + * Create a new mock producer with invented metadata. + * + * Equivalent to {@link #MockProducer(Cluster, boolean, Partitioner, Serializer, Serializer)} new MockProducer(Cluster.empty(), false, null, null, null)} + */ + public MockProducer() { + this(Cluster.empty(), false, null, null, null); + } + + @Override + public void initTransactions() { + verifyProducerState(); + if (this.transactionInitialized) { + throw new IllegalStateException("MockProducer has already been initialized for transactions."); + } + if (this.initTransactionException != null) { + throw this.initTransactionException; + } + this.transactionInitialized = true; + this.transactionInFlight = false; + this.transactionCommitted = false; + this.transactionAborted = false; + this.sentOffsets = false; + } + + @Override + public void beginTransaction() throws ProducerFencedException { + verifyProducerState(); + verifyTransactionsInitialized(); + + if (this.beginTransactionException != null) { + throw this.beginTransactionException; + } + + if (transactionInFlight) { + throw new IllegalStateException("Transaction already started"); + } + + this.transactionInFlight = true; + this.transactionCommitted = false; + this.transactionAborted = false; + this.sentOffsets = false; + } + + @Deprecated + @Override + public void sendOffsetsToTransaction(Map offsets, + String consumerGroupId) throws ProducerFencedException { + Objects.requireNonNull(consumerGroupId); + sendOffsetsToTransaction(offsets, new ConsumerGroupMetadata(consumerGroupId)); + } + + @Override + public void sendOffsetsToTransaction(Map offsets, + ConsumerGroupMetadata groupMetadata) throws ProducerFencedException { + Objects.requireNonNull(groupMetadata); + verifyProducerState(); + verifyTransactionsInitialized(); + verifyTransactionInFlight(); + + if (this.sendOffsetsToTransactionException != null) { + throw this.sendOffsetsToTransactionException; + } + + if (offsets.size() == 0) { + return; + } + Map uncommittedOffsets = + this.uncommittedConsumerGroupOffsets.computeIfAbsent(groupMetadata.groupId(), k -> new HashMap<>()); + uncommittedOffsets.putAll(offsets); + this.sentOffsets = true; + } + + @Override + public void commitTransaction() throws ProducerFencedException { + verifyProducerState(); + verifyTransactionsInitialized(); + verifyTransactionInFlight(); + + if (this.commitTransactionException != null) { + throw this.commitTransactionException; + } + + flush(); + + this.sent.addAll(this.uncommittedSends); + if (!this.uncommittedConsumerGroupOffsets.isEmpty()) + this.consumerGroupOffsets.add(this.uncommittedConsumerGroupOffsets); + + this.uncommittedSends.clear(); + this.uncommittedConsumerGroupOffsets = new HashMap<>(); + this.transactionCommitted = true; + this.transactionAborted = false; + this.transactionInFlight = false; + + ++this.commitCount; + } + + @Override + public void abortTransaction() throws ProducerFencedException { + verifyProducerState(); + verifyTransactionsInitialized(); + verifyTransactionInFlight(); + + if (this.abortTransactionException != null) { + throw this.abortTransactionException; + } + + flush(); + this.uncommittedSends.clear(); + this.uncommittedConsumerGroupOffsets.clear(); + this.transactionCommitted = false; + this.transactionAborted = true; + this.transactionInFlight = false; + } + + private synchronized void verifyProducerState() { + if (this.closed) { + throw new IllegalStateException("MockProducer is already closed."); + } + if (this.producerFenced) { + throw new ProducerFencedException("MockProducer is fenced."); + } + } + + private void verifyTransactionsInitialized() { + if (!this.transactionInitialized) { + throw new IllegalStateException("MockProducer hasn't been initialized for transactions."); + } + } + + private void verifyTransactionInFlight() { + if (!this.transactionInFlight) { + throw new IllegalStateException("There is no open transaction."); + } + } + + /** + * Adds the record to the list of sent records. The {@link RecordMetadata} returned will be immediately satisfied. + * + * @see #history() + */ + @Override + public synchronized Future send(ProducerRecord record) { + return send(record, null); + } + + /** + * Adds the record to the list of sent records. + * + * @see #history() + */ + @Override + public synchronized Future send(ProducerRecord record, Callback callback) { + if (this.closed) { + throw new IllegalStateException("MockProducer is already closed."); + } + + if (this.producerFenced) { + throw new KafkaException("MockProducer is fenced.", new ProducerFencedException("Fenced")); + } + if (this.sendException != null) { + throw this.sendException; + } + + int partition = 0; + if (!this.cluster.partitionsForTopic(record.topic()).isEmpty()) + partition = partition(record, this.cluster); + else { + //just to throw ClassCastException if serializers are not the proper ones to serialize key/value + keySerializer.serialize(record.topic(), record.key()); + valueSerializer.serialize(record.topic(), record.value()); + } + + TopicPartition topicPartition = new TopicPartition(record.topic(), partition); + ProduceRequestResult result = new ProduceRequestResult(topicPartition); + FutureRecordMetadata future = new FutureRecordMetadata(result, 0, RecordBatch.NO_TIMESTAMP, + 0, 0, Time.SYSTEM); + long offset = nextOffset(topicPartition); + long baseOffset = Math.max(0, offset - Integer.MAX_VALUE); + int batchIndex = (int) Math.min(Integer.MAX_VALUE, offset); + Completion completion = new Completion(offset, new RecordMetadata(topicPartition, baseOffset, batchIndex, + RecordBatch.NO_TIMESTAMP, 0, 0), result, callback, topicPartition); + + if (!this.transactionInFlight) + this.sent.add(record); + else + this.uncommittedSends.add(record); + + if (autoComplete) + completion.complete(null); + else + this.completions.addLast(completion); + + return future; + } + + /** + * Get the next offset for this topic/partition + */ + private long nextOffset(TopicPartition tp) { + Long offset = this.offsets.get(tp); + if (offset == null) { + this.offsets.put(tp, 1L); + return 0L; + } else { + Long next = offset + 1; + this.offsets.put(tp, next); + return offset; + } + } + + public synchronized void flush() { + verifyProducerState(); + + if (this.flushException != null) { + throw this.flushException; + } + + while (!this.completions.isEmpty()) + completeNext(); + } + + public List partitionsFor(String topic) { + if (this.partitionsForException != null) { + throw this.partitionsForException; + } + + return this.cluster.partitionsForTopic(topic); + } + + public Map metrics() { + return mockMetrics; + } + + /** + * Set a mock metric for testing purpose + */ + public void setMockMetrics(MetricName name, Metric metric) { + mockMetrics.put(name, metric); + } + + @Override + public void close() { + close(Duration.ofMillis(0)); + } + + @Override + public void close(Duration timeout) { + if (this.closeException != null) { + throw this.closeException; + } + + this.closed = true; + } + + public boolean closed() { + return this.closed; + } + + public synchronized void fenceProducer() { + verifyProducerState(); + verifyTransactionsInitialized(); + this.producerFenced = true; + } + + public boolean transactionInitialized() { + return this.transactionInitialized; + } + + public boolean transactionInFlight() { + return this.transactionInFlight; + } + + public boolean transactionCommitted() { + return this.transactionCommitted; + } + + public boolean transactionAborted() { + return this.transactionAborted; + } + + public boolean flushed() { + return this.completions.isEmpty(); + } + + public boolean sentOffsets() { + return this.sentOffsets; + } + + public long commitCount() { + return this.commitCount; + } + + /** + * Get the list of sent records since the last call to {@link #clear()} + */ + public synchronized List> history() { + return new ArrayList<>(this.sent); + } + + public synchronized List> uncommittedRecords() { + return new ArrayList<>(this.uncommittedSends); + } + + /** + * + * Get the list of committed consumer group offsets since the last call to {@link #clear()} + */ + public synchronized List>> consumerGroupOffsetsHistory() { + return new ArrayList<>(this.consumerGroupOffsets); + } + + public synchronized Map> uncommittedOffsets() { + return this.uncommittedConsumerGroupOffsets; + } + + /** + * Clear the stored history of sent records, consumer group offsets + */ + public synchronized void clear() { + this.sent.clear(); + this.uncommittedSends.clear(); + this.sentOffsets = false; + this.completions.clear(); + this.consumerGroupOffsets.clear(); + this.uncommittedConsumerGroupOffsets.clear(); + } + + /** + * Complete the earliest uncompleted call successfully. + * + * @return true if there was an uncompleted call to complete + */ + public synchronized boolean completeNext() { + return errorNext(null); + } + + /** + * Complete the earliest uncompleted call with the given error. + * + * @return true if there was an uncompleted call to complete + */ + public synchronized boolean errorNext(RuntimeException e) { + Completion completion = this.completions.pollFirst(); + if (completion != null) { + completion.complete(e); + return true; + } else { + return false; + } + } + + /** + * computes partition for given record. + */ + private int partition(ProducerRecord record, Cluster cluster) { + Integer partition = record.partition(); + String topic = record.topic(); + if (partition != null) { + List partitions = cluster.partitionsForTopic(topic); + int numPartitions = partitions.size(); + // they have given us a partition, use it + if (partition < 0 || partition >= numPartitions) + throw new IllegalArgumentException("Invalid partition given with record: " + partition + + " is not in the range [0..." + + numPartitions + + "]."); + return partition; + } + byte[] keyBytes = keySerializer.serialize(topic, record.headers(), record.key()); + byte[] valueBytes = valueSerializer.serialize(topic, record.headers(), record.value()); + return this.partitioner.partition(topic, record.key(), keyBytes, record.value(), valueBytes, cluster); + } + + private static class Completion { + private final long offset; + private final RecordMetadata metadata; + private final ProduceRequestResult result; + private final Callback callback; + private final TopicPartition tp; + + public Completion(long offset, + RecordMetadata metadata, + ProduceRequestResult result, + Callback callback, + TopicPartition tp) { + this.metadata = metadata; + this.offset = offset; + this.result = result; + this.callback = callback; + this.tp = tp; + } + + public void complete(RuntimeException e) { + if (e == null) { + result.set(offset, RecordBatch.NO_TIMESTAMP, null); + } else { + result.set(-1, RecordBatch.NO_TIMESTAMP, index -> e); + } + + if (callback != null) { + if (e == null) + callback.onCompletion(metadata, null); + else + callback.onCompletion(new RecordMetadata(tp, -1, -1, RecordBatch.NO_TIMESTAMP, -1, -1), e); + } + result.done(); + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/Partitioner.java b/clients/src/main/java/org/apache/kafka/clients/producer/Partitioner.java new file mode 100644 index 0000000..13eaa5a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/Partitioner.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer; + +import org.apache.kafka.common.Configurable; +import org.apache.kafka.common.Cluster; + +import java.io.Closeable; + +/** + * Partitioner Interface + */ +public interface Partitioner extends Configurable, Closeable { + + /** + * Compute the partition for the given record. + * + * @param topic The topic name + * @param key The key to partition on (or null if no key) + * @param keyBytes The serialized key to partition on( or null if no key) + * @param value The value to partition on or null + * @param valueBytes The serialized value to partition on or null + * @param cluster The current cluster metadata + */ + int partition(String topic, Object key, byte[] keyBytes, Object value, byte[] valueBytes, Cluster cluster); + + /** + * This is called when partitioner is closed. + */ + void close(); + + /** + * Notifies the partitioner a new batch is about to be created. When using the sticky partitioner, + * this method can change the chosen sticky partition for the new batch. + * @param topic The topic name + * @param cluster The current cluster metadata + * @param prevPartition The partition previously selected for the record that triggered a new batch + */ + default void onNewBatch(String topic, Cluster cluster, int prevPartition) { + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/Producer.java b/clients/src/main/java/org/apache/kafka/clients/producer/Producer.java new file mode 100644 index 0000000..4f3e9ec --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/Producer.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer; + +import org.apache.kafka.clients.consumer.ConsumerGroupMetadata; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.ProducerFencedException; + +import java.io.Closeable; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Future; + +/** + * The interface for the {@link KafkaProducer} + * @see KafkaProducer + * @see MockProducer + */ +public interface Producer extends Closeable { + + /** + * See {@link KafkaProducer#initTransactions()} + */ + void initTransactions(); + + /** + * See {@link KafkaProducer#beginTransaction()} + */ + void beginTransaction() throws ProducerFencedException; + + /** + * See {@link KafkaProducer#sendOffsetsToTransaction(Map, String)} + */ + @Deprecated + void sendOffsetsToTransaction(Map offsets, + String consumerGroupId) throws ProducerFencedException; + + /** + * See {@link KafkaProducer#sendOffsetsToTransaction(Map, ConsumerGroupMetadata)} + */ + void sendOffsetsToTransaction(Map offsets, + ConsumerGroupMetadata groupMetadata) throws ProducerFencedException; + + /** + * See {@link KafkaProducer#commitTransaction()} + */ + void commitTransaction() throws ProducerFencedException; + + /** + * See {@link KafkaProducer#abortTransaction()} + */ + void abortTransaction() throws ProducerFencedException; + + /** + * See {@link KafkaProducer#send(ProducerRecord)} + */ + Future send(ProducerRecord record); + + /** + * See {@link KafkaProducer#send(ProducerRecord, Callback)} + */ + Future send(ProducerRecord record, Callback callback); + + /** + * See {@link KafkaProducer#flush()} + */ + void flush(); + + /** + * See {@link KafkaProducer#partitionsFor(String)} + */ + List partitionsFor(String topic); + + /** + * See {@link KafkaProducer#metrics()} + */ + Map metrics(); + + /** + * See {@link KafkaProducer#close()} + */ + void close(); + + /** + * See {@link KafkaProducer#close(Duration)} + */ + void close(Duration timeout); +} diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/ProducerConfig.java b/clients/src/main/java/org/apache/kafka/clients/producer/ProducerConfig.java new file mode 100644 index 0000000..fbd3449 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/ProducerConfig.java @@ -0,0 +1,542 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer; + +import org.apache.kafka.clients.ClientDnsLookup; +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.clients.producer.internals.DefaultPartitioner; +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigDef.Importance; +import org.apache.kafka.common.config.ConfigDef.Type; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.config.SecurityConfig; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.serialization.Serializer; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.apache.kafka.common.config.ConfigDef.Range.atLeast; +import static org.apache.kafka.common.config.ConfigDef.Range.between; +import static org.apache.kafka.common.config.ConfigDef.ValidString.in; + +/** + * Configuration for the Kafka Producer. Documentation for these configurations can be found in the Kafka documentation + */ +public class ProducerConfig extends AbstractConfig { + + /* + * NOTE: DO NOT CHANGE EITHER CONFIG STRINGS OR THEIR JAVA VARIABLE NAMES AS THESE ARE PART OF THE PUBLIC API AND + * CHANGE WILL BREAK USER CODE. + */ + + private static final ConfigDef CONFIG; + + /** bootstrap.servers */ + public static final String BOOTSTRAP_SERVERS_CONFIG = CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG; + + /** client.dns.lookup */ + public static final String CLIENT_DNS_LOOKUP_CONFIG = CommonClientConfigs.CLIENT_DNS_LOOKUP_CONFIG; + + /** metadata.max.age.ms */ + public static final String METADATA_MAX_AGE_CONFIG = CommonClientConfigs.METADATA_MAX_AGE_CONFIG; + private static final String METADATA_MAX_AGE_DOC = CommonClientConfigs.METADATA_MAX_AGE_DOC; + + /** metadata.max.idle.ms */ + public static final String METADATA_MAX_IDLE_CONFIG = "metadata.max.idle.ms"; + private static final String METADATA_MAX_IDLE_DOC = + "Controls how long the producer will cache metadata for a topic that's idle. If the elapsed " + + "time since a topic was last produced to exceeds the metadata idle duration, then the topic's " + + "metadata is forgotten and the next access to it will force a metadata fetch request."; + + /** batch.size */ + public static final String BATCH_SIZE_CONFIG = "batch.size"; + private static final String BATCH_SIZE_DOC = "The producer will attempt to batch records together into fewer requests whenever multiple records are being sent" + + " to the same partition. This helps performance on both the client and the server. This configuration controls the " + + "default batch size in bytes. " + + "

" + + "No attempt will be made to batch records larger than this size. " + + "

" + + "Requests sent to brokers will contain multiple batches, one for each partition with data available to be sent. " + + "

" + + "A small batch size will make batching less common and may reduce throughput (a batch size of zero will disable " + + "batching entirely). A very large batch size may use memory a bit more wastefully as we will always allocate a " + + "buffer of the specified batch size in anticipation of additional records." + + "

" + + "Note: This setting gives the upper bound of the batch size to be sent. If we have fewer than this many bytes accumulated " + + "for this partition, we will 'linger' for the linger.ms time waiting for more records to show up. " + + "This linger.ms setting defaults to 0, which means we'll immediately send out a record even the accumulated " + + "batch size is under this batch.size setting."; + + /** acks */ + public static final String ACKS_CONFIG = "acks"; + private static final String ACKS_DOC = "The number of acknowledgments the producer requires the leader to have received before considering a request complete. This controls the " + + " durability of records that are sent. The following settings are allowed: " + + "

    " + + "
  • acks=0 If set to zero then the producer will not wait for any acknowledgment from the" + + " server at all. The record will be immediately added to the socket buffer and considered sent. No guarantee can be" + + " made that the server has received the record in this case, and the retries configuration will not" + + " take effect (as the client won't generally know of any failures). The offset given back for each record will" + + " always be set to -1." + + "
  • acks=1 This will mean the leader will write the record to its local log but will respond" + + " without awaiting full acknowledgement from all followers. In this case should the leader fail immediately after" + + " acknowledging the record but before the followers have replicated it then the record will be lost." + + "
  • acks=all This means the leader will wait for the full set of in-sync replicas to" + + " acknowledge the record. This guarantees that the record will not be lost as long as at least one in-sync replica" + + " remains alive. This is the strongest available guarantee. This is equivalent to the acks=-1 setting." + + "
"; + + /** linger.ms */ + public static final String LINGER_MS_CONFIG = "linger.ms"; + private static final String LINGER_MS_DOC = "The producer groups together any records that arrive in between request transmissions into a single batched request. " + + "Normally this occurs only under load when records arrive faster than they can be sent out. However in some circumstances the client may want to " + + "reduce the number of requests even under moderate load. This setting accomplishes this by adding a small amount " + + "of artificial delay—that is, rather than immediately sending out a record, the producer will wait for up to " + + "the given delay to allow other records to be sent so that the sends can be batched together. This can be thought " + + "of as analogous to Nagle's algorithm in TCP. This setting gives the upper bound on the delay for batching: once " + + "we get " + BATCH_SIZE_CONFIG + " worth of records for a partition it will be sent immediately regardless of this " + + "setting, however if we have fewer than this many bytes accumulated for this partition we will 'linger' for the " + + "specified time waiting for more records to show up. This setting defaults to 0 (i.e. no delay). Setting " + LINGER_MS_CONFIG + "=5, " + + "for example, would have the effect of reducing the number of requests sent but would add up to 5ms of latency to records sent in the absence of load."; + + /** request.timeout.ms */ + public static final String REQUEST_TIMEOUT_MS_CONFIG = CommonClientConfigs.REQUEST_TIMEOUT_MS_CONFIG; + private static final String REQUEST_TIMEOUT_MS_DOC = CommonClientConfigs.REQUEST_TIMEOUT_MS_DOC + + " This should be larger than replica.lag.time.max.ms (a broker configuration)" + + " to reduce the possibility of message duplication due to unnecessary producer retries."; + + /** delivery.timeout.ms */ + public static final String DELIVERY_TIMEOUT_MS_CONFIG = "delivery.timeout.ms"; + private static final String DELIVERY_TIMEOUT_MS_DOC = "An upper bound on the time to report success or failure " + + "after a call to send() returns. This limits the total time that a record will be delayed " + + "prior to sending, the time to await acknowledgement from the broker (if expected), and the time allowed " + + "for retriable send failures. The producer may report failure to send a record earlier than this config if " + + "either an unrecoverable error is encountered, the retries have been exhausted, " + + "or the record is added to a batch which reached an earlier delivery expiration deadline. " + + "The value of this config should be greater than or equal to the sum of " + REQUEST_TIMEOUT_MS_CONFIG + " " + + "and " + LINGER_MS_CONFIG + "."; + + /** client.id */ + public static final String CLIENT_ID_CONFIG = CommonClientConfigs.CLIENT_ID_CONFIG; + + /** send.buffer.bytes */ + public static final String SEND_BUFFER_CONFIG = CommonClientConfigs.SEND_BUFFER_CONFIG; + + /** receive.buffer.bytes */ + public static final String RECEIVE_BUFFER_CONFIG = CommonClientConfigs.RECEIVE_BUFFER_CONFIG; + + /** max.request.size */ + public static final String MAX_REQUEST_SIZE_CONFIG = "max.request.size"; + private static final String MAX_REQUEST_SIZE_DOC = + "The maximum size of a request in bytes. This setting will limit the number of record " + + "batches the producer will send in a single request to avoid sending huge requests. " + + "This is also effectively a cap on the maximum uncompressed record batch size. Note that the server " + + "has its own cap on the record batch size (after compression if compression is enabled) which may be different from this."; + + /** reconnect.backoff.ms */ + public static final String RECONNECT_BACKOFF_MS_CONFIG = CommonClientConfigs.RECONNECT_BACKOFF_MS_CONFIG; + + /** reconnect.backoff.max.ms */ + public static final String RECONNECT_BACKOFF_MAX_MS_CONFIG = CommonClientConfigs.RECONNECT_BACKOFF_MAX_MS_CONFIG; + + /** max.block.ms */ + public static final String MAX_BLOCK_MS_CONFIG = "max.block.ms"; + private static final String MAX_BLOCK_MS_DOC = "The configuration controls how long the KafkaProducer's send(), partitionsFor(), " + + "initTransactions(), sendOffsetsToTransaction(), commitTransaction() " + + "and abortTransaction() methods will block. " + + "For send() this timeout bounds the total time waiting for both metadata fetch and buffer allocation " + + "(blocking in the user-supplied serializers or partitioner is not counted against this timeout). " + + "For partitionsFor() this timeout bounds the time spent waiting for metadata if it is unavailable. " + + "The transaction-related methods always block, but may timeout if " + + "the transaction coordinator could not be discovered or did not respond within the timeout."; + + /** buffer.memory */ + public static final String BUFFER_MEMORY_CONFIG = "buffer.memory"; + private static final String BUFFER_MEMORY_DOC = "The total bytes of memory the producer can use to buffer records waiting to be sent to the server. If records are " + + "sent faster than they can be delivered to the server the producer will block for " + MAX_BLOCK_MS_CONFIG + " after which it will throw an exception." + + "

" + + "This setting should correspond roughly to the total memory the producer will use, but is not a hard bound since " + + "not all memory the producer uses is used for buffering. Some additional memory will be used for compression (if " + + "compression is enabled) as well as for maintaining in-flight requests."; + + /** retry.backoff.ms */ + public static final String RETRY_BACKOFF_MS_CONFIG = CommonClientConfigs.RETRY_BACKOFF_MS_CONFIG; + + /** compression.type */ + public static final String COMPRESSION_TYPE_CONFIG = "compression.type"; + private static final String COMPRESSION_TYPE_DOC = "The compression type for all data generated by the producer. The default is none (i.e. no compression). Valid " + + " values are none, gzip, snappy, lz4, or zstd. " + + "Compression is of full batches of data, so the efficacy of batching will also impact the compression ratio (more batching means better compression)."; + + /** metrics.sample.window.ms */ + public static final String METRICS_SAMPLE_WINDOW_MS_CONFIG = CommonClientConfigs.METRICS_SAMPLE_WINDOW_MS_CONFIG; + + /** metrics.num.samples */ + public static final String METRICS_NUM_SAMPLES_CONFIG = CommonClientConfigs.METRICS_NUM_SAMPLES_CONFIG; + + /** + * metrics.recording.level + */ + public static final String METRICS_RECORDING_LEVEL_CONFIG = CommonClientConfigs.METRICS_RECORDING_LEVEL_CONFIG; + + /** metric.reporters */ + public static final String METRIC_REPORTER_CLASSES_CONFIG = CommonClientConfigs.METRIC_REPORTER_CLASSES_CONFIG; + + /** max.in.flight.requests.per.connection */ + public static final String MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION = "max.in.flight.requests.per.connection"; + private static final String MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION_DOC = "The maximum number of unacknowledged requests the client will send on a single connection before blocking." + + " Note that if this config is set to be greater than 1 and enable.idempotence is set to false, there is a risk of" + + " message re-ordering after a failed send due to retries (i.e., if retries are enabled)."; + + /** retries */ + public static final String RETRIES_CONFIG = CommonClientConfigs.RETRIES_CONFIG; + private static final String RETRIES_DOC = "Setting a value greater than zero will cause the client to resend any record whose send fails with a potentially transient error." + + " Note that this retry is no different than if the client resent the record upon receiving the error." + + " Allowing retries without setting " + MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION + " to 1 will potentially change the" + + " ordering of records because if two batches are sent to a single partition, and the first fails and is retried but the second" + + " succeeds, then the records in the second batch may appear first. Note additionally that produce requests will be" + + " failed before the number of retries has been exhausted if the timeout configured by" + + " " + DELIVERY_TIMEOUT_MS_CONFIG + " expires first before successful acknowledgement. Users should generally" + + " prefer to leave this config unset and instead use " + DELIVERY_TIMEOUT_MS_CONFIG + " to control" + + " retry behavior."; + + /** key.serializer */ + public static final String KEY_SERIALIZER_CLASS_CONFIG = "key.serializer"; + public static final String KEY_SERIALIZER_CLASS_DOC = "Serializer class for key that implements the org.apache.kafka.common.serialization.Serializer interface."; + + /** value.serializer */ + public static final String VALUE_SERIALIZER_CLASS_CONFIG = "value.serializer"; + public static final String VALUE_SERIALIZER_CLASS_DOC = "Serializer class for value that implements the org.apache.kafka.common.serialization.Serializer interface."; + + /** socket.connection.setup.timeout.ms */ + public static final String SOCKET_CONNECTION_SETUP_TIMEOUT_MS_CONFIG = CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MS_CONFIG; + + /** socket.connection.setup.timeout.max.ms */ + public static final String SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS_CONFIG = CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS_CONFIG; + + /** connections.max.idle.ms */ + public static final String CONNECTIONS_MAX_IDLE_MS_CONFIG = CommonClientConfigs.CONNECTIONS_MAX_IDLE_MS_CONFIG; + + /** partitioner.class */ + public static final String PARTITIONER_CLASS_CONFIG = "partitioner.class"; + private static final String PARTITIONER_CLASS_DOC = "A class to use to determine which partition to be send to when produce the records. Available options are:" + + "

    " + + "
  • org.apache.kafka.clients.producer.internals.DefaultPartitioner: The default partitioner. " + + "This strategy will try sticking to a partition until the batch is full, or linger.ms is up. It works with the strategy:" + + "
      " + + "
    • If no partition is specified but a key is present, choose a partition based on a hash of the key
    • " + + "
    • If no partition or key is present, choose the sticky partition that changes when the batch is full, or linger.ms is up.
    • " + + "
    " + + "
  • " + + "
  • org.apache.kafka.clients.producer.RoundRobinPartitioner: This partitioning strategy is that " + + "each record in a series of consecutive records will be sent to a different partition(no matter if the 'key' is provided or not), " + + "until we run out of partitions and start over again. Note: There's a known issue that will cause uneven distribution when new batch is created. " + + "Please check KAFKA-9965 for more detail." + + "
  • " + + "
  • org.apache.kafka.clients.producer.UniformStickyPartitioner: This partitioning strategy will " + + "try sticking to a partition(no matter if the 'key' is provided or not) until the batch is full, or linger.ms is up." + + "
  • " + + "
" + + "

Implementing the org.apache.kafka.clients.producer.Partitioner interface allows you to plug in a custom partitioner."; + + /** interceptor.classes */ + public static final String INTERCEPTOR_CLASSES_CONFIG = "interceptor.classes"; + public static final String INTERCEPTOR_CLASSES_DOC = "A list of classes to use as interceptors. " + + "Implementing the org.apache.kafka.clients.producer.ProducerInterceptor interface allows you to intercept (and possibly mutate) the records " + + "received by the producer before they are published to the Kafka cluster. By default, there are no interceptors."; + + /** enable.idempotence */ + public static final String ENABLE_IDEMPOTENCE_CONFIG = "enable.idempotence"; + public static final String ENABLE_IDEMPOTENCE_DOC = "When set to 'true', the producer will ensure that exactly one copy of each message is written in the stream. If 'false', producer " + + "retries due to broker failures, etc., may write duplicates of the retried message in the stream. " + + "Note that enabling idempotence requires " + MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION + " to be less than or equal to 5 " + + "(with message ordering preserved for any allowable value), " + RETRIES_CONFIG + " to be greater than 0, and " + + ACKS_CONFIG + " must be 'all'. If these values are not explicitly set by the user, suitable values will be chosen. If incompatible " + + "values are set, a ConfigException will be thrown."; + + /** transaction.timeout.ms */ + public static final String TRANSACTION_TIMEOUT_CONFIG = "transaction.timeout.ms"; + public static final String TRANSACTION_TIMEOUT_DOC = "The maximum amount of time in ms that the transaction coordinator will wait for a transaction status update from the producer before proactively aborting the ongoing transaction." + + "If this value is larger than the transaction.max.timeout.ms setting in the broker, the request will fail with a InvalidTxnTimeoutException error."; + + /** transactional.id */ + public static final String TRANSACTIONAL_ID_CONFIG = "transactional.id"; + public static final String TRANSACTIONAL_ID_DOC = "The TransactionalId to use for transactional delivery. This enables reliability semantics which span multiple producer sessions since it allows the client to guarantee that transactions using the same TransactionalId have been completed prior to starting any new transactions. If no TransactionalId is provided, then the producer is limited to idempotent delivery. " + + "If a TransactionalId is configured, enable.idempotence is implied. " + + "By default the TransactionId is not configured, which means transactions cannot be used. " + + "Note that, by default, transactions require a cluster of at least three brokers which is the recommended setting for production; for development you can change this, by adjusting broker setting transaction.state.log.replication.factor."; + + /** + * security.providers + */ + public static final String SECURITY_PROVIDERS_CONFIG = SecurityConfig.SECURITY_PROVIDERS_CONFIG; + private static final String SECURITY_PROVIDERS_DOC = SecurityConfig.SECURITY_PROVIDERS_DOC; + + private static final AtomicInteger PRODUCER_CLIENT_ID_SEQUENCE = new AtomicInteger(1); + + static { + CONFIG = new ConfigDef().define(BOOTSTRAP_SERVERS_CONFIG, Type.LIST, Collections.emptyList(), new ConfigDef.NonNullValidator(), Importance.HIGH, CommonClientConfigs.BOOTSTRAP_SERVERS_DOC) + .define(CLIENT_DNS_LOOKUP_CONFIG, + Type.STRING, + ClientDnsLookup.USE_ALL_DNS_IPS.toString(), + in(ClientDnsLookup.USE_ALL_DNS_IPS.toString(), + ClientDnsLookup.RESOLVE_CANONICAL_BOOTSTRAP_SERVERS_ONLY.toString()), + Importance.MEDIUM, + CommonClientConfigs.CLIENT_DNS_LOOKUP_DOC) + .define(BUFFER_MEMORY_CONFIG, Type.LONG, 32 * 1024 * 1024L, atLeast(0L), Importance.HIGH, BUFFER_MEMORY_DOC) + .define(RETRIES_CONFIG, Type.INT, Integer.MAX_VALUE, between(0, Integer.MAX_VALUE), Importance.HIGH, RETRIES_DOC) + .define(ACKS_CONFIG, + Type.STRING, + "all", + in("all", "-1", "0", "1"), + Importance.LOW, + ACKS_DOC) + .define(COMPRESSION_TYPE_CONFIG, Type.STRING, "none", Importance.HIGH, COMPRESSION_TYPE_DOC) + .define(BATCH_SIZE_CONFIG, Type.INT, 16384, atLeast(0), Importance.MEDIUM, BATCH_SIZE_DOC) + .define(LINGER_MS_CONFIG, Type.LONG, 0, atLeast(0), Importance.MEDIUM, LINGER_MS_DOC) + .define(DELIVERY_TIMEOUT_MS_CONFIG, Type.INT, 120 * 1000, atLeast(0), Importance.MEDIUM, DELIVERY_TIMEOUT_MS_DOC) + .define(CLIENT_ID_CONFIG, Type.STRING, "", Importance.MEDIUM, CommonClientConfigs.CLIENT_ID_DOC) + .define(SEND_BUFFER_CONFIG, Type.INT, 128 * 1024, atLeast(CommonClientConfigs.SEND_BUFFER_LOWER_BOUND), Importance.MEDIUM, CommonClientConfigs.SEND_BUFFER_DOC) + .define(RECEIVE_BUFFER_CONFIG, Type.INT, 32 * 1024, atLeast(CommonClientConfigs.RECEIVE_BUFFER_LOWER_BOUND), Importance.MEDIUM, CommonClientConfigs.RECEIVE_BUFFER_DOC) + .define(MAX_REQUEST_SIZE_CONFIG, + Type.INT, + 1024 * 1024, + atLeast(0), + Importance.MEDIUM, + MAX_REQUEST_SIZE_DOC) + .define(RECONNECT_BACKOFF_MS_CONFIG, Type.LONG, 50L, atLeast(0L), Importance.LOW, CommonClientConfigs.RECONNECT_BACKOFF_MS_DOC) + .define(RECONNECT_BACKOFF_MAX_MS_CONFIG, Type.LONG, 1000L, atLeast(0L), Importance.LOW, CommonClientConfigs.RECONNECT_BACKOFF_MAX_MS_DOC) + .define(RETRY_BACKOFF_MS_CONFIG, Type.LONG, 100L, atLeast(0L), Importance.LOW, CommonClientConfigs.RETRY_BACKOFF_MS_DOC) + .define(MAX_BLOCK_MS_CONFIG, + Type.LONG, + 60 * 1000, + atLeast(0), + Importance.MEDIUM, + MAX_BLOCK_MS_DOC) + .define(REQUEST_TIMEOUT_MS_CONFIG, + Type.INT, + 30 * 1000, + atLeast(0), + Importance.MEDIUM, + REQUEST_TIMEOUT_MS_DOC) + .define(METADATA_MAX_AGE_CONFIG, Type.LONG, 5 * 60 * 1000, atLeast(0), Importance.LOW, METADATA_MAX_AGE_DOC) + .define(METADATA_MAX_IDLE_CONFIG, + Type.LONG, + 5 * 60 * 1000, + atLeast(5000), + Importance.LOW, + METADATA_MAX_IDLE_DOC) + .define(METRICS_SAMPLE_WINDOW_MS_CONFIG, + Type.LONG, + 30000, + atLeast(0), + Importance.LOW, + CommonClientConfigs.METRICS_SAMPLE_WINDOW_MS_DOC) + .define(METRICS_NUM_SAMPLES_CONFIG, Type.INT, 2, atLeast(1), Importance.LOW, CommonClientConfigs.METRICS_NUM_SAMPLES_DOC) + .define(METRICS_RECORDING_LEVEL_CONFIG, + Type.STRING, + Sensor.RecordingLevel.INFO.toString(), + in(Sensor.RecordingLevel.INFO.toString(), Sensor.RecordingLevel.DEBUG.toString(), Sensor.RecordingLevel.TRACE.toString()), + Importance.LOW, + CommonClientConfigs.METRICS_RECORDING_LEVEL_DOC) + .define(METRIC_REPORTER_CLASSES_CONFIG, + Type.LIST, + Collections.emptyList(), + new ConfigDef.NonNullValidator(), + Importance.LOW, + CommonClientConfigs.METRIC_REPORTER_CLASSES_DOC) + .define(MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION, + Type.INT, + 5, + atLeast(1), + Importance.LOW, + MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION_DOC) + .define(KEY_SERIALIZER_CLASS_CONFIG, + Type.CLASS, + Importance.HIGH, + KEY_SERIALIZER_CLASS_DOC) + .define(VALUE_SERIALIZER_CLASS_CONFIG, + Type.CLASS, + Importance.HIGH, + VALUE_SERIALIZER_CLASS_DOC) + .define(SOCKET_CONNECTION_SETUP_TIMEOUT_MS_CONFIG, + Type.LONG, + CommonClientConfigs.DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT_MS, + Importance.MEDIUM, + CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MS_DOC) + .define(SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS_CONFIG, + Type.LONG, + CommonClientConfigs.DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS, + Importance.MEDIUM, + CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS_DOC) + /* default is set to be a bit lower than the server default (10 min), to avoid both client and server closing connection at same time */ + .define(CONNECTIONS_MAX_IDLE_MS_CONFIG, + Type.LONG, + 9 * 60 * 1000, + Importance.MEDIUM, + CommonClientConfigs.CONNECTIONS_MAX_IDLE_MS_DOC) + .define(PARTITIONER_CLASS_CONFIG, + Type.CLASS, + DefaultPartitioner.class, + Importance.MEDIUM, PARTITIONER_CLASS_DOC) + .define(INTERCEPTOR_CLASSES_CONFIG, + Type.LIST, + Collections.emptyList(), + new ConfigDef.NonNullValidator(), + Importance.LOW, + INTERCEPTOR_CLASSES_DOC) + .define(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, + Type.STRING, + CommonClientConfigs.DEFAULT_SECURITY_PROTOCOL, + Importance.MEDIUM, + CommonClientConfigs.SECURITY_PROTOCOL_DOC) + .define(SECURITY_PROVIDERS_CONFIG, + Type.STRING, + null, + Importance.LOW, + SECURITY_PROVIDERS_DOC) + .withClientSslSupport() + .withClientSaslSupport() + .define(ENABLE_IDEMPOTENCE_CONFIG, + Type.BOOLEAN, + true, + Importance.LOW, + ENABLE_IDEMPOTENCE_DOC) + .define(TRANSACTION_TIMEOUT_CONFIG, + Type.INT, + 60000, + Importance.LOW, + TRANSACTION_TIMEOUT_DOC) + .define(TRANSACTIONAL_ID_CONFIG, + Type.STRING, + null, + new ConfigDef.NonEmptyString(), + Importance.LOW, + TRANSACTIONAL_ID_DOC); + } + + @Override + protected Map postProcessParsedConfig(final Map parsedValues) { + Map refinedConfigs = CommonClientConfigs.postProcessReconnectBackoffConfigs(this, parsedValues); + maybeOverrideEnableIdempotence(refinedConfigs); + maybeOverrideClientId(refinedConfigs); + maybeOverrideAcksAndRetries(refinedConfigs); + return refinedConfigs; + } + + private void maybeOverrideClientId(final Map configs) { + String refinedClientId; + boolean userConfiguredClientId = this.originals().containsKey(CLIENT_ID_CONFIG); + if (userConfiguredClientId) { + refinedClientId = this.getString(CLIENT_ID_CONFIG); + } else { + String transactionalId = this.getString(TRANSACTIONAL_ID_CONFIG); + refinedClientId = "producer-" + (transactionalId != null ? transactionalId : PRODUCER_CLIENT_ID_SEQUENCE.getAndIncrement()); + } + configs.put(CLIENT_ID_CONFIG, refinedClientId); + } + + private void maybeOverrideEnableIdempotence(final Map configs) { + boolean userConfiguredIdempotence = this.originals().containsKey(ENABLE_IDEMPOTENCE_CONFIG); + boolean userConfiguredTransactions = this.originals().containsKey(TRANSACTIONAL_ID_CONFIG); + + if (userConfiguredTransactions && !userConfiguredIdempotence) { + configs.put(ENABLE_IDEMPOTENCE_CONFIG, true); + } + } + + private void maybeOverrideAcksAndRetries(final Map configs) { + final String acksStr = parseAcks(this.getString(ACKS_CONFIG)); + configs.put(ACKS_CONFIG, acksStr); + // For idempotence producers, values for `RETRIES_CONFIG` and `ACKS_CONFIG` might need to be overridden. + if (idempotenceEnabled()) { + boolean userConfiguredRetries = this.originals().containsKey(RETRIES_CONFIG); + if (this.getInt(RETRIES_CONFIG) == 0) { + throw new ConfigException("Must set " + ProducerConfig.RETRIES_CONFIG + " to non-zero when using the idempotent producer."); + } + configs.put(RETRIES_CONFIG, userConfiguredRetries ? this.getInt(RETRIES_CONFIG) : Integer.MAX_VALUE); + + boolean userConfiguredAcks = this.originals().containsKey(ACKS_CONFIG); + final short acks = Short.valueOf(acksStr); + if (userConfiguredAcks && acks != (short) -1) { + throw new ConfigException("Must set " + ACKS_CONFIG + " to all in order to use the idempotent " + + "producer. Otherwise we cannot guarantee idempotence."); + } + configs.put(ACKS_CONFIG, "-1"); + } + } + + private static String parseAcks(String acksString) { + try { + return acksString.trim().equalsIgnoreCase("all") ? "-1" : Short.parseShort(acksString.trim()) + ""; + } catch (NumberFormatException e) { + throw new ConfigException("Invalid configuration value for 'acks': " + acksString); + } + } + + static Map appendSerializerToConfig(Map configs, + Serializer keySerializer, + Serializer valueSerializer) { + Map newConfigs = new HashMap<>(configs); + if (keySerializer != null) + newConfigs.put(KEY_SERIALIZER_CLASS_CONFIG, keySerializer.getClass()); + if (valueSerializer != null) + newConfigs.put(VALUE_SERIALIZER_CLASS_CONFIG, valueSerializer.getClass()); + return newConfigs; + } + + public ProducerConfig(Properties props) { + super(CONFIG, props); + } + + public ProducerConfig(Map props) { + super(CONFIG, props); + } + + boolean idempotenceEnabled() { + boolean userConfiguredIdempotence = this.originals().containsKey(ENABLE_IDEMPOTENCE_CONFIG); + boolean userConfiguredTransactions = this.originals().containsKey(TRANSACTIONAL_ID_CONFIG); + boolean idempotenceEnabled = userConfiguredIdempotence && this.getBoolean(ENABLE_IDEMPOTENCE_CONFIG); + + if (!idempotenceEnabled && userConfiguredIdempotence && userConfiguredTransactions) + throw new ConfigException("Cannot set a " + ProducerConfig.TRANSACTIONAL_ID_CONFIG + " without also enabling idempotence."); + return userConfiguredTransactions || idempotenceEnabled; + } + + ProducerConfig(Map props, boolean doLog) { + super(CONFIG, props, doLog); + } + + public static Set configNames() { + return CONFIG.names(); + } + + public static ConfigDef configDef() { + return new ConfigDef(CONFIG); + } + + public static void main(String[] args) { + System.out.println(CONFIG.toHtml(4, config -> "producerconfigs_" + config)); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/ProducerInterceptor.java b/clients/src/main/java/org/apache/kafka/clients/producer/ProducerInterceptor.java new file mode 100644 index 0000000..8f89d6f --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/ProducerInterceptor.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer; + +import org.apache.kafka.common.Configurable; + +/** + * A plugin interface that allows you to intercept (and possibly mutate) the records received by the producer before + * they are published to the Kafka cluster. + *

+ * This class will get producer config properties via configure() method, including clientId assigned + * by KafkaProducer if not specified in the producer config. The interceptor implementation needs to be aware that it will be + * sharing producer config namespace with other interceptors and serializers, and ensure that there are no conflicts. + *

+ * Exceptions thrown by ProducerInterceptor methods will be caught, logged, but not propagated further. As a result, if + * the user configures the interceptor with the wrong key and value type parameters, the producer will not throw an exception, + * just log the errors. + *

+ * ProducerInterceptor callbacks may be called from multiple threads. Interceptor implementation must ensure thread-safety, if needed. + *

+ * Implement {@link org.apache.kafka.common.ClusterResourceListener} to receive cluster metadata once it's available. Please see the class documentation for ClusterResourceListener for more information. + */ +public interface ProducerInterceptor extends Configurable { + /** + * This is called from {@link org.apache.kafka.clients.producer.KafkaProducer#send(ProducerRecord)} and + * {@link org.apache.kafka.clients.producer.KafkaProducer#send(ProducerRecord, Callback)} methods, before key and value + * get serialized and partition is assigned (if partition is not specified in ProducerRecord). + *

+ * This method is allowed to modify the record, in which case, the new record will be returned. The implication of modifying + * key/value is that partition assignment (if not specified in ProducerRecord) will be done based on modified key/value, + * not key/value from the client. Consequently, key and value transformation done in onSend() needs to be consistent: + * same key and value should mutate to the same (modified) key and value. Otherwise, log compaction would not work + * as expected. + *

+ * Similarly, it is up to interceptor implementation to ensure that correct topic/partition is returned in ProducerRecord. + * Most often, it should be the same topic/partition from 'record'. + *

+ * Any exception thrown by this method will be caught by the caller and logged, but not propagated further. + *

+ * Since the producer may run multiple interceptors, a particular interceptor's onSend() callback will be called in the order + * specified by {@link org.apache.kafka.clients.producer.ProducerConfig#INTERCEPTOR_CLASSES_CONFIG}. The first interceptor + * in the list gets the record passed from the client, the following interceptor will be passed the record returned by the + * previous interceptor, and so on. Since interceptors are allowed to modify records, interceptors may potentially get + * the record already modified by other interceptors. However, building a pipeline of mutable interceptors that depend on the output + * of the previous interceptor is discouraged, because of potential side-effects caused by interceptors potentially failing to + * modify the record and throwing an exception. If one of the interceptors in the list throws an exception from onSend(), the exception + * is caught, logged, and the next interceptor is called with the record returned by the last successful interceptor in the list, + * or otherwise the client. + * + * @param record the record from client or the record returned by the previous interceptor in the chain of interceptors. + * @return producer record to send to topic/partition + */ + ProducerRecord onSend(ProducerRecord record); + + /** + * This method is called when the record sent to the server has been acknowledged, or when sending the record fails before + * it gets sent to the server. + *

+ * This method is generally called just before the user callback is called, and in additional cases when KafkaProducer.send() + * throws an exception. + *

+ * Any exception thrown by this method will be ignored by the caller. + *

+ * This method will generally execute in the background I/O thread, so the implementation should be reasonably fast. + * Otherwise, sending of messages from other threads could be delayed. + * + * @param metadata The metadata for the record that was sent (i.e. the partition and offset). + * If an error occurred, metadata will contain only valid topic and maybe + * partition. If partition is not given in ProducerRecord and an error occurs + * before partition gets assigned, then partition will be set to RecordMetadata.NO_PARTITION. + * The metadata may be null if the client passed null record to + * {@link org.apache.kafka.clients.producer.KafkaProducer#send(ProducerRecord)}. + * @param exception The exception thrown during processing of this record. Null if no error occurred. + */ + void onAcknowledgement(RecordMetadata metadata, Exception exception); + + /** + * This is called when interceptor is closed + */ + void close(); +} diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/ProducerRecord.java b/clients/src/main/java/org/apache/kafka/clients/producer/ProducerRecord.java new file mode 100644 index 0000000..0fa37dc --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/ProducerRecord.java @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer; + +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeaders; + +import java.util.Objects; + +/** + * A key/value pair to be sent to Kafka. This consists of a topic name to which the record is being sent, an optional + * partition number, and an optional key and value. + *

+ * If a valid partition number is specified that partition will be used when sending the record. If no partition is + * specified but a key is present a partition will be chosen using a hash of the key. If neither key nor partition is + * present a partition will be assigned in a round-robin fashion. + *

+ * The record also has an associated timestamp. If the user did not provide a timestamp, the producer will stamp the + * record with its current time. The timestamp eventually used by Kafka depends on the timestamp type configured for + * the topic. + *

  • + * If the topic is configured to use {@link org.apache.kafka.common.record.TimestampType#CREATE_TIME CreateTime}, + * the timestamp in the producer record will be used by the broker. + *
  • + *
  • + * If the topic is configured to use {@link org.apache.kafka.common.record.TimestampType#LOG_APPEND_TIME LogAppendTime}, + * the timestamp in the producer record will be overwritten by the broker with the broker local time when it appends the + * message to its log. + *
  • + *

    + * In either of the cases above, the timestamp that has actually been used will be returned to user in + * {@link RecordMetadata} + */ +public class ProducerRecord { + + private final String topic; + private final Integer partition; + private final Headers headers; + private final K key; + private final V value; + private final Long timestamp; + + /** + * Creates a record with a specified timestamp to be sent to a specified topic and partition + * + * @param topic The topic the record will be appended to + * @param partition The partition to which the record should be sent + * @param timestamp The timestamp of the record, in milliseconds since epoch. If null, the producer will assign + * the timestamp using System.currentTimeMillis(). + * @param key The key that will be included in the record + * @param value The record contents + * @param headers the headers that will be included in the record + */ + public ProducerRecord(String topic, Integer partition, Long timestamp, K key, V value, Iterable

    headers) { + if (topic == null) + throw new IllegalArgumentException("Topic cannot be null."); + if (timestamp != null && timestamp < 0) + throw new IllegalArgumentException( + String.format("Invalid timestamp: %d. Timestamp should always be non-negative or null.", timestamp)); + if (partition != null && partition < 0) + throw new IllegalArgumentException( + String.format("Invalid partition: %d. Partition number should always be non-negative or null.", partition)); + this.topic = topic; + this.partition = partition; + this.key = key; + this.value = value; + this.timestamp = timestamp; + this.headers = new RecordHeaders(headers); + } + + /** + * Creates a record with a specified timestamp to be sent to a specified topic and partition + * + * @param topic The topic the record will be appended to + * @param partition The partition to which the record should be sent + * @param timestamp The timestamp of the record, in milliseconds since epoch. If null, the producer will assign the + * timestamp using System.currentTimeMillis(). + * @param key The key that will be included in the record + * @param value The record contents + */ + public ProducerRecord(String topic, Integer partition, Long timestamp, K key, V value) { + this(topic, partition, timestamp, key, value, null); + } + + /** + * Creates a record to be sent to a specified topic and partition + * + * @param topic The topic the record will be appended to + * @param partition The partition to which the record should be sent + * @param key The key that will be included in the record + * @param value The record contents + * @param headers The headers that will be included in the record + */ + public ProducerRecord(String topic, Integer partition, K key, V value, Iterable
    headers) { + this(topic, partition, null, key, value, headers); + } + + /** + * Creates a record to be sent to a specified topic and partition + * + * @param topic The topic the record will be appended to + * @param partition The partition to which the record should be sent + * @param key The key that will be included in the record + * @param value The record contents + */ + public ProducerRecord(String topic, Integer partition, K key, V value) { + this(topic, partition, null, key, value, null); + } + + /** + * Create a record to be sent to Kafka + * + * @param topic The topic the record will be appended to + * @param key The key that will be included in the record + * @param value The record contents + */ + public ProducerRecord(String topic, K key, V value) { + this(topic, null, null, key, value, null); + } + + /** + * Create a record with no key + * + * @param topic The topic this record should be sent to + * @param value The record contents + */ + public ProducerRecord(String topic, V value) { + this(topic, null, null, null, value, null); + } + + /** + * @return The topic this record is being sent to + */ + public String topic() { + return topic; + } + + /** + * @return The headers + */ + public Headers headers() { + return headers; + } + + /** + * @return The key (or null if no key is specified) + */ + public K key() { + return key; + } + + /** + * @return The value + */ + public V value() { + return value; + } + + /** + * @return The timestamp, which is in milliseconds since epoch. + */ + public Long timestamp() { + return timestamp; + } + + /** + * @return The partition to which the record will be sent (or null if no partition was specified) + */ + public Integer partition() { + return partition; + } + + @Override + public String toString() { + String headers = this.headers == null ? "null" : this.headers.toString(); + String key = this.key == null ? "null" : this.key.toString(); + String value = this.value == null ? "null" : this.value.toString(); + String timestamp = this.timestamp == null ? "null" : this.timestamp.toString(); + return "ProducerRecord(topic=" + topic + ", partition=" + partition + ", headers=" + headers + ", key=" + key + ", value=" + value + + ", timestamp=" + timestamp + ")"; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + else if (!(o instanceof ProducerRecord)) + return false; + + ProducerRecord that = (ProducerRecord) o; + + return Objects.equals(key, that.key) && + Objects.equals(partition, that.partition) && + Objects.equals(topic, that.topic) && + Objects.equals(headers, that.headers) && + Objects.equals(value, that.value) && + Objects.equals(timestamp, that.timestamp); + } + + @Override + public int hashCode() { + int result = topic != null ? topic.hashCode() : 0; + result = 31 * result + (partition != null ? partition.hashCode() : 0); + result = 31 * result + (headers != null ? headers.hashCode() : 0); + result = 31 * result + (key != null ? key.hashCode() : 0); + result = 31 * result + (value != null ? value.hashCode() : 0); + result = 31 * result + (timestamp != null ? timestamp.hashCode() : 0); + return result; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/RecordMetadata.java b/clients/src/main/java/org/apache/kafka/clients/producer/RecordMetadata.java new file mode 100644 index 0000000..8efc4c4 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/RecordMetadata.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.requests.ProduceResponse; + +/** + * The metadata for a record that has been acknowledged by the server + */ +public final class RecordMetadata { + + /** + * Partition value for record without partition assigned + */ + public static final int UNKNOWN_PARTITION = -1; + + private final long offset; + // The timestamp of the message. + // If LogAppendTime is used for the topic, the timestamp will be the timestamp returned by the broker. + // If CreateTime is used for the topic, the timestamp is the timestamp in the corresponding ProducerRecord if the + // user provided one. Otherwise, it will be the producer local time when the producer record was handed to the + // producer. + private final long timestamp; + private final int serializedKeySize; + private final int serializedValueSize; + private final TopicPartition topicPartition; + + /** + * Creates a new instance with the provided parameters. + */ + public RecordMetadata(TopicPartition topicPartition, long baseOffset, int batchIndex, long timestamp, + int serializedKeySize, int serializedValueSize) { + // ignore the batchIndex if the base offset is -1, since this indicates the offset is unknown + this.offset = baseOffset == -1 ? baseOffset : baseOffset + batchIndex; + this.timestamp = timestamp; + this.serializedKeySize = serializedKeySize; + this.serializedValueSize = serializedValueSize; + this.topicPartition = topicPartition; + } + + /** + * Creates a new instance with the provided parameters. + * + * @deprecated use constructor without `checksum` parameter. This constructor will be removed in + * Apache Kafka 4.0 (deprecated since 3.0). + */ + @Deprecated + public RecordMetadata(TopicPartition topicPartition, long baseOffset, long batchIndex, long timestamp, + Long checksum, int serializedKeySize, int serializedValueSize) { + this(topicPartition, baseOffset, batchIndexToInt(batchIndex), timestamp, serializedKeySize, serializedValueSize); + } + + private static int batchIndexToInt(long batchIndex) { + if (batchIndex > Integer.MAX_VALUE) + throw new IllegalArgumentException("batchIndex is larger than Integer.MAX_VALUE: " + batchIndex); + return (int) batchIndex; + } + + /** + * Indicates whether the record metadata includes the offset. + * @return true if the offset is included in the metadata, false otherwise. + */ + public boolean hasOffset() { + return this.offset != ProduceResponse.INVALID_OFFSET; + } + + /** + * The offset of the record in the topic/partition. + * @return the offset of the record, or -1 if {{@link #hasOffset()}} returns false. + */ + public long offset() { + return this.offset; + } + + /** + * Indicates whether the record metadata includes the timestamp. + * @return true if a valid timestamp exists, false otherwise. + */ + public boolean hasTimestamp() { + return this.timestamp != RecordBatch.NO_TIMESTAMP; + } + + /** + * The timestamp of the record in the topic/partition. + * + * @return the timestamp of the record, or -1 if the {{@link #hasTimestamp()}} returns false. + */ + public long timestamp() { + return this.timestamp; + } + + /** + * The size of the serialized, uncompressed key in bytes. If key is null, the returned size + * is -1. + */ + public int serializedKeySize() { + return this.serializedKeySize; + } + + /** + * The size of the serialized, uncompressed value in bytes. If value is null, the returned + * size is -1. + */ + public int serializedValueSize() { + return this.serializedValueSize; + } + + /** + * The topic the record was appended to + */ + public String topic() { + return this.topicPartition.topic(); + } + + /** + * The partition the record was sent to + */ + public int partition() { + return this.topicPartition.partition(); + } + + @Override + public String toString() { + return topicPartition.toString() + "@" + offset; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/RoundRobinPartitioner.java b/clients/src/main/java/org/apache/kafka/clients/producer/RoundRobinPartitioner.java new file mode 100644 index 0000000..80c4725 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/RoundRobinPartitioner.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.utils.Utils; + +/** + * The "Round-Robin" partitioner + * + * This partitioning strategy can be used when user wants + * to distribute the writes to all partitions equally. This + * is the behaviour regardless of record key hash. + * + */ +public class RoundRobinPartitioner implements Partitioner { + private final ConcurrentMap topicCounterMap = new ConcurrentHashMap<>(); + + public void configure(Map configs) {} + + /** + * Compute the partition for the given record. + * + * @param topic The topic name + * @param key The key to partition on (or null if no key) + * @param keyBytes serialized key to partition on (or null if no key) + * @param value The value to partition on or null + * @param valueBytes serialized value to partition on or null + * @param cluster The current cluster metadata + */ + @Override + public int partition(String topic, Object key, byte[] keyBytes, Object value, byte[] valueBytes, Cluster cluster) { + List partitions = cluster.partitionsForTopic(topic); + int numPartitions = partitions.size(); + int nextValue = nextValue(topic); + List availablePartitions = cluster.availablePartitionsForTopic(topic); + if (!availablePartitions.isEmpty()) { + int part = Utils.toPositive(nextValue) % availablePartitions.size(); + return availablePartitions.get(part).partition(); + } else { + // no partitions are available, give a non-available partition + return Utils.toPositive(nextValue) % numPartitions; + } + } + + private int nextValue(String topic) { + AtomicInteger counter = topicCounterMap.computeIfAbsent(topic, k -> { + return new AtomicInteger(0); + }); + return counter.getAndIncrement(); + } + + public void close() {} + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/UniformStickyPartitioner.java b/clients/src/main/java/org/apache/kafka/clients/producer/UniformStickyPartitioner.java new file mode 100644 index 0000000..be11d0b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/UniformStickyPartitioner.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer; + +import java.util.Map; + +import org.apache.kafka.clients.producer.internals.StickyPartitionCache; +import org.apache.kafka.common.Cluster; + + +/** + * The partitioning strategy: + *
      + *
    • If a partition is specified in the record, use it + *
    • Otherwise choose the sticky partition that changes when the batch is full. + * + * NOTE: In contrast to the DefaultPartitioner, the record key is NOT used as part of the partitioning strategy in this + * partitioner. Records with the same key are not guaranteed to be sent to the same partition. + * + * See KIP-480 for details about sticky partitioning. + */ +public class UniformStickyPartitioner implements Partitioner { + + private final StickyPartitionCache stickyPartitionCache = new StickyPartitionCache(); + + public void configure(Map configs) {} + + /** + * Compute the partition for the given record. + * + * @param topic The topic name + * @param key The key to partition on (or null if no key) + * @param keyBytes serialized key to partition on (or null if no key) + * @param value The value to partition on or null + * @param valueBytes serialized value to partition on or null + * @param cluster The current cluster metadata + */ + public int partition(String topic, Object key, byte[] keyBytes, Object value, byte[] valueBytes, Cluster cluster) { + return stickyPartitionCache.partition(topic, cluster); + } + + public void close() {} + + /** + * If a batch completed for the current sticky partition, change the sticky partition. + * Alternately, if no sticky partition has been determined, set one. + */ + public void onNewBatch(String topic, Cluster cluster, int prevPartition) { + stickyPartitionCache.nextPartition(topic, cluster, prevPartition); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/BufferPool.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/BufferPool.java new file mode 100644 index 0000000..210911a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/BufferPool.java @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer.internals; + +import java.nio.ByteBuffer; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; + +import org.apache.kafka.clients.producer.BufferExhaustedException; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.stats.Meter; +import org.apache.kafka.common.utils.Time; + + +/** + * A pool of ByteBuffers kept under a given memory limit. This class is fairly specific to the needs of the producer. In + * particular it has the following properties: + *
        + *
      1. There is a special "poolable size" and buffers of this size are kept in a free list and recycled + *
      2. It is fair. That is all memory is given to the longest waiting thread until it has sufficient memory. This + * prevents starvation or deadlock when a thread asks for a large chunk of memory and needs to block until multiple + * buffers are deallocated. + *
      + */ +public class BufferPool { + + static final String WAIT_TIME_SENSOR_NAME = "bufferpool-wait-time"; + + private final long totalMemory; + private final int poolableSize; + private final ReentrantLock lock; + private final Deque free; + private final Deque waiters; + /** Total available memory is the sum of nonPooledAvailableMemory and the number of byte buffers in free * poolableSize. */ + private long nonPooledAvailableMemory; + private final Metrics metrics; + private final Time time; + private final Sensor waitTime; + private boolean closed; + + /** + * Create a new buffer pool + * + * @param memory The maximum amount of memory that this buffer pool can allocate + * @param poolableSize The buffer size to cache in the free list rather than deallocating + * @param metrics instance of Metrics + * @param time time instance + * @param metricGrpName logical group name for metrics + */ + public BufferPool(long memory, int poolableSize, Metrics metrics, Time time, String metricGrpName) { + this.poolableSize = poolableSize; + this.lock = new ReentrantLock(); + this.free = new ArrayDeque<>(); + this.waiters = new ArrayDeque<>(); + this.totalMemory = memory; + this.nonPooledAvailableMemory = memory; + this.metrics = metrics; + this.time = time; + this.waitTime = this.metrics.sensor(WAIT_TIME_SENSOR_NAME); + MetricName rateMetricName = metrics.metricName("bufferpool-wait-ratio", + metricGrpName, + "The fraction of time an appender waits for space allocation."); + MetricName totalMetricName = metrics.metricName("bufferpool-wait-time-total", + metricGrpName, + "*Deprecated* The total time an appender waits for space allocation."); + MetricName totalNsMetricName = metrics.metricName("bufferpool-wait-time-ns-total", + metricGrpName, + "The total time in nanoseconds an appender waits for space allocation."); + + Sensor bufferExhaustedRecordSensor = metrics.sensor("buffer-exhausted-records"); + MetricName bufferExhaustedRateMetricName = metrics.metricName("buffer-exhausted-rate", metricGrpName, "The average per-second number of record sends that are dropped due to buffer exhaustion"); + MetricName bufferExhaustedTotalMetricName = metrics.metricName("buffer-exhausted-total", metricGrpName, "The total number of record sends that are dropped due to buffer exhaustion"); + bufferExhaustedRecordSensor.add(new Meter(bufferExhaustedRateMetricName, bufferExhaustedTotalMetricName)); + + this.waitTime.add(new Meter(TimeUnit.NANOSECONDS, rateMetricName, totalMetricName)); + this.waitTime.add(new Meter(TimeUnit.NANOSECONDS, rateMetricName, totalNsMetricName)); + this.closed = false; + } + + /** + * Allocate a buffer of the given size. This method blocks if there is not enough memory and the buffer pool + * is configured with blocking mode. + * + * @param size The buffer size to allocate in bytes + * @param maxTimeToBlockMs The maximum time in milliseconds to block for buffer memory to be available + * @return The buffer + * @throws InterruptedException If the thread is interrupted while blocked + * @throws IllegalArgumentException if size is larger than the total memory controlled by the pool (and hence we would block + * forever) + */ + public ByteBuffer allocate(int size, long maxTimeToBlockMs) throws InterruptedException { + if (size > this.totalMemory) + throw new IllegalArgumentException("Attempt to allocate " + size + + " bytes, but there is a hard limit of " + + this.totalMemory + + " on memory allocations."); + + ByteBuffer buffer = null; + this.lock.lock(); + + if (this.closed) { + this.lock.unlock(); + throw new KafkaException("Producer closed while allocating memory"); + } + + try { + // check if we have a free buffer of the right size pooled + if (size == poolableSize && !this.free.isEmpty()) + return this.free.pollFirst(); + + // now check if the request is immediately satisfiable with the + // memory on hand or if we need to block + int freeListSize = freeSize() * this.poolableSize; + if (this.nonPooledAvailableMemory + freeListSize >= size) { + // we have enough unallocated or pooled memory to immediately + // satisfy the request, but need to allocate the buffer + freeUp(size); + this.nonPooledAvailableMemory -= size; + } else { + // we are out of memory and will have to block + int accumulated = 0; + Condition moreMemory = this.lock.newCondition(); + try { + long remainingTimeToBlockNs = TimeUnit.MILLISECONDS.toNanos(maxTimeToBlockMs); + this.waiters.addLast(moreMemory); + // loop over and over until we have a buffer or have reserved + // enough memory to allocate one + while (accumulated < size) { + long startWaitNs = time.nanoseconds(); + long timeNs; + boolean waitingTimeElapsed; + try { + waitingTimeElapsed = !moreMemory.await(remainingTimeToBlockNs, TimeUnit.NANOSECONDS); + } finally { + long endWaitNs = time.nanoseconds(); + timeNs = Math.max(0L, endWaitNs - startWaitNs); + recordWaitTime(timeNs); + } + + if (this.closed) + throw new KafkaException("Producer closed while allocating memory"); + + if (waitingTimeElapsed) { + this.metrics.sensor("buffer-exhausted-records").record(); + throw new BufferExhaustedException("Failed to allocate " + size + " bytes within the configured max blocking time " + + maxTimeToBlockMs + " ms. Total memory: " + totalMemory() + " bytes. Available memory: " + availableMemory() + + " bytes. Poolable size: " + poolableSize() + " bytes"); + } + + remainingTimeToBlockNs -= timeNs; + + // check if we can satisfy this request from the free list, + // otherwise allocate memory + if (accumulated == 0 && size == this.poolableSize && !this.free.isEmpty()) { + // just grab a buffer from the free list + buffer = this.free.pollFirst(); + accumulated = size; + } else { + // we'll need to allocate memory, but we may only get + // part of what we need on this iteration + freeUp(size - accumulated); + int got = (int) Math.min(size - accumulated, this.nonPooledAvailableMemory); + this.nonPooledAvailableMemory -= got; + accumulated += got; + } + } + // Don't reclaim memory on throwable since nothing was thrown + accumulated = 0; + } finally { + // When this loop was not able to successfully terminate don't loose available memory + this.nonPooledAvailableMemory += accumulated; + this.waiters.remove(moreMemory); + } + } + } finally { + // signal any additional waiters if there is more memory left + // over for them + try { + if (!(this.nonPooledAvailableMemory == 0 && this.free.isEmpty()) && !this.waiters.isEmpty()) + this.waiters.peekFirst().signal(); + } finally { + // Another finally... otherwise find bugs complains + lock.unlock(); + } + } + + if (buffer == null) + return safeAllocateByteBuffer(size); + else + return buffer; + } + + // Protected for testing + protected void recordWaitTime(long timeNs) { + this.waitTime.record(timeNs, time.milliseconds()); + } + + /** + * Allocate a buffer. If buffer allocation fails (e.g. because of OOM) then return the size count back to + * available memory and signal the next waiter if it exists. + */ + private ByteBuffer safeAllocateByteBuffer(int size) { + boolean error = true; + try { + ByteBuffer buffer = allocateByteBuffer(size); + error = false; + return buffer; + } finally { + if (error) { + this.lock.lock(); + try { + this.nonPooledAvailableMemory += size; + if (!this.waiters.isEmpty()) + this.waiters.peekFirst().signal(); + } finally { + this.lock.unlock(); + } + } + } + } + + // Protected for testing. + protected ByteBuffer allocateByteBuffer(int size) { + return ByteBuffer.allocate(size); + } + + /** + * Attempt to ensure we have at least the requested number of bytes of memory for allocation by deallocating pooled + * buffers (if needed) + */ + private void freeUp(int size) { + while (!this.free.isEmpty() && this.nonPooledAvailableMemory < size) + this.nonPooledAvailableMemory += this.free.pollLast().capacity(); + } + + /** + * Return buffers to the pool. If they are of the poolable size add them to the free list, otherwise just mark the + * memory as free. + * + * @param buffer The buffer to return + * @param size The size of the buffer to mark as deallocated, note that this may be smaller than buffer.capacity + * since the buffer may re-allocate itself during in-place compression + */ + public void deallocate(ByteBuffer buffer, int size) { + lock.lock(); + try { + if (size == this.poolableSize && size == buffer.capacity()) { + buffer.clear(); + this.free.add(buffer); + } else { + this.nonPooledAvailableMemory += size; + } + Condition moreMem = this.waiters.peekFirst(); + if (moreMem != null) + moreMem.signal(); + } finally { + lock.unlock(); + } + } + + public void deallocate(ByteBuffer buffer) { + deallocate(buffer, buffer.capacity()); + } + + /** + * the total free memory both unallocated and in the free list + */ + public long availableMemory() { + lock.lock(); + try { + return this.nonPooledAvailableMemory + freeSize() * (long) this.poolableSize; + } finally { + lock.unlock(); + } + } + + // Protected for testing. + protected int freeSize() { + return this.free.size(); + } + + /** + * Get the unallocated memory (not in the free list or in use) + */ + public long unallocatedMemory() { + lock.lock(); + try { + return this.nonPooledAvailableMemory; + } finally { + lock.unlock(); + } + } + + /** + * The number of threads blocked waiting on memory + */ + public int queued() { + lock.lock(); + try { + return this.waiters.size(); + } finally { + lock.unlock(); + } + } + + /** + * The buffer size that will be retained in the free list after use + */ + public int poolableSize() { + return this.poolableSize; + } + + /** + * The total memory managed by this pool + */ + public long totalMemory() { + return this.totalMemory; + } + + // package-private method used only for testing + Deque waiters() { + return this.waiters; + } + + /** + * Closes the buffer pool. Memory will be prevented from being allocated, but may be deallocated. All allocations + * awaiting available memory will be notified to abort. + */ + public void close() { + this.lock.lock(); + this.closed = true; + try { + for (Condition waiter : this.waiters) + waiter.signal(); + } finally { + this.lock.unlock(); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/DefaultPartitioner.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/DefaultPartitioner.java new file mode 100644 index 0000000..cf765d1 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/DefaultPartitioner.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer.internals; + +import org.apache.kafka.clients.producer.Partitioner; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.utils.Utils; + +import java.util.Map; + +/** + * The default partitioning strategy: + *
        + *
      • If a partition is specified in the record, use it + *
      • If no partition is specified but a key is present choose a partition based on a hash of the key + *
      • If no partition or key is present choose the sticky partition that changes when the batch is full. + * + * See KIP-480 for details about sticky partitioning. + */ +public class DefaultPartitioner implements Partitioner { + + private final StickyPartitionCache stickyPartitionCache = new StickyPartitionCache(); + + public void configure(Map configs) {} + + /** + * Compute the partition for the given record. + * + * @param topic The topic name + * @param key The key to partition on (or null if no key) + * @param keyBytes serialized key to partition on (or null if no key) + * @param value The value to partition on or null + * @param valueBytes serialized value to partition on or null + * @param cluster The current cluster metadata + */ + public int partition(String topic, Object key, byte[] keyBytes, Object value, byte[] valueBytes, Cluster cluster) { + return partition(topic, key, keyBytes, value, valueBytes, cluster, cluster.partitionsForTopic(topic).size()); + } + + /** + * Compute the partition for the given record. + * + * @param topic The topic name + * @param numPartitions The number of partitions of the given {@code topic} + * @param key The key to partition on (or null if no key) + * @param keyBytes serialized key to partition on (or null if no key) + * @param value The value to partition on or null + * @param valueBytes serialized value to partition on or null + * @param cluster The current cluster metadata + */ + public int partition(String topic, Object key, byte[] keyBytes, Object value, byte[] valueBytes, Cluster cluster, + int numPartitions) { + if (keyBytes == null) { + return stickyPartitionCache.partition(topic, cluster); + } + // hash the keyBytes to choose a partition + return Utils.toPositive(Utils.murmur2(keyBytes)) % numPartitions; + } + + public void close() {} + + /** + * If a batch completed for the current sticky partition, change the sticky partition. + * Alternately, if no sticky partition has been determined, set one. + */ + public void onNewBatch(String topic, Cluster cluster, int prevPartition) { + stickyPartitionCache.nextPartition(topic, cluster, prevPartition); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/ErrorLoggingCallback.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/ErrorLoggingCallback.java new file mode 100644 index 0000000..07a2878 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/ErrorLoggingCallback.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer.internals; + +import org.apache.kafka.clients.producer.Callback; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.charset.StandardCharsets; + +public class ErrorLoggingCallback implements Callback { + private static final Logger log = LoggerFactory.getLogger(ErrorLoggingCallback.class); + private String topic; + private byte[] key; + private byte[] value; + private int valueLength; + private boolean logAsString; + + public ErrorLoggingCallback(String topic, byte[] key, byte[] value, boolean logAsString) { + this.topic = topic; + this.key = key; + + if (logAsString) { + this.value = value; + } + + this.valueLength = value == null ? -1 : value.length; + this.logAsString = logAsString; + } + + public void onCompletion(RecordMetadata metadata, Exception e) { + if (e != null) { + String keyString = (key == null) ? "null" : + logAsString ? new String(key, StandardCharsets.UTF_8) : key.length + " bytes"; + String valueString = (valueLength == -1) ? "null" : + logAsString ? new String(value, StandardCharsets.UTF_8) : valueLength + " bytes"; + log.error("Error when sending message to topic {} with key: {}, value: {} with error:", + topic, keyString, valueString, e); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/FutureRecordMetadata.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/FutureRecordMetadata.java new file mode 100644 index 0000000..0026237 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/FutureRecordMetadata.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer.internals; + +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.utils.Time; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +/** + * The future result of a record send + */ +public final class FutureRecordMetadata implements Future { + + private final ProduceRequestResult result; + private final int batchIndex; + private final long createTimestamp; + private final int serializedKeySize; + private final int serializedValueSize; + private final Time time; + private volatile FutureRecordMetadata nextRecordMetadata = null; + + public FutureRecordMetadata(ProduceRequestResult result, int batchIndex, long createTimestamp, int serializedKeySize, + int serializedValueSize, Time time) { + this.result = result; + this.batchIndex = batchIndex; + this.createTimestamp = createTimestamp; + this.serializedKeySize = serializedKeySize; + this.serializedValueSize = serializedValueSize; + this.time = time; + } + + @Override + public boolean cancel(boolean interrupt) { + return false; + } + + @Override + public boolean isCancelled() { + return false; + } + + @Override + public RecordMetadata get() throws InterruptedException, ExecutionException { + this.result.await(); + if (nextRecordMetadata != null) + return nextRecordMetadata.get(); + return valueOrError(); + } + + @Override + public RecordMetadata get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { + // Handle overflow. + long now = time.milliseconds(); + long timeoutMillis = unit.toMillis(timeout); + long deadline = Long.MAX_VALUE - timeoutMillis < now ? Long.MAX_VALUE : now + timeoutMillis; + boolean occurred = this.result.await(timeout, unit); + if (!occurred) + throw new TimeoutException("Timeout after waiting for " + timeoutMillis + " ms."); + if (nextRecordMetadata != null) + return nextRecordMetadata.get(deadline - time.milliseconds(), TimeUnit.MILLISECONDS); + return valueOrError(); + } + + /** + * This method is used when we have to split a large batch in smaller ones. A chained metadata will allow the + * future that has already returned to the users to wait on the newly created split batches even after the + * old big batch has been deemed as done. + */ + void chain(FutureRecordMetadata futureRecordMetadata) { + if (nextRecordMetadata == null) + nextRecordMetadata = futureRecordMetadata; + else + nextRecordMetadata.chain(futureRecordMetadata); + } + + RecordMetadata valueOrError() throws ExecutionException { + RuntimeException exception = this.result.error(batchIndex); + if (exception != null) + throw new ExecutionException(exception); + else + return value(); + } + + RecordMetadata value() { + if (nextRecordMetadata != null) + return nextRecordMetadata.value(); + return new RecordMetadata(result.topicPartition(), this.result.baseOffset(), this.batchIndex, + timestamp(), this.serializedKeySize, this.serializedValueSize); + } + + private long timestamp() { + return result.hasLogAppendTime() ? result.logAppendTime() : createTimestamp; + } + + @Override + public boolean isDone() { + if (nextRecordMetadata != null) + return nextRecordMetadata.isDone(); + return this.result.completed(); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/IncompleteBatches.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/IncompleteBatches.java new file mode 100644 index 0000000..00f8197 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/IncompleteBatches.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer.internals; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.Set; +import java.util.stream.Collectors; + +/* + * A thread-safe helper class to hold batches that haven't been acknowledged yet (including those + * which have and have not been sent). + */ +class IncompleteBatches { + private final Set incomplete; + + public IncompleteBatches() { + this.incomplete = new HashSet<>(); + } + + public void add(ProducerBatch batch) { + synchronized (incomplete) { + this.incomplete.add(batch); + } + } + + public void remove(ProducerBatch batch) { + synchronized (incomplete) { + boolean removed = this.incomplete.remove(batch); + if (!removed) + throw new IllegalStateException("Remove from the incomplete set failed. This should be impossible."); + } + } + + public Iterable copyAll() { + synchronized (incomplete) { + return new ArrayList<>(this.incomplete); + } + } + + public Iterable requestResults() { + synchronized (incomplete) { + return incomplete.stream().map(batch -> batch.produceFuture).collect(Collectors.toList()); + } + } + + public boolean isEmpty() { + synchronized (incomplete) { + return incomplete.isEmpty(); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/KafkaProducerMetrics.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/KafkaProducerMetrics.java new file mode 100644 index 0000000..3c6fe26 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/KafkaProducerMetrics.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.producer.internals; + +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.stats.CumulativeSum; + +import java.util.Map; + +public class KafkaProducerMetrics implements AutoCloseable { + + public static final String GROUP = "producer-metrics"; + private static final String FLUSH = "flush"; + private static final String TXN_INIT = "txn-init"; + private static final String TXN_BEGIN = "txn-begin"; + private static final String TXN_SEND_OFFSETS = "txn-send-offsets"; + private static final String TXN_COMMIT = "txn-commit"; + private static final String TXN_ABORT = "txn-abort"; + private static final String TOTAL_TIME_SUFFIX = "-time-ns-total"; + + private final Map tags; + private final Metrics metrics; + private final Sensor initTimeSensor; + private final Sensor beginTxnTimeSensor; + private final Sensor flushTimeSensor; + private final Sensor sendOffsetsSensor; + private final Sensor commitTxnSensor; + private final Sensor abortTxnSensor; + + public KafkaProducerMetrics(Metrics metrics) { + this.metrics = metrics; + tags = this.metrics.config().tags(); + flushTimeSensor = newLatencySensor( + FLUSH, + "Total time producer has spent in flush in nanoseconds." + ); + initTimeSensor = newLatencySensor( + TXN_INIT, + "Total time producer has spent in initTransactions in nanoseconds." + ); + beginTxnTimeSensor = newLatencySensor( + TXN_BEGIN, + "Total time producer has spent in beginTransaction in nanoseconds." + ); + sendOffsetsSensor = newLatencySensor( + TXN_SEND_OFFSETS, + "Total time producer has spent in sendOffsetsToTransaction in nanoseconds." + ); + commitTxnSensor = newLatencySensor( + TXN_COMMIT, + "Total time producer has spent in commitTransaction in nanoseconds." + ); + abortTxnSensor = newLatencySensor( + TXN_ABORT, + "Total time producer has spent in abortTransaction in nanoseconds." + ); + } + + @Override + public void close() { + removeMetric(FLUSH); + removeMetric(TXN_INIT); + removeMetric(TXN_BEGIN); + removeMetric(TXN_SEND_OFFSETS); + removeMetric(TXN_COMMIT); + removeMetric(TXN_ABORT); + } + + public void recordFlush(long duration) { + flushTimeSensor.record(duration); + } + + public void recordInit(long duration) { + initTimeSensor.record(duration); + } + + public void recordBeginTxn(long duration) { + beginTxnTimeSensor.record(duration); + } + + public void recordSendOffsets(long duration) { + sendOffsetsSensor.record(duration); + } + + public void recordCommitTxn(long duration) { + commitTxnSensor.record(duration); + } + + public void recordAbortTxn(long duration) { + abortTxnSensor.record(duration); + } + + private Sensor newLatencySensor(String name, String description) { + Sensor sensor = metrics.sensor(name + TOTAL_TIME_SUFFIX); + sensor.add(metricName(name, description), new CumulativeSum()); + return sensor; + } + + private MetricName metricName(final String name, final String description) { + return metrics.metricName(name + TOTAL_TIME_SUFFIX, GROUP, description, tags); + } + + private void removeMetric(final String name) { + metrics.removeSensor(name + TOTAL_TIME_SUFFIX); + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProduceRequestResult.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProduceRequestResult.java new file mode 100644 index 0000000..9077b10 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProduceRequestResult.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer.internals; + +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.record.RecordBatch; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; + +/** + * A class that models the future completion of a produce request for a single partition. There is one of these per + * partition in a produce request and it is shared by all the {@link RecordMetadata} instances that are batched together + * for the same partition in the request. + */ +public class ProduceRequestResult { + + private final CountDownLatch latch = new CountDownLatch(1); + private final TopicPartition topicPartition; + + private volatile Long baseOffset = null; + private volatile long logAppendTime = RecordBatch.NO_TIMESTAMP; + private volatile Function errorsByIndex; + + /** + * Create an instance of this class. + * + * @param topicPartition The topic and partition to which this record set was sent was sent + */ + public ProduceRequestResult(TopicPartition topicPartition) { + this.topicPartition = topicPartition; + } + + /** + * Set the result of the produce request. + * + * @param baseOffset The base offset assigned to the record + * @param logAppendTime The log append time or -1 if CreateTime is being used + * @param errorsByIndex Function mapping the batch index to the exception, or null if the response was successful + */ + public void set(long baseOffset, long logAppendTime, Function errorsByIndex) { + this.baseOffset = baseOffset; + this.logAppendTime = logAppendTime; + this.errorsByIndex = errorsByIndex; + } + + /** + * Mark this request as complete and unblock any threads waiting on its completion. + */ + public void done() { + if (baseOffset == null) + throw new IllegalStateException("The method `set` must be invoked before this method."); + this.latch.countDown(); + } + + /** + * Await the completion of this request + */ + public void await() throws InterruptedException { + latch.await(); + } + + /** + * Await the completion of this request (up to the given time interval) + * @param timeout The maximum time to wait + * @param unit The unit for the max time + * @return true if the request completed, false if we timed out + */ + public boolean await(long timeout, TimeUnit unit) throws InterruptedException { + return latch.await(timeout, unit); + } + + /** + * The base offset for the request (the first offset in the record set) + */ + public long baseOffset() { + return baseOffset; + } + + /** + * Return true if log append time is being used for this topic + */ + public boolean hasLogAppendTime() { + return logAppendTime != RecordBatch.NO_TIMESTAMP; + } + + /** + * The log append time or -1 if CreateTime is being used + */ + public long logAppendTime() { + return logAppendTime; + } + + /** + * The error thrown (generally on the server) while processing this request + */ + public RuntimeException error(int batchIndex) { + if (errorsByIndex == null) { + return null; + } else { + return errorsByIndex.apply(batchIndex); + } + } + + /** + * The topic and partition to which the record was appended + */ + public TopicPartition topicPartition() { + return topicPartition; + } + + /** + * Has the request completed? + */ + public boolean completed() { + return this.latch.getCount() == 0L; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerBatch.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerBatch.java new file mode 100644 index 0000000..4da0362 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerBatch.java @@ -0,0 +1,520 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer.internals; + +import org.apache.kafka.clients.producer.Callback; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.RecordBatchTooLargeException; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.record.AbstractRecords; +import org.apache.kafka.common.record.CompressionRatioEstimator; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.MemoryRecordsBuilder; +import org.apache.kafka.common.record.MutableRecordBatch; +import org.apache.kafka.common.record.Record; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.common.requests.ProduceResponse; +import org.apache.kafka.common.utils.ProducerIdAndEpoch; +import org.apache.kafka.common.utils.Time; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Deque; +import java.util.Iterator; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +import static org.apache.kafka.common.record.RecordBatch.MAGIC_VALUE_V2; +import static org.apache.kafka.common.record.RecordBatch.NO_TIMESTAMP; + +/** + * A batch of records that is or will be sent. + * + * This class is not thread safe and external synchronization must be used when modifying it + */ +public final class ProducerBatch { + + private static final Logger log = LoggerFactory.getLogger(ProducerBatch.class); + + private enum FinalState { ABORTED, FAILED, SUCCEEDED } + + final long createdMs; + final TopicPartition topicPartition; + final ProduceRequestResult produceFuture; + + private final List thunks = new ArrayList<>(); + private final MemoryRecordsBuilder recordsBuilder; + private final AtomicInteger attempts = new AtomicInteger(0); + private final boolean isSplitBatch; + private final AtomicReference finalState = new AtomicReference<>(null); + + int recordCount; + int maxRecordSize; + private long lastAttemptMs; + private long lastAppendTime; + private long drainedMs; + private boolean retry; + private boolean reopened; + + public ProducerBatch(TopicPartition tp, MemoryRecordsBuilder recordsBuilder, long createdMs) { + this(tp, recordsBuilder, createdMs, false); + } + + public ProducerBatch(TopicPartition tp, MemoryRecordsBuilder recordsBuilder, long createdMs, boolean isSplitBatch) { + this.createdMs = createdMs; + this.lastAttemptMs = createdMs; + this.recordsBuilder = recordsBuilder; + this.topicPartition = tp; + this.lastAppendTime = createdMs; + this.produceFuture = new ProduceRequestResult(topicPartition); + this.retry = false; + this.isSplitBatch = isSplitBatch; + float compressionRatioEstimation = CompressionRatioEstimator.estimation(topicPartition.topic(), + recordsBuilder.compressionType()); + recordsBuilder.setEstimatedCompressionRatio(compressionRatioEstimation); + } + + /** + * Append the record to the current record set and return the relative offset within that record set + * + * @return The RecordSend corresponding to this record or null if there isn't sufficient room. + */ + public FutureRecordMetadata tryAppend(long timestamp, byte[] key, byte[] value, Header[] headers, Callback callback, long now) { + if (!recordsBuilder.hasRoomFor(timestamp, key, value, headers)) { + return null; + } else { + this.recordsBuilder.append(timestamp, key, value, headers); + this.maxRecordSize = Math.max(this.maxRecordSize, AbstractRecords.estimateSizeInBytesUpperBound(magic(), + recordsBuilder.compressionType(), key, value, headers)); + this.lastAppendTime = now; + FutureRecordMetadata future = new FutureRecordMetadata(this.produceFuture, this.recordCount, + timestamp, + key == null ? -1 : key.length, + value == null ? -1 : value.length, + Time.SYSTEM); + // we have to keep every future returned to the users in case the batch needs to be + // split to several new batches and resent. + thunks.add(new Thunk(callback, future)); + this.recordCount++; + return future; + } + } + + /** + * This method is only used by {@link #split(int)} when splitting a large batch to smaller ones. + * @return true if the record has been successfully appended, false otherwise. + */ + private boolean tryAppendForSplit(long timestamp, ByteBuffer key, ByteBuffer value, Header[] headers, Thunk thunk) { + if (!recordsBuilder.hasRoomFor(timestamp, key, value, headers)) { + return false; + } else { + // No need to get the CRC. + this.recordsBuilder.append(timestamp, key, value, headers); + this.maxRecordSize = Math.max(this.maxRecordSize, AbstractRecords.estimateSizeInBytesUpperBound(magic(), + recordsBuilder.compressionType(), key, value, headers)); + FutureRecordMetadata future = new FutureRecordMetadata(this.produceFuture, this.recordCount, + timestamp, + key == null ? -1 : key.remaining(), + value == null ? -1 : value.remaining(), + Time.SYSTEM); + // Chain the future to the original thunk. + thunk.future.chain(future); + this.thunks.add(thunk); + this.recordCount++; + return true; + } + } + + /** + * Abort the batch and complete the future and callbacks. + * + * @param exception The exception to use to complete the future and awaiting callbacks. + */ + public void abort(RuntimeException exception) { + if (!finalState.compareAndSet(null, FinalState.ABORTED)) + throw new IllegalStateException("Batch has already been completed in final state " + finalState.get()); + + log.trace("Aborting batch for partition {}", topicPartition, exception); + completeFutureAndFireCallbacks(ProduceResponse.INVALID_OFFSET, RecordBatch.NO_TIMESTAMP, index -> exception); + } + + /** + * Check if the batch has been completed (either successfully or exceptionally). + * @return `true` if the batch has been completed, `false` otherwise. + */ + public boolean isDone() { + return finalState() != null; + } + + /** + * Complete the batch successfully. + * @param baseOffset The base offset of the messages assigned by the server + * @param logAppendTime The log append time or -1 if CreateTime is being used + * @return true if the batch was completed as a result of this call, and false + * if it had been completed previously + */ + public boolean complete(long baseOffset, long logAppendTime) { + return done(baseOffset, logAppendTime, null, null); + } + + /** + * Complete the batch exceptionally. The provided top-level exception will be used + * for each record future contained in the batch. + * + * @param topLevelException top-level partition error + * @param recordExceptions Record exception function mapping batchIndex to the respective record exception + * @return true if the batch was completed as a result of this call, and false + * if it had been completed previously + */ + public boolean completeExceptionally( + RuntimeException topLevelException, + Function recordExceptions + ) { + Objects.requireNonNull(topLevelException); + Objects.requireNonNull(recordExceptions); + return done(ProduceResponse.INVALID_OFFSET, RecordBatch.NO_TIMESTAMP, topLevelException, recordExceptions); + } + + /** + * Finalize the state of a batch. Final state, once set, is immutable. This function may be called + * once or twice on a batch. It may be called twice if + * 1. An inflight batch expires before a response from the broker is received. The batch's final + * state is set to FAILED. But it could succeed on the broker and second time around batch.done() may + * try to set SUCCEEDED final state. + * 2. If a transaction abortion happens or if the producer is closed forcefully, the final state is + * ABORTED but again it could succeed if broker responds with a success. + * + * Attempted transitions from [FAILED | ABORTED] --> SUCCEEDED are logged. + * Attempted transitions from one failure state to the same or a different failed state are ignored. + * Attempted transitions from SUCCEEDED to the same or a failed state throw an exception. + * + * @param baseOffset The base offset of the messages assigned by the server + * @param logAppendTime The log append time or -1 if CreateTime is being used + * @param topLevelException The exception that occurred (or null if the request was successful) + * @param recordExceptions Record exception function mapping batchIndex to the respective record exception + * @return true if the batch was completed successfully and false if the batch was previously aborted + */ + private boolean done( + long baseOffset, + long logAppendTime, + RuntimeException topLevelException, + Function recordExceptions + ) { + final FinalState tryFinalState = (topLevelException == null) ? FinalState.SUCCEEDED : FinalState.FAILED; + if (tryFinalState == FinalState.SUCCEEDED) { + log.trace("Successfully produced messages to {} with base offset {}.", topicPartition, baseOffset); + } else { + log.trace("Failed to produce messages to {} with base offset {}.", topicPartition, baseOffset, topLevelException); + } + + if (this.finalState.compareAndSet(null, tryFinalState)) { + completeFutureAndFireCallbacks(baseOffset, logAppendTime, recordExceptions); + return true; + } + + if (this.finalState.get() != FinalState.SUCCEEDED) { + if (tryFinalState == FinalState.SUCCEEDED) { + // Log if a previously unsuccessful batch succeeded later on. + log.debug("ProduceResponse returned {} for {} after batch with base offset {} had already been {}.", + tryFinalState, topicPartition, baseOffset, this.finalState.get()); + } else { + // FAILED --> FAILED and ABORTED --> FAILED transitions are ignored. + log.debug("Ignored state transition {} -> {} for {} batch with base offset {}", + this.finalState.get(), tryFinalState, topicPartition, baseOffset); + } + } else { + // A SUCCESSFUL batch must not attempt another state change. + throw new IllegalStateException("A " + this.finalState.get() + " batch must not attempt another state change to " + tryFinalState); + } + return false; + } + + private void completeFutureAndFireCallbacks( + long baseOffset, + long logAppendTime, + Function recordExceptions + ) { + // Set the future before invoking the callbacks as we rely on its state for the `onCompletion` call + produceFuture.set(baseOffset, logAppendTime, recordExceptions); + + // execute callbacks + for (int i = 0; i < thunks.size(); i++) { + try { + Thunk thunk = thunks.get(i); + if (thunk.callback != null) { + if (recordExceptions == null) { + RecordMetadata metadata = thunk.future.value(); + thunk.callback.onCompletion(metadata, null); + } else { + RuntimeException exception = recordExceptions.apply(i); + thunk.callback.onCompletion(null, exception); + } + } + } catch (Exception e) { + log.error("Error executing user-provided callback on message for topic-partition '{}'", topicPartition, e); + } + } + + produceFuture.done(); + } + + public Deque split(int splitBatchSize) { + Deque batches = new ArrayDeque<>(); + MemoryRecords memoryRecords = recordsBuilder.build(); + + Iterator recordBatchIter = memoryRecords.batches().iterator(); + if (!recordBatchIter.hasNext()) + throw new IllegalStateException("Cannot split an empty producer batch."); + + RecordBatch recordBatch = recordBatchIter.next(); + if (recordBatch.magic() < MAGIC_VALUE_V2 && !recordBatch.isCompressed()) + throw new IllegalArgumentException("Batch splitting cannot be used with non-compressed messages " + + "with version v0 and v1"); + + if (recordBatchIter.hasNext()) + throw new IllegalArgumentException("A producer batch should only have one record batch."); + + Iterator thunkIter = thunks.iterator(); + // We always allocate batch size because we are already splitting a big batch. + // And we also Retain the create time of the original batch. + ProducerBatch batch = null; + + for (Record record : recordBatch) { + assert thunkIter.hasNext(); + Thunk thunk = thunkIter.next(); + if (batch == null) + batch = createBatchOffAccumulatorForRecord(record, splitBatchSize); + + // A newly created batch can always host the first message. + if (!batch.tryAppendForSplit(record.timestamp(), record.key(), record.value(), record.headers(), thunk)) { + batches.add(batch); + batch.closeForRecordAppends(); + batch = createBatchOffAccumulatorForRecord(record, splitBatchSize); + batch.tryAppendForSplit(record.timestamp(), record.key(), record.value(), record.headers(), thunk); + } + } + + // Close the last batch and add it to the batch list after split. + if (batch != null) { + batches.add(batch); + batch.closeForRecordAppends(); + } + + produceFuture.set(ProduceResponse.INVALID_OFFSET, NO_TIMESTAMP, index -> new RecordBatchTooLargeException()); + produceFuture.done(); + + if (hasSequence()) { + int sequence = baseSequence(); + ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(producerId(), producerEpoch()); + for (ProducerBatch newBatch : batches) { + newBatch.setProducerState(producerIdAndEpoch, sequence, isTransactional()); + sequence += newBatch.recordCount; + } + } + return batches; + } + + private ProducerBatch createBatchOffAccumulatorForRecord(Record record, int batchSize) { + int initialSize = Math.max(AbstractRecords.estimateSizeInBytesUpperBound(magic(), + recordsBuilder.compressionType(), record.key(), record.value(), record.headers()), batchSize); + ByteBuffer buffer = ByteBuffer.allocate(initialSize); + + // Note that we intentionally do not set producer state (producerId, epoch, sequence, and isTransactional) + // for the newly created batch. This will be set when the batch is dequeued for sending (which is consistent + // with how normal batches are handled). + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, magic(), recordsBuilder.compressionType(), + TimestampType.CREATE_TIME, 0L); + return new ProducerBatch(topicPartition, builder, this.createdMs, true); + } + + public boolean isCompressed() { + return recordsBuilder.compressionType() != CompressionType.NONE; + } + + /** + * A callback and the associated FutureRecordMetadata argument to pass to it. + */ + final private static class Thunk { + final Callback callback; + final FutureRecordMetadata future; + + Thunk(Callback callback, FutureRecordMetadata future) { + this.callback = callback; + this.future = future; + } + } + + @Override + public String toString() { + return "ProducerBatch(topicPartition=" + topicPartition + ", recordCount=" + recordCount + ")"; + } + + boolean hasReachedDeliveryTimeout(long deliveryTimeoutMs, long now) { + return deliveryTimeoutMs <= now - this.createdMs; + } + + public FinalState finalState() { + return this.finalState.get(); + } + + int attempts() { + return attempts.get(); + } + + void reenqueued(long now) { + attempts.getAndIncrement(); + lastAttemptMs = Math.max(lastAppendTime, now); + lastAppendTime = Math.max(lastAppendTime, now); + retry = true; + } + + long queueTimeMs() { + return drainedMs - createdMs; + } + + long waitedTimeMs(long nowMs) { + return Math.max(0, nowMs - lastAttemptMs); + } + + void drained(long nowMs) { + this.drainedMs = Math.max(drainedMs, nowMs); + } + + boolean isSplitBatch() { + return isSplitBatch; + } + + /** + * Returns if the batch is been retried for sending to kafka + */ + public boolean inRetry() { + return this.retry; + } + + public MemoryRecords records() { + return recordsBuilder.build(); + } + + public int estimatedSizeInBytes() { + return recordsBuilder.estimatedSizeInBytes(); + } + + public double compressionRatio() { + return recordsBuilder.compressionRatio(); + } + + public boolean isFull() { + return recordsBuilder.isFull(); + } + + public void setProducerState(ProducerIdAndEpoch producerIdAndEpoch, int baseSequence, boolean isTransactional) { + recordsBuilder.setProducerState(producerIdAndEpoch.producerId, producerIdAndEpoch.epoch, baseSequence, isTransactional); + } + + public void resetProducerState(ProducerIdAndEpoch producerIdAndEpoch, int baseSequence, boolean isTransactional) { + log.info("Resetting sequence number of batch with current sequence {} for partition {} to {}", + this.baseSequence(), this.topicPartition, baseSequence); + reopened = true; + recordsBuilder.reopenAndRewriteProducerState(producerIdAndEpoch.producerId, producerIdAndEpoch.epoch, baseSequence, isTransactional); + } + + /** + * Release resources required for record appends (e.g. compression buffers). Once this method is called, it's only + * possible to update the RecordBatch header. + */ + public void closeForRecordAppends() { + recordsBuilder.closeForRecordAppends(); + } + + public void close() { + recordsBuilder.close(); + if (!recordsBuilder.isControlBatch()) { + CompressionRatioEstimator.updateEstimation(topicPartition.topic(), + recordsBuilder.compressionType(), + (float) recordsBuilder.compressionRatio()); + } + reopened = false; + } + + /** + * Abort the record builder and reset the state of the underlying buffer. This is used prior to aborting + * the batch with {@link #abort(RuntimeException)} and ensures that no record previously appended can be + * read. This is used in scenarios where we want to ensure a batch ultimately gets aborted, but in which + * it is not safe to invoke the completion callbacks (e.g. because we are holding a lock, such as + * when aborting batches in {@link RecordAccumulator}). + */ + public void abortRecordAppends() { + recordsBuilder.abort(); + } + + public boolean isClosed() { + return recordsBuilder.isClosed(); + } + + public ByteBuffer buffer() { + return recordsBuilder.buffer(); + } + + public int initialCapacity() { + return recordsBuilder.initialCapacity(); + } + + public boolean isWritable() { + return !recordsBuilder.isClosed(); + } + + public byte magic() { + return recordsBuilder.magic(); + } + + public long producerId() { + return recordsBuilder.producerId(); + } + + public short producerEpoch() { + return recordsBuilder.producerEpoch(); + } + + public int baseSequence() { + return recordsBuilder.baseSequence(); + } + + public int lastSequence() { + return recordsBuilder.baseSequence() + recordsBuilder.numRecords() - 1; + } + + public boolean hasSequence() { + return baseSequence() != RecordBatch.NO_SEQUENCE; + } + + public boolean isTransactional() { + return recordsBuilder.isTransactional(); + } + + public boolean sequenceHasBeenReset() { + return reopened; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerInterceptors.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerInterceptors.java new file mode 100644 index 0000000..ceec552 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerInterceptors.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer.internals; + + +import org.apache.kafka.clients.producer.ProducerInterceptor; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.record.RecordBatch; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Closeable; +import java.util.List; + +/** + * A container that holds the list {@link org.apache.kafka.clients.producer.ProducerInterceptor} + * and wraps calls to the chain of custom interceptors. + */ +public class ProducerInterceptors implements Closeable { + private static final Logger log = LoggerFactory.getLogger(ProducerInterceptors.class); + private final List> interceptors; + + public ProducerInterceptors(List> interceptors) { + this.interceptors = interceptors; + } + + /** + * This is called when client sends the record to KafkaProducer, before key and value gets serialized. + * The method calls {@link ProducerInterceptor#onSend(ProducerRecord)} method. ProducerRecord + * returned from the first interceptor's onSend() is passed to the second interceptor onSend(), and so on in the + * interceptor chain. The record returned from the last interceptor is returned from this method. + * + * This method does not throw exceptions. Exceptions thrown by any of interceptor methods are caught and ignored. + * If an interceptor in the middle of the chain, that normally modifies the record, throws an exception, + * the next interceptor in the chain will be called with a record returned by the previous interceptor that did not + * throw an exception. + * + * @param record the record from client + * @return producer record to send to topic/partition + */ + public ProducerRecord onSend(ProducerRecord record) { + ProducerRecord interceptRecord = record; + for (ProducerInterceptor interceptor : this.interceptors) { + try { + interceptRecord = interceptor.onSend(interceptRecord); + } catch (Exception e) { + // do not propagate interceptor exception, log and continue calling other interceptors + // be careful not to throw exception from here + if (record != null) + log.warn("Error executing interceptor onSend callback for topic: {}, partition: {}", record.topic(), record.partition(), e); + else + log.warn("Error executing interceptor onSend callback", e); + } + } + return interceptRecord; + } + + /** + * This method is called when the record sent to the server has been acknowledged, or when sending the record fails before + * it gets sent to the server. This method calls {@link ProducerInterceptor#onAcknowledgement(RecordMetadata, Exception)} + * method for each interceptor. + * + * This method does not throw exceptions. Exceptions thrown by any of interceptor methods are caught and ignored. + * + * @param metadata The metadata for the record that was sent (i.e. the partition and offset). + * If an error occurred, metadata will only contain valid topic and maybe partition. + * @param exception The exception thrown during processing of this record. Null if no error occurred. + */ + public void onAcknowledgement(RecordMetadata metadata, Exception exception) { + for (ProducerInterceptor interceptor : this.interceptors) { + try { + interceptor.onAcknowledgement(metadata, exception); + } catch (Exception e) { + // do not propagate interceptor exceptions, just log + log.warn("Error executing interceptor onAcknowledgement callback", e); + } + } + } + + /** + * This method is called when sending the record fails in {@link ProducerInterceptor#onSend + * (ProducerRecord)} method. This method calls {@link ProducerInterceptor#onAcknowledgement(RecordMetadata, Exception)} + * method for each interceptor + * + * @param record The record from client + * @param interceptTopicPartition The topic/partition for the record if an error occurred + * after partition gets assigned; the topic part of interceptTopicPartition is the same as in record. + * @param exception The exception thrown during processing of this record. + */ + public void onSendError(ProducerRecord record, TopicPartition interceptTopicPartition, Exception exception) { + for (ProducerInterceptor interceptor : this.interceptors) { + try { + if (record == null && interceptTopicPartition == null) { + interceptor.onAcknowledgement(null, exception); + } else { + if (interceptTopicPartition == null) { + interceptTopicPartition = new TopicPartition(record.topic(), + record.partition() == null ? RecordMetadata.UNKNOWN_PARTITION : record.partition()); + } + interceptor.onAcknowledgement(new RecordMetadata(interceptTopicPartition, -1, -1, + RecordBatch.NO_TIMESTAMP, -1, -1), exception); + } + } catch (Exception e) { + // do not propagate interceptor exceptions, just log + log.warn("Error executing interceptor onAcknowledgement callback", e); + } + } + } + + /** + * Closes every interceptor in a container. + */ + @Override + public void close() { + for (ProducerInterceptor interceptor : this.interceptors) { + try { + interceptor.close(); + } catch (Exception e) { + log.error("Failed to close producer interceptor ", e); + } + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerMetadata.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerMetadata.java new file mode 100644 index 0000000..d7a88a1 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerMetadata.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer.internals; + +import org.apache.kafka.clients.Metadata; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.internals.ClusterResourceListeners; +import org.apache.kafka.common.requests.MetadataRequest; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.slf4j.Logger; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +public class ProducerMetadata extends Metadata { + // If a topic hasn't been accessed for this many milliseconds, it is removed from the cache. + private final long metadataIdleMs; + + /* Topics with expiry time */ + private final Map topics = new HashMap<>(); + private final Set newTopics = new HashSet<>(); + private final Logger log; + private final Time time; + + public ProducerMetadata(long refreshBackoffMs, + long metadataExpireMs, + long metadataIdleMs, + LogContext logContext, + ClusterResourceListeners clusterResourceListeners, + Time time) { + super(refreshBackoffMs, metadataExpireMs, logContext, clusterResourceListeners); + this.metadataIdleMs = metadataIdleMs; + this.log = logContext.logger(ProducerMetadata.class); + this.time = time; + } + + @Override + public synchronized MetadataRequest.Builder newMetadataRequestBuilder() { + return new MetadataRequest.Builder(new ArrayList<>(topics.keySet()), true); + } + + @Override + public synchronized MetadataRequest.Builder newMetadataRequestBuilderForNewTopics() { + return new MetadataRequest.Builder(new ArrayList<>(newTopics), true); + } + + public synchronized void add(String topic, long nowMs) { + Objects.requireNonNull(topic, "topic cannot be null"); + if (topics.put(topic, nowMs + metadataIdleMs) == null) { + newTopics.add(topic); + requestUpdateForNewTopics(); + } + } + + public synchronized int requestUpdateForTopic(String topic) { + if (newTopics.contains(topic)) { + return requestUpdateForNewTopics(); + } else { + return requestUpdate(); + } + } + + // Visible for testing + synchronized Set topics() { + return topics.keySet(); + } + + // Visible for testing + synchronized Set newTopics() { + return newTopics; + } + + public synchronized boolean containsTopic(String topic) { + return topics.containsKey(topic); + } + + @Override + public synchronized boolean retainTopic(String topic, boolean isInternal, long nowMs) { + Long expireMs = topics.get(topic); + if (expireMs == null) { + return false; + } else if (newTopics.contains(topic)) { + return true; + } else if (expireMs <= nowMs) { + log.debug("Removing unused topic {} from the metadata list, expiryMs {} now {}", topic, expireMs, nowMs); + topics.remove(topic); + return false; + } else { + return true; + } + } + + /** + * Wait for metadata update until the current version is larger than the last version we know of + */ + public synchronized void awaitUpdate(final int lastVersion, final long timeoutMs) throws InterruptedException { + long currentTimeMs = time.milliseconds(); + long deadlineMs = currentTimeMs + timeoutMs < 0 ? Long.MAX_VALUE : currentTimeMs + timeoutMs; + time.waitObject(this, () -> { + // Throw fatal exceptions, if there are any. Recoverable topic errors will be handled by the caller. + maybeThrowFatalException(); + return updateVersion() > lastVersion || isClosed(); + }, deadlineMs); + + if (isClosed()) + throw new KafkaException("Requested metadata update after close"); + } + + @Override + public synchronized void update(int requestVersion, MetadataResponse response, boolean isPartialUpdate, long nowMs) { + super.update(requestVersion, response, isPartialUpdate, nowMs); + + // Remove all topics in the response that are in the new topic set. Note that if an error was encountered for a + // new topic's metadata, then any work to resolve the error will include the topic in a full metadata update. + if (!newTopics.isEmpty()) { + for (MetadataResponse.TopicMetadata metadata : response.topicMetadata()) { + newTopics.remove(metadata.topic()); + } + } + + notifyAll(); + } + + @Override + public synchronized void fatalError(KafkaException fatalException) { + super.fatalError(fatalException); + notifyAll(); + } + + /** + * Close this instance and notify any awaiting threads. + */ + @Override + public synchronized void close() { + super.close(); + notifyAll(); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerMetrics.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerMetrics.java new file mode 100644 index 0000000..030a232 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/ProducerMetrics.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer.internals; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.apache.kafka.common.MetricNameTemplate; +import org.apache.kafka.common.metrics.MetricConfig; +import org.apache.kafka.common.metrics.Metrics; + +public class ProducerMetrics { + + public final SenderMetricsRegistry senderMetrics; + private final Metrics metrics; + + public ProducerMetrics(Metrics metrics) { + this.metrics = metrics; + this.senderMetrics = new SenderMetricsRegistry(this.metrics); + } + + private List getAllTemplates() { + List l = new ArrayList<>(this.senderMetrics.allTemplates()); + return l; + } + + public static void main(String[] args) { + Map metricTags = Collections.singletonMap("client-id", "client-id"); + MetricConfig metricConfig = new MetricConfig().tags(metricTags); + Metrics metrics = new Metrics(metricConfig); + + ProducerMetrics metricsRegistry = new ProducerMetrics(metrics); + System.out.println(Metrics.toHtmlTable("kafka.producer", metricsRegistry.getAllTemplates())); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java new file mode 100644 index 0000000..24a80b9 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/RecordAccumulator.java @@ -0,0 +1,846 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer.internals; + +import java.nio.ByteBuffer; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Deque; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.kafka.clients.ApiVersions; +import org.apache.kafka.clients.producer.Callback; +import org.apache.kafka.common.utils.ProducerIdAndEpoch; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.metrics.Measurable; +import org.apache.kafka.common.metrics.MetricConfig; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.record.AbstractRecords; +import org.apache.kafka.common.record.CompressionRatioEstimator; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.MemoryRecordsBuilder; +import org.apache.kafka.common.record.Record; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.common.utils.CopyOnWriteMap; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.slf4j.Logger; + +/** + * This class acts as a queue that accumulates records into {@link MemoryRecords} + * instances to be sent to the server. + *

        + * The accumulator uses a bounded amount of memory and append calls will block when that memory is exhausted, unless + * this behavior is explicitly disabled. + */ +public final class RecordAccumulator { + + private final Logger log; + private volatile boolean closed; + private final AtomicInteger flushesInProgress; + private final AtomicInteger appendsInProgress; + private final int batchSize; + private final CompressionType compression; + private final int lingerMs; + private final long retryBackoffMs; + private final int deliveryTimeoutMs; + private final BufferPool free; + private final Time time; + private final ApiVersions apiVersions; + private final ConcurrentMap> batches; + private final IncompleteBatches incomplete; + // The following variables are only accessed by the sender thread, so we don't need to protect them. + private final Set muted; + private int drainIndex; + private final TransactionManager transactionManager; + private long nextBatchExpiryTimeMs = Long.MAX_VALUE; // the earliest time (absolute) a batch will expire. + + /** + * Create a new record accumulator + * + * @param logContext The log context used for logging + * @param batchSize The size to use when allocating {@link MemoryRecords} instances + * @param compression The compression codec for the records + * @param lingerMs An artificial delay time to add before declaring a records instance that isn't full ready for + * sending. This allows time for more records to arrive. Setting a non-zero lingerMs will trade off some + * latency for potentially better throughput due to more batching (and hence fewer, larger requests). + * @param retryBackoffMs An artificial delay time to retry the produce request upon receiving an error. This avoids + * exhausting all retries in a short period of time. + * @param metrics The metrics + * @param time The time instance to use + * @param apiVersions Request API versions for current connected brokers + * @param transactionManager The shared transaction state object which tracks producer IDs, epochs, and sequence + * numbers per partition. + */ + public RecordAccumulator(LogContext logContext, + int batchSize, + CompressionType compression, + int lingerMs, + long retryBackoffMs, + int deliveryTimeoutMs, + Metrics metrics, + String metricGrpName, + Time time, + ApiVersions apiVersions, + TransactionManager transactionManager, + BufferPool bufferPool) { + this.log = logContext.logger(RecordAccumulator.class); + this.drainIndex = 0; + this.closed = false; + this.flushesInProgress = new AtomicInteger(0); + this.appendsInProgress = new AtomicInteger(0); + this.batchSize = batchSize; + this.compression = compression; + this.lingerMs = lingerMs; + this.retryBackoffMs = retryBackoffMs; + this.deliveryTimeoutMs = deliveryTimeoutMs; + this.batches = new CopyOnWriteMap<>(); + this.free = bufferPool; + this.incomplete = new IncompleteBatches(); + this.muted = new HashSet<>(); + this.time = time; + this.apiVersions = apiVersions; + this.transactionManager = transactionManager; + registerMetrics(metrics, metricGrpName); + } + + private void registerMetrics(Metrics metrics, String metricGrpName) { + MetricName metricName = metrics.metricName("waiting-threads", metricGrpName, "The number of user threads blocked waiting for buffer memory to enqueue their records"); + Measurable waitingThreads = new Measurable() { + public double measure(MetricConfig config, long now) { + return free.queued(); + } + }; + metrics.addMetric(metricName, waitingThreads); + + metricName = metrics.metricName("buffer-total-bytes", metricGrpName, "The maximum amount of buffer memory the client can use (whether or not it is currently used)."); + Measurable totalBytes = new Measurable() { + public double measure(MetricConfig config, long now) { + return free.totalMemory(); + } + }; + metrics.addMetric(metricName, totalBytes); + + metricName = metrics.metricName("buffer-available-bytes", metricGrpName, "The total amount of buffer memory that is not being used (either unallocated or in the free list)."); + Measurable availableBytes = new Measurable() { + public double measure(MetricConfig config, long now) { + return free.availableMemory(); + } + }; + metrics.addMetric(metricName, availableBytes); + } + + /** + * Add a record to the accumulator, return the append result + *

        + * The append result will contain the future metadata, and flag for whether the appended batch is full or a new batch is created + *

        + * + * @param tp The topic/partition to which this record is being sent + * @param timestamp The timestamp of the record + * @param key The key for the record + * @param value The value for the record + * @param headers the Headers for the record + * @param callback The user-supplied callback to execute when the request is complete + * @param maxTimeToBlock The maximum time in milliseconds to block for buffer memory to be available + * @param abortOnNewBatch A boolean that indicates returning before a new batch is created and + * running the partitioner's onNewBatch method before trying to append again + * @param nowMs The current time, in milliseconds + */ + public RecordAppendResult append(TopicPartition tp, + long timestamp, + byte[] key, + byte[] value, + Header[] headers, + Callback callback, + long maxTimeToBlock, + boolean abortOnNewBatch, + long nowMs) throws InterruptedException { + // We keep track of the number of appending thread to make sure we do not miss batches in + // abortIncompleteBatches(). + appendsInProgress.incrementAndGet(); + ByteBuffer buffer = null; + if (headers == null) headers = Record.EMPTY_HEADERS; + try { + // check if we have an in-progress batch + Deque dq = getOrCreateDeque(tp); + synchronized (dq) { + if (closed) + throw new KafkaException("Producer closed while send in progress"); + RecordAppendResult appendResult = tryAppend(timestamp, key, value, headers, callback, dq, nowMs); + if (appendResult != null) + return appendResult; + } + + // we don't have an in-progress record batch try to allocate a new batch + if (abortOnNewBatch) { + // Return a result that will cause another call to append. + return new RecordAppendResult(null, false, false, true); + } + + byte maxUsableMagic = apiVersions.maxUsableProduceMagic(); + int size = Math.max(this.batchSize, AbstractRecords.estimateSizeInBytesUpperBound(maxUsableMagic, compression, key, value, headers)); + log.trace("Allocating a new {} byte message buffer for topic {} partition {} with remaining timeout {}ms", size, tp.topic(), tp.partition(), maxTimeToBlock); + buffer = free.allocate(size, maxTimeToBlock); + + // Update the current time in case the buffer allocation blocked above. + nowMs = time.milliseconds(); + synchronized (dq) { + // Need to check if producer is closed again after grabbing the dequeue lock. + if (closed) + throw new KafkaException("Producer closed while send in progress"); + + RecordAppendResult appendResult = tryAppend(timestamp, key, value, headers, callback, dq, nowMs); + if (appendResult != null) { + // Somebody else found us a batch, return the one we waited for! Hopefully this doesn't happen often... + return appendResult; + } + + MemoryRecordsBuilder recordsBuilder = recordsBuilder(buffer, maxUsableMagic); + ProducerBatch batch = new ProducerBatch(tp, recordsBuilder, nowMs); + FutureRecordMetadata future = Objects.requireNonNull(batch.tryAppend(timestamp, key, value, headers, + callback, nowMs)); + + dq.addLast(batch); + incomplete.add(batch); + + // Don't deallocate this buffer in the finally block as it's being used in the record batch + buffer = null; + return new RecordAppendResult(future, dq.size() > 1 || batch.isFull(), true, false); + } + } finally { + if (buffer != null) + free.deallocate(buffer); + appendsInProgress.decrementAndGet(); + } + } + + private MemoryRecordsBuilder recordsBuilder(ByteBuffer buffer, byte maxUsableMagic) { + if (transactionManager != null && maxUsableMagic < RecordBatch.MAGIC_VALUE_V2) { + throw new UnsupportedVersionException("Attempting to use idempotence with a broker which does not " + + "support the required message format (v2). The broker must be version 0.11 or later."); + } + return MemoryRecords.builder(buffer, maxUsableMagic, compression, TimestampType.CREATE_TIME, 0L); + } + + /** + * Try to append to a ProducerBatch. + * + * If it is full, we return null and a new batch is created. We also close the batch for record appends to free up + * resources like compression buffers. The batch will be fully closed (ie. the record batch headers will be written + * and memory records built) in one of the following cases (whichever comes first): right before send, + * if it is expired, or when the producer is closed. + */ + private RecordAppendResult tryAppend(long timestamp, byte[] key, byte[] value, Header[] headers, + Callback callback, Deque deque, long nowMs) { + ProducerBatch last = deque.peekLast(); + if (last != null) { + FutureRecordMetadata future = last.tryAppend(timestamp, key, value, headers, callback, nowMs); + if (future == null) + last.closeForRecordAppends(); + else + return new RecordAppendResult(future, deque.size() > 1 || last.isFull(), false, false); + } + return null; + } + + private boolean isMuted(TopicPartition tp) { + return muted.contains(tp); + } + + public void resetNextBatchExpiryTime() { + nextBatchExpiryTimeMs = Long.MAX_VALUE; + } + + public void maybeUpdateNextBatchExpiryTime(ProducerBatch batch) { + if (batch.createdMs + deliveryTimeoutMs > 0) { + // the non-negative check is to guard us against potential overflow due to setting + // a large value for deliveryTimeoutMs + nextBatchExpiryTimeMs = Math.min(nextBatchExpiryTimeMs, batch.createdMs + deliveryTimeoutMs); + } else { + log.warn("Skipping next batch expiry time update due to addition overflow: " + + "batch.createMs={}, deliveryTimeoutMs={}", batch.createdMs, deliveryTimeoutMs); + } + } + + /** + * Get a list of batches which have been sitting in the accumulator too long and need to be expired. + */ + public List expiredBatches(long now) { + List expiredBatches = new ArrayList<>(); + for (Map.Entry> entry : this.batches.entrySet()) { + // expire the batches in the order of sending + Deque deque = entry.getValue(); + synchronized (deque) { + while (!deque.isEmpty()) { + ProducerBatch batch = deque.getFirst(); + if (batch.hasReachedDeliveryTimeout(deliveryTimeoutMs, now)) { + deque.poll(); + batch.abortRecordAppends(); + expiredBatches.add(batch); + } else { + maybeUpdateNextBatchExpiryTime(batch); + break; + } + } + } + } + return expiredBatches; + } + + public long getDeliveryTimeoutMs() { + return deliveryTimeoutMs; + } + + /** + * Re-enqueue the given record batch in the accumulator. In Sender.completeBatch method, we check + * whether the batch has reached deliveryTimeoutMs or not. Hence we do not do the delivery timeout check here. + */ + public void reenqueue(ProducerBatch batch, long now) { + batch.reenqueued(now); + Deque deque = getOrCreateDeque(batch.topicPartition); + synchronized (deque) { + if (transactionManager != null) + insertInSequenceOrder(deque, batch); + else + deque.addFirst(batch); + } + } + + /** + * Split the big batch that has been rejected and reenqueue the split batches in to the accumulator. + * @return the number of split batches. + */ + public int splitAndReenqueue(ProducerBatch bigBatch) { + // Reset the estimated compression ratio to the initial value or the big batch compression ratio, whichever + // is bigger. There are several different ways to do the reset. We chose the most conservative one to ensure + // the split doesn't happen too often. + CompressionRatioEstimator.setEstimation(bigBatch.topicPartition.topic(), compression, + Math.max(1.0f, (float) bigBatch.compressionRatio())); + Deque dq = bigBatch.split(this.batchSize); + int numSplitBatches = dq.size(); + Deque partitionDequeue = getOrCreateDeque(bigBatch.topicPartition); + while (!dq.isEmpty()) { + ProducerBatch batch = dq.pollLast(); + incomplete.add(batch); + // We treat the newly split batches as if they are not even tried. + synchronized (partitionDequeue) { + if (transactionManager != null) { + // We should track the newly created batches since they already have assigned sequences. + transactionManager.addInFlightBatch(batch); + insertInSequenceOrder(partitionDequeue, batch); + } else { + partitionDequeue.addFirst(batch); + } + } + } + return numSplitBatches; + } + + // We will have to do extra work to ensure the queue is in order when requests are being retried and there are + // multiple requests in flight to that partition. If the first in flight request fails to append, then all the + // subsequent in flight requests will also fail because the sequence numbers will not be accepted. + // + // Further, once batches are being retried, we are reduced to a single in flight request for that partition. So when + // the subsequent batches come back in sequence order, they will have to be placed further back in the queue. + // + // Note that this assumes that all the batches in the queue which have an assigned sequence also have the current + // producer id. We will not attempt to reorder messages if the producer id has changed, we will throw an + // IllegalStateException instead. + private void insertInSequenceOrder(Deque deque, ProducerBatch batch) { + // When we are requeing and have enabled idempotence, the reenqueued batch must always have a sequence. + if (batch.baseSequence() == RecordBatch.NO_SEQUENCE) + throw new IllegalStateException("Trying to re-enqueue a batch which doesn't have a sequence even " + + "though idempotency is enabled."); + + if (transactionManager.nextBatchBySequence(batch.topicPartition) == null) + throw new IllegalStateException("We are re-enqueueing a batch which is not tracked as part of the in flight " + + "requests. batch.topicPartition: " + batch.topicPartition + "; batch.baseSequence: " + batch.baseSequence()); + + ProducerBatch firstBatchInQueue = deque.peekFirst(); + if (firstBatchInQueue != null && firstBatchInQueue.hasSequence() && firstBatchInQueue.baseSequence() < batch.baseSequence()) { + // The incoming batch can't be inserted at the front of the queue without violating the sequence ordering. + // This means that the incoming batch should be placed somewhere further back. + // We need to find the right place for the incoming batch and insert it there. + // We will only enter this branch if we have multiple inflights sent to different brokers and we need to retry + // the inflight batches. + // + // Since we reenqueue exactly one batch a time and ensure that the queue is ordered by sequence always, it + // is a simple linear scan of a subset of the in flight batches to find the right place in the queue each time. + List orderedBatches = new ArrayList<>(); + while (deque.peekFirst() != null && deque.peekFirst().hasSequence() && deque.peekFirst().baseSequence() < batch.baseSequence()) + orderedBatches.add(deque.pollFirst()); + + log.debug("Reordered incoming batch with sequence {} for partition {}. It was placed in the queue at " + + "position {}", batch.baseSequence(), batch.topicPartition, orderedBatches.size()); + // Either we have reached a point where there are batches without a sequence (ie. never been drained + // and are hence in order by default), or the batch at the front of the queue has a sequence greater + // than the incoming batch. This is the right place to add the incoming batch. + deque.addFirst(batch); + + // Now we have to re insert the previously queued batches in the right order. + for (int i = orderedBatches.size() - 1; i >= 0; --i) { + deque.addFirst(orderedBatches.get(i)); + } + + // At this point, the incoming batch has been queued in the correct place according to its sequence. + } else { + deque.addFirst(batch); + } + } + + /** + * Get a list of nodes whose partitions are ready to be sent, and the earliest time at which any non-sendable + * partition will be ready; Also return the flag for whether there are any unknown leaders for the accumulated + * partition batches. + *

        + * A destination node is ready to send data if: + *

          + *
        1. There is at least one partition that is not backing off its send + *
        2. and those partitions are not muted (to prevent reordering if + * {@value org.apache.kafka.clients.producer.ProducerConfig#MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION} + * is set to one)
        3. + *
        4. and any of the following are true
        5. + *
            + *
          • The record set is full
          • + *
          • The record set has sat in the accumulator for at least lingerMs milliseconds
          • + *
          • The accumulator is out of memory and threads are blocking waiting for data (in this case all partitions + * are immediately considered ready).
          • + *
          • The accumulator has been closed
          • + *
          + *
        + */ + public ReadyCheckResult ready(Cluster cluster, long nowMs) { + Set readyNodes = new HashSet<>(); + long nextReadyCheckDelayMs = Long.MAX_VALUE; + Set unknownLeaderTopics = new HashSet<>(); + + boolean exhausted = this.free.queued() > 0; + for (Map.Entry> entry : this.batches.entrySet()) { + Deque deque = entry.getValue(); + synchronized (deque) { + // When producing to a large number of partitions, this path is hot and deques are often empty. + // We check whether a batch exists first to avoid the more expensive checks whenever possible. + ProducerBatch batch = deque.peekFirst(); + if (batch != null) { + TopicPartition part = entry.getKey(); + Node leader = cluster.leaderFor(part); + if (leader == null) { + // This is a partition for which leader is not known, but messages are available to send. + // Note that entries are currently not removed from batches when deque is empty. + unknownLeaderTopics.add(part.topic()); + } else if (!readyNodes.contains(leader) && !isMuted(part)) { + long waitedTimeMs = batch.waitedTimeMs(nowMs); + boolean backingOff = batch.attempts() > 0 && waitedTimeMs < retryBackoffMs; + long timeToWaitMs = backingOff ? retryBackoffMs : lingerMs; + boolean full = deque.size() > 1 || batch.isFull(); + boolean expired = waitedTimeMs >= timeToWaitMs; + boolean transactionCompleting = transactionManager != null && transactionManager.isCompleting(); + boolean sendable = full + || expired + || exhausted + || closed + || flushInProgress() + || transactionCompleting; + if (sendable && !backingOff) { + readyNodes.add(leader); + } else { + long timeLeftMs = Math.max(timeToWaitMs - waitedTimeMs, 0); + // Note that this results in a conservative estimate since an un-sendable partition may have + // a leader that will later be found to have sendable data. However, this is good enough + // since we'll just wake up and then sleep again for the remaining time. + nextReadyCheckDelayMs = Math.min(timeLeftMs, nextReadyCheckDelayMs); + } + } + } + } + } + return new ReadyCheckResult(readyNodes, nextReadyCheckDelayMs, unknownLeaderTopics); + } + + /** + * Check whether there are any batches which haven't been drained + */ + public boolean hasUndrained() { + for (Map.Entry> entry : this.batches.entrySet()) { + Deque deque = entry.getValue(); + synchronized (deque) { + if (!deque.isEmpty()) + return true; + } + } + return false; + } + + private boolean shouldStopDrainBatchesForPartition(ProducerBatch first, TopicPartition tp) { + ProducerIdAndEpoch producerIdAndEpoch = null; + if (transactionManager != null) { + if (!transactionManager.isSendToPartitionAllowed(tp)) + return true; + + producerIdAndEpoch = transactionManager.producerIdAndEpoch(); + if (!producerIdAndEpoch.isValid()) + // we cannot send the batch until we have refreshed the producer id + return true; + + if (!first.hasSequence()) { + if (transactionManager.hasInflightBatches(tp) && transactionManager.hasStaleProducerIdAndEpoch(tp)) { + // Don't drain any new batches while the partition has in-flight batches with a different epoch + // and/or producer ID. Otherwise, a batch with a new epoch and sequence number + // 0 could be written before earlier batches complete, which would cause out of sequence errors + return true; + } + + if (transactionManager.hasUnresolvedSequence(first.topicPartition)) + // Don't drain any new batches while the state of previous sequence numbers + // is unknown. The previous batches would be unknown if they were aborted + // on the client after being sent to the broker at least once. + return true; + } + + int firstInFlightSequence = transactionManager.firstInFlightSequence(first.topicPartition); + if (firstInFlightSequence != RecordBatch.NO_SEQUENCE && first.hasSequence() + && first.baseSequence() != firstInFlightSequence) + // If the queued batch already has an assigned sequence, then it is being retried. + // In this case, we wait until the next immediate batch is ready and drain that. + // We only move on when the next in line batch is complete (either successfully or due to + // a fatal broker error). This effectively reduces our in flight request count to 1. + return true; + } + return false; + } + + private List drainBatchesForOneNode(Cluster cluster, Node node, int maxSize, long now) { + int size = 0; + List parts = cluster.partitionsForNode(node.id()); + List ready = new ArrayList<>(); + /* to make starvation less likely this loop doesn't start at 0 */ + int start = drainIndex = drainIndex % parts.size(); + do { + PartitionInfo part = parts.get(drainIndex); + TopicPartition tp = new TopicPartition(part.topic(), part.partition()); + this.drainIndex = (this.drainIndex + 1) % parts.size(); + + // Only proceed if the partition has no in-flight batches. + if (isMuted(tp)) + continue; + + Deque deque = getDeque(tp); + if (deque == null) + continue; + + synchronized (deque) { + // invariant: !isMuted(tp,now) && deque != null + ProducerBatch first = deque.peekFirst(); + if (first == null) + continue; + + // first != null + boolean backoff = first.attempts() > 0 && first.waitedTimeMs(now) < retryBackoffMs; + // Only drain the batch if it is not during backoff period. + if (backoff) + continue; + + if (size + first.estimatedSizeInBytes() > maxSize && !ready.isEmpty()) { + // there is a rare case that a single batch size is larger than the request size due to + // compression; in this case we will still eventually send this batch in a single request + break; + } else { + if (shouldStopDrainBatchesForPartition(first, tp)) + break; + + boolean isTransactional = transactionManager != null && transactionManager.isTransactional(); + ProducerIdAndEpoch producerIdAndEpoch = + transactionManager != null ? transactionManager.producerIdAndEpoch() : null; + ProducerBatch batch = deque.pollFirst(); + if (producerIdAndEpoch != null && !batch.hasSequence()) { + // If the producer id/epoch of the partition do not match the latest one + // of the producer, we update it and reset the sequence. This should be + // only done when all its in-flight batches have completed. This is guarantee + // in `shouldStopDrainBatchesForPartition`. + transactionManager.maybeUpdateProducerIdAndEpoch(batch.topicPartition); + + // If the batch already has an assigned sequence, then we should not change the producer id and + // sequence number, since this may introduce duplicates. In particular, the previous attempt + // may actually have been accepted, and if we change the producer id and sequence here, this + // attempt will also be accepted, causing a duplicate. + // + // Additionally, we update the next sequence number bound for the partition, and also have + // the transaction manager track the batch so as to ensure that sequence ordering is maintained + // even if we receive out of order responses. + batch.setProducerState(producerIdAndEpoch, transactionManager.sequenceNumber(batch.topicPartition), isTransactional); + transactionManager.incrementSequenceNumber(batch.topicPartition, batch.recordCount); + log.debug("Assigned producerId {} and producerEpoch {} to batch with base sequence " + + "{} being sent to partition {}", producerIdAndEpoch.producerId, + producerIdAndEpoch.epoch, batch.baseSequence(), tp); + + transactionManager.addInFlightBatch(batch); + } + batch.close(); + size += batch.records().sizeInBytes(); + ready.add(batch); + + batch.drained(now); + } + } + } while (start != drainIndex); + return ready; + } + + /** + * Drain all the data for the given nodes and collate them into a list of batches that will fit within the specified + * size on a per-node basis. This method attempts to avoid choosing the same topic-node over and over. + * + * @param cluster The current cluster metadata + * @param nodes The list of node to drain + * @param maxSize The maximum number of bytes to drain + * @param now The current unix time in milliseconds + * @return A list of {@link ProducerBatch} for each node specified with total size less than the requested maxSize. + */ + public Map> drain(Cluster cluster, Set nodes, int maxSize, long now) { + if (nodes.isEmpty()) + return Collections.emptyMap(); + + Map> batches = new HashMap<>(); + for (Node node : nodes) { + List ready = drainBatchesForOneNode(cluster, node, maxSize, now); + batches.put(node.id(), ready); + } + return batches; + } + + /** + * The earliest absolute time a batch will expire (in milliseconds) + */ + public long nextExpiryTimeMs() { + return this.nextBatchExpiryTimeMs; + } + + private Deque getDeque(TopicPartition tp) { + return batches.get(tp); + } + + /** + * Get the deque for the given topic-partition, creating it if necessary. + */ + private Deque getOrCreateDeque(TopicPartition tp) { + Deque d = this.batches.get(tp); + if (d != null) + return d; + d = new ArrayDeque<>(); + Deque previous = this.batches.putIfAbsent(tp, d); + if (previous == null) + return d; + else + return previous; + } + + /** + * Deallocate the record batch + */ + public void deallocate(ProducerBatch batch) { + incomplete.remove(batch); + // Only deallocate the batch if it is not a split batch because split batch are allocated outside the + // buffer pool. + if (!batch.isSplitBatch()) + free.deallocate(batch.buffer(), batch.initialCapacity()); + } + + /** + * Package private for unit test. Get the buffer pool remaining size in bytes. + */ + long bufferPoolAvailableMemory() { + return free.availableMemory(); + } + + /** + * Are there any threads currently waiting on a flush? + * + * package private for test + */ + boolean flushInProgress() { + return flushesInProgress.get() > 0; + } + + /* Visible for testing */ + Map> batches() { + return Collections.unmodifiableMap(batches); + } + + /** + * Initiate the flushing of data from the accumulator...this makes all requests immediately ready + */ + public void beginFlush() { + this.flushesInProgress.getAndIncrement(); + } + + /** + * Are there any threads currently appending messages? + */ + private boolean appendsInProgress() { + return appendsInProgress.get() > 0; + } + + /** + * Mark all partitions as ready to send and block until the send is complete + */ + public void awaitFlushCompletion() throws InterruptedException { + try { + // Obtain a copy of all of the incomplete ProduceRequestResult(s) at the time of the flush. + // We must be careful not to hold a reference to the ProduceBatch(s) so that garbage + // collection can occur on the contents. + // The sender will remove ProducerBatch(s) from the original incomplete collection. + for (ProduceRequestResult result : this.incomplete.requestResults()) + result.await(); + } finally { + this.flushesInProgress.decrementAndGet(); + } + } + + /** + * Check whether there are any pending batches (whether sent or unsent). + */ + public boolean hasIncomplete() { + return !this.incomplete.isEmpty(); + } + + /** + * This function is only called when sender is closed forcefully. It will fail all the + * incomplete batches and return. + */ + public void abortIncompleteBatches() { + // We need to keep aborting the incomplete batch until no thread is trying to append to + // 1. Avoid losing batches. + // 2. Free up memory in case appending threads are blocked on buffer full. + // This is a tight loop but should be able to get through very quickly. + do { + abortBatches(); + } while (appendsInProgress()); + // After this point, no thread will append any messages because they will see the close + // flag set. We need to do the last abort after no thread was appending in case there was a new + // batch appended by the last appending thread. + abortBatches(); + this.batches.clear(); + } + + /** + * Go through incomplete batches and abort them. + */ + private void abortBatches() { + abortBatches(new KafkaException("Producer is closed forcefully.")); + } + + /** + * Abort all incomplete batches (whether they have been sent or not) + */ + void abortBatches(final RuntimeException reason) { + for (ProducerBatch batch : incomplete.copyAll()) { + Deque dq = getDeque(batch.topicPartition); + synchronized (dq) { + batch.abortRecordAppends(); + dq.remove(batch); + } + batch.abort(reason); + deallocate(batch); + } + } + + /** + * Abort any batches which have not been drained + */ + void abortUndrainedBatches(RuntimeException reason) { + for (ProducerBatch batch : incomplete.copyAll()) { + Deque dq = getDeque(batch.topicPartition); + boolean aborted = false; + synchronized (dq) { + if ((transactionManager != null && !batch.hasSequence()) || (transactionManager == null && !batch.isClosed())) { + aborted = true; + batch.abortRecordAppends(); + dq.remove(batch); + } + } + if (aborted) { + batch.abort(reason); + deallocate(batch); + } + } + } + + public void mutePartition(TopicPartition tp) { + muted.add(tp); + } + + public void unmutePartition(TopicPartition tp) { + muted.remove(tp); + } + + /** + * Close this accumulator and force all the record buffers to be drained + */ + public void close() { + this.closed = true; + this.free.close(); + } + + /* + * Metadata about a record just appended to the record accumulator + */ + public final static class RecordAppendResult { + public final FutureRecordMetadata future; + public final boolean batchIsFull; + public final boolean newBatchCreated; + public final boolean abortForNewBatch; + + public RecordAppendResult(FutureRecordMetadata future, boolean batchIsFull, boolean newBatchCreated, boolean abortForNewBatch) { + this.future = future; + this.batchIsFull = batchIsFull; + this.newBatchCreated = newBatchCreated; + this.abortForNewBatch = abortForNewBatch; + } + } + + /* + * The set of nodes that have at least one complete record batch in the accumulator + */ + public final static class ReadyCheckResult { + public final Set readyNodes; + public final long nextReadyCheckDelayMs; + public final Set unknownLeaderTopics; + + public ReadyCheckResult(Set readyNodes, long nextReadyCheckDelayMs, Set unknownLeaderTopics) { + this.readyNodes = readyNodes; + this.nextReadyCheckDelayMs = nextReadyCheckDelayMs; + this.unknownLeaderTopics = unknownLeaderTopics; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java new file mode 100644 index 0000000..2f55e62 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/Sender.java @@ -0,0 +1,1023 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer.internals; + +import org.apache.kafka.clients.ApiVersions; +import org.apache.kafka.clients.ClientRequest; +import org.apache.kafka.clients.ClientResponse; +import org.apache.kafka.clients.KafkaClient; +import org.apache.kafka.clients.Metadata; +import org.apache.kafka.clients.NetworkClientUtils; +import org.apache.kafka.clients.RequestCompletionHandler; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.InvalidRecordException; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.errors.ClusterAuthorizationException; +import org.apache.kafka.common.errors.InvalidMetadataException; +import org.apache.kafka.common.errors.RetriableException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.TopicAuthorizationException; +import org.apache.kafka.common.errors.TransactionAbortedException; +import org.apache.kafka.common.errors.UnknownTopicOrPartitionException; +import org.apache.kafka.common.message.ProduceRequestData; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.stats.Avg; +import org.apache.kafka.common.metrics.stats.Max; +import org.apache.kafka.common.metrics.stats.Meter; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.requests.AbstractRequest; +import org.apache.kafka.common.requests.FindCoordinatorRequest; +import org.apache.kafka.common.requests.ProduceRequest; +import org.apache.kafka.common.requests.ProduceResponse; +import org.apache.kafka.common.requests.RequestHeader; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.slf4j.Logger; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * The background thread that handles the sending of produce requests to the Kafka cluster. This thread makes metadata + * requests to renew its view of the cluster and then sends produce requests to the appropriate nodes. + */ +public class Sender implements Runnable { + + private final Logger log; + + /* the state of each nodes connection */ + private final KafkaClient client; + + /* the record accumulator that batches records */ + private final RecordAccumulator accumulator; + + /* the metadata for the client */ + private final ProducerMetadata metadata; + + /* the flag indicating whether the producer should guarantee the message order on the broker or not. */ + private final boolean guaranteeMessageOrder; + + /* the maximum request size to attempt to send to the server */ + private final int maxRequestSize; + + /* the number of acknowledgements to request from the server */ + private final short acks; + + /* the number of times to retry a failed request before giving up */ + private final int retries; + + /* the clock instance used for getting the time */ + private final Time time; + + /* true while the sender thread is still running */ + private volatile boolean running; + + /* true when the caller wants to ignore all unsent/inflight messages and force close. */ + private volatile boolean forceClose; + + /* metrics */ + private final SenderMetrics sensors; + + /* the max time to wait for the server to respond to the request*/ + private final int requestTimeoutMs; + + /* The max time to wait before retrying a request which has failed */ + private final long retryBackoffMs; + + /* current request API versions supported by the known brokers */ + private final ApiVersions apiVersions; + + /* all the state related to transactions, in particular the producer id, producer epoch, and sequence numbers */ + private final TransactionManager transactionManager; + + // A per-partition queue of batches ordered by creation time for tracking the in-flight batches + private final Map> inFlightBatches; + + public Sender(LogContext logContext, + KafkaClient client, + ProducerMetadata metadata, + RecordAccumulator accumulator, + boolean guaranteeMessageOrder, + int maxRequestSize, + short acks, + int retries, + SenderMetricsRegistry metricsRegistry, + Time time, + int requestTimeoutMs, + long retryBackoffMs, + TransactionManager transactionManager, + ApiVersions apiVersions) { + this.log = logContext.logger(Sender.class); + this.client = client; + this.accumulator = accumulator; + this.metadata = metadata; + this.guaranteeMessageOrder = guaranteeMessageOrder; + this.maxRequestSize = maxRequestSize; + this.running = true; + this.acks = acks; + this.retries = retries; + this.time = time; + this.sensors = new SenderMetrics(metricsRegistry, metadata, client, time); + this.requestTimeoutMs = requestTimeoutMs; + this.retryBackoffMs = retryBackoffMs; + this.apiVersions = apiVersions; + this.transactionManager = transactionManager; + this.inFlightBatches = new HashMap<>(); + } + + public List inFlightBatches(TopicPartition tp) { + return inFlightBatches.containsKey(tp) ? inFlightBatches.get(tp) : new ArrayList<>(); + } + + private void maybeRemoveFromInflightBatches(ProducerBatch batch) { + List batches = inFlightBatches.get(batch.topicPartition); + if (batches != null) { + batches.remove(batch); + if (batches.isEmpty()) { + inFlightBatches.remove(batch.topicPartition); + } + } + } + + private void maybeRemoveAndDeallocateBatch(ProducerBatch batch) { + maybeRemoveFromInflightBatches(batch); + this.accumulator.deallocate(batch); + } + + /** + * Get the in-flight batches that has reached delivery timeout. + */ + private List getExpiredInflightBatches(long now) { + List expiredBatches = new ArrayList<>(); + + for (Iterator>> batchIt = inFlightBatches.entrySet().iterator(); batchIt.hasNext();) { + Map.Entry> entry = batchIt.next(); + List partitionInFlightBatches = entry.getValue(); + if (partitionInFlightBatches != null) { + Iterator iter = partitionInFlightBatches.iterator(); + while (iter.hasNext()) { + ProducerBatch batch = iter.next(); + if (batch.hasReachedDeliveryTimeout(accumulator.getDeliveryTimeoutMs(), now)) { + iter.remove(); + // expireBatches is called in Sender.sendProducerData, before client.poll. + // The !batch.isDone() invariant should always hold. An IllegalStateException + // exception will be thrown if the invariant is violated. + if (!batch.isDone()) { + expiredBatches.add(batch); + } else { + throw new IllegalStateException(batch.topicPartition + " batch created at " + + batch.createdMs + " gets unexpected final state " + batch.finalState()); + } + } else { + accumulator.maybeUpdateNextBatchExpiryTime(batch); + break; + } + } + if (partitionInFlightBatches.isEmpty()) { + batchIt.remove(); + } + } + } + return expiredBatches; + } + + private void addToInflightBatches(List batches) { + for (ProducerBatch batch : batches) { + List inflightBatchList = inFlightBatches.get(batch.topicPartition); + if (inflightBatchList == null) { + inflightBatchList = new ArrayList<>(); + inFlightBatches.put(batch.topicPartition, inflightBatchList); + } + inflightBatchList.add(batch); + } + } + + public void addToInflightBatches(Map> batches) { + for (List batchList : batches.values()) { + addToInflightBatches(batchList); + } + } + + private boolean hasPendingTransactionalRequests() { + return transactionManager != null && transactionManager.hasPendingRequests() && transactionManager.hasOngoingTransaction(); + } + + /** + * The main run loop for the sender thread + */ + @Override + public void run() { + log.debug("Starting Kafka producer I/O thread."); + + // main loop, runs until close is called + while (running) { + try { + runOnce(); + } catch (Exception e) { + log.error("Uncaught error in kafka producer I/O thread: ", e); + } + } + + log.debug("Beginning shutdown of Kafka producer I/O thread, sending remaining records."); + + // okay we stopped accepting requests but there may still be + // requests in the transaction manager, accumulator or waiting for acknowledgment, + // wait until these are completed. + while (!forceClose && ((this.accumulator.hasUndrained() || this.client.inFlightRequestCount() > 0) || hasPendingTransactionalRequests())) { + try { + runOnce(); + } catch (Exception e) { + log.error("Uncaught error in kafka producer I/O thread: ", e); + } + } + + // Abort the transaction if any commit or abort didn't go through the transaction manager's queue + while (!forceClose && transactionManager != null && transactionManager.hasOngoingTransaction()) { + if (!transactionManager.isCompleting()) { + log.info("Aborting incomplete transaction due to shutdown"); + transactionManager.beginAbort(); + } + try { + runOnce(); + } catch (Exception e) { + log.error("Uncaught error in kafka producer I/O thread: ", e); + } + } + + if (forceClose) { + // We need to fail all the incomplete transactional requests and batches and wake up the threads waiting on + // the futures. + if (transactionManager != null) { + log.debug("Aborting incomplete transactional requests due to forced shutdown"); + transactionManager.close(); + } + log.debug("Aborting incomplete batches due to forced shutdown"); + this.accumulator.abortIncompleteBatches(); + } + try { + this.client.close(); + } catch (Exception e) { + log.error("Failed to close network client", e); + } + + log.debug("Shutdown of Kafka producer I/O thread has completed."); + } + + /** + * Run a single iteration of sending + * + */ + void runOnce() { + if (transactionManager != null) { + try { + transactionManager.maybeResolveSequences(); + + // do not continue sending if the transaction manager is in a failed state + if (transactionManager.hasFatalError()) { + RuntimeException lastError = transactionManager.lastError(); + if (lastError != null) + maybeAbortBatches(lastError); + client.poll(retryBackoffMs, time.milliseconds()); + return; + } + + // Check whether we need a new producerId. If so, we will enqueue an InitProducerId + // request which will be sent below + transactionManager.bumpIdempotentEpochAndResetIdIfNeeded(); + + if (maybeSendAndPollTransactionalRequest()) { + return; + } + } catch (AuthenticationException e) { + // This is already logged as error, but propagated here to perform any clean ups. + log.trace("Authentication exception while processing transactional request", e); + transactionManager.authenticationFailed(e); + } + } + + long currentTimeMs = time.milliseconds(); + long pollTimeout = sendProducerData(currentTimeMs); + client.poll(pollTimeout, currentTimeMs); + } + + private long sendProducerData(long now) { + Cluster cluster = metadata.fetch(); + // get the list of partitions with data ready to send + RecordAccumulator.ReadyCheckResult result = this.accumulator.ready(cluster, now); + + // if there are any partitions whose leaders are not known yet, force metadata update + if (!result.unknownLeaderTopics.isEmpty()) { + // The set of topics with unknown leader contains topics with leader election pending as well as + // topics which may have expired. Add the topic again to metadata to ensure it is included + // and request metadata update, since there are messages to send to the topic. + for (String topic : result.unknownLeaderTopics) + this.metadata.add(topic, now); + + log.debug("Requesting metadata update due to unknown leader topics from the batched records: {}", + result.unknownLeaderTopics); + this.metadata.requestUpdate(); + } + + // remove any nodes we aren't ready to send to + Iterator iter = result.readyNodes.iterator(); + long notReadyTimeout = Long.MAX_VALUE; + while (iter.hasNext()) { + Node node = iter.next(); + if (!this.client.ready(node, now)) { + iter.remove(); + notReadyTimeout = Math.min(notReadyTimeout, this.client.pollDelayMs(node, now)); + } + } + + // create produce requests + Map> batches = this.accumulator.drain(cluster, result.readyNodes, this.maxRequestSize, now); + addToInflightBatches(batches); + if (guaranteeMessageOrder) { + // Mute all the partitions drained + for (List batchList : batches.values()) { + for (ProducerBatch batch : batchList) + this.accumulator.mutePartition(batch.topicPartition); + } + } + + accumulator.resetNextBatchExpiryTime(); + List expiredInflightBatches = getExpiredInflightBatches(now); + List expiredBatches = this.accumulator.expiredBatches(now); + expiredBatches.addAll(expiredInflightBatches); + + // Reset the producer id if an expired batch has previously been sent to the broker. Also update the metrics + // for expired batches. see the documentation of @TransactionState.resetIdempotentProducerId to understand why + // we need to reset the producer id here. + if (!expiredBatches.isEmpty()) + log.trace("Expired {} batches in accumulator", expiredBatches.size()); + for (ProducerBatch expiredBatch : expiredBatches) { + String errorMessage = "Expiring " + expiredBatch.recordCount + " record(s) for " + expiredBatch.topicPartition + + ":" + (now - expiredBatch.createdMs) + " ms has passed since batch creation"; + failBatch(expiredBatch, new TimeoutException(errorMessage), false); + if (transactionManager != null && expiredBatch.inRetry()) { + // This ensures that no new batches are drained until the current in flight batches are fully resolved. + transactionManager.markSequenceUnresolved(expiredBatch); + } + } + sensors.updateProduceRequestMetrics(batches); + + // If we have any nodes that are ready to send + have sendable data, poll with 0 timeout so this can immediately + // loop and try sending more data. Otherwise, the timeout will be the smaller value between next batch expiry + // time, and the delay time for checking data availability. Note that the nodes may have data that isn't yet + // sendable due to lingering, backing off, etc. This specifically does not include nodes with sendable data + // that aren't ready to send since they would cause busy looping. + long pollTimeout = Math.min(result.nextReadyCheckDelayMs, notReadyTimeout); + pollTimeout = Math.min(pollTimeout, this.accumulator.nextExpiryTimeMs() - now); + pollTimeout = Math.max(pollTimeout, 0); + if (!result.readyNodes.isEmpty()) { + log.trace("Nodes with data ready to send: {}", result.readyNodes); + // if some partitions are already ready to be sent, the select time would be 0; + // otherwise if some partition already has some data accumulated but not ready yet, + // the select time will be the time difference between now and its linger expiry time; + // otherwise the select time will be the time difference between now and the metadata expiry time; + pollTimeout = 0; + } + sendProduceRequests(batches, now); + return pollTimeout; + } + + /** + * Returns true if a transactional request is sent or polled, or if a FindCoordinator request is enqueued + */ + private boolean maybeSendAndPollTransactionalRequest() { + if (transactionManager.hasInFlightRequest()) { + // as long as there are outstanding transactional requests, we simply wait for them to return + client.poll(retryBackoffMs, time.milliseconds()); + return true; + } + + if (transactionManager.hasAbortableError() || transactionManager.isAborting()) { + if (accumulator.hasIncomplete()) { + // Attempt to get the last error that caused this abort. + RuntimeException exception = transactionManager.lastError(); + // If there was no error, but we are still aborting, + // then this is most likely a case where there was no fatal error. + if (exception == null) { + exception = new TransactionAbortedException(); + } + accumulator.abortUndrainedBatches(exception); + } + } + + TransactionManager.TxnRequestHandler nextRequestHandler = transactionManager.nextRequest(accumulator.hasIncomplete()); + if (nextRequestHandler == null) + return false; + + AbstractRequest.Builder requestBuilder = nextRequestHandler.requestBuilder(); + Node targetNode = null; + try { + FindCoordinatorRequest.CoordinatorType coordinatorType = nextRequestHandler.coordinatorType(); + targetNode = coordinatorType != null ? + transactionManager.coordinator(coordinatorType) : + client.leastLoadedNode(time.milliseconds()); + if (targetNode != null) { + if (!awaitNodeReady(targetNode, coordinatorType)) { + log.trace("Target node {} not ready within request timeout, will retry when node is ready.", targetNode); + maybeFindCoordinatorAndRetry(nextRequestHandler); + return true; + } + } else if (coordinatorType != null) { + log.trace("Coordinator not known for {}, will retry {} after finding coordinator.", coordinatorType, requestBuilder.apiKey()); + maybeFindCoordinatorAndRetry(nextRequestHandler); + return true; + } else { + log.trace("No nodes available to send requests, will poll and retry when until a node is ready."); + transactionManager.retry(nextRequestHandler); + client.poll(retryBackoffMs, time.milliseconds()); + return true; + } + + if (nextRequestHandler.isRetry()) + time.sleep(nextRequestHandler.retryBackoffMs()); + + long currentTimeMs = time.milliseconds(); + ClientRequest clientRequest = client.newClientRequest(targetNode.idString(), requestBuilder, currentTimeMs, + true, requestTimeoutMs, nextRequestHandler); + log.debug("Sending transactional request {} to node {} with correlation ID {}", requestBuilder, targetNode, clientRequest.correlationId()); + client.send(clientRequest, currentTimeMs); + transactionManager.setInFlightCorrelationId(clientRequest.correlationId()); + client.poll(retryBackoffMs, time.milliseconds()); + return true; + } catch (IOException e) { + log.debug("Disconnect from {} while trying to send request {}. Going " + + "to back off and retry.", targetNode, requestBuilder, e); + // We break here so that we pick up the FindCoordinator request immediately. + maybeFindCoordinatorAndRetry(nextRequestHandler); + return true; + } + } + + private void maybeFindCoordinatorAndRetry(TransactionManager.TxnRequestHandler nextRequestHandler) { + if (nextRequestHandler.needsCoordinator()) { + transactionManager.lookupCoordinator(nextRequestHandler); + } else { + // For non-coordinator requests, sleep here to prevent a tight loop when no node is available + time.sleep(retryBackoffMs); + metadata.requestUpdate(); + } + + transactionManager.retry(nextRequestHandler); + } + + private void maybeAbortBatches(RuntimeException exception) { + if (accumulator.hasIncomplete()) { + log.error("Aborting producer batches due to fatal error", exception); + accumulator.abortBatches(exception); + } + } + + /** + * Start closing the sender (won't actually complete until all data is sent out) + */ + public void initiateClose() { + // Ensure accumulator is closed first to guarantee that no more appends are accepted after + // breaking from the sender loop. Otherwise, we may miss some callbacks when shutting down. + this.accumulator.close(); + this.running = false; + this.wakeup(); + } + + /** + * Closes the sender without sending out any pending messages. + */ + public void forceClose() { + this.forceClose = true; + initiateClose(); + } + + public boolean isRunning() { + return running; + } + + private boolean awaitNodeReady(Node node, FindCoordinatorRequest.CoordinatorType coordinatorType) throws IOException { + if (NetworkClientUtils.awaitReady(client, node, time, requestTimeoutMs)) { + if (coordinatorType == FindCoordinatorRequest.CoordinatorType.TRANSACTION) { + // Indicate to the transaction manager that the coordinator is ready, allowing it to check ApiVersions + // This allows us to bump transactional epochs even if the coordinator is temporarily unavailable at + // the time when the abortable error is handled + transactionManager.handleCoordinatorReady(); + } + return true; + } + return false; + } + + /** + * Handle a produce response + */ + private void handleProduceResponse(ClientResponse response, Map batches, long now) { + RequestHeader requestHeader = response.requestHeader(); + int correlationId = requestHeader.correlationId(); + if (response.wasDisconnected()) { + log.trace("Cancelled request with header {} due to node {} being disconnected", + requestHeader, response.destination()); + for (ProducerBatch batch : batches.values()) + completeBatch(batch, new ProduceResponse.PartitionResponse(Errors.NETWORK_EXCEPTION, String.format("Disconnected from node %s", response.destination())), + correlationId, now); + } else if (response.versionMismatch() != null) { + log.warn("Cancelled request {} due to a version mismatch with node {}", + response, response.destination(), response.versionMismatch()); + for (ProducerBatch batch : batches.values()) + completeBatch(batch, new ProduceResponse.PartitionResponse(Errors.UNSUPPORTED_VERSION), correlationId, now); + } else { + log.trace("Received produce response from node {} with correlation id {}", response.destination(), correlationId); + // if we have a response, parse it + if (response.hasResponse()) { + // Sender should exercise PartitionProduceResponse rather than ProduceResponse.PartitionResponse + // https://issues.apache.org/jira/browse/KAFKA-10696 + ProduceResponse produceResponse = (ProduceResponse) response.responseBody(); + produceResponse.data().responses().forEach(r -> r.partitionResponses().forEach(p -> { + TopicPartition tp = new TopicPartition(r.name(), p.index()); + ProduceResponse.PartitionResponse partResp = new ProduceResponse.PartitionResponse( + Errors.forCode(p.errorCode()), + p.baseOffset(), + p.logAppendTimeMs(), + p.logStartOffset(), + p.recordErrors() + .stream() + .map(e -> new ProduceResponse.RecordError(e.batchIndex(), e.batchIndexErrorMessage())) + .collect(Collectors.toList()), + p.errorMessage()); + ProducerBatch batch = batches.get(tp); + completeBatch(batch, partResp, correlationId, now); + })); + this.sensors.recordLatency(response.destination(), response.requestLatencyMs()); + } else { + // this is the acks = 0 case, just complete all requests + for (ProducerBatch batch : batches.values()) { + completeBatch(batch, new ProduceResponse.PartitionResponse(Errors.NONE), correlationId, now); + } + } + } + } + + /** + * Complete or retry the given batch of records. + * + * @param batch The record batch + * @param response The produce response + * @param correlationId The correlation id for the request + * @param now The current POSIX timestamp in milliseconds + */ + private void completeBatch(ProducerBatch batch, ProduceResponse.PartitionResponse response, long correlationId, + long now) { + Errors error = response.error; + + if (error == Errors.MESSAGE_TOO_LARGE && batch.recordCount > 1 && !batch.isDone() && + (batch.magic() >= RecordBatch.MAGIC_VALUE_V2 || batch.isCompressed())) { + // If the batch is too large, we split the batch and send the split batches again. We do not decrement + // the retry attempts in this case. + log.warn( + "Got error produce response in correlation id {} on topic-partition {}, splitting and retrying ({} attempts left). Error: {}", + correlationId, + batch.topicPartition, + this.retries - batch.attempts(), + formatErrMsg(response)); + if (transactionManager != null) + transactionManager.removeInFlightBatch(batch); + this.accumulator.splitAndReenqueue(batch); + maybeRemoveAndDeallocateBatch(batch); + this.sensors.recordBatchSplit(); + } else if (error != Errors.NONE) { + if (canRetry(batch, response, now)) { + log.warn( + "Got error produce response with correlation id {} on topic-partition {}, retrying ({} attempts left). Error: {}", + correlationId, + batch.topicPartition, + this.retries - batch.attempts() - 1, + formatErrMsg(response)); + reenqueueBatch(batch, now); + } else if (error == Errors.DUPLICATE_SEQUENCE_NUMBER) { + // If we have received a duplicate sequence error, it means that the sequence number has advanced beyond + // the sequence of the current batch, and we haven't retained batch metadata on the broker to return + // the correct offset and timestamp. + // + // The only thing we can do is to return success to the user and not return a valid offset and timestamp. + completeBatch(batch, response); + } else { + // tell the user the result of their request. We only adjust sequence numbers if the batch didn't exhaust + // its retries -- if it did, we don't know whether the sequence number was accepted or not, and + // thus it is not safe to reassign the sequence. + failBatch(batch, response, batch.attempts() < this.retries); + } + if (error.exception() instanceof InvalidMetadataException) { + if (error.exception() instanceof UnknownTopicOrPartitionException) { + log.warn("Received unknown topic or partition error in produce request on partition {}. The " + + "topic-partition may not exist or the user may not have Describe access to it", + batch.topicPartition); + } else { + log.warn("Received invalid metadata error in produce request on partition {} due to {}. Going " + + "to request metadata update now", batch.topicPartition, + error.exception(response.errorMessage).toString()); + } + metadata.requestUpdate(); + } + } else { + completeBatch(batch, response); + } + + // Unmute the completed partition. + if (guaranteeMessageOrder) + this.accumulator.unmutePartition(batch.topicPartition); + } + + /** + * Format the error from a {@link ProduceResponse.PartitionResponse} in a user-friendly string + * e.g "NETWORK_EXCEPTION. Error Message: Disconnected from node 0" + */ + private String formatErrMsg(ProduceResponse.PartitionResponse response) { + String errorMessageSuffix = (response.errorMessage == null || response.errorMessage.isEmpty()) ? + "" : String.format(". Error Message: %s", response.errorMessage); + return String.format("%s%s", response.error, errorMessageSuffix); + } + + private void reenqueueBatch(ProducerBatch batch, long currentTimeMs) { + this.accumulator.reenqueue(batch, currentTimeMs); + maybeRemoveFromInflightBatches(batch); + this.sensors.recordRetries(batch.topicPartition.topic(), batch.recordCount); + } + + private void completeBatch(ProducerBatch batch, ProduceResponse.PartitionResponse response) { + if (transactionManager != null) { + transactionManager.handleCompletedBatch(batch, response); + } + + if (batch.complete(response.baseOffset, response.logAppendTime)) { + maybeRemoveAndDeallocateBatch(batch); + } + } + + private void failBatch(ProducerBatch batch, + ProduceResponse.PartitionResponse response, + boolean adjustSequenceNumbers) { + final RuntimeException topLevelException; + if (response.error == Errors.TOPIC_AUTHORIZATION_FAILED) + topLevelException = new TopicAuthorizationException(Collections.singleton(batch.topicPartition.topic())); + else if (response.error == Errors.CLUSTER_AUTHORIZATION_FAILED) + topLevelException = new ClusterAuthorizationException("The producer is not authorized to do idempotent sends"); + else + topLevelException = response.error.exception(response.errorMessage); + + if (response.recordErrors == null || response.recordErrors.isEmpty()) { + failBatch(batch, topLevelException, adjustSequenceNumbers); + } else { + Map recordErrorMap = new HashMap<>(response.recordErrors.size()); + for (ProduceResponse.RecordError recordError : response.recordErrors) { + // The API leaves us with some awkwardness interpreting the errors in the response. + // We cannot differentiate between different error cases (such as INVALID_TIMESTAMP) + // from the single error code at the partition level, so instead we use INVALID_RECORD + // for all failed records and rely on the message to distinguish the cases. + final String errorMessage; + if (recordError.message != null) { + errorMessage = recordError.message; + } else if (response.errorMessage != null) { + errorMessage = response.errorMessage; + } else { + errorMessage = response.error.message(); + } + + // If the batch contained only a single record error, then we can unambiguously + // use the exception type corresponding to the partition-level error code. + if (response.recordErrors.size() == 1) { + recordErrorMap.put(recordError.batchIndex, response.error.exception(errorMessage)); + } else { + recordErrorMap.put(recordError.batchIndex, new InvalidRecordException(errorMessage)); + } + } + + Function recordExceptions = batchIndex -> { + RuntimeException exception = recordErrorMap.get(batchIndex); + if (exception != null) { + return exception; + } else { + // If the response contains record errors, then the records which failed validation + // will be present in the response. To avoid confusion for the remaining records, we + // return a generic exception. + return new KafkaException("Failed to append record because it was part of a batch " + + "which had one more more invalid records"); + } + }; + + failBatch(batch, topLevelException, recordExceptions, adjustSequenceNumbers); + } + } + + private void failBatch( + ProducerBatch batch, + RuntimeException topLevelException, + boolean adjustSequenceNumbers + ) { + failBatch(batch, topLevelException, batchIndex -> topLevelException, adjustSequenceNumbers); + } + + private void failBatch( + ProducerBatch batch, + RuntimeException topLevelException, + Function recordExceptions, + boolean adjustSequenceNumbers + ) { + if (transactionManager != null) { + transactionManager.handleFailedBatch(batch, topLevelException, adjustSequenceNumbers); + } + + this.sensors.recordErrors(batch.topicPartition.topic(), batch.recordCount); + + if (batch.completeExceptionally(topLevelException, recordExceptions)) { + maybeRemoveAndDeallocateBatch(batch); + } + } + + /** + * We can retry a send if the error is transient and the number of attempts taken is fewer than the maximum allowed. + * We can also retry OutOfOrderSequence exceptions for future batches, since if the first batch has failed, the + * future batches are certain to fail with an OutOfOrderSequence exception. + */ + private boolean canRetry(ProducerBatch batch, ProduceResponse.PartitionResponse response, long now) { + return !batch.hasReachedDeliveryTimeout(accumulator.getDeliveryTimeoutMs(), now) && + batch.attempts() < this.retries && + !batch.isDone() && + (transactionManager == null ? + response.error.exception() instanceof RetriableException : + transactionManager.canRetry(response, batch)); + } + + /** + * Transfer the record batches into a list of produce requests on a per-node basis + */ + private void sendProduceRequests(Map> collated, long now) { + for (Map.Entry> entry : collated.entrySet()) + sendProduceRequest(now, entry.getKey(), acks, requestTimeoutMs, entry.getValue()); + } + + /** + * Create a produce request from the given record batches + */ + private void sendProduceRequest(long now, int destination, short acks, int timeout, List batches) { + if (batches.isEmpty()) + return; + + final Map recordsByPartition = new HashMap<>(batches.size()); + + // find the minimum magic version used when creating the record sets + byte minUsedMagic = apiVersions.maxUsableProduceMagic(); + for (ProducerBatch batch : batches) { + if (batch.magic() < minUsedMagic) + minUsedMagic = batch.magic(); + } + ProduceRequestData.TopicProduceDataCollection tpd = new ProduceRequestData.TopicProduceDataCollection(); + for (ProducerBatch batch : batches) { + TopicPartition tp = batch.topicPartition; + MemoryRecords records = batch.records(); + + // down convert if necessary to the minimum magic used. In general, there can be a delay between the time + // that the producer starts building the batch and the time that we send the request, and we may have + // chosen the message format based on out-dated metadata. In the worst case, we optimistically chose to use + // the new message format, but found that the broker didn't support it, so we need to down-convert on the + // client before sending. This is intended to handle edge cases around cluster upgrades where brokers may + // not all support the same message format version. For example, if a partition migrates from a broker + // which is supporting the new magic version to one which doesn't, then we will need to convert. + if (!records.hasMatchingMagic(minUsedMagic)) + records = batch.records().downConvert(minUsedMagic, 0, time).records(); + ProduceRequestData.TopicProduceData tpData = tpd.find(tp.topic()); + if (tpData == null) { + tpData = new ProduceRequestData.TopicProduceData().setName(tp.topic()); + tpd.add(tpData); + } + tpData.partitionData().add(new ProduceRequestData.PartitionProduceData() + .setIndex(tp.partition()) + .setRecords(records)); + recordsByPartition.put(tp, batch); + } + + String transactionalId = null; + if (transactionManager != null && transactionManager.isTransactional()) { + transactionalId = transactionManager.transactionalId(); + } + + ProduceRequest.Builder requestBuilder = ProduceRequest.forMagic(minUsedMagic, + new ProduceRequestData() + .setAcks(acks) + .setTimeoutMs(timeout) + .setTransactionalId(transactionalId) + .setTopicData(tpd)); + RequestCompletionHandler callback = response -> handleProduceResponse(response, recordsByPartition, time.milliseconds()); + + String nodeId = Integer.toString(destination); + ClientRequest clientRequest = client.newClientRequest(nodeId, requestBuilder, now, acks != 0, + requestTimeoutMs, callback); + client.send(clientRequest, now); + log.trace("Sent produce request to {}: {}", nodeId, requestBuilder); + } + + /** + * Wake up the selector associated with this send thread + */ + public void wakeup() { + this.client.wakeup(); + } + + public static Sensor throttleTimeSensor(SenderMetricsRegistry metrics) { + Sensor produceThrottleTimeSensor = metrics.sensor("produce-throttle-time"); + produceThrottleTimeSensor.add(metrics.produceThrottleTimeAvg, new Avg()); + produceThrottleTimeSensor.add(metrics.produceThrottleTimeMax, new Max()); + return produceThrottleTimeSensor; + } + + /** + * A collection of sensors for the sender + */ + private static class SenderMetrics { + public final Sensor retrySensor; + public final Sensor errorSensor; + public final Sensor queueTimeSensor; + public final Sensor requestTimeSensor; + public final Sensor recordsPerRequestSensor; + public final Sensor batchSizeSensor; + public final Sensor compressionRateSensor; + public final Sensor maxRecordSizeSensor; + public final Sensor batchSplitSensor; + private final SenderMetricsRegistry metrics; + private final Time time; + + public SenderMetrics(SenderMetricsRegistry metrics, Metadata metadata, KafkaClient client, Time time) { + this.metrics = metrics; + this.time = time; + + this.batchSizeSensor = metrics.sensor("batch-size"); + this.batchSizeSensor.add(metrics.batchSizeAvg, new Avg()); + this.batchSizeSensor.add(metrics.batchSizeMax, new Max()); + + this.compressionRateSensor = metrics.sensor("compression-rate"); + this.compressionRateSensor.add(metrics.compressionRateAvg, new Avg()); + + this.queueTimeSensor = metrics.sensor("queue-time"); + this.queueTimeSensor.add(metrics.recordQueueTimeAvg, new Avg()); + this.queueTimeSensor.add(metrics.recordQueueTimeMax, new Max()); + + this.requestTimeSensor = metrics.sensor("request-time"); + this.requestTimeSensor.add(metrics.requestLatencyAvg, new Avg()); + this.requestTimeSensor.add(metrics.requestLatencyMax, new Max()); + + this.recordsPerRequestSensor = metrics.sensor("records-per-request"); + this.recordsPerRequestSensor.add(new Meter(metrics.recordSendRate, metrics.recordSendTotal)); + this.recordsPerRequestSensor.add(metrics.recordsPerRequestAvg, new Avg()); + + this.retrySensor = metrics.sensor("record-retries"); + this.retrySensor.add(new Meter(metrics.recordRetryRate, metrics.recordRetryTotal)); + + this.errorSensor = metrics.sensor("errors"); + this.errorSensor.add(new Meter(metrics.recordErrorRate, metrics.recordErrorTotal)); + + this.maxRecordSizeSensor = metrics.sensor("record-size"); + this.maxRecordSizeSensor.add(metrics.recordSizeMax, new Max()); + this.maxRecordSizeSensor.add(metrics.recordSizeAvg, new Avg()); + + this.metrics.addMetric(metrics.requestsInFlight, (config, now) -> client.inFlightRequestCount()); + this.metrics.addMetric(metrics.metadataAge, + (config, now) -> (now - metadata.lastSuccessfulUpdate()) / 1000.0); + + this.batchSplitSensor = metrics.sensor("batch-split-rate"); + this.batchSplitSensor.add(new Meter(metrics.batchSplitRate, metrics.batchSplitTotal)); + } + + private void maybeRegisterTopicMetrics(String topic) { + // if one sensor of the metrics has been registered for the topic, + // then all other sensors should have been registered; and vice versa + String topicRecordsCountName = "topic." + topic + ".records-per-batch"; + Sensor topicRecordCount = this.metrics.getSensor(topicRecordsCountName); + if (topicRecordCount == null) { + Map metricTags = Collections.singletonMap("topic", topic); + + topicRecordCount = this.metrics.sensor(topicRecordsCountName); + MetricName rateMetricName = this.metrics.topicRecordSendRate(metricTags); + MetricName totalMetricName = this.metrics.topicRecordSendTotal(metricTags); + topicRecordCount.add(new Meter(rateMetricName, totalMetricName)); + + String topicByteRateName = "topic." + topic + ".bytes"; + Sensor topicByteRate = this.metrics.sensor(topicByteRateName); + rateMetricName = this.metrics.topicByteRate(metricTags); + totalMetricName = this.metrics.topicByteTotal(metricTags); + topicByteRate.add(new Meter(rateMetricName, totalMetricName)); + + String topicCompressionRateName = "topic." + topic + ".compression-rate"; + Sensor topicCompressionRate = this.metrics.sensor(topicCompressionRateName); + MetricName m = this.metrics.topicCompressionRate(metricTags); + topicCompressionRate.add(m, new Avg()); + + String topicRetryName = "topic." + topic + ".record-retries"; + Sensor topicRetrySensor = this.metrics.sensor(topicRetryName); + rateMetricName = this.metrics.topicRecordRetryRate(metricTags); + totalMetricName = this.metrics.topicRecordRetryTotal(metricTags); + topicRetrySensor.add(new Meter(rateMetricName, totalMetricName)); + + String topicErrorName = "topic." + topic + ".record-errors"; + Sensor topicErrorSensor = this.metrics.sensor(topicErrorName); + rateMetricName = this.metrics.topicRecordErrorRate(metricTags); + totalMetricName = this.metrics.topicRecordErrorTotal(metricTags); + topicErrorSensor.add(new Meter(rateMetricName, totalMetricName)); + } + } + + public void updateProduceRequestMetrics(Map> batches) { + long now = time.milliseconds(); + for (List nodeBatch : batches.values()) { + int records = 0; + for (ProducerBatch batch : nodeBatch) { + // register all per-topic metrics at once + String topic = batch.topicPartition.topic(); + maybeRegisterTopicMetrics(topic); + + // per-topic record send rate + String topicRecordsCountName = "topic." + topic + ".records-per-batch"; + Sensor topicRecordCount = Objects.requireNonNull(this.metrics.getSensor(topicRecordsCountName)); + topicRecordCount.record(batch.recordCount); + + // per-topic bytes send rate + String topicByteRateName = "topic." + topic + ".bytes"; + Sensor topicByteRate = Objects.requireNonNull(this.metrics.getSensor(topicByteRateName)); + topicByteRate.record(batch.estimatedSizeInBytes()); + + // per-topic compression rate + String topicCompressionRateName = "topic." + topic + ".compression-rate"; + Sensor topicCompressionRate = Objects.requireNonNull(this.metrics.getSensor(topicCompressionRateName)); + topicCompressionRate.record(batch.compressionRatio()); + + // global metrics + this.batchSizeSensor.record(batch.estimatedSizeInBytes(), now); + this.queueTimeSensor.record(batch.queueTimeMs(), now); + this.compressionRateSensor.record(batch.compressionRatio()); + this.maxRecordSizeSensor.record(batch.maxRecordSize, now); + records += batch.recordCount; + } + this.recordsPerRequestSensor.record(records, now); + } + } + + public void recordRetries(String topic, int count) { + long now = time.milliseconds(); + this.retrySensor.record(count, now); + String topicRetryName = "topic." + topic + ".record-retries"; + Sensor topicRetrySensor = this.metrics.getSensor(topicRetryName); + if (topicRetrySensor != null) + topicRetrySensor.record(count, now); + } + + public void recordErrors(String topic, int count) { + long now = time.milliseconds(); + this.errorSensor.record(count, now); + String topicErrorName = "topic." + topic + ".record-errors"; + Sensor topicErrorSensor = this.metrics.getSensor(topicErrorName); + if (topicErrorSensor != null) + topicErrorSensor.record(count, now); + } + + public void recordLatency(String node, long latency) { + long now = time.milliseconds(); + this.requestTimeSensor.record(latency, now); + if (!node.isEmpty()) { + String nodeTimeName = "node-" + node + ".latency"; + Sensor nodeRequestTime = this.metrics.getSensor(nodeTimeName); + if (nodeRequestTime != null) + nodeRequestTime.record(latency, now); + } + } + + void recordBatchSplit() { + this.batchSplitSensor.record(); + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/SenderMetricsRegistry.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/SenderMetricsRegistry.java new file mode 100644 index 0000000..2ad2cba --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/SenderMetricsRegistry.java @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer.internals; + +import java.util.ArrayList; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.MetricNameTemplate; +import org.apache.kafka.common.metrics.Measurable; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.Sensor; + +public class SenderMetricsRegistry { + + final static String TOPIC_METRIC_GROUP_NAME = "producer-topic-metrics"; + + private final List allTemplates; + + public final MetricName batchSizeAvg; + public final MetricName batchSizeMax; + public final MetricName compressionRateAvg; + public final MetricName recordQueueTimeAvg; + public final MetricName recordQueueTimeMax; + public final MetricName requestLatencyAvg; + public final MetricName requestLatencyMax; + public final MetricName produceThrottleTimeAvg; + public final MetricName produceThrottleTimeMax; + public final MetricName recordSendRate; + public final MetricName recordSendTotal; + public final MetricName recordsPerRequestAvg; + public final MetricName recordRetryRate; + public final MetricName recordRetryTotal; + public final MetricName recordErrorRate; + public final MetricName recordErrorTotal; + public final MetricName recordSizeMax; + public final MetricName recordSizeAvg; + public final MetricName requestsInFlight; + public final MetricName metadataAge; + public final MetricName batchSplitRate; + public final MetricName batchSplitTotal; + + private final MetricNameTemplate topicRecordSendRate; + private final MetricNameTemplate topicRecordSendTotal; + private final MetricNameTemplate topicByteRate; + private final MetricNameTemplate topicByteTotal; + private final MetricNameTemplate topicCompressionRate; + private final MetricNameTemplate topicRecordRetryRate; + private final MetricNameTemplate topicRecordRetryTotal; + private final MetricNameTemplate topicRecordErrorRate; + private final MetricNameTemplate topicRecordErrorTotal; + + private final Metrics metrics; + private final Set tags; + private final LinkedHashSet topicTags; + + public SenderMetricsRegistry(Metrics metrics) { + this.metrics = metrics; + this.tags = this.metrics.config().tags().keySet(); + this.allTemplates = new ArrayList<>(); + + /***** Client level *****/ + + this.batchSizeAvg = createMetricName("batch-size-avg", + "The average number of bytes sent per partition per-request."); + this.batchSizeMax = createMetricName("batch-size-max", + "The max number of bytes sent per partition per-request."); + this.compressionRateAvg = createMetricName("compression-rate-avg", + "The average compression rate of record batches, defined as the average ratio of the " + + "compressed batch size over the uncompressed size."); + this.recordQueueTimeAvg = createMetricName("record-queue-time-avg", + "The average time in ms record batches spent in the send buffer."); + this.recordQueueTimeMax = createMetricName("record-queue-time-max", + "The maximum time in ms record batches spent in the send buffer."); + this.requestLatencyAvg = createMetricName("request-latency-avg", + "The average request latency in ms"); + this.requestLatencyMax = createMetricName("request-latency-max", + "The maximum request latency in ms"); + this.recordSendRate = createMetricName("record-send-rate", + "The average number of records sent per second."); + this.recordSendTotal = createMetricName("record-send-total", + "The total number of records sent."); + this.recordsPerRequestAvg = createMetricName("records-per-request-avg", + "The average number of records per request."); + this.recordRetryRate = createMetricName("record-retry-rate", + "The average per-second number of retried record sends"); + this.recordRetryTotal = createMetricName("record-retry-total", + "The total number of retried record sends"); + this.recordErrorRate = createMetricName("record-error-rate", + "The average per-second number of record sends that resulted in errors"); + this.recordErrorTotal = createMetricName("record-error-total", + "The total number of record sends that resulted in errors"); + this.recordSizeMax = createMetricName("record-size-max", + "The maximum record size"); + this.recordSizeAvg = createMetricName("record-size-avg", + "The average record size"); + this.requestsInFlight = createMetricName("requests-in-flight", + "The current number of in-flight requests awaiting a response."); + this.metadataAge = createMetricName("metadata-age", + "The age in seconds of the current producer metadata being used."); + this.batchSplitRate = createMetricName("batch-split-rate", + "The average number of batch splits per second"); + this.batchSplitTotal = createMetricName("batch-split-total", + "The total number of batch splits"); + + this.produceThrottleTimeAvg = createMetricName("produce-throttle-time-avg", + "The average time in ms a request was throttled by a broker"); + this.produceThrottleTimeMax = createMetricName("produce-throttle-time-max", + "The maximum time in ms a request was throttled by a broker"); + + /***** Topic level *****/ + this.topicTags = new LinkedHashSet<>(tags); + this.topicTags.add("topic"); + + // We can't create the MetricName up front for these, because we don't know the topic name yet. + this.topicRecordSendRate = createTopicTemplate("record-send-rate", + "The average number of records sent per second for a topic."); + this.topicRecordSendTotal = createTopicTemplate("record-send-total", + "The total number of records sent for a topic."); + this.topicByteRate = createTopicTemplate("byte-rate", + "The average number of bytes sent per second for a topic."); + this.topicByteTotal = createTopicTemplate("byte-total", + "The total number of bytes sent for a topic."); + this.topicCompressionRate = createTopicTemplate("compression-rate", + "The average compression rate of record batches for a topic, defined as the average ratio " + + "of the compressed batch size over the uncompressed size."); + this.topicRecordRetryRate = createTopicTemplate("record-retry-rate", + "The average per-second number of retried record sends for a topic"); + this.topicRecordRetryTotal = createTopicTemplate("record-retry-total", + "The total number of retried record sends for a topic"); + this.topicRecordErrorRate = createTopicTemplate("record-error-rate", + "The average per-second number of record sends that resulted in errors for a topic"); + this.topicRecordErrorTotal = createTopicTemplate("record-error-total", + "The total number of record sends that resulted in errors for a topic"); + + } + + private MetricName createMetricName(String name, String description) { + return this.metrics.metricInstance(createTemplate(name, KafkaProducerMetrics.GROUP, description, this.tags)); + } + + private MetricNameTemplate createTopicTemplate(String name, String description) { + return createTemplate(name, TOPIC_METRIC_GROUP_NAME, description, this.topicTags); + } + + /** topic level metrics **/ + public MetricName topicRecordSendRate(Map tags) { + return this.metrics.metricInstance(this.topicRecordSendRate, tags); + } + + public MetricName topicRecordSendTotal(Map tags) { + return this.metrics.metricInstance(this.topicRecordSendTotal, tags); + } + + public MetricName topicByteRate(Map tags) { + return this.metrics.metricInstance(this.topicByteRate, tags); + } + + public MetricName topicByteTotal(Map tags) { + return this.metrics.metricInstance(this.topicByteTotal, tags); + } + + public MetricName topicCompressionRate(Map tags) { + return this.metrics.metricInstance(this.topicCompressionRate, tags); + } + + public MetricName topicRecordRetryRate(Map tags) { + return this.metrics.metricInstance(this.topicRecordRetryRate, tags); + } + + public MetricName topicRecordRetryTotal(Map tags) { + return this.metrics.metricInstance(this.topicRecordRetryTotal, tags); + } + + public MetricName topicRecordErrorRate(Map tags) { + return this.metrics.metricInstance(this.topicRecordErrorRate, tags); + } + + public MetricName topicRecordErrorTotal(Map tags) { + return this.metrics.metricInstance(this.topicRecordErrorTotal, tags); + } + + public List allTemplates() { + return allTemplates; + } + + public Sensor sensor(String name) { + return this.metrics.sensor(name); + } + + public void addMetric(MetricName m, Measurable measurable) { + this.metrics.addMetric(m, measurable); + } + + public Sensor getSensor(String name) { + return this.metrics.getSensor(name); + } + + private MetricNameTemplate createTemplate(String name, String group, String description, Set tags) { + MetricNameTemplate template = new MetricNameTemplate(name, group, description, tags); + this.allTemplates.add(template); + return template; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/StickyPartitionCache.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/StickyPartitionCache.java new file mode 100644 index 0000000..b432009 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/StickyPartitionCache.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer.internals; + +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.PartitionInfo; + +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.kafka.common.utils.Utils; + +/** + * An internal class that implements a cache used for sticky partitioning behavior. The cache tracks the current sticky + * partition for any given topic. This class should not be used externally. + */ +public class StickyPartitionCache { + private final ConcurrentMap indexCache; + public StickyPartitionCache() { + this.indexCache = new ConcurrentHashMap<>(); + } + + public int partition(String topic, Cluster cluster) { + Integer part = indexCache.get(topic); + if (part == null) { + return nextPartition(topic, cluster, -1); + } + return part; + } + + public int nextPartition(String topic, Cluster cluster, int prevPartition) { + List partitions = cluster.partitionsForTopic(topic); + Integer oldPart = indexCache.get(topic); + Integer newPart = oldPart; + // Check that the current sticky partition for the topic is either not set or that the partition that + // triggered the new batch matches the sticky partition that needs to be changed. + if (oldPart == null || oldPart == prevPartition) { + List availablePartitions = cluster.availablePartitionsForTopic(topic); + if (availablePartitions.size() < 1) { + Integer random = Utils.toPositive(ThreadLocalRandom.current().nextInt()); + newPart = random % partitions.size(); + } else if (availablePartitions.size() == 1) { + newPart = availablePartitions.get(0).partition(); + } else { + while (newPart == null || newPart.equals(oldPart)) { + int random = Utils.toPositive(ThreadLocalRandom.current().nextInt()); + newPart = availablePartitions.get(random % availablePartitions.size()).partition(); + } + } + // Only change the sticky partition if it is null or prevPartition matches the current sticky partition. + if (oldPart == null) { + indexCache.putIfAbsent(topic, newPart); + } else { + indexCache.replace(topic, prevPartition, newPart); + } + return indexCache.get(topic); + } + return indexCache.get(topic); + } + +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java new file mode 100644 index 0000000..2de31a0 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java @@ -0,0 +1,1765 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer.internals; + +import org.apache.kafka.clients.ApiVersions; +import org.apache.kafka.clients.ClientResponse; +import org.apache.kafka.clients.NodeApiVersions; +import org.apache.kafka.clients.RequestCompletionHandler; +import org.apache.kafka.clients.consumer.CommitFailedException; +import org.apache.kafka.clients.consumer.ConsumerGroupMetadata; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.errors.InvalidPidMappingException; +import org.apache.kafka.common.errors.InvalidProducerEpochException; +import org.apache.kafka.common.errors.RetriableException; +import org.apache.kafka.common.errors.UnknownProducerIdException; +import org.apache.kafka.common.message.ApiVersionsResponseData.ApiVersion; +import org.apache.kafka.common.message.FindCoordinatorResponseData.Coordinator; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.utils.ProducerIdAndEpoch; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.errors.ClusterAuthorizationException; +import org.apache.kafka.common.errors.GroupAuthorizationException; +import org.apache.kafka.common.errors.OutOfOrderSequenceException; +import org.apache.kafka.common.errors.ProducerFencedException; +import org.apache.kafka.common.errors.TopicAuthorizationException; +import org.apache.kafka.common.errors.TransactionalIdAuthorizationException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.AddOffsetsToTxnRequestData; +import org.apache.kafka.common.message.EndTxnRequestData; +import org.apache.kafka.common.message.FindCoordinatorRequestData; +import org.apache.kafka.common.message.InitProducerIdRequestData; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.DefaultRecordBatch; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.requests.AbstractRequest; +import org.apache.kafka.common.requests.AbstractResponse; +import org.apache.kafka.common.requests.AddOffsetsToTxnRequest; +import org.apache.kafka.common.requests.AddOffsetsToTxnResponse; +import org.apache.kafka.common.requests.AddPartitionsToTxnRequest; +import org.apache.kafka.common.requests.AddPartitionsToTxnResponse; +import org.apache.kafka.common.requests.EndTxnRequest; +import org.apache.kafka.common.requests.EndTxnResponse; +import org.apache.kafka.common.requests.FindCoordinatorRequest; +import org.apache.kafka.common.requests.FindCoordinatorRequest.CoordinatorType; +import org.apache.kafka.common.requests.FindCoordinatorResponse; +import org.apache.kafka.common.requests.InitProducerIdRequest; +import org.apache.kafka.common.requests.InitProducerIdResponse; +import org.apache.kafka.common.requests.ProduceResponse; +import org.apache.kafka.common.requests.TransactionResult; +import org.apache.kafka.common.requests.TxnOffsetCommitRequest; +import org.apache.kafka.common.requests.TxnOffsetCommitRequest.CommittedOffset; +import org.apache.kafka.common.requests.TxnOffsetCommitResponse; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.PrimitiveRef; +import org.slf4j.Logger; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.OptionalInt; +import java.util.OptionalLong; +import java.util.PriorityQueue; +import java.util.Set; +import java.util.SortedSet; +import java.util.TreeSet; +import java.util.function.Consumer; +import java.util.function.Supplier; + +/** + * A class which maintains state for transactions. Also keeps the state necessary to ensure idempotent production. + */ +public class TransactionManager { + private static final int NO_INFLIGHT_REQUEST_CORRELATION_ID = -1; + private static final int NO_LAST_ACKED_SEQUENCE_NUMBER = -1; + + private final Logger log; + private final String transactionalId; + private final int transactionTimeoutMs; + private final ApiVersions apiVersions; + + private static class TopicPartitionBookkeeper { + + private final Map topicPartitions = new HashMap<>(); + + private TopicPartitionEntry getPartition(TopicPartition topicPartition) { + TopicPartitionEntry ent = topicPartitions.get(topicPartition); + if (ent == null) + throw new IllegalStateException("Trying to get the sequence number for " + topicPartition + + ", but the sequence number was never set for this partition."); + return ent; + } + + private TopicPartitionEntry getOrCreatePartition(TopicPartition topicPartition) { + TopicPartitionEntry ent = topicPartitions.get(topicPartition); + if (ent == null) { + ent = new TopicPartitionEntry(); + topicPartitions.put(topicPartition, ent); + } + return ent; + } + + private void addPartition(TopicPartition topicPartition) { + this.topicPartitions.putIfAbsent(topicPartition, new TopicPartitionEntry()); + } + + private boolean contains(TopicPartition topicPartition) { + return topicPartitions.containsKey(topicPartition); + } + + private void reset() { + topicPartitions.clear(); + } + + private OptionalLong lastAckedOffset(TopicPartition topicPartition) { + TopicPartitionEntry entry = topicPartitions.get(topicPartition); + if (entry != null && entry.lastAckedOffset != ProduceResponse.INVALID_OFFSET) + return OptionalLong.of(entry.lastAckedOffset); + else + return OptionalLong.empty(); + } + + private OptionalInt lastAckedSequence(TopicPartition topicPartition) { + TopicPartitionEntry entry = topicPartitions.get(topicPartition); + if (entry != null && entry.lastAckedSequence != NO_LAST_ACKED_SEQUENCE_NUMBER) + return OptionalInt.of(entry.lastAckedSequence); + else + return OptionalInt.empty(); + } + + private void startSequencesAtBeginning(TopicPartition topicPartition, ProducerIdAndEpoch newProducerIdAndEpoch) { + final PrimitiveRef.IntRef sequence = PrimitiveRef.ofInt(0); + TopicPartitionEntry topicPartitionEntry = getPartition(topicPartition); + topicPartitionEntry.resetSequenceNumbers(inFlightBatch -> { + inFlightBatch.resetProducerState(newProducerIdAndEpoch, sequence.value, inFlightBatch.isTransactional()); + sequence.value += inFlightBatch.recordCount; + }); + topicPartitionEntry.producerIdAndEpoch = newProducerIdAndEpoch; + topicPartitionEntry.nextSequence = sequence.value; + topicPartitionEntry.lastAckedSequence = NO_LAST_ACKED_SEQUENCE_NUMBER; + } + } + + private static class TopicPartitionEntry { + + // The producer id/epoch being used for a given partition. + private ProducerIdAndEpoch producerIdAndEpoch; + + // The base sequence of the next batch bound for a given partition. + private int nextSequence; + + // The sequence number of the last record of the last ack'd batch from the given partition. When there are no + // in flight requests for a partition, the lastAckedSequence(topicPartition) == nextSequence(topicPartition) - 1. + private int lastAckedSequence; + + // Keep track of the in flight batches bound for a partition, ordered by sequence. This helps us to ensure that + // we continue to order batches by the sequence numbers even when the responses come back out of order during + // leader failover. We add a batch to the queue when it is drained, and remove it when the batch completes + // (either successfully or through a fatal failure). + private SortedSet inflightBatchesBySequence; + + // We keep track of the last acknowledged offset on a per partition basis in order to disambiguate UnknownProducer + // responses which are due to the retention period elapsing, and those which are due to actual lost data. + private long lastAckedOffset; + + TopicPartitionEntry() { + this.producerIdAndEpoch = ProducerIdAndEpoch.NONE; + this.nextSequence = 0; + this.lastAckedSequence = NO_LAST_ACKED_SEQUENCE_NUMBER; + this.lastAckedOffset = ProduceResponse.INVALID_OFFSET; + this.inflightBatchesBySequence = new TreeSet<>(Comparator.comparingInt(ProducerBatch::baseSequence)); + } + + void resetSequenceNumbers(Consumer resetSequence) { + TreeSet newInflights = new TreeSet<>(Comparator.comparingInt(ProducerBatch::baseSequence)); + for (ProducerBatch inflightBatch : inflightBatchesBySequence) { + resetSequence.accept(inflightBatch); + newInflights.add(inflightBatch); + } + inflightBatchesBySequence = newInflights; + } + } + + private final TopicPartitionBookkeeper topicPartitionBookkeeper; + + private final Map pendingTxnOffsetCommits; + + // If a batch bound for a partition expired locally after being sent at least once, the partition has is considered + // to have an unresolved state. We keep track fo such partitions here, and cannot assign any more sequence numbers + // for this partition until the unresolved state gets cleared. This may happen if other inflight batches returned + // successfully (indicating that the expired batch actually made it to the broker). If we don't get any successful + // responses for the partition once the inflight request count falls to zero, we reset the producer id and + // consequently clear this data structure as well. + // The value of the map is the sequence number of the batch following the expired one, computed by adding its + // record count to its sequence number. This is used to tell if a subsequent batch is the one immediately following + // the expired one. + private final Map partitionsWithUnresolvedSequences; + + // The partitions that have received an error that triggers an epoch bump. When the epoch is bumped, these + // partitions will have the sequences of their in-flight batches rewritten + private final Set partitionsToRewriteSequences; + + private final PriorityQueue pendingRequests; + private final Set newPartitionsInTransaction; + private final Set pendingPartitionsInTransaction; + private final Set partitionsInTransaction; + private TransactionalRequestResult pendingResult; + + // This is used by the TxnRequestHandlers to control how long to back off before a given request is retried. + // For instance, this value is lowered by the AddPartitionsToTxnHandler when it receives a CONCURRENT_TRANSACTIONS + // error for the first AddPartitionsRequest in a transaction. + private final long retryBackoffMs; + + // The retryBackoff is overridden to the following value if the first AddPartitions receives a + // CONCURRENT_TRANSACTIONS error. + private static final long ADD_PARTITIONS_RETRY_BACKOFF_MS = 20L; + + private int inFlightRequestCorrelationId = NO_INFLIGHT_REQUEST_CORRELATION_ID; + private Node transactionCoordinator; + private Node consumerGroupCoordinator; + private boolean coordinatorSupportsBumpingEpoch; + + private volatile State currentState = State.UNINITIALIZED; + private volatile RuntimeException lastError = null; + private volatile ProducerIdAndEpoch producerIdAndEpoch; + private volatile boolean transactionStarted = false; + private volatile boolean epochBumpRequired = false; + + private enum State { + UNINITIALIZED, + INITIALIZING, + READY, + IN_TRANSACTION, + COMMITTING_TRANSACTION, + ABORTING_TRANSACTION, + ABORTABLE_ERROR, + FATAL_ERROR; + + private boolean isTransitionValid(State source, State target) { + switch (target) { + case UNINITIALIZED: + return source == READY; + case INITIALIZING: + return source == UNINITIALIZED || source == ABORTING_TRANSACTION; + case READY: + return source == INITIALIZING || source == COMMITTING_TRANSACTION || source == ABORTING_TRANSACTION; + case IN_TRANSACTION: + return source == READY; + case COMMITTING_TRANSACTION: + return source == IN_TRANSACTION; + case ABORTING_TRANSACTION: + return source == IN_TRANSACTION || source == ABORTABLE_ERROR; + case ABORTABLE_ERROR: + return source == IN_TRANSACTION || source == COMMITTING_TRANSACTION || source == ABORTABLE_ERROR; + case FATAL_ERROR: + default: + // We can transition to FATAL_ERROR unconditionally. + // FATAL_ERROR is never a valid starting state for any transition. So the only option is to close the + // producer or do purely non transactional requests. + return true; + } + } + } + + // We use the priority to determine the order in which requests need to be sent out. For instance, if we have + // a pending FindCoordinator request, that must always go first. Next, If we need a producer id, that must go second. + // The endTxn request must always go last, unless we are bumping the epoch (a special case of InitProducerId) as + // part of ending the transaction. + private enum Priority { + FIND_COORDINATOR(0), + INIT_PRODUCER_ID(1), + ADD_PARTITIONS_OR_OFFSETS(2), + END_TXN(3), + EPOCH_BUMP(4); + + final int priority; + + Priority(int priority) { + this.priority = priority; + } + } + + public TransactionManager(final LogContext logContext, + final String transactionalId, + final int transactionTimeoutMs, + final long retryBackoffMs, + final ApiVersions apiVersions) { + this.producerIdAndEpoch = ProducerIdAndEpoch.NONE; + this.transactionalId = transactionalId; + this.log = logContext.logger(TransactionManager.class); + this.transactionTimeoutMs = transactionTimeoutMs; + this.transactionCoordinator = null; + this.consumerGroupCoordinator = null; + this.newPartitionsInTransaction = new HashSet<>(); + this.pendingPartitionsInTransaction = new HashSet<>(); + this.partitionsInTransaction = new HashSet<>(); + this.pendingRequests = new PriorityQueue<>(10, Comparator.comparingInt(o -> o.priority().priority)); + this.pendingTxnOffsetCommits = new HashMap<>(); + this.partitionsWithUnresolvedSequences = new HashMap<>(); + this.partitionsToRewriteSequences = new HashSet<>(); + this.retryBackoffMs = retryBackoffMs; + this.topicPartitionBookkeeper = new TopicPartitionBookkeeper(); + this.apiVersions = apiVersions; + } + + public synchronized TransactionalRequestResult initializeTransactions() { + return initializeTransactions(ProducerIdAndEpoch.NONE); + } + + synchronized TransactionalRequestResult initializeTransactions(ProducerIdAndEpoch producerIdAndEpoch) { + boolean isEpochBump = producerIdAndEpoch != ProducerIdAndEpoch.NONE; + return handleCachedTransactionRequestResult(() -> { + // If this is an epoch bump, we will transition the state as part of handling the EndTxnRequest + if (!isEpochBump) { + transitionTo(State.INITIALIZING); + log.info("Invoking InitProducerId for the first time in order to acquire a producer ID"); + } else { + log.info("Invoking InitProducerId with current producer ID and epoch {} in order to bump the epoch", producerIdAndEpoch); + } + InitProducerIdRequestData requestData = new InitProducerIdRequestData() + .setTransactionalId(transactionalId) + .setTransactionTimeoutMs(transactionTimeoutMs) + .setProducerId(producerIdAndEpoch.producerId) + .setProducerEpoch(producerIdAndEpoch.epoch); + InitProducerIdHandler handler = new InitProducerIdHandler(new InitProducerIdRequest.Builder(requestData), + isEpochBump); + enqueueRequest(handler); + return handler.result; + }, State.INITIALIZING); + } + + public synchronized void beginTransaction() { + ensureTransactional(); + maybeFailWithError(); + transitionTo(State.IN_TRANSACTION); + } + + public synchronized TransactionalRequestResult beginCommit() { + return handleCachedTransactionRequestResult(() -> { + maybeFailWithError(); + transitionTo(State.COMMITTING_TRANSACTION); + return beginCompletingTransaction(TransactionResult.COMMIT); + }, State.COMMITTING_TRANSACTION); + } + + public synchronized TransactionalRequestResult beginAbort() { + return handleCachedTransactionRequestResult(() -> { + if (currentState != State.ABORTABLE_ERROR) + maybeFailWithError(); + transitionTo(State.ABORTING_TRANSACTION); + + // We're aborting the transaction, so there should be no need to add new partitions + newPartitionsInTransaction.clear(); + return beginCompletingTransaction(TransactionResult.ABORT); + }, State.ABORTING_TRANSACTION); + } + + private TransactionalRequestResult beginCompletingTransaction(TransactionResult transactionResult) { + if (!newPartitionsInTransaction.isEmpty()) + enqueueRequest(addPartitionsToTransactionHandler()); + + // If the error is an INVALID_PRODUCER_ID_MAPPING error, the server will not accept an EndTxnRequest, so skip + // directly to InitProducerId. Otherwise, we must first abort the transaction, because the producer will be + // fenced if we directly call InitProducerId. + if (!(lastError instanceof InvalidPidMappingException)) { + EndTxnRequest.Builder builder = new EndTxnRequest.Builder( + new EndTxnRequestData() + .setTransactionalId(transactionalId) + .setProducerId(producerIdAndEpoch.producerId) + .setProducerEpoch(producerIdAndEpoch.epoch) + .setCommitted(transactionResult.id)); + + EndTxnHandler handler = new EndTxnHandler(builder); + enqueueRequest(handler); + if (!epochBumpRequired) { + return handler.result; + } + } + + return initializeTransactions(this.producerIdAndEpoch); + } + + public synchronized TransactionalRequestResult sendOffsetsToTransaction(final Map offsets, + final ConsumerGroupMetadata groupMetadata) { + ensureTransactional(); + maybeFailWithError(); + if (currentState != State.IN_TRANSACTION) + throw new KafkaException("Cannot send offsets to transaction either because the producer is not in an " + + "active transaction"); + + log.debug("Begin adding offsets {} for consumer group {} to transaction", offsets, groupMetadata); + AddOffsetsToTxnRequest.Builder builder = new AddOffsetsToTxnRequest.Builder( + new AddOffsetsToTxnRequestData() + .setTransactionalId(transactionalId) + .setProducerId(producerIdAndEpoch.producerId) + .setProducerEpoch(producerIdAndEpoch.epoch) + .setGroupId(groupMetadata.groupId()) + ); + AddOffsetsToTxnHandler handler = new AddOffsetsToTxnHandler(builder, offsets, groupMetadata); + + enqueueRequest(handler); + return handler.result; + } + + public synchronized void maybeAddPartitionToTransaction(TopicPartition topicPartition) { + if (isPartitionAdded(topicPartition) || isPartitionPendingAdd(topicPartition)) + return; + + log.debug("Begin adding new partition {} to transaction", topicPartition); + topicPartitionBookkeeper.addPartition(topicPartition); + newPartitionsInTransaction.add(topicPartition); + } + + RuntimeException lastError() { + return lastError; + } + + public synchronized void failIfNotReadyForSend() { + if (hasError()) + throw new KafkaException("Cannot perform send because at least one previous transactional or " + + "idempotent request has failed with errors.", lastError); + + if (isTransactional()) { + if (!hasProducerId()) + throw new IllegalStateException("Cannot perform a 'send' before completing a call to initTransactions " + + "when transactions are enabled."); + + if (currentState != State.IN_TRANSACTION) + throw new IllegalStateException("Cannot call send in state " + currentState); + } + } + + synchronized boolean isSendToPartitionAllowed(TopicPartition tp) { + if (hasFatalError()) + return false; + return !isTransactional() || partitionsInTransaction.contains(tp); + } + + public String transactionalId() { + return transactionalId; + } + + public boolean hasProducerId() { + return producerIdAndEpoch.isValid(); + } + + public boolean isTransactional() { + return transactionalId != null; + } + + synchronized boolean hasPartitionsToAdd() { + return !newPartitionsInTransaction.isEmpty() || !pendingPartitionsInTransaction.isEmpty(); + } + + synchronized boolean isCompleting() { + return currentState == State.COMMITTING_TRANSACTION || currentState == State.ABORTING_TRANSACTION; + } + + synchronized boolean hasError() { + return currentState == State.ABORTABLE_ERROR || currentState == State.FATAL_ERROR; + } + + synchronized boolean isAborting() { + return currentState == State.ABORTING_TRANSACTION; + } + + synchronized void transitionToAbortableError(RuntimeException exception) { + if (currentState == State.ABORTING_TRANSACTION) { + log.debug("Skipping transition to abortable error state since the transaction is already being " + + "aborted. Underlying exception: ", exception); + return; + } + + log.info("Transiting to abortable error state due to {}", exception.toString()); + transitionTo(State.ABORTABLE_ERROR, exception); + } + + synchronized void transitionToFatalError(RuntimeException exception) { + log.info("Transiting to fatal error state due to {}", exception.toString()); + transitionTo(State.FATAL_ERROR, exception); + + if (pendingResult != null) { + pendingResult.fail(exception); + } + } + + // visible for testing + synchronized boolean isPartitionAdded(TopicPartition partition) { + return partitionsInTransaction.contains(partition); + } + + // visible for testing + synchronized boolean isPartitionPendingAdd(TopicPartition partition) { + return newPartitionsInTransaction.contains(partition) || pendingPartitionsInTransaction.contains(partition); + } + + /** + * Get the current producer id and epoch without blocking. Callers must use {@link ProducerIdAndEpoch#isValid()} to + * verify that the result is valid. + * + * @return the current ProducerIdAndEpoch. + */ + ProducerIdAndEpoch producerIdAndEpoch() { + return producerIdAndEpoch; + } + + synchronized public void maybeUpdateProducerIdAndEpoch(TopicPartition topicPartition) { + if (hasStaleProducerIdAndEpoch(topicPartition) && !hasInflightBatches(topicPartition)) { + // If the batch was on a different ID and/or epoch (due to an epoch bump) and all its in-flight batches + // have completed, reset the partition sequence so that the next batch (with the new epoch) starts from 0 + topicPartitionBookkeeper.startSequencesAtBeginning(topicPartition, this.producerIdAndEpoch); + log.debug("ProducerId of partition {} set to {} with epoch {}. Reinitialize sequence at beginning.", + topicPartition, producerIdAndEpoch.producerId, producerIdAndEpoch.epoch); + } + } + + /** + * Set the producer id and epoch atomically. + */ + private void setProducerIdAndEpoch(ProducerIdAndEpoch producerIdAndEpoch) { + log.info("ProducerId set to {} with epoch {}", producerIdAndEpoch.producerId, producerIdAndEpoch.epoch); + this.producerIdAndEpoch = producerIdAndEpoch; + } + + /** + * This method resets the producer ID and epoch and sets the state to UNINITIALIZED, which will trigger a new + * InitProducerId request. This method is only called when the producer epoch is exhausted; we will bump the epoch + * instead. + */ + private void resetIdempotentProducerId() { + if (isTransactional()) + throw new IllegalStateException("Cannot reset producer state for a transactional producer. " + + "You must either abort the ongoing transaction or reinitialize the transactional producer instead"); + log.debug("Resetting idempotent producer ID. ID and epoch before reset are {}", this.producerIdAndEpoch); + setProducerIdAndEpoch(ProducerIdAndEpoch.NONE); + transitionTo(State.UNINITIALIZED); + } + + private void resetSequenceForPartition(TopicPartition topicPartition) { + topicPartitionBookkeeper.topicPartitions.remove(topicPartition); + this.partitionsWithUnresolvedSequences.remove(topicPartition); + } + + private void resetSequenceNumbers() { + topicPartitionBookkeeper.reset(); + this.partitionsWithUnresolvedSequences.clear(); + } + + synchronized void requestEpochBumpForPartition(TopicPartition tp) { + epochBumpRequired = true; + this.partitionsToRewriteSequences.add(tp); + } + + private void bumpIdempotentProducerEpoch() { + if (this.producerIdAndEpoch.epoch == Short.MAX_VALUE) { + resetIdempotentProducerId(); + } else { + setProducerIdAndEpoch(new ProducerIdAndEpoch(this.producerIdAndEpoch.producerId, (short) (this.producerIdAndEpoch.epoch + 1))); + log.debug("Incremented producer epoch, current producer ID and epoch are now {}", this.producerIdAndEpoch); + } + + // When the epoch is bumped, rewrite all in-flight sequences for the partition(s) that triggered the epoch bump + for (TopicPartition topicPartition : this.partitionsToRewriteSequences) { + this.topicPartitionBookkeeper.startSequencesAtBeginning(topicPartition, this.producerIdAndEpoch); + this.partitionsWithUnresolvedSequences.remove(topicPartition); + } + this.partitionsToRewriteSequences.clear(); + + epochBumpRequired = false; + } + + synchronized void bumpIdempotentEpochAndResetIdIfNeeded() { + if (!isTransactional()) { + if (epochBumpRequired) { + bumpIdempotentProducerEpoch(); + } + if (currentState != State.INITIALIZING && !hasProducerId()) { + transitionTo(State.INITIALIZING); + InitProducerIdRequestData requestData = new InitProducerIdRequestData() + .setTransactionalId(null) + .setTransactionTimeoutMs(Integer.MAX_VALUE); + InitProducerIdHandler handler = new InitProducerIdHandler(new InitProducerIdRequest.Builder(requestData), false); + enqueueRequest(handler); + } + } + } + + /** + * Returns the next sequence number to be written to the given TopicPartition. + */ + synchronized Integer sequenceNumber(TopicPartition topicPartition) { + return topicPartitionBookkeeper.getOrCreatePartition(topicPartition).nextSequence; + } + + /** + * Returns the current producer id/epoch of the given TopicPartition. + */ + synchronized ProducerIdAndEpoch producerIdAndEpoch(TopicPartition topicPartition) { + return topicPartitionBookkeeper.getOrCreatePartition(topicPartition).producerIdAndEpoch; + } + + synchronized void incrementSequenceNumber(TopicPartition topicPartition, int increment) { + Integer currentSequence = sequenceNumber(topicPartition); + + currentSequence = DefaultRecordBatch.incrementSequence(currentSequence, increment); + topicPartitionBookkeeper.getPartition(topicPartition).nextSequence = currentSequence; + } + + synchronized void addInFlightBatch(ProducerBatch batch) { + if (!batch.hasSequence()) + throw new IllegalStateException("Can't track batch for partition " + batch.topicPartition + " when sequence is not set."); + topicPartitionBookkeeper.getPartition(batch.topicPartition).inflightBatchesBySequence.add(batch); + } + + /** + * Returns the first inflight sequence for a given partition. This is the base sequence of an inflight batch with + * the lowest sequence number. + * @return the lowest inflight sequence if the transaction manager is tracking inflight requests for this partition. + * If there are no inflight requests being tracked for this partition, this method will return + * RecordBatch.NO_SEQUENCE. + */ + synchronized int firstInFlightSequence(TopicPartition topicPartition) { + if (!hasInflightBatches(topicPartition)) + return RecordBatch.NO_SEQUENCE; + + SortedSet inflightBatches = topicPartitionBookkeeper.getPartition(topicPartition).inflightBatchesBySequence; + if (inflightBatches.isEmpty()) + return RecordBatch.NO_SEQUENCE; + else + return inflightBatches.first().baseSequence(); + } + + synchronized ProducerBatch nextBatchBySequence(TopicPartition topicPartition) { + SortedSet queue = topicPartitionBookkeeper.getPartition(topicPartition).inflightBatchesBySequence; + return queue.isEmpty() ? null : queue.first(); + } + + synchronized void removeInFlightBatch(ProducerBatch batch) { + if (hasInflightBatches(batch.topicPartition)) { + topicPartitionBookkeeper.getPartition(batch.topicPartition).inflightBatchesBySequence.remove(batch); + } + } + + private int maybeUpdateLastAckedSequence(TopicPartition topicPartition, int sequence) { + int lastAckedSequence = lastAckedSequence(topicPartition).orElse(NO_LAST_ACKED_SEQUENCE_NUMBER); + if (sequence > lastAckedSequence) { + topicPartitionBookkeeper.getPartition(topicPartition).lastAckedSequence = sequence; + return sequence; + } + + return lastAckedSequence; + } + + synchronized OptionalInt lastAckedSequence(TopicPartition topicPartition) { + return topicPartitionBookkeeper.lastAckedSequence(topicPartition); + } + + synchronized OptionalLong lastAckedOffset(TopicPartition topicPartition) { + return topicPartitionBookkeeper.lastAckedOffset(topicPartition); + } + + private void updateLastAckedOffset(ProduceResponse.PartitionResponse response, ProducerBatch batch) { + if (response.baseOffset == ProduceResponse.INVALID_OFFSET) + return; + long lastOffset = response.baseOffset + batch.recordCount - 1; + OptionalLong lastAckedOffset = lastAckedOffset(batch.topicPartition); + // It might happen that the TransactionManager has been reset while a request was reenqueued and got a valid + // response for this. This can happen only if the producer is only idempotent (not transactional) and in + // this case there will be no tracked bookkeeper entry about it, so we have to insert one. + if (!lastAckedOffset.isPresent() && !isTransactional()) { + topicPartitionBookkeeper.addPartition(batch.topicPartition); + } + if (lastOffset > lastAckedOffset.orElse(ProduceResponse.INVALID_OFFSET)) { + topicPartitionBookkeeper.getPartition(batch.topicPartition).lastAckedOffset = lastOffset; + } else { + log.trace("Partition {} keeps lastOffset at {}", batch.topicPartition, lastOffset); + } + } + + public synchronized void handleCompletedBatch(ProducerBatch batch, ProduceResponse.PartitionResponse response) { + int lastAckedSequence = maybeUpdateLastAckedSequence(batch.topicPartition, batch.lastSequence()); + log.debug("ProducerId: {}; Set last ack'd sequence number for topic-partition {} to {}", + batch.producerId(), + batch.topicPartition, + lastAckedSequence); + + updateLastAckedOffset(response, batch); + removeInFlightBatch(batch); + } + + private void maybeTransitionToErrorState(RuntimeException exception) { + if (exception instanceof ClusterAuthorizationException + || exception instanceof TransactionalIdAuthorizationException + || exception instanceof ProducerFencedException + || exception instanceof UnsupportedVersionException) { + transitionToFatalError(exception); + } else if (isTransactional()) { + if (canBumpEpoch() && !isCompleting()) { + epochBumpRequired = true; + } + transitionToAbortableError(exception); + } + } + + synchronized void handleFailedBatch(ProducerBatch batch, RuntimeException exception, boolean adjustSequenceNumbers) { + maybeTransitionToErrorState(exception); + removeInFlightBatch(batch); + + if (hasFatalError()) { + log.debug("Ignoring batch {} with producer id {}, epoch {}, and sequence number {} " + + "since the producer is already in fatal error state", batch, batch.producerId(), + batch.producerEpoch(), batch.baseSequence(), exception); + return; + } + + if (exception instanceof OutOfOrderSequenceException && !isTransactional()) { + log.error("The broker returned {} for topic-partition {} with producerId {}, epoch {}, and sequence number {}", + exception, batch.topicPartition, batch.producerId(), batch.producerEpoch(), batch.baseSequence()); + + // If we fail with an OutOfOrderSequenceException, we have a gap in the log. Bump the epoch for this + // partition, which will reset the sequence number to 0 and allow us to continue + requestEpochBumpForPartition(batch.topicPartition); + } else if (exception instanceof UnknownProducerIdException) { + // If we get an UnknownProducerId for a partition, then the broker has no state for that producer. It will + // therefore accept a write with sequence number 0. We reset the sequence number for the partition here so + // that the producer can continue after aborting the transaction. All inflight-requests to this partition + // will also fail with an UnknownProducerId error, so the sequence will remain at 0. Note that if the + // broker supports bumping the epoch, we will later reset all sequence numbers after calling InitProducerId + resetSequenceForPartition(batch.topicPartition); + } else { + if (adjustSequenceNumbers) { + if (!isTransactional()) { + requestEpochBumpForPartition(batch.topicPartition); + } else { + adjustSequencesDueToFailedBatch(batch); + } + } + } + } + + // If a batch is failed fatally, the sequence numbers for future batches bound for the partition must be adjusted + // so that they don't fail with the OutOfOrderSequenceException. + // + // This method must only be called when we know that the batch is question has been unequivocally failed by the broker, + // ie. it has received a confirmed fatal status code like 'Message Too Large' or something similar. + private void adjustSequencesDueToFailedBatch(ProducerBatch batch) { + if (!topicPartitionBookkeeper.contains(batch.topicPartition)) + // Sequence numbers are not being tracked for this partition. This could happen if the producer id was just + // reset due to a previous OutOfOrderSequenceException. + return; + log.debug("producerId: {}, send to partition {} failed fatally. Reducing future sequence numbers by {}", + batch.producerId(), batch.topicPartition, batch.recordCount); + int currentSequence = sequenceNumber(batch.topicPartition); + currentSequence -= batch.recordCount; + if (currentSequence < 0) + throw new IllegalStateException("Sequence number for partition " + batch.topicPartition + " is going to become negative: " + currentSequence); + + setNextSequence(batch.topicPartition, currentSequence); + + topicPartitionBookkeeper.getPartition(batch.topicPartition).resetSequenceNumbers(inFlightBatch -> { + if (inFlightBatch.baseSequence() < batch.baseSequence()) + return; + + int newSequence = inFlightBatch.baseSequence() - batch.recordCount; + if (newSequence < 0) + throw new IllegalStateException("Sequence number for batch with sequence " + inFlightBatch.baseSequence() + + " for partition " + batch.topicPartition + " is going to become negative: " + newSequence); + + log.info("Resetting sequence number of batch with current sequence {} for partition {} to {}", inFlightBatch.baseSequence(), batch.topicPartition, newSequence); + inFlightBatch.resetProducerState(new ProducerIdAndEpoch(inFlightBatch.producerId(), inFlightBatch.producerEpoch()), newSequence, inFlightBatch.isTransactional()); + }); + } + + synchronized boolean hasInflightBatches(TopicPartition topicPartition) { + return !topicPartitionBookkeeper.getOrCreatePartition(topicPartition).inflightBatchesBySequence.isEmpty(); + } + + synchronized boolean hasStaleProducerIdAndEpoch(TopicPartition topicPartition) { + return !producerIdAndEpoch.equals(topicPartitionBookkeeper.getOrCreatePartition(topicPartition).producerIdAndEpoch); + } + + synchronized boolean hasUnresolvedSequences() { + return !partitionsWithUnresolvedSequences.isEmpty(); + } + + synchronized boolean hasUnresolvedSequence(TopicPartition topicPartition) { + return partitionsWithUnresolvedSequences.containsKey(topicPartition); + } + + synchronized void markSequenceUnresolved(ProducerBatch batch) { + int nextSequence = batch.lastSequence() + 1; + partitionsWithUnresolvedSequences.compute(batch.topicPartition, + (k, v) -> v == null ? nextSequence : Math.max(v, nextSequence)); + log.debug("Marking partition {} unresolved with next sequence number {}", batch.topicPartition, + partitionsWithUnresolvedSequences.get(batch.topicPartition)); + } + + // Attempts to resolve unresolved sequences. If all in-flight requests are complete and some partitions are still + // unresolved, either bump the epoch if possible, or transition to a fatal error + synchronized void maybeResolveSequences() { + for (Iterator iter = partitionsWithUnresolvedSequences.keySet().iterator(); iter.hasNext(); ) { + TopicPartition topicPartition = iter.next(); + if (!hasInflightBatches(topicPartition)) { + // The partition has been fully drained. At this point, the last ack'd sequence should be one less than + // next sequence destined for the partition. If so, the partition is fully resolved. If not, we should + // reset the sequence number if necessary. + if (isNextSequence(topicPartition, sequenceNumber(topicPartition))) { + // This would happen when a batch was expired, but subsequent batches succeeded. + iter.remove(); + } else { + // We would enter this branch if all in flight batches were ultimately expired in the producer. + if (isTransactional()) { + // For the transactional producer, we bump the epoch if possible, otherwise we transition to a fatal error + String unackedMessagesErr = "The client hasn't received acknowledgment for some previously " + + "sent messages and can no longer retry them. "; + if (canBumpEpoch()) { + epochBumpRequired = true; + KafkaException exception = new KafkaException(unackedMessagesErr + "It is safe to abort " + + "the transaction and continue."); + transitionToAbortableError(exception); + } else { + KafkaException exception = new KafkaException(unackedMessagesErr + "It isn't safe to continue."); + transitionToFatalError(exception); + } + } else { + // For the idempotent producer, bump the epoch + log.info("No inflight batches remaining for {}, last ack'd sequence for partition is {}, next sequence is {}. " + + "Going to bump epoch and reset sequence numbers.", topicPartition, + lastAckedSequence(topicPartition).orElse(NO_LAST_ACKED_SEQUENCE_NUMBER), sequenceNumber(topicPartition)); + requestEpochBumpForPartition(topicPartition); + } + + iter.remove(); + } + } + } + } + + private boolean isNextSequence(TopicPartition topicPartition, int sequence) { + return sequence - lastAckedSequence(topicPartition).orElse(NO_LAST_ACKED_SEQUENCE_NUMBER) == 1; + } + + private void setNextSequence(TopicPartition topicPartition, int sequence) { + topicPartitionBookkeeper.getPartition(topicPartition).nextSequence = sequence; + } + + private boolean isNextSequenceForUnresolvedPartition(TopicPartition topicPartition, int sequence) { + return this.hasUnresolvedSequence(topicPartition) && + sequence == this.partitionsWithUnresolvedSequences.get(topicPartition); + } + + synchronized TxnRequestHandler nextRequest(boolean hasIncompleteBatches) { + if (!newPartitionsInTransaction.isEmpty()) + enqueueRequest(addPartitionsToTransactionHandler()); + + TxnRequestHandler nextRequestHandler = pendingRequests.peek(); + if (nextRequestHandler == null) + return null; + + // Do not send the EndTxn until all batches have been flushed + if (nextRequestHandler.isEndTxn() && hasIncompleteBatches) + return null; + + pendingRequests.poll(); + if (maybeTerminateRequestWithError(nextRequestHandler)) { + log.trace("Not sending transactional request {} because we are in an error state", + nextRequestHandler.requestBuilder()); + return null; + } + + if (nextRequestHandler.isEndTxn() && !transactionStarted) { + nextRequestHandler.result.done(); + if (currentState != State.FATAL_ERROR) { + log.debug("Not sending EndTxn for completed transaction since no partitions " + + "or offsets were successfully added"); + completeTransaction(); + } + nextRequestHandler = pendingRequests.poll(); + } + + if (nextRequestHandler != null) + log.trace("Request {} dequeued for sending", nextRequestHandler.requestBuilder()); + + return nextRequestHandler; + } + + synchronized void retry(TxnRequestHandler request) { + request.setRetry(); + enqueueRequest(request); + } + + synchronized void authenticationFailed(AuthenticationException e) { + for (TxnRequestHandler request : pendingRequests) + request.fatalError(e); + } + + synchronized void close() { + KafkaException shutdownException = new KafkaException("The producer closed forcefully"); + pendingRequests.forEach(handler -> + handler.fatalError(shutdownException)); + if (pendingResult != null) { + pendingResult.fail(shutdownException); + } + } + + Node coordinator(FindCoordinatorRequest.CoordinatorType type) { + switch (type) { + case GROUP: + return consumerGroupCoordinator; + case TRANSACTION: + return transactionCoordinator; + default: + throw new IllegalStateException("Received an invalid coordinator type: " + type); + } + } + + void lookupCoordinator(TxnRequestHandler request) { + lookupCoordinator(request.coordinatorType(), request.coordinatorKey()); + } + + void setInFlightCorrelationId(int correlationId) { + inFlightRequestCorrelationId = correlationId; + } + + private void clearInFlightCorrelationId() { + inFlightRequestCorrelationId = NO_INFLIGHT_REQUEST_CORRELATION_ID; + } + + boolean hasInFlightRequest() { + return inFlightRequestCorrelationId != NO_INFLIGHT_REQUEST_CORRELATION_ID; + } + + // visible for testing. + boolean hasFatalError() { + return currentState == State.FATAL_ERROR; + } + + // visible for testing. + boolean hasAbortableError() { + return currentState == State.ABORTABLE_ERROR; + } + + // visible for testing + synchronized boolean transactionContainsPartition(TopicPartition topicPartition) { + return partitionsInTransaction.contains(topicPartition); + } + + // visible for testing + synchronized boolean hasPendingOffsetCommits() { + return !pendingTxnOffsetCommits.isEmpty(); + } + + synchronized boolean hasPendingRequests() { + return !pendingRequests.isEmpty(); + } + + // visible for testing + synchronized boolean hasOngoingTransaction() { + // transactions are considered ongoing once started until completion or a fatal error + return currentState == State.IN_TRANSACTION || isCompleting() || hasAbortableError(); + } + + synchronized boolean canRetry(ProduceResponse.PartitionResponse response, ProducerBatch batch) { + Errors error = response.error; + + // An UNKNOWN_PRODUCER_ID means that we have lost the producer state on the broker. Depending on the log start + // offset, we may want to retry these, as described for each case below. If none of those apply, then for the + // idempotent producer, we will locally bump the epoch and reset the sequence numbers of in-flight batches from + // sequence 0, then retry the failed batch, which should now succeed. For the transactional producer, allow the + // batch to fail. When processing the failed batch, we will transition to an abortable error and set a flag + // indicating that we need to bump the epoch (if supported by the broker). + if (error == Errors.UNKNOWN_PRODUCER_ID) { + if (response.logStartOffset == -1) { + // We don't know the log start offset with this response. We should just retry the request until we get it. + // The UNKNOWN_PRODUCER_ID error code was added along with the new ProduceResponse which includes the + // logStartOffset. So the '-1' sentinel is not for backward compatibility. Instead, it is possible for + // a broker to not know the logStartOffset at when it is returning the response because the partition + // may have moved away from the broker from the time the error was initially raised to the time the + // response was being constructed. In these cases, we should just retry the request: we are guaranteed + // to eventually get a logStartOffset once things settle down. + return true; + } + + if (batch.sequenceHasBeenReset()) { + // When the first inflight batch fails due to the truncation case, then the sequences of all the other + // in flight batches would have been restarted from the beginning. However, when those responses + // come back from the broker, they would also come with an UNKNOWN_PRODUCER_ID error. In this case, we should not + // reset the sequence numbers to the beginning. + return true; + } else if (lastAckedOffset(batch.topicPartition).orElse(NO_LAST_ACKED_SEQUENCE_NUMBER) < response.logStartOffset) { + // The head of the log has been removed, probably due to the retention time elapsing. In this case, + // we expect to lose the producer state. For the transactional producer, reset the sequences of all + // inflight batches to be from the beginning and retry them, so that the transaction does not need to + // be aborted. For the idempotent producer, bump the epoch to avoid reusing (sequence, epoch) pairs + if (isTransactional()) { + topicPartitionBookkeeper.startSequencesAtBeginning(batch.topicPartition, this.producerIdAndEpoch); + } else { + requestEpochBumpForPartition(batch.topicPartition); + } + return true; + } + + if (!isTransactional()) { + // For the idempotent producer, always retry UNKNOWN_PRODUCER_ID errors. If the batch has the current + // producer ID and epoch, request a bump of the epoch. Otherwise just retry the produce. + requestEpochBumpForPartition(batch.topicPartition); + return true; + } + } else if (error == Errors.OUT_OF_ORDER_SEQUENCE_NUMBER) { + if (!hasUnresolvedSequence(batch.topicPartition) && + (batch.sequenceHasBeenReset() || !isNextSequence(batch.topicPartition, batch.baseSequence()))) { + // We should retry the OutOfOrderSequenceException if the batch is _not_ the next batch, ie. its base + // sequence isn't the lastAckedSequence + 1. + return true; + } else if (!isTransactional()) { + // For the idempotent producer, retry all OUT_OF_ORDER_SEQUENCE_NUMBER errors. If there are no + // unresolved sequences, or this batch is the one immediately following an unresolved sequence, we know + // there is actually a gap in the sequences, and we bump the epoch. Otherwise, retry without bumping + // and wait to see if the sequence resolves + if (!hasUnresolvedSequence(batch.topicPartition) || + isNextSequenceForUnresolvedPartition(batch.topicPartition, batch.baseSequence())) { + requestEpochBumpForPartition(batch.topicPartition); + } + return true; + } + } + + // If neither of the above cases are true, retry if the exception is retriable + return error.exception() instanceof RetriableException; + } + + // visible for testing + synchronized boolean isReady() { + return isTransactional() && currentState == State.READY; + } + + void handleCoordinatorReady() { + NodeApiVersions nodeApiVersions = transactionCoordinator != null ? + apiVersions.get(transactionCoordinator.idString()) : + null; + ApiVersion initProducerIdVersion = nodeApiVersions != null ? + nodeApiVersions.apiVersion(ApiKeys.INIT_PRODUCER_ID) : + null; + this.coordinatorSupportsBumpingEpoch = initProducerIdVersion != null && + initProducerIdVersion.maxVersion() >= 3; + } + + private void transitionTo(State target) { + transitionTo(target, null); + } + + private void transitionTo(State target, RuntimeException error) { + if (!currentState.isTransitionValid(currentState, target)) { + String idString = transactionalId == null ? "" : "TransactionalId " + transactionalId + ": "; + throw new KafkaException(idString + "Invalid transition attempted from state " + + currentState.name() + " to state " + target.name()); + } + + if (target == State.FATAL_ERROR || target == State.ABORTABLE_ERROR) { + if (error == null) + throw new IllegalArgumentException("Cannot transition to " + target + " with a null exception"); + lastError = error; + } else { + lastError = null; + } + + if (lastError != null) + log.debug("Transition from state {} to error state {}", currentState, target, lastError); + else + log.debug("Transition from state {} to {}", currentState, target); + + currentState = target; + } + + private void ensureTransactional() { + if (!isTransactional()) + throw new IllegalStateException("Transactional method invoked on a non-transactional producer."); + } + + private void maybeFailWithError() { + if (hasError()) { + // for ProducerFencedException, do not wrap it as a KafkaException + // but create a new instance without the call trace since it was not thrown because of the current call + if (lastError instanceof ProducerFencedException) { + throw new ProducerFencedException("The producer has been rejected from the broker because " + + "it tried to use an old epoch with the transactionalId"); + } else if (lastError instanceof InvalidProducerEpochException) { + throw new InvalidProducerEpochException("Producer attempted to produce with an old epoch " + producerIdAndEpoch); + } else { + throw new KafkaException("Cannot execute transactional method because we are in an error state", lastError); + } + } + } + + private boolean maybeTerminateRequestWithError(TxnRequestHandler requestHandler) { + if (hasError()) { + if (hasAbortableError() && requestHandler instanceof FindCoordinatorHandler) + // No harm letting the FindCoordinator request go through if we're expecting to abort + return false; + + requestHandler.fail(lastError); + return true; + } + return false; + } + + private void enqueueRequest(TxnRequestHandler requestHandler) { + log.debug("Enqueuing transactional request {}", requestHandler.requestBuilder()); + pendingRequests.add(requestHandler); + } + + private void lookupCoordinator(FindCoordinatorRequest.CoordinatorType type, String coordinatorKey) { + switch (type) { + case GROUP: + consumerGroupCoordinator = null; + break; + case TRANSACTION: + transactionCoordinator = null; + break; + default: + throw new IllegalStateException("Invalid coordinator type: " + type); + } + + FindCoordinatorRequestData data = new FindCoordinatorRequestData() + .setKeyType(type.id()) + .setKey(coordinatorKey); + FindCoordinatorRequest.Builder builder = new FindCoordinatorRequest.Builder(data); + enqueueRequest(new FindCoordinatorHandler(builder)); + } + + private TxnRequestHandler addPartitionsToTransactionHandler() { + pendingPartitionsInTransaction.addAll(newPartitionsInTransaction); + newPartitionsInTransaction.clear(); + AddPartitionsToTxnRequest.Builder builder = + new AddPartitionsToTxnRequest.Builder(transactionalId, + producerIdAndEpoch.producerId, + producerIdAndEpoch.epoch, + new ArrayList<>(pendingPartitionsInTransaction)); + return new AddPartitionsToTxnHandler(builder); + } + + private TxnOffsetCommitHandler txnOffsetCommitHandler(TransactionalRequestResult result, + Map offsets, + ConsumerGroupMetadata groupMetadata) { + for (Map.Entry entry : offsets.entrySet()) { + OffsetAndMetadata offsetAndMetadata = entry.getValue(); + CommittedOffset committedOffset = new CommittedOffset(offsetAndMetadata.offset(), + offsetAndMetadata.metadata(), offsetAndMetadata.leaderEpoch()); + pendingTxnOffsetCommits.put(entry.getKey(), committedOffset); + } + + final TxnOffsetCommitRequest.Builder builder = + new TxnOffsetCommitRequest.Builder(transactionalId, + groupMetadata.groupId(), + producerIdAndEpoch.producerId, + producerIdAndEpoch.epoch, + pendingTxnOffsetCommits, + groupMetadata.memberId(), + groupMetadata.generationId(), + groupMetadata.groupInstanceId() + ); + return new TxnOffsetCommitHandler(result, builder); + } + + private TransactionalRequestResult handleCachedTransactionRequestResult( + Supplier transactionalRequestResultSupplier, + State targetState) { + ensureTransactional(); + + if (pendingResult != null && currentState == targetState) { + TransactionalRequestResult result = pendingResult; + if (result.isCompleted()) + pendingResult = null; + return result; + } + + pendingResult = transactionalRequestResultSupplier.get(); + return pendingResult; + } + + // package-private for testing + boolean canBumpEpoch() { + if (!isTransactional()) { + return true; + } + + return coordinatorSupportsBumpingEpoch; + } + + private void completeTransaction() { + if (epochBumpRequired) { + transitionTo(State.INITIALIZING); + } else { + transitionTo(State.READY); + } + lastError = null; + epochBumpRequired = false; + transactionStarted = false; + newPartitionsInTransaction.clear(); + pendingPartitionsInTransaction.clear(); + partitionsInTransaction.clear(); + } + + abstract class TxnRequestHandler implements RequestCompletionHandler { + protected final TransactionalRequestResult result; + private boolean isRetry = false; + + TxnRequestHandler(TransactionalRequestResult result) { + this.result = result; + } + + TxnRequestHandler(String operation) { + this(new TransactionalRequestResult(operation)); + } + + void fatalError(RuntimeException e) { + result.fail(e); + transitionToFatalError(e); + } + + void abortableError(RuntimeException e) { + result.fail(e); + transitionToAbortableError(e); + } + + void abortableErrorIfPossible(RuntimeException e) { + if (canBumpEpoch()) { + epochBumpRequired = true; + abortableError(e); + } else { + fatalError(e); + } + } + + void fail(RuntimeException e) { + result.fail(e); + } + + void reenqueue() { + synchronized (TransactionManager.this) { + this.isRetry = true; + enqueueRequest(this); + } + } + + long retryBackoffMs() { + return retryBackoffMs; + } + + @Override + public void onComplete(ClientResponse response) { + if (response.requestHeader().correlationId() != inFlightRequestCorrelationId) { + fatalError(new RuntimeException("Detected more than one in-flight transactional request.")); + } else { + clearInFlightCorrelationId(); + if (response.wasDisconnected()) { + log.debug("Disconnected from {}. Will retry.", response.destination()); + if (this.needsCoordinator()) + lookupCoordinator(this.coordinatorType(), this.coordinatorKey()); + reenqueue(); + } else if (response.versionMismatch() != null) { + fatalError(response.versionMismatch()); + } else if (response.hasResponse()) { + log.trace("Received transactional response {} for request {}", response.responseBody(), + requestBuilder()); + synchronized (TransactionManager.this) { + handleResponse(response.responseBody()); + } + } else { + fatalError(new KafkaException("Could not execute transactional request for unknown reasons")); + } + } + } + + boolean needsCoordinator() { + return coordinatorType() != null; + } + + FindCoordinatorRequest.CoordinatorType coordinatorType() { + return FindCoordinatorRequest.CoordinatorType.TRANSACTION; + } + + String coordinatorKey() { + return transactionalId; + } + + void setRetry() { + this.isRetry = true; + } + + boolean isRetry() { + return isRetry; + } + + boolean isEndTxn() { + return false; + } + + abstract AbstractRequest.Builder requestBuilder(); + + abstract void handleResponse(AbstractResponse responseBody); + + abstract Priority priority(); + } + + private class InitProducerIdHandler extends TxnRequestHandler { + private final InitProducerIdRequest.Builder builder; + private final boolean isEpochBump; + + private InitProducerIdHandler(InitProducerIdRequest.Builder builder, boolean isEpochBump) { + super("InitProducerId"); + this.builder = builder; + this.isEpochBump = isEpochBump; + } + + @Override + InitProducerIdRequest.Builder requestBuilder() { + return builder; + } + + @Override + Priority priority() { + return this.isEpochBump ? Priority.EPOCH_BUMP : Priority.INIT_PRODUCER_ID; + } + + @Override + FindCoordinatorRequest.CoordinatorType coordinatorType() { + if (isTransactional()) { + return FindCoordinatorRequest.CoordinatorType.TRANSACTION; + } else { + return null; + } + } + + @Override + public void handleResponse(AbstractResponse response) { + InitProducerIdResponse initProducerIdResponse = (InitProducerIdResponse) response; + Errors error = initProducerIdResponse.error(); + + if (error == Errors.NONE) { + ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(initProducerIdResponse.data().producerId(), + initProducerIdResponse.data().producerEpoch()); + setProducerIdAndEpoch(producerIdAndEpoch); + transitionTo(State.READY); + lastError = null; + if (this.isEpochBump) { + resetSequenceNumbers(); + } + result.done(); + } else if (error == Errors.NOT_COORDINATOR || error == Errors.COORDINATOR_NOT_AVAILABLE) { + lookupCoordinator(FindCoordinatorRequest.CoordinatorType.TRANSACTION, transactionalId); + reenqueue(); + } else if (error == Errors.COORDINATOR_LOAD_IN_PROGRESS || error == Errors.CONCURRENT_TRANSACTIONS) { + reenqueue(); + } else if (error == Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED || + error == Errors.CLUSTER_AUTHORIZATION_FAILED) { + fatalError(error.exception()); + } else if (error == Errors.INVALID_PRODUCER_EPOCH || error == Errors.PRODUCER_FENCED) { + // We could still receive INVALID_PRODUCER_EPOCH from old versioned transaction coordinator, + // just treat it the same as PRODUCE_FENCED. + fatalError(Errors.PRODUCER_FENCED.exception()); + } else { + fatalError(new KafkaException("Unexpected error in InitProducerIdResponse; " + error.message())); + } + } + } + + private class AddPartitionsToTxnHandler extends TxnRequestHandler { + private final AddPartitionsToTxnRequest.Builder builder; + private long retryBackoffMs; + + private AddPartitionsToTxnHandler(AddPartitionsToTxnRequest.Builder builder) { + super("AddPartitionsToTxn"); + this.builder = builder; + this.retryBackoffMs = TransactionManager.this.retryBackoffMs; + } + + @Override + AddPartitionsToTxnRequest.Builder requestBuilder() { + return builder; + } + + @Override + Priority priority() { + return Priority.ADD_PARTITIONS_OR_OFFSETS; + } + + @Override + public void handleResponse(AbstractResponse response) { + AddPartitionsToTxnResponse addPartitionsToTxnResponse = (AddPartitionsToTxnResponse) response; + Map errors = addPartitionsToTxnResponse.errors(); + boolean hasPartitionErrors = false; + Set unauthorizedTopics = new HashSet<>(); + retryBackoffMs = TransactionManager.this.retryBackoffMs; + + for (Map.Entry topicPartitionErrorEntry : errors.entrySet()) { + TopicPartition topicPartition = topicPartitionErrorEntry.getKey(); + Errors error = topicPartitionErrorEntry.getValue(); + + if (error == Errors.NONE) { + continue; + } else if (error == Errors.COORDINATOR_NOT_AVAILABLE || error == Errors.NOT_COORDINATOR) { + lookupCoordinator(FindCoordinatorRequest.CoordinatorType.TRANSACTION, transactionalId); + reenqueue(); + return; + } else if (error == Errors.CONCURRENT_TRANSACTIONS) { + maybeOverrideRetryBackoffMs(); + reenqueue(); + return; + } else if (error == Errors.COORDINATOR_LOAD_IN_PROGRESS || error == Errors.UNKNOWN_TOPIC_OR_PARTITION) { + reenqueue(); + return; + } else if (error == Errors.INVALID_PRODUCER_EPOCH || error == Errors.PRODUCER_FENCED) { + // We could still receive INVALID_PRODUCER_EPOCH from old versioned transaction coordinator, + // just treat it the same as PRODUCE_FENCED. + fatalError(Errors.PRODUCER_FENCED.exception()); + return; + } else if (error == Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED) { + fatalError(error.exception()); + return; + } else if (error == Errors.INVALID_TXN_STATE) { + fatalError(new KafkaException(error.exception())); + return; + } else if (error == Errors.TOPIC_AUTHORIZATION_FAILED) { + unauthorizedTopics.add(topicPartition.topic()); + } else if (error == Errors.OPERATION_NOT_ATTEMPTED) { + log.debug("Did not attempt to add partition {} to transaction because other partitions in the " + + "batch had errors.", topicPartition); + hasPartitionErrors = true; + } else if (error == Errors.UNKNOWN_PRODUCER_ID || error == Errors.INVALID_PRODUCER_ID_MAPPING) { + abortableErrorIfPossible(error.exception()); + return; + } else { + log.error("Could not add partition {} due to unexpected error {}", topicPartition, error); + hasPartitionErrors = true; + } + } + + Set partitions = errors.keySet(); + + // Remove the partitions from the pending set regardless of the result. We use the presence + // of partitions in the pending set to know when it is not safe to send batches. However, if + // the partitions failed to be added and we enter an error state, we expect the batches to be + // aborted anyway. In this case, we must be able to continue sending the batches which are in + // retry for partitions that were successfully added. + pendingPartitionsInTransaction.removeAll(partitions); + + if (!unauthorizedTopics.isEmpty()) { + abortableError(new TopicAuthorizationException(unauthorizedTopics)); + } else if (hasPartitionErrors) { + abortableError(new KafkaException("Could not add partitions to transaction due to errors: " + errors)); + } else { + log.debug("Successfully added partitions {} to transaction", partitions); + partitionsInTransaction.addAll(partitions); + transactionStarted = true; + result.done(); + } + } + + @Override + public long retryBackoffMs() { + return Math.min(TransactionManager.this.retryBackoffMs, this.retryBackoffMs); + } + + private void maybeOverrideRetryBackoffMs() { + // We only want to reduce the backoff when retrying the first AddPartition which errored out due to a + // CONCURRENT_TRANSACTIONS error since this means that the previous transaction is still completing and + // we don't want to wait too long before trying to start the new one. + // + // This is only a temporary fix, the long term solution is being tracked in + // https://issues.apache.org/jira/browse/KAFKA-5482 + if (partitionsInTransaction.isEmpty()) + this.retryBackoffMs = ADD_PARTITIONS_RETRY_BACKOFF_MS; + } + } + + private class FindCoordinatorHandler extends TxnRequestHandler { + private final FindCoordinatorRequest.Builder builder; + + private FindCoordinatorHandler(FindCoordinatorRequest.Builder builder) { + super("FindCoordinator"); + this.builder = builder; + } + + @Override + FindCoordinatorRequest.Builder requestBuilder() { + return builder; + } + + @Override + Priority priority() { + return Priority.FIND_COORDINATOR; + } + + @Override + FindCoordinatorRequest.CoordinatorType coordinatorType() { + return null; + } + + @Override + String coordinatorKey() { + return null; + } + + @Override + public void handleResponse(AbstractResponse response) { + CoordinatorType coordinatorType = CoordinatorType.forId(builder.data().keyType()); + + List coordinators = ((FindCoordinatorResponse) response).coordinators(); + if (coordinators.size() != 1) { + log.error("Group coordinator lookup failed: Invalid response containing more than a single coordinator"); + fatalError(new IllegalStateException("Group coordinator lookup failed: Invalid response containing more than a single coordinator")); + } + Coordinator coordinatorData = coordinators.get(0); + // For older versions without batching, obtain key from request data since it is not included in response + String key = coordinatorData.key() == null ? builder.data().key() : coordinatorData.key(); + Errors error = Errors.forCode(coordinatorData.errorCode()); + if (error == Errors.NONE) { + Node node = new Node(coordinatorData.nodeId(), coordinatorData.host(), coordinatorData.port()); + switch (coordinatorType) { + case GROUP: + consumerGroupCoordinator = node; + break; + case TRANSACTION: + transactionCoordinator = node; + + } + result.done(); + log.info("Discovered {} coordinator {}", coordinatorType.toString().toLowerCase(Locale.ROOT), node); + } else if (error == Errors.COORDINATOR_NOT_AVAILABLE) { + reenqueue(); + } else if (error == Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED) { + fatalError(error.exception()); + } else if (error == Errors.GROUP_AUTHORIZATION_FAILED) { + abortableError(GroupAuthorizationException.forGroupId(key)); + } else { + fatalError(new KafkaException(String.format("Could not find a coordinator with type %s with key %s due to " + + "unexpected error: %s", coordinatorType, key, + coordinatorData.errorMessage()))); + } + } + } + + private class EndTxnHandler extends TxnRequestHandler { + private final EndTxnRequest.Builder builder; + + private EndTxnHandler(EndTxnRequest.Builder builder) { + super("EndTxn(" + builder.data.committed() + ")"); + this.builder = builder; + } + + @Override + EndTxnRequest.Builder requestBuilder() { + return builder; + } + + @Override + Priority priority() { + return Priority.END_TXN; + } + + @Override + boolean isEndTxn() { + return true; + } + + @Override + public void handleResponse(AbstractResponse response) { + EndTxnResponse endTxnResponse = (EndTxnResponse) response; + Errors error = endTxnResponse.error(); + + if (error == Errors.NONE) { + completeTransaction(); + result.done(); + } else if (error == Errors.COORDINATOR_NOT_AVAILABLE || error == Errors.NOT_COORDINATOR) { + lookupCoordinator(FindCoordinatorRequest.CoordinatorType.TRANSACTION, transactionalId); + reenqueue(); + } else if (error == Errors.COORDINATOR_LOAD_IN_PROGRESS || error == Errors.CONCURRENT_TRANSACTIONS) { + reenqueue(); + } else if (error == Errors.INVALID_PRODUCER_EPOCH || error == Errors.PRODUCER_FENCED) { + // We could still receive INVALID_PRODUCER_EPOCH from old versioned transaction coordinator, + // just treat it the same as PRODUCE_FENCED. + fatalError(Errors.PRODUCER_FENCED.exception()); + } else if (error == Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED) { + fatalError(error.exception()); + } else if (error == Errors.INVALID_TXN_STATE) { + fatalError(error.exception()); + } else if (error == Errors.UNKNOWN_PRODUCER_ID || error == Errors.INVALID_PRODUCER_ID_MAPPING) { + abortableErrorIfPossible(error.exception()); + } else { + fatalError(new KafkaException("Unhandled error in EndTxnResponse: " + error.message())); + } + } + } + + private class AddOffsetsToTxnHandler extends TxnRequestHandler { + private final AddOffsetsToTxnRequest.Builder builder; + private final Map offsets; + private final ConsumerGroupMetadata groupMetadata; + + private AddOffsetsToTxnHandler(AddOffsetsToTxnRequest.Builder builder, + Map offsets, + ConsumerGroupMetadata groupMetadata) { + super("AddOffsetsToTxn"); + this.builder = builder; + this.offsets = offsets; + this.groupMetadata = groupMetadata; + } + + @Override + AddOffsetsToTxnRequest.Builder requestBuilder() { + return builder; + } + + @Override + Priority priority() { + return Priority.ADD_PARTITIONS_OR_OFFSETS; + } + + @Override + public void handleResponse(AbstractResponse response) { + AddOffsetsToTxnResponse addOffsetsToTxnResponse = (AddOffsetsToTxnResponse) response; + Errors error = Errors.forCode(addOffsetsToTxnResponse.data().errorCode()); + + if (error == Errors.NONE) { + log.debug("Successfully added partition for consumer group {} to transaction", builder.data.groupId()); + + // note the result is not completed until the TxnOffsetCommit returns + pendingRequests.add(txnOffsetCommitHandler(result, offsets, groupMetadata)); + + transactionStarted = true; + } else if (error == Errors.COORDINATOR_NOT_AVAILABLE || error == Errors.NOT_COORDINATOR) { + lookupCoordinator(FindCoordinatorRequest.CoordinatorType.TRANSACTION, transactionalId); + reenqueue(); + } else if (error == Errors.COORDINATOR_LOAD_IN_PROGRESS || error == Errors.CONCURRENT_TRANSACTIONS) { + reenqueue(); + } else if (error == Errors.UNKNOWN_PRODUCER_ID || error == Errors.INVALID_PRODUCER_ID_MAPPING) { + abortableErrorIfPossible(error.exception()); + } else if (error == Errors.INVALID_PRODUCER_EPOCH || error == Errors.PRODUCER_FENCED) { + // We could still receive INVALID_PRODUCER_EPOCH from old versioned transaction coordinator, + // just treat it the same as PRODUCE_FENCED. + fatalError(Errors.PRODUCER_FENCED.exception()); + } else if (error == Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED) { + fatalError(error.exception()); + } else if (error == Errors.GROUP_AUTHORIZATION_FAILED) { + abortableError(GroupAuthorizationException.forGroupId(builder.data.groupId())); + } else { + fatalError(new KafkaException("Unexpected error in AddOffsetsToTxnResponse: " + error.message())); + } + } + } + + private class TxnOffsetCommitHandler extends TxnRequestHandler { + private final TxnOffsetCommitRequest.Builder builder; + + private TxnOffsetCommitHandler(TransactionalRequestResult result, + TxnOffsetCommitRequest.Builder builder) { + super(result); + this.builder = builder; + } + + @Override + TxnOffsetCommitRequest.Builder requestBuilder() { + return builder; + } + + @Override + Priority priority() { + return Priority.ADD_PARTITIONS_OR_OFFSETS; + } + + @Override + FindCoordinatorRequest.CoordinatorType coordinatorType() { + return FindCoordinatorRequest.CoordinatorType.GROUP; + } + + @Override + String coordinatorKey() { + return builder.data.groupId(); + } + + @Override + public void handleResponse(AbstractResponse response) { + TxnOffsetCommitResponse txnOffsetCommitResponse = (TxnOffsetCommitResponse) response; + boolean coordinatorReloaded = false; + Map errors = txnOffsetCommitResponse.errors(); + + log.debug("Received TxnOffsetCommit response for consumer group {}: {}", builder.data.groupId(), + errors); + + for (Map.Entry entry : errors.entrySet()) { + TopicPartition topicPartition = entry.getKey(); + Errors error = entry.getValue(); + if (error == Errors.NONE) { + pendingTxnOffsetCommits.remove(topicPartition); + } else if (error == Errors.COORDINATOR_NOT_AVAILABLE + || error == Errors.NOT_COORDINATOR + || error == Errors.REQUEST_TIMED_OUT) { + if (!coordinatorReloaded) { + coordinatorReloaded = true; + lookupCoordinator(FindCoordinatorRequest.CoordinatorType.GROUP, builder.data.groupId()); + } + } else if (error == Errors.UNKNOWN_TOPIC_OR_PARTITION + || error == Errors.COORDINATOR_LOAD_IN_PROGRESS) { + // If the topic is unknown or the coordinator is loading, retry with the current coordinator + continue; + } else if (error == Errors.GROUP_AUTHORIZATION_FAILED) { + abortableError(GroupAuthorizationException.forGroupId(builder.data.groupId())); + break; + } else if (error == Errors.FENCED_INSTANCE_ID) { + abortableError(error.exception()); + break; + } else if (error == Errors.UNKNOWN_MEMBER_ID + || error == Errors.ILLEGAL_GENERATION) { + abortableError(new CommitFailedException("Transaction offset Commit failed " + + "due to consumer group metadata mismatch: " + error.exception().getMessage())); + break; + } else if (isFatalException(error)) { + fatalError(error.exception()); + break; + } else { + fatalError(new KafkaException("Unexpected error in TxnOffsetCommitResponse: " + error.message())); + break; + } + } + + if (result.isCompleted()) { + pendingTxnOffsetCommits.clear(); + } else if (pendingTxnOffsetCommits.isEmpty()) { + result.done(); + } else { + // Retry the commits which failed with a retriable error + reenqueue(); + } + } + } + + private boolean isFatalException(Errors error) { + return error == Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED + || error == Errors.INVALID_PRODUCER_EPOCH + || error == Errors.PRODUCER_FENCED + || error == Errors.UNSUPPORTED_FOR_MESSAGE_FORMAT; + } +} diff --git a/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionalRequestResult.java b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionalRequestResult.java new file mode 100644 index 0000000..d442b18 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionalRequestResult.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer.internals; + + +import org.apache.kafka.common.errors.InterruptException; +import org.apache.kafka.common.errors.TimeoutException; + +import java.util.Locale; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +public final class TransactionalRequestResult { + + private final CountDownLatch latch; + private volatile RuntimeException error = null; + private final String operation; + + public TransactionalRequestResult(String operation) { + this(new CountDownLatch(1), operation); + } + + private TransactionalRequestResult(CountDownLatch latch, String operation) { + this.latch = latch; + this.operation = operation; + } + + public void fail(RuntimeException error) { + this.error = error; + this.latch.countDown(); + } + + public void done() { + this.latch.countDown(); + } + + public void await() { + boolean completed = false; + + while (!completed) { + try { + latch.await(); + completed = true; + } catch (InterruptedException e) { + // Keep waiting until done, we have no other option for these transactional requests. + } + } + + if (!isSuccessful()) + throw error(); + } + + public void await(long timeout, TimeUnit unit) { + try { + boolean success = latch.await(timeout, unit); + if (!isSuccessful()) { + throw error(); + } + if (!success) { + throw new TimeoutException("Timeout expired after " + timeout + " " + unit.name().toLowerCase(Locale.ROOT) + " while awaiting " + operation); + } + } catch (InterruptedException e) { + throw new InterruptException("Received interrupt while awaiting " + operation, e); + } + } + + public RuntimeException error() { + return error; + } + + public boolean isSuccessful() { + return error == null; + } + + public boolean isCompleted() { + return latch.getCount() == 0L; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/Cluster.java b/clients/src/main/java/org/apache/kafka/common/Cluster.java new file mode 100644 index 0000000..7d3f6f0 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/Cluster.java @@ -0,0 +1,391 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common; + +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; + +/** + * An immutable representation of a subset of the nodes, topics, and partitions in the Kafka cluster. + */ +public final class Cluster { + + private final boolean isBootstrapConfigured; + private final List nodes; + private final Set unauthorizedTopics; + private final Set invalidTopics; + private final Set internalTopics; + private final Node controller; + private final Map partitionsByTopicPartition; + private final Map> partitionsByTopic; + private final Map> availablePartitionsByTopic; + private final Map> partitionsByNode; + private final Map nodesById; + private final ClusterResource clusterResource; + private final Map topicIds; + private final Map topicNames; + + /** + * Create a new cluster with the given id, nodes and partitions + * @param nodes The nodes in the cluster + * @param partitions Information about a subset of the topic-partitions this cluster hosts + */ + public Cluster(String clusterId, + Collection nodes, + Collection partitions, + Set unauthorizedTopics, + Set internalTopics) { + this(clusterId, false, nodes, partitions, unauthorizedTopics, Collections.emptySet(), internalTopics, null, Collections.emptyMap()); + } + + /** + * Create a new cluster with the given id, nodes and partitions + * @param nodes The nodes in the cluster + * @param partitions Information about a subset of the topic-partitions this cluster hosts + */ + public Cluster(String clusterId, + Collection nodes, + Collection partitions, + Set unauthorizedTopics, + Set internalTopics, + Node controller) { + this(clusterId, false, nodes, partitions, unauthorizedTopics, Collections.emptySet(), internalTopics, controller, Collections.emptyMap()); + } + + /** + * Create a new cluster with the given id, nodes and partitions + * @param nodes The nodes in the cluster + * @param partitions Information about a subset of the topic-partitions this cluster hosts + */ + public Cluster(String clusterId, + Collection nodes, + Collection partitions, + Set unauthorizedTopics, + Set invalidTopics, + Set internalTopics, + Node controller) { + this(clusterId, false, nodes, partitions, unauthorizedTopics, invalidTopics, internalTopics, controller, Collections.emptyMap()); + } + + /** + * Create a new cluster with the given id, nodes, partitions and topicIds + * @param nodes The nodes in the cluster + * @param partitions Information about a subset of the topic-partitions this cluster hosts + */ + public Cluster(String clusterId, + Collection nodes, + Collection partitions, + Set unauthorizedTopics, + Set invalidTopics, + Set internalTopics, + Node controller, + Map topicIds) { + this(clusterId, false, nodes, partitions, unauthorizedTopics, invalidTopics, internalTopics, controller, topicIds); + } + + private Cluster(String clusterId, + boolean isBootstrapConfigured, + Collection nodes, + Collection partitions, + Set unauthorizedTopics, + Set invalidTopics, + Set internalTopics, + Node controller, + Map topicIds) { + this.isBootstrapConfigured = isBootstrapConfigured; + this.clusterResource = new ClusterResource(clusterId); + // make a randomized, unmodifiable copy of the nodes + List copy = new ArrayList<>(nodes); + Collections.shuffle(copy); + this.nodes = Collections.unmodifiableList(copy); + + // Index the nodes for quick lookup + Map tmpNodesById = new HashMap<>(); + Map> tmpPartitionsByNode = new HashMap<>(nodes.size()); + for (Node node : nodes) { + tmpNodesById.put(node.id(), node); + // Populate the map here to make it easy to add the partitions per node efficiently when iterating over + // the partitions + tmpPartitionsByNode.put(node.id(), new ArrayList<>()); + } + this.nodesById = Collections.unmodifiableMap(tmpNodesById); + + // index the partition infos by topic, topic+partition, and node + // note that this code is performance sensitive if there are a large number of partitions so we are careful + // to avoid unnecessary work + Map tmpPartitionsByTopicPartition = new HashMap<>(partitions.size()); + Map> tmpPartitionsByTopic = new HashMap<>(); + for (PartitionInfo p : partitions) { + tmpPartitionsByTopicPartition.put(new TopicPartition(p.topic(), p.partition()), p); + tmpPartitionsByTopic.computeIfAbsent(p.topic(), topic -> new ArrayList<>()).add(p); + + // The leader may not be known + if (p.leader() == null || p.leader().isEmpty()) + continue; + + // If it is known, its node information should be available + List partitionsForNode = Objects.requireNonNull(tmpPartitionsByNode.get(p.leader().id())); + partitionsForNode.add(p); + } + + // Update the values of `tmpPartitionsByNode` to contain unmodifiable lists + for (Map.Entry> entry : tmpPartitionsByNode.entrySet()) { + tmpPartitionsByNode.put(entry.getKey(), Collections.unmodifiableList(entry.getValue())); + } + + // Populate `tmpAvailablePartitionsByTopic` and update the values of `tmpPartitionsByTopic` to contain + // unmodifiable lists + Map> tmpAvailablePartitionsByTopic = new HashMap<>(tmpPartitionsByTopic.size()); + for (Map.Entry> entry : tmpPartitionsByTopic.entrySet()) { + String topic = entry.getKey(); + List partitionsForTopic = Collections.unmodifiableList(entry.getValue()); + tmpPartitionsByTopic.put(topic, partitionsForTopic); + // Optimise for the common case where all partitions are available + boolean foundUnavailablePartition = partitionsForTopic.stream().anyMatch(p -> p.leader() == null); + List availablePartitionsForTopic; + if (foundUnavailablePartition) { + availablePartitionsForTopic = new ArrayList<>(partitionsForTopic.size()); + for (PartitionInfo p : partitionsForTopic) { + if (p.leader() != null) + availablePartitionsForTopic.add(p); + } + availablePartitionsForTopic = Collections.unmodifiableList(availablePartitionsForTopic); + } else { + availablePartitionsForTopic = partitionsForTopic; + } + tmpAvailablePartitionsByTopic.put(topic, availablePartitionsForTopic); + } + + this.partitionsByTopicPartition = Collections.unmodifiableMap(tmpPartitionsByTopicPartition); + this.partitionsByTopic = Collections.unmodifiableMap(tmpPartitionsByTopic); + this.availablePartitionsByTopic = Collections.unmodifiableMap(tmpAvailablePartitionsByTopic); + this.partitionsByNode = Collections.unmodifiableMap(tmpPartitionsByNode); + this.topicIds = Collections.unmodifiableMap(topicIds); + Map tmpTopicNames = new HashMap<>(); + topicIds.forEach((key, value) -> tmpTopicNames.put(value, key)); + this.topicNames = Collections.unmodifiableMap(tmpTopicNames); + + this.unauthorizedTopics = Collections.unmodifiableSet(unauthorizedTopics); + this.invalidTopics = Collections.unmodifiableSet(invalidTopics); + this.internalTopics = Collections.unmodifiableSet(internalTopics); + this.controller = controller; + } + + /** + * Create an empty cluster instance with no nodes and no topic-partitions. + */ + public static Cluster empty() { + return new Cluster(null, new ArrayList<>(0), new ArrayList<>(0), Collections.emptySet(), + Collections.emptySet(), null); + } + + /** + * Create a "bootstrap" cluster using the given list of host/ports + * @param addresses The addresses + * @return A cluster for these hosts/ports + */ + public static Cluster bootstrap(List addresses) { + List nodes = new ArrayList<>(); + int nodeId = -1; + for (InetSocketAddress address : addresses) + nodes.add(new Node(nodeId--, address.getHostString(), address.getPort())); + return new Cluster(null, true, nodes, new ArrayList<>(0), + Collections.emptySet(), Collections.emptySet(), Collections.emptySet(), null, Collections.emptyMap()); + } + + /** + * Return a copy of this cluster combined with `partitions`. + */ + public Cluster withPartitions(Map partitions) { + Map combinedPartitions = new HashMap<>(this.partitionsByTopicPartition); + combinedPartitions.putAll(partitions); + return new Cluster(clusterResource.clusterId(), this.nodes, combinedPartitions.values(), + new HashSet<>(this.unauthorizedTopics), new HashSet<>(this.invalidTopics), + new HashSet<>(this.internalTopics), this.controller); + } + + /** + * @return The known set of nodes + */ + public List nodes() { + return this.nodes; + } + + /** + * Get the node by the node id (or null if the node is not online or does not exist) + * @param id The id of the node + * @return The node, or null if the node is not online or does not exist + */ + public Node nodeById(int id) { + return this.nodesById.get(id); + } + + /** + * Get the node by node id if the replica for the given partition is online + * @param partition + * @param id + * @return the node + */ + public Optional nodeIfOnline(TopicPartition partition, int id) { + Node node = nodeById(id); + if (node != null && !Arrays.asList(partition(partition).offlineReplicas()).contains(node)) { + return Optional.of(node); + } else { + return Optional.empty(); + } + } + + /** + * Get the current leader for the given topic-partition + * @param topicPartition The topic and partition we want to know the leader for + * @return The node that is the leader for this topic-partition, or null if there is currently no leader + */ + public Node leaderFor(TopicPartition topicPartition) { + PartitionInfo info = partitionsByTopicPartition.get(topicPartition); + if (info == null) + return null; + else + return info.leader(); + } + + /** + * Get the metadata for the specified partition + * @param topicPartition The topic and partition to fetch info for + * @return The metadata about the given topic and partition, or null if none is found + */ + public PartitionInfo partition(TopicPartition topicPartition) { + return partitionsByTopicPartition.get(topicPartition); + } + + /** + * Get the list of partitions for this topic + * @param topic The topic name + * @return A list of partitions + */ + public List partitionsForTopic(String topic) { + return partitionsByTopic.getOrDefault(topic, Collections.emptyList()); + } + + /** + * Get the number of partitions for the given topic. + * @param topic The topic to get the number of partitions for + * @return The number of partitions or null if there is no corresponding metadata + */ + public Integer partitionCountForTopic(String topic) { + List partitions = this.partitionsByTopic.get(topic); + return partitions == null ? null : partitions.size(); + } + + /** + * Get the list of available partitions for this topic + * @param topic The topic name + * @return A list of partitions + */ + public List availablePartitionsForTopic(String topic) { + return availablePartitionsByTopic.getOrDefault(topic, Collections.emptyList()); + } + + /** + * Get the list of partitions whose leader is this node + * @param nodeId The node id + * @return A list of partitions + */ + public List partitionsForNode(int nodeId) { + return partitionsByNode.getOrDefault(nodeId, Collections.emptyList()); + } + + /** + * Get all topics. + * @return a set of all topics + */ + public Set topics() { + return partitionsByTopic.keySet(); + } + + public Set unauthorizedTopics() { + return unauthorizedTopics; + } + + public Set invalidTopics() { + return invalidTopics; + } + + public Set internalTopics() { + return internalTopics; + } + + public boolean isBootstrapConfigured() { + return isBootstrapConfigured; + } + + public ClusterResource clusterResource() { + return clusterResource; + } + + public Node controller() { + return controller; + } + + public Collection topicIds() { + return topicIds.values(); + } + + public Uuid topicId(String topic) { + return topicIds.getOrDefault(topic, Uuid.ZERO_UUID); + } + + public String topicName(Uuid topicId) { + return topicNames.get(topicId); + } + + @Override + public String toString() { + return "Cluster(id = " + clusterResource.clusterId() + ", nodes = " + this.nodes + + ", partitions = " + this.partitionsByTopicPartition.values() + ", controller = " + controller + ")"; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Cluster cluster = (Cluster) o; + return isBootstrapConfigured == cluster.isBootstrapConfigured && + Objects.equals(nodes, cluster.nodes) && + Objects.equals(unauthorizedTopics, cluster.unauthorizedTopics) && + Objects.equals(invalidTopics, cluster.invalidTopics) && + Objects.equals(internalTopics, cluster.internalTopics) && + Objects.equals(controller, cluster.controller) && + Objects.equals(partitionsByTopicPartition, cluster.partitionsByTopicPartition) && + Objects.equals(clusterResource, cluster.clusterResource); + } + + @Override + public int hashCode() { + return Objects.hash(isBootstrapConfigured, nodes, unauthorizedTopics, invalidTopics, internalTopics, controller, + partitionsByTopicPartition, clusterResource); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/ClusterResource.java b/clients/src/main/java/org/apache/kafka/common/ClusterResource.java new file mode 100644 index 0000000..749f2d1 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/ClusterResource.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common; + + +import java.util.Objects; + +/** + * The ClusterResource class encapsulates metadata for a Kafka cluster. + */ +public class ClusterResource { + + private final String clusterId; + + /** + * Create {@link ClusterResource} with a cluster id. Note that cluster id may be {@code null} if the + * metadata request was sent to a broker without support for cluster ids. The first version of Kafka + * to support cluster id is 0.10.1.0. + * @param clusterId + */ + public ClusterResource(String clusterId) { + this.clusterId = clusterId; + } + + /** + * Return the cluster id. Note that it may be {@code null} if the metadata request was sent to a broker without + * support for cluster ids. The first version of Kafka to support cluster id is 0.10.1.0. + */ + public String clusterId() { + return clusterId; + } + + @Override + public String toString() { + return "ClusterResource(clusterId=" + clusterId + ")"; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ClusterResource that = (ClusterResource) o; + return Objects.equals(clusterId, that.clusterId); + } + + @Override + public int hashCode() { + return Objects.hash(clusterId); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/ClusterResourceListener.java b/clients/src/main/java/org/apache/kafka/common/ClusterResourceListener.java new file mode 100644 index 0000000..f1939df --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/ClusterResourceListener.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common; + +/** + * A callback interface that users can implement when they wish to get notified about changes in the Cluster metadata. + *

        + * Users who need access to cluster metadata in interceptors, metric reporters, serializers and deserializers + * can implement this interface. The order of method calls for each of these types is described below. + *

        + *

        Clients

        + * There will be one invocation of {@link ClusterResourceListener#onUpdate(ClusterResource)} after each metadata response. + * Note that the cluster id may be null when the Kafka broker version is below 0.10.1.0. If you receive a null cluster id, you can expect it to always be null unless you have a cluster with multiple broker versions which can happen if the cluster is being upgraded while the client is running. + *

        + * {@link org.apache.kafka.clients.producer.ProducerInterceptor} : The {@link ClusterResourceListener#onUpdate(ClusterResource)} method will be invoked after {@link org.apache.kafka.clients.producer.ProducerInterceptor#onSend(org.apache.kafka.clients.producer.ProducerRecord)} + * but before {@link org.apache.kafka.clients.producer.ProducerInterceptor#onAcknowledgement(org.apache.kafka.clients.producer.RecordMetadata, Exception)} . + *

        + * {@link org.apache.kafka.clients.consumer.ConsumerInterceptor} : The {@link ClusterResourceListener#onUpdate(ClusterResource)} method will be invoked before {@link org.apache.kafka.clients.consumer.ConsumerInterceptor#onConsume(org.apache.kafka.clients.consumer.ConsumerRecords)} + *

        + * {@link org.apache.kafka.common.serialization.Serializer} : The {@link ClusterResourceListener#onUpdate(ClusterResource)} method will be invoked before {@link org.apache.kafka.common.serialization.Serializer#serialize(String, Object)} + *

        + * {@link org.apache.kafka.common.serialization.Deserializer} : The {@link ClusterResourceListener#onUpdate(ClusterResource)} method will be invoked before {@link org.apache.kafka.common.serialization.Deserializer#deserialize(String, byte[])} + *

        + * {@link org.apache.kafka.common.metrics.MetricsReporter} : The {@link ClusterResourceListener#onUpdate(ClusterResource)} method will be invoked after first {@link org.apache.kafka.clients.producer.KafkaProducer#send(org.apache.kafka.clients.producer.ProducerRecord)} invocation for Producer metrics reporter + * and after first {@link org.apache.kafka.clients.consumer.KafkaConsumer#poll(java.time.Duration)} invocation for Consumer metrics + * reporters. The reporter may receive metric events from the network layer before this method is invoked. + *

        Broker

        + * There is a single invocation {@link ClusterResourceListener#onUpdate(ClusterResource)} on broker start-up and the cluster metadata will never change. + *

        + * KafkaMetricsReporter : The {@link ClusterResourceListener#onUpdate(ClusterResource)} method will be invoked during the bootup of the Kafka broker. The reporter may receive metric events from the network layer before this method is invoked. + *

        + * {@link org.apache.kafka.common.metrics.MetricsReporter} : The {@link ClusterResourceListener#onUpdate(ClusterResource)} method will be invoked during the bootup of the Kafka broker. The reporter may receive metric events from the network layer before this method is invoked. + */ +public interface ClusterResourceListener { + /** + * A callback method that a user can implement to get updates for {@link ClusterResource}. + * @param clusterResource cluster metadata + */ + void onUpdate(ClusterResource clusterResource); +} diff --git a/clients/src/main/java/org/apache/kafka/common/Configurable.java b/clients/src/main/java/org/apache/kafka/common/Configurable.java new file mode 100644 index 0000000..ecca298 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/Configurable.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common; + +import java.util.Map; + +/** + * A Mix-in style interface for classes that are instantiated by reflection and need to take configuration parameters + */ +public interface Configurable { + + /** + * Configure this class with the given key-value pairs + */ + void configure(Map configs); + +} diff --git a/clients/src/main/java/org/apache/kafka/common/ConsumerGroupState.java b/clients/src/main/java/org/apache/kafka/common/ConsumerGroupState.java new file mode 100644 index 0000000..ebd2b53 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/ConsumerGroupState.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common; + +import java.util.Arrays; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * The consumer group state. + */ +public enum ConsumerGroupState { + UNKNOWN("Unknown"), + PREPARING_REBALANCE("PreparingRebalance"), + COMPLETING_REBALANCE("CompletingRebalance"), + STABLE("Stable"), + DEAD("Dead"), + EMPTY("Empty"); + + private final static Map NAME_TO_ENUM = Arrays.stream(values()) + .collect(Collectors.toMap(state -> state.name, Function.identity()));; + + private final String name; + + ConsumerGroupState(String name) { + this.name = name; + } + + /** + * Parse a string into a consumer group state. + */ + public static ConsumerGroupState parse(String name) { + ConsumerGroupState state = NAME_TO_ENUM.get(name); + return state == null ? UNKNOWN : state; + } + + @Override + public String toString() { + return name; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/ElectionType.java b/clients/src/main/java/org/apache/kafka/common/ElectionType.java new file mode 100644 index 0000000..55331c5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/ElectionType.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Arrays; +import java.util.Set; + +/** + * Options for {@link org.apache.kafka.clients.admin.Admin#electLeaders(ElectionType, Set, org.apache.kafka.clients.admin.ElectLeadersOptions)}. + * + * The API of this class is evolving, see {@link org.apache.kafka.clients.admin.Admin} for details. + */ +@InterfaceStability.Evolving +public enum ElectionType { + PREFERRED((byte) 0), UNCLEAN((byte) 1); + + public final byte value; + + ElectionType(byte value) { + this.value = value; + } + + public static ElectionType valueOf(byte value) { + if (value == PREFERRED.value) { + return PREFERRED; + } else if (value == UNCLEAN.value) { + return UNCLEAN; + } else { + throw new IllegalArgumentException( + String.format("Value %s must be one of %s", value, Arrays.asList(ElectionType.values()))); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/Endpoint.java b/clients/src/main/java/org/apache/kafka/common/Endpoint.java new file mode 100644 index 0000000..2353de2 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/Endpoint.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common; + +import java.util.Objects; +import java.util.Optional; + +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.security.auth.SecurityProtocol; + +/** + * Represents a broker endpoint. + */ + +@InterfaceStability.Evolving +public class Endpoint { + + private final String listenerName; + private final SecurityProtocol securityProtocol; + private final String host; + private final int port; + + public Endpoint(String listenerName, SecurityProtocol securityProtocol, String host, int port) { + this.listenerName = listenerName; + this.securityProtocol = securityProtocol; + this.host = host; + this.port = port; + } + + /** + * Returns the listener name of this endpoint. This is non-empty for endpoints provided + * to broker plugins, but may be empty when used in clients. + */ + public Optional listenerName() { + return Optional.ofNullable(listenerName); + } + + /** + * Returns the security protocol of this endpoint. + */ + public SecurityProtocol securityProtocol() { + return securityProtocol; + } + + /** + * Returns advertised host name of this endpoint. + */ + public String host() { + return host; + } + + /** + * Returns the port to which the listener is bound. + */ + public int port() { + return port; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof Endpoint)) { + return false; + } + + Endpoint that = (Endpoint) o; + return Objects.equals(this.listenerName, that.listenerName) && + Objects.equals(this.securityProtocol, that.securityProtocol) && + Objects.equals(this.host, that.host) && + this.port == that.port; + + } + + @Override + public int hashCode() { + return Objects.hash(listenerName, securityProtocol, host, port); + } + + @Override + public String toString() { + return "Endpoint(" + + "listenerName='" + listenerName + '\'' + + ", securityProtocol=" + securityProtocol + + ", host='" + host + '\'' + + ", port=" + port + + ')'; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/InvalidRecordException.java b/clients/src/main/java/org/apache/kafka/common/InvalidRecordException.java new file mode 100644 index 0000000..4c2815b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/InvalidRecordException.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common; + +import org.apache.kafka.common.errors.ApiException; + +public class InvalidRecordException extends ApiException { + + private static final long serialVersionUID = 1; + + public InvalidRecordException(String s) { + super(s); + } + + public InvalidRecordException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/IsolationLevel.java b/clients/src/main/java/org/apache/kafka/common/IsolationLevel.java new file mode 100644 index 0000000..79f0a92 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/IsolationLevel.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common; + +public enum IsolationLevel { + READ_UNCOMMITTED((byte) 0), READ_COMMITTED((byte) 1); + + private final byte id; + + IsolationLevel(byte id) { + this.id = id; + } + + public byte id() { + return id; + } + + public static IsolationLevel forId(byte id) { + switch (id) { + case 0: + return READ_UNCOMMITTED; + case 1: + return READ_COMMITTED; + default: + throw new IllegalArgumentException("Unknown isolation level " + id); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/KafkaException.java b/clients/src/main/java/org/apache/kafka/common/KafkaException.java new file mode 100644 index 0000000..7a20691 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/KafkaException.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common; + +/** + * The base class of all other Kafka exceptions + */ +public class KafkaException extends RuntimeException { + + private final static long serialVersionUID = 1L; + + public KafkaException(String message, Throwable cause) { + super(message, cause); + } + + public KafkaException(String message) { + super(message); + } + + public KafkaException(Throwable cause) { + super(cause); + } + + public KafkaException() { + super(); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/KafkaFuture.java b/clients/src/main/java/org/apache/kafka/common/KafkaFuture.java new file mode 100644 index 0000000..84aed74 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/KafkaFuture.java @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common; + +import org.apache.kafka.common.internals.KafkaFutureImpl; + +import java.util.Arrays; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +/** + * A flexible future which supports call chaining and other asynchronous programming patterns. + * + *

        Relation to {@code CompletionStage}

        + *

        It is possible to obtain a {@code CompletionStage} from a + * {@code KafkaFuture} instance by calling {@link #toCompletionStage()}. + * If converting {@link KafkaFuture#whenComplete(BiConsumer)} or {@link KafkaFuture#thenApply(BaseFunction)} to + * {@link CompletableFuture#whenComplete(java.util.function.BiConsumer)} or + * {@link CompletableFuture#thenApply(java.util.function.Function)} be aware that the returned + * {@code KafkaFuture} will fail with an {@code ExecutionException}, whereas a {@code CompletionStage} fails + * with a {@code CompletionException}. + */ +public abstract class KafkaFuture implements Future { + /** + * A function which takes objects of type A and returns objects of type B. + */ + @FunctionalInterface + public interface BaseFunction { + B apply(A a); + } + + /** + * A function which takes objects of type A and returns objects of type B. + * + * @deprecated Since Kafka 3.0. Use the {@link BaseFunction} functional interface. + */ + @Deprecated + public static abstract class Function implements BaseFunction { } + + /** + * A consumer of two different types of object. + */ + @FunctionalInterface + public interface BiConsumer { + void accept(A a, B b); + } + + /** + * Returns a new KafkaFuture that is already completed with the given value. + */ + public static KafkaFuture completedFuture(U value) { + KafkaFuture future = new KafkaFutureImpl<>(); + future.complete(value); + return future; + } + + /** + * Returns a new KafkaFuture that is completed when all the given futures have completed. If + * any future throws an exception, the returned future returns it. If multiple futures throw + * an exception, which one gets returned is arbitrarily chosen. + */ + public static KafkaFuture allOf(KafkaFuture... futures) { + KafkaFutureImpl result = new KafkaFutureImpl<>(); + CompletableFuture.allOf(Arrays.stream(futures) + .map(kafkaFuture -> { + // Safe since KafkaFuture's only subclass is KafkaFuture for which toCompletionStage() + // always return a CF. + return (CompletableFuture) kafkaFuture.toCompletionStage(); + }) + .toArray(CompletableFuture[]::new)).whenComplete((value, ex) -> { + if (ex == null) { + result.complete(value); + } else { + // Have to unwrap the CompletionException which allOf() introduced + result.completeExceptionally(ex.getCause()); + } + }); + + return result; + } + + /** + * Gets a {@code CompletionStage} with the same completion properties as this {@code KafkaFuture}. + * The returned instance will complete when this future completes and in the same way + * (with the same result or exception). + * + *

        Calling {@code toCompletableFuture()} on the returned instance will yield a {@code CompletableFuture}, + * but invocation of the completion methods ({@code complete()} and other methods in the {@code complete*()} + * and {@code obtrude*()} families) on that {@code CompletableFuture} instance will result in + * {@code UnsupportedOperationException} being thrown. Unlike a "minimal" {@code CompletableFuture}, + * the {@code get*()} and other methods of {@code CompletableFuture} that are not inherited from + * {@code CompletionStage} will work normally. + * + *

        If you want to block on the completion of a KafkaFuture you should use + * {@link #get()}, {@link #get(long, TimeUnit)} or {@link #getNow(Object)}, rather then calling + * {@code .toCompletionStage().toCompletableFuture().get()} etc. + * + * @since Kafka 3.0 + */ + public abstract CompletionStage toCompletionStage(); + + /** + * Returns a new KafkaFuture that, when this future completes normally, is executed with this + * futures's result as the argument to the supplied function. + * + * The function may be invoked by the thread that calls {@code thenApply} or it may be invoked by the thread that + * completes the future. + */ + public abstract KafkaFuture thenApply(BaseFunction function); + + /** + * @see KafkaFuture#thenApply(BaseFunction) + * + * Prefer {@link KafkaFuture#thenApply(BaseFunction)} as this function is here for backwards compatibility reasons + * and might be deprecated/removed in a future release. + */ + public abstract KafkaFuture thenApply(Function function); + + /** + * Returns a new KafkaFuture with the same result or exception as this future, that executes the given action + * when this future completes. + * + * When this future is done, the given action is invoked with the result (or null if none) and the exception + * (or null if none) of this future as arguments. + * + * The returned future is completed when the action returns. + * The supplied action should not throw an exception. However, if it does, the following rules apply: + * if this future completed normally but the supplied action throws an exception, then the returned future completes + * exceptionally with the supplied action's exception. + * Or, if this future completed exceptionally and the supplied action throws an exception, then the returned future + * completes exceptionally with this future's exception. + * + * The action may be invoked by the thread that calls {@code whenComplete} or it may be invoked by the thread that + * completes the future. + * + * @param action the action to preform + * @return the new future + */ + public abstract KafkaFuture whenComplete(BiConsumer action); + + /** + * If not already completed, sets the value returned by get() and related methods to the given + * value. + */ + protected abstract boolean complete(T newValue); + + /** + * If not already completed, causes invocations of get() and related methods to throw the given + * exception. + */ + protected abstract boolean completeExceptionally(Throwable newException); + + /** + * If not already completed, completes this future with a CancellationException. Dependent + * futures that have not already completed will also complete exceptionally, with a + * CompletionException caused by this CancellationException. + */ + @Override + public abstract boolean cancel(boolean mayInterruptIfRunning); + + /** + * Waits if necessary for this future to complete, and then returns its result. + */ + @Override + public abstract T get() throws InterruptedException, ExecutionException; + + /** + * Waits if necessary for at most the given time for this future to complete, and then returns + * its result, if available. + */ + @Override + public abstract T get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, + TimeoutException; + + /** + * Returns the result value (or throws any encountered exception) if completed, else returns + * the given valueIfAbsent. + */ + public abstract T getNow(T valueIfAbsent) throws InterruptedException, ExecutionException; + + /** + * Returns true if this CompletableFuture was cancelled before it completed normally. + */ + @Override + public abstract boolean isCancelled(); + + /** + * Returns true if this CompletableFuture completed exceptionally, in any way. + */ + public abstract boolean isCompletedExceptionally(); + + /** + * Returns true if completed in any fashion: normally, exceptionally, or via cancellation. + */ + @Override + public abstract boolean isDone(); +} diff --git a/clients/src/main/java/org/apache/kafka/common/MessageFormatter.java b/clients/src/main/java/org/apache/kafka/common/MessageFormatter.java new file mode 100644 index 0000000..390aa61 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/MessageFormatter.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common; + +import java.io.Closeable; +import java.io.PrintStream; +import java.util.Map; + +import org.apache.kafka.clients.consumer.ConsumerRecord; + +/** + * This interface allows to define Formatters that can be used to parse and format records read by a + * Consumer instance for display. + * The kafka-console-consumer has built-in support for MessageFormatter, via the --formatter flag. + * + * Kafka provides a few implementations to display records of internal topics such as __consumer_offsets, + * __transaction_state and the MirrorMaker2 topics. + * + */ +public interface MessageFormatter extends Configurable, Closeable { + + /** + * Configures the MessageFormatter + * @param configs Map to configure the formatter + */ + default void configure(Map configs) {} + + /** + * Parses and formats a record for display + * @param consumerRecord the record to format + * @param output the print stream used to output the record + */ + void writeTo(ConsumerRecord consumerRecord, PrintStream output); + + /** + * Closes the formatter + */ + default void close() {} +} diff --git a/clients/src/main/java/org/apache/kafka/common/Metric.java b/clients/src/main/java/org/apache/kafka/common/Metric.java new file mode 100644 index 0000000..01f8137 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/Metric.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common; + +/** + * A metric tracked for monitoring purposes. + */ +public interface Metric { + + /** + * A name for this metric + */ + MetricName metricName(); + + /** + * The value of the metric, which may be measurable or a non-measurable gauge + */ + Object metricValue(); + +} diff --git a/clients/src/main/java/org/apache/kafka/common/MetricName.java b/clients/src/main/java/org/apache/kafka/common/MetricName.java new file mode 100644 index 0000000..b1ccf30 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/MetricName.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common; + +import java.util.Map; +import java.util.Objects; + +/** + * The MetricName class encapsulates a metric's name, logical group and its related attributes. It should be constructed using metrics.MetricName(...). + *

        + * This class captures the following parameters: + *

        + *  name The name of the metric
        + *  group logical group name of the metrics to which this metric belongs.
        + *  description A human-readable description to include in the metric. This is optional.
        + *  tags additional key/value attributes of the metric. This is optional.
        + * 
        + * group, tags parameters can be used to create unique metric names while reporting in JMX or any custom reporting. + *

        + * Ex: standard JMX MBean can be constructed like domainName:type=group,key1=val1,key2=val2 + *

        + * + * Usage looks something like this: + *

        {@code
        + * // set up metrics:
        + *
        + * Map metricTags = new LinkedHashMap();
        + * metricTags.put("client-id", "producer-1");
        + * metricTags.put("topic", "topic");
        + *
        + * MetricConfig metricConfig = new MetricConfig().tags(metricTags);
        + * Metrics metrics = new Metrics(metricConfig); // this is the global repository of metrics and sensors
        + *
        + * Sensor sensor = metrics.sensor("message-sizes");
        + *
        + * MetricName metricName = metrics.metricName("message-size-avg", "producer-metrics", "average message size");
        + * sensor.add(metricName, new Avg());
        + *
        + * metricName = metrics.metricName("message-size-max", "producer-metrics");
        + * sensor.add(metricName, new Max());
        + *
        + * metricName = metrics.metricName("message-size-min", "producer-metrics", "message minimum size", "client-id", "my-client", "topic", "my-topic");
        + * sensor.add(metricName, new Min());
        + *
        + * // as messages are sent we record the sizes
        + * sensor.record(messageSize);
        + * }
        + */ +public final class MetricName { + + private final String name; + private final String group; + private final String description; + private Map tags; + private int hash = 0; + + /** + * Please create MetricName by method {@link org.apache.kafka.common.metrics.Metrics#metricName(String, String, String, Map)} + * + * @param name The name of the metric + * @param group logical group name of the metrics to which this metric belongs + * @param description A human-readable description to include in the metric + * @param tags additional key/value attributes of the metric + */ + public MetricName(String name, String group, String description, Map tags) { + this.name = Objects.requireNonNull(name); + this.group = Objects.requireNonNull(group); + this.description = Objects.requireNonNull(description); + this.tags = Objects.requireNonNull(tags); + } + + public String name() { + return this.name; + } + + public String group() { + return this.group; + } + + public Map tags() { + return this.tags; + } + + public String description() { + return this.description; + } + + @Override + public int hashCode() { + if (hash != 0) + return hash; + final int prime = 31; + int result = 1; + result = prime * result + group.hashCode(); + result = prime * result + name.hashCode(); + result = prime * result + tags.hashCode(); + this.hash = result; + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + MetricName other = (MetricName) obj; + return group.equals(other.group) && name.equals(other.name) && tags.equals(other.tags); + } + + @Override + public String toString() { + return "MetricName [name=" + name + ", group=" + group + ", description=" + + description + ", tags=" + tags + "]"; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/MetricNameTemplate.java b/clients/src/main/java/org/apache/kafka/common/MetricNameTemplate.java new file mode 100644 index 0000000..f9c4ef5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/MetricNameTemplate.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common; + +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.Objects; +import java.util.Set; + +/** + * A template for a MetricName. It contains a name, group, and description, as + * well as all the tags that will be used to create the mBean name. Tag values + * are omitted from the template, but are filled in at runtime with their + * specified values. The order of the tags is maintained, if an ordered set + * is provided, so that the mBean names can be compared and sorted lexicographically. + */ +public class MetricNameTemplate { + private final String name; + private final String group; + private final String description; + private LinkedHashSet tags; + + /** + * Create a new template. Note that the order of the tags will be preserved if the supplied + * {@code tagsNames} set has an order. + * + * @param name the name of the metric; may not be null + * @param group the name of the group; may not be null + * @param description the description of the metric; may not be null + * @param tagsNames the set of metric tag names, which can/should be a set that maintains order; may not be null + */ + public MetricNameTemplate(String name, String group, String description, Set tagsNames) { + this.name = Objects.requireNonNull(name); + this.group = Objects.requireNonNull(group); + this.description = Objects.requireNonNull(description); + this.tags = new LinkedHashSet<>(Objects.requireNonNull(tagsNames)); + } + + /** + * Create a new template. Note that the order of the tags will be preserved. + * + * @param name the name of the metric; may not be null + * @param group the name of the group; may not be null + * @param description the description of the metric; may not be null + * @param tagsNames the names of the metric tags in the preferred order; none of the tag names should be null + */ + public MetricNameTemplate(String name, String group, String description, String... tagsNames) { + this(name, group, description, getTags(tagsNames)); + } + + private static LinkedHashSet getTags(String... keys) { + LinkedHashSet tags = new LinkedHashSet<>(); + + Collections.addAll(tags, keys); + + return tags; + } + + /** + * Get the name of the metric. + * + * @return the metric name; never null + */ + public String name() { + return this.name; + } + + /** + * Get the name of the group. + * + * @return the group name; never null + */ + public String group() { + return this.group; + } + + /** + * Get the description of the metric. + * + * @return the metric description; never null + */ + public String description() { + return this.description; + } + + /** + * Get the set of tag names for the metric. + * + * @return the ordered set of tag names; never null but possibly empty + */ + public Set tags() { + return tags; + } + + @Override + public int hashCode() { + return Objects.hash(name, group, tags); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + MetricNameTemplate other = (MetricNameTemplate) o; + return Objects.equals(name, other.name) && Objects.equals(group, other.group) && + Objects.equals(tags, other.tags); + } + + @Override + public String toString() { + return String.format("name=%s, group=%s, tags=%s", name, group, tags); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/Node.java b/clients/src/main/java/org/apache/kafka/common/Node.java new file mode 100644 index 0000000..020d2bc --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/Node.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common; + +import java.util.Objects; + +/** + * Information about a Kafka node + */ +public class Node { + + private static final Node NO_NODE = new Node(-1, "", -1); + + private final int id; + private final String idString; + private final String host; + private final int port; + private final String rack; + + // Cache hashCode as it is called in performance sensitive parts of the code (e.g. RecordAccumulator.ready) + private Integer hash; + + public Node(int id, String host, int port) { + this(id, host, port, null); + } + + public Node(int id, String host, int port, String rack) { + this.id = id; + this.idString = Integer.toString(id); + this.host = host; + this.port = port; + this.rack = rack; + } + + public static Node noNode() { + return NO_NODE; + } + + /** + * Check whether this node is empty, which may be the case if noNode() is used as a placeholder + * in a response payload with an error. + * @return true if it is, false otherwise + */ + public boolean isEmpty() { + return host == null || host.isEmpty() || port < 0; + } + + /** + * The node id of this node + */ + public int id() { + return id; + } + + /** + * String representation of the node id. + * Typically the integer id is used to serialize over the wire, the string representation is used as an identifier with NetworkClient code + */ + public String idString() { + return idString; + } + + /** + * The host name for this node + */ + public String host() { + return host; + } + + /** + * The port for this node + */ + public int port() { + return port; + } + + /** + * True if this node has a defined rack + */ + public boolean hasRack() { + return rack != null; + } + + /** + * The rack for this node + */ + public String rack() { + return rack; + } + + @Override + public int hashCode() { + Integer h = this.hash; + if (h == null) { + int result = 31 + ((host == null) ? 0 : host.hashCode()); + result = 31 * result + id; + result = 31 * result + port; + result = 31 * result + ((rack == null) ? 0 : rack.hashCode()); + this.hash = result; + return result; + } else { + return h; + } + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null || getClass() != obj.getClass()) + return false; + Node other = (Node) obj; + return id == other.id && + port == other.port && + Objects.equals(host, other.host) && + Objects.equals(rack, other.rack); + } + + @Override + public String toString() { + return host + ":" + port + " (id: " + idString + " rack: " + rack + ")"; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/PartitionInfo.java b/clients/src/main/java/org/apache/kafka/common/PartitionInfo.java new file mode 100644 index 0000000..29bedff --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/PartitionInfo.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common; + +/** + * This is used to describe per-partition state in the MetadataResponse. + */ +public class PartitionInfo { + private final String topic; + private final int partition; + private final Node leader; + private final Node[] replicas; + private final Node[] inSyncReplicas; + private final Node[] offlineReplicas; + + public PartitionInfo(String topic, int partition, Node leader, Node[] replicas, Node[] inSyncReplicas) { + this(topic, partition, leader, replicas, inSyncReplicas, new Node[0]); + } + + public PartitionInfo(String topic, + int partition, + Node leader, + Node[] replicas, + Node[] inSyncReplicas, + Node[] offlineReplicas) { + this.topic = topic; + this.partition = partition; + this.leader = leader; + this.replicas = replicas; + this.inSyncReplicas = inSyncReplicas; + this.offlineReplicas = offlineReplicas; + } + + /** + * The topic name + */ + public String topic() { + return topic; + } + + /** + * The partition id + */ + public int partition() { + return partition; + } + + /** + * The node id of the node currently acting as a leader for this partition or null if there is no leader + */ + public Node leader() { + return leader; + } + + /** + * The complete set of replicas for this partition regardless of whether they are alive or up-to-date + */ + public Node[] replicas() { + return replicas; + } + + /** + * The subset of the replicas that are in sync, that is caught-up to the leader and ready to take over as leader if + * the leader should fail + */ + public Node[] inSyncReplicas() { + return inSyncReplicas; + } + + /** + * The subset of the replicas that are offline + */ + public Node[] offlineReplicas() { + return offlineReplicas; + } + + @Override + public String toString() { + return String.format("Partition(topic = %s, partition = %d, leader = %s, replicas = %s, isr = %s, offlineReplicas = %s)", + topic, + partition, + leader == null ? "none" : leader.idString(), + formatNodeIds(replicas), + formatNodeIds(inSyncReplicas), + formatNodeIds(offlineReplicas)); + } + + /* Extract the node ids from each item in the array and format for display */ + private String formatNodeIds(Node[] nodes) { + StringBuilder b = new StringBuilder("["); + if (nodes != null) { + for (int i = 0; i < nodes.length; i++) { + b.append(nodes[i].idString()); + if (i < nodes.length - 1) + b.append(','); + } + } + b.append("]"); + return b.toString(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/Reconfigurable.java b/clients/src/main/java/org/apache/kafka/common/Reconfigurable.java new file mode 100644 index 0000000..8db9dc2 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/Reconfigurable.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common; + +import org.apache.kafka.common.config.ConfigException; + +import java.util.Map; +import java.util.Set; + +/** + * Interface for reconfigurable classes that support dynamic configuration. + */ +public interface Reconfigurable extends Configurable { + + /** + * Returns the names of configs that may be reconfigured. + */ + Set reconfigurableConfigs(); + + /** + * Validates the provided configuration. The provided map contains + * all configs including any reconfigurable configs that may be different + * from the initial configuration. Reconfiguration will be not performed + * if this method throws any exception. + * @throws ConfigException if the provided configs are not valid. The exception + * message from ConfigException will be returned to the client in + * the AlterConfigs response. + */ + void validateReconfiguration(Map configs) throws ConfigException; + + /** + * Reconfigures this instance with the given key-value pairs. The provided + * map contains all configs including any reconfigurable configs that + * may have changed since the object was initially configured using + * {@link Configurable#configure(Map)}. This method will only be invoked if + * the configs have passed validation using {@link #validateReconfiguration(Map)}. + */ + void reconfigure(Map configs); + +} diff --git a/clients/src/main/java/org/apache/kafka/common/TopicCollection.java b/clients/src/main/java/org/apache/kafka/common/TopicCollection.java new file mode 100644 index 0000000..5661e6c --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/TopicCollection.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; + +/** + * A class used to represent a collection of topics. This collection may define topics by name or ID. + */ +public abstract class TopicCollection { + + private TopicCollection() {} + + /** + * @return a collection of topics defined by topic ID + */ + public static TopicIdCollection ofTopicIds(Collection topics) { + return new TopicIdCollection(topics); + } + + /** + * @return a collection of topics defined by topic name + */ + public static TopicNameCollection ofTopicNames(Collection topics) { + return new TopicNameCollection(topics); + } + + /** + * A class used to represent a collection of topics defined by their topic ID. + * Subclassing this class beyond the classes provided here is not supported. + */ + public static class TopicIdCollection extends TopicCollection { + private final Collection topicIds; + + private TopicIdCollection(Collection topicIds) { + this.topicIds = new ArrayList<>(topicIds); + } + + /** + * @return A collection of topic IDs + */ + public Collection topicIds() { + return Collections.unmodifiableCollection(topicIds); + } + } + + /** + * A class used to represent a collection of topics defined by their topic name. + * Subclassing this class beyond the classes provided here is not supported. + */ + public static class TopicNameCollection extends TopicCollection { + private final Collection topicNames; + + private TopicNameCollection(Collection topicNames) { + this.topicNames = new ArrayList<>(topicNames); + } + + /** + * @return A collection of topic names + */ + public Collection topicNames() { + return Collections.unmodifiableCollection(topicNames); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/TopicIdPartition.java b/clients/src/main/java/org/apache/kafka/common/TopicIdPartition.java new file mode 100644 index 0000000..09d861e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/TopicIdPartition.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common; + +import java.util.Objects; + +/** + * This represents universally unique identifier with topic id for a topic partition. This makes sure that topics + * recreated with the same name will always have unique topic identifiers. + */ +public class TopicIdPartition { + + private final Uuid topicId; + private final TopicPartition topicPartition; + + /** + * Create an instance with the provided parameters. + * + * @param topicId the topic id + * @param topicPartition the topic partition + */ + public TopicIdPartition(Uuid topicId, TopicPartition topicPartition) { + this.topicId = Objects.requireNonNull(topicId, "topicId can not be null"); + this.topicPartition = Objects.requireNonNull(topicPartition, "topicPartition can not be null"); + } + + /** + * Create an instance with the provided parameters. + * + * @param topicId the topic id + * @param partition the partition id + * @param topic the topic name or null + */ + public TopicIdPartition(Uuid topicId, int partition, String topic) { + this.topicId = Objects.requireNonNull(topicId, "topicId can not be null"); + this.topicPartition = new TopicPartition(topic, partition); + } + + /** + * @return Universally unique id representing this topic partition. + */ + public Uuid topicId() { + return topicId; + } + + /** + * @return the topic name or null if it is unknown. + */ + public String topic() { + return topicPartition.topic(); + } + + /** + * @return the partition id. + */ + public int partition() { + return topicPartition.partition(); + } + + /** + * @return Topic partition representing this instance. + */ + public TopicPartition topicPartition() { + return topicPartition; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + TopicIdPartition that = (TopicIdPartition) o; + return topicId.equals(that.topicId) && + topicPartition.equals(that.topicPartition); + } + + @Override + public int hashCode() { + final int prime = 31; + int result = prime + topicId.hashCode(); + result = prime * result + topicPartition.hashCode(); + return result; + } + + @Override + public String toString() { + return topicId + ":" + topic() + "-" + partition(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/TopicPartition.java b/clients/src/main/java/org/apache/kafka/common/TopicPartition.java new file mode 100644 index 0000000..7c8fe79 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/TopicPartition.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common; + +import java.io.Serializable; +import java.util.Objects; + +/** + * A topic name and partition number + */ +public final class TopicPartition implements Serializable { + private static final long serialVersionUID = -613627415771699627L; + + private int hash = 0; + private final int partition; + private final String topic; + + public TopicPartition(String topic, int partition) { + this.partition = partition; + this.topic = topic; + } + + public int partition() { + return partition; + } + + public String topic() { + return topic; + } + + @Override + public int hashCode() { + if (hash != 0) + return hash; + final int prime = 31; + int result = prime + partition; + result = prime * result + Objects.hashCode(topic); + this.hash = result; + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + TopicPartition other = (TopicPartition) obj; + return partition == other.partition && Objects.equals(topic, other.topic); + } + + @Override + public String toString() { + return topic + "-" + partition; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/TopicPartitionInfo.java b/clients/src/main/java/org/apache/kafka/common/TopicPartitionInfo.java new file mode 100644 index 0000000..60b5d37 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/TopicPartitionInfo.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common; + +import org.apache.kafka.common.utils.Utils; + +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +/** + * A class containing leadership, replicas and ISR information for a topic partition. + */ +public class TopicPartitionInfo { + private final int partition; + private final Node leader; + private final List replicas; + private final List isr; + + /** + * Create an instance of this class with the provided parameters. + * + * @param partition the partition id + * @param leader the leader of the partition or {@link Node#noNode()} if there is none. + * @param replicas the replicas of the partition in the same order as the replica assignment (the preferred replica + * is the head of the list) + * @param isr the in-sync replicas + */ + public TopicPartitionInfo(int partition, Node leader, List replicas, List isr) { + this.partition = partition; + this.leader = leader; + this.replicas = Collections.unmodifiableList(replicas); + this.isr = Collections.unmodifiableList(isr); + } + + /** + * Return the partition id. + */ + public int partition() { + return partition; + } + + /** + * Return the leader of the partition or null if there is none. + */ + public Node leader() { + return leader; + } + + /** + * Return the replicas of the partition in the same order as the replica assignment. The preferred replica is the + * head of the list. + * + * Brokers with version lower than 0.11.0.0 return the replicas in unspecified order due to a bug. + */ + public List replicas() { + return replicas; + } + + /** + * Return the in-sync replicas of the partition. Note that the ordering of the result is unspecified. + */ + public List isr() { + return isr; + } + + public String toString() { + return "(partition=" + partition + ", leader=" + leader + ", replicas=" + + Utils.join(replicas, ", ") + ", isr=" + Utils.join(isr, ", ") + ")"; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + TopicPartitionInfo that = (TopicPartitionInfo) o; + + return partition == that.partition && + Objects.equals(leader, that.leader) && + Objects.equals(replicas, that.replicas) && + Objects.equals(isr, that.isr); + } + + @Override + public int hashCode() { + int result = partition; + result = 31 * result + (leader != null ? leader.hashCode() : 0); + result = 31 * result + (replicas != null ? replicas.hashCode() : 0); + result = 31 * result + (isr != null ? isr.hashCode() : 0); + return result; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/TopicPartitionReplica.java b/clients/src/main/java/org/apache/kafka/common/TopicPartitionReplica.java new file mode 100644 index 0000000..0a7c419 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/TopicPartitionReplica.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common; + +import java.io.Serializable; +import java.util.Objects; + + +/** + * The topic name, partition number and the brokerId of the replica + */ +public final class TopicPartitionReplica implements Serializable { + + private int hash = 0; + private final int brokerId; + private final int partition; + private final String topic; + + public TopicPartitionReplica(String topic, int partition, int brokerId) { + this.topic = Objects.requireNonNull(topic); + this.partition = partition; + this.brokerId = brokerId; + } + + public String topic() { + return topic; + } + + public int partition() { + return partition; + } + + public int brokerId() { + return brokerId; + } + + @Override + public int hashCode() { + if (hash != 0) { + return hash; + } + final int prime = 31; + int result = 1; + result = prime * result + topic.hashCode(); + result = prime * result + partition; + result = prime * result + brokerId; + this.hash = result; + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + TopicPartitionReplica other = (TopicPartitionReplica) obj; + return partition == other.partition && brokerId == other.brokerId && topic.equals(other.topic); + } + + @Override + public String toString() { + return String.format("%s-%d-%d", topic, partition, brokerId); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/Uuid.java b/clients/src/main/java/org/apache/kafka/common/Uuid.java new file mode 100644 index 0000000..a639f3e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/Uuid.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common; + +import java.nio.ByteBuffer; +import java.util.Base64; + +/** + * This class defines an immutable universally unique identifier (UUID). It represents a 128-bit value. + * More specifically, the random UUIDs generated by this class are variant 2 (Leach-Salz) version 4 UUIDs. + * This is the same type of UUID as the ones generated by java.util.UUID. The toString() method prints + * using the base64 string encoding. Likewise, the fromString method expects a base64 string encoding. + */ +public class Uuid implements Comparable { + + /** + * A UUID for the metadata topic in KRaft mode. Will never be returned by the randomUuid method. + */ + public static final Uuid METADATA_TOPIC_ID = new Uuid(0L, 1L); + private static final java.util.UUID METADATA_TOPIC_ID_INTERNAL = new java.util.UUID(0L, 1L); + + /** + * A UUID that represents a null or empty UUID. Will never be returned by the randomUuid method. + */ + public static final Uuid ZERO_UUID = new Uuid(0L, 0L); + private static final java.util.UUID ZERO_ID_INTERNAL = new java.util.UUID(0L, 0L); + + private final long mostSignificantBits; + private final long leastSignificantBits; + + /** + * Constructs a 128-bit type 4 UUID where the first long represents the most significant 64 bits + * and the second long represents the least significant 64 bits. + */ + public Uuid(long mostSigBits, long leastSigBits) { + this.mostSignificantBits = mostSigBits; + this.leastSignificantBits = leastSigBits; + } + + /** + * Static factory to retrieve a type 4 (pseudo randomly generated) UUID. + */ + public static Uuid randomUuid() { + java.util.UUID uuid = java.util.UUID.randomUUID(); + while (uuid.equals(METADATA_TOPIC_ID_INTERNAL) || uuid.equals(ZERO_ID_INTERNAL)) { + uuid = java.util.UUID.randomUUID(); + } + return new Uuid(uuid.getMostSignificantBits(), uuid.getLeastSignificantBits()); + } + + /** + * Returns the most significant bits of the UUID's 128 value. + */ + public long getMostSignificantBits() { + return this.mostSignificantBits; + } + + /** + * Returns the least significant bits of the UUID's 128 value. + */ + public long getLeastSignificantBits() { + return this.leastSignificantBits; + } + + /** + * Returns true iff obj is another Uuid represented by the same two long values. + */ + @Override + public boolean equals(Object obj) { + if ((null == obj) || (obj.getClass() != this.getClass())) + return false; + Uuid id = (Uuid) obj; + return this.mostSignificantBits == id.mostSignificantBits && + this.leastSignificantBits == id.leastSignificantBits; + } + + /** + * Returns a hash code for this UUID + */ + @Override + public int hashCode() { + long xor = mostSignificantBits ^ leastSignificantBits; + return (int) (xor >> 32) ^ (int) xor; + } + + /** + * Returns a base64 string encoding of the UUID. + */ + @Override + public String toString() { + return Base64.getUrlEncoder().withoutPadding().encodeToString(getBytesFromUuid()); + } + + /** + * Creates a UUID based on a base64 string encoding used in the toString() method. + */ + public static Uuid fromString(String str) { + if (str.length() > 24) { + throw new IllegalArgumentException("Input string with prefix `" + + str.substring(0, 24) + "` is too long to be decoded as a base64 UUID"); + } + + ByteBuffer uuidBytes = ByteBuffer.wrap(Base64.getUrlDecoder().decode(str)); + if (uuidBytes.remaining() != 16) { + throw new IllegalArgumentException("Input string `" + str + "` decoded as " + + uuidBytes.remaining() + " bytes, which is not equal to the expected 16 bytes " + + "of a base64-encoded UUID"); + } + + return new Uuid(uuidBytes.getLong(), uuidBytes.getLong()); + } + + private byte[] getBytesFromUuid() { + // Extract bytes for uuid which is 128 bits (or 16 bytes) long. + ByteBuffer uuidBytes = ByteBuffer.wrap(new byte[16]); + uuidBytes.putLong(this.mostSignificantBits); + uuidBytes.putLong(this.leastSignificantBits); + return uuidBytes.array(); + } + + @Override + public int compareTo(Uuid other) { + if (mostSignificantBits > other.mostSignificantBits) { + return 1; + } else if (mostSignificantBits < other.mostSignificantBits) { + return -1; + } else if (leastSignificantBits > other.leastSignificantBits) { + return 1; + } else if (leastSignificantBits < other.leastSignificantBits) { + return -1; + } else { + return 0; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/acl/AccessControlEntry.java b/clients/src/main/java/org/apache/kafka/common/acl/AccessControlEntry.java new file mode 100644 index 0000000..d5e05df --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/acl/AccessControlEntry.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.acl; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Objects; + +/** + * Represents an access control entry. ACEs are a tuple of principal, host, operation, and permissionType. + * + * The API for this class is still evolving and we may break compatibility in minor releases, if necessary. + */ +@InterfaceStability.Evolving +public class AccessControlEntry { + final AccessControlEntryData data; + + /** + * Create an instance of an access control entry with the provided parameters. + * + * @param principal non-null principal + * @param host non-null host + * @param operation non-null operation, ANY is not an allowed operation + * @param permissionType non-null permission type, ANY is not an allowed type + */ + public AccessControlEntry(String principal, String host, AclOperation operation, AclPermissionType permissionType) { + Objects.requireNonNull(principal); + Objects.requireNonNull(host); + Objects.requireNonNull(operation); + if (operation == AclOperation.ANY) + throw new IllegalArgumentException("operation must not be ANY"); + Objects.requireNonNull(permissionType); + if (permissionType == AclPermissionType.ANY) + throw new IllegalArgumentException("permissionType must not be ANY"); + this.data = new AccessControlEntryData(principal, host, operation, permissionType); + } + + /** + * Return the principal for this entry. + */ + public String principal() { + return data.principal(); + } + + /** + * Return the host or `*` for all hosts. + */ + public String host() { + return data.host(); + } + + /** + * Return the AclOperation. This method will never return AclOperation.ANY. + */ + public AclOperation operation() { + return data.operation(); + } + + /** + * Return the AclPermissionType. This method will never return AclPermissionType.ANY. + */ + public AclPermissionType permissionType() { + return data.permissionType(); + } + + /** + * Create a filter which matches only this AccessControlEntry. + */ + public AccessControlEntryFilter toFilter() { + return new AccessControlEntryFilter(data); + } + + @Override + public String toString() { + return data.toString(); + } + + /** + * Return true if this AclResource has any UNKNOWN components. + */ + public boolean isUnknown() { + return data.isUnknown(); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof AccessControlEntry)) + return false; + AccessControlEntry other = (AccessControlEntry) o; + return data.equals(other.data); + } + + @Override + public int hashCode() { + return data.hashCode(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/acl/AccessControlEntryData.java b/clients/src/main/java/org/apache/kafka/common/acl/AccessControlEntryData.java new file mode 100644 index 0000000..ad7660d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/acl/AccessControlEntryData.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.acl; + +import java.util.Objects; + +/** + * An internal, private class which contains the data stored in AccessControlEntry and + * AccessControlEntryFilter objects. + */ +class AccessControlEntryData { + private final String principal; + private final String host; + private final AclOperation operation; + private final AclPermissionType permissionType; + + AccessControlEntryData(String principal, String host, AclOperation operation, AclPermissionType permissionType) { + this.principal = principal; + this.host = host; + this.operation = operation; + this.permissionType = permissionType; + } + + String principal() { + return principal; + } + + String host() { + return host; + } + + AclOperation operation() { + return operation; + } + + AclPermissionType permissionType() { + return permissionType; + } + + /** + * Returns a string describing an ANY or UNKNOWN field, or null if there is + * no such field. + */ + public String findIndefiniteField() { + if (principal() == null) + return "Principal is NULL"; + if (host() == null) + return "Host is NULL"; + if (operation() == AclOperation.ANY) + return "Operation is ANY"; + if (operation() == AclOperation.UNKNOWN) + return "Operation is UNKNOWN"; + if (permissionType() == AclPermissionType.ANY) + return "Permission type is ANY"; + if (permissionType() == AclPermissionType.UNKNOWN) + return "Permission type is UNKNOWN"; + return null; + } + + @Override + public String toString() { + return "(principal=" + (principal == null ? "" : principal) + + ", host=" + (host == null ? "" : host) + + ", operation=" + operation + + ", permissionType=" + permissionType + ")"; + } + + /** + * Return true if there are any UNKNOWN components. + */ + boolean isUnknown() { + return operation.isUnknown() || permissionType.isUnknown(); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof AccessControlEntryData)) + return false; + AccessControlEntryData other = (AccessControlEntryData) o; + return Objects.equals(principal, other.principal) && + Objects.equals(host, other.host) && + Objects.equals(operation, other.operation) && + Objects.equals(permissionType, other.permissionType); + } + + @Override + public int hashCode() { + return Objects.hash(principal, host, operation, permissionType); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/acl/AccessControlEntryFilter.java b/clients/src/main/java/org/apache/kafka/common/acl/AccessControlEntryFilter.java new file mode 100644 index 0000000..225e73a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/acl/AccessControlEntryFilter.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.acl; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Objects; + +/** + * Represents a filter which matches access control entries. + * + * The API for this class is still evolving and we may break compatibility in minor releases, if necessary. + */ +@InterfaceStability.Evolving +public class AccessControlEntryFilter { + private final AccessControlEntryData data; + + /** + * Matches any access control entry. + */ + public static final AccessControlEntryFilter ANY = + new AccessControlEntryFilter(null, null, AclOperation.ANY, AclPermissionType.ANY); + + /** + * Create an instance of an access control entry filter with the provided parameters. + * + * @param principal the principal or null + * @param host the host or null + * @param operation non-null operation + * @param permissionType non-null permission type + */ + public AccessControlEntryFilter(String principal, String host, AclOperation operation, AclPermissionType permissionType) { + Objects.requireNonNull(operation); + Objects.requireNonNull(permissionType); + this.data = new AccessControlEntryData(principal, host, operation, permissionType); + } + + /** + * This is a non-public constructor used in AccessControlEntry#toFilter + * + * @param data The access control data. + */ + AccessControlEntryFilter(AccessControlEntryData data) { + this.data = data; + } + + /** + * Return the principal or null. + */ + public String principal() { + return data.principal(); + } + + /** + * Return the host or null. The value `*` means any host. + */ + public String host() { + return data.host(); + } + + /** + * Return the AclOperation. + */ + public AclOperation operation() { + return data.operation(); + } + + /** + * Return the AclPermissionType. + */ + public AclPermissionType permissionType() { + return data.permissionType(); + } + + @Override + public String toString() { + return data.toString(); + } + + /** + * Return true if there are any UNKNOWN components. + */ + public boolean isUnknown() { + return data.isUnknown(); + } + + /** + * Returns true if this filter matches the given AccessControlEntry. + */ + public boolean matches(AccessControlEntry other) { + if ((principal() != null) && (!data.principal().equals(other.principal()))) + return false; + if ((host() != null) && (!host().equals(other.host()))) + return false; + if ((operation() != AclOperation.ANY) && (!operation().equals(other.operation()))) + return false; + return (permissionType() == AclPermissionType.ANY) || (permissionType().equals(other.permissionType())); + } + + /** + * Returns true if this filter could only match one ACE -- in other words, if + * there are no ANY or UNKNOWN fields. + */ + public boolean matchesAtMostOne() { + return findIndefiniteField() == null; + } + + /** + * Returns a string describing an ANY or UNKNOWN field, or null if there is + * no such field. + */ + public String findIndefiniteField() { + return data.findIndefiniteField(); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof AccessControlEntryFilter)) + return false; + AccessControlEntryFilter other = (AccessControlEntryFilter) o; + return data.equals(other.data); + } + + @Override + public int hashCode() { + return data.hashCode(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/acl/AclBinding.java b/clients/src/main/java/org/apache/kafka/common/acl/AclBinding.java new file mode 100644 index 0000000..c323426 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/acl/AclBinding.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.acl; + +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.resource.ResourcePattern; + +import java.util.Objects; + +/** + * Represents a binding between a resource pattern and an access control entry. + * + * The API for this class is still evolving and we may break compatibility in minor releases, if necessary. + */ +@InterfaceStability.Evolving +public class AclBinding { + private final ResourcePattern pattern; + private final AccessControlEntry entry; + + /** + * Create an instance of this class with the provided parameters. + * + * @param pattern non-null resource pattern. + * @param entry non-null entry + */ + public AclBinding(ResourcePattern pattern, AccessControlEntry entry) { + this.pattern = Objects.requireNonNull(pattern, "pattern"); + this.entry = Objects.requireNonNull(entry, "entry"); + } + + /** + * @return true if this binding has any UNKNOWN components. + */ + public boolean isUnknown() { + return pattern.isUnknown() || entry.isUnknown(); + } + + /** + * @return the resource pattern for this binding. + */ + public ResourcePattern pattern() { + return pattern; + } + + /** + * @return the access control entry for this binding. + */ + public final AccessControlEntry entry() { + return entry; + } + + /** + * Create a filter which matches only this AclBinding. + */ + public AclBindingFilter toFilter() { + return new AclBindingFilter(pattern.toFilter(), entry.toFilter()); + } + + @Override + public String toString() { + return "(pattern=" + pattern + ", entry=" + entry + ")"; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AclBinding that = (AclBinding) o; + return Objects.equals(pattern, that.pattern) && + Objects.equals(entry, that.entry); + } + + @Override + public int hashCode() { + return Objects.hash(pattern, entry); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/acl/AclBindingFilter.java b/clients/src/main/java/org/apache/kafka/common/acl/AclBindingFilter.java new file mode 100644 index 0000000..7682386 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/acl/AclBindingFilter.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.acl; + +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.resource.ResourcePatternFilter; + +import java.util.Objects; + +/** + * A filter which can match AclBinding objects. + * + * The API for this class is still evolving and we may break compatibility in minor releases, if necessary. + */ +@InterfaceStability.Evolving +public class AclBindingFilter { + private final ResourcePatternFilter patternFilter; + private final AccessControlEntryFilter entryFilter; + + /** + * A filter which matches any ACL binding. + */ + public static final AclBindingFilter ANY = new AclBindingFilter(ResourcePatternFilter.ANY, AccessControlEntryFilter.ANY); + + /** + * Create an instance of this filter with the provided parameters. + * + * @param patternFilter non-null pattern filter + * @param entryFilter non-null access control entry filter + */ + public AclBindingFilter(ResourcePatternFilter patternFilter, AccessControlEntryFilter entryFilter) { + this.patternFilter = Objects.requireNonNull(patternFilter, "patternFilter"); + this.entryFilter = Objects.requireNonNull(entryFilter, "entryFilter"); + } + + /** + * @return {@code true} if this filter has any UNKNOWN components. + */ + public boolean isUnknown() { + return patternFilter.isUnknown() || entryFilter.isUnknown(); + } + + /** + * @return the resource pattern filter. + */ + public ResourcePatternFilter patternFilter() { + return patternFilter; + } + + /** + * @return the access control entry filter. + */ + public final AccessControlEntryFilter entryFilter() { + return entryFilter; + } + + @Override + public String toString() { + return "(patternFilter=" + patternFilter + ", entryFilter=" + entryFilter + ")"; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AclBindingFilter that = (AclBindingFilter) o; + return Objects.equals(patternFilter, that.patternFilter) && + Objects.equals(entryFilter, that.entryFilter); + } + + /** + * Return true if the resource and entry filters can only match one ACE. In other words, if + * there are no ANY or UNKNOWN fields. + */ + public boolean matchesAtMostOne() { + return patternFilter.matchesAtMostOne() && entryFilter.matchesAtMostOne(); + } + + /** + * Return a string describing an ANY or UNKNOWN field, or null if there is no such field. + */ + public String findIndefiniteField() { + String indefinite = patternFilter.findIndefiniteField(); + if (indefinite != null) + return indefinite; + return entryFilter.findIndefiniteField(); + } + + /** + * Return true if the resource filter matches the binding's resource and the entry filter matches binding's entry. + */ + public boolean matches(AclBinding binding) { + return patternFilter.matches(binding.pattern()) && entryFilter.matches(binding.entry()); + } + + @Override + public int hashCode() { + return Objects.hash(patternFilter, entryFilter); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/acl/AclOperation.java b/clients/src/main/java/org/apache/kafka/common/acl/AclOperation.java new file mode 100644 index 0000000..6710697 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/acl/AclOperation.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.acl; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.HashMap; +import java.util.Locale; + +/** + * Represents an operation which an ACL grants or denies permission to perform. + * + * Some operations imply other operations: + *
          + *
        • ALLOW ALL implies ALLOW everything + *
        • DENY ALL implies DENY everything + * + *
        • ALLOW READ implies ALLOW DESCRIBE + *
        • ALLOW WRITE implies ALLOW DESCRIBE + *
        • ALLOW DELETE implies ALLOW DESCRIBE + * + *
        • ALLOW ALTER implies ALLOW DESCRIBE + * + *
        • ALLOW ALTER_CONFIGS implies ALLOW DESCRIBE_CONFIGS + *
        + * The API for this class is still evolving and we may break compatibility in minor releases, if necessary. + */ +@InterfaceStability.Evolving +public enum AclOperation { + /** + * Represents any AclOperation which this client cannot understand, perhaps because this + * client is too old. + */ + UNKNOWN((byte) 0), + + /** + * In a filter, matches any AclOperation. + */ + ANY((byte) 1), + + /** + * ALL operation. + */ + ALL((byte) 2), + + /** + * READ operation. + */ + READ((byte) 3), + + /** + * WRITE operation. + */ + WRITE((byte) 4), + + /** + * CREATE operation. + */ + CREATE((byte) 5), + + /** + * DELETE operation. + */ + DELETE((byte) 6), + + /** + * ALTER operation. + */ + ALTER((byte) 7), + + /** + * DESCRIBE operation. + */ + DESCRIBE((byte) 8), + + /** + * CLUSTER_ACTION operation. + */ + CLUSTER_ACTION((byte) 9), + + /** + * DESCRIBE_CONFIGS operation. + */ + DESCRIBE_CONFIGS((byte) 10), + + /** + * ALTER_CONFIGS operation. + */ + ALTER_CONFIGS((byte) 11), + + /** + * IDEMPOTENT_WRITE operation. + */ + IDEMPOTENT_WRITE((byte) 12); + + // Note: we cannot have more than 30 ACL operations without modifying the format used + // to describe ACL operations in MetadataResponse. + + private final static HashMap CODE_TO_VALUE = new HashMap<>(); + + static { + for (AclOperation operation : AclOperation.values()) { + CODE_TO_VALUE.put(operation.code, operation); + } + } + + /** + * Parse the given string as an ACL operation. + * + * @param str The string to parse. + * + * @return The AclOperation, or UNKNOWN if the string could not be matched. + */ + public static AclOperation fromString(String str) throws IllegalArgumentException { + try { + return AclOperation.valueOf(str.toUpperCase(Locale.ROOT)); + } catch (IllegalArgumentException e) { + return UNKNOWN; + } + } + + /** + * Return the AclOperation with the provided code or `AclOperation.UNKNOWN` if one cannot be found. + */ + public static AclOperation fromCode(byte code) { + AclOperation operation = CODE_TO_VALUE.get(code); + if (operation == null) { + return UNKNOWN; + } + return operation; + } + + private final byte code; + + AclOperation(byte code) { + this.code = code; + } + + /** + * Return the code of this operation. + */ + public byte code() { + return code; + } + + /** + * Return true if this operation is UNKNOWN. + */ + public boolean isUnknown() { + return this == UNKNOWN; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/acl/AclPermissionType.java b/clients/src/main/java/org/apache/kafka/common/acl/AclPermissionType.java new file mode 100644 index 0000000..c5b077c --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/acl/AclPermissionType.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.acl; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.HashMap; +import java.util.Locale; + +/** + * Represents whether an ACL grants or denies permissions. + * + * The API for this class is still evolving and we may break compatibility in minor releases, if necessary. + */ +@InterfaceStability.Evolving +public enum AclPermissionType { + /** + * Represents any AclPermissionType which this client cannot understand, + * perhaps because this client is too old. + */ + UNKNOWN((byte) 0), + + /** + * In a filter, matches any AclPermissionType. + */ + ANY((byte) 1), + + /** + * Disallows access. + */ + DENY((byte) 2), + + /** + * Grants access. + */ + ALLOW((byte) 3); + + private final static HashMap CODE_TO_VALUE = new HashMap<>(); + + static { + for (AclPermissionType permissionType : AclPermissionType.values()) { + CODE_TO_VALUE.put(permissionType.code, permissionType); + } + } + + /** + * Parse the given string as an ACL permission. + * + * @param str The string to parse. + * + * @return The AclPermissionType, or UNKNOWN if the string could not be matched. + */ + public static AclPermissionType fromString(String str) { + try { + return AclPermissionType.valueOf(str.toUpperCase(Locale.ROOT)); + } catch (IllegalArgumentException e) { + return UNKNOWN; + } + } + + /** + * Return the AclPermissionType with the provided code or `AclPermissionType.UNKNOWN` if one cannot be found. + */ + public static AclPermissionType fromCode(byte code) { + AclPermissionType permissionType = CODE_TO_VALUE.get(code); + if (permissionType == null) { + return UNKNOWN; + } + return permissionType; + } + + private final byte code; + + AclPermissionType(byte code) { + this.code = code; + } + + /** + * Return the code of this permission type. + */ + public byte code() { + return code; + } + + /** + * Return true if this permission type is UNKNOWN. + */ + public boolean isUnknown() { + return this == UNKNOWN; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/annotation/InterfaceStability.java b/clients/src/main/java/org/apache/kafka/common/annotation/InterfaceStability.java new file mode 100644 index 0000000..25e30ad --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/annotation/InterfaceStability.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; + +/** + * Annotation to inform users of how much to rely on a particular package, class or method not changing over time. + * Currently the stability can be {@link Stable}, {@link Evolving} or {@link Unstable}. + */ +@InterfaceStability.Evolving +public class InterfaceStability { + /** + * Compatibility is maintained in major, minor and patch releases with one exception: compatibility may be broken + * in a major release (i.e. 0.m) for APIs that have been deprecated for at least one major/minor release cycle. + * In cases where the impact of breaking compatibility is significant, there is also a minimum deprecation period + * of one year. + * + * This is the default stability level for public APIs that are not annotated. + */ + @Documented + @Retention(RetentionPolicy.RUNTIME) + public @interface Stable { } + + /** + * Compatibility may be broken at minor release (i.e. m.x). + */ + @Documented + @Retention(RetentionPolicy.RUNTIME) + public @interface Evolving { } + + /** + * No guarantee is provided as to reliability or stability across any level of release granularity. + */ + @Documented + @Retention(RetentionPolicy.RUNTIME) + public @interface Unstable { } +} diff --git a/clients/src/main/java/org/apache/kafka/common/cache/Cache.java b/clients/src/main/java/org/apache/kafka/common/cache/Cache.java new file mode 100644 index 0000000..0da4907 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/cache/Cache.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.cache; + +/** + * Interface for caches, semi-persistent maps which store key-value mappings until either an eviction criteria is met + * or the entries are manually invalidated. Caches are not required to be thread-safe, but some implementations may be. + */ +public interface Cache { + + /** + * Look up a value in the cache. + * @param key the key to + * @return the cached value, or null if it is not present. + */ + V get(K key); + + /** + * Insert an entry into the cache. + * @param key the key to insert + * @param value the value to insert + */ + void put(K key, V value); + + /** + * Manually invalidate a key, clearing its entry from the cache. + * @param key the key to remove + * @return true if the key existed in the cache and the entry was removed or false if it was not present + */ + boolean remove(K key); + + /** + * Get the number of entries in this cache. If this cache is used by multiple threads concurrently, the returned + * value will only be approximate. + * @return the number of entries in the cache + */ + long size(); +} diff --git a/clients/src/main/java/org/apache/kafka/common/cache/LRUCache.java b/clients/src/main/java/org/apache/kafka/common/cache/LRUCache.java new file mode 100644 index 0000000..672cb65 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/cache/LRUCache.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.cache; + +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * A cache implementing a least recently used policy. + */ +public class LRUCache implements Cache { + private final LinkedHashMap cache; + + public LRUCache(final int maxSize) { + cache = new LinkedHashMap(16, .75f, true) { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return this.size() > maxSize; // require this. prefix to make lgtm.com happy + } + }; + } + + @Override + public V get(K key) { + return cache.get(key); + } + + @Override + public void put(K key, V value) { + cache.put(key, value); + } + + @Override + public boolean remove(K key) { + return cache.remove(key) != null; + } + + @Override + public long size() { + return cache.size(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/cache/SynchronizedCache.java b/clients/src/main/java/org/apache/kafka/common/cache/SynchronizedCache.java new file mode 100644 index 0000000..27cc4ba --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/cache/SynchronizedCache.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.cache; + +/** + * Wrapper for caches that adds simple synchronization to provide a thread-safe cache. Note that this simply adds + * synchronization around each cache method on the underlying unsynchronized cache. It does not add any support for + * atomically checking for existence of an entry and computing and inserting the value if it is missing. + */ +public class SynchronizedCache implements Cache { + private final Cache underlying; + + public SynchronizedCache(Cache underlying) { + this.underlying = underlying; + } + + @Override + public synchronized V get(K key) { + return underlying.get(key); + } + + @Override + public synchronized void put(K key, V value) { + underlying.put(key, value); + } + + @Override + public synchronized boolean remove(K key) { + return underlying.remove(key); + } + + @Override + public synchronized long size() { + return underlying.size(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/compress/KafkaLZ4BlockInputStream.java b/clients/src/main/java/org/apache/kafka/common/compress/KafkaLZ4BlockInputStream.java new file mode 100644 index 0000000..e2fbd5a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/compress/KafkaLZ4BlockInputStream.java @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.compress; + +import net.jpountz.lz4.LZ4Compressor; +import net.jpountz.lz4.LZ4Exception; +import net.jpountz.lz4.LZ4Factory; +import net.jpountz.lz4.LZ4SafeDecompressor; +import net.jpountz.xxhash.XXHash32; +import net.jpountz.xxhash.XXHashFactory; + +import org.apache.kafka.common.compress.KafkaLZ4BlockOutputStream.BD; +import org.apache.kafka.common.compress.KafkaLZ4BlockOutputStream.FLG; +import org.apache.kafka.common.utils.BufferSupplier; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +import static org.apache.kafka.common.compress.KafkaLZ4BlockOutputStream.LZ4_FRAME_INCOMPRESSIBLE_MASK; +import static org.apache.kafka.common.compress.KafkaLZ4BlockOutputStream.MAGIC; + +/** + * A partial implementation of the v1.5.1 LZ4 Frame format. + * + * @see LZ4 Frame Format + * + * This class is not thread-safe. + */ +public final class KafkaLZ4BlockInputStream extends InputStream { + + public static final String PREMATURE_EOS = "Stream ended prematurely"; + public static final String NOT_SUPPORTED = "Stream unsupported (invalid magic bytes)"; + public static final String BLOCK_HASH_MISMATCH = "Block checksum mismatch"; + public static final String DESCRIPTOR_HASH_MISMATCH = "Stream frame descriptor corrupted"; + + private static final LZ4SafeDecompressor DECOMPRESSOR = LZ4Factory.fastestInstance().safeDecompressor(); + private static final XXHash32 CHECKSUM = XXHashFactory.fastestInstance().hash32(); + + private static final RuntimeException BROKEN_LZ4_EXCEPTION; + // https://issues.apache.org/jira/browse/KAFKA-9203 + // detect buggy lz4 libraries on the classpath + static { + RuntimeException exception = null; + try { + detectBrokenLz4Version(); + } catch (RuntimeException e) { + exception = e; + } + BROKEN_LZ4_EXCEPTION = exception; + } + + private final ByteBuffer in; + private final boolean ignoreFlagDescriptorChecksum; + private final BufferSupplier bufferSupplier; + private final ByteBuffer decompressionBuffer; + // `flg` and `maxBlockSize` are effectively final, they are initialised in the `readHeader` method that is only + // invoked from the constructor + private FLG flg; + private int maxBlockSize; + + // If a block is compressed, this is the same as `decompressionBuffer`. If a block is not compressed, this is + // a slice of `in` to avoid unnecessary copies. + private ByteBuffer decompressedBuffer; + private boolean finished; + + /** + * Create a new {@link InputStream} that will decompress data using the LZ4 algorithm. + * + * @param in The byte buffer to decompress + * @param ignoreFlagDescriptorChecksum for compatibility with old kafka clients, ignore incorrect HC byte + * @throws IOException + */ + public KafkaLZ4BlockInputStream(ByteBuffer in, BufferSupplier bufferSupplier, boolean ignoreFlagDescriptorChecksum) throws IOException { + if (BROKEN_LZ4_EXCEPTION != null) { + throw BROKEN_LZ4_EXCEPTION; + } + this.ignoreFlagDescriptorChecksum = ignoreFlagDescriptorChecksum; + this.in = in.duplicate().order(ByteOrder.LITTLE_ENDIAN); + this.bufferSupplier = bufferSupplier; + readHeader(); + decompressionBuffer = bufferSupplier.get(maxBlockSize); + finished = false; + } + + /** + * Check whether KafkaLZ4BlockInputStream is configured to ignore the + * Frame Descriptor checksum, which is useful for compatibility with + * old client implementations that use incorrect checksum calculations. + */ + public boolean ignoreFlagDescriptorChecksum() { + return this.ignoreFlagDescriptorChecksum; + } + + /** + * Reads the magic number and frame descriptor from input buffer. + * + * @throws IOException + */ + private void readHeader() throws IOException { + // read first 6 bytes into buffer to check magic and FLG/BD descriptor flags + if (in.remaining() < 6) { + throw new IOException(PREMATURE_EOS); + } + + if (MAGIC != in.getInt()) { + throw new IOException(NOT_SUPPORTED); + } + // mark start of data to checksum + in.mark(); + + flg = FLG.fromByte(in.get()); + maxBlockSize = BD.fromByte(in.get()).getBlockMaximumSize(); + + if (flg.isContentSizeSet()) { + if (in.remaining() < 8) { + throw new IOException(PREMATURE_EOS); + } + in.position(in.position() + 8); + } + + // Final byte of Frame Descriptor is HC checksum + + // Old implementations produced incorrect HC checksums + if (ignoreFlagDescriptorChecksum) { + in.position(in.position() + 1); + return; + } + + int len = in.position() - in.reset().position(); + + int hash = CHECKSUM.hash(in, in.position(), len, 0); + in.position(in.position() + len); + if (in.get() != (byte) ((hash >> 8) & 0xFF)) { + throw new IOException(DESCRIPTOR_HASH_MISMATCH); + } + } + + /** + * Decompresses (if necessary) buffered data, optionally computes and validates a XXHash32 checksum, and writes the + * result to a buffer. + * + * @throws IOException + */ + private void readBlock() throws IOException { + if (in.remaining() < 4) { + throw new IOException(PREMATURE_EOS); + } + + int blockSize = in.getInt(); + boolean compressed = (blockSize & LZ4_FRAME_INCOMPRESSIBLE_MASK) == 0; + blockSize &= ~LZ4_FRAME_INCOMPRESSIBLE_MASK; + + // Check for EndMark + if (blockSize == 0) { + finished = true; + if (flg.isContentChecksumSet()) + in.getInt(); // TODO: verify this content checksum + return; + } else if (blockSize > maxBlockSize) { + throw new IOException(String.format("Block size %s exceeded max: %s", blockSize, maxBlockSize)); + } + + if (in.remaining() < blockSize) { + throw new IOException(PREMATURE_EOS); + } + + if (compressed) { + try { + final int bufferSize = DECOMPRESSOR.decompress(in, in.position(), blockSize, decompressionBuffer, 0, + maxBlockSize); + decompressionBuffer.position(0); + decompressionBuffer.limit(bufferSize); + decompressedBuffer = decompressionBuffer; + } catch (LZ4Exception e) { + throw new IOException(e); + } + } else { + decompressedBuffer = in.slice(); + decompressedBuffer.limit(blockSize); + } + + // verify checksum + if (flg.isBlockChecksumSet()) { + int hash = CHECKSUM.hash(in, in.position(), blockSize, 0); + in.position(in.position() + blockSize); + if (hash != in.getInt()) { + throw new IOException(BLOCK_HASH_MISMATCH); + } + } else { + in.position(in.position() + blockSize); + } + } + + @Override + public int read() throws IOException { + if (finished) { + return -1; + } + if (available() == 0) { + readBlock(); + } + if (finished) { + return -1; + } + + return decompressedBuffer.get() & 0xFF; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + net.jpountz.util.SafeUtils.checkRange(b, off, len); + if (finished) { + return -1; + } + if (available() == 0) { + readBlock(); + } + if (finished) { + return -1; + } + len = Math.min(len, available()); + + decompressedBuffer.get(b, off, len); + return len; + } + + @Override + public long skip(long n) throws IOException { + if (finished) { + return 0; + } + if (available() == 0) { + readBlock(); + } + if (finished) { + return 0; + } + int skipped = (int) Math.min(n, available()); + decompressedBuffer.position(decompressedBuffer.position() + skipped); + return skipped; + } + + @Override + public int available() { + return decompressedBuffer == null ? 0 : decompressedBuffer.remaining(); + } + + @Override + public void close() { + bufferSupplier.release(decompressionBuffer); + } + + @Override + public void mark(int readlimit) { + throw new RuntimeException("mark not supported"); + } + + @Override + public void reset() { + throw new RuntimeException("reset not supported"); + } + + @Override + public boolean markSupported() { + return false; + } + + /** + * Checks whether the version of lz4 on the classpath has the fix for reading from ByteBuffers with + * non-zero array offsets (see https://github.com/lz4/lz4-java/pull/65) + */ + static void detectBrokenLz4Version() { + byte[] source = new byte[]{1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3}; + final LZ4Compressor compressor = LZ4Factory.fastestInstance().fastCompressor(); + + final byte[] compressed = new byte[compressor.maxCompressedLength(source.length)]; + final int compressedLength = compressor.compress(source, 0, source.length, compressed, 0, + compressed.length); + + // allocate an array-backed ByteBuffer with non-zero array-offset containing the compressed data + // a buggy decompressor will read the data from the beginning of the underlying array instead of + // the beginning of the ByteBuffer, failing to decompress the invalid data. + final byte[] zeroes = {0, 0, 0, 0, 0}; + ByteBuffer nonZeroOffsetBuffer = ByteBuffer + .allocate(zeroes.length + compressed.length) // allocates the backing array with extra space to offset the data + .put(zeroes) // prepend invalid bytes (zeros) before the compressed data in the array + .slice() // create a new ByteBuffer sharing the underlying array, offset to start on the compressed data + .put(compressed); // write the compressed data at the beginning of this new buffer + + ByteBuffer dest = ByteBuffer.allocate(source.length); + try { + DECOMPRESSOR.decompress(nonZeroOffsetBuffer, 0, compressedLength, dest, 0, source.length); + } catch (Exception e) { + throw new RuntimeException("Kafka has detected detected a buggy lz4-java library (< 1.4.x) on the classpath." + + " If you are using Kafka client libraries, make sure your application does not" + + " accidentally override the version provided by Kafka or include multiple versions" + + " of the library on the classpath. The lz4-java version on the classpath should" + + " match the version the Kafka client libraries depend on. Adding -verbose:class" + + " to your JVM arguments may help understand which lz4-java version is getting loaded.", e); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/compress/KafkaLZ4BlockOutputStream.java b/clients/src/main/java/org/apache/kafka/common/compress/KafkaLZ4BlockOutputStream.java new file mode 100644 index 0000000..5c5aee4 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/compress/KafkaLZ4BlockOutputStream.java @@ -0,0 +1,423 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.compress; + +import java.io.IOException; +import java.io.OutputStream; + +import org.apache.kafka.common.utils.ByteUtils; + +import net.jpountz.lz4.LZ4Compressor; +import net.jpountz.lz4.LZ4Factory; +import net.jpountz.xxhash.XXHash32; +import net.jpountz.xxhash.XXHashFactory; + +/** + * A partial implementation of the v1.5.1 LZ4 Frame format. + * + * @see LZ4 Frame Format + * + * This class is not thread-safe. + */ +public final class KafkaLZ4BlockOutputStream extends OutputStream { + + public static final int MAGIC = 0x184D2204; + public static final int LZ4_MAX_HEADER_LENGTH = 19; + public static final int LZ4_FRAME_INCOMPRESSIBLE_MASK = 0x80000000; + + public static final String CLOSED_STREAM = "The stream is already closed"; + + public static final int BLOCKSIZE_64KB = 4; + public static final int BLOCKSIZE_256KB = 5; + public static final int BLOCKSIZE_1MB = 6; + public static final int BLOCKSIZE_4MB = 7; + + private final LZ4Compressor compressor; + private final XXHash32 checksum; + private final boolean useBrokenFlagDescriptorChecksum; + private final FLG flg; + private final BD bd; + private final int maxBlockSize; + private OutputStream out; + private byte[] buffer; + private byte[] compressedBuffer; + private int bufferOffset; + private boolean finished; + + /** + * Create a new {@link OutputStream} that will compress data using the LZ4 algorithm. + * + * @param out The output stream to compress + * @param blockSize Default: 4. The block size used during compression. 4=64kb, 5=256kb, 6=1mb, 7=4mb. All other + * values will generate an exception + * @param blockChecksum Default: false. When true, a XXHash32 checksum is computed and appended to the stream for + * every block of data + * @param useBrokenFlagDescriptorChecksum Default: false. When true, writes an incorrect FrameDescriptor checksum + * compatible with older kafka clients. + * @throws IOException + */ + public KafkaLZ4BlockOutputStream(OutputStream out, int blockSize, boolean blockChecksum, boolean useBrokenFlagDescriptorChecksum) throws IOException { + this.out = out; + compressor = LZ4Factory.fastestInstance().fastCompressor(); + checksum = XXHashFactory.fastestInstance().hash32(); + this.useBrokenFlagDescriptorChecksum = useBrokenFlagDescriptorChecksum; + bd = new BD(blockSize); + flg = new FLG(blockChecksum); + bufferOffset = 0; + maxBlockSize = bd.getBlockMaximumSize(); + buffer = new byte[maxBlockSize]; + compressedBuffer = new byte[compressor.maxCompressedLength(maxBlockSize)]; + finished = false; + writeHeader(); + } + + /** + * Create a new {@link OutputStream} that will compress data using the LZ4 algorithm. + * + * @param out The output stream to compress + * @param blockSize Default: 4. The block size used during compression. 4=64kb, 5=256kb, 6=1mb, 7=4mb. All other + * values will generate an exception + * @param blockChecksum Default: false. When true, a XXHash32 checksum is computed and appended to the stream for + * every block of data + * @throws IOException + */ + public KafkaLZ4BlockOutputStream(OutputStream out, int blockSize, boolean blockChecksum) throws IOException { + this(out, blockSize, blockChecksum, false); + } + + /** + * Create a new {@link OutputStream} that will compress data using the LZ4 algorithm. + * + * @param out The stream to compress + * @param blockSize Default: 4. The block size used during compression. 4=64kb, 5=256kb, 6=1mb, 7=4mb. All other + * values will generate an exception + * @throws IOException + */ + public KafkaLZ4BlockOutputStream(OutputStream out, int blockSize) throws IOException { + this(out, blockSize, false, false); + } + + /** + * Create a new {@link OutputStream} that will compress data using the LZ4 algorithm. + * + * @param out The output stream to compress + * @throws IOException + */ + public KafkaLZ4BlockOutputStream(OutputStream out) throws IOException { + this(out, BLOCKSIZE_64KB); + } + + public KafkaLZ4BlockOutputStream(OutputStream out, boolean useBrokenHC) throws IOException { + this(out, BLOCKSIZE_64KB, false, useBrokenHC); + } + + /** + * Check whether KafkaLZ4BlockInputStream is configured to write an + * incorrect Frame Descriptor checksum, which is useful for + * compatibility with old client implementations. + */ + public boolean useBrokenFlagDescriptorChecksum() { + return this.useBrokenFlagDescriptorChecksum; + } + + /** + * Writes the magic number and frame descriptor to the underlying {@link OutputStream}. + * + * @throws IOException + */ + private void writeHeader() throws IOException { + ByteUtils.writeUnsignedIntLE(buffer, 0, MAGIC); + bufferOffset = 4; + buffer[bufferOffset++] = flg.toByte(); + buffer[bufferOffset++] = bd.toByte(); + // TODO write uncompressed content size, update flg.validate() + + // compute checksum on all descriptor fields + int offset = 4; + int len = bufferOffset - offset; + if (this.useBrokenFlagDescriptorChecksum) { + len += offset; + offset = 0; + } + byte hash = (byte) ((checksum.hash(buffer, offset, len, 0) >> 8) & 0xFF); + buffer[bufferOffset++] = hash; + + // write out frame descriptor + out.write(buffer, 0, bufferOffset); + bufferOffset = 0; + } + + /** + * Compresses buffered data, optionally computes an XXHash32 checksum, and writes the result to the underlying + * {@link OutputStream}. + * + * @throws IOException + */ + private void writeBlock() throws IOException { + if (bufferOffset == 0) { + return; + } + + int compressedLength = compressor.compress(buffer, 0, bufferOffset, compressedBuffer, 0); + byte[] bufferToWrite = compressedBuffer; + int compressMethod = 0; + + // Store block uncompressed if compressed length is greater (incompressible) + if (compressedLength >= bufferOffset) { + bufferToWrite = buffer; + compressedLength = bufferOffset; + compressMethod = LZ4_FRAME_INCOMPRESSIBLE_MASK; + } + + // Write content + ByteUtils.writeUnsignedIntLE(out, compressedLength | compressMethod); + out.write(bufferToWrite, 0, compressedLength); + + // Calculate and write block checksum + if (flg.isBlockChecksumSet()) { + int hash = checksum.hash(bufferToWrite, 0, compressedLength, 0); + ByteUtils.writeUnsignedIntLE(out, hash); + } + bufferOffset = 0; + } + + /** + * Similar to the {@link #writeBlock()} method. Writes a 0-length block (without block checksum) to signal the end + * of the block stream. + * + * @throws IOException + */ + private void writeEndMark() throws IOException { + ByteUtils.writeUnsignedIntLE(out, 0); + // TODO implement content checksum, update flg.validate() + } + + @Override + public void write(int b) throws IOException { + ensureNotFinished(); + if (bufferOffset == maxBlockSize) { + writeBlock(); + } + buffer[bufferOffset++] = (byte) b; + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + net.jpountz.util.SafeUtils.checkRange(b, off, len); + ensureNotFinished(); + + int bufferRemainingLength = maxBlockSize - bufferOffset; + // while b will fill the buffer + while (len > bufferRemainingLength) { + // fill remaining space in buffer + System.arraycopy(b, off, buffer, bufferOffset, bufferRemainingLength); + bufferOffset = maxBlockSize; + writeBlock(); + // compute new offset and length + off += bufferRemainingLength; + len -= bufferRemainingLength; + bufferRemainingLength = maxBlockSize; + } + + System.arraycopy(b, off, buffer, bufferOffset, len); + bufferOffset += len; + } + + @Override + public void flush() throws IOException { + if (!finished) { + writeBlock(); + } + if (out != null) { + out.flush(); + } + } + + /** + * A simple state check to ensure the stream is still open. + */ + private void ensureNotFinished() { + if (finished) { + throw new IllegalStateException(CLOSED_STREAM); + } + } + + @Override + public void close() throws IOException { + try { + if (!finished) { + // basically flush the buffer writing the last block + writeBlock(); + // write the end block + writeEndMark(); + } + } finally { + try { + if (out != null) { + try (OutputStream outStream = out) { + outStream.flush(); + } + } + } finally { + out = null; + buffer = null; + compressedBuffer = null; + finished = true; + } + } + } + + public static class FLG { + + private static final int VERSION = 1; + + private final int reserved; + private final int contentChecksum; + private final int contentSize; + private final int blockChecksum; + private final int blockIndependence; + private final int version; + + public FLG() { + this(false); + } + + public FLG(boolean blockChecksum) { + this(0, 0, 0, blockChecksum ? 1 : 0, 1, VERSION); + } + + private FLG(int reserved, + int contentChecksum, + int contentSize, + int blockChecksum, + int blockIndependence, + int version) { + this.reserved = reserved; + this.contentChecksum = contentChecksum; + this.contentSize = contentSize; + this.blockChecksum = blockChecksum; + this.blockIndependence = blockIndependence; + this.version = version; + validate(); + } + + public static FLG fromByte(byte flg) { + int reserved = (flg >>> 0) & 3; + int contentChecksum = (flg >>> 2) & 1; + int contentSize = (flg >>> 3) & 1; + int blockChecksum = (flg >>> 4) & 1; + int blockIndependence = (flg >>> 5) & 1; + int version = (flg >>> 6) & 3; + + return new FLG(reserved, + contentChecksum, + contentSize, + blockChecksum, + blockIndependence, + version); + } + + public byte toByte() { + return (byte) (((reserved & 3) << 0) | ((contentChecksum & 1) << 2) + | ((contentSize & 1) << 3) | ((blockChecksum & 1) << 4) | ((blockIndependence & 1) << 5) | ((version & 3) << 6)); + } + + private void validate() { + if (reserved != 0) { + throw new RuntimeException("Reserved bits must be 0"); + } + if (blockIndependence != 1) { + throw new RuntimeException("Dependent block stream is unsupported"); + } + if (version != VERSION) { + throw new RuntimeException(String.format("Version %d is unsupported", version)); + } + } + + public boolean isContentChecksumSet() { + return contentChecksum == 1; + } + + public boolean isContentSizeSet() { + return contentSize == 1; + } + + public boolean isBlockChecksumSet() { + return blockChecksum == 1; + } + + public boolean isBlockIndependenceSet() { + return blockIndependence == 1; + } + + public int getVersion() { + return version; + } + } + + public static class BD { + + private final int reserved2; + private final int blockSizeValue; + private final int reserved3; + + public BD() { + this(0, BLOCKSIZE_64KB, 0); + } + + public BD(int blockSizeValue) { + this(0, blockSizeValue, 0); + } + + private BD(int reserved2, int blockSizeValue, int reserved3) { + this.reserved2 = reserved2; + this.blockSizeValue = blockSizeValue; + this.reserved3 = reserved3; + validate(); + } + + public static BD fromByte(byte bd) { + int reserved2 = (bd >>> 0) & 15; + int blockMaximumSize = (bd >>> 4) & 7; + int reserved3 = (bd >>> 7) & 1; + + return new BD(reserved2, blockMaximumSize, reserved3); + } + + private void validate() { + if (reserved2 != 0) { + throw new RuntimeException("Reserved2 field must be 0"); + } + if (blockSizeValue < 4 || blockSizeValue > 7) { + throw new RuntimeException("Block size value must be between 4 and 7"); + } + if (reserved3 != 0) { + throw new RuntimeException("Reserved3 field must be 0"); + } + } + + // 2^(2n+8) + public int getBlockMaximumSize() { + return 1 << ((2 * blockSizeValue) + 8); + } + + public byte toByte() { + return (byte) (((reserved2 & 15) << 0) | ((blockSizeValue & 7) << 4) | ((reserved3 & 1) << 7)); + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/compress/SnappyFactory.java b/clients/src/main/java/org/apache/kafka/common/compress/SnappyFactory.java new file mode 100644 index 0000000..b56273d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/compress/SnappyFactory.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.compress; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.utils.ByteBufferInputStream; +import org.apache.kafka.common.utils.ByteBufferOutputStream; +import org.xerial.snappy.SnappyInputStream; +import org.xerial.snappy.SnappyOutputStream; + +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; + +public class SnappyFactory { + + private SnappyFactory() { } + + public static OutputStream wrapForOutput(ByteBufferOutputStream buffer) { + try { + return new SnappyOutputStream(buffer); + } catch (Throwable e) { + throw new KafkaException(e); + } + } + + public static InputStream wrapForInput(ByteBuffer buffer) { + try { + return new SnappyInputStream(new ByteBufferInputStream(buffer)); + } catch (Throwable e) { + throw new KafkaException(e); + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/compress/ZstdFactory.java b/clients/src/main/java/org/apache/kafka/common/compress/ZstdFactory.java new file mode 100644 index 0000000..4664f4e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/compress/ZstdFactory.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.compress; + +import com.github.luben.zstd.BufferPool; +import com.github.luben.zstd.RecyclingBufferPool; +import com.github.luben.zstd.ZstdInputStreamNoFinalizer; +import com.github.luben.zstd.ZstdOutputStreamNoFinalizer; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.utils.BufferSupplier; +import org.apache.kafka.common.utils.ByteBufferInputStream; +import org.apache.kafka.common.utils.ByteBufferOutputStream; + +import java.io.BufferedInputStream; +import java.io.BufferedOutputStream; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; + +public class ZstdFactory { + + private ZstdFactory() { } + + public static OutputStream wrapForOutput(ByteBufferOutputStream buffer) { + try { + // Set input buffer (uncompressed) to 16 KB (none by default) to ensure reasonable performance + // in cases where the caller passes a small number of bytes to write (potentially a single byte). + return new BufferedOutputStream(new ZstdOutputStreamNoFinalizer(buffer, RecyclingBufferPool.INSTANCE), 16 * 1024); + } catch (Throwable e) { + throw new KafkaException(e); + } + } + + public static InputStream wrapForInput(ByteBuffer buffer, byte messageVersion, BufferSupplier decompressionBufferSupplier) { + try { + // We use our own BufferSupplier instead of com.github.luben.zstd.RecyclingBufferPool since our + // implementation doesn't require locking or soft references. + BufferPool bufferPool = new BufferPool() { + @Override + public ByteBuffer get(int capacity) { + return decompressionBufferSupplier.get(capacity); + } + + @Override + public void release(ByteBuffer buffer) { + decompressionBufferSupplier.release(buffer); + } + }; + + // Set output buffer (uncompressed) to 16 KB (none by default) to ensure reasonable performance + // in cases where the caller reads a small number of bytes (potentially a single byte). + return new BufferedInputStream(new ZstdInputStreamNoFinalizer(new ByteBufferInputStream(buffer), + bufferPool), 16 * 1024); + } catch (Throwable e) { + throw new KafkaException(e); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/config/AbstractConfig.java b/clients/src/main/java/org/apache/kafka/common/config/AbstractConfig.java new file mode 100644 index 0000000..7ef4609 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/config/AbstractConfig.java @@ -0,0 +1,678 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.config; + +import org.apache.kafka.common.Configurable; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.config.types.Password; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.apache.kafka.common.config.provider.ConfigProvider; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; +import java.util.concurrent.ConcurrentHashMap; + +/** + * A convenient base class for configurations to extend. + *

        + * This class holds both the original configuration that was provided as well as the parsed + */ +public class AbstractConfig { + + private final Logger log = LoggerFactory.getLogger(getClass()); + + /** + * Configs for which values have been requested, used to detect unused configs. + * This set must be concurrent modifiable and iterable. It will be modified + * when directly accessed or as a result of RecordingMap access. + */ + private final Set used = ConcurrentHashMap.newKeySet(); + + /* the original values passed in by the user */ + private final Map originals; + + /* the parsed values */ + private final Map values; + + private final ConfigDef definition; + + public static final String CONFIG_PROVIDERS_CONFIG = "config.providers"; + + private static final String CONFIG_PROVIDERS_PARAM = ".param."; + + /** + * Construct a configuration with a ConfigDef and the configuration properties, which can include properties + * for zero or more {@link ConfigProvider} that will be used to resolve variables in configuration property + * values. + * + * The originals is a name-value pair configuration properties and optional config provider configs. The + * value of the configuration can be a variable as defined below or the actual value. This constructor will + * first instantiate the ConfigProviders using the config provider configs, then it will find all the + * variables in the values of the originals configurations, attempt to resolve the variables using the named + * ConfigProviders, and then parse and validate the configurations. + * + * ConfigProvider configs can be passed either as configs in the originals map or in the separate + * configProviderProps map. If config providers properties are passed in the configProviderProps any config + * provider properties in originals map will be ignored. If ConfigProvider properties are not provided, the + * constructor will skip the variable substitution step and will simply validate and parse the supplied + * configuration. + * + * The "{@code config.providers}" configuration property and all configuration properties that begin with the + * "{@code config.providers.}" prefix are reserved. The "{@code config.providers}" configuration property + * specifies the names of the config providers, and properties that begin with the "{@code config.providers..}" + * prefix correspond to the properties for that named provider. For example, the "{@code config.providers..class}" + * property specifies the name of the {@link ConfigProvider} implementation class that should be used for + * the provider. + * + * The keys for ConfigProvider configs in both originals and configProviderProps will start with the above + * mentioned "{@code config.providers.}" prefix. + * + * Variables have the form "${providerName:[path:]key}", where "providerName" is the name of a ConfigProvider, + * "path" is an optional string, and "key" is a required string. This variable is resolved by passing the "key" + * and optional "path" to a ConfigProvider with the specified name, and the result from the ConfigProvider is + * then used in place of the variable. Variables that cannot be resolved by the AbstractConfig constructor will + * be left unchanged in the configuration. + * + * + * @param definition the definition of the configurations; may not be null + * @param originals the configuration properties plus any optional config provider properties; + * @param configProviderProps the map of properties of config providers which will be instantiated by + * the constructor to resolve any variables in {@code originals}; may be null or empty + * @param doLog whether the configurations should be logged + */ + @SuppressWarnings("unchecked") + public AbstractConfig(ConfigDef definition, Map originals, Map configProviderProps, boolean doLog) { + /* check that all the keys are really strings */ + for (Map.Entry entry : originals.entrySet()) + if (!(entry.getKey() instanceof String)) + throw new ConfigException(entry.getKey().toString(), entry.getValue(), "Key must be a string."); + + this.originals = resolveConfigVariables(configProviderProps, (Map) originals); + this.values = definition.parse(this.originals); + Map configUpdates = postProcessParsedConfig(Collections.unmodifiableMap(this.values)); + for (Map.Entry update : configUpdates.entrySet()) { + this.values.put(update.getKey(), update.getValue()); + } + definition.parse(this.values); + this.definition = definition; + if (doLog) + logAll(); + } + + /** + * Construct a configuration with a ConfigDef and the configuration properties, + * which can include properties for zero or more {@link ConfigProvider} + * that will be used to resolve variables in configuration property values. + * + * @param definition the definition of the configurations; may not be null + * @param originals the configuration properties plus any optional config provider properties; may not be null + */ + public AbstractConfig(ConfigDef definition, Map originals) { + this(definition, originals, Collections.emptyMap(), true); + } + + /** + * Construct a configuration with a ConfigDef and the configuration properties, + * which can include properties for zero or more {@link ConfigProvider} + * that will be used to resolve variables in configuration property values. + * + * @param definition the definition of the configurations; may not be null + * @param originals the configuration properties plus any optional config provider properties; may not be null + * @param doLog whether the configurations should be logged + */ + public AbstractConfig(ConfigDef definition, Map originals, boolean doLog) { + this(definition, originals, Collections.emptyMap(), doLog); + + } + + /** + * Called directly after user configs got parsed (and thus default values got set). + * This allows to change default values for "secondary defaults" if required. + * + * @param parsedValues unmodifiable map of current configuration + * @return a map of updates that should be applied to the configuration (will be validated to prevent bad updates) + */ + protected Map postProcessParsedConfig(Map parsedValues) { + return Collections.emptyMap(); + } + + protected Object get(String key) { + if (!values.containsKey(key)) + throw new ConfigException(String.format("Unknown configuration '%s'", key)); + used.add(key); + return values.get(key); + } + + public void ignore(String key) { + used.add(key); + } + + public Short getShort(String key) { + return (Short) get(key); + } + + public Integer getInt(String key) { + return (Integer) get(key); + } + + public Long getLong(String key) { + return (Long) get(key); + } + + public Double getDouble(String key) { + return (Double) get(key); + } + + @SuppressWarnings("unchecked") + public List getList(String key) { + return (List) get(key); + } + + public Boolean getBoolean(String key) { + return (Boolean) get(key); + } + + public String getString(String key) { + return (String) get(key); + } + + public ConfigDef.Type typeOf(String key) { + ConfigDef.ConfigKey configKey = definition.configKeys().get(key); + if (configKey == null) + return null; + return configKey.type; + } + + public String documentationOf(String key) { + ConfigDef.ConfigKey configKey = definition.configKeys().get(key); + if (configKey == null) + return null; + return configKey.documentation; + } + + public Password getPassword(String key) { + return (Password) get(key); + } + + public Class getClass(String key) { + return (Class) get(key); + } + + public Set unused() { + Set keys = new HashSet<>(originals.keySet()); + keys.removeAll(used); + return keys; + } + + public Map originals() { + Map copy = new RecordingMap<>(); + copy.putAll(originals); + return copy; + } + + public Map originals(Map configOverrides) { + Map copy = new RecordingMap<>(); + copy.putAll(originals); + copy.putAll(configOverrides); + return copy; + } + + /** + * Get all the original settings, ensuring that all values are of type String. + * @return the original settings + * @throws ClassCastException if any of the values are not strings + */ + public Map originalsStrings() { + Map copy = new RecordingMap<>(); + for (Map.Entry entry : originals.entrySet()) { + if (!(entry.getValue() instanceof String)) + throw new ClassCastException("Non-string value found in original settings for key " + entry.getKey() + + ": " + (entry.getValue() == null ? null : entry.getValue().getClass().getName())); + copy.put(entry.getKey(), (String) entry.getValue()); + } + return copy; + } + + /** + * Gets all original settings with the given prefix, stripping the prefix before adding it to the output. + * + * @param prefix the prefix to use as a filter + * @return a Map containing the settings with the prefix + */ + public Map originalsWithPrefix(String prefix) { + return originalsWithPrefix(prefix, true); + } + + /** + * Gets all original settings with the given prefix. + * + * @param prefix the prefix to use as a filter + * @param strip strip the prefix before adding to the output if set true + * @return a Map containing the settings with the prefix + */ + public Map originalsWithPrefix(String prefix, boolean strip) { + Map result = new RecordingMap<>(prefix, false); + for (Map.Entry entry : originals.entrySet()) { + if (entry.getKey().startsWith(prefix) && entry.getKey().length() > prefix.length()) { + if (strip) + result.put(entry.getKey().substring(prefix.length()), entry.getValue()); + else + result.put(entry.getKey(), entry.getValue()); + } + } + return result; + } + + /** + * Put all keys that do not start with {@code prefix} and their parsed values in the result map and then + * put all the remaining keys with the prefix stripped and their parsed values in the result map. + * + * This is useful if one wants to allow prefixed configs to override default ones. + *

        + * Two forms of prefixes are supported: + *

          + *
        • listener.name.{listenerName}.some.prop: If the provided prefix is `listener.name.{listenerName}.`, + * the key `some.prop` with the value parsed using the definition of `some.prop` is returned.
        • + *
        • listener.name.{listenerName}.{mechanism}.some.prop: If the provided prefix is `listener.name.{listenerName}.`, + * the key `{mechanism}.some.prop` with the value parsed using the definition of `some.prop` is returned. + * This is used to provide per-mechanism configs for a broker listener (e.g sasl.jaas.config)
        • + *
        + *

        + */ + public Map valuesWithPrefixOverride(String prefix) { + Map result = new RecordingMap<>(values(), prefix, true); + for (Map.Entry entry : originals.entrySet()) { + if (entry.getKey().startsWith(prefix) && entry.getKey().length() > prefix.length()) { + String keyWithNoPrefix = entry.getKey().substring(prefix.length()); + ConfigDef.ConfigKey configKey = definition.configKeys().get(keyWithNoPrefix); + if (configKey != null) + result.put(keyWithNoPrefix, definition.parseValue(configKey, entry.getValue(), true)); + else { + String keyWithNoSecondaryPrefix = keyWithNoPrefix.substring(keyWithNoPrefix.indexOf('.') + 1); + configKey = definition.configKeys().get(keyWithNoSecondaryPrefix); + if (configKey != null) + result.put(keyWithNoPrefix, definition.parseValue(configKey, entry.getValue(), true)); + } + } + } + return result; + } + + /** + * If at least one key with {@code prefix} exists, all prefixed values will be parsed and put into map. + * If no value with {@code prefix} exists all unprefixed values will be returned. + * + * This is useful if one wants to allow prefixed configs to override default ones, but wants to use either + * only prefixed configs or only regular configs, but not mix them. + */ + public Map valuesWithPrefixAllOrNothing(String prefix) { + Map withPrefix = originalsWithPrefix(prefix, true); + + if (withPrefix.isEmpty()) { + return new RecordingMap<>(values(), "", true); + } else { + Map result = new RecordingMap<>(prefix, true); + + for (Map.Entry entry : withPrefix.entrySet()) { + ConfigDef.ConfigKey configKey = definition.configKeys().get(entry.getKey()); + if (configKey != null) + result.put(entry.getKey(), definition.parseValue(configKey, entry.getValue(), true)); + } + + return result; + } + } + + public Map values() { + return new RecordingMap<>(values); + } + + public Map nonInternalValues() { + Map nonInternalConfigs = new RecordingMap<>(); + values.forEach((key, value) -> { + ConfigDef.ConfigKey configKey = definition.configKeys().get(key); + if (configKey == null || !configKey.internalConfig) { + nonInternalConfigs.put(key, value); + } + }); + return nonInternalConfigs; + } + + private void logAll() { + StringBuilder b = new StringBuilder(); + b.append(getClass().getSimpleName()); + b.append(" values: "); + b.append(Utils.NL); + + for (Map.Entry entry : new TreeMap<>(this.values).entrySet()) { + b.append('\t'); + b.append(entry.getKey()); + b.append(" = "); + b.append(entry.getValue()); + b.append(Utils.NL); + } + log.info(b.toString()); + } + + /** + * Log warnings for any unused configurations + */ + public void logUnused() { + for (String key : unused()) + log.warn("The configuration '{}' was supplied but isn't a known config.", key); + } + + private T getConfiguredInstance(Object klass, Class t, Map configPairs) { + if (klass == null) + return null; + + Object o; + if (klass instanceof String) { + try { + o = Utils.newInstance((String) klass, t); + } catch (ClassNotFoundException e) { + throw new KafkaException("Class " + klass + " cannot be found", e); + } + } else if (klass instanceof Class) { + o = Utils.newInstance((Class) klass); + } else + throw new KafkaException("Unexpected element of type " + klass.getClass().getName() + ", expected String or Class"); + if (!t.isInstance(o)) + throw new KafkaException(klass + " is not an instance of " + t.getName()); + if (o instanceof Configurable) + ((Configurable) o).configure(configPairs); + + return t.cast(o); + } + + /** + * Get a configured instance of the give class specified by the given configuration key. If the object implements + * Configurable configure it using the configuration. + * + * @param key The configuration key for the class + * @param t The interface the class should implement + * @return A configured instance of the class + */ + public T getConfiguredInstance(String key, Class t) { + return getConfiguredInstance(key, t, Collections.emptyMap()); + } + + /** + * Get a configured instance of the give class specified by the given configuration key. If the object implements + * Configurable configure it using the configuration. + * + * @param key The configuration key for the class + * @param t The interface the class should implement + * @param configOverrides override origin configs + * @return A configured instance of the class + */ + public T getConfiguredInstance(String key, Class t, Map configOverrides) { + Class c = getClass(key); + + return getConfiguredInstance(c, t, originals(configOverrides)); + } + + /** + * Get a list of configured instances of the given class specified by the given configuration key. The configuration + * may specify either null or an empty string to indicate no configured instances. In both cases, this method + * returns an empty list to indicate no configured instances. + * @param key The configuration key for the class + * @param t The interface the class should implement + * @return The list of configured instances + */ + public List getConfiguredInstances(String key, Class t) { + return getConfiguredInstances(key, t, Collections.emptyMap()); + } + + /** + * Get a list of configured instances of the given class specified by the given configuration key. The configuration + * may specify either null or an empty string to indicate no configured instances. In both cases, this method + * returns an empty list to indicate no configured instances. + * @param key The configuration key for the class + * @param t The interface the class should implement + * @param configOverrides Configuration overrides to use. + * @return The list of configured instances + */ + public List getConfiguredInstances(String key, Class t, Map configOverrides) { + return getConfiguredInstances(getList(key), t, configOverrides); + } + + /** + * Get a list of configured instances of the given class specified by the given configuration key. The configuration + * may specify either null or an empty string to indicate no configured instances. In both cases, this method + * returns an empty list to indicate no configured instances. + * @param classNames The list of class names of the instances to create + * @param t The interface the class should implement + * @param configOverrides Configuration overrides to use. + * @return The list of configured instances + */ + public List getConfiguredInstances(List classNames, Class t, Map configOverrides) { + List objects = new ArrayList<>(); + if (classNames == null) + return objects; + Map configPairs = originals(); + configPairs.putAll(configOverrides); + for (Object klass : classNames) { + Object o = getConfiguredInstance(klass, t, configPairs); + objects.add(t.cast(o)); + } + return objects; + } + + private Map extractPotentialVariables(Map configMap) { + // Variables are tuples of the form "${providerName:[path:]key}". From the configMap we extract the subset of configs with string + // values as potential variables. + Map configMapAsString = new HashMap<>(); + for (Map.Entry entry : configMap.entrySet()) { + if (entry.getValue() instanceof String) + configMapAsString.put((String) entry.getKey(), (String) entry.getValue()); + } + + return configMapAsString; + } + + /** + * Instantiates given list of config providers and fetches the actual values of config variables from the config providers. + * returns a map of config key and resolved values. + * @param configProviderProps The map of config provider configs + * @param originals The map of raw configs. + * @return map of resolved config variable. + */ + @SuppressWarnings("unchecked") + private Map resolveConfigVariables(Map configProviderProps, Map originals) { + Map providerConfigString; + Map configProperties; + Map resolvedOriginals = new HashMap<>(); + // As variable configs are strings, parse the originals and obtain the potential variable configs. + Map indirectVariables = extractPotentialVariables(originals); + + resolvedOriginals.putAll(originals); + if (configProviderProps == null || configProviderProps.isEmpty()) { + providerConfigString = indirectVariables; + configProperties = originals; + } else { + providerConfigString = extractPotentialVariables(configProviderProps); + configProperties = configProviderProps; + } + Map providers = instantiateConfigProviders(providerConfigString, configProperties); + + if (!providers.isEmpty()) { + ConfigTransformer configTransformer = new ConfigTransformer(providers); + ConfigTransformerResult result = configTransformer.transform(indirectVariables); + if (!result.data().isEmpty()) { + resolvedOriginals.putAll(result.data()); + } + } + providers.values().forEach(x -> Utils.closeQuietly(x, "config provider")); + + return new ResolvingMap<>(resolvedOriginals, originals); + } + + private Map configProviderProperties(String configProviderPrefix, Map providerConfigProperties) { + Map result = new HashMap<>(); + for (Map.Entry entry : providerConfigProperties.entrySet()) { + String key = entry.getKey(); + if (key.startsWith(configProviderPrefix) && key.length() > configProviderPrefix.length()) { + result.put(key.substring(configProviderPrefix.length()), entry.getValue()); + } + } + return result; + } + + /** + * Instantiates and configures the ConfigProviders. The config providers configs are defined as follows: + * config.providers : A comma-separated list of names for providers. + * config.providers.{name}.class : The Java class name for a provider. + * config.providers.{name}.param.{param-name} : A parameter to be passed to the above Java class on initialization. + * returns a map of config provider name and its instance. + * @param indirectConfigs The map of potential variable configs + * @param providerConfigProperties The map of config provider configs + * @return map map of config provider name and its instance. + */ + private Map instantiateConfigProviders(Map indirectConfigs, Map providerConfigProperties) { + final String configProviders = indirectConfigs.get(CONFIG_PROVIDERS_CONFIG); + + if (configProviders == null || configProviders.isEmpty()) { + return Collections.emptyMap(); + } + + Map providerMap = new HashMap<>(); + + for (String provider: configProviders.split(",")) { + String providerClass = providerClassProperty(provider); + if (indirectConfigs.containsKey(providerClass)) + providerMap.put(provider, indirectConfigs.get(providerClass)); + + } + // Instantiate Config Providers + Map configProviderInstances = new HashMap<>(); + for (Map.Entry entry : providerMap.entrySet()) { + try { + String prefix = CONFIG_PROVIDERS_CONFIG + "." + entry.getKey() + CONFIG_PROVIDERS_PARAM; + Map configProperties = configProviderProperties(prefix, providerConfigProperties); + ConfigProvider provider = Utils.newInstance(entry.getValue(), ConfigProvider.class); + provider.configure(configProperties); + configProviderInstances.put(entry.getKey(), provider); + } catch (ClassNotFoundException e) { + log.error("Could not load config provider class " + entry.getValue(), e); + throw new ConfigException(providerClassProperty(entry.getKey()), entry.getValue(), "Could not load config provider class or one of its dependencies"); + } + } + + return configProviderInstances; + } + + private static String providerClassProperty(String providerName) { + return String.format("%s.%s.class", CONFIG_PROVIDERS_CONFIG, providerName); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + AbstractConfig that = (AbstractConfig) o; + + return originals.equals(that.originals); + } + + @Override + public int hashCode() { + return originals.hashCode(); + } + + /** + * Marks keys retrieved via `get` as used. This is needed because `Configurable.configure` takes a `Map` instead + * of an `AbstractConfig` and we can't change that without breaking public API like `Partitioner`. + */ + private class RecordingMap extends HashMap { + + private final String prefix; + private final boolean withIgnoreFallback; + + RecordingMap() { + this("", false); + } + + RecordingMap(String prefix, boolean withIgnoreFallback) { + this.prefix = prefix; + this.withIgnoreFallback = withIgnoreFallback; + } + + RecordingMap(Map m) { + this(m, "", false); + } + + RecordingMap(Map m, String prefix, boolean withIgnoreFallback) { + super(m); + this.prefix = prefix; + this.withIgnoreFallback = withIgnoreFallback; + } + + @Override + public V get(Object key) { + if (key instanceof String) { + String stringKey = (String) key; + String keyWithPrefix; + if (prefix.isEmpty()) { + keyWithPrefix = stringKey; + } else { + keyWithPrefix = prefix + stringKey; + } + ignore(keyWithPrefix); + if (withIgnoreFallback) + ignore(stringKey); + } + return super.get(key); + } + } + + /** + * ResolvingMap keeps a track of the original map instance and the resolved configs. + * The originals are tracked in a separate nested map and may be a `RecordingMap`; thus + * any access to a value for a key needs to be recorded on the originals map. + * The resolved configs are kept in the inherited map and are therefore mutable, though any + * mutations are not applied to the originals. + */ + private static class ResolvingMap extends HashMap { + + private final Map originals; + + ResolvingMap(Map resolved, Map originals) { + super(resolved); + this.originals = Collections.unmodifiableMap(originals); + } + + @Override + public V get(Object key) { + if (key instanceof String && originals.containsKey(key)) { + // Intentionally ignore the result; call just to mark the original entry as used + originals.get(key); + } + // But always use the resolved entry + return super.get(key); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/config/Config.java b/clients/src/main/java/org/apache/kafka/common/config/Config.java new file mode 100644 index 0000000..f7fa95c --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/config/Config.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.config; + +import java.util.List; + +public class Config { + private final List configValues; + + public Config(List configValues) { + this.configValues = configValues; + } + + public List configValues() { + return configValues; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/config/ConfigChangeCallback.java b/clients/src/main/java/org/apache/kafka/common/config/ConfigChangeCallback.java new file mode 100644 index 0000000..faa7d3d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/config/ConfigChangeCallback.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.config; + +import org.apache.kafka.common.config.provider.ConfigProvider; + +/** + * A callback passed to {@link ConfigProvider} for subscribing to changes. + */ +public interface ConfigChangeCallback { + + /** + * Performs an action when configuration data changes. + * + * @param path the path at which the data resides + * @param data the configuration data + */ + void onChange(String path, ConfigData data); +} diff --git a/clients/src/main/java/org/apache/kafka/common/config/ConfigData.java b/clients/src/main/java/org/apache/kafka/common/config/ConfigData.java new file mode 100644 index 0000000..8661ee1 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/config/ConfigData.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.config; + +import org.apache.kafka.common.config.provider.ConfigProvider; + +import java.util.Map; + +/** + * Configuration data from a {@link ConfigProvider}. + */ +public class ConfigData { + + private final Map data; + private final Long ttl; + + /** + * Creates a new ConfigData with the given data and TTL (in milliseconds). + * + * @param data a Map of key-value pairs + * @param ttl the time-to-live of the data in milliseconds, or null if there is no TTL + */ + public ConfigData(Map data, Long ttl) { + this.data = data; + this.ttl = ttl; + } + + /** + * Creates a new ConfigData with the given data. + * + * @param data a Map of key-value pairs + */ + public ConfigData(Map data) { + this(data, null); + } + + /** + * Returns the data. + * + * @return data a Map of key-value pairs + */ + public Map data() { + return data; + } + + /** + * Returns the TTL (in milliseconds). + * + * @return ttl the time-to-live (in milliseconds) of the data, or null if there is no TTL + */ + public Long ttl() { + return ttl; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/config/ConfigDef.java b/clients/src/main/java/org/apache/kafka/common/config/ConfigDef.java new file mode 100644 index 0000000..85b0103 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/config/ConfigDef.java @@ -0,0 +1,1594 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.config; + +import java.util.function.Function; +import java.util.stream.Collectors; +import org.apache.kafka.common.config.types.Password; +import org.apache.kafka.common.utils.Utils; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.function.BiConsumer; +import java.util.function.Supplier; +import java.util.regex.Pattern; + +/** + * This class is used for specifying the set of expected configurations. For each configuration, you can specify + * the name, the type, the default value, the documentation, the group information, the order in the group, + * the width of the configuration value and the name suitable for display in the UI. + * + * You can provide special validation logic used for single configuration validation by overriding {@link Validator}. + * + * Moreover, you can specify the dependents of a configuration. The valid values and visibility of a configuration + * may change according to the values of other configurations. You can override {@link Recommender} to get valid + * values and set visibility of a configuration given the current configuration values. + * + *

        + * To use the class: + *

        + *

        + * ConfigDef defs = new ConfigDef();
        + *
        + * defs.define("config_with_default", Type.STRING, "default string value", "Configuration with default value.");
        + * defs.define("config_with_validator", Type.INT, 42, Range.atLeast(0), "Configuration with user provided validator.");
        + * defs.define("config_with_dependents", Type.INT, "Configuration with dependents.", "group", 1, "Config With Dependents", Arrays.asList("config_with_default","config_with_validator"));
        + *
        + * Map<String, String> props = new HashMap<>();
        + * props.put("config_with_default", "some value");
        + * props.put("config_with_dependents", "some other value");
        + *
        + * Map<String, Object> configs = defs.parse(props);
        + * // will return "some value"
        + * String someConfig = (String) configs.get("config_with_default");
        + * // will return default value of 42
        + * int anotherConfig = (Integer) configs.get("config_with_validator");
        + *
        + * To validate the full configuration, use:
        + * List<Config> configs = defs.validate(props);
        + * The {@link Config} contains updated configuration information given the current configuration values.
        + * 
        + *

        + * This class can be used standalone or in combination with {@link AbstractConfig} which provides some additional + * functionality for accessing configs. + */ +public class ConfigDef { + + private static final Pattern COMMA_WITH_WHITESPACE = Pattern.compile("\\s*,\\s*"); + + /** + * A unique Java object which represents the lack of a default value. + */ + public static final Object NO_DEFAULT_VALUE = new Object(); + + private final Map configKeys; + private final List groups; + private Set configsWithNoParent; + + public ConfigDef() { + configKeys = new LinkedHashMap<>(); + groups = new LinkedList<>(); + configsWithNoParent = null; + } + + public ConfigDef(ConfigDef base) { + configKeys = new LinkedHashMap<>(base.configKeys); + groups = new LinkedList<>(base.groups); + // It is not safe to copy this from the parent because we may subsequently add to the set of configs and + // invalidate this + configsWithNoParent = null; + } + + /** + * Returns unmodifiable set of properties names defined in this {@linkplain ConfigDef} + * + * @return new unmodifiable {@link Set} instance containing the keys + */ + public Set names() { + return Collections.unmodifiableSet(configKeys.keySet()); + } + + public Map defaultValues() { + Map defaultValues = new HashMap<>(); + for (ConfigKey key : configKeys.values()) { + if (key.defaultValue != NO_DEFAULT_VALUE) + defaultValues.put(key.name, key.defaultValue); + } + return defaultValues; + } + + public ConfigDef define(ConfigKey key) { + if (configKeys.containsKey(key.name)) { + throw new ConfigException("Configuration " + key.name + " is defined twice."); + } + if (key.group != null && !groups.contains(key.group)) { + groups.add(key.group); + } + configKeys.put(key.name, key); + return this; + } + + /** + * Define a new configuration + * @param name the name of the config parameter + * @param type the type of the config + * @param defaultValue the default value to use if this config isn't present + * @param validator the validator to use in checking the correctness of the config + * @param importance the importance of this config + * @param documentation the documentation string for the config + * @param group the group this config belongs to + * @param orderInGroup the order of this config in the group + * @param width the width of the config + * @param displayName the name suitable for display + * @param dependents the configurations that are dependents of this configuration + * @param recommender the recommender provides valid values given the parent configuration values + * @return This ConfigDef so you can chain calls + */ + public ConfigDef define(String name, Type type, Object defaultValue, Validator validator, Importance importance, String documentation, + String group, int orderInGroup, Width width, String displayName, List dependents, Recommender recommender) { + return define(new ConfigKey(name, type, defaultValue, validator, importance, documentation, group, orderInGroup, width, displayName, dependents, recommender, false)); + } + + /** + * Define a new configuration with no custom recommender + * @param name the name of the config parameter + * @param type the type of the config + * @param defaultValue the default value to use if this config isn't present + * @param validator the validator to use in checking the correctness of the config + * @param importance the importance of this config + * @param documentation the documentation string for the config + * @param group the group this config belongs to + * @param orderInGroup the order of this config in the group + * @param width the width of the config + * @param displayName the name suitable for display + * @param dependents the configurations that are dependents of this configuration + * @return This ConfigDef so you can chain calls + */ + public ConfigDef define(String name, Type type, Object defaultValue, Validator validator, Importance importance, String documentation, + String group, int orderInGroup, Width width, String displayName, List dependents) { + return define(name, type, defaultValue, validator, importance, documentation, group, orderInGroup, width, displayName, dependents, null); + } + + /** + * Define a new configuration with no dependents + * @param name the name of the config parameter + * @param type the type of the config + * @param defaultValue the default value to use if this config isn't present + * @param validator the validator to use in checking the correctness of the config + * @param importance the importance of this config + * @param documentation the documentation string for the config + * @param group the group this config belongs to + * @param orderInGroup the order of this config in the group + * @param width the width of the config + * @param displayName the name suitable for display + * @param recommender the recommender provides valid values given the parent configuration values + * @return This ConfigDef so you can chain calls + */ + public ConfigDef define(String name, Type type, Object defaultValue, Validator validator, Importance importance, String documentation, + String group, int orderInGroup, Width width, String displayName, Recommender recommender) { + return define(name, type, defaultValue, validator, importance, documentation, group, orderInGroup, width, displayName, Collections.emptyList(), recommender); + } + + /** + * Define a new configuration with no dependents and no custom recommender + * @param name the name of the config parameter + * @param type the type of the config + * @param defaultValue the default value to use if this config isn't present + * @param validator the validator to use in checking the correctness of the config + * @param importance the importance of this config + * @param documentation the documentation string for the config + * @param group the group this config belongs to + * @param orderInGroup the order of this config in the group + * @param width the width of the config + * @param displayName the name suitable for display + * @return This ConfigDef so you can chain calls + */ + public ConfigDef define(String name, Type type, Object defaultValue, Validator validator, Importance importance, String documentation, + String group, int orderInGroup, Width width, String displayName) { + return define(name, type, defaultValue, validator, importance, documentation, group, orderInGroup, width, displayName, Collections.emptyList()); + } + + /** + * Define a new configuration with no special validation logic + * @param name the name of the config parameter + * @param type the type of the config + * @param defaultValue the default value to use if this config isn't present + * @param importance the importance of this config + * @param documentation the documentation string for the config + * @param group the group this config belongs to + * @param orderInGroup the order of this config in the group + * @param width the width of the config + * @param displayName the name suitable for display + * @param dependents the configurations that are dependents of this configuration + * @param recommender the recommender provides valid values given the parent configuration values + * @return This ConfigDef so you can chain calls + */ + public ConfigDef define(String name, Type type, Object defaultValue, Importance importance, String documentation, + String group, int orderInGroup, Width width, String displayName, List dependents, Recommender recommender) { + return define(name, type, defaultValue, null, importance, documentation, group, orderInGroup, width, displayName, dependents, recommender); + } + + /** + * Define a new configuration with no special validation logic and no custom recommender + * @param name the name of the config parameter + * @param type the type of the config + * @param defaultValue the default value to use if this config isn't present + * @param importance the importance of this config + * @param documentation the documentation string for the config + * @param group the group this config belongs to + * @param orderInGroup the order of this config in the group + * @param width the width of the config + * @param displayName the name suitable for display + * @param dependents the configurations that are dependents of this configuration + * @return This ConfigDef so you can chain calls + */ + public ConfigDef define(String name, Type type, Object defaultValue, Importance importance, String documentation, + String group, int orderInGroup, Width width, String displayName, List dependents) { + return define(name, type, defaultValue, null, importance, documentation, group, orderInGroup, width, displayName, dependents, null); + } + + /** + * Define a new configuration with no special validation logic and no custom recommender + * @param name the name of the config parameter + * @param type the type of the config + * @param defaultValue the default value to use if this config isn't present + * @param importance the importance of this config + * @param documentation the documentation string for the config + * @param group the group this config belongs to + * @param orderInGroup the order of this config in the group + * @param width the width of the config + * @param displayName the name suitable for display + * @param recommender the recommender provides valid values given the parent configuration values + * @return This ConfigDef so you can chain calls + */ + public ConfigDef define(String name, Type type, Object defaultValue, Importance importance, String documentation, + String group, int orderInGroup, Width width, String displayName, Recommender recommender) { + return define(name, type, defaultValue, null, importance, documentation, group, orderInGroup, width, displayName, Collections.emptyList(), recommender); + } + + /** + * Define a new configuration with no special validation logic, not dependents and no custom recommender + * @param name the name of the config parameter + * @param type the type of the config + * @param defaultValue the default value to use if this config isn't present + * @param importance the importance of this config + * @param documentation the documentation string for the config + * @param group the group this config belongs to + * @param orderInGroup the order of this config in the group + * @param width the width of the config + * @param displayName the name suitable for display + * @return This ConfigDef so you can chain calls + */ + public ConfigDef define(String name, Type type, Object defaultValue, Importance importance, String documentation, + String group, int orderInGroup, Width width, String displayName) { + return define(name, type, defaultValue, null, importance, documentation, group, orderInGroup, width, displayName, Collections.emptyList()); + } + + /** + * Define a new configuration with no default value and no special validation logic + * @param name the name of the config parameter + * @param type the type of the config + * @param importance the importance of this config + * @param documentation the documentation string for the config + * @param group the group this config belongs to + * @param orderInGroup the order of this config in the group + * @param width the width of the config + * @param displayName the name suitable for display + * @param dependents the configurations that are dependents of this configuration + * @param recommender the recommender provides valid values given the parent configuration value + * @return This ConfigDef so you can chain calls + */ + public ConfigDef define(String name, Type type, Importance importance, String documentation, String group, int orderInGroup, + Width width, String displayName, List dependents, Recommender recommender) { + return define(name, type, NO_DEFAULT_VALUE, null, importance, documentation, group, orderInGroup, width, displayName, dependents, recommender); + } + + /** + * Define a new configuration with no default value, no special validation logic and no custom recommender + * @param name the name of the config parameter + * @param type the type of the config + * @param importance the importance of this config + * @param documentation the documentation string for the config + * @param group the group this config belongs to + * @param orderInGroup the order of this config in the group + * @param width the width of the config + * @param displayName the name suitable for display + * @param dependents the configurations that are dependents of this configuration + * @return This ConfigDef so you can chain calls + */ + public ConfigDef define(String name, Type type, Importance importance, String documentation, String group, int orderInGroup, + Width width, String displayName, List dependents) { + return define(name, type, NO_DEFAULT_VALUE, null, importance, documentation, group, orderInGroup, width, displayName, dependents, null); + } + + /** + * Define a new configuration with no default value, no special validation logic and no custom recommender + * @param name the name of the config parameter + * @param type the type of the config + * @param importance the importance of this config + * @param documentation the documentation string for the config + * @param group the group this config belongs to + * @param orderInGroup the order of this config in the group + * @param width the width of the config + * @param displayName the name suitable for display + * @param recommender the recommender provides valid values given the parent configuration value + * @return This ConfigDef so you can chain calls + */ + public ConfigDef define(String name, Type type, Importance importance, String documentation, String group, int orderInGroup, + Width width, String displayName, Recommender recommender) { + return define(name, type, NO_DEFAULT_VALUE, null, importance, documentation, group, orderInGroup, width, displayName, Collections.emptyList(), recommender); + } + + /** + * Define a new configuration with no default value, no special validation logic, no dependents and no custom recommender + * @param name the name of the config parameter + * @param type the type of the config + * @param importance the importance of this config + * @param documentation the documentation string for the config + * @param group the group this config belongs to + * @param orderInGroup the order of this config in the group + * @param width the width of the config + * @param displayName the name suitable for display + * @return This ConfigDef so you can chain calls + */ + public ConfigDef define(String name, Type type, Importance importance, String documentation, String group, int orderInGroup, + Width width, String displayName) { + return define(name, type, NO_DEFAULT_VALUE, null, importance, documentation, group, orderInGroup, width, displayName, Collections.emptyList()); + } + + /** + * Define a new configuration with no group, no order in group, no width, no display name, no dependents and no custom recommender + * @param name the name of the config parameter + * @param type the type of the config + * @param defaultValue the default value to use if this config isn't present + * @param validator the validator to use in checking the correctness of the config + * @param importance the importance of this config + * @param documentation the documentation string for the config + * @return This ConfigDef so you can chain calls + */ + public ConfigDef define(String name, Type type, Object defaultValue, Validator validator, Importance importance, String documentation) { + return define(name, type, defaultValue, validator, importance, documentation, null, -1, Width.NONE, name); + } + + /** + * Define a new configuration with no special validation logic + * @param name The name of the config parameter + * @param type The type of the config + * @param defaultValue The default value to use if this config isn't present + * @param importance The importance of this config: is this something you will likely need to change. + * @param documentation The documentation string for the config + * @return This ConfigDef so you can chain calls + */ + public ConfigDef define(String name, Type type, Object defaultValue, Importance importance, String documentation) { + return define(name, type, defaultValue, null, importance, documentation); + } + + /** + * Define a new configuration with no default value and no special validation logic + * @param name The name of the config parameter + * @param type The type of the config + * @param importance The importance of this config: is this something you will likely need to change. + * @param documentation The documentation string for the config + * @return This ConfigDef so you can chain calls + */ + public ConfigDef define(String name, Type type, Importance importance, String documentation) { + return define(name, type, NO_DEFAULT_VALUE, null, importance, documentation); + } + + /** + * Define a new internal configuration. Internal configuration won't show up in the docs and aren't + * intended for general use. + * @param name The name of the config parameter + * @param type The type of the config + * @param defaultValue The default value to use if this config isn't present + * @param importance The importance of this config (i.e. is this something you will likely need to change?) + * @return This ConfigDef so you can chain calls + */ + public ConfigDef defineInternal(final String name, final Type type, final Object defaultValue, final Importance importance) { + return define(new ConfigKey(name, type, defaultValue, null, importance, "", "", -1, Width.NONE, name, Collections.emptyList(), null, true)); + } + + /** + * Define a new internal configuration. Internal configuration won't show up in the docs and aren't + * intended for general use. + * @param name The name of the config parameter + * @param type The type of the config + * @param defaultValue The default value to use if this config isn't present + * @param validator The validator to use in checking the correctness of the config + * @param importance The importance of this config (i.e. is this something you will likely need to change?) + * @param documentation The documentation string for the config + * @return This ConfigDef so you can chain calls + */ + public ConfigDef defineInternal(final String name, final Type type, final Object defaultValue, final Validator validator, final Importance importance, final String documentation) { + return define(new ConfigKey(name, type, defaultValue, validator, importance, documentation, "", -1, Width.NONE, name, Collections.emptyList(), null, true)); + } + + /** + * Get the configuration keys + * @return a map containing all configuration keys + */ + public Map configKeys() { + return configKeys; + } + + /** + * Get the groups for the configuration + * @return a list of group names + */ + public List groups() { + return groups; + } + + /** + * Add standard SSL client configuration options. + * @return this + */ + public ConfigDef withClientSslSupport() { + SslConfigs.addClientSslSupport(this); + return this; + } + + /** + * Add standard SASL client configuration options. + * @return this + */ + public ConfigDef withClientSaslSupport() { + SaslConfigs.addClientSaslSupport(this); + return this; + } + + /** + * Parse and validate configs against this configuration definition. The input is a map of configs. It is expected + * that the keys of the map are strings, but the values can either be strings or they may already be of the + * appropriate type (int, string, etc). This will work equally well with either java.util.Properties instances or a + * programmatically constructed map. + * + * @param props The configs to parse and validate. + * @return Parsed and validated configs. The key will be the config name and the value will be the value parsed into + * the appropriate type (int, string, etc). + */ + public Map parse(Map props) { + // Check all configurations are defined + List undefinedConfigKeys = undefinedDependentConfigs(); + if (!undefinedConfigKeys.isEmpty()) { + String joined = Utils.join(undefinedConfigKeys, ","); + throw new ConfigException("Some configurations in are referred in the dependents, but not defined: " + joined); + } + // parse all known keys + Map values = new HashMap<>(); + for (ConfigKey key : configKeys.values()) + values.put(key.name, parseValue(key, props.get(key.name), props.containsKey(key.name))); + return values; + } + + Object parseValue(ConfigKey key, Object value, boolean isSet) { + Object parsedValue; + if (isSet) { + parsedValue = parseType(key.name, value, key.type); + // props map doesn't contain setting, the key is required because no default value specified - its an error + } else if (NO_DEFAULT_VALUE.equals(key.defaultValue)) { + throw new ConfigException("Missing required configuration \"" + key.name + "\" which has no default value."); + } else { + // otherwise assign setting its default value + parsedValue = key.defaultValue; + } + if (key.validator != null) { + key.validator.ensureValid(key.name, parsedValue); + } + return parsedValue; + } + + /** + * Validate the current configuration values with the configuration definition. + * @param props the current configuration values + * @return List of Config, each Config contains the updated configuration information given + * the current configuration values. + */ + public List validate(Map props) { + return new ArrayList<>(validateAll(props).values()); + } + + public Map validateAll(Map props) { + Map configValues = new HashMap<>(); + for (String name: configKeys.keySet()) { + configValues.put(name, new ConfigValue(name)); + } + + List undefinedConfigKeys = undefinedDependentConfigs(); + for (String undefinedConfigKey: undefinedConfigKeys) { + ConfigValue undefinedConfigValue = new ConfigValue(undefinedConfigKey); + undefinedConfigValue.addErrorMessage(undefinedConfigKey + " is referred in the dependents, but not defined."); + undefinedConfigValue.visible(false); + configValues.put(undefinedConfigKey, undefinedConfigValue); + } + + Map parsed = parseForValidate(props, configValues); + return validate(parsed, configValues); + } + + // package accessible for testing + Map parseForValidate(Map props, Map configValues) { + Map parsed = new HashMap<>(); + Set configsWithNoParent = getConfigsWithNoParent(); + for (String name: configsWithNoParent) { + parseForValidate(name, props, parsed, configValues); + } + return parsed; + } + + + private Map validate(Map parsed, Map configValues) { + Set configsWithNoParent = getConfigsWithNoParent(); + for (String name: configsWithNoParent) { + validate(name, parsed, configValues); + } + return configValues; + } + + private List undefinedDependentConfigs() { + Set undefinedConfigKeys = new HashSet<>(); + for (ConfigKey configKey : configKeys.values()) { + for (String dependent: configKey.dependents) { + if (!configKeys.containsKey(dependent)) { + undefinedConfigKeys.add(dependent); + } + } + } + return new ArrayList<>(undefinedConfigKeys); + } + + // package accessible for testing + Set getConfigsWithNoParent() { + if (this.configsWithNoParent != null) { + return this.configsWithNoParent; + } + Set configsWithParent = new HashSet<>(); + + for (ConfigKey configKey: configKeys.values()) { + List dependents = configKey.dependents; + configsWithParent.addAll(dependents); + } + + Set configs = new HashSet<>(configKeys.keySet()); + configs.removeAll(configsWithParent); + this.configsWithNoParent = configs; + return configs; + } + + private void parseForValidate(String name, Map props, Map parsed, Map configs) { + if (!configKeys.containsKey(name)) { + return; + } + ConfigKey key = configKeys.get(name); + ConfigValue config = configs.get(name); + + Object value = null; + if (props.containsKey(key.name)) { + try { + value = parseType(key.name, props.get(key.name), key.type); + } catch (ConfigException e) { + config.addErrorMessage(e.getMessage()); + } + } else if (NO_DEFAULT_VALUE.equals(key.defaultValue)) { + config.addErrorMessage("Missing required configuration \"" + key.name + "\" which has no default value."); + } else { + value = key.defaultValue; + } + + if (key.validator != null) { + try { + key.validator.ensureValid(key.name, value); + } catch (ConfigException e) { + config.addErrorMessage(e.getMessage()); + } + } + config.value(value); + parsed.put(name, value); + for (String dependent: key.dependents) { + parseForValidate(dependent, props, parsed, configs); + } + } + + private void validate(String name, Map parsed, Map configs) { + if (!configKeys.containsKey(name)) { + return; + } + ConfigKey key = configKeys.get(name); + ConfigValue value = configs.get(name); + if (key.recommender != null) { + try { + List recommendedValues = key.recommender.validValues(name, parsed); + List originalRecommendedValues = value.recommendedValues(); + if (!originalRecommendedValues.isEmpty()) { + Set originalRecommendedValueSet = new HashSet<>(originalRecommendedValues); + recommendedValues.removeIf(o -> !originalRecommendedValueSet.contains(o)); + } + value.recommendedValues(recommendedValues); + value.visible(key.recommender.visible(name, parsed)); + } catch (ConfigException e) { + value.addErrorMessage(e.getMessage()); + } + } + + configs.put(name, value); + for (String dependent: key.dependents) { + validate(dependent, parsed, configs); + } + } + + /** + * Parse a value according to its expected type. + * @param name The config name + * @param value The config value + * @param type The expected type + * @return The parsed object + */ + public static Object parseType(String name, Object value, Type type) { + try { + if (value == null) return null; + + String trimmed = null; + if (value instanceof String) + trimmed = ((String) value).trim(); + + switch (type) { + case BOOLEAN: + if (value instanceof String) { + if (trimmed.equalsIgnoreCase("true")) + return true; + else if (trimmed.equalsIgnoreCase("false")) + return false; + else + throw new ConfigException(name, value, "Expected value to be either true or false"); + } else if (value instanceof Boolean) + return value; + else + throw new ConfigException(name, value, "Expected value to be either true or false"); + case PASSWORD: + if (value instanceof Password) + return value; + else if (value instanceof String) + return new Password(trimmed); + else + throw new ConfigException(name, value, "Expected value to be a string, but it was a " + value.getClass().getName()); + case STRING: + if (value instanceof String) + return trimmed; + else + throw new ConfigException(name, value, "Expected value to be a string, but it was a " + value.getClass().getName()); + case INT: + if (value instanceof Integer) { + return value; + } else if (value instanceof String) { + return Integer.parseInt(trimmed); + } else { + throw new ConfigException(name, value, "Expected value to be a 32-bit integer, but it was a " + value.getClass().getName()); + } + case SHORT: + if (value instanceof Short) { + return value; + } else if (value instanceof String) { + return Short.parseShort(trimmed); + } else { + throw new ConfigException(name, value, "Expected value to be a 16-bit integer (short), but it was a " + value.getClass().getName()); + } + case LONG: + if (value instanceof Integer) + return ((Integer) value).longValue(); + if (value instanceof Long) + return value; + else if (value instanceof String) + return Long.parseLong(trimmed); + else + throw new ConfigException(name, value, "Expected value to be a 64-bit integer (long), but it was a " + value.getClass().getName()); + case DOUBLE: + if (value instanceof Number) + return ((Number) value).doubleValue(); + else if (value instanceof String) + return Double.parseDouble(trimmed); + else + throw new ConfigException(name, value, "Expected value to be a double, but it was a " + value.getClass().getName()); + case LIST: + if (value instanceof List) + return value; + else if (value instanceof String) + if (trimmed.isEmpty()) + return Collections.emptyList(); + else + return Arrays.asList(COMMA_WITH_WHITESPACE.split(trimmed, -1)); + else + throw new ConfigException(name, value, "Expected a comma separated list."); + case CLASS: + if (value instanceof Class) + return value; + else if (value instanceof String) { + ClassLoader contextOrKafkaClassLoader = Utils.getContextOrKafkaClassLoader(); + // Use loadClass here instead of Class.forName because the name we use here may be an alias + // and not match the name of the class that gets loaded. If that happens, Class.forName can + // throw an exception. + Class klass = contextOrKafkaClassLoader.loadClass(trimmed); + // Invoke forName here with the true name of the requested class to cause class + // initialization to take place. + return Class.forName(klass.getName(), true, contextOrKafkaClassLoader); + } else + throw new ConfigException(name, value, "Expected a Class instance or class name."); + default: + throw new IllegalStateException("Unknown type."); + } + } catch (NumberFormatException e) { + throw new ConfigException(name, value, "Not a number of type " + type); + } catch (ClassNotFoundException e) { + throw new ConfigException(name, value, "Class " + value + " could not be found."); + } + } + + public static String convertToString(Object parsedValue, Type type) { + if (parsedValue == null) { + return null; + } + + if (type == null) { + return parsedValue.toString(); + } + + switch (type) { + case BOOLEAN: + case SHORT: + case INT: + case LONG: + case DOUBLE: + case STRING: + case PASSWORD: + return parsedValue.toString(); + case LIST: + List valueList = (List) parsedValue; + return Utils.join(valueList, ","); + case CLASS: + Class clazz = (Class) parsedValue; + return clazz.getName(); + default: + throw new IllegalStateException("Unknown type."); + } + } + + /** + * Converts a map of config (key, value) pairs to a map of strings where each value + * is converted to a string. This method should be used with care since it stores + * actual password values to String. Values from this map should never be used in log entries. + */ + public static Map convertToStringMapWithPasswordValues(Map configs) { + Map result = new HashMap<>(); + for (Map.Entry entry : configs.entrySet()) { + Object value = entry.getValue(); + String strValue; + if (value instanceof Password) + strValue = ((Password) value).value(); + else if (value instanceof List) + strValue = convertToString(value, Type.LIST); + else if (value instanceof Class) + strValue = convertToString(value, Type.CLASS); + else + strValue = convertToString(value, null); + if (strValue != null) + result.put(entry.getKey(), strValue); + } + return result; + } + + /** + * The config types + */ + public enum Type { + BOOLEAN, STRING, INT, SHORT, LONG, DOUBLE, LIST, CLASS, PASSWORD; + + public boolean isSensitive() { + return this == PASSWORD; + } + } + + /** + * The importance level for a configuration + */ + public enum Importance { + HIGH, MEDIUM, LOW + } + + /** + * The width of a configuration value + */ + public enum Width { + NONE, SHORT, MEDIUM, LONG + } + + /** + * This is used by the {@link #validate(Map)} to get valid values for a configuration given the current + * configuration values in order to perform full configuration validation and visibility modification. + * In case that there are dependencies between configurations, the valid values and visibility + * for a configuration may change given the values of other configurations. + */ + public interface Recommender { + + /** + * The valid values for the configuration given the current configuration values. + * @param name The name of the configuration + * @param parsedConfig The parsed configuration values + * @return The list of valid values. To function properly, the returned objects should have the type + * defined for the configuration using the recommender. + */ + List validValues(String name, Map parsedConfig); + + /** + * Set the visibility of the configuration given the current configuration values. + * @param name The name of the configuration + * @param parsedConfig The parsed configuration values + * @return The visibility of the configuration + */ + boolean visible(String name, Map parsedConfig); + } + + /** + * Validation logic the user may provide to perform single configuration validation. + */ + public interface Validator { + /** + * Perform single configuration validation. + * @param name The name of the configuration + * @param value The value of the configuration + * @throws ConfigException if the value is invalid. + */ + void ensureValid(String name, Object value); + } + + /** + * Validation logic for numeric ranges + */ + public static class Range implements Validator { + private final Number min; + private final Number max; + + /** + * A numeric range with inclusive upper bound and inclusive lower bound + * @param min the lower bound + * @param max the upper bound + */ + private Range(Number min, Number max) { + this.min = min; + this.max = max; + } + + /** + * A numeric range that checks only the lower bound + * + * @param min The minimum acceptable value + */ + public static Range atLeast(Number min) { + return new Range(min, null); + } + + /** + * A numeric range that checks both the upper (inclusive) and lower bound + */ + public static Range between(Number min, Number max) { + return new Range(min, max); + } + + public void ensureValid(String name, Object o) { + if (o == null) + throw new ConfigException(name, null, "Value must be non-null"); + Number n = (Number) o; + if (min != null && n.doubleValue() < min.doubleValue()) + throw new ConfigException(name, o, "Value must be at least " + min); + if (max != null && n.doubleValue() > max.doubleValue()) + throw new ConfigException(name, o, "Value must be no more than " + max); + } + + public String toString() { + if (min == null && max == null) + return "[...]"; + else if (min == null) + return "[...," + max + "]"; + else if (max == null) + return "[" + min + ",...]"; + else + return "[" + min + ",...," + max + "]"; + } + } + + public static class ValidList implements Validator { + + final ValidString validString; + + private ValidList(List validStrings) { + this.validString = new ValidString(validStrings); + } + + public static ValidList in(String... validStrings) { + return new ValidList(Arrays.asList(validStrings)); + } + + @Override + public void ensureValid(final String name, final Object value) { + @SuppressWarnings("unchecked") + List values = (List) value; + for (String string : values) { + validString.ensureValid(name, string); + } + } + + public String toString() { + return validString.toString(); + } + } + + public static class ValidString implements Validator { + final List validStrings; + + private ValidString(List validStrings) { + this.validStrings = validStrings; + } + + public static ValidString in(String... validStrings) { + return new ValidString(Arrays.asList(validStrings)); + } + + @Override + public void ensureValid(String name, Object o) { + String s = (String) o; + if (!validStrings.contains(s)) { + throw new ConfigException(name, o, "String must be one of: " + Utils.join(validStrings, ", ")); + } + + } + + public String toString() { + return "[" + Utils.join(validStrings, ", ") + "]"; + } + } + + public static class CaseInsensitiveValidString implements Validator { + + final Set validStrings; + + private CaseInsensitiveValidString(List validStrings) { + this.validStrings = validStrings.stream() + .map(s -> s.toUpperCase(Locale.ROOT)) + .collect(Collectors.toSet()); + } + + public static CaseInsensitiveValidString in(String... validStrings) { + return new CaseInsensitiveValidString(Arrays.asList(validStrings)); + } + + @Override + public void ensureValid(String name, Object o) { + String s = (String) o; + if (s == null || !validStrings.contains(s.toUpperCase(Locale.ROOT))) { + throw new ConfigException(name, o, "String must be one of (case insensitive): " + Utils.join(validStrings, ", ")); + } + } + + public String toString() { + return "(case insensitive) [" + Utils.join(validStrings, ", ") + "]"; + } + } + + public static class NonNullValidator implements Validator { + @Override + public void ensureValid(String name, Object value) { + if (value == null) { + // Pass in the string null to avoid the spotbugs warning + throw new ConfigException(name, "null", "entry must be non null"); + } + } + + public String toString() { + return "non-null string"; + } + } + + public static class LambdaValidator implements Validator { + BiConsumer ensureValid; + Supplier toStringFunction; + + private LambdaValidator(BiConsumer ensureValid, + Supplier toStringFunction) { + this.ensureValid = ensureValid; + this.toStringFunction = toStringFunction; + } + + public static LambdaValidator with(BiConsumer ensureValid, + Supplier toStringFunction) { + return new LambdaValidator(ensureValid, toStringFunction); + } + + @Override + public void ensureValid(String name, Object value) { + ensureValid.accept(name, value); + } + + @Override + public String toString() { + return toStringFunction.get(); + } + } + + public static class CompositeValidator implements Validator { + private final List validators; + + private CompositeValidator(List validators) { + this.validators = Collections.unmodifiableList(validators); + } + + public static CompositeValidator of(Validator... validators) { + return new CompositeValidator(Arrays.asList(validators)); + } + + @Override + public void ensureValid(String name, Object value) { + for (Validator validator: validators) { + validator.ensureValid(name, value); + } + } + + @Override + public String toString() { + if (validators == null) return ""; + StringBuilder desc = new StringBuilder(); + for (Validator v: validators) { + if (desc.length() > 0) { + desc.append(',').append(' '); + } + desc.append(v); + } + return desc.toString(); + } + } + + public static class NonEmptyString implements Validator { + + @Override + public void ensureValid(String name, Object o) { + String s = (String) o; + if (s != null && s.isEmpty()) { + throw new ConfigException(name, o, "String must be non-empty"); + } + } + + @Override + public String toString() { + return "non-empty string"; + } + } + + public static class NonEmptyStringWithoutControlChars implements Validator { + + public static NonEmptyStringWithoutControlChars nonEmptyStringWithoutControlChars() { + return new NonEmptyStringWithoutControlChars(); + } + + @Override + public void ensureValid(String name, Object value) { + String s = (String) value; + + if (s == null) { + // This can happen during creation of the config object due to no default value being defined for the + // name configuration - a missing name parameter is caught when checking for mandatory parameters, + // thus we can ok a null value here + return; + } else if (s.isEmpty()) { + throw new ConfigException(name, value, "String may not be empty"); + } + + // Check name string for illegal characters + ArrayList foundIllegalCharacters = new ArrayList<>(); + + for (int i = 0; i < s.length(); i++) { + if (Character.isISOControl(s.codePointAt(i))) { + foundIllegalCharacters.add(s.codePointAt(i)); + } + } + + if (!foundIllegalCharacters.isEmpty()) { + throw new ConfigException(name, value, "String may not contain control sequences but had the following ASCII chars: " + Utils.join(foundIllegalCharacters, ", ")); + } + } + + public String toString() { + return "non-empty string without ISO control characters"; + } + } + + public static class ConfigKey { + public final String name; + public final Type type; + public final String documentation; + public final Object defaultValue; + public final Validator validator; + public final Importance importance; + public final String group; + public final int orderInGroup; + public final Width width; + public final String displayName; + public final List dependents; + public final Recommender recommender; + public final boolean internalConfig; + + public ConfigKey(String name, Type type, Object defaultValue, Validator validator, + Importance importance, String documentation, String group, + int orderInGroup, Width width, String displayName, + List dependents, Recommender recommender, + boolean internalConfig) { + this.name = name; + this.type = type; + this.defaultValue = NO_DEFAULT_VALUE.equals(defaultValue) ? NO_DEFAULT_VALUE : parseType(name, defaultValue, type); + this.validator = validator; + this.importance = importance; + if (this.validator != null && hasDefault()) + this.validator.ensureValid(name, this.defaultValue); + this.documentation = documentation; + this.dependents = dependents; + this.group = group; + this.orderInGroup = orderInGroup; + this.width = width; + this.displayName = displayName; + this.recommender = recommender; + this.internalConfig = internalConfig; + } + + public boolean hasDefault() { + return !NO_DEFAULT_VALUE.equals(this.defaultValue); + } + + public Type type() { + return type; + } + } + + protected List headers() { + return Arrays.asList("Name", "Description", "Type", "Default", "Valid Values", "Importance"); + } + + protected String getConfigValue(ConfigKey key, String headerName) { + switch (headerName) { + case "Name": + return key.name; + case "Description": + return key.documentation; + case "Type": + return key.type.toString().toLowerCase(Locale.ROOT); + case "Default": + if (key.hasDefault()) { + if (key.defaultValue == null) + return "null"; + String defaultValueStr = convertToString(key.defaultValue, key.type); + if (defaultValueStr.isEmpty()) + return "\"\""; + else { + String suffix = ""; + if (key.name.endsWith(".bytes")) { + suffix = niceMemoryUnits(((Number) key.defaultValue).longValue()); + } else if (key.name.endsWith(".ms")) { + suffix = niceTimeUnits(((Number) key.defaultValue).longValue()); + } + return defaultValueStr + suffix; + } + } else + return ""; + case "Valid Values": + return key.validator != null ? key.validator.toString() : ""; + case "Importance": + return key.importance.toString().toLowerCase(Locale.ROOT); + default: + throw new RuntimeException("Can't find value for header '" + headerName + "' in " + key.name); + } + } + + static String niceMemoryUnits(long bytes) { + long value = bytes; + int i = 0; + while (value != 0 && i < 4) { + if (value % 1024L == 0) { + value /= 1024L; + i++; + } else { + break; + } + } + switch (i) { + case 1: + return " (" + value + " kibibyte" + (value == 1 ? ")" : "s)"); + case 2: + return " (" + value + " mebibyte" + (value == 1 ? ")" : "s)"); + case 3: + return " (" + value + " gibibyte" + (value == 1 ? ")" : "s)"); + case 4: + return " (" + value + " tebibyte" + (value == 1 ? ")" : "s)"); + default: + return ""; + } + } + + static String niceTimeUnits(long millis) { + long value = millis; + long[] divisors = {1000, 60, 60, 24}; + String[] units = {"second", "minute", "hour", "day"}; + int i = 0; + while (value != 0 && i < 4) { + if (value % divisors[i] == 0) { + value /= divisors[i]; + i++; + } else { + break; + } + } + if (i > 0) { + return " (" + value + " " + units[i - 1] + (value > 1 ? "s)" : ")"); + } + return ""; + } + + public String toHtmlTable() { + return toHtmlTable(Collections.emptyMap()); + } + + private void addHeader(StringBuilder builder, String headerName) { + builder.append(""); + builder.append(headerName); + builder.append("\n"); + } + + private void addColumnValue(StringBuilder builder, String value) { + builder.append(""); + builder.append(value); + builder.append(""); + } + + /** + * Converts this config into an HTML table that can be embedded into docs. + * If dynamicUpdateModes is non-empty, a "Dynamic Update Mode" column + * will be included n the table with the value of the update mode. Default + * mode is "read-only". + * @param dynamicUpdateModes Config name -> update mode mapping + */ + public String toHtmlTable(Map dynamicUpdateModes) { + boolean hasUpdateModes = !dynamicUpdateModes.isEmpty(); + List configs = sortedConfigs(); + StringBuilder b = new StringBuilder(); + b.append("\n"); + b.append("\n"); + // print column headers + for (String headerName : headers()) { + addHeader(b, headerName); + } + if (hasUpdateModes) + addHeader(b, "Dynamic Update Mode"); + b.append("\n"); + for (ConfigKey key : configs) { + if (key.internalConfig) { + continue; + } + b.append("\n"); + // print column values + for (String headerName : headers()) { + addColumnValue(b, getConfigValue(key, headerName)); + b.append(""); + } + if (hasUpdateModes) { + String updateMode = dynamicUpdateModes.get(key.name); + if (updateMode == null) + updateMode = "read-only"; + addColumnValue(b, updateMode); + } + b.append("\n"); + } + b.append("
        "); + return b.toString(); + } + + /** + * Get the configs formatted with reStructuredText, suitable for embedding in Sphinx + * documentation. + */ + public String toRst() { + StringBuilder b = new StringBuilder(); + for (ConfigKey key : sortedConfigs()) { + if (key.internalConfig) { + continue; + } + getConfigKeyRst(key, b); + b.append("\n"); + } + return b.toString(); + } + + /** + * Configs with new metadata (group, orderInGroup, dependents) formatted with reStructuredText, suitable for embedding in Sphinx + * documentation. + */ + public String toEnrichedRst() { + StringBuilder b = new StringBuilder(); + + String lastKeyGroupName = ""; + for (ConfigKey key : sortedConfigs()) { + if (key.internalConfig) { + continue; + } + if (key.group != null) { + if (!lastKeyGroupName.equalsIgnoreCase(key.group)) { + b.append(key.group).append("\n"); + + char[] underLine = new char[key.group.length()]; + Arrays.fill(underLine, '^'); + b.append(new String(underLine)).append("\n\n"); + } + lastKeyGroupName = key.group; + } + + getConfigKeyRst(key, b); + + if (key.dependents != null && key.dependents.size() > 0) { + int j = 0; + b.append(" * Dependents: "); + for (String dependent : key.dependents) { + b.append("``"); + b.append(dependent); + if (++j == key.dependents.size()) + b.append("``"); + else + b.append("``, "); + } + b.append("\n"); + } + b.append("\n"); + } + return b.toString(); + } + + /** + * Shared content on Rst and Enriched Rst. + */ + private void getConfigKeyRst(ConfigKey key, StringBuilder b) { + b.append("``").append(key.name).append("``").append("\n"); + if (key.documentation != null) { + for (String docLine : key.documentation.split("\n")) { + if (docLine.length() == 0) { + continue; + } + b.append(" ").append(docLine).append("\n\n"); + } + } else { + b.append("\n"); + } + b.append(" * Type: ").append(getConfigValue(key, "Type")).append("\n"); + if (key.hasDefault()) { + b.append(" * Default: ").append(getConfigValue(key, "Default")).append("\n"); + } + if (key.validator != null) { + b.append(" * Valid Values: ").append(getConfigValue(key, "Valid Values")).append("\n"); + } + b.append(" * Importance: ").append(getConfigValue(key, "Importance")).append("\n"); + } + + /** + * Get a list of configs sorted taking the 'group' and 'orderInGroup' into account. + * + * If grouping is not specified, the result will reflect "natural" order: listing required fields first, then ordering by importance, and finally by name. + */ + private List sortedConfigs() { + final Map groupOrd = new HashMap<>(groups.size()); + int ord = 0; + for (String group: groups) { + groupOrd.put(group, ord++); + } + + List configs = new ArrayList<>(configKeys.values()); + Collections.sort(configs, (k1, k2) -> compare(k1, k2, groupOrd)); + return configs; + } + + private int compare(ConfigKey k1, ConfigKey k2, Map groupOrd) { + int cmp = k1.group == null + ? (k2.group == null ? 0 : -1) + : (k2.group == null ? 1 : Integer.compare(groupOrd.get(k1.group), groupOrd.get(k2.group))); + if (cmp == 0) { + cmp = Integer.compare(k1.orderInGroup, k2.orderInGroup); + if (cmp == 0) { + // first take anything with no default value + if (!k1.hasDefault() && k2.hasDefault()) + cmp = -1; + else if (!k2.hasDefault() && k1.hasDefault()) + cmp = 1; + else { + cmp = k1.importance.compareTo(k2.importance); + if (cmp == 0) + return k1.name.compareTo(k2.name); + } + } + } + return cmp; + } + + public void embed(final String keyPrefix, final String groupPrefix, final int startingOrd, final ConfigDef child) { + int orderInGroup = startingOrd; + for (ConfigKey key : child.sortedConfigs()) { + define(new ConfigKey( + keyPrefix + key.name, + key.type, + key.defaultValue, + embeddedValidator(keyPrefix, key.validator), + key.importance, + key.documentation, + groupPrefix + (key.group == null ? "" : ": " + key.group), + orderInGroup++, + key.width, + key.displayName, + embeddedDependents(keyPrefix, key.dependents), + embeddedRecommender(keyPrefix, key.recommender), + key.internalConfig)); + } + } + + /** + * Returns a new validator instance that delegates to the base validator but unprefixes the config name along the way. + */ + private static Validator embeddedValidator(final String keyPrefix, final Validator base) { + if (base == null) return null; + return new Validator() { + public void ensureValid(String name, Object value) { + base.ensureValid(name.substring(keyPrefix.length()), value); + } + + @Override + public String toString() { + return base.toString(); + } + }; + } + + /** + * Updated list of dependent configs with the specified {@code prefix} added. + */ + private static List embeddedDependents(final String keyPrefix, final List dependents) { + if (dependents == null) return null; + final List updatedDependents = new ArrayList<>(dependents.size()); + for (String dependent : dependents) { + updatedDependents.add(keyPrefix + dependent); + } + return updatedDependents; + } + + /** + * Returns a new recommender instance that delegates to the base recommender but unprefixes the input parameters along the way. + */ + private static Recommender embeddedRecommender(final String keyPrefix, final Recommender base) { + if (base == null) return null; + return new Recommender() { + private String unprefixed(String k) { + return k.substring(keyPrefix.length()); + } + + private Map unprefixed(Map parsedConfig) { + final Map unprefixedParsedConfig = new HashMap<>(parsedConfig.size()); + for (Map.Entry e : parsedConfig.entrySet()) { + if (e.getKey().startsWith(keyPrefix)) { + unprefixedParsedConfig.put(unprefixed(e.getKey()), e.getValue()); + } + } + return unprefixedParsedConfig; + } + + @Override + public List validValues(String name, Map parsedConfig) { + return base.validValues(unprefixed(name), unprefixed(parsedConfig)); + } + + @Override + public boolean visible(String name, Map parsedConfig) { + return base.visible(unprefixed(name), unprefixed(parsedConfig)); + } + }; + } + + public String toHtml() { + return toHtml(Collections.emptyMap()); + } + + /** + * Converts this config into an HTML list that can be embedded into docs. + * @param headerDepth The top level header depth in the generated HTML. + * @param idGenerator A function for computing the HTML id attribute in the generated HTML from a given config name. + */ + public String toHtml(int headerDepth, Function idGenerator) { + return toHtml(headerDepth, idGenerator, Collections.emptyMap()); + } + + /** + * Converts this config into an HTML list that can be embedded into docs. + * If dynamicUpdateModes is non-empty, a "Dynamic Update Mode" label + * will be included in the config details with the value of the update mode. Default + * mode is "read-only". + * @param dynamicUpdateModes Config name -> update mode mapping. + */ + public String toHtml(Map dynamicUpdateModes) { + return toHtml(4, Function.identity(), dynamicUpdateModes); + } + + /** + * Converts this config into an HTML list that can be embedded into docs. + * If dynamicUpdateModes is non-empty, a "Dynamic Update Mode" label + * will be included in the config details with the value of the update mode. Default + * mode is "read-only". + * @param headerDepth The top level header depth in the generated HTML. + * @param idGenerator A function for computing the HTML id attribute in the generated HTML from a given config name. + * @param dynamicUpdateModes Config name -> update mode mapping. + */ + public String toHtml(int headerDepth, Function idGenerator, + Map dynamicUpdateModes) { + boolean hasUpdateModes = !dynamicUpdateModes.isEmpty(); + List configs = sortedConfigs(); + StringBuilder b = new StringBuilder(); + b.append("
          \n"); + for (ConfigKey key : configs) { + if (key.internalConfig) { + continue; + } + b.append("
        • \n"); + b.append(String.format("" + + "%3$s" + + "%n", headerDepth, idGenerator.apply(key.name), key.name)); + b.append("

          "); + if (key.documentation != null) { + b.append(key.documentation.replaceAll("\n", "
          ")); + } + b.append("

          \n"); + + b.append("" + + "\n"); + for (String detail : headers()) { + if (detail.equals("Name") || detail.equals("Description")) continue; + addConfigDetail(b, detail, getConfigValue(key, detail)); + } + if (hasUpdateModes) { + String updateMode = dynamicUpdateModes.get(key.name); + if (updateMode == null) + updateMode = "read-only"; + addConfigDetail(b, "Update Mode", updateMode); + } + b.append("
          \n"); + b.append("
        • \n"); + } + b.append("
        \n"); + return b.toString(); + } + + private static void addConfigDetail(StringBuilder builder, String name, String value) { + builder.append("" + + "" + name + ":" + + "" + value + "" + + "\n"); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/config/ConfigException.java b/clients/src/main/java/org/apache/kafka/common/config/ConfigException.java new file mode 100644 index 0000000..c48bfc6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/config/ConfigException.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.config; + +import org.apache.kafka.common.KafkaException; + +/** + * Thrown if the user supplies an invalid configuration + */ +public class ConfigException extends KafkaException { + + private static final long serialVersionUID = 1L; + + public ConfigException(String message) { + super(message); + } + + public ConfigException(String name, Object value) { + this(name, value, null); + } + + public ConfigException(String name, Object value, String message) { + super("Invalid value " + value + " for configuration " + name + (message == null ? "" : ": " + message)); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/config/ConfigResource.java b/clients/src/main/java/org/apache/kafka/common/config/ConfigResource.java new file mode 100644 index 0000000..8870238 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/config/ConfigResource.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.config; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; +import java.util.Objects; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * A class representing resources that have configs. + */ +public final class ConfigResource { + + /** + * Type of resource. + */ + public enum Type { + BROKER_LOGGER((byte) 8), BROKER((byte) 4), TOPIC((byte) 2), UNKNOWN((byte) 0); + + private static final Map TYPES = Collections.unmodifiableMap( + Arrays.stream(values()).collect(Collectors.toMap(Type::id, Function.identity())) + ); + + private final byte id; + + Type(final byte id) { + this.id = id; + } + + public byte id() { + return id; + } + + public static Type forId(final byte id) { + return TYPES.getOrDefault(id, UNKNOWN); + } + } + + private final Type type; + private final String name; + + /** + * Create an instance of this class with the provided parameters. + * + * @param type a non-null resource type + * @param name a non-null resource name + */ + public ConfigResource(Type type, String name) { + Objects.requireNonNull(type, "type should not be null"); + Objects.requireNonNull(name, "name should not be null"); + this.type = type; + this.name = name; + } + + /** + * Return the resource type. + */ + public Type type() { + return type; + } + + /** + * Return the resource name. + */ + public String name() { + return name; + } + + /** + * Returns true if this is the default resource of a resource type. + * Resource name is empty for the default resource. + */ + public boolean isDefault() { + return name.isEmpty(); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + ConfigResource that = (ConfigResource) o; + + return type == that.type && name.equals(that.name); + } + + @Override + public int hashCode() { + int result = type.hashCode(); + result = 31 * result + name.hashCode(); + return result; + } + + @Override + public String toString() { + return "ConfigResource(type=" + type + ", name='" + name + "')"; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/config/ConfigTransformer.java b/clients/src/main/java/org/apache/kafka/common/config/ConfigTransformer.java new file mode 100644 index 0000000..4f078b1 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/config/ConfigTransformer.java @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.config; + +import org.apache.kafka.common.config.provider.ConfigProvider; +import org.apache.kafka.common.config.provider.FileConfigProvider; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * This class wraps a set of {@link ConfigProvider} instances and uses them to perform + * transformations. + * + *

        The default variable pattern is of the form ${provider:[path:]key}, + * where the provider corresponds to a {@link ConfigProvider} instance, as passed to + * {@link ConfigTransformer#ConfigTransformer(Map)}. The pattern will extract a set + * of paths (which are optional) and keys and then pass them to {@link ConfigProvider#get(String, Set)} to obtain the + * values with which to replace the variables. + * + *

        For example, if a Map consisting of an entry with a provider name "file" and provider instance + * {@link FileConfigProvider} is passed to the {@link ConfigTransformer#ConfigTransformer(Map)}, and a Properties + * file with contents + *

        + * fileKey=someValue
        + * 
        + * resides at the path "/tmp/properties.txt", then when a configuration Map which has an entry with a key "someKey" and + * a value "${file:/tmp/properties.txt:fileKey}" is passed to the {@link #transform(Map)} method, then the transformed + * Map will have an entry with key "someKey" and a value "someValue". + * + *

        This class only depends on {@link ConfigProvider#get(String, Set)} and does not depend on subscription support + * in a {@link ConfigProvider}, such as the {@link ConfigProvider#subscribe(String, Set, ConfigChangeCallback)} and + * {@link ConfigProvider#unsubscribe(String, Set, ConfigChangeCallback)} methods. + */ +public class ConfigTransformer { + public static final Pattern DEFAULT_PATTERN = Pattern.compile("\\$\\{([^}]*?):(([^}]*?):)?([^}]*?)\\}"); + private static final String EMPTY_PATH = ""; + + private final Map configProviders; + + /** + * Creates a ConfigTransformer with the default pattern, of the form ${provider:[path:]key}. + * + * @param configProviders a Map of provider names and {@link ConfigProvider} instances. + */ + public ConfigTransformer(Map configProviders) { + this.configProviders = configProviders; + } + + /** + * Transforms the given configuration data by using the {@link ConfigProvider} instances to + * look up values to replace the variables in the pattern. + * + * @param configs the configuration values to be transformed + * @return an instance of {@link ConfigTransformerResult} + */ + public ConfigTransformerResult transform(Map configs) { + Map>> keysByProvider = new HashMap<>(); + Map>> lookupsByProvider = new HashMap<>(); + + // Collect the variables from the given configs that need transformation + for (Map.Entry config : configs.entrySet()) { + if (config.getValue() != null) { + List vars = getVars(config.getValue(), DEFAULT_PATTERN); + for (ConfigVariable var : vars) { + Map> keysByPath = keysByProvider.computeIfAbsent(var.providerName, k -> new HashMap<>()); + Set keys = keysByPath.computeIfAbsent(var.path, k -> new HashSet<>()); + keys.add(var.variable); + } + } + } + + // Retrieve requested variables from the ConfigProviders + Map ttls = new HashMap<>(); + for (Map.Entry>> entry : keysByProvider.entrySet()) { + String providerName = entry.getKey(); + ConfigProvider provider = configProviders.get(providerName); + Map> keysByPath = entry.getValue(); + if (provider != null && keysByPath != null) { + for (Map.Entry> pathWithKeys : keysByPath.entrySet()) { + String path = pathWithKeys.getKey(); + Set keys = new HashSet<>(pathWithKeys.getValue()); + ConfigData configData = provider.get(path, keys); + Map data = configData.data(); + Long ttl = configData.ttl(); + if (ttl != null && ttl >= 0) { + ttls.put(path, ttl); + } + Map> keyValuesByPath = + lookupsByProvider.computeIfAbsent(providerName, k -> new HashMap<>()); + keyValuesByPath.put(path, data); + } + } + } + + // Perform the transformations by performing variable replacements + Map data = new HashMap<>(configs); + for (Map.Entry config : configs.entrySet()) { + data.put(config.getKey(), replace(lookupsByProvider, config.getValue(), DEFAULT_PATTERN)); + } + return new ConfigTransformerResult(data, ttls); + } + + private static List getVars(String value, Pattern pattern) { + List configVars = new ArrayList<>(); + Matcher matcher = pattern.matcher(value); + while (matcher.find()) { + configVars.add(new ConfigVariable(matcher)); + } + return configVars; + } + + private static String replace(Map>> lookupsByProvider, + String value, + Pattern pattern) { + if (value == null) { + return null; + } + Matcher matcher = pattern.matcher(value); + StringBuilder builder = new StringBuilder(); + int i = 0; + while (matcher.find()) { + ConfigVariable configVar = new ConfigVariable(matcher); + Map> lookupsByPath = lookupsByProvider.get(configVar.providerName); + if (lookupsByPath != null) { + Map keyValues = lookupsByPath.get(configVar.path); + String replacement = keyValues.get(configVar.variable); + builder.append(value, i, matcher.start()); + if (replacement == null) { + // No replacements will be performed; just return the original value + builder.append(matcher.group(0)); + } else { + builder.append(replacement); + } + i = matcher.end(); + } + } + builder.append(value, i, value.length()); + return builder.toString(); + } + + private static class ConfigVariable { + final String providerName; + final String path; + final String variable; + + ConfigVariable(Matcher matcher) { + this.providerName = matcher.group(1); + this.path = matcher.group(3) != null ? matcher.group(3) : EMPTY_PATH; + this.variable = matcher.group(4); + } + + public String toString() { + return "(" + providerName + ":" + (path != null ? path + ":" : "") + variable + ")"; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/config/ConfigTransformerResult.java b/clients/src/main/java/org/apache/kafka/common/config/ConfigTransformerResult.java new file mode 100644 index 0000000..a05669c --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/config/ConfigTransformerResult.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.config; + +import org.apache.kafka.common.config.provider.ConfigProvider; + +import java.util.Map; + +/** + * The result of a transformation from {@link ConfigTransformer}. + */ +public class ConfigTransformerResult { + + private Map ttls; + private Map data; + + /** + * Creates a new ConfigTransformerResult with the given data and TTL values for a set of paths. + * + * @param data a Map of key-value pairs + * @param ttls a Map of path and TTL values (in milliseconds) + */ + public ConfigTransformerResult(Map data, Map ttls) { + this.data = data; + this.ttls = ttls; + } + + /** + * Returns the transformed data, with variables replaced with corresponding values from the + * ConfigProvider instances if found. + * + *

        Modifying the transformed data that is returned does not affect the {@link ConfigProvider} nor the + * original data that was used as the source of the transformation. + * + * @return data a Map of key-value pairs + */ + public Map data() { + return data; + } + + /** + * Returns the TTL values (in milliseconds) returned from the ConfigProvider instances for a given set of paths. + * + * @return data a Map of path and TTL values + */ + public Map ttls() { + return ttls; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/config/ConfigValue.java b/clients/src/main/java/org/apache/kafka/common/config/ConfigValue.java new file mode 100644 index 0000000..dafd7c6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/config/ConfigValue.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.config; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +public class ConfigValue { + + private final String name; + private Object value; + private List recommendedValues; + private final List errorMessages; + private boolean visible; + + public ConfigValue(String name) { + this(name, null, new ArrayList<>(), new ArrayList()); + } + + public ConfigValue(String name, Object value, List recommendedValues, List errorMessages) { + this.name = name; + this.value = value; + this.recommendedValues = recommendedValues; + this.errorMessages = errorMessages; + this.visible = true; + } + + public String name() { + return name; + } + + public Object value() { + return value; + } + + public List recommendedValues() { + return recommendedValues; + } + + public List errorMessages() { + return errorMessages; + } + + public boolean visible() { + return visible; + } + + public void value(Object value) { + this.value = value; + } + + public void recommendedValues(List recommendedValues) { + this.recommendedValues = recommendedValues; + } + + public void addErrorMessage(String errorMessage) { + this.errorMessages.add(errorMessage); + } + + public void visible(boolean visible) { + this.visible = visible; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ConfigValue that = (ConfigValue) o; + return Objects.equals(name, that.name) && + Objects.equals(value, that.value) && + Objects.equals(recommendedValues, that.recommendedValues) && + Objects.equals(errorMessages, that.errorMessages) && + Objects.equals(visible, that.visible); + } + + @Override + public int hashCode() { + return Objects.hash(name, value, recommendedValues, errorMessages, visible); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append("[") + .append(name) + .append(",") + .append(value) + .append(",") + .append(recommendedValues) + .append(",") + .append(errorMessages) + .append(",") + .append(visible) + .append("]"); + return sb.toString(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/config/LogLevelConfig.java b/clients/src/main/java/org/apache/kafka/common/config/LogLevelConfig.java new file mode 100644 index 0000000..fe7e2eb --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/config/LogLevelConfig.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.config; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; + +/** + * This class holds definitions for log level configurations related to Kafka's application logging. See KIP-412 for additional information + */ +public class LogLevelConfig { + /* + * NOTE: DO NOT CHANGE EITHER CONFIG NAMES AS THESE ARE PART OF THE PUBLIC API AND CHANGE WILL BREAK USER CODE. + */ + + /** + * The FATAL level designates a very severe error + * that will lead the Kafka broker to abort. + */ + public static final String FATAL_LOG_LEVEL = "FATAL"; + + /** + * The ERROR level designates error events that + * might still allow the broker to continue running. + */ + public static final String ERROR_LOG_LEVEL = "ERROR"; + + /** + * The WARN level designates potentially harmful situations. + */ + public static final String WARN_LOG_LEVEL = "WARN"; + + /** + * The INFO level designates informational messages + * that highlight normal Kafka events at a coarse-grained level + */ + public static final String INFO_LOG_LEVEL = "INFO"; + + /** + * The DEBUG level designates fine-grained + * informational events that are most useful to debug Kafka + */ + public static final String DEBUG_LOG_LEVEL = "DEBUG"; + + /** + * The TRACE level designates finer-grained + * informational events than the DEBUG level. + */ + public static final String TRACE_LOG_LEVEL = "TRACE"; + + public static final Set VALID_LOG_LEVELS = new HashSet<>(Arrays.asList( + FATAL_LOG_LEVEL, ERROR_LOG_LEVEL, WARN_LOG_LEVEL, + INFO_LOG_LEVEL, DEBUG_LOG_LEVEL, TRACE_LOG_LEVEL + )); +} diff --git a/clients/src/main/java/org/apache/kafka/common/config/SaslConfigs.java b/clients/src/main/java/org/apache/kafka/common/config/SaslConfigs.java new file mode 100644 index 0000000..a897b5e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/config/SaslConfigs.java @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.config; + +import org.apache.kafka.common.config.ConfigDef.Range; + +public class SaslConfigs { + + private static final String OAUTHBEARER_NOTE = " Currently applies only to OAUTHBEARER."; + + /* + * NOTE: DO NOT CHANGE EITHER CONFIG NAMES AS THESE ARE PART OF THE PUBLIC API AND CHANGE WILL BREAK USER CODE. + */ + /** SASL mechanism configuration - standard mechanism names are listed here. */ + public static final String SASL_MECHANISM = "sasl.mechanism"; + public static final String SASL_MECHANISM_DOC = "SASL mechanism used for client connections. This may be any mechanism for which a security provider is available. GSSAPI is the default mechanism."; + public static final String GSSAPI_MECHANISM = "GSSAPI"; + public static final String DEFAULT_SASL_MECHANISM = GSSAPI_MECHANISM; + + public static final String SASL_JAAS_CONFIG = "sasl.jaas.config"; + public static final String SASL_JAAS_CONFIG_DOC = "JAAS login context parameters for SASL connections in the format used by JAAS configuration files. " + + "JAAS configuration file format is described here. " + + "The format for the value is: loginModuleClass controlFlag (optionName=optionValue)*;. For brokers, " + + "the config must be prefixed with listener prefix and SASL mechanism name in lower-case. For example, " + + "listener.name.sasl_ssl.scram-sha-256.sasl.jaas.config=com.example.ScramLoginModule required;"; + + public static final String SASL_CLIENT_CALLBACK_HANDLER_CLASS = "sasl.client.callback.handler.class"; + public static final String SASL_CLIENT_CALLBACK_HANDLER_CLASS_DOC = "The fully qualified name of a SASL client callback handler class " + + "that implements the AuthenticateCallbackHandler interface."; + + public static final String SASL_LOGIN_CALLBACK_HANDLER_CLASS = "sasl.login.callback.handler.class"; + public static final String SASL_LOGIN_CALLBACK_HANDLER_CLASS_DOC = "The fully qualified name of a SASL login callback handler class " + + "that implements the AuthenticateCallbackHandler interface. For brokers, login callback handler config must be prefixed with " + + "listener prefix and SASL mechanism name in lower-case. For example, " + + "listener.name.sasl_ssl.scram-sha-256.sasl.login.callback.handler.class=com.example.CustomScramLoginCallbackHandler"; + + public static final String SASL_LOGIN_CLASS = "sasl.login.class"; + public static final String SASL_LOGIN_CLASS_DOC = "The fully qualified name of a class that implements the Login interface. " + + "For brokers, login config must be prefixed with listener prefix and SASL mechanism name in lower-case. For example, " + + "listener.name.sasl_ssl.scram-sha-256.sasl.login.class=com.example.CustomScramLogin"; + + public static final String SASL_KERBEROS_SERVICE_NAME = "sasl.kerberos.service.name"; + public static final String SASL_KERBEROS_SERVICE_NAME_DOC = "The Kerberos principal name that Kafka runs as. " + + "This can be defined either in Kafka's JAAS config or in Kafka's config."; + + public static final String SASL_KERBEROS_KINIT_CMD = "sasl.kerberos.kinit.cmd"; + public static final String SASL_KERBEROS_KINIT_CMD_DOC = "Kerberos kinit command path."; + public static final String DEFAULT_KERBEROS_KINIT_CMD = "/usr/bin/kinit"; + + public static final String SASL_KERBEROS_TICKET_RENEW_WINDOW_FACTOR = "sasl.kerberos.ticket.renew.window.factor"; + public static final String SASL_KERBEROS_TICKET_RENEW_WINDOW_FACTOR_DOC = "Login thread will sleep until the specified window factor of time from last refresh" + + " to ticket's expiry has been reached, at which time it will try to renew the ticket."; + public static final double DEFAULT_KERBEROS_TICKET_RENEW_WINDOW_FACTOR = 0.80; + + public static final String SASL_KERBEROS_TICKET_RENEW_JITTER = "sasl.kerberos.ticket.renew.jitter"; + public static final String SASL_KERBEROS_TICKET_RENEW_JITTER_DOC = "Percentage of random jitter added to the renewal time."; + public static final double DEFAULT_KERBEROS_TICKET_RENEW_JITTER = 0.05; + + public static final String SASL_KERBEROS_MIN_TIME_BEFORE_RELOGIN = "sasl.kerberos.min.time.before.relogin"; + public static final String SASL_KERBEROS_MIN_TIME_BEFORE_RELOGIN_DOC = "Login thread sleep time between refresh attempts."; + public static final long DEFAULT_KERBEROS_MIN_TIME_BEFORE_RELOGIN = 1 * 60 * 1000L; + + public static final String SASL_LOGIN_REFRESH_WINDOW_FACTOR = "sasl.login.refresh.window.factor"; + public static final String SASL_LOGIN_REFRESH_WINDOW_FACTOR_DOC = "Login refresh thread will sleep until the specified window factor relative to the" + + " credential's lifetime has been reached, at which time it will try to refresh the credential." + + " Legal values are between 0.5 (50%) and 1.0 (100%) inclusive; a default value of 0.8 (80%) is used" + + " if no value is specified." + + OAUTHBEARER_NOTE; + public static final double DEFAULT_LOGIN_REFRESH_WINDOW_FACTOR = 0.80; + + public static final String SASL_LOGIN_REFRESH_WINDOW_JITTER = "sasl.login.refresh.window.jitter"; + public static final String SASL_LOGIN_REFRESH_WINDOW_JITTER_DOC = "The maximum amount of random jitter relative to the credential's lifetime" + + " that is added to the login refresh thread's sleep time. Legal values are between 0 and 0.25 (25%) inclusive;" + + " a default value of 0.05 (5%) is used if no value is specified." + + OAUTHBEARER_NOTE; + public static final double DEFAULT_LOGIN_REFRESH_WINDOW_JITTER = 0.05; + + public static final String SASL_LOGIN_REFRESH_MIN_PERIOD_SECONDS = "sasl.login.refresh.min.period.seconds"; + public static final String SASL_LOGIN_REFRESH_MIN_PERIOD_SECONDS_DOC = "The desired minimum time for the login refresh thread to wait before refreshing a credential," + + " in seconds. Legal values are between 0 and 900 (15 minutes); a default value of 60 (1 minute) is used if no value is specified. This value and " + + " sasl.login.refresh.buffer.seconds are both ignored if their sum exceeds the remaining lifetime of a credential." + + OAUTHBEARER_NOTE; + public static final short DEFAULT_LOGIN_REFRESH_MIN_PERIOD_SECONDS = 60; + + public static final String SASL_LOGIN_REFRESH_BUFFER_SECONDS = "sasl.login.refresh.buffer.seconds"; + public static final String SASL_LOGIN_REFRESH_BUFFER_SECONDS_DOC = "The amount of buffer time before credential expiration to maintain when refreshing a credential," + + " in seconds. If a refresh would otherwise occur closer to expiration than the number of buffer seconds then the refresh will be moved up to maintain" + + " as much of the buffer time as possible. Legal values are between 0 and 3600 (1 hour); a default value of 300 (5 minutes) is used if no value is specified." + + " This value and sasl.login.refresh.min.period.seconds are both ignored if their sum exceeds the remaining lifetime of a credential." + + OAUTHBEARER_NOTE; + public static final short DEFAULT_LOGIN_REFRESH_BUFFER_SECONDS = 300; + + public static final String SASL_LOGIN_CONNECT_TIMEOUT_MS = "sasl.login.connect.timeout.ms"; + public static final String SASL_LOGIN_CONNECT_TIMEOUT_MS_DOC = "The (optional) value in milliseconds for the external authentication provider connection timeout." + + OAUTHBEARER_NOTE; + + public static final String SASL_LOGIN_READ_TIMEOUT_MS = "sasl.login.read.timeout.ms"; + public static final String SASL_LOGIN_READ_TIMEOUT_MS_DOC = "The (optional) value in milliseconds for the external authentication provider read timeout." + + OAUTHBEARER_NOTE; + + private static final String LOGIN_EXPONENTIAL_BACKOFF_NOTE = " Login uses an exponential backoff algorithm with an initial wait based on the" + + " sasl.login.retry.backoff.ms setting and will double in wait length between attempts up to a maximum wait length specified by the" + + " sasl.login.retry.backoff.max.ms setting." + + OAUTHBEARER_NOTE; + + public static final String SASL_LOGIN_RETRY_BACKOFF_MAX_MS = "sasl.login.retry.backoff.max.ms"; + public static final long DEFAULT_SASL_LOGIN_RETRY_BACKOFF_MAX_MS = 10000; + public static final String SASL_LOGIN_RETRY_BACKOFF_MAX_MS_DOC = "The (optional) value in milliseconds for the maximum wait between login attempts to the" + + " external authentication provider." + + LOGIN_EXPONENTIAL_BACKOFF_NOTE; + + public static final String SASL_LOGIN_RETRY_BACKOFF_MS = "sasl.login.retry.backoff.ms"; + public static final long DEFAULT_SASL_LOGIN_RETRY_BACKOFF_MS = 100; + public static final String SASL_LOGIN_RETRY_BACKOFF_MS_DOC = "The (optional) value in milliseconds for the initial wait between login attempts to the external" + + " authentication provider." + + LOGIN_EXPONENTIAL_BACKOFF_NOTE; + + public static final String SASL_OAUTHBEARER_SCOPE_CLAIM_NAME = "sasl.oauthbearer.scope.claim.name"; + public static final String DEFAULT_SASL_OAUTHBEARER_SCOPE_CLAIM_NAME = "scope"; + public static final String SASL_OAUTHBEARER_SCOPE_CLAIM_NAME_DOC = "The OAuth claim for the scope is often named \"" + DEFAULT_SASL_OAUTHBEARER_SCOPE_CLAIM_NAME + "\", but this (optional)" + + " setting can provide a different name to use for the scope included in the JWT payload's claims if the OAuth/OIDC provider uses a different" + + " name for that claim."; + + public static final String SASL_OAUTHBEARER_SUB_CLAIM_NAME = "sasl.oauthbearer.sub.claim.name"; + public static final String DEFAULT_SASL_OAUTHBEARER_SUB_CLAIM_NAME = "sub"; + public static final String SASL_OAUTHBEARER_SUB_CLAIM_NAME_DOC = "The OAuth claim for the subject is often named \"" + DEFAULT_SASL_OAUTHBEARER_SUB_CLAIM_NAME + "\", but this (optional)" + + " setting can provide a different name to use for the subject included in the JWT payload's claims if the OAuth/OIDC provider uses a different" + + " name for that claim."; + + public static final String SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL = "sasl.oauthbearer.token.endpoint.url"; + public static final String SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL_DOC = "The URL for the OAuth/OIDC identity provider. If the URL is HTTP(S)-based, it is the issuer's token" + + " endpoint URL to which requests will be made to login based on the configuration in " + SASL_JAAS_CONFIG + ". If the URL is file-based, it" + + " specifies a file containing an access token (in JWT serialized form) issued by the OAuth/OIDC identity provider to use for authorization."; + + public static final String SASL_OAUTHBEARER_JWKS_ENDPOINT_URL = "sasl.oauthbearer.jwks.endpoint.url"; + public static final String SASL_OAUTHBEARER_JWKS_ENDPOINT_URL_DOC = "The OAuth/OIDC provider URL from which the provider's" + + " JWKS (JSON Web Key Set) can be retrieved. The URL can be HTTP(S)-based or file-based." + + " If the URL is HTTP(S)-based, the JWKS data will be retrieved from the OAuth/OIDC provider via the configured URL on broker startup. All then-current" + + " keys will be cached on the broker for incoming requests. If an authentication request is received for a JWT that includes a \"kid\" header claim value that" + + " isn't yet in the cache, the JWKS endpoint will be queried again on demand. However, the broker polls the URL every sasl.oauthbearer.jwks.endpoint.refresh.ms" + + " milliseconds to refresh the cache with any forthcoming keys before any JWT requests that include them are received." + + " If the URL is file-based, the broker will load the JWKS file from a configured location on startup. In the event that the JWT includes a \"kid\" header" + + " value that isn't in the JWKS file, the broker will reject the JWT and authentication will fail."; + + public static final String SASL_OAUTHBEARER_JWKS_ENDPOINT_REFRESH_MS = "sasl.oauthbearer.jwks.endpoint.refresh.ms"; + public static final long DEFAULT_SASL_OAUTHBEARER_JWKS_ENDPOINT_REFRESH_MS = 60 * 60 * 1000; + public static final String SASL_OAUTHBEARER_JWKS_ENDPOINT_REFRESH_MS_DOC = "The (optional) value in milliseconds for the broker to wait between refreshing its JWKS (JSON Web Key Set)" + + " cache that contains the keys to verify the signature of the JWT."; + + private static final String JWKS_EXPONENTIAL_BACKOFF_NOTE = " JWKS retrieval uses an exponential backoff algorithm with an initial wait based on the" + + " sasl.oauthbearer.jwks.endpoint.retry.backoff.ms setting and will double in wait length between attempts up to a maximum wait length specified by the" + + " sasl.oauthbearer.jwks.endpoint.retry.backoff.max.ms setting."; + + public static final String SASL_OAUTHBEARER_JWKS_ENDPOINT_RETRY_BACKOFF_MAX_MS = "sasl.oauthbearer.jwks.endpoint.retry.backoff.max.ms"; + public static final long DEFAULT_SASL_OAUTHBEARER_JWKS_ENDPOINT_RETRY_BACKOFF_MAX_MS = 10000; + public static final String SASL_OAUTHBEARER_JWKS_ENDPOINT_RETRY_BACKOFF_MAX_MS_DOC = "The (optional) value in milliseconds for the maximum wait between attempts to retrieve the JWKS (JSON Web Key Set)" + + " from the external authentication provider." + + JWKS_EXPONENTIAL_BACKOFF_NOTE; + + public static final String SASL_OAUTHBEARER_JWKS_ENDPOINT_RETRY_BACKOFF_MS = "sasl.oauthbearer.jwks.endpoint.retry.backoff.ms"; + public static final long DEFAULT_SASL_OAUTHBEARER_JWKS_ENDPOINT_RETRY_BACKOFF_MS = 100; + public static final String SASL_OAUTHBEARER_JWKS_ENDPOINT_RETRY_BACKOFF_MS_DOC = "The (optional) value in milliseconds for the initial wait between JWKS (JSON Web Key Set) retrieval attempts from the external" + + " authentication provider." + + JWKS_EXPONENTIAL_BACKOFF_NOTE; + + public static final String SASL_OAUTHBEARER_CLOCK_SKEW_SECONDS = "sasl.oauthbearer.clock.skew.seconds"; + public static final int DEFAULT_SASL_OAUTHBEARER_CLOCK_SKEW_SECONDS = 30; + public static final String SASL_OAUTHBEARER_CLOCK_SKEW_SECONDS_DOC = "The (optional) value in seconds to allow for differences between the time of the OAuth/OIDC identity provider and" + + " the broker."; + + public static final String SASL_OAUTHBEARER_EXPECTED_AUDIENCE = "sasl.oauthbearer.expected.audience"; + public static final String SASL_OAUTHBEARER_EXPECTED_AUDIENCE_DOC = "The (optional) comma-delimited setting for the broker to use to verify that the JWT was issued for one of the" + + " expected audiences. The JWT will be inspected for the standard OAuth \"aud\" claim and if this value is set, the broker will match the value from JWT's \"aud\" claim " + + " to see if there is an exact match. If there is no match, the broker will reject the JWT and authentication will fail."; + + public static final String SASL_OAUTHBEARER_EXPECTED_ISSUER = "sasl.oauthbearer.expected.issuer"; + public static final String SASL_OAUTHBEARER_EXPECTED_ISSUER_DOC = "The (optional) setting for the broker to use to verify that the JWT was created by the expected issuer. The JWT will" + + " be inspected for the standard OAuth \"iss\" claim and if this value is set, the broker will match it exactly against what is in the JWT's \"iss\" claim. If there is no" + + " match, the broker will reject the JWT and authentication will fail."; + + public static void addClientSaslSupport(ConfigDef config) { + config.define(SaslConfigs.SASL_KERBEROS_SERVICE_NAME, ConfigDef.Type.STRING, null, ConfigDef.Importance.MEDIUM, SaslConfigs.SASL_KERBEROS_SERVICE_NAME_DOC) + .define(SaslConfigs.SASL_KERBEROS_KINIT_CMD, ConfigDef.Type.STRING, SaslConfigs.DEFAULT_KERBEROS_KINIT_CMD, ConfigDef.Importance.LOW, SaslConfigs.SASL_KERBEROS_KINIT_CMD_DOC) + .define(SaslConfigs.SASL_KERBEROS_TICKET_RENEW_WINDOW_FACTOR, ConfigDef.Type.DOUBLE, SaslConfigs.DEFAULT_KERBEROS_TICKET_RENEW_WINDOW_FACTOR, ConfigDef.Importance.LOW, SaslConfigs.SASL_KERBEROS_TICKET_RENEW_WINDOW_FACTOR_DOC) + .define(SaslConfigs.SASL_KERBEROS_TICKET_RENEW_JITTER, ConfigDef.Type.DOUBLE, SaslConfigs.DEFAULT_KERBEROS_TICKET_RENEW_JITTER, ConfigDef.Importance.LOW, SaslConfigs.SASL_KERBEROS_TICKET_RENEW_JITTER_DOC) + .define(SaslConfigs.SASL_KERBEROS_MIN_TIME_BEFORE_RELOGIN, ConfigDef.Type.LONG, SaslConfigs.DEFAULT_KERBEROS_MIN_TIME_BEFORE_RELOGIN, ConfigDef.Importance.LOW, SaslConfigs.SASL_KERBEROS_MIN_TIME_BEFORE_RELOGIN_DOC) + .define(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_FACTOR, ConfigDef.Type.DOUBLE, SaslConfigs.DEFAULT_LOGIN_REFRESH_WINDOW_FACTOR, Range.between(0.5, 1.0), ConfigDef.Importance.LOW, SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_FACTOR_DOC) + .define(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_JITTER, ConfigDef.Type.DOUBLE, SaslConfigs.DEFAULT_LOGIN_REFRESH_WINDOW_JITTER, Range.between(0.0, 0.25), ConfigDef.Importance.LOW, SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_JITTER_DOC) + .define(SaslConfigs.SASL_LOGIN_REFRESH_MIN_PERIOD_SECONDS, ConfigDef.Type.SHORT, SaslConfigs.DEFAULT_LOGIN_REFRESH_MIN_PERIOD_SECONDS, Range.between(0, 900), ConfigDef.Importance.LOW, SaslConfigs.SASL_LOGIN_REFRESH_MIN_PERIOD_SECONDS_DOC) + .define(SaslConfigs.SASL_LOGIN_REFRESH_BUFFER_SECONDS, ConfigDef.Type.SHORT, SaslConfigs.DEFAULT_LOGIN_REFRESH_BUFFER_SECONDS, Range.between(0, 3600), ConfigDef.Importance.LOW, SaslConfigs.SASL_LOGIN_REFRESH_BUFFER_SECONDS_DOC) + .define(SaslConfigs.SASL_MECHANISM, ConfigDef.Type.STRING, SaslConfigs.DEFAULT_SASL_MECHANISM, ConfigDef.Importance.MEDIUM, SaslConfigs.SASL_MECHANISM_DOC) + .define(SaslConfigs.SASL_JAAS_CONFIG, ConfigDef.Type.PASSWORD, null, ConfigDef.Importance.MEDIUM, SaslConfigs.SASL_JAAS_CONFIG_DOC) + .define(SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS, ConfigDef.Type.CLASS, null, ConfigDef.Importance.MEDIUM, SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS_DOC) + .define(SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS, ConfigDef.Type.CLASS, null, ConfigDef.Importance.MEDIUM, SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS_DOC) + .define(SaslConfigs.SASL_LOGIN_CLASS, ConfigDef.Type.CLASS, null, ConfigDef.Importance.MEDIUM, SaslConfigs.SASL_LOGIN_CLASS_DOC) + .define(SaslConfigs.SASL_LOGIN_CONNECT_TIMEOUT_MS, ConfigDef.Type.INT, null, ConfigDef.Importance.LOW, SASL_LOGIN_CONNECT_TIMEOUT_MS_DOC) + .define(SaslConfigs.SASL_LOGIN_READ_TIMEOUT_MS, ConfigDef.Type.INT, null, ConfigDef.Importance.LOW, SASL_LOGIN_READ_TIMEOUT_MS_DOC) + .define(SaslConfigs.SASL_LOGIN_RETRY_BACKOFF_MAX_MS, ConfigDef.Type.LONG, DEFAULT_SASL_LOGIN_RETRY_BACKOFF_MAX_MS, ConfigDef.Importance.LOW, SASL_LOGIN_RETRY_BACKOFF_MAX_MS_DOC) + .define(SaslConfigs.SASL_LOGIN_RETRY_BACKOFF_MS, ConfigDef.Type.LONG, DEFAULT_SASL_LOGIN_RETRY_BACKOFF_MS, ConfigDef.Importance.LOW, SASL_LOGIN_RETRY_BACKOFF_MS_DOC) + .define(SaslConfigs.SASL_OAUTHBEARER_SCOPE_CLAIM_NAME, ConfigDef.Type.STRING, DEFAULT_SASL_OAUTHBEARER_SCOPE_CLAIM_NAME, ConfigDef.Importance.LOW, SASL_OAUTHBEARER_SCOPE_CLAIM_NAME_DOC) + .define(SaslConfigs.SASL_OAUTHBEARER_SUB_CLAIM_NAME, ConfigDef.Type.STRING, DEFAULT_SASL_OAUTHBEARER_SUB_CLAIM_NAME, ConfigDef.Importance.LOW, SASL_OAUTHBEARER_SUB_CLAIM_NAME_DOC) + .define(SaslConfigs.SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, ConfigDef.Type.STRING, null, ConfigDef.Importance.MEDIUM, SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL_DOC) + .define(SaslConfigs.SASL_OAUTHBEARER_JWKS_ENDPOINT_URL, ConfigDef.Type.STRING, null, ConfigDef.Importance.MEDIUM, SASL_OAUTHBEARER_JWKS_ENDPOINT_URL_DOC) + .define(SaslConfigs.SASL_OAUTHBEARER_JWKS_ENDPOINT_REFRESH_MS, ConfigDef.Type.LONG, DEFAULT_SASL_OAUTHBEARER_JWKS_ENDPOINT_REFRESH_MS, ConfigDef.Importance.LOW, SASL_OAUTHBEARER_JWKS_ENDPOINT_REFRESH_MS_DOC) + .define(SaslConfigs.SASL_OAUTHBEARER_JWKS_ENDPOINT_RETRY_BACKOFF_MAX_MS, ConfigDef.Type.LONG, DEFAULT_SASL_OAUTHBEARER_JWKS_ENDPOINT_RETRY_BACKOFF_MAX_MS, ConfigDef.Importance.LOW, SASL_OAUTHBEARER_JWKS_ENDPOINT_RETRY_BACKOFF_MAX_MS_DOC) + .define(SaslConfigs.SASL_OAUTHBEARER_JWKS_ENDPOINT_RETRY_BACKOFF_MS, ConfigDef.Type.LONG, DEFAULT_SASL_OAUTHBEARER_JWKS_ENDPOINT_RETRY_BACKOFF_MS, ConfigDef.Importance.LOW, SASL_OAUTHBEARER_JWKS_ENDPOINT_RETRY_BACKOFF_MS_DOC) + .define(SaslConfigs.SASL_OAUTHBEARER_CLOCK_SKEW_SECONDS, ConfigDef.Type.INT, DEFAULT_SASL_OAUTHBEARER_CLOCK_SKEW_SECONDS, ConfigDef.Importance.LOW, SASL_OAUTHBEARER_CLOCK_SKEW_SECONDS_DOC) + .define(SaslConfigs.SASL_OAUTHBEARER_EXPECTED_AUDIENCE, ConfigDef.Type.LIST, null, ConfigDef.Importance.LOW, SASL_OAUTHBEARER_EXPECTED_AUDIENCE_DOC) + .define(SaslConfigs.SASL_OAUTHBEARER_EXPECTED_ISSUER, ConfigDef.Type.STRING, null, ConfigDef.Importance.LOW, SASL_OAUTHBEARER_EXPECTED_ISSUER_DOC); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/config/SecurityConfig.java b/clients/src/main/java/org/apache/kafka/common/config/SecurityConfig.java new file mode 100644 index 0000000..b4dc26c --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/config/SecurityConfig.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.config; + +/** + * Contains the common security config for SSL and SASL + */ +public class SecurityConfig { + + public static final String SECURITY_PROVIDERS_CONFIG = "security.providers"; + public static final String SECURITY_PROVIDERS_DOC = "A list of configurable creator classes each returning a provider" + + " implementing security algorithms. These classes should implement the" + + " org.apache.kafka.common.security.auth.SecurityProviderCreator interface."; + +} diff --git a/clients/src/main/java/org/apache/kafka/common/config/SslClientAuth.java b/clients/src/main/java/org/apache/kafka/common/config/SslClientAuth.java new file mode 100644 index 0000000..9d85b18 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/config/SslClientAuth.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.config; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Locale; + +/** + * Describes whether the server should require or request client authentication. + */ +public enum SslClientAuth { + REQUIRED, + REQUESTED, + NONE; + + public static final List VALUES = + Collections.unmodifiableList(Arrays.asList(SslClientAuth.values())); + + public static SslClientAuth forConfig(String key) { + if (key == null) { + return SslClientAuth.NONE; + } + String upperCaseKey = key.toUpperCase(Locale.ROOT); + for (SslClientAuth auth : VALUES) { + if (auth.name().equals(upperCaseKey)) { + return auth; + } + } + return null; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/config/SslConfigs.java b/clients/src/main/java/org/apache/kafka/common/config/SslConfigs.java new file mode 100644 index 0000000..d7ed803 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/config/SslConfigs.java @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.config; + +import org.apache.kafka.common.config.ConfigDef.Type; +import org.apache.kafka.common.config.internals.BrokerSecurityConfigs; +import org.apache.kafka.common.utils.Java; +import org.apache.kafka.common.utils.Utils; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.TrustManagerFactory; +import java.util.Set; + +public class SslConfigs { + /* + * NOTE: DO NOT CHANGE EITHER CONFIG NAMES AS THESE ARE PART OF THE PUBLIC API AND CHANGE WILL BREAK USER CODE. + */ + + public static final String SSL_PROTOCOL_CONFIG = "ssl.protocol"; + public static final String SSL_PROTOCOL_DOC = "The SSL protocol used to generate the SSLContext. " + + "The default is 'TLSv1.3' when running with Java 11 or newer, 'TLSv1.2' otherwise. " + + "This value should be fine for most use cases. " + + "Allowed values in recent JVMs are 'TLSv1.2' and 'TLSv1.3'. 'TLS', 'TLSv1.1', 'SSL', 'SSLv2' and 'SSLv3' " + + "may be supported in older JVMs, but their usage is discouraged due to known security vulnerabilities. " + + "With the default value for this config and 'ssl.enabled.protocols', clients will downgrade to 'TLSv1.2' if " + + "the server does not support 'TLSv1.3'. If this config is set to 'TLSv1.2', clients will not use 'TLSv1.3' even " + + "if it is one of the values in ssl.enabled.protocols and the server only supports 'TLSv1.3'."; + + public static final String DEFAULT_SSL_PROTOCOL; + + public static final String SSL_PROVIDER_CONFIG = "ssl.provider"; + public static final String SSL_PROVIDER_DOC = "The name of the security provider used for SSL connections. Default value is the default security provider of the JVM."; + + public static final String SSL_CIPHER_SUITES_CONFIG = "ssl.cipher.suites"; + public static final String SSL_CIPHER_SUITES_DOC = "A list of cipher suites. This is a named combination of authentication, encryption, MAC and key exchange algorithm used to negotiate the security settings for a network connection using TLS or SSL network protocol. " + + "By default all the available cipher suites are supported."; + + public static final String SSL_ENABLED_PROTOCOLS_CONFIG = "ssl.enabled.protocols"; + public static final String SSL_ENABLED_PROTOCOLS_DOC = "The list of protocols enabled for SSL connections. " + + "The default is 'TLSv1.2,TLSv1.3' when running with Java 11 or newer, 'TLSv1.2' otherwise. With the " + + "default value for Java 11, clients and servers will prefer TLSv1.3 if both support it and fallback " + + "to TLSv1.2 otherwise (assuming both support at least TLSv1.2). This default should be fine for most " + + "cases. Also see the config documentation for `ssl.protocol`."; + public static final String DEFAULT_SSL_ENABLED_PROTOCOLS; + + static { + if (Java.IS_JAVA11_COMPATIBLE) { + DEFAULT_SSL_PROTOCOL = "TLSv1.3"; + DEFAULT_SSL_ENABLED_PROTOCOLS = "TLSv1.2,TLSv1.3"; + } else { + DEFAULT_SSL_PROTOCOL = "TLSv1.2"; + DEFAULT_SSL_ENABLED_PROTOCOLS = "TLSv1.2"; + } + } + + public static final String SSL_KEYSTORE_TYPE_CONFIG = "ssl.keystore.type"; + public static final String SSL_KEYSTORE_TYPE_DOC = "The file format of the key store file. " + + "This is optional for client."; + public static final String DEFAULT_SSL_KEYSTORE_TYPE = "JKS"; + + public static final String SSL_KEYSTORE_KEY_CONFIG = "ssl.keystore.key"; + public static final String SSL_KEYSTORE_KEY_DOC = "Private key in the format specified by 'ssl.keystore.type'. " + + "Default SSL engine factory supports only PEM format with PKCS#8 keys. If the key is encrypted, " + + "key password must be specified using 'ssl.key.password'"; + + public static final String SSL_KEYSTORE_CERTIFICATE_CHAIN_CONFIG = "ssl.keystore.certificate.chain"; + public static final String SSL_KEYSTORE_CERTIFICATE_CHAIN_DOC = "Certificate chain in the format specified by 'ssl.keystore.type'. " + + "Default SSL engine factory supports only PEM format with a list of X.509 certificates"; + + public static final String SSL_TRUSTSTORE_CERTIFICATES_CONFIG = "ssl.truststore.certificates"; + public static final String SSL_TRUSTSTORE_CERTIFICATES_DOC = "Trusted certificates in the format specified by 'ssl.truststore.type'. " + + "Default SSL engine factory supports only PEM format with X.509 certificates."; + + public static final String SSL_KEYSTORE_LOCATION_CONFIG = "ssl.keystore.location"; + public static final String SSL_KEYSTORE_LOCATION_DOC = "The location of the key store file. " + + "This is optional for client and can be used for two-way authentication for client."; + + public static final String SSL_KEYSTORE_PASSWORD_CONFIG = "ssl.keystore.password"; + public static final String SSL_KEYSTORE_PASSWORD_DOC = "The store password for the key store file. " + + "This is optional for client and only needed if 'ssl.keystore.location' is configured. " + + " Key store password is not supported for PEM format."; + + public static final String SSL_KEY_PASSWORD_CONFIG = "ssl.key.password"; + public static final String SSL_KEY_PASSWORD_DOC = "The password of the private key in the key store file or" + + "the PEM key specified in `ssl.keystore.key'. This is required for clients only if two-way authentication is configured."; + + public static final String SSL_TRUSTSTORE_TYPE_CONFIG = "ssl.truststore.type"; + public static final String SSL_TRUSTSTORE_TYPE_DOC = "The file format of the trust store file."; + public static final String DEFAULT_SSL_TRUSTSTORE_TYPE = "JKS"; + + public static final String SSL_TRUSTSTORE_LOCATION_CONFIG = "ssl.truststore.location"; + public static final String SSL_TRUSTSTORE_LOCATION_DOC = "The location of the trust store file. "; + + public static final String SSL_TRUSTSTORE_PASSWORD_CONFIG = "ssl.truststore.password"; + public static final String SSL_TRUSTSTORE_PASSWORD_DOC = "The password for the trust store file. " + + "If a password is not set, trust store file configured will still be used, but integrity checking is disabled. " + + "Trust store password is not supported for PEM format."; + + public static final String SSL_KEYMANAGER_ALGORITHM_CONFIG = "ssl.keymanager.algorithm"; + public static final String SSL_KEYMANAGER_ALGORITHM_DOC = "The algorithm used by key manager factory for SSL connections. " + + "Default value is the key manager factory algorithm configured for the Java Virtual Machine."; + public static final String DEFAULT_SSL_KEYMANGER_ALGORITHM = KeyManagerFactory.getDefaultAlgorithm(); + + public static final String SSL_TRUSTMANAGER_ALGORITHM_CONFIG = "ssl.trustmanager.algorithm"; + public static final String SSL_TRUSTMANAGER_ALGORITHM_DOC = "The algorithm used by trust manager factory for SSL connections. " + + "Default value is the trust manager factory algorithm configured for the Java Virtual Machine."; + public static final String DEFAULT_SSL_TRUSTMANAGER_ALGORITHM = TrustManagerFactory.getDefaultAlgorithm(); + + public static final String SSL_ENDPOINT_IDENTIFICATION_ALGORITHM_CONFIG = "ssl.endpoint.identification.algorithm"; + public static final String SSL_ENDPOINT_IDENTIFICATION_ALGORITHM_DOC = "The endpoint identification algorithm to validate server hostname using server certificate. "; + public static final String DEFAULT_SSL_ENDPOINT_IDENTIFICATION_ALGORITHM = "https"; + + public static final String SSL_SECURE_RANDOM_IMPLEMENTATION_CONFIG = "ssl.secure.random.implementation"; + public static final String SSL_SECURE_RANDOM_IMPLEMENTATION_DOC = "The SecureRandom PRNG implementation to use for SSL cryptography operations. "; + + public static final String SSL_ENGINE_FACTORY_CLASS_CONFIG = "ssl.engine.factory.class"; + public static final String SSL_ENGINE_FACTORY_CLASS_DOC = "The class of type org.apache.kafka.common.security.auth.SslEngineFactory to provide SSLEngine objects. Default value is org.apache.kafka.common.security.ssl.DefaultSslEngineFactory"; + + public static void addClientSslSupport(ConfigDef config) { + config.define(SslConfigs.SSL_PROTOCOL_CONFIG, ConfigDef.Type.STRING, SslConfigs.DEFAULT_SSL_PROTOCOL, ConfigDef.Importance.MEDIUM, SslConfigs.SSL_PROTOCOL_DOC) + .define(SslConfigs.SSL_PROVIDER_CONFIG, ConfigDef.Type.STRING, null, ConfigDef.Importance.MEDIUM, SslConfigs.SSL_PROVIDER_DOC) + .define(SslConfigs.SSL_CIPHER_SUITES_CONFIG, ConfigDef.Type.LIST, null, ConfigDef.Importance.LOW, SslConfigs.SSL_CIPHER_SUITES_DOC) + .define(SslConfigs.SSL_ENABLED_PROTOCOLS_CONFIG, ConfigDef.Type.LIST, SslConfigs.DEFAULT_SSL_ENABLED_PROTOCOLS, ConfigDef.Importance.MEDIUM, SslConfigs.SSL_ENABLED_PROTOCOLS_DOC) + .define(SslConfigs.SSL_KEYSTORE_TYPE_CONFIG, ConfigDef.Type.STRING, SslConfigs.DEFAULT_SSL_KEYSTORE_TYPE, ConfigDef.Importance.MEDIUM, SslConfigs.SSL_KEYSTORE_TYPE_DOC) + .define(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG, ConfigDef.Type.STRING, null, ConfigDef.Importance.HIGH, SslConfigs.SSL_KEYSTORE_LOCATION_DOC) + .define(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG, ConfigDef.Type.PASSWORD, null, ConfigDef.Importance.HIGH, SslConfigs.SSL_KEYSTORE_PASSWORD_DOC) + .define(SslConfigs.SSL_KEY_PASSWORD_CONFIG, ConfigDef.Type.PASSWORD, null, ConfigDef.Importance.HIGH, SslConfigs.SSL_KEY_PASSWORD_DOC) + .define(SslConfigs.SSL_KEYSTORE_KEY_CONFIG, Type.PASSWORD, null, ConfigDef.Importance.HIGH, SslConfigs.SSL_KEYSTORE_KEY_DOC) + .define(SslConfigs.SSL_KEYSTORE_CERTIFICATE_CHAIN_CONFIG, ConfigDef.Type.PASSWORD, null, ConfigDef.Importance.HIGH, SslConfigs.SSL_KEYSTORE_CERTIFICATE_CHAIN_DOC) + .define(SslConfigs.SSL_TRUSTSTORE_CERTIFICATES_CONFIG, ConfigDef.Type.PASSWORD, null, ConfigDef.Importance.HIGH, SslConfigs.SSL_TRUSTSTORE_CERTIFICATES_DOC) + .define(SslConfigs.SSL_TRUSTSTORE_TYPE_CONFIG, ConfigDef.Type.STRING, SslConfigs.DEFAULT_SSL_TRUSTSTORE_TYPE, ConfigDef.Importance.MEDIUM, SslConfigs.SSL_TRUSTSTORE_TYPE_DOC) + .define(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG, ConfigDef.Type.STRING, null, ConfigDef.Importance.HIGH, SslConfigs.SSL_TRUSTSTORE_LOCATION_DOC) + .define(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG, ConfigDef.Type.PASSWORD, null, ConfigDef.Importance.HIGH, SslConfigs.SSL_TRUSTSTORE_PASSWORD_DOC) + .define(SslConfigs.SSL_KEYMANAGER_ALGORITHM_CONFIG, ConfigDef.Type.STRING, SslConfigs.DEFAULT_SSL_KEYMANGER_ALGORITHM, ConfigDef.Importance.LOW, SslConfigs.SSL_KEYMANAGER_ALGORITHM_DOC) + .define(SslConfigs.SSL_TRUSTMANAGER_ALGORITHM_CONFIG, ConfigDef.Type.STRING, SslConfigs.DEFAULT_SSL_TRUSTMANAGER_ALGORITHM, ConfigDef.Importance.LOW, SslConfigs.SSL_TRUSTMANAGER_ALGORITHM_DOC) + .define(SslConfigs.SSL_ENDPOINT_IDENTIFICATION_ALGORITHM_CONFIG, ConfigDef.Type.STRING, SslConfigs.DEFAULT_SSL_ENDPOINT_IDENTIFICATION_ALGORITHM, ConfigDef.Importance.LOW, SslConfigs.SSL_ENDPOINT_IDENTIFICATION_ALGORITHM_DOC) + .define(SslConfigs.SSL_SECURE_RANDOM_IMPLEMENTATION_CONFIG, ConfigDef.Type.STRING, null, ConfigDef.Importance.LOW, SslConfigs.SSL_SECURE_RANDOM_IMPLEMENTATION_DOC) + .define(SslConfigs.SSL_ENGINE_FACTORY_CLASS_CONFIG, ConfigDef.Type.CLASS, null, ConfigDef.Importance.LOW, SslConfigs.SSL_ENGINE_FACTORY_CLASS_DOC); + } + + public static final Set RECONFIGURABLE_CONFIGS = Utils.mkSet( + SslConfigs.SSL_KEYSTORE_TYPE_CONFIG, + SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG, + SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG, + SslConfigs.SSL_KEY_PASSWORD_CONFIG, + SslConfigs.SSL_TRUSTSTORE_TYPE_CONFIG, + SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG, + SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG, + SslConfigs.SSL_KEYSTORE_CERTIFICATE_CHAIN_CONFIG, + SslConfigs.SSL_KEYSTORE_KEY_CONFIG, + SslConfigs.SSL_TRUSTSTORE_CERTIFICATES_CONFIG); + + public static final Set NON_RECONFIGURABLE_CONFIGS = Utils.mkSet( + BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, + SslConfigs.SSL_PROTOCOL_CONFIG, + SslConfigs.SSL_PROVIDER_CONFIG, + SslConfigs.SSL_CIPHER_SUITES_CONFIG, + SslConfigs.SSL_ENABLED_PROTOCOLS_CONFIG, + SslConfigs.SSL_KEYMANAGER_ALGORITHM_CONFIG, + SslConfigs.SSL_TRUSTMANAGER_ALGORITHM_CONFIG, + SslConfigs.SSL_ENDPOINT_IDENTIFICATION_ALGORITHM_CONFIG, + SslConfigs.SSL_SECURE_RANDOM_IMPLEMENTATION_CONFIG, + SslConfigs.SSL_ENGINE_FACTORY_CLASS_CONFIG); +} diff --git a/clients/src/main/java/org/apache/kafka/common/config/TopicConfig.java b/clients/src/main/java/org/apache/kafka/common/config/TopicConfig.java new file mode 100755 index 0000000..73439c5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/config/TopicConfig.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.config; + +/** + *

        Keys that can be used to configure a topic. These keys are useful when creating or reconfiguring a + * topic using the AdminClient. + * + *

        The intended pattern is for broker configs to include a `log.` prefix. For example, to set the default broker + * cleanup policy, one would set log.cleanup.policy instead of cleanup.policy. Unfortunately, there are many cases + * where this pattern is not followed. + */ +// This is a public API, so we should not remove or alter keys without a discussion and a deprecation period. +// Eventually this should replace LogConfig.scala. +public class TopicConfig { + public static final String SEGMENT_BYTES_CONFIG = "segment.bytes"; + public static final String SEGMENT_BYTES_DOC = "This configuration controls the segment file size for " + + "the log. Retention and cleaning is always done a file at a time so a larger segment size means " + + "fewer files but less granular control over retention."; + + public static final String SEGMENT_MS_CONFIG = "segment.ms"; + public static final String SEGMENT_MS_DOC = "This configuration controls the period of time after " + + "which Kafka will force the log to roll even if the segment file isn't full to ensure that retention " + + "can delete or compact old data."; + + public static final String SEGMENT_JITTER_MS_CONFIG = "segment.jitter.ms"; + public static final String SEGMENT_JITTER_MS_DOC = "The maximum random jitter subtracted from the scheduled " + + "segment roll time to avoid thundering herds of segment rolling"; + + public static final String SEGMENT_INDEX_BYTES_CONFIG = "segment.index.bytes"; + public static final String SEGMENT_INDEX_BYTES_DOC = "This configuration controls the size of the index that " + + "maps offsets to file positions. We preallocate this index file and shrink it only after log " + + "rolls. You generally should not need to change this setting."; + + public static final String FLUSH_MESSAGES_INTERVAL_CONFIG = "flush.messages"; + public static final String FLUSH_MESSAGES_INTERVAL_DOC = "This setting allows specifying an interval at " + + "which we will force an fsync of data written to the log. For example if this was set to 1 " + + "we would fsync after every message; if it were 5 we would fsync after every five messages. " + + "In general we recommend you not set this and use replication for durability and allow the " + + "operating system's background flush capabilities as it is more efficient. This setting can " + + "be overridden on a per-topic basis (see the per-topic configuration section)."; + + public static final String FLUSH_MS_CONFIG = "flush.ms"; + public static final String FLUSH_MS_DOC = "This setting allows specifying a time interval at which we will " + + "force an fsync of data written to the log. For example if this was set to 1000 " + + "we would fsync after 1000 ms had passed. In general we recommend you not set " + + "this and use replication for durability and allow the operating system's background " + + "flush capabilities as it is more efficient."; + + public static final String RETENTION_BYTES_CONFIG = "retention.bytes"; + public static final String RETENTION_BYTES_DOC = "This configuration controls the maximum size a partition " + + "(which consists of log segments) can grow to before we will discard old log segments to free up space if we " + + "are using the \"delete\" retention policy. By default there is no size limit only a time limit. " + + "Since this limit is enforced at the partition level, multiply it by the number of partitions to compute " + + "the topic retention in bytes."; + + public static final String RETENTION_MS_CONFIG = "retention.ms"; + public static final String RETENTION_MS_DOC = "This configuration controls the maximum time we will retain a " + + "log before we will discard old log segments to free up space if we are using the " + + "\"delete\" retention policy. This represents an SLA on how soon consumers must read " + + "their data. If set to -1, no time limit is applied."; + + public static final String REMOTE_LOG_STORAGE_ENABLE_CONFIG = "remote.storage.enable"; + public static final String REMOTE_LOG_STORAGE_ENABLE_DOC = "To enable tier storage for a topic, set `remote.storage.enable` as true. " + + "You can not disable this config once it is enabled. It will be provided in future versions."; + + public static final String LOCAL_LOG_RETENTION_MS_CONFIG = "local.retention.ms"; + public static final String LOCAL_LOG_RETENTION_MS_DOC = "The number of milli seconds to keep the local log segment before it gets deleted. " + + "Default value is -2, it represents `retention.ms` value is to be used. The effective value should always be less than or equal " + + "to `retention.ms` value."; + + public static final String LOCAL_LOG_RETENTION_BYTES_CONFIG = "local.retention.bytes"; + public static final String LOCAL_LOG_RETENTION_BYTES_DOC = "The maximum size of local log segments that can grow for a partition before it " + + "deletes the old segments. Default value is -2, it represents `retention.bytes` value to be used. The effective value should always be " + + "less than or equal to `retention.bytes` value."; + + public static final String MAX_MESSAGE_BYTES_CONFIG = "max.message.bytes"; + public static final String MAX_MESSAGE_BYTES_DOC = + "The largest record batch size allowed by Kafka (after compression if compression is enabled). " + + "If this is increased and there are consumers older than 0.10.2, the consumers' fetch " + + "size must also be increased so that they can fetch record batches this large. " + + "In the latest message format version, records are always grouped into batches for efficiency. " + + "In previous message format versions, uncompressed records are not grouped into batches and this " + + "limit only applies to a single record in that case."; + + public static final String INDEX_INTERVAL_BYTES_CONFIG = "index.interval.bytes"; + public static final String INDEX_INTERVAL_BYTES_DOCS = "This setting controls how frequently " + + "Kafka adds an index entry to its offset index. The default setting ensures that we index a " + + "message roughly every 4096 bytes. More indexing allows reads to jump closer to the exact " + + "position in the log but makes the index larger. You probably don't need to change this."; + + public static final String FILE_DELETE_DELAY_MS_CONFIG = "file.delete.delay.ms"; + public static final String FILE_DELETE_DELAY_MS_DOC = "The time to wait before deleting a file from the " + + "filesystem"; + + public static final String DELETE_RETENTION_MS_CONFIG = "delete.retention.ms"; + public static final String DELETE_RETENTION_MS_DOC = "The amount of time to retain delete tombstone markers " + + "for log compacted topics. This setting also gives a bound " + + "on the time in which a consumer must complete a read if they begin from offset 0 " + + "to ensure that they get a valid snapshot of the final stage (otherwise delete " + + "tombstones may be collected before they complete their scan)."; + + public static final String MIN_COMPACTION_LAG_MS_CONFIG = "min.compaction.lag.ms"; + public static final String MIN_COMPACTION_LAG_MS_DOC = "The minimum time a message will remain " + + "uncompacted in the log. Only applicable for logs that are being compacted."; + + public static final String MAX_COMPACTION_LAG_MS_CONFIG = "max.compaction.lag.ms"; + public static final String MAX_COMPACTION_LAG_MS_DOC = "The maximum time a message will remain " + + "ineligible for compaction in the log. Only applicable for logs that are being compacted."; + + public static final String MIN_CLEANABLE_DIRTY_RATIO_CONFIG = "min.cleanable.dirty.ratio"; + public static final String MIN_CLEANABLE_DIRTY_RATIO_DOC = "This configuration controls how frequently " + + "the log compactor will attempt to clean the log (assuming log " + + "compaction is enabled). By default we will avoid cleaning a log where more than " + + "50% of the log has been compacted. This ratio bounds the maximum space wasted in " + + "the log by duplicates (at 50% at most 50% of the log could be duplicates). A " + + "higher ratio will mean fewer, more efficient cleanings but will mean more wasted " + + "space in the log. If the " + MAX_COMPACTION_LAG_MS_CONFIG + " or the " + MIN_COMPACTION_LAG_MS_CONFIG + + " configurations are also specified, then the log compactor considers the log to be eligible for compaction " + + "as soon as either: (i) the dirty ratio threshold has been met and the log has had dirty (uncompacted) " + + "records for at least the " + MIN_COMPACTION_LAG_MS_CONFIG + " duration, or (ii) if the log has had " + + "dirty (uncompacted) records for at most the " + MAX_COMPACTION_LAG_MS_CONFIG + " period."; + + public static final String CLEANUP_POLICY_CONFIG = "cleanup.policy"; + public static final String CLEANUP_POLICY_COMPACT = "compact"; + public static final String CLEANUP_POLICY_DELETE = "delete"; + public static final String CLEANUP_POLICY_DOC = "A string that is either \"" + CLEANUP_POLICY_DELETE + + "\" or \"" + CLEANUP_POLICY_COMPACT + "\" or both. This string designates the retention policy to use on " + + "old log segments. The default policy (\"delete\") will discard old segments when their retention " + + "time or size limit has been reached. The \"compact\" setting will enable log " + + "compaction on the topic."; + + public static final String UNCLEAN_LEADER_ELECTION_ENABLE_CONFIG = "unclean.leader.election.enable"; + public static final String UNCLEAN_LEADER_ELECTION_ENABLE_DOC = "Indicates whether to enable replicas " + + "not in the ISR set to be elected as leader as a last resort, even though doing so may result in data " + + "loss."; + + public static final String MIN_IN_SYNC_REPLICAS_CONFIG = "min.insync.replicas"; + public static final String MIN_IN_SYNC_REPLICAS_DOC = "When a producer sets acks to \"all\" (or \"-1\"), " + + "this configuration specifies the minimum number of replicas that must acknowledge " + + "a write for the write to be considered successful. If this minimum cannot be met, " + + "then the producer will raise an exception (either NotEnoughReplicas or " + + "NotEnoughReplicasAfterAppend).
        When used together, min.insync.replicas and acks " + + "allow you to enforce greater durability guarantees. A typical scenario would be to " + + "create a topic with a replication factor of 3, set min.insync.replicas to 2, and " + + "produce with acks of \"all\". This will ensure that the producer raises an exception " + + "if a majority of replicas do not receive a write."; + + public static final String COMPRESSION_TYPE_CONFIG = "compression.type"; + public static final String COMPRESSION_TYPE_DOC = "Specify the final compression type for a given topic. " + + "This configuration accepts the standard compression codecs ('gzip', 'snappy', 'lz4', 'zstd'). It additionally " + + "accepts 'uncompressed' which is equivalent to no compression; and 'producer' which means retain the " + + "original compression codec set by the producer."; + + public static final String PREALLOCATE_CONFIG = "preallocate"; + public static final String PREALLOCATE_DOC = "True if we should preallocate the file on disk when " + + "creating a new log segment."; + + /** + * @deprecated since 3.0, removal planned in 4.0. The default value for this config is appropriate + * for most situations. + */ + @Deprecated + public static final String MESSAGE_FORMAT_VERSION_CONFIG = "message.format.version"; + + /** + * @deprecated since 3.0, removal planned in 4.0. The default value for this config is appropriate + * for most situations. + */ + @Deprecated + public static final String MESSAGE_FORMAT_VERSION_DOC = "[DEPRECATED] Specify the message format version the broker " + + "will use to append messages to the logs. The value of this config is always assumed to be `3.0` if " + + "`inter.broker.protocol.version` is 3.0 or higher (the actual config value is ignored). Otherwise, the value should " + + "be a valid ApiVersion. Some examples are: 0.10.0, 1.1, 2.8, 3.0. By setting a particular message format version, the " + + "user is certifying that all the existing messages on disk are smaller or equal than the specified version. Setting " + + "this value incorrectly will cause consumers with older versions to break as they will receive messages with a format " + + "that they don't understand."; + + public static final String MESSAGE_TIMESTAMP_TYPE_CONFIG = "message.timestamp.type"; + public static final String MESSAGE_TIMESTAMP_TYPE_DOC = "Define whether the timestamp in the message is " + + "message create time or log append time. The value should be either `CreateTime` or `LogAppendTime`"; + + public static final String MESSAGE_TIMESTAMP_DIFFERENCE_MAX_MS_CONFIG = "message.timestamp.difference.max.ms"; + public static final String MESSAGE_TIMESTAMP_DIFFERENCE_MAX_MS_DOC = "The maximum difference allowed between " + + "the timestamp when a broker receives a message and the timestamp specified in the message. If " + + "message.timestamp.type=CreateTime, a message will be rejected if the difference in timestamp " + + "exceeds this threshold. This configuration is ignored if message.timestamp.type=LogAppendTime."; + + public static final String MESSAGE_DOWNCONVERSION_ENABLE_CONFIG = "message.downconversion.enable"; + public static final String MESSAGE_DOWNCONVERSION_ENABLE_DOC = "This configuration controls whether " + + "down-conversion of message formats is enabled to satisfy consume requests. When set to false, " + + "broker will not perform down-conversion for consumers expecting an older message format. The broker responds " + + "with UNSUPPORTED_VERSION error for consume requests from such older clients. This configuration" + + "does not apply to any message format conversion that might be required for replication to followers."; +} diff --git a/clients/src/main/java/org/apache/kafka/common/config/internals/BrokerSecurityConfigs.java b/clients/src/main/java/org/apache/kafka/common/config/internals/BrokerSecurityConfigs.java new file mode 100644 index 0000000..0b90da8 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/config/internals/BrokerSecurityConfigs.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.config.internals; + +import org.apache.kafka.common.config.SaslConfigs; + +import java.util.Collections; +import java.util.List; + +/** + * Common home for broker-side security configs which need to be accessible from the libraries shared + * between the broker and the client. + * + * Note this is an internal API and subject to change without notice. + */ +public class BrokerSecurityConfigs { + + public static final String PRINCIPAL_BUILDER_CLASS_CONFIG = "principal.builder.class"; + public static final String SASL_KERBEROS_PRINCIPAL_TO_LOCAL_RULES_CONFIG = "sasl.kerberos.principal.to.local.rules"; + public static final String SSL_CLIENT_AUTH_CONFIG = "ssl.client.auth"; + public static final String SASL_ENABLED_MECHANISMS_CONFIG = "sasl.enabled.mechanisms"; + public static final String SASL_SERVER_CALLBACK_HANDLER_CLASS = "sasl.server.callback.handler.class"; + public static final String SSL_PRINCIPAL_MAPPING_RULES_CONFIG = "ssl.principal.mapping.rules"; + public static final String CONNECTIONS_MAX_REAUTH_MS = "connections.max.reauth.ms"; + + public static final String PRINCIPAL_BUILDER_CLASS_DOC = "The fully qualified name of a class that implements the " + + "KafkaPrincipalBuilder interface, which is used to build the KafkaPrincipal object used during " + + "authorization. If no principal builder is defined, the default behavior depends " + + "on the security protocol in use. For SSL authentication, the principal will be derived using the " + + "rules defined by " + SSL_PRINCIPAL_MAPPING_RULES_CONFIG + " applied on the distinguished " + + "name from the client certificate if one is provided; otherwise, if client authentication is not required, " + + "the principal name will be ANONYMOUS. For SASL authentication, the principal will be derived using the " + + "rules defined by " + SASL_KERBEROS_PRINCIPAL_TO_LOCAL_RULES_CONFIG + " if GSSAPI is in use, " + + "and the SASL authentication ID for other mechanisms. For PLAINTEXT, the principal will be ANONYMOUS."; + + public static final String SSL_PRINCIPAL_MAPPING_RULES_DOC = "A list of rules for mapping from distinguished name" + + " from the client certificate to short name. The rules are evaluated in order and the first rule that matches" + + " a principal name is used to map it to a short name. Any later rules in the list are ignored. By default," + + " distinguished name of the X.500 certificate will be the principal. For more details on the format please" + + " see security authorization and acls. Note that this configuration is ignored" + + " if an extension of KafkaPrincipalBuilder is provided by the " + PRINCIPAL_BUILDER_CLASS_CONFIG + "" + + " configuration."; + public static final String DEFAULT_SSL_PRINCIPAL_MAPPING_RULES = "DEFAULT"; + + public static final String SASL_KERBEROS_PRINCIPAL_TO_LOCAL_RULES_DOC = "A list of rules for mapping from principal " + + "names to short names (typically operating system usernames). The rules are evaluated in order and the " + + "first rule that matches a principal name is used to map it to a short name. Any later rules in the list are " + + "ignored. By default, principal names of the form {username}/{hostname}@{REALM} are mapped to {username}. " + + "For more details on the format please see security authorization and acls. " + + "Note that this configuration is ignored if an extension of KafkaPrincipalBuilder is provided by the " + + "" + PRINCIPAL_BUILDER_CLASS_CONFIG + " configuration."; + public static final List DEFAULT_SASL_KERBEROS_PRINCIPAL_TO_LOCAL_RULES = Collections.singletonList("DEFAULT"); + + public static final String SSL_CLIENT_AUTH_DOC = "Configures kafka broker to request client authentication." + + " The following settings are common: " + + "

          " + + "
        • ssl.client.auth=required If set to required client authentication is required." + + "
        • ssl.client.auth=requested This means client authentication is optional." + + " unlike required, if this option is set client can choose not to provide authentication information about itself" + + "
        • ssl.client.auth=none This means client authentication is not needed." + + "
        "; + + public static final String SASL_ENABLED_MECHANISMS_DOC = "The list of SASL mechanisms enabled in the Kafka server. " + + "The list may contain any mechanism for which a security provider is available. " + + "Only GSSAPI is enabled by default."; + public static final List DEFAULT_SASL_ENABLED_MECHANISMS = Collections.singletonList(SaslConfigs.GSSAPI_MECHANISM); + + public static final String SASL_SERVER_CALLBACK_HANDLER_CLASS_DOC = "The fully qualified name of a SASL server callback handler " + + "class that implements the AuthenticateCallbackHandler interface. Server callback handlers must be prefixed with " + + "listener prefix and SASL mechanism name in lower-case. For example, " + + "listener.name.sasl_ssl.plain.sasl.server.callback.handler.class=com.example.CustomPlainCallbackHandler."; + + public static final String CONNECTIONS_MAX_REAUTH_MS_DOC = "When explicitly set to a positive number (the default is 0, not a positive number), " + + "a session lifetime that will not exceed the configured value will be communicated to v2.2.0 or later clients when they authenticate. " + + "The broker will disconnect any such connection that is not re-authenticated within the session lifetime and that is then subsequently " + + "used for any purpose other than re-authentication. Configuration names can optionally be prefixed with listener prefix and SASL " + + "mechanism name in lower-case. For example, listener.name.sasl_ssl.oauthbearer.connections.max.reauth.ms=3600000"; +} diff --git a/clients/src/main/java/org/apache/kafka/common/config/internals/QuotaConfigs.java b/clients/src/main/java/org/apache/kafka/common/config/internals/QuotaConfigs.java new file mode 100644 index 0000000..543e67b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/config/internals/QuotaConfigs.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.config.internals; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.security.scram.internals.ScramMechanism; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; + +/** + * Define the dynamic quota configs. Note that these are not normal configurations that exist in properties files. They + * only exist dynamically in the controller (or ZK, depending on which mode the cluster is running). + */ +public class QuotaConfigs { + public static final String PRODUCER_BYTE_RATE_OVERRIDE_CONFIG = "producer_byte_rate"; + public static final String CONSUMER_BYTE_RATE_OVERRIDE_CONFIG = "consumer_byte_rate"; + public static final String REQUEST_PERCENTAGE_OVERRIDE_CONFIG = "request_percentage"; + public static final String CONTROLLER_MUTATION_RATE_OVERRIDE_CONFIG = "controller_mutation_rate"; + public static final String IP_CONNECTION_RATE_OVERRIDE_CONFIG = "connection_creation_rate"; + + public static final String PRODUCER_BYTE_RATE_DOC = "A rate representing the upper bound (bytes/sec) for producer traffic."; + public static final String CONSUMER_BYTE_RATE_DOC = "A rate representing the upper bound (bytes/sec) for consumer traffic."; + public static final String REQUEST_PERCENTAGE_DOC = "A percentage representing the upper bound of time spent for processing requests."; + public static final String CONTROLLER_MUTATION_RATE_DOC = "The rate at which mutations are accepted for the create " + + "topics request, the create partitions request and the delete topics request. The rate is accumulated by " + + "the number of partitions created or deleted."; + public static final String IP_CONNECTION_RATE_DOC = "An int representing the upper bound of connections accepted " + + "for the specified IP."; + + public static final int IP_CONNECTION_RATE_DEFAULT = Integer.MAX_VALUE; + + private static Set userClientConfigNames = new HashSet<>(Arrays.asList( + PRODUCER_BYTE_RATE_OVERRIDE_CONFIG, CONSUMER_BYTE_RATE_OVERRIDE_CONFIG, + REQUEST_PERCENTAGE_OVERRIDE_CONFIG, CONTROLLER_MUTATION_RATE_OVERRIDE_CONFIG + )); + + private static void buildUserClientQuotaConfigDef(ConfigDef configDef) { + configDef.define(PRODUCER_BYTE_RATE_OVERRIDE_CONFIG, ConfigDef.Type.LONG, Long.MAX_VALUE, + ConfigDef.Importance.MEDIUM, PRODUCER_BYTE_RATE_DOC); + + configDef.define(CONSUMER_BYTE_RATE_OVERRIDE_CONFIG, ConfigDef.Type.LONG, Long.MAX_VALUE, + ConfigDef.Importance.MEDIUM, CONSUMER_BYTE_RATE_DOC); + + configDef.define(REQUEST_PERCENTAGE_OVERRIDE_CONFIG, ConfigDef.Type.DOUBLE, + Integer.valueOf(Integer.MAX_VALUE).doubleValue(), + ConfigDef.Importance.MEDIUM, REQUEST_PERCENTAGE_DOC); + + configDef.define(CONTROLLER_MUTATION_RATE_OVERRIDE_CONFIG, ConfigDef.Type.DOUBLE, + Integer.valueOf(Integer.MAX_VALUE).doubleValue(), + ConfigDef.Importance.MEDIUM, CONTROLLER_MUTATION_RATE_DOC); + } + + public static boolean isClientOrUserConfig(String name) { + return userClientConfigNames.contains(name); + } + + public static ConfigDef userConfigs() { + ConfigDef configDef = new ConfigDef(); + ScramMechanism.mechanismNames().forEach(mechanismName -> { + configDef.define(mechanismName, ConfigDef.Type.STRING, null, ConfigDef.Importance.MEDIUM, + "User credentials for SCRAM mechanism " + mechanismName); + }); + buildUserClientQuotaConfigDef(configDef); + return configDef; + } + + public static ConfigDef clientConfigs() { + ConfigDef configDef = new ConfigDef(); + buildUserClientQuotaConfigDef(configDef); + return configDef; + } + + public static ConfigDef ipConfigs() { + ConfigDef configDef = new ConfigDef(); + configDef.define(IP_CONNECTION_RATE_OVERRIDE_CONFIG, ConfigDef.Type.INT, Integer.MAX_VALUE, + ConfigDef.Range.atLeast(0), ConfigDef.Importance.MEDIUM, IP_CONNECTION_RATE_DOC); + return configDef; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/config/provider/ConfigProvider.java b/clients/src/main/java/org/apache/kafka/common/config/provider/ConfigProvider.java new file mode 100644 index 0000000..035bcbc --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/config/provider/ConfigProvider.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.config.provider; + +import org.apache.kafka.common.Configurable; +import org.apache.kafka.common.config.ConfigChangeCallback; +import org.apache.kafka.common.config.ConfigData; + +import java.io.Closeable; +import java.util.Set; + +/** + * A provider of configuration data, which may optionally support subscriptions to configuration changes. + * Implementations are required to safely support concurrent calls to any of the methods in this interface. + * Kafka Connect discovers configuration providers using Java's Service Provider mechanism (see {@code java.util.ServiceLoader}). + * To support this, implementations of this interface should also contain a service provider configuration file in {@code META-INF/service/org.apache.kafka.common.config.provider.ConfigProvider}. + */ +public interface ConfigProvider extends Configurable, Closeable { + + /** + * Retrieves the data at the given path. + * + * @param path the path where the data resides + * @return the configuration data + */ + ConfigData get(String path); + + /** + * Retrieves the data with the given keys at the given path. + * + * @param path the path where the data resides + * @param keys the keys whose values will be retrieved + * @return the configuration data + */ + ConfigData get(String path, Set keys); + + /** + * Subscribes to changes for the given keys at the given path (optional operation). + * + * @param path the path where the data resides + * @param keys the keys whose values will be retrieved + * @param callback the callback to invoke upon change + * @throws {@link UnsupportedOperationException} if the subscribe operation is not supported + */ + default void subscribe(String path, Set keys, ConfigChangeCallback callback) { + throw new UnsupportedOperationException(); + } + + /** + * Unsubscribes to changes for the given keys at the given path (optional operation). + * + * @param path the path where the data resides + * @param keys the keys whose values will be retrieved + * @param callback the callback to be unsubscribed from changes + * @throws {@link UnsupportedOperationException} if the unsubscribe operation is not supported + */ + default void unsubscribe(String path, Set keys, ConfigChangeCallback callback) { + throw new UnsupportedOperationException(); + } + + /** + * Clears all subscribers (optional operation). + * + * @throws {@link UnsupportedOperationException} if the unsubscribeAll operation is not supported + */ + default void unsubscribeAll() { + throw new UnsupportedOperationException(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/config/provider/DirectoryConfigProvider.java b/clients/src/main/java/org/apache/kafka/common/config/provider/DirectoryConfigProvider.java new file mode 100644 index 0000000..bcfe674 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/config/provider/DirectoryConfigProvider.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.config.provider; + +import org.apache.kafka.common.config.ConfigData; +import org.apache.kafka.common.config.ConfigException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Map; +import java.util.Set; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static java.util.Collections.emptyMap; + +/** + * An implementation of {@link ConfigProvider} based on a directory of files. + * Property keys correspond to the names of the regular (i.e. non-directory) + * files in a directory given by the path parameter. + * Property values are taken from the file contents corresponding to each key. + */ +public class DirectoryConfigProvider implements ConfigProvider { + + private static final Logger log = LoggerFactory.getLogger(DirectoryConfigProvider.class); + + @Override + public void configure(Map configs) { } + + @Override + public void close() throws IOException { } + + /** + * Retrieves the data contained in regular files in the directory given by {@code path}. + * Non-regular files (such as directories) in the given directory are silently ignored. + * @param path the directory where data files reside. + * @return the configuration data. + */ + @Override + public ConfigData get(String path) { + return get(path, Files::isRegularFile); + } + + /** + * Retrieves the data contained in the regular files named by {@code keys} in the directory given by {@code path}. + * Non-regular files (such as directories) in the given directory are silently ignored. + * @param path the directory where data files reside. + * @param keys the keys whose values will be retrieved. + * @return the configuration data. + */ + @Override + public ConfigData get(String path, Set keys) { + return get(path, pathname -> + Files.isRegularFile(pathname) + && keys.contains(pathname.getFileName().toString())); + } + + private static ConfigData get(String path, Predicate fileFilter) { + Map map = emptyMap(); + if (path != null && !path.isEmpty()) { + Path dir = new File(path).toPath(); + if (!Files.isDirectory(dir)) { + log.warn("The path {} is not a directory", path); + } else { + try (Stream stream = Files.list(dir)) { + map = stream + .filter(fileFilter) + .collect(Collectors.toMap( + p -> p.getFileName().toString(), + p -> read(p))); + } catch (IOException e) { + throw new ConfigException("Could not list directory " + dir, e); + } + } + } + return new ConfigData(map); + } + + private static String read(Path path) { + try { + return new String(Files.readAllBytes(path), StandardCharsets.UTF_8); + } catch (IOException e) { + throw new ConfigException("Could not read file " + path + " for property " + path.getFileName(), e); + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/config/provider/FileConfigProvider.java b/clients/src/main/java/org/apache/kafka/common/config/provider/FileConfigProvider.java new file mode 100644 index 0000000..3920da0 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/config/provider/FileConfigProvider.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.config.provider; + +import org.apache.kafka.common.config.ConfigData; +import org.apache.kafka.common.config.ConfigException; + +import java.io.IOException; +import java.io.Reader; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; +import java.util.Set; + +/** + * An implementation of {@link ConfigProvider} that represents a Properties file. + * All property keys and values are stored as cleartext. + */ +public class FileConfigProvider implements ConfigProvider { + + public void configure(Map configs) { + } + + /** + * Retrieves the data at the given Properties file. + * + * @param path the file where the data resides + * @return the configuration data + */ + public ConfigData get(String path) { + Map data = new HashMap<>(); + if (path == null || path.isEmpty()) { + return new ConfigData(data); + } + try (Reader reader = reader(path)) { + Properties properties = new Properties(); + properties.load(reader); + Enumeration keys = properties.keys(); + while (keys.hasMoreElements()) { + String key = keys.nextElement().toString(); + String value = properties.getProperty(key); + if (value != null) { + data.put(key, value); + } + } + return new ConfigData(data); + } catch (IOException e) { + throw new ConfigException("Could not read properties from file " + path, e); + } + } + + /** + * Retrieves the data with the given keys at the given Properties file. + * + * @param path the file where the data resides + * @param keys the keys whose values will be retrieved + * @return the configuration data + */ + public ConfigData get(String path, Set keys) { + Map data = new HashMap<>(); + if (path == null || path.isEmpty()) { + return new ConfigData(data); + } + try (Reader reader = reader(path)) { + Properties properties = new Properties(); + properties.load(reader); + for (String key : keys) { + String value = properties.getProperty(key); + if (value != null) { + data.put(key, value); + } + } + return new ConfigData(data); + } catch (IOException e) { + throw new ConfigException("Could not read properties from file " + path, e); + } + } + + // visible for testing + protected Reader reader(String path) throws IOException { + return Files.newBufferedReader(Paths.get(path)); + } + + public void close() { + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/config/types/Password.java b/clients/src/main/java/org/apache/kafka/common/config/types/Password.java new file mode 100644 index 0000000..eafffb9 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/config/types/Password.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.config.types; + +/** + * A wrapper class for passwords to hide them while logging a config + */ +public class Password { + + public static final String HIDDEN = "[hidden]"; + + private final String value; + + /** + * Construct a new Password object + * @param value The value of a password + */ + public Password(String value) { + this.value = value; + } + + @Override + public int hashCode() { + return value.hashCode(); + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof Password)) + return false; + Password other = (Password) obj; + return value.equals(other.value); + } + + /** + * Returns hidden password string + * + * @return hidden password string + */ + @Override + public String toString() { + return HIDDEN; + } + + /** + * Returns real password string + * + * @return real password string + */ + public String value() { + return value; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/ApiException.java b/clients/src/main/java/org/apache/kafka/common/errors/ApiException.java new file mode 100644 index 0000000..aa4e98c --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/ApiException.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +import org.apache.kafka.common.KafkaException; + +/** + * Any API exception that is part of the public protocol and should be a subclass of this class and be part of this + * package. + */ +public class ApiException extends KafkaException { + + private static final long serialVersionUID = 1L; + + public ApiException(String message, Throwable cause) { + super(message, cause); + } + + public ApiException(String message) { + super(message); + } + + public ApiException(Throwable cause) { + super(cause); + } + + public ApiException() { + super(); + } + + /* avoid the expensive and useless stack trace for api exceptions */ + @Override + public Throwable fillInStackTrace() { + return this; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/AuthenticationException.java b/clients/src/main/java/org/apache/kafka/common/errors/AuthenticationException.java new file mode 100644 index 0000000..7a05eba --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/AuthenticationException.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +import javax.net.ssl.SSLException; + +/** + * This exception indicates that SASL authentication has failed. + * On authentication failure, clients abort the operation requested and raise one + * of the subclasses of this exception: + *
          + * {@link SaslAuthenticationException} if SASL handshake fails with invalid credentials + * or any other failure specific to the SASL mechanism used for authentication + *
        • {@link UnsupportedSaslMechanismException} if the SASL mechanism requested by the client + * is not supported on the broker.
        • + *
        • {@link IllegalSaslStateException} if an unexpected request is received on during SASL + * handshake. This could be due to misconfigured security protocol.
        • + *
        • {@link SslAuthenticationException} if SSL handshake failed due to any {@link SSLException}. + *
        + */ +public class AuthenticationException extends ApiException { + + private static final long serialVersionUID = 1L; + + public AuthenticationException(String message) { + super(message); + } + + public AuthenticationException(Throwable cause) { + super(cause); + } + + public AuthenticationException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/AuthorizationException.java b/clients/src/main/java/org/apache/kafka/common/errors/AuthorizationException.java new file mode 100644 index 0000000..0471fe6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/AuthorizationException.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class AuthorizationException extends ApiException { + + public AuthorizationException(String message) { + super(message); + } + + public AuthorizationException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/BrokerIdNotRegisteredException.java b/clients/src/main/java/org/apache/kafka/common/errors/BrokerIdNotRegisteredException.java new file mode 100644 index 0000000..cc8a47a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/BrokerIdNotRegisteredException.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class BrokerIdNotRegisteredException extends ApiException { + + public BrokerIdNotRegisteredException(String message) { + super(message); + } + + public BrokerIdNotRegisteredException(String message, Throwable throwable) { + super(message, throwable); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/BrokerNotAvailableException.java b/clients/src/main/java/org/apache/kafka/common/errors/BrokerNotAvailableException.java new file mode 100644 index 0000000..26bb803 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/BrokerNotAvailableException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class BrokerNotAvailableException extends ApiException { + + private static final long serialVersionUID = 1L; + + public BrokerNotAvailableException(String message) { + super(message); + } + + public BrokerNotAvailableException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/ClusterAuthorizationException.java b/clients/src/main/java/org/apache/kafka/common/errors/ClusterAuthorizationException.java new file mode 100644 index 0000000..61b8929 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/ClusterAuthorizationException.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class ClusterAuthorizationException extends AuthorizationException { + + private static final long serialVersionUID = 1L; + + public ClusterAuthorizationException(String message) { + super(message); + } + + public ClusterAuthorizationException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/ConcurrentTransactionsException.java b/clients/src/main/java/org/apache/kafka/common/errors/ConcurrentTransactionsException.java new file mode 100644 index 0000000..6ad6b8a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/ConcurrentTransactionsException.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class ConcurrentTransactionsException extends ApiException { + private static final long serialVersionUID = 1L; + + public ConcurrentTransactionsException(final String message) { + super(message); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/ControllerMovedException.java b/clients/src/main/java/org/apache/kafka/common/errors/ControllerMovedException.java new file mode 100644 index 0000000..124e793 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/ControllerMovedException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class ControllerMovedException extends ApiException { + + private static final long serialVersionUID = 1L; + + public ControllerMovedException(String message) { + super(message); + } + + public ControllerMovedException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/CoordinatorLoadInProgressException.java b/clients/src/main/java/org/apache/kafka/common/errors/CoordinatorLoadInProgressException.java new file mode 100644 index 0000000..4bdb978 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/CoordinatorLoadInProgressException.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * In the context of the group coordinator, the broker returns this error code for any coordinator request if + * it is still loading the group metadata (e.g. after a leader change for that group metadata topic partition). + * + * In the context of the transactional coordinator, this error will be returned if there is a pending transactional + * request with the same transactional id, or if the transaction cache is currently being populated from the transaction + * log. + */ +public class CoordinatorLoadInProgressException extends RetriableException { + + private static final long serialVersionUID = 1L; + + public CoordinatorLoadInProgressException(String message) { + super(message); + } + + public CoordinatorLoadInProgressException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/CoordinatorNotAvailableException.java b/clients/src/main/java/org/apache/kafka/common/errors/CoordinatorNotAvailableException.java new file mode 100644 index 0000000..827ce54 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/CoordinatorNotAvailableException.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * In the context of the group coordinator, the broker returns this error code for metadata or offset commit + * requests if the group metadata topic has not been created yet. + * + * In the context of the transactional coordinator, this error will be returned if the underlying transactional log + * is under replicated or if an append to the log times out. + */ +public class CoordinatorNotAvailableException extends RetriableException { + public static final CoordinatorNotAvailableException INSTANCE = new CoordinatorNotAvailableException(); + + private static final long serialVersionUID = 1L; + + private CoordinatorNotAvailableException() { + super(); + } + + public CoordinatorNotAvailableException(String message) { + super(message); + } + + public CoordinatorNotAvailableException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/CorruptRecordException.java b/clients/src/main/java/org/apache/kafka/common/errors/CorruptRecordException.java new file mode 100644 index 0000000..abcf516 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/CorruptRecordException.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * This exception indicates a record has failed its internal CRC check, this generally indicates network or disk + * corruption. + */ +public class CorruptRecordException extends RetriableException { + + private static final long serialVersionUID = 1L; + + public CorruptRecordException() { + super("This message has failed its CRC checksum, exceeds the valid size, has a null key for a compacted topic, or is otherwise corrupt."); + } + + public CorruptRecordException(String message) { + super(message); + } + + public CorruptRecordException(Throwable cause) { + super(cause); + } + + public CorruptRecordException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/DelegationTokenAuthorizationException.java b/clients/src/main/java/org/apache/kafka/common/errors/DelegationTokenAuthorizationException.java new file mode 100644 index 0000000..ddc97c6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/DelegationTokenAuthorizationException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class DelegationTokenAuthorizationException extends AuthorizationException { + + private static final long serialVersionUID = 1L; + + public DelegationTokenAuthorizationException(String message) { + super(message); + } + + public DelegationTokenAuthorizationException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/DelegationTokenDisabledException.java b/clients/src/main/java/org/apache/kafka/common/errors/DelegationTokenDisabledException.java new file mode 100644 index 0000000..798611e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/DelegationTokenDisabledException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class DelegationTokenDisabledException extends ApiException { + + private static final long serialVersionUID = 1L; + + public DelegationTokenDisabledException(String message) { + super(message); + } + + public DelegationTokenDisabledException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/DelegationTokenExpiredException.java b/clients/src/main/java/org/apache/kafka/common/errors/DelegationTokenExpiredException.java new file mode 100644 index 0000000..4dae7f3 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/DelegationTokenExpiredException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class DelegationTokenExpiredException extends ApiException { + + private static final long serialVersionUID = 1L; + + public DelegationTokenExpiredException(String message) { + super(message); + } + + public DelegationTokenExpiredException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/DelegationTokenNotFoundException.java b/clients/src/main/java/org/apache/kafka/common/errors/DelegationTokenNotFoundException.java new file mode 100644 index 0000000..5875edf --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/DelegationTokenNotFoundException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class DelegationTokenNotFoundException extends ApiException { + + private static final long serialVersionUID = 1L; + + public DelegationTokenNotFoundException(String message) { + super(message); + } + + public DelegationTokenNotFoundException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/DelegationTokenOwnerMismatchException.java b/clients/src/main/java/org/apache/kafka/common/errors/DelegationTokenOwnerMismatchException.java new file mode 100644 index 0000000..5c8239e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/DelegationTokenOwnerMismatchException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class DelegationTokenOwnerMismatchException extends ApiException { + + private static final long serialVersionUID = 1L; + + public DelegationTokenOwnerMismatchException(String message) { + super(message); + } + + public DelegationTokenOwnerMismatchException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/DisconnectException.java b/clients/src/main/java/org/apache/kafka/common/errors/DisconnectException.java new file mode 100644 index 0000000..e0bc787 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/DisconnectException.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + + +/** + * Server disconnected before a request could be completed. + */ +public class DisconnectException extends RetriableException { + public static final DisconnectException INSTANCE = new DisconnectException(); + + private static final long serialVersionUID = 1L; + + public DisconnectException() { + super(); + } + + public DisconnectException(String message, Throwable cause) { + super(message, cause); + } + + public DisconnectException(String message) { + super(message); + } + + public DisconnectException(Throwable cause) { + super(cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/DuplicateBrokerRegistrationException.java b/clients/src/main/java/org/apache/kafka/common/errors/DuplicateBrokerRegistrationException.java new file mode 100644 index 0000000..06f3820 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/DuplicateBrokerRegistrationException.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class DuplicateBrokerRegistrationException extends ApiException { + + public DuplicateBrokerRegistrationException(String message) { + super(message); + } + + public DuplicateBrokerRegistrationException(String message, Throwable throwable) { + super(message, throwable); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/DuplicateResourceException.java b/clients/src/main/java/org/apache/kafka/common/errors/DuplicateResourceException.java new file mode 100644 index 0000000..1c0ec43 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/DuplicateResourceException.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * Exception thrown due to a request that illegally refers to the same resource twice + * (for example, trying to both create and delete the same SCRAM credential for a particular user in a single request). + */ +public class DuplicateResourceException extends ApiException { + + private static final long serialVersionUID = 1L; + + private final String resource; + + /** + * Constructor + * + * @param message the exception's message + */ + public DuplicateResourceException(String message) { + this(null, message); + } + + /** + * + * @param message the exception's message + * @param cause the exception's cause + */ + public DuplicateResourceException(String message, Throwable cause) { + this(null, message, cause); + } + + /** + * Constructor + * + * @param resource the (potentially null) resource that was referred to twice + * @param message the exception's message + */ + public DuplicateResourceException(String resource, String message) { + super(message); + this.resource = resource; + } + + /** + * Constructor + * + * @param resource the (potentially null) resource that was referred to twice + * @param message the exception's message + * @param cause the exception's cause + */ + public DuplicateResourceException(String resource, String message, Throwable cause) { + super(message, cause); + this.resource = resource; + } + + /** + * + * @return the (potentially null) resource that was referred to twice + */ + public String resource() { + return this.resource; + } +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/common/errors/DuplicateSequenceException.java b/clients/src/main/java/org/apache/kafka/common/errors/DuplicateSequenceException.java new file mode 100644 index 0000000..11f81af --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/DuplicateSequenceException.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class DuplicateSequenceException extends ApiException { + + public DuplicateSequenceException(String message) { + super(message); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/ElectionNotNeededException.java b/clients/src/main/java/org/apache/kafka/common/errors/ElectionNotNeededException.java new file mode 100644 index 0000000..74fc7d6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/ElectionNotNeededException.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class ElectionNotNeededException extends InvalidMetadataException { + + public ElectionNotNeededException(String message) { + super(message); + } + + public ElectionNotNeededException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/EligibleLeadersNotAvailableException.java b/clients/src/main/java/org/apache/kafka/common/errors/EligibleLeadersNotAvailableException.java new file mode 100644 index 0000000..8767965 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/EligibleLeadersNotAvailableException.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class EligibleLeadersNotAvailableException extends InvalidMetadataException { + + public EligibleLeadersNotAvailableException(String message) { + super(message); + } + + public EligibleLeadersNotAvailableException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/FeatureUpdateFailedException.java b/clients/src/main/java/org/apache/kafka/common/errors/FeatureUpdateFailedException.java new file mode 100644 index 0000000..9f5e23d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/FeatureUpdateFailedException.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class FeatureUpdateFailedException extends ApiException { + private static final long serialVersionUID = 1L; + + public FeatureUpdateFailedException(final String message) { + super(message); + } + + public FeatureUpdateFailedException(final String message, final Throwable cause) { + super(message, cause); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/FencedInstanceIdException.java b/clients/src/main/java/org/apache/kafka/common/errors/FencedInstanceIdException.java new file mode 100644 index 0000000..78e4034 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/FencedInstanceIdException.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class FencedInstanceIdException extends ApiException { + private static final long serialVersionUID = 1L; + + public FencedInstanceIdException(String message) { + super(message); + } + + public FencedInstanceIdException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/FencedLeaderEpochException.java b/clients/src/main/java/org/apache/kafka/common/errors/FencedLeaderEpochException.java new file mode 100644 index 0000000..24f0eef --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/FencedLeaderEpochException.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * The request contained a leader epoch which is smaller than that on the broker that received the + * request. This can happen when an operation is attempted before a pending metadata update has been + * received. Clients will typically refresh metadata before retrying. + */ +public class FencedLeaderEpochException extends InvalidMetadataException { + private static final long serialVersionUID = 1L; + + public FencedLeaderEpochException(String message) { + super(message); + } + + public FencedLeaderEpochException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/FetchSessionIdNotFoundException.java b/clients/src/main/java/org/apache/kafka/common/errors/FetchSessionIdNotFoundException.java new file mode 100644 index 0000000..2ce5f74 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/FetchSessionIdNotFoundException.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.errors; + +public class FetchSessionIdNotFoundException extends RetriableException { + private static final long serialVersionUID = 1L; + + public FetchSessionIdNotFoundException() { + } + + public FetchSessionIdNotFoundException(String message) { + super(message); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/FetchSessionTopicIdException.java b/clients/src/main/java/org/apache/kafka/common/errors/FetchSessionTopicIdException.java new file mode 100644 index 0000000..11a6e1d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/FetchSessionTopicIdException.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class FetchSessionTopicIdException extends RetriableException { + private static final long serialVersionUID = 1L; + + public FetchSessionTopicIdException(String message) { + super(message); + } +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/common/errors/GroupAuthorizationException.java b/clients/src/main/java/org/apache/kafka/common/errors/GroupAuthorizationException.java new file mode 100644 index 0000000..22eae3b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/GroupAuthorizationException.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class GroupAuthorizationException extends AuthorizationException { + private final String groupId; + + public GroupAuthorizationException(String message, String groupId) { + super(message); + this.groupId = groupId; + } + + public GroupAuthorizationException(String message) { + this(message, null); + } + + /** + * Return the group ID that failed authorization. May be null if it is not known + * in the context the exception was raised in. + * + * @return nullable groupId + */ + public String groupId() { + return groupId; + } + + public static GroupAuthorizationException forGroupId(String groupId) { + return new GroupAuthorizationException("Not authorized to access group: " + groupId, groupId); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/GroupIdNotFoundException.java b/clients/src/main/java/org/apache/kafka/common/errors/GroupIdNotFoundException.java new file mode 100644 index 0000000..a4d509d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/GroupIdNotFoundException.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class GroupIdNotFoundException extends ApiException { + public GroupIdNotFoundException(String message) { + super(message); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/GroupMaxSizeReachedException.java b/clients/src/main/java/org/apache/kafka/common/errors/GroupMaxSizeReachedException.java new file mode 100644 index 0000000..85d0c7d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/GroupMaxSizeReachedException.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * Indicates that a consumer group is already at its configured maximum capacity and cannot accommodate more members + */ +public class GroupMaxSizeReachedException extends ApiException { + private static final long serialVersionUID = 1L; + + public GroupMaxSizeReachedException(String message) { + super(message); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/GroupNotEmptyException.java b/clients/src/main/java/org/apache/kafka/common/errors/GroupNotEmptyException.java new file mode 100644 index 0000000..e15b3e6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/GroupNotEmptyException.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class GroupNotEmptyException extends ApiException { + public GroupNotEmptyException(String message) { + super(message); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/GroupSubscribedToTopicException.java b/clients/src/main/java/org/apache/kafka/common/errors/GroupSubscribedToTopicException.java new file mode 100644 index 0000000..a62fe32 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/GroupSubscribedToTopicException.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class GroupSubscribedToTopicException extends ApiException { + public GroupSubscribedToTopicException(String message) { + super(message); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/IllegalGenerationException.java b/clients/src/main/java/org/apache/kafka/common/errors/IllegalGenerationException.java new file mode 100644 index 0000000..efd749f --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/IllegalGenerationException.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class IllegalGenerationException extends ApiException { + private static final long serialVersionUID = 1L; + + public IllegalGenerationException() { + super(); + } + + public IllegalGenerationException(String message, Throwable cause) { + super(message, cause); + } + + public IllegalGenerationException(String message) { + super(message); + } + + public IllegalGenerationException(Throwable cause) { + super(cause); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/IllegalSaslStateException.java b/clients/src/main/java/org/apache/kafka/common/errors/IllegalSaslStateException.java new file mode 100644 index 0000000..691244a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/IllegalSaslStateException.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * This exception indicates unexpected requests prior to SASL authentication. + * This could be due to misconfigured security, e.g. if PLAINTEXT protocol + * is used to connect to a SASL endpoint. + */ +public class IllegalSaslStateException extends AuthenticationException { + + private static final long serialVersionUID = 1L; + + public IllegalSaslStateException(String message) { + super(message); + } + + public IllegalSaslStateException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/InconsistentClusterIdException.java b/clients/src/main/java/org/apache/kafka/common/errors/InconsistentClusterIdException.java new file mode 100644 index 0000000..62fed41 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/InconsistentClusterIdException.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class InconsistentClusterIdException extends ApiException { + + public InconsistentClusterIdException(String message) { + super(message); + } + + public InconsistentClusterIdException(String message, Throwable throwable) { + super(message, throwable); + } +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/common/errors/InconsistentGroupProtocolException.java b/clients/src/main/java/org/apache/kafka/common/errors/InconsistentGroupProtocolException.java new file mode 100644 index 0000000..28bcbe5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/InconsistentGroupProtocolException.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class InconsistentGroupProtocolException extends ApiException { + private static final long serialVersionUID = 1L; + + public InconsistentGroupProtocolException(String message, Throwable cause) { + super(message, cause); + } + + public InconsistentGroupProtocolException(String message) { + super(message); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/InconsistentTopicIdException.java b/clients/src/main/java/org/apache/kafka/common/errors/InconsistentTopicIdException.java new file mode 100644 index 0000000..1dfe468 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/InconsistentTopicIdException.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class InconsistentTopicIdException extends InvalidMetadataException { + + private static final long serialVersionUID = 1L; + + public InconsistentTopicIdException(String message) { + super(message); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/InconsistentVoterSetException.java b/clients/src/main/java/org/apache/kafka/common/errors/InconsistentVoterSetException.java new file mode 100644 index 0000000..8a3667b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/InconsistentVoterSetException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class InconsistentVoterSetException extends ApiException { + + private static final long serialVersionUID = 1; + + public InconsistentVoterSetException(String s) { + super(s); + } + + public InconsistentVoterSetException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/InterruptException.java b/clients/src/main/java/org/apache/kafka/common/errors/InterruptException.java new file mode 100644 index 0000000..fec66bb --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/InterruptException.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +import org.apache.kafka.common.KafkaException; + +/** + * An unchecked wrapper for InterruptedException + */ +public class InterruptException extends KafkaException { + + private static final long serialVersionUID = 1L; + + public InterruptException(InterruptedException cause) { + super(cause); + Thread.currentThread().interrupt(); + } + + public InterruptException(String message, InterruptedException cause) { + super(message, cause); + Thread.currentThread().interrupt(); + } + + public InterruptException(String message) { + super(message, new InterruptedException()); + Thread.currentThread().interrupt(); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/InvalidCommitOffsetSizeException.java b/clients/src/main/java/org/apache/kafka/common/errors/InvalidCommitOffsetSizeException.java new file mode 100644 index 0000000..a17a30f --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/InvalidCommitOffsetSizeException.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class InvalidCommitOffsetSizeException extends ApiException { + private static final long serialVersionUID = 1L; + + public InvalidCommitOffsetSizeException(String message, Throwable cause) { + super(message, cause); + } + + public InvalidCommitOffsetSizeException(String message) { + super(message); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/InvalidConfigurationException.java b/clients/src/main/java/org/apache/kafka/common/errors/InvalidConfigurationException.java new file mode 100644 index 0000000..333566a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/InvalidConfigurationException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class InvalidConfigurationException extends ApiException { + + private static final long serialVersionUID = 1L; + + public InvalidConfigurationException(String message) { + super(message); + } + + public InvalidConfigurationException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/InvalidFetchSessionEpochException.java b/clients/src/main/java/org/apache/kafka/common/errors/InvalidFetchSessionEpochException.java new file mode 100644 index 0000000..3b135c0 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/InvalidFetchSessionEpochException.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.errors; + +public class InvalidFetchSessionEpochException extends RetriableException { + private static final long serialVersionUID = 1L; + + public InvalidFetchSessionEpochException() { + } + + public InvalidFetchSessionEpochException(String message) { + super(message); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/InvalidFetchSizeException.java b/clients/src/main/java/org/apache/kafka/common/errors/InvalidFetchSizeException.java new file mode 100644 index 0000000..65a0aeb --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/InvalidFetchSizeException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class InvalidFetchSizeException extends ApiException { + + private static final long serialVersionUID = 1L; + + public InvalidFetchSizeException(String message) { + super(message); + } + + public InvalidFetchSizeException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/InvalidGroupIdException.java b/clients/src/main/java/org/apache/kafka/common/errors/InvalidGroupIdException.java new file mode 100644 index 0000000..95e6f36 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/InvalidGroupIdException.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class InvalidGroupIdException extends ApiException { + private static final long serialVersionUID = 1L; + + public InvalidGroupIdException(String message, Throwable cause) { + super(message, cause); + } + + public InvalidGroupIdException(String message) { + super(message); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/InvalidMetadataException.java b/clients/src/main/java/org/apache/kafka/common/errors/InvalidMetadataException.java new file mode 100644 index 0000000..e6663db --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/InvalidMetadataException.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * An exception that may indicate the client's metadata is out of date + */ +public abstract class InvalidMetadataException extends RetriableException { + + private static final long serialVersionUID = 1L; + + public InvalidMetadataException() { + super(); + } + + public InvalidMetadataException(String message) { + super(message); + } + + public InvalidMetadataException(String message, Throwable cause) { + super(message, cause); + } + + public InvalidMetadataException(Throwable cause) { + super(cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/InvalidOffsetException.java b/clients/src/main/java/org/apache/kafka/common/errors/InvalidOffsetException.java new file mode 100644 index 0000000..0d954f1 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/InvalidOffsetException.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * Thrown when the offset for a set of partitions is invalid (either undefined or out of range), + * and no reset policy has been configured. + * @see OffsetOutOfRangeException + */ +public class InvalidOffsetException extends ApiException { + + private static final long serialVersionUID = 1L; + + public InvalidOffsetException(String message) { + super(message); + } + + public InvalidOffsetException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/InvalidPartitionsException.java b/clients/src/main/java/org/apache/kafka/common/errors/InvalidPartitionsException.java new file mode 100644 index 0000000..c65ced4 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/InvalidPartitionsException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class InvalidPartitionsException extends ApiException { + + private static final long serialVersionUID = 1L; + + public InvalidPartitionsException(String message) { + super(message); + } + + public InvalidPartitionsException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/InvalidPidMappingException.java b/clients/src/main/java/org/apache/kafka/common/errors/InvalidPidMappingException.java new file mode 100644 index 0000000..69fb71e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/InvalidPidMappingException.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class InvalidPidMappingException extends ApiException { + public InvalidPidMappingException(String message) { + super(message); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/InvalidPrincipalTypeException.java b/clients/src/main/java/org/apache/kafka/common/errors/InvalidPrincipalTypeException.java new file mode 100644 index 0000000..a0736e9 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/InvalidPrincipalTypeException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class InvalidPrincipalTypeException extends ApiException { + + private static final long serialVersionUID = 1L; + + public InvalidPrincipalTypeException(String message) { + super(message); + } + + public InvalidPrincipalTypeException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/InvalidProducerEpochException.java b/clients/src/main/java/org/apache/kafka/common/errors/InvalidProducerEpochException.java new file mode 100644 index 0000000..79b8236 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/InvalidProducerEpochException.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * This exception indicates that the produce request sent to the partition leader + * contains a non-matching producer epoch. When encountering this exception, user should abort the ongoing transaction + * by calling KafkaProducer#abortTransaction which would try to send initPidRequest and reinitialize the producer + * under the hood. + */ +public class InvalidProducerEpochException extends ApiException { + + private static final long serialVersionUID = 1L; + + public InvalidProducerEpochException(String message) { + super(message); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/InvalidReplicaAssignmentException.java b/clients/src/main/java/org/apache/kafka/common/errors/InvalidReplicaAssignmentException.java new file mode 100644 index 0000000..5357d91 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/InvalidReplicaAssignmentException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class InvalidReplicaAssignmentException extends ApiException { + + private static final long serialVersionUID = 1L; + + public InvalidReplicaAssignmentException(String message) { + super(message); + } + + public InvalidReplicaAssignmentException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/InvalidReplicationFactorException.java b/clients/src/main/java/org/apache/kafka/common/errors/InvalidReplicationFactorException.java new file mode 100644 index 0000000..699d5a8 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/InvalidReplicationFactorException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class InvalidReplicationFactorException extends ApiException { + + private static final long serialVersionUID = 1L; + + public InvalidReplicationFactorException(String message) { + super(message); + } + + public InvalidReplicationFactorException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/InvalidRequestException.java b/clients/src/main/java/org/apache/kafka/common/errors/InvalidRequestException.java new file mode 100644 index 0000000..7470f66 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/InvalidRequestException.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * Thrown when a request breaks basic wire protocol rules. + * This most likely occurs because of a request being malformed by the client library or + * the message was sent to an incompatible broker. + */ +public class InvalidRequestException extends ApiException { + + private static final long serialVersionUID = 1L; + + public InvalidRequestException(String message) { + super(message); + } + + public InvalidRequestException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/InvalidRequiredAcksException.java b/clients/src/main/java/org/apache/kafka/common/errors/InvalidRequiredAcksException.java new file mode 100644 index 0000000..423c091 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/InvalidRequiredAcksException.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class InvalidRequiredAcksException extends ApiException { + private static final long serialVersionUID = 1L; + + public InvalidRequiredAcksException(String message) { + super(message); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/InvalidSessionTimeoutException.java b/clients/src/main/java/org/apache/kafka/common/errors/InvalidSessionTimeoutException.java new file mode 100644 index 0000000..a971498 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/InvalidSessionTimeoutException.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class InvalidSessionTimeoutException extends ApiException { + private static final long serialVersionUID = 1L; + + public InvalidSessionTimeoutException(String message, Throwable cause) { + super(message, cause); + } + + public InvalidSessionTimeoutException(String message) { + super(message); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/InvalidTimestampException.java b/clients/src/main/java/org/apache/kafka/common/errors/InvalidTimestampException.java new file mode 100644 index 0000000..0e3cd92 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/InvalidTimestampException.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * Indicate the timestamp of a record is invalid. + */ +public class InvalidTimestampException extends ApiException { + + private static final long serialVersionUID = 1L; + + public InvalidTimestampException(String message) { + super(message); + } + + public InvalidTimestampException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/InvalidTopicException.java b/clients/src/main/java/org/apache/kafka/common/errors/InvalidTopicException.java new file mode 100644 index 0000000..344d231 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/InvalidTopicException.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +import java.util.HashSet; +import java.util.Set; + + +/** + * The client has attempted to perform an operation on an invalid topic. + * For example the topic name is too long, contains invalid characters etc. + * This exception is not retriable because the operation won't suddenly become valid. + * + * @see UnknownTopicOrPartitionException + */ +public class InvalidTopicException extends ApiException { + private static final long serialVersionUID = 1L; + + private final Set invalidTopics; + + public InvalidTopicException() { + super(); + invalidTopics = new HashSet<>(); + } + + public InvalidTopicException(String message, Throwable cause) { + super(message, cause); + invalidTopics = new HashSet<>(); + } + + public InvalidTopicException(String message) { + super(message); + invalidTopics = new HashSet<>(); + } + + public InvalidTopicException(Throwable cause) { + super(cause); + invalidTopics = new HashSet<>(); + } + + public InvalidTopicException(Set invalidTopics) { + super("Invalid topics: " + invalidTopics); + this.invalidTopics = invalidTopics; + } + + public InvalidTopicException(String message, Set invalidTopics) { + super(message); + this.invalidTopics = invalidTopics; + } + + public Set invalidTopics() { + return invalidTopics; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/InvalidTxnStateException.java b/clients/src/main/java/org/apache/kafka/common/errors/InvalidTxnStateException.java new file mode 100644 index 0000000..ff06904 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/InvalidTxnStateException.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class InvalidTxnStateException extends ApiException { + public InvalidTxnStateException(String message) { + super(message); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/InvalidTxnTimeoutException.java b/clients/src/main/java/org/apache/kafka/common/errors/InvalidTxnTimeoutException.java new file mode 100644 index 0000000..f16af66 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/InvalidTxnTimeoutException.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * The transaction coordinator returns this error code if the timeout received via the InitProducerIdRequest is larger than + * the `transaction.max.timeout.ms` config value. + */ +public class InvalidTxnTimeoutException extends ApiException { + private static final long serialVersionUID = 1L; + + public InvalidTxnTimeoutException(String message, Throwable cause) { + super(message, cause); + } + + public InvalidTxnTimeoutException(String message) { + super(message); + } +} + diff --git a/clients/src/main/java/org/apache/kafka/common/errors/InvalidUpdateVersionException.java b/clients/src/main/java/org/apache/kafka/common/errors/InvalidUpdateVersionException.java new file mode 100644 index 0000000..e41262d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/InvalidUpdateVersionException.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class InvalidUpdateVersionException extends ApiException { + + public InvalidUpdateVersionException(String message) { + super(message); + } + + public InvalidUpdateVersionException(String message, Throwable throwable) { + super(message, throwable); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/KafkaStorageException.java b/clients/src/main/java/org/apache/kafka/common/errors/KafkaStorageException.java new file mode 100644 index 0000000..c45afb0 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/KafkaStorageException.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * Miscellaneous disk-related IOException occurred when handling a request. + * Client should request metadata update and retry if the response shows KafkaStorageException + * + * Here are the guidelines on how to handle KafkaStorageException and IOException: + * + * 1) If the server has not finished loading logs, IOException does not need to be converted to KafkaStorageException + * 2) After the server has finished loading logs, IOException should be caught and trigger LogDirFailureChannel.maybeAddOfflineLogDir() + * Then the IOException should either be swallowed and logged, or be converted and re-thrown as KafkaStorageException + * 3) It is preferred for IOException to be caught in Log rather than in ReplicaManager or LogSegment. + * + */ +public class KafkaStorageException extends InvalidMetadataException { + + private static final long serialVersionUID = 1L; + + public KafkaStorageException() { + super(); + } + + public KafkaStorageException(String message) { + super(message); + } + + public KafkaStorageException(Throwable cause) { + super(cause); + } + + public KafkaStorageException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/LeaderNotAvailableException.java b/clients/src/main/java/org/apache/kafka/common/errors/LeaderNotAvailableException.java new file mode 100644 index 0000000..69bc624 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/LeaderNotAvailableException.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * There is no currently available leader for the given partition (either because a leadership election is in progress + * or because all replicas are down). + */ +public class LeaderNotAvailableException extends InvalidMetadataException { + + private static final long serialVersionUID = 1L; + + public LeaderNotAvailableException(String message) { + super(message); + } + + public LeaderNotAvailableException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/ListenerNotFoundException.java b/clients/src/main/java/org/apache/kafka/common/errors/ListenerNotFoundException.java new file mode 100644 index 0000000..82c5d89 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/ListenerNotFoundException.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * The leader does not have an endpoint corresponding to the listener on which metadata was requested. + * This could indicate a broker configuration error or a transient error when listeners are updated + * dynamically and client requests are processed before all brokers have updated their listeners. + * This is currently used only for missing listeners on leader brokers, but may be used for followers + * in future. + */ +public class ListenerNotFoundException extends InvalidMetadataException { + + private static final long serialVersionUID = 1L; + + public ListenerNotFoundException(String message) { + super(message); + } + + public ListenerNotFoundException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/LogDirNotFoundException.java b/clients/src/main/java/org/apache/kafka/common/errors/LogDirNotFoundException.java new file mode 100644 index 0000000..0a4ae16 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/LogDirNotFoundException.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * Thrown when a request is made for a log directory that is not present on the broker + */ +public class LogDirNotFoundException extends ApiException { + + private static final long serialVersionUID = 1L; + + public LogDirNotFoundException(String message) { + super(message); + } + + public LogDirNotFoundException(String message, Throwable cause) { + super(message, cause); + } + + public LogDirNotFoundException(Throwable cause) { + super(cause); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/MemberIdRequiredException.java b/clients/src/main/java/org/apache/kafka/common/errors/MemberIdRequiredException.java new file mode 100644 index 0000000..55393e0 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/MemberIdRequiredException.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class MemberIdRequiredException extends ApiException { + + private static final long serialVersionUID = 1L; + + public MemberIdRequiredException(String message) { + super(message); + } + + public MemberIdRequiredException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/NetworkException.java b/clients/src/main/java/org/apache/kafka/common/errors/NetworkException.java new file mode 100644 index 0000000..fadd9bd --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/NetworkException.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * A misc. network-related IOException occurred when making a request. This could be because the client's metadata is + * out of date and it is making a request to a node that is now dead. + */ +public class NetworkException extends InvalidMetadataException { + + private static final long serialVersionUID = 1L; + + public NetworkException() { + super(); + } + + public NetworkException(String message, Throwable cause) { + super(message, cause); + } + + public NetworkException(String message) { + super(message); + } + + public NetworkException(Throwable cause) { + super(cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/NoReassignmentInProgressException.java b/clients/src/main/java/org/apache/kafka/common/errors/NoReassignmentInProgressException.java new file mode 100644 index 0000000..9fd8a73 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/NoReassignmentInProgressException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.errors; + +/** + * Thrown if a reassignment cannot be cancelled because none is in progress. + */ +public class NoReassignmentInProgressException extends ApiException { + public NoReassignmentInProgressException(String message) { + super(message); + } + + public NoReassignmentInProgressException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/NotControllerException.java b/clients/src/main/java/org/apache/kafka/common/errors/NotControllerException.java new file mode 100644 index 0000000..1c3e014 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/NotControllerException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class NotControllerException extends RetriableException { + + private static final long serialVersionUID = 1L; + + public NotControllerException(String message) { + super(message); + } + + public NotControllerException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/NotCoordinatorException.java b/clients/src/main/java/org/apache/kafka/common/errors/NotCoordinatorException.java new file mode 100644 index 0000000..00ca32c --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/NotCoordinatorException.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * In the context of the group coordinator, the broker returns this error code if it receives an offset fetch + * or commit request for a group it's not the coordinator of. + * + * In the context of the transactional coordinator, it returns this error when it receives a transactional + * request with a transactionalId the coordinator doesn't own. + */ +public class NotCoordinatorException extends RetriableException { + + private static final long serialVersionUID = 1L; + + public NotCoordinatorException(String message) { + super(message); + } + + public NotCoordinatorException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/NotEnoughReplicasAfterAppendException.java b/clients/src/main/java/org/apache/kafka/common/errors/NotEnoughReplicasAfterAppendException.java new file mode 100644 index 0000000..22ebc34 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/NotEnoughReplicasAfterAppendException.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * Number of insync replicas for the partition is lower than min.insync.replicas This exception is raised when the low + * ISR size is discovered *after* the message was already appended to the log. Producer retries will cause duplicates. + */ +public class NotEnoughReplicasAfterAppendException extends RetriableException { + private static final long serialVersionUID = 1L; + + public NotEnoughReplicasAfterAppendException(String message) { + super(message); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/NotEnoughReplicasException.java b/clients/src/main/java/org/apache/kafka/common/errors/NotEnoughReplicasException.java new file mode 100644 index 0000000..cb90e86 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/NotEnoughReplicasException.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * Number of insync replicas for the partition is lower than min.insync.replicas + */ +public class NotEnoughReplicasException extends RetriableException { + private static final long serialVersionUID = 1L; + + public NotEnoughReplicasException() { + super(); + } + + public NotEnoughReplicasException(String message, Throwable cause) { + super(message, cause); + } + + public NotEnoughReplicasException(String message) { + super(message); + } + + public NotEnoughReplicasException(Throwable cause) { + super(cause); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/NotLeaderForPartitionException.java b/clients/src/main/java/org/apache/kafka/common/errors/NotLeaderForPartitionException.java new file mode 100644 index 0000000..30efc49 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/NotLeaderForPartitionException.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * This server is not the leader for the given partition. + * @deprecated since 2.6. Use {@link NotLeaderOrFollowerException}. + */ +@Deprecated +public class NotLeaderForPartitionException extends InvalidMetadataException { + + private static final long serialVersionUID = 1L; + + public NotLeaderForPartitionException() { + super(); + } + + public NotLeaderForPartitionException(String message) { + super(message); + } + + public NotLeaderForPartitionException(Throwable cause) { + super(cause); + } + + public NotLeaderForPartitionException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/NotLeaderOrFollowerException.java b/clients/src/main/java/org/apache/kafka/common/errors/NotLeaderOrFollowerException.java new file mode 100644 index 0000000..2db960b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/NotLeaderOrFollowerException.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * Broker returns this error if a request could not be processed because the broker is not the leader + * or follower for a topic partition. This could be a transient exception during leader elections and + * reassignments. For `Produce` and other requests which are intended only for the leader, this exception + * indicates that the broker is not the current leader. For consumer `Fetch` requests which may be + * satisfied by a leader or follower, this exception indicates that the broker is not a replica + * of the topic partition. + */ +@SuppressWarnings("deprecation") +public class NotLeaderOrFollowerException extends NotLeaderForPartitionException { + + private static final long serialVersionUID = 1L; + + public NotLeaderOrFollowerException() { + super(); + } + + public NotLeaderOrFollowerException(String message) { + super(message); + } + + public NotLeaderOrFollowerException(Throwable cause) { + super(cause); + } + + public NotLeaderOrFollowerException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/OffsetMetadataTooLarge.java b/clients/src/main/java/org/apache/kafka/common/errors/OffsetMetadataTooLarge.java new file mode 100644 index 0000000..b77f167 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/OffsetMetadataTooLarge.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * The client has tried to save its offset with associated metadata larger than the maximum size allowed by the server. + */ +public class OffsetMetadataTooLarge extends ApiException { + + private static final long serialVersionUID = 1L; + + public OffsetMetadataTooLarge() { + } + + public OffsetMetadataTooLarge(String message) { + super(message); + } + + public OffsetMetadataTooLarge(Throwable cause) { + super(cause); + } + + public OffsetMetadataTooLarge(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/OffsetNotAvailableException.java b/clients/src/main/java/org/apache/kafka/common/errors/OffsetNotAvailableException.java new file mode 100644 index 0000000..97de3b3 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/OffsetNotAvailableException.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * Indicates that the leader is not able to guarantee monotonically increasing offsets + * due to the high watermark lagging behind the epoch start offset after a recent leader election + */ +public class OffsetNotAvailableException extends RetriableException { + private static final long serialVersionUID = 1L; + + public OffsetNotAvailableException(String message) { + super(message); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/OffsetOutOfRangeException.java b/clients/src/main/java/org/apache/kafka/common/errors/OffsetOutOfRangeException.java new file mode 100644 index 0000000..92a70fd --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/OffsetOutOfRangeException.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * No reset policy has been defined, and the offsets for these partitions are either larger or smaller + * than the range of offsets the server has for the given partition. + */ +public class OffsetOutOfRangeException extends InvalidOffsetException { + + private static final long serialVersionUID = 1L; + + public OffsetOutOfRangeException(String message) { + super(message); + } + + public OffsetOutOfRangeException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/OperationNotAttemptedException.java b/clients/src/main/java/org/apache/kafka/common/errors/OperationNotAttemptedException.java new file mode 100644 index 0000000..96df321 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/OperationNotAttemptedException.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * Indicates that the broker did not attempt to execute this operation. This may happen for batched RPCs where some + * operations in the batch failed, causing the broker to respond without trying the rest. + */ +public class OperationNotAttemptedException extends ApiException { + public OperationNotAttemptedException(final String message) { + super(message); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/OutOfOrderSequenceException.java b/clients/src/main/java/org/apache/kafka/common/errors/OutOfOrderSequenceException.java new file mode 100644 index 0000000..462e91e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/OutOfOrderSequenceException.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * This exception indicates that the broker received an unexpected sequence number from the producer, + * which means that data may have been lost. If the producer is configured for idempotence only (i.e. + * if enable.idempotence is set and no transactional.id is configured), it + * is possible to continue sending with the same producer instance, but doing so risks reordering + * of sent records. For transactional producers, this is a fatal error and you should close the + * producer. + */ +public class OutOfOrderSequenceException extends ApiException { + + public OutOfOrderSequenceException(String msg) { + super(msg); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/PolicyViolationException.java b/clients/src/main/java/org/apache/kafka/common/errors/PolicyViolationException.java new file mode 100644 index 0000000..0316938 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/PolicyViolationException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * Exception thrown if a create topics request does not satisfy the configured policy for a topic. + */ +public class PolicyViolationException extends ApiException { + + public PolicyViolationException(String message) { + super(message); + } + + public PolicyViolationException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/PositionOutOfRangeException.java b/clients/src/main/java/org/apache/kafka/common/errors/PositionOutOfRangeException.java new file mode 100644 index 0000000..c502d19 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/PositionOutOfRangeException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class PositionOutOfRangeException extends ApiException { + + private static final long serialVersionUID = 1; + + public PositionOutOfRangeException(String s) { + super(s); + } + + public PositionOutOfRangeException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/PreferredLeaderNotAvailableException.java b/clients/src/main/java/org/apache/kafka/common/errors/PreferredLeaderNotAvailableException.java new file mode 100644 index 0000000..73dfd64 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/PreferredLeaderNotAvailableException.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class PreferredLeaderNotAvailableException extends InvalidMetadataException { + + public PreferredLeaderNotAvailableException(String message) { + super(message); + } + + public PreferredLeaderNotAvailableException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/PrincipalDeserializationException.java b/clients/src/main/java/org/apache/kafka/common/errors/PrincipalDeserializationException.java new file mode 100644 index 0000000..d0eed95 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/PrincipalDeserializationException.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * Exception used to indicate a kafka principal deserialization failure during request forwarding. + */ +public class PrincipalDeserializationException extends ApiException { + + private static final long serialVersionUID = 1L; + + public PrincipalDeserializationException(String message) { + super(message); + } + + public PrincipalDeserializationException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/ProducerFencedException.java b/clients/src/main/java/org/apache/kafka/common/errors/ProducerFencedException.java new file mode 100644 index 0000000..c47dbf5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/ProducerFencedException.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * This fatal exception indicates that another producer with the same transactional.id has been + * started. It is only possible to have one producer instance with a transactional.id at any + * given time, and the latest one to be started "fences" the previous instances so that they can no longer + * make transactional requests. When you encounter this exception, you must close the producer instance. + */ +public class ProducerFencedException extends ApiException { + + public ProducerFencedException(String msg) { + super(msg); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/ReassignmentInProgressException.java b/clients/src/main/java/org/apache/kafka/common/errors/ReassignmentInProgressException.java new file mode 100644 index 0000000..abd624b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/ReassignmentInProgressException.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.errors; + +/** + * Thrown if a request cannot be completed because a partition reassignment is in progress. + */ +public class ReassignmentInProgressException extends ApiException { + + public ReassignmentInProgressException(String msg) { + super(msg); + } + + public ReassignmentInProgressException(String msg, Throwable cause) { + super(msg, cause); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/RebalanceInProgressException.java b/clients/src/main/java/org/apache/kafka/common/errors/RebalanceInProgressException.java new file mode 100644 index 0000000..031abb7 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/RebalanceInProgressException.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class RebalanceInProgressException extends ApiException { + private static final long serialVersionUID = 1L; + + public RebalanceInProgressException() { + super(); + } + + public RebalanceInProgressException(String message, Throwable cause) { + super(message, cause); + } + + public RebalanceInProgressException(String message) { + super(message); + } + + public RebalanceInProgressException(Throwable cause) { + super(cause); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/RecordBatchTooLargeException.java b/clients/src/main/java/org/apache/kafka/common/errors/RecordBatchTooLargeException.java new file mode 100644 index 0000000..b1ef77d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/RecordBatchTooLargeException.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * This record batch is larger than the maximum allowable size + */ +public class RecordBatchTooLargeException extends ApiException { + + private static final long serialVersionUID = 1L; + + public RecordBatchTooLargeException() { + super(); + } + + public RecordBatchTooLargeException(String message, Throwable cause) { + super(message, cause); + } + + public RecordBatchTooLargeException(String message) { + super(message); + } + + public RecordBatchTooLargeException(Throwable cause) { + super(cause); + } + +} + diff --git a/clients/src/main/java/org/apache/kafka/common/errors/RecordDeserializationException.java b/clients/src/main/java/org/apache/kafka/common/errors/RecordDeserializationException.java new file mode 100644 index 0000000..a15df6c --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/RecordDeserializationException.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +import org.apache.kafka.common.TopicPartition; + +/** + * This exception is raised for any error that occurs while deserializing records received by the consumer using + * the configured {@link org.apache.kafka.common.serialization.Deserializer}. + */ +public class RecordDeserializationException extends SerializationException { + + private static final long serialVersionUID = 1L; + private final TopicPartition partition; + private final long offset; + + public RecordDeserializationException(TopicPartition partition, long offset, String message, Throwable cause) { + super(message, cause); + this.partition = partition; + this.offset = offset; + } + + public TopicPartition topicPartition() { + return partition; + } + + public long offset() { + return offset; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/RecordTooLargeException.java b/clients/src/main/java/org/apache/kafka/common/errors/RecordTooLargeException.java new file mode 100644 index 0000000..9ffaa87 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/RecordTooLargeException.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +import org.apache.kafka.common.TopicPartition; + +import java.util.Map; + +/** + * This record is larger than the maximum allowable size + */ +public class RecordTooLargeException extends ApiException { + + private static final long serialVersionUID = 1L; + private Map recordTooLargePartitions = null; + + public RecordTooLargeException() { + super(); + } + + public RecordTooLargeException(String message, Throwable cause) { + super(message, cause); + } + + public RecordTooLargeException(String message) { + super(message); + } + + public RecordTooLargeException(Throwable cause) { + super(cause); + } + + public RecordTooLargeException(String message, Map recordTooLargePartitions) { + super(message); + this.recordTooLargePartitions = recordTooLargePartitions; + } + + public Map recordTooLargePartitions() { + return recordTooLargePartitions; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/ReplicaNotAvailableException.java b/clients/src/main/java/org/apache/kafka/common/errors/ReplicaNotAvailableException.java new file mode 100644 index 0000000..07971cd --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/ReplicaNotAvailableException.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * The replica is not available for the requested topic partition. This may be + * a transient exception during reassignments. From version 2.6 onwards, Fetch requests + * and other requests intended only for the leader or follower of the topic partition return + * {@link NotLeaderOrFollowerException} if the broker is a not a replica of the partition. + */ +public class ReplicaNotAvailableException extends InvalidMetadataException { + + private static final long serialVersionUID = 1L; + + public ReplicaNotAvailableException(String message) { + super(message); + } + + public ReplicaNotAvailableException(String message, Throwable cause) { + super(message, cause); + } + + public ReplicaNotAvailableException(Throwable cause) { + super(cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/ResourceNotFoundException.java b/clients/src/main/java/org/apache/kafka/common/errors/ResourceNotFoundException.java new file mode 100644 index 0000000..17dca08 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/ResourceNotFoundException.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * Exception thrown due to a request for a resource that does not exist. + */ +public class ResourceNotFoundException extends ApiException { + + private static final long serialVersionUID = 1L; + + private final String resource; + + /** + * Constructor + * + * @param message the exception's message + */ + public ResourceNotFoundException(String message) { + this(null, message); + } + + /** + * + * @param message the exception's message + * @param cause the exception's cause + */ + public ResourceNotFoundException(String message, Throwable cause) { + this(null, message, cause); + } + + /** + * Constructor + * + * @param resource the (potentially null) resource that was not found + * @param message the exception's message + */ + public ResourceNotFoundException(String resource, String message) { + super(message); + this.resource = resource; + } + + /** + * Constructor + * + * @param resource the (potentially null) resource that was not found + * @param message the exception's message + * @param cause the exception's cause + */ + public ResourceNotFoundException(String resource, String message, Throwable cause) { + super(message, cause); + this.resource = resource; + } + + /** + * + * @return the (potentially null) resource that was not found + */ + public String resource() { + return this.resource; + } +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/common/errors/RetriableException.java b/clients/src/main/java/org/apache/kafka/common/errors/RetriableException.java new file mode 100644 index 0000000..6d9a76d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/RetriableException.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * A retriable exception is a transient exception that if retried may succeed. + */ +public abstract class RetriableException extends ApiException { + + private static final long serialVersionUID = 1L; + + public RetriableException(String message, Throwable cause) { + super(message, cause); + } + + public RetriableException(String message) { + super(message); + } + + public RetriableException(Throwable cause) { + super(cause); + } + + public RetriableException() { + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/SaslAuthenticationException.java b/clients/src/main/java/org/apache/kafka/common/errors/SaslAuthenticationException.java new file mode 100644 index 0000000..c6bc8bd --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/SaslAuthenticationException.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +import javax.security.sasl.SaslServer; + +/** + * This exception indicates that SASL authentication has failed. The error message + * in the exception indicates the actual cause of failure. + *

        + * SASL authentication failures typically indicate invalid credentials, but + * could also include other failures specific to the SASL mechanism used + * for authentication. + *

        + *

        Note:If {@link SaslServer#evaluateResponse(byte[])} throws this exception during + * authentication, the message from the exception will be sent to clients in the SaslAuthenticate + * response. Custom {@link SaslServer} implementations may throw this exception in order to + * provide custom error messages to clients, but should take care not to include any + * security-critical information in the message that should not be leaked to unauthenticated clients. + *

        + */ +public class SaslAuthenticationException extends AuthenticationException { + + private static final long serialVersionUID = 1L; + + public SaslAuthenticationException(String message) { + super(message); + } + + public SaslAuthenticationException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/SecurityDisabledException.java b/clients/src/main/java/org/apache/kafka/common/errors/SecurityDisabledException.java new file mode 100644 index 0000000..25f3f35 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/SecurityDisabledException.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * An error indicating that security is disabled on the broker. + */ +public class SecurityDisabledException extends ApiException { + private static final long serialVersionUID = 1L; + + public SecurityDisabledException(String message) { + super(message); + } + + public SecurityDisabledException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/SerializationException.java b/clients/src/main/java/org/apache/kafka/common/errors/SerializationException.java new file mode 100644 index 0000000..02d0710 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/SerializationException.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +import org.apache.kafka.common.KafkaException; + +/** + * Any exception during serialization in the producer + */ +public class SerializationException extends KafkaException { + + private static final long serialVersionUID = 1L; + + public SerializationException(String message, Throwable cause) { + super(message, cause); + } + + public SerializationException(String message) { + super(message); + } + + public SerializationException(Throwable cause) { + super(cause); + } + + public SerializationException() { + super(); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/SnapshotNotFoundException.java b/clients/src/main/java/org/apache/kafka/common/errors/SnapshotNotFoundException.java new file mode 100644 index 0000000..5b3e7ed --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/SnapshotNotFoundException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class SnapshotNotFoundException extends ApiException { + + private static final long serialVersionUID = 1; + + public SnapshotNotFoundException(String s) { + super(s); + } + + public SnapshotNotFoundException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/SslAuthenticationException.java b/clients/src/main/java/org/apache/kafka/common/errors/SslAuthenticationException.java new file mode 100644 index 0000000..3cdbf2a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/SslAuthenticationException.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +import javax.net.ssl.SSLException; + +/** + * This exception indicates that SSL handshake has failed. See {@link #getCause()} + * for the {@link SSLException} that caused this failure. + *

        + * SSL handshake failures in clients may indicate client authentication + * failure due to untrusted certificates if server is configured to request + * client certificates. Handshake failures could also indicate misconfigured + * security including protocol/cipher suite mismatch, server certificate + * authentication failure or server host name verification failure. + *

        + */ +public class SslAuthenticationException extends AuthenticationException { + + private static final long serialVersionUID = 1L; + + public SslAuthenticationException(String message) { + super(message); + } + + public SslAuthenticationException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/StaleBrokerEpochException.java b/clients/src/main/java/org/apache/kafka/common/errors/StaleBrokerEpochException.java new file mode 100644 index 0000000..a5c0b41 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/StaleBrokerEpochException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class StaleBrokerEpochException extends ApiException { + + private static final long serialVersionUID = 1L; + + public StaleBrokerEpochException(String message) { + super(message); + } + + public StaleBrokerEpochException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/ThrottlingQuotaExceededException.java b/clients/src/main/java/org/apache/kafka/common/errors/ThrottlingQuotaExceededException.java new file mode 100644 index 0000000..c4d1350 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/ThrottlingQuotaExceededException.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * Exception thrown if an operation on a resource exceeds the throttling quota. + */ +public class ThrottlingQuotaExceededException extends RetriableException { + private int throttleTimeMs = 0; + + public ThrottlingQuotaExceededException(String message) { + super(message); + } + + public ThrottlingQuotaExceededException(int throttleTimeMs, String message) { + super(message); + this.throttleTimeMs = throttleTimeMs; + } + + public int throttleTimeMs() { + return this.throttleTimeMs; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/TimeoutException.java b/clients/src/main/java/org/apache/kafka/common/errors/TimeoutException.java new file mode 100644 index 0000000..47fe034 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/TimeoutException.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * Indicates that a request timed out. + */ +public class TimeoutException extends RetriableException { + + private static final long serialVersionUID = 1L; + + public TimeoutException() { + super(); + } + + public TimeoutException(String message, Throwable cause) { + super(message, cause); + } + + public TimeoutException(String message) { + super(message); + } + + public TimeoutException(Throwable cause) { + super(cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/TopicAuthorizationException.java b/clients/src/main/java/org/apache/kafka/common/errors/TopicAuthorizationException.java new file mode 100644 index 0000000..e2235f8 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/TopicAuthorizationException.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +import java.util.Collections; +import java.util.Set; + +public class TopicAuthorizationException extends AuthorizationException { + private final Set unauthorizedTopics; + + public TopicAuthorizationException(String message, Set unauthorizedTopics) { + super(message); + this.unauthorizedTopics = unauthorizedTopics; + } + + public TopicAuthorizationException(Set unauthorizedTopics) { + this("Not authorized to access topics: " + unauthorizedTopics, unauthorizedTopics); + } + + public TopicAuthorizationException(String message) { + this(message, Collections.emptySet()); + } + + /** + * Get the set of topics which failed authorization. May be empty if the set is not known + * in the context the exception was raised in. + * + * @return possibly empty set of unauthorized topics + */ + public Set unauthorizedTopics() { + return unauthorizedTopics; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/TopicDeletionDisabledException.java b/clients/src/main/java/org/apache/kafka/common/errors/TopicDeletionDisabledException.java new file mode 100644 index 0000000..41577d2 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/TopicDeletionDisabledException.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.errors; + +public class TopicDeletionDisabledException extends ApiException { + private static final long serialVersionUID = 1L; + + public TopicDeletionDisabledException() { + } + + public TopicDeletionDisabledException(String message) { + super(message); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/TopicExistsException.java b/clients/src/main/java/org/apache/kafka/common/errors/TopicExistsException.java new file mode 100644 index 0000000..cc0c8f1 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/TopicExistsException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class TopicExistsException extends ApiException { + + private static final long serialVersionUID = 1L; + + public TopicExistsException(String message) { + super(message); + } + + public TopicExistsException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/TransactionAbortedException.java b/clients/src/main/java/org/apache/kafka/common/errors/TransactionAbortedException.java new file mode 100644 index 0000000..c394ac5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/TransactionAbortedException.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * This is the Exception thrown when we are aborting any undrained batches during + * a transaction which is aborted without any underlying cause - which likely means that the user chose to abort. + */ +public class TransactionAbortedException extends ApiException { + + private final static long serialVersionUID = 1L; + + public TransactionAbortedException(String message, Throwable cause) { + super(message, cause); + } + + public TransactionAbortedException(String message) { + super(message); + } + + public TransactionAbortedException() { + super("Failing batch since transaction was aborted"); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/TransactionCoordinatorFencedException.java b/clients/src/main/java/org/apache/kafka/common/errors/TransactionCoordinatorFencedException.java new file mode 100644 index 0000000..583ce04 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/TransactionCoordinatorFencedException.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class TransactionCoordinatorFencedException extends ApiException { + + private static final long serialVersionUID = 1L; + + public TransactionCoordinatorFencedException(String message) { + super(message); + } + + public TransactionCoordinatorFencedException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/TransactionalIdAuthorizationException.java b/clients/src/main/java/org/apache/kafka/common/errors/TransactionalIdAuthorizationException.java new file mode 100644 index 0000000..3f85513 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/TransactionalIdAuthorizationException.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class TransactionalIdAuthorizationException extends AuthorizationException { + public TransactionalIdAuthorizationException(final String message) { + super(message); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/TransactionalIdNotFoundException.java b/clients/src/main/java/org/apache/kafka/common/errors/TransactionalIdNotFoundException.java new file mode 100644 index 0000000..240eaa3 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/TransactionalIdNotFoundException.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class TransactionalIdNotFoundException extends ApiException { + + public TransactionalIdNotFoundException(String message) { + super(message); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/UnacceptableCredentialException.java b/clients/src/main/java/org/apache/kafka/common/errors/UnacceptableCredentialException.java new file mode 100644 index 0000000..b7cffff --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/UnacceptableCredentialException.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * Exception thrown when attempting to define a credential that does not meet the criteria for acceptability + * (for example, attempting to create a SCRAM credential with an empty username or password or too few/many iterations). + */ +public class UnacceptableCredentialException extends ApiException { + + private static final long serialVersionUID = 1L; + + /** + * Constructor + * + * @param message the exception's message + */ + public UnacceptableCredentialException(String message) { + super(message); + } + + /** + * + * @param message the exception's message + * @param cause the exception's cause + */ + public UnacceptableCredentialException(String message, Throwable cause) { + super(message, cause); + } +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/common/errors/UnknownLeaderEpochException.java b/clients/src/main/java/org/apache/kafka/common/errors/UnknownLeaderEpochException.java new file mode 100644 index 0000000..3714c36 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/UnknownLeaderEpochException.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * The request contained a leader epoch which is larger than that on the broker that received the + * request. This can happen if the client observes a metadata update before it has been propagated + * to all brokers. Clients need not refresh metadata before retrying. + */ +public class UnknownLeaderEpochException extends RetriableException { + private static final long serialVersionUID = 1L; + + public UnknownLeaderEpochException(String message) { + super(message); + } + + public UnknownLeaderEpochException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/UnknownMemberIdException.java b/clients/src/main/java/org/apache/kafka/common/errors/UnknownMemberIdException.java new file mode 100644 index 0000000..f6eea5b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/UnknownMemberIdException.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class UnknownMemberIdException extends ApiException { + private static final long serialVersionUID = 1L; + + public UnknownMemberIdException() { + super(); + } + + public UnknownMemberIdException(String message, Throwable cause) { + super(message, cause); + } + + public UnknownMemberIdException(String message) { + super(message); + } + + public UnknownMemberIdException(Throwable cause) { + super(cause); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/UnknownProducerIdException.java b/clients/src/main/java/org/apache/kafka/common/errors/UnknownProducerIdException.java new file mode 100644 index 0000000..ce17345 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/UnknownProducerIdException.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.errors; + +/** + * This exception is raised by the broker if it could not locate the producer metadata associated with the producerId + * in question. This could happen if, for instance, the producer's records were deleted because their retention time + * had elapsed. Once the last records of the producerId are removed, the producer's metadata is removed from the broker, + * and future appends by the producer will return this exception. + */ +public class UnknownProducerIdException extends OutOfOrderSequenceException { + + public UnknownProducerIdException(String message) { + super(message); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/UnknownServerException.java b/clients/src/main/java/org/apache/kafka/common/errors/UnknownServerException.java new file mode 100644 index 0000000..37e003b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/UnknownServerException.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * An error occurred on the server for which the client doesn't have a corresponding error code. This is generally an + * unexpected error. + * + */ +public class UnknownServerException extends ApiException { + + private static final long serialVersionUID = 1L; + + public UnknownServerException() { + } + + public UnknownServerException(String message) { + super(message); + } + + public UnknownServerException(Throwable cause) { + super(cause); + } + + public UnknownServerException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/UnknownTopicIdException.java b/clients/src/main/java/org/apache/kafka/common/errors/UnknownTopicIdException.java new file mode 100644 index 0000000..e5023eb --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/UnknownTopicIdException.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +public class UnknownTopicIdException extends InvalidMetadataException { + + private static final long serialVersionUID = 1L; + + public UnknownTopicIdException(String message) { + super(message); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/UnknownTopicOrPartitionException.java b/clients/src/main/java/org/apache/kafka/common/errors/UnknownTopicOrPartitionException.java new file mode 100644 index 0000000..6d10945 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/UnknownTopicOrPartitionException.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * This topic/partition doesn't exist. + * This exception is used in contexts where a topic doesn't seem to exist based on possibly stale metadata. + * This exception is retriable because the topic or partition might subsequently be created. + * + * @see InvalidTopicException + */ +public class UnknownTopicOrPartitionException extends InvalidMetadataException { + + private static final long serialVersionUID = 1L; + + public UnknownTopicOrPartitionException() { + } + + public UnknownTopicOrPartitionException(String message) { + super(message); + } + + public UnknownTopicOrPartitionException(Throwable throwable) { + super(throwable); + } + + public UnknownTopicOrPartitionException(String message, Throwable throwable) { + super(message, throwable); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/UnstableOffsetCommitException.java b/clients/src/main/java/org/apache/kafka/common/errors/UnstableOffsetCommitException.java new file mode 100644 index 0000000..c89d717 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/UnstableOffsetCommitException.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * Exception thrown when there are unstable offsets for the requested topic partitions. + */ +public class UnstableOffsetCommitException extends RetriableException { + + private static final long serialVersionUID = 1L; + + public UnstableOffsetCommitException(String message) { + super(message); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/UnsupportedByAuthenticationException.java b/clients/src/main/java/org/apache/kafka/common/errors/UnsupportedByAuthenticationException.java new file mode 100644 index 0000000..40f357c --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/UnsupportedByAuthenticationException.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * Authentication mechanism does not support the requested function. + */ +public class UnsupportedByAuthenticationException extends ApiException { + private static final long serialVersionUID = 1L; + + public UnsupportedByAuthenticationException(String message) { + super(message); + } + + public UnsupportedByAuthenticationException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/UnsupportedCompressionTypeException.java b/clients/src/main/java/org/apache/kafka/common/errors/UnsupportedCompressionTypeException.java new file mode 100644 index 0000000..29ffe1b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/UnsupportedCompressionTypeException.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * The requesting client does not support the compression type of given partition. + */ +public class UnsupportedCompressionTypeException extends ApiException { + + private static final long serialVersionUID = 1L; + + public UnsupportedCompressionTypeException(String message) { + super(message); + } + + public UnsupportedCompressionTypeException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/UnsupportedForMessageFormatException.java b/clients/src/main/java/org/apache/kafka/common/errors/UnsupportedForMessageFormatException.java new file mode 100644 index 0000000..f66298e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/UnsupportedForMessageFormatException.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * The message format version does not support the requested function. For example, if idempotence is + * requested and the topic is using a message format older than 0.11.0.0, then this error will be returned. + */ +public class UnsupportedForMessageFormatException extends ApiException { + private static final long serialVersionUID = 1L; + + public UnsupportedForMessageFormatException(String message) { + super(message); + } + + public UnsupportedForMessageFormatException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/UnsupportedSaslMechanismException.java b/clients/src/main/java/org/apache/kafka/common/errors/UnsupportedSaslMechanismException.java new file mode 100644 index 0000000..4db4aee --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/UnsupportedSaslMechanismException.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +/** + * This exception indicates that the SASL mechanism requested by the client + * is not enabled on the broker. + */ +public class UnsupportedSaslMechanismException extends AuthenticationException { + + private static final long serialVersionUID = 1L; + + public UnsupportedSaslMechanismException(String message) { + super(message); + } + + public UnsupportedSaslMechanismException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/UnsupportedVersionException.java b/clients/src/main/java/org/apache/kafka/common/errors/UnsupportedVersionException.java new file mode 100644 index 0000000..484947b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/UnsupportedVersionException.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +import java.util.Map; + +/** + * Indicates that a request API or version needed by the client is not supported by the broker. This is + * typically a fatal error as Kafka clients will downgrade request versions as needed except in cases where + * a needed feature is not available in old versions. Fatal errors can generally only be handled by closing + * the client instance, although in some cases it may be possible to continue without relying on the + * underlying feature. For example, when the producer is used with idempotence enabled, this error is fatal + * since the producer does not support reverting to weaker semantics. On the other hand, if this error + * is raised from {@link org.apache.kafka.clients.consumer.KafkaConsumer#offsetsForTimes(Map)}, it would + * be possible to revert to alternative logic to set the consumer's position. + */ +public class UnsupportedVersionException extends ApiException { + private static final long serialVersionUID = 1L; + + public UnsupportedVersionException(String message, Throwable cause) { + super(message, cause); + } + + public UnsupportedVersionException(String message) { + super(message); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/errors/WakeupException.java b/clients/src/main/java/org/apache/kafka/common/errors/WakeupException.java new file mode 100644 index 0000000..f8ae840 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/errors/WakeupException.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.errors; + +import org.apache.kafka.common.KafkaException; + +/** + * Exception used to indicate preemption of a blocking operation by an external thread. + * For example, {@link org.apache.kafka.clients.consumer.KafkaConsumer#wakeup} + * can be used to break out of an active {@link org.apache.kafka.clients.consumer.KafkaConsumer#poll(java.time.Duration)}, + * which would raise an instance of this exception. + */ +public class WakeupException extends KafkaException { + private static final long serialVersionUID = 1L; + +} diff --git a/clients/src/main/java/org/apache/kafka/common/feature/BaseVersionRange.java b/clients/src/main/java/org/apache/kafka/common/feature/BaseVersionRange.java new file mode 100644 index 0000000..2d6ce70 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/feature/BaseVersionRange.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.feature; + +import static java.util.stream.Collectors.joining; + +import java.util.Map; +import java.util.Objects; + +import org.apache.kafka.common.utils.Utils; + +/** + * Represents an immutable basic version range using 2 attributes: min and max, each of type short. + * The min and max attributes need to satisfy 2 rules: + * - they are each expected to be >= 1, as we only consider positive version values to be valid. + * - max should be >= min. + * + * The class also provides API to convert the version range to a map. + * The class allows for configurable labels for the min/max attributes, which can be specialized by + * sub-classes (if needed). + */ +class BaseVersionRange { + // Non-empty label for the min version key, that's used only to convert to/from a map. + private final String minKeyLabel; + + // The value of the minimum version. + private final short minValue; + + // Non-empty label for the max version key, that's used only to convert to/from a map. + private final String maxKeyLabel; + + // The value of the maximum version. + private final short maxValue; + + /** + * Raises an exception unless the following condition is met: + * minValue >= 1 and maxValue >= 1 and maxValue >= minValue. + * + * @param minKeyLabel Label for the min version key, that's used only to convert to/from a map. + * @param minValue The minimum version value. + * @param maxKeyLabel Label for the max version key, that's used only to convert to/from a map. + * @param maxValue The maximum version value. + * + * @throws IllegalArgumentException If any of the following conditions are true: + * - (minValue < 1) OR (maxValue < 1) OR (maxValue < minValue). + * - minKeyLabel is empty, OR, minKeyLabel is empty. + */ + protected BaseVersionRange(String minKeyLabel, short minValue, String maxKeyLabel, short maxValue) { + if (minValue < 1 || maxValue < 1 || maxValue < minValue) { + throw new IllegalArgumentException( + String.format( + "Expected minValue >= 1, maxValue >= 1 and maxValue >= minValue, but received" + + " minValue: %d, maxValue: %d", minValue, maxValue)); + } + if (minKeyLabel.isEmpty()) { + throw new IllegalArgumentException("Expected minKeyLabel to be non-empty."); + } + if (maxKeyLabel.isEmpty()) { + throw new IllegalArgumentException("Expected maxKeyLabel to be non-empty."); + } + this.minKeyLabel = minKeyLabel; + this.minValue = minValue; + this.maxKeyLabel = maxKeyLabel; + this.maxValue = maxValue; + } + + public short min() { + return minValue; + } + + public short max() { + return maxValue; + } + + public String toString() { + return String.format( + "%s[%s]", + this.getClass().getSimpleName(), + mapToString(toMap())); + } + + public Map toMap() { + return Utils.mkMap(Utils.mkEntry(minKeyLabel, min()), Utils.mkEntry(maxKeyLabel, max())); + } + + private static String mapToString(final Map map) { + return map + .entrySet() + .stream() + .map(entry -> String.format("%s:%d", entry.getKey(), entry.getValue())) + .collect(joining(", ")); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + + if (other == null || getClass() != other.getClass()) { + return false; + } + + final BaseVersionRange that = (BaseVersionRange) other; + return Objects.equals(this.minKeyLabel, that.minKeyLabel) && + this.minValue == that.minValue && + Objects.equals(this.maxKeyLabel, that.maxKeyLabel) && + this.maxValue == that.maxValue; + } + + @Override + public int hashCode() { + return Objects.hash(minKeyLabel, minValue, maxKeyLabel, maxValue); + } + + public static short valueOrThrow(String key, Map versionRangeMap) { + final Short value = versionRangeMap.get(key); + if (value == null) { + throw new IllegalArgumentException(String.format("%s absent in [%s]", key, mapToString(versionRangeMap))); + } + return value; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/feature/Features.java b/clients/src/main/java/org/apache/kafka/common/feature/Features.java new file mode 100644 index 0000000..4006d71 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/feature/Features.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.feature; + +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.Objects; + +import static java.util.stream.Collectors.joining; + +/** + * Represents an immutable dictionary with key being feature name, and value being . + * Also provides API to convert the features and their version ranges to/from a map. + * + * This class can be instantiated only using its factory functions, with the important ones being: + * Features.supportedFeatures(...) and Features.finalizedFeatures(...). + * + * @param is the type of version range. + * @see SupportedVersionRange + * @see FinalizedVersionRange + */ +public class Features { + private final Map features; + + /** + * Constructor is made private, as for readability it is preferred the caller uses one of the + * static factory functions for instantiation (see below). + * + * @param features Map of feature name to a type of VersionRange. + */ + private Features(Map features) { + Objects.requireNonNull(features, "Provided features can not be null."); + this.features = features; + } + + /** + * @param features Map of feature name to SupportedVersionRange. + * + * @return Returns a new Features object representing supported features. + */ + public static Features supportedFeatures(Map features) { + return new Features<>(features); + } + + /** + * @param features Map of feature name to FinalizedVersionRange. + * + * @return Returns a new Features object representing finalized features. + */ + public static Features finalizedFeatures(Map features) { + return new Features<>(features); + } + + // Visible for testing. + public static Features emptyFinalizedFeatures() { + return new Features<>(new HashMap<>()); + } + + public static Features emptySupportedFeatures() { + return new Features<>(new HashMap<>()); + } + + public Map features() { + return features; + } + + public boolean empty() { + return features.isEmpty(); + } + + /** + * @param feature name of the feature + * + * @return the VersionRangeType corresponding to the feature name, or null if the + * feature is absent + */ + public VersionRangeType get(String feature) { + return features.get(feature); + } + + public String toString() { + return String.format( + "Features{%s}", + features + .entrySet() + .stream() + .map(entry -> String.format("(%s -> %s)", entry.getKey(), entry.getValue())) + .collect(joining(", ")) + ); + } + + /** + * @return A map representation of the underlying features. The returned value can be converted + * back to Features using one of the from*FeaturesMap() APIs of this class. + */ + public Map> toMap() { + return features.entrySet().stream().collect( + Collectors.toMap( + Map.Entry::getKey, + entry -> entry.getValue().toMap())); + } + + /** + * An interface that defines behavior to convert from a Map to an object of type BaseVersionRange. + */ + private interface MapToBaseVersionRangeConverter { + + /** + * Convert the map representation of an object of type , to an object of type . + * + * @param baseVersionRangeMap the map representation of a BaseVersionRange object. + * + * @return the object of type + */ + V fromMap(Map baseVersionRangeMap); + } + + private static Features fromFeaturesMap( + Map> featuresMap, MapToBaseVersionRangeConverter converter) { + return new Features<>(featuresMap.entrySet().stream().collect( + Collectors.toMap( + Map.Entry::getKey, + entry -> converter.fromMap(entry.getValue())))); + } + + /** + * Converts from a map to Features. + * + * @param featuresMap the map representation of a Features object, + * generated using the toMap() API. + * + * @return the Features object + */ + public static Features fromFinalizedFeaturesMap( + Map> featuresMap) { + return fromFeaturesMap(featuresMap, FinalizedVersionRange::fromMap); + } + + /** + * Converts from a map to Features. + * + * @param featuresMap the map representation of a Features object, + * generated using the toMap() API. + * + * @return the Features object + */ + public static Features fromSupportedFeaturesMap( + Map> featuresMap) { + return fromFeaturesMap(featuresMap, SupportedVersionRange::fromMap); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (!(other instanceof Features)) { + return false; + } + + final Features that = (Features) other; + return Objects.equals(this.features, that.features); + } + + @Override + public int hashCode() { + return Objects.hash(features); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/feature/FinalizedVersionRange.java b/clients/src/main/java/org/apache/kafka/common/feature/FinalizedVersionRange.java new file mode 100644 index 0000000..27e6440 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/feature/FinalizedVersionRange.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.feature; + +import java.util.Map; + +/** + * An extended {@link BaseVersionRange} representing the min/max versions for a finalized feature. + */ +public class FinalizedVersionRange extends BaseVersionRange { + // Label for the min version key, that's used only to convert to/from a map. + private static final String MIN_VERSION_LEVEL_KEY_LABEL = "min_version_level"; + + // Label for the max version key, that's used only to convert to/from a map. + private static final String MAX_VERSION_LEVEL_KEY_LABEL = "max_version_level"; + + public FinalizedVersionRange(short minVersionLevel, short maxVersionLevel) { + super(MIN_VERSION_LEVEL_KEY_LABEL, minVersionLevel, MAX_VERSION_LEVEL_KEY_LABEL, maxVersionLevel); + } + + public static FinalizedVersionRange fromMap(Map versionRangeMap) { + return new FinalizedVersionRange( + BaseVersionRange.valueOrThrow(MIN_VERSION_LEVEL_KEY_LABEL, versionRangeMap), + BaseVersionRange.valueOrThrow(MAX_VERSION_LEVEL_KEY_LABEL, versionRangeMap)); + } + + /** + * Checks if the [min, max] version level range of this object does *NOT* fall within the + * [min, max] range of the provided SupportedVersionRange parameter. + * + * @param supportedVersionRange the SupportedVersionRange to be checked + * + * @return - true, if the version levels are compatible + * - false otherwise + */ + public boolean isIncompatibleWith(SupportedVersionRange supportedVersionRange) { + return min() < supportedVersionRange.min() || max() > supportedVersionRange.max(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/feature/SupportedVersionRange.java b/clients/src/main/java/org/apache/kafka/common/feature/SupportedVersionRange.java new file mode 100644 index 0000000..8993014 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/feature/SupportedVersionRange.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.feature; + +import java.util.Map; + +/** + * An extended {@link BaseVersionRange} representing the min/max versions for a supported feature. + */ +public class SupportedVersionRange extends BaseVersionRange { + // Label for the min version key, that's used only to convert to/from a map. + private static final String MIN_VERSION_KEY_LABEL = "min_version"; + + // Label for the max version key, that's used only to convert to/from a map. + private static final String MAX_VERSION_KEY_LABEL = "max_version"; + + public SupportedVersionRange(short minVersion, short maxVersion) { + super(MIN_VERSION_KEY_LABEL, minVersion, MAX_VERSION_KEY_LABEL, maxVersion); + } + + public SupportedVersionRange(short maxVersion) { + this((short) 1, maxVersion); + } + + public static SupportedVersionRange fromMap(Map versionRangeMap) { + return new SupportedVersionRange( + BaseVersionRange.valueOrThrow(MIN_VERSION_KEY_LABEL, versionRangeMap), + BaseVersionRange.valueOrThrow(MAX_VERSION_KEY_LABEL, versionRangeMap)); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/header/Header.java b/clients/src/main/java/org/apache/kafka/common/header/Header.java new file mode 100644 index 0000000..58869b4 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/header/Header.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.header; + +public interface Header { + + String key(); + + byte[] value(); + +} diff --git a/clients/src/main/java/org/apache/kafka/common/header/Headers.java b/clients/src/main/java/org/apache/kafka/common/header/Headers.java new file mode 100644 index 0000000..2353249 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/header/Headers.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.header; + +public interface Headers extends Iterable
        { + + /** + * Adds a header (key inside), to the end, returning if the operation succeeded. + * + * @param header the Header to be added + * @return this instance of the Headers, once the header is added. + * @throws IllegalStateException is thrown if headers are in a read-only state. + */ + Headers add(Header header) throws IllegalStateException; + + /** + * Creates and adds a header, to the end, returning if the operation succeeded. + * + * @param key of the header to be added. + * @param value of the header to be added. + * @return this instance of the Headers, once the header is added. + * @throws IllegalStateException is thrown if headers are in a read-only state. + */ + Headers add(String key, byte[] value) throws IllegalStateException; + + /** + * Removes all headers for the given key returning if the operation succeeded. + * + * @param key to remove all headers for. + * @return this instance of the Headers, once the header is removed. + * @throws IllegalStateException is thrown if headers are in a read-only state. + */ + Headers remove(String key) throws IllegalStateException; + + /** + * Returns just one (the very last) header for the given key, if present. + * + * @param key to get the last header for. + * @return this last header matching the given key, returns none if not present. + */ + Header lastHeader(String key); + + /** + * Returns all headers for the given key, in the order they were added in, if present. + * + * @param key to return the headers for. + * @return all headers for the given key, in the order they were added in, if NO headers are present an empty iterable is returned. + */ + Iterable
        headers(String key); + + /** + * Returns all headers as an array, in the order they were added in. + * + * @return the headers as a Header[], mutating this array will not affect the Headers, if NO headers are present an empty array is returned. + */ + Header[] toArray(); + +} diff --git a/clients/src/main/java/org/apache/kafka/common/header/internals/RecordHeader.java b/clients/src/main/java/org/apache/kafka/common/header/internals/RecordHeader.java new file mode 100644 index 0000000..2a29d9d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/header/internals/RecordHeader.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.header.internals; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Objects; + +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.utils.Utils; + +public class RecordHeader implements Header { + private ByteBuffer keyBuffer; + private String key; + private ByteBuffer valueBuffer; + private byte[] value; + + public RecordHeader(String key, byte[] value) { + Objects.requireNonNull(key, "Null header keys are not permitted"); + this.key = key; + this.value = value; + } + + public RecordHeader(ByteBuffer keyBuffer, ByteBuffer valueBuffer) { + this.keyBuffer = Objects.requireNonNull(keyBuffer, "Null header keys are not permitted"); + this.valueBuffer = valueBuffer; + } + + public String key() { + if (key == null) { + key = Utils.utf8(keyBuffer, keyBuffer.remaining()); + keyBuffer = null; + } + return key; + } + + public byte[] value() { + if (value == null && valueBuffer != null) { + value = Utils.toArray(valueBuffer); + valueBuffer = null; + } + return value; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + RecordHeader header = (RecordHeader) o; + return Objects.equals(key(), header.key()) && + Arrays.equals(value(), header.value()); + } + + @Override + public int hashCode() { + int result = key().hashCode(); + result = 31 * result + Arrays.hashCode(value()); + return result; + } + + @Override + public String toString() { + return "RecordHeader(key = " + key() + ", value = " + Arrays.toString(value()) + ")"; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/header/internals/RecordHeaders.java b/clients/src/main/java/org/apache/kafka/common/header/internals/RecordHeaders.java new file mode 100644 index 0000000..7137f72 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/header/internals/RecordHeaders.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.header.internals; + +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.record.Record; +import org.apache.kafka.common.utils.AbstractIterator; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Objects; + +public class RecordHeaders implements Headers { + + private final List
        headers; + private volatile boolean isReadOnly; + + public RecordHeaders() { + this((Iterable
        ) null); + } + + public RecordHeaders(Header[] headers) { + this(headers == null ? null : Arrays.asList(headers)); + } + + public RecordHeaders(Iterable
        headers) { + //Use efficient copy constructor if possible, fallback to iteration otherwise + if (headers == null) { + this.headers = new ArrayList<>(); + } else if (headers instanceof RecordHeaders) { + this.headers = new ArrayList<>(((RecordHeaders) headers).headers); + } else { + this.headers = new ArrayList<>(); + for (Header header : headers) { + Objects.requireNonNull(header, "Header cannot be null."); + this.headers.add(header); + } + } + } + + @Override + public Headers add(Header header) throws IllegalStateException { + Objects.requireNonNull(header, "Header cannot be null."); + canWrite(); + headers.add(header); + return this; + } + + @Override + public Headers add(String key, byte[] value) throws IllegalStateException { + return add(new RecordHeader(key, value)); + } + + @Override + public Headers remove(String key) throws IllegalStateException { + canWrite(); + checkKey(key); + Iterator
        iterator = iterator(); + while (iterator.hasNext()) { + if (iterator.next().key().equals(key)) { + iterator.remove(); + } + } + return this; + } + + @Override + public Header lastHeader(String key) { + checkKey(key); + for (int i = headers.size() - 1; i >= 0; i--) { + Header header = headers.get(i); + if (header.key().equals(key)) { + return header; + } + } + return null; + } + + @Override + public Iterable
        headers(final String key) { + checkKey(key); + return () -> new FilterByKeyIterator(headers.iterator(), key); + } + + @Override + public Iterator
        iterator() { + return closeAware(headers.iterator()); + } + + public void setReadOnly() { + this.isReadOnly = true; + } + + public Header[] toArray() { + return headers.isEmpty() ? Record.EMPTY_HEADERS : headers.toArray(new Header[0]); + } + + private void checkKey(String key) { + if (key == null) + throw new IllegalArgumentException("key cannot be null."); + } + + private void canWrite() { + if (isReadOnly) + throw new IllegalStateException("RecordHeaders has been closed."); + } + + private Iterator
        closeAware(final Iterator
        original) { + return new Iterator
        () { + @Override + public boolean hasNext() { + return original.hasNext(); + } + + public Header next() { + return original.next(); + } + + @Override + public void remove() { + canWrite(); + original.remove(); + } + }; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + RecordHeaders headers1 = (RecordHeaders) o; + + return Objects.equals(headers, headers1.headers); + } + + @Override + public int hashCode() { + return headers != null ? headers.hashCode() : 0; + } + + @Override + public String toString() { + return "RecordHeaders(" + + "headers = " + headers + + ", isReadOnly = " + isReadOnly + + ')'; + } + + private static final class FilterByKeyIterator extends AbstractIterator
        { + + private final Iterator
        original; + private final String key; + + private FilterByKeyIterator(Iterator
        original, String key) { + this.original = original; + this.key = key; + } + + protected Header makeNext() { + while (true) { + if (original.hasNext()) { + Header header = original.next(); + if (!header.key().equals(key)) + continue; + + return header; + } + return this.allDone(); + } + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/internals/ClusterResourceListeners.java b/clients/src/main/java/org/apache/kafka/common/internals/ClusterResourceListeners.java new file mode 100644 index 0000000..1209f38 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/internals/ClusterResourceListeners.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.internals; + +import org.apache.kafka.common.ClusterResource; +import org.apache.kafka.common.ClusterResourceListener; + +import java.util.ArrayList; +import java.util.List; + +public class ClusterResourceListeners { + + private final List clusterResourceListeners; + + public ClusterResourceListeners() { + this.clusterResourceListeners = new ArrayList<>(); + } + + /** + * Add only if the candidate implements {@link ClusterResourceListener}. + * @param candidate Object which might implement {@link ClusterResourceListener} + */ + public void maybeAdd(Object candidate) { + if (candidate instanceof ClusterResourceListener) { + clusterResourceListeners.add((ClusterResourceListener) candidate); + } + } + + /** + * Add all items who implement {@link ClusterResourceListener} from the list. + * @param candidateList List of objects which might implement {@link ClusterResourceListener} + */ + public void maybeAddAll(List candidateList) { + for (Object candidate : candidateList) { + this.maybeAdd(candidate); + } + } + + /** + * Send the updated cluster metadata to all {@link ClusterResourceListener}. + * @param cluster Cluster metadata + */ + public void onUpdate(ClusterResource cluster) { + for (ClusterResourceListener clusterResourceListener : clusterResourceListeners) { + clusterResourceListener.onUpdate(cluster); + } + } +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/common/internals/FatalExitError.java b/clients/src/main/java/org/apache/kafka/common/internals/FatalExitError.java new file mode 100644 index 0000000..901af45 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/internals/FatalExitError.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.internals; + +import org.apache.kafka.common.utils.Exit; + +/** + * An error that indicates the need to exit the JVM process. This should only be used by the server or command-line + * tools. Clients should never shutdown the JVM process. + * + * This exception is expected to be caught at the highest level of the thread so that no shared lock is held by + * the thread when it calls {@link Exit#exit(int)}. + */ +public class FatalExitError extends Error { + + private final static long serialVersionUID = 1L; + + private final int statusCode; + + public FatalExitError(int statusCode) { + if (statusCode == 0) + throw new IllegalArgumentException("statusCode must not be 0"); + this.statusCode = statusCode; + } + + public FatalExitError() { + this(1); + } + + public int statusCode() { + return statusCode; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/internals/KafkaCompletableFuture.java b/clients/src/main/java/org/apache/kafka/common/internals/KafkaCompletableFuture.java new file mode 100644 index 0000000..8c75ea4 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/internals/KafkaCompletableFuture.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.internals; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +/** + * This internal class exists because CompletableFuture exposes complete(), completeExceptionally() and + * other methods which would allow erroneous completion by user code of a KafkaFuture returned from a + * Kafka API to a client application. + * @param The type of the future value. + */ +public class KafkaCompletableFuture extends CompletableFuture { + + /** + * Completes this future normally. For internal use by the Kafka clients, not by user code. + * @param value the result value + * @return {@code true} if this invocation caused this CompletableFuture + * to transition to a completed state, else {@code false} + */ + boolean kafkaComplete(T value) { + return super.complete(value); + } + + /** + * Completes this future exceptionally. For internal use by the Kafka clients, not by user code. + * @param throwable the exception. + * @return {@code true} if this invocation caused this CompletableFuture + * to transition to a completed state, else {@code false} + */ + boolean kafkaCompleteExceptionally(Throwable throwable) { + return super.completeExceptionally(throwable); + } + + @Override + public boolean complete(T value) { + throw erroneousCompletionException(); + } + + @Override + public boolean completeExceptionally(Throwable ex) { + throw erroneousCompletionException(); + } + + @Override + public void obtrudeValue(T value) { + throw erroneousCompletionException(); + } + + @Override + public void obtrudeException(Throwable ex) { + throw erroneousCompletionException(); + } + + //@Override // enable once Kafka no longer supports Java 8 + public CompletableFuture newIncompleteFuture() { + return new KafkaCompletableFuture<>(); + } + + //@Override // enable once Kafka no longer supports Java 8 + public CompletableFuture completeAsync(Supplier supplier, Executor executor) { + throw erroneousCompletionException(); + } + + //@Override // enable once Kafka no longer supports Java 8 + public CompletableFuture completeAsync(Supplier supplier) { + throw erroneousCompletionException(); + } + + //@Override // enable once Kafka no longer supports Java 8 + public CompletableFuture completeOnTimeout(T value, long timeout, TimeUnit unit) { + throw erroneousCompletionException(); + } + + private UnsupportedOperationException erroneousCompletionException() { + return new UnsupportedOperationException("User code should not complete futures returned from Kafka clients"); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/internals/KafkaFutureImpl.java b/clients/src/main/java/org/apache/kafka/common/internals/KafkaFutureImpl.java new file mode 100644 index 0000000..711bd25 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/internals/KafkaFutureImpl.java @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.internals; + +import java.util.concurrent.CancellationException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import org.apache.kafka.common.KafkaFuture; + +/** + * A flexible future which supports call chaining and other asynchronous programming patterns. + */ +public class KafkaFutureImpl extends KafkaFuture { + + private final KafkaCompletableFuture completableFuture; + + private final boolean isDependant; + + public KafkaFutureImpl() { + this(false, new KafkaCompletableFuture<>()); + } + + private KafkaFutureImpl(boolean isDependant, KafkaCompletableFuture completableFuture) { + this.isDependant = isDependant; + this.completableFuture = completableFuture; + } + + @Override + public CompletionStage toCompletionStage() { + return completableFuture; + } + + /** + * Returns a new KafkaFuture that, when this future completes normally, is executed with this + * futures's result as the argument to the supplied function. + */ + @Override + public KafkaFuture thenApply(BaseFunction function) { + CompletableFuture appliedFuture = completableFuture.thenApply(value -> { + try { + return function.apply(value); + } catch (Throwable t) { + if (t instanceof CompletionException) { + // KafkaFuture#thenApply, when the function threw CompletionException should return + // an ExecutionException wrapping a CompletionException wrapping the exception thrown by the + // function. CompletableFuture#thenApply will just return ExecutionException wrapping the + // exception thrown by the function, so we add an extra CompletionException here to + // maintain the KafkaFuture behaviour. + throw new CompletionException(t); + } else { + throw t; + } + } + }); + return new KafkaFutureImpl<>(true, toKafkaCompletableFuture(appliedFuture)); + } + + private static KafkaCompletableFuture toKafkaCompletableFuture(CompletableFuture completableFuture) { + if (completableFuture instanceof KafkaCompletableFuture) { + return (KafkaCompletableFuture) completableFuture; + } else { + final KafkaCompletableFuture result = new KafkaCompletableFuture<>(); + completableFuture.whenComplete((x, y) -> { + if (y != null) { + result.kafkaCompleteExceptionally(y); + } else { + result.kafkaComplete(x); + } + }); + return result; + } + } + + /** + * @see KafkaFutureImpl#thenApply(BaseFunction) + * @deprecated Since Kafka 3.0. + */ + @Deprecated + @Override + public KafkaFuture thenApply(Function function) { + return thenApply((BaseFunction) function); + } + + @Override + public KafkaFuture whenComplete(final BiConsumer biConsumer) { + CompletableFuture tCompletableFuture = completableFuture.whenComplete((java.util.function.BiConsumer) (a, b) -> { + try { + biConsumer.accept(a, b); + } catch (Throwable t) { + if (t instanceof CompletionException) { + throw new CompletionException(t); + } else { + throw t; + } + } + }); + return new KafkaFutureImpl<>(true, toKafkaCompletableFuture(tCompletableFuture)); + } + + + @Override + public boolean complete(T newValue) { + return completableFuture.kafkaComplete(newValue); + } + + @Override + public boolean completeExceptionally(Throwable newException) { + // CompletableFuture#get() always wraps the _cause_ of a CompletionException in ExecutionException + // (which KafkaFuture does not) so wrap CompletionException in an extra one to avoid losing the + // first CompletionException in the exception chain. + return completableFuture.kafkaCompleteExceptionally( + newException instanceof CompletionException ? new CompletionException(newException) : newException); + } + + /** + * If not already completed, completes this future with a CancellationException. Dependent + * futures that have not already completed will also complete exceptionally, with a + * CompletionException caused by this CancellationException. + */ + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return completableFuture.cancel(mayInterruptIfRunning); + } + + /** + * We need to deal with differences between KafkaFuture's historic API and the API of CompletableFuture: + * CompletableFuture#get() does not wrap CancellationException in ExecutionException (nor does KafkaFuture). + * CompletableFuture#get() always wraps the _cause_ of a CompletionException in ExecutionException + * (which KafkaFuture does not). + * + * The semantics for KafkaFuture are that all exceptional completions of the future (via #completeExceptionally() + * or exceptions from dependants) manifest as ExecutionException, as observed via both get() and getNow(). + */ + private void maybeThrowCancellationException(Throwable cause) { + if (cause instanceof CancellationException) { + throw (CancellationException) cause; + } + } + + /** + * Waits if necessary for this future to complete, and then returns its result. + */ + @Override + public T get() throws InterruptedException, ExecutionException { + try { + return completableFuture.get(); + } catch (ExecutionException e) { + maybeThrowCancellationException(e.getCause()); + throw e; + } + } + + /** + * Waits if necessary for at most the given time for this future to complete, and then returns + * its result, if available. + */ + @Override + public T get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, + TimeoutException { + try { + return completableFuture.get(timeout, unit); + } catch (ExecutionException e) { + maybeThrowCancellationException(e.getCause()); + throw e; + } + } + + /** + * Returns the result value (or throws any encountered exception) if completed, else returns + * the given valueIfAbsent. + */ + @Override + public T getNow(T valueIfAbsent) throws ExecutionException { + try { + return completableFuture.getNow(valueIfAbsent); + } catch (CompletionException e) { + maybeThrowCancellationException(e.getCause()); + // Note, unlike CompletableFuture#get() which throws ExecutionException, CompletableFuture#getNow() + // throws CompletionException, thus needs rewrapping to conform to KafkaFuture API, + // where KafkaFuture#getNow() throws ExecutionException. + throw new ExecutionException(e.getCause()); + } + } + + /** + * Returns true if this CompletableFuture was cancelled before it completed normally. + */ + @Override + public boolean isCancelled() { + if (isDependant) { + // Having isCancelled() for a dependent future just return + // CompletableFuture.isCancelled() would break the historical KafkaFuture behaviour because + // CompletableFuture#isCancelled() just checks for the exception being CancellationException + // whereas it will be a CompletionException wrapping a CancellationException + // due needing to compensate for CompletableFuture's CompletionException unwrapping + // shenanigans in other methods. + try { + completableFuture.getNow(null); + return false; + } catch (Exception e) { + return e instanceof CompletionException + && e.getCause() instanceof CancellationException; + } + } else { + return completableFuture.isCancelled(); + } + } + + /** + * Returns true if this CompletableFuture completed exceptionally, in any way. + */ + @Override + public boolean isCompletedExceptionally() { + return completableFuture.isCompletedExceptionally(); + } + + /** + * Returns true if completed in any fashion: normally, exceptionally, or via cancellation. + */ + @Override + public boolean isDone() { + return completableFuture.isDone(); + } + + @Override + public String toString() { + T value = null; + Throwable exception = null; + try { + value = completableFuture.getNow(null); + } catch (CompletionException e) { + exception = e.getCause(); + } catch (Exception e) { + exception = e; + } + return String.format("KafkaFuture{value=%s,exception=%s,done=%b}", value, exception, exception != null || value != null); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/internals/PartitionStates.java b/clients/src/main/java/org/apache/kafka/common/internals/PartitionStates.java new file mode 100644 index 0000000..96652df --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/internals/PartitionStates.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.internals; + +import org.apache.kafka.common.TopicPartition; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.function.BiConsumer; + +/** + * This class is a useful building block for doing fetch requests where topic partitions have to be rotated via + * round-robin to ensure fairness and some level of determinism given the existence of a limit on the fetch response + * size. Because the serialization of fetch requests is more efficient if all partitions for the same topic are grouped + * together, we do such grouping in the method `set`. + * + * As partitions are moved to the end, the same topic may be repeated more than once. In the optimal case, a single + * topic would "wrap around" and appear twice. However, as partitions are fetched in different orders and partition + * leadership changes, we will deviate from the optimal. If this turns out to be an issue in practice, we can improve + * it by tracking the partitions per node or calling `set` every so often. + * + * Note that this class is not thread-safe with the exception of {@link #size()} which returns the number of + * partitions currently tracked. + */ +public class PartitionStates { + + private final LinkedHashMap map = new LinkedHashMap<>(); + private final Set partitionSetView = Collections.unmodifiableSet(map.keySet()); + + /* the number of partitions that are currently assigned available in a thread safe manner */ + private volatile int size = 0; + + public PartitionStates() {} + + public void moveToEnd(TopicPartition topicPartition) { + S state = map.remove(topicPartition); + if (state != null) + map.put(topicPartition, state); + } + + public void updateAndMoveToEnd(TopicPartition topicPartition, S state) { + map.remove(topicPartition); + map.put(topicPartition, state); + updateSize(); + } + + public void update(TopicPartition topicPartition, S state) { + map.put(topicPartition, state); + updateSize(); + } + + public void remove(TopicPartition topicPartition) { + map.remove(topicPartition); + updateSize(); + } + + /** + * Returns an unmodifiable view of the partitions in random order. + * changes to this PartitionStates instance will be reflected in this view. + */ + public Set partitionSet() { + return partitionSetView; + } + + public void clear() { + map.clear(); + updateSize(); + } + + public boolean contains(TopicPartition topicPartition) { + return map.containsKey(topicPartition); + } + + public Iterator stateIterator() { + return map.values().iterator(); + } + + public void forEach(BiConsumer biConsumer) { + map.forEach(biConsumer); + } + + public Map partitionStateMap() { + return Collections.unmodifiableMap(map); + } + + /** + * Returns the partition state values in order. + */ + public List partitionStateValues() { + return new ArrayList<>(map.values()); + } + + public S stateValue(TopicPartition topicPartition) { + return map.get(topicPartition); + } + + /** + * Get the number of partitions that are currently being tracked. This is thread-safe. + */ + public int size() { + return size; + } + + /** + * Update the builder to have the received map as its state (i.e. the previous state is cleared). The builder will + * "batch by topic", so if we have a, b and c, each with two partitions, we may end up with something like the + * following (the order of topics and partitions within topics is dependent on the iteration order of the received + * map): a0, a1, b1, b0, c0, c1. + */ + public void set(Map partitionToState) { + map.clear(); + update(partitionToState); + updateSize(); + } + + private void updateSize() { + size = map.size(); + } + + private void update(Map partitionToState) { + LinkedHashMap> topicToPartitions = new LinkedHashMap<>(); + for (TopicPartition tp : partitionToState.keySet()) { + List partitions = topicToPartitions.computeIfAbsent(tp.topic(), k -> new ArrayList<>()); + partitions.add(tp); + } + for (Map.Entry> entry : topicToPartitions.entrySet()) { + for (TopicPartition tp : entry.getValue()) { + S state = partitionToState.get(tp); + map.put(tp, state); + } + } + } + + public static class PartitionState { + private final TopicPartition topicPartition; + private final S value; + + public PartitionState(TopicPartition topicPartition, S state) { + this.topicPartition = Objects.requireNonNull(topicPartition); + this.value = Objects.requireNonNull(state); + } + + public S value() { + return value; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + PartitionState that = (PartitionState) o; + + return topicPartition.equals(that.topicPartition) && value.equals(that.value); + } + + @Override + public int hashCode() { + int result = topicPartition.hashCode(); + result = 31 * result + value.hashCode(); + return result; + } + + public TopicPartition topicPartition() { + return topicPartition; + } + + @Override + public String toString() { + return "PartitionState(" + topicPartition + "=" + value + ')'; + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/internals/Topic.java b/clients/src/main/java/org/apache/kafka/common/internals/Topic.java new file mode 100644 index 0000000..7a5fefb --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/internals/Topic.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.internals; + +import org.apache.kafka.common.errors.InvalidTopicException; +import org.apache.kafka.common.utils.Utils; + +import java.util.Collections; +import java.util.Set; +import java.util.function.Consumer; + +public class Topic { + + public static final String GROUP_METADATA_TOPIC_NAME = "__consumer_offsets"; + public static final String TRANSACTION_STATE_TOPIC_NAME = "__transaction_state"; + public static final String LEGAL_CHARS = "[a-zA-Z0-9._-]"; + + private static final Set INTERNAL_TOPICS = Collections.unmodifiableSet( + Utils.mkSet(GROUP_METADATA_TOPIC_NAME, TRANSACTION_STATE_TOPIC_NAME)); + + private static final int MAX_NAME_LENGTH = 249; + + public static void validate(String topic) { + validate(topic, "Topic name", message -> { + throw new InvalidTopicException(message); + }); + } + + public static void validate(String name, String logPrefix, Consumer throwableConsumer) { + if (name.isEmpty()) + throwableConsumer.accept(logPrefix + " is illegal, it can't be empty"); + if (".".equals(name) || "..".equals(name)) + throwableConsumer.accept(logPrefix + " cannot be \".\" or \"..\""); + if (name.length() > MAX_NAME_LENGTH) + throwableConsumer.accept(logPrefix + " is illegal, it can't be longer than " + MAX_NAME_LENGTH + + " characters, " + logPrefix + ": " + name); + if (!containsValidPattern(name)) + throwableConsumer.accept(logPrefix + " \"" + name + "\" is illegal, it contains a character other than " + + "ASCII alphanumerics, '.', '_' and '-'"); + } + + public static boolean isInternal(String topic) { + return INTERNAL_TOPICS.contains(topic); + } + + /** + * Due to limitations in metric names, topics with a period ('.') or underscore ('_') could collide. + * + * @param topic The topic to check for colliding character + * @return true if the topic has collision characters + */ + public static boolean hasCollisionChars(String topic) { + return topic.contains("_") || topic.contains("."); + } + + /** + * Returns true if the topicNames collide due to a period ('.') or underscore ('_') in the same position. + * + * @param topicA A topic to check for collision + * @param topicB A topic to check for collision + * @return true if the topics collide + */ + public static boolean hasCollision(String topicA, String topicB) { + return topicA.replace('.', '_').equals(topicB.replace('.', '_')); + } + + /** + * Valid characters for Kafka topics are the ASCII alphanumerics, '.', '_', and '-' + */ + static boolean containsValidPattern(String topic) { + for (int i = 0; i < topic.length(); ++i) { + char c = topic.charAt(i); + + // We don't use Character.isLetterOrDigit(c) because it's slower + boolean validChar = (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || (c >= 'A' && c <= 'Z') || c == '.' || + c == '_' || c == '-'; + if (!validChar) + return false; + } + return true; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/memory/GarbageCollectedMemoryPool.java b/clients/src/main/java/org/apache/kafka/common/memory/GarbageCollectedMemoryPool.java new file mode 100644 index 0000000..18f8ffe --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/memory/GarbageCollectedMemoryPool.java @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.memory; + +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.utils.Utils; + +import java.lang.ref.ReferenceQueue; +import java.lang.ref.WeakReference; +import java.nio.ByteBuffer; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + + +/** + * An extension of SimpleMemoryPool that tracks allocated buffers and logs an error when they "leak" + * (when they are garbage-collected without having been release()ed). + * THIS IMPLEMENTATION IS A DEVELOPMENT/DEBUGGING AID AND IS NOT MEANT PRO PRODUCTION USE. + */ +public class GarbageCollectedMemoryPool extends SimpleMemoryPool implements AutoCloseable { + + private final ReferenceQueue garbageCollectedBuffers = new ReferenceQueue<>(); + //serves 2 purposes - 1st it maintains the ref objects reachable (which is a requirement for them + //to ever be enqueued), 2nd keeps some (small) metadata for every buffer allocated + private final Map buffersInFlight = new ConcurrentHashMap<>(); + private final GarbageCollectionListener gcListener = new GarbageCollectionListener(); + private final Thread gcListenerThread; + private volatile boolean alive = true; + + public GarbageCollectedMemoryPool(long sizeBytes, int maxSingleAllocationSize, boolean strict, Sensor oomPeriodSensor) { + super(sizeBytes, maxSingleAllocationSize, strict, oomPeriodSensor); + this.alive = true; + this.gcListenerThread = new Thread(gcListener, "memory pool GC listener"); + this.gcListenerThread.setDaemon(true); //so we dont need to worry about shutdown + this.gcListenerThread.start(); + } + + @Override + protected void bufferToBeReturned(ByteBuffer justAllocated) { + BufferReference ref = new BufferReference(justAllocated, garbageCollectedBuffers); + BufferMetadata metadata = new BufferMetadata(justAllocated.capacity()); + if (buffersInFlight.put(ref, metadata) != null) + //this is a bug. it means either 2 different co-existing buffers got + //the same identity or we failed to register a released/GC'ed buffer + throw new IllegalStateException("allocated buffer identity " + ref.hashCode + " already registered as in use?!"); + + log.trace("allocated buffer of size {} and identity {}", sizeBytes, ref.hashCode); + } + + @Override + protected void bufferToBeReleased(ByteBuffer justReleased) { + BufferReference ref = new BufferReference(justReleased); //used ro lookup only + BufferMetadata metadata = buffersInFlight.remove(ref); + if (metadata == null) + //its impossible for the buffer to have already been GC'ed (because we have a hard ref to it + //in the function arg) so this means either a double free or not our buffer. + throw new IllegalArgumentException("returned buffer " + ref.hashCode + " was never allocated by this pool"); + if (metadata.sizeBytes != justReleased.capacity()) { + //this is a bug + throw new IllegalStateException("buffer " + ref.hashCode + " has capacity " + justReleased.capacity() + " but recorded as " + metadata.sizeBytes); + } + log.trace("released buffer of size {} and identity {}", metadata.sizeBytes, ref.hashCode); + } + + @Override + public void close() { + alive = false; + gcListenerThread.interrupt(); + } + + private class GarbageCollectionListener implements Runnable { + @Override + public void run() { + while (alive) { + try { + BufferReference ref = (BufferReference) garbageCollectedBuffers.remove(); //blocks + ref.clear(); + //this cannot race with a release() call because an object is either reachable or not, + //release() can only happen before its GC'ed, and enqueue can only happen after. + //if the ref was enqueued it must then not have been released + BufferMetadata metadata = buffersInFlight.remove(ref); + + if (metadata == null) { + //it can happen rarely that the buffer was release()ed properly (so no metadata) and yet + //the reference object to it remains reachable for a short period of time after release() + //and hence gets enqueued. this is because we keep refs in a ConcurrentHashMap which cleans + //up keys lazily. + continue; + } + + availableMemory.addAndGet(metadata.sizeBytes); + log.error("Reclaimed buffer of size {} and identity {} that was not properly release()ed. This is a bug.", metadata.sizeBytes, ref.hashCode); + } catch (InterruptedException e) { + log.debug("interrupted", e); + //ignore, we're a daemon thread + } + } + log.info("GC listener shutting down"); + } + } + + private static final class BufferMetadata { + private final int sizeBytes; + + private BufferMetadata(int sizeBytes) { + this.sizeBytes = sizeBytes; + } + } + + private static final class BufferReference extends WeakReference { + private final int hashCode; + + private BufferReference(ByteBuffer referent) { //used for lookup purposes only - no queue required. + this(referent, null); + } + + private BufferReference(ByteBuffer referent, ReferenceQueue q) { + super(referent, q); + hashCode = System.identityHashCode(referent); + } + + @Override + public boolean equals(Object o) { + if (this == o) { //this is important to find leaked buffers (by ref identity) + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + BufferReference that = (BufferReference) o; + if (hashCode != that.hashCode) { + return false; + } + ByteBuffer thisBuf = get(); + if (thisBuf == null) { + //our buffer has already been GC'ed, yet "that" is not us. so not same buffer + return false; + } + ByteBuffer thatBuf = that.get(); + return thisBuf == thatBuf; + } + + @Override + public int hashCode() { + return hashCode; + } + } + + @Override + public String toString() { + long allocated = sizeBytes - availableMemory.get(); + return "GarbageCollectedMemoryPool{" + Utils.formatBytes(allocated) + "/" + Utils.formatBytes(sizeBytes) + " used in " + buffersInFlight.size() + " buffers}"; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/memory/MemoryPool.java b/clients/src/main/java/org/apache/kafka/common/memory/MemoryPool.java new file mode 100644 index 0000000..5887816 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/memory/MemoryPool.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.memory; + +import java.nio.ByteBuffer; + + +/** + * A common memory pool interface for non-blocking pools. + * Every buffer returned from {@link #tryAllocate(int)} must always be {@link #release(ByteBuffer) released}. + */ +public interface MemoryPool { + MemoryPool NONE = new MemoryPool() { + @Override + public ByteBuffer tryAllocate(int sizeBytes) { + return ByteBuffer.allocate(sizeBytes); + } + + @Override + public void release(ByteBuffer previouslyAllocated) { + //nop + } + + @Override + public long size() { + return Long.MAX_VALUE; + } + + @Override + public long availableMemory() { + return Long.MAX_VALUE; + } + + @Override + public boolean isOutOfMemory() { + return false; + } + + @Override + public String toString() { + return "NONE"; + } + }; + + /** + * Tries to acquire a ByteBuffer of the specified size + * @param sizeBytes size required + * @return a ByteBuffer (which later needs to be release()ed), or null if no memory available. + * the buffer will be of the exact size requested, even if backed by a larger chunk of memory + */ + ByteBuffer tryAllocate(int sizeBytes); + + /** + * Returns a previously allocated buffer to the pool. + * @param previouslyAllocated a buffer previously returned from tryAllocate() + */ + void release(ByteBuffer previouslyAllocated); + + /** + * Returns the total size of this pool + * @return total size, in bytes + */ + long size(); + + /** + * Returns the amount of memory available for allocation by this pool. + * NOTE: result may be negative (pools may over allocate to avoid starvation issues) + * @return bytes available + */ + long availableMemory(); + + /** + * Returns true if the pool cannot currently allocate any more buffers + * - meaning total outstanding buffers meets or exceeds pool size and + * some would need to be released before further allocations are possible. + * + * This is equivalent to availableMemory() <= 0 + * @return true if out of memory + */ + boolean isOutOfMemory(); +} diff --git a/clients/src/main/java/org/apache/kafka/common/memory/SimpleMemoryPool.java b/clients/src/main/java/org/apache/kafka/common/memory/SimpleMemoryPool.java new file mode 100644 index 0000000..f1ab8f7 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/memory/SimpleMemoryPool.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.memory; + +import java.nio.ByteBuffer; +import java.util.concurrent.atomic.AtomicLong; + +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + + +/** + * a simple pool implementation. this implementation just provides a limit on the total outstanding memory. + * any buffer allocated must be release()ed always otherwise memory is not marked as reclaimed (and "leak"s) + */ +public class SimpleMemoryPool implements MemoryPool { + protected final Logger log = LoggerFactory.getLogger(getClass()); //subclass-friendly + + protected final long sizeBytes; + protected final boolean strict; + protected final AtomicLong availableMemory; + protected final int maxSingleAllocationSize; + protected final AtomicLong startOfNoMemPeriod = new AtomicLong(); //nanoseconds + protected volatile Sensor oomTimeSensor; + + public SimpleMemoryPool(long sizeInBytes, int maxSingleAllocationBytes, boolean strict, Sensor oomPeriodSensor) { + if (sizeInBytes <= 0 || maxSingleAllocationBytes <= 0 || maxSingleAllocationBytes > sizeInBytes) + throw new IllegalArgumentException("must provide a positive size and max single allocation size smaller than size." + + "provided " + sizeInBytes + " and " + maxSingleAllocationBytes + " respectively"); + this.sizeBytes = sizeInBytes; + this.strict = strict; + this.availableMemory = new AtomicLong(sizeInBytes); + this.maxSingleAllocationSize = maxSingleAllocationBytes; + this.oomTimeSensor = oomPeriodSensor; + } + + @Override + public ByteBuffer tryAllocate(int sizeBytes) { + if (sizeBytes < 1) + throw new IllegalArgumentException("requested size " + sizeBytes + "<=0"); + if (sizeBytes > maxSingleAllocationSize) + throw new IllegalArgumentException("requested size " + sizeBytes + " is larger than maxSingleAllocationSize " + maxSingleAllocationSize); + + long available; + boolean success = false; + //in strict mode we will only allocate memory if we have at least the size required. + //in non-strict mode we will allocate memory if we have _any_ memory available (so available memory + //can dip into the negative and max allocated memory would be sizeBytes + maxSingleAllocationSize) + long threshold = strict ? sizeBytes : 1; + while ((available = availableMemory.get()) >= threshold) { + success = availableMemory.compareAndSet(available, available - sizeBytes); + if (success) + break; + } + + if (success) { + maybeRecordEndOfDrySpell(); + } else { + if (oomTimeSensor != null) { + startOfNoMemPeriod.compareAndSet(0, System.nanoTime()); + } + log.trace("refused to allocate buffer of size {}", sizeBytes); + return null; + } + + ByteBuffer allocated = ByteBuffer.allocate(sizeBytes); + bufferToBeReturned(allocated); + return allocated; + } + + @Override + public void release(ByteBuffer previouslyAllocated) { + if (previouslyAllocated == null) + throw new IllegalArgumentException("provided null buffer"); + + bufferToBeReleased(previouslyAllocated); + availableMemory.addAndGet(previouslyAllocated.capacity()); + maybeRecordEndOfDrySpell(); + } + + @Override + public long size() { + return sizeBytes; + } + + @Override + public long availableMemory() { + return availableMemory.get(); + } + + @Override + public boolean isOutOfMemory() { + return availableMemory.get() <= 0; + } + + //allows subclasses to do their own bookkeeping (and validation) _before_ memory is returned to client code. + protected void bufferToBeReturned(ByteBuffer justAllocated) { + log.trace("allocated buffer of size {} ", justAllocated.capacity()); + } + + //allows subclasses to do their own bookkeeping (and validation) _before_ memory is marked as reclaimed. + protected void bufferToBeReleased(ByteBuffer justReleased) { + log.trace("released buffer of size {}", justReleased.capacity()); + } + + @Override + public String toString() { + long allocated = sizeBytes - availableMemory.get(); + return "SimpleMemoryPool{" + Utils.formatBytes(allocated) + "/" + Utils.formatBytes(sizeBytes) + " used}"; + } + + protected void maybeRecordEndOfDrySpell() { + if (oomTimeSensor != null) { + long startOfDrySpell = startOfNoMemPeriod.getAndSet(0); + if (startOfDrySpell != 0) { + //how long were we refusing allocation requests for + oomTimeSensor.record((System.nanoTime() - startOfDrySpell) / 1000000.0); //fractional (double) millis + } + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/CompoundStat.java b/clients/src/main/java/org/apache/kafka/common/metrics/CompoundStat.java new file mode 100644 index 0000000..f2a7ac6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/CompoundStat.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics; + +import org.apache.kafka.common.MetricName; + +import java.util.List; + +/** + * A compound stat is a stat where a single measurement and associated data structure feeds many metrics. This is the + * example for a histogram which has many associated percentiles. + */ +public interface CompoundStat extends Stat { + + List stats(); + + class NamedMeasurable { + + private final MetricName name; + private final Measurable stat; + + public NamedMeasurable(MetricName name, Measurable stat) { + super(); + this.name = name; + this.stat = stat; + } + + public MetricName name() { + return name; + } + + public Measurable stat() { + return stat; + } + + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/Gauge.java b/clients/src/main/java/org/apache/kafka/common/metrics/Gauge.java new file mode 100644 index 0000000..647942b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/Gauge.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics; + +/** + * A gauge metric is an instantaneous reading of a particular value. + */ +public interface Gauge extends MetricValueProvider { + + /** + * Returns the current value associated with this gauge. + * @param config The configuration for this metric + * @param now The POSIX time in milliseconds the measurement is being taken + */ + T value(MetricConfig config, long now); + +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/JmxReporter.java b/clients/src/main/java/org/apache/kafka/common/metrics/JmxReporter.java new file mode 100644 index 0000000..3867091 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/JmxReporter.java @@ -0,0 +1,356 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.utils.ConfigUtils; +import org.apache.kafka.common.utils.Sanitizer; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.management.Attribute; +import javax.management.AttributeList; +import javax.management.AttributeNotFoundException; +import javax.management.DynamicMBean; +import javax.management.JMException; +import javax.management.MBeanAttributeInfo; +import javax.management.MBeanInfo; +import javax.management.MBeanServer; +import javax.management.MalformedObjectNameException; +import javax.management.ObjectName; +import java.lang.management.ManagementFactory; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.function.Predicate; +import java.util.regex.Pattern; +import java.util.regex.PatternSyntaxException; + +/** + * Register metrics in JMX as dynamic mbeans based on the metric names + */ +public class JmxReporter implements MetricsReporter { + + public static final String METRICS_CONFIG_PREFIX = "metrics.jmx."; + + public static final String EXCLUDE_CONFIG = METRICS_CONFIG_PREFIX + "exclude"; + public static final String EXCLUDE_CONFIG_ALIAS = METRICS_CONFIG_PREFIX + "blacklist"; + + public static final String INCLUDE_CONFIG = METRICS_CONFIG_PREFIX + "include"; + public static final String INCLUDE_CONFIG_ALIAS = METRICS_CONFIG_PREFIX + "whitelist"; + + + public static final Set RECONFIGURABLE_CONFIGS = Utils.mkSet(INCLUDE_CONFIG, + INCLUDE_CONFIG_ALIAS, + EXCLUDE_CONFIG, + EXCLUDE_CONFIG_ALIAS); + + public static final String DEFAULT_INCLUDE = ".*"; + public static final String DEFAULT_EXCLUDE = ""; + + private static final Logger log = LoggerFactory.getLogger(JmxReporter.class); + private static final Object LOCK = new Object(); + private String prefix; + private final Map mbeans = new HashMap<>(); + private Predicate mbeanPredicate = s -> true; + + public JmxReporter() { + this(""); + } + + /** + * Create a JMX reporter that prefixes all metrics with the given string. + * @deprecated Since 2.6.0. Use {@link JmxReporter#JmxReporter()} + * Initialize JmxReporter with {@link JmxReporter#contextChange(MetricsContext)} + * Populate prefix by adding _namespace/prefix key value pair to {@link MetricsContext} + */ + @Deprecated + public JmxReporter(String prefix) { + this.prefix = prefix != null ? prefix : ""; + } + + @Override + public void configure(Map configs) { + reconfigure(configs); + } + + @Override + public Set reconfigurableConfigs() { + return RECONFIGURABLE_CONFIGS; + } + + @Override + public void validateReconfiguration(Map configs) throws ConfigException { + compilePredicate(configs); + } + + @Override + public void reconfigure(Map configs) { + synchronized (LOCK) { + this.mbeanPredicate = JmxReporter.compilePredicate(configs); + + mbeans.forEach((name, mbean) -> { + if (mbeanPredicate.test(name)) { + reregister(mbean); + } else { + unregister(mbean); + } + }); + } + } + + @Override + public void init(List metrics) { + synchronized (LOCK) { + for (KafkaMetric metric : metrics) + addAttribute(metric); + + mbeans.forEach((name, mbean) -> { + if (mbeanPredicate.test(name)) { + reregister(mbean); + } + }); + } + } + + public boolean containsMbean(String mbeanName) { + return mbeans.containsKey(mbeanName); + } + + @Override + public void metricChange(KafkaMetric metric) { + synchronized (LOCK) { + String mbeanName = addAttribute(metric); + if (mbeanName != null && mbeanPredicate.test(mbeanName)) { + reregister(mbeans.get(mbeanName)); + } + } + } + + @Override + public void metricRemoval(KafkaMetric metric) { + synchronized (LOCK) { + MetricName metricName = metric.metricName(); + String mBeanName = getMBeanName(prefix, metricName); + KafkaMbean mbean = removeAttribute(metric, mBeanName); + if (mbean != null) { + if (mbean.metrics.isEmpty()) { + unregister(mbean); + mbeans.remove(mBeanName); + } else if (mbeanPredicate.test(mBeanName)) + reregister(mbean); + } + } + } + + private KafkaMbean removeAttribute(KafkaMetric metric, String mBeanName) { + MetricName metricName = metric.metricName(); + KafkaMbean mbean = this.mbeans.get(mBeanName); + if (mbean != null) + mbean.removeAttribute(metricName.name()); + return mbean; + } + + private String addAttribute(KafkaMetric metric) { + try { + MetricName metricName = metric.metricName(); + String mBeanName = getMBeanName(prefix, metricName); + if (!this.mbeans.containsKey(mBeanName)) + mbeans.put(mBeanName, new KafkaMbean(mBeanName)); + KafkaMbean mbean = this.mbeans.get(mBeanName); + mbean.setAttribute(metricName.name(), metric); + return mBeanName; + } catch (JMException e) { + throw new KafkaException("Error creating mbean attribute for metricName :" + metric.metricName(), e); + } + } + + /** + * @param metricName + * @return standard JMX MBean name in the following format domainName:type=metricType,key1=val1,key2=val2 + */ + static String getMBeanName(String prefix, MetricName metricName) { + StringBuilder mBeanName = new StringBuilder(); + mBeanName.append(prefix); + mBeanName.append(":type="); + mBeanName.append(metricName.group()); + for (Map.Entry entry : metricName.tags().entrySet()) { + if (entry.getKey().length() <= 0 || entry.getValue().length() <= 0) + continue; + mBeanName.append(","); + mBeanName.append(entry.getKey()); + mBeanName.append("="); + mBeanName.append(Sanitizer.jmxSanitize(entry.getValue())); + } + return mBeanName.toString(); + } + + public void close() { + synchronized (LOCK) { + for (KafkaMbean mbean : this.mbeans.values()) + unregister(mbean); + } + } + + private void unregister(KafkaMbean mbean) { + MBeanServer server = ManagementFactory.getPlatformMBeanServer(); + try { + if (server.isRegistered(mbean.name())) + server.unregisterMBean(mbean.name()); + } catch (JMException e) { + throw new KafkaException("Error unregistering mbean", e); + } + } + + private void reregister(KafkaMbean mbean) { + unregister(mbean); + try { + ManagementFactory.getPlatformMBeanServer().registerMBean(mbean, mbean.name()); + } catch (JMException e) { + throw new KafkaException("Error registering mbean " + mbean.name(), e); + } + } + + private static class KafkaMbean implements DynamicMBean { + private final ObjectName objectName; + private final Map metrics; + + KafkaMbean(String mbeanName) throws MalformedObjectNameException { + this.metrics = new HashMap<>(); + this.objectName = new ObjectName(mbeanName); + } + + public ObjectName name() { + return objectName; + } + + void setAttribute(String name, KafkaMetric metric) { + this.metrics.put(name, metric); + } + + @Override + public Object getAttribute(String name) throws AttributeNotFoundException { + if (this.metrics.containsKey(name)) + return this.metrics.get(name).metricValue(); + else + throw new AttributeNotFoundException("Could not find attribute " + name); + } + + @Override + public AttributeList getAttributes(String[] names) { + AttributeList list = new AttributeList(); + for (String name : names) { + try { + list.add(new Attribute(name, getAttribute(name))); + } catch (Exception e) { + log.warn("Error getting JMX attribute '{}'", name, e); + } + } + return list; + } + + KafkaMetric removeAttribute(String name) { + return this.metrics.remove(name); + } + + @Override + public MBeanInfo getMBeanInfo() { + MBeanAttributeInfo[] attrs = new MBeanAttributeInfo[metrics.size()]; + int i = 0; + for (Map.Entry entry : this.metrics.entrySet()) { + String attribute = entry.getKey(); + KafkaMetric metric = entry.getValue(); + attrs[i] = new MBeanAttributeInfo(attribute, + double.class.getName(), + metric.metricName().description(), + true, + false, + false); + i += 1; + } + return new MBeanInfo(this.getClass().getName(), "", attrs, null, null, null); + } + + @Override + public Object invoke(String name, Object[] params, String[] sig) { + throw new UnsupportedOperationException("Set not allowed."); + } + + @Override + public void setAttribute(Attribute attribute) { + throw new UnsupportedOperationException("Set not allowed."); + } + + @Override + public AttributeList setAttributes(AttributeList list) { + throw new UnsupportedOperationException("Set not allowed."); + } + + } + + public static Predicate compilePredicate(Map originalConfig) { + Map configs = ConfigUtils.translateDeprecatedConfigs( + originalConfig, new String[][]{{INCLUDE_CONFIG, INCLUDE_CONFIG_ALIAS}, + {EXCLUDE_CONFIG, EXCLUDE_CONFIG_ALIAS}}); + String include = (String) configs.get(INCLUDE_CONFIG); + String exclude = (String) configs.get(EXCLUDE_CONFIG); + + if (include == null) { + include = DEFAULT_INCLUDE; + } + + if (exclude == null) { + exclude = DEFAULT_EXCLUDE; + } + + try { + Pattern includePattern = Pattern.compile(include); + Pattern excludePattern = Pattern.compile(exclude); + + return s -> includePattern.matcher(s).matches() + && !excludePattern.matcher(s).matches(); + } catch (PatternSyntaxException e) { + throw new ConfigException("JMX filter for configuration" + METRICS_CONFIG_PREFIX + + ".(include/exclude) is not a valid regular expression"); + } + } + + @Override + public void contextChange(MetricsContext metricsContext) { + String namespace = metricsContext.contextLabels().get(MetricsContext.NAMESPACE); + Objects.requireNonNull(namespace); + synchronized (LOCK) { + if (!mbeans.isEmpty()) { + throw new IllegalStateException("JMX MetricsContext can only be updated before JMX metrics are created"); + } + + // prevent prefix from getting reset back to empty for backwards compatibility + // with the deprecated JmxReporter(String prefix) constructor, in case contextChange gets called + // via one of the Metrics() constructor with a default empty MetricsContext() + if (namespace.isEmpty()) { + return; + } + + prefix = namespace; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/KafkaMetric.java b/clients/src/main/java/org/apache/kafka/common/metrics/KafkaMetric.java new file mode 100644 index 0000000..cb29dc2 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/KafkaMetric.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics; + +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.utils.Time; + +public final class KafkaMetric implements Metric { + + private MetricName metricName; + private final Object lock; + private final Time time; + private final MetricValueProvider metricValueProvider; + private MetricConfig config; + + // public for testing + public KafkaMetric(Object lock, MetricName metricName, MetricValueProvider valueProvider, + MetricConfig config, Time time) { + this.metricName = metricName; + this.lock = lock; + if (!(valueProvider instanceof Measurable) && !(valueProvider instanceof Gauge)) + throw new IllegalArgumentException("Unsupported metric value provider of class " + valueProvider.getClass()); + this.metricValueProvider = valueProvider; + this.config = config; + this.time = time; + } + + public MetricConfig config() { + return this.config; + } + + @Override + public MetricName metricName() { + return this.metricName; + } + + @Override + public Object metricValue() { + long now = time.milliseconds(); + synchronized (this.lock) { + if (this.metricValueProvider instanceof Measurable) + return ((Measurable) metricValueProvider).measure(config, now); + else if (this.metricValueProvider instanceof Gauge) + return ((Gauge) metricValueProvider).value(config, now); + else + throw new IllegalStateException("Not a valid metric: " + this.metricValueProvider.getClass()); + } + } + + public Measurable measurable() { + if (this.metricValueProvider instanceof Measurable) + return (Measurable) metricValueProvider; + else + throw new IllegalStateException("Not a measurable: " + this.metricValueProvider.getClass()); + } + + double measurableValue(long timeMs) { + synchronized (this.lock) { + if (this.metricValueProvider instanceof Measurable) + return ((Measurable) metricValueProvider).measure(config, timeMs); + else + return 0; + } + } + + public void config(MetricConfig config) { + synchronized (lock) { + this.config = config; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/KafkaMetricsContext.java b/clients/src/main/java/org/apache/kafka/common/metrics/KafkaMetricsContext.java new file mode 100644 index 0000000..43eb8cb --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/KafkaMetricsContext.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * A implementation of MetricsContext, it encapsulates required metrics context properties for Kafka services and clients + */ +public class KafkaMetricsContext implements MetricsContext { + /** + * Client or Service's contextLabels map. + */ + private final Map contextLabels = new HashMap<>(); + + /** + * Create a MetricsContext with namespace, no service or client properties + * @param namespace value for _namespace key + */ + public KafkaMetricsContext(String namespace) { + this(namespace, new HashMap<>()); + } + + /** + * Create a MetricsContext with namespace, service or client properties + * @param namespace value for _namespace key + * @param contextLabels contextLabels additional entries to add to the context. + * values will be converted to string using Object.toString() + */ + public KafkaMetricsContext(String namespace, Map contextLabels) { + this.contextLabels.put(MetricsContext.NAMESPACE, namespace); + contextLabels.forEach((key, value) -> this.contextLabels.put(key, value != null ? value.toString() : null)); + } + + @Override + public Map contextLabels() { + return Collections.unmodifiableMap(contextLabels); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/Measurable.java b/clients/src/main/java/org/apache/kafka/common/metrics/Measurable.java new file mode 100644 index 0000000..866caba --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/Measurable.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics; + +/** + * A measurable quantity that can be registered as a metric + */ +public interface Measurable extends MetricValueProvider { + + /** + * Measure this quantity and return the result as a double + * @param config The configuration for this metric + * @param now The POSIX time in milliseconds the measurement is being taken + * @return The measured value + */ + double measure(MetricConfig config, long now); + +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/MeasurableStat.java b/clients/src/main/java/org/apache/kafka/common/metrics/MeasurableStat.java new file mode 100644 index 0000000..035449e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/MeasurableStat.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics; + +/** + * A MeasurableStat is a {@link Stat} that is also {@link Measurable} (i.e. can produce a single floating point value). + * This is the interface used for most of the simple statistics such as {@link org.apache.kafka.common.metrics.stats.Avg}, + * {@link org.apache.kafka.common.metrics.stats.Max}, {@link org.apache.kafka.common.metrics.stats.CumulativeCount}, etc. + */ +public interface MeasurableStat extends Stat, Measurable { + +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/MetricConfig.java b/clients/src/main/java/org/apache/kafka/common/metrics/MetricConfig.java new file mode 100644 index 0000000..7367e96 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/MetricConfig.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics; + +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +/** + * Configuration values for metrics + */ +public class MetricConfig { + + private Quota quota; + private int samples; + private long eventWindow; + private long timeWindowMs; + private Map tags; + private Sensor.RecordingLevel recordingLevel; + + public MetricConfig() { + this.quota = null; + this.samples = 2; + this.eventWindow = Long.MAX_VALUE; + this.timeWindowMs = TimeUnit.MILLISECONDS.convert(30, TimeUnit.SECONDS); + this.tags = new LinkedHashMap<>(); + this.recordingLevel = Sensor.RecordingLevel.INFO; + } + + public Quota quota() { + return this.quota; + } + + public MetricConfig quota(Quota quota) { + this.quota = quota; + return this; + } + + public long eventWindow() { + return eventWindow; + } + + public MetricConfig eventWindow(long window) { + this.eventWindow = window; + return this; + } + + public long timeWindowMs() { + return timeWindowMs; + } + + public MetricConfig timeWindow(long window, TimeUnit unit) { + this.timeWindowMs = TimeUnit.MILLISECONDS.convert(window, unit); + return this; + } + + public Map tags() { + return this.tags; + } + + public MetricConfig tags(Map tags) { + this.tags = tags; + return this; + } + + public int samples() { + return this.samples; + } + + public MetricConfig samples(int samples) { + if (samples < 1) + throw new IllegalArgumentException("The number of samples must be at least 1."); + this.samples = samples; + return this; + } + + public Sensor.RecordingLevel recordLevel() { + return this.recordingLevel; + } + + public MetricConfig recordLevel(Sensor.RecordingLevel recordingLevel) { + this.recordingLevel = recordingLevel; + return this; + } + + +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/MetricValueProvider.java b/clients/src/main/java/org/apache/kafka/common/metrics/MetricValueProvider.java new file mode 100644 index 0000000..68028e7 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/MetricValueProvider.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics; + +/** + * Super-interface for {@link Measurable} or {@link Gauge} that provides + * metric values. + *

        + * In the future for Java8 and above, {@link Gauge#value(MetricConfig, long)} will be + * moved to this interface with a default implementation in {@link Measurable} that returns + * {@link Measurable#measure(MetricConfig, long)}. + *

        + */ +public interface MetricValueProvider { } diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/Metrics.java b/clients/src/main/java/org/apache/kafka/common/metrics/Metrics.java new file mode 100644 index 0000000..52b7794 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/Metrics.java @@ -0,0 +1,671 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics; + +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.MetricNameTemplate; +import org.apache.kafka.common.metrics.internals.MetricsUtils; +import org.apache.kafka.common.utils.KafkaThread; +import org.apache.kafka.common.utils.Time; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Closeable; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.Set; +import java.util.TreeMap; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +import static java.util.Collections.emptyList; + +/** + * A registry of sensors and metrics. + *

        + * A metric is a named, numerical measurement. A sensor is a handle to record numerical measurements as they occur. Each + * Sensor has zero or more associated metrics. For example a Sensor might represent message sizes and we might associate + * with this sensor a metric for the average, maximum, or other statistics computed off the sequence of message sizes + * that are recorded by the sensor. + *

        + * Usage looks something like this: + * + *

        + * // set up metrics:
        + * Metrics metrics = new Metrics(); // this is the global repository of metrics and sensors
        + * Sensor sensor = metrics.sensor("message-sizes");
        + * MetricName metricName = new MetricName("message-size-avg", "producer-metrics");
        + * sensor.add(metricName, new Avg());
        + * metricName = new MetricName("message-size-max", "producer-metrics");
        + * sensor.add(metricName, new Max());
        + * 
        + * // as messages are sent we record the sizes
        + * sensor.record(messageSize);
        + * 
        + */ +public class Metrics implements Closeable { + + private final MetricConfig config; + private final ConcurrentMap metrics; + private final ConcurrentMap sensors; + private final ConcurrentMap> childrenSensors; + private final List reporters; + private final Time time; + private final ScheduledThreadPoolExecutor metricsScheduler; + private static final Logger log = LoggerFactory.getLogger(Metrics.class); + + /** + * Create a metrics repository with no metric reporters and default configuration. + * Expiration of Sensors is disabled. + */ + public Metrics() { + this(new MetricConfig()); + } + + /** + * Create a metrics repository with no metric reporters and default configuration. + * Expiration of Sensors is disabled. + */ + public Metrics(Time time) { + this(new MetricConfig(), new ArrayList<>(0), time); + } + + /** + * Create a metrics repository with no metric reporters and the given default configuration. + * Expiration of Sensors is disabled. + */ + public Metrics(MetricConfig defaultConfig, Time time) { + this(defaultConfig, new ArrayList<>(0), time); + } + + + /** + * Create a metrics repository with no reporters and the given default config. This config will be used for any + * metric that doesn't override its own config. Expiration of Sensors is disabled. + * @param defaultConfig The default config to use for all metrics that don't override their config + */ + public Metrics(MetricConfig defaultConfig) { + this(defaultConfig, new ArrayList<>(0), Time.SYSTEM); + } + + /** + * Create a metrics repository with a default config and the given metric reporters. + * Expiration of Sensors is disabled. + * @param defaultConfig The default config + * @param reporters The metrics reporters + * @param time The time instance to use with the metrics + */ + public Metrics(MetricConfig defaultConfig, List reporters, Time time) { + this(defaultConfig, reporters, time, false); + } + + /** + * Create a metrics repository with a default config, metric reporters and metric context + * Expiration of Sensors is disabled. + * @param defaultConfig The default config + * @param reporters The metrics reporters + * @param time The time instance to use with the metrics + * @param metricsContext The metricsContext to initialize metrics reporter with + */ + public Metrics(MetricConfig defaultConfig, List reporters, Time time, MetricsContext metricsContext) { + this(defaultConfig, reporters, time, false, metricsContext); + } + + /** + * Create a metrics repository with a default config, given metric reporters and the ability to expire eligible sensors + * @param defaultConfig The default config + * @param reporters The metrics reporters + * @param time The time instance to use with the metrics + * @param enableExpiration true if the metrics instance can garbage collect inactive sensors, false otherwise + */ + public Metrics(MetricConfig defaultConfig, List reporters, Time time, boolean enableExpiration) { + this(defaultConfig, reporters, time, enableExpiration, new KafkaMetricsContext("")); + } + + /** + * Create a metrics repository with a default config, given metric reporters, the ability to expire eligible sensors + * and MetricContext + * @param defaultConfig The default config + * @param reporters The metrics reporters + * @param time The time instance to use with the metrics + * @param enableExpiration true if the metrics instance can garbage collect inactive sensors, false otherwise + * @param metricsContext The metricsContext to initialize metrics reporter with + */ + public Metrics(MetricConfig defaultConfig, List reporters, Time time, boolean enableExpiration, + MetricsContext metricsContext) { + this.config = defaultConfig; + this.sensors = new ConcurrentHashMap<>(); + this.metrics = new ConcurrentHashMap<>(); + this.childrenSensors = new ConcurrentHashMap<>(); + this.reporters = Objects.requireNonNull(reporters); + this.time = time; + for (MetricsReporter reporter : reporters) { + reporter.contextChange(metricsContext); + reporter.init(new ArrayList<>()); + } + + // Create the ThreadPoolExecutor only if expiration of Sensors is enabled. + if (enableExpiration) { + this.metricsScheduler = new ScheduledThreadPoolExecutor(1); + // Creating a daemon thread to not block shutdown + this.metricsScheduler.setThreadFactory(runnable -> KafkaThread.daemon("SensorExpiryThread", runnable)); + this.metricsScheduler.scheduleAtFixedRate(new ExpireSensorTask(), 30, 30, TimeUnit.SECONDS); + } else { + this.metricsScheduler = null; + } + + addMetric(metricName("count", "kafka-metrics-count", "total number of registered metrics"), + (config, now) -> metrics.size()); + } + + /** + * Create a MetricName with the given name, group, description and tags, plus default tags specified in the metric + * configuration. Tag in tags takes precedence if the same tag key is specified in the default metric configuration. + * + * @param name The name of the metric + * @param group logical group name of the metrics to which this metric belongs + * @param description A human-readable description to include in the metric + * @param tags additional key/value attributes of the metric + */ + public MetricName metricName(String name, String group, String description, Map tags) { + Map combinedTag = new LinkedHashMap<>(config.tags()); + combinedTag.putAll(tags); + return new MetricName(name, group, description, combinedTag); + } + + /** + * Create a MetricName with the given name, group, description, and default tags + * specified in the metric configuration. + * + * @param name The name of the metric + * @param group logical group name of the metrics to which this metric belongs + * @param description A human-readable description to include in the metric + */ + public MetricName metricName(String name, String group, String description) { + return metricName(name, group, description, new HashMap<>()); + } + + /** + * Create a MetricName with the given name, group and default tags specified in the metric configuration. + * + * @param name The name of the metric + * @param group logical group name of the metrics to which this metric belongs + */ + public MetricName metricName(String name, String group) { + return metricName(name, group, "", new HashMap<>()); + } + + /** + * Create a MetricName with the given name, group, description, and keyValue as tags, plus default tags specified in the metric + * configuration. Tag in keyValue takes precedence if the same tag key is specified in the default metric configuration. + * + * @param name The name of the metric + * @param group logical group name of the metrics to which this metric belongs + * @param description A human-readable description to include in the metric + * @param keyValue additional key/value attributes of the metric (must come in pairs) + */ + public MetricName metricName(String name, String group, String description, String... keyValue) { + return metricName(name, group, description, MetricsUtils.getTags(keyValue)); + } + + /** + * Create a MetricName with the given name, group and tags, plus default tags specified in the metric + * configuration. Tag in tags takes precedence if the same tag key is specified in the default metric configuration. + * + * @param name The name of the metric + * @param group logical group name of the metrics to which this metric belongs + * @param tags key/value attributes of the metric + */ + public MetricName metricName(String name, String group, Map tags) { + return metricName(name, group, "", tags); + } + + /** + * Use the specified domain and metric name templates to generate an HTML table documenting the metrics. A separate table section + * will be generated for each of the MBeans and the associated attributes. The MBean names are lexicographically sorted to + * determine the order of these sections. This order is therefore dependent upon the order of the + * tags in each {@link MetricNameTemplate}. + * + * @param domain the domain or prefix for the JMX MBean names; may not be null + * @param allMetrics the collection of all {@link MetricNameTemplate} instances each describing one metric; may not be null + * @return the string containing the HTML table; never null + */ + public static String toHtmlTable(String domain, Iterable allMetrics) { + Map> beansAndAttributes = new TreeMap<>(); + + try (Metrics metrics = new Metrics()) { + for (MetricNameTemplate template : allMetrics) { + Map tags = new LinkedHashMap<>(); + for (String s : template.tags()) { + tags.put(s, "{" + s + "}"); + } + + MetricName metricName = metrics.metricName(template.name(), template.group(), template.description(), tags); + String mBeanName = JmxReporter.getMBeanName(domain, metricName); + if (!beansAndAttributes.containsKey(mBeanName)) { + beansAndAttributes.put(mBeanName, new TreeMap<>()); + } + Map attrAndDesc = beansAndAttributes.get(mBeanName); + if (!attrAndDesc.containsKey(template.name())) { + attrAndDesc.put(template.name(), template.description()); + } else { + throw new IllegalArgumentException("mBean '" + mBeanName + "' attribute '" + template.name() + "' is defined twice."); + } + } + } + + StringBuilder b = new StringBuilder(); + b.append("\n"); + + for (Entry> e : beansAndAttributes.entrySet()) { + b.append("\n"); + b.append(""); + b.append("\n"); + + b.append("\n"); + b.append("\n"); + b.append("\n"); + b.append("\n"); + b.append("\n"); + + for (Entry e2 : e.getValue().entrySet()) { + b.append("\n"); + b.append(""); + b.append(""); + b.append(""); + b.append("\n"); + } + + } + b.append("
        "); + b.append(e.getKey()); + b.append("
        Attribute nameDescription
        "); + b.append(e2.getKey()); + b.append(""); + b.append(e2.getValue()); + b.append("
        "); + + return b.toString(); + + } + + public MetricConfig config() { + return config; + } + + /** + * Get the sensor with the given name if it exists + * @param name The name of the sensor + * @return Return the sensor or null if no such sensor exists + */ + public Sensor getSensor(String name) { + return this.sensors.get(Objects.requireNonNull(name)); + } + + /** + * Get or create a sensor with the given unique name and no parent sensors. This uses + * a default recording level of INFO. + * @param name The sensor name + * @return The sensor + */ + public Sensor sensor(String name) { + return this.sensor(name, Sensor.RecordingLevel.INFO); + } + + /** + * Get or create a sensor with the given unique name and no parent sensors and with a given + * recording level. + * @param name The sensor name. + * @param recordingLevel The recording level. + * @return The sensor + */ + public Sensor sensor(String name, Sensor.RecordingLevel recordingLevel) { + return sensor(name, null, recordingLevel, (Sensor[]) null); + } + + + /** + * Get or create a sensor with the given unique name and zero or more parent sensors. All parent sensors will + * receive every value recorded with this sensor. This uses a default recording level of INFO. + * @param name The name of the sensor + * @param parents The parent sensors + * @return The sensor that is created + */ + public Sensor sensor(String name, Sensor... parents) { + return this.sensor(name, Sensor.RecordingLevel.INFO, parents); + } + + /** + * Get or create a sensor with the given unique name and zero or more parent sensors. All parent sensors will + * receive every value recorded with this sensor. + * @param name The name of the sensor. + * @param parents The parent sensors. + * @param recordingLevel The recording level. + * @return The sensor that is created + */ + public Sensor sensor(String name, Sensor.RecordingLevel recordingLevel, Sensor... parents) { + return sensor(name, null, recordingLevel, parents); + } + + /** + * Get or create a sensor with the given unique name and zero or more parent sensors. All parent sensors will + * receive every value recorded with this sensor. This uses a default recording level of INFO. + * @param name The name of the sensor + * @param config A default configuration to use for this sensor for metrics that don't have their own config + * @param parents The parent sensors + * @return The sensor that is created + */ + public synchronized Sensor sensor(String name, MetricConfig config, Sensor... parents) { + return this.sensor(name, config, Sensor.RecordingLevel.INFO, parents); + } + + + /** + * Get or create a sensor with the given unique name and zero or more parent sensors. All parent sensors will + * receive every value recorded with this sensor. + * @param name The name of the sensor + * @param config A default configuration to use for this sensor for metrics that don't have their own config + * @param recordingLevel The recording level. + * @param parents The parent sensors + * @return The sensor that is created + */ + public synchronized Sensor sensor(String name, MetricConfig config, Sensor.RecordingLevel recordingLevel, Sensor... parents) { + return sensor(name, config, Long.MAX_VALUE, recordingLevel, parents); + } + + /** + * Get or create a sensor with the given unique name and zero or more parent sensors. All parent sensors will + * receive every value recorded with this sensor. + * @param name The name of the sensor + * @param config A default configuration to use for this sensor for metrics that don't have their own config + * @param inactiveSensorExpirationTimeSeconds If no value if recorded on the Sensor for this duration of time, + * it is eligible for removal + * @param parents The parent sensors + * @param recordingLevel The recording level. + * @return The sensor that is created + */ + public synchronized Sensor sensor(String name, MetricConfig config, long inactiveSensorExpirationTimeSeconds, Sensor.RecordingLevel recordingLevel, Sensor... parents) { + Sensor s = getSensor(name); + if (s == null) { + s = new Sensor(this, name, parents, config == null ? this.config : config, time, inactiveSensorExpirationTimeSeconds, recordingLevel); + this.sensors.put(name, s); + if (parents != null) { + for (Sensor parent : parents) { + List children = childrenSensors.computeIfAbsent(parent, k -> new ArrayList<>()); + children.add(s); + } + } + log.trace("Added sensor with name {}", name); + } + return s; + } + + /** + * Get or create a sensor with the given unique name and zero or more parent sensors. All parent sensors will + * receive every value recorded with this sensor. This uses a default recording level of INFO. + * @param name The name of the sensor + * @param config A default configuration to use for this sensor for metrics that don't have their own config + * @param inactiveSensorExpirationTimeSeconds If no value if recorded on the Sensor for this duration of time, + * it is eligible for removal + * @param parents The parent sensors + * @return The sensor that is created + */ + public synchronized Sensor sensor(String name, MetricConfig config, long inactiveSensorExpirationTimeSeconds, Sensor... parents) { + return this.sensor(name, config, inactiveSensorExpirationTimeSeconds, Sensor.RecordingLevel.INFO, parents); + } + + /** + * Remove a sensor (if it exists), associated metrics and its children. + * + * @param name The name of the sensor to be removed + */ + public void removeSensor(String name) { + Sensor sensor = sensors.get(name); + if (sensor != null) { + List childSensors = null; + synchronized (sensor) { + synchronized (this) { + if (sensors.remove(name, sensor)) { + for (KafkaMetric metric : sensor.metrics()) + removeMetric(metric.metricName()); + log.trace("Removed sensor with name {}", name); + childSensors = childrenSensors.remove(sensor); + for (final Sensor parent : sensor.parents()) { + childrenSensors.getOrDefault(parent, emptyList()).remove(sensor); + } + } + } + } + if (childSensors != null) { + for (Sensor childSensor : childSensors) + removeSensor(childSensor.name()); + } + } + } + + /** + * Add a metric to monitor an object that implements measurable. This metric won't be associated with any sensor. + * This is a way to expose existing values as metrics. + * + * This method is kept for binary compatibility purposes, it has the same behaviour as + * {@link #addMetric(MetricName, MetricValueProvider)}. + * + * @param metricName The name of the metric + * @param measurable The measurable that will be measured by this metric + */ + public void addMetric(MetricName metricName, Measurable measurable) { + addMetric(metricName, null, measurable); + } + + /** + * Add a metric to monitor an object that implements Measurable. This metric won't be associated with any sensor. + * This is a way to expose existing values as metrics. + * + * This method is kept for binary compatibility purposes, it has the same behaviour as + * {@link #addMetric(MetricName, MetricConfig, MetricValueProvider)}. + * + * @param metricName The name of the metric + * @param config The configuration to use when measuring this measurable + * @param measurable The measurable that will be measured by this metric + */ + public void addMetric(MetricName metricName, MetricConfig config, Measurable measurable) { + addMetric(metricName, config, (MetricValueProvider) measurable); + } + + /** + * Add a metric to monitor an object that implements MetricValueProvider. This metric won't be associated with any + * sensor. This is a way to expose existing values as metrics. User is expected to add any additional + * synchronization to update and access metric values, if required. + * + * @param metricName The name of the metric + * @param metricValueProvider The metric value provider associated with this metric + */ + public void addMetric(MetricName metricName, MetricConfig config, MetricValueProvider metricValueProvider) { + KafkaMetric m = new KafkaMetric(new Object(), + Objects.requireNonNull(metricName), + Objects.requireNonNull(metricValueProvider), + config == null ? this.config : config, + time); + registerMetric(m); + } + + /** + * Add a metric to monitor an object that implements MetricValueProvider. This metric won't be associated with any + * sensor. This is a way to expose existing values as metrics. User is expected to add any additional + * synchronization to update and access metric values, if required. + * + * @param metricName The name of the metric + * @param metricValueProvider The metric value provider associated with this metric + */ + public void addMetric(MetricName metricName, MetricValueProvider metricValueProvider) { + addMetric(metricName, null, metricValueProvider); + } + + /** + * Remove a metric if it exists and return it. Return null otherwise. If a metric is removed, `metricRemoval` + * will be invoked for each reporter. + * + * @param metricName The name of the metric + * @return the removed `KafkaMetric` or null if no such metric exists + */ + public synchronized KafkaMetric removeMetric(MetricName metricName) { + KafkaMetric metric = this.metrics.remove(metricName); + if (metric != null) { + for (MetricsReporter reporter : reporters) { + try { + reporter.metricRemoval(metric); + } catch (Exception e) { + log.error("Error when removing metric from " + reporter.getClass().getName(), e); + } + } + log.trace("Removed metric named {}", metricName); + } + return metric; + } + + /** + * Add a MetricReporter + */ + public synchronized void addReporter(MetricsReporter reporter) { + Objects.requireNonNull(reporter).init(new ArrayList<>(metrics.values())); + this.reporters.add(reporter); + } + + /** + * Remove a MetricReporter + */ + public synchronized void removeReporter(MetricsReporter reporter) { + if (this.reporters.remove(reporter)) { + reporter.close(); + } + } + + synchronized void registerMetric(KafkaMetric metric) { + MetricName metricName = metric.metricName(); + if (this.metrics.containsKey(metricName)) + throw new IllegalArgumentException("A metric named '" + metricName + "' already exists, can't register another one."); + this.metrics.put(metricName, metric); + for (MetricsReporter reporter : reporters) { + try { + reporter.metricChange(metric); + } catch (Exception e) { + log.error("Error when registering metric on " + reporter.getClass().getName(), e); + } + } + log.trace("Registered metric named {}", metricName); + } + + /** + * Get all the metrics currently maintained indexed by metricName + */ + public Map metrics() { + return this.metrics; + } + + public List reporters() { + return this.reporters; + } + + public KafkaMetric metric(MetricName metricName) { + return this.metrics.get(metricName); + } + + /** + * This iterates over every Sensor and triggers a removeSensor if it has expired + * Package private for testing + */ + class ExpireSensorTask implements Runnable { + @Override + public void run() { + for (Map.Entry sensorEntry : sensors.entrySet()) { + // removeSensor also locks the sensor object. This is fine because synchronized is reentrant + // There is however a minor race condition here. Assume we have a parent sensor P and child sensor C. + // Calling record on C would cause a record on P as well. + // So expiration time for P == expiration time for C. If the record on P happens via C just after P is removed, + // that will cause C to also get removed. + // Since the expiration time is typically high it is not expected to be a significant concern + // and thus not necessary to optimize + synchronized (sensorEntry.getValue()) { + if (sensorEntry.getValue().hasExpired()) { + log.debug("Removing expired sensor {}", sensorEntry.getKey()); + removeSensor(sensorEntry.getKey()); + } + } + } + } + } + + /* For testing use only. */ + Map> childrenSensors() { + return Collections.unmodifiableMap(childrenSensors); + } + + public MetricName metricInstance(MetricNameTemplate template, String... keyValue) { + return metricInstance(template, MetricsUtils.getTags(keyValue)); + } + + public MetricName metricInstance(MetricNameTemplate template, Map tags) { + // check to make sure that the runtime defined tags contain all the template tags. + Set runtimeTagKeys = new HashSet<>(tags.keySet()); + runtimeTagKeys.addAll(config().tags().keySet()); + + Set templateTagKeys = template.tags(); + + if (!runtimeTagKeys.equals(templateTagKeys)) { + throw new IllegalArgumentException("For '" + template.name() + "', runtime-defined metric tags do not match the tags in the template. " + + "Runtime = " + runtimeTagKeys.toString() + " Template = " + templateTagKeys.toString()); + } + + return this.metricName(template.name(), template.group(), template.description(), tags); + } + + /** + * Close this metrics repository. + */ + @Override + public void close() { + if (this.metricsScheduler != null) { + this.metricsScheduler.shutdown(); + try { + this.metricsScheduler.awaitTermination(30, TimeUnit.SECONDS); + } catch (InterruptedException ex) { + // ignore and continue shutdown + Thread.currentThread().interrupt(); + } + } + log.info("Metrics scheduler closed"); + + for (MetricsReporter reporter : reporters) { + try { + log.info("Closing reporter {}", reporter.getClass().getName()); + reporter.close(); + } catch (Exception e) { + log.error("Error when closing " + reporter.getClass().getName(), e); + } + } + log.info("Metrics reporters closed"); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/MetricsContext.java b/clients/src/main/java/org/apache/kafka/common/metrics/MetricsContext.java new file mode 100644 index 0000000..080aae6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/MetricsContext.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Map; + +/** + * MetricsContext encapsulates additional contextLabels about metrics exposed via a + * {@link org.apache.kafka.common.metrics.MetricsReporter} + * + *

        The {@link #contextLabels()} map provides following information: + *

        + *
        in all components
        + *
        a {@code _namespace} field indicating the component exposing metrics + * e.g. kafka.server, kafka.consumer. + * The {@link JmxReporter} uses this as prefix for MBean names
        + * + *
        for clients and streams libraries
        + *
        any freeform fields passed in via + * client properties in the form of {@code metrics.context.=}
        + * + *
        for kafka brokers
        + *
        kafka.broker.id, kafka.cluster.id
        + * + *
        for connect workers
        + *
        connect.kafka.cluster.id, connect.group.id
        + *
        + */ +@InterfaceStability.Evolving +public interface MetricsContext { + /* predefined fields */ + String NAMESPACE = "_namespace"; // metrics namespace, formerly jmx prefix + + /** + * Returns the labels for this metrics context. + * + * @return the map of label keys and values; never null but possibly empty + */ + Map contextLabels(); +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/MetricsReporter.java b/clients/src/main/java/org/apache/kafka/common/metrics/MetricsReporter.java new file mode 100644 index 0000000..75771fb --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/MetricsReporter.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.apache.kafka.common.Reconfigurable; +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.config.ConfigException; + +/** + * A plugin interface to allow things to listen as new metrics are created so they can be reported. + *

        + * Implement {@link org.apache.kafka.common.ClusterResourceListener} to receive cluster metadata once it's available. Please see the class documentation for ClusterResourceListener for more information. + */ +public interface MetricsReporter extends Reconfigurable, AutoCloseable { + + /** + * This is called when the reporter is first registered to initially register all existing metrics + * @param metrics All currently existing metrics + */ + void init(List metrics); + + /** + * This is called whenever a metric is updated or added + * @param metric + */ + void metricChange(KafkaMetric metric); + + /** + * This is called whenever a metric is removed + * @param metric + */ + void metricRemoval(KafkaMetric metric); + + /** + * Called when the metrics repository is closed. + */ + void close(); + + // default methods for backwards compatibility with reporters that only implement Configurable + default Set reconfigurableConfigs() { + return Collections.emptySet(); + } + + default void validateReconfiguration(Map configs) throws ConfigException { + } + + default void reconfigure(Map configs) { + } + + /** + * Sets the context labels for the service or library exposing metrics. This will be called before {@link #init(List)} and may be called anytime after that. + * + * @param metricsContext the metric context + */ + @InterfaceStability.Evolving + default void contextChange(MetricsContext metricsContext) { + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/Quota.java b/clients/src/main/java/org/apache/kafka/common/metrics/Quota.java new file mode 100644 index 0000000..e414133 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/Quota.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics; + +/** + * An upper or lower bound for metrics + */ +public final class Quota { + + private final boolean upper; + private final double bound; + + public Quota(double bound, boolean upper) { + this.bound = bound; + this.upper = upper; + } + + public static Quota upperBound(double upperBound) { + return new Quota(upperBound, true); + } + + public static Quota lowerBound(double lowerBound) { + return new Quota(lowerBound, false); + } + + public boolean isUpperBound() { + return this.upper; + } + + public double bound() { + return this.bound; + } + + public boolean acceptable(double value) { + return (upper && value <= bound) || (!upper && value >= bound); + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + (int) this.bound; + result = prime * result + (this.upper ? 1 : 0); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (!(obj instanceof Quota)) + return false; + Quota that = (Quota) obj; + return (that.bound == this.bound) && (that.upper == this.upper); + } + + @Override + public String toString() { + return (upper ? "upper=" : "lower=") + bound; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/QuotaViolationException.java b/clients/src/main/java/org/apache/kafka/common/metrics/QuotaViolationException.java new file mode 100644 index 0000000..7068d31 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/QuotaViolationException.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics; + +import org.apache.kafka.common.KafkaException; + +/** + * Thrown when a sensor records a value that causes a metric to go outside the bounds configured as its quota + */ +public class QuotaViolationException extends KafkaException { + + private static final long serialVersionUID = 1L; + private final KafkaMetric metric; + private final double value; + private final double bound; + + public QuotaViolationException(KafkaMetric metric, double value, double bound) { + this.metric = metric; + this.value = value; + this.bound = bound; + } + + public KafkaMetric metric() { + return metric; + } + + public double value() { + return value; + } + + public double bound() { + return bound; + } + + @Override + public String toString() { + return getClass().getName() + + ": '" + + metric.metricName() + + "' violated quota. Actual: " + + value + + ", Threshold: " + + bound; + } + + /* avoid the expensive and stack trace for quota violation exceptions */ + @Override + public Throwable fillInStackTrace() { + return this; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/Sensor.java b/clients/src/main/java/org/apache/kafka/common/metrics/Sensor.java new file mode 100644 index 0000000..5ae3b8d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/Sensor.java @@ -0,0 +1,390 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics; + +import java.util.function.Supplier; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.CompoundStat.NamedMeasurable; +import org.apache.kafka.common.metrics.stats.TokenBucket; +import org.apache.kafka.common.utils.Time; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import static java.util.Arrays.asList; +import static java.util.Collections.unmodifiableList; + +/** + * A sensor applies a continuous sequence of numerical values to a set of associated metrics. For example a sensor on + * message size would record a sequence of message sizes using the {@link #record(double)} api and would maintain a set + * of metrics about request sizes such as the average or max. + */ +public final class Sensor { + + private final Metrics registry; + private final String name; + private final Sensor[] parents; + private final List stats; + private final Map metrics; + private final MetricConfig config; + private final Time time; + private volatile long lastRecordTime; + private final long inactiveSensorExpirationTimeMs; + private final Object metricLock; + + private static class StatAndConfig { + private final Stat stat; + private final Supplier configSupplier; + + StatAndConfig(Stat stat, Supplier configSupplier) { + this.stat = stat; + this.configSupplier = configSupplier; + } + + public Stat stat() { + return stat; + } + + public MetricConfig config() { + return configSupplier.get(); + } + + @Override + public String toString() { + return "StatAndConfig(stat=" + stat + ')'; + } + } + + public enum RecordingLevel { + INFO(0, "INFO"), DEBUG(1, "DEBUG"), TRACE(2, "TRACE"); + + private static final RecordingLevel[] ID_TO_TYPE; + private static final int MIN_RECORDING_LEVEL_KEY = 0; + public static final int MAX_RECORDING_LEVEL_KEY; + + static { + int maxRL = -1; + for (RecordingLevel level : RecordingLevel.values()) { + maxRL = Math.max(maxRL, level.id); + } + RecordingLevel[] idToName = new RecordingLevel[maxRL + 1]; + for (RecordingLevel level : RecordingLevel.values()) { + idToName[level.id] = level; + } + ID_TO_TYPE = idToName; + MAX_RECORDING_LEVEL_KEY = maxRL; + } + + /** an english description of the api--this is for debugging and can change */ + public final String name; + + /** the permanent and immutable id of an API--this can't change ever */ + public final short id; + + RecordingLevel(int id, String name) { + this.id = (short) id; + this.name = name; + } + + public static RecordingLevel forId(int id) { + if (id < MIN_RECORDING_LEVEL_KEY || id > MAX_RECORDING_LEVEL_KEY) + throw new IllegalArgumentException(String.format("Unexpected RecordLevel id `%d`, it should be between `%d` " + + "and `%d` (inclusive)", id, MIN_RECORDING_LEVEL_KEY, MAX_RECORDING_LEVEL_KEY)); + return ID_TO_TYPE[id]; + } + + /** Case insensitive lookup by protocol name */ + public static RecordingLevel forName(String name) { + return RecordingLevel.valueOf(name.toUpperCase(Locale.ROOT)); + } + + public boolean shouldRecord(final int configId) { + if (configId == INFO.id) { + return this.id == INFO.id; + } else if (configId == DEBUG.id) { + return this.id == INFO.id || this.id == DEBUG.id; + } else if (configId == TRACE.id) { + return true; + } else { + throw new IllegalStateException("Did not recognize recording level " + configId); + } + } + } + + private final RecordingLevel recordingLevel; + + Sensor(Metrics registry, String name, Sensor[] parents, MetricConfig config, Time time, + long inactiveSensorExpirationTimeSeconds, RecordingLevel recordingLevel) { + super(); + this.registry = registry; + this.name = Objects.requireNonNull(name); + this.parents = parents == null ? new Sensor[0] : parents; + this.metrics = new LinkedHashMap<>(); + this.stats = new ArrayList<>(); + this.config = config; + this.time = time; + this.inactiveSensorExpirationTimeMs = TimeUnit.MILLISECONDS.convert(inactiveSensorExpirationTimeSeconds, TimeUnit.SECONDS); + this.lastRecordTime = time.milliseconds(); + this.recordingLevel = recordingLevel; + this.metricLock = new Object(); + checkForest(new HashSet<>()); + } + + /* Validate that this sensor doesn't end up referencing itself */ + private void checkForest(Set sensors) { + if (!sensors.add(this)) + throw new IllegalArgumentException("Circular dependency in sensors: " + name() + " is its own parent."); + for (Sensor parent : parents) + parent.checkForest(sensors); + } + + /** + * The name this sensor is registered with. This name will be unique among all registered sensors. + */ + public String name() { + return this.name; + } + + List parents() { + return unmodifiableList(asList(parents)); + } + + /** + * @return true if the sensor's record level indicates that the metric will be recorded, false otherwise + */ + public boolean shouldRecord() { + return this.recordingLevel.shouldRecord(config.recordLevel().id); + } + + /** + * Record an occurrence, this is just short-hand for {@link #record(double) record(1.0)} + */ + public void record() { + if (shouldRecord()) { + recordInternal(1.0d, time.milliseconds(), true); + } + } + + /** + * Record a value with this sensor + * @param value The value to record + * @throws QuotaViolationException if recording this value moves a metric beyond its configured maximum or minimum + * bound + */ + public void record(double value) { + if (shouldRecord()) { + recordInternal(value, time.milliseconds(), true); + } + } + + /** + * Record a value at a known time. This method is slightly faster than {@link #record(double)} since it will reuse + * the time stamp. + * @param value The value we are recording + * @param timeMs The current POSIX time in milliseconds + * @throws QuotaViolationException if recording this value moves a metric beyond its configured maximum or minimum + * bound + */ + public void record(double value, long timeMs) { + if (shouldRecord()) { + recordInternal(value, timeMs, true); + } + } + + /** + * Record a value at a known time. This method is slightly faster than {@link #record(double)} since it will reuse + * the time stamp. + * @param value The value we are recording + * @param timeMs The current POSIX time in milliseconds + * @param checkQuotas Indicate if quota must be enforced or not + * @throws QuotaViolationException if recording this value moves a metric beyond its configured maximum or minimum + * bound + */ + public void record(double value, long timeMs, boolean checkQuotas) { + if (shouldRecord()) { + recordInternal(value, timeMs, checkQuotas); + } + } + + private void recordInternal(double value, long timeMs, boolean checkQuotas) { + this.lastRecordTime = timeMs; + synchronized (this) { + synchronized (metricLock()) { + // increment all the stats + for (StatAndConfig statAndConfig : this.stats) { + statAndConfig.stat.record(statAndConfig.config(), value, timeMs); + } + } + if (checkQuotas) + checkQuotas(timeMs); + } + for (Sensor parent : parents) + parent.record(value, timeMs, checkQuotas); + } + + /** + * Check if we have violated our quota for any metric that has a configured quota + */ + public void checkQuotas() { + checkQuotas(time.milliseconds()); + } + + public void checkQuotas(long timeMs) { + for (KafkaMetric metric : this.metrics.values()) { + MetricConfig config = metric.config(); + if (config != null) { + Quota quota = config.quota(); + if (quota != null) { + double value = metric.measurableValue(timeMs); + if (metric.measurable() instanceof TokenBucket) { + if (value < 0) { + throw new QuotaViolationException(metric, value, quota.bound()); + } + } else { + if (!quota.acceptable(value)) { + throw new QuotaViolationException(metric, value, quota.bound()); + } + } + } + } + } + } + + /** + * Register a compound statistic with this sensor with no config override + * @param stat The stat to register + * @return true if stat is added to sensor, false if sensor is expired + */ + public boolean add(CompoundStat stat) { + return add(stat, null); + } + + /** + * Register a compound statistic with this sensor which yields multiple measurable quantities (like a histogram) + * @param stat The stat to register + * @param config The configuration for this stat. If null then the stat will use the default configuration for this + * sensor. + * @return true if stat is added to sensor, false if sensor is expired + */ + public synchronized boolean add(CompoundStat stat, MetricConfig config) { + if (hasExpired()) + return false; + + final MetricConfig statConfig = config == null ? this.config : config; + stats.add(new StatAndConfig(Objects.requireNonNull(stat), () -> statConfig)); + Object lock = metricLock(); + for (NamedMeasurable m : stat.stats()) { + final KafkaMetric metric = new KafkaMetric(lock, m.name(), m.stat(), statConfig, time); + if (!metrics.containsKey(metric.metricName())) { + registry.registerMetric(metric); + metrics.put(metric.metricName(), metric); + } + } + return true; + } + + /** + * Register a metric with this sensor + * @param metricName The name of the metric + * @param stat The statistic to keep + * @return true if metric is added to sensor, false if sensor is expired + */ + public boolean add(MetricName metricName, MeasurableStat stat) { + return add(metricName, stat, null); + } + + /** + * Register a metric with this sensor + * + * @param metricName The name of the metric + * @param stat The statistic to keep + * @param config A special configuration for this metric. If null use the sensor default configuration. + * @return true if metric is added to sensor, false if sensor is expired + */ + public synchronized boolean add(final MetricName metricName, final MeasurableStat stat, final MetricConfig config) { + if (hasExpired()) { + return false; + } else if (metrics.containsKey(metricName)) { + return true; + } else { + final MetricConfig statConfig = config == null ? this.config : config; + final KafkaMetric metric = new KafkaMetric( + metricLock(), + Objects.requireNonNull(metricName), + Objects.requireNonNull(stat), + statConfig, + time + ); + registry.registerMetric(metric); + metrics.put(metric.metricName(), metric); + stats.add(new StatAndConfig(Objects.requireNonNull(stat), metric::config)); + return true; + } + } + + /** + * Return if metrics were registered with this sensor. + * + * @return true if metrics were registered, false otherwise + */ + public synchronized boolean hasMetrics() { + return !metrics.isEmpty(); + } + + /** + * Return true if the Sensor is eligible for removal due to inactivity. + * false otherwise + */ + public boolean hasExpired() { + return (time.milliseconds() - this.lastRecordTime) > this.inactiveSensorExpirationTimeMs; + } + + synchronized List metrics() { + return unmodifiableList(new ArrayList<>(this.metrics.values())); + } + + /** + * KafkaMetrics of sensors which use SampledStat should be synchronized on the same lock + * for sensor record and metric value read to allow concurrent reads and updates. For simplicity, + * all sensors are synchronized on this object. + *

        + * Sensor object is not used as a lock for reading metric value since metrics reporter is + * invoked while holding Sensor and Metrics locks to report addition and removal of metrics + * and synchronized reporters may deadlock if Sensor lock is used for reading metrics values. + * Note that Sensor object itself is used as a lock to protect the access to stats and metrics + * while recording metric values, adding and deleting sensors. + *

        + * Locking order (assume all MetricsReporter methods may be synchronized): + *

          + *
        • Sensor#add: Sensor -> Metrics -> MetricsReporter
        • + *
        • Metrics#removeSensor: Sensor -> Metrics -> MetricsReporter
        • + *
        • KafkaMetric#metricValue: MetricsReporter -> Sensor#metricLock
        • + *
        • Sensor#record: Sensor -> Sensor#metricLock
        • + *
        + *

        + */ + private Object metricLock() { + return metricLock; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/Stat.java b/clients/src/main/java/org/apache/kafka/common/metrics/Stat.java new file mode 100644 index 0000000..fa5aa1a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/Stat.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics; + +/** + * A Stat is a quantity such as average, max, etc that is computed off the stream of updates to a sensor + */ +public interface Stat { + + /** + * Record the given value + * @param config The configuration to use for this metric + * @param value The value to record + * @param timeMs The POSIX time in milliseconds this value occurred + */ + void record(MetricConfig config, double value, long timeMs); + +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/internals/IntGaugeSuite.java b/clients/src/main/java/org/apache/kafka/common/metrics/internals/IntGaugeSuite.java new file mode 100644 index 0000000..cd52759 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/internals/IntGaugeSuite.java @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics.internals; + +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.Gauge; +import org.apache.kafka.common.metrics.MetricConfig; +import org.apache.kafka.common.metrics.MetricValueProvider; +import org.apache.kafka.common.metrics.Metrics; +import org.slf4j.Logger; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Function; + +/** + * Manages a suite of integer Gauges. + */ +public final class IntGaugeSuite implements AutoCloseable { + /** + * The log4j logger. + */ + private final Logger log; + + /** + * The name of this suite. + */ + private final String suiteName; + + /** + * The metrics object to use. + */ + private final Metrics metrics; + + /** + * A user-supplied callback which translates keys into unique metric names. + */ + private final Function metricNameCalculator; + + /** + * The maximum number of gauges that we will ever create at once. + */ + private final int maxEntries; + + /** + * A map from keys to gauges. Protected by the object monitor. + */ + private final Map gauges; + + /** + * The keys of gauges that can be removed, since their value is zero. + * Protected by the object monitor. + */ + private final Set removable; + + /** + * A lockless list of pending metrics additions and removals. + */ + private final ConcurrentLinkedDeque pending; + + /** + * A lock which serializes modifications to metrics. This lock is not + * required to create a new pending operation. + */ + private final Lock modifyMetricsLock; + + /** + * True if this suite is closed. Protected by the object monitor. + */ + private boolean closed; + + /** + * A pending metrics addition or removal. + */ + private static class PendingMetricsChange { + /** + * The name of the metric to add or remove. + */ + private final MetricName metricName; + + /** + * In an addition, this field is the MetricValueProvider to add. + * In a removal, this field is null. + */ + private final MetricValueProvider provider; + + PendingMetricsChange(MetricName metricName, MetricValueProvider provider) { + this.metricName = metricName; + this.provider = provider; + } + } + + /** + * The gauge object which we register with the metrics system. + */ + private static class StoredIntGauge implements Gauge { + private final MetricName metricName; + private int value; + + StoredIntGauge(MetricName metricName) { + this.metricName = metricName; + this.value = 1; + } + + /** + * This callback is invoked when the metrics system retrieves the value of this gauge. + */ + @Override + public synchronized Integer value(MetricConfig config, long now) { + return value; + } + + synchronized int increment() { + return ++value; + } + + synchronized int decrement() { + return --value; + } + + synchronized int value() { + return value; + } + } + + public IntGaugeSuite(Logger log, + String suiteName, + Metrics metrics, + Function metricNameCalculator, + int maxEntries) { + this.log = log; + this.suiteName = suiteName; + this.metrics = metrics; + this.metricNameCalculator = metricNameCalculator; + this.maxEntries = maxEntries; + this.gauges = new HashMap<>(1); + this.removable = new HashSet<>(); + this.pending = new ConcurrentLinkedDeque<>(); + this.modifyMetricsLock = new ReentrantLock(); + this.closed = false; + log.trace("{}: created new gauge suite with maxEntries = {}.", + suiteName, maxEntries); + } + + public void increment(K key) { + synchronized (this) { + if (closed) { + log.warn("{}: Attempted to increment {}, but the GaugeSuite was closed.", + suiteName, key.toString()); + return; + } + StoredIntGauge gauge = gauges.get(key); + if (gauge != null) { + // Fast path: increment the existing counter. + if (gauge.increment() > 0) { + removable.remove(key); + } + return; + } + if (gauges.size() == maxEntries) { + if (removable.isEmpty()) { + log.debug("{}: Attempted to increment {}, but there are already {} entries.", + suiteName, key.toString(), maxEntries); + return; + } + Iterator iter = removable.iterator(); + K keyToRemove = iter.next(); + iter.remove(); + MetricName metricNameToRemove = gauges.get(keyToRemove).metricName; + gauges.remove(keyToRemove); + pending.push(new PendingMetricsChange(metricNameToRemove, null)); + log.trace("{}: Removing the metric {}, which has a value of 0.", + suiteName, keyToRemove.toString()); + } + MetricName metricNameToAdd = metricNameCalculator.apply(key); + gauge = new StoredIntGauge(metricNameToAdd); + gauges.put(key, gauge); + pending.push(new PendingMetricsChange(metricNameToAdd, gauge)); + log.trace("{}: Adding a new metric {}.", suiteName, key.toString()); + } + // Drop the object monitor and perform any pending metrics additions or removals. + performPendingMetricsOperations(); + } + + /** + * Perform pending metrics additions or removals. + * It is important to perform them in order. For example, we don't want to try + * to remove a metric that we haven't finished adding yet. + */ + private void performPendingMetricsOperations() { + modifyMetricsLock.lock(); + try { + log.trace("{}: entering performPendingMetricsOperations", suiteName); + for (PendingMetricsChange change = pending.pollLast(); + change != null; + change = pending.pollLast()) { + if (change.provider == null) { + if (log.isTraceEnabled()) { + log.trace("{}: removing metric {}", suiteName, change.metricName); + } + metrics.removeMetric(change.metricName); + } else { + if (log.isTraceEnabled()) { + log.trace("{}: adding metric {}", suiteName, change.metricName); + } + metrics.addMetric(change.metricName, change.provider); + } + } + log.trace("{}: leaving performPendingMetricsOperations", suiteName); + } finally { + modifyMetricsLock.unlock(); + } + } + + public synchronized void decrement(K key) { + if (closed) { + log.warn("{}: Attempted to decrement {}, but the gauge suite was closed.", + suiteName, key.toString()); + return; + } + StoredIntGauge gauge = gauges.get(key); + if (gauge == null) { + log.debug("{}: Attempted to decrement {}, but no such metric was registered.", + suiteName, key.toString()); + } else { + int cur = gauge.decrement(); + log.trace("{}: Removed a reference to {}. {} reference(s) remaining.", + suiteName, key.toString(), cur); + if (cur <= 0) { + removable.add(key); + } + } + } + + @Override + public synchronized void close() { + if (closed) { + log.trace("{}: gauge suite is already closed.", suiteName); + return; + } + closed = true; + int prevSize = 0; + for (Iterator iter = gauges.values().iterator(); iter.hasNext(); ) { + pending.push(new PendingMetricsChange(iter.next().metricName, null)); + prevSize++; + iter.remove(); + } + performPendingMetricsOperations(); + log.trace("{}: closed {} metric(s).", suiteName, prevSize); + } + + /** + * Get the maximum number of metrics this suite can create. + */ + public int maxEntries() { + return maxEntries; + } + + // Visible for testing only. + Metrics metrics() { + return metrics; + } + + /** + * Return a map from keys to current reference counts. + * Visible for testing only. + */ + synchronized Map values() { + HashMap values = new HashMap<>(); + for (Map.Entry entry : gauges.entrySet()) { + values.put(entry.getKey(), entry.getValue().value()); + } + return values; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/internals/MetricsUtils.java b/clients/src/main/java/org/apache/kafka/common/metrics/internals/MetricsUtils.java new file mode 100644 index 0000000..edf061f --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/internals/MetricsUtils.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics.internals; + +import org.apache.kafka.common.metrics.Metrics; + +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +public class MetricsUtils { + /** + * Converts the provided time from milliseconds to the requested + * time unit. + */ + public static double convert(long timeMs, TimeUnit unit) { + switch (unit) { + case NANOSECONDS: + return timeMs * 1000.0 * 1000.0; + case MICROSECONDS: + return timeMs * 1000.0; + case MILLISECONDS: + return timeMs; + case SECONDS: + return timeMs / 1000.0; + case MINUTES: + return timeMs / (60.0 * 1000.0); + case HOURS: + return timeMs / (60.0 * 60.0 * 1000.0); + case DAYS: + return timeMs / (24.0 * 60.0 * 60.0 * 1000.0); + default: + throw new IllegalStateException("Unknown unit: " + unit); + } + } + + /** + * Create a set of tags using the supplied key and value pairs. The order of the tags will be kept. + * + * @param keyValue the key and value pairs for the tags; must be an even number + * @return the map of tags that can be supplied to the {@link Metrics} methods; never null + */ + public static Map getTags(String... keyValue) { + if ((keyValue.length % 2) != 0) + throw new IllegalArgumentException("keyValue needs to be specified in pairs"); + Map tags = new LinkedHashMap<>(keyValue.length / 2); + + for (int i = 0; i < keyValue.length; i += 2) + tags.put(keyValue[i], keyValue[i + 1]); + return tags; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/stats/Avg.java b/clients/src/main/java/org/apache/kafka/common/metrics/stats/Avg.java new file mode 100644 index 0000000..4e6c337 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/stats/Avg.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics.stats; + +import java.util.List; + +import org.apache.kafka.common.metrics.MetricConfig; + +/** + * A {@link SampledStat} that maintains a simple average over its samples. + */ +public class Avg extends SampledStat { + + public Avg() { + super(0.0); + } + + @Override + protected void update(Sample sample, MetricConfig config, double value, long now) { + sample.value += value; + } + + @Override + public double combine(List samples, MetricConfig config, long now) { + double total = 0.0; + long count = 0; + for (Sample s : samples) { + total += s.value; + count += s.eventCount; + } + return count == 0 ? Double.NaN : total / count; + } + +} + diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/stats/CumulativeCount.java b/clients/src/main/java/org/apache/kafka/common/metrics/stats/CumulativeCount.java new file mode 100644 index 0000000..85591b5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/stats/CumulativeCount.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics.stats; + +import org.apache.kafka.common.metrics.MetricConfig; + +/** + * A non-sampled version of {@link WindowedCount} maintained over all time. + * + * This is a special kind of {@link CumulativeSum} that always records {@code 1} instead of the provided value. + * In other words, it counts the number of + * {@link CumulativeCount#record(MetricConfig, double, long)} invocations, + * instead of summing the recorded values. + */ +public class CumulativeCount extends CumulativeSum { + @Override + public void record(final MetricConfig config, final double value, final long timeMs) { + super.record(config, 1, timeMs); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/stats/CumulativeSum.java b/clients/src/main/java/org/apache/kafka/common/metrics/stats/CumulativeSum.java new file mode 100644 index 0000000..6726b9d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/stats/CumulativeSum.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics.stats; + +import org.apache.kafka.common.metrics.MeasurableStat; +import org.apache.kafka.common.metrics.MetricConfig; + +/** + * An non-sampled cumulative total maintained over all time. + * This is a non-sampled version of {@link WindowedSum}. + * + * See also {@link CumulativeCount} if you just want to increment the value by 1 on each recording. + */ +public class CumulativeSum implements MeasurableStat { + + private double total; + + public CumulativeSum() { + total = 0.0; + } + + public CumulativeSum(double value) { + total = value; + } + + @Override + public void record(MetricConfig config, double value, long now) { + total += value; + } + + @Override + public double measure(MetricConfig config, long now) { + return total; + } + + @Override + public String toString() { + return "CumulativeSum(total=" + total + ")"; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/stats/Frequencies.java b/clients/src/main/java/org/apache/kafka/common/metrics/stats/Frequencies.java new file mode 100644 index 0000000..36daea6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/stats/Frequencies.java @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics.stats; + +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.CompoundStat; +import org.apache.kafka.common.metrics.Measurable; +import org.apache.kafka.common.metrics.MetricConfig; +import org.apache.kafka.common.metrics.stats.Histogram.BinScheme; +import org.apache.kafka.common.metrics.stats.Histogram.ConstantBinScheme; + +import java.util.ArrayList; +import java.util.List; + +/** + * A {@link CompoundStat} that represents a normalized distribution with a {@link Frequency} metric for each + * bucketed value. The values of the {@link Frequency} metrics specify the frequency of the center value appearing + * relative to the total number of values recorded. + *

        + * For example, consider a component that records failure or success of an operation using boolean values, with + * one metric to capture the percentage of operations that failed another to capture the percentage of operations + * that succeeded. + *

        + * This can be accomplish by created a {@link org.apache.kafka.common.metrics.Sensor Sensor} to record the values, + * with 0.0 for false and 1.0 for true. Then, create a single {@link Frequencies} object that has two + * {@link Frequency} metrics: one centered around 0.0 and another centered around 1.0. The {@link Frequencies} + * object is a {@link CompoundStat}, and so it can be {@link org.apache.kafka.common.metrics.Sensor#add(CompoundStat) + * added directly to a Sensor} so the metrics are created automatically. + */ +public class Frequencies extends SampledStat implements CompoundStat { + + /** + * Create a Frequencies instance with metrics for the frequency of a boolean sensor that records 0.0 for + * false and 1.0 for true. + * + * @param falseMetricName the name of the metric capturing the frequency of failures; may be null if not needed + * @param trueMetricName the name of the metric capturing the frequency of successes; may be null if not needed + * @return the Frequencies instance; never null + * @throws IllegalArgumentException if both {@code falseMetricName} and {@code trueMetricName} are null + */ + public static Frequencies forBooleanValues(MetricName falseMetricName, MetricName trueMetricName) { + List frequencies = new ArrayList<>(); + if (falseMetricName != null) { + frequencies.add(new Frequency(falseMetricName, 0.0)); + } + if (trueMetricName != null) { + frequencies.add(new Frequency(trueMetricName, 1.0)); + } + if (frequencies.isEmpty()) { + throw new IllegalArgumentException("Must specify at least one metric name"); + } + Frequency[] frequencyArray = frequencies.toArray(new Frequency[0]); + return new Frequencies(2, 0.0, 1.0, frequencyArray); + } + + private final Frequency[] frequencies; + private final BinScheme binScheme; + + /** + * Create a Frequencies that captures the values in the specified range into the given number of buckets, + * where the buckets are centered around the minimum, maximum, and intermediate values. + * + * @param buckets the number of buckets; must be at least 1 + * @param min the minimum value to be captured + * @param max the maximum value to be captured + * @param frequencies the list of {@link Frequency} metrics, which at most should be one per bucket centered + * on the bucket's value, though not every bucket need to correspond to a metric if the + * value is not needed + * @throws IllegalArgumentException if any of the {@link Frequency} objects do not have a + * {@link Frequency#centerValue() center value} within the specified range + */ + public Frequencies(int buckets, double min, double max, Frequency... frequencies) { + super(0.0); // initial value is unused by this implementation + if (max < min) { + throw new IllegalArgumentException("The maximum value " + max + + " must be greater than the minimum value " + min); + } + if (buckets < 1) { + throw new IllegalArgumentException("Must be at least 1 bucket"); + } + if (buckets < frequencies.length) { + throw new IllegalArgumentException("More frequencies than buckets"); + } + this.frequencies = frequencies; + for (Frequency freq : frequencies) { + if (min > freq.centerValue() || max < freq.centerValue()) { + throw new IllegalArgumentException("The frequency centered at '" + freq.centerValue() + + "' is not within the range [" + min + "," + max + "]"); + } + } + double halfBucketWidth = (max - min) / (buckets - 1) / 2.0; + this.binScheme = new ConstantBinScheme(buckets, min - halfBucketWidth, max + halfBucketWidth); + } + + @Override + public List stats() { + List ms = new ArrayList<>(frequencies.length); + for (Frequency frequency : frequencies) { + final double center = frequency.centerValue(); + ms.add(new NamedMeasurable(frequency.name(), new Measurable() { + public double measure(MetricConfig config, long now) { + return frequency(config, now, center); + } + })); + } + return ms; + } + + /** + * Return the computed frequency describing the number of occurrences of the values in the bucket for the given + * center point, relative to the total number of occurrences in the samples. + * + * @param config the metric configuration + * @param now the current time in milliseconds + * @param centerValue the value corresponding to the center point of the bucket + * @return the frequency of the values in the bucket relative to the total number of samples + */ + public double frequency(MetricConfig config, long now, double centerValue) { + purgeObsoleteSamples(config, now); + long totalCount = 0; + for (Sample sample : samples) { + totalCount += sample.eventCount; + } + if (totalCount == 0) { + return 0.0d; + } + // Add up all of the counts in the bin corresponding to the center value + float count = 0.0f; + int binNum = binScheme.toBin(centerValue); + for (Sample s : samples) { + HistogramSample sample = (HistogramSample) s; + float[] hist = sample.histogram.counts(); + count += hist[binNum]; + } + // Compute the ratio of counts to total counts + return count / (double) totalCount; + } + + double totalCount() { + long count = 0; + for (Sample sample : samples) { + count += sample.eventCount; + } + return count; + } + + @Override + public double combine(List samples, MetricConfig config, long now) { + return totalCount(); + } + + @Override + protected HistogramSample newSample(long timeMs) { + return new HistogramSample(binScheme, timeMs); + } + + @Override + protected void update(Sample sample, MetricConfig config, double value, long timeMs) { + HistogramSample hist = (HistogramSample) sample; + hist.histogram.record(value); + } + + private static class HistogramSample extends SampledStat.Sample { + + private final Histogram histogram; + + private HistogramSample(BinScheme scheme, long now) { + super(0.0, now); + histogram = new Histogram(scheme); + } + + @Override + public void reset(long now) { + super.reset(now); + histogram.clear(); + } + } +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/stats/Frequency.java b/clients/src/main/java/org/apache/kafka/common/metrics/stats/Frequency.java new file mode 100644 index 0000000..5222219 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/stats/Frequency.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics.stats; + + +import org.apache.kafka.common.MetricName; + +/** + * Definition of a frequency metric used in a {@link Frequencies} compound statistic. + */ +public class Frequency { + + private final MetricName name; + private final double centerValue; + + /** + * Create an instance with the given name and center point value. + * + * @param name the name of the frequency metric; may not be null + * @param centerValue the value identifying the {@link Frequencies} bucket to be reported + */ + public Frequency(MetricName name, double centerValue) { + this.name = name; + this.centerValue = centerValue; + } + + /** + * Get the name of this metric. + * + * @return the metric name; never null + */ + public MetricName name() { + return this.name; + } + + /** + * Get the value of this metrics center point. + * + * @return the center point value + */ + public double centerValue() { + return this.centerValue; + } + + @Override + public String toString() { + return "Frequency(" + + "name=" + name + + ", centerValue=" + centerValue + + ')'; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/stats/Histogram.java b/clients/src/main/java/org/apache/kafka/common/metrics/stats/Histogram.java new file mode 100644 index 0000000..97f9182 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/stats/Histogram.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics.stats; + +import java.util.Arrays; + +public class Histogram { + + private final BinScheme binScheme; + private final float[] hist; + private double count; + + public Histogram(BinScheme binScheme) { + this.hist = new float[binScheme.bins()]; + this.count = 0.0f; + this.binScheme = binScheme; + } + + public void record(double value) { + this.hist[binScheme.toBin(value)] += 1.0f; + this.count += 1.0d; + } + + public double value(double quantile) { + if (count == 0.0d) + return Double.NaN; + if (quantile > 1.00d) + return Float.POSITIVE_INFINITY; + if (quantile < 0.00d) + return Float.NEGATIVE_INFINITY; + float sum = 0.0f; + float quant = (float) quantile; + for (int i = 0; i < this.hist.length - 1; i++) { + sum += this.hist[i]; + if (sum / count > quant) + return binScheme.fromBin(i); + } + return binScheme.fromBin(this.hist.length - 1); + } + + public float[] counts() { + return this.hist; + } + + public void clear() { + Arrays.fill(this.hist, 0.0f); + this.count = 0; + } + + @Override + public String toString() { + StringBuilder b = new StringBuilder("{"); + for (int i = 0; i < this.hist.length - 1; i++) { + b.append(String.format("%.10f", binScheme.fromBin(i))); + b.append(':'); + b.append(String.format("%.0f", this.hist[i])); + b.append(','); + } + b.append(Float.POSITIVE_INFINITY); + b.append(':'); + b.append(String.format("%.0f", this.hist[this.hist.length - 1])); + b.append('}'); + return b.toString(); + } + + /** + * An algorithm for determining the bin in which a value is to be placed as well as calculating the upper end + * of each bin. + */ + public interface BinScheme { + + /** + * Get the number of bins. + * + * @return the number of bins + */ + int bins(); + + /** + * Determine the 0-based bin number in which the supplied value should be placed. + * + * @param value the value + * @return the 0-based index of the bin + */ + int toBin(double value); + + /** + * Determine the value at the upper range of the specified bin. + * + * @param bin the 0-based bin number + * @return the value at the upper end of the bin; or {@link Float#NEGATIVE_INFINITY negative infinity} + * if the bin number is negative or {@link Float#POSITIVE_INFINITY positive infinity} if the 0-based + * bin number is greater than or equal to the {@link #bins() number of bins}. + */ + double fromBin(int bin); + } + + /** + * A scheme for calculating the bins where the width of each bin is a constant determined by the range of values + * and the number of bins. + */ + public static class ConstantBinScheme implements BinScheme { + private static final int MIN_BIN_NUMBER = 0; + private final double min; + private final double max; + private final int bins; + private final double bucketWidth; + private final int maxBinNumber; + + /** + * Create a bin scheme with the specified number of bins that all have the same width. + * + * @param bins the number of bins; must be at least 2 + * @param min the minimum value to be counted in the bins + * @param max the maximum value to be counted in the bins + */ + public ConstantBinScheme(int bins, double min, double max) { + if (bins < 2) + throw new IllegalArgumentException("Must have at least 2 bins."); + this.min = min; + this.max = max; + this.bins = bins; + this.bucketWidth = (max - min) / bins; + this.maxBinNumber = bins - 1; + } + + public int bins() { + return this.bins; + } + + public double fromBin(int b) { + if (b < MIN_BIN_NUMBER) { + return Float.NEGATIVE_INFINITY; + } + if (b > maxBinNumber) { + return Float.POSITIVE_INFINITY; + } + return min + b * bucketWidth; + } + + public int toBin(double x) { + int binNumber = (int) ((x - min) / bucketWidth); + if (binNumber < MIN_BIN_NUMBER) { + return MIN_BIN_NUMBER; + } + return Math.min(binNumber, maxBinNumber); + } + } + + /** + * A scheme for calculating the bins where the width of each bin is one more than the previous bin, and therefore + * the bin widths are increasing at a linear rate. However, the bin widths are scaled such that the specified range + * of values will all fit within the bins (e.g., the upper range of the last bin is equal to the maximum value). + */ + public static class LinearBinScheme implements BinScheme { + private final int bins; + private final double max; + private final double scale; + + /** + * Create a linear bin scheme with the specified number of bins and the maximum value to be counted in the bins. + * + * @param numBins the number of bins; must be at least 2 + * @param max the maximum value to be counted in the bins + */ + public LinearBinScheme(int numBins, double max) { + if (numBins < 2) + throw new IllegalArgumentException("Must have at least 2 bins."); + this.bins = numBins; + this.max = max; + double denom = numBins * (numBins - 1.0) / 2.0; + this.scale = max / denom; + } + + public int bins() { + return this.bins; + } + + public double fromBin(int b) { + if (b > this.bins - 1) { + return Float.POSITIVE_INFINITY; + } else if (b < 0.0000d) { + return Float.NEGATIVE_INFINITY; + } else { + return this.scale * (b * (b + 1.0)) / 2.0; + } + } + + public int toBin(double x) { + if (x < 0.0d) { + throw new IllegalArgumentException("Values less than 0.0 not accepted."); + } else if (x > this.max) { + return this.bins - 1; + } else { + return (int) (-0.5 + 0.5 * Math.sqrt(1.0 + 8.0 * x / this.scale)); + } + } + } +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/stats/Max.java b/clients/src/main/java/org/apache/kafka/common/metrics/stats/Max.java new file mode 100644 index 0000000..d91bf40 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/stats/Max.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics.stats; + +import java.util.List; + +import org.apache.kafka.common.metrics.MetricConfig; + +/** + * A {@link SampledStat} that gives the max over its samples. + */ +public final class Max extends SampledStat { + + public Max() { + super(Double.NEGATIVE_INFINITY); + } + + @Override + protected void update(Sample sample, MetricConfig config, double value, long now) { + sample.value = Math.max(sample.value, value); + } + + @Override + public double combine(List samples, MetricConfig config, long now) { + double max = Double.NEGATIVE_INFINITY; + long count = 0; + for (Sample sample : samples) { + max = Math.max(max, sample.value); + count += sample.eventCount; + } + return count == 0 ? Double.NaN : max; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/stats/Meter.java b/clients/src/main/java/org/apache/kafka/common/metrics/stats/Meter.java new file mode 100644 index 0000000..0eec0c4 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/stats/Meter.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics.stats; + +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.CompoundStat; +import org.apache.kafka.common.metrics.MetricConfig; + + +/** + * A compound stat that includes a rate metric and a cumulative total metric. + */ +public class Meter implements CompoundStat { + private final MetricName rateMetricName; + private final MetricName totalMetricName; + private final Rate rate; + private final CumulativeSum total; + + /** + * Construct a Meter with seconds as time unit + */ + public Meter(MetricName rateMetricName, MetricName totalMetricName) { + this(TimeUnit.SECONDS, new WindowedSum(), rateMetricName, totalMetricName); + } + + /** + * Construct a Meter with provided time unit + */ + public Meter(TimeUnit unit, MetricName rateMetricName, MetricName totalMetricName) { + this(unit, new WindowedSum(), rateMetricName, totalMetricName); + } + + /** + * Construct a Meter with seconds as time unit + */ + public Meter(SampledStat rateStat, MetricName rateMetricName, MetricName totalMetricName) { + this(TimeUnit.SECONDS, rateStat, rateMetricName, totalMetricName); + } + + /** + * Construct a Meter with provided time unit + */ + public Meter(TimeUnit unit, SampledStat rateStat, MetricName rateMetricName, MetricName totalMetricName) { + if (!(rateStat instanceof WindowedSum)) { + throw new IllegalArgumentException("Meter is supported only for WindowedCount or WindowedSum."); + } + this.total = new CumulativeSum(); + this.rate = new Rate(unit, rateStat); + this.rateMetricName = rateMetricName; + this.totalMetricName = totalMetricName; + } + + @Override + public List stats() { + return Arrays.asList( + new NamedMeasurable(totalMetricName, total), + new NamedMeasurable(rateMetricName, rate)); + } + + @Override + public void record(MetricConfig config, double value, long timeMs) { + rate.record(config, value, timeMs); + // Total metrics with Count stat should record 1.0 (as recorded in the count) + double totalValue = (rate.stat instanceof WindowedCount) ? 1.0 : value; + total.record(config, totalValue, timeMs); + } + + @Override + public String toString() { + return "Meter(" + + "rate=" + rate + + ", total=" + total + + ", rateMetricName=" + rateMetricName + + ", totalMetricName=" + totalMetricName + + ')'; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/stats/Min.java b/clients/src/main/java/org/apache/kafka/common/metrics/stats/Min.java new file mode 100644 index 0000000..3b9925a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/stats/Min.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics.stats; + +import java.util.List; + +import org.apache.kafka.common.metrics.MetricConfig; + +/** + * A {@link SampledStat} that gives the min over its samples. + */ +public class Min extends SampledStat { + + public Min() { + super(Double.MAX_VALUE); + } + + @Override + protected void update(Sample sample, MetricConfig config, double value, long now) { + sample.value = Math.min(sample.value, value); + } + + @Override + public double combine(List samples, MetricConfig config, long now) { + double min = Double.MAX_VALUE; + long count = 0; + for (Sample sample : samples) { + min = Math.min(min, sample.value); + count += sample.eventCount; + } + return count == 0 ? Double.NaN : min; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/stats/Percentile.java b/clients/src/main/java/org/apache/kafka/common/metrics/stats/Percentile.java new file mode 100644 index 0000000..f8ae3a1 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/stats/Percentile.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics.stats; + +import org.apache.kafka.common.MetricName; + +public class Percentile { + + private final MetricName name; + private final double percentile; + + public Percentile(MetricName name, double percentile) { + super(); + this.name = name; + this.percentile = percentile; + } + + public MetricName name() { + return this.name; + } + + public double percentile() { + return this.percentile; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/stats/Percentiles.java b/clients/src/main/java/org/apache/kafka/common/metrics/stats/Percentiles.java new file mode 100644 index 0000000..4cdc2ce --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/stats/Percentiles.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics.stats; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.kafka.common.metrics.CompoundStat; +import org.apache.kafka.common.metrics.MetricConfig; +import org.apache.kafka.common.metrics.stats.Histogram.BinScheme; +import org.apache.kafka.common.metrics.stats.Histogram.ConstantBinScheme; +import org.apache.kafka.common.metrics.stats.Histogram.LinearBinScheme; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A compound stat that reports one or more percentiles + */ +public class Percentiles extends SampledStat implements CompoundStat { + + private final Logger log = LoggerFactory.getLogger(Percentiles.class); + + public enum BucketSizing { + CONSTANT, LINEAR + } + + private final int buckets; + private final Percentile[] percentiles; + private final BinScheme binScheme; + private final double min; + private final double max; + + public Percentiles(int sizeInBytes, double max, BucketSizing bucketing, Percentile... percentiles) { + this(sizeInBytes, 0.0, max, bucketing, percentiles); + } + + public Percentiles(int sizeInBytes, double min, double max, BucketSizing bucketing, Percentile... percentiles) { + super(0.0); + this.percentiles = percentiles; + this.buckets = sizeInBytes / 4; + this.min = min; + this.max = max; + if (bucketing == BucketSizing.CONSTANT) { + this.binScheme = new ConstantBinScheme(buckets, min, max); + } else if (bucketing == BucketSizing.LINEAR) { + if (min != 0.0d) + throw new IllegalArgumentException("Linear bucket sizing requires min to be 0.0."); + this.binScheme = new LinearBinScheme(buckets, max); + } else { + throw new IllegalArgumentException("Unknown bucket type: " + bucketing); + } + } + + @Override + public List stats() { + List ms = new ArrayList<>(this.percentiles.length); + for (Percentile percentile : this.percentiles) { + final double pct = percentile.percentile(); + ms.add(new NamedMeasurable( + percentile.name(), + (config, now) -> value(config, now, pct / 100.0)) + ); + } + return ms; + } + + public double value(MetricConfig config, long now, double quantile) { + purgeObsoleteSamples(config, now); + float count = 0.0f; + for (Sample sample : this.samples) + count += sample.eventCount; + if (count == 0.0f) + return Double.NaN; + float sum = 0.0f; + float quant = (float) quantile; + for (int b = 0; b < buckets; b++) { + for (Sample s : this.samples) { + HistogramSample sample = (HistogramSample) s; + float[] hist = sample.histogram.counts(); + sum += hist[b]; + if (sum / count > quant) + return binScheme.fromBin(b); + } + } + return Double.POSITIVE_INFINITY; + } + + @Override + public double combine(List samples, MetricConfig config, long now) { + return value(config, now, 0.5); + } + + @Override + protected HistogramSample newSample(long timeMs) { + return new HistogramSample(this.binScheme, timeMs); + } + + @Override + protected void update(Sample sample, MetricConfig config, double value, long timeMs) { + final double boundedValue; + if (value > max) { + log.debug("Received value {} which is greater than max recordable value {}, will be pinned to the max value", + value, max); + boundedValue = max; + } else if (value < min) { + log.debug("Received value {} which is less than min recordable value {}, will be pinned to the min value", + value, min); + boundedValue = min; + } else { + boundedValue = value; + } + + HistogramSample hist = (HistogramSample) sample; + hist.histogram.record(boundedValue); + } + + private static class HistogramSample extends SampledStat.Sample { + private final Histogram histogram; + + private HistogramSample(BinScheme scheme, long now) { + super(0.0, now); + this.histogram = new Histogram(scheme); + } + + @Override + public void reset(long now) { + super.reset(now); + this.histogram.clear(); + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/stats/Rate.java b/clients/src/main/java/org/apache/kafka/common/metrics/stats/Rate.java new file mode 100644 index 0000000..c6b8574 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/stats/Rate.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics.stats; + +import java.util.Locale; +import java.util.concurrent.TimeUnit; + +import org.apache.kafka.common.metrics.MeasurableStat; +import org.apache.kafka.common.metrics.MetricConfig; + +import static org.apache.kafka.common.metrics.internals.MetricsUtils.convert; + +/** + * The rate of the given quantity. By default this is the total observed over a set of samples from a sampled statistic + * divided by the elapsed time over the sample windows. Alternative {@link SampledStat} implementations can be provided, + * however, to record the rate of occurrences (e.g. the count of values measured over the time interval) or other such + * values. + */ +public class Rate implements MeasurableStat { + + protected final TimeUnit unit; + protected final SampledStat stat; + + public Rate() { + this(TimeUnit.SECONDS); + } + + public Rate(TimeUnit unit) { + this(unit, new WindowedSum()); + } + + public Rate(SampledStat stat) { + this(TimeUnit.SECONDS, stat); + } + + public Rate(TimeUnit unit, SampledStat stat) { + this.stat = stat; + this.unit = unit; + } + + public String unitName() { + return unit.name().substring(0, unit.name().length() - 2).toLowerCase(Locale.ROOT); + } + + @Override + public void record(MetricConfig config, double value, long timeMs) { + this.stat.record(config, value, timeMs); + } + + @Override + public double measure(MetricConfig config, long now) { + double value = stat.measure(config, now); + return value / convert(windowSize(config, now), unit); + } + + public long windowSize(MetricConfig config, long now) { + // purge old samples before we compute the window size + stat.purgeObsoleteSamples(config, now); + + /* + * Here we check the total amount of time elapsed since the oldest non-obsolete window. + * This give the total windowSize of the batch which is the time used for Rate computation. + * However, there is an issue if we do not have sufficient data for e.g. if only 1 second has elapsed in a 30 second + * window, the measured rate will be very high. + * Hence we assume that the elapsed time is always N-1 complete windows plus whatever fraction of the final window is complete. + * + * Note that we could simply count the amount of time elapsed in the current window and add n-1 windows to get the total time, + * but this approach does not account for sleeps. SampledStat only creates samples whenever record is called, + * if no record is called for a period of time that time is not accounted for in windowSize and produces incorrect results. + */ + long totalElapsedTimeMs = now - stat.oldest(now).lastWindowMs; + // Check how many full windows of data we have currently retained + int numFullWindows = (int) (totalElapsedTimeMs / config.timeWindowMs()); + int minFullWindows = config.samples() - 1; + + // If the available windows are less than the minimum required, add the difference to the totalElapsedTime + if (numFullWindows < minFullWindows) + totalElapsedTimeMs += (minFullWindows - numFullWindows) * config.timeWindowMs(); + + return totalElapsedTimeMs; + } + + @Override + public String toString() { + return "Rate(" + + "unit=" + unit + + ", stat=" + stat + + ')'; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/stats/SampledStat.java b/clients/src/main/java/org/apache/kafka/common/metrics/stats/SampledStat.java new file mode 100644 index 0000000..faf596a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/stats/SampledStat.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics.stats; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.kafka.common.metrics.MeasurableStat; +import org.apache.kafka.common.metrics.MetricConfig; + +/** + * A SampledStat records a single scalar value measured over one or more samples. Each sample is recorded over a + * configurable window. The window can be defined by number of events or elapsed time (or both, if both are given the + * window is complete when either the event count or elapsed time criterion is met). + *

        + * All the samples are combined to produce the measurement. When a window is complete the oldest sample is cleared and + * recycled to begin recording the next sample. + * + * Subclasses of this class define different statistics measured using this basic pattern. + */ +public abstract class SampledStat implements MeasurableStat { + + private double initialValue; + private int current = 0; + protected List samples; + + public SampledStat(double initialValue) { + this.initialValue = initialValue; + this.samples = new ArrayList<>(2); + } + + @Override + public void record(MetricConfig config, double value, long timeMs) { + Sample sample = current(timeMs); + if (sample.isComplete(timeMs, config)) + sample = advance(config, timeMs); + update(sample, config, value, timeMs); + sample.eventCount += 1; + } + + private Sample advance(MetricConfig config, long timeMs) { + this.current = (this.current + 1) % config.samples(); + if (this.current >= samples.size()) { + Sample sample = newSample(timeMs); + this.samples.add(sample); + return sample; + } else { + Sample sample = current(timeMs); + sample.reset(timeMs); + return sample; + } + } + + protected Sample newSample(long timeMs) { + return new Sample(this.initialValue, timeMs); + } + + @Override + public double measure(MetricConfig config, long now) { + purgeObsoleteSamples(config, now); + return combine(this.samples, config, now); + } + + public Sample current(long timeMs) { + if (samples.size() == 0) + this.samples.add(newSample(timeMs)); + return this.samples.get(this.current); + } + + public Sample oldest(long now) { + if (samples.size() == 0) + this.samples.add(newSample(now)); + Sample oldest = this.samples.get(0); + for (int i = 1; i < this.samples.size(); i++) { + Sample curr = this.samples.get(i); + if (curr.lastWindowMs < oldest.lastWindowMs) + oldest = curr; + } + return oldest; + } + + @Override + public String toString() { + return "SampledStat(" + + "initialValue=" + initialValue + + ", current=" + current + + ", samples=" + samples + + ')'; + } + + protected abstract void update(Sample sample, MetricConfig config, double value, long timeMs); + + public abstract double combine(List samples, MetricConfig config, long now); + + /* Timeout any windows that have expired in the absence of any events */ + protected void purgeObsoleteSamples(MetricConfig config, long now) { + long expireAge = config.samples() * config.timeWindowMs(); + for (Sample sample : samples) { + if (now - sample.lastWindowMs >= expireAge) + sample.reset(now); + } + } + + protected static class Sample { + public double initialValue; + public long eventCount; + public long lastWindowMs; + public double value; + + public Sample(double initialValue, long now) { + this.initialValue = initialValue; + this.eventCount = 0; + this.lastWindowMs = now; + this.value = initialValue; + } + + public void reset(long now) { + this.eventCount = 0; + this.lastWindowMs = now; + this.value = initialValue; + } + + public boolean isComplete(long timeMs, MetricConfig config) { + return timeMs - lastWindowMs >= config.timeWindowMs() || eventCount >= config.eventWindow(); + } + + @Override + public String toString() { + return "Sample(" + + "value=" + value + + ", eventCount=" + eventCount + + ", lastWindowMs=" + lastWindowMs + + ", initialValue=" + initialValue + + ')'; + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/stats/SimpleRate.java b/clients/src/main/java/org/apache/kafka/common/metrics/stats/SimpleRate.java new file mode 100644 index 0000000..931bd9c --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/stats/SimpleRate.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics.stats; + +import org.apache.kafka.common.metrics.MetricConfig; + +/** + * A simple rate the rate is incrementally calculated + * based on the elapsed time between the earliest reading + * and now. + * + * An exception is made for the first window, which is + * considered of fixed size. This avoids the issue of + * an artificially high rate when the gap between readings + * is close to 0. + */ +public class SimpleRate extends Rate { + + @Override + public long windowSize(MetricConfig config, long now) { + stat.purgeObsoleteSamples(config, now); + long elapsed = now - stat.oldest(now).lastWindowMs; + return Math.max(elapsed, config.timeWindowMs()); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/stats/TokenBucket.java b/clients/src/main/java/org/apache/kafka/common/metrics/stats/TokenBucket.java new file mode 100644 index 0000000..c86ff51 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/stats/TokenBucket.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics.stats; + +import java.util.concurrent.TimeUnit; +import org.apache.kafka.common.metrics.MeasurableStat; +import org.apache.kafka.common.metrics.MetricConfig; +import org.apache.kafka.common.metrics.Quota; + +import static org.apache.kafka.common.metrics.internals.MetricsUtils.convert; + +/** + * The {@link TokenBucket} is a {@link MeasurableStat} implementing a token bucket algorithm + * that is usable within a {@link org.apache.kafka.common.metrics.Sensor}. + * + * The {@link Quota#bound()} defined the refill rate of the bucket while the maximum burst or + * the maximum number of credits of the bucket is defined by + * {@link MetricConfig#samples() * MetricConfig#timeWindowMs() * Quota#bound()}. + * + * The quota is considered as exhausted when the amount of remaining credits in the bucket + * is below zero. The enforcement is done by the {@link org.apache.kafka.common.metrics.Sensor}. + * + * Token Bucket vs Rate based Quota: + * The current sampled rate based quota does not cope well with bursty workloads. The issue is + * that a unique and large sample can hold the average above the quota until it is discarded. + * Practically, when this happens, one must wait until the sample is expired to bring the rate + * below the quota even though less time would be theoretically required. As an example, let's + * imagine that we have: + * - Quota (Q) = 5 + * - Samples (S) = 100 + * - Window (W) = 1s + * A burst of 560 brings the average rate (R) to 5.6 (560 / 100). The expected throttle time is + * computed as follow: ((R - Q / Q * S * W)) = ((5.6 - 5) / 5 * 100 * 1) = 12 secs. In practice, + * the average rate won't go below the quota before the burst is dropped from the samples so one + * must wait 100s (S * W). + * + * The token bucket relies on continuously updated amount of credits. Therefore, it does not + * suffers from the above issue. The same example would work as follow: + * - Quota (Q) = 5 + * - Burst (B) = 5 * 1 * 100 = 500 (Q * S * W) + * A burst of 560 brings the amount of credits to -60. One must wait 12s (-(-60)/5) to refill the + * bucket to zero. + */ +public class TokenBucket implements MeasurableStat { + private final TimeUnit unit; + private double tokens; + private long lastUpdateMs; + + public TokenBucket() { + this(TimeUnit.SECONDS); + } + + public TokenBucket(TimeUnit unit) { + this.unit = unit; + this.tokens = 0; + this.lastUpdateMs = 0; + } + + @Override + public double measure(final MetricConfig config, final long timeMs) { + if (config.quota() == null) + return Long.MAX_VALUE; + final double quota = config.quota().bound(); + final double burst = burst(config); + refill(quota, burst, timeMs); + return this.tokens; + } + + @Override + public void record(final MetricConfig config, final double value, final long timeMs) { + if (config.quota() == null) + return; + final double quota = config.quota().bound(); + final double burst = burst(config); + refill(quota, burst, timeMs); + this.tokens = Math.min(burst, this.tokens - value); + } + + private void refill(final double quota, final double burst, final long timeMs) { + this.tokens = Math.min(burst, this.tokens + quota * convert(timeMs - lastUpdateMs, unit)); + this.lastUpdateMs = timeMs; + } + + private double burst(final MetricConfig config) { + return config.samples() * convert(config.timeWindowMs(), unit) * config.quota().bound(); + } + + @Override + public String toString() { + return "TokenBucket(" + + "unit=" + unit + + ", tokens=" + tokens + + ", lastUpdateMs=" + lastUpdateMs + + ')'; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/stats/Value.java b/clients/src/main/java/org/apache/kafka/common/metrics/stats/Value.java new file mode 100644 index 0000000..deb81c7 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/stats/Value.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics.stats; + +import org.apache.kafka.common.metrics.MeasurableStat; +import org.apache.kafka.common.metrics.MetricConfig; + +/** + * An instantaneous value. + */ +public class Value implements MeasurableStat { + private double value = 0; + + @Override + public double measure(MetricConfig config, long now) { + return value; + } + + @Override + public void record(MetricConfig config, double value, long timeMs) { + this.value = value; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/stats/WindowedCount.java b/clients/src/main/java/org/apache/kafka/common/metrics/stats/WindowedCount.java new file mode 100644 index 0000000..825f404 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/stats/WindowedCount.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics.stats; + +import org.apache.kafka.common.metrics.MetricConfig; + +/** + * A {@link SampledStat} that maintains a simple count of what it has seen. + * This is a special kind of {@link WindowedSum} that always records a value of {@code 1} instead of the provided value. + * In other words, it counts the number of + * {@link WindowedCount#record(MetricConfig, double, long)} invocations, + * instead of summing the recorded values. + * + * See also {@link CumulativeCount} for a non-sampled version of this metric. + */ +public class WindowedCount extends WindowedSum { + @Override + protected void update(Sample sample, MetricConfig config, double value, long now) { + super.update(sample, config, 1.0, now); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/metrics/stats/WindowedSum.java b/clients/src/main/java/org/apache/kafka/common/metrics/stats/WindowedSum.java new file mode 100644 index 0000000..14aa562 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/metrics/stats/WindowedSum.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics.stats; + +import org.apache.kafka.common.metrics.MetricConfig; + +import java.util.List; + +/** + * A {@link SampledStat} that maintains the sum of what it has seen. + * This is a sampled version of {@link CumulativeSum}. + * + * See also {@link WindowedCount} if you want to increment the value by 1 on each recording. + */ +public class WindowedSum extends SampledStat { + + public WindowedSum() { + super(0); + } + + @Override + protected void update(Sample sample, MetricConfig config, double value, long now) { + sample.value += value; + } + + @Override + public double combine(List samples, MetricConfig config, long now) { + double total = 0.0; + for (Sample sample : samples) + total += sample.value; + return total; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/Authenticator.java b/clients/src/main/java/org/apache/kafka/common/network/Authenticator.java new file mode 100644 index 0000000..873c1f9 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/Authenticator.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.security.auth.KafkaPrincipal; +import org.apache.kafka.common.security.auth.KafkaPrincipalSerde; + +import java.io.Closeable; +import java.io.IOException; +import java.util.Optional; + +/** + * Authentication for Channel + */ +public interface Authenticator extends Closeable { + /** + * Implements any authentication mechanism. Use transportLayer to read or write tokens. + * For security protocols PLAINTEXT and SSL, this is a no-op since no further authentication + * needs to be done. For SASL_PLAINTEXT and SASL_SSL, this performs the SASL authentication. + * + * @throws AuthenticationException if authentication fails due to invalid credentials or + * other security configuration errors + * @throws IOException if read/write fails due to an I/O error + */ + void authenticate() throws AuthenticationException, IOException; + + /** + * Perform any processing related to authentication failure. This is invoked when the channel is about to be closed + * because of an {@link AuthenticationException} thrown from a prior {@link #authenticate()} call. + * @throws IOException if read/write fails due to an I/O error + */ + default void handleAuthenticationFailure() throws IOException { + } + + /** + * Returns Principal using PrincipalBuilder + */ + KafkaPrincipal principal(); + + /** + * Returns the serializer/deserializer interface for principal + */ + Optional principalSerde(); + + /** + * returns true if authentication is complete otherwise returns false; + */ + boolean complete(); + + /** + * Begins re-authentication. Uses transportLayer to read or write tokens as is + * done for {@link #authenticate()}. For security protocols PLAINTEXT and SSL, + * this is a no-op since re-authentication does not apply/is not supported, + * respectively. For SASL_PLAINTEXT and SASL_SSL, this performs a SASL + * authentication. Any in-flight responses from prior requests can/will be read + * and collected for later processing as required. There must not be partially + * written requests; any request queued for writing (for which zero bytes have + * been written) remains queued until after re-authentication succeeds. + * + * @param reauthenticationContext + * the context in which this re-authentication is occurring. This + * instance is responsible for closing the previous Authenticator + * returned by + * {@link ReauthenticationContext#previousAuthenticator()}. + * @throws AuthenticationException + * if authentication fails due to invalid credentials or other + * security configuration errors + * @throws IOException + * if read/write fails due to an I/O error + */ + default void reauthenticate(ReauthenticationContext reauthenticationContext) throws IOException { + // empty + } + + /** + * Return the session expiration time, if any, otherwise null. The value is in + * nanoseconds as per {@code System.nanoTime()} and is therefore only useful + * when compared to such a value -- it's absolute value is meaningless. This + * value may be non-null only on the server-side. It represents the time after + * which, in the absence of re-authentication, the broker will close the session + * if it receives a request unrelated to authentication. We store nanoseconds + * here to avoid having to invoke the more expensive {@code milliseconds()} call + * on the broker for every request + * + * @return the session expiration time, if any, otherwise null + */ + default Long serverSessionExpirationTimeNanos() { + return null; + } + + /** + * Return the time on or after which a client should re-authenticate this + * session, if any, otherwise null. The value is in nanoseconds as per + * {@code System.nanoTime()} and is therefore only useful when compared to such + * a value -- it's absolute value is meaningless. This value may be non-null + * only on the client-side. It will be a random time between 85% and 95% of the + * full session lifetime to account for latency between client and server and to + * avoid re-authentication storms that could be caused by many sessions + * re-authenticating simultaneously. + * + * @return the time on or after which a client should re-authenticate this + * session, if any, otherwise null + */ + default Long clientSessionReauthenticationTimeNanos() { + return null; + } + + /** + * Return the number of milliseconds that elapsed while re-authenticating this + * session from the perspective of this instance, if applicable, otherwise null. + * The server-side perspective will yield a lower value than the client-side + * perspective of the same re-authentication because the client-side observes an + * additional network round-trip. + * + * @return the number of milliseconds that elapsed while re-authenticating this + * session from the perspective of this instance, if applicable, + * otherwise null + */ + default Long reauthenticationLatencyMs() { + return null; + } + + /** + * Return the next (always non-null but possibly empty) client-side + * {@link NetworkReceive} response that arrived during re-authentication that + * is unrelated to re-authentication, if any. These correspond to requests sent + * prior to the beginning of re-authentication; the requests were made when the + * channel was successfully authenticated, and the responses arrived during the + * re-authentication process. The response returned is removed from the authenticator's + * queue. Responses of requests sent after completion of re-authentication are + * processed only when the authenticator response queue is empty. + * + * @return the (always non-null but possibly empty) client-side + * {@link NetworkReceive} response that arrived during + * re-authentication that is unrelated to re-authentication, if any + */ + default Optional pollResponseReceivedDuringReauthentication() { + return Optional.empty(); + } + + /** + * Return true if this is a server-side authenticator and the connected client + * has indicated that it supports re-authentication, otherwise false + * + * @return true if this is a server-side authenticator and the connected client + * has indicated that it supports re-authentication, otherwise false + */ + default boolean connectedClientSupportsReauthentication() { + return false; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/ByteBufferSend.java b/clients/src/main/java/org/apache/kafka/common/network/ByteBufferSend.java new file mode 100644 index 0000000..c6ffcde --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/ByteBufferSend.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import java.io.EOFException; +import java.io.IOException; +import java.nio.ByteBuffer; + +/** + * A send backed by an array of byte buffers + */ +public class ByteBufferSend implements Send { + + private final long size; + protected final ByteBuffer[] buffers; + private long remaining; + private boolean pending = false; + + public ByteBufferSend(ByteBuffer... buffers) { + this.buffers = buffers; + for (ByteBuffer buffer : buffers) + remaining += buffer.remaining(); + this.size = remaining; + } + + public ByteBufferSend(ByteBuffer[] buffers, long size) { + this.buffers = buffers; + this.size = size; + this.remaining = size; + } + + @Override + public boolean completed() { + return remaining <= 0 && !pending; + } + + @Override + public long size() { + return this.size; + } + + @Override + public long writeTo(TransferableChannel channel) throws IOException { + long written = channel.write(buffers); + if (written < 0) + throw new EOFException("Wrote negative bytes to channel. This shouldn't happen."); + remaining -= written; + pending = channel.hasPendingWrites(); + return written; + } + + public long remaining() { + return remaining; + } + + @Override + public String toString() { + return "ByteBufferSend(" + + ", size=" + size + + ", remaining=" + remaining + + ", pending=" + pending + + ')'; + } + + public static ByteBufferSend sizePrefixed(ByteBuffer buffer) { + ByteBuffer sizeBuffer = ByteBuffer.allocate(4); + sizeBuffer.putInt(0, buffer.remaining()); + return new ByteBufferSend(sizeBuffer, buffer); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/ChannelBuilder.java b/clients/src/main/java/org/apache/kafka/common/network/ChannelBuilder.java new file mode 100644 index 0000000..0cf1d74 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/ChannelBuilder.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import java.nio.channels.SelectionKey; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Configurable; +import org.apache.kafka.common.memory.MemoryPool; + + +/** + * A ChannelBuilder interface to build Channel based on configs + */ +public interface ChannelBuilder extends AutoCloseable, Configurable { + + /** + * returns a Channel with TransportLayer and Authenticator configured. + * @param id channel id + * @param key SelectionKey + * @param maxReceiveSize max size of a single receive buffer to allocate + * @param memoryPool memory pool from which to allocate buffers, or null for none + * @return KafkaChannel + */ + KafkaChannel buildChannel(String id, SelectionKey key, int maxReceiveSize, + MemoryPool memoryPool, ChannelMetadataRegistry metadataRegistry) throws KafkaException; + + /** + * Closes ChannelBuilder + */ + @Override + void close(); + +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/ChannelBuilders.java b/clients/src/main/java/org/apache/kafka/common/network/ChannelBuilders.java new file mode 100644 index 0000000..b2760ae --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/ChannelBuilders.java @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import org.apache.kafka.common.Configurable; +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.SslClientAuth; +import org.apache.kafka.common.config.internals.BrokerSecurityConfigs; +import org.apache.kafka.common.errors.InvalidConfigurationException; +import org.apache.kafka.common.requests.ApiVersionsResponse; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.common.security.JaasContext; +import org.apache.kafka.common.security.authenticator.DefaultKafkaPrincipalBuilder; +import org.apache.kafka.common.security.auth.KafkaPrincipalBuilder; +import org.apache.kafka.common.security.authenticator.CredentialCache; +import org.apache.kafka.common.security.kerberos.KerberosShortNamer; +import org.apache.kafka.common.security.ssl.SslPrincipalMapper; +import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.function.Supplier; + +public class ChannelBuilders { + private static final Logger log = LoggerFactory.getLogger(ChannelBuilders.class); + + private ChannelBuilders() { } + + /** + * @param securityProtocol the securityProtocol + * @param contextType the contextType, it must be non-null if `securityProtocol` is SASL_*; it is ignored otherwise + * @param config client config + * @param listenerName the listenerName if contextType is SERVER or null otherwise + * @param clientSaslMechanism SASL mechanism if mode is CLIENT, ignored otherwise + * @param time the time instance + * @param saslHandshakeRequestEnable flag to enable Sasl handshake requests; disabled only for SASL + * inter-broker connections with inter-broker protocol version < 0.10 + * @param logContext the log context instance + * + * @return the configured `ChannelBuilder` + * @throws IllegalArgumentException if `mode` invariants described above is not maintained + */ + public static ChannelBuilder clientChannelBuilder( + SecurityProtocol securityProtocol, + JaasContext.Type contextType, + AbstractConfig config, + ListenerName listenerName, + String clientSaslMechanism, + Time time, + boolean saslHandshakeRequestEnable, + LogContext logContext) { + + if (securityProtocol == SecurityProtocol.SASL_PLAINTEXT || securityProtocol == SecurityProtocol.SASL_SSL) { + if (contextType == null) + throw new IllegalArgumentException("`contextType` must be non-null if `securityProtocol` is `" + securityProtocol + "`"); + if (clientSaslMechanism == null) + throw new IllegalArgumentException("`clientSaslMechanism` must be non-null in client mode if `securityProtocol` is `" + securityProtocol + "`"); + } + return create(securityProtocol, Mode.CLIENT, contextType, config, listenerName, false, clientSaslMechanism, + saslHandshakeRequestEnable, null, null, time, logContext, null); + } + + /** + * @param listenerName the listenerName + * @param isInterBrokerListener whether or not this listener is used for inter-broker requests + * @param securityProtocol the securityProtocol + * @param config server config + * @param credentialCache Credential cache for SASL/SCRAM if SCRAM is enabled + * @param tokenCache Delegation token cache + * @param time the time instance + * @param logContext the log context instance + * @param apiVersionSupplier supplier for ApiVersions responses sent prior to authentication + * + * @return the configured `ChannelBuilder` + */ + public static ChannelBuilder serverChannelBuilder(ListenerName listenerName, + boolean isInterBrokerListener, + SecurityProtocol securityProtocol, + AbstractConfig config, + CredentialCache credentialCache, + DelegationTokenCache tokenCache, + Time time, + LogContext logContext, + Supplier apiVersionSupplier) { + return create(securityProtocol, Mode.SERVER, JaasContext.Type.SERVER, config, listenerName, + isInterBrokerListener, null, true, credentialCache, + tokenCache, time, logContext, apiVersionSupplier); + } + + private static ChannelBuilder create(SecurityProtocol securityProtocol, + Mode mode, + JaasContext.Type contextType, + AbstractConfig config, + ListenerName listenerName, + boolean isInterBrokerListener, + String clientSaslMechanism, + boolean saslHandshakeRequestEnable, + CredentialCache credentialCache, + DelegationTokenCache tokenCache, + Time time, + LogContext logContext, + Supplier apiVersionSupplier) { + Map configs = channelBuilderConfigs(config, listenerName); + + ChannelBuilder channelBuilder; + switch (securityProtocol) { + case SSL: + requireNonNullMode(mode, securityProtocol); + channelBuilder = new SslChannelBuilder(mode, listenerName, isInterBrokerListener, logContext); + break; + case SASL_SSL: + case SASL_PLAINTEXT: + requireNonNullMode(mode, securityProtocol); + Map jaasContexts; + String sslClientAuthOverride = null; + if (mode == Mode.SERVER) { + @SuppressWarnings("unchecked") + List enabledMechanisms = (List) configs.get(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG); + jaasContexts = new HashMap<>(enabledMechanisms.size()); + for (String mechanism : enabledMechanisms) + jaasContexts.put(mechanism, JaasContext.loadServerContext(listenerName, mechanism, configs)); + + // SSL client authentication is enabled in brokers for SASL_SSL only if listener-prefixed config is specified. + if (listenerName != null && securityProtocol == SecurityProtocol.SASL_SSL) { + String configuredClientAuth = (String) configs.get(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG); + String listenerClientAuth = (String) config.originalsWithPrefix(listenerName.configPrefix(), true) + .get(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG); + + // If `ssl.client.auth` is configured at the listener-level, we don't set an override and SslFactory + // uses the value from `configs`. If not, we propagate `sslClientAuthOverride=NONE` to SslFactory and + // it applies the override to the latest configs when it is configured or reconfigured. `Note that + // ssl.client.auth` cannot be dynamically altered. + if (listenerClientAuth == null) { + sslClientAuthOverride = SslClientAuth.NONE.name().toLowerCase(Locale.ROOT); + if (configuredClientAuth != null && !configuredClientAuth.equalsIgnoreCase(SslClientAuth.NONE.name())) { + log.warn("Broker configuration '{}' is applied only to SSL listeners. Listener-prefixed configuration can be used" + + " to enable SSL client authentication for SASL_SSL listeners. In future releases, broker-wide option without" + + " listener prefix may be applied to SASL_SSL listeners as well. All configuration options intended for specific" + + " listeners should be listener-prefixed.", BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG); + } + } + } + } else { + // Use server context for inter-broker client connections and client context for other clients + JaasContext jaasContext = contextType == JaasContext.Type.CLIENT ? JaasContext.loadClientContext(configs) : + JaasContext.loadServerContext(listenerName, clientSaslMechanism, configs); + jaasContexts = Collections.singletonMap(clientSaslMechanism, jaasContext); + } + channelBuilder = new SaslChannelBuilder(mode, + jaasContexts, + securityProtocol, + listenerName, + isInterBrokerListener, + clientSaslMechanism, + saslHandshakeRequestEnable, + credentialCache, + tokenCache, + sslClientAuthOverride, + time, + logContext, + apiVersionSupplier); + break; + case PLAINTEXT: + channelBuilder = new PlaintextChannelBuilder(listenerName); + break; + default: + throw new IllegalArgumentException("Unexpected securityProtocol " + securityProtocol); + } + + channelBuilder.configure(configs); + return channelBuilder; + } + + /** + * @return a mutable RecordingMap. The elements got from RecordingMap are marked as "used". + */ + @SuppressWarnings("unchecked") + static Map channelBuilderConfigs(final AbstractConfig config, final ListenerName listenerName) { + Map parsedConfigs; + if (listenerName == null) + parsedConfigs = (Map) config.values(); + else + parsedConfigs = config.valuesWithPrefixOverride(listenerName.configPrefix()); + + config.originals().entrySet().stream() + .filter(e -> !parsedConfigs.containsKey(e.getKey())) // exclude already parsed configs + // exclude already parsed listener prefix configs + .filter(e -> !(listenerName != null && e.getKey().startsWith(listenerName.configPrefix()) && + parsedConfigs.containsKey(e.getKey().substring(listenerName.configPrefix().length())))) + // exclude keys like `{mechanism}.some.prop` if "listener.name." prefix is present and key `some.prop` exists in parsed configs. + .filter(e -> !(listenerName != null && parsedConfigs.containsKey(e.getKey().substring(e.getKey().indexOf('.') + 1)))) + .forEach(e -> parsedConfigs.put(e.getKey(), e.getValue())); + return parsedConfigs; + } + + private static void requireNonNullMode(Mode mode, SecurityProtocol securityProtocol) { + if (mode == null) + throw new IllegalArgumentException("`mode` must be non-null if `securityProtocol` is `" + securityProtocol + "`"); + } + + public static KafkaPrincipalBuilder createPrincipalBuilder(Map configs, + KerberosShortNamer kerberosShortNamer, + SslPrincipalMapper sslPrincipalMapper) { + Class principalBuilderClass = (Class) configs.get(BrokerSecurityConfigs.PRINCIPAL_BUILDER_CLASS_CONFIG); + final KafkaPrincipalBuilder builder; + + if (principalBuilderClass == null || principalBuilderClass == DefaultKafkaPrincipalBuilder.class) { + builder = new DefaultKafkaPrincipalBuilder(kerberosShortNamer, sslPrincipalMapper); + } else if (KafkaPrincipalBuilder.class.isAssignableFrom(principalBuilderClass)) { + builder = (KafkaPrincipalBuilder) Utils.newInstance(principalBuilderClass); + } else { + throw new InvalidConfigurationException("Type " + principalBuilderClass.getName() + " is not " + + "an instance of " + KafkaPrincipalBuilder.class.getName()); + } + + if (builder instanceof Configurable) + ((Configurable) builder).configure(configs); + + return builder; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/ChannelMetadataRegistry.java b/clients/src/main/java/org/apache/kafka/common/network/ChannelMetadataRegistry.java new file mode 100644 index 0000000..a3453d8 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/ChannelMetadataRegistry.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import java.io.Closeable; + +/** + * Metadata about a channel is provided in various places in the network stack. This + * registry is used as a common place to collect them. + */ +public interface ChannelMetadataRegistry extends Closeable { + + /** + * Register information about the SSL cipher we are using. + * Re-registering the information will overwrite the previous one. + */ + void registerCipherInformation(CipherInformation cipherInformation); + + /** + * Get the currently registered cipher information. + */ + CipherInformation cipherInformation(); + + /** + * Register information about the client client we are using. + * Depending on the clients, the ApiVersionsRequest could be received + * multiple times or not at all. Re-registering the information will + * overwrite the previous one. + */ + void registerClientInformation(ClientInformation clientInformation); + + /** + * Get the currently registered client information. + */ + ClientInformation clientInformation(); + + /** + * Unregister everything that has been registered and close the registry. + */ + void close(); +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/ChannelState.java b/clients/src/main/java/org/apache/kafka/common/network/ChannelState.java new file mode 100644 index 0000000..5f6dfb9 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/ChannelState.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import org.apache.kafka.common.errors.AuthenticationException; + +/** + * States for KafkaChannel: + *

          + *
        • NOT_CONNECTED: Connections are created in NOT_CONNECTED state. State is updated + * on {@link TransportLayer#finishConnect()} when socket connection is established. + * PLAINTEXT channels transition from NOT_CONNECTED to READY, others transition + * to AUTHENTICATE. Failures in NOT_CONNECTED state typically indicate that the + * remote endpoint is unavailable, which may be due to misconfigured endpoints.
        • + *
        • AUTHENTICATE: SSL, SASL_SSL and SASL_PLAINTEXT channels are in AUTHENTICATE state during SSL and + * SASL handshake. Disconnections in AUTHENTICATE state may indicate that authentication failed with + * SSL or SASL (broker version < 1.0.0). Channels transition to READY state when authentication completes + * successfully.
        • + *
        • READY: Connected, authenticated channels are in READY state. Channels may transition from + * READY to EXPIRED, FAILED_SEND or LOCAL_CLOSE.
        • + *
        • EXPIRED: Idle connections are moved to EXPIRED state on idle timeout and the channel is closed.
        • + *
        • FAILED_SEND: Channels transition from READY to FAILED_SEND state if the channel is closed due + * to a send failure.
        • + *
        • AUTHENTICATION_FAILED: Channels are moved to this state if the requested SASL mechanism is not + * enabled in the broker or when brokers with versions 1.0.0 and above provide an error response + * during SASL authentication. {@link #exception()} gives the reason provided by the broker for + * authentication failure.
        • + *
        • LOCAL_CLOSE: Channels are moved to LOCAL_CLOSE state if close() is initiated locally.
        • + *
        + * If the remote endpoint closes a channel, the state of the channel reflects the state the channel + * was in at the time of disconnection. This state may be useful to identify the reason for disconnection. + *

        + * Typical transitions: + *

          + *
        • PLAINTEXT Good path: NOT_CONNECTED => READY => LOCAL_CLOSE
        • + *
        • SASL/SSL Good path: NOT_CONNECTED => AUTHENTICATE => READY => LOCAL_CLOSE
        • + *
        • Bootstrap server misconfiguration: NOT_CONNECTED, disconnected in NOT_CONNECTED state
        • + *
        • Security misconfiguration: NOT_CONNECTED => AUTHENTICATE => AUTHENTICATION_FAILED, disconnected in AUTHENTICATION_FAILED state
        • + *
        • Security misconfiguration with older broker: NOT_CONNECTED => AUTHENTICATE, disconnected in AUTHENTICATE state
        • + *
        + */ +public class ChannelState { + public enum State { + NOT_CONNECTED, + AUTHENTICATE, + READY, + EXPIRED, + FAILED_SEND, + AUTHENTICATION_FAILED, + LOCAL_CLOSE + } + + // AUTHENTICATION_FAILED has a custom exception. For other states, + // create a reusable `ChannelState` instance per-state. + public static final ChannelState NOT_CONNECTED = new ChannelState(State.NOT_CONNECTED); + public static final ChannelState AUTHENTICATE = new ChannelState(State.AUTHENTICATE); + public static final ChannelState READY = new ChannelState(State.READY); + public static final ChannelState EXPIRED = new ChannelState(State.EXPIRED); + public static final ChannelState FAILED_SEND = new ChannelState(State.FAILED_SEND); + public static final ChannelState LOCAL_CLOSE = new ChannelState(State.LOCAL_CLOSE); + + private final State state; + private final AuthenticationException exception; + private final String remoteAddress; + + public ChannelState(State state) { + this(state, null, null); + } + + public ChannelState(State state, String remoteAddress) { + this(state, null, remoteAddress); + } + + public ChannelState(State state, AuthenticationException exception, String remoteAddress) { + this.state = state; + this.exception = exception; + this.remoteAddress = remoteAddress; + } + + public State state() { + return state; + } + + public AuthenticationException exception() { + return exception; + } + + public String remoteAddress() { + return remoteAddress; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/CipherInformation.java b/clients/src/main/java/org/apache/kafka/common/network/CipherInformation.java new file mode 100644 index 0000000..d65aeb9 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/CipherInformation.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import java.util.Objects; + +public class CipherInformation { + private final String cipher; + private final String protocol; + + public CipherInformation(String cipher, String protocol) { + this.cipher = cipher == null || cipher.isEmpty() ? "unknown" : cipher; + this.protocol = protocol == null || protocol.isEmpty() ? "unknown" : protocol; + } + + public String cipher() { + return cipher; + } + + public String protocol() { + return protocol; + } + + @Override + public String toString() { + return "CipherInformation(cipher=" + cipher + + ", protocol=" + protocol + ")"; + } + + @Override + public int hashCode() { + return Objects.hash(cipher, protocol); + } + + @Override + public boolean equals(Object o) { + if (o == null) { + return false; + } + if (!(o instanceof CipherInformation)) { + return false; + } + CipherInformation other = (CipherInformation) o; + return other.cipher.equals(cipher) && + other.protocol.equals(protocol); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/ClientInformation.java b/clients/src/main/java/org/apache/kafka/common/network/ClientInformation.java new file mode 100644 index 0000000..cb99a86 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/ClientInformation.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.network; + +import java.util.Objects; + +public class ClientInformation { + public static final String UNKNOWN_NAME_OR_VERSION = "unknown"; + public static final ClientInformation EMPTY = new ClientInformation(UNKNOWN_NAME_OR_VERSION, UNKNOWN_NAME_OR_VERSION); + + private final String softwareName; + private final String softwareVersion; + + public ClientInformation(String softwareName, String softwareVersion) { + this.softwareName = softwareName.isEmpty() ? UNKNOWN_NAME_OR_VERSION : softwareName; + this.softwareVersion = softwareVersion.isEmpty() ? UNKNOWN_NAME_OR_VERSION : softwareVersion; + } + + public String softwareName() { + return this.softwareName; + } + + public String softwareVersion() { + return this.softwareVersion; + } + + @Override + public String toString() { + return "ClientInformation(softwareName=" + softwareName + + ", softwareVersion=" + softwareVersion + ")"; + } + + @Override + public int hashCode() { + return Objects.hash(softwareName, softwareVersion); + } + + @Override + public boolean equals(Object o) { + if (o == null) { + return false; + } + if (!(o instanceof ClientInformation)) { + return false; + } + ClientInformation other = (ClientInformation) o; + return other.softwareName.equals(softwareName) && + other.softwareVersion.equals(softwareVersion); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/DefaultChannelMetadataRegistry.java b/clients/src/main/java/org/apache/kafka/common/network/DefaultChannelMetadataRegistry.java new file mode 100644 index 0000000..ae9e9a8 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/DefaultChannelMetadataRegistry.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +public class DefaultChannelMetadataRegistry implements ChannelMetadataRegistry { + private CipherInformation cipherInformation; + private ClientInformation clientInformation; + + @Override + public void registerCipherInformation(final CipherInformation cipherInformation) { + if (this.cipherInformation != null) { + this.cipherInformation = cipherInformation; + } + } + + @Override + public CipherInformation cipherInformation() { + return this.cipherInformation; + } + + @Override + public void registerClientInformation(final ClientInformation clientInformation) { + this.clientInformation = clientInformation; + } + + @Override + public ClientInformation clientInformation() { + return this.clientInformation; + } + + @Override + public void close() { + this.cipherInformation = null; + this.clientInformation = null; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/DelayedResponseAuthenticationException.java b/clients/src/main/java/org/apache/kafka/common/network/DelayedResponseAuthenticationException.java new file mode 100644 index 0000000..8474426 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/DelayedResponseAuthenticationException.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import org.apache.kafka.common.errors.AuthenticationException; + +public class DelayedResponseAuthenticationException extends AuthenticationException { + private static final long serialVersionUID = 1L; + + public DelayedResponseAuthenticationException(Throwable cause) { + super(cause); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/InvalidReceiveException.java b/clients/src/main/java/org/apache/kafka/common/network/InvalidReceiveException.java new file mode 100644 index 0000000..a56353a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/InvalidReceiveException.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import org.apache.kafka.common.KafkaException; + +public class InvalidReceiveException extends KafkaException { + + public InvalidReceiveException(String message) { + super(message); + } + + public InvalidReceiveException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java b/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java new file mode 100644 index 0000000..bc82280 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java @@ -0,0 +1,674 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.errors.SslAuthenticationException; +import org.apache.kafka.common.memory.MemoryPool; +import org.apache.kafka.common.security.auth.KafkaPrincipal; +import org.apache.kafka.common.security.auth.KafkaPrincipalSerde; +import org.apache.kafka.common.utils.Utils; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.Socket; +import java.net.SocketAddress; +import java.nio.channels.SelectionKey; +import java.nio.channels.SocketChannel; +import java.util.Optional; +import java.util.function.Supplier; + +/** + * A Kafka connection either existing on a client (which could be a broker in an + * inter-broker scenario) and representing the channel to a remote broker or the + * reverse (existing on a broker and representing the channel to a remote + * client, which could be a broker in an inter-broker scenario). + *

        + * Each instance has the following: + *

          + *
        • a unique ID identifying it in the {@code KafkaClient} instance via which + * the connection was made on the client-side or in the instance where it was + * accepted on the server-side
        • + *
        • a reference to the underlying {@link TransportLayer} to allow reading and + * writing
        • + *
        • an {@link Authenticator} that performs the authentication (or + * re-authentication, if that feature is enabled and it applies to this + * connection) by reading and writing directly from/to the same + * {@link TransportLayer}.
        • + *
        • a {@link MemoryPool} into which responses are read (typically the JVM + * heap for clients, though smaller pools can be used for brokers and for + * testing out-of-memory scenarios)
        • + *
        • a {@link NetworkReceive} representing the current incomplete/in-progress + * request (from the server-side perspective) or response (from the client-side + * perspective) being read, if applicable; or a non-null value that has had no + * data read into it yet or a null value if there is no in-progress + * request/response (either could be the case)
        • + *
        • a {@link Send} representing the current request (from the client-side + * perspective) or response (from the server-side perspective) that is either + * waiting to be sent or partially sent, if applicable, or null
        • + *
        • a {@link ChannelMuteState} to document if the channel has been muted due + * to memory pressure or other reasons
        • + *
        + */ +public class KafkaChannel implements AutoCloseable { + private static final long MIN_REAUTH_INTERVAL_ONE_SECOND_NANOS = 1000 * 1000 * 1000; + + /** + * Mute States for KafkaChannel: + *
          + *
        • NOT_MUTED: Channel is not muted. This is the default state.
        • + *
        • MUTED: Channel is muted. Channel must be in this state to be unmuted.
        • + *
        • MUTED_AND_RESPONSE_PENDING: (SocketServer only) Channel is muted and SocketServer has not sent a response + * back to the client yet (acks != 0) or is currently waiting to receive a + * response from the API layer (acks == 0).
        • + *
        • MUTED_AND_THROTTLED: (SocketServer only) Channel is muted and throttling is in progress due to quota + * violation.
        • + *
        • MUTED_AND_THROTTLED_AND_RESPONSE_PENDING: (SocketServer only) Channel is muted, throttling is in progress, + * and a response is currently pending.
        • + *
        + */ + public enum ChannelMuteState { + NOT_MUTED, + MUTED, + MUTED_AND_RESPONSE_PENDING, + MUTED_AND_THROTTLED, + MUTED_AND_THROTTLED_AND_RESPONSE_PENDING + } + + /** Socket server events that will change the mute state: + *
          + *
        • REQUEST_RECEIVED: A request has been received from the client.
        • + *
        • RESPONSE_SENT: A response has been sent out to the client (ack != 0) or SocketServer has heard back from + * the API layer (acks = 0)
        • + *
        • THROTTLE_STARTED: Throttling started due to quota violation.
        • + *
        • THROTTLE_ENDED: Throttling ended.
        • + *
        + * + * Valid transitions on each event are: + *
          + *
        • REQUEST_RECEIVED: MUTED => MUTED_AND_RESPONSE_PENDING
        • + *
        • RESPONSE_SENT: MUTED_AND_RESPONSE_PENDING => MUTED, MUTED_AND_THROTTLED_AND_RESPONSE_PENDING => MUTED_AND_THROTTLED
        • + *
        • THROTTLE_STARTED: MUTED_AND_RESPONSE_PENDING => MUTED_AND_THROTTLED_AND_RESPONSE_PENDING
        • + *
        • THROTTLE_ENDED: MUTED_AND_THROTTLED => MUTED, MUTED_AND_THROTTLED_AND_RESPONSE_PENDING => MUTED_AND_RESPONSE_PENDING
        • + *
        + */ + public enum ChannelMuteEvent { + REQUEST_RECEIVED, + RESPONSE_SENT, + THROTTLE_STARTED, + THROTTLE_ENDED + } + + private final String id; + private final TransportLayer transportLayer; + private final Supplier authenticatorCreator; + private Authenticator authenticator; + // Tracks accumulated network thread time. This is updated on the network thread. + // The values are read and reset after each response is sent. + private long networkThreadTimeNanos; + private final int maxReceiveSize; + private final MemoryPool memoryPool; + private final ChannelMetadataRegistry metadataRegistry; + private NetworkReceive receive; + private NetworkSend send; + // Track connection and mute state of channels to enable outstanding requests on channels to be + // processed after the channel is disconnected. + private boolean disconnected; + private ChannelMuteState muteState; + private ChannelState state; + private SocketAddress remoteAddress; + private int successfulAuthentications; + private boolean midWrite; + private long lastReauthenticationStartNanos; + + public KafkaChannel(String id, TransportLayer transportLayer, Supplier authenticatorCreator, + int maxReceiveSize, MemoryPool memoryPool, ChannelMetadataRegistry metadataRegistry) { + this.id = id; + this.transportLayer = transportLayer; + this.authenticatorCreator = authenticatorCreator; + this.authenticator = authenticatorCreator.get(); + this.networkThreadTimeNanos = 0L; + this.maxReceiveSize = maxReceiveSize; + this.memoryPool = memoryPool; + this.metadataRegistry = metadataRegistry; + this.disconnected = false; + this.muteState = ChannelMuteState.NOT_MUTED; + this.state = ChannelState.NOT_CONNECTED; + } + + public void close() throws IOException { + this.disconnected = true; + Utils.closeAll(transportLayer, authenticator, receive, metadataRegistry); + } + + /** + * Returns the principal returned by `authenticator.principal()`. + */ + public KafkaPrincipal principal() { + return authenticator.principal(); + } + + public Optional principalSerde() { + return authenticator.principalSerde(); + } + + /** + * Does handshake of transportLayer and authentication using configured authenticator. + * For SSL with client authentication enabled, {@link TransportLayer#handshake()} performs + * authentication. For SASL, authentication is performed by {@link Authenticator#authenticate()}. + */ + public void prepare() throws AuthenticationException, IOException { + boolean authenticating = false; + try { + if (!transportLayer.ready()) + transportLayer.handshake(); + if (transportLayer.ready() && !authenticator.complete()) { + authenticating = true; + authenticator.authenticate(); + } + } catch (AuthenticationException e) { + // Clients are notified of authentication exceptions to enable operations to be terminated + // without retries. Other errors are handled as network exceptions in Selector. + String remoteDesc = remoteAddress != null ? remoteAddress.toString() : null; + state = new ChannelState(ChannelState.State.AUTHENTICATION_FAILED, e, remoteDesc); + if (authenticating) { + delayCloseOnAuthenticationFailure(); + throw new DelayedResponseAuthenticationException(e); + } + throw e; + } + if (ready()) { + ++successfulAuthentications; + state = ChannelState.READY; + } + } + + public void disconnect() { + disconnected = true; + if (state == ChannelState.NOT_CONNECTED && remoteAddress != null) { + //if we captured the remote address we can provide more information + state = new ChannelState(ChannelState.State.NOT_CONNECTED, remoteAddress.toString()); + } + transportLayer.disconnect(); + } + + public void state(ChannelState state) { + this.state = state; + } + + public ChannelState state() { + return this.state; + } + + public boolean finishConnect() throws IOException { + //we need to grab remoteAddr before finishConnect() is called otherwise + //it becomes inaccessible if the connection was refused. + SocketChannel socketChannel = transportLayer.socketChannel(); + if (socketChannel != null) { + remoteAddress = socketChannel.getRemoteAddress(); + } + boolean connected = transportLayer.finishConnect(); + if (connected) { + if (ready()) { + state = ChannelState.READY; + } else if (remoteAddress != null) { + state = new ChannelState(ChannelState.State.AUTHENTICATE, remoteAddress.toString()); + } else { + state = ChannelState.AUTHENTICATE; + } + } + return connected; + } + + public boolean isConnected() { + return transportLayer.isConnected(); + } + + public String id() { + return id; + } + + public SelectionKey selectionKey() { + return transportLayer.selectionKey(); + } + + /** + * externally muting a channel should be done via selector to ensure proper state handling + */ + void mute() { + if (muteState == ChannelMuteState.NOT_MUTED) { + if (!disconnected) transportLayer.removeInterestOps(SelectionKey.OP_READ); + muteState = ChannelMuteState.MUTED; + } + } + + /** + * Unmute the channel. The channel can be unmuted only if it is in the MUTED state. For other muted states + * (MUTED_AND_*), this is a no-op. + * + * @return Whether or not the channel is in the NOT_MUTED state after the call + */ + boolean maybeUnmute() { + if (muteState == ChannelMuteState.MUTED) { + if (!disconnected) transportLayer.addInterestOps(SelectionKey.OP_READ); + muteState = ChannelMuteState.NOT_MUTED; + } + return muteState == ChannelMuteState.NOT_MUTED; + } + + // Handle the specified channel mute-related event and transition the mute state according to the state machine. + public void handleChannelMuteEvent(ChannelMuteEvent event) { + boolean stateChanged = false; + switch (event) { + case REQUEST_RECEIVED: + if (muteState == ChannelMuteState.MUTED) { + muteState = ChannelMuteState.MUTED_AND_RESPONSE_PENDING; + stateChanged = true; + } + break; + case RESPONSE_SENT: + if (muteState == ChannelMuteState.MUTED_AND_RESPONSE_PENDING) { + muteState = ChannelMuteState.MUTED; + stateChanged = true; + } + if (muteState == ChannelMuteState.MUTED_AND_THROTTLED_AND_RESPONSE_PENDING) { + muteState = ChannelMuteState.MUTED_AND_THROTTLED; + stateChanged = true; + } + break; + case THROTTLE_STARTED: + if (muteState == ChannelMuteState.MUTED_AND_RESPONSE_PENDING) { + muteState = ChannelMuteState.MUTED_AND_THROTTLED_AND_RESPONSE_PENDING; + stateChanged = true; + } + break; + case THROTTLE_ENDED: + if (muteState == ChannelMuteState.MUTED_AND_THROTTLED) { + muteState = ChannelMuteState.MUTED; + stateChanged = true; + } + if (muteState == ChannelMuteState.MUTED_AND_THROTTLED_AND_RESPONSE_PENDING) { + muteState = ChannelMuteState.MUTED_AND_RESPONSE_PENDING; + stateChanged = true; + } + } + if (!stateChanged) { + throw new IllegalStateException("Cannot transition from " + muteState.name() + " for " + event.name()); + } + } + + public ChannelMuteState muteState() { + return muteState; + } + + /** + * Delay channel close on authentication failure. This will remove all read/write operations from the channel until + * {@link #completeCloseOnAuthenticationFailure()} is called to finish up the channel close. + */ + private void delayCloseOnAuthenticationFailure() { + transportLayer.removeInterestOps(SelectionKey.OP_WRITE); + } + + /** + * Finish up any processing on {@link #prepare()} failure. + * @throws IOException + */ + void completeCloseOnAuthenticationFailure() throws IOException { + transportLayer.addInterestOps(SelectionKey.OP_WRITE); + // Invoke the underlying handler to finish up any processing on authentication failure + authenticator.handleAuthenticationFailure(); + } + + /** + * Returns true if this channel has been explicitly muted using {@link KafkaChannel#mute()} + */ + public boolean isMuted() { + return muteState != ChannelMuteState.NOT_MUTED; + } + + public boolean isInMutableState() { + //some requests do not require memory, so if we do not know what the current (or future) request is + //(receive == null) we dont mute. we also dont mute if whatever memory required has already been + //successfully allocated (if none is required for the currently-being-read request + //receive.memoryAllocated() is expected to return true) + if (receive == null || receive.memoryAllocated()) + return false; + //also cannot mute if underlying transport is not in the ready state + return transportLayer.ready(); + } + + public boolean ready() { + return transportLayer.ready() && authenticator.complete(); + } + + public boolean hasSend() { + return send != null; + } + + /** + * Returns the address to which this channel's socket is connected or `null` if the socket has never been connected. + * + * If the socket was connected prior to being closed, then this method will continue to return the + * connected address after the socket is closed. + */ + public InetAddress socketAddress() { + return transportLayer.socketChannel().socket().getInetAddress(); + } + + public String socketDescription() { + Socket socket = transportLayer.socketChannel().socket(); + if (socket.getInetAddress() == null) + return socket.getLocalAddress().toString(); + return socket.getInetAddress().toString(); + } + + public void setSend(NetworkSend send) { + if (this.send != null) + throw new IllegalStateException("Attempt to begin a send operation with prior send operation still in progress, connection id is " + id); + this.send = send; + this.transportLayer.addInterestOps(SelectionKey.OP_WRITE); + } + + public NetworkSend maybeCompleteSend() { + if (send != null && send.completed()) { + midWrite = false; + transportLayer.removeInterestOps(SelectionKey.OP_WRITE); + NetworkSend result = send; + send = null; + return result; + } + return null; + } + + public long read() throws IOException { + if (receive == null) { + receive = new NetworkReceive(maxReceiveSize, id, memoryPool); + } + + long bytesReceived = receive(this.receive); + + if (this.receive.requiredMemoryAmountKnown() && !this.receive.memoryAllocated() && isInMutableState()) { + //pool must be out of memory, mute ourselves. + mute(); + } + return bytesReceived; + } + + public NetworkReceive currentReceive() { + return receive; + } + + public NetworkReceive maybeCompleteReceive() { + if (receive != null && receive.complete()) { + receive.payload().rewind(); + NetworkReceive result = receive; + receive = null; + return result; + } + return null; + } + + public long write() throws IOException { + if (send == null) + return 0; + + midWrite = true; + return send.writeTo(transportLayer); + } + + /** + * Accumulates network thread time for this channel. + */ + public void addNetworkThreadTimeNanos(long nanos) { + networkThreadTimeNanos += nanos; + } + + /** + * Returns accumulated network thread time for this channel and resets + * the value to zero. + */ + public long getAndResetNetworkThreadTimeNanos() { + long current = networkThreadTimeNanos; + networkThreadTimeNanos = 0; + return current; + } + + private long receive(NetworkReceive receive) throws IOException { + try { + return receive.readFrom(transportLayer); + } catch (SslAuthenticationException e) { + // With TLSv1.3, post-handshake messages may throw SSLExceptions, which are + // handled as authentication failures + String remoteDesc = remoteAddress != null ? remoteAddress.toString() : null; + state = new ChannelState(ChannelState.State.AUTHENTICATION_FAILED, e, remoteDesc); + throw e; + } + } + + /** + * @return true if underlying transport has bytes remaining to be read from any underlying intermediate buffers. + */ + public boolean hasBytesBuffered() { + return transportLayer.hasBytesBuffered(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + KafkaChannel that = (KafkaChannel) o; + return id.equals(that.id); + } + + @Override + public int hashCode() { + return id.hashCode(); + } + + @Override + public String toString() { + return super.toString() + " id=" + id; + } + + /** + * Return the number of times this instance has successfully authenticated. This + * value can only exceed 1 when re-authentication is enabled and it has + * succeeded at least once. + * + * @return the number of times this instance has successfully authenticated + */ + public int successfulAuthentications() { + return successfulAuthentications; + } + + /** + * If this is a server-side connection that has an expiration time and at least + * 1 second has passed since the prior re-authentication (if any) started then + * begin the process of re-authenticating the connection and return true, + * otherwise return false + * + * @param saslHandshakeNetworkReceive + * the mandatory {@link NetworkReceive} containing the + * {@code SaslHandshakeRequest} that has been received on the server + * and that initiates re-authentication. + * @param nowNanosSupplier + * {@code Supplier} of the current time. The value must be in + * nanoseconds as per {@code System.nanoTime()} and is therefore only + * useful when compared to such a value -- it's absolute value is + * meaningless. + * + * @return true if this is a server-side connection that has an expiration time + * and at least 1 second has passed since the prior re-authentication + * (if any) started to indicate that the re-authentication process has + * begun, otherwise false + * @throws AuthenticationException + * if re-authentication fails due to invalid credentials or other + * security configuration errors + * @throws IOException + * if read/write fails due to an I/O error + * @throws IllegalStateException + * if this channel is not "ready" + */ + public boolean maybeBeginServerReauthentication(NetworkReceive saslHandshakeNetworkReceive, + Supplier nowNanosSupplier) throws AuthenticationException, IOException { + if (!ready()) + throw new IllegalStateException( + "KafkaChannel should be \"ready\" when processing SASL Handshake for potential re-authentication"); + /* + * Re-authentication is disabled if there is no session expiration time, in + * which case the SASL handshake network receive will be processed normally, + * which results in a failure result being sent to the client. Also, no need to + * check if we are muted since since we are processing a received packet when we + * invoke this. + */ + if (authenticator.serverSessionExpirationTimeNanos() == null) + return false; + /* + * We've delayed getting the time as long as possible in case we don't need it, + * but at this point we need it -- so get it now. + */ + long nowNanos = nowNanosSupplier.get(); + /* + * Cannot re-authenticate more than once every second; an attempt to do so will + * result in the SASL handshake network receive being processed normally, which + * results in a failure result being sent to the client. + */ + if (lastReauthenticationStartNanos != 0 + && nowNanos - lastReauthenticationStartNanos < MIN_REAUTH_INTERVAL_ONE_SECOND_NANOS) + return false; + lastReauthenticationStartNanos = nowNanos; + swapAuthenticatorsAndBeginReauthentication( + new ReauthenticationContext(authenticator, saslHandshakeNetworkReceive, nowNanos)); + return true; + } + + /** + * If this is a client-side connection that is not muted, there is no + * in-progress write, and there is a session expiration time defined that has + * past then begin the process of re-authenticating the connection and return + * true, otherwise return false + * + * @param nowNanosSupplier + * {@code Supplier} of the current time. The value must be in + * nanoseconds as per {@code System.nanoTime()} and is therefore only + * useful when compared to such a value -- it's absolute value is + * meaningless. + * + * @return true if this is a client-side connection that is not muted, there is + * no in-progress write, and there is a session expiration time defined + * that has past to indicate that the re-authentication process has + * begun, otherwise false + * @throws AuthenticationException + * if re-authentication fails due to invalid credentials or other + * security configuration errors + * @throws IOException + * if read/write fails due to an I/O error + * @throws IllegalStateException + * if this channel is not "ready" + */ + public boolean maybeBeginClientReauthentication(Supplier nowNanosSupplier) + throws AuthenticationException, IOException { + if (!ready()) + throw new IllegalStateException( + "KafkaChannel should always be \"ready\" when it is checked for possible re-authentication"); + if (muteState != ChannelMuteState.NOT_MUTED || midWrite + || authenticator.clientSessionReauthenticationTimeNanos() == null) + return false; + /* + * We've delayed getting the time as long as possible in case we don't need it, + * but at this point we need it -- so get it now. + */ + long nowNanos = nowNanosSupplier.get(); + if (nowNanos < authenticator.clientSessionReauthenticationTimeNanos()) + return false; + swapAuthenticatorsAndBeginReauthentication(new ReauthenticationContext(authenticator, receive, nowNanos)); + receive = null; + return true; + } + + /** + * Return the number of milliseconds that elapsed while re-authenticating this + * session from the perspective of this instance, if applicable, otherwise null. + * The server-side perspective will yield a lower value than the client-side + * perspective of the same re-authentication because the client-side observes an + * additional network round-trip. + * + * @return the number of milliseconds that elapsed while re-authenticating this + * session from the perspective of this instance, if applicable, + * otherwise null + */ + public Long reauthenticationLatencyMs() { + return authenticator.reauthenticationLatencyMs(); + } + + /** + * Return true if this is a server-side channel and the given time is past the + * session expiration time, if any, otherwise false + * + * @param nowNanos + * the current time in nanoseconds as per {@code System.nanoTime()} + * @return true if this is a server-side channel and the given time is past the + * session expiration time, if any, otherwise false + */ + public boolean serverAuthenticationSessionExpired(long nowNanos) { + Long serverSessionExpirationTimeNanos = authenticator.serverSessionExpirationTimeNanos(); + return serverSessionExpirationTimeNanos != null && nowNanos - serverSessionExpirationTimeNanos > 0; + } + + /** + * Return the (always non-null but possibly empty) client-side + * {@link NetworkReceive} response that arrived during re-authentication but + * is unrelated to re-authentication. This corresponds to a request sent + * prior to the beginning of re-authentication; the request was made when the + * channel was successfully authenticated, and the response arrived during the + * re-authentication process. + * + * @return client-side {@link NetworkReceive} response that arrived during + * re-authentication that is unrelated to re-authentication. This may + * be empty. + */ + public Optional pollResponseReceivedDuringReauthentication() { + return authenticator.pollResponseReceivedDuringReauthentication(); + } + + /** + * Return true if this is a server-side channel and the connected client has + * indicated that it supports re-authentication, otherwise false + * + * @return true if this is a server-side channel and the connected client has + * indicated that it supports re-authentication, otherwise false + */ + boolean connectedClientSupportsReauthentication() { + return authenticator.connectedClientSupportsReauthentication(); + } + + private void swapAuthenticatorsAndBeginReauthentication(ReauthenticationContext reauthenticationContext) + throws IOException { + // it is up to the new authenticator to close the old one + // replace with a new one and begin the process of re-authenticating + authenticator = authenticatorCreator.get(); + authenticator.reauthenticate(reauthenticationContext); + } + + public ChannelMetadataRegistry channelMetadataRegistry() { + return metadataRegistry; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/ListenerName.java b/clients/src/main/java/org/apache/kafka/common/network/ListenerName.java new file mode 100644 index 0000000..2decccb --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/ListenerName.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import org.apache.kafka.common.security.auth.SecurityProtocol; + +import java.util.Locale; +import java.util.Objects; + +public final class ListenerName { + + private static final String CONFIG_STATIC_PREFIX = "listener.name"; + + /** + * Create an instance with the security protocol name as the value. + */ + public static ListenerName forSecurityProtocol(SecurityProtocol securityProtocol) { + return new ListenerName(securityProtocol.name); + } + + /** + * Create an instance with the provided value converted to uppercase. + */ + public static ListenerName normalised(String value) { + return new ListenerName(value.toUpperCase(Locale.ROOT)); + } + + private final String value; + + public ListenerName(String value) { + Objects.requireNonNull(value, "value should not be null"); + this.value = value; + } + + public String value() { + return value; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof ListenerName)) + return false; + ListenerName that = (ListenerName) o; + return value.equals(that.value); + } + + @Override + public int hashCode() { + return value.hashCode(); + } + + @Override + public String toString() { + return "ListenerName(" + value + ")"; + } + + public String configPrefix() { + return CONFIG_STATIC_PREFIX + "." + value.toLowerCase(Locale.ROOT) + "."; + } + + public String saslMechanismConfigPrefix(String saslMechanism) { + return configPrefix() + saslMechanismPrefix(saslMechanism); + } + + public static String saslMechanismPrefix(String saslMechanism) { + return saslMechanism.toLowerCase(Locale.ROOT) + "."; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/ListenerReconfigurable.java b/clients/src/main/java/org/apache/kafka/common/network/ListenerReconfigurable.java new file mode 100644 index 0000000..3541212 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/ListenerReconfigurable.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import org.apache.kafka.common.Reconfigurable; + +/** + * Interface for reconfigurable entities associated with a listener. + */ +public interface ListenerReconfigurable extends Reconfigurable { + + /** + * Returns the listener name associated with this reconfigurable. Listener-specific + * configs corresponding to this listener name are provided for reconfiguration. + */ + ListenerName listenerName(); +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/Mode.java b/clients/src/main/java/org/apache/kafka/common/network/Mode.java new file mode 100644 index 0000000..6123970 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/Mode.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +/** + * Connection mode for SSL and SASL connections. + */ +public enum Mode { CLIENT, SERVER } diff --git a/clients/src/main/java/org/apache/kafka/common/network/NetworkReceive.java b/clients/src/main/java/org/apache/kafka/common/network/NetworkReceive.java new file mode 100644 index 0000000..5332c81 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/NetworkReceive.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import org.apache.kafka.common.memory.MemoryPool; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.EOFException; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ScatteringByteChannel; + +/** + * A size delimited Receive that consists of a 4 byte network-ordered size N followed by N bytes of content + */ +public class NetworkReceive implements Receive { + + public final static String UNKNOWN_SOURCE = ""; + public final static int UNLIMITED = -1; + private static final Logger log = LoggerFactory.getLogger(NetworkReceive.class); + private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocate(0); + + private final String source; + private final ByteBuffer size; + private final int maxSize; + private final MemoryPool memoryPool; + private int requestedBufferSize = -1; + private ByteBuffer buffer; + + + public NetworkReceive(String source, ByteBuffer buffer) { + this.source = source; + this.buffer = buffer; + this.size = null; + this.maxSize = UNLIMITED; + this.memoryPool = MemoryPool.NONE; + } + + public NetworkReceive(String source) { + this.source = source; + this.size = ByteBuffer.allocate(4); + this.buffer = null; + this.maxSize = UNLIMITED; + this.memoryPool = MemoryPool.NONE; + } + + public NetworkReceive(int maxSize, String source) { + this.source = source; + this.size = ByteBuffer.allocate(4); + this.buffer = null; + this.maxSize = maxSize; + this.memoryPool = MemoryPool.NONE; + } + + public NetworkReceive(int maxSize, String source, MemoryPool memoryPool) { + this.source = source; + this.size = ByteBuffer.allocate(4); + this.buffer = null; + this.maxSize = maxSize; + this.memoryPool = memoryPool; + } + + public NetworkReceive() { + this(UNKNOWN_SOURCE); + } + + @Override + public String source() { + return source; + } + + @Override + public boolean complete() { + return !size.hasRemaining() && buffer != null && !buffer.hasRemaining(); + } + + public long readFrom(ScatteringByteChannel channel) throws IOException { + int read = 0; + if (size.hasRemaining()) { + int bytesRead = channel.read(size); + if (bytesRead < 0) + throw new EOFException(); + read += bytesRead; + if (!size.hasRemaining()) { + size.rewind(); + int receiveSize = size.getInt(); + if (receiveSize < 0) + throw new InvalidReceiveException("Invalid receive (size = " + receiveSize + ")"); + if (maxSize != UNLIMITED && receiveSize > maxSize) + throw new InvalidReceiveException("Invalid receive (size = " + receiveSize + " larger than " + maxSize + ")"); + requestedBufferSize = receiveSize; //may be 0 for some payloads (SASL) + if (receiveSize == 0) { + buffer = EMPTY_BUFFER; + } + } + } + if (buffer == null && requestedBufferSize != -1) { //we know the size we want but havent been able to allocate it yet + buffer = memoryPool.tryAllocate(requestedBufferSize); + if (buffer == null) + log.trace("Broker low on memory - could not allocate buffer of size {} for source {}", requestedBufferSize, source); + } + if (buffer != null) { + int bytesRead = channel.read(buffer); + if (bytesRead < 0) + throw new EOFException(); + read += bytesRead; + } + + return read; + } + + @Override + public boolean requiredMemoryAmountKnown() { + return requestedBufferSize != -1; + } + + @Override + public boolean memoryAllocated() { + return buffer != null; + } + + + @Override + public void close() throws IOException { + if (buffer != null && buffer != EMPTY_BUFFER) { + memoryPool.release(buffer); + buffer = null; + } + } + + public ByteBuffer payload() { + return this.buffer; + } + + public int bytesRead() { + if (buffer == null) + return size.position(); + return buffer.position() + size.position(); + } + + /** + * Returns the total size of the receive including payload and size buffer + * for use in metrics. This is consistent with {@link NetworkSend#size()} + */ + public int size() { + return payload().limit() + size.limit(); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/NetworkSend.java b/clients/src/main/java/org/apache/kafka/common/network/NetworkSend.java new file mode 100644 index 0000000..2a51a56 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/NetworkSend.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import java.io.IOException; + +public class NetworkSend implements Send { + private final String destinationId; + private final Send send; + + public NetworkSend(String destinationId, Send send) { + this.destinationId = destinationId; + this.send = send; + } + + public String destinationId() { + return destinationId; + } + + @Override + public boolean completed() { + return send.completed(); + } + + @Override + public long writeTo(TransferableChannel channel) throws IOException { + return send.writeTo(channel); + } + + @Override + public long size() { + return send.size(); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/PlaintextChannelBuilder.java b/clients/src/main/java/org/apache/kafka/common/network/PlaintextChannelBuilder.java new file mode 100644 index 0000000..50bbc48 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/PlaintextChannelBuilder.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.memory.MemoryPool; +import org.apache.kafka.common.security.auth.KafkaPrincipal; +import org.apache.kafka.common.security.auth.KafkaPrincipalBuilder; +import org.apache.kafka.common.security.auth.KafkaPrincipalSerde; +import org.apache.kafka.common.security.auth.PlaintextAuthenticationContext; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Closeable; +import java.io.IOException; +import java.net.InetAddress; +import java.nio.channels.SelectionKey; +import java.util.Map; +import java.util.Optional; +import java.util.function.Supplier; + +public class PlaintextChannelBuilder implements ChannelBuilder { + private static final Logger log = LoggerFactory.getLogger(PlaintextChannelBuilder.class); + private final ListenerName listenerName; + private Map configs; + + /** + * Constructs a plaintext channel builder. ListenerName is non-null whenever + * it's instantiated in the broker and null otherwise. + */ + public PlaintextChannelBuilder(ListenerName listenerName) { + this.listenerName = listenerName; + } + + public void configure(Map configs) throws KafkaException { + this.configs = configs; + } + + @Override + public KafkaChannel buildChannel(String id, SelectionKey key, int maxReceiveSize, + MemoryPool memoryPool, ChannelMetadataRegistry metadataRegistry) throws KafkaException { + try { + PlaintextTransportLayer transportLayer = buildTransportLayer(key); + Supplier authenticatorCreator = () -> new PlaintextAuthenticator(configs, transportLayer, listenerName); + return buildChannel(id, transportLayer, authenticatorCreator, maxReceiveSize, + memoryPool != null ? memoryPool : MemoryPool.NONE, metadataRegistry); + } catch (Exception e) { + log.warn("Failed to create channel due to ", e); + throw new KafkaException(e); + } + } + + // visible for testing + KafkaChannel buildChannel(String id, TransportLayer transportLayer, Supplier authenticatorCreator, + int maxReceiveSize, MemoryPool memoryPool, ChannelMetadataRegistry metadataRegistry) { + return new KafkaChannel(id, transportLayer, authenticatorCreator, maxReceiveSize, memoryPool, metadataRegistry); + } + + protected PlaintextTransportLayer buildTransportLayer(SelectionKey key) throws IOException { + return new PlaintextTransportLayer(key); + } + + @Override + public void close() {} + + private static class PlaintextAuthenticator implements Authenticator { + private final PlaintextTransportLayer transportLayer; + private final KafkaPrincipalBuilder principalBuilder; + private final ListenerName listenerName; + + private PlaintextAuthenticator(Map configs, PlaintextTransportLayer transportLayer, ListenerName listenerName) { + this.transportLayer = transportLayer; + this.principalBuilder = ChannelBuilders.createPrincipalBuilder(configs, null, null); + this.listenerName = listenerName; + } + + @Override + public void authenticate() {} + + @Override + public KafkaPrincipal principal() { + InetAddress clientAddress = transportLayer.socketChannel().socket().getInetAddress(); + // listenerName should only be null in Client mode where principal() should not be called + if (listenerName == null) + throw new IllegalStateException("Unexpected call to principal() when listenerName is null"); + return principalBuilder.build(new PlaintextAuthenticationContext(clientAddress, listenerName.value())); + } + + @Override + public Optional principalSerde() { + return principalBuilder instanceof KafkaPrincipalSerde ? Optional.of((KafkaPrincipalSerde) principalBuilder) : Optional.empty(); + } + + @Override + public boolean complete() { + return true; + } + + @Override + public void close() { + if (principalBuilder instanceof Closeable) + Utils.closeQuietly((Closeable) principalBuilder, "principal builder"); + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/PlaintextTransportLayer.java b/clients/src/main/java/org/apache/kafka/common/network/PlaintextTransportLayer.java new file mode 100644 index 0000000..845b147 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/PlaintextTransportLayer.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +/* + * Transport layer for PLAINTEXT communication + */ + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.channels.SocketChannel; +import java.nio.channels.SelectionKey; + +import java.security.Principal; + +import org.apache.kafka.common.security.auth.KafkaPrincipal; + +public class PlaintextTransportLayer implements TransportLayer { + private final SelectionKey key; + private final SocketChannel socketChannel; + private final Principal principal = KafkaPrincipal.ANONYMOUS; + + public PlaintextTransportLayer(SelectionKey key) throws IOException { + this.key = key; + this.socketChannel = (SocketChannel) key.channel(); + } + + @Override + public boolean ready() { + return true; + } + + @Override + public boolean finishConnect() throws IOException { + boolean connected = socketChannel.finishConnect(); + if (connected) + key.interestOps(key.interestOps() & ~SelectionKey.OP_CONNECT | SelectionKey.OP_READ); + return connected; + } + + @Override + public void disconnect() { + key.cancel(); + } + + @Override + public SocketChannel socketChannel() { + return socketChannel; + } + + @Override + public SelectionKey selectionKey() { + return key; + } + + @Override + public boolean isOpen() { + return socketChannel.isOpen(); + } + + @Override + public boolean isConnected() { + return socketChannel.isConnected(); + } + + @Override + public void close() throws IOException { + socketChannel.socket().close(); + socketChannel.close(); + } + + /** + * Performs SSL handshake hence is a no-op for the non-secure + * implementation + */ + @Override + public void handshake() {} + + /** + * Reads a sequence of bytes from this channel into the given buffer. + * + * @param dst The buffer into which bytes are to be transferred + * @return The number of bytes read, possible zero or -1 if the channel has reached end-of-stream + * @throws IOException if some other I/O error occurs + */ + @Override + public int read(ByteBuffer dst) throws IOException { + return socketChannel.read(dst); + } + + /** + * Reads a sequence of bytes from this channel into the given buffers. + * + * @param dsts - The buffers into which bytes are to be transferred. + * @return The number of bytes read, possibly zero, or -1 if the channel has reached end-of-stream. + * @throws IOException if some other I/O error occurs + */ + @Override + public long read(ByteBuffer[] dsts) throws IOException { + return socketChannel.read(dsts); + } + + /** + * Reads a sequence of bytes from this channel into a subsequence of the given buffers. + * @param dsts - The buffers into which bytes are to be transferred + * @param offset - The offset within the buffer array of the first buffer into which bytes are to be transferred; must be non-negative and no larger than dsts.length. + * @param length - The maximum number of buffers to be accessed; must be non-negative and no larger than dsts.length - offset + * @return The number of bytes read, possibly zero, or -1 if the channel has reached end-of-stream. + * @throws IOException if some other I/O error occurs + */ + @Override + public long read(ByteBuffer[] dsts, int offset, int length) throws IOException { + return socketChannel.read(dsts, offset, length); + } + + /** + * Writes a sequence of bytes to this channel from the given buffer. + * + * @param src The buffer from which bytes are to be retrieved + * @return The number of bytes read, possibly zero, or -1 if the channel has reached end-of-stream + * @throws IOException If some other I/O error occurs + */ + @Override + public int write(ByteBuffer src) throws IOException { + return socketChannel.write(src); + } + + /** + * Writes a sequence of bytes to this channel from the given buffer. + * + * @param srcs The buffer from which bytes are to be retrieved + * @return The number of bytes read, possibly zero, or -1 if the channel has reached end-of-stream + * @throws IOException If some other I/O error occurs + */ + @Override + public long write(ByteBuffer[] srcs) throws IOException { + return socketChannel.write(srcs); + } + + /** + * Writes a sequence of bytes to this channel from the subsequence of the given buffers. + * + * @param srcs The buffers from which bytes are to be retrieved + * @param offset The offset within the buffer array of the first buffer from which bytes are to be retrieved; must be non-negative and no larger than srcs.length. + * @param length - The maximum number of buffers to be accessed; must be non-negative and no larger than srcs.length - offset. + * @return returns no.of bytes written , possibly zero. + * @throws IOException If some other I/O error occurs + */ + @Override + public long write(ByteBuffer[] srcs, int offset, int length) throws IOException { + return socketChannel.write(srcs, offset, length); + } + + /** + * always returns false as there will be not be any + * pending writes since we directly write to socketChannel. + */ + @Override + public boolean hasPendingWrites() { + return false; + } + + /** + * Returns ANONYMOUS as Principal. + */ + @Override + public Principal peerPrincipal() { + return principal; + } + + /** + * Adds the interestOps to selectionKey. + */ + @Override + public void addInterestOps(int ops) { + key.interestOps(key.interestOps() | ops); + + } + + /** + * Removes the interestOps from selectionKey. + */ + @Override + public void removeInterestOps(int ops) { + key.interestOps(key.interestOps() & ~ops); + } + + @Override + public boolean isMute() { + return key.isValid() && (key.interestOps() & SelectionKey.OP_READ) == 0; + } + + @Override + public boolean hasBytesBuffered() { + return false; + } + + @Override + public long transferFrom(FileChannel fileChannel, long position, long count) throws IOException { + return fileChannel.transferTo(position, count, socketChannel); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/ReauthenticationContext.java b/clients/src/main/java/org/apache/kafka/common/network/ReauthenticationContext.java new file mode 100644 index 0000000..37e46cb --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/ReauthenticationContext.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import java.util.Objects; + +/** + * Defines the context in which an {@link Authenticator} is to be created during + * a re-authentication. + */ +public class ReauthenticationContext { + private final NetworkReceive networkReceive; + private final Authenticator previousAuthenticator; + private final long reauthenticationBeginNanos; + + /** + * Constructor + * + * @param previousAuthenticator + * the mandatory {@link Authenticator} that was previously used to + * authenticate the channel + * @param networkReceive + * the applicable {@link NetworkReceive} instance, if any. For the + * client side this may be a response that has been partially read, a + * non-null instance that has had no data read into it yet, or null; + * if it is non-null then this is the instance that data should + * initially be read into during re-authentication. For the server + * side this is mandatory and it must contain the + * {@code SaslHandshakeRequest} that has been received on the server + * and that initiates re-authentication. + * + * @param nowNanos + * the current time. The value is in nanoseconds as per + * {@code System.nanoTime()} and is therefore only useful when + * compared to such a value -- it's absolute value is meaningless. + * This defines the moment when re-authentication begins. + */ + public ReauthenticationContext(Authenticator previousAuthenticator, NetworkReceive networkReceive, long nowNanos) { + this.previousAuthenticator = Objects.requireNonNull(previousAuthenticator); + this.networkReceive = networkReceive; + this.reauthenticationBeginNanos = nowNanos; + } + + /** + * Return the applicable {@link NetworkReceive} instance, if any. For the client + * side this may be a response that has been partially read, a non-null instance + * that has had no data read into it yet, or null; if it is non-null then this + * is the instance that data should initially be read into during + * re-authentication. For the server side this is mandatory and it must contain + * the {@code SaslHandshakeRequest} that has been received on the server and + * that initiates re-authentication. + * + * @return the applicable {@link NetworkReceive} instance, if any + */ + public NetworkReceive networkReceive() { + return networkReceive; + } + + /** + * Return the always non-null {@link Authenticator} that was previously used to + * authenticate the channel + * + * @return the always non-null {@link Authenticator} that was previously used to + * authenticate the channel + */ + public Authenticator previousAuthenticator() { + return previousAuthenticator; + } + + /** + * Return the time when re-authentication began. The value is in nanoseconds as + * per {@code System.nanoTime()} and is therefore only useful when compared to + * such a value -- it's absolute value is meaningless. + * + * @return the time when re-authentication began + */ + public long reauthenticationBeginNanos() { + return reauthenticationBeginNanos; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/Receive.java b/clients/src/main/java/org/apache/kafka/common/network/Receive.java new file mode 100644 index 0000000..3bc2761 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/Receive.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.channels.ScatteringByteChannel; + +/** + * This interface models the in-progress reading of data from a channel to a source identified by an integer id + */ +public interface Receive extends Closeable { + + /** + * The numeric id of the source from which we are receiving data. + */ + String source(); + + /** + * Are we done receiving data? + */ + boolean complete(); + + /** + * Read bytes into this receive from the given channel + * @param channel The channel to read from + * @return The number of bytes read + * @throws IOException If the reading fails + */ + long readFrom(ScatteringByteChannel channel) throws IOException; + + /** + * Do we know yet how much memory we require to fully read this + */ + boolean requiredMemoryAmountKnown(); + + /** + * Has the underlying memory required to complete reading been allocated yet? + */ + boolean memoryAllocated(); +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/SaslChannelBuilder.java b/clients/src/main/java/org/apache/kafka/common/network/SaslChannelBuilder.java new file mode 100644 index 0000000..8b390d1 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/SaslChannelBuilder.java @@ -0,0 +1,405 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.config.SaslConfigs; +import org.apache.kafka.common.config.SslConfigs; +import org.apache.kafka.common.config.internals.BrokerSecurityConfigs; +import org.apache.kafka.common.memory.MemoryPool; +import org.apache.kafka.common.requests.ApiVersionsResponse; +import org.apache.kafka.common.security.JaasContext; +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; +import org.apache.kafka.common.security.auth.Login; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.common.security.authenticator.CredentialCache; +import org.apache.kafka.common.security.authenticator.DefaultLogin; +import org.apache.kafka.common.security.authenticator.LoginManager; +import org.apache.kafka.common.security.authenticator.SaslClientAuthenticator; +import org.apache.kafka.common.security.authenticator.SaslClientCallbackHandler; +import org.apache.kafka.common.security.authenticator.SaslServerAuthenticator; +import org.apache.kafka.common.security.authenticator.SaslServerCallbackHandler; +import org.apache.kafka.common.security.kerberos.KerberosClientCallbackHandler; +import org.apache.kafka.common.security.kerberos.KerberosLogin; +import org.apache.kafka.common.security.kerberos.KerberosName; +import org.apache.kafka.common.security.kerberos.KerberosShortNamer; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; +import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerRefreshingLogin; +import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerSaslClientCallbackHandler; +import org.apache.kafka.common.security.oauthbearer.internals.unsecured.OAuthBearerUnsecuredValidatorCallbackHandler; +import org.apache.kafka.common.security.plain.internals.PlainSaslServer; +import org.apache.kafka.common.security.plain.internals.PlainServerCallbackHandler; +import org.apache.kafka.common.security.scram.ScramCredential; +import org.apache.kafka.common.security.scram.internals.ScramMechanism; +import org.apache.kafka.common.security.scram.internals.ScramServerCallbackHandler; +import org.apache.kafka.common.security.ssl.SslFactory; +import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.ietf.jgss.GSSContext; +import org.ietf.jgss.GSSCredential; +import org.ietf.jgss.GSSException; +import org.ietf.jgss.GSSManager; +import org.ietf.jgss.GSSName; +import org.ietf.jgss.Oid; +import org.slf4j.Logger; + +import javax.security.auth.Subject; +import javax.security.auth.kerberos.KerberosPrincipal; +import java.io.IOException; +import java.net.Socket; +import java.nio.channels.SelectionKey; +import java.nio.channels.SocketChannel; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Supplier; + +public class SaslChannelBuilder implements ChannelBuilder, ListenerReconfigurable { + static final String GSS_NATIVE_PROP = "sun.security.jgss.native"; + + private final SecurityProtocol securityProtocol; + private final ListenerName listenerName; + private final boolean isInterBrokerListener; + private final String clientSaslMechanism; + private final Mode mode; + private final Map jaasContexts; + private final boolean handshakeRequestEnable; + private final CredentialCache credentialCache; + private final DelegationTokenCache tokenCache; + private final Map loginManagers; + private final Map subjects; + private final Supplier apiVersionSupplier; + + private SslFactory sslFactory; + private Map configs; + private final String sslClientAuthOverride; + + private KerberosShortNamer kerberosShortNamer; + private Map saslCallbackHandlers; + private Map connectionsMaxReauthMsByMechanism; + private final Time time; + private final LogContext logContext; + private final Logger log; + + public SaslChannelBuilder(Mode mode, + Map jaasContexts, + SecurityProtocol securityProtocol, + ListenerName listenerName, + boolean isInterBrokerListener, + String clientSaslMechanism, + boolean handshakeRequestEnable, + CredentialCache credentialCache, + DelegationTokenCache tokenCache, + String sslClientAuthOverride, + Time time, + LogContext logContext, + Supplier apiVersionSupplier) { + this.mode = mode; + this.jaasContexts = jaasContexts; + this.loginManagers = new HashMap<>(jaasContexts.size()); + this.subjects = new HashMap<>(jaasContexts.size()); + this.securityProtocol = securityProtocol; + this.listenerName = listenerName; + this.isInterBrokerListener = isInterBrokerListener; + this.handshakeRequestEnable = handshakeRequestEnable; + this.clientSaslMechanism = clientSaslMechanism; + this.credentialCache = credentialCache; + this.tokenCache = tokenCache; + this.sslClientAuthOverride = sslClientAuthOverride; + this.saslCallbackHandlers = new HashMap<>(); + this.connectionsMaxReauthMsByMechanism = new HashMap<>(); + this.time = time; + this.logContext = logContext; + this.log = logContext.logger(getClass()); + this.apiVersionSupplier = apiVersionSupplier; + + if (mode == Mode.SERVER && apiVersionSupplier == null) { + throw new IllegalArgumentException("Server channel builder must provide an ApiVersionResponse supplier"); + } + } + + @SuppressWarnings("unchecked") + @Override + public void configure(Map configs) throws KafkaException { + try { + this.configs = configs; + if (mode == Mode.SERVER) { + createServerCallbackHandlers(configs); + createConnectionsMaxReauthMsMap(configs); + } else + createClientCallbackHandler(configs); + for (Map.Entry entry : saslCallbackHandlers.entrySet()) { + String mechanism = entry.getKey(); + entry.getValue().configure(configs, mechanism, jaasContexts.get(mechanism).configurationEntries()); + } + + Class defaultLoginClass = defaultLoginClass(); + if (mode == Mode.SERVER && jaasContexts.containsKey(SaslConfigs.GSSAPI_MECHANISM)) { + String defaultRealm; + try { + defaultRealm = defaultKerberosRealm(); + } catch (Exception ke) { + defaultRealm = ""; + } + List principalToLocalRules = (List) configs.get(BrokerSecurityConfigs.SASL_KERBEROS_PRINCIPAL_TO_LOCAL_RULES_CONFIG); + if (principalToLocalRules != null) + kerberosShortNamer = KerberosShortNamer.fromUnparsedRules(defaultRealm, principalToLocalRules); + } + for (Map.Entry entry : jaasContexts.entrySet()) { + String mechanism = entry.getKey(); + // With static JAAS configuration, use KerberosLogin if Kerberos is enabled. With dynamic JAAS configuration, + // use KerberosLogin only for the LoginContext corresponding to GSSAPI + LoginManager loginManager = LoginManager.acquireLoginManager(entry.getValue(), mechanism, defaultLoginClass, configs); + loginManagers.put(mechanism, loginManager); + Subject subject = loginManager.subject(); + subjects.put(mechanism, subject); + if (mode == Mode.SERVER && mechanism.equals(SaslConfigs.GSSAPI_MECHANISM)) + maybeAddNativeGssapiCredentials(subject); + } + if (this.securityProtocol == SecurityProtocol.SASL_SSL) { + // Disable SSL client authentication as we are using SASL authentication + this.sslFactory = new SslFactory(mode, sslClientAuthOverride, isInterBrokerListener); + this.sslFactory.configure(configs); + } + } catch (Throwable e) { + close(); + throw new KafkaException(e); + } + } + + @Override + public Set reconfigurableConfigs() { + return securityProtocol == SecurityProtocol.SASL_SSL ? SslConfigs.RECONFIGURABLE_CONFIGS : Collections.emptySet(); + } + + @Override + public void validateReconfiguration(Map configs) { + if (this.securityProtocol == SecurityProtocol.SASL_SSL) + sslFactory.validateReconfiguration(configs); + } + + @Override + public void reconfigure(Map configs) { + if (this.securityProtocol == SecurityProtocol.SASL_SSL) + sslFactory.reconfigure(configs); + } + + @Override + public ListenerName listenerName() { + return listenerName; + } + + @Override + public KafkaChannel buildChannel(String id, SelectionKey key, int maxReceiveSize, + MemoryPool memoryPool, ChannelMetadataRegistry metadataRegistry) throws KafkaException { + try { + SocketChannel socketChannel = (SocketChannel) key.channel(); + Socket socket = socketChannel.socket(); + TransportLayer transportLayer = buildTransportLayer(id, key, socketChannel, metadataRegistry); + Supplier authenticatorCreator; + if (mode == Mode.SERVER) { + authenticatorCreator = () -> buildServerAuthenticator(configs, + Collections.unmodifiableMap(saslCallbackHandlers), + id, + transportLayer, + Collections.unmodifiableMap(subjects), + Collections.unmodifiableMap(connectionsMaxReauthMsByMechanism), + metadataRegistry); + } else { + LoginManager loginManager = loginManagers.get(clientSaslMechanism); + authenticatorCreator = () -> buildClientAuthenticator(configs, + saslCallbackHandlers.get(clientSaslMechanism), + id, + socket.getInetAddress().getHostName(), + loginManager.serviceName(), + transportLayer, + subjects.get(clientSaslMechanism)); + } + return new KafkaChannel(id, transportLayer, authenticatorCreator, maxReceiveSize, + memoryPool != null ? memoryPool : MemoryPool.NONE, metadataRegistry); + } catch (Exception e) { + log.info("Failed to create channel due to ", e); + throw new KafkaException(e); + } + } + + @Override + public void close() { + for (LoginManager loginManager : loginManagers.values()) + loginManager.release(); + loginManagers.clear(); + for (AuthenticateCallbackHandler handler : saslCallbackHandlers.values()) + handler.close(); + if (sslFactory != null) sslFactory.close(); + } + + // Visible to override for testing + protected TransportLayer buildTransportLayer(String id, SelectionKey key, SocketChannel socketChannel, + ChannelMetadataRegistry metadataRegistry) throws IOException { + if (this.securityProtocol == SecurityProtocol.SASL_SSL) { + return SslTransportLayer.create(id, key, + sslFactory.createSslEngine(socketChannel.socket()), + metadataRegistry); + } else { + return new PlaintextTransportLayer(key); + } + } + + // Visible to override for testing + protected SaslServerAuthenticator buildServerAuthenticator(Map configs, + Map callbackHandlers, + String id, + TransportLayer transportLayer, + Map subjects, + Map connectionsMaxReauthMsByMechanism, + ChannelMetadataRegistry metadataRegistry) { + return new SaslServerAuthenticator(configs, callbackHandlers, id, subjects, + kerberosShortNamer, listenerName, securityProtocol, transportLayer, + connectionsMaxReauthMsByMechanism, metadataRegistry, time, apiVersionSupplier); + } + + // Visible to override for testing + protected SaslClientAuthenticator buildClientAuthenticator(Map configs, + AuthenticateCallbackHandler callbackHandler, + String id, + String serverHost, + String servicePrincipal, + TransportLayer transportLayer, Subject subject) { + return new SaslClientAuthenticator(configs, callbackHandler, id, subject, servicePrincipal, + serverHost, clientSaslMechanism, handshakeRequestEnable, transportLayer, time, logContext); + } + + // Package private for testing + Map loginManagers() { + return loginManagers; + } + + private static String defaultKerberosRealm() { + // see https://issues.apache.org/jira/browse/HADOOP-10848 for details + return new KerberosPrincipal("tmp", 1).getRealm(); + } + + private void createClientCallbackHandler(Map configs) { + @SuppressWarnings("unchecked") + Class clazz = (Class) configs.get(SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS); + if (clazz == null) + clazz = clientCallbackHandlerClass(); + AuthenticateCallbackHandler callbackHandler = Utils.newInstance(clazz); + saslCallbackHandlers.put(clientSaslMechanism, callbackHandler); + } + + private void createServerCallbackHandlers(Map configs) { + for (String mechanism : jaasContexts.keySet()) { + AuthenticateCallbackHandler callbackHandler; + String prefix = ListenerName.saslMechanismPrefix(mechanism); + @SuppressWarnings("unchecked") + Class clazz = + (Class) configs.get(prefix + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS); + if (clazz != null) + callbackHandler = Utils.newInstance(clazz); + else if (mechanism.equals(PlainSaslServer.PLAIN_MECHANISM)) + callbackHandler = new PlainServerCallbackHandler(); + else if (ScramMechanism.isScram(mechanism)) + callbackHandler = new ScramServerCallbackHandler(credentialCache.cache(mechanism, ScramCredential.class), tokenCache); + else if (mechanism.equals(OAuthBearerLoginModule.OAUTHBEARER_MECHANISM)) + callbackHandler = new OAuthBearerUnsecuredValidatorCallbackHandler(); + else + callbackHandler = new SaslServerCallbackHandler(); + saslCallbackHandlers.put(mechanism, callbackHandler); + } + } + + private void createConnectionsMaxReauthMsMap(Map configs) { + for (String mechanism : jaasContexts.keySet()) { + String prefix = ListenerName.saslMechanismPrefix(mechanism); + Long connectionsMaxReauthMs = (Long) configs.get(prefix + BrokerSecurityConfigs.CONNECTIONS_MAX_REAUTH_MS); + if (connectionsMaxReauthMs == null) + connectionsMaxReauthMs = (Long) configs.get(BrokerSecurityConfigs.CONNECTIONS_MAX_REAUTH_MS); + if (connectionsMaxReauthMs != null) + connectionsMaxReauthMsByMechanism.put(mechanism, connectionsMaxReauthMs); + } + } + + protected Class defaultLoginClass() { + if (jaasContexts.containsKey(SaslConfigs.GSSAPI_MECHANISM)) + return KerberosLogin.class; + if (OAuthBearerLoginModule.OAUTHBEARER_MECHANISM.equals(clientSaslMechanism)) + return OAuthBearerRefreshingLogin.class; + return DefaultLogin.class; + } + + private Class clientCallbackHandlerClass() { + switch (clientSaslMechanism) { + case SaslConfigs.GSSAPI_MECHANISM: + return KerberosClientCallbackHandler.class; + case OAuthBearerLoginModule.OAUTHBEARER_MECHANISM: + return OAuthBearerSaslClientCallbackHandler.class; + default: + return SaslClientCallbackHandler.class; + } + } + + // As described in http://docs.oracle.com/javase/8/docs/technotes/guides/security/jgss/jgss-features.html: + // "To enable Java GSS to delegate to the native GSS library and its list of native mechanisms, + // set the system property "sun.security.jgss.native" to true" + // "In addition, when performing operations as a particular Subject, for example, Subject.doAs(...) + // or Subject.doAsPrivileged(...), the to-be-used GSSCredential should be added to Subject's + // private credential set. Otherwise, the GSS operations will fail since no credential is found." + private void maybeAddNativeGssapiCredentials(Subject subject) { + boolean usingNativeJgss = Boolean.getBoolean(GSS_NATIVE_PROP); + if (usingNativeJgss && subject.getPrivateCredentials(GSSCredential.class).isEmpty()) { + + final String servicePrincipal = SaslClientAuthenticator.firstPrincipal(subject); + KerberosName kerberosName; + try { + kerberosName = KerberosName.parse(servicePrincipal); + } catch (IllegalArgumentException e) { + throw new KafkaException("Principal has name with unexpected format " + servicePrincipal); + } + final String servicePrincipalName = kerberosName.serviceName(); + final String serviceHostname = kerberosName.hostName(); + + try { + GSSManager manager = gssManager(); + // This Oid is used to represent the Kerberos version 5 GSS-API mechanism. It is defined in + // RFC 1964. + Oid krb5Mechanism = new Oid("1.2.840.113554.1.2.2"); + GSSName gssName = manager.createName(servicePrincipalName + "@" + serviceHostname, GSSName.NT_HOSTBASED_SERVICE); + GSSCredential cred = manager.createCredential(gssName, + GSSContext.INDEFINITE_LIFETIME, krb5Mechanism, GSSCredential.ACCEPT_ONLY); + subject.getPrivateCredentials().add(cred); + log.info("Configured native GSSAPI private credentials for {}@{}", serviceHostname, serviceHostname); + } catch (GSSException ex) { + log.warn("Cannot add private credential to subject; clients authentication may fail", ex); + } + } + } + + // Visibility to override for testing + protected GSSManager gssManager() { + return GSSManager.getInstance(); + } + + // Visibility for testing + protected Subject subject(String saslMechanism) { + return subjects.get(saslMechanism); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/Selectable.java b/clients/src/main/java/org/apache/kafka/common/network/Selectable.java new file mode 100644 index 0000000..afdd42e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/Selectable.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.Collection; +import java.util.List; +import java.util.Map; + +/** + * An interface for asynchronous, multi-channel network I/O + */ +public interface Selectable { + + /** + * See {@link #connect(String, InetSocketAddress, int, int) connect()} + */ + int USE_DEFAULT_BUFFER_SIZE = -1; + + /** + * Begin establishing a socket connection to the given address identified by the given address + * @param id The id for this connection + * @param address The address to connect to + * @param sendBufferSize The send buffer for the socket + * @param receiveBufferSize The receive buffer for the socket + * @throws IOException If we cannot begin connecting + */ + void connect(String id, InetSocketAddress address, int sendBufferSize, int receiveBufferSize) throws IOException; + + /** + * Wakeup this selector if it is blocked on I/O + */ + void wakeup(); + + /** + * Close this selector + */ + void close(); + + /** + * Close the connection identified by the given id + */ + void close(String id); + + /** + * Queue the given request for sending in the subsequent {@link #poll(long) poll()} calls + * @param send The request to send + */ + void send(NetworkSend send); + + /** + * Do I/O. Reads, writes, connection establishment, etc. + * @param timeout The amount of time to block if there is nothing to do + * @throws IOException + */ + void poll(long timeout) throws IOException; + + /** + * The list of sends that completed on the last {@link #poll(long) poll()} call. + */ + List completedSends(); + + /** + * The collection of receives that completed on the last {@link #poll(long) poll()} call. + */ + Collection completedReceives(); + + /** + * The connections that finished disconnecting on the last {@link #poll(long) poll()} + * call. Channel state indicates the local channel state at the time of disconnection. + */ + Map disconnected(); + + /** + * The list of connections that completed their connection on the last {@link #poll(long) poll()} + * call. + */ + List connected(); + + /** + * Disable reads from the given connection + * @param id The id for the connection + */ + void mute(String id); + + /** + * Re-enable reads from the given connection + * @param id The id for the connection + */ + void unmute(String id); + + /** + * Disable reads from all connections + */ + void muteAll(); + + /** + * Re-enable reads from all connections + */ + void unmuteAll(); + + /** + * returns true if a channel is ready + * @param id The id for the connection + */ + boolean isChannelReady(String id); +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/Selector.java b/clients/src/main/java/org/apache/kafka/common/network/Selector.java new file mode 100644 index 0000000..dc7534a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/Selector.java @@ -0,0 +1,1491 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.memory.MemoryPool; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.internals.IntGaugeSuite; +import org.apache.kafka.common.metrics.stats.Avg; +import org.apache.kafka.common.metrics.stats.CumulativeSum; +import org.apache.kafka.common.metrics.stats.Max; +import org.apache.kafka.common.metrics.stats.Meter; +import org.apache.kafka.common.metrics.stats.SampledStat; +import org.apache.kafka.common.metrics.stats.WindowedCount; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.nio.channels.CancelledKeyException; +import java.nio.channels.SelectionKey; +import java.nio.channels.SocketChannel; +import java.nio.channels.UnresolvedAddressException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +/** + * A nioSelector interface for doing non-blocking multi-connection network I/O. + *

        + * This class works with {@link NetworkSend} and {@link NetworkReceive} to transmit size-delimited network requests and + * responses. + *

        + * A connection can be added to the nioSelector associated with an integer id by doing + * + *

        + * nioSelector.connect("42", new InetSocketAddress("google.com", server.port), 64000, 64000);
        + * 
        + * + * The connect call does not block on the creation of the TCP connection, so the connect method only begins initiating + * the connection. The successful invocation of this method does not mean a valid connection has been established. + * + * Sending requests, receiving responses, processing connection completions, and disconnections on the existing + * connections are all done using the poll() call. + * + *
        + * nioSelector.send(new NetworkSend(myDestination, myBytes));
        + * nioSelector.send(new NetworkSend(myOtherDestination, myOtherBytes));
        + * nioSelector.poll(TIMEOUT_MS);
        + * 
        + * + * The nioSelector maintains several lists that are reset by each call to poll() which are available via + * various getters. These are reset by each call to poll(). + * + * This class is not thread safe! + */ +public class Selector implements Selectable, AutoCloseable { + + public static final long NO_IDLE_TIMEOUT_MS = -1; + public static final int NO_FAILED_AUTHENTICATION_DELAY = 0; + + private enum CloseMode { + GRACEFUL(true), // process outstanding buffered receives, notify disconnect + NOTIFY_ONLY(true), // discard any outstanding receives, notify disconnect + DISCARD_NO_NOTIFY(false); // discard any outstanding receives, no disconnect notification + + boolean notifyDisconnect; + + CloseMode(boolean notifyDisconnect) { + this.notifyDisconnect = notifyDisconnect; + } + } + + private final Logger log; + private final java.nio.channels.Selector nioSelector; + private final Map channels; + private final Set explicitlyMutedChannels; + private boolean outOfMemory; + private final List completedSends; + private final LinkedHashMap completedReceives; + private final Set immediatelyConnectedKeys; + private final Map closingChannels; + private Set keysWithBufferedRead; + private final Map disconnected; + private final List connected; + private final List failedSends; + private final Time time; + private final SelectorMetrics sensors; + private final ChannelBuilder channelBuilder; + private final int maxReceiveSize; + private final boolean recordTimePerConnection; + private final IdleExpiryManager idleExpiryManager; + private final LinkedHashMap delayedClosingChannels; + private final MemoryPool memoryPool; + private final long lowMemThreshold; + private final int failedAuthenticationDelayMs; + + //indicates if the previous call to poll was able to make progress in reading already-buffered data. + //this is used to prevent tight loops when memory is not available to read any more data + private boolean madeReadProgressLastPoll = true; + + /** + * Create a new nioSelector + * @param maxReceiveSize Max size in bytes of a single network receive (use {@link NetworkReceive#UNLIMITED} for no limit) + * @param connectionMaxIdleMs Max idle connection time (use {@link #NO_IDLE_TIMEOUT_MS} to disable idle timeout) + * @param failedAuthenticationDelayMs Minimum time by which failed authentication response and channel close should be delayed by. + * Use {@link #NO_FAILED_AUTHENTICATION_DELAY} to disable this delay. + * @param metrics Registry for Selector metrics + * @param time Time implementation + * @param metricGrpPrefix Prefix for the group of metrics registered by Selector + * @param metricTags Additional tags to add to metrics registered by Selector + * @param metricsPerConnection Whether or not to enable per-connection metrics + * @param channelBuilder Channel builder for every new connection + * @param logContext Context for logging with additional info + */ + public Selector(int maxReceiveSize, + long connectionMaxIdleMs, + int failedAuthenticationDelayMs, + Metrics metrics, + Time time, + String metricGrpPrefix, + Map metricTags, + boolean metricsPerConnection, + boolean recordTimePerConnection, + ChannelBuilder channelBuilder, + MemoryPool memoryPool, + LogContext logContext) { + try { + this.nioSelector = java.nio.channels.Selector.open(); + } catch (IOException e) { + throw new KafkaException(e); + } + this.maxReceiveSize = maxReceiveSize; + this.time = time; + this.channels = new HashMap<>(); + this.explicitlyMutedChannels = new HashSet<>(); + this.outOfMemory = false; + this.completedSends = new ArrayList<>(); + this.completedReceives = new LinkedHashMap<>(); + this.immediatelyConnectedKeys = new HashSet<>(); + this.closingChannels = new HashMap<>(); + this.keysWithBufferedRead = new HashSet<>(); + this.connected = new ArrayList<>(); + this.disconnected = new HashMap<>(); + this.failedSends = new ArrayList<>(); + this.log = logContext.logger(Selector.class); + this.sensors = new SelectorMetrics(metrics, metricGrpPrefix, metricTags, metricsPerConnection); + this.channelBuilder = channelBuilder; + this.recordTimePerConnection = recordTimePerConnection; + this.idleExpiryManager = connectionMaxIdleMs < 0 ? null : new IdleExpiryManager(time, connectionMaxIdleMs); + this.memoryPool = memoryPool; + this.lowMemThreshold = (long) (0.1 * this.memoryPool.size()); + this.failedAuthenticationDelayMs = failedAuthenticationDelayMs; + this.delayedClosingChannels = (failedAuthenticationDelayMs > NO_FAILED_AUTHENTICATION_DELAY) ? new LinkedHashMap() : null; + } + + public Selector(int maxReceiveSize, + long connectionMaxIdleMs, + Metrics metrics, + Time time, + String metricGrpPrefix, + Map metricTags, + boolean metricsPerConnection, + boolean recordTimePerConnection, + ChannelBuilder channelBuilder, + MemoryPool memoryPool, + LogContext logContext) { + this(maxReceiveSize, connectionMaxIdleMs, NO_FAILED_AUTHENTICATION_DELAY, metrics, time, metricGrpPrefix, metricTags, + metricsPerConnection, recordTimePerConnection, channelBuilder, memoryPool, logContext); + } + + public Selector(int maxReceiveSize, + long connectionMaxIdleMs, + int failedAuthenticationDelayMs, + Metrics metrics, + Time time, + String metricGrpPrefix, + Map metricTags, + boolean metricsPerConnection, + ChannelBuilder channelBuilder, + LogContext logContext) { + this(maxReceiveSize, connectionMaxIdleMs, failedAuthenticationDelayMs, metrics, time, metricGrpPrefix, metricTags, metricsPerConnection, false, channelBuilder, MemoryPool.NONE, logContext); + } + + public Selector(int maxReceiveSize, + long connectionMaxIdleMs, + Metrics metrics, + Time time, + String metricGrpPrefix, + Map metricTags, + boolean metricsPerConnection, + ChannelBuilder channelBuilder, + LogContext logContext) { + this(maxReceiveSize, connectionMaxIdleMs, NO_FAILED_AUTHENTICATION_DELAY, metrics, time, metricGrpPrefix, metricTags, metricsPerConnection, channelBuilder, logContext); + } + + public Selector(long connectionMaxIdleMS, Metrics metrics, Time time, String metricGrpPrefix, ChannelBuilder channelBuilder, LogContext logContext) { + this(NetworkReceive.UNLIMITED, connectionMaxIdleMS, metrics, time, metricGrpPrefix, Collections.emptyMap(), true, channelBuilder, logContext); + } + + public Selector(long connectionMaxIdleMS, int failedAuthenticationDelayMs, Metrics metrics, Time time, String metricGrpPrefix, ChannelBuilder channelBuilder, LogContext logContext) { + this(NetworkReceive.UNLIMITED, connectionMaxIdleMS, failedAuthenticationDelayMs, metrics, time, metricGrpPrefix, Collections.emptyMap(), true, channelBuilder, logContext); + } + + /** + * Begin connecting to the given address and add the connection to this nioSelector associated with the given id + * number. + *

        + * Note that this call only initiates the connection, which will be completed on a future {@link #poll(long)} + * call. Check {@link #connected()} to see which (if any) connections have completed after a given poll call. + * @param id The id for the new connection + * @param address The address to connect to + * @param sendBufferSize The send buffer for the new connection + * @param receiveBufferSize The receive buffer for the new connection + * @throws IllegalStateException if there is already a connection for that id + * @throws IOException if DNS resolution fails on the hostname or if the broker is down + */ + @Override + public void connect(String id, InetSocketAddress address, int sendBufferSize, int receiveBufferSize) throws IOException { + ensureNotRegistered(id); + SocketChannel socketChannel = SocketChannel.open(); + SelectionKey key = null; + try { + configureSocketChannel(socketChannel, sendBufferSize, receiveBufferSize); + boolean connected = doConnect(socketChannel, address); + key = registerChannel(id, socketChannel, SelectionKey.OP_CONNECT); + + if (connected) { + // OP_CONNECT won't trigger for immediately connected channels + log.debug("Immediately connected to node {}", id); + immediatelyConnectedKeys.add(key); + key.interestOps(0); + } + } catch (IOException | RuntimeException e) { + if (key != null) + immediatelyConnectedKeys.remove(key); + channels.remove(id); + socketChannel.close(); + throw e; + } + } + + // Visible to allow test cases to override. In particular, we use this to implement a blocking connect + // in order to simulate "immediately connected" sockets. + protected boolean doConnect(SocketChannel channel, InetSocketAddress address) throws IOException { + try { + return channel.connect(address); + } catch (UnresolvedAddressException e) { + throw new IOException("Can't resolve address: " + address, e); + } + } + + private void configureSocketChannel(SocketChannel socketChannel, int sendBufferSize, int receiveBufferSize) + throws IOException { + socketChannel.configureBlocking(false); + Socket socket = socketChannel.socket(); + socket.setKeepAlive(true); + if (sendBufferSize != Selectable.USE_DEFAULT_BUFFER_SIZE) + socket.setSendBufferSize(sendBufferSize); + if (receiveBufferSize != Selectable.USE_DEFAULT_BUFFER_SIZE) + socket.setReceiveBufferSize(receiveBufferSize); + socket.setTcpNoDelay(true); + } + + /** + * Register the nioSelector with an existing channel + * Use this on server-side, when a connection is accepted by a different thread but processed by the Selector + *

        + * If a connection already exists with the same connection id in `channels` or `closingChannels`, + * an exception is thrown. Connection ids must be chosen to avoid conflict when remote ports are reused. + * Kafka brokers add an incrementing index to the connection id to avoid reuse in the timing window + * where an existing connection may not yet have been closed by the broker when a new connection with + * the same remote host:port is processed. + *

        + * If a `KafkaChannel` cannot be created for this connection, the `socketChannel` is closed + * and its selection key cancelled. + *

        + */ + public void register(String id, SocketChannel socketChannel) throws IOException { + ensureNotRegistered(id); + registerChannel(id, socketChannel, SelectionKey.OP_READ); + this.sensors.connectionCreated.record(); + // Default to empty client information as the ApiVersionsRequest is not + // mandatory. In this case, we still want to account for the connection. + ChannelMetadataRegistry metadataRegistry = this.channel(id).channelMetadataRegistry(); + if (metadataRegistry.clientInformation() == null) + metadataRegistry.registerClientInformation(ClientInformation.EMPTY); + } + + private void ensureNotRegistered(String id) { + if (this.channels.containsKey(id)) + throw new IllegalStateException("There is already a connection for id " + id); + if (this.closingChannels.containsKey(id)) + throw new IllegalStateException("There is already a connection for id " + id + " that is still being closed"); + } + + protected SelectionKey registerChannel(String id, SocketChannel socketChannel, int interestedOps) throws IOException { + SelectionKey key = socketChannel.register(nioSelector, interestedOps); + KafkaChannel channel = buildAndAttachKafkaChannel(socketChannel, id, key); + this.channels.put(id, channel); + if (idleExpiryManager != null) + idleExpiryManager.update(channel.id(), time.nanoseconds()); + return key; + } + + private KafkaChannel buildAndAttachKafkaChannel(SocketChannel socketChannel, String id, SelectionKey key) throws IOException { + try { + KafkaChannel channel = channelBuilder.buildChannel(id, key, maxReceiveSize, memoryPool, + new SelectorChannelMetadataRegistry()); + key.attach(channel); + return channel; + } catch (Exception e) { + try { + socketChannel.close(); + } finally { + key.cancel(); + } + throw new IOException("Channel could not be created for socket " + socketChannel, e); + } + } + + /** + * Interrupt the nioSelector if it is blocked waiting to do I/O. + */ + @Override + public void wakeup() { + this.nioSelector.wakeup(); + } + + /** + * Close this selector and all associated connections + */ + @Override + public void close() { + List connections = new ArrayList<>(channels.keySet()); + AtomicReference firstException = new AtomicReference<>(); + Utils.closeAllQuietly(firstException, "release connections", + connections.stream().map(id -> (AutoCloseable) () -> close(id)).toArray(AutoCloseable[]::new)); + // If there is any exception thrown in close(id), we should still be able + // to close the remaining objects, especially the sensors because keeping + // the sensors may lead to failure to start up the ReplicaFetcherThread if + // the old sensors with the same names has not yet been cleaned up. + Utils.closeQuietly(nioSelector, "nioSelector", firstException); + Utils.closeQuietly(sensors, "sensors", firstException); + Utils.closeQuietly(channelBuilder, "channelBuilder", firstException); + Throwable exception = firstException.get(); + if (exception instanceof RuntimeException && !(exception instanceof SecurityException)) { + throw (RuntimeException) exception; + } + } + + /** + * Queue the given request for sending in the subsequent {@link #poll(long)} calls + * @param send The request to send + */ + public void send(NetworkSend send) { + String connectionId = send.destinationId(); + KafkaChannel channel = openOrClosingChannelOrFail(connectionId); + if (closingChannels.containsKey(connectionId)) { + // ensure notification via `disconnected`, leave channel in the state in which closing was triggered + this.failedSends.add(connectionId); + } else { + try { + channel.setSend(send); + } catch (Exception e) { + // update the state for consistency, the channel will be discarded after `close` + channel.state(ChannelState.FAILED_SEND); + // ensure notification via `disconnected` when `failedSends` are processed in the next poll + this.failedSends.add(connectionId); + close(channel, CloseMode.DISCARD_NO_NOTIFY); + if (!(e instanceof CancelledKeyException)) { + log.error("Unexpected exception during send, closing connection {} and rethrowing exception {}", + connectionId, e); + throw e; + } + } + } + } + + /** + * Do whatever I/O can be done on each connection without blocking. This includes completing connections, completing + * disconnections, initiating new sends, or making progress on in-progress sends or receives. + * + * When this call is completed the user can check for completed sends, receives, connections or disconnects using + * {@link #completedSends()}, {@link #completedReceives()}, {@link #connected()}, {@link #disconnected()}. These + * lists will be cleared at the beginning of each `poll` call and repopulated by the call if there is + * any completed I/O. + * + * In the "Plaintext" setting, we are using socketChannel to read & write to the network. But for the "SSL" setting, + * we encrypt the data before we use socketChannel to write data to the network, and decrypt before we return the responses. + * This requires additional buffers to be maintained as we are reading from network, since the data on the wire is encrypted + * we won't be able to read exact no.of bytes as kafka protocol requires. We read as many bytes as we can, up to SSLEngine's + * application buffer size. This means we might be reading additional bytes than the requested size. + * If there is no further data to read from socketChannel selector won't invoke that channel and we have additional bytes + * in the buffer. To overcome this issue we added "keysWithBufferedRead" map which tracks channels which have data in the SSL + * buffers. If there are channels with buffered data that can by processed, we set "timeout" to 0 and process the data even + * if there is no more data to read from the socket. + * + * Atmost one entry is added to "completedReceives" for a channel in each poll. This is necessary to guarantee that + * requests from a channel are processed on the broker in the order they are sent. Since outstanding requests added + * by SocketServer to the request queue may be processed by different request handler threads, requests on each + * channel must be processed one-at-a-time to guarantee ordering. + * + * @param timeout The amount of time to wait, in milliseconds, which must be non-negative + * @throws IllegalArgumentException If `timeout` is negative + * @throws IllegalStateException If a send is given for which we have no existing connection or for which there is + * already an in-progress send + */ + @Override + public void poll(long timeout) throws IOException { + if (timeout < 0) + throw new IllegalArgumentException("timeout should be >= 0"); + + boolean madeReadProgressLastCall = madeReadProgressLastPoll; + clear(); + + boolean dataInBuffers = !keysWithBufferedRead.isEmpty(); + + if (!immediatelyConnectedKeys.isEmpty() || (madeReadProgressLastCall && dataInBuffers)) + timeout = 0; + + if (!memoryPool.isOutOfMemory() && outOfMemory) { + //we have recovered from memory pressure. unmute any channel not explicitly muted for other reasons + log.trace("Broker no longer low on memory - unmuting incoming sockets"); + for (KafkaChannel channel : channels.values()) { + if (channel.isInMutableState() && !explicitlyMutedChannels.contains(channel)) { + channel.maybeUnmute(); + } + } + outOfMemory = false; + } + + /* check ready keys */ + long startSelect = time.nanoseconds(); + int numReadyKeys = select(timeout); + long endSelect = time.nanoseconds(); + this.sensors.selectTime.record(endSelect - startSelect, time.milliseconds()); + + if (numReadyKeys > 0 || !immediatelyConnectedKeys.isEmpty() || dataInBuffers) { + Set readyKeys = this.nioSelector.selectedKeys(); + + // Poll from channels that have buffered data (but nothing more from the underlying socket) + if (dataInBuffers) { + keysWithBufferedRead.removeAll(readyKeys); //so no channel gets polled twice + Set toPoll = keysWithBufferedRead; + keysWithBufferedRead = new HashSet<>(); //poll() calls will repopulate if needed + pollSelectionKeys(toPoll, false, endSelect); + } + + // Poll from channels where the underlying socket has more data + pollSelectionKeys(readyKeys, false, endSelect); + // Clear all selected keys so that they are included in the ready count for the next select + readyKeys.clear(); + + pollSelectionKeys(immediatelyConnectedKeys, true, endSelect); + immediatelyConnectedKeys.clear(); + } else { + madeReadProgressLastPoll = true; //no work is also "progress" + } + + long endIo = time.nanoseconds(); + this.sensors.ioTime.record(endIo - endSelect, time.milliseconds()); + + // Close channels that were delayed and are now ready to be closed + completeDelayedChannelClose(endIo); + + // we use the time at the end of select to ensure that we don't close any connections that + // have just been processed in pollSelectionKeys + maybeCloseOldestConnection(endSelect); + } + + /** + * handle any ready I/O on a set of selection keys + * @param selectionKeys set of keys to handle + * @param isImmediatelyConnected true if running over a set of keys for just-connected sockets + * @param currentTimeNanos time at which set of keys was determined + */ + // package-private for testing + void pollSelectionKeys(Set selectionKeys, + boolean isImmediatelyConnected, + long currentTimeNanos) { + for (SelectionKey key : determineHandlingOrder(selectionKeys)) { + KafkaChannel channel = channel(key); + long channelStartTimeNanos = recordTimePerConnection ? time.nanoseconds() : 0; + boolean sendFailed = false; + String nodeId = channel.id(); + + // register all per-connection metrics at once + sensors.maybeRegisterConnectionMetrics(nodeId); + if (idleExpiryManager != null) + idleExpiryManager.update(nodeId, currentTimeNanos); + + try { + /* complete any connections that have finished their handshake (either normally or immediately) */ + if (isImmediatelyConnected || key.isConnectable()) { + if (channel.finishConnect()) { + this.connected.add(nodeId); + this.sensors.connectionCreated.record(); + + SocketChannel socketChannel = (SocketChannel) key.channel(); + log.debug("Created socket with SO_RCVBUF = {}, SO_SNDBUF = {}, SO_TIMEOUT = {} to node {}", + socketChannel.socket().getReceiveBufferSize(), + socketChannel.socket().getSendBufferSize(), + socketChannel.socket().getSoTimeout(), + nodeId); + } else { + continue; + } + } + + /* if channel is not ready finish prepare */ + if (channel.isConnected() && !channel.ready()) { + channel.prepare(); + if (channel.ready()) { + long readyTimeMs = time.milliseconds(); + boolean isReauthentication = channel.successfulAuthentications() > 1; + if (isReauthentication) { + sensors.successfulReauthentication.record(1.0, readyTimeMs); + if (channel.reauthenticationLatencyMs() == null) + log.warn( + "Should never happen: re-authentication latency for a re-authenticated channel was null; continuing..."); + else + sensors.reauthenticationLatency + .record(channel.reauthenticationLatencyMs().doubleValue(), readyTimeMs); + } else { + sensors.successfulAuthentication.record(1.0, readyTimeMs); + if (!channel.connectedClientSupportsReauthentication()) + sensors.successfulAuthenticationNoReauth.record(1.0, readyTimeMs); + } + log.debug("Successfully {}authenticated with {}", isReauthentication ? + "re-" : "", channel.socketDescription()); + } + } + if (channel.ready() && channel.state() == ChannelState.NOT_CONNECTED) + channel.state(ChannelState.READY); + Optional responseReceivedDuringReauthentication = channel.pollResponseReceivedDuringReauthentication(); + responseReceivedDuringReauthentication.ifPresent(receive -> { + long currentTimeMs = time.milliseconds(); + addToCompletedReceives(channel, receive, currentTimeMs); + }); + + //if channel is ready and has bytes to read from socket or buffer, and has no + //previous completed receive then read from it + if (channel.ready() && (key.isReadable() || channel.hasBytesBuffered()) && !hasCompletedReceive(channel) + && !explicitlyMutedChannels.contains(channel)) { + attemptRead(channel); + } + + if (channel.hasBytesBuffered() && !explicitlyMutedChannels.contains(channel)) { + //this channel has bytes enqueued in intermediary buffers that we could not read + //(possibly because no memory). it may be the case that the underlying socket will + //not come up in the next poll() and so we need to remember this channel for the + //next poll call otherwise data may be stuck in said buffers forever. If we attempt + //to process buffered data and no progress is made, the channel buffered status is + //cleared to avoid the overhead of checking every time. + keysWithBufferedRead.add(key); + } + + /* if channel is ready write to any sockets that have space in their buffer and for which we have data */ + + long nowNanos = channelStartTimeNanos != 0 ? channelStartTimeNanos : currentTimeNanos; + try { + attemptWrite(key, channel, nowNanos); + } catch (Exception e) { + sendFailed = true; + throw e; + } + + /* cancel any defunct sockets */ + if (!key.isValid()) + close(channel, CloseMode.GRACEFUL); + + } catch (Exception e) { + String desc = channel.socketDescription(); + if (e instanceof IOException) { + log.debug("Connection with {} disconnected", desc, e); + } else if (e instanceof AuthenticationException) { + boolean isReauthentication = channel.successfulAuthentications() > 0; + if (isReauthentication) + sensors.failedReauthentication.record(); + else + sensors.failedAuthentication.record(); + String exceptionMessage = e.getMessage(); + if (e instanceof DelayedResponseAuthenticationException) + exceptionMessage = e.getCause().getMessage(); + log.info("Failed {}authentication with {} ({})", isReauthentication ? "re-" : "", + desc, exceptionMessage); + } else { + log.warn("Unexpected error from {}; closing connection", desc, e); + } + + if (e instanceof DelayedResponseAuthenticationException) + maybeDelayCloseOnAuthenticationFailure(channel); + else + close(channel, sendFailed ? CloseMode.NOTIFY_ONLY : CloseMode.GRACEFUL); + } finally { + maybeRecordTimePerConnection(channel, channelStartTimeNanos); + } + } + } + + private void attemptWrite(SelectionKey key, KafkaChannel channel, long nowNanos) throws IOException { + if (channel.hasSend() + && channel.ready() + && key.isWritable() + && !channel.maybeBeginClientReauthentication(() -> nowNanos)) { + write(channel); + } + } + + // package-private for testing + void write(KafkaChannel channel) throws IOException { + String nodeId = channel.id(); + long bytesSent = channel.write(); + NetworkSend send = channel.maybeCompleteSend(); + // We may complete the send with bytesSent < 1 if `TransportLayer.hasPendingWrites` was true and `channel.write()` + // caused the pending writes to be written to the socket channel buffer + if (bytesSent > 0 || send != null) { + long currentTimeMs = time.milliseconds(); + if (bytesSent > 0) + this.sensors.recordBytesSent(nodeId, bytesSent, currentTimeMs); + if (send != null) { + this.completedSends.add(send); + this.sensors.recordCompletedSend(nodeId, send.size(), currentTimeMs); + } + } + } + + private Collection determineHandlingOrder(Set selectionKeys) { + //it is possible that the iteration order over selectionKeys is the same every invocation. + //this may cause starvation of reads when memory is low. to address this we shuffle the keys if memory is low. + if (!outOfMemory && memoryPool.availableMemory() < lowMemThreshold) { + List shuffledKeys = new ArrayList<>(selectionKeys); + Collections.shuffle(shuffledKeys); + return shuffledKeys; + } else { + return selectionKeys; + } + } + + private void attemptRead(KafkaChannel channel) throws IOException { + String nodeId = channel.id(); + + long bytesReceived = channel.read(); + if (bytesReceived != 0) { + long currentTimeMs = time.milliseconds(); + sensors.recordBytesReceived(nodeId, bytesReceived, currentTimeMs); + madeReadProgressLastPoll = true; + + NetworkReceive receive = channel.maybeCompleteReceive(); + if (receive != null) { + addToCompletedReceives(channel, receive, currentTimeMs); + } + } + if (channel.isMuted()) { + outOfMemory = true; //channel has muted itself due to memory pressure. + } else { + madeReadProgressLastPoll = true; + } + } + + private boolean maybeReadFromClosingChannel(KafkaChannel channel) { + boolean hasPending; + if (channel.state().state() != ChannelState.State.READY) + hasPending = false; + else if (explicitlyMutedChannels.contains(channel) || hasCompletedReceive(channel)) + hasPending = true; + else { + try { + attemptRead(channel); + hasPending = hasCompletedReceive(channel); + } catch (Exception e) { + log.trace("Read from closing channel failed, ignoring exception", e); + hasPending = false; + } + } + return hasPending; + } + + // Record time spent in pollSelectionKeys for channel (moved into a method to keep checkstyle happy) + private void maybeRecordTimePerConnection(KafkaChannel channel, long startTimeNanos) { + if (recordTimePerConnection) + channel.addNetworkThreadTimeNanos(time.nanoseconds() - startTimeNanos); + } + + @Override + public List completedSends() { + return this.completedSends; + } + + @Override + public Collection completedReceives() { + return this.completedReceives.values(); + } + + @Override + public Map disconnected() { + return this.disconnected; + } + + @Override + public List connected() { + return this.connected; + } + + @Override + public void mute(String id) { + KafkaChannel channel = openOrClosingChannelOrFail(id); + mute(channel); + } + + private void mute(KafkaChannel channel) { + channel.mute(); + explicitlyMutedChannels.add(channel); + keysWithBufferedRead.remove(channel.selectionKey()); + } + + @Override + public void unmute(String id) { + KafkaChannel channel = openOrClosingChannelOrFail(id); + unmute(channel); + } + + private void unmute(KafkaChannel channel) { + // Remove the channel from explicitlyMutedChannels only if the channel has been actually unmuted. + if (channel.maybeUnmute()) { + explicitlyMutedChannels.remove(channel); + if (channel.hasBytesBuffered()) { + keysWithBufferedRead.add(channel.selectionKey()); + } + } + } + + @Override + public void muteAll() { + for (KafkaChannel channel : this.channels.values()) + mute(channel); + } + + @Override + public void unmuteAll() { + for (KafkaChannel channel : this.channels.values()) + unmute(channel); + } + + // package-private for testing + void completeDelayedChannelClose(long currentTimeNanos) { + if (delayedClosingChannels == null) + return; + + while (!delayedClosingChannels.isEmpty()) { + DelayedAuthenticationFailureClose delayedClose = delayedClosingChannels.values().iterator().next(); + if (!delayedClose.tryClose(currentTimeNanos)) + break; + } + } + + private void maybeCloseOldestConnection(long currentTimeNanos) { + if (idleExpiryManager == null) + return; + + Map.Entry expiredConnection = idleExpiryManager.pollExpiredConnection(currentTimeNanos); + if (expiredConnection != null) { + String connectionId = expiredConnection.getKey(); + KafkaChannel channel = this.channels.get(connectionId); + if (channel != null) { + if (log.isTraceEnabled()) + log.trace("About to close the idle connection from {} due to being idle for {} millis", + connectionId, (currentTimeNanos - expiredConnection.getValue()) / 1000 / 1000); + channel.state(ChannelState.EXPIRED); + close(channel, CloseMode.GRACEFUL); + } + } + } + + /** + * Clears completed receives. This is used by SocketServer to remove references to + * receive buffers after processing completed receives, without waiting for the next + * poll(). + */ + public void clearCompletedReceives() { + this.completedReceives.clear(); + } + + /** + * Clears completed sends. This is used by SocketServer to remove references to + * send buffers after processing completed sends, without waiting for the next + * poll(). + */ + public void clearCompletedSends() { + this.completedSends.clear(); + } + + /** + * Clears all the results from the previous poll. This is invoked by Selector at the start of + * a poll() when all the results from the previous poll are expected to have been handled. + *

        + * SocketServer uses {@link #clearCompletedSends()} and {@link #clearCompletedReceives()} to + * clear `completedSends` and `completedReceives` as soon as they are processed to avoid + * holding onto large request/response buffers from multiple connections longer than necessary. + * Clients rely on Selector invoking {@link #clear()} at the start of each poll() since memory usage + * is less critical and clearing once-per-poll provides the flexibility to process these results in + * any order before the next poll. + */ + private void clear() { + this.completedSends.clear(); + this.completedReceives.clear(); + this.connected.clear(); + this.disconnected.clear(); + + // Remove closed channels after all their buffered receives have been processed or if a send was requested + for (Iterator> it = closingChannels.entrySet().iterator(); it.hasNext(); ) { + KafkaChannel channel = it.next().getValue(); + boolean sendFailed = failedSends.remove(channel.id()); + boolean hasPending = false; + if (!sendFailed) + hasPending = maybeReadFromClosingChannel(channel); + if (!hasPending || sendFailed) { + doClose(channel, true); + it.remove(); + } + } + + for (String channel : this.failedSends) + this.disconnected.put(channel, ChannelState.FAILED_SEND); + this.failedSends.clear(); + this.madeReadProgressLastPoll = false; + } + + /** + * Check for data, waiting up to the given timeout. + * + * @param timeoutMs Length of time to wait, in milliseconds, which must be non-negative + * @return The number of keys ready + */ + private int select(long timeoutMs) throws IOException { + if (timeoutMs < 0L) + throw new IllegalArgumentException("timeout should be >= 0"); + + if (timeoutMs == 0L) + return this.nioSelector.selectNow(); + else + return this.nioSelector.select(timeoutMs); + } + + /** + * Close the connection identified by the given id + */ + public void close(String id) { + KafkaChannel channel = this.channels.get(id); + if (channel != null) { + // There is no disconnect notification for local close, but updating + // channel state here anyway to avoid confusion. + channel.state(ChannelState.LOCAL_CLOSE); + close(channel, CloseMode.DISCARD_NO_NOTIFY); + } else { + KafkaChannel closingChannel = this.closingChannels.remove(id); + // Close any closing channel, leave the channel in the state in which closing was triggered + if (closingChannel != null) + doClose(closingChannel, false); + } + } + + private void maybeDelayCloseOnAuthenticationFailure(KafkaChannel channel) { + DelayedAuthenticationFailureClose delayedClose = new DelayedAuthenticationFailureClose(channel, failedAuthenticationDelayMs); + if (delayedClosingChannels != null) + delayedClosingChannels.put(channel.id(), delayedClose); + else + delayedClose.closeNow(); + } + + private void handleCloseOnAuthenticationFailure(KafkaChannel channel) { + try { + channel.completeCloseOnAuthenticationFailure(); + } catch (Exception e) { + log.error("Exception handling close on authentication failure node {}", channel.id(), e); + } finally { + close(channel, CloseMode.GRACEFUL); + } + } + + /** + * Begin closing this connection. + * If 'closeMode' is `CloseMode.GRACEFUL`, the channel is disconnected here, but outstanding receives + * are processed. The channel is closed when there are no outstanding receives or if a send is + * requested. For other values of `closeMode`, outstanding receives are discarded and the channel + * is closed immediately. + * + * The channel will be added to disconnect list when it is actually closed if `closeMode.notifyDisconnect` + * is true. + */ + private void close(KafkaChannel channel, CloseMode closeMode) { + channel.disconnect(); + + // Ensure that `connected` does not have closed channels. This could happen if `prepare` throws an exception + // in the `poll` invocation when `finishConnect` succeeds + connected.remove(channel.id()); + + // Keep track of closed channels with pending receives so that all received records + // may be processed. For example, when producer with acks=0 sends some records and + // closes its connections, a single poll() in the broker may receive records and + // handle close(). When the remote end closes its connection, the channel is retained until + // a send fails or all outstanding receives are processed. Mute state of disconnected channels + // are tracked to ensure that requests are processed one-by-one by the broker to preserve ordering. + if (closeMode == CloseMode.GRACEFUL && maybeReadFromClosingChannel(channel)) { + closingChannels.put(channel.id(), channel); + log.debug("Tracking closing connection {} to process outstanding requests", channel.id()); + } else { + doClose(channel, closeMode.notifyDisconnect); + } + this.channels.remove(channel.id()); + + if (delayedClosingChannels != null) + delayedClosingChannels.remove(channel.id()); + + if (idleExpiryManager != null) + idleExpiryManager.remove(channel.id()); + } + + private void doClose(KafkaChannel channel, boolean notifyDisconnect) { + SelectionKey key = channel.selectionKey(); + try { + immediatelyConnectedKeys.remove(key); + keysWithBufferedRead.remove(key); + channel.close(); + } catch (IOException e) { + log.error("Exception closing connection to node {}:", channel.id(), e); + } finally { + key.cancel(); + key.attach(null); + } + + this.sensors.connectionClosed.record(); + this.explicitlyMutedChannels.remove(channel); + if (notifyDisconnect) + this.disconnected.put(channel.id(), channel.state()); + } + + /** + * check if channel is ready + */ + @Override + public boolean isChannelReady(String id) { + KafkaChannel channel = this.channels.get(id); + return channel != null && channel.ready(); + } + + private KafkaChannel openOrClosingChannelOrFail(String id) { + KafkaChannel channel = this.channels.get(id); + if (channel == null) + channel = this.closingChannels.get(id); + if (channel == null) + throw new IllegalStateException("Attempt to retrieve channel for which there is no connection. Connection id " + id + " existing connections " + channels.keySet()); + return channel; + } + + /** + * Return the selector channels. + */ + public List channels() { + return new ArrayList<>(channels.values()); + } + + /** + * Return the channel associated with this connection or `null` if there is no channel associated with the + * connection. + */ + public KafkaChannel channel(String id) { + return this.channels.get(id); + } + + /** + * Return the channel with the specified id if it was disconnected, but not yet closed + * since there are outstanding messages to be processed. + */ + public KafkaChannel closingChannel(String id) { + return closingChannels.get(id); + } + + /** + * Returns the lowest priority channel chosen using the following sequence: + * 1) If one or more channels are in closing state, return any one of them + * 2) If idle expiry manager is enabled, return the least recently updated channel + * 3) Otherwise return any of the channels + * + * This method is used to close a channel to accommodate a new channel on the inter-broker listener + * when broker-wide `max.connections` limit is enabled. + */ + public KafkaChannel lowestPriorityChannel() { + KafkaChannel channel = null; + if (!closingChannels.isEmpty()) { + channel = closingChannels.values().iterator().next(); + } else if (idleExpiryManager != null && !idleExpiryManager.lruConnections.isEmpty()) { + String channelId = idleExpiryManager.lruConnections.keySet().iterator().next(); + channel = channel(channelId); + } else if (!channels.isEmpty()) { + channel = channels.values().iterator().next(); + } + return channel; + } + + /** + * Get the channel associated with selectionKey + */ + private KafkaChannel channel(SelectionKey key) { + return (KafkaChannel) key.attachment(); + } + + /** + * Check if given channel has a completed receive + */ + private boolean hasCompletedReceive(KafkaChannel channel) { + return completedReceives.containsKey(channel.id()); + } + + /** + * adds a receive to completed receives + */ + private void addToCompletedReceives(KafkaChannel channel, NetworkReceive networkReceive, long currentTimeMs) { + if (hasCompletedReceive(channel)) + throw new IllegalStateException("Attempting to add second completed receive to channel " + channel.id()); + + this.completedReceives.put(channel.id(), networkReceive); + sensors.recordCompletedReceive(channel.id(), networkReceive.size(), currentTimeMs); + } + + // only for testing + public Set keys() { + return new HashSet<>(nioSelector.keys()); + } + + + class SelectorChannelMetadataRegistry implements ChannelMetadataRegistry { + private CipherInformation cipherInformation; + private ClientInformation clientInformation; + + @Override + public void registerCipherInformation(final CipherInformation cipherInformation) { + if (this.cipherInformation != null) { + if (this.cipherInformation.equals(cipherInformation)) + return; + sensors.connectionsByCipher.decrement(this.cipherInformation); + } + + this.cipherInformation = cipherInformation; + sensors.connectionsByCipher.increment(cipherInformation); + } + + @Override + public CipherInformation cipherInformation() { + return cipherInformation; + } + + @Override + public void registerClientInformation(final ClientInformation clientInformation) { + if (this.clientInformation != null) { + if (this.clientInformation.equals(clientInformation)) + return; + sensors.connectionsByClient.decrement(this.clientInformation); + } + + this.clientInformation = clientInformation; + sensors.connectionsByClient.increment(clientInformation); + } + + @Override + public ClientInformation clientInformation() { + return clientInformation; + } + + @Override + public void close() { + if (this.cipherInformation != null) { + sensors.connectionsByCipher.decrement(this.cipherInformation); + this.cipherInformation = null; + } + + if (this.clientInformation != null) { + sensors.connectionsByClient.decrement(this.clientInformation); + this.clientInformation = null; + } + } + } + + class SelectorMetrics implements AutoCloseable { + private final Metrics metrics; + private final Map metricTags; + private final boolean metricsPerConnection; + private final String metricGrpName; + private final String perConnectionMetricGrpName; + + public final Sensor connectionClosed; + public final Sensor connectionCreated; + public final Sensor successfulAuthentication; + public final Sensor successfulReauthentication; + public final Sensor successfulAuthenticationNoReauth; + public final Sensor reauthenticationLatency; + public final Sensor failedAuthentication; + public final Sensor failedReauthentication; + public final Sensor bytesTransferred; + public final Sensor bytesSent; + public final Sensor requestsSent; + public final Sensor bytesReceived; + public final Sensor responsesReceived; + public final Sensor selectTime; + public final Sensor ioTime; + public final IntGaugeSuite connectionsByCipher; + public final IntGaugeSuite connectionsByClient; + + /* Names of metrics that are not registered through sensors */ + private final List topLevelMetricNames = new ArrayList<>(); + private final List sensors = new ArrayList<>(); + + public SelectorMetrics(Metrics metrics, String metricGrpPrefix, Map metricTags, boolean metricsPerConnection) { + this.metrics = metrics; + this.metricTags = metricTags; + this.metricsPerConnection = metricsPerConnection; + this.metricGrpName = metricGrpPrefix + "-metrics"; + this.perConnectionMetricGrpName = metricGrpPrefix + "-node-metrics"; + StringBuilder tagsSuffix = new StringBuilder(); + + for (Map.Entry tag: metricTags.entrySet()) { + tagsSuffix.append(tag.getKey()); + tagsSuffix.append("-"); + tagsSuffix.append(tag.getValue()); + } + + this.connectionClosed = sensor("connections-closed:" + tagsSuffix); + this.connectionClosed.add(createMeter(metrics, metricGrpName, metricTags, + "connection-close", "connections closed")); + + this.connectionCreated = sensor("connections-created:" + tagsSuffix); + this.connectionCreated.add(createMeter(metrics, metricGrpName, metricTags, + "connection-creation", "new connections established")); + + this.successfulAuthentication = sensor("successful-authentication:" + tagsSuffix); + this.successfulAuthentication.add(createMeter(metrics, metricGrpName, metricTags, + "successful-authentication", "connections with successful authentication")); + + this.successfulReauthentication = sensor("successful-reauthentication:" + tagsSuffix); + this.successfulReauthentication.add(createMeter(metrics, metricGrpName, metricTags, + "successful-reauthentication", "successful re-authentication of connections")); + + this.successfulAuthenticationNoReauth = sensor("successful-authentication-no-reauth:" + tagsSuffix); + MetricName successfulAuthenticationNoReauthMetricName = metrics.metricName( + "successful-authentication-no-reauth-total", metricGrpName, + "The total number of connections with successful authentication where the client does not support re-authentication", + metricTags); + this.successfulAuthenticationNoReauth.add(successfulAuthenticationNoReauthMetricName, new CumulativeSum()); + + this.failedAuthentication = sensor("failed-authentication:" + tagsSuffix); + this.failedAuthentication.add(createMeter(metrics, metricGrpName, metricTags, + "failed-authentication", "connections with failed authentication")); + + this.failedReauthentication = sensor("failed-reauthentication:" + tagsSuffix); + this.failedReauthentication.add(createMeter(metrics, metricGrpName, metricTags, + "failed-reauthentication", "failed re-authentication of connections")); + + this.reauthenticationLatency = sensor("reauthentication-latency:" + tagsSuffix); + MetricName reauthenticationLatencyMaxMetricName = metrics.metricName("reauthentication-latency-max", + metricGrpName, "The max latency observed due to re-authentication", + metricTags); + this.reauthenticationLatency.add(reauthenticationLatencyMaxMetricName, new Max()); + MetricName reauthenticationLatencyAvgMetricName = metrics.metricName("reauthentication-latency-avg", + metricGrpName, "The average latency observed due to re-authentication", + metricTags); + this.reauthenticationLatency.add(reauthenticationLatencyAvgMetricName, new Avg()); + + this.bytesTransferred = sensor("bytes-sent-received:" + tagsSuffix); + bytesTransferred.add(createMeter(metrics, metricGrpName, metricTags, new WindowedCount(), + "network-io", "network operations (reads or writes) on all connections")); + + this.bytesSent = sensor("bytes-sent:" + tagsSuffix, bytesTransferred); + this.bytesSent.add(createMeter(metrics, metricGrpName, metricTags, + "outgoing-byte", "outgoing bytes sent to all servers")); + + this.requestsSent = sensor("requests-sent:" + tagsSuffix); + this.requestsSent.add(createMeter(metrics, metricGrpName, metricTags, new WindowedCount(), + "request", "requests sent")); + MetricName metricName = metrics.metricName("request-size-avg", metricGrpName, "The average size of requests sent.", metricTags); + this.requestsSent.add(metricName, new Avg()); + metricName = metrics.metricName("request-size-max", metricGrpName, "The maximum size of any request sent.", metricTags); + this.requestsSent.add(metricName, new Max()); + + this.bytesReceived = sensor("bytes-received:" + tagsSuffix, bytesTransferred); + this.bytesReceived.add(createMeter(metrics, metricGrpName, metricTags, + "incoming-byte", "bytes read off all sockets")); + + this.responsesReceived = sensor("responses-received:" + tagsSuffix); + this.responsesReceived.add(createMeter(metrics, metricGrpName, metricTags, + new WindowedCount(), "response", "responses received")); + + this.selectTime = sensor("select-time:" + tagsSuffix); + this.selectTime.add(createMeter(metrics, metricGrpName, metricTags, + new WindowedCount(), "select", "times the I/O layer checked for new I/O to perform")); + metricName = metrics.metricName("io-wait-time-ns-avg", metricGrpName, "The average length of time the I/O thread spent waiting for a socket ready for reads or writes in nanoseconds.", metricTags); + this.selectTime.add(metricName, new Avg()); + this.selectTime.add(createIOThreadRatioMeterLegacy(metrics, metricGrpName, metricTags, "io-wait", "waiting")); + this.selectTime.add(createIOThreadRatioMeter(metrics, metricGrpName, metricTags, "io-wait", "waiting")); + + this.ioTime = sensor("io-time:" + tagsSuffix); + metricName = metrics.metricName("io-time-ns-avg", metricGrpName, "The average length of time for I/O per select call in nanoseconds.", metricTags); + this.ioTime.add(metricName, new Avg()); + this.ioTime.add(createIOThreadRatioMeterLegacy(metrics, metricGrpName, metricTags, "io", "doing I/O")); + this.ioTime.add(createIOThreadRatioMeter(metrics, metricGrpName, metricTags, "io", "doing I/O")); + + this.connectionsByCipher = new IntGaugeSuite<>(log, "sslCiphers", metrics, + cipherInformation -> { + Map tags = new LinkedHashMap<>(); + tags.put("cipher", cipherInformation.cipher()); + tags.put("protocol", cipherInformation.protocol()); + tags.putAll(metricTags); + return metrics.metricName("connections", metricGrpName, "The number of connections with this SSL cipher and protocol.", tags); + }, 100); + + this.connectionsByClient = new IntGaugeSuite<>(log, "clients", metrics, + clientInformation -> { + Map tags = new LinkedHashMap<>(); + tags.put("clientSoftwareName", clientInformation.softwareName()); + tags.put("clientSoftwareVersion", clientInformation.softwareVersion()); + tags.putAll(metricTags); + return metrics.metricName("connections", metricGrpName, "The number of connections with this client and version.", tags); + }, 100); + + metricName = metrics.metricName("connection-count", metricGrpName, "The current number of active connections.", metricTags); + topLevelMetricNames.add(metricName); + this.metrics.addMetric(metricName, (config, now) -> channels.size()); + } + + private Meter createMeter(Metrics metrics, String groupName, Map metricTags, + SampledStat stat, String baseName, String descriptiveName) { + MetricName rateMetricName = metrics.metricName(baseName + "-rate", groupName, + String.format("The number of %s per second", descriptiveName), metricTags); + MetricName totalMetricName = metrics.metricName(baseName + "-total", groupName, + String.format("The total number of %s", descriptiveName), metricTags); + if (stat == null) + return new Meter(rateMetricName, totalMetricName); + else + return new Meter(stat, rateMetricName, totalMetricName); + } + + private Meter createMeter(Metrics metrics, String groupName, Map metricTags, + String baseName, String descriptiveName) { + return createMeter(metrics, groupName, metricTags, null, baseName, descriptiveName); + } + + /** + * This method generates `time-total` metrics but has a couple of deficiencies: no `-ns` suffix and no dash between basename + * and `time-toal` suffix. + * @deprecated use {{@link #createIOThreadRatioMeter(Metrics, String, Map, String, String)}} for new metrics instead + */ + @Deprecated + private Meter createIOThreadRatioMeterLegacy(Metrics metrics, String groupName, Map metricTags, + String baseName, String action) { + MetricName rateMetricName = metrics.metricName(baseName + "-ratio", groupName, + String.format("*Deprecated* The fraction of time the I/O thread spent %s", action), metricTags); + MetricName totalMetricName = metrics.metricName(baseName + "time-total", groupName, + String.format("*Deprecated* The total time the I/O thread spent %s", action), metricTags); + return new Meter(TimeUnit.NANOSECONDS, rateMetricName, totalMetricName); + } + + private Meter createIOThreadRatioMeter(Metrics metrics, String groupName, Map metricTags, + String baseName, String action) { + MetricName rateMetricName = metrics.metricName(baseName + "-ratio", groupName, + String.format("The fraction of time the I/O thread spent %s", action), metricTags); + MetricName totalMetricName = metrics.metricName(baseName + "-time-ns-total", groupName, + String.format("The total time the I/O thread spent %s", action), metricTags); + return new Meter(TimeUnit.NANOSECONDS, rateMetricName, totalMetricName); + } + + private Sensor sensor(String name, Sensor... parents) { + Sensor sensor = metrics.sensor(name, parents); + sensors.add(sensor); + return sensor; + } + + public void maybeRegisterConnectionMetrics(String connectionId) { + if (!connectionId.isEmpty() && metricsPerConnection) { + // if one sensor of the metrics has been registered for the connection, + // then all other sensors should have been registered; and vice versa + String nodeRequestName = "node-" + connectionId + ".requests-sent"; + Sensor nodeRequest = this.metrics.getSensor(nodeRequestName); + if (nodeRequest == null) { + Map tags = new LinkedHashMap<>(metricTags); + tags.put("node-id", "node-" + connectionId); + + nodeRequest = sensor(nodeRequestName); + nodeRequest.add(createMeter(metrics, perConnectionMetricGrpName, tags, new WindowedCount(), "request", "requests sent")); + MetricName metricName = metrics.metricName("request-size-avg", perConnectionMetricGrpName, "The average size of requests sent.", tags); + nodeRequest.add(metricName, new Avg()); + metricName = metrics.metricName("request-size-max", perConnectionMetricGrpName, "The maximum size of any request sent.", tags); + nodeRequest.add(metricName, new Max()); + + String bytesSentName = "node-" + connectionId + ".bytes-sent"; + Sensor bytesSent = sensor(bytesSentName); + bytesSent.add(createMeter(metrics, perConnectionMetricGrpName, tags, "outgoing-byte", "outgoing bytes")); + + String nodeResponseName = "node-" + connectionId + ".responses-received"; + Sensor nodeResponse = sensor(nodeResponseName); + nodeResponse.add(createMeter(metrics, perConnectionMetricGrpName, tags, new WindowedCount(), "response", "responses received")); + + String bytesReceivedName = "node-" + connectionId + ".bytes-received"; + Sensor bytesReceive = sensor(bytesReceivedName); + bytesReceive.add(createMeter(metrics, perConnectionMetricGrpName, tags, "incoming-byte", "incoming bytes")); + + String nodeTimeName = "node-" + connectionId + ".latency"; + Sensor nodeRequestTime = sensor(nodeTimeName); + metricName = metrics.metricName("request-latency-avg", perConnectionMetricGrpName, tags); + nodeRequestTime.add(metricName, new Avg()); + metricName = metrics.metricName("request-latency-max", perConnectionMetricGrpName, tags); + nodeRequestTime.add(metricName, new Max()); + } + } + } + + public void recordBytesSent(String connectionId, long bytes, long currentTimeMs) { + this.bytesSent.record(bytes, currentTimeMs); + if (!connectionId.isEmpty()) { + String bytesSentName = "node-" + connectionId + ".bytes-sent"; + Sensor bytesSent = this.metrics.getSensor(bytesSentName); + if (bytesSent != null) + bytesSent.record(bytes, currentTimeMs); + } + } + + public void recordCompletedSend(String connectionId, long totalBytes, long currentTimeMs) { + requestsSent.record(totalBytes, currentTimeMs); + if (!connectionId.isEmpty()) { + String nodeRequestName = "node-" + connectionId + ".requests-sent"; + Sensor nodeRequest = this.metrics.getSensor(nodeRequestName); + if (nodeRequest != null) + nodeRequest.record(totalBytes, currentTimeMs); + } + } + + public void recordBytesReceived(String connectionId, long bytes, long currentTimeMs) { + this.bytesReceived.record(bytes, currentTimeMs); + if (!connectionId.isEmpty()) { + String bytesReceivedName = "node-" + connectionId + ".bytes-received"; + Sensor bytesReceived = this.metrics.getSensor(bytesReceivedName); + if (bytesReceived != null) + bytesReceived.record(bytes, currentTimeMs); + } + } + + public void recordCompletedReceive(String connectionId, long totalBytes, long currentTimeMs) { + responsesReceived.record(totalBytes, currentTimeMs); + if (!connectionId.isEmpty()) { + String nodeRequestName = "node-" + connectionId + ".responses-received"; + Sensor nodeRequest = this.metrics.getSensor(nodeRequestName); + if (nodeRequest != null) + nodeRequest.record(totalBytes, currentTimeMs); + } + } + + public void close() { + for (MetricName metricName : topLevelMetricNames) + metrics.removeMetric(metricName); + for (Sensor sensor : sensors) + metrics.removeSensor(sensor.name()); + connectionsByCipher.close(); + connectionsByClient.close(); + } + } + + /** + * Encapsulate a channel that must be closed after a specific delay has elapsed due to authentication failure. + */ + private class DelayedAuthenticationFailureClose { + private final KafkaChannel channel; + private final long endTimeNanos; + private boolean closed; + + /** + * @param channel The channel whose close is being delayed + * @param delayMs The amount of time by which the operation should be delayed + */ + public DelayedAuthenticationFailureClose(KafkaChannel channel, int delayMs) { + this.channel = channel; + this.endTimeNanos = time.nanoseconds() + (delayMs * 1000L * 1000L); + this.closed = false; + } + + /** + * Try to close this channel if the delay has expired. + * @param currentTimeNanos The current time + * @return True if the delay has expired and the channel was closed; false otherwise + */ + public final boolean tryClose(long currentTimeNanos) { + if (endTimeNanos <= currentTimeNanos) + closeNow(); + return closed; + } + + /** + * Close the channel now, regardless of whether the delay has expired or not. + */ + public final void closeNow() { + if (closed) + throw new IllegalStateException("Attempt to close a channel that has already been closed"); + handleCloseOnAuthenticationFailure(channel); + closed = true; + } + } + + // helper class for tracking least recently used connections to enable idle connection closing + private static class IdleExpiryManager { + private final Map lruConnections; + private final long connectionsMaxIdleNanos; + private long nextIdleCloseCheckTime; + + public IdleExpiryManager(Time time, long connectionsMaxIdleMs) { + this.connectionsMaxIdleNanos = connectionsMaxIdleMs * 1000 * 1000; + // initial capacity and load factor are default, we set them explicitly because we want to set accessOrder = true + this.lruConnections = new LinkedHashMap<>(16, .75F, true); + this.nextIdleCloseCheckTime = time.nanoseconds() + this.connectionsMaxIdleNanos; + } + + public void update(String connectionId, long currentTimeNanos) { + lruConnections.put(connectionId, currentTimeNanos); + } + + public Map.Entry pollExpiredConnection(long currentTimeNanos) { + if (currentTimeNanos <= nextIdleCloseCheckTime) + return null; + + if (lruConnections.isEmpty()) { + nextIdleCloseCheckTime = currentTimeNanos + connectionsMaxIdleNanos; + return null; + } + + Map.Entry oldestConnectionEntry = lruConnections.entrySet().iterator().next(); + Long connectionLastActiveTime = oldestConnectionEntry.getValue(); + nextIdleCloseCheckTime = connectionLastActiveTime + connectionsMaxIdleNanos; + + if (currentTimeNanos > nextIdleCloseCheckTime) + return oldestConnectionEntry; + else + return null; + } + + public void remove(String connectionId) { + lruConnections.remove(connectionId); + } + } + + //package-private for testing + boolean isOutOfMemory() { + return outOfMemory; + } + + //package-private for testing + boolean isMadeReadProgressLastPoll() { + return madeReadProgressLastPoll; + } + + // package-private for testing + Map delayedClosingChannels() { + return delayedClosingChannels; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/Send.java b/clients/src/main/java/org/apache/kafka/common/network/Send.java new file mode 100644 index 0000000..1a7b0a9 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/Send.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import java.io.IOException; + +/** + * This interface models the in-progress sending of data. + */ +public interface Send { + + /** + * Is this send complete? + */ + boolean completed(); + + /** + * Write some as-yet unwritten bytes from this send to the provided channel. It may take multiple calls for the send + * to be completely written + * @param channel The Channel to write to + * @return The number of bytes written + * @throws IOException If the write fails + */ + long writeTo(TransferableChannel channel) throws IOException; + + /** + * Size of the send + */ + long size(); + +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/SslChannelBuilder.java b/clients/src/main/java/org/apache/kafka/common/network/SslChannelBuilder.java new file mode 100644 index 0000000..4dabf0a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/SslChannelBuilder.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.config.SslConfigs; +import org.apache.kafka.common.config.internals.BrokerSecurityConfigs; +import org.apache.kafka.common.memory.MemoryPool; +import org.apache.kafka.common.security.auth.KafkaPrincipal; +import org.apache.kafka.common.security.auth.KafkaPrincipalBuilder; +import org.apache.kafka.common.security.auth.KafkaPrincipalSerde; +import org.apache.kafka.common.security.auth.SslAuthenticationContext; +import org.apache.kafka.common.security.ssl.SslFactory; +import org.apache.kafka.common.security.ssl.SslPrincipalMapper; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; + +import java.io.Closeable; +import java.io.IOException; +import java.net.InetAddress; +import java.nio.channels.SelectionKey; +import java.nio.channels.SocketChannel; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.Supplier; + +public class SslChannelBuilder implements ChannelBuilder, ListenerReconfigurable { + private final ListenerName listenerName; + private final boolean isInterBrokerListener; + private SslFactory sslFactory; + private Mode mode; + private Map configs; + private SslPrincipalMapper sslPrincipalMapper; + private final Logger log; + + /** + * Constructs an SSL channel builder. ListenerName is provided only + * for server channel builder and will be null for client channel builder. + */ + public SslChannelBuilder(Mode mode, + ListenerName listenerName, + boolean isInterBrokerListener, + LogContext logContext) { + this.mode = mode; + this.listenerName = listenerName; + this.isInterBrokerListener = isInterBrokerListener; + this.log = logContext.logger(getClass()); + } + + public void configure(Map configs) throws KafkaException { + try { + this.configs = configs; + String sslPrincipalMappingRules = (String) configs.get(BrokerSecurityConfigs.SSL_PRINCIPAL_MAPPING_RULES_CONFIG); + if (sslPrincipalMappingRules != null) + sslPrincipalMapper = SslPrincipalMapper.fromRules(sslPrincipalMappingRules); + this.sslFactory = new SslFactory(mode, null, isInterBrokerListener); + this.sslFactory.configure(this.configs); + } catch (KafkaException e) { + throw e; + } catch (Exception e) { + throw new KafkaException(e); + } + } + + @Override + public Set reconfigurableConfigs() { + return SslConfigs.RECONFIGURABLE_CONFIGS; + } + + @Override + public void validateReconfiguration(Map configs) { + sslFactory.validateReconfiguration(configs); + } + + @Override + public void reconfigure(Map configs) { + sslFactory.reconfigure(configs); + } + + @Override + public ListenerName listenerName() { + return listenerName; + } + + @Override + public KafkaChannel buildChannel(String id, SelectionKey key, int maxReceiveSize, + MemoryPool memoryPool, ChannelMetadataRegistry metadataRegistry) throws KafkaException { + try { + SslTransportLayer transportLayer = buildTransportLayer(sslFactory, id, key, metadataRegistry); + Supplier authenticatorCreator = () -> + new SslAuthenticator(configs, transportLayer, listenerName, sslPrincipalMapper); + return new KafkaChannel(id, transportLayer, authenticatorCreator, maxReceiveSize, + memoryPool != null ? memoryPool : MemoryPool.NONE, metadataRegistry); + } catch (Exception e) { + log.info("Failed to create channel due to ", e); + throw new KafkaException(e); + } + } + + @Override + public void close() { + if (sslFactory != null) sslFactory.close(); + } + + protected SslTransportLayer buildTransportLayer(SslFactory sslFactory, String id, SelectionKey key, ChannelMetadataRegistry metadataRegistry) throws IOException { + SocketChannel socketChannel = (SocketChannel) key.channel(); + return SslTransportLayer.create(id, key, sslFactory.createSslEngine(socketChannel.socket()), + metadataRegistry); + } + + /** + * Note that client SSL authentication is handled in {@link SslTransportLayer}. This class is only used + * to transform the derived principal using a {@link KafkaPrincipalBuilder} configured by the user. + */ + private static class SslAuthenticator implements Authenticator { + private final SslTransportLayer transportLayer; + private final KafkaPrincipalBuilder principalBuilder; + private final ListenerName listenerName; + + private SslAuthenticator(Map configs, SslTransportLayer transportLayer, ListenerName listenerName, SslPrincipalMapper sslPrincipalMapper) { + this.transportLayer = transportLayer; + this.principalBuilder = ChannelBuilders.createPrincipalBuilder(configs, null, sslPrincipalMapper); + this.listenerName = listenerName; + } + /** + * No-Op for plaintext authenticator + */ + @Override + public void authenticate() {} + + /** + * Constructs Principal using configured principalBuilder. + * @return the built principal + */ + @Override + public KafkaPrincipal principal() { + InetAddress clientAddress = transportLayer.socketChannel().socket().getInetAddress(); + // listenerName should only be null in Client mode where principal() should not be called + if (listenerName == null) + throw new IllegalStateException("Unexpected call to principal() when listenerName is null"); + SslAuthenticationContext context = new SslAuthenticationContext( + transportLayer.sslSession(), + clientAddress, + listenerName.value()); + return principalBuilder.build(context); + } + + @Override + public Optional principalSerde() { + return principalBuilder instanceof KafkaPrincipalSerde ? Optional.of((KafkaPrincipalSerde) principalBuilder) : Optional.empty(); + } + + @Override + public void close() throws IOException { + if (principalBuilder instanceof Closeable) + Utils.closeQuietly((Closeable) principalBuilder, "principal builder"); + } + + /** + * SslAuthenticator doesn't implement any additional authentication mechanism. + * @return true + */ + @Override + public boolean complete() { + return true; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java b/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java new file mode 100644 index 0000000..b9879ad --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/SslTransportLayer.java @@ -0,0 +1,1006 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import java.io.IOException; +import java.io.EOFException; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.channels.SocketChannel; +import java.nio.channels.SelectionKey; +import java.nio.channels.CancelledKeyException; + +import java.security.Principal; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLEngineResult.Status; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLHandshakeException; +import javax.net.ssl.SSLKeyException; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLProtocolException; +import javax.net.ssl.SSLSession; + +import org.apache.kafka.common.errors.SslAuthenticationException; +import org.apache.kafka.common.security.auth.KafkaPrincipal; +import org.apache.kafka.common.utils.ByteUtils; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.ByteBufferUnmapper; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; + +/* + * Transport layer for SSL communication + * + * + * TLS v1.3 notes: + * https://tools.ietf.org/html/rfc8446#section-4.6 : Post-Handshake Messages + * "TLS also allows other messages to be sent after the main handshake. + * These messages use a handshake content type and are encrypted under + * the appropriate application traffic key." + */ +public class SslTransportLayer implements TransportLayer { + private enum State { + // Initial state + NOT_INITIALIZED, + // SSLEngine is in handshake mode + HANDSHAKE, + // SSL handshake failed, connection will be terminated + HANDSHAKE_FAILED, + // SSLEngine has completed handshake, post-handshake messages may be pending for TLSv1.3 + POST_HANDSHAKE, + // SSLEngine has completed handshake, any post-handshake messages have been processed for TLSv1.3 + // For TLSv1.3, we move the channel to READY state when incoming data is processed after handshake + READY, + // Channel is being closed + CLOSING + } + + private final String channelId; + private final SSLEngine sslEngine; + private final SelectionKey key; + private final SocketChannel socketChannel; + private final ChannelMetadataRegistry metadataRegistry; + private final Logger log; + + private HandshakeStatus handshakeStatus; + private SSLEngineResult handshakeResult; + private State state; + private SslAuthenticationException handshakeException; + private ByteBuffer netReadBuffer; + private ByteBuffer netWriteBuffer; + private ByteBuffer appReadBuffer; + private ByteBuffer fileChannelBuffer; + private boolean hasBytesBuffered; + + public static SslTransportLayer create(String channelId, SelectionKey key, SSLEngine sslEngine, + ChannelMetadataRegistry metadataRegistry) throws IOException { + return new SslTransportLayer(channelId, key, sslEngine, metadataRegistry); + } + + // Prefer `create`, only use this in tests + SslTransportLayer(String channelId, SelectionKey key, SSLEngine sslEngine, + ChannelMetadataRegistry metadataRegistry) { + this.channelId = channelId; + this.key = key; + this.socketChannel = (SocketChannel) key.channel(); + this.sslEngine = sslEngine; + this.state = State.NOT_INITIALIZED; + this.metadataRegistry = metadataRegistry; + + final LogContext logContext = new LogContext(String.format("[SslTransportLayer channelId=%s key=%s] ", channelId, key)); + this.log = logContext.logger(getClass()); + } + + // Visible for testing + protected void startHandshake() throws IOException { + if (state != State.NOT_INITIALIZED) + throw new IllegalStateException("startHandshake() can only be called once, state " + state); + + this.netReadBuffer = ByteBuffer.allocate(netReadBufferSize()); + this.netWriteBuffer = ByteBuffer.allocate(netWriteBufferSize()); + this.appReadBuffer = ByteBuffer.allocate(applicationBufferSize()); + netWriteBuffer.limit(0); + netReadBuffer.limit(0); + + state = State.HANDSHAKE; + //initiate handshake + sslEngine.beginHandshake(); + handshakeStatus = sslEngine.getHandshakeStatus(); + } + + @Override + public boolean ready() { + return state == State.POST_HANDSHAKE || state == State.READY; + } + + /** + * does socketChannel.finishConnect() + */ + @Override + public boolean finishConnect() throws IOException { + boolean connected = socketChannel.finishConnect(); + if (connected) + key.interestOps(key.interestOps() & ~SelectionKey.OP_CONNECT | SelectionKey.OP_READ); + return connected; + } + + /** + * disconnects selectionKey. + */ + @Override + public void disconnect() { + key.cancel(); + } + + @Override + public SocketChannel socketChannel() { + return socketChannel; + } + + @Override + public SelectionKey selectionKey() { + return key; + } + + @Override + public boolean isOpen() { + return socketChannel.isOpen(); + } + + @Override + public boolean isConnected() { + return socketChannel.isConnected(); + } + + /** + * Sends an SSL close message and closes socketChannel. + */ + @Override + public void close() throws IOException { + State prevState = state; + if (state == State.CLOSING) return; + state = State.CLOSING; + sslEngine.closeOutbound(); + try { + if (prevState != State.NOT_INITIALIZED && isConnected()) { + if (!flush(netWriteBuffer)) { + throw new IOException("Remaining data in the network buffer, can't send SSL close message."); + } + //prep the buffer for the close message + netWriteBuffer.clear(); + //perform the close, since we called sslEngine.closeOutbound + SSLEngineResult wrapResult = sslEngine.wrap(ByteUtils.EMPTY_BUF, netWriteBuffer); + //we should be in a close state + if (wrapResult.getStatus() != SSLEngineResult.Status.CLOSED) { + throw new IOException("Unexpected status returned by SSLEngine.wrap, expected CLOSED, received " + + wrapResult.getStatus() + ". Will not send close message to peer."); + } + netWriteBuffer.flip(); + flush(netWriteBuffer); + } + } catch (IOException ie) { + log.debug("Failed to send SSL Close message", ie); + } finally { + socketChannel.socket().close(); + socketChannel.close(); + netReadBuffer = null; + netWriteBuffer = null; + appReadBuffer = null; + if (fileChannelBuffer != null) { + ByteBufferUnmapper.unmap("fileChannelBuffer", fileChannelBuffer); + fileChannelBuffer = null; + } + } + } + + /** + * returns true if there are any pending contents in netWriteBuffer + */ + @Override + public boolean hasPendingWrites() { + return netWriteBuffer.hasRemaining(); + } + + /** + * Reads available bytes from socket channel to `netReadBuffer`. + * Visible for testing. + * @return number of bytes read + */ + protected int readFromSocketChannel() throws IOException { + return socketChannel.read(netReadBuffer); + } + + /** + * Flushes the buffer to the network, non blocking. + * Visible for testing. + * @param buf ByteBuffer + * @return boolean true if the buffer has been emptied out, false otherwise + * @throws IOException + */ + protected boolean flush(ByteBuffer buf) throws IOException { + int remaining = buf.remaining(); + if (remaining > 0) { + int written = socketChannel.write(buf); + return written >= remaining; + } + return true; + } + + /** + * Performs SSL handshake, non blocking. + * Before application data (kafka protocols) can be sent client & kafka broker must + * perform ssl handshake. + * During the handshake SSLEngine generates encrypted data that will be transported over socketChannel. + * Each SSLEngine operation generates SSLEngineResult , of which SSLEngineResult.handshakeStatus field is used to + * determine what operation needs to occur to move handshake along. + * A typical handshake might look like this. + * +-------------+----------------------------------+-------------+ + * | client | SSL/TLS message | HSStatus | + * +-------------+----------------------------------+-------------+ + * | wrap() | ClientHello | NEED_UNWRAP | + * | unwrap() | ServerHello/Cert/ServerHelloDone | NEED_WRAP | + * | wrap() | ClientKeyExchange | NEED_WRAP | + * | wrap() | ChangeCipherSpec | NEED_WRAP | + * | wrap() | Finished | NEED_UNWRAP | + * | unwrap() | ChangeCipherSpec | NEED_UNWRAP | + * | unwrap() | Finished | FINISHED | + * +-------------+----------------------------------+-------------+ + * + * @throws IOException if read/write fails + * @throws SslAuthenticationException if handshake fails with an {@link SSLException} + */ + @Override + public void handshake() throws IOException { + if (state == State.NOT_INITIALIZED) { + try { + startHandshake(); + } catch (SSLException e) { + maybeProcessHandshakeFailure(e, false, null); + } + } + if (ready()) + throw renegotiationException(); + if (state == State.CLOSING) + throw closingException(); + + int read = 0; + boolean readable = key.isReadable(); + try { + // Read any available bytes before attempting any writes to ensure that handshake failures + // reported by the peer are processed even if writes fail (since peer closes connection + // if handshake fails) + if (readable) + read = readFromSocketChannel(); + + doHandshake(); + if (ready()) + updateBytesBuffered(true); + } catch (SSLException e) { + maybeProcessHandshakeFailure(e, true, null); + } catch (IOException e) { + maybeThrowSslAuthenticationException(); + + // This exception could be due to a write. If there is data available to unwrap in the buffer, or data available + // in the socket channel to read and unwrap, process the data so that any SSL handshake exceptions are reported. + try { + do { + handshakeUnwrap(false, true); + } while (readable && readFromSocketChannel() > 0); + } catch (SSLException e1) { + maybeProcessHandshakeFailure(e1, false, e); + } + + // If we get here, this is not a handshake failure, throw the original IOException + throw e; + } + + // Read from socket failed, so throw any pending handshake exception or EOF exception. + if (read == -1) { + maybeThrowSslAuthenticationException(); + throw new EOFException("EOF during handshake, handshake status is " + handshakeStatus); + } + } + + @SuppressWarnings("fallthrough") + private void doHandshake() throws IOException { + boolean read = key.isReadable(); + boolean write = key.isWritable(); + handshakeStatus = sslEngine.getHandshakeStatus(); + if (!flush(netWriteBuffer)) { + key.interestOps(key.interestOps() | SelectionKey.OP_WRITE); + return; + } + // Throw any pending handshake exception since `netWriteBuffer` has been flushed + maybeThrowSslAuthenticationException(); + + switch (handshakeStatus) { + case NEED_TASK: + log.trace("SSLHandshake NEED_TASK channelId {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}", + channelId, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position()); + handshakeStatus = runDelegatedTasks(); + break; + case NEED_WRAP: + log.trace("SSLHandshake NEED_WRAP channelId {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}", + channelId, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position()); + handshakeResult = handshakeWrap(write); + if (handshakeResult.getStatus() == Status.BUFFER_OVERFLOW) { + int currentNetWriteBufferSize = netWriteBufferSize(); + netWriteBuffer.compact(); + netWriteBuffer = Utils.ensureCapacity(netWriteBuffer, currentNetWriteBufferSize); + netWriteBuffer.flip(); + if (netWriteBuffer.limit() >= currentNetWriteBufferSize) { + throw new IllegalStateException("Buffer overflow when available data size (" + netWriteBuffer.limit() + + ") >= network buffer size (" + currentNetWriteBufferSize + ")"); + } + } else if (handshakeResult.getStatus() == Status.BUFFER_UNDERFLOW) { + throw new IllegalStateException("Should not have received BUFFER_UNDERFLOW during handshake WRAP."); + } else if (handshakeResult.getStatus() == Status.CLOSED) { + throw new EOFException(); + } + log.trace("SSLHandshake NEED_WRAP channelId {}, handshakeResult {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}", + channelId, handshakeResult, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position()); + //if handshake status is not NEED_UNWRAP or unable to flush netWriteBuffer contents + //we will break here otherwise we can do need_unwrap in the same call. + if (handshakeStatus != HandshakeStatus.NEED_UNWRAP || !flush(netWriteBuffer)) { + key.interestOps(key.interestOps() | SelectionKey.OP_WRITE); + break; + } + case NEED_UNWRAP: + log.trace("SSLHandshake NEED_UNWRAP channelId {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}", + channelId, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position()); + do { + handshakeResult = handshakeUnwrap(read, false); + if (handshakeResult.getStatus() == Status.BUFFER_OVERFLOW) { + int currentAppBufferSize = applicationBufferSize(); + appReadBuffer = Utils.ensureCapacity(appReadBuffer, currentAppBufferSize); + if (appReadBuffer.position() > currentAppBufferSize) { + throw new IllegalStateException("Buffer underflow when available data size (" + appReadBuffer.position() + + ") > packet buffer size (" + currentAppBufferSize + ")"); + } + } + } while (handshakeResult.getStatus() == Status.BUFFER_OVERFLOW); + if (handshakeResult.getStatus() == Status.BUFFER_UNDERFLOW) { + int currentNetReadBufferSize = netReadBufferSize(); + netReadBuffer = Utils.ensureCapacity(netReadBuffer, currentNetReadBufferSize); + if (netReadBuffer.position() >= currentNetReadBufferSize) { + throw new IllegalStateException("Buffer underflow when there is available data"); + } + } else if (handshakeResult.getStatus() == Status.CLOSED) { + throw new EOFException("SSL handshake status CLOSED during handshake UNWRAP"); + } + log.trace("SSLHandshake NEED_UNWRAP channelId {}, handshakeResult {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {}", + channelId, handshakeResult, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position()); + + //if handshakeStatus completed than fall-through to finished status. + //after handshake is finished there is no data left to read/write in socketChannel. + //so the selector won't invoke this channel if we don't go through the handshakeFinished here. + if (handshakeStatus != HandshakeStatus.FINISHED) { + if (handshakeStatus == HandshakeStatus.NEED_WRAP) { + key.interestOps(key.interestOps() | SelectionKey.OP_WRITE); + } else if (handshakeStatus == HandshakeStatus.NEED_UNWRAP) { + key.interestOps(key.interestOps() & ~SelectionKey.OP_WRITE); + } + break; + } + case FINISHED: + handshakeFinished(); + break; + case NOT_HANDSHAKING: + handshakeFinished(); + break; + default: + throw new IllegalStateException(String.format("Unexpected status [%s]", handshakeStatus)); + } + } + + private SSLHandshakeException renegotiationException() { + return new SSLHandshakeException("Renegotiation is not supported"); + } + + private IllegalStateException closingException() { + throw new IllegalStateException("Channel is in closing state"); + } + + /** + * Executes the SSLEngine tasks needed. + * @return HandshakeStatus + */ + private HandshakeStatus runDelegatedTasks() { + for (;;) { + Runnable task = delegatedTask(); + if (task == null) { + break; + } + task.run(); + } + return sslEngine.getHandshakeStatus(); + } + + /** + * Checks if the handshake status is finished + * Sets the interestOps for the selectionKey. + */ + private void handshakeFinished() throws IOException { + // SSLEngine.getHandshakeStatus is transient and it doesn't record FINISHED status properly. + // It can move from FINISHED status to NOT_HANDSHAKING after the handshake is completed. + // Hence we also need to check handshakeResult.getHandshakeStatus() if the handshake finished or not + if (handshakeResult.getHandshakeStatus() == HandshakeStatus.FINISHED) { + //we are complete if we have delivered the last packet + //remove OP_WRITE if we are complete, otherwise we still have data to write + if (netWriteBuffer.hasRemaining()) + key.interestOps(key.interestOps() | SelectionKey.OP_WRITE); + else { + state = sslEngine.getSession().getProtocol().equals("TLSv1.3") ? State.POST_HANDSHAKE : State.READY; + key.interestOps(key.interestOps() & ~SelectionKey.OP_WRITE); + SSLSession session = sslEngine.getSession(); + log.debug("SSL handshake completed successfully with peerHost '{}' peerPort {} peerPrincipal '{}' cipherSuite '{}'", + session.getPeerHost(), session.getPeerPort(), peerPrincipal(), session.getCipherSuite()); + metadataRegistry.registerCipherInformation( + new CipherInformation(session.getCipherSuite(), session.getProtocol())); + } + + log.trace("SSLHandshake FINISHED channelId {}, appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {} ", + channelId, appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position()); + } else { + throw new IOException("NOT_HANDSHAKING during handshake"); + } + } + + /** + * Performs the WRAP function + * @param doWrite boolean + * @return SSLEngineResult + * @throws IOException + */ + private SSLEngineResult handshakeWrap(boolean doWrite) throws IOException { + log.trace("SSLHandshake handshakeWrap {}", channelId); + if (netWriteBuffer.hasRemaining()) + throw new IllegalStateException("handshakeWrap called with netWriteBuffer not empty"); + //this should never be called with a network buffer that contains data + //so we can clear it here. + netWriteBuffer.clear(); + SSLEngineResult result = sslEngine.wrap(ByteUtils.EMPTY_BUF, netWriteBuffer); + //prepare the results to be written + netWriteBuffer.flip(); + handshakeStatus = result.getHandshakeStatus(); + if (result.getStatus() == SSLEngineResult.Status.OK && + result.getHandshakeStatus() == HandshakeStatus.NEED_TASK) { + handshakeStatus = runDelegatedTasks(); + } + + if (doWrite) flush(netWriteBuffer); + return result; + } + + /** + * Perform handshake unwrap + * @param doRead boolean If true, read more from the socket channel + * @param ignoreHandshakeStatus If true, continue to unwrap if data available regardless of handshake status + * @return SSLEngineResult + * @throws IOException + */ + private SSLEngineResult handshakeUnwrap(boolean doRead, boolean ignoreHandshakeStatus) throws IOException { + log.trace("SSLHandshake handshakeUnwrap {}", channelId); + SSLEngineResult result; + int read = 0; + if (doRead) + read = readFromSocketChannel(); + boolean cont; + do { + //prepare the buffer with the incoming data + int position = netReadBuffer.position(); + netReadBuffer.flip(); + result = sslEngine.unwrap(netReadBuffer, appReadBuffer); + netReadBuffer.compact(); + handshakeStatus = result.getHandshakeStatus(); + if (result.getStatus() == SSLEngineResult.Status.OK && + result.getHandshakeStatus() == HandshakeStatus.NEED_TASK) { + handshakeStatus = runDelegatedTasks(); + } + cont = (result.getStatus() == SSLEngineResult.Status.OK && + handshakeStatus == HandshakeStatus.NEED_UNWRAP) || + (ignoreHandshakeStatus && netReadBuffer.position() != position); + log.trace("SSLHandshake handshakeUnwrap: handshakeStatus {} status {}", handshakeStatus, result.getStatus()); + } while (netReadBuffer.position() != 0 && cont); + + // Throw EOF exception for failed read after processing already received data + // so that handshake failures are reported correctly + if (read == -1) + throw new EOFException("EOF during handshake, handshake status is " + handshakeStatus); + + return result; + } + + + /** + * Reads a sequence of bytes from this channel into the given buffer. Reads as much as possible + * until either the dst buffer is full or there is no more data in the socket. + * + * @param dst The buffer into which bytes are to be transferred + * @return The number of bytes read, possible zero or -1 if the channel has reached end-of-stream + * and no more data is available + * @throws IOException if some other I/O error occurs + */ + @Override + public int read(ByteBuffer dst) throws IOException { + if (state == State.CLOSING) return -1; + else if (!ready()) return 0; + + //if we have unread decrypted data in appReadBuffer read that into dst buffer. + int read = 0; + if (appReadBuffer.position() > 0) { + read = readFromAppBuffer(dst); + } + + boolean readFromNetwork = false; + boolean isClosed = false; + // Each loop reads at most once from the socket. + while (dst.remaining() > 0) { + int netread = 0; + netReadBuffer = Utils.ensureCapacity(netReadBuffer, netReadBufferSize()); + if (netReadBuffer.remaining() > 0) { + netread = readFromSocketChannel(); + if (netread > 0) + readFromNetwork = true; + } + + while (netReadBuffer.position() > 0) { + netReadBuffer.flip(); + SSLEngineResult unwrapResult; + try { + unwrapResult = sslEngine.unwrap(netReadBuffer, appReadBuffer); + if (state == State.POST_HANDSHAKE && appReadBuffer.position() != 0) { + // For TLSv1.3, we have finished processing post-handshake messages since we are now processing data + state = State.READY; + } + } catch (SSLException e) { + // For TLSv1.3, handle SSL exceptions while processing post-handshake messages as authentication exceptions + if (state == State.POST_HANDSHAKE) { + state = State.HANDSHAKE_FAILED; + throw new SslAuthenticationException("Failed to process post-handshake messages", e); + } else + throw e; + } + netReadBuffer.compact(); + // handle ssl renegotiation. + if (unwrapResult.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING && + unwrapResult.getHandshakeStatus() != HandshakeStatus.FINISHED && + unwrapResult.getStatus() == Status.OK) { + log.error("Renegotiation requested, but it is not supported, channelId {}, " + + "appReadBuffer pos {}, netReadBuffer pos {}, netWriteBuffer pos {} handshakeStatus {}", channelId, + appReadBuffer.position(), netReadBuffer.position(), netWriteBuffer.position(), unwrapResult.getHandshakeStatus()); + throw renegotiationException(); + } + + if (unwrapResult.getStatus() == Status.OK) { + read += readFromAppBuffer(dst); + } else if (unwrapResult.getStatus() == Status.BUFFER_OVERFLOW) { + int currentApplicationBufferSize = applicationBufferSize(); + appReadBuffer = Utils.ensureCapacity(appReadBuffer, currentApplicationBufferSize); + if (appReadBuffer.position() >= currentApplicationBufferSize) { + throw new IllegalStateException("Buffer overflow when available data size (" + appReadBuffer.position() + + ") >= application buffer size (" + currentApplicationBufferSize + ")"); + } + + // appReadBuffer will extended upto currentApplicationBufferSize + // we need to read the existing content into dst before we can do unwrap again. If there are no space in dst + // we can break here. + if (dst.hasRemaining()) + read += readFromAppBuffer(dst); + else + break; + } else if (unwrapResult.getStatus() == Status.BUFFER_UNDERFLOW) { + int currentNetReadBufferSize = netReadBufferSize(); + netReadBuffer = Utils.ensureCapacity(netReadBuffer, currentNetReadBufferSize); + if (netReadBuffer.position() >= currentNetReadBufferSize) { + throw new IllegalStateException("Buffer underflow when available data size (" + netReadBuffer.position() + + ") > packet buffer size (" + currentNetReadBufferSize + ")"); + } + break; + } else if (unwrapResult.getStatus() == Status.CLOSED) { + // If data has been read and unwrapped, return the data. Close will be handled on the next poll. + if (appReadBuffer.position() == 0 && read == 0) + throw new EOFException(); + else { + isClosed = true; + break; + } + } + } + if (read == 0 && netread < 0) + throw new EOFException("EOF during read"); + if (netread <= 0 || isClosed) + break; + } + updateBytesBuffered(readFromNetwork || read > 0); + // If data has been read and unwrapped, return the data even if end-of-stream, channel will be closed + // on a subsequent poll. + return read; + } + + + /** + * Reads a sequence of bytes from this channel into the given buffers. + * + * @param dsts - The buffers into which bytes are to be transferred. + * @return The number of bytes read, possibly zero, or -1 if the channel has reached end-of-stream. + * @throws IOException if some other I/O error occurs + */ + @Override + public long read(ByteBuffer[] dsts) throws IOException { + return read(dsts, 0, dsts.length); + } + + + /** + * Reads a sequence of bytes from this channel into a subsequence of the given buffers. + * @param dsts - The buffers into which bytes are to be transferred + * @param offset - The offset within the buffer array of the first buffer into which bytes are to be transferred; must be non-negative and no larger than dsts.length. + * @param length - The maximum number of buffers to be accessed; must be non-negative and no larger than dsts.length - offset + * @return The number of bytes read, possibly zero, or -1 if the channel has reached end-of-stream. + * @throws IOException if some other I/O error occurs + */ + @Override + public long read(ByteBuffer[] dsts, int offset, int length) throws IOException { + if ((offset < 0) || (length < 0) || (offset > dsts.length - length)) + throw new IndexOutOfBoundsException(); + + int totalRead = 0; + int i = offset; + while (i < length) { + if (dsts[i].hasRemaining()) { + int read = read(dsts[i]); + if (read > 0) + totalRead += read; + else + break; + } + if (!dsts[i].hasRemaining()) { + i++; + } + } + return totalRead; + } + + + /** + * Writes a sequence of bytes to this channel from the given buffer. + * + * @param src The buffer from which bytes are to be retrieved + * @return The number of bytes read from src, possibly zero, or -1 if the channel has reached end-of-stream + * @throws IOException If some other I/O error occurs + */ + @Override + public int write(ByteBuffer src) throws IOException { + if (state == State.CLOSING) + throw closingException(); + if (!ready()) + return 0; + + int written = 0; + while (flush(netWriteBuffer) && src.hasRemaining()) { + netWriteBuffer.clear(); + SSLEngineResult wrapResult = sslEngine.wrap(src, netWriteBuffer); + netWriteBuffer.flip(); + + //handle ssl renegotiation + if (wrapResult.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING && wrapResult.getStatus() == Status.OK) + throw renegotiationException(); + + if (wrapResult.getStatus() == Status.OK) { + written += wrapResult.bytesConsumed(); + } else if (wrapResult.getStatus() == Status.BUFFER_OVERFLOW) { + // BUFFER_OVERFLOW means that the last `wrap` call had no effect, so we expand the buffer and try again + netWriteBuffer = Utils.ensureCapacity(netWriteBuffer, netWriteBufferSize()); + netWriteBuffer.position(netWriteBuffer.limit()); + } else if (wrapResult.getStatus() == Status.BUFFER_UNDERFLOW) { + throw new IllegalStateException("SSL BUFFER_UNDERFLOW during write"); + } else if (wrapResult.getStatus() == Status.CLOSED) { + throw new EOFException(); + } + } + return written; + } + + /** + * Writes a sequence of bytes to this channel from the subsequence of the given buffers. + * + * @param srcs The buffers from which bytes are to be retrieved + * @param offset The offset within the buffer array of the first buffer from which bytes are to be retrieved; must be non-negative and no larger than srcs.length. + * @param length - The maximum number of buffers to be accessed; must be non-negative and no larger than srcs.length - offset. + * @return returns no.of bytes written , possibly zero. + * @throws IOException If some other I/O error occurs + */ + @Override + public long write(ByteBuffer[] srcs, int offset, int length) throws IOException { + if ((offset < 0) || (length < 0) || (offset > srcs.length - length)) + throw new IndexOutOfBoundsException(); + int totalWritten = 0; + int i = offset; + while (i < length) { + if (srcs[i].hasRemaining() || hasPendingWrites()) { + int written = write(srcs[i]); + if (written > 0) { + totalWritten += written; + } + } + if (!srcs[i].hasRemaining() && !hasPendingWrites()) { + i++; + } else { + // if we are unable to write the current buffer to socketChannel we should break, + // as we might have reached max socket send buffer size. + break; + } + } + return totalWritten; + } + + /** + * Writes a sequence of bytes to this channel from the given buffers. + * + * @param srcs The buffers from which bytes are to be retrieved + * @return returns no.of bytes consumed by SSLEngine.wrap , possibly zero. + * @throws IOException If some other I/O error occurs + */ + @Override + public long write(ByteBuffer[] srcs) throws IOException { + return write(srcs, 0, srcs.length); + } + + + /** + * SSLSession's peerPrincipal for the remote host. + * @return Principal + */ + public Principal peerPrincipal() { + try { + return sslEngine.getSession().getPeerPrincipal(); + } catch (SSLPeerUnverifiedException se) { + log.debug("SSL peer is not authenticated, returning ANONYMOUS instead"); + return KafkaPrincipal.ANONYMOUS; + } + } + + /** + * returns an SSL Session after the handshake is established + * throws IllegalStateException if the handshake is not established + */ + public SSLSession sslSession() throws IllegalStateException { + return sslEngine.getSession(); + } + + /** + * Adds interestOps to SelectionKey of the TransportLayer + * @param ops SelectionKey interestOps + */ + @Override + public void addInterestOps(int ops) { + if (!key.isValid()) + throw new CancelledKeyException(); + else if (!ready()) + throw new IllegalStateException("handshake is not completed"); + + key.interestOps(key.interestOps() | ops); + } + + /** + * removes interestOps to SelectionKey of the TransportLayer + * @param ops SelectionKey interestOps + */ + @Override + public void removeInterestOps(int ops) { + if (!key.isValid()) + throw new CancelledKeyException(); + else if (!ready()) + throw new IllegalStateException("handshake is not completed"); + + key.interestOps(key.interestOps() & ~ops); + } + + + /** + * returns delegatedTask for the SSLEngine. + */ + protected Runnable delegatedTask() { + return sslEngine.getDelegatedTask(); + } + + /** + * transfers appReadBuffer contents (decrypted data) into dst bytebuffer + * @param dst ByteBuffer + */ + private int readFromAppBuffer(ByteBuffer dst) { + appReadBuffer.flip(); + int remaining = Math.min(appReadBuffer.remaining(), dst.remaining()); + if (remaining > 0) { + int limit = appReadBuffer.limit(); + appReadBuffer.limit(appReadBuffer.position() + remaining); + dst.put(appReadBuffer); + appReadBuffer.limit(limit); + } + appReadBuffer.compact(); + return remaining; + } + + protected int netReadBufferSize() { + return sslEngine.getSession().getPacketBufferSize(); + } + + protected int netWriteBufferSize() { + return sslEngine.getSession().getPacketBufferSize(); + } + + protected int applicationBufferSize() { + return sslEngine.getSession().getApplicationBufferSize(); + } + + protected ByteBuffer netReadBuffer() { + return netReadBuffer; + } + + // Visibility for testing + protected ByteBuffer appReadBuffer() { + return appReadBuffer; + } + + /** + * SSL exceptions are propagated as authentication failures so that clients can avoid + * retries and report the failure. If `flush` is true, exceptions are propagated after + * any pending outgoing bytes are flushed to ensure that the peer is notified of the failure. + */ + private void handshakeFailure(SSLException sslException, boolean flush) throws IOException { + //Release all resources such as internal buffers that SSLEngine is managing + sslEngine.closeOutbound(); + try { + sslEngine.closeInbound(); + } catch (SSLException e) { + log.debug("SSLEngine.closeInBound() raised an exception.", e); + } + + state = State.HANDSHAKE_FAILED; + handshakeException = new SslAuthenticationException("SSL handshake failed", sslException); + + // Attempt to flush any outgoing bytes. If flush doesn't complete, delay exception handling until outgoing bytes + // are flushed. If write fails because remote end has closed the channel, log the I/O exception and continue to + // handle the handshake failure as an authentication exception. + try { + if (!flush || flush(netWriteBuffer)) + throw handshakeException; + } catch (IOException e) { + log.debug("Failed to flush all bytes before closing channel", e); + throw handshakeException; + } + } + + // SSL handshake failures are typically thrown as SSLHandshakeException, SSLProtocolException, + // SSLPeerUnverifiedException or SSLKeyException if the cause is known. These exceptions indicate + // authentication failures (e.g. configuration errors) which should not be retried. But the SSL engine + // may also throw exceptions using the base class SSLException in a few cases: + // a) If there are no matching ciphers or TLS version or the private key is invalid, client will be + // unable to process the server message and an SSLException is thrown: + // javax.net.ssl.SSLException: Unrecognized SSL message, plaintext connection? + // b) If server closes the connection gracefully during handshake, client may receive close_notify + // and and an SSLException is thrown: + // javax.net.ssl.SSLException: Received close_notify during handshake + // We want to handle a) as a non-retriable SslAuthenticationException and b) as a retriable IOException. + // To do this we need to rely on the exception string. Since it is safer to throw a retriable exception + // when we are not sure, we will treat only the first exception string as a handshake exception. + private void maybeProcessHandshakeFailure(SSLException sslException, boolean flush, IOException ioException) throws IOException { + if (sslException instanceof SSLHandshakeException || sslException instanceof SSLProtocolException || + sslException instanceof SSLPeerUnverifiedException || sslException instanceof SSLKeyException || + sslException.getMessage().contains("Unrecognized SSL message") || + sslException.getMessage().contains("Received fatal alert: ")) + handshakeFailure(sslException, flush); + else if (ioException == null) + throw sslException; + else { + log.debug("SSLException while unwrapping data after IOException, original IOException will be propagated", sslException); + throw ioException; + } + } + + // If handshake has already failed, throw the authentication exception. + private void maybeThrowSslAuthenticationException() { + if (handshakeException != null) + throw handshakeException; + } + + @Override + public boolean isMute() { + return key.isValid() && (key.interestOps() & SelectionKey.OP_READ) == 0; + } + + @Override + public boolean hasBytesBuffered() { + return hasBytesBuffered; + } + + // Update `hasBytesBuffered` status. If any bytes were read from the network or + // if data was returned from read, `hasBytesBuffered` is set to true if any buffered + // data is still remaining. If not, `hasBytesBuffered` is set to false since no progress + // can be made until more data is available to read from the network. + private void updateBytesBuffered(boolean madeProgress) { + if (madeProgress) + hasBytesBuffered = netReadBuffer.position() != 0 || appReadBuffer.position() != 0; + else + hasBytesBuffered = false; + } + + @Override + public long transferFrom(FileChannel fileChannel, long position, long count) throws IOException { + if (state == State.CLOSING) + throw closingException(); + if (state != State.READY) + return 0; + + if (!flush(netWriteBuffer)) + return 0; + + long channelSize = fileChannel.size(); + if (position > channelSize) + return 0; + int totalBytesToWrite = (int) Math.min(Math.min(count, channelSize - position), Integer.MAX_VALUE); + + if (fileChannelBuffer == null) { + // Pick a size that allows for reasonably efficient disk reads, keeps the memory overhead per connection + // manageable and can typically be drained in a single `write` call. The `netWriteBuffer` is typically 16k + // and the socket send buffer is 100k by default, so 32k is a good number given the mentioned trade-offs. + int transferSize = 32768; + // Allocate a direct buffer to avoid one heap to heap buffer copy. SSLEngine copies the source + // buffer (fileChannelBuffer) to the destination buffer (netWriteBuffer) and then encrypts in-place. + // FileChannel.read() to a heap buffer requires a copy from a direct buffer to a heap buffer, which is not + // useful here. + fileChannelBuffer = ByteBuffer.allocateDirect(transferSize); + // The loop below drains any remaining bytes from the buffer before reading from disk, so we ensure there + // are no remaining bytes in the empty buffer + fileChannelBuffer.position(fileChannelBuffer.limit()); + } + + int totalBytesWritten = 0; + long pos = position; + try { + while (totalBytesWritten < totalBytesToWrite) { + if (!fileChannelBuffer.hasRemaining()) { + fileChannelBuffer.clear(); + int bytesRemaining = totalBytesToWrite - totalBytesWritten; + if (bytesRemaining < fileChannelBuffer.limit()) + fileChannelBuffer.limit(bytesRemaining); + int bytesRead = fileChannel.read(fileChannelBuffer, pos); + if (bytesRead <= 0) + break; + fileChannelBuffer.flip(); + } + int networkBytesWritten = write(fileChannelBuffer); + totalBytesWritten += networkBytesWritten; + // In the case of a partial write we only return the written bytes to the caller. As a result, the + // `position` passed in the next `transferFrom` call won't include the bytes remaining in + // `fileChannelBuffer`. By draining `fileChannelBuffer` first, we ensure we update `pos` before + // we invoke `fileChannel.read`. + if (fileChannelBuffer.hasRemaining()) + break; + pos += networkBytesWritten; + } + return totalBytesWritten; + } catch (IOException e) { + if (totalBytesWritten > 0) + return totalBytesWritten; + throw e; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/TransferableChannel.java b/clients/src/main/java/org/apache/kafka/common/network/TransferableChannel.java new file mode 100644 index 0000000..2c635f0 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/TransferableChannel.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import java.io.IOException; +import java.nio.channels.FileChannel; +import java.nio.channels.GatheringByteChannel; + +/** + * Extends GatheringByteChannel with the minimal set of methods required by the Send interface. Supporting TLS and + * efficient zero copy transfers are the main reasons for the additional methods. + * + * @see SslTransportLayer + */ +public interface TransferableChannel extends GatheringByteChannel { + + /** + * @return true if there are any pending writes. false if the implementation directly write all data to output. + */ + boolean hasPendingWrites(); + + /** + * Transfers bytes from `fileChannel` to this `TransferableChannel`. + * + * This method will delegate to {@link FileChannel#transferTo(long, long, java.nio.channels.WritableByteChannel)}, + * but it will unwrap the destination channel, if possible, in order to benefit from zero copy. This is required + * because the fast path of `transferTo` is only executed if the destination buffer inherits from an internal JDK + * class. + * + * @param fileChannel The source channel + * @param position The position within the file at which the transfer is to begin; must be non-negative + * @param count The maximum number of bytes to be transferred; must be non-negative + * @return The number of bytes, possibly zero, that were actually transferred + * @see FileChannel#transferTo(long, long, java.nio.channels.WritableByteChannel) + */ + long transferFrom(FileChannel fileChannel, long position, long count) throws IOException; +} diff --git a/clients/src/main/java/org/apache/kafka/common/network/TransportLayer.java b/clients/src/main/java/org/apache/kafka/common/network/TransportLayer.java new file mode 100644 index 0000000..db46f3b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/network/TransportLayer.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +/* + * Transport layer for underlying communication. + * At very basic level it is wrapper around SocketChannel and can be used as substitute for SocketChannel + * and other network Channel implementations. + * As NetworkClient replaces BlockingChannel and other implementations we will be using KafkaChannel as + * a network I/O channel. + */ + +import org.apache.kafka.common.errors.AuthenticationException; + +import java.io.IOException; +import java.nio.channels.ScatteringByteChannel; +import java.nio.channels.SelectionKey; +import java.nio.channels.SocketChannel; +import java.security.Principal; + +public interface TransportLayer extends ScatteringByteChannel, TransferableChannel { + + /** + * Returns true if the channel has handshake and authentication done. + */ + boolean ready(); + + /** + * Finishes the process of connecting a socket channel. + */ + boolean finishConnect() throws IOException; + + /** + * disconnect socketChannel + */ + void disconnect(); + + /** + * Tells whether or not this channel's network socket is connected. + */ + boolean isConnected(); + + /** + * returns underlying socketChannel + */ + SocketChannel socketChannel(); + + /** + * Get the underlying selection key + */ + SelectionKey selectionKey(); + + /** + * This a no-op for the non-secure PLAINTEXT implementation. For SSL, this performs + * SSL handshake. The SSL handshake includes client authentication if configured using + * {@link org.apache.kafka.common.config.SslConfigs#SSL_CLIENT_AUTH_CONFIG}. + * @throws AuthenticationException if handshake fails due to an {@link javax.net.ssl.SSLException}. + * @throws IOException if read or write fails with an I/O error. + */ + void handshake() throws AuthenticationException, IOException; + + /** + * Returns `SSLSession.getPeerPrincipal()` if this is an SslTransportLayer and there is an authenticated peer, + * `KafkaPrincipal.ANONYMOUS` is returned otherwise. + */ + Principal peerPrincipal() throws IOException; + + void addInterestOps(int ops); + + void removeInterestOps(int ops); + + boolean isMute(); + + /** + * @return true if channel has bytes to be read in any intermediate buffers + * which may be processed without reading additional data from the network. + */ + boolean hasBytesBuffered(); +} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/ApiKeys.java b/clients/src/main/java/org/apache/kafka/common/protocol/ApiKeys.java new file mode 100644 index 0000000..428e4a8 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/protocol/ApiKeys.java @@ -0,0 +1,281 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.protocol; + +import org.apache.kafka.common.message.ApiMessageType; +import org.apache.kafka.common.protocol.types.Schema; +import org.apache.kafka.common.protocol.types.Type; +import org.apache.kafka.common.record.RecordBatch; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.EnumMap; +import java.util.EnumSet; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static org.apache.kafka.common.protocol.types.Type.BYTES; +import static org.apache.kafka.common.protocol.types.Type.COMPACT_BYTES; +import static org.apache.kafka.common.protocol.types.Type.COMPACT_NULLABLE_BYTES; +import static org.apache.kafka.common.protocol.types.Type.NULLABLE_BYTES; +import static org.apache.kafka.common.protocol.types.Type.RECORDS; + +/** + * Identifiers for all the Kafka APIs + */ +public enum ApiKeys { + PRODUCE(ApiMessageType.PRODUCE), + FETCH(ApiMessageType.FETCH), + LIST_OFFSETS(ApiMessageType.LIST_OFFSETS), + METADATA(ApiMessageType.METADATA), + LEADER_AND_ISR(ApiMessageType.LEADER_AND_ISR, true), + STOP_REPLICA(ApiMessageType.STOP_REPLICA, true), + UPDATE_METADATA(ApiMessageType.UPDATE_METADATA, true), + CONTROLLED_SHUTDOWN(ApiMessageType.CONTROLLED_SHUTDOWN, true), + OFFSET_COMMIT(ApiMessageType.OFFSET_COMMIT), + OFFSET_FETCH(ApiMessageType.OFFSET_FETCH), + FIND_COORDINATOR(ApiMessageType.FIND_COORDINATOR), + JOIN_GROUP(ApiMessageType.JOIN_GROUP), + HEARTBEAT(ApiMessageType.HEARTBEAT), + LEAVE_GROUP(ApiMessageType.LEAVE_GROUP), + SYNC_GROUP(ApiMessageType.SYNC_GROUP), + DESCRIBE_GROUPS(ApiMessageType.DESCRIBE_GROUPS), + LIST_GROUPS(ApiMessageType.LIST_GROUPS), + SASL_HANDSHAKE(ApiMessageType.SASL_HANDSHAKE), + API_VERSIONS(ApiMessageType.API_VERSIONS), + CREATE_TOPICS(ApiMessageType.CREATE_TOPICS, false, true), + DELETE_TOPICS(ApiMessageType.DELETE_TOPICS, false, true), + DELETE_RECORDS(ApiMessageType.DELETE_RECORDS), + INIT_PRODUCER_ID(ApiMessageType.INIT_PRODUCER_ID), + OFFSET_FOR_LEADER_EPOCH(ApiMessageType.OFFSET_FOR_LEADER_EPOCH), + ADD_PARTITIONS_TO_TXN(ApiMessageType.ADD_PARTITIONS_TO_TXN, false, RecordBatch.MAGIC_VALUE_V2, false), + ADD_OFFSETS_TO_TXN(ApiMessageType.ADD_OFFSETS_TO_TXN, false, RecordBatch.MAGIC_VALUE_V2, false), + END_TXN(ApiMessageType.END_TXN, false, RecordBatch.MAGIC_VALUE_V2, false), + WRITE_TXN_MARKERS(ApiMessageType.WRITE_TXN_MARKERS, true, RecordBatch.MAGIC_VALUE_V2, false), + TXN_OFFSET_COMMIT(ApiMessageType.TXN_OFFSET_COMMIT, false, RecordBatch.MAGIC_VALUE_V2, false), + DESCRIBE_ACLS(ApiMessageType.DESCRIBE_ACLS), + CREATE_ACLS(ApiMessageType.CREATE_ACLS, false, true), + DELETE_ACLS(ApiMessageType.DELETE_ACLS, false, true), + DESCRIBE_CONFIGS(ApiMessageType.DESCRIBE_CONFIGS), + ALTER_CONFIGS(ApiMessageType.ALTER_CONFIGS, false, true), + ALTER_REPLICA_LOG_DIRS(ApiMessageType.ALTER_REPLICA_LOG_DIRS), + DESCRIBE_LOG_DIRS(ApiMessageType.DESCRIBE_LOG_DIRS), + SASL_AUTHENTICATE(ApiMessageType.SASL_AUTHENTICATE), + CREATE_PARTITIONS(ApiMessageType.CREATE_PARTITIONS, false, true), + CREATE_DELEGATION_TOKEN(ApiMessageType.CREATE_DELEGATION_TOKEN, false, true), + RENEW_DELEGATION_TOKEN(ApiMessageType.RENEW_DELEGATION_TOKEN, false, true), + EXPIRE_DELEGATION_TOKEN(ApiMessageType.EXPIRE_DELEGATION_TOKEN, false, true), + DESCRIBE_DELEGATION_TOKEN(ApiMessageType.DESCRIBE_DELEGATION_TOKEN), + DELETE_GROUPS(ApiMessageType.DELETE_GROUPS), + ELECT_LEADERS(ApiMessageType.ELECT_LEADERS, false, true), + INCREMENTAL_ALTER_CONFIGS(ApiMessageType.INCREMENTAL_ALTER_CONFIGS, false, true), + ALTER_PARTITION_REASSIGNMENTS(ApiMessageType.ALTER_PARTITION_REASSIGNMENTS, false, true), + LIST_PARTITION_REASSIGNMENTS(ApiMessageType.LIST_PARTITION_REASSIGNMENTS, false, true), + OFFSET_DELETE(ApiMessageType.OFFSET_DELETE), + DESCRIBE_CLIENT_QUOTAS(ApiMessageType.DESCRIBE_CLIENT_QUOTAS), + ALTER_CLIENT_QUOTAS(ApiMessageType.ALTER_CLIENT_QUOTAS, false, true), + DESCRIBE_USER_SCRAM_CREDENTIALS(ApiMessageType.DESCRIBE_USER_SCRAM_CREDENTIALS), + ALTER_USER_SCRAM_CREDENTIALS(ApiMessageType.ALTER_USER_SCRAM_CREDENTIALS, false, true), + VOTE(ApiMessageType.VOTE, true, RecordBatch.MAGIC_VALUE_V0, false), + BEGIN_QUORUM_EPOCH(ApiMessageType.BEGIN_QUORUM_EPOCH, true, RecordBatch.MAGIC_VALUE_V0, false), + END_QUORUM_EPOCH(ApiMessageType.END_QUORUM_EPOCH, true, RecordBatch.MAGIC_VALUE_V0, false), + DESCRIBE_QUORUM(ApiMessageType.DESCRIBE_QUORUM, true, RecordBatch.MAGIC_VALUE_V0, true), + ALTER_ISR(ApiMessageType.ALTER_ISR, true), + UPDATE_FEATURES(ApiMessageType.UPDATE_FEATURES, false, true), + ENVELOPE(ApiMessageType.ENVELOPE, true, RecordBatch.MAGIC_VALUE_V0, false), + FETCH_SNAPSHOT(ApiMessageType.FETCH_SNAPSHOT, false, RecordBatch.MAGIC_VALUE_V0, false), + DESCRIBE_CLUSTER(ApiMessageType.DESCRIBE_CLUSTER), + DESCRIBE_PRODUCERS(ApiMessageType.DESCRIBE_PRODUCERS), + BROKER_REGISTRATION(ApiMessageType.BROKER_REGISTRATION, true, RecordBatch.MAGIC_VALUE_V0, false), + BROKER_HEARTBEAT(ApiMessageType.BROKER_HEARTBEAT, true, RecordBatch.MAGIC_VALUE_V0, false), + UNREGISTER_BROKER(ApiMessageType.UNREGISTER_BROKER, false, RecordBatch.MAGIC_VALUE_V0, true), + DESCRIBE_TRANSACTIONS(ApiMessageType.DESCRIBE_TRANSACTIONS), + LIST_TRANSACTIONS(ApiMessageType.LIST_TRANSACTIONS), + ALLOCATE_PRODUCER_IDS(ApiMessageType.ALLOCATE_PRODUCER_IDS, true, true); + + private static final Map> APIS_BY_LISTENER = + new EnumMap<>(ApiMessageType.ListenerType.class); + + static { + for (ApiMessageType.ListenerType listenerType : ApiMessageType.ListenerType.values()) { + APIS_BY_LISTENER.put(listenerType, filterApisForListener(listenerType)); + } + } + + // The generator ensures every `ApiMessageType` has a unique id + private static final Map ID_TO_TYPE = Arrays.stream(ApiKeys.values()) + .collect(Collectors.toMap(key -> (int) key.id, Function.identity())); + + /** the permanent and immutable id of an API - this can't change ever */ + public final short id; + + /** An english description of the api - used for debugging and metric names, it can potentially be changed via a KIP */ + public final String name; + + /** indicates if this is a ClusterAction request used only by brokers */ + public final boolean clusterAction; + + /** indicates the minimum required inter broker magic required to support the API */ + public final byte minRequiredInterBrokerMagic; + + /** indicates whether the API is enabled for forwarding **/ + public final boolean forwardable; + + public final boolean requiresDelayedAllocation; + + public final ApiMessageType messageType; + + ApiKeys(ApiMessageType messageType) { + this(messageType, false); + } + + ApiKeys(ApiMessageType messageType, boolean clusterAction) { + this(messageType, clusterAction, RecordBatch.MAGIC_VALUE_V0, false); + } + + ApiKeys(ApiMessageType messageType, boolean clusterAction, boolean forwardable) { + this(messageType, clusterAction, RecordBatch.MAGIC_VALUE_V0, forwardable); + } + + ApiKeys( + ApiMessageType messageType, + boolean clusterAction, + byte minRequiredInterBrokerMagic, + boolean forwardable + ) { + this.messageType = messageType; + this.id = messageType.apiKey(); + this.name = messageType.name; + this.clusterAction = clusterAction; + this.minRequiredInterBrokerMagic = minRequiredInterBrokerMagic; + this.requiresDelayedAllocation = forwardable || shouldRetainsBufferReference(messageType.requestSchemas()); + this.forwardable = forwardable; + } + + private static boolean shouldRetainsBufferReference(Schema[] requestSchemas) { + boolean requestRetainsBufferReference = false; + for (Schema requestVersionSchema : requestSchemas) { + if (retainsBufferReference(requestVersionSchema)) { + requestRetainsBufferReference = true; + break; + } + } + return requestRetainsBufferReference; + } + + public static ApiKeys forId(int id) { + ApiKeys apiKey = ID_TO_TYPE.get(id); + if (apiKey == null) { + throw new IllegalArgumentException("Unexpected api key: " + id); + } + return apiKey; + } + + public static boolean hasId(int id) { + return ID_TO_TYPE.containsKey(id); + } + + public short latestVersion() { + return messageType.highestSupportedVersion(); + } + + public short oldestVersion() { + return messageType.lowestSupportedVersion(); + } + + public List allVersions() { + List versions = new ArrayList<>(latestVersion() - oldestVersion() + 1); + for (short version = oldestVersion(); version <= latestVersion(); version++) { + versions.add(version); + } + return versions; + } + + public boolean isVersionSupported(short apiVersion) { + return apiVersion >= oldestVersion() && apiVersion <= latestVersion(); + } + + public short requestHeaderVersion(short apiVersion) { + return messageType.requestHeaderVersion(apiVersion); + } + + public short responseHeaderVersion(short apiVersion) { + return messageType.responseHeaderVersion(apiVersion); + } + + public boolean inScope(ApiMessageType.ListenerType listener) { + return messageType.listeners().contains(listener); + } + + private static String toHtml() { + final StringBuilder b = new StringBuilder(); + b.append("\n"); + b.append(""); + b.append("\n"); + b.append("\n"); + b.append(""); + for (ApiKeys key : zkBrokerApis()) { + b.append("\n"); + b.append(""); + b.append(""); + b.append("\n"); + } + b.append("
        NameKey
        "); + b.append("" + key.name + ""); + b.append(""); + b.append(key.id); + b.append("
        \n"); + return b.toString(); + } + + public static void main(String[] args) { + System.out.println(toHtml()); + } + + private static boolean retainsBufferReference(Schema schema) { + final AtomicBoolean hasBuffer = new AtomicBoolean(false); + Schema.Visitor detector = new Schema.Visitor() { + @Override + public void visit(Type field) { + if (field == BYTES || field == NULLABLE_BYTES || field == RECORDS || + field == COMPACT_BYTES || field == COMPACT_NULLABLE_BYTES) + hasBuffer.set(true); + } + }; + schema.walk(detector); + return hasBuffer.get(); + } + + public static EnumSet zkBrokerApis() { + return apisForListener(ApiMessageType.ListenerType.ZK_BROKER); + } + + public static EnumSet apisForListener(ApiMessageType.ListenerType listener) { + return APIS_BY_LISTENER.get(listener); + } + + private static EnumSet filterApisForListener(ApiMessageType.ListenerType listener) { + List controllerApis = Arrays.stream(ApiKeys.values()) + .filter(apiKey -> apiKey.messageType.listeners().contains(listener)) + .collect(Collectors.toList()); + return EnumSet.copyOf(controllerApis); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/ApiMessage.java b/clients/src/main/java/org/apache/kafka/common/protocol/ApiMessage.java new file mode 100644 index 0000000..4f17565 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/protocol/ApiMessage.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.protocol; + +/** + * A Message which is part of the top-level Kafka API. + */ +public interface ApiMessage extends Message { + /** + * Returns the API key of this message, or -1 if there is none. + */ + short apiKey(); +} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/ByteBufferAccessor.java b/clients/src/main/java/org/apache/kafka/common/protocol/ByteBufferAccessor.java new file mode 100644 index 0000000..bd0925d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/protocol/ByteBufferAccessor.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.protocol; + +import org.apache.kafka.common.utils.ByteUtils; + +import java.nio.ByteBuffer; + +public class ByteBufferAccessor implements Readable, Writable { + private final ByteBuffer buf; + + public ByteBufferAccessor(ByteBuffer buf) { + this.buf = buf; + } + + @Override + public byte readByte() { + return buf.get(); + } + + @Override + public short readShort() { + return buf.getShort(); + } + + @Override + public int readInt() { + return buf.getInt(); + } + + @Override + public long readLong() { + return buf.getLong(); + } + + @Override + public double readDouble() { + return ByteUtils.readDouble(buf); + } + + @Override + public void readArray(byte[] arr) { + buf.get(arr); + } + + @Override + public int readUnsignedVarint() { + return ByteUtils.readUnsignedVarint(buf); + } + + @Override + public ByteBuffer readByteBuffer(int length) { + ByteBuffer res = buf.slice(); + res.limit(length); + + buf.position(buf.position() + length); + + return res; + } + + @Override + public void writeByte(byte val) { + buf.put(val); + } + + @Override + public void writeShort(short val) { + buf.putShort(val); + } + + @Override + public void writeInt(int val) { + buf.putInt(val); + } + + @Override + public void writeLong(long val) { + buf.putLong(val); + } + + @Override + public void writeDouble(double val) { + ByteUtils.writeDouble(val, buf); + } + + @Override + public void writeByteArray(byte[] arr) { + buf.put(arr); + } + + @Override + public void writeUnsignedVarint(int i) { + ByteUtils.writeUnsignedVarint(i, buf); + } + + @Override + public void writeByteBuffer(ByteBuffer src) { + buf.put(src.duplicate()); + } + + @Override + public void writeVarint(int i) { + ByteUtils.writeVarint(i, buf); + } + + @Override + public void writeVarlong(long i) { + ByteUtils.writeVarlong(i, buf); + } + + @Override + public int readVarint() { + return ByteUtils.readVarint(buf); + } + + @Override + public long readVarlong() { + return ByteUtils.readVarlong(buf); + } + + @Override + public int remaining() { + return buf.remaining(); + } + + public void flip() { + buf.flip(); + } + + public ByteBuffer buffer() { + return buf; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/DataInputStreamReadable.java b/clients/src/main/java/org/apache/kafka/common/protocol/DataInputStreamReadable.java new file mode 100644 index 0000000..70ed52d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/protocol/DataInputStreamReadable.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.protocol; + +import org.apache.kafka.common.utils.ByteUtils; + +import java.io.Closeable; +import java.io.DataInputStream; +import java.io.IOException; +import java.nio.ByteBuffer; + +public class DataInputStreamReadable implements Readable, Closeable { + protected final DataInputStream input; + + public DataInputStreamReadable(DataInputStream input) { + this.input = input; + } + + @Override + public byte readByte() { + try { + return input.readByte(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public short readShort() { + try { + return input.readShort(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public int readInt() { + try { + return input.readInt(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public long readLong() { + try { + return input.readLong(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public double readDouble() { + try { + return input.readDouble(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void readArray(byte[] arr) { + try { + input.readFully(arr); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public int readUnsignedVarint() { + try { + return ByteUtils.readUnsignedVarint(input); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public ByteBuffer readByteBuffer(int length) { + byte[] arr = new byte[length]; + readArray(arr); + return ByteBuffer.wrap(arr); + } + + @Override + public int readVarint() { + try { + return ByteUtils.readVarint(input); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public long readVarlong() { + try { + return ByteUtils.readVarlong(input); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public int remaining() { + try { + return input.available(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void close() { + try { + input.close(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/DataOutputStreamWritable.java b/clients/src/main/java/org/apache/kafka/common/protocol/DataOutputStreamWritable.java new file mode 100644 index 0000000..f484016 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/protocol/DataOutputStreamWritable.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.protocol; + +import org.apache.kafka.common.utils.ByteUtils; +import org.apache.kafka.common.utils.Utils; + +import java.io.Closeable; +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; + +public class DataOutputStreamWritable implements Writable, Closeable { + protected final DataOutputStream out; + + public DataOutputStreamWritable(DataOutputStream out) { + this.out = out; + } + + @Override + public void writeByte(byte val) { + try { + out.writeByte(val); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void writeShort(short val) { + try { + out.writeShort(val); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void writeInt(int val) { + try { + out.writeInt(val); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void writeLong(long val) { + try { + out.writeLong(val); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void writeDouble(double val) { + try { + out.writeDouble(val); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void writeByteArray(byte[] arr) { + try { + out.write(arr); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void writeUnsignedVarint(int i) { + try { + ByteUtils.writeUnsignedVarint(i, out); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void writeByteBuffer(ByteBuffer buf) { + try { + if (buf.hasArray()) { + out.write(buf.array(), buf.position(), buf.limit()); + } else { + byte[] bytes = Utils.toArray(buf); + out.write(bytes); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void writeVarint(int i) { + try { + ByteUtils.writeVarint(i, out); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void writeVarlong(long i) { + try { + ByteUtils.writeVarlong(i, out); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public void flush() { + try { + out.flush(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void close() { + try { + out.close(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/Errors.java b/clients/src/main/java/org/apache/kafka/common/protocol/Errors.java new file mode 100644 index 0000000..f48ae6c --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/protocol/Errors.java @@ -0,0 +1,510 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.protocol; + +import org.apache.kafka.common.InvalidRecordException; +import org.apache.kafka.common.errors.ApiException; +import org.apache.kafka.common.errors.BrokerIdNotRegisteredException; +import org.apache.kafka.common.errors.BrokerNotAvailableException; +import org.apache.kafka.common.errors.ClusterAuthorizationException; +import org.apache.kafka.common.errors.ConcurrentTransactionsException; +import org.apache.kafka.common.errors.ControllerMovedException; +import org.apache.kafka.common.errors.CoordinatorLoadInProgressException; +import org.apache.kafka.common.errors.CoordinatorNotAvailableException; +import org.apache.kafka.common.errors.CorruptRecordException; +import org.apache.kafka.common.errors.DelegationTokenAuthorizationException; +import org.apache.kafka.common.errors.DelegationTokenDisabledException; +import org.apache.kafka.common.errors.DelegationTokenExpiredException; +import org.apache.kafka.common.errors.DelegationTokenNotFoundException; +import org.apache.kafka.common.errors.DelegationTokenOwnerMismatchException; +import org.apache.kafka.common.errors.DuplicateBrokerRegistrationException; +import org.apache.kafka.common.errors.DuplicateResourceException; +import org.apache.kafka.common.errors.DuplicateSequenceException; +import org.apache.kafka.common.errors.ElectionNotNeededException; +import org.apache.kafka.common.errors.EligibleLeadersNotAvailableException; +import org.apache.kafka.common.errors.FeatureUpdateFailedException; +import org.apache.kafka.common.errors.FencedInstanceIdException; +import org.apache.kafka.common.errors.FencedLeaderEpochException; +import org.apache.kafka.common.errors.FetchSessionIdNotFoundException; +import org.apache.kafka.common.errors.FetchSessionTopicIdException; +import org.apache.kafka.common.errors.GroupAuthorizationException; +import org.apache.kafka.common.errors.GroupIdNotFoundException; +import org.apache.kafka.common.errors.GroupMaxSizeReachedException; +import org.apache.kafka.common.errors.GroupNotEmptyException; +import org.apache.kafka.common.errors.GroupSubscribedToTopicException; +import org.apache.kafka.common.errors.IllegalGenerationException; +import org.apache.kafka.common.errors.IllegalSaslStateException; +import org.apache.kafka.common.errors.InconsistentGroupProtocolException; +import org.apache.kafka.common.errors.InconsistentTopicIdException; +import org.apache.kafka.common.errors.InconsistentVoterSetException; +import org.apache.kafka.common.errors.InconsistentClusterIdException; +import org.apache.kafka.common.errors.InvalidCommitOffsetSizeException; +import org.apache.kafka.common.errors.InvalidConfigurationException; +import org.apache.kafka.common.errors.InvalidFetchSessionEpochException; +import org.apache.kafka.common.errors.InvalidFetchSizeException; +import org.apache.kafka.common.errors.InvalidGroupIdException; +import org.apache.kafka.common.errors.InvalidPartitionsException; +import org.apache.kafka.common.errors.InvalidPidMappingException; +import org.apache.kafka.common.errors.InvalidPrincipalTypeException; +import org.apache.kafka.common.errors.InvalidProducerEpochException; +import org.apache.kafka.common.errors.InvalidReplicaAssignmentException; +import org.apache.kafka.common.errors.InvalidReplicationFactorException; +import org.apache.kafka.common.errors.InvalidRequestException; +import org.apache.kafka.common.errors.InvalidRequiredAcksException; +import org.apache.kafka.common.errors.InvalidSessionTimeoutException; +import org.apache.kafka.common.errors.InvalidTimestampException; +import org.apache.kafka.common.errors.InvalidTopicException; +import org.apache.kafka.common.errors.InvalidTxnStateException; +import org.apache.kafka.common.errors.InvalidTxnTimeoutException; +import org.apache.kafka.common.errors.InvalidUpdateVersionException; +import org.apache.kafka.common.errors.KafkaStorageException; +import org.apache.kafka.common.errors.LeaderNotAvailableException; +import org.apache.kafka.common.errors.ListenerNotFoundException; +import org.apache.kafka.common.errors.LogDirNotFoundException; +import org.apache.kafka.common.errors.MemberIdRequiredException; +import org.apache.kafka.common.errors.NetworkException; +import org.apache.kafka.common.errors.NoReassignmentInProgressException; +import org.apache.kafka.common.errors.NotControllerException; +import org.apache.kafka.common.errors.NotCoordinatorException; +import org.apache.kafka.common.errors.NotEnoughReplicasAfterAppendException; +import org.apache.kafka.common.errors.NotEnoughReplicasException; +import org.apache.kafka.common.errors.NotLeaderOrFollowerException; +import org.apache.kafka.common.errors.OffsetMetadataTooLarge; +import org.apache.kafka.common.errors.OffsetNotAvailableException; +import org.apache.kafka.common.errors.OffsetOutOfRangeException; +import org.apache.kafka.common.errors.OperationNotAttemptedException; +import org.apache.kafka.common.errors.OutOfOrderSequenceException; +import org.apache.kafka.common.errors.PolicyViolationException; +import org.apache.kafka.common.errors.PositionOutOfRangeException; +import org.apache.kafka.common.errors.PreferredLeaderNotAvailableException; +import org.apache.kafka.common.errors.PrincipalDeserializationException; +import org.apache.kafka.common.errors.ProducerFencedException; +import org.apache.kafka.common.errors.ReassignmentInProgressException; +import org.apache.kafka.common.errors.RebalanceInProgressException; +import org.apache.kafka.common.errors.RecordBatchTooLargeException; +import org.apache.kafka.common.errors.RecordTooLargeException; +import org.apache.kafka.common.errors.ReplicaNotAvailableException; +import org.apache.kafka.common.errors.ResourceNotFoundException; +import org.apache.kafka.common.errors.RetriableException; +import org.apache.kafka.common.errors.SaslAuthenticationException; +import org.apache.kafka.common.errors.SecurityDisabledException; +import org.apache.kafka.common.errors.SnapshotNotFoundException; +import org.apache.kafka.common.errors.StaleBrokerEpochException; +import org.apache.kafka.common.errors.ThrottlingQuotaExceededException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.TopicAuthorizationException; +import org.apache.kafka.common.errors.TopicDeletionDisabledException; +import org.apache.kafka.common.errors.TopicExistsException; +import org.apache.kafka.common.errors.TransactionCoordinatorFencedException; +import org.apache.kafka.common.errors.TransactionalIdAuthorizationException; +import org.apache.kafka.common.errors.TransactionalIdNotFoundException; +import org.apache.kafka.common.errors.UnacceptableCredentialException; +import org.apache.kafka.common.errors.UnknownLeaderEpochException; +import org.apache.kafka.common.errors.UnknownMemberIdException; +import org.apache.kafka.common.errors.UnknownProducerIdException; +import org.apache.kafka.common.errors.UnknownServerException; +import org.apache.kafka.common.errors.UnknownTopicIdException; +import org.apache.kafka.common.errors.UnknownTopicOrPartitionException; +import org.apache.kafka.common.errors.UnstableOffsetCommitException; +import org.apache.kafka.common.errors.UnsupportedByAuthenticationException; +import org.apache.kafka.common.errors.UnsupportedCompressionTypeException; +import org.apache.kafka.common.errors.UnsupportedForMessageFormatException; +import org.apache.kafka.common.errors.UnsupportedSaslMechanismException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.Function; + +/** + * This class contains all the client-server errors--those errors that must be sent from the server to the client. These + * are thus part of the protocol. The names can be changed but the error code cannot. + * + * Note that client library will convert an unknown error code to the non-retriable UnknownServerException if the client library + * version is old and does not recognize the newly-added error code. Therefore when a new server-side error is added, + * we may need extra logic to convert the new error code to another existing error code before sending the response back to + * the client if the request version suggests that the client may not recognize the new error code. + * + * Do not add exceptions that occur only on the client or only on the server here. + * + * @see org.apache.kafka.common.network.SslTransportLayer + */ +public enum Errors { + UNKNOWN_SERVER_ERROR(-1, "The server experienced an unexpected error when processing the request.", + UnknownServerException::new), + NONE(0, null, message -> null), + OFFSET_OUT_OF_RANGE(1, "The requested offset is not within the range of offsets maintained by the server.", + OffsetOutOfRangeException::new), + CORRUPT_MESSAGE(2, "This message has failed its CRC checksum, exceeds the valid size, has a null key for a compacted topic, or is otherwise corrupt.", + CorruptRecordException::new), + UNKNOWN_TOPIC_OR_PARTITION(3, "This server does not host this topic-partition.", + UnknownTopicOrPartitionException::new), + INVALID_FETCH_SIZE(4, "The requested fetch size is invalid.", + InvalidFetchSizeException::new), + LEADER_NOT_AVAILABLE(5, "There is no leader for this topic-partition as we are in the middle of a leadership election.", + LeaderNotAvailableException::new), + NOT_LEADER_OR_FOLLOWER(6, "For requests intended only for the leader, this error indicates that the broker is not the current leader. " + + "For requests intended for any replica, this error indicates that the broker is not a replica of the topic partition.", + NotLeaderOrFollowerException::new), + REQUEST_TIMED_OUT(7, "The request timed out.", + TimeoutException::new), + BROKER_NOT_AVAILABLE(8, "The broker is not available.", + BrokerNotAvailableException::new), + REPLICA_NOT_AVAILABLE(9, "The replica is not available for the requested topic-partition. Produce/Fetch requests and other requests " + + "intended only for the leader or follower return NOT_LEADER_OR_FOLLOWER if the broker is not a replica of the topic-partition.", + ReplicaNotAvailableException::new), + MESSAGE_TOO_LARGE(10, "The request included a message larger than the max message size the server will accept.", + RecordTooLargeException::new), + STALE_CONTROLLER_EPOCH(11, "The controller moved to another broker.", + ControllerMovedException::new), + OFFSET_METADATA_TOO_LARGE(12, "The metadata field of the offset request was too large.", + OffsetMetadataTooLarge::new), + NETWORK_EXCEPTION(13, "The server disconnected before a response was received.", + NetworkException::new), + COORDINATOR_LOAD_IN_PROGRESS(14, "The coordinator is loading and hence can't process requests.", + CoordinatorLoadInProgressException::new), + COORDINATOR_NOT_AVAILABLE(15, "The coordinator is not available.", + CoordinatorNotAvailableException::new), + NOT_COORDINATOR(16, "This is not the correct coordinator.", + NotCoordinatorException::new), + INVALID_TOPIC_EXCEPTION(17, "The request attempted to perform an operation on an invalid topic.", + InvalidTopicException::new), + RECORD_LIST_TOO_LARGE(18, "The request included message batch larger than the configured segment size on the server.", + RecordBatchTooLargeException::new), + NOT_ENOUGH_REPLICAS(19, "Messages are rejected since there are fewer in-sync replicas than required.", + NotEnoughReplicasException::new), + NOT_ENOUGH_REPLICAS_AFTER_APPEND(20, "Messages are written to the log, but to fewer in-sync replicas than required.", + NotEnoughReplicasAfterAppendException::new), + INVALID_REQUIRED_ACKS(21, "Produce request specified an invalid value for required acks.", + InvalidRequiredAcksException::new), + ILLEGAL_GENERATION(22, "Specified group generation id is not valid.", + IllegalGenerationException::new), + INCONSISTENT_GROUP_PROTOCOL(23, + "The group member's supported protocols are incompatible with those of existing members " + + "or first group member tried to join with empty protocol type or empty protocol list.", + InconsistentGroupProtocolException::new), + INVALID_GROUP_ID(24, "The configured groupId is invalid.", + InvalidGroupIdException::new), + UNKNOWN_MEMBER_ID(25, "The coordinator is not aware of this member.", + UnknownMemberIdException::new), + INVALID_SESSION_TIMEOUT(26, + "The session timeout is not within the range allowed by the broker " + + "(as configured by group.min.session.timeout.ms and group.max.session.timeout.ms).", + InvalidSessionTimeoutException::new), + REBALANCE_IN_PROGRESS(27, "The group is rebalancing, so a rejoin is needed.", + RebalanceInProgressException::new), + INVALID_COMMIT_OFFSET_SIZE(28, "The committing offset data size is not valid.", + InvalidCommitOffsetSizeException::new), + TOPIC_AUTHORIZATION_FAILED(29, "Topic authorization failed.", TopicAuthorizationException::new), + GROUP_AUTHORIZATION_FAILED(30, "Group authorization failed.", GroupAuthorizationException::new), + CLUSTER_AUTHORIZATION_FAILED(31, "Cluster authorization failed.", + ClusterAuthorizationException::new), + INVALID_TIMESTAMP(32, "The timestamp of the message is out of acceptable range.", + InvalidTimestampException::new), + UNSUPPORTED_SASL_MECHANISM(33, "The broker does not support the requested SASL mechanism.", + UnsupportedSaslMechanismException::new), + ILLEGAL_SASL_STATE(34, "Request is not valid given the current SASL state.", + IllegalSaslStateException::new), + UNSUPPORTED_VERSION(35, "The version of API is not supported.", + UnsupportedVersionException::new), + TOPIC_ALREADY_EXISTS(36, "Topic with this name already exists.", + TopicExistsException::new), + INVALID_PARTITIONS(37, "Number of partitions is below 1.", + InvalidPartitionsException::new), + INVALID_REPLICATION_FACTOR(38, "Replication factor is below 1 or larger than the number of available brokers.", + InvalidReplicationFactorException::new), + INVALID_REPLICA_ASSIGNMENT(39, "Replica assignment is invalid.", + InvalidReplicaAssignmentException::new), + INVALID_CONFIG(40, "Configuration is invalid.", + InvalidConfigurationException::new), + NOT_CONTROLLER(41, "This is not the correct controller for this cluster.", + NotControllerException::new), + INVALID_REQUEST(42, "This most likely occurs because of a request being malformed by the " + + "client library or the message was sent to an incompatible broker. See the broker logs " + + "for more details.", + InvalidRequestException::new), + UNSUPPORTED_FOR_MESSAGE_FORMAT(43, "The message format version on the broker does not support the request.", + UnsupportedForMessageFormatException::new), + POLICY_VIOLATION(44, "Request parameters do not satisfy the configured policy.", + PolicyViolationException::new), + OUT_OF_ORDER_SEQUENCE_NUMBER(45, "The broker received an out of order sequence number.", + OutOfOrderSequenceException::new), + DUPLICATE_SEQUENCE_NUMBER(46, "The broker received a duplicate sequence number.", + DuplicateSequenceException::new), + INVALID_PRODUCER_EPOCH(47, "Producer attempted to produce with an old epoch.", + InvalidProducerEpochException::new), + INVALID_TXN_STATE(48, "The producer attempted a transactional operation in an invalid state.", + InvalidTxnStateException::new), + INVALID_PRODUCER_ID_MAPPING(49, "The producer attempted to use a producer id which is not currently assigned to " + + "its transactional id.", + InvalidPidMappingException::new), + INVALID_TRANSACTION_TIMEOUT(50, "The transaction timeout is larger than the maximum value allowed by " + + "the broker (as configured by transaction.max.timeout.ms).", + InvalidTxnTimeoutException::new), + CONCURRENT_TRANSACTIONS(51, "The producer attempted to update a transaction " + + "while another concurrent operation on the same transaction was ongoing.", + ConcurrentTransactionsException::new), + TRANSACTION_COORDINATOR_FENCED(52, "Indicates that the transaction coordinator sending a WriteTxnMarker " + + "is no longer the current coordinator for a given producer.", + TransactionCoordinatorFencedException::new), + TRANSACTIONAL_ID_AUTHORIZATION_FAILED(53, "Transactional Id authorization failed.", + TransactionalIdAuthorizationException::new), + SECURITY_DISABLED(54, "Security features are disabled.", + SecurityDisabledException::new), + OPERATION_NOT_ATTEMPTED(55, "The broker did not attempt to execute this operation. This may happen for " + + "batched RPCs where some operations in the batch failed, causing the broker to respond without " + + "trying the rest.", + OperationNotAttemptedException::new), + KAFKA_STORAGE_ERROR(56, "Disk error when trying to access log file on the disk.", + KafkaStorageException::new), + LOG_DIR_NOT_FOUND(57, "The user-specified log directory is not found in the broker config.", + LogDirNotFoundException::new), + SASL_AUTHENTICATION_FAILED(58, "SASL Authentication failed.", + SaslAuthenticationException::new), + UNKNOWN_PRODUCER_ID(59, "This exception is raised by the broker if it could not locate the producer metadata " + + "associated with the producerId in question. This could happen if, for instance, the producer's records " + + "were deleted because their retention time had elapsed. Once the last records of the producerId are " + + "removed, the producer's metadata is removed from the broker, and future appends by the producer will " + + "return this exception.", + UnknownProducerIdException::new), + REASSIGNMENT_IN_PROGRESS(60, "A partition reassignment is in progress.", + ReassignmentInProgressException::new), + DELEGATION_TOKEN_AUTH_DISABLED(61, "Delegation Token feature is not enabled.", + DelegationTokenDisabledException::new), + DELEGATION_TOKEN_NOT_FOUND(62, "Delegation Token is not found on server.", + DelegationTokenNotFoundException::new), + DELEGATION_TOKEN_OWNER_MISMATCH(63, "Specified Principal is not valid Owner/Renewer.", + DelegationTokenOwnerMismatchException::new), + DELEGATION_TOKEN_REQUEST_NOT_ALLOWED(64, "Delegation Token requests are not allowed on PLAINTEXT/1-way SSL " + + "channels and on delegation token authenticated channels.", + UnsupportedByAuthenticationException::new), + DELEGATION_TOKEN_AUTHORIZATION_FAILED(65, "Delegation Token authorization failed.", + DelegationTokenAuthorizationException::new), + DELEGATION_TOKEN_EXPIRED(66, "Delegation Token is expired.", + DelegationTokenExpiredException::new), + INVALID_PRINCIPAL_TYPE(67, "Supplied principalType is not supported.", + InvalidPrincipalTypeException::new), + NON_EMPTY_GROUP(68, "The group is not empty.", + GroupNotEmptyException::new), + GROUP_ID_NOT_FOUND(69, "The group id does not exist.", + GroupIdNotFoundException::new), + FETCH_SESSION_ID_NOT_FOUND(70, "The fetch session ID was not found.", + FetchSessionIdNotFoundException::new), + INVALID_FETCH_SESSION_EPOCH(71, "The fetch session epoch is invalid.", + InvalidFetchSessionEpochException::new), + LISTENER_NOT_FOUND(72, "There is no listener on the leader broker that matches the listener on which " + + "metadata request was processed.", + ListenerNotFoundException::new), + TOPIC_DELETION_DISABLED(73, "Topic deletion is disabled.", + TopicDeletionDisabledException::new), + FENCED_LEADER_EPOCH(74, "The leader epoch in the request is older than the epoch on the broker.", + FencedLeaderEpochException::new), + UNKNOWN_LEADER_EPOCH(75, "The leader epoch in the request is newer than the epoch on the broker.", + UnknownLeaderEpochException::new), + UNSUPPORTED_COMPRESSION_TYPE(76, "The requesting client does not support the compression type of given partition.", + UnsupportedCompressionTypeException::new), + STALE_BROKER_EPOCH(77, "Broker epoch has changed.", + StaleBrokerEpochException::new), + OFFSET_NOT_AVAILABLE(78, "The leader high watermark has not caught up from a recent leader " + + "election so the offsets cannot be guaranteed to be monotonically increasing.", + OffsetNotAvailableException::new), + MEMBER_ID_REQUIRED(79, "The group member needs to have a valid member id before actually entering a consumer group.", + MemberIdRequiredException::new), + PREFERRED_LEADER_NOT_AVAILABLE(80, "The preferred leader was not available.", + PreferredLeaderNotAvailableException::new), + GROUP_MAX_SIZE_REACHED(81, "The consumer group has reached its max size.", GroupMaxSizeReachedException::new), + FENCED_INSTANCE_ID(82, "The broker rejected this static consumer since " + + "another consumer with the same group.instance.id has registered with a different member.id.", + FencedInstanceIdException::new), + ELIGIBLE_LEADERS_NOT_AVAILABLE(83, "Eligible topic partition leaders are not available.", + EligibleLeadersNotAvailableException::new), + ELECTION_NOT_NEEDED(84, "Leader election not needed for topic partition.", ElectionNotNeededException::new), + NO_REASSIGNMENT_IN_PROGRESS(85, "No partition reassignment is in progress.", + NoReassignmentInProgressException::new), + GROUP_SUBSCRIBED_TO_TOPIC(86, "Deleting offsets of a topic is forbidden while the consumer group is actively subscribed to it.", + GroupSubscribedToTopicException::new), + INVALID_RECORD(87, "This record has failed the validation on broker and hence will be rejected.", InvalidRecordException::new), + UNSTABLE_OFFSET_COMMIT(88, "There are unstable offsets that need to be cleared.", UnstableOffsetCommitException::new), + THROTTLING_QUOTA_EXCEEDED(89, "The throttling quota has been exceeded.", ThrottlingQuotaExceededException::new), + PRODUCER_FENCED(90, "There is a newer producer with the same transactionalId " + + "which fences the current one.", ProducerFencedException::new), + RESOURCE_NOT_FOUND(91, "A request illegally referred to a resource that does not exist.", ResourceNotFoundException::new), + DUPLICATE_RESOURCE(92, "A request illegally referred to the same resource twice.", DuplicateResourceException::new), + UNACCEPTABLE_CREDENTIAL(93, "Requested credential would not meet criteria for acceptability.", UnacceptableCredentialException::new), + INCONSISTENT_VOTER_SET(94, "Indicates that the either the sender or recipient of a " + + "voter-only request is not one of the expected voters", InconsistentVoterSetException::new), + INVALID_UPDATE_VERSION(95, "The given update version was invalid.", InvalidUpdateVersionException::new), + FEATURE_UPDATE_FAILED(96, "Unable to update finalized features due to an unexpected server error.", FeatureUpdateFailedException::new), + PRINCIPAL_DESERIALIZATION_FAILURE(97, "Request principal deserialization failed during forwarding. " + + "This indicates an internal error on the broker cluster security setup.", PrincipalDeserializationException::new), + SNAPSHOT_NOT_FOUND(98, "Requested snapshot was not found", SnapshotNotFoundException::new), + POSITION_OUT_OF_RANGE( + 99, + "Requested position is not greater than or equal to zero, and less than the size of the snapshot.", + PositionOutOfRangeException::new), + UNKNOWN_TOPIC_ID(100, "This server does not host this topic ID.", UnknownTopicIdException::new), + DUPLICATE_BROKER_REGISTRATION(101, "This broker ID is already in use.", DuplicateBrokerRegistrationException::new), + BROKER_ID_NOT_REGISTERED(102, "The given broker ID was not registered.", BrokerIdNotRegisteredException::new), + INCONSISTENT_TOPIC_ID(103, "The log's topic ID did not match the topic ID in the request", InconsistentTopicIdException::new), + INCONSISTENT_CLUSTER_ID(104, "The clusterId in the request does not match that found on the server", InconsistentClusterIdException::new), + TRANSACTIONAL_ID_NOT_FOUND(105, "The transactionalId could not be found", TransactionalIdNotFoundException::new), + FETCH_SESSION_TOPIC_ID_ERROR(106, "The fetch session encountered inconsistent topic ID usage", FetchSessionTopicIdException::new); + + private static final Logger log = LoggerFactory.getLogger(Errors.class); + + private static Map, Errors> classToError = new HashMap<>(); + private static Map codeToError = new HashMap<>(); + + static { + for (Errors error : Errors.values()) { + if (codeToError.put(error.code(), error) != null) + throw new ExceptionInInitializerError("Code " + error.code() + " for error " + + error + " has already been used"); + + if (error.exception != null) + classToError.put(error.exception.getClass(), error); + } + } + + private final short code; + private final Function builder; + private final ApiException exception; + + Errors(int code, String defaultExceptionString, Function builder) { + this.code = (short) code; + this.builder = builder; + this.exception = builder.apply(defaultExceptionString); + } + + /** + * An instance of the exception + */ + public ApiException exception() { + return this.exception; + } + + /** + * Create an instance of the ApiException that contains the given error message. + * + * @param message The message string to set. + * @return The exception. + */ + public ApiException exception(String message) { + if (message == null) { + // If no error message was specified, return an exception with the default error message. + return exception; + } + // Return an exception with the given error message. + return builder.apply(message); + } + + /** + * Returns the class name of the exception or null if this is {@code Errors.NONE}. + */ + public String exceptionName() { + return exception == null ? null : exception.getClass().getName(); + } + + /** + * The error code for the exception + */ + public short code() { + return this.code; + } + + /** + * Throw the exception corresponding to this error if there is one + */ + public void maybeThrow() { + if (exception != null) { + throw this.exception; + } + } + + /** + * Get a friendly description of the error (if one is available). + * @return the error message + */ + public String message() { + if (exception != null) + return exception.getMessage(); + return toString(); + } + + /** + * Throw the exception if there is one + */ + public static Errors forCode(short code) { + Errors error = codeToError.get(code); + if (error != null) { + return error; + } else { + log.warn("Unexpected error code: {}.", code); + return UNKNOWN_SERVER_ERROR; + } + } + + /** + * Return the error instance associated with this exception or any of its superclasses (or UNKNOWN if there is none). + * If there are multiple matches in the class hierarchy, the first match starting from the bottom is used. + */ + public static Errors forException(Throwable t) { + Class clazz = t.getClass(); + while (clazz != null) { + Errors error = classToError.get(clazz); + if (error != null) + return error; + clazz = clazz.getSuperclass(); + } + return UNKNOWN_SERVER_ERROR; + } + + private static String toHtml() { + final StringBuilder b = new StringBuilder(); + b.append("\n"); + b.append(""); + b.append("\n"); + b.append("\n"); + b.append("\n"); + b.append("\n"); + b.append("\n"); + for (Errors error : Errors.values()) { + b.append(""); + b.append(""); + b.append(""); + b.append(""); + b.append(""); + b.append("\n"); + } + b.append("
        ErrorCodeRetriableDescription
        "); + b.append(error.name()); + b.append(""); + b.append(error.code()); + b.append(""); + b.append(error.exception() != null && error.exception() instanceof RetriableException ? "True" : "False"); + b.append(""); + b.append(error.exception() != null ? error.exception().getMessage() : ""); + b.append("
        \n"); + return b.toString(); + } + + public static void main(String[] args) { + System.out.println(toHtml()); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/Message.java b/clients/src/main/java/org/apache/kafka/common/protocol/Message.java new file mode 100644 index 0000000..e379f01 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/protocol/Message.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.protocol; + +import org.apache.kafka.common.protocol.types.RawTaggedField; + +import java.util.List; + +/** + * An object that can serialize itself. The serialization protocol is versioned. + * Messages also implement toString, equals, and hashCode. + */ +public interface Message { + /** + * Returns the lowest supported API key of this message, inclusive. + */ + short lowestSupportedVersion(); + + /** + * Returns the highest supported API key of this message, inclusive. + */ + short highestSupportedVersion(); + + /** + * Returns the number of bytes it would take to write out this message. + * + * @param cache The serialization size cache to populate. + * @param version The version to use. + * + * @throws {@see org.apache.kafka.common.errors.UnsupportedVersionException} + * If the specified version is too new to be supported + * by this software. + */ + default int size(ObjectSerializationCache cache, short version) { + MessageSizeAccumulator size = new MessageSizeAccumulator(); + addSize(size, cache, version); + return size.totalSize(); + } + + /** + * Add the size of this message to an accumulator. + * + * @param size The size accumulator to add to + * @param cache The serialization size cache to populate. + * @param version The version to use. + */ + void addSize(MessageSizeAccumulator size, ObjectSerializationCache cache, short version); + + /** + * Writes out this message to the given Writable. + * + * @param writable The destination writable. + * @param cache The object serialization cache to use. You must have + * previously populated the size cache using #{Message#size()}. + * @param version The version to use. + * + * @throws {@see org.apache.kafka.common.errors.UnsupportedVersionException} + * If the specified version is too new to be supported + * by this software. + */ + void write(Writable writable, ObjectSerializationCache cache, short version); + + /** + * Reads this message from the given Readable. This will overwrite all + * relevant fields with information from the byte buffer. + * + * @param readable The source readable. + * @param version The version to use. + * + * @throws {@see org.apache.kafka.common.errors.UnsupportedVersionException} + * If the specified version is too new to be supported + * by this software. + */ + void read(Readable readable, short version); + + /** + * Returns a list of tagged fields which this software can't understand. + * + * @return The raw tagged fields. + */ + List unknownTaggedFields(); + + /** + * Make a deep copy of the message. + * + * @return A copy of the message which does not share any mutable fields. + */ + Message duplicate(); +} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/MessageSizeAccumulator.java b/clients/src/main/java/org/apache/kafka/common/protocol/MessageSizeAccumulator.java new file mode 100644 index 0000000..dac007e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/protocol/MessageSizeAccumulator.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.protocol; + +/** + * Helper class which facilitates zero-copy network transmission. See {@link SendBuilder}. + */ +public class MessageSizeAccumulator { + private int totalSize = 0; + private int zeroCopySize = 0; + + /** + * Get the total size of the message. + * + * @return total size in bytes + */ + public int totalSize() { + return totalSize; + } + + /** + * Size excluding zero copy fields as specified by {@link #zeroCopySize}. This is typically the size of the byte + * buffer used to serialize messages. + */ + public int sizeExcludingZeroCopy() { + return totalSize - zeroCopySize; + } + + /** + * Get the total "zero-copy" size of the message. This is the summed + * total of all fields which have either have a type of 'bytes' with + * 'zeroCopy' enabled, or a type of 'records' + * + * @return total size of zero-copy data in the message + */ + public int zeroCopySize() { + return zeroCopySize; + } + + public void addZeroCopyBytes(int size) { + zeroCopySize += size; + totalSize += size; + } + + public void addBytes(int size) { + totalSize += size; + } + + public void add(MessageSizeAccumulator size) { + this.totalSize += size.totalSize; + this.zeroCopySize += size.zeroCopySize; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/MessageUtil.java b/clients/src/main/java/org/apache/kafka/common/protocol/MessageUtil.java new file mode 100644 index 0000000..288ffd0 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/protocol/MessageUtil.java @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.protocol; + +import com.fasterxml.jackson.databind.JsonNode; +import org.apache.kafka.common.protocol.types.RawTaggedField; +import org.apache.kafka.common.utils.Utils; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + + +public final class MessageUtil { + /** + * Copy a byte buffer into an array. This will not affect the buffer's + * position or mark. + */ + public static byte[] byteBufferToArray(ByteBuffer buf) { + byte[] arr = new byte[buf.remaining()]; + int prevPosition = buf.position(); + try { + buf.get(arr); + } finally { + buf.position(prevPosition); + } + return arr; + } + + public static String deepToString(Iterator iter) { + StringBuilder bld = new StringBuilder("["); + String prefix = ""; + while (iter.hasNext()) { + Object object = iter.next(); + bld.append(prefix); + bld.append(object.toString()); + prefix = ", "; + } + bld.append("]"); + return bld.toString(); + } + + public static byte jsonNodeToByte(JsonNode node, String about) { + int value = jsonNodeToInt(node, about); + if (value > Byte.MAX_VALUE) { + if (value <= 256) { + // It's more traditional to refer to bytes as unsigned, + // so we support that here. + value -= 128; + } else { + throw new RuntimeException(about + ": value " + value + + " does not fit in an 8-bit signed integer."); + } + } + if (value < Byte.MIN_VALUE) { + throw new RuntimeException(about + ": value " + value + + " does not fit in an 8-bit signed integer."); + } + return (byte) value; + } + + public static short jsonNodeToShort(JsonNode node, String about) { + int value = jsonNodeToInt(node, about); + if ((value < Short.MIN_VALUE) || (value > Short.MAX_VALUE)) { + throw new RuntimeException(about + ": value " + value + + " does not fit in a 16-bit signed integer."); + } + return (short) value; + } + + public static int jsonNodeToUnsignedShort(JsonNode node, String about) { + int value = jsonNodeToInt(node, about); + if (value < 0 || value > 65535) { + throw new RuntimeException(about + ": value " + value + + " does not fit in a 16-bit unsigned integer."); + } + return value; + } + + public static int jsonNodeToInt(JsonNode node, String about) { + if (node.isInt()) { + return node.asInt(); + } + if (node.isTextual()) { + throw new NumberFormatException(about + ": expected an integer or " + + "string type, but got " + node.getNodeType()); + } + String text = node.asText(); + if (text.startsWith("0x")) { + try { + return Integer.parseInt(text.substring(2), 16); + } catch (NumberFormatException e) { + throw new NumberFormatException(about + ": failed to " + + "parse hexadecimal number: " + e.getMessage()); + } + } else { + try { + return Integer.parseInt(text); + } catch (NumberFormatException e) { + throw new NumberFormatException(about + ": failed to " + + "parse number: " + e.getMessage()); + } + } + } + + public static long jsonNodeToLong(JsonNode node, String about) { + if (node.isLong()) { + return node.asLong(); + } + if (node.isTextual()) { + throw new NumberFormatException(about + ": expected an integer or " + + "string type, but got " + node.getNodeType()); + } + String text = node.asText(); + if (text.startsWith("0x")) { + try { + return Long.parseLong(text.substring(2), 16); + } catch (NumberFormatException e) { + throw new NumberFormatException(about + ": failed to " + + "parse hexadecimal number: " + e.getMessage()); + } + } else { + try { + return Long.parseLong(text); + } catch (NumberFormatException e) { + throw new NumberFormatException(about + ": failed to " + + "parse number: " + e.getMessage()); + } + } + } + + public static byte[] jsonNodeToBinary(JsonNode node, String about) { + if (!node.isBinary()) { + throw new RuntimeException(about + ": expected Base64-encoded binary data."); + } + try { + byte[] value = node.binaryValue(); + return value; + } catch (IOException e) { + throw new RuntimeException(about + ": unable to retrieve Base64-encoded binary data", e); + } + } + + public static double jsonNodeToDouble(JsonNode node, String about) { + if (!node.isFloatingPointNumber()) { + throw new NumberFormatException(about + ": expected a floating point " + + "type, but got " + node.getNodeType()); + } + return node.asDouble(); + } + + public static byte[] duplicate(byte[] array) { + if (array == null) + return null; + return Arrays.copyOf(array, array.length); + } + + /** + * Compare two RawTaggedFields lists. + * A null list is equivalent to an empty one in this context. + */ + public static boolean compareRawTaggedFields(List first, + List second) { + if (first == null) { + return second == null || second.isEmpty(); + } else if (second == null) { + return first.isEmpty(); + } else { + return first.equals(second); + } + } + + public static ByteBuffer toByteBuffer(final Message message, final short version) { + ObjectSerializationCache cache = new ObjectSerializationCache(); + int messageSize = message.size(cache, version); + ByteBufferAccessor bytes = new ByteBufferAccessor(ByteBuffer.allocate(messageSize)); + message.write(bytes, cache, version); + bytes.flip(); + return bytes.buffer(); + } + + public static ByteBuffer toVersionPrefixedByteBuffer(final short version, final Message message) { + ObjectSerializationCache cache = new ObjectSerializationCache(); + int messageSize = message.size(cache, version); + ByteBufferAccessor bytes = new ByteBufferAccessor(ByteBuffer.allocate(messageSize + 2)); + bytes.writeShort(version); + message.write(bytes, cache, version); + bytes.flip(); + return bytes.buffer(); + } + + public static byte[] toVersionPrefixedBytes(final short version, final Message message) { + ByteBuffer buffer = toVersionPrefixedByteBuffer(version, message); + // take the inner array directly if it is full with data + if (buffer.hasArray() && + buffer.arrayOffset() == 0 && + buffer.position() == 0 && + buffer.limit() == buffer.array().length) return buffer.array(); + else return Utils.toArray(buffer); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/ObjectSerializationCache.java b/clients/src/main/java/org/apache/kafka/common/protocol/ObjectSerializationCache.java new file mode 100644 index 0000000..208b056 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/protocol/ObjectSerializationCache.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.protocol; + +import java.util.IdentityHashMap; + +/** + * The ObjectSerializationCache stores sizes and values computed during the + * first serialization pass. This avoids recalculating and recomputing the same + * values during the second pass. + * + * It is intended to be used as part of a two-pass serialization process like: + * ObjectSerializationCache cache = new ObjectSerializationCache(); + * message.size(version, cache); + * message.write(version, cache); + */ +public final class ObjectSerializationCache { + private final IdentityHashMap map; + + public ObjectSerializationCache() { + this.map = new IdentityHashMap<>(); + } + + public void setArraySizeInBytes(Object o, Integer size) { + map.put(o, size); + } + + public Integer getArraySizeInBytes(Object o) { + return (Integer) map.get(o); + } + + public void cacheSerializedValue(Object o, byte[] val) { + map.put(o, val); + } + + public byte[] getSerializedValue(Object o) { + Object value = map.get(o); + return (byte[]) value; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/Protocol.java b/clients/src/main/java/org/apache/kafka/common/protocol/Protocol.java new file mode 100644 index 0000000..d455b26 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/protocol/Protocol.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.protocol; + +import org.apache.kafka.common.message.RequestHeaderData; +import org.apache.kafka.common.message.ResponseHeaderData; +import org.apache.kafka.common.protocol.types.BoundField; +import org.apache.kafka.common.protocol.types.Schema; +import org.apache.kafka.common.protocol.types.TaggedFields; +import org.apache.kafka.common.protocol.types.Type; + +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; + +public class Protocol { + + private static String indentString(int size) { + StringBuilder b = new StringBuilder(size); + for (int i = 0; i < size; i++) + b.append(" "); + return b.toString(); + } + + private static void schemaToBnfHtml(Schema schema, StringBuilder b, int indentSize) { + final String indentStr = indentString(indentSize); + final Map subTypes = new LinkedHashMap<>(); + + // Top level fields + for (BoundField field: schema.fields()) { + Type type = field.def.type; + if (type.isArray()) { + b.append("["); + b.append(field.def.name); + b.append("] "); + if (!subTypes.containsKey(field.def.name)) { + subTypes.put(field.def.name, type.arrayElementType().get()); + } + } else if (type instanceof TaggedFields) { + b.append("TAG_BUFFER "); + } else { + b.append(field.def.name); + b.append(" "); + if (!subTypes.containsKey(field.def.name)) + subTypes.put(field.def.name, type); + } + } + b.append("\n"); + + // Sub Types/Schemas + for (Map.Entry entry: subTypes.entrySet()) { + if (entry.getValue() instanceof Schema) { + // Complex Schema Type + b.append(indentStr); + b.append(entry.getKey()); + b.append(" => "); + schemaToBnfHtml((Schema) entry.getValue(), b, indentSize + 2); + } else { + // Standard Field Type + b.append(indentStr); + b.append(entry.getKey()); + b.append(" => "); + b.append(entry.getValue()); + b.append("\n"); + } + } + } + + private static void populateSchemaFields(Schema schema, Set fields) { + for (BoundField field: schema.fields()) { + fields.add(field); + if (field.def.type.isArray()) { + Type innerType = field.def.type.arrayElementType().get(); + if (innerType instanceof Schema) + populateSchemaFields((Schema) innerType, fields); + } else if (field.def.type instanceof Schema) + populateSchemaFields((Schema) field.def.type, fields); + } + } + + private static void schemaToFieldTableHtml(Schema schema, StringBuilder b) { + Set fields = new LinkedHashSet<>(); + populateSchemaFields(schema, fields); + + b.append("\n"); + b.append(""); + b.append("\n"); + b.append("\n"); + b.append(""); + for (BoundField field : fields) { + b.append("\n"); + b.append(""); + b.append(""); + b.append("\n"); + } + b.append("
        FieldDescription
        "); + b.append(field.def.name); + b.append(""); + b.append(field.def.docString); + b.append("
        \n"); + } + + public static String toHtml() { + final StringBuilder b = new StringBuilder(); + b.append("

        Headers:
        \n"); + + for (int i = 0; i < RequestHeaderData.SCHEMAS.length; i++) { + b.append("
        ");
        +            b.append("Request Header v").append(i).append(" => ");
        +            schemaToBnfHtml(RequestHeaderData.SCHEMAS[i], b, 2);
        +            b.append("
        \n"); + schemaToFieldTableHtml(RequestHeaderData.SCHEMAS[i], b); + } + for (int i = 0; i < ResponseHeaderData.SCHEMAS.length; i++) { + b.append("
        ");
        +            b.append("Response Header v").append(i).append(" => ");
        +            schemaToBnfHtml(ResponseHeaderData.SCHEMAS[i], b, 2);
        +            b.append("
        \n"); + schemaToFieldTableHtml(ResponseHeaderData.SCHEMAS[i], b); + } + for (ApiKeys key : ApiKeys.zkBrokerApis()) { + // Key + b.append("
        "); + b.append(""); + b.append(key.name); + b.append(" API (Key: "); + b.append(key.id); + b.append("):
        \n\n"); + // Requests + b.append("Requests:
        \n"); + Schema[] requests = key.messageType.requestSchemas(); + for (int i = 0; i < requests.length; i++) { + Schema schema = requests[i]; + // Schema + if (schema != null) { + b.append("

        "); + // Version header + b.append("

        ");
        +                    b.append(key.name);
        +                    b.append(" Request (Version: ");
        +                    b.append(i);
        +                    b.append(") => ");
        +                    schemaToBnfHtml(requests[i], b, 2);
        +                    b.append("
        "); + schemaToFieldTableHtml(requests[i], b); + } + b.append("

        \n"); + } + + // Responses + b.append("Responses:
        \n"); + Schema[] responses = key.messageType.responseSchemas(); + for (int i = 0; i < responses.length; i++) { + Schema schema = responses[i]; + // Schema + if (schema != null) { + b.append("

        "); + // Version header + b.append("

        ");
        +                    b.append(key.name);
        +                    b.append(" Response (Version: ");
        +                    b.append(i);
        +                    b.append(") => ");
        +                    schemaToBnfHtml(responses[i], b, 2);
        +                    b.append("
        "); + schemaToFieldTableHtml(responses[i], b); + } + b.append("

        \n"); + } + } + + return b.toString(); + } + + public static void main(String[] args) { + System.out.println(toHtml()); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/Readable.java b/clients/src/main/java/org/apache/kafka/common/protocol/Readable.java new file mode 100644 index 0000000..9c9e461 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/protocol/Readable.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.protocol; + +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.protocol.types.RawTaggedField; +import org.apache.kafka.common.record.MemoryRecords; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; + +public interface Readable { + byte readByte(); + short readShort(); + int readInt(); + long readLong(); + double readDouble(); + void readArray(byte[] arr); + int readUnsignedVarint(); + ByteBuffer readByteBuffer(int length); + int readVarint(); + long readVarlong(); + int remaining(); + + default String readString(int length) { + byte[] arr = new byte[length]; + readArray(arr); + return new String(arr, StandardCharsets.UTF_8); + } + + default List readUnknownTaggedField(List unknowns, int tag, int size) { + if (unknowns == null) { + unknowns = new ArrayList<>(); + } + byte[] data = new byte[size]; + readArray(data); + unknowns.add(new RawTaggedField(tag, data)); + return unknowns; + } + + default MemoryRecords readRecords(int length) { + if (length < 0) { + // no records + return null; + } else { + ByteBuffer recordsBuffer = readByteBuffer(length); + return MemoryRecords.readableRecords(recordsBuffer); + } + } + + /** + * Read a UUID with the most significant digits first. + */ + default Uuid readUuid() { + return new Uuid(readLong(), readLong()); + } + + default int readUnsignedShort() { + return Short.toUnsignedInt(readShort()); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/SendBuilder.java b/clients/src/main/java/org/apache/kafka/common/protocol/SendBuilder.java new file mode 100644 index 0000000..46afacd --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/protocol/SendBuilder.java @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.protocol; + +import org.apache.kafka.common.network.ByteBufferSend; +import org.apache.kafka.common.network.Send; +import org.apache.kafka.common.record.BaseRecords; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.MultiRecordsSend; +import org.apache.kafka.common.record.UnalignedMemoryRecords; +import org.apache.kafka.common.requests.RequestHeader; +import org.apache.kafka.common.requests.ResponseHeader; +import org.apache.kafka.common.utils.ByteUtils; + +import java.nio.ByteBuffer; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.List; +import java.util.Queue; + +/** + * This class provides a way to build {@link Send} objects for network transmission + * from generated {@link org.apache.kafka.common.protocol.ApiMessage} types without + * allocating new space for "zero-copy" fields (see {@link #writeByteBuffer(ByteBuffer)} + * and {@link #writeRecords(BaseRecords)}). + * + * See {@link org.apache.kafka.common.requests.EnvelopeRequest#toSend(RequestHeader)} + * for example usage. + */ +public class SendBuilder implements Writable { + private final ByteBuffer buffer; + + private final Queue sends = new ArrayDeque<>(1); + private long sizeOfSends = 0; + + private final List buffers = new ArrayList<>(); + private long sizeOfBuffers = 0; + + SendBuilder(int size) { + this.buffer = ByteBuffer.allocate(size); + this.buffer.mark(); + } + + @Override + public void writeByte(byte val) { + buffer.put(val); + } + + @Override + public void writeShort(short val) { + buffer.putShort(val); + } + + @Override + public void writeInt(int val) { + buffer.putInt(val); + } + + @Override + public void writeLong(long val) { + buffer.putLong(val); + } + + @Override + public void writeDouble(double val) { + buffer.putDouble(val); + } + + @Override + public void writeByteArray(byte[] arr) { + buffer.put(arr); + } + + @Override + public void writeUnsignedVarint(int i) { + ByteUtils.writeUnsignedVarint(i, buffer); + } + + /** + * Write a byte buffer. The reference to the underlying buffer will + * be retained in the result of {@link #build()}. + * + * @param buf the buffer to write + */ + @Override + public void writeByteBuffer(ByteBuffer buf) { + flushPendingBuffer(); + addBuffer(buf.duplicate()); + } + + @Override + public void writeVarint(int i) { + ByteUtils.writeVarint(i, buffer); + } + + @Override + public void writeVarlong(long i) { + ByteUtils.writeVarlong(i, buffer); + } + + private void addBuffer(ByteBuffer buffer) { + buffers.add(buffer); + sizeOfBuffers += buffer.remaining(); + } + + private void addSend(Send send) { + sends.add(send); + sizeOfSends += send.size(); + } + + private void clearBuffers() { + buffers.clear(); + sizeOfBuffers = 0; + } + + /** + * Write a record set. The underlying record data will be retained + * in the result of {@link #build()}. See {@link BaseRecords#toSend()}. + * + * @param records the records to write + */ + @Override + public void writeRecords(BaseRecords records) { + if (records instanceof MemoryRecords) { + flushPendingBuffer(); + addBuffer(((MemoryRecords) records).buffer()); + } else if (records instanceof UnalignedMemoryRecords) { + flushPendingBuffer(); + addBuffer(((UnalignedMemoryRecords) records).buffer()); + } else { + flushPendingSend(); + addSend(records.toSend()); + } + } + + private void flushPendingSend() { + flushPendingBuffer(); + if (!buffers.isEmpty()) { + ByteBuffer[] byteBufferArray = buffers.toArray(new ByteBuffer[0]); + addSend(new ByteBufferSend(byteBufferArray, sizeOfBuffers)); + clearBuffers(); + } + } + + private void flushPendingBuffer() { + int latestPosition = buffer.position(); + buffer.reset(); + + if (latestPosition > buffer.position()) { + buffer.limit(latestPosition); + addBuffer(buffer.slice()); + + buffer.position(latestPosition); + buffer.limit(buffer.capacity()); + buffer.mark(); + } + } + + public Send build() { + flushPendingSend(); + + if (sends.size() == 1) { + return sends.poll(); + } else { + return new MultiRecordsSend(sends, sizeOfSends); + } + } + + public static Send buildRequestSend( + RequestHeader header, + Message apiRequest + ) { + return buildSend( + header.data(), + header.headerVersion(), + apiRequest, + header.apiVersion() + ); + } + + public static Send buildResponseSend( + ResponseHeader header, + Message apiResponse, + short apiVersion + ) { + return buildSend( + header.data(), + header.headerVersion(), + apiResponse, + apiVersion + ); + } + + private static Send buildSend( + Message header, + short headerVersion, + Message apiMessage, + short apiVersion + ) { + ObjectSerializationCache serializationCache = new ObjectSerializationCache(); + + MessageSizeAccumulator messageSize = new MessageSizeAccumulator(); + header.addSize(messageSize, serializationCache, headerVersion); + apiMessage.addSize(messageSize, serializationCache, apiVersion); + + SendBuilder builder = new SendBuilder(messageSize.sizeExcludingZeroCopy() + 4); + builder.writeInt(messageSize.totalSize()); + header.write(builder, serializationCache, headerVersion); + apiMessage.write(builder, serializationCache, apiVersion); + + return builder.build(); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/Writable.java b/clients/src/main/java/org/apache/kafka/common/protocol/Writable.java new file mode 100644 index 0000000..8dbec87 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/protocol/Writable.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.protocol; + +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.record.BaseRecords; +import org.apache.kafka.common.record.MemoryRecords; + +import java.nio.ByteBuffer; + +public interface Writable { + void writeByte(byte val); + void writeShort(short val); + void writeInt(int val); + void writeLong(long val); + void writeDouble(double val); + void writeByteArray(byte[] arr); + void writeUnsignedVarint(int i); + void writeByteBuffer(ByteBuffer buf); + void writeVarint(int i); + void writeVarlong(long i); + + default void writeRecords(BaseRecords records) { + if (records instanceof MemoryRecords) { + MemoryRecords memRecords = (MemoryRecords) records; + writeByteBuffer(memRecords.buffer()); + } else { + throw new UnsupportedOperationException("Unsupported record type " + records.getClass()); + } + } + + default void writeUuid(Uuid uuid) { + writeLong(uuid.getMostSignificantBits()); + writeLong(uuid.getLeastSignificantBits()); + } + + default void writeUnsignedShort(int i) { + // The setter functions in the generated code prevent us from setting + // ints outside the valid range of a short. + writeShort((short) i); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/types/ArrayOf.java b/clients/src/main/java/org/apache/kafka/common/protocol/types/ArrayOf.java new file mode 100644 index 0000000..3333084 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/protocol/types/ArrayOf.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.protocol.types; + +import org.apache.kafka.common.protocol.types.Type.DocumentedType; + +import java.nio.ByteBuffer; +import java.util.Optional; + +/** + * Represents a type for an array of a particular type + */ +public class ArrayOf extends DocumentedType { + + private static final String ARRAY_TYPE_NAME = "ARRAY"; + + private final Type type; + private final boolean nullable; + + public ArrayOf(Type type) { + this(type, false); + } + + public static ArrayOf nullable(Type type) { + return new ArrayOf(type, true); + } + + private ArrayOf(Type type, boolean nullable) { + this.type = type; + this.nullable = nullable; + } + + @Override + public boolean isNullable() { + return nullable; + } + + @Override + public void write(ByteBuffer buffer, Object o) { + if (o == null) { + buffer.putInt(-1); + return; + } + + Object[] objs = (Object[]) o; + int size = objs.length; + buffer.putInt(size); + + for (Object obj : objs) + type.write(buffer, obj); + } + + @Override + public Object read(ByteBuffer buffer) { + int size = buffer.getInt(); + if (size < 0 && isNullable()) + return null; + else if (size < 0) + throw new SchemaException("Array size " + size + " cannot be negative"); + + if (size > buffer.remaining()) + throw new SchemaException("Error reading array of size " + size + ", only " + buffer.remaining() + " bytes available"); + Object[] objs = new Object[size]; + for (int i = 0; i < size; i++) + objs[i] = type.read(buffer); + return objs; + } + + @Override + public int sizeOf(Object o) { + int size = 4; + if (o == null) + return size; + + Object[] objs = (Object[]) o; + for (Object obj : objs) + size += type.sizeOf(obj); + return size; + } + + @Override + public Optional arrayElementType() { + return Optional.of(type); + } + + @Override + public String toString() { + return ARRAY_TYPE_NAME + "(" + type + ")"; + } + + @Override + public Object[] validate(Object item) { + try { + if (isNullable() && item == null) + return null; + + Object[] array = (Object[]) item; + for (Object obj : array) + type.validate(obj); + return array; + } catch (ClassCastException e) { + throw new SchemaException("Not an Object[]."); + } + } + + @Override + public String typeName() { + return ARRAY_TYPE_NAME; + } + + @Override + public String documentation() { + return "Represents a sequence of objects of a given type T. " + + "Type T can be either a primitive type (e.g. " + STRING + ") or a structure. " + + "First, the length N is given as an " + INT32 + ". Then N instances of type T follow. " + + "A null array is represented with a length of -1. " + + "In protocol documentation an array of T instances is referred to as [T]."; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/types/BoundField.java b/clients/src/main/java/org/apache/kafka/common/protocol/types/BoundField.java new file mode 100644 index 0000000..b031b4f --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/protocol/types/BoundField.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.protocol.types; + +/** + * A field definition bound to a particular schema. + */ +public class BoundField { + public final Field def; + final int index; + final Schema schema; + + public BoundField(Field def, Schema schema, int index) { + this.def = def; + this.schema = schema; + this.index = index; + } + + @Override + public String toString() { + return def.name + ":" + def.type; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/types/CompactArrayOf.java b/clients/src/main/java/org/apache/kafka/common/protocol/types/CompactArrayOf.java new file mode 100644 index 0000000..4e9f8f8 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/protocol/types/CompactArrayOf.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.protocol.types; + +import org.apache.kafka.common.protocol.types.Type.DocumentedType; +import org.apache.kafka.common.utils.ByteUtils; + +import java.nio.ByteBuffer; +import java.util.Optional; + +/** + * Represents a type for a compact array of a particular type. + * A compact array represents its length with a varint rather than a + * fixed-length field. + */ +public class CompactArrayOf extends DocumentedType { + private static final String COMPACT_ARRAY_TYPE_NAME = "COMPACT_ARRAY"; + + private final Type type; + private final boolean nullable; + + + public CompactArrayOf(Type type) { + this(type, false); + } + + public static CompactArrayOf nullable(Type type) { + return new CompactArrayOf(type, true); + } + + private CompactArrayOf(Type type, boolean nullable) { + this.type = type; + this.nullable = nullable; + } + + @Override + public boolean isNullable() { + return nullable; + } + + @Override + public void write(ByteBuffer buffer, Object o) { + if (o == null) { + ByteUtils.writeUnsignedVarint(0, buffer); + return; + } + Object[] objs = (Object[]) o; + int size = objs.length; + ByteUtils.writeUnsignedVarint(size + 1, buffer); + + for (Object obj : objs) + type.write(buffer, obj); + } + + @Override + public Object read(ByteBuffer buffer) { + int n = ByteUtils.readUnsignedVarint(buffer); + if (n == 0) { + if (isNullable()) { + return null; + } else { + throw new SchemaException("This array is not nullable."); + } + } + int size = n - 1; + if (size > buffer.remaining()) + throw new SchemaException("Error reading array of size " + size + ", only " + buffer.remaining() + " bytes available"); + Object[] objs = new Object[size]; + for (int i = 0; i < size; i++) + objs[i] = type.read(buffer); + return objs; + } + + @Override + public int sizeOf(Object o) { + if (o == null) { + return 1; + } + Object[] objs = (Object[]) o; + int size = ByteUtils.sizeOfUnsignedVarint(objs.length + 1); + for (Object obj : objs) { + size += type.sizeOf(obj); + } + return size; + } + + @Override + public Optional arrayElementType() { + return Optional.of(type); + } + + @Override + public String toString() { + return COMPACT_ARRAY_TYPE_NAME + "(" + type + ")"; + } + + @Override + public Object[] validate(Object item) { + try { + if (isNullable() && item == null) + return null; + + Object[] array = (Object[]) item; + for (Object obj : array) + type.validate(obj); + return array; + } catch (ClassCastException e) { + throw new SchemaException("Not an Object[]."); + } + } + + @Override + public String typeName() { + return COMPACT_ARRAY_TYPE_NAME; + } + + @Override + public String documentation() { + return "Represents a sequence of objects of a given type T. " + + "Type T can be either a primitive type (e.g. " + STRING + ") or a structure. " + + "First, the length N + 1 is given as an UNSIGNED_VARINT. Then N instances of type T follow. " + + "A null array is represented with a length of 0. " + + "In protocol documentation an array of T instances is referred to as [T]."; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/types/Field.java b/clients/src/main/java/org/apache/kafka/common/protocol/types/Field.java new file mode 100644 index 0000000..44726f8 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/protocol/types/Field.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.protocol.types; + +public class Field { + public final String name; + public final String docString; + public final Type type; + public final boolean hasDefaultValue; + public final Object defaultValue; + + public Field(String name, Type type, String docString, boolean hasDefaultValue, Object defaultValue) { + this.name = name; + this.docString = docString; + this.type = type; + this.hasDefaultValue = hasDefaultValue; + this.defaultValue = defaultValue; + + if (hasDefaultValue) + type.validate(defaultValue); + } + + public Field(String name, Type type, String docString) { + this(name, type, docString, false, null); + } + + public Field(String name, Type type, String docString, Object defaultValue) { + this(name, type, docString, true, defaultValue); + } + + public Field(String name, Type type) { + this(name, type, null, false, null); + } + + public static class Int8 extends Field { + public Int8(String name, String docString) { + super(name, Type.INT8, docString, false, null); + } + public Int8(String name, String docString, byte defaultValue) { + super(name, Type.INT8, docString, true, defaultValue); + } + } + + public static class Int32 extends Field { + public Int32(String name, String docString) { + super(name, Type.INT32, docString, false, null); + } + + public Int32(String name, String docString, int defaultValue) { + super(name, Type.INT32, docString, true, defaultValue); + } + } + + public static class Int64 extends Field { + public Int64(String name, String docString) { + super(name, Type.INT64, docString, false, null); + } + + public Int64(String name, String docString, long defaultValue) { + super(name, Type.INT64, docString, true, defaultValue); + } + } + + public static class UUID extends Field { + public UUID(String name, String docString) { + super(name, Type.UUID, docString, false, null); + } + + public UUID(String name, String docString, UUID defaultValue) { + super(name, Type.UUID, docString, true, defaultValue); + } + } + + public static class Int16 extends Field { + public Int16(String name, String docString) { + super(name, Type.INT16, docString, false, null); + } + } + + public static class Uint16 extends Field { + public Uint16(String name, String docString) { + super(name, Type.UINT16, docString, false, null); + } + } + + public static class Float64 extends Field { + public Float64(String name, String docString) { + super(name, Type.FLOAT64, docString, false, null); + } + + public Float64(String name, String docString, double defaultValue) { + super(name, Type.FLOAT64, docString, true, defaultValue); + } + } + + public static class Str extends Field { + public Str(String name, String docString) { + super(name, Type.STRING, docString, false, null); + } + } + + public static class CompactStr extends Field { + public CompactStr(String name, String docString) { + super(name, Type.COMPACT_STRING, docString, false, null); + } + } + + public static class NullableStr extends Field { + public NullableStr(String name, String docString) { + super(name, Type.NULLABLE_STRING, docString, false, null); + } + } + + public static class CompactNullableStr extends Field { + public CompactNullableStr(String name, String docString) { + super(name, Type.COMPACT_NULLABLE_STRING, docString, false, null); + } + } + + public static class Bool extends Field { + public Bool(String name, String docString) { + super(name, Type.BOOLEAN, docString, false, null); + } + } + + public static class Array extends Field { + public Array(String name, Type elementType, String docString) { + super(name, new ArrayOf(elementType), docString, false, null); + } + } + + public static class CompactArray extends Field { + public CompactArray(String name, Type elementType, String docString) { + super(name, new CompactArrayOf(elementType), docString, false, null); + } + } + + public static class TaggedFieldsSection extends Field { + private static final String NAME = "_tagged_fields"; + private static final String DOC_STRING = "The tagged fields"; + + /** + * Create a new TaggedFieldsSection with the given tags and fields. + * + * @param fields This is an array containing Integer tags followed + * by associated Field objects. + * @return The new {@link TaggedFieldsSection} + */ + public static TaggedFieldsSection of(Object... fields) { + return new TaggedFieldsSection(TaggedFields.of(fields)); + } + + public TaggedFieldsSection(Type type) { + super(NAME, type, DOC_STRING, false, null); + } + } + + public static class ComplexArray { + public final String name; + public final String docString; + + public ComplexArray(String name, String docString) { + this.name = name; + this.docString = docString; + } + + public Field withFields(Field... fields) { + Schema elementType = new Schema(fields); + return new Field(name, new ArrayOf(elementType), docString, false, null); + } + + public Field nullableWithFields(Field... fields) { + Schema elementType = new Schema(fields); + return new Field(name, ArrayOf.nullable(elementType), docString, false, null); + } + + public Field withFields(String docStringOverride, Field... fields) { + Schema elementType = new Schema(fields); + return new Field(name, new ArrayOf(elementType), docStringOverride, false, null); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/types/RawTaggedField.java b/clients/src/main/java/org/apache/kafka/common/protocol/types/RawTaggedField.java new file mode 100644 index 0000000..60deb5e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/protocol/types/RawTaggedField.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.protocol.types; + +import java.util.Arrays; + +public class RawTaggedField { + private final int tag; + private final byte[] data; + + public RawTaggedField(int tag, byte[] data) { + this.tag = tag; + this.data = data; + } + + public int tag() { + return tag; + } + + public byte[] data() { + return data; + } + + public int size() { + return data.length; + } + + @Override + public boolean equals(Object o) { + if ((o == null) || (!o.getClass().equals(getClass()))) { + return false; + } + RawTaggedField other = (RawTaggedField) o; + return tag == other.tag && Arrays.equals(data, other.data); + } + + @Override + public int hashCode() { + return tag ^ Arrays.hashCode(data); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/types/RawTaggedFieldWriter.java b/clients/src/main/java/org/apache/kafka/common/protocol/types/RawTaggedFieldWriter.java new file mode 100644 index 0000000..7218d34 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/protocol/types/RawTaggedFieldWriter.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.protocol.types; + +import org.apache.kafka.common.protocol.Writable; + +import java.util.ArrayList; +import java.util.List; +import java.util.ListIterator; + +/** + * The RawTaggedFieldWriter is used by Message subclasses to serialize their + * lists of raw tags. + */ +public class RawTaggedFieldWriter { + private static final RawTaggedFieldWriter EMPTY_WRITER = + new RawTaggedFieldWriter(new ArrayList<>(0)); + + private final List fields; + private final ListIterator iter; + private int prevTag; + + public static RawTaggedFieldWriter forFields(List fields) { + if (fields == null) { + return EMPTY_WRITER; + } + return new RawTaggedFieldWriter(fields); + } + + private RawTaggedFieldWriter(List fields) { + this.fields = fields; + this.iter = this.fields.listIterator(); + this.prevTag = -1; + } + + public int numFields() { + return fields.size(); + } + + public void writeRawTags(Writable writable, int nextDefinedTag) { + while (iter.hasNext()) { + RawTaggedField field = iter.next(); + int tag = field.tag(); + if (tag >= nextDefinedTag) { + if (tag == nextDefinedTag) { + // We must not have a raw tag field that duplicates the tag of another field. + throw new RuntimeException("Attempted to use tag " + tag + " as an " + + "undefined tag."); + } + iter.previous(); + return; + } + if (tag <= prevTag) { + // The raw tag field list must be sorted by tag, and there must not be + // any duplicate tags. + throw new RuntimeException("Invalid raw tag field list: tag " + tag + + " comes after tag " + prevTag + ", but is not higher than it."); + } + writable.writeUnsignedVarint(field.tag()); + writable.writeUnsignedVarint(field.data().length); + writable.writeByteArray(field.data()); + prevTag = tag; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/types/Schema.java b/clients/src/main/java/org/apache/kafka/common/protocol/types/Schema.java new file mode 100644 index 0000000..aa6ffbe --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/protocol/types/Schema.java @@ -0,0 +1,235 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.protocol.types; + +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +/** + * The schema for a compound record definition + */ +public class Schema extends Type { + private final static Object[] NO_VALUES = new Object[0]; + + private final BoundField[] fields; + private final Map fieldsByName; + private final boolean tolerateMissingFieldsWithDefaults; + private final Struct cachedStruct; + + /** + * Construct the schema with a given list of its field values + * + * @param fs the fields of this schema + * + * @throws SchemaException If the given list have duplicate fields + */ + public Schema(Field... fs) { + this(false, fs); + } + + /** + * Construct the schema with a given list of its field values and the ability to tolerate + * missing optional fields with defaults at the end of the schema definition. + * + * @param tolerateMissingFieldsWithDefaults whether to accept records with missing optional + * fields the end of the schema + * @param fs the fields of this schema + * + * @throws SchemaException If the given list have duplicate fields + */ + public Schema(boolean tolerateMissingFieldsWithDefaults, Field... fs) { + this.fields = new BoundField[fs.length]; + this.fieldsByName = new HashMap<>(); + this.tolerateMissingFieldsWithDefaults = tolerateMissingFieldsWithDefaults; + for (int i = 0; i < this.fields.length; i++) { + Field def = fs[i]; + if (fieldsByName.containsKey(def.name)) + throw new SchemaException("Schema contains a duplicate field: " + def.name); + this.fields[i] = new BoundField(def, this, i); + this.fieldsByName.put(def.name, this.fields[i]); + } + //6 schemas have no fields at the time of this writing (3 versions each of list_groups and api_versions) + //for such schemas there's no point in even creating a unique Struct object when deserializing. + this.cachedStruct = this.fields.length > 0 ? null : new Struct(this, NO_VALUES); + } + + /** + * Write a struct to the buffer + */ + @Override + public void write(ByteBuffer buffer, Object o) { + Struct r = (Struct) o; + for (BoundField field : fields) { + try { + Object value = field.def.type.validate(r.get(field)); + field.def.type.write(buffer, value); + } catch (Exception e) { + throw new SchemaException("Error writing field '" + field.def.name + "': " + + (e.getMessage() == null ? e.getClass().getName() : e.getMessage())); + } + } + } + + /** + * Read a struct from the buffer. If this schema is configured to tolerate missing + * optional fields at the end of the buffer, these fields are replaced with their default + * values; otherwise, if the schema does not tolerate missing fields, or if missing fields + * don't have a default value, a {@code SchemaException} is thrown to signify that mandatory + * fields are missing. + */ + @Override + public Struct read(ByteBuffer buffer) { + if (cachedStruct != null) { + return cachedStruct; + } + Object[] objects = new Object[fields.length]; + for (int i = 0; i < fields.length; i++) { + try { + if (tolerateMissingFieldsWithDefaults) { + if (buffer.hasRemaining()) { + objects[i] = fields[i].def.type.read(buffer); + } else if (fields[i].def.hasDefaultValue) { + objects[i] = fields[i].def.defaultValue; + } else { + throw new SchemaException("Missing value for field '" + fields[i].def.name + + "' which has no default value."); + } + } else { + objects[i] = fields[i].def.type.read(buffer); + } + } catch (Exception e) { + throw new SchemaException("Error reading field '" + fields[i].def.name + "': " + + (e.getMessage() == null ? e.getClass().getName() : e.getMessage())); + } + } + return new Struct(this, objects); + } + + /** + * The size of the given record + */ + @Override + public int sizeOf(Object o) { + int size = 0; + Struct r = (Struct) o; + for (BoundField field : fields) { + try { + size += field.def.type.sizeOf(r.get(field)); + } catch (Exception e) { + throw new SchemaException("Error computing size for field '" + field.def.name + "': " + + (e.getMessage() == null ? e.getClass().getName() : e.getMessage())); + } + } + return size; + } + + /** + * The number of fields in this schema + */ + public int numFields() { + return this.fields.length; + } + + /** + * Get a field by its slot in the record array + * + * @param slot The slot at which this field sits + * @return The field + */ + public BoundField get(int slot) { + return this.fields[slot]; + } + + /** + * Get a field by its name + * + * @param name The name of the field + * @return The field + */ + public BoundField get(String name) { + return this.fieldsByName.get(name); + } + + /** + * Get all the fields in this schema + */ + public BoundField[] fields() { + return this.fields; + } + + /** + * Display a string representation of the schema + */ + @Override + public String toString() { + StringBuilder b = new StringBuilder(); + b.append('{'); + for (int i = 0; i < this.fields.length; i++) { + b.append(this.fields[i].toString()); + if (i < this.fields.length - 1) + b.append(','); + } + b.append("}"); + return b.toString(); + } + + @Override + public Struct validate(Object item) { + try { + Struct struct = (Struct) item; + for (BoundField field : fields) { + try { + field.def.type.validate(struct.get(field)); + } catch (SchemaException e) { + throw new SchemaException("Invalid value for field '" + field.def.name + "': " + e.getMessage()); + } + } + return struct; + } catch (ClassCastException e) { + throw new SchemaException("Not a Struct."); + } + } + + public void walk(Visitor visitor) { + Objects.requireNonNull(visitor, "visitor must be non-null"); + handleNode(this, visitor); + } + + private static void handleNode(Type node, Visitor visitor) { + if (node instanceof Schema) { + Schema schema = (Schema) node; + visitor.visit(schema); + for (BoundField f : schema.fields()) + handleNode(f.def.type, visitor); + } else if (node.isArray()) { + visitor.visit(node); + handleNode(node.arrayElementType().get(), visitor); + } else { + visitor.visit(node); + } + } + + /** + * Override one or more of the visit methods with the desired logic. + */ + public static abstract class Visitor { + public void visit(Schema schema) {} + public void visit(Type field) {} + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/types/SchemaException.java b/clients/src/main/java/org/apache/kafka/common/protocol/types/SchemaException.java new file mode 100644 index 0000000..8bbab32 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/protocol/types/SchemaException.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.protocol.types; + +import org.apache.kafka.common.KafkaException; + +/** + * Thrown if the protocol schema validation fails while parsing request or response. + */ +public class SchemaException extends KafkaException { + + private static final long serialVersionUID = 1L; + + public SchemaException(String message) { + super(message); + } + + public SchemaException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/types/Struct.java b/clients/src/main/java/org/apache/kafka/common/protocol/types/Struct.java new file mode 100644 index 0000000..9b9b5e6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/protocol/types/Struct.java @@ -0,0 +1,606 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.protocol.types; + +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.record.BaseRecords; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Objects; + +/** + * A record that can be serialized and deserialized according to a pre-defined schema + */ +public class Struct { + private final Schema schema; + private final Object[] values; + + Struct(Schema schema, Object[] values) { + this.schema = schema; + this.values = values; + } + + public Struct(Schema schema) { + this.schema = schema; + this.values = new Object[this.schema.numFields()]; + } + + /** + * The schema for this struct. + */ + public Schema schema() { + return this.schema; + } + + /** + * Return the value of the given pre-validated field, or if the value is missing return the default value. + * + * @param field The field for which to get the default value + * @throws SchemaException if the field has no value and has no default. + */ + private Object getFieldOrDefault(BoundField field) { + Object value = this.values[field.index]; + if (value != null) + return value; + else if (field.def.hasDefaultValue) + return field.def.defaultValue; + else if (field.def.type.isNullable()) + return null; + else + throw new SchemaException("Missing value for field '" + field.def.name + "' which has no default value."); + } + + /** + * Get the value for the field directly by the field index with no lookup needed (faster!) + * + * @param field The field to look up + * @return The value for that field. + * @throws SchemaException if the field has no value and has no default. + */ + public Object get(BoundField field) { + validateField(field); + return getFieldOrDefault(field); + } + + public Byte get(Field.Int8 field) { + return getByte(field.name); + } + + public Integer get(Field.Int32 field) { + return getInt(field.name); + } + + public Long get(Field.Int64 field) { + return getLong(field.name); + } + + public Uuid get(Field.UUID field) { + return getUuid(field.name); + } + + public Integer get(Field.Uint16 field) { + return getInt(field.name); + } + + public Short get(Field.Int16 field) { + return getShort(field.name); + } + + public Double get(Field.Float64 field) { + return getDouble(field.name); + } + + public String get(Field.Str field) { + return getString(field.name); + } + + public String get(Field.NullableStr field) { + return getString(field.name); + } + + public Boolean get(Field.Bool field) { + return getBoolean(field.name); + } + + public Object[] get(Field.Array field) { + return getArray(field.name); + } + + public Object[] get(Field.ComplexArray field) { + return getArray(field.name); + } + + public Long getOrElse(Field.Int64 field, long alternative) { + if (hasField(field.name)) + return getLong(field.name); + return alternative; + } + + public Uuid getOrElse(Field.UUID field, Uuid alternative) { + if (hasField(field.name)) + return getUuid(field.name); + return alternative; + } + + public Short getOrElse(Field.Int16 field, short alternative) { + if (hasField(field.name)) + return getShort(field.name); + return alternative; + } + + public Byte getOrElse(Field.Int8 field, byte alternative) { + if (hasField(field.name)) + return getByte(field.name); + return alternative; + } + + public Integer getOrElse(Field.Int32 field, int alternative) { + if (hasField(field.name)) + return getInt(field.name); + return alternative; + } + + public Double getOrElse(Field.Float64 field, double alternative) { + if (hasField(field.name)) + return getDouble(field.name); + return alternative; + } + + public String getOrElse(Field.NullableStr field, String alternative) { + if (hasField(field.name)) + return getString(field.name); + return alternative; + } + + public String getOrElse(Field.Str field, String alternative) { + if (hasField(field.name)) + return getString(field.name); + return alternative; + } + + public boolean getOrElse(Field.Bool field, boolean alternative) { + if (hasField(field.name)) + return getBoolean(field.name); + return alternative; + } + + public Object[] getOrEmpty(Field.Array field) { + if (hasField(field.name)) + return getArray(field.name); + return new Object[0]; + } + + public Object[] getOrEmpty(Field.ComplexArray field) { + if (hasField(field.name)) + return getArray(field.name); + return new Object[0]; + } + + /** + * Get the record value for the field with the given name by doing a hash table lookup (slower!) + * + * @param name The name of the field + * @return The value in the record + * @throws SchemaException If no such field exists + */ + public Object get(String name) { + BoundField field = schema.get(name); + if (field == null) + throw new SchemaException("No such field: " + name); + return getFieldOrDefault(field); + } + + /** + * Check if the struct contains a field. + * @param name + * @return Whether a field exists. + */ + public boolean hasField(String name) { + return schema.get(name) != null; + } + + public boolean hasField(Field def) { + return schema.get(def.name) != null; + } + + public boolean hasField(Field.ComplexArray def) { + return schema.get(def.name) != null; + } + + public Struct getStruct(BoundField field) { + return (Struct) get(field); + } + + public Struct getStruct(String name) { + return (Struct) get(name); + } + + public Byte getByte(BoundField field) { + return (Byte) get(field); + } + + public byte getByte(String name) { + return (Byte) get(name); + } + + public BaseRecords getRecords(String name) { + return (BaseRecords) get(name); + } + + public Short getShort(BoundField field) { + return (Short) get(field); + } + + public Short getShort(String name) { + return (Short) get(name); + } + + public Integer getUnsignedShort(BoundField field) { + return (Integer) get(field); + } + + public Integer getUnsignedShort(String name) { + return (Integer) get(name); + } + + public Integer getInt(BoundField field) { + return (Integer) get(field); + } + + public Integer getInt(String name) { + return (Integer) get(name); + } + + public Long getUnsignedInt(String name) { + return (Long) get(name); + } + + public Long getLong(BoundField field) { + return (Long) get(field); + } + + public Long getLong(String name) { + return (Long) get(name); + } + + public Uuid getUuid(BoundField field) { + return (Uuid) get(field); + } + + public Uuid getUuid(String name) { + return (Uuid) get(name); + } + + public Double getDouble(BoundField field) { + return (Double) get(field); + } + + public Double getDouble(String name) { + return (Double) get(name); + } + + public Object[] getArray(BoundField field) { + return (Object[]) get(field); + } + + public Object[] getArray(String name) { + return (Object[]) get(name); + } + + public String getString(BoundField field) { + return (String) get(field); + } + + public String getString(String name) { + return (String) get(name); + } + + public Boolean getBoolean(BoundField field) { + return (Boolean) get(field); + } + + public Boolean getBoolean(String name) { + return (Boolean) get(name); + } + + public ByteBuffer getBytes(BoundField field) { + Object result = get(field); + if (result instanceof byte[]) + return ByteBuffer.wrap((byte[]) result); + return (ByteBuffer) result; + } + + public ByteBuffer getBytes(String name) { + Object result = get(name); + if (result instanceof byte[]) + return ByteBuffer.wrap((byte[]) result); + return (ByteBuffer) result; + } + + public byte[] getByteArray(String name) { + Object result = get(name); + if (result instanceof byte[]) + return (byte[]) result; + ByteBuffer buf = (ByteBuffer) result; + byte[] arr = new byte[buf.remaining()]; + buf.get(arr); + buf.flip(); + return arr; + } + + /** + * Set the given field to the specified value + * + * @param field The field + * @param value The value + * @throws SchemaException If the validation of the field failed + */ + public Struct set(BoundField field, Object value) { + validateField(field); + this.values[field.index] = value; + return this; + } + + /** + * Set the field specified by the given name to the value + * + * @param name The name of the field + * @param value The value to set + * @throws SchemaException If the field is not known + */ + public Struct set(String name, Object value) { + BoundField field = this.schema.get(name); + if (field == null) + throw new SchemaException("Unknown field: " + name); + this.values[field.index] = value; + return this; + } + + public Struct set(Field.Str def, String value) { + return set(def.name, value); + } + + public Struct set(Field.NullableStr def, String value) { + return set(def.name, value); + } + + public Struct set(Field.Int8 def, byte value) { + return set(def.name, value); + } + + public Struct set(Field.Int32 def, int value) { + return set(def.name, value); + } + + public Struct set(Field.Int64 def, long value) { + return set(def.name, value); + } + + public Struct set(Field.UUID def, Uuid value) { + return set(def.name, value); + } + + public Struct set(Field.Int16 def, short value) { + return set(def.name, value); + } + + public Struct set(Field.Uint16 def, int value) { + if (value < 0 || value > 65535) { + throw new RuntimeException("Invalid value for unsigned short for " + + def.name + ": " + value); + } + return set(def.name, value); + } + + public Struct set(Field.Float64 def, double value) { + return set(def.name, value); + } + + public Struct set(Field.Bool def, boolean value) { + return set(def.name, value); + } + + public Struct set(Field.Array def, Object[] value) { + return set(def.name, value); + } + + public Struct set(Field.ComplexArray def, Object[] value) { + return set(def.name, value); + } + + public Struct setByteArray(String name, byte[] value) { + ByteBuffer buf = value == null ? null : ByteBuffer.wrap(value); + return set(name, buf); + } + + public Struct setIfExists(Field.Array def, Object[] value) { + return setIfExists(def.name, value); + } + + public Struct setIfExists(Field.ComplexArray def, Object[] value) { + return setIfExists(def.name, value); + } + + public Struct setIfExists(Field def, Object value) { + return setIfExists(def.name, value); + } + + public Struct setIfExists(String fieldName, Object value) { + BoundField field = this.schema.get(fieldName); + if (field != null) + this.values[field.index] = value; + return this; + } + + /** + * Create a struct for the schema of a container type (struct or array). Note that for array type, this method + * assumes that the type is an array of schema and creates a struct of that schema. Arrays of other types can't be + * instantiated with this method. + * + * @param field The field to create an instance of + * @return The struct + * @throws SchemaException If the given field is not a container type + */ + public Struct instance(BoundField field) { + validateField(field); + if (field.def.type instanceof Schema) { + return new Struct((Schema) field.def.type); + } else if (field.def.type.isArray()) { + return new Struct((Schema) field.def.type.arrayElementType().get()); + } else { + throw new SchemaException("Field '" + field.def.name + "' is not a container type, it is of type " + field.def.type); + } + } + + /** + * Create a struct instance for the given field which must be a container type (struct or array) + * + * @param field The name of the field to create (field must be a schema type) + * @return The struct + * @throws SchemaException If the given field is not a container type + */ + public Struct instance(String field) { + return instance(schema.get(field)); + } + + public Struct instance(Field field) { + return instance(schema.get(field.name)); + } + + public Struct instance(Field.ComplexArray field) { + return instance(schema.get(field.name)); + } + + /** + * Empty all the values from this record + */ + public void clear() { + Arrays.fill(this.values, null); + } + + /** + * Get the serialized size of this object + */ + public int sizeOf() { + return this.schema.sizeOf(this); + } + + /** + * Write this struct to a buffer + */ + public void writeTo(ByteBuffer buffer) { + this.schema.write(buffer, this); + } + + /** + * Ensure the user doesn't try to access fields from the wrong schema + * + * @throws SchemaException If validation fails + */ + private void validateField(BoundField field) { + Objects.requireNonNull(field, "`field` must be non-null"); + if (this.schema != field.schema) + throw new SchemaException("Attempt to access field '" + field.def.name + "' from a different schema instance."); + if (field.index > values.length) + throw new SchemaException("Invalid field index: " + field.index); + } + + /** + * Validate the contents of this struct against its schema + * + * @throws SchemaException If validation fails + */ + public void validate() { + this.schema.validate(this); + } + + @Override + public String toString() { + StringBuilder b = new StringBuilder(); + b.append('{'); + for (int i = 0; i < this.values.length; i++) { + BoundField f = this.schema.get(i); + b.append(f.def.name); + b.append('='); + if (f.def.type.isArray() && this.values[i] != null) { + Object[] arrayValue = (Object[]) this.values[i]; + b.append('['); + for (int j = 0; j < arrayValue.length; j++) { + b.append(arrayValue[j]); + if (j < arrayValue.length - 1) + b.append(','); + } + b.append(']'); + } else + b.append(this.values[i]); + if (i < this.values.length - 1) + b.append(','); + } + b.append('}'); + return b.toString(); + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + for (int i = 0; i < this.values.length; i++) { + BoundField f = this.schema.get(i); + if (f.def.type.isArray()) { + if (this.get(f) != null) { + Object[] arrayObject = (Object[]) this.get(f); + for (Object arrayItem: arrayObject) + result = prime * result + arrayItem.hashCode(); + } + } else { + Object field = this.get(f); + if (field != null) { + result = prime * result + field.hashCode(); + } + } + } + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + Struct other = (Struct) obj; + if (schema != other.schema) + return false; + for (int i = 0; i < this.values.length; i++) { + BoundField f = this.schema.get(i); + boolean result; + if (f.def.type.isArray()) { + result = Arrays.equals((Object[]) this.get(f), (Object[]) other.get(f)); + } else { + Object thisField = this.get(f); + Object otherField = other.get(f); + result = Objects.equals(thisField, otherField); + } + if (!result) + return false; + } + return true; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/types/TaggedFields.java b/clients/src/main/java/org/apache/kafka/common/protocol/types/TaggedFields.java new file mode 100644 index 0000000..4e1ab0d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/protocol/types/TaggedFields.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.protocol.types; + +import org.apache.kafka.common.protocol.types.Type.DocumentedType; +import org.apache.kafka.common.utils.ByteUtils; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.Map; +import java.util.NavigableMap; +import java.util.TreeMap; + +/** + * Represents a tagged fields section. + */ +public class TaggedFields extends DocumentedType { + private static final String TAGGED_FIELDS_TYPE_NAME = "TAGGED_FIELDS"; + + private final Map fields; + + /** + * Create a new TaggedFields object with the given tags and fields. + * + * @param fields This is an array containing Integer tags followed + * by associated Field objects. + * @return The new {@link TaggedFields} + */ + @SuppressWarnings("unchecked") + public static TaggedFields of(Object... fields) { + if (fields.length % 2 != 0) { + throw new RuntimeException("TaggedFields#of takes an even " + + "number of parameters."); + } + TreeMap newFields = new TreeMap<>(); + for (int i = 0; i < fields.length; i += 2) { + Integer tag = (Integer) fields[i]; + Field field = (Field) fields[i + 1]; + newFields.put(tag, field); + } + return new TaggedFields(newFields); + } + + public TaggedFields(Map fields) { + this.fields = fields; + } + + @Override + public boolean isNullable() { + return false; + } + + @SuppressWarnings("unchecked") + @Override + public void write(ByteBuffer buffer, Object o) { + NavigableMap objects = (NavigableMap) o; + ByteUtils.writeUnsignedVarint(objects.size(), buffer); + for (Map.Entry entry : objects.entrySet()) { + Integer tag = entry.getKey(); + Field field = fields.get(tag); + ByteUtils.writeUnsignedVarint(tag, buffer); + if (field == null) { + RawTaggedField value = (RawTaggedField) entry.getValue(); + ByteUtils.writeUnsignedVarint(value.data().length, buffer); + buffer.put(value.data()); + } else { + ByteUtils.writeUnsignedVarint(field.type.sizeOf(entry.getValue()), buffer); + field.type.write(buffer, entry.getValue()); + } + } + } + + @SuppressWarnings("unchecked") + @Override + public NavigableMap read(ByteBuffer buffer) { + int numTaggedFields = ByteUtils.readUnsignedVarint(buffer); + if (numTaggedFields == 0) { + return Collections.emptyNavigableMap(); + } + NavigableMap objects = new TreeMap<>(); + int prevTag = -1; + for (int i = 0; i < numTaggedFields; i++) { + int tag = ByteUtils.readUnsignedVarint(buffer); + if (tag <= prevTag) { + throw new RuntimeException("Invalid or out-of-order tag " + tag); + } + prevTag = tag; + int size = ByteUtils.readUnsignedVarint(buffer); + Field field = fields.get(tag); + if (field == null) { + byte[] bytes = new byte[size]; + buffer.get(bytes); + objects.put(tag, new RawTaggedField(tag, bytes)); + } else { + objects.put(tag, field.type.read(buffer)); + } + } + return objects; + } + + @SuppressWarnings("unchecked") + @Override + public int sizeOf(Object o) { + int size = 0; + NavigableMap objects = (NavigableMap) o; + size += ByteUtils.sizeOfUnsignedVarint(objects.size()); + for (Map.Entry entry : objects.entrySet()) { + Integer tag = entry.getKey(); + size += ByteUtils.sizeOfUnsignedVarint(tag); + Field field = fields.get(tag); + if (field == null) { + RawTaggedField value = (RawTaggedField) entry.getValue(); + size += value.data().length + ByteUtils.sizeOfUnsignedVarint(value.data().length); + } else { + int valueSize = field.type.sizeOf(entry.getValue()); + size += valueSize + ByteUtils.sizeOfUnsignedVarint(valueSize); + } + } + return size; + } + + @Override + public String toString() { + StringBuilder bld = new StringBuilder("TAGGED_FIELDS_TYPE_NAME("); + String prefix = ""; + for (Map.Entry field : fields.entrySet()) { + bld.append(prefix); + prefix = ", "; + bld.append(field.getKey()).append(" -> ").append(field.getValue().toString()); + } + bld.append(")"); + return bld.toString(); + } + + @SuppressWarnings("unchecked") + @Override + public Map validate(Object item) { + try { + NavigableMap objects = (NavigableMap) item; + for (Map.Entry entry : objects.entrySet()) { + Integer tag = entry.getKey(); + Field field = fields.get(tag); + if (field == null) { + if (!(entry.getValue() instanceof RawTaggedField)) { + throw new SchemaException("The value associated with tag " + tag + + " must be a RawTaggedField in this version of the software."); + } + } else { + field.type.validate(entry.getValue()); + } + } + return objects; + } catch (ClassCastException e) { + throw new SchemaException("Not a NavigableMap."); + } + } + + @Override + public String typeName() { + return TAGGED_FIELDS_TYPE_NAME; + } + + @Override + public String documentation() { + return "Represents a series of tagged fields."; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/protocol/types/Type.java b/clients/src/main/java/org/apache/kafka/common/protocol/types/Type.java new file mode 100644 index 0000000..46a59bd --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/protocol/types/Type.java @@ -0,0 +1,1130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.protocol.types; + +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.record.BaseRecords; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.utils.ByteUtils; +import org.apache.kafka.common.utils.Utils; + +import java.nio.ByteBuffer; +import java.util.Optional; + +/** + * A serializable type + */ +public abstract class Type { + + /** + * Write the typed object to the buffer + * + * @throws SchemaException If the object is not valid for its type + */ + public abstract void write(ByteBuffer buffer, Object o); + + /** + * Read the typed object from the buffer + * + * @throws SchemaException If the object is not valid for its type + */ + public abstract Object read(ByteBuffer buffer); + + /** + * Validate the object. If succeeded return its typed object. + * + * @throws SchemaException If validation failed + */ + public abstract Object validate(Object o); + + /** + * Return the size of the object in bytes + */ + public abstract int sizeOf(Object o); + + /** + * Check if the type supports null values + * @return whether or not null is a valid value for the type implementation + */ + public boolean isNullable() { + return false; + } + + /** + * If the type is an array, return the type of the array elements. Otherwise, return empty. + */ + public Optional arrayElementType() { + return Optional.empty(); + } + + /** + * Returns true if the type is an array. + */ + public final boolean isArray() { + return arrayElementType().isPresent(); + } + + /** + * A Type that can return its description for documentation purposes. + */ + public static abstract class DocumentedType extends Type { + + /** + * Short name of the type to identify it in documentation; + * @return the name of the type + */ + public abstract String typeName(); + + /** + * Documentation of the Type. + * + * @return details about valid values, representation + */ + public abstract String documentation(); + + @Override + public String toString() { + return typeName(); + } + } + /** + * The Boolean type represents a boolean value in a byte by using + * the value of 0 to represent false, and 1 to represent true. + * + * If for some reason a value that is not 0 or 1 is read, + * then any non-zero value will return true. + */ + public static final DocumentedType BOOLEAN = new DocumentedType() { + @Override + public void write(ByteBuffer buffer, Object o) { + if ((Boolean) o) + buffer.put((byte) 1); + else + buffer.put((byte) 0); + } + + @Override + public Object read(ByteBuffer buffer) { + byte value = buffer.get(); + return value != 0; + } + + @Override + public int sizeOf(Object o) { + return 1; + } + + @Override + public String typeName() { + return "BOOLEAN"; + } + + @Override + public Boolean validate(Object item) { + if (item instanceof Boolean) + return (Boolean) item; + else + throw new SchemaException(item + " is not a Boolean."); + } + + @Override + public String documentation() { + return "Represents a boolean value in a byte. " + + "Values 0 and 1 are used to represent false and true respectively. " + + "When reading a boolean value, any non-zero value is considered true."; + } + }; + + public static final DocumentedType INT8 = new DocumentedType() { + @Override + public void write(ByteBuffer buffer, Object o) { + buffer.put((Byte) o); + } + + @Override + public Object read(ByteBuffer buffer) { + return buffer.get(); + } + + @Override + public int sizeOf(Object o) { + return 1; + } + + @Override + public String typeName() { + return "INT8"; + } + + @Override + public Byte validate(Object item) { + if (item instanceof Byte) + return (Byte) item; + else + throw new SchemaException(item + " is not a Byte."); + } + + @Override + public String documentation() { + return "Represents an integer between -27 and 27-1 inclusive."; + } + }; + + public static final DocumentedType INT16 = new DocumentedType() { + @Override + public void write(ByteBuffer buffer, Object o) { + buffer.putShort((Short) o); + } + + @Override + public Object read(ByteBuffer buffer) { + return buffer.getShort(); + } + + @Override + public int sizeOf(Object o) { + return 2; + } + + @Override + public String typeName() { + return "INT16"; + } + + @Override + public Short validate(Object item) { + if (item instanceof Short) + return (Short) item; + else + throw new SchemaException(item + " is not a Short."); + } + + @Override + public String documentation() { + return "Represents an integer between -215 and 215-1 inclusive. " + + "The values are encoded using two bytes in network byte order (big-endian)."; + } + }; + + public static final DocumentedType UINT16 = new DocumentedType() { + @Override + public void write(ByteBuffer buffer, Object o) { + Integer value = (Integer) o; + buffer.putShort((short) value.intValue()); + } + + @Override + public Object read(ByteBuffer buffer) { + short value = buffer.getShort(); + return Integer.valueOf(Short.toUnsignedInt(value)); + } + + @Override + public int sizeOf(Object o) { + return 2; + } + + @Override + public String typeName() { + return "UINT16"; + } + + @Override + public Integer validate(Object item) { + if (item instanceof Integer) + return (Integer) item; + else + throw new SchemaException(item + " is not an a Integer (encoding an unsigned short)"); + } + + @Override + public String documentation() { + return "Represents an integer between 0 and 65535 inclusive. " + + "The values are encoded using two bytes in network byte order (big-endian)."; + } + }; + + public static final DocumentedType INT32 = new DocumentedType() { + @Override + public void write(ByteBuffer buffer, Object o) { + buffer.putInt((Integer) o); + } + + @Override + public Object read(ByteBuffer buffer) { + return buffer.getInt(); + } + + @Override + public int sizeOf(Object o) { + return 4; + } + + @Override + public String typeName() { + return "INT32"; + } + + @Override + public Integer validate(Object item) { + if (item instanceof Integer) + return (Integer) item; + else + throw new SchemaException(item + " is not an Integer."); + } + + @Override + public String documentation() { + return "Represents an integer between -231 and 231-1 inclusive. " + + "The values are encoded using four bytes in network byte order (big-endian)."; + } + }; + + public static final DocumentedType UNSIGNED_INT32 = new DocumentedType() { + @Override + public void write(ByteBuffer buffer, Object o) { + ByteUtils.writeUnsignedInt(buffer, (long) o); + } + + @Override + public Object read(ByteBuffer buffer) { + return ByteUtils.readUnsignedInt(buffer); + } + + @Override + public int sizeOf(Object o) { + return 4; + } + + @Override + public String typeName() { + return "UINT32"; + } + + @Override + public Long validate(Object item) { + if (item instanceof Long) + return (Long) item; + else + throw new SchemaException(item + " is not an a Long (encoding an unsigned integer)."); + } + + @Override + public String documentation() { + return "Represents an integer between 0 and 232-1 inclusive. " + + "The values are encoded using four bytes in network byte order (big-endian)."; + } + }; + + public static final DocumentedType INT64 = new DocumentedType() { + @Override + public void write(ByteBuffer buffer, Object o) { + buffer.putLong((Long) o); + } + + @Override + public Object read(ByteBuffer buffer) { + return buffer.getLong(); + } + + @Override + public int sizeOf(Object o) { + return 8; + } + + @Override + public String typeName() { + return "INT64"; + } + + @Override + public Long validate(Object item) { + if (item instanceof Long) + return (Long) item; + else + throw new SchemaException(item + " is not a Long."); + } + + @Override + public String documentation() { + return "Represents an integer between -263 and 263-1 inclusive. " + + "The values are encoded using eight bytes in network byte order (big-endian)."; + } + }; + + public static final DocumentedType UUID = new DocumentedType() { + @Override + public void write(ByteBuffer buffer, Object o) { + final Uuid uuid = (Uuid) o; + buffer.putLong(uuid.getMostSignificantBits()); + buffer.putLong(uuid.getLeastSignificantBits()); + } + + @Override + public Object read(ByteBuffer buffer) { + return new Uuid(buffer.getLong(), buffer.getLong()); + } + + @Override + public int sizeOf(Object o) { + return 16; + } + + @Override + public String typeName() { + return "UUID"; + } + + @Override + public Uuid validate(Object item) { + if (item instanceof Uuid) + return (Uuid) item; + else + throw new SchemaException(item + " is not a Uuid."); + } + + @Override + public String documentation() { + return "Represents a type 4 immutable universally unique identifier (Uuid). " + + "The values are encoded using sixteen bytes in network byte order (big-endian)."; + } + }; + + public static final DocumentedType FLOAT64 = new DocumentedType() { + @Override + public void write(ByteBuffer buffer, Object o) { + ByteUtils.writeDouble((Double) o, buffer); + } + + @Override + public Object read(ByteBuffer buffer) { + return ByteUtils.readDouble(buffer); + } + + @Override + public int sizeOf(Object o) { + return 8; + } + + @Override + public String typeName() { + return "FLOAT64"; + } + + @Override + public Double validate(Object item) { + if (item instanceof Double) + return (Double) item; + else + throw new SchemaException(item + " is not a Double."); + } + + @Override + public String documentation() { + return "Represents a double-precision 64-bit format IEEE 754 value. " + + "The values are encoded using eight bytes in network byte order (big-endian)."; + } + }; + + public static final DocumentedType STRING = new DocumentedType() { + @Override + public void write(ByteBuffer buffer, Object o) { + byte[] bytes = Utils.utf8((String) o); + if (bytes.length > Short.MAX_VALUE) + throw new SchemaException("String length " + bytes.length + " is larger than the maximum string length."); + buffer.putShort((short) bytes.length); + buffer.put(bytes); + } + + @Override + public String read(ByteBuffer buffer) { + short length = buffer.getShort(); + if (length < 0) + throw new SchemaException("String length " + length + " cannot be negative"); + if (length > buffer.remaining()) + throw new SchemaException("Error reading string of length " + length + ", only " + buffer.remaining() + " bytes available"); + String result = Utils.utf8(buffer, length); + buffer.position(buffer.position() + length); + return result; + } + + @Override + public int sizeOf(Object o) { + return 2 + Utils.utf8Length((String) o); + } + + @Override + public String typeName() { + return "STRING"; + } + + @Override + public String validate(Object item) { + if (item instanceof String) + return (String) item; + else + throw new SchemaException(item + " is not a String."); + } + + @Override + public String documentation() { + return "Represents a sequence of characters. First the length N is given as an " + INT16 + + ". Then N bytes follow which are the UTF-8 encoding of the character sequence. " + + "Length must not be negative."; + } + }; + + public static final DocumentedType COMPACT_STRING = new DocumentedType() { + @Override + public void write(ByteBuffer buffer, Object o) { + byte[] bytes = Utils.utf8((String) o); + if (bytes.length > Short.MAX_VALUE) + throw new SchemaException("String length " + bytes.length + " is larger than the maximum string length."); + ByteUtils.writeUnsignedVarint(bytes.length + 1, buffer); + buffer.put(bytes); + } + + @Override + public String read(ByteBuffer buffer) { + int length = ByteUtils.readUnsignedVarint(buffer) - 1; + if (length < 0) + throw new SchemaException("String length " + length + " cannot be negative"); + if (length > Short.MAX_VALUE) + throw new SchemaException("String length " + length + " is larger than the maximum string length."); + if (length > buffer.remaining()) + throw new SchemaException("Error reading string of length " + length + ", only " + buffer.remaining() + " bytes available"); + String result = Utils.utf8(buffer, length); + buffer.position(buffer.position() + length); + return result; + } + + @Override + public int sizeOf(Object o) { + int length = Utils.utf8Length((String) o); + return ByteUtils.sizeOfUnsignedVarint(length + 1) + length; + } + + @Override + public String typeName() { + return "COMPACT_STRING"; + } + + @Override + public String validate(Object item) { + if (item instanceof String) + return (String) item; + else + throw new SchemaException(item + " is not a String."); + } + + @Override + public String documentation() { + return "Represents a sequence of characters. First the length N + 1 is given as an UNSIGNED_VARINT " + + ". Then N bytes follow which are the UTF-8 encoding of the character sequence."; + } + }; + + public static final DocumentedType NULLABLE_STRING = new DocumentedType() { + @Override + public boolean isNullable() { + return true; + } + + @Override + public void write(ByteBuffer buffer, Object o) { + if (o == null) { + buffer.putShort((short) -1); + return; + } + + byte[] bytes = Utils.utf8((String) o); + if (bytes.length > Short.MAX_VALUE) + throw new SchemaException("String length " + bytes.length + " is larger than the maximum string length."); + buffer.putShort((short) bytes.length); + buffer.put(bytes); + } + + @Override + public String read(ByteBuffer buffer) { + short length = buffer.getShort(); + if (length < 0) + return null; + if (length > buffer.remaining()) + throw new SchemaException("Error reading string of length " + length + ", only " + buffer.remaining() + " bytes available"); + String result = Utils.utf8(buffer, length); + buffer.position(buffer.position() + length); + return result; + } + + @Override + public int sizeOf(Object o) { + if (o == null) + return 2; + + return 2 + Utils.utf8Length((String) o); + } + + @Override + public String typeName() { + return "NULLABLE_STRING"; + } + + @Override + public String validate(Object item) { + if (item == null) + return null; + + if (item instanceof String) + return (String) item; + else + throw new SchemaException(item + " is not a String."); + } + + @Override + public String documentation() { + return "Represents a sequence of characters or null. For non-null strings, first the length N is given as an " + INT16 + + ". Then N bytes follow which are the UTF-8 encoding of the character sequence. " + + "A null value is encoded with length of -1 and there are no following bytes."; + } + }; + + public static final DocumentedType COMPACT_NULLABLE_STRING = new DocumentedType() { + @Override + public boolean isNullable() { + return true; + } + + @Override + public void write(ByteBuffer buffer, Object o) { + if (o == null) { + ByteUtils.writeUnsignedVarint(0, buffer); + } else { + byte[] bytes = Utils.utf8((String) o); + if (bytes.length > Short.MAX_VALUE) + throw new SchemaException("String length " + bytes.length + " is larger than the maximum string length."); + ByteUtils.writeUnsignedVarint(bytes.length + 1, buffer); + buffer.put(bytes); + } + } + + @Override + public String read(ByteBuffer buffer) { + int length = ByteUtils.readUnsignedVarint(buffer) - 1; + if (length < 0) { + return null; + } else if (length > Short.MAX_VALUE) { + throw new SchemaException("String length " + length + " is larger than the maximum string length."); + } else if (length > buffer.remaining()) { + throw new SchemaException("Error reading string of length " + length + ", only " + buffer.remaining() + " bytes available"); + } else { + String result = Utils.utf8(buffer, length); + buffer.position(buffer.position() + length); + return result; + } + } + + @Override + public int sizeOf(Object o) { + if (o == null) { + return 1; + } + int length = Utils.utf8Length((String) o); + return ByteUtils.sizeOfUnsignedVarint(length + 1) + length; + } + + @Override + public String typeName() { + return "COMPACT_NULLABLE_STRING"; + } + + @Override + public String validate(Object item) { + if (item == null) { + return null; + } else if (item instanceof String) { + return (String) item; + } else { + throw new SchemaException(item + " is not a String."); + } + } + + @Override + public String documentation() { + return "Represents a sequence of characters. First the length N + 1 is given as an UNSIGNED_VARINT " + + ". Then N bytes follow which are the UTF-8 encoding of the character sequence. " + + "A null string is represented with a length of 0."; + } + }; + + public static final DocumentedType BYTES = new DocumentedType() { + @Override + public void write(ByteBuffer buffer, Object o) { + ByteBuffer arg = (ByteBuffer) o; + int pos = arg.position(); + buffer.putInt(arg.remaining()); + buffer.put(arg); + arg.position(pos); + } + + @Override + public Object read(ByteBuffer buffer) { + int size = buffer.getInt(); + if (size < 0) + throw new SchemaException("Bytes size " + size + " cannot be negative"); + if (size > buffer.remaining()) + throw new SchemaException("Error reading bytes of size " + size + ", only " + buffer.remaining() + " bytes available"); + + ByteBuffer val = buffer.slice(); + val.limit(size); + buffer.position(buffer.position() + size); + return val; + } + + @Override + public int sizeOf(Object o) { + ByteBuffer buffer = (ByteBuffer) o; + return 4 + buffer.remaining(); + } + + @Override + public String typeName() { + return "BYTES"; + } + + @Override + public ByteBuffer validate(Object item) { + if (item instanceof ByteBuffer) + return (ByteBuffer) item; + else + throw new SchemaException(item + " is not a java.nio.ByteBuffer."); + } + + @Override + public String documentation() { + return "Represents a raw sequence of bytes. First the length N is given as an " + INT32 + + ". Then N bytes follow."; + } + }; + + public static final DocumentedType COMPACT_BYTES = new DocumentedType() { + @Override + public void write(ByteBuffer buffer, Object o) { + ByteBuffer arg = (ByteBuffer) o; + int pos = arg.position(); + ByteUtils.writeUnsignedVarint(arg.remaining() + 1, buffer); + buffer.put(arg); + arg.position(pos); + } + + @Override + public Object read(ByteBuffer buffer) { + int size = ByteUtils.readUnsignedVarint(buffer) - 1; + if (size < 0) + throw new SchemaException("Bytes size " + size + " cannot be negative"); + if (size > buffer.remaining()) + throw new SchemaException("Error reading bytes of size " + size + ", only " + buffer.remaining() + " bytes available"); + + ByteBuffer val = buffer.slice(); + val.limit(size); + buffer.position(buffer.position() + size); + return val; + } + + @Override + public int sizeOf(Object o) { + ByteBuffer buffer = (ByteBuffer) o; + int remaining = buffer.remaining(); + return ByteUtils.sizeOfUnsignedVarint(remaining + 1) + remaining; + } + + @Override + public String typeName() { + return "COMPACT_BYTES"; + } + + @Override + public ByteBuffer validate(Object item) { + if (item instanceof ByteBuffer) + return (ByteBuffer) item; + else + throw new SchemaException(item + " is not a java.nio.ByteBuffer."); + } + + @Override + public String documentation() { + return "Represents a raw sequence of bytes. First the length N+1 is given as an UNSIGNED_VARINT." + + "Then N bytes follow."; + } + }; + + public static final DocumentedType NULLABLE_BYTES = new DocumentedType() { + @Override + public boolean isNullable() { + return true; + } + + @Override + public void write(ByteBuffer buffer, Object o) { + if (o == null) { + buffer.putInt(-1); + return; + } + + ByteBuffer arg = (ByteBuffer) o; + int pos = arg.position(); + buffer.putInt(arg.remaining()); + buffer.put(arg); + arg.position(pos); + } + + @Override + public Object read(ByteBuffer buffer) { + int size = buffer.getInt(); + if (size < 0) + return null; + if (size > buffer.remaining()) + throw new SchemaException("Error reading bytes of size " + size + ", only " + buffer.remaining() + " bytes available"); + + ByteBuffer val = buffer.slice(); + val.limit(size); + buffer.position(buffer.position() + size); + return val; + } + + @Override + public int sizeOf(Object o) { + if (o == null) + return 4; + + ByteBuffer buffer = (ByteBuffer) o; + return 4 + buffer.remaining(); + } + + @Override + public String typeName() { + return "NULLABLE_BYTES"; + } + + @Override + public ByteBuffer validate(Object item) { + if (item == null) + return null; + + if (item instanceof ByteBuffer) + return (ByteBuffer) item; + + throw new SchemaException(item + " is not a java.nio.ByteBuffer."); + } + + @Override + public String documentation() { + return "Represents a raw sequence of bytes or null. For non-null values, first the length N is given as an " + INT32 + + ". Then N bytes follow. A null value is encoded with length of -1 and there are no following bytes."; + } + }; + + public static final DocumentedType COMPACT_NULLABLE_BYTES = new DocumentedType() { + @Override + public boolean isNullable() { + return true; + } + + @Override + public void write(ByteBuffer buffer, Object o) { + if (o == null) { + ByteUtils.writeUnsignedVarint(0, buffer); + } else { + ByteBuffer arg = (ByteBuffer) o; + int pos = arg.position(); + ByteUtils.writeUnsignedVarint(arg.remaining() + 1, buffer); + buffer.put(arg); + arg.position(pos); + } + } + + @Override + public Object read(ByteBuffer buffer) { + int size = ByteUtils.readUnsignedVarint(buffer) - 1; + if (size < 0) + return null; + if (size > buffer.remaining()) + throw new SchemaException("Error reading bytes of size " + size + ", only " + buffer.remaining() + " bytes available"); + + ByteBuffer val = buffer.slice(); + val.limit(size); + buffer.position(buffer.position() + size); + return val; + } + + @Override + public int sizeOf(Object o) { + if (o == null) { + return 1; + } + ByteBuffer buffer = (ByteBuffer) o; + int remaining = buffer.remaining(); + return ByteUtils.sizeOfUnsignedVarint(remaining + 1) + remaining; + } + + @Override + public String typeName() { + return "COMPACT_NULLABLE_BYTES"; + } + + @Override + public ByteBuffer validate(Object item) { + if (item == null) + return null; + + if (item instanceof ByteBuffer) + return (ByteBuffer) item; + + throw new SchemaException(item + " is not a java.nio.ByteBuffer."); + } + + @Override + public String documentation() { + return "Represents a raw sequence of bytes. First the length N+1 is given as an UNSIGNED_VARINT." + + "Then N bytes follow. A null object is represented with a length of 0."; + } + }; + + public static final DocumentedType COMPACT_RECORDS = new DocumentedType() { + @Override + public boolean isNullable() { + return true; + } + + @Override + public void write(ByteBuffer buffer, Object o) { + if (o == null) { + COMPACT_NULLABLE_BYTES.write(buffer, null); + } else if (o instanceof MemoryRecords) { + MemoryRecords records = (MemoryRecords) o; + COMPACT_NULLABLE_BYTES.write(buffer, records.buffer().duplicate()); + } else { + throw new IllegalArgumentException("Unexpected record type: " + o.getClass()); + } + } + + @Override + public MemoryRecords read(ByteBuffer buffer) { + ByteBuffer recordsBuffer = (ByteBuffer) COMPACT_NULLABLE_BYTES.read(buffer); + if (recordsBuffer == null) { + return null; + } else { + return MemoryRecords.readableRecords(recordsBuffer); + } + } + + @Override + public int sizeOf(Object o) { + if (o == null) { + return 1; + } + + BaseRecords records = (BaseRecords) o; + int recordsSize = records.sizeInBytes(); + return ByteUtils.sizeOfUnsignedVarint(recordsSize + 1) + recordsSize; + } + + @Override + public String typeName() { + return "COMPACT_RECORDS"; + } + + @Override + public BaseRecords validate(Object item) { + if (item == null) + return null; + + if (item instanceof BaseRecords) + return (BaseRecords) item; + + throw new SchemaException(item + " is not an instance of " + BaseRecords.class.getName()); + } + + @Override + public String documentation() { + return "Represents a sequence of Kafka records as " + COMPACT_NULLABLE_BYTES + ". " + + "For a detailed description of records see " + + "Message Sets."; + } + }; + + public static final DocumentedType RECORDS = new DocumentedType() { + @Override + public boolean isNullable() { + return true; + } + + @Override + public void write(ByteBuffer buffer, Object o) { + if (o == null) { + NULLABLE_BYTES.write(buffer, null); + } else if (o instanceof MemoryRecords) { + MemoryRecords records = (MemoryRecords) o; + NULLABLE_BYTES.write(buffer, records.buffer().duplicate()); + } else { + throw new IllegalArgumentException("Unexpected record type: " + o.getClass()); + } + } + + @Override + public MemoryRecords read(ByteBuffer buffer) { + ByteBuffer recordsBuffer = (ByteBuffer) NULLABLE_BYTES.read(buffer); + if (recordsBuffer == null) { + return null; + } else { + return MemoryRecords.readableRecords(recordsBuffer); + } + } + + @Override + public int sizeOf(Object o) { + if (o == null) + return 4; + + BaseRecords records = (BaseRecords) o; + return 4 + records.sizeInBytes(); + } + + @Override + public String typeName() { + return "RECORDS"; + } + + @Override + public BaseRecords validate(Object item) { + if (item == null) + return null; + + if (item instanceof BaseRecords) + return (BaseRecords) item; + + throw new SchemaException(item + " is not an instance of " + BaseRecords.class.getName()); + } + + @Override + public String documentation() { + return "Represents a sequence of Kafka records as " + NULLABLE_BYTES + ". " + + "For a detailed description of records see " + + "Message Sets."; + } + }; + + public static final DocumentedType VARINT = new DocumentedType() { + @Override + public void write(ByteBuffer buffer, Object o) { + ByteUtils.writeVarint((Integer) o, buffer); + } + + @Override + public Integer read(ByteBuffer buffer) { + return ByteUtils.readVarint(buffer); + } + + @Override + public Integer validate(Object item) { + if (item instanceof Integer) + return (Integer) item; + throw new SchemaException(item + " is not an integer"); + } + + public String typeName() { + return "VARINT"; + } + + @Override + public int sizeOf(Object o) { + return ByteUtils.sizeOfVarint((Integer) o); + } + + @Override + public String documentation() { + return "Represents an integer between -231 and 231-1 inclusive. " + + "Encoding follows the variable-length zig-zag encoding from " + + " Google Protocol Buffers."; + } + }; + + public static final DocumentedType VARLONG = new DocumentedType() { + @Override + public void write(ByteBuffer buffer, Object o) { + ByteUtils.writeVarlong((Long) o, buffer); + } + + @Override + public Long read(ByteBuffer buffer) { + return ByteUtils.readVarlong(buffer); + } + + @Override + public Long validate(Object item) { + if (item instanceof Long) + return (Long) item; + throw new SchemaException(item + " is not a long"); + } + + public String typeName() { + return "VARLONG"; + } + + @Override + public int sizeOf(Object o) { + return ByteUtils.sizeOfVarlong((Long) o); + } + + @Override + public String documentation() { + return "Represents an integer between -263 and 263-1 inclusive. " + + "Encoding follows the variable-length zig-zag encoding from " + + " Google Protocol Buffers."; + } + }; + + private static String toHtml() { + DocumentedType[] types = { + BOOLEAN, INT8, INT16, INT32, INT64, + UNSIGNED_INT32, VARINT, VARLONG, UUID, FLOAT64, + STRING, COMPACT_STRING, NULLABLE_STRING, COMPACT_NULLABLE_STRING, + BYTES, COMPACT_BYTES, NULLABLE_BYTES, COMPACT_NULLABLE_BYTES, + RECORDS, new ArrayOf(STRING), new CompactArrayOf(COMPACT_STRING)}; + final StringBuilder b = new StringBuilder(); + b.append("\n"); + b.append(""); + b.append("\n"); + b.append("\n"); + b.append("\n"); + for (DocumentedType type : types) { + b.append(""); + b.append(""); + b.append(""); + b.append("\n"); + } + b.append("
        TypeDescription
        "); + b.append(type.typeName()); + b.append(""); + b.append(type.documentation()); + b.append("
        \n"); + return b.toString(); + } + + public static void main(String[] args) { + System.out.println(toHtml()); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/quota/ClientQuotaAlteration.java b/clients/src/main/java/org/apache/kafka/common/quota/ClientQuotaAlteration.java new file mode 100644 index 0000000..88670ce --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/quota/ClientQuotaAlteration.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.quota; + +import java.util.Collection; +import java.util.Objects; + +/** + * Describes a configuration alteration to be made to a client quota entity. + */ +public class ClientQuotaAlteration { + + public static class Op { + private final String key; + private final Double value; + + /** + * @param key the quota type to alter + * @param value if set then the existing value is updated, + * otherwise if null, the existing value is cleared + */ + public Op(String key, Double value) { + this.key = key; + this.value = value; + } + + /** + * @return the quota type to alter + */ + public String key() { + return this.key; + } + + /** + * @return if set then the existing value is updated, + * otherwise if null, the existing value is cleared + */ + public Double value() { + return this.value; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Op that = (Op) o; + return Objects.equals(key, that.key) && Objects.equals(value, that.value); + } + + @Override + public int hashCode() { + return Objects.hash(key, value); + } + + @Override + public String toString() { + return "ClientQuotaAlteration.Op(key=" + key + ", value=" + value + ")"; + } + } + + private final ClientQuotaEntity entity; + private final Collection ops; + + /** + * @param entity the entity whose config will be modified + * @param ops the alteration to perform + */ + public ClientQuotaAlteration(ClientQuotaEntity entity, Collection ops) { + this.entity = entity; + this.ops = ops; + } + + /** + * @return the entity whose config will be modified + */ + public ClientQuotaEntity entity() { + return this.entity; + } + + /** + * @return the alteration to perform + */ + public Collection ops() { + return this.ops; + } + + @Override + public String toString() { + return "ClientQuotaAlteration(entity=" + entity + ", ops=" + ops + ")"; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/quota/ClientQuotaEntity.java b/clients/src/main/java/org/apache/kafka/common/quota/ClientQuotaEntity.java new file mode 100644 index 0000000..d6dffdb --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/quota/ClientQuotaEntity.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.quota; + +import java.util.Map; +import java.util.Objects; + +/** + * Describes a client quota entity, which is a mapping of entity types to their names. + */ +public class ClientQuotaEntity { + + private final Map entries; + + /** + * The type of an entity entry. + */ + public static final String USER = "user"; + public static final String CLIENT_ID = "client-id"; + public static final String IP = "ip"; + + public static boolean isValidEntityType(String entityType) { + return Objects.equals(entityType, USER) || + Objects.equals(entityType, CLIENT_ID) || + Objects.equals(entityType, IP); + } + + /** + * Constructs a quota entity for the given types and names. If a name is null, + * then it is mapped to the built-in default entity name. + * + * @param entries maps entity type to its name + */ + public ClientQuotaEntity(Map entries) { + this.entries = entries; + } + + /** + * @return map of entity type to its name + */ + public Map entries() { + return this.entries; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ClientQuotaEntity that = (ClientQuotaEntity) o; + return Objects.equals(entries, that.entries); + } + + @Override + public int hashCode() { + return Objects.hash(entries); + } + + @Override + public String toString() { + return "ClientQuotaEntity(entries=" + entries + ")"; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/quota/ClientQuotaFilter.java b/clients/src/main/java/org/apache/kafka/common/quota/ClientQuotaFilter.java new file mode 100644 index 0000000..e8a6a72 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/quota/ClientQuotaFilter.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.quota; + +import java.util.Collection; +import java.util.Collections; +import java.util.Objects; + +/** + * Describes a client quota entity filter. + */ +public class ClientQuotaFilter { + + private final Collection components; + private final boolean strict; + + /** + * A filter to be applied to matching client quotas. + * + * @param components the components to filter on + * @param strict whether the filter only includes specified components + */ + private ClientQuotaFilter(Collection components, boolean strict) { + this.components = components; + this.strict = strict; + } + + /** + * Constructs and returns a quota filter that matches all provided components. Matching entities + * with entity types that are not specified by a component will also be included in the result. + * + * @param components the components for the filter + */ + public static ClientQuotaFilter contains(Collection components) { + return new ClientQuotaFilter(components, false); + } + + /** + * Constructs and returns a quota filter that matches all provided components. Matching entities + * with entity types that are not specified by a component will *not* be included in the result. + * + * @param components the components for the filter + */ + public static ClientQuotaFilter containsOnly(Collection components) { + return new ClientQuotaFilter(components, true); + } + + /** + * Constructs and returns a quota filter that matches all configured entities. + */ + public static ClientQuotaFilter all() { + return new ClientQuotaFilter(Collections.emptyList(), false); + } + + /** + * @return the filter's components + */ + public Collection components() { + return this.components; + } + + /** + * @return whether the filter is strict, i.e. only includes specified components + */ + public boolean strict() { + return this.strict; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ClientQuotaFilter that = (ClientQuotaFilter) o; + return Objects.equals(components, that.components) && Objects.equals(strict, that.strict); + } + + @Override + public int hashCode() { + return Objects.hash(components, strict); + } + + @Override + public String toString() { + return "ClientQuotaFilter(components=" + components + ", strict=" + strict + ")"; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/quota/ClientQuotaFilterComponent.java b/clients/src/main/java/org/apache/kafka/common/quota/ClientQuotaFilterComponent.java new file mode 100644 index 0000000..9fc8f72 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/quota/ClientQuotaFilterComponent.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.quota; + +import java.util.Objects; +import java.util.Optional; + +/** + * Describes a component for applying a client quota filter. + */ +public class ClientQuotaFilterComponent { + + private final String entityType; + private final Optional match; + + /** + * A filter to be applied. + * + * @param entityType the entity type the filter component applies to + * @param match if present, the name that's matched exactly + * if empty, matches the default name + * if null, matches any specified name + */ + private ClientQuotaFilterComponent(String entityType, Optional match) { + this.entityType = Objects.requireNonNull(entityType); + this.match = match; + } + + /** + * Constructs and returns a filter component that exactly matches the provided entity + * name for the entity type. + * + * @param entityType the entity type the filter component applies to + * @param entityName the entity name that's matched exactly + */ + public static ClientQuotaFilterComponent ofEntity(String entityType, String entityName) { + return new ClientQuotaFilterComponent(entityType, Optional.of(Objects.requireNonNull(entityName))); + } + + /** + * Constructs and returns a filter component that matches the built-in default entity name + * for the entity type. + * + * @param entityType the entity type the filter component applies to + */ + public static ClientQuotaFilterComponent ofDefaultEntity(String entityType) { + return new ClientQuotaFilterComponent(entityType, Optional.empty()); + } + + /** + * Constructs and returns a filter component that matches any specified name for the + * entity type. + * + * @param entityType the entity type the filter component applies to + */ + public static ClientQuotaFilterComponent ofEntityType(String entityType) { + return new ClientQuotaFilterComponent(entityType, null); + } + + /** + * @return the component's entity type + */ + public String entityType() { + return this.entityType; + } + + /** + * @return the optional match string, where: + * if present, the name that's matched exactly + * if empty, matches the default name + * if null, matches any specified name + */ + public Optional match() { + return this.match; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ClientQuotaFilterComponent that = (ClientQuotaFilterComponent) o; + return Objects.equals(that.entityType, entityType) && Objects.equals(that.match, match); + } + + @Override + public int hashCode() { + return Objects.hash(entityType, match); + } + + @Override + public String toString() { + return "ClientQuotaFilterComponent(entityType=" + entityType + ", match=" + match + ")"; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/AbstractLegacyRecordBatch.java b/clients/src/main/java/org/apache/kafka/common/record/AbstractLegacyRecordBatch.java new file mode 100644 index 0000000..0f2ccde --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/AbstractLegacyRecordBatch.java @@ -0,0 +1,624 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.InvalidRecordException; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.errors.CorruptRecordException; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.utils.AbstractIterator; +import org.apache.kafka.common.utils.BufferSupplier; +import org.apache.kafka.common.utils.ByteBufferOutputStream; +import org.apache.kafka.common.utils.ByteUtils; +import org.apache.kafka.common.utils.CloseableIterator; +import org.apache.kafka.common.utils.Utils; + +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.ArrayDeque; +import java.util.Iterator; +import java.util.NoSuchElementException; +import java.util.Objects; +import java.util.OptionalLong; + +import static org.apache.kafka.common.record.Records.LOG_OVERHEAD; +import static org.apache.kafka.common.record.Records.OFFSET_OFFSET; + +/** + * This {@link RecordBatch} implementation is for magic versions 0 and 1. In addition to implementing + * {@link RecordBatch}, it also implements {@link Record}, which exposes the duality of the old message + * format in its handling of compressed messages. The wrapper record is considered the record batch in this + * interface, while the inner records are considered the log records (though they both share the same schema). + * + * In general, this class should not be used directly. Instances of {@link Records} provides access to this + * class indirectly through the {@link RecordBatch} interface. + */ +public abstract class AbstractLegacyRecordBatch extends AbstractRecordBatch implements Record { + + public abstract LegacyRecord outerRecord(); + + @Override + public long lastOffset() { + return offset(); + } + + @Override + public boolean isValid() { + return outerRecord().isValid(); + } + + @Override + public void ensureValid() { + outerRecord().ensureValid(); + } + + @Override + public int keySize() { + return outerRecord().keySize(); + } + + @Override + public boolean hasKey() { + return outerRecord().hasKey(); + } + + @Override + public ByteBuffer key() { + return outerRecord().key(); + } + + @Override + public int valueSize() { + return outerRecord().valueSize(); + } + + @Override + public boolean hasValue() { + return !outerRecord().hasNullValue(); + } + + @Override + public ByteBuffer value() { + return outerRecord().value(); + } + + @Override + public Header[] headers() { + return Record.EMPTY_HEADERS; + } + + @Override + public boolean hasMagic(byte magic) { + return magic == outerRecord().magic(); + } + + @Override + public boolean hasTimestampType(TimestampType timestampType) { + return outerRecord().timestampType() == timestampType; + } + + @Override + public long checksum() { + return outerRecord().checksum(); + } + + @Override + public long maxTimestamp() { + return timestamp(); + } + + @Override + public long timestamp() { + return outerRecord().timestamp(); + } + + @Override + public TimestampType timestampType() { + return outerRecord().timestampType(); + } + + @Override + public long baseOffset() { + return iterator().next().offset(); + } + + @Override + public byte magic() { + return outerRecord().magic(); + } + + @Override + public CompressionType compressionType() { + return outerRecord().compressionType(); + } + + @Override + public int sizeInBytes() { + return outerRecord().sizeInBytes() + LOG_OVERHEAD; + } + + @Override + public Integer countOrNull() { + return null; + } + + @Override + public String toString() { + return "LegacyRecordBatch(offset=" + offset() + ", " + outerRecord() + ")"; + } + + @Override + public void writeTo(ByteBuffer buffer) { + writeHeader(buffer, offset(), outerRecord().sizeInBytes()); + buffer.put(outerRecord().buffer().duplicate()); + } + + @Override + public long producerId() { + return RecordBatch.NO_PRODUCER_ID; + } + + @Override + public short producerEpoch() { + return RecordBatch.NO_PRODUCER_EPOCH; + } + + @Override + public boolean hasProducerId() { + return false; + } + + @Override + public int sequence() { + return RecordBatch.NO_SEQUENCE; + } + + @Override + public int baseSequence() { + return RecordBatch.NO_SEQUENCE; + } + + @Override + public int lastSequence() { + return RecordBatch.NO_SEQUENCE; + } + + @Override + public boolean isTransactional() { + return false; + } + + @Override + public int partitionLeaderEpoch() { + return RecordBatch.NO_PARTITION_LEADER_EPOCH; + } + + @Override + public boolean isControlBatch() { + return false; + } + + @Override + public OptionalLong deleteHorizonMs() { + return OptionalLong.empty(); + } + + /** + * Get an iterator for the nested entries contained within this batch. Note that + * if the batch is not compressed, then this method will return an iterator over the + * shallow record only (i.e. this object). + * @return An iterator over the records contained within this batch + */ + @Override + public Iterator iterator() { + return iterator(BufferSupplier.NO_CACHING); + } + + CloseableIterator iterator(BufferSupplier bufferSupplier) { + if (isCompressed()) + return new DeepRecordsIterator(this, false, Integer.MAX_VALUE, bufferSupplier); + + return new CloseableIterator() { + private boolean hasNext = true; + + @Override + public void close() {} + + @Override + public boolean hasNext() { + return hasNext; + } + + @Override + public Record next() { + if (!hasNext) + throw new NoSuchElementException(); + hasNext = false; + return AbstractLegacyRecordBatch.this; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + + @Override + public CloseableIterator streamingIterator(BufferSupplier bufferSupplier) { + // the older message format versions do not support streaming, so we return the normal iterator + return iterator(bufferSupplier); + } + + static void writeHeader(ByteBuffer buffer, long offset, int size) { + buffer.putLong(offset); + buffer.putInt(size); + } + + static void writeHeader(DataOutputStream out, long offset, int size) throws IOException { + out.writeLong(offset); + out.writeInt(size); + } + + private static final class DataLogInputStream implements LogInputStream { + private final InputStream stream; + protected final int maxMessageSize; + private final ByteBuffer offsetAndSizeBuffer; + + DataLogInputStream(InputStream stream, int maxMessageSize) { + this.stream = stream; + this.maxMessageSize = maxMessageSize; + this.offsetAndSizeBuffer = ByteBuffer.allocate(Records.LOG_OVERHEAD); + } + + public AbstractLegacyRecordBatch nextBatch() throws IOException { + offsetAndSizeBuffer.clear(); + Utils.readFully(stream, offsetAndSizeBuffer); + if (offsetAndSizeBuffer.hasRemaining()) + return null; + + long offset = offsetAndSizeBuffer.getLong(Records.OFFSET_OFFSET); + int size = offsetAndSizeBuffer.getInt(Records.SIZE_OFFSET); + if (size < LegacyRecord.RECORD_OVERHEAD_V0) + throw new CorruptRecordException(String.format("Record size is less than the minimum record overhead (%d)", LegacyRecord.RECORD_OVERHEAD_V0)); + if (size > maxMessageSize) + throw new CorruptRecordException(String.format("Record size exceeds the largest allowable message size (%d).", maxMessageSize)); + + ByteBuffer batchBuffer = ByteBuffer.allocate(size); + Utils.readFully(stream, batchBuffer); + if (batchBuffer.hasRemaining()) + return null; + batchBuffer.flip(); + + return new BasicLegacyRecordBatch(offset, new LegacyRecord(batchBuffer)); + } + } + + private static class DeepRecordsIterator extends AbstractIterator implements CloseableIterator { + private final ArrayDeque innerEntries; + private final long absoluteBaseOffset; + private final byte wrapperMagic; + + private DeepRecordsIterator(AbstractLegacyRecordBatch wrapperEntry, + boolean ensureMatchingMagic, + int maxMessageSize, + BufferSupplier bufferSupplier) { + LegacyRecord wrapperRecord = wrapperEntry.outerRecord(); + this.wrapperMagic = wrapperRecord.magic(); + if (wrapperMagic != RecordBatch.MAGIC_VALUE_V0 && wrapperMagic != RecordBatch.MAGIC_VALUE_V1) + throw new InvalidRecordException("Invalid wrapper magic found in legacy deep record iterator " + wrapperMagic); + + CompressionType compressionType = wrapperRecord.compressionType(); + if (compressionType == CompressionType.ZSTD) + throw new InvalidRecordException("Invalid wrapper compressionType found in legacy deep record iterator " + wrapperMagic); + ByteBuffer wrapperValue = wrapperRecord.value(); + if (wrapperValue == null) + throw new InvalidRecordException("Found invalid compressed record set with null value (magic = " + + wrapperMagic + ")"); + + InputStream stream = compressionType.wrapForInput(wrapperValue, wrapperRecord.magic(), bufferSupplier); + LogInputStream logStream = new DataLogInputStream(stream, maxMessageSize); + + long lastOffsetFromWrapper = wrapperEntry.lastOffset(); + long timestampFromWrapper = wrapperRecord.timestamp(); + this.innerEntries = new ArrayDeque<>(); + + // If relative offset is used, we need to decompress the entire message first to compute + // the absolute offset. For simplicity and because it's a format that is on its way out, we + // do the same for message format version 0 + try { + while (true) { + AbstractLegacyRecordBatch innerEntry = logStream.nextBatch(); + if (innerEntry == null) + break; + + LegacyRecord record = innerEntry.outerRecord(); + byte magic = record.magic(); + + if (ensureMatchingMagic && magic != wrapperMagic) + throw new InvalidRecordException("Compressed message magic " + magic + + " does not match wrapper magic " + wrapperMagic); + + if (magic == RecordBatch.MAGIC_VALUE_V1) { + LegacyRecord recordWithTimestamp = new LegacyRecord( + record.buffer(), + timestampFromWrapper, + wrapperRecord.timestampType()); + innerEntry = new BasicLegacyRecordBatch(innerEntry.lastOffset(), recordWithTimestamp); + } + + innerEntries.addLast(innerEntry); + } + + if (innerEntries.isEmpty()) + throw new InvalidRecordException("Found invalid compressed record set with no inner records"); + + if (wrapperMagic == RecordBatch.MAGIC_VALUE_V1) { + if (lastOffsetFromWrapper == 0) { + // The outer offset may be 0 if this is produce data from certain versions of librdkafka. + this.absoluteBaseOffset = 0; + } else { + long lastInnerOffset = innerEntries.getLast().offset(); + if (lastOffsetFromWrapper < lastInnerOffset) + throw new InvalidRecordException("Found invalid wrapper offset in compressed v1 message set, " + + "wrapper offset '" + lastOffsetFromWrapper + "' is less than the last inner message " + + "offset '" + lastInnerOffset + "' and it is not zero."); + this.absoluteBaseOffset = lastOffsetFromWrapper - lastInnerOffset; + } + } else { + this.absoluteBaseOffset = -1; + } + } catch (IOException e) { + throw new KafkaException(e); + } finally { + Utils.closeQuietly(stream, "records iterator stream"); + } + } + + @Override + protected Record makeNext() { + if (innerEntries.isEmpty()) + return allDone(); + + AbstractLegacyRecordBatch entry = innerEntries.remove(); + + // Convert offset to absolute offset if needed. + if (wrapperMagic == RecordBatch.MAGIC_VALUE_V1) { + long absoluteOffset = absoluteBaseOffset + entry.offset(); + entry = new BasicLegacyRecordBatch(absoluteOffset, entry.outerRecord()); + } + + if (entry.isCompressed()) + throw new InvalidRecordException("Inner messages must not be compressed"); + + return entry; + } + + @Override + public void close() {} + } + + private static class BasicLegacyRecordBatch extends AbstractLegacyRecordBatch { + private final LegacyRecord record; + private final long offset; + + private BasicLegacyRecordBatch(long offset, LegacyRecord record) { + this.offset = offset; + this.record = record; + } + + @Override + public long offset() { + return offset; + } + + @Override + public LegacyRecord outerRecord() { + return record; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + BasicLegacyRecordBatch that = (BasicLegacyRecordBatch) o; + + return offset == that.offset && + Objects.equals(record, that.record); + } + + @Override + public int hashCode() { + int result = record != null ? record.hashCode() : 0; + result = 31 * result + Long.hashCode(offset); + return result; + } + } + + static class ByteBufferLegacyRecordBatch extends AbstractLegacyRecordBatch implements MutableRecordBatch { + private final ByteBuffer buffer; + private final LegacyRecord record; + + ByteBufferLegacyRecordBatch(ByteBuffer buffer) { + this.buffer = buffer; + buffer.position(LOG_OVERHEAD); + this.record = new LegacyRecord(buffer.slice()); + buffer.position(OFFSET_OFFSET); + } + + @Override + public long offset() { + return buffer.getLong(OFFSET_OFFSET); + } + + @Override + public OptionalLong deleteHorizonMs() { + return OptionalLong.empty(); + } + + @Override + public LegacyRecord outerRecord() { + return record; + } + + @Override + public void setLastOffset(long offset) { + buffer.putLong(OFFSET_OFFSET, offset); + } + + @Override + public void setMaxTimestamp(TimestampType timestampType, long timestamp) { + if (record.magic() == RecordBatch.MAGIC_VALUE_V0) + throw new UnsupportedOperationException("Cannot set timestamp for a record with magic = 0"); + + long currentTimestamp = record.timestamp(); + // We don't need to recompute crc if the timestamp is not updated. + if (record.timestampType() == timestampType && currentTimestamp == timestamp) + return; + + setTimestampAndUpdateCrc(timestampType, timestamp); + } + + @Override + public void setPartitionLeaderEpoch(int epoch) { + throw new UnsupportedOperationException("Magic versions prior to 2 do not support partition leader epoch"); + } + + private void setTimestampAndUpdateCrc(TimestampType timestampType, long timestamp) { + byte attributes = LegacyRecord.computeAttributes(magic(), compressionType(), timestampType); + buffer.put(LOG_OVERHEAD + LegacyRecord.ATTRIBUTES_OFFSET, attributes); + buffer.putLong(LOG_OVERHEAD + LegacyRecord.TIMESTAMP_OFFSET, timestamp); + long crc = record.computeChecksum(); + ByteUtils.writeUnsignedInt(buffer, LOG_OVERHEAD + LegacyRecord.CRC_OFFSET, crc); + } + + /** + * LegacyRecordBatch does not implement this iterator and would hence fallback to the normal iterator. + * + * @return An iterator over the records contained within this batch + */ + @Override + public CloseableIterator skipKeyValueIterator(BufferSupplier bufferSupplier) { + return CloseableIterator.wrap(iterator(bufferSupplier)); + } + + @Override + public void writeTo(ByteBufferOutputStream outputStream) { + outputStream.write(buffer.duplicate()); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + ByteBufferLegacyRecordBatch that = (ByteBufferLegacyRecordBatch) o; + + return Objects.equals(buffer, that.buffer); + } + + @Override + public int hashCode() { + return buffer != null ? buffer.hashCode() : 0; + } + } + + static class LegacyFileChannelRecordBatch extends FileLogInputStream.FileChannelRecordBatch { + + LegacyFileChannelRecordBatch(long offset, + byte magic, + FileRecords fileRecords, + int position, + int batchSize) { + super(offset, magic, fileRecords, position, batchSize); + } + + @Override + protected RecordBatch toMemoryRecordBatch(ByteBuffer buffer) { + return new ByteBufferLegacyRecordBatch(buffer); + } + + @Override + public long baseOffset() { + return loadFullBatch().baseOffset(); + } + + @Override + public OptionalLong deleteHorizonMs() { + return OptionalLong.empty(); + } + + @Override + public long lastOffset() { + return offset; + } + + @Override + public long producerId() { + return RecordBatch.NO_PRODUCER_ID; + } + + @Override + public short producerEpoch() { + return RecordBatch.NO_PRODUCER_EPOCH; + } + + @Override + public int baseSequence() { + return RecordBatch.NO_SEQUENCE; + } + + @Override + public int lastSequence() { + return RecordBatch.NO_SEQUENCE; + } + + @Override + public Integer countOrNull() { + return null; + } + + @Override + public boolean isTransactional() { + return false; + } + + @Override + public boolean isControlBatch() { + return false; + } + + @Override + public int partitionLeaderEpoch() { + return RecordBatch.NO_PARTITION_LEADER_EPOCH; + } + + @Override + protected int headerSize() { + return LOG_OVERHEAD + LegacyRecord.headerSize(magic); + } + + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/AbstractRecordBatch.java b/clients/src/main/java/org/apache/kafka/common/record/AbstractRecordBatch.java new file mode 100644 index 0000000..d104fcd --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/AbstractRecordBatch.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + + +abstract class AbstractRecordBatch implements RecordBatch { + @Override + public boolean hasProducerId() { + return RecordBatch.NO_PRODUCER_ID < producerId(); + } + + @Override + public long nextOffset() { + return lastOffset() + 1; + } + + @Override + public boolean isCompressed() { + return compressionType() != CompressionType.NONE; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/AbstractRecords.java b/clients/src/main/java/org/apache/kafka/common/record/AbstractRecords.java new file mode 100644 index 0000000..265ef5b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/AbstractRecords.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.utils.AbstractIterator; +import org.apache.kafka.common.utils.Utils; + +import java.nio.ByteBuffer; +import java.util.Iterator; + +public abstract class AbstractRecords implements Records { + + private final Iterable records = this::recordsIterator; + + @Override + public boolean hasMatchingMagic(byte magic) { + for (RecordBatch batch : batches()) + if (batch.magic() != magic) + return false; + return true; + } + + public RecordBatch firstBatch() { + Iterator iterator = batches().iterator(); + + if (!iterator.hasNext()) + return null; + + return iterator.next(); + } + + /** + * Get an iterator over the deep records. + * @return An iterator over the records + */ + @Override + public Iterable records() { + return records; + } + + @Override + public DefaultRecordsSend toSend() { + return new DefaultRecordsSend<>(this); + } + + private Iterator recordsIterator() { + return new AbstractIterator() { + private final Iterator batches = batches().iterator(); + private Iterator records; + + @Override + protected Record makeNext() { + if (records != null && records.hasNext()) + return records.next(); + + if (batches.hasNext()) { + records = batches.next().iterator(); + return makeNext(); + } + + return allDone(); + } + }; + } + + public static int estimateSizeInBytes(byte magic, + long baseOffset, + CompressionType compressionType, + Iterable records) { + int size = 0; + if (magic <= RecordBatch.MAGIC_VALUE_V1) { + for (Record record : records) + size += Records.LOG_OVERHEAD + LegacyRecord.recordSize(magic, record.key(), record.value()); + } else { + size = DefaultRecordBatch.sizeInBytes(baseOffset, records); + } + return estimateCompressedSizeInBytes(size, compressionType); + } + + public static int estimateSizeInBytes(byte magic, + CompressionType compressionType, + Iterable records) { + int size = 0; + if (magic <= RecordBatch.MAGIC_VALUE_V1) { + for (SimpleRecord record : records) + size += Records.LOG_OVERHEAD + LegacyRecord.recordSize(magic, record.key(), record.value()); + } else { + size = DefaultRecordBatch.sizeInBytes(records); + } + return estimateCompressedSizeInBytes(size, compressionType); + } + + private static int estimateCompressedSizeInBytes(int size, CompressionType compressionType) { + return compressionType == CompressionType.NONE ? size : Math.min(Math.max(size / 2, 1024), 1 << 16); + } + + /** + * Get an upper bound estimate on the batch size needed to hold a record with the given fields. This is only + * an estimate because it does not take into account overhead from the compression algorithm. + */ + public static int estimateSizeInBytesUpperBound(byte magic, CompressionType compressionType, byte[] key, byte[] value, Header[] headers) { + return estimateSizeInBytesUpperBound(magic, compressionType, Utils.wrapNullable(key), Utils.wrapNullable(value), headers); + } + + /** + * Get an upper bound estimate on the batch size needed to hold a record with the given fields. This is only + * an estimate because it does not take into account overhead from the compression algorithm. + */ + public static int estimateSizeInBytesUpperBound(byte magic, CompressionType compressionType, ByteBuffer key, + ByteBuffer value, Header[] headers) { + if (magic >= RecordBatch.MAGIC_VALUE_V2) + return DefaultRecordBatch.estimateBatchSizeUpperBound(key, value, headers); + else if (compressionType != CompressionType.NONE) + return Records.LOG_OVERHEAD + LegacyRecord.recordOverhead(magic) + LegacyRecord.recordSize(magic, key, value); + else + return Records.LOG_OVERHEAD + LegacyRecord.recordSize(magic, key, value); + } + + /** + * Return the size of the record batch header. + * + * For V0 and V1 with no compression, it's unclear if Records.LOG_OVERHEAD or 0 should be chosen. There is no header + * per batch, but a sequence of batches is preceded by the offset and size. This method returns `0` as it's what + * `MemoryRecordsBuilder` requires. + */ + public static int recordBatchHeaderSizeInBytes(byte magic, CompressionType compressionType) { + if (magic > RecordBatch.MAGIC_VALUE_V1) { + return DefaultRecordBatch.RECORD_BATCH_OVERHEAD; + } else if (compressionType != CompressionType.NONE) { + return Records.LOG_OVERHEAD + LegacyRecord.recordOverhead(magic); + } else { + return 0; + } + } + + +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/BaseRecords.java b/clients/src/main/java/org/apache/kafka/common/record/BaseRecords.java new file mode 100644 index 0000000..49e316f --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/BaseRecords.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +/** + * Base interface for accessing records which could be contained in the log, or an in-memory materialization of log records. + */ +public interface BaseRecords { + /** + * The size of these records in bytes. + * @return The size in bytes of the records + */ + int sizeInBytes(); + + /** + * Encapsulate this {@link BaseRecords} object into {@link RecordsSend} + * @return Initialized {@link RecordsSend} object + */ + RecordsSend toSend(); +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/ByteBufferLogInputStream.java b/clients/src/main/java/org/apache/kafka/common/record/ByteBufferLogInputStream.java new file mode 100644 index 0000000..572abd8 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/ByteBufferLogInputStream.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.errors.CorruptRecordException; + +import java.nio.ByteBuffer; + +import static org.apache.kafka.common.record.Records.HEADER_SIZE_UP_TO_MAGIC; +import static org.apache.kafka.common.record.Records.LOG_OVERHEAD; +import static org.apache.kafka.common.record.Records.MAGIC_OFFSET; +import static org.apache.kafka.common.record.Records.SIZE_OFFSET; + +/** + * A byte buffer backed log input stream. This class avoids the need to copy records by returning + * slices from the underlying byte buffer. + */ +class ByteBufferLogInputStream implements LogInputStream { + private final ByteBuffer buffer; + private final int maxMessageSize; + + ByteBufferLogInputStream(ByteBuffer buffer, int maxMessageSize) { + this.buffer = buffer; + this.maxMessageSize = maxMessageSize; + } + + public MutableRecordBatch nextBatch() { + int remaining = buffer.remaining(); + + Integer batchSize = nextBatchSize(); + if (batchSize == null || remaining < batchSize) + return null; + + byte magic = buffer.get(buffer.position() + MAGIC_OFFSET); + + ByteBuffer batchSlice = buffer.slice(); + batchSlice.limit(batchSize); + buffer.position(buffer.position() + batchSize); + + if (magic > RecordBatch.MAGIC_VALUE_V1) + return new DefaultRecordBatch(batchSlice); + else + return new AbstractLegacyRecordBatch.ByteBufferLegacyRecordBatch(batchSlice); + } + + /** + * Validates the header of the next batch and returns batch size. + * @return next batch size including LOG_OVERHEAD if buffer contains header up to + * magic byte, null otherwise + * @throws CorruptRecordException if record size or magic is invalid + */ + Integer nextBatchSize() throws CorruptRecordException { + int remaining = buffer.remaining(); + if (remaining < LOG_OVERHEAD) + return null; + int recordSize = buffer.getInt(buffer.position() + SIZE_OFFSET); + // V0 has the smallest overhead, stricter checking is done later + if (recordSize < LegacyRecord.RECORD_OVERHEAD_V0) + throw new CorruptRecordException(String.format("Record size %d is less than the minimum record overhead (%d)", + recordSize, LegacyRecord.RECORD_OVERHEAD_V0)); + if (recordSize > maxMessageSize) + throw new CorruptRecordException(String.format("Record size %d exceeds the largest allowable message size (%d).", + recordSize, maxMessageSize)); + + if (remaining < HEADER_SIZE_UP_TO_MAGIC) + return null; + + byte magic = buffer.get(buffer.position() + MAGIC_OFFSET); + if (magic < 0 || magic > RecordBatch.CURRENT_MAGIC_VALUE) + throw new CorruptRecordException("Invalid magic found in record: " + magic); + + return recordSize + LOG_OVERHEAD; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/CompressionRatioEstimator.java b/clients/src/main/java/org/apache/kafka/common/record/CompressionRatioEstimator.java new file mode 100644 index 0000000..264525b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/CompressionRatioEstimator.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.record; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + + +/** + * This class help estimate the compression ratio for each topic and compression type combination. + */ +public class CompressionRatioEstimator { + // The constant speed to increase compression ratio when a batch compresses better than expected. + public static final float COMPRESSION_RATIO_IMPROVING_STEP = 0.005f; + // The minimum speed to decrease compression ratio when a batch compresses worse than expected. + public static final float COMPRESSION_RATIO_DETERIORATE_STEP = 0.05f; + private static final ConcurrentMap COMPRESSION_RATIO = new ConcurrentHashMap<>(); + + /** + * Update the compression ratio estimation for a topic and compression type. + * + * @param topic the topic to update compression ratio estimation. + * @param type the compression type. + * @param observedRatio the observed compression ratio. + * @return the compression ratio estimation after the update. + */ + public static float updateEstimation(String topic, CompressionType type, float observedRatio) { + float[] compressionRatioForTopic = getAndCreateEstimationIfAbsent(topic); + float currentEstimation = compressionRatioForTopic[type.id]; + synchronized (compressionRatioForTopic) { + if (observedRatio > currentEstimation) + compressionRatioForTopic[type.id] = Math.max(currentEstimation + COMPRESSION_RATIO_DETERIORATE_STEP, observedRatio); + else if (observedRatio < currentEstimation) { + compressionRatioForTopic[type.id] = Math.max(currentEstimation - COMPRESSION_RATIO_IMPROVING_STEP, observedRatio); + } + } + return compressionRatioForTopic[type.id]; + } + + /** + * Get the compression ratio estimation for a topic and compression type. + */ + public static float estimation(String topic, CompressionType type) { + float[] compressionRatioForTopic = getAndCreateEstimationIfAbsent(topic); + return compressionRatioForTopic[type.id]; + } + + /** + * Reset the compression ratio estimation to the initial values for a topic. + */ + public static void resetEstimation(String topic) { + float[] compressionRatioForTopic = getAndCreateEstimationIfAbsent(topic); + synchronized (compressionRatioForTopic) { + for (CompressionType type : CompressionType.values()) { + compressionRatioForTopic[type.id] = type.rate; + } + } + } + + /** + * Remove the compression ratio estimation for a topic. + */ + public static void removeEstimation(String topic) { + COMPRESSION_RATIO.remove(topic); + } + + /** + * Set the compression estimation for a topic compression type combination. This method is for unit test purpose. + */ + public static void setEstimation(String topic, CompressionType type, float ratio) { + float[] compressionRatioForTopic = getAndCreateEstimationIfAbsent(topic); + synchronized (compressionRatioForTopic) { + compressionRatioForTopic[type.id] = ratio; + } + } + + private static float[] getAndCreateEstimationIfAbsent(String topic) { + float[] compressionRatioForTopic = COMPRESSION_RATIO.get(topic); + if (compressionRatioForTopic == null) { + compressionRatioForTopic = initialCompressionRatio(); + float[] existingCompressionRatio = COMPRESSION_RATIO.putIfAbsent(topic, compressionRatioForTopic); + // Someone created the compression ratio array before us, use it. + if (existingCompressionRatio != null) + return existingCompressionRatio; + } + return compressionRatioForTopic; + } + + private static float[] initialCompressionRatio() { + float[] compressionRatio = new float[CompressionType.values().length]; + for (CompressionType type : CompressionType.values()) { + compressionRatio[type.id] = type.rate; + } + return compressionRatio; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/CompressionType.java b/clients/src/main/java/org/apache/kafka/common/record/CompressionType.java new file mode 100644 index 0000000..1b9754f --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/CompressionType.java @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.compress.KafkaLZ4BlockInputStream; +import org.apache.kafka.common.compress.KafkaLZ4BlockOutputStream; +import org.apache.kafka.common.compress.SnappyFactory; +import org.apache.kafka.common.compress.ZstdFactory; +import org.apache.kafka.common.utils.BufferSupplier; +import org.apache.kafka.common.utils.ByteBufferInputStream; +import org.apache.kafka.common.utils.ByteBufferOutputStream; + +import java.io.BufferedInputStream; +import java.io.BufferedOutputStream; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.util.zip.GZIPInputStream; +import java.util.zip.GZIPOutputStream; + +/** + * The compression type to use + */ +public enum CompressionType { + NONE(0, "none", 1.0f) { + @Override + public OutputStream wrapForOutput(ByteBufferOutputStream buffer, byte messageVersion) { + return buffer; + } + + @Override + public InputStream wrapForInput(ByteBuffer buffer, byte messageVersion, BufferSupplier decompressionBufferSupplier) { + return new ByteBufferInputStream(buffer); + } + }, + + // Shipped with the JDK + GZIP(1, "gzip", 1.0f) { + @Override + public OutputStream wrapForOutput(ByteBufferOutputStream buffer, byte messageVersion) { + try { + // Set input buffer (uncompressed) to 16 KB (none by default) and output buffer (compressed) to + // 8 KB (0.5 KB by default) to ensure reasonable performance in cases where the caller passes a small + // number of bytes to write (potentially a single byte) + return new BufferedOutputStream(new GZIPOutputStream(buffer, 8 * 1024), 16 * 1024); + } catch (Exception e) { + throw new KafkaException(e); + } + } + + @Override + public InputStream wrapForInput(ByteBuffer buffer, byte messageVersion, BufferSupplier decompressionBufferSupplier) { + try { + // Set output buffer (uncompressed) to 16 KB (none by default) and input buffer (compressed) to + // 8 KB (0.5 KB by default) to ensure reasonable performance in cases where the caller reads a small + // number of bytes (potentially a single byte) + return new BufferedInputStream(new GZIPInputStream(new ByteBufferInputStream(buffer), 8 * 1024), + 16 * 1024); + } catch (Exception e) { + throw new KafkaException(e); + } + } + }, + + // We should only load classes from a given compression library when we actually use said compression library. This + // is because compression libraries include native code for a set of platforms and we want to avoid errors + // in case the platform is not supported and the compression library is not actually used. + // To ensure this, we only reference compression library code from classes that are only invoked when actual usage + // happens. + + SNAPPY(2, "snappy", 1.0f) { + @Override + public OutputStream wrapForOutput(ByteBufferOutputStream buffer, byte messageVersion) { + return SnappyFactory.wrapForOutput(buffer); + } + + @Override + public InputStream wrapForInput(ByteBuffer buffer, byte messageVersion, BufferSupplier decompressionBufferSupplier) { + return SnappyFactory.wrapForInput(buffer); + } + }, + + LZ4(3, "lz4", 1.0f) { + @Override + public OutputStream wrapForOutput(ByteBufferOutputStream buffer, byte messageVersion) { + try { + return new KafkaLZ4BlockOutputStream(buffer, messageVersion == RecordBatch.MAGIC_VALUE_V0); + } catch (Throwable e) { + throw new KafkaException(e); + } + } + + @Override + public InputStream wrapForInput(ByteBuffer inputBuffer, byte messageVersion, BufferSupplier decompressionBufferSupplier) { + try { + return new KafkaLZ4BlockInputStream(inputBuffer, decompressionBufferSupplier, + messageVersion == RecordBatch.MAGIC_VALUE_V0); + } catch (Throwable e) { + throw new KafkaException(e); + } + } + }, + + ZSTD(4, "zstd", 1.0f) { + @Override + public OutputStream wrapForOutput(ByteBufferOutputStream buffer, byte messageVersion) { + return ZstdFactory.wrapForOutput(buffer); + } + + @Override + public InputStream wrapForInput(ByteBuffer buffer, byte messageVersion, BufferSupplier decompressionBufferSupplier) { + return ZstdFactory.wrapForInput(buffer, messageVersion, decompressionBufferSupplier); + } + }; + + public final int id; + public final String name; + public final float rate; + + CompressionType(int id, String name, float rate) { + this.id = id; + this.name = name; + this.rate = rate; + } + + /** + * Wrap bufferStream with an OutputStream that will compress data with this CompressionType. + * + * Note: Unlike {@link #wrapForInput}, {@link #wrapForOutput} cannot take {@link ByteBuffer}s directly. + * Currently, {@link MemoryRecordsBuilder#writeDefaultBatchHeader()} and {@link MemoryRecordsBuilder#writeLegacyCompressedWrapperHeader()} + * write to the underlying buffer in the given {@link ByteBufferOutputStream} after the compressed data has been written. + * In the event that the buffer needs to be expanded while writing the data, access to the underlying buffer needs to be preserved. + */ + public abstract OutputStream wrapForOutput(ByteBufferOutputStream bufferStream, byte messageVersion); + + /** + * Wrap buffer with an InputStream that will decompress data with this CompressionType. + * + * @param decompressionBufferSupplier The supplier of ByteBuffer(s) used for decompression if supported. + * For small record batches, allocating a potentially large buffer (64 KB for LZ4) + * will dominate the cost of decompressing and iterating over the records in the + * batch. As such, a supplier that reuses buffers will have a significant + * performance impact. + */ + public abstract InputStream wrapForInput(ByteBuffer buffer, byte messageVersion, BufferSupplier decompressionBufferSupplier); + + public static CompressionType forId(int id) { + switch (id) { + case 0: + return NONE; + case 1: + return GZIP; + case 2: + return SNAPPY; + case 3: + return LZ4; + case 4: + return ZSTD; + default: + throw new IllegalArgumentException("Unknown compression type id: " + id); + } + } + + public static CompressionType forName(String name) { + if (NONE.name.equals(name)) + return NONE; + else if (GZIP.name.equals(name)) + return GZIP; + else if (SNAPPY.name.equals(name)) + return SNAPPY; + else if (LZ4.name.equals(name)) + return LZ4; + else if (ZSTD.name.equals(name)) + return ZSTD; + else + throw new IllegalArgumentException("Unknown compression name: " + name); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/ControlRecordType.java b/clients/src/main/java/org/apache/kafka/common/record/ControlRecordType.java new file mode 100644 index 0000000..40c64f8 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/ControlRecordType.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.InvalidRecordException; +import org.apache.kafka.common.protocol.types.Field; +import org.apache.kafka.common.protocol.types.Schema; +import org.apache.kafka.common.protocol.types.Struct; +import org.apache.kafka.common.protocol.types.Type; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; + +/** + * Control records specify a schema for the record key which includes a version and type: + * + * Key => Version Type + * Version => Int16 + * Type => Int16 + * + * In the future, the version can be bumped to indicate a new schema, but it must be backwards compatible + * with the current schema. In general, this means we can add new fields, but we cannot remove old ones. + * + * Note that control records are not considered for compaction by the log cleaner. + * + * The schema for the value field is left to the control record type to specify. + */ +public enum ControlRecordType { + ABORT((short) 0), + COMMIT((short) 1), + + // Raft quorum related control messages. + LEADER_CHANGE((short) 2), + SNAPSHOT_HEADER((short) 3), + SNAPSHOT_FOOTER((short) 4), + + // UNKNOWN is used to indicate a control type which the client is not aware of and should be ignored + UNKNOWN((short) -1); + + private static final Logger log = LoggerFactory.getLogger(ControlRecordType.class); + + static final short CURRENT_CONTROL_RECORD_KEY_VERSION = 0; + static final int CURRENT_CONTROL_RECORD_KEY_SIZE = 4; + private static final Schema CONTROL_RECORD_KEY_SCHEMA_VERSION_V0 = new Schema( + new Field("version", Type.INT16), + new Field("type", Type.INT16)); + + final short type; + + ControlRecordType(short type) { + this.type = type; + } + + public Struct recordKey() { + if (this == UNKNOWN) + throw new IllegalArgumentException("Cannot serialize UNKNOWN control record type"); + + Struct struct = new Struct(CONTROL_RECORD_KEY_SCHEMA_VERSION_V0); + struct.set("version", CURRENT_CONTROL_RECORD_KEY_VERSION); + struct.set("type", type); + return struct; + } + + public static short parseTypeId(ByteBuffer key) { + if (key.remaining() < CURRENT_CONTROL_RECORD_KEY_SIZE) + throw new InvalidRecordException("Invalid value size found for end control record key. Must have " + + "at least " + CURRENT_CONTROL_RECORD_KEY_SIZE + " bytes, but found only " + key.remaining()); + + short version = key.getShort(0); + if (version < 0) + throw new InvalidRecordException("Invalid version found for control record: " + version + + ". May indicate data corruption"); + + if (version != CURRENT_CONTROL_RECORD_KEY_VERSION) + log.debug("Received unknown control record key version {}. Parsing as version {}", version, + CURRENT_CONTROL_RECORD_KEY_VERSION); + return key.getShort(2); + } + + public static ControlRecordType fromTypeId(short typeId) { + switch (typeId) { + case 0: + return ABORT; + case 1: + return COMMIT; + case 2: + return LEADER_CHANGE; + case 3: + return SNAPSHOT_HEADER; + case 4: + return SNAPSHOT_FOOTER; + + default: + return UNKNOWN; + } + } + + public static ControlRecordType parse(ByteBuffer key) { + return fromTypeId(parseTypeId(key)); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/ControlRecordUtils.java b/clients/src/main/java/org/apache/kafka/common/record/ControlRecordUtils.java new file mode 100644 index 0000000..e74f641 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/ControlRecordUtils.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.message.LeaderChangeMessage; +import org.apache.kafka.common.message.SnapshotHeaderRecord; +import org.apache.kafka.common.message.SnapshotFooterRecord; +import org.apache.kafka.common.protocol.ByteBufferAccessor; + +import java.nio.ByteBuffer; + +/** + * Utility class for easy interaction with control records. + */ +public class ControlRecordUtils { + + public static final short LEADER_CHANGE_SCHEMA_HIGHEST_VERSION = new LeaderChangeMessage().highestSupportedVersion(); + public static final short SNAPSHOT_HEADER_HIGHEST_VERSION = new SnapshotHeaderRecord().highestSupportedVersion(); + public static final short SNAPSHOT_FOOTER_HIGHEST_VERSION = new SnapshotFooterRecord().highestSupportedVersion(); + + public static LeaderChangeMessage deserializeLeaderChangeMessage(Record record) { + ControlRecordType recordType = ControlRecordType.parse(record.key()); + if (recordType != ControlRecordType.LEADER_CHANGE) { + throw new IllegalArgumentException( + "Expected LEADER_CHANGE control record type(2), but found " + recordType.toString()); + } + return deserializeLeaderChangeMessage(record.value().duplicate()); + } + + public static LeaderChangeMessage deserializeLeaderChangeMessage(ByteBuffer data) { + ByteBufferAccessor byteBufferAccessor = new ByteBufferAccessor(data.duplicate()); + return new LeaderChangeMessage(byteBufferAccessor, LEADER_CHANGE_SCHEMA_HIGHEST_VERSION); + } + + public static SnapshotHeaderRecord deserializedSnapshotHeaderRecord(Record record) { + ControlRecordType recordType = ControlRecordType.parse(record.key()); + if (recordType != ControlRecordType.SNAPSHOT_HEADER) { + throw new IllegalArgumentException( + "Expected SNAPSHOT_HEADER control record type(3), but found " + recordType.toString()); + } + return deserializedSnapshotHeaderRecord(record.value().duplicate()); + } + + public static SnapshotHeaderRecord deserializedSnapshotHeaderRecord(ByteBuffer data) { + ByteBufferAccessor byteBufferAccessor = new ByteBufferAccessor(data.duplicate()); + return new SnapshotHeaderRecord(byteBufferAccessor, SNAPSHOT_HEADER_HIGHEST_VERSION); + } + + public static SnapshotFooterRecord deserializedSnapshotFooterRecord(Record record) { + ControlRecordType recordType = ControlRecordType.parse(record.key()); + if (recordType != ControlRecordType.SNAPSHOT_FOOTER) { + throw new IllegalArgumentException( + "Expected SNAPSHOT_FOOTER control record type(4), but found " + recordType.toString()); + } + return deserializedSnapshotFooterRecord(record.value().duplicate()); + } + + public static SnapshotFooterRecord deserializedSnapshotFooterRecord(ByteBuffer data) { + ByteBufferAccessor byteBufferAccessor = new ByteBufferAccessor(data.duplicate()); + return new SnapshotFooterRecord(byteBufferAccessor, SNAPSHOT_FOOTER_HIGHEST_VERSION); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/ConvertedRecords.java b/clients/src/main/java/org/apache/kafka/common/record/ConvertedRecords.java new file mode 100644 index 0000000..d9150e5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/ConvertedRecords.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +public class ConvertedRecords { + + private final T records; + private final RecordConversionStats recordConversionStats; + + public ConvertedRecords(T records, RecordConversionStats recordConversionStats) { + this.records = records; + this.recordConversionStats = recordConversionStats; + } + + public T records() { + return records; + } + + public RecordConversionStats recordConversionStats() { + return recordConversionStats; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/DefaultRecord.java b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecord.java new file mode 100644 index 0000000..8772556 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecord.java @@ -0,0 +1,622 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.InvalidRecordException; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.utils.ByteUtils; +import org.apache.kafka.common.utils.PrimitiveRef; +import org.apache.kafka.common.utils.PrimitiveRef.IntRef; +import org.apache.kafka.common.utils.Utils; + +import java.io.DataInput; +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.BufferUnderflowException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Objects; + +import static org.apache.kafka.common.record.RecordBatch.MAGIC_VALUE_V2; + +/** + * This class implements the inner record format for magic 2 and above. The schema is as follows: + * + * + * Record => + * Length => Varint + * Attributes => Int8 + * TimestampDelta => Varlong + * OffsetDelta => Varint + * Key => Bytes + * Value => Bytes + * Headers => [HeaderKey HeaderValue] + * HeaderKey => String + * HeaderValue => Bytes + * + * Note that in this schema, the Bytes and String types use a variable length integer to represent + * the length of the field. The array type used for the headers also uses a Varint for the number of + * headers. + * + * The current record attributes are depicted below: + * + * ---------------- + * | Unused (0-7) | + * ---------------- + * + * The offset and timestamp deltas compute the difference relative to the base offset and + * base timestamp of the batch that this record is contained in. + */ +public class DefaultRecord implements Record { + + // excluding key, value and headers: 5 bytes length + 10 bytes timestamp + 5 bytes offset + 1 byte attributes + public static final int MAX_RECORD_OVERHEAD = 21; + + private static final int NULL_VARINT_SIZE_BYTES = ByteUtils.sizeOfVarint(-1); + + private final int sizeInBytes; + private final byte attributes; + private final long offset; + private final long timestamp; + private final int sequence; + private final ByteBuffer key; + private final ByteBuffer value; + private final Header[] headers; + + DefaultRecord(int sizeInBytes, + byte attributes, + long offset, + long timestamp, + int sequence, + ByteBuffer key, + ByteBuffer value, + Header[] headers) { + this.sizeInBytes = sizeInBytes; + this.attributes = attributes; + this.offset = offset; + this.timestamp = timestamp; + this.sequence = sequence; + this.key = key; + this.value = value; + this.headers = headers; + } + + @Override + public long offset() { + return offset; + } + + @Override + public int sequence() { + return sequence; + } + + @Override + public int sizeInBytes() { + return sizeInBytes; + } + + @Override + public long timestamp() { + return timestamp; + } + + public byte attributes() { + return attributes; + } + + @Override + public void ensureValid() {} + + @Override + public int keySize() { + return key == null ? -1 : key.remaining(); + } + + @Override + public int valueSize() { + return value == null ? -1 : value.remaining(); + } + + @Override + public boolean hasKey() { + return key != null; + } + + @Override + public ByteBuffer key() { + return key == null ? null : key.duplicate(); + } + + @Override + public boolean hasValue() { + return value != null; + } + + @Override + public ByteBuffer value() { + return value == null ? null : value.duplicate(); + } + + @Override + public Header[] headers() { + return headers; + } + + /** + * Write the record to `out` and return its size. + */ + public static int writeTo(DataOutputStream out, + int offsetDelta, + long timestampDelta, + ByteBuffer key, + ByteBuffer value, + Header[] headers) throws IOException { + int sizeInBytes = sizeOfBodyInBytes(offsetDelta, timestampDelta, key, value, headers); + ByteUtils.writeVarint(sizeInBytes, out); + + byte attributes = 0; // there are no used record attributes at the moment + out.write(attributes); + + ByteUtils.writeVarlong(timestampDelta, out); + ByteUtils.writeVarint(offsetDelta, out); + + if (key == null) { + ByteUtils.writeVarint(-1, out); + } else { + int keySize = key.remaining(); + ByteUtils.writeVarint(keySize, out); + Utils.writeTo(out, key, keySize); + } + + if (value == null) { + ByteUtils.writeVarint(-1, out); + } else { + int valueSize = value.remaining(); + ByteUtils.writeVarint(valueSize, out); + Utils.writeTo(out, value, valueSize); + } + + if (headers == null) + throw new IllegalArgumentException("Headers cannot be null"); + + ByteUtils.writeVarint(headers.length, out); + + for (Header header : headers) { + String headerKey = header.key(); + if (headerKey == null) + throw new IllegalArgumentException("Invalid null header key found in headers"); + + byte[] utf8Bytes = Utils.utf8(headerKey); + ByteUtils.writeVarint(utf8Bytes.length, out); + out.write(utf8Bytes); + + byte[] headerValue = header.value(); + if (headerValue == null) { + ByteUtils.writeVarint(-1, out); + } else { + ByteUtils.writeVarint(headerValue.length, out); + out.write(headerValue); + } + } + + return ByteUtils.sizeOfVarint(sizeInBytes) + sizeInBytes; + } + + @Override + public boolean hasMagic(byte magic) { + return magic >= MAGIC_VALUE_V2; + } + + @Override + public boolean isCompressed() { + return false; + } + + @Override + public boolean hasTimestampType(TimestampType timestampType) { + return false; + } + + @Override + public String toString() { + return String.format("DefaultRecord(offset=%d, timestamp=%d, key=%d bytes, value=%d bytes)", + offset, + timestamp, + key == null ? 0 : key.limit(), + value == null ? 0 : value.limit()); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + DefaultRecord that = (DefaultRecord) o; + return sizeInBytes == that.sizeInBytes && + attributes == that.attributes && + offset == that.offset && + timestamp == that.timestamp && + sequence == that.sequence && + Objects.equals(key, that.key) && + Objects.equals(value, that.value) && + Arrays.equals(headers, that.headers); + } + + @Override + public int hashCode() { + int result = sizeInBytes; + result = 31 * result + (int) attributes; + result = 31 * result + Long.hashCode(offset); + result = 31 * result + Long.hashCode(timestamp); + result = 31 * result + sequence; + result = 31 * result + (key != null ? key.hashCode() : 0); + result = 31 * result + (value != null ? value.hashCode() : 0); + result = 31 * result + Arrays.hashCode(headers); + return result; + } + + public static DefaultRecord readFrom(DataInput input, + long baseOffset, + long baseTimestamp, + int baseSequence, + Long logAppendTime) throws IOException { + int sizeOfBodyInBytes = ByteUtils.readVarint(input); + ByteBuffer recordBuffer = ByteBuffer.allocate(sizeOfBodyInBytes); + input.readFully(recordBuffer.array(), 0, sizeOfBodyInBytes); + int totalSizeInBytes = ByteUtils.sizeOfVarint(sizeOfBodyInBytes) + sizeOfBodyInBytes; + return readFrom(recordBuffer, totalSizeInBytes, sizeOfBodyInBytes, baseOffset, baseTimestamp, + baseSequence, logAppendTime); + } + + public static DefaultRecord readFrom(ByteBuffer buffer, + long baseOffset, + long baseTimestamp, + int baseSequence, + Long logAppendTime) { + int sizeOfBodyInBytes = ByteUtils.readVarint(buffer); + if (buffer.remaining() < sizeOfBodyInBytes) + throw new InvalidRecordException("Invalid record size: expected " + sizeOfBodyInBytes + + " bytes in record payload, but instead the buffer has only " + buffer.remaining() + + " remaining bytes."); + + int totalSizeInBytes = ByteUtils.sizeOfVarint(sizeOfBodyInBytes) + sizeOfBodyInBytes; + return readFrom(buffer, totalSizeInBytes, sizeOfBodyInBytes, baseOffset, baseTimestamp, + baseSequence, logAppendTime); + } + + private static DefaultRecord readFrom(ByteBuffer buffer, + int sizeInBytes, + int sizeOfBodyInBytes, + long baseOffset, + long baseTimestamp, + int baseSequence, + Long logAppendTime) { + try { + int recordStart = buffer.position(); + byte attributes = buffer.get(); + long timestampDelta = ByteUtils.readVarlong(buffer); + long timestamp = baseTimestamp + timestampDelta; + if (logAppendTime != null) + timestamp = logAppendTime; + + int offsetDelta = ByteUtils.readVarint(buffer); + long offset = baseOffset + offsetDelta; + int sequence = baseSequence >= 0 ? + DefaultRecordBatch.incrementSequence(baseSequence, offsetDelta) : + RecordBatch.NO_SEQUENCE; + + ByteBuffer key = null; + int keySize = ByteUtils.readVarint(buffer); + if (keySize >= 0) { + key = buffer.slice(); + key.limit(keySize); + buffer.position(buffer.position() + keySize); + } + + ByteBuffer value = null; + int valueSize = ByteUtils.readVarint(buffer); + if (valueSize >= 0) { + value = buffer.slice(); + value.limit(valueSize); + buffer.position(buffer.position() + valueSize); + } + + int numHeaders = ByteUtils.readVarint(buffer); + if (numHeaders < 0) + throw new InvalidRecordException("Found invalid number of record headers " + numHeaders); + + final Header[] headers; + if (numHeaders == 0) + headers = Record.EMPTY_HEADERS; + else + headers = readHeaders(buffer, numHeaders); + + // validate whether we have read all header bytes in the current record + if (buffer.position() - recordStart != sizeOfBodyInBytes) + throw new InvalidRecordException("Invalid record size: expected to read " + sizeOfBodyInBytes + + " bytes in record payload, but instead read " + (buffer.position() - recordStart)); + + return new DefaultRecord(sizeInBytes, attributes, offset, timestamp, sequence, key, value, headers); + } catch (BufferUnderflowException | IllegalArgumentException e) { + throw new InvalidRecordException("Found invalid record structure", e); + } + } + + public static PartialDefaultRecord readPartiallyFrom(DataInput input, + byte[] skipArray, + long baseOffset, + long baseTimestamp, + int baseSequence, + Long logAppendTime) throws IOException { + int sizeOfBodyInBytes = ByteUtils.readVarint(input); + int totalSizeInBytes = ByteUtils.sizeOfVarint(sizeOfBodyInBytes) + sizeOfBodyInBytes; + + return readPartiallyFrom(input, skipArray, totalSizeInBytes, sizeOfBodyInBytes, baseOffset, baseTimestamp, + baseSequence, logAppendTime); + } + + private static PartialDefaultRecord readPartiallyFrom(DataInput input, + byte[] skipArray, + int sizeInBytes, + int sizeOfBodyInBytes, + long baseOffset, + long baseTimestamp, + int baseSequence, + Long logAppendTime) throws IOException { + ByteBuffer skipBuffer = ByteBuffer.wrap(skipArray); + // set its limit to 0 to indicate no bytes readable yet + skipBuffer.limit(0); + + try { + // reading the attributes / timestamp / offset and key-size does not require + // any byte array allocation and therefore we can just read them straight-forwardly + IntRef bytesRemaining = PrimitiveRef.ofInt(sizeOfBodyInBytes); + + byte attributes = readByte(skipBuffer, input, bytesRemaining); + long timestampDelta = readVarLong(skipBuffer, input, bytesRemaining); + long timestamp = baseTimestamp + timestampDelta; + if (logAppendTime != null) + timestamp = logAppendTime; + + int offsetDelta = readVarInt(skipBuffer, input, bytesRemaining); + long offset = baseOffset + offsetDelta; + int sequence = baseSequence >= 0 ? + DefaultRecordBatch.incrementSequence(baseSequence, offsetDelta) : + RecordBatch.NO_SEQUENCE; + + // first skip key + int keySize = skipLengthDelimitedField(skipBuffer, input, bytesRemaining); + + // then skip value + int valueSize = skipLengthDelimitedField(skipBuffer, input, bytesRemaining); + + // then skip header + int numHeaders = readVarInt(skipBuffer, input, bytesRemaining); + if (numHeaders < 0) + throw new InvalidRecordException("Found invalid number of record headers " + numHeaders); + for (int i = 0; i < numHeaders; i++) { + int headerKeySize = skipLengthDelimitedField(skipBuffer, input, bytesRemaining); + if (headerKeySize < 0) + throw new InvalidRecordException("Invalid negative header key size " + headerKeySize); + + // headerValueSize + skipLengthDelimitedField(skipBuffer, input, bytesRemaining); + } + + if (bytesRemaining.value > 0 || skipBuffer.remaining() > 0) + throw new InvalidRecordException("Invalid record size: expected to read " + sizeOfBodyInBytes + + " bytes in record payload, but there are still bytes remaining"); + + return new PartialDefaultRecord(sizeInBytes, attributes, offset, timestamp, sequence, keySize, valueSize); + } catch (BufferUnderflowException | IllegalArgumentException e) { + throw new InvalidRecordException("Found invalid record structure", e); + } + } + + private static byte readByte(ByteBuffer buffer, DataInput input, IntRef bytesRemaining) throws IOException { + if (buffer.remaining() < 1 && bytesRemaining.value > 0) { + readMore(buffer, input, bytesRemaining); + } + + return buffer.get(); + } + + private static long readVarLong(ByteBuffer buffer, DataInput input, IntRef bytesRemaining) throws IOException { + if (buffer.remaining() < 10 && bytesRemaining.value > 0) { + readMore(buffer, input, bytesRemaining); + } + + return ByteUtils.readVarlong(buffer); + } + + private static int readVarInt(ByteBuffer buffer, DataInput input, IntRef bytesRemaining) throws IOException { + if (buffer.remaining() < 5 && bytesRemaining.value > 0) { + readMore(buffer, input, bytesRemaining); + } + + return ByteUtils.readVarint(buffer); + } + + private static int skipLengthDelimitedField(ByteBuffer buffer, DataInput input, IntRef bytesRemaining) throws IOException { + boolean needMore = false; + int sizeInBytes = -1; + int bytesToSkip = -1; + + while (true) { + if (needMore) { + readMore(buffer, input, bytesRemaining); + needMore = false; + } + + if (bytesToSkip < 0) { + if (buffer.remaining() < 5 && bytesRemaining.value > 0) { + needMore = true; + } else { + sizeInBytes = ByteUtils.readVarint(buffer); + if (sizeInBytes <= 0) + return sizeInBytes; + else + bytesToSkip = sizeInBytes; + + } + } else { + if (bytesToSkip > buffer.remaining()) { + bytesToSkip -= buffer.remaining(); + buffer.position(buffer.limit()); + needMore = true; + } else { + buffer.position(buffer.position() + bytesToSkip); + return sizeInBytes; + } + } + } + } + + private static void readMore(ByteBuffer buffer, DataInput input, IntRef bytesRemaining) throws IOException { + if (bytesRemaining.value > 0) { + byte[] array = buffer.array(); + + // first copy the remaining bytes to the beginning of the array; + // at most 4 bytes would be shifted here + int stepsToLeftShift = buffer.position(); + int bytesToLeftShift = buffer.remaining(); + for (int i = 0; i < bytesToLeftShift; i++) { + array[i] = array[i + stepsToLeftShift]; + } + + // then try to read more bytes to the remaining of the array + int bytesRead = Math.min(bytesRemaining.value, array.length - bytesToLeftShift); + input.readFully(array, bytesToLeftShift, bytesRead); + buffer.rewind(); + // only those many bytes are readable + buffer.limit(bytesToLeftShift + bytesRead); + + bytesRemaining.value -= bytesRead; + } else { + throw new InvalidRecordException("Invalid record size: expected to read more bytes in record payload"); + } + } + + private static Header[] readHeaders(ByteBuffer buffer, int numHeaders) { + Header[] headers = new Header[numHeaders]; + for (int i = 0; i < numHeaders; i++) { + int headerKeySize = ByteUtils.readVarint(buffer); + if (headerKeySize < 0) + throw new InvalidRecordException("Invalid negative header key size " + headerKeySize); + + ByteBuffer headerKeyBuffer = buffer.slice(); + headerKeyBuffer.limit(headerKeySize); + buffer.position(buffer.position() + headerKeySize); + + ByteBuffer headerValue = null; + int headerValueSize = ByteUtils.readVarint(buffer); + if (headerValueSize >= 0) { + headerValue = buffer.slice(); + headerValue.limit(headerValueSize); + buffer.position(buffer.position() + headerValueSize); + } + + headers[i] = new RecordHeader(headerKeyBuffer, headerValue); + } + + return headers; + } + + public static int sizeInBytes(int offsetDelta, + long timestampDelta, + ByteBuffer key, + ByteBuffer value, + Header[] headers) { + int bodySize = sizeOfBodyInBytes(offsetDelta, timestampDelta, key, value, headers); + return bodySize + ByteUtils.sizeOfVarint(bodySize); + } + + public static int sizeInBytes(int offsetDelta, + long timestampDelta, + int keySize, + int valueSize, + Header[] headers) { + int bodySize = sizeOfBodyInBytes(offsetDelta, timestampDelta, keySize, valueSize, headers); + return bodySize + ByteUtils.sizeOfVarint(bodySize); + } + + private static int sizeOfBodyInBytes(int offsetDelta, + long timestampDelta, + ByteBuffer key, + ByteBuffer value, + Header[] headers) { + int keySize = key == null ? -1 : key.remaining(); + int valueSize = value == null ? -1 : value.remaining(); + return sizeOfBodyInBytes(offsetDelta, timestampDelta, keySize, valueSize, headers); + } + + public static int sizeOfBodyInBytes(int offsetDelta, + long timestampDelta, + int keySize, + int valueSize, + Header[] headers) { + int size = 1; // always one byte for attributes + size += ByteUtils.sizeOfVarint(offsetDelta); + size += ByteUtils.sizeOfVarlong(timestampDelta); + size += sizeOf(keySize, valueSize, headers); + return size; + } + + private static int sizeOf(int keySize, int valueSize, Header[] headers) { + int size = 0; + if (keySize < 0) + size += NULL_VARINT_SIZE_BYTES; + else + size += ByteUtils.sizeOfVarint(keySize) + keySize; + + if (valueSize < 0) + size += NULL_VARINT_SIZE_BYTES; + else + size += ByteUtils.sizeOfVarint(valueSize) + valueSize; + + if (headers == null) + throw new IllegalArgumentException("Headers cannot be null"); + + size += ByteUtils.sizeOfVarint(headers.length); + for (Header header : headers) { + String headerKey = header.key(); + if (headerKey == null) + throw new IllegalArgumentException("Invalid null header key found in headers"); + + int headerKeySize = Utils.utf8Length(headerKey); + size += ByteUtils.sizeOfVarint(headerKeySize) + headerKeySize; + + byte[] headerValue = header.value(); + if (headerValue == null) { + size += NULL_VARINT_SIZE_BYTES; + } else { + size += ByteUtils.sizeOfVarint(headerValue.length) + headerValue.length; + } + } + return size; + } + + static int recordSizeUpperBound(ByteBuffer key, ByteBuffer value, Header[] headers) { + int keySize = key == null ? -1 : key.remaining(); + int valueSize = value == null ? -1 : value.remaining(); + return MAX_RECORD_OVERHEAD + sizeOf(keySize, valueSize, headers); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java new file mode 100644 index 0000000..ec3c720 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordBatch.java @@ -0,0 +1,749 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.InvalidRecordException; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.errors.CorruptRecordException; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.utils.BufferSupplier; +import org.apache.kafka.common.utils.ByteBufferOutputStream; +import org.apache.kafka.common.utils.ByteUtils; +import org.apache.kafka.common.utils.CloseableIterator; +import org.apache.kafka.common.utils.Crc32C; + +import java.io.DataInputStream; +import java.io.EOFException; +import java.io.IOException; +import java.nio.BufferUnderflowException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Objects; +import java.util.OptionalLong; + +import static org.apache.kafka.common.record.Records.LOG_OVERHEAD; + +/** + * RecordBatch implementation for magic 2 and above. The schema is given below: + * + * RecordBatch => + * BaseOffset => Int64 + * Length => Int32 + * PartitionLeaderEpoch => Int32 + * Magic => Int8 + * CRC => Uint32 + * Attributes => Int16 + * LastOffsetDelta => Int32 // also serves as LastSequenceDelta + * FirstTimestamp => Int64 + * MaxTimestamp => Int64 + * ProducerId => Int64 + * ProducerEpoch => Int16 + * BaseSequence => Int32 + * Records => [Record] + * + * Note that when compression is enabled (see attributes below), the compressed record data is serialized + * directly following the count of the number of records. + * + * The CRC covers the data from the attributes to the end of the batch (i.e. all the bytes that follow the CRC). It is + * located after the magic byte, which means that clients must parse the magic byte before deciding how to interpret + * the bytes between the batch length and the magic byte. The partition leader epoch field is not included in the CRC + * computation to avoid the need to recompute the CRC when this field is assigned for every batch that is received by + * the broker. The CRC-32C (Castagnoli) polynomial is used for the computation. + * + * On Compaction: Unlike the older message formats, magic v2 and above preserves the first and last offset/sequence + * numbers from the original batch when the log is cleaned. This is required in order to be able to restore the + * producer's state when the log is reloaded. If we did not retain the last sequence number, then following + * a partition leader failure, once the new leader has rebuilt the producer state from the log, the next sequence + * expected number would no longer be in sync with what was written by the client. This would cause an + * unexpected OutOfOrderSequence error, which is typically fatal. The base sequence number must be preserved for + * duplicate checking: the broker checks incoming Produce requests for duplicates by verifying that the first and + * last sequence numbers of the incoming batch match the last from that producer. + * + * Note that if all of the records in a batch are removed during compaction, the broker may still retain an empty + * batch header in order to preserve the producer sequence information as described above. These empty batches + * are retained only until either a new sequence number is written by the corresponding producer or the producerId + * is expired from lack of activity. + * + * There is no similar need to preserve the timestamp from the original batch after compaction. The FirstTimestamp + * field therefore always reflects the timestamp of the first record in the batch. If the batch is empty, the + * FirstTimestamp will be set to -1 (NO_TIMESTAMP). + * + * Similarly, the MaxTimestamp field reflects the maximum timestamp of the current records if the timestamp type + * is CREATE_TIME. For LOG_APPEND_TIME, on the other hand, the MaxTimestamp field reflects the timestamp set + * by the broker and is preserved after compaction. Additionally, the MaxTimestamp of an empty batch always retains + * the previous value prior to becoming empty. + * + * The delete horizon flag for the sixth bit is used to determine if the first timestamp of the batch had been set to + * the time for which tombstones / transaction markers need to be removed. If it is true, then the first timestamp is + * the delete horizon, otherwise, it is merely the first timestamp of the record batch. + * + * The current attributes are given below: + * + * --------------------------------------------------------------------------------------------------------------------------- + * | Unused (7-15) | Delete Horizon Flag (6) | Control (5) | Transactional (4) | Timestamp Type (3) | Compression Type (0-2) | + * --------------------------------------------------------------------------------------------------------------------------- + */ +public class DefaultRecordBatch extends AbstractRecordBatch implements MutableRecordBatch { + static final int BASE_OFFSET_OFFSET = 0; + static final int BASE_OFFSET_LENGTH = 8; + static final int LENGTH_OFFSET = BASE_OFFSET_OFFSET + BASE_OFFSET_LENGTH; + static final int LENGTH_LENGTH = 4; + static final int PARTITION_LEADER_EPOCH_OFFSET = LENGTH_OFFSET + LENGTH_LENGTH; + static final int PARTITION_LEADER_EPOCH_LENGTH = 4; + static final int MAGIC_OFFSET = PARTITION_LEADER_EPOCH_OFFSET + PARTITION_LEADER_EPOCH_LENGTH; + static final int MAGIC_LENGTH = 1; + static final int CRC_OFFSET = MAGIC_OFFSET + MAGIC_LENGTH; + static final int CRC_LENGTH = 4; + static final int ATTRIBUTES_OFFSET = CRC_OFFSET + CRC_LENGTH; + static final int ATTRIBUTE_LENGTH = 2; + public static final int LAST_OFFSET_DELTA_OFFSET = ATTRIBUTES_OFFSET + ATTRIBUTE_LENGTH; + static final int LAST_OFFSET_DELTA_LENGTH = 4; + static final int BASE_TIMESTAMP_OFFSET = LAST_OFFSET_DELTA_OFFSET + LAST_OFFSET_DELTA_LENGTH; + static final int BASE_TIMESTAMP_LENGTH = 8; + static final int MAX_TIMESTAMP_OFFSET = BASE_TIMESTAMP_OFFSET + BASE_TIMESTAMP_LENGTH; + static final int MAX_TIMESTAMP_LENGTH = 8; + static final int PRODUCER_ID_OFFSET = MAX_TIMESTAMP_OFFSET + MAX_TIMESTAMP_LENGTH; + static final int PRODUCER_ID_LENGTH = 8; + static final int PRODUCER_EPOCH_OFFSET = PRODUCER_ID_OFFSET + PRODUCER_ID_LENGTH; + static final int PRODUCER_EPOCH_LENGTH = 2; + static final int BASE_SEQUENCE_OFFSET = PRODUCER_EPOCH_OFFSET + PRODUCER_EPOCH_LENGTH; + static final int BASE_SEQUENCE_LENGTH = 4; + public static final int RECORDS_COUNT_OFFSET = BASE_SEQUENCE_OFFSET + BASE_SEQUENCE_LENGTH; + static final int RECORDS_COUNT_LENGTH = 4; + static final int RECORDS_OFFSET = RECORDS_COUNT_OFFSET + RECORDS_COUNT_LENGTH; + public static final int RECORD_BATCH_OVERHEAD = RECORDS_OFFSET; + + private static final byte COMPRESSION_CODEC_MASK = 0x07; + private static final byte TRANSACTIONAL_FLAG_MASK = 0x10; + private static final int CONTROL_FLAG_MASK = 0x20; + private static final byte DELETE_HORIZON_FLAG_MASK = 0x40; + private static final byte TIMESTAMP_TYPE_MASK = 0x08; + + private static final int MAX_SKIP_BUFFER_SIZE = 2048; + + private final ByteBuffer buffer; + + DefaultRecordBatch(ByteBuffer buffer) { + this.buffer = buffer; + } + + @Override + public byte magic() { + return buffer.get(MAGIC_OFFSET); + } + + @Override + public void ensureValid() { + if (sizeInBytes() < RECORD_BATCH_OVERHEAD) + throw new CorruptRecordException("Record batch is corrupt (the size " + sizeInBytes() + + " is smaller than the minimum allowed overhead " + RECORD_BATCH_OVERHEAD + ")"); + + if (!isValid()) + throw new CorruptRecordException("Record is corrupt (stored crc = " + checksum() + + ", computed crc = " + computeChecksum() + ")"); + } + + /** + * Gets the base timestamp of the batch which is used to calculate the record timestamps from the deltas. + * + * @return The base timestamp + */ + public long baseTimestamp() { + return buffer.getLong(BASE_TIMESTAMP_OFFSET); + } + + @Override + public long maxTimestamp() { + return buffer.getLong(MAX_TIMESTAMP_OFFSET); + } + + @Override + public TimestampType timestampType() { + return (attributes() & TIMESTAMP_TYPE_MASK) == 0 ? TimestampType.CREATE_TIME : TimestampType.LOG_APPEND_TIME; + } + + @Override + public long baseOffset() { + return buffer.getLong(BASE_OFFSET_OFFSET); + } + + @Override + public long lastOffset() { + return baseOffset() + lastOffsetDelta(); + } + + @Override + public long producerId() { + return buffer.getLong(PRODUCER_ID_OFFSET); + } + + @Override + public short producerEpoch() { + return buffer.getShort(PRODUCER_EPOCH_OFFSET); + } + + @Override + public int baseSequence() { + return buffer.getInt(BASE_SEQUENCE_OFFSET); + } + + private int lastOffsetDelta() { + return buffer.getInt(LAST_OFFSET_DELTA_OFFSET); + } + + @Override + public int lastSequence() { + int baseSequence = baseSequence(); + if (baseSequence == RecordBatch.NO_SEQUENCE) + return RecordBatch.NO_SEQUENCE; + return incrementSequence(baseSequence, lastOffsetDelta()); + } + + @Override + public CompressionType compressionType() { + return CompressionType.forId(attributes() & COMPRESSION_CODEC_MASK); + } + + @Override + public int sizeInBytes() { + return LOG_OVERHEAD + buffer.getInt(LENGTH_OFFSET); + } + + private int count() { + return buffer.getInt(RECORDS_COUNT_OFFSET); + } + + @Override + public Integer countOrNull() { + return count(); + } + + @Override + public void writeTo(ByteBuffer buffer) { + buffer.put(this.buffer.duplicate()); + } + + @Override + public void writeTo(ByteBufferOutputStream outputStream) { + outputStream.write(this.buffer.duplicate()); + } + + @Override + public boolean isTransactional() { + return (attributes() & TRANSACTIONAL_FLAG_MASK) > 0; + } + + private boolean hasDeleteHorizonMs() { + return (attributes() & DELETE_HORIZON_FLAG_MASK) > 0; + } + + @Override + public OptionalLong deleteHorizonMs() { + if (hasDeleteHorizonMs()) + return OptionalLong.of(buffer.getLong(BASE_TIMESTAMP_OFFSET)); + else + return OptionalLong.empty(); + } + + @Override + public boolean isControlBatch() { + return (attributes() & CONTROL_FLAG_MASK) > 0; + } + + @Override + public int partitionLeaderEpoch() { + return buffer.getInt(PARTITION_LEADER_EPOCH_OFFSET); + } + + public DataInputStream recordInputStream(BufferSupplier bufferSupplier) { + final ByteBuffer buffer = this.buffer.duplicate(); + buffer.position(RECORDS_OFFSET); + return new DataInputStream(compressionType().wrapForInput(buffer, magic(), bufferSupplier)); + } + + private CloseableIterator compressedIterator(BufferSupplier bufferSupplier, boolean skipKeyValue) { + final DataInputStream inputStream = recordInputStream(bufferSupplier); + + if (skipKeyValue) { + // this buffer is used to skip length delimited fields like key, value, headers + byte[] skipArray = new byte[MAX_SKIP_BUFFER_SIZE]; + + return new StreamRecordIterator(inputStream) { + @Override + protected Record doReadRecord(long baseOffset, long firstTimestamp, int baseSequence, Long logAppendTime) throws IOException { + return DefaultRecord.readPartiallyFrom(inputStream, skipArray, baseOffset, firstTimestamp, baseSequence, logAppendTime); + } + }; + } else { + return new StreamRecordIterator(inputStream) { + @Override + protected Record doReadRecord(long baseOffset, long firstTimestamp, int baseSequence, Long logAppendTime) throws IOException { + return DefaultRecord.readFrom(inputStream, baseOffset, firstTimestamp, baseSequence, logAppendTime); + } + }; + } + } + + private CloseableIterator uncompressedIterator() { + final ByteBuffer buffer = this.buffer.duplicate(); + buffer.position(RECORDS_OFFSET); + return new RecordIterator() { + @Override + protected Record readNext(long baseOffset, long baseTimestamp, int baseSequence, Long logAppendTime) { + try { + return DefaultRecord.readFrom(buffer, baseOffset, baseTimestamp, baseSequence, logAppendTime); + } catch (BufferUnderflowException e) { + throw new InvalidRecordException("Incorrect declared batch size, premature EOF reached"); + } + } + @Override + protected boolean ensureNoneRemaining() { + return !buffer.hasRemaining(); + } + @Override + public void close() {} + }; + } + + @Override + public Iterator iterator() { + if (count() == 0) + return Collections.emptyIterator(); + + if (!isCompressed()) + return uncompressedIterator(); + + // for a normal iterator, we cannot ensure that the underlying compression stream is closed, + // so we decompress the full record set here. Use cases which call for a lower memory footprint + // can use `streamingIterator` at the cost of additional complexity + try (CloseableIterator iterator = compressedIterator(BufferSupplier.NO_CACHING, false)) { + List records = new ArrayList<>(count()); + while (iterator.hasNext()) + records.add(iterator.next()); + return records.iterator(); + } + } + + @Override + public CloseableIterator skipKeyValueIterator(BufferSupplier bufferSupplier) { + if (count() == 0) { + return CloseableIterator.wrap(Collections.emptyIterator()); + } + + /* + * For uncompressed iterator, it is actually not worth skipping key / value / headers at all since + * its ByteBufferInputStream's skip() function is less efficient compared with just reading it actually + * as it will allocate new byte array. + */ + if (!isCompressed()) + return uncompressedIterator(); + + // we define this to be a closable iterator so that caller (i.e. the log validator) needs to close it + // while we can save memory footprint of not decompressing the full record set ahead of time + return compressedIterator(bufferSupplier, true); + } + + @Override + public CloseableIterator streamingIterator(BufferSupplier bufferSupplier) { + if (isCompressed()) + return compressedIterator(bufferSupplier, false); + else + return uncompressedIterator(); + } + + @Override + public void setLastOffset(long offset) { + buffer.putLong(BASE_OFFSET_OFFSET, offset - lastOffsetDelta()); + } + + @Override + public void setMaxTimestamp(TimestampType timestampType, long maxTimestamp) { + long currentMaxTimestamp = maxTimestamp(); + // We don't need to recompute crc if the timestamp is not updated. + if (timestampType() == timestampType && currentMaxTimestamp == maxTimestamp) + return; + + byte attributes = computeAttributes(compressionType(), timestampType, isTransactional(), isControlBatch(), hasDeleteHorizonMs()); + buffer.putShort(ATTRIBUTES_OFFSET, attributes); + buffer.putLong(MAX_TIMESTAMP_OFFSET, maxTimestamp); + long crc = computeChecksum(); + ByteUtils.writeUnsignedInt(buffer, CRC_OFFSET, crc); + } + + @Override + public void setPartitionLeaderEpoch(int epoch) { + buffer.putInt(PARTITION_LEADER_EPOCH_OFFSET, epoch); + } + + @Override + public long checksum() { + return ByteUtils.readUnsignedInt(buffer, CRC_OFFSET); + } + + public boolean isValid() { + return sizeInBytes() >= RECORD_BATCH_OVERHEAD && checksum() == computeChecksum(); + } + + private long computeChecksum() { + return Crc32C.compute(buffer, ATTRIBUTES_OFFSET, buffer.limit() - ATTRIBUTES_OFFSET); + } + + private byte attributes() { + // note we're not using the second byte of attributes + return (byte) buffer.getShort(ATTRIBUTES_OFFSET); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + DefaultRecordBatch that = (DefaultRecordBatch) o; + return Objects.equals(buffer, that.buffer); + } + + @Override + public int hashCode() { + return buffer != null ? buffer.hashCode() : 0; + } + + private static byte computeAttributes(CompressionType type, TimestampType timestampType, + boolean isTransactional, boolean isControl, boolean isDeleteHorizonSet) { + if (timestampType == TimestampType.NO_TIMESTAMP_TYPE) + throw new IllegalArgumentException("Timestamp type must be provided to compute attributes for message " + + "format v2 and above"); + + byte attributes = isTransactional ? TRANSACTIONAL_FLAG_MASK : 0; + if (isControl) + attributes |= CONTROL_FLAG_MASK; + if (type.id > 0) + attributes |= COMPRESSION_CODEC_MASK & type.id; + if (timestampType == TimestampType.LOG_APPEND_TIME) + attributes |= TIMESTAMP_TYPE_MASK; + if (isDeleteHorizonSet) + attributes |= DELETE_HORIZON_FLAG_MASK; + return attributes; + } + + public static void writeEmptyHeader(ByteBuffer buffer, + byte magic, + long producerId, + short producerEpoch, + int baseSequence, + long baseOffset, + long lastOffset, + int partitionLeaderEpoch, + TimestampType timestampType, + long timestamp, + boolean isTransactional, + boolean isControlRecord) { + int offsetDelta = (int) (lastOffset - baseOffset); + writeHeader(buffer, baseOffset, offsetDelta, DefaultRecordBatch.RECORD_BATCH_OVERHEAD, magic, + CompressionType.NONE, timestampType, RecordBatch.NO_TIMESTAMP, timestamp, producerId, + producerEpoch, baseSequence, isTransactional, isControlRecord, false, partitionLeaderEpoch, 0); + } + + public static void writeHeader(ByteBuffer buffer, + long baseOffset, + int lastOffsetDelta, + int sizeInBytes, + byte magic, + CompressionType compressionType, + TimestampType timestampType, + long firstTimestamp, + long maxTimestamp, + long producerId, + short epoch, + int sequence, + boolean isTransactional, + boolean isControlBatch, + boolean isDeleteHorizonSet, + int partitionLeaderEpoch, + int numRecords) { + if (magic < RecordBatch.CURRENT_MAGIC_VALUE) + throw new IllegalArgumentException("Invalid magic value " + magic); + if (firstTimestamp < 0 && firstTimestamp != NO_TIMESTAMP) + throw new IllegalArgumentException("Invalid message timestamp " + firstTimestamp); + + short attributes = computeAttributes(compressionType, timestampType, isTransactional, isControlBatch, isDeleteHorizonSet); + + int position = buffer.position(); + buffer.putLong(position + BASE_OFFSET_OFFSET, baseOffset); + buffer.putInt(position + LENGTH_OFFSET, sizeInBytes - LOG_OVERHEAD); + buffer.putInt(position + PARTITION_LEADER_EPOCH_OFFSET, partitionLeaderEpoch); + buffer.put(position + MAGIC_OFFSET, magic); + buffer.putShort(position + ATTRIBUTES_OFFSET, attributes); + buffer.putLong(position + BASE_TIMESTAMP_OFFSET, firstTimestamp); + buffer.putLong(position + MAX_TIMESTAMP_OFFSET, maxTimestamp); + buffer.putInt(position + LAST_OFFSET_DELTA_OFFSET, lastOffsetDelta); + buffer.putLong(position + PRODUCER_ID_OFFSET, producerId); + buffer.putShort(position + PRODUCER_EPOCH_OFFSET, epoch); + buffer.putInt(position + BASE_SEQUENCE_OFFSET, sequence); + buffer.putInt(position + RECORDS_COUNT_OFFSET, numRecords); + long crc = Crc32C.compute(buffer, ATTRIBUTES_OFFSET, sizeInBytes - ATTRIBUTES_OFFSET); + buffer.putInt(position + CRC_OFFSET, (int) crc); + buffer.position(position + RECORD_BATCH_OVERHEAD); + } + + @Override + public String toString() { + return "RecordBatch(magic=" + magic() + ", offsets=[" + baseOffset() + ", " + lastOffset() + "], " + + "sequence=[" + baseSequence() + ", " + lastSequence() + "], " + + "isTransactional=" + isTransactional() + ", isControlBatch=" + isControlBatch() + ", " + + "compression=" + compressionType() + ", timestampType=" + timestampType() + ", crc=" + checksum() + ")"; + } + + public static int sizeInBytes(long baseOffset, Iterable records) { + Iterator iterator = records.iterator(); + if (!iterator.hasNext()) + return 0; + + int size = RECORD_BATCH_OVERHEAD; + Long firstTimestamp = null; + while (iterator.hasNext()) { + Record record = iterator.next(); + int offsetDelta = (int) (record.offset() - baseOffset); + if (firstTimestamp == null) + firstTimestamp = record.timestamp(); + long timestampDelta = record.timestamp() - firstTimestamp; + size += DefaultRecord.sizeInBytes(offsetDelta, timestampDelta, record.key(), record.value(), + record.headers()); + } + return size; + } + + public static int sizeInBytes(Iterable records) { + Iterator iterator = records.iterator(); + if (!iterator.hasNext()) + return 0; + + int size = RECORD_BATCH_OVERHEAD; + int offsetDelta = 0; + Long firstTimestamp = null; + while (iterator.hasNext()) { + SimpleRecord record = iterator.next(); + if (firstTimestamp == null) + firstTimestamp = record.timestamp(); + long timestampDelta = record.timestamp() - firstTimestamp; + size += DefaultRecord.sizeInBytes(offsetDelta++, timestampDelta, record.key(), record.value(), + record.headers()); + } + return size; + } + + /** + * Get an upper bound on the size of a batch with only a single record using a given key and value. This + * is only an estimate because it does not take into account additional overhead from the compression + * algorithm used. + */ + static int estimateBatchSizeUpperBound(ByteBuffer key, ByteBuffer value, Header[] headers) { + return RECORD_BATCH_OVERHEAD + DefaultRecord.recordSizeUpperBound(key, value, headers); + } + + public static int incrementSequence(int sequence, int increment) { + if (sequence > Integer.MAX_VALUE - increment) + return increment - (Integer.MAX_VALUE - sequence) - 1; + return sequence + increment; + } + + public static int decrementSequence(int sequence, int decrement) { + if (sequence < decrement) + return Integer.MAX_VALUE - (decrement - sequence) + 1; + return sequence - decrement; + } + + private abstract class RecordIterator implements CloseableIterator { + private final Long logAppendTime; + private final long baseOffset; + private final long baseTimestamp; + private final int baseSequence; + private final int numRecords; + private int readRecords = 0; + + RecordIterator() { + this.logAppendTime = timestampType() == TimestampType.LOG_APPEND_TIME ? maxTimestamp() : null; + this.baseOffset = baseOffset(); + this.baseTimestamp = baseTimestamp(); + this.baseSequence = baseSequence(); + int numRecords = count(); + if (numRecords < 0) + throw new InvalidRecordException("Found invalid record count " + numRecords + " in magic v" + + magic() + " batch"); + this.numRecords = numRecords; + } + + @Override + public boolean hasNext() { + return readRecords < numRecords; + } + + @Override + public Record next() { + if (readRecords >= numRecords) + throw new NoSuchElementException(); + + readRecords++; + Record rec = readNext(baseOffset, baseTimestamp, baseSequence, logAppendTime); + if (readRecords == numRecords) { + // Validate that the actual size of the batch is equal to declared size + // by checking that after reading declared number of items, there no items left + // (overflow case, i.e. reading past buffer end is checked elsewhere). + if (!ensureNoneRemaining()) + throw new InvalidRecordException("Incorrect declared batch size, records still remaining in file"); + } + return rec; + } + + protected abstract Record readNext(long baseOffset, long baseTimestamp, int baseSequence, Long logAppendTime); + + protected abstract boolean ensureNoneRemaining(); + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + + } + + private abstract class StreamRecordIterator extends RecordIterator { + private final DataInputStream inputStream; + + StreamRecordIterator(DataInputStream inputStream) { + super(); + this.inputStream = inputStream; + } + + abstract Record doReadRecord(long baseOffset, long firstTimestamp, int baseSequence, Long logAppendTime) throws IOException; + + @Override + protected Record readNext(long baseOffset, long baseTimestamp, int baseSequence, Long logAppendTime) { + try { + return doReadRecord(baseOffset, baseTimestamp, baseSequence, logAppendTime); + } catch (EOFException e) { + throw new InvalidRecordException("Incorrect declared batch size, premature EOF reached"); + } catch (IOException e) { + throw new KafkaException("Failed to decompress record stream", e); + } + } + + @Override + protected boolean ensureNoneRemaining() { + try { + return inputStream.read() == -1; + } catch (IOException e) { + throw new KafkaException("Error checking for remaining bytes after reading batch", e); + } + } + + @Override + public void close() { + try { + inputStream.close(); + } catch (IOException e) { + throw new KafkaException("Failed to close record stream", e); + } + } + } + + static class DefaultFileChannelRecordBatch extends FileLogInputStream.FileChannelRecordBatch { + + DefaultFileChannelRecordBatch(long offset, + byte magic, + FileRecords fileRecords, + int position, + int batchSize) { + super(offset, magic, fileRecords, position, batchSize); + } + + @Override + protected RecordBatch toMemoryRecordBatch(ByteBuffer buffer) { + return new DefaultRecordBatch(buffer); + } + + @Override + public long baseOffset() { + return offset; + } + + @Override + public long lastOffset() { + return loadBatchHeader().lastOffset(); + } + + @Override + public long producerId() { + return loadBatchHeader().producerId(); + } + + @Override + public short producerEpoch() { + return loadBatchHeader().producerEpoch(); + } + + @Override + public int baseSequence() { + return loadBatchHeader().baseSequence(); + } + + @Override + public int lastSequence() { + return loadBatchHeader().lastSequence(); + } + + @Override + public long checksum() { + return loadBatchHeader().checksum(); + } + + @Override + public Integer countOrNull() { + return loadBatchHeader().countOrNull(); + } + + @Override + public boolean isTransactional() { + return loadBatchHeader().isTransactional(); + } + + @Override + public OptionalLong deleteHorizonMs() { + return loadBatchHeader().deleteHorizonMs(); + } + + @Override + public boolean isControlBatch() { + return loadBatchHeader().isControlBatch(); + } + + @Override + public int partitionLeaderEpoch() { + return loadBatchHeader().partitionLeaderEpoch(); + } + + @Override + protected int headerSize() { + return RECORD_BATCH_OVERHEAD; + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordsSend.java b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordsSend.java new file mode 100644 index 0000000..bbb17d4 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/DefaultRecordsSend.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.network.TransferableChannel; + +import java.io.IOException; + +public class DefaultRecordsSend extends RecordsSend { + public DefaultRecordsSend(T records) { + this(records, records.sizeInBytes()); + } + + public DefaultRecordsSend(T records, int maxBytesToWrite) { + super(records, maxBytesToWrite); + } + + @Override + protected long writeTo(TransferableChannel channel, long previouslyWritten, int remaining) throws IOException { + return records().writeTo(channel, previouslyWritten, remaining); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/EndTransactionMarker.java b/clients/src/main/java/org/apache/kafka/common/record/EndTransactionMarker.java new file mode 100644 index 0000000..4bf1ebf --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/EndTransactionMarker.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.InvalidRecordException; +import org.apache.kafka.common.protocol.types.Field; +import org.apache.kafka.common.protocol.types.Schema; +import org.apache.kafka.common.protocol.types.Struct; +import org.apache.kafka.common.protocol.types.Type; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; + +/** + * This class represents the control record which is written to the log to indicate the completion + * of a transaction. The record key specifies the {@link ControlRecordType control type} and the + * value embeds information useful for write validation (for now, just the coordinator epoch). + */ +public class EndTransactionMarker { + private static final Logger log = LoggerFactory.getLogger(EndTransactionMarker.class); + + private static final short CURRENT_END_TXN_MARKER_VERSION = 0; + private static final Schema END_TXN_MARKER_SCHEMA_VERSION_V0 = new Schema( + new Field("version", Type.INT16), + new Field("coordinator_epoch", Type.INT32)); + static final int CURRENT_END_TXN_MARKER_VALUE_SIZE = 6; + static final int CURRENT_END_TXN_SCHEMA_RECORD_SIZE = DefaultRecord.sizeInBytes(0, 0L, + ControlRecordType.CURRENT_CONTROL_RECORD_KEY_SIZE, + EndTransactionMarker.CURRENT_END_TXN_MARKER_VALUE_SIZE, + Record.EMPTY_HEADERS); + + private final ControlRecordType type; + private final int coordinatorEpoch; + + public EndTransactionMarker(ControlRecordType type, int coordinatorEpoch) { + ensureTransactionMarkerControlType(type); + this.type = type; + this.coordinatorEpoch = coordinatorEpoch; + } + + public int coordinatorEpoch() { + return coordinatorEpoch; + } + + public ControlRecordType controlType() { + return type; + } + + private Struct buildRecordValue() { + Struct struct = new Struct(END_TXN_MARKER_SCHEMA_VERSION_V0); + struct.set("version", CURRENT_END_TXN_MARKER_VERSION); + struct.set("coordinator_epoch", coordinatorEpoch); + return struct; + } + + public ByteBuffer serializeValue() { + Struct valueStruct = buildRecordValue(); + ByteBuffer value = ByteBuffer.allocate(valueStruct.sizeOf()); + valueStruct.writeTo(value); + value.flip(); + return value; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + EndTransactionMarker that = (EndTransactionMarker) o; + return coordinatorEpoch == that.coordinatorEpoch && type == that.type; + } + + @Override + public int hashCode() { + int result = type != null ? type.hashCode() : 0; + result = 31 * result + coordinatorEpoch; + return result; + } + + private static void ensureTransactionMarkerControlType(ControlRecordType type) { + if (type != ControlRecordType.COMMIT && type != ControlRecordType.ABORT) + throw new IllegalArgumentException("Invalid control record type for end transaction marker" + type); + } + + public static EndTransactionMarker deserialize(Record record) { + ControlRecordType type = ControlRecordType.parse(record.key()); + return deserializeValue(type, record.value()); + } + + static EndTransactionMarker deserializeValue(ControlRecordType type, ByteBuffer value) { + ensureTransactionMarkerControlType(type); + + if (value.remaining() < CURRENT_END_TXN_MARKER_VALUE_SIZE) + throw new InvalidRecordException("Invalid value size found for end transaction marker. Must have " + + "at least " + CURRENT_END_TXN_MARKER_VALUE_SIZE + " bytes, but found only " + value.remaining()); + + short version = value.getShort(0); + if (version < 0) + throw new InvalidRecordException("Invalid version found for end transaction marker: " + version + + ". May indicate data corruption"); + + if (version > CURRENT_END_TXN_MARKER_VERSION) + log.debug("Received end transaction marker value version {}. Parsing as version {}", version, + CURRENT_END_TXN_MARKER_VERSION); + + int coordinatorEpoch = value.getInt(2); + return new EndTransactionMarker(type, coordinatorEpoch); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/FileLogInputStream.java b/clients/src/main/java/org/apache/kafka/common/record/FileLogInputStream.java new file mode 100644 index 0000000..10837d6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/FileLogInputStream.java @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.errors.CorruptRecordException; +import org.apache.kafka.common.record.AbstractLegacyRecordBatch.LegacyFileChannelRecordBatch; +import org.apache.kafka.common.record.DefaultRecordBatch.DefaultFileChannelRecordBatch; +import org.apache.kafka.common.utils.BufferSupplier; +import org.apache.kafka.common.utils.CloseableIterator; +import org.apache.kafka.common.utils.Utils; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.util.Iterator; +import java.util.Objects; + +import static org.apache.kafka.common.record.Records.LOG_OVERHEAD; +import static org.apache.kafka.common.record.Records.HEADER_SIZE_UP_TO_MAGIC; +import static org.apache.kafka.common.record.Records.MAGIC_OFFSET; +import static org.apache.kafka.common.record.Records.OFFSET_OFFSET; +import static org.apache.kafka.common.record.Records.SIZE_OFFSET; + +/** + * A log input stream which is backed by a {@link FileChannel}. + */ +public class FileLogInputStream implements LogInputStream { + private int position; + private final int end; + private final FileRecords fileRecords; + private final ByteBuffer logHeaderBuffer = ByteBuffer.allocate(HEADER_SIZE_UP_TO_MAGIC); + + /** + * Create a new log input stream over the FileChannel + * @param records Underlying FileRecords instance + * @param start Position in the file channel to start from + * @param end Position in the file channel not to read past + */ + FileLogInputStream(FileRecords records, + int start, + int end) { + this.fileRecords = records; + this.position = start; + this.end = end; + } + + @Override + public FileChannelRecordBatch nextBatch() throws IOException { + FileChannel channel = fileRecords.channel(); + if (position >= end - HEADER_SIZE_UP_TO_MAGIC) + return null; + + logHeaderBuffer.rewind(); + Utils.readFullyOrFail(channel, logHeaderBuffer, position, "log header"); + + logHeaderBuffer.rewind(); + long offset = logHeaderBuffer.getLong(OFFSET_OFFSET); + int size = logHeaderBuffer.getInt(SIZE_OFFSET); + + // V0 has the smallest overhead, stricter checking is done later + if (size < LegacyRecord.RECORD_OVERHEAD_V0) + throw new CorruptRecordException(String.format("Found record size %d smaller than minimum record " + + "overhead (%d) in file %s.", size, LegacyRecord.RECORD_OVERHEAD_V0, fileRecords.file())); + + if (position > end - LOG_OVERHEAD - size) + return null; + + byte magic = logHeaderBuffer.get(MAGIC_OFFSET); + final FileChannelRecordBatch batch; + + if (magic < RecordBatch.MAGIC_VALUE_V2) + batch = new LegacyFileChannelRecordBatch(offset, magic, fileRecords, position, size); + else + batch = new DefaultFileChannelRecordBatch(offset, magic, fileRecords, position, size); + + position += batch.sizeInBytes(); + return batch; + } + + /** + * Log entry backed by an underlying FileChannel. This allows iteration over the record batches + * without needing to read the record data into memory until it is needed. The downside + * is that entries will generally no longer be readable when the underlying channel is closed. + */ + public abstract static class FileChannelRecordBatch extends AbstractRecordBatch { + protected final long offset; + protected final byte magic; + protected final FileRecords fileRecords; + protected final int position; + protected final int batchSize; + + private RecordBatch fullBatch; + private RecordBatch batchHeader; + + FileChannelRecordBatch(long offset, + byte magic, + FileRecords fileRecords, + int position, + int batchSize) { + this.offset = offset; + this.magic = magic; + this.fileRecords = fileRecords; + this.position = position; + this.batchSize = batchSize; + } + + @Override + public CompressionType compressionType() { + return loadBatchHeader().compressionType(); + } + + @Override + public TimestampType timestampType() { + return loadBatchHeader().timestampType(); + } + + @Override + public long checksum() { + return loadBatchHeader().checksum(); + } + + @Override + public long maxTimestamp() { + return loadBatchHeader().maxTimestamp(); + } + + public int position() { + return position; + } + + @Override + public byte magic() { + return magic; + } + + @Override + public Iterator iterator() { + return loadFullBatch().iterator(); + } + + @Override + public CloseableIterator streamingIterator(BufferSupplier bufferSupplier) { + return loadFullBatch().streamingIterator(bufferSupplier); + } + + @Override + public boolean isValid() { + return loadFullBatch().isValid(); + } + + @Override + public void ensureValid() { + loadFullBatch().ensureValid(); + } + + @Override + public int sizeInBytes() { + return LOG_OVERHEAD + batchSize; + } + + @Override + public void writeTo(ByteBuffer buffer) { + FileChannel channel = fileRecords.channel(); + try { + int limit = buffer.limit(); + buffer.limit(buffer.position() + sizeInBytes()); + Utils.readFully(channel, buffer, position); + buffer.limit(limit); + } catch (IOException e) { + throw new KafkaException("Failed to read record batch at position " + position + " from " + fileRecords, e); + } + } + + protected abstract RecordBatch toMemoryRecordBatch(ByteBuffer buffer); + + protected abstract int headerSize(); + + protected RecordBatch loadFullBatch() { + if (fullBatch == null) { + batchHeader = null; + fullBatch = loadBatchWithSize(sizeInBytes(), "full record batch"); + } + return fullBatch; + } + + protected RecordBatch loadBatchHeader() { + if (fullBatch != null) + return fullBatch; + + if (batchHeader == null) + batchHeader = loadBatchWithSize(headerSize(), "record batch header"); + + return batchHeader; + } + + private RecordBatch loadBatchWithSize(int size, String description) { + FileChannel channel = fileRecords.channel(); + try { + ByteBuffer buffer = ByteBuffer.allocate(size); + Utils.readFullyOrFail(channel, buffer, position, description); + buffer.rewind(); + return toMemoryRecordBatch(buffer); + } catch (IOException e) { + throw new KafkaException("Failed to load record batch at position " + position + " from " + fileRecords, e); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + FileChannelRecordBatch that = (FileChannelRecordBatch) o; + + FileChannel channel = fileRecords == null ? null : fileRecords.channel(); + FileChannel thatChannel = that.fileRecords == null ? null : that.fileRecords.channel(); + + return offset == that.offset && + position == that.position && + batchSize == that.batchSize && + Objects.equals(channel, thatChannel); + } + + @Override + public int hashCode() { + FileChannel channel = fileRecords == null ? null : fileRecords.channel(); + + int result = Long.hashCode(offset); + result = 31 * result + (channel != null ? channel.hashCode() : 0); + result = 31 * result + position; + result = 31 * result + batchSize; + return result; + } + + @Override + public String toString() { + return "FileChannelRecordBatch(magic: " + magic + + ", offset: " + offset + + ", size: " + batchSize + ")"; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/FileRecords.java b/clients/src/main/java/org/apache/kafka/common/record/FileRecords.java new file mode 100644 index 0000000..17a41e2 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/FileRecords.java @@ -0,0 +1,556 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.network.TransferableChannel; +import org.apache.kafka.common.record.FileLogInputStream.FileChannelRecordBatch; +import org.apache.kafka.common.utils.AbstractIterator; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; + +import java.io.Closeable; +import java.io.File; +import java.io.IOException; +import java.io.RandomAccessFile; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.file.Files; +import java.nio.file.StandardOpenOption; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * A {@link Records} implementation backed by a file. An optional start and end position can be applied to this + * instance to enable slicing a range of the log records. + */ +public class FileRecords extends AbstractRecords implements Closeable { + private final boolean isSlice; + private final int start; + private final int end; + + private final Iterable batches; + + // mutable state + private final AtomicInteger size; + private final FileChannel channel; + private volatile File file; + + /** + * The {@code FileRecords.open} methods should be used instead of this constructor whenever possible. + * The constructor is visible for tests. + */ + FileRecords(File file, + FileChannel channel, + int start, + int end, + boolean isSlice) throws IOException { + this.file = file; + this.channel = channel; + this.start = start; + this.end = end; + this.isSlice = isSlice; + this.size = new AtomicInteger(); + + if (isSlice) { + // don't check the file size if this is just a slice view + size.set(end - start); + } else { + if (channel.size() > Integer.MAX_VALUE) + throw new KafkaException("The size of segment " + file + " (" + channel.size() + + ") is larger than the maximum allowed segment size of " + Integer.MAX_VALUE); + + int limit = Math.min((int) channel.size(), end); + size.set(limit - start); + + // if this is not a slice, update the file pointer to the end of the file + // set the file position to the last byte in the file + channel.position(limit); + } + + batches = batchesFrom(start); + } + + @Override + public int sizeInBytes() { + return size.get(); + } + + /** + * Get the underlying file. + * @return The file + */ + public File file() { + return file; + } + + /** + * Get the underlying file channel. + * @return The file channel + */ + public FileChannel channel() { + return channel; + } + + /** + * Read log batches into the given buffer until there are no bytes remaining in the buffer or the end of the file + * is reached. + * + * @param buffer The buffer to write the batches to + * @param position Position in the buffer to read from + * @throws IOException If an I/O error occurs, see {@link FileChannel#read(ByteBuffer, long)} for details on the + * possible exceptions + */ + public void readInto(ByteBuffer buffer, int position) throws IOException { + Utils.readFully(channel, buffer, position + this.start); + buffer.flip(); + } + + /** + * Return a slice of records from this instance, which is a view into this set starting from the given position + * and with the given size limit. + * + * If the size is beyond the end of the file, the end will be based on the size of the file at the time of the read. + * + * If this message set is already sliced, the position will be taken relative to that slicing. + * + * @param position The start position to begin the read from + * @param size The number of bytes after the start position to include + * @return A sliced wrapper on this message set limited based on the given position and size + */ + public FileRecords slice(int position, int size) throws IOException { + int availableBytes = availableBytes(position, size); + int startPosition = this.start + position; + return new FileRecords(file, channel, startPosition, startPosition + availableBytes, true); + } + + /** + * Return a slice of records from this instance, the difference with {@link FileRecords#slice(int, int)} is + * that the position is not necessarily on an offset boundary. + * + * This method is reserved for cases where offset alignment is not necessary, such as in the replication of raft + * snapshots. + * + * @param position The start position to begin the read from + * @param size The number of bytes after the start position to include + * @return A unaligned slice of records on this message set limited based on the given position and size + */ + public UnalignedFileRecords sliceUnaligned(int position, int size) { + int availableBytes = availableBytes(position, size); + return new UnalignedFileRecords(channel, this.start + position, availableBytes); + } + + private int availableBytes(int position, int size) { + // Cache current size in case concurrent write changes it + int currentSizeInBytes = sizeInBytes(); + + if (position < 0) + throw new IllegalArgumentException("Invalid position: " + position + " in read from " + this); + if (position > currentSizeInBytes - start) + throw new IllegalArgumentException("Slice from position " + position + " exceeds end position of " + this); + if (size < 0) + throw new IllegalArgumentException("Invalid size: " + size + " in read from " + this); + + int end = this.start + position + size; + // Handle integer overflow or if end is beyond the end of the file + if (end < 0 || end > start + currentSizeInBytes) + end = this.start + currentSizeInBytes; + return end - (this.start + position); + } + + /** + * Append a set of records to the file. This method is not thread-safe and must be + * protected with a lock. + * + * @param records The records to append + * @return the number of bytes written to the underlying file + */ + public int append(MemoryRecords records) throws IOException { + if (records.sizeInBytes() > Integer.MAX_VALUE - size.get()) + throw new IllegalArgumentException("Append of size " + records.sizeInBytes() + + " bytes is too large for segment with current file position at " + size.get()); + + int written = records.writeFullyTo(channel); + size.getAndAdd(written); + return written; + } + + /** + * Commit all written data to the physical disk + */ + public void flush() throws IOException { + channel.force(true); + } + + /** + * Close this record set + */ + public void close() throws IOException { + flush(); + trim(); + channel.close(); + } + + /** + * Close file handlers used by the FileChannel but don't write to disk. This is used when the disk may have failed + */ + public void closeHandlers() throws IOException { + channel.close(); + } + + /** + * Delete this message set from the filesystem + * @throws IOException if deletion fails due to an I/O error + * @return {@code true} if the file was deleted by this method; {@code false} if the file could not be deleted + * because it did not exist + */ + public boolean deleteIfExists() throws IOException { + Utils.closeQuietly(channel, "FileChannel"); + return Files.deleteIfExists(file.toPath()); + } + + /** + * Trim file when close or roll to next file + */ + public void trim() throws IOException { + truncateTo(sizeInBytes()); + } + + /** + * Update the parent directory (to be used with caution since this does not reopen the file channel) + * @param parentDir The new parent directory + */ + public void updateParentDir(File parentDir) { + this.file = new File(parentDir, file.getName()); + } + + /** + * Rename the file that backs this message set + * @throws IOException if rename fails. + */ + public void renameTo(File f) throws IOException { + try { + Utils.atomicMoveWithFallback(file.toPath(), f.toPath(), false); + } finally { + this.file = f; + } + } + + /** + * Truncate this file message set to the given size in bytes. Note that this API does no checking that the + * given size falls on a valid message boundary. + * In some versions of the JDK truncating to the same size as the file message set will cause an + * update of the files mtime, so truncate is only performed if the targetSize is smaller than the + * size of the underlying FileChannel. + * It is expected that no other threads will do writes to the log when this function is called. + * @param targetSize The size to truncate to. Must be between 0 and sizeInBytes. + * @return The number of bytes truncated off + */ + public int truncateTo(int targetSize) throws IOException { + int originalSize = sizeInBytes(); + if (targetSize > originalSize || targetSize < 0) + throw new KafkaException("Attempt to truncate log segment " + file + " to " + targetSize + " bytes failed, " + + " size of this log segment is " + originalSize + " bytes."); + if (targetSize < (int) channel.size()) { + channel.truncate(targetSize); + size.set(targetSize); + } + return originalSize - targetSize; + } + + @Override + public ConvertedRecords downConvert(byte toMagic, long firstOffset, Time time) { + ConvertedRecords convertedRecords = RecordsUtil.downConvert(batches, toMagic, firstOffset, time); + if (convertedRecords.recordConversionStats().numRecordsConverted() == 0) { + // This indicates that the message is too large, which means that the buffer is not large + // enough to hold a full record batch. We just return all the bytes in this instance. + // Even though the record batch does not have the right format version, we expect old clients + // to raise an error to the user after reading the record batch size and seeing that there + // are not enough available bytes in the response to read it fully. Note that this is + // only possible prior to KIP-74, after which the broker was changed to always return at least + // one full record batch, even if it requires exceeding the max fetch size requested by the client. + return new ConvertedRecords<>(this, RecordConversionStats.EMPTY); + } else { + return convertedRecords; + } + } + + @Override + public long writeTo(TransferableChannel destChannel, long offset, int length) throws IOException { + long newSize = Math.min(channel.size(), end) - start; + int oldSize = sizeInBytes(); + if (newSize < oldSize) + throw new KafkaException(String.format( + "Size of FileRecords %s has been truncated during write: old size %d, new size %d", + file.getAbsolutePath(), oldSize, newSize)); + + long position = start + offset; + long count = Math.min(length, oldSize - offset); + return destChannel.transferFrom(channel, position, count); + } + + /** + * Search forward for the file position of the last offset that is greater than or equal to the target offset + * and return its physical position and the size of the message (including log overhead) at the returned offset. If + * no such offsets are found, return null. + * + * @param targetOffset The offset to search for. + * @param startingPosition The starting position in the file to begin searching from. + */ + public LogOffsetPosition searchForOffsetWithSize(long targetOffset, int startingPosition) { + for (FileChannelRecordBatch batch : batchesFrom(startingPosition)) { + long offset = batch.lastOffset(); + if (offset >= targetOffset) + return new LogOffsetPosition(offset, batch.position(), batch.sizeInBytes()); + } + return null; + } + + /** + * Search forward for the first message that meets the following requirements: + * - Message's timestamp is greater than or equals to the targetTimestamp. + * - Message's position in the log file is greater than or equals to the startingPosition. + * - Message's offset is greater than or equals to the startingOffset. + * + * @param targetTimestamp The timestamp to search for. + * @param startingPosition The starting position to search. + * @param startingOffset The starting offset to search. + * @return The timestamp and offset of the message found. Null if no message is found. + */ + public TimestampAndOffset searchForTimestamp(long targetTimestamp, int startingPosition, long startingOffset) { + for (RecordBatch batch : batchesFrom(startingPosition)) { + if (batch.maxTimestamp() >= targetTimestamp) { + // We found a message + for (Record record : batch) { + long timestamp = record.timestamp(); + if (timestamp >= targetTimestamp && record.offset() >= startingOffset) + return new TimestampAndOffset(timestamp, record.offset(), + maybeLeaderEpoch(batch.partitionLeaderEpoch())); + } + } + } + return null; + } + + /** + * Return the largest timestamp of the messages after a given position in this file message set. + * @param startingPosition The starting position. + * @return The largest timestamp of the messages after the given position. + */ + public TimestampAndOffset largestTimestampAfter(int startingPosition) { + long maxTimestamp = RecordBatch.NO_TIMESTAMP; + long offsetOfMaxTimestamp = -1L; + int leaderEpochOfMaxTimestamp = RecordBatch.NO_PARTITION_LEADER_EPOCH; + + for (RecordBatch batch : batchesFrom(startingPosition)) { + long timestamp = batch.maxTimestamp(); + if (timestamp > maxTimestamp) { + maxTimestamp = timestamp; + offsetOfMaxTimestamp = batch.lastOffset(); + leaderEpochOfMaxTimestamp = batch.partitionLeaderEpoch(); + } + } + return new TimestampAndOffset(maxTimestamp, offsetOfMaxTimestamp, + maybeLeaderEpoch(leaderEpochOfMaxTimestamp)); + } + + private Optional maybeLeaderEpoch(int leaderEpoch) { + return leaderEpoch == RecordBatch.NO_PARTITION_LEADER_EPOCH ? + Optional.empty() : Optional.of(leaderEpoch); + } + + /** + * Get an iterator over the record batches in the file. Note that the batches are + * backed by the open file channel. When the channel is closed (i.e. when this instance + * is closed), the batches will generally no longer be readable. + * @return An iterator over the batches + */ + @Override + public Iterable batches() { + return batches; + } + + @Override + public String toString() { + return "FileRecords(size=" + sizeInBytes() + + ", file=" + file + + ", start=" + start + + ", end=" + end + + ")"; + } + + /** + * Get an iterator over the record batches in the file, starting at a specific position. This is similar to + * {@link #batches()} except that callers specify a particular position to start reading the batches from. This + * method must be used with caution: the start position passed in must be a known start of a batch. + * @param start The position to start record iteration from; must be a known position for start of a batch + * @return An iterator over batches starting from {@code start} + */ + public Iterable batchesFrom(final int start) { + return () -> batchIterator(start); + } + + @Override + public AbstractIterator batchIterator() { + return batchIterator(start); + } + + private AbstractIterator batchIterator(int start) { + final int end; + if (isSlice) + end = this.end; + else + end = this.sizeInBytes(); + FileLogInputStream inputStream = new FileLogInputStream(this, start, end); + return new RecordBatchIterator<>(inputStream); + } + + public static FileRecords open(File file, + boolean mutable, + boolean fileAlreadyExists, + int initFileSize, + boolean preallocate) throws IOException { + FileChannel channel = openChannel(file, mutable, fileAlreadyExists, initFileSize, preallocate); + int end = (!fileAlreadyExists && preallocate) ? 0 : Integer.MAX_VALUE; + return new FileRecords(file, channel, 0, end, false); + } + + public static FileRecords open(File file, + boolean fileAlreadyExists, + int initFileSize, + boolean preallocate) throws IOException { + return open(file, true, fileAlreadyExists, initFileSize, preallocate); + } + + public static FileRecords open(File file, boolean mutable) throws IOException { + return open(file, mutable, false, 0, false); + } + + public static FileRecords open(File file) throws IOException { + return open(file, true); + } + + /** + * Open a channel for the given file + * For windows NTFS and some old LINUX file system, set preallocate to true and initFileSize + * with one value (for example 512 * 1025 *1024 ) can improve the kafka produce performance. + * @param file File path + * @param mutable mutable + * @param fileAlreadyExists File already exists or not + * @param initFileSize The size used for pre allocate file, for example 512 * 1025 *1024 + * @param preallocate Pre-allocate file or not, gotten from configuration. + */ + private static FileChannel openChannel(File file, + boolean mutable, + boolean fileAlreadyExists, + int initFileSize, + boolean preallocate) throws IOException { + if (mutable) { + if (fileAlreadyExists || !preallocate) { + return FileChannel.open(file.toPath(), StandardOpenOption.CREATE, StandardOpenOption.READ, + StandardOpenOption.WRITE); + } else { + RandomAccessFile randomAccessFile = new RandomAccessFile(file, "rw"); + randomAccessFile.setLength(initFileSize); + return randomAccessFile.getChannel(); + } + } else { + return FileChannel.open(file.toPath()); + } + } + + public static class LogOffsetPosition { + public final long offset; + public final int position; + public final int size; + + public LogOffsetPosition(long offset, int position, int size) { + this.offset = offset; + this.position = position; + this.size = size; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + LogOffsetPosition that = (LogOffsetPosition) o; + + return offset == that.offset && + position == that.position && + size == that.size; + + } + + @Override + public int hashCode() { + int result = Long.hashCode(offset); + result = 31 * result + position; + result = 31 * result + size; + return result; + } + + @Override + public String toString() { + return "LogOffsetPosition(" + + "offset=" + offset + + ", position=" + position + + ", size=" + size + + ')'; + } + } + + public static class TimestampAndOffset { + public final long timestamp; + public final long offset; + public final Optional leaderEpoch; + + public TimestampAndOffset(long timestamp, long offset, Optional leaderEpoch) { + this.timestamp = timestamp; + this.offset = offset; + this.leaderEpoch = leaderEpoch; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TimestampAndOffset that = (TimestampAndOffset) o; + return timestamp == that.timestamp && + offset == that.offset && + Objects.equals(leaderEpoch, that.leaderEpoch); + } + + @Override + public int hashCode() { + return Objects.hash(timestamp, offset, leaderEpoch); + } + + @Override + public String toString() { + return "TimestampAndOffset(" + + "timestamp=" + timestamp + + ", offset=" + offset + + ", leaderEpoch=" + leaderEpoch + + ')'; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/LazyDownConversionRecords.java b/clients/src/main/java/org/apache/kafka/common/record/LazyDownConversionRecords.java new file mode 100644 index 0000000..56ef8e1 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/LazyDownConversionRecords.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.AbstractIterator; +import org.apache.kafka.common.utils.Time; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +/** + * Encapsulation for holding records that require down-conversion in a lazy, chunked manner (KIP-283). See + * {@link LazyDownConversionRecordsSend} for the actual chunked send implementation. + */ +public class LazyDownConversionRecords implements BaseRecords { + private final TopicPartition topicPartition; + private final Records records; + private final byte toMagic; + private final long firstOffset; + private ConvertedRecords firstConvertedBatch; + private final int sizeInBytes; + private final Time time; + + /** + * @param topicPartition The topic-partition to which records belong + * @param records Records to lazily down-convert + * @param toMagic Magic version to down-convert to + * @param firstOffset The starting offset for down-converted records. This only impacts some cases. See + * {@link RecordsUtil#downConvert(Iterable, byte, long, Time)} for an explanation. + * @param time The time instance to use + * + * @throws org.apache.kafka.common.errors.UnsupportedCompressionTypeException If the first batch to down-convert + * has a compression type which we do not support down-conversion for. + */ + public LazyDownConversionRecords(TopicPartition topicPartition, Records records, byte toMagic, long firstOffset, Time time) { + this.topicPartition = Objects.requireNonNull(topicPartition); + this.records = Objects.requireNonNull(records); + this.toMagic = toMagic; + this.firstOffset = firstOffset; + this.time = Objects.requireNonNull(time); + + // To make progress, kafka consumers require at least one full record batch per partition, i.e. we need to + // ensure we can accommodate one full batch of down-converted messages. We achieve this by having `sizeInBytes` + // factor in the size of the first down-converted batch and we return at least that many bytes. + java.util.Iterator> it = iterator(0); + if (it.hasNext()) { + firstConvertedBatch = it.next(); + sizeInBytes = Math.max(records.sizeInBytes(), firstConvertedBatch.records().sizeInBytes()); + } else { + // If there are messages before down-conversion and no messages after down-conversion, + // make sure we are able to send at least an overflow message to the consumer so that it can throw + // a RecordTooLargeException. Typically, the consumer would need to increase the fetch size in such cases. + // If there are no messages before down-conversion, we return an empty record batch. + firstConvertedBatch = null; + sizeInBytes = records.batches().iterator().hasNext() ? LazyDownConversionRecordsSend.MIN_OVERFLOW_MESSAGE_LENGTH : 0; + } + } + + @Override + public int sizeInBytes() { + return sizeInBytes; + } + + @Override + public LazyDownConversionRecordsSend toSend() { + return new LazyDownConversionRecordsSend(this); + } + + public TopicPartition topicPartition() { + return topicPartition; + } + + @Override + public boolean equals(Object o) { + if (o instanceof LazyDownConversionRecords) { + LazyDownConversionRecords that = (LazyDownConversionRecords) o; + return toMagic == that.toMagic && + firstOffset == that.firstOffset && + topicPartition.equals(that.topicPartition) && + records.equals(that.records); + } + return false; + } + + @Override + public int hashCode() { + int result = toMagic; + result = 31 * result + Long.hashCode(firstOffset); + result = 31 * result + topicPartition.hashCode(); + result = 31 * result + records.hashCode(); + return result; + } + + @Override + public String toString() { + return "LazyDownConversionRecords(size=" + sizeInBytes + + ", underlying=" + records + + ", toMagic=" + toMagic + + ", firstOffset=" + firstOffset + + ")"; + } + + public java.util.Iterator> iterator(long maximumReadSize) { + // We typically expect only one iterator instance to be created, so null out the first converted batch after + // first use to make it available for GC. + ConvertedRecords firstBatch = firstConvertedBatch; + firstConvertedBatch = null; + return new Iterator(records, maximumReadSize, firstBatch); + } + + /** + * Implementation for being able to iterate over down-converted records. Goal of this implementation is to keep + * it as memory-efficient as possible by not having to maintain all down-converted records in-memory. Maintains + * a view into batches of down-converted records. + */ + private class Iterator extends AbstractIterator> { + private final AbstractIterator batchIterator; + private final long maximumReadSize; + private ConvertedRecords firstConvertedBatch; + + /** + * @param recordsToDownConvert Records that require down-conversion + * @param maximumReadSize Maximum possible size of underlying records that will be down-converted in each call to + * {@link #makeNext()}. This is a soft limit as {@link #makeNext()} will always convert + * and return at least one full message batch. + */ + private Iterator(Records recordsToDownConvert, long maximumReadSize, ConvertedRecords firstConvertedBatch) { + this.batchIterator = recordsToDownConvert.batchIterator(); + this.maximumReadSize = maximumReadSize; + this.firstConvertedBatch = firstConvertedBatch; + // If we already have the first down-converted batch, advance the underlying records iterator to next batch + if (firstConvertedBatch != null) + this.batchIterator.next(); + } + + /** + * Make next set of down-converted records + * @return Down-converted records + */ + @Override + protected ConvertedRecords makeNext() { + // If we have cached the first down-converted batch, return that now + if (firstConvertedBatch != null) { + ConvertedRecords convertedBatch = firstConvertedBatch; + firstConvertedBatch = null; + return convertedBatch; + } + + while (batchIterator.hasNext()) { + final List batches = new ArrayList<>(); + boolean isFirstBatch = true; + long sizeSoFar = 0; + + // Figure out batches we should down-convert based on the size constraints + while (batchIterator.hasNext() && + (isFirstBatch || (batchIterator.peek().sizeInBytes() + sizeSoFar) <= maximumReadSize)) { + RecordBatch currentBatch = batchIterator.next(); + batches.add(currentBatch); + sizeSoFar += currentBatch.sizeInBytes(); + isFirstBatch = false; + } + + ConvertedRecords convertedRecords = RecordsUtil.downConvert(batches, toMagic, firstOffset, time); + // During conversion, it is possible that we drop certain batches because they do not have an equivalent + // representation in the message format we want to convert to. For example, V0 and V1 message formats + // have no notion of transaction markers which were introduced in V2 so they get dropped during conversion. + // We return converted records only when we have at least one valid batch of messages after conversion. + if (convertedRecords.records().sizeInBytes() > 0) + return convertedRecords; + } + return allDone(); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/LazyDownConversionRecordsSend.java b/clients/src/main/java/org/apache/kafka/common/record/LazyDownConversionRecordsSend.java new file mode 100644 index 0000000..17addef --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/LazyDownConversionRecordsSend.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.UnsupportedCompressionTypeException; +import org.apache.kafka.common.network.TransferableChannel; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Iterator; + +/** + * Encapsulation for {@link RecordsSend} for {@link LazyDownConversionRecords}. Records are down-converted in batches and + * on-demand when {@link #writeTo} method is called. + */ +public final class LazyDownConversionRecordsSend extends RecordsSend { + private static final Logger log = LoggerFactory.getLogger(LazyDownConversionRecordsSend.class); + private static final int MAX_READ_SIZE = 128 * 1024; + static final int MIN_OVERFLOW_MESSAGE_LENGTH = Records.LOG_OVERHEAD; + + private RecordConversionStats recordConversionStats; + private RecordsSend convertedRecordsWriter; + private Iterator> convertedRecordsIterator; + + public LazyDownConversionRecordsSend(LazyDownConversionRecords records) { + super(records, records.sizeInBytes()); + convertedRecordsWriter = null; + recordConversionStats = new RecordConversionStats(); + convertedRecordsIterator = records().iterator(MAX_READ_SIZE); + } + + private MemoryRecords buildOverflowBatch(int remaining) { + // We do not have any records left to down-convert. Construct an overflow message for the length remaining. + // This message will be ignored by the consumer because its length will be past the length of maximum + // possible response size. + // DefaultRecordBatch => + // BaseOffset => Int64 + // Length => Int32 + // ... + ByteBuffer overflowMessageBatch = ByteBuffer.allocate( + Math.max(MIN_OVERFLOW_MESSAGE_LENGTH, Math.min(remaining + 1, MAX_READ_SIZE))); + overflowMessageBatch.putLong(-1L); + + // Fill in the length of the overflow batch. A valid batch must be at least as long as the minimum batch + // overhead. + overflowMessageBatch.putInt(Math.max(remaining + 1, DefaultRecordBatch.RECORD_BATCH_OVERHEAD)); + log.debug("Constructed overflow message batch for partition {} with length={}", topicPartition(), remaining); + return MemoryRecords.readableRecords(overflowMessageBatch); + } + + @Override + public long writeTo(TransferableChannel channel, long previouslyWritten, int remaining) throws IOException { + if (convertedRecordsWriter == null || convertedRecordsWriter.completed()) { + MemoryRecords convertedRecords; + + try { + // Check if we have more chunks left to down-convert + if (convertedRecordsIterator.hasNext()) { + // Get next chunk of down-converted messages + ConvertedRecords recordsAndStats = convertedRecordsIterator.next(); + convertedRecords = (MemoryRecords) recordsAndStats.records(); + recordConversionStats.add(recordsAndStats.recordConversionStats()); + log.debug("Down-converted records for partition {} with length={}", topicPartition(), convertedRecords.sizeInBytes()); + } else { + convertedRecords = buildOverflowBatch(remaining); + } + } catch (UnsupportedCompressionTypeException e) { + // We have encountered a compression type which does not support down-conversion (e.g. zstd). + // Since we have already sent at least one batch and we have committed to the fetch size, we + // send an overflow batch. The consumer will read the first few records and then fetch from the + // offset of the batch which has the unsupported compression type. At that time, we will + // send back the UNSUPPORTED_COMPRESSION_TYPE erro which will allow the consumer to fail gracefully. + convertedRecords = buildOverflowBatch(remaining); + } + + convertedRecordsWriter = new DefaultRecordsSend<>(convertedRecords, Math.min(convertedRecords.sizeInBytes(), remaining)); + } + return convertedRecordsWriter.writeTo(channel); + } + + public RecordConversionStats recordConversionStats() { + return recordConversionStats; + } + + public TopicPartition topicPartition() { + return records().topicPartition(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/LegacyRecord.java b/clients/src/main/java/org/apache/kafka/common/record/LegacyRecord.java new file mode 100644 index 0000000..32c5aa8 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/LegacyRecord.java @@ -0,0 +1,577 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.errors.CorruptRecordException; +import org.apache.kafka.common.utils.ByteBufferOutputStream; +import org.apache.kafka.common.utils.ByteUtils; +import org.apache.kafka.common.utils.Checksums; +import org.apache.kafka.common.utils.Crc32; +import org.apache.kafka.common.utils.Utils; + +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; + +import static org.apache.kafka.common.utils.Utils.wrapNullable; + +/** + * This class represents the serialized key and value along with the associated CRC and other fields + * of message format versions 0 and 1. Note that it is uncommon to need to access this class directly. + * Usually it should be accessed indirectly through the {@link Record} interface which is exposed + * through the {@link Records} object. + */ +public final class LegacyRecord { + + /** + * The current offset and size for all the fixed-length fields + */ + public static final int CRC_OFFSET = 0; + public static final int CRC_LENGTH = 4; + public static final int MAGIC_OFFSET = CRC_OFFSET + CRC_LENGTH; + public static final int MAGIC_LENGTH = 1; + public static final int ATTRIBUTES_OFFSET = MAGIC_OFFSET + MAGIC_LENGTH; + public static final int ATTRIBUTES_LENGTH = 1; + public static final int TIMESTAMP_OFFSET = ATTRIBUTES_OFFSET + ATTRIBUTES_LENGTH; + public static final int TIMESTAMP_LENGTH = 8; + public static final int KEY_SIZE_OFFSET_V0 = ATTRIBUTES_OFFSET + ATTRIBUTES_LENGTH; + public static final int KEY_SIZE_OFFSET_V1 = TIMESTAMP_OFFSET + TIMESTAMP_LENGTH; + public static final int KEY_SIZE_LENGTH = 4; + public static final int KEY_OFFSET_V0 = KEY_SIZE_OFFSET_V0 + KEY_SIZE_LENGTH; + public static final int KEY_OFFSET_V1 = KEY_SIZE_OFFSET_V1 + KEY_SIZE_LENGTH; + public static final int VALUE_SIZE_LENGTH = 4; + + /** + * The size for the record header + */ + public static final int HEADER_SIZE_V0 = CRC_LENGTH + MAGIC_LENGTH + ATTRIBUTES_LENGTH; + public static final int HEADER_SIZE_V1 = CRC_LENGTH + MAGIC_LENGTH + ATTRIBUTES_LENGTH + TIMESTAMP_LENGTH; + + /** + * The amount of overhead bytes in a record + */ + public static final int RECORD_OVERHEAD_V0 = HEADER_SIZE_V0 + KEY_SIZE_LENGTH + VALUE_SIZE_LENGTH; + + /** + * The amount of overhead bytes in a record + */ + public static final int RECORD_OVERHEAD_V1 = HEADER_SIZE_V1 + KEY_SIZE_LENGTH + VALUE_SIZE_LENGTH; + + /** + * Specifies the mask for the compression code. 3 bits to hold the compression codec. 0 is reserved to indicate no + * compression + */ + private static final int COMPRESSION_CODEC_MASK = 0x07; + + /** + * Specify the mask of timestamp type: 0 for CreateTime, 1 for LogAppendTime. + */ + private static final byte TIMESTAMP_TYPE_MASK = 0x08; + + /** + * Timestamp value for records without a timestamp + */ + public static final long NO_TIMESTAMP = -1L; + + private final ByteBuffer buffer; + private final Long wrapperRecordTimestamp; + private final TimestampType wrapperRecordTimestampType; + + public LegacyRecord(ByteBuffer buffer) { + this(buffer, null, null); + } + + public LegacyRecord(ByteBuffer buffer, Long wrapperRecordTimestamp, TimestampType wrapperRecordTimestampType) { + this.buffer = buffer; + this.wrapperRecordTimestamp = wrapperRecordTimestamp; + this.wrapperRecordTimestampType = wrapperRecordTimestampType; + } + + /** + * Compute the checksum of the record from the record contents + */ + public long computeChecksum() { + return Crc32.crc32(buffer, MAGIC_OFFSET, buffer.limit() - MAGIC_OFFSET); + } + + /** + * Retrieve the previously computed CRC for this record + */ + public long checksum() { + return ByteUtils.readUnsignedInt(buffer, CRC_OFFSET); + } + + /** + * Returns true if the crc stored with the record matches the crc computed off the record contents + */ + public boolean isValid() { + return sizeInBytes() >= RECORD_OVERHEAD_V0 && checksum() == computeChecksum(); + } + + public Long wrapperRecordTimestamp() { + return wrapperRecordTimestamp; + } + + public TimestampType wrapperRecordTimestampType() { + return wrapperRecordTimestampType; + } + + /** + * Throw an InvalidRecordException if isValid is false for this record + */ + public void ensureValid() { + if (sizeInBytes() < RECORD_OVERHEAD_V0) + throw new CorruptRecordException("Record is corrupt (crc could not be retrieved as the record is too " + + "small, size = " + sizeInBytes() + ")"); + + if (!isValid()) + throw new CorruptRecordException("Record is corrupt (stored crc = " + checksum() + + ", computed crc = " + computeChecksum() + ")"); + } + + /** + * The complete serialized size of this record in bytes (including crc, header attributes, etc), but + * excluding the log overhead (offset and record size). + * @return the size in bytes + */ + public int sizeInBytes() { + return buffer.limit(); + } + + /** + * The length of the key in bytes + * @return the size in bytes of the key (0 if the key is null) + */ + public int keySize() { + if (magic() == RecordBatch.MAGIC_VALUE_V0) + return buffer.getInt(KEY_SIZE_OFFSET_V0); + else + return buffer.getInt(KEY_SIZE_OFFSET_V1); + } + + /** + * Does the record have a key? + * @return true if so, false otherwise + */ + public boolean hasKey() { + return keySize() >= 0; + } + + /** + * The position where the value size is stored + */ + private int valueSizeOffset() { + if (magic() == RecordBatch.MAGIC_VALUE_V0) + return KEY_OFFSET_V0 + Math.max(0, keySize()); + else + return KEY_OFFSET_V1 + Math.max(0, keySize()); + } + + /** + * The length of the value in bytes + * @return the size in bytes of the value (0 if the value is null) + */ + public int valueSize() { + return buffer.getInt(valueSizeOffset()); + } + + /** + * Check whether the value field of this record is null. + * @return true if the value is null, false otherwise + */ + public boolean hasNullValue() { + return valueSize() < 0; + } + + /** + * The magic value (i.e. message format version) of this record + * @return the magic value + */ + public byte magic() { + return buffer.get(MAGIC_OFFSET); + } + + /** + * The attributes stored with this record + * @return the attributes + */ + public byte attributes() { + return buffer.get(ATTRIBUTES_OFFSET); + } + + /** + * When magic value is greater than 0, the timestamp of a record is determined in the following way: + * 1. wrapperRecordTimestampType = null and wrapperRecordTimestamp is null - Uncompressed message, timestamp is in the message. + * 2. wrapperRecordTimestampType = LOG_APPEND_TIME and WrapperRecordTimestamp is not null - Compressed message using LOG_APPEND_TIME + * 3. wrapperRecordTimestampType = CREATE_TIME and wrapperRecordTimestamp is not null - Compressed message using CREATE_TIME + * + * @return the timestamp as determined above + */ + public long timestamp() { + if (magic() == RecordBatch.MAGIC_VALUE_V0) + return RecordBatch.NO_TIMESTAMP; + else { + // case 2 + if (wrapperRecordTimestampType == TimestampType.LOG_APPEND_TIME && wrapperRecordTimestamp != null) + return wrapperRecordTimestamp; + // Case 1, 3 + else + return buffer.getLong(TIMESTAMP_OFFSET); + } + } + + /** + * Get the timestamp type of the record. + * + * @return The timestamp type or {@link TimestampType#NO_TIMESTAMP_TYPE} if the magic is 0. + */ + public TimestampType timestampType() { + return timestampType(magic(), wrapperRecordTimestampType, attributes()); + } + + /** + * The compression type used with this record + */ + public CompressionType compressionType() { + return CompressionType.forId(buffer.get(ATTRIBUTES_OFFSET) & COMPRESSION_CODEC_MASK); + } + + /** + * A ByteBuffer containing the value of this record + * @return the value or null if the value for this record is null + */ + public ByteBuffer value() { + return Utils.sizeDelimited(buffer, valueSizeOffset()); + } + + /** + * A ByteBuffer containing the message key + * @return the buffer or null if the key for this record is null + */ + public ByteBuffer key() { + if (magic() == RecordBatch.MAGIC_VALUE_V0) + return Utils.sizeDelimited(buffer, KEY_SIZE_OFFSET_V0); + else + return Utils.sizeDelimited(buffer, KEY_SIZE_OFFSET_V1); + } + + /** + * Get the underlying buffer backing this record instance. + * + * @return the buffer + */ + public ByteBuffer buffer() { + return this.buffer; + } + + public String toString() { + if (magic() > 0) + return String.format("Record(magic=%d, attributes=%d, compression=%s, crc=%d, %s=%d, key=%d bytes, value=%d bytes)", + magic(), + attributes(), + compressionType(), + checksum(), + timestampType(), + timestamp(), + key() == null ? 0 : key().limit(), + value() == null ? 0 : value().limit()); + else + return String.format("Record(magic=%d, attributes=%d, compression=%s, crc=%d, key=%d bytes, value=%d bytes)", + magic(), + attributes(), + compressionType(), + checksum(), + key() == null ? 0 : key().limit(), + value() == null ? 0 : value().limit()); + } + + public boolean equals(Object other) { + if (this == other) + return true; + if (other == null) + return false; + if (!other.getClass().equals(LegacyRecord.class)) + return false; + LegacyRecord record = (LegacyRecord) other; + return this.buffer.equals(record.buffer); + } + + public int hashCode() { + return buffer.hashCode(); + } + + /** + * Create a new record instance. If the record's compression type is not none, then + * its value payload should be already compressed with the specified type; the constructor + * would always write the value payload as is and will not do the compression itself. + * + * @param magic The magic value to use + * @param timestamp The timestamp of the record + * @param key The key of the record (null, if none) + * @param value The record value + * @param compressionType The compression type used on the contents of the record (if any) + * @param timestampType The timestamp type to be used for this record + */ + public static LegacyRecord create(byte magic, + long timestamp, + byte[] key, + byte[] value, + CompressionType compressionType, + TimestampType timestampType) { + int keySize = key == null ? 0 : key.length; + int valueSize = value == null ? 0 : value.length; + ByteBuffer buffer = ByteBuffer.allocate(recordSize(magic, keySize, valueSize)); + write(buffer, magic, timestamp, wrapNullable(key), wrapNullable(value), compressionType, timestampType); + buffer.rewind(); + return new LegacyRecord(buffer); + } + + public static LegacyRecord create(byte magic, long timestamp, byte[] key, byte[] value) { + return create(magic, timestamp, key, value, CompressionType.NONE, TimestampType.CREATE_TIME); + } + + /** + * Write the header for a compressed record set in-place (i.e. assuming the compressed record data has already + * been written at the value offset in a wrapped record). This lets you dynamically create a compressed message + * set, and then go back later and fill in its size and CRC, which saves the need for copying to another buffer. + * + * @param buffer The buffer containing the compressed record data positioned at the first offset of the + * @param magic The magic value of the record set + * @param recordSize The size of the record (including record overhead) + * @param timestamp The timestamp of the wrapper record + * @param compressionType The compression type used + * @param timestampType The timestamp type of the wrapper record + */ + public static void writeCompressedRecordHeader(ByteBuffer buffer, + byte magic, + int recordSize, + long timestamp, + CompressionType compressionType, + TimestampType timestampType) { + int recordPosition = buffer.position(); + int valueSize = recordSize - recordOverhead(magic); + + // write the record header with a null value (the key is always null for the wrapper) + write(buffer, magic, timestamp, null, null, compressionType, timestampType); + buffer.position(recordPosition); + + // now fill in the value size + buffer.putInt(recordPosition + keyOffset(magic), valueSize); + + // compute and fill the crc from the beginning of the message + long crc = Crc32.crc32(buffer, MAGIC_OFFSET, recordSize - MAGIC_OFFSET); + ByteUtils.writeUnsignedInt(buffer, recordPosition + CRC_OFFSET, crc); + } + + private static void write(ByteBuffer buffer, + byte magic, + long timestamp, + ByteBuffer key, + ByteBuffer value, + CompressionType compressionType, + TimestampType timestampType) { + try { + DataOutputStream out = new DataOutputStream(new ByteBufferOutputStream(buffer)); + write(out, magic, timestamp, key, value, compressionType, timestampType); + } catch (IOException e) { + throw new KafkaException(e); + } + } + + /** + * Write the record data with the given compression type and return the computed crc. + * + * @param out The output stream to write to + * @param magic The magic value to be used + * @param timestamp The timestamp of the record + * @param key The record key + * @param value The record value + * @param compressionType The compression type + * @param timestampType The timestamp type + * @return the computed CRC for this record. + * @throws IOException for any IO errors writing to the output stream. + */ + public static long write(DataOutputStream out, + byte magic, + long timestamp, + byte[] key, + byte[] value, + CompressionType compressionType, + TimestampType timestampType) throws IOException { + return write(out, magic, timestamp, wrapNullable(key), wrapNullable(value), compressionType, timestampType); + } + + public static long write(DataOutputStream out, + byte magic, + long timestamp, + ByteBuffer key, + ByteBuffer value, + CompressionType compressionType, + TimestampType timestampType) throws IOException { + byte attributes = computeAttributes(magic, compressionType, timestampType); + long crc = computeChecksum(magic, attributes, timestamp, key, value); + write(out, magic, crc, attributes, timestamp, key, value); + return crc; + } + + /** + * Write a record using raw fields (without validation). This should only be used in testing. + */ + public static void write(DataOutputStream out, + byte magic, + long crc, + byte attributes, + long timestamp, + byte[] key, + byte[] value) throws IOException { + write(out, magic, crc, attributes, timestamp, wrapNullable(key), wrapNullable(value)); + } + + // Write a record to the buffer, if the record's compression type is none, then + // its value payload should be already compressed with the specified type + private static void write(DataOutputStream out, + byte magic, + long crc, + byte attributes, + long timestamp, + ByteBuffer key, + ByteBuffer value) throws IOException { + if (magic != RecordBatch.MAGIC_VALUE_V0 && magic != RecordBatch.MAGIC_VALUE_V1) + throw new IllegalArgumentException("Invalid magic value " + magic); + if (timestamp < 0 && timestamp != RecordBatch.NO_TIMESTAMP) + throw new IllegalArgumentException("Invalid message timestamp " + timestamp); + + // write crc + out.writeInt((int) (crc & 0xffffffffL)); + // write magic value + out.writeByte(magic); + // write attributes + out.writeByte(attributes); + + // maybe write timestamp + if (magic > RecordBatch.MAGIC_VALUE_V0) + out.writeLong(timestamp); + + // write the key + if (key == null) { + out.writeInt(-1); + } else { + int size = key.remaining(); + out.writeInt(size); + Utils.writeTo(out, key, size); + } + // write the value + if (value == null) { + out.writeInt(-1); + } else { + int size = value.remaining(); + out.writeInt(size); + Utils.writeTo(out, value, size); + } + } + + static int recordSize(byte magic, ByteBuffer key, ByteBuffer value) { + return recordSize(magic, key == null ? 0 : key.limit(), value == null ? 0 : value.limit()); + } + + public static int recordSize(byte magic, int keySize, int valueSize) { + return recordOverhead(magic) + keySize + valueSize; + } + + // visible only for testing + public static byte computeAttributes(byte magic, CompressionType type, TimestampType timestampType) { + byte attributes = 0; + if (type.id > 0) + attributes |= COMPRESSION_CODEC_MASK & type.id; + if (magic > RecordBatch.MAGIC_VALUE_V0) { + if (timestampType == TimestampType.NO_TIMESTAMP_TYPE) + throw new IllegalArgumentException("Timestamp type must be provided to compute attributes for " + + "message format v1"); + if (timestampType == TimestampType.LOG_APPEND_TIME) + attributes |= TIMESTAMP_TYPE_MASK; + } + return attributes; + } + + // visible only for testing + public static long computeChecksum(byte magic, byte attributes, long timestamp, byte[] key, byte[] value) { + return computeChecksum(magic, attributes, timestamp, wrapNullable(key), wrapNullable(value)); + } + + /** + * Compute the checksum of the record from the attributes, key and value payloads + */ + private static long computeChecksum(byte magic, byte attributes, long timestamp, ByteBuffer key, ByteBuffer value) { + Crc32 crc = new Crc32(); + crc.update(magic); + crc.update(attributes); + if (magic > RecordBatch.MAGIC_VALUE_V0) + Checksums.updateLong(crc, timestamp); + // update for the key + if (key == null) { + Checksums.updateInt(crc, -1); + } else { + int size = key.remaining(); + Checksums.updateInt(crc, size); + Checksums.update(crc, key, size); + } + // update for the value + if (value == null) { + Checksums.updateInt(crc, -1); + } else { + int size = value.remaining(); + Checksums.updateInt(crc, size); + Checksums.update(crc, value, size); + } + return crc.getValue(); + } + + static int recordOverhead(byte magic) { + if (magic == 0) + return RECORD_OVERHEAD_V0; + else if (magic == 1) + return RECORD_OVERHEAD_V1; + throw new IllegalArgumentException("Invalid magic used in LegacyRecord: " + magic); + } + + static int headerSize(byte magic) { + if (magic == 0) + return HEADER_SIZE_V0; + else if (magic == 1) + return HEADER_SIZE_V1; + throw new IllegalArgumentException("Invalid magic used in LegacyRecord: " + magic); + } + + private static int keyOffset(byte magic) { + if (magic == 0) + return KEY_OFFSET_V0; + else if (magic == 1) + return KEY_OFFSET_V1; + throw new IllegalArgumentException("Invalid magic used in LegacyRecord: " + magic); + } + + public static TimestampType timestampType(byte magic, TimestampType wrapperRecordTimestampType, byte attributes) { + if (magic == 0) + return TimestampType.NO_TIMESTAMP_TYPE; + else if (wrapperRecordTimestampType != null) + return wrapperRecordTimestampType; + else + return (attributes & TIMESTAMP_TYPE_MASK) == 0 ? TimestampType.CREATE_TIME : TimestampType.LOG_APPEND_TIME; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/LogInputStream.java b/clients/src/main/java/org/apache/kafka/common/record/LogInputStream.java new file mode 100644 index 0000000..0c2bb8c --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/LogInputStream.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import java.io.IOException; + +/** + * An abstraction between an underlying input stream and record iterators, a {@link LogInputStream} only returns + * the batches at one level. For magic values 0 and 1, this means that it can either handle iteration + * at the top level of the log or deep iteration within the payload of a single message, but it does not attempt + * to handle both. For magic value 2, this is only used for iterating over the top-level record batches (inner + * records do not follow the {@link RecordBatch} interface. + * + * The generic typing allows for implementations which present only a view of the log entries, which enables more + * efficient iteration when the record data is not actually needed. See for example + * {@link FileLogInputStream.FileChannelRecordBatch} in which the record is not brought into memory until needed. + * + * @param Type parameter of the log entry + */ +interface LogInputStream { + + /** + * Get the next record batch from the underlying input stream. + * + * @return The next record batch or null if there is none + * @throws IOException for any IO errors + */ + T nextBatch() throws IOException; +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java new file mode 100644 index 0000000..eacc211 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecords.java @@ -0,0 +1,808 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.CorruptRecordException; +import org.apache.kafka.common.message.LeaderChangeMessage; +import org.apache.kafka.common.message.SnapshotHeaderRecord; +import org.apache.kafka.common.message.SnapshotFooterRecord; +import org.apache.kafka.common.network.TransferableChannel; +import org.apache.kafka.common.record.MemoryRecords.RecordFilter.BatchRetention; +import org.apache.kafka.common.record.MemoryRecords.RecordFilter.BatchRetentionResult; +import org.apache.kafka.common.utils.AbstractIterator; +import org.apache.kafka.common.utils.BufferSupplier; +import org.apache.kafka.common.utils.ByteBufferOutputStream; +import org.apache.kafka.common.utils.CloseableIterator; +import org.apache.kafka.common.utils.Time; + +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.GatheringByteChannel; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +/** + * A {@link Records} implementation backed by a ByteBuffer. This is used only for reading or + * modifying in-place an existing buffer of record batches. To create a new buffer see {@link MemoryRecordsBuilder}, + * or one of the {@link #builder(ByteBuffer, byte, CompressionType, TimestampType, long)} variants. + */ +public class MemoryRecords extends AbstractRecords { + private static final Logger log = LoggerFactory.getLogger(MemoryRecords.class); + public static final MemoryRecords EMPTY = MemoryRecords.readableRecords(ByteBuffer.allocate(0)); + + private final ByteBuffer buffer; + + private final Iterable batches = this::batchIterator; + + private int validBytes = -1; + + // Construct a writable memory records + private MemoryRecords(ByteBuffer buffer) { + Objects.requireNonNull(buffer, "buffer should not be null"); + this.buffer = buffer; + } + + @Override + public int sizeInBytes() { + return buffer.limit(); + } + + @Override + public long writeTo(TransferableChannel channel, long position, int length) throws IOException { + if (position > Integer.MAX_VALUE) + throw new IllegalArgumentException("position should not be greater than Integer.MAX_VALUE: " + position); + if (position + length > buffer.limit()) + throw new IllegalArgumentException("position+length should not be greater than buffer.limit(), position: " + + position + ", length: " + length + ", buffer.limit(): " + buffer.limit()); + + return Utils.tryWriteTo(channel, (int) position, length, buffer); + } + + /** + * Write all records to the given channel (including partial records). + * @param channel The channel to write to + * @return The number of bytes written + * @throws IOException For any IO errors writing to the channel + */ + public int writeFullyTo(GatheringByteChannel channel) throws IOException { + buffer.mark(); + int written = 0; + while (written < sizeInBytes()) + written += channel.write(buffer); + buffer.reset(); + return written; + } + + /** + * The total number of bytes in this message set not including any partial, trailing messages. This + * may be smaller than what is returned by {@link #sizeInBytes()}. + * @return The number of valid bytes + */ + public int validBytes() { + if (validBytes >= 0) + return validBytes; + + int bytes = 0; + for (RecordBatch batch : batches()) + bytes += batch.sizeInBytes(); + + this.validBytes = bytes; + return bytes; + } + + @Override + public ConvertedRecords downConvert(byte toMagic, long firstOffset, Time time) { + return RecordsUtil.downConvert(batches(), toMagic, firstOffset, time); + } + + @Override + public AbstractIterator batchIterator() { + return new RecordBatchIterator<>(new ByteBufferLogInputStream(buffer.duplicate(), Integer.MAX_VALUE)); + } + + /** + * Validates the header of the first batch and returns batch size. + * @return first batch size including LOG_OVERHEAD if buffer contains header up to + * magic byte, null otherwise + * @throws CorruptRecordException if record size or magic is invalid + */ + public Integer firstBatchSize() { + if (buffer.remaining() < HEADER_SIZE_UP_TO_MAGIC) + return null; + return new ByteBufferLogInputStream(buffer, Integer.MAX_VALUE).nextBatchSize(); + } + + /** + * Filter the records into the provided ByteBuffer. + * + * @param partition The partition that is filtered (used only for logging) + * @param filter The filter function + * @param destinationBuffer The byte buffer to write the filtered records to + * @param maxRecordBatchSize The maximum record batch size. Note this is not a hard limit: if a batch + * exceeds this after filtering, we log a warning, but the batch will still be + * created. + * @param decompressionBufferSupplier The supplier of ByteBuffer(s) used for decompression if supported. For small + * record batches, allocating a potentially large buffer (64 KB for LZ4) will + * dominate the cost of decompressing and iterating over the records in the + * batch. As such, a supplier that reuses buffers will have a significant + * performance impact. + * @return A FilterResult with a summary of the output (for metrics) and potentially an overflow buffer + */ + public FilterResult filterTo(TopicPartition partition, RecordFilter filter, ByteBuffer destinationBuffer, + int maxRecordBatchSize, BufferSupplier decompressionBufferSupplier) { + return filterTo(partition, batches(), filter, destinationBuffer, maxRecordBatchSize, decompressionBufferSupplier); + } + + /** + * Note: This method is also used to convert the first timestamp of the batch (which is usually the timestamp of the first record) + * to the delete horizon of the tombstones or txn markers which are present in the batch. + */ + private static FilterResult filterTo(TopicPartition partition, Iterable batches, + RecordFilter filter, ByteBuffer destinationBuffer, int maxRecordBatchSize, + BufferSupplier decompressionBufferSupplier) { + FilterResult filterResult = new FilterResult(destinationBuffer); + ByteBufferOutputStream bufferOutputStream = new ByteBufferOutputStream(destinationBuffer); + for (MutableRecordBatch batch : batches) { + final BatchRetentionResult batchRetentionResult = filter.checkBatchRetention(batch); + final boolean containsMarkerForEmptyTxn = batchRetentionResult.containsMarkerForEmptyTxn; + final BatchRetention batchRetention = batchRetentionResult.batchRetention; + + filterResult.bytesRead += batch.sizeInBytes(); + + if (batchRetention == BatchRetention.DELETE) + continue; + + // We use the absolute offset to decide whether to retain the message or not. Due to KAFKA-4298, we have to + // allow for the possibility that a previous version corrupted the log by writing a compressed record batch + // with a magic value not matching the magic of the records (magic < 2). This will be fixed as we + // recopy the messages to the destination buffer. + byte batchMagic = batch.magic(); + List retainedRecords = new ArrayList<>(); + + final BatchFilterResult iterationResult = filterBatch(batch, decompressionBufferSupplier, filterResult, filter, + batchMagic, true, retainedRecords); + boolean containsTombstones = iterationResult.containsTombstones; + boolean writeOriginalBatch = iterationResult.writeOriginalBatch; + long maxOffset = iterationResult.maxOffset; + + if (!retainedRecords.isEmpty()) { + // we check if the delete horizon should be set to a new value + // in which case, we need to reset the base timestamp and overwrite the timestamp deltas + // if the batch does not contain tombstones, then we don't need to overwrite batch + boolean needToSetDeleteHorizon = batch.magic() >= RecordBatch.MAGIC_VALUE_V2 && (containsTombstones || containsMarkerForEmptyTxn) + && !batch.deleteHorizonMs().isPresent(); + if (writeOriginalBatch && !needToSetDeleteHorizon) { + batch.writeTo(bufferOutputStream); + filterResult.updateRetainedBatchMetadata(batch, retainedRecords.size(), false); + } else { + final MemoryRecordsBuilder builder; + long deleteHorizonMs; + if (needToSetDeleteHorizon) + deleteHorizonMs = filter.currentTime + filter.deleteRetentionMs; + else + deleteHorizonMs = batch.deleteHorizonMs().orElse(RecordBatch.NO_TIMESTAMP); + builder = buildRetainedRecordsInto(batch, retainedRecords, bufferOutputStream, deleteHorizonMs); + + MemoryRecords records = builder.build(); + int filteredBatchSize = records.sizeInBytes(); + if (filteredBatchSize > batch.sizeInBytes() && filteredBatchSize > maxRecordBatchSize) + log.warn("Record batch from {} with last offset {} exceeded max record batch size {} after cleaning " + + "(new size is {}). Consumers with version earlier than 0.10.1.0 may need to " + + "increase their fetch sizes.", + partition, batch.lastOffset(), maxRecordBatchSize, filteredBatchSize); + + MemoryRecordsBuilder.RecordsInfo info = builder.info(); + filterResult.updateRetainedBatchMetadata(info.maxTimestamp, info.shallowOffsetOfMaxTimestamp, + maxOffset, retainedRecords.size(), filteredBatchSize); + } + } else if (batchRetention == BatchRetention.RETAIN_EMPTY) { + if (batchMagic < RecordBatch.MAGIC_VALUE_V2) + throw new IllegalStateException("Empty batches are only supported for magic v2 and above"); + + bufferOutputStream.ensureRemaining(DefaultRecordBatch.RECORD_BATCH_OVERHEAD); + DefaultRecordBatch.writeEmptyHeader(bufferOutputStream.buffer(), batchMagic, batch.producerId(), + batch.producerEpoch(), batch.baseSequence(), batch.baseOffset(), batch.lastOffset(), + batch.partitionLeaderEpoch(), batch.timestampType(), batch.maxTimestamp(), + batch.isTransactional(), batch.isControlBatch()); + filterResult.updateRetainedBatchMetadata(batch, 0, true); + } + + // If we had to allocate a new buffer to fit the filtered buffer (see KAFKA-5316), return early to + // avoid the need for additional allocations. + ByteBuffer outputBuffer = bufferOutputStream.buffer(); + if (outputBuffer != destinationBuffer) { + filterResult.outputBuffer = outputBuffer; + return filterResult; + } + } + + return filterResult; + } + + private static BatchFilterResult filterBatch(RecordBatch batch, + BufferSupplier decompressionBufferSupplier, + FilterResult filterResult, + RecordFilter filter, + byte batchMagic, + boolean writeOriginalBatch, + List retainedRecords) { + long maxOffset = -1; + boolean containsTombstones = false; + try (final CloseableIterator iterator = batch.streamingIterator(decompressionBufferSupplier)) { + while (iterator.hasNext()) { + Record record = iterator.next(); + filterResult.messagesRead += 1; + + if (filter.shouldRetainRecord(batch, record)) { + // Check for log corruption due to KAFKA-4298. If we find it, make sure that we overwrite + // the corrupted batch with correct data. + if (!record.hasMagic(batchMagic)) + writeOriginalBatch = false; + + if (record.offset() > maxOffset) + maxOffset = record.offset(); + + retainedRecords.add(record); + + if (!record.hasValue()) { + containsTombstones = true; + } + } else { + writeOriginalBatch = false; + } + } + return new BatchFilterResult(writeOriginalBatch, containsTombstones, maxOffset); + } + } + + private static class BatchFilterResult { + private final boolean writeOriginalBatch; + private final boolean containsTombstones; + private final long maxOffset; + private BatchFilterResult(final boolean writeOriginalBatch, + final boolean containsTombstones, + final long maxOffset) { + this.writeOriginalBatch = writeOriginalBatch; + this.containsTombstones = containsTombstones; + this.maxOffset = maxOffset; + } + } + + private static MemoryRecordsBuilder buildRetainedRecordsInto(RecordBatch originalBatch, + List retainedRecords, + ByteBufferOutputStream bufferOutputStream, + final long deleteHorizonMs) { + byte magic = originalBatch.magic(); + TimestampType timestampType = originalBatch.timestampType(); + long logAppendTime = timestampType == TimestampType.LOG_APPEND_TIME ? + originalBatch.maxTimestamp() : RecordBatch.NO_TIMESTAMP; + long baseOffset = magic >= RecordBatch.MAGIC_VALUE_V2 ? + originalBatch.baseOffset() : retainedRecords.get(0).offset(); + + MemoryRecordsBuilder builder = new MemoryRecordsBuilder(bufferOutputStream, magic, + originalBatch.compressionType(), timestampType, baseOffset, logAppendTime, originalBatch.producerId(), + originalBatch.producerEpoch(), originalBatch.baseSequence(), originalBatch.isTransactional(), + originalBatch.isControlBatch(), originalBatch.partitionLeaderEpoch(), bufferOutputStream.limit(), deleteHorizonMs); + + for (Record record : retainedRecords) + builder.append(record); + + if (magic >= RecordBatch.MAGIC_VALUE_V2) + // we must preserve the last offset from the initial batch in order to ensure that the + // last sequence number from the batch remains even after compaction. Otherwise, the producer + // could incorrectly see an out of sequence error. + builder.overrideLastOffset(originalBatch.lastOffset()); + + return builder; + } + + /** + * Get the byte buffer that backs this instance for reading. + */ + public ByteBuffer buffer() { + return buffer.duplicate(); + } + + @Override + public Iterable batches() { + return batches; + } + + @Override + public String toString() { + return "MemoryRecords(size=" + sizeInBytes() + + ", buffer=" + buffer + + ")"; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + MemoryRecords that = (MemoryRecords) o; + + return buffer.equals(that.buffer); + } + + @Override + public int hashCode() { + return buffer.hashCode(); + } + + public static abstract class RecordFilter { + public final long currentTime; + public final long deleteRetentionMs; + + public RecordFilter(final long currentTime, final long deleteRetentionMs) { + this.currentTime = currentTime; + this.deleteRetentionMs = deleteRetentionMs; + } + + public static class BatchRetentionResult { + public final BatchRetention batchRetention; + public final boolean containsMarkerForEmptyTxn; + public BatchRetentionResult(final BatchRetention batchRetention, + final boolean containsMarkerForEmptyTxn) { + this.batchRetention = batchRetention; + this.containsMarkerForEmptyTxn = containsMarkerForEmptyTxn; + } + } + + public enum BatchRetention { + DELETE, // Delete the batch without inspecting records + RETAIN_EMPTY, // Retain the batch even if it is empty + DELETE_EMPTY // Delete the batch if it is empty + } + + /** + * Check whether the full batch can be discarded (i.e. whether we even need to + * check the records individually). + */ + protected abstract BatchRetentionResult checkBatchRetention(RecordBatch batch); + + /** + * Check whether a record should be retained in the log. Note that {@link #checkBatchRetention(RecordBatch)} + * is used prior to checking individual record retention. Only records from batches which were not + * explicitly discarded with {@link BatchRetention#DELETE} will be considered. + */ + protected abstract boolean shouldRetainRecord(RecordBatch recordBatch, Record record); + } + + public static class FilterResult { + private ByteBuffer outputBuffer; + private int messagesRead = 0; + // Note that `bytesRead` should contain only bytes from batches that have been processed, i.e. bytes from + // `messagesRead` and any discarded batches. + private int bytesRead = 0; + private int messagesRetained = 0; + private int bytesRetained = 0; + private long maxOffset = -1L; + private long maxTimestamp = RecordBatch.NO_TIMESTAMP; + private long shallowOffsetOfMaxTimestamp = -1L; + + private FilterResult(ByteBuffer outputBuffer) { + this.outputBuffer = outputBuffer; + } + + private void updateRetainedBatchMetadata(MutableRecordBatch retainedBatch, int numMessagesInBatch, boolean headerOnly) { + int bytesRetained = headerOnly ? DefaultRecordBatch.RECORD_BATCH_OVERHEAD : retainedBatch.sizeInBytes(); + updateRetainedBatchMetadata(retainedBatch.maxTimestamp(), retainedBatch.lastOffset(), + retainedBatch.lastOffset(), numMessagesInBatch, bytesRetained); + } + + private void updateRetainedBatchMetadata(long maxTimestamp, long shallowOffsetOfMaxTimestamp, long maxOffset, + int messagesRetained, int bytesRetained) { + validateBatchMetadata(maxTimestamp, shallowOffsetOfMaxTimestamp, maxOffset); + if (maxTimestamp > this.maxTimestamp) { + this.maxTimestamp = maxTimestamp; + this.shallowOffsetOfMaxTimestamp = shallowOffsetOfMaxTimestamp; + } + this.maxOffset = Math.max(maxOffset, this.maxOffset); + this.messagesRetained += messagesRetained; + this.bytesRetained += bytesRetained; + } + + private void validateBatchMetadata(long maxTimestamp, long shallowOffsetOfMaxTimestamp, long maxOffset) { + if (maxTimestamp != RecordBatch.NO_TIMESTAMP && shallowOffsetOfMaxTimestamp < 0) + throw new IllegalArgumentException("shallowOffset undefined for maximum timestamp " + maxTimestamp); + if (maxOffset < 0) + throw new IllegalArgumentException("maxOffset undefined"); + } + + public ByteBuffer outputBuffer() { + return outputBuffer; + } + + public int messagesRead() { + return messagesRead; + } + + public int bytesRead() { + return bytesRead; + } + + public int messagesRetained() { + return messagesRetained; + } + + public int bytesRetained() { + return bytesRetained; + } + + public long maxOffset() { + return maxOffset; + } + + public long maxTimestamp() { + return maxTimestamp; + } + + public long shallowOffsetOfMaxTimestamp() { + return shallowOffsetOfMaxTimestamp; + } + } + + public static MemoryRecords readableRecords(ByteBuffer buffer) { + return new MemoryRecords(buffer); + } + + public static MemoryRecordsBuilder builder(ByteBuffer buffer, + CompressionType compressionType, + TimestampType timestampType, + long baseOffset) { + return builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, compressionType, timestampType, baseOffset); + } + + public static MemoryRecordsBuilder builder(ByteBuffer buffer, + CompressionType compressionType, + TimestampType timestampType, + long baseOffset, + int maxSize) { + long logAppendTime = RecordBatch.NO_TIMESTAMP; + if (timestampType == TimestampType.LOG_APPEND_TIME) + logAppendTime = System.currentTimeMillis(); + + return new MemoryRecordsBuilder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, compressionType, timestampType, baseOffset, + logAppendTime, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, + false, false, RecordBatch.NO_PARTITION_LEADER_EPOCH, maxSize); + } + + public static MemoryRecordsBuilder idempotentBuilder(ByteBuffer buffer, + CompressionType compressionType, + long baseOffset, + long producerId, + short producerEpoch, + int baseSequence) { + return builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, compressionType, TimestampType.CREATE_TIME, + baseOffset, System.currentTimeMillis(), producerId, producerEpoch, baseSequence); + } + + public static MemoryRecordsBuilder builder(ByteBuffer buffer, + byte magic, + CompressionType compressionType, + TimestampType timestampType, + long baseOffset, + long logAppendTime) { + return builder(buffer, magic, compressionType, timestampType, baseOffset, logAppendTime, + RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, false, + RecordBatch.NO_PARTITION_LEADER_EPOCH); + } + + public static MemoryRecordsBuilder builder(ByteBuffer buffer, + byte magic, + CompressionType compressionType, + TimestampType timestampType, + long baseOffset) { + long logAppendTime = RecordBatch.NO_TIMESTAMP; + if (timestampType == TimestampType.LOG_APPEND_TIME) + logAppendTime = System.currentTimeMillis(); + return builder(buffer, magic, compressionType, timestampType, baseOffset, logAppendTime, + RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, false, + RecordBatch.NO_PARTITION_LEADER_EPOCH); + } + + public static MemoryRecordsBuilder builder(ByteBuffer buffer, + byte magic, + CompressionType compressionType, + TimestampType timestampType, + long baseOffset, + long logAppendTime, + int partitionLeaderEpoch) { + return builder(buffer, magic, compressionType, timestampType, baseOffset, logAppendTime, + RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, false, partitionLeaderEpoch); + } + + public static MemoryRecordsBuilder builder(ByteBuffer buffer, + CompressionType compressionType, + long baseOffset, + long producerId, + short producerEpoch, + int baseSequence, + boolean isTransactional) { + return builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, compressionType, TimestampType.CREATE_TIME, baseOffset, + RecordBatch.NO_TIMESTAMP, producerId, producerEpoch, baseSequence, isTransactional, + RecordBatch.NO_PARTITION_LEADER_EPOCH); + } + + public static MemoryRecordsBuilder builder(ByteBuffer buffer, + byte magic, + CompressionType compressionType, + TimestampType timestampType, + long baseOffset, + long logAppendTime, + long producerId, + short producerEpoch, + int baseSequence) { + return builder(buffer, magic, compressionType, timestampType, baseOffset, logAppendTime, + producerId, producerEpoch, baseSequence, false, RecordBatch.NO_PARTITION_LEADER_EPOCH); + } + + public static MemoryRecordsBuilder builder(ByteBuffer buffer, + byte magic, + CompressionType compressionType, + TimestampType timestampType, + long baseOffset, + long logAppendTime, + long producerId, + short producerEpoch, + int baseSequence, + boolean isTransactional, + int partitionLeaderEpoch) { + return builder(buffer, magic, compressionType, timestampType, baseOffset, + logAppendTime, producerId, producerEpoch, baseSequence, isTransactional, false, partitionLeaderEpoch); + } + + public static MemoryRecordsBuilder builder(ByteBuffer buffer, + byte magic, + CompressionType compressionType, + TimestampType timestampType, + long baseOffset, + long logAppendTime, + long producerId, + short producerEpoch, + int baseSequence, + boolean isTransactional, + boolean isControlBatch, + int partitionLeaderEpoch) { + return new MemoryRecordsBuilder(buffer, magic, compressionType, timestampType, baseOffset, + logAppendTime, producerId, producerEpoch, baseSequence, isTransactional, isControlBatch, partitionLeaderEpoch, + buffer.remaining()); + } + + public static MemoryRecords withRecords(CompressionType compressionType, SimpleRecord... records) { + return withRecords(RecordBatch.CURRENT_MAGIC_VALUE, compressionType, records); + } + + public static MemoryRecords withRecords(CompressionType compressionType, int partitionLeaderEpoch, SimpleRecord... records) { + return withRecords(RecordBatch.CURRENT_MAGIC_VALUE, 0L, compressionType, TimestampType.CREATE_TIME, + RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, + partitionLeaderEpoch, false, records); + } + + public static MemoryRecords withRecords(byte magic, CompressionType compressionType, SimpleRecord... records) { + return withRecords(magic, 0L, compressionType, TimestampType.CREATE_TIME, records); + } + + public static MemoryRecords withRecords(long initialOffset, CompressionType compressionType, SimpleRecord... records) { + return withRecords(RecordBatch.CURRENT_MAGIC_VALUE, initialOffset, compressionType, TimestampType.CREATE_TIME, + records); + } + + public static MemoryRecords withRecords(byte magic, long initialOffset, CompressionType compressionType, SimpleRecord... records) { + return withRecords(magic, initialOffset, compressionType, TimestampType.CREATE_TIME, records); + } + + public static MemoryRecords withRecords(long initialOffset, CompressionType compressionType, Integer partitionLeaderEpoch, SimpleRecord... records) { + return withRecords(RecordBatch.CURRENT_MAGIC_VALUE, initialOffset, compressionType, TimestampType.CREATE_TIME, RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, partitionLeaderEpoch, false, records); + } + + public static MemoryRecords withIdempotentRecords(CompressionType compressionType, long producerId, + short producerEpoch, int baseSequence, SimpleRecord... records) { + return withRecords(RecordBatch.CURRENT_MAGIC_VALUE, 0L, compressionType, TimestampType.CREATE_TIME, producerId, producerEpoch, + baseSequence, RecordBatch.NO_PARTITION_LEADER_EPOCH, false, records); + } + + public static MemoryRecords withIdempotentRecords(byte magic, long initialOffset, CompressionType compressionType, + long producerId, short producerEpoch, int baseSequence, + int partitionLeaderEpoch, SimpleRecord... records) { + return withRecords(magic, initialOffset, compressionType, TimestampType.CREATE_TIME, producerId, producerEpoch, + baseSequence, partitionLeaderEpoch, false, records); + } + + public static MemoryRecords withIdempotentRecords(long initialOffset, CompressionType compressionType, long producerId, + short producerEpoch, int baseSequence, int partitionLeaderEpoch, + SimpleRecord... records) { + return withRecords(RecordBatch.CURRENT_MAGIC_VALUE, initialOffset, compressionType, TimestampType.CREATE_TIME, + producerId, producerEpoch, baseSequence, partitionLeaderEpoch, false, records); + } + + public static MemoryRecords withTransactionalRecords(CompressionType compressionType, long producerId, + short producerEpoch, int baseSequence, SimpleRecord... records) { + return withRecords(RecordBatch.CURRENT_MAGIC_VALUE, 0L, compressionType, TimestampType.CREATE_TIME, + producerId, producerEpoch, baseSequence, RecordBatch.NO_PARTITION_LEADER_EPOCH, true, records); + } + + public static MemoryRecords withTransactionalRecords(byte magic, long initialOffset, CompressionType compressionType, + long producerId, short producerEpoch, int baseSequence, + int partitionLeaderEpoch, SimpleRecord... records) { + return withRecords(magic, initialOffset, compressionType, TimestampType.CREATE_TIME, producerId, producerEpoch, + baseSequence, partitionLeaderEpoch, true, records); + } + + public static MemoryRecords withTransactionalRecords(long initialOffset, CompressionType compressionType, long producerId, + short producerEpoch, int baseSequence, int partitionLeaderEpoch, + SimpleRecord... records) { + return withTransactionalRecords(RecordBatch.CURRENT_MAGIC_VALUE, initialOffset, compressionType, + producerId, producerEpoch, baseSequence, partitionLeaderEpoch, records); + } + + public static MemoryRecords withRecords(byte magic, long initialOffset, CompressionType compressionType, + TimestampType timestampType, SimpleRecord... records) { + return withRecords(magic, initialOffset, compressionType, timestampType, RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, RecordBatch.NO_PARTITION_LEADER_EPOCH, + false, records); + } + + public static MemoryRecords withRecords(byte magic, long initialOffset, CompressionType compressionType, + TimestampType timestampType, long producerId, short producerEpoch, + int baseSequence, int partitionLeaderEpoch, boolean isTransactional, + SimpleRecord... records) { + if (records.length == 0) + return MemoryRecords.EMPTY; + int sizeEstimate = AbstractRecords.estimateSizeInBytes(magic, compressionType, Arrays.asList(records)); + ByteBufferOutputStream bufferStream = new ByteBufferOutputStream(sizeEstimate); + long logAppendTime = RecordBatch.NO_TIMESTAMP; + if (timestampType == TimestampType.LOG_APPEND_TIME) + logAppendTime = System.currentTimeMillis(); + MemoryRecordsBuilder builder = new MemoryRecordsBuilder(bufferStream, magic, compressionType, timestampType, + initialOffset, logAppendTime, producerId, producerEpoch, baseSequence, isTransactional, false, + partitionLeaderEpoch, sizeEstimate); + for (SimpleRecord record : records) + builder.append(record); + return builder.build(); + } + + public static MemoryRecords withEndTransactionMarker(long producerId, short producerEpoch, EndTransactionMarker marker) { + return withEndTransactionMarker(0L, System.currentTimeMillis(), RecordBatch.NO_PARTITION_LEADER_EPOCH, + producerId, producerEpoch, marker); + } + + public static MemoryRecords withEndTransactionMarker(long timestamp, long producerId, short producerEpoch, + EndTransactionMarker marker) { + return withEndTransactionMarker(0L, timestamp, RecordBatch.NO_PARTITION_LEADER_EPOCH, producerId, + producerEpoch, marker); + } + + public static MemoryRecords withEndTransactionMarker(long initialOffset, long timestamp, int partitionLeaderEpoch, + long producerId, short producerEpoch, + EndTransactionMarker marker) { + int endTxnMarkerBatchSize = DefaultRecordBatch.RECORD_BATCH_OVERHEAD + + EndTransactionMarker.CURRENT_END_TXN_SCHEMA_RECORD_SIZE; + ByteBuffer buffer = ByteBuffer.allocate(endTxnMarkerBatchSize); + writeEndTransactionalMarker(buffer, initialOffset, timestamp, partitionLeaderEpoch, producerId, + producerEpoch, marker); + buffer.flip(); + return MemoryRecords.readableRecords(buffer); + } + + public static void writeEndTransactionalMarker(ByteBuffer buffer, long initialOffset, long timestamp, + int partitionLeaderEpoch, long producerId, short producerEpoch, + EndTransactionMarker marker) { + boolean isTransactional = true; + try (MemoryRecordsBuilder builder = new MemoryRecordsBuilder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE, + TimestampType.CREATE_TIME, initialOffset, timestamp, producerId, producerEpoch, + RecordBatch.NO_SEQUENCE, isTransactional, true, partitionLeaderEpoch, + buffer.capacity()) + ) { + builder.appendEndTxnMarker(timestamp, marker); + } + } + + public static MemoryRecords withLeaderChangeMessage( + long initialOffset, + long timestamp, + int leaderEpoch, + ByteBuffer buffer, + LeaderChangeMessage leaderChangeMessage + ) { + writeLeaderChangeMessage(buffer, initialOffset, timestamp, leaderEpoch, leaderChangeMessage); + buffer.flip(); + return MemoryRecords.readableRecords(buffer); + } + + private static void writeLeaderChangeMessage(ByteBuffer buffer, + long initialOffset, + long timestamp, + int leaderEpoch, + LeaderChangeMessage leaderChangeMessage) { + try (MemoryRecordsBuilder builder = new MemoryRecordsBuilder( + buffer, RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE, + TimestampType.CREATE_TIME, initialOffset, timestamp, + RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, + false, true, leaderEpoch, buffer.capacity()) + ) { + builder.appendLeaderChangeMessage(timestamp, leaderChangeMessage); + } + } + + public static MemoryRecords withSnapshotHeaderRecord( + long initialOffset, + long timestamp, + int leaderEpoch, + ByteBuffer buffer, + SnapshotHeaderRecord snapshotHeaderRecord + ) { + writeSnapshotHeaderRecord(buffer, initialOffset, timestamp, leaderEpoch, snapshotHeaderRecord); + buffer.flip(); + return MemoryRecords.readableRecords(buffer); + } + + private static void writeSnapshotHeaderRecord(ByteBuffer buffer, + long initialOffset, + long timestamp, + int leaderEpoch, + SnapshotHeaderRecord snapshotHeaderRecord + ) { + try (MemoryRecordsBuilder builder = new MemoryRecordsBuilder( + buffer, RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE, + TimestampType.CREATE_TIME, initialOffset, timestamp, + RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, + false, true, leaderEpoch, buffer.capacity()) + ) { + builder.appendSnapshotHeaderMessage(timestamp, snapshotHeaderRecord); + } + } + + public static MemoryRecords withSnapshotFooterRecord( + long initialOffset, + long timestamp, + int leaderEpoch, + ByteBuffer buffer, + SnapshotFooterRecord snapshotFooterRecord + ) { + writeSnapshotFooterRecord(buffer, initialOffset, timestamp, leaderEpoch, snapshotFooterRecord); + buffer.flip(); + return MemoryRecords.readableRecords(buffer); + } + + private static void writeSnapshotFooterRecord(ByteBuffer buffer, + long initialOffset, + long timestamp, + int leaderEpoch, + SnapshotFooterRecord snapshotFooterRecord + ) { + try (MemoryRecordsBuilder builder = new MemoryRecordsBuilder( + buffer, RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE, + TimestampType.CREATE_TIME, initialOffset, timestamp, + RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, + false, true, leaderEpoch, buffer.capacity()) + ) { + builder.appendSnapshotFooterMessage(timestamp, snapshotFooterRecord); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/MemoryRecordsBuilder.java b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecordsBuilder.java new file mode 100644 index 0000000..b825a93 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/MemoryRecordsBuilder.java @@ -0,0 +1,881 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.message.LeaderChangeMessage; +import org.apache.kafka.common.message.SnapshotHeaderRecord; +import org.apache.kafka.common.message.SnapshotFooterRecord; +import org.apache.kafka.common.protocol.MessageUtil; +import org.apache.kafka.common.protocol.types.Struct; +import org.apache.kafka.common.utils.ByteBufferOutputStream; +import org.apache.kafka.common.utils.Utils; + +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; + +import static org.apache.kafka.common.utils.Utils.wrapNullable; + +/** + * This class is used to write new log data in memory, i.e. this is the write path for {@link MemoryRecords}. + * It transparently handles compression and exposes methods for appending new records, possibly with message + * format conversion. + * + * In cases where keeping memory retention low is important and there's a gap between the time that record appends stop + * and the builder is closed (e.g. the Producer), it's important to call `closeForRecordAppends` when the former happens. + * This will release resources like compression buffers that can be relatively large (64 KB for LZ4). + */ +public class MemoryRecordsBuilder implements AutoCloseable { + private static final float COMPRESSION_RATE_ESTIMATION_FACTOR = 1.05f; + private static final DataOutputStream CLOSED_STREAM = new DataOutputStream(new OutputStream() { + @Override + public void write(int b) { + throw new IllegalStateException("MemoryRecordsBuilder is closed for record appends"); + } + }); + + private final TimestampType timestampType; + private final CompressionType compressionType; + // Used to hold a reference to the underlying ByteBuffer so that we can write the record batch header and access + // the written bytes. ByteBufferOutputStream allocates a new ByteBuffer if the existing one is not large enough, + // so it's not safe to hold a direct reference to the underlying ByteBuffer. + private final ByteBufferOutputStream bufferStream; + private final byte magic; + private final int initialPosition; + private final long baseOffset; + private final long logAppendTime; + private final boolean isControlBatch; + private final int partitionLeaderEpoch; + private final int writeLimit; + private final int batchHeaderSizeInBytes; + + // Use a conservative estimate of the compression ratio. The producer overrides this using statistics + // from previous batches before appending any records. + private float estimatedCompressionRatio = 1.0F; + + // Used to append records, may compress data on the fly + private DataOutputStream appendStream; + private boolean isTransactional; + private long producerId; + private short producerEpoch; + private int baseSequence; + private int uncompressedRecordsSizeInBytes = 0; // Number of bytes (excluding the header) written before compression + private int numRecords = 0; + private float actualCompressionRatio = 1; + private long maxTimestamp = RecordBatch.NO_TIMESTAMP; + private long deleteHorizonMs; + private long offsetOfMaxTimestamp = -1; + private Long lastOffset = null; + private Long baseTimestamp = null; + + private MemoryRecords builtRecords; + private boolean aborted = false; + + public MemoryRecordsBuilder(ByteBufferOutputStream bufferStream, + byte magic, + CompressionType compressionType, + TimestampType timestampType, + long baseOffset, + long logAppendTime, + long producerId, + short producerEpoch, + int baseSequence, + boolean isTransactional, + boolean isControlBatch, + int partitionLeaderEpoch, + int writeLimit, + long deleteHorizonMs) { + if (magic > RecordBatch.MAGIC_VALUE_V0 && timestampType == TimestampType.NO_TIMESTAMP_TYPE) + throw new IllegalArgumentException("TimestampType must be set for magic >= 0"); + if (magic < RecordBatch.MAGIC_VALUE_V2) { + if (isTransactional) + throw new IllegalArgumentException("Transactional records are not supported for magic " + magic); + if (isControlBatch) + throw new IllegalArgumentException("Control records are not supported for magic " + magic); + if (compressionType == CompressionType.ZSTD) + throw new IllegalArgumentException("ZStandard compression is not supported for magic " + magic); + if (deleteHorizonMs != RecordBatch.NO_TIMESTAMP) + throw new IllegalArgumentException("Delete horizon timestamp is not supported for magic " + magic); + } + + this.magic = magic; + this.timestampType = timestampType; + this.compressionType = compressionType; + this.baseOffset = baseOffset; + this.logAppendTime = logAppendTime; + this.numRecords = 0; + this.uncompressedRecordsSizeInBytes = 0; + this.actualCompressionRatio = 1; + this.maxTimestamp = RecordBatch.NO_TIMESTAMP; + this.producerId = producerId; + this.producerEpoch = producerEpoch; + this.baseSequence = baseSequence; + this.isTransactional = isTransactional; + this.isControlBatch = isControlBatch; + this.deleteHorizonMs = deleteHorizonMs; + this.partitionLeaderEpoch = partitionLeaderEpoch; + this.writeLimit = writeLimit; + this.initialPosition = bufferStream.position(); + this.batchHeaderSizeInBytes = AbstractRecords.recordBatchHeaderSizeInBytes(magic, compressionType); + + bufferStream.position(initialPosition + batchHeaderSizeInBytes); + this.bufferStream = bufferStream; + this.appendStream = new DataOutputStream(compressionType.wrapForOutput(this.bufferStream, magic)); + + if (hasDeleteHorizonMs()) { + this.baseTimestamp = deleteHorizonMs; + } + } + + public MemoryRecordsBuilder(ByteBufferOutputStream bufferStream, + byte magic, + CompressionType compressionType, + TimestampType timestampType, + long baseOffset, + long logAppendTime, + long producerId, + short producerEpoch, + int baseSequence, + boolean isTransactional, + boolean isControlBatch, + int partitionLeaderEpoch, + int writeLimit) { + this(bufferStream, magic, compressionType, timestampType, baseOffset, logAppendTime, producerId, + producerEpoch, baseSequence, isTransactional, isControlBatch, partitionLeaderEpoch, writeLimit, + RecordBatch.NO_TIMESTAMP); + } + + /** + * Construct a new builder. + * + * @param buffer The underlying buffer to use (note that this class will allocate a new buffer if necessary + * to fit the records appended) + * @param magic The magic value to use + * @param compressionType The compression codec to use + * @param timestampType The desired timestamp type. For magic > 0, this cannot be {@link TimestampType#NO_TIMESTAMP_TYPE}. + * @param baseOffset The initial offset to use for + * @param logAppendTime The log append time of this record set. Can be set to NO_TIMESTAMP if CREATE_TIME is used. + * @param producerId The producer ID associated with the producer writing this record set + * @param producerEpoch The epoch of the producer + * @param baseSequence The sequence number of the first record in this set + * @param isTransactional Whether or not the records are part of a transaction + * @param isControlBatch Whether or not this is a control batch (e.g. for transaction markers) + * @param partitionLeaderEpoch The epoch of the partition leader appending the record set to the log + * @param writeLimit The desired limit on the total bytes for this record set (note that this can be exceeded + * when compression is used since size estimates are rough, and in the case that the first + * record added exceeds the size). + */ + public MemoryRecordsBuilder(ByteBuffer buffer, + byte magic, + CompressionType compressionType, + TimestampType timestampType, + long baseOffset, + long logAppendTime, + long producerId, + short producerEpoch, + int baseSequence, + boolean isTransactional, + boolean isControlBatch, + int partitionLeaderEpoch, + int writeLimit) { + this(new ByteBufferOutputStream(buffer), magic, compressionType, timestampType, baseOffset, logAppendTime, + producerId, producerEpoch, baseSequence, isTransactional, isControlBatch, partitionLeaderEpoch, + writeLimit); + } + + public ByteBuffer buffer() { + return bufferStream.buffer(); + } + + public int initialCapacity() { + return bufferStream.initialCapacity(); + } + + public double compressionRatio() { + return actualCompressionRatio; + } + + public CompressionType compressionType() { + return compressionType; + } + + public boolean isControlBatch() { + return isControlBatch; + } + + public boolean isTransactional() { + return isTransactional; + } + + public boolean hasDeleteHorizonMs() { + return magic >= RecordBatch.MAGIC_VALUE_V2 && deleteHorizonMs >= 0L; + } + + /** + * Close this builder and return the resulting buffer. + * @return The built log buffer + */ + public MemoryRecords build() { + if (aborted) { + throw new IllegalStateException("Attempting to build an aborted record batch"); + } + close(); + return builtRecords; + } + + /** + * Get the max timestamp and its offset. The details of the offset returned are a bit subtle. + * + * If the log append time is used, the offset will be the last offset unless no compression is used and + * the message format version is 0 or 1, in which case, it will be the first offset. + * + * If create time is used, the offset will be the last offset unless no compression is used and the message + * format version is 0 or 1, in which case, it will be the offset of the record with the max timestamp. + * + * @return The max timestamp and its offset + */ + public RecordsInfo info() { + if (timestampType == TimestampType.LOG_APPEND_TIME) { + long shallowOffsetOfMaxTimestamp; + // Use the last offset when dealing with record batches + if (compressionType != CompressionType.NONE || magic >= RecordBatch.MAGIC_VALUE_V2) + shallowOffsetOfMaxTimestamp = lastOffset; + else + shallowOffsetOfMaxTimestamp = baseOffset; + return new RecordsInfo(logAppendTime, shallowOffsetOfMaxTimestamp); + } else if (maxTimestamp == RecordBatch.NO_TIMESTAMP) { + return new RecordsInfo(RecordBatch.NO_TIMESTAMP, lastOffset); + } else { + long shallowOffsetOfMaxTimestamp; + // Use the last offset when dealing with record batches + if (compressionType != CompressionType.NONE || magic >= RecordBatch.MAGIC_VALUE_V2) + shallowOffsetOfMaxTimestamp = lastOffset; + else + shallowOffsetOfMaxTimestamp = offsetOfMaxTimestamp; + return new RecordsInfo(maxTimestamp, shallowOffsetOfMaxTimestamp); + } + } + + public int numRecords() { + return numRecords; + } + + /** + * Return the sum of the size of the batch header (always uncompressed) and the records (before compression). + */ + public int uncompressedBytesWritten() { + return uncompressedRecordsSizeInBytes + batchHeaderSizeInBytes; + } + + public void setProducerState(long producerId, short producerEpoch, int baseSequence, boolean isTransactional) { + if (isClosed()) { + // Sequence numbers are assigned when the batch is closed while the accumulator is being drained. + // If the resulting ProduceRequest to the partition leader failed for a retriable error, the batch will + // be re queued. In this case, we should not attempt to set the state again, since changing the producerId and sequence + // once a batch has been sent to the broker risks introducing duplicates. + throw new IllegalStateException("Trying to set producer state of an already closed batch. This indicates a bug on the client."); + } + this.producerId = producerId; + this.producerEpoch = producerEpoch; + this.baseSequence = baseSequence; + this.isTransactional = isTransactional; + } + + public void overrideLastOffset(long lastOffset) { + if (builtRecords != null) + throw new IllegalStateException("Cannot override the last offset after the records have been built"); + this.lastOffset = lastOffset; + } + + /** + * Release resources required for record appends (e.g. compression buffers). Once this method is called, it's only + * possible to update the RecordBatch header. + */ + public void closeForRecordAppends() { + if (appendStream != CLOSED_STREAM) { + try { + appendStream.close(); + } catch (IOException e) { + throw new KafkaException(e); + } finally { + appendStream = CLOSED_STREAM; + } + } + } + + public void abort() { + closeForRecordAppends(); + buffer().position(initialPosition); + aborted = true; + } + + public void reopenAndRewriteProducerState(long producerId, short producerEpoch, int baseSequence, boolean isTransactional) { + if (aborted) + throw new IllegalStateException("Should not reopen a batch which is already aborted."); + builtRecords = null; + this.producerId = producerId; + this.producerEpoch = producerEpoch; + this.baseSequence = baseSequence; + this.isTransactional = isTransactional; + } + + + public void close() { + if (aborted) + throw new IllegalStateException("Cannot close MemoryRecordsBuilder as it has already been aborted"); + + if (builtRecords != null) + return; + + validateProducerState(); + + closeForRecordAppends(); + + if (numRecords == 0L) { + buffer().position(initialPosition); + builtRecords = MemoryRecords.EMPTY; + } else { + if (magic > RecordBatch.MAGIC_VALUE_V1) + this.actualCompressionRatio = (float) writeDefaultBatchHeader() / this.uncompressedRecordsSizeInBytes; + else if (compressionType != CompressionType.NONE) + this.actualCompressionRatio = (float) writeLegacyCompressedWrapperHeader() / this.uncompressedRecordsSizeInBytes; + + ByteBuffer buffer = buffer().duplicate(); + buffer.flip(); + buffer.position(initialPosition); + builtRecords = MemoryRecords.readableRecords(buffer.slice()); + } + } + + private void validateProducerState() { + if (isTransactional && producerId == RecordBatch.NO_PRODUCER_ID) + throw new IllegalArgumentException("Cannot write transactional messages without a valid producer ID"); + + if (producerId != RecordBatch.NO_PRODUCER_ID) { + if (producerEpoch == RecordBatch.NO_PRODUCER_EPOCH) + throw new IllegalArgumentException("Invalid negative producer epoch"); + + if (baseSequence < 0 && !isControlBatch) + throw new IllegalArgumentException("Invalid negative sequence number used"); + + if (magic < RecordBatch.MAGIC_VALUE_V2) + throw new IllegalArgumentException("Idempotent messages are not supported for magic " + magic); + } + } + + /** + * Write the header to the default batch. + * @return the written compressed bytes. + */ + private int writeDefaultBatchHeader() { + ensureOpenForRecordBatchWrite(); + ByteBuffer buffer = bufferStream.buffer(); + int pos = buffer.position(); + buffer.position(initialPosition); + int size = pos - initialPosition; + int writtenCompressed = size - DefaultRecordBatch.RECORD_BATCH_OVERHEAD; + int offsetDelta = (int) (lastOffset - baseOffset); + + final long maxTimestamp; + if (timestampType == TimestampType.LOG_APPEND_TIME) + maxTimestamp = logAppendTime; + else + maxTimestamp = this.maxTimestamp; + + DefaultRecordBatch.writeHeader(buffer, baseOffset, offsetDelta, size, magic, compressionType, timestampType, + baseTimestamp, maxTimestamp, producerId, producerEpoch, baseSequence, isTransactional, isControlBatch, + hasDeleteHorizonMs(), partitionLeaderEpoch, numRecords); + + buffer.position(pos); + return writtenCompressed; + } + + /** + * Write the header to the legacy batch. + * @return the written compressed bytes. + */ + private int writeLegacyCompressedWrapperHeader() { + ensureOpenForRecordBatchWrite(); + ByteBuffer buffer = bufferStream.buffer(); + int pos = buffer.position(); + buffer.position(initialPosition); + + int wrapperSize = pos - initialPosition - Records.LOG_OVERHEAD; + int writtenCompressed = wrapperSize - LegacyRecord.recordOverhead(magic); + AbstractLegacyRecordBatch.writeHeader(buffer, lastOffset, wrapperSize); + + long timestamp = timestampType == TimestampType.LOG_APPEND_TIME ? logAppendTime : maxTimestamp; + LegacyRecord.writeCompressedRecordHeader(buffer, magic, wrapperSize, timestamp, compressionType, timestampType); + + buffer.position(pos); + return writtenCompressed; + } + + /** + * Append a new record at the given offset. + */ + private void appendWithOffset(long offset, boolean isControlRecord, long timestamp, ByteBuffer key, + ByteBuffer value, Header[] headers) { + try { + if (isControlRecord != isControlBatch) + throw new IllegalArgumentException("Control records can only be appended to control batches"); + + if (lastOffset != null && offset <= lastOffset) + throw new IllegalArgumentException(String.format("Illegal offset %s following previous offset %s " + + "(Offsets must increase monotonically).", offset, lastOffset)); + + if (timestamp < 0 && timestamp != RecordBatch.NO_TIMESTAMP) + throw new IllegalArgumentException("Invalid negative timestamp " + timestamp); + + if (magic < RecordBatch.MAGIC_VALUE_V2 && headers != null && headers.length > 0) + throw new IllegalArgumentException("Magic v" + magic + " does not support record headers"); + + if (baseTimestamp == null) + baseTimestamp = timestamp; + + if (magic > RecordBatch.MAGIC_VALUE_V1) { + appendDefaultRecord(offset, timestamp, key, value, headers); + } else { + appendLegacyRecord(offset, timestamp, key, value, magic); + } + } catch (IOException e) { + throw new KafkaException("I/O exception when writing to the append stream, closing", e); + } + } + + /** + * Append a new record at the given offset. + * @param offset The absolute offset of the record in the log buffer + * @param timestamp The record timestamp + * @param key The record key + * @param value The record value + * @param headers The record headers if there are any + */ + public void appendWithOffset(long offset, long timestamp, byte[] key, byte[] value, Header[] headers) { + appendWithOffset(offset, false, timestamp, wrapNullable(key), wrapNullable(value), headers); + } + + /** + * Append a new record at the given offset. + * @param offset The absolute offset of the record in the log buffer + * @param timestamp The record timestamp + * @param key The record key + * @param value The record value + * @param headers The record headers if there are any + */ + public void appendWithOffset(long offset, long timestamp, ByteBuffer key, ByteBuffer value, Header[] headers) { + appendWithOffset(offset, false, timestamp, key, value, headers); + } + + /** + * Append a new record at the given offset. + * @param offset The absolute offset of the record in the log buffer + * @param timestamp The record timestamp + * @param key The record key + * @param value The record value + */ + public void appendWithOffset(long offset, long timestamp, byte[] key, byte[] value) { + appendWithOffset(offset, timestamp, wrapNullable(key), wrapNullable(value), Record.EMPTY_HEADERS); + } + + /** + * Append a new record at the given offset. + * @param offset The absolute offset of the record in the log buffer + * @param timestamp The record timestamp + * @param key The record key + * @param value The record value + */ + public void appendWithOffset(long offset, long timestamp, ByteBuffer key, ByteBuffer value) { + appendWithOffset(offset, timestamp, key, value, Record.EMPTY_HEADERS); + } + + /** + * Append a new record at the given offset. + * @param offset The absolute offset of the record in the log buffer + * @param record The record to append + */ + public void appendWithOffset(long offset, SimpleRecord record) { + appendWithOffset(offset, record.timestamp(), record.key(), record.value(), record.headers()); + } + + /** + * Append a control record at the given offset. The control record type must be known or + * this method will raise an error. + * + * @param offset The absolute offset of the record in the log buffer + * @param record The record to append + */ + public void appendControlRecordWithOffset(long offset, SimpleRecord record) { + short typeId = ControlRecordType.parseTypeId(record.key()); + ControlRecordType type = ControlRecordType.fromTypeId(typeId); + if (type == ControlRecordType.UNKNOWN) + throw new IllegalArgumentException("Cannot append record with unknown control record type " + typeId); + + appendWithOffset(offset, true, record.timestamp(), + record.key(), record.value(), record.headers()); + } + + /** + * Append a new record at the next sequential offset. + * @param timestamp The record timestamp + * @param key The record key + * @param value The record value + */ + public void append(long timestamp, ByteBuffer key, ByteBuffer value) { + append(timestamp, key, value, Record.EMPTY_HEADERS); + } + + /** + * Append a new record at the next sequential offset. + * @param timestamp The record timestamp + * @param key The record key + * @param value The record value + * @param headers The record headers if there are any + * @return CRC of the record or null if record-level CRC is not supported for the message format + */ + public void append(long timestamp, ByteBuffer key, ByteBuffer value, Header[] headers) { + appendWithOffset(nextSequentialOffset(), timestamp, key, value, headers); + } + + /** + * Append a new record at the next sequential offset. + * @param timestamp The record timestamp + * @param key The record key + * @param value The record value + * @return CRC of the record or null if record-level CRC is not supported for the message format + */ + public void append(long timestamp, byte[] key, byte[] value) { + append(timestamp, wrapNullable(key), wrapNullable(value), Record.EMPTY_HEADERS); + } + + /** + * Append a new record at the next sequential offset. + * @param timestamp The record timestamp + * @param key The record key + * @param value The record value + * @param headers The record headers if there are any + */ + public void append(long timestamp, byte[] key, byte[] value, Header[] headers) { + append(timestamp, wrapNullable(key), wrapNullable(value), headers); + } + + /** + * Append a new record at the next sequential offset. + * @param record The record to append + */ + public void append(SimpleRecord record) { + appendWithOffset(nextSequentialOffset(), record); + } + + /** + * Append a control record at the next sequential offset. + * @param timestamp The record timestamp + * @param type The control record type (cannot be UNKNOWN) + * @param value The control record value + */ + private void appendControlRecord(long timestamp, ControlRecordType type, ByteBuffer value) { + Struct keyStruct = type.recordKey(); + ByteBuffer key = ByteBuffer.allocate(keyStruct.sizeOf()); + keyStruct.writeTo(key); + key.flip(); + appendWithOffset(nextSequentialOffset(), true, timestamp, key, value, Record.EMPTY_HEADERS); + } + + public void appendEndTxnMarker(long timestamp, EndTransactionMarker marker) { + if (producerId == RecordBatch.NO_PRODUCER_ID) + throw new IllegalArgumentException("End transaction marker requires a valid producerId"); + if (!isTransactional) + throw new IllegalArgumentException("End transaction marker depends on batch transactional flag being enabled"); + ByteBuffer value = marker.serializeValue(); + appendControlRecord(timestamp, marker.controlType(), value); + } + + public void appendLeaderChangeMessage(long timestamp, LeaderChangeMessage leaderChangeMessage) { + if (partitionLeaderEpoch == RecordBatch.NO_PARTITION_LEADER_EPOCH) { + throw new IllegalArgumentException("Partition leader epoch must be valid, but get " + partitionLeaderEpoch); + } + appendControlRecord(timestamp, ControlRecordType.LEADER_CHANGE, + MessageUtil.toByteBuffer(leaderChangeMessage, ControlRecordUtils.LEADER_CHANGE_SCHEMA_HIGHEST_VERSION)); + } + + public void appendSnapshotHeaderMessage(long timestamp, SnapshotHeaderRecord snapshotHeaderRecord) { + appendControlRecord(timestamp, ControlRecordType.SNAPSHOT_HEADER, + MessageUtil.toByteBuffer(snapshotHeaderRecord, ControlRecordUtils.SNAPSHOT_HEADER_HIGHEST_VERSION)); + } + + public void appendSnapshotFooterMessage(long timestamp, SnapshotFooterRecord snapshotHeaderRecord) { + appendControlRecord(timestamp, ControlRecordType.SNAPSHOT_FOOTER, + MessageUtil.toByteBuffer(snapshotHeaderRecord, ControlRecordUtils.SNAPSHOT_FOOTER_HIGHEST_VERSION)); + } + + /** + * Add a legacy record without doing offset/magic validation (this should only be used in testing). + * @param offset The offset of the record + * @param record The record to add + */ + public void appendUncheckedWithOffset(long offset, LegacyRecord record) { + ensureOpenForRecordAppend(); + try { + int size = record.sizeInBytes(); + AbstractLegacyRecordBatch.writeHeader(appendStream, toInnerOffset(offset), size); + + ByteBuffer buffer = record.buffer().duplicate(); + appendStream.write(buffer.array(), buffer.arrayOffset(), buffer.limit()); + + recordWritten(offset, record.timestamp(), size + Records.LOG_OVERHEAD); + } catch (IOException e) { + throw new KafkaException("I/O exception when writing to the append stream, closing", e); + } + } + + /** + * Append a record without doing offset/magic validation (this should only be used in testing). + * + * @param offset The offset of the record + * @param record The record to add + */ + public void appendUncheckedWithOffset(long offset, SimpleRecord record) throws IOException { + if (magic >= RecordBatch.MAGIC_VALUE_V2) { + int offsetDelta = (int) (offset - baseOffset); + long timestamp = record.timestamp(); + if (baseTimestamp == null) + baseTimestamp = timestamp; + + int sizeInBytes = DefaultRecord.writeTo(appendStream, + offsetDelta, + timestamp - baseTimestamp, + record.key(), + record.value(), + record.headers()); + recordWritten(offset, timestamp, sizeInBytes); + } else { + LegacyRecord legacyRecord = LegacyRecord.create(magic, + record.timestamp(), + Utils.toNullableArray(record.key()), + Utils.toNullableArray(record.value())); + appendUncheckedWithOffset(offset, legacyRecord); + } + } + + /** + * Append a record at the next sequential offset. + * @param record the record to add + */ + public void append(Record record) { + appendWithOffset(record.offset(), isControlBatch, record.timestamp(), record.key(), record.value(), record.headers()); + } + + /** + * Append a log record using a different offset + * @param offset The offset of the record + * @param record The record to add + */ + public void appendWithOffset(long offset, Record record) { + appendWithOffset(offset, record.timestamp(), record.key(), record.value(), record.headers()); + } + + /** + * Add a record with a given offset. The record must have a magic which matches the magic use to + * construct this builder and the offset must be greater than the last appended record. + * @param offset The offset of the record + * @param record The record to add + */ + public void appendWithOffset(long offset, LegacyRecord record) { + appendWithOffset(offset, record.timestamp(), record.key(), record.value()); + } + + /** + * Append the record at the next consecutive offset. If no records have been appended yet, use the base + * offset of this builder. + * @param record The record to add + */ + public void append(LegacyRecord record) { + appendWithOffset(nextSequentialOffset(), record); + } + + private void appendDefaultRecord(long offset, long timestamp, ByteBuffer key, ByteBuffer value, + Header[] headers) throws IOException { + ensureOpenForRecordAppend(); + int offsetDelta = (int) (offset - baseOffset); + long timestampDelta = timestamp - baseTimestamp; + int sizeInBytes = DefaultRecord.writeTo(appendStream, offsetDelta, timestampDelta, key, value, headers); + recordWritten(offset, timestamp, sizeInBytes); + } + + private long appendLegacyRecord(long offset, long timestamp, ByteBuffer key, ByteBuffer value, byte magic) throws IOException { + ensureOpenForRecordAppend(); + if (compressionType == CompressionType.NONE && timestampType == TimestampType.LOG_APPEND_TIME) + timestamp = logAppendTime; + + int size = LegacyRecord.recordSize(magic, key, value); + AbstractLegacyRecordBatch.writeHeader(appendStream, toInnerOffset(offset), size); + + if (timestampType == TimestampType.LOG_APPEND_TIME) + timestamp = logAppendTime; + long crc = LegacyRecord.write(appendStream, magic, timestamp, key, value, CompressionType.NONE, timestampType); + recordWritten(offset, timestamp, size + Records.LOG_OVERHEAD); + return crc; + } + + private long toInnerOffset(long offset) { + // use relative offsets for compressed messages with magic v1 + if (magic > 0 && compressionType != CompressionType.NONE) + return offset - baseOffset; + return offset; + } + + private void recordWritten(long offset, long timestamp, int size) { + if (numRecords == Integer.MAX_VALUE) + throw new IllegalArgumentException("Maximum number of records per batch exceeded, max records: " + Integer.MAX_VALUE); + if (offset - baseOffset > Integer.MAX_VALUE) + throw new IllegalArgumentException("Maximum offset delta exceeded, base offset: " + baseOffset + + ", last offset: " + offset); + + numRecords += 1; + uncompressedRecordsSizeInBytes += size; + lastOffset = offset; + + if (magic > RecordBatch.MAGIC_VALUE_V0 && timestamp > maxTimestamp) { + maxTimestamp = timestamp; + offsetOfMaxTimestamp = offset; + } + } + + private void ensureOpenForRecordAppend() { + if (appendStream == CLOSED_STREAM) + throw new IllegalStateException("Tried to append a record, but MemoryRecordsBuilder is closed for record appends"); + } + + private void ensureOpenForRecordBatchWrite() { + if (isClosed()) + throw new IllegalStateException("Tried to write record batch header, but MemoryRecordsBuilder is closed"); + if (aborted) + throw new IllegalStateException("Tried to write record batch header, but MemoryRecordsBuilder is aborted"); + } + + /** + * Get an estimate of the number of bytes written (based on the estimation factor hard-coded in {@link CompressionType}. + * @return The estimated number of bytes written + */ + private int estimatedBytesWritten() { + if (compressionType == CompressionType.NONE) { + return batchHeaderSizeInBytes + uncompressedRecordsSizeInBytes; + } else { + // estimate the written bytes to the underlying byte buffer based on uncompressed written bytes + return batchHeaderSizeInBytes + (int) (uncompressedRecordsSizeInBytes * estimatedCompressionRatio * COMPRESSION_RATE_ESTIMATION_FACTOR); + } + } + + /** + * Set the estimated compression ratio for the memory records builder. + */ + public void setEstimatedCompressionRatio(float estimatedCompressionRatio) { + this.estimatedCompressionRatio = estimatedCompressionRatio; + } + + /** + * Check if we have room for a new record containing the given key/value pair. If no records have been + * appended, then this returns true. + */ + public boolean hasRoomFor(long timestamp, byte[] key, byte[] value, Header[] headers) { + return hasRoomFor(timestamp, wrapNullable(key), wrapNullable(value), headers); + } + + /** + * Check if we have room for a new record containing the given key/value pair. If no records have been + * appended, then this returns true. + * + * Note that the return value is based on the estimate of the bytes written to the compressor, which may not be + * accurate if compression is used. When this happens, the following append may cause dynamic buffer + * re-allocation in the underlying byte buffer stream. + */ + public boolean hasRoomFor(long timestamp, ByteBuffer key, ByteBuffer value, Header[] headers) { + if (isFull()) + return false; + + // We always allow at least one record to be appended (the ByteBufferOutputStream will grow as needed) + if (numRecords == 0) + return true; + + final int recordSize; + if (magic < RecordBatch.MAGIC_VALUE_V2) { + recordSize = Records.LOG_OVERHEAD + LegacyRecord.recordSize(magic, key, value); + } else { + int nextOffsetDelta = lastOffset == null ? 0 : (int) (lastOffset - baseOffset + 1); + long timestampDelta = baseTimestamp == null ? 0 : timestamp - baseTimestamp; + recordSize = DefaultRecord.sizeInBytes(nextOffsetDelta, timestampDelta, key, value, headers); + } + + // Be conservative and not take compression of the new record into consideration. + return this.writeLimit >= estimatedBytesWritten() + recordSize; + } + + public boolean isClosed() { + return builtRecords != null; + } + + public boolean isFull() { + // note that the write limit is respected only after the first record is added which ensures we can always + // create non-empty batches (this is used to disable batching when the producer's batch size is set to 0). + return appendStream == CLOSED_STREAM || (this.numRecords > 0 && this.writeLimit <= estimatedBytesWritten()); + } + + /** + * Get an estimate of the number of bytes written to the underlying buffer. The returned value + * is exactly correct if the record set is not compressed or if the builder has been closed. + */ + public int estimatedSizeInBytes() { + return builtRecords != null ? builtRecords.sizeInBytes() : estimatedBytesWritten(); + } + + public byte magic() { + return magic; + } + + private long nextSequentialOffset() { + return lastOffset == null ? baseOffset : lastOffset + 1; + } + + public static class RecordsInfo { + public final long maxTimestamp; + public final long shallowOffsetOfMaxTimestamp; + + public RecordsInfo(long maxTimestamp, + long shallowOffsetOfMaxTimestamp) { + this.maxTimestamp = maxTimestamp; + this.shallowOffsetOfMaxTimestamp = shallowOffsetOfMaxTimestamp; + } + } + + /** + * Return the producer id of the RecordBatches created by this builder. + */ + public long producerId() { + return this.producerId; + } + + public short producerEpoch() { + return this.producerEpoch; + } + + public int baseSequence() { + return this.baseSequence; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/MultiRecordsSend.java b/clients/src/main/java/org/apache/kafka/common/record/MultiRecordsSend.java new file mode 100644 index 0000000..22883b2 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/MultiRecordsSend.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.network.Send; +import org.apache.kafka.common.network.TransferableChannel; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Queue; + +/** + * A set of composite sends with nested {@link RecordsSend}, sent one after another + */ +public class MultiRecordsSend implements Send { + private static final Logger log = LoggerFactory.getLogger(MultiRecordsSend.class); + + private final Queue sendQueue; + private final long size; + private Map recordConversionStats; + + private long totalWritten = 0; + private Send current; + + /** + * Construct a MultiRecordsSend from a queue of Send objects. The queue will be consumed as the MultiRecordsSend + * progresses (on completion, it will be empty). + */ + public MultiRecordsSend(Queue sends) { + this.sendQueue = sends; + + long size = 0; + for (Send send : sends) + size += send.size(); + this.size = size; + + this.current = sendQueue.poll(); + } + + public MultiRecordsSend(Queue sends, long size) { + this.sendQueue = sends; + this.size = size; + this.current = sendQueue.poll(); + } + + @Override + public long size() { + return size; + } + + @Override + public boolean completed() { + return current == null; + } + + // Visible for testing + int numResidentSends() { + int count = 0; + if (current != null) + count += 1; + count += sendQueue.size(); + return count; + } + + @Override + public long writeTo(TransferableChannel channel) throws IOException { + if (completed()) + throw new KafkaException("This operation cannot be invoked on a complete request."); + + int totalWrittenPerCall = 0; + boolean sendComplete; + do { + long written = current.writeTo(channel); + totalWrittenPerCall += written; + sendComplete = current.completed(); + if (sendComplete) { + updateRecordConversionStats(current); + current = sendQueue.poll(); + } + } while (!completed() && sendComplete); + + totalWritten += totalWrittenPerCall; + + if (completed() && totalWritten != size) + log.error("mismatch in sending bytes over socket; expected: {} actual: {}", size, totalWritten); + + log.trace("Bytes written as part of multi-send call: {}, total bytes written so far: {}, expected bytes to write: {}", + totalWrittenPerCall, totalWritten, size); + + return totalWrittenPerCall; + } + + /** + * Get any statistics that were recorded as part of executing this {@link MultiRecordsSend}. + * @return Records processing statistics (could be null if no statistics were collected) + */ + public Map recordConversionStats() { + return recordConversionStats; + } + + @Override + public String toString() { + return "MultiRecordsSend(" + + "size=" + size + + ", totalWritten=" + totalWritten + + ')'; + } + + private void updateRecordConversionStats(Send completedSend) { + // The underlying send might have accumulated statistics that need to be recorded. For example, + // LazyDownConversionRecordsSend accumulates statistics related to the number of bytes down-converted, the amount + // of temporary memory used for down-conversion, etc. Pull out any such statistics from the underlying send + // and fold it up appropriately. + if (completedSend instanceof LazyDownConversionRecordsSend) { + if (recordConversionStats == null) + recordConversionStats = new HashMap<>(); + LazyDownConversionRecordsSend lazyRecordsSend = (LazyDownConversionRecordsSend) completedSend; + recordConversionStats.put(lazyRecordsSend.topicPartition(), lazyRecordsSend.recordConversionStats()); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/MutableRecordBatch.java b/clients/src/main/java/org/apache/kafka/common/record/MutableRecordBatch.java new file mode 100644 index 0000000..fc924b0 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/MutableRecordBatch.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.utils.BufferSupplier; +import org.apache.kafka.common.utils.ByteBufferOutputStream; +import org.apache.kafka.common.utils.CloseableIterator; + +/** + * A mutable record batch is one that can be modified in place (without copying). This is used by the broker + * to override certain fields in the batch before appending it to the log. + */ +public interface MutableRecordBatch extends RecordBatch { + + /** + * Set the last offset of this batch. + * @param offset The last offset to use + */ + void setLastOffset(long offset); + + /** + * Set the max timestamp for this batch. When using log append time, this effectively overrides the individual + * timestamps of all the records contained in the batch. To avoid recompression, the record fields are not updated + * by this method, but clients ignore them if the timestamp time is log append time. Note that firstTimestamp is not + * updated by this method. + * + * This typically requires re-computation of the batch's CRC. + * + * @param timestampType The timestamp type + * @param maxTimestamp The maximum timestamp + */ + void setMaxTimestamp(TimestampType timestampType, long maxTimestamp); + + /** + * Set the partition leader epoch for this batch of records. + * @param epoch The partition leader epoch to use + */ + void setPartitionLeaderEpoch(int epoch); + + /** + * Write this record batch into an output stream. + * @param outputStream The buffer to write the batch to + */ + void writeTo(ByteBufferOutputStream outputStream); + + /** + * Return an iterator which skips parsing key, value and headers from the record stream, and therefore the resulted + * {@code org.apache.kafka.common.record.Record}'s key and value fields would be empty. This iterator is used + * when the read record's key and value are not needed and hence can save some byte buffer allocating / GC overhead. + * + * @return The closeable iterator + */ + CloseableIterator skipKeyValueIterator(BufferSupplier bufferSupplier); +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/PartialDefaultRecord.java b/clients/src/main/java/org/apache/kafka/common/record/PartialDefaultRecord.java new file mode 100644 index 0000000..67ca1ab --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/PartialDefaultRecord.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.header.Header; + +import java.nio.ByteBuffer; + +public class PartialDefaultRecord extends DefaultRecord { + + private final int keySize; + private final int valueSize; + + PartialDefaultRecord(int sizeInBytes, + byte attributes, + long offset, + long timestamp, + int sequence, + int keySize, + int valueSize) { + super(sizeInBytes, attributes, offset, timestamp, sequence, null, null, null); + + this.keySize = keySize; + this.valueSize = valueSize; + } + + @Override + public boolean equals(Object o) { + return super.equals(o) && + this.keySize == ((PartialDefaultRecord) o).keySize && + this.valueSize == ((PartialDefaultRecord) o).valueSize; + } + + @Override + public int hashCode() { + int result = super.hashCode(); + result = 31 * result + keySize; + result = 31 * result + valueSize; + return result; + } + + @Override + public String toString() { + return String.format("PartialDefaultRecord(offset=%d, timestamp=%d, key=%d bytes, value=%d bytes)", + offset(), + timestamp(), + keySize, + valueSize); + } + + @Override + public int keySize() { + return keySize; + } + + @Override + public boolean hasKey() { + return keySize >= 0; + } + + @Override + public ByteBuffer key() { + throw new UnsupportedOperationException("key is skipped in PartialDefaultRecord"); + } + + @Override + public int valueSize() { + return valueSize; + } + + @Override + public boolean hasValue() { + return valueSize >= 0; + } + + @Override + public ByteBuffer value() { + throw new UnsupportedOperationException("value is skipped in PartialDefaultRecord"); + } + + @Override + public Header[] headers() { + throw new UnsupportedOperationException("headers is skipped in PartialDefaultRecord"); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/Record.java b/clients/src/main/java/org/apache/kafka/common/record/Record.java new file mode 100644 index 0000000..4a387f7 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/Record.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import java.nio.ByteBuffer; + +import org.apache.kafka.common.header.Header; + +/** + * A log record is a tuple consisting of a unique offset in the log, a sequence number assigned by + * the producer, a timestamp, a key and a value. + */ +public interface Record { + + Header[] EMPTY_HEADERS = new Header[0]; + + /** + * The offset of this record in the log + * @return the offset + */ + long offset(); + + /** + * Get the sequence number assigned by the producer. + * @return the sequence number + */ + int sequence(); + + /** + * Get the size in bytes of this record. + * @return the size of the record in bytes + */ + int sizeInBytes(); + + /** + * Get the record's timestamp. + * @return the record's timestamp + */ + long timestamp(); + + /** + * Raise a {@link org.apache.kafka.common.errors.CorruptRecordException} if the record does not have a valid checksum. + */ + void ensureValid(); + + /** + * Get the size in bytes of the key. + * @return the size of the key, or -1 if there is no key + */ + int keySize(); + + /** + * Check whether this record has a key + * @return true if there is a key, false otherwise + */ + boolean hasKey(); + + /** + * Get the record's key. + * @return the key or null if there is none + */ + ByteBuffer key(); + + /** + * Get the size in bytes of the value. + * @return the size of the value, or -1 if the value is null + */ + int valueSize(); + + /** + * Check whether a value is present (i.e. if the value is not null) + * @return true if so, false otherwise + */ + boolean hasValue(); + + /** + * Get the record's value + * @return the (nullable) value + */ + ByteBuffer value(); + + /** + * Check whether the record has a particular magic. For versions prior to 2, the record contains its own magic, + * so this function can be used to check whether it matches a particular value. For version 2 and above, this + * method returns true if the passed magic is greater than or equal to 2. + * + * @param magic the magic value to check + * @return true if the record has a magic field (versions prior to 2) and the value matches + */ + boolean hasMagic(byte magic); + + /** + * For versions prior to 2, check whether the record is compressed (and therefore + * has nested record content). For versions 2 and above, this always returns false. + * @return true if the magic is lower than 2 and the record is compressed + */ + boolean isCompressed(); + + /** + * For versions prior to 2, the record contained a timestamp type attribute. This method can be + * used to check whether the value of that attribute matches a particular timestamp type. For versions + * 2 and above, this will always be false. + * + * @param timestampType the timestamp type to compare + * @return true if the version is lower than 2 and the timestamp type matches + */ + boolean hasTimestampType(TimestampType timestampType); + + /** + * Get the headers. For magic versions 1 and below, this always returns an empty array. + * + * @return the array of headers + */ + Header[] headers(); +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/RecordBatch.java b/clients/src/main/java/org/apache/kafka/common/record/RecordBatch.java new file mode 100644 index 0000000..7d231c1 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/RecordBatch.java @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.utils.BufferSupplier; +import org.apache.kafka.common.utils.CloseableIterator; + +import java.nio.ByteBuffer; +import java.util.Iterator; +import java.util.OptionalLong; + +/** + * A record batch is a container for records. In old versions of the record format (versions 0 and 1), + * a batch consisted always of a single record if no compression was enabled, but could contain + * many records otherwise. Newer versions (magic versions 2 and above) will generally contain many records + * regardless of compression. + */ +public interface RecordBatch extends Iterable { + + /** + * The "magic" values + */ + byte MAGIC_VALUE_V0 = 0; + byte MAGIC_VALUE_V1 = 1; + byte MAGIC_VALUE_V2 = 2; + + /** + * The current "magic" value + */ + byte CURRENT_MAGIC_VALUE = MAGIC_VALUE_V2; + + /** + * Timestamp value for records without a timestamp + */ + long NO_TIMESTAMP = -1L; + + /** + * Values used in the v2 record format by non-idempotent/non-transactional producers or when + * up-converting from an older format. + */ + long NO_PRODUCER_ID = -1L; + short NO_PRODUCER_EPOCH = -1; + int NO_SEQUENCE = -1; + + /** + * Used to indicate an unknown leader epoch, which will be the case when the record set is + * first created by the producer. + */ + int NO_PARTITION_LEADER_EPOCH = -1; + + /** + * Check whether the checksum of this batch is correct. + * + * @return true If so, false otherwise + */ + boolean isValid(); + + /** + * Raise an exception if the checksum is not valid. + */ + void ensureValid(); + + /** + * Get the checksum of this record batch, which covers the batch header as well as all of the records. + * + * @return The 4-byte unsigned checksum represented as a long + */ + long checksum(); + + /** + * Get the max timestamp or log append time of this record batch. + * + * If the timestamp type is create time, this is the max timestamp among all records contained in this batch and + * the value is updated during compaction. + * + * @return The max timestamp + */ + long maxTimestamp(); + + /** + * Get the timestamp type of this record batch. This will be {@link TimestampType#NO_TIMESTAMP_TYPE} + * if the batch has magic 0. + * + * @return The timestamp type + */ + TimestampType timestampType(); + + /** + * Get the base offset contained in this record batch. For magic version prior to 2, the base offset will + * always be the offset of the first message in the batch. This generally requires deep iteration and will + * return the offset of the first record in the record batch. For magic version 2 and above, this will return + * the first offset of the original record batch (i.e. prior to compaction). For non-compacted topics, the + * behavior is equivalent. + * + * Because this requires deep iteration for older magic versions, this method should be used with + * caution. Generally {@link #lastOffset()} is safer since access is efficient for all magic versions. + * + * @return The base offset of this record batch (which may or may not be the offset of the first record + * as described above). + */ + long baseOffset(); + + /** + * Get the last offset in this record batch (inclusive). Just like {@link #baseOffset()}, the last offset + * always reflects the offset of the last record in the original batch, even if it is removed during log + * compaction. + * + * @return The offset of the last record in this batch + */ + long lastOffset(); + + /** + * Get the offset following this record batch (i.e. the last offset contained in this batch plus one). + * + * @return the next consecutive offset following this batch + */ + long nextOffset(); + + /** + * Get the record format version of this record batch (i.e its magic value). + * + * @return the magic byte + */ + byte magic(); + + /** + * Get the producer id for this log record batch. For older magic versions, this will return -1. + * + * @return The producer id or -1 if there is none + */ + long producerId(); + + /** + * Get the producer epoch for this log record batch. + * + * @return The producer epoch, or -1 if there is none + */ + short producerEpoch(); + + /** + * Does the batch have a valid producer id set. + */ + boolean hasProducerId(); + + /** + * Get the base sequence number of this record batch. Like {@link #baseOffset()}, this value is not + * affected by compaction: it always retains the base sequence number from the original batch. + * + * @return The first sequence number or -1 if there is none + */ + int baseSequence(); + + /** + * Get the last sequence number of this record batch. Like {@link #lastOffset()}, the last sequence number + * always reflects the sequence number of the last record in the original batch, even if it is removed during log + * compaction. + * + * @return The last sequence number or -1 if there is none + */ + int lastSequence(); + + /** + * Get the compression type of this record batch. + * + * @return The compression type + */ + CompressionType compressionType(); + + /** + * Get the size in bytes of this batch, including the size of the record and the batch overhead. + * @return The size in bytes of this batch + */ + int sizeInBytes(); + + /** + * Get the count if it is efficiently supported by the record format (which is only the case + * for magic 2 and higher). + * + * @return The number of records in the batch or null for magic versions 0 and 1. + */ + Integer countOrNull(); + + /** + * Check whether this record batch is compressed. + * @return true if so, false otherwise + */ + boolean isCompressed(); + + /** + * Write this record batch into a buffer. + * @param buffer The buffer to write the batch to + */ + void writeTo(ByteBuffer buffer); + + /** + * Whether or not this record batch is part of a transaction. + * @return true if it is, false otherwise + */ + boolean isTransactional(); + + /** + * Get the delete horizon, returns OptionalLong.EMPTY if the first timestamp is not the delete horizon + * @return timestamp of the delete horizon + */ + OptionalLong deleteHorizonMs(); + + /** + * Get the partition leader epoch of this record batch. + * @return The leader epoch or -1 if it is unknown + */ + int partitionLeaderEpoch(); + + /** + * Return a streaming iterator which basically delays decompression of the record stream until the records + * are actually asked for using {@link Iterator#next()}. If the message format does not support streaming + * iteration, then the normal iterator is returned. Either way, callers should ensure that the iterator is closed. + * + * @param decompressionBufferSupplier The supplier of ByteBuffer(s) used for decompression if supported. + * For small record batches, allocating a potentially large buffer (64 KB for LZ4) + * will dominate the cost of decompressing and iterating over the records in the + * batch. As such, a supplier that reuses buffers will have a significant + * performance impact. + * @return The closeable iterator + */ + CloseableIterator streamingIterator(BufferSupplier decompressionBufferSupplier); + + /** + * Check whether this is a control batch (i.e. whether the control bit is set in the batch attributes). + * For magic versions prior to 2, this is always false. + * + * @return Whether this is a batch containing control records + */ + boolean isControlBatch(); +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/RecordBatchIterator.java b/clients/src/main/java/org/apache/kafka/common/record/RecordBatchIterator.java new file mode 100644 index 0000000..88af039 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/RecordBatchIterator.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.utils.AbstractIterator; + +import java.io.IOException; + +class RecordBatchIterator extends AbstractIterator { + + private final LogInputStream logInputStream; + + RecordBatchIterator(LogInputStream logInputStream) { + this.logInputStream = logInputStream; + } + + @Override + protected T makeNext() { + try { + T batch = logInputStream.nextBatch(); + if (batch == null) + return allDone(); + return batch; + } catch (IOException e) { + throw new KafkaException(e); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/RecordConversionStats.java b/clients/src/main/java/org/apache/kafka/common/record/RecordConversionStats.java new file mode 100644 index 0000000..4f0bca5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/RecordConversionStats.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +public class RecordConversionStats { + + public static final RecordConversionStats EMPTY = new RecordConversionStats(); + + private long temporaryMemoryBytes; + private int numRecordsConverted; + private long conversionTimeNanos; + + public RecordConversionStats(long temporaryMemoryBytes, int numRecordsConverted, long conversionTimeNanos) { + this.temporaryMemoryBytes = temporaryMemoryBytes; + this.numRecordsConverted = numRecordsConverted; + this.conversionTimeNanos = conversionTimeNanos; + } + + public RecordConversionStats() { + this(0, 0, 0); + } + + public void add(RecordConversionStats stats) { + temporaryMemoryBytes += stats.temporaryMemoryBytes; + numRecordsConverted += stats.numRecordsConverted; + conversionTimeNanos += stats.conversionTimeNanos; + } + + /** + * Returns the number of temporary memory bytes allocated to process the records. + * This size depends on whether the records need decompression and/or conversion: + *
          + *
        • Non compressed, no conversion: zero
        • + *
        • Non compressed, with conversion: size of the converted buffer
        • + *
        • Compressed, no conversion: size of the original buffer after decompression
        • + *
        • Compressed, with conversion: size of the original buffer after decompression + size of the converted buffer uncompressed
        • + *
        + */ + public long temporaryMemoryBytes() { + return temporaryMemoryBytes; + } + + public int numRecordsConverted() { + return numRecordsConverted; + } + + public long conversionTimeNanos() { + return conversionTimeNanos; + } + + @Override + public String toString() { + return String.format("RecordConversionStats(temporaryMemoryBytes=%d, numRecordsConverted=%d, conversionTimeNanos=%d)", + temporaryMemoryBytes, numRecordsConverted, conversionTimeNanos); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/RecordVersion.java b/clients/src/main/java/org/apache/kafka/common/record/RecordVersion.java new file mode 100644 index 0000000..8406d53 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/RecordVersion.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +/** + * Defines the record format versions supported by Kafka. + * + * For historical reasons, the record format version is also known as `magic` and `message format version`. Note that + * the version actually applies to the {@link RecordBatch} (instead of the {@link Record}). Finally, the + * `message.format.version` topic config confusingly expects an ApiVersion instead of a RecordVersion. + */ +public enum RecordVersion { + V0(0), V1(1), V2(2); + + private static final RecordVersion[] VALUES = values(); + + public final byte value; + + RecordVersion(int value) { + this.value = (byte) value; + } + + /** + * Check whether this version precedes another version. + * + * @return true only if the magic value is less than the other's + */ + public boolean precedes(RecordVersion other) { + return this.value < other.value; + } + + public static RecordVersion lookup(byte value) { + if (value < 0 || value >= VALUES.length) + throw new IllegalArgumentException("Unknown record version: " + value); + return VALUES[value]; + } + + public static RecordVersion current() { + return V2; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/Records.java b/clients/src/main/java/org/apache/kafka/common/record/Records.java new file mode 100644 index 0000000..2179c7c --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/Records.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.utils.AbstractIterator; +import org.apache.kafka.common.utils.Time; + +import java.util.Iterator; + + +/** + * Interface for accessing the records contained in a log. The log itself is represented as a sequence of record + * batches (see {@link RecordBatch}). + * + * For magic versions 1 and below, each batch consists of an 8 byte offset, a 4 byte record size, and a "shallow" + * {@link Record record}. If the batch is not compressed, then each batch will have only the shallow record contained + * inside it. If it is compressed, the batch contains "deep" records, which are packed into the value field of the + * shallow record. To iterate over the shallow batches, use {@link Records#batches()}; for the deep records, use + * {@link Records#records()}. Note that the deep iterator handles both compressed and non-compressed batches: + * if the batch is not compressed, the shallow record is returned; otherwise, the shallow batch is decompressed and the + * deep records are returned. + * + * For magic version 2, every batch contains 1 or more log record, regardless of compression. You can iterate + * over the batches directly using {@link Records#batches()}. Records can be iterated either directly from an individual + * batch or through {@link Records#records()}. Just as in previous versions, iterating over the records typically involves + * decompression and should therefore be used with caution. + * + * See {@link MemoryRecords} for the in-memory representation and {@link FileRecords} for the on-disk representation. + */ +public interface Records extends TransferableRecords { + int OFFSET_OFFSET = 0; + int OFFSET_LENGTH = 8; + int SIZE_OFFSET = OFFSET_OFFSET + OFFSET_LENGTH; + int SIZE_LENGTH = 4; + int LOG_OVERHEAD = SIZE_OFFSET + SIZE_LENGTH; + + // the magic offset is at the same offset for all current message formats, but the 4 bytes + // between the size and the magic is dependent on the version. + int MAGIC_OFFSET = 16; + int MAGIC_LENGTH = 1; + int HEADER_SIZE_UP_TO_MAGIC = MAGIC_OFFSET + MAGIC_LENGTH; + + /** + * Get the record batches. Note that the signature allows subclasses + * to return a more specific batch type. This enables optimizations such as in-place offset + * assignment (see for example {@link DefaultRecordBatch}), and partial reading of + * record data (see {@link FileLogInputStream.FileChannelRecordBatch#magic()}. + * @return An iterator over the record batches of the log + */ + Iterable batches(); + + /** + * Get an iterator over the record batches. This is similar to {@link #batches()} but returns an {@link AbstractIterator} + * instead of {@link Iterator}, so that clients can use methods like {@link AbstractIterator#peek() peek}. + * @return An iterator over the record batches of the log + */ + AbstractIterator batchIterator(); + + /** + * Check whether all batches in this buffer have a certain magic value. + * @param magic The magic value to check + * @return true if all record batches have a matching magic value, false otherwise + */ + boolean hasMatchingMagic(byte magic); + + /** + * Convert all batches in this buffer to the format passed as a parameter. Note that this requires + * deep iteration since all of the deep records must also be converted to the desired format. + * @param toMagic The magic value to convert to + * @param firstOffset The starting offset for returned records. This only impacts some cases. See + * {@link RecordsUtil#downConvert(Iterable, byte, long, Time)} for an explanation. + * @param time instance used for reporting stats + * @return A ConvertedRecords instance which may or may not contain the same instance in its records field. + */ + ConvertedRecords downConvert(byte toMagic, long firstOffset, Time time); + + /** + * Get an iterator over the records in this log. Note that this generally requires decompression, + * and should therefore be used with care. + * @return The record iterator + */ + Iterable records(); +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/RecordsSend.java b/clients/src/main/java/org/apache/kafka/common/record/RecordsSend.java new file mode 100644 index 0000000..b582ec2 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/RecordsSend.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.network.Send; +import org.apache.kafka.common.network.TransferableChannel; + +import java.io.EOFException; +import java.io.IOException; +import java.nio.ByteBuffer; + +public abstract class RecordsSend implements Send { + private static final ByteBuffer EMPTY_BYTE_BUFFER = ByteBuffer.allocate(0); + + private final T records; + private final int maxBytesToWrite; + private int remaining; + private boolean pending = false; + + protected RecordsSend(T records, int maxBytesToWrite) { + this.records = records; + this.maxBytesToWrite = maxBytesToWrite; + this.remaining = maxBytesToWrite; + } + + @Override + public boolean completed() { + return remaining <= 0 && !pending; + } + + @Override + public final long writeTo(TransferableChannel channel) throws IOException { + long written = 0; + + if (remaining > 0) { + written = writeTo(channel, size() - remaining, remaining); + if (written < 0) + throw new EOFException("Wrote negative bytes to channel. This shouldn't happen."); + remaining -= written; + } + + pending = channel.hasPendingWrites(); + if (remaining <= 0 && pending) + channel.write(EMPTY_BYTE_BUFFER); + + return written; + } + + @Override + public long size() { + return maxBytesToWrite; + } + + protected T records() { + return records; + } + + /** + * Write records up to `remaining` bytes to `channel`. The implementation is allowed to be stateful. The contract + * from the caller is that the first invocation will be with `previouslyWritten` equal to 0, and `remaining` equal to + * the to maximum bytes we want to write the to `channel`. `previouslyWritten` and `remaining` will be adjusted + * appropriately for every subsequent invocation. See {@link #writeTo} for example expected usage. + * @param channel The channel to write to + * @param previouslyWritten Bytes written in previous calls to {@link #writeTo(TransferableChannel, long, int)}; 0 if being called for the first time + * @param remaining Number of bytes remaining to be written + * @return The number of bytes actually written + * @throws IOException For any IO errors + */ + protected abstract long writeTo(TransferableChannel channel, long previouslyWritten, int remaining) throws IOException; +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/RecordsUtil.java b/clients/src/main/java/org/apache/kafka/common/record/RecordsUtil.java new file mode 100644 index 0000000..423d1e1 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/RecordsUtil.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.errors.UnsupportedCompressionTypeException; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +public class RecordsUtil { + /** + * Down convert batches to the provided message format version. The first offset parameter is only relevant in the + * conversion from uncompressed v2 or higher to v1 or lower. The reason is that uncompressed records in v0 and v1 + * are not batched (put another way, each batch always has 1 record). + * + * If a client requests records in v1 format starting from the middle of an uncompressed batch in v2 format, we + * need to drop records from the batch during the conversion. Some versions of librdkafka rely on this for + * correctness. + * + * The temporaryMemoryBytes computation assumes that the batches are not loaded into the heap + * (via classes like FileChannelRecordBatch) before this method is called. This is the case in the broker (we + * only load records into the heap when down converting), but it's not for the producer. However, down converting + * in the producer is very uncommon and the extra complexity to handle that case is not worth it. + */ + protected static ConvertedRecords downConvert(Iterable batches, byte toMagic, + long firstOffset, Time time) { + // maintain the batch along with the decompressed records to avoid the need to decompress again + List recordBatchAndRecordsList = new ArrayList<>(); + int totalSizeEstimate = 0; + long startNanos = time.nanoseconds(); + + for (RecordBatch batch : batches) { + if (toMagic < RecordBatch.MAGIC_VALUE_V2) { + if (batch.isControlBatch()) + continue; + + if (batch.compressionType() == CompressionType.ZSTD) + throw new UnsupportedCompressionTypeException("Down-conversion of zstandard-compressed batches " + + "is not supported"); + } + + if (batch.magic() <= toMagic) { + totalSizeEstimate += batch.sizeInBytes(); + recordBatchAndRecordsList.add(new RecordBatchAndRecords(batch, null, null)); + } else { + List records = new ArrayList<>(); + for (Record record : batch) { + // See the method javadoc for an explanation + if (toMagic > RecordBatch.MAGIC_VALUE_V1 || batch.isCompressed() || record.offset() >= firstOffset) + records.add(record); + } + if (records.isEmpty()) + continue; + final long baseOffset; + if (batch.magic() >= RecordBatch.MAGIC_VALUE_V2 && toMagic >= RecordBatch.MAGIC_VALUE_V2) + baseOffset = batch.baseOffset(); + else + baseOffset = records.get(0).offset(); + totalSizeEstimate += AbstractRecords.estimateSizeInBytes(toMagic, baseOffset, batch.compressionType(), records); + recordBatchAndRecordsList.add(new RecordBatchAndRecords(batch, records, baseOffset)); + } + } + + ByteBuffer buffer = ByteBuffer.allocate(totalSizeEstimate); + long temporaryMemoryBytes = 0; + int numRecordsConverted = 0; + + for (RecordBatchAndRecords recordBatchAndRecords : recordBatchAndRecordsList) { + temporaryMemoryBytes += recordBatchAndRecords.batch.sizeInBytes(); + if (recordBatchAndRecords.batch.magic() <= toMagic) { + buffer = Utils.ensureCapacity(buffer, buffer.position() + recordBatchAndRecords.batch.sizeInBytes()); + recordBatchAndRecords.batch.writeTo(buffer); + } else { + MemoryRecordsBuilder builder = convertRecordBatch(toMagic, buffer, recordBatchAndRecords); + buffer = builder.buffer(); + temporaryMemoryBytes += builder.uncompressedBytesWritten(); + numRecordsConverted += builder.numRecords(); + } + } + + buffer.flip(); + RecordConversionStats stats = new RecordConversionStats(temporaryMemoryBytes, numRecordsConverted, + time.nanoseconds() - startNanos); + return new ConvertedRecords<>(MemoryRecords.readableRecords(buffer), stats); + } + + /** + * Return a buffer containing the converted record batches. The returned buffer may not be the same as the received + * one (e.g. it may require expansion). + */ + private static MemoryRecordsBuilder convertRecordBatch(byte magic, ByteBuffer buffer, RecordBatchAndRecords recordBatchAndRecords) { + RecordBatch batch = recordBatchAndRecords.batch; + final TimestampType timestampType = batch.timestampType(); + long logAppendTime = timestampType == TimestampType.LOG_APPEND_TIME ? batch.maxTimestamp() : RecordBatch.NO_TIMESTAMP; + + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, magic, batch.compressionType(), + timestampType, recordBatchAndRecords.baseOffset, logAppendTime); + for (Record record : recordBatchAndRecords.records) { + // Down-convert this record. Ignore headers when down-converting to V0 and V1 since they are not supported + if (magic > RecordBatch.MAGIC_VALUE_V1) + builder.append(record); + else + builder.appendWithOffset(record.offset(), record.timestamp(), record.key(), record.value()); + } + + builder.close(); + return builder; + } + + + private static class RecordBatchAndRecords { + private final RecordBatch batch; + private final List records; + private final Long baseOffset; + + private RecordBatchAndRecords(RecordBatch batch, List records, Long baseOffset) { + this.batch = batch; + this.records = records; + this.baseOffset = baseOffset; + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/SimpleRecord.java b/clients/src/main/java/org/apache/kafka/common/record/SimpleRecord.java new file mode 100644 index 0000000..9b8c87f --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/SimpleRecord.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.utils.Utils; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Objects; + +/** + * High-level representation of a kafka record. This is useful when building record sets to + * avoid depending on a specific magic version. + */ +public class SimpleRecord { + private final ByteBuffer key; + private final ByteBuffer value; + private final long timestamp; + private final Header[] headers; + + public SimpleRecord(long timestamp, ByteBuffer key, ByteBuffer value, Header[] headers) { + Objects.requireNonNull(headers, "Headers must be non-null"); + this.key = key; + this.value = value; + this.timestamp = timestamp; + this.headers = headers; + } + + public SimpleRecord(long timestamp, byte[] key, byte[] value, Header[] headers) { + this(timestamp, Utils.wrapNullable(key), Utils.wrapNullable(value), headers); + } + + public SimpleRecord(long timestamp, ByteBuffer key, ByteBuffer value) { + this(timestamp, key, value, Record.EMPTY_HEADERS); + } + + public SimpleRecord(long timestamp, byte[] key, byte[] value) { + this(timestamp, Utils.wrapNullable(key), Utils.wrapNullable(value)); + } + + public SimpleRecord(long timestamp, byte[] value) { + this(timestamp, null, value); + } + + public SimpleRecord(byte[] value) { + this(RecordBatch.NO_TIMESTAMP, null, value); + } + + public SimpleRecord(ByteBuffer value) { + this(RecordBatch.NO_TIMESTAMP, null, value); + } + + public SimpleRecord(byte[] key, byte[] value) { + this(RecordBatch.NO_TIMESTAMP, key, value); + } + + public SimpleRecord(Record record) { + this(record.timestamp(), record.key(), record.value(), record.headers()); + } + + public ByteBuffer key() { + return key; + } + + public ByteBuffer value() { + return value; + } + + public long timestamp() { + return timestamp; + } + + public Header[] headers() { + return headers; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + SimpleRecord that = (SimpleRecord) o; + return timestamp == that.timestamp && + Objects.equals(key, that.key) && + Objects.equals(value, that.value) && + Arrays.equals(headers, that.headers); + } + + @Override + public int hashCode() { + int result = key != null ? key.hashCode() : 0; + result = 31 * result + (value != null ? value.hashCode() : 0); + result = 31 * result + Long.hashCode(timestamp); + result = 31 * result + Arrays.hashCode(headers); + return result; + } + + @Override + public String toString() { + return String.format("SimpleRecord(timestamp=%d, key=%d bytes, value=%d bytes)", + timestamp(), + key == null ? 0 : key.limit(), + value == null ? 0 : value.limit()); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/TimestampType.java b/clients/src/main/java/org/apache/kafka/common/record/TimestampType.java new file mode 100644 index 0000000..becde9d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/TimestampType.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import java.util.NoSuchElementException; + +/** + * The timestamp type of the records. + */ +public enum TimestampType { + NO_TIMESTAMP_TYPE(-1, "NoTimestampType"), CREATE_TIME(0, "CreateTime"), LOG_APPEND_TIME(1, "LogAppendTime"); + + public final int id; + public final String name; + + TimestampType(int id, String name) { + this.id = id; + this.name = name; + } + + public static TimestampType forName(String name) { + for (TimestampType t : values()) + if (t.name.equals(name)) + return t; + throw new NoSuchElementException("Invalid timestamp type " + name); + } + + @Override + public String toString() { + return name; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/TransferableRecords.java b/clients/src/main/java/org/apache/kafka/common/record/TransferableRecords.java new file mode 100644 index 0000000..09c0304 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/TransferableRecords.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.network.TransferableChannel; + +import java.io.IOException; + +/** + * Represents a record set which can be transferred to a channel + * @see Records + * @see UnalignedRecords + */ +public interface TransferableRecords extends BaseRecords { + + /** + * Attempts to write the contents of this buffer to a channel. + * @param channel The channel to write to + * @param position The position in the buffer to write from + * @param length The number of bytes to write + * @return The number of bytes actually written + * @throws IOException For any IO errors + */ + long writeTo(TransferableChannel channel, long position, int length) throws IOException; +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/UnalignedFileRecords.java b/clients/src/main/java/org/apache/kafka/common/record/UnalignedFileRecords.java new file mode 100644 index 0000000..96970f9 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/UnalignedFileRecords.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.network.TransferableChannel; + +import java.io.IOException; +import java.nio.channels.FileChannel; + +/** + * Represents a file record set which is not necessarily offset-aligned + */ +public class UnalignedFileRecords implements UnalignedRecords { + + private final FileChannel channel; + private final long position; + private final int size; + + public UnalignedFileRecords(FileChannel channel, long position, int size) { + this.channel = channel; + this.position = position; + this.size = size; + } + + @Override + public int sizeInBytes() { + return size; + } + + @Override + public long writeTo(TransferableChannel destChannel, long previouslyWritten, int remaining) throws IOException { + long position = this.position + previouslyWritten; + long count = Math.min(remaining, sizeInBytes() - previouslyWritten); + return destChannel.transferFrom(channel, position, count); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/UnalignedMemoryRecords.java b/clients/src/main/java/org/apache/kafka/common/record/UnalignedMemoryRecords.java new file mode 100644 index 0000000..23795e3 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/UnalignedMemoryRecords.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.network.TransferableChannel; +import org.apache.kafka.common.utils.Utils; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Objects; + +/** + * Represents a memory record set which is not necessarily offset-aligned + */ +public class UnalignedMemoryRecords implements UnalignedRecords { + + private final ByteBuffer buffer; + + public UnalignedMemoryRecords(ByteBuffer buffer) { + this.buffer = Objects.requireNonNull(buffer); + } + + public ByteBuffer buffer() { + return buffer.duplicate(); + } + + @Override + public int sizeInBytes() { + return buffer.remaining(); + } + + @Override + public long writeTo(TransferableChannel channel, long position, int length) throws IOException { + if (position > Integer.MAX_VALUE) + throw new IllegalArgumentException("position should not be greater than Integer.MAX_VALUE: " + position); + if (position + length > buffer.limit()) + throw new IllegalArgumentException("position+length should not be greater than buffer.limit(), position: " + + position + ", length: " + length + ", buffer.limit(): " + buffer.limit()); + return Utils.tryWriteTo(channel, (int) position, length, buffer); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/record/UnalignedRecords.java b/clients/src/main/java/org/apache/kafka/common/record/UnalignedRecords.java new file mode 100644 index 0000000..561d1af --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/record/UnalignedRecords.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +/** + * Represents a record set which is not necessarily offset-aligned, and is + * only used when fetching raft snapshot + */ +public interface UnalignedRecords extends TransferableRecords { + + @Override + default RecordsSend toSend() { + return new DefaultRecordsSend<>(this, sizeInBytes()); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/replica/ClientMetadata.java b/clients/src/main/java/org/apache/kafka/common/replica/ClientMetadata.java new file mode 100644 index 0000000..b328733 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/replica/ClientMetadata.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.replica; + +import org.apache.kafka.common.security.auth.KafkaPrincipal; + +import java.net.InetAddress; +import java.util.Objects; + +/** + * Holder for all the client metadata required to determine a preferred replica. + */ +public interface ClientMetadata { + + /** + * Rack ID sent by the client + */ + String rackId(); + + /** + * Client ID sent by the client + */ + String clientId(); + + /** + * Incoming address of the client + */ + InetAddress clientAddress(); + + /** + * Security principal of the client + */ + KafkaPrincipal principal(); + + /** + * Listener name for the client + */ + String listenerName(); + + + class DefaultClientMetadata implements ClientMetadata { + private final String rackId; + private final String clientId; + private final InetAddress clientAddress; + private final KafkaPrincipal principal; + private final String listenerName; + + public DefaultClientMetadata(String rackId, String clientId, InetAddress clientAddress, + KafkaPrincipal principal, String listenerName) { + this.rackId = rackId; + this.clientId = clientId; + this.clientAddress = clientAddress; + this.principal = principal; + this.listenerName = listenerName; + } + + @Override + public String rackId() { + return rackId; + } + + @Override + public String clientId() { + return clientId; + } + + @Override + public InetAddress clientAddress() { + return clientAddress; + } + + @Override + public KafkaPrincipal principal() { + return principal; + } + + @Override + public String listenerName() { + return listenerName; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DefaultClientMetadata that = (DefaultClientMetadata) o; + return Objects.equals(rackId, that.rackId) && + Objects.equals(clientId, that.clientId) && + Objects.equals(clientAddress, that.clientAddress) && + Objects.equals(principal, that.principal) && + Objects.equals(listenerName, that.listenerName); + } + + @Override + public int hashCode() { + return Objects.hash(rackId, clientId, clientAddress, principal, listenerName); + } + + @Override + public String toString() { + return "DefaultClientMetadata{" + + "rackId='" + rackId + '\'' + + ", clientId='" + clientId + '\'' + + ", clientAddress=" + clientAddress + + ", principal=" + principal + + ", listenerName='" + listenerName + '\'' + + '}'; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/replica/PartitionView.java b/clients/src/main/java/org/apache/kafka/common/replica/PartitionView.java new file mode 100644 index 0000000..8174e63 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/replica/PartitionView.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.replica; + +import java.util.Collections; +import java.util.Objects; +import java.util.Set; + +/** + * View of a partition used by {@link ReplicaSelector} to determine a preferred replica. + */ +public interface PartitionView { + Set replicas(); + + ReplicaView leader(); + + class DefaultPartitionView implements PartitionView { + private final Set replicas; + private final ReplicaView leader; + + public DefaultPartitionView(Set replicas, ReplicaView leader) { + this.replicas = Collections.unmodifiableSet(replicas); + this.leader = leader; + } + + @Override + public Set replicas() { + return replicas; + } + + @Override + public ReplicaView leader() { + return leader; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DefaultPartitionView that = (DefaultPartitionView) o; + return Objects.equals(replicas, that.replicas) && + Objects.equals(leader, that.leader); + } + + @Override + public int hashCode() { + return Objects.hash(replicas, leader); + } + + @Override + public String toString() { + return "DefaultPartitionView{" + + "replicas=" + replicas + + ", leader=" + leader + + '}'; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/replica/RackAwareReplicaSelector.java b/clients/src/main/java/org/apache/kafka/common/replica/RackAwareReplicaSelector.java new file mode 100644 index 0000000..8ae6872 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/replica/RackAwareReplicaSelector.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.replica; + +import org.apache.kafka.common.TopicPartition; + +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Returns a replica whose rack id is equal to the rack id specified in the client request metadata. If no such replica + * is found, returns the leader. + */ +public class RackAwareReplicaSelector implements ReplicaSelector { + + @Override + public Optional select(TopicPartition topicPartition, + ClientMetadata clientMetadata, + PartitionView partitionView) { + if (clientMetadata.rackId() != null && !clientMetadata.rackId().isEmpty()) { + Set sameRackReplicas = partitionView.replicas().stream() + .filter(replicaInfo -> clientMetadata.rackId().equals(replicaInfo.endpoint().rack())) + .collect(Collectors.toSet()); + if (sameRackReplicas.isEmpty()) { + return Optional.of(partitionView.leader()); + } else { + if (sameRackReplicas.contains(partitionView.leader())) { + // Use the leader if it's in this rack + return Optional.of(partitionView.leader()); + } else { + // Otherwise, get the most caught-up replica + return sameRackReplicas.stream().max(ReplicaView.comparator()); + } + } + } else { + return Optional.of(partitionView.leader()); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/replica/ReplicaSelector.java b/clients/src/main/java/org/apache/kafka/common/replica/ReplicaSelector.java new file mode 100644 index 0000000..301fc9f --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/replica/ReplicaSelector.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.replica; + +import org.apache.kafka.common.Configurable; +import org.apache.kafka.common.TopicPartition; + +import java.io.Closeable; +import java.io.IOException; +import java.util.Map; +import java.util.Optional; + +/** + * Plug-able interface for selecting a preferred read replica given the current set of replicas for a partition + * and metadata from the client. + */ +public interface ReplicaSelector extends Configurable, Closeable { + + /** + * Select the preferred replica a client should use for fetching. If no replica is available, this will return an + * empty optional. + */ + Optional select(TopicPartition topicPartition, + ClientMetadata clientMetadata, + PartitionView partitionView); + @Override + default void close() throws IOException { + // No-op by default + } + + @Override + default void configure(Map configs) { + // No-op by default + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/replica/ReplicaView.java b/clients/src/main/java/org/apache/kafka/common/replica/ReplicaView.java new file mode 100644 index 0000000..69c6cf7 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/replica/ReplicaView.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.replica; + +import org.apache.kafka.common.Node; + +import java.util.Comparator; +import java.util.Objects; + +/** + * View of a replica used by {@link ReplicaSelector} to determine a preferred replica. + */ +public interface ReplicaView { + + /** + * The endpoint information for this replica (hostname, port, rack, etc) + */ + Node endpoint(); + + /** + * The log end offset for this replica + */ + long logEndOffset(); + + /** + * The number of milliseconds (if any) since the last time this replica was caught up to the high watermark. + * For a leader replica, this is always zero. + */ + long timeSinceLastCaughtUpMs(); + + /** + * Comparator for ReplicaView that returns in the order of "most caught up". This is used for deterministic + * selection of a replica when there is a tie from a selector. + */ + static Comparator comparator() { + return Comparator.comparingLong(ReplicaView::logEndOffset) + .thenComparing(Comparator.comparingLong(ReplicaView::timeSinceLastCaughtUpMs).reversed()) + .thenComparing(replicaInfo -> replicaInfo.endpoint().id()); + } + + class DefaultReplicaView implements ReplicaView { + private final Node endpoint; + private final long logEndOffset; + private final long timeSinceLastCaughtUpMs; + + public DefaultReplicaView(Node endpoint, long logEndOffset, long timeSinceLastCaughtUpMs) { + this.endpoint = endpoint; + this.logEndOffset = logEndOffset; + this.timeSinceLastCaughtUpMs = timeSinceLastCaughtUpMs; + } + + @Override + public Node endpoint() { + return endpoint; + } + + @Override + public long logEndOffset() { + return logEndOffset; + } + + @Override + public long timeSinceLastCaughtUpMs() { + return timeSinceLastCaughtUpMs; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + DefaultReplicaView that = (DefaultReplicaView) o; + return logEndOffset == that.logEndOffset && + Objects.equals(endpoint, that.endpoint) && + Objects.equals(timeSinceLastCaughtUpMs, that.timeSinceLastCaughtUpMs); + } + + @Override + public int hashCode() { + return Objects.hash(endpoint, logEndOffset, timeSinceLastCaughtUpMs); + } + + @Override + public String toString() { + return "DefaultReplicaView{" + + "endpoint=" + endpoint + + ", logEndOffset=" + logEndOffset + + ", timeSinceLastCaughtUpMs=" + timeSinceLastCaughtUpMs + + '}'; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/AbstractControlRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/AbstractControlRequest.java new file mode 100644 index 0000000..dc4a1e2 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/AbstractControlRequest.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.protocol.ApiKeys; + +// Abstract class for all control requests including UpdateMetadataRequest, LeaderAndIsrRequest and StopReplicaRequest +public abstract class AbstractControlRequest extends AbstractRequest { + + public static final long UNKNOWN_BROKER_EPOCH = -1L; + + public static abstract class Builder extends AbstractRequest.Builder { + protected final int controllerId; + protected final int controllerEpoch; + protected final long brokerEpoch; + + protected Builder(ApiKeys api, short version, int controllerId, int controllerEpoch, long brokerEpoch) { + super(api, version); + this.controllerId = controllerId; + this.controllerEpoch = controllerEpoch; + this.brokerEpoch = brokerEpoch; + } + + } + + protected AbstractControlRequest(ApiKeys api, short version) { + super(api, version); + } + + public abstract int controllerId(); + + public abstract int controllerEpoch(); + + public abstract long brokerEpoch(); + +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/AbstractRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/AbstractRequest.java new file mode 100644 index 0000000..0c38e99 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/AbstractRequest.java @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.network.Send; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.protocol.MessageUtil; +import org.apache.kafka.common.protocol.ObjectSerializationCache; +import org.apache.kafka.common.protocol.SendBuilder; + +import java.nio.ByteBuffer; +import java.util.Map; + +public abstract class AbstractRequest implements AbstractRequestResponse { + + public static abstract class Builder { + private final ApiKeys apiKey; + private final short oldestAllowedVersion; + private final short latestAllowedVersion; + + /** + * Construct a new builder which allows any supported version + */ + public Builder(ApiKeys apiKey) { + this(apiKey, apiKey.oldestVersion(), apiKey.latestVersion()); + } + + /** + * Construct a new builder which allows only a specific version + */ + public Builder(ApiKeys apiKey, short allowedVersion) { + this(apiKey, allowedVersion, allowedVersion); + } + + /** + * Construct a new builder which allows an inclusive range of versions + */ + public Builder(ApiKeys apiKey, short oldestAllowedVersion, short latestAllowedVersion) { + this.apiKey = apiKey; + this.oldestAllowedVersion = oldestAllowedVersion; + this.latestAllowedVersion = latestAllowedVersion; + } + + public ApiKeys apiKey() { + return apiKey; + } + + public short oldestAllowedVersion() { + return oldestAllowedVersion; + } + + public short latestAllowedVersion() { + return latestAllowedVersion; + } + + public T build() { + return build(latestAllowedVersion()); + } + + public abstract T build(short version); + } + + private final short version; + private final ApiKeys apiKey; + + public AbstractRequest(ApiKeys apiKey, short version) { + if (!apiKey.isVersionSupported(version)) + throw new UnsupportedVersionException("The " + apiKey + " protocol does not support version " + version); + this.version = version; + this.apiKey = apiKey; + } + + /** + * Get the version of this AbstractRequest object. + */ + public short version() { + return version; + } + + public ApiKeys apiKey() { + return apiKey; + } + + public final Send toSend(RequestHeader header) { + return SendBuilder.buildRequestSend(header, data()); + } + + /** + * Serializes header and body without prefixing with size (unlike `toSend`, which does include a size prefix). + */ + public final ByteBuffer serializeWithHeader(RequestHeader header) { + if (header.apiKey() != apiKey) { + throw new IllegalArgumentException("Could not build request " + apiKey + " with header api key " + header.apiKey()); + } + if (header.apiVersion() != version) { + throw new IllegalArgumentException("Could not build request version " + version + " with header version " + header.apiVersion()); + } + return RequestUtils.serialize(header.data(), header.headerVersion(), data(), version); + } + + // Visible for testing + public final ByteBuffer serialize() { + return MessageUtil.toByteBuffer(data(), version); + } + + // Visible for testing + final int sizeInBytes() { + return data().size(new ObjectSerializationCache(), version); + } + + public String toString(boolean verbose) { + return data().toString(); + } + + @Override + public final String toString() { + return toString(true); + } + + /** + * Get an error response for a request + */ + public AbstractResponse getErrorResponse(Throwable e) { + return getErrorResponse(AbstractResponse.DEFAULT_THROTTLE_TIME, e); + } + + /** + * Get an error response for a request with specified throttle time in the response if applicable + */ + public abstract AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e); + + /** + * Get the error counts corresponding to an error response. This is overridden for requests + * where response may be null (e.g produce with acks=0). + */ + public Map errorCounts(Throwable e) { + AbstractResponse response = getErrorResponse(0, e); + if (response == null) + throw new IllegalStateException("Error counts could not be obtained for request " + this); + else + return response.errorCounts(); + } + + /** + * Factory method for getting a request object based on ApiKey ID and a version + */ + public static RequestAndSize parseRequest(ApiKeys apiKey, short apiVersion, ByteBuffer buffer) { + int bufferSize = buffer.remaining(); + return new RequestAndSize(doParseRequest(apiKey, apiVersion, buffer), bufferSize); + } + + private static AbstractRequest doParseRequest(ApiKeys apiKey, short apiVersion, ByteBuffer buffer) { + switch (apiKey) { + case PRODUCE: + return ProduceRequest.parse(buffer, apiVersion); + case FETCH: + return FetchRequest.parse(buffer, apiVersion); + case LIST_OFFSETS: + return ListOffsetsRequest.parse(buffer, apiVersion); + case METADATA: + return MetadataRequest.parse(buffer, apiVersion); + case OFFSET_COMMIT: + return OffsetCommitRequest.parse(buffer, apiVersion); + case OFFSET_FETCH: + return OffsetFetchRequest.parse(buffer, apiVersion); + case FIND_COORDINATOR: + return FindCoordinatorRequest.parse(buffer, apiVersion); + case JOIN_GROUP: + return JoinGroupRequest.parse(buffer, apiVersion); + case HEARTBEAT: + return HeartbeatRequest.parse(buffer, apiVersion); + case LEAVE_GROUP: + return LeaveGroupRequest.parse(buffer, apiVersion); + case SYNC_GROUP: + return SyncGroupRequest.parse(buffer, apiVersion); + case STOP_REPLICA: + return StopReplicaRequest.parse(buffer, apiVersion); + case CONTROLLED_SHUTDOWN: + return ControlledShutdownRequest.parse(buffer, apiVersion); + case UPDATE_METADATA: + return UpdateMetadataRequest.parse(buffer, apiVersion); + case LEADER_AND_ISR: + return LeaderAndIsrRequest.parse(buffer, apiVersion); + case DESCRIBE_GROUPS: + return DescribeGroupsRequest.parse(buffer, apiVersion); + case LIST_GROUPS: + return ListGroupsRequest.parse(buffer, apiVersion); + case SASL_HANDSHAKE: + return SaslHandshakeRequest.parse(buffer, apiVersion); + case API_VERSIONS: + return ApiVersionsRequest.parse(buffer, apiVersion); + case CREATE_TOPICS: + return CreateTopicsRequest.parse(buffer, apiVersion); + case DELETE_TOPICS: + return DeleteTopicsRequest.parse(buffer, apiVersion); + case DELETE_RECORDS: + return DeleteRecordsRequest.parse(buffer, apiVersion); + case INIT_PRODUCER_ID: + return InitProducerIdRequest.parse(buffer, apiVersion); + case OFFSET_FOR_LEADER_EPOCH: + return OffsetsForLeaderEpochRequest.parse(buffer, apiVersion); + case ADD_PARTITIONS_TO_TXN: + return AddPartitionsToTxnRequest.parse(buffer, apiVersion); + case ADD_OFFSETS_TO_TXN: + return AddOffsetsToTxnRequest.parse(buffer, apiVersion); + case END_TXN: + return EndTxnRequest.parse(buffer, apiVersion); + case WRITE_TXN_MARKERS: + return WriteTxnMarkersRequest.parse(buffer, apiVersion); + case TXN_OFFSET_COMMIT: + return TxnOffsetCommitRequest.parse(buffer, apiVersion); + case DESCRIBE_ACLS: + return DescribeAclsRequest.parse(buffer, apiVersion); + case CREATE_ACLS: + return CreateAclsRequest.parse(buffer, apiVersion); + case DELETE_ACLS: + return DeleteAclsRequest.parse(buffer, apiVersion); + case DESCRIBE_CONFIGS: + return DescribeConfigsRequest.parse(buffer, apiVersion); + case ALTER_CONFIGS: + return AlterConfigsRequest.parse(buffer, apiVersion); + case ALTER_REPLICA_LOG_DIRS: + return AlterReplicaLogDirsRequest.parse(buffer, apiVersion); + case DESCRIBE_LOG_DIRS: + return DescribeLogDirsRequest.parse(buffer, apiVersion); + case SASL_AUTHENTICATE: + return SaslAuthenticateRequest.parse(buffer, apiVersion); + case CREATE_PARTITIONS: + return CreatePartitionsRequest.parse(buffer, apiVersion); + case CREATE_DELEGATION_TOKEN: + return CreateDelegationTokenRequest.parse(buffer, apiVersion); + case RENEW_DELEGATION_TOKEN: + return RenewDelegationTokenRequest.parse(buffer, apiVersion); + case EXPIRE_DELEGATION_TOKEN: + return ExpireDelegationTokenRequest.parse(buffer, apiVersion); + case DESCRIBE_DELEGATION_TOKEN: + return DescribeDelegationTokenRequest.parse(buffer, apiVersion); + case DELETE_GROUPS: + return DeleteGroupsRequest.parse(buffer, apiVersion); + case ELECT_LEADERS: + return ElectLeadersRequest.parse(buffer, apiVersion); + case INCREMENTAL_ALTER_CONFIGS: + return IncrementalAlterConfigsRequest.parse(buffer, apiVersion); + case ALTER_PARTITION_REASSIGNMENTS: + return AlterPartitionReassignmentsRequest.parse(buffer, apiVersion); + case LIST_PARTITION_REASSIGNMENTS: + return ListPartitionReassignmentsRequest.parse(buffer, apiVersion); + case OFFSET_DELETE: + return OffsetDeleteRequest.parse(buffer, apiVersion); + case DESCRIBE_CLIENT_QUOTAS: + return DescribeClientQuotasRequest.parse(buffer, apiVersion); + case ALTER_CLIENT_QUOTAS: + return AlterClientQuotasRequest.parse(buffer, apiVersion); + case DESCRIBE_USER_SCRAM_CREDENTIALS: + return DescribeUserScramCredentialsRequest.parse(buffer, apiVersion); + case ALTER_USER_SCRAM_CREDENTIALS: + return AlterUserScramCredentialsRequest.parse(buffer, apiVersion); + case VOTE: + return VoteRequest.parse(buffer, apiVersion); + case BEGIN_QUORUM_EPOCH: + return BeginQuorumEpochRequest.parse(buffer, apiVersion); + case END_QUORUM_EPOCH: + return EndQuorumEpochRequest.parse(buffer, apiVersion); + case DESCRIBE_QUORUM: + return DescribeQuorumRequest.parse(buffer, apiVersion); + case ALTER_ISR: + return AlterIsrRequest.parse(buffer, apiVersion); + case UPDATE_FEATURES: + return UpdateFeaturesRequest.parse(buffer, apiVersion); + case ENVELOPE: + return EnvelopeRequest.parse(buffer, apiVersion); + case FETCH_SNAPSHOT: + return FetchSnapshotRequest.parse(buffer, apiVersion); + case DESCRIBE_CLUSTER: + return DescribeClusterRequest.parse(buffer, apiVersion); + case DESCRIBE_PRODUCERS: + return DescribeProducersRequest.parse(buffer, apiVersion); + case BROKER_REGISTRATION: + return BrokerRegistrationRequest.parse(buffer, apiVersion); + case BROKER_HEARTBEAT: + return BrokerHeartbeatRequest.parse(buffer, apiVersion); + case UNREGISTER_BROKER: + return UnregisterBrokerRequest.parse(buffer, apiVersion); + case DESCRIBE_TRANSACTIONS: + return DescribeTransactionsRequest.parse(buffer, apiVersion); + case LIST_TRANSACTIONS: + return ListTransactionsRequest.parse(buffer, apiVersion); + case ALLOCATE_PRODUCER_IDS: + return AllocateProducerIdsRequest.parse(buffer, apiVersion); + default: + throw new AssertionError(String.format("ApiKey %s is not currently handled in `parseRequest`, the " + + "code should be updated to do so.", apiKey)); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/AbstractRequestResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/AbstractRequestResponse.java new file mode 100644 index 0000000..0f4b180 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/AbstractRequestResponse.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.protocol.ApiMessage; + +public interface AbstractRequestResponse { + + ApiMessage data(); +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/AbstractResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/AbstractResponse.java new file mode 100644 index 0000000..47f2b3c --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/AbstractResponse.java @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.network.Send; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.protocol.MessageUtil; +import org.apache.kafka.common.protocol.SendBuilder; + +import java.nio.ByteBuffer; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public abstract class AbstractResponse implements AbstractRequestResponse { + public static final int DEFAULT_THROTTLE_TIME = 0; + + private final ApiKeys apiKey; + + protected AbstractResponse(ApiKeys apiKey) { + this.apiKey = apiKey; + } + + public final Send toSend(ResponseHeader header, short version) { + return SendBuilder.buildResponseSend(header, data(), version); + } + + /** + * Serializes header and body without prefixing with size (unlike `toSend`, which does include a size prefix). + */ + final ByteBuffer serializeWithHeader(ResponseHeader header, short version) { + return RequestUtils.serialize(header.data(), header.headerVersion(), data(), version); + } + + // Visible for testing + final ByteBuffer serialize(short version) { + return MessageUtil.toByteBuffer(data(), version); + } + + /** + * The number of each type of error in the response, including {@link Errors#NONE} and top-level errors as well as + * more specifically scoped errors (such as topic or partition-level errors). + * @return A count of errors. + */ + public abstract Map errorCounts(); + + protected Map errorCounts(Errors error) { + return Collections.singletonMap(error, 1); + } + + protected Map errorCounts(Stream errors) { + return errors.collect(Collectors.groupingBy(e -> e, Collectors.summingInt(e -> 1))); + } + + protected Map errorCounts(Collection errors) { + Map errorCounts = new HashMap<>(); + for (Errors error : errors) + updateErrorCounts(errorCounts, error); + return errorCounts; + } + + protected Map apiErrorCounts(Map errors) { + Map errorCounts = new HashMap<>(); + for (ApiError apiError : errors.values()) + updateErrorCounts(errorCounts, apiError.error()); + return errorCounts; + } + + protected void updateErrorCounts(Map errorCounts, Errors error) { + Integer count = errorCounts.getOrDefault(error, 0); + errorCounts.put(error, count + 1); + } + + /** + * Parse a response from the provided buffer. The buffer is expected to hold both + * the {@link ResponseHeader} as well as the response payload. + */ + public static AbstractResponse parseResponse(ByteBuffer buffer, RequestHeader requestHeader) { + ApiKeys apiKey = requestHeader.apiKey(); + short apiVersion = requestHeader.apiVersion(); + + ResponseHeader responseHeader = ResponseHeader.parse(buffer, apiKey.responseHeaderVersion(apiVersion)); + + if (requestHeader.correlationId() != responseHeader.correlationId()) { + throw new CorrelationIdMismatchException("Correlation id for response (" + + responseHeader.correlationId() + ") does not match request (" + + requestHeader.correlationId() + "), request header: " + requestHeader, + requestHeader.correlationId(), responseHeader.correlationId()); + } + + return AbstractResponse.parseResponse(apiKey, buffer, apiVersion); + } + + public static AbstractResponse parseResponse(ApiKeys apiKey, ByteBuffer responseBuffer, short version) { + switch (apiKey) { + case PRODUCE: + return ProduceResponse.parse(responseBuffer, version); + case FETCH: + return FetchResponse.parse(responseBuffer, version); + case LIST_OFFSETS: + return ListOffsetsResponse.parse(responseBuffer, version); + case METADATA: + return MetadataResponse.parse(responseBuffer, version); + case OFFSET_COMMIT: + return OffsetCommitResponse.parse(responseBuffer, version); + case OFFSET_FETCH: + return OffsetFetchResponse.parse(responseBuffer, version); + case FIND_COORDINATOR: + return FindCoordinatorResponse.parse(responseBuffer, version); + case JOIN_GROUP: + return JoinGroupResponse.parse(responseBuffer, version); + case HEARTBEAT: + return HeartbeatResponse.parse(responseBuffer, version); + case LEAVE_GROUP: + return LeaveGroupResponse.parse(responseBuffer, version); + case SYNC_GROUP: + return SyncGroupResponse.parse(responseBuffer, version); + case STOP_REPLICA: + return StopReplicaResponse.parse(responseBuffer, version); + case CONTROLLED_SHUTDOWN: + return ControlledShutdownResponse.parse(responseBuffer, version); + case UPDATE_METADATA: + return UpdateMetadataResponse.parse(responseBuffer, version); + case LEADER_AND_ISR: + return LeaderAndIsrResponse.parse(responseBuffer, version); + case DESCRIBE_GROUPS: + return DescribeGroupsResponse.parse(responseBuffer, version); + case LIST_GROUPS: + return ListGroupsResponse.parse(responseBuffer, version); + case SASL_HANDSHAKE: + return SaslHandshakeResponse.parse(responseBuffer, version); + case API_VERSIONS: + return ApiVersionsResponse.parse(responseBuffer, version); + case CREATE_TOPICS: + return CreateTopicsResponse.parse(responseBuffer, version); + case DELETE_TOPICS: + return DeleteTopicsResponse.parse(responseBuffer, version); + case DELETE_RECORDS: + return DeleteRecordsResponse.parse(responseBuffer, version); + case INIT_PRODUCER_ID: + return InitProducerIdResponse.parse(responseBuffer, version); + case OFFSET_FOR_LEADER_EPOCH: + return OffsetsForLeaderEpochResponse.parse(responseBuffer, version); + case ADD_PARTITIONS_TO_TXN: + return AddPartitionsToTxnResponse.parse(responseBuffer, version); + case ADD_OFFSETS_TO_TXN: + return AddOffsetsToTxnResponse.parse(responseBuffer, version); + case END_TXN: + return EndTxnResponse.parse(responseBuffer, version); + case WRITE_TXN_MARKERS: + return WriteTxnMarkersResponse.parse(responseBuffer, version); + case TXN_OFFSET_COMMIT: + return TxnOffsetCommitResponse.parse(responseBuffer, version); + case DESCRIBE_ACLS: + return DescribeAclsResponse.parse(responseBuffer, version); + case CREATE_ACLS: + return CreateAclsResponse.parse(responseBuffer, version); + case DELETE_ACLS: + return DeleteAclsResponse.parse(responseBuffer, version); + case DESCRIBE_CONFIGS: + return DescribeConfigsResponse.parse(responseBuffer, version); + case ALTER_CONFIGS: + return AlterConfigsResponse.parse(responseBuffer, version); + case ALTER_REPLICA_LOG_DIRS: + return AlterReplicaLogDirsResponse.parse(responseBuffer, version); + case DESCRIBE_LOG_DIRS: + return DescribeLogDirsResponse.parse(responseBuffer, version); + case SASL_AUTHENTICATE: + return SaslAuthenticateResponse.parse(responseBuffer, version); + case CREATE_PARTITIONS: + return CreatePartitionsResponse.parse(responseBuffer, version); + case CREATE_DELEGATION_TOKEN: + return CreateDelegationTokenResponse.parse(responseBuffer, version); + case RENEW_DELEGATION_TOKEN: + return RenewDelegationTokenResponse.parse(responseBuffer, version); + case EXPIRE_DELEGATION_TOKEN: + return ExpireDelegationTokenResponse.parse(responseBuffer, version); + case DESCRIBE_DELEGATION_TOKEN: + return DescribeDelegationTokenResponse.parse(responseBuffer, version); + case DELETE_GROUPS: + return DeleteGroupsResponse.parse(responseBuffer, version); + case ELECT_LEADERS: + return ElectLeadersResponse.parse(responseBuffer, version); + case INCREMENTAL_ALTER_CONFIGS: + return IncrementalAlterConfigsResponse.parse(responseBuffer, version); + case ALTER_PARTITION_REASSIGNMENTS: + return AlterPartitionReassignmentsResponse.parse(responseBuffer, version); + case LIST_PARTITION_REASSIGNMENTS: + return ListPartitionReassignmentsResponse.parse(responseBuffer, version); + case OFFSET_DELETE: + return OffsetDeleteResponse.parse(responseBuffer, version); + case DESCRIBE_CLIENT_QUOTAS: + return DescribeClientQuotasResponse.parse(responseBuffer, version); + case ALTER_CLIENT_QUOTAS: + return AlterClientQuotasResponse.parse(responseBuffer, version); + case DESCRIBE_USER_SCRAM_CREDENTIALS: + return DescribeUserScramCredentialsResponse.parse(responseBuffer, version); + case ALTER_USER_SCRAM_CREDENTIALS: + return AlterUserScramCredentialsResponse.parse(responseBuffer, version); + case VOTE: + return VoteResponse.parse(responseBuffer, version); + case BEGIN_QUORUM_EPOCH: + return BeginQuorumEpochResponse.parse(responseBuffer, version); + case END_QUORUM_EPOCH: + return EndQuorumEpochResponse.parse(responseBuffer, version); + case DESCRIBE_QUORUM: + return DescribeQuorumResponse.parse(responseBuffer, version); + case ALTER_ISR: + return AlterIsrResponse.parse(responseBuffer, version); + case UPDATE_FEATURES: + return UpdateFeaturesResponse.parse(responseBuffer, version); + case ENVELOPE: + return EnvelopeResponse.parse(responseBuffer, version); + case FETCH_SNAPSHOT: + return FetchSnapshotResponse.parse(responseBuffer, version); + case DESCRIBE_CLUSTER: + return DescribeClusterResponse.parse(responseBuffer, version); + case DESCRIBE_PRODUCERS: + return DescribeProducersResponse.parse(responseBuffer, version); + case BROKER_REGISTRATION: + return BrokerRegistrationResponse.parse(responseBuffer, version); + case BROKER_HEARTBEAT: + return BrokerHeartbeatResponse.parse(responseBuffer, version); + case UNREGISTER_BROKER: + return UnregisterBrokerResponse.parse(responseBuffer, version); + case DESCRIBE_TRANSACTIONS: + return DescribeTransactionsResponse.parse(responseBuffer, version); + case LIST_TRANSACTIONS: + return ListTransactionsResponse.parse(responseBuffer, version); + case ALLOCATE_PRODUCER_IDS: + return AllocateProducerIdsResponse.parse(responseBuffer, version); + default: + throw new AssertionError(String.format("ApiKey %s is not currently handled in `parseResponse`, the " + + "code should be updated to do so.", apiKey)); + } + } + + /** + * Returns whether or not client should throttle upon receiving a response of the specified version with a non-zero + * throttle time. Client-side throttling is needed when communicating with a newer version of broker which, on + * quota violation, sends out responses before throttling. + */ + public boolean shouldClientThrottle(short version) { + return false; + } + + public ApiKeys apiKey() { + return apiKey; + } + + public abstract int throttleTimeMs(); + + public String toString() { + return data().toString(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/AddOffsetsToTxnRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/AddOffsetsToTxnRequest.java new file mode 100644 index 0000000..1e5f986 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/AddOffsetsToTxnRequest.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.AddOffsetsToTxnRequestData; +import org.apache.kafka.common.message.AddOffsetsToTxnResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; + +public class AddOffsetsToTxnRequest extends AbstractRequest { + + private final AddOffsetsToTxnRequestData data; + + public static class Builder extends AbstractRequest.Builder { + public AddOffsetsToTxnRequestData data; + + public Builder(AddOffsetsToTxnRequestData data) { + super(ApiKeys.ADD_OFFSETS_TO_TXN); + this.data = data; + } + + @Override + public AddOffsetsToTxnRequest build(short version) { + return new AddOffsetsToTxnRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + public AddOffsetsToTxnRequest(AddOffsetsToTxnRequestData data, short version) { + super(ApiKeys.ADD_OFFSETS_TO_TXN, version); + this.data = data; + } + + @Override + public AddOffsetsToTxnRequestData data() { + return data; + } + + @Override + public AddOffsetsToTxnResponse getErrorResponse(int throttleTimeMs, Throwable e) { + return new AddOffsetsToTxnResponse(new AddOffsetsToTxnResponseData() + .setErrorCode(Errors.forException(e).code()) + .setThrottleTimeMs(throttleTimeMs)); + } + + public static AddOffsetsToTxnRequest parse(ByteBuffer buffer, short version) { + return new AddOffsetsToTxnRequest(new AddOffsetsToTxnRequestData(new ByteBufferAccessor(buffer), version), version); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/AddOffsetsToTxnResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/AddOffsetsToTxnResponse.java new file mode 100644 index 0000000..ce9a6cf --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/AddOffsetsToTxnResponse.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.AddOffsetsToTxnResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Map; + +/** + * Possible error codes: + * + * - {@link Errors#NOT_COORDINATOR} + * - {@link Errors#COORDINATOR_NOT_AVAILABLE} + * - {@link Errors#COORDINATOR_LOAD_IN_PROGRESS} + * - {@link Errors#INVALID_PRODUCER_ID_MAPPING} + * - {@link Errors#INVALID_PRODUCER_EPOCH} // for version <=1 + * - {@link Errors#PRODUCER_FENCED} + * - {@link Errors#INVALID_TXN_STATE} + * - {@link Errors#GROUP_AUTHORIZATION_FAILED} + * - {@link Errors#TRANSACTIONAL_ID_AUTHORIZATION_FAILED} + */ +public class AddOffsetsToTxnResponse extends AbstractResponse { + + private final AddOffsetsToTxnResponseData data; + + public AddOffsetsToTxnResponse(AddOffsetsToTxnResponseData data) { + super(ApiKeys.ADD_OFFSETS_TO_TXN); + this.data = data; + } + + @Override + public Map errorCounts() { + return errorCounts(Errors.forCode(data.errorCode())); + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + @Override + public AddOffsetsToTxnResponseData data() { + return data; + } + + public static AddOffsetsToTxnResponse parse(ByteBuffer buffer, short version) { + return new AddOffsetsToTxnResponse(new AddOffsetsToTxnResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public String toString() { + return data.toString(); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 1; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnRequest.java new file mode 100644 index 0000000..1034c0f --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnRequest.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.AddPartitionsToTxnRequestData; +import org.apache.kafka.common.message.AddPartitionsToTxnRequestData.AddPartitionsToTxnTopic; +import org.apache.kafka.common.message.AddPartitionsToTxnRequestData.AddPartitionsToTxnTopicCollection; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class AddPartitionsToTxnRequest extends AbstractRequest { + + private final AddPartitionsToTxnRequestData data; + + private List cachedPartitions = null; + + public static class Builder extends AbstractRequest.Builder { + public final AddPartitionsToTxnRequestData data; + + public Builder(final AddPartitionsToTxnRequestData data) { + super(ApiKeys.ADD_PARTITIONS_TO_TXN); + this.data = data; + } + + public Builder(final String transactionalId, + final long producerId, + final short producerEpoch, + final List partitions) { + super(ApiKeys.ADD_PARTITIONS_TO_TXN); + + Map> partitionMap = new HashMap<>(); + for (TopicPartition topicPartition : partitions) { + String topicName = topicPartition.topic(); + + partitionMap.compute(topicName, (key, subPartitions) -> { + if (subPartitions == null) { + subPartitions = new ArrayList<>(); + } + subPartitions.add(topicPartition.partition()); + return subPartitions; + }); + } + + AddPartitionsToTxnTopicCollection topics = new AddPartitionsToTxnTopicCollection(); + for (Map.Entry> partitionEntry : partitionMap.entrySet()) { + topics.add(new AddPartitionsToTxnTopic() + .setName(partitionEntry.getKey()) + .setPartitions(partitionEntry.getValue())); + } + + this.data = new AddPartitionsToTxnRequestData() + .setTransactionalId(transactionalId) + .setProducerId(producerId) + .setProducerEpoch(producerEpoch) + .setTopics(topics); + } + + @Override + public AddPartitionsToTxnRequest build(short version) { + return new AddPartitionsToTxnRequest(data, version); + } + + static List getPartitions(AddPartitionsToTxnRequestData data) { + List partitions = new ArrayList<>(); + for (AddPartitionsToTxnTopic topicCollection : data.topics()) { + for (Integer partition : topicCollection.partitions()) { + partitions.add(new TopicPartition(topicCollection.name(), partition)); + } + } + return partitions; + } + + @Override + public String toString() { + return data.toString(); + } + } + + public AddPartitionsToTxnRequest(final AddPartitionsToTxnRequestData data, short version) { + super(ApiKeys.ADD_PARTITIONS_TO_TXN, version); + this.data = data; + } + + public List partitions() { + if (cachedPartitions != null) { + return cachedPartitions; + } + cachedPartitions = Builder.getPartitions(data); + return cachedPartitions; + } + + @Override + public AddPartitionsToTxnRequestData data() { + return data; + } + + @Override + public AddPartitionsToTxnResponse getErrorResponse(int throttleTimeMs, Throwable e) { + final HashMap errors = new HashMap<>(); + for (TopicPartition partition : partitions()) { + errors.put(partition, Errors.forException(e)); + } + return new AddPartitionsToTxnResponse(throttleTimeMs, errors); + } + + public static AddPartitionsToTxnRequest parse(ByteBuffer buffer, short version) { + return new AddPartitionsToTxnRequest(new AddPartitionsToTxnRequestData(new ByteBufferAccessor(buffer), version), version); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnResponse.java new file mode 100644 index 0000000..57b2a5a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnResponse.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.AddPartitionsToTxnResponseData; +import org.apache.kafka.common.message.AddPartitionsToTxnResponseData.AddPartitionsToTxnPartitionResult; +import org.apache.kafka.common.message.AddPartitionsToTxnResponseData.AddPartitionsToTxnPartitionResultCollection; +import org.apache.kafka.common.message.AddPartitionsToTxnResponseData.AddPartitionsToTxnTopicResult; +import org.apache.kafka.common.message.AddPartitionsToTxnResponseData.AddPartitionsToTxnTopicResultCollection; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; + +/** + * Possible error codes: + * + * - {@link Errors#NOT_COORDINATOR} + * - {@link Errors#COORDINATOR_NOT_AVAILABLE} + * - {@link Errors#COORDINATOR_LOAD_IN_PROGRESS} + * - {@link Errors#INVALID_TXN_STATE} + * - {@link Errors#INVALID_PRODUCER_ID_MAPPING} + * - {@link Errors#INVALID_PRODUCER_EPOCH} // for version <=1 + * - {@link Errors#PRODUCER_FENCED} + * - {@link Errors#TOPIC_AUTHORIZATION_FAILED} + * - {@link Errors#TRANSACTIONAL_ID_AUTHORIZATION_FAILED} + * - {@link Errors#UNKNOWN_TOPIC_OR_PARTITION} + */ +public class AddPartitionsToTxnResponse extends AbstractResponse { + + private final AddPartitionsToTxnResponseData data; + + private Map cachedErrorsMap = null; + + public AddPartitionsToTxnResponse(AddPartitionsToTxnResponseData data) { + super(ApiKeys.ADD_PARTITIONS_TO_TXN); + this.data = data; + } + + public AddPartitionsToTxnResponse(int throttleTimeMs, Map errors) { + super(ApiKeys.ADD_PARTITIONS_TO_TXN); + + Map resultMap = new HashMap<>(); + + for (Map.Entry entry : errors.entrySet()) { + TopicPartition topicPartition = entry.getKey(); + String topicName = topicPartition.topic(); + + AddPartitionsToTxnPartitionResult partitionResult = + new AddPartitionsToTxnPartitionResult() + .setErrorCode(entry.getValue().code()) + .setPartitionIndex(topicPartition.partition()); + + AddPartitionsToTxnPartitionResultCollection partitionResultCollection = resultMap.getOrDefault( + topicName, new AddPartitionsToTxnPartitionResultCollection() + ); + + partitionResultCollection.add(partitionResult); + resultMap.put(topicName, partitionResultCollection); + } + + AddPartitionsToTxnTopicResultCollection topicCollection = new AddPartitionsToTxnTopicResultCollection(); + for (Map.Entry entry : resultMap.entrySet()) { + topicCollection.add(new AddPartitionsToTxnTopicResult() + .setName(entry.getKey()) + .setResults(entry.getValue())); + } + + this.data = new AddPartitionsToTxnResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setResults(topicCollection); + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + public Map errors() { + if (cachedErrorsMap != null) { + return cachedErrorsMap; + } + + cachedErrorsMap = new HashMap<>(); + + for (AddPartitionsToTxnTopicResult topicResult : this.data.results()) { + for (AddPartitionsToTxnPartitionResult partitionResult : topicResult.results()) { + cachedErrorsMap.put(new TopicPartition( + topicResult.name(), partitionResult.partitionIndex()), + Errors.forCode(partitionResult.errorCode())); + } + } + return cachedErrorsMap; + } + + @Override + public Map errorCounts() { + return errorCounts(errors().values()); + } + + @Override + public AddPartitionsToTxnResponseData data() { + return data; + } + + public static AddPartitionsToTxnResponse parse(ByteBuffer buffer, short version) { + return new AddPartitionsToTxnResponse(new AddPartitionsToTxnResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public String toString() { + return data.toString(); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 1; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/AllocateProducerIdsRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/AllocateProducerIdsRequest.java new file mode 100644 index 0000000..7938f92 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/AllocateProducerIdsRequest.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.AllocateProducerIdsRequestData; +import org.apache.kafka.common.message.AllocateProducerIdsResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; + +public class AllocateProducerIdsRequest extends AbstractRequest { + private final AllocateProducerIdsRequestData data; + + public AllocateProducerIdsRequest(AllocateProducerIdsRequestData data, short version) { + super(ApiKeys.ALLOCATE_PRODUCER_IDS, version); + this.data = data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + return new AllocateProducerIdsResponse(new AllocateProducerIdsResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setErrorCode(Errors.forException(e).code())); + } + + @Override + public AllocateProducerIdsRequestData data() { + return data; + } + + public static class Builder extends AbstractRequest.Builder { + + private final AllocateProducerIdsRequestData data; + + public Builder(AllocateProducerIdsRequestData data) { + super(ApiKeys.ALLOCATE_PRODUCER_IDS); + this.data = data; + } + + @Override + public AllocateProducerIdsRequest build(short version) { + return new AllocateProducerIdsRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + public static AllocateProducerIdsRequest parse(ByteBuffer buffer, short version) { + return new AllocateProducerIdsRequest(new AllocateProducerIdsRequestData( + new ByteBufferAccessor(buffer), version), version); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/AllocateProducerIdsResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/AllocateProducerIdsResponse.java new file mode 100644 index 0000000..5d48c39 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/AllocateProducerIdsResponse.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.AllocateProducerIdsResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.Map; + +public class AllocateProducerIdsResponse extends AbstractResponse { + + private final AllocateProducerIdsResponseData data; + + public AllocateProducerIdsResponse(AllocateProducerIdsResponseData data) { + super(ApiKeys.ALLOCATE_PRODUCER_IDS); + this.data = data; + } + + @Override + public AllocateProducerIdsResponseData data() { + return data; + } + + /** + * The number of each type of error in the response, including {@link Errors#NONE} and top-level errors as well as + * more specifically scoped errors (such as topic or partition-level errors). + * + * @return A count of errors. + */ + @Override + public Map errorCounts() { + return Collections.singletonMap(Errors.forCode(data.errorCode()), 1); + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + public static AllocateProducerIdsResponse parse(ByteBuffer buffer, short version) { + return new AllocateProducerIdsResponse(new AllocateProducerIdsResponseData( + new ByteBufferAccessor(buffer), version)); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/AlterClientQuotasRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/AlterClientQuotasRequest.java new file mode 100644 index 0000000..3b06348 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/AlterClientQuotasRequest.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.AlterClientQuotasRequestData; +import org.apache.kafka.common.message.AlterClientQuotasRequestData.EntityData; +import org.apache.kafka.common.message.AlterClientQuotasRequestData.EntryData; +import org.apache.kafka.common.message.AlterClientQuotasRequestData.OpData; +import org.apache.kafka.common.message.AlterClientQuotasResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.quota.ClientQuotaAlteration; +import org.apache.kafka.common.quota.ClientQuotaEntity; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class AlterClientQuotasRequest extends AbstractRequest { + + public static class Builder extends AbstractRequest.Builder { + + private final AlterClientQuotasRequestData data; + + public Builder(Collection entries, boolean validateOnly) { + super(ApiKeys.ALTER_CLIENT_QUOTAS); + + List entryData = new ArrayList<>(entries.size()); + for (ClientQuotaAlteration entry : entries) { + List entityData = new ArrayList<>(entry.entity().entries().size()); + for (Map.Entry entityEntries : entry.entity().entries().entrySet()) { + entityData.add(new EntityData() + .setEntityType(entityEntries.getKey()) + .setEntityName(entityEntries.getValue())); + } + + List opData = new ArrayList<>(entry.ops().size()); + for (ClientQuotaAlteration.Op op : entry.ops()) { + opData.add(new OpData() + .setKey(op.key()) + .setValue(op.value() == null ? 0.0 : op.value()) + .setRemove(op.value() == null)); + } + + entryData.add(new EntryData() + .setEntity(entityData) + .setOps(opData)); + } + + this.data = new AlterClientQuotasRequestData() + .setEntries(entryData) + .setValidateOnly(validateOnly); + } + + @Override + public AlterClientQuotasRequest build(short version) { + return new AlterClientQuotasRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final AlterClientQuotasRequestData data; + + public AlterClientQuotasRequest(AlterClientQuotasRequestData data, short version) { + super(ApiKeys.ALTER_CLIENT_QUOTAS, version); + this.data = data; + } + + public List entries() { + List entries = new ArrayList<>(data.entries().size()); + for (EntryData entryData : data.entries()) { + Map entity = new HashMap<>(entryData.entity().size()); + for (EntityData entityData : entryData.entity()) { + entity.put(entityData.entityType(), entityData.entityName()); + } + + List ops = new ArrayList<>(entryData.ops().size()); + for (OpData opData : entryData.ops()) { + Double value = opData.remove() ? null : opData.value(); + ops.add(new ClientQuotaAlteration.Op(opData.key(), value)); + } + + entries.add(new ClientQuotaAlteration(new ClientQuotaEntity(entity), ops)); + } + return entries; + } + + public boolean validateOnly() { + return data.validateOnly(); + } + + @Override + public AlterClientQuotasRequestData data() { + return data; + } + + @Override + public AlterClientQuotasResponse getErrorResponse(int throttleTimeMs, Throwable e) { + Errors error = Errors.forException(e); + List responseEntries = new ArrayList<>(); + for (EntryData entryData : data.entries()) { + List responseEntities = new ArrayList<>(); + for (EntityData entityData : entryData.entity()) { + responseEntities.add(new AlterClientQuotasResponseData.EntityData() + .setEntityType(entityData.entityType()) + .setEntityName(entityData.entityName())); + } + responseEntries.add(new AlterClientQuotasResponseData.EntryData() + .setEntity(responseEntities) + .setErrorCode(error.code()) + .setErrorMessage(error.message())); + } + AlterClientQuotasResponseData responseData = new AlterClientQuotasResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setEntries(responseEntries); + return new AlterClientQuotasResponse(responseData); + } + + public static AlterClientQuotasRequest parse(ByteBuffer buffer, short version) { + return new AlterClientQuotasRequest(new AlterClientQuotasRequestData(new ByteBufferAccessor(buffer), version), version); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/AlterClientQuotasResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/AlterClientQuotasResponse.java new file mode 100644 index 0000000..fcacc5d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/AlterClientQuotasResponse.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.message.AlterClientQuotasResponseData; +import org.apache.kafka.common.message.AlterClientQuotasResponseData.EntityData; +import org.apache.kafka.common.message.AlterClientQuotasResponseData.EntryData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.quota.ClientQuotaEntity; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class AlterClientQuotasResponse extends AbstractResponse { + + private final AlterClientQuotasResponseData data; + + public AlterClientQuotasResponse(AlterClientQuotasResponseData data) { + super(ApiKeys.ALTER_CLIENT_QUOTAS); + this.data = data; + } + + public void complete(Map> futures) { + for (EntryData entryData : data.entries()) { + Map entityEntries = new HashMap<>(entryData.entity().size()); + for (EntityData entityData : entryData.entity()) { + entityEntries.put(entityData.entityType(), entityData.entityName()); + } + ClientQuotaEntity entity = new ClientQuotaEntity(entityEntries); + + KafkaFutureImpl future = futures.get(entity); + if (future == null) { + throw new IllegalArgumentException("Future map must contain entity " + entity); + } + + Errors error = Errors.forCode(entryData.errorCode()); + if (error == Errors.NONE) { + future.complete(null); + } else { + future.completeExceptionally(error.exception(entryData.errorMessage())); + } + } + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + @Override + public Map errorCounts() { + Map counts = new HashMap<>(); + data.entries().forEach(entry -> + updateErrorCounts(counts, Errors.forCode(entry.errorCode())) + ); + return counts; + } + + @Override + public AlterClientQuotasResponseData data() { + return data; + } + + private static List toEntityData(ClientQuotaEntity entity) { + List entityData = new ArrayList<>(entity.entries().size()); + for (Map.Entry entry : entity.entries().entrySet()) { + entityData.add(new AlterClientQuotasResponseData.EntityData() + .setEntityType(entry.getKey()) + .setEntityName(entry.getValue())); + } + return entityData; + } + + public static AlterClientQuotasResponse parse(ByteBuffer buffer, short version) { + return new AlterClientQuotasResponse(new AlterClientQuotasResponseData(new ByteBufferAccessor(buffer), version)); + } + + public static AlterClientQuotasResponse fromQuotaEntities(Map result, int throttleTimeMs) { + List entries = new ArrayList<>(result.size()); + for (Map.Entry entry : result.entrySet()) { + ApiError e = entry.getValue(); + entries.add(new EntryData() + .setErrorCode(e.error().code()) + .setErrorMessage(e.message()) + .setEntity(toEntityData(entry.getKey()))); + } + + return new AlterClientQuotasResponse(new AlterClientQuotasResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setEntries(entries)); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/AlterConfigsRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/AlterConfigsRequest.java new file mode 100644 index 0000000..b4d35d5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/AlterConfigsRequest.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.message.AlterConfigsRequestData; +import org.apache.kafka.common.message.AlterConfigsResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; + +import java.nio.ByteBuffer; +import java.util.Collection; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +public class AlterConfigsRequest extends AbstractRequest { + + public static class Config { + private final Collection entries; + + public Config(Collection entries) { + this.entries = Objects.requireNonNull(entries, "entries"); + } + + public Collection entries() { + return entries; + } + } + + public static class ConfigEntry { + private final String name; + private final String value; + + public ConfigEntry(String name, String value) { + this.name = Objects.requireNonNull(name, "name"); + this.value = Objects.requireNonNull(value, "value"); + } + + public String name() { + return name; + } + + public String value() { + return value; + } + + } + + public static class Builder extends AbstractRequest.Builder { + + private final AlterConfigsRequestData data = new AlterConfigsRequestData(); + + public Builder(Map configs, boolean validateOnly) { + super(ApiKeys.ALTER_CONFIGS); + Objects.requireNonNull(configs, "configs"); + for (Map.Entry entry : configs.entrySet()) { + AlterConfigsRequestData.AlterConfigsResource resource = + new AlterConfigsRequestData.AlterConfigsResource() + .setResourceName(entry.getKey().name()) + .setResourceType(entry.getKey().type().id()); + for (ConfigEntry x : entry.getValue().entries) { + resource.configs().add(new AlterConfigsRequestData.AlterableConfig() + .setName(x.name()) + .setValue(x.value())); + } + this.data.resources().add(resource); + } + this.data.setValidateOnly(validateOnly); + } + + @Override + public AlterConfigsRequest build(short version) { + return new AlterConfigsRequest(data, version); + } + } + + private final AlterConfigsRequestData data; + + public AlterConfigsRequest(AlterConfigsRequestData data, short version) { + super(ApiKeys.ALTER_CONFIGS, version); + this.data = data; + } + + public Map configs() { + return data.resources().stream().collect(Collectors.toMap( + resource -> new ConfigResource( + ConfigResource.Type.forId(resource.resourceType()), + resource.resourceName()), + resource -> new Config(resource.configs().stream() + .map(entry -> new ConfigEntry(entry.name(), entry.value())) + .collect(Collectors.toList())))); + } + + public boolean validateOnly() { + return data.validateOnly(); + } + + @Override + public AlterConfigsRequestData data() { + return data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + ApiError error = ApiError.fromThrowable(e); + AlterConfigsResponseData data = new AlterConfigsResponseData() + .setThrottleTimeMs(throttleTimeMs); + for (AlterConfigsRequestData.AlterConfigsResource resource : this.data.resources()) { + data.responses().add(new AlterConfigsResponseData.AlterConfigsResourceResponse() + .setResourceType(resource.resourceType()) + .setResourceName(resource.resourceName()) + .setErrorMessage(error.message()) + .setErrorCode(error.error().code())); + } + return new AlterConfigsResponse(data); + + } + + public static AlterConfigsRequest parse(ByteBuffer buffer, short version) { + return new AlterConfigsRequest(new AlterConfigsRequestData(new ByteBufferAccessor(buffer), version), version); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/AlterConfigsResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/AlterConfigsResponse.java new file mode 100644 index 0000000..1115f06 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/AlterConfigsResponse.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.message.AlterConfigsResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Map; +import java.util.stream.Collectors; + +public class AlterConfigsResponse extends AbstractResponse { + + private final AlterConfigsResponseData data; + + public AlterConfigsResponse(AlterConfigsResponseData data) { + super(ApiKeys.ALTER_CONFIGS); + this.data = data; + } + + public Map errors() { + return data.responses().stream().collect(Collectors.toMap( + response -> new ConfigResource( + ConfigResource.Type.forId(response.resourceType()), + response.resourceName()), + response -> new ApiError(Errors.forCode(response.errorCode()), response.errorMessage()) + )); + } + + @Override + public Map errorCounts() { + return apiErrorCounts(errors()); + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + @Override + public AlterConfigsResponseData data() { + return data; + } + + public static AlterConfigsResponse parse(ByteBuffer buffer, short version) { + return new AlterConfigsResponse(new AlterConfigsResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 1; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/AlterIsrRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/AlterIsrRequest.java new file mode 100644 index 0000000..516c2ce --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/AlterIsrRequest.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.AlterIsrRequestData; +import org.apache.kafka.common.message.AlterIsrResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; + +public class AlterIsrRequest extends AbstractRequest { + + private final AlterIsrRequestData data; + + public AlterIsrRequest(AlterIsrRequestData data, short apiVersion) { + super(ApiKeys.ALTER_ISR, apiVersion); + this.data = data; + } + + @Override + public AlterIsrRequestData data() { + return data; + } + + /** + * Get an error response for a request with specified throttle time in the response if applicable + */ + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + return new AlterIsrResponse(new AlterIsrResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setErrorCode(Errors.forException(e).code())); + } + + public static AlterIsrRequest parse(ByteBuffer buffer, short version) { + return new AlterIsrRequest(new AlterIsrRequestData(new ByteBufferAccessor(buffer), version), version); + } + + public static class Builder extends AbstractRequest.Builder { + + private final AlterIsrRequestData data; + + public Builder(AlterIsrRequestData data) { + super(ApiKeys.ALTER_ISR); + this.data = data; + } + + @Override + public AlterIsrRequest build(short version) { + return new AlterIsrRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/AlterIsrResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/AlterIsrResponse.java new file mode 100644 index 0000000..c3106ed --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/AlterIsrResponse.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.AlterIsrResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; + +public class AlterIsrResponse extends AbstractResponse { + + private final AlterIsrResponseData data; + + public AlterIsrResponse(AlterIsrResponseData data) { + super(ApiKeys.ALTER_ISR); + this.data = data; + } + + @Override + public AlterIsrResponseData data() { + return data; + } + + @Override + public Map errorCounts() { + Map counts = new HashMap<>(); + updateErrorCounts(counts, Errors.forCode(data.errorCode())); + data.topics().forEach(topicResponse -> topicResponse.partitions().forEach(partitionResponse -> { + updateErrorCounts(counts, Errors.forCode(partitionResponse.errorCode())); + })); + return counts; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + public static AlterIsrResponse parse(ByteBuffer buffer, short version) { + return new AlterIsrResponse(new AlterIsrResponseData(new ByteBufferAccessor(buffer), version)); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/AlterPartitionReassignmentsRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/AlterPartitionReassignmentsRequest.java new file mode 100644 index 0000000..2d289cc --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/AlterPartitionReassignmentsRequest.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.AlterPartitionReassignmentsRequestData; +import org.apache.kafka.common.message.AlterPartitionReassignmentsRequestData.ReassignableTopic; +import org.apache.kafka.common.message.AlterPartitionReassignmentsResponseData; +import org.apache.kafka.common.message.AlterPartitionReassignmentsResponseData.ReassignablePartitionResponse; +import org.apache.kafka.common.message.AlterPartitionReassignmentsResponseData.ReassignableTopicResponse; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +public class AlterPartitionReassignmentsRequest extends AbstractRequest { + + public static class Builder extends AbstractRequest.Builder { + private final AlterPartitionReassignmentsRequestData data; + + public Builder(AlterPartitionReassignmentsRequestData data) { + super(ApiKeys.ALTER_PARTITION_REASSIGNMENTS); + this.data = data; + } + + @Override + public AlterPartitionReassignmentsRequest build(short version) { + return new AlterPartitionReassignmentsRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final AlterPartitionReassignmentsRequestData data; + + private AlterPartitionReassignmentsRequest(AlterPartitionReassignmentsRequestData data, short version) { + super(ApiKeys.ALTER_PARTITION_REASSIGNMENTS, version); + this.data = data; + } + + public static AlterPartitionReassignmentsRequest parse(ByteBuffer buffer, short version) { + return new AlterPartitionReassignmentsRequest(new AlterPartitionReassignmentsRequestData( + new ByteBufferAccessor(buffer), version), version); + } + + public AlterPartitionReassignmentsRequestData data() { + return data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + ApiError apiError = ApiError.fromThrowable(e); + List topicResponses = new ArrayList<>(); + + for (ReassignableTopic topic : data.topics()) { + List partitionResponses = topic.partitions().stream().map(partition -> + new ReassignablePartitionResponse() + .setPartitionIndex(partition.partitionIndex()) + .setErrorCode(apiError.error().code()) + .setErrorMessage(apiError.message()) + ).collect(Collectors.toList()); + topicResponses.add( + new ReassignableTopicResponse() + .setName(topic.name()) + .setPartitions(partitionResponses) + ); + } + + AlterPartitionReassignmentsResponseData responseData = new AlterPartitionReassignmentsResponseData() + .setResponses(topicResponses) + .setErrorCode(apiError.error().code()) + .setErrorMessage(apiError.message()) + .setThrottleTimeMs(throttleTimeMs); + return new AlterPartitionReassignmentsResponse(responseData); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/AlterPartitionReassignmentsResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/AlterPartitionReassignmentsResponse.java new file mode 100644 index 0000000..ab166b8 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/AlterPartitionReassignmentsResponse.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.AlterPartitionReassignmentsResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; + +public class AlterPartitionReassignmentsResponse extends AbstractResponse { + + private final AlterPartitionReassignmentsResponseData data; + + public AlterPartitionReassignmentsResponse(AlterPartitionReassignmentsResponseData data) { + super(ApiKeys.ALTER_PARTITION_REASSIGNMENTS); + this.data = data; + } + + public static AlterPartitionReassignmentsResponse parse(ByteBuffer buffer, short version) { + return new AlterPartitionReassignmentsResponse( + new AlterPartitionReassignmentsResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public AlterPartitionReassignmentsResponseData data() { + return data; + } + + @Override + public boolean shouldClientThrottle(short version) { + return true; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + @Override + public Map errorCounts() { + Map counts = new HashMap<>(); + updateErrorCounts(counts, Errors.forCode(data.errorCode())); + + data.responses().forEach(topicResponse -> + topicResponse.partitions().forEach(partitionResponse -> + updateErrorCounts(counts, Errors.forCode(partitionResponse.errorCode())) + )); + return counts; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/AlterReplicaLogDirsRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/AlterReplicaLogDirsRequest.java new file mode 100644 index 0000000..68a87e6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/AlterReplicaLogDirsRequest.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Collectors; + +import org.apache.kafka.common.message.AlterReplicaLogDirsRequestData; +import org.apache.kafka.common.message.AlterReplicaLogDirsResponseData; +import org.apache.kafka.common.message.AlterReplicaLogDirsResponseData.AlterReplicaLogDirTopicResult; + +public class AlterReplicaLogDirsRequest extends AbstractRequest { + + private final AlterReplicaLogDirsRequestData data; + + public static class Builder extends AbstractRequest.Builder { + private final AlterReplicaLogDirsRequestData data; + + public Builder(AlterReplicaLogDirsRequestData data) { + super(ApiKeys.ALTER_REPLICA_LOG_DIRS); + this.data = data; + } + + @Override + public AlterReplicaLogDirsRequest build(short version) { + return new AlterReplicaLogDirsRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + public AlterReplicaLogDirsRequest(AlterReplicaLogDirsRequestData data, short version) { + super(ApiKeys.ALTER_REPLICA_LOG_DIRS, version); + this.data = data; + } + + @Override + public AlterReplicaLogDirsRequestData data() { + return data; + } + + public AlterReplicaLogDirsResponse getErrorResponse(int throttleTimeMs, Throwable e) { + AlterReplicaLogDirsResponseData data = new AlterReplicaLogDirsResponseData(); + data.setResults(this.data.dirs().stream().flatMap(alterDir -> + alterDir.topics().stream().map(topic -> + new AlterReplicaLogDirTopicResult() + .setTopicName(topic.name()) + .setPartitions(topic.partitions().stream().map(partitionId -> + new AlterReplicaLogDirsResponseData.AlterReplicaLogDirPartitionResult() + .setErrorCode(Errors.forException(e).code()) + .setPartitionIndex(partitionId)).collect(Collectors.toList())))).collect(Collectors.toList())); + return new AlterReplicaLogDirsResponse(data.setThrottleTimeMs(throttleTimeMs)); + } + + public Map partitionDirs() { + Map result = new HashMap<>(); + data.dirs().forEach(alterDir -> + alterDir.topics().forEach(topic -> + topic.partitions().forEach(partition -> + result.put(new TopicPartition(topic.name(), partition.intValue()), alterDir.path()))) + ); + return result; + } + + public static AlterReplicaLogDirsRequest parse(ByteBuffer buffer, short version) { + return new AlterReplicaLogDirsRequest(new AlterReplicaLogDirsRequestData(new ByteBufferAccessor(buffer), version), version); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/AlterReplicaLogDirsResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/AlterReplicaLogDirsResponse.java new file mode 100644 index 0000000..afa658d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/AlterReplicaLogDirsResponse.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.AlterReplicaLogDirsResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; + +/** + * Possible error codes: + * + * {@link Errors#LOG_DIR_NOT_FOUND} + * {@link Errors#KAFKA_STORAGE_ERROR} + * {@link Errors#REPLICA_NOT_AVAILABLE} + * {@link Errors#UNKNOWN_SERVER_ERROR} + */ +public class AlterReplicaLogDirsResponse extends AbstractResponse { + + private final AlterReplicaLogDirsResponseData data; + + public AlterReplicaLogDirsResponse(AlterReplicaLogDirsResponseData data) { + super(ApiKeys.ALTER_REPLICA_LOG_DIRS); + this.data = data; + } + + @Override + public AlterReplicaLogDirsResponseData data() { + return data; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + @Override + public Map errorCounts() { + Map errorCounts = new HashMap<>(); + data.results().forEach(topicResult -> + topicResult.partitions().forEach(partitionResult -> + updateErrorCounts(errorCounts, Errors.forCode(partitionResult.errorCode())))); + return errorCounts; + } + + public static AlterReplicaLogDirsResponse parse(ByteBuffer buffer, short version) { + return new AlterReplicaLogDirsResponse(new AlterReplicaLogDirsResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 1; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/AlterUserScramCredentialsRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/AlterUserScramCredentialsRequest.java new file mode 100644 index 0000000..1ca7ea7 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/AlterUserScramCredentialsRequest.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.AlterUserScramCredentialsRequestData; +import org.apache.kafka.common.message.AlterUserScramCredentialsResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; + +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class AlterUserScramCredentialsRequest extends AbstractRequest { + + public static class Builder extends AbstractRequest.Builder { + private final AlterUserScramCredentialsRequestData data; + + public Builder(AlterUserScramCredentialsRequestData data) { + super(ApiKeys.ALTER_USER_SCRAM_CREDENTIALS); + this.data = data; + } + + @Override + public AlterUserScramCredentialsRequest build(short version) { + return new AlterUserScramCredentialsRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final AlterUserScramCredentialsRequestData data; + + private AlterUserScramCredentialsRequest(AlterUserScramCredentialsRequestData data, short version) { + super(ApiKeys.ALTER_USER_SCRAM_CREDENTIALS, version); + this.data = data; + } + + public static AlterUserScramCredentialsRequest parse(ByteBuffer buffer, short version) { + return new AlterUserScramCredentialsRequest(new AlterUserScramCredentialsRequestData(new ByteBufferAccessor(buffer), version), version); + } + + @Override + public AlterUserScramCredentialsRequestData data() { + return data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + ApiError apiError = ApiError.fromThrowable(e); + short errorCode = apiError.error().code(); + String errorMessage = apiError.message(); + Set users = Stream.concat( + this.data.deletions().stream().map(deletion -> deletion.name()), + this.data.upsertions().stream().map(upsertion -> upsertion.name())) + .collect(Collectors.toSet()); + List results = + users.stream().sorted().map(user -> + new AlterUserScramCredentialsResponseData.AlterUserScramCredentialsResult() + .setUser(user) + .setErrorCode(errorCode) + .setErrorMessage(errorMessage)) + .collect(Collectors.toList()); + return new AlterUserScramCredentialsResponse(new AlterUserScramCredentialsResponseData().setResults(results)); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/AlterUserScramCredentialsResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/AlterUserScramCredentialsResponse.java new file mode 100644 index 0000000..97c0b7d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/AlterUserScramCredentialsResponse.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.AlterUserScramCredentialsResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Map; + +public class AlterUserScramCredentialsResponse extends AbstractResponse { + + private final AlterUserScramCredentialsResponseData data; + + public AlterUserScramCredentialsResponse(AlterUserScramCredentialsResponseData responseData) { + super(ApiKeys.ALTER_USER_SCRAM_CREDENTIALS); + this.data = responseData; + } + + @Override + public AlterUserScramCredentialsResponseData data() { + return data; + } + + @Override + public boolean shouldClientThrottle(short version) { + return true; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + @Override + public Map errorCounts() { + return errorCounts(data.results().stream().map(r -> Errors.forCode(r.errorCode()))); + } + + public static AlterUserScramCredentialsResponse parse(ByteBuffer buffer, short version) { + return new AlterUserScramCredentialsResponse(new AlterUserScramCredentialsResponseData(new ByteBufferAccessor(buffer), version)); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/ApiError.java b/clients/src/main/java/org/apache/kafka/common/requests/ApiError.java new file mode 100644 index 0000000..0196653 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/ApiError.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.errors.ApiException; +import org.apache.kafka.common.protocol.Errors; + +import java.util.Objects; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ExecutionException; + +/** + * Encapsulates an error code (via the Errors enum) and an optional message. Generally, the optional message is only + * defined if it adds information over the default message associated with the error code. + * + * This is an internal class (like every class in the requests package). + */ +public class ApiError { + + public static final ApiError NONE = new ApiError(Errors.NONE, null); + + private final Errors error; + private final String message; + + public static ApiError fromThrowable(Throwable t) { + Throwable throwableToBeEncoded = t; + // Get the underlying cause for common exception types from the concurrent library. + // This is useful to handle cases where exceptions may be raised from a future or a + // completion stage (as might be the case for requests sent to the controller in `ControllerApis`) + if (t instanceof CompletionException || t instanceof ExecutionException) { + throwableToBeEncoded = t.getCause(); + } + // Avoid populating the error message if it's a generic one. Also don't populate error + // message for UNKNOWN_SERVER_ERROR to ensure we don't leak sensitive information. + Errors error = Errors.forException(throwableToBeEncoded); + String message = error == Errors.UNKNOWN_SERVER_ERROR || + error.message().equals(throwableToBeEncoded.getMessage()) ? null : throwableToBeEncoded.getMessage(); + return new ApiError(error, message); + } + + public ApiError(Errors error) { + this(error, error.message()); + } + + public ApiError(Errors error, String message) { + this.error = error; + this.message = message; + } + + public ApiError(short code, String message) { + this.error = Errors.forCode(code); + this.message = message; + } + + public boolean is(Errors error) { + return this.error == error; + } + + public boolean isFailure() { + return !isSuccess(); + } + + public boolean isSuccess() { + return is(Errors.NONE); + } + + public Errors error() { + return error; + } + + /** + * Return the optional error message or null. Consider using {@link #messageWithFallback()} instead. + */ + public String message() { + return message; + } + + /** + * If `message` is defined, return it. Otherwise fallback to the default error message associated with the error + * code. + */ + public String messageWithFallback() { + if (message == null) + return error.message(); + return message; + } + + public ApiException exception() { + return error.exception(message); + } + + @Override + public int hashCode() { + return Objects.hash(error, message); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof ApiError)) { + return false; + } + ApiError other = (ApiError) o; + return Objects.equals(error, other.error) && + Objects.equals(message, other.message); + } + + @Override + public String toString() { + return "ApiError(error=" + error + ", message=" + message + ")"; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/ApiVersionsRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/ApiVersionsRequest.java new file mode 100644 index 0000000..4ae819c --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/ApiVersionsRequest.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import java.util.regex.Pattern; +import org.apache.kafka.common.message.ApiVersionsRequestData; +import org.apache.kafka.common.message.ApiVersionsResponseData; +import org.apache.kafka.common.message.ApiVersionsResponseData.ApiVersionCollection; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.utils.AppInfoParser; + +import java.nio.ByteBuffer; + +public class ApiVersionsRequest extends AbstractRequest { + + public static class Builder extends AbstractRequest.Builder { + private static final String DEFAULT_CLIENT_SOFTWARE_NAME = "apache-kafka-java"; + + private static final ApiVersionsRequestData DATA = new ApiVersionsRequestData() + .setClientSoftwareName(DEFAULT_CLIENT_SOFTWARE_NAME) + .setClientSoftwareVersion(AppInfoParser.getVersion()); + + public Builder() { + super(ApiKeys.API_VERSIONS); + } + + public Builder(short version) { + super(ApiKeys.API_VERSIONS, version); + } + + @Override + public ApiVersionsRequest build(short version) { + return new ApiVersionsRequest(DATA, version); + } + + @Override + public String toString() { + return DATA.toString(); + } + } + + private static final Pattern SOFTWARE_NAME_VERSION_PATTERN = Pattern.compile("[a-zA-Z0-9](?:[a-zA-Z0-9\\-.]*[a-zA-Z0-9])?"); + + private final Short unsupportedRequestVersion; + + private final ApiVersionsRequestData data; + + public ApiVersionsRequest(ApiVersionsRequestData data, short version) { + this(data, version, null); + } + + public ApiVersionsRequest(ApiVersionsRequestData data, short version, Short unsupportedRequestVersion) { + super(ApiKeys.API_VERSIONS, version); + this.data = data; + + // Unlike other request types, the broker handles ApiVersion requests with higher versions than + // supported. It does so by treating the request as if it were v0 and returns a response using + // the v0 response schema. The reason for this is that the client does not yet know what versions + // a broker supports when this request is sent, so instead of assuming the lowest supported version, + // it can use the most recent version and only fallback to the old version when necessary. + this.unsupportedRequestVersion = unsupportedRequestVersion; + } + + public boolean hasUnsupportedRequestVersion() { + return unsupportedRequestVersion != null; + } + + public boolean isValid() { + if (version() >= 3) { + return SOFTWARE_NAME_VERSION_PATTERN.matcher(data.clientSoftwareName()).matches() && + SOFTWARE_NAME_VERSION_PATTERN.matcher(data.clientSoftwareVersion()).matches(); + } else { + return true; + } + } + + @Override + public ApiVersionsRequestData data() { + return data; + } + + @Override + public ApiVersionsResponse getErrorResponse(int throttleTimeMs, Throwable e) { + ApiVersionsResponseData data = new ApiVersionsResponseData() + .setErrorCode(Errors.forException(e).code()); + + if (version() >= 1) { + data.setThrottleTimeMs(throttleTimeMs); + } + + // Starting from Apache Kafka 2.4 (KIP-511), ApiKeys field is populated with the supported + // versions of the ApiVersionsRequest when an UNSUPPORTED_VERSION error is returned. + if (Errors.forException(e) == Errors.UNSUPPORTED_VERSION) { + ApiVersionCollection apiKeys = new ApiVersionCollection(); + apiKeys.add(ApiVersionsResponse.toApiVersion(ApiKeys.API_VERSIONS)); + data.setApiKeys(apiKeys); + } + + return new ApiVersionsResponse(data); + } + + public static ApiVersionsRequest parse(ByteBuffer buffer, short version) { + return new ApiVersionsRequest(new ApiVersionsRequestData(new ByteBufferAccessor(buffer), version), version); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/ApiVersionsResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/ApiVersionsResponse.java new file mode 100644 index 0000000..1190989 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/ApiVersionsResponse.java @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.feature.Features; +import org.apache.kafka.common.feature.FinalizedVersionRange; +import org.apache.kafka.common.feature.SupportedVersionRange; +import org.apache.kafka.common.message.ApiMessageType; +import org.apache.kafka.common.message.ApiVersionsResponseData; +import org.apache.kafka.common.message.ApiVersionsResponseData.ApiVersion; +import org.apache.kafka.common.message.ApiVersionsResponseData.ApiVersionCollection; +import org.apache.kafka.common.message.ApiVersionsResponseData.FinalizedFeatureKey; +import org.apache.kafka.common.message.ApiVersionsResponseData.FinalizedFeatureKeyCollection; +import org.apache.kafka.common.message.ApiVersionsResponseData.SupportedFeatureKey; +import org.apache.kafka.common.message.ApiVersionsResponseData.SupportedFeatureKeyCollection; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.RecordVersion; + +import java.nio.ByteBuffer; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +/** + * Possible error codes: + * - {@link Errors#UNSUPPORTED_VERSION} + * - {@link Errors#INVALID_REQUEST} + */ +public class ApiVersionsResponse extends AbstractResponse { + + public static final long UNKNOWN_FINALIZED_FEATURES_EPOCH = -1L; + + private final ApiVersionsResponseData data; + + public ApiVersionsResponse(ApiVersionsResponseData data) { + super(ApiKeys.API_VERSIONS); + this.data = data; + } + + @Override + public ApiVersionsResponseData data() { + return data; + } + + public ApiVersion apiVersion(short apiKey) { + return data.apiKeys().find(apiKey); + } + + @Override + public Map errorCounts() { + return errorCounts(Errors.forCode(this.data.errorCode())); + } + + @Override + public int throttleTimeMs() { + return this.data.throttleTimeMs(); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 2; + } + + public static ApiVersionsResponse parse(ByteBuffer buffer, short version) { + // Fallback to version 0 for ApiVersions response. If a client sends an ApiVersionsRequest + // using a version higher than that supported by the broker, a version 0 response is sent + // to the client indicating UNSUPPORTED_VERSION. When the client receives the response, it + // falls back while parsing it which means that the version received by this + // method is not necessarily the real one. It may be version 0 as well. + int prev = buffer.position(); + try { + return new ApiVersionsResponse(new ApiVersionsResponseData(new ByteBufferAccessor(buffer), version)); + } catch (RuntimeException e) { + buffer.position(prev); + if (version != 0) + return new ApiVersionsResponse(new ApiVersionsResponseData(new ByteBufferAccessor(buffer), (short) 0)); + else + throw e; + } + } + + public static ApiVersionsResponse defaultApiVersionsResponse( + ApiMessageType.ListenerType listenerType + ) { + return defaultApiVersionsResponse(0, listenerType); + } + + public static ApiVersionsResponse defaultApiVersionsResponse( + int throttleTimeMs, + ApiMessageType.ListenerType listenerType + ) { + return createApiVersionsResponse(throttleTimeMs, filterApis(RecordVersion.current(), listenerType)); + } + + public static ApiVersionsResponse createApiVersionsResponse( + int throttleTimeMs, + ApiVersionCollection apiVersions + ) { + return createApiVersionsResponse( + throttleTimeMs, + apiVersions, + Features.emptySupportedFeatures(), + Features.emptyFinalizedFeatures(), + UNKNOWN_FINALIZED_FEATURES_EPOCH + ); + } + + public static ApiVersionsResponse createApiVersionsResponse( + int throttleTimeMs, + ApiVersionCollection apiVersions, + Features latestSupportedFeatures, + Features finalizedFeatures, + long finalizedFeaturesEpoch + ) { + return new ApiVersionsResponse( + createApiVersionsResponseData( + throttleTimeMs, + Errors.NONE, + apiVersions, + latestSupportedFeatures, + finalizedFeatures, + finalizedFeaturesEpoch + ) + ); + } + + public static ApiVersionCollection filterApis( + RecordVersion minRecordVersion, + ApiMessageType.ListenerType listenerType + ) { + ApiVersionCollection apiKeys = new ApiVersionCollection(); + for (ApiKeys apiKey : ApiKeys.apisForListener(listenerType)) { + if (apiKey.minRequiredInterBrokerMagic <= minRecordVersion.value) { + apiKeys.add(ApiVersionsResponse.toApiVersion(apiKey)); + } + } + return apiKeys; + } + + public static ApiVersionCollection collectApis(Set apiKeys) { + ApiVersionCollection res = new ApiVersionCollection(); + for (ApiKeys apiKey : apiKeys) { + res.add(ApiVersionsResponse.toApiVersion(apiKey)); + } + return res; + } + + /** + * Find the common range of supported API versions between the locally + * known range and that of another set. + * + * @param listenerType the listener type which constrains the set of exposed APIs + * @param minRecordVersion min inter broker magic + * @param activeControllerApiVersions controller ApiVersions + * @return commonly agreed ApiVersion collection + */ + public static ApiVersionCollection intersectForwardableApis( + final ApiMessageType.ListenerType listenerType, + final RecordVersion minRecordVersion, + final Map activeControllerApiVersions + ) { + ApiVersionCollection apiKeys = new ApiVersionCollection(); + for (ApiKeys apiKey : ApiKeys.apisForListener(listenerType)) { + if (apiKey.minRequiredInterBrokerMagic <= minRecordVersion.value) { + ApiVersion brokerApiVersion = toApiVersion(apiKey); + + final ApiVersion finalApiVersion; + if (!apiKey.forwardable) { + finalApiVersion = brokerApiVersion; + } else { + Optional intersectVersion = intersect(brokerApiVersion, + activeControllerApiVersions.getOrDefault(apiKey, null)); + if (intersectVersion.isPresent()) { + finalApiVersion = intersectVersion.get(); + } else { + // Controller doesn't support this API key, or there is no intersection. + continue; + } + } + + apiKeys.add(finalApiVersion.duplicate()); + } + } + return apiKeys; + } + + private static ApiVersionsResponseData createApiVersionsResponseData( + final int throttleTimeMs, + final Errors error, + final ApiVersionCollection apiKeys, + final Features latestSupportedFeatures, + final Features finalizedFeatures, + final long finalizedFeaturesEpoch + ) { + final ApiVersionsResponseData data = new ApiVersionsResponseData(); + data.setThrottleTimeMs(throttleTimeMs); + data.setErrorCode(error.code()); + data.setApiKeys(apiKeys); + data.setSupportedFeatures(createSupportedFeatureKeys(latestSupportedFeatures)); + data.setFinalizedFeatures(createFinalizedFeatureKeys(finalizedFeatures)); + data.setFinalizedFeaturesEpoch(finalizedFeaturesEpoch); + + return data; + } + + private static SupportedFeatureKeyCollection createSupportedFeatureKeys( + Features latestSupportedFeatures) { + SupportedFeatureKeyCollection converted = new SupportedFeatureKeyCollection(); + for (Map.Entry feature : latestSupportedFeatures.features().entrySet()) { + final SupportedFeatureKey key = new SupportedFeatureKey(); + final SupportedVersionRange versionRange = feature.getValue(); + key.setName(feature.getKey()); + key.setMinVersion(versionRange.min()); + key.setMaxVersion(versionRange.max()); + converted.add(key); + } + + return converted; + } + + private static FinalizedFeatureKeyCollection createFinalizedFeatureKeys( + Features finalizedFeatures) { + FinalizedFeatureKeyCollection converted = new FinalizedFeatureKeyCollection(); + for (Map.Entry feature : finalizedFeatures.features().entrySet()) { + final FinalizedFeatureKey key = new FinalizedFeatureKey(); + final FinalizedVersionRange versionLevelRange = feature.getValue(); + key.setName(feature.getKey()); + key.setMinVersionLevel(versionLevelRange.min()); + key.setMaxVersionLevel(versionLevelRange.max()); + converted.add(key); + } + + return converted; + } + + public static Optional intersect(ApiVersion thisVersion, + ApiVersion other) { + if (thisVersion == null || other == null) return Optional.empty(); + if (thisVersion.apiKey() != other.apiKey()) + throw new IllegalArgumentException("thisVersion.apiKey: " + thisVersion.apiKey() + + " must be equal to other.apiKey: " + other.apiKey()); + short minVersion = (short) Math.max(thisVersion.minVersion(), other.minVersion()); + short maxVersion = (short) Math.min(thisVersion.maxVersion(), other.maxVersion()); + return minVersion > maxVersion + ? Optional.empty() + : Optional.of(new ApiVersion() + .setApiKey(thisVersion.apiKey()) + .setMinVersion(minVersion) + .setMaxVersion(maxVersion)); + } + + public static ApiVersion toApiVersion(ApiKeys apiKey) { + return new ApiVersion() + .setApiKey(apiKey.id) + .setMinVersion(apiKey.oldestVersion()) + .setMaxVersion(apiKey.latestVersion()); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/BeginQuorumEpochRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/BeginQuorumEpochRequest.java new file mode 100644 index 0000000..0794fb4 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/BeginQuorumEpochRequest.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.BeginQuorumEpochRequestData; +import org.apache.kafka.common.message.BeginQuorumEpochResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Collections; + +public class BeginQuorumEpochRequest extends AbstractRequest { + public static class Builder extends AbstractRequest.Builder { + private final BeginQuorumEpochRequestData data; + + public Builder(BeginQuorumEpochRequestData data) { + super(ApiKeys.BEGIN_QUORUM_EPOCH); + this.data = data; + } + + @Override + public BeginQuorumEpochRequest build(short version) { + return new BeginQuorumEpochRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final BeginQuorumEpochRequestData data; + + private BeginQuorumEpochRequest(BeginQuorumEpochRequestData data, short version) { + super(ApiKeys.BEGIN_QUORUM_EPOCH, version); + this.data = data; + } + + @Override + public BeginQuorumEpochRequestData data() { + return data; + } + + @Override + public BeginQuorumEpochResponse getErrorResponse(int throttleTimeMs, Throwable e) { + return new BeginQuorumEpochResponse(new BeginQuorumEpochResponseData() + .setErrorCode(Errors.forException(e).code())); + } + + public static BeginQuorumEpochRequest parse(ByteBuffer buffer, short version) { + return new BeginQuorumEpochRequest(new BeginQuorumEpochRequestData(new ByteBufferAccessor(buffer), version), version); + } + + public static BeginQuorumEpochRequestData singletonRequest(TopicPartition topicPartition, + int leaderEpoch, + int leaderId) { + return singletonRequest(topicPartition, null, leaderEpoch, leaderId); + } + + public static BeginQuorumEpochRequestData singletonRequest(TopicPartition topicPartition, + String clusterId, + int leaderEpoch, + int leaderId) { + return new BeginQuorumEpochRequestData() + .setClusterId(clusterId) + .setTopics(Collections.singletonList( + new BeginQuorumEpochRequestData.TopicData() + .setTopicName(topicPartition.topic()) + .setPartitions(Collections.singletonList( + new BeginQuorumEpochRequestData.PartitionData() + .setPartitionIndex(topicPartition.partition()) + .setLeaderEpoch(leaderEpoch) + .setLeaderId(leaderId)))) + ); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/BeginQuorumEpochResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/BeginQuorumEpochResponse.java new file mode 100644 index 0000000..c8c0328 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/BeginQuorumEpochResponse.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.BeginQuorumEpochResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * Possible error codes. + * + * Top level errors: + * - {@link Errors#CLUSTER_AUTHORIZATION_FAILED} + * - {@link Errors#BROKER_NOT_AVAILABLE} + * + * Partition level errors: + * - {@link Errors#FENCED_LEADER_EPOCH} + * - {@link Errors#INVALID_REQUEST} + * - {@link Errors#INCONSISTENT_VOTER_SET} + * - {@link Errors#UNKNOWN_TOPIC_OR_PARTITION} + */ +public class BeginQuorumEpochResponse extends AbstractResponse { + private final BeginQuorumEpochResponseData data; + + public BeginQuorumEpochResponse(BeginQuorumEpochResponseData data) { + super(ApiKeys.BEGIN_QUORUM_EPOCH); + this.data = data; + } + + public static BeginQuorumEpochResponseData singletonResponse( + Errors topLevelError, + TopicPartition topicPartition, + Errors partitionLevelError, + int leaderEpoch, + int leaderId + ) { + return new BeginQuorumEpochResponseData() + .setErrorCode(topLevelError.code()) + .setTopics(Collections.singletonList( + new BeginQuorumEpochResponseData.TopicData() + .setTopicName(topicPartition.topic()) + .setPartitions(Collections.singletonList( + new BeginQuorumEpochResponseData.PartitionData() + .setErrorCode(partitionLevelError.code()) + .setLeaderId(leaderId) + .setLeaderEpoch(leaderEpoch) + ))) + ); + } + + @Override + public Map errorCounts() { + Map errors = new HashMap<>(); + + errors.put(Errors.forCode(data.errorCode()), 1); + + for (BeginQuorumEpochResponseData.TopicData topicResponse : data.topics()) { + for (BeginQuorumEpochResponseData.PartitionData partitionResponse : topicResponse.partitions()) { + errors.compute(Errors.forCode(partitionResponse.errorCode()), + (error, count) -> count == null ? 1 : count + 1); + } + } + return errors; + } + + @Override + public BeginQuorumEpochResponseData data() { + return data; + } + + @Override + public int throttleTimeMs() { + return DEFAULT_THROTTLE_TIME; + } + + public static BeginQuorumEpochResponse parse(ByteBuffer buffer, short version) { + return new BeginQuorumEpochResponse(new BeginQuorumEpochResponseData(new ByteBufferAccessor(buffer), version)); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/BrokerHeartbeatRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/BrokerHeartbeatRequest.java new file mode 100644 index 0000000..3c3f350 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/BrokerHeartbeatRequest.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.BrokerHeartbeatRequestData; +import org.apache.kafka.common.message.BrokerHeartbeatResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; + +public class BrokerHeartbeatRequest extends AbstractRequest { + + public static class Builder extends AbstractRequest.Builder { + private final BrokerHeartbeatRequestData data; + + public Builder(BrokerHeartbeatRequestData data) { + super(ApiKeys.BROKER_HEARTBEAT); + this.data = data; + } + + @Override + public BrokerHeartbeatRequest build(short version) { + return new BrokerHeartbeatRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final BrokerHeartbeatRequestData data; + + public BrokerHeartbeatRequest(BrokerHeartbeatRequestData data, short version) { + super(ApiKeys.BROKER_HEARTBEAT, version); + this.data = data; + } + + @Override + public BrokerHeartbeatRequestData data() { + return data; + } + + @Override + public BrokerHeartbeatResponse getErrorResponse(int throttleTimeMs, Throwable e) { + Errors error = Errors.forException(e); + return new BrokerHeartbeatResponse(new BrokerHeartbeatResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setErrorCode(error.code())); + } + + public static BrokerHeartbeatRequest parse(ByteBuffer buffer, short version) { + return new BrokerHeartbeatRequest(new BrokerHeartbeatRequestData(new ByteBufferAccessor(buffer), version), + version); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/BrokerHeartbeatResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/BrokerHeartbeatResponse.java new file mode 100644 index 0000000..e7d01e5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/BrokerHeartbeatResponse.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.BrokerHeartbeatResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; + +public class BrokerHeartbeatResponse extends AbstractResponse { + private final BrokerHeartbeatResponseData data; + + public BrokerHeartbeatResponse(BrokerHeartbeatResponseData data) { + super(ApiKeys.BROKER_HEARTBEAT); + this.data = data; + } + + @Override + public BrokerHeartbeatResponseData data() { + return data; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + @Override + public Map errorCounts() { + Map errorCounts = new HashMap<>(); + errorCounts.put(Errors.forCode(data.errorCode()), 1); + return errorCounts; + } + + public static BrokerHeartbeatResponse parse(ByteBuffer buffer, short version) { + return new BrokerHeartbeatResponse(new BrokerHeartbeatResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public boolean shouldClientThrottle(short version) { + return true; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/BrokerRegistrationRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/BrokerRegistrationRequest.java new file mode 100644 index 0000000..2ba1529 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/BrokerRegistrationRequest.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.BrokerRegistrationRequestData; +import org.apache.kafka.common.message.BrokerRegistrationResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; + +public class BrokerRegistrationRequest extends AbstractRequest { + + public static class Builder extends AbstractRequest.Builder { + private final BrokerRegistrationRequestData data; + + public Builder(BrokerRegistrationRequestData data) { + super(ApiKeys.BROKER_REGISTRATION); + this.data = data; + } + + @Override + public BrokerRegistrationRequest build(short version) { + return new BrokerRegistrationRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final BrokerRegistrationRequestData data; + + public BrokerRegistrationRequest(BrokerRegistrationRequestData data, short version) { + super(ApiKeys.BROKER_REGISTRATION, version); + this.data = data; + } + + @Override + public BrokerRegistrationRequestData data() { + return data; + } + + @Override + public BrokerRegistrationResponse getErrorResponse(int throttleTimeMs, Throwable e) { + Errors error = Errors.forException(e); + return new BrokerRegistrationResponse(new BrokerRegistrationResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setErrorCode(error.code())); + } + + public static BrokerRegistrationRequest parse(ByteBuffer buffer, short version) { + return new BrokerRegistrationRequest(new BrokerRegistrationRequestData(new ByteBufferAccessor(buffer), version), + version); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/BrokerRegistrationResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/BrokerRegistrationResponse.java new file mode 100644 index 0000000..8296d7a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/BrokerRegistrationResponse.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.BrokerRegistrationResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; + +public class BrokerRegistrationResponse extends AbstractResponse { + private final BrokerRegistrationResponseData data; + + public BrokerRegistrationResponse(BrokerRegistrationResponseData data) { + super(ApiKeys.BROKER_REGISTRATION); + this.data = data; + } + + @Override + public BrokerRegistrationResponseData data() { + return data; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + @Override + public Map errorCounts() { + Map errorCounts = new HashMap<>(); + errorCounts.put(Errors.forCode(data.errorCode()), 1); + return errorCounts; + } + + public static BrokerRegistrationResponse parse(ByteBuffer buffer, short version) { + return new BrokerRegistrationResponse(new BrokerRegistrationResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public boolean shouldClientThrottle(short version) { + return true; + } + + @Override + public String toString() { + return data.toString(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/ControlledShutdownRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/ControlledShutdownRequest.java new file mode 100644 index 0000000..088c351 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/ControlledShutdownRequest.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.ControlledShutdownRequestData; +import org.apache.kafka.common.message.ControlledShutdownResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; + +public class ControlledShutdownRequest extends AbstractRequest { + + public static class Builder extends AbstractRequest.Builder { + + private final ControlledShutdownRequestData data; + + public Builder(ControlledShutdownRequestData data, short desiredVersion) { + super(ApiKeys.CONTROLLED_SHUTDOWN, desiredVersion); + this.data = data; + } + + @Override + public ControlledShutdownRequest build(short version) { + return new ControlledShutdownRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final ControlledShutdownRequestData data; + + private ControlledShutdownRequest(ControlledShutdownRequestData data, short version) { + super(ApiKeys.CONTROLLED_SHUTDOWN, version); + this.data = data; + } + + @Override + public ControlledShutdownResponse getErrorResponse(int throttleTimeMs, Throwable e) { + ControlledShutdownResponseData data = new ControlledShutdownResponseData() + .setErrorCode(Errors.forException(e).code()); + return new ControlledShutdownResponse(data); + } + + public static ControlledShutdownRequest parse(ByteBuffer buffer, short version) { + return new ControlledShutdownRequest(new ControlledShutdownRequestData(new ByteBufferAccessor(buffer), version), + version); + } + + @Override + public ControlledShutdownRequestData data() { + return data; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/ControlledShutdownResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/ControlledShutdownResponse.java new file mode 100644 index 0000000..73b6a50 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/ControlledShutdownResponse.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.ControlledShutdownResponseData; +import org.apache.kafka.common.message.ControlledShutdownResponseData.RemainingPartition; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Map; +import java.util.Set; + + +public class ControlledShutdownResponse extends AbstractResponse { + + /** + * Possible error codes: + * + * UNKNOWN(-1) (this is because IllegalStateException may be thrown in `KafkaController.shutdownBroker`, it would be good to improve this) + * BROKER_NOT_AVAILABLE(8) + * STALE_CONTROLLER_EPOCH(11) + */ + private final ControlledShutdownResponseData data; + + public ControlledShutdownResponse(ControlledShutdownResponseData data) { + super(ApiKeys.CONTROLLED_SHUTDOWN); + this.data = data; + } + + public Errors error() { + return Errors.forCode(data.errorCode()); + } + + @Override + public Map errorCounts() { + return errorCounts(error()); + } + + @Override + public int throttleTimeMs() { + return DEFAULT_THROTTLE_TIME; + } + + public static ControlledShutdownResponse parse(ByteBuffer buffer, short version) { + return new ControlledShutdownResponse(new ControlledShutdownResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public ControlledShutdownResponseData data() { + return data; + } + + public static ControlledShutdownResponse prepareResponse(Errors error, Set tps) { + ControlledShutdownResponseData data = new ControlledShutdownResponseData(); + data.setErrorCode(error.code()); + ControlledShutdownResponseData.RemainingPartitionCollection pSet = new ControlledShutdownResponseData.RemainingPartitionCollection(); + tps.forEach(tp -> { + pSet.add(new RemainingPartition() + .setTopicName(tp.topic()) + .setPartitionIndex(tp.partition())); + }); + data.setRemainingPartitions(pSet); + return new ControlledShutdownResponse(data); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/CorrelationIdMismatchException.java b/clients/src/main/java/org/apache/kafka/common/requests/CorrelationIdMismatchException.java new file mode 100644 index 0000000..2610a27 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/CorrelationIdMismatchException.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +/** + * Raised if the correlationId in a response header does not match + * the expected value from the request header. + */ +public class CorrelationIdMismatchException extends IllegalStateException { + private final int requestCorrelationId; + private final int responseCorrelationId; + + public CorrelationIdMismatchException( + String message, + int requestCorrelationId, + int responseCorrelationId + ) { + super(message); + this.requestCorrelationId = requestCorrelationId; + this.responseCorrelationId = responseCorrelationId; + } + + public int requestCorrelationId() { + return requestCorrelationId; + } + + public int responseCorrelationId() { + return responseCorrelationId; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/CreateAclsRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/CreateAclsRequest.java new file mode 100644 index 0000000..29df832 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/CreateAclsRequest.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.acl.AccessControlEntry; +import org.apache.kafka.common.acl.AclBinding; +import org.apache.kafka.common.acl.AclOperation; +import org.apache.kafka.common.acl.AclPermissionType; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.CreateAclsRequestData; +import org.apache.kafka.common.message.CreateAclsRequestData.AclCreation; +import org.apache.kafka.common.message.CreateAclsResponseData; +import org.apache.kafka.common.message.CreateAclsResponseData.AclCreationResult; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.resource.PatternType; +import org.apache.kafka.common.resource.ResourcePattern; +import org.apache.kafka.common.resource.ResourceType; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.List; + +public class CreateAclsRequest extends AbstractRequest { + + public static class Builder extends AbstractRequest.Builder { + private final CreateAclsRequestData data; + + public Builder(CreateAclsRequestData data) { + super(ApiKeys.CREATE_ACLS); + this.data = data; + } + + @Override + public CreateAclsRequest build(short version) { + return new CreateAclsRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final CreateAclsRequestData data; + + CreateAclsRequest(CreateAclsRequestData data, short version) { + super(ApiKeys.CREATE_ACLS, version); + validate(data); + this.data = data; + } + + public List aclCreations() { + return data.creations(); + } + + @Override + public CreateAclsRequestData data() { + return data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable throwable) { + CreateAclsResponseData.AclCreationResult result = CreateAclsRequest.aclResult(throwable); + List results = Collections.nCopies(data.creations().size(), result); + return new CreateAclsResponse(new CreateAclsResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setResults(results)); + } + + public static CreateAclsRequest parse(ByteBuffer buffer, short version) { + return new CreateAclsRequest(new CreateAclsRequestData(new ByteBufferAccessor(buffer), version), version); + } + + private void validate(CreateAclsRequestData data) { + if (version() == 0) { + final boolean unsupported = data.creations().stream().anyMatch(creation -> + creation.resourcePatternType() != PatternType.LITERAL.code()); + if (unsupported) + throw new UnsupportedVersionException("Version 0 only supports literal resource pattern types"); + } + + final boolean unknown = data.creations().stream().anyMatch(creation -> + creation.resourcePatternType() == PatternType.UNKNOWN.code() + || creation.resourceType() == ResourceType.UNKNOWN.code() + || creation.permissionType() == AclPermissionType.UNKNOWN.code() + || creation.operation() == AclOperation.UNKNOWN.code()); + if (unknown) + throw new IllegalArgumentException("CreatableAcls contain unknown elements: " + data.creations()); + } + + public static AclBinding aclBinding(AclCreation acl) { + ResourcePattern pattern = new ResourcePattern( + ResourceType.fromCode(acl.resourceType()), + acl.resourceName(), + PatternType.fromCode(acl.resourcePatternType())); + AccessControlEntry entry = new AccessControlEntry( + acl.principal(), + acl.host(), + AclOperation.fromCode(acl.operation()), + AclPermissionType.fromCode(acl.permissionType())); + return new AclBinding(pattern, entry); + } + + public static AclCreation aclCreation(AclBinding binding) { + return new AclCreation() + .setHost(binding.entry().host()) + .setOperation(binding.entry().operation().code()) + .setPermissionType(binding.entry().permissionType().code()) + .setPrincipal(binding.entry().principal()) + .setResourceName(binding.pattern().name()) + .setResourceType(binding.pattern().resourceType().code()) + .setResourcePatternType(binding.pattern().patternType().code()); + } + + private static AclCreationResult aclResult(Throwable throwable) { + ApiError apiError = ApiError.fromThrowable(throwable); + return new AclCreationResult() + .setErrorCode(apiError.error().code()) + .setErrorMessage(apiError.message()); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/CreateAclsResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/CreateAclsResponse.java new file mode 100644 index 0000000..8bc6643 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/CreateAclsResponse.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.CreateAclsResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Map; + +public class CreateAclsResponse extends AbstractResponse { + private final CreateAclsResponseData data; + + public CreateAclsResponse(CreateAclsResponseData data) { + super(ApiKeys.CREATE_ACLS); + this.data = data; + } + + @Override + public CreateAclsResponseData data() { + return data; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + public List results() { + return data.results(); + } + + @Override + public Map errorCounts() { + return errorCounts(results().stream().map(r -> Errors.forCode(r.errorCode()))); + } + + public static CreateAclsResponse parse(ByteBuffer buffer, short version) { + return new CreateAclsResponse(new CreateAclsResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 1; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/CreateDelegationTokenRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/CreateDelegationTokenRequest.java new file mode 100644 index 0000000..1fee1b7 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/CreateDelegationTokenRequest.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.CreateDelegationTokenRequestData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.security.auth.KafkaPrincipal; + +import java.nio.ByteBuffer; + +public class CreateDelegationTokenRequest extends AbstractRequest { + + private final CreateDelegationTokenRequestData data; + + private CreateDelegationTokenRequest(CreateDelegationTokenRequestData data, short version) { + super(ApiKeys.CREATE_DELEGATION_TOKEN, version); + this.data = data; + } + + public static CreateDelegationTokenRequest parse(ByteBuffer buffer, short version) { + return new CreateDelegationTokenRequest(new CreateDelegationTokenRequestData(new ByteBufferAccessor(buffer), version), + version); + } + + @Override + public CreateDelegationTokenRequestData data() { + return data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + return CreateDelegationTokenResponse.prepareResponse(throttleTimeMs, Errors.forException(e), KafkaPrincipal.ANONYMOUS); + } + + public static class Builder extends AbstractRequest.Builder { + private final CreateDelegationTokenRequestData data; + + public Builder(CreateDelegationTokenRequestData data) { + super(ApiKeys.CREATE_DELEGATION_TOKEN); + this.data = data; + } + + @Override + public CreateDelegationTokenRequest build(short version) { + return new CreateDelegationTokenRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/CreateDelegationTokenResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/CreateDelegationTokenResponse.java new file mode 100644 index 0000000..b679a30 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/CreateDelegationTokenResponse.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.CreateDelegationTokenResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.security.auth.KafkaPrincipal; + +import java.nio.ByteBuffer; +import java.util.Map; + +public class CreateDelegationTokenResponse extends AbstractResponse { + + private final CreateDelegationTokenResponseData data; + + public CreateDelegationTokenResponse(CreateDelegationTokenResponseData data) { + super(ApiKeys.CREATE_DELEGATION_TOKEN); + this.data = data; + } + + public static CreateDelegationTokenResponse parse(ByteBuffer buffer, short version) { + return new CreateDelegationTokenResponse( + new CreateDelegationTokenResponseData(new ByteBufferAccessor(buffer), version)); + } + + public static CreateDelegationTokenResponse prepareResponse(int throttleTimeMs, + Errors error, + KafkaPrincipal owner, + long issueTimestamp, + long expiryTimestamp, + long maxTimestamp, + String tokenId, + ByteBuffer hmac) { + CreateDelegationTokenResponseData data = new CreateDelegationTokenResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setErrorCode(error.code()) + .setPrincipalType(owner.getPrincipalType()) + .setPrincipalName(owner.getName()) + .setIssueTimestampMs(issueTimestamp) + .setExpiryTimestampMs(expiryTimestamp) + .setMaxTimestampMs(maxTimestamp) + .setTokenId(tokenId) + .setHmac(hmac.array()); + return new CreateDelegationTokenResponse(data); + } + + public static CreateDelegationTokenResponse prepareResponse(int throttleTimeMs, Errors error, KafkaPrincipal owner) { + return prepareResponse(throttleTimeMs, error, owner, -1, -1, -1, "", ByteBuffer.wrap(new byte[] {})); + } + + @Override + public CreateDelegationTokenResponseData data() { + return data; + } + + @Override + public Map errorCounts() { + return errorCounts(error()); + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + public Errors error() { + return Errors.forCode(data.errorCode()); + } + + public boolean hasError() { + return error() != Errors.NONE; + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 1; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/CreatePartitionsRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/CreatePartitionsRequest.java new file mode 100644 index 0000000..d371bbb --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/CreatePartitionsRequest.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.CreatePartitionsRequestData; +import org.apache.kafka.common.message.CreatePartitionsRequestData.CreatePartitionsTopic; +import org.apache.kafka.common.message.CreatePartitionsResponseData; +import org.apache.kafka.common.message.CreatePartitionsResponseData.CreatePartitionsTopicResult; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; + +import java.nio.ByteBuffer; + +public class CreatePartitionsRequest extends AbstractRequest { + + private final CreatePartitionsRequestData data; + + public static class Builder extends AbstractRequest.Builder { + + private final CreatePartitionsRequestData data; + + public Builder(CreatePartitionsRequestData data) { + super(ApiKeys.CREATE_PARTITIONS); + this.data = data; + } + + @Override + public CreatePartitionsRequest build(short version) { + return new CreatePartitionsRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + CreatePartitionsRequest(CreatePartitionsRequestData data, short apiVersion) { + super(ApiKeys.CREATE_PARTITIONS, apiVersion); + this.data = data; + } + + @Override + public CreatePartitionsRequestData data() { + return data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + CreatePartitionsResponseData response = new CreatePartitionsResponseData(); + response.setThrottleTimeMs(throttleTimeMs); + + ApiError apiError = ApiError.fromThrowable(e); + for (CreatePartitionsTopic topic : data.topics()) { + response.results().add(new CreatePartitionsTopicResult() + .setName(topic.name()) + .setErrorCode(apiError.error().code()) + .setErrorMessage(apiError.message()) + ); + } + return new CreatePartitionsResponse(response); + } + + public static CreatePartitionsRequest parse(ByteBuffer buffer, short version) { + return new CreatePartitionsRequest(new CreatePartitionsRequestData(new ByteBufferAccessor(buffer), version), version); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/CreatePartitionsResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/CreatePartitionsResponse.java new file mode 100644 index 0000000..e59ac98 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/CreatePartitionsResponse.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.CreatePartitionsResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; + +public class CreatePartitionsResponse extends AbstractResponse { + + private final CreatePartitionsResponseData data; + + public CreatePartitionsResponse(CreatePartitionsResponseData data) { + super(ApiKeys.CREATE_PARTITIONS); + this.data = data; + } + + @Override + public CreatePartitionsResponseData data() { + return data; + } + + @Override + public Map errorCounts() { + Map counts = new HashMap<>(); + data.results().forEach(result -> + updateErrorCounts(counts, Errors.forCode(result.errorCode())) + ); + return counts; + } + + public static CreatePartitionsResponse parse(ByteBuffer buffer, short version) { + return new CreatePartitionsResponse(new CreatePartitionsResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 1; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/CreateTopicsRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/CreateTopicsRequest.java new file mode 100644 index 0000000..a003c2d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/CreateTopicsRequest.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import java.nio.ByteBuffer; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.CreateTopicsRequestData; +import org.apache.kafka.common.message.CreateTopicsRequestData.CreatableTopic; +import org.apache.kafka.common.message.CreateTopicsResponseData; +import org.apache.kafka.common.message.CreateTopicsResponseData.CreatableTopicResult; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; + +public class CreateTopicsRequest extends AbstractRequest { + public static class Builder extends AbstractRequest.Builder { + private final CreateTopicsRequestData data; + + public Builder(CreateTopicsRequestData data) { + super(ApiKeys.CREATE_TOPICS); + this.data = data; + } + + @Override + public CreateTopicsRequest build(short version) { + if (data.validateOnly() && version == 0) + throw new UnsupportedVersionException("validateOnly is not supported in version 0 of " + + "CreateTopicsRequest"); + + final List topicsWithDefaults = data.topics() + .stream() + .filter(topic -> topic.assignments().isEmpty()) + .filter(topic -> + topic.numPartitions() == CreateTopicsRequest.NO_NUM_PARTITIONS + || topic.replicationFactor() == CreateTopicsRequest.NO_REPLICATION_FACTOR) + .map(CreatableTopic::name) + .collect(Collectors.toList()); + + if (!topicsWithDefaults.isEmpty() && version < 4) { + throw new UnsupportedVersionException("Creating topics with default " + + "partitions/replication factor are only supported in CreateTopicRequest " + + "version 4+. The following topics need values for partitions and replicas: " + + topicsWithDefaults); + } + + return new CreateTopicsRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + + @Override + public boolean equals(Object other) { + return other instanceof Builder && this.data.equals(((Builder) other).data); + } + + @Override + public int hashCode() { + return data.hashCode(); + } + } + + private final CreateTopicsRequestData data; + + public static final int NO_NUM_PARTITIONS = -1; + public static final short NO_REPLICATION_FACTOR = -1; + + public CreateTopicsRequest(CreateTopicsRequestData data, short version) { + super(ApiKeys.CREATE_TOPICS, version); + this.data = data; + } + + @Override + public CreateTopicsRequestData data() { + return data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + CreateTopicsResponseData response = new CreateTopicsResponseData(); + if (version() >= 2) { + response.setThrottleTimeMs(throttleTimeMs); + } + ApiError apiError = ApiError.fromThrowable(e); + for (CreatableTopic topic : data.topics()) { + response.topics().add(new CreatableTopicResult(). + setName(topic.name()). + setErrorCode(apiError.error().code()). + setErrorMessage(apiError.message())); + } + return new CreateTopicsResponse(response); + } + + public static CreateTopicsRequest parse(ByteBuffer buffer, short version) { + return new CreateTopicsRequest(new CreateTopicsRequestData(new ByteBufferAccessor(buffer), version), version); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/CreateTopicsResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/CreateTopicsResponse.java new file mode 100644 index 0000000..dd06277 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/CreateTopicsResponse.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.CreateTopicsResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; + +public class CreateTopicsResponse extends AbstractResponse { + /** + * Possible error codes: + * + * REQUEST_TIMED_OUT(7) + * INVALID_TOPIC_EXCEPTION(17) + * TOPIC_AUTHORIZATION_FAILED(29) + * TOPIC_ALREADY_EXISTS(36) + * INVALID_PARTITIONS(37) + * INVALID_REPLICATION_FACTOR(38) + * INVALID_REPLICA_ASSIGNMENT(39) + * INVALID_CONFIG(40) + * NOT_CONTROLLER(41) + * INVALID_REQUEST(42) + * POLICY_VIOLATION(44) + */ + + private final CreateTopicsResponseData data; + + public CreateTopicsResponse(CreateTopicsResponseData data) { + super(ApiKeys.CREATE_TOPICS); + this.data = data; + } + + @Override + public CreateTopicsResponseData data() { + return data; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + @Override + public Map errorCounts() { + HashMap counts = new HashMap<>(); + data.topics().forEach(result -> + updateErrorCounts(counts, Errors.forCode(result.errorCode())) + ); + return counts; + } + + public static CreateTopicsResponse parse(ByteBuffer buffer, short version) { + return new CreateTopicsResponse(new CreateTopicsResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 3; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DeleteAclsRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/DeleteAclsRequest.java new file mode 100644 index 0000000..98fd658 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DeleteAclsRequest.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import java.util.Collections; +import java.util.stream.Collectors; +import org.apache.kafka.common.acl.AccessControlEntryFilter; +import org.apache.kafka.common.acl.AclBindingFilter; +import org.apache.kafka.common.acl.AclOperation; +import org.apache.kafka.common.acl.AclPermissionType; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.DeleteAclsRequestData; +import org.apache.kafka.common.message.DeleteAclsRequestData.DeleteAclsFilter; +import org.apache.kafka.common.message.DeleteAclsResponseData; +import org.apache.kafka.common.message.DeleteAclsResponseData.DeleteAclsFilterResult; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.resource.PatternType; +import org.apache.kafka.common.resource.ResourcePatternFilter; +import org.apache.kafka.common.resource.ResourceType; + +import java.nio.ByteBuffer; +import java.util.List; + +import static org.apache.kafka.common.protocol.ApiKeys.DELETE_ACLS; + +public class DeleteAclsRequest extends AbstractRequest { + public static class Builder extends AbstractRequest.Builder { + private final DeleteAclsRequestData data; + + public Builder(DeleteAclsRequestData data) { + super(DELETE_ACLS); + this.data = data; + } + + @Override + public DeleteAclsRequest build(short version) { + return new DeleteAclsRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + + } + + private final DeleteAclsRequestData data; + + private DeleteAclsRequest(DeleteAclsRequestData data, short version) { + super(ApiKeys.DELETE_ACLS, version); + this.data = data; + normalizeAndValidate(); + } + + private void normalizeAndValidate() { + if (version() == 0) { + for (DeleteAclsRequestData.DeleteAclsFilter filter : data.filters()) { + PatternType patternType = PatternType.fromCode(filter.patternTypeFilter()); + + // On older brokers, no pattern types existed except LITERAL (effectively). So even though ANY is not + // directly supported on those brokers, we can get the same effect as ANY by setting the pattern type + // to LITERAL. Note that the wildcard `*` is considered `LITERAL` for compatibility reasons. + if (patternType == PatternType.ANY) + filter.setPatternTypeFilter(PatternType.LITERAL.code()); + else if (patternType != PatternType.LITERAL) + throw new UnsupportedVersionException("Version 0 does not support pattern type " + + patternType + " (only LITERAL and ANY are supported)"); + } + } + + final boolean unknown = data.filters().stream().anyMatch(filter -> + filter.patternTypeFilter() == PatternType.UNKNOWN.code() + || filter.resourceTypeFilter() == ResourceType.UNKNOWN.code() + || filter.operation() == AclOperation.UNKNOWN.code() + || filter.permissionType() == AclPermissionType.UNKNOWN.code() + ); + + if (unknown) { + throw new IllegalArgumentException("Filters contain UNKNOWN elements, filters: " + data.filters()); + } + } + + public List filters() { + return data.filters().stream().map(DeleteAclsRequest::aclBindingFilter).collect(Collectors.toList()); + } + + @Override + public DeleteAclsRequestData data() { + return data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable throwable) { + ApiError apiError = ApiError.fromThrowable(throwable); + List filterResults = Collections.nCopies(data.filters().size(), + new DeleteAclsResponseData.DeleteAclsFilterResult() + .setErrorCode(apiError.error().code()) + .setErrorMessage(apiError.message())); + return new DeleteAclsResponse(new DeleteAclsResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setFilterResults(filterResults), version()); + } + + public static DeleteAclsRequest parse(ByteBuffer buffer, short version) { + return new DeleteAclsRequest(new DeleteAclsRequestData(new ByteBufferAccessor(buffer), version), version); + } + + public static DeleteAclsFilter deleteAclsFilter(AclBindingFilter filter) { + return new DeleteAclsFilter() + .setResourceNameFilter(filter.patternFilter().name()) + .setResourceTypeFilter(filter.patternFilter().resourceType().code()) + .setPatternTypeFilter(filter.patternFilter().patternType().code()) + .setHostFilter(filter.entryFilter().host()) + .setOperation(filter.entryFilter().operation().code()) + .setPermissionType(filter.entryFilter().permissionType().code()) + .setPrincipalFilter(filter.entryFilter().principal()); + } + + private static AclBindingFilter aclBindingFilter(DeleteAclsFilter filter) { + ResourcePatternFilter patternFilter = new ResourcePatternFilter( + ResourceType.fromCode(filter.resourceTypeFilter()), + filter.resourceNameFilter(), + PatternType.fromCode(filter.patternTypeFilter())); + AccessControlEntryFilter entryFilter = new AccessControlEntryFilter( + filter.principalFilter(), + filter.hostFilter(), + AclOperation.fromCode(filter.operation()), + AclPermissionType.fromCode(filter.permissionType())); + return new AclBindingFilter(patternFilter, entryFilter); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DeleteAclsResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/DeleteAclsResponse.java new file mode 100644 index 0000000..7482953 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DeleteAclsResponse.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.acl.AccessControlEntry; +import org.apache.kafka.common.acl.AclBinding; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.resource.ResourcePattern; +import org.apache.kafka.common.acl.AclOperation; +import org.apache.kafka.common.acl.AclPermissionType; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.DeleteAclsResponseData; +import org.apache.kafka.common.message.DeleteAclsResponseData.DeleteAclsFilterResult; +import org.apache.kafka.common.message.DeleteAclsResponseData.DeleteAclsMatchingAcl; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.resource.PatternType; +import org.apache.kafka.common.resource.ResourceType; +import org.apache.kafka.server.authorizer.AclDeleteResult; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +public class DeleteAclsResponse extends AbstractResponse { + public static final Logger log = LoggerFactory.getLogger(DeleteAclsResponse.class); + + private final DeleteAclsResponseData data; + + public DeleteAclsResponse(DeleteAclsResponseData data, short version) { + super(ApiKeys.DELETE_ACLS); + this.data = data; + validate(version); + } + + @Override + public DeleteAclsResponseData data() { + return data; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + public List filterResults() { + return data.filterResults(); + } + + @Override + public Map errorCounts() { + return errorCounts(filterResults().stream().map(r -> Errors.forCode(r.errorCode()))); + } + + public static DeleteAclsResponse parse(ByteBuffer buffer, short version) { + return new DeleteAclsResponse(new DeleteAclsResponseData(new ByteBufferAccessor(buffer), version), version); + } + + public String toString() { + return data.toString(); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 1; + } + + private void validate(short version) { + if (version == 0) { + final boolean unsupported = filterResults().stream() + .flatMap(r -> r.matchingAcls().stream()) + .anyMatch(matchingAcl -> matchingAcl.patternType() != PatternType.LITERAL.code()); + if (unsupported) + throw new UnsupportedVersionException("Version 0 only supports literal resource pattern types"); + } + + final boolean unknown = filterResults().stream() + .flatMap(r -> r.matchingAcls().stream()) + .anyMatch(matchingAcl -> matchingAcl.patternType() == PatternType.UNKNOWN.code() + || matchingAcl.resourceType() == ResourceType.UNKNOWN.code() + || matchingAcl.permissionType() == AclPermissionType.UNKNOWN.code() + || matchingAcl.operation() == AclOperation.UNKNOWN.code()); + if (unknown) + throw new IllegalArgumentException("DeleteAclsMatchingAcls contain UNKNOWN elements"); + } + + public static DeleteAclsFilterResult filterResult(AclDeleteResult result) { + ApiError error = result.exception().map(e -> ApiError.fromThrowable(e)).orElse(ApiError.NONE); + List matchingAcls = result.aclBindingDeleteResults().stream() + .map(DeleteAclsResponse::matchingAcl) + .collect(Collectors.toList()); + return new DeleteAclsFilterResult() + .setErrorCode(error.error().code()) + .setErrorMessage(error.message()) + .setMatchingAcls(matchingAcls); + } + + private static DeleteAclsMatchingAcl matchingAcl(AclDeleteResult.AclBindingDeleteResult result) { + ApiError error = result.exception().map(e -> ApiError.fromThrowable(e)).orElse(ApiError.NONE); + AclBinding acl = result.aclBinding(); + return matchingAcl(acl, error); + } + + // Visible for testing + public static DeleteAclsMatchingAcl matchingAcl(AclBinding acl, ApiError error) { + return new DeleteAclsMatchingAcl() + .setErrorCode(error.error().code()) + .setErrorMessage(error.message()) + .setResourceName(acl.pattern().name()) + .setResourceType(acl.pattern().resourceType().code()) + .setPatternType(acl.pattern().patternType().code()) + .setHost(acl.entry().host()) + .setOperation(acl.entry().operation().code()) + .setPermissionType(acl.entry().permissionType().code()) + .setPrincipal(acl.entry().principal()); + } + + public static AclBinding aclBinding(DeleteAclsMatchingAcl matchingAcl) { + ResourcePattern resourcePattern = new ResourcePattern(ResourceType.fromCode(matchingAcl.resourceType()), + matchingAcl.resourceName(), PatternType.fromCode(matchingAcl.patternType())); + AccessControlEntry accessControlEntry = new AccessControlEntry(matchingAcl.principal(), matchingAcl.host(), + AclOperation.fromCode(matchingAcl.operation()), AclPermissionType.fromCode(matchingAcl.permissionType())); + return new AclBinding(resourcePattern, accessControlEntry); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DeleteGroupsRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/DeleteGroupsRequest.java new file mode 100644 index 0000000..87d6dee --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DeleteGroupsRequest.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.DeleteGroupsRequestData; +import org.apache.kafka.common.message.DeleteGroupsResponseData; +import org.apache.kafka.common.message.DeleteGroupsResponseData.DeletableGroupResult; +import org.apache.kafka.common.message.DeleteGroupsResponseData.DeletableGroupResultCollection; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; + +public class DeleteGroupsRequest extends AbstractRequest { + public static class Builder extends AbstractRequest.Builder { + private final DeleteGroupsRequestData data; + + public Builder(DeleteGroupsRequestData data) { + super(ApiKeys.DELETE_GROUPS); + this.data = data; + } + + @Override + public DeleteGroupsRequest build(short version) { + return new DeleteGroupsRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final DeleteGroupsRequestData data; + + public DeleteGroupsRequest(DeleteGroupsRequestData data, short version) { + super(ApiKeys.DELETE_GROUPS, version); + this.data = data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + Errors error = Errors.forException(e); + DeletableGroupResultCollection groupResults = new DeletableGroupResultCollection(); + for (String groupId : data.groupsNames()) { + groupResults.add(new DeletableGroupResult() + .setGroupId(groupId) + .setErrorCode(error.code())); + } + + return new DeleteGroupsResponse( + new DeleteGroupsResponseData() + .setResults(groupResults) + .setThrottleTimeMs(throttleTimeMs) + ); + } + + public static DeleteGroupsRequest parse(ByteBuffer buffer, short version) { + return new DeleteGroupsRequest(new DeleteGroupsRequestData(new ByteBufferAccessor(buffer), version), version); + } + + @Override + public DeleteGroupsRequestData data() { + return data; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DeleteGroupsResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/DeleteGroupsResponse.java new file mode 100644 index 0000000..4cbffda --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DeleteGroupsResponse.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.DeleteGroupsResponseData; +import org.apache.kafka.common.message.DeleteGroupsResponseData.DeletableGroupResult; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; + +/** + * Possible error codes: + * + * COORDINATOR_LOAD_IN_PROGRESS (14) + * COORDINATOR_NOT_AVAILABLE(15) + * NOT_COORDINATOR (16) + * INVALID_GROUP_ID(24) + * GROUP_AUTHORIZATION_FAILED(30) + * NON_EMPTY_GROUP(68) + * GROUP_ID_NOT_FOUND(69) + */ +public class DeleteGroupsResponse extends AbstractResponse { + + private final DeleteGroupsResponseData data; + + public DeleteGroupsResponse(DeleteGroupsResponseData data) { + super(ApiKeys.DELETE_GROUPS); + this.data = data; + } + + @Override + public DeleteGroupsResponseData data() { + return data; + } + + public Map errors() { + Map errorMap = new HashMap<>(); + for (DeletableGroupResult result : data.results()) { + errorMap.put(result.groupId(), Errors.forCode(result.errorCode())); + } + return errorMap; + } + + public Errors get(String group) throws IllegalArgumentException { + DeletableGroupResult result = data.results().find(group); + if (result == null) { + throw new IllegalArgumentException("could not find group " + group + " in the delete group response"); + } + return Errors.forCode(result.errorCode()); + } + + @Override + public Map errorCounts() { + Map counts = new HashMap<>(); + data.results().forEach(result -> + updateErrorCounts(counts, Errors.forCode(result.errorCode())) + ); + return counts; + } + + public static DeleteGroupsResponse parse(ByteBuffer buffer, short version) { + return new DeleteGroupsResponse(new DeleteGroupsResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 1; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DeleteRecordsRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/DeleteRecordsRequest.java new file mode 100644 index 0000000..a4f62e1 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DeleteRecordsRequest.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.DeleteRecordsRequestData; +import org.apache.kafka.common.message.DeleteRecordsRequestData.DeleteRecordsTopic; +import org.apache.kafka.common.message.DeleteRecordsResponseData; +import org.apache.kafka.common.message.DeleteRecordsResponseData.DeleteRecordsTopicResult; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; + +public class DeleteRecordsRequest extends AbstractRequest { + + public static final long HIGH_WATERMARK = -1L; + + private final DeleteRecordsRequestData data; + + public static class Builder extends AbstractRequest.Builder { + private DeleteRecordsRequestData data; + + public Builder(DeleteRecordsRequestData data) { + super(ApiKeys.DELETE_RECORDS); + this.data = data; + } + + @Override + public DeleteRecordsRequest build(short version) { + return new DeleteRecordsRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private DeleteRecordsRequest(DeleteRecordsRequestData data, short version) { + super(ApiKeys.DELETE_RECORDS, version); + this.data = data; + } + + @Override + public DeleteRecordsRequestData data() { + return data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + DeleteRecordsResponseData result = new DeleteRecordsResponseData().setThrottleTimeMs(throttleTimeMs); + short errorCode = Errors.forException(e).code(); + for (DeleteRecordsTopic topic : data.topics()) { + DeleteRecordsTopicResult topicResult = new DeleteRecordsTopicResult().setName(topic.name()); + result.topics().add(topicResult); + for (DeleteRecordsRequestData.DeleteRecordsPartition partition : topic.partitions()) { + topicResult.partitions().add(new DeleteRecordsResponseData.DeleteRecordsPartitionResult() + .setPartitionIndex(partition.partitionIndex()) + .setErrorCode(errorCode) + .setLowWatermark(DeleteRecordsResponse.INVALID_LOW_WATERMARK)); + } + } + return new DeleteRecordsResponse(result); + } + + public static DeleteRecordsRequest parse(ByteBuffer buffer, short version) { + return new DeleteRecordsRequest(new DeleteRecordsRequestData(new ByteBufferAccessor(buffer), version), version); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DeleteRecordsResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/DeleteRecordsResponse.java new file mode 100644 index 0000000..b090543 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DeleteRecordsResponse.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.DeleteRecordsResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; + +public class DeleteRecordsResponse extends AbstractResponse { + + public static final long INVALID_LOW_WATERMARK = -1L; + private final DeleteRecordsResponseData data; + + /** + * Possible error code: + * + * OFFSET_OUT_OF_RANGE (1) + * UNKNOWN_TOPIC_OR_PARTITION (3) + * NOT_LEADER_OR_FOLLOWER (6) + * REQUEST_TIMED_OUT (7) + * UNKNOWN (-1) + */ + + public DeleteRecordsResponse(DeleteRecordsResponseData data) { + super(ApiKeys.DELETE_RECORDS); + this.data = data; + } + + @Override + public DeleteRecordsResponseData data() { + return data; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + @Override + public Map errorCounts() { + Map errorCounts = new HashMap<>(); + data.topics().forEach(topicResponses -> + topicResponses.partitions().forEach(response -> + updateErrorCounts(errorCounts, Errors.forCode(response.errorCode())) + ) + ); + return errorCounts; + } + + public static DeleteRecordsResponse parse(ByteBuffer buffer, short version) { + return new DeleteRecordsResponse(new DeleteRecordsResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 1; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DeleteTopicsRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/DeleteTopicsRequest.java new file mode 100644 index 0000000..1322630 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DeleteTopicsRequest.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.message.DeleteTopicsRequestData; +import org.apache.kafka.common.message.DeleteTopicsResponseData; +import org.apache.kafka.common.message.DeleteTopicsResponseData.DeletableTopicResult; +import org.apache.kafka.common.message.DeleteTopicsRequestData.DeleteTopicState; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +public class DeleteTopicsRequest extends AbstractRequest { + + public static class Builder extends AbstractRequest.Builder { + private DeleteTopicsRequestData data; + + public Builder(DeleteTopicsRequestData data) { + super(ApiKeys.DELETE_TOPICS); + this.data = data; + } + + @Override + public DeleteTopicsRequest build(short version) { + if (version >= 6 && !data.topicNames().isEmpty()) { + data.setTopics(groupByTopic(data.topicNames())); + } + return new DeleteTopicsRequest(data, version); + } + + private List groupByTopic(List topics) { + List topicStates = new ArrayList<>(); + for (String topic : topics) { + topicStates.add(new DeleteTopicState().setName(topic)); + } + return topicStates; + } + + @Override + public String toString() { + return data.toString(); + } + } + + private DeleteTopicsRequestData data; + + private DeleteTopicsRequest(DeleteTopicsRequestData data, short version) { + super(ApiKeys.DELETE_TOPICS, version); + this.data = data; + } + + @Override + public DeleteTopicsRequestData data() { + return data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + DeleteTopicsResponseData response = new DeleteTopicsResponseData(); + if (version() >= 1) { + response.setThrottleTimeMs(throttleTimeMs); + } + ApiError apiError = ApiError.fromThrowable(e); + for (DeleteTopicState topic : topics()) { + response.responses().add(new DeletableTopicResult() + .setName(topic.name()) + .setTopicId(topic.topicId()) + .setErrorCode(apiError.error().code())); + } + return new DeleteTopicsResponse(response); + } + + public List topicNames() { + if (version() >= 6) + return data.topics().stream().map(topic -> topic.name()).collect(Collectors.toList()); + return data.topicNames(); + } + + public int numberOfTopics() { + if (version() >= 6) + return data.topics().size(); + return data.topicNames().size(); + } + + public List topicIds() { + if (version() >= 6) + return data.topics().stream().map(topic -> topic.topicId()).collect(Collectors.toList()); + return Collections.emptyList(); + } + + public List topics() { + if (version() >= 6) + return data.topics(); + return data.topicNames().stream().map(name -> new DeleteTopicState().setName(name)).collect(Collectors.toList()); + } + + public static DeleteTopicsRequest parse(ByteBuffer buffer, short version) { + return new DeleteTopicsRequest(new DeleteTopicsRequestData(new ByteBufferAccessor(buffer), version), version); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DeleteTopicsResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/DeleteTopicsResponse.java new file mode 100644 index 0000000..2090c4f --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DeleteTopicsResponse.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.DeleteTopicsResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; + + +public class DeleteTopicsResponse extends AbstractResponse { + + /** + * Possible error codes: + * + * REQUEST_TIMED_OUT(7) + * INVALID_TOPIC_EXCEPTION(17) + * TOPIC_AUTHORIZATION_FAILED(29) + * NOT_CONTROLLER(41) + * INVALID_REQUEST(42) + * TOPIC_DELETION_DISABLED(73) + */ + private final DeleteTopicsResponseData data; + + public DeleteTopicsResponse(DeleteTopicsResponseData data) { + super(ApiKeys.DELETE_TOPICS); + this.data = data; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + @Override + public DeleteTopicsResponseData data() { + return data; + } + + @Override + public Map errorCounts() { + HashMap counts = new HashMap<>(); + data.responses().forEach(result -> + updateErrorCounts(counts, Errors.forCode(result.errorCode())) + ); + return counts; + } + + public static DeleteTopicsResponse parse(ByteBuffer buffer, short version) { + return new DeleteTopicsResponse(new DeleteTopicsResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 2; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DescribeAclsRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/DescribeAclsRequest.java new file mode 100644 index 0000000..1ddf5bf --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DescribeAclsRequest.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.acl.AccessControlEntryFilter; +import org.apache.kafka.common.acl.AclBindingFilter; +import org.apache.kafka.common.acl.AclOperation; +import org.apache.kafka.common.acl.AclPermissionType; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.DescribeAclsRequestData; +import org.apache.kafka.common.message.DescribeAclsResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.resource.PatternType; +import org.apache.kafka.common.resource.ResourcePatternFilter; +import org.apache.kafka.common.resource.ResourceType; + +import java.nio.ByteBuffer; + +public class DescribeAclsRequest extends AbstractRequest { + + public static class Builder extends AbstractRequest.Builder { + private final DescribeAclsRequestData data; + + public Builder(AclBindingFilter filter) { + super(ApiKeys.DESCRIBE_ACLS); + ResourcePatternFilter patternFilter = filter.patternFilter(); + AccessControlEntryFilter entryFilter = filter.entryFilter(); + data = new DescribeAclsRequestData() + .setHostFilter(entryFilter.host()) + .setOperation(entryFilter.operation().code()) + .setPermissionType(entryFilter.permissionType().code()) + .setPrincipalFilter(entryFilter.principal()) + .setResourceNameFilter(patternFilter.name()) + .setPatternTypeFilter(patternFilter.patternType().code()) + .setResourceTypeFilter(patternFilter.resourceType().code()); + } + + @Override + public DescribeAclsRequest build(short version) { + return new DescribeAclsRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final DescribeAclsRequestData data; + + private DescribeAclsRequest(DescribeAclsRequestData data, short version) { + super(ApiKeys.DESCRIBE_ACLS, version); + this.data = data; + normalizeAndValidate(version); + } + + private void normalizeAndValidate(short version) { + if (version == 0) { + PatternType patternType = PatternType.fromCode(data.patternTypeFilter()); + // On older brokers, no pattern types existed except LITERAL (effectively). So even though ANY is not + // directly supported on those brokers, we can get the same effect as ANY by setting the pattern type + // to LITERAL. Note that the wildcard `*` is considered `LITERAL` for compatibility reasons. + if (patternType == PatternType.ANY) + data.setPatternTypeFilter(PatternType.LITERAL.code()); + else if (patternType != PatternType.LITERAL) + throw new UnsupportedVersionException("Version 0 only supports literal resource pattern types"); + } + + if (data.patternTypeFilter() == PatternType.UNKNOWN.code() + || data.resourceTypeFilter() == ResourceType.UNKNOWN.code() + || data.permissionType() == AclPermissionType.UNKNOWN.code() + || data.operation() == AclOperation.UNKNOWN.code()) { + throw new IllegalArgumentException("DescribeAclsRequest contains UNKNOWN elements: " + data); + } + } + + @Override + public DescribeAclsRequestData data() { + return data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable throwable) { + ApiError error = ApiError.fromThrowable(throwable); + DescribeAclsResponseData response = new DescribeAclsResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setErrorCode(error.error().code()) + .setErrorMessage(error.message()); + return new DescribeAclsResponse(response, version()); + } + + public static DescribeAclsRequest parse(ByteBuffer buffer, short version) { + return new DescribeAclsRequest(new DescribeAclsRequestData(new ByteBufferAccessor(buffer), version), version); + } + + public AclBindingFilter filter() { + ResourcePatternFilter rpf = new ResourcePatternFilter( + ResourceType.fromCode(data.resourceTypeFilter()), + data.resourceNameFilter(), + PatternType.fromCode(data.patternTypeFilter())); + AccessControlEntryFilter acef = new AccessControlEntryFilter( + data.principalFilter(), + data.hostFilter(), + AclOperation.fromCode(data.operation()), + AclPermissionType.fromCode(data.permissionType())); + return new AclBindingFilter(rpf, acef); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DescribeAclsResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/DescribeAclsResponse.java new file mode 100644 index 0000000..c4190e6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DescribeAclsResponse.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.acl.AccessControlEntry; +import org.apache.kafka.common.acl.AclBinding; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.resource.PatternType; +import org.apache.kafka.common.resource.ResourcePattern; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.apache.kafka.common.acl.AclOperation; +import org.apache.kafka.common.acl.AclPermissionType; +import org.apache.kafka.common.message.DescribeAclsResponseData; +import org.apache.kafka.common.message.DescribeAclsResponseData.AclDescription; +import org.apache.kafka.common.message.DescribeAclsResponseData.DescribeAclsResource; +import org.apache.kafka.common.resource.ResourceType; + +public class DescribeAclsResponse extends AbstractResponse { + + private final DescribeAclsResponseData data; + + public DescribeAclsResponse(DescribeAclsResponseData data, short version) { + super(ApiKeys.DESCRIBE_ACLS); + this.data = data; + validate(Optional.of(version)); + } + + // Skips version validation, visible for testing + DescribeAclsResponse(DescribeAclsResponseData data) { + super(ApiKeys.DESCRIBE_ACLS); + this.data = data; + validate(Optional.empty()); + } + + @Override + public DescribeAclsResponseData data() { + return data; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + public ApiError error() { + return new ApiError(Errors.forCode(data.errorCode()), data.errorMessage()); + } + + @Override + public Map errorCounts() { + return errorCounts(Errors.forCode(data.errorCode())); + } + + public List acls() { + return data.resources(); + } + + public static DescribeAclsResponse parse(ByteBuffer buffer, short version) { + return new DescribeAclsResponse(new DescribeAclsResponseData(new ByteBufferAccessor(buffer), version), version); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 1; + } + + private void validate(Optional version) { + if (version.isPresent() && version.get() == 0) { + final boolean unsupported = acls().stream() + .anyMatch(acl -> acl.patternType() != PatternType.LITERAL.code()); + if (unsupported) { + throw new UnsupportedVersionException("Version 0 only supports literal resource pattern types"); + } + } + + for (DescribeAclsResource resource : acls()) { + if (resource.patternType() == PatternType.UNKNOWN.code() || resource.resourceType() == ResourceType.UNKNOWN.code()) + throw new IllegalArgumentException("Contain UNKNOWN elements"); + for (AclDescription acl : resource.acls()) { + if (acl.operation() == AclOperation.UNKNOWN.code() || acl.permissionType() == AclPermissionType.UNKNOWN.code()) { + throw new IllegalArgumentException("Contain UNKNOWN elements"); + } + } + } + } + + private static Stream aclBindings(DescribeAclsResource resource) { + return resource.acls().stream().map(acl -> { + ResourcePattern pattern = new ResourcePattern( + ResourceType.fromCode(resource.resourceType()), + resource.resourceName(), + PatternType.fromCode(resource.patternType())); + AccessControlEntry entry = new AccessControlEntry( + acl.principal(), + acl.host(), + AclOperation.fromCode(acl.operation()), + AclPermissionType.fromCode(acl.permissionType())); + return new AclBinding(pattern, entry); + }); + } + + public static List aclBindings(List resources) { + return resources.stream().flatMap(DescribeAclsResponse::aclBindings).collect(Collectors.toList()); + } + + public static List aclsResources(Collection acls) { + Map> patternToEntries = new HashMap<>(); + for (AclBinding acl : acls) { + patternToEntries.computeIfAbsent(acl.pattern(), v -> new ArrayList<>()).add(acl.entry()); + } + List resources = new ArrayList<>(patternToEntries.size()); + for (Entry> entry : patternToEntries.entrySet()) { + ResourcePattern key = entry.getKey(); + List aclDescriptions = new ArrayList<>(); + for (AccessControlEntry ace : entry.getValue()) { + AclDescription ad = new AclDescription() + .setHost(ace.host()) + .setOperation(ace.operation().code()) + .setPermissionType(ace.permissionType().code()) + .setPrincipal(ace.principal()); + aclDescriptions.add(ad); + } + DescribeAclsResource dar = new DescribeAclsResource() + .setResourceName(key.name()) + .setPatternType(key.patternType().code()) + .setResourceType(key.resourceType().code()) + .setAcls(aclDescriptions); + resources.add(dar); + } + return resources; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DescribeClientQuotasRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/DescribeClientQuotasRequest.java new file mode 100644 index 0000000..3d95f42 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DescribeClientQuotasRequest.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.DescribeClientQuotasRequestData; +import org.apache.kafka.common.message.DescribeClientQuotasRequestData.ComponentData; +import org.apache.kafka.common.message.DescribeClientQuotasResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.quota.ClientQuotaFilter; +import org.apache.kafka.common.quota.ClientQuotaFilterComponent; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +public class DescribeClientQuotasRequest extends AbstractRequest { + // These values must not change. + public static final byte MATCH_TYPE_EXACT = 0; + public static final byte MATCH_TYPE_DEFAULT = 1; + public static final byte MATCH_TYPE_SPECIFIED = 2; + + public static class Builder extends AbstractRequest.Builder { + + private final DescribeClientQuotasRequestData data; + + public Builder(ClientQuotaFilter filter) { + super(ApiKeys.DESCRIBE_CLIENT_QUOTAS); + + List componentData = new ArrayList<>(filter.components().size()); + for (ClientQuotaFilterComponent component : filter.components()) { + ComponentData fd = new ComponentData().setEntityType(component.entityType()); + if (component.match() == null) { + fd.setMatchType(MATCH_TYPE_SPECIFIED); + fd.setMatch(null); + } else if (component.match().isPresent()) { + fd.setMatchType(MATCH_TYPE_EXACT); + fd.setMatch(component.match().get()); + } else { + fd.setMatchType(MATCH_TYPE_DEFAULT); + fd.setMatch(null); + } + componentData.add(fd); + } + this.data = new DescribeClientQuotasRequestData() + .setComponents(componentData) + .setStrict(filter.strict()); + } + + @Override + public DescribeClientQuotasRequest build(short version) { + return new DescribeClientQuotasRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final DescribeClientQuotasRequestData data; + + public DescribeClientQuotasRequest(DescribeClientQuotasRequestData data, short version) { + super(ApiKeys.DESCRIBE_CLIENT_QUOTAS, version); + this.data = data; + } + + public ClientQuotaFilter filter() { + List components = new ArrayList<>(data.components().size()); + for (ComponentData componentData : data.components()) { + ClientQuotaFilterComponent component; + switch (componentData.matchType()) { + case MATCH_TYPE_EXACT: + component = ClientQuotaFilterComponent.ofEntity(componentData.entityType(), componentData.match()); + break; + case MATCH_TYPE_DEFAULT: + component = ClientQuotaFilterComponent.ofDefaultEntity(componentData.entityType()); + break; + case MATCH_TYPE_SPECIFIED: + component = ClientQuotaFilterComponent.ofEntityType(componentData.entityType()); + break; + default: + throw new IllegalArgumentException("Unexpected match type: " + componentData.matchType()); + } + components.add(component); + } + if (data.strict()) { + return ClientQuotaFilter.containsOnly(components); + } else { + return ClientQuotaFilter.contains(components); + } + } + + @Override + public DescribeClientQuotasRequestData data() { + return data; + } + + @Override + public DescribeClientQuotasResponse getErrorResponse(int throttleTimeMs, Throwable e) { + ApiError error = ApiError.fromThrowable(e); + return new DescribeClientQuotasResponse(new DescribeClientQuotasResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setErrorCode(error.error().code()) + .setErrorMessage(error.message()) + .setEntries(null)); + } + + public static DescribeClientQuotasRequest parse(ByteBuffer buffer, short version) { + return new DescribeClientQuotasRequest(new DescribeClientQuotasRequestData(new ByteBufferAccessor(buffer), version), + version); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DescribeClientQuotasResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/DescribeClientQuotasResponse.java new file mode 100644 index 0000000..1474143 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DescribeClientQuotasResponse.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.message.DescribeClientQuotasResponseData; +import org.apache.kafka.common.message.DescribeClientQuotasResponseData.EntityData; +import org.apache.kafka.common.message.DescribeClientQuotasResponseData.EntryData; +import org.apache.kafka.common.message.DescribeClientQuotasResponseData.ValueData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.quota.ClientQuotaEntity; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class DescribeClientQuotasResponse extends AbstractResponse { + + private final DescribeClientQuotasResponseData data; + + public DescribeClientQuotasResponse(DescribeClientQuotasResponseData data) { + super(ApiKeys.DESCRIBE_CLIENT_QUOTAS); + this.data = data; + } + + public void complete(KafkaFutureImpl>> future) { + Errors error = Errors.forCode(data.errorCode()); + if (error != Errors.NONE) { + future.completeExceptionally(error.exception(data.errorMessage())); + return; + } + + Map> result = new HashMap<>(data.entries().size()); + for (EntryData entries : data.entries()) { + Map entity = new HashMap<>(entries.entity().size()); + for (EntityData entityData : entries.entity()) { + entity.put(entityData.entityType(), entityData.entityName()); + } + + Map values = new HashMap<>(entries.values().size()); + for (ValueData valueData : entries.values()) { + values.put(valueData.key(), valueData.value()); + } + + result.put(new ClientQuotaEntity(entity), values); + } + future.complete(result); + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + @Override + public DescribeClientQuotasResponseData data() { + return data; + } + + @Override + public Map errorCounts() { + return errorCounts(Errors.forCode(data.errorCode())); + } + + public static DescribeClientQuotasResponse parse(ByteBuffer buffer, short version) { + return new DescribeClientQuotasResponse(new DescribeClientQuotasResponseData(new ByteBufferAccessor(buffer), version)); + } + + public static DescribeClientQuotasResponse fromQuotaEntities(Map> entities, + int throttleTimeMs) { + List entries = new ArrayList<>(entities.size()); + for (Map.Entry> entry : entities.entrySet()) { + ClientQuotaEntity quotaEntity = entry.getKey(); + List entityData = new ArrayList<>(quotaEntity.entries().size()); + for (Map.Entry entityEntry : quotaEntity.entries().entrySet()) { + entityData.add(new EntityData() + .setEntityType(entityEntry.getKey()) + .setEntityName(entityEntry.getValue())); + } + + Map quotaValues = entry.getValue(); + List valueData = new ArrayList<>(quotaValues.size()); + for (Map.Entry valuesEntry : entry.getValue().entrySet()) { + valueData.add(new ValueData() + .setKey(valuesEntry.getKey()) + .setValue(valuesEntry.getValue())); + } + + entries.add(new EntryData() + .setEntity(entityData) + .setValues(valueData)); + } + + return new DescribeClientQuotasResponse(new DescribeClientQuotasResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setErrorCode((short) 0) + .setErrorMessage(null) + .setEntries(entries)); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DescribeClusterRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/DescribeClusterRequest.java new file mode 100644 index 0000000..02f2dda --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DescribeClusterRequest.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import java.nio.ByteBuffer; +import org.apache.kafka.common.message.DescribeClusterRequestData; +import org.apache.kafka.common.message.DescribeClusterResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; + +public class DescribeClusterRequest extends AbstractRequest { + + public static class Builder extends AbstractRequest.Builder { + + private final DescribeClusterRequestData data; + + public Builder(DescribeClusterRequestData data) { + super(ApiKeys.DESCRIBE_CLUSTER); + this.data = data; + } + + @Override + public DescribeClusterRequest build(final short version) { + return new DescribeClusterRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final DescribeClusterRequestData data; + + public DescribeClusterRequest(DescribeClusterRequestData data, short version) { + super(ApiKeys.DESCRIBE_CLUSTER, version); + this.data = data; + } + + @Override + public DescribeClusterRequestData data() { + return data; + } + + @Override + public AbstractResponse getErrorResponse(final int throttleTimeMs, final Throwable e) { + ApiError apiError = ApiError.fromThrowable(e); + return new DescribeClusterResponse(new DescribeClusterResponseData() + .setErrorCode(apiError.error().code()) + .setErrorMessage(apiError.message())); + } + + @Override + public String toString(final boolean verbose) { + return data.toString(); + } + + public static DescribeClusterRequest parse(ByteBuffer buffer, short version) { + return new DescribeClusterRequest(new DescribeClusterRequestData(new ByteBufferAccessor(buffer), version), version); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DescribeClusterResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/DescribeClusterResponse.java new file mode 100644 index 0000000..60d9311 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DescribeClusterResponse.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import java.nio.ByteBuffer; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.message.DescribeClusterResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +public class DescribeClusterResponse extends AbstractResponse { + + private final DescribeClusterResponseData data; + + public DescribeClusterResponse(DescribeClusterResponseData data) { + super(ApiKeys.DESCRIBE_CLUSTER); + this.data = data; + } + + public Map nodes() { + return data.brokers().valuesList().stream() + .map(b -> new Node(b.brokerId(), b.host(), b.port(), b.rack())) + .collect(Collectors.toMap(Node::id, Function.identity())); + } + + @Override + public Map errorCounts() { + return errorCounts(Errors.forCode(data.errorCode())); + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + @Override + public DescribeClusterResponseData data() { + return data; + } + + public static DescribeClusterResponse parse(ByteBuffer buffer, short version) { + return new DescribeClusterResponse(new DescribeClusterResponseData(new ByteBufferAccessor(buffer), version)); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DescribeConfigsRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/DescribeConfigsRequest.java new file mode 100644 index 0000000..d612ca8 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DescribeConfigsRequest.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.DescribeConfigsRequestData; +import org.apache.kafka.common.message.DescribeConfigsResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.stream.Collectors; + +public class DescribeConfigsRequest extends AbstractRequest { + + public static class Builder extends AbstractRequest.Builder { + private final DescribeConfigsRequestData data; + + public Builder(DescribeConfigsRequestData data) { + super(ApiKeys.DESCRIBE_CONFIGS); + this.data = data; + } + + @Override + public DescribeConfigsRequest build(short version) { + return new DescribeConfigsRequest(data, version); + } + } + + private final DescribeConfigsRequestData data; + + public DescribeConfigsRequest(DescribeConfigsRequestData data, short version) { + super(ApiKeys.DESCRIBE_CONFIGS, version); + this.data = data; + } + + @Override + public DescribeConfigsRequestData data() { + return data; + } + + @Override + public DescribeConfigsResponse getErrorResponse(int throttleTimeMs, Throwable e) { + Errors error = Errors.forException(e); + return new DescribeConfigsResponse(new DescribeConfigsResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setResults(data.resources().stream().map(result -> { + return new DescribeConfigsResponseData.DescribeConfigsResult().setErrorCode(error.code()) + .setErrorMessage(error.message()) + .setResourceName(result.resourceName()) + .setResourceType(result.resourceType()); + }).collect(Collectors.toList()) + )); + } + + public static DescribeConfigsRequest parse(ByteBuffer buffer, short version) { + return new DescribeConfigsRequest(new DescribeConfigsRequestData(new ByteBufferAccessor(buffer), version), version); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DescribeConfigsResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/DescribeConfigsResponse.java new file mode 100644 index 0000000..aa7a713 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DescribeConfigsResponse.java @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.message.DescribeConfigsResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.function.Function; +import java.util.stream.Collectors; + +public class DescribeConfigsResponse extends AbstractResponse { + + public static class Config { + private final ApiError error; + private final Collection entries; + + public Config(ApiError error, Collection entries) { + this.error = Objects.requireNonNull(error, "error"); + this.entries = Objects.requireNonNull(entries, "entries"); + } + + public ApiError error() { + return error; + } + + public Collection entries() { + return entries; + } + } + + public static class ConfigEntry { + private final String name; + private final String value; + private final boolean isSensitive; + private final ConfigSource source; + private final boolean readOnly; + private final Collection synonyms; + private final ConfigType type; + private final String documentation; + + public ConfigEntry(String name, String value, ConfigSource source, boolean isSensitive, boolean readOnly, + Collection synonyms) { + this(name, value, source, isSensitive, readOnly, synonyms, ConfigType.UNKNOWN, null); + } + + public ConfigEntry(String name, String value, ConfigSource source, boolean isSensitive, boolean readOnly, + Collection synonyms, ConfigType type, String documentation) { + + this.name = Objects.requireNonNull(name, "name"); + this.value = value; + this.source = Objects.requireNonNull(source, "source"); + this.isSensitive = isSensitive; + this.readOnly = readOnly; + this.synonyms = Objects.requireNonNull(synonyms, "synonyms"); + this.type = type; + this.documentation = documentation; + } + + public String name() { + return name; + } + + public String value() { + return value; + } + + public boolean isSensitive() { + return isSensitive; + } + + public ConfigSource source() { + return source; + } + + public boolean isReadOnly() { + return readOnly; + } + + public Collection synonyms() { + return synonyms; + } + + public ConfigType type() { + return type; + } + + public String documentation() { + return documentation; + } + } + + public enum ConfigSource { + UNKNOWN((byte) 0, org.apache.kafka.clients.admin.ConfigEntry.ConfigSource.UNKNOWN), + TOPIC_CONFIG((byte) 1, org.apache.kafka.clients.admin.ConfigEntry.ConfigSource.DYNAMIC_TOPIC_CONFIG), + DYNAMIC_BROKER_CONFIG((byte) 2, org.apache.kafka.clients.admin.ConfigEntry.ConfigSource.DYNAMIC_BROKER_CONFIG), + DYNAMIC_DEFAULT_BROKER_CONFIG((byte) 3, org.apache.kafka.clients.admin.ConfigEntry.ConfigSource.DYNAMIC_DEFAULT_BROKER_CONFIG), + STATIC_BROKER_CONFIG((byte) 4, org.apache.kafka.clients.admin.ConfigEntry.ConfigSource.STATIC_BROKER_CONFIG), + DEFAULT_CONFIG((byte) 5, org.apache.kafka.clients.admin.ConfigEntry.ConfigSource.DEFAULT_CONFIG), + DYNAMIC_BROKER_LOGGER_CONFIG((byte) 6, org.apache.kafka.clients.admin.ConfigEntry.ConfigSource.DYNAMIC_BROKER_LOGGER_CONFIG); + + final byte id; + private final org.apache.kafka.clients.admin.ConfigEntry.ConfigSource source; + private static final ConfigSource[] VALUES = values(); + + ConfigSource(byte id, org.apache.kafka.clients.admin.ConfigEntry.ConfigSource source) { + this.id = id; + this.source = source; + } + + public byte id() { + return id; + } + + public static ConfigSource forId(byte id) { + if (id < 0) + throw new IllegalArgumentException("id should be positive, id: " + id); + if (id >= VALUES.length) + return UNKNOWN; + return VALUES[id]; + } + + public org.apache.kafka.clients.admin.ConfigEntry.ConfigSource source() { + return source; + } + } + + public enum ConfigType { + UNKNOWN((byte) 0, org.apache.kafka.clients.admin.ConfigEntry.ConfigType.UNKNOWN), + BOOLEAN((byte) 1, org.apache.kafka.clients.admin.ConfigEntry.ConfigType.BOOLEAN), + STRING((byte) 2, org.apache.kafka.clients.admin.ConfigEntry.ConfigType.STRING), + INT((byte) 3, org.apache.kafka.clients.admin.ConfigEntry.ConfigType.INT), + SHORT((byte) 4, org.apache.kafka.clients.admin.ConfigEntry.ConfigType.SHORT), + LONG((byte) 5, org.apache.kafka.clients.admin.ConfigEntry.ConfigType.LONG), + DOUBLE((byte) 6, org.apache.kafka.clients.admin.ConfigEntry.ConfigType.DOUBLE), + LIST((byte) 7, org.apache.kafka.clients.admin.ConfigEntry.ConfigType.LIST), + CLASS((byte) 8, org.apache.kafka.clients.admin.ConfigEntry.ConfigType.CLASS), + PASSWORD((byte) 9, org.apache.kafka.clients.admin.ConfigEntry.ConfigType.PASSWORD); + + final byte id; + final org.apache.kafka.clients.admin.ConfigEntry.ConfigType type; + private static final ConfigType[] VALUES = values(); + + ConfigType(byte id, org.apache.kafka.clients.admin.ConfigEntry.ConfigType type) { + this.id = id; + this.type = type; + } + + public byte id() { + return id; + } + + public static ConfigType forId(byte id) { + if (id < 0) + throw new IllegalArgumentException("id should be positive, id: " + id); + if (id >= VALUES.length) + return UNKNOWN; + return VALUES[id]; + } + + public org.apache.kafka.clients.admin.ConfigEntry.ConfigType type() { + return type; + } + } + + public static class ConfigSynonym { + private final String name; + private final String value; + private final ConfigSource source; + + public ConfigSynonym(String name, String value, ConfigSource source) { + this.name = Objects.requireNonNull(name, "name"); + this.value = value; + this.source = Objects.requireNonNull(source, "source"); + } + + public String name() { + return name; + } + public String value() { + return value; + } + public ConfigSource source() { + return source; + } + } + + public Map resultMap() { + return data().results().stream().collect(Collectors.toMap( + configsResult -> + new ConfigResource(ConfigResource.Type.forId(configsResult.resourceType()), + configsResult.resourceName()), + Function.identity())); + } + + private final DescribeConfigsResponseData data; + + public DescribeConfigsResponse(DescribeConfigsResponseData data) { + super(ApiKeys.DESCRIBE_CONFIGS); + this.data = data; + } + + // This constructor should only be used after deserialization, it has special handling for version 0 + private DescribeConfigsResponse(DescribeConfigsResponseData data, short version) { + super(ApiKeys.DESCRIBE_CONFIGS); + this.data = data; + if (version == 0) { + for (DescribeConfigsResponseData.DescribeConfigsResult result : data.results()) { + for (DescribeConfigsResponseData.DescribeConfigsResourceResult config : result.configs()) { + if (config.isDefault()) { + config.setConfigSource(ConfigSource.DEFAULT_CONFIG.id); + } else { + if (result.resourceType() == ConfigResource.Type.BROKER.id()) { + config.setConfigSource(ConfigSource.STATIC_BROKER_CONFIG.id); + } else if (result.resourceType() == ConfigResource.Type.TOPIC.id()) { + config.setConfigSource(ConfigSource.TOPIC_CONFIG.id); + } else { + config.setConfigSource(ConfigSource.UNKNOWN.id); + } + } + } + } + } + } + + @Override + public DescribeConfigsResponseData data() { + return data; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + @Override + public Map errorCounts() { + Map errorCounts = new HashMap<>(); + data.results().forEach(response -> + updateErrorCounts(errorCounts, Errors.forCode(response.errorCode())) + ); + return errorCounts; + } + + public static DescribeConfigsResponse parse(ByteBuffer buffer, short version) { + return new DescribeConfigsResponse(new DescribeConfigsResponseData(new ByteBufferAccessor(buffer), version), version); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 2; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DescribeDelegationTokenRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/DescribeDelegationTokenRequest.java new file mode 100644 index 0000000..9bf59e8 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DescribeDelegationTokenRequest.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.DescribeDelegationTokenRequestData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.security.auth.KafkaPrincipal; + +import java.nio.ByteBuffer; +import java.util.List; +import java.util.stream.Collectors; + +public class DescribeDelegationTokenRequest extends AbstractRequest { + + public static class Builder extends AbstractRequest.Builder { + private final DescribeDelegationTokenRequestData data; + + public Builder(List owners) { + super(ApiKeys.DESCRIBE_DELEGATION_TOKEN); + this.data = new DescribeDelegationTokenRequestData() + .setOwners(owners == null ? null : owners + .stream() + .map(owner -> new DescribeDelegationTokenRequestData.DescribeDelegationTokenOwner() + .setPrincipalName(owner.getName()) + .setPrincipalType(owner.getPrincipalType())) + .collect(Collectors.toList())); + } + + @Override + public DescribeDelegationTokenRequest build(short version) { + return new DescribeDelegationTokenRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final DescribeDelegationTokenRequestData data; + + public DescribeDelegationTokenRequest(DescribeDelegationTokenRequestData data, short version) { + super(ApiKeys.DESCRIBE_DELEGATION_TOKEN, version); + this.data = data; + } + + @Override + public DescribeDelegationTokenRequestData data() { + return data; + } + + public boolean ownersListEmpty() { + return data.owners() != null && data.owners().isEmpty(); + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + return new DescribeDelegationTokenResponse(throttleTimeMs, Errors.forException(e)); + } + + public static DescribeDelegationTokenRequest parse(ByteBuffer buffer, short version) { + return new DescribeDelegationTokenRequest(new DescribeDelegationTokenRequestData( + new ByteBufferAccessor(buffer), version), version); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DescribeDelegationTokenResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/DescribeDelegationTokenResponse.java new file mode 100644 index 0000000..4a2162f --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DescribeDelegationTokenResponse.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.DescribeDelegationTokenResponseData; +import org.apache.kafka.common.message.DescribeDelegationTokenResponseData.DescribedDelegationToken; +import org.apache.kafka.common.message.DescribeDelegationTokenResponseData.DescribedDelegationTokenRenewer; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.security.auth.KafkaPrincipal; +import org.apache.kafka.common.security.token.delegation.DelegationToken; +import org.apache.kafka.common.security.token.delegation.TokenInformation; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +public class DescribeDelegationTokenResponse extends AbstractResponse { + + private final DescribeDelegationTokenResponseData data; + + public DescribeDelegationTokenResponse(int throttleTimeMs, Errors error, List tokens) { + super(ApiKeys.DESCRIBE_DELEGATION_TOKEN); + List describedDelegationTokenList = tokens + .stream() + .map(dt -> new DescribedDelegationToken() + .setTokenId(dt.tokenInfo().tokenId()) + .setPrincipalType(dt.tokenInfo().owner().getPrincipalType()) + .setPrincipalName(dt.tokenInfo().owner().getName()) + .setIssueTimestamp(dt.tokenInfo().issueTimestamp()) + .setMaxTimestamp(dt.tokenInfo().maxTimestamp()) + .setExpiryTimestamp(dt.tokenInfo().expiryTimestamp()) + .setHmac(dt.hmac()) + .setRenewers(dt.tokenInfo().renewers() + .stream() + .map(r -> new DescribedDelegationTokenRenewer().setPrincipalName(r.getName()).setPrincipalType(r.getPrincipalType())) + .collect(Collectors.toList()))) + .collect(Collectors.toList()); + + this.data = new DescribeDelegationTokenResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setErrorCode(error.code()) + .setTokens(describedDelegationTokenList); + } + + public DescribeDelegationTokenResponse(int throttleTimeMs, Errors error) { + this(throttleTimeMs, error, new ArrayList<>()); + } + + public DescribeDelegationTokenResponse(DescribeDelegationTokenResponseData data) { + super(ApiKeys.DESCRIBE_DELEGATION_TOKEN); + this.data = data; + } + + public static DescribeDelegationTokenResponse parse(ByteBuffer buffer, short version) { + return new DescribeDelegationTokenResponse(new DescribeDelegationTokenResponseData( + new ByteBufferAccessor(buffer), version)); + } + + @Override + public Map errorCounts() { + return errorCounts(error()); + } + + @Override + public DescribeDelegationTokenResponseData data() { + return data; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + public Errors error() { + return Errors.forCode(data.errorCode()); + } + + public List tokens() { + return data.tokens() + .stream() + .map(ddt -> new DelegationToken(new TokenInformation( + ddt.tokenId(), + new KafkaPrincipal(ddt.principalType(), ddt.principalName()), + ddt.renewers() + .stream() + .map(ddtr -> new KafkaPrincipal(ddtr.principalType(), ddtr.principalName())) + .collect(Collectors.toList()), ddt.issueTimestamp(), ddt.maxTimestamp(), ddt.expiryTimestamp()), + ddt.hmac())) + .collect(Collectors.toList()); + } + + public boolean hasError() { + return error() != Errors.NONE; + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 1; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DescribeGroupsRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/DescribeGroupsRequest.java new file mode 100644 index 0000000..eff5bb9 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DescribeGroupsRequest.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.DescribeGroupsRequestData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; + +import static org.apache.kafka.common.requests.AbstractResponse.DEFAULT_THROTTLE_TIME; + +public class DescribeGroupsRequest extends AbstractRequest { + public static class Builder extends AbstractRequest.Builder { + private final DescribeGroupsRequestData data; + + public Builder(DescribeGroupsRequestData data) { + super(ApiKeys.DESCRIBE_GROUPS); + this.data = data; + } + + @Override + public DescribeGroupsRequest build(short version) { + return new DescribeGroupsRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final DescribeGroupsRequestData data; + + private DescribeGroupsRequest(DescribeGroupsRequestData data, short version) { + super(ApiKeys.DESCRIBE_GROUPS, version); + this.data = data; + } + + @Override + public DescribeGroupsRequestData data() { + return data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + if (version() == 0) { + return DescribeGroupsResponse.fromError(DEFAULT_THROTTLE_TIME, Errors.forException(e), data.groups()); + } else { + return DescribeGroupsResponse.fromError(throttleTimeMs, Errors.forException(e), data.groups()); + } + } + + public static DescribeGroupsRequest parse(ByteBuffer buffer, short version) { + return new DescribeGroupsRequest(new DescribeGroupsRequestData(new ByteBufferAccessor(buffer), version), version); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DescribeGroupsResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/DescribeGroupsResponse.java new file mode 100644 index 0000000..360caf0 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DescribeGroupsResponse.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.DescribeGroupsResponseData; +import org.apache.kafka.common.message.DescribeGroupsResponseData.DescribedGroup; +import org.apache.kafka.common.message.DescribeGroupsResponseData.DescribedGroupMember; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.utils.Utils; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +public class DescribeGroupsResponse extends AbstractResponse { + + public static final int AUTHORIZED_OPERATIONS_OMITTED = Integer.MIN_VALUE; + + /** + * Possible per-group error codes: + * + * COORDINATOR_LOAD_IN_PROGRESS (14) + * COORDINATOR_NOT_AVAILABLE (15) + * NOT_COORDINATOR (16) + * AUTHORIZATION_FAILED (29) + */ + + private final DescribeGroupsResponseData data; + + public DescribeGroupsResponse(DescribeGroupsResponseData data) { + super(ApiKeys.DESCRIBE_GROUPS); + this.data = data; + } + + public static DescribedGroupMember groupMember( + final String memberId, + final String groupInstanceId, + final String clientId, + final String clientHost, + final byte[] assignment, + final byte[] metadata) { + return new DescribedGroupMember() + .setMemberId(memberId) + .setGroupInstanceId(groupInstanceId) + .setClientId(clientId) + .setClientHost(clientHost) + .setMemberAssignment(assignment) + .setMemberMetadata(metadata); + } + + public static DescribedGroup groupMetadata( + final String groupId, + final Errors error, + final String state, + final String protocolType, + final String protocol, + final List members, + final Set authorizedOperations) { + DescribedGroup groupMetadata = new DescribedGroup(); + groupMetadata.setGroupId(groupId) + .setErrorCode(error.code()) + .setGroupState(state) + .setProtocolType(protocolType) + .setProtocolData(protocol) + .setMembers(members) + .setAuthorizedOperations(Utils.to32BitField(authorizedOperations)); + return groupMetadata; + } + + public static DescribedGroup groupMetadata( + final String groupId, + final Errors error, + final String state, + final String protocolType, + final String protocol, + final List members, + final int authorizedOperations) { + DescribedGroup groupMetadata = new DescribedGroup(); + groupMetadata.setGroupId(groupId) + .setErrorCode(error.code()) + .setGroupState(state) + .setProtocolType(protocolType) + .setProtocolData(protocol) + .setMembers(members) + .setAuthorizedOperations(authorizedOperations); + return groupMetadata; + } + + @Override + public DescribeGroupsResponseData data() { + return data; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + public static final String UNKNOWN_STATE = ""; + public static final String UNKNOWN_PROTOCOL_TYPE = ""; + public static final String UNKNOWN_PROTOCOL = ""; + + @Override + public Map errorCounts() { + Map errorCounts = new HashMap<>(); + data.groups().forEach(describedGroup -> + updateErrorCounts(errorCounts, Errors.forCode(describedGroup.errorCode()))); + return errorCounts; + } + + public static DescribedGroup forError(String groupId, Errors error) { + return groupMetadata(groupId, error, DescribeGroupsResponse.UNKNOWN_STATE, DescribeGroupsResponse.UNKNOWN_PROTOCOL_TYPE, + DescribeGroupsResponse.UNKNOWN_PROTOCOL, Collections.emptyList(), AUTHORIZED_OPERATIONS_OMITTED); + } + + public static DescribeGroupsResponse fromError(int throttleTimeMs, Errors error, List groupIds) { + DescribeGroupsResponseData describeGroupsResponseData = new DescribeGroupsResponseData(); + describeGroupsResponseData.setThrottleTimeMs(throttleTimeMs); + for (String groupId : groupIds) + describeGroupsResponseData.groups().add(DescribeGroupsResponse.forError(groupId, error)); + return new DescribeGroupsResponse(describeGroupsResponseData); + } + + public static DescribeGroupsResponse parse(ByteBuffer buffer, short version) { + return new DescribeGroupsResponse(new DescribeGroupsResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 2; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DescribeLogDirsRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/DescribeLogDirsRequest.java new file mode 100644 index 0000000..05ca4f0 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DescribeLogDirsRequest.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.DescribeLogDirsRequestData; +import org.apache.kafka.common.message.DescribeLogDirsResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; + +import java.nio.ByteBuffer; + +public class DescribeLogDirsRequest extends AbstractRequest { + + private final DescribeLogDirsRequestData data; + + public static class Builder extends AbstractRequest.Builder { + private final DescribeLogDirsRequestData data; + + public Builder(DescribeLogDirsRequestData data) { + super(ApiKeys.DESCRIBE_LOG_DIRS); + this.data = data; + } + + @Override + public DescribeLogDirsRequest build(short version) { + return new DescribeLogDirsRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + public DescribeLogDirsRequest(DescribeLogDirsRequestData data, short version) { + super(ApiKeys.DESCRIBE_LOG_DIRS, version); + this.data = data; + } + + @Override + public DescribeLogDirsRequestData data() { + return data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + return new DescribeLogDirsResponse(new DescribeLogDirsResponseData().setThrottleTimeMs(throttleTimeMs)); + } + + public boolean isAllTopicPartitions() { + return data.topics() == null; + } + + public static DescribeLogDirsRequest parse(ByteBuffer buffer, short version) { + return new DescribeLogDirsRequest(new DescribeLogDirsRequestData(new ByteBufferAccessor(buffer), version), version); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DescribeLogDirsResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/DescribeLogDirsResponse.java new file mode 100644 index 0000000..cd1326b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DescribeLogDirsResponse.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.DescribeLogDirsResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; + + +public class DescribeLogDirsResponse extends AbstractResponse { + + public static final long INVALID_OFFSET_LAG = -1L; + + private final DescribeLogDirsResponseData data; + + public DescribeLogDirsResponse(DescribeLogDirsResponseData data) { + super(ApiKeys.DESCRIBE_LOG_DIRS); + this.data = data; + } + + @Override + public DescribeLogDirsResponseData data() { + return data; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + @Override + public Map errorCounts() { + Map errorCounts = new HashMap<>(); + data.results().forEach(result -> { + updateErrorCounts(errorCounts, Errors.forCode(result.errorCode())); + }); + return errorCounts; + } + + public static DescribeLogDirsResponse parse(ByteBuffer buffer, short version) { + return new DescribeLogDirsResponse(new DescribeLogDirsResponseData(new ByteBufferAccessor(buffer), version)); + } + + // Note this class is part of the public API, reachable from Admin.describeLogDirs() + /** + * Possible error code: + * + * KAFKA_STORAGE_ERROR (56) + * UNKNOWN (-1) + * + * @deprecated Deprecated Since Kafka 2.7. + * Use {@link org.apache.kafka.clients.admin.DescribeLogDirsResult#descriptions()} + * and {@link org.apache.kafka.clients.admin.DescribeLogDirsResult#allDescriptions()} to access the replacement + * class {@link org.apache.kafka.clients.admin.LogDirDescription}. + */ + @Deprecated + static public class LogDirInfo { + public final Errors error; + public final Map replicaInfos; + + public LogDirInfo(Errors error, Map replicaInfos) { + this.error = error; + this.replicaInfos = replicaInfos; + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(); + builder.append("(error=") + .append(error) + .append(", replicas=") + .append(replicaInfos) + .append(")"); + return builder.toString(); + } + } + + // Note this class is part of the public API, reachable from Admin.describeLogDirs() + + /** + * @deprecated Deprecated Since Kafka 2.7. + * Use {@link org.apache.kafka.clients.admin.DescribeLogDirsResult#descriptions()} + * and {@link org.apache.kafka.clients.admin.DescribeLogDirsResult#allDescriptions()} to access the replacement + * class {@link org.apache.kafka.clients.admin.ReplicaInfo}. + */ + @Deprecated + static public class ReplicaInfo { + + public final long size; + public final long offsetLag; + public final boolean isFuture; + + public ReplicaInfo(long size, long offsetLag, boolean isFuture) { + this.size = size; + this.offsetLag = offsetLag; + this.isFuture = isFuture; + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(); + builder.append("(size=") + .append(size) + .append(", offsetLag=") + .append(offsetLag) + .append(", isFuture=") + .append(isFuture) + .append(")"); + return builder.toString(); + } + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 1; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DescribeProducersRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/DescribeProducersRequest.java new file mode 100644 index 0000000..39aab22 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DescribeProducersRequest.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.DescribeProducersRequestData; +import org.apache.kafka.common.message.DescribeProducersRequestData.TopicRequest; +import org.apache.kafka.common.message.DescribeProducersResponseData; +import org.apache.kafka.common.message.DescribeProducersResponseData.PartitionResponse; +import org.apache.kafka.common.message.DescribeProducersResponseData.TopicResponse; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; + +public class DescribeProducersRequest extends AbstractRequest { + public static class Builder extends AbstractRequest.Builder { + public final DescribeProducersRequestData data; + + public Builder(DescribeProducersRequestData data) { + super(ApiKeys.DESCRIBE_PRODUCERS); + this.data = data; + } + + public DescribeProducersRequestData.TopicRequest addTopic(String topic) { + DescribeProducersRequestData.TopicRequest topicRequest = + new DescribeProducersRequestData.TopicRequest().setName(topic); + data.topics().add(topicRequest); + return topicRequest; + } + + @Override + public DescribeProducersRequest build(short version) { + return new DescribeProducersRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final DescribeProducersRequestData data; + + private DescribeProducersRequest(DescribeProducersRequestData data, short version) { + super(ApiKeys.DESCRIBE_PRODUCERS, version); + this.data = data; + } + + @Override + public DescribeProducersRequestData data() { + return data; + } + + @Override + public DescribeProducersResponse getErrorResponse(int throttleTimeMs, Throwable e) { + Errors error = Errors.forException(e); + DescribeProducersResponseData response = new DescribeProducersResponseData(); + for (TopicRequest topicRequest : data.topics()) { + TopicResponse topicResponse = new TopicResponse() + .setName(topicRequest.name()); + for (int partitionId : topicRequest.partitionIndexes()) { + topicResponse.partitions().add( + new PartitionResponse() + .setPartitionIndex(partitionId) + .setErrorCode(error.code()) + ); + } + response.topics().add(topicResponse); + } + return new DescribeProducersResponse(response); + } + + public static DescribeProducersRequest parse(ByteBuffer buffer, short version) { + return new DescribeProducersRequest(new DescribeProducersRequestData( + new ByteBufferAccessor(buffer), version), version); + } + + @Override + public String toString(boolean verbose) { + return data.toString(); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DescribeProducersResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/DescribeProducersResponse.java new file mode 100644 index 0000000..74e9437 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DescribeProducersResponse.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.DescribeProducersResponseData; +import org.apache.kafka.common.message.DescribeProducersResponseData.PartitionResponse; +import org.apache.kafka.common.message.DescribeProducersResponseData.TopicResponse; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; + +public class DescribeProducersResponse extends AbstractResponse { + private final DescribeProducersResponseData data; + + public DescribeProducersResponse(DescribeProducersResponseData data) { + super(ApiKeys.DESCRIBE_PRODUCERS); + this.data = data; + } + + @Override + public DescribeProducersResponseData data() { + return data; + } + + @Override + public Map errorCounts() { + Map errorCounts = new HashMap<>(); + for (TopicResponse topicResponse : data.topics()) { + for (PartitionResponse partitionResponse : topicResponse.partitions()) { + updateErrorCounts(errorCounts, Errors.forCode(partitionResponse.errorCode())); + } + } + return errorCounts; + } + + public static DescribeProducersResponse parse(ByteBuffer buffer, short version) { + return new DescribeProducersResponse(new DescribeProducersResponseData( + new ByteBufferAccessor(buffer), version)); + } + + @Override + public String toString() { + return data.toString(); + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DescribeQuorumRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/DescribeQuorumRequest.java new file mode 100644 index 0000000..acdb11c --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DescribeQuorumRequest.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.DescribeQuorumRequestData; +import org.apache.kafka.common.message.DescribeQuorumResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +public class DescribeQuorumRequest extends AbstractRequest { + public static class Builder extends AbstractRequest.Builder { + private final DescribeQuorumRequestData data; + + public Builder(DescribeQuorumRequestData data) { + super(ApiKeys.DESCRIBE_QUORUM); + this.data = data; + } + + @Override + public DescribeQuorumRequest build(short version) { + return new DescribeQuorumRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final DescribeQuorumRequestData data; + + private DescribeQuorumRequest(DescribeQuorumRequestData data, short version) { + super(ApiKeys.DESCRIBE_QUORUM, version); + this.data = data; + } + + public static DescribeQuorumRequest parse(ByteBuffer buffer, short version) { + return new DescribeQuorumRequest(new DescribeQuorumRequestData(new ByteBufferAccessor(buffer), version), version); + } + + public static DescribeQuorumRequestData singletonRequest(TopicPartition topicPartition) { + return new DescribeQuorumRequestData() + .setTopics(Collections.singletonList( + new DescribeQuorumRequestData.TopicData() + .setTopicName(topicPartition.topic()) + .setPartitions(Collections.singletonList( + new DescribeQuorumRequestData.PartitionData() + .setPartitionIndex(topicPartition.partition())) + ))); + } + + @Override + public DescribeQuorumRequestData data() { + return data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + return new DescribeQuorumResponse(getTopLevelErrorResponse(Errors.forException(e))); + } + + public static DescribeQuorumResponseData getPartitionLevelErrorResponse(DescribeQuorumRequestData data, Errors error) { + short errorCode = error.code(); + + List topicResponses = new ArrayList<>(); + for (DescribeQuorumRequestData.TopicData topic : data.topics()) { + topicResponses.add( + new DescribeQuorumResponseData.TopicData() + .setTopicName(topic.topicName()) + .setPartitions(topic.partitions().stream().map( + requestPartition -> new DescribeQuorumResponseData.PartitionData() + .setPartitionIndex(requestPartition.partitionIndex()) + .setErrorCode(errorCode) + ).collect(Collectors.toList()))); + } + + return new DescribeQuorumResponseData().setTopics(topicResponses); + } + + public static DescribeQuorumResponseData getTopLevelErrorResponse(Errors topLevelError) { + return new DescribeQuorumResponseData().setErrorCode(topLevelError.code()); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DescribeQuorumResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/DescribeQuorumResponse.java new file mode 100644 index 0000000..cbf945b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DescribeQuorumResponse.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.DescribeQuorumResponseData; +import org.apache.kafka.common.message.DescribeQuorumResponseData.ReplicaState; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Possible error codes. + * + * Top level errors: + * - {@link Errors#CLUSTER_AUTHORIZATION_FAILED} + * - {@link Errors#BROKER_NOT_AVAILABLE} + * + * Partition level errors: + * - {@link Errors#INVALID_REQUEST} + * - {@link Errors#UNKNOWN_TOPIC_OR_PARTITION} + */ +public class DescribeQuorumResponse extends AbstractResponse { + private final DescribeQuorumResponseData data; + + public DescribeQuorumResponse(DescribeQuorumResponseData data) { + super(ApiKeys.DESCRIBE_QUORUM); + this.data = data; + } + + @Override + public Map errorCounts() { + Map errors = new HashMap<>(); + + errors.put(Errors.forCode(data.errorCode()), 1); + + for (DescribeQuorumResponseData.TopicData topicResponse : data.topics()) { + for (DescribeQuorumResponseData.PartitionData partitionResponse : topicResponse.partitions()) { + updateErrorCounts(errors, Errors.forCode(partitionResponse.errorCode())); + } + } + return errors; + } + + @Override + public DescribeQuorumResponseData data() { + return data; + } + + @Override + public int throttleTimeMs() { + return DEFAULT_THROTTLE_TIME; + } + + public static DescribeQuorumResponseData singletonResponse(TopicPartition topicPartition, + int leaderId, + int leaderEpoch, + long highWatermark, + List voterStates, + List observerStates) { + return new DescribeQuorumResponseData() + .setTopics(Collections.singletonList(new DescribeQuorumResponseData.TopicData() + .setTopicName(topicPartition.topic()) + .setPartitions(Collections.singletonList(new DescribeQuorumResponseData.PartitionData() + .setErrorCode(Errors.NONE.code()) + .setLeaderId(leaderId) + .setLeaderEpoch(leaderEpoch) + .setHighWatermark(highWatermark) + .setCurrentVoters(voterStates) + .setObservers(observerStates))))); + } + + public static DescribeQuorumResponse parse(ByteBuffer buffer, short version) { + return new DescribeQuorumResponse(new DescribeQuorumResponseData(new ByteBufferAccessor(buffer), version)); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DescribeTransactionsRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/DescribeTransactionsRequest.java new file mode 100644 index 0000000..a6e44fa --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DescribeTransactionsRequest.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.DescribeTransactionsRequestData; +import org.apache.kafka.common.message.DescribeTransactionsResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; + +public class DescribeTransactionsRequest extends AbstractRequest { + public static class Builder extends AbstractRequest.Builder { + public final DescribeTransactionsRequestData data; + + public Builder(DescribeTransactionsRequestData data) { + super(ApiKeys.DESCRIBE_TRANSACTIONS); + this.data = data; + } + + @Override + public DescribeTransactionsRequest build(short version) { + return new DescribeTransactionsRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final DescribeTransactionsRequestData data; + + private DescribeTransactionsRequest(DescribeTransactionsRequestData data, short version) { + super(ApiKeys.DESCRIBE_TRANSACTIONS, version); + this.data = data; + } + + @Override + public DescribeTransactionsRequestData data() { + return data; + } + + @Override + public DescribeTransactionsResponse getErrorResponse(int throttleTimeMs, Throwable e) { + Errors error = Errors.forException(e); + DescribeTransactionsResponseData response = new DescribeTransactionsResponseData() + .setThrottleTimeMs(throttleTimeMs); + + for (String transactionalId : data.transactionalIds()) { + DescribeTransactionsResponseData.TransactionState transactionState = + new DescribeTransactionsResponseData.TransactionState() + .setTransactionalId(transactionalId) + .setErrorCode(error.code()); + response.transactionStates().add(transactionState); + } + return new DescribeTransactionsResponse(response); + } + + public static DescribeTransactionsRequest parse(ByteBuffer buffer, short version) { + return new DescribeTransactionsRequest(new DescribeTransactionsRequestData( + new ByteBufferAccessor(buffer), version), version); + } + + @Override + public String toString(boolean verbose) { + return data.toString(); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DescribeTransactionsResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/DescribeTransactionsResponse.java new file mode 100644 index 0000000..cf151b3 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DescribeTransactionsResponse.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.DescribeTransactionsResponseData; +import org.apache.kafka.common.message.DescribeTransactionsResponseData.TransactionState; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; + +public class DescribeTransactionsResponse extends AbstractResponse { + private final DescribeTransactionsResponseData data; + + public DescribeTransactionsResponse(DescribeTransactionsResponseData data) { + super(ApiKeys.DESCRIBE_TRANSACTIONS); + this.data = data; + } + + @Override + public DescribeTransactionsResponseData data() { + return data; + } + + @Override + public Map errorCounts() { + Map errorCounts = new HashMap<>(); + for (TransactionState transactionState : data.transactionStates()) { + Errors error = Errors.forCode(transactionState.errorCode()); + updateErrorCounts(errorCounts, error); + } + return errorCounts; + } + + public static DescribeTransactionsResponse parse(ByteBuffer buffer, short version) { + return new DescribeTransactionsResponse(new DescribeTransactionsResponseData( + new ByteBufferAccessor(buffer), version)); + } + + @Override + public String toString() { + return data.toString(); + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DescribeUserScramCredentialsRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/DescribeUserScramCredentialsRequest.java new file mode 100644 index 0000000..0142e5a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DescribeUserScramCredentialsRequest.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.DescribeUserScramCredentialsRequestData; +import org.apache.kafka.common.message.DescribeUserScramCredentialsResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; + +import java.nio.ByteBuffer; + +public class DescribeUserScramCredentialsRequest extends AbstractRequest { + + public static class Builder extends AbstractRequest.Builder { + private final DescribeUserScramCredentialsRequestData data; + + public Builder(DescribeUserScramCredentialsRequestData data) { + super(ApiKeys.DESCRIBE_USER_SCRAM_CREDENTIALS); + this.data = data; + } + + @Override + public DescribeUserScramCredentialsRequest build(short version) { + return new DescribeUserScramCredentialsRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final DescribeUserScramCredentialsRequestData data; + private final short version; + + private DescribeUserScramCredentialsRequest(DescribeUserScramCredentialsRequestData data, short version) { + super(ApiKeys.DESCRIBE_USER_SCRAM_CREDENTIALS, version); + this.data = data; + this.version = version; + } + + public static DescribeUserScramCredentialsRequest parse(ByteBuffer buffer, short version) { + return new DescribeUserScramCredentialsRequest(new DescribeUserScramCredentialsRequestData( + new ByteBufferAccessor(buffer), version), version); + } + + @Override + public DescribeUserScramCredentialsRequestData data() { + return data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + ApiError apiError = ApiError.fromThrowable(e); + return new DescribeUserScramCredentialsResponse(new DescribeUserScramCredentialsResponseData() + .setErrorCode(apiError.error().code()) + .setErrorMessage(apiError.message())); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/DescribeUserScramCredentialsResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/DescribeUserScramCredentialsResponse.java new file mode 100644 index 0000000..001cefa --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/DescribeUserScramCredentialsResponse.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.DescribeUserScramCredentialsResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Map; + +public class DescribeUserScramCredentialsResponse extends AbstractResponse { + + private final DescribeUserScramCredentialsResponseData data; + + public DescribeUserScramCredentialsResponse(DescribeUserScramCredentialsResponseData responseData) { + super(ApiKeys.DESCRIBE_USER_SCRAM_CREDENTIALS); + this.data = responseData; + } + + @Override + public DescribeUserScramCredentialsResponseData data() { + return data; + } + + @Override + public boolean shouldClientThrottle(short version) { + return true; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + @Override + public Map errorCounts() { + return errorCounts(data.results().stream().map(r -> Errors.forCode(r.errorCode()))); + } + + public static DescribeUserScramCredentialsResponse parse(ByteBuffer buffer, short version) { + return new DescribeUserScramCredentialsResponse(new DescribeUserScramCredentialsResponseData(new ByteBufferAccessor(buffer), version)); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/ElectLeadersRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/ElectLeadersRequest.java new file mode 100644 index 0000000..febb030 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/ElectLeadersRequest.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import org.apache.kafka.common.ElectionType; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.ElectLeadersRequestData.TopicPartitions; +import org.apache.kafka.common.message.ElectLeadersRequestData; +import org.apache.kafka.common.message.ElectLeadersResponseData.PartitionResult; +import org.apache.kafka.common.message.ElectLeadersResponseData.ReplicaElectionResult; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.MessageUtil; + +public class ElectLeadersRequest extends AbstractRequest { + public static class Builder extends AbstractRequest.Builder { + private final ElectionType electionType; + private final Collection topicPartitions; + private final int timeoutMs; + + public Builder(ElectionType electionType, Collection topicPartitions, int timeoutMs) { + super(ApiKeys.ELECT_LEADERS); + this.electionType = electionType; + this.topicPartitions = topicPartitions; + this.timeoutMs = timeoutMs; + } + + @Override + public ElectLeadersRequest build(short version) { + return new ElectLeadersRequest(toRequestData(version), version); + } + + @Override + public String toString() { + return "ElectLeadersRequest(" + + "electionType=" + electionType + + ", topicPartitions=" + ((topicPartitions == null) ? "null" : MessageUtil.deepToString(topicPartitions.iterator())) + + ", timeoutMs=" + timeoutMs + + ")"; + } + + private ElectLeadersRequestData toRequestData(short version) { + if (electionType != ElectionType.PREFERRED && version == 0) { + throw new UnsupportedVersionException("API Version 0 only supports PREFERRED election type"); + } + + ElectLeadersRequestData data = new ElectLeadersRequestData() + .setTimeoutMs(timeoutMs); + + if (topicPartitions != null) { + topicPartitions.forEach(tp -> { + ElectLeadersRequestData.TopicPartitions tps = data.topicPartitions().find(tp.topic()); + if (tps == null) { + tps = new ElectLeadersRequestData.TopicPartitions().setTopic(tp.topic()); + data.topicPartitions().add(tps); + } + tps.partitions().add(tp.partition()); + }); + } else { + data.setTopicPartitions(null); + } + + data.setElectionType(electionType.value); + + return data; + } + } + + private final ElectLeadersRequestData data; + + private ElectLeadersRequest(ElectLeadersRequestData data, short version) { + super(ApiKeys.ELECT_LEADERS, version); + this.data = data; + } + + @Override + public ElectLeadersRequestData data() { + return data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + ApiError apiError = ApiError.fromThrowable(e); + List electionResults = new ArrayList<>(); + + if (data.topicPartitions() != null) { + for (TopicPartitions topic : data.topicPartitions()) { + ReplicaElectionResult electionResult = new ReplicaElectionResult(); + + electionResult.setTopic(topic.topic()); + for (Integer partitionId : topic.partitions()) { + PartitionResult partitionResult = new PartitionResult(); + partitionResult.setPartitionId(partitionId); + partitionResult.setErrorCode(apiError.error().code()); + partitionResult.setErrorMessage(apiError.message()); + + electionResult.partitionResult().add(partitionResult); + } + + electionResults.add(electionResult); + } + } + + return new ElectLeadersResponse(throttleTimeMs, apiError.error().code(), electionResults, version()); + } + + public static ElectLeadersRequest parse(ByteBuffer buffer, short version) { + return new ElectLeadersRequest(new ElectLeadersRequestData(new ByteBufferAccessor(buffer), version), version); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/ElectLeadersResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/ElectLeadersResponse.java new file mode 100644 index 0000000..88d4d19 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/ElectLeadersResponse.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.ElectLeadersResponseData; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.message.ElectLeadersResponseData.ReplicaElectionResult; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; + +public class ElectLeadersResponse extends AbstractResponse { + + private final ElectLeadersResponseData data; + + public ElectLeadersResponse(ElectLeadersResponseData data) { + super(ApiKeys.ELECT_LEADERS); + this.data = data; + } + + public ElectLeadersResponse( + int throttleTimeMs, + short errorCode, + List electionResults, + short version) { + super(ApiKeys.ELECT_LEADERS); + this.data = new ElectLeadersResponseData(); + data.setThrottleTimeMs(throttleTimeMs); + if (version >= 1) + data.setErrorCode(errorCode); + data.setReplicaElectionResults(electionResults); + } + + @Override + public ElectLeadersResponseData data() { + return data; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + @Override + public Map errorCounts() { + HashMap counts = new HashMap<>(); + updateErrorCounts(counts, Errors.forCode(data.errorCode())); + data.replicaElectionResults().forEach(result -> + result.partitionResult().forEach(partitionResult -> + updateErrorCounts(counts, Errors.forCode(partitionResult.errorCode())) + ) + ); + return counts; + } + + public static ElectLeadersResponse parse(ByteBuffer buffer, short version) { + return new ElectLeadersResponse(new ElectLeadersResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public boolean shouldClientThrottle(short version) { + return true; + } + + public static Map> electLeadersResult(ElectLeadersResponseData data) { + Map> map = new HashMap<>(); + + for (ElectLeadersResponseData.ReplicaElectionResult topicResults : data.replicaElectionResults()) { + for (ElectLeadersResponseData.PartitionResult partitionResult : topicResults.partitionResult()) { + Optional value = Optional.empty(); + Errors error = Errors.forCode(partitionResult.errorCode()); + if (error != Errors.NONE) { + value = Optional.of(error.exception(partitionResult.errorMessage())); + } + + map.put(new TopicPartition(topicResults.topic(), partitionResult.partitionId()), + value); + } + } + + return map; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/EndQuorumEpochRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/EndQuorumEpochRequest.java new file mode 100644 index 0000000..136bc54 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/EndQuorumEpochRequest.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.EndQuorumEpochRequestData; +import org.apache.kafka.common.message.EndQuorumEpochResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.List; + +public class EndQuorumEpochRequest extends AbstractRequest { + public static class Builder extends AbstractRequest.Builder { + private final EndQuorumEpochRequestData data; + + public Builder(EndQuorumEpochRequestData data) { + super(ApiKeys.END_QUORUM_EPOCH); + this.data = data; + } + + @Override + public EndQuorumEpochRequest build(short version) { + return new EndQuorumEpochRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final EndQuorumEpochRequestData data; + + private EndQuorumEpochRequest(EndQuorumEpochRequestData data, short version) { + super(ApiKeys.END_QUORUM_EPOCH, version); + this.data = data; + } + + @Override + public EndQuorumEpochRequestData data() { + return data; + } + + @Override + public EndQuorumEpochResponse getErrorResponse(int throttleTimeMs, Throwable e) { + return new EndQuorumEpochResponse(new EndQuorumEpochResponseData() + .setErrorCode(Errors.forException(e).code())); + } + + public static EndQuorumEpochRequest parse(ByteBuffer buffer, short version) { + return new EndQuorumEpochRequest(new EndQuorumEpochRequestData(new ByteBufferAccessor(buffer), version), version); + } + + public static EndQuorumEpochRequestData singletonRequest(TopicPartition topicPartition, + int leaderEpoch, + int leaderId, + List preferredSuccessors) { + return singletonRequest(topicPartition, null, leaderEpoch, leaderId, preferredSuccessors); + } + + public static EndQuorumEpochRequestData singletonRequest(TopicPartition topicPartition, + String clusterId, + int leaderEpoch, + int leaderId, + List preferredSuccessors) { + return new EndQuorumEpochRequestData() + .setClusterId(clusterId) + .setTopics(Collections.singletonList( + new EndQuorumEpochRequestData.TopicData() + .setTopicName(topicPartition.topic()) + .setPartitions(Collections.singletonList( + new EndQuorumEpochRequestData.PartitionData() + .setPartitionIndex(topicPartition.partition()) + .setLeaderEpoch(leaderEpoch) + .setLeaderId(leaderId) + .setPreferredSuccessors(preferredSuccessors)))) + ); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/EndQuorumEpochResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/EndQuorumEpochResponse.java new file mode 100644 index 0000000..ac2c0c5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/EndQuorumEpochResponse.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.EndQuorumEpochResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * Possible error codes. + * + * Top level errors: + * - {@link Errors#CLUSTER_AUTHORIZATION_FAILED} + * - {@link Errors#BROKER_NOT_AVAILABLE} + * + * Partition level errors: + * - {@link Errors#FENCED_LEADER_EPOCH} + * - {@link Errors#INVALID_REQUEST} + * - {@link Errors#INCONSISTENT_VOTER_SET} + * - {@link Errors#UNKNOWN_TOPIC_OR_PARTITION} + */ +public class EndQuorumEpochResponse extends AbstractResponse { + private final EndQuorumEpochResponseData data; + + public EndQuorumEpochResponse(EndQuorumEpochResponseData data) { + super(ApiKeys.END_QUORUM_EPOCH); + this.data = data; + } + + @Override + public Map errorCounts() { + Map errors = new HashMap<>(); + + errors.put(Errors.forCode(data.errorCode()), 1); + + for (EndQuorumEpochResponseData.TopicData topicResponse : data.topics()) { + for (EndQuorumEpochResponseData.PartitionData partitionResponse : topicResponse.partitions()) { + updateErrorCounts(errors, Errors.forCode(partitionResponse.errorCode())); + } + } + return errors; + } + + @Override + public EndQuorumEpochResponseData data() { + return data; + } + + @Override + public int throttleTimeMs() { + return DEFAULT_THROTTLE_TIME; + } + + public static EndQuorumEpochResponseData singletonResponse( + Errors topLevelError, + TopicPartition topicPartition, + Errors partitionLevelError, + int leaderEpoch, + int leaderId + ) { + return new EndQuorumEpochResponseData() + .setErrorCode(topLevelError.code()) + .setTopics(Collections.singletonList( + new EndQuorumEpochResponseData.TopicData() + .setTopicName(topicPartition.topic()) + .setPartitions(Collections.singletonList( + new EndQuorumEpochResponseData.PartitionData() + .setErrorCode(partitionLevelError.code()) + .setLeaderId(leaderId) + .setLeaderEpoch(leaderEpoch) + ))) + ); + } + + public static EndQuorumEpochResponse parse(ByteBuffer buffer, short version) { + return new EndQuorumEpochResponse(new EndQuorumEpochResponseData(new ByteBufferAccessor(buffer), version)); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/EndTxnRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/EndTxnRequest.java new file mode 100644 index 0000000..c9ea980 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/EndTxnRequest.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.EndTxnRequestData; +import org.apache.kafka.common.message.EndTxnResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; + +public class EndTxnRequest extends AbstractRequest { + + private final EndTxnRequestData data; + + public static class Builder extends AbstractRequest.Builder { + public final EndTxnRequestData data; + + public Builder(EndTxnRequestData data) { + super(ApiKeys.END_TXN); + this.data = data; + } + + @Override + public EndTxnRequest build(short version) { + return new EndTxnRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private EndTxnRequest(EndTxnRequestData data, short version) { + super(ApiKeys.END_TXN, version); + this.data = data; + } + + public TransactionResult result() { + if (data.committed()) + return TransactionResult.COMMIT; + else + return TransactionResult.ABORT; + } + + @Override + public EndTxnRequestData data() { + return data; + } + + @Override + public EndTxnResponse getErrorResponse(int throttleTimeMs, Throwable e) { + return new EndTxnResponse(new EndTxnResponseData() + .setErrorCode(Errors.forException(e).code()) + .setThrottleTimeMs(throttleTimeMs) + ); + } + + public static EndTxnRequest parse(ByteBuffer buffer, short version) { + return new EndTxnRequest(new EndTxnRequestData(new ByteBufferAccessor(buffer), version), version); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/EndTxnResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/EndTxnResponse.java new file mode 100644 index 0000000..029e7d0 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/EndTxnResponse.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.EndTxnResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Map; + +/** + * Possible error codes: + * + * - {@link Errors#NOT_COORDINATOR} + * - {@link Errors#COORDINATOR_NOT_AVAILABLE} + * - {@link Errors#COORDINATOR_LOAD_IN_PROGRESS} + * - {@link Errors#INVALID_TXN_STATE} + * - {@link Errors#INVALID_PRODUCER_ID_MAPPING} + * - {@link Errors#INVALID_PRODUCER_EPOCH} // for version <=1 + * - {@link Errors#PRODUCER_FENCED} + * - {@link Errors#TRANSACTIONAL_ID_AUTHORIZATION_FAILED} + */ +public class EndTxnResponse extends AbstractResponse { + + private final EndTxnResponseData data; + + public EndTxnResponse(EndTxnResponseData data) { + super(ApiKeys.END_TXN); + this.data = data; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + + public Errors error() { + return Errors.forCode(data.errorCode()); + } + + @Override + public Map errorCounts() { + return errorCounts(error()); + } + + @Override + public EndTxnResponseData data() { + return data; + } + + public static EndTxnResponse parse(ByteBuffer buffer, short version) { + return new EndTxnResponse(new EndTxnResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public String toString() { + return data.toString(); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 1; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/EnvelopeRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/EnvelopeRequest.java new file mode 100644 index 0000000..5e8d3fa --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/EnvelopeRequest.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.EnvelopeRequestData; +import org.apache.kafka.common.message.EnvelopeResponseData; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; + +public class EnvelopeRequest extends AbstractRequest { + + public static class Builder extends AbstractRequest.Builder { + + private final EnvelopeRequestData data; + + public Builder(ByteBuffer requestData, + byte[] serializedPrincipal, + byte[] clientAddress) { + super(ApiKeys.ENVELOPE); + this.data = new EnvelopeRequestData() + .setRequestData(requestData) + .setRequestPrincipal(serializedPrincipal) + .setClientHostAddress(clientAddress); + } + + @Override + public EnvelopeRequest build(short version) { + return new EnvelopeRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final EnvelopeRequestData data; + + public EnvelopeRequest(EnvelopeRequestData data, short version) { + super(ApiKeys.ENVELOPE, version); + this.data = data; + } + + public ByteBuffer requestData() { + return data.requestData(); + } + + public byte[] clientAddress() { + return data.clientHostAddress(); + } + + public byte[] requestPrincipal() { + return data.requestPrincipal(); + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + return new EnvelopeResponse(new EnvelopeResponseData() + .setErrorCode(Errors.forException(e).code())); + } + + public static EnvelopeRequest parse(ByteBuffer buffer, short version) { + return new EnvelopeRequest(new EnvelopeRequestData(new ByteBufferAccessor(buffer), version), version); + } + + @Override + public EnvelopeRequestData data() { + return data; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/EnvelopeResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/EnvelopeResponse.java new file mode 100644 index 0000000..529f616 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/EnvelopeResponse.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.EnvelopeResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Map; + +public class EnvelopeResponse extends AbstractResponse { + + private final EnvelopeResponseData data; + + public EnvelopeResponse(ByteBuffer responseData, Errors error) { + super(ApiKeys.ENVELOPE); + this.data = new EnvelopeResponseData() + .setResponseData(responseData) + .setErrorCode(error.code()); + } + + public EnvelopeResponse(Errors error) { + this(null, error); + } + + public EnvelopeResponse(EnvelopeResponseData data) { + super(ApiKeys.ENVELOPE); + this.data = data; + } + + public ByteBuffer responseData() { + return data.responseData(); + } + + @Override + public Map errorCounts() { + return errorCounts(error()); + } + + public Errors error() { + return Errors.forCode(data.errorCode()); + } + + @Override + public EnvelopeResponseData data() { + return data; + } + + @Override + public int throttleTimeMs() { + return DEFAULT_THROTTLE_TIME; + } + + public static EnvelopeResponse parse(ByteBuffer buffer, short version) { + return new EnvelopeResponse(new EnvelopeResponseData(new ByteBufferAccessor(buffer), version)); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/ExpireDelegationTokenRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/ExpireDelegationTokenRequest.java new file mode 100644 index 0000000..85b0238 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/ExpireDelegationTokenRequest.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import java.nio.ByteBuffer; + +import org.apache.kafka.common.message.ExpireDelegationTokenRequestData; +import org.apache.kafka.common.message.ExpireDelegationTokenResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +public class ExpireDelegationTokenRequest extends AbstractRequest { + + private final ExpireDelegationTokenRequestData data; + + private ExpireDelegationTokenRequest(ExpireDelegationTokenRequestData data, short version) { + super(ApiKeys.EXPIRE_DELEGATION_TOKEN, version); + this.data = data; + } + + public static ExpireDelegationTokenRequest parse(ByteBuffer buffer, short version) { + return new ExpireDelegationTokenRequest( + new ExpireDelegationTokenRequestData(new ByteBufferAccessor(buffer), version), version); + } + + @Override + public ExpireDelegationTokenRequestData data() { + return data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + return new ExpireDelegationTokenResponse( + new ExpireDelegationTokenResponseData() + .setErrorCode(Errors.forException(e).code()) + .setThrottleTimeMs(throttleTimeMs)); + } + + public ByteBuffer hmac() { + return ByteBuffer.wrap(data.hmac()); + } + + public long expiryTimePeriod() { + return data.expiryTimePeriodMs(); + } + + public static class Builder extends AbstractRequest.Builder { + private final ExpireDelegationTokenRequestData data; + + public Builder(ExpireDelegationTokenRequestData data) { + super(ApiKeys.EXPIRE_DELEGATION_TOKEN); + this.data = data; + } + + @Override + public ExpireDelegationTokenRequest build(short version) { + return new ExpireDelegationTokenRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/ExpireDelegationTokenResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/ExpireDelegationTokenResponse.java new file mode 100644 index 0000000..163ee78 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/ExpireDelegationTokenResponse.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import java.nio.ByteBuffer; +import java.util.Map; + +import org.apache.kafka.common.message.ExpireDelegationTokenResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +public class ExpireDelegationTokenResponse extends AbstractResponse { + + private final ExpireDelegationTokenResponseData data; + + public ExpireDelegationTokenResponse(ExpireDelegationTokenResponseData data) { + super(ApiKeys.EXPIRE_DELEGATION_TOKEN); + this.data = data; + } + + public static ExpireDelegationTokenResponse parse(ByteBuffer buffer, short version) { + return new ExpireDelegationTokenResponse(new ExpireDelegationTokenResponseData(new ByteBufferAccessor(buffer), + version)); + } + + public Errors error() { + return Errors.forCode(data.errorCode()); + } + + public long expiryTimestamp() { + return data.expiryTimestampMs(); + } + + @Override + public Map errorCounts() { + return errorCounts(error()); + } + + @Override + public ExpireDelegationTokenResponseData data() { + return data; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + public boolean hasError() { + return error() != Errors.NONE; + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 1; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/FetchMetadata.java b/clients/src/main/java/org/apache/kafka/common/requests/FetchMetadata.java new file mode 100644 index 0000000..feb6953 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/FetchMetadata.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Objects; + +public class FetchMetadata { + public static final Logger log = LoggerFactory.getLogger(FetchMetadata.class); + + /** + * The session ID used by clients with no session. + */ + public static final int INVALID_SESSION_ID = 0; + + /** + * The first epoch. When used in a fetch request, indicates that the client + * wants to create or recreate a session. + */ + public static final int INITIAL_EPOCH = 0; + + /** + * An invalid epoch. When used in a fetch request, indicates that the client + * wants to close any existing session, and not create a new one. + */ + public static final int FINAL_EPOCH = -1; + + /** + * The FetchMetadata that is used when initializing a new FetchSessionHandler. + */ + public static final FetchMetadata INITIAL = new FetchMetadata(INVALID_SESSION_ID, INITIAL_EPOCH); + + /** + * The FetchMetadata that is implicitly used for handling older FetchRequests that + * don't include fetch metadata. + */ + public static final FetchMetadata LEGACY = new FetchMetadata(INVALID_SESSION_ID, FINAL_EPOCH); + + /** + * Returns the next epoch. + * + * @param prevEpoch The previous epoch. + * @return The next epoch. + */ + public static int nextEpoch(int prevEpoch) { + if (prevEpoch < 0) { + // The next epoch after FINAL_EPOCH is always FINAL_EPOCH itself. + return FINAL_EPOCH; + } else if (prevEpoch == Integer.MAX_VALUE) { + return 1; + } else { + return prevEpoch + 1; + } + } + + /** + * The fetch session ID. + */ + private final int sessionId; + + /** + * The fetch session epoch. + */ + private final int epoch; + + public FetchMetadata(int sessionId, int epoch) { + this.sessionId = sessionId; + this.epoch = epoch; + } + + /** + * Returns true if this is a full fetch request. + */ + public boolean isFull() { + return (this.epoch == INITIAL_EPOCH) || (this.epoch == FINAL_EPOCH); + } + + public int sessionId() { + return sessionId; + } + + public int epoch() { + return epoch; + } + + @Override + public int hashCode() { + return Objects.hash(sessionId, epoch); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + FetchMetadata that = (FetchMetadata) o; + return sessionId == that.sessionId && epoch == that.epoch; + } + + /** + * Return the metadata for the next error response. + */ + public FetchMetadata nextCloseExisting() { + return new FetchMetadata(sessionId, INITIAL_EPOCH); + } + + /** + * Return the metadata for the next full fetch request. + */ + public static FetchMetadata newIncremental(int sessionId) { + return new FetchMetadata(sessionId, nextEpoch(INITIAL_EPOCH)); + } + + /** + * Return the metadata for the next incremental response. + */ + public FetchMetadata nextIncremental() { + return new FetchMetadata(sessionId, nextEpoch(epoch)); + } + + @Override + public String toString() { + StringBuilder bld = new StringBuilder(); + if (sessionId == INVALID_SESSION_ID) { + bld.append("(sessionId=INVALID, "); + } else { + bld.append("(sessionId=").append(sessionId).append(", "); + } + if (epoch == INITIAL_EPOCH) { + bld.append("epoch=INITIAL)"); + } else if (epoch == FINAL_EPOCH) { + bld.append("epoch=FINAL)"); + } else { + bld.append("epoch=").append(epoch).append(")"); + } + return bld.toString(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/FetchRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/FetchRequest.java new file mode 100644 index 0000000..48ba022 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/FetchRequest.java @@ -0,0 +1,430 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.IsolationLevel; +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.message.FetchRequestData; +import org.apache.kafka.common.message.FetchRequestData.ForgottenTopic; +import org.apache.kafka.common.message.FetchResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.utils.Utils; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +public class FetchRequest extends AbstractRequest { + + public static final int CONSUMER_REPLICA_ID = -1; + + // default values for older versions where a request level limit did not exist + public static final int DEFAULT_RESPONSE_MAX_BYTES = Integer.MAX_VALUE; + public static final long INVALID_LOG_START_OFFSET = -1L; + + private final FetchRequestData data; + private volatile LinkedHashMap fetchData = null; + private volatile List toForget = null; + + // This is an immutable read-only structures derived from FetchRequestData + private final FetchMetadata metadata; + + public static final class PartitionData { + public final Uuid topicId; + public final long fetchOffset; + public final long logStartOffset; + public final int maxBytes; + public final Optional currentLeaderEpoch; + public final Optional lastFetchedEpoch; + + public PartitionData( + Uuid topicId, + long fetchOffset, + long logStartOffset, + int maxBytes, + Optional currentLeaderEpoch + ) { + this(topicId, fetchOffset, logStartOffset, maxBytes, currentLeaderEpoch, Optional.empty()); + } + + public PartitionData( + Uuid topicId, + long fetchOffset, + long logStartOffset, + int maxBytes, + Optional currentLeaderEpoch, + Optional lastFetchedEpoch + ) { + this.topicId = topicId; + this.fetchOffset = fetchOffset; + this.logStartOffset = logStartOffset; + this.maxBytes = maxBytes; + this.currentLeaderEpoch = currentLeaderEpoch; + this.lastFetchedEpoch = lastFetchedEpoch; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PartitionData that = (PartitionData) o; + return Objects.equals(topicId, that.topicId) && + fetchOffset == that.fetchOffset && + logStartOffset == that.logStartOffset && + maxBytes == that.maxBytes && + Objects.equals(currentLeaderEpoch, that.currentLeaderEpoch) && + Objects.equals(lastFetchedEpoch, that.lastFetchedEpoch); + } + + @Override + public int hashCode() { + return Objects.hash(topicId, fetchOffset, logStartOffset, maxBytes, currentLeaderEpoch, lastFetchedEpoch); + } + + @Override + public String toString() { + return "PartitionData(" + + "topicId=" + topicId + + ", fetchOffset=" + fetchOffset + + ", logStartOffset=" + logStartOffset + + ", maxBytes=" + maxBytes + + ", currentLeaderEpoch=" + currentLeaderEpoch + + ", lastFetchedEpoch=" + lastFetchedEpoch + + ')'; + } + } + + private static Optional optionalEpoch(int rawEpochValue) { + if (rawEpochValue < 0) { + return Optional.empty(); + } else { + return Optional.of(rawEpochValue); + } + } + + public static class Builder extends AbstractRequest.Builder { + private final int maxWait; + private final int minBytes; + private final int replicaId; + private final Map toFetch; + private IsolationLevel isolationLevel = IsolationLevel.READ_UNCOMMITTED; + private int maxBytes = DEFAULT_RESPONSE_MAX_BYTES; + private FetchMetadata metadata = FetchMetadata.LEGACY; + private List removed = Collections.emptyList(); + private List replaced = Collections.emptyList(); + private String rackId = ""; + + public static Builder forConsumer(short maxVersion, int maxWait, int minBytes, Map fetchData) { + return new Builder(ApiKeys.FETCH.oldestVersion(), maxVersion, + CONSUMER_REPLICA_ID, maxWait, minBytes, fetchData); + } + + public static Builder forReplica(short allowedVersion, int replicaId, int maxWait, int minBytes, + Map fetchData) { + return new Builder(allowedVersion, allowedVersion, replicaId, maxWait, minBytes, fetchData); + } + + public Builder(short minVersion, short maxVersion, int replicaId, int maxWait, int minBytes, + Map fetchData) { + super(ApiKeys.FETCH, minVersion, maxVersion); + this.replicaId = replicaId; + this.maxWait = maxWait; + this.minBytes = minBytes; + this.toFetch = fetchData; + } + + public Builder isolationLevel(IsolationLevel isolationLevel) { + this.isolationLevel = isolationLevel; + return this; + } + + public Builder metadata(FetchMetadata metadata) { + this.metadata = metadata; + return this; + } + + public Builder rackId(String rackId) { + this.rackId = rackId; + return this; + } + + public Map fetchData() { + return this.toFetch; + } + + public Builder setMaxBytes(int maxBytes) { + this.maxBytes = maxBytes; + return this; + } + + public List removed() { + return removed; + } + + public Builder removed(List removed) { + this.removed = removed; + return this; + } + + public List replaced() { + return replaced; + } + + public Builder replaced(List replaced) { + this.replaced = replaced; + return this; + } + + private void addToForgottenTopicMap(List toForget, Map forgottenTopicMap) { + toForget.forEach(topicIdPartition -> { + FetchRequestData.ForgottenTopic forgottenTopic = forgottenTopicMap.get(topicIdPartition.topic()); + if (forgottenTopic == null) { + forgottenTopic = new ForgottenTopic() + .setTopic(topicIdPartition.topic()) + .setTopicId(topicIdPartition.topicId()); + forgottenTopicMap.put(topicIdPartition.topic(), forgottenTopic); + } + forgottenTopic.partitions().add(topicIdPartition.partition()); + }); + } + + @Override + public FetchRequest build(short version) { + if (version < 3) { + maxBytes = DEFAULT_RESPONSE_MAX_BYTES; + } + + FetchRequestData fetchRequestData = new FetchRequestData(); + fetchRequestData.setReplicaId(replicaId); + fetchRequestData.setMaxWaitMs(maxWait); + fetchRequestData.setMinBytes(minBytes); + fetchRequestData.setMaxBytes(maxBytes); + fetchRequestData.setIsolationLevel(isolationLevel.id()); + fetchRequestData.setForgottenTopicsData(new ArrayList<>()); + + Map forgottenTopicMap = new LinkedHashMap<>(); + addToForgottenTopicMap(removed, forgottenTopicMap); + + // If a version older than v13 is used, topic-partition which were replaced + // by a topic-partition with the same name but a different topic ID are not + // sent out in the "forget" set in order to not remove the newly added + // partition in the "fetch" set. + if (version >= 13) { + addToForgottenTopicMap(replaced, forgottenTopicMap); + } + + forgottenTopicMap.forEach((topic, forgottenTopic) -> fetchRequestData.forgottenTopicsData().add(forgottenTopic)); + + // We collect the partitions in a single FetchTopic only if they appear sequentially in the fetchData + fetchRequestData.setTopics(new ArrayList<>()); + FetchRequestData.FetchTopic fetchTopic = null; + for (Map.Entry entry : toFetch.entrySet()) { + TopicPartition topicPartition = entry.getKey(); + PartitionData partitionData = entry.getValue(); + + if (fetchTopic == null || !topicPartition.topic().equals(fetchTopic.topic())) { + fetchTopic = new FetchRequestData.FetchTopic() + .setTopic(topicPartition.topic()) + .setTopicId(partitionData.topicId) + .setPartitions(new ArrayList<>()); + fetchRequestData.topics().add(fetchTopic); + } + + FetchRequestData.FetchPartition fetchPartition = new FetchRequestData.FetchPartition() + .setPartition(topicPartition.partition()) + .setCurrentLeaderEpoch(partitionData.currentLeaderEpoch.orElse(RecordBatch.NO_PARTITION_LEADER_EPOCH)) + .setLastFetchedEpoch(partitionData.lastFetchedEpoch.orElse(RecordBatch.NO_PARTITION_LEADER_EPOCH)) + .setFetchOffset(partitionData.fetchOffset) + .setLogStartOffset(partitionData.logStartOffset) + .setPartitionMaxBytes(partitionData.maxBytes); + + fetchTopic.partitions().add(fetchPartition); + } + + if (metadata != null) { + fetchRequestData.setSessionEpoch(metadata.epoch()); + fetchRequestData.setSessionId(metadata.sessionId()); + } + fetchRequestData.setRackId(rackId); + + return new FetchRequest(fetchRequestData, version); + } + + @Override + public String toString() { + StringBuilder bld = new StringBuilder(); + bld.append("(type=FetchRequest"). + append(", replicaId=").append(replicaId). + append(", maxWait=").append(maxWait). + append(", minBytes=").append(minBytes). + append(", maxBytes=").append(maxBytes). + append(", fetchData=").append(toFetch). + append(", isolationLevel=").append(isolationLevel). + append(", removed=").append(Utils.join(removed, ", ")). + append(", replaced=").append(Utils.join(replaced, ", ")). + append(", metadata=").append(metadata). + append(", rackId=").append(rackId). + append(")"); + return bld.toString(); + } + } + + public FetchRequest(FetchRequestData fetchRequestData, short version) { + super(ApiKeys.FETCH, version); + this.data = fetchRequestData; + this.metadata = new FetchMetadata(fetchRequestData.sessionId(), fetchRequestData.sessionEpoch()); + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + // For versions 13+ the error is indicated by setting the top-level error code, and no partitions will be returned. + // For earlier versions, the error is indicated in two ways: by setting the same error code in all partitions, + // and by setting the top-level error code. The form where we set the same error code in all partitions + // is needed in order to maintain backwards compatibility with older versions of the protocol + // in which there was no top-level error code. Note that for incremental fetch responses, there + // may not be any partitions at all in the response. For this reason, the top-level error code + // is essential for them. + Errors error = Errors.forException(e); + List topicResponseList = new ArrayList<>(); + // For version 13+, we know the client can handle a top level error code, so we don't need to send back partitions too. + if (version() < 13) { + data.topics().forEach(topic -> { + List partitionResponses = topic.partitions().stream().map(partition -> + FetchResponse.partitionResponse(partition.partition(), error)).collect(Collectors.toList()); + topicResponseList.add(new FetchResponseData.FetchableTopicResponse() + .setTopic(topic.topic()) + .setTopicId(topic.topicId()) + .setPartitions(partitionResponses)); + }); + } + return new FetchResponse(new FetchResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setErrorCode(error.code()) + .setSessionId(data.sessionId()) + .setResponses(topicResponseList)); + } + + public int replicaId() { + return data.replicaId(); + } + + public int maxWait() { + return data.maxWaitMs(); + } + + public int minBytes() { + return data.minBytes(); + } + + public int maxBytes() { + return data.maxBytes(); + } + + // For versions < 13, builds the partitionData map using only the FetchRequestData. + // For versions 13+, builds the partitionData map using both the FetchRequestData and a mapping of topic IDs to names. + public Map fetchData(Map topicNames) { + if (fetchData == null) { + synchronized (this) { + if (fetchData == null) { + fetchData = new LinkedHashMap<>(); + short version = version(); + data.topics().forEach(fetchTopic -> { + String name; + if (version < 13) { + name = fetchTopic.topic(); // can't be null + } else { + name = topicNames.get(fetchTopic.topicId()); + } + fetchTopic.partitions().forEach(fetchPartition -> + // Topic name may be null here if the topic name was unable to be resolved using the topicNames map. + fetchData.put(new TopicIdPartition(fetchTopic.topicId(), new TopicPartition(name, fetchPartition.partition())), + new PartitionData( + fetchTopic.topicId(), + fetchPartition.fetchOffset(), + fetchPartition.logStartOffset(), + fetchPartition.partitionMaxBytes(), + optionalEpoch(fetchPartition.currentLeaderEpoch()), + optionalEpoch(fetchPartition.lastFetchedEpoch()) + ) + ) + ); + }); + } + } + } + return fetchData; + } + + // For versions < 13, builds the forgotten topics list using only the FetchRequestData. + // For versions 13+, builds the forgotten topics list using both the FetchRequestData and a mapping of topic IDs to names. + public List forgottenTopics(Map topicNames) { + if (toForget == null) { + synchronized (this) { + if (toForget == null) { + toForget = new ArrayList<>(); + data.forgottenTopicsData().forEach(forgottenTopic -> { + String name; + if (version() < 13) { + name = forgottenTopic.topic(); // can't be null + } else { + name = topicNames.get(forgottenTopic.topicId()); + } + // Topic name may be null here if the topic name was unable to be resolved using the topicNames map. + forgottenTopic.partitions().forEach(partitionId -> toForget.add(new TopicIdPartition(forgottenTopic.topicId(), new TopicPartition(name, partitionId)))); + }); + } + } + } + return toForget; + } + + public boolean isFromFollower() { + return replicaId() >= 0; + } + + public IsolationLevel isolationLevel() { + return IsolationLevel.forId(data.isolationLevel()); + } + + public FetchMetadata metadata() { + return metadata; + } + + public String rackId() { + return data.rackId(); + } + + public static FetchRequest parse(ByteBuffer buffer, short version) { + return new FetchRequest(new FetchRequestData(new ByteBufferAccessor(buffer), version), version); + } + + @Override + public FetchRequestData data() { + return data; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/FetchResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/FetchResponse.java new file mode 100644 index 0000000..2e0a02e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/FetchResponse.java @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.message.FetchResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.protocol.ObjectSerializationCache; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.Records; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID; + +/** + * This wrapper supports all versions of the Fetch API + * + * Possible error codes: + * + * - {@link Errors#OFFSET_OUT_OF_RANGE} If the fetch offset is out of range for a requested partition + * - {@link Errors#TOPIC_AUTHORIZATION_FAILED} If the user does not have READ access to a requested topic + * - {@link Errors#REPLICA_NOT_AVAILABLE} If the request is received by a broker with version < 2.6 which is not a replica + * - {@link Errors#NOT_LEADER_OR_FOLLOWER} If the broker is not a leader or follower and either the provided leader epoch + * matches the known leader epoch on the broker or is empty + * - {@link Errors#FENCED_LEADER_EPOCH} If the epoch is lower than the broker's epoch + * - {@link Errors#UNKNOWN_LEADER_EPOCH} If the epoch is larger than the broker's epoch + * - {@link Errors#UNKNOWN_TOPIC_OR_PARTITION} If the broker does not have metadata for a topic or partition + * - {@link Errors#KAFKA_STORAGE_ERROR} If the log directory for one of the requested partitions is offline + * - {@link Errors#UNSUPPORTED_COMPRESSION_TYPE} If a fetched topic is using a compression type which is + * not supported by the fetch request version + * - {@link Errors#CORRUPT_MESSAGE} If corrupt message encountered, e.g. when the broker scans the log to find + * the fetch offset after the index lookup + * - {@link Errors#UNKNOWN_TOPIC_ID} If the request contains a topic ID unknown to the broker + * - {@link Errors#FETCH_SESSION_TOPIC_ID_ERROR} If the request version supports topic IDs but the session does not or vice versa, + * or a topic ID in the request is inconsistent with a topic ID in the session + * - {@link Errors#INCONSISTENT_TOPIC_ID} If a topic ID in the session does not match the topic ID in the log + * - {@link Errors#UNKNOWN_SERVER_ERROR} For any unexpected errors + */ +public class FetchResponse extends AbstractResponse { + public static final long INVALID_HIGH_WATERMARK = -1L; + public static final long INVALID_LAST_STABLE_OFFSET = -1L; + public static final long INVALID_LOG_START_OFFSET = -1L; + public static final int INVALID_PREFERRED_REPLICA_ID = -1; + + private final FetchResponseData data; + // we build responseData when needed. + private volatile LinkedHashMap responseData = null; + + @Override + public FetchResponseData data() { + return data; + } + + /** + * From version 3 or later, the authorized and existing entries in `FetchRequest.fetchData` should be in the same order in `responseData`. + * Version 13 introduces topic IDs which can lead to a few new errors. If there is any unknown topic ID in the request, the + * response will contain a partition-level UNKNOWN_TOPIC_ID error for that partition. + * If a request's topic ID usage is inconsistent with the session, we will return a top level FETCH_SESSION_TOPIC_ID_ERROR error. + * We may also return INCONSISTENT_TOPIC_ID error as a partition-level error when a partition in the session has a topic ID + * inconsistent with the log. + */ + public FetchResponse(FetchResponseData fetchResponseData) { + super(ApiKeys.FETCH); + this.data = fetchResponseData; + } + + public Errors error() { + return Errors.forCode(data.errorCode()); + } + + public LinkedHashMap responseData(Map topicNames, short version) { + if (responseData == null) { + synchronized (this) { + if (responseData == null) { + responseData = new LinkedHashMap<>(); + data.responses().forEach(topicResponse -> { + String name; + if (version < 13) { + name = topicResponse.topic(); + } else { + name = topicNames.get(topicResponse.topicId()); + } + if (name != null) { + topicResponse.partitions().forEach(partition -> + responseData.put(new TopicPartition(name, partition.partitionIndex()), partition)); + } + }); + } + } + } + return responseData; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + public int sessionId() { + return data.sessionId(); + } + + @Override + public Map errorCounts() { + Map errorCounts = new HashMap<>(); + updateErrorCounts(errorCounts, error()); + data.responses().forEach(topicResponse -> + topicResponse.partitions().forEach(partition -> + updateErrorCounts(errorCounts, Errors.forCode(partition.errorCode()))) + ); + return errorCounts; + } + + public static FetchResponse parse(ByteBuffer buffer, short version) { + return new FetchResponse(new FetchResponseData(new ByteBufferAccessor(buffer), version)); + } + + // Fetch versions 13 and above should have topic IDs for all topics. + // Fetch versions < 13 should return the empty set. + public Set topicIds() { + return data.responses().stream().map(FetchResponseData.FetchableTopicResponse::topicId).filter(id -> !id.equals(Uuid.ZERO_UUID)).collect(Collectors.toSet()); + } + + /** + * Convenience method to find the size of a response. + * + * @param version The version of the response to use. + * @param partIterator The partition iterator. + * @return The response size in bytes. + */ + public static int sizeOf(short version, + Iterator> partIterator) { + // Since the throttleTimeMs and metadata field sizes are constant and fixed, we can + // use arbitrary values here without affecting the result. + FetchResponseData data = toMessage(Errors.NONE, 0, INVALID_SESSION_ID, partIterator); + ObjectSerializationCache cache = new ObjectSerializationCache(); + return 4 + data.size(cache, version); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 8; + } + + public static Optional divergingEpoch(FetchResponseData.PartitionData partitionResponse) { + return partitionResponse.divergingEpoch().epoch() < 0 ? Optional.empty() + : Optional.of(partitionResponse.divergingEpoch()); + } + + public static boolean isDivergingEpoch(FetchResponseData.PartitionData partitionResponse) { + return partitionResponse.divergingEpoch().epoch() >= 0; + } + + public static Optional preferredReadReplica(FetchResponseData.PartitionData partitionResponse) { + return partitionResponse.preferredReadReplica() == INVALID_PREFERRED_REPLICA_ID ? Optional.empty() + : Optional.of(partitionResponse.preferredReadReplica()); + } + + public static boolean isPreferredReplica(FetchResponseData.PartitionData partitionResponse) { + return partitionResponse.preferredReadReplica() != INVALID_PREFERRED_REPLICA_ID; + } + + public static FetchResponseData.PartitionData partitionResponse(TopicIdPartition topicIdPartition, Errors error) { + return partitionResponse(topicIdPartition.topicPartition().partition(), error); + } + + public static FetchResponseData.PartitionData partitionResponse(int partition, Errors error) { + return new FetchResponseData.PartitionData() + .setPartitionIndex(partition) + .setErrorCode(error.code()) + .setHighWatermark(FetchResponse.INVALID_HIGH_WATERMARK); + } + + /** + * Returns `partition.records` as `Records` (instead of `BaseRecords`). If `records` is `null`, returns `MemoryRecords.EMPTY`. + * + * If this response was deserialized after a fetch, this method should never fail. An example where this would + * fail is a down-converted response (e.g. LazyDownConversionRecords) on the broker (before it's serialized and + * sent on the wire). + * + * @param partition partition data + * @return Records or empty record if the records in PartitionData is null. + */ + public static Records recordsOrFail(FetchResponseData.PartitionData partition) { + if (partition.records() == null) return MemoryRecords.EMPTY; + if (partition.records() instanceof Records) return (Records) partition.records(); + throw new ClassCastException("The record type is " + partition.records().getClass().getSimpleName() + ", which is not a subtype of " + + Records.class.getSimpleName() + ". This method is only safe to call if the `FetchResponse` was deserialized from bytes."); + } + + /** + * @return The size in bytes of the records. 0 is returned if records of input partition is null. + */ + public static int recordsSize(FetchResponseData.PartitionData partition) { + return partition.records() == null ? 0 : partition.records().sizeInBytes(); + } + + // TODO: remove as a part of KAFKA-12410 + public static FetchResponse of(Errors error, + int throttleTimeMs, + int sessionId, + LinkedHashMap responseData) { + return new FetchResponse(toMessage(error, throttleTimeMs, sessionId, responseData.entrySet().iterator())); + } + + private static boolean matchingTopic(FetchResponseData.FetchableTopicResponse previousTopic, TopicIdPartition currentTopic) { + if (previousTopic == null) + return false; + if (!previousTopic.topicId().equals(Uuid.ZERO_UUID)) + return previousTopic.topicId().equals(currentTopic.topicId()); + else + return previousTopic.topic().equals(currentTopic.topicPartition().topic()); + + } + + private static FetchResponseData toMessage(Errors error, + int throttleTimeMs, + int sessionId, + Iterator> partIterator) { + List topicResponseList = new ArrayList<>(); + while (partIterator.hasNext()) { + Map.Entry entry = partIterator.next(); + FetchResponseData.PartitionData partitionData = entry.getValue(); + // Since PartitionData alone doesn't know the partition ID, we set it here + partitionData.setPartitionIndex(entry.getKey().topicPartition().partition()); + // We have to keep the order of input topic-partition. Hence, we batch the partitions only if the last + // batch is in the same topic group. + FetchResponseData.FetchableTopicResponse previousTopic = topicResponseList.isEmpty() ? null + : topicResponseList.get(topicResponseList.size() - 1); + if (matchingTopic(previousTopic, entry.getKey())) + previousTopic.partitions().add(partitionData); + else { + List partitionResponses = new ArrayList<>(); + partitionResponses.add(partitionData); + topicResponseList.add(new FetchResponseData.FetchableTopicResponse() + .setTopic(entry.getKey().topicPartition().topic()) + .setTopicId(entry.getKey().topicId()) + .setPartitions(partitionResponses)); + } + } + + return new FetchResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setErrorCode(error.code()) + .setSessionId(sessionId) + .setResponses(topicResponseList); + } +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/common/requests/FetchSnapshotRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/FetchSnapshotRequest.java new file mode 100644 index 0000000..1769e94 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/FetchSnapshotRequest.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.FetchSnapshotRequestData; +import org.apache.kafka.common.message.FetchSnapshotResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.Optional; +import java.util.function.UnaryOperator; + +final public class FetchSnapshotRequest extends AbstractRequest { + private final FetchSnapshotRequestData data; + + public FetchSnapshotRequest(FetchSnapshotRequestData data, short version) { + super(ApiKeys.FETCH_SNAPSHOT, version); + this.data = data; + } + + @Override + public FetchSnapshotResponse getErrorResponse(int throttleTimeMs, Throwable e) { + return new FetchSnapshotResponse( + new FetchSnapshotResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setErrorCode(Errors.forException(e).code()) + ); + } + + @Override + public FetchSnapshotRequestData data() { + return data; + } + + /** + * Creates a FetchSnapshotRequestData with a single PartitionSnapshot for the topic partition. + * + * The partition index will already be populated when calling operator. + * + * @param topicPartition the topic partition to include + * @param operator unary operator responsible for populating all the appropriate fields + * @return the created fetch snapshot request data + */ + public static FetchSnapshotRequestData singleton( + String clusterId, + TopicPartition topicPartition, + UnaryOperator operator + ) { + FetchSnapshotRequestData.PartitionSnapshot partitionSnapshot = operator.apply( + new FetchSnapshotRequestData.PartitionSnapshot().setPartition(topicPartition.partition()) + ); + + return new FetchSnapshotRequestData() + .setClusterId(clusterId) + .setTopics( + Collections.singletonList( + new FetchSnapshotRequestData.TopicSnapshot() + .setName(topicPartition.topic()) + .setPartitions(Collections.singletonList(partitionSnapshot)) + ) + ); + } + + /** + * Finds the PartitionSnapshot for a given topic partition. + * + * @param data the fetch snapshot request data + * @param topicPartition the topic partition to find + * @return the request partition snapshot if found, otherwise an empty Optional + */ + public static Optional forTopicPartition( + FetchSnapshotRequestData data, + TopicPartition topicPartition + ) { + return data + .topics() + .stream() + .filter(topic -> topic.name().equals(topicPartition.topic())) + .flatMap(topic -> topic.partitions().stream()) + .filter(partition -> partition.partition() == topicPartition.partition()) + .findAny(); + } + + public static FetchSnapshotRequest parse(ByteBuffer buffer, short version) { + return new FetchSnapshotRequest(new FetchSnapshotRequestData(new ByteBufferAccessor(buffer), version), version); + } + + public static class Builder extends AbstractRequest.Builder { + private final FetchSnapshotRequestData data; + + public Builder(FetchSnapshotRequestData data) { + super(ApiKeys.FETCH_SNAPSHOT); + this.data = data; + } + + @Override + public FetchSnapshotRequest build(short version) { + return new FetchSnapshotRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/FetchSnapshotResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/FetchSnapshotResponse.java new file mode 100644 index 0000000..7c1ce27 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/FetchSnapshotResponse.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.FetchSnapshotResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.function.UnaryOperator; + +final public class FetchSnapshotResponse extends AbstractResponse { + private final FetchSnapshotResponseData data; + + public FetchSnapshotResponse(FetchSnapshotResponseData data) { + super(ApiKeys.FETCH_SNAPSHOT); + this.data = data; + } + + @Override + public Map errorCounts() { + Map errors = new HashMap<>(); + + Errors topLevelError = Errors.forCode(data.errorCode()); + if (topLevelError != Errors.NONE) { + errors.put(topLevelError, 1); + } + + for (FetchSnapshotResponseData.TopicSnapshot topicResponse : data.topics()) { + for (FetchSnapshotResponseData.PartitionSnapshot partitionResponse : topicResponse.partitions()) { + errors.compute(Errors.forCode(partitionResponse.errorCode()), + (error, count) -> count == null ? 1 : count + 1); + } + } + + return errors; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + @Override + public FetchSnapshotResponseData data() { + return data; + } + + /** + * Creates a FetchSnapshotResponseData with a top level error. + * + * @param error the top level error + * @return the created fetch snapshot response data + */ + public static FetchSnapshotResponseData withTopLevelError(Errors error) { + return new FetchSnapshotResponseData().setErrorCode(error.code()); + } + + /** + * Creates a FetchSnapshotResponseData with a single PartitionSnapshot for the topic partition. + * + * The partition index will already by populated when calling operator. + * + * @param topicPartition the topic partition to include + * @param operator unary operator responsible for populating all of the appropriate fields + * @return the created fetch snapshot response data + */ + public static FetchSnapshotResponseData singleton( + TopicPartition topicPartition, + UnaryOperator operator + ) { + FetchSnapshotResponseData.PartitionSnapshot partitionSnapshot = operator.apply( + new FetchSnapshotResponseData.PartitionSnapshot().setIndex(topicPartition.partition()) + ); + + return new FetchSnapshotResponseData() + .setTopics( + Collections.singletonList( + new FetchSnapshotResponseData.TopicSnapshot() + .setName(topicPartition.topic()) + .setPartitions(Collections.singletonList(partitionSnapshot)) + ) + ); + } + + /** + * Finds the PartitionSnapshot for a given topic partition. + * + * @param data the fetch snapshot response data + * @param topicPartition the topic partition to find + * @return the response partition snapshot if found, otherwise an empty Optional + */ + public static Optional forTopicPartition( + FetchSnapshotResponseData data, + TopicPartition topicPartition + ) { + return data + .topics() + .stream() + .filter(topic -> topic.name().equals(topicPartition.topic())) + .flatMap(topic -> topic.partitions().stream()) + .filter(parition -> parition.index() == topicPartition.partition()) + .findAny(); + } + + public static FetchSnapshotResponse parse(ByteBuffer buffer, short version) { + return new FetchSnapshotResponse(new FetchSnapshotResponseData(new ByteBufferAccessor(buffer), version)); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/FindCoordinatorRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/FindCoordinatorRequest.java new file mode 100644 index 0000000..fcac7de --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/FindCoordinatorRequest.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.Node; +import org.apache.kafka.common.errors.InvalidRequestException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.FindCoordinatorRequestData; +import org.apache.kafka.common.message.FindCoordinatorResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Collections; + +public class FindCoordinatorRequest extends AbstractRequest { + + public static final short MIN_BATCHED_VERSION = 4; + + public static class Builder extends AbstractRequest.Builder { + private final FindCoordinatorRequestData data; + + public Builder(FindCoordinatorRequestData data) { + super(ApiKeys.FIND_COORDINATOR); + this.data = data; + } + + @Override + public FindCoordinatorRequest build(short version) { + if (version < 1 && data.keyType() == CoordinatorType.TRANSACTION.id()) { + throw new UnsupportedVersionException("Cannot create a v" + version + " FindCoordinator request " + + "because we require features supported only in 2 or later."); + } + int batchedKeys = data.coordinatorKeys().size(); + if (version < MIN_BATCHED_VERSION) { + if (batchedKeys > 1) + throw new NoBatchedFindCoordinatorsException("Cannot create a v" + version + " FindCoordinator request " + + "because we require features supported only in " + MIN_BATCHED_VERSION + " or later."); + if (batchedKeys == 1) { + data.setKey(data.coordinatorKeys().get(0)); + data.setCoordinatorKeys(Collections.emptyList()); + } + } else if (batchedKeys == 0 && data.key() != null) { + data.setCoordinatorKeys(Collections.singletonList(data.key())); + data.setKey(""); // default value + } + return new FindCoordinatorRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + + public FindCoordinatorRequestData data() { + return data; + } + } + + /** + * Indicates that it is not possible to lookup coordinators in batches with FindCoordinator. Instead + * coordinators must be looked up one by one. + */ + public static class NoBatchedFindCoordinatorsException extends UnsupportedVersionException { + private static final long serialVersionUID = 1L; + + public NoBatchedFindCoordinatorsException(String message, Throwable cause) { + super(message, cause); + } + + public NoBatchedFindCoordinatorsException(String message) { + super(message); + } + } + + private final FindCoordinatorRequestData data; + + private FindCoordinatorRequest(FindCoordinatorRequestData data, short version) { + super(ApiKeys.FIND_COORDINATOR, version); + this.data = data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + FindCoordinatorResponseData response = new FindCoordinatorResponseData(); + if (version() >= 2) { + response.setThrottleTimeMs(throttleTimeMs); + } + Errors error = Errors.forException(e); + if (version() < MIN_BATCHED_VERSION) { + return FindCoordinatorResponse.prepareOldResponse(error, Node.noNode()); + } else { + return FindCoordinatorResponse.prepareErrorResponse(error, data.coordinatorKeys()); + } + } + + public static FindCoordinatorRequest parse(ByteBuffer buffer, short version) { + return new FindCoordinatorRequest(new FindCoordinatorRequestData(new ByteBufferAccessor(buffer), version), + version); + } + + @Override + public FindCoordinatorRequestData data() { + return data; + } + + public enum CoordinatorType { + GROUP((byte) 0), TRANSACTION((byte) 1); + + final byte id; + + CoordinatorType(byte id) { + this.id = id; + } + + public byte id() { + return id; + } + + public static CoordinatorType forId(byte id) { + switch (id) { + case 0: + return GROUP; + case 1: + return TRANSACTION; + default: + throw new InvalidRequestException("Unknown coordinator type received: " + id); + } + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/FindCoordinatorResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/FindCoordinatorResponse.java new file mode 100644 index 0000000..080ba24 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/FindCoordinatorResponse.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.Node; +import org.apache.kafka.common.message.FindCoordinatorResponseData; +import org.apache.kafka.common.message.FindCoordinatorResponseData.Coordinator; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + + +public class FindCoordinatorResponse extends AbstractResponse { + + /** + * Possible error codes: + * + * COORDINATOR_LOAD_IN_PROGRESS (14) + * COORDINATOR_NOT_AVAILABLE (15) + * GROUP_AUTHORIZATION_FAILED (30) + * INVALID_REQUEST (42) + * TRANSACTIONAL_ID_AUTHORIZATION_FAILED (53) + */ + + private final FindCoordinatorResponseData data; + + public FindCoordinatorResponse(FindCoordinatorResponseData data) { + super(ApiKeys.FIND_COORDINATOR); + this.data = data; + } + + @Override + public FindCoordinatorResponseData data() { + return data; + } + + public Node node() { + return new Node(data.nodeId(), data.host(), data.port()); + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + public boolean hasError() { + return error() != Errors.NONE; + } + + public Errors error() { + return Errors.forCode(data.errorCode()); + } + + @Override + public Map errorCounts() { + if (!data.coordinators().isEmpty()) { + Map errorCounts = new HashMap<>(); + for (Coordinator coordinator : data.coordinators()) { + updateErrorCounts(errorCounts, Errors.forCode(coordinator.errorCode())); + } + return errorCounts; + } else { + return errorCounts(error()); + } + } + + public static FindCoordinatorResponse parse(ByteBuffer buffer, short version) { + return new FindCoordinatorResponse(new FindCoordinatorResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public String toString() { + return data.toString(); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 2; + } + + public List coordinators() { + if (!data.coordinators().isEmpty()) + return data.coordinators(); + else { + FindCoordinatorResponseData.Coordinator coordinator = new Coordinator() + .setErrorCode(data.errorCode()) + .setErrorMessage(data.errorMessage()) + .setKey(null) + .setNodeId(data.nodeId()) + .setHost(data.host()) + .setPort(data.port()); + return Collections.singletonList(coordinator); + } + } + + public static FindCoordinatorResponse prepareOldResponse(Errors error, Node node) { + FindCoordinatorResponseData data = new FindCoordinatorResponseData(); + data.setErrorCode(error.code()) + .setErrorMessage(error.message()) + .setNodeId(node.id()) + .setHost(node.host()) + .setPort(node.port()); + return new FindCoordinatorResponse(data); + } + + public static FindCoordinatorResponse prepareResponse(Errors error, String key, Node node) { + FindCoordinatorResponseData data = new FindCoordinatorResponseData(); + data.setCoordinators(Collections.singletonList( + new FindCoordinatorResponseData.Coordinator() + .setErrorCode(error.code()) + .setErrorMessage(error.message()) + .setKey(key) + .setHost(node.host()) + .setPort(node.port()) + .setNodeId(node.id()))); + return new FindCoordinatorResponse(data); + } + + public static FindCoordinatorResponse prepareErrorResponse(Errors error, List keys) { + FindCoordinatorResponseData data = new FindCoordinatorResponseData(); + List coordinators = new ArrayList<>(keys.size()); + for (String key : keys) { + FindCoordinatorResponseData.Coordinator coordinator = new FindCoordinatorResponseData.Coordinator() + .setErrorCode(error.code()) + .setErrorMessage(error.message()) + .setKey(key) + .setHost(Node.noNode().host()) + .setPort(Node.noNode().port()) + .setNodeId(Node.noNode().id()); + coordinators.add(coordinator); + } + data.setCoordinators(coordinators); + return new FindCoordinatorResponse(data); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/HeartbeatRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/HeartbeatRequest.java new file mode 100644 index 0000000..482e61a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/HeartbeatRequest.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.HeartbeatRequestData; +import org.apache.kafka.common.message.HeartbeatResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; + +public class HeartbeatRequest extends AbstractRequest { + + public static class Builder extends AbstractRequest.Builder { + private final HeartbeatRequestData data; + + public Builder(HeartbeatRequestData data) { + super(ApiKeys.HEARTBEAT); + this.data = data; + } + + @Override + public HeartbeatRequest build(short version) { + if (data.groupInstanceId() != null && version < 3) { + throw new UnsupportedVersionException("The broker heartbeat protocol version " + + version + " does not support usage of config group.instance.id."); + } + return new HeartbeatRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final HeartbeatRequestData data; + + private HeartbeatRequest(HeartbeatRequestData data, short version) { + super(ApiKeys.HEARTBEAT, version); + this.data = data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + HeartbeatResponseData responseData = new HeartbeatResponseData(). + setErrorCode(Errors.forException(e).code()); + if (version() >= 1) { + responseData.setThrottleTimeMs(throttleTimeMs); + } + return new HeartbeatResponse(responseData); + } + + public static HeartbeatRequest parse(ByteBuffer buffer, short version) { + return new HeartbeatRequest(new HeartbeatRequestData(new ByteBufferAccessor(buffer), version), version); + } + + @Override + public HeartbeatRequestData data() { + return data; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/HeartbeatResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/HeartbeatResponse.java new file mode 100644 index 0000000..eb402fc --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/HeartbeatResponse.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.HeartbeatResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Map; + +public class HeartbeatResponse extends AbstractResponse { + + /** + * Possible error codes: + * + * GROUP_COORDINATOR_NOT_AVAILABLE (15) + * NOT_COORDINATOR (16) + * ILLEGAL_GENERATION (22) + * UNKNOWN_MEMBER_ID (25) + * REBALANCE_IN_PROGRESS (27) + * GROUP_AUTHORIZATION_FAILED (30) + */ + private final HeartbeatResponseData data; + + public HeartbeatResponse(HeartbeatResponseData data) { + super(ApiKeys.HEARTBEAT); + this.data = data; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + public Errors error() { + return Errors.forCode(data.errorCode()); + } + + @Override + public Map errorCounts() { + return errorCounts(error()); + } + + @Override + public HeartbeatResponseData data() { + return data; + } + + public static HeartbeatResponse parse(ByteBuffer buffer, short version) { + return new HeartbeatResponse(new HeartbeatResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 2; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/IncrementalAlterConfigsRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/IncrementalAlterConfigsRequest.java new file mode 100644 index 0000000..2bc5914 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/IncrementalAlterConfigsRequest.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.clients.admin.AlterConfigOp; +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.message.IncrementalAlterConfigsRequestData; +import org.apache.kafka.common.message.IncrementalAlterConfigsRequestData.AlterConfigsResource; +import org.apache.kafka.common.message.IncrementalAlterConfigsResponseData; +import org.apache.kafka.common.message.IncrementalAlterConfigsResponseData.AlterConfigsResourceResponse; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; + +import java.nio.ByteBuffer; +import java.util.Collection; +import java.util.Map; + +public class IncrementalAlterConfigsRequest extends AbstractRequest { + + public static class Builder extends AbstractRequest.Builder { + private final IncrementalAlterConfigsRequestData data; + + public Builder(IncrementalAlterConfigsRequestData data) { + super(ApiKeys.INCREMENTAL_ALTER_CONFIGS); + this.data = data; + } + + public Builder(final Collection resources, + final Map> configs, + final boolean validateOnly) { + super(ApiKeys.INCREMENTAL_ALTER_CONFIGS); + this.data = new IncrementalAlterConfigsRequestData() + .setValidateOnly(validateOnly); + for (ConfigResource resource : resources) { + IncrementalAlterConfigsRequestData.AlterableConfigCollection alterableConfigSet = + new IncrementalAlterConfigsRequestData.AlterableConfigCollection(); + for (AlterConfigOp configEntry : configs.get(resource)) + alterableConfigSet.add(new IncrementalAlterConfigsRequestData.AlterableConfig() + .setName(configEntry.configEntry().name()) + .setValue(configEntry.configEntry().value()) + .setConfigOperation(configEntry.opType().id())); + IncrementalAlterConfigsRequestData.AlterConfigsResource alterConfigsResource = new IncrementalAlterConfigsRequestData.AlterConfigsResource(); + alterConfigsResource.setResourceType(resource.type().id()) + .setResourceName(resource.name()).setConfigs(alterableConfigSet); + data.resources().add(alterConfigsResource); + } + } + + public Builder(final Map> configs, + final boolean validateOnly) { + this(configs.keySet(), configs, validateOnly); + } + + @Override + public IncrementalAlterConfigsRequest build(short version) { + return new IncrementalAlterConfigsRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final IncrementalAlterConfigsRequestData data; + private final short version; + + private IncrementalAlterConfigsRequest(IncrementalAlterConfigsRequestData data, short version) { + super(ApiKeys.INCREMENTAL_ALTER_CONFIGS, version); + this.data = data; + this.version = version; + } + + public static IncrementalAlterConfigsRequest parse(ByteBuffer buffer, short version) { + return new IncrementalAlterConfigsRequest(new IncrementalAlterConfigsRequestData( + new ByteBufferAccessor(buffer), version), version); + } + + @Override + public IncrementalAlterConfigsRequestData data() { + return data; + } + + @Override + public AbstractResponse getErrorResponse(final int throttleTimeMs, final Throwable e) { + IncrementalAlterConfigsResponseData response = new IncrementalAlterConfigsResponseData(); + ApiError apiError = ApiError.fromThrowable(e); + for (AlterConfigsResource resource : data.resources()) { + response.responses().add(new AlterConfigsResourceResponse() + .setResourceName(resource.resourceName()) + .setResourceType(resource.resourceType()) + .setErrorCode(apiError.error().code()) + .setErrorMessage(apiError.message())); + } + return new IncrementalAlterConfigsResponse(response); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/IncrementalAlterConfigsResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/IncrementalAlterConfigsResponse.java new file mode 100644 index 0000000..b5887de --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/IncrementalAlterConfigsResponse.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.message.IncrementalAlterConfigsResponseData; +import org.apache.kafka.common.message.IncrementalAlterConfigsResponseData.AlterConfigsResourceResponse; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class IncrementalAlterConfigsResponse extends AbstractResponse { + + public IncrementalAlterConfigsResponse(final int requestThrottleMs, + final Map results) { + super(ApiKeys.INCREMENTAL_ALTER_CONFIGS); + final List newResults = new ArrayList<>(results.size()); + results.forEach( + (resource, error) -> newResults.add( + new AlterConfigsResourceResponse() + .setErrorCode(error.error().code()) + .setErrorMessage(error.message()) + .setResourceName(resource.name()) + .setResourceType(resource.type().id())) + ); + + this.data = new IncrementalAlterConfigsResponseData() + .setResponses(newResults) + .setThrottleTimeMs(requestThrottleMs); + } + + public static Map fromResponseData(final IncrementalAlterConfigsResponseData data) { + Map map = new HashMap<>(); + for (AlterConfigsResourceResponse response : data.responses()) { + map.put(new ConfigResource(ConfigResource.Type.forId(response.resourceType()), response.resourceName()), + new ApiError(Errors.forCode(response.errorCode()), response.errorMessage())); + } + return map; + } + + private final IncrementalAlterConfigsResponseData data; + + public IncrementalAlterConfigsResponse(IncrementalAlterConfigsResponseData data) { + super(ApiKeys.INCREMENTAL_ALTER_CONFIGS); + this.data = data; + } + + @Override + public IncrementalAlterConfigsResponseData data() { + return data; + } + + @Override + public Map errorCounts() { + HashMap counts = new HashMap<>(); + data.responses().forEach(response -> + updateErrorCounts(counts, Errors.forCode(response.errorCode())) + ); + return counts; + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 0; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + public static IncrementalAlterConfigsResponse parse(ByteBuffer buffer, short version) { + return new IncrementalAlterConfigsResponse(new IncrementalAlterConfigsResponseData( + new ByteBufferAccessor(buffer), version)); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/InitProducerIdRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/InitProducerIdRequest.java new file mode 100644 index 0000000..5c24b41 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/InitProducerIdRequest.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.InitProducerIdRequestData; +import org.apache.kafka.common.message.InitProducerIdResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.RecordBatch; + +import java.nio.ByteBuffer; + +public class InitProducerIdRequest extends AbstractRequest { + public static class Builder extends AbstractRequest.Builder { + private final InitProducerIdRequestData data; + + public Builder(InitProducerIdRequestData data) { + super(ApiKeys.INIT_PRODUCER_ID); + this.data = data; + } + + @Override + public InitProducerIdRequest build(short version) { + if (data.transactionTimeoutMs() <= 0) + throw new IllegalArgumentException("transaction timeout value is not positive: " + data.transactionTimeoutMs()); + + if (data.transactionalId() != null && data.transactionalId().isEmpty()) + throw new IllegalArgumentException("Must set either a null or a non-empty transactional id."); + + return new InitProducerIdRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final InitProducerIdRequestData data; + + private InitProducerIdRequest(InitProducerIdRequestData data, short version) { + super(ApiKeys.INIT_PRODUCER_ID, version); + this.data = data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + InitProducerIdResponseData response = new InitProducerIdResponseData() + .setErrorCode(Errors.forException(e).code()) + .setProducerId(RecordBatch.NO_PRODUCER_ID) + .setProducerEpoch(RecordBatch.NO_PRODUCER_EPOCH) + .setThrottleTimeMs(0); + return new InitProducerIdResponse(response); + } + + public static InitProducerIdRequest parse(ByteBuffer buffer, short version) { + return new InitProducerIdRequest(new InitProducerIdRequestData(new ByteBufferAccessor(buffer), version), version); + } + + @Override + public InitProducerIdRequestData data() { + return data; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/InitProducerIdResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/InitProducerIdResponse.java new file mode 100644 index 0000000..f8451d7 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/InitProducerIdResponse.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.InitProducerIdResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Map; + +/** + * Possible error codes: + * + * - {@link Errors#NOT_COORDINATOR} + * - {@link Errors#COORDINATOR_NOT_AVAILABLE} + * - {@link Errors#COORDINATOR_LOAD_IN_PROGRESS} + * - {@link Errors#TRANSACTIONAL_ID_AUTHORIZATION_FAILED} + * - {@link Errors#CLUSTER_AUTHORIZATION_FAILED} + * - {@link Errors#INVALID_PRODUCER_EPOCH} // for version <=3 + * - {@link Errors#PRODUCER_FENCED} + */ +public class InitProducerIdResponse extends AbstractResponse { + private final InitProducerIdResponseData data; + + public InitProducerIdResponse(InitProducerIdResponseData data) { + super(ApiKeys.INIT_PRODUCER_ID); + this.data = data; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + @Override + public Map errorCounts() { + return errorCounts(Errors.forCode(data.errorCode())); + } + + @Override + public InitProducerIdResponseData data() { + return data; + } + + public static InitProducerIdResponse parse(ByteBuffer buffer, short version) { + return new InitProducerIdResponse(new InitProducerIdResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public String toString() { + return data.toString(); + } + + public Errors error() { + return Errors.forCode(data.errorCode()); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 1; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/JoinGroupRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/JoinGroupRequest.java new file mode 100644 index 0000000..220a59d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/JoinGroupRequest.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.errors.InvalidConfigurationException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.internals.Topic; +import org.apache.kafka.common.message.JoinGroupRequestData; +import org.apache.kafka.common.message.JoinGroupResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Collections; + +public class JoinGroupRequest extends AbstractRequest { + + public static class Builder extends AbstractRequest.Builder { + + private final JoinGroupRequestData data; + + public Builder(JoinGroupRequestData data) { + super(ApiKeys.JOIN_GROUP); + this.data = data; + } + + @Override + public JoinGroupRequest build(short version) { + if (data.groupInstanceId() != null && version < 5) { + throw new UnsupportedVersionException("The broker join group protocol version " + + version + " does not support usage of config group.instance.id."); + } + return new JoinGroupRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final JoinGroupRequestData data; + + public static final String UNKNOWN_MEMBER_ID = ""; + public static final int UNKNOWN_GENERATION_ID = -1; + public static final String UNKNOWN_PROTOCOL_NAME = ""; + + /** + * Ported from class Topic in {@link org.apache.kafka.common.internals} to restrict the charset for + * static member id. + */ + public static void validateGroupInstanceId(String id) { + Topic.validate(id, "Group instance id", message -> { + throw new InvalidConfigurationException(message); + }); + } + + public JoinGroupRequest(JoinGroupRequestData data, short version) { + super(ApiKeys.JOIN_GROUP, version); + this.data = data; + maybeOverrideRebalanceTimeout(version); + } + + private void maybeOverrideRebalanceTimeout(short version) { + if (version == 0) { + // Version 0 has no rebalance timeout, so we use the session timeout + // to be consistent with the original behavior of the API. + data.setRebalanceTimeoutMs(data.sessionTimeoutMs()); + } + } + + @Override + public JoinGroupRequestData data() { + return data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + JoinGroupResponseData data = new JoinGroupResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setErrorCode(Errors.forException(e).code()) + .setGenerationId(UNKNOWN_GENERATION_ID) + .setProtocolName(UNKNOWN_PROTOCOL_NAME) + .setLeader(UNKNOWN_MEMBER_ID) + .setMemberId(UNKNOWN_MEMBER_ID) + .setMembers(Collections.emptyList()); + + if (version() >= 7) + data.setProtocolName(null); + else + data.setProtocolName(UNKNOWN_PROTOCOL_NAME); + + return new JoinGroupResponse(data); + } + + public static JoinGroupRequest parse(ByteBuffer buffer, short version) { + return new JoinGroupRequest(new JoinGroupRequestData(new ByteBufferAccessor(buffer), version), version); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/JoinGroupResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/JoinGroupResponse.java new file mode 100644 index 0000000..336c824 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/JoinGroupResponse.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.JoinGroupResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Map; + +public class JoinGroupResponse extends AbstractResponse { + + private final JoinGroupResponseData data; + + public JoinGroupResponse(JoinGroupResponseData data) { + super(ApiKeys.JOIN_GROUP); + this.data = data; + } + + @Override + public JoinGroupResponseData data() { + return data; + } + + public boolean isLeader() { + return data.memberId().equals(data.leader()); + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + public Errors error() { + return Errors.forCode(data.errorCode()); + } + + @Override + public Map errorCounts() { + return errorCounts(Errors.forCode(data.errorCode())); + } + + public static JoinGroupResponse parse(ByteBuffer buffer, short version) { + return new JoinGroupResponse(new JoinGroupResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public String toString() { + return data.toString(); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 3; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/LeaderAndIsrRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/LeaderAndIsrRequest.java new file mode 100644 index 0000000..d738286 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/LeaderAndIsrRequest.java @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.Node; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.message.LeaderAndIsrRequestData; +import org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrLiveLeader; +import org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrTopicState; +import org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrPartitionState; +import org.apache.kafka.common.message.LeaderAndIsrResponseData; +import org.apache.kafka.common.message.LeaderAndIsrResponseData.LeaderAndIsrTopicError; +import org.apache.kafka.common.message.LeaderAndIsrResponseData.LeaderAndIsrPartitionError; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.utils.FlattenedIterator; +import org.apache.kafka.common.utils.Utils; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +public class LeaderAndIsrRequest extends AbstractControlRequest { + + public static class Builder extends AbstractControlRequest.Builder { + + private final List partitionStates; + private final Map topicIds; + private final Collection liveLeaders; + + public Builder(short version, int controllerId, int controllerEpoch, long brokerEpoch, + List partitionStates, Map topicIds, + Collection liveLeaders) { + super(ApiKeys.LEADER_AND_ISR, version, controllerId, controllerEpoch, brokerEpoch); + this.partitionStates = partitionStates; + this.topicIds = topicIds; + this.liveLeaders = liveLeaders; + } + + @Override + public LeaderAndIsrRequest build(short version) { + List leaders = liveLeaders.stream().map(n -> new LeaderAndIsrLiveLeader() + .setBrokerId(n.id()) + .setHostName(n.host()) + .setPort(n.port()) + ).collect(Collectors.toList()); + + LeaderAndIsrRequestData data = new LeaderAndIsrRequestData() + .setControllerId(controllerId) + .setControllerEpoch(controllerEpoch) + .setBrokerEpoch(brokerEpoch) + .setLiveLeaders(leaders); + + if (version >= 2) { + Map topicStatesMap = groupByTopic(partitionStates, topicIds); + data.setTopicStates(new ArrayList<>(topicStatesMap.values())); + } else { + data.setUngroupedPartitionStates(partitionStates); + } + + return new LeaderAndIsrRequest(data, version); + } + + private static Map groupByTopic(List partitionStates, Map topicIds) { + Map topicStates = new HashMap<>(); + // We don't null out the topic name in LeaderAndIsrRequestPartition since it's ignored by + // the generated code if version >= 2 + for (LeaderAndIsrPartitionState partition : partitionStates) { + LeaderAndIsrTopicState topicState = topicStates.computeIfAbsent(partition.topicName(), t -> new LeaderAndIsrTopicState() + .setTopicName(partition.topicName()) + .setTopicId(topicIds.getOrDefault(partition.topicName(), Uuid.ZERO_UUID))); + topicState.partitionStates().add(partition); + } + return topicStates; + } + + @Override + public String toString() { + StringBuilder bld = new StringBuilder(); + bld.append("(type=LeaderAndIsRequest") + .append(", controllerId=").append(controllerId) + .append(", controllerEpoch=").append(controllerEpoch) + .append(", brokerEpoch=").append(brokerEpoch) + .append(", partitionStates=").append(partitionStates) + .append(", topicIds=").append(topicIds) + .append(", liveLeaders=(").append(Utils.join(liveLeaders, ", ")).append(")") + .append(")"); + return bld.toString(); + + } + } + + private final LeaderAndIsrRequestData data; + + LeaderAndIsrRequest(LeaderAndIsrRequestData data, short version) { + super(ApiKeys.LEADER_AND_ISR, version); + this.data = data; + // Do this from the constructor to make it thread-safe (even though it's only needed when some methods are called) + normalize(); + } + + private void normalize() { + if (version() >= 2) { + for (LeaderAndIsrTopicState topicState : data.topicStates()) { + for (LeaderAndIsrPartitionState partitionState : topicState.partitionStates()) { + // Set the topic name so that we can always present the ungrouped view to callers + partitionState.setTopicName(topicState.topicName()); + } + } + } + } + + @Override + public LeaderAndIsrResponse getErrorResponse(int throttleTimeMs, Throwable e) { + LeaderAndIsrResponseData responseData = new LeaderAndIsrResponseData(); + Errors error = Errors.forException(e); + responseData.setErrorCode(error.code()); + + if (version() < 5) { + List partitions = new ArrayList<>(); + for (LeaderAndIsrPartitionState partition : partitionStates()) { + partitions.add(new LeaderAndIsrPartitionError() + .setTopicName(partition.topicName()) + .setPartitionIndex(partition.partitionIndex()) + .setErrorCode(error.code())); + } + responseData.setPartitionErrors(partitions); + } else { + for (LeaderAndIsrTopicState topicState : data.topicStates()) { + List partitions = new ArrayList<>( + topicState.partitionStates().size()); + for (LeaderAndIsrPartitionState partition : topicState.partitionStates()) { + partitions.add(new LeaderAndIsrPartitionError() + .setPartitionIndex(partition.partitionIndex()) + .setErrorCode(error.code())); + } + responseData.topics().add(new LeaderAndIsrTopicError() + .setTopicId(topicState.topicId()) + .setPartitionErrors(partitions)); + } + } + + return new LeaderAndIsrResponse(responseData, version()); + } + + @Override + public int controllerId() { + return data.controllerId(); + } + + @Override + public int controllerEpoch() { + return data.controllerEpoch(); + } + + @Override + public long brokerEpoch() { + return data.brokerEpoch(); + } + + public Iterable partitionStates() { + if (version() >= 2) + return () -> new FlattenedIterator<>(data.topicStates().iterator(), + topicState -> topicState.partitionStates().iterator()); + return data.ungroupedPartitionStates(); + } + + public Map topicIds() { + return data.topicStates().stream() + .collect(Collectors.toMap(LeaderAndIsrTopicState::topicName, LeaderAndIsrTopicState::topicId)); + } + + public List liveLeaders() { + return Collections.unmodifiableList(data.liveLeaders()); + } + + @Override + public LeaderAndIsrRequestData data() { + return data; + } + + public static LeaderAndIsrRequest parse(ByteBuffer buffer, short version) { + return new LeaderAndIsrRequest(new LeaderAndIsrRequestData(new ByteBufferAccessor(buffer), version), version); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/LeaderAndIsrResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/LeaderAndIsrResponse.java new file mode 100644 index 0000000..c7c04e2 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/LeaderAndIsrResponse.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.message.LeaderAndIsrResponseData; +import org.apache.kafka.common.message.LeaderAndIsrResponseData.LeaderAndIsrTopicError; +import org.apache.kafka.common.message.LeaderAndIsrResponseData.LeaderAndIsrTopicErrorCollection; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +public class LeaderAndIsrResponse extends AbstractResponse { + + /** + * Possible error code: + * + * STALE_CONTROLLER_EPOCH (11) + * STALE_BROKER_EPOCH (77) + */ + private final LeaderAndIsrResponseData data; + private final short version; + + public LeaderAndIsrResponse(LeaderAndIsrResponseData data, short version) { + super(ApiKeys.LEADER_AND_ISR); + this.data = data; + this.version = version; + } + + public LeaderAndIsrTopicErrorCollection topics() { + return this.data.topics(); + } + + public Errors error() { + return Errors.forCode(data.errorCode()); + } + + @Override + public Map errorCounts() { + Errors error = error(); + if (error != Errors.NONE) { + // Minor optimization since the top-level error applies to all partitions + if (version < 5) + return Collections.singletonMap(error, data.partitionErrors().size() + 1); + return Collections.singletonMap(error, + data.topics().stream().mapToInt(t -> t.partitionErrors().size()).sum() + 1); + } + Map errors; + if (version < 5) + errors = errorCounts(data.partitionErrors().stream().map(l -> Errors.forCode(l.errorCode()))); + else + errors = errorCounts(data.topics().stream().flatMap(t -> t.partitionErrors().stream()).map(l -> + Errors.forCode(l.errorCode()))); + updateErrorCounts(errors, Errors.NONE); + return errors; + } + + public Map partitionErrors(Map topicNames) { + Map errors = new HashMap<>(); + if (version < 5) { + data.partitionErrors().forEach(partition -> + errors.put(new TopicPartition(partition.topicName(), partition.partitionIndex()), + Errors.forCode(partition.errorCode()))); + } else { + for (LeaderAndIsrTopicError topic : data.topics()) { + String topicName = topicNames.get(topic.topicId()); + if (topicName != null) { + topic.partitionErrors().forEach(partition -> + errors.put(new TopicPartition(topicName, partition.partitionIndex()), + Errors.forCode(partition.errorCode()))); + } + } + } + return errors; + } + + @Override + public int throttleTimeMs() { + return DEFAULT_THROTTLE_TIME; + } + + public static LeaderAndIsrResponse parse(ByteBuffer buffer, short version) { + return new LeaderAndIsrResponse(new LeaderAndIsrResponseData(new ByteBufferAccessor(buffer), version), version); + } + + @Override + public LeaderAndIsrResponseData data() { + return data; + } + + @Override + public String toString() { + return data.toString(); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/LeaveGroupRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/LeaveGroupRequest.java new file mode 100644 index 0000000..8ce9535 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/LeaveGroupRequest.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.LeaveGroupRequestData; +import org.apache.kafka.common.message.LeaveGroupRequestData.MemberIdentity; +import org.apache.kafka.common.message.LeaveGroupResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.protocol.MessageUtil; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.List; + +public class LeaveGroupRequest extends AbstractRequest { + + public static class Builder extends AbstractRequest.Builder { + private final String groupId; + private final List members; + + public Builder(String groupId, List members) { + this(groupId, members, ApiKeys.LEAVE_GROUP.oldestVersion(), ApiKeys.LEAVE_GROUP.latestVersion()); + } + + Builder(String groupId, List members, short oldestVersion, short latestVersion) { + super(ApiKeys.LEAVE_GROUP, oldestVersion, latestVersion); + this.groupId = groupId; + this.members = members; + if (members.isEmpty()) { + throw new IllegalArgumentException("leaving members should not be empty"); + } + } + + /** + * Based on the request version to choose fields. + */ + @Override + public LeaveGroupRequest build(short version) { + final LeaveGroupRequestData data; + // Starting from version 3, all the leave group request will be in batch. + if (version >= 3) { + data = new LeaveGroupRequestData() + .setGroupId(groupId) + .setMembers(members); + } else { + if (members.size() != 1) { + throw new UnsupportedVersionException("Version " + version + " leave group request only " + + "supports single member instance than " + members.size() + " members"); + } + + data = new LeaveGroupRequestData() + .setGroupId(groupId) + .setMemberId(members.get(0).memberId()); + } + return new LeaveGroupRequest(data, version); + } + + @Override + public String toString() { + return "(type=LeaveGroupRequest" + + ", groupId=" + groupId + + ", members=" + MessageUtil.deepToString(members.iterator()) + + ")"; + } + } + private final LeaveGroupRequestData data; + + private LeaveGroupRequest(LeaveGroupRequestData data, short version) { + super(ApiKeys.LEAVE_GROUP, version); + this.data = data; + } + + @Override + public LeaveGroupRequestData data() { + return data; + } + + public List members() { + // Before version 3, leave group request is still in single mode + return version() <= 2 ? Collections.singletonList( + new MemberIdentity() + .setMemberId(data.memberId())) : data.members(); + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + LeaveGroupResponseData responseData = new LeaveGroupResponseData() + .setErrorCode(Errors.forException(e).code()); + + if (version() >= 1) { + responseData.setThrottleTimeMs(throttleTimeMs); + } + return new LeaveGroupResponse(responseData); + } + + public static LeaveGroupRequest parse(ByteBuffer buffer, short version) { + return new LeaveGroupRequest(new LeaveGroupRequestData(new ByteBufferAccessor(buffer), version), version); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/LeaveGroupResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/LeaveGroupResponse.java new file mode 100644 index 0000000..9a59139 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/LeaveGroupResponse.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.LeaveGroupResponseData; +import org.apache.kafka.common.message.LeaveGroupResponseData.MemberResponse; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +/** + * Possible error codes. + * + * Top level errors: + * - {@link Errors#COORDINATOR_LOAD_IN_PROGRESS} + * - {@link Errors#COORDINATOR_NOT_AVAILABLE} + * - {@link Errors#NOT_COORDINATOR} + * - {@link Errors#GROUP_AUTHORIZATION_FAILED} + * + * Member level errors: + * - {@link Errors#FENCED_INSTANCE_ID} + * - {@link Errors#UNKNOWN_MEMBER_ID} + * + * If the top level error code is set, normally this indicates that broker early stops the request + * handling due to some severe global error, so it is expected to see the member level errors to be empty. + * For older version response, we may populate member level error towards top level because older client + * couldn't parse member level. + */ +public class LeaveGroupResponse extends AbstractResponse { + + private final LeaveGroupResponseData data; + + public LeaveGroupResponse(LeaveGroupResponseData data) { + super(ApiKeys.LEAVE_GROUP); + this.data = data; + } + + public LeaveGroupResponse(List memberResponses, + Errors topLevelError, + final int throttleTimeMs, + final short version) { + super(ApiKeys.LEAVE_GROUP); + if (version <= 2) { + // Populate member level error. + final short errorCode = getError(topLevelError, memberResponses).code(); + + this.data = new LeaveGroupResponseData() + .setErrorCode(errorCode); + } else { + this.data = new LeaveGroupResponseData() + .setErrorCode(topLevelError.code()) + .setMembers(memberResponses); + } + + if (version >= 1) { + this.data.setThrottleTimeMs(throttleTimeMs); + } + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + public List memberResponses() { + return data.members(); + } + + public Errors error() { + return getError(Errors.forCode(data.errorCode()), data.members()); + } + + public Errors topLevelError() { + return Errors.forCode(data.errorCode()); + } + + private static Errors getError(Errors topLevelError, List memberResponses) { + if (topLevelError != Errors.NONE) { + return topLevelError; + } else { + for (MemberResponse memberResponse : memberResponses) { + Errors memberError = Errors.forCode(memberResponse.errorCode()); + if (memberError != Errors.NONE) { + return memberError; + } + } + return Errors.NONE; + } + } + + @Override + public Map errorCounts() { + Map combinedErrorCounts = new HashMap<>(); + // Top level error. + updateErrorCounts(combinedErrorCounts, Errors.forCode(data.errorCode())); + + // Member level error. + data.members().forEach(memberResponse -> { + updateErrorCounts(combinedErrorCounts, Errors.forCode(memberResponse.errorCode())); + }); + return combinedErrorCounts; + } + + @Override + public LeaveGroupResponseData data() { + return data; + } + + public static LeaveGroupResponse parse(ByteBuffer buffer, short version) { + return new LeaveGroupResponse(new LeaveGroupResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 2; + } + + @Override + public boolean equals(Object other) { + return other instanceof LeaveGroupResponse && + ((LeaveGroupResponse) other).data.equals(this.data); + } + + @Override + public int hashCode() { + return Objects.hashCode(data); + } + + @Override + public String toString() { + return data.toString(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/ListGroupsRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/ListGroupsRequest.java new file mode 100644 index 0000000..ab7ec61 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/ListGroupsRequest.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.ListGroupsRequestData; +import org.apache.kafka.common.message.ListGroupsResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Collections; + +/** + * Possible error codes: + * + * COORDINATOR_LOAD_IN_PROGRESS (14) + * COORDINATOR_NOT_AVAILABLE (15) + * AUTHORIZATION_FAILED (29) + */ +public class ListGroupsRequest extends AbstractRequest { + + public static class Builder extends AbstractRequest.Builder { + + private final ListGroupsRequestData data; + + public Builder(ListGroupsRequestData data) { + super(ApiKeys.LIST_GROUPS); + this.data = data; + } + + @Override + public ListGroupsRequest build(short version) { + if (!data.statesFilter().isEmpty() && version < 4) { + throw new UnsupportedVersionException("The broker only supports ListGroups " + + "v" + version + ", but we need v4 or newer to request groups by states."); + } + return new ListGroupsRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final ListGroupsRequestData data; + + public ListGroupsRequest(ListGroupsRequestData data, short version) { + super(ApiKeys.LIST_GROUPS, version); + this.data = data; + } + + @Override + public ListGroupsResponse getErrorResponse(int throttleTimeMs, Throwable e) { + ListGroupsResponseData listGroupsResponseData = new ListGroupsResponseData(). + setGroups(Collections.emptyList()). + setErrorCode(Errors.forException(e).code()); + if (version() >= 1) { + listGroupsResponseData.setThrottleTimeMs(throttleTimeMs); + } + return new ListGroupsResponse(listGroupsResponseData); + } + + public static ListGroupsRequest parse(ByteBuffer buffer, short version) { + return new ListGroupsRequest(new ListGroupsRequestData(new ByteBufferAccessor(buffer), version), version); + } + + @Override + public ListGroupsRequestData data() { + return data; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/ListGroupsResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/ListGroupsResponse.java new file mode 100644 index 0000000..270c43c --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/ListGroupsResponse.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.ListGroupsResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Map; + +public class ListGroupsResponse extends AbstractResponse { + + private final ListGroupsResponseData data; + + public ListGroupsResponse(ListGroupsResponseData data) { + super(ApiKeys.LIST_GROUPS); + this.data = data; + } + + @Override + public ListGroupsResponseData data() { + return data; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + @Override + public Map errorCounts() { + return errorCounts(Errors.forCode(data.errorCode())); + } + + public static ListGroupsResponse parse(ByteBuffer buffer, short version) { + return new ListGroupsResponse(new ListGroupsResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 2; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/ListOffsetsRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/ListOffsetsRequest.java new file mode 100644 index 0000000..6b7734a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/ListOffsetsRequest.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.apache.kafka.common.IsolationLevel; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.ListOffsetsRequestData; +import org.apache.kafka.common.message.ListOffsetsRequestData.ListOffsetsPartition; +import org.apache.kafka.common.message.ListOffsetsRequestData.ListOffsetsTopic; +import org.apache.kafka.common.message.ListOffsetsResponseData; +import org.apache.kafka.common.message.ListOffsetsResponseData.ListOffsetsPartitionResponse; +import org.apache.kafka.common.message.ListOffsetsResponseData.ListOffsetsTopicResponse; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +public class ListOffsetsRequest extends AbstractRequest { + public static final long EARLIEST_TIMESTAMP = -2L; + public static final long LATEST_TIMESTAMP = -1L; + public static final long MAX_TIMESTAMP = -3L; + + public static final int CONSUMER_REPLICA_ID = -1; + public static final int DEBUGGING_REPLICA_ID = -2; + + private final ListOffsetsRequestData data; + private final Set duplicatePartitions; + + public static class Builder extends AbstractRequest.Builder { + private final ListOffsetsRequestData data; + + public static Builder forReplica(short allowedVersion, int replicaId) { + return new Builder((short) 0, allowedVersion, replicaId, IsolationLevel.READ_UNCOMMITTED); + } + + public static Builder forConsumer(boolean requireTimestamp, IsolationLevel isolationLevel, boolean requireMaxTimestamp) { + short minVersion = 0; + if (requireMaxTimestamp) + minVersion = 7; + else if (isolationLevel == IsolationLevel.READ_COMMITTED) + minVersion = 2; + else if (requireTimestamp) + minVersion = 1; + return new Builder(minVersion, ApiKeys.LIST_OFFSETS.latestVersion(), CONSUMER_REPLICA_ID, isolationLevel); + } + + private Builder(short oldestAllowedVersion, + short latestAllowedVersion, + int replicaId, + IsolationLevel isolationLevel) { + super(ApiKeys.LIST_OFFSETS, oldestAllowedVersion, latestAllowedVersion); + data = new ListOffsetsRequestData() + .setIsolationLevel(isolationLevel.id()) + .setReplicaId(replicaId); + } + + public Builder setTargetTimes(List topics) { + data.setTopics(topics); + return this; + } + + @Override + public ListOffsetsRequest build(short version) { + return new ListOffsetsRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + /** + * Private constructor with a specified version. + */ + private ListOffsetsRequest(ListOffsetsRequestData data, short version) { + super(ApiKeys.LIST_OFFSETS, version); + this.data = data; + duplicatePartitions = new HashSet<>(); + Set partitions = new HashSet<>(); + for (ListOffsetsTopic topic : data.topics()) { + for (ListOffsetsPartition partition : topic.partitions()) { + TopicPartition tp = new TopicPartition(topic.name(), partition.partitionIndex()); + if (!partitions.add(tp)) { + duplicatePartitions.add(tp); + } + } + } + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + short versionId = version(); + short errorCode = Errors.forException(e).code(); + + List responses = new ArrayList<>(); + for (ListOffsetsTopic topic : data.topics()) { + ListOffsetsTopicResponse topicResponse = new ListOffsetsTopicResponse().setName(topic.name()); + List partitions = new ArrayList<>(); + for (ListOffsetsPartition partition : topic.partitions()) { + ListOffsetsPartitionResponse partitionResponse = new ListOffsetsPartitionResponse() + .setErrorCode(errorCode) + .setPartitionIndex(partition.partitionIndex()); + if (versionId == 0) { + partitionResponse.setOldStyleOffsets(Collections.emptyList()); + } else { + partitionResponse.setOffset(ListOffsetsResponse.UNKNOWN_OFFSET) + .setTimestamp(ListOffsetsResponse.UNKNOWN_TIMESTAMP); + } + partitions.add(partitionResponse); + } + topicResponse.setPartitions(partitions); + responses.add(topicResponse); + } + ListOffsetsResponseData responseData = new ListOffsetsResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setTopics(responses); + return new ListOffsetsResponse(responseData); + } + + @Override + public ListOffsetsRequestData data() { + return data; + } + + public int replicaId() { + return data.replicaId(); + } + + public IsolationLevel isolationLevel() { + return IsolationLevel.forId(data.isolationLevel()); + } + + public List topics() { + return data.topics(); + } + + public Set duplicatePartitions() { + return duplicatePartitions; + } + + public static ListOffsetsRequest parse(ByteBuffer buffer, short version) { + return new ListOffsetsRequest(new ListOffsetsRequestData(new ByteBufferAccessor(buffer), version), version); + } + + public static List toListOffsetsTopics(Map timestampsToSearch) { + Map topics = new HashMap<>(); + for (Map.Entry entry : timestampsToSearch.entrySet()) { + TopicPartition tp = entry.getKey(); + ListOffsetsTopic topic = topics.computeIfAbsent(tp.topic(), k -> new ListOffsetsTopic().setName(tp.topic())); + topic.partitions().add(entry.getValue()); + } + return new ArrayList<>(topics.values()); + } + + public static ListOffsetsTopic singletonRequestData(String topic, int partitionIndex, long timestamp, int maxNumOffsets) { + return new ListOffsetsTopic() + .setName(topic) + .setPartitions(Collections.singletonList(new ListOffsetsPartition() + .setPartitionIndex(partitionIndex) + .setTimestamp(timestamp) + .setMaxNumOffsets(maxNumOffsets))); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/ListOffsetsResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/ListOffsetsResponse.java new file mode 100644 index 0000000..8c4a51b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/ListOffsetsResponse.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.ListOffsetsResponseData; +import org.apache.kafka.common.message.ListOffsetsResponseData.ListOffsetsPartitionResponse; +import org.apache.kafka.common.message.ListOffsetsResponseData.ListOffsetsTopicResponse; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.RecordBatch; + +/** + * Possible error codes: + * + * - {@link Errors#UNSUPPORTED_FOR_MESSAGE_FORMAT} If the message format does not support lookup by timestamp + * - {@link Errors#TOPIC_AUTHORIZATION_FAILED} If the user does not have DESCRIBE access to a requested topic + * - {@link Errors#REPLICA_NOT_AVAILABLE} If the request is received by a broker with version < 2.6 which is not a replica + * - {@link Errors#NOT_LEADER_OR_FOLLOWER} If the broker is not a leader or follower and either the provided leader epoch + * matches the known leader epoch on the broker or is empty + * - {@link Errors#FENCED_LEADER_EPOCH} If the epoch is lower than the broker's epoch + * - {@link Errors#UNKNOWN_LEADER_EPOCH} If the epoch is larger than the broker's epoch + * - {@link Errors#UNKNOWN_TOPIC_OR_PARTITION} If the broker does not have metadata for a topic or partition + * - {@link Errors#KAFKA_STORAGE_ERROR} If the log directory for one of the requested partitions is offline + * - {@link Errors#UNKNOWN_SERVER_ERROR} For any unexpected errors + * - {@link Errors#LEADER_NOT_AVAILABLE} The leader's HW has not caught up after recent election (v4 protocol) + * - {@link Errors#OFFSET_NOT_AVAILABLE} The leader's HW has not caught up after recent election (v5+ protocol) + */ +public class ListOffsetsResponse extends AbstractResponse { + public static final long UNKNOWN_TIMESTAMP = -1L; + public static final long UNKNOWN_OFFSET = -1L; + public static final int UNKNOWN_EPOCH = RecordBatch.NO_PARTITION_LEADER_EPOCH; + + private final ListOffsetsResponseData data; + + public ListOffsetsResponse(ListOffsetsResponseData data) { + super(ApiKeys.LIST_OFFSETS); + this.data = data; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + @Override + public ListOffsetsResponseData data() { + return data; + } + + public List topics() { + return data.topics(); + } + + @Override + public Map errorCounts() { + Map errorCounts = new HashMap<>(); + topics().forEach(topic -> + topic.partitions().forEach(partition -> + updateErrorCounts(errorCounts, Errors.forCode(partition.errorCode())) + ) + ); + return errorCounts; + } + + public static ListOffsetsResponse parse(ByteBuffer buffer, short version) { + return new ListOffsetsResponse(new ListOffsetsResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public String toString() { + return data.toString(); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 3; + } + + public static ListOffsetsTopicResponse singletonListOffsetsTopicResponse(TopicPartition tp, Errors error, long timestamp, long offset, int epoch) { + return new ListOffsetsTopicResponse() + .setName(tp.topic()) + .setPartitions(Collections.singletonList(new ListOffsetsPartitionResponse() + .setPartitionIndex(tp.partition()) + .setErrorCode(error.code()) + .setTimestamp(timestamp) + .setOffset(offset) + .setLeaderEpoch(epoch))); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/ListPartitionReassignmentsRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/ListPartitionReassignmentsRequest.java new file mode 100644 index 0000000..03affd1 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/ListPartitionReassignmentsRequest.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.ListPartitionReassignmentsRequestData; +import org.apache.kafka.common.message.ListPartitionReassignmentsResponseData; +import org.apache.kafka.common.message.ListPartitionReassignmentsResponseData.OngoingPartitionReassignment; +import org.apache.kafka.common.message.ListPartitionReassignmentsResponseData.OngoingTopicReassignment; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import static org.apache.kafka.common.message.ListPartitionReassignmentsRequestData.ListPartitionReassignmentsTopics; + +public class ListPartitionReassignmentsRequest extends AbstractRequest { + + public static class Builder extends AbstractRequest.Builder { + private final ListPartitionReassignmentsRequestData data; + + public Builder(ListPartitionReassignmentsRequestData data) { + super(ApiKeys.LIST_PARTITION_REASSIGNMENTS); + this.data = data; + } + + @Override + public ListPartitionReassignmentsRequest build(short version) { + return new ListPartitionReassignmentsRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private ListPartitionReassignmentsRequestData data; + + private ListPartitionReassignmentsRequest(ListPartitionReassignmentsRequestData data, short version) { + super(ApiKeys.LIST_PARTITION_REASSIGNMENTS, version); + this.data = data; + } + + public static ListPartitionReassignmentsRequest parse(ByteBuffer buffer, short version) { + return new ListPartitionReassignmentsRequest(new ListPartitionReassignmentsRequestData( + new ByteBufferAccessor(buffer), version), version); + } + + @Override + public ListPartitionReassignmentsRequestData data() { + return data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + ApiError apiError = ApiError.fromThrowable(e); + + List ongoingTopicReassignments = new ArrayList<>(); + if (data.topics() != null) { + for (ListPartitionReassignmentsTopics topic : data.topics()) { + ongoingTopicReassignments.add( + new OngoingTopicReassignment() + .setName(topic.name()) + .setPartitions(topic.partitionIndexes().stream().map(partitionIndex -> + new OngoingPartitionReassignment().setPartitionIndex(partitionIndex)).collect(Collectors.toList())) + ); + } + } + ListPartitionReassignmentsResponseData responseData = new ListPartitionReassignmentsResponseData() + .setTopics(ongoingTopicReassignments) + .setErrorCode(apiError.error().code()) + .setErrorMessage(apiError.message()) + .setThrottleTimeMs(throttleTimeMs); + return new ListPartitionReassignmentsResponse(responseData); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/ListPartitionReassignmentsResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/ListPartitionReassignmentsResponse.java new file mode 100644 index 0000000..4a890e8 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/ListPartitionReassignmentsResponse.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.ListPartitionReassignmentsResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Map; + +public class ListPartitionReassignmentsResponse extends AbstractResponse { + + private final ListPartitionReassignmentsResponseData data; + + public ListPartitionReassignmentsResponse(ListPartitionReassignmentsResponseData responseData) { + super(ApiKeys.LIST_PARTITION_REASSIGNMENTS); + this.data = responseData; + } + + public static ListPartitionReassignmentsResponse parse(ByteBuffer buffer, short version) { + return new ListPartitionReassignmentsResponse(new ListPartitionReassignmentsResponseData( + new ByteBufferAccessor(buffer), version)); + } + + @Override + public ListPartitionReassignmentsResponseData data() { + return data; + } + + @Override + public boolean shouldClientThrottle(short version) { + return true; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + @Override + public Map errorCounts() { + return errorCounts(Errors.forCode(data.errorCode())); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/ListTransactionsRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/ListTransactionsRequest.java new file mode 100644 index 0000000..0651f1f --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/ListTransactionsRequest.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.ListTransactionsRequestData; +import org.apache.kafka.common.message.ListTransactionsResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; + +public class ListTransactionsRequest extends AbstractRequest { + public static class Builder extends AbstractRequest.Builder { + public final ListTransactionsRequestData data; + + public Builder(ListTransactionsRequestData data) { + super(ApiKeys.LIST_TRANSACTIONS); + this.data = data; + } + + @Override + public ListTransactionsRequest build(short version) { + return new ListTransactionsRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final ListTransactionsRequestData data; + + private ListTransactionsRequest(ListTransactionsRequestData data, short version) { + super(ApiKeys.LIST_TRANSACTIONS, version); + this.data = data; + } + + public ListTransactionsRequestData data() { + return data; + } + + @Override + public ListTransactionsResponse getErrorResponse(int throttleTimeMs, Throwable e) { + Errors error = Errors.forException(e); + ListTransactionsResponseData response = new ListTransactionsResponseData() + .setErrorCode(error.code()) + .setThrottleTimeMs(throttleTimeMs); + return new ListTransactionsResponse(response); + } + + public static ListTransactionsRequest parse(ByteBuffer buffer, short version) { + return new ListTransactionsRequest(new ListTransactionsRequestData( + new ByteBufferAccessor(buffer), version), version); + } + + @Override + public String toString(boolean verbose) { + return data.toString(); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/ListTransactionsResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/ListTransactionsResponse.java new file mode 100644 index 0000000..13ed184 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/ListTransactionsResponse.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.ListTransactionsResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; + +public class ListTransactionsResponse extends AbstractResponse { + private final ListTransactionsResponseData data; + + public ListTransactionsResponse(ListTransactionsResponseData data) { + super(ApiKeys.LIST_TRANSACTIONS); + this.data = data; + } + + public ListTransactionsResponseData data() { + return data; + } + + @Override + public Map errorCounts() { + Map errorCounts = new HashMap<>(); + updateErrorCounts(errorCounts, Errors.forCode(data.errorCode())); + return errorCounts; + } + + public static ListTransactionsResponse parse(ByteBuffer buffer, short version) { + return new ListTransactionsResponse(new ListTransactionsResponseData( + new ByteBufferAccessor(buffer), version)); + } + + @Override + public String toString() { + return data.toString(); + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/MetadataRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/MetadataRequest.java new file mode 100644 index 0000000..aab5fc6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/MetadataRequest.java @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.MetadataRequestData; +import org.apache.kafka.common.message.MetadataRequestData.MetadataRequestTopic; +import org.apache.kafka.common.message.MetadataResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +public class MetadataRequest extends AbstractRequest { + + public static class Builder extends AbstractRequest.Builder { + private static final MetadataRequestData ALL_TOPICS_REQUEST_DATA = new MetadataRequestData(). + setTopics(null).setAllowAutoTopicCreation(true); + + private final MetadataRequestData data; + + public Builder(MetadataRequestData data) { + super(ApiKeys.METADATA); + this.data = data; + } + + public Builder(List topics, boolean allowAutoTopicCreation, short allowedVersion) { + this(topics, allowAutoTopicCreation, allowedVersion, allowedVersion); + } + + public Builder(List topics, boolean allowAutoTopicCreation, short minVersion, short maxVersion) { + super(ApiKeys.METADATA, minVersion, maxVersion); + MetadataRequestData data = new MetadataRequestData(); + if (topics == null) + data.setTopics(null); + else { + topics.forEach(topic -> data.topics().add(new MetadataRequestTopic().setName(topic))); + } + + data.setAllowAutoTopicCreation(allowAutoTopicCreation); + this.data = data; + } + + public Builder(List topics, boolean allowAutoTopicCreation) { + this(topics, allowAutoTopicCreation, ApiKeys.METADATA.oldestVersion(), ApiKeys.METADATA.latestVersion()); + } + + public Builder(List topicIds) { + super(ApiKeys.METADATA, ApiKeys.METADATA.oldestVersion(), ApiKeys.METADATA.latestVersion()); + MetadataRequestData data = new MetadataRequestData(); + if (topicIds == null) + data.setTopics(null); + else { + topicIds.forEach(topicId -> data.topics().add(new MetadataRequestTopic().setTopicId(topicId))); + } + + // It's impossible to create topic with topicId + data.setAllowAutoTopicCreation(false); + this.data = data; + } + + public static Builder allTopics() { + // This never causes auto-creation, but we set the boolean to true because that is the default value when + // deserializing V2 and older. This way, the value is consistent after serialization and deserialization. + return new Builder(ALL_TOPICS_REQUEST_DATA); + } + + public boolean emptyTopicList() { + return data.topics().isEmpty(); + } + + public boolean isAllTopics() { + return data.topics() == null; + } + + public List topics() { + return data.topics() + .stream() + .map(MetadataRequestTopic::name) + .collect(Collectors.toList()); + } + + @Override + public MetadataRequest build(short version) { + if (version < 1) + throw new UnsupportedVersionException("MetadataRequest versions older than 1 are not supported."); + if (!data.allowAutoTopicCreation() && version < 4) + throw new UnsupportedVersionException("MetadataRequest versions older than 4 don't support the " + + "allowAutoTopicCreation field"); + if (data.topics() != null) { + data.topics().forEach(topic -> { + if (topic.name() == null && version < 12) + throw new UnsupportedVersionException("MetadataRequest version " + version + + " does not support null topic names."); + if (topic.topicId() != Uuid.ZERO_UUID && version < 12) + throw new UnsupportedVersionException("MetadataRequest version " + version + + " does not support non-zero topic IDs."); + }); + } + return new MetadataRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final MetadataRequestData data; + + public MetadataRequest(MetadataRequestData data, short version) { + super(ApiKeys.METADATA, version); + this.data = data; + } + + @Override + public MetadataRequestData data() { + return data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + Errors error = Errors.forException(e); + MetadataResponseData responseData = new MetadataResponseData(); + if (data.topics() != null) { + for (MetadataRequestTopic topic : data.topics()) { + // the response does not allow null, so convert to empty string if necessary + String topicName = topic.name() == null ? "" : topic.name(); + responseData.topics().add(new MetadataResponseData.MetadataResponseTopic() + .setName(topicName) + .setTopicId(topic.topicId()) + .setErrorCode(error.code()) + .setIsInternal(false) + .setPartitions(Collections.emptyList())); + } + } + + responseData.setThrottleTimeMs(throttleTimeMs); + return new MetadataResponse(responseData, true); + } + + public boolean isAllTopics() { + return (data.topics() == null) || + (data.topics().isEmpty() && version() == 0); // In version 0, an empty topic list indicates + // "request metadata for all topics." + } + + public List topics() { + if (isAllTopics()) // In version 0, we return null for empty topic list + return null; + else + return data.topics() + .stream() + .map(MetadataRequestTopic::name) + .collect(Collectors.toList()); + } + + public List topicIds() { + if (isAllTopics()) + return Collections.emptyList(); + else if (version() < 10) + return Collections.emptyList(); + else + return data.topics() + .stream() + .map(MetadataRequestTopic::topicId) + .collect(Collectors.toList()); + } + + public boolean allowAutoTopicCreation() { + return data.allowAutoTopicCreation(); + } + + public static MetadataRequest parse(ByteBuffer buffer, short version) { + return new MetadataRequest(new MetadataRequestData(new ByteBufferAccessor(buffer), version), version); + } + + public static List convertToMetadataRequestTopic(final Collection topics) { + return topics.stream().map(topic -> new MetadataRequestTopic() + .setName(topic)) + .collect(Collectors.toList()); + } + + public static List convertTopicIdsToMetadataRequestTopic(final Collection topicIds) { + return topicIds.stream().map(topicId -> new MetadataRequestTopic() + .setTopicId(topicId)) + .collect(Collectors.toList()); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/MetadataResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/MetadataResponse.java new file mode 100644 index 0000000..d539fa8 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/MetadataResponse.java @@ -0,0 +1,508 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.message.MetadataResponseData; +import org.apache.kafka.common.message.MetadataResponseData.MetadataResponseBroker; +import org.apache.kafka.common.message.MetadataResponseData.MetadataResponsePartition; +import org.apache.kafka.common.message.MetadataResponseData.MetadataResponseTopic; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.utils.Utils; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * Possible topic-level error codes: + * UnknownTopic (3) + * LeaderNotAvailable (5) + * InvalidTopic (17) + * TopicAuthorizationFailed (29) + + * Possible partition-level error codes: + * LeaderNotAvailable (5) + * ReplicaNotAvailable (9) + */ +public class MetadataResponse extends AbstractResponse { + public static final int NO_CONTROLLER_ID = -1; + public static final int NO_LEADER_ID = -1; + public static final int AUTHORIZED_OPERATIONS_OMITTED = Integer.MIN_VALUE; + + private final MetadataResponseData data; + private volatile Holder holder; + private final boolean hasReliableLeaderEpochs; + + public MetadataResponse(MetadataResponseData data, short version) { + this(data, hasReliableLeaderEpochs(version)); + } + + MetadataResponse(MetadataResponseData data, boolean hasReliableLeaderEpochs) { + super(ApiKeys.METADATA); + this.data = data; + this.hasReliableLeaderEpochs = hasReliableLeaderEpochs; + } + + @Override + public MetadataResponseData data() { + return data; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + /** + * Get a map of the topics which had metadata errors + * @return the map + */ + public Map errors() { + Map errors = new HashMap<>(); + for (MetadataResponseTopic metadata : data.topics()) { + if (metadata.name() == null) { + throw new IllegalStateException("Use errorsByTopicId() when managing topic using topic id"); + } + if (metadata.errorCode() != Errors.NONE.code()) + errors.put(metadata.name(), Errors.forCode(metadata.errorCode())); + } + return errors; + } + + /** + * Get a map of the topicIds which had metadata errors + * @return the map + */ + public Map errorsByTopicId() { + Map errors = new HashMap<>(); + for (MetadataResponseTopic metadata : data.topics()) { + if (metadata.topicId() == Uuid.ZERO_UUID) { + throw new IllegalStateException("Use errors() when managing topic using topic name"); + } + if (metadata.errorCode() != Errors.NONE.code()) + errors.put(metadata.topicId(), Errors.forCode(metadata.errorCode())); + } + return errors; + } + + @Override + public Map errorCounts() { + Map errorCounts = new HashMap<>(); + data.topics().forEach(metadata -> { + metadata.partitions().forEach(p -> updateErrorCounts(errorCounts, Errors.forCode(p.errorCode()))); + updateErrorCounts(errorCounts, Errors.forCode(metadata.errorCode())); + }); + return errorCounts; + } + + /** + * Returns the set of topics with the specified error + */ + public Set topicsByError(Errors error) { + Set errorTopics = new HashSet<>(); + for (MetadataResponseTopic metadata : data.topics()) { + if (metadata.errorCode() == error.code()) + errorTopics.add(metadata.name()); + } + return errorTopics; + } + + /** + * Get a snapshot of the cluster metadata from this response + * @return the cluster snapshot + */ + public Cluster buildCluster() { + Set internalTopics = new HashSet<>(); + List partitions = new ArrayList<>(); + Map topicIds = new HashMap<>(); + + for (TopicMetadata metadata : topicMetadata()) { + if (metadata.error == Errors.NONE) { + if (metadata.isInternal) + internalTopics.add(metadata.topic); + if (metadata.topicId() != null && metadata.topicId() != Uuid.ZERO_UUID) { + topicIds.put(metadata.topic, metadata.topicId()); + } + for (PartitionMetadata partitionMetadata : metadata.partitionMetadata) { + partitions.add(toPartitionInfo(partitionMetadata, holder().brokers)); + } + } + } + return new Cluster(data.clusterId(), brokers(), partitions, topicsByError(Errors.TOPIC_AUTHORIZATION_FAILED), + topicsByError(Errors.INVALID_TOPIC_EXCEPTION), internalTopics, controller(), topicIds); + } + + public static PartitionInfo toPartitionInfo(PartitionMetadata metadata, Map nodesById) { + return new PartitionInfo(metadata.topic(), + metadata.partition(), + metadata.leaderId.map(nodesById::get).orElse(null), + convertToNodeArray(metadata.replicaIds, nodesById), + convertToNodeArray(metadata.inSyncReplicaIds, nodesById), + convertToNodeArray(metadata.offlineReplicaIds, nodesById)); + } + + private static Node[] convertToNodeArray(List replicaIds, Map nodesById) { + return replicaIds.stream().map(replicaId -> { + Node node = nodesById.get(replicaId); + if (node == null) + return new Node(replicaId, "", -1); + return node; + }).toArray(Node[]::new); + } + + /** + * Returns a 32-bit bitfield to represent authorized operations for this topic. + */ + public Optional topicAuthorizedOperations(String topicName) { + MetadataResponseTopic topic = data.topics().find(topicName); + if (topic == null) + return Optional.empty(); + else + return Optional.of(topic.topicAuthorizedOperations()); + } + + /** + * Returns a 32-bit bitfield to represent authorized operations for this cluster. + */ + public int clusterAuthorizedOperations() { + return data.clusterAuthorizedOperations(); + } + + private Holder holder() { + if (holder == null) { + synchronized (data) { + if (holder == null) + holder = new Holder(data); + } + } + return holder; + } + + /** + * Get all brokers returned in metadata response + * @return the brokers + */ + public Collection brokers() { + return holder().brokers.values(); + } + + public Map brokersById() { + return holder().brokers; + } + + /** + * Get all topic metadata returned in the metadata response + * @return the topicMetadata + */ + public Collection topicMetadata() { + return holder().topicMetadata; + } + + /** + * The controller node returned in metadata response + * @return the controller node or null if it doesn't exist + */ + public Node controller() { + return holder().controller; + } + + /** + * The cluster identifier returned in the metadata response. + * @return cluster identifier if it is present in the response, null otherwise. + */ + public String clusterId() { + return this.data.clusterId(); + } + + /** + * Check whether the leader epochs returned from the response can be relied on + * for epoch validation in Fetch, ListOffsets, and OffsetsForLeaderEpoch requests. + * If not, then the client will not retain the leader epochs and hence will not + * forward them in requests. + * + * @return true if the epoch can be used for validation + */ + public boolean hasReliableLeaderEpochs() { + return hasReliableLeaderEpochs; + } + + // Prior to Kafka version 2.4 (which coincides with Metadata version 9), the broker + // does not propagate leader epoch information accurately while a reassignment is in + // progress. Relying on a stale epoch can lead to FENCED_LEADER_EPOCH errors which + // can prevent consumption throughout the course of a reassignment. It is safer in + // this case to revert to the behavior in previous protocol versions which checks + // leader status only. + private static boolean hasReliableLeaderEpochs(short version) { + return version >= 9; + } + + public static MetadataResponse parse(ByteBuffer buffer, short version) { + return new MetadataResponse(new MetadataResponseData(new ByteBufferAccessor(buffer), version), + hasReliableLeaderEpochs(version)); + } + + public static class TopicMetadata { + private final Errors error; + private final String topic; + private final Uuid topicId; + private final boolean isInternal; + private final List partitionMetadata; + private int authorizedOperations; + + public TopicMetadata(Errors error, + String topic, + Uuid topicId, + boolean isInternal, + List partitionMetadata, + int authorizedOperations) { + this.error = error; + this.topic = topic; + this.topicId = topicId; + this.isInternal = isInternal; + this.partitionMetadata = partitionMetadata; + this.authorizedOperations = authorizedOperations; + } + + public TopicMetadata(Errors error, + String topic, + boolean isInternal, + List partitionMetadata) { + this(error, topic, Uuid.ZERO_UUID, isInternal, partitionMetadata, AUTHORIZED_OPERATIONS_OMITTED); + } + + public Errors error() { + return error; + } + + public String topic() { + return topic; + } + + public Uuid topicId() { + return topicId; + } + + public boolean isInternal() { + return isInternal; + } + + public List partitionMetadata() { + return partitionMetadata; + } + + public void authorizedOperations(int authorizedOperations) { + this.authorizedOperations = authorizedOperations; + } + + public int authorizedOperations() { + return authorizedOperations; + } + + @Override + public boolean equals(final Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + final TopicMetadata that = (TopicMetadata) o; + return isInternal == that.isInternal && + error == that.error && + Objects.equals(topic, that.topic) && + Objects.equals(topicId, that.topicId) && + Objects.equals(partitionMetadata, that.partitionMetadata) && + Objects.equals(authorizedOperations, that.authorizedOperations); + } + + @Override + public int hashCode() { + return Objects.hash(error, topic, isInternal, partitionMetadata, authorizedOperations); + } + + @Override + public String toString() { + return "TopicMetadata{" + + "error=" + error + + ", topic='" + topic + '\'' + + ", topicId='" + topicId + '\'' + + ", isInternal=" + isInternal + + ", partitionMetadata=" + partitionMetadata + + ", authorizedOperations=" + authorizedOperations + + '}'; + } + } + + // This is used to describe per-partition state in the MetadataResponse + public static class PartitionMetadata { + public final TopicPartition topicPartition; + public final Errors error; + public final Optional leaderId; + public final Optional leaderEpoch; + public final List replicaIds; + public final List inSyncReplicaIds; + public final List offlineReplicaIds; + + public PartitionMetadata(Errors error, + TopicPartition topicPartition, + Optional leaderId, + Optional leaderEpoch, + List replicaIds, + List inSyncReplicaIds, + List offlineReplicaIds) { + this.error = error; + this.topicPartition = topicPartition; + this.leaderId = leaderId; + this.leaderEpoch = leaderEpoch; + this.replicaIds = replicaIds; + this.inSyncReplicaIds = inSyncReplicaIds; + this.offlineReplicaIds = offlineReplicaIds; + } + + public int partition() { + return topicPartition.partition(); + } + + public String topic() { + return topicPartition.topic(); + } + + public PartitionMetadata withoutLeaderEpoch() { + return new PartitionMetadata(error, + topicPartition, + leaderId, + Optional.empty(), + replicaIds, + inSyncReplicaIds, + offlineReplicaIds); + } + + @Override + public String toString() { + return "PartitionMetadata(" + + "error=" + error + + ", partition=" + topicPartition + + ", leader=" + leaderId + + ", leaderEpoch=" + leaderEpoch + + ", replicas=" + Utils.join(replicaIds, ",") + + ", isr=" + Utils.join(inSyncReplicaIds, ",") + + ", offlineReplicas=" + Utils.join(offlineReplicaIds, ",") + ')'; + } + } + + private static class Holder { + private final Map brokers; + private final Node controller; + private final Collection topicMetadata; + + Holder(MetadataResponseData data) { + this.brokers = Collections.unmodifiableMap(createBrokers(data)); + this.topicMetadata = createTopicMetadata(data); + this.controller = brokers.get(data.controllerId()); + } + + private Map createBrokers(MetadataResponseData data) { + return data.brokers().valuesList().stream().map(b -> new Node(b.nodeId(), b.host(), b.port(), b.rack())) + .collect(Collectors.toMap(Node::id, Function.identity())); + } + + private Collection createTopicMetadata(MetadataResponseData data) { + List topicMetadataList = new ArrayList<>(); + for (MetadataResponseTopic topicMetadata : data.topics()) { + Errors topicError = Errors.forCode(topicMetadata.errorCode()); + String topic = topicMetadata.name(); + Uuid topicId = topicMetadata.topicId(); + boolean isInternal = topicMetadata.isInternal(); + List partitionMetadataList = new ArrayList<>(); + + for (MetadataResponsePartition partitionMetadata : topicMetadata.partitions()) { + Errors partitionError = Errors.forCode(partitionMetadata.errorCode()); + int partitionIndex = partitionMetadata.partitionIndex(); + + int leaderId = partitionMetadata.leaderId(); + Optional leaderIdOpt = leaderId < 0 ? Optional.empty() : Optional.of(leaderId); + + Optional leaderEpoch = RequestUtils.getLeaderEpoch(partitionMetadata.leaderEpoch()); + TopicPartition topicPartition = new TopicPartition(topic, partitionIndex); + partitionMetadataList.add(new PartitionMetadata(partitionError, topicPartition, leaderIdOpt, + leaderEpoch, partitionMetadata.replicaNodes(), partitionMetadata.isrNodes(), + partitionMetadata.offlineReplicas())); + } + + topicMetadataList.add(new TopicMetadata(topicError, topic, topicId, isInternal, partitionMetadataList, + topicMetadata.topicAuthorizedOperations())); + } + return topicMetadataList; + } + + } + + public static MetadataResponse prepareResponse(short version, + int throttleTimeMs, + Collection brokers, + String clusterId, + int controllerId, + List topics, + int clusterAuthorizedOperations) { + return prepareResponse(hasReliableLeaderEpochs(version), throttleTimeMs, brokers, clusterId, controllerId, + topics, clusterAuthorizedOperations); + } + + // Visible for testing + public static MetadataResponse prepareResponse(boolean hasReliableEpoch, + int throttleTimeMs, + Collection brokers, + String clusterId, + int controllerId, + List topics, + int clusterAuthorizedOperations) { + MetadataResponseData responseData = new MetadataResponseData(); + responseData.setThrottleTimeMs(throttleTimeMs); + brokers.forEach(broker -> + responseData.brokers().add(new MetadataResponseBroker() + .setNodeId(broker.id()) + .setHost(broker.host()) + .setPort(broker.port()) + .setRack(broker.rack())) + ); + + responseData.setClusterId(clusterId); + responseData.setControllerId(controllerId); + responseData.setClusterAuthorizedOperations(clusterAuthorizedOperations); + + topics.forEach(topicMetadata -> responseData.topics().add(topicMetadata)); + return new MetadataResponse(responseData, hasReliableEpoch); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 6; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/OffsetCommitRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/OffsetCommitRequest.java new file mode 100644 index 0000000..9869da5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/OffsetCommitRequest.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.OffsetCommitRequestData; +import org.apache.kafka.common.message.OffsetCommitRequestData.OffsetCommitRequestTopic; +import org.apache.kafka.common.message.OffsetCommitResponseData; +import org.apache.kafka.common.message.OffsetCommitResponseData.OffsetCommitResponsePartition; +import org.apache.kafka.common.message.OffsetCommitResponseData.OffsetCommitResponseTopic; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class OffsetCommitRequest extends AbstractRequest { + // default values for the current version + public static final int DEFAULT_GENERATION_ID = -1; + public static final String DEFAULT_MEMBER_ID = ""; + public static final long DEFAULT_RETENTION_TIME = -1L; + + // default values for old versions, will be removed after these versions are no longer supported + public static final long DEFAULT_TIMESTAMP = -1L; // for V0, V1 + + private final OffsetCommitRequestData data; + + public static class Builder extends AbstractRequest.Builder { + + private final OffsetCommitRequestData data; + + public Builder(OffsetCommitRequestData data) { + super(ApiKeys.OFFSET_COMMIT); + this.data = data; + } + + @Override + public OffsetCommitRequest build(short version) { + if (data.groupInstanceId() != null && version < 7) { + throw new UnsupportedVersionException("The broker offset commit protocol version " + + version + " does not support usage of config group.instance.id."); + } + return new OffsetCommitRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + public OffsetCommitRequest(OffsetCommitRequestData data, short version) { + super(ApiKeys.OFFSET_COMMIT, version); + this.data = data; + } + + @Override + public OffsetCommitRequestData data() { + return data; + } + + public Map offsets() { + Map offsets = new HashMap<>(); + for (OffsetCommitRequestTopic topic : data.topics()) { + for (OffsetCommitRequestData.OffsetCommitRequestPartition partition : topic.partitions()) { + offsets.put(new TopicPartition(topic.name(), partition.partitionIndex()), + partition.committedOffset()); + } + } + return offsets; + } + + public static List getErrorResponseTopics( + List requestTopics, + Errors e) { + List responseTopicData = new ArrayList<>(); + for (OffsetCommitRequestTopic entry : requestTopics) { + List responsePartitions = + new ArrayList<>(); + for (OffsetCommitRequestData.OffsetCommitRequestPartition requestPartition : entry.partitions()) { + responsePartitions.add(new OffsetCommitResponsePartition() + .setPartitionIndex(requestPartition.partitionIndex()) + .setErrorCode(e.code())); + } + responseTopicData.add(new OffsetCommitResponseTopic() + .setName(entry.name()) + .setPartitions(responsePartitions) + ); + } + return responseTopicData; + } + + @Override + public OffsetCommitResponse getErrorResponse(int throttleTimeMs, Throwable e) { + List + responseTopicData = getErrorResponseTopics(data.topics(), Errors.forException(e)); + return new OffsetCommitResponse(new OffsetCommitResponseData() + .setTopics(responseTopicData) + .setThrottleTimeMs(throttleTimeMs)); + } + + public static OffsetCommitRequest parse(ByteBuffer buffer, short version) { + return new OffsetCommitRequest(new OffsetCommitRequestData(new ByteBufferAccessor(buffer), version), version); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/OffsetCommitResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/OffsetCommitResponse.java new file mode 100644 index 0000000..2ed0e31 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/OffsetCommitResponse.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.OffsetCommitResponseData; +import org.apache.kafka.common.message.OffsetCommitResponseData.OffsetCommitResponsePartition; +import org.apache.kafka.common.message.OffsetCommitResponseData.OffsetCommitResponseTopic; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; + +/** + * Possible error codes: + * + * - {@link Errors#UNKNOWN_TOPIC_OR_PARTITION} + * - {@link Errors#REQUEST_TIMED_OUT} + * - {@link Errors#OFFSET_METADATA_TOO_LARGE} + * - {@link Errors#COORDINATOR_LOAD_IN_PROGRESS} + * - {@link Errors#COORDINATOR_NOT_AVAILABLE} + * - {@link Errors#NOT_COORDINATOR} + * - {@link Errors#ILLEGAL_GENERATION} + * - {@link Errors#UNKNOWN_MEMBER_ID} + * - {@link Errors#REBALANCE_IN_PROGRESS} + * - {@link Errors#INVALID_COMMIT_OFFSET_SIZE} + * - {@link Errors#TOPIC_AUTHORIZATION_FAILED} + * - {@link Errors#GROUP_AUTHORIZATION_FAILED} + */ +public class OffsetCommitResponse extends AbstractResponse { + + private final OffsetCommitResponseData data; + + public OffsetCommitResponse(OffsetCommitResponseData data) { + super(ApiKeys.OFFSET_COMMIT); + this.data = data; + } + + public OffsetCommitResponse(int requestThrottleMs, Map responseData) { + super(ApiKeys.OFFSET_COMMIT); + Map + responseTopicDataMap = new HashMap<>(); + + for (Map.Entry entry : responseData.entrySet()) { + TopicPartition topicPartition = entry.getKey(); + String topicName = topicPartition.topic(); + + OffsetCommitResponseTopic topic = responseTopicDataMap.getOrDefault( + topicName, new OffsetCommitResponseTopic().setName(topicName)); + + topic.partitions().add(new OffsetCommitResponsePartition() + .setErrorCode(entry.getValue().code()) + .setPartitionIndex(topicPartition.partition())); + responseTopicDataMap.put(topicName, topic); + } + + data = new OffsetCommitResponseData() + .setTopics(new ArrayList<>(responseTopicDataMap.values())) + .setThrottleTimeMs(requestThrottleMs); + } + + public OffsetCommitResponse(Map responseData) { + this(DEFAULT_THROTTLE_TIME, responseData); + } + + @Override + public OffsetCommitResponseData data() { + return data; + } + + @Override + public Map errorCounts() { + return errorCounts(data.topics().stream().flatMap(topicResult -> + topicResult.partitions().stream().map(partitionResult -> + Errors.forCode(partitionResult.errorCode())))); + } + + public static OffsetCommitResponse parse(ByteBuffer buffer, short version) { + return new OffsetCommitResponse(new OffsetCommitResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public String toString() { + return data.toString(); + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 4; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/OffsetDeleteRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/OffsetDeleteRequest.java new file mode 100644 index 0000000..28b763d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/OffsetDeleteRequest.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.OffsetDeleteRequestData; +import org.apache.kafka.common.message.OffsetDeleteResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; + +public class OffsetDeleteRequest extends AbstractRequest { + + public static class Builder extends AbstractRequest.Builder { + + private final OffsetDeleteRequestData data; + + public Builder(OffsetDeleteRequestData data) { + super(ApiKeys.OFFSET_DELETE); + this.data = data; + } + + @Override + public OffsetDeleteRequest build(short version) { + return new OffsetDeleteRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final OffsetDeleteRequestData data; + + public OffsetDeleteRequest(OffsetDeleteRequestData data, short version) { + super(ApiKeys.OFFSET_DELETE, version); + this.data = data; + } + + public AbstractResponse getErrorResponse(int throttleTimeMs, Errors error) { + return new OffsetDeleteResponse( + new OffsetDeleteResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setErrorCode(error.code()) + ); + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + return getErrorResponse(throttleTimeMs, Errors.forException(e)); + } + + public static OffsetDeleteRequest parse(ByteBuffer buffer, short version) { + return new OffsetDeleteRequest(new OffsetDeleteRequestData(new ByteBufferAccessor(buffer), version), version); + } + + @Override + public OffsetDeleteRequestData data() { + return data; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/OffsetDeleteResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/OffsetDeleteResponse.java new file mode 100644 index 0000000..79f6f4e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/OffsetDeleteResponse.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.OffsetDeleteResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; + +/** + * Possible error codes: + * + * - Partition errors: + * - {@link Errors#GROUP_SUBSCRIBED_TO_TOPIC} + * - {@link Errors#TOPIC_AUTHORIZATION_FAILED} + * - {@link Errors#UNKNOWN_TOPIC_OR_PARTITION} + * + * - Group or coordinator errors: + * - {@link Errors#COORDINATOR_LOAD_IN_PROGRESS} + * - {@link Errors#COORDINATOR_NOT_AVAILABLE} + * - {@link Errors#NOT_COORDINATOR} + * - {@link Errors#GROUP_AUTHORIZATION_FAILED} + * - {@link Errors#INVALID_GROUP_ID} + * - {@link Errors#GROUP_ID_NOT_FOUND} + * - {@link Errors#NON_EMPTY_GROUP} + */ +public class OffsetDeleteResponse extends AbstractResponse { + + private final OffsetDeleteResponseData data; + + public OffsetDeleteResponse(OffsetDeleteResponseData data) { + super(ApiKeys.OFFSET_DELETE); + this.data = data; + } + + @Override + public OffsetDeleteResponseData data() { + return data; + } + + @Override + public Map errorCounts() { + Map counts = new HashMap<>(); + updateErrorCounts(counts, Errors.forCode(data.errorCode())); + data.topics().forEach(topic -> + topic.partitions().forEach(partition -> + updateErrorCounts(counts, Errors.forCode(partition.errorCode())) + ) + ); + return counts; + } + + public static OffsetDeleteResponse parse(ByteBuffer buffer, short version) { + return new OffsetDeleteResponse(new OffsetDeleteResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 0; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/OffsetFetchRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/OffsetFetchRequest.java new file mode 100644 index 0000000..c5c094a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/OffsetFetchRequest.java @@ -0,0 +1,331 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import java.util.Collections; +import java.util.Map.Entry; +import java.util.stream.Collectors; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.OffsetFetchRequestData; +import org.apache.kafka.common.message.OffsetFetchRequestData.OffsetFetchRequestGroup; +import org.apache.kafka.common.message.OffsetFetchRequestData.OffsetFetchRequestTopic; +import org.apache.kafka.common.message.OffsetFetchRequestData.OffsetFetchRequestTopics; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +public class OffsetFetchRequest extends AbstractRequest { + + private static final Logger log = LoggerFactory.getLogger(OffsetFetchRequest.class); + + private static final List ALL_TOPIC_PARTITIONS = null; + private static final List ALL_TOPIC_PARTITIONS_BATCH = null; + private final OffsetFetchRequestData data; + + public static class Builder extends AbstractRequest.Builder { + + public final OffsetFetchRequestData data; + private final boolean throwOnFetchStableOffsetsUnsupported; + + public Builder(String groupId, + boolean requireStable, + List partitions, + boolean throwOnFetchStableOffsetsUnsupported) { + super(ApiKeys.OFFSET_FETCH); + + final List topics; + if (partitions != null) { + Map offsetFetchRequestTopicMap = new HashMap<>(); + for (TopicPartition topicPartition : partitions) { + String topicName = topicPartition.topic(); + OffsetFetchRequestTopic topic = offsetFetchRequestTopicMap.getOrDefault( + topicName, new OffsetFetchRequestTopic().setName(topicName)); + topic.partitionIndexes().add(topicPartition.partition()); + offsetFetchRequestTopicMap.put(topicName, topic); + } + topics = new ArrayList<>(offsetFetchRequestTopicMap.values()); + } else { + // If passed in partition list is null, it is requesting offsets for all topic partitions. + topics = ALL_TOPIC_PARTITIONS; + } + + this.data = new OffsetFetchRequestData() + .setGroupId(groupId) + .setRequireStable(requireStable) + .setTopics(topics); + this.throwOnFetchStableOffsetsUnsupported = throwOnFetchStableOffsetsUnsupported; + } + + boolean isAllTopicPartitions() { + return this.data.topics() == ALL_TOPIC_PARTITIONS; + } + + public Builder(Map> groupIdToTopicPartitionMap, + boolean requireStable, + boolean throwOnFetchStableOffsetsUnsupported) { + super(ApiKeys.OFFSET_FETCH); + + List groups = new ArrayList<>(); + for (Entry> entry : groupIdToTopicPartitionMap.entrySet()) { + String groupName = entry.getKey(); + List tpList = entry.getValue(); + final List topics; + if (tpList != null) { + Map offsetFetchRequestTopicMap = + new HashMap<>(); + for (TopicPartition topicPartition : tpList) { + String topicName = topicPartition.topic(); + OffsetFetchRequestTopics topic = offsetFetchRequestTopicMap.getOrDefault( + topicName, new OffsetFetchRequestTopics().setName(topicName)); + topic.partitionIndexes().add(topicPartition.partition()); + offsetFetchRequestTopicMap.put(topicName, topic); + } + topics = new ArrayList<>(offsetFetchRequestTopicMap.values()); + } else { + topics = ALL_TOPIC_PARTITIONS_BATCH; + } + groups.add(new OffsetFetchRequestGroup() + .setGroupId(groupName) + .setTopics(topics)); + } + this.data = new OffsetFetchRequestData() + .setGroups(groups) + .setRequireStable(requireStable); + this.throwOnFetchStableOffsetsUnsupported = throwOnFetchStableOffsetsUnsupported; + } + + @Override + public OffsetFetchRequest build(short version) { + if (isAllTopicPartitions() && version < 2) { + throw new UnsupportedVersionException("The broker only supports OffsetFetchRequest " + + "v" + version + ", but we need v2 or newer to request all topic partitions."); + } + if (data.groups().size() > 1 && version < 8) { + throw new NoBatchedOffsetFetchRequestException("Broker does not support" + + " batching groups for fetch offset request on version " + version); + } + if (data.requireStable() && version < 7) { + if (throwOnFetchStableOffsetsUnsupported) { + throw new UnsupportedVersionException("Broker unexpectedly " + + "doesn't support requireStable flag on version " + version); + } else { + log.trace("Fallback the requireStable flag to false as broker " + + "only supports OffsetFetchRequest version {}. Need " + + "v7 or newer to enable this feature", version); + data.setRequireStable(false); + } + } + // convert data to use the appropriate version since version 8 uses different format + if (version < 8) { + OffsetFetchRequestData oldDataFormat = null; + if (!data.groups().isEmpty()) { + OffsetFetchRequestGroup group = data.groups().get(0); + String groupName = group.groupId(); + List topics = group.topics(); + List oldFormatTopics = null; + if (topics != null) { + oldFormatTopics = topics + .stream() + .map(t -> + new OffsetFetchRequestTopic() + .setName(t.name()) + .setPartitionIndexes(t.partitionIndexes())) + .collect(Collectors.toList()); + } + oldDataFormat = new OffsetFetchRequestData() + .setGroupId(groupName) + .setTopics(oldFormatTopics) + .setRequireStable(data.requireStable()); + } + return new OffsetFetchRequest(oldDataFormat == null ? data : oldDataFormat, version); + } else { + if (data.groups().isEmpty()) { + String groupName = data.groupId(); + List oldFormatTopics = data.topics(); + List topics = null; + if (oldFormatTopics != null) { + topics = oldFormatTopics + .stream() + .map(t -> new OffsetFetchRequestTopics() + .setName(t.name()) + .setPartitionIndexes(t.partitionIndexes())) + .collect(Collectors.toList()); + } + OffsetFetchRequestData convertedDataFormat = + new OffsetFetchRequestData() + .setGroups(Collections.singletonList( + new OffsetFetchRequestGroup() + .setGroupId(groupName) + .setTopics(topics))) + .setRequireStable(data.requireStable()); + return new OffsetFetchRequest(convertedDataFormat, version); + } + } + return new OffsetFetchRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + /** + * Indicates that it is not possible to fetch consumer groups in batches with FetchOffset. + * Instead consumer groups' offsets must be fetched one by one. + */ + public static class NoBatchedOffsetFetchRequestException extends UnsupportedVersionException { + private static final long serialVersionUID = 1L; + + public NoBatchedOffsetFetchRequestException(String message) { + super(message); + } + } + + public List partitions() { + if (isAllPartitions()) { + return null; + } + List partitions = new ArrayList<>(); + for (OffsetFetchRequestTopic topic : data.topics()) { + for (Integer partitionIndex : topic.partitionIndexes()) { + partitions.add(new TopicPartition(topic.name(), partitionIndex)); + } + } + return partitions; + } + + public String groupId() { + return data.groupId(); + } + + public boolean requireStable() { + return data.requireStable(); + } + + public Map> groupIdsToPartitions() { + Map> groupIdsToPartitions = new HashMap<>(); + for (OffsetFetchRequestGroup group : data.groups()) { + List tpList = null; + if (group.topics() != ALL_TOPIC_PARTITIONS_BATCH) { + tpList = new ArrayList<>(); + for (OffsetFetchRequestTopics topic : group.topics()) { + for (Integer partitionIndex : topic.partitionIndexes()) { + tpList.add(new TopicPartition(topic.name(), partitionIndex)); + } + } + } + groupIdsToPartitions.put(group.groupId(), tpList); + } + return groupIdsToPartitions; + } + + public Map> groupIdsToTopics() { + Map> groupIdsToTopics = + new HashMap<>(data.groups().size()); + data.groups().forEach(g -> groupIdsToTopics.put(g.groupId(), g.topics())); + return groupIdsToTopics; + } + + public List groupIds() { + return data.groups() + .stream() + .map(OffsetFetchRequestGroup::groupId) + .collect(Collectors.toList()); + } + + private OffsetFetchRequest(OffsetFetchRequestData data, short version) { + super(ApiKeys.OFFSET_FETCH, version); + this.data = data; + } + + public OffsetFetchResponse getErrorResponse(Errors error) { + return getErrorResponse(AbstractResponse.DEFAULT_THROTTLE_TIME, error); + } + + public OffsetFetchResponse getErrorResponse(int throttleTimeMs, Errors error) { + Map responsePartitions = new HashMap<>(); + if (version() < 2) { + OffsetFetchResponse.PartitionData partitionError = new OffsetFetchResponse.PartitionData( + OffsetFetchResponse.INVALID_OFFSET, + Optional.empty(), + OffsetFetchResponse.NO_METADATA, + error); + + for (OffsetFetchRequestTopic topic : this.data.topics()) { + for (int partitionIndex : topic.partitionIndexes()) { + responsePartitions.put( + new TopicPartition(topic.name(), partitionIndex), partitionError); + } + } + return new OffsetFetchResponse(error, responsePartitions); + } + if (version() == 2) { + return new OffsetFetchResponse(error, responsePartitions); + } + if (version() >= 3 && version() < 8) { + return new OffsetFetchResponse(throttleTimeMs, error, responsePartitions); + } + List groupIds = groupIds(); + Map errorsMap = new HashMap<>(groupIds.size()); + Map> partitionMap = + new HashMap<>(groupIds.size()); + for (String g : groupIds) { + errorsMap.put(g, error); + partitionMap.put(g, responsePartitions); + } + return new OffsetFetchResponse(throttleTimeMs, errorsMap, partitionMap); + } + + @Override + public OffsetFetchResponse getErrorResponse(int throttleTimeMs, Throwable e) { + return getErrorResponse(throttleTimeMs, Errors.forException(e)); + } + + public static OffsetFetchRequest parse(ByteBuffer buffer, short version) { + return new OffsetFetchRequest(new OffsetFetchRequestData(new ByteBufferAccessor(buffer), version), version); + } + + public boolean isAllPartitions() { + return data.topics() == ALL_TOPIC_PARTITIONS; + } + + public boolean isAllPartitionsForGroup(String groupId) { + OffsetFetchRequestGroup group = data + .groups() + .stream() + .filter(g -> g.groupId().equals(groupId)) + .collect(Collectors.toList()) + .get(0); + return group.topics() == ALL_TOPIC_PARTITIONS_BATCH; + } + + @Override + public OffsetFetchRequestData data() { + return data; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/OffsetFetchResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/OffsetFetchResponse.java new file mode 100644 index 0000000..213182e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/OffsetFetchResponse.java @@ -0,0 +1,347 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import java.util.Map.Entry; +import java.util.stream.Collectors; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.OffsetFetchResponseData; +import org.apache.kafka.common.message.OffsetFetchResponseData.OffsetFetchResponseGroup; +import org.apache.kafka.common.message.OffsetFetchResponseData.OffsetFetchResponsePartition; +import org.apache.kafka.common.message.OffsetFetchResponseData.OffsetFetchResponsePartitions; +import org.apache.kafka.common.message.OffsetFetchResponseData.OffsetFetchResponseTopic; +import org.apache.kafka.common.message.OffsetFetchResponseData.OffsetFetchResponseTopics; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +import static org.apache.kafka.common.record.RecordBatch.NO_PARTITION_LEADER_EPOCH; + +/** + * Possible error codes: + * + * - Partition errors: + * - {@link Errors#UNKNOWN_TOPIC_OR_PARTITION} + * - {@link Errors#TOPIC_AUTHORIZATION_FAILED} + * - {@link Errors#UNSTABLE_OFFSET_COMMIT} + * + * - Group or coordinator errors: + * - {@link Errors#COORDINATOR_LOAD_IN_PROGRESS} + * - {@link Errors#COORDINATOR_NOT_AVAILABLE} + * - {@link Errors#NOT_COORDINATOR} + * - {@link Errors#GROUP_AUTHORIZATION_FAILED} + */ +public class OffsetFetchResponse extends AbstractResponse { + public static final long INVALID_OFFSET = -1L; + public static final String NO_METADATA = ""; + public static final PartitionData UNKNOWN_PARTITION = new PartitionData(INVALID_OFFSET, + Optional.empty(), + NO_METADATA, + Errors.UNKNOWN_TOPIC_OR_PARTITION); + public static final PartitionData UNAUTHORIZED_PARTITION = new PartitionData(INVALID_OFFSET, + Optional.empty(), + NO_METADATA, + Errors.TOPIC_AUTHORIZATION_FAILED); + private static final List PARTITION_ERRORS = Arrays.asList( + Errors.UNKNOWN_TOPIC_OR_PARTITION, Errors.TOPIC_AUTHORIZATION_FAILED); + + private final OffsetFetchResponseData data; + private final Errors error; + private final Map groupLevelErrors = new HashMap<>(); + + public static final class PartitionData { + public final long offset; + public final String metadata; + public final Errors error; + public final Optional leaderEpoch; + + public PartitionData(long offset, + Optional leaderEpoch, + String metadata, + Errors error) { + this.offset = offset; + this.leaderEpoch = leaderEpoch; + this.metadata = metadata; + this.error = error; + } + + public boolean hasError() { + return this.error != Errors.NONE; + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof PartitionData)) + return false; + PartitionData otherPartition = (PartitionData) other; + return Objects.equals(this.offset, otherPartition.offset) + && Objects.equals(this.leaderEpoch, otherPartition.leaderEpoch) + && Objects.equals(this.metadata, otherPartition.metadata) + && Objects.equals(this.error, otherPartition.error); + } + + @Override + public String toString() { + return "PartitionData(" + + "offset=" + offset + + ", leaderEpoch=" + leaderEpoch.orElse(NO_PARTITION_LEADER_EPOCH) + + ", metadata=" + metadata + + ", error='" + error.toString() + + ")"; + } + + @Override + public int hashCode() { + return Objects.hash(offset, leaderEpoch, metadata, error); + } + } + + public OffsetFetchResponse(OffsetFetchResponseData data) { + super(ApiKeys.OFFSET_FETCH); + this.data = data; + this.error = null; + } + + /** + * Constructor without throttle time. + * @param error Potential coordinator or group level error code (for api version 2 and later) + * @param responseData Fetched offset information grouped by topic-partition + */ + public OffsetFetchResponse(Errors error, Map responseData) { + this(DEFAULT_THROTTLE_TIME, error, responseData); + } + + /** + * Constructor with throttle time for version 0 to 7 + * @param throttleTimeMs The time in milliseconds that this response was throttled + * @param error Potential coordinator or group level error code (for api version 2 and later) + * @param responseData Fetched offset information grouped by topic-partition + */ + public OffsetFetchResponse(int throttleTimeMs, Errors error, Map responseData) { + super(ApiKeys.OFFSET_FETCH); + Map offsetFetchResponseTopicMap = new HashMap<>(); + for (Map.Entry entry : responseData.entrySet()) { + String topicName = entry.getKey().topic(); + OffsetFetchResponseTopic topic = offsetFetchResponseTopicMap.getOrDefault( + topicName, new OffsetFetchResponseTopic().setName(topicName)); + PartitionData partitionData = entry.getValue(); + topic.partitions().add(new OffsetFetchResponsePartition() + .setPartitionIndex(entry.getKey().partition()) + .setErrorCode(partitionData.error.code()) + .setCommittedOffset(partitionData.offset) + .setCommittedLeaderEpoch( + partitionData.leaderEpoch.orElse(NO_PARTITION_LEADER_EPOCH)) + .setMetadata(partitionData.metadata) + ); + offsetFetchResponseTopicMap.put(topicName, topic); + } + + this.data = new OffsetFetchResponseData() + .setTopics(new ArrayList<>(offsetFetchResponseTopicMap.values())) + .setErrorCode(error.code()) + .setThrottleTimeMs(throttleTimeMs); + this.error = error; + } + + /** + * Constructor with throttle time for version 8 and above. + * @param throttleTimeMs The time in milliseconds that this response was throttled + * @param errors Potential coordinator or group level error code + * @param responseData Fetched offset information grouped by topic-partition and by group + */ + public OffsetFetchResponse(int throttleTimeMs, + Map errors, Map> responseData) { + super(ApiKeys.OFFSET_FETCH); + List groupList = new ArrayList<>(); + for (Entry> entry : responseData.entrySet()) { + String groupName = entry.getKey(); + Map partitionDataMap = entry.getValue(); + Map offsetFetchResponseTopicsMap = new HashMap<>(); + for (Entry partitionEntry : partitionDataMap.entrySet()) { + String topicName = partitionEntry.getKey().topic(); + OffsetFetchResponseTopics topic = + offsetFetchResponseTopicsMap.getOrDefault(topicName, + new OffsetFetchResponseTopics().setName(topicName)); + PartitionData partitionData = partitionEntry.getValue(); + topic.partitions().add(new OffsetFetchResponsePartitions() + .setPartitionIndex(partitionEntry.getKey().partition()) + .setErrorCode(partitionData.error.code()) + .setCommittedOffset(partitionData.offset) + .setCommittedLeaderEpoch( + partitionData.leaderEpoch.orElse(NO_PARTITION_LEADER_EPOCH)) + .setMetadata(partitionData.metadata)); + offsetFetchResponseTopicsMap.put(topicName, topic); + } + groupList.add(new OffsetFetchResponseGroup() + .setGroupId(groupName) + .setTopics(new ArrayList<>(offsetFetchResponseTopicsMap.values())) + .setErrorCode(errors.get(groupName).code())); + groupLevelErrors.put(groupName, errors.get(groupName)); + } + this.data = new OffsetFetchResponseData() + .setGroups(groupList) + .setThrottleTimeMs(throttleTimeMs); + this.error = null; + } + + public OffsetFetchResponse(OffsetFetchResponseData data, short version) { + super(ApiKeys.OFFSET_FETCH); + this.data = data; + // for version 2 and later use the top-level error code (in ERROR_CODE_KEY_NAME) from the response. + // for older versions there is no top-level error in the response and all errors are partition errors, + // so if there is a group or coordinator error at the partition level use that as the top-level error. + // this way clients can depend on the top-level error regardless of the offset fetch version. + // we return the error differently starting with version 8, so we will only populate the + // error field if we are between version 2 and 7. if we are in version 8 or greater, then + // we will populate the map of group id to error codes. + if (version < 8) { + this.error = version >= 2 ? Errors.forCode(data.errorCode()) : topLevelError(data); + } else { + for (OffsetFetchResponseGroup group : data.groups()) { + this.groupLevelErrors.put(group.groupId(), Errors.forCode(group.errorCode())); + } + this.error = null; + } + } + + private static Errors topLevelError(OffsetFetchResponseData data) { + for (OffsetFetchResponseTopic topic : data.topics()) { + for (OffsetFetchResponsePartition partition : topic.partitions()) { + Errors partitionError = Errors.forCode(partition.errorCode()); + if (partitionError != Errors.NONE && !PARTITION_ERRORS.contains(partitionError)) { + return partitionError; + } + } + } + return Errors.NONE; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + public boolean hasError() { + return error != Errors.NONE; + } + + public boolean groupHasError(String groupId) { + return groupLevelErrors.get(groupId) != Errors.NONE; + } + + public Errors error() { + return error; + } + + public Errors groupLevelError(String groupId) { + if (error != null) { + return error; + } + return groupLevelErrors.get(groupId); + } + + @Override + public Map errorCounts() { + Map counts = new HashMap<>(); + if (!groupLevelErrors.isEmpty()) { + // built response with v8 or above + for (Map.Entry entry : groupLevelErrors.entrySet()) { + updateErrorCounts(counts, entry.getValue()); + } + for (OffsetFetchResponseGroup group : data.groups()) { + group.topics().forEach(topic -> + topic.partitions().forEach(partition -> + updateErrorCounts(counts, Errors.forCode(partition.errorCode())))); + } + } else { + // built response with v0-v7 + updateErrorCounts(counts, error); + data.topics().forEach(topic -> + topic.partitions().forEach(partition -> + updateErrorCounts(counts, Errors.forCode(partition.errorCode())))); + } + return counts; + } + + // package-private for testing purposes + Map responseDataV0ToV7() { + Map responseData = new HashMap<>(); + for (OffsetFetchResponseTopic topic : data.topics()) { + for (OffsetFetchResponsePartition partition : topic.partitions()) { + responseData.put(new TopicPartition(topic.name(), partition.partitionIndex()), + new PartitionData(partition.committedOffset(), + RequestUtils.getLeaderEpoch(partition.committedLeaderEpoch()), + partition.metadata(), + Errors.forCode(partition.errorCode())) + ); + } + } + return responseData; + } + + private Map buildResponseData(String groupId) { + Map responseData = new HashMap<>(); + OffsetFetchResponseGroup group = data + .groups() + .stream() + .filter(g -> g.groupId().equals(groupId)) + .collect(Collectors.toList()) + .get(0); + for (OffsetFetchResponseTopics topic : group.topics()) { + for (OffsetFetchResponsePartitions partition : topic.partitions()) { + responseData.put(new TopicPartition(topic.name(), partition.partitionIndex()), + new PartitionData(partition.committedOffset(), + RequestUtils.getLeaderEpoch(partition.committedLeaderEpoch()), + partition.metadata(), + Errors.forCode(partition.errorCode())) + ); + } + } + return responseData; + } + + public Map partitionDataMap(String groupId) { + if (groupLevelErrors.isEmpty()) { + return responseDataV0ToV7(); + } + return buildResponseData(groupId); + } + + public static OffsetFetchResponse parse(ByteBuffer buffer, short version) { + return new OffsetFetchResponse(new OffsetFetchResponseData(new ByteBufferAccessor(buffer), version), version); + } + + @Override + public OffsetFetchResponseData data() { + return data; + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 4; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/OffsetsForLeaderEpochRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/OffsetsForLeaderEpochRequest.java new file mode 100644 index 0000000..f9f4f23 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/OffsetsForLeaderEpochRequest.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData; +import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderTopicCollection; +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData; +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset; +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.OffsetForLeaderTopicResult; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; + +import static org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.UNDEFINED_EPOCH; +import static org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.UNDEFINED_EPOCH_OFFSET; + +public class OffsetsForLeaderEpochRequest extends AbstractRequest { + + /** + * Sentinel replica_id value to indicate a regular consumer rather than another broker + */ + public static final int CONSUMER_REPLICA_ID = -1; + + /** + * Sentinel replica_id which indicates either a debug consumer or a replica which is using + * an old version of the protocol. + */ + public static final int DEBUGGING_REPLICA_ID = -2; + + private final OffsetForLeaderEpochRequestData data; + + public static class Builder extends AbstractRequest.Builder { + private final OffsetForLeaderEpochRequestData data; + + Builder(short oldestAllowedVersion, short latestAllowedVersion, OffsetForLeaderEpochRequestData data) { + super(ApiKeys.OFFSET_FOR_LEADER_EPOCH, oldestAllowedVersion, latestAllowedVersion); + this.data = data; + } + + public static Builder forConsumer(OffsetForLeaderTopicCollection epochsByPartition) { + // Old versions of this API require CLUSTER permission which is not typically granted + // to clients. Beginning with version 3, the broker requires only TOPIC Describe + // permission for the topic of each requested partition. In order to ensure client + // compatibility, we only send this request when we can guarantee the relaxed permissions. + OffsetForLeaderEpochRequestData data = new OffsetForLeaderEpochRequestData(); + data.setReplicaId(CONSUMER_REPLICA_ID); + data.setTopics(epochsByPartition); + return new Builder((short) 3, ApiKeys.OFFSET_FOR_LEADER_EPOCH.latestVersion(), data); + } + + public static Builder forFollower(short version, OffsetForLeaderTopicCollection epochsByPartition, int replicaId) { + OffsetForLeaderEpochRequestData data = new OffsetForLeaderEpochRequestData(); + data.setReplicaId(replicaId); + data.setTopics(epochsByPartition); + return new Builder(version, version, data); + } + + @Override + public OffsetsForLeaderEpochRequest build(short version) { + if (version < oldestAllowedVersion() || version > latestAllowedVersion()) + throw new UnsupportedVersionException("Cannot build " + this + " with version " + version); + + return new OffsetsForLeaderEpochRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + public OffsetsForLeaderEpochRequest(OffsetForLeaderEpochRequestData data, short version) { + super(ApiKeys.OFFSET_FOR_LEADER_EPOCH, version); + this.data = data; + } + + @Override + public OffsetForLeaderEpochRequestData data() { + return data; + } + + public int replicaId() { + return data.replicaId(); + } + + public static OffsetsForLeaderEpochRequest parse(ByteBuffer buffer, short version) { + return new OffsetsForLeaderEpochRequest(new OffsetForLeaderEpochRequestData(new ByteBufferAccessor(buffer), version), version); + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + Errors error = Errors.forException(e); + + OffsetForLeaderEpochResponseData responseData = new OffsetForLeaderEpochResponseData(); + data.topics().forEach(topic -> { + OffsetForLeaderTopicResult topicData = new OffsetForLeaderTopicResult() + .setTopic(topic.topic()); + topic.partitions().forEach(partition -> + topicData.partitions().add(new EpochEndOffset() + .setPartition(partition.partition()) + .setErrorCode(error.code()) + .setLeaderEpoch(UNDEFINED_EPOCH) + .setEndOffset(UNDEFINED_EPOCH_OFFSET))); + responseData.topics().add(topicData); + }); + + return new OffsetsForLeaderEpochResponse(responseData); + } + + /** + * Check whether a broker allows Topic-level permissions in order to use the + * OffsetForLeaderEpoch API. Old versions require Cluster permission. + */ + public static boolean supportsTopicPermission(short latestUsableVersion) { + return latestUsableVersion >= 3; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/OffsetsForLeaderEpochResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/OffsetsForLeaderEpochResponse.java new file mode 100644 index 0000000..893d5a2 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/OffsetsForLeaderEpochResponse.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; + +import static org.apache.kafka.common.record.RecordBatch.NO_PARTITION_LEADER_EPOCH; + +/** + * Possible error codes: + * - {@link Errors#TOPIC_AUTHORIZATION_FAILED} If the user does not have DESCRIBE access to a requested topic + * - {@link Errors#REPLICA_NOT_AVAILABLE} If the request is received by a broker with version < 2.6 which is not a replica + * - {@link Errors#NOT_LEADER_OR_FOLLOWER} If the broker is not a leader or follower and either the provided leader epoch + * matches the known leader epoch on the broker or is empty + * - {@link Errors#FENCED_LEADER_EPOCH} If the epoch is lower than the broker's epoch + * - {@link Errors#UNKNOWN_LEADER_EPOCH} If the epoch is larger than the broker's epoch + * - {@link Errors#UNKNOWN_TOPIC_OR_PARTITION} If the broker does not have metadata for a topic or partition + * - {@link Errors#KAFKA_STORAGE_ERROR} If the log directory for one of the requested partitions is offline + * - {@link Errors#UNKNOWN_SERVER_ERROR} For any unexpected errors + */ +public class OffsetsForLeaderEpochResponse extends AbstractResponse { + public static final long UNDEFINED_EPOCH_OFFSET = NO_PARTITION_LEADER_EPOCH; + public static final int UNDEFINED_EPOCH = NO_PARTITION_LEADER_EPOCH; + + private final OffsetForLeaderEpochResponseData data; + + public OffsetsForLeaderEpochResponse(OffsetForLeaderEpochResponseData data) { + super(ApiKeys.OFFSET_FOR_LEADER_EPOCH); + this.data = data; + } + + @Override + public OffsetForLeaderEpochResponseData data() { + return data; + } + + @Override + public Map errorCounts() { + Map errorCounts = new HashMap<>(); + data.topics().forEach(topic -> + topic.partitions().forEach(partition -> + updateErrorCounts(errorCounts, Errors.forCode(partition.errorCode())))); + return errorCounts; + } + + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + public static OffsetsForLeaderEpochResponse parse(ByteBuffer buffer, short version) { + return new OffsetsForLeaderEpochResponse(new OffsetForLeaderEpochResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public String toString() { + return data.toString(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/ProduceRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/ProduceRequest.java new file mode 100644 index 0000000..758631a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/ProduceRequest.java @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.InvalidRecordException; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.UnsupportedCompressionTypeException; +import org.apache.kafka.common.message.ProduceRequestData; +import org.apache.kafka.common.message.ProduceResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.BaseRecords; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.record.Records; +import org.apache.kafka.common.utils.Utils; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.apache.kafka.common.requests.ProduceResponse.INVALID_OFFSET; + +public class ProduceRequest extends AbstractRequest { + + public static Builder forMagic(byte magic, ProduceRequestData data) { + // Message format upgrades correspond with a bump in the produce request version. Older + // message format versions are generally not supported by the produce request versions + // following the bump. + + final short minVersion; + final short maxVersion; + if (magic < RecordBatch.MAGIC_VALUE_V2) { + minVersion = 2; + maxVersion = 2; + } else { + minVersion = 3; + maxVersion = ApiKeys.PRODUCE.latestVersion(); + } + return new Builder(minVersion, maxVersion, data); + } + + public static Builder forCurrentMagic(ProduceRequestData data) { + return forMagic(RecordBatch.CURRENT_MAGIC_VALUE, data); + } + + public static class Builder extends AbstractRequest.Builder { + private final ProduceRequestData data; + + public Builder(short minVersion, + short maxVersion, + ProduceRequestData data) { + super(ApiKeys.PRODUCE, minVersion, maxVersion); + this.data = data; + } + + @Override + public ProduceRequest build(short version) { + return build(version, true); + } + + // Visible for testing only + public ProduceRequest buildUnsafe(short version) { + return build(version, false); + } + + private ProduceRequest build(short version, boolean validate) { + if (validate) { + // Validate the given records first + data.topicData().forEach(tpd -> + tpd.partitionData().forEach(partitionProduceData -> + ProduceRequest.validateRecords(version, partitionProduceData.records()))); + } + return new ProduceRequest(data, version); + } + + @Override + public String toString() { + StringBuilder bld = new StringBuilder(); + bld.append("(type=ProduceRequest") + .append(", acks=").append(data.acks()) + .append(", timeout=").append(data.timeoutMs()) + .append(", partitionRecords=(").append(data.topicData().stream().flatMap(d -> d.partitionData().stream()).collect(Collectors.toList())) + .append("), transactionalId='").append(data.transactionalId() != null ? data.transactionalId() : "") + .append("'"); + return bld.toString(); + } + } + + /** + * We have to copy acks, timeout, transactionalId and partitionSizes from data since data maybe reset to eliminate + * the reference to ByteBuffer but those metadata are still useful. + */ + private final short acks; + private final int timeout; + private final String transactionalId; + // This is set to null by `clearPartitionRecords` to prevent unnecessary memory retention when a produce request is + // put in the purgatory (due to client throttling, it can take a while before the response is sent). + // Care should be taken in methods that use this field. + private volatile ProduceRequestData data; + // the partitionSizes is lazily initialized since it is used by server-side in production. + private volatile Map partitionSizes; + + public ProduceRequest(ProduceRequestData produceRequestData, short version) { + super(ApiKeys.PRODUCE, version); + this.data = produceRequestData; + this.acks = data.acks(); + this.timeout = data.timeoutMs(); + this.transactionalId = data.transactionalId(); + } + + // visible for testing + Map partitionSizes() { + if (partitionSizes == null) { + // this method may be called by different thread (see the comment on data) + synchronized (this) { + if (partitionSizes == null) { + partitionSizes = new HashMap<>(); + data.topicData().forEach(topicData -> + topicData.partitionData().forEach(partitionData -> + partitionSizes.compute(new TopicPartition(topicData.name(), partitionData.index()), + (ignored, previousValue) -> + partitionData.records().sizeInBytes() + (previousValue == null ? 0 : previousValue)) + ) + ); + } + } + } + return partitionSizes; + } + + /** + * @return data or IllegalStateException if the data is removed (to prevent unnecessary memory retention). + */ + @Override + public ProduceRequestData data() { + // Store it in a local variable to protect against concurrent updates + ProduceRequestData tmp = data; + if (tmp == null) + throw new IllegalStateException("The partition records are no longer available because clearPartitionRecords() has been invoked."); + return tmp; + } + + @Override + public String toString(boolean verbose) { + // Use the same format as `Struct.toString()` + StringBuilder bld = new StringBuilder(); + bld.append("{acks=").append(acks) + .append(",timeout=").append(timeout); + + if (verbose) + bld.append(",partitionSizes=").append(Utils.mkString(partitionSizes(), "[", "]", "=", ",")); + else + bld.append(",numPartitions=").append(partitionSizes().size()); + + bld.append("}"); + return bld.toString(); + } + + @Override + public ProduceResponse getErrorResponse(int throttleTimeMs, Throwable e) { + /* In case the producer doesn't actually want any response */ + if (acks == 0) return null; + ApiError apiError = ApiError.fromThrowable(e); + ProduceResponseData data = new ProduceResponseData().setThrottleTimeMs(throttleTimeMs); + partitionSizes().forEach((tp, ignored) -> { + ProduceResponseData.TopicProduceResponse tpr = data.responses().find(tp.topic()); + if (tpr == null) { + tpr = new ProduceResponseData.TopicProduceResponse().setName(tp.topic()); + data.responses().add(tpr); + } + tpr.partitionResponses().add(new ProduceResponseData.PartitionProduceResponse() + .setIndex(tp.partition()) + .setRecordErrors(Collections.emptyList()) + .setBaseOffset(INVALID_OFFSET) + .setLogAppendTimeMs(RecordBatch.NO_TIMESTAMP) + .setLogStartOffset(INVALID_OFFSET) + .setErrorMessage(apiError.message()) + .setErrorCode(apiError.error().code())); + }); + return new ProduceResponse(data); + } + + @Override + public Map errorCounts(Throwable e) { + Errors error = Errors.forException(e); + return Collections.singletonMap(error, partitionSizes().size()); + } + + public short acks() { + return acks; + } + + public int timeout() { + return timeout; + } + + public String transactionalId() { + return transactionalId; + } + + public void clearPartitionRecords() { + // lazily initialize partitionSizes. + partitionSizes(); + data = null; + } + + public static void validateRecords(short version, BaseRecords baseRecords) { + if (version >= 3) { + if (baseRecords instanceof Records) { + Records records = (Records) baseRecords; + Iterator iterator = records.batches().iterator(); + if (!iterator.hasNext()) + throw new InvalidRecordException("Produce requests with version " + version + " must have at least " + + "one record batch"); + + RecordBatch entry = iterator.next(); + if (entry.magic() != RecordBatch.MAGIC_VALUE_V2) + throw new InvalidRecordException("Produce requests with version " + version + " are only allowed to " + + "contain record batches with magic version 2"); + if (version < 7 && entry.compressionType() == CompressionType.ZSTD) { + throw new UnsupportedCompressionTypeException("Produce requests with version " + version + " are not allowed to " + + "use ZStandard compression"); + } + + if (iterator.hasNext()) + throw new InvalidRecordException("Produce requests with version " + version + " are only allowed to " + + "contain exactly one record batch"); + } + } + + // Note that we do not do similar validation for older versions to ensure compatibility with + // clients which send the wrong magic version in the wrong version of the produce request. The broker + // did not do this validation before, so we maintain that behavior here. + } + + public static ProduceRequest parse(ByteBuffer buffer, short version) { + return new ProduceRequest(new ProduceRequestData(new ByteBufferAccessor(buffer), version), version); + } + + public static byte requiredMagicForVersion(short produceRequestVersion) { + if (produceRequestVersion < ApiKeys.PRODUCE.oldestVersion() || produceRequestVersion > ApiKeys.PRODUCE.latestVersion()) + throw new IllegalArgumentException("Magic value to use for produce request version " + + produceRequestVersion + " is not known"); + + switch (produceRequestVersion) { + case 0: + case 1: + return RecordBatch.MAGIC_VALUE_V0; + + case 2: + return RecordBatch.MAGIC_VALUE_V1; + + default: + return RecordBatch.MAGIC_VALUE_V2; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/ProduceResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/ProduceResponse.java new file mode 100644 index 0000000..9b94536 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/ProduceResponse.java @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.ProduceResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.RecordBatch; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.stream.Collectors; + +/** + * This wrapper supports both v0 and v8 of ProduceResponse. + * + * Possible error code: + * + * {@link Errors#CORRUPT_MESSAGE} + * {@link Errors#UNKNOWN_TOPIC_OR_PARTITION} + * {@link Errors#NOT_LEADER_OR_FOLLOWER} + * {@link Errors#MESSAGE_TOO_LARGE} + * {@link Errors#INVALID_TOPIC_EXCEPTION} + * {@link Errors#RECORD_LIST_TOO_LARGE} + * {@link Errors#NOT_ENOUGH_REPLICAS} + * {@link Errors#NOT_ENOUGH_REPLICAS_AFTER_APPEND} + * {@link Errors#INVALID_REQUIRED_ACKS} + * {@link Errors#TOPIC_AUTHORIZATION_FAILED} + * {@link Errors#UNSUPPORTED_FOR_MESSAGE_FORMAT} + * {@link Errors#INVALID_PRODUCER_EPOCH} + * {@link Errors#CLUSTER_AUTHORIZATION_FAILED} + * {@link Errors#TRANSACTIONAL_ID_AUTHORIZATION_FAILED} + * {@link Errors#INVALID_RECORD} + */ +public class ProduceResponse extends AbstractResponse { + public static final long INVALID_OFFSET = -1L; + private final ProduceResponseData data; + + public ProduceResponse(ProduceResponseData produceResponseData) { + super(ApiKeys.PRODUCE); + this.data = produceResponseData; + } + + /** + * Constructor for Version 0 + * @param responses Produced data grouped by topic-partition + */ + @Deprecated + public ProduceResponse(Map responses) { + this(responses, DEFAULT_THROTTLE_TIME); + } + + /** + * Constructor for the latest version + * @param responses Produced data grouped by topic-partition + * @param throttleTimeMs Time in milliseconds the response was throttled + */ + @Deprecated + public ProduceResponse(Map responses, int throttleTimeMs) { + this(toData(responses, throttleTimeMs)); + } + + private static ProduceResponseData toData(Map responses, int throttleTimeMs) { + ProduceResponseData data = new ProduceResponseData().setThrottleTimeMs(throttleTimeMs); + responses.forEach((tp, response) -> { + ProduceResponseData.TopicProduceResponse tpr = data.responses().find(tp.topic()); + if (tpr == null) { + tpr = new ProduceResponseData.TopicProduceResponse().setName(tp.topic()); + data.responses().add(tpr); + } + tpr.partitionResponses() + .add(new ProduceResponseData.PartitionProduceResponse() + .setIndex(tp.partition()) + .setBaseOffset(response.baseOffset) + .setLogStartOffset(response.logStartOffset) + .setLogAppendTimeMs(response.logAppendTime) + .setErrorMessage(response.errorMessage) + .setErrorCode(response.error.code()) + .setRecordErrors(response.recordErrors + .stream() + .map(e -> new ProduceResponseData.BatchIndexAndErrorMessage() + .setBatchIndex(e.batchIndex) + .setBatchIndexErrorMessage(e.message)) + .collect(Collectors.toList()))); + }); + return data; + } + + @Override + public ProduceResponseData data() { + return this.data; + } + + @Override + public int throttleTimeMs() { + return this.data.throttleTimeMs(); + } + + @Override + public Map errorCounts() { + Map errorCounts = new HashMap<>(); + data.responses().forEach(t -> t.partitionResponses().forEach(p -> updateErrorCounts(errorCounts, Errors.forCode(p.errorCode())))); + return errorCounts; + } + + public static final class PartitionResponse { + public Errors error; + public long baseOffset; + public long logAppendTime; + public long logStartOffset; + public List recordErrors; + public String errorMessage; + + public PartitionResponse(Errors error) { + this(error, INVALID_OFFSET, RecordBatch.NO_TIMESTAMP, INVALID_OFFSET); + } + + public PartitionResponse(Errors error, String errorMessage) { + this(error, INVALID_OFFSET, RecordBatch.NO_TIMESTAMP, INVALID_OFFSET, Collections.emptyList(), errorMessage); + } + + public PartitionResponse(Errors error, long baseOffset, long logAppendTime, long logStartOffset) { + this(error, baseOffset, logAppendTime, logStartOffset, Collections.emptyList(), null); + } + + public PartitionResponse(Errors error, long baseOffset, long logAppendTime, long logStartOffset, List recordErrors) { + this(error, baseOffset, logAppendTime, logStartOffset, recordErrors, null); + } + + public PartitionResponse(Errors error, long baseOffset, long logAppendTime, long logStartOffset, List recordErrors, String errorMessage) { + this.error = error; + this.baseOffset = baseOffset; + this.logAppendTime = logAppendTime; + this.logStartOffset = logStartOffset; + this.recordErrors = recordErrors; + this.errorMessage = errorMessage; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PartitionResponse that = (PartitionResponse) o; + return baseOffset == that.baseOffset && + logAppendTime == that.logAppendTime && + logStartOffset == that.logStartOffset && + error == that.error && + Objects.equals(recordErrors, that.recordErrors) && + Objects.equals(errorMessage, that.errorMessage); + } + + @Override + public int hashCode() { + return Objects.hash(error, baseOffset, logAppendTime, logStartOffset, recordErrors, errorMessage); + } + + @Override + public String toString() { + StringBuilder b = new StringBuilder(); + b.append('{'); + b.append("error: "); + b.append(error); + b.append(",offset: "); + b.append(baseOffset); + b.append(",logAppendTime: "); + b.append(logAppendTime); + b.append(", logStartOffset: "); + b.append(logStartOffset); + b.append(", recordErrors: "); + b.append(recordErrors); + b.append(", errorMessage: "); + if (errorMessage != null) { + b.append(errorMessage); + } else { + b.append("null"); + } + b.append('}'); + return b.toString(); + } + } + + public static final class RecordError { + public final int batchIndex; + public final String message; + + public RecordError(int batchIndex, String message) { + this.batchIndex = batchIndex; + this.message = message; + } + + public RecordError(int batchIndex) { + this.batchIndex = batchIndex; + this.message = null; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RecordError that = (RecordError) o; + return batchIndex == that.batchIndex && + Objects.equals(message, that.message); + } + + @Override + public int hashCode() { + return Objects.hash(batchIndex, message); + } + } + + public static ProduceResponse parse(ByteBuffer buffer, short version) { + return new ProduceResponse(new ProduceResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 6; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/RenewDelegationTokenRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/RenewDelegationTokenRequest.java new file mode 100644 index 0000000..91a9f96 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/RenewDelegationTokenRequest.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import java.nio.ByteBuffer; + +import org.apache.kafka.common.message.RenewDelegationTokenRequestData; +import org.apache.kafka.common.message.RenewDelegationTokenResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +public class RenewDelegationTokenRequest extends AbstractRequest { + + private final RenewDelegationTokenRequestData data; + + public RenewDelegationTokenRequest(RenewDelegationTokenRequestData data, short version) { + super(ApiKeys.RENEW_DELEGATION_TOKEN, version); + this.data = data; + } + + public static RenewDelegationTokenRequest parse(ByteBuffer buffer, short version) { + return new RenewDelegationTokenRequest(new RenewDelegationTokenRequestData( + new ByteBufferAccessor(buffer), version), version); + } + + @Override + public RenewDelegationTokenRequestData data() { + return data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + return new RenewDelegationTokenResponse( + new RenewDelegationTokenResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setErrorCode(Errors.forException(e).code())); + } + + public static class Builder extends AbstractRequest.Builder { + private final RenewDelegationTokenRequestData data; + + public Builder(RenewDelegationTokenRequestData data) { + super(ApiKeys.RENEW_DELEGATION_TOKEN); + this.data = data; + } + + @Override + public RenewDelegationTokenRequest build(short version) { + return new RenewDelegationTokenRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/RenewDelegationTokenResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/RenewDelegationTokenResponse.java new file mode 100644 index 0000000..30708ff --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/RenewDelegationTokenResponse.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import java.nio.ByteBuffer; +import java.util.Map; + +import org.apache.kafka.common.message.RenewDelegationTokenResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +public class RenewDelegationTokenResponse extends AbstractResponse { + + private final RenewDelegationTokenResponseData data; + + public RenewDelegationTokenResponse(RenewDelegationTokenResponseData data) { + super(ApiKeys.RENEW_DELEGATION_TOKEN); + this.data = data; + } + + public static RenewDelegationTokenResponse parse(ByteBuffer buffer, short version) { + return new RenewDelegationTokenResponse(new RenewDelegationTokenResponseData( + new ByteBufferAccessor(buffer), version)); + } + + @Override + public Map errorCounts() { + return errorCounts(error()); + } + + @Override + public RenewDelegationTokenResponseData data() { + return data; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + public Errors error() { + return Errors.forCode(data.errorCode()); + } + + public long expiryTimestamp() { + return data.expiryTimestampMs(); + } + + public boolean hasError() { + return error() != Errors.NONE; + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 1; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/RequestAndSize.java b/clients/src/main/java/org/apache/kafka/common/requests/RequestAndSize.java new file mode 100644 index 0000000..4f94a09 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/RequestAndSize.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +public class RequestAndSize { + public final AbstractRequest request; + public final int size; + + public RequestAndSize(AbstractRequest request, int size) { + this.request = request; + this.size = size; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/RequestContext.java b/clients/src/main/java/org/apache/kafka/common/requests/RequestContext.java new file mode 100644 index 0000000..d7a6df1 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/RequestContext.java @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.errors.InvalidRequestException; +import org.apache.kafka.common.message.ApiVersionsRequestData; +import org.apache.kafka.common.network.ClientInformation; +import org.apache.kafka.common.network.ListenerName; +import org.apache.kafka.common.network.Send; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.security.auth.KafkaPrincipal; +import org.apache.kafka.common.security.auth.KafkaPrincipalSerde; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.server.authorizer.AuthorizableRequestContext; + +import java.net.InetAddress; +import java.nio.ByteBuffer; +import java.util.Optional; + +import static org.apache.kafka.common.protocol.ApiKeys.API_VERSIONS; + +public class RequestContext implements AuthorizableRequestContext { + public final RequestHeader header; + public final String connectionId; + public final InetAddress clientAddress; + public final KafkaPrincipal principal; + public final ListenerName listenerName; + public final SecurityProtocol securityProtocol; + public final ClientInformation clientInformation; + public final boolean fromPrivilegedListener; + public final Optional principalSerde; + + public RequestContext(RequestHeader header, + String connectionId, + InetAddress clientAddress, + KafkaPrincipal principal, + ListenerName listenerName, + SecurityProtocol securityProtocol, + ClientInformation clientInformation, + boolean fromPrivilegedListener) { + this(header, + connectionId, + clientAddress, + principal, + listenerName, + securityProtocol, + clientInformation, + fromPrivilegedListener, + Optional.empty()); + } + + public RequestContext(RequestHeader header, + String connectionId, + InetAddress clientAddress, + KafkaPrincipal principal, + ListenerName listenerName, + SecurityProtocol securityProtocol, + ClientInformation clientInformation, + boolean fromPrivilegedListener, + Optional principalSerde) { + this.header = header; + this.connectionId = connectionId; + this.clientAddress = clientAddress; + this.principal = principal; + this.listenerName = listenerName; + this.securityProtocol = securityProtocol; + this.clientInformation = clientInformation; + this.fromPrivilegedListener = fromPrivilegedListener; + this.principalSerde = principalSerde; + } + + public RequestAndSize parseRequest(ByteBuffer buffer) { + if (isUnsupportedApiVersionsRequest()) { + // Unsupported ApiVersion requests are treated as v0 requests and are not parsed + ApiVersionsRequest apiVersionsRequest = new ApiVersionsRequest(new ApiVersionsRequestData(), (short) 0, header.apiVersion()); + return new RequestAndSize(apiVersionsRequest, 0); + } else { + ApiKeys apiKey = header.apiKey(); + try { + short apiVersion = header.apiVersion(); + return AbstractRequest.parseRequest(apiKey, apiVersion, buffer); + } catch (Throwable ex) { + throw new InvalidRequestException("Error getting request for apiKey: " + apiKey + + ", apiVersion: " + header.apiVersion() + + ", connectionId: " + connectionId + + ", listenerName: " + listenerName + + ", principal: " + principal, ex); + } + } + } + + /** + * Build a {@link Send} for direct transmission of the provided response + * over the network. + */ + public Send buildResponseSend(AbstractResponse body) { + return body.toSend(header.toResponseHeader(), apiVersion()); + } + + /** + * Serialize a response into a {@link ByteBuffer}. This is used when the response + * will be encapsulated in an {@link EnvelopeResponse}. The buffer will contain + * both the serialized {@link ResponseHeader} as well as the bytes from the response. + * There is no `size` prefix unlike the output from {@link #buildResponseSend(AbstractResponse)}. + * + * Note that envelope requests are reserved only for APIs which have set the + * {@link ApiKeys#forwardable} flag. Notably the `Fetch` API cannot be forwarded, + * so we do not lose the benefit of "zero copy" transfers from disk. + */ + public ByteBuffer buildResponseEnvelopePayload(AbstractResponse body) { + return body.serializeWithHeader(header.toResponseHeader(), apiVersion()); + } + + private boolean isUnsupportedApiVersionsRequest() { + return header.apiKey() == API_VERSIONS && !API_VERSIONS.isVersionSupported(header.apiVersion()); + } + + public short apiVersion() { + // Use v0 when serializing an unhandled ApiVersion response + if (isUnsupportedApiVersionsRequest()) + return 0; + return header.apiVersion(); + } + + @Override + public String listenerName() { + return listenerName.value(); + } + + @Override + public SecurityProtocol securityProtocol() { + return securityProtocol; + } + + @Override + public KafkaPrincipal principal() { + return principal; + } + + @Override + public InetAddress clientAddress() { + return clientAddress; + } + + @Override + public int requestType() { + return header.apiKey().id; + } + + @Override + public int requestVersion() { + return header.apiVersion(); + } + + @Override + public String clientId() { + return header.clientId(); + } + + @Override + public int correlationId() { + return header.correlationId(); + } + + @Override + public String toString() { + return "RequestContext(" + + "header=" + header + + ", connectionId='" + connectionId + '\'' + + ", clientAddress=" + clientAddress + + ", principal=" + principal + + ", listenerName=" + listenerName + + ", securityProtocol=" + securityProtocol + + ", clientInformation=" + clientInformation + + ", fromPrivilegedListener=" + fromPrivilegedListener + + ", principalSerde=" + principalSerde + + ')'; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/RequestHeader.java b/clients/src/main/java/org/apache/kafka/common/requests/RequestHeader.java new file mode 100644 index 0000000..30a8f70 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/RequestHeader.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.errors.InvalidRequestException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.RequestHeaderData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.ObjectSerializationCache; + +import java.nio.ByteBuffer; + +/** + * The header for a request in the Kafka protocol + */ +public class RequestHeader implements AbstractRequestResponse { + private final RequestHeaderData data; + private final short headerVersion; + + public RequestHeader(ApiKeys requestApiKey, short requestVersion, String clientId, int correlationId) { + this(new RequestHeaderData(). + setRequestApiKey(requestApiKey.id). + setRequestApiVersion(requestVersion). + setClientId(clientId). + setCorrelationId(correlationId), + requestApiKey.requestHeaderVersion(requestVersion)); + } + + public RequestHeader(RequestHeaderData data, short headerVersion) { + this.data = data; + this.headerVersion = headerVersion; + } + + public ApiKeys apiKey() { + return ApiKeys.forId(data.requestApiKey()); + } + + public short apiVersion() { + return data.requestApiVersion(); + } + + public short headerVersion() { + return headerVersion; + } + + public String clientId() { + return data.clientId(); + } + + public int correlationId() { + return data.correlationId(); + } + + public RequestHeaderData data() { + return data; + } + + public void write(ByteBuffer buffer, ObjectSerializationCache serializationCache) { + data.write(new ByteBufferAccessor(buffer), serializationCache, headerVersion); + } + + public int size(ObjectSerializationCache serializationCache) { + return data.size(serializationCache, headerVersion); + } + + public ResponseHeader toResponseHeader() { + return new ResponseHeader(data.correlationId(), apiKey().responseHeaderVersion(apiVersion())); + } + + public static RequestHeader parse(ByteBuffer buffer) { + short apiKey = -1; + try { + // We derive the header version from the request api version, so we read that first. + // The request api version is part of `RequestHeaderData`, so we reset the buffer position after the read. + int position = buffer.position(); + apiKey = buffer.getShort(); + short apiVersion = buffer.getShort(); + short headerVersion = ApiKeys.forId(apiKey).requestHeaderVersion(apiVersion); + buffer.position(position); + RequestHeaderData headerData = new RequestHeaderData( + new ByteBufferAccessor(buffer), headerVersion); + // Due to a quirk in the protocol, client ID is marked as nullable. + // However, we treat a null client ID as equivalent to an empty client ID. + if (headerData.clientId() == null) { + headerData.setClientId(""); + } + return new RequestHeader(headerData, headerVersion); + } catch (UnsupportedVersionException e) { + throw new InvalidRequestException("Unknown API key " + apiKey, e); + } catch (Throwable ex) { + throw new InvalidRequestException("Error parsing request header. Our best guess of the apiKey is: " + + apiKey, ex); + } + } + + @Override + public String toString() { + return "RequestHeader(apiKey=" + apiKey() + + ", apiVersion=" + apiVersion() + + ", clientId=" + clientId() + + ", correlationId=" + correlationId() + + ")"; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RequestHeader that = (RequestHeader) o; + return this.data.equals(that.data); + } + + @Override + public int hashCode() { + return this.data.hashCode(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/RequestUtils.java b/clients/src/main/java/org/apache/kafka/common/requests/RequestUtils.java new file mode 100644 index 0000000..cc6e5a2 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/RequestUtils.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.ProduceRequestData; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Message; +import org.apache.kafka.common.protocol.ObjectSerializationCache; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.record.Records; + +import java.nio.ByteBuffer; +import java.util.Iterator; +import java.util.Optional; +import java.util.function.Predicate; + +public final class RequestUtils { + + private RequestUtils() {} + + public static Optional getLeaderEpoch(int leaderEpoch) { + return leaderEpoch == RecordBatch.NO_PARTITION_LEADER_EPOCH ? + Optional.empty() : Optional.of(leaderEpoch); + } + + public static boolean hasTransactionalRecords(ProduceRequest request) { + return flag(request, RecordBatch::isTransactional); + } + + /** + * find a flag from all records of a produce request. + * @param request produce request + * @param predicate used to predicate the record + * @return true if there is any matched flag in the produce request. Otherwise, false + */ + static boolean flag(ProduceRequest request, Predicate predicate) { + for (ProduceRequestData.TopicProduceData tp : request.data().topicData()) { + for (ProduceRequestData.PartitionProduceData p : tp.partitionData()) { + if (p.records() instanceof Records) { + Iterator iter = (((Records) p.records())).batchIterator(); + if (iter.hasNext() && predicate.test(iter.next())) return true; + } + } + } + return false; + } + + public static ByteBuffer serialize( + Message header, + short headerVersion, + Message apiMessage, + short apiVersion + ) { + ObjectSerializationCache cache = new ObjectSerializationCache(); + + int headerSize = header.size(cache, headerVersion); + int messageSize = apiMessage.size(cache, apiVersion); + ByteBufferAccessor writable = new ByteBufferAccessor(ByteBuffer.allocate(headerSize + messageSize)); + + header.write(writable, cache, headerVersion); + apiMessage.write(writable, cache, apiVersion); + + writable.flip(); + return writable.buffer(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/ResponseHeader.java b/clients/src/main/java/org/apache/kafka/common/requests/ResponseHeader.java new file mode 100644 index 0000000..d45e28d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/ResponseHeader.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.ResponseHeaderData; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.ObjectSerializationCache; + +import java.nio.ByteBuffer; +import java.util.Objects; + +/** + * A response header in the kafka protocol. + */ +public class ResponseHeader implements AbstractRequestResponse { + private final ResponseHeaderData data; + private final short headerVersion; + + public ResponseHeader(int correlationId, short headerVersion) { + this(new ResponseHeaderData().setCorrelationId(correlationId), headerVersion); + } + + public ResponseHeader(ResponseHeaderData data, short headerVersion) { + this.data = data; + this.headerVersion = headerVersion; + } + + public int size(ObjectSerializationCache serializationCache) { + return data().size(serializationCache, headerVersion); + } + + public int correlationId() { + return this.data.correlationId(); + } + + public short headerVersion() { + return headerVersion; + } + + public ResponseHeaderData data() { + return data; + } + + public void write(ByteBuffer buffer, ObjectSerializationCache serializationCache) { + data.write(new ByteBufferAccessor(buffer), serializationCache, headerVersion); + } + + @Override + public String toString() { + return "ResponseHeader(" + + "correlationId=" + data.correlationId() + + ", headerVersion=" + headerVersion + + ")"; + } + + public static ResponseHeader parse(ByteBuffer buffer, short headerVersion) { + return new ResponseHeader( + new ResponseHeaderData(new ByteBufferAccessor(buffer), headerVersion), + headerVersion); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ResponseHeader that = (ResponseHeader) o; + return headerVersion == that.headerVersion && + Objects.equals(data, that.data); + } + + @Override + public int hashCode() { + return Objects.hash(data, headerVersion); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/SaslAuthenticateRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/SaslAuthenticateRequest.java new file mode 100644 index 0000000..e2080ce --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/SaslAuthenticateRequest.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.SaslAuthenticateRequestData; +import org.apache.kafka.common.message.SaslAuthenticateResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; + +import java.nio.ByteBuffer; + + +/** + * Request from SASL client containing client SASL authentication token as defined by the + * SASL protocol for the configured SASL mechanism. + *

        + * For interoperability with versions prior to Kafka 1.0.0, this request is used only with broker + * version 1.0.0 and higher that support SaslHandshake request v1. Clients connecting to older + * brokers will send SaslHandshake request v0 followed by SASL tokens without the Kafka request headers. + */ +public class SaslAuthenticateRequest extends AbstractRequest { + + public static class Builder extends AbstractRequest.Builder { + private final SaslAuthenticateRequestData data; + + public Builder(SaslAuthenticateRequestData data) { + super(ApiKeys.SASL_AUTHENTICATE); + this.data = data; + } + + @Override + public SaslAuthenticateRequest build(short version) { + return new SaslAuthenticateRequest(data, version); + } + + @Override + public String toString() { + return "(type=SaslAuthenticateRequest)"; + } + } + + private final SaslAuthenticateRequestData data; + + public SaslAuthenticateRequest(SaslAuthenticateRequestData data, short version) { + super(ApiKeys.SASL_AUTHENTICATE, version); + this.data = data; + } + + @Override + public SaslAuthenticateRequestData data() { + return data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + ApiError apiError = ApiError.fromThrowable(e); + SaslAuthenticateResponseData response = new SaslAuthenticateResponseData() + .setErrorCode(apiError.error().code()) + .setErrorMessage(apiError.message()); + return new SaslAuthenticateResponse(response); + } + + public static SaslAuthenticateRequest parse(ByteBuffer buffer, short version) { + return new SaslAuthenticateRequest(new SaslAuthenticateRequestData(new ByteBufferAccessor(buffer), version), + version); + } +} + diff --git a/clients/src/main/java/org/apache/kafka/common/requests/SaslAuthenticateResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/SaslAuthenticateResponse.java new file mode 100644 index 0000000..bd12d3d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/SaslAuthenticateResponse.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.SaslAuthenticateResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Map; + +/** + * Response from SASL server which for a SASL challenge as defined by the SASL protocol + * for the mechanism configured for the client. + */ +public class SaslAuthenticateResponse extends AbstractResponse { + + private final SaslAuthenticateResponseData data; + + public SaslAuthenticateResponse(SaslAuthenticateResponseData data) { + super(ApiKeys.SASL_AUTHENTICATE); + this.data = data; + } + + /** + * Possible error codes: + * SASL_AUTHENTICATION_FAILED(57) : Authentication failed + */ + public Errors error() { + return Errors.forCode(data.errorCode()); + } + + @Override + public Map errorCounts() { + return errorCounts(Errors.forCode(data.errorCode())); + } + + public String errorMessage() { + return data.errorMessage(); + } + + public long sessionLifetimeMs() { + return data.sessionLifetimeMs(); + } + + public byte[] saslAuthBytes() { + return data.authBytes(); + } + + @Override + public int throttleTimeMs() { + return DEFAULT_THROTTLE_TIME; + } + + @Override + public SaslAuthenticateResponseData data() { + return data; + } + + public static SaslAuthenticateResponse parse(ByteBuffer buffer, short version) { + return new SaslAuthenticateResponse(new SaslAuthenticateResponseData(new ByteBufferAccessor(buffer), version)); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/SaslHandshakeRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/SaslHandshakeRequest.java new file mode 100644 index 0000000..09d3a87 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/SaslHandshakeRequest.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + + +import org.apache.kafka.common.message.SaslHandshakeRequestData; +import org.apache.kafka.common.message.SaslHandshakeResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; + +import java.nio.ByteBuffer; + +/** + * Request from SASL client containing client SASL mechanism. + *

        + * For interoperability with Kafka 0.9.0.x, the mechanism flow may be omitted when using GSSAPI. Hence + * this request should not conflict with the first GSSAPI client packet. For GSSAPI, the first context + * establishment packet starts with byte 0x60 (APPLICATION-0 tag) followed by a variable-length encoded size. + * This handshake request starts with a request header two-byte API key set to 17, followed by a mechanism name, + * making it easy to distinguish from a GSSAPI packet. + */ +public class SaslHandshakeRequest extends AbstractRequest { + + public static class Builder extends AbstractRequest.Builder { + private final SaslHandshakeRequestData data; + + public Builder(SaslHandshakeRequestData data) { + super(ApiKeys.SASL_HANDSHAKE); + this.data = data; + } + + @Override + public SaslHandshakeRequest build(short version) { + return new SaslHandshakeRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final SaslHandshakeRequestData data; + + public SaslHandshakeRequest(SaslHandshakeRequestData data, short version) { + super(ApiKeys.SASL_HANDSHAKE, version); + this.data = data; + } + + @Override + public SaslHandshakeRequestData data() { + return data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + SaslHandshakeResponseData response = new SaslHandshakeResponseData(); + response.setErrorCode(ApiError.fromThrowable(e).error().code()); + return new SaslHandshakeResponse(response); + } + + public static SaslHandshakeRequest parse(ByteBuffer buffer, short version) { + return new SaslHandshakeRequest(new SaslHandshakeRequestData(new ByteBufferAccessor(buffer), version), version); + } +} + diff --git a/clients/src/main/java/org/apache/kafka/common/requests/SaslHandshakeResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/SaslHandshakeResponse.java new file mode 100644 index 0000000..63c047a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/SaslHandshakeResponse.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.SaslHandshakeResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Map; + +/** + * Response from SASL server which indicates if the client-chosen mechanism is enabled in the server. + * For error responses, the list of enabled mechanisms is included in the response. + */ +public class SaslHandshakeResponse extends AbstractResponse { + + private final SaslHandshakeResponseData data; + + public SaslHandshakeResponse(SaslHandshakeResponseData data) { + super(ApiKeys.SASL_HANDSHAKE); + this.data = data; + } + + /* + * Possible error codes: + * UNSUPPORTED_SASL_MECHANISM(33): Client mechanism not enabled in server + * ILLEGAL_SASL_STATE(34) : Invalid request during SASL handshake + */ + public Errors error() { + return Errors.forCode(data.errorCode()); + } + + @Override + public Map errorCounts() { + return errorCounts(Errors.forCode(data.errorCode())); + } + + @Override + public int throttleTimeMs() { + return DEFAULT_THROTTLE_TIME; + } + + @Override + public SaslHandshakeResponseData data() { + return data; + } + + public List enabledMechanisms() { + return data.mechanisms(); + } + + public static SaslHandshakeResponse parse(ByteBuffer buffer, short version) { + return new SaslHandshakeResponse(new SaslHandshakeResponseData(new ByteBufferAccessor(buffer), version)); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/StopReplicaRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/StopReplicaRequest.java new file mode 100644 index 0000000..4326aaf --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/StopReplicaRequest.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.StopReplicaRequestData; +import org.apache.kafka.common.message.StopReplicaRequestData.StopReplicaPartitionState; +import org.apache.kafka.common.message.StopReplicaRequestData.StopReplicaPartitionV0; +import org.apache.kafka.common.message.StopReplicaRequestData.StopReplicaTopicV1; +import org.apache.kafka.common.message.StopReplicaRequestData.StopReplicaTopicState; +import org.apache.kafka.common.message.StopReplicaResponseData; +import org.apache.kafka.common.message.StopReplicaResponseData.StopReplicaPartitionError; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.utils.MappedIterator; +import org.apache.kafka.common.utils.Utils; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +public class StopReplicaRequest extends AbstractControlRequest { + + public static class Builder extends AbstractControlRequest.Builder { + private final boolean deletePartitions; + private final List topicStates; + + public Builder(short version, int controllerId, int controllerEpoch, long brokerEpoch, + boolean deletePartitions, List topicStates) { + super(ApiKeys.STOP_REPLICA, version, controllerId, controllerEpoch, brokerEpoch); + this.deletePartitions = deletePartitions; + this.topicStates = topicStates; + } + + public StopReplicaRequest build(short version) { + StopReplicaRequestData data = new StopReplicaRequestData() + .setControllerId(controllerId) + .setControllerEpoch(controllerEpoch) + .setBrokerEpoch(brokerEpoch); + + if (version >= 3) { + data.setTopicStates(topicStates); + } else if (version >= 1) { + data.setDeletePartitions(deletePartitions); + List topics = topicStates.stream().map(topic -> + new StopReplicaTopicV1() + .setName(topic.topicName()) + .setPartitionIndexes(topic.partitionStates().stream() + .map(StopReplicaPartitionState::partitionIndex) + .collect(Collectors.toList()))) + .collect(Collectors.toList()); + data.setTopics(topics); + } else { + data.setDeletePartitions(deletePartitions); + List partitions = topicStates.stream().flatMap(topic -> + topic.partitionStates().stream().map(partition -> + new StopReplicaPartitionV0() + .setTopicName(topic.topicName()) + .setPartitionIndex(partition.partitionIndex()))) + .collect(Collectors.toList()); + data.setUngroupedPartitions(partitions); + } + + return new StopReplicaRequest(data, version); + } + + @Override + public String toString() { + StringBuilder bld = new StringBuilder(); + bld.append("(type=StopReplicaRequest"). + append(", controllerId=").append(controllerId). + append(", controllerEpoch=").append(controllerEpoch). + append(", brokerEpoch=").append(brokerEpoch). + append(", deletePartitions=").append(deletePartitions). + append(", topicStates=").append(Utils.join(topicStates, ",")). + append(")"); + return bld.toString(); + } + } + + private final StopReplicaRequestData data; + + private StopReplicaRequest(StopReplicaRequestData data, short version) { + super(ApiKeys.STOP_REPLICA, version); + this.data = data; + } + + @Override + public StopReplicaResponse getErrorResponse(int throttleTimeMs, Throwable e) { + Errors error = Errors.forException(e); + + StopReplicaResponseData data = new StopReplicaResponseData(); + data.setErrorCode(error.code()); + + List partitions = new ArrayList<>(); + for (StopReplicaTopicState topic : topicStates()) { + for (StopReplicaPartitionState partition : topic.partitionStates()) { + partitions.add(new StopReplicaPartitionError() + .setTopicName(topic.topicName()) + .setPartitionIndex(partition.partitionIndex()) + .setErrorCode(error.code())); + } + } + data.setPartitionErrors(partitions); + + return new StopReplicaResponse(data); + } + + /** + * Note that this method has allocation overhead per iterated element, so callers should copy the result into + * another collection if they need to iterate more than once. + * + * Implementation note: we should strive to avoid allocation overhead per element, see + * `UpdateMetadataRequest.partitionStates()` for the preferred approach. That's not possible in this case and + * StopReplicaRequest should be relatively rare in comparison to other request types. + */ + public Iterable topicStates() { + if (version() < 1) { + Map topicStates = new HashMap<>(); + for (StopReplicaPartitionV0 partition : data.ungroupedPartitions()) { + StopReplicaTopicState topicState = topicStates.computeIfAbsent(partition.topicName(), + topic -> new StopReplicaTopicState().setTopicName(topic)); + topicState.partitionStates().add(new StopReplicaPartitionState() + .setPartitionIndex(partition.partitionIndex()) + .setDeletePartition(data.deletePartitions())); + } + return topicStates.values(); + } else if (version() < 3) { + return () -> new MappedIterator<>(data.topics().iterator(), topic -> + new StopReplicaTopicState() + .setTopicName(topic.name()) + .setPartitionStates(topic.partitionIndexes().stream() + .map(partition -> new StopReplicaPartitionState() + .setPartitionIndex(partition) + .setDeletePartition(data.deletePartitions())) + .collect(Collectors.toList()))); + } else { + return data.topicStates(); + } + } + + public Map partitionStates() { + Map partitionStates = new HashMap<>(); + + if (version() < 1) { + for (StopReplicaPartitionV0 partition : data.ungroupedPartitions()) { + partitionStates.put( + new TopicPartition(partition.topicName(), partition.partitionIndex()), + new StopReplicaPartitionState() + .setPartitionIndex(partition.partitionIndex()) + .setDeletePartition(data.deletePartitions())); + } + } else if (version() < 3) { + for (StopReplicaTopicV1 topic : data.topics()) { + for (Integer partitionIndex : topic.partitionIndexes()) { + partitionStates.put( + new TopicPartition(topic.name(), partitionIndex), + new StopReplicaPartitionState() + .setPartitionIndex(partitionIndex) + .setDeletePartition(data.deletePartitions())); + } + } + } else { + for (StopReplicaTopicState topicState : data.topicStates()) { + for (StopReplicaPartitionState partitionState: topicState.partitionStates()) { + partitionStates.put( + new TopicPartition(topicState.topicName(), partitionState.partitionIndex()), + partitionState); + } + } + } + + return partitionStates; + } + + @Override + public int controllerId() { + return data.controllerId(); + } + + @Override + public int controllerEpoch() { + return data.controllerEpoch(); + } + + @Override + public long brokerEpoch() { + return data.brokerEpoch(); + } + + public static StopReplicaRequest parse(ByteBuffer buffer, short version) { + return new StopReplicaRequest(new StopReplicaRequestData(new ByteBufferAccessor(buffer), version), version); + } + + @Override + public StopReplicaRequestData data() { + return data; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/StopReplicaResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/StopReplicaResponse.java new file mode 100644 index 0000000..10ab153 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/StopReplicaResponse.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.StopReplicaResponseData; +import org.apache.kafka.common.message.StopReplicaResponseData.StopReplicaPartitionError; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +public class StopReplicaResponse extends AbstractResponse { + + /** + * Possible error code: + * - {@link Errors#STALE_CONTROLLER_EPOCH} + * - {@link Errors#STALE_BROKER_EPOCH} + * - {@link Errors#FENCED_LEADER_EPOCH} + * - {@link Errors#KAFKA_STORAGE_ERROR} + */ + private final StopReplicaResponseData data; + + public StopReplicaResponse(StopReplicaResponseData data) { + super(ApiKeys.STOP_REPLICA); + this.data = data; + } + + public List partitionErrors() { + return data.partitionErrors(); + } + + public Errors error() { + return Errors.forCode(data.errorCode()); + } + + @Override + public Map errorCounts() { + if (data.errorCode() != Errors.NONE.code()) + // Minor optimization since the top-level error applies to all partitions + return Collections.singletonMap(error(), data.partitionErrors().size() + 1); + Map errors = errorCounts(data.partitionErrors().stream().map(p -> Errors.forCode(p.errorCode()))); + updateErrorCounts(errors, Errors.forCode(data.errorCode())); // top level error + return errors; + } + + public static StopReplicaResponse parse(ByteBuffer buffer, short version) { + return new StopReplicaResponse(new StopReplicaResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public int throttleTimeMs() { + return DEFAULT_THROTTLE_TIME; + } + + @Override + public StopReplicaResponseData data() { + return data; + } + + @Override + public String toString() { + return data.toString(); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/SyncGroupRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/SyncGroupRequest.java new file mode 100644 index 0000000..8242b71 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/SyncGroupRequest.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.SyncGroupRequestData; +import org.apache.kafka.common.message.SyncGroupResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; + +public class SyncGroupRequest extends AbstractRequest { + + public static class Builder extends AbstractRequest.Builder { + + private final SyncGroupRequestData data; + + public Builder(SyncGroupRequestData data) { + super(ApiKeys.SYNC_GROUP); + this.data = data; + } + + @Override + public SyncGroupRequest build(short version) { + if (data.groupInstanceId() != null && version < 3) { + throw new UnsupportedVersionException("The broker sync group protocol version " + + version + " does not support usage of config group.instance.id."); + } + return new SyncGroupRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final SyncGroupRequestData data; + + public SyncGroupRequest(SyncGroupRequestData data, short version) { + super(ApiKeys.SYNC_GROUP, version); + this.data = data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + return new SyncGroupResponse(new SyncGroupResponseData() + .setErrorCode(Errors.forException(e).code()) + .setAssignment(new byte[0]) + .setThrottleTimeMs(throttleTimeMs)); + } + + public Map groupAssignments() { + Map groupAssignments = new HashMap<>(); + for (SyncGroupRequestData.SyncGroupRequestAssignment assignment : data.assignments()) { + groupAssignments.put(assignment.memberId(), ByteBuffer.wrap(assignment.assignment())); + } + return groupAssignments; + } + + /** + * ProtocolType and ProtocolName are mandatory since version 5. This methods verifies that + * they are defined for version 5 or higher, or returns true otherwise for older versions. + */ + public boolean areMandatoryProtocolTypeAndNamePresent() { + if (version() >= 5) + return data.protocolType() != null && data.protocolName() != null; + else + return true; + } + + public static SyncGroupRequest parse(ByteBuffer buffer, short version) { + return new SyncGroupRequest(new SyncGroupRequestData(new ByteBufferAccessor(buffer), version), version); + } + + @Override + public SyncGroupRequestData data() { + return data; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/SyncGroupResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/SyncGroupResponse.java new file mode 100644 index 0000000..822a3e7 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/SyncGroupResponse.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.SyncGroupResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Map; + +public class SyncGroupResponse extends AbstractResponse { + + private final SyncGroupResponseData data; + + public SyncGroupResponse(SyncGroupResponseData data) { + super(ApiKeys.SYNC_GROUP); + this.data = data; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + public Errors error() { + return Errors.forCode(data.errorCode()); + } + + @Override + public Map errorCounts() { + return errorCounts(Errors.forCode(data.errorCode())); + } + + @Override + public SyncGroupResponseData data() { + return data; + } + + @Override + public String toString() { + return data.toString(); + } + + public static SyncGroupResponse parse(ByteBuffer buffer, short version) { + return new SyncGroupResponse(new SyncGroupResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 2; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/TransactionResult.java b/clients/src/main/java/org/apache/kafka/common/requests/TransactionResult.java new file mode 100644 index 0000000..d0448af --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/TransactionResult.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +public enum TransactionResult { + ABORT(false), COMMIT(true); + + public final boolean id; + + TransactionResult(boolean id) { + this.id = id; + } + + public static TransactionResult forId(boolean id) { + if (id) { + return TransactionResult.COMMIT; + } + return TransactionResult.ABORT; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/TxnOffsetCommitRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/TxnOffsetCommitRequest.java new file mode 100644 index 0000000..27b13ba --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/TxnOffsetCommitRequest.java @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.TxnOffsetCommitRequestData; +import org.apache.kafka.common.message.TxnOffsetCommitRequestData.TxnOffsetCommitRequestPartition; +import org.apache.kafka.common.message.TxnOffsetCommitRequestData.TxnOffsetCommitRequestTopic; +import org.apache.kafka.common.message.TxnOffsetCommitResponseData; +import org.apache.kafka.common.message.TxnOffsetCommitResponseData.TxnOffsetCommitResponsePartition; +import org.apache.kafka.common.message.TxnOffsetCommitResponseData.TxnOffsetCommitResponseTopic; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.RecordBatch; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +public class TxnOffsetCommitRequest extends AbstractRequest { + + private final TxnOffsetCommitRequestData data; + + public static class Builder extends AbstractRequest.Builder { + + public final TxnOffsetCommitRequestData data; + + + public Builder(final String transactionalId, + final String consumerGroupId, + final long producerId, + final short producerEpoch, + final Map pendingTxnOffsetCommits) { + this(transactionalId, + consumerGroupId, + producerId, + producerEpoch, + pendingTxnOffsetCommits, + JoinGroupRequest.UNKNOWN_MEMBER_ID, + JoinGroupRequest.UNKNOWN_GENERATION_ID, + Optional.empty()); + } + + public Builder(final String transactionalId, + final String consumerGroupId, + final long producerId, + final short producerEpoch, + final Map pendingTxnOffsetCommits, + final String memberId, + final int generationId, + final Optional groupInstanceId) { + super(ApiKeys.TXN_OFFSET_COMMIT); + this.data = new TxnOffsetCommitRequestData() + .setTransactionalId(transactionalId) + .setGroupId(consumerGroupId) + .setProducerId(producerId) + .setProducerEpoch(producerEpoch) + .setTopics(getTopics(pendingTxnOffsetCommits)) + .setMemberId(memberId) + .setGenerationId(generationId) + .setGroupInstanceId(groupInstanceId.orElse(null)); + } + + @Override + public TxnOffsetCommitRequest build(short version) { + if (version < 3 && groupMetadataSet()) { + throw new UnsupportedVersionException("Broker doesn't support group metadata commit API on version " + version + + ", minimum supported request version is 3 which requires brokers to be on version 2.5 or above."); + } + return new TxnOffsetCommitRequest(data, version); + } + + private boolean groupMetadataSet() { + return !data.memberId().equals(JoinGroupRequest.UNKNOWN_MEMBER_ID) || + data.generationId() != JoinGroupRequest.UNKNOWN_GENERATION_ID || + data.groupInstanceId() != null; + } + + @Override + public String toString() { + return data.toString(); + } + } + + public TxnOffsetCommitRequest(TxnOffsetCommitRequestData data, short version) { + super(ApiKeys.TXN_OFFSET_COMMIT, version); + this.data = data; + } + + public Map offsets() { + List topics = data.topics(); + Map offsetMap = new HashMap<>(); + for (TxnOffsetCommitRequestTopic topic : topics) { + for (TxnOffsetCommitRequestPartition partition : topic.partitions()) { + offsetMap.put(new TopicPartition(topic.name(), partition.partitionIndex()), + new CommittedOffset(partition.committedOffset(), + partition.committedMetadata(), + RequestUtils.getLeaderEpoch(partition.committedLeaderEpoch())) + ); + } + } + return offsetMap; + } + + static List getTopics(Map pendingTxnOffsetCommits) { + Map> topicPartitionMap = new HashMap<>(); + for (Map.Entry entry : pendingTxnOffsetCommits.entrySet()) { + TopicPartition topicPartition = entry.getKey(); + CommittedOffset offset = entry.getValue(); + + List partitions = + topicPartitionMap.getOrDefault(topicPartition.topic(), new ArrayList<>()); + partitions.add(new TxnOffsetCommitRequestPartition() + .setPartitionIndex(topicPartition.partition()) + .setCommittedOffset(offset.offset) + .setCommittedLeaderEpoch(offset.leaderEpoch.orElse(RecordBatch.NO_PARTITION_LEADER_EPOCH)) + .setCommittedMetadata(offset.metadata) + ); + topicPartitionMap.put(topicPartition.topic(), partitions); + } + return topicPartitionMap.entrySet().stream() + .map(entry -> new TxnOffsetCommitRequestTopic() + .setName(entry.getKey()) + .setPartitions(entry.getValue())) + .collect(Collectors.toList()); + } + + @Override + public TxnOffsetCommitRequestData data() { + return data; + } + + static List getErrorResponseTopics(List requestTopics, + Errors e) { + List responseTopicData = new ArrayList<>(); + for (TxnOffsetCommitRequestTopic entry : requestTopics) { + List responsePartitions = new ArrayList<>(); + for (TxnOffsetCommitRequestPartition requestPartition : entry.partitions()) { + responsePartitions.add(new TxnOffsetCommitResponsePartition() + .setPartitionIndex(requestPartition.partitionIndex()) + .setErrorCode(e.code())); + } + responseTopicData.add(new TxnOffsetCommitResponseTopic() + .setName(entry.name()) + .setPartitions(responsePartitions) + ); + } + return responseTopicData; + } + + @Override + public TxnOffsetCommitResponse getErrorResponse(int throttleTimeMs, Throwable e) { + List responseTopicData = + getErrorResponseTopics(data.topics(), Errors.forException(e)); + + return new TxnOffsetCommitResponse(new TxnOffsetCommitResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setTopics(responseTopicData)); + } + + public static TxnOffsetCommitRequest parse(ByteBuffer buffer, short version) { + return new TxnOffsetCommitRequest(new TxnOffsetCommitRequestData( + new ByteBufferAccessor(buffer), version), version); + } + + public static class CommittedOffset { + public final long offset; + public final String metadata; + public final Optional leaderEpoch; + + public CommittedOffset(long offset, String metadata, Optional leaderEpoch) { + this.offset = offset; + this.metadata = metadata; + this.leaderEpoch = leaderEpoch; + } + + @Override + public String toString() { + return "CommittedOffset(" + + "offset=" + offset + + ", leaderEpoch=" + leaderEpoch + + ", metadata='" + metadata + "')"; + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof CommittedOffset)) { + return false; + } + CommittedOffset otherOffset = (CommittedOffset) other; + + return this.offset == otherOffset.offset + && this.leaderEpoch.equals(otherOffset.leaderEpoch) + && Objects.equals(this.metadata, otherOffset.metadata); + } + + @Override + public int hashCode() { + return Objects.hash(offset, leaderEpoch, metadata); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/TxnOffsetCommitResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/TxnOffsetCommitResponse.java new file mode 100644 index 0000000..b4de547 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/TxnOffsetCommitResponse.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.TxnOffsetCommitResponseData; +import org.apache.kafka.common.message.TxnOffsetCommitResponseData.TxnOffsetCommitResponsePartition; +import org.apache.kafka.common.message.TxnOffsetCommitResponseData.TxnOffsetCommitResponseTopic; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; + +/** + * Possible error codes: + * + * - {@link Errors#INVALID_PRODUCER_EPOCH} + * - {@link Errors#NOT_COORDINATOR} + * - {@link Errors#COORDINATOR_NOT_AVAILABLE} + * - {@link Errors#COORDINATOR_LOAD_IN_PROGRESS} + * - {@link Errors#OFFSET_METADATA_TOO_LARGE} + * - {@link Errors#GROUP_AUTHORIZATION_FAILED} + * - {@link Errors#INVALID_COMMIT_OFFSET_SIZE} + * - {@link Errors#TRANSACTIONAL_ID_AUTHORIZATION_FAILED} + * - {@link Errors#REQUEST_TIMED_OUT} + * - {@link Errors#UNKNOWN_MEMBER_ID} + * - {@link Errors#FENCED_INSTANCE_ID} + * - {@link Errors#ILLEGAL_GENERATION} + */ +public class TxnOffsetCommitResponse extends AbstractResponse { + + private final TxnOffsetCommitResponseData data; + + public TxnOffsetCommitResponse(TxnOffsetCommitResponseData data) { + super(ApiKeys.TXN_OFFSET_COMMIT); + this.data = data; + } + + public TxnOffsetCommitResponse(int requestThrottleMs, Map responseData) { + super(ApiKeys.TXN_OFFSET_COMMIT); + Map responseTopicDataMap = new HashMap<>(); + + for (Map.Entry entry : responseData.entrySet()) { + TopicPartition topicPartition = entry.getKey(); + String topicName = topicPartition.topic(); + + TxnOffsetCommitResponseTopic topic = responseTopicDataMap.getOrDefault( + topicName, new TxnOffsetCommitResponseTopic().setName(topicName)); + + topic.partitions().add(new TxnOffsetCommitResponsePartition() + .setErrorCode(entry.getValue().code()) + .setPartitionIndex(topicPartition.partition()) + ); + responseTopicDataMap.put(topicName, topic); + } + + data = new TxnOffsetCommitResponseData() + .setTopics(new ArrayList<>(responseTopicDataMap.values())) + .setThrottleTimeMs(requestThrottleMs); + } + + @Override + public TxnOffsetCommitResponseData data() { + return data; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + @Override + public Map errorCounts() { + return errorCounts(data.topics().stream().flatMap(topic -> + topic.partitions().stream().map(partition -> + Errors.forCode(partition.errorCode())))); + } + + public Map errors() { + Map errorMap = new HashMap<>(); + for (TxnOffsetCommitResponseTopic topic : data.topics()) { + for (TxnOffsetCommitResponsePartition partition : topic.partitions()) { + errorMap.put(new TopicPartition(topic.name(), partition.partitionIndex()), + Errors.forCode(partition.errorCode())); + } + } + return errorMap; + } + + public static TxnOffsetCommitResponse parse(ByteBuffer buffer, short version) { + return new TxnOffsetCommitResponse(new TxnOffsetCommitResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public String toString() { + return data.toString(); + } + + @Override + public boolean shouldClientThrottle(short version) { + return version >= 1; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/UnregisterBrokerRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/UnregisterBrokerRequest.java new file mode 100644 index 0000000..253499f --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/UnregisterBrokerRequest.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.UnregisterBrokerRequestData; +import org.apache.kafka.common.message.UnregisterBrokerResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; + +public class UnregisterBrokerRequest extends AbstractRequest { + + public static class Builder extends AbstractRequest.Builder { + private final UnregisterBrokerRequestData data; + + public Builder(UnregisterBrokerRequestData data) { + super(ApiKeys.UNREGISTER_BROKER); + this.data = data; + } + + @Override + public UnregisterBrokerRequest build(short version) { + return new UnregisterBrokerRequest(data, version); + } + } + + private final UnregisterBrokerRequestData data; + + public UnregisterBrokerRequest(UnregisterBrokerRequestData data, short version) { + super(ApiKeys.UNREGISTER_BROKER, version); + this.data = data; + } + + @Override + public UnregisterBrokerRequestData data() { + return data; + } + + @Override + public UnregisterBrokerResponse getErrorResponse(int throttleTimeMs, Throwable e) { + Errors error = Errors.forException(e); + return new UnregisterBrokerResponse(new UnregisterBrokerResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setErrorCode(error.code())); + } + + public static UnregisterBrokerRequest parse(ByteBuffer buffer, short version) { + return new UnregisterBrokerRequest(new UnregisterBrokerRequestData(new ByteBufferAccessor(buffer), version), + version); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/UnregisterBrokerResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/UnregisterBrokerResponse.java new file mode 100644 index 0000000..b508ac3 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/UnregisterBrokerResponse.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.UnregisterBrokerResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; + +public class UnregisterBrokerResponse extends AbstractResponse { + private final UnregisterBrokerResponseData data; + + public UnregisterBrokerResponse(UnregisterBrokerResponseData data) { + super(ApiKeys.UNREGISTER_BROKER); + this.data = data; + } + + @Override + public UnregisterBrokerResponseData data() { + return data; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + @Override + public Map errorCounts() { + Map errorCounts = new HashMap<>(); + if (data.errorCode() != 0) { + errorCounts.put(Errors.forCode(data.errorCode()), 1); + } + return errorCounts; + } + + public static UnregisterBrokerResponse parse(ByteBuffer buffer, short version) { + return new UnregisterBrokerResponse(new UnregisterBrokerResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public boolean shouldClientThrottle(short version) { + return true; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/UpdateFeaturesRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/UpdateFeaturesRequest.java new file mode 100644 index 0000000..7a6bf66 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/UpdateFeaturesRequest.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.UpdateFeaturesRequestData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; + +import java.nio.ByteBuffer; +import java.util.Collections; + +public class UpdateFeaturesRequest extends AbstractRequest { + + public static class Builder extends AbstractRequest.Builder { + + private final UpdateFeaturesRequestData data; + + public Builder(UpdateFeaturesRequestData data) { + super(ApiKeys.UPDATE_FEATURES); + this.data = data; + } + + @Override + public UpdateFeaturesRequest build(short version) { + return new UpdateFeaturesRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final UpdateFeaturesRequestData data; + + public UpdateFeaturesRequest(UpdateFeaturesRequestData data, short version) { + super(ApiKeys.UPDATE_FEATURES, version); + this.data = data; + } + + @Override + public UpdateFeaturesResponse getErrorResponse(int throttleTimeMs, Throwable e) { + return UpdateFeaturesResponse.createWithErrors( + ApiError.fromThrowable(e), + Collections.emptyMap(), + throttleTimeMs + ); + } + + @Override + public UpdateFeaturesRequestData data() { + return data; + } + + public static UpdateFeaturesRequest parse(ByteBuffer buffer, short version) { + return new UpdateFeaturesRequest(new UpdateFeaturesRequestData(new ByteBufferAccessor(buffer), version), version); + } + + public static boolean isDeleteRequest(UpdateFeaturesRequestData.FeatureUpdateKey update) { + return update.maxVersionLevel() < 1 && update.allowDowngrade(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/UpdateFeaturesResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/UpdateFeaturesResponse.java new file mode 100644 index 0000000..26825a0 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/UpdateFeaturesResponse.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.UpdateFeaturesResponseData; +import org.apache.kafka.common.message.UpdateFeaturesResponseData.UpdatableFeatureResult; +import org.apache.kafka.common.message.UpdateFeaturesResponseData.UpdatableFeatureResultCollection; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; + +/** + * Possible error codes: + * + * - {@link Errors#CLUSTER_AUTHORIZATION_FAILED} + * - {@link Errors#NOT_CONTROLLER} + * - {@link Errors#INVALID_REQUEST} + * - {@link Errors#FEATURE_UPDATE_FAILED} + */ +public class UpdateFeaturesResponse extends AbstractResponse { + + private final UpdateFeaturesResponseData data; + + public UpdateFeaturesResponse(UpdateFeaturesResponseData data) { + super(ApiKeys.UPDATE_FEATURES); + this.data = data; + } + + public ApiError topLevelError() { + return new ApiError(Errors.forCode(data.errorCode()), data.errorMessage()); + } + + @Override + public Map errorCounts() { + Map errorCounts = new HashMap<>(); + updateErrorCounts(errorCounts, Errors.forCode(data.errorCode())); + for (UpdatableFeatureResult result : data.results()) { + updateErrorCounts(errorCounts, Errors.forCode(result.errorCode())); + } + return errorCounts; + } + + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + @Override + public String toString() { + return data.toString(); + } + + @Override + public UpdateFeaturesResponseData data() { + return data; + } + + public static UpdateFeaturesResponse parse(ByteBuffer buffer, short version) { + return new UpdateFeaturesResponse(new UpdateFeaturesResponseData(new ByteBufferAccessor(buffer), version)); + } + + public static UpdateFeaturesResponse createWithErrors(ApiError topLevelError, Map updateErrors, int throttleTimeMs) { + final UpdatableFeatureResultCollection results = new UpdatableFeatureResultCollection(); + for (final Map.Entry updateError : updateErrors.entrySet()) { + final String feature = updateError.getKey(); + final ApiError error = updateError.getValue(); + final UpdatableFeatureResult result = new UpdatableFeatureResult(); + result.setFeature(feature) + .setErrorCode(error.error().code()) + .setErrorMessage(error.message()); + results.add(result); + } + final UpdateFeaturesResponseData responseData = new UpdateFeaturesResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setErrorCode(topLevelError.error().code()) + .setErrorMessage(topLevelError.message()) + .setResults(results) + .setThrottleTimeMs(throttleTimeMs); + return new UpdateFeaturesResponse(responseData); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/UpdateMetadataRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/UpdateMetadataRequest.java new file mode 100644 index 0000000..845bdd9 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/UpdateMetadataRequest.java @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.UpdateMetadataRequestData; +import org.apache.kafka.common.message.UpdateMetadataRequestData.UpdateMetadataBroker; +import org.apache.kafka.common.message.UpdateMetadataRequestData.UpdateMetadataEndpoint; +import org.apache.kafka.common.message.UpdateMetadataRequestData.UpdateMetadataPartitionState; +import org.apache.kafka.common.message.UpdateMetadataRequestData.UpdateMetadataTopicState; +import org.apache.kafka.common.message.UpdateMetadataResponseData; +import org.apache.kafka.common.network.ListenerName; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.common.utils.FlattenedIterator; +import org.apache.kafka.common.utils.Utils; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static java.util.Collections.singletonList; + +public class UpdateMetadataRequest extends AbstractControlRequest { + + public static class Builder extends AbstractControlRequest.Builder { + private final List partitionStates; + private final List liveBrokers; + private final Map topicIds; + + public Builder(short version, int controllerId, int controllerEpoch, long brokerEpoch, + List partitionStates, List liveBrokers, + Map topicIds) { + super(ApiKeys.UPDATE_METADATA, version, controllerId, controllerEpoch, brokerEpoch); + this.partitionStates = partitionStates; + this.liveBrokers = liveBrokers; + this.topicIds = topicIds; + } + + @Override + public UpdateMetadataRequest build(short version) { + if (version < 3) { + for (UpdateMetadataBroker broker : liveBrokers) { + if (version == 0) { + if (broker.endpoints().size() != 1) + throw new UnsupportedVersionException("UpdateMetadataRequest v0 requires a single endpoint"); + if (broker.endpoints().get(0).securityProtocol() != SecurityProtocol.PLAINTEXT.id) + throw new UnsupportedVersionException("UpdateMetadataRequest v0 only handles PLAINTEXT endpoints"); + // Don't null out `endpoints` since it's ignored by the generated code if version >= 1 + UpdateMetadataEndpoint endpoint = broker.endpoints().get(0); + broker.setV0Host(endpoint.host()); + broker.setV0Port(endpoint.port()); + } else { + if (broker.endpoints().stream().anyMatch(endpoint -> !endpoint.listener().isEmpty() && + !endpoint.listener().equals(listenerNameFromSecurityProtocol(endpoint)))) { + throw new UnsupportedVersionException("UpdateMetadataRequest v0-v3 does not support custom " + + "listeners, request version: " + version + ", endpoints: " + broker.endpoints()); + } + } + } + } + + UpdateMetadataRequestData data = new UpdateMetadataRequestData() + .setControllerId(controllerId) + .setControllerEpoch(controllerEpoch) + .setBrokerEpoch(brokerEpoch) + .setLiveBrokers(liveBrokers); + + if (version >= 5) { + Map topicStatesMap = groupByTopic(topicIds, partitionStates); + data.setTopicStates(new ArrayList<>(topicStatesMap.values())); + } else { + data.setUngroupedPartitionStates(partitionStates); + } + + return new UpdateMetadataRequest(data, version); + } + + private static Map groupByTopic(Map topicIds, List partitionStates) { + Map topicStates = new HashMap<>(); + for (UpdateMetadataPartitionState partition : partitionStates) { + // We don't null out the topic name in UpdateMetadataPartitionState since it's ignored by the generated + // code if version >= 5 + UpdateMetadataTopicState topicState = topicStates.computeIfAbsent(partition.topicName(), + t -> new UpdateMetadataTopicState() + .setTopicName(partition.topicName()) + .setTopicId(topicIds.getOrDefault(partition.topicName(), Uuid.ZERO_UUID)) + + ); + topicState.partitionStates().add(partition); + } + return topicStates; + } + + @Override + public String toString() { + StringBuilder bld = new StringBuilder(); + bld.append("(type: UpdateMetadataRequest="). + append(", controllerId=").append(controllerId). + append(", controllerEpoch=").append(controllerEpoch). + append(", brokerEpoch=").append(brokerEpoch). + append(", partitionStates=").append(partitionStates). + append(", liveBrokers=").append(Utils.join(liveBrokers, ", ")). + append(")"); + return bld.toString(); + } + } + + private final UpdateMetadataRequestData data; + + UpdateMetadataRequest(UpdateMetadataRequestData data, short version) { + super(ApiKeys.UPDATE_METADATA, version); + this.data = data; + // Do this from the constructor to make it thread-safe (even though it's only needed when some methods are called) + normalize(); + } + + private void normalize() { + // Version 0 only supported a single host and port and the protocol was always plaintext + // Version 1 added support for multiple endpoints, each with its own security protocol + // Version 2 added support for rack + // Version 3 added support for listener name, which we can infer from the security protocol for older versions + if (version() < 3) { + for (UpdateMetadataBroker liveBroker : data.liveBrokers()) { + // Set endpoints so that callers can rely on it always being present + if (version() == 0 && liveBroker.endpoints().isEmpty()) { + SecurityProtocol securityProtocol = SecurityProtocol.PLAINTEXT; + liveBroker.setEndpoints(singletonList(new UpdateMetadataEndpoint() + .setHost(liveBroker.v0Host()) + .setPort(liveBroker.v0Port()) + .setSecurityProtocol(securityProtocol.id) + .setListener(ListenerName.forSecurityProtocol(securityProtocol).value()))); + } else { + for (UpdateMetadataEndpoint endpoint : liveBroker.endpoints()) { + // Set listener so that callers can rely on it always being present + if (endpoint.listener().isEmpty()) + endpoint.setListener(listenerNameFromSecurityProtocol(endpoint)); + } + } + } + } + + if (version() >= 5) { + for (UpdateMetadataTopicState topicState : data.topicStates()) { + for (UpdateMetadataPartitionState partitionState : topicState.partitionStates()) { + // Set the topic name so that we can always present the ungrouped view to callers + partitionState.setTopicName(topicState.topicName()); + } + } + } + } + + private static String listenerNameFromSecurityProtocol(UpdateMetadataEndpoint endpoint) { + SecurityProtocol securityProtocol = SecurityProtocol.forId(endpoint.securityProtocol()); + return ListenerName.forSecurityProtocol(securityProtocol).value(); + } + + @Override + public int controllerId() { + return data.controllerId(); + } + + @Override + public int controllerEpoch() { + return data.controllerEpoch(); + } + + @Override + public long brokerEpoch() { + return data.brokerEpoch(); + } + + @Override + public UpdateMetadataResponse getErrorResponse(int throttleTimeMs, Throwable e) { + UpdateMetadataResponseData data = new UpdateMetadataResponseData() + .setErrorCode(Errors.forException(e).code()); + return new UpdateMetadataResponse(data); + } + + public Iterable partitionStates() { + if (version() >= 5) { + return () -> new FlattenedIterator<>(data.topicStates().iterator(), + topicState -> topicState.partitionStates().iterator()); + } + return data.ungroupedPartitionStates(); + } + + public List topicStates() { + if (version() >= 5) { + return data.topicStates(); + } + return Collections.emptyList(); + } + + public List liveBrokers() { + return data.liveBrokers(); + } + + @Override + public UpdateMetadataRequestData data() { + return data; + } + + public static UpdateMetadataRequest parse(ByteBuffer buffer, short version) { + return new UpdateMetadataRequest(new UpdateMetadataRequestData(new ByteBufferAccessor(buffer), version), version); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/UpdateMetadataResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/UpdateMetadataResponse.java new file mode 100644 index 0000000..cc7749a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/UpdateMetadataResponse.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.UpdateMetadataResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Map; + +public class UpdateMetadataResponse extends AbstractResponse { + + private final UpdateMetadataResponseData data; + + public UpdateMetadataResponse(UpdateMetadataResponseData data) { + super(ApiKeys.UPDATE_METADATA); + this.data = data; + } + + public Errors error() { + return Errors.forCode(data.errorCode()); + } + + @Override + public Map errorCounts() { + return errorCounts(error()); + } + + @Override + public int throttleTimeMs() { + return DEFAULT_THROTTLE_TIME; + } + + public static UpdateMetadataResponse parse(ByteBuffer buffer, short version) { + return new UpdateMetadataResponse(new UpdateMetadataResponseData(new ByteBufferAccessor(buffer), version)); + } + + @Override + public UpdateMetadataResponseData data() { + return data; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/VoteRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/VoteRequest.java new file mode 100644 index 0000000..8fba2f0 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/VoteRequest.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.VoteRequestData; +import org.apache.kafka.common.message.VoteResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Collections; + +public class VoteRequest extends AbstractRequest { + + public static class Builder extends AbstractRequest.Builder { + private final VoteRequestData data; + + public Builder(VoteRequestData data) { + super(ApiKeys.VOTE); + this.data = data; + } + + @Override + public VoteRequest build(short version) { + return new VoteRequest(data, version); + } + + @Override + public String toString() { + return data.toString(); + } + } + + private final VoteRequestData data; + + private VoteRequest(VoteRequestData data, short version) { + super(ApiKeys.VOTE, version); + this.data = data; + } + + @Override + public VoteRequestData data() { + return data; + } + + @Override + public AbstractResponse getErrorResponse(int throttleTimeMs, Throwable e) { + return new VoteResponse(new VoteResponseData() + .setErrorCode(Errors.forException(e).code())); + } + + public static VoteRequest parse(ByteBuffer buffer, short version) { + return new VoteRequest(new VoteRequestData(new ByteBufferAccessor(buffer), version), version); + } + + public static VoteRequestData singletonRequest(TopicPartition topicPartition, + int candidateEpoch, + int candidateId, + int lastEpoch, + long lastEpochEndOffset) { + return singletonRequest(topicPartition, + null, + candidateEpoch, + candidateId, + lastEpoch, + lastEpochEndOffset); + } + + public static VoteRequestData singletonRequest(TopicPartition topicPartition, + String clusterId, + int candidateEpoch, + int candidateId, + int lastEpoch, + long lastEpochEndOffset) { + return new VoteRequestData() + .setClusterId(clusterId) + .setTopics(Collections.singletonList( + new VoteRequestData.TopicData() + .setTopicName(topicPartition.topic()) + .setPartitions(Collections.singletonList( + new VoteRequestData.PartitionData() + .setPartitionIndex(topicPartition.partition()) + .setCandidateEpoch(candidateEpoch) + .setCandidateId(candidateId) + .setLastOffsetEpoch(lastEpoch) + .setLastOffset(lastEpochEndOffset)) + ))); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/VoteResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/VoteResponse.java new file mode 100644 index 0000000..51991ad --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/VoteResponse.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.VoteResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * Possible error codes. + * + * Top level errors: + * - {@link Errors#CLUSTER_AUTHORIZATION_FAILED} + * - {@link Errors#BROKER_NOT_AVAILABLE} + * + * Partition level errors: + * - {@link Errors#FENCED_LEADER_EPOCH} + * - {@link Errors#INVALID_REQUEST} + * - {@link Errors#INCONSISTENT_VOTER_SET} + * - {@link Errors#UNKNOWN_TOPIC_OR_PARTITION} + */ +public class VoteResponse extends AbstractResponse { + private final VoteResponseData data; + + public VoteResponse(VoteResponseData data) { + super(ApiKeys.VOTE); + this.data = data; + } + + public static VoteResponseData singletonResponse(Errors topLevelError, + TopicPartition topicPartition, + Errors partitionLevelError, + int leaderEpoch, + int leaderId, + boolean voteGranted) { + return new VoteResponseData() + .setErrorCode(topLevelError.code()) + .setTopics(Collections.singletonList( + new VoteResponseData.TopicData() + .setTopicName(topicPartition.topic()) + .setPartitions(Collections.singletonList( + new VoteResponseData.PartitionData() + .setErrorCode(partitionLevelError.code()) + .setLeaderId(leaderId) + .setLeaderEpoch(leaderEpoch) + .setVoteGranted(voteGranted))))); + } + + @Override + public Map errorCounts() { + Map errors = new HashMap<>(); + + errors.put(Errors.forCode(data.errorCode()), 1); + + for (VoteResponseData.TopicData topicResponse : data.topics()) { + for (VoteResponseData.PartitionData partitionResponse : topicResponse.partitions()) { + updateErrorCounts(errors, Errors.forCode(partitionResponse.errorCode())); + } + } + return errors; + } + + @Override + public VoteResponseData data() { + return data; + } + + @Override + public int throttleTimeMs() { + return DEFAULT_THROTTLE_TIME; + } + + public static VoteResponse parse(ByteBuffer buffer, short version) { + return new VoteResponse(new VoteResponseData(new ByteBufferAccessor(buffer), version)); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/WriteTxnMarkersRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/WriteTxnMarkersRequest.java new file mode 100644 index 0000000..73a6b6c --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/WriteTxnMarkersRequest.java @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.WriteTxnMarkersRequestData; +import org.apache.kafka.common.message.WriteTxnMarkersRequestData.WritableTxnMarker; +import org.apache.kafka.common.message.WriteTxnMarkersRequestData.WritableTxnMarkerTopic; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +public class WriteTxnMarkersRequest extends AbstractRequest { + + public static class TxnMarkerEntry { + private final long producerId; + private final short producerEpoch; + private final int coordinatorEpoch; + private final TransactionResult result; + private final List partitions; + + public TxnMarkerEntry(long producerId, + short producerEpoch, + int coordinatorEpoch, + TransactionResult result, + List partitions) { + this.producerId = producerId; + this.producerEpoch = producerEpoch; + this.coordinatorEpoch = coordinatorEpoch; + this.result = result; + this.partitions = partitions; + } + + public long producerId() { + return producerId; + } + + public short producerEpoch() { + return producerEpoch; + } + + public int coordinatorEpoch() { + return coordinatorEpoch; + } + + public TransactionResult transactionResult() { + return result; + } + + public List partitions() { + return partitions; + } + + @Override + public String toString() { + return "TxnMarkerEntry{" + + "producerId=" + producerId + + ", producerEpoch=" + producerEpoch + + ", coordinatorEpoch=" + coordinatorEpoch + + ", result=" + result + + ", partitions=" + partitions + + '}'; + } + + @Override + public boolean equals(final Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + final TxnMarkerEntry that = (TxnMarkerEntry) o; + return producerId == that.producerId && + producerEpoch == that.producerEpoch && + coordinatorEpoch == that.coordinatorEpoch && + result == that.result && + Objects.equals(partitions, that.partitions); + } + + @Override + public int hashCode() { + return Objects.hash(producerId, producerEpoch, coordinatorEpoch, result, partitions); + } + } + + public static class Builder extends AbstractRequest.Builder { + + public final WriteTxnMarkersRequestData data; + + public Builder(WriteTxnMarkersRequestData data) { + super(ApiKeys.WRITE_TXN_MARKERS); + this.data = data; + } + + public Builder(short version, final List markers) { + super(ApiKeys.WRITE_TXN_MARKERS, version); + List dataMarkers = new ArrayList<>(); + for (TxnMarkerEntry marker : markers) { + final Map topicMap = new HashMap<>(); + for (TopicPartition topicPartition : marker.partitions) { + WritableTxnMarkerTopic topic = topicMap.getOrDefault(topicPartition.topic(), + new WritableTxnMarkerTopic() + .setName(topicPartition.topic())); + topic.partitionIndexes().add(topicPartition.partition()); + topicMap.put(topicPartition.topic(), topic); + } + + dataMarkers.add(new WritableTxnMarker() + .setProducerId(marker.producerId) + .setProducerEpoch(marker.producerEpoch) + .setCoordinatorEpoch(marker.coordinatorEpoch) + .setTransactionResult(marker.transactionResult().id) + .setTopics(new ArrayList<>(topicMap.values()))); + } + this.data = new WriteTxnMarkersRequestData().setMarkers(dataMarkers); + } + + @Override + public WriteTxnMarkersRequest build(short version) { + return new WriteTxnMarkersRequest(data, version); + } + } + + private final WriteTxnMarkersRequestData data; + + private WriteTxnMarkersRequest(WriteTxnMarkersRequestData data, short version) { + super(ApiKeys.WRITE_TXN_MARKERS, version); + this.data = data; + } + + @Override + public WriteTxnMarkersRequestData data() { + return data; + } + + @Override + public WriteTxnMarkersResponse getErrorResponse(int throttleTimeMs, Throwable e) { + Errors error = Errors.forException(e); + + final Map> errors = new HashMap<>(data.markers().size()); + for (WritableTxnMarker markerEntry : data.markers()) { + Map errorsPerPartition = new HashMap<>(); + for (WritableTxnMarkerTopic topic : markerEntry.topics()) { + for (Integer partitionIdx : topic.partitionIndexes()) { + errorsPerPartition.put(new TopicPartition(topic.name(), partitionIdx), error); + } + } + errors.put(markerEntry.producerId(), errorsPerPartition); + } + + return new WriteTxnMarkersResponse(errors); + } + + public List markers() { + List markers = new ArrayList<>(); + for (WritableTxnMarker markerEntry : data.markers()) { + List topicPartitions = new ArrayList<>(); + for (WritableTxnMarkerTopic topic : markerEntry.topics()) { + for (Integer partitionIdx : topic.partitionIndexes()) { + topicPartitions.add(new TopicPartition(topic.name(), partitionIdx)); + } + } + markers.add(new TxnMarkerEntry( + markerEntry.producerId(), + markerEntry.producerEpoch(), + markerEntry.coordinatorEpoch(), + TransactionResult.forId(markerEntry.transactionResult()), + topicPartitions) + ); + } + return markers; + } + + public static WriteTxnMarkersRequest parse(ByteBuffer buffer, short version) { + return new WriteTxnMarkersRequest(new WriteTxnMarkersRequestData(new ByteBufferAccessor(buffer), version), version); + } + + @Override + public boolean equals(final Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + final WriteTxnMarkersRequest that = (WriteTxnMarkersRequest) o; + return Objects.equals(this.data, that.data); + } + + @Override + public int hashCode() { + return Objects.hash(this.data); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/requests/WriteTxnMarkersResponse.java b/clients/src/main/java/org/apache/kafka/common/requests/WriteTxnMarkersResponse.java new file mode 100644 index 0000000..fd2a834 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/requests/WriteTxnMarkersResponse.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.WriteTxnMarkersResponseData; +import org.apache.kafka.common.message.WriteTxnMarkersResponseData.WritableTxnMarkerPartitionResult; +import org.apache.kafka.common.message.WriteTxnMarkersResponseData.WritableTxnMarkerResult; +import org.apache.kafka.common.message.WriteTxnMarkersResponseData.WritableTxnMarkerTopicResult; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Possible error codes: + * + * - {@link Errors#CORRUPT_MESSAGE} + * - {@link Errors#INVALID_PRODUCER_EPOCH} + * - {@link Errors#UNKNOWN_TOPIC_OR_PARTITION} + * - {@link Errors#NOT_LEADER_OR_FOLLOWER} + * - {@link Errors#MESSAGE_TOO_LARGE} + * - {@link Errors#RECORD_LIST_TOO_LARGE} + * - {@link Errors#NOT_ENOUGH_REPLICAS} + * - {@link Errors#NOT_ENOUGH_REPLICAS_AFTER_APPEND} + * - {@link Errors#INVALID_REQUIRED_ACKS} + * - {@link Errors#TRANSACTION_COORDINATOR_FENCED} + * - {@link Errors#REQUEST_TIMED_OUT} + * - {@link Errors#CLUSTER_AUTHORIZATION_FAILED} + */ +public class WriteTxnMarkersResponse extends AbstractResponse { + + private final WriteTxnMarkersResponseData data; + + public WriteTxnMarkersResponse(Map> errors) { + super(ApiKeys.WRITE_TXN_MARKERS); + List markers = new ArrayList<>(); + for (Map.Entry> markerEntry : errors.entrySet()) { + Map responseTopicDataMap = new HashMap<>(); + for (Map.Entry topicEntry : markerEntry.getValue().entrySet()) { + TopicPartition topicPartition = topicEntry.getKey(); + String topicName = topicPartition.topic(); + + WritableTxnMarkerTopicResult topic = + responseTopicDataMap.getOrDefault(topicName, new WritableTxnMarkerTopicResult().setName(topicName)); + topic.partitions().add(new WritableTxnMarkerPartitionResult() + .setErrorCode(topicEntry.getValue().code()) + .setPartitionIndex(topicPartition.partition()) + ); + responseTopicDataMap.put(topicName, topic); + } + + markers.add(new WritableTxnMarkerResult() + .setProducerId(markerEntry.getKey()) + .setTopics(new ArrayList<>(responseTopicDataMap.values())) + ); + } + this.data = new WriteTxnMarkersResponseData() + .setMarkers(markers); + } + + public WriteTxnMarkersResponse(WriteTxnMarkersResponseData data) { + super(ApiKeys.WRITE_TXN_MARKERS); + this.data = data; + } + + @Override + public WriteTxnMarkersResponseData data() { + return data; + } + + public Map> errorsByProducerId() { + Map> errors = new HashMap<>(); + for (WritableTxnMarkerResult marker : data.markers()) { + Map topicPartitionErrorsMap = new HashMap<>(); + for (WritableTxnMarkerTopicResult topic : marker.topics()) { + for (WritableTxnMarkerPartitionResult partitionResult : topic.partitions()) { + topicPartitionErrorsMap.put(new TopicPartition(topic.name(), partitionResult.partitionIndex()), + Errors.forCode(partitionResult.errorCode())); + } + } + errors.put(marker.producerId(), topicPartitionErrorsMap); + } + return errors; + } + + @Override + public int throttleTimeMs() { + return DEFAULT_THROTTLE_TIME; + } + + @Override + public Map errorCounts() { + Map errorCounts = new HashMap<>(); + for (WritableTxnMarkerResult marker : data.markers()) { + for (WritableTxnMarkerTopicResult topic : marker.topics()) { + for (WritableTxnMarkerPartitionResult partitionResult : topic.partitions()) + updateErrorCounts(errorCounts, Errors.forCode(partitionResult.errorCode())); + } + } + return errorCounts; + } + + public static WriteTxnMarkersResponse parse(ByteBuffer buffer, short version) { + return new WriteTxnMarkersResponse(new WriteTxnMarkersResponseData(new ByteBufferAccessor(buffer), version)); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/resource/PatternType.java b/clients/src/main/java/org/apache/kafka/common/resource/PatternType.java new file mode 100644 index 0000000..0c05a0b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/resource/PatternType.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.resource; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * Resource pattern type. + */ +@InterfaceStability.Evolving +public enum PatternType { + /** + * Represents any PatternType which this client cannot understand, perhaps because this client is too old. + */ + UNKNOWN((byte) 0), + + /** + * In a filter, matches any resource pattern type. + */ + ANY((byte) 1), + + /** + * In a filter, will perform pattern matching. + * + * e.g. Given a filter of {@code ResourcePatternFilter(TOPIC, "payments.received", MATCH)`}, the filter match + * any {@link ResourcePattern} that matches topic 'payments.received'. This might include: + *

          + *
        • A Literal pattern with the same type and name, e.g. {@code ResourcePattern(TOPIC, "payments.received", LITERAL)}
        • + *
        • A Wildcard pattern with the same type, e.g. {@code ResourcePattern(TOPIC, "*", LITERAL)}
        • + *
        • A Prefixed pattern with the same type and where the name is a matching prefix, e.g. {@code ResourcePattern(TOPIC, "payments.", PREFIXED)}
        • + *
        + */ + MATCH((byte) 2), + + /** + * A literal resource name. + * + * A literal name defines the full name of a resource, e.g. topic with name 'foo', or group with name 'bob'. + * + * The special wildcard character {@code *} can be used to represent a resource with any name. + */ + LITERAL((byte) 3), + + /** + * A prefixed resource name. + * + * A prefixed name defines a prefix for a resource, e.g. topics with names that start with 'foo'. + */ + PREFIXED((byte) 4); + + private final static Map CODE_TO_VALUE = + Collections.unmodifiableMap( + Arrays.stream(PatternType.values()) + .collect(Collectors.toMap(PatternType::code, Function.identity())) + ); + + private final static Map NAME_TO_VALUE = + Collections.unmodifiableMap( + Arrays.stream(PatternType.values()) + .collect(Collectors.toMap(PatternType::name, Function.identity())) + ); + + private final byte code; + + PatternType(byte code) { + this.code = code; + } + + /** + * @return the code of this resource. + */ + public byte code() { + return code; + } + + /** + * @return whether this resource pattern type is UNKNOWN. + */ + public boolean isUnknown() { + return this == UNKNOWN; + } + + /** + * @return whether this resource pattern type is a concrete type, rather than UNKNOWN or one of the filter types. + */ + public boolean isSpecific() { + return this != UNKNOWN && this != ANY && this != MATCH; + } + + /** + * Return the PatternType with the provided code or {@link #UNKNOWN} if one cannot be found. + */ + public static PatternType fromCode(byte code) { + return CODE_TO_VALUE.getOrDefault(code, UNKNOWN); + } + + /** + * Return the PatternType with the provided name or {@link #UNKNOWN} if one cannot be found. + */ + public static PatternType fromString(String name) { + return NAME_TO_VALUE.getOrDefault(name, UNKNOWN); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/resource/Resource.java b/clients/src/main/java/org/apache/kafka/common/resource/Resource.java new file mode 100644 index 0000000..ebc5b8e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/resource/Resource.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.resource; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Objects; + +/** + * Represents a cluster resource with a tuple of (type, name). + * + * The API for this class is still evolving and we may break compatibility in minor releases, if necessary. + */ +@InterfaceStability.Evolving +public class Resource { + private final ResourceType resourceType; + private final String name; + + /** + * The name of the CLUSTER resource. + */ + public final static String CLUSTER_NAME = "kafka-cluster"; + + /** + * A resource representing the whole cluster. + */ + public final static Resource CLUSTER = new Resource(ResourceType.CLUSTER, CLUSTER_NAME); + + /** + * Create an instance of this class with the provided parameters. + * + * @param resourceType non-null resource type + * @param name non-null resource name + */ + public Resource(ResourceType resourceType, String name) { + Objects.requireNonNull(resourceType); + this.resourceType = resourceType; + Objects.requireNonNull(name); + this.name = name; + } + + /** + * Return the resource type. + */ + public ResourceType resourceType() { + return resourceType; + } + + /** + * Return the resource name. + */ + public String name() { + return name; + } + + @Override + public String toString() { + return "(resourceType=" + resourceType + ", name=" + ((name == null) ? "" : name) + ")"; + } + + /** + * Return true if this Resource has any UNKNOWN components. + */ + public boolean isUnknown() { + return resourceType.isUnknown(); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof Resource)) + return false; + Resource other = (Resource) o; + return resourceType.equals(other.resourceType) && Objects.equals(name, other.name); + } + + @Override + public int hashCode() { + return Objects.hash(resourceType, name); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/resource/ResourcePattern.java b/clients/src/main/java/org/apache/kafka/common/resource/ResourcePattern.java new file mode 100644 index 0000000..2b7504f --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/resource/ResourcePattern.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.resource; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Objects; + +/** + * Represents a pattern that is used by ACLs to match zero or more + * {@link org.apache.kafka.common.resource.Resource Resources}. + * + * The API for this class is still evolving and we may break compatibility in minor releases, if necessary. + */ +@InterfaceStability.Evolving +public class ResourcePattern { + /** + * A special literal resource name that corresponds to 'all resources of a certain type'. + */ + public static final String WILDCARD_RESOURCE = "*"; + + private final ResourceType resourceType; + private final String name; + private final PatternType patternType; + + /** + * Create a pattern using the supplied parameters. + * + * @param resourceType non-null, specific, resource type + * @param name non-null resource name, which can be the {@link #WILDCARD_RESOURCE}. + * @param patternType non-null, specific, resource pattern type, which controls how the pattern will match resource names. + */ + public ResourcePattern(ResourceType resourceType, String name, PatternType patternType) { + this.resourceType = Objects.requireNonNull(resourceType, "resourceType"); + this.name = Objects.requireNonNull(name, "name"); + this.patternType = Objects.requireNonNull(patternType, "patternType"); + + if (resourceType == ResourceType.ANY) { + throw new IllegalArgumentException("resourceType must not be ANY"); + } + + if (patternType == PatternType.MATCH || patternType == PatternType.ANY) { + throw new IllegalArgumentException("patternType must not be " + patternType); + } + } + + /** + * @return the specific resource type this pattern matches + */ + public ResourceType resourceType() { + return resourceType; + } + + /** + * @return the resource name. + */ + public String name() { + return name; + } + + /** + * @return the resource pattern type. + */ + public PatternType patternType() { + return patternType; + } + + /** + * @return a filter which matches only this pattern. + */ + public ResourcePatternFilter toFilter() { + return new ResourcePatternFilter(resourceType, name, patternType); + } + + @Override + public String toString() { + return "ResourcePattern(resourceType=" + resourceType + ", name=" + ((name == null) ? "" : name) + ", patternType=" + patternType + ")"; + } + + /** + * @return {@code true} if this Resource has any UNKNOWN components. + */ + public boolean isUnknown() { + return resourceType.isUnknown() || patternType.isUnknown(); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + final ResourcePattern resource = (ResourcePattern) o; + return resourceType == resource.resourceType && + Objects.equals(name, resource.name) && + patternType == resource.patternType; + } + + @Override + public int hashCode() { + return Objects.hash(resourceType, name, patternType); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/resource/ResourcePatternFilter.java b/clients/src/main/java/org/apache/kafka/common/resource/ResourcePatternFilter.java new file mode 100644 index 0000000..6f511c9 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/resource/ResourcePatternFilter.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.resource; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Objects; + +import static org.apache.kafka.common.resource.ResourcePattern.WILDCARD_RESOURCE; + +/** + * Represents a filter that can match {@link ResourcePattern}. + *

        + * The API for this class is still evolving and we may break compatibility in minor releases, if necessary. + */ +@InterfaceStability.Evolving +public class ResourcePatternFilter { + /** + * Matches any resource pattern. + */ + public static final ResourcePatternFilter ANY = new ResourcePatternFilter(ResourceType.ANY, null, PatternType.ANY); + + private final ResourceType resourceType; + private final String name; + private final PatternType patternType; + + /** + * Create a filter using the supplied parameters. + * + * @param resourceType non-null resource type. + * If {@link ResourceType#ANY}, the filter will ignore the resource type of the pattern. + * If any other resource type, the filter will match only patterns with the same type. + * @param name resource name or {@code null}. + * If {@code null}, the filter will ignore the name of resources. + * If {@link ResourcePattern#WILDCARD_RESOURCE}, will match only wildcard patterns. + * @param patternType non-null resource pattern type. + * If {@link PatternType#ANY}, the filter will match patterns regardless of pattern type. + * If {@link PatternType#MATCH}, the filter will match patterns that would match the supplied + * {@code name}, including a matching prefixed and wildcards patterns. + * If any other resource pattern type, the filter will match only patterns with the same type. + */ + public ResourcePatternFilter(ResourceType resourceType, String name, PatternType patternType) { + this.resourceType = Objects.requireNonNull(resourceType, "resourceType"); + this.name = name; + this.patternType = Objects.requireNonNull(patternType, "patternType"); + } + + /** + * @return {@code true} if this filter has any UNKNOWN components. + */ + public boolean isUnknown() { + return resourceType.isUnknown() || patternType.isUnknown(); + } + + /** + * @return the specific resource type this pattern matches + */ + public ResourceType resourceType() { + return resourceType; + } + + /** + * @return the resource name. + */ + public String name() { + return name; + } + + /** + * @return the resource pattern type. + */ + public PatternType patternType() { + return patternType; + } + + /** + * @return {@code true} if this filter matches the given pattern. + */ + public boolean matches(ResourcePattern pattern) { + if (!resourceType.equals(ResourceType.ANY) && !resourceType.equals(pattern.resourceType())) { + return false; + } + + if (!patternType.equals(PatternType.ANY) && !patternType.equals(PatternType.MATCH) && !patternType.equals(pattern.patternType())) { + return false; + } + + if (name == null) { + return true; + } + + if (patternType.equals(PatternType.ANY) || patternType.equals(pattern.patternType())) { + return name.equals(pattern.name()); + } + + switch (pattern.patternType()) { + case LITERAL: + return name.equals(pattern.name()) || pattern.name().equals(WILDCARD_RESOURCE); + + case PREFIXED: + return name.startsWith(pattern.name()); + + default: + throw new IllegalArgumentException("Unsupported PatternType: " + pattern.patternType()); + } + } + + /** + * @return {@code true} if this filter could only match one pattern. + * In other words, if there are no ANY or UNKNOWN fields. + */ + public boolean matchesAtMostOne() { + return findIndefiniteField() == null; + } + + /** + * @return a string describing any ANY or UNKNOWN field, or null if there is no such field. + */ + public String findIndefiniteField() { + if (resourceType == ResourceType.ANY) + return "Resource type is ANY."; + if (resourceType == ResourceType.UNKNOWN) + return "Resource type is UNKNOWN."; + if (name == null) + return "Resource name is NULL."; + if (patternType == PatternType.MATCH) + return "Resource pattern type is MATCH."; + if (patternType == PatternType.UNKNOWN) + return "Resource pattern type is UNKNOWN."; + return null; + } + + @Override + public String toString() { + return "ResourcePattern(resourceType=" + resourceType + ", name=" + ((name == null) ? "" : name) + ", patternType=" + patternType + ")"; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + final ResourcePatternFilter resource = (ResourcePatternFilter) o; + return resourceType == resource.resourceType && + Objects.equals(name, resource.name) && + patternType == resource.patternType; + } + + @Override + public int hashCode() { + return Objects.hash(resourceType, name, patternType); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/resource/ResourceType.java b/clients/src/main/java/org/apache/kafka/common/resource/ResourceType.java new file mode 100644 index 0000000..2ce653f --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/resource/ResourceType.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.resource; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.HashMap; +import java.util.Locale; + +/** + * Represents a type of resource which an ACL can be applied to. + * + * The API for this class is still evolving and we may break compatibility in minor releases, if necessary. + */ +@InterfaceStability.Evolving +public enum ResourceType { + /** + * Represents any ResourceType which this client cannot understand, + * perhaps because this client is too old. + */ + UNKNOWN((byte) 0), + + /** + * In a filter, matches any ResourceType. + */ + ANY((byte) 1), + + /** + * A Kafka topic. + */ + TOPIC((byte) 2), + + /** + * A consumer group. + */ + GROUP((byte) 3), + + /** + * The cluster as a whole. + */ + CLUSTER((byte) 4), + + /** + * A transactional ID. + */ + TRANSACTIONAL_ID((byte) 5), + + /** + * A token ID. + */ + DELEGATION_TOKEN((byte) 6); + + private final static HashMap CODE_TO_VALUE = new HashMap<>(); + + static { + for (ResourceType resourceType : ResourceType.values()) { + CODE_TO_VALUE.put(resourceType.code, resourceType); + } + } + + /** + * Parse the given string as an ACL resource type. + * + * @param str The string to parse. + * + * @return The ResourceType, or UNKNOWN if the string could not be matched. + */ + public static ResourceType fromString(String str) throws IllegalArgumentException { + try { + return ResourceType.valueOf(str.toUpperCase(Locale.ROOT)); + } catch (IllegalArgumentException e) { + return UNKNOWN; + } + } + + /** + * Return the ResourceType with the provided code or `ResourceType.UNKNOWN` if one cannot be found. + */ + public static ResourceType fromCode(byte code) { + ResourceType resourceType = CODE_TO_VALUE.get(code); + if (resourceType == null) { + return UNKNOWN; + } + return resourceType; + } + + private final byte code; + + ResourceType(byte code) { + this.code = code; + } + + /** + * Return the code of this resource. + */ + public byte code() { + return code; + } + + /** + * Return whether this resource type is UNKNOWN. + */ + public boolean isUnknown() { + return this == UNKNOWN; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/JaasConfig.java b/clients/src/main/java/org/apache/kafka/common/security/JaasConfig.java new file mode 100644 index 0000000..5e837a6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/JaasConfig.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security; + +import java.io.IOException; +import java.io.StreamTokenizer; +import java.io.StringReader; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +import javax.security.auth.login.AppConfigurationEntry; +import javax.security.auth.login.Configuration; +import javax.security.auth.login.AppConfigurationEntry.LoginModuleControlFlag; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.config.SaslConfigs; + +/** + * JAAS configuration parser that constructs a JAAS configuration object with a single + * login context from the Kafka configuration option {@link SaslConfigs#SASL_JAAS_CONFIG}. + *

        + * JAAS configuration file format is described here. + * The format of the property value is: + *

        + * {@code
        + *     (=)*;
        + * }
        + * 
        + */ +class JaasConfig extends Configuration { + + private final String loginContextName; + private final List configEntries; + + public JaasConfig(String loginContextName, String jaasConfigParams) { + StreamTokenizer tokenizer = new StreamTokenizer(new StringReader(jaasConfigParams)); + tokenizer.slashSlashComments(true); + tokenizer.slashStarComments(true); + tokenizer.wordChars('-', '-'); + tokenizer.wordChars('_', '_'); + tokenizer.wordChars('$', '$'); + + try { + configEntries = new ArrayList<>(); + while (tokenizer.nextToken() != StreamTokenizer.TT_EOF) { + configEntries.add(parseAppConfigurationEntry(tokenizer)); + } + if (configEntries.isEmpty()) + throw new IllegalArgumentException("Login module not specified in JAAS config"); + + this.loginContextName = loginContextName; + + } catch (IOException e) { + throw new KafkaException("Unexpected exception while parsing JAAS config"); + } + } + + @Override + public AppConfigurationEntry[] getAppConfigurationEntry(String name) { + if (this.loginContextName.equals(name)) + return configEntries.toArray(new AppConfigurationEntry[0]); + else + return null; + } + + private LoginModuleControlFlag loginModuleControlFlag(String flag) { + if (flag == null) + throw new IllegalArgumentException("Login module control flag is not available in the JAAS config"); + + LoginModuleControlFlag controlFlag; + switch (flag.toUpperCase(Locale.ROOT)) { + case "REQUIRED": + controlFlag = LoginModuleControlFlag.REQUIRED; + break; + case "REQUISITE": + controlFlag = LoginModuleControlFlag.REQUISITE; + break; + case "SUFFICIENT": + controlFlag = LoginModuleControlFlag.SUFFICIENT; + break; + case "OPTIONAL": + controlFlag = LoginModuleControlFlag.OPTIONAL; + break; + default: + throw new IllegalArgumentException("Invalid login module control flag '" + flag + "' in JAAS config"); + } + return controlFlag; + } + + private AppConfigurationEntry parseAppConfigurationEntry(StreamTokenizer tokenizer) throws IOException { + String loginModule = tokenizer.sval; + if (tokenizer.nextToken() == StreamTokenizer.TT_EOF) + throw new IllegalArgumentException("Login module control flag not specified in JAAS config"); + LoginModuleControlFlag controlFlag = loginModuleControlFlag(tokenizer.sval); + Map options = new HashMap<>(); + while (tokenizer.nextToken() != StreamTokenizer.TT_EOF && tokenizer.ttype != ';') { + String key = tokenizer.sval; + if (tokenizer.nextToken() != '=' || tokenizer.nextToken() == StreamTokenizer.TT_EOF || tokenizer.sval == null) + throw new IllegalArgumentException("Value not specified for key '" + key + "' in JAAS config"); + String value = tokenizer.sval; + options.put(key, value); + } + if (tokenizer.ttype != ';') + throw new IllegalArgumentException("JAAS config entry not terminated by semi-colon"); + return new AppConfigurationEntry(loginModule, controlFlag, options); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/JaasContext.java b/clients/src/main/java/org/apache/kafka/common/security/JaasContext.java new file mode 100644 index 0000000..48216a8 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/JaasContext.java @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security; + +import org.apache.kafka.common.config.SaslConfigs; +import org.apache.kafka.common.config.types.Password; +import org.apache.kafka.common.network.ListenerName; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.security.auth.login.AppConfigurationEntry; +import javax.security.auth.login.Configuration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +public class JaasContext { + + private static final Logger LOG = LoggerFactory.getLogger(JaasContext.class); + + private static final String GLOBAL_CONTEXT_NAME_SERVER = "KafkaServer"; + private static final String GLOBAL_CONTEXT_NAME_CLIENT = "KafkaClient"; + + /** + * Returns an instance of this class. + * + * The context will contain the configuration specified by the JAAS configuration property + * {@link SaslConfigs#SASL_JAAS_CONFIG} with prefix `listener.name.{listenerName}.{mechanism}.` + * with listenerName and mechanism in lower case. The context `KafkaServer` will be returned + * with a single login context entry loaded from the property. + *

        + * If the property is not defined, the context will contain the default Configuration and + * the context name will be one of: + *

          + *
        1. Lowercased listener name followed by a period and the string `KafkaServer`
        2. + *
        3. The string `KafkaServer`
        4. + *
        + * If both are valid entries in the default JAAS configuration, the first option is chosen. + *

        + * + * @throws IllegalArgumentException if listenerName or mechanism is not defined. + */ + public static JaasContext loadServerContext(ListenerName listenerName, String mechanism, Map configs) { + if (listenerName == null) + throw new IllegalArgumentException("listenerName should not be null for SERVER"); + if (mechanism == null) + throw new IllegalArgumentException("mechanism should not be null for SERVER"); + String listenerContextName = listenerName.value().toLowerCase(Locale.ROOT) + "." + GLOBAL_CONTEXT_NAME_SERVER; + Password dynamicJaasConfig = (Password) configs.get(mechanism.toLowerCase(Locale.ROOT) + "." + SaslConfigs.SASL_JAAS_CONFIG); + if (dynamicJaasConfig == null && configs.get(SaslConfigs.SASL_JAAS_CONFIG) != null) + LOG.warn("Server config {} should be prefixed with SASL mechanism name, ignoring config", SaslConfigs.SASL_JAAS_CONFIG); + return load(Type.SERVER, listenerContextName, GLOBAL_CONTEXT_NAME_SERVER, dynamicJaasConfig); + } + + /** + * Returns an instance of this class. + * + * If JAAS configuration property @link SaslConfigs#SASL_JAAS_CONFIG} is specified, + * the configuration object is created by parsing the property value. Otherwise, the default Configuration + * is returned. The context name is always `KafkaClient`. + * + */ + public static JaasContext loadClientContext(Map configs) { + Password dynamicJaasConfig = (Password) configs.get(SaslConfigs.SASL_JAAS_CONFIG); + return load(JaasContext.Type.CLIENT, null, GLOBAL_CONTEXT_NAME_CLIENT, dynamicJaasConfig); + } + + static JaasContext load(JaasContext.Type contextType, String listenerContextName, + String globalContextName, Password dynamicJaasConfig) { + if (dynamicJaasConfig != null) { + JaasConfig jaasConfig = new JaasConfig(globalContextName, dynamicJaasConfig.value()); + AppConfigurationEntry[] contextModules = jaasConfig.getAppConfigurationEntry(globalContextName); + if (contextModules == null || contextModules.length == 0) + throw new IllegalArgumentException("JAAS config property does not contain any login modules"); + else if (contextModules.length != 1) + throw new IllegalArgumentException("JAAS config property contains " + contextModules.length + " login modules, should be 1 module"); + return new JaasContext(globalContextName, contextType, jaasConfig, dynamicJaasConfig); + } else + return defaultContext(contextType, listenerContextName, globalContextName); + } + + private static JaasContext defaultContext(JaasContext.Type contextType, String listenerContextName, + String globalContextName) { + String jaasConfigFile = System.getProperty(JaasUtils.JAVA_LOGIN_CONFIG_PARAM); + if (jaasConfigFile == null) { + if (contextType == Type.CLIENT) { + LOG.debug("System property '" + JaasUtils.JAVA_LOGIN_CONFIG_PARAM + "' and Kafka SASL property '" + + SaslConfigs.SASL_JAAS_CONFIG + "' are not set, using default JAAS configuration."); + } else { + LOG.debug("System property '" + JaasUtils.JAVA_LOGIN_CONFIG_PARAM + "' is not set, using default JAAS " + + "configuration."); + } + } + + Configuration jaasConfig = Configuration.getConfiguration(); + + AppConfigurationEntry[] configEntries = null; + String contextName = globalContextName; + + if (listenerContextName != null) { + configEntries = jaasConfig.getAppConfigurationEntry(listenerContextName); + if (configEntries != null) + contextName = listenerContextName; + } + + if (configEntries == null) + configEntries = jaasConfig.getAppConfigurationEntry(globalContextName); + + if (configEntries == null) { + String listenerNameText = listenerContextName == null ? "" : " or '" + listenerContextName + "'"; + String errorMessage = "Could not find a '" + globalContextName + "'" + listenerNameText + " entry in the JAAS " + + "configuration. System property '" + JaasUtils.JAVA_LOGIN_CONFIG_PARAM + "' is " + + (jaasConfigFile == null ? "not set" : jaasConfigFile); + throw new IllegalArgumentException(errorMessage); + } + + return new JaasContext(contextName, contextType, jaasConfig, null); + } + + /** + * The type of the SASL login context, it should be SERVER for the broker and CLIENT for the clients (consumer, producer, + * etc.). This is used to validate behaviour (e.g. some functionality is only available in the broker or clients). + */ + public enum Type { CLIENT, SERVER } + + private final String name; + private final Type type; + private final Configuration configuration; + private final List configurationEntries; + private final Password dynamicJaasConfig; + + public JaasContext(String name, Type type, Configuration configuration, Password dynamicJaasConfig) { + this.name = name; + this.type = type; + this.configuration = configuration; + AppConfigurationEntry[] entries = configuration.getAppConfigurationEntry(name); + if (entries == null) + throw new IllegalArgumentException("Could not find a '" + name + "' entry in this JAAS configuration."); + this.configurationEntries = Collections.unmodifiableList(new ArrayList<>(Arrays.asList(entries))); + this.dynamicJaasConfig = dynamicJaasConfig; + } + + public String name() { + return name; + } + + public Type type() { + return type; + } + + public Configuration configuration() { + return configuration; + } + + public List configurationEntries() { + return configurationEntries; + } + + public Password dynamicJaasConfig() { + return dynamicJaasConfig; + } + + /** + * Returns the configuration option for key from this context. + * If login module name is specified, return option value only from that module. + */ + public static String configEntryOption(List configurationEntries, String key, String loginModuleName) { + for (AppConfigurationEntry entry : configurationEntries) { + if (loginModuleName != null && !loginModuleName.equals(entry.getLoginModuleName())) + continue; + Object val = entry.getOptions().get(key); + if (val != null) + return (String) val; + } + return null; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/JaasUtils.java b/clients/src/main/java/org/apache/kafka/common/security/JaasUtils.java new file mode 100644 index 0000000..baff563 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/JaasUtils.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security; + +import org.apache.kafka.common.KafkaException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.security.auth.login.Configuration; + +public final class JaasUtils { + private static final Logger LOG = LoggerFactory.getLogger(JaasUtils.class); + public static final String JAVA_LOGIN_CONFIG_PARAM = "java.security.auth.login.config"; + + public static final String SERVICE_NAME = "serviceName"; + + public static final String ZK_SASL_CLIENT = "zookeeper.sasl.client"; + public static final String ZK_LOGIN_CONTEXT_NAME_KEY = "zookeeper.sasl.clientconfig"; + + private static final String DEFAULT_ZK_LOGIN_CONTEXT_NAME = "Client"; + private static final String DEFAULT_ZK_SASL_CLIENT = "true"; + + private JaasUtils() {} + + public static String zkSecuritySysConfigString() { + String loginConfig = System.getProperty(JAVA_LOGIN_CONFIG_PARAM); + String clientEnabled = System.getProperty(ZK_SASL_CLIENT, "default:" + DEFAULT_ZK_SASL_CLIENT); + String contextName = System.getProperty(ZK_LOGIN_CONTEXT_NAME_KEY, "default:" + DEFAULT_ZK_LOGIN_CONTEXT_NAME); + return "[" + + JAVA_LOGIN_CONFIG_PARAM + "=" + loginConfig + + ", " + + ZK_SASL_CLIENT + "=" + clientEnabled + + ", " + + ZK_LOGIN_CONTEXT_NAME_KEY + "=" + contextName + + "]"; + } + + public static boolean isZkSaslEnabled() { + // Technically a client must also check if TLS mutual authentication has been configured, + // but we will leave that up to the client code to determine since direct connectivity to ZooKeeper + // has been deprecated in many clients and we don't wish to re-introduce a ZooKeeper jar dependency here. + boolean zkSaslEnabled = Boolean.parseBoolean(System.getProperty(ZK_SASL_CLIENT, DEFAULT_ZK_SASL_CLIENT)); + String zkLoginContextName = System.getProperty(ZK_LOGIN_CONTEXT_NAME_KEY, DEFAULT_ZK_LOGIN_CONTEXT_NAME); + + LOG.debug("Checking login config for Zookeeper JAAS context {}", zkSecuritySysConfigString()); + + boolean foundLoginConfigEntry; + try { + Configuration loginConf = Configuration.getConfiguration(); + foundLoginConfigEntry = loginConf.getAppConfigurationEntry(zkLoginContextName) != null; + } catch (Exception e) { + throw new KafkaException("Exception while loading Zookeeper JAAS login context " + + zkSecuritySysConfigString(), e); + } + + if (foundLoginConfigEntry && !zkSaslEnabled) { + LOG.error("JAAS configuration is present, but system property " + + ZK_SASL_CLIENT + " is set to false, which disables " + + "SASL in the ZooKeeper client"); + throw new KafkaException("Exception while determining if ZooKeeper is secure " + + zkSecuritySysConfigString()); + } + + return foundLoginConfigEntry; + } +} + diff --git a/clients/src/main/java/org/apache/kafka/common/security/auth/AuthenticateCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/auth/AuthenticateCallbackHandler.java new file mode 100644 index 0000000..8951d3a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/auth/AuthenticateCallbackHandler.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.auth; + +import java.util.List; +import java.util.Map; + +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.login.AppConfigurationEntry; + +/* + * Callback handler for SASL-based authentication + */ +public interface AuthenticateCallbackHandler extends CallbackHandler { + + /** + * Configures this callback handler for the specified SASL mechanism. + * + * @param configs Key-value pairs containing the parsed configuration options of + * the client or broker. Note that these are the Kafka configuration options + * and not the JAAS configuration options. JAAS config options may be obtained + * from `jaasConfigEntries` for callbacks which obtain some configs from the + * JAAS configuration. For configs that may be specified as both Kafka config + * as well as JAAS config (e.g. sasl.kerberos.service.name), the configuration + * is treated as invalid if conflicting values are provided. + * @param saslMechanism Negotiated SASL mechanism. For clients, this is the SASL + * mechanism configured for the client. For brokers, this is the mechanism + * negotiated with the client and is one of the mechanisms enabled on the broker. + * @param jaasConfigEntries JAAS configuration entries from the JAAS login context. + * This list contains a single entry for clients and may contain more than + * one entry for brokers if multiple mechanisms are enabled on a listener using + * static JAAS configuration where there is no mapping between mechanisms and + * login module entries. In this case, callback handlers can use the login module in + * `jaasConfigEntries` to identify the entry corresponding to `saslMechanism`. + * Alternatively, dynamic JAAS configuration option + * {@link org.apache.kafka.common.config.SaslConfigs#SASL_JAAS_CONFIG} may be + * configured on brokers with listener and mechanism prefix, in which case + * only the configuration entry corresponding to `saslMechanism` will be provided + * in `jaasConfigEntries`. + */ + void configure(Map configs, String saslMechanism, List jaasConfigEntries); + + /** + * Closes this instance. + */ + void close(); +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/auth/AuthenticationContext.java b/clients/src/main/java/org/apache/kafka/common/security/auth/AuthenticationContext.java new file mode 100644 index 0000000..a8abea8 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/auth/AuthenticationContext.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.auth; + +import java.net.InetAddress; + + +/** + * An object representing contextual information from the authentication session. See + * {@link PlaintextAuthenticationContext}, {@link SaslAuthenticationContext} + * and {@link SslAuthenticationContext}. This class is only used in the broker. + */ +public interface AuthenticationContext { + /** + * Underlying security protocol of the authentication session. + */ + SecurityProtocol securityProtocol(); + + /** + * Address of the authenticated client + */ + InetAddress clientAddress(); + + /** + * Name of the listener used for the connection + */ + String listenerName(); +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/auth/KafkaPrincipal.java b/clients/src/main/java/org/apache/kafka/common/security/auth/KafkaPrincipal.java new file mode 100644 index 0000000..8b83e32 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/auth/KafkaPrincipal.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.auth; + +import java.security.Principal; + +import static java.util.Objects.requireNonNull; + +/** + *

        Principals in Kafka are defined by a type and a name. The principal type will always be "User" + * for the simple authorizer that is enabled by default, but custom authorizers can leverage different + * principal types (such as to enable group or role-based ACLs). The {@link KafkaPrincipalBuilder} interface + * is used when you need to derive a different principal type from the authentication context, or when + * you need to represent relations between different principals. For example, you could extend + * {@link KafkaPrincipal} in order to link a user principal to one or more role principals. + * + *

        For custom extensions of {@link KafkaPrincipal}, there two key points to keep in mind: + *

          + *
        1. To be compatible with the ACL APIs provided by Kafka (including the command line tool), each ACL + * can only represent a permission granted to a single principal (consisting of a principal type and name). + * It is possible to use richer ACL semantics, but you must implement your own mechanisms for adding + * and removing ACLs. + *
        2. In general, {@link KafkaPrincipal} extensions are only useful when the corresponding Authorizer + * is also aware of the extension. If you have a {@link KafkaPrincipalBuilder} which derives user groups + * from the authentication context (e.g. from an SSL client certificate), then you need a custom + * authorizer which is capable of using the additional group information. + *
        + */ +public class KafkaPrincipal implements Principal { + public static final String USER_TYPE = "User"; + public final static KafkaPrincipal ANONYMOUS = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "ANONYMOUS"); + + private final String principalType; + private final String name; + private volatile boolean tokenAuthenticated; + + public KafkaPrincipal(String principalType, String name) { + this(principalType, name, false); + } + + public KafkaPrincipal(String principalType, String name, boolean tokenAuthenticated) { + this.principalType = requireNonNull(principalType, "Principal type cannot be null"); + this.name = requireNonNull(name, "Principal name cannot be null"); + this.tokenAuthenticated = tokenAuthenticated; + } + + @Override + public String toString() { + return principalType + ":" + name; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null) return false; + if (getClass() != o.getClass()) return false; + + KafkaPrincipal that = (KafkaPrincipal) o; + return principalType.equals(that.principalType) && name.equals(that.name); + } + + @Override + public int hashCode() { + int result = principalType != null ? principalType.hashCode() : 0; + result = 31 * result + (name != null ? name.hashCode() : 0); + return result; + } + + @Override + public String getName() { + return name; + } + + public String getPrincipalType() { + return principalType; + } + + public void tokenAuthenticated(boolean tokenAuthenticated) { + this.tokenAuthenticated = tokenAuthenticated; + } + + public boolean tokenAuthenticated() { + return tokenAuthenticated; + } +} + diff --git a/clients/src/main/java/org/apache/kafka/common/security/auth/KafkaPrincipalBuilder.java b/clients/src/main/java/org/apache/kafka/common/security/auth/KafkaPrincipalBuilder.java new file mode 100644 index 0000000..941d3b1 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/auth/KafkaPrincipalBuilder.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.auth; + +/** + * Pluggable principal builder interface which supports both SSL authentication through + * {@link SslAuthenticationContext} and SASL through {@link SaslAuthenticationContext}. + * + * Note that the {@link org.apache.kafka.common.Configurable} and {@link java.io.Closeable} + * interfaces are respected if implemented. Additionally, implementations must provide a + * default no-arg constructor. + */ +public interface KafkaPrincipalBuilder { + /** + * Build a kafka principal from the authentication context. + * @param context The authentication context (either {@link SslAuthenticationContext} or + * {@link SaslAuthenticationContext}) + * @return The built principal which may provide additional enrichment through a subclass of + * {@link KafkaPrincipalBuilder}. + */ + KafkaPrincipal build(AuthenticationContext context); +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/auth/KafkaPrincipalSerde.java b/clients/src/main/java/org/apache/kafka/common/security/auth/KafkaPrincipalSerde.java new file mode 100644 index 0000000..c32f7f5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/auth/KafkaPrincipalSerde.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.auth; + +import org.apache.kafka.common.errors.SerializationException; + +/** + * Serializer/Deserializer interface for {@link KafkaPrincipal} for the purpose of inter-broker forwarding. + * Any serialization/deserialization failure should raise a {@link SerializationException} to be consistent. + */ +public interface KafkaPrincipalSerde { + + /** + * Serialize a {@link KafkaPrincipal} into byte array. + * + * @param principal principal to be serialized + * @return serialized bytes + * @throws SerializationException + */ + byte[] serialize(KafkaPrincipal principal) throws SerializationException; + + /** + * Deserialize a {@link KafkaPrincipal} from byte array. + * @param bytes byte array to be deserialized + * @return the deserialized principal + * @throws SerializationException + */ + KafkaPrincipal deserialize(byte[] bytes) throws SerializationException; +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/auth/Login.java b/clients/src/main/java/org/apache/kafka/common/security/auth/Login.java new file mode 100644 index 0000000..eda5e7a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/auth/Login.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.auth; + +import java.util.Map; + +import javax.security.auth.Subject; +import javax.security.auth.login.Configuration; +import javax.security.auth.login.LoginContext; +import javax.security.auth.login.LoginException; + +/** + * Login interface for authentication. + */ +public interface Login { + + /** + * Configures this login instance. + * @param configs Key-value pairs containing the parsed configuration options of + * the client or broker. Note that these are the Kafka configuration options + * and not the JAAS configuration options. The JAAS options may be obtained + * from `jaasConfiguration`. + * @param contextName JAAS context name for this login which may be used to obtain + * the login context from `jaasConfiguration`. + * @param jaasConfiguration JAAS configuration containing the login context named + * `contextName`. If static JAAS configuration is used, this `Configuration` + * may also contain other login contexts. + * @param loginCallbackHandler Login callback handler instance to use for this Login. + * Login callback handler class may be configured using + * {@link org.apache.kafka.common.config.SaslConfigs#SASL_LOGIN_CALLBACK_HANDLER_CLASS}. + */ + void configure(Map configs, String contextName, Configuration jaasConfiguration, + AuthenticateCallbackHandler loginCallbackHandler); + + /** + * Performs login for each login module specified for the login context of this instance. + */ + LoginContext login() throws LoginException; + + /** + * Returns the authenticated subject of this login context. + */ + Subject subject(); + + /** + * Returns the service name to be used for SASL. + */ + String serviceName(); + + /** + * Closes this instance. + */ + void close(); +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/auth/PlaintextAuthenticationContext.java b/clients/src/main/java/org/apache/kafka/common/security/auth/PlaintextAuthenticationContext.java new file mode 100644 index 0000000..a111f21 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/auth/PlaintextAuthenticationContext.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.auth; + +import java.net.InetAddress; + +public class PlaintextAuthenticationContext implements AuthenticationContext { + private final InetAddress clientAddress; + private final String listenerName; + + public PlaintextAuthenticationContext(InetAddress clientAddress, String listenerName) { + this.clientAddress = clientAddress; + this.listenerName = listenerName; + } + + @Override + public SecurityProtocol securityProtocol() { + return SecurityProtocol.PLAINTEXT; + } + + @Override + public InetAddress clientAddress() { + return clientAddress; + } + + @Override + public String listenerName() { + return listenerName; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/auth/SaslAuthenticationContext.java b/clients/src/main/java/org/apache/kafka/common/security/auth/SaslAuthenticationContext.java new file mode 100644 index 0000000..5b22625 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/auth/SaslAuthenticationContext.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.auth; + +import javax.net.ssl.SSLSession; +import javax.security.sasl.SaslServer; + +import java.net.InetAddress; +import java.util.Optional; + +public class SaslAuthenticationContext implements AuthenticationContext { + private final SaslServer server; + private final SecurityProtocol securityProtocol; + private final InetAddress clientAddress; + private final String listenerName; + private final Optional sslSession; + + public SaslAuthenticationContext(SaslServer server, SecurityProtocol securityProtocol, InetAddress clientAddress, String listenerName) { + this(server, securityProtocol, clientAddress, listenerName, Optional.empty()); + } + + public SaslAuthenticationContext(SaslServer server, SecurityProtocol securityProtocol, + InetAddress clientAddress, + String listenerName, + Optional sslSession) { + this.server = server; + this.securityProtocol = securityProtocol; + this.clientAddress = clientAddress; + this.listenerName = listenerName; + this.sslSession = sslSession; + } + + public SaslServer server() { + return server; + } + + /** + * Returns SSL session for the connection if security protocol is SASL_SSL. If SSL + * mutual client authentication is enabled for the listener, peer principal can be + * determined using {@link SSLSession#getPeerPrincipal()}. + */ + public Optional sslSession() { + return sslSession; + } + + @Override + public SecurityProtocol securityProtocol() { + return securityProtocol; + } + + @Override + public InetAddress clientAddress() { + return clientAddress; + } + + @Override + public String listenerName() { + return listenerName; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/auth/SaslExtensions.java b/clients/src/main/java/org/apache/kafka/common/security/auth/SaslExtensions.java new file mode 100644 index 0000000..c129f1e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/auth/SaslExtensions.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.auth; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * A simple immutable value object class holding customizable SASL extensions + */ +public class SaslExtensions { + /** + * An "empty" instance indicating no SASL extensions + */ + public static final SaslExtensions NO_SASL_EXTENSIONS = new SaslExtensions(Collections.emptyMap()); + private final Map extensionsMap; + + public SaslExtensions(Map extensionsMap) { + this.extensionsMap = Collections.unmodifiableMap(new HashMap<>(extensionsMap)); + } + + /** + * Returns an immutable map of the extension names and their values + */ + public Map map() { + return extensionsMap; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + return extensionsMap.equals(((SaslExtensions) o).extensionsMap); + } + + @Override + public String toString() { + return extensionsMap.toString(); + } + + @Override + public int hashCode() { + return extensionsMap.hashCode(); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/auth/SaslExtensionsCallback.java b/clients/src/main/java/org/apache/kafka/common/security/auth/SaslExtensionsCallback.java new file mode 100644 index 0000000..c5bd449 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/auth/SaslExtensionsCallback.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.auth; + +import java.util.Objects; + +import javax.security.auth.callback.Callback; + +/** + * Optional callback used for SASL mechanisms if any extensions need to be set + * in the SASL exchange. + */ +public class SaslExtensionsCallback implements Callback { + private SaslExtensions extensions = SaslExtensions.NO_SASL_EXTENSIONS; + + /** + * Returns always non-null {@link SaslExtensions} consisting of the extension + * names and values that are sent by the client to the server in the initial + * client SASL authentication message. The default value is + * {@link SaslExtensions#NO_SASL_EXTENSIONS} so that if this callback is + * unhandled the client will see a non-null value. + */ + public SaslExtensions extensions() { + return extensions; + } + + /** + * Sets the SASL extensions on this callback. + * + * @param extensions + * the mandatory extensions to set + */ + public void extensions(SaslExtensions extensions) { + this.extensions = Objects.requireNonNull(extensions, "extensions must not be null"); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/auth/SecurityProtocol.java b/clients/src/main/java/org/apache/kafka/common/security/auth/SecurityProtocol.java new file mode 100644 index 0000000..f48a194 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/auth/SecurityProtocol.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.auth; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +public enum SecurityProtocol { + /** Un-authenticated, non-encrypted channel */ + PLAINTEXT(0, "PLAINTEXT"), + /** SSL channel */ + SSL(1, "SSL"), + /** SASL authenticated, non-encrypted channel */ + SASL_PLAINTEXT(2, "SASL_PLAINTEXT"), + /** SASL authenticated, SSL channel */ + SASL_SSL(3, "SASL_SSL"); + + private static final Map CODE_TO_SECURITY_PROTOCOL; + private static final List NAMES; + + static { + SecurityProtocol[] protocols = SecurityProtocol.values(); + List names = new ArrayList<>(protocols.length); + Map codeToSecurityProtocol = new HashMap<>(protocols.length); + for (SecurityProtocol proto : protocols) { + codeToSecurityProtocol.put(proto.id, proto); + names.add(proto.name); + } + CODE_TO_SECURITY_PROTOCOL = Collections.unmodifiableMap(codeToSecurityProtocol); + NAMES = Collections.unmodifiableList(names); + } + + /** The permanent and immutable id of a security protocol -- this can't change, and must match kafka.cluster.SecurityProtocol */ + public final short id; + + /** Name of the security protocol. This may be used by client configuration. */ + public final String name; + + SecurityProtocol(int id, String name) { + this.id = (short) id; + this.name = name; + } + + public static List names() { + return NAMES; + } + + public static SecurityProtocol forId(short id) { + return CODE_TO_SECURITY_PROTOCOL.get(id); + } + + /** Case insensitive lookup by protocol name */ + public static SecurityProtocol forName(String name) { + return SecurityProtocol.valueOf(name.toUpperCase(Locale.ROOT)); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/auth/SecurityProviderCreator.java b/clients/src/main/java/org/apache/kafka/common/security/auth/SecurityProviderCreator.java new file mode 100644 index 0000000..ae56f9a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/auth/SecurityProviderCreator.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.auth; + +import org.apache.kafka.common.Configurable; +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.security.Provider; +import java.util.Map; + +/** + * An interface for generating security providers. + */ +@InterfaceStability.Evolving +public interface SecurityProviderCreator extends Configurable { + + /** + * Configure method is used to configure the generator to create the Security Provider + * @param config configuration parameters for initialising security provider + */ + default void configure(Map config) { + + } + + /** + * Generate the security provider configured + */ + Provider getProvider(); +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/auth/SslAuthenticationContext.java b/clients/src/main/java/org/apache/kafka/common/security/auth/SslAuthenticationContext.java new file mode 100644 index 0000000..88819f9 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/auth/SslAuthenticationContext.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.auth; + +import javax.net.ssl.SSLSession; +import java.net.InetAddress; + +public class SslAuthenticationContext implements AuthenticationContext { + private final SSLSession session; + private final InetAddress clientAddress; + private final String listenerName; + + public SslAuthenticationContext(SSLSession session, InetAddress clientAddress, String listenerName) { + this.session = session; + this.clientAddress = clientAddress; + this.listenerName = listenerName; + } + + public SSLSession session() { + return session; + } + + @Override + public SecurityProtocol securityProtocol() { + return SecurityProtocol.SSL; + } + + @Override + public InetAddress clientAddress() { + return clientAddress; + } + + @Override + public String listenerName() { + return listenerName; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/auth/SslEngineFactory.java b/clients/src/main/java/org/apache/kafka/common/security/auth/SslEngineFactory.java new file mode 100644 index 0000000..586017d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/auth/SslEngineFactory.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.auth; + +import org.apache.kafka.common.Configurable; + +import javax.net.ssl.SSLEngine; +import java.io.Closeable; +import java.security.KeyStore; +import java.util.Map; +import java.util.Set; + +/** + * Plugin interface for allowing creation of SSLEngine object in a custom way. + * For example, you can use this to customize loading your key material and trust material needed for SSLContext. + * This is complementary to the existing Java Security Provider mechanism which allows the entire provider + * to be replaced with a custom provider. In scenarios where only the configuration mechanism for SSL engines + * need to be updated, this interface provides a convenient method for overriding the default implementation. + */ +public interface SslEngineFactory extends Configurable, Closeable { + + /** + * Creates a new SSLEngine object to be used by the client. + * + * @param peerHost The peer host to use. This is used in client mode if endpoint validation is enabled. + * @param peerPort The peer port to use. This is a hint and not used for validation. + * @param endpointIdentification Endpoint identification algorithm for client mode. + * @return The new SSLEngine. + */ + SSLEngine createClientSslEngine(String peerHost, int peerPort, String endpointIdentification); + + /** + * Creates a new SSLEngine object to be used by the server. + * + * @param peerHost The peer host to use. This is a hint and not used for validation. + * @param peerPort The peer port to use. This is a hint and not used for validation. + * @return The new SSLEngine. + */ + SSLEngine createServerSslEngine(String peerHost, int peerPort); + + /** + * Returns true if SSLEngine needs to be rebuilt. This method will be called when reconfiguration is triggered on + * the SslFactory used to create SSL engines. Based on the new configs provided in nextConfigs, this method + * will decide whether underlying SSLEngine object needs to be rebuilt. If this method returns true, the + * SslFactory will create a new instance of this object with nextConfigs and run other + * checks before deciding to use the new object for new incoming connection requests. Existing connections + * are not impacted by this and will not see any changes done as part of reconfiguration. + *

        + * For example, if the implementation depends on file-based key material, it can check if the file was updated + * compared to the previous/last-loaded timestamp and return true. + *

        + * + * @param nextConfigs The new configuration we want to use. + * @return True only if the underlying SSLEngine object should be rebuilt. + */ + boolean shouldBeRebuilt(Map nextConfigs); + + /** + * Returns the names of configs that may be reconfigured. + * @return Names of configuration options that are dynamically reconfigurable. + */ + Set reconfigurableConfigs(); + + /** + * Returns keystore configured for this factory. + * @return The keystore for this factory or null if a keystore is not configured. + */ + KeyStore keystore(); + + /** + * Returns truststore configured for this factory. + * @return The truststore for this factory or null if a truststore is not configured. + */ + KeyStore truststore(); +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/AbstractLogin.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/AbstractLogin.java new file mode 100644 index 0000000..7e13508 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/AbstractLogin.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.authenticator; + +import javax.security.auth.login.AppConfigurationEntry; +import javax.security.auth.login.Configuration; +import javax.security.auth.login.LoginContext; +import javax.security.auth.login.LoginException; +import javax.security.sasl.RealmCallback; +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.auth.Subject; + +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; +import org.apache.kafka.common.security.auth.Login; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.Map; + +/** + * Base login class that implements methods common to typical SASL mechanisms. + */ +public abstract class AbstractLogin implements Login { + private static final Logger log = LoggerFactory.getLogger(AbstractLogin.class); + + private String contextName; + private Configuration configuration; + private LoginContext loginContext; + private AuthenticateCallbackHandler loginCallbackHandler; + + @Override + public void configure(Map configs, String contextName, Configuration configuration, + AuthenticateCallbackHandler loginCallbackHandler) { + this.contextName = contextName; + this.configuration = configuration; + this.loginCallbackHandler = loginCallbackHandler; + } + + @Override + public LoginContext login() throws LoginException { + loginContext = new LoginContext(contextName, null, loginCallbackHandler, configuration); + loginContext.login(); + log.info("Successfully logged in."); + return loginContext; + } + + @Override + public Subject subject() { + return loginContext.getSubject(); + } + + protected String contextName() { + return contextName; + } + + protected Configuration configuration() { + return configuration; + } + + /** + * Callback handler for creating login context. Login callback handlers + * should support the callbacks required for the login modules used by + * the KafkaServer and KafkaClient contexts. Kafka does not support + * callback handlers which require additional user input. + * + */ + public static class DefaultLoginCallbackHandler implements AuthenticateCallbackHandler { + + @Override + public void configure(Map configs, String saslMechanism, List jaasConfigEntries) { + } + + @Override + public void handle(Callback[] callbacks) throws UnsupportedCallbackException { + for (Callback callback : callbacks) { + if (callback instanceof NameCallback) { + NameCallback nc = (NameCallback) callback; + nc.setName(nc.getDefaultName()); + } else if (callback instanceof PasswordCallback) { + String errorMessage = "Could not login: the client is being asked for a password, but the Kafka" + + " client code does not currently support obtaining a password from the user."; + throw new UnsupportedCallbackException(callback, errorMessage); + } else if (callback instanceof RealmCallback) { + RealmCallback rc = (RealmCallback) callback; + rc.setText(rc.getDefaultText()); + } else { + throw new UnsupportedCallbackException(callback, "Unrecognized SASL Login callback"); + } + } + } + + @Override + public void close() { + } + } +} + diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/CredentialCache.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/CredentialCache.java new file mode 100644 index 0000000..ecf3ea9 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/CredentialCache.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.authenticator; + +import java.util.concurrent.ConcurrentHashMap; + +public class CredentialCache { + + private final ConcurrentHashMap> cacheMap = new ConcurrentHashMap<>(); + + public Cache createCache(String mechanism, Class credentialClass) { + Cache cache = new Cache<>(credentialClass); + @SuppressWarnings("unchecked") + Cache oldCache = (Cache) cacheMap.putIfAbsent(mechanism, cache); + return oldCache == null ? cache : oldCache; + } + + @SuppressWarnings("unchecked") + public Cache cache(String mechanism, Class credentialClass) { + Cache cache = cacheMap.get(mechanism); + if (cache != null) { + if (cache.credentialClass() != credentialClass) + throw new IllegalArgumentException("Invalid credential class " + credentialClass + ", expected " + cache.credentialClass()); + return (Cache) cache; + } else + return null; + } + + public static class Cache { + private final Class credentialClass; + private final ConcurrentHashMap credentials; + + public Cache(Class credentialClass) { + this.credentialClass = credentialClass; + this.credentials = new ConcurrentHashMap<>(); + } + + public C get(String username) { + return credentials.get(username); + } + + public C put(String username, C credential) { + return credentials.put(username, credential); + } + + public C remove(String username) { + return credentials.remove(username); + } + + public Class credentialClass() { + return credentialClass; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/DefaultKafkaPrincipalBuilder.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/DefaultKafkaPrincipalBuilder.java new file mode 100644 index 0000000..cae0796 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/DefaultKafkaPrincipalBuilder.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.authenticator; + +import javax.security.auth.x500.X500Principal; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.config.SaslConfigs; +import org.apache.kafka.common.errors.SerializationException; +import org.apache.kafka.common.message.DefaultPrincipalData; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.MessageUtil; +import org.apache.kafka.common.security.auth.AuthenticationContext; +import org.apache.kafka.common.security.auth.KafkaPrincipal; +import org.apache.kafka.common.security.auth.KafkaPrincipalBuilder; +import org.apache.kafka.common.security.auth.KafkaPrincipalSerde; +import org.apache.kafka.common.security.auth.PlaintextAuthenticationContext; +import org.apache.kafka.common.security.auth.SaslAuthenticationContext; +import org.apache.kafka.common.security.auth.SslAuthenticationContext; +import org.apache.kafka.common.security.kerberos.KerberosName; +import org.apache.kafka.common.security.kerberos.KerberosShortNamer; + +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; +import javax.security.sasl.SaslServer; +import org.apache.kafka.common.security.ssl.SslPrincipalMapper; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.security.Principal; + +/** + * Default implementation of {@link KafkaPrincipalBuilder} which provides basic support for + * SSL authentication and SASL authentication. In the latter case, when GSSAPI is used, this + * class applies {@link org.apache.kafka.common.security.kerberos.KerberosShortNamer} to transform + * the name. + * + * NOTE: This is an internal class and can change without notice. + */ +public class DefaultKafkaPrincipalBuilder implements KafkaPrincipalBuilder, KafkaPrincipalSerde { + private final KerberosShortNamer kerberosShortNamer; + private final SslPrincipalMapper sslPrincipalMapper; + + /** + * Construct a new instance. + * + * @param kerberosShortNamer Kerberos name rewrite rules or null if none have been configured + * @param sslPrincipalMapper SSL Principal mapper or null if none have been configured + */ + public DefaultKafkaPrincipalBuilder(KerberosShortNamer kerberosShortNamer, SslPrincipalMapper sslPrincipalMapper) { + this.kerberosShortNamer = kerberosShortNamer; + this.sslPrincipalMapper = sslPrincipalMapper; + } + + @Override + public KafkaPrincipal build(AuthenticationContext context) { + if (context instanceof PlaintextAuthenticationContext) { + return KafkaPrincipal.ANONYMOUS; + } else if (context instanceof SslAuthenticationContext) { + SSLSession sslSession = ((SslAuthenticationContext) context).session(); + try { + return applySslPrincipalMapper(sslSession.getPeerPrincipal()); + } catch (SSLPeerUnverifiedException se) { + return KafkaPrincipal.ANONYMOUS; + } + } else if (context instanceof SaslAuthenticationContext) { + SaslServer saslServer = ((SaslAuthenticationContext) context).server(); + if (SaslConfigs.GSSAPI_MECHANISM.equals(saslServer.getMechanismName())) + return applyKerberosShortNamer(saslServer.getAuthorizationID()); + else + return new KafkaPrincipal(KafkaPrincipal.USER_TYPE, saslServer.getAuthorizationID()); + } else { + throw new IllegalArgumentException("Unhandled authentication context type: " + context.getClass().getName()); + } + } + + private KafkaPrincipal applyKerberosShortNamer(String authorizationId) { + KerberosName kerberosName = KerberosName.parse(authorizationId); + try { + String shortName = kerberosShortNamer.shortName(kerberosName); + return new KafkaPrincipal(KafkaPrincipal.USER_TYPE, shortName); + } catch (IOException e) { + throw new KafkaException("Failed to set name for '" + kerberosName + + "' based on Kerberos authentication rules.", e); + } + } + + private KafkaPrincipal applySslPrincipalMapper(Principal principal) { + try { + if (!(principal instanceof X500Principal) || principal == KafkaPrincipal.ANONYMOUS) { + return new KafkaPrincipal(KafkaPrincipal.USER_TYPE, principal.getName()); + } else { + return new KafkaPrincipal(KafkaPrincipal.USER_TYPE, sslPrincipalMapper.getName(principal.getName())); + } + } catch (IOException e) { + throw new KafkaException("Failed to map name for '" + principal.getName() + + "' based on SSL principal mapping rules.", e); + } + } + + @Override + public byte[] serialize(KafkaPrincipal principal) { + DefaultPrincipalData data = new DefaultPrincipalData() + .setType(principal.getPrincipalType()) + .setName(principal.getName()) + .setTokenAuthenticated(principal.tokenAuthenticated()); + return MessageUtil.toVersionPrefixedBytes(DefaultPrincipalData.HIGHEST_SUPPORTED_VERSION, data); + } + + @Override + public KafkaPrincipal deserialize(byte[] bytes) { + ByteBuffer buffer = ByteBuffer.wrap(bytes); + short version = buffer.getShort(); + if (version < DefaultPrincipalData.LOWEST_SUPPORTED_VERSION || version > DefaultPrincipalData.HIGHEST_SUPPORTED_VERSION) { + throw new SerializationException("Invalid principal data version " + version); + } + + DefaultPrincipalData data = new DefaultPrincipalData(new ByteBufferAccessor(buffer), version); + return new KafkaPrincipal(data.type(), data.name(), data.tokenAuthenticated()); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/DefaultLogin.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/DefaultLogin.java new file mode 100644 index 0000000..a902d7d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/DefaultLogin.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.authenticator; + +public class DefaultLogin extends AbstractLogin { + + @Override + public String serviceName() { + return "kafka"; + } + + @Override + public void close() { + } +} + diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/LoginManager.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/LoginManager.java new file mode 100644 index 0000000..6613fd1 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/LoginManager.java @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.authenticator; + +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.config.SaslConfigs; +import org.apache.kafka.common.config.types.Password; +import org.apache.kafka.common.network.ListenerName; +import org.apache.kafka.common.security.JaasContext; +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; +import org.apache.kafka.common.security.auth.Login; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; +import org.apache.kafka.common.security.oauthbearer.internals.unsecured.OAuthBearerUnsecuredLoginCallbackHandler; +import org.apache.kafka.common.utils.SecurityUtils; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.security.auth.Subject; +import javax.security.auth.login.LoginException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +public class LoginManager { + + private static final Logger LOGGER = LoggerFactory.getLogger(LoginManager.class); + + // static configs (broker or client) + private static final Map, LoginManager> STATIC_INSTANCES = new HashMap<>(); + + // dynamic configs (broker or client) + private static final Map, LoginManager> DYNAMIC_INSTANCES = new HashMap<>(); + + private final Login login; + private final LoginMetadata loginMetadata; + private final AuthenticateCallbackHandler loginCallbackHandler; + private int refCount; + + private LoginManager(JaasContext jaasContext, String saslMechanism, Map configs, + LoginMetadata loginMetadata) throws LoginException { + this.loginMetadata = loginMetadata; + this.login = Utils.newInstance(loginMetadata.loginClass); + loginCallbackHandler = Utils.newInstance(loginMetadata.loginCallbackClass); + loginCallbackHandler.configure(configs, saslMechanism, jaasContext.configurationEntries()); + login.configure(configs, jaasContext.name(), jaasContext.configuration(), loginCallbackHandler); + login.login(); + } + + /** + * Returns an instance of `LoginManager` and increases its reference count. + * + * `release()` should be invoked when the `LoginManager` is no longer needed. This method will try to reuse an + * existing `LoginManager` for the provided context type. If `jaasContext` was loaded from a dynamic config, + * login managers are reused for the same dynamic config value. For `jaasContext` loaded from static JAAS + * configuration, login managers are reused for static contexts with the same login context name. + * + * This is a bit ugly and it would be nicer if we could pass the `LoginManager` to `ChannelBuilders.create` and + * shut it down when the broker or clients are closed. It's straightforward to do the former, but it's more + * complicated to do the latter without making the consumer API more complex. + * + * @param jaasContext Static or dynamic JAAS context. `jaasContext.dynamicJaasConfig()` is non-null for dynamic context. + * For static contexts, this may contain multiple login modules if the context type is SERVER. + * For CLIENT static contexts and dynamic contexts of CLIENT and SERVER, 'jaasContext` contains + * only one login module. + * @param saslMechanism SASL mechanism for which login manager is being acquired. For dynamic contexts, the single + * login module in `jaasContext` corresponds to this SASL mechanism. Hence `Login` class is + * chosen based on this mechanism. + * @param defaultLoginClass Default login class to use if an override is not specified in `configs` + * @param configs Config options used to configure `Login` if a new login manager is created. + * + */ + public static LoginManager acquireLoginManager(JaasContext jaasContext, String saslMechanism, + Class defaultLoginClass, + Map configs) throws LoginException { + Class loginClass = configuredClassOrDefault(configs, jaasContext, + saslMechanism, SaslConfigs.SASL_LOGIN_CLASS, defaultLoginClass); + Class defaultLoginCallbackHandlerClass = OAuthBearerLoginModule.OAUTHBEARER_MECHANISM + .equals(saslMechanism) ? OAuthBearerUnsecuredLoginCallbackHandler.class + : AbstractLogin.DefaultLoginCallbackHandler.class; + Class loginCallbackClass = configuredClassOrDefault(configs, jaasContext, + saslMechanism, SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS, defaultLoginCallbackHandlerClass); + synchronized (LoginManager.class) { + LoginManager loginManager; + Password jaasConfigValue = jaasContext.dynamicJaasConfig(); + if (jaasConfigValue != null) { + LoginMetadata loginMetadata = new LoginMetadata<>(jaasConfigValue, loginClass, loginCallbackClass); + loginManager = DYNAMIC_INSTANCES.get(loginMetadata); + if (loginManager == null) { + loginManager = new LoginManager(jaasContext, saslMechanism, configs, loginMetadata); + DYNAMIC_INSTANCES.put(loginMetadata, loginManager); + } + } else { + LoginMetadata loginMetadata = new LoginMetadata<>(jaasContext.name(), loginClass, loginCallbackClass); + loginManager = STATIC_INSTANCES.get(loginMetadata); + if (loginManager == null) { + loginManager = new LoginManager(jaasContext, saslMechanism, configs, loginMetadata); + STATIC_INSTANCES.put(loginMetadata, loginManager); + } + } + SecurityUtils.addConfiguredSecurityProviders(configs); + return loginManager.acquire(); + } + } + + public Subject subject() { + return login.subject(); + } + + public String serviceName() { + return login.serviceName(); + } + + // Only for testing + Object cacheKey() { + return loginMetadata.configInfo; + } + + private LoginManager acquire() { + ++refCount; + LOGGER.trace("{} acquired", this); + return this; + } + + /** + * Decrease the reference count for this instance and release resources if it reaches 0. + */ + public void release() { + synchronized (LoginManager.class) { + if (refCount == 0) + throw new IllegalStateException("release() called on disposed " + this); + else if (refCount == 1) { + if (loginMetadata.configInfo instanceof Password) { + DYNAMIC_INSTANCES.remove(loginMetadata); + } else { + STATIC_INSTANCES.remove(loginMetadata); + } + login.close(); + loginCallbackHandler.close(); + } + --refCount; + LOGGER.trace("{} released", this); + } + } + + @Override + public String toString() { + return "LoginManager(serviceName=" + serviceName() + + // subject.toString() exposes private credentials, so we can't use it + ", publicCredentials=" + subject().getPublicCredentials() + + ", refCount=" + refCount + ')'; + } + + /* Should only be used in tests. */ + public static void closeAll() { + synchronized (LoginManager.class) { + for (LoginMetadata key : new ArrayList<>(STATIC_INSTANCES.keySet())) + STATIC_INSTANCES.remove(key).login.close(); + for (LoginMetadata key : new ArrayList<>(DYNAMIC_INSTANCES.keySet())) + DYNAMIC_INSTANCES.remove(key).login.close(); + } + } + + private static Class configuredClassOrDefault(Map configs, + JaasContext jaasContext, + String saslMechanism, + String configName, + Class defaultClass) { + String prefix = jaasContext.type() == JaasContext.Type.SERVER ? ListenerName.saslMechanismPrefix(saslMechanism) : ""; + @SuppressWarnings("unchecked") + Class clazz = (Class) configs.get(prefix + configName); + if (clazz != null && jaasContext.configurationEntries().size() != 1) { + String errorMessage = configName + " cannot be specified with multiple login modules in the JAAS context. " + + SaslConfigs.SASL_JAAS_CONFIG + " must be configured to override mechanism-specific configs."; + throw new ConfigException(errorMessage); + } + if (clazz == null) + clazz = defaultClass; + return clazz; + } + + private static class LoginMetadata { + final T configInfo; + final Class loginClass; + final Class loginCallbackClass; + + LoginMetadata(T configInfo, Class loginClass, + Class loginCallbackClass) { + this.configInfo = configInfo; + this.loginClass = loginClass; + this.loginCallbackClass = loginCallbackClass; + } + + @Override + public int hashCode() { + return Objects.hash(configInfo, loginClass, loginCallbackClass); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + LoginMetadata loginMetadata = (LoginMetadata) o; + return Objects.equals(configInfo, loginMetadata.configInfo) && + Objects.equals(loginClass, loginMetadata.loginClass) && + Objects.equals(loginCallbackClass, loginMetadata.loginCallbackClass); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java new file mode 100644 index 0000000..e502f80 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java @@ -0,0 +1,711 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.authenticator; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.clients.NetworkClient; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.config.SaslConfigs; +import org.apache.kafka.common.errors.IllegalSaslStateException; +import org.apache.kafka.common.errors.SaslAuthenticationException; +import org.apache.kafka.common.errors.UnsupportedSaslMechanismException; +import org.apache.kafka.common.message.ApiVersionsResponseData.ApiVersion; +import org.apache.kafka.common.message.RequestHeaderData; +import org.apache.kafka.common.message.SaslAuthenticateRequestData; +import org.apache.kafka.common.message.SaslHandshakeRequestData; +import org.apache.kafka.common.network.Authenticator; +import org.apache.kafka.common.network.ByteBufferSend; +import org.apache.kafka.common.network.NetworkReceive; +import org.apache.kafka.common.network.ReauthenticationContext; +import org.apache.kafka.common.network.Send; +import org.apache.kafka.common.network.TransportLayer; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.protocol.types.SchemaException; +import org.apache.kafka.common.requests.AbstractResponse; +import org.apache.kafka.common.requests.ApiVersionsRequest; +import org.apache.kafka.common.requests.ApiVersionsResponse; +import org.apache.kafka.common.requests.RequestHeader; +import org.apache.kafka.common.requests.SaslAuthenticateRequest; +import org.apache.kafka.common.requests.SaslAuthenticateResponse; +import org.apache.kafka.common.requests.SaslHandshakeRequest; +import org.apache.kafka.common.requests.SaslHandshakeResponse; +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; +import org.apache.kafka.common.security.auth.KafkaPrincipal; +import org.apache.kafka.common.security.auth.KafkaPrincipalSerde; +import org.apache.kafka.common.security.kerberos.KerberosError; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; + +import javax.security.auth.Subject; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslClient; +import javax.security.sasl.SaslException; +import java.io.IOException; +import java.nio.BufferUnderflowException; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.security.Principal; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Random; +import java.util.Set; + +public class SaslClientAuthenticator implements Authenticator { + /** + * The internal state transitions for initial authentication of a channel are + * declared in order, starting with {@link #SEND_APIVERSIONS_REQUEST} and ending + * in either {@link #COMPLETE} or {@link #FAILED}. + *

        + * Re-authentication of a channel starts with the state + * {@link #REAUTH_PROCESS_ORIG_APIVERSIONS_RESPONSE} and then flows to + * {@link #REAUTH_SEND_HANDSHAKE_REQUEST} followed by + * {@link #REAUTH_RECEIVE_HANDSHAKE_OR_OTHER_RESPONSE} and then + * {@link #REAUTH_INITIAL}; after that the flow joins the authentication flow + * at the {@link #INTERMEDIATE} state and ends at either {@link #COMPLETE} or + * {@link #FAILED}. + */ + public enum SaslState { + SEND_APIVERSIONS_REQUEST, // Initial state for authentication: client sends ApiVersionsRequest in this state when authenticating + RECEIVE_APIVERSIONS_RESPONSE, // Awaiting ApiVersionsResponse from server + SEND_HANDSHAKE_REQUEST, // Received ApiVersionsResponse, send SaslHandshake request + RECEIVE_HANDSHAKE_RESPONSE, // Awaiting SaslHandshake response from server when authenticating + INITIAL, // Initial authentication state starting SASL token exchange for configured mechanism, send first token + INTERMEDIATE, // Intermediate state during SASL token exchange, process challenges and send responses + CLIENT_COMPLETE, // Sent response to last challenge. If using SaslAuthenticate, wait for authentication status from server, else COMPLETE + COMPLETE, // Authentication sequence complete. If using SaslAuthenticate, this state implies successful authentication. + FAILED, // Failed authentication due to an error at some stage + REAUTH_PROCESS_ORIG_APIVERSIONS_RESPONSE, // Initial state for re-authentication: process ApiVersionsResponse from original authentication + REAUTH_SEND_HANDSHAKE_REQUEST, // Processed original ApiVersionsResponse, send SaslHandshake request as part of re-authentication + REAUTH_RECEIVE_HANDSHAKE_OR_OTHER_RESPONSE, // Awaiting SaslHandshake response from server when re-authenticating, and may receive other, in-flight responses sent prior to start of re-authentication as well + REAUTH_INITIAL, // Initial re-authentication state starting SASL token exchange for configured mechanism, send first token + } + + private static final short DISABLE_KAFKA_SASL_AUTHENTICATE_HEADER = -1; + private static final Random RNG = new Random(); + + /** + * the reserved range of correlation id for Sasl requests. + * + * Noted: there is a story about reserved range. The response of LIST_OFFSET is compatible to response of SASL_HANDSHAKE. + * Hence, we could miss the schema error when using schema of SASL_HANDSHAKE to parse response of LIST_OFFSET. + * For example: the IllegalStateException caused by mismatched correlation id is thrown if following steps happens. + * 1) sent LIST_OFFSET + * 2) sent SASL_HANDSHAKE + * 3) receive response of LIST_OFFSET + * 4) succeed to use schema of SASL_HANDSHAKE to parse response of LIST_OFFSET + * 5) throw IllegalStateException due to mismatched correlation id + * As a simple approach, we force Sasl requests to use a reserved correlation id which is separated from those + * used in NetworkClient for Kafka requests. Hence, we can guarantee that every SASL request will throw + * SchemaException due to correlation id mismatch during reauthentication + */ + public static final int MAX_RESERVED_CORRELATION_ID = Integer.MAX_VALUE; + + /** + * We only expect one request in-flight a time during authentication so the small range is fine. + */ + public static final int MIN_RESERVED_CORRELATION_ID = MAX_RESERVED_CORRELATION_ID - 7; + + /** + * @return true if the correlation id is reserved for SASL request. otherwise, false + */ + public static boolean isReserved(int correlationId) { + return correlationId >= MIN_RESERVED_CORRELATION_ID; + } + + private final Subject subject; + private final String servicePrincipal; + private final String host; + private final String node; + private final String mechanism; + private final TransportLayer transportLayer; + private final SaslClient saslClient; + private final Map configs; + private final String clientPrincipalName; + private final AuthenticateCallbackHandler callbackHandler; + private final Time time; + private final Logger log; + private final ReauthInfo reauthInfo; + + // buffers used in `authenticate` + private NetworkReceive netInBuffer; + private Send netOutBuffer; + + // Current SASL state + private SaslState saslState; + // Next SASL state to be set when outgoing writes associated with the current SASL state complete + private SaslState pendingSaslState; + // Correlation ID for the next request + private int correlationId; + // Request header for which response from the server is pending + private RequestHeader currentRequestHeader; + // Version of SaslAuthenticate request/responses + private short saslAuthenticateVersion; + // Version of SaslHandshake request/responses + private short saslHandshakeVersion; + + public SaslClientAuthenticator(Map configs, + AuthenticateCallbackHandler callbackHandler, + String node, + Subject subject, + String servicePrincipal, + String host, + String mechanism, + boolean handshakeRequestEnable, + TransportLayer transportLayer, + Time time, + LogContext logContext) { + this.node = node; + this.subject = subject; + this.callbackHandler = callbackHandler; + this.host = host; + this.servicePrincipal = servicePrincipal; + this.mechanism = mechanism; + this.correlationId = 0; + this.transportLayer = transportLayer; + this.configs = configs; + this.saslAuthenticateVersion = DISABLE_KAFKA_SASL_AUTHENTICATE_HEADER; + this.time = time; + this.log = logContext.logger(getClass()); + this.reauthInfo = new ReauthInfo(); + + try { + setSaslState(handshakeRequestEnable ? SaslState.SEND_APIVERSIONS_REQUEST : SaslState.INITIAL); + + // determine client principal from subject for Kerberos to use as authorization id for the SaslClient. + // For other mechanisms, the authenticated principal (username for PLAIN and SCRAM) is used as + // authorization id. Hence the principal is not specified for creating the SaslClient. + if (mechanism.equals(SaslConfigs.GSSAPI_MECHANISM)) + this.clientPrincipalName = firstPrincipal(subject); + else + this.clientPrincipalName = null; + + saslClient = createSaslClient(); + } catch (Exception e) { + throw new SaslAuthenticationException("Failed to configure SaslClientAuthenticator", e); + } + } + + // visible for testing + SaslClient createSaslClient() { + try { + return Subject.doAs(subject, (PrivilegedExceptionAction) () -> { + String[] mechs = {mechanism}; + log.debug("Creating SaslClient: client={};service={};serviceHostname={};mechs={}", + clientPrincipalName, servicePrincipal, host, Arrays.toString(mechs)); + SaslClient retvalSaslClient = Sasl.createSaslClient(mechs, clientPrincipalName, servicePrincipal, host, configs, callbackHandler); + if (retvalSaslClient == null) { + throw new SaslAuthenticationException("Failed to create SaslClient with mechanism " + mechanism); + } + return retvalSaslClient; + }); + } catch (PrivilegedActionException e) { + throw new SaslAuthenticationException("Failed to create SaslClient with mechanism " + mechanism, e.getCause()); + } + } + + /** + * Sends an empty message to the server to initiate the authentication process. It then evaluates server challenges + * via `SaslClient.evaluateChallenge` and returns client responses until authentication succeeds or fails. + * + * The messages are sent and received as size delimited bytes that consists of a 4 byte network-ordered size N + * followed by N bytes representing the opaque payload. + */ + @SuppressWarnings("fallthrough") + public void authenticate() throws IOException { + if (netOutBuffer != null && !flushNetOutBufferAndUpdateInterestOps()) + return; + + switch (saslState) { + case SEND_APIVERSIONS_REQUEST: + // Always use version 0 request since brokers treat requests with schema exceptions as GSSAPI tokens + ApiVersionsRequest apiVersionsRequest = new ApiVersionsRequest.Builder().build((short) 0); + send(apiVersionsRequest.toSend(nextRequestHeader(ApiKeys.API_VERSIONS, apiVersionsRequest.version()))); + setSaslState(SaslState.RECEIVE_APIVERSIONS_RESPONSE); + break; + case RECEIVE_APIVERSIONS_RESPONSE: + ApiVersionsResponse apiVersionsResponse = (ApiVersionsResponse) receiveKafkaResponse(); + if (apiVersionsResponse == null) + break; + else { + setSaslAuthenticateAndHandshakeVersions(apiVersionsResponse); + reauthInfo.apiVersionsResponseReceivedFromBroker = apiVersionsResponse; + setSaslState(SaslState.SEND_HANDSHAKE_REQUEST); + // Fall through to send handshake request with the latest supported version + } + case SEND_HANDSHAKE_REQUEST: + sendHandshakeRequest(saslHandshakeVersion); + setSaslState(SaslState.RECEIVE_HANDSHAKE_RESPONSE); + break; + case RECEIVE_HANDSHAKE_RESPONSE: + SaslHandshakeResponse handshakeResponse = (SaslHandshakeResponse) receiveKafkaResponse(); + if (handshakeResponse == null) + break; + else { + handleSaslHandshakeResponse(handshakeResponse); + setSaslState(SaslState.INITIAL); + // Fall through and start SASL authentication using the configured client mechanism + } + case INITIAL: + sendInitialToken(); + setSaslState(SaslState.INTERMEDIATE); + break; + case REAUTH_PROCESS_ORIG_APIVERSIONS_RESPONSE: + setSaslAuthenticateAndHandshakeVersions(reauthInfo.apiVersionsResponseFromOriginalAuthentication); + setSaslState(SaslState.REAUTH_SEND_HANDSHAKE_REQUEST); // Will set immediately + // Fall through to send handshake request with the latest supported version + case REAUTH_SEND_HANDSHAKE_REQUEST: + sendHandshakeRequest(saslHandshakeVersion); + setSaslState(SaslState.REAUTH_RECEIVE_HANDSHAKE_OR_OTHER_RESPONSE); + break; + case REAUTH_RECEIVE_HANDSHAKE_OR_OTHER_RESPONSE: + handshakeResponse = (SaslHandshakeResponse) receiveKafkaResponse(); + if (handshakeResponse == null) + break; + handleSaslHandshakeResponse(handshakeResponse); + setSaslState(SaslState.REAUTH_INITIAL); // Will set immediately + /* + * Fall through and start SASL authentication using the configured client + * mechanism. Note that we have to either fall through or add a loop to enter + * the switch statement again. We will fall through to avoid adding the loop and + * therefore minimize the changes to authentication-related code due to the + * changes related to re-authentication. + */ + case REAUTH_INITIAL: + sendInitialToken(); + setSaslState(SaslState.INTERMEDIATE); + break; + case INTERMEDIATE: + byte[] serverToken = receiveToken(); + boolean noResponsesPending = serverToken != null && !sendSaslClientToken(serverToken, false); + // For versions without SASL_AUTHENTICATE header, SASL exchange may be complete after a token is sent to server. + // For versions with SASL_AUTHENTICATE header, server always sends a response to each SASL_AUTHENTICATE request. + if (saslClient.isComplete()) { + if (saslAuthenticateVersion == DISABLE_KAFKA_SASL_AUTHENTICATE_HEADER || noResponsesPending) + setSaslState(SaslState.COMPLETE); + else + setSaslState(SaslState.CLIENT_COMPLETE); + } + break; + case CLIENT_COMPLETE: + byte[] serverResponse = receiveToken(); + if (serverResponse != null) + setSaslState(SaslState.COMPLETE); + break; + case COMPLETE: + break; + case FAILED: + // Should never get here since exception would have been propagated earlier + throw new IllegalStateException("SASL handshake has already failed"); + } + } + + private void sendHandshakeRequest(short version) throws IOException { + SaslHandshakeRequest handshakeRequest = createSaslHandshakeRequest(version); + send(handshakeRequest.toSend(nextRequestHeader(ApiKeys.SASL_HANDSHAKE, handshakeRequest.version()))); + } + + private void sendInitialToken() throws IOException { + sendSaslClientToken(new byte[0], true); + } + + @Override + public void reauthenticate(ReauthenticationContext reauthenticationContext) throws IOException { + SaslClientAuthenticator previousSaslClientAuthenticator = (SaslClientAuthenticator) Objects + .requireNonNull(reauthenticationContext).previousAuthenticator(); + ApiVersionsResponse apiVersionsResponseFromOriginalAuthentication = previousSaslClientAuthenticator.reauthInfo + .apiVersionsResponse(); + previousSaslClientAuthenticator.close(); + reauthInfo.reauthenticating(apiVersionsResponseFromOriginalAuthentication, + reauthenticationContext.reauthenticationBeginNanos()); + NetworkReceive netInBufferFromChannel = reauthenticationContext.networkReceive(); + netInBuffer = netInBufferFromChannel; + setSaslState(SaslState.REAUTH_PROCESS_ORIG_APIVERSIONS_RESPONSE); // Will set immediately + authenticate(); + } + + @Override + public Optional pollResponseReceivedDuringReauthentication() { + return reauthInfo.pollResponseReceivedDuringReauthentication(); + } + + @Override + public Long clientSessionReauthenticationTimeNanos() { + return reauthInfo.clientSessionReauthenticationTimeNanos; + } + + @Override + public Long reauthenticationLatencyMs() { + return reauthInfo.reauthenticationLatencyMs(); + } + + // visible for testing + int nextCorrelationId() { + if (!isReserved(correlationId)) + correlationId = MIN_RESERVED_CORRELATION_ID; + return correlationId++; + } + + private RequestHeader nextRequestHeader(ApiKeys apiKey, short version) { + String clientId = (String) configs.get(CommonClientConfigs.CLIENT_ID_CONFIG); + short requestApiKey = apiKey.id; + currentRequestHeader = new RequestHeader( + new RequestHeaderData(). + setRequestApiKey(requestApiKey). + setRequestApiVersion(version). + setClientId(clientId). + setCorrelationId(nextCorrelationId()), + apiKey.requestHeaderVersion(version)); + return currentRequestHeader; + } + + // Visible to override for testing + protected SaslHandshakeRequest createSaslHandshakeRequest(short version) { + return new SaslHandshakeRequest.Builder( + new SaslHandshakeRequestData().setMechanism(mechanism)).build(version); + } + + // Visible to override for testing + protected void setSaslAuthenticateAndHandshakeVersions(ApiVersionsResponse apiVersionsResponse) { + ApiVersion authenticateVersion = apiVersionsResponse.apiVersion(ApiKeys.SASL_AUTHENTICATE.id); + if (authenticateVersion != null) { + this.saslAuthenticateVersion = (short) Math.min(authenticateVersion.maxVersion(), + ApiKeys.SASL_AUTHENTICATE.latestVersion()); + } + ApiVersion handshakeVersion = apiVersionsResponse.apiVersion(ApiKeys.SASL_HANDSHAKE.id); + if (handshakeVersion != null) { + this.saslHandshakeVersion = (short) Math.min(handshakeVersion.maxVersion(), + ApiKeys.SASL_HANDSHAKE.latestVersion()); + } + } + + private void setSaslState(SaslState saslState) { + if (netOutBuffer != null && !netOutBuffer.completed()) + pendingSaslState = saslState; + else { + this.pendingSaslState = null; + this.saslState = saslState; + log.debug("Set SASL client state to {}", saslState); + if (saslState == SaslState.COMPLETE) { + reauthInfo.setAuthenticationEndAndSessionReauthenticationTimes(time.nanoseconds()); + if (!reauthInfo.reauthenticating()) + transportLayer.removeInterestOps(SelectionKey.OP_WRITE); + else + /* + * Re-authentication is triggered by a write, so we have to make sure that + * pending write is actually sent. + */ + transportLayer.addInterestOps(SelectionKey.OP_WRITE); + } + } + } + + /** + * Sends a SASL client token to server if required. This may be an initial token to start + * SASL token exchange or response to a challenge from the server. + * @return true if a token was sent to the server + */ + private boolean sendSaslClientToken(byte[] serverToken, boolean isInitial) throws IOException { + if (!saslClient.isComplete()) { + byte[] saslToken = createSaslToken(serverToken, isInitial); + if (saslToken != null) { + ByteBuffer tokenBuf = ByteBuffer.wrap(saslToken); + Send send; + if (saslAuthenticateVersion == DISABLE_KAFKA_SASL_AUTHENTICATE_HEADER) { + send = ByteBufferSend.sizePrefixed(tokenBuf); + } else { + SaslAuthenticateRequestData data = new SaslAuthenticateRequestData() + .setAuthBytes(tokenBuf.array()); + SaslAuthenticateRequest request = new SaslAuthenticateRequest.Builder(data).build(saslAuthenticateVersion); + send = request.toSend(nextRequestHeader(ApiKeys.SASL_AUTHENTICATE, saslAuthenticateVersion)); + } + send(send); + return true; + } + } + return false; + } + + private void send(Send send) throws IOException { + try { + netOutBuffer = send; + flushNetOutBufferAndUpdateInterestOps(); + } catch (IOException e) { + setSaslState(SaslState.FAILED); + throw e; + } + } + + private boolean flushNetOutBufferAndUpdateInterestOps() throws IOException { + boolean flushedCompletely = flushNetOutBuffer(); + if (flushedCompletely) { + transportLayer.removeInterestOps(SelectionKey.OP_WRITE); + if (pendingSaslState != null) + setSaslState(pendingSaslState); + } else + transportLayer.addInterestOps(SelectionKey.OP_WRITE); + return flushedCompletely; + } + + private byte[] receiveResponseOrToken() throws IOException { + if (netInBuffer == null) netInBuffer = new NetworkReceive(node); + netInBuffer.readFrom(transportLayer); + byte[] serverPacket = null; + if (netInBuffer.complete()) { + netInBuffer.payload().rewind(); + serverPacket = new byte[netInBuffer.payload().remaining()]; + netInBuffer.payload().get(serverPacket, 0, serverPacket.length); + netInBuffer = null; // reset the networkReceive as we read all the data. + } + return serverPacket; + } + + public KafkaPrincipal principal() { + return new KafkaPrincipal(KafkaPrincipal.USER_TYPE, clientPrincipalName); + } + + @Override + public Optional principalSerde() { + return Optional.empty(); + } + + public boolean complete() { + return saslState == SaslState.COMPLETE; + } + + public void close() throws IOException { + if (saslClient != null) + saslClient.dispose(); + } + + private byte[] receiveToken() throws IOException { + if (saslAuthenticateVersion == DISABLE_KAFKA_SASL_AUTHENTICATE_HEADER) { + return receiveResponseOrToken(); + } else { + SaslAuthenticateResponse response = (SaslAuthenticateResponse) receiveKafkaResponse(); + if (response != null) { + Errors error = response.error(); + if (error != Errors.NONE) { + setSaslState(SaslState.FAILED); + String errMsg = response.errorMessage(); + throw errMsg == null ? error.exception() : error.exception(errMsg); + } + long sessionLifetimeMs = response.sessionLifetimeMs(); + if (sessionLifetimeMs > 0L) + reauthInfo.positiveSessionLifetimeMs = sessionLifetimeMs; + return Utils.copyArray(response.saslAuthBytes()); + } else + return null; + } + } + + + private byte[] createSaslToken(final byte[] saslToken, boolean isInitial) throws SaslException { + if (saslToken == null) + throw new IllegalSaslStateException("Error authenticating with the Kafka Broker: received a `null` saslToken."); + + try { + if (isInitial && !saslClient.hasInitialResponse()) + return saslToken; + else + return Subject.doAs(subject, (PrivilegedExceptionAction) () -> saslClient.evaluateChallenge(saslToken)); + } catch (PrivilegedActionException e) { + String error = "An error: (" + e + ") occurred when evaluating SASL token received from the Kafka Broker."; + KerberosError kerberosError = KerberosError.fromException(e); + // Try to provide hints to use about what went wrong so they can fix their configuration. + if (kerberosError == KerberosError.SERVER_NOT_FOUND) { + error += " This may be caused by Java's being unable to resolve the Kafka Broker's" + + " hostname correctly. You may want to try to adding" + + " '-Dsun.net.spi.nameservice.provider.1=dns,sun' to your client's JVMFLAGS environment." + + " Users must configure FQDN of kafka brokers when authenticating using SASL and" + + " `socketChannel.socket().getInetAddress().getHostName()` must match the hostname in `principal/hostname@realm`"; + } + //Unwrap the SaslException inside `PrivilegedActionException` + Throwable cause = e.getCause(); + // Treat transient Kerberos errors as non-fatal SaslExceptions that are processed as I/O exceptions + // and all other failures as fatal SaslAuthenticationException. + if ((kerberosError != null && kerberosError.retriable()) || (kerberosError == null && KerberosError.isRetriableClientGssException(e))) { + error += " Kafka Client will retry."; + throw new SaslException(error, cause); + } else { + error += " Kafka Client will go to AUTHENTICATION_FAILED state."; + throw new SaslAuthenticationException(error, cause); + } + } + } + + private boolean flushNetOutBuffer() throws IOException { + if (!netOutBuffer.completed()) { + netOutBuffer.writeTo(transportLayer); + } + return netOutBuffer.completed(); + } + + private AbstractResponse receiveKafkaResponse() throws IOException { + if (netInBuffer == null) + netInBuffer = new NetworkReceive(node); + NetworkReceive receive = netInBuffer; + try { + byte[] responseBytes = receiveResponseOrToken(); + if (responseBytes == null) + return null; + else { + AbstractResponse response = NetworkClient.parseResponse(ByteBuffer.wrap(responseBytes), currentRequestHeader); + currentRequestHeader = null; + return response; + } + } catch (BufferUnderflowException | SchemaException | IllegalArgumentException e) { + /* + * Account for the fact that during re-authentication there may be responses + * arriving for requests that were sent in the past. + */ + if (reauthInfo.reauthenticating()) { + /* + * It didn't match the current request header, so it must be unrelated to + * re-authentication. Save it so it can be processed later. + */ + receive.payload().rewind(); + reauthInfo.pendingAuthenticatedReceives.add(receive); + return null; + } + log.debug("Invalid SASL mechanism response, server may be expecting only GSSAPI tokens"); + setSaslState(SaslState.FAILED); + throw new IllegalSaslStateException("Invalid SASL mechanism response, server may be expecting a different protocol", e); + } + } + + private void handleSaslHandshakeResponse(SaslHandshakeResponse response) { + Errors error = response.error(); + if (error != Errors.NONE) + setSaslState(SaslState.FAILED); + switch (error) { + case NONE: + break; + case UNSUPPORTED_SASL_MECHANISM: + throw new UnsupportedSaslMechanismException(String.format("Client SASL mechanism '%s' not enabled in the server, enabled mechanisms are %s", + mechanism, response.enabledMechanisms())); + case ILLEGAL_SASL_STATE: + throw new IllegalSaslStateException(String.format("Unexpected handshake request with client mechanism %s, enabled mechanisms are %s", + mechanism, response.enabledMechanisms())); + default: + throw new IllegalSaslStateException(String.format("Unknown error code %s, client mechanism is %s, enabled mechanisms are %s", + response.error(), mechanism, response.enabledMechanisms())); + } + } + + /** + * Returns the first Principal from Subject. + * @throws KafkaException if there are no Principals in the Subject. + * During Kerberos re-login, principal is reset on Subject. An exception is + * thrown so that the connection is retried after any configured backoff. + */ + public static String firstPrincipal(Subject subject) { + Set principals = subject.getPrincipals(); + synchronized (principals) { + Iterator iterator = principals.iterator(); + if (iterator.hasNext()) + return iterator.next().getName(); + else + throw new KafkaException("Principal could not be determined from Subject, this may be a transient failure due to Kerberos re-login"); + } + } + + /** + * Information related to re-authentication + */ + private class ReauthInfo { + public ApiVersionsResponse apiVersionsResponseFromOriginalAuthentication; + public long reauthenticationBeginNanos; + public List pendingAuthenticatedReceives = new ArrayList<>(); + public ApiVersionsResponse apiVersionsResponseReceivedFromBroker; + public Long positiveSessionLifetimeMs; + public long authenticationEndNanos; + public Long clientSessionReauthenticationTimeNanos; + + public void reauthenticating(ApiVersionsResponse apiVersionsResponseFromOriginalAuthentication, + long reauthenticationBeginNanos) { + this.apiVersionsResponseFromOriginalAuthentication = Objects + .requireNonNull(apiVersionsResponseFromOriginalAuthentication); + this.reauthenticationBeginNanos = reauthenticationBeginNanos; + } + + public boolean reauthenticating() { + return apiVersionsResponseFromOriginalAuthentication != null; + } + + public ApiVersionsResponse apiVersionsResponse() { + return reauthenticating() ? apiVersionsResponseFromOriginalAuthentication + : apiVersionsResponseReceivedFromBroker; + } + + /** + * Return the (always non-null but possibly empty) NetworkReceive response that + * arrived during re-authentication that is unrelated to re-authentication, if + * any. This corresponds to a request sent prior to the beginning of + * re-authentication; the request was made when the channel was successfully + * authenticated, and the response arrived during the re-authentication + * process. + * + * @return the (always non-null but possibly empty) NetworkReceive response + * that arrived during re-authentication that is unrelated to + * re-authentication, if any + */ + public Optional pollResponseReceivedDuringReauthentication() { + if (pendingAuthenticatedReceives.isEmpty()) + return Optional.empty(); + return Optional.of(pendingAuthenticatedReceives.remove(0)); + } + + public void setAuthenticationEndAndSessionReauthenticationTimes(long nowNanos) { + authenticationEndNanos = nowNanos; + long sessionLifetimeMsToUse = 0; + if (positiveSessionLifetimeMs != null) { + // pick a random percentage between 85% and 95% for session re-authentication + double pctWindowFactorToTakeNetworkLatencyAndClockDriftIntoAccount = 0.85; + double pctWindowJitterToAvoidReauthenticationStormAcrossManyChannelsSimultaneously = 0.10; + double pctToUse = pctWindowFactorToTakeNetworkLatencyAndClockDriftIntoAccount + RNG.nextDouble() + * pctWindowJitterToAvoidReauthenticationStormAcrossManyChannelsSimultaneously; + sessionLifetimeMsToUse = (long) (positiveSessionLifetimeMs * pctToUse); + clientSessionReauthenticationTimeNanos = authenticationEndNanos + 1000 * 1000 * sessionLifetimeMsToUse; + log.debug( + "Finished {} with session expiration in {} ms and session re-authentication on or after {} ms", + authenticationOrReauthenticationText(), positiveSessionLifetimeMs, sessionLifetimeMsToUse); + } else + log.debug("Finished {} with no session expiration and no session re-authentication", + authenticationOrReauthenticationText()); + } + + public Long reauthenticationLatencyMs() { + return reauthenticating() + ? Math.round((authenticationEndNanos - reauthenticationBeginNanos) / 1000.0 / 1000.0) + : null; + } + + private String authenticationOrReauthenticationText() { + return reauthenticating() ? "re-authentication" : "authentication"; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientCallbackHandler.java new file mode 100644 index 0000000..e141bb6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientCallbackHandler.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.authenticator; + +import java.security.AccessController; +import java.util.List; +import java.util.Map; + +import javax.security.auth.Subject; +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.auth.login.AppConfigurationEntry; +import javax.security.sasl.AuthorizeCallback; +import javax.security.sasl.RealmCallback; + +import org.apache.kafka.common.config.SaslConfigs; +import org.apache.kafka.common.security.auth.SaslExtensionsCallback; +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; +import org.apache.kafka.common.security.auth.SaslExtensions; +import org.apache.kafka.common.security.scram.ScramExtensionsCallback; +import org.apache.kafka.common.security.scram.internals.ScramMechanism; + +/** + * Default callback handler for Sasl clients. The callbacks required for the SASL mechanism + * configured for the client should be supported by this callback handler. See + * Java SASL API + * for the list of SASL callback handlers required for each SASL mechanism. + * + * For adding custom SASL extensions, a {@link SaslExtensions} may be added to the subject's public credentials + */ +public class SaslClientCallbackHandler implements AuthenticateCallbackHandler { + + private String mechanism; + + @Override + public void configure(Map configs, String saslMechanism, List jaasConfigEntries) { + this.mechanism = saslMechanism; + } + + @Override + public void handle(Callback[] callbacks) throws UnsupportedCallbackException { + Subject subject = Subject.getSubject(AccessController.getContext()); + for (Callback callback : callbacks) { + if (callback instanceof NameCallback) { + NameCallback nc = (NameCallback) callback; + if (subject != null && !subject.getPublicCredentials(String.class).isEmpty()) { + nc.setName(subject.getPublicCredentials(String.class).iterator().next()); + } else + nc.setName(nc.getDefaultName()); + } else if (callback instanceof PasswordCallback) { + if (subject != null && !subject.getPrivateCredentials(String.class).isEmpty()) { + char[] password = subject.getPrivateCredentials(String.class).iterator().next().toCharArray(); + ((PasswordCallback) callback).setPassword(password); + } else { + String errorMessage = "Could not login: the client is being asked for a password, but the Kafka" + + " client code does not currently support obtaining a password from the user."; + throw new UnsupportedCallbackException(callback, errorMessage); + } + } else if (callback instanceof RealmCallback) { + RealmCallback rc = (RealmCallback) callback; + rc.setText(rc.getDefaultText()); + } else if (callback instanceof AuthorizeCallback) { + AuthorizeCallback ac = (AuthorizeCallback) callback; + String authId = ac.getAuthenticationID(); + String authzId = ac.getAuthorizationID(); + ac.setAuthorized(authId.equals(authzId)); + if (ac.isAuthorized()) + ac.setAuthorizedID(authzId); + } else if (callback instanceof ScramExtensionsCallback) { + if (ScramMechanism.isScram(mechanism) && subject != null && !subject.getPublicCredentials(Map.class).isEmpty()) { + @SuppressWarnings("unchecked") + Map extensions = (Map) subject.getPublicCredentials(Map.class).iterator().next(); + ((ScramExtensionsCallback) callback).extensions(extensions); + } + } else if (callback instanceof SaslExtensionsCallback) { + if (!SaslConfigs.GSSAPI_MECHANISM.equals(mechanism) && + subject != null && !subject.getPublicCredentials(SaslExtensions.class).isEmpty()) { + SaslExtensions extensions = subject.getPublicCredentials(SaslExtensions.class).iterator().next(); + ((SaslExtensionsCallback) callback).extensions(extensions); + } + } else { + throw new UnsupportedCallbackException(callback, "Unrecognized SASL ClientCallback"); + } + } + } + + @Override + public void close() { + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslInternalConfigs.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslInternalConfigs.java new file mode 100644 index 0000000..c1793eb --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslInternalConfigs.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.authenticator; + +import org.apache.kafka.common.config.internals.BrokerSecurityConfigs; + +public class SaslInternalConfigs { + /** + * The server (broker) specifies a positive session length in milliseconds to a + * SASL client when {@link BrokerSecurityConfigs#CONNECTIONS_MAX_REAUTH_MS} is + * positive as per KIP + * 368: Allow SASL Connections to Periodically Re-Authenticate. The session + * length is the minimum of the configured value and any session length implied + * by the credential presented during authentication. The lifetime defined by + * the credential, in terms of milliseconds since the epoch, is available via a + * negotiated property on the SASL Server instance, and that value can be + * converted to a session length by subtracting the time at which authentication + * occurred. This variable defines the negotiated property key that is used to + * communicate the credential lifetime in milliseconds since the epoch. + */ + public static final String CREDENTIAL_LIFETIME_MS_SASL_NEGOTIATED_PROPERTY_KEY = "CREDENTIAL.LIFETIME.MS"; + + private SaslInternalConfigs() { + // empty + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java new file mode 100644 index 0000000..6e35ee7 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticator.java @@ -0,0 +1,726 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.authenticator; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.config.SaslConfigs; +import org.apache.kafka.common.config.internals.BrokerSecurityConfigs; +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.errors.IllegalSaslStateException; +import org.apache.kafka.common.errors.InvalidRequestException; +import org.apache.kafka.common.errors.SaslAuthenticationException; +import org.apache.kafka.common.errors.UnsupportedSaslMechanismException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.SaslAuthenticateResponseData; +import org.apache.kafka.common.message.SaslHandshakeResponseData; +import org.apache.kafka.common.network.Authenticator; +import org.apache.kafka.common.network.ByteBufferSend; +import org.apache.kafka.common.network.ChannelBuilders; +import org.apache.kafka.common.network.ChannelMetadataRegistry; +import org.apache.kafka.common.network.ClientInformation; +import org.apache.kafka.common.network.ListenerName; +import org.apache.kafka.common.network.NetworkReceive; +import org.apache.kafka.common.network.ReauthenticationContext; +import org.apache.kafka.common.network.Send; +import org.apache.kafka.common.network.SslTransportLayer; +import org.apache.kafka.common.network.TransportLayer; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.AbstractResponse; +import org.apache.kafka.common.requests.ApiVersionsRequest; +import org.apache.kafka.common.requests.ApiVersionsResponse; +import org.apache.kafka.common.requests.RequestAndSize; +import org.apache.kafka.common.requests.RequestContext; +import org.apache.kafka.common.requests.RequestHeader; +import org.apache.kafka.common.requests.SaslAuthenticateRequest; +import org.apache.kafka.common.requests.SaslAuthenticateResponse; +import org.apache.kafka.common.requests.SaslHandshakeRequest; +import org.apache.kafka.common.requests.SaslHandshakeResponse; +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; +import org.apache.kafka.common.security.auth.KafkaPrincipal; +import org.apache.kafka.common.security.auth.KafkaPrincipalBuilder; +import org.apache.kafka.common.security.auth.KafkaPrincipalSerde; +import org.apache.kafka.common.security.auth.SaslAuthenticationContext; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.common.security.kerberos.KerberosError; +import org.apache.kafka.common.security.kerberos.KerberosName; +import org.apache.kafka.common.security.kerberos.KerberosShortNamer; +import org.apache.kafka.common.security.scram.ScramLoginModule; +import org.apache.kafka.common.security.scram.internals.ScramMechanism; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.SSLSession; +import javax.security.auth.Subject; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslException; +import javax.security.sasl.SaslServer; +import java.io.Closeable; +import java.io.IOException; +import java.net.InetAddress; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; +import java.util.ArrayList; +import java.util.Date; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.function.Supplier; + +public class SaslServerAuthenticator implements Authenticator { + // GSSAPI limits requests to 64K, but we allow a bit extra for custom SASL mechanisms + static final int MAX_RECEIVE_SIZE = 524288; + private static final Logger LOG = LoggerFactory.getLogger(SaslServerAuthenticator.class); + + /** + * The internal state transitions for initial authentication of a channel on the + * server side are declared in order, starting with {@link #INITIAL_REQUEST} and + * ending in either {@link #COMPLETE} or {@link #FAILED}. + *

        + * Re-authentication of a channel on the server side starts with the state + * {@link #REAUTH_PROCESS_HANDSHAKE}. It may then flow to + * {@link #REAUTH_BAD_MECHANISM} before a transition to {@link #FAILED} if + * re-authentication is attempted with a mechanism different than the original + * one; otherwise it joins the authentication flow at the {@link #AUTHENTICATE} + * state and likewise ends at either {@link #COMPLETE} or {@link #FAILED}. + */ + private enum SaslState { + INITIAL_REQUEST, // May be GSSAPI token, SaslHandshake or ApiVersions for authentication + HANDSHAKE_OR_VERSIONS_REQUEST, // May be SaslHandshake or ApiVersions + HANDSHAKE_REQUEST, // After an ApiVersions request, next request must be SaslHandshake + AUTHENTICATE, // Authentication tokens (SaslHandshake v1 and above indicate SaslAuthenticate headers) + COMPLETE, // Authentication completed successfully + FAILED, // Authentication failed + REAUTH_PROCESS_HANDSHAKE, // Initial state for re-authentication, processes SASL handshake request + REAUTH_BAD_MECHANISM, // When re-authentication requested with wrong mechanism, generate exception + } + + private final SecurityProtocol securityProtocol; + private final ListenerName listenerName; + private final String connectionId; + private final Map subjects; + private final TransportLayer transportLayer; + private final List enabledMechanisms; + private final Map configs; + private final KafkaPrincipalBuilder principalBuilder; + private final Map callbackHandlers; + private final Map connectionsMaxReauthMsByMechanism; + private final Time time; + private final ReauthInfo reauthInfo; + private final ChannelMetadataRegistry metadataRegistry; + private final Supplier apiVersionSupplier; + + // Current SASL state + private SaslState saslState = SaslState.INITIAL_REQUEST; + // Next SASL state to be set when outgoing writes associated with the current SASL state complete + private SaslState pendingSaslState = null; + // Exception that will be thrown by `authenticate()` when SaslState is set to FAILED after outbound writes complete + private AuthenticationException pendingException = null; + private SaslServer saslServer; + private String saslMechanism; + + // buffers used in `authenticate` + private NetworkReceive netInBuffer; + private Send netOutBuffer; + private Send authenticationFailureSend = null; + // flag indicating if sasl tokens are sent as Kafka SaslAuthenticate request/responses + private boolean enableKafkaSaslAuthenticateHeaders; + + public SaslServerAuthenticator(Map configs, + Map callbackHandlers, + String connectionId, + Map subjects, + KerberosShortNamer kerberosNameParser, + ListenerName listenerName, + SecurityProtocol securityProtocol, + TransportLayer transportLayer, + Map connectionsMaxReauthMsByMechanism, + ChannelMetadataRegistry metadataRegistry, + Time time, + Supplier apiVersionSupplier) { + this.callbackHandlers = callbackHandlers; + this.connectionId = connectionId; + this.subjects = subjects; + this.listenerName = listenerName; + this.securityProtocol = securityProtocol; + this.enableKafkaSaslAuthenticateHeaders = false; + this.transportLayer = transportLayer; + this.connectionsMaxReauthMsByMechanism = connectionsMaxReauthMsByMechanism; + this.time = time; + this.reauthInfo = new ReauthInfo(); + this.metadataRegistry = metadataRegistry; + this.apiVersionSupplier = apiVersionSupplier; + + this.configs = configs; + @SuppressWarnings("unchecked") + List enabledMechanisms = (List) this.configs.get(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG); + if (enabledMechanisms == null || enabledMechanisms.isEmpty()) + throw new IllegalArgumentException("No SASL mechanisms are enabled"); + this.enabledMechanisms = new ArrayList<>(new HashSet<>(enabledMechanisms)); + for (String mechanism : this.enabledMechanisms) { + if (!callbackHandlers.containsKey(mechanism)) + throw new IllegalArgumentException("Callback handler not specified for SASL mechanism " + mechanism); + if (!subjects.containsKey(mechanism)) + throw new IllegalArgumentException("Subject cannot be null for SASL mechanism " + mechanism); + LOG.trace("{} for mechanism={}: {}", BrokerSecurityConfigs.CONNECTIONS_MAX_REAUTH_MS, mechanism, + connectionsMaxReauthMsByMechanism.get(mechanism)); + } + + // Note that the old principal builder does not support SASL, so we do not need to pass the + // authenticator or the transport layer + this.principalBuilder = ChannelBuilders.createPrincipalBuilder(configs, kerberosNameParser, null); + } + + private void createSaslServer(String mechanism) throws IOException { + this.saslMechanism = mechanism; + Subject subject = subjects.get(mechanism); + final AuthenticateCallbackHandler callbackHandler = callbackHandlers.get(mechanism); + if (mechanism.equals(SaslConfigs.GSSAPI_MECHANISM)) { + saslServer = createSaslKerberosServer(callbackHandler, configs, subject); + } else { + try { + saslServer = Subject.doAs(subject, (PrivilegedExceptionAction) () -> + Sasl.createSaslServer(saslMechanism, "kafka", serverAddress().getHostName(), configs, callbackHandler)); + if (saslServer == null) { + throw new SaslException("Kafka Server failed to create a SaslServer to interact with a client during session authentication with server mechanism " + saslMechanism); + } + } catch (PrivilegedActionException e) { + throw new SaslException("Kafka Server failed to create a SaslServer to interact with a client during session authentication with server mechanism " + saslMechanism, e.getCause()); + } + } + } + + private SaslServer createSaslKerberosServer(final AuthenticateCallbackHandler saslServerCallbackHandler, final Map configs, Subject subject) throws IOException { + // server is using a JAAS-authenticated subject: determine service principal name and hostname from kafka server's subject. + final String servicePrincipal = SaslClientAuthenticator.firstPrincipal(subject); + KerberosName kerberosName; + try { + kerberosName = KerberosName.parse(servicePrincipal); + } catch (IllegalArgumentException e) { + throw new KafkaException("Principal has name with unexpected format " + servicePrincipal); + } + final String servicePrincipalName = kerberosName.serviceName(); + final String serviceHostname = kerberosName.hostName(); + + LOG.debug("Creating SaslServer for {} with mechanism {}", kerberosName, saslMechanism); + + try { + return Subject.doAs(subject, (PrivilegedExceptionAction) () -> + Sasl.createSaslServer(saslMechanism, servicePrincipalName, serviceHostname, configs, saslServerCallbackHandler)); + } catch (PrivilegedActionException e) { + throw new SaslException("Kafka Server failed to create a SaslServer to interact with a client during session authentication", e.getCause()); + } + } + + /** + * Evaluates client responses via `SaslServer.evaluateResponse` and returns the issued challenge to the client until + * authentication succeeds or fails. + * + * The messages are sent and received as size delimited bytes that consists of a 4 byte network-ordered size N + * followed by N bytes representing the opaque payload. + */ + @SuppressWarnings("fallthrough") + @Override + public void authenticate() throws IOException { + if (saslState != SaslState.REAUTH_PROCESS_HANDSHAKE) { + if (netOutBuffer != null && !flushNetOutBufferAndUpdateInterestOps()) + return; + + if (saslServer != null && saslServer.isComplete()) { + setSaslState(SaslState.COMPLETE); + return; + } + + // allocate on heap (as opposed to any socket server memory pool) + if (netInBuffer == null) netInBuffer = new NetworkReceive(MAX_RECEIVE_SIZE, connectionId); + + netInBuffer.readFrom(transportLayer); + if (!netInBuffer.complete()) + return; + netInBuffer.payload().rewind(); + } + byte[] clientToken = new byte[netInBuffer.payload().remaining()]; + netInBuffer.payload().get(clientToken, 0, clientToken.length); + netInBuffer = null; // reset the networkReceive as we read all the data. + try { + switch (saslState) { + case REAUTH_PROCESS_HANDSHAKE: + case HANDSHAKE_OR_VERSIONS_REQUEST: + case HANDSHAKE_REQUEST: + handleKafkaRequest(clientToken); + break; + case REAUTH_BAD_MECHANISM: + throw new SaslAuthenticationException(reauthInfo.badMechanismErrorMessage); + case INITIAL_REQUEST: + if (handleKafkaRequest(clientToken)) + break; + // For default GSSAPI, fall through to authenticate using the client token as the first GSSAPI packet. + // This is required for interoperability with 0.9.0.x clients which do not send handshake request + case AUTHENTICATE: + handleSaslToken(clientToken); + // When the authentication exchange is complete and no more tokens are expected from the client, + // update SASL state. Current SASL state will be updated when outgoing writes to the client complete. + if (saslServer.isComplete()) + setSaslState(SaslState.COMPLETE); + break; + default: + break; + } + } catch (AuthenticationException e) { + // Exception will be propagated after response is sent to client + setSaslState(SaslState.FAILED, e); + } catch (Exception e) { + // In the case of IOExceptions and other unexpected exceptions, fail immediately + saslState = SaslState.FAILED; + LOG.debug("Failed during {}: {}", reauthInfo.authenticationOrReauthenticationText(), e.getMessage()); + throw e; + } + } + + @Override + public KafkaPrincipal principal() { + Optional sslSession = transportLayer instanceof SslTransportLayer ? + Optional.of(((SslTransportLayer) transportLayer).sslSession()) : Optional.empty(); + SaslAuthenticationContext context = new SaslAuthenticationContext(saslServer, securityProtocol, + clientAddress(), listenerName.value(), sslSession); + KafkaPrincipal principal = principalBuilder.build(context); + if (ScramMechanism.isScram(saslMechanism) && Boolean.parseBoolean((String) saslServer.getNegotiatedProperty(ScramLoginModule.TOKEN_AUTH_CONFIG))) { + principal.tokenAuthenticated(true); + } + return principal; + } + + @Override + public Optional principalSerde() { + return principalBuilder instanceof KafkaPrincipalSerde ? Optional.of((KafkaPrincipalSerde) principalBuilder) : Optional.empty(); + } + + @Override + public boolean complete() { + return saslState == SaslState.COMPLETE; + } + + @Override + public void handleAuthenticationFailure() throws IOException { + sendAuthenticationFailureResponse(); + } + + @Override + public void close() throws IOException { + if (principalBuilder instanceof Closeable) + Utils.closeQuietly((Closeable) principalBuilder, "principal builder"); + if (saslServer != null) + saslServer.dispose(); + } + + @Override + public void reauthenticate(ReauthenticationContext reauthenticationContext) throws IOException { + NetworkReceive saslHandshakeReceive = reauthenticationContext.networkReceive(); + if (saslHandshakeReceive == null) + throw new IllegalArgumentException( + "Invalid saslHandshakeReceive in server-side re-authentication context: null"); + SaslServerAuthenticator previousSaslServerAuthenticator = (SaslServerAuthenticator) reauthenticationContext.previousAuthenticator(); + reauthInfo.reauthenticating(previousSaslServerAuthenticator.saslMechanism, + previousSaslServerAuthenticator.principal(), reauthenticationContext.reauthenticationBeginNanos()); + previousSaslServerAuthenticator.close(); + netInBuffer = saslHandshakeReceive; + LOG.debug("Beginning re-authentication: {}", this); + netInBuffer.payload().rewind(); + setSaslState(SaslState.REAUTH_PROCESS_HANDSHAKE); + authenticate(); + } + + @Override + public Long serverSessionExpirationTimeNanos() { + return reauthInfo.sessionExpirationTimeNanos; + } + + @Override + public Long reauthenticationLatencyMs() { + return reauthInfo.reauthenticationLatencyMs(); + } + + @Override + public boolean connectedClientSupportsReauthentication() { + return reauthInfo.connectedClientSupportsReauthentication; + } + + private void setSaslState(SaslState saslState) { + setSaslState(saslState, null); + } + + private void setSaslState(SaslState saslState, AuthenticationException exception) { + if (netOutBuffer != null && !netOutBuffer.completed()) { + pendingSaslState = saslState; + pendingException = exception; + } else { + this.saslState = saslState; + LOG.debug("Set SASL server state to {} during {}", saslState, reauthInfo.authenticationOrReauthenticationText()); + this.pendingSaslState = null; + this.pendingException = null; + if (exception != null) + throw exception; + } + } + + private boolean flushNetOutBufferAndUpdateInterestOps() throws IOException { + boolean flushedCompletely = flushNetOutBuffer(); + if (flushedCompletely) { + transportLayer.removeInterestOps(SelectionKey.OP_WRITE); + if (pendingSaslState != null) + setSaslState(pendingSaslState, pendingException); + } else + transportLayer.addInterestOps(SelectionKey.OP_WRITE); + return flushedCompletely; + } + + private boolean flushNetOutBuffer() throws IOException { + if (!netOutBuffer.completed()) + netOutBuffer.writeTo(transportLayer); + return netOutBuffer.completed(); + } + + private InetAddress serverAddress() { + return transportLayer.socketChannel().socket().getLocalAddress(); + } + + private InetAddress clientAddress() { + return transportLayer.socketChannel().socket().getInetAddress(); + } + + private void handleSaslToken(byte[] clientToken) throws IOException { + if (!enableKafkaSaslAuthenticateHeaders) { + byte[] response = saslServer.evaluateResponse(clientToken); + if (saslServer.isComplete()) { + reauthInfo.calcCompletionTimesAndReturnSessionLifetimeMs(); + if (reauthInfo.reauthenticating()) + reauthInfo.ensurePrincipalUnchanged(principal()); + } + if (response != null) { + netOutBuffer = ByteBufferSend.sizePrefixed(ByteBuffer.wrap(response)); + flushNetOutBufferAndUpdateInterestOps(); + } + } else { + ByteBuffer requestBuffer = ByteBuffer.wrap(clientToken); + RequestHeader header = RequestHeader.parse(requestBuffer); + ApiKeys apiKey = header.apiKey(); + short version = header.apiVersion(); + RequestContext requestContext = new RequestContext(header, connectionId, clientAddress(), + KafkaPrincipal.ANONYMOUS, listenerName, securityProtocol, ClientInformation.EMPTY, false); + RequestAndSize requestAndSize = requestContext.parseRequest(requestBuffer); + if (apiKey != ApiKeys.SASL_AUTHENTICATE) { + IllegalSaslStateException e = new IllegalSaslStateException("Unexpected Kafka request of type " + apiKey + " during SASL authentication."); + buildResponseOnAuthenticateFailure(requestContext, requestAndSize.request.getErrorResponse(e)); + throw e; + } + if (!apiKey.isVersionSupported(version)) { + // We cannot create an error response if the request version of SaslAuthenticate is not supported + // This should not normally occur since clients typically check supported versions using ApiVersionsRequest + throw new UnsupportedVersionException("Version " + version + " is not supported for apiKey " + apiKey); + } + /* + * The client sends multiple SASL_AUTHENTICATE requests, and the client is known + * to support the required version if any one of them indicates it supports that + * version. + */ + if (!reauthInfo.connectedClientSupportsReauthentication) + reauthInfo.connectedClientSupportsReauthentication = version > 0; + SaslAuthenticateRequest saslAuthenticateRequest = (SaslAuthenticateRequest) requestAndSize.request; + + try { + byte[] responseToken = saslServer.evaluateResponse( + Utils.copyArray(saslAuthenticateRequest.data().authBytes())); + if (reauthInfo.reauthenticating() && saslServer.isComplete()) + reauthInfo.ensurePrincipalUnchanged(principal()); + // For versions with SASL_AUTHENTICATE header, send a response to SASL_AUTHENTICATE request even if token is empty. + byte[] responseBytes = responseToken == null ? new byte[0] : responseToken; + long sessionLifetimeMs = !saslServer.isComplete() ? 0L + : reauthInfo.calcCompletionTimesAndReturnSessionLifetimeMs(); + sendKafkaResponse(requestContext, new SaslAuthenticateResponse( + new SaslAuthenticateResponseData() + .setErrorCode(Errors.NONE.code()) + .setAuthBytes(responseBytes) + .setSessionLifetimeMs(sessionLifetimeMs))); + } catch (SaslAuthenticationException e) { + buildResponseOnAuthenticateFailure(requestContext, + new SaslAuthenticateResponse( + new SaslAuthenticateResponseData() + .setErrorCode(Errors.SASL_AUTHENTICATION_FAILED.code()) + .setErrorMessage(e.getMessage()))); + throw e; + } catch (SaslException e) { + KerberosError kerberosError = KerberosError.fromException(e); + if (kerberosError != null && kerberosError.retriable()) { + // Handle retriable Kerberos exceptions as I/O exceptions rather than authentication exceptions + throw e; + } else { + // DO NOT include error message from the `SaslException` in the client response since it may + // contain sensitive data like the existence of the user. + String errorMessage = "Authentication failed during " + + reauthInfo.authenticationOrReauthenticationText() + + " due to invalid credentials with SASL mechanism " + saslMechanism; + buildResponseOnAuthenticateFailure(requestContext, new SaslAuthenticateResponse( + new SaslAuthenticateResponseData() + .setErrorCode(Errors.SASL_AUTHENTICATION_FAILED.code()) + .setErrorMessage(errorMessage))); + throw new SaslAuthenticationException(errorMessage, e); + } + } + } + } + + private boolean handleKafkaRequest(byte[] requestBytes) throws IOException, AuthenticationException { + boolean isKafkaRequest = false; + String clientMechanism = null; + try { + ByteBuffer requestBuffer = ByteBuffer.wrap(requestBytes); + RequestHeader header = RequestHeader.parse(requestBuffer); + ApiKeys apiKey = header.apiKey(); + + // A valid Kafka request header was received. SASL authentication tokens are now expected only + // following a SaslHandshakeRequest since this is not a GSSAPI client token from a Kafka 0.9.0.x client. + if (saslState == SaslState.INITIAL_REQUEST) + setSaslState(SaslState.HANDSHAKE_OR_VERSIONS_REQUEST); + isKafkaRequest = true; + + // Raise an error prior to parsing if the api cannot be handled at this layer. This avoids + // unnecessary exposure to some of the more complex schema types. + if (apiKey != ApiKeys.API_VERSIONS && apiKey != ApiKeys.SASL_HANDSHAKE) + throw new IllegalSaslStateException("Unexpected Kafka request of type " + apiKey + " during SASL handshake."); + + LOG.debug("Handling Kafka request {} during {}", apiKey, reauthInfo.authenticationOrReauthenticationText()); + + + RequestContext requestContext = new RequestContext(header, connectionId, clientAddress(), + KafkaPrincipal.ANONYMOUS, listenerName, securityProtocol, ClientInformation.EMPTY, false); + RequestAndSize requestAndSize = requestContext.parseRequest(requestBuffer); + if (apiKey == ApiKeys.API_VERSIONS) + handleApiVersionsRequest(requestContext, (ApiVersionsRequest) requestAndSize.request); + else + clientMechanism = handleHandshakeRequest(requestContext, (SaslHandshakeRequest) requestAndSize.request); + } catch (InvalidRequestException e) { + if (saslState == SaslState.INITIAL_REQUEST) { + // InvalidRequestException is thrown if the request is not in Kafka format or if the API key + // is invalid. For compatibility with 0.9.0.x where the first packet is a GSSAPI token + // starting with 0x60, revert to GSSAPI for both these exceptions. + if (LOG.isDebugEnabled()) { + StringBuilder tokenBuilder = new StringBuilder(); + for (byte b : requestBytes) { + tokenBuilder.append(String.format("%02x", b)); + if (tokenBuilder.length() >= 20) + break; + } + LOG.debug("Received client packet of length {} starting with bytes 0x{}, process as GSSAPI packet", requestBytes.length, tokenBuilder); + } + if (enabledMechanisms.contains(SaslConfigs.GSSAPI_MECHANISM)) { + LOG.debug("First client packet is not a SASL mechanism request, using default mechanism GSSAPI"); + clientMechanism = SaslConfigs.GSSAPI_MECHANISM; + } else + throw new UnsupportedSaslMechanismException("Exception handling first SASL packet from client, GSSAPI is not supported by server", e); + } else + throw e; + } + if (clientMechanism != null && (!reauthInfo.reauthenticating() + || reauthInfo.saslMechanismUnchanged(clientMechanism))) { + createSaslServer(clientMechanism); + setSaslState(SaslState.AUTHENTICATE); + } + return isKafkaRequest; + } + + private String handleHandshakeRequest(RequestContext context, SaslHandshakeRequest handshakeRequest) throws IOException, UnsupportedSaslMechanismException { + String clientMechanism = handshakeRequest.data().mechanism(); + short version = context.header.apiVersion(); + if (version >= 1) + this.enableKafkaSaslAuthenticateHeaders(true); + if (enabledMechanisms.contains(clientMechanism)) { + LOG.debug("Using SASL mechanism '{}' provided by client", clientMechanism); + sendKafkaResponse(context, new SaslHandshakeResponse( + new SaslHandshakeResponseData().setErrorCode(Errors.NONE.code()).setMechanisms(enabledMechanisms))); + return clientMechanism; + } else { + LOG.debug("SASL mechanism '{}' requested by client is not supported", clientMechanism); + buildResponseOnAuthenticateFailure(context, new SaslHandshakeResponse( + new SaslHandshakeResponseData().setErrorCode(Errors.UNSUPPORTED_SASL_MECHANISM.code()).setMechanisms(enabledMechanisms))); + throw new UnsupportedSaslMechanismException("Unsupported SASL mechanism " + clientMechanism); + } + } + + // Visible to override for testing + protected void enableKafkaSaslAuthenticateHeaders(boolean flag) { + this.enableKafkaSaslAuthenticateHeaders = flag; + } + + private void handleApiVersionsRequest(RequestContext context, ApiVersionsRequest apiVersionsRequest) throws IOException { + if (saslState != SaslState.HANDSHAKE_OR_VERSIONS_REQUEST) + throw new IllegalStateException("Unexpected ApiVersions request received during SASL authentication state " + saslState); + + if (apiVersionsRequest.hasUnsupportedRequestVersion()) + sendKafkaResponse(context, apiVersionsRequest.getErrorResponse(0, Errors.UNSUPPORTED_VERSION.exception())); + else if (!apiVersionsRequest.isValid()) + sendKafkaResponse(context, apiVersionsRequest.getErrorResponse(0, Errors.INVALID_REQUEST.exception())); + else { + metadataRegistry.registerClientInformation(new ClientInformation(apiVersionsRequest.data().clientSoftwareName(), + apiVersionsRequest.data().clientSoftwareVersion())); + sendKafkaResponse(context, apiVersionSupplier.get()); + setSaslState(SaslState.HANDSHAKE_REQUEST); + } + } + + /** + * Build a {@link Send} response on {@link #authenticate()} failure. The actual response is sent out when + * {@link #sendAuthenticationFailureResponse()} is called. + */ + private void buildResponseOnAuthenticateFailure(RequestContext context, AbstractResponse response) { + authenticationFailureSend = context.buildResponseSend(response); + } + + /** + * Send any authentication failure response that may have been previously built. + */ + private void sendAuthenticationFailureResponse() throws IOException { + if (authenticationFailureSend == null) + return; + sendKafkaResponse(authenticationFailureSend); + authenticationFailureSend = null; + } + + private void sendKafkaResponse(RequestContext context, AbstractResponse response) throws IOException { + sendKafkaResponse(context.buildResponseSend(response)); + } + + private void sendKafkaResponse(Send send) throws IOException { + netOutBuffer = send; + flushNetOutBufferAndUpdateInterestOps(); + } + + /** + * Information related to re-authentication + */ + private class ReauthInfo { + public String previousSaslMechanism; + public KafkaPrincipal previousKafkaPrincipal; + public long reauthenticationBeginNanos; + public Long sessionExpirationTimeNanos; + public boolean connectedClientSupportsReauthentication; + public long authenticationEndNanos; + public String badMechanismErrorMessage; + + public void reauthenticating(String previousSaslMechanism, KafkaPrincipal previousKafkaPrincipal, + long reauthenticationBeginNanos) { + this.previousSaslMechanism = Objects.requireNonNull(previousSaslMechanism); + this.previousKafkaPrincipal = Objects.requireNonNull(previousKafkaPrincipal); + this.reauthenticationBeginNanos = reauthenticationBeginNanos; + } + + public boolean reauthenticating() { + return previousSaslMechanism != null; + } + + public String authenticationOrReauthenticationText() { + return reauthenticating() ? "re-authentication" : "authentication"; + } + + public void ensurePrincipalUnchanged(KafkaPrincipal reauthenticatedKafkaPrincipal) throws SaslAuthenticationException { + if (!previousKafkaPrincipal.equals(reauthenticatedKafkaPrincipal)) { + throw new SaslAuthenticationException(String.format( + "Cannot change principals during re-authentication from %s.%s: %s.%s", + previousKafkaPrincipal.getPrincipalType(), previousKafkaPrincipal.getName(), + reauthenticatedKafkaPrincipal.getPrincipalType(), reauthenticatedKafkaPrincipal.getName())); + } + } + + /* + * We define the REAUTH_BAD_MECHANISM state because the failed re-authentication + * metric does not get updated if we send back an error immediately upon the + * start of re-authentication. + */ + public boolean saslMechanismUnchanged(String clientMechanism) { + if (previousSaslMechanism.equals(clientMechanism)) + return true; + badMechanismErrorMessage = String.format( + "SASL mechanism '%s' requested by client is not supported for re-authentication of mechanism '%s'", + clientMechanism, previousSaslMechanism); + LOG.debug(badMechanismErrorMessage); + setSaslState(SaslState.REAUTH_BAD_MECHANISM); + return false; + } + + private long calcCompletionTimesAndReturnSessionLifetimeMs() { + long retvalSessionLifetimeMs = 0L; + long authenticationEndMs = time.milliseconds(); + authenticationEndNanos = time.nanoseconds(); + Long credentialExpirationMs = (Long) saslServer + .getNegotiatedProperty(SaslInternalConfigs.CREDENTIAL_LIFETIME_MS_SASL_NEGOTIATED_PROPERTY_KEY); + Long connectionsMaxReauthMs = connectionsMaxReauthMsByMechanism.get(saslMechanism); + if (credentialExpirationMs != null || connectionsMaxReauthMs != null) { + if (credentialExpirationMs == null) + retvalSessionLifetimeMs = zeroIfNegative(connectionsMaxReauthMs); + else if (connectionsMaxReauthMs == null) + retvalSessionLifetimeMs = zeroIfNegative(credentialExpirationMs - authenticationEndMs); + else + retvalSessionLifetimeMs = zeroIfNegative( + Math.min(credentialExpirationMs - authenticationEndMs, connectionsMaxReauthMs)); + if (retvalSessionLifetimeMs > 0L) + sessionExpirationTimeNanos = authenticationEndNanos + 1000 * 1000 * retvalSessionLifetimeMs; + } + if (credentialExpirationMs != null) { + if (sessionExpirationTimeNanos != null) + LOG.debug( + "Authentication complete; session max lifetime from broker config={} ms, credential expiration={} ({} ms); session expiration = {} ({} ms), sending {} ms to client", + connectionsMaxReauthMs, new Date(credentialExpirationMs), + credentialExpirationMs - authenticationEndMs, + new Date(authenticationEndMs + retvalSessionLifetimeMs), retvalSessionLifetimeMs, + retvalSessionLifetimeMs); + else + LOG.debug( + "Authentication complete; session max lifetime from broker config={} ms, credential expiration={} ({} ms); no session expiration, sending 0 ms to client", + connectionsMaxReauthMs, new Date(credentialExpirationMs), + credentialExpirationMs - authenticationEndMs); + } else { + if (sessionExpirationTimeNanos != null) + LOG.debug( + "Authentication complete; session max lifetime from broker config={} ms, no credential expiration; session expiration = {} ({} ms), sending {} ms to client", + connectionsMaxReauthMs, new Date(authenticationEndMs + retvalSessionLifetimeMs), + retvalSessionLifetimeMs, retvalSessionLifetimeMs); + else + LOG.debug( + "Authentication complete; session max lifetime from broker config={} ms, no credential expiration; no session expiration, sending 0 ms to client", + connectionsMaxReauthMs); + } + return retvalSessionLifetimeMs; + } + + public Long reauthenticationLatencyMs() { + if (!reauthenticating()) + return null; + // record at least 1 ms if there is some latency + long latencyNanos = authenticationEndNanos - reauthenticationBeginNanos; + return latencyNanos == 0L ? 0L : Math.max(1L, Math.round(latencyNanos / 1000.0 / 1000.0)); + } + + private long zeroIfNegative(long value) { + return Math.max(0L, value); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerCallbackHandler.java new file mode 100644 index 0000000..d3d43cb --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslServerCallbackHandler.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.authenticator; + +import java.util.List; +import java.util.Map; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.auth.login.AppConfigurationEntry; +import javax.security.sasl.AuthorizeCallback; +import javax.security.sasl.RealmCallback; + +import org.apache.kafka.common.config.SaslConfigs; +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Default callback handler for Sasl servers. The callbacks required for all the SASL + * mechanisms enabled in the server should be supported by this callback handler. See + * Java SASL API + * for the list of SASL callback handlers required for each SASL mechanism. + */ +public class SaslServerCallbackHandler implements AuthenticateCallbackHandler { + private static final Logger LOG = LoggerFactory.getLogger(SaslServerCallbackHandler.class); + + private String mechanism; + + @Override + public void configure(Map configs, String mechanism, List jaasConfigEntries) { + this.mechanism = mechanism; + } + + @Override + public void handle(Callback[] callbacks) throws UnsupportedCallbackException { + for (Callback callback : callbacks) { + if (callback instanceof RealmCallback) + handleRealmCallback((RealmCallback) callback); + else if (callback instanceof AuthorizeCallback && mechanism.equals(SaslConfigs.GSSAPI_MECHANISM)) + handleAuthorizeCallback((AuthorizeCallback) callback); + else + throw new UnsupportedCallbackException(callback); + } + } + + private void handleRealmCallback(RealmCallback rc) { + LOG.trace("Client supplied realm: {} ", rc.getDefaultText()); + rc.setText(rc.getDefaultText()); + } + + private void handleAuthorizeCallback(AuthorizeCallback ac) { + String authenticationID = ac.getAuthenticationID(); + String authorizationID = ac.getAuthorizationID(); + LOG.info("Successfully authenticated client: authenticationID={}; authorizationID={}.", + authenticationID, authorizationID); + ac.setAuthorized(true); + ac.setAuthorizedID(authenticationID); + } + + @Override + public void close() { + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/kerberos/BadFormatString.java b/clients/src/main/java/org/apache/kafka/common/security/kerberos/BadFormatString.java new file mode 100644 index 0000000..3f9070b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/kerberos/BadFormatString.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.kerberos; + +import java.io.IOException; + +public class BadFormatString extends IOException { + BadFormatString(String msg) { + super(msg); + } + BadFormatString(String msg, Throwable err) { + super(msg, err); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosClientCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosClientCallbackHandler.java new file mode 100644 index 0000000..fa9cad2 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosClientCallbackHandler.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.kerberos; + +import org.apache.kafka.common.config.SaslConfigs; +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.auth.login.AppConfigurationEntry; +import javax.security.sasl.AuthorizeCallback; +import javax.security.sasl.RealmCallback; +import java.util.List; +import java.util.Map; + +/** + * Callback handler for SASL/GSSAPI clients. + */ +public class KerberosClientCallbackHandler implements AuthenticateCallbackHandler { + + @Override + public void configure(Map configs, String saslMechanism, List jaasConfigEntries) { + if (!saslMechanism.equals(SaslConfigs.GSSAPI_MECHANISM)) + throw new IllegalStateException("Kerberos callback handler should only be used with GSSAPI"); + } + + @Override + public void handle(Callback[] callbacks) throws UnsupportedCallbackException { + for (Callback callback : callbacks) { + if (callback instanceof NameCallback) { + NameCallback nc = (NameCallback) callback; + nc.setName(nc.getDefaultName()); + } else if (callback instanceof PasswordCallback) { + String errorMessage = "Could not login: the client is being asked for a password, but the Kafka" + + " client code does not currently support obtaining a password from the user."; + errorMessage += " Make sure -Djava.security.auth.login.config property passed to JVM and" + + " the client is configured to use a ticket cache (using" + + " the JAAS configuration setting 'useTicketCache=true)'. Make sure you are using" + + " FQDN of the Kafka broker you are trying to connect to."; + throw new UnsupportedCallbackException(callback, errorMessage); + } else if (callback instanceof RealmCallback) { + RealmCallback rc = (RealmCallback) callback; + rc.setText(rc.getDefaultText()); + } else if (callback instanceof AuthorizeCallback) { + AuthorizeCallback ac = (AuthorizeCallback) callback; + String authId = ac.getAuthenticationID(); + String authzId = ac.getAuthorizationID(); + ac.setAuthorized(authId.equals(authzId)); + if (ac.isAuthorized()) + ac.setAuthorizedID(authzId); + } else { + throw new UnsupportedCallbackException(callback, "Unrecognized SASL ClientCallback"); + } + } + } + + @Override + public void close() { + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosError.java b/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosError.java new file mode 100644 index 0000000..4b8e8e0 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosError.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.kerberos; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.security.authenticator.SaslClientAuthenticator; +import org.apache.kafka.common.utils.Java; +import org.ietf.jgss.GSSException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.security.sasl.SaslClient; +import java.lang.reflect.Method; + +/** + * Kerberos exceptions that may require special handling. The standard Kerberos error codes + * for these errors are retrieved using KrbException#errorCode() from the underlying Kerberos + * exception thrown during {@link SaslClient#evaluateChallenge(byte[])}. + */ +public enum KerberosError { + // (Mechanism level: Server not found in Kerberos database (7) - UNKNOWN_SERVER) + // This is retriable, but included here to add extra logging for this case. + SERVER_NOT_FOUND(7, false), + // (Mechanism level: Client not yet valid - try again later (21)) + CLIENT_NOT_YET_VALID(21, true), + // (Mechanism level: Ticket not yet valid (33) - Ticket not yet valid)]) + // This could be a small timing window. + TICKET_NOT_YET_VALID(33, true), + // (Mechanism level: Request is a replay (34) - Request is a replay) + // Replay detection used to prevent DoS attacks can result in false positives, so retry on error. + REPLAY(34, true); + + private static final Logger log = LoggerFactory.getLogger(SaslClientAuthenticator.class); + private static final Class KRB_EXCEPTION_CLASS; + private static final Method KRB_EXCEPTION_RETURN_CODE_METHOD; + + static { + try { + // different IBM JDKs versions include different security implementations + if (Java.isIbmJdk() && canLoad("com.ibm.security.krb5.KrbException")) { + KRB_EXCEPTION_CLASS = Class.forName("com.ibm.security.krb5.KrbException"); + } else if (Java.isIbmJdk() && canLoad("com.ibm.security.krb5.internal.KrbException")) { + KRB_EXCEPTION_CLASS = Class.forName("com.ibm.security.krb5.internal.KrbException"); + } else { + KRB_EXCEPTION_CLASS = Class.forName("sun.security.krb5.KrbException"); + } + KRB_EXCEPTION_RETURN_CODE_METHOD = KRB_EXCEPTION_CLASS.getMethod("returnCode"); + } catch (Exception e) { + throw new KafkaException("Kerberos exceptions could not be initialized", e); + } + } + + private static boolean canLoad(String clazz) { + try { + Class.forName(clazz); + return true; + } catch (Exception e) { + return false; + } + } + + private final int errorCode; + private final boolean retriable; + + KerberosError(int errorCode, boolean retriable) { + this.errorCode = errorCode; + this.retriable = retriable; + } + + public boolean retriable() { + return retriable; + } + + public static KerberosError fromException(Exception exception) { + Throwable cause = exception.getCause(); + while (cause != null && !KRB_EXCEPTION_CLASS.isInstance(cause)) { + cause = cause.getCause(); + } + if (cause == null) + return null; + else { + try { + Integer errorCode = (Integer) KRB_EXCEPTION_RETURN_CODE_METHOD.invoke(cause); + return fromErrorCode(errorCode); + } catch (Exception e) { + log.trace("Kerberos return code could not be determined from {} due to {}", exception, e); + return null; + } + } + } + + private static KerberosError fromErrorCode(int errorCode) { + for (KerberosError error : values()) { + if (error.errorCode == errorCode) + return error; + } + return null; + } + + /** + * Returns true if the exception should be handled as a transient failure on clients. + * We handle GSSException.NO_CRED as retriable on the client-side since this may + * occur during re-login if a clients attempts to authentication after logout, but + * before the subsequent login. + */ + public static boolean isRetriableClientGssException(Exception exception) { + Throwable cause = exception.getCause(); + while (cause != null && !(cause instanceof GSSException)) { + cause = cause.getCause(); + } + if (cause != null) { + GSSException gssException = (GSSException) cause; + return gssException.getMajor() == GSSException.NO_CRED; + } + return false; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosLogin.java b/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosLogin.java new file mode 100644 index 0000000..f2b25a5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosLogin.java @@ -0,0 +1,396 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.kerberos; + +import javax.security.auth.kerberos.KerberosPrincipal; +import javax.security.auth.login.AppConfigurationEntry; +import javax.security.auth.login.Configuration; +import javax.security.auth.login.LoginContext; +import javax.security.auth.login.LoginException; +import javax.security.auth.kerberos.KerberosTicket; +import javax.security.auth.Subject; + +import org.apache.kafka.common.security.JaasContext; +import org.apache.kafka.common.security.JaasUtils; +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; +import org.apache.kafka.common.security.authenticator.AbstractLogin; +import org.apache.kafka.common.config.SaslConfigs; +import org.apache.kafka.common.utils.KafkaThread; +import org.apache.kafka.common.utils.Shell; +import org.apache.kafka.common.utils.Time; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Arrays; +import java.util.Date; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; + +/** + * This class is responsible for refreshing Kerberos credentials for + * logins for both Kafka client and server. + */ +public class KerberosLogin extends AbstractLogin { + private static final Logger log = LoggerFactory.getLogger(KerberosLogin.class); + + private static final Random RNG = new Random(); + + private final Time time = Time.SYSTEM; + private Thread t; + private boolean isKrbTicket; + private boolean isUsingTicketCache; + + private String principal; + + // LoginThread will sleep until 80% of time from last refresh to + // ticket's expiry has been reached, at which time it will wake + // and try to renew the ticket. + private double ticketRenewWindowFactor; + + /** + * Percentage of random jitter added to the renewal time + */ + private double ticketRenewJitter; + + // Regardless of ticketRenewWindowFactor setting above and the ticket expiry time, + // thread will not sleep between refresh attempts any less than 1 minute (60*1000 milliseconds = 1 minute). + // Change the '1' to e.g. 5, to change this to 5 minutes. + private long minTimeBeforeRelogin; + + private String kinitCmd; + + private volatile Subject subject; + + private LoginContext loginContext; + private String serviceName; + private long lastLogin; + + @Override + public void configure(Map configs, String contextName, Configuration configuration, + AuthenticateCallbackHandler callbackHandler) { + super.configure(configs, contextName, configuration, callbackHandler); + this.ticketRenewWindowFactor = (Double) configs.get(SaslConfigs.SASL_KERBEROS_TICKET_RENEW_WINDOW_FACTOR); + this.ticketRenewJitter = (Double) configs.get(SaslConfigs.SASL_KERBEROS_TICKET_RENEW_JITTER); + this.minTimeBeforeRelogin = (Long) configs.get(SaslConfigs.SASL_KERBEROS_MIN_TIME_BEFORE_RELOGIN); + this.kinitCmd = (String) configs.get(SaslConfigs.SASL_KERBEROS_KINIT_CMD); + this.serviceName = getServiceName(configs, contextName, configuration); + } + + /** + * Performs login for each login module specified for the login context of this instance and starts the thread used + * to periodically re-login to the Kerberos Ticket Granting Server. + */ + @Override + public LoginContext login() throws LoginException { + + this.lastLogin = currentElapsedTime(); + loginContext = super.login(); + subject = loginContext.getSubject(); + isKrbTicket = !subject.getPrivateCredentials(KerberosTicket.class).isEmpty(); + + AppConfigurationEntry[] entries = configuration().getAppConfigurationEntry(contextName()); + if (entries.length == 0) { + isUsingTicketCache = false; + principal = null; + } else { + // there will only be a single entry + AppConfigurationEntry entry = entries[0]; + if (entry.getOptions().get("useTicketCache") != null) { + String val = (String) entry.getOptions().get("useTicketCache"); + isUsingTicketCache = val.equals("true"); + } else + isUsingTicketCache = false; + if (entry.getOptions().get("principal") != null) + principal = (String) entry.getOptions().get("principal"); + else + principal = null; + } + + if (!isKrbTicket) { + log.debug("[Principal={}]: It is not a Kerberos ticket", principal); + t = null; + // if no TGT, do not bother with ticket management. + return loginContext; + } + log.debug("[Principal={}]: It is a Kerberos ticket", principal); + + // Refresh the Ticket Granting Ticket (TGT) periodically. How often to refresh is determined by the + // TGT's existing expiry date and the configured minTimeBeforeRelogin. For testing and development, + // you can decrease the interval of expiration of tickets (for example, to 3 minutes) by running: + // "modprinc -maxlife 3mins " in kadmin. + t = KafkaThread.daemon(String.format("kafka-kerberos-refresh-thread-%s", principal), () -> { + log.info("[Principal={}]: TGT refresh thread started.", principal); + while (true) { // renewal thread's main loop. if it exits from here, thread will exit. + KerberosTicket tgt = getTGT(); + long now = currentWallTime(); + long nextRefresh; + Date nextRefreshDate; + if (tgt == null) { + nextRefresh = now + minTimeBeforeRelogin; + nextRefreshDate = new Date(nextRefresh); + log.warn("[Principal={}]: No TGT found: will try again at {}", principal, nextRefreshDate); + } else { + nextRefresh = getRefreshTime(tgt); + long expiry = tgt.getEndTime().getTime(); + Date expiryDate = new Date(expiry); + if (isUsingTicketCache && tgt.getRenewTill() != null && tgt.getRenewTill().getTime() < expiry) { + log.warn("The TGT cannot be renewed beyond the next expiry date: {}." + + "This process will not be able to authenticate new SASL connections after that " + + "time (for example, it will not be able to authenticate a new connection with a Kafka " + + "Broker). Ask your system administrator to either increase the " + + "'renew until' time by doing : 'modprinc -maxrenewlife {} ' within " + + "kadmin, or instead, to generate a keytab for {}. Because the TGT's " + + "expiry cannot be further extended by refreshing, exiting refresh thread now.", + expiryDate, principal, principal); + return; + } + // determine how long to sleep from looking at ticket's expiry. + // We should not allow the ticket to expire, but we should take into consideration + // minTimeBeforeRelogin. Will not sleep less than minTimeBeforeRelogin, unless doing so + // would cause ticket expiration. + if ((nextRefresh > expiry) || (minTimeBeforeRelogin > expiry - now)) { + // expiry is before next scheduled refresh). + log.info("[Principal={}]: Refreshing now because expiry is before next scheduled refresh time.", principal); + nextRefresh = now; + } else { + if (nextRefresh - now < minTimeBeforeRelogin) { + // next scheduled refresh is sooner than (now + MIN_TIME_BEFORE_LOGIN). + Date until = new Date(nextRefresh); + Date newUntil = new Date(now + minTimeBeforeRelogin); + log.warn("[Principal={}]: TGT refresh thread time adjusted from {} to {} since the former is sooner " + + "than the minimum refresh interval ({} seconds) from now.", + principal, until, newUntil, minTimeBeforeRelogin / 1000); + } + nextRefresh = Math.max(nextRefresh, now + minTimeBeforeRelogin); + } + nextRefreshDate = new Date(nextRefresh); + if (nextRefresh > expiry) { + log.error("[Principal={}]: Next refresh: {} is later than expiry {}. This may indicate a clock skew problem." + + "Check that this host and the KDC hosts' clocks are in sync. Exiting refresh thread.", + principal, nextRefreshDate, expiryDate); + return; + } + } + if (now < nextRefresh) { + Date until = new Date(nextRefresh); + log.info("[Principal={}]: TGT refresh sleeping until: {}", principal, until); + try { + Thread.sleep(nextRefresh - now); + } catch (InterruptedException ie) { + log.warn("[Principal={}]: TGT renewal thread has been interrupted and will exit.", principal); + return; + } + } else { + log.error("[Principal={}]: NextRefresh: {} is in the past: exiting refresh thread. Check" + + " clock sync between this host and KDC - (KDC's clock is likely ahead of this host)." + + " Manual intervention will be required for this client to successfully authenticate." + + " Exiting refresh thread.", principal, nextRefreshDate); + return; + } + if (isUsingTicketCache) { + String kinitArgs = "-R"; + int retry = 1; + while (retry >= 0) { + try { + log.debug("[Principal={}]: Running ticket cache refresh command: {} {}", principal, kinitCmd, kinitArgs); + Shell.execCommand(kinitCmd, kinitArgs); + break; + } catch (Exception e) { + if (retry > 0) { + log.warn("[Principal={}]: Error when trying to renew with TicketCache, but will retry ", principal, e); + --retry; + // sleep for 10 seconds + try { + Thread.sleep(10 * 1000); + } catch (InterruptedException ie) { + log.error("[Principal={}]: Interrupted while renewing TGT, exiting Login thread", principal); + return; + } + } else { + log.warn("[Principal={}]: Could not renew TGT due to problem running shell command: '{} {}'. " + + "Exiting refresh thread.", principal, kinitCmd, kinitArgs, e); + return; + } + } + } + } + try { + int retry = 1; + while (retry >= 0) { + try { + reLogin(); + break; + } catch (LoginException le) { + if (retry > 0) { + log.warn("[Principal={}]: Error when trying to re-Login, but will retry ", principal, le); + --retry; + // sleep for 10 seconds. + try { + Thread.sleep(10 * 1000); + } catch (InterruptedException e) { + log.error("[Principal={}]: Interrupted during login retry after LoginException:", principal, le); + throw le; + } + } else { + log.error("[Principal={}]: Could not refresh TGT.", principal, le); + } + } + } + } catch (LoginException le) { + log.error("[Principal={}]: Failed to refresh TGT: refresh thread exiting now.", principal, le); + return; + } + } + }); + t.start(); + return loginContext; + } + + @Override + public void close() { + if ((t != null) && (t.isAlive())) { + t.interrupt(); + try { + t.join(); + } catch (InterruptedException e) { + log.warn("[Principal={}]: Error while waiting for Login thread to shutdown.", principal, e); + Thread.currentThread().interrupt(); + } + } + } + + @Override + public Subject subject() { + return subject; + } + + @Override + public String serviceName() { + return serviceName; + } + + private static String getServiceName(Map configs, String contextName, Configuration configuration) { + List configEntries = Arrays.asList(configuration.getAppConfigurationEntry(contextName)); + String jaasServiceName = JaasContext.configEntryOption(configEntries, JaasUtils.SERVICE_NAME, null); + String configServiceName = (String) configs.get(SaslConfigs.SASL_KERBEROS_SERVICE_NAME); + if (jaasServiceName != null && configServiceName != null && !jaasServiceName.equals(configServiceName)) { + String message = String.format("Conflicting serviceName values found in JAAS and Kafka configs " + + "value in JAAS file %s, value in Kafka config %s", jaasServiceName, configServiceName); + throw new IllegalArgumentException(message); + } + + if (jaasServiceName != null) + return jaasServiceName; + if (configServiceName != null) + return configServiceName; + + throw new IllegalArgumentException("No serviceName defined in either JAAS or Kafka config"); + } + + + private long getRefreshTime(KerberosTicket tgt) { + long start = tgt.getStartTime().getTime(); + long expires = tgt.getEndTime().getTime(); + log.info("[Principal={}]: TGT valid starting at: {}", principal, tgt.getStartTime()); + log.info("[Principal={}]: TGT expires: {}", principal, tgt.getEndTime()); + long proposedRefresh = start + (long) ((expires - start) * + (ticketRenewWindowFactor + (ticketRenewJitter * RNG.nextDouble()))); + + if (proposedRefresh > expires) + // proposedRefresh is too far in the future: it's after ticket expires: simply return now. + return currentWallTime(); + else + return proposedRefresh; + } + + private KerberosTicket getTGT() { + Set tickets = subject.getPrivateCredentials(KerberosTicket.class); + for (KerberosTicket ticket : tickets) { + KerberosPrincipal server = ticket.getServer(); + if (server.getName().equals("krbtgt/" + server.getRealm() + "@" + server.getRealm())) { + log.debug("Found TGT with client principal '{}' and server principal '{}'.", ticket.getClient().getName(), + ticket.getServer().getName()); + return ticket; + } + } + return null; + } + + private boolean hasSufficientTimeElapsed() { + long now = currentElapsedTime(); + if (now - lastLogin < minTimeBeforeRelogin) { + log.warn("[Principal={}]: Not attempting to re-login since the last re-login was attempted less than {} seconds before.", + principal, minTimeBeforeRelogin / 1000); + return false; + } + return true; + } + + /** + * Re-login a principal. This method assumes that {@link #login()} has happened already. + * @throws javax.security.auth.login.LoginException on a failure + */ + protected void reLogin() throws LoginException { + if (!isKrbTicket) { + return; + } + if (loginContext == null) { + throw new LoginException("Login must be done first"); + } + if (!hasSufficientTimeElapsed()) { + return; + } + synchronized (KerberosLogin.class) { + log.info("Initiating logout for {}", principal); + // register most recent relogin attempt + lastLogin = currentElapsedTime(); + //clear up the kerberos state. But the tokens are not cleared! As per + //the Java kerberos login module code, only the kerberos credentials + //are cleared. If previous logout succeeded but login failed, we shouldn't + //logout again since duplicate logout causes NPE from Java 9 onwards. + if (subject != null && !subject.getPrincipals().isEmpty()) { + logout(); + } + //login and also update the subject field of this instance to + //have the new credentials (pass it to the LoginContext constructor) + loginContext = new LoginContext(contextName(), subject, null, configuration()); + log.info("Initiating re-login for {}", principal); + login(loginContext); + } + } + + // Visibility to override for testing + protected void login(LoginContext loginContext) throws LoginException { + loginContext.login(); + } + + // Visibility to override for testing + protected void logout() throws LoginException { + loginContext.logout(); + } + + private long currentElapsedTime() { + return time.hiResClockMs(); + } + + private long currentWallTime() { + return time.milliseconds(); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosName.java b/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosName.java new file mode 100644 index 0000000..8ac6d7e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosName.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.kerberos; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public class KerberosName { + + /** + * A pattern that matches a Kerberos name with at most 3 components. + */ + private static final Pattern NAME_PARSER = Pattern.compile("([^/@]*)(/([^/@]*))?@([^/@]*)"); + + /** The first component of the name */ + private final String serviceName; + /** The second component of the name. It may be null. */ + private final String hostName; + /** The realm of the name. */ + private final String realm; + + /** + * Creates an instance of `KerberosName` with the provided parameters. + */ + public KerberosName(String serviceName, String hostName, String realm) { + if (serviceName == null) + throw new IllegalArgumentException("serviceName must not be null"); + this.serviceName = serviceName; + this.hostName = hostName; + this.realm = realm; + } + + /** + * Create a name from the full Kerberos principal name. + */ + public static KerberosName parse(String principalName) { + Matcher match = NAME_PARSER.matcher(principalName); + if (!match.matches()) { + if (principalName.contains("@")) { + throw new IllegalArgumentException("Malformed Kerberos name: " + principalName); + } else { + return new KerberosName(principalName, null, null); + } + } else { + return new KerberosName(match.group(1), match.group(3), match.group(4)); + } + } + + /** + * Put the name back together from the parts. + */ + @Override + public String toString() { + StringBuilder result = new StringBuilder(); + result.append(serviceName); + if (hostName != null) { + result.append('/'); + result.append(hostName); + } + if (realm != null) { + result.append('@'); + result.append(realm); + } + return result.toString(); + } + + /** + * Get the first component of the name. + * @return the first section of the Kerberos principal name + */ + public String serviceName() { + return serviceName; + } + + /** + * Get the second component of the name. + * @return the second section of the Kerberos principal name, and may be null + */ + public String hostName() { + return hostName; + } + + /** + * Get the realm of the name. + * @return the realm of the name, may be null + */ + public String realm() { + return realm; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosRule.java b/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosRule.java new file mode 100644 index 0000000..91280ca --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosRule.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.kerberos; + +import java.io.IOException; +import java.util.Locale; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * An encoding of a rule for translating kerberos names. + */ +class KerberosRule { + + /** + * A pattern that matches a string without '$' and then a single + * parameter with $n. + */ + private static final Pattern PARAMETER_PATTERN = Pattern.compile("([^$]*)(\\$(\\d*))?"); + + /** + * A pattern that recognizes simple/non-simple names. + */ + private static final Pattern NON_SIMPLE_PATTERN = Pattern.compile("[/@]"); + + private final String defaultRealm; + private final boolean isDefault; + private final int numOfComponents; + private final String format; + private final Pattern match; + private final Pattern fromPattern; + private final String toPattern; + private final boolean repeat; + private final boolean toLowerCase; + private final boolean toUpperCase; + + KerberosRule(String defaultRealm) { + this.defaultRealm = defaultRealm; + isDefault = true; + numOfComponents = 0; + format = null; + match = null; + fromPattern = null; + toPattern = null; + repeat = false; + toLowerCase = false; + toUpperCase = false; + } + + KerberosRule(String defaultRealm, int numOfComponents, String format, String match, String fromPattern, + String toPattern, boolean repeat, boolean toLowerCase, boolean toUpperCase) { + this.defaultRealm = defaultRealm; + isDefault = false; + this.numOfComponents = numOfComponents; + this.format = format; + this.match = match == null ? null : Pattern.compile(match); + this.fromPattern = + fromPattern == null ? null : Pattern.compile(fromPattern); + this.toPattern = toPattern; + this.repeat = repeat; + this.toLowerCase = toLowerCase; + this.toUpperCase = toUpperCase; + } + + @Override + public String toString() { + StringBuilder buf = new StringBuilder(); + if (isDefault) { + buf.append("DEFAULT"); + } else { + buf.append("RULE:["); + buf.append(numOfComponents); + buf.append(':'); + buf.append(format); + buf.append(']'); + if (match != null) { + buf.append('('); + buf.append(match); + buf.append(')'); + } + if (fromPattern != null) { + buf.append("s/"); + buf.append(fromPattern); + buf.append('/'); + buf.append(toPattern); + buf.append('/'); + if (repeat) { + buf.append('g'); + } + } + if (toLowerCase) { + buf.append("/L"); + } + if (toUpperCase) { + buf.append("/U"); + } + } + return buf.toString(); + } + + /** + * Replace the numbered parameters of the form $n where n is from 0 to + * the length of params - 1. Normal text is copied directly and $n is replaced + * by the corresponding parameter. + * @param format the string to replace parameters again + * @param params the list of parameters + * @return the generated string with the parameter references replaced. + * @throws BadFormatString + */ + static String replaceParameters(String format, + String[] params) throws BadFormatString { + Matcher match = PARAMETER_PATTERN.matcher(format); + int start = 0; + StringBuilder result = new StringBuilder(); + while (start < format.length() && match.find(start)) { + result.append(match.group(1)); + String paramNum = match.group(3); + if (paramNum != null) { + try { + int num = Integer.parseInt(paramNum); + if (num < 0 || num >= params.length) { + throw new BadFormatString("index " + num + " from " + format + + " is outside of the valid range 0 to " + + (params.length - 1)); + } + result.append(params[num]); + } catch (NumberFormatException nfe) { + throw new BadFormatString("bad format in username mapping in " + + paramNum, nfe); + } + + } + start = match.end(); + } + return result.toString(); + } + + /** + * Replace the matches of the from pattern in the base string with the value + * of the to string. + * @param base the string to transform + * @param from the pattern to look for in the base string + * @param to the string to replace matches of the pattern with + * @param repeat whether the substitution should be repeated + * @return + */ + static String replaceSubstitution(String base, Pattern from, String to, + boolean repeat) { + Matcher match = from.matcher(base); + if (repeat) { + return match.replaceAll(to); + } else { + return match.replaceFirst(to); + } + } + + /** + * Try to apply this rule to the given name represented as a parameter + * array. + * @param params first element is the realm, second and later elements are + * are the components of the name "a/b@FOO" -> {"FOO", "a", "b"} + * @return the short name if this rule applies or null + * @throws IOException throws if something is wrong with the rules + */ + String apply(String[] params) throws IOException { + String result = null; + if (isDefault) { + if (defaultRealm.equals(params[0])) { + result = params[1]; + } + } else if (params.length - 1 == numOfComponents) { + String base = replaceParameters(format, params); + if (match == null || match.matcher(base).matches()) { + if (fromPattern == null) { + result = base; + } else { + result = replaceSubstitution(base, fromPattern, toPattern, repeat); + } + } + } + if (result != null && NON_SIMPLE_PATTERN.matcher(result).find()) { + throw new NoMatchingRule("Non-simple name " + result + " after auth_to_local rule " + this); + } + if (toLowerCase && result != null) { + result = result.toLowerCase(Locale.ENGLISH); + } else if (toUpperCase && result != null) { + result = result.toUpperCase(Locale.ENGLISH); + } + + return result; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosShortNamer.java b/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosShortNamer.java new file mode 100644 index 0000000..96e01f1 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/kerberos/KerberosShortNamer.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.kerberos; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * This class implements parsing and handling of Kerberos principal names. In + * particular, it splits them apart and translates them down into local + * operating system names. + */ +public class KerberosShortNamer { + + /** + * A pattern for parsing a auth_to_local rule. + */ + private static final Pattern RULE_PARSER = Pattern.compile("((DEFAULT)|((RULE:\\[(\\d*):([^\\]]*)](\\(([^)]*)\\))?(s/([^/]*)/([^/]*)/(g)?)?/?(L|U)?)))"); + + /* Rules for the translation of the principal name into an operating system name */ + private final List principalToLocalRules; + + public KerberosShortNamer(List principalToLocalRules) { + this.principalToLocalRules = principalToLocalRules; + } + + public static KerberosShortNamer fromUnparsedRules(String defaultRealm, List principalToLocalRules) { + List rules = principalToLocalRules == null ? Collections.singletonList("DEFAULT") : principalToLocalRules; + return new KerberosShortNamer(parseRules(defaultRealm, rules)); + } + + private static List parseRules(String defaultRealm, List rules) { + List result = new ArrayList<>(); + for (String rule : rules) { + Matcher matcher = RULE_PARSER.matcher(rule); + if (!matcher.lookingAt()) { + throw new IllegalArgumentException("Invalid rule: " + rule); + } + if (rule.length() != matcher.end()) + throw new IllegalArgumentException("Invalid rule: `" + rule + "`, unmatched substring: `" + rule.substring(matcher.end()) + "`"); + if (matcher.group(2) != null) { + result.add(new KerberosRule(defaultRealm)); + } else { + result.add(new KerberosRule(defaultRealm, + Integer.parseInt(matcher.group(5)), + matcher.group(6), + matcher.group(8), + matcher.group(10), + matcher.group(11), + "g".equals(matcher.group(12)), + "L".equals(matcher.group(13)), + "U".equals(matcher.group(13)))); + + } + } + return result; + } + + /** + * Get the translation of the principal name into an operating system + * user name. + * @return the short name + * @throws IOException + */ + public String shortName(KerberosName kerberosName) throws IOException { + String[] params; + if (kerberosName.hostName() == null) { + // if it is already simple, just return it + if (kerberosName.realm() == null) + return kerberosName.serviceName(); + params = new String[]{kerberosName.realm(), kerberosName.serviceName()}; + } else { + params = new String[]{kerberosName.realm(), kerberosName.serviceName(), kerberosName.hostName()}; + } + for (KerberosRule r : principalToLocalRules) { + String result = r.apply(params); + if (result != null) + return result; + } + throw new NoMatchingRule("No rules apply to " + kerberosName + ", rules " + principalToLocalRules); + } + + @Override + public String toString() { + return "KerberosShortNamer(principalToLocalRules = " + principalToLocalRules + ")"; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/kerberos/NoMatchingRule.java b/clients/src/main/java/org/apache/kafka/common/security/kerberos/NoMatchingRule.java new file mode 100644 index 0000000..387c222 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/kerberos/NoMatchingRule.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.kerberos; + +import java.io.IOException; + +public class NoMatchingRule extends IOException { + NoMatchingRule(String msg) { + super(msg); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerExtensionsValidatorCallback.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerExtensionsValidatorCallback.java new file mode 100644 index 0000000..eab208b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerExtensionsValidatorCallback.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer; + +import org.apache.kafka.common.security.auth.SaslExtensions; + +import javax.security.auth.callback.Callback; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.apache.kafka.common.utils.CollectionUtils.subtractMap; + +/** + * A {@code Callback} for use by the {@code SaslServer} implementation when it + * needs to validate the SASL extensions for the OAUTHBEARER mechanism + * Callback handlers should use the {@link #valid(String)} + * method to communicate valid extensions back to the SASL server. + * Callback handlers should use the + * {@link #error(String, String)} method to communicate validation errors back to + * the SASL Server. + * As per RFC-7628 (https://tools.ietf.org/html/rfc7628#section-3.1), unknown extensions must be ignored by the server. + * The callback handler implementation should simply ignore unknown extensions, + * not calling {@link #error(String, String)} nor {@link #valid(String)}. + * Callback handlers should communicate other problems by raising an {@code IOException}. + *

        + * The OAuth bearer token is provided in the callback for better context in extension validation. + * It is very important that token validation is done in its own {@link OAuthBearerValidatorCallback} + * irregardless of provided extensions, as they are inherently insecure. + */ +public class OAuthBearerExtensionsValidatorCallback implements Callback { + private final OAuthBearerToken token; + private final SaslExtensions inputExtensions; + private final Map validatedExtensions = new HashMap<>(); + private final Map invalidExtensions = new HashMap<>(); + + public OAuthBearerExtensionsValidatorCallback(OAuthBearerToken token, SaslExtensions extensions) { + this.token = Objects.requireNonNull(token); + this.inputExtensions = Objects.requireNonNull(extensions); + } + + /** + * @return {@link OAuthBearerToken} the OAuth bearer token of the client + */ + public OAuthBearerToken token() { + return token; + } + + /** + * @return {@link SaslExtensions} consisting of the unvalidated extension names and values that were sent by the client + */ + public SaslExtensions inputExtensions() { + return inputExtensions; + } + + /** + * @return an unmodifiable {@link Map} consisting of the validated and recognized by the server extension names and values + */ + public Map validatedExtensions() { + return Collections.unmodifiableMap(validatedExtensions); + } + + /** + * @return An immutable {@link Map} consisting of the name->error messages of extensions which failed validation + */ + public Map invalidExtensions() { + return Collections.unmodifiableMap(invalidExtensions); + } + + /** + * @return An immutable {@link Map} consisting of the extensions that have neither been validated nor invalidated + */ + public Map ignoredExtensions() { + return Collections.unmodifiableMap(subtractMap(subtractMap(inputExtensions.map(), invalidExtensions), validatedExtensions)); + } + + /** + * Validates a specific extension in the original {@code inputExtensions} map + * @param extensionName - the name of the extension which was validated + */ + public void valid(String extensionName) { + if (!inputExtensions.map().containsKey(extensionName)) + throw new IllegalArgumentException(String.format("Extension %s was not found in the original extensions", extensionName)); + validatedExtensions.put(extensionName, inputExtensions.map().get(extensionName)); + } + /** + * Set the error value for a specific extension key-value pair if validation has failed + * + * @param invalidExtensionName + * the mandatory extension name which caused the validation failure + * @param errorMessage + * error message describing why the validation failed + */ + public void error(String invalidExtensionName, String errorMessage) { + if (Objects.requireNonNull(invalidExtensionName).isEmpty()) + throw new IllegalArgumentException("extension name must not be empty"); + this.invalidExtensions.put(invalidExtensionName, errorMessage); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModule.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModule.java new file mode 100644 index 0000000..e7976b5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModule.java @@ -0,0 +1,428 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer; + +import java.io.IOException; +import java.util.Collections; +import java.util.Iterator; +import java.util.Map; +import java.util.Objects; + +import javax.security.auth.Subject; +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.auth.login.LoginException; +import javax.security.auth.spi.LoginModule; + +import org.apache.kafka.common.config.SaslConfigs; +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; +import org.apache.kafka.common.security.auth.Login; +import org.apache.kafka.common.security.auth.SaslExtensionsCallback; +import org.apache.kafka.common.security.auth.SaslExtensions; +import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerSaslClientProvider; +import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerSaslServerProvider; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * The {@code LoginModule} for the SASL/OAUTHBEARER mechanism. When a client + * (whether a non-broker client or a broker when SASL/OAUTHBEARER is the + * inter-broker protocol) connects to Kafka the {@code OAuthBearerLoginModule} + * instance asks its configured {@link AuthenticateCallbackHandler} + * implementation to handle an instance of {@link OAuthBearerTokenCallback} and + * return an instance of {@link OAuthBearerToken}. A default, builtin + * {@link AuthenticateCallbackHandler} implementation creates an unsecured token + * as defined by these JAAS module options: + *

        + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
        JAAS Module Option for Unsecured Token RetrievalDocumentation
        {@code unsecuredLoginStringClaim_="value"}Creates a {@code String} claim with the given name and value. Any valid + * claim name can be specified except '{@code iat}' and '{@code exp}' (these are + * automatically generated).
        {@code unsecuredLoginNumberClaim_="value"}Creates a {@code Number} claim with the given name and value. Any valid + * claim name can be specified except '{@code iat}' and '{@code exp}' (these are + * automatically generated).
        {@code unsecuredLoginListClaim_="value"}Creates a {@code String List} claim with the given name and values parsed + * from the given value where the first character is taken as the delimiter. For + * example: {@code unsecuredLoginListClaim_fubar="|value1|value2"}. Any valid + * claim name can be specified except '{@code iat}' and '{@code exp}' (these are + * automatically generated).
        {@code unsecuredLoginPrincipalClaimName}Set to a custom claim name if you wish the name of the {@code String} + * claim holding the principal name to be something other than + * '{@code sub}'.
        {@code unsecuredLoginLifetimeSeconds}Set to an integer value if the token expiration is to be set to something + * other than the default value of 3600 seconds (which is 1 hour). The + * '{@code exp}' claim will be set to reflect the expiration time.
        {@code unsecuredLoginScopeClaimName}Set to a custom claim name if you wish the name of the {@code String} or + * {@code String List} claim holding any token scope to be something other than + * '{@code scope}'.
        + *

        + *

        + * You can also add custom unsecured SASL extensions when using the default, builtin {@link AuthenticateCallbackHandler} + * implementation through using the configurable option {@code unsecuredLoginExtension_}. Note that there + * are validations for the key/values in order to conform to the SASL/OAUTHBEARER standard + * (https://tools.ietf.org/html/rfc7628#section-3.1), including the reserved key at + * {@link org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerClientInitialResponse#AUTH_KEY}. + * The {@code OAuthBearerLoginModule} instance also asks its configured {@link AuthenticateCallbackHandler} + * implementation to handle an instance of {@link SaslExtensionsCallback} and return an instance of {@link SaslExtensions}. + * The configured callback handler does not need to handle this callback, though -- any {@code UnsupportedCallbackException} + * that is thrown is ignored, and no SASL extensions will be associated with the login. + *

        + * Production use cases will require writing an implementation of + * {@link AuthenticateCallbackHandler} that can handle an instance of + * {@link OAuthBearerTokenCallback} and declaring it via either the + * {@code sasl.login.callback.handler.class} configuration option for a + * non-broker client or via the + * {@code listener.name.sasl_ssl.oauthbearer.sasl.login.callback.handler.class} + * configuration option for brokers (when SASL/OAUTHBEARER is the inter-broker + * protocol). + *

        + * This class stores the retrieved {@link OAuthBearerToken} in the + * {@code Subject}'s private credentials where the {@code SaslClient} can + * retrieve it. An appropriate, builtin {@code SaslClient} implementation is + * automatically used and configured such that it can perform that retrieval. + *

        + * Here is a typical, basic JAAS configuration for a client leveraging unsecured + * SASL/OAUTHBEARER authentication: + * + *

        + * KafkaClient {
        + *      org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule Required
        + *      unsecuredLoginStringClaim_sub="thePrincipalName";
        + * };
        + * 
        + * + * An implementation of the {@link Login} interface specific to the + * {@code OAUTHBEARER} mechanism is automatically applied; it periodically + * refreshes any token before it expires so that the client can continue to make + * connections to brokers. The parameters that impact how the refresh algorithm + * operates are specified as part of the producer/consumer/broker configuration + * and are as follows. See the documentation for these properties elsewhere for + * details. + *

        + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
        Producer/Consumer/Broker Configuration Property
        {@code sasl.login.refresh.window.factor}
        {@code sasl.login.refresh.window.jitter}
        {@code sasl.login.refresh.min.period.seconds}
        {@code sasl.login.refresh.min.buffer.seconds}
        + *

        + * When a broker accepts a SASL/OAUTHBEARER connection the instance of the + * builtin {@code SaslServer} implementation asks its configured + * {@link AuthenticateCallbackHandler} implementation to handle an instance of + * {@link OAuthBearerValidatorCallback} constructed with the OAuth 2 Bearer + * Token's compact serialization and return an instance of + * {@link OAuthBearerToken} if the value validates. A default, builtin + * {@link AuthenticateCallbackHandler} implementation validates an unsecured + * token as defined by these JAAS module options: + *

        + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
        JAAS Module Option for Unsecured Token ValidationDocumentation
        {@code unsecuredValidatorPrincipalClaimName="value"}Set to a non-empty value if you wish a particular {@code String} claim + * holding a principal name to be checked for existence; the default is to check + * for the existence of the '{@code sub}' claim.
        {@code unsecuredValidatorScopeClaimName="value"}Set to a custom claim name if you wish the name of the {@code String} or + * {@code String List} claim holding any token scope to be something other than + * '{@code scope}'.
        {@code unsecuredValidatorRequiredScope="value"}Set to a space-delimited list of scope values if you wish the + * {@code String/String List} claim holding the token scope to be checked to + * make sure it contains certain values.
        {@code unsecuredValidatorAllowableClockSkewMs="value"}Set to a positive integer value if you wish to allow up to some number of + * positive milliseconds of clock skew (the default is 0).
        + *

        + * Here is a typical, basic JAAS configuration for a broker leveraging unsecured + * SASL/OAUTHBEARER validation: + * + *

        + * KafkaServer {
        + *      org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule Required
        + *      unsecuredLoginStringClaim_sub="thePrincipalName";
        + * };
        + * 
        + * + * Production use cases will require writing an implementation of + * {@link AuthenticateCallbackHandler} that can handle an instance of + * {@link OAuthBearerValidatorCallback} and declaring it via the + * {@code listener.name.sasl_ssl.oauthbearer.sasl.server.callback.handler.class} + * broker configuration option. + *

        + * The builtin {@code SaslServer} implementation for SASL/OAUTHBEARER in Kafka + * makes the instance of {@link OAuthBearerToken} available upon successful + * authentication via the negotiated property "{@code OAUTHBEARER.token}"; the + * token could be used in a custom authorizer (to authorize based on JWT claims + * rather than ACLs, for example). + *

        + * This implementation's {@code logout()} method will logout the specific token + * that this instance logged in if it's {@code Subject} instance is shared + * across multiple {@code LoginContext}s and there happen to be multiple tokens + * on the {@code Subject}. This functionality is useful because it means a new + * token with a longer lifetime can be created before a soon-to-expire token is + * actually logged out. Otherwise, if multiple simultaneous tokens were not + * supported like this, the soon-to-be expired token would have to be logged out + * first, and then if the new token could not be retrieved (maybe the + * authorization server is temporarily unavailable, for example) the client + * would be left without a token and would be unable to create new connections. + * Better to mitigate this possibility by leaving the existing token (which + * still has some lifetime left) in place until a new replacement token is + * actually retrieved. This implementation supports this. + * + * @see SaslConfigs#SASL_LOGIN_REFRESH_WINDOW_FACTOR_DOC + * @see SaslConfigs#SASL_LOGIN_REFRESH_WINDOW_JITTER_DOC + * @see SaslConfigs#SASL_LOGIN_REFRESH_MIN_PERIOD_SECONDS_DOC + * @see SaslConfigs#SASL_LOGIN_REFRESH_BUFFER_SECONDS_DOC + */ +public class OAuthBearerLoginModule implements LoginModule { + + /** + * Login state transitions: + * Initial state: NOT_LOGGED_IN + * login() : NOT_LOGGED_IN => LOGGED_IN_NOT_COMMITTED + * commit() : LOGGED_IN_NOT_COMMITTED => COMMITTED + * abort() : LOGGED_IN_NOT_COMMITTED => NOT_LOGGED_IN + * logout() : Any state => NOT_LOGGED_IN + */ + private enum LoginState { + NOT_LOGGED_IN, + LOGGED_IN_NOT_COMMITTED, + COMMITTED + } + + /** + * The SASL Mechanism name for OAuth 2: {@code OAUTHBEARER} + */ + public static final String OAUTHBEARER_MECHANISM = "OAUTHBEARER"; + private static final Logger log = LoggerFactory.getLogger(OAuthBearerLoginModule.class); + private static final SaslExtensions EMPTY_EXTENSIONS = new SaslExtensions(Collections.emptyMap()); + private Subject subject = null; + private AuthenticateCallbackHandler callbackHandler = null; + private OAuthBearerToken tokenRequiringCommit = null; + private OAuthBearerToken myCommittedToken = null; + private SaslExtensions extensionsRequiringCommit = null; + private SaslExtensions myCommittedExtensions = null; + private LoginState loginState; + + static { + OAuthBearerSaslClientProvider.initialize(); // not part of public API + OAuthBearerSaslServerProvider.initialize(); // not part of public API + } + + @Override + public void initialize(Subject subject, CallbackHandler callbackHandler, Map sharedState, + Map options) { + this.subject = Objects.requireNonNull(subject); + if (!(Objects.requireNonNull(callbackHandler) instanceof AuthenticateCallbackHandler)) + throw new IllegalArgumentException(String.format("Callback handler must be castable to %s: %s", + AuthenticateCallbackHandler.class.getName(), callbackHandler.getClass().getName())); + this.callbackHandler = (AuthenticateCallbackHandler) callbackHandler; + } + + @Override + public boolean login() throws LoginException { + if (loginState == LoginState.LOGGED_IN_NOT_COMMITTED) { + if (tokenRequiringCommit != null) + throw new IllegalStateException(String.format( + "Already have an uncommitted token with private credential token count=%d", committedTokenCount())); + else + throw new IllegalStateException("Already logged in without a token"); + } + if (loginState == LoginState.COMMITTED) { + if (myCommittedToken != null) + throw new IllegalStateException(String.format( + "Already have a committed token with private credential token count=%d; must login on another login context or logout here first before reusing the same login context", + committedTokenCount())); + else + throw new IllegalStateException("Login has already been committed without a token"); + } + + identifyToken(); + if (tokenRequiringCommit != null) + identifyExtensions(); + else + log.debug("Logged in without a token, this login cannot be used to establish client connections"); + + loginState = LoginState.LOGGED_IN_NOT_COMMITTED; + log.debug("Login succeeded; invoke commit() to commit it; current committed token count={}", + committedTokenCount()); + return true; + } + + private void identifyToken() throws LoginException { + OAuthBearerTokenCallback tokenCallback = new OAuthBearerTokenCallback(); + try { + callbackHandler.handle(new Callback[] {tokenCallback}); + } catch (IOException | UnsupportedCallbackException e) { + log.error(e.getMessage(), e); + throw new LoginException("An internal error occurred while retrieving token from callback handler"); + } + + tokenRequiringCommit = tokenCallback.token(); + if (tokenCallback.errorCode() != null) { + log.info("Login failed: {} : {} (URI={})", tokenCallback.errorCode(), tokenCallback.errorDescription(), + tokenCallback.errorUri()); + throw new LoginException(tokenCallback.errorDescription()); + } + } + + /** + * Attaches SASL extensions to the Subject + */ + private void identifyExtensions() throws LoginException { + SaslExtensionsCallback extensionsCallback = new SaslExtensionsCallback(); + try { + callbackHandler.handle(new Callback[] {extensionsCallback}); + extensionsRequiringCommit = extensionsCallback.extensions(); + } catch (IOException e) { + log.error(e.getMessage(), e); + throw new LoginException("An internal error occurred while retrieving SASL extensions from callback handler"); + } catch (UnsupportedCallbackException e) { + extensionsRequiringCommit = EMPTY_EXTENSIONS; + log.debug("CallbackHandler {} does not support SASL extensions. No extensions will be added", callbackHandler.getClass().getName()); + } + if (extensionsRequiringCommit == null) { + log.error("SASL Extensions cannot be null. Check whether your callback handler is explicitly setting them as null."); + throw new LoginException("Extensions cannot be null."); + } + } + + @Override + public boolean logout() { + if (loginState == LoginState.LOGGED_IN_NOT_COMMITTED) + throw new IllegalStateException( + "Cannot call logout() immediately after login(); need to first invoke commit() or abort()"); + if (loginState != LoginState.COMMITTED) { + log.debug("Nothing here to log out"); + return false; + } + if (myCommittedToken != null) { + log.trace("Logging out my token; current committed token count = {}", committedTokenCount()); + for (Iterator iterator = subject.getPrivateCredentials().iterator(); iterator.hasNext(); ) { + Object privateCredential = iterator.next(); + if (privateCredential == myCommittedToken) { + iterator.remove(); + myCommittedToken = null; + break; + } + } + log.debug("Done logging out my token; committed token count is now {}", committedTokenCount()); + } else + log.debug("No tokens to logout for this login"); + + if (myCommittedExtensions != null) { + log.trace("Logging out my extensions"); + if (subject.getPublicCredentials().removeIf(e -> myCommittedExtensions == e)) + myCommittedExtensions = null; + log.debug("Done logging out my extensions"); + } else + log.debug("No extensions to logout for this login"); + + loginState = LoginState.NOT_LOGGED_IN; + return true; + } + + @Override + public boolean commit() { + if (loginState != LoginState.LOGGED_IN_NOT_COMMITTED) { + log.debug("Nothing here to commit"); + return false; + } + + if (tokenRequiringCommit != null) { + log.trace("Committing my token; current committed token count = {}", committedTokenCount()); + subject.getPrivateCredentials().add(tokenRequiringCommit); + myCommittedToken = tokenRequiringCommit; + tokenRequiringCommit = null; + log.debug("Done committing my token; committed token count is now {}", committedTokenCount()); + } else + log.debug("No tokens to commit, this login cannot be used to establish client connections"); + + if (extensionsRequiringCommit != null) { + subject.getPublicCredentials().add(extensionsRequiringCommit); + myCommittedExtensions = extensionsRequiringCommit; + extensionsRequiringCommit = null; + } + + loginState = LoginState.COMMITTED; + return true; + } + + @Override + public boolean abort() { + if (loginState == LoginState.LOGGED_IN_NOT_COMMITTED) { + log.debug("Login aborted"); + tokenRequiringCommit = null; + extensionsRequiringCommit = null; + loginState = LoginState.NOT_LOGGED_IN; + return true; + } + log.debug("Nothing here to abort"); + return false; + } + + private int committedTokenCount() { + return subject.getPrivateCredentials(OAuthBearerToken.class).size(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerToken.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerToken.java new file mode 100644 index 0000000..ee443ed --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerToken.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer; + +import java.util.Set; + +import org.apache.kafka.common.annotation.InterfaceStability; + +/** + * The b64token value as defined in + * RFC 6750 Section + * 2.1 along with the token's specific scope and lifetime and principal + * name. + *

        + * A network request would be required to re-hydrate an opaque token, and that + * could result in (for example) an {@code IOException}, but retrievers for + * various attributes ({@link #scope()}, {@link #lifetimeMs()}, etc.) declare no + * exceptions. Therefore, if a network request is required for any of these + * retriever methods, that request could be performed at construction time so + * that the various attributes can be reliably provided thereafter. For example, + * a constructor might declare {@code throws IOException} in such a case. + * Alternatively, the retrievers could throw unchecked exceptions. + *

        + * This interface was introduced in 2.0.0 and, while it feels stable, it could + * evolve. We will try to evolve the API in a compatible manner (easier now that + * Java 7 and its lack of default methods doesn't have to be supported), but we + * reserve the right to make breaking changes in minor releases, if necessary. + * We will update the {@code InterfaceStability} annotation and this notice once + * the API is considered stable. + * + * @see RFC 6749 + * Section 1.4 and + * RFC 6750 + * Section 2.1 + */ +@InterfaceStability.Evolving +public interface OAuthBearerToken { + /** + * The b64token value as defined in + * RFC 6750 Section + * 2.1 + * + * @return b64token value as defined in + * RFC 6750 + * Section 2.1 + */ + String value(); + + /** + * The token's scope of access, as per + * RFC 6749 Section + * 1.4 + * + * @return the token's (always non-null but potentially empty) scope of access, + * as per RFC + * 6749 Section 1.4. Note that all values in the returned set will + * be trimmed of preceding and trailing whitespace, and the result will + * never contain the empty string. + */ + Set scope(); + + /** + * The token's lifetime, expressed as the number of milliseconds since the + * epoch, as per RFC + * 6749 Section 1.4 + * + * @return the token'slifetime, expressed as the number of milliseconds since + * the epoch, as per + * RFC 6749 + * Section 1.4. + */ + long lifetimeMs(); + + /** + * The name of the principal to which this credential applies + * + * @return the always non-null/non-empty principal name + */ + String principalName(); + + /** + * When the credential became valid, in terms of the number of milliseconds + * since the epoch, if known, otherwise null. An expiring credential may not + * necessarily indicate when it was created -- just when it expires -- so we + * need to support a null return value here. + * + * @return the time when the credential became valid, in terms of the number of + * milliseconds since the epoch, if known, otherwise null + */ + Long startTimeMs(); +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerTokenCallback.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerTokenCallback.java new file mode 100644 index 0000000..3f4f269 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerTokenCallback.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer; + +import java.util.Objects; + +import javax.security.auth.callback.Callback; + +import org.apache.kafka.common.annotation.InterfaceStability; + +/** + * A {@code Callback} for use by the {@code SaslClient} and {@code Login} + * implementations when they require an OAuth 2 bearer token. Callback handlers + * should use the {@link #error(String, String, String)} method to communicate + * errors returned by the authorization server as per + * RFC 6749: The OAuth + * 2.0 Authorization Framework. Callback handlers should communicate other + * problems by raising an {@code IOException}. + *

        + * This class was introduced in 2.0.0 and, while it feels stable, it could + * evolve. We will try to evolve the API in a compatible manner, but we reserve + * the right to make breaking changes in minor releases, if necessary. We will + * update the {@code InterfaceStability} annotation and this notice once the API + * is considered stable. + */ +@InterfaceStability.Evolving +public class OAuthBearerTokenCallback implements Callback { + private OAuthBearerToken token = null; + private String errorCode = null; + private String errorDescription = null; + private String errorUri = null; + + /** + * Return the (potentially null) token + * + * @return the (potentially null) token + */ + public OAuthBearerToken token() { + return token; + } + + /** + * Return the optional (but always non-empty if not null) error code as per + * RFC 6749: The OAuth + * 2.0 Authorization Framework. + * + * @return the optional (but always non-empty if not null) error code + */ + public String errorCode() { + return errorCode; + } + + /** + * Return the (potentially null) error description as per + * RFC 6749: The OAuth + * 2.0 Authorization Framework. + * + * @return the (potentially null) error description + */ + public String errorDescription() { + return errorDescription; + } + + /** + * Return the (potentially null) error URI as per + * RFC 6749: The OAuth + * 2.0 Authorization Framework. + * + * @return the (potentially null) error URI + */ + public String errorUri() { + return errorUri; + } + + /** + * Set the token. All error-related values are cleared. + * + * @param token + * the optional token to set + */ + public void token(OAuthBearerToken token) { + this.token = token; + this.errorCode = null; + this.errorDescription = null; + this.errorUri = null; + } + + /** + * Set the error values as per + * RFC 6749: The OAuth + * 2.0 Authorization Framework. Any token is cleared. + * + * @param errorCode + * the mandatory error code to set + * @param errorDescription + * the optional error description to set + * @param errorUri + * the optional error URI to set + */ + public void error(String errorCode, String errorDescription, String errorUri) { + if (Objects.requireNonNull(errorCode).isEmpty()) + throw new IllegalArgumentException("error code must not be empty"); + this.errorCode = errorCode; + this.errorDescription = errorDescription; + this.errorUri = errorUri; + this.token = null; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerValidatorCallback.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerValidatorCallback.java new file mode 100644 index 0000000..36bcf08 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerValidatorCallback.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer; + +import java.util.Objects; + +import javax.security.auth.callback.Callback; + +import org.apache.kafka.common.annotation.InterfaceStability; + +/** + * A {@code Callback} for use by the {@code SaslServer} implementation when it + * needs to provide an OAuth 2 bearer token compact serialization for + * validation. Callback handlers should use the + * {@link #error(String, String, String)} method to communicate errors back to + * the SASL Client as per + * RFC 6749: The OAuth + * 2.0 Authorization Framework and the IANA + * OAuth Extensions Error Registry. Callback handlers should communicate + * other problems by raising an {@code IOException}. + *

        + * This class was introduced in 2.0.0 and, while it feels stable, it could + * evolve. We will try to evolve the API in a compatible manner, but we reserve + * the right to make breaking changes in minor releases, if necessary. We will + * update the {@code InterfaceStability} annotation and this notice once the API + * is considered stable. + */ +@InterfaceStability.Evolving +public class OAuthBearerValidatorCallback implements Callback { + private final String tokenValue; + private OAuthBearerToken token = null; + private String errorStatus = null; + private String errorScope = null; + private String errorOpenIDConfiguration = null; + + /** + * Constructor + * + * @param tokenValue + * the mandatory/non-blank token value + */ + public OAuthBearerValidatorCallback(String tokenValue) { + if (Objects.requireNonNull(tokenValue).isEmpty()) + throw new IllegalArgumentException("token value must not be empty"); + this.tokenValue = tokenValue; + } + + /** + * Return the (always non-null) token value + * + * @return the (always non-null) token value + */ + public String tokenValue() { + return tokenValue; + } + + /** + * Return the (potentially null) token + * + * @return the (potentially null) token + */ + public OAuthBearerToken token() { + return token; + } + + /** + * Return the (potentially null) error status value as per + * RFC 7628: A Set + * of Simple Authentication and Security Layer (SASL) Mechanisms for OAuth + * and the IANA + * OAuth Extensions Error Registry. + * + * @return the (potentially null) error status value + */ + public String errorStatus() { + return errorStatus; + } + + /** + * Return the (potentially null) error scope value as per + * RFC 7628: A Set + * of Simple Authentication and Security Layer (SASL) Mechanisms for OAuth. + * + * @return the (potentially null) error scope value + */ + public String errorScope() { + return errorScope; + } + + /** + * Return the (potentially null) error openid-configuration value as per + * RFC 7628: A Set + * of Simple Authentication and Security Layer (SASL) Mechanisms for OAuth. + * + * @return the (potentially null) error openid-configuration value + */ + public String errorOpenIDConfiguration() { + return errorOpenIDConfiguration; + } + + /** + * Set the token. The token value is unchanged and is expected to match the + * provided token's value. All error values are cleared. + * + * @param token + * the mandatory token to set + */ + public void token(OAuthBearerToken token) { + this.token = Objects.requireNonNull(token); + this.errorStatus = null; + this.errorScope = null; + this.errorOpenIDConfiguration = null; + } + + /** + * Set the error values as per + * RFC 7628: A Set + * of Simple Authentication and Security Layer (SASL) Mechanisms for OAuth. + * Any token is cleared. + * + * @param errorStatus + * the mandatory error status value from the IANA + * OAuth Extensions Error Registry to set + * @param errorScope + * the optional error scope value to set + * @param errorOpenIDConfiguration + * the optional error openid-configuration value to set + */ + public void error(String errorStatus, String errorScope, String errorOpenIDConfiguration) { + if (Objects.requireNonNull(errorStatus).isEmpty()) + throw new IllegalArgumentException("error status must not be empty"); + this.errorStatus = errorStatus; + this.errorScope = errorScope; + this.errorOpenIDConfiguration = errorOpenIDConfiguration; + this.token = null; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerClientInitialResponse.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerClientInitialResponse.java new file mode 100644 index 0000000..a356f0d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerClientInitialResponse.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer.internals; + +import org.apache.kafka.common.security.auth.SaslExtensions; +import org.apache.kafka.common.utils.Utils; + +import javax.security.sasl.SaslException; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.Objects; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public class OAuthBearerClientInitialResponse { + static final String SEPARATOR = "\u0001"; + + private static final String SASLNAME = "(?:[\\x01-\\x7F&&[^=,]]|=2C|=3D)+"; + private static final String KEY = "[A-Za-z]+"; + private static final String VALUE = "[\\x21-\\x7E \t\r\n]+"; + + private static final String KVPAIRS = String.format("(%s=%s%s)*", KEY, VALUE, SEPARATOR); + private static final Pattern AUTH_PATTERN = Pattern.compile("(?[\\w]+)[ ]+(?[-_\\.a-zA-Z0-9]+)"); + private static final Pattern CLIENT_INITIAL_RESPONSE_PATTERN = Pattern.compile( + String.format("n,(a=(?%s))?,%s(?%s)%s", SASLNAME, SEPARATOR, KVPAIRS, SEPARATOR)); + public static final String AUTH_KEY = "auth"; + + private final String tokenValue; + private final String authorizationId; + private SaslExtensions saslExtensions; + + public static final Pattern EXTENSION_KEY_PATTERN = Pattern.compile(KEY); + public static final Pattern EXTENSION_VALUE_PATTERN = Pattern.compile(VALUE); + + public OAuthBearerClientInitialResponse(byte[] response) throws SaslException { + String responseMsg = new String(response, StandardCharsets.UTF_8); + Matcher matcher = CLIENT_INITIAL_RESPONSE_PATTERN.matcher(responseMsg); + if (!matcher.matches()) + throw new SaslException("Invalid OAUTHBEARER client first message"); + String authzid = matcher.group("authzid"); + this.authorizationId = authzid == null ? "" : authzid; + String kvPairs = matcher.group("kvpairs"); + Map properties = Utils.parseMap(kvPairs, "=", SEPARATOR); + String auth = properties.get(AUTH_KEY); + if (auth == null) + throw new SaslException("Invalid OAUTHBEARER client first message: 'auth' not specified"); + properties.remove(AUTH_KEY); + SaslExtensions extensions = new SaslExtensions(properties); + validateExtensions(extensions); + this.saslExtensions = extensions; + + Matcher authMatcher = AUTH_PATTERN.matcher(auth); + if (!authMatcher.matches()) + throw new SaslException("Invalid OAUTHBEARER client first message: invalid 'auth' format"); + if (!"bearer".equalsIgnoreCase(authMatcher.group("scheme"))) { + String msg = String.format("Invalid scheme in OAUTHBEARER client first message: %s", + matcher.group("scheme")); + throw new SaslException(msg); + } + this.tokenValue = authMatcher.group("token"); + } + + /** + * Constructor + * + * @param tokenValue + * the mandatory token value + * @param extensions + * the optional extensions + * @throws SaslException + * if any extension name or value fails to conform to the required + * regular expression as defined by the specification, or if the + * reserved {@code auth} appears as a key + */ + public OAuthBearerClientInitialResponse(String tokenValue, SaslExtensions extensions) throws SaslException { + this(tokenValue, "", extensions); + } + + /** + * Constructor + * + * @param tokenValue + * the mandatory token value + * @param authorizationId + * the optional authorization ID + * @param extensions + * the optional extensions + * @throws SaslException + * if any extension name or value fails to conform to the required + * regular expression as defined by the specification, or if the + * reserved {@code auth} appears as a key + */ + public OAuthBearerClientInitialResponse(String tokenValue, String authorizationId, SaslExtensions extensions) throws SaslException { + this.tokenValue = Objects.requireNonNull(tokenValue, "token value must not be null"); + this.authorizationId = authorizationId == null ? "" : authorizationId; + validateExtensions(extensions); + this.saslExtensions = extensions != null ? extensions : SaslExtensions.NO_SASL_EXTENSIONS; + } + + /** + * Return the always non-null extensions + * + * @return the always non-null extensions + */ + public SaslExtensions extensions() { + return saslExtensions; + } + + public byte[] toBytes() { + String authzid = authorizationId.isEmpty() ? "" : "a=" + authorizationId; + String extensions = extensionsMessage(); + if (extensions.length() > 0) + extensions = SEPARATOR + extensions; + + String message = String.format("n,%s,%sauth=Bearer %s%s%s%s", authzid, + SEPARATOR, tokenValue, extensions, SEPARATOR, SEPARATOR); + + return message.getBytes(StandardCharsets.UTF_8); + } + + /** + * Return the always non-null token value + * + * @return the always non-null toklen value + */ + public String tokenValue() { + return tokenValue; + } + + /** + * Return the always non-null authorization ID + * + * @return the always non-null authorization ID + */ + public String authorizationId() { + return authorizationId; + } + + /** + * Validates that the given extensions conform to the standard. They should also not contain the reserve key name {@link OAuthBearerClientInitialResponse#AUTH_KEY} + * + * @param extensions + * optional extensions to validate + * @throws SaslException + * if any extension name or value fails to conform to the required + * regular expression as defined by the specification, or if the + * reserved {@code auth} appears as a key + * + * @see RFC 7628, + * Section 3.1 + */ + public static void validateExtensions(SaslExtensions extensions) throws SaslException { + if (extensions == null) + return; + if (extensions.map().containsKey(OAuthBearerClientInitialResponse.AUTH_KEY)) + throw new SaslException("Extension name " + OAuthBearerClientInitialResponse.AUTH_KEY + " is invalid"); + + for (Map.Entry entry : extensions.map().entrySet()) { + String extensionName = entry.getKey(); + String extensionValue = entry.getValue(); + + if (!EXTENSION_KEY_PATTERN.matcher(extensionName).matches()) + throw new SaslException("Extension name " + extensionName + " is invalid"); + if (!EXTENSION_VALUE_PATTERN.matcher(extensionValue).matches()) + throw new SaslException("Extension value (" + extensionValue + ") for extension " + extensionName + " is invalid"); + } + } + + /** + * Converts the SASLExtensions to an OAuth protocol-friendly string + */ + private String extensionsMessage() { + return Utils.mkString(saslExtensions.map(), "", "", "=", SEPARATOR); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerRefreshingLogin.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerRefreshingLogin.java new file mode 100644 index 0000000..4adbe39 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerRefreshingLogin.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer.internals; + +import java.util.Map; +import java.util.Set; + +import javax.security.auth.Subject; +import javax.security.auth.login.Configuration; +import javax.security.auth.login.LoginContext; +import javax.security.auth.login.LoginException; + +import org.apache.kafka.common.config.SaslConfigs; +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; +import org.apache.kafka.common.security.auth.Login; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; +import org.apache.kafka.common.security.oauthbearer.internals.expiring.ExpiringCredential; +import org.apache.kafka.common.security.oauthbearer.internals.expiring.ExpiringCredentialRefreshConfig; +import org.apache.kafka.common.security.oauthbearer.internals.expiring.ExpiringCredentialRefreshingLogin; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This class is responsible for refreshing logins for both Kafka client and + * server when the credential is an OAuth 2 bearer token communicated over + * SASL/OAUTHBEARER. An OAuth 2 bearer token has a limited lifetime, and an + * instance of this class periodically refreshes it so that the client can + * create new connections to brokers on an ongoing basis. + *

        + * This class does not need to be explicitly set via the + * {@code sasl.login.class} client configuration property or the + * {@code listener.name.sasl_[plaintext|ssl].oauthbearer.sasl.login.class} + * broker configuration property when the SASL mechanism is OAUTHBEARER; it is + * automatically set by default in that case. + *

        + * The parameters that impact how the refresh algorithm operates are specified + * as part of the producer/consumer/broker configuration and are as follows. See + * the documentation for these properties elsewhere for details. + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
        Producer/Consumer/Broker Configuration Property
        {@code sasl.login.refresh.window.factor}
        {@code sasl.login.refresh.window.jitter}
        {@code sasl.login.refresh.min.period.seconds}
        {@code sasl.login.refresh.min.buffer.seconds}
        + * + * @see OAuthBearerLoginModule + * @see SaslConfigs#SASL_LOGIN_REFRESH_WINDOW_FACTOR_DOC + * @see SaslConfigs#SASL_LOGIN_REFRESH_WINDOW_JITTER_DOC + * @see SaslConfigs#SASL_LOGIN_REFRESH_MIN_PERIOD_SECONDS_DOC + * @see SaslConfigs#SASL_LOGIN_REFRESH_BUFFER_SECONDS_DOC + */ +public class OAuthBearerRefreshingLogin implements Login { + private static final Logger log = LoggerFactory.getLogger(OAuthBearerRefreshingLogin.class); + private ExpiringCredentialRefreshingLogin expiringCredentialRefreshingLogin = null; + + @Override + public void configure(Map configs, String contextName, Configuration configuration, + AuthenticateCallbackHandler loginCallbackHandler) { + /* + * Specify this class as the one to synchronize on so that only one OAuth 2 + * Bearer Token is refreshed at a given time. Specify null if we don't mind + * multiple simultaneously refreshes. Refreshes happen on the order of minutes + * rather than seconds or milliseconds, and there are typically minutes of + * lifetime remaining when the refresh occurs, so serializing them seems + * reasonable. + */ + Class classToSynchronizeOnPriorToRefresh = OAuthBearerRefreshingLogin.class; + expiringCredentialRefreshingLogin = new ExpiringCredentialRefreshingLogin(contextName, configuration, + new ExpiringCredentialRefreshConfig(configs, true), loginCallbackHandler, + classToSynchronizeOnPriorToRefresh) { + @Override + public ExpiringCredential expiringCredential() { + Set privateCredentialTokens = expiringCredentialRefreshingLogin.subject() + .getPrivateCredentials(OAuthBearerToken.class); + if (privateCredentialTokens.isEmpty()) + return null; + final OAuthBearerToken token = privateCredentialTokens.iterator().next(); + if (log.isDebugEnabled()) + log.debug("Found expiring credential with principal '{}'.", token.principalName()); + return new ExpiringCredential() { + @Override + public String principalName() { + return token.principalName(); + } + + @Override + public Long startTimeMs() { + return token.startTimeMs(); + } + + @Override + public long expireTimeMs() { + return token.lifetimeMs(); + } + + @Override + public Long absoluteLastRefreshTimeMs() { + return null; + } + }; + } + }; + } + + @Override + public void close() { + if (expiringCredentialRefreshingLogin != null) + expiringCredentialRefreshingLogin.close(); + } + + @Override + public Subject subject() { + return expiringCredentialRefreshingLogin != null ? expiringCredentialRefreshingLogin.subject() : null; + } + + @Override + public String serviceName() { + return expiringCredentialRefreshingLogin != null ? expiringCredentialRefreshingLogin.serviceName() : null; + } + + @Override + public synchronized LoginContext login() throws LoginException { + if (expiringCredentialRefreshingLogin != null) + return expiringCredentialRefreshingLogin.login(); + throw new LoginException("Login was not configured properly"); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClient.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClient.java new file mode 100644 index 0000000..e32e16d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClient.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer.internals; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Map; +import java.util.Objects; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.sasl.SaslClient; +import javax.security.sasl.SaslClientFactory; +import javax.security.sasl.SaslException; + +import org.apache.kafka.common.errors.IllegalSaslStateException; +import org.apache.kafka.common.security.auth.SaslExtensions; +import org.apache.kafka.common.security.auth.SaslExtensionsCallback; +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * {@code SaslClient} implementation for SASL/OAUTHBEARER in Kafka. This + * implementation requires an instance of {@code AuthenticateCallbackHandler} + * that can handle an instance of {@link OAuthBearerTokenCallback} and return + * the {@link OAuthBearerToken} generated by the {@code login()} event on the + * {@code LoginContext}. Said handler can also optionally handle an instance of {@link SaslExtensionsCallback} + * to return any extensions generated by the {@code login()} event on the {@code LoginContext}. + * + * @see RFC 6750, + * Section 2.1 + * + */ +public class OAuthBearerSaslClient implements SaslClient { + static final byte BYTE_CONTROL_A = (byte) 0x01; + private static final Logger log = LoggerFactory.getLogger(OAuthBearerSaslClient.class); + private final CallbackHandler callbackHandler; + + enum State { + SEND_CLIENT_FIRST_MESSAGE, RECEIVE_SERVER_FIRST_MESSAGE, RECEIVE_SERVER_MESSAGE_AFTER_FAILURE, COMPLETE, FAILED + } + + private State state; + + public OAuthBearerSaslClient(AuthenticateCallbackHandler callbackHandler) { + this.callbackHandler = Objects.requireNonNull(callbackHandler); + setState(State.SEND_CLIENT_FIRST_MESSAGE); + } + + public CallbackHandler callbackHandler() { + return callbackHandler; + } + + @Override + public String getMechanismName() { + return OAuthBearerLoginModule.OAUTHBEARER_MECHANISM; + } + + @Override + public boolean hasInitialResponse() { + return true; + } + + @Override + public byte[] evaluateChallenge(byte[] challenge) throws SaslException { + try { + OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback(); + switch (state) { + case SEND_CLIENT_FIRST_MESSAGE: + if (challenge != null && challenge.length != 0) + throw new SaslException("Expected empty challenge"); + callbackHandler().handle(new Callback[] {callback}); + SaslExtensions extensions = retrieveCustomExtensions(); + + setState(State.RECEIVE_SERVER_FIRST_MESSAGE); + + return new OAuthBearerClientInitialResponse(callback.token().value(), extensions).toBytes(); + case RECEIVE_SERVER_FIRST_MESSAGE: + if (challenge != null && challenge.length != 0) { + String jsonErrorResponse = new String(challenge, StandardCharsets.UTF_8); + if (log.isDebugEnabled()) + log.debug("Sending %%x01 response to server after receiving an error: {}", + jsonErrorResponse); + setState(State.RECEIVE_SERVER_MESSAGE_AFTER_FAILURE); + return new byte[] {BYTE_CONTROL_A}; + } + callbackHandler().handle(new Callback[] {callback}); + if (log.isDebugEnabled()) + log.debug("Successfully authenticated as {}", callback.token().principalName()); + setState(State.COMPLETE); + return null; + default: + throw new IllegalSaslStateException("Unexpected challenge in Sasl client state " + state); + } + } catch (SaslException e) { + setState(State.FAILED); + throw e; + } catch (IOException | UnsupportedCallbackException e) { + setState(State.FAILED); + throw new SaslException(e.getMessage(), e); + } + } + + @Override + public boolean isComplete() { + return state == State.COMPLETE; + } + + @Override + public byte[] unwrap(byte[] incoming, int offset, int len) { + if (!isComplete()) + throw new IllegalStateException("Authentication exchange has not completed"); + return Arrays.copyOfRange(incoming, offset, offset + len); + } + + @Override + public byte[] wrap(byte[] outgoing, int offset, int len) { + if (!isComplete()) + throw new IllegalStateException("Authentication exchange has not completed"); + return Arrays.copyOfRange(outgoing, offset, offset + len); + } + + @Override + public Object getNegotiatedProperty(String propName) { + if (!isComplete()) + throw new IllegalStateException("Authentication exchange has not completed"); + return null; + } + + @Override + public void dispose() { + } + + private void setState(State state) { + log.debug("Setting SASL/{} client state to {}", OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, state); + this.state = state; + } + + private SaslExtensions retrieveCustomExtensions() throws SaslException { + SaslExtensionsCallback extensionsCallback = new SaslExtensionsCallback(); + try { + callbackHandler().handle(new Callback[] {extensionsCallback}); + } catch (UnsupportedCallbackException e) { + log.debug("Extensions callback is not supported by client callback handler {}, no extensions will be added", + callbackHandler()); + } catch (Exception e) { + throw new SaslException("SASL extensions could not be obtained", e); + } + + return extensionsCallback.extensions(); + } + + public static class OAuthBearerSaslClientFactory implements SaslClientFactory { + @Override + public SaslClient createSaslClient(String[] mechanisms, String authorizationId, String protocol, + String serverName, Map props, CallbackHandler callbackHandler) { + String[] mechanismNamesCompatibleWithPolicy = getMechanismNames(props); + for (String mechanism : mechanisms) { + for (int i = 0; i < mechanismNamesCompatibleWithPolicy.length; i++) { + if (mechanismNamesCompatibleWithPolicy[i].equals(mechanism)) { + if (!(Objects.requireNonNull(callbackHandler) instanceof AuthenticateCallbackHandler)) + throw new IllegalArgumentException(String.format( + "Callback handler must be castable to %s: %s", + AuthenticateCallbackHandler.class.getName(), callbackHandler.getClass().getName())); + return new OAuthBearerSaslClient((AuthenticateCallbackHandler) callbackHandler); + } + } + } + return null; + } + + @Override + public String[] getMechanismNames(Map props) { + return OAuthBearerSaslServer.mechanismNamesCompatibleWithPolicy(props); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClientCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClientCallbackHandler.java new file mode 100644 index 0000000..bca55be --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClientCallbackHandler.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer.internals; + +import java.io.IOException; +import java.security.AccessController; +import java.util.Collections; +import java.util.Comparator; +import java.util.Date; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.SortedSet; +import java.util.TreeSet; + +import javax.security.auth.Subject; +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.auth.login.AppConfigurationEntry; + +import org.apache.kafka.common.security.auth.SaslExtensionsCallback; +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; +import org.apache.kafka.common.security.auth.SaslExtensions; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An implementation of {@code AuthenticateCallbackHandler} that recognizes + * {@link OAuthBearerTokenCallback} and retrieves OAuth 2 Bearer Token that was + * created when the {@code OAuthBearerLoginModule} logged in by looking for an + * instance of {@link OAuthBearerToken} in the {@code Subject}'s private + * credentials. This class also recognizes {@link SaslExtensionsCallback} and retrieves any SASL extensions that were + * created when the {@code OAuthBearerLoginModule} logged in by looking for an instance of {@link SaslExtensions} + * in the {@code Subject}'s public credentials + *

        + * Use of this class is configured automatically and does not need to be + * explicitly set via the {@code sasl.client.callback.handler.class} + * configuration property. + */ +public class OAuthBearerSaslClientCallbackHandler implements AuthenticateCallbackHandler { + private static final Logger log = LoggerFactory.getLogger(OAuthBearerSaslClientCallbackHandler.class); + private boolean configured = false; + + /** + * Return true if this instance has been configured, otherwise false + * + * @return true if this instance has been configured, otherwise false + */ + public boolean configured() { + return configured; + } + + @Override + public void configure(Map configs, String saslMechanism, List jaasConfigEntries) { + if (!OAuthBearerLoginModule.OAUTHBEARER_MECHANISM.equals(saslMechanism)) + throw new IllegalArgumentException(String.format("Unexpected SASL mechanism: %s", saslMechanism)); + configured = true; + } + + @Override + public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { + if (!configured()) + throw new IllegalStateException("Callback handler not configured"); + for (Callback callback : callbacks) { + if (callback instanceof OAuthBearerTokenCallback) + handleCallback((OAuthBearerTokenCallback) callback); + else if (callback instanceof SaslExtensionsCallback) + handleCallback((SaslExtensionsCallback) callback, Subject.getSubject(AccessController.getContext())); + else + throw new UnsupportedCallbackException(callback); + } + } + + @Override + public void close() { + // empty + } + + private void handleCallback(OAuthBearerTokenCallback callback) throws IOException { + if (callback.token() != null) + throw new IllegalArgumentException("Callback had a token already"); + Subject subject = Subject.getSubject(AccessController.getContext()); + Set privateCredentials = subject != null + ? subject.getPrivateCredentials(OAuthBearerToken.class) + : Collections.emptySet(); + if (privateCredentials.size() == 0) + throw new IOException("No OAuth Bearer tokens in Subject's private credentials"); + if (privateCredentials.size() == 1) + callback.token(privateCredentials.iterator().next()); + else { + /* + * There a very small window of time upon token refresh (on the order of milliseconds) + * where both an old and a new token appear on the Subject's private credentials. + * Rather than implement a lock to eliminate this window, we will deal with it by + * checking for the existence of multiple tokens and choosing the one that has the + * longest lifetime. It is also possible that a bug could cause multiple tokens to + * exist (e.g. KAFKA-7902), so dealing with the unlikely possibility that occurs + * during normal operation also allows us to deal more robustly with potential bugs. + */ + SortedSet sortedByLifetime = + new TreeSet<>( + new Comparator() { + @Override + public int compare(OAuthBearerToken o1, OAuthBearerToken o2) { + return Long.compare(o1.lifetimeMs(), o2.lifetimeMs()); + } + }); + sortedByLifetime.addAll(privateCredentials); + log.warn("Found {} OAuth Bearer tokens in Subject's private credentials; the oldest expires at {}, will use the newest, which expires at {}", + sortedByLifetime.size(), + new Date(sortedByLifetime.first().lifetimeMs()), + new Date(sortedByLifetime.last().lifetimeMs())); + callback.token(sortedByLifetime.last()); + } + } + + /** + * Attaches the first {@link SaslExtensions} found in the public credentials of the Subject + */ + private static void handleCallback(SaslExtensionsCallback extensionsCallback, Subject subject) { + if (subject != null && !subject.getPublicCredentials(SaslExtensions.class).isEmpty()) { + SaslExtensions extensions = subject.getPublicCredentials(SaslExtensions.class).iterator().next(); + extensionsCallback.extensions(extensions); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClientProvider.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClientProvider.java new file mode 100644 index 0000000..08777ef --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClientProvider.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer.internals; + +import java.security.Provider; +import java.security.Security; + +import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; +import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerSaslClient.OAuthBearerSaslClientFactory; + +public class OAuthBearerSaslClientProvider extends Provider { + private static final long serialVersionUID = 1L; + + protected OAuthBearerSaslClientProvider() { + super("SASL/OAUTHBEARER Client Provider", 1.0, "SASL/OAUTHBEARER Client Provider for Kafka"); + put("SaslClientFactory." + OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, + OAuthBearerSaslClientFactory.class.getName()); + } + + public static void initialize() { + Security.addProvider(new OAuthBearerSaslClientProvider()); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServer.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServer.java new file mode 100644 index 0000000..8735f49 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServer.java @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer.internals; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Map; +import java.util.Objects; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslException; +import javax.security.sasl.SaslServer; +import javax.security.sasl.SaslServerFactory; + +import org.apache.kafka.common.errors.SaslAuthenticationException; +import org.apache.kafka.common.security.auth.SaslExtensions; +import org.apache.kafka.common.security.authenticator.SaslInternalConfigs; +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerExtensionsValidatorCallback; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerValidatorCallback; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * {@code SaslServer} implementation for SASL/OAUTHBEARER in Kafka. An instance + * of {@link OAuthBearerToken} is available upon successful authentication via + * the negotiated property "{@code OAUTHBEARER.token}"; the token could be used + * in a custom authorizer (to authorize based on JWT claims rather than ACLs, + * for example). + */ +public class OAuthBearerSaslServer implements SaslServer { + + private static final Logger log = LoggerFactory.getLogger(OAuthBearerSaslServer.class); + private static final String NEGOTIATED_PROPERTY_KEY_TOKEN = OAuthBearerLoginModule.OAUTHBEARER_MECHANISM + ".token"; + private static final String INTERNAL_ERROR_ON_SERVER = "Authentication could not be performed due to an internal error on the server"; + + private final AuthenticateCallbackHandler callbackHandler; + + private boolean complete; + private OAuthBearerToken tokenForNegotiatedProperty = null; + private String errorMessage = null; + private SaslExtensions extensions; + + public OAuthBearerSaslServer(CallbackHandler callbackHandler) { + if (!(Objects.requireNonNull(callbackHandler) instanceof AuthenticateCallbackHandler)) + throw new IllegalArgumentException(String.format("Callback handler must be castable to %s: %s", + AuthenticateCallbackHandler.class.getName(), callbackHandler.getClass().getName())); + this.callbackHandler = (AuthenticateCallbackHandler) callbackHandler; + } + + /** + * @throws SaslAuthenticationException + * if access token cannot be validated + *

        + * Note: This method may throw + * {@link SaslAuthenticationException} to provide custom error + * messages to clients. But care should be taken to avoid including + * any information in the exception message that should not be + * leaked to unauthenticated clients. It may be safer to throw + * {@link SaslException} in some cases so that a standard error + * message is returned to clients. + *

        + */ + @Override + public byte[] evaluateResponse(byte[] response) throws SaslException, SaslAuthenticationException { + if (response.length == 1 && response[0] == OAuthBearerSaslClient.BYTE_CONTROL_A && errorMessage != null) { + log.debug("Received %x01 response from client after it received our error"); + throw new SaslAuthenticationException(errorMessage); + } + errorMessage = null; + + OAuthBearerClientInitialResponse clientResponse; + try { + clientResponse = new OAuthBearerClientInitialResponse(response); + } catch (SaslException e) { + log.debug(e.getMessage()); + throw e; + } + + return process(clientResponse.tokenValue(), clientResponse.authorizationId(), clientResponse.extensions()); + } + + @Override + public String getAuthorizationID() { + if (!complete) + throw new IllegalStateException("Authentication exchange has not completed"); + return tokenForNegotiatedProperty.principalName(); + } + + @Override + public String getMechanismName() { + return OAuthBearerLoginModule.OAUTHBEARER_MECHANISM; + } + + @Override + public Object getNegotiatedProperty(String propName) { + if (!complete) + throw new IllegalStateException("Authentication exchange has not completed"); + if (NEGOTIATED_PROPERTY_KEY_TOKEN.equals(propName)) + return tokenForNegotiatedProperty; + if (SaslInternalConfigs.CREDENTIAL_LIFETIME_MS_SASL_NEGOTIATED_PROPERTY_KEY.equals(propName)) + return tokenForNegotiatedProperty.lifetimeMs(); + return extensions.map().get(propName); + } + + @Override + public boolean isComplete() { + return complete; + } + + @Override + public byte[] unwrap(byte[] incoming, int offset, int len) { + if (!complete) + throw new IllegalStateException("Authentication exchange has not completed"); + return Arrays.copyOfRange(incoming, offset, offset + len); + } + + @Override + public byte[] wrap(byte[] outgoing, int offset, int len) { + if (!complete) + throw new IllegalStateException("Authentication exchange has not completed"); + return Arrays.copyOfRange(outgoing, offset, offset + len); + } + + @Override + public void dispose() { + complete = false; + tokenForNegotiatedProperty = null; + extensions = null; + } + + private byte[] process(String tokenValue, String authorizationId, SaslExtensions extensions) throws SaslException { + OAuthBearerValidatorCallback callback = new OAuthBearerValidatorCallback(tokenValue); + try { + callbackHandler.handle(new Callback[] {callback}); + } catch (IOException | UnsupportedCallbackException e) { + handleCallbackError(e); + } + OAuthBearerToken token = callback.token(); + if (token == null) { + errorMessage = jsonErrorResponse(callback.errorStatus(), callback.errorScope(), + callback.errorOpenIDConfiguration()); + log.debug(errorMessage); + return errorMessage.getBytes(StandardCharsets.UTF_8); + } + /* + * We support the client specifying an authorization ID as per the SASL + * specification, but it must match the principal name if it is specified. + */ + if (!authorizationId.isEmpty() && !authorizationId.equals(token.principalName())) + throw new SaslAuthenticationException(String.format( + "Authentication failed: Client requested an authorization id (%s) that is different from the token's principal name (%s)", + authorizationId, token.principalName())); + + Map validExtensions = processExtensions(token, extensions); + + tokenForNegotiatedProperty = token; + this.extensions = new SaslExtensions(validExtensions); + complete = true; + log.debug("Successfully authenticate User={}", token.principalName()); + return new byte[0]; + } + + private Map processExtensions(OAuthBearerToken token, SaslExtensions extensions) throws SaslException { + OAuthBearerExtensionsValidatorCallback extensionsCallback = new OAuthBearerExtensionsValidatorCallback(token, extensions); + try { + callbackHandler.handle(new Callback[] {extensionsCallback}); + } catch (UnsupportedCallbackException e) { + // backwards compatibility - no extensions will be added + } catch (IOException e) { + handleCallbackError(e); + } + if (!extensionsCallback.invalidExtensions().isEmpty()) { + String errorMessage = String.format("Authentication failed: %d extensions are invalid! They are: %s", + extensionsCallback.invalidExtensions().size(), + Utils.mkString(extensionsCallback.invalidExtensions(), "", "", ": ", "; ")); + log.debug(errorMessage); + throw new SaslAuthenticationException(errorMessage); + } + + return extensionsCallback.validatedExtensions(); + } + + private static String jsonErrorResponse(String errorStatus, String errorScope, String errorOpenIDConfiguration) { + String jsonErrorResponse = String.format("{\"status\":\"%s\"", errorStatus); + if (errorScope != null) + jsonErrorResponse = String.format("%s, \"scope\":\"%s\"", jsonErrorResponse, errorScope); + if (errorOpenIDConfiguration != null) + jsonErrorResponse = String.format("%s, \"openid-configuration\":\"%s\"", jsonErrorResponse, + errorOpenIDConfiguration); + jsonErrorResponse = String.format("%s}", jsonErrorResponse); + return jsonErrorResponse; + } + + private void handleCallbackError(Exception e) throws SaslException { + String msg = String.format("%s: %s", INTERNAL_ERROR_ON_SERVER, e.getMessage()); + log.debug(msg, e); + throw new SaslException(msg); + } + + public static String[] mechanismNamesCompatibleWithPolicy(Map props) { + return props != null && "true".equals(String.valueOf(props.get(Sasl.POLICY_NOPLAINTEXT))) ? new String[] {} + : new String[] {OAuthBearerLoginModule.OAUTHBEARER_MECHANISM}; + } + + public static class OAuthBearerSaslServerFactory implements SaslServerFactory { + @Override + public SaslServer createSaslServer(String mechanism, String protocol, String serverName, Map props, + CallbackHandler callbackHandler) { + String[] mechanismNamesCompatibleWithPolicy = getMechanismNames(props); + for (int i = 0; i < mechanismNamesCompatibleWithPolicy.length; i++) { + if (mechanismNamesCompatibleWithPolicy[i].equals(mechanism)) { + return new OAuthBearerSaslServer(callbackHandler); + } + } + return null; + } + + @Override + public String[] getMechanismNames(Map props) { + return OAuthBearerSaslServer.mechanismNamesCompatibleWithPolicy(props); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServerProvider.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServerProvider.java new file mode 100644 index 0000000..2e179ce --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServerProvider.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer.internals; + +import java.security.Provider; +import java.security.Security; + +import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; +import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerSaslServer.OAuthBearerSaslServerFactory; + +public class OAuthBearerSaslServerProvider extends Provider { + private static final long serialVersionUID = 1L; + + protected OAuthBearerSaslServerProvider() { + super("SASL/OAUTHBEARER Server Provider", 1.0, "SASL/OAUTHBEARER Server Provider for Kafka"); + put("SaslServerFactory." + OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, + OAuthBearerSaslServerFactory.class.getName()); + } + + public static void initialize() { + Security.addProvider(new OAuthBearerSaslServerProvider()); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredential.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredential.java new file mode 100644 index 0000000..1bfa4b2 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredential.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer.internals.expiring; + +/** + * A credential that expires and that can potentially be refreshed + * + * @see ExpiringCredentialRefreshingLogin + */ +public interface ExpiringCredential { + /** + * The name of the principal to which this credential applies (used only for + * logging) + * + * @return the always non-null/non-empty principal name + */ + String principalName(); + + /** + * When the credential became valid, in terms of the number of milliseconds + * since the epoch, if known, otherwise null. An expiring credential may not + * necessarily indicate when it was created -- just when it expires -- so we + * need to support a null return value here. + * + * @return the time when the credential became valid, in terms of the number of + * milliseconds since the epoch, if known, otherwise null + */ + Long startTimeMs(); + + /** + * When the credential expires, in terms of the number of milliseconds since the + * epoch. All expiring credentials by definition must indicate their expiration + * time -- thus, unlike other methods, we do not support a null return value + * here. + * + * @return the time when the credential expires, in terms of the number of + * milliseconds since the epoch + */ + long expireTimeMs(); + + /** + * The point after which the credential can no longer be refreshed, in terms of + * the number of milliseconds since the epoch, if any, otherwise null. Some + * expiring credentials can be refreshed over and over again without limit, so + * we support a null return value here. + * + * @return the point after which the credential can no longer be refreshed, in + * terms of the number of milliseconds since the epoch, if any, + * otherwise null + */ + Long absoluteLastRefreshTimeMs(); +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredentialRefreshConfig.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredentialRefreshConfig.java new file mode 100644 index 0000000..1df69f7 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredentialRefreshConfig.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer.internals.expiring; + +import java.util.Map; +import java.util.Objects; + +import org.apache.kafka.common.config.SaslConfigs; + +/** + * Immutable refresh-related configuration for expiring credentials that can be + * parsed from a producer/consumer/broker config. + */ +public class ExpiringCredentialRefreshConfig { + private final double loginRefreshWindowFactor; + private final double loginRefreshWindowJitter; + private final short loginRefreshMinPeriodSeconds; + private final short loginRefreshBufferSeconds; + private final boolean loginRefreshReloginAllowedBeforeLogout; + + /** + * Constructor based on producer/consumer/broker configs and the indicated value + * for whether or not client relogin is allowed before logout + * + * @param configs + * the mandatory (but possibly empty) producer/consumer/broker + * configs upon which to build this instance + * @param clientReloginAllowedBeforeLogout + * if the {@code LoginModule} and {@code SaslClient} implementations + * support multiple simultaneous login contexts on a single + * {@code Subject} at the same time. If true, then upon refresh, + * logout will only be invoked on the original {@code LoginContext} + * after a new one successfully logs in. This can be helpful if the + * original credential still has some lifetime left when an attempt + * to refresh the credential fails; the client will still be able to + * create new connections as long as the original credential remains + * valid. Otherwise, if logout is immediately invoked prior to + * relogin, a relogin failure leaves the client without the ability + * to connect until relogin does in fact succeed. + */ + public ExpiringCredentialRefreshConfig(Map configs, boolean clientReloginAllowedBeforeLogout) { + Objects.requireNonNull(configs); + this.loginRefreshWindowFactor = (Double) configs.get(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_FACTOR); + this.loginRefreshWindowJitter = (Double) configs.get(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_JITTER); + this.loginRefreshMinPeriodSeconds = (Short) configs.get(SaslConfigs.SASL_LOGIN_REFRESH_MIN_PERIOD_SECONDS); + this.loginRefreshBufferSeconds = (Short) configs.get(SaslConfigs.SASL_LOGIN_REFRESH_BUFFER_SECONDS); + this.loginRefreshReloginAllowedBeforeLogout = clientReloginAllowedBeforeLogout; + } + + /** + * Background login refresh thread will sleep until the specified window factor + * relative to the credential's total lifetime has been reached, at which time + * it will try to refresh the credential. + * + * @return the login refresh window factor + */ + public double loginRefreshWindowFactor() { + return loginRefreshWindowFactor; + } + + /** + * Amount of random jitter added to the background login refresh thread's sleep + * time. + * + * @return the login refresh window jitter + */ + public double loginRefreshWindowJitter() { + return loginRefreshWindowJitter; + } + + /** + * The desired minimum time between checks by the background login refresh + * thread, in seconds + * + * @return the desired minimum refresh period, in seconds + */ + public short loginRefreshMinPeriodSeconds() { + return loginRefreshMinPeriodSeconds; + } + + /** + * The amount of buffer time before expiration to maintain when refreshing. If a + * refresh is scheduled to occur closer to expiration than the number of seconds + * defined here then the refresh will be moved up to maintain as much of the + * desired buffer as possible. + * + * @return the refresh buffer, in seconds + */ + public short loginRefreshBufferSeconds() { + return loginRefreshBufferSeconds; + } + + /** + * If the LoginModule and SaslClient implementations support multiple + * simultaneous login contexts on a single Subject at the same time. If true, + * then upon refresh, logout will only be invoked on the original LoginContext + * after a new one successfully logs in. This can be helpful if the original + * credential still has some lifetime left when an attempt to refresh the + * credential fails; the client will still be able to create new connections as + * long as the original credential remains valid. Otherwise, if logout is + * immediately invoked prior to relogin, a relogin failure leaves the client + * without the ability to connect until relogin does in fact succeed. + * + * @return true if relogin is allowed prior to discarding an existing + * (presumably unexpired) credential, otherwise false + */ + public boolean loginRefreshReloginAllowedBeforeLogout() { + return loginRefreshReloginAllowedBeforeLogout; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredentialRefreshingLogin.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredentialRefreshingLogin.java new file mode 100644 index 0000000..ab2f303 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredentialRefreshingLogin.java @@ -0,0 +1,444 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer.internals.expiring; + +import java.util.Date; +import java.util.Objects; +import java.util.Random; + +import javax.security.auth.Subject; +import javax.security.auth.login.Configuration; +import javax.security.auth.login.LoginContext; +import javax.security.auth.login.LoginException; + +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; +import org.apache.kafka.common.security.auth.Login; +import org.apache.kafka.common.utils.KafkaThread; +import org.apache.kafka.common.utils.Time; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This class is responsible for refreshing logins for both Kafka client and + * server when the login is a type that has a limited lifetime/will expire. The + * credentials for the login must implement {@link ExpiringCredential}. + */ +public abstract class ExpiringCredentialRefreshingLogin implements AutoCloseable { + /** + * Class that can be overridden for testing + */ + static class LoginContextFactory { + public LoginContext createLoginContext(ExpiringCredentialRefreshingLogin expiringCredentialRefreshingLogin) + throws LoginException { + return new LoginContext(expiringCredentialRefreshingLogin.contextName(), + expiringCredentialRefreshingLogin.subject(), expiringCredentialRefreshingLogin.callbackHandler(), + expiringCredentialRefreshingLogin.configuration()); + } + + public void refresherThreadStarted() { + // empty + } + + public void refresherThreadDone() { + // empty + } + } + + private static class ExitRefresherThreadDueToIllegalStateException extends Exception { + private static final long serialVersionUID = -6108495378411920380L; + + public ExitRefresherThreadDueToIllegalStateException(String message) { + super(message); + } + } + + private class Refresher implements Runnable { + @Override + public void run() { + log.info("[Principal={}]: Expiring credential re-login thread started.", principalLogText()); + while (true) { + /* + * Refresh thread's main loop. Each expiring credential lives for one iteration + * of the loop. Thread will exit if the loop exits from here. + */ + long nowMs = currentMs(); + Long nextRefreshMs = refreshMs(nowMs); + if (nextRefreshMs == null) { + loginContextFactory.refresherThreadDone(); + return; + } + // safety check motivated by KAFKA-7945, + // should generally never happen except due to a bug + if (nextRefreshMs.longValue() < nowMs) { + log.warn("[Principal={}]: Expiring credential re-login sleep time was calculated to be in the past! Will explicitly adjust. ({})", principalLogText(), + new Date(nextRefreshMs)); + nextRefreshMs = Long.valueOf(nowMs + 10 * 1000); // refresh in 10 seconds + } + log.info("[Principal={}]: Expiring credential re-login sleeping until: {}", principalLogText(), + new Date(nextRefreshMs)); + time.sleep(nextRefreshMs - nowMs); + if (Thread.currentThread().isInterrupted()) { + log.info("[Principal={}]: Expiring credential re-login thread has been interrupted and will exit.", + principalLogText()); + loginContextFactory.refresherThreadDone(); + return; + } + while (true) { + /* + * Perform a re-login over and over again with some intervening delay + * unless/until either the refresh succeeds or we are interrupted. + */ + try { + reLogin(); + break; // success + } catch (ExitRefresherThreadDueToIllegalStateException e) { + log.error(e.getMessage(), e); + loginContextFactory.refresherThreadDone(); + return; + } catch (LoginException loginException) { + log.warn(String.format( + "[Principal=%s]: LoginException during login retry; will sleep %d seconds before trying again.", + principalLogText(), DELAY_SECONDS_BEFORE_NEXT_RETRY_WHEN_RELOGIN_FAILS), + loginException); + // Sleep and allow loop to run/try again unless interrupted + time.sleep(DELAY_SECONDS_BEFORE_NEXT_RETRY_WHEN_RELOGIN_FAILS * 1000); + if (Thread.currentThread().isInterrupted()) { + log.error( + "[Principal={}]: Interrupted while trying to perform a subsequent expiring credential re-login after one or more initial re-login failures: re-login thread exiting now: {}", + principalLogText(), String.valueOf(loginException.getMessage())); + loginContextFactory.refresherThreadDone(); + return; + } + } + } + } + } + } + + private static final Logger log = LoggerFactory.getLogger(ExpiringCredentialRefreshingLogin.class); + private static final long DELAY_SECONDS_BEFORE_NEXT_RETRY_WHEN_RELOGIN_FAILS = 10L; + private static final Random RNG = new Random(); + private final Time time; + private Thread refresherThread; + + private final LoginContextFactory loginContextFactory; + private final String contextName; + private final Configuration configuration; + private final ExpiringCredentialRefreshConfig expiringCredentialRefreshConfig; + private final AuthenticateCallbackHandler callbackHandler; + + // mark volatile due to existence of public subject() method + private volatile Subject subject = null; + private boolean hasExpiringCredential = false; + private String principalName = null; + private LoginContext loginContext = null; + private ExpiringCredential expiringCredential = null; + private final Class mandatoryClassToSynchronizeOnPriorToRefresh; + + public ExpiringCredentialRefreshingLogin(String contextName, Configuration configuration, + ExpiringCredentialRefreshConfig expiringCredentialRefreshConfig, + AuthenticateCallbackHandler callbackHandler, Class mandatoryClassToSynchronizeOnPriorToRefresh) { + this(contextName, configuration, expiringCredentialRefreshConfig, callbackHandler, + mandatoryClassToSynchronizeOnPriorToRefresh, new LoginContextFactory(), Time.SYSTEM); + } + + public ExpiringCredentialRefreshingLogin(String contextName, Configuration configuration, + ExpiringCredentialRefreshConfig expiringCredentialRefreshConfig, + AuthenticateCallbackHandler callbackHandler, Class mandatoryClassToSynchronizeOnPriorToRefresh, + LoginContextFactory loginContextFactory, Time time) { + this.contextName = Objects.requireNonNull(contextName); + this.configuration = Objects.requireNonNull(configuration); + this.expiringCredentialRefreshConfig = Objects.requireNonNull(expiringCredentialRefreshConfig); + this.callbackHandler = callbackHandler; + this.mandatoryClassToSynchronizeOnPriorToRefresh = Objects + .requireNonNull(mandatoryClassToSynchronizeOnPriorToRefresh); + this.loginContextFactory = loginContextFactory; + this.time = Objects.requireNonNull(time); + } + + public Subject subject() { + return subject; // field requires volatile keyword + } + + public String contextName() { + return contextName; + } + + public Configuration configuration() { + return configuration; + } + + public AuthenticateCallbackHandler callbackHandler() { + return callbackHandler; + } + + public String serviceName() { + return "kafka"; + } + + /** + * Performs login for each login module specified for the login context of this + * instance and starts the thread used to periodically re-login. + *

        + * The synchronized keyword is not necessary because an implementation of + * {@link Login} will delegate to this code (e.g. OAuthBearerRefreshingLogin}, + * and the {@code login()} method on the delegating class will itself be + * synchronized if necessary. + */ + public LoginContext login() throws LoginException { + LoginContext tmpLoginContext = loginContextFactory.createLoginContext(this); + tmpLoginContext.login(); + log.info("Successfully logged in."); + loginContext = tmpLoginContext; + subject = loginContext.getSubject(); + expiringCredential = expiringCredential(); + hasExpiringCredential = expiringCredential != null; + if (!hasExpiringCredential) { + // do not bother with re-logins. + log.debug("No Expiring Credential"); + principalName = null; + refresherThread = null; + return loginContext; + } + + principalName = expiringCredential.principalName(); + + // Check for a clock skew problem + long expireTimeMs = expiringCredential.expireTimeMs(); + long nowMs = currentMs(); + if (nowMs > expireTimeMs) { + log.error( + "[Principal={}]: Current clock: {} is later than expiry {}. This may indicate a clock skew problem." + + " Check that this host's and remote host's clocks are in sync. Not starting refresh thread." + + " This process is likely unable to authenticate SASL connections (for example, it is unlikely" + + " to be able to authenticate a connection with a Kafka Broker).", + principalLogText(), new Date(nowMs), new Date(expireTimeMs)); + return loginContext; + } + + if (log.isDebugEnabled()) + log.debug("[Principal={}]: It is an expiring credential", principalLogText()); + + /* + * Re-login periodically. How often is determined by the expiration date of the + * credential and refresh-related configuration values. + */ + refresherThread = KafkaThread.daemon(String.format("kafka-expiring-relogin-thread-%s", principalName), + new Refresher()); + refresherThread.start(); + loginContextFactory.refresherThreadStarted(); + return loginContext; + } + + public void close() { + if (refresherThread != null && refresherThread.isAlive()) { + refresherThread.interrupt(); + try { + refresherThread.join(); + } catch (InterruptedException e) { + log.warn("[Principal={}]: Interrupted while waiting for re-login thread to shutdown.", + principalLogText(), e); + Thread.currentThread().interrupt(); + } + } + } + + public abstract ExpiringCredential expiringCredential(); + + /** + * Determine when to sleep until before performing a refresh + * + * @param relativeToMs + * the point (in terms of number of milliseconds since the epoch) at + * which to perform the calculation + * @return null if no refresh should occur, otherwise the time to sleep until + * (in terms of the number of milliseconds since the epoch) before + * performing a refresh + */ + private Long refreshMs(long relativeToMs) { + if (expiringCredential == null) { + /* + * Re-login failed because our login() invocation did not generate a credential + * but also did not generate an exception. Try logging in again after some delay + * (it seems likely to be a bug, but it doesn't hurt to keep trying to refresh). + */ + long retvalNextRefreshMs = relativeToMs + DELAY_SECONDS_BEFORE_NEXT_RETRY_WHEN_RELOGIN_FAILS * 1000L; + log.warn("[Principal={}]: No Expiring credential found: will try again at {}", principalLogText(), + new Date(retvalNextRefreshMs)); + return retvalNextRefreshMs; + } + long expireTimeMs = expiringCredential.expireTimeMs(); + if (relativeToMs > expireTimeMs) { + boolean logoutRequiredBeforeLoggingBackIn = isLogoutRequiredBeforeLoggingBackIn(); + if (logoutRequiredBeforeLoggingBackIn) { + log.error( + "[Principal={}]: Current clock: {} is later than expiry {}. This may indicate a clock skew problem." + + " Check that this host's and remote host's clocks are in sync. Exiting refresh thread.", + principalLogText(), new Date(relativeToMs), new Date(expireTimeMs)); + return null; + } else { + /* + * Since the current soon-to-expire credential isn't logged out until we have a + * new credential with a refreshed lifetime, it is possible that the current + * credential could expire if the re-login continually fails over and over again + * making us unable to get the new credential. Therefore keep trying rather than + * exiting. + */ + long retvalNextRefreshMs = relativeToMs + DELAY_SECONDS_BEFORE_NEXT_RETRY_WHEN_RELOGIN_FAILS * 1000L; + log.warn("[Principal={}]: Expiring credential already expired at {}: will try to refresh again at {}", + principalLogText(), new Date(expireTimeMs), new Date(retvalNextRefreshMs)); + return retvalNextRefreshMs; + } + } + Long absoluteLastRefreshTimeMs = expiringCredential.absoluteLastRefreshTimeMs(); + if (absoluteLastRefreshTimeMs != null && absoluteLastRefreshTimeMs.longValue() < expireTimeMs) { + log.warn("[Principal={}]: Expiring credential refresh thread exiting because the" + + " expiring credential's current expiration time ({}) exceeds the latest possible refresh time ({})." + + " This process will not be able to authenticate new SASL connections after that" + + " time (for example, it will not be able to authenticate a new connection with a Kafka Broker).", + principalLogText(), new Date(expireTimeMs), new Date(absoluteLastRefreshTimeMs.longValue())); + return null; + } + Long optionalStartTime = expiringCredential.startTimeMs(); + long startMs = optionalStartTime != null ? optionalStartTime.longValue() : relativeToMs; + log.info("[Principal={}]: Expiring credential valid from {} to {}", expiringCredential.principalName(), + new java.util.Date(startMs), new java.util.Date(expireTimeMs)); + + double pct = expiringCredentialRefreshConfig.loginRefreshWindowFactor() + + (expiringCredentialRefreshConfig.loginRefreshWindowJitter() * RNG.nextDouble()); + /* + * Ignore buffer times if the credential's remaining lifetime is less than their + * sum. + */ + long refreshMinPeriodSeconds = expiringCredentialRefreshConfig.loginRefreshMinPeriodSeconds(); + long clientRefreshBufferSeconds = expiringCredentialRefreshConfig.loginRefreshBufferSeconds(); + if (relativeToMs + 1000L * (refreshMinPeriodSeconds + clientRefreshBufferSeconds) > expireTimeMs) { + long retvalRefreshMs = relativeToMs + (long) ((expireTimeMs - relativeToMs) * pct); + log.warn( + "[Principal={}]: Expiring credential expires at {}, so buffer times of {} and {} seconds" + + " at the front and back, respectively, cannot be accommodated. We will refresh at {}.", + principalLogText(), new Date(expireTimeMs), refreshMinPeriodSeconds, clientRefreshBufferSeconds, + new Date(retvalRefreshMs)); + return retvalRefreshMs; + } + long proposedRefreshMs = startMs + (long) ((expireTimeMs - startMs) * pct); + // Don't let it violate the requested end buffer time + long beginningOfEndBufferTimeMs = expireTimeMs - clientRefreshBufferSeconds * 1000; + if (proposedRefreshMs > beginningOfEndBufferTimeMs) { + log.info( + "[Principal={}]: Proposed refresh time of {} extends into the desired buffer time of {} seconds before expiration, so refresh it at the desired buffer begin point, at {}", + expiringCredential.principalName(), new Date(proposedRefreshMs), clientRefreshBufferSeconds, + new Date(beginningOfEndBufferTimeMs)); + return beginningOfEndBufferTimeMs; + } + // Don't let it violate the minimum refresh period + long endOfMinRefreshBufferTime = relativeToMs + 1000 * refreshMinPeriodSeconds; + if (proposedRefreshMs < endOfMinRefreshBufferTime) { + log.info( + "[Principal={}]: Expiring credential re-login thread time adjusted from {} to {} since the former is sooner " + + "than the minimum refresh interval ({} seconds from now).", + principalLogText(), new Date(proposedRefreshMs), new Date(endOfMinRefreshBufferTime), + refreshMinPeriodSeconds); + return endOfMinRefreshBufferTime; + } + // Proposed refresh time doesn't violate any constraints + return proposedRefreshMs; + } + + private void reLogin() throws LoginException, ExitRefresherThreadDueToIllegalStateException { + synchronized (mandatoryClassToSynchronizeOnPriorToRefresh) { + // Only perform one refresh of a particular type at a time + boolean logoutRequiredBeforeLoggingBackIn = isLogoutRequiredBeforeLoggingBackIn(); + if (hasExpiringCredential && logoutRequiredBeforeLoggingBackIn) { + String principalLogTextPriorToLogout = principalLogText(); + log.info("Initiating logout for {}", principalLogTextPriorToLogout); + loginContext.logout(); + // Make absolutely sure we were logged out + expiringCredential = expiringCredential(); + hasExpiringCredential = expiringCredential != null; + if (hasExpiringCredential) + // We can't force the removal because we don't know how to do it, so abort + throw new ExitRefresherThreadDueToIllegalStateException(String.format( + "Subject's private credentials still contains an instance of %s even though logout() was invoked; exiting refresh thread", + expiringCredential.getClass().getName())); + } + /* + * Perform a login, making note of any credential that might need a logout() + * afterwards + */ + ExpiringCredential optionalCredentialToLogout = expiringCredential; + LoginContext optionalLoginContextToLogout = loginContext; + boolean cleanLogin = false; // remember to restore the original if necessary + try { + loginContext = loginContextFactory.createLoginContext(ExpiringCredentialRefreshingLogin.this); + log.info("Initiating re-login for {}, logout() still needs to be called on a previous login = {}", + principalName, optionalCredentialToLogout != null); + loginContext.login(); + cleanLogin = true; // no need to restore the original + // Perform a logout() on any original credential if necessary + if (optionalCredentialToLogout != null) + optionalLoginContextToLogout.logout(); + } finally { + if (!cleanLogin) + // restore the original + loginContext = optionalLoginContextToLogout; + } + /* + * Get the new credential and make sure it is not any old one that required a + * logout() after the login() + */ + expiringCredential = expiringCredential(); + hasExpiringCredential = expiringCredential != null; + if (!hasExpiringCredential) { + /* + * Re-login has failed because our login() invocation has not generated a + * credential but has also not generated an exception. We won't exit here; + * instead we will allow login retries in case we can somehow fix the issue (it + * seems likely to be a bug, but it doesn't hurt to keep trying to refresh). + */ + log.error("No Expiring Credential after a supposedly-successful re-login"); + principalName = null; + } else { + if (expiringCredential == optionalCredentialToLogout) + /* + * The login() didn't identify a new credential; we still have the old one. We + * don't know how to fix this, so abort. + */ + throw new ExitRefresherThreadDueToIllegalStateException(String.format( + "Subject's private credentials still contains the previous, soon-to-expire instance of %s even though login() followed by logout() was invoked; exiting refresh thread", + expiringCredential.getClass().getName())); + principalName = expiringCredential.principalName(); + if (log.isDebugEnabled()) + log.debug("[Principal={}]: It is an expiring credential after re-login as expected", + principalLogText()); + } + } + } + + private String principalLogText() { + return expiringCredential == null ? principalName + : expiringCredential.getClass().getSimpleName() + ":" + principalName; + } + + private long currentMs() { + return time.milliseconds(); + } + + private boolean isLogoutRequiredBeforeLoggingBackIn() { + return !expiringCredentialRefreshConfig.loginRefreshReloginAllowedBeforeLogout(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerConfigException.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerConfigException.java new file mode 100644 index 0000000..3dcdcd0 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerConfigException.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer.internals.unsecured; + +import org.apache.kafka.common.KafkaException; + +/** + * Exception thrown when there is a problem with the configuration (an invalid + * option in a JAAS config, for example). + */ +public class OAuthBearerConfigException extends KafkaException { + private static final long serialVersionUID = -8056105648062343518L; + + public OAuthBearerConfigException(String s) { + super(s); + } + + public OAuthBearerConfigException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerIllegalTokenException.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerIllegalTokenException.java new file mode 100644 index 0000000..7885900 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerIllegalTokenException.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer.internals.unsecured; + +import java.util.Objects; + +import org.apache.kafka.common.KafkaException; + +/** + * Exception thrown when token validation fails due to a problem with the token + * itself (as opposed to a missing remote resource or a configuration problem) + */ +public class OAuthBearerIllegalTokenException extends KafkaException { + private static final long serialVersionUID = -5275276640051316350L; + private final OAuthBearerValidationResult reason; + + /** + * Constructor + * + * @param reason + * the mandatory reason for the validation failure; it must indicate + * failure + */ + public OAuthBearerIllegalTokenException(OAuthBearerValidationResult reason) { + super(Objects.requireNonNull(reason).failureDescription()); + if (reason.success()) + throw new IllegalArgumentException("The reason indicates success; it must instead indicate failure"); + this.reason = reason; + } + + /** + * Return the (always non-null) reason for the validation failure + * + * @return the reason for the validation failure + */ + public OAuthBearerValidationResult reason() { + return reason; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerScopeUtils.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerScopeUtils.java new file mode 100644 index 0000000..7cae41a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerScopeUtils.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer.internals.unsecured; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.regex.Pattern; + +/** + * Utility class for help dealing with + * Access Token + * Scopes + */ +public class OAuthBearerScopeUtils { + private static final Pattern INDIVIDUAL_SCOPE_ITEM_PATTERN = Pattern.compile("[\\x23-\\x5B\\x5D-\\x7E\\x21]+"); + + /** + * Return true if the given value meets the definition of a valid scope item as + * per RFC 6749 + * Section 3.3, otherwise false + * + * @param scopeItem + * the mandatory scope item to check for validity + * @return true if the given value meets the definition of a valid scope item, + * otherwise false + */ + public static boolean isValidScopeItem(String scopeItem) { + return INDIVIDUAL_SCOPE_ITEM_PATTERN.matcher(Objects.requireNonNull(scopeItem)).matches(); + } + + /** + * Convert a space-delimited list of scope values (for example, + * "scope1 scope2") to a List containing the individual elements + * ("scope1" and "scope2") + * + * @param spaceDelimitedScope + * the mandatory (but possibly empty) space-delimited scope values, + * each of which must be valid according to + * {@link #isValidScopeItem(String)} + * @return the list of the given (possibly empty) space-delimited values + * @throws OAuthBearerConfigException + * if any of the individual scope values are malformed/illegal + */ + public static List parseScope(String spaceDelimitedScope) throws OAuthBearerConfigException { + List retval = new ArrayList<>(); + for (String individualScopeItem : Objects.requireNonNull(spaceDelimitedScope).split(" ")) { + if (!individualScopeItem.isEmpty()) { + if (!isValidScopeItem(individualScopeItem)) + throw new OAuthBearerConfigException(String.format("Invalid scope value: %s", individualScopeItem)); + retval.add(individualScopeItem); + } + } + return Collections.unmodifiableList(retval); + } + + private OAuthBearerScopeUtils() { + // empty + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredJws.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredJws.java new file mode 100644 index 0000000..fa175b3 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredJws.java @@ -0,0 +1,370 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer.internals.unsecured; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Base64; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.Set; + +import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; +import org.apache.kafka.common.utils.Utils; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.JsonNodeType; + +/** + * A simple unsecured JWS implementation. The '{@code nbf}' claim is ignored if + * it is given because the related logic is not required for Kafka testing and + * development purposes. + * + * @see RFC 7515 + */ +public class OAuthBearerUnsecuredJws implements OAuthBearerToken { + private final String compactSerialization; + private final List splits; + private final Map header; + private final String principalClaimName; + private final String scopeClaimName; + private final Map claims; + private final Set scope; + private final long lifetime; + private final String principalName; + private final Long startTimeMs; + + /** + * Constructor with the given principal and scope claim names + * + * @param compactSerialization + * the compact serialization to parse as an unsecured JWS + * @param principalClaimName + * the required principal claim name + * @param scopeClaimName + * the required scope claim name + * @throws OAuthBearerIllegalTokenException + * if the compact serialization is not a valid unsecured JWS + * (meaning it did not have 3 dot-separated Base64URL sections + * without an empty digital signature; or the header or claims + * either are not valid Base 64 URL encoded values or are not JSON + * after decoding; or the mandatory '{@code alg}' header value is + * not "{@code none}") + */ + public OAuthBearerUnsecuredJws(String compactSerialization, String principalClaimName, String scopeClaimName) + throws OAuthBearerIllegalTokenException { + this.compactSerialization = Objects.requireNonNull(compactSerialization); + if (compactSerialization.contains("..")) + throw new OAuthBearerIllegalTokenException( + OAuthBearerValidationResult.newFailure("Malformed compact serialization contains '..'")); + this.splits = extractCompactSerializationSplits(); + this.header = toMap(splits().get(0)); + String claimsSplit = splits.get(1); + this.claims = toMap(claimsSplit); + String alg = Objects.requireNonNull(header().get("alg"), "JWS header must have an Algorithm value").toString(); + if (!"none".equals(alg)) + throw new OAuthBearerIllegalTokenException( + OAuthBearerValidationResult.newFailure("Unsecured JWS must have 'none' for an algorithm")); + String digitalSignatureSplit = splits.get(2); + if (!digitalSignatureSplit.isEmpty()) + throw new OAuthBearerIllegalTokenException( + OAuthBearerValidationResult.newFailure("Unsecured JWS must not contain a digital signature")); + this.principalClaimName = Objects.requireNonNull(principalClaimName).trim(); + if (this.principalClaimName.isEmpty()) + throw new IllegalArgumentException("Must specify a non-blank principal claim name"); + this.scopeClaimName = Objects.requireNonNull(scopeClaimName).trim(); + if (this.scopeClaimName.isEmpty()) + throw new IllegalArgumentException("Must specify a non-blank scope claim name"); + this.scope = calculateScope(); + Number expirationTimeSeconds = expirationTime(); + if (expirationTimeSeconds == null) + throw new OAuthBearerIllegalTokenException( + OAuthBearerValidationResult.newFailure("No expiration time in JWT")); + lifetime = convertClaimTimeInSecondsToMs(expirationTimeSeconds); + String principalName = claim(this.principalClaimName, String.class); + if (Utils.isBlank(principalName)) + throw new OAuthBearerIllegalTokenException(OAuthBearerValidationResult + .newFailure("No principal name in JWT claim: " + this.principalClaimName)); + this.principalName = principalName; + this.startTimeMs = calculateStartTimeMs(); + } + + @Override + public String value() { + return compactSerialization; + } + + /** + * Return the 3 or 5 dot-separated sections of the JWT compact serialization + * + * @return the 3 or 5 dot-separated sections of the JWT compact serialization + */ + public List splits() { + return splits; + } + + /** + * Return the JOSE Header as a {@code Map} + * + * @return the JOSE header + */ + public Map header() { + return header; + } + + @Override + public String principalName() { + return principalName; + } + + @Override + public Long startTimeMs() { + return startTimeMs; + } + + @Override + public long lifetimeMs() { + return lifetime; + } + + @Override + public Set scope() throws OAuthBearerIllegalTokenException { + return scope; + } + + /** + * Return the JWT Claim Set as a {@code Map} + * + * @return the (always non-null but possibly empty) claims + */ + public Map claims() { + return claims; + } + + /** + * Return the (always non-null/non-empty) principal claim name + * + * @return the (always non-null/non-empty) principal claim name + */ + public String principalClaimName() { + return principalClaimName; + } + + /** + * Return the (always non-null/non-empty) scope claim name + * + * @return the (always non-null/non-empty) scope claim name + */ + public String scopeClaimName() { + return scopeClaimName; + } + + /** + * Indicate if the claim exists and is the given type + * + * @param claimName + * the mandatory JWT claim name + * @param type + * the mandatory type, which should either be String.class, + * Number.class, or List.class + * @return true if the claim exists and is the given type, otherwise false + */ + public boolean isClaimType(String claimName, Class type) { + Object value = rawClaim(claimName); + Objects.requireNonNull(type); + if (value == null) + return false; + if (type == String.class && value instanceof String) + return true; + if (type == Number.class && value instanceof Number) + return true; + return type == List.class && value instanceof List; + } + + /** + * Extract a claim of the given type + * + * @param claimName + * the mandatory JWT claim name + * @param type + * the mandatory type, which must either be String.class, + * Number.class, or List.class + * @return the claim if it exists, otherwise null + * @throws OAuthBearerIllegalTokenException + * if the claim exists but is not the given type + */ + public T claim(String claimName, Class type) throws OAuthBearerIllegalTokenException { + Object value = rawClaim(claimName); + try { + return Objects.requireNonNull(type).cast(value); + } catch (ClassCastException e) { + throw new OAuthBearerIllegalTokenException( + OAuthBearerValidationResult.newFailure(String.format("The '%s' claim was not of type %s: %s", + claimName, type.getSimpleName(), value.getClass().getSimpleName()))); + } + } + + /** + * Extract a claim in its raw form + * + * @param claimName + * the mandatory JWT claim name + * @return the raw claim value, if it exists, otherwise null + */ + public Object rawClaim(String claimName) { + return claims().get(Objects.requireNonNull(claimName)); + } + + /** + * Return the + * Expiration + * Time claim + * + * @return the Expiration + * Time claim if available, otherwise null + * @throws OAuthBearerIllegalTokenException + * if the claim value is the incorrect type + */ + public Number expirationTime() throws OAuthBearerIllegalTokenException { + return claim("exp", Number.class); + } + + /** + * Return the Issued + * At claim + * + * @return the + * Issued + * At claim if available, otherwise null + * @throws OAuthBearerIllegalTokenException + * if the claim value is the incorrect type + */ + public Number issuedAt() throws OAuthBearerIllegalTokenException { + return claim("iat", Number.class); + } + + /** + * Return the + * Subject claim + * + * @return the Subject claim + * if available, otherwise null + * @throws OAuthBearerIllegalTokenException + * if the claim value is the incorrect type + */ + public String subject() throws OAuthBearerIllegalTokenException { + return claim("sub", String.class); + } + + /** + * Decode the given Base64URL-encoded value, parse the resulting JSON as a JSON + * object, and return the map of member names to their values (each value being + * represented as either a String, a Number, or a List of Strings). + * + * @param split + * the value to decode and parse + * @return the map of JSON member names to their String, Number, or String List + * value + * @throws OAuthBearerIllegalTokenException + * if the given Base64URL-encoded value cannot be decoded or parsed + */ + public static Map toMap(String split) throws OAuthBearerIllegalTokenException { + Map retval = new HashMap<>(); + try { + byte[] decode = Base64.getDecoder().decode(split); + JsonNode jsonNode = new ObjectMapper().readTree(decode); + if (jsonNode == null) + throw new OAuthBearerIllegalTokenException(OAuthBearerValidationResult.newFailure("malformed JSON")); + for (Iterator> iterator = jsonNode.fields(); iterator.hasNext();) { + Entry entry = iterator.next(); + retval.put(entry.getKey(), convert(entry.getValue())); + } + return Collections.unmodifiableMap(retval); + } catch (IllegalArgumentException e) { + // potentially thrown by java.util.Base64.Decoder implementations + throw new OAuthBearerIllegalTokenException( + OAuthBearerValidationResult.newFailure("malformed Base64 URL encoded value")); + } catch (IOException e) { + throw new OAuthBearerIllegalTokenException(OAuthBearerValidationResult.newFailure("malformed JSON")); + } + } + + private List extractCompactSerializationSplits() { + List tmpSplits = new ArrayList<>(Arrays.asList(compactSerialization.split("\\."))); + if (compactSerialization.endsWith(".")) + tmpSplits.add(""); + if (tmpSplits.size() != 3) + throw new OAuthBearerIllegalTokenException(OAuthBearerValidationResult.newFailure( + "Unsecured JWS compact serializations must have 3 dot-separated Base64URL-encoded values")); + return Collections.unmodifiableList(tmpSplits); + } + + private static Object convert(JsonNode value) { + if (value.isArray()) { + List retvalList = new ArrayList<>(); + for (JsonNode arrayElement : value) + retvalList.add(arrayElement.asText()); + return retvalList; + } + return value.getNodeType() == JsonNodeType.NUMBER ? value.numberValue() : value.asText(); + } + + private Long calculateStartTimeMs() throws OAuthBearerIllegalTokenException { + Number issuedAtSeconds = claim("iat", Number.class); + return issuedAtSeconds == null ? null : convertClaimTimeInSecondsToMs(issuedAtSeconds); + } + + private static long convertClaimTimeInSecondsToMs(Number claimValue) { + return Math.round(claimValue.doubleValue() * 1000); + } + + private Set calculateScope() { + String scopeClaimName = scopeClaimName(); + if (isClaimType(scopeClaimName, String.class)) { + String scopeClaimValue = claim(scopeClaimName, String.class); + if (Utils.isBlank(scopeClaimValue)) + return Collections.emptySet(); + else { + Set retval = new HashSet<>(); + retval.add(scopeClaimValue.trim()); + return Collections.unmodifiableSet(retval); + } + } + List scopeClaimValue = claim(scopeClaimName, List.class); + if (scopeClaimValue == null || scopeClaimValue.isEmpty()) + return Collections.emptySet(); + @SuppressWarnings("unchecked") + List stringList = (List) scopeClaimValue; + Set retval = new HashSet<>(); + for (String scope : stringList) { + if (!Utils.isBlank(scope)) { + retval.add(scope.trim()); + } + } + return Collections.unmodifiableSet(retval); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredLoginCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredLoginCallbackHandler.java new file mode 100644 index 0000000..eb4c7db --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredLoginCallbackHandler.java @@ -0,0 +1,343 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer.internals.unsecured; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Base64; +import java.util.Base64.Encoder; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.auth.login.AppConfigurationEntry; +import javax.security.sasl.SaslException; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; +import org.apache.kafka.common.security.auth.SaslExtensionsCallback; +import org.apache.kafka.common.security.auth.SaslExtensions; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback; +import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerClientInitialResponse; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A {@code CallbackHandler} that recognizes {@link OAuthBearerTokenCallback} + * to return an unsecured OAuth 2 bearer token and {@link SaslExtensionsCallback} to return SASL extensions + *

        + * Claims and their values on the returned token can be specified using + * {@code unsecuredLoginStringClaim_}, + * {@code unsecuredLoginNumberClaim_}, and + * {@code unsecuredLoginListClaim_} options. The first character of + * the value is taken as the delimiter for list claims. You may define any claim + * name and value except '{@code iat}' and '{@code exp}', both of which are + * calculated automatically. + *

        + *

        + * You can also add custom unsecured SASL extensions using + * {@code unsecuredLoginExtension_}. Extension keys and values are subject to regex validation. + * The extension key must also not be equal to the reserved key {@link OAuthBearerClientInitialResponse#AUTH_KEY} + *

        + * This implementation also accepts the following options: + *

          + *
        • {@code unsecuredLoginPrincipalClaimName} set to a custom claim name if + * you wish the name of the String claim holding the principal name to be + * something other than '{@code sub}'.
        • + *
        • {@code unsecuredLoginLifetimeSeconds} set to an integer value if the + * token expiration is to be set to something other than the default value of + * 3600 seconds (which is 1 hour). The '{@code exp}' claim reflects the + * expiration time.
        • + *
        • {@code unsecuredLoginScopeClaimName} set to a custom claim name if you + * wish the name of the String or String List claim holding any token scope to + * be something other than '{@code scope}'
        • + *
        + * For example: + * + *
        + * KafkaClient {
        + *      org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule Required
        + *      unsecuredLoginStringClaim_sub="thePrincipalName"
        + *      unsecuredLoginListClaim_scope="|scopeValue1|scopeValue2"
        + *      unsecuredLoginLifetimeSeconds="60"
        + *      unsecuredLoginExtension_traceId="123";
        + * };
        + * 
        + * + * This class is the default when the SASL mechanism is OAUTHBEARER and no value + * is explicitly set via either the {@code sasl.login.callback.handler.class} + * client configuration property or the + * {@code listener.name.sasl_[plaintext|ssl].oauthbearer.sasl.login.callback.handler.class} + * broker configuration property. + */ +public class OAuthBearerUnsecuredLoginCallbackHandler implements AuthenticateCallbackHandler { + private final Logger log = LoggerFactory.getLogger(OAuthBearerUnsecuredLoginCallbackHandler.class); + private static final String OPTION_PREFIX = "unsecuredLogin"; + private static final String PRINCIPAL_CLAIM_NAME_OPTION = OPTION_PREFIX + "PrincipalClaimName"; + private static final String LIFETIME_SECONDS_OPTION = OPTION_PREFIX + "LifetimeSeconds"; + private static final String SCOPE_CLAIM_NAME_OPTION = OPTION_PREFIX + "ScopeClaimName"; + private static final Set RESERVED_CLAIMS = Collections + .unmodifiableSet(new HashSet<>(Arrays.asList("iat", "exp"))); + private static final String DEFAULT_PRINCIPAL_CLAIM_NAME = "sub"; + private static final String DEFAULT_LIFETIME_SECONDS_ONE_HOUR = "3600"; + private static final String DEFAULT_SCOPE_CLAIM_NAME = "scope"; + private static final String STRING_CLAIM_PREFIX = OPTION_PREFIX + "StringClaim_"; + private static final String NUMBER_CLAIM_PREFIX = OPTION_PREFIX + "NumberClaim_"; + private static final String LIST_CLAIM_PREFIX = OPTION_PREFIX + "ListClaim_"; + private static final String EXTENSION_PREFIX = OPTION_PREFIX + "Extension_"; + private static final String QUOTE = "\""; + private Time time = Time.SYSTEM; + private Map moduleOptions = null; + private boolean configured = false; + + private static final Pattern DOUBLEQUOTE = Pattern.compile("\"", Pattern.LITERAL); + + private static final Pattern BACKSLASH = Pattern.compile("\\", Pattern.LITERAL); + + /** + * For testing + * + * @param time + * the mandatory time to set + */ + void time(Time time) { + this.time = Objects.requireNonNull(time); + } + + /** + * Return true if this instance has been configured, otherwise false + * + * @return true if this instance has been configured, otherwise false + */ + public boolean configured() { + return configured; + } + + @SuppressWarnings("unchecked") + @Override + public void configure(Map configs, String saslMechanism, List jaasConfigEntries) { + if (!OAuthBearerLoginModule.OAUTHBEARER_MECHANISM.equals(saslMechanism)) + throw new IllegalArgumentException(String.format("Unexpected SASL mechanism: %s", saslMechanism)); + if (Objects.requireNonNull(jaasConfigEntries).size() != 1 || jaasConfigEntries.get(0) == null) + throw new IllegalArgumentException( + String.format("Must supply exactly 1 non-null JAAS mechanism configuration (size was %d)", + jaasConfigEntries.size())); + this.moduleOptions = Collections.unmodifiableMap((Map) jaasConfigEntries.get(0).getOptions()); + configured = true; + } + + @Override + public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { + if (!configured()) + throw new IllegalStateException("Callback handler not configured"); + for (Callback callback : callbacks) { + if (callback instanceof OAuthBearerTokenCallback) + try { + handleTokenCallback((OAuthBearerTokenCallback) callback); + } catch (KafkaException e) { + throw new IOException(e.getMessage(), e); + } + else if (callback instanceof SaslExtensionsCallback) + try { + handleExtensionsCallback((SaslExtensionsCallback) callback); + } catch (KafkaException e) { + throw new IOException(e.getMessage(), e); + } + else + throw new UnsupportedCallbackException(callback); + } + } + + @Override + public void close() { + // empty + } + + private void handleTokenCallback(OAuthBearerTokenCallback callback) { + if (callback.token() != null) + throw new IllegalArgumentException("Callback had a token already"); + if (moduleOptions.isEmpty()) { + log.debug("Token not provided, this login cannot be used to establish client connections"); + callback.token(null); + return; + } + if (moduleOptions.keySet().stream().noneMatch(name -> !name.startsWith(EXTENSION_PREFIX))) { + throw new OAuthBearerConfigException("Extensions provided in login context without a token"); + } + String principalClaimNameValue = optionValue(PRINCIPAL_CLAIM_NAME_OPTION); + String principalClaimName = Utils.isBlank(principalClaimNameValue) ? DEFAULT_PRINCIPAL_CLAIM_NAME : principalClaimNameValue.trim(); + String scopeClaimNameValue = optionValue(SCOPE_CLAIM_NAME_OPTION); + String scopeClaimName = Utils.isBlank(scopeClaimNameValue) ? DEFAULT_SCOPE_CLAIM_NAME : scopeClaimNameValue.trim(); + String headerJson = "{" + claimOrHeaderJsonText("alg", "none") + "}"; + String lifetimeSecondsValueToUse = optionValue(LIFETIME_SECONDS_OPTION, DEFAULT_LIFETIME_SECONDS_ONE_HOUR); + String claimsJson; + try { + claimsJson = String.format("{%s,%s%s}", expClaimText(Long.parseLong(lifetimeSecondsValueToUse)), + claimOrHeaderJsonText("iat", time.milliseconds() / 1000.0), + commaPrependedStringNumberAndListClaimsJsonText()); + } catch (NumberFormatException e) { + throw new OAuthBearerConfigException(e.getMessage()); + } + try { + Encoder urlEncoderNoPadding = Base64.getUrlEncoder().withoutPadding(); + OAuthBearerUnsecuredJws jws = new OAuthBearerUnsecuredJws( + String.format("%s.%s.", + urlEncoderNoPadding.encodeToString(headerJson.getBytes(StandardCharsets.UTF_8)), + urlEncoderNoPadding.encodeToString(claimsJson.getBytes(StandardCharsets.UTF_8))), + principalClaimName, scopeClaimName); + log.info("Retrieved token with principal {}", jws.principalName()); + callback.token(jws); + } catch (OAuthBearerIllegalTokenException e) { + // occurs if the principal claim doesn't exist or has an empty value + throw new OAuthBearerConfigException(e.getMessage(), e); + } + } + + /** + * Add and validate all the configured extensions. + * Token keys, apart from passing regex validation, must not be equal to the reserved key {@link OAuthBearerClientInitialResponse#AUTH_KEY} + */ + private void handleExtensionsCallback(SaslExtensionsCallback callback) { + Map extensions = new HashMap<>(); + for (Map.Entry configEntry : this.moduleOptions.entrySet()) { + String key = configEntry.getKey(); + if (!key.startsWith(EXTENSION_PREFIX)) + continue; + + extensions.put(key.substring(EXTENSION_PREFIX.length()), configEntry.getValue()); + } + + SaslExtensions saslExtensions = new SaslExtensions(extensions); + try { + OAuthBearerClientInitialResponse.validateExtensions(saslExtensions); + } catch (SaslException e) { + throw new ConfigException(e.getMessage()); + } + + callback.extensions(saslExtensions); + } + + private String commaPrependedStringNumberAndListClaimsJsonText() throws OAuthBearerConfigException { + StringBuilder sb = new StringBuilder(); + for (String key : moduleOptions.keySet()) { + if (key.startsWith(STRING_CLAIM_PREFIX) && key.length() > STRING_CLAIM_PREFIX.length()) + sb.append(',').append(claimOrHeaderJsonText( + confirmNotReservedClaimName(key.substring(STRING_CLAIM_PREFIX.length())), optionValue(key))); + else if (key.startsWith(NUMBER_CLAIM_PREFIX) && key.length() > NUMBER_CLAIM_PREFIX.length()) + sb.append(',') + .append(claimOrHeaderJsonText( + confirmNotReservedClaimName(key.substring(NUMBER_CLAIM_PREFIX.length())), + Double.valueOf(optionValue(key)))); + else if (key.startsWith(LIST_CLAIM_PREFIX) && key.length() > LIST_CLAIM_PREFIX.length()) + sb.append(',') + .append(claimOrHeaderJsonArrayText( + confirmNotReservedClaimName(key.substring(LIST_CLAIM_PREFIX.length())), + listJsonText(optionValue(key)))); + } + return sb.toString(); + } + + private String confirmNotReservedClaimName(String claimName) throws OAuthBearerConfigException { + if (RESERVED_CLAIMS.contains(claimName)) + throw new OAuthBearerConfigException(String.format("Cannot explicitly set the '%s' claim", claimName)); + return claimName; + } + + private String listJsonText(String value) { + if (value.isEmpty() || value.length() <= 1) + return "[]"; + String delimiter; + String unescapedDelimiterChar = value.substring(0, 1); + switch (unescapedDelimiterChar) { + case "\\": + case ".": + case "[": + case "(": + case "{": + case "|": + case "^": + case "$": + delimiter = "\\" + unescapedDelimiterChar; + break; + default: + delimiter = unescapedDelimiterChar; + break; + } + String listText = value.substring(1); + String[] elements = listText.split(delimiter); + StringBuilder sb = new StringBuilder(); + for (String element : elements) { + sb.append(sb.length() == 0 ? '[' : ','); + sb.append('"').append(escape(element)).append('"'); + } + if (listText.startsWith(unescapedDelimiterChar) || listText.endsWith(unescapedDelimiterChar) + || listText.contains(unescapedDelimiterChar + unescapedDelimiterChar)) + sb.append(",\"\""); + return sb.append(']').toString(); + } + + private String optionValue(String key) { + return optionValue(key, null); + } + + private String optionValue(String key, String defaultValue) { + String explicitValue = option(key); + return explicitValue != null ? explicitValue : defaultValue; + } + + private String option(String key) { + if (!configured) + throw new IllegalStateException("Callback handler not configured"); + return moduleOptions.get(Objects.requireNonNull(key)); + } + + private String claimOrHeaderJsonText(String claimName, Number claimValue) { + return QUOTE + escape(claimName) + QUOTE + ":" + claimValue; + } + + private String claimOrHeaderJsonText(String claimName, String claimValue) { + return QUOTE + escape(claimName) + QUOTE + ":" + QUOTE + escape(claimValue) + QUOTE; + } + + private String claimOrHeaderJsonArrayText(String claimName, String escapedClaimValue) { + if (!escapedClaimValue.startsWith("[") || !escapedClaimValue.endsWith("]")) + throw new IllegalArgumentException(String.format("Illegal JSON array: %s", escapedClaimValue)); + return QUOTE + escape(claimName) + QUOTE + ":" + escapedClaimValue; + } + + private String escape(String jsonStringValue) { + String replace1 = DOUBLEQUOTE.matcher(jsonStringValue).replaceAll(Matcher.quoteReplacement("\\\"")); + return BACKSLASH.matcher(replace1).replaceAll(Matcher.quoteReplacement("\\\\")); + } + + private String expClaimText(long lifetimeSeconds) { + return claimOrHeaderJsonText("exp", time.milliseconds() / 1000.0 + lifetimeSeconds); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredValidatorCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredValidatorCallbackHandler.java new file mode 100644 index 0000000..7a81521 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredValidatorCallbackHandler.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer.internals.unsecured; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.auth.login.AppConfigurationEntry; + +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerExtensionsValidatorCallback; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerValidatorCallback; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A {@code CallbackHandler} that recognizes + * {@link OAuthBearerValidatorCallback} and validates an unsecured OAuth 2 + * bearer token. It requires there to be an "exp" (Expiration Time) + * claim of type Number. If "iat" (Issued At) or + * "nbf" (Not Before) claims are present each must be a number that + * precedes the Expiration Time claim, and if both are present the Not Before + * claim must not precede the Issued At claim. It also accepts the following + * options, none of which are required: + *
          + *
        • {@code unsecuredValidatorPrincipalClaimName} set to a non-empty value if + * you wish a particular String claim holding a principal name to be checked for + * existence; the default is to check for the existence of the '{@code sub}' + * claim
        • + *
        • {@code unsecuredValidatorScopeClaimName} set to a custom claim name if + * you wish the name of the String or String List claim holding any token scope + * to be something other than '{@code scope}'
        • + *
        • {@code unsecuredValidatorRequiredScope} set to a space-delimited list of + * scope values if you wish the String/String List claim holding the token scope + * to be checked to make sure it contains certain values
        • + *
        • {@code unsecuredValidatorAllowableClockSkewMs} set to a positive integer + * value if you wish to allow up to some number of positive milliseconds of + * clock skew (the default is 0)
        • + *
            + * For example: + * + *
            + * KafkaServer {
            + *      org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule Required
            + *      unsecuredLoginStringClaim_sub="thePrincipalName"
            + *      unsecuredLoginListClaim_scope=",KAFKA_BROKER,LOGIN_TO_KAFKA"
            + *      unsecuredValidatorRequiredScope="LOGIN_TO_KAFKA"
            + *      unsecuredValidatorAllowableClockSkewMs="3000";
            + * };
            + * 
            + * It also recognizes {@link OAuthBearerExtensionsValidatorCallback} and validates every extension passed to it. + * + * This class is the default when the SASL mechanism is OAUTHBEARER and no value + * is explicitly set via the + * {@code listener.name.sasl_[plaintext|ssl].oauthbearer.sasl.server.callback.handler.class} + * broker configuration property. + * It is worth noting that this class is not suitable for production use due to the use of unsecured JWT tokens and + * validation of every given extension. + */ +public class OAuthBearerUnsecuredValidatorCallbackHandler implements AuthenticateCallbackHandler { + private static final Logger log = LoggerFactory.getLogger(OAuthBearerUnsecuredValidatorCallbackHandler.class); + private static final String OPTION_PREFIX = "unsecuredValidator"; + private static final String PRINCIPAL_CLAIM_NAME_OPTION = OPTION_PREFIX + "PrincipalClaimName"; + private static final String SCOPE_CLAIM_NAME_OPTION = OPTION_PREFIX + "ScopeClaimName"; + private static final String REQUIRED_SCOPE_OPTION = OPTION_PREFIX + "RequiredScope"; + private static final String ALLOWABLE_CLOCK_SKEW_MILLIS_OPTION = OPTION_PREFIX + "AllowableClockSkewMs"; + private Time time = Time.SYSTEM; + private Map moduleOptions = null; + private boolean configured = false; + + /** + * For testing + * + * @param time + * the mandatory time to set + */ + void time(Time time) { + this.time = Objects.requireNonNull(time); + } + + /** + * Return true if this instance has been configured, otherwise false + * + * @return true if this instance has been configured, otherwise false + */ + public boolean configured() { + return configured; + } + + @SuppressWarnings("unchecked") + @Override + public void configure(Map configs, String saslMechanism, List jaasConfigEntries) { + if (!OAuthBearerLoginModule.OAUTHBEARER_MECHANISM.equals(saslMechanism)) + throw new IllegalArgumentException(String.format("Unexpected SASL mechanism: %s", saslMechanism)); + if (Objects.requireNonNull(jaasConfigEntries).size() != 1 || jaasConfigEntries.get(0) == null) + throw new IllegalArgumentException( + String.format("Must supply exactly 1 non-null JAAS mechanism configuration (size was %d)", + jaasConfigEntries.size())); + final Map unmodifiableModuleOptions = Collections + .unmodifiableMap((Map) jaasConfigEntries.get(0).getOptions()); + this.moduleOptions = unmodifiableModuleOptions; + configured = true; + } + + @Override + public void handle(Callback[] callbacks) throws UnsupportedCallbackException { + if (!configured()) + throw new IllegalStateException("Callback handler not configured"); + for (Callback callback : callbacks) { + if (callback instanceof OAuthBearerValidatorCallback) { + OAuthBearerValidatorCallback validationCallback = (OAuthBearerValidatorCallback) callback; + try { + handleCallback(validationCallback); + } catch (OAuthBearerIllegalTokenException e) { + OAuthBearerValidationResult failureReason = e.reason(); + String failureScope = failureReason.failureScope(); + validationCallback.error(failureScope != null ? "insufficient_scope" : "invalid_token", + failureScope, failureReason.failureOpenIdConfig()); + } + } else if (callback instanceof OAuthBearerExtensionsValidatorCallback) { + OAuthBearerExtensionsValidatorCallback extensionsCallback = (OAuthBearerExtensionsValidatorCallback) callback; + extensionsCallback.inputExtensions().map().forEach((extensionName, v) -> extensionsCallback.valid(extensionName)); + } else + throw new UnsupportedCallbackException(callback); + } + } + + @Override + public void close() { + // empty + } + + private void handleCallback(OAuthBearerValidatorCallback callback) { + String tokenValue = callback.tokenValue(); + if (tokenValue == null) + throw new IllegalArgumentException("Callback missing required token value"); + String principalClaimName = principalClaimName(); + String scopeClaimName = scopeClaimName(); + List requiredScope = requiredScope(); + int allowableClockSkewMs = allowableClockSkewMs(); + OAuthBearerUnsecuredJws unsecuredJwt = new OAuthBearerUnsecuredJws(tokenValue, principalClaimName, + scopeClaimName); + long now = time.milliseconds(); + OAuthBearerValidationUtils + .validateClaimForExistenceAndType(unsecuredJwt, true, principalClaimName, String.class) + .throwExceptionIfFailed(); + OAuthBearerValidationUtils.validateIssuedAt(unsecuredJwt, false, now, allowableClockSkewMs) + .throwExceptionIfFailed(); + OAuthBearerValidationUtils.validateExpirationTime(unsecuredJwt, now, allowableClockSkewMs) + .throwExceptionIfFailed(); + OAuthBearerValidationUtils.validateTimeConsistency(unsecuredJwt).throwExceptionIfFailed(); + OAuthBearerValidationUtils.validateScope(unsecuredJwt, requiredScope).throwExceptionIfFailed(); + log.info("Successfully validated token with principal {}: {}", unsecuredJwt.principalName(), + unsecuredJwt.claims()); + callback.token(unsecuredJwt); + } + + private String principalClaimName() { + String principalClaimNameValue = option(PRINCIPAL_CLAIM_NAME_OPTION); + return Utils.isBlank(principalClaimNameValue) ? "sub" : principalClaimNameValue.trim(); + } + + private String scopeClaimName() { + String scopeClaimNameValue = option(SCOPE_CLAIM_NAME_OPTION); + return Utils.isBlank(scopeClaimNameValue) ? "scope" : scopeClaimNameValue.trim(); + } + + private List requiredScope() { + String requiredSpaceDelimitedScope = option(REQUIRED_SCOPE_OPTION); + return Utils.isBlank(requiredSpaceDelimitedScope) ? Collections.emptyList() : OAuthBearerScopeUtils.parseScope(requiredSpaceDelimitedScope.trim()); + } + + private int allowableClockSkewMs() { + String allowableClockSkewMsValue = option(ALLOWABLE_CLOCK_SKEW_MILLIS_OPTION); + int allowableClockSkewMs = 0; + try { + allowableClockSkewMs = Utils.isBlank(allowableClockSkewMsValue) ? 0 : Integer.parseInt(allowableClockSkewMsValue.trim()); + } catch (NumberFormatException e) { + throw new OAuthBearerConfigException(e.getMessage(), e); + } + if (allowableClockSkewMs < 0) { + throw new OAuthBearerConfigException( + String.format("Allowable clock skew millis must not be negative: %s", allowableClockSkewMsValue)); + } + return allowableClockSkewMs; + } + + private String option(String key) { + if (!configured) + throw new IllegalStateException("Callback handler not configured"); + return moduleOptions.get(Objects.requireNonNull(key)); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerValidationResult.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerValidationResult.java new file mode 100644 index 0000000..2806b4d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerValidationResult.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer.internals.unsecured; + +import java.io.Serializable; + +/** + * The result of some kind of token validation + */ +public class OAuthBearerValidationResult implements Serializable { + private static final long serialVersionUID = 5774669940899777373L; + private final boolean success; + private final String failureDescription; + private final String failureScope; + private final String failureOpenIdConfig; + + /** + * Return an instance indicating success + * + * @return an instance indicating success + */ + public static OAuthBearerValidationResult newSuccess() { + return new OAuthBearerValidationResult(true, null, null, null); + } + + /** + * Return a new validation failure instance + * + * @param failureDescription + * optional description of the failure + * @return a new validation failure instance + */ + public static OAuthBearerValidationResult newFailure(String failureDescription) { + return newFailure(failureDescription, null, null); + } + + /** + * Return a new validation failure instance + * + * @param failureDescription + * optional description of the failure + * @param failureScope + * optional scope to be reported with the failure + * @param failureOpenIdConfig + * optional OpenID Connect configuration to be reported with the + * failure + * @return a new validation failure instance + */ + public static OAuthBearerValidationResult newFailure(String failureDescription, String failureScope, + String failureOpenIdConfig) { + return new OAuthBearerValidationResult(false, failureDescription, failureScope, failureOpenIdConfig); + } + + private OAuthBearerValidationResult(boolean success, String failureDescription, String failureScope, + String failureOpenIdConfig) { + if (success && (failureScope != null || failureOpenIdConfig != null)) + throw new IllegalArgumentException("success was indicated but failure scope/OpenIdConfig were provided"); + this.success = success; + this.failureDescription = failureDescription; + this.failureScope = failureScope; + this.failureOpenIdConfig = failureOpenIdConfig; + } + + /** + * Return true if this instance indicates success, otherwise false + * + * @return true if this instance indicates success, otherwise false + */ + public boolean success() { + return success; + } + + /** + * Return the (potentially null) descriptive message for the failure + * + * @return the (potentially null) descriptive message for the failure + */ + public String failureDescription() { + return failureDescription; + } + + /** + * Return the (potentially null) scope to be reported with the failure + * + * @return the (potentially null) scope to be reported with the failure + */ + public String failureScope() { + return failureScope; + } + + /** + * Return the (potentially null) OpenID Connect configuration to be reported + * with the failure + * + * @return the (potentially null) OpenID Connect configuration to be reported + * with the failure + */ + public String failureOpenIdConfig() { + return failureOpenIdConfig; + } + + /** + * Raise an exception if this instance indicates failure, otherwise do nothing + * + * @throws OAuthBearerIllegalTokenException + * if this instance indicates failure + */ + public void throwExceptionIfFailed() throws OAuthBearerIllegalTokenException { + if (!success()) + throw new OAuthBearerIllegalTokenException(this); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerValidationUtils.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerValidationUtils.java new file mode 100644 index 0000000..ce1b62b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerValidationUtils.java @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer.internals.unsecured; + +import java.util.List; +import java.util.Objects; +import java.util.Set; + +import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; + +public class OAuthBearerValidationUtils { + /** + * Validate the given claim for existence and type. It can be required to exist + * in the given claims, and if it exists it must be one of the types indicated + * + * @param jwt + * the mandatory JWT to which the validation will be applied + * @param required + * true if the claim is required to exist + * @param claimName + * the required claim name identifying the claim to be checked + * @param allowedTypes + * one or more of {@code String.class}, {@code Number.class}, and + * {@code List.class} identifying the type(s) that the claim value is + * allowed to be if it exists + * @return the result of the validation + */ + public static OAuthBearerValidationResult validateClaimForExistenceAndType(OAuthBearerUnsecuredJws jwt, + boolean required, String claimName, Class... allowedTypes) { + Object rawClaim = Objects.requireNonNull(jwt).rawClaim(Objects.requireNonNull(claimName)); + if (rawClaim == null) + return required + ? OAuthBearerValidationResult.newFailure(String.format("Required claim missing: %s", claimName)) + : OAuthBearerValidationResult.newSuccess(); + for (Class allowedType : allowedTypes) { + if (allowedType != null && allowedType.isAssignableFrom(rawClaim.getClass())) + return OAuthBearerValidationResult.newSuccess(); + } + return OAuthBearerValidationResult.newFailure(String.format("The %s claim had the incorrect type: %s", + claimName, rawClaim.getClass().getSimpleName())); + } + + /** + * Validate the 'iat' (Issued At) claim. It can be required to exist in the + * given claims, and if it exists it must be a (potentially fractional) number + * of seconds since the epoch defining when the JWT was issued; it is a + * validation error if the Issued At time is after the time at which the check + * is being done (plus any allowable clock skew). + * + * @param jwt + * the mandatory JWT to which the validation will be applied + * @param required + * true if the claim is required to exist + * @param whenCheckTimeMs + * the time relative to which the validation is to occur + * @param allowableClockSkewMs + * non-negative number to take into account some potential clock skew + * @return the result of the validation + * @throws OAuthBearerConfigException + * if the given allowable clock skew is negative + */ + public static OAuthBearerValidationResult validateIssuedAt(OAuthBearerUnsecuredJws jwt, boolean required, + long whenCheckTimeMs, int allowableClockSkewMs) throws OAuthBearerConfigException { + Number value; + try { + value = Objects.requireNonNull(jwt).issuedAt(); + } catch (OAuthBearerIllegalTokenException e) { + return e.reason(); + } + boolean exists = value != null; + if (!exists) + return doesNotExistResult(required, "iat"); + double doubleValue = value.doubleValue(); + return 1000 * doubleValue > whenCheckTimeMs + confirmNonNegative(allowableClockSkewMs) + ? OAuthBearerValidationResult.newFailure(String.format( + "The Issued At value (%f seconds) was after the indicated time (%d ms) plus allowable clock skew (%d ms)", + doubleValue, whenCheckTimeMs, allowableClockSkewMs)) + : OAuthBearerValidationResult.newSuccess(); + } + + /** + * Validate the 'exp' (Expiration Time) claim. It must exist and it must be a + * (potentially fractional) number of seconds defining the point at which the + * JWT expires. It is a validation error if the time at which the check is being + * done (minus any allowable clock skew) is on or after the Expiration Time + * time. + * + * @param jwt + * the mandatory JWT to which the validation will be applied + * @param whenCheckTimeMs + * the time relative to which the validation is to occur + * @param allowableClockSkewMs + * non-negative number to take into account some potential clock skew + * @return the result of the validation + * @throws OAuthBearerConfigException + * if the given allowable clock skew is negative + */ + public static OAuthBearerValidationResult validateExpirationTime(OAuthBearerUnsecuredJws jwt, long whenCheckTimeMs, + int allowableClockSkewMs) throws OAuthBearerConfigException { + Number value; + try { + value = Objects.requireNonNull(jwt).expirationTime(); + } catch (OAuthBearerIllegalTokenException e) { + return e.reason(); + } + boolean exists = value != null; + if (!exists) + return doesNotExistResult(true, "exp"); + double doubleValue = value.doubleValue(); + return whenCheckTimeMs - confirmNonNegative(allowableClockSkewMs) >= 1000 * doubleValue + ? OAuthBearerValidationResult.newFailure(String.format( + "The indicated time (%d ms) minus allowable clock skew (%d ms) was on or after the Expiration Time value (%f seconds)", + whenCheckTimeMs, allowableClockSkewMs, doubleValue)) + : OAuthBearerValidationResult.newSuccess(); + } + + /** + * Validate the 'iat' (Issued At) and 'exp' (Expiration Time) claims for + * internal consistency. The following must be true if both claims exist: + * + *
            +     * exp > iat
            +     * 
            + * + * @param jwt + * the mandatory JWT to which the validation will be applied + * @return the result of the validation + */ + public static OAuthBearerValidationResult validateTimeConsistency(OAuthBearerUnsecuredJws jwt) { + Number issuedAt; + Number expirationTime; + try { + issuedAt = Objects.requireNonNull(jwt).issuedAt(); + expirationTime = jwt.expirationTime(); + } catch (OAuthBearerIllegalTokenException e) { + return e.reason(); + } + if (expirationTime != null && issuedAt != null && expirationTime.doubleValue() <= issuedAt.doubleValue()) + return OAuthBearerValidationResult.newFailure( + String.format("The Expiration Time time (%f seconds) was not after the Issued At time (%f seconds)", + expirationTime.doubleValue(), issuedAt.doubleValue())); + return OAuthBearerValidationResult.newSuccess(); + } + + /** + * Validate the given token's scope against the required scope. Every required + * scope element (if any) must exist in the provided token's scope for the + * validation to succeed. + * + * @param token + * the required token for which the scope will to validate + * @param requiredScope + * the optional required scope against which the given token's scope + * will be validated + * @return the result of the validation + */ + public static OAuthBearerValidationResult validateScope(OAuthBearerToken token, List requiredScope) { + final Set tokenScope = token.scope(); + if (requiredScope == null || requiredScope.isEmpty()) + return OAuthBearerValidationResult.newSuccess(); + for (String requiredScopeElement : requiredScope) { + if (!tokenScope.contains(requiredScopeElement)) + return OAuthBearerValidationResult.newFailure(String.format( + "The provided scope (%s) was mising a required scope (%s). All required scope elements: %s", + String.valueOf(tokenScope), requiredScopeElement, requiredScope.toString()), + requiredScope.toString(), null); + } + return OAuthBearerValidationResult.newSuccess(); + } + + private static int confirmNonNegative(int allowableClockSkewMs) throws OAuthBearerConfigException { + if (allowableClockSkewMs < 0) + throw new OAuthBearerConfigException( + String.format("Allowable clock skew must not be negative: %d", allowableClockSkewMs)); + return allowableClockSkewMs; + } + + private static OAuthBearerValidationResult doesNotExistResult(boolean required, String claimName) { + return required ? OAuthBearerValidationResult.newFailure(String.format("Required claim missing: %s", claimName)) + : OAuthBearerValidationResult.newSuccess(); + } + + private OAuthBearerValidationUtils() { + // empty + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/AccessTokenRetriever.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/AccessTokenRetriever.java new file mode 100644 index 0000000..e4ae599 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/AccessTokenRetriever.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import java.io.Closeable; +import java.io.IOException; + +/** + * An AccessTokenRetriever is the internal API by which the login module will + * retrieve an access token for use in authorization by the broker. The implementation may + * involve authentication to a remote system, or it can be as simple as loading the contents + * of a file or configuration setting. + * + * Retrieval is a separate concern from validation, so it isn't necessary for + * the AccessTokenRetriever implementation to validate the integrity of the JWT + * access token. + * + * @see HttpAccessTokenRetriever + * @see FileTokenRetriever + */ + +public interface AccessTokenRetriever extends Initable, Closeable { + + /** + * Retrieves a JWT access token in its serialized three-part form. The implementation + * is free to determine how it should be retrieved but should not perform validation + * on the result. + * + * Note: This is a blocking function and callers should be aware that the + * implementation may be communicating over a network, with the file system, coordinating + * threads, etc. The facility in the {@link javax.security.auth.spi.LoginModule} from + * which this is ultimately called does not provide an asynchronous approach. + * + * @return Non-null JWT access token string + * + * @throws IOException Thrown on errors related to IO during retrieval + */ + + String retrieve() throws IOException; + + /** + * Lifecycle method to perform a clean shutdown of the retriever. This must + * be performed by the caller to ensure the correct state, freeing up and releasing any + * resources performed in {@link #init()}. + * + * @throws IOException Thrown on errors related to IO during closure + */ + + default void close() throws IOException { + // This method left intentionally blank. + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/AccessTokenRetrieverFactory.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/AccessTokenRetrieverFactory.java new file mode 100644 index 0000000..e7b3b5c --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/AccessTokenRetrieverFactory.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import static org.apache.kafka.common.config.SaslConfigs.SASL_LOGIN_CONNECT_TIMEOUT_MS; +import static org.apache.kafka.common.config.SaslConfigs.SASL_LOGIN_READ_TIMEOUT_MS; +import static org.apache.kafka.common.config.SaslConfigs.SASL_LOGIN_RETRY_BACKOFF_MAX_MS; +import static org.apache.kafka.common.config.SaslConfigs.SASL_LOGIN_RETRY_BACKOFF_MS; +import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL; +import static org.apache.kafka.common.security.oauthbearer.secured.OAuthBearerLoginCallbackHandler.CLIENT_ID_CONFIG; +import static org.apache.kafka.common.security.oauthbearer.secured.OAuthBearerLoginCallbackHandler.CLIENT_SECRET_CONFIG; +import static org.apache.kafka.common.security.oauthbearer.secured.OAuthBearerLoginCallbackHandler.SCOPE_CONFIG; + +import java.net.URL; +import java.util.Locale; +import java.util.Map; +import javax.net.ssl.SSLSocketFactory; + +public class AccessTokenRetrieverFactory { + + /** + * Create an {@link AccessTokenRetriever} from the given SASL and JAAS configuration. + * + * Note: the returned AccessTokenRetriever is not initialized + * here and must be done by the caller prior to use. + * + * @param configs SASL configuration + * @param jaasConfig JAAS configuration + * + * @return Non-null {@link AccessTokenRetriever} + */ + + public static AccessTokenRetriever create(Map configs, Map jaasConfig) { + return create(configs, null, jaasConfig); + } + + public static AccessTokenRetriever create(Map configs, + String saslMechanism, + Map jaasConfig) { + ConfigurationUtils cu = new ConfigurationUtils(configs, saslMechanism); + URL tokenEndpointUrl = cu.validateUrl(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL); + + if (tokenEndpointUrl.getProtocol().toLowerCase(Locale.ROOT).equals("file")) { + return new FileTokenRetriever(cu.validateFile(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL)); + } else { + JaasOptionsUtils jou = new JaasOptionsUtils(jaasConfig); + String clientId = jou.validateString(CLIENT_ID_CONFIG); + String clientSecret = jou.validateString(CLIENT_SECRET_CONFIG); + String scope = jou.validateString(SCOPE_CONFIG, false); + + SSLSocketFactory sslSocketFactory = null; + + if (jou.shouldCreateSSLSocketFactory(tokenEndpointUrl)) + sslSocketFactory = jou.createSSLSocketFactory(); + + return new HttpAccessTokenRetriever(clientId, + clientSecret, + scope, + sslSocketFactory, + tokenEndpointUrl.toString(), + cu.validateLong(SASL_LOGIN_RETRY_BACKOFF_MS), + cu.validateLong(SASL_LOGIN_RETRY_BACKOFF_MAX_MS), + cu.validateInteger(SASL_LOGIN_CONNECT_TIMEOUT_MS, false), + cu.validateInteger(SASL_LOGIN_READ_TIMEOUT_MS, false)); + } + } + +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/AccessTokenValidator.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/AccessTokenValidator.java new file mode 100644 index 0000000..2a8c2b0 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/AccessTokenValidator.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; + +/** + * An instance of AccessTokenValidator acts as a function object that, given an access + * token in base-64 encoded JWT format, can parse the data, perform validation, and construct an + * {@link OAuthBearerToken} for use by the caller. + * + * The primary reason for this abstraction is that client and broker may have different libraries + * available to them to perform these operations. Additionally, the exact steps for validation may + * differ between implementations. To put this more concretely: the implementation in the Kafka + * client does not have bundled a robust library to perform this logic, and it is not the + * responsibility of the client to perform vigorous validation. However, the Kafka broker ships with + * a richer set of library dependencies that can perform more substantial validation and is also + * expected to perform a trust-but-verify test of the access token's signature. + * + * See: + * + * + * + * @see LoginAccessTokenValidator A basic AccessTokenValidator used by client-side login + * authentication + * @see ValidatorAccessTokenValidator A more robust AccessTokenValidator that is used on the broker + * to validate the token's contents and verify the signature + */ + +public interface AccessTokenValidator { + + /** + * Accepts an OAuth JWT access token in base-64 encoded format, validates, and returns an + * OAuthBearerToken. + * + * @param accessToken Non-null JWT access token + * + * @return {@link OAuthBearerToken} + * + * @throws ValidateException Thrown on errors performing validation of given token + */ + + OAuthBearerToken validate(String accessToken) throws ValidateException; + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/AccessTokenValidatorFactory.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/AccessTokenValidatorFactory.java new file mode 100644 index 0000000..232ebc1 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/AccessTokenValidatorFactory.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_CLOCK_SKEW_SECONDS; +import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_EXPECTED_AUDIENCE; +import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_EXPECTED_ISSUER; +import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_SCOPE_CLAIM_NAME; +import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_SUB_CLAIM_NAME; + +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.jose4j.keys.resolvers.VerificationKeyResolver; + +public class AccessTokenValidatorFactory { + + public static AccessTokenValidator create(Map configs) { + return create(configs, (String) null); + } + + public static AccessTokenValidator create(Map configs, String saslMechanism) { + ConfigurationUtils cu = new ConfigurationUtils(configs, saslMechanism); + String scopeClaimName = cu.get(SASL_OAUTHBEARER_SCOPE_CLAIM_NAME); + String subClaimName = cu.get(SASL_OAUTHBEARER_SUB_CLAIM_NAME); + return new LoginAccessTokenValidator(scopeClaimName, subClaimName); + } + + public static AccessTokenValidator create(Map configs, + VerificationKeyResolver verificationKeyResolver) { + return create(configs, null, verificationKeyResolver); + } + + public static AccessTokenValidator create(Map configs, + String saslMechanism, + VerificationKeyResolver verificationKeyResolver) { + ConfigurationUtils cu = new ConfigurationUtils(configs, saslMechanism); + Set expectedAudiences = null; + List l = cu.get(SASL_OAUTHBEARER_EXPECTED_AUDIENCE); + + if (l != null) + expectedAudiences = Collections.unmodifiableSet(new HashSet<>(l)); + + Integer clockSkew = cu.validateInteger(SASL_OAUTHBEARER_CLOCK_SKEW_SECONDS, false); + String expectedIssuer = cu.validateString(SASL_OAUTHBEARER_EXPECTED_ISSUER, false); + String scopeClaimName = cu.validateString(SASL_OAUTHBEARER_SCOPE_CLAIM_NAME); + String subClaimName = cu.validateString(SASL_OAUTHBEARER_SUB_CLAIM_NAME); + + return new ValidatorAccessTokenValidator(clockSkew, + expectedAudiences, + expectedIssuer, + verificationKeyResolver, + scopeClaimName, + subClaimName); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/BasicOAuthBearerToken.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/BasicOAuthBearerToken.java new file mode 100644 index 0000000..8527f80 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/BasicOAuthBearerToken.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import java.util.Set; +import java.util.StringJoiner; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; + +/** + * An implementation of the {@link OAuthBearerToken} that fairly straightforwardly stores the values + * given to its constructor (except the scope set which is copied to avoid modifications). + * + * Very little validation is applied here with respect to the validity of the given values. All + * validation is assumed to happen by users of this class. + * + * @see RFC 7515: JSON Web Signature (JWS) + */ + +public class BasicOAuthBearerToken implements OAuthBearerToken { + + private final String token; + + private final Set scopes; + + private final Long lifetimeMs; + + private final String principalName; + + private final Long startTimeMs; + + /** + * Creates a new OAuthBearerToken instance around the given values. + * + * @param token Value containing the compact serialization as a base 64 string that + * can be parsed, decoded, and validated as a well-formed JWS. Must be + * non-null, non-blank, and non-whitespace only. + * @param scopes Set of non-null scopes. May contain case-sensitive + * "duplicates". The given set is copied and made unmodifiable so neither + * the caller of this constructor nor any downstream users can modify it. + * @param lifetimeMs The token's lifetime, expressed as the number of milliseconds since the + * epoch. Must be non-negative. + * @param principalName The name of the principal to which this credential applies. Must be + * non-null, non-blank, and non-whitespace only. + * @param startTimeMs The token's start time, expressed as the number of milliseconds since + * the epoch, if available, otherwise null. Must be + * non-negative if a non-null value is provided. + */ + + public BasicOAuthBearerToken(String token, + Set scopes, + long lifetimeMs, + String principalName, + Long startTimeMs) { + this.token = token; + this.scopes = scopes; + this.lifetimeMs = lifetimeMs; + this.principalName = principalName; + this.startTimeMs = startTimeMs; + } + + /** + * The b64token value as defined in + * RFC 6750 Section + * 2.1 + * + * @return b64token value as defined in + * RFC 6750 + * Section 2.1 + */ + + @Override + public String value() { + return token; + } + + /** + * The token's scope of access, as per + * RFC 6749 Section + * 1.4 + * + * @return the token's (always non-null but potentially empty) scope of access, + * as per RFC + * 6749 Section 1.4. Note that all values in the returned set will + * be trimmed of preceding and trailing whitespace, and the result will + * never contain the empty string. + */ + + @Override + public Set scope() { + // Immutability of the set is performed in the constructor/validation utils class, so + // we don't need to repeat it here. + return scopes; + } + + /** + * The token's lifetime, expressed as the number of milliseconds since the + * epoch, as per RFC + * 6749 Section 1.4 + * + * @return the token's lifetime, expressed as the number of milliseconds since + * the epoch, as per + * RFC 6749 + * Section 1.4. + */ + + @Override + public long lifetimeMs() { + return lifetimeMs; + } + + /** + * The name of the principal to which this credential applies + * + * @return the always non-null/non-empty principal name + */ + + @Override + public String principalName() { + return principalName; + } + + /** + * When the credential became valid, in terms of the number of milliseconds + * since the epoch, if known, otherwise null. An expiring credential may not + * necessarily indicate when it was created -- just when it expires -- so we + * need to support a null return value here. + * + * @return the time when the credential became valid, in terms of the number of + * milliseconds since the epoch, if known, otherwise null + */ + + @Override + public Long startTimeMs() { + return startTimeMs; + } + + @Override + public String toString() { + return new StringJoiner(", ", BasicOAuthBearerToken.class.getSimpleName() + "[", "]") + .add("token='" + token + "'") + .add("scopes=" + scopes) + .add("lifetimeMs=" + lifetimeMs) + .add("principalName='" + principalName + "'") + .add("startTimeMs=" + startTimeMs) + .toString(); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/ClaimValidationUtils.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/ClaimValidationUtils.java new file mode 100644 index 0000000..bb08ec5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/ClaimValidationUtils.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +/** + * Simple utility class to perform basic cleaning and validation on input values so that they're + * performed consistently throughout the code base. + */ + +public class ClaimValidationUtils { + + /** + * Validates that the scopes are valid, where invalid means any of + * the following: + * + *
              + *
            • Collection is null
            • + *
            • Collection has duplicates
            • + *
            • Any of the elements in the collection are null
            • + *
            • Any of the elements in the collection are zero length
            • + *
            • Any of the elements in the collection are whitespace only
            • + *
            + * + * @param scopeClaimName Name of the claim used for the scope values + * @param scopes Collection of String scopes + * + * @return Unmodifiable {@link Set} that includes the values of the original set, but with + * each value trimmed + * + * @throws ValidateException Thrown if the value is null, contains duplicates, or + * if any of the values in the set are null, empty, + * or whitespace only + */ + + public static Set validateScopes(String scopeClaimName, Collection scopes) throws ValidateException { + if (scopes == null) + throw new ValidateException(String.format("%s value must be non-null", scopeClaimName)); + + Set copy = new HashSet<>(); + + for (String scope : scopes) { + scope = validateString(scopeClaimName, scope); + + if (copy.contains(scope)) + throw new ValidateException(String.format("%s value must not contain duplicates - %s already present", scopeClaimName, scope)); + + copy.add(scope); + } + + return Collections.unmodifiableSet(copy); + } + + /** + * Validates that the given lifetime is valid, where invalid means any of + * the following: + * + *
              + *
            • null
            • + *
            • Negative
            • + *
            + * + * @param claimName Name of the claim + * @param claimValue Expiration time (in milliseconds) + * + * @return Input parameter, as provided + * + * @throws ValidateException Thrown if the value is null or negative + */ + + public static long validateExpiration(String claimName, Long claimValue) throws ValidateException { + if (claimValue == null) + throw new ValidateException(String.format("%s value must be non-null", claimName)); + + if (claimValue < 0) + throw new ValidateException(String.format("%s value must be non-negative; value given was \"%s\"", claimName, claimValue)); + + return claimValue; + } + + /** + * Validates that the given claim value is valid, where invalid means any of + * the following: + * + *
              + *
            • null
            • + *
            • Zero length
            • + *
            • Whitespace only
            • + *
            + * + * @param claimName Name of the claim + * @param claimValue Name of the subject + * + * @return Trimmed version of the claimValue parameter + * + * @throws ValidateException Thrown if the value is null, empty, or whitespace only + */ + + public static String validateSubject(String claimName, String claimValue) throws ValidateException { + return validateString(claimName, claimValue); + } + + /** + * Validates that the given issued at claim name is valid, where invalid means any of + * the following: + * + *
              + *
            • Negative
            • + *
            + * + * @param claimName Name of the claim + * @param claimValue Start time (in milliseconds) or null if not used + * + * @return Input parameter, as provided + * + * @throws ValidateException Thrown if the value is negative + */ + + public static Long validateIssuedAt(String claimName, Long claimValue) throws ValidateException { + if (claimValue != null && claimValue < 0) + throw new ValidateException(String.format("%s value must be null or non-negative; value given was \"%s\"", claimName, claimValue)); + + return claimValue; + } + + /** + * Validates that the given claim name override is valid, where invalid means + * any of the following: + * + *
              + *
            • null
            • + *
            • Zero length
            • + *
            • Whitespace only
            • + *
            + * + * @param name "Standard" name of the claim, e.g. sub + * @param value "Override" name of the claim, e.g. email + * + * @return Trimmed version of the value parameter + * + * @throws ValidateException Thrown if the value is null, empty, or whitespace only + */ + + public static String validateClaimNameOverride(String name, String value) throws ValidateException { + return validateString(name, value); + } + + private static String validateString(String name, String value) throws ValidateException { + if (value == null) + throw new ValidateException(String.format("%s value must be non-null", name)); + + if (value.isEmpty()) + throw new ValidateException(String.format("%s value must be non-empty", name)); + + value = value.trim(); + + if (value.isEmpty()) + throw new ValidateException(String.format("%s value must not contain only whitespace", name)); + + return value; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/CloseableVerificationKeyResolver.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/CloseableVerificationKeyResolver.java new file mode 100644 index 0000000..b74aaa1 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/CloseableVerificationKeyResolver.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import java.io.Closeable; +import java.io.IOException; +import org.jose4j.keys.resolvers.VerificationKeyResolver; + +/** + * The {@link OAuthBearerValidatorCallbackHandler} uses a {@link VerificationKeyResolver} as + * part of its validation of the incoming JWT. Some of the VerificationKeyResolver + * implementations use resources like threads, connections, etc. that should be properly closed + * when no longer needed. Since the VerificationKeyResolver interface itself doesn't + * define a close method, we provide a means to do that here. + * + * @see OAuthBearerValidatorCallbackHandler + * @see VerificationKeyResolver + * @see Closeable + */ + +public interface CloseableVerificationKeyResolver extends Initable, Closeable, VerificationKeyResolver { + + /** + * Lifecycle method to perform a clean shutdown of the {@link VerificationKeyResolver}. + * This must be performed by the caller to ensure the correct state, freeing up + * and releasing any resources performed in {@link #init()}. + * + * @throws IOException Thrown on errors related to IO during closure + */ + + default void close() throws IOException { + // This method left intentionally blank. + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/ConfigurationUtils.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/ConfigurationUtils.java new file mode 100644 index 0000000..f17295d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/ConfigurationUtils.java @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import java.io.File; +import java.net.MalformedURLException; +import java.net.URISyntaxException; +import java.net.URL; +import java.nio.file.Path; +import java.util.Locale; +import java.util.Map; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.network.ListenerName; + +/** + * ConfigurationUtils is a utility class to perform basic configuration-related + * logic and is separated out here for easier, more direct testing. + */ + +public class ConfigurationUtils { + + private final Map configs; + + private final String prefix; + + public ConfigurationUtils(Map configs) { + this(configs, null); + } + + public ConfigurationUtils(Map configs, String saslMechanism) { + this.configs = configs; + + if (saslMechanism != null && !saslMechanism.trim().isEmpty()) + this.prefix = ListenerName.saslMechanismPrefix(saslMechanism.trim()); + else + this.prefix = null; + } + + /** + * Validates that, if a value is supplied, is a file that: + * + *
          • + *
              exists
            + *
              has read permission
            + *
              points to a file
            + *
          • + * + * If the value is null or an empty string, it is assumed to be an "empty" value and thus. + * ignored. Any whitespace is trimmed off of the beginning and end. + */ + + public Path validateFile(String name) { + URL url = validateUrl(name); + File file; + + try { + file = new File(url.toURI().getRawPath()).getAbsoluteFile(); + } catch (URISyntaxException e) { + throw new ConfigException(name, url.toString(), String.format("The OAuth configuration option %s contains a URL (%s) that is malformed: %s", name, url, e.getMessage())); + } + + if (!file.exists()) + throw new ConfigException(name, file, String.format("The OAuth configuration option %s contains a file (%s) that doesn't exist", name, file)); + + if (!file.canRead()) + throw new ConfigException(name, file, String.format("The OAuth configuration option %s contains a file (%s) that doesn't have read permission", name, file)); + + if (file.isDirectory()) + throw new ConfigException(name, file, String.format("The OAuth configuration option %s references a directory (%s), not a file", name, file)); + + return file.toPath(); + } + + /** + * Validates that, if a value is supplied, is a value that: + * + *
          • + *
              is an Integer
            + *
              has a value that is not less than the provided minimum value
            + *
          • + * + * If the value is null or an empty string, it is assumed to be an "empty" value and thus + * ignored. Any whitespace is trimmed off of the beginning and end. + */ + + public Integer validateInteger(String name, boolean isRequired) { + Integer value = get(name); + + if (value == null) { + if (isRequired) + throw new ConfigException(name, null, String.format("The OAuth configuration option %s must be non-null", name)); + else + return null; + } + + return value; + } + + /** + * Validates that, if a value is supplied, is a value that: + * + *
          • + *
              is an Integer
            + *
              has a value that is not less than the provided minimum value
            + *
          • + * + * If the value is null or an empty string, it is assumed to be an "empty" value and thus + * ignored. Any whitespace is trimmed off of the beginning and end. + */ + + public Long validateLong(String name) { + return validateLong(name, true); + } + + public Long validateLong(String name, boolean isRequired) { + return validateLong(name, isRequired, null); + } + + public Long validateLong(String name, boolean isRequired, Long min) { + Long value = get(name); + + if (value == null) { + if (isRequired) + throw new ConfigException(name, null, String.format("The OAuth configuration option %s must be non-null", name)); + else + return null; + } + + if (min != null && value < min) + throw new ConfigException(name, value, String.format("The OAuth configuration option %s value must be at least %s", name, min)); + + return value; + } + + /** + * Validates that the configured URL that: + * + *
          • + *
              is well-formed
            + *
              contains a scheme
            + *
              uses either HTTP, HTTPS, or file protocols
            + *
          • + * + * No effort is made to connect to the URL in the validation step. + */ + + public URL validateUrl(String name) { + String value = validateString(name); + URL url; + + try { + url = new URL(value); + } catch (MalformedURLException e) { + throw new ConfigException(name, value, String.format("The OAuth configuration option %s contains a URL (%s) that is malformed: %s", name, value, e.getMessage())); + } + + String protocol = url.getProtocol(); + + if (protocol == null || protocol.trim().isEmpty()) + throw new ConfigException(name, value, String.format("The OAuth configuration option %s contains a URL (%s) that is missing the protocol", name, value)); + + protocol = protocol.toLowerCase(Locale.ROOT); + + if (!(protocol.equals("http") || protocol.equals("https") || protocol.equals("file"))) + throw new ConfigException(name, value, String.format("The OAuth configuration option %s contains a URL (%s) that contains an invalid protocol (%s); only \"http\", \"https\", and \"file\" protocol are supported", name, value, protocol)); + + return url; + } + + public String validateString(String name) throws ValidateException { + return validateString(name, true); + } + + public String validateString(String name, boolean isRequired) throws ValidateException { + String value = get(name); + + if (value == null) { + if (isRequired) + throw new ConfigException(String.format("The OAuth configuration option %s value must be non-null", name)); + else + return null; + } + + value = value.trim(); + + if (value.isEmpty()) { + if (isRequired) + throw new ConfigException(String.format("The OAuth configuration option %s value must not contain only whitespace", name)); + else + return null; + } + + return value; + } + + @SuppressWarnings("unchecked") + public T get(String name) { + T value = (T) configs.get(prefix + name); + + if (value != null) + return value; + + return (T) configs.get(name); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/FileTokenRetriever.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/FileTokenRetriever.java new file mode 100644 index 0000000..3ffa4c8 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/FileTokenRetriever.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import java.io.IOException; +import java.nio.file.Path; +import org.apache.kafka.common.utils.Utils; + +/** + * FileTokenRetriever is an {@link AccessTokenRetriever} that will load the contents, + * interpreting them as a JWT access key in the serialized form. + * + * @see AccessTokenRetriever + */ + +public class FileTokenRetriever implements AccessTokenRetriever { + + private final Path accessTokenFile; + + private String accessToken; + + public FileTokenRetriever(Path accessTokenFile) { + this.accessTokenFile = accessTokenFile; + } + + @Override + public void init() throws IOException { + this.accessToken = Utils.readFileAsString(accessTokenFile.toFile().getPath()); + } + + @Override + public String retrieve() throws IOException { + if (accessToken == null) + throw new IllegalStateException("Access token is null; please call init() first"); + + return accessToken; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/HttpAccessTokenRetriever.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/HttpAccessTokenRetriever.java new file mode 100644 index 0000000..b52952a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/HttpAccessTokenRetriever.java @@ -0,0 +1,348 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.UnsupportedEncodingException; +import java.net.HttpURLConnection; +import java.net.URL; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.Collections; +import java.util.HashSet; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import javax.net.ssl.HttpsURLConnection; +import javax.net.ssl.SSLSocketFactory; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.config.SaslConfigs; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * HttpAccessTokenRetriever is an {@link AccessTokenRetriever} that will + * communicate with an OAuth/OIDC provider directly via HTTP to post client credentials + * ({@link OAuthBearerLoginCallbackHandler#CLIENT_ID_CONFIG}/{@link OAuthBearerLoginCallbackHandler#CLIENT_SECRET_CONFIG}) + * to a publicized token endpoint URL + * ({@link SaslConfigs#SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL}). + * + * @see AccessTokenRetriever + * @see OAuthBearerLoginCallbackHandler#CLIENT_ID_CONFIG + * @see OAuthBearerLoginCallbackHandler#CLIENT_SECRET_CONFIG + * @see OAuthBearerLoginCallbackHandler#SCOPE_CONFIG + * @see SaslConfigs#SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL + */ + +public class HttpAccessTokenRetriever implements AccessTokenRetriever { + + private static final Logger log = LoggerFactory.getLogger(HttpAccessTokenRetriever.class); + + private static final Set UNRETRYABLE_HTTP_CODES; + + private static final int MAX_RESPONSE_BODY_LENGTH = 1000; + + public static final String AUTHORIZATION_HEADER = "Authorization"; + + static { + // This does not have to be an exhaustive list. There are other HTTP codes that + // are defined in different RFCs (e.g. https://datatracker.ietf.org/doc/html/rfc6585) + // that we won't worry about yet. The worst case if a status code is missing from + // this set is that the request will be retried. + UNRETRYABLE_HTTP_CODES = new HashSet<>(); + UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_BAD_REQUEST); + UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_UNAUTHORIZED); + UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_PAYMENT_REQUIRED); + UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_FORBIDDEN); + UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_NOT_FOUND); + UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_BAD_METHOD); + UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_NOT_ACCEPTABLE); + UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_PROXY_AUTH); + UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_CONFLICT); + UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_GONE); + UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_LENGTH_REQUIRED); + UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_PRECON_FAILED); + UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_ENTITY_TOO_LARGE); + UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_REQ_TOO_LONG); + UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_UNSUPPORTED_TYPE); + UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_NOT_IMPLEMENTED); + UNRETRYABLE_HTTP_CODES.add(HttpURLConnection.HTTP_VERSION); + } + + private final String clientId; + + private final String clientSecret; + + private final String scope; + + private final SSLSocketFactory sslSocketFactory; + + private final String tokenEndpointUrl; + + private final long loginRetryBackoffMs; + + private final long loginRetryBackoffMaxMs; + + private final Integer loginConnectTimeoutMs; + + private final Integer loginReadTimeoutMs; + + public HttpAccessTokenRetriever(String clientId, + String clientSecret, + String scope, + SSLSocketFactory sslSocketFactory, + String tokenEndpointUrl, + long loginRetryBackoffMs, + long loginRetryBackoffMaxMs, + Integer loginConnectTimeoutMs, + Integer loginReadTimeoutMs) { + this.clientId = Objects.requireNonNull(clientId); + this.clientSecret = Objects.requireNonNull(clientSecret); + this.scope = scope; + this.sslSocketFactory = sslSocketFactory; + this.tokenEndpointUrl = Objects.requireNonNull(tokenEndpointUrl); + this.loginRetryBackoffMs = loginRetryBackoffMs; + this.loginRetryBackoffMaxMs = loginRetryBackoffMaxMs; + this.loginConnectTimeoutMs = loginConnectTimeoutMs; + this.loginReadTimeoutMs = loginReadTimeoutMs; + } + + /** + * Retrieves a JWT access token in its serialized three-part form. The implementation + * is free to determine how it should be retrieved but should not perform validation + * on the result. + * + * Note: This is a blocking function and callers should be aware that the + * implementation communicates over a network. The facility in the + * {@link javax.security.auth.spi.LoginModule} from which this is ultimately called + * does not provide an asynchronous approach. + * + * @return Non-null JWT access token string + * + * @throws IOException Thrown on errors related to IO during retrieval + */ + + @Override + public String retrieve() throws IOException { + String authorizationHeader = formatAuthorizationHeader(clientId, clientSecret); + String requestBody = formatRequestBody(scope); + Retry retry = new Retry<>(loginRetryBackoffMs, loginRetryBackoffMaxMs); + Map headers = Collections.singletonMap(AUTHORIZATION_HEADER, authorizationHeader); + + String responseBody; + + try { + responseBody = retry.execute(() -> { + HttpURLConnection con = null; + + try { + con = (HttpURLConnection) new URL(tokenEndpointUrl).openConnection(); + + if (sslSocketFactory != null && con instanceof HttpsURLConnection) + ((HttpsURLConnection) con).setSSLSocketFactory(sslSocketFactory); + + return post(con, headers, requestBody, loginConnectTimeoutMs, loginReadTimeoutMs); + } catch (IOException e) { + throw new ExecutionException(e); + } finally { + if (con != null) + con.disconnect(); + } + }); + } catch (ExecutionException e) { + if (e.getCause() instanceof IOException) + throw (IOException) e.getCause(); + else + throw new KafkaException(e.getCause()); + } + + return parseAccessToken(responseBody); + } + + public static String post(HttpURLConnection con, + Map headers, + String requestBody, + Integer connectTimeoutMs, + Integer readTimeoutMs) + throws IOException, UnretryableException { + handleInput(con, headers, requestBody, connectTimeoutMs, readTimeoutMs); + return handleOutput(con); + } + + private static void handleInput(HttpURLConnection con, + Map headers, + String requestBody, + Integer connectTimeoutMs, + Integer readTimeoutMs) + throws IOException, UnretryableException { + log.debug("handleInput - starting post for {}", con.getURL()); + con.setRequestMethod("POST"); + con.setRequestProperty("Accept", "application/json"); + + if (headers != null) { + for (Map.Entry header : headers.entrySet()) + con.setRequestProperty(header.getKey(), header.getValue()); + } + + con.setRequestProperty("Cache-Control", "no-cache"); + + if (requestBody != null) { + con.setRequestProperty("Content-Length", String.valueOf(requestBody.length())); + con.setDoOutput(true); + } + + con.setUseCaches(false); + + if (connectTimeoutMs != null) + con.setConnectTimeout(connectTimeoutMs); + + if (readTimeoutMs != null) + con.setReadTimeout(readTimeoutMs); + + log.debug("handleInput - preparing to connect to {}", con.getURL()); + con.connect(); + + if (requestBody != null) { + try (OutputStream os = con.getOutputStream()) { + ByteArrayInputStream is = new ByteArrayInputStream(requestBody.getBytes(StandardCharsets.UTF_8)); + log.debug("handleInput - preparing to write request body to {}", con.getURL()); + copy(is, os); + } + } + } + + static String handleOutput(final HttpURLConnection con) throws IOException { + int responseCode = con.getResponseCode(); + log.debug("handleOutput - responseCode: {}", responseCode); + + String responseBody = null; + + try (InputStream is = con.getInputStream()) { + ByteArrayOutputStream os = new ByteArrayOutputStream(); + log.debug("handleOutput - preparing to read response body from {}", con.getURL()); + copy(is, os); + responseBody = os.toString(StandardCharsets.UTF_8.name()); + } catch (Exception e) { + log.warn("handleOutput - error retrieving data", e); + } + + if (responseCode == HttpURLConnection.HTTP_OK || responseCode == HttpURLConnection.HTTP_CREATED) { + log.debug("handleOutput - responseCode: {}, response: {}", responseCode, responseBody); + + if (responseBody == null || responseBody.isEmpty()) + throw new IOException(String.format("The token endpoint response was unexpectedly empty despite response code %s from %s", responseCode, con.getURL())); + + return responseBody; + } else { + log.warn("handleOutput - error response code: {}, error response body: {}", responseCode, responseBody); + + if (UNRETRYABLE_HTTP_CODES.contains(responseCode)) { + // We know that this is a non-transient error, so let's not keep retrying the + // request unnecessarily. + throw new UnretryableException(new IOException(String.format("The response code %s was encountered reading the token endpoint response; will not attempt further retries", responseCode))); + } else { + // We don't know if this is a transient (retryable) error or not, so let's assume + // it is. + throw new IOException(String.format("The unexpected response code %s was encountered reading the token endpoint response", responseCode)); + } + } + } + + static void copy(InputStream is, OutputStream os) throws IOException { + byte[] buf = new byte[4096]; + int b; + + while ((b = is.read(buf)) != -1) + os.write(buf, 0, b); + } + + static String parseAccessToken(String responseBody) throws IOException { + log.debug("parseAccessToken - responseBody: {}", responseBody); + ObjectMapper mapper = new ObjectMapper(); + JsonNode rootNode = mapper.readTree(responseBody); + JsonNode accessTokenNode = rootNode.at("/access_token"); + + if (accessTokenNode == null) { + // Only grab the first N characters so that if the response body is huge, we don't + // blow up. + String snippet = responseBody; + + if (snippet.length() > MAX_RESPONSE_BODY_LENGTH) { + int actualLength = responseBody.length(); + String s = responseBody.substring(0, MAX_RESPONSE_BODY_LENGTH); + snippet = String.format("%s (trimmed to first %s characters out of %s total)", s, MAX_RESPONSE_BODY_LENGTH, actualLength); + } + + throw new IOException(String.format("The token endpoint response did not contain an access_token value. Response: (%s)", snippet)); + } + + return sanitizeString("the token endpoint response's access_token JSON attribute", accessTokenNode.textValue()); + } + + static String formatAuthorizationHeader(String clientId, String clientSecret) { + clientId = sanitizeString("the token endpoint request client ID parameter", clientId); + clientSecret = sanitizeString("the token endpoint request client secret parameter", clientSecret); + + String s = String.format("%s:%s", clientId, clientSecret); + String encoded = Base64.getUrlEncoder().encodeToString(Utils.utf8(s)); + return String.format("Basic %s", encoded); + } + + static String formatRequestBody(String scope) throws IOException { + try { + StringBuilder requestParameters = new StringBuilder(); + requestParameters.append("grant_type=client_credentials"); + + if (scope != null && !scope.trim().isEmpty()) { + scope = scope.trim(); + String encodedScope = URLEncoder.encode(scope, StandardCharsets.UTF_8.name()); + requestParameters.append("&scope=").append(encodedScope); + } + + return requestParameters.toString(); + } catch (UnsupportedEncodingException e) { + // The world has gone crazy! + throw new IOException(String.format("Encoding %s not supported", StandardCharsets.UTF_8.name())); + } + } + + private static String sanitizeString(String name, String value) { + if (value == null) + throw new IllegalArgumentException(String.format("The value for %s must be non-null", name)); + + if (value.isEmpty()) + throw new IllegalArgumentException(String.format("The value for %s must be non-empty", name)); + + value = value.trim(); + + if (value.isEmpty()) + throw new IllegalArgumentException(String.format("The value for %s must not contain only whitespace", name)); + + return value; + } + +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/Initable.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/Initable.java new file mode 100644 index 0000000..bf4115e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/Initable.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import java.io.IOException; + +public interface Initable { + + /** + * Lifecycle method to perform any one-time initialization of the retriever. This must + * be performed by the caller to ensure the correct state before methods are invoked. + * + * @throws IOException Thrown on errors related to IO during initialization + */ + + default void init() throws IOException { + // This method left intentionally blank. + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/JaasOptionsUtils.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/JaasOptionsUtils.java new file mode 100644 index 0000000..e728881 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/JaasOptionsUtils.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import java.net.URL; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import javax.net.ssl.SSLSocketFactory; +import javax.security.auth.login.AppConfigurationEntry; +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.network.Mode; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; +import org.apache.kafka.common.security.ssl.DefaultSslEngineFactory; +import org.apache.kafka.common.security.ssl.SslFactory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * JaasOptionsUtils is a utility class to perform logic for the JAAS options and + * is separated out here for easier, more direct testing. + */ + +public class JaasOptionsUtils { + + private static final Logger log = LoggerFactory.getLogger(JaasOptionsUtils.class); + + private final Map options; + + public JaasOptionsUtils(Map options) { + this.options = options; + } + + public static Map getOptions(String saslMechanism, List jaasConfigEntries) { + if (!OAuthBearerLoginModule.OAUTHBEARER_MECHANISM.equals(saslMechanism)) + throw new IllegalArgumentException(String.format("Unexpected SASL mechanism: %s", saslMechanism)); + + if (Objects.requireNonNull(jaasConfigEntries).size() != 1 || jaasConfigEntries.get(0) == null) + throw new IllegalArgumentException(String.format("Must supply exactly 1 non-null JAAS mechanism configuration (size was %d)", jaasConfigEntries.size())); + + return Collections.unmodifiableMap(jaasConfigEntries.get(0).getOptions()); + } + + public boolean shouldCreateSSLSocketFactory(URL url) { + return url.getProtocol().equalsIgnoreCase("https"); + } + + public Map getSslClientConfig() { + ConfigDef sslConfigDef = new ConfigDef(); + sslConfigDef.withClientSslSupport(); + AbstractConfig sslClientConfig = new AbstractConfig(sslConfigDef, options); + return sslClientConfig.values(); + } + + public SSLSocketFactory createSSLSocketFactory() { + Map sslClientConfig = getSslClientConfig(); + SslFactory sslFactory = new SslFactory(Mode.CLIENT); + sslFactory.configure(sslClientConfig); + SSLSocketFactory socketFactory = ((DefaultSslEngineFactory) sslFactory.sslEngineFactory()).sslContext().getSocketFactory(); + log.debug("Created SSLSocketFactory: {}", sslClientConfig); + return socketFactory; + } + + public String validateString(String name) throws ValidateException { + return validateString(name, true); + } + + public String validateString(String name, boolean isRequired) throws ValidateException { + String value = (String) options.get(name); + + if (value == null) { + if (isRequired) + throw new ConfigException(String.format("The OAuth configuration option %s value must be non-null", name)); + else + return null; + } + + value = value.trim(); + + if (value.isEmpty()) { + if (isRequired) + throw new ConfigException(String.format("The OAuth configuration option %s value must not contain only whitespace", name)); + else + return null; + } + + return value; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/JwksFileVerificationKeyResolver.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/JwksFileVerificationKeyResolver.java new file mode 100644 index 0000000..19ed749 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/JwksFileVerificationKeyResolver.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import java.io.IOException; +import java.nio.file.Path; +import java.security.Key; +import java.util.List; +import org.apache.kafka.common.utils.Utils; +import org.jose4j.jwk.JsonWebKeySet; +import org.jose4j.jws.JsonWebSignature; +import org.jose4j.jwx.JsonWebStructure; +import org.jose4j.keys.resolvers.JwksVerificationKeyResolver; +import org.jose4j.keys.resolvers.VerificationKeyResolver; +import org.jose4j.lang.JoseException; +import org.jose4j.lang.UnresolvableKeyException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * JwksFileVerificationKeyResolver is a {@link VerificationKeyResolver} implementation + * that will load the JWKS from the given file system directory. + * + * A JWKS (JSON Web Key Set) + * is a JSON document provided by the OAuth/OIDC provider that lists the keys used to sign the JWTs + * it issues. + * + * Here is a sample JWKS JSON document: + * + *
            + * {
            + *   "keys": [
            + *     {
            + *       "kty": "RSA",
            + *       "alg": "RS256",
            + *       "kid": "abc123",
            + *       "use": "sig",
            + *       "e": "AQAB",
            + *       "n": "..."
            + *     },
            + *     {
            + *       "kty": "RSA",
            + *       "alg": "RS256",
            + *       "kid": "def456",
            + *       "use": "sig",
            + *       "e": "AQAB",
            + *       "n": "..."
            + *     }
            + *   ]
            + * }
            + * 
            + * + * Without going into too much detail, the array of keys enumerates the key data that the provider + * is using to sign the JWT. The key ID (kid) is referenced by the JWT's header in + * order to match up the JWT's signing key with the key in the JWKS. During the validation step of + * the broker, the jose4j OAuth library will use the contents of the appropriate key in the JWKS + * to validate the signature. + * + * Given that the JWKS is referenced by the JWT, the JWKS must be made available by the + * OAuth/OIDC provider so that a JWT can be validated. + * + * @see org.apache.kafka.common.config.SaslConfigs#SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL + * @see VerificationKeyResolver + */ + +public class JwksFileVerificationKeyResolver implements CloseableVerificationKeyResolver { + + private static final Logger log = LoggerFactory.getLogger(JwksFileVerificationKeyResolver.class); + + private final Path jwksFile; + + private VerificationKeyResolver delegate; + + public JwksFileVerificationKeyResolver(Path jwksFile) { + this.jwksFile = jwksFile; + } + + @Override + public void init() throws IOException { + log.debug("Starting creation of new VerificationKeyResolver from {}", jwksFile); + String json = Utils.readFileAsString(jwksFile.toFile().getPath()); + + JsonWebKeySet jwks; + + try { + jwks = new JsonWebKeySet(json); + } catch (JoseException e) { + throw new IOException(e); + } + + delegate = new JwksVerificationKeyResolver(jwks.getJsonWebKeys()); + } + + @Override + public Key resolveKey(JsonWebSignature jws, List nestingContext) throws UnresolvableKeyException { + if (delegate == null) + throw new UnresolvableKeyException("VerificationKeyResolver delegate is null; please call init() first"); + + return delegate.resolveKey(jws, nestingContext); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/LoginAccessTokenValidator.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/LoginAccessTokenValidator.java new file mode 100644 index 0000000..b67ffb2 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/LoginAccessTokenValidator.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import static org.apache.kafka.common.config.SaslConfigs.DEFAULT_SASL_OAUTHBEARER_SCOPE_CLAIM_NAME; +import static org.apache.kafka.common.config.SaslConfigs.DEFAULT_SASL_OAUTHBEARER_SUB_CLAIM_NAME; + +import java.util.Collection; +import java.util.Collections; +import java.util.Map; +import java.util.Set; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; +import org.apache.kafka.common.security.oauthbearer.internals.unsecured.OAuthBearerIllegalTokenException; +import org.apache.kafka.common.security.oauthbearer.internals.unsecured.OAuthBearerUnsecuredJws; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * LoginAccessTokenValidator is an implementation of {@link AccessTokenValidator} that is used + * by the client to perform some rudimentary validation of the JWT access token that is received + * as part of the response from posting the client credentials to the OAuth/OIDC provider's + * token endpoint. + * + * The validation steps performed are: + * + *
              + *
            1. + * Basic structural validation of the b64token value as defined in + * RFC 6750 Section 2.1 + *
            2. + *
            3. Basic conversion of the token into an in-memory map
            4. + *
            5. Presence of scope, exp, subject, and iat claims
            6. + *
            + */ + +public class LoginAccessTokenValidator implements AccessTokenValidator { + + private static final Logger log = LoggerFactory.getLogger(LoginAccessTokenValidator.class); + + public static final String EXPIRATION_CLAIM_NAME = "exp"; + + public static final String ISSUED_AT_CLAIM_NAME = "iat"; + + private final String scopeClaimName; + + private final String subClaimName; + + /** + * Creates a new LoginAccessTokenValidator that will be used by the client for lightweight + * validation of the JWT. + * + * @param scopeClaimName Name of the scope claim to use; must be non-null + * @param subClaimName Name of the subject claim to use; must be non-null + */ + + public LoginAccessTokenValidator(String scopeClaimName, String subClaimName) { + this.scopeClaimName = ClaimValidationUtils.validateClaimNameOverride(DEFAULT_SASL_OAUTHBEARER_SCOPE_CLAIM_NAME, scopeClaimName); + this.subClaimName = ClaimValidationUtils.validateClaimNameOverride(DEFAULT_SASL_OAUTHBEARER_SUB_CLAIM_NAME, subClaimName); + } + + /** + * Accepts an OAuth JWT access token in base-64 encoded format, validates, and returns an + * OAuthBearerToken. + * + * @param accessToken Non-null JWT access token + * @return {@link OAuthBearerToken} + * @throws ValidateException Thrown on errors performing validation of given token + */ + + @SuppressWarnings("unchecked") + public OAuthBearerToken validate(String accessToken) throws ValidateException { + SerializedJwt serializedJwt = new SerializedJwt(accessToken); + Map payload; + + try { + payload = OAuthBearerUnsecuredJws.toMap(serializedJwt.getPayload()); + } catch (OAuthBearerIllegalTokenException e) { + throw new ValidateException(String.format("Could not validate the access token: %s", e.getMessage()), e); + } + + Object scopeRaw = getClaim(payload, scopeClaimName); + Collection scopeRawCollection; + + if (scopeRaw instanceof String) + scopeRawCollection = Collections.singletonList((String) scopeRaw); + else if (scopeRaw instanceof Collection) + scopeRawCollection = (Collection) scopeRaw; + else + scopeRawCollection = Collections.emptySet(); + + Number expirationRaw = (Number) getClaim(payload, EXPIRATION_CLAIM_NAME); + String subRaw = (String) getClaim(payload, subClaimName); + Number issuedAtRaw = (Number) getClaim(payload, ISSUED_AT_CLAIM_NAME); + + Set scopes = ClaimValidationUtils.validateScopes(scopeClaimName, scopeRawCollection); + long expiration = ClaimValidationUtils.validateExpiration(EXPIRATION_CLAIM_NAME, + expirationRaw != null ? expirationRaw.longValue() * 1000L : null); + String subject = ClaimValidationUtils.validateSubject(subClaimName, subRaw); + Long issuedAt = ClaimValidationUtils.validateIssuedAt(ISSUED_AT_CLAIM_NAME, + issuedAtRaw != null ? issuedAtRaw.longValue() * 1000L : null); + + OAuthBearerToken token = new BasicOAuthBearerToken(accessToken, + scopes, + expiration, + subject, + issuedAt); + + return token; + } + + private Object getClaim(Map payload, String claimName) { + Object value = payload.get(claimName); + log.debug("getClaim - {}: {}", claimName, value); + return value; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/OAuthBearerLoginCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/OAuthBearerLoginCallbackHandler.java new file mode 100644 index 0000000..da426f0 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/OAuthBearerLoginCallbackHandler.java @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.auth.login.AppConfigurationEntry; +import javax.security.sasl.SaslException; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; +import org.apache.kafka.common.security.auth.SaslExtensions; +import org.apache.kafka.common.security.auth.SaslExtensionsCallback; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback; +import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerClientInitialResponse; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + *

            + * OAuthBearerLoginCallbackHandler is an {@link AuthenticateCallbackHandler} that + * accepts {@link OAuthBearerTokenCallback} and {@link SaslExtensionsCallback} callbacks to + * perform the steps to request a JWT from an OAuth/OIDC provider using the + * clientcredentials. This grant type is commonly used for non-interactive + * "service accounts" where there is no user available to interactively supply credentials. + *

            + * + *

            + * The OAuthBearerLoginCallbackHandler is used on the client side to retrieve a JWT + * and the {@link OAuthBearerValidatorCallbackHandler} is used on the broker to validate the JWT + * that was sent to it by the client to allow access. Both the brokers and clients will need to + * be configured with their appropriate callback handlers and respective configuration for OAuth + * functionality to work. + *

            + * + *

            + * Note that while this callback handler class must be specified for a Kafka client that wants to + * use OAuth functionality, in the case of OAuth-based inter-broker communication, the callback + * handler must be used on the Kafka broker side as well. + * {@link } + *

            + * + *

            + * This {@link AuthenticateCallbackHandler} is enabled by specifying its class name in the Kafka + * configuration. For client use, specify the class name in the + * {@link org.apache.kafka.common.config.SaslConfigs#SASL_LOGIN_CALLBACK_HANDLER_CLASS} + * configuration like so: + * + * + * sasl.login.callback.handler.class=org.apache.kafka.common.security.oauthbearer.secured.OAuthBearerLoginCallbackHandler + * + *

            + * + *

            + * If using OAuth login on the broker side (for inter-broker communication), the callback handler + * class will be specified with a listener-based property: + * listener.name..oauthbearer.sasl.login.callback.handler.class like so: + * + * + * listener.name..oauthbearer.sasl.login.callback.handler.class=org.apache.kafka.common.security.oauthbearer.secured.OAuthBearerLoginCallbackHandler + * + *

            + * + *

            + * The Kafka configuration must also include JAAS configuration which includes the following + * OAuth-specific options: + * + *

              + *
            • clientIdOAuth client ID (required)
            • + *
            • clientSecretOAuth client secret (required)
            • + *
            • scopeOAuth scope (optional)
            • + *
            + *

            + * + *

            + * The JAAS configuration can also include any SSL options that are needed. The configuration + * options are the same as those specified by the configuration in + * {@link org.apache.kafka.common.config.SslConfigs#addClientSslSupport(ConfigDef)}. + *

            + * + *

            + * Here's an example of the JAAS configuration for a Kafka client: + * + * + * sasl.jaas.config=org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required \ + * clientId="foo" \ + * clientSecret="bar" \ + * scope="baz" \ + * ssl.protocol="SSL" ; + * + *

            + * + *

            + * The configuration option + * {@link org.apache.kafka.common.config.SaslConfigs#SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL} + * is also required in order for the client to contact the OAuth/OIDC provider. For example: + * + * + * sasl.oauthbearer.token.endpoint.url=https://example.com/oauth2/v1/token + * + * + * Please see the OAuth/OIDC providers documentation for the token endpoint URL. + *

            + * + *

            + * The following is a list of all the configuration options that are available for the login + * callback handler: + * + *

              + *
            • {@link org.apache.kafka.common.config.SaslConfigs#SASL_LOGIN_CALLBACK_HANDLER_CLASS}
            • + *
            • {@link org.apache.kafka.common.config.SaslConfigs#SASL_LOGIN_CONNECT_TIMEOUT_MS}
            • + *
            • {@link org.apache.kafka.common.config.SaslConfigs#SASL_LOGIN_READ_TIMEOUT_MS}
            • + *
            • {@link org.apache.kafka.common.config.SaslConfigs#SASL_LOGIN_RETRY_BACKOFF_MS}
            • + *
            • {@link org.apache.kafka.common.config.SaslConfigs#SASL_LOGIN_RETRY_BACKOFF_MAX_MS}
            • + *
            • {@link org.apache.kafka.common.config.SaslConfigs#SASL_JAAS_CONFIG}
            • + *
            • {@link org.apache.kafka.common.config.SaslConfigs#SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL}
            • + *
            • {@link org.apache.kafka.common.config.SaslConfigs#SASL_OAUTHBEARER_SCOPE_CLAIM_NAME}
            • + *
            • {@link org.apache.kafka.common.config.SaslConfigs#SASL_OAUTHBEARER_SUB_CLAIM_NAME}
            • + *
            + *

            + */ + +public class OAuthBearerLoginCallbackHandler implements AuthenticateCallbackHandler { + + private static final Logger log = LoggerFactory.getLogger(OAuthBearerLoginCallbackHandler.class); + + public static final String CLIENT_ID_CONFIG = "clientId"; + public static final String CLIENT_SECRET_CONFIG = "clientSecret"; + public static final String SCOPE_CONFIG = "scope"; + + public static final String CLIENT_ID_DOC = "The OAuth/OIDC identity provider-issued " + + "client ID to uniquely identify the service account to use for authentication for " + + "this client. The value must be paired with a corresponding " + CLIENT_SECRET_CONFIG + " " + + "value and is provided to the OAuth provider using the OAuth " + + "clientcredentials grant type."; + + public static final String CLIENT_SECRET_DOC = "The OAuth/OIDC identity provider-issued " + + "client secret serves a similar function as a password to the " + CLIENT_ID_CONFIG + " " + + "account and identifies the service account to use for authentication for " + + "this client. The value must be paired with a corresponding " + CLIENT_ID_CONFIG + " " + + "value and is provided to the OAuth provider using the OAuth " + + "clientcredentials grant type."; + + public static final String SCOPE_DOC = "The (optional) HTTP/HTTPS login request to the " + + "token endpoint (" + SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL + ") may need to specify an " + + "OAuth \"scope\". If so, the " + SCOPE_CONFIG + " is used to provide the value to " + + "include with the login request."; + + private static final String EXTENSION_PREFIX = "extension_"; + + private Map moduleOptions; + + private AccessTokenRetriever accessTokenRetriever; + + private AccessTokenValidator accessTokenValidator; + + private boolean isInitialized = false; + + @Override + public void configure(Map configs, String saslMechanism, List jaasConfigEntries) { + moduleOptions = JaasOptionsUtils.getOptions(saslMechanism, jaasConfigEntries); + AccessTokenRetriever accessTokenRetriever = AccessTokenRetrieverFactory.create(configs, saslMechanism, moduleOptions); + AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs, saslMechanism); + init(accessTokenRetriever, accessTokenValidator); + } + + /* + * Package-visible for testing. + */ + + void init(AccessTokenRetriever accessTokenRetriever, AccessTokenValidator accessTokenValidator) { + this.accessTokenRetriever = accessTokenRetriever; + this.accessTokenValidator = accessTokenValidator; + + try { + this.accessTokenRetriever.init(); + } catch (IOException e) { + throw new KafkaException("The OAuth login configuration encountered an error when initializing the AccessTokenRetriever", e); + } + + isInitialized = true; + } + + /* + * Package-visible for testing. + */ + + AccessTokenRetriever getAccessTokenRetriever() { + return accessTokenRetriever; + } + + @Override + public void close() { + if (accessTokenRetriever != null) { + try { + this.accessTokenRetriever.close(); + } catch (IOException e) { + log.warn("The OAuth login configuration encountered an error when closing the AccessTokenRetriever", e); + } + } + } + + @Override + public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { + checkInitialized(); + + for (Callback callback : callbacks) { + if (callback instanceof OAuthBearerTokenCallback) { + handleTokenCallback((OAuthBearerTokenCallback) callback); + } else if (callback instanceof SaslExtensionsCallback) { + handleExtensionsCallback((SaslExtensionsCallback) callback); + } else { + throw new UnsupportedCallbackException(callback); + } + } + } + + private void handleTokenCallback(OAuthBearerTokenCallback callback) throws IOException { + checkInitialized(); + String accessToken = accessTokenRetriever.retrieve(); + + try { + OAuthBearerToken token = accessTokenValidator.validate(accessToken); + callback.token(token); + } catch (ValidateException e) { + log.warn(e.getMessage(), e); + callback.error("invalid_token", e.getMessage(), null); + } + } + + private void handleExtensionsCallback(SaslExtensionsCallback callback) { + checkInitialized(); + + Map extensions = new HashMap<>(); + + for (Map.Entry configEntry : this.moduleOptions.entrySet()) { + String key = configEntry.getKey(); + + if (!key.startsWith(EXTENSION_PREFIX)) + continue; + + Object valueRaw = configEntry.getValue(); + String value; + + if (valueRaw instanceof String) + value = (String) valueRaw; + else + value = String.valueOf(valueRaw); + + extensions.put(key.substring(EXTENSION_PREFIX.length()), value); + } + + SaslExtensions saslExtensions = new SaslExtensions(extensions); + + try { + OAuthBearerClientInitialResponse.validateExtensions(saslExtensions); + } catch (SaslException e) { + throw new ConfigException(e.getMessage()); + } + + callback.extensions(saslExtensions); + } + + private void checkInitialized() { + if (!isInitialized) + throw new IllegalStateException(String.format("To use %s, first call the configure or init method", getClass().getSimpleName())); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/OAuthBearerValidatorCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/OAuthBearerValidatorCallbackHandler.java new file mode 100644 index 0000000..5ba7378 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/OAuthBearerValidatorCallbackHandler.java @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import java.io.IOException; +import java.security.Key; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicInteger; +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.auth.login.AppConfigurationEntry; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerExtensionsValidatorCallback; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerValidatorCallback; +import org.jose4j.jws.JsonWebSignature; +import org.jose4j.jwx.JsonWebStructure; +import org.jose4j.lang.UnresolvableKeyException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + *

            + * OAuthBearerValidatorCallbackHandler is an {@link AuthenticateCallbackHandler} that + * accepts {@link OAuthBearerValidatorCallback} and {@link OAuthBearerExtensionsValidatorCallback} + * callbacks to implement OAuth/OIDC validation. This callback handler is intended only to be used + * on the Kafka broker side as it will receive a {@link OAuthBearerValidatorCallback} that includes + * the JWT provided by the Kafka client. That JWT is validated in terms of format, expiration, + * signature, and audience and issuer (if desired). This callback handler is the broker side of the + * OAuth functionality, whereas {@link OAuthBearerLoginCallbackHandler} is used by clients. + *

            + * + *

            + * This {@link AuthenticateCallbackHandler} is enabled in the broker configuration by setting the + * {@link org.apache.kafka.common.config.internals.BrokerSecurityConfigs#SASL_SERVER_CALLBACK_HANDLER_CLASS} + * like so: + * + * + * listener.name..oauthbearer.sasl.server.callback.handler.class=org.apache.kafka.common.security.oauthbearer.secured.OAuthBearerValidatorCallbackHandler + * + *

            + * + *

            + * The JAAS configuration for OAuth is also needed. If using OAuth for inter-broker communication, + * the options are those specified in {@link OAuthBearerLoginCallbackHandler}. + *

            + * + *

            + * The configuration option + * {@link org.apache.kafka.common.config.SaslConfigs#SASL_OAUTHBEARER_JWKS_ENDPOINT_URL} + * is also required in order to contact the OAuth/OIDC provider to retrieve the JWKS for use in + * JWT signature validation. For example: + * + * + * listener.name..oauthbearer.sasl.oauthbearer.jwks.endpoint.url=https://example.com/oauth2/v1/keys + * + * + * Please see the OAuth/OIDC providers documentation for the JWKS endpoint URL. + *

            + * + *

            + * The following is a list of all the configuration options that are available for the broker + * validation callback handler: + * + *

              + *
            • {@link org.apache.kafka.common.config.internals.BrokerSecurityConfigs#SASL_SERVER_CALLBACK_HANDLER_CLASS}
            • + *
            • {@link org.apache.kafka.common.config.SaslConfigs#SASL_JAAS_CONFIG}
            • + *
            • {@link org.apache.kafka.common.config.SaslConfigs#SASL_OAUTHBEARER_CLOCK_SKEW_SECONDS}
            • + *
            • {@link org.apache.kafka.common.config.SaslConfigs#SASL_OAUTHBEARER_EXPECTED_AUDIENCE}
            • + *
            • {@link org.apache.kafka.common.config.SaslConfigs#SASL_OAUTHBEARER_EXPECTED_ISSUER}
            • + *
            • {@link org.apache.kafka.common.config.SaslConfigs#SASL_OAUTHBEARER_JWKS_ENDPOINT_REFRESH_MS}
            • + *
            • {@link org.apache.kafka.common.config.SaslConfigs#SASL_OAUTHBEARER_JWKS_ENDPOINT_RETRY_BACKOFF_MAX_MS}
            • + *
            • {@link org.apache.kafka.common.config.SaslConfigs#SASL_OAUTHBEARER_JWKS_ENDPOINT_RETRY_BACKOFF_MS}
            • + *
            • {@link org.apache.kafka.common.config.SaslConfigs#SASL_OAUTHBEARER_JWKS_ENDPOINT_URL}
            • + *
            • {@link org.apache.kafka.common.config.SaslConfigs#SASL_OAUTHBEARER_SCOPE_CLAIM_NAME}
            • + *
            • {@link org.apache.kafka.common.config.SaslConfigs#SASL_OAUTHBEARER_SUB_CLAIM_NAME}
            • + *
            + *

            + */ + +public class OAuthBearerValidatorCallbackHandler implements AuthenticateCallbackHandler { + + private static final Logger log = LoggerFactory.getLogger(OAuthBearerValidatorCallbackHandler.class); + + /** + * Because a {@link CloseableVerificationKeyResolver} instance can spawn threads and issue + * HTTP(S) calls ({@link RefreshingHttpsJwksVerificationKeyResolver}), we only want to create + * a new instance for each particular set of configuration. Because each set of configuration + * may have multiple instances, we want to reuse the single instance. + */ + + private static final Map VERIFICATION_KEY_RESOLVER_CACHE = new HashMap<>(); + + private CloseableVerificationKeyResolver verificationKeyResolver; + + private AccessTokenValidator accessTokenValidator; + + private boolean isInitialized = false; + + @Override + public void configure(Map configs, String saslMechanism, List jaasConfigEntries) { + Map moduleOptions = JaasOptionsUtils.getOptions(saslMechanism, jaasConfigEntries); + CloseableVerificationKeyResolver verificationKeyResolver; + + // Here's the logic which keeps our VerificationKeyResolvers down to a single instance. + synchronized (VERIFICATION_KEY_RESOLVER_CACHE) { + VerificationKeyResolverKey key = new VerificationKeyResolverKey(configs, moduleOptions); + verificationKeyResolver = VERIFICATION_KEY_RESOLVER_CACHE.computeIfAbsent(key, k -> + new RefCountingVerificationKeyResolver(VerificationKeyResolverFactory.create(configs, saslMechanism, moduleOptions))); + } + + AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs, saslMechanism, verificationKeyResolver); + init(verificationKeyResolver, accessTokenValidator); + } + + /* + * Package-visible for testing. + */ + + void init(CloseableVerificationKeyResolver verificationKeyResolver, AccessTokenValidator accessTokenValidator) { + this.verificationKeyResolver = verificationKeyResolver; + this.accessTokenValidator = accessTokenValidator; + + try { + verificationKeyResolver.init(); + } catch (Exception e) { + throw new KafkaException("The OAuth validator configuration encountered an error when initializing the VerificationKeyResolver", e); + } + + isInitialized = true; + } + + @Override + public void close() { + if (verificationKeyResolver != null) { + try { + verificationKeyResolver.close(); + } catch (Exception e) { + log.error(e.getMessage(), e); + } + } + } + + @Override + public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { + checkInitialized(); + + for (Callback callback : callbacks) { + if (callback instanceof OAuthBearerValidatorCallback) { + handleValidatorCallback((OAuthBearerValidatorCallback) callback); + } else if (callback instanceof OAuthBearerExtensionsValidatorCallback) { + handleExtensionsValidatorCallback((OAuthBearerExtensionsValidatorCallback) callback); + } else { + throw new UnsupportedCallbackException(callback); + } + } + } + + private void handleValidatorCallback(OAuthBearerValidatorCallback callback) { + checkInitialized(); + + OAuthBearerToken token; + + try { + token = accessTokenValidator.validate(callback.tokenValue()); + callback.token(token); + } catch (ValidateException e) { + log.warn(e.getMessage(), e); + callback.error("invalid_token", null, null); + } + } + + private void handleExtensionsValidatorCallback(OAuthBearerExtensionsValidatorCallback extensionsValidatorCallback) { + checkInitialized(); + + extensionsValidatorCallback.inputExtensions().map().forEach((extensionName, v) -> extensionsValidatorCallback.valid(extensionName)); + } + + private void checkInitialized() { + if (!isInitialized) + throw new IllegalStateException(String.format("To use %s, first call the configure or init method", getClass().getSimpleName())); + } + + /** + * VkrKey is a simple structure which encapsulates the criteria for different + * sets of configuration. This will allow us to use this object as a key in a {@link Map} + * to keep a single instance per key. + */ + + private static class VerificationKeyResolverKey { + + private final Map configs; + + private final Map moduleOptions; + + public VerificationKeyResolverKey(Map configs, Map moduleOptions) { + this.configs = configs; + this.moduleOptions = moduleOptions; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + VerificationKeyResolverKey that = (VerificationKeyResolverKey) o; + return configs.equals(that.configs) && moduleOptions.equals(that.moduleOptions); + } + + @Override + public int hashCode() { + return Objects.hash(configs, moduleOptions); + } + + } + + /** + * RefCountingVerificationKeyResolver allows us to share a single + * {@link CloseableVerificationKeyResolver} instance between multiple + * {@link AuthenticateCallbackHandler} instances and perform the lifecycle methods the + * appropriate number of times. + */ + + private static class RefCountingVerificationKeyResolver implements CloseableVerificationKeyResolver { + + private final CloseableVerificationKeyResolver delegate; + + private final AtomicInteger count = new AtomicInteger(0); + + public RefCountingVerificationKeyResolver(CloseableVerificationKeyResolver delegate) { + this.delegate = delegate; + } + + @Override + public Key resolveKey(JsonWebSignature jws, List nestingContext) throws UnresolvableKeyException { + return delegate.resolveKey(jws, nestingContext); + } + + @Override + public void init() throws IOException { + if (count.incrementAndGet() == 1) + delegate.init(); + } + + @Override + public void close() throws IOException { + if (count.decrementAndGet() == 0) + delegate.close(); + } + + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/RefreshingHttpsJwks.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/RefreshingHttpsJwks.java new file mode 100644 index 0000000..4003a44 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/RefreshingHttpsJwks.java @@ -0,0 +1,364 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import java.io.Closeable; +import java.io.IOException; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; +import org.apache.kafka.common.utils.Time; +import org.jose4j.jwk.HttpsJwks; +import org.jose4j.jwk.JsonWebKey; +import org.jose4j.lang.JoseException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Implementation of {@link HttpsJwks} that will periodically refresh the JWKS cache to reduce or + * even prevent HTTP/HTTPS traffic in the hot path of validation. It is assumed that it's + * possible to receive a JWT that contains a kid that points to yet-unknown JWK, + * thus requiring a connection to the OAuth/OIDC provider to be made. Hopefully, in practice, + * keys are made available for some amount of time before they're used within JWTs. + * + * This instance is created and provided to the + * {@link org.jose4j.keys.resolvers.HttpsJwksVerificationKeyResolver} that is used when using + * an HTTP-/HTTPS-based {@link org.jose4j.keys.resolvers.VerificationKeyResolver}, which is then + * provided to the {@link ValidatorAccessTokenValidator} to use in validating the signature of + * a JWT. + * + * @see org.jose4j.keys.resolvers.HttpsJwksVerificationKeyResolver + * @see org.jose4j.keys.resolvers.VerificationKeyResolver + * @see ValidatorAccessTokenValidator + */ + +public final class RefreshingHttpsJwks implements Initable, Closeable { + + private static final Logger log = LoggerFactory.getLogger(RefreshingHttpsJwks.class); + + private static final int MISSING_KEY_ID_CACHE_MAX_ENTRIES = 16; + + static final long MISSING_KEY_ID_CACHE_IN_FLIGHT_MS = 60000; + + static final int MISSING_KEY_ID_MAX_KEY_LENGTH = 1000; + + private static final int SHUTDOWN_TIMEOUT = 10; + + private static final TimeUnit SHUTDOWN_TIME_UNIT = TimeUnit.SECONDS; + + /** + * {@link HttpsJwks} does the actual work of contacting the OAuth/OIDC endpoint to get the + * JWKS. In some cases, the call to {@link HttpsJwks#getJsonWebKeys()} will trigger a call + * to {@link HttpsJwks#refresh()} which will block the current thread in network I/O. We cache + * the JWKS ourselves (see {@link #jsonWebKeys}) to avoid the network I/O. + * + * We want to be very careful where we use the {@link HttpsJwks} instance so that we don't + * perform any operation (directly or indirectly) that could cause blocking. This is because + * the JWKS logic is part of the larger authentication logic which operates on Kafka's network + * thread. It's OK to execute {@link HttpsJwks#getJsonWebKeys()} (which calls + * {@link HttpsJwks#refresh()}) from within {@link #init()} as that method is called only at + * startup, and we can afford the blocking hit there. + */ + + private final HttpsJwks httpsJwks; + + private final ScheduledExecutorService executorService; + + private final Time time; + + private final long refreshMs; + + private final long refreshRetryBackoffMs; + + private final long refreshRetryBackoffMaxMs; + + /** + * Protects {@link #missingKeyIds} and {@link #jsonWebKeys}. + */ + + private final ReadWriteLock refreshLock = new ReentrantReadWriteLock(); + + private final Map missingKeyIds; + + /** + * Flag to prevent concurrent refresh invocations. + */ + + private final AtomicBoolean refreshInProgressFlag = new AtomicBoolean(false); + + /** + * As mentioned in the comments for {@link #httpsJwks}, we cache the JWKS ourselves so that + * we can return the list immediately without any network I/O. They are only cached within + * calls to {@link #refresh()}. + */ + + private List jsonWebKeys; + + private boolean isInitialized; + + /** + * Creates a RefreshingHttpsJwks that will be used by the + * {@link RefreshingHttpsJwksVerificationKeyResolver} to resolve new key IDs in JWTs. + * + * @param time {@link Time} instance + * @param httpsJwks {@link HttpsJwks} instance from which to retrieve the JWKS + * based on the OAuth/OIDC standard + * @param refreshMs The number of milliseconds between refresh passes to connect + * to the OAuth/OIDC JWKS endpoint to retrieve the latest set + * @param refreshRetryBackoffMs Time for delay after initial failed attempt to retrieve JWKS + * @param refreshRetryBackoffMaxMs Maximum time to retrieve JWKS + */ + + public RefreshingHttpsJwks(Time time, + HttpsJwks httpsJwks, + long refreshMs, + long refreshRetryBackoffMs, + long refreshRetryBackoffMaxMs) { + if (refreshMs <= 0) + throw new IllegalArgumentException("JWKS validation key refresh configuration value retryWaitMs value must be positive"); + + this.httpsJwks = httpsJwks; + this.time = time; + this.refreshMs = refreshMs; + this.refreshRetryBackoffMs = refreshRetryBackoffMs; + this.refreshRetryBackoffMaxMs = refreshRetryBackoffMaxMs; + this.executorService = Executors.newSingleThreadScheduledExecutor(); + this.missingKeyIds = new LinkedHashMap(MISSING_KEY_ID_CACHE_MAX_ENTRIES, .75f, true) { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return this.size() > MISSING_KEY_ID_CACHE_MAX_ENTRIES; + } + }; + } + + @Override + public void init() throws IOException { + try { + log.debug("init started"); + + List localJWKs; + + try { + localJWKs = httpsJwks.getJsonWebKeys(); + } catch (JoseException e) { + throw new IOException("Could not refresh JWKS", e); + } + + try { + refreshLock.writeLock().lock(); + jsonWebKeys = Collections.unmodifiableList(localJWKs); + } finally { + refreshLock.writeLock().unlock(); + } + + // Since we just grabbed the keys (which will have invoked a HttpsJwks.refresh() + // internally), we can delay our first invocation by refreshMs. + // + // Note: we refer to this as a _scheduled_ refresh. + executorService.scheduleAtFixedRate(this::refresh, + refreshMs, + refreshMs, + TimeUnit.MILLISECONDS); + + log.info("JWKS validation key refresh thread started with a refresh interval of {} ms", refreshMs); + } finally { + isInitialized = true; + + log.debug("init completed"); + } + } + + @Override + public void close() { + try { + log.debug("close started"); + + try { + log.debug("JWKS validation key refresh thread shutting down"); + executorService.shutdown(); + + if (!executorService.awaitTermination(SHUTDOWN_TIMEOUT, SHUTDOWN_TIME_UNIT)) { + log.warn("JWKS validation key refresh thread termination did not end after {} {}", + SHUTDOWN_TIMEOUT, SHUTDOWN_TIME_UNIT); + } + } catch (InterruptedException e) { + log.warn("JWKS validation key refresh thread error during close", e); + } + } finally { + log.debug("close completed"); + } + } + + /** + * Our implementation avoids the blocking call within {@link HttpsJwks#refresh()} that is + * sometimes called internal to {@link HttpsJwks#getJsonWebKeys()}. We want to avoid any + * blocking I/O as this code is running in the authentication path on the Kafka network thread. + * + * The list may be stale up to {@link #refreshMs}. + * + * @return {@link List} of {@link JsonWebKey} instances + * + * @throws JoseException Thrown if a problem is encountered parsing the JSON content into JWKs + * @throws IOException Thrown f a problem is encountered making the HTTP request + */ + + public List getJsonWebKeys() throws JoseException, IOException { + if (!isInitialized) + throw new IllegalStateException("Please call init() first"); + + try { + refreshLock.readLock().lock(); + return jsonWebKeys; + } finally { + refreshLock.readLock().unlock(); + } + } + + public String getLocation() { + return httpsJwks.getLocation(); + } + + /** + *

            + * refresh is an internal method that will refresh the JWKS cache and is + * invoked in one of two ways: + * + *

              + *
            1. Scheduled
            2. + *
            3. Expedited
            4. + *
            + *

            + * + *

            + * The scheduled refresh is scheduled in {@link #init()} and runs every + * {@link #refreshMs} milliseconds. An expedited refresh is performed when an + * incoming JWT refers to a key ID that isn't in our JWKS cache ({@link #jsonWebKeys}) + * and we try to perform a refresh sooner than the next scheduled refresh. + *

            + */ + + private void refresh() { + if (!refreshInProgressFlag.compareAndSet(false, true)) { + log.debug("OAuth JWKS refresh is already in progress; ignoring concurrent refresh"); + return; + } + + try { + log.info("OAuth JWKS refresh of {} starting", httpsJwks.getLocation()); + Retry> retry = new Retry<>(refreshRetryBackoffMs, refreshRetryBackoffMaxMs); + List localJWKs = retry.execute(() -> { + try { + log.debug("JWKS validation key calling refresh of {} starting", httpsJwks.getLocation()); + // Call the *actual* refresh implementation that will more than likely issue + // HTTP(S) calls over the network. + httpsJwks.refresh(); + List jwks = httpsJwks.getJsonWebKeys(); + log.debug("JWKS validation key refresh of {} complete", httpsJwks.getLocation()); + return jwks; + } catch (Exception e) { + throw new ExecutionException(e); + } + }); + + try { + refreshLock.writeLock().lock(); + + for (JsonWebKey jwk : localJWKs) + missingKeyIds.remove(jwk.getKeyId()); + + jsonWebKeys = Collections.unmodifiableList(localJWKs); + } finally { + refreshLock.writeLock().unlock(); + } + + log.info("OAuth JWKS refresh of {} complete", httpsJwks.getLocation()); + } catch (ExecutionException e) { + log.warn("OAuth JWKS refresh of {} encountered an error; not updating local JWKS cache", httpsJwks.getLocation(), e); + } finally { + refreshInProgressFlag.set(false); + } + } + + /** + *

            + * maybeExpediteRefresh is a public method that will trigger a refresh of + * the JWKS cache if all of the following conditions are met: + * + *

              + *
            • The given keyId parameter is <e; the + * {@link #MISSING_KEY_ID_MAX_KEY_LENGTH}
            • + *
            • The key isn't in the process of being expedited already
            • + *
            + * + *

            + * This expedited refresh is scheduled immediately. + *

            + * + * @param keyId JWT key ID + * @return true if an expedited refresh was scheduled, false otherwise + */ + + public boolean maybeExpediteRefresh(String keyId) { + if (keyId.length() > MISSING_KEY_ID_MAX_KEY_LENGTH) { + // Although there's no limit on the length of the key ID, they're generally + // "reasonably" short. If we have a very long key ID length, we're going to assume + // the JWT is malformed, and we will not actually try to resolve the key. + // + // In this case, let's prevent blowing out our memory in two ways: + // + // 1. Don't try to resolve the key as the large ID will sit in our cache + // 2. Report the issue in the logs but include only the first N characters + int actualLength = keyId.length(); + String s = keyId.substring(0, MISSING_KEY_ID_MAX_KEY_LENGTH); + String snippet = String.format("%s (trimmed to first %s characters out of %s total)", s, MISSING_KEY_ID_MAX_KEY_LENGTH, actualLength); + log.warn("Key ID {} was too long to cache", snippet); + return false; + } else { + try { + refreshLock.writeLock().lock(); + + Long nextCheckTime = missingKeyIds.get(keyId); + long currTime = time.milliseconds(); + log.debug("For key ID {}, nextCheckTime: {}, currTime: {}", keyId, nextCheckTime, currTime); + + if (nextCheckTime == null || nextCheckTime <= currTime) { + // If there's no entry in the missing key ID cache for the incoming key ID, + // or it has expired, schedule a refresh ASAP. + nextCheckTime = currTime + MISSING_KEY_ID_CACHE_IN_FLIGHT_MS; + missingKeyIds.put(keyId, nextCheckTime); + executorService.schedule(this::refresh, 0, TimeUnit.MILLISECONDS); + return true; + } else { + return false; + } + } finally { + refreshLock.writeLock().unlock(); + } + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/RefreshingHttpsJwksVerificationKeyResolver.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/RefreshingHttpsJwksVerificationKeyResolver.java new file mode 100644 index 0000000..b496720 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/RefreshingHttpsJwksVerificationKeyResolver.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import java.io.IOException; +import java.security.Key; +import java.util.List; +import org.jose4j.jwk.HttpsJwks; +import org.jose4j.jwk.JsonWebKey; +import org.jose4j.jwk.VerificationJwkSelector; +import org.jose4j.jws.JsonWebSignature; +import org.jose4j.jwx.JsonWebStructure; +import org.jose4j.keys.resolvers.VerificationKeyResolver; +import org.jose4j.lang.JoseException; +import org.jose4j.lang.UnresolvableKeyException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * RefreshingHttpsJwksVerificationKeyResolver is a + * {@link VerificationKeyResolver} implementation that will periodically refresh the + * JWKS using its {@link HttpsJwks} instance. + * + * A JWKS (JSON Web Key Set) + * is a JSON document provided by the OAuth/OIDC provider that lists the keys used to sign the JWTs + * it issues. + * + * Here is a sample JWKS JSON document: + * + *
            + * {
            + *   "keys": [
            + *     {
            + *       "kty": "RSA",
            + *       "alg": "RS256",
            + *       "kid": "abc123",
            + *       "use": "sig",
            + *       "e": "AQAB",
            + *       "n": "..."
            + *     },
            + *     {
            + *       "kty": "RSA",
            + *       "alg": "RS256",
            + *       "kid": "def456",
            + *       "use": "sig",
            + *       "e": "AQAB",
            + *       "n": "..."
            + *     }
            + *   ]
            + * }
            + * 
            + * + * Without going into too much detail, the array of keys enumerates the key data that the provider + * is using to sign the JWT. The key ID (kid) is referenced by the JWT's header in + * order to match up the JWT's signing key with the key in the JWKS. During the validation step of + * the broker, the jose4j OAuth library will use the contents of the appropriate key in the JWKS + * to validate the signature. + * + * Given that the JWKS is referenced by the JWT, the JWKS must be made available by the + * OAuth/OIDC provider so that a JWT can be validated. + * + * @see CloseableVerificationKeyResolver + * @see VerificationKeyResolver + * @see RefreshingHttpsJwks + * @see HttpsJwks + */ + +public class RefreshingHttpsJwksVerificationKeyResolver implements CloseableVerificationKeyResolver { + + private static final Logger log = LoggerFactory.getLogger(RefreshingHttpsJwksVerificationKeyResolver.class); + + private final RefreshingHttpsJwks refreshingHttpsJwks; + + private final VerificationJwkSelector verificationJwkSelector; + + private boolean isInitialized; + + public RefreshingHttpsJwksVerificationKeyResolver(RefreshingHttpsJwks refreshingHttpsJwks) { + this.refreshingHttpsJwks = refreshingHttpsJwks; + this.verificationJwkSelector = new VerificationJwkSelector(); + } + + @Override + public void init() throws IOException { + try { + log.debug("init started"); + + refreshingHttpsJwks.init(); + } finally { + isInitialized = true; + + log.debug("init completed"); + } + } + + @Override + public void close() { + try { + log.debug("close started"); + + refreshingHttpsJwks.close(); + } finally { + log.debug("close completed"); + } + } + + @Override + public Key resolveKey(JsonWebSignature jws, List nestingContext) throws UnresolvableKeyException { + if (!isInitialized) + throw new IllegalStateException("Please call init() first"); + + try { + List jwks = refreshingHttpsJwks.getJsonWebKeys(); + JsonWebKey jwk = verificationJwkSelector.select(jws, jwks); + + if (jwk != null) + return jwk.getKey(); + + String keyId = jws.getKeyIdHeaderValue(); + + if (refreshingHttpsJwks.maybeExpediteRefresh(keyId)) + log.debug("Refreshing JWKs from {} as no suitable verification key for JWS w/ header {} was found in {}", refreshingHttpsJwks.getLocation(), jws.getHeaders().getFullHeaderAsJsonString(), jwks); + + StringBuilder sb = new StringBuilder(); + sb.append("Unable to find a suitable verification key for JWS w/ header ").append(jws.getHeaders().getFullHeaderAsJsonString()); + sb.append(" from JWKs ").append(jwks).append(" obtained from ").append( + refreshingHttpsJwks.getLocation()); + throw new UnresolvableKeyException(sb.toString()); + } catch (JoseException | IOException e) { + StringBuilder sb = new StringBuilder(); + sb.append("Unable to find a suitable verification key for JWS w/ header ").append(jws.getHeaders().getFullHeaderAsJsonString()); + sb.append(" due to an unexpected exception (").append(e).append(") while obtaining or using keys from JWKS endpoint at ").append( + refreshingHttpsJwks.getLocation()); + throw new UnresolvableKeyException(sb.toString(), e); + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/Retry.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/Retry.java new file mode 100644 index 0000000..ffa5672 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/Retry.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import java.util.concurrent.ExecutionException; +import org.apache.kafka.common.utils.Time; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Retry encapsulates the mechanism to perform a retry and then exponential + * backoff using provided wait times between attempts. + * + * @param Result type + */ + +public class Retry { + + private static final Logger log = LoggerFactory.getLogger(Retry.class); + + private final Time time; + + private final long retryBackoffMs; + + private final long retryBackoffMaxMs; + + public Retry(long retryBackoffMs, long retryBackoffMaxMs) { + this(Time.SYSTEM, retryBackoffMs, retryBackoffMaxMs); + } + + public Retry(Time time, long retryBackoffMs, long retryBackoffMaxMs) { + this.time = time; + this.retryBackoffMs = retryBackoffMs; + this.retryBackoffMaxMs = retryBackoffMaxMs; + + if (this.retryBackoffMs < 0) + throw new IllegalArgumentException(String.format("retryBackoffMs value (%s) must be non-negative", retryBackoffMs)); + + if (this.retryBackoffMaxMs < 0) + throw new IllegalArgumentException(String.format("retryBackoffMaxMs value (%s) must be non-negative", retryBackoffMaxMs)); + + if (this.retryBackoffMaxMs < this.retryBackoffMs) + throw new IllegalArgumentException(String.format("retryBackoffMaxMs value (%s) is less than retryBackoffMs value (%s)", retryBackoffMaxMs, retryBackoffMs)); + } + + public R execute(Retryable retryable) throws ExecutionException { + long endMs = time.milliseconds() + retryBackoffMaxMs; + int currAttempt = 0; + ExecutionException error = null; + + while (time.milliseconds() <= endMs) { + currAttempt++; + + try { + return retryable.call(); + } catch (UnretryableException e) { + // We've deemed this error to not be worth retrying, so collect the error and + // fail immediately. + if (error == null) + error = new ExecutionException(e); + + break; + } catch (ExecutionException e) { + log.warn("Error during retry attempt {}", currAttempt, e); + + if (error == null) + error = e; + + long waitMs = retryBackoffMs * (long) Math.pow(2, currAttempt - 1); + long diff = endMs - time.milliseconds(); + waitMs = Math.min(waitMs, diff); + + if (waitMs <= 0) + break; + + String message = String.format("Attempt %s to make call resulted in an error; sleeping %s ms before retrying", + currAttempt, waitMs); + log.warn(message, e); + + time.sleep(waitMs); + } + } + + if (error == null) + // Really shouldn't ever get to here, but... + error = new ExecutionException(new IllegalStateException("Exhausted all retry attempts but no attempt returned value or encountered exception")); + + throw error; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/Retryable.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/Retryable.java new file mode 100644 index 0000000..67967ad --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/Retryable.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import java.util.concurrent.ExecutionException; + +/** + * Simple interface to abstract out the call that is made so that it can be retried. + * + * @param Result type + * + * @see Retry + * @see UnretryableException + */ + +public interface Retryable { + + /** + * Perform the operation and return the data from the response. + * + * @return Return response data, formatted in the given data type + * + * @throws ExecutionException Thrown on errors connecting, writing, reading, timeouts, etc. + * that can likely be tried again + * @throws UnretryableException Thrown on errors that we can determine should not be tried again + */ + + R call() throws ExecutionException, UnretryableException; + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/SerializedJwt.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/SerializedJwt.java new file mode 100644 index 0000000..962d720 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/SerializedJwt.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +/** + * SerializedJwt provides a modicum of structure and validation around a JWT's serialized form by + * splitting and making the three sections (header, payload, and signature) available to the user. + */ + +public class SerializedJwt { + + private final String token; + + private final String header; + + private final String payload; + + private final String signature; + + public SerializedJwt(String token) { + if (token == null) + token = ""; + else + token = token.trim(); + + if (token.isEmpty()) + throw new ValidateException("Empty JWT provided; expected three sections (header, payload, and signature)"); + + String[] splits = token.split("\\."); + + if (splits.length != 3) + throw new ValidateException(String.format("Malformed JWT provided (%s); expected three sections (header, payload, and signature), but %s sections provided", + token, splits.length)); + + this.token = token.trim(); + this.header = validateSection(splits[0], "header"); + this.payload = validateSection(splits[1], "payload"); + this.signature = validateSection(splits[2], "signature"); + } + + /** + * Returns the entire base 64-encoded JWT. + * + * @return JWT + */ + + public String getToken() { + return token; + } + + /** + * Returns the first section--the JWT header--in its base 64-encoded form. + * + * @return Header section of the JWT + */ + + public String getHeader() { + return header; + } + + /** + * Returns the second section--the JWT payload--in its base 64-encoded form. + * + * @return Payload section of the JWT + */ + + public String getPayload() { + return payload; + } + + /** + * Returns the third section--the JWT signature--in its base 64-encoded form. + * + * @return Signature section of the JWT + */ + + public String getSignature() { + return signature; + } + + private String validateSection(String section, String sectionName) throws ValidateException { + section = section.trim(); + + if (section.isEmpty()) + throw new ValidateException(String.format( + "Malformed JWT provided; expected at least three sections (header, payload, and signature), but %s section missing", + sectionName)); + + return section; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/UnretryableException.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/UnretryableException.java new file mode 100644 index 0000000..1964cfb --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/UnretryableException.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import org.apache.kafka.common.KafkaException; + +public class UnretryableException extends KafkaException { + + public UnretryableException(String message) { + super(message); + } + + public UnretryableException(Throwable cause) { + super(cause); + } + + public UnretryableException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/ValidateException.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/ValidateException.java new file mode 100644 index 0000000..2ebebeb --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/ValidateException.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import javax.security.auth.callback.Callback; +import org.apache.kafka.common.KafkaException; + +/** + * ValidateException is thrown in cases where a JWT access token cannot be determined to be + * valid for one reason or another. It is intended to be used when errors arise within the + * processing of a {@link javax.security.auth.callback.CallbackHandler#handle(Callback[])}. + * This error, however, is not thrown from that method directly. + * + * @see AccessTokenValidator#validate(String) + */ + +public class ValidateException extends KafkaException { + + public ValidateException(String message) { + super(message); + } + + public ValidateException(Throwable cause) { + super(cause); + } + + public ValidateException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/ValidatorAccessTokenValidator.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/ValidatorAccessTokenValidator.java new file mode 100644 index 0000000..7668438 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/ValidatorAccessTokenValidator.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import static org.jose4j.jwa.AlgorithmConstraints.DISALLOW_NONE; + +import java.util.Collection; +import java.util.Collections; +import java.util.Set; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; +import org.jose4j.jwt.JwtClaims; +import org.jose4j.jwt.MalformedClaimException; +import org.jose4j.jwt.NumericDate; +import org.jose4j.jwt.ReservedClaimNames; +import org.jose4j.jwt.consumer.InvalidJwtException; +import org.jose4j.jwt.consumer.JwtConsumer; +import org.jose4j.jwt.consumer.JwtConsumerBuilder; +import org.jose4j.jwt.consumer.JwtContext; +import org.jose4j.keys.resolvers.VerificationKeyResolver; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * ValidatorAccessTokenValidator is an implementation of {@link AccessTokenValidator} that is used + * by the broker to perform more extensive validation of the JWT access token that is received + * from the client, but ultimately from posting the client credentials to the OAuth/OIDC provider's + * token endpoint. + * + * The validation steps performed (primary by the jose4j library) are: + * + *
              + *
            1. + * Basic structural validation of the b64token value as defined in + * RFC 6750 Section 2.1 + *
            2. + *
            3. Basic conversion of the token into an in-memory data structure
            4. + *
            5. + * Presence of scope, exp, subject, iss, and + * iat claims + *
            6. + *
            7. + * Signature matching validation against the kid and those provided by + * the OAuth/OIDC provider's JWKS + *
            8. + *
            + */ + +public class ValidatorAccessTokenValidator implements AccessTokenValidator { + + private static final Logger log = LoggerFactory.getLogger(ValidatorAccessTokenValidator.class); + + private final JwtConsumer jwtConsumer; + + private final String scopeClaimName; + + private final String subClaimName; + + /** + * Creates a new ValidatorAccessTokenValidator that will be used by the broker for more + * thorough validation of the JWT. + * + * @param clockSkew The optional value (in seconds) to allow for differences + * between the time of the OAuth/OIDC identity provider and + * the broker. If null is provided, the broker + * and the OAUth/OIDC identity provider are assumed to have + * very close clock settings. + * @param expectedAudiences The (optional) set the broker will use to verify that + * the JWT was issued for one of the expected audiences. + * The JWT will be inspected for the standard OAuth + * aud claim and if this value is set, the + * broker will match the value from JWT's aud + * claim to see if there is an exact match. If there is no + * match, the broker will reject the JWT and authentication + * will fail. May be null to not perform any + * check to verify the JWT's aud claim matches any + * fixed set of known/expected audiences. + * @param expectedIssuer The (optional) value for the broker to use to verify that + * the JWT was created by the expected issuer. The JWT will + * be inspected for the standard OAuth iss claim + * and if this value is set, the broker will match it + * exactly against what is in the JWT's iss + * claim. If there is no match, the broker will reject the JWT + * and authentication will fail. May be null to not + * perform any check to verify the JWT's iss claim + * matches a specific issuer. + * @param verificationKeyResolver jose4j-based {@link VerificationKeyResolver} that is used + * to validate the signature matches the contents of the header + * and payload + * @param scopeClaimName Name of the scope claim to use; must be non-null + * @param subClaimName Name of the subject claim to use; must be + * non-null + * + * @see JwtConsumerBuilder + * @see JwtConsumer + * @see VerificationKeyResolver + */ + + public ValidatorAccessTokenValidator(Integer clockSkew, + Set expectedAudiences, + String expectedIssuer, + VerificationKeyResolver verificationKeyResolver, + String scopeClaimName, + String subClaimName) { + final JwtConsumerBuilder jwtConsumerBuilder = new JwtConsumerBuilder(); + + if (clockSkew != null) + jwtConsumerBuilder.setAllowedClockSkewInSeconds(clockSkew); + + if (expectedAudiences != null && !expectedAudiences.isEmpty()) + jwtConsumerBuilder.setExpectedAudience(expectedAudiences.toArray(new String[0])); + + if (expectedIssuer != null) + jwtConsumerBuilder.setExpectedIssuer(expectedIssuer); + + this.jwtConsumer = jwtConsumerBuilder + .setJwsAlgorithmConstraints(DISALLOW_NONE) + .setRequireExpirationTime() + .setRequireIssuedAt() + .setRequireSubject() + .setVerificationKeyResolver(verificationKeyResolver) + .build(); + this.scopeClaimName = scopeClaimName; + this.subClaimName = subClaimName; + } + + /** + * Accepts an OAuth JWT access token in base-64 encoded format, validates, and returns an + * OAuthBearerToken. + * + * @param accessToken Non-null JWT access token + * @return {@link OAuthBearerToken} + * @throws ValidateException Thrown on errors performing validation of given token + */ + + @SuppressWarnings("unchecked") + public OAuthBearerToken validate(String accessToken) throws ValidateException { + SerializedJwt serializedJwt = new SerializedJwt(accessToken); + + JwtContext jwt; + + try { + jwt = jwtConsumer.process(serializedJwt.getToken()); + } catch (InvalidJwtException e) { + throw new ValidateException(String.format("Could not validate the access token: %s", e.getMessage()), e); + } + + JwtClaims claims = jwt.getJwtClaims(); + + Object scopeRaw = getClaim(() -> claims.getClaimValue(scopeClaimName), scopeClaimName); + Collection scopeRawCollection; + + if (scopeRaw instanceof String) + scopeRawCollection = Collections.singletonList((String) scopeRaw); + else if (scopeRaw instanceof Collection) + scopeRawCollection = (Collection) scopeRaw; + else + scopeRawCollection = Collections.emptySet(); + + NumericDate expirationRaw = getClaim(claims::getExpirationTime, ReservedClaimNames.EXPIRATION_TIME); + String subRaw = getClaim(() -> claims.getStringClaimValue(subClaimName), subClaimName); + NumericDate issuedAtRaw = getClaim(claims::getIssuedAt, ReservedClaimNames.ISSUED_AT); + + Set scopes = ClaimValidationUtils.validateScopes(scopeClaimName, scopeRawCollection); + long expiration = ClaimValidationUtils.validateExpiration(ReservedClaimNames.EXPIRATION_TIME, + expirationRaw != null ? expirationRaw.getValueInMillis() : null); + String sub = ClaimValidationUtils.validateSubject(subClaimName, subRaw); + Long issuedAt = ClaimValidationUtils.validateIssuedAt(ReservedClaimNames.ISSUED_AT, + issuedAtRaw != null ? issuedAtRaw.getValueInMillis() : null); + + OAuthBearerToken token = new BasicOAuthBearerToken(accessToken, + scopes, + expiration, + sub, + issuedAt); + + return token; + } + + private T getClaim(ClaimSupplier supplier, String claimName) throws ValidateException { + try { + T value = supplier.get(); + log.debug("getClaim - {}: {}", claimName, value); + return value; + } catch (MalformedClaimException e) { + throw new ValidateException(String.format("Could not extract the '%s' claim from the access token", claimName), e); + } + } + + public interface ClaimSupplier { + + T get() throws MalformedClaimException; + + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/VerificationKeyResolverFactory.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/VerificationKeyResolverFactory.java new file mode 100644 index 0000000..b6ec46a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/VerificationKeyResolverFactory.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_JWKS_ENDPOINT_REFRESH_MS; +import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_JWKS_ENDPOINT_RETRY_BACKOFF_MAX_MS; +import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_JWKS_ENDPOINT_RETRY_BACKOFF_MS; +import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_JWKS_ENDPOINT_URL; + +import java.net.URL; +import java.nio.file.Path; +import java.util.Locale; +import java.util.Map; +import javax.net.ssl.SSLSocketFactory; +import org.apache.kafka.common.utils.Time; +import org.jose4j.http.Get; +import org.jose4j.jwk.HttpsJwks; + +public class VerificationKeyResolverFactory { + + /** + * Create an {@link AccessTokenRetriever} from the given + * {@link org.apache.kafka.common.config.SaslConfigs}. + * + * Note: the returned CloseableVerificationKeyResolver is not + * initialized here and must be done by the caller. + * + * Primarily exposed here for unit testing. + * + * @param configs SASL configuration + * + * @return Non-null {@link CloseableVerificationKeyResolver} + */ + public static CloseableVerificationKeyResolver create(Map configs, + Map jaasConfig) { + return create(configs, null, jaasConfig); + } + + public static CloseableVerificationKeyResolver create(Map configs, + String saslMechanism, + Map jaasConfig) { + ConfigurationUtils cu = new ConfigurationUtils(configs, saslMechanism); + URL jwksEndpointUrl = cu.validateUrl(SASL_OAUTHBEARER_JWKS_ENDPOINT_URL); + + if (jwksEndpointUrl.getProtocol().toLowerCase(Locale.ROOT).equals("file")) { + Path p = cu.validateFile(SASL_OAUTHBEARER_JWKS_ENDPOINT_URL); + return new JwksFileVerificationKeyResolver(p); + } else { + long refreshIntervalMs = cu.validateLong(SASL_OAUTHBEARER_JWKS_ENDPOINT_REFRESH_MS, true, 0L); + JaasOptionsUtils jou = new JaasOptionsUtils(jaasConfig); + SSLSocketFactory sslSocketFactory = null; + + if (jou.shouldCreateSSLSocketFactory(jwksEndpointUrl)) + sslSocketFactory = jou.createSSLSocketFactory(); + + HttpsJwks httpsJwks = new HttpsJwks(jwksEndpointUrl.toString()); + httpsJwks.setDefaultCacheDuration(refreshIntervalMs); + + if (sslSocketFactory != null) { + Get get = new Get(); + get.setSslSocketFactory(sslSocketFactory); + httpsJwks.setSimpleHttpGet(get); + } + + RefreshingHttpsJwks refreshingHttpsJwks = new RefreshingHttpsJwks(Time.SYSTEM, + httpsJwks, + refreshIntervalMs, + cu.validateLong(SASL_OAUTHBEARER_JWKS_ENDPOINT_RETRY_BACKOFF_MS), + cu.validateLong(SASL_OAUTHBEARER_JWKS_ENDPOINT_RETRY_BACKOFF_MAX_MS)); + return new RefreshingHttpsJwksVerificationKeyResolver(refreshingHttpsJwks); + } + } + +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/common/security/plain/PlainAuthenticateCallback.java b/clients/src/main/java/org/apache/kafka/common/security/plain/PlainAuthenticateCallback.java new file mode 100644 index 0000000..7f42645 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/plain/PlainAuthenticateCallback.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.plain; + +import javax.security.auth.callback.Callback; + +/* + * Authentication callback for SASL/PLAIN authentication. Callback handler must + * set authenticated flag to true if the client provided password in the callback + * matches the expected password. + */ +public class PlainAuthenticateCallback implements Callback { + private final char[] password; + private boolean authenticated; + + /** + * Creates a callback with the password provided by the client + * @param password The password provided by the client during SASL/PLAIN authentication + */ + public PlainAuthenticateCallback(char[] password) { + this.password = password; + } + + /** + * Returns the password provided by the client during SASL/PLAIN authentication + */ + public char[] password() { + return password; + } + + /** + * Returns true if client password matches expected password, false otherwise. + * This state is set the server-side callback handler. + */ + public boolean authenticated() { + return this.authenticated; + } + + /** + * Sets the authenticated state. This is set by the server-side callback handler + * by matching the client provided password with the expected password. + * + * @param authenticated true indicates successful authentication + */ + public void authenticated(boolean authenticated) { + this.authenticated = authenticated; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/plain/PlainLoginModule.java b/clients/src/main/java/org/apache/kafka/common/security/plain/PlainLoginModule.java new file mode 100644 index 0000000..4085168 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/plain/PlainLoginModule.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.plain; + +import org.apache.kafka.common.security.plain.internals.PlainSaslServerProvider; + +import java.util.Map; + +import javax.security.auth.Subject; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.spi.LoginModule; + +public class PlainLoginModule implements LoginModule { + + private static final String USERNAME_CONFIG = "username"; + private static final String PASSWORD_CONFIG = "password"; + + static { + PlainSaslServerProvider.initialize(); + } + + @Override + public void initialize(Subject subject, CallbackHandler callbackHandler, Map sharedState, Map options) { + String username = (String) options.get(USERNAME_CONFIG); + if (username != null) + subject.getPublicCredentials().add(username); + String password = (String) options.get(PASSWORD_CONFIG); + if (password != null) + subject.getPrivateCredentials().add(password); + } + + @Override + public boolean login() { + return true; + } + + @Override + public boolean logout() { + return true; + } + + @Override + public boolean commit() { + return true; + } + + @Override + public boolean abort() { + return false; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/plain/internals/PlainSaslServer.java b/clients/src/main/java/org/apache/kafka/common/security/plain/internals/PlainSaslServer.java new file mode 100644 index 0000000..7a65fe6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/plain/internals/PlainSaslServer.java @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.plain.internals; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.NameCallback; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslException; +import javax.security.sasl.SaslServer; +import javax.security.sasl.SaslServerFactory; + +import org.apache.kafka.common.errors.SaslAuthenticationException; +import org.apache.kafka.common.security.plain.PlainAuthenticateCallback; + +/** + * Simple SaslServer implementation for SASL/PLAIN. In order to make this implementation + * fully pluggable, authentication of username/password is fully contained within the + * server implementation. + *

            + * Valid users with passwords are specified in the Jaas configuration file. Each user + * is specified with user_ as key and as value. This is consistent + * with Zookeeper Digest-MD5 implementation. + *

            + * To avoid storing clear passwords on disk or to integrate with external authentication + * servers in production systems, this module can be replaced with a different implementation. + * + */ +public class PlainSaslServer implements SaslServer { + + public static final String PLAIN_MECHANISM = "PLAIN"; + + private final CallbackHandler callbackHandler; + private boolean complete; + private String authorizationId; + + public PlainSaslServer(CallbackHandler callbackHandler) { + this.callbackHandler = callbackHandler; + } + + /** + * @throws SaslAuthenticationException if username/password combination is invalid or if the requested + * authorization id is not the same as username. + *

            + * Note: This method may throw {@link SaslAuthenticationException} to provide custom error messages + * to clients. But care should be taken to avoid including any information in the exception message that + * should not be leaked to unauthenticated clients. It may be safer to throw {@link SaslException} in + * some cases so that a standard error message is returned to clients. + *

            + */ + @Override + public byte[] evaluateResponse(byte[] responseBytes) throws SaslAuthenticationException { + /* + * Message format (from https://tools.ietf.org/html/rfc4616): + * + * message = [authzid] UTF8NUL authcid UTF8NUL passwd + * authcid = 1*SAFE ; MUST accept up to 255 octets + * authzid = 1*SAFE ; MUST accept up to 255 octets + * passwd = 1*SAFE ; MUST accept up to 255 octets + * UTF8NUL = %x00 ; UTF-8 encoded NUL character + * + * SAFE = UTF1 / UTF2 / UTF3 / UTF4 + * ;; any UTF-8 encoded Unicode character except NUL + */ + + String response = new String(responseBytes, StandardCharsets.UTF_8); + List tokens = extractTokens(response); + String authorizationIdFromClient = tokens.get(0); + String username = tokens.get(1); + String password = tokens.get(2); + + if (username.isEmpty()) { + throw new SaslAuthenticationException("Authentication failed: username not specified"); + } + if (password.isEmpty()) { + throw new SaslAuthenticationException("Authentication failed: password not specified"); + } + + NameCallback nameCallback = new NameCallback("username", username); + PlainAuthenticateCallback authenticateCallback = new PlainAuthenticateCallback(password.toCharArray()); + try { + callbackHandler.handle(new Callback[]{nameCallback, authenticateCallback}); + } catch (Throwable e) { + throw new SaslAuthenticationException("Authentication failed: credentials for user could not be verified", e); + } + if (!authenticateCallback.authenticated()) + throw new SaslAuthenticationException("Authentication failed: Invalid username or password"); + if (!authorizationIdFromClient.isEmpty() && !authorizationIdFromClient.equals(username)) + throw new SaslAuthenticationException("Authentication failed: Client requested an authorization id that is different from username"); + + this.authorizationId = username; + + complete = true; + return new byte[0]; + } + + private List extractTokens(String string) { + List tokens = new ArrayList<>(); + int startIndex = 0; + for (int i = 0; i < 4; ++i) { + int endIndex = string.indexOf("\u0000", startIndex); + if (endIndex == -1) { + tokens.add(string.substring(startIndex)); + break; + } + tokens.add(string.substring(startIndex, endIndex)); + startIndex = endIndex + 1; + } + + if (tokens.size() != 3) + throw new SaslAuthenticationException("Invalid SASL/PLAIN response: expected 3 tokens, got " + + tokens.size()); + + return tokens; + } + + @Override + public String getAuthorizationID() { + if (!complete) + throw new IllegalStateException("Authentication exchange has not completed"); + return authorizationId; + } + + @Override + public String getMechanismName() { + return PLAIN_MECHANISM; + } + + @Override + public Object getNegotiatedProperty(String propName) { + if (!complete) + throw new IllegalStateException("Authentication exchange has not completed"); + return null; + } + + @Override + public boolean isComplete() { + return complete; + } + + @Override + public byte[] unwrap(byte[] incoming, int offset, int len) { + if (!complete) + throw new IllegalStateException("Authentication exchange has not completed"); + return Arrays.copyOfRange(incoming, offset, offset + len); + } + + @Override + public byte[] wrap(byte[] outgoing, int offset, int len) { + if (!complete) + throw new IllegalStateException("Authentication exchange has not completed"); + return Arrays.copyOfRange(outgoing, offset, offset + len); + } + + @Override + public void dispose() { + } + + public static class PlainSaslServerFactory implements SaslServerFactory { + + @Override + public SaslServer createSaslServer(String mechanism, String protocol, String serverName, Map props, CallbackHandler cbh) + throws SaslException { + + if (!PLAIN_MECHANISM.equals(mechanism)) + throw new SaslException(String.format("Mechanism \'%s\' is not supported. Only PLAIN is supported.", mechanism)); + + return new PlainSaslServer(cbh); + } + + @Override + public String[] getMechanismNames(Map props) { + if (props == null) return new String[]{PLAIN_MECHANISM}; + String noPlainText = (String) props.get(Sasl.POLICY_NOPLAINTEXT); + if ("true".equals(noPlainText)) + return new String[]{}; + else + return new String[]{PLAIN_MECHANISM}; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/plain/internals/PlainSaslServerProvider.java b/clients/src/main/java/org/apache/kafka/common/security/plain/internals/PlainSaslServerProvider.java new file mode 100644 index 0000000..33c3dc5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/plain/internals/PlainSaslServerProvider.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.plain.internals; + +import java.security.Provider; +import java.security.Security; + +import org.apache.kafka.common.security.plain.internals.PlainSaslServer.PlainSaslServerFactory; + +public class PlainSaslServerProvider extends Provider { + + private static final long serialVersionUID = 1L; + + @SuppressWarnings("deprecation") + protected PlainSaslServerProvider() { + super("Simple SASL/PLAIN Server Provider", 1.0, "Simple SASL/PLAIN Server Provider for Kafka"); + put("SaslServerFactory." + PlainSaslServer.PLAIN_MECHANISM, PlainSaslServerFactory.class.getName()); + } + + public static void initialize() { + Security.addProvider(new PlainSaslServerProvider()); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/plain/internals/PlainServerCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/plain/internals/PlainServerCallbackHandler.java new file mode 100644 index 0000000..10f5817 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/plain/internals/PlainServerCallbackHandler.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.plain.internals; + +import org.apache.kafka.common.security.JaasContext; +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.security.plain.PlainAuthenticateCallback; +import org.apache.kafka.common.security.plain.PlainLoginModule; +import org.apache.kafka.common.utils.Utils; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.auth.login.AppConfigurationEntry; + +public class PlainServerCallbackHandler implements AuthenticateCallbackHandler { + + private static final String JAAS_USER_PREFIX = "user_"; + private List jaasConfigEntries; + + @Override + public void configure(Map configs, String mechanism, List jaasConfigEntries) { + this.jaasConfigEntries = jaasConfigEntries; + } + + @Override + public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { + String username = null; + for (Callback callback: callbacks) { + if (callback instanceof NameCallback) + username = ((NameCallback) callback).getDefaultName(); + else if (callback instanceof PlainAuthenticateCallback) { + PlainAuthenticateCallback plainCallback = (PlainAuthenticateCallback) callback; + boolean authenticated = authenticate(username, plainCallback.password()); + plainCallback.authenticated(authenticated); + } else + throw new UnsupportedCallbackException(callback); + } + } + + protected boolean authenticate(String username, char[] password) throws IOException { + if (username == null) + return false; + else { + String expectedPassword = JaasContext.configEntryOption(jaasConfigEntries, + JAAS_USER_PREFIX + username, + PlainLoginModule.class.getName()); + return expectedPassword != null && Utils.isEqualConstantTime(password, expectedPassword.toCharArray()); + } + } + + @Override + public void close() throws KafkaException { + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramCredential.java b/clients/src/main/java/org/apache/kafka/common/security/scram/ScramCredential.java new file mode 100644 index 0000000..dfbfef1 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/scram/ScramCredential.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.scram; + +/** + * SCRAM credential class that encapsulates the credential data persisted for each user that is + * accessible to the server. See RFC rfc5802 + * for details. + */ +public class ScramCredential { + + private final byte[] salt; + private final byte[] serverKey; + private final byte[] storedKey; + private final int iterations; + + /** + * Constructs a new credential. + */ + public ScramCredential(byte[] salt, byte[] storedKey, byte[] serverKey, int iterations) { + this.salt = salt; + this.serverKey = serverKey; + this.storedKey = storedKey; + this.iterations = iterations; + } + + /** + * Returns the salt used to process this credential using the SCRAM algorithm. + */ + public byte[] salt() { + return salt; + } + + /** + * Server key computed from the client password using the SCRAM algorithm. + */ + public byte[] serverKey() { + return serverKey; + } + + /** + * Stored key computed from the client password using the SCRAM algorithm. + */ + public byte[] storedKey() { + return storedKey; + } + + /** + * Number of iterations used to process this credential using the SCRAM algorithm. + */ + public int iterations() { + return iterations; + } +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramCredentialCallback.java b/clients/src/main/java/org/apache/kafka/common/security/scram/ScramCredentialCallback.java new file mode 100644 index 0000000..d5988cb --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/scram/ScramCredentialCallback.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.scram; + +import javax.security.auth.callback.Callback; + +/** + * Callback used for SCRAM mechanisms. + */ +public class ScramCredentialCallback implements Callback { + private ScramCredential scramCredential; + + /** + * Sets the SCRAM credential for this instance. + */ + public void scramCredential(ScramCredential scramCredential) { + this.scramCredential = scramCredential; + } + + /** + * Returns the SCRAM credential if set on this instance. + */ + public ScramCredential scramCredential() { + return scramCredential; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramExtensionsCallback.java b/clients/src/main/java/org/apache/kafka/common/security/scram/ScramExtensionsCallback.java new file mode 100644 index 0000000..b83c94e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/scram/ScramExtensionsCallback.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.scram; + +import javax.security.auth.callback.Callback; +import java.util.Collections; +import java.util.Map; + + +/** + * Optional callback used for SCRAM mechanisms if any extensions need to be set + * in the SASL/SCRAM exchange. + */ +public class ScramExtensionsCallback implements Callback { + private Map extensions = Collections.emptyMap(); + + /** + * Returns map of the extension names and values that are sent by the client to + * the server in the initial client SCRAM authentication message. + * Default is an empty unmodifiable map. + */ + public Map extensions() { + return extensions; + } + + /** + * Sets the SCRAM extensions on this callback. Maps passed in should be unmodifiable + */ + public void extensions(Map extensions) { + this.extensions = extensions; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/ScramLoginModule.java b/clients/src/main/java/org/apache/kafka/common/security/scram/ScramLoginModule.java new file mode 100644 index 0000000..104b3fc --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/scram/ScramLoginModule.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.scram; + +import org.apache.kafka.common.security.scram.internals.ScramSaslClientProvider; +import org.apache.kafka.common.security.scram.internals.ScramSaslServerProvider; + +import java.util.Collections; +import java.util.Map; + +import javax.security.auth.Subject; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.spi.LoginModule; + +public class ScramLoginModule implements LoginModule { + + private static final String USERNAME_CONFIG = "username"; + private static final String PASSWORD_CONFIG = "password"; + public static final String TOKEN_AUTH_CONFIG = "tokenauth"; + + static { + ScramSaslClientProvider.initialize(); + ScramSaslServerProvider.initialize(); + } + + @Override + public void initialize(Subject subject, CallbackHandler callbackHandler, Map sharedState, Map options) { + String username = (String) options.get(USERNAME_CONFIG); + if (username != null) + subject.getPublicCredentials().add(username); + String password = (String) options.get(PASSWORD_CONFIG); + if (password != null) + subject.getPrivateCredentials().add(password); + + Boolean useTokenAuthentication = "true".equalsIgnoreCase((String) options.get(TOKEN_AUTH_CONFIG)); + if (useTokenAuthentication) { + Map scramExtensions = Collections.singletonMap(TOKEN_AUTH_CONFIG, "true"); + subject.getPublicCredentials().add(scramExtensions); + } + } + + @Override + public boolean login() { + return true; + } + + @Override + public boolean logout() { + return true; + } + + @Override + public boolean commit() { + return true; + } + + @Override + public boolean abort() { + return false; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramCredentialUtils.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramCredentialUtils.java new file mode 100644 index 0000000..0ce51a5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramCredentialUtils.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.scram.internals; + +import java.util.Base64; +import java.util.Collection; +import java.util.Properties; + +import org.apache.kafka.common.security.authenticator.CredentialCache; +import org.apache.kafka.common.security.scram.ScramCredential; + +/** + * SCRAM Credential persistence utility functions. Implements format conversion used + * for the credential store implemented in Kafka. Credentials are persisted as a comma-separated + * String of key-value pairs: + *
            + *   salt=salt,stored_key=stored_key,server_key=server_key,iterations=iterations
            + * 
            + * + */ +public final class ScramCredentialUtils { + private static final String SALT = "salt"; + private static final String STORED_KEY = "stored_key"; + private static final String SERVER_KEY = "server_key"; + private static final String ITERATIONS = "iterations"; + + private ScramCredentialUtils() {} + + public static String credentialToString(ScramCredential credential) { + return String.format("%s=%s,%s=%s,%s=%s,%s=%d", + SALT, + Base64.getEncoder().encodeToString(credential.salt()), + STORED_KEY, + Base64.getEncoder().encodeToString(credential.storedKey()), + SERVER_KEY, + Base64.getEncoder().encodeToString(credential.serverKey()), + ITERATIONS, + credential.iterations()); + } + + public static ScramCredential credentialFromString(String str) { + Properties props = toProps(str); + if (props.size() != 4 || !props.containsKey(SALT) || !props.containsKey(STORED_KEY) || + !props.containsKey(SERVER_KEY) || !props.containsKey(ITERATIONS)) { + throw new IllegalArgumentException("Credentials not valid: " + str); + } + byte[] salt = Base64.getDecoder().decode(props.getProperty(SALT)); + byte[] storedKey = Base64.getDecoder().decode(props.getProperty(STORED_KEY)); + byte[] serverKey = Base64.getDecoder().decode(props.getProperty(SERVER_KEY)); + int iterations = Integer.parseInt(props.getProperty(ITERATIONS)); + return new ScramCredential(salt, storedKey, serverKey, iterations); + } + + private static Properties toProps(String str) { + Properties props = new Properties(); + String[] tokens = str.split(","); + for (String token : tokens) { + int index = token.indexOf('='); + if (index <= 0) + throw new IllegalArgumentException("Credentials not valid: " + str); + props.put(token.substring(0, index), token.substring(index + 1)); + } + return props; + } + + public static void createCache(CredentialCache cache, Collection mechanisms) { + for (String mechanism : ScramMechanism.mechanismNames()) { + if (mechanisms.contains(mechanism)) + cache.createCache(mechanism, ScramCredential.class); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramExtensions.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramExtensions.java new file mode 100644 index 0000000..439bcfe --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramExtensions.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.scram.internals; + +import org.apache.kafka.common.security.auth.SaslExtensions; +import org.apache.kafka.common.security.scram.ScramLoginModule; +import org.apache.kafka.common.utils.Utils; + +import java.util.Collections; +import java.util.Map; + +public class ScramExtensions extends SaslExtensions { + + public ScramExtensions() { + this(Collections.emptyMap()); + } + + public ScramExtensions(String extensions) { + this(Utils.parseMap(extensions, "=", ",")); + } + + public ScramExtensions(Map extensionMap) { + super(extensionMap); + } + + public boolean tokenAuthenticated() { + return Boolean.parseBoolean(map().get(ScramLoginModule.TOKEN_AUTH_CONFIG)); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramFormatter.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramFormatter.java new file mode 100644 index 0000000..4c03fa1 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramFormatter.java @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.scram.internals; + +import java.math.BigInteger; +import java.nio.charset.StandardCharsets; +import java.security.InvalidKeyException; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import javax.crypto.Mac; +import javax.crypto.spec.SecretKeySpec; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.security.scram.ScramCredential; +import org.apache.kafka.common.security.scram.internals.ScramMessages.ClientFinalMessage; +import org.apache.kafka.common.security.scram.internals.ScramMessages.ClientFirstMessage; +import org.apache.kafka.common.security.scram.internals.ScramMessages.ServerFirstMessage; + +/** + * Scram message salt and hash functions defined in RFC 5802. + */ +public class ScramFormatter { + + private static final Pattern EQUAL = Pattern.compile("=", Pattern.LITERAL); + private static final Pattern COMMA = Pattern.compile(",", Pattern.LITERAL); + private static final Pattern EQUAL_TWO_C = Pattern.compile("=2C", Pattern.LITERAL); + private static final Pattern EQUAL_THREE_D = Pattern.compile("=3D", Pattern.LITERAL); + + private final MessageDigest messageDigest; + private final Mac mac; + private final SecureRandom random; + + public ScramFormatter(ScramMechanism mechanism) throws NoSuchAlgorithmException { + this.messageDigest = MessageDigest.getInstance(mechanism.hashAlgorithm()); + this.mac = Mac.getInstance(mechanism.macAlgorithm()); + this.random = new SecureRandom(); + } + + public byte[] hmac(byte[] key, byte[] bytes) throws InvalidKeyException { + mac.init(new SecretKeySpec(key, mac.getAlgorithm())); + return mac.doFinal(bytes); + } + + public byte[] hash(byte[] str) { + return messageDigest.digest(str); + } + + public static byte[] xor(byte[] first, byte[] second) { + if (first.length != second.length) + throw new IllegalArgumentException("Argument arrays must be of the same length"); + byte[] result = new byte[first.length]; + for (int i = 0; i < result.length; i++) + result[i] = (byte) (first[i] ^ second[i]); + return result; + } + + public byte[] hi(byte[] str, byte[] salt, int iterations) throws InvalidKeyException { + mac.init(new SecretKeySpec(str, mac.getAlgorithm())); + mac.update(salt); + byte[] u1 = mac.doFinal(new byte[]{0, 0, 0, 1}); + byte[] prev = u1; + byte[] result = u1; + for (int i = 2; i <= iterations; i++) { + byte[] ui = hmac(str, prev); + result = xor(result, ui); + prev = ui; + } + return result; + } + + public static byte[] normalize(String str) { + return toBytes(str); + } + + public byte[] saltedPassword(String password, byte[] salt, int iterations) throws InvalidKeyException { + return hi(normalize(password), salt, iterations); + } + + public byte[] clientKey(byte[] saltedPassword) throws InvalidKeyException { + return hmac(saltedPassword, toBytes("Client Key")); + } + + public byte[] storedKey(byte[] clientKey) { + return hash(clientKey); + } + + public static String saslName(String username) { + String replace1 = EQUAL.matcher(username).replaceAll(Matcher.quoteReplacement("=3D")); + return COMMA.matcher(replace1).replaceAll(Matcher.quoteReplacement("=2C")); + } + + public static String username(String saslName) { + String username = EQUAL_TWO_C.matcher(saslName).replaceAll(Matcher.quoteReplacement(",")); + if (EQUAL_THREE_D.matcher(username).replaceAll(Matcher.quoteReplacement("")).indexOf('=') >= 0) { + throw new IllegalArgumentException("Invalid username: " + saslName); + } + return EQUAL_THREE_D.matcher(username).replaceAll(Matcher.quoteReplacement("=")); + } + + public static String authMessage(String clientFirstMessageBare, String serverFirstMessage, String clientFinalMessageWithoutProof) { + return clientFirstMessageBare + "," + serverFirstMessage + "," + clientFinalMessageWithoutProof; + } + + public byte[] clientSignature(byte[] storedKey, ClientFirstMessage clientFirstMessage, ServerFirstMessage serverFirstMessage, ClientFinalMessage clientFinalMessage) throws InvalidKeyException { + byte[] authMessage = authMessage(clientFirstMessage, serverFirstMessage, clientFinalMessage); + return hmac(storedKey, authMessage); + } + + public byte[] clientProof(byte[] saltedPassword, ClientFirstMessage clientFirstMessage, ServerFirstMessage serverFirstMessage, ClientFinalMessage clientFinalMessage) throws InvalidKeyException { + byte[] clientKey = clientKey(saltedPassword); + byte[] storedKey = hash(clientKey); + byte[] clientSignature = hmac(storedKey, authMessage(clientFirstMessage, serverFirstMessage, clientFinalMessage)); + return xor(clientKey, clientSignature); + } + + private byte[] authMessage(ClientFirstMessage clientFirstMessage, ServerFirstMessage serverFirstMessage, ClientFinalMessage clientFinalMessage) { + return toBytes(authMessage(clientFirstMessage.clientFirstMessageBare(), + serverFirstMessage.toMessage(), + clientFinalMessage.clientFinalMessageWithoutProof())); + } + + public byte[] storedKey(byte[] clientSignature, byte[] clientProof) { + return hash(xor(clientSignature, clientProof)); + } + + public byte[] serverKey(byte[] saltedPassword) throws InvalidKeyException { + return hmac(saltedPassword, toBytes("Server Key")); + } + + public byte[] serverSignature(byte[] serverKey, ClientFirstMessage clientFirstMessage, ServerFirstMessage serverFirstMessage, ClientFinalMessage clientFinalMessage) throws InvalidKeyException { + byte[] authMessage = authMessage(clientFirstMessage, serverFirstMessage, clientFinalMessage); + return hmac(serverKey, authMessage); + } + + public String secureRandomString() { + return secureRandomString(random); + } + + public static String secureRandomString(SecureRandom random) { + return new BigInteger(130, random).toString(Character.MAX_RADIX); + } + + public byte[] secureRandomBytes() { + return secureRandomBytes(random); + } + + public static byte[] secureRandomBytes(SecureRandom random) { + return toBytes(secureRandomString(random)); + } + + public static byte[] toBytes(String str) { + return str.getBytes(StandardCharsets.UTF_8); + } + + public ScramCredential generateCredential(String password, int iterations) { + try { + byte[] salt = secureRandomBytes(); + byte[] saltedPassword = saltedPassword(password, salt, iterations); + return generateCredential(salt, saltedPassword, iterations); + } catch (InvalidKeyException e) { + throw new KafkaException("Could not create credential", e); + } + } + + public ScramCredential generateCredential(byte[] salt, byte[] saltedPassword, int iterations) { + try { + byte[] clientKey = clientKey(saltedPassword); + byte[] storedKey = storedKey(clientKey); + byte[] serverKey = serverKey(saltedPassword); + return new ScramCredential(salt, storedKey, serverKey, iterations); + } catch (InvalidKeyException e) { + throw new KafkaException("Could not create credential", e); + } + } +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramMechanism.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramMechanism.java new file mode 100644 index 0000000..9f6e69d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramMechanism.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.scram.internals; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +public enum ScramMechanism { + + SCRAM_SHA_256("SHA-256", "HmacSHA256", 4096), + SCRAM_SHA_512("SHA-512", "HmacSHA512", 4096); + + private final String mechanismName; + private final String hashAlgorithm; + private final String macAlgorithm; + private final int minIterations; + + private static final Map MECHANISMS_MAP; + + static { + Map map = new HashMap<>(); + for (ScramMechanism mech : values()) + map.put(mech.mechanismName, mech); + MECHANISMS_MAP = Collections.unmodifiableMap(map); + } + + ScramMechanism(String hashAlgorithm, String macAlgorithm, int minIterations) { + this.mechanismName = "SCRAM-" + hashAlgorithm; + this.hashAlgorithm = hashAlgorithm; + this.macAlgorithm = macAlgorithm; + this.minIterations = minIterations; + } + + public final String mechanismName() { + return mechanismName; + } + + public String hashAlgorithm() { + return hashAlgorithm; + } + + public String macAlgorithm() { + return macAlgorithm; + } + + public int minIterations() { + return minIterations; + } + + public static ScramMechanism forMechanismName(String mechanismName) { + return MECHANISMS_MAP.get(mechanismName); + } + + public static Collection mechanismNames() { + return MECHANISMS_MAP.keySet(); + } + + public static boolean isScram(String mechanismName) { + return MECHANISMS_MAP.containsKey(mechanismName); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramMessages.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramMessages.java new file mode 100644 index 0000000..0551296 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramMessages.java @@ -0,0 +1,288 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.scram.internals; + +import org.apache.kafka.common.utils.Utils; + +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import javax.security.sasl.SaslException; + +/** + * SCRAM request/response message creation and parsing based on + * RFC 5802 + * + */ +public class ScramMessages { + + static abstract class AbstractScramMessage { + + static final String ALPHA = "[A-Za-z]+"; + static final String VALUE_SAFE = "[\\x01-\\x7F&&[^=,]]+"; + static final String VALUE = "[\\x01-\\x7F&&[^,]]+"; + static final String PRINTABLE = "[\\x21-\\x7E&&[^,]]+"; + static final String SASLNAME = "(?:[\\x01-\\x7F&&[^=,]]|=2C|=3D)+"; + static final String BASE64_CHAR = "[a-zA-Z0-9/+]"; + static final String BASE64 = String.format("(?:%s{4})*(?:%s{3}=|%s{2}==)?", BASE64_CHAR, BASE64_CHAR, BASE64_CHAR); + static final String RESERVED = String.format("(m=%s,)?", VALUE); + static final String EXTENSIONS = String.format("(,%s=%s)*", ALPHA, VALUE); + + abstract String toMessage(); + + public byte[] toBytes() { + return toMessage().getBytes(StandardCharsets.UTF_8); + } + + protected String toMessage(byte[] messageBytes) { + return new String(messageBytes, StandardCharsets.UTF_8); + } + } + + /** + * Format: + * gs2-header [reserved-mext ","] username "," nonce ["," extensions] + * Limitations: + * Only gs2-header "n" is supported. + * Extensions are ignored. + * + */ + public static class ClientFirstMessage extends AbstractScramMessage { + private static final Pattern PATTERN = Pattern.compile(String.format( + "n,(a=(?%s))?,%sn=(?%s),r=(?%s)(?%s)", + SASLNAME, + RESERVED, + SASLNAME, + PRINTABLE, + EXTENSIONS)); + + + private final String saslName; + private final String nonce; + private final String authorizationId; + private final ScramExtensions extensions; + public ClientFirstMessage(byte[] messageBytes) throws SaslException { + String message = toMessage(messageBytes); + Matcher matcher = PATTERN.matcher(message); + if (!matcher.matches()) + throw new SaslException("Invalid SCRAM client first message format: " + message); + String authzid = matcher.group("authzid"); + this.authorizationId = authzid != null ? authzid : ""; + this.saslName = matcher.group("saslname"); + this.nonce = matcher.group("nonce"); + String extString = matcher.group("extensions"); + + this.extensions = extString.startsWith(",") ? new ScramExtensions(extString.substring(1)) : new ScramExtensions(); + } + public ClientFirstMessage(String saslName, String nonce, Map extensions) { + this.saslName = saslName; + this.nonce = nonce; + this.extensions = new ScramExtensions(extensions); + this.authorizationId = ""; // Optional authzid not specified in gs2-header + } + public String saslName() { + return saslName; + } + public String nonce() { + return nonce; + } + public String authorizationId() { + return authorizationId; + } + public String gs2Header() { + return "n," + authorizationId + ","; + } + public ScramExtensions extensions() { + return extensions; + } + + public String clientFirstMessageBare() { + String extensionStr = Utils.mkString(extensions.map(), "", "", "=", ","); + + if (extensionStr.isEmpty()) + return String.format("n=%s,r=%s", saslName, nonce); + else + return String.format("n=%s,r=%s,%s", saslName, nonce, extensionStr); + } + String toMessage() { + return gs2Header() + clientFirstMessageBare(); + } + } + + /** + * Format: + * [reserved-mext ","] nonce "," salt "," iteration-count ["," extensions] + * Limitations: + * Extensions are ignored. + * + */ + public static class ServerFirstMessage extends AbstractScramMessage { + private static final Pattern PATTERN = Pattern.compile(String.format( + "%sr=(?%s),s=(?%s),i=(?[0-9]+)%s", + RESERVED, + PRINTABLE, + BASE64, + EXTENSIONS)); + + private final String nonce; + private final byte[] salt; + private final int iterations; + public ServerFirstMessage(byte[] messageBytes) throws SaslException { + String message = toMessage(messageBytes); + Matcher matcher = PATTERN.matcher(message); + if (!matcher.matches()) + throw new SaslException("Invalid SCRAM server first message format: " + message); + try { + this.iterations = Integer.parseInt(matcher.group("iterations")); + if (this.iterations <= 0) + throw new SaslException("Invalid SCRAM server first message format: invalid iterations " + iterations); + } catch (NumberFormatException e) { + throw new SaslException("Invalid SCRAM server first message format: invalid iterations"); + } + this.nonce = matcher.group("nonce"); + String salt = matcher.group("salt"); + this.salt = Base64.getDecoder().decode(salt); + } + public ServerFirstMessage(String clientNonce, String serverNonce, byte[] salt, int iterations) { + this.nonce = clientNonce + serverNonce; + this.salt = salt; + this.iterations = iterations; + } + public String nonce() { + return nonce; + } + public byte[] salt() { + return salt; + } + public int iterations() { + return iterations; + } + String toMessage() { + return String.format("r=%s,s=%s,i=%d", nonce, Base64.getEncoder().encodeToString(salt), iterations); + } + } + /** + * Format: + * channel-binding "," nonce ["," extensions]"," proof + * Limitations: + * Extensions are ignored. + * + */ + public static class ClientFinalMessage extends AbstractScramMessage { + private static final Pattern PATTERN = Pattern.compile(String.format( + "c=(?%s),r=(?%s)%s,p=(?%s)", + BASE64, + PRINTABLE, + EXTENSIONS, + BASE64)); + + private final byte[] channelBinding; + private final String nonce; + private byte[] proof; + public ClientFinalMessage(byte[] messageBytes) throws SaslException { + String message = toMessage(messageBytes); + Matcher matcher = PATTERN.matcher(message); + if (!matcher.matches()) + throw new SaslException("Invalid SCRAM client final message format: " + message); + + this.channelBinding = Base64.getDecoder().decode(matcher.group("channel")); + this.nonce = matcher.group("nonce"); + this.proof = Base64.getDecoder().decode(matcher.group("proof")); + } + public ClientFinalMessage(byte[] channelBinding, String nonce) { + this.channelBinding = channelBinding; + this.nonce = nonce; + } + public byte[] channelBinding() { + return channelBinding; + } + public String nonce() { + return nonce; + } + public byte[] proof() { + return proof; + } + public void proof(byte[] proof) { + this.proof = proof; + } + public String clientFinalMessageWithoutProof() { + return String.format("c=%s,r=%s", + Base64.getEncoder().encodeToString(channelBinding), + nonce); + } + String toMessage() { + return String.format("%s,p=%s", + clientFinalMessageWithoutProof(), + Base64.getEncoder().encodeToString(proof)); + } + } + /** + * Format: + * ("e=" server-error-value | "v=" base64_server_signature) ["," extensions] + * Limitations: + * Extensions are ignored. + * + */ + public static class ServerFinalMessage extends AbstractScramMessage { + private static final Pattern PATTERN = Pattern.compile(String.format( + "(?:e=(?%s))|(?:v=(?%s))%s", + VALUE_SAFE, + BASE64, + EXTENSIONS)); + + private final String error; + private final byte[] serverSignature; + public ServerFinalMessage(byte[] messageBytes) throws SaslException { + String message = toMessage(messageBytes); + Matcher matcher = PATTERN.matcher(message); + if (!matcher.matches()) + throw new SaslException("Invalid SCRAM server final message format: " + message); + String error = null; + try { + error = matcher.group("error"); + } catch (IllegalArgumentException e) { + // ignore + } + if (error == null) { + this.serverSignature = Base64.getDecoder().decode(matcher.group("signature")); + this.error = null; + } else { + this.serverSignature = null; + this.error = error; + } + } + public ServerFinalMessage(String error, byte[] serverSignature) { + this.error = error; + this.serverSignature = serverSignature; + } + public String error() { + return error; + } + public byte[] serverSignature() { + return serverSignature; + } + String toMessage() { + if (error != null) + return "e=" + error; + else + return "v=" + Base64.getEncoder().encodeToString(serverSignature); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslClient.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslClient.java new file mode 100644 index 0000000..536e409 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslClient.java @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.scram.internals; + +import java.nio.charset.StandardCharsets; +import java.security.InvalidKeyException; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.Arrays; +import java.util.Collection; +import java.util.Map; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.sasl.SaslClient; +import javax.security.sasl.SaslClientFactory; +import javax.security.sasl.SaslException; + +import org.apache.kafka.common.errors.IllegalSaslStateException; +import org.apache.kafka.common.security.scram.ScramExtensionsCallback; +import org.apache.kafka.common.security.scram.internals.ScramMessages.ClientFinalMessage; +import org.apache.kafka.common.security.scram.internals.ScramMessages.ServerFinalMessage; +import org.apache.kafka.common.security.scram.internals.ScramMessages.ServerFirstMessage; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * SaslClient implementation for SASL/SCRAM. + *

            + * This implementation expects a login module that populates username as + * the Subject's public credential and password as the private credential. + * + * @see RFC 5802 + * + */ +public class ScramSaslClient implements SaslClient { + + private static final Logger log = LoggerFactory.getLogger(ScramSaslClient.class); + + enum State { + SEND_CLIENT_FIRST_MESSAGE, + RECEIVE_SERVER_FIRST_MESSAGE, + RECEIVE_SERVER_FINAL_MESSAGE, + COMPLETE, + FAILED + } + + private final ScramMechanism mechanism; + private final CallbackHandler callbackHandler; + private final ScramFormatter formatter; + private String clientNonce; + private State state; + private byte[] saltedPassword; + private ScramMessages.ClientFirstMessage clientFirstMessage; + private ScramMessages.ServerFirstMessage serverFirstMessage; + private ScramMessages.ClientFinalMessage clientFinalMessage; + + public ScramSaslClient(ScramMechanism mechanism, CallbackHandler cbh) throws NoSuchAlgorithmException { + this.mechanism = mechanism; + this.callbackHandler = cbh; + this.formatter = new ScramFormatter(mechanism); + setState(State.SEND_CLIENT_FIRST_MESSAGE); + } + + @Override + public String getMechanismName() { + return mechanism.mechanismName(); + } + + @Override + public boolean hasInitialResponse() { + return true; + } + + @Override + public byte[] evaluateChallenge(byte[] challenge) throws SaslException { + try { + switch (state) { + case SEND_CLIENT_FIRST_MESSAGE: + if (challenge != null && challenge.length != 0) + throw new SaslException("Expected empty challenge"); + clientNonce = formatter.secureRandomString(); + NameCallback nameCallback = new NameCallback("Name:"); + ScramExtensionsCallback extensionsCallback = new ScramExtensionsCallback(); + + try { + callbackHandler.handle(new Callback[]{nameCallback}); + try { + callbackHandler.handle(new Callback[]{extensionsCallback}); + } catch (UnsupportedCallbackException e) { + log.debug("Extensions callback is not supported by client callback handler {}, no extensions will be added", + callbackHandler); + } + } catch (Throwable e) { + throw new SaslException("User name or extensions could not be obtained", e); + } + + String username = nameCallback.getName(); + String saslName = ScramFormatter.saslName(username); + Map extensions = extensionsCallback.extensions(); + this.clientFirstMessage = new ScramMessages.ClientFirstMessage(saslName, clientNonce, extensions); + setState(State.RECEIVE_SERVER_FIRST_MESSAGE); + return clientFirstMessage.toBytes(); + + case RECEIVE_SERVER_FIRST_MESSAGE: + this.serverFirstMessage = new ServerFirstMessage(challenge); + if (!serverFirstMessage.nonce().startsWith(clientNonce)) + throw new SaslException("Invalid server nonce: does not start with client nonce"); + if (serverFirstMessage.iterations() < mechanism.minIterations()) + throw new SaslException("Requested iterations " + serverFirstMessage.iterations() + " is less than the minimum " + mechanism.minIterations() + " for " + mechanism); + PasswordCallback passwordCallback = new PasswordCallback("Password:", false); + try { + callbackHandler.handle(new Callback[]{passwordCallback}); + } catch (Throwable e) { + throw new SaslException("User name could not be obtained", e); + } + this.clientFinalMessage = handleServerFirstMessage(passwordCallback.getPassword()); + setState(State.RECEIVE_SERVER_FINAL_MESSAGE); + return clientFinalMessage.toBytes(); + + case RECEIVE_SERVER_FINAL_MESSAGE: + ServerFinalMessage serverFinalMessage = new ServerFinalMessage(challenge); + if (serverFinalMessage.error() != null) + throw new SaslException("Sasl authentication using " + mechanism + " failed with error: " + serverFinalMessage.error()); + handleServerFinalMessage(serverFinalMessage.serverSignature()); + setState(State.COMPLETE); + return null; + + default: + throw new IllegalSaslStateException("Unexpected challenge in Sasl client state " + state); + } + } catch (SaslException e) { + setState(State.FAILED); + throw e; + } + } + + @Override + public boolean isComplete() { + return state == State.COMPLETE; + } + + @Override + public byte[] unwrap(byte[] incoming, int offset, int len) { + if (!isComplete()) + throw new IllegalStateException("Authentication exchange has not completed"); + return Arrays.copyOfRange(incoming, offset, offset + len); + } + + @Override + public byte[] wrap(byte[] outgoing, int offset, int len) { + if (!isComplete()) + throw new IllegalStateException("Authentication exchange has not completed"); + return Arrays.copyOfRange(outgoing, offset, offset + len); + } + + @Override + public Object getNegotiatedProperty(String propName) { + if (!isComplete()) + throw new IllegalStateException("Authentication exchange has not completed"); + return null; + } + + @Override + public void dispose() { + } + + private void setState(State state) { + log.debug("Setting SASL/{} client state to {}", mechanism, state); + this.state = state; + } + + private ClientFinalMessage handleServerFirstMessage(char[] password) throws SaslException { + try { + byte[] passwordBytes = ScramFormatter.normalize(new String(password)); + this.saltedPassword = formatter.hi(passwordBytes, serverFirstMessage.salt(), serverFirstMessage.iterations()); + + ClientFinalMessage clientFinalMessage = new ClientFinalMessage("n,,".getBytes(StandardCharsets.UTF_8), serverFirstMessage.nonce()); + byte[] clientProof = formatter.clientProof(saltedPassword, clientFirstMessage, serverFirstMessage, clientFinalMessage); + clientFinalMessage.proof(clientProof); + return clientFinalMessage; + } catch (InvalidKeyException e) { + throw new SaslException("Client final message could not be created", e); + } + } + + private void handleServerFinalMessage(byte[] signature) throws SaslException { + try { + byte[] serverKey = formatter.serverKey(saltedPassword); + byte[] serverSignature = formatter.serverSignature(serverKey, clientFirstMessage, serverFirstMessage, clientFinalMessage); + if (!MessageDigest.isEqual(signature, serverSignature)) + throw new SaslException("Invalid server signature in server final message"); + } catch (InvalidKeyException e) { + throw new SaslException("Sasl server signature verification failed", e); + } + } + + public static class ScramSaslClientFactory implements SaslClientFactory { + + @Override + public SaslClient createSaslClient(String[] mechanisms, + String authorizationId, + String protocol, + String serverName, + Map props, + CallbackHandler cbh) throws SaslException { + + ScramMechanism mechanism = null; + for (String mech : mechanisms) { + mechanism = ScramMechanism.forMechanismName(mech); + if (mechanism != null) + break; + } + if (mechanism == null) + throw new SaslException(String.format("Requested mechanisms '%s' not supported. Supported mechanisms are '%s'.", + Arrays.asList(mechanisms), ScramMechanism.mechanismNames())); + + try { + return new ScramSaslClient(mechanism, cbh); + } catch (NoSuchAlgorithmException e) { + throw new SaslException("Hash algorithm not supported for mechanism " + mechanism, e); + } + } + + @Override + public String[] getMechanismNames(Map props) { + Collection mechanisms = ScramMechanism.mechanismNames(); + return mechanisms.toArray(new String[0]); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslClientProvider.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslClientProvider.java new file mode 100644 index 0000000..8c5b85a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslClientProvider.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.scram.internals; + +import java.security.Provider; +import java.security.Security; + +import org.apache.kafka.common.security.scram.internals.ScramSaslClient.ScramSaslClientFactory; + +public class ScramSaslClientProvider extends Provider { + + private static final long serialVersionUID = 1L; + + @SuppressWarnings("deprecation") + protected ScramSaslClientProvider() { + super("SASL/SCRAM Client Provider", 1.0, "SASL/SCRAM Client Provider for Kafka"); + for (ScramMechanism mechanism : ScramMechanism.values()) + put("SaslClientFactory." + mechanism.mechanismName(), ScramSaslClientFactory.class.getName()); + } + + public static void initialize() { + Security.addProvider(new ScramSaslClientProvider()); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServer.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServer.java new file mode 100644 index 0000000..d5d55a6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServer.java @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.scram.internals; + +import java.security.InvalidKeyException; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.Arrays; +import java.util.Collection; +import java.util.Map; +import java.util.Set; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.NameCallback; +import javax.security.sasl.SaslException; +import javax.security.sasl.SaslServer; +import javax.security.sasl.SaslServerFactory; + +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.errors.IllegalSaslStateException; +import org.apache.kafka.common.errors.SaslAuthenticationException; +import org.apache.kafka.common.security.authenticator.SaslInternalConfigs; +import org.apache.kafka.common.security.scram.ScramCredential; +import org.apache.kafka.common.security.scram.ScramCredentialCallback; +import org.apache.kafka.common.security.scram.ScramLoginModule; +import org.apache.kafka.common.security.scram.internals.ScramMessages.ClientFinalMessage; +import org.apache.kafka.common.security.scram.internals.ScramMessages.ClientFirstMessage; +import org.apache.kafka.common.security.scram.internals.ScramMessages.ServerFinalMessage; +import org.apache.kafka.common.security.scram.internals.ScramMessages.ServerFirstMessage; +import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCredentialCallback; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * SaslServer implementation for SASL/SCRAM. This server is configured with a callback + * handler for integration with a credential manager. Kafka brokers provide callbacks + * based on a Zookeeper-based password store. + * + * @see RFC 5802 + */ +public class ScramSaslServer implements SaslServer { + + private static final Logger log = LoggerFactory.getLogger(ScramSaslServer.class); + private static final Set SUPPORTED_EXTENSIONS = Utils.mkSet(ScramLoginModule.TOKEN_AUTH_CONFIG); + + enum State { + RECEIVE_CLIENT_FIRST_MESSAGE, + RECEIVE_CLIENT_FINAL_MESSAGE, + COMPLETE, + FAILED + } + + private final ScramMechanism mechanism; + private final ScramFormatter formatter; + private final CallbackHandler callbackHandler; + private State state; + private String username; + private ClientFirstMessage clientFirstMessage; + private ServerFirstMessage serverFirstMessage; + private ScramExtensions scramExtensions; + private ScramCredential scramCredential; + private String authorizationId; + private Long tokenExpiryTimestamp; + + public ScramSaslServer(ScramMechanism mechanism, Map props, CallbackHandler callbackHandler) throws NoSuchAlgorithmException { + this.mechanism = mechanism; + this.formatter = new ScramFormatter(mechanism); + this.callbackHandler = callbackHandler; + setState(State.RECEIVE_CLIENT_FIRST_MESSAGE); + } + + /** + * @throws SaslAuthenticationException if the requested authorization id is not the same as username. + *

            + * Note: This method may throw {@link SaslAuthenticationException} to provide custom error messages + * to clients. But care should be taken to avoid including any information in the exception message that + * should not be leaked to unauthenticated clients. It may be safer to throw {@link SaslException} in + * most cases so that a standard error message is returned to clients. + *

            + */ + @Override + public byte[] evaluateResponse(byte[] response) throws SaslException, SaslAuthenticationException { + try { + switch (state) { + case RECEIVE_CLIENT_FIRST_MESSAGE: + this.clientFirstMessage = new ClientFirstMessage(response); + this.scramExtensions = clientFirstMessage.extensions(); + if (!SUPPORTED_EXTENSIONS.containsAll(scramExtensions.map().keySet())) { + log.debug("Unsupported extensions will be ignored, supported {}, provided {}", + SUPPORTED_EXTENSIONS, scramExtensions.map().keySet()); + } + String serverNonce = formatter.secureRandomString(); + try { + String saslName = clientFirstMessage.saslName(); + this.username = ScramFormatter.username(saslName); + NameCallback nameCallback = new NameCallback("username", username); + ScramCredentialCallback credentialCallback; + if (scramExtensions.tokenAuthenticated()) { + DelegationTokenCredentialCallback tokenCallback = new DelegationTokenCredentialCallback(); + credentialCallback = tokenCallback; + callbackHandler.handle(new Callback[]{nameCallback, tokenCallback}); + if (tokenCallback.tokenOwner() == null) + throw new SaslException("Token Authentication failed: Invalid tokenId : " + username); + this.authorizationId = tokenCallback.tokenOwner(); + this.tokenExpiryTimestamp = tokenCallback.tokenExpiryTimestamp(); + } else { + credentialCallback = new ScramCredentialCallback(); + callbackHandler.handle(new Callback[]{nameCallback, credentialCallback}); + this.authorizationId = username; + this.tokenExpiryTimestamp = null; + } + this.scramCredential = credentialCallback.scramCredential(); + if (scramCredential == null) + throw new SaslException("Authentication failed: Invalid user credentials"); + String authorizationIdFromClient = clientFirstMessage.authorizationId(); + if (!authorizationIdFromClient.isEmpty() && !authorizationIdFromClient.equals(username)) + throw new SaslAuthenticationException("Authentication failed: Client requested an authorization id that is different from username"); + + if (scramCredential.iterations() < mechanism.minIterations()) + throw new SaslException("Iterations " + scramCredential.iterations() + " is less than the minimum " + mechanism.minIterations() + " for " + mechanism); + this.serverFirstMessage = new ServerFirstMessage(clientFirstMessage.nonce(), + serverNonce, + scramCredential.salt(), + scramCredential.iterations()); + setState(State.RECEIVE_CLIENT_FINAL_MESSAGE); + return serverFirstMessage.toBytes(); + } catch (SaslException | AuthenticationException e) { + throw e; + } catch (Throwable e) { + throw new SaslException("Authentication failed: Credentials could not be obtained", e); + } + + case RECEIVE_CLIENT_FINAL_MESSAGE: + try { + ClientFinalMessage clientFinalMessage = new ClientFinalMessage(response); + verifyClientProof(clientFinalMessage); + byte[] serverKey = scramCredential.serverKey(); + byte[] serverSignature = formatter.serverSignature(serverKey, clientFirstMessage, serverFirstMessage, clientFinalMessage); + ServerFinalMessage serverFinalMessage = new ServerFinalMessage(null, serverSignature); + clearCredentials(); + setState(State.COMPLETE); + return serverFinalMessage.toBytes(); + } catch (InvalidKeyException e) { + throw new SaslException("Authentication failed: Invalid client final message", e); + } + + default: + throw new IllegalSaslStateException("Unexpected challenge in Sasl server state " + state); + } + } catch (SaslException | AuthenticationException e) { + clearCredentials(); + setState(State.FAILED); + throw e; + } + } + + @Override + public String getAuthorizationID() { + if (!isComplete()) + throw new IllegalStateException("Authentication exchange has not completed"); + return authorizationId; + } + + @Override + public String getMechanismName() { + return mechanism.mechanismName(); + } + + @Override + public Object getNegotiatedProperty(String propName) { + if (!isComplete()) + throw new IllegalStateException("Authentication exchange has not completed"); + if (SaslInternalConfigs.CREDENTIAL_LIFETIME_MS_SASL_NEGOTIATED_PROPERTY_KEY.equals(propName)) + return tokenExpiryTimestamp; // will be null if token not used + if (SUPPORTED_EXTENSIONS.contains(propName)) + return scramExtensions.map().get(propName); + else + return null; + } + + @Override + public boolean isComplete() { + return state == State.COMPLETE; + } + + @Override + public byte[] unwrap(byte[] incoming, int offset, int len) { + if (!isComplete()) + throw new IllegalStateException("Authentication exchange has not completed"); + return Arrays.copyOfRange(incoming, offset, offset + len); + } + + @Override + public byte[] wrap(byte[] outgoing, int offset, int len) { + if (!isComplete()) + throw new IllegalStateException("Authentication exchange has not completed"); + return Arrays.copyOfRange(outgoing, offset, offset + len); + } + + @Override + public void dispose() { + } + + private void setState(State state) { + log.debug("Setting SASL/{} server state to {}", mechanism, state); + this.state = state; + } + + private void verifyClientProof(ClientFinalMessage clientFinalMessage) throws SaslException { + try { + byte[] expectedStoredKey = scramCredential.storedKey(); + byte[] clientSignature = formatter.clientSignature(expectedStoredKey, clientFirstMessage, serverFirstMessage, clientFinalMessage); + byte[] computedStoredKey = formatter.storedKey(clientSignature, clientFinalMessage.proof()); + if (!MessageDigest.isEqual(computedStoredKey, expectedStoredKey)) + throw new SaslException("Invalid client credentials"); + } catch (InvalidKeyException e) { + throw new SaslException("Sasl client verification failed", e); + } + } + + private void clearCredentials() { + scramCredential = null; + clientFirstMessage = null; + serverFirstMessage = null; + } + + public static class ScramSaslServerFactory implements SaslServerFactory { + + @Override + public SaslServer createSaslServer(String mechanism, String protocol, String serverName, Map props, CallbackHandler cbh) + throws SaslException { + + if (!ScramMechanism.isScram(mechanism)) { + throw new SaslException(String.format("Requested mechanism '%s' is not supported. Supported mechanisms are '%s'.", + mechanism, ScramMechanism.mechanismNames())); + } + try { + return new ScramSaslServer(ScramMechanism.forMechanismName(mechanism), props, cbh); + } catch (NoSuchAlgorithmException e) { + throw new SaslException("Hash algorithm not supported for mechanism " + mechanism, e); + } + } + + @Override + public String[] getMechanismNames(Map props) { + Collection mechanisms = ScramMechanism.mechanismNames(); + return mechanisms.toArray(new String[0]); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServerProvider.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServerProvider.java new file mode 100644 index 0000000..6a86860 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServerProvider.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.scram.internals; + +import java.security.Provider; +import java.security.Security; + +import org.apache.kafka.common.security.scram.internals.ScramSaslServer.ScramSaslServerFactory; + +public class ScramSaslServerProvider extends Provider { + + private static final long serialVersionUID = 1L; + + @SuppressWarnings("deprecation") + protected ScramSaslServerProvider() { + super("SASL/SCRAM Server Provider", 1.0, "SASL/SCRAM Server Provider for Kafka"); + for (ScramMechanism mechanism : ScramMechanism.values()) + put("SaslServerFactory." + mechanism.mechanismName(), ScramSaslServerFactory.class.getName()); + } + + public static void initialize() { + Security.addProvider(new ScramSaslServerProvider()); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramServerCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramServerCallbackHandler.java new file mode 100644 index 0000000..1af38e9 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramServerCallbackHandler.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.scram.internals; + +import java.util.List; +import java.util.Map; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.auth.login.AppConfigurationEntry; + +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; +import org.apache.kafka.common.security.authenticator.CredentialCache; +import org.apache.kafka.common.security.scram.ScramCredential; +import org.apache.kafka.common.security.scram.ScramCredentialCallback; +import org.apache.kafka.common.security.token.delegation.TokenInformation; +import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache; +import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCredentialCallback; + +public class ScramServerCallbackHandler implements AuthenticateCallbackHandler { + + private final CredentialCache.Cache credentialCache; + private final DelegationTokenCache tokenCache; + private String saslMechanism; + + public ScramServerCallbackHandler(CredentialCache.Cache credentialCache, + DelegationTokenCache tokenCache) { + this.credentialCache = credentialCache; + this.tokenCache = tokenCache; + } + + @Override + public void configure(Map configs, String mechanism, List jaasConfigEntries) { + this.saslMechanism = mechanism; + } + + @Override + public void handle(Callback[] callbacks) throws UnsupportedCallbackException { + String username = null; + for (Callback callback : callbacks) { + if (callback instanceof NameCallback) + username = ((NameCallback) callback).getDefaultName(); + else if (callback instanceof DelegationTokenCredentialCallback) { + DelegationTokenCredentialCallback tokenCallback = (DelegationTokenCredentialCallback) callback; + tokenCallback.scramCredential(tokenCache.credential(saslMechanism, username)); + tokenCallback.tokenOwner(tokenCache.owner(username)); + TokenInformation tokenInfo = tokenCache.token(username); + if (tokenInfo != null) + tokenCallback.tokenExpiryTimestamp(tokenInfo.expiryTimestamp()); + } else if (callback instanceof ScramCredentialCallback) { + ScramCredentialCallback sc = (ScramCredentialCallback) callback; + sc.scramCredential(credentialCache.get(username)); + } else + throw new UnsupportedCallbackException(callback); + } + } + + @Override + public void close() { + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/ssl/DefaultSslEngineFactory.java b/clients/src/main/java/org/apache/kafka/common/security/ssl/DefaultSslEngineFactory.java new file mode 100644 index 0000000..1528a4a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/ssl/DefaultSslEngineFactory.java @@ -0,0 +1,584 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.ssl; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.config.SslClientAuth; +import org.apache.kafka.common.config.SslConfigs; +import org.apache.kafka.common.config.internals.BrokerSecurityConfigs; +import org.apache.kafka.common.config.types.Password; +import org.apache.kafka.common.errors.InvalidConfigurationException; +import org.apache.kafka.common.network.Mode; +import org.apache.kafka.common.security.auth.SslEngineFactory; +import org.apache.kafka.common.utils.SecurityUtils; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.security.GeneralSecurityException; +import java.security.Key; +import java.security.KeyFactory; +import java.security.KeyStore; +import java.security.PrivateKey; +import java.security.SecureRandom; +import java.security.cert.Certificate; +import java.security.cert.CertificateFactory; +import java.security.spec.InvalidKeySpecException; +import java.security.spec.PKCS8EncodedKeySpec; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Base64; +import java.util.Collections; +import java.util.Date; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import javax.crypto.Cipher; +import javax.crypto.EncryptedPrivateKeyInfo; +import javax.crypto.SecretKey; +import javax.crypto.SecretKeyFactory; +import javax.crypto.spec.PBEKeySpec; +import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.TrustManagerFactory; + +public final class DefaultSslEngineFactory implements SslEngineFactory { + + private static final Logger log = LoggerFactory.getLogger(DefaultSslEngineFactory.class); + public static final String PEM_TYPE = "PEM"; + + private Map configs; + private String protocol; + private String provider; + private String kmfAlgorithm; + private String tmfAlgorithm; + private SecurityStore keystore; + private SecurityStore truststore; + private String[] cipherSuites; + private String[] enabledProtocols; + private SecureRandom secureRandomImplementation; + private SSLContext sslContext; + private SslClientAuth sslClientAuth; + + + @Override + public SSLEngine createClientSslEngine(String peerHost, int peerPort, String endpointIdentification) { + return createSslEngine(Mode.CLIENT, peerHost, peerPort, endpointIdentification); + } + + @Override + public SSLEngine createServerSslEngine(String peerHost, int peerPort) { + return createSslEngine(Mode.SERVER, peerHost, peerPort, null); + } + + @Override + public boolean shouldBeRebuilt(Map nextConfigs) { + if (!nextConfigs.equals(configs)) { + return true; + } + if (truststore != null && truststore.modified()) { + return true; + } + if (keystore != null && keystore.modified()) { + return true; + } + return false; + } + + @Override + public Set reconfigurableConfigs() { + return SslConfigs.RECONFIGURABLE_CONFIGS; + } + + @Override + public KeyStore keystore() { + return this.keystore != null ? this.keystore.get() : null; + } + + @Override + public KeyStore truststore() { + return this.truststore != null ? this.truststore.get() : null; + } + + @SuppressWarnings("unchecked") + @Override + public void configure(Map configs) { + this.configs = Collections.unmodifiableMap(configs); + this.protocol = (String) configs.get(SslConfigs.SSL_PROTOCOL_CONFIG); + this.provider = (String) configs.get(SslConfigs.SSL_PROVIDER_CONFIG); + SecurityUtils.addConfiguredSecurityProviders(this.configs); + + List cipherSuitesList = (List) configs.get(SslConfigs.SSL_CIPHER_SUITES_CONFIG); + if (cipherSuitesList != null && !cipherSuitesList.isEmpty()) { + this.cipherSuites = cipherSuitesList.toArray(new String[0]); + } else { + this.cipherSuites = null; + } + + List enabledProtocolsList = (List) configs.get(SslConfigs.SSL_ENABLED_PROTOCOLS_CONFIG); + if (enabledProtocolsList != null && !enabledProtocolsList.isEmpty()) { + this.enabledProtocols = enabledProtocolsList.toArray(new String[0]); + } else { + this.enabledProtocols = null; + } + + this.secureRandomImplementation = createSecureRandom((String) + configs.get(SslConfigs.SSL_SECURE_RANDOM_IMPLEMENTATION_CONFIG)); + + this.sslClientAuth = createSslClientAuth((String) configs.get( + BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG)); + + this.kmfAlgorithm = (String) configs.get(SslConfigs.SSL_KEYMANAGER_ALGORITHM_CONFIG); + this.tmfAlgorithm = (String) configs.get(SslConfigs.SSL_TRUSTMANAGER_ALGORITHM_CONFIG); + + this.keystore = createKeystore((String) configs.get(SslConfigs.SSL_KEYSTORE_TYPE_CONFIG), + (String) configs.get(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG), + (Password) configs.get(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG), + (Password) configs.get(SslConfigs.SSL_KEY_PASSWORD_CONFIG), + (Password) configs.get(SslConfigs.SSL_KEYSTORE_KEY_CONFIG), + (Password) configs.get(SslConfigs.SSL_KEYSTORE_CERTIFICATE_CHAIN_CONFIG)); + + this.truststore = createTruststore((String) configs.get(SslConfigs.SSL_TRUSTSTORE_TYPE_CONFIG), + (String) configs.get(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG), + (Password) configs.get(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG), + (Password) configs.get(SslConfigs.SSL_TRUSTSTORE_CERTIFICATES_CONFIG)); + + this.sslContext = createSSLContext(keystore, truststore); + } + + @Override + public void close() { + this.sslContext = null; + } + + //For Test only + public SSLContext sslContext() { + return this.sslContext; + } + + private SSLEngine createSslEngine(Mode mode, String peerHost, int peerPort, String endpointIdentification) { + SSLEngine sslEngine = sslContext.createSSLEngine(peerHost, peerPort); + if (cipherSuites != null) sslEngine.setEnabledCipherSuites(cipherSuites); + if (enabledProtocols != null) sslEngine.setEnabledProtocols(enabledProtocols); + + if (mode == Mode.SERVER) { + sslEngine.setUseClientMode(false); + switch (sslClientAuth) { + case REQUIRED: + sslEngine.setNeedClientAuth(true); + break; + case REQUESTED: + sslEngine.setWantClientAuth(true); + break; + case NONE: + break; + } + sslEngine.setUseClientMode(false); + } else { + sslEngine.setUseClientMode(true); + SSLParameters sslParams = sslEngine.getSSLParameters(); + // SSLParameters#setEndpointIdentificationAlgorithm enables endpoint validation + // only in client mode. Hence, validation is enabled only for clients. + sslParams.setEndpointIdentificationAlgorithm(endpointIdentification); + sslEngine.setSSLParameters(sslParams); + } + return sslEngine; + } + private static SslClientAuth createSslClientAuth(String key) { + SslClientAuth auth = SslClientAuth.forConfig(key); + if (auth != null) { + return auth; + } + log.warn("Unrecognized client authentication configuration {}. Falling " + + "back to NONE. Recognized client authentication configurations are {}.", + key, String.join(", ", SslClientAuth.VALUES.stream(). + map(Enum::name).collect(Collectors.toList()))); + return SslClientAuth.NONE; + } + + private static SecureRandom createSecureRandom(String key) { + if (key == null) { + return null; + } + try { + return SecureRandom.getInstance(key); + } catch (GeneralSecurityException e) { + throw new KafkaException(e); + } + } + + private SSLContext createSSLContext(SecurityStore keystore, SecurityStore truststore) { + try { + SSLContext sslContext; + if (provider != null) + sslContext = SSLContext.getInstance(protocol, provider); + else + sslContext = SSLContext.getInstance(protocol); + + KeyManager[] keyManagers = null; + if (keystore != null || kmfAlgorithm != null) { + String kmfAlgorithm = this.kmfAlgorithm != null ? + this.kmfAlgorithm : KeyManagerFactory.getDefaultAlgorithm(); + KeyManagerFactory kmf = KeyManagerFactory.getInstance(kmfAlgorithm); + if (keystore != null) { + kmf.init(keystore.get(), keystore.keyPassword()); + } else { + kmf.init(null, null); + } + keyManagers = kmf.getKeyManagers(); + } + + String tmfAlgorithm = this.tmfAlgorithm != null ? this.tmfAlgorithm : TrustManagerFactory.getDefaultAlgorithm(); + TrustManagerFactory tmf = TrustManagerFactory.getInstance(tmfAlgorithm); + KeyStore ts = truststore == null ? null : truststore.get(); + tmf.init(ts); + + sslContext.init(keyManagers, tmf.getTrustManagers(), this.secureRandomImplementation); + log.debug("Created SSL context with keystore {}, truststore {}, provider {}.", + keystore, truststore, sslContext.getProvider().getName()); + return sslContext; + } catch (Exception e) { + throw new KafkaException(e); + } + } + + // Visibility to override for testing + protected SecurityStore createKeystore(String type, String path, Password password, Password keyPassword, Password privateKey, Password certificateChain) { + if (privateKey != null) { + if (!PEM_TYPE.equals(type)) + throw new InvalidConfigurationException("SSL private key can be specified only for PEM, but key store type is " + type + "."); + else if (certificateChain == null) + throw new InvalidConfigurationException("SSL private key is specified, but certificate chain is not specified."); + else if (path != null) + throw new InvalidConfigurationException("Both SSL key store location and separate private key are specified."); + else if (password != null) + throw new InvalidConfigurationException("SSL key store password cannot be specified with PEM format, only key password may be specified."); + else + return new PemStore(certificateChain, privateKey, keyPassword); + } else if (certificateChain != null) { + throw new InvalidConfigurationException("SSL certificate chain is specified, but private key is not specified"); + } else if (PEM_TYPE.equals(type) && path != null) { + if (password != null) + throw new InvalidConfigurationException("SSL key store password cannot be specified with PEM format, only key password may be specified"); + else if (keyPassword == null) + throw new InvalidConfigurationException("SSL PEM key store is specified, but key password is not specified."); + else + return new FileBasedPemStore(path, keyPassword, true); + } else if (path == null && password != null) { + throw new InvalidConfigurationException("SSL key store is not specified, but key store password is specified."); + } else if (path != null && password == null) { + throw new InvalidConfigurationException("SSL key store is specified, but key store password is not specified."); + } else if (path != null && password != null) { + return new FileBasedStore(type, path, password, keyPassword, true); + } else + return null; // path == null, clients may use this path with brokers that don't require client auth + } + + private static SecurityStore createTruststore(String type, String path, Password password, Password trustStoreCerts) { + if (trustStoreCerts != null) { + if (!PEM_TYPE.equals(type)) + throw new InvalidConfigurationException("SSL trust store certs can be specified only for PEM, but trust store type is " + type + "."); + else if (path != null) + throw new InvalidConfigurationException("Both SSL trust store location and separate trust certificates are specified."); + else if (password != null) + throw new InvalidConfigurationException("SSL trust store password cannot be specified for PEM format."); + else + return new PemStore(trustStoreCerts); + } else if (PEM_TYPE.equals(type) && path != null) { + if (password != null) + throw new InvalidConfigurationException("SSL trust store password cannot be specified for PEM format."); + else + return new FileBasedPemStore(path, null, false); + } else if (path == null && password != null) { + throw new InvalidConfigurationException("SSL trust store is not specified, but trust store password is specified."); + } else if (path != null) { + return new FileBasedStore(type, path, password, null, false); + } else + return null; + } + + static interface SecurityStore { + KeyStore get(); + char[] keyPassword(); + boolean modified(); + } + + // package access for testing + static class FileBasedStore implements SecurityStore { + private final String type; + protected final String path; + private final Password password; + protected final Password keyPassword; + private final Long fileLastModifiedMs; + private final KeyStore keyStore; + + FileBasedStore(String type, String path, Password password, Password keyPassword, boolean isKeyStore) { + Objects.requireNonNull(type, "type must not be null"); + this.type = type; + this.path = path; + this.password = password; + this.keyPassword = keyPassword; + fileLastModifiedMs = lastModifiedMs(path); + this.keyStore = load(isKeyStore); + } + + @Override + public KeyStore get() { + return keyStore; + } + + @Override + public char[] keyPassword() { + Password passwd = keyPassword != null ? keyPassword : password; + return passwd == null ? null : passwd.value().toCharArray(); + } + + /** + * Loads this keystore + * @return the keystore + * @throws KafkaException if the file could not be read or if the keystore could not be loaded + * using the specified configs (e.g. if the password or keystore type is invalid) + */ + protected KeyStore load(boolean isKeyStore) { + try (InputStream in = Files.newInputStream(Paths.get(path))) { + KeyStore ks = KeyStore.getInstance(type); + // If a password is not set access to the truststore is still available, but integrity checking is disabled. + char[] passwordChars = password != null ? password.value().toCharArray() : null; + ks.load(in, passwordChars); + return ks; + } catch (GeneralSecurityException | IOException e) { + throw new KafkaException("Failed to load SSL keystore " + path + " of type " + type, e); + } + } + + private Long lastModifiedMs(String path) { + try { + return Files.getLastModifiedTime(Paths.get(path)).toMillis(); + } catch (IOException e) { + log.error("Modification time of key store could not be obtained: " + path, e); + return null; + } + } + + public boolean modified() { + Long modifiedMs = lastModifiedMs(path); + return modifiedMs != null && !Objects.equals(modifiedMs, this.fileLastModifiedMs); + } + + @Override + public String toString() { + return "SecurityStore(" + + "path=" + path + + ", modificationTime=" + (fileLastModifiedMs == null ? null : new Date(fileLastModifiedMs)) + ")"; + } + } + + static class FileBasedPemStore extends FileBasedStore { + FileBasedPemStore(String path, Password keyPassword, boolean isKeyStore) { + super(PEM_TYPE, path, null, keyPassword, isKeyStore); + } + + @Override + protected KeyStore load(boolean isKeyStore) { + try { + Password storeContents = new Password(Utils.readFileAsString(path)); + PemStore pemStore = isKeyStore ? new PemStore(storeContents, storeContents, keyPassword) : + new PemStore(storeContents); + return pemStore.keyStore; + } catch (Exception e) { + throw new InvalidConfigurationException("Failed to load PEM SSL keystore " + path, e); + } + } + } + + static class PemStore implements SecurityStore { + private static final PemParser CERTIFICATE_PARSER = new PemParser("CERTIFICATE"); + private static final PemParser PRIVATE_KEY_PARSER = new PemParser("PRIVATE KEY"); + private static final List KEY_FACTORIES = Arrays.asList( + keyFactory("RSA"), + keyFactory("DSA"), + keyFactory("EC") + ); + + private final char[] keyPassword; + private final KeyStore keyStore; + + PemStore(Password certificateChain, Password privateKey, Password keyPassword) { + this.keyPassword = keyPassword == null ? null : keyPassword.value().toCharArray(); + keyStore = createKeyStoreFromPem(privateKey.value(), certificateChain.value(), this.keyPassword); + } + + PemStore(Password trustStoreCerts) { + this.keyPassword = null; + keyStore = createTrustStoreFromPem(trustStoreCerts.value()); + } + + @Override + public KeyStore get() { + return keyStore; + } + + @Override + public char[] keyPassword() { + return keyPassword; + } + + @Override + public boolean modified() { + return false; + } + + private KeyStore createKeyStoreFromPem(String privateKeyPem, String certChainPem, char[] keyPassword) { + try { + KeyStore ks = KeyStore.getInstance("PKCS12"); + ks.load(null, null); + Key key = privateKey(privateKeyPem, keyPassword); + Certificate[] certChain = certs(certChainPem); + ks.setKeyEntry("kafka", key, keyPassword, certChain); + return ks; + } catch (Exception e) { + throw new InvalidConfigurationException("Invalid PEM keystore configs", e); + } + } + + private KeyStore createTrustStoreFromPem(String trustedCertsPem) { + try { + KeyStore ts = KeyStore.getInstance("PKCS12"); + ts.load(null, null); + Certificate[] certs = certs(trustedCertsPem); + for (int i = 0; i < certs.length; i++) { + ts.setCertificateEntry("kafka" + i, certs[i]); + } + return ts; + } catch (InvalidConfigurationException e) { + throw e; + } catch (Exception e) { + throw new InvalidConfigurationException("Invalid PEM keystore configs", e); + } + } + + private Certificate[] certs(String pem) throws GeneralSecurityException { + List certEntries = CERTIFICATE_PARSER.pemEntries(pem); + if (certEntries.isEmpty()) + throw new InvalidConfigurationException("At least one certificate expected, but none found"); + + Certificate[] certs = new Certificate[certEntries.size()]; + for (int i = 0; i < certs.length; i++) { + certs[i] = CertificateFactory.getInstance("X.509") + .generateCertificate(new ByteArrayInputStream(certEntries.get(i))); + } + return certs; + } + + private PrivateKey privateKey(String pem, char[] keyPassword) throws Exception { + List keyEntries = PRIVATE_KEY_PARSER.pemEntries(pem); + if (keyEntries.isEmpty()) + throw new InvalidConfigurationException("Private key not provided"); + if (keyEntries.size() != 1) + throw new InvalidConfigurationException("Expected one private key, but found " + keyEntries.size()); + + byte[] keyBytes = keyEntries.get(0); + PKCS8EncodedKeySpec keySpec; + if (keyPassword == null) { + keySpec = new PKCS8EncodedKeySpec(keyBytes); + } else { + EncryptedPrivateKeyInfo keyInfo = new EncryptedPrivateKeyInfo(keyBytes); + String algorithm = keyInfo.getAlgName(); + SecretKeyFactory keyFactory = SecretKeyFactory.getInstance(algorithm); + SecretKey pbeKey = keyFactory.generateSecret(new PBEKeySpec(keyPassword)); + Cipher cipher = Cipher.getInstance(algorithm); + cipher.init(Cipher.DECRYPT_MODE, pbeKey, keyInfo.getAlgParameters()); + keySpec = keyInfo.getKeySpec(cipher); + } + + InvalidKeySpecException firstException = null; + for (KeyFactory factory : KEY_FACTORIES) { + try { + return factory.generatePrivate(keySpec); + } catch (InvalidKeySpecException e) { + if (firstException == null) + firstException = e; + } + } + throw new InvalidConfigurationException("Private key could not be loaded", firstException); + } + + private static KeyFactory keyFactory(String algorithm) { + try { + return KeyFactory.getInstance(algorithm); + } catch (Exception e) { + throw new InvalidConfigurationException("Could not create key factory for algorithm " + algorithm, e); + } + } + } + + /** + * Parser to process certificate/private key entries from PEM files + * Examples: + * -----BEGIN CERTIFICATE----- + * Base64 cert + * -----END CERTIFICATE----- + * + * -----BEGIN ENCRYPTED PRIVATE KEY----- + * Base64 private key + * -----END ENCRYPTED PRIVATE KEY----- + * Additional data may be included before headers, so we match all entries within the PEM. + */ + static class PemParser { + private final String name; + private final Pattern pattern; + + PemParser(String name) { + this.name = name; + String beginOrEndFormat = "-+%s\\s*.*%s[^-]*-+\\s+"; + String nameIgnoreSpace = name.replace(" ", "\\s+"); + + String encodingParams = "\\s*[^\\r\\n]*:[^\\r\\n]*[\\r\\n]+"; + String base64Pattern = "([a-zA-Z0-9/+=\\s]*)"; + String patternStr = String.format(beginOrEndFormat, "BEGIN", nameIgnoreSpace) + + String.format("(?:%s)*", encodingParams) + + base64Pattern + + String.format(beginOrEndFormat, "END", nameIgnoreSpace); + pattern = Pattern.compile(patternStr); + } + + private List pemEntries(String pem) { + Matcher matcher = pattern.matcher(pem + "\n"); // allow last newline to be omitted in value + List entries = new ArrayList<>(); + while (matcher.find()) { + String base64Str = matcher.group(1).replaceAll("\\s", ""); + entries.add(Base64.getDecoder().decode(base64Str)); + } + if (entries.isEmpty()) + throw new InvalidConfigurationException("No matching " + name + " entries in PEM file"); + return entries; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/ssl/SslFactory.java b/clients/src/main/java/org/apache/kafka/common/security/ssl/SslFactory.java new file mode 100644 index 0000000..d0cc4cc --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/ssl/SslFactory.java @@ -0,0 +1,482 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.ssl; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Reconfigurable; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.config.SslConfigs; +import org.apache.kafka.common.config.internals.BrokerSecurityConfigs; +import org.apache.kafka.common.network.Mode; +import org.apache.kafka.common.security.auth.SslEngineFactory; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLException; +import java.io.Closeable; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.nio.ByteBuffer; +import java.security.cert.Certificate; +import java.security.cert.X509Certificate; +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.security.Principal; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.HashSet; + +public class SslFactory implements Reconfigurable, Closeable { + private static final Logger log = LoggerFactory.getLogger(SslFactory.class); + + private final Mode mode; + private final String clientAuthConfigOverride; + private final boolean keystoreVerifiableUsingTruststore; + private String endpointIdentification; + private SslEngineFactory sslEngineFactory; + private Map sslEngineFactoryConfig; + + public SslFactory(Mode mode) { + this(mode, null, false); + } + + /** + * Create an SslFactory. + * + * @param mode Whether to use client or server mode. + * @param clientAuthConfigOverride The value to override ssl.client.auth with, or null + * if we don't want to override it. + * @param keystoreVerifiableUsingTruststore True if we should require the keystore to be verifiable + * using the truststore. + */ + public SslFactory(Mode mode, + String clientAuthConfigOverride, + boolean keystoreVerifiableUsingTruststore) { + this.mode = mode; + this.clientAuthConfigOverride = clientAuthConfigOverride; + this.keystoreVerifiableUsingTruststore = keystoreVerifiableUsingTruststore; + } + + @SuppressWarnings("unchecked") + @Override + public void configure(Map configs) throws KafkaException { + if (sslEngineFactory != null) { + throw new IllegalStateException("SslFactory was already configured."); + } + this.endpointIdentification = (String) configs.get(SslConfigs.SSL_ENDPOINT_IDENTIFICATION_ALGORITHM_CONFIG); + + // The input map must be a mutable RecordingMap in production. + Map nextConfigs = (Map) configs; + if (clientAuthConfigOverride != null) { + nextConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, clientAuthConfigOverride); + } + SslEngineFactory builder = instantiateSslEngineFactory(nextConfigs); + if (keystoreVerifiableUsingTruststore) { + try { + SslEngineValidator.validate(builder, builder); + } catch (Exception e) { + throw new ConfigException("A client SSLEngine created with the provided settings " + + "can't connect to a server SSLEngine created with those settings.", e); + } + } + this.sslEngineFactory = builder; + } + + @Override + public Set reconfigurableConfigs() { + return sslEngineFactory.reconfigurableConfigs(); + } + + @Override + public void validateReconfiguration(Map newConfigs) { + createNewSslEngineFactory(newConfigs); + } + + @Override + public void reconfigure(Map newConfigs) throws KafkaException { + SslEngineFactory newSslEngineFactory = createNewSslEngineFactory(newConfigs); + if (newSslEngineFactory != this.sslEngineFactory) { + Utils.closeQuietly(this.sslEngineFactory, "close stale ssl engine factory"); + this.sslEngineFactory = newSslEngineFactory; + log.info("Created new {} SSL engine builder with keystore {} truststore {}", mode, + newSslEngineFactory.keystore(), newSslEngineFactory.truststore()); + } + } + + private SslEngineFactory instantiateSslEngineFactory(Map configs) { + @SuppressWarnings("unchecked") + Class sslEngineFactoryClass = + (Class) configs.get(SslConfigs.SSL_ENGINE_FACTORY_CLASS_CONFIG); + SslEngineFactory sslEngineFactory; + if (sslEngineFactoryClass == null) { + sslEngineFactory = new DefaultSslEngineFactory(); + } else { + sslEngineFactory = Utils.newInstance(sslEngineFactoryClass); + } + sslEngineFactory.configure(configs); + this.sslEngineFactoryConfig = configs; + return sslEngineFactory; + } + + private SslEngineFactory createNewSslEngineFactory(Map newConfigs) { + if (sslEngineFactory == null) { + throw new IllegalStateException("SslFactory has not been configured."); + } + Map nextConfigs = new HashMap<>(sslEngineFactoryConfig); + copyMapEntries(nextConfigs, newConfigs, reconfigurableConfigs()); + if (clientAuthConfigOverride != null) { + nextConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, clientAuthConfigOverride); + } + if (!sslEngineFactory.shouldBeRebuilt(nextConfigs)) { + return sslEngineFactory; + } + try { + SslEngineFactory newSslEngineFactory = instantiateSslEngineFactory(nextConfigs); + if (sslEngineFactory.keystore() == null) { + if (newSslEngineFactory.keystore() != null) { + throw new ConfigException("Cannot add SSL keystore to an existing listener for " + + "which no keystore was configured."); + } + } else { + if (newSslEngineFactory.keystore() == null) { + throw new ConfigException("Cannot remove the SSL keystore from an existing listener for " + + "which a keystore was configured."); + } + + CertificateEntries.ensureCompatible(newSslEngineFactory.keystore(), sslEngineFactory.keystore()); + } + if (sslEngineFactory.truststore() == null && newSslEngineFactory.truststore() != null) { + throw new ConfigException("Cannot add SSL truststore to an existing listener for which no " + + "truststore was configured."); + } + if (keystoreVerifiableUsingTruststore) { + if (sslEngineFactory.truststore() != null || sslEngineFactory.keystore() != null) { + SslEngineValidator.validate(sslEngineFactory, newSslEngineFactory); + } + } + return newSslEngineFactory; + } catch (Exception e) { + log.debug("Validation of dynamic config update of SSLFactory failed.", e); + throw new ConfigException("Validation of dynamic config update of SSLFactory failed: " + e); + } + } + + public SSLEngine createSslEngine(Socket socket) { + return createSslEngine(peerHost(socket), socket.getPort()); + } + + /** + * Prefer `createSslEngine(Socket)` if a `Socket` instance is available. If using this overload, + * avoid reverse DNS resolution in the computation of `peerHost`. + */ + public SSLEngine createSslEngine(String peerHost, int peerPort) { + if (sslEngineFactory == null) { + throw new IllegalStateException("SslFactory has not been configured."); + } + if (mode == Mode.SERVER) { + return sslEngineFactory.createServerSslEngine(peerHost, peerPort); + } else { + return sslEngineFactory.createClientSslEngine(peerHost, peerPort, endpointIdentification); + } + } + + /** + * Returns host/IP address of remote host without reverse DNS lookup to be used as the host + * for creating SSL engine. This is used as a hint for session reuse strategy and also for + * hostname verification of server hostnames. + *

            + * Scenarios: + *

              + *
            • Server-side + *
                + *
              • Server accepts connection from a client. Server knows only client IP + * address. We want to avoid reverse DNS lookup of the client IP address since the server + * does not verify or use client hostname. The IP address can be used directly.
              • + *
              + *
            • + *
            • Client-side + *
                + *
              • Client connects to server using hostname. No lookup is necessary + * and the hostname should be used to create the SSL engine. This hostname is validated + * against the hostname in SubjectAltName (dns) or CommonName in the certificate if + * hostname verification is enabled. Authentication fails if hostname does not match.
              • + *
              • Client connects to server using IP address, but certificate contains only + * SubjectAltName (dns). Use of reverse DNS lookup to determine hostname introduces + * a security vulnerability since authentication would be reliant on a secure DNS. + * Hence hostname verification should fail in this case.
              • + *
              • Client connects to server using IP address and certificate contains + * SubjectAltName (ipaddress). This could be used when Kafka is on a private network. + * If reverse DNS lookup is used, authentication would succeed using IP address if lookup + * fails and IP address is used, but authentication would fail if lookup succeeds and + * dns name is used. For consistency and to avoid dependency on a potentially insecure + * DNS, reverse DNS lookup should be avoided and the IP address specified by the client for + * connection should be used to create the SSL engine.
              • + *
            • + *
            + */ + private String peerHost(Socket socket) { + return new InetSocketAddress(socket.getInetAddress(), 0).getHostString(); + } + + public SslEngineFactory sslEngineFactory() { + return sslEngineFactory; + } + + /** + * Copy entries from one map into another. + * + * @param destMap The map to copy entries into. + * @param srcMap The map to copy entries from. + * @param keySet Only entries with these keys will be copied. + * @param The map key type. + * @param The map value type. + */ + private static void copyMapEntries(Map destMap, + Map srcMap, + Set keySet) { + for (K k : keySet) { + copyMapEntry(destMap, srcMap, k); + } + } + + /** + * Copy entry from one map into another. + * + * @param destMap The map to copy entries into. + * @param srcMap The map to copy entries from. + * @param key The entry with this key will be copied + * @param The map key type. + * @param The map value type. + */ + private static void copyMapEntry(Map destMap, + Map srcMap, + K key) { + if (srcMap.containsKey(key)) { + destMap.put(key, srcMap.get(key)); + } + } + + @Override + public void close() { + Utils.closeQuietly(sslEngineFactory, "close engine factory"); + } + + static class CertificateEntries { + private final String alias; + private final Principal subjectPrincipal; + private final Set> subjectAltNames; + + static List create(KeyStore keystore) throws GeneralSecurityException { + Enumeration aliases = keystore.aliases(); + List entries = new ArrayList<>(); + while (aliases.hasMoreElements()) { + String alias = aliases.nextElement(); + Certificate cert = keystore.getCertificate(alias); + if (cert instanceof X509Certificate) + entries.add(new CertificateEntries(alias, (X509Certificate) cert)); + } + return entries; + } + + static void ensureCompatible(KeyStore newKeystore, KeyStore oldKeystore) throws GeneralSecurityException { + List newEntries = CertificateEntries.create(newKeystore); + List oldEntries = CertificateEntries.create(oldKeystore); + if (newEntries.size() != oldEntries.size()) { + throw new ConfigException(String.format("Keystore entries do not match, existing store contains %d entries, new store contains %d entries", + oldEntries.size(), newEntries.size())); + } + for (int i = 0; i < newEntries.size(); i++) { + CertificateEntries newEntry = newEntries.get(i); + CertificateEntries oldEntry = oldEntries.get(i); + if (!Objects.equals(newEntry.subjectPrincipal, oldEntry.subjectPrincipal)) { + throw new ConfigException(String.format("Keystore DistinguishedName does not match: " + + " existing={alias=%s, DN=%s}, new={alias=%s, DN=%s}", + oldEntry.alias, oldEntry.subjectPrincipal, newEntry.alias, newEntry.subjectPrincipal)); + } + if (!newEntry.subjectAltNames.containsAll(oldEntry.subjectAltNames)) { + throw new ConfigException(String.format("Keystore SubjectAltNames do not match: " + + " existing={alias=%s, SAN=%s}, new={alias=%s, SAN=%s}", + oldEntry.alias, oldEntry.subjectAltNames, newEntry.alias, newEntry.subjectAltNames)); + } + } + } + + CertificateEntries(String alias, X509Certificate cert) throws GeneralSecurityException { + this.alias = alias; + this.subjectPrincipal = cert.getSubjectX500Principal(); + Collection> altNames = cert.getSubjectAlternativeNames(); + // use a set for comparison + this.subjectAltNames = altNames != null ? new HashSet<>(altNames) : Collections.emptySet(); + } + + @Override + public int hashCode() { + return Objects.hash(subjectPrincipal, subjectAltNames); + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof CertificateEntries)) + return false; + CertificateEntries other = (CertificateEntries) obj; + return Objects.equals(subjectPrincipal, other.subjectPrincipal) && + Objects.equals(subjectAltNames, other.subjectAltNames); + } + + @Override + public String toString() { + return "subjectPrincipal=" + subjectPrincipal + + ", subjectAltNames=" + subjectAltNames; + } + } + + /** + * Validator used to verify dynamic update of keystore used in inter-broker communication. + * The validator checks that a successful handshake can be performed using the keystore and + * truststore configured on this SslFactory. + */ + private static class SslEngineValidator { + private static final ByteBuffer EMPTY_BUF = ByteBuffer.allocate(0); + private final SSLEngine sslEngine; + private SSLEngineResult handshakeResult; + private ByteBuffer appBuffer; + private ByteBuffer netBuffer; + + static void validate(SslEngineFactory oldEngineBuilder, + SslEngineFactory newEngineBuilder) throws SSLException { + validate(createSslEngineForValidation(oldEngineBuilder, Mode.SERVER), + createSslEngineForValidation(newEngineBuilder, Mode.CLIENT)); + validate(createSslEngineForValidation(newEngineBuilder, Mode.SERVER), + createSslEngineForValidation(oldEngineBuilder, Mode.CLIENT)); + } + + private static SSLEngine createSslEngineForValidation(SslEngineFactory sslEngineFactory, Mode mode) { + // Use empty hostname, disable hostname verification + if (mode == Mode.SERVER) { + return sslEngineFactory.createServerSslEngine("", 0); + } else { + return sslEngineFactory.createClientSslEngine("", 0, ""); + } + } + + static void validate(SSLEngine clientEngine, SSLEngine serverEngine) throws SSLException { + SslEngineValidator clientValidator = new SslEngineValidator(clientEngine); + SslEngineValidator serverValidator = new SslEngineValidator(serverEngine); + try { + clientValidator.beginHandshake(); + serverValidator.beginHandshake(); + while (!serverValidator.complete() || !clientValidator.complete()) { + clientValidator.handshake(serverValidator); + serverValidator.handshake(clientValidator); + } + } finally { + clientValidator.close(); + serverValidator.close(); + } + } + + private SslEngineValidator(SSLEngine engine) { + this.sslEngine = engine; + appBuffer = ByteBuffer.allocate(sslEngine.getSession().getApplicationBufferSize()); + netBuffer = ByteBuffer.allocate(sslEngine.getSession().getPacketBufferSize()); + } + + void beginHandshake() throws SSLException { + sslEngine.beginHandshake(); + } + void handshake(SslEngineValidator peerValidator) throws SSLException { + SSLEngineResult.HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus(); + while (true) { + switch (handshakeStatus) { + case NEED_WRAP: + if (netBuffer.position() != 0) // Wait for peer to consume previously wrapped data + return; + handshakeResult = sslEngine.wrap(EMPTY_BUF, netBuffer); + switch (handshakeResult.getStatus()) { + case OK: break; + case BUFFER_OVERFLOW: + netBuffer.compact(); + netBuffer = Utils.ensureCapacity(netBuffer, sslEngine.getSession().getPacketBufferSize()); + netBuffer.flip(); + break; + case BUFFER_UNDERFLOW: + case CLOSED: + default: + throw new SSLException("Unexpected handshake status: " + handshakeResult.getStatus()); + } + return; + case NEED_UNWRAP: + if (peerValidator.netBuffer.position() == 0) // no data to unwrap, return to process peer + return; + peerValidator.netBuffer.flip(); // unwrap the data from peer + handshakeResult = sslEngine.unwrap(peerValidator.netBuffer, appBuffer); + peerValidator.netBuffer.compact(); + handshakeStatus = handshakeResult.getHandshakeStatus(); + switch (handshakeResult.getStatus()) { + case OK: break; + case BUFFER_OVERFLOW: + appBuffer = Utils.ensureCapacity(appBuffer, sslEngine.getSession().getApplicationBufferSize()); + break; + case BUFFER_UNDERFLOW: + netBuffer = Utils.ensureCapacity(netBuffer, sslEngine.getSession().getPacketBufferSize()); + break; + case CLOSED: + default: + throw new SSLException("Unexpected handshake status: " + handshakeResult.getStatus()); + } + break; + case NEED_TASK: + sslEngine.getDelegatedTask().run(); + handshakeStatus = sslEngine.getHandshakeStatus(); + break; + case FINISHED: + return; + case NOT_HANDSHAKING: + if (handshakeResult.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.FINISHED) + throw new SSLException("Did not finish handshake"); + return; + default: + throw new IllegalStateException("Unexpected handshake status " + handshakeStatus); + } + } + } + + boolean complete() { + return sslEngine.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED || + sslEngine.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING; + } + + void close() { + sslEngine.closeOutbound(); + try { + sslEngine.closeInbound(); + } catch (Exception e) { + // ignore + } + } + } +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/common/security/ssl/SslPrincipalMapper.java b/clients/src/main/java/org/apache/kafka/common/security/ssl/SslPrincipalMapper.java new file mode 100644 index 0000000..33da964 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/ssl/SslPrincipalMapper.java @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.ssl; + +import java.io.IOException; +import java.util.List; +import java.util.ArrayList; +import java.util.Locale; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static org.apache.kafka.common.config.internals.BrokerSecurityConfigs.DEFAULT_SSL_PRINCIPAL_MAPPING_RULES; + +public class SslPrincipalMapper { + + private static final String RULE_PATTERN = "(DEFAULT)|RULE:((\\\\.|[^\\\\/])*)/((\\\\.|[^\\\\/])*)/([LU]?).*?|(.*?)"; + private static final Pattern RULE_SPLITTER = Pattern.compile("\\s*(" + RULE_PATTERN + ")\\s*(,\\s*|$)"); + private static final Pattern RULE_PARSER = Pattern.compile(RULE_PATTERN); + + private final List rules; + + public SslPrincipalMapper(String sslPrincipalMappingRules) { + this.rules = parseRules(splitRules(sslPrincipalMappingRules)); + } + + public static SslPrincipalMapper fromRules(String sslPrincipalMappingRules) { + return new SslPrincipalMapper(sslPrincipalMappingRules); + } + + private static List splitRules(String sslPrincipalMappingRules) { + if (sslPrincipalMappingRules == null) { + sslPrincipalMappingRules = DEFAULT_SSL_PRINCIPAL_MAPPING_RULES; + } + + List result = new ArrayList<>(); + Matcher matcher = RULE_SPLITTER.matcher(sslPrincipalMappingRules.trim()); + while (matcher.find()) { + result.add(matcher.group(1)); + } + + return result; + } + + private static List parseRules(List rules) { + List result = new ArrayList<>(); + for (String rule : rules) { + Matcher matcher = RULE_PARSER.matcher(rule); + if (!matcher.lookingAt()) { + throw new IllegalArgumentException("Invalid rule: " + rule); + } + if (rule.length() != matcher.end()) { + throw new IllegalArgumentException("Invalid rule: `" + rule + "`, unmatched substring: `" + rule.substring(matcher.end()) + "`"); + } + + // empty rules are ignored + if (matcher.group(1) != null) { + result.add(new Rule()); + } else if (matcher.group(2) != null) { + result.add(new Rule(matcher.group(2), + matcher.group(4), + "L".equals(matcher.group(6)), + "U".equals(matcher.group(6)))); + } + } + + return result; + } + + public String getName(String distinguishedName) throws IOException { + for (Rule r : rules) { + String principalName = r.apply(distinguishedName); + if (principalName != null) { + return principalName; + } + } + throw new NoMatchingRule("No rules apply to " + distinguishedName + ", rules " + rules); + } + + @Override + public String toString() { + return "SslPrincipalMapper(rules = " + rules + ")"; + } + + public static class NoMatchingRule extends IOException { + NoMatchingRule(String msg) { + super(msg); + } + } + + private static class Rule { + private static final Pattern BACK_REFERENCE_PATTERN = Pattern.compile("\\$(\\d+)"); + + private final boolean isDefault; + private final Pattern pattern; + private final String replacement; + private final boolean toLowerCase; + private final boolean toUpperCase; + + Rule() { + isDefault = true; + pattern = null; + replacement = null; + toLowerCase = false; + toUpperCase = false; + } + + Rule(String pattern, String replacement, boolean toLowerCase, boolean toUpperCase) { + isDefault = false; + this.pattern = pattern == null ? null : Pattern.compile(pattern); + this.replacement = replacement; + this.toLowerCase = toLowerCase; + this.toUpperCase = toUpperCase; + } + + String apply(String distinguishedName) { + if (isDefault) { + return distinguishedName; + } + + String result = null; + final Matcher m = pattern.matcher(distinguishedName); + + if (m.matches()) { + result = distinguishedName.replaceAll(pattern.pattern(), escapeLiteralBackReferences(replacement, m.groupCount())); + } + + if (toLowerCase && result != null) { + result = result.toLowerCase(Locale.ENGLISH); + } else if (toUpperCase & result != null) { + result = result.toUpperCase(Locale.ENGLISH); + } + + return result; + } + + //If we find a back reference that is not valid, then we will treat it as a literal string. For example, if we have 3 capturing + //groups and the Replacement Value has the value is "$1@$4", then we want to treat the $4 as a literal "$4", rather + //than attempting to use it as a back reference. + //This method was taken from Apache Nifi project : org.apache.nifi.authorization.util.IdentityMappingUtil + private String escapeLiteralBackReferences(final String unescaped, final int numCapturingGroups) { + if (numCapturingGroups == 0) { + return unescaped; + } + + String value = unescaped; + final Matcher backRefMatcher = BACK_REFERENCE_PATTERN.matcher(value); + while (backRefMatcher.find()) { + final String backRefNum = backRefMatcher.group(1); + if (backRefNum.startsWith("0")) { + continue; + } + int backRefIndex = Integer.parseInt(backRefNum); + + + // if we have a replacement value like $123, and we have less than 123 capturing groups, then + // we want to truncate the 3 and use capturing group 12; if we have less than 12 capturing groups, + // then we want to truncate the 2 and use capturing group 1; if we don't have a capturing group then + // we want to truncate the 1 and get 0. + while (backRefIndex > numCapturingGroups && backRefIndex >= 10) { + backRefIndex /= 10; + } + + if (backRefIndex > numCapturingGroups) { + final StringBuilder sb = new StringBuilder(value.length() + 1); + final int groupStart = backRefMatcher.start(1); + + sb.append(value.substring(0, groupStart - 1)); + sb.append("\\"); + sb.append(value.substring(groupStart - 1)); + value = sb.toString(); + } + } + + return value; + } + + @Override + public String toString() { + StringBuilder buf = new StringBuilder(); + if (isDefault) { + buf.append("DEFAULT"); + } else { + buf.append("RULE:"); + if (pattern != null) { + buf.append(pattern); + } + if (replacement != null) { + buf.append("/"); + buf.append(replacement); + } + if (toLowerCase) { + buf.append("/L"); + } else if (toUpperCase) { + buf.append("/U"); + } + } + return buf.toString(); + } + + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/token/delegation/DelegationToken.java b/clients/src/main/java/org/apache/kafka/common/security/token/delegation/DelegationToken.java new file mode 100644 index 0000000..a2141b5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/token/delegation/DelegationToken.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.token.delegation; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.security.MessageDigest; +import java.util.Arrays; +import java.util.Base64; +import java.util.Objects; + +/** + * A class representing a delegation token. + * + */ +@InterfaceStability.Evolving +public class DelegationToken { + private TokenInformation tokenInformation; + private byte[] hmac; + + public DelegationToken(TokenInformation tokenInformation, byte[] hmac) { + this.tokenInformation = tokenInformation; + this.hmac = hmac; + } + + public TokenInformation tokenInfo() { + return tokenInformation; + } + + public byte[] hmac() { + return hmac; + } + + public String hmacAsBase64String() { + return Base64.getEncoder().encodeToString(hmac); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + DelegationToken token = (DelegationToken) o; + + return Objects.equals(tokenInformation, token.tokenInformation) && MessageDigest.isEqual(hmac, token.hmac); + } + + @Override + public int hashCode() { + int result = tokenInformation != null ? tokenInformation.hashCode() : 0; + result = 31 * result + Arrays.hashCode(hmac); + return result; + } + + @Override + public String toString() { + return "DelegationToken{" + + "tokenInformation=" + tokenInformation + + ", hmac=[*******]" + + '}'; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/token/delegation/TokenInformation.java b/clients/src/main/java/org/apache/kafka/common/security/token/delegation/TokenInformation.java new file mode 100644 index 0000000..9903eb5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/token/delegation/TokenInformation.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.token.delegation; + +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.security.auth.KafkaPrincipal; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Objects; + +/** + * A class representing a delegation token details. + * + */ +@InterfaceStability.Evolving +public class TokenInformation { + + private KafkaPrincipal owner; + private Collection renewers; + private long issueTimestamp; + private long maxTimestamp; + private long expiryTimestamp; + private String tokenId; + + public TokenInformation(String tokenId, KafkaPrincipal owner, Collection renewers, + long issueTimestamp, long maxTimestamp, long expiryTimestamp) { + this.tokenId = tokenId; + this.owner = owner; + this.renewers = renewers; + this.issueTimestamp = issueTimestamp; + this.maxTimestamp = maxTimestamp; + this.expiryTimestamp = expiryTimestamp; + } + + public KafkaPrincipal owner() { + return owner; + } + + public String ownerAsString() { + return owner.toString(); + } + + public Collection renewers() { + return renewers; + } + + public Collection renewersAsString() { + Collection renewerList = new ArrayList<>(); + for (KafkaPrincipal renewer : renewers) { + renewerList.add(renewer.toString()); + } + return renewerList; + } + + public long issueTimestamp() { + return issueTimestamp; + } + + public long expiryTimestamp() { + return expiryTimestamp; + } + + public void setExpiryTimestamp(long expiryTimestamp) { + this.expiryTimestamp = expiryTimestamp; + } + + public String tokenId() { + return tokenId; + } + + public long maxTimestamp() { + return maxTimestamp; + } + + public boolean ownerOrRenewer(KafkaPrincipal principal) { + return owner.equals(principal) || renewers.contains(principal); + } + + @Override + public String toString() { + return "TokenInformation{" + + "owner=" + owner + + ", renewers=" + renewers + + ", issueTimestamp=" + issueTimestamp + + ", maxTimestamp=" + maxTimestamp + + ", expiryTimestamp=" + expiryTimestamp + + ", tokenId='" + tokenId + '\'' + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + TokenInformation that = (TokenInformation) o; + + return issueTimestamp == that.issueTimestamp && + maxTimestamp == that.maxTimestamp && + Objects.equals(owner, that.owner) && + Objects.equals(renewers, that.renewers) && + Objects.equals(tokenId, that.tokenId); + } + + @Override + public int hashCode() { + int result = owner != null ? owner.hashCode() : 0; + result = 31 * result + (renewers != null ? renewers.hashCode() : 0); + result = 31 * result + Long.hashCode(issueTimestamp); + result = 31 * result + Long.hashCode(maxTimestamp); + result = 31 * result + (tokenId != null ? tokenId.hashCode() : 0); + return result; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/token/delegation/internals/DelegationTokenCache.java b/clients/src/main/java/org/apache/kafka/common/security/token/delegation/internals/DelegationTokenCache.java new file mode 100644 index 0000000..9cc913f --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/token/delegation/internals/DelegationTokenCache.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.token.delegation.internals; + +import org.apache.kafka.common.security.authenticator.CredentialCache; +import org.apache.kafka.common.security.scram.ScramCredential; +import org.apache.kafka.common.security.scram.internals.ScramCredentialUtils; +import org.apache.kafka.common.security.scram.internals.ScramMechanism; +import org.apache.kafka.common.security.token.delegation.DelegationToken; +import org.apache.kafka.common.security.token.delegation.TokenInformation; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +public class DelegationTokenCache { + + private CredentialCache credentialCache = new CredentialCache(); + + //Cache to hold all the tokens + private Map tokenCache = new ConcurrentHashMap<>(); + + //Cache to hold hmac->tokenId mapping. This is required for renew, expire requests + private Map hmacTokenIdCache = new ConcurrentHashMap<>(); + + //Cache to hold tokenId->hmac mapping. This is required for removing entry from hmacTokenIdCache using tokenId. + private Map tokenIdHmacCache = new ConcurrentHashMap<>(); + + public DelegationTokenCache(Collection scramMechanisms) { + //Create caches for scramMechanisms + ScramCredentialUtils.createCache(credentialCache, scramMechanisms); + } + + public ScramCredential credential(String mechanism, String tokenId) { + CredentialCache.Cache cache = credentialCache.cache(mechanism, ScramCredential.class); + return cache == null ? null : cache.get(tokenId); + } + + public String owner(String tokenId) { + TokenInformation tokenInfo = tokenCache.get(tokenId); + return tokenInfo == null ? null : tokenInfo.owner().getName(); + } + + public void updateCache(DelegationToken token, Map scramCredentialMap) { + //Update TokenCache + String tokenId = token.tokenInfo().tokenId(); + addToken(tokenId, token.tokenInfo()); + String hmac = token.hmacAsBase64String(); + //Update Scram Credentials + updateCredentials(tokenId, scramCredentialMap); + //Update hmac-id cache + hmacTokenIdCache.put(hmac, tokenId); + tokenIdHmacCache.put(tokenId, hmac); + } + + public void removeCache(String tokenId) { + removeToken(tokenId); + updateCredentials(tokenId, new HashMap<>()); + } + + public String tokenIdForHmac(String base64hmac) { + return hmacTokenIdCache.get(base64hmac); + } + + public TokenInformation tokenForHmac(String base64hmac) { + String tokenId = hmacTokenIdCache.get(base64hmac); + return tokenId == null ? null : tokenCache.get(tokenId); + } + + public TokenInformation addToken(String tokenId, TokenInformation tokenInfo) { + return tokenCache.put(tokenId, tokenInfo); + } + + public void removeToken(String tokenId) { + TokenInformation tokenInfo = tokenCache.remove(tokenId); + if (tokenInfo != null) { + String hmac = tokenIdHmacCache.remove(tokenInfo.tokenId()); + if (hmac != null) { + hmacTokenIdCache.remove(hmac); + } + } + } + + public Collection tokens() { + return tokenCache.values(); + } + + public TokenInformation token(String tokenId) { + return tokenCache.get(tokenId); + } + + public CredentialCache.Cache credentialCache(String mechanism) { + return credentialCache.cache(mechanism, ScramCredential.class); + } + + private void updateCredentials(String tokenId, Map scramCredentialMap) { + for (String mechanism : ScramMechanism.mechanismNames()) { + CredentialCache.Cache cache = credentialCache.cache(mechanism, ScramCredential.class); + if (cache != null) { + ScramCredential credential = scramCredentialMap.get(mechanism); + if (credential == null) { + cache.remove(tokenId); + } else { + cache.put(tokenId, credential); + } + } + } + } +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/common/security/token/delegation/internals/DelegationTokenCredentialCallback.java b/clients/src/main/java/org/apache/kafka/common/security/token/delegation/internals/DelegationTokenCredentialCallback.java new file mode 100644 index 0000000..5d9eee9 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/token/delegation/internals/DelegationTokenCredentialCallback.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.token.delegation.internals; + +import org.apache.kafka.common.security.scram.ScramCredentialCallback; + +public class DelegationTokenCredentialCallback extends ScramCredentialCallback { + private String tokenOwner; + private Long tokenExpiryTimestamp; + + public void tokenOwner(String tokenOwner) { + this.tokenOwner = tokenOwner; + } + + public String tokenOwner() { + return tokenOwner; + } + + public void tokenExpiryTimestamp(Long tokenExpiryTimestamp) { + this.tokenExpiryTimestamp = tokenExpiryTimestamp; + } + + public Long tokenExpiryTimestamp() { + return tokenExpiryTimestamp; + } +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/common/serialization/ByteArrayDeserializer.java b/clients/src/main/java/org/apache/kafka/common/serialization/ByteArrayDeserializer.java new file mode 100644 index 0000000..1147f45 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/serialization/ByteArrayDeserializer.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +public class ByteArrayDeserializer implements Deserializer { + + @Override + public byte[] deserialize(String topic, byte[] data) { + return data; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/serialization/ByteArraySerializer.java b/clients/src/main/java/org/apache/kafka/common/serialization/ByteArraySerializer.java new file mode 100644 index 0000000..6bebaa6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/serialization/ByteArraySerializer.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +public class ByteArraySerializer implements Serializer { + @Override + public byte[] serialize(String topic, byte[] data) { + return data; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/serialization/ByteBufferDeserializer.java b/clients/src/main/java/org/apache/kafka/common/serialization/ByteBufferDeserializer.java new file mode 100644 index 0000000..0dfcf5f --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/serialization/ByteBufferDeserializer.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +import java.nio.ByteBuffer; + +public class ByteBufferDeserializer implements Deserializer { + public ByteBuffer deserialize(String topic, byte[] data) { + if (data == null) + return null; + + return ByteBuffer.wrap(data); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/serialization/ByteBufferSerializer.java b/clients/src/main/java/org/apache/kafka/common/serialization/ByteBufferSerializer.java new file mode 100644 index 0000000..9fb1254 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/serialization/ByteBufferSerializer.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +import java.nio.ByteBuffer; + +public class ByteBufferSerializer implements Serializer { + public byte[] serialize(String topic, ByteBuffer data) { + if (data == null) + return null; + + data.rewind(); + + if (data.hasArray()) { + byte[] arr = data.array(); + if (data.arrayOffset() == 0 && arr.length == data.remaining()) { + return arr; + } + } + + byte[] ret = new byte[data.remaining()]; + data.get(ret, 0, ret.length); + data.rewind(); + return ret; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/serialization/BytesDeserializer.java b/clients/src/main/java/org/apache/kafka/common/serialization/BytesDeserializer.java new file mode 100644 index 0000000..1350dca --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/serialization/BytesDeserializer.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +import org.apache.kafka.common.utils.Bytes; + +public class BytesDeserializer implements Deserializer { + public Bytes deserialize(String topic, byte[] data) { + if (data == null) + return null; + + return new Bytes(data); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/serialization/BytesSerializer.java b/clients/src/main/java/org/apache/kafka/common/serialization/BytesSerializer.java new file mode 100644 index 0000000..62ea6ec --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/serialization/BytesSerializer.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +import org.apache.kafka.common.utils.Bytes; + +public class BytesSerializer implements Serializer { + public byte[] serialize(String topic, Bytes data) { + if (data == null) + return null; + + return data.get(); + } +} + diff --git a/clients/src/main/java/org/apache/kafka/common/serialization/Deserializer.java b/clients/src/main/java/org/apache/kafka/common/serialization/Deserializer.java new file mode 100644 index 0000000..eb56485 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/serialization/Deserializer.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +import org.apache.kafka.common.header.Headers; + +import java.io.Closeable; +import java.util.Map; + +/** + * An interface for converting bytes to objects. + * + * A class that implements this interface is expected to have a constructor with no parameters. + *

            + * Implement {@link org.apache.kafka.common.ClusterResourceListener} to receive cluster metadata once it's available. Please see the class documentation for ClusterResourceListener for more information. + * + * @param Type to be deserialized into. + */ +public interface Deserializer extends Closeable { + + /** + * Configure this class. + * @param configs configs in key/value pairs + * @param isKey whether is for key or value + */ + default void configure(Map configs, boolean isKey) { + // intentionally left blank + } + + /** + * Deserialize a record value from a byte array into a value or object. + * @param topic topic associated with the data + * @param data serialized bytes; may be null; implementations are recommended to handle null by returning a value or null rather than throwing an exception. + * @return deserialized typed data; may be null + */ + T deserialize(String topic, byte[] data); + + /** + * Deserialize a record value from a byte array into a value or object. + * @param topic topic associated with the data + * @param headers headers associated with the record; may be empty. + * @param data serialized bytes; may be null; implementations are recommended to handle null by returning a value or null rather than throwing an exception. + * @return deserialized typed data; may be null + */ + default T deserialize(String topic, Headers headers, byte[] data) { + return deserialize(topic, data); + } + + /** + * Close this deserializer. + *

            + * This method must be idempotent as it may be called multiple times. + */ + @Override + default void close() { + // intentionally left blank + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/serialization/DoubleDeserializer.java b/clients/src/main/java/org/apache/kafka/common/serialization/DoubleDeserializer.java new file mode 100644 index 0000000..0fa1cce --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/serialization/DoubleDeserializer.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +import org.apache.kafka.common.errors.SerializationException; + +public class DoubleDeserializer implements Deserializer { + + @Override + public Double deserialize(String topic, byte[] data) { + if (data == null) + return null; + if (data.length != 8) { + throw new SerializationException("Size of data received by Deserializer is not 8"); + } + + long value = 0; + for (byte b : data) { + value <<= 8; + value |= b & 0xFF; + } + return Double.longBitsToDouble(value); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/serialization/DoubleSerializer.java b/clients/src/main/java/org/apache/kafka/common/serialization/DoubleSerializer.java new file mode 100644 index 0000000..99781b5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/serialization/DoubleSerializer.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +public class DoubleSerializer implements Serializer { + @Override + public byte[] serialize(String topic, Double data) { + if (data == null) + return null; + + long bits = Double.doubleToLongBits(data); + return new byte[] { + (byte) (bits >>> 56), + (byte) (bits >>> 48), + (byte) (bits >>> 40), + (byte) (bits >>> 32), + (byte) (bits >>> 24), + (byte) (bits >>> 16), + (byte) (bits >>> 8), + (byte) bits + }; + } +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/common/serialization/FloatDeserializer.java b/clients/src/main/java/org/apache/kafka/common/serialization/FloatDeserializer.java new file mode 100644 index 0000000..0903177 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/serialization/FloatDeserializer.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +import org.apache.kafka.common.errors.SerializationException; + +public class FloatDeserializer implements Deserializer { + @Override + public Float deserialize(final String topic, final byte[] data) { + if (data == null) + return null; + if (data.length != 4) { + throw new SerializationException("Size of data received by Deserializer is not 4"); + } + + int value = 0; + for (byte b : data) { + value <<= 8; + value |= b & 0xFF; + } + return Float.intBitsToFloat(value); + } +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/common/serialization/FloatSerializer.java b/clients/src/main/java/org/apache/kafka/common/serialization/FloatSerializer.java new file mode 100644 index 0000000..aa72d43 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/serialization/FloatSerializer.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +public class FloatSerializer implements Serializer { + @Override + public byte[] serialize(final String topic, final Float data) { + if (data == null) + return null; + + long bits = Float.floatToRawIntBits(data); + return new byte[] { + (byte) (bits >>> 24), + (byte) (bits >>> 16), + (byte) (bits >>> 8), + (byte) bits + }; + } +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/common/serialization/IntegerDeserializer.java b/clients/src/main/java/org/apache/kafka/common/serialization/IntegerDeserializer.java new file mode 100644 index 0000000..20ca63f --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/serialization/IntegerDeserializer.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +import org.apache.kafka.common.errors.SerializationException; + +public class IntegerDeserializer implements Deserializer { + public Integer deserialize(String topic, byte[] data) { + if (data == null) + return null; + if (data.length != 4) { + throw new SerializationException("Size of data received by IntegerDeserializer is not 4"); + } + + int value = 0; + for (byte b : data) { + value <<= 8; + value |= b & 0xFF; + } + return value; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/serialization/IntegerSerializer.java b/clients/src/main/java/org/apache/kafka/common/serialization/IntegerSerializer.java new file mode 100644 index 0000000..8ab5310 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/serialization/IntegerSerializer.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +public class IntegerSerializer implements Serializer { + public byte[] serialize(String topic, Integer data) { + if (data == null) + return null; + + return new byte[] { + (byte) (data >>> 24), + (byte) (data >>> 16), + (byte) (data >>> 8), + data.byteValue() + }; + } +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/common/serialization/ListDeserializer.java b/clients/src/main/java/org/apache/kafka/common/serialization/ListDeserializer.java new file mode 100644 index 0000000..ad82c0b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/serialization/ListDeserializer.java @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +import static org.apache.kafka.common.serialization.Serdes.ListSerde.SerializationStrategy; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; + +import java.io.ByteArrayInputStream; +import java.io.DataInputStream; +import java.io.IOException; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.errors.SerializationException; +import org.apache.kafka.common.serialization.Serdes.ListSerde; +import org.apache.kafka.common.utils.Utils; + +public class ListDeserializer implements Deserializer> { + + final Logger log = LoggerFactory.getLogger(ListDeserializer.class); + + private static final Map>, Integer> FIXED_LENGTH_DESERIALIZERS = mkMap( + mkEntry(ShortDeserializer.class, Short.BYTES), + mkEntry(IntegerDeserializer.class, Integer.BYTES), + mkEntry(FloatDeserializer.class, Float.BYTES), + mkEntry(LongDeserializer.class, Long.BYTES), + mkEntry(DoubleDeserializer.class, Double.BYTES), + mkEntry(UUIDDeserializer.class, 36) + ); + + private Deserializer inner; + private Class listClass; + private Integer primitiveSize; + + public ListDeserializer() {} + + public > ListDeserializer(Class listClass, Deserializer inner) { + if (listClass == null || inner == null) { + log.error("Could not construct ListDeserializer as not all required parameters were present -- listClass: {}, inner: {}", listClass, inner); + throw new IllegalArgumentException("ListDeserializer requires both \"listClass\" and \"innerDeserializer\" parameters to be provided during initialization"); + } + this.listClass = listClass; + this.inner = inner; + this.primitiveSize = FIXED_LENGTH_DESERIALIZERS.get(inner.getClass()); + } + + public Deserializer innerDeserializer() { + return inner; + } + + @Override + public void configure(Map configs, boolean isKey) { + if (listClass != null || inner != null) { + log.error("Could not configure ListDeserializer as some parameters were already set -- listClass: {}, inner: {}", listClass, inner); + throw new ConfigException("List deserializer was already initialized using a non-default constructor"); + } + configureListClass(configs, isKey); + configureInnerSerde(configs, isKey); + } + + private void configureListClass(Map configs, boolean isKey) { + String listTypePropertyName = isKey ? CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_TYPE_CLASS : CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_TYPE_CLASS; + final Object listClassOrName = configs.get(listTypePropertyName); + if (listClassOrName == null) { + throw new ConfigException("Not able to determine the list class because it was neither passed via the constructor nor set in the config."); + } + try { + if (listClassOrName instanceof String) { + listClass = Utils.loadClass((String) listClassOrName, Object.class); + } else if (listClassOrName instanceof Class) { + listClass = (Class) listClassOrName; + } else { + throw new KafkaException("Could not determine the list class instance using \"" + listTypePropertyName + "\" property."); + } + } catch (final ClassNotFoundException e) { + throw new ConfigException(listTypePropertyName, listClassOrName, "Deserializer's list class \"" + listClassOrName + "\" could not be found."); + } + } + + @SuppressWarnings("unchecked") + private void configureInnerSerde(Map configs, boolean isKey) { + String innerSerdePropertyName = isKey ? CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_INNER_CLASS : CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_INNER_CLASS; + final Object innerSerdeClassOrName = configs.get(innerSerdePropertyName); + if (innerSerdeClassOrName == null) { + throw new ConfigException("Not able to determine the inner serde class because it was neither passed via the constructor nor set in the config."); + } + try { + if (innerSerdeClassOrName instanceof String) { + inner = Utils.newInstance((String) innerSerdeClassOrName, Serde.class).deserializer(); + } else if (innerSerdeClassOrName instanceof Class) { + inner = (Deserializer) ((Serde) Utils.newInstance((Class) innerSerdeClassOrName)).deserializer(); + } else { + throw new KafkaException("Could not determine the inner serde class instance using \"" + innerSerdePropertyName + "\" property."); + } + inner.configure(configs, isKey); + primitiveSize = FIXED_LENGTH_DESERIALIZERS.get(inner.getClass()); + } catch (final ClassNotFoundException e) { + throw new ConfigException(innerSerdePropertyName, innerSerdeClassOrName, "Deserializer's inner serde class \"" + innerSerdeClassOrName + "\" could not be found."); + } + } + + @SuppressWarnings("unchecked") + private List createListInstance(int listSize) { + try { + Constructor> listConstructor; + try { + listConstructor = (Constructor>) listClass.getConstructor(Integer.TYPE); + return listConstructor.newInstance(listSize); + } catch (NoSuchMethodException e) { + listConstructor = (Constructor>) listClass.getConstructor(); + return listConstructor.newInstance(); + } + } catch (InstantiationException | IllegalAccessException | NoSuchMethodException | + IllegalArgumentException | InvocationTargetException e) { + log.error("Failed to construct list due to ", e); + throw new KafkaException("Could not construct a list instance of \"" + listClass.getCanonicalName() + "\"", e); + } + } + + private SerializationStrategy parseSerializationStrategyFlag(final int serializationStrategyFlag) throws IOException { + if (serializationStrategyFlag < 0 || serializationStrategyFlag >= SerializationStrategy.VALUES.length) { + throw new SerializationException("Invalid serialization strategy flag value"); + } + return SerializationStrategy.VALUES[serializationStrategyFlag]; + } + + private List deserializeNullIndexList(final DataInputStream dis) throws IOException { + int nullIndexListSize = dis.readInt(); + List nullIndexList = new ArrayList<>(nullIndexListSize); + while (nullIndexListSize != 0) { + nullIndexList.add(dis.readInt()); + nullIndexListSize--; + } + return nullIndexList; + } + + @Override + public List deserialize(String topic, byte[] data) { + if (data == null) { + return null; + } + try (final DataInputStream dis = new DataInputStream(new ByteArrayInputStream(data))) { + SerializationStrategy serStrategy = parseSerializationStrategyFlag(dis.readByte()); + List nullIndexList = null; + if (serStrategy == SerializationStrategy.CONSTANT_SIZE) { + // In CONSTANT_SIZE strategy, indexes of null entries are decoded from a null index list + nullIndexList = deserializeNullIndexList(dis); + } + final int size = dis.readInt(); + List deserializedList = createListInstance(size); + for (int i = 0; i < size; i++) { + int entrySize = serStrategy == SerializationStrategy.CONSTANT_SIZE ? primitiveSize : dis.readInt(); + if (entrySize == ListSerde.NULL_ENTRY_VALUE || (nullIndexList != null && nullIndexList.contains(i))) { + deserializedList.add(null); + continue; + } + byte[] payload = new byte[entrySize]; + if (dis.read(payload) == -1) { + log.error("Ran out of bytes in serialized list"); + log.trace("Deserialized list so far: {}", deserializedList); // avoid logging actual data above TRACE level since it may contain sensitive information + throw new SerializationException("End of the stream was reached prematurely"); + } + deserializedList.add(inner.deserialize(topic, payload)); + } + return deserializedList; + } catch (IOException e) { + throw new KafkaException("Unable to deserialize into a List", e); + } + } + + @Override + public void close() { + if (inner != null) { + inner.close(); + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/serialization/ListSerializer.java b/clients/src/main/java/org/apache/kafka/common/serialization/ListSerializer.java new file mode 100644 index 0000000..2c15256 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/serialization/ListSerializer.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +import java.util.ArrayList; +import java.util.Iterator; +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.utils.Utils; + +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.kafka.common.serialization.Serdes.ListSerde.SerializationStrategy; + +public class ListSerializer implements Serializer> { + + final Logger log = LoggerFactory.getLogger(ListSerializer.class); + + private static final List>> FIXED_LENGTH_SERIALIZERS = Arrays.asList( + ShortSerializer.class, + IntegerSerializer.class, + FloatSerializer.class, + LongSerializer.class, + DoubleSerializer.class, + UUIDSerializer.class); + + private Serializer inner; + private SerializationStrategy serStrategy; + + public ListSerializer() {} + + public ListSerializer(Serializer inner) { + if (inner == null) { + throw new IllegalArgumentException("ListSerializer requires \"serializer\" parameter to be provided during initialization"); + } + this.inner = inner; + this.serStrategy = FIXED_LENGTH_SERIALIZERS.contains(inner.getClass()) ? SerializationStrategy.CONSTANT_SIZE : SerializationStrategy.VARIABLE_SIZE; + } + + public Serializer getInnerSerializer() { + return inner; + } + + @SuppressWarnings("unchecked") + @Override + public void configure(Map configs, boolean isKey) { + if (inner != null) { + log.error("Could not configure ListSerializer as the parameter has already been set -- inner: {}", inner); + throw new ConfigException("List serializer was already initialized using a non-default constructor"); + } + final String innerSerdePropertyName = isKey ? CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_INNER_CLASS : CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_INNER_CLASS; + final Object innerSerdeClassOrName = configs.get(innerSerdePropertyName); + if (innerSerdeClassOrName == null) { + throw new ConfigException("Not able to determine the serializer class because it was neither passed via the constructor nor set in the config."); + } + try { + if (innerSerdeClassOrName instanceof String) { + inner = Utils.newInstance((String) innerSerdeClassOrName, Serde.class).serializer(); + } else if (innerSerdeClassOrName instanceof Class) { + inner = (Serializer) ((Serde) Utils.newInstance((Class) innerSerdeClassOrName)).serializer(); + } else { + throw new KafkaException("Could not create a serializer class instance using \"" + innerSerdePropertyName + "\" property."); + } + inner.configure(configs, isKey); + serStrategy = FIXED_LENGTH_SERIALIZERS.contains(inner.getClass()) ? SerializationStrategy.CONSTANT_SIZE : SerializationStrategy.VARIABLE_SIZE; + } catch (final ClassNotFoundException e) { + throw new ConfigException(innerSerdePropertyName, innerSerdeClassOrName, "Serializer class " + innerSerdeClassOrName + " could not be found."); + } + } + + private void serializeNullIndexList(final DataOutputStream out, List data) throws IOException { + int i = 0; + List nullIndexList = new ArrayList<>(); + for (Iterator it = data.listIterator(); it.hasNext(); i++) { + if (it.next() == null) { + nullIndexList.add(i); + } + } + out.writeInt(nullIndexList.size()); + for (int nullIndex : nullIndexList) { + out.writeInt(nullIndex); + } + } + + @Override + public byte[] serialize(String topic, List data) { + if (data == null) { + return null; + } + try (final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final DataOutputStream out = new DataOutputStream(baos)) { + out.writeByte(serStrategy.ordinal()); // write serialization strategy flag + if (serStrategy == SerializationStrategy.CONSTANT_SIZE) { + // In CONSTANT_SIZE strategy, indexes of null entries are encoded in a null index list + serializeNullIndexList(out, data); + } + final int size = data.size(); + out.writeInt(size); + for (Inner entry : data) { + if (entry == null) { + if (serStrategy == SerializationStrategy.VARIABLE_SIZE) { + out.writeInt(Serdes.ListSerde.NULL_ENTRY_VALUE); + } + } else { + final byte[] bytes = inner.serialize(topic, entry); + if (serStrategy == SerializationStrategy.VARIABLE_SIZE) { + out.writeInt(bytes.length); + } + out.write(bytes); + } + } + return baos.toByteArray(); + } catch (IOException e) { + log.error("Failed to serialize list due to", e); + log.trace("List that could not be serialized: {}", data); // avoid logging actual data above TRACE level since it may contain sensitive information + throw new KafkaException("Failed to serialize List", e); + } + } + + @Override + public void close() { + if (inner != null) { + inner.close(); + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/serialization/LongDeserializer.java b/clients/src/main/java/org/apache/kafka/common/serialization/LongDeserializer.java new file mode 100644 index 0000000..1e445d2 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/serialization/LongDeserializer.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +import org.apache.kafka.common.errors.SerializationException; + +public class LongDeserializer implements Deserializer { + public Long deserialize(String topic, byte[] data) { + if (data == null) + return null; + if (data.length != 8) { + throw new SerializationException("Size of data received by LongDeserializer is not 8"); + } + + long value = 0; + for (byte b : data) { + value <<= 8; + value |= b & 0xFF; + } + return value; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/serialization/LongSerializer.java b/clients/src/main/java/org/apache/kafka/common/serialization/LongSerializer.java new file mode 100644 index 0000000..436f0e0 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/serialization/LongSerializer.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +public class LongSerializer implements Serializer { + public byte[] serialize(String topic, Long data) { + if (data == null) + return null; + + return new byte[] { + (byte) (data >>> 56), + (byte) (data >>> 48), + (byte) (data >>> 40), + (byte) (data >>> 32), + (byte) (data >>> 24), + (byte) (data >>> 16), + (byte) (data >>> 8), + data.byteValue() + }; + } +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/common/serialization/Serde.java b/clients/src/main/java/org/apache/kafka/common/serialization/Serde.java new file mode 100644 index 0000000..5b052e6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/serialization/Serde.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +import java.io.Closeable; +import java.util.Map; + +/** + * The interface for wrapping a serializer and deserializer for the given data type. + * + * @param Type to be serialized from and deserialized into. + * + * A class that implements this interface is expected to have a constructor with no parameter. + */ +public interface Serde extends Closeable { + + /** + * Configure this class, which will configure the underlying serializer and deserializer. + * + * @param configs configs in key/value pairs + * @param isKey whether is for key or value + */ + default void configure(Map configs, boolean isKey) { + // intentionally left blank + } + + /** + * Close this serde class, which will close the underlying serializer and deserializer. + *

            + * This method has to be idempotent because it might be called multiple times. + */ + @Override + default void close() { + // intentionally left blank + } + + Serializer serializer(); + + Deserializer deserializer(); +} diff --git a/clients/src/main/java/org/apache/kafka/common/serialization/Serdes.java b/clients/src/main/java/org/apache/kafka/common/serialization/Serdes.java new file mode 100644 index 0000000..4a150e0 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/serialization/Serdes.java @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +import org.apache.kafka.common.utils.Bytes; + +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +/** + * Factory for creating serializers / deserializers. + */ +public class Serdes { + + static public class WrapperSerde implements Serde { + final private Serializer serializer; + final private Deserializer deserializer; + + public WrapperSerde(Serializer serializer, Deserializer deserializer) { + this.serializer = serializer; + this.deserializer = deserializer; + } + + @Override + public void configure(Map configs, boolean isKey) { + serializer.configure(configs, isKey); + deserializer.configure(configs, isKey); + } + + @Override + public void close() { + serializer.close(); + deserializer.close(); + } + + @Override + public Serializer serializer() { + return serializer; + } + + @Override + public Deserializer deserializer() { + return deserializer; + } + } + + static public final class VoidSerde extends WrapperSerde { + public VoidSerde() { + super(new VoidSerializer(), new VoidDeserializer()); + } + } + + static public final class LongSerde extends WrapperSerde { + public LongSerde() { + super(new LongSerializer(), new LongDeserializer()); + } + } + + static public final class IntegerSerde extends WrapperSerde { + public IntegerSerde() { + super(new IntegerSerializer(), new IntegerDeserializer()); + } + } + + static public final class ShortSerde extends WrapperSerde { + public ShortSerde() { + super(new ShortSerializer(), new ShortDeserializer()); + } + } + + static public final class FloatSerde extends WrapperSerde { + public FloatSerde() { + super(new FloatSerializer(), new FloatDeserializer()); + } + } + + static public final class DoubleSerde extends WrapperSerde { + public DoubleSerde() { + super(new DoubleSerializer(), new DoubleDeserializer()); + } + } + + static public final class StringSerde extends WrapperSerde { + public StringSerde() { + super(new StringSerializer(), new StringDeserializer()); + } + } + + static public final class ByteBufferSerde extends WrapperSerde { + public ByteBufferSerde() { + super(new ByteBufferSerializer(), new ByteBufferDeserializer()); + } + } + + static public final class BytesSerde extends WrapperSerde { + public BytesSerde() { + super(new BytesSerializer(), new BytesDeserializer()); + } + } + + static public final class ByteArraySerde extends WrapperSerde { + public ByteArraySerde() { + super(new ByteArraySerializer(), new ByteArrayDeserializer()); + } + } + + static public final class UUIDSerde extends WrapperSerde { + public UUIDSerde() { + super(new UUIDSerializer(), new UUIDDeserializer()); + } + } + + static public final class ListSerde extends WrapperSerde> { + + final static int NULL_ENTRY_VALUE = -1; + + enum SerializationStrategy { + CONSTANT_SIZE, + VARIABLE_SIZE; + + public static final SerializationStrategy[] VALUES = SerializationStrategy.values(); + } + + public ListSerde() { + super(new ListSerializer<>(), new ListDeserializer<>()); + } + + public > ListSerde(Class listClass, Serde serde) { + super(new ListSerializer<>(serde.serializer()), new ListDeserializer<>(listClass, serde.deserializer())); + } + + } + + @SuppressWarnings("unchecked") + static public Serde serdeFrom(Class type) { + if (String.class.isAssignableFrom(type)) { + return (Serde) String(); + } + + if (Short.class.isAssignableFrom(type)) { + return (Serde) Short(); + } + + if (Integer.class.isAssignableFrom(type)) { + return (Serde) Integer(); + } + + if (Long.class.isAssignableFrom(type)) { + return (Serde) Long(); + } + + if (Float.class.isAssignableFrom(type)) { + return (Serde) Float(); + } + + if (Double.class.isAssignableFrom(type)) { + return (Serde) Double(); + } + + if (byte[].class.isAssignableFrom(type)) { + return (Serde) ByteArray(); + } + + if (ByteBuffer.class.isAssignableFrom(type)) { + return (Serde) ByteBuffer(); + } + + if (Bytes.class.isAssignableFrom(type)) { + return (Serde) Bytes(); + } + + if (UUID.class.isAssignableFrom(type)) { + return (Serde) UUID(); + } + + // TODO: we can also serializes objects of type T using generic Java serialization by default + throw new IllegalArgumentException("Unknown class for built-in serializer. Supported types are: " + + "String, Short, Integer, Long, Float, Double, ByteArray, ByteBuffer, Bytes, UUID"); + } + + /** + * Construct a serde object from separate serializer and deserializer + * + * @param serializer must not be null. + * @param deserializer must not be null. + */ + static public Serde serdeFrom(final Serializer serializer, final Deserializer deserializer) { + if (serializer == null) { + throw new IllegalArgumentException("serializer must not be null"); + } + if (deserializer == null) { + throw new IllegalArgumentException("deserializer must not be null"); + } + + return new WrapperSerde<>(serializer, deserializer); + } + + /** + * A serde for nullable {@code Long} type. + */ + static public Serde Long() { + return new LongSerde(); + } + + /** + * A serde for nullable {@code Integer} type. + */ + static public Serde Integer() { + return new IntegerSerde(); + } + + /** + * A serde for nullable {@code Short} type. + */ + static public Serde Short() { + return new ShortSerde(); + } + + /** + * A serde for nullable {@code Float} type. + */ + static public Serde Float() { + return new FloatSerde(); + } + + /** + * A serde for nullable {@code Double} type. + */ + static public Serde Double() { + return new DoubleSerde(); + } + + /** + * A serde for nullable {@code String} type. + */ + static public Serde String() { + return new StringSerde(); + } + + /** + * A serde for nullable {@code ByteBuffer} type. + */ + static public Serde ByteBuffer() { + return new ByteBufferSerde(); + } + + /** + * A serde for nullable {@code Bytes} type. + */ + static public Serde Bytes() { + return new BytesSerde(); + } + + /** + * A serde for nullable {@code UUID} type + */ + static public Serde UUID() { + return new UUIDSerde(); + } + + /** + * A serde for nullable {@code byte[]} type. + */ + static public Serde ByteArray() { + return new ByteArraySerde(); + } + + /** + * A serde for {@code Void} type. + */ + static public Serde Void() { + return new VoidSerde(); + } + + /* + * A serde for {@code List} type + */ + static public , Inner> Serde> ListSerde(Class listClass, Serde innerSerde) { + return new ListSerde<>(listClass, innerSerde); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/serialization/Serializer.java b/clients/src/main/java/org/apache/kafka/common/serialization/Serializer.java new file mode 100644 index 0000000..144b5ab --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/serialization/Serializer.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +import org.apache.kafka.common.header.Headers; + +import java.io.Closeable; +import java.util.Map; + +/** + * An interface for converting objects to bytes. + * + * A class that implements this interface is expected to have a constructor with no parameter. + *

            + * Implement {@link org.apache.kafka.common.ClusterResourceListener} to receive cluster metadata once it's available. Please see the class documentation for ClusterResourceListener for more information. + * + * @param Type to be serialized from. + */ +public interface Serializer extends Closeable { + + /** + * Configure this class. + * @param configs configs in key/value pairs + * @param isKey whether is for key or value + */ + default void configure(Map configs, boolean isKey) { + // intentionally left blank + } + + /** + * Convert {@code data} into a byte array. + * + * @param topic topic associated with data + * @param data typed data + * @return serialized bytes + */ + byte[] serialize(String topic, T data); + + /** + * Convert {@code data} into a byte array. + * + * @param topic topic associated with data + * @param headers headers associated with the record + * @param data typed data + * @return serialized bytes + */ + default byte[] serialize(String topic, Headers headers, T data) { + return serialize(topic, data); + } + + /** + * Close this serializer. + *

            + * This method must be idempotent as it may be called multiple times. + */ + @Override + default void close() { + // intentionally left blank + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/serialization/ShortDeserializer.java b/clients/src/main/java/org/apache/kafka/common/serialization/ShortDeserializer.java new file mode 100644 index 0000000..7814a7b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/serialization/ShortDeserializer.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +import org.apache.kafka.common.errors.SerializationException; + +public class ShortDeserializer implements Deserializer { + + public Short deserialize(String topic, byte[] data) { + if (data == null) + return null; + if (data.length != 2) { + throw new SerializationException("Size of data received by ShortDeserializer is not 2"); + } + + short value = 0; + for (byte b : data) { + value <<= 8; + value |= b & 0xFF; + } + return value; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/serialization/ShortSerializer.java b/clients/src/main/java/org/apache/kafka/common/serialization/ShortSerializer.java new file mode 100644 index 0000000..e54354b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/serialization/ShortSerializer.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +public class ShortSerializer implements Serializer { + public byte[] serialize(String topic, Short data) { + if (data == null) + return null; + + return new byte[] { + (byte) (data >>> 8), + data.byteValue() + }; + } +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/common/serialization/StringDeserializer.java b/clients/src/main/java/org/apache/kafka/common/serialization/StringDeserializer.java new file mode 100644 index 0000000..3d8b7bb --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/serialization/StringDeserializer.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +import org.apache.kafka.common.errors.SerializationException; + +import java.io.UnsupportedEncodingException; +import java.nio.charset.StandardCharsets; +import java.util.Map; + +/** + * String encoding defaults to UTF8 and can be customized by setting the property key.deserializer.encoding, + * value.deserializer.encoding or deserializer.encoding. The first two take precedence over the last. + */ +public class StringDeserializer implements Deserializer { + private String encoding = StandardCharsets.UTF_8.name(); + + @Override + public void configure(Map configs, boolean isKey) { + String propertyName = isKey ? "key.deserializer.encoding" : "value.deserializer.encoding"; + Object encodingValue = configs.get(propertyName); + if (encodingValue == null) + encodingValue = configs.get("deserializer.encoding"); + if (encodingValue instanceof String) + encoding = (String) encodingValue; + } + + @Override + public String deserialize(String topic, byte[] data) { + try { + if (data == null) + return null; + else + return new String(data, encoding); + } catch (UnsupportedEncodingException e) { + throw new SerializationException("Error when deserializing byte[] to string due to unsupported encoding " + encoding); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/serialization/StringSerializer.java b/clients/src/main/java/org/apache/kafka/common/serialization/StringSerializer.java new file mode 100644 index 0000000..ee01f1a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/serialization/StringSerializer.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +import org.apache.kafka.common.errors.SerializationException; + +import java.io.UnsupportedEncodingException; +import java.nio.charset.StandardCharsets; +import java.util.Map; + +/** + * String encoding defaults to UTF8 and can be customized by setting the property key.serializer.encoding, + * value.serializer.encoding or serializer.encoding. The first two take precedence over the last. + */ +public class StringSerializer implements Serializer { + private String encoding = StandardCharsets.UTF_8.name(); + + @Override + public void configure(Map configs, boolean isKey) { + String propertyName = isKey ? "key.serializer.encoding" : "value.serializer.encoding"; + Object encodingValue = configs.get(propertyName); + if (encodingValue == null) + encodingValue = configs.get("serializer.encoding"); + if (encodingValue instanceof String) + encoding = (String) encodingValue; + } + + @Override + public byte[] serialize(String topic, String data) { + try { + if (data == null) + return null; + else + return data.getBytes(encoding); + } catch (UnsupportedEncodingException e) { + throw new SerializationException("Error when serializing string to byte[] due to unsupported encoding " + encoding); + } + } +} \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/common/serialization/UUIDDeserializer.java b/clients/src/main/java/org/apache/kafka/common/serialization/UUIDDeserializer.java new file mode 100644 index 0000000..779a9bd --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/serialization/UUIDDeserializer.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +import org.apache.kafka.common.errors.SerializationException; + +import java.io.UnsupportedEncodingException; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.UUID; + +/** + * We are converting the byte array to String before deserializing to UUID. String encoding defaults to UTF8 and can be customized by setting + * the property key.deserializer.encoding, value.deserializer.encoding or deserializer.encoding. The first two take precedence over the last. + */ +public class UUIDDeserializer implements Deserializer { + private String encoding = StandardCharsets.UTF_8.name(); + + @Override + public void configure(Map configs, boolean isKey) { + String propertyName = isKey ? "key.deserializer.encoding" : "value.deserializer.encoding"; + Object encodingValue = configs.get(propertyName); + if (encodingValue == null) + encodingValue = configs.get("deserializer.encoding"); + if (encodingValue instanceof String) + encoding = (String) encodingValue; + } + + @Override + public UUID deserialize(String topic, byte[] data) { + try { + if (data == null) + return null; + else + return UUID.fromString(new String(data, encoding)); + } catch (UnsupportedEncodingException e) { + throw new SerializationException("Error when deserializing byte[] to UUID due to unsupported encoding " + encoding, e); + } catch (IllegalArgumentException e) { + throw new SerializationException("Error parsing data into UUID", e); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/serialization/UUIDSerializer.java b/clients/src/main/java/org/apache/kafka/common/serialization/UUIDSerializer.java new file mode 100644 index 0000000..1477546 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/serialization/UUIDSerializer.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +import org.apache.kafka.common.errors.SerializationException; + +import java.io.UnsupportedEncodingException; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.UUID; + +/** + * We are converting UUID to String before serializing. String encoding defaults to UTF8 and can be customized by setting + * the property key.deserializer.encoding, value.deserializer.encoding or deserializer.encoding. The first two take precedence over the last. + */ +public class UUIDSerializer implements Serializer { + private String encoding = StandardCharsets.UTF_8.name(); + + @Override + public void configure(Map configs, boolean isKey) { + String propertyName = isKey ? "key.serializer.encoding" : "value.serializer.encoding"; + Object encodingValue = configs.get(propertyName); + if (encodingValue == null) + encodingValue = configs.get("serializer.encoding"); + if (encodingValue instanceof String) + encoding = (String) encodingValue; + } + + @Override + public byte[] serialize(String topic, UUID data) { + try { + if (data == null) + return null; + else + return data.toString().getBytes(encoding); + } catch (UnsupportedEncodingException e) { + throw new SerializationException("Error when serializing UUID to byte[] due to unsupported encoding " + encoding); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/serialization/VoidDeserializer.java b/clients/src/main/java/org/apache/kafka/common/serialization/VoidDeserializer.java new file mode 100644 index 0000000..08ff57a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/serialization/VoidDeserializer.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +public class VoidDeserializer implements Deserializer { + @Override + public Void deserialize(String topic, byte[] data) { + if (data != null) + throw new IllegalArgumentException("Data should be null for a VoidDeserializer."); + + return null; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/serialization/VoidSerializer.java b/clients/src/main/java/org/apache/kafka/common/serialization/VoidSerializer.java new file mode 100644 index 0000000..f1f2c60 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/serialization/VoidSerializer.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +public class VoidSerializer implements Serializer { + @Override + public byte[] serialize(String topic, Void data) { + return null; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/AbstractIterator.java b/clients/src/main/java/org/apache/kafka/common/utils/AbstractIterator.java new file mode 100644 index 0000000..daf89e6 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/AbstractIterator.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import java.util.Iterator; +import java.util.NoSuchElementException; + +/** + * A base class that simplifies implementing an iterator + * @param The type of thing we are iterating over + */ +public abstract class AbstractIterator implements Iterator { + + private enum State { + READY, NOT_READY, DONE, FAILED + } + + private State state = State.NOT_READY; + private T next; + + @Override + public boolean hasNext() { + switch (state) { + case FAILED: + throw new IllegalStateException("Iterator is in failed state"); + case DONE: + return false; + case READY: + return true; + default: + return maybeComputeNext(); + } + } + + @Override + public T next() { + if (!hasNext()) + throw new NoSuchElementException(); + state = State.NOT_READY; + if (next == null) + throw new IllegalStateException("Expected item but none found."); + return next; + } + + @Override + public void remove() { + throw new UnsupportedOperationException("Removal not supported"); + } + + public T peek() { + if (!hasNext()) + throw new NoSuchElementException(); + return next; + } + + protected T allDone() { + state = State.DONE; + return null; + } + + protected abstract T makeNext(); + + private Boolean maybeComputeNext() { + state = State.FAILED; + next = makeNext(); + if (state == State.DONE) { + return false; + } else { + state = State.READY; + return true; + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/AppInfoParser.java b/clients/src/main/java/org/apache/kafka/common/utils/AppInfoParser.java new file mode 100644 index 0000000..19f98d1 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/AppInfoParser.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import java.io.InputStream; +import java.lang.management.ManagementFactory; +import java.util.Properties; + +import javax.management.JMException; +import javax.management.MBeanServer; +import javax.management.ObjectName; + +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.Gauge; +import org.apache.kafka.common.metrics.MetricConfig; +import org.apache.kafka.common.metrics.Metrics; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class AppInfoParser { + private static final Logger log = LoggerFactory.getLogger(AppInfoParser.class); + private static final String VERSION; + private static final String COMMIT_ID; + + protected static final String DEFAULT_VALUE = "unknown"; + + static { + Properties props = new Properties(); + try (InputStream resourceStream = AppInfoParser.class.getResourceAsStream("/kafka/kafka-version.properties")) { + props.load(resourceStream); + } catch (Exception e) { + log.warn("Error while loading kafka-version.properties: {}", e.getMessage()); + } + VERSION = props.getProperty("version", DEFAULT_VALUE).trim(); + COMMIT_ID = props.getProperty("commitId", DEFAULT_VALUE).trim(); + } + + public static String getVersion() { + return VERSION; + } + + public static String getCommitId() { + return COMMIT_ID; + } + + public static synchronized void registerAppInfo(String prefix, String id, Metrics metrics, long nowMs) { + try { + ObjectName name = new ObjectName(prefix + ":type=app-info,id=" + Sanitizer.jmxSanitize(id)); + AppInfo mBean = new AppInfo(nowMs); + ManagementFactory.getPlatformMBeanServer().registerMBean(mBean, name); + + registerMetrics(metrics, mBean); // prefix will be added later by JmxReporter + } catch (JMException e) { + log.warn("Error registering AppInfo mbean", e); + } + } + + public static synchronized void unregisterAppInfo(String prefix, String id, Metrics metrics) { + MBeanServer server = ManagementFactory.getPlatformMBeanServer(); + try { + ObjectName name = new ObjectName(prefix + ":type=app-info,id=" + Sanitizer.jmxSanitize(id)); + if (server.isRegistered(name)) + server.unregisterMBean(name); + + unregisterMetrics(metrics); + } catch (JMException e) { + log.warn("Error unregistering AppInfo mbean", e); + } finally { + log.info("App info {} for {} unregistered", prefix, id); + } + } + + private static MetricName metricName(Metrics metrics, String name) { + return metrics.metricName(name, "app-info", "Metric indicating " + name); + } + + private static void registerMetrics(Metrics metrics, AppInfo appInfo) { + if (metrics != null) { + metrics.addMetric(metricName(metrics, "version"), new ImmutableValue<>(appInfo.getVersion())); + metrics.addMetric(metricName(metrics, "commit-id"), new ImmutableValue<>(appInfo.getCommitId())); + metrics.addMetric(metricName(metrics, "start-time-ms"), new ImmutableValue<>(appInfo.getStartTimeMs())); + } + } + + private static void unregisterMetrics(Metrics metrics) { + if (metrics != null) { + metrics.removeMetric(metricName(metrics, "version")); + metrics.removeMetric(metricName(metrics, "commit-id")); + metrics.removeMetric(metricName(metrics, "start-time-ms")); + } + } + + public interface AppInfoMBean { + String getVersion(); + String getCommitId(); + Long getStartTimeMs(); + } + + public static class AppInfo implements AppInfoMBean { + + private final Long startTimeMs; + + public AppInfo(long startTimeMs) { + this.startTimeMs = startTimeMs; + log.info("Kafka version: {}", AppInfoParser.getVersion()); + log.info("Kafka commitId: {}", AppInfoParser.getCommitId()); + log.info("Kafka startTimeMs: {}", startTimeMs); + } + + @Override + public String getVersion() { + return AppInfoParser.getVersion(); + } + + @Override + public String getCommitId() { + return AppInfoParser.getCommitId(); + } + + @Override + public Long getStartTimeMs() { + return startTimeMs; + } + + } + + static class ImmutableValue implements Gauge { + private final T value; + + public ImmutableValue(T value) { + this.value = value; + } + + @Override + public T value(MetricConfig config, long now) { + return value; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/BufferSupplier.java b/clients/src/main/java/org/apache/kafka/common/utils/BufferSupplier.java new file mode 100644 index 0000000..1688d10 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/BufferSupplier.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.utils; + +import java.nio.ByteBuffer; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.HashMap; +import java.util.Map; + +/** + * Simple non-threadsafe interface for caching byte buffers. This is suitable for simple cases like ensuring that + * a given KafkaConsumer reuses the same decompression buffer when iterating over fetched records. For small record + * batches, allocating a potentially large buffer (64 KB for LZ4) will dominate the cost of decompressing and + * iterating over the records in the batch. + */ +public abstract class BufferSupplier implements AutoCloseable { + + public static final BufferSupplier NO_CACHING = new BufferSupplier() { + @Override + public ByteBuffer get(int capacity) { + return ByteBuffer.allocate(capacity); + } + + @Override + public void release(ByteBuffer buffer) {} + + @Override + public void close() {} + }; + + public static BufferSupplier create() { + return new DefaultSupplier(); + } + + /** + * Supply a buffer with the required capacity. This may return a cached buffer or allocate a new instance. + */ + public abstract ByteBuffer get(int capacity); + + /** + * Return the provided buffer to be reused by a subsequent call to `get`. + */ + public abstract void release(ByteBuffer buffer); + + /** + * Release all resources associated with this supplier. + */ + public abstract void close(); + + private static class DefaultSupplier extends BufferSupplier { + // We currently use a single block size, so optimise for that case + private final Map> bufferMap = new HashMap<>(1); + + @Override + public ByteBuffer get(int size) { + Deque bufferQueue = bufferMap.get(size); + if (bufferQueue == null || bufferQueue.isEmpty()) + return ByteBuffer.allocate(size); + else + return bufferQueue.pollFirst(); + } + + @Override + public void release(ByteBuffer buffer) { + buffer.clear(); + Deque bufferQueue = bufferMap.get(buffer.capacity()); + if (bufferQueue == null) { + // We currently keep a single buffer in flight, so optimise for that case + bufferQueue = new ArrayDeque<>(1); + bufferMap.put(buffer.capacity(), bufferQueue); + } + bufferQueue.addLast(buffer); + } + + @Override + public void close() { + bufferMap.clear(); + } + } + + /** + * Simple buffer supplier for single-threaded usage. It caches a single buffer, which grows + * monotonically as needed to fulfill the allocation request. + */ + public static class GrowableBufferSupplier extends BufferSupplier { + private ByteBuffer cachedBuffer; + + @Override + public ByteBuffer get(int minCapacity) { + if (cachedBuffer != null && cachedBuffer.capacity() >= minCapacity) { + ByteBuffer res = cachedBuffer; + cachedBuffer = null; + return res; + } else { + cachedBuffer = null; + return ByteBuffer.allocate(minCapacity); + } + } + + @Override + public void release(ByteBuffer buffer) { + buffer.clear(); + cachedBuffer = buffer; + } + + @Override + public void close() { + cachedBuffer = null; + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/ByteBufferInputStream.java b/clients/src/main/java/org/apache/kafka/common/utils/ByteBufferInputStream.java new file mode 100644 index 0000000..1266d4b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/ByteBufferInputStream.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import java.io.InputStream; +import java.nio.ByteBuffer; + +/** + * A byte buffer backed input inputStream + */ +public final class ByteBufferInputStream extends InputStream { + private final ByteBuffer buffer; + + public ByteBufferInputStream(ByteBuffer buffer) { + this.buffer = buffer; + } + + public int read() { + if (!buffer.hasRemaining()) { + return -1; + } + return buffer.get() & 0xFF; + } + + public int read(byte[] bytes, int off, int len) { + if (len == 0) { + return 0; + } + if (!buffer.hasRemaining()) { + return -1; + } + + len = Math.min(len, buffer.remaining()); + buffer.get(bytes, off, len); + return len; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/ByteBufferOutputStream.java b/clients/src/main/java/org/apache/kafka/common/utils/ByteBufferOutputStream.java new file mode 100644 index 0000000..43e3bba --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/ByteBufferOutputStream.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import java.io.OutputStream; +import java.nio.ByteBuffer; + +/** + * A ByteBuffer-backed OutputStream that expands the internal ByteBuffer as required. Given this, the caller should + * always access the underlying ByteBuffer via the {@link #buffer()} method until all writes are completed. + * + * This class is typically used for 2 purposes: + * + * 1. Write to a ByteBuffer when there is a chance that we may need to expand it in order to fit all the desired data + * 2. Write to a ByteBuffer via methods that expect an OutputStream interface + * + * Hard to track bugs can happen when this class is used for the second reason and unexpected buffer expansion happens. + * So, it's best to assume that buffer expansion can always happen. An improvement would be to create a separate class + * that throws an error if buffer expansion is required to avoid the issue altogether. + */ +public class ByteBufferOutputStream extends OutputStream { + + private static final float REALLOCATION_FACTOR = 1.1f; + + private final int initialCapacity; + private final int initialPosition; + private ByteBuffer buffer; + + /** + * Creates an instance of this class that will write to the received `buffer` up to its `limit`. If necessary to + * satisfy `write` or `position` calls, larger buffers will be allocated so the {@link #buffer()} method may return + * a different buffer than the received `buffer` parameter. + * + * Prefer one of the constructors that allocate the internal buffer for clearer semantics. + */ + public ByteBufferOutputStream(ByteBuffer buffer) { + this.buffer = buffer; + this.initialPosition = buffer.position(); + this.initialCapacity = buffer.capacity(); + } + + public ByteBufferOutputStream(int initialCapacity) { + this(initialCapacity, false); + } + + public ByteBufferOutputStream(int initialCapacity, boolean directBuffer) { + this(directBuffer ? ByteBuffer.allocateDirect(initialCapacity) : ByteBuffer.allocate(initialCapacity)); + } + + public void write(int b) { + ensureRemaining(1); + buffer.put((byte) b); + } + + public void write(byte[] bytes, int off, int len) { + ensureRemaining(len); + buffer.put(bytes, off, len); + } + + public void write(ByteBuffer sourceBuffer) { + ensureRemaining(sourceBuffer.remaining()); + buffer.put(sourceBuffer); + } + + public ByteBuffer buffer() { + return buffer; + } + + public int position() { + return buffer.position(); + } + + public int remaining() { + return buffer.remaining(); + } + + public int limit() { + return buffer.limit(); + } + + public void position(int position) { + ensureRemaining(position - buffer.position()); + buffer.position(position); + } + + /** + * The capacity of the first internal ByteBuffer used by this class. This is useful in cases where a pooled + * ByteBuffer was passed via the constructor and it needs to be returned to the pool. + */ + public int initialCapacity() { + return initialCapacity; + } + + /** + * Ensure there is enough space to write some number of bytes, expanding the underlying buffer if necessary. + * This can be used to avoid incremental expansions through calls to {@link #write(int)} when you know how + * many total bytes are needed. + * + * @param remainingBytesRequired The number of bytes required + */ + public void ensureRemaining(int remainingBytesRequired) { + if (remainingBytesRequired > buffer.remaining()) + expandBuffer(remainingBytesRequired); + } + + private void expandBuffer(int remainingRequired) { + int expandSize = Math.max((int) (buffer.limit() * REALLOCATION_FACTOR), buffer.position() + remainingRequired); + ByteBuffer temp = ByteBuffer.allocate(expandSize); + int limit = limit(); + buffer.flip(); + temp.put(buffer); + buffer.limit(limit); + // reset the old buffer's position so that the partial data in the new buffer cannot be mistakenly consumed + // we should ideally only do this for the original buffer, but the additional complexity doesn't seem worth it + buffer.position(initialPosition); + buffer = temp; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/ByteBufferUnmapper.java b/clients/src/main/java/org/apache/kafka/common/utils/ByteBufferUnmapper.java new file mode 100644 index 0000000..4777f7b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/ByteBufferUnmapper.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.utils; + +import java.io.IOException; +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.nio.ByteBuffer; + +import static java.lang.invoke.MethodHandles.constant; +import static java.lang.invoke.MethodHandles.dropArguments; +import static java.lang.invoke.MethodHandles.filterReturnValue; +import static java.lang.invoke.MethodHandles.guardWithTest; +import static java.lang.invoke.MethodHandles.lookup; +import static java.lang.invoke.MethodType.methodType; + +/** + * Provides a mechanism to unmap mapped and direct byte buffers. + * + * The implementation was inspired by the one in Lucene's MMapDirectory. + */ +public final class ByteBufferUnmapper { + + // null if unmap is not supported + private static final MethodHandle UNMAP; + + // null if unmap is supported + private static final RuntimeException UNMAP_NOT_SUPPORTED_EXCEPTION; + + static { + Object unmap = null; + RuntimeException exception = null; + try { + unmap = lookupUnmapMethodHandle(); + } catch (RuntimeException e) { + exception = e; + } + if (unmap != null) { + UNMAP = (MethodHandle) unmap; + UNMAP_NOT_SUPPORTED_EXCEPTION = null; + } else { + UNMAP = null; + UNMAP_NOT_SUPPORTED_EXCEPTION = exception; + } + } + + private ByteBufferUnmapper() {} + + /** + * Unmap the provided mapped or direct byte buffer. + * + * This buffer cannot be referenced after this call, so it's highly recommended that any fields referencing it + * should be set to null. + * + * @throws IllegalArgumentException if buffer is not mapped or direct. + */ + public static void unmap(String resourceDescription, ByteBuffer buffer) throws IOException { + if (!buffer.isDirect()) + throw new IllegalArgumentException("Unmapping only works with direct buffers"); + if (UNMAP == null) + throw UNMAP_NOT_SUPPORTED_EXCEPTION; + + try { + UNMAP.invokeExact(buffer); + } catch (Throwable throwable) { + throw new IOException("Unable to unmap the mapped buffer: " + resourceDescription, throwable); + } + } + + private static MethodHandle lookupUnmapMethodHandle() { + final MethodHandles.Lookup lookup = lookup(); + try { + if (Java.IS_JAVA9_COMPATIBLE) + return unmapJava9(lookup); + else + return unmapJava7Or8(lookup); + } catch (ReflectiveOperationException | RuntimeException e1) { + throw new UnsupportedOperationException("Unmapping is not supported on this platform, because internal " + + "Java APIs are not compatible with this Kafka version", e1); + } + } + + private static MethodHandle unmapJava7Or8(MethodHandles.Lookup lookup) throws ReflectiveOperationException { + /* "Compile" a MethodHandle that is roughly equivalent to the following lambda: + * + * (ByteBuffer buffer) -> { + * sun.misc.Cleaner cleaner = ((java.nio.DirectByteBuffer) byteBuffer).cleaner(); + * if (nonNull(cleaner)) + * cleaner.clean(); + * else + * noop(cleaner); // the noop is needed because MethodHandles#guardWithTest always needs both if and else + * } + */ + Class directBufferClass = Class.forName("java.nio.DirectByteBuffer"); + Method m = directBufferClass.getMethod("cleaner"); + m.setAccessible(true); + MethodHandle directBufferCleanerMethod = lookup.unreflect(m); + Class cleanerClass = directBufferCleanerMethod.type().returnType(); + MethodHandle cleanMethod = lookup.findVirtual(cleanerClass, "clean", methodType(void.class)); + MethodHandle nonNullTest = lookup.findStatic(ByteBufferUnmapper.class, "nonNull", + methodType(boolean.class, Object.class)).asType(methodType(boolean.class, cleanerClass)); + MethodHandle noop = dropArguments(constant(Void.class, null).asType(methodType(void.class)), 0, cleanerClass); + MethodHandle unmapper = filterReturnValue(directBufferCleanerMethod, guardWithTest(nonNullTest, cleanMethod, noop)) + .asType(methodType(void.class, ByteBuffer.class)); + return unmapper; + } + + private static MethodHandle unmapJava9(MethodHandles.Lookup lookup) throws ReflectiveOperationException { + Class unsafeClass = Class.forName("sun.misc.Unsafe"); + MethodHandle unmapper = lookup.findVirtual(unsafeClass, "invokeCleaner", + methodType(void.class, ByteBuffer.class)); + Field f = unsafeClass.getDeclaredField("theUnsafe"); + f.setAccessible(true); + Object theUnsafe = f.get(null); + return unmapper.bindTo(theUnsafe); + } + + private static boolean nonNull(Object o) { + return o != null; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/ByteUtils.java b/clients/src/main/java/org/apache/kafka/common/utils/ByteUtils.java new file mode 100644 index 0000000..1586872 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/ByteUtils.java @@ -0,0 +1,436 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; + +/** + * This classes exposes low-level methods for reading/writing from byte streams or buffers. + */ +public final class ByteUtils { + + public static final ByteBuffer EMPTY_BUF = ByteBuffer.wrap(new byte[0]); + + private ByteUtils() {} + + /** + * Read an unsigned integer from the current position in the buffer, incrementing the position by 4 bytes + * + * @param buffer The buffer to read from + * @return The integer read, as a long to avoid signedness + */ + public static long readUnsignedInt(ByteBuffer buffer) { + return buffer.getInt() & 0xffffffffL; + } + + /** + * Read an unsigned integer from the given position without modifying the buffers position + * + * @param buffer the buffer to read from + * @param index the index from which to read the integer + * @return The integer read, as a long to avoid signedness + */ + public static long readUnsignedInt(ByteBuffer buffer, int index) { + return buffer.getInt(index) & 0xffffffffL; + } + + /** + * Read an unsigned integer stored in little-endian format from the {@link InputStream}. + * + * @param in The stream to read from + * @return The integer read (MUST BE TREATED WITH SPECIAL CARE TO AVOID SIGNEDNESS) + */ + public static int readUnsignedIntLE(InputStream in) throws IOException { + return in.read() + | (in.read() << 8) + | (in.read() << 16) + | (in.read() << 24); + } + + /** + * Read an unsigned integer stored in little-endian format from a byte array + * at a given offset. + * + * @param buffer The byte array to read from + * @param offset The position in buffer to read from + * @return The integer read (MUST BE TREATED WITH SPECIAL CARE TO AVOID SIGNEDNESS) + */ + public static int readUnsignedIntLE(byte[] buffer, int offset) { + return (buffer[offset] << 0 & 0xff) + | ((buffer[offset + 1] & 0xff) << 8) + | ((buffer[offset + 2] & 0xff) << 16) + | ((buffer[offset + 3] & 0xff) << 24); + } + + /** + * Write the given long value as a 4 byte unsigned integer. Overflow is ignored. + * + * @param buffer The buffer to write to + * @param index The position in the buffer at which to begin writing + * @param value The value to write + */ + public static void writeUnsignedInt(ByteBuffer buffer, int index, long value) { + buffer.putInt(index, (int) (value & 0xffffffffL)); + } + + /** + * Write the given long value as a 4 byte unsigned integer. Overflow is ignored. + * + * @param buffer The buffer to write to + * @param value The value to write + */ + public static void writeUnsignedInt(ByteBuffer buffer, long value) { + buffer.putInt((int) (value & 0xffffffffL)); + } + + /** + * Write an unsigned integer in little-endian format to the {@link OutputStream}. + * + * @param out The stream to write to + * @param value The value to write + */ + public static void writeUnsignedIntLE(OutputStream out, int value) throws IOException { + out.write(value); + out.write(value >>> 8); + out.write(value >>> 16); + out.write(value >>> 24); + } + + /** + * Write an unsigned integer in little-endian format to a byte array + * at a given offset. + * + * @param buffer The byte array to write to + * @param offset The position in buffer to write to + * @param value The value to write + */ + public static void writeUnsignedIntLE(byte[] buffer, int offset, int value) { + buffer[offset] = (byte) value; + buffer[offset + 1] = (byte) (value >>> 8); + buffer[offset + 2] = (byte) (value >>> 16); + buffer[offset + 3] = (byte) (value >>> 24); + } + + /** + * Read an integer stored in variable-length format using unsigned decoding from + * Google Protocol Buffers. + * + * @param buffer The buffer to read from + * @return The integer read + * + * @throws IllegalArgumentException if variable-length value does not terminate after 5 bytes have been read + */ + public static int readUnsignedVarint(ByteBuffer buffer) { + int value = 0; + int i = 0; + int b; + while (((b = buffer.get()) & 0x80) != 0) { + value |= (b & 0x7f) << i; + i += 7; + if (i > 28) + throw illegalVarintException(value); + } + value |= b << i; + return value; + } + + /** + * Read an integer stored in variable-length format using unsigned decoding from + * Google Protocol Buffers. + * + * @param in The input to read from + * @return The integer read + * + * @throws IllegalArgumentException if variable-length value does not terminate after 5 bytes have been read + * @throws IOException if {@link DataInput} throws {@link IOException} + */ + public static int readUnsignedVarint(DataInput in) throws IOException { + int value = 0; + int i = 0; + int b; + while (((b = in.readByte()) & 0x80) != 0) { + value |= (b & 0x7f) << i; + i += 7; + if (i > 28) + throw illegalVarintException(value); + } + value |= b << i; + return value; + } + + /** + * Read an integer stored in variable-length format using zig-zag decoding from + * Google Protocol Buffers. + * + * @param buffer The buffer to read from + * @return The integer read + * + * @throws IllegalArgumentException if variable-length value does not terminate after 5 bytes have been read + */ + public static int readVarint(ByteBuffer buffer) { + int value = readUnsignedVarint(buffer); + return (value >>> 1) ^ -(value & 1); + } + + /** + * Read an integer stored in variable-length format using zig-zag decoding from + * Google Protocol Buffers. + * + * @param in The input to read from + * @return The integer read + * + * @throws IllegalArgumentException if variable-length value does not terminate after 5 bytes have been read + * @throws IOException if {@link DataInput} throws {@link IOException} + */ + public static int readVarint(DataInput in) throws IOException { + int value = readUnsignedVarint(in); + return (value >>> 1) ^ -(value & 1); + } + + /** + * Read a long stored in variable-length format using zig-zag decoding from + * Google Protocol Buffers. + * + * @param in The input to read from + * @return The long value read + * + * @throws IllegalArgumentException if variable-length value does not terminate after 10 bytes have been read + * @throws IOException if {@link DataInput} throws {@link IOException} + */ + public static long readVarlong(DataInput in) throws IOException { + long value = 0L; + int i = 0; + long b; + while (((b = in.readByte()) & 0x80) != 0) { + value |= (b & 0x7f) << i; + i += 7; + if (i > 63) + throw illegalVarlongException(value); + } + value |= b << i; + return (value >>> 1) ^ -(value & 1); + } + + /** + * Read a long stored in variable-length format using zig-zag decoding from + * Google Protocol Buffers. + * + * @param buffer The buffer to read from + * @return The long value read + * + * @throws IllegalArgumentException if variable-length value does not terminate after 10 bytes have been read + */ + public static long readVarlong(ByteBuffer buffer) { + long value = 0L; + int i = 0; + long b; + while (((b = buffer.get()) & 0x80) != 0) { + value |= (b & 0x7f) << i; + i += 7; + if (i > 63) + throw illegalVarlongException(value); + } + value |= b << i; + return (value >>> 1) ^ -(value & 1); + } + + /** + * Read a double-precision 64-bit format IEEE 754 value. + * + * @param in The input to read from + * @return The double value read + */ + public static double readDouble(DataInput in) throws IOException { + return in.readDouble(); + } + + /** + * Read a double-precision 64-bit format IEEE 754 value. + * + * @param buffer The buffer to read from + * @return The long value read + */ + public static double readDouble(ByteBuffer buffer) { + return buffer.getDouble(); + } + + /** + * Write the given integer following the variable-length unsigned encoding from + * Google Protocol Buffers + * into the buffer. + * + * @param value The value to write + * @param buffer The output to write to + */ + public static void writeUnsignedVarint(int value, ByteBuffer buffer) { + while ((value & 0xffffff80) != 0L) { + byte b = (byte) ((value & 0x7f) | 0x80); + buffer.put(b); + value >>>= 7; + } + buffer.put((byte) value); + } + + /** + * Write the given integer following the variable-length unsigned encoding from + * Google Protocol Buffers + * into the buffer. + * + * @param value The value to write + * @param out The output to write to + */ + public static void writeUnsignedVarint(int value, DataOutput out) throws IOException { + while ((value & 0xffffff80) != 0L) { + byte b = (byte) ((value & 0x7f) | 0x80); + out.writeByte(b); + value >>>= 7; + } + out.writeByte((byte) value); + } + + /** + * Write the given integer following the variable-length zig-zag encoding from + * Google Protocol Buffers + * into the output. + * + * @param value The value to write + * @param out The output to write to + */ + public static void writeVarint(int value, DataOutput out) throws IOException { + writeUnsignedVarint((value << 1) ^ (value >> 31), out); + } + + /** + * Write the given integer following the variable-length zig-zag encoding from + * Google Protocol Buffers + * into the buffer. + * + * @param value The value to write + * @param buffer The output to write to + */ + public static void writeVarint(int value, ByteBuffer buffer) { + writeUnsignedVarint((value << 1) ^ (value >> 31), buffer); + } + + /** + * Write the given integer following the variable-length zig-zag encoding from + * Google Protocol Buffers + * into the output. + * + * @param value The value to write + * @param out The output to write to + */ + public static void writeVarlong(long value, DataOutput out) throws IOException { + long v = (value << 1) ^ (value >> 63); + while ((v & 0xffffffffffffff80L) != 0L) { + out.writeByte(((int) v & 0x7f) | 0x80); + v >>>= 7; + } + out.writeByte((byte) v); + } + + /** + * Write the given integer following the variable-length zig-zag encoding from + * Google Protocol Buffers + * into the buffer. + * + * @param value The value to write + * @param buffer The buffer to write to + */ + public static void writeVarlong(long value, ByteBuffer buffer) { + long v = (value << 1) ^ (value >> 63); + while ((v & 0xffffffffffffff80L) != 0L) { + byte b = (byte) ((v & 0x7f) | 0x80); + buffer.put(b); + v >>>= 7; + } + buffer.put((byte) v); + } + + /** + * Write the given double following the double-precision 64-bit format IEEE 754 value into the output. + * + * @param value The value to write + * @param out The output to write to + */ + public static void writeDouble(double value, DataOutput out) throws IOException { + out.writeDouble(value); + } + + /** + * Write the given double following the double-precision 64-bit format IEEE 754 value into the buffer. + * + * @param value The value to write + * @param buffer The buffer to write to + */ + public static void writeDouble(double value, ByteBuffer buffer) { + buffer.putDouble(value); + } + + /** + * Number of bytes needed to encode an integer in unsigned variable-length format. + * + * @param value The signed value + */ + public static int sizeOfUnsignedVarint(int value) { + int bytes = 1; + while ((value & 0xffffff80) != 0L) { + bytes += 1; + value >>>= 7; + } + return bytes; + } + + /** + * Number of bytes needed to encode an integer in variable-length format. + * + * @param value The signed value + */ + public static int sizeOfVarint(int value) { + return sizeOfUnsignedVarint((value << 1) ^ (value >> 31)); + } + + /** + * Number of bytes needed to encode a long in variable-length format. + * + * @param value The signed value + */ + public static int sizeOfVarlong(long value) { + long v = (value << 1) ^ (value >> 63); + int bytes = 1; + while ((v & 0xffffffffffffff80L) != 0L) { + bytes += 1; + v >>>= 7; + } + return bytes; + } + + private static IllegalArgumentException illegalVarintException(int value) { + throw new IllegalArgumentException("Varint is too long, the most significant bit in the 5th byte is set, " + + "converted value: " + Integer.toHexString(value)); + } + + private static IllegalArgumentException illegalVarlongException(long value) { + throw new IllegalArgumentException("Varlong is too long, most significant bit in the 10th byte is set, " + + "converted value: " + Long.toHexString(value)); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/Bytes.java b/clients/src/main/java/org/apache/kafka/common/utils/Bytes.java new file mode 100644 index 0000000..df75459 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/Bytes.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.Comparator; + +/** + * Utility class that handles immutable byte arrays. + */ +public class Bytes implements Comparable { + + public static final byte[] EMPTY = new byte[0]; + + private static final char[] HEX_CHARS_UPPER = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'}; + + private final byte[] bytes; + + // cache the hash code for the string, default to 0 + private int hashCode; + + public static Bytes wrap(byte[] bytes) { + if (bytes == null) + return null; + return new Bytes(bytes); + } + + /** + * Create a Bytes using the byte array. + * + * @param bytes This array becomes the backing storage for the object. + */ + public Bytes(byte[] bytes) { + this.bytes = bytes; + + // initialize hash code to 0 + hashCode = 0; + } + + /** + * Get the data from the Bytes. + * @return The underlying byte array + */ + public byte[] get() { + return this.bytes; + } + + /** + * The hashcode is cached except for the case where it is computed as 0, in which + * case we compute the hashcode on every call. + * + * @return the hashcode + */ + @Override + public int hashCode() { + if (hashCode == 0) { + hashCode = Arrays.hashCode(bytes); + } + + return hashCode; + } + + @Override + public boolean equals(Object other) { + if (this == other) + return true; + if (other == null) + return false; + + // we intentionally use the function to compute hashcode here + if (this.hashCode() != other.hashCode()) + return false; + + if (other instanceof Bytes) + return Arrays.equals(this.bytes, ((Bytes) other).get()); + + return false; + } + + @Override + public int compareTo(Bytes that) { + return BYTES_LEXICO_COMPARATOR.compare(this.bytes, that.bytes); + } + + @Override + public String toString() { + return Bytes.toString(bytes, 0, bytes.length); + } + + /** + * Write a printable representation of a byte array. Non-printable + * characters are hex escaped in the format \\x%02X, eg: + * \x00 \x05 etc. + * + * This function is brought from org.apache.hadoop.hbase.util.Bytes + * + * @param b array to write out + * @param off offset to start at + * @param len length to write + * @return string output + */ + private static String toString(final byte[] b, int off, int len) { + StringBuilder result = new StringBuilder(); + + if (b == null) + return result.toString(); + + // just in case we are passed a 'len' that is > buffer length... + if (off >= b.length) + return result.toString(); + + if (off + len > b.length) + len = b.length - off; + + for (int i = off; i < off + len; ++i) { + int ch = b[i] & 0xFF; + if (ch >= ' ' && ch <= '~' && ch != '\\') { + result.append((char) ch); + } else { + result.append("\\x"); + result.append(HEX_CHARS_UPPER[ch / 0x10]); + result.append(HEX_CHARS_UPPER[ch % 0x10]); + } + } + return result.toString(); + } + + /** + * Increment the underlying byte array by adding 1. Throws an IndexOutOfBoundsException if incrementing would cause + * the underlying input byte array to overflow. + * + * @param input - The byte array to increment + * @return A new copy of the incremented byte array. + */ + public static Bytes increment(Bytes input) throws IndexOutOfBoundsException { + byte[] inputArr = input.get(); + byte[] ret = new byte[inputArr.length]; + int carry = 1; + for (int i = inputArr.length - 1; i >= 0; i--) { + if (inputArr[i] == (byte) 0xFF && carry == 1) { + ret[i] = (byte) 0x00; + } else { + ret[i] = (byte) (inputArr[i] + carry); + carry = 0; + } + } + if (carry == 0) { + return wrap(ret); + } else { + throw new IndexOutOfBoundsException(); + } + } + + /** + * A byte array comparator based on lexicograpic ordering. + */ + public final static ByteArrayComparator BYTES_LEXICO_COMPARATOR = new LexicographicByteArrayComparator(); + + public interface ByteArrayComparator extends Comparator, Serializable { + + int compare(final byte[] buffer1, int offset1, int length1, + final byte[] buffer2, int offset2, int length2); + } + + private static class LexicographicByteArrayComparator implements ByteArrayComparator { + + @Override + public int compare(byte[] buffer1, byte[] buffer2) { + return compare(buffer1, 0, buffer1.length, buffer2, 0, buffer2.length); + } + + public int compare(final byte[] buffer1, int offset1, int length1, + final byte[] buffer2, int offset2, int length2) { + + // short circuit equal case + if (buffer1 == buffer2 && + offset1 == offset2 && + length1 == length2) { + return 0; + } + + // similar to Arrays.compare() but considers offset and length + int end1 = offset1 + length1; + int end2 = offset2 + length2; + for (int i = offset1, j = offset2; i < end1 && j < end2; i++, j++) { + int a = buffer1[i] & 0xff; + int b = buffer2[j] & 0xff; + if (a != b) { + return a - b; + } + } + return length1 - length2; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/Checksums.java b/clients/src/main/java/org/apache/kafka/common/utils/Checksums.java new file mode 100644 index 0000000..679b592 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/Checksums.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import java.nio.ByteBuffer; +import java.util.zip.Checksum; + +/** + * Utility methods for `Checksum` instances. + * + * Implementation note: we can add methods to our implementations of CRC32 and CRC32C, but we cannot do the same for + * the Java implementations (we prefer the Java 9 implementation of CRC32C if available). A utility class is the + * simplest way to add methods that are useful for all Checksum implementations. + * + * NOTE: This class is intended for INTERNAL usage only within Kafka. + */ +public final class Checksums { + + private Checksums() { + } + + public static void update(Checksum checksum, ByteBuffer buffer, int length) { + update(checksum, buffer, 0, length); + } + + public static void update(Checksum checksum, ByteBuffer buffer, int offset, int length) { + if (buffer.hasArray()) { + checksum.update(buffer.array(), buffer.position() + buffer.arrayOffset() + offset, length); + } else { + int start = buffer.position() + offset; + for (int i = start; i < start + length; i++) + checksum.update(buffer.get(i)); + } + } + + public static void updateInt(Checksum checksum, int input) { + checksum.update((byte) (input >> 24)); + checksum.update((byte) (input >> 16)); + checksum.update((byte) (input >> 8)); + checksum.update((byte) input /* >> 0 */); + } + + public static void updateLong(Checksum checksum, long input) { + checksum.update((byte) (input >> 56)); + checksum.update((byte) (input >> 48)); + checksum.update((byte) (input >> 40)); + checksum.update((byte) (input >> 32)); + checksum.update((byte) (input >> 24)); + checksum.update((byte) (input >> 16)); + checksum.update((byte) (input >> 8)); + checksum.update((byte) input /* >> 0 */); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/CircularIterator.java b/clients/src/main/java/org/apache/kafka/common/utils/CircularIterator.java new file mode 100644 index 0000000..925f4ad --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/CircularIterator.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.utils; + +import java.util.Collection; +import java.util.ConcurrentModificationException; +import java.util.Iterator; +import java.util.Objects; + +/** + * An iterator that cycles through the {@code Iterator} of a {@code Collection} + * indefinitely. Useful for tasks such as round-robin load balancing. This class + * does not provide thread-safe access. This {@code Iterator} supports + * {@code null} elements in the underlying {@code Collection}. This + * {@code Iterator} does not support any modification to the underlying + * {@code Collection} after it has been wrapped by this class. Changing the + * underlying {@code Collection} may cause a + * {@link ConcurrentModificationException} or some other undefined behavior. + */ +public class CircularIterator implements Iterator { + + private final Iterable iterable; + private Iterator iterator; + private T nextValue; + + /** + * Create a new instance of a CircularIterator. The ordering of this + * Iterator will be dictated by the Iterator returned by Collection itself. + * + * @param col The collection to iterate indefinitely + * + * @throws NullPointerException if col is {@code null} + * @throws IllegalArgumentException if col is empty. + */ + public CircularIterator(final Collection col) { + this.iterable = Objects.requireNonNull(col); + this.iterator = col.iterator(); + if (col.isEmpty()) { + throw new IllegalArgumentException("CircularIterator can only be used on non-empty lists"); + } + this.nextValue = advance(); + } + + /** + * Returns true since the iteration will forever cycle through the provided + * {@code Collection}. + * + * @return Always true + */ + @Override + public boolean hasNext() { + return true; + } + + @Override + public T next() { + final T next = nextValue; + nextValue = advance(); + return next; + } + + /** + * Return the next value in the {@code Iterator}, restarting the + * {@code Iterator} if necessary. + * + * @return The next value in the iterator + */ + private T advance() { + if (!iterator.hasNext()) { + iterator = iterable.iterator(); + } + return iterator.next(); + } + + /** + * Peek at the next value in the Iterator. Calling this method multiple + * times will return the same element without advancing this Iterator. The + * value returned by this method will be the next item returned by + * {@code next()}. + * + * @return The next value in this {@code Iterator} + */ + public T peek() { + return nextValue; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/CloseableIterator.java b/clients/src/main/java/org/apache/kafka/common/utils/CloseableIterator.java new file mode 100644 index 0000000..50b0636 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/CloseableIterator.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import java.io.Closeable; +import java.util.Iterator; + +/** + * Iterators that need to be closed in order to release resources should implement this interface. + * + * Warning: before implementing this interface, consider if there are better options. The chance of misuse is + * a bit high since people are used to iterating without closing. + */ +public interface CloseableIterator extends Iterator, Closeable { + void close(); + + static CloseableIterator wrap(Iterator inner) { + return new CloseableIterator() { + @Override + public void close() {} + + @Override + public boolean hasNext() { + return inner.hasNext(); + } + + @Override + public R next() { + return inner.next(); + } + + @Override + public void remove() { + inner.remove(); + } + }; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/CollectionUtils.java b/clients/src/main/java/org/apache/kafka/common/utils/CollectionUtils.java new file mode 100644 index 0000000..3ebbd91 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/CollectionUtils.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.apache.kafka.common.TopicPartition; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Collection; +import java.util.function.BiConsumer; +import java.util.function.Function; +import java.util.stream.Collectors; + +public final class CollectionUtils { + + private CollectionUtils() {} + + /** + * Given two maps (A, B), returns all the key-value pairs in A whose keys are not contained in B + */ + public static Map subtractMap(Map minuend, Map subtrahend) { + return minuend.entrySet().stream() + .filter(entry -> !subtrahend.containsKey(entry.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } + + /** + * group data by topic + * + * @param data Data to be partitioned + * @param Partition data type + * @return partitioned data + */ + public static Map> groupPartitionDataByTopic(Map data) { + Map> dataByTopic = new HashMap<>(); + for (Map.Entry entry : data.entrySet()) { + String topic = entry.getKey().topic(); + int partition = entry.getKey().partition(); + Map topicData = dataByTopic.computeIfAbsent(topic, t -> new HashMap<>()); + topicData.put(partition, entry.getValue()); + } + return dataByTopic; + } + + /** + * Group a list of partitions by the topic name. + * + * @param partitions The partitions to collect + * @return partitions per topic + */ + public static Map> groupPartitionsByTopic(Collection partitions) { + return groupPartitionsByTopic( + partitions, + topic -> new ArrayList<>(), + List::add + ); + } + + /** + * Group a collection of partitions by topic + * + * @return The map used to group the partitions + */ + public static Map groupPartitionsByTopic( + Collection partitions, + Function buildGroup, + BiConsumer addToGroup + ) { + Map dataByTopic = new HashMap<>(); + for (TopicPartition tp : partitions) { + String topic = tp.topic(); + T topicData = dataByTopic.computeIfAbsent(topic, buildGroup); + addToGroup.accept(topicData, tp.partition()); + } + return dataByTopic; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/ConfigUtils.java b/clients/src/main/java/org/apache/kafka/common/utils/ConfigUtils.java new file mode 100644 index 0000000..0f839ff --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/ConfigUtils.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.utils; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigDef.ConfigKey; +import org.apache.kafka.common.config.ConfigDef.Type; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class ConfigUtils { + + private static final Logger log = LoggerFactory.getLogger(ConfigUtils.class); + + /** + * Translates deprecated configurations into their non-deprecated equivalents + * + * This is a convenience method for {@link ConfigUtils#translateDeprecatedConfigs(Map, Map)} + * until we can use Java 9+ {@code Map.of(..)} and {@code Set.of(...)} + * + * @param configs the input configuration + * @param aliasGroups An array of arrays of synonyms. Each synonym array begins with the non-deprecated synonym + * For example, new String[][] { { a, b }, { c, d, e} } + * would declare b as a deprecated synonym for a, + * and d and e as deprecated synonyms for c. + * The ordering of synonyms determines the order of precedence + * (e.g. the first synonym takes precedence over the second one) + * @return a new configuration map with deprecated keys translated to their non-deprecated equivalents + */ + public static Map translateDeprecatedConfigs(Map configs, String[][] aliasGroups) { + return translateDeprecatedConfigs(configs, Stream.of(aliasGroups) + .collect(Collectors.toMap(x -> x[0], x -> Stream.of(x).skip(1).collect(Collectors.toList())))); + } + + /** + * Translates deprecated configurations into their non-deprecated equivalents + * + * @param configs the input configuration + * @param aliasGroups A map of config to synonyms. Each key is the non-deprecated synonym + * For example, Map.of(a , Set.of(b), c, Set.of(d, e)) + * would declare b as a deprecated synonym for a, + * and d and e as deprecated synonyms for c. + * The ordering of synonyms determines the order of precedence + * (e.g. the first synonym takes precedence over the second one) + * @return a new configuration map with deprecated keys translated to their non-deprecated equivalents + */ + public static Map translateDeprecatedConfigs(Map configs, + Map> aliasGroups) { + Set aliasSet = Stream.concat( + aliasGroups.keySet().stream(), + aliasGroups.values().stream().flatMap(Collection::stream)) + .collect(Collectors.toSet()); + + // pass through all configurations without aliases + Map newConfigs = configs.entrySet().stream() + .filter(e -> !aliasSet.contains(e.getKey())) + // filter out null values + .filter(e -> Objects.nonNull(e.getValue())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + aliasGroups.forEach((target, aliases) -> { + List deprecated = aliases.stream() + .filter(configs::containsKey) + .collect(Collectors.toList()); + + if (deprecated.isEmpty()) { + // No deprecated key(s) found. + if (configs.containsKey(target)) { + newConfigs.put(target, configs.get(target)); + } + return; + } + + String aliasString = String.join(", ", deprecated); + + if (configs.containsKey(target)) { + // Ignore the deprecated key(s) because the actual key was set. + log.error(target + " was configured, as well as the deprecated alias(es) " + + aliasString + ". Using the value of " + target); + newConfigs.put(target, configs.get(target)); + } else if (deprecated.size() > 1) { + log.error("The configuration keys " + aliasString + " are deprecated and may be " + + "removed in the future. Additionally, this configuration is ambigous because " + + "these configuration keys are all aliases for " + target + ". Please update " + + "your configuration to have only " + target + " set."); + newConfigs.put(target, configs.get(deprecated.get(0))); + } else { + log.warn("Configuration key " + deprecated.get(0) + " is deprecated and may be removed " + + "in the future. Please update your configuration to use " + target + " instead."); + newConfigs.put(target, configs.get(deprecated.get(0))); + } + }); + + return newConfigs; + } + + public static String configMapToRedactedString(Map map, ConfigDef configDef) { + StringBuilder bld = new StringBuilder("{"); + List keys = new ArrayList<>(map.keySet()); + Collections.sort(keys); + String prefix = ""; + for (String key : keys) { + bld.append(prefix).append(key).append("="); + ConfigKey configKey = configDef.configKeys().get(key); + if (configKey == null || configKey.type().isSensitive()) { + bld.append("(redacted)"); + } else { + Object value = map.get(key); + if (value == null) { + bld.append("null"); + } else if (configKey.type() == Type.STRING) { + bld.append("\"").append(value).append("\""); + } else { + bld.append(value); + } + } + prefix = ", "; + } + bld.append("}"); + return bld.toString(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/CopyOnWriteMap.java b/clients/src/main/java/org/apache/kafka/common/utils/CopyOnWriteMap.java new file mode 100644 index 0000000..1a3351f --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/CopyOnWriteMap.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentMap; + +/** + * A simple read-optimized map implementation that synchronizes only writes and does a full copy on each modification + */ +public class CopyOnWriteMap implements ConcurrentMap { + + private volatile Map map; + + public CopyOnWriteMap() { + this.map = Collections.emptyMap(); + } + + public CopyOnWriteMap(Map map) { + this.map = Collections.unmodifiableMap(map); + } + + @Override + public boolean containsKey(Object k) { + return map.containsKey(k); + } + + @Override + public boolean containsValue(Object v) { + return map.containsValue(v); + } + + @Override + public Set> entrySet() { + return map.entrySet(); + } + + @Override + public V get(Object k) { + return map.get(k); + } + + @Override + public boolean isEmpty() { + return map.isEmpty(); + } + + @Override + public Set keySet() { + return map.keySet(); + } + + @Override + public int size() { + return map.size(); + } + + @Override + public Collection values() { + return map.values(); + } + + @Override + public synchronized void clear() { + this.map = Collections.emptyMap(); + } + + @Override + public synchronized V put(K k, V v) { + Map copy = new HashMap(this.map); + V prev = copy.put(k, v); + this.map = Collections.unmodifiableMap(copy); + return prev; + } + + @Override + public synchronized void putAll(Map entries) { + Map copy = new HashMap(this.map); + copy.putAll(entries); + this.map = Collections.unmodifiableMap(copy); + } + + @Override + public synchronized V remove(Object key) { + Map copy = new HashMap(this.map); + V prev = copy.remove(key); + this.map = Collections.unmodifiableMap(copy); + return prev; + } + + @Override + public synchronized V putIfAbsent(K k, V v) { + if (!containsKey(k)) + return put(k, v); + else + return get(k); + } + + @Override + public synchronized boolean remove(Object k, Object v) { + if (containsKey(k) && get(k).equals(v)) { + remove(k); + return true; + } else { + return false; + } + } + + @Override + public synchronized boolean replace(K k, V original, V replacement) { + if (containsKey(k) && get(k).equals(original)) { + put(k, replacement); + return true; + } else { + return false; + } + } + + @Override + public synchronized V replace(K k, V v) { + if (containsKey(k)) { + return put(k, v); + } else { + return null; + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/Crc32.java b/clients/src/main/java/org/apache/kafka/common/utils/Crc32.java new file mode 100644 index 0000000..777ea2b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/Crc32.java @@ -0,0 +1,400 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import java.nio.ByteBuffer; +import java.util.zip.Checksum; + +/** + * This class was taken from Hadoop org.apache.hadoop.util.PureJavaCrc32 + * + * A pure-java implementation of the CRC32 checksum that uses the same polynomial as the built-in native CRC32. + * + * This is to avoid the JNI overhead for certain uses of Checksumming where many small pieces of data are checksummed in + * succession. + * + * The current version is ~10x to 1.8x as fast as Sun's native java.util.zip.CRC32 in Java 1.6 + * + * @see java.util.zip.CRC32 + */ +public class Crc32 implements Checksum { + + /** + * Compute the CRC32 of the byte array + * + * @param bytes The array to compute the checksum for + * @return The CRC32 + */ + public static long crc32(byte[] bytes) { + return crc32(bytes, 0, bytes.length); + } + + /** + * Compute the CRC32 of the segment of the byte array given by the specified size and offset + * + * @param bytes The bytes to checksum + * @param offset the offset at which to begin checksumming + * @param size the number of bytes to checksum + * @return The CRC32 + */ + public static long crc32(byte[] bytes, int offset, int size) { + Crc32 crc = new Crc32(); + crc.update(bytes, offset, size); + return crc.getValue(); + } + + /** + * Compute the CRC32 of a byte buffer from a given offset (relative to the buffer's current position) + * + * @param buffer The buffer with the underlying data + * @param offset The offset relative to the current position + * @param size The number of bytes beginning from the offset to include + * @return The CRC32 + */ + public static long crc32(ByteBuffer buffer, int offset, int size) { + Crc32 crc = new Crc32(); + Checksums.update(crc, buffer, offset, size); + return crc.getValue(); + } + + /** the current CRC value, bit-flipped */ + private int crc; + + /** Create a new PureJavaCrc32 object. */ + public Crc32() { + reset(); + } + + @Override + public long getValue() { + return (~crc) & 0xffffffffL; + } + + @Override + public void reset() { + crc = 0xffffffff; + } + + @SuppressWarnings("fallthrough") + @Override + public void update(byte[] b, int off, int len) { + if (off < 0 || len < 0 || off > b.length - len) + throw new ArrayIndexOutOfBoundsException(); + + int localCrc = crc; + + while (len > 7) { + final int c0 = (b[off + 0] ^ localCrc) & 0xff; + final int c1 = (b[off + 1] ^ (localCrc >>>= 8)) & 0xff; + final int c2 = (b[off + 2] ^ (localCrc >>>= 8)) & 0xff; + final int c3 = (b[off + 3] ^ (localCrc >>>= 8)) & 0xff; + localCrc = (T[T8_7_START + c0] ^ T[T8_6_START + c1]) ^ (T[T8_5_START + c2] ^ T[T8_4_START + c3]); + + final int c4 = b[off + 4] & 0xff; + final int c5 = b[off + 5] & 0xff; + final int c6 = b[off + 6] & 0xff; + final int c7 = b[off + 7] & 0xff; + + localCrc ^= (T[T8_3_START + c4] ^ T[T8_2_START + c5]) ^ (T[T8_1_START + c6] ^ T[T8_0_START + c7]); + + off += 8; + len -= 8; + } + + /* loop unroll - duff's device style */ + switch (len) { + case 7: + localCrc = (localCrc >>> 8) ^ T[T8_0_START + ((localCrc ^ b[off++]) & 0xff)]; + case 6: + localCrc = (localCrc >>> 8) ^ T[T8_0_START + ((localCrc ^ b[off++]) & 0xff)]; + case 5: + localCrc = (localCrc >>> 8) ^ T[T8_0_START + ((localCrc ^ b[off++]) & 0xff)]; + case 4: + localCrc = (localCrc >>> 8) ^ T[T8_0_START + ((localCrc ^ b[off++]) & 0xff)]; + case 3: + localCrc = (localCrc >>> 8) ^ T[T8_0_START + ((localCrc ^ b[off++]) & 0xff)]; + case 2: + localCrc = (localCrc >>> 8) ^ T[T8_0_START + ((localCrc ^ b[off++]) & 0xff)]; + case 1: + localCrc = (localCrc >>> 8) ^ T[T8_0_START + ((localCrc ^ b[off++]) & 0xff)]; + default: + /* nothing */ + } + + // Publish crc out to object + crc = localCrc; + } + + @Override + final public void update(int b) { + crc = (crc >>> 8) ^ T[T8_0_START + ((crc ^ b) & 0xff)]; + } + + /* + * CRC-32 lookup tables generated by the polynomial 0xEDB88320. See also TestPureJavaCrc32.Table. + */ + private static final int T8_0_START = 0 * 256; + private static final int T8_1_START = 1 * 256; + private static final int T8_2_START = 2 * 256; + private static final int T8_3_START = 3 * 256; + private static final int T8_4_START = 4 * 256; + private static final int T8_5_START = 5 * 256; + private static final int T8_6_START = 6 * 256; + private static final int T8_7_START = 7 * 256; + + private static final int[] T = new int[] { + /* T8_0 */ + 0x00000000, 0x77073096, 0xEE0E612C, 0x990951BA, 0x076DC419, 0x706AF48F, 0xE963A535, 0x9E6495A3, 0x0EDB8832, + 0x79DCB8A4, 0xE0D5E91E, 0x97D2D988, 0x09B64C2B, 0x7EB17CBD, 0xE7B82D07, 0x90BF1D91, 0x1DB71064, 0x6AB020F2, + 0xF3B97148, 0x84BE41DE, 0x1ADAD47D, 0x6DDDE4EB, 0xF4D4B551, 0x83D385C7, 0x136C9856, 0x646BA8C0, 0xFD62F97A, + 0x8A65C9EC, 0x14015C4F, 0x63066CD9, 0xFA0F3D63, 0x8D080DF5, 0x3B6E20C8, 0x4C69105E, 0xD56041E4, 0xA2677172, + 0x3C03E4D1, 0x4B04D447, 0xD20D85FD, 0xA50AB56B, 0x35B5A8FA, 0x42B2986C, 0xDBBBC9D6, 0xACBCF940, 0x32D86CE3, + 0x45DF5C75, 0xDCD60DCF, 0xABD13D59, 0x26D930AC, 0x51DE003A, 0xC8D75180, 0xBFD06116, 0x21B4F4B5, 0x56B3C423, + 0xCFBA9599, 0xB8BDA50F, 0x2802B89E, 0x5F058808, 0xC60CD9B2, 0xB10BE924, 0x2F6F7C87, 0x58684C11, 0xC1611DAB, + 0xB6662D3D, 0x76DC4190, 0x01DB7106, 0x98D220BC, 0xEFD5102A, 0x71B18589, 0x06B6B51F, 0x9FBFE4A5, 0xE8B8D433, + 0x7807C9A2, 0x0F00F934, 0x9609A88E, 0xE10E9818, 0x7F6A0DBB, 0x086D3D2D, 0x91646C97, 0xE6635C01, 0x6B6B51F4, + 0x1C6C6162, 0x856530D8, 0xF262004E, 0x6C0695ED, 0x1B01A57B, 0x8208F4C1, 0xF50FC457, 0x65B0D9C6, 0x12B7E950, + 0x8BBEB8EA, 0xFCB9887C, 0x62DD1DDF, 0x15DA2D49, 0x8CD37CF3, 0xFBD44C65, 0x4DB26158, 0x3AB551CE, 0xA3BC0074, + 0xD4BB30E2, 0x4ADFA541, 0x3DD895D7, 0xA4D1C46D, 0xD3D6F4FB, 0x4369E96A, 0x346ED9FC, 0xAD678846, 0xDA60B8D0, + 0x44042D73, 0x33031DE5, 0xAA0A4C5F, 0xDD0D7CC9, 0x5005713C, 0x270241AA, 0xBE0B1010, 0xC90C2086, 0x5768B525, + 0x206F85B3, 0xB966D409, 0xCE61E49F, 0x5EDEF90E, 0x29D9C998, 0xB0D09822, 0xC7D7A8B4, 0x59B33D17, 0x2EB40D81, + 0xB7BD5C3B, 0xC0BA6CAD, 0xEDB88320, 0x9ABFB3B6, 0x03B6E20C, 0x74B1D29A, 0xEAD54739, 0x9DD277AF, 0x04DB2615, + 0x73DC1683, 0xE3630B12, 0x94643B84, 0x0D6D6A3E, 0x7A6A5AA8, 0xE40ECF0B, 0x9309FF9D, 0x0A00AE27, 0x7D079EB1, + 0xF00F9344, 0x8708A3D2, 0x1E01F268, 0x6906C2FE, 0xF762575D, 0x806567CB, 0x196C3671, 0x6E6B06E7, 0xFED41B76, + 0x89D32BE0, 0x10DA7A5A, 0x67DD4ACC, 0xF9B9DF6F, 0x8EBEEFF9, 0x17B7BE43, 0x60B08ED5, 0xD6D6A3E8, 0xA1D1937E, + 0x38D8C2C4, 0x4FDFF252, 0xD1BB67F1, 0xA6BC5767, 0x3FB506DD, 0x48B2364B, 0xD80D2BDA, 0xAF0A1B4C, 0x36034AF6, + 0x41047A60, 0xDF60EFC3, 0xA867DF55, 0x316E8EEF, 0x4669BE79, 0xCB61B38C, 0xBC66831A, 0x256FD2A0, 0x5268E236, + 0xCC0C7795, 0xBB0B4703, 0x220216B9, 0x5505262F, 0xC5BA3BBE, 0xB2BD0B28, 0x2BB45A92, 0x5CB36A04, 0xC2D7FFA7, + 0xB5D0CF31, 0x2CD99E8B, 0x5BDEAE1D, 0x9B64C2B0, 0xEC63F226, 0x756AA39C, 0x026D930A, 0x9C0906A9, 0xEB0E363F, + 0x72076785, 0x05005713, 0x95BF4A82, 0xE2B87A14, 0x7BB12BAE, 0x0CB61B38, 0x92D28E9B, 0xE5D5BE0D, 0x7CDCEFB7, + 0x0BDBDF21, 0x86D3D2D4, 0xF1D4E242, 0x68DDB3F8, 0x1FDA836E, 0x81BE16CD, 0xF6B9265B, 0x6FB077E1, 0x18B74777, + 0x88085AE6, 0xFF0F6A70, 0x66063BCA, 0x11010B5C, 0x8F659EFF, 0xF862AE69, 0x616BFFD3, 0x166CCF45, 0xA00AE278, + 0xD70DD2EE, 0x4E048354, 0x3903B3C2, 0xA7672661, 0xD06016F7, 0x4969474D, 0x3E6E77DB, 0xAED16A4A, 0xD9D65ADC, + 0x40DF0B66, 0x37D83BF0, 0xA9BCAE53, 0xDEBB9EC5, 0x47B2CF7F, 0x30B5FFE9, 0xBDBDF21C, 0xCABAC28A, 0x53B39330, + 0x24B4A3A6, 0xBAD03605, 0xCDD70693, 0x54DE5729, 0x23D967BF, 0xB3667A2E, 0xC4614AB8, 0x5D681B02, 0x2A6F2B94, + 0xB40BBE37, 0xC30C8EA1, 0x5A05DF1B, 0x2D02EF8D, + /* T8_1 */ + 0x00000000, 0x191B3141, 0x32366282, 0x2B2D53C3, 0x646CC504, 0x7D77F445, 0x565AA786, 0x4F4196C7, 0xC8D98A08, + 0xD1C2BB49, 0xFAEFE88A, 0xE3F4D9CB, 0xACB54F0C, 0xB5AE7E4D, 0x9E832D8E, 0x87981CCF, 0x4AC21251, 0x53D92310, + 0x78F470D3, 0x61EF4192, 0x2EAED755, 0x37B5E614, 0x1C98B5D7, 0x05838496, 0x821B9859, 0x9B00A918, 0xB02DFADB, + 0xA936CB9A, 0xE6775D5D, 0xFF6C6C1C, 0xD4413FDF, 0xCD5A0E9E, 0x958424A2, 0x8C9F15E3, 0xA7B24620, 0xBEA97761, + 0xF1E8E1A6, 0xE8F3D0E7, 0xC3DE8324, 0xDAC5B265, 0x5D5DAEAA, 0x44469FEB, 0x6F6BCC28, 0x7670FD69, 0x39316BAE, + 0x202A5AEF, 0x0B07092C, 0x121C386D, 0xDF4636F3, 0xC65D07B2, 0xED705471, 0xF46B6530, 0xBB2AF3F7, 0xA231C2B6, + 0x891C9175, 0x9007A034, 0x179FBCFB, 0x0E848DBA, 0x25A9DE79, 0x3CB2EF38, 0x73F379FF, 0x6AE848BE, 0x41C51B7D, + 0x58DE2A3C, 0xF0794F05, 0xE9627E44, 0xC24F2D87, 0xDB541CC6, 0x94158A01, 0x8D0EBB40, 0xA623E883, 0xBF38D9C2, + 0x38A0C50D, 0x21BBF44C, 0x0A96A78F, 0x138D96CE, 0x5CCC0009, 0x45D73148, 0x6EFA628B, 0x77E153CA, 0xBABB5D54, + 0xA3A06C15, 0x888D3FD6, 0x91960E97, 0xDED79850, 0xC7CCA911, 0xECE1FAD2, 0xF5FACB93, 0x7262D75C, 0x6B79E61D, + 0x4054B5DE, 0x594F849F, 0x160E1258, 0x0F152319, 0x243870DA, 0x3D23419B, 0x65FD6BA7, 0x7CE65AE6, 0x57CB0925, + 0x4ED03864, 0x0191AEA3, 0x188A9FE2, 0x33A7CC21, 0x2ABCFD60, 0xAD24E1AF, 0xB43FD0EE, 0x9F12832D, 0x8609B26C, + 0xC94824AB, 0xD05315EA, 0xFB7E4629, 0xE2657768, 0x2F3F79F6, 0x362448B7, 0x1D091B74, 0x04122A35, 0x4B53BCF2, + 0x52488DB3, 0x7965DE70, 0x607EEF31, 0xE7E6F3FE, 0xFEFDC2BF, 0xD5D0917C, 0xCCCBA03D, 0x838A36FA, 0x9A9107BB, + 0xB1BC5478, 0xA8A76539, 0x3B83984B, 0x2298A90A, 0x09B5FAC9, 0x10AECB88, 0x5FEF5D4F, 0x46F46C0E, 0x6DD93FCD, + 0x74C20E8C, 0xF35A1243, 0xEA412302, 0xC16C70C1, 0xD8774180, 0x9736D747, 0x8E2DE606, 0xA500B5C5, 0xBC1B8484, + 0x71418A1A, 0x685ABB5B, 0x4377E898, 0x5A6CD9D9, 0x152D4F1E, 0x0C367E5F, 0x271B2D9C, 0x3E001CDD, 0xB9980012, + 0xA0833153, 0x8BAE6290, 0x92B553D1, 0xDDF4C516, 0xC4EFF457, 0xEFC2A794, 0xF6D996D5, 0xAE07BCE9, 0xB71C8DA8, + 0x9C31DE6B, 0x852AEF2A, 0xCA6B79ED, 0xD37048AC, 0xF85D1B6F, 0xE1462A2E, 0x66DE36E1, 0x7FC507A0, 0x54E85463, + 0x4DF36522, 0x02B2F3E5, 0x1BA9C2A4, 0x30849167, 0x299FA026, 0xE4C5AEB8, 0xFDDE9FF9, 0xD6F3CC3A, 0xCFE8FD7B, + 0x80A96BBC, 0x99B25AFD, 0xB29F093E, 0xAB84387F, 0x2C1C24B0, 0x350715F1, 0x1E2A4632, 0x07317773, 0x4870E1B4, + 0x516BD0F5, 0x7A468336, 0x635DB277, 0xCBFAD74E, 0xD2E1E60F, 0xF9CCB5CC, 0xE0D7848D, 0xAF96124A, 0xB68D230B, + 0x9DA070C8, 0x84BB4189, 0x03235D46, 0x1A386C07, 0x31153FC4, 0x280E0E85, 0x674F9842, 0x7E54A903, 0x5579FAC0, + 0x4C62CB81, 0x8138C51F, 0x9823F45E, 0xB30EA79D, 0xAA1596DC, 0xE554001B, 0xFC4F315A, 0xD7626299, 0xCE7953D8, + 0x49E14F17, 0x50FA7E56, 0x7BD72D95, 0x62CC1CD4, 0x2D8D8A13, 0x3496BB52, 0x1FBBE891, 0x06A0D9D0, 0x5E7EF3EC, + 0x4765C2AD, 0x6C48916E, 0x7553A02F, 0x3A1236E8, 0x230907A9, 0x0824546A, 0x113F652B, 0x96A779E4, 0x8FBC48A5, + 0xA4911B66, 0xBD8A2A27, 0xF2CBBCE0, 0xEBD08DA1, 0xC0FDDE62, 0xD9E6EF23, 0x14BCE1BD, 0x0DA7D0FC, 0x268A833F, + 0x3F91B27E, 0x70D024B9, 0x69CB15F8, 0x42E6463B, 0x5BFD777A, 0xDC656BB5, 0xC57E5AF4, 0xEE530937, 0xF7483876, + 0xB809AEB1, 0xA1129FF0, 0x8A3FCC33, 0x9324FD72, + /* T8_2 */ + 0x00000000, 0x01C26A37, 0x0384D46E, 0x0246BE59, 0x0709A8DC, 0x06CBC2EB, 0x048D7CB2, 0x054F1685, 0x0E1351B8, + 0x0FD13B8F, 0x0D9785D6, 0x0C55EFE1, 0x091AF964, 0x08D89353, 0x0A9E2D0A, 0x0B5C473D, 0x1C26A370, 0x1DE4C947, + 0x1FA2771E, 0x1E601D29, 0x1B2F0BAC, 0x1AED619B, 0x18ABDFC2, 0x1969B5F5, 0x1235F2C8, 0x13F798FF, 0x11B126A6, + 0x10734C91, 0x153C5A14, 0x14FE3023, 0x16B88E7A, 0x177AE44D, 0x384D46E0, 0x398F2CD7, 0x3BC9928E, 0x3A0BF8B9, + 0x3F44EE3C, 0x3E86840B, 0x3CC03A52, 0x3D025065, 0x365E1758, 0x379C7D6F, 0x35DAC336, 0x3418A901, 0x3157BF84, + 0x3095D5B3, 0x32D36BEA, 0x331101DD, 0x246BE590, 0x25A98FA7, 0x27EF31FE, 0x262D5BC9, 0x23624D4C, 0x22A0277B, + 0x20E69922, 0x2124F315, 0x2A78B428, 0x2BBADE1F, 0x29FC6046, 0x283E0A71, 0x2D711CF4, 0x2CB376C3, 0x2EF5C89A, + 0x2F37A2AD, 0x709A8DC0, 0x7158E7F7, 0x731E59AE, 0x72DC3399, 0x7793251C, 0x76514F2B, 0x7417F172, 0x75D59B45, + 0x7E89DC78, 0x7F4BB64F, 0x7D0D0816, 0x7CCF6221, 0x798074A4, 0x78421E93, 0x7A04A0CA, 0x7BC6CAFD, 0x6CBC2EB0, + 0x6D7E4487, 0x6F38FADE, 0x6EFA90E9, 0x6BB5866C, 0x6A77EC5B, 0x68315202, 0x69F33835, 0x62AF7F08, 0x636D153F, + 0x612BAB66, 0x60E9C151, 0x65A6D7D4, 0x6464BDE3, 0x662203BA, 0x67E0698D, 0x48D7CB20, 0x4915A117, 0x4B531F4E, + 0x4A917579, 0x4FDE63FC, 0x4E1C09CB, 0x4C5AB792, 0x4D98DDA5, 0x46C49A98, 0x4706F0AF, 0x45404EF6, 0x448224C1, + 0x41CD3244, 0x400F5873, 0x4249E62A, 0x438B8C1D, 0x54F16850, 0x55330267, 0x5775BC3E, 0x56B7D609, 0x53F8C08C, + 0x523AAABB, 0x507C14E2, 0x51BE7ED5, 0x5AE239E8, 0x5B2053DF, 0x5966ED86, 0x58A487B1, 0x5DEB9134, 0x5C29FB03, + 0x5E6F455A, 0x5FAD2F6D, 0xE1351B80, 0xE0F771B7, 0xE2B1CFEE, 0xE373A5D9, 0xE63CB35C, 0xE7FED96B, 0xE5B86732, + 0xE47A0D05, 0xEF264A38, 0xEEE4200F, 0xECA29E56, 0xED60F461, 0xE82FE2E4, 0xE9ED88D3, 0xEBAB368A, 0xEA695CBD, + 0xFD13B8F0, 0xFCD1D2C7, 0xFE976C9E, 0xFF5506A9, 0xFA1A102C, 0xFBD87A1B, 0xF99EC442, 0xF85CAE75, 0xF300E948, + 0xF2C2837F, 0xF0843D26, 0xF1465711, 0xF4094194, 0xF5CB2BA3, 0xF78D95FA, 0xF64FFFCD, 0xD9785D60, 0xD8BA3757, + 0xDAFC890E, 0xDB3EE339, 0xDE71F5BC, 0xDFB39F8B, 0xDDF521D2, 0xDC374BE5, 0xD76B0CD8, 0xD6A966EF, 0xD4EFD8B6, + 0xD52DB281, 0xD062A404, 0xD1A0CE33, 0xD3E6706A, 0xD2241A5D, 0xC55EFE10, 0xC49C9427, 0xC6DA2A7E, 0xC7184049, + 0xC25756CC, 0xC3953CFB, 0xC1D382A2, 0xC011E895, 0xCB4DAFA8, 0xCA8FC59F, 0xC8C97BC6, 0xC90B11F1, 0xCC440774, + 0xCD866D43, 0xCFC0D31A, 0xCE02B92D, 0x91AF9640, 0x906DFC77, 0x922B422E, 0x93E92819, 0x96A63E9C, 0x976454AB, + 0x9522EAF2, 0x94E080C5, 0x9FBCC7F8, 0x9E7EADCF, 0x9C381396, 0x9DFA79A1, 0x98B56F24, 0x99770513, 0x9B31BB4A, + 0x9AF3D17D, 0x8D893530, 0x8C4B5F07, 0x8E0DE15E, 0x8FCF8B69, 0x8A809DEC, 0x8B42F7DB, 0x89044982, 0x88C623B5, + 0x839A6488, 0x82580EBF, 0x801EB0E6, 0x81DCDAD1, 0x8493CC54, 0x8551A663, 0x8717183A, 0x86D5720D, 0xA9E2D0A0, + 0xA820BA97, 0xAA6604CE, 0xABA46EF9, 0xAEEB787C, 0xAF29124B, 0xAD6FAC12, 0xACADC625, 0xA7F18118, 0xA633EB2F, + 0xA4755576, 0xA5B73F41, 0xA0F829C4, 0xA13A43F3, 0xA37CFDAA, 0xA2BE979D, 0xB5C473D0, 0xB40619E7, 0xB640A7BE, + 0xB782CD89, 0xB2CDDB0C, 0xB30FB13B, 0xB1490F62, 0xB08B6555, 0xBBD72268, 0xBA15485F, 0xB853F606, 0xB9919C31, + 0xBCDE8AB4, 0xBD1CE083, 0xBF5A5EDA, 0xBE9834ED, + /* T8_3 */ + 0x00000000, 0xB8BC6765, 0xAA09C88B, 0x12B5AFEE, 0x8F629757, 0x37DEF032, 0x256B5FDC, 0x9DD738B9, 0xC5B428EF, + 0x7D084F8A, 0x6FBDE064, 0xD7018701, 0x4AD6BFB8, 0xF26AD8DD, 0xE0DF7733, 0x58631056, 0x5019579F, 0xE8A530FA, + 0xFA109F14, 0x42ACF871, 0xDF7BC0C8, 0x67C7A7AD, 0x75720843, 0xCDCE6F26, 0x95AD7F70, 0x2D111815, 0x3FA4B7FB, + 0x8718D09E, 0x1ACFE827, 0xA2738F42, 0xB0C620AC, 0x087A47C9, 0xA032AF3E, 0x188EC85B, 0x0A3B67B5, 0xB28700D0, + 0x2F503869, 0x97EC5F0C, 0x8559F0E2, 0x3DE59787, 0x658687D1, 0xDD3AE0B4, 0xCF8F4F5A, 0x7733283F, 0xEAE41086, + 0x525877E3, 0x40EDD80D, 0xF851BF68, 0xF02BF8A1, 0x48979FC4, 0x5A22302A, 0xE29E574F, 0x7F496FF6, 0xC7F50893, + 0xD540A77D, 0x6DFCC018, 0x359FD04E, 0x8D23B72B, 0x9F9618C5, 0x272A7FA0, 0xBAFD4719, 0x0241207C, 0x10F48F92, + 0xA848E8F7, 0x9B14583D, 0x23A83F58, 0x311D90B6, 0x89A1F7D3, 0x1476CF6A, 0xACCAA80F, 0xBE7F07E1, 0x06C36084, + 0x5EA070D2, 0xE61C17B7, 0xF4A9B859, 0x4C15DF3C, 0xD1C2E785, 0x697E80E0, 0x7BCB2F0E, 0xC377486B, 0xCB0D0FA2, + 0x73B168C7, 0x6104C729, 0xD9B8A04C, 0x446F98F5, 0xFCD3FF90, 0xEE66507E, 0x56DA371B, 0x0EB9274D, 0xB6054028, + 0xA4B0EFC6, 0x1C0C88A3, 0x81DBB01A, 0x3967D77F, 0x2BD27891, 0x936E1FF4, 0x3B26F703, 0x839A9066, 0x912F3F88, + 0x299358ED, 0xB4446054, 0x0CF80731, 0x1E4DA8DF, 0xA6F1CFBA, 0xFE92DFEC, 0x462EB889, 0x549B1767, 0xEC277002, + 0x71F048BB, 0xC94C2FDE, 0xDBF98030, 0x6345E755, 0x6B3FA09C, 0xD383C7F9, 0xC1366817, 0x798A0F72, 0xE45D37CB, + 0x5CE150AE, 0x4E54FF40, 0xF6E89825, 0xAE8B8873, 0x1637EF16, 0x048240F8, 0xBC3E279D, 0x21E91F24, 0x99557841, + 0x8BE0D7AF, 0x335CB0CA, 0xED59B63B, 0x55E5D15E, 0x47507EB0, 0xFFEC19D5, 0x623B216C, 0xDA874609, 0xC832E9E7, + 0x708E8E82, 0x28ED9ED4, 0x9051F9B1, 0x82E4565F, 0x3A58313A, 0xA78F0983, 0x1F336EE6, 0x0D86C108, 0xB53AA66D, + 0xBD40E1A4, 0x05FC86C1, 0x1749292F, 0xAFF54E4A, 0x322276F3, 0x8A9E1196, 0x982BBE78, 0x2097D91D, 0x78F4C94B, + 0xC048AE2E, 0xD2FD01C0, 0x6A4166A5, 0xF7965E1C, 0x4F2A3979, 0x5D9F9697, 0xE523F1F2, 0x4D6B1905, 0xF5D77E60, + 0xE762D18E, 0x5FDEB6EB, 0xC2098E52, 0x7AB5E937, 0x680046D9, 0xD0BC21BC, 0x88DF31EA, 0x3063568F, 0x22D6F961, + 0x9A6A9E04, 0x07BDA6BD, 0xBF01C1D8, 0xADB46E36, 0x15080953, 0x1D724E9A, 0xA5CE29FF, 0xB77B8611, 0x0FC7E174, + 0x9210D9CD, 0x2AACBEA8, 0x38191146, 0x80A57623, 0xD8C66675, 0x607A0110, 0x72CFAEFE, 0xCA73C99B, 0x57A4F122, + 0xEF189647, 0xFDAD39A9, 0x45115ECC, 0x764DEE06, 0xCEF18963, 0xDC44268D, 0x64F841E8, 0xF92F7951, 0x41931E34, + 0x5326B1DA, 0xEB9AD6BF, 0xB3F9C6E9, 0x0B45A18C, 0x19F00E62, 0xA14C6907, 0x3C9B51BE, 0x842736DB, 0x96929935, + 0x2E2EFE50, 0x2654B999, 0x9EE8DEFC, 0x8C5D7112, 0x34E11677, 0xA9362ECE, 0x118A49AB, 0x033FE645, 0xBB838120, + 0xE3E09176, 0x5B5CF613, 0x49E959FD, 0xF1553E98, 0x6C820621, 0xD43E6144, 0xC68BCEAA, 0x7E37A9CF, 0xD67F4138, + 0x6EC3265D, 0x7C7689B3, 0xC4CAEED6, 0x591DD66F, 0xE1A1B10A, 0xF3141EE4, 0x4BA87981, 0x13CB69D7, 0xAB770EB2, + 0xB9C2A15C, 0x017EC639, 0x9CA9FE80, 0x241599E5, 0x36A0360B, 0x8E1C516E, 0x866616A7, 0x3EDA71C2, 0x2C6FDE2C, + 0x94D3B949, 0x090481F0, 0xB1B8E695, 0xA30D497B, 0x1BB12E1E, 0x43D23E48, 0xFB6E592D, 0xE9DBF6C3, 0x516791A6, + 0xCCB0A91F, 0x740CCE7A, 0x66B96194, 0xDE0506F1, + /* T8_4 */ + 0x00000000, 0x3D6029B0, 0x7AC05360, 0x47A07AD0, 0xF580A6C0, 0xC8E08F70, 0x8F40F5A0, 0xB220DC10, 0x30704BC1, + 0x0D106271, 0x4AB018A1, 0x77D03111, 0xC5F0ED01, 0xF890C4B1, 0xBF30BE61, 0x825097D1, 0x60E09782, 0x5D80BE32, + 0x1A20C4E2, 0x2740ED52, 0x95603142, 0xA80018F2, 0xEFA06222, 0xD2C04B92, 0x5090DC43, 0x6DF0F5F3, 0x2A508F23, + 0x1730A693, 0xA5107A83, 0x98705333, 0xDFD029E3, 0xE2B00053, 0xC1C12F04, 0xFCA106B4, 0xBB017C64, 0x866155D4, + 0x344189C4, 0x0921A074, 0x4E81DAA4, 0x73E1F314, 0xF1B164C5, 0xCCD14D75, 0x8B7137A5, 0xB6111E15, 0x0431C205, + 0x3951EBB5, 0x7EF19165, 0x4391B8D5, 0xA121B886, 0x9C419136, 0xDBE1EBE6, 0xE681C256, 0x54A11E46, 0x69C137F6, + 0x2E614D26, 0x13016496, 0x9151F347, 0xAC31DAF7, 0xEB91A027, 0xD6F18997, 0x64D15587, 0x59B17C37, 0x1E1106E7, + 0x23712F57, 0x58F35849, 0x659371F9, 0x22330B29, 0x1F532299, 0xAD73FE89, 0x9013D739, 0xD7B3ADE9, 0xEAD38459, + 0x68831388, 0x55E33A38, 0x124340E8, 0x2F236958, 0x9D03B548, 0xA0639CF8, 0xE7C3E628, 0xDAA3CF98, 0x3813CFCB, + 0x0573E67B, 0x42D39CAB, 0x7FB3B51B, 0xCD93690B, 0xF0F340BB, 0xB7533A6B, 0x8A3313DB, 0x0863840A, 0x3503ADBA, + 0x72A3D76A, 0x4FC3FEDA, 0xFDE322CA, 0xC0830B7A, 0x872371AA, 0xBA43581A, 0x9932774D, 0xA4525EFD, 0xE3F2242D, + 0xDE920D9D, 0x6CB2D18D, 0x51D2F83D, 0x167282ED, 0x2B12AB5D, 0xA9423C8C, 0x9422153C, 0xD3826FEC, 0xEEE2465C, + 0x5CC29A4C, 0x61A2B3FC, 0x2602C92C, 0x1B62E09C, 0xF9D2E0CF, 0xC4B2C97F, 0x8312B3AF, 0xBE729A1F, 0x0C52460F, + 0x31326FBF, 0x7692156F, 0x4BF23CDF, 0xC9A2AB0E, 0xF4C282BE, 0xB362F86E, 0x8E02D1DE, 0x3C220DCE, 0x0142247E, + 0x46E25EAE, 0x7B82771E, 0xB1E6B092, 0x8C869922, 0xCB26E3F2, 0xF646CA42, 0x44661652, 0x79063FE2, 0x3EA64532, + 0x03C66C82, 0x8196FB53, 0xBCF6D2E3, 0xFB56A833, 0xC6368183, 0x74165D93, 0x49767423, 0x0ED60EF3, 0x33B62743, + 0xD1062710, 0xEC660EA0, 0xABC67470, 0x96A65DC0, 0x248681D0, 0x19E6A860, 0x5E46D2B0, 0x6326FB00, 0xE1766CD1, + 0xDC164561, 0x9BB63FB1, 0xA6D61601, 0x14F6CA11, 0x2996E3A1, 0x6E369971, 0x5356B0C1, 0x70279F96, 0x4D47B626, + 0x0AE7CCF6, 0x3787E546, 0x85A73956, 0xB8C710E6, 0xFF676A36, 0xC2074386, 0x4057D457, 0x7D37FDE7, 0x3A978737, + 0x07F7AE87, 0xB5D77297, 0x88B75B27, 0xCF1721F7, 0xF2770847, 0x10C70814, 0x2DA721A4, 0x6A075B74, 0x576772C4, + 0xE547AED4, 0xD8278764, 0x9F87FDB4, 0xA2E7D404, 0x20B743D5, 0x1DD76A65, 0x5A7710B5, 0x67173905, 0xD537E515, + 0xE857CCA5, 0xAFF7B675, 0x92979FC5, 0xE915E8DB, 0xD475C16B, 0x93D5BBBB, 0xAEB5920B, 0x1C954E1B, 0x21F567AB, + 0x66551D7B, 0x5B3534CB, 0xD965A31A, 0xE4058AAA, 0xA3A5F07A, 0x9EC5D9CA, 0x2CE505DA, 0x11852C6A, 0x562556BA, + 0x6B457F0A, 0x89F57F59, 0xB49556E9, 0xF3352C39, 0xCE550589, 0x7C75D999, 0x4115F029, 0x06B58AF9, 0x3BD5A349, + 0xB9853498, 0x84E51D28, 0xC34567F8, 0xFE254E48, 0x4C059258, 0x7165BBE8, 0x36C5C138, 0x0BA5E888, 0x28D4C7DF, + 0x15B4EE6F, 0x521494BF, 0x6F74BD0F, 0xDD54611F, 0xE03448AF, 0xA794327F, 0x9AF41BCF, 0x18A48C1E, 0x25C4A5AE, + 0x6264DF7E, 0x5F04F6CE, 0xED242ADE, 0xD044036E, 0x97E479BE, 0xAA84500E, 0x4834505D, 0x755479ED, 0x32F4033D, + 0x0F942A8D, 0xBDB4F69D, 0x80D4DF2D, 0xC774A5FD, 0xFA148C4D, 0x78441B9C, 0x4524322C, 0x028448FC, 0x3FE4614C, + 0x8DC4BD5C, 0xB0A494EC, 0xF704EE3C, 0xCA64C78C, + /* T8_5 */ + 0x00000000, 0xCB5CD3A5, 0x4DC8A10B, 0x869472AE, 0x9B914216, 0x50CD91B3, 0xD659E31D, 0x1D0530B8, 0xEC53826D, + 0x270F51C8, 0xA19B2366, 0x6AC7F0C3, 0x77C2C07B, 0xBC9E13DE, 0x3A0A6170, 0xF156B2D5, 0x03D6029B, 0xC88AD13E, + 0x4E1EA390, 0x85427035, 0x9847408D, 0x531B9328, 0xD58FE186, 0x1ED33223, 0xEF8580F6, 0x24D95353, 0xA24D21FD, + 0x6911F258, 0x7414C2E0, 0xBF481145, 0x39DC63EB, 0xF280B04E, 0x07AC0536, 0xCCF0D693, 0x4A64A43D, 0x81387798, + 0x9C3D4720, 0x57619485, 0xD1F5E62B, 0x1AA9358E, 0xEBFF875B, 0x20A354FE, 0xA6372650, 0x6D6BF5F5, 0x706EC54D, + 0xBB3216E8, 0x3DA66446, 0xF6FAB7E3, 0x047A07AD, 0xCF26D408, 0x49B2A6A6, 0x82EE7503, 0x9FEB45BB, 0x54B7961E, + 0xD223E4B0, 0x197F3715, 0xE82985C0, 0x23755665, 0xA5E124CB, 0x6EBDF76E, 0x73B8C7D6, 0xB8E41473, 0x3E7066DD, + 0xF52CB578, 0x0F580A6C, 0xC404D9C9, 0x4290AB67, 0x89CC78C2, 0x94C9487A, 0x5F959BDF, 0xD901E971, 0x125D3AD4, + 0xE30B8801, 0x28575BA4, 0xAEC3290A, 0x659FFAAF, 0x789ACA17, 0xB3C619B2, 0x35526B1C, 0xFE0EB8B9, 0x0C8E08F7, + 0xC7D2DB52, 0x4146A9FC, 0x8A1A7A59, 0x971F4AE1, 0x5C439944, 0xDAD7EBEA, 0x118B384F, 0xE0DD8A9A, 0x2B81593F, + 0xAD152B91, 0x6649F834, 0x7B4CC88C, 0xB0101B29, 0x36846987, 0xFDD8BA22, 0x08F40F5A, 0xC3A8DCFF, 0x453CAE51, + 0x8E607DF4, 0x93654D4C, 0x58399EE9, 0xDEADEC47, 0x15F13FE2, 0xE4A78D37, 0x2FFB5E92, 0xA96F2C3C, 0x6233FF99, + 0x7F36CF21, 0xB46A1C84, 0x32FE6E2A, 0xF9A2BD8F, 0x0B220DC1, 0xC07EDE64, 0x46EAACCA, 0x8DB67F6F, 0x90B34FD7, + 0x5BEF9C72, 0xDD7BEEDC, 0x16273D79, 0xE7718FAC, 0x2C2D5C09, 0xAAB92EA7, 0x61E5FD02, 0x7CE0CDBA, 0xB7BC1E1F, + 0x31286CB1, 0xFA74BF14, 0x1EB014D8, 0xD5ECC77D, 0x5378B5D3, 0x98246676, 0x852156CE, 0x4E7D856B, 0xC8E9F7C5, + 0x03B52460, 0xF2E396B5, 0x39BF4510, 0xBF2B37BE, 0x7477E41B, 0x6972D4A3, 0xA22E0706, 0x24BA75A8, 0xEFE6A60D, + 0x1D661643, 0xD63AC5E6, 0x50AEB748, 0x9BF264ED, 0x86F75455, 0x4DAB87F0, 0xCB3FF55E, 0x006326FB, 0xF135942E, + 0x3A69478B, 0xBCFD3525, 0x77A1E680, 0x6AA4D638, 0xA1F8059D, 0x276C7733, 0xEC30A496, 0x191C11EE, 0xD240C24B, + 0x54D4B0E5, 0x9F886340, 0x828D53F8, 0x49D1805D, 0xCF45F2F3, 0x04192156, 0xF54F9383, 0x3E134026, 0xB8873288, + 0x73DBE12D, 0x6EDED195, 0xA5820230, 0x2316709E, 0xE84AA33B, 0x1ACA1375, 0xD196C0D0, 0x5702B27E, 0x9C5E61DB, + 0x815B5163, 0x4A0782C6, 0xCC93F068, 0x07CF23CD, 0xF6999118, 0x3DC542BD, 0xBB513013, 0x700DE3B6, 0x6D08D30E, + 0xA65400AB, 0x20C07205, 0xEB9CA1A0, 0x11E81EB4, 0xDAB4CD11, 0x5C20BFBF, 0x977C6C1A, 0x8A795CA2, 0x41258F07, + 0xC7B1FDA9, 0x0CED2E0C, 0xFDBB9CD9, 0x36E74F7C, 0xB0733DD2, 0x7B2FEE77, 0x662ADECF, 0xAD760D6A, 0x2BE27FC4, + 0xE0BEAC61, 0x123E1C2F, 0xD962CF8A, 0x5FF6BD24, 0x94AA6E81, 0x89AF5E39, 0x42F38D9C, 0xC467FF32, 0x0F3B2C97, + 0xFE6D9E42, 0x35314DE7, 0xB3A53F49, 0x78F9ECEC, 0x65FCDC54, 0xAEA00FF1, 0x28347D5F, 0xE368AEFA, 0x16441B82, + 0xDD18C827, 0x5B8CBA89, 0x90D0692C, 0x8DD55994, 0x46898A31, 0xC01DF89F, 0x0B412B3A, 0xFA1799EF, 0x314B4A4A, + 0xB7DF38E4, 0x7C83EB41, 0x6186DBF9, 0xAADA085C, 0x2C4E7AF2, 0xE712A957, 0x15921919, 0xDECECABC, 0x585AB812, + 0x93066BB7, 0x8E035B0F, 0x455F88AA, 0xC3CBFA04, 0x089729A1, 0xF9C19B74, 0x329D48D1, 0xB4093A7F, 0x7F55E9DA, + 0x6250D962, 0xA90C0AC7, 0x2F987869, 0xE4C4ABCC, + /* T8_6 */ + 0x00000000, 0xA6770BB4, 0x979F1129, 0x31E81A9D, 0xF44F2413, 0x52382FA7, 0x63D0353A, 0xC5A73E8E, 0x33EF4E67, + 0x959845D3, 0xA4705F4E, 0x020754FA, 0xC7A06A74, 0x61D761C0, 0x503F7B5D, 0xF64870E9, 0x67DE9CCE, 0xC1A9977A, + 0xF0418DE7, 0x56368653, 0x9391B8DD, 0x35E6B369, 0x040EA9F4, 0xA279A240, 0x5431D2A9, 0xF246D91D, 0xC3AEC380, + 0x65D9C834, 0xA07EF6BA, 0x0609FD0E, 0x37E1E793, 0x9196EC27, 0xCFBD399C, 0x69CA3228, 0x582228B5, 0xFE552301, + 0x3BF21D8F, 0x9D85163B, 0xAC6D0CA6, 0x0A1A0712, 0xFC5277FB, 0x5A257C4F, 0x6BCD66D2, 0xCDBA6D66, 0x081D53E8, + 0xAE6A585C, 0x9F8242C1, 0x39F54975, 0xA863A552, 0x0E14AEE6, 0x3FFCB47B, 0x998BBFCF, 0x5C2C8141, 0xFA5B8AF5, + 0xCBB39068, 0x6DC49BDC, 0x9B8CEB35, 0x3DFBE081, 0x0C13FA1C, 0xAA64F1A8, 0x6FC3CF26, 0xC9B4C492, 0xF85CDE0F, + 0x5E2BD5BB, 0x440B7579, 0xE27C7ECD, 0xD3946450, 0x75E36FE4, 0xB044516A, 0x16335ADE, 0x27DB4043, 0x81AC4BF7, + 0x77E43B1E, 0xD19330AA, 0xE07B2A37, 0x460C2183, 0x83AB1F0D, 0x25DC14B9, 0x14340E24, 0xB2430590, 0x23D5E9B7, + 0x85A2E203, 0xB44AF89E, 0x123DF32A, 0xD79ACDA4, 0x71EDC610, 0x4005DC8D, 0xE672D739, 0x103AA7D0, 0xB64DAC64, + 0x87A5B6F9, 0x21D2BD4D, 0xE47583C3, 0x42028877, 0x73EA92EA, 0xD59D995E, 0x8BB64CE5, 0x2DC14751, 0x1C295DCC, + 0xBA5E5678, 0x7FF968F6, 0xD98E6342, 0xE86679DF, 0x4E11726B, 0xB8590282, 0x1E2E0936, 0x2FC613AB, 0x89B1181F, + 0x4C162691, 0xEA612D25, 0xDB8937B8, 0x7DFE3C0C, 0xEC68D02B, 0x4A1FDB9F, 0x7BF7C102, 0xDD80CAB6, 0x1827F438, + 0xBE50FF8C, 0x8FB8E511, 0x29CFEEA5, 0xDF879E4C, 0x79F095F8, 0x48188F65, 0xEE6F84D1, 0x2BC8BA5F, 0x8DBFB1EB, + 0xBC57AB76, 0x1A20A0C2, 0x8816EAF2, 0x2E61E146, 0x1F89FBDB, 0xB9FEF06F, 0x7C59CEE1, 0xDA2EC555, 0xEBC6DFC8, + 0x4DB1D47C, 0xBBF9A495, 0x1D8EAF21, 0x2C66B5BC, 0x8A11BE08, 0x4FB68086, 0xE9C18B32, 0xD82991AF, 0x7E5E9A1B, + 0xEFC8763C, 0x49BF7D88, 0x78576715, 0xDE206CA1, 0x1B87522F, 0xBDF0599B, 0x8C184306, 0x2A6F48B2, 0xDC27385B, + 0x7A5033EF, 0x4BB82972, 0xEDCF22C6, 0x28681C48, 0x8E1F17FC, 0xBFF70D61, 0x198006D5, 0x47ABD36E, 0xE1DCD8DA, + 0xD034C247, 0x7643C9F3, 0xB3E4F77D, 0x1593FCC9, 0x247BE654, 0x820CEDE0, 0x74449D09, 0xD23396BD, 0xE3DB8C20, + 0x45AC8794, 0x800BB91A, 0x267CB2AE, 0x1794A833, 0xB1E3A387, 0x20754FA0, 0x86024414, 0xB7EA5E89, 0x119D553D, + 0xD43A6BB3, 0x724D6007, 0x43A57A9A, 0xE5D2712E, 0x139A01C7, 0xB5ED0A73, 0x840510EE, 0x22721B5A, 0xE7D525D4, + 0x41A22E60, 0x704A34FD, 0xD63D3F49, 0xCC1D9F8B, 0x6A6A943F, 0x5B828EA2, 0xFDF58516, 0x3852BB98, 0x9E25B02C, + 0xAFCDAAB1, 0x09BAA105, 0xFFF2D1EC, 0x5985DA58, 0x686DC0C5, 0xCE1ACB71, 0x0BBDF5FF, 0xADCAFE4B, 0x9C22E4D6, + 0x3A55EF62, 0xABC30345, 0x0DB408F1, 0x3C5C126C, 0x9A2B19D8, 0x5F8C2756, 0xF9FB2CE2, 0xC813367F, 0x6E643DCB, + 0x982C4D22, 0x3E5B4696, 0x0FB35C0B, 0xA9C457BF, 0x6C636931, 0xCA146285, 0xFBFC7818, 0x5D8B73AC, 0x03A0A617, + 0xA5D7ADA3, 0x943FB73E, 0x3248BC8A, 0xF7EF8204, 0x519889B0, 0x6070932D, 0xC6079899, 0x304FE870, 0x9638E3C4, + 0xA7D0F959, 0x01A7F2ED, 0xC400CC63, 0x6277C7D7, 0x539FDD4A, 0xF5E8D6FE, 0x647E3AD9, 0xC209316D, 0xF3E12BF0, + 0x55962044, 0x90311ECA, 0x3646157E, 0x07AE0FE3, 0xA1D90457, 0x579174BE, 0xF1E67F0A, 0xC00E6597, 0x66796E23, + 0xA3DE50AD, 0x05A95B19, 0x34414184, 0x92364A30, + /* T8_7 */ + 0x00000000, 0xCCAA009E, 0x4225077D, 0x8E8F07E3, 0x844A0EFA, 0x48E00E64, 0xC66F0987, 0x0AC50919, 0xD3E51BB5, + 0x1F4F1B2B, 0x91C01CC8, 0x5D6A1C56, 0x57AF154F, 0x9B0515D1, 0x158A1232, 0xD92012AC, 0x7CBB312B, 0xB01131B5, + 0x3E9E3656, 0xF23436C8, 0xF8F13FD1, 0x345B3F4F, 0xBAD438AC, 0x767E3832, 0xAF5E2A9E, 0x63F42A00, 0xED7B2DE3, + 0x21D12D7D, 0x2B142464, 0xE7BE24FA, 0x69312319, 0xA59B2387, 0xF9766256, 0x35DC62C8, 0xBB53652B, 0x77F965B5, + 0x7D3C6CAC, 0xB1966C32, 0x3F196BD1, 0xF3B36B4F, 0x2A9379E3, 0xE639797D, 0x68B67E9E, 0xA41C7E00, 0xAED97719, + 0x62737787, 0xECFC7064, 0x205670FA, 0x85CD537D, 0x496753E3, 0xC7E85400, 0x0B42549E, 0x01875D87, 0xCD2D5D19, + 0x43A25AFA, 0x8F085A64, 0x562848C8, 0x9A824856, 0x140D4FB5, 0xD8A74F2B, 0xD2624632, 0x1EC846AC, 0x9047414F, + 0x5CED41D1, 0x299DC2ED, 0xE537C273, 0x6BB8C590, 0xA712C50E, 0xADD7CC17, 0x617DCC89, 0xEFF2CB6A, 0x2358CBF4, + 0xFA78D958, 0x36D2D9C6, 0xB85DDE25, 0x74F7DEBB, 0x7E32D7A2, 0xB298D73C, 0x3C17D0DF, 0xF0BDD041, 0x5526F3C6, + 0x998CF358, 0x1703F4BB, 0xDBA9F425, 0xD16CFD3C, 0x1DC6FDA2, 0x9349FA41, 0x5FE3FADF, 0x86C3E873, 0x4A69E8ED, + 0xC4E6EF0E, 0x084CEF90, 0x0289E689, 0xCE23E617, 0x40ACE1F4, 0x8C06E16A, 0xD0EBA0BB, 0x1C41A025, 0x92CEA7C6, + 0x5E64A758, 0x54A1AE41, 0x980BAEDF, 0x1684A93C, 0xDA2EA9A2, 0x030EBB0E, 0xCFA4BB90, 0x412BBC73, 0x8D81BCED, + 0x8744B5F4, 0x4BEEB56A, 0xC561B289, 0x09CBB217, 0xAC509190, 0x60FA910E, 0xEE7596ED, 0x22DF9673, 0x281A9F6A, + 0xE4B09FF4, 0x6A3F9817, 0xA6959889, 0x7FB58A25, 0xB31F8ABB, 0x3D908D58, 0xF13A8DC6, 0xFBFF84DF, 0x37558441, + 0xB9DA83A2, 0x7570833C, 0x533B85DA, 0x9F918544, 0x111E82A7, 0xDDB48239, 0xD7718B20, 0x1BDB8BBE, 0x95548C5D, + 0x59FE8CC3, 0x80DE9E6F, 0x4C749EF1, 0xC2FB9912, 0x0E51998C, 0x04949095, 0xC83E900B, 0x46B197E8, 0x8A1B9776, + 0x2F80B4F1, 0xE32AB46F, 0x6DA5B38C, 0xA10FB312, 0xABCABA0B, 0x6760BA95, 0xE9EFBD76, 0x2545BDE8, 0xFC65AF44, + 0x30CFAFDA, 0xBE40A839, 0x72EAA8A7, 0x782FA1BE, 0xB485A120, 0x3A0AA6C3, 0xF6A0A65D, 0xAA4DE78C, 0x66E7E712, + 0xE868E0F1, 0x24C2E06F, 0x2E07E976, 0xE2ADE9E8, 0x6C22EE0B, 0xA088EE95, 0x79A8FC39, 0xB502FCA7, 0x3B8DFB44, + 0xF727FBDA, 0xFDE2F2C3, 0x3148F25D, 0xBFC7F5BE, 0x736DF520, 0xD6F6D6A7, 0x1A5CD639, 0x94D3D1DA, 0x5879D144, + 0x52BCD85D, 0x9E16D8C3, 0x1099DF20, 0xDC33DFBE, 0x0513CD12, 0xC9B9CD8C, 0x4736CA6F, 0x8B9CCAF1, 0x8159C3E8, + 0x4DF3C376, 0xC37CC495, 0x0FD6C40B, 0x7AA64737, 0xB60C47A9, 0x3883404A, 0xF42940D4, 0xFEEC49CD, 0x32464953, + 0xBCC94EB0, 0x70634E2E, 0xA9435C82, 0x65E95C1C, 0xEB665BFF, 0x27CC5B61, 0x2D095278, 0xE1A352E6, 0x6F2C5505, + 0xA386559B, 0x061D761C, 0xCAB77682, 0x44387161, 0x889271FF, 0x825778E6, 0x4EFD7878, 0xC0727F9B, 0x0CD87F05, + 0xD5F86DA9, 0x19526D37, 0x97DD6AD4, 0x5B776A4A, 0x51B26353, 0x9D1863CD, 0x1397642E, 0xDF3D64B0, 0x83D02561, + 0x4F7A25FF, 0xC1F5221C, 0x0D5F2282, 0x079A2B9B, 0xCB302B05, 0x45BF2CE6, 0x89152C78, 0x50353ED4, 0x9C9F3E4A, + 0x121039A9, 0xDEBA3937, 0xD47F302E, 0x18D530B0, 0x965A3753, 0x5AF037CD, 0xFF6B144A, 0x33C114D4, 0xBD4E1337, + 0x71E413A9, 0x7B211AB0, 0xB78B1A2E, 0x39041DCD, 0xF5AE1D53, 0x2C8E0FFF, 0xE0240F61, 0x6EAB0882, 0xA201081C, + 0xA8C40105, 0x646E019B, 0xEAE10678, 0x264B06E6 }; +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/Crc32C.java b/clients/src/main/java/org/apache/kafka/common/utils/Crc32C.java new file mode 100644 index 0000000..dfe22e8 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/Crc32C.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.utils; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.nio.ByteBuffer; +import java.util.zip.Checksum; + +/** + * A class that can be used to compute the CRC32C (Castagnoli) of a ByteBuffer or array of bytes. + * + * We use java.util.zip.CRC32C (introduced in Java 9) if it is available and fallback to PureJavaCrc32C, otherwise. + * java.util.zip.CRC32C is significantly faster on reasonably modern CPUs as it uses the CRC32 instruction introduced + * in SSE4.2. + * + * NOTE: This class is intended for INTERNAL usage only within Kafka. + */ +public final class Crc32C { + + private static final ChecksumFactory CHECKSUM_FACTORY; + + static { + if (Java.IS_JAVA9_COMPATIBLE) + CHECKSUM_FACTORY = new Java9ChecksumFactory(); + else + CHECKSUM_FACTORY = new PureJavaChecksumFactory(); + } + + private Crc32C() {} + + /** + * Compute the CRC32C (Castagnoli) of the segment of the byte array given by the specified size and offset + * + * @param bytes The bytes to checksum + * @param offset the offset at which to begin the checksum computation + * @param size the number of bytes to checksum + * @return The CRC32C + */ + public static long compute(byte[] bytes, int offset, int size) { + Checksum crc = create(); + crc.update(bytes, offset, size); + return crc.getValue(); + } + + /** + * Compute the CRC32C (Castagnoli) of a byte buffer from a given offset (relative to the buffer's current position) + * + * @param buffer The buffer with the underlying data + * @param offset The offset relative to the current position + * @param size The number of bytes beginning from the offset to include + * @return The CRC32C + */ + public static long compute(ByteBuffer buffer, int offset, int size) { + Checksum crc = create(); + Checksums.update(crc, buffer, offset, size); + return crc.getValue(); + } + + public static Checksum create() { + return CHECKSUM_FACTORY.create(); + } + + private interface ChecksumFactory { + Checksum create(); + } + + private static class Java9ChecksumFactory implements ChecksumFactory { + private static final MethodHandle CONSTRUCTOR; + + static { + try { + Class cls = Class.forName("java.util.zip.CRC32C"); + CONSTRUCTOR = MethodHandles.publicLookup().findConstructor(cls, MethodType.methodType(void.class)); + } catch (ReflectiveOperationException e) { + // Should never happen + throw new RuntimeException(e); + } + } + + @Override + public Checksum create() { + try { + return (Checksum) CONSTRUCTOR.invoke(); + } catch (Throwable throwable) { + // Should never happen + throw new RuntimeException(throwable); + } + } + } + + private static class PureJavaChecksumFactory implements ChecksumFactory { + @Override + public Checksum create() { + return new PureJavaCrc32C(); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/Exit.java b/clients/src/main/java/org/apache/kafka/common/utils/Exit.java new file mode 100644 index 0000000..45f92e4 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/Exit.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +/** + * Internal class that should be used instead of `Exit.exit()` and `Runtime.getRuntime().halt()` so that tests can + * easily change the behaviour. + */ +public class Exit { + + public interface Procedure { + void execute(int statusCode, String message); + } + + public interface ShutdownHookAdder { + void addShutdownHook(String name, Runnable runnable); + } + + private static final Procedure DEFAULT_HALT_PROCEDURE = new Procedure() { + @Override + public void execute(int statusCode, String message) { + Runtime.getRuntime().halt(statusCode); + } + }; + + private static final Procedure DEFAULT_EXIT_PROCEDURE = new Procedure() { + @Override + public void execute(int statusCode, String message) { + System.exit(statusCode); + } + }; + + private static final ShutdownHookAdder DEFAULT_SHUTDOWN_HOOK_ADDER = new ShutdownHookAdder() { + @Override + public void addShutdownHook(String name, Runnable runnable) { + if (name != null) + Runtime.getRuntime().addShutdownHook(KafkaThread.nonDaemon(name, runnable)); + else + Runtime.getRuntime().addShutdownHook(new Thread(runnable)); + } + }; + + private volatile static Procedure exitProcedure = DEFAULT_EXIT_PROCEDURE; + private volatile static Procedure haltProcedure = DEFAULT_HALT_PROCEDURE; + private volatile static ShutdownHookAdder shutdownHookAdder = DEFAULT_SHUTDOWN_HOOK_ADDER; + + public static void exit(int statusCode) { + exit(statusCode, null); + } + + public static void exit(int statusCode, String message) { + exitProcedure.execute(statusCode, message); + } + + public static void halt(int statusCode) { + halt(statusCode, null); + } + + public static void halt(int statusCode, String message) { + haltProcedure.execute(statusCode, message); + } + + public static void addShutdownHook(String name, Runnable runnable) { + shutdownHookAdder.addShutdownHook(name, runnable); + } + + public static void setExitProcedure(Procedure procedure) { + exitProcedure = procedure; + } + + public static void setHaltProcedure(Procedure procedure) { + haltProcedure = procedure; + } + + public static void setShutdownHookAdder(ShutdownHookAdder shutdownHookAdder) { + Exit.shutdownHookAdder = shutdownHookAdder; + } + + public static void resetExitProcedure() { + exitProcedure = DEFAULT_EXIT_PROCEDURE; + } + + public static void resetHaltProcedure() { + haltProcedure = DEFAULT_HALT_PROCEDURE; + } + + public static void resetShutdownHookAdder() { + shutdownHookAdder = DEFAULT_SHUTDOWN_HOOK_ADDER; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/ExponentialBackoff.java b/clients/src/main/java/org/apache/kafka/common/utils/ExponentialBackoff.java new file mode 100644 index 0000000..7550184 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/ExponentialBackoff.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.utils; + +import java.util.concurrent.ThreadLocalRandom; + +/** + * An utility class for keeping the parameters and providing the value of exponential + * retry backoff, exponential reconnect backoff, exponential timeout, etc. + * The formula is: + * Backoff(attempts) = random(1 - jitter, 1 + jitter) * initialInterval * multiplier ^ attempts + * If initialInterval is greater or equal than maxInterval, a constant backoff of will be provided + * This class is thread-safe + */ +public class ExponentialBackoff { + private final int multiplier; + private final double expMax; + private final long initialInterval; + private final double jitter; + + public ExponentialBackoff(long initialInterval, int multiplier, long maxInterval, double jitter) { + this.initialInterval = initialInterval; + this.multiplier = multiplier; + this.jitter = jitter; + this.expMax = maxInterval > initialInterval ? + Math.log(maxInterval / (double) Math.max(initialInterval, 1)) / Math.log(multiplier) : 0; + } + + public long backoff(long attempts) { + if (expMax == 0) { + return initialInterval; + } + double exp = Math.min(attempts, this.expMax); + double term = initialInterval * Math.pow(multiplier, exp); + double randomFactor = jitter < Double.MIN_NORMAL ? 1.0 : + ThreadLocalRandom.current().nextDouble(1 - jitter, 1 + jitter); + return (long) (randomFactor * term); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/FixedOrderMap.java b/clients/src/main/java/org/apache/kafka/common/utils/FixedOrderMap.java new file mode 100644 index 0000000..175282e --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/FixedOrderMap.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * An ordered map (LinkedHashMap) implementation for which the order is immutable. + * To accomplish this, all methods of removing mappings are disabled (they are marked + * deprecated and throw an exception). + * + * This class is final to prevent subclasses from violating the desired property. + * + * @param The key type + * @param The value type + */ +public final class FixedOrderMap extends LinkedHashMap { + private static final long serialVersionUID = -6504110858733236170L; + + @Deprecated + @Override + protected boolean removeEldestEntry(final Map.Entry eldest) { + return false; + } + + @Deprecated + @Override + public V remove(final Object key) { + throw new UnsupportedOperationException("Removing from registeredStores is not allowed"); + } + + @Deprecated + @Override + public boolean remove(final Object key, final Object value) { + throw new UnsupportedOperationException("Removing from registeredStores is not allowed"); + } + + @Override + public FixedOrderMap clone() { + throw new UnsupportedOperationException(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/FlattenedIterator.java b/clients/src/main/java/org/apache/kafka/common/utils/FlattenedIterator.java new file mode 100644 index 0000000..48bf3b7 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/FlattenedIterator.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import java.util.Iterator; +import java.util.function.Function; + +/** + * Provides a flattened iterator over the inner elements of an outer iterator. + */ +public final class FlattenedIterator extends AbstractIterator { + private final Iterator outerIterator; + private final Function> innerIteratorFunction; + private Iterator innerIterator; + + public FlattenedIterator(Iterator outerIterator, Function> innerIteratorFunction) { + this.outerIterator = outerIterator; + this.innerIteratorFunction = innerIteratorFunction; + } + + @Override + public I makeNext() { + while (innerIterator == null || !innerIterator.hasNext()) { + if (outerIterator.hasNext()) + innerIterator = innerIteratorFunction.apply(outerIterator.next()); + else + return allDone(); + } + return innerIterator.next(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/ImplicitLinkedHashCollection.java b/clients/src/main/java/org/apache/kafka/common/utils/ImplicitLinkedHashCollection.java new file mode 100644 index 0000000..ef33f5f --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/ImplicitLinkedHashCollection.java @@ -0,0 +1,695 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.utils; + +import java.util.AbstractCollection; +import java.util.AbstractSequentialList; +import java.util.AbstractSet; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.ListIterator; +import java.util.NoSuchElementException; +import java.util.Set; + +/** + * A memory-efficient hash set which tracks the order of insertion of elements. + * + * Like java.util.LinkedHashSet, this collection maintains a linked list of elements. + * However, rather than using a separate linked list, this collection embeds the next + * and previous fields into the elements themselves. This reduces memory consumption, + * because it means that we only have to store one Java object per element, rather + * than multiple. + * + * The next and previous fields are stored as array indices rather than pointers. + * This ensures that the fields only take 32 bits, even when pointers are 64 bits. + * It also makes the garbage collector's job easier, because it reduces the number of + * pointers that it must chase. + * + * This class uses linear probing. Unlike HashMap (but like HashTable), we don't force + * the size to be a power of 2. This saves memory. + * + * This set does not allow null elements. It does not have internal synchronization. + */ +public class ImplicitLinkedHashCollection extends AbstractCollection { + /** + * The interface which elements of this collection must implement. The prev, + * setPrev, next, and setNext functions handle manipulating the implicit linked + * list which these elements reside in inside the collection. + * elementKeysAreEqual() is the function which this collection uses to compare + * elements. + */ + public interface Element { + int prev(); + void setPrev(int prev); + int next(); + void setNext(int next); + default boolean elementKeysAreEqual(Object other) { + return equals(other); + } + } + + /** + * A special index value used to indicate that the next or previous field is + * the head. + */ + private static final int HEAD_INDEX = -1; + + /** + * A special index value used for next and previous indices which have not + * been initialized. + */ + public static final int INVALID_INDEX = -2; + + /** + * The minimum new capacity for a non-empty implicit hash set. + */ + private static final int MIN_NONEMPTY_CAPACITY = 5; + + /** + * A static empty array used to avoid object allocations when the capacity is zero. + */ + private static final Element[] EMPTY_ELEMENTS = new Element[0]; + + private static class HeadElement implements Element { + static final HeadElement EMPTY = new HeadElement(); + + private int prev = HEAD_INDEX; + private int next = HEAD_INDEX; + + @Override + public int prev() { + return prev; + } + + @Override + public void setPrev(int prev) { + this.prev = prev; + } + + @Override + public int next() { + return next; + } + + @Override + public void setNext(int next) { + this.next = next; + } + } + + private static Element indexToElement(Element head, Element[] elements, int index) { + if (index == HEAD_INDEX) { + return head; + } + return elements[index]; + } + + private static void addToListTail(Element head, Element[] elements, int elementIdx) { + int oldTailIdx = head.prev(); + Element element = indexToElement(head, elements, elementIdx); + Element oldTail = indexToElement(head, elements, oldTailIdx); + head.setPrev(elementIdx); + oldTail.setNext(elementIdx); + element.setPrev(oldTailIdx); + element.setNext(HEAD_INDEX); + } + + private static void removeFromList(Element head, Element[] elements, int elementIdx) { + Element element = indexToElement(head, elements, elementIdx); + elements[elementIdx] = null; + int prevIdx = element.prev(); + int nextIdx = element.next(); + Element prev = indexToElement(head, elements, prevIdx); + Element next = indexToElement(head, elements, nextIdx); + prev.setNext(nextIdx); + next.setPrev(prevIdx); + element.setNext(INVALID_INDEX); + element.setPrev(INVALID_INDEX); + } + + private class ImplicitLinkedHashCollectionIterator implements ListIterator { + private int index = 0; + private Element cur; + private Element lastReturned; + + ImplicitLinkedHashCollectionIterator(int index) { + this.cur = indexToElement(head, elements, head.next()); + for (int i = 0; i < index; ++i) { + next(); + } + this.lastReturned = null; + } + + @Override + public boolean hasNext() { + return cur != head; + } + + @Override + public boolean hasPrevious() { + return indexToElement(head, elements, cur.prev()) != head; + } + + @Override + public E next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + @SuppressWarnings("unchecked") + E returnValue = (E) cur; + lastReturned = cur; + cur = indexToElement(head, elements, cur.next()); + ++index; + return returnValue; + } + + @Override + public E previous() { + Element prev = indexToElement(head, elements, cur.prev()); + if (prev == head) { + throw new NoSuchElementException(); + } + cur = prev; + --index; + lastReturned = cur; + @SuppressWarnings("unchecked") + E returnValue = (E) cur; + return returnValue; + } + + @Override + public int nextIndex() { + return index; + } + + @Override + public int previousIndex() { + return index - 1; + } + + @Override + public void remove() { + if (lastReturned == null) { + throw new IllegalStateException(); + } + Element nextElement = indexToElement(head, elements, lastReturned.next()); + ImplicitLinkedHashCollection.this.removeElementAtSlot(nextElement.prev()); + if (lastReturned == cur) { + // If the element we are removing was cur, set cur to cur->next. + cur = nextElement; + } else { + // If the element we are removing comes before cur, decrement the index, + // since there are now fewer entries before cur. + --index; + } + lastReturned = null; + } + + @Override + public void set(E e) { + throw new UnsupportedOperationException(); + } + + @Override + public void add(E e) { + throw new UnsupportedOperationException(); + } + } + + private class ImplicitLinkedHashCollectionListView extends AbstractSequentialList { + + @Override + public ListIterator listIterator(int index) { + if (index < 0 || index > size) { + throw new IndexOutOfBoundsException(); + } + + return ImplicitLinkedHashCollection.this.listIterator(index); + } + + @Override + public int size() { + return size; + } + } + + private class ImplicitLinkedHashCollectionSetView extends AbstractSet { + + @Override + public Iterator iterator() { + return ImplicitLinkedHashCollection.this.iterator(); + } + + @Override + public int size() { + return size; + } + + @Override + public boolean add(E newElement) { + return ImplicitLinkedHashCollection.this.add(newElement); + } + + @Override + public boolean remove(Object key) { + return ImplicitLinkedHashCollection.this.remove(key); + } + + @Override + public boolean contains(Object key) { + return ImplicitLinkedHashCollection.this.contains(key); + } + + @Override + public void clear() { + ImplicitLinkedHashCollection.this.clear(); + } + } + + private Element head; + + Element[] elements; + + private int size; + + /** + * Returns an iterator that will yield every element in the set. + * The elements will be returned in the order that they were inserted in. + * + * Do not modify the set while you are iterating over it (except by calling + * remove on the iterator itself, of course.) + */ + @Override + final public Iterator iterator() { + return listIterator(0); + } + + private ListIterator listIterator(int index) { + return new ImplicitLinkedHashCollectionIterator(index); + } + + final int slot(Element[] curElements, Object e) { + return (e.hashCode() & 0x7fffffff) % curElements.length; + } + + /** + * Find an element matching an example element. + * + * Using the element's hash code, we can look up the slot where it belongs. + * However, it may not have ended up in exactly this slot, due to a collision. + * Therefore, we must search forward in the array until we hit a null, before + * concluding that the element is not present. + * + * @param key The element to match. + * @return The match index, or INVALID_INDEX if no match was found. + */ + final private int findIndexOfEqualElement(Object key) { + if (key == null || size == 0) { + return INVALID_INDEX; + } + int slot = slot(elements, key); + for (int seen = 0; seen < elements.length; seen++) { + Element element = elements[slot]; + if (element == null) { + return INVALID_INDEX; + } + if (element.elementKeysAreEqual(key)) { + return slot; + } + slot = (slot + 1) % elements.length; + } + return INVALID_INDEX; + } + + /** + * An element e in the collection such that e.elementKeysAreEqual(key) and + * e.hashCode() == key.hashCode(). + * + * @param key The element to match. + * @return The matching element, or null if there were none. + */ + final public E find(E key) { + int index = findIndexOfEqualElement(key); + if (index == INVALID_INDEX) { + return null; + } + @SuppressWarnings("unchecked") + E result = (E) elements[index]; + return result; + } + + /** + * Returns the number of elements in the set. + */ + @Override + final public int size() { + return size; + } + + /** + * Returns true if there is at least one element e in the collection such + * that key.elementKeysAreEqual(e) and key.hashCode() == e.hashCode(). + * + * @param key The object to try to match. + */ + @Override + final public boolean contains(Object key) { + return findIndexOfEqualElement(key) != INVALID_INDEX; + } + + private static int calculateCapacity(int expectedNumElements) { + // Avoid using even-sized capacities, to get better key distribution. + int newCapacity = (2 * expectedNumElements) + 1; + // Don't use a capacity that is too small. + return Math.max(newCapacity, MIN_NONEMPTY_CAPACITY); + } + + /** + * Add a new element to the collection. + * + * @param newElement The new element. + * + * @return True if the element was added to the collection; + * false if it was not, because there was an existing equal element. + */ + @Override + final public boolean add(E newElement) { + if (newElement == null) { + return false; + } + if (newElement.prev() != INVALID_INDEX || newElement.next() != INVALID_INDEX) { + return false; + } + if ((size + 1) >= elements.length / 2) { + changeCapacity(calculateCapacity(elements.length)); + } + int slot = addInternal(newElement, elements); + if (slot >= 0) { + addToListTail(head, elements, slot); + size++; + return true; + } + return false; + } + + final public void mustAdd(E newElement) { + if (!add(newElement)) { + throw new RuntimeException("Unable to add " + newElement); + } + } + + /** + * Adds a new element to the appropriate place in the elements array. + * + * @param newElement The new element to add. + * @param addElements The elements array. + * @return The index at which the element was inserted, or INVALID_INDEX + * if the element could not be inserted. + */ + int addInternal(Element newElement, Element[] addElements) { + int slot = slot(addElements, newElement); + for (int seen = 0; seen < addElements.length; seen++) { + Element element = addElements[slot]; + if (element == null) { + addElements[slot] = newElement; + return slot; + } + if (element.elementKeysAreEqual(newElement)) { + return INVALID_INDEX; + } + slot = (slot + 1) % addElements.length; + } + throw new RuntimeException("Not enough hash table slots to add a new element."); + } + + private void changeCapacity(int newCapacity) { + Element[] newElements = new Element[newCapacity]; + HeadElement newHead = new HeadElement(); + int oldSize = size; + for (Iterator iter = iterator(); iter.hasNext(); ) { + Element element = iter.next(); + iter.remove(); + int newSlot = addInternal(element, newElements); + addToListTail(newHead, newElements, newSlot); + } + this.elements = newElements; + this.head = newHead; + this.size = oldSize; + } + + /** + * Remove the first element e such that key.elementKeysAreEqual(e) + * and key.hashCode == e.hashCode. + * + * @param key The object to try to match. + * @return True if an element was removed; false otherwise. + */ + @Override + final public boolean remove(Object key) { + int slot = findElementToRemove(key); + if (slot == INVALID_INDEX) { + return false; + } + removeElementAtSlot(slot); + return true; + } + + int findElementToRemove(Object key) { + return findIndexOfEqualElement(key); + } + + /** + * Remove an element in a particular slot. + * + * @param slot The slot of the element to remove. + * + * @return True if an element was removed; false otherwise. + */ + private boolean removeElementAtSlot(int slot) { + size--; + removeFromList(head, elements, slot); + slot = (slot + 1) % elements.length; + + // Find the next empty slot + int endSlot = slot; + for (int seen = 0; seen < elements.length; seen++) { + Element element = elements[endSlot]; + if (element == null) { + break; + } + endSlot = (endSlot + 1) % elements.length; + } + + // We must preserve the denseness invariant. The denseness invariant says that + // any element is either in the slot indicated by its hash code, or a slot which + // is not separated from that slot by any nulls. + // Reseat all elements in between the deleted element and the next empty slot. + while (slot != endSlot) { + reseat(slot); + slot = (slot + 1) % elements.length; + } + return true; + } + + private void reseat(int prevSlot) { + Element element = elements[prevSlot]; + int newSlot = slot(elements, element); + for (int seen = 0; seen < elements.length; seen++) { + Element e = elements[newSlot]; + if ((e == null) || (e == element)) { + break; + } + newSlot = (newSlot + 1) % elements.length; + } + if (newSlot == prevSlot) { + return; + } + Element prev = indexToElement(head, elements, element.prev()); + prev.setNext(newSlot); + Element next = indexToElement(head, elements, element.next()); + next.setPrev(newSlot); + elements[prevSlot] = null; + elements[newSlot] = element; + } + + /** + * Create a new ImplicitLinkedHashCollection. + */ + public ImplicitLinkedHashCollection() { + this(0); + } + + /** + * Create a new ImplicitLinkedHashCollection. + * + * @param expectedNumElements The number of elements we expect to have in this set. + * This is used to optimize by setting the capacity ahead + * of time rather than growing incrementally. + */ + public ImplicitLinkedHashCollection(int expectedNumElements) { + clear(expectedNumElements); + } + + /** + * Create a new ImplicitLinkedHashCollection. + * + * @param iter We will add all the elements accessible through this iterator + * to the set. + */ + public ImplicitLinkedHashCollection(Iterator iter) { + clear(0); + while (iter.hasNext()) { + mustAdd(iter.next()); + } + } + + /** + * Removes all of the elements from this set. + */ + @Override + final public void clear() { + clear(elements.length); + } + + /** + * Moves an element which is already in the collection so that it comes last + * in iteration order. + */ + final public void moveToEnd(E element) { + if (element.prev() == INVALID_INDEX || element.next() == INVALID_INDEX) { + throw new RuntimeException("Element " + element + " is not in the collection."); + } + Element prevElement = indexToElement(head, elements, element.prev()); + Element nextElement = indexToElement(head, elements, element.next()); + int slot = prevElement.next(); + prevElement.setNext(element.next()); + nextElement.setPrev(element.prev()); + addToListTail(head, elements, slot); + } + + /** + * Removes all of the elements from this set, and resets the set capacity + * based on the provided expected number of elements. + */ + final public void clear(int expectedNumElements) { + if (expectedNumElements == 0) { + // Optimize away object allocations for empty sets. + this.head = HeadElement.EMPTY; + this.elements = EMPTY_ELEMENTS; + this.size = 0; + } else { + this.head = new HeadElement(); + this.elements = new Element[calculateCapacity(expectedNumElements)]; + this.size = 0; + } + } + + /** + * Compares the specified object with this collection for equality. Two + * {@code ImplicitLinkedHashCollection} objects are equal if they contain the + * same elements (as determined by the element's {@code equals} method), and + * those elements were inserted in the same order. Because + * {@code ImplicitLinkedHashCollectionListIterator} iterates over the elements + * in insertion order, it is sufficient to call {@code valuesList.equals}. + * + * Note that {@link ImplicitLinkedHashMultiCollection} does not override + * {@code equals} and uses this method as well. This means that two + * {@code ImplicitLinkedHashMultiCollection} objects will be considered equal even + * if they each contain two elements A and B such that A.equals(B) but A != B and + * A and B have switched insertion positions between the two collections. This + * is an acceptable definition of equality, because the collections are still + * equal in terms of the order and value of each element. + * + * @param o object to be compared for equality with this collection + * @return true is the specified object is equal to this collection + */ + @Override + public boolean equals(Object o) { + if (o == this) + return true; + + if (!(o instanceof ImplicitLinkedHashCollection)) + return false; + + ImplicitLinkedHashCollection ilhs = (ImplicitLinkedHashCollection) o; + return this.valuesList().equals(ilhs.valuesList()); + } + + /** + * Returns the hash code value for this collection. Because + * {@code ImplicitLinkedHashCollection.equals} compares the {@code valuesList} + * of two {@code ImplicitLinkedHashCollection} objects to determine equality, + * this method uses the @{code valuesList} to compute the has code value as well. + * + * @return the hash code value for this collection + */ + @Override + public int hashCode() { + return this.valuesList().hashCode(); + } + + // Visible for testing + final int numSlots() { + return elements.length; + } + + /** + * Returns a {@link List} view of the elements contained in the collection, + * ordered by order of insertion into the collection. The list is backed by the + * collection, so changes to the collection are reflected in the list and + * vice-versa. The list supports element removal, which removes the corresponding + * element from the collection, but does not support the {@code add} or + * {@code set} operations. + * + * The list is implemented as a circular linked list, so all index-based + * operations, such as {@code List.get}, run in O(n) time. + * + * @return a list view of the elements contained in this collection + */ + public List valuesList() { + return new ImplicitLinkedHashCollectionListView(); + } + + /** + * Returns a {@link Set} view of the elements contained in the collection. The + * set is backed by the collection, so changes to the collection are reflected in + * the set, and vice versa. The set supports element removal and addition, which + * removes from or adds to the collection, respectively. + * + * @return a set view of the elements contained in this collection + */ + public Set valuesSet() { + return new ImplicitLinkedHashCollectionSetView(); + } + + public void sort(Comparator comparator) { + ArrayList array = new ArrayList<>(size); + Iterator iterator = iterator(); + while (iterator.hasNext()) { + E e = iterator.next(); + iterator.remove(); + array.add(e); + } + array.sort(comparator); + for (E e : array) { + add(e); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/ImplicitLinkedHashMultiCollection.java b/clients/src/main/java/org/apache/kafka/common/utils/ImplicitLinkedHashMultiCollection.java new file mode 100644 index 0000000..b95b7de --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/ImplicitLinkedHashMultiCollection.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.utils; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; + +/** + * A memory-efficient hash multiset which tracks the order of insertion of elements. + * See org.apache.kafka.common.utils.ImplicitLinkedHashCollection for implementation details. + * + * This class is a multi-set because it allows multiple elements to be inserted that + * have equivalent keys. + * + * We use reference equality when adding elements to the set. A new element A can + * be added if there is no existing element B such that A == B. If an element B + * exists such that A.elementKeysAreEqual(B), A will still be added. + * + * When deleting an element A from the set, we will try to delete the element B such + * that A == B. If no such element can be found, we will try to delete an element B + * such that A.elementKeysAreEqual(B). + * + * contains() and find() are unchanged from the base class-- they will look for element + * based on object equality via elementKeysAreEqual, not reference equality. + * + * This multiset does not allow null elements. It does not have internal synchronization. + */ +public class ImplicitLinkedHashMultiCollection + extends ImplicitLinkedHashCollection { + public ImplicitLinkedHashMultiCollection() { + super(0); + } + + public ImplicitLinkedHashMultiCollection(int expectedNumElements) { + super(expectedNumElements); + } + + public ImplicitLinkedHashMultiCollection(Iterator iter) { + super(iter); + } + + + /** + * Adds a new element to the appropriate place in the elements array. + * + * @param newElement The new element to add. + * @param addElements The elements array. + * @return The index at which the element was inserted, or INVALID_INDEX + * if the element could not be inserted. + */ + @Override + int addInternal(Element newElement, Element[] addElements) { + int slot = slot(addElements, newElement); + for (int seen = 0; seen < addElements.length; seen++) { + Element element = addElements[slot]; + if (element == null) { + addElements[slot] = newElement; + return slot; + } + if (element == newElement) { + return INVALID_INDEX; + } + slot = (slot + 1) % addElements.length; + } + throw new RuntimeException("Not enough hash table slots to add a new element."); + } + + /** + * Find an element matching an example element. + * + * @param key The element to match. + * + * @return The match index, or INVALID_INDEX if no match was found. + */ + @Override + int findElementToRemove(Object key) { + if (key == null || size() == 0) { + return INVALID_INDEX; + } + int slot = slot(elements, key); + int bestSlot = INVALID_INDEX; + for (int seen = 0; seen < elements.length; seen++) { + Element element = elements[slot]; + if (element == null) { + return bestSlot; + } + if (key == element) { + return slot; + } else if (element.elementKeysAreEqual(key)) { + bestSlot = slot; + } + slot = (slot + 1) % elements.length; + } + return INVALID_INDEX; + } + + /** + * Returns all of the elements e in the collection such that + * key.elementKeysAreEqual(e) and key.hashCode() == e.hashCode(). + * + * @param key The element to match. + * + * @return All of the matching elements. + */ + final public List findAll(E key) { + if (key == null || size() == 0) { + return Collections.emptyList(); + } + ArrayList results = new ArrayList<>(); + int slot = slot(elements, key); + for (int seen = 0; seen < elements.length; seen++) { + Element element = elements[slot]; + if (element == null) { + break; + } + if (key.elementKeysAreEqual(element)) { + @SuppressWarnings("unchecked") + E result = (E) elements[slot]; + results.add(result); + } + slot = (slot + 1) % elements.length; + } + return results; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/Java.java b/clients/src/main/java/org/apache/kafka/common/utils/Java.java new file mode 100644 index 0000000..c0c0a89 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/Java.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import java.util.StringTokenizer; + +public final class Java { + + private Java() { } + + private static final Version VERSION = parseVersion(System.getProperty("java.specification.version")); + + // Package private for testing + static Version parseVersion(String versionString) { + final StringTokenizer st = new StringTokenizer(versionString, "."); + int majorVersion = Integer.parseInt(st.nextToken()); + int minorVersion; + if (st.hasMoreTokens()) + minorVersion = Integer.parseInt(st.nextToken()); + else + minorVersion = 0; + return new Version(majorVersion, minorVersion); + } + + // Having these as static final provides the best opportunity for compilar optimization + public static final boolean IS_JAVA9_COMPATIBLE = VERSION.isJava9Compatible(); + public static final boolean IS_JAVA11_COMPATIBLE = VERSION.isJava11Compatible(); + + public static boolean isIbmJdk() { + return System.getProperty("java.vendor").contains("IBM"); + } + + // Package private for testing + static class Version { + public final int majorVersion; + public final int minorVersion; + + private Version(int majorVersion, int minorVersion) { + this.majorVersion = majorVersion; + this.minorVersion = minorVersion; + } + + @Override + public String toString() { + return "Version(majorVersion=" + majorVersion + + ", minorVersion=" + minorVersion + ")"; + } + + // Package private for testing + boolean isJava9Compatible() { + return majorVersion >= 9; + } + + boolean isJava11Compatible() { + return majorVersion >= 11; + } + + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/KafkaThread.java b/clients/src/main/java/org/apache/kafka/common/utils/KafkaThread.java new file mode 100644 index 0000000..13430e4 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/KafkaThread.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A wrapper for Thread that sets things up nicely + */ +public class KafkaThread extends Thread { + + private final Logger log = LoggerFactory.getLogger(getClass()); + + public static KafkaThread daemon(final String name, Runnable runnable) { + return new KafkaThread(name, runnable, true); + } + + public static KafkaThread nonDaemon(final String name, Runnable runnable) { + return new KafkaThread(name, runnable, false); + } + + public KafkaThread(final String name, boolean daemon) { + super(name); + configureThread(name, daemon); + } + + public KafkaThread(final String name, Runnable runnable, boolean daemon) { + super(runnable, name); + configureThread(name, daemon); + } + + private void configureThread(final String name, boolean daemon) { + setDaemon(daemon); + setUncaughtExceptionHandler((t, e) -> log.error("Uncaught exception in thread '{}':", name, e)); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/LogContext.java b/clients/src/main/java/org/apache/kafka/common/utils/LogContext.java new file mode 100644 index 0000000..774961f --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/LogContext.java @@ -0,0 +1,793 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.slf4j.Marker; +import org.slf4j.helpers.FormattingTuple; +import org.slf4j.helpers.MessageFormatter; +import org.slf4j.spi.LocationAwareLogger; + +/** + * This class provides a way to instrument loggers with a common context which can be used to + * automatically enrich log messages. For example, in the KafkaConsumer, it is often useful to know + * the groupId of the consumer, so this can be added to a context object which can then be passed to + * all of the dependent components in order to build new loggers. This removes the need to manually + * add the groupId to each message. + */ +public class LogContext { + + private final String logPrefix; + + public LogContext(String logPrefix) { + this.logPrefix = logPrefix == null ? "" : logPrefix; + } + + public LogContext() { + this(""); + } + + public Logger logger(Class clazz) { + Logger logger = LoggerFactory.getLogger(clazz); + if (logger instanceof LocationAwareLogger) { + return new LocationAwareKafkaLogger(logPrefix, (LocationAwareLogger) logger); + } else { + return new LocationIgnorantKafkaLogger(logPrefix, logger); + } + } + + public String logPrefix() { + return logPrefix; + } + + private static abstract class AbstractKafkaLogger implements Logger { + private final String prefix; + + protected AbstractKafkaLogger(final String prefix) { + this.prefix = prefix; + } + + protected String addPrefix(final String message) { + return prefix + message; + } + } + + private static class LocationAwareKafkaLogger extends AbstractKafkaLogger { + private final LocationAwareLogger logger; + private final String fqcn; + + LocationAwareKafkaLogger(String logPrefix, LocationAwareLogger logger) { + super(logPrefix); + this.logger = logger; + this.fqcn = LocationAwareKafkaLogger.class.getName(); + } + + @Override + public String getName() { + return logger.getName(); + } + + @Override + public boolean isTraceEnabled() { + return logger.isTraceEnabled(); + } + + @Override + public boolean isTraceEnabled(Marker marker) { + return logger.isTraceEnabled(marker); + } + + @Override + public boolean isDebugEnabled() { + return logger.isDebugEnabled(); + } + + @Override + public boolean isDebugEnabled(Marker marker) { + return logger.isDebugEnabled(marker); + } + + @Override + public boolean isInfoEnabled() { + return logger.isInfoEnabled(); + } + + @Override + public boolean isInfoEnabled(Marker marker) { + return logger.isInfoEnabled(marker); + } + + @Override + public boolean isWarnEnabled() { + return logger.isWarnEnabled(); + } + + @Override + public boolean isWarnEnabled(Marker marker) { + return logger.isWarnEnabled(marker); + } + + @Override + public boolean isErrorEnabled() { + return logger.isErrorEnabled(); + } + + @Override + public boolean isErrorEnabled(Marker marker) { + return logger.isErrorEnabled(marker); + } + + @Override + public void trace(String message) { + if (logger.isTraceEnabled()) { + writeLog(null, LocationAwareLogger.TRACE_INT, message, null, null); + } + } + + @Override + public void trace(String format, Object arg) { + if (logger.isTraceEnabled()) { + writeLog(null, LocationAwareLogger.TRACE_INT, format, new Object[]{arg}, null); + } + } + + @Override + public void trace(String format, Object arg1, Object arg2) { + if (logger.isTraceEnabled()) { + writeLog(null, LocationAwareLogger.TRACE_INT, format, new Object[]{arg1, arg2}, null); + } + } + + @Override + public void trace(String format, Object... args) { + if (logger.isTraceEnabled()) { + writeLog(null, LocationAwareLogger.TRACE_INT, format, args, null); + } + } + + @Override + public void trace(String msg, Throwable t) { + if (logger.isTraceEnabled()) { + writeLog(null, LocationAwareLogger.TRACE_INT, msg, null, t); + } + } + + @Override + public void trace(Marker marker, String msg) { + if (logger.isTraceEnabled()) { + writeLog(marker, LocationAwareLogger.TRACE_INT, msg, null, null); + } + } + + @Override + public void trace(Marker marker, String format, Object arg) { + if (logger.isTraceEnabled()) { + writeLog(marker, LocationAwareLogger.TRACE_INT, format, new Object[]{arg}, null); + } + } + + @Override + public void trace(Marker marker, String format, Object arg1, Object arg2) { + if (logger.isTraceEnabled()) { + writeLog(marker, LocationAwareLogger.TRACE_INT, format, new Object[]{arg1, arg2}, null); + } + } + + @Override + public void trace(Marker marker, String format, Object... argArray) { + if (logger.isTraceEnabled()) { + writeLog(marker, LocationAwareLogger.TRACE_INT, format, argArray, null); + } + } + + @Override + public void trace(Marker marker, String msg, Throwable t) { + if (logger.isTraceEnabled()) { + writeLog(marker, LocationAwareLogger.TRACE_INT, msg, null, t); + } + } + + @Override + public void debug(String message) { + if (logger.isDebugEnabled()) { + writeLog(null, LocationAwareLogger.DEBUG_INT, message, null, null); + } + } + + @Override + public void debug(String format, Object arg) { + if (logger.isDebugEnabled()) { + writeLog(null, LocationAwareLogger.DEBUG_INT, format, new Object[]{arg}, null); + } + } + + @Override + public void debug(String format, Object arg1, Object arg2) { + if (logger.isDebugEnabled()) { + writeLog(null, LocationAwareLogger.DEBUG_INT, format, new Object[]{arg1, arg2}, null); + } + } + + @Override + public void debug(String format, Object... args) { + if (logger.isDebugEnabled()) { + writeLog(null, LocationAwareLogger.DEBUG_INT, format, args, null); + } + } + + @Override + public void debug(String msg, Throwable t) { + if (logger.isDebugEnabled()) { + writeLog(null, LocationAwareLogger.DEBUG_INT, msg, null, t); + } + } + + @Override + public void debug(Marker marker, String msg) { + if (logger.isDebugEnabled()) { + writeLog(marker, LocationAwareLogger.DEBUG_INT, msg, null, null); + } + } + + @Override + public void debug(Marker marker, String format, Object arg) { + if (logger.isDebugEnabled()) { + writeLog(marker, LocationAwareLogger.DEBUG_INT, format, new Object[]{arg}, null); + } + } + + @Override + public void debug(Marker marker, String format, Object arg1, Object arg2) { + if (logger.isDebugEnabled()) { + writeLog(marker, LocationAwareLogger.DEBUG_INT, format, new Object[]{arg1, arg2}, null); + } + } + + @Override + public void debug(Marker marker, String format, Object... arguments) { + if (logger.isDebugEnabled()) { + writeLog(marker, LocationAwareLogger.DEBUG_INT, format, arguments, null); + } + } + + @Override + public void debug(Marker marker, String msg, Throwable t) { + if (logger.isDebugEnabled()) { + writeLog(marker, LocationAwareLogger.DEBUG_INT, msg, null, t); + } + } + + @Override + public void warn(String message) { + writeLog(null, LocationAwareLogger.WARN_INT, message, null, null); + } + + @Override + public void warn(String format, Object arg) { + writeLog(null, LocationAwareLogger.WARN_INT, format, new Object[]{arg}, null); + } + + @Override + public void warn(String message, Object arg1, Object arg2) { + writeLog(null, LocationAwareLogger.WARN_INT, message, new Object[]{arg1, arg2}, null); + } + + @Override + public void warn(String format, Object... args) { + writeLog(null, LocationAwareLogger.WARN_INT, format, args, null); + } + + @Override + public void warn(String msg, Throwable t) { + writeLog(null, LocationAwareLogger.WARN_INT, msg, null, t); + } + + @Override + public void warn(Marker marker, String msg) { + writeLog(marker, LocationAwareLogger.WARN_INT, msg, null, null); + } + + @Override + public void warn(Marker marker, String format, Object arg) { + writeLog(marker, LocationAwareLogger.WARN_INT, format, new Object[]{arg}, null); + } + + @Override + public void warn(Marker marker, String format, Object arg1, Object arg2) { + writeLog(marker, LocationAwareLogger.WARN_INT, format, new Object[]{arg1, arg2}, null); + } + + @Override + public void warn(Marker marker, String format, Object... arguments) { + writeLog(marker, LocationAwareLogger.WARN_INT, format, arguments, null); + } + + @Override + public void warn(Marker marker, String msg, Throwable t) { + writeLog(marker, LocationAwareLogger.WARN_INT, msg, null, t); + } + + @Override + public void error(String message) { + writeLog(null, LocationAwareLogger.ERROR_INT, message, null, null); + } + + @Override + public void error(String format, Object arg) { + writeLog(null, LocationAwareLogger.ERROR_INT, format, new Object[]{arg}, null); + } + + @Override + public void error(String format, Object arg1, Object arg2) { + writeLog(null, LocationAwareLogger.ERROR_INT, format, new Object[]{arg1, arg2}, null); + } + + @Override + public void error(String format, Object... args) { + writeLog(null, LocationAwareLogger.ERROR_INT, format, args, null); + } + + @Override + public void error(String msg, Throwable t) { + writeLog(null, LocationAwareLogger.ERROR_INT, msg, null, t); + } + + @Override + public void error(Marker marker, String msg) { + writeLog(marker, LocationAwareLogger.ERROR_INT, msg, null, null); + } + + @Override + public void error(Marker marker, String format, Object arg) { + writeLog(marker, LocationAwareLogger.ERROR_INT, format, new Object[]{arg}, null); + } + + @Override + public void error(Marker marker, String format, Object arg1, Object arg2) { + writeLog(marker, LocationAwareLogger.ERROR_INT, format, new Object[]{arg1, arg2}, null); + } + + @Override + public void error(Marker marker, String format, Object... arguments) { + writeLog(marker, LocationAwareLogger.ERROR_INT, format, arguments, null); + } + + @Override + public void error(Marker marker, String msg, Throwable t) { + writeLog(marker, LocationAwareLogger.ERROR_INT, msg, null, t); + } + + @Override + public void info(String msg) { + writeLog(null, LocationAwareLogger.INFO_INT, msg, null, null); + } + + @Override + public void info(String format, Object arg) { + writeLog(null, LocationAwareLogger.INFO_INT, format, new Object[]{arg}, null); + } + + @Override + public void info(String format, Object arg1, Object arg2) { + writeLog(null, LocationAwareLogger.INFO_INT, format, new Object[]{arg1, arg2}, null); + } + + @Override + public void info(String format, Object... args) { + writeLog(null, LocationAwareLogger.INFO_INT, format, args, null); + } + + @Override + public void info(String msg, Throwable t) { + writeLog(null, LocationAwareLogger.INFO_INT, msg, null, t); + } + + @Override + public void info(Marker marker, String msg) { + writeLog(marker, LocationAwareLogger.INFO_INT, msg, null, null); + } + + @Override + public void info(Marker marker, String format, Object arg) { + writeLog(marker, LocationAwareLogger.INFO_INT, format, new Object[]{arg}, null); + } + + @Override + public void info(Marker marker, String format, Object arg1, Object arg2) { + writeLog(marker, LocationAwareLogger.INFO_INT, format, new Object[]{arg1, arg2}, null); + } + + @Override + public void info(Marker marker, String format, Object... arguments) { + writeLog(marker, LocationAwareLogger.INFO_INT, format, arguments, null); + } + + @Override + public void info(Marker marker, String msg, Throwable t) { + writeLog(marker, LocationAwareLogger.INFO_INT, msg, null, t); + } + + private void writeLog(Marker marker, int level, String format, Object[] args, Throwable exception) { + String message = format; + if (args != null && args.length > 0) { + FormattingTuple formatted = MessageFormatter.arrayFormat(format, args); + if (exception == null && formatted.getThrowable() != null) { + exception = formatted.getThrowable(); + } + message = formatted.getMessage(); + } + logger.log(marker, fqcn, level, addPrefix(message), null, exception); + } + } + + private static class LocationIgnorantKafkaLogger extends AbstractKafkaLogger { + private final Logger logger; + + LocationIgnorantKafkaLogger(String logPrefix, Logger logger) { + super(logPrefix); + this.logger = logger; + } + + @Override + public String getName() { + return logger.getName(); + } + + @Override + public boolean isTraceEnabled() { + return logger.isTraceEnabled(); + } + + @Override + public boolean isTraceEnabled(Marker marker) { + return logger.isTraceEnabled(marker); + } + + @Override + public boolean isDebugEnabled() { + return logger.isDebugEnabled(); + } + + @Override + public boolean isDebugEnabled(Marker marker) { + return logger.isDebugEnabled(marker); + } + + @Override + public boolean isInfoEnabled() { + return logger.isInfoEnabled(); + } + + @Override + public boolean isInfoEnabled(Marker marker) { + return logger.isInfoEnabled(marker); + } + + @Override + public boolean isWarnEnabled() { + return logger.isWarnEnabled(); + } + + @Override + public boolean isWarnEnabled(Marker marker) { + return logger.isWarnEnabled(marker); + } + + @Override + public boolean isErrorEnabled() { + return logger.isErrorEnabled(); + } + + @Override + public boolean isErrorEnabled(Marker marker) { + return logger.isErrorEnabled(marker); + } + + @Override + public void trace(String message) { + if (logger.isTraceEnabled()) { + logger.trace(addPrefix(message)); + } + } + + @Override + public void trace(String message, Object arg) { + if (logger.isTraceEnabled()) { + logger.trace(addPrefix(message), arg); + } + } + + @Override + public void trace(String message, Object arg1, Object arg2) { + if (logger.isTraceEnabled()) { + logger.trace(addPrefix(message), arg1, arg2); + } + } + + @Override + public void trace(String message, Object... args) { + if (logger.isTraceEnabled()) { + logger.trace(addPrefix(message), args); + } + } + + @Override + public void trace(String msg, Throwable t) { + if (logger.isTraceEnabled()) { + logger.trace(addPrefix(msg), t); + } + } + + @Override + public void trace(Marker marker, String msg) { + if (logger.isTraceEnabled()) { + logger.trace(marker, addPrefix(msg)); + } + } + + @Override + public void trace(Marker marker, String format, Object arg) { + if (logger.isTraceEnabled()) { + logger.trace(marker, addPrefix(format), arg); + } + } + + @Override + public void trace(Marker marker, String format, Object arg1, Object arg2) { + if (logger.isTraceEnabled()) { + logger.trace(marker, addPrefix(format), arg1, arg2); + } + } + + @Override + public void trace(Marker marker, String format, Object... argArray) { + if (logger.isTraceEnabled()) { + logger.trace(marker, addPrefix(format), argArray); + } + } + + @Override + public void trace(Marker marker, String msg, Throwable t) { + if (logger.isTraceEnabled()) { + logger.trace(marker, addPrefix(msg), t); + } + } + + @Override + public void debug(String message) { + if (logger.isDebugEnabled()) { + logger.debug(addPrefix(message)); + } + } + + @Override + public void debug(String message, Object arg) { + if (logger.isDebugEnabled()) { + logger.debug(addPrefix(message), arg); + } + } + + @Override + public void debug(String message, Object arg1, Object arg2) { + if (logger.isDebugEnabled()) { + logger.debug(addPrefix(message), arg1, arg2); + } + } + + @Override + public void debug(String message, Object... args) { + if (logger.isDebugEnabled()) { + logger.debug(addPrefix(message), args); + } + } + + @Override + public void debug(String msg, Throwable t) { + if (logger.isDebugEnabled()) { + logger.debug(addPrefix(msg), t); + } + } + + @Override + public void debug(Marker marker, String msg) { + if (logger.isDebugEnabled()) { + logger.debug(marker, addPrefix(msg)); + } + } + + @Override + public void debug(Marker marker, String format, Object arg) { + if (logger.isDebugEnabled()) { + logger.debug(marker, addPrefix(format), arg); + } + } + + @Override + public void debug(Marker marker, String format, Object arg1, Object arg2) { + if (logger.isDebugEnabled()) { + logger.debug(marker, addPrefix(format), arg1, arg2); + } + } + + @Override + public void debug(Marker marker, String format, Object... arguments) { + if (logger.isDebugEnabled()) { + logger.debug(marker, addPrefix(format), arguments); + } + } + + @Override + public void debug(Marker marker, String msg, Throwable t) { + if (logger.isDebugEnabled()) { + logger.debug(marker, addPrefix(msg), t); + } + } + + @Override + public void warn(String message) { + logger.warn(addPrefix(message)); + } + + @Override + public void warn(String message, Object arg) { + logger.warn(addPrefix(message), arg); + } + + @Override + public void warn(String message, Object arg1, Object arg2) { + logger.warn(addPrefix(message), arg1, arg2); + } + + @Override + public void warn(String message, Object... args) { + logger.warn(addPrefix(message), args); + } + + @Override + public void warn(String msg, Throwable t) { + logger.warn(addPrefix(msg), t); + } + + @Override + public void warn(Marker marker, String msg) { + logger.warn(marker, addPrefix(msg)); + } + + @Override + public void warn(Marker marker, String format, Object arg) { + logger.warn(marker, addPrefix(format), arg); + } + + @Override + public void warn(Marker marker, String format, Object arg1, Object arg2) { + logger.warn(marker, addPrefix(format), arg1, arg2); + } + + @Override + public void warn(Marker marker, String format, Object... arguments) { + logger.warn(marker, addPrefix(format), arguments); + } + + @Override + public void warn(Marker marker, String msg, Throwable t) { + logger.warn(marker, addPrefix(msg), t); + } + + @Override + public void error(String message) { + logger.error(addPrefix(message)); + } + + @Override + public void error(String message, Object arg) { + logger.error(addPrefix(message), arg); + } + + @Override + public void error(String message, Object arg1, Object arg2) { + logger.error(addPrefix(message), arg1, arg2); + } + + @Override + public void error(String message, Object... args) { + logger.error(addPrefix(message), args); + } + + @Override + public void error(String msg, Throwable t) { + logger.error(addPrefix(msg), t); + } + + @Override + public void error(Marker marker, String msg) { + logger.error(marker, addPrefix(msg)); + } + + @Override + public void error(Marker marker, String format, Object arg) { + logger.error(marker, addPrefix(format), arg); + } + + @Override + public void error(Marker marker, String format, Object arg1, Object arg2) { + logger.error(marker, addPrefix(format), arg1, arg2); + } + + @Override + public void error(Marker marker, String format, Object... arguments) { + logger.error(marker, addPrefix(format), arguments); + } + + @Override + public void error(Marker marker, String msg, Throwable t) { + logger.error(marker, addPrefix(msg), t); + } + + @Override + public void info(String message) { + logger.info(addPrefix(message)); + } + + @Override + public void info(String message, Object arg) { + logger.info(addPrefix(message), arg); + } + + @Override + public void info(String message, Object arg1, Object arg2) { + logger.info(addPrefix(message), arg1, arg2); + } + + @Override + public void info(String message, Object... args) { + logger.info(addPrefix(message), args); + } + + @Override + public void info(String msg, Throwable t) { + logger.info(addPrefix(msg), t); + } + + @Override + public void info(Marker marker, String msg) { + logger.info(marker, addPrefix(msg)); + } + + @Override + public void info(Marker marker, String format, Object arg) { + logger.info(marker, addPrefix(format), arg); + } + + @Override + public void info(Marker marker, String format, Object arg1, Object arg2) { + logger.info(marker, addPrefix(format), arg1, arg2); + } + + @Override + public void info(Marker marker, String format, Object... arguments) { + logger.info(marker, addPrefix(format), arguments); + } + + @Override + public void info(Marker marker, String msg, Throwable t) { + logger.info(marker, addPrefix(msg), t); + } + + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/LoggingSignalHandler.java b/clients/src/main/java/org/apache/kafka/common/utils/LoggingSignalHandler.java new file mode 100644 index 0000000..112d7fd --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/LoggingSignalHandler.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.Method; +import java.lang.reflect.Proxy; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +public class LoggingSignalHandler { + + private static final Logger log = LoggerFactory.getLogger(LoggingSignalHandler.class); + + private static final List SIGNALS = Arrays.asList("TERM", "INT", "HUP"); + + private final Constructor signalConstructor; + private final Class signalHandlerClass; + private final Class signalClass; + private final Method signalHandleMethod; + private final Method signalGetNameMethod; + private final Method signalHandlerHandleMethod; + + /** + * Create an instance of this class. + * + * @throws ReflectiveOperationException if the underlying API has changed in an incompatible manner. + */ + public LoggingSignalHandler() throws ReflectiveOperationException { + signalClass = Class.forName("sun.misc.Signal"); + signalConstructor = signalClass.getConstructor(String.class); + signalHandlerClass = Class.forName("sun.misc.SignalHandler"); + signalHandlerHandleMethod = signalHandlerClass.getMethod("handle", signalClass); + signalHandleMethod = signalClass.getMethod("handle", signalClass, signalHandlerClass); + signalGetNameMethod = signalClass.getMethod("getName"); + } + + /** + * Register signal handler to log termination due to SIGTERM, SIGHUP and SIGINT (control-c). This method + * does not currently work on Windows. + * + * @implNote sun.misc.Signal and sun.misc.SignalHandler are described as "not encapsulated" in + * http://openjdk.java.net/jeps/260. However, they are not available in the compile classpath if the `--release` + * flag is used. As a workaround, we rely on reflection. + */ + public void register() throws ReflectiveOperationException { + Map jvmSignalHandlers = new ConcurrentHashMap<>(); + + for (String signal : SIGNALS) { + register(signal, jvmSignalHandlers); + } + log.info("Registered signal handlers for " + String.join(", ", SIGNALS)); + } + + private Object createSignalHandler(final Map jvmSignalHandlers) { + InvocationHandler invocationHandler = new InvocationHandler() { + + private String getName(Object signal) throws ReflectiveOperationException { + return (String) signalGetNameMethod.invoke(signal); + } + + private void handle(Object signalHandler, Object signal) throws ReflectiveOperationException { + signalHandlerHandleMethod.invoke(signalHandler, signal); + } + + @Override + public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { + Object signal = args[0]; + log.info("Terminating process due to signal {}", signal); + Object handler = jvmSignalHandlers.get(getName(signal)); + if (handler != null) + handle(handler, signal); + return null; + } + }; + return Proxy.newProxyInstance(Utils.getContextOrKafkaClassLoader(), new Class[] {signalHandlerClass}, + invocationHandler); + } + + private void register(String signalName, final Map jvmSignalHandlers) throws ReflectiveOperationException { + Object signal = signalConstructor.newInstance(signalName); + Object signalHandler = createSignalHandler(jvmSignalHandlers); + Object oldHandler = signalHandleMethod.invoke(null, signal, signalHandler); + if (oldHandler != null) + jvmSignalHandlers.put(signalName, oldHandler); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/MappedIterator.java b/clients/src/main/java/org/apache/kafka/common/utils/MappedIterator.java new file mode 100644 index 0000000..f6eb270 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/MappedIterator.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import java.util.Iterator; +import java.util.function.Function; + +/** + * An iterator that maps another iterator's elements from type `F` to type `T`. + */ +public final class MappedIterator implements Iterator { + private final Iterator underlyingIterator; + private final Function mapper; + + public MappedIterator(Iterator underlyingIterator, Function mapper) { + this.underlyingIterator = underlyingIterator; + this.mapper = mapper; + } + + @Override + public final boolean hasNext() { + return underlyingIterator.hasNext(); + } + + @Override + public final T next() { + return mapper.apply(underlyingIterator.next()); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/OperatingSystem.java b/clients/src/main/java/org/apache/kafka/common/utils/OperatingSystem.java new file mode 100644 index 0000000..8dc8b86 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/OperatingSystem.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import java.util.Locale; + +public final class OperatingSystem { + + private OperatingSystem() { + } + + public static final String NAME; + + public static final boolean IS_WINDOWS; + + public static final boolean IS_ZOS; + + static { + NAME = System.getProperty("os.name").toLowerCase(Locale.ROOT); + IS_WINDOWS = NAME.startsWith("windows"); + IS_ZOS = NAME.startsWith("z/os"); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/PrimitiveRef.java b/clients/src/main/java/org/apache/kafka/common/utils/PrimitiveRef.java new file mode 100644 index 0000000..e1bbfe3 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/PrimitiveRef.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +/** + * Primitive reference used to pass primitive typed values as parameter-by-reference. + * + * This is cheaper than using Atomic references. + */ +public class PrimitiveRef { + public static IntRef ofInt(int value) { + return new IntRef(value); + } + + public static class IntRef { + public int value; + + IntRef(int value) { + this.value = value; + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/ProducerIdAndEpoch.java b/clients/src/main/java/org/apache/kafka/common/utils/ProducerIdAndEpoch.java new file mode 100644 index 0000000..674b423 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/ProducerIdAndEpoch.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.apache.kafka.common.record.RecordBatch; + +public class ProducerIdAndEpoch { + public static final ProducerIdAndEpoch NONE = new ProducerIdAndEpoch(RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH); + + public final long producerId; + public final short epoch; + + public ProducerIdAndEpoch(long producerId, short epoch) { + this.producerId = producerId; + this.epoch = epoch; + } + + public boolean isValid() { + return RecordBatch.NO_PRODUCER_ID < producerId; + } + + @Override + public String toString() { + return "(producerId=" + producerId + ", epoch=" + epoch + ")"; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + ProducerIdAndEpoch that = (ProducerIdAndEpoch) o; + + if (producerId != that.producerId) return false; + return epoch == that.epoch; + } + + @Override + public int hashCode() { + int result = (int) (producerId ^ (producerId >>> 32)); + result = 31 * result + (int) epoch; + return result; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/PureJavaCrc32C.java b/clients/src/main/java/org/apache/kafka/common/utils/PureJavaCrc32C.java new file mode 100644 index 0000000..8abc93d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/PureJavaCrc32C.java @@ -0,0 +1,645 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + * Some portions of this file Copyright (c) 2004-2006 Intel Corporation and + * licensed under the BSD license. + */ +package org.apache.kafka.common.utils; + +import java.util.zip.Checksum; + +/** + * This class was taken from Hadoop: org.apache.hadoop.util.PureJavaCrc32C. + * + * A pure-java implementation of the CRC32 checksum that uses + * the CRC32-C polynomial, the same polynomial used by iSCSI + * and implemented on many Intel chipsets supporting SSE4.2. + * + * NOTE: This class is intended for INTERNAL usage only within Kafka. + */ +// The exact version that was retrieved from Hadoop: +// https://github.com/apache/hadoop/blob/224de4f92c222a7b915e9c5d6bdd1a4a3fcbcf31/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/util/PureJavaCrc32C.java +public class PureJavaCrc32C implements Checksum { + + /** the current CRC value, bit-flipped */ + private int crc; + + public PureJavaCrc32C() { + reset(); + } + + @Override + public long getValue() { + long ret = crc; + return (~ret) & 0xffffffffL; + } + + @Override + public void reset() { + crc = 0xffffffff; + } + + @SuppressWarnings("fallthrough") + @Override + public void update(byte[] b, int off, int len) { + int localCrc = crc; + + while (len > 7) { + final int c0 = (b[off + 0] ^ localCrc) & 0xff; + final int c1 = (b[off + 1] ^ (localCrc >>>= 8)) & 0xff; + final int c2 = (b[off + 2] ^ (localCrc >>>= 8)) & 0xff; + final int c3 = (b[off + 3] ^ (localCrc >>>= 8)) & 0xff; + localCrc = (T[T8_7_START + c0] ^ T[T8_6_START + c1]) + ^ (T[T8_5_START + c2] ^ T[T8_4_START + c3]); + + final int c4 = b[off + 4] & 0xff; + final int c5 = b[off + 5] & 0xff; + final int c6 = b[off + 6] & 0xff; + final int c7 = b[off + 7] & 0xff; + + localCrc ^= (T[T8_3_START + c4] ^ T[T8_2_START + c5]) + ^ (T[T8_1_START + c6] ^ T[T8_0_START + c7]); + + off += 8; + len -= 8; + } + + /* loop unroll - duff's device style */ + switch (len) { + case 7: + localCrc = (localCrc >>> 8) ^ T[T8_0_START + ((localCrc ^ b[off++]) & 0xff)]; + case 6: + localCrc = (localCrc >>> 8) ^ T[T8_0_START + ((localCrc ^ b[off++]) & 0xff)]; + case 5: + localCrc = (localCrc >>> 8) ^ T[T8_0_START + ((localCrc ^ b[off++]) & 0xff)]; + case 4: + localCrc = (localCrc >>> 8) ^ T[T8_0_START + ((localCrc ^ b[off++]) & 0xff)]; + case 3: + localCrc = (localCrc >>> 8) ^ T[T8_0_START + ((localCrc ^ b[off++]) & 0xff)]; + case 2: + localCrc = (localCrc >>> 8) ^ T[T8_0_START + ((localCrc ^ b[off++]) & 0xff)]; + case 1: + localCrc = (localCrc >>> 8) ^ T[T8_0_START + ((localCrc ^ b[off++]) & 0xff)]; + default: + /* nothing */ + } + + // Publish crc out to object + crc = localCrc; + } + + @Override + final public void update(int b) { + crc = (crc >>> 8) ^ T[T8_0_START + ((crc ^ b) & 0xff)]; + } + + // CRC polynomial tables generated by: + // java -cp build/test/classes/:build/classes/ \ + // org.apache.hadoop.util.TestPureJavaCrc32\$Table 82F63B78 + + private static final int T8_0_START = 0 * 256; + private static final int T8_1_START = 1 * 256; + private static final int T8_2_START = 2 * 256; + private static final int T8_3_START = 3 * 256; + private static final int T8_4_START = 4 * 256; + private static final int T8_5_START = 5 * 256; + private static final int T8_6_START = 6 * 256; + private static final int T8_7_START = 7 * 256; + + private static final int[] T = new int[]{ + /* T8_0 */ + 0x00000000, 0xF26B8303, 0xE13B70F7, 0x1350F3F4, + 0xC79A971F, 0x35F1141C, 0x26A1E7E8, 0xD4CA64EB, + 0x8AD958CF, 0x78B2DBCC, 0x6BE22838, 0x9989AB3B, + 0x4D43CFD0, 0xBF284CD3, 0xAC78BF27, 0x5E133C24, + 0x105EC76F, 0xE235446C, 0xF165B798, 0x030E349B, + 0xD7C45070, 0x25AFD373, 0x36FF2087, 0xC494A384, + 0x9A879FA0, 0x68EC1CA3, 0x7BBCEF57, 0x89D76C54, + 0x5D1D08BF, 0xAF768BBC, 0xBC267848, 0x4E4DFB4B, + 0x20BD8EDE, 0xD2D60DDD, 0xC186FE29, 0x33ED7D2A, + 0xE72719C1, 0x154C9AC2, 0x061C6936, 0xF477EA35, + 0xAA64D611, 0x580F5512, 0x4B5FA6E6, 0xB93425E5, + 0x6DFE410E, 0x9F95C20D, 0x8CC531F9, 0x7EAEB2FA, + 0x30E349B1, 0xC288CAB2, 0xD1D83946, 0x23B3BA45, + 0xF779DEAE, 0x05125DAD, 0x1642AE59, 0xE4292D5A, + 0xBA3A117E, 0x4851927D, 0x5B016189, 0xA96AE28A, + 0x7DA08661, 0x8FCB0562, 0x9C9BF696, 0x6EF07595, + 0x417B1DBC, 0xB3109EBF, 0xA0406D4B, 0x522BEE48, + 0x86E18AA3, 0x748A09A0, 0x67DAFA54, 0x95B17957, + 0xCBA24573, 0x39C9C670, 0x2A993584, 0xD8F2B687, + 0x0C38D26C, 0xFE53516F, 0xED03A29B, 0x1F682198, + 0x5125DAD3, 0xA34E59D0, 0xB01EAA24, 0x42752927, + 0x96BF4DCC, 0x64D4CECF, 0x77843D3B, 0x85EFBE38, + 0xDBFC821C, 0x2997011F, 0x3AC7F2EB, 0xC8AC71E8, + 0x1C661503, 0xEE0D9600, 0xFD5D65F4, 0x0F36E6F7, + 0x61C69362, 0x93AD1061, 0x80FDE395, 0x72966096, + 0xA65C047D, 0x5437877E, 0x4767748A, 0xB50CF789, + 0xEB1FCBAD, 0x197448AE, 0x0A24BB5A, 0xF84F3859, + 0x2C855CB2, 0xDEEEDFB1, 0xCDBE2C45, 0x3FD5AF46, + 0x7198540D, 0x83F3D70E, 0x90A324FA, 0x62C8A7F9, + 0xB602C312, 0x44694011, 0x5739B3E5, 0xA55230E6, + 0xFB410CC2, 0x092A8FC1, 0x1A7A7C35, 0xE811FF36, + 0x3CDB9BDD, 0xCEB018DE, 0xDDE0EB2A, 0x2F8B6829, + 0x82F63B78, 0x709DB87B, 0x63CD4B8F, 0x91A6C88C, + 0x456CAC67, 0xB7072F64, 0xA457DC90, 0x563C5F93, + 0x082F63B7, 0xFA44E0B4, 0xE9141340, 0x1B7F9043, + 0xCFB5F4A8, 0x3DDE77AB, 0x2E8E845F, 0xDCE5075C, + 0x92A8FC17, 0x60C37F14, 0x73938CE0, 0x81F80FE3, + 0x55326B08, 0xA759E80B, 0xB4091BFF, 0x466298FC, + 0x1871A4D8, 0xEA1A27DB, 0xF94AD42F, 0x0B21572C, + 0xDFEB33C7, 0x2D80B0C4, 0x3ED04330, 0xCCBBC033, + 0xA24BB5A6, 0x502036A5, 0x4370C551, 0xB11B4652, + 0x65D122B9, 0x97BAA1BA, 0x84EA524E, 0x7681D14D, + 0x2892ED69, 0xDAF96E6A, 0xC9A99D9E, 0x3BC21E9D, + 0xEF087A76, 0x1D63F975, 0x0E330A81, 0xFC588982, + 0xB21572C9, 0x407EF1CA, 0x532E023E, 0xA145813D, + 0x758FE5D6, 0x87E466D5, 0x94B49521, 0x66DF1622, + 0x38CC2A06, 0xCAA7A905, 0xD9F75AF1, 0x2B9CD9F2, + 0xFF56BD19, 0x0D3D3E1A, 0x1E6DCDEE, 0xEC064EED, + 0xC38D26C4, 0x31E6A5C7, 0x22B65633, 0xD0DDD530, + 0x0417B1DB, 0xF67C32D8, 0xE52CC12C, 0x1747422F, + 0x49547E0B, 0xBB3FFD08, 0xA86F0EFC, 0x5A048DFF, + 0x8ECEE914, 0x7CA56A17, 0x6FF599E3, 0x9D9E1AE0, + 0xD3D3E1AB, 0x21B862A8, 0x32E8915C, 0xC083125F, + 0x144976B4, 0xE622F5B7, 0xF5720643, 0x07198540, + 0x590AB964, 0xAB613A67, 0xB831C993, 0x4A5A4A90, + 0x9E902E7B, 0x6CFBAD78, 0x7FAB5E8C, 0x8DC0DD8F, + 0xE330A81A, 0x115B2B19, 0x020BD8ED, 0xF0605BEE, + 0x24AA3F05, 0xD6C1BC06, 0xC5914FF2, 0x37FACCF1, + 0x69E9F0D5, 0x9B8273D6, 0x88D28022, 0x7AB90321, + 0xAE7367CA, 0x5C18E4C9, 0x4F48173D, 0xBD23943E, + 0xF36E6F75, 0x0105EC76, 0x12551F82, 0xE03E9C81, + 0x34F4F86A, 0xC69F7B69, 0xD5CF889D, 0x27A40B9E, + 0x79B737BA, 0x8BDCB4B9, 0x988C474D, 0x6AE7C44E, + 0xBE2DA0A5, 0x4C4623A6, 0x5F16D052, 0xAD7D5351, + /* T8_1 */ + 0x00000000, 0x13A29877, 0x274530EE, 0x34E7A899, + 0x4E8A61DC, 0x5D28F9AB, 0x69CF5132, 0x7A6DC945, + 0x9D14C3B8, 0x8EB65BCF, 0xBA51F356, 0xA9F36B21, + 0xD39EA264, 0xC03C3A13, 0xF4DB928A, 0xE7790AFD, + 0x3FC5F181, 0x2C6769F6, 0x1880C16F, 0x0B225918, + 0x714F905D, 0x62ED082A, 0x560AA0B3, 0x45A838C4, + 0xA2D13239, 0xB173AA4E, 0x859402D7, 0x96369AA0, + 0xEC5B53E5, 0xFFF9CB92, 0xCB1E630B, 0xD8BCFB7C, + 0x7F8BE302, 0x6C297B75, 0x58CED3EC, 0x4B6C4B9B, + 0x310182DE, 0x22A31AA9, 0x1644B230, 0x05E62A47, + 0xE29F20BA, 0xF13DB8CD, 0xC5DA1054, 0xD6788823, + 0xAC154166, 0xBFB7D911, 0x8B507188, 0x98F2E9FF, + 0x404E1283, 0x53EC8AF4, 0x670B226D, 0x74A9BA1A, + 0x0EC4735F, 0x1D66EB28, 0x298143B1, 0x3A23DBC6, + 0xDD5AD13B, 0xCEF8494C, 0xFA1FE1D5, 0xE9BD79A2, + 0x93D0B0E7, 0x80722890, 0xB4958009, 0xA737187E, + 0xFF17C604, 0xECB55E73, 0xD852F6EA, 0xCBF06E9D, + 0xB19DA7D8, 0xA23F3FAF, 0x96D89736, 0x857A0F41, + 0x620305BC, 0x71A19DCB, 0x45463552, 0x56E4AD25, + 0x2C896460, 0x3F2BFC17, 0x0BCC548E, 0x186ECCF9, + 0xC0D23785, 0xD370AFF2, 0xE797076B, 0xF4359F1C, + 0x8E585659, 0x9DFACE2E, 0xA91D66B7, 0xBABFFEC0, + 0x5DC6F43D, 0x4E646C4A, 0x7A83C4D3, 0x69215CA4, + 0x134C95E1, 0x00EE0D96, 0x3409A50F, 0x27AB3D78, + 0x809C2506, 0x933EBD71, 0xA7D915E8, 0xB47B8D9F, + 0xCE1644DA, 0xDDB4DCAD, 0xE9537434, 0xFAF1EC43, + 0x1D88E6BE, 0x0E2A7EC9, 0x3ACDD650, 0x296F4E27, + 0x53028762, 0x40A01F15, 0x7447B78C, 0x67E52FFB, + 0xBF59D487, 0xACFB4CF0, 0x981CE469, 0x8BBE7C1E, + 0xF1D3B55B, 0xE2712D2C, 0xD69685B5, 0xC5341DC2, + 0x224D173F, 0x31EF8F48, 0x050827D1, 0x16AABFA6, + 0x6CC776E3, 0x7F65EE94, 0x4B82460D, 0x5820DE7A, + 0xFBC3FAF9, 0xE861628E, 0xDC86CA17, 0xCF245260, + 0xB5499B25, 0xA6EB0352, 0x920CABCB, 0x81AE33BC, + 0x66D73941, 0x7575A136, 0x419209AF, 0x523091D8, + 0x285D589D, 0x3BFFC0EA, 0x0F186873, 0x1CBAF004, + 0xC4060B78, 0xD7A4930F, 0xE3433B96, 0xF0E1A3E1, + 0x8A8C6AA4, 0x992EF2D3, 0xADC95A4A, 0xBE6BC23D, + 0x5912C8C0, 0x4AB050B7, 0x7E57F82E, 0x6DF56059, + 0x1798A91C, 0x043A316B, 0x30DD99F2, 0x237F0185, + 0x844819FB, 0x97EA818C, 0xA30D2915, 0xB0AFB162, + 0xCAC27827, 0xD960E050, 0xED8748C9, 0xFE25D0BE, + 0x195CDA43, 0x0AFE4234, 0x3E19EAAD, 0x2DBB72DA, + 0x57D6BB9F, 0x447423E8, 0x70938B71, 0x63311306, + 0xBB8DE87A, 0xA82F700D, 0x9CC8D894, 0x8F6A40E3, + 0xF50789A6, 0xE6A511D1, 0xD242B948, 0xC1E0213F, + 0x26992BC2, 0x353BB3B5, 0x01DC1B2C, 0x127E835B, + 0x68134A1E, 0x7BB1D269, 0x4F567AF0, 0x5CF4E287, + 0x04D43CFD, 0x1776A48A, 0x23910C13, 0x30339464, + 0x4A5E5D21, 0x59FCC556, 0x6D1B6DCF, 0x7EB9F5B8, + 0x99C0FF45, 0x8A626732, 0xBE85CFAB, 0xAD2757DC, + 0xD74A9E99, 0xC4E806EE, 0xF00FAE77, 0xE3AD3600, + 0x3B11CD7C, 0x28B3550B, 0x1C54FD92, 0x0FF665E5, + 0x759BACA0, 0x663934D7, 0x52DE9C4E, 0x417C0439, + 0xA6050EC4, 0xB5A796B3, 0x81403E2A, 0x92E2A65D, + 0xE88F6F18, 0xFB2DF76F, 0xCFCA5FF6, 0xDC68C781, + 0x7B5FDFFF, 0x68FD4788, 0x5C1AEF11, 0x4FB87766, + 0x35D5BE23, 0x26772654, 0x12908ECD, 0x013216BA, + 0xE64B1C47, 0xF5E98430, 0xC10E2CA9, 0xD2ACB4DE, + 0xA8C17D9B, 0xBB63E5EC, 0x8F844D75, 0x9C26D502, + 0x449A2E7E, 0x5738B609, 0x63DF1E90, 0x707D86E7, + 0x0A104FA2, 0x19B2D7D5, 0x2D557F4C, 0x3EF7E73B, + 0xD98EEDC6, 0xCA2C75B1, 0xFECBDD28, 0xED69455F, + 0x97048C1A, 0x84A6146D, 0xB041BCF4, 0xA3E32483, + /* T8_2 */ + 0x00000000, 0xA541927E, 0x4F6F520D, 0xEA2EC073, + 0x9EDEA41A, 0x3B9F3664, 0xD1B1F617, 0x74F06469, + 0x38513EC5, 0x9D10ACBB, 0x773E6CC8, 0xD27FFEB6, + 0xA68F9ADF, 0x03CE08A1, 0xE9E0C8D2, 0x4CA15AAC, + 0x70A27D8A, 0xD5E3EFF4, 0x3FCD2F87, 0x9A8CBDF9, + 0xEE7CD990, 0x4B3D4BEE, 0xA1138B9D, 0x045219E3, + 0x48F3434F, 0xEDB2D131, 0x079C1142, 0xA2DD833C, + 0xD62DE755, 0x736C752B, 0x9942B558, 0x3C032726, + 0xE144FB14, 0x4405696A, 0xAE2BA919, 0x0B6A3B67, + 0x7F9A5F0E, 0xDADBCD70, 0x30F50D03, 0x95B49F7D, + 0xD915C5D1, 0x7C5457AF, 0x967A97DC, 0x333B05A2, + 0x47CB61CB, 0xE28AF3B5, 0x08A433C6, 0xADE5A1B8, + 0x91E6869E, 0x34A714E0, 0xDE89D493, 0x7BC846ED, + 0x0F382284, 0xAA79B0FA, 0x40577089, 0xE516E2F7, + 0xA9B7B85B, 0x0CF62A25, 0xE6D8EA56, 0x43997828, + 0x37691C41, 0x92288E3F, 0x78064E4C, 0xDD47DC32, + 0xC76580D9, 0x622412A7, 0x880AD2D4, 0x2D4B40AA, + 0x59BB24C3, 0xFCFAB6BD, 0x16D476CE, 0xB395E4B0, + 0xFF34BE1C, 0x5A752C62, 0xB05BEC11, 0x151A7E6F, + 0x61EA1A06, 0xC4AB8878, 0x2E85480B, 0x8BC4DA75, + 0xB7C7FD53, 0x12866F2D, 0xF8A8AF5E, 0x5DE93D20, + 0x29195949, 0x8C58CB37, 0x66760B44, 0xC337993A, + 0x8F96C396, 0x2AD751E8, 0xC0F9919B, 0x65B803E5, + 0x1148678C, 0xB409F5F2, 0x5E273581, 0xFB66A7FF, + 0x26217BCD, 0x8360E9B3, 0x694E29C0, 0xCC0FBBBE, + 0xB8FFDFD7, 0x1DBE4DA9, 0xF7908DDA, 0x52D11FA4, + 0x1E704508, 0xBB31D776, 0x511F1705, 0xF45E857B, + 0x80AEE112, 0x25EF736C, 0xCFC1B31F, 0x6A802161, + 0x56830647, 0xF3C29439, 0x19EC544A, 0xBCADC634, + 0xC85DA25D, 0x6D1C3023, 0x8732F050, 0x2273622E, + 0x6ED23882, 0xCB93AAFC, 0x21BD6A8F, 0x84FCF8F1, + 0xF00C9C98, 0x554D0EE6, 0xBF63CE95, 0x1A225CEB, + 0x8B277743, 0x2E66E53D, 0xC448254E, 0x6109B730, + 0x15F9D359, 0xB0B84127, 0x5A968154, 0xFFD7132A, + 0xB3764986, 0x1637DBF8, 0xFC191B8B, 0x595889F5, + 0x2DA8ED9C, 0x88E97FE2, 0x62C7BF91, 0xC7862DEF, + 0xFB850AC9, 0x5EC498B7, 0xB4EA58C4, 0x11ABCABA, + 0x655BAED3, 0xC01A3CAD, 0x2A34FCDE, 0x8F756EA0, + 0xC3D4340C, 0x6695A672, 0x8CBB6601, 0x29FAF47F, + 0x5D0A9016, 0xF84B0268, 0x1265C21B, 0xB7245065, + 0x6A638C57, 0xCF221E29, 0x250CDE5A, 0x804D4C24, + 0xF4BD284D, 0x51FCBA33, 0xBBD27A40, 0x1E93E83E, + 0x5232B292, 0xF77320EC, 0x1D5DE09F, 0xB81C72E1, + 0xCCEC1688, 0x69AD84F6, 0x83834485, 0x26C2D6FB, + 0x1AC1F1DD, 0xBF8063A3, 0x55AEA3D0, 0xF0EF31AE, + 0x841F55C7, 0x215EC7B9, 0xCB7007CA, 0x6E3195B4, + 0x2290CF18, 0x87D15D66, 0x6DFF9D15, 0xC8BE0F6B, + 0xBC4E6B02, 0x190FF97C, 0xF321390F, 0x5660AB71, + 0x4C42F79A, 0xE90365E4, 0x032DA597, 0xA66C37E9, + 0xD29C5380, 0x77DDC1FE, 0x9DF3018D, 0x38B293F3, + 0x7413C95F, 0xD1525B21, 0x3B7C9B52, 0x9E3D092C, + 0xEACD6D45, 0x4F8CFF3B, 0xA5A23F48, 0x00E3AD36, + 0x3CE08A10, 0x99A1186E, 0x738FD81D, 0xD6CE4A63, + 0xA23E2E0A, 0x077FBC74, 0xED517C07, 0x4810EE79, + 0x04B1B4D5, 0xA1F026AB, 0x4BDEE6D8, 0xEE9F74A6, + 0x9A6F10CF, 0x3F2E82B1, 0xD50042C2, 0x7041D0BC, + 0xAD060C8E, 0x08479EF0, 0xE2695E83, 0x4728CCFD, + 0x33D8A894, 0x96993AEA, 0x7CB7FA99, 0xD9F668E7, + 0x9557324B, 0x3016A035, 0xDA386046, 0x7F79F238, + 0x0B899651, 0xAEC8042F, 0x44E6C45C, 0xE1A75622, + 0xDDA47104, 0x78E5E37A, 0x92CB2309, 0x378AB177, + 0x437AD51E, 0xE63B4760, 0x0C158713, 0xA954156D, + 0xE5F54FC1, 0x40B4DDBF, 0xAA9A1DCC, 0x0FDB8FB2, + 0x7B2BEBDB, 0xDE6A79A5, 0x3444B9D6, 0x91052BA8, + /* T8_3 */ + 0x00000000, 0xDD45AAB8, 0xBF672381, 0x62228939, + 0x7B2231F3, 0xA6679B4B, 0xC4451272, 0x1900B8CA, + 0xF64463E6, 0x2B01C95E, 0x49234067, 0x9466EADF, + 0x8D665215, 0x5023F8AD, 0x32017194, 0xEF44DB2C, + 0xE964B13D, 0x34211B85, 0x560392BC, 0x8B463804, + 0x924680CE, 0x4F032A76, 0x2D21A34F, 0xF06409F7, + 0x1F20D2DB, 0xC2657863, 0xA047F15A, 0x7D025BE2, + 0x6402E328, 0xB9474990, 0xDB65C0A9, 0x06206A11, + 0xD725148B, 0x0A60BE33, 0x6842370A, 0xB5079DB2, + 0xAC072578, 0x71428FC0, 0x136006F9, 0xCE25AC41, + 0x2161776D, 0xFC24DDD5, 0x9E0654EC, 0x4343FE54, + 0x5A43469E, 0x8706EC26, 0xE524651F, 0x3861CFA7, + 0x3E41A5B6, 0xE3040F0E, 0x81268637, 0x5C632C8F, + 0x45639445, 0x98263EFD, 0xFA04B7C4, 0x27411D7C, + 0xC805C650, 0x15406CE8, 0x7762E5D1, 0xAA274F69, + 0xB327F7A3, 0x6E625D1B, 0x0C40D422, 0xD1057E9A, + 0xABA65FE7, 0x76E3F55F, 0x14C17C66, 0xC984D6DE, + 0xD0846E14, 0x0DC1C4AC, 0x6FE34D95, 0xB2A6E72D, + 0x5DE23C01, 0x80A796B9, 0xE2851F80, 0x3FC0B538, + 0x26C00DF2, 0xFB85A74A, 0x99A72E73, 0x44E284CB, + 0x42C2EEDA, 0x9F874462, 0xFDA5CD5B, 0x20E067E3, + 0x39E0DF29, 0xE4A57591, 0x8687FCA8, 0x5BC25610, + 0xB4868D3C, 0x69C32784, 0x0BE1AEBD, 0xD6A40405, + 0xCFA4BCCF, 0x12E11677, 0x70C39F4E, 0xAD8635F6, + 0x7C834B6C, 0xA1C6E1D4, 0xC3E468ED, 0x1EA1C255, + 0x07A17A9F, 0xDAE4D027, 0xB8C6591E, 0x6583F3A6, + 0x8AC7288A, 0x57828232, 0x35A00B0B, 0xE8E5A1B3, + 0xF1E51979, 0x2CA0B3C1, 0x4E823AF8, 0x93C79040, + 0x95E7FA51, 0x48A250E9, 0x2A80D9D0, 0xF7C57368, + 0xEEC5CBA2, 0x3380611A, 0x51A2E823, 0x8CE7429B, + 0x63A399B7, 0xBEE6330F, 0xDCC4BA36, 0x0181108E, + 0x1881A844, 0xC5C402FC, 0xA7E68BC5, 0x7AA3217D, + 0x52A0C93F, 0x8FE56387, 0xEDC7EABE, 0x30824006, + 0x2982F8CC, 0xF4C75274, 0x96E5DB4D, 0x4BA071F5, + 0xA4E4AAD9, 0x79A10061, 0x1B838958, 0xC6C623E0, + 0xDFC69B2A, 0x02833192, 0x60A1B8AB, 0xBDE41213, + 0xBBC47802, 0x6681D2BA, 0x04A35B83, 0xD9E6F13B, + 0xC0E649F1, 0x1DA3E349, 0x7F816A70, 0xA2C4C0C8, + 0x4D801BE4, 0x90C5B15C, 0xF2E73865, 0x2FA292DD, + 0x36A22A17, 0xEBE780AF, 0x89C50996, 0x5480A32E, + 0x8585DDB4, 0x58C0770C, 0x3AE2FE35, 0xE7A7548D, + 0xFEA7EC47, 0x23E246FF, 0x41C0CFC6, 0x9C85657E, + 0x73C1BE52, 0xAE8414EA, 0xCCA69DD3, 0x11E3376B, + 0x08E38FA1, 0xD5A62519, 0xB784AC20, 0x6AC10698, + 0x6CE16C89, 0xB1A4C631, 0xD3864F08, 0x0EC3E5B0, + 0x17C35D7A, 0xCA86F7C2, 0xA8A47EFB, 0x75E1D443, + 0x9AA50F6F, 0x47E0A5D7, 0x25C22CEE, 0xF8878656, + 0xE1873E9C, 0x3CC29424, 0x5EE01D1D, 0x83A5B7A5, + 0xF90696D8, 0x24433C60, 0x4661B559, 0x9B241FE1, + 0x8224A72B, 0x5F610D93, 0x3D4384AA, 0xE0062E12, + 0x0F42F53E, 0xD2075F86, 0xB025D6BF, 0x6D607C07, + 0x7460C4CD, 0xA9256E75, 0xCB07E74C, 0x16424DF4, + 0x106227E5, 0xCD278D5D, 0xAF050464, 0x7240AEDC, + 0x6B401616, 0xB605BCAE, 0xD4273597, 0x09629F2F, + 0xE6264403, 0x3B63EEBB, 0x59416782, 0x8404CD3A, + 0x9D0475F0, 0x4041DF48, 0x22635671, 0xFF26FCC9, + 0x2E238253, 0xF36628EB, 0x9144A1D2, 0x4C010B6A, + 0x5501B3A0, 0x88441918, 0xEA669021, 0x37233A99, + 0xD867E1B5, 0x05224B0D, 0x6700C234, 0xBA45688C, + 0xA345D046, 0x7E007AFE, 0x1C22F3C7, 0xC167597F, + 0xC747336E, 0x1A0299D6, 0x782010EF, 0xA565BA57, + 0xBC65029D, 0x6120A825, 0x0302211C, 0xDE478BA4, + 0x31035088, 0xEC46FA30, 0x8E647309, 0x5321D9B1, + 0x4A21617B, 0x9764CBC3, 0xF54642FA, 0x2803E842, + /* T8_4 */ + 0x00000000, 0x38116FAC, 0x7022DF58, 0x4833B0F4, + 0xE045BEB0, 0xD854D11C, 0x906761E8, 0xA8760E44, + 0xC5670B91, 0xFD76643D, 0xB545D4C9, 0x8D54BB65, + 0x2522B521, 0x1D33DA8D, 0x55006A79, 0x6D1105D5, + 0x8F2261D3, 0xB7330E7F, 0xFF00BE8B, 0xC711D127, + 0x6F67DF63, 0x5776B0CF, 0x1F45003B, 0x27546F97, + 0x4A456A42, 0x725405EE, 0x3A67B51A, 0x0276DAB6, + 0xAA00D4F2, 0x9211BB5E, 0xDA220BAA, 0xE2336406, + 0x1BA8B557, 0x23B9DAFB, 0x6B8A6A0F, 0x539B05A3, + 0xFBED0BE7, 0xC3FC644B, 0x8BCFD4BF, 0xB3DEBB13, + 0xDECFBEC6, 0xE6DED16A, 0xAEED619E, 0x96FC0E32, + 0x3E8A0076, 0x069B6FDA, 0x4EA8DF2E, 0x76B9B082, + 0x948AD484, 0xAC9BBB28, 0xE4A80BDC, 0xDCB96470, + 0x74CF6A34, 0x4CDE0598, 0x04EDB56C, 0x3CFCDAC0, + 0x51EDDF15, 0x69FCB0B9, 0x21CF004D, 0x19DE6FE1, + 0xB1A861A5, 0x89B90E09, 0xC18ABEFD, 0xF99BD151, + 0x37516AAE, 0x0F400502, 0x4773B5F6, 0x7F62DA5A, + 0xD714D41E, 0xEF05BBB2, 0xA7360B46, 0x9F2764EA, + 0xF236613F, 0xCA270E93, 0x8214BE67, 0xBA05D1CB, + 0x1273DF8F, 0x2A62B023, 0x625100D7, 0x5A406F7B, + 0xB8730B7D, 0x806264D1, 0xC851D425, 0xF040BB89, + 0x5836B5CD, 0x6027DA61, 0x28146A95, 0x10050539, + 0x7D1400EC, 0x45056F40, 0x0D36DFB4, 0x3527B018, + 0x9D51BE5C, 0xA540D1F0, 0xED736104, 0xD5620EA8, + 0x2CF9DFF9, 0x14E8B055, 0x5CDB00A1, 0x64CA6F0D, + 0xCCBC6149, 0xF4AD0EE5, 0xBC9EBE11, 0x848FD1BD, + 0xE99ED468, 0xD18FBBC4, 0x99BC0B30, 0xA1AD649C, + 0x09DB6AD8, 0x31CA0574, 0x79F9B580, 0x41E8DA2C, + 0xA3DBBE2A, 0x9BCAD186, 0xD3F96172, 0xEBE80EDE, + 0x439E009A, 0x7B8F6F36, 0x33BCDFC2, 0x0BADB06E, + 0x66BCB5BB, 0x5EADDA17, 0x169E6AE3, 0x2E8F054F, + 0x86F90B0B, 0xBEE864A7, 0xF6DBD453, 0xCECABBFF, + 0x6EA2D55C, 0x56B3BAF0, 0x1E800A04, 0x269165A8, + 0x8EE76BEC, 0xB6F60440, 0xFEC5B4B4, 0xC6D4DB18, + 0xABC5DECD, 0x93D4B161, 0xDBE70195, 0xE3F66E39, + 0x4B80607D, 0x73910FD1, 0x3BA2BF25, 0x03B3D089, + 0xE180B48F, 0xD991DB23, 0x91A26BD7, 0xA9B3047B, + 0x01C50A3F, 0x39D46593, 0x71E7D567, 0x49F6BACB, + 0x24E7BF1E, 0x1CF6D0B2, 0x54C56046, 0x6CD40FEA, + 0xC4A201AE, 0xFCB36E02, 0xB480DEF6, 0x8C91B15A, + 0x750A600B, 0x4D1B0FA7, 0x0528BF53, 0x3D39D0FF, + 0x954FDEBB, 0xAD5EB117, 0xE56D01E3, 0xDD7C6E4F, + 0xB06D6B9A, 0x887C0436, 0xC04FB4C2, 0xF85EDB6E, + 0x5028D52A, 0x6839BA86, 0x200A0A72, 0x181B65DE, + 0xFA2801D8, 0xC2396E74, 0x8A0ADE80, 0xB21BB12C, + 0x1A6DBF68, 0x227CD0C4, 0x6A4F6030, 0x525E0F9C, + 0x3F4F0A49, 0x075E65E5, 0x4F6DD511, 0x777CBABD, + 0xDF0AB4F9, 0xE71BDB55, 0xAF286BA1, 0x9739040D, + 0x59F3BFF2, 0x61E2D05E, 0x29D160AA, 0x11C00F06, + 0xB9B60142, 0x81A76EEE, 0xC994DE1A, 0xF185B1B6, + 0x9C94B463, 0xA485DBCF, 0xECB66B3B, 0xD4A70497, + 0x7CD10AD3, 0x44C0657F, 0x0CF3D58B, 0x34E2BA27, + 0xD6D1DE21, 0xEEC0B18D, 0xA6F30179, 0x9EE26ED5, + 0x36946091, 0x0E850F3D, 0x46B6BFC9, 0x7EA7D065, + 0x13B6D5B0, 0x2BA7BA1C, 0x63940AE8, 0x5B856544, + 0xF3F36B00, 0xCBE204AC, 0x83D1B458, 0xBBC0DBF4, + 0x425B0AA5, 0x7A4A6509, 0x3279D5FD, 0x0A68BA51, + 0xA21EB415, 0x9A0FDBB9, 0xD23C6B4D, 0xEA2D04E1, + 0x873C0134, 0xBF2D6E98, 0xF71EDE6C, 0xCF0FB1C0, + 0x6779BF84, 0x5F68D028, 0x175B60DC, 0x2F4A0F70, + 0xCD796B76, 0xF56804DA, 0xBD5BB42E, 0x854ADB82, + 0x2D3CD5C6, 0x152DBA6A, 0x5D1E0A9E, 0x650F6532, + 0x081E60E7, 0x300F0F4B, 0x783CBFBF, 0x402DD013, + 0xE85BDE57, 0xD04AB1FB, 0x9879010F, 0xA0686EA3, + /* T8_5 */ + 0x00000000, 0xEF306B19, 0xDB8CA0C3, 0x34BCCBDA, + 0xB2F53777, 0x5DC55C6E, 0x697997B4, 0x8649FCAD, + 0x6006181F, 0x8F367306, 0xBB8AB8DC, 0x54BAD3C5, + 0xD2F32F68, 0x3DC34471, 0x097F8FAB, 0xE64FE4B2, + 0xC00C303E, 0x2F3C5B27, 0x1B8090FD, 0xF4B0FBE4, + 0x72F90749, 0x9DC96C50, 0xA975A78A, 0x4645CC93, + 0xA00A2821, 0x4F3A4338, 0x7B8688E2, 0x94B6E3FB, + 0x12FF1F56, 0xFDCF744F, 0xC973BF95, 0x2643D48C, + 0x85F4168D, 0x6AC47D94, 0x5E78B64E, 0xB148DD57, + 0x370121FA, 0xD8314AE3, 0xEC8D8139, 0x03BDEA20, + 0xE5F20E92, 0x0AC2658B, 0x3E7EAE51, 0xD14EC548, + 0x570739E5, 0xB83752FC, 0x8C8B9926, 0x63BBF23F, + 0x45F826B3, 0xAAC84DAA, 0x9E748670, 0x7144ED69, + 0xF70D11C4, 0x183D7ADD, 0x2C81B107, 0xC3B1DA1E, + 0x25FE3EAC, 0xCACE55B5, 0xFE729E6F, 0x1142F576, + 0x970B09DB, 0x783B62C2, 0x4C87A918, 0xA3B7C201, + 0x0E045BEB, 0xE13430F2, 0xD588FB28, 0x3AB89031, + 0xBCF16C9C, 0x53C10785, 0x677DCC5F, 0x884DA746, + 0x6E0243F4, 0x813228ED, 0xB58EE337, 0x5ABE882E, + 0xDCF77483, 0x33C71F9A, 0x077BD440, 0xE84BBF59, + 0xCE086BD5, 0x213800CC, 0x1584CB16, 0xFAB4A00F, + 0x7CFD5CA2, 0x93CD37BB, 0xA771FC61, 0x48419778, + 0xAE0E73CA, 0x413E18D3, 0x7582D309, 0x9AB2B810, + 0x1CFB44BD, 0xF3CB2FA4, 0xC777E47E, 0x28478F67, + 0x8BF04D66, 0x64C0267F, 0x507CEDA5, 0xBF4C86BC, + 0x39057A11, 0xD6351108, 0xE289DAD2, 0x0DB9B1CB, + 0xEBF65579, 0x04C63E60, 0x307AF5BA, 0xDF4A9EA3, + 0x5903620E, 0xB6330917, 0x828FC2CD, 0x6DBFA9D4, + 0x4BFC7D58, 0xA4CC1641, 0x9070DD9B, 0x7F40B682, + 0xF9094A2F, 0x16392136, 0x2285EAEC, 0xCDB581F5, + 0x2BFA6547, 0xC4CA0E5E, 0xF076C584, 0x1F46AE9D, + 0x990F5230, 0x763F3929, 0x4283F2F3, 0xADB399EA, + 0x1C08B7D6, 0xF338DCCF, 0xC7841715, 0x28B47C0C, + 0xAEFD80A1, 0x41CDEBB8, 0x75712062, 0x9A414B7B, + 0x7C0EAFC9, 0x933EC4D0, 0xA7820F0A, 0x48B26413, + 0xCEFB98BE, 0x21CBF3A7, 0x1577387D, 0xFA475364, + 0xDC0487E8, 0x3334ECF1, 0x0788272B, 0xE8B84C32, + 0x6EF1B09F, 0x81C1DB86, 0xB57D105C, 0x5A4D7B45, + 0xBC029FF7, 0x5332F4EE, 0x678E3F34, 0x88BE542D, + 0x0EF7A880, 0xE1C7C399, 0xD57B0843, 0x3A4B635A, + 0x99FCA15B, 0x76CCCA42, 0x42700198, 0xAD406A81, + 0x2B09962C, 0xC439FD35, 0xF08536EF, 0x1FB55DF6, + 0xF9FAB944, 0x16CAD25D, 0x22761987, 0xCD46729E, + 0x4B0F8E33, 0xA43FE52A, 0x90832EF0, 0x7FB345E9, + 0x59F09165, 0xB6C0FA7C, 0x827C31A6, 0x6D4C5ABF, + 0xEB05A612, 0x0435CD0B, 0x308906D1, 0xDFB96DC8, + 0x39F6897A, 0xD6C6E263, 0xE27A29B9, 0x0D4A42A0, + 0x8B03BE0D, 0x6433D514, 0x508F1ECE, 0xBFBF75D7, + 0x120CEC3D, 0xFD3C8724, 0xC9804CFE, 0x26B027E7, + 0xA0F9DB4A, 0x4FC9B053, 0x7B757B89, 0x94451090, + 0x720AF422, 0x9D3A9F3B, 0xA98654E1, 0x46B63FF8, + 0xC0FFC355, 0x2FCFA84C, 0x1B736396, 0xF443088F, + 0xD200DC03, 0x3D30B71A, 0x098C7CC0, 0xE6BC17D9, + 0x60F5EB74, 0x8FC5806D, 0xBB794BB7, 0x544920AE, + 0xB206C41C, 0x5D36AF05, 0x698A64DF, 0x86BA0FC6, + 0x00F3F36B, 0xEFC39872, 0xDB7F53A8, 0x344F38B1, + 0x97F8FAB0, 0x78C891A9, 0x4C745A73, 0xA344316A, + 0x250DCDC7, 0xCA3DA6DE, 0xFE816D04, 0x11B1061D, + 0xF7FEE2AF, 0x18CE89B6, 0x2C72426C, 0xC3422975, + 0x450BD5D8, 0xAA3BBEC1, 0x9E87751B, 0x71B71E02, + 0x57F4CA8E, 0xB8C4A197, 0x8C786A4D, 0x63480154, + 0xE501FDF9, 0x0A3196E0, 0x3E8D5D3A, 0xD1BD3623, + 0x37F2D291, 0xD8C2B988, 0xEC7E7252, 0x034E194B, + 0x8507E5E6, 0x6A378EFF, 0x5E8B4525, 0xB1BB2E3C, + /* T8_6 */ + 0x00000000, 0x68032CC8, 0xD0065990, 0xB8057558, + 0xA5E0C5D1, 0xCDE3E919, 0x75E69C41, 0x1DE5B089, + 0x4E2DFD53, 0x262ED19B, 0x9E2BA4C3, 0xF628880B, + 0xEBCD3882, 0x83CE144A, 0x3BCB6112, 0x53C84DDA, + 0x9C5BFAA6, 0xF458D66E, 0x4C5DA336, 0x245E8FFE, + 0x39BB3F77, 0x51B813BF, 0xE9BD66E7, 0x81BE4A2F, + 0xD27607F5, 0xBA752B3D, 0x02705E65, 0x6A7372AD, + 0x7796C224, 0x1F95EEEC, 0xA7909BB4, 0xCF93B77C, + 0x3D5B83BD, 0x5558AF75, 0xED5DDA2D, 0x855EF6E5, + 0x98BB466C, 0xF0B86AA4, 0x48BD1FFC, 0x20BE3334, + 0x73767EEE, 0x1B755226, 0xA370277E, 0xCB730BB6, + 0xD696BB3F, 0xBE9597F7, 0x0690E2AF, 0x6E93CE67, + 0xA100791B, 0xC90355D3, 0x7106208B, 0x19050C43, + 0x04E0BCCA, 0x6CE39002, 0xD4E6E55A, 0xBCE5C992, + 0xEF2D8448, 0x872EA880, 0x3F2BDDD8, 0x5728F110, + 0x4ACD4199, 0x22CE6D51, 0x9ACB1809, 0xF2C834C1, + 0x7AB7077A, 0x12B42BB2, 0xAAB15EEA, 0xC2B27222, + 0xDF57C2AB, 0xB754EE63, 0x0F519B3B, 0x6752B7F3, + 0x349AFA29, 0x5C99D6E1, 0xE49CA3B9, 0x8C9F8F71, + 0x917A3FF8, 0xF9791330, 0x417C6668, 0x297F4AA0, + 0xE6ECFDDC, 0x8EEFD114, 0x36EAA44C, 0x5EE98884, + 0x430C380D, 0x2B0F14C5, 0x930A619D, 0xFB094D55, + 0xA8C1008F, 0xC0C22C47, 0x78C7591F, 0x10C475D7, + 0x0D21C55E, 0x6522E996, 0xDD279CCE, 0xB524B006, + 0x47EC84C7, 0x2FEFA80F, 0x97EADD57, 0xFFE9F19F, + 0xE20C4116, 0x8A0F6DDE, 0x320A1886, 0x5A09344E, + 0x09C17994, 0x61C2555C, 0xD9C72004, 0xB1C40CCC, + 0xAC21BC45, 0xC422908D, 0x7C27E5D5, 0x1424C91D, + 0xDBB77E61, 0xB3B452A9, 0x0BB127F1, 0x63B20B39, + 0x7E57BBB0, 0x16549778, 0xAE51E220, 0xC652CEE8, + 0x959A8332, 0xFD99AFFA, 0x459CDAA2, 0x2D9FF66A, + 0x307A46E3, 0x58796A2B, 0xE07C1F73, 0x887F33BB, + 0xF56E0EF4, 0x9D6D223C, 0x25685764, 0x4D6B7BAC, + 0x508ECB25, 0x388DE7ED, 0x808892B5, 0xE88BBE7D, + 0xBB43F3A7, 0xD340DF6F, 0x6B45AA37, 0x034686FF, + 0x1EA33676, 0x76A01ABE, 0xCEA56FE6, 0xA6A6432E, + 0x6935F452, 0x0136D89A, 0xB933ADC2, 0xD130810A, + 0xCCD53183, 0xA4D61D4B, 0x1CD36813, 0x74D044DB, + 0x27180901, 0x4F1B25C9, 0xF71E5091, 0x9F1D7C59, + 0x82F8CCD0, 0xEAFBE018, 0x52FE9540, 0x3AFDB988, + 0xC8358D49, 0xA036A181, 0x1833D4D9, 0x7030F811, + 0x6DD54898, 0x05D66450, 0xBDD31108, 0xD5D03DC0, + 0x8618701A, 0xEE1B5CD2, 0x561E298A, 0x3E1D0542, + 0x23F8B5CB, 0x4BFB9903, 0xF3FEEC5B, 0x9BFDC093, + 0x546E77EF, 0x3C6D5B27, 0x84682E7F, 0xEC6B02B7, + 0xF18EB23E, 0x998D9EF6, 0x2188EBAE, 0x498BC766, + 0x1A438ABC, 0x7240A674, 0xCA45D32C, 0xA246FFE4, + 0xBFA34F6D, 0xD7A063A5, 0x6FA516FD, 0x07A63A35, + 0x8FD9098E, 0xE7DA2546, 0x5FDF501E, 0x37DC7CD6, + 0x2A39CC5F, 0x423AE097, 0xFA3F95CF, 0x923CB907, + 0xC1F4F4DD, 0xA9F7D815, 0x11F2AD4D, 0x79F18185, + 0x6414310C, 0x0C171DC4, 0xB412689C, 0xDC114454, + 0x1382F328, 0x7B81DFE0, 0xC384AAB8, 0xAB878670, + 0xB66236F9, 0xDE611A31, 0x66646F69, 0x0E6743A1, + 0x5DAF0E7B, 0x35AC22B3, 0x8DA957EB, 0xE5AA7B23, + 0xF84FCBAA, 0x904CE762, 0x2849923A, 0x404ABEF2, + 0xB2828A33, 0xDA81A6FB, 0x6284D3A3, 0x0A87FF6B, + 0x17624FE2, 0x7F61632A, 0xC7641672, 0xAF673ABA, + 0xFCAF7760, 0x94AC5BA8, 0x2CA92EF0, 0x44AA0238, + 0x594FB2B1, 0x314C9E79, 0x8949EB21, 0xE14AC7E9, + 0x2ED97095, 0x46DA5C5D, 0xFEDF2905, 0x96DC05CD, + 0x8B39B544, 0xE33A998C, 0x5B3FECD4, 0x333CC01C, + 0x60F48DC6, 0x08F7A10E, 0xB0F2D456, 0xD8F1F89E, + 0xC5144817, 0xAD1764DF, 0x15121187, 0x7D113D4F, + /* T8_7 */ + 0x00000000, 0x493C7D27, 0x9278FA4E, 0xDB448769, + 0x211D826D, 0x6821FF4A, 0xB3657823, 0xFA590504, + 0x423B04DA, 0x0B0779FD, 0xD043FE94, 0x997F83B3, + 0x632686B7, 0x2A1AFB90, 0xF15E7CF9, 0xB86201DE, + 0x847609B4, 0xCD4A7493, 0x160EF3FA, 0x5F328EDD, + 0xA56B8BD9, 0xEC57F6FE, 0x37137197, 0x7E2F0CB0, + 0xC64D0D6E, 0x8F717049, 0x5435F720, 0x1D098A07, + 0xE7508F03, 0xAE6CF224, 0x7528754D, 0x3C14086A, + 0x0D006599, 0x443C18BE, 0x9F789FD7, 0xD644E2F0, + 0x2C1DE7F4, 0x65219AD3, 0xBE651DBA, 0xF759609D, + 0x4F3B6143, 0x06071C64, 0xDD439B0D, 0x947FE62A, + 0x6E26E32E, 0x271A9E09, 0xFC5E1960, 0xB5626447, + 0x89766C2D, 0xC04A110A, 0x1B0E9663, 0x5232EB44, + 0xA86BEE40, 0xE1579367, 0x3A13140E, 0x732F6929, + 0xCB4D68F7, 0x827115D0, 0x593592B9, 0x1009EF9E, + 0xEA50EA9A, 0xA36C97BD, 0x782810D4, 0x31146DF3, + 0x1A00CB32, 0x533CB615, 0x8878317C, 0xC1444C5B, + 0x3B1D495F, 0x72213478, 0xA965B311, 0xE059CE36, + 0x583BCFE8, 0x1107B2CF, 0xCA4335A6, 0x837F4881, + 0x79264D85, 0x301A30A2, 0xEB5EB7CB, 0xA262CAEC, + 0x9E76C286, 0xD74ABFA1, 0x0C0E38C8, 0x453245EF, + 0xBF6B40EB, 0xF6573DCC, 0x2D13BAA5, 0x642FC782, + 0xDC4DC65C, 0x9571BB7B, 0x4E353C12, 0x07094135, + 0xFD504431, 0xB46C3916, 0x6F28BE7F, 0x2614C358, + 0x1700AEAB, 0x5E3CD38C, 0x857854E5, 0xCC4429C2, + 0x361D2CC6, 0x7F2151E1, 0xA465D688, 0xED59ABAF, + 0x553BAA71, 0x1C07D756, 0xC743503F, 0x8E7F2D18, + 0x7426281C, 0x3D1A553B, 0xE65ED252, 0xAF62AF75, + 0x9376A71F, 0xDA4ADA38, 0x010E5D51, 0x48322076, + 0xB26B2572, 0xFB575855, 0x2013DF3C, 0x692FA21B, + 0xD14DA3C5, 0x9871DEE2, 0x4335598B, 0x0A0924AC, + 0xF05021A8, 0xB96C5C8F, 0x6228DBE6, 0x2B14A6C1, + 0x34019664, 0x7D3DEB43, 0xA6796C2A, 0xEF45110D, + 0x151C1409, 0x5C20692E, 0x8764EE47, 0xCE589360, + 0x763A92BE, 0x3F06EF99, 0xE44268F0, 0xAD7E15D7, + 0x572710D3, 0x1E1B6DF4, 0xC55FEA9D, 0x8C6397BA, + 0xB0779FD0, 0xF94BE2F7, 0x220F659E, 0x6B3318B9, + 0x916A1DBD, 0xD856609A, 0x0312E7F3, 0x4A2E9AD4, + 0xF24C9B0A, 0xBB70E62D, 0x60346144, 0x29081C63, + 0xD3511967, 0x9A6D6440, 0x4129E329, 0x08159E0E, + 0x3901F3FD, 0x703D8EDA, 0xAB7909B3, 0xE2457494, + 0x181C7190, 0x51200CB7, 0x8A648BDE, 0xC358F6F9, + 0x7B3AF727, 0x32068A00, 0xE9420D69, 0xA07E704E, + 0x5A27754A, 0x131B086D, 0xC85F8F04, 0x8163F223, + 0xBD77FA49, 0xF44B876E, 0x2F0F0007, 0x66337D20, + 0x9C6A7824, 0xD5560503, 0x0E12826A, 0x472EFF4D, + 0xFF4CFE93, 0xB67083B4, 0x6D3404DD, 0x240879FA, + 0xDE517CFE, 0x976D01D9, 0x4C2986B0, 0x0515FB97, + 0x2E015D56, 0x673D2071, 0xBC79A718, 0xF545DA3F, + 0x0F1CDF3B, 0x4620A21C, 0x9D642575, 0xD4585852, + 0x6C3A598C, 0x250624AB, 0xFE42A3C2, 0xB77EDEE5, + 0x4D27DBE1, 0x041BA6C6, 0xDF5F21AF, 0x96635C88, + 0xAA7754E2, 0xE34B29C5, 0x380FAEAC, 0x7133D38B, + 0x8B6AD68F, 0xC256ABA8, 0x19122CC1, 0x502E51E6, + 0xE84C5038, 0xA1702D1F, 0x7A34AA76, 0x3308D751, + 0xC951D255, 0x806DAF72, 0x5B29281B, 0x1215553C, + 0x230138CF, 0x6A3D45E8, 0xB179C281, 0xF845BFA6, + 0x021CBAA2, 0x4B20C785, 0x906440EC, 0xD9583DCB, + 0x613A3C15, 0x28064132, 0xF342C65B, 0xBA7EBB7C, + 0x4027BE78, 0x091BC35F, 0xD25F4436, 0x9B633911, + 0xA777317B, 0xEE4B4C5C, 0x350FCB35, 0x7C33B612, + 0x866AB316, 0xCF56CE31, 0x14124958, 0x5D2E347F, + 0xE54C35A1, 0xAC704886, 0x7734CFEF, 0x3E08B2C8, + 0xC451B7CC, 0x8D6DCAEB, 0x56294D82, 0x1F1530A5 + }; +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/Sanitizer.java b/clients/src/main/java/org/apache/kafka/common/utils/Sanitizer.java new file mode 100644 index 0000000..f921590 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/Sanitizer.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import java.io.UnsupportedEncodingException; +import java.net.URLDecoder; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.util.regex.Pattern; + +import javax.management.ObjectName; + +import org.apache.kafka.common.KafkaException; + +/** + * Utility class for sanitizing/desanitizing/quoting values used in JMX metric names + * or as ZooKeeper node name. + *

            + * User principals and client-ids are URL-encoded using ({@link #sanitize(String)} + * for use as ZooKeeper node names. User principals are URL-encoded in all metric + * names as well. All other metric tags including client-id are quoted if they + * contain special characters using {@link #jmxSanitize(String)} when + * registering in JMX. + */ +public class Sanitizer { + + /** + * Even though only a small number of characters are disallowed in JMX, quote any + * string containing special characters to be safe. All characters in strings sanitized + * using {@link #sanitize(String)} are safe for JMX and hence included here. + */ + private static final Pattern MBEAN_PATTERN = Pattern.compile("[\\w-%\\. \t]*"); + + /** + * Sanitize `name` for safe use as JMX metric name as well as ZooKeeper node name + * using URL-encoding. + */ + public static String sanitize(String name) { + String encoded = ""; + try { + encoded = URLEncoder.encode(name, StandardCharsets.UTF_8.name()); + StringBuilder builder = new StringBuilder(); + for (int i = 0; i < encoded.length(); i++) { + char c = encoded.charAt(i); + if (c == '*') { // Metric ObjectName treats * as pattern + builder.append("%2A"); + } else if (c == '+') { // Space URL-encoded as +, replace with percent encoding + builder.append("%20"); + } else { + builder.append(c); + } + } + return builder.toString(); + } catch (UnsupportedEncodingException e) { + throw new KafkaException(e); + } + } + + /** + * Desanitize name that was URL-encoded using {@link #sanitize(String)}. This + * is used to obtain the desanitized version of node names in ZooKeeper. + */ + public static String desanitize(String name) { + try { + return URLDecoder.decode(name, StandardCharsets.UTF_8.name()); + } catch (UnsupportedEncodingException e) { + throw new KafkaException(e); + } + } + + /** + * Quote `name` using {@link ObjectName#quote(String)} if `name` contains + * characters that are not safe for use in JMX. User principals that are + * already sanitized using {@link #sanitize(String)} will not be quoted + * since they are safe for JMX. + */ + public static String jmxSanitize(String name) { + return MBEAN_PATTERN.matcher(name).matches() ? name : ObjectName.quote(name); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/Scheduler.java b/clients/src/main/java/org/apache/kafka/common/utils/Scheduler.java new file mode 100644 index 0000000..a8ada65 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/Scheduler.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import java.util.concurrent.Callable; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; + +/** + * An interface for scheduling tasks for the future. + * + * Implementations of this class should be thread-safe. + */ +public interface Scheduler { + Scheduler SYSTEM = new SystemScheduler(); + + /** + * Get the timekeeper associated with this scheduler. + */ + Time time(); + + /** + * Schedule a callable to be executed in the future on a + * ScheduledExecutorService. Note that the Callable may not be queued on + * the executor until the designated time arrives. + * + * @param executor The executor to use. + * @param callable The callable to execute. + * @param delayMs The delay to use, in milliseconds. + * @param The return type of the callable. + * @return A future which will complete when the callable is finished. + */ + Future schedule(final ScheduledExecutorService executor, + final Callable callable, long delayMs); +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/SecurityUtils.java b/clients/src/main/java/org/apache/kafka/common/utils/SecurityUtils.java new file mode 100644 index 0000000..88a4cfc --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/SecurityUtils.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.apache.kafka.common.acl.AclOperation; +import org.apache.kafka.common.acl.AclPermissionType; +import org.apache.kafka.common.config.SecurityConfig; +import org.apache.kafka.common.resource.PatternType; +import org.apache.kafka.common.resource.ResourcePattern; +import org.apache.kafka.common.resource.ResourceType; +import org.apache.kafka.common.security.auth.SecurityProviderCreator; +import org.apache.kafka.common.security.auth.KafkaPrincipal; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.security.Security; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; + +public class SecurityUtils { + + private static final Logger LOGGER = LoggerFactory.getLogger(SecurityConfig.class); + + private static final Map NAME_TO_RESOURCE_TYPES; + private static final Map NAME_TO_OPERATIONS; + private static final Map NAME_TO_PERMISSION_TYPES; + + static { + NAME_TO_RESOURCE_TYPES = new HashMap<>(ResourceType.values().length); + NAME_TO_OPERATIONS = new HashMap<>(AclOperation.values().length); + NAME_TO_PERMISSION_TYPES = new HashMap<>(AclPermissionType.values().length); + + for (ResourceType resourceType : ResourceType.values()) { + String resourceTypeName = toPascalCase(resourceType.name()); + NAME_TO_RESOURCE_TYPES.put(resourceTypeName, resourceType); + NAME_TO_RESOURCE_TYPES.put(resourceTypeName.toUpperCase(Locale.ROOT), resourceType); + } + for (AclOperation operation : AclOperation.values()) { + String operationName = toPascalCase(operation.name()); + NAME_TO_OPERATIONS.put(operationName, operation); + NAME_TO_OPERATIONS.put(operationName.toUpperCase(Locale.ROOT), operation); + } + for (AclPermissionType permissionType : AclPermissionType.values()) { + String permissionName = toPascalCase(permissionType.name()); + NAME_TO_PERMISSION_TYPES.put(permissionName, permissionType); + NAME_TO_PERMISSION_TYPES.put(permissionName.toUpperCase(Locale.ROOT), permissionType); + } + } + + public static KafkaPrincipal parseKafkaPrincipal(String str) { + if (str == null || str.isEmpty()) { + throw new IllegalArgumentException("expected a string in format principalType:principalName but got " + str); + } + + String[] split = str.split(":", 2); + + if (split.length != 2) { + throw new IllegalArgumentException("expected a string in format principalType:principalName but got " + str); + } + + return new KafkaPrincipal(split[0], split[1]); + } + + public static void addConfiguredSecurityProviders(Map configs) { + String securityProviderClassesStr = (String) configs.get(SecurityConfig.SECURITY_PROVIDERS_CONFIG); + if (securityProviderClassesStr == null || securityProviderClassesStr.equals("")) { + return; + } + try { + String[] securityProviderClasses = securityProviderClassesStr.replaceAll("\\s+", "").split(","); + for (int index = 0; index < securityProviderClasses.length; index++) { + SecurityProviderCreator securityProviderCreator = + (SecurityProviderCreator) Class.forName(securityProviderClasses[index]).getConstructor().newInstance(); + securityProviderCreator.configure(configs); + Security.insertProviderAt(securityProviderCreator.getProvider(), index + 1); + } + } catch (ClassCastException e) { + LOGGER.error("Creators provided through " + SecurityConfig.SECURITY_PROVIDERS_CONFIG + + " are expected to be sub-classes of SecurityProviderCreator"); + } catch (ClassNotFoundException cnfe) { + LOGGER.error("Unrecognized security provider creator class", cnfe); + } catch (ReflectiveOperationException e) { + LOGGER.error("Unexpected implementation of security provider creator class", e); + } + } + + public static ResourceType resourceType(String name) { + return valueFromMap(NAME_TO_RESOURCE_TYPES, name, ResourceType.UNKNOWN); + } + + public static AclOperation operation(String name) { + return valueFromMap(NAME_TO_OPERATIONS, name, AclOperation.UNKNOWN); + } + + public static AclPermissionType permissionType(String name) { + return valueFromMap(NAME_TO_PERMISSION_TYPES, name, AclPermissionType.UNKNOWN); + } + + // We use Pascal-case to store these values, so lookup using provided key first to avoid + // case conversion for the common case. For backward compatibility, also perform + // case-insensitive look up (without underscores) by converting the key to upper-case. + private static T valueFromMap(Map map, String key, T unknown) { + T value = map.get(key); + if (value == null) { + value = map.get(key.toUpperCase(Locale.ROOT)); + } + return value == null ? unknown : value; + } + + public static String resourceTypeName(ResourceType resourceType) { + return toPascalCase(resourceType.name()); + } + + public static String operationName(AclOperation operation) { + return toPascalCase(operation.name()); + } + + public static String permissionTypeName(AclPermissionType permissionType) { + return toPascalCase(permissionType.name()); + } + + private static String toPascalCase(String name) { + StringBuilder builder = new StringBuilder(); + boolean capitalizeNext = true; + for (char c : name.toCharArray()) { + if (c == '_') + capitalizeNext = true; + else if (capitalizeNext) { + builder.append(Character.toUpperCase(c)); + capitalizeNext = false; + } else + builder.append(Character.toLowerCase(c)); + } + return builder.toString(); + } + + public static void authorizeByResourceTypeCheckArgs(AclOperation op, + ResourceType type) { + if (type == ResourceType.ANY) { + throw new IllegalArgumentException( + "Must specify a non-filter resource type for authorizeByResourceType"); + } + + if (type == ResourceType.UNKNOWN) { + throw new IllegalArgumentException( + "Unknown resource type"); + } + + if (op == AclOperation.ANY) { + throw new IllegalArgumentException( + "Must specify a non-filter operation type for authorizeByResourceType"); + } + + if (op == AclOperation.UNKNOWN) { + throw new IllegalArgumentException( + "Unknown operation type"); + } + } + + public static boolean denyAll(ResourcePattern pattern) { + return pattern.patternType() == PatternType.LITERAL + && pattern.name().equals(ResourcePattern.WILDCARD_RESOURCE); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/Shell.java b/clients/src/main/java/org/apache/kafka/common/utils/Shell.java new file mode 100644 index 0000000..a9b93ec --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/Shell.java @@ -0,0 +1,299 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.util.Timer; +import java.util.TimerTask; +import java.util.concurrent.atomic.AtomicBoolean; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A base class for running a Unix command. + * + * Shell can be used to run unix commands like du or + * df. + */ +abstract public class Shell { + + private static final Logger LOG = LoggerFactory.getLogger(Shell.class); + + /** Return an array containing the command name and its parameters */ + protected abstract String[] execString(); + + /** Parse the execution result */ + protected abstract void parseExecResult(BufferedReader lines) throws IOException; + + private final long timeout; + + private int exitCode; + private Process process; // sub process used to execute the command + + /* If or not script finished executing */ + private volatile AtomicBoolean completed; + + /** + * @param timeout Specifies the time in milliseconds, after which the command will be killed. -1 means no timeout. + */ + public Shell(long timeout) { + this.timeout = timeout; + } + + /** get the exit code + * @return the exit code of the process + */ + public int exitCode() { + return exitCode; + } + + /** get the current sub-process executing the given command + * @return process executing the command + */ + public Process process() { + return process; + } + + protected void run() throws IOException { + exitCode = 0; // reset for next run + runCommand(); + } + + /** Run a command */ + private void runCommand() throws IOException { + ProcessBuilder builder = new ProcessBuilder(execString()); + Timer timeoutTimer = null; + completed = new AtomicBoolean(false); + + process = builder.start(); + if (timeout > -1) { + timeoutTimer = new Timer(); + //One time scheduling. + timeoutTimer.schedule(new ShellTimeoutTimerTask(this), timeout); + } + final BufferedReader errReader = new BufferedReader( + new InputStreamReader(process.getErrorStream(), StandardCharsets.UTF_8)); + BufferedReader inReader = new BufferedReader( + new InputStreamReader(process.getInputStream(), StandardCharsets.UTF_8)); + final StringBuffer errMsg = new StringBuffer(); + + // read error and input streams as this would free up the buffers + // free the error stream buffer + Thread errThread = KafkaThread.nonDaemon("kafka-shell-thread", new Runnable() { + @Override + public void run() { + try { + String line = errReader.readLine(); + while ((line != null) && !Thread.currentThread().isInterrupted()) { + errMsg.append(line); + errMsg.append(System.getProperty("line.separator")); + line = errReader.readLine(); + } + } catch (IOException ioe) { + LOG.warn("Error reading the error stream", ioe); + } + } + }); + errThread.start(); + + try { + parseExecResult(inReader); // parse the output + // wait for the process to finish and check the exit code + exitCode = process.waitFor(); + try { + // make sure that the error thread exits + errThread.join(); + } catch (InterruptedException ie) { + LOG.warn("Interrupted while reading the error stream", ie); + } + completed.set(true); + //the timeout thread handling + //taken care in finally block + if (exitCode != 0) { + throw new ExitCodeException(exitCode, errMsg.toString()); + } + } catch (InterruptedException ie) { + throw new IOException(ie.toString()); + } finally { + if (timeoutTimer != null) + timeoutTimer.cancel(); + + // close the input stream + try { + inReader.close(); + } catch (IOException ioe) { + LOG.warn("Error while closing the input stream", ioe); + } + if (!completed.get()) + errThread.interrupt(); + + try { + errReader.close(); + } catch (IOException ioe) { + LOG.warn("Error while closing the error stream", ioe); + } + + process.destroy(); + } + } + + + /** + * This is an IOException with exit code added. + */ + @SuppressWarnings("serial") + public static class ExitCodeException extends IOException { + int exitCode; + + public ExitCodeException(int exitCode, String message) { + super(message); + this.exitCode = exitCode; + } + + public int getExitCode() { + return exitCode; + } + } + + /** + * A simple shell command executor. + * + * ShellCommandExecutorshould be used in cases where the output + * of the command needs no explicit parsing and where the command, working + * directory and the environment remains unchanged. The output of the command + * is stored as-is and is expected to be small. + */ + public static class ShellCommandExecutor extends Shell { + + private final String[] command; + private StringBuffer output; + + /** + * Create a new instance of the ShellCommandExecutor to execute a command. + * + * @param execString The command to execute with arguments + * @param timeout Specifies the time in milliseconds, after which the + * command will be killed. -1 means no timeout. + */ + + public ShellCommandExecutor(String[] execString, long timeout) { + super(timeout); + command = execString.clone(); + } + + + /** Execute the shell command. */ + public void execute() throws IOException { + this.run(); + } + + protected String[] execString() { + return command; + } + + protected void parseExecResult(BufferedReader reader) throws IOException { + output = new StringBuffer(); + char[] buf = new char[512]; + int nRead; + while ((nRead = reader.read(buf, 0, buf.length)) > 0) { + output.append(buf, 0, nRead); + } + } + + /** Get the output of the shell command.*/ + public String output() { + return (output == null) ? "" : output.toString(); + } + + /** + * Returns the commands of this instance. + * Arguments with spaces in are presented with quotes round; other + * arguments are presented raw + * + * @return a string representation of the object. + */ + public String toString() { + StringBuilder builder = new StringBuilder(); + String[] args = execString(); + for (String s : args) { + if (s.indexOf(' ') >= 0) { + builder.append('"').append(s).append('"'); + } else { + builder.append(s); + } + builder.append(' '); + } + return builder.toString(); + } + } + + /** + * Static method to execute a shell command. + * Covers most of the simple cases without requiring the user to implement + * the Shell interface. + * @param cmd shell command to execute. + * @return the output of the executed command. + */ + public static String execCommand(String... cmd) throws IOException { + return execCommand(cmd, -1); + } + + /** + * Static method to execute a shell command. + * Covers most of the simple cases without requiring the user to implement + * the Shell interface. + * @param cmd shell command to execute. + * @param timeout time in milliseconds after which script should be killed. -1 means no timeout. + * @return the output of the executed command. + */ + public static String execCommand(String[] cmd, long timeout) throws IOException { + ShellCommandExecutor exec = new ShellCommandExecutor(cmd, timeout); + exec.execute(); + return exec.output(); + } + + /** + * Timer which is used to timeout scripts spawned off by shell. + */ + private static class ShellTimeoutTimerTask extends TimerTask { + + private final Shell shell; + + public ShellTimeoutTimerTask(Shell shell) { + this.shell = shell; + } + + @Override + public void run() { + Process p = shell.process(); + try { + p.exitValue(); + } catch (Exception e) { + //Process has not terminated. + //So check if it has completed + //if not just destroy it. + if (p != null && !shell.completed.get()) { + p.destroy(); + } + } + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/SystemScheduler.java b/clients/src/main/java/org/apache/kafka/common/utils/SystemScheduler.java new file mode 100644 index 0000000..c8c1148 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/SystemScheduler.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import java.util.concurrent.Callable; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +/** + * A scheduler implementation that uses the system clock. + * + * Use Scheduler.SYSTEM instead of constructing an instance of this class. + */ +public class SystemScheduler implements Scheduler { + SystemScheduler() { + } + + @Override + public Time time() { + return Time.SYSTEM; + } + + @Override + public Future schedule(final ScheduledExecutorService executor, + final Callable callable, long delayMs) { + return executor.schedule(callable, delayMs, TimeUnit.MILLISECONDS); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/SystemTime.java b/clients/src/main/java/org/apache/kafka/common/utils/SystemTime.java new file mode 100644 index 0000000..31919a2 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/SystemTime.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.apache.kafka.common.errors.TimeoutException; + +import java.util.function.Supplier; + +/** + * A time implementation that uses the system clock and sleep call. Use `Time.SYSTEM` instead of creating an instance + * of this class. + */ +public class SystemTime implements Time { + + @Override + public long milliseconds() { + return System.currentTimeMillis(); + } + + @Override + public long nanoseconds() { + return System.nanoTime(); + } + + @Override + public void sleep(long ms) { + Utils.sleep(ms); + } + + @Override + public void waitObject(Object obj, Supplier condition, long deadlineMs) throws InterruptedException { + synchronized (obj) { + while (true) { + if (condition.get()) + return; + + long currentTimeMs = milliseconds(); + if (currentTimeMs >= deadlineMs) + throw new TimeoutException("Condition not satisfied before deadline"); + + obj.wait(deadlineMs - currentTimeMs); + } + } + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/ThreadUtils.java b/clients/src/main/java/org/apache/kafka/common/utils/ThreadUtils.java new file mode 100644 index 0000000..750c8d7 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/ThreadUtils.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.utils; + +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicLong; + +/** + * Utilities for working with threads. + */ +public class ThreadUtils { + /** + * Create a new ThreadFactory. + * + * @param pattern The pattern to use. If this contains %d, it will be + * replaced with a thread number. It should not contain more + * than one %d. + * @param daemon True if we want daemon threads. + * @return The new ThreadFactory. + */ + public static ThreadFactory createThreadFactory(final String pattern, + final boolean daemon) { + return new ThreadFactory() { + private final AtomicLong threadEpoch = new AtomicLong(0); + + @Override + public Thread newThread(Runnable r) { + String threadName; + if (pattern.contains("%d")) { + threadName = String.format(pattern, threadEpoch.addAndGet(1)); + } else { + threadName = pattern; + } + Thread thread = new Thread(r, threadName); + thread.setDaemon(daemon); + return thread; + } + }; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/Time.java b/clients/src/main/java/org/apache/kafka/common/utils/Time.java new file mode 100644 index 0000000..9e0a475 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/Time.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import java.time.Duration; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +/** + * An interface abstracting the clock to use in unit testing classes that make use of clock time. + * + * Implementations of this class should be thread-safe. + */ +public interface Time { + + Time SYSTEM = new SystemTime(); + + /** + * Returns the current time in milliseconds. + */ + long milliseconds(); + + /** + * Returns the value returned by `nanoseconds` converted into milliseconds. + */ + default long hiResClockMs() { + return TimeUnit.NANOSECONDS.toMillis(nanoseconds()); + } + + /** + * Returns the current value of the running JVM's high-resolution time source, in nanoseconds. + * + *

            This method can only be used to measure elapsed time and is + * not related to any other notion of system or wall-clock time. + * The value returned represents nanoseconds since some fixed but + * arbitrary origin time (perhaps in the future, so values + * may be negative). The same origin is used by all invocations of + * this method in an instance of a Java virtual machine; other + * virtual machine instances are likely to use a different origin. + */ + long nanoseconds(); + + /** + * Sleep for the given number of milliseconds + */ + void sleep(long ms); + + /** + * Wait for a condition using the monitor of a given object. This avoids the implicit + * dependence on system time when calling {@link Object#wait()}. + * + * @param obj The object that will be waited with {@link Object#wait()}. Note that it is the responsibility + * of the caller to call notify on this object when the condition is satisfied. + * @param condition The condition we are awaiting + * @param deadlineMs The deadline timestamp at which to raise a timeout error + * + * @throws org.apache.kafka.common.errors.TimeoutException if the timeout expires before the condition is satisfied + */ + void waitObject(Object obj, Supplier condition, long deadlineMs) throws InterruptedException; + + /** + * Get a timer which is bound to this time instance and expires after the given timeout + */ + default Timer timer(long timeoutMs) { + return new Timer(this, timeoutMs); + } + + /** + * Get a timer which is bound to this time instance and expires after the given timeout + */ + default Timer timer(Duration timeout) { + return timer(timeout.toMillis()); + } + +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/Timer.java b/clients/src/main/java/org/apache/kafka/common/utils/Timer.java new file mode 100644 index 0000000..98b09a3 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/Timer.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.utils; + +/** + * This is a helper class which makes blocking methods with a timeout easier to implement. + * In particular it enables use cases where a high-level blocking call with a timeout is + * composed of several lower level calls, each of which has their own respective timeouts. The idea + * is to create a single timer object for the high level timeout and carry it along to + * all of the lower level methods. This class also handles common problems such as integer overflow. + * This class also ensures monotonic updates to the timer even if the underlying clock is subject + * to non-monotonic behavior. For example, the remaining time returned by {@link #remainingMs()} is + * guaranteed to decrease monotonically until it hits zero. + * + * Note that it is up to the caller to ensure progress of the timer using one of the + * {@link #update()} methods or {@link #sleep(long)}. The timer will cache the current time and + * return it indefinitely until the timer has been updated. This allows the caller to limit + * unnecessary system calls and update the timer only when needed. For example, a timer which is + * waiting a request sent through the {@link org.apache.kafka.clients.NetworkClient} should call + * {@link #update()} following each blocking call to + * {@link org.apache.kafka.clients.NetworkClient#poll(long, long)}. + * + * A typical usage might look something like this: + * + *

            + *     Time time = Time.SYSTEM;
            + *     Timer timer = time.timer(500);
            + *
            + *     while (!conditionSatisfied() && timer.notExpired) {
            + *         client.poll(timer.remainingMs(), timer.currentTimeMs());
            + *         timer.update();
            + *     }
            + * 
            + */ +public class Timer { + private final Time time; + private long startMs; + private long currentTimeMs; + private long deadlineMs; + private long timeoutMs; + + Timer(Time time, long timeoutMs) { + this.time = time; + update(); + reset(timeoutMs); + } + + /** + * Check timer expiration. Like {@link #remainingMs()}, this depends on the current cached + * time in milliseconds, which is only updated through one of the {@link #update()} methods + * or with {@link #sleep(long)}; + * + * @return true if the timer has expired, false otherwise + */ + public boolean isExpired() { + return currentTimeMs >= deadlineMs; + } + + /** + * Check whether the timer has not yet expired. + * @return true if there is still time remaining before expiration + */ + public boolean notExpired() { + return !isExpired(); + } + + /** + * Reset the timer to the specific timeout. This will use the underlying {@link #Timer(Time, long)} + * implementation to update the current cached time in milliseconds and it will set a new timer + * deadline. + * + * @param timeoutMs The new timeout in milliseconds + */ + public void updateAndReset(long timeoutMs) { + update(); + reset(timeoutMs); + } + + /** + * Reset the timer using a new timeout. Note that this does not update the cached current time + * in milliseconds, so it typically must be accompanied with a separate call to {@link #update()}. + * Typically, you can just use {@link #updateAndReset(long)}. + * + * @param timeoutMs The new timeout in milliseconds + */ + public void reset(long timeoutMs) { + if (timeoutMs < 0) + throw new IllegalArgumentException("Invalid negative timeout " + timeoutMs); + + this.timeoutMs = timeoutMs; + this.startMs = this.currentTimeMs; + + if (currentTimeMs > Long.MAX_VALUE - timeoutMs) + this.deadlineMs = Long.MAX_VALUE; + else + this.deadlineMs = currentTimeMs + timeoutMs; + } + + /** + * Reset the timer's deadline directly. + * + * @param deadlineMs The new deadline in milliseconds + */ + public void resetDeadline(long deadlineMs) { + if (deadlineMs < 0) + throw new IllegalArgumentException("Invalid negative deadline " + deadlineMs); + + this.timeoutMs = Math.max(0, deadlineMs - this.currentTimeMs); + this.startMs = this.currentTimeMs; + this.deadlineMs = deadlineMs; + } + + /** + * Use the underlying {@link Time} implementation to update the current cached time. If + * the underlying time returns a value which is smaller than the current cached time, + * the update will be ignored. + */ + public void update() { + update(time.milliseconds()); + } + + /** + * Update the cached current time to a specific value. In some contexts, the caller may already + * have an accurate time, so this avoids unnecessary calls to system time. + * + * Note that if the updated current time is smaller than the cached time, then the update + * is ignored. + * + * @param currentTimeMs The current time in milliseconds to cache + */ + public void update(long currentTimeMs) { + this.currentTimeMs = Math.max(currentTimeMs, this.currentTimeMs); + } + + /** + * Get the remaining time in milliseconds until the timer expires. Like {@link #currentTimeMs}, + * this depends on the cached current time, so the returned value will not change until the timer + * has been updated using one of the {@link #update()} methods or {@link #sleep(long)}. + * + * @return The cached remaining time in milliseconds until timer expiration + */ + public long remainingMs() { + return Math.max(0, deadlineMs - currentTimeMs); + } + + /** + * Get the current time in milliseconds. This will return the same cached value until the timer + * has been updated using one of the {@link #update()} methods or {@link #sleep(long)} is used. + * + * Note that the value returned is guaranteed to increase monotonically even if the underlying + * {@link Time} implementation goes backwards. Effectively, the timer will just wait for the + * time to catch up. + * + * @return The current cached time in milliseconds + */ + public long currentTimeMs() { + return currentTimeMs; + } + + /** + * Get the amount of time that has elapsed since the timer began. If the timer was reset, this + * will be the amount of time since the last reset. + * + * @return The elapsed time since construction or the last reset + */ + public long elapsedMs() { + return currentTimeMs - startMs; + } + + /** + * Get the current timeout value specified through {@link #reset(long)} or {@link #resetDeadline(long)}. + * This value is constant until altered by one of these API calls. + * + * @return The timeout in milliseconds + */ + public long timeoutMs() { + return timeoutMs; + } + + /** + * Sleep for the requested duration and update the timer. Return when either the duration has + * elapsed or the timer has expired. + * + * @param durationMs The duration in milliseconds to sleep + */ + public void sleep(long durationMs) { + long sleepDurationMs = Math.min(durationMs, remainingMs()); + time.sleep(sleepDurationMs); + update(); + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/utils/Utils.java b/clients/src/main/java/org/apache/kafka/common/utils/Utils.java new file mode 100755 index 0000000..a04f5c5 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/utils/Utils.java @@ -0,0 +1,1417 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import java.nio.BufferUnderflowException; +import java.nio.file.StandardOpenOption; +import java.util.AbstractMap; +import java.util.EnumSet; +import java.util.Map.Entry; +import java.util.SortedSet; +import java.util.TreeSet; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.network.TransferableChannel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Closeable; +import java.io.DataOutput; +import java.io.EOFException; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.io.PrintWriter; +import java.io.StringWriter; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.charset.StandardCharsets; +import java.nio.file.FileVisitResult; +import java.nio.file.Files; +import java.nio.file.NoSuchFileException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.SimpleFileVisitor; +import java.nio.file.StandardCopyOption; +import java.nio.file.attribute.BasicFileAttributes; +import java.text.DecimalFormat; +import java.text.DecimalFormatSymbols; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; +import java.util.function.BinaryOperator; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.function.Supplier; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collector; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import java.text.ParseException; +import java.text.SimpleDateFormat; +import java.util.Date; + +public final class Utils { + + private Utils() {} + + // This matches URIs of formats: host:port and protocol:\\host:port + // IPv6 is supported with [ip] pattern + private static final Pattern HOST_PORT_PATTERN = Pattern.compile(".*?\\[?([0-9a-zA-Z\\-%._:]*)\\]?:([0-9]+)"); + + private static final Pattern VALID_HOST_CHARACTERS = Pattern.compile("([0-9a-zA-Z\\-%._:]*)"); + + // Prints up to 2 decimal digits. Used for human readable printing + private static final DecimalFormat TWO_DIGIT_FORMAT = new DecimalFormat("0.##", + DecimalFormatSymbols.getInstance(Locale.ENGLISH)); + + private static final String[] BYTE_SCALE_SUFFIXES = new String[] {"B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"}; + + public static final String NL = System.getProperty("line.separator"); + + private static final Logger log = LoggerFactory.getLogger(Utils.class); + + /** + * Get a sorted list representation of a collection. + * @param collection The collection to sort + * @param The class of objects in the collection + * @return An unmodifiable sorted list with the contents of the collection + */ + public static > List sorted(Collection collection) { + List res = new ArrayList<>(collection); + Collections.sort(res); + return Collections.unmodifiableList(res); + } + + /** + * Turn the given UTF8 byte array into a string + * + * @param bytes The byte array + * @return The string + */ + public static String utf8(byte[] bytes) { + return new String(bytes, StandardCharsets.UTF_8); + } + + /** + * Read a UTF8 string from a byte buffer. Note that the position of the byte buffer is not affected + * by this method. + * + * @param buffer The buffer to read from + * @param length The length of the string in bytes + * @return The UTF8 string + */ + public static String utf8(ByteBuffer buffer, int length) { + return utf8(buffer, 0, length); + } + + /** + * Read a UTF8 string from the current position till the end of a byte buffer. The position of the byte buffer is + * not affected by this method. + * + * @param buffer The buffer to read from + * @return The UTF8 string + */ + public static String utf8(ByteBuffer buffer) { + return utf8(buffer, buffer.remaining()); + } + + /** + * Read a UTF8 string from a byte buffer at a given offset. Note that the position of the byte buffer + * is not affected by this method. + * + * @param buffer The buffer to read from + * @param offset The offset relative to the current position in the buffer + * @param length The length of the string in bytes + * @return The UTF8 string + */ + public static String utf8(ByteBuffer buffer, int offset, int length) { + if (buffer.hasArray()) + return new String(buffer.array(), buffer.arrayOffset() + buffer.position() + offset, length, StandardCharsets.UTF_8); + else + return utf8(toArray(buffer, offset, length)); + } + + /** + * Turn a string into a utf8 byte[] + * + * @param string The string + * @return The byte[] + */ + public static byte[] utf8(String string) { + return string.getBytes(StandardCharsets.UTF_8); + } + + /** + * Get the absolute value of the given number. If the number is Int.MinValue return 0. This is different from + * java.lang.Math.abs or scala.math.abs in that they return Int.MinValue (!). + */ + public static int abs(int n) { + return (n == Integer.MIN_VALUE) ? 0 : Math.abs(n); + } + + /** + * Get the minimum of some long values. + * @param first Used to ensure at least one value + * @param rest The remaining values to compare + * @return The minimum of all passed values + */ + public static long min(long first, long... rest) { + long min = first; + for (long r : rest) { + if (r < min) + min = r; + } + return min; + } + + /** + * Get the maximum of some long values. + * @param first Used to ensure at least one value + * @param rest The remaining values to compare + * @return The maximum of all passed values + */ + public static long max(long first, long... rest) { + long max = first; + for (long r : rest) { + if (r > max) + max = r; + } + return max; + } + + + public static short min(short first, short second) { + return (short) Math.min(first, second); + } + + /** + * Get the length for UTF8-encoding a string without encoding it first + * + * @param s The string to calculate the length for + * @return The length when serialized + */ + public static int utf8Length(CharSequence s) { + int count = 0; + for (int i = 0, len = s.length(); i < len; i++) { + char ch = s.charAt(i); + if (ch <= 0x7F) { + count++; + } else if (ch <= 0x7FF) { + count += 2; + } else if (Character.isHighSurrogate(ch)) { + count += 4; + ++i; + } else { + count += 3; + } + } + return count; + } + + /** + * Read the given byte buffer from its current position to its limit into a byte array. + * @param buffer The buffer to read from + */ + public static byte[] toArray(ByteBuffer buffer) { + return toArray(buffer, 0, buffer.remaining()); + } + + /** + * Read a byte array from its current position given the size in the buffer + * @param buffer The buffer to read from + * @param size The number of bytes to read into the array + */ + public static byte[] toArray(ByteBuffer buffer, int size) { + return toArray(buffer, 0, size); + } + + /** + * Convert a ByteBuffer to a nullable array. + * @param buffer The buffer to convert + * @return The resulting array or null if the buffer is null + */ + public static byte[] toNullableArray(ByteBuffer buffer) { + return buffer == null ? null : toArray(buffer); + } + + /** + * Wrap an array as a nullable ByteBuffer. + * @param array The nullable array to wrap + * @return The wrapping ByteBuffer or null if array is null + */ + public static ByteBuffer wrapNullable(byte[] array) { + return array == null ? null : ByteBuffer.wrap(array); + } + + /** + * Read a byte array from the given offset and size in the buffer + * @param buffer The buffer to read from + * @param offset The offset relative to the current position of the buffer + * @param size The number of bytes to read into the array + */ + public static byte[] toArray(ByteBuffer buffer, int offset, int size) { + byte[] dest = new byte[size]; + if (buffer.hasArray()) { + System.arraycopy(buffer.array(), buffer.position() + buffer.arrayOffset() + offset, dest, 0, size); + } else { + int pos = buffer.position(); + buffer.position(pos + offset); + buffer.get(dest); + buffer.position(pos); + } + return dest; + } + + /** + * Starting from the current position, read an integer indicating the size of the byte array to read, + * then read the array. Consumes the buffer: upon returning, the buffer's position is after the array + * that is returned. + * @param buffer The buffer to read a size-prefixed array from + * @return The array + */ + public static byte[] getNullableSizePrefixedArray(final ByteBuffer buffer) { + final int size = buffer.getInt(); + return getNullableArray(buffer, size); + } + + /** + * Read a byte array of the given size. Consumes the buffer: upon returning, the buffer's position + * is after the array that is returned. + * @param buffer The buffer to read a size-prefixed array from + * @param size The number of bytes to read out of the buffer + * @return The array + */ + public static byte[] getNullableArray(final ByteBuffer buffer, final int size) { + if (size > buffer.remaining()) { + // preemptively throw this when the read is doomed to fail, so we don't have to allocate the array. + throw new BufferUnderflowException(); + } + final byte[] oldBytes = size == -1 ? null : new byte[size]; + if (oldBytes != null) { + buffer.get(oldBytes); + } + return oldBytes; + } + + /** + * Returns a copy of src byte array + * @param src The byte array to copy + * @return The copy + */ + public static byte[] copyArray(byte[] src) { + return Arrays.copyOf(src, src.length); + } + + /** + * Compares two character arrays for equality using a constant-time algorithm, which is needed + * for comparing passwords. Two arrays are equal if they have the same length and all + * characters at corresponding positions are equal. + * + * All characters in the first array are examined to determine equality. + * The calculation time depends only on the length of this first character array; it does not + * depend on the length of the second character array or the contents of either array. + * + * @param first the first array to compare + * @param second the second array to compare + * @return true if the arrays are equal, or false otherwise + */ + public static boolean isEqualConstantTime(char[] first, char[] second) { + if (first == second) { + return true; + } + if (first == null || second == null) { + return false; + } + + if (second.length == 0) { + return first.length == 0; + } + + // time-constant comparison that always compares all characters in first array + boolean matches = first.length == second.length; + for (int i = 0; i < first.length; ++i) { + int j = i < second.length ? i : 0; + if (first[i] != second[j]) { + matches = false; + } + } + return matches; + } + + /** + * Sleep for a bit + * @param ms The duration of the sleep + */ + public static void sleep(long ms) { + try { + Thread.sleep(ms); + } catch (InterruptedException e) { + // this is okay, we just wake up early + Thread.currentThread().interrupt(); + } + } + + /** + * Instantiate the class + */ + public static T newInstance(Class c) { + if (c == null) + throw new KafkaException("class cannot be null"); + try { + return c.getDeclaredConstructor().newInstance(); + } catch (NoSuchMethodException e) { + throw new KafkaException("Could not find a public no-argument constructor for " + c.getName(), e); + } catch (ReflectiveOperationException | RuntimeException e) { + throw new KafkaException("Could not instantiate class " + c.getName(), e); + } + } + + /** + * Look up the class by name and instantiate it. + * @param klass class name + * @param base super class of the class to be instantiated + * @param the type of the base class + * @return the new instance + */ + public static T newInstance(String klass, Class base) throws ClassNotFoundException { + return Utils.newInstance(loadClass(klass, base)); + } + + /** + * Look up a class by name. + * @param klass class name + * @param base super class of the class for verification + * @param the type of the base class + * @return the new class + */ + public static Class loadClass(String klass, Class base) throws ClassNotFoundException { + return Class.forName(klass, true, Utils.getContextOrKafkaClassLoader()).asSubclass(base); + } + + /** + * Cast {@code klass} to {@code base} and instantiate it. + * @param klass The class to instantiate + * @param base A know baseclass of klass. + * @param the type of the base class + * @throws ClassCastException If {@code klass} is not a subclass of {@code base}. + * @return the new instance. + */ + public static T newInstance(Class klass, Class base) { + return Utils.newInstance(klass.asSubclass(base)); + } + + /** + * Construct a new object using a class name and parameters. + * + * @param className The full name of the class to construct. + * @param params A sequence of (type, object) elements. + * @param The type of object to construct. + * @return The new object. + * @throws ClassNotFoundException If there was a problem constructing the object. + */ + public static T newParameterizedInstance(String className, Object... params) + throws ClassNotFoundException { + Class[] argTypes = new Class[params.length / 2]; + Object[] args = new Object[params.length / 2]; + try { + Class c = Class.forName(className, true, Utils.getContextOrKafkaClassLoader()); + for (int i = 0; i < params.length / 2; i++) { + argTypes[i] = (Class) params[2 * i]; + args[i] = params[(2 * i) + 1]; + } + @SuppressWarnings("unchecked") + Constructor constructor = (Constructor) c.getConstructor(argTypes); + return constructor.newInstance(args); + } catch (NoSuchMethodException e) { + throw new ClassNotFoundException(String.format("Failed to find " + + "constructor with %s for %s", Utils.join(argTypes, ", "), className), e); + } catch (InstantiationException e) { + throw new ClassNotFoundException(String.format("Failed to instantiate " + + "%s", className), e); + } catch (IllegalAccessException e) { + throw new ClassNotFoundException(String.format("Unable to access " + + "constructor of %s", className), e); + } catch (InvocationTargetException e) { + throw new ClassNotFoundException(String.format("Unable to invoke " + + "constructor of %s", className), e); + } + } + + /** + * Generates 32 bit murmur2 hash from byte array + * @param data byte array to hash + * @return 32 bit hash of the given array + */ + @SuppressWarnings("fallthrough") + public static int murmur2(final byte[] data) { + int length = data.length; + int seed = 0x9747b28c; + // 'm' and 'r' are mixing constants generated offline. + // They're not really 'magic', they just happen to work well. + final int m = 0x5bd1e995; + final int r = 24; + + // Initialize the hash to a random value + int h = seed ^ length; + int length4 = length / 4; + + for (int i = 0; i < length4; i++) { + final int i4 = i * 4; + int k = (data[i4 + 0] & 0xff) + ((data[i4 + 1] & 0xff) << 8) + ((data[i4 + 2] & 0xff) << 16) + ((data[i4 + 3] & 0xff) << 24); + k *= m; + k ^= k >>> r; + k *= m; + h *= m; + h ^= k; + } + + // Handle the last few bytes of the input array + switch (length % 4) { + case 3: + h ^= (data[(length & ~3) + 2] & 0xff) << 16; + case 2: + h ^= (data[(length & ~3) + 1] & 0xff) << 8; + case 1: + h ^= data[length & ~3] & 0xff; + h *= m; + } + + h ^= h >>> 13; + h *= m; + h ^= h >>> 15; + + return h; + } + + /** + * Extracts the hostname from a "host:port" address string. + * @param address address string to parse + * @return hostname or null if the given address is incorrect + */ + public static String getHost(String address) { + Matcher matcher = HOST_PORT_PATTERN.matcher(address); + return matcher.matches() ? matcher.group(1) : null; + } + + /** + * Extracts the port number from a "host:port" address string. + * @param address address string to parse + * @return port number or null if the given address is incorrect + */ + public static Integer getPort(String address) { + Matcher matcher = HOST_PORT_PATTERN.matcher(address); + return matcher.matches() ? Integer.parseInt(matcher.group(2)) : null; + } + + /** + * Basic validation of the supplied address. checks for valid characters + * @param address hostname string to validate + * @return true if address contains valid characters + */ + public static boolean validHostPattern(String address) { + return VALID_HOST_CHARACTERS.matcher(address).matches(); + } + + /** + * Formats hostname and port number as a "host:port" address string, + * surrounding IPv6 addresses with braces '[', ']' + * @param host hostname + * @param port port number + * @return address string + */ + public static String formatAddress(String host, Integer port) { + return host.contains(":") + ? "[" + host + "]:" + port // IPv6 + : host + ":" + port; + } + + /** + * Formats a byte number as a human readable String ("3.2 MB") + * @param bytes some size in bytes + * @return + */ + public static String formatBytes(long bytes) { + if (bytes < 0) { + return String.valueOf(bytes); + } + double asDouble = (double) bytes; + int ordinal = (int) Math.floor(Math.log(asDouble) / Math.log(1024.0)); + double scale = Math.pow(1024.0, ordinal); + double scaled = asDouble / scale; + String formatted = TWO_DIGIT_FORMAT.format(scaled); + try { + return formatted + " " + BYTE_SCALE_SUFFIXES[ordinal]; + } catch (IndexOutOfBoundsException e) { + //huge number? + return String.valueOf(asDouble); + } + } + + /** + * Create a string representation of an array joined by the given separator + * @param strs The array of items + * @param separator The separator + * @return The string representation. + */ + public static String join(T[] strs, String separator) { + return join(Arrays.asList(strs), separator); + } + + /** + * Create a string representation of a collection joined by the given separator + * @param collection The list of items + * @param separator The separator + * @return The string representation. + */ + public static String join(Collection collection, String separator) { + Objects.requireNonNull(collection); + StringBuilder sb = new StringBuilder(); + Iterator iter = collection.iterator(); + while (iter.hasNext()) { + sb.append(iter.next()); + if (iter.hasNext()) + sb.append(separator); + } + return sb.toString(); + } + + /** + * Converts a {@code Map} class into a string, concatenating keys and values + * Example: + * {@code mkString({ key: "hello", keyTwo: "hi" }, "|START|", "|END|", "=", ",") + * => "|START|key=hello,keyTwo=hi|END|"} + */ + public static String mkString(Map map, String begin, String end, + String keyValueSeparator, String elementSeparator) { + StringBuilder bld = new StringBuilder(); + bld.append(begin); + String prefix = ""; + for (Map.Entry entry : map.entrySet()) { + bld.append(prefix).append(entry.getKey()). + append(keyValueSeparator).append(entry.getValue()); + prefix = elementSeparator; + } + bld.append(end); + return bld.toString(); + } + + /** + * Converts an extensions string into a {@code Map}. + * + * Example: + * {@code parseMap("key=hey,keyTwo=hi,keyThree=hello", "=", ",") => { key: "hey", keyTwo: "hi", keyThree: "hello" }} + * + */ + public static Map parseMap(String mapStr, String keyValueSeparator, String elementSeparator) { + Map map = new HashMap<>(); + + if (!mapStr.isEmpty()) { + String[] attrvals = mapStr.split(elementSeparator); + for (String attrval : attrvals) { + String[] array = attrval.split(keyValueSeparator, 2); + map.put(array[0], array[1]); + } + } + return map; + } + + /** + * Read a properties file from the given path + * @param filename The path of the file to read + * @return the loaded properties + */ + public static Properties loadProps(String filename) throws IOException { + return loadProps(filename, null); + } + + /** + * Read a properties file from the given path + * @param filename The path of the file to read + * @param onlyIncludeKeys When non-null, only return values associated with these keys and ignore all others + * @return the loaded properties + */ + public static Properties loadProps(String filename, List onlyIncludeKeys) throws IOException { + Properties props = new Properties(); + + if (filename != null) { + try (InputStream propStream = Files.newInputStream(Paths.get(filename))) { + props.load(propStream); + } + } else { + System.out.println("Did not load any properties since the property file is not specified"); + } + + if (onlyIncludeKeys == null || onlyIncludeKeys.isEmpty()) + return props; + Properties requestedProps = new Properties(); + onlyIncludeKeys.forEach(key -> { + String value = props.getProperty(key); + if (value != null) + requestedProps.setProperty(key, value); + }); + return requestedProps; + } + + /** + * Converts a Properties object to a Map, calling {@link #toString} to ensure all keys and values + * are Strings. + */ + public static Map propsToStringMap(Properties props) { + Map result = new HashMap<>(); + for (Map.Entry entry : props.entrySet()) + result.put(entry.getKey().toString(), entry.getValue().toString()); + return result; + } + + /** + * Get the stack trace from an exception as a string + */ + public static String stackTrace(Throwable e) { + StringWriter sw = new StringWriter(); + PrintWriter pw = new PrintWriter(sw); + e.printStackTrace(pw); + return sw.toString(); + } + + /** + * Read a buffer into a Byte array for the given offset and length + */ + public static byte[] readBytes(ByteBuffer buffer, int offset, int length) { + byte[] dest = new byte[length]; + if (buffer.hasArray()) { + System.arraycopy(buffer.array(), buffer.arrayOffset() + offset, dest, 0, length); + } else { + buffer.mark(); + buffer.position(offset); + buffer.get(dest); + buffer.reset(); + } + return dest; + } + + /** + * Read the given byte buffer into a Byte array + */ + public static byte[] readBytes(ByteBuffer buffer) { + return Utils.readBytes(buffer, 0, buffer.limit()); + } + + /** + * Read a file as string and return the content. The file is treated as a stream and no seek is performed. + * This allows the program to read from a regular file as well as from a pipe/fifo. + */ + public static String readFileAsString(String path) throws IOException { + try { + byte[] allBytes = Files.readAllBytes(Paths.get(path)); + return new String(allBytes, StandardCharsets.UTF_8); + } catch (IOException ex) { + throw new IOException("Unable to read file " + path, ex); + } + } + + /** + * Check if the given ByteBuffer capacity + * @param existingBuffer ByteBuffer capacity to check + * @param newLength new length for the ByteBuffer. + * returns ByteBuffer + */ + public static ByteBuffer ensureCapacity(ByteBuffer existingBuffer, int newLength) { + if (newLength > existingBuffer.capacity()) { + ByteBuffer newBuffer = ByteBuffer.allocate(newLength); + existingBuffer.flip(); + newBuffer.put(existingBuffer); + return newBuffer; + } + return existingBuffer; + } + + /** + * Creates a set + * @param elems the elements + * @param the type of element + * @return Set + */ + @SafeVarargs + public static Set mkSet(T... elems) { + Set result = new HashSet<>((int) (elems.length / 0.75) + 1); + for (T elem : elems) + result.add(elem); + return result; + } + + /** + * Creates a sorted set + * @param elems the elements + * @param the type of element, must be comparable + * @return SortedSet + */ + @SafeVarargs + public static > SortedSet mkSortedSet(T... elems) { + SortedSet result = new TreeSet<>(); + for (T elem : elems) + result.add(elem); + return result; + } + + /** + * Creates a map entry (for use with {@link Utils#mkMap(java.util.Map.Entry[])}) + * + * @param k The key + * @param v The value + * @param The key type + * @param The value type + * @return An entry + */ + public static Map.Entry mkEntry(final K k, final V v) { + return new AbstractMap.SimpleEntry<>(k, v); + } + + /** + * Creates a map from a sequence of entries + * + * @param entries The entries to map + * @param The key type + * @param The value type + * @return A map + */ + @SafeVarargs + public static Map mkMap(final Map.Entry... entries) { + final LinkedHashMap result = new LinkedHashMap<>(); + for (final Map.Entry entry : entries) { + result.put(entry.getKey(), entry.getValue()); + } + return result; + } + + /** + * Creates a {@link Properties} from a map + * + * @param properties A map of properties to add + * @return The properties object + */ + public static Properties mkProperties(final Map properties) { + final Properties result = new Properties(); + for (final Map.Entry entry : properties.entrySet()) { + result.setProperty(entry.getKey(), entry.getValue()); + } + return result; + } + + /** + * Creates a {@link Properties} from a map + * + * @param properties A map of properties to add + * @return The properties object + */ + public static Properties mkObjectProperties(final Map properties) { + final Properties result = new Properties(); + for (final Map.Entry entry : properties.entrySet()) { + result.put(entry.getKey(), entry.getValue()); + } + return result; + } + + /** + * Recursively delete the given file/directory and any subfiles (if any exist) + * + * @param rootFile The root file at which to begin deleting + */ + public static void delete(final File rootFile) throws IOException { + if (rootFile == null) + return; + Files.walkFileTree(rootFile.toPath(), new SimpleFileVisitor() { + @Override + public FileVisitResult visitFileFailed(Path path, IOException exc) throws IOException { + // If the root path did not exist, ignore the error; otherwise throw it. + if (exc instanceof NoSuchFileException && path.toFile().equals(rootFile)) + return FileVisitResult.TERMINATE; + throw exc; + } + + @Override + public FileVisitResult visitFile(Path path, BasicFileAttributes attrs) throws IOException { + Files.delete(path); + return FileVisitResult.CONTINUE; + } + + @Override + public FileVisitResult postVisitDirectory(Path path, IOException exc) throws IOException { + // KAFKA-8999: if there's an exception thrown previously already, we should throw it + if (exc != null) { + throw exc; + } + + Files.delete(path); + return FileVisitResult.CONTINUE; + } + }); + } + + /** + * Returns an empty collection if this list is null + * @param other + * @return + */ + public static List safe(List other) { + return other == null ? Collections.emptyList() : other; + } + + /** + * Get the ClassLoader which loaded Kafka. + */ + public static ClassLoader getKafkaClassLoader() { + return Utils.class.getClassLoader(); + } + + /** + * Get the Context ClassLoader on this thread or, if not present, the ClassLoader that + * loaded Kafka. + * + * This should be used whenever passing a ClassLoader to Class.forName + */ + public static ClassLoader getContextOrKafkaClassLoader() { + ClassLoader cl = Thread.currentThread().getContextClassLoader(); + if (cl == null) + return getKafkaClassLoader(); + else + return cl; + } + + /** + * Attempts to move source to target atomically and falls back to a non-atomic move if it fails. + * This function also flushes the parent directory to guarantee crash consistency. + * + * @throws IOException if both atomic and non-atomic moves fail + */ + public static void atomicMoveWithFallback(Path source, Path target) throws IOException { + atomicMoveWithFallback(source, target, true); + } + + /** + * Attempts to move source to target atomically and falls back to a non-atomic move if it fails. + * This function allows callers to decide whether to flush the parent directory. This is needed + * when a sequence of atomicMoveWithFallback is called for the same directory and we don't want + * to repeatedly flush the same parent directory. + * + * @throws IOException if both atomic and non-atomic moves fail + */ + public static void atomicMoveWithFallback(Path source, Path target, boolean needFlushParentDir) throws IOException { + try { + Files.move(source, target, StandardCopyOption.ATOMIC_MOVE); + } catch (IOException outer) { + try { + Files.move(source, target, StandardCopyOption.REPLACE_EXISTING); + log.debug("Non-atomic move of {} to {} succeeded after atomic move failed due to {}", source, target, + outer.getMessage()); + } catch (IOException inner) { + inner.addSuppressed(outer); + throw inner; + } + } finally { + if (needFlushParentDir) { + flushDir(target.toAbsolutePath().normalize().getParent()); + } + } + } + + /** + * Flushes dirty directories to guarantee crash consistency. + * + * Note: We don't fsync directories on Windows OS because otherwise it'll throw AccessDeniedException (KAFKA-13391) + * + * @throws IOException if flushing the directory fails. + */ + public static void flushDir(Path path) throws IOException { + if (path != null && !OperatingSystem.IS_WINDOWS) { + try (FileChannel dir = FileChannel.open(path, StandardOpenOption.READ)) { + dir.force(true); + } + } + } + + /** + * Closes all the provided closeables. + * @throws IOException if any of the close methods throws an IOException. + * The first IOException is thrown with subsequent exceptions + * added as suppressed exceptions. + */ + public static void closeAll(Closeable... closeables) throws IOException { + IOException exception = null; + for (Closeable closeable : closeables) { + try { + if (closeable != null) + closeable.close(); + } catch (IOException e) { + if (exception != null) + exception.addSuppressed(e); + else + exception = e; + } + } + if (exception != null) + throw exception; + } + + /** + * An {@link AutoCloseable} interface without a throws clause in the signature + * + * This is used with lambda expressions in try-with-resources clauses + * to avoid casting un-checked exceptions to checked exceptions unnecessarily. + */ + @FunctionalInterface + public interface UncheckedCloseable extends AutoCloseable { + @Override + void close(); + } + + /** + * Closes {@code closeable} and if an exception is thrown, it is logged at the WARN level. + */ + public static void closeQuietly(AutoCloseable closeable, String name) { + if (closeable != null) { + try { + closeable.close(); + } catch (Throwable t) { + log.warn("Failed to close {} with type {}", name, closeable.getClass().getName(), t); + } + } + } + + public static void closeQuietly(AutoCloseable closeable, String name, AtomicReference firstException) { + if (closeable != null) { + try { + closeable.close(); + } catch (Throwable t) { + firstException.compareAndSet(null, t); + log.error("Failed to close {} with type {}", name, closeable.getClass().getName(), t); + } + } + } + + /** + * close all closable objects even if one of them throws exception. + * @param firstException keeps the first exception + * @param name message of closing those objects + * @param closeables closable objects + */ + public static void closeAllQuietly(AtomicReference firstException, String name, AutoCloseable... closeables) { + for (AutoCloseable closeable : closeables) closeQuietly(closeable, name, firstException); + } + + /** + * A cheap way to deterministically convert a number to a positive value. When the input is + * positive, the original value is returned. When the input number is negative, the returned + * positive value is the original value bit AND against 0x7fffffff which is not its absolute + * value. + * + * Note: changing this method in the future will possibly cause partition selection not to be + * compatible with the existing messages already placed on a partition since it is used + * in producer's {@link org.apache.kafka.clients.producer.internals.DefaultPartitioner} + * + * @param number a given number + * @return a positive number. + */ + public static int toPositive(int number) { + return number & 0x7fffffff; + } + + /** + * Read a size-delimited byte buffer starting at the given offset. + * @param buffer Buffer containing the size and data + * @param start Offset in the buffer to read from + * @return A slice of the buffer containing only the delimited data (excluding the size) + */ + public static ByteBuffer sizeDelimited(ByteBuffer buffer, int start) { + int size = buffer.getInt(start); + if (size < 0) { + return null; + } else { + ByteBuffer b = buffer.duplicate(); + b.position(start + 4); + b = b.slice(); + b.limit(size); + b.rewind(); + return b; + } + } + + /** + * Read data from the channel to the given byte buffer until there are no bytes remaining in the buffer. If the end + * of the file is reached while there are bytes remaining in the buffer, an EOFException is thrown. + * + * @param channel File channel containing the data to read from + * @param destinationBuffer The buffer into which bytes are to be transferred + * @param position The file position at which the transfer is to begin; it must be non-negative + * @param description A description of what is being read, this will be included in the EOFException if it is thrown + * + * @throws IllegalArgumentException If position is negative + * @throws EOFException If the end of the file is reached while there are remaining bytes in the destination buffer + * @throws IOException If an I/O error occurs, see {@link FileChannel#read(ByteBuffer, long)} for details on the + * possible exceptions + */ + public static void readFullyOrFail(FileChannel channel, ByteBuffer destinationBuffer, long position, + String description) throws IOException { + if (position < 0) { + throw new IllegalArgumentException("The file channel position cannot be negative, but it is " + position); + } + int expectedReadBytes = destinationBuffer.remaining(); + readFully(channel, destinationBuffer, position); + if (destinationBuffer.hasRemaining()) { + throw new EOFException(String.format("Failed to read `%s` from file channel `%s`. Expected to read %d bytes, " + + "but reached end of file after reading %d bytes. Started read from position %d.", + description, channel, expectedReadBytes, expectedReadBytes - destinationBuffer.remaining(), position)); + } + } + + /** + * Read data from the channel to the given byte buffer until there are no bytes remaining in the buffer or the end + * of the file has been reached. + * + * @param channel File channel containing the data to read from + * @param destinationBuffer The buffer into which bytes are to be transferred + * @param position The file position at which the transfer is to begin; it must be non-negative + * + * @throws IllegalArgumentException If position is negative + * @throws IOException If an I/O error occurs, see {@link FileChannel#read(ByteBuffer, long)} for details on the + * possible exceptions + */ + public static void readFully(FileChannel channel, ByteBuffer destinationBuffer, long position) throws IOException { + if (position < 0) { + throw new IllegalArgumentException("The file channel position cannot be negative, but it is " + position); + } + long currentPosition = position; + int bytesRead; + do { + bytesRead = channel.read(destinationBuffer, currentPosition); + currentPosition += bytesRead; + } while (bytesRead != -1 && destinationBuffer.hasRemaining()); + } + + /** + * Read data from the input stream to the given byte buffer until there are no bytes remaining in the buffer or the + * end of the stream has been reached. + * + * @param inputStream Input stream to read from + * @param destinationBuffer The buffer into which bytes are to be transferred (it must be backed by an array) + * + * @throws IOException If an I/O error occurs + */ + public static void readFully(InputStream inputStream, ByteBuffer destinationBuffer) throws IOException { + if (!destinationBuffer.hasArray()) + throw new IllegalArgumentException("destinationBuffer must be backed by an array"); + int initialOffset = destinationBuffer.arrayOffset() + destinationBuffer.position(); + byte[] array = destinationBuffer.array(); + int length = destinationBuffer.remaining(); + int totalBytesRead = 0; + do { + int bytesRead = inputStream.read(array, initialOffset + totalBytesRead, length - totalBytesRead); + if (bytesRead == -1) + break; + totalBytesRead += bytesRead; + } while (length > totalBytesRead); + destinationBuffer.position(destinationBuffer.position() + totalBytesRead); + } + + public static void writeFully(FileChannel channel, ByteBuffer sourceBuffer) throws IOException { + while (sourceBuffer.hasRemaining()) + channel.write(sourceBuffer); + } + + /** + * Trying to write data in source buffer to a {@link TransferableChannel}, we may need to call this method multiple + * times since this method doesn't ensure the data in the source buffer can be fully written to the destination channel. + * + * @param destChannel The destination channel + * @param position From which the source buffer will be written + * @param length The max size of bytes can be written + * @param sourceBuffer The source buffer + * + * @return The length of the actual written data + * @throws IOException If an I/O error occurs + */ + public static long tryWriteTo(TransferableChannel destChannel, + int position, + int length, + ByteBuffer sourceBuffer) throws IOException { + + ByteBuffer dup = sourceBuffer.duplicate(); + dup.position(position); + dup.limit(position + length); + return destChannel.write(dup); + } + + /** + * Write the contents of a buffer to an output stream. The bytes are copied from the current position + * in the buffer. + * @param out The output to write to + * @param buffer The buffer to write from + * @param length The number of bytes to write + * @throws IOException For any errors writing to the output + */ + public static void writeTo(DataOutput out, ByteBuffer buffer, int length) throws IOException { + if (buffer.hasArray()) { + out.write(buffer.array(), buffer.position() + buffer.arrayOffset(), length); + } else { + int pos = buffer.position(); + for (int i = pos; i < length + pos; i++) + out.writeByte(buffer.get(i)); + } + } + + public static List toList(Iterable iterable) { + return toList(iterable.iterator()); + } + + public static List toList(Iterator iterator) { + List res = new ArrayList<>(); + while (iterator.hasNext()) + res.add(iterator.next()); + return res; + } + + public static List toList(Iterator iterator, Predicate predicate) { + List res = new ArrayList<>(); + while (iterator.hasNext()) { + T e = iterator.next(); + if (predicate.test(e)) { + res.add(e); + } + } + return res; + } + + public static List concatListsUnmodifiable(List left, List right) { + return concatLists(left, right, Collections::unmodifiableList); + } + + public static List concatLists(List left, List right, Function, List> finisher) { + return Stream.concat(left.stream(), right.stream()) + .collect(Collectors.collectingAndThen(Collectors.toList(), finisher)); + } + + public static int to32BitField(final Set bytes) { + int value = 0; + for (final byte b : bytes) + value |= 1 << checkRange(b); + return value; + } + + private static byte checkRange(final byte i) { + if (i > 31) + throw new IllegalArgumentException("out of range: i>31, i = " + i); + if (i < 0) + throw new IllegalArgumentException("out of range: i<0, i = " + i); + return i; + } + + public static Set from32BitField(final int intValue) { + Set result = new HashSet<>(); + for (int itr = intValue, count = 0; itr != 0; itr >>>= 1) { + if ((itr & 1) != 0) + result.add((byte) count); + count++; + } + return result; + } + + public static Map transformMap( + Map map, + Function keyMapper, + Function valueMapper) { + return map.entrySet().stream().collect( + Collectors.toMap( + entry -> keyMapper.apply(entry.getKey()), + entry -> valueMapper.apply(entry.getValue()) + ) + ); + } + + /** + * A Collector that offers two kinds of convenience: + * 1. You can specify the concrete type of the returned Map + * 2. You can turn a stream of Entries directly into a Map without having to mess with a key function + * and a value function. In particular, this is handy if all you need to do is apply a filter to a Map's entries. + * + * + * One thing to be wary of: These types are too "distant" for IDE type checkers to warn you if you + * try to do something like build a TreeMap of non-Comparable elements. You'd get a runtime exception for that. + * + * @param mapSupplier The constructor for your concrete map type. + * @param The Map key type + * @param The Map value type + * @param The type of the Map itself. + * @return new Collector, M, M> + */ + public static > Collector, M, M> entriesToMap(final Supplier mapSupplier) { + return new Collector, M, M>() { + @Override + public Supplier supplier() { + return mapSupplier; + } + + @Override + public BiConsumer> accumulator() { + return (map, entry) -> map.put(entry.getKey(), entry.getValue()); + } + + @Override + public BinaryOperator combiner() { + return (map, map2) -> { + map.putAll(map2); + return map; + }; + } + + @Override + public Function finisher() { + return map -> map; + } + + @Override + public Set characteristics() { + return EnumSet.of(Characteristics.UNORDERED, Characteristics.IDENTITY_FINISH); + } + }; + } + + @SafeVarargs + public static Set union(final Supplier> constructor, final Set... set) { + final Set result = constructor.get(); + for (final Set s : set) { + result.addAll(s); + } + return result; + } + + @SafeVarargs + public static Set intersection(final Supplier> constructor, final Set first, final Set... set) { + final Set result = constructor.get(); + result.addAll(first); + for (final Set s : set) { + result.retainAll(s); + } + return result; + } + + public static Set diff(final Supplier> constructor, final Set left, final Set right) { + final Set result = constructor.get(); + result.addAll(left); + result.removeAll(right); + return result; + } + + public static Map filterMap(final Map map, final Predicate> filterPredicate) { + return map.entrySet().stream().filter(filterPredicate).collect(Collectors.toMap(Entry::getKey, Entry::getValue)); + } + + /** + * Convert a properties to map. All keys in properties must be string type. Otherwise, a ConfigException is thrown. + * @param properties to be converted + * @return a map including all elements in properties + */ + public static Map propsToMap(Properties properties) { + Map map = new HashMap<>(properties.size()); + for (Map.Entry entry : properties.entrySet()) { + if (entry.getKey() instanceof String) { + String k = (String) entry.getKey(); + map.put(k, properties.get(k)); + } else { + throw new ConfigException(entry.getKey().toString(), entry.getValue(), "Key must be a string."); + } + } + return map; + } + + /** + * Convert timestamp to an epoch value + * @param timestamp the timestamp to be converted, the accepted formats are: + * (1) yyyy-MM-dd'T'HH:mm:ss.SSS, ex: 2020-11-10T16:51:38.198 + * (2) yyyy-MM-dd'T'HH:mm:ss.SSSZ, ex: 2020-11-10T16:51:38.198+0800 + * (3) yyyy-MM-dd'T'HH:mm:ss.SSSX, ex: 2020-11-10T16:51:38.198+08 + * (4) yyyy-MM-dd'T'HH:mm:ss.SSSXX, ex: 2020-11-10T16:51:38.198+0800 + * (5) yyyy-MM-dd'T'HH:mm:ss.SSSXXX, ex: 2020-11-10T16:51:38.198+08:00 + * + * @return epoch value of a given timestamp (i.e. the number of milliseconds since January 1, 1970, 00:00:00 GMT) + * @throws ParseException for timestamp that doesn't follow ISO8601 format or the format is not expected + */ + public static long getDateTime(String timestamp) throws ParseException, IllegalArgumentException { + if (timestamp == null) { + throw new IllegalArgumentException("Error parsing timestamp with null value"); + } + + final String[] timestampParts = timestamp.split("T"); + if (timestampParts.length < 2) { + throw new ParseException("Error parsing timestamp. It does not contain a 'T' according to ISO8601 format", timestamp.length()); + } + + final String secondPart = timestampParts[1]; + if (!(secondPart.contains("+") || secondPart.contains("-") || secondPart.contains("Z"))) { + timestamp = timestamp + "Z"; + } + + SimpleDateFormat simpleDateFormat = new SimpleDateFormat(); + // strictly parsing the date/time format + simpleDateFormat.setLenient(false); + try { + simpleDateFormat.applyPattern("yyyy-MM-dd'T'HH:mm:ss.SSSXXX"); + final Date date = simpleDateFormat.parse(timestamp); + return date.getTime(); + } catch (final ParseException e) { + simpleDateFormat.applyPattern("yyyy-MM-dd'T'HH:mm:ss.SSSX"); + final Date date = simpleDateFormat.parse(timestamp); + return date.getTime(); + } + } + + @SuppressWarnings("unchecked") + public static Iterator covariantCast(Iterator iterator) { + return (Iterator) iterator; + } + + /** + * Checks if a string is null, empty or whitespace only. + * @param str a string to be checked + * @return true if the string is null, empty or whitespace only; otherwise, return false. + */ + public static boolean isBlank(String str) { + return str == null || str.trim().isEmpty(); + } + + public static Map initializeMap(Collection keys, Supplier valueSupplier) { + Map res = new HashMap<>(keys.size()); + keys.forEach(key -> res.put(key, valueSupplier.get())); + return res; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/server/authorizer/AclCreateResult.java b/clients/src/main/java/org/apache/kafka/server/authorizer/AclCreateResult.java new file mode 100644 index 0000000..70b9c00 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/server/authorizer/AclCreateResult.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.server.authorizer; + +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.errors.ApiException; + +import java.util.Optional; + +@InterfaceStability.Evolving +public class AclCreateResult { + public static final AclCreateResult SUCCESS = new AclCreateResult(); + + private final ApiException exception; + + private AclCreateResult() { + this(null); + } + + public AclCreateResult(ApiException exception) { + this.exception = exception; + } + + /** + * Returns any exception during create. If exception is empty, the request has succeeded. + */ + public Optional exception() { + return exception == null ? Optional.empty() : Optional.of(exception); + } +} diff --git a/clients/src/main/java/org/apache/kafka/server/authorizer/AclDeleteResult.java b/clients/src/main/java/org/apache/kafka/server/authorizer/AclDeleteResult.java new file mode 100644 index 0000000..994d6fd --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/server/authorizer/AclDeleteResult.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.server.authorizer; + +import java.util.Collections; +import java.util.Collection; +import java.util.Optional; + +import org.apache.kafka.common.acl.AclBinding; +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.errors.ApiException; + +@InterfaceStability.Evolving +public class AclDeleteResult { + private final ApiException exception; + private final Collection aclBindingDeleteResults; + + public AclDeleteResult(ApiException exception) { + this(Collections.emptySet(), exception); + } + + public AclDeleteResult(Collection deleteResults) { + this(deleteResults, null); + } + + private AclDeleteResult(Collection deleteResults, ApiException exception) { + this.aclBindingDeleteResults = deleteResults; + this.exception = exception; + } + + /** + * Returns any exception while attempting to match ACL filter to delete ACLs. + * If exception is empty, filtering has succeeded. See {@link #aclBindingDeleteResults()} + * for deletion results for each filter. + */ + public Optional exception() { + return exception == null ? Optional.empty() : Optional.of(exception); + } + + /** + * Returns delete result for each matching ACL binding. + */ + public Collection aclBindingDeleteResults() { + return aclBindingDeleteResults; + } + + + /** + * Delete result for each ACL binding that matched a delete filter. + */ + public static class AclBindingDeleteResult { + private final AclBinding aclBinding; + private final ApiException exception; + + public AclBindingDeleteResult(AclBinding aclBinding) { + this(aclBinding, null); + } + + public AclBindingDeleteResult(AclBinding aclBinding, ApiException exception) { + this.aclBinding = aclBinding; + this.exception = exception; + } + + /** + * Returns ACL binding that matched the delete filter. If {@link #exception()} is + * empty, the ACL binding was successfully deleted. + */ + public AclBinding aclBinding() { + return aclBinding; + } + + /** + * Returns any exception that resulted in failure to delete ACL binding. + * If exception is empty, the ACL binding was successfully deleted. + */ + public Optional exception() { + return exception == null ? Optional.empty() : Optional.of(exception); + } + } +} diff --git a/clients/src/main/java/org/apache/kafka/server/authorizer/Action.java b/clients/src/main/java/org/apache/kafka/server/authorizer/Action.java new file mode 100644 index 0000000..60af34b --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/server/authorizer/Action.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.server.authorizer; + +import java.util.Objects; +import org.apache.kafka.common.acl.AclOperation; +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.resource.ResourcePattern; + +@InterfaceStability.Evolving +public class Action { + + private final ResourcePattern resourcePattern; + private final AclOperation operation; + private final int resourceReferenceCount; + private final boolean logIfAllowed; + private final boolean logIfDenied; + + /** + * @param operation non-null operation being performed + * @param resourcePattern non-null resource pattern on which this action is being performed + */ + public Action(AclOperation operation, + ResourcePattern resourcePattern, + int resourceReferenceCount, + boolean logIfAllowed, + boolean logIfDenied) { + this.operation = Objects.requireNonNull(operation, "operation can't be null"); + this.resourcePattern = Objects.requireNonNull(resourcePattern, "resourcePattern can't be null"); + this.logIfAllowed = logIfAllowed; + this.logIfDenied = logIfDenied; + this.resourceReferenceCount = resourceReferenceCount; + } + + /** + * @return a non-null resource pattern on which this action is being performed + */ + public ResourcePattern resourcePattern() { + return resourcePattern; + } + + /** + * + * @return a non-null operation being performed + */ + public AclOperation operation() { + return operation; + } + + /** + * Indicates if audit logs tracking ALLOWED access should include this action if result is + * ALLOWED. The flag is true if access to a resource is granted while processing the request as a + * result of this authorization. The flag is false only for requests used to describe access where + * no operation on the resource is actually performed based on the authorization result. + */ + public boolean logIfAllowed() { + return logIfAllowed; + } + + /** + * Indicates if audit logs tracking DENIED access should include this action if result is + * DENIED. The flag is true if access to a resource was explicitly requested and request + * is denied as a result of this authorization request. The flag is false if request was + * filtering out authorized resources (e.g. to subscribe to regex pattern). The flag is also + * false if this is an optional authorization where an alternative resource authorization is + * applied if this fails (e.g. Cluster:Create which is subsequently overridden by Topic:Create). + */ + public boolean logIfDenied() { + return logIfDenied; + } + + /** + * Number of times the resource being authorized is referenced within the request. For example, a single + * request may reference `n` topic partitions of the same topic. Brokers will authorize the topic once + * with `resourceReferenceCount=n`. Authorizers may include the count in audit logs. + */ + public int resourceReferenceCount() { + return resourceReferenceCount; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof Action)) { + return false; + } + + Action that = (Action) o; + return Objects.equals(this.resourcePattern, that.resourcePattern) && + Objects.equals(this.operation, that.operation) && + this.resourceReferenceCount == that.resourceReferenceCount && + this.logIfAllowed == that.logIfAllowed && + this.logIfDenied == that.logIfDenied; + + } + + @Override + public int hashCode() { + return Objects.hash(resourcePattern, operation, resourceReferenceCount, logIfAllowed, logIfDenied); + } + + @Override + public String toString() { + return "Action(" + + ", resourcePattern='" + resourcePattern + '\'' + + ", operation='" + operation + '\'' + + ", resourceReferenceCount='" + resourceReferenceCount + '\'' + + ", logIfAllowed='" + logIfAllowed + '\'' + + ", logIfDenied='" + logIfDenied + '\'' + + ')'; + } +} diff --git a/clients/src/main/java/org/apache/kafka/server/authorizer/AuthorizableRequestContext.java b/clients/src/main/java/org/apache/kafka/server/authorizer/AuthorizableRequestContext.java new file mode 100644 index 0000000..f68b938 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/server/authorizer/AuthorizableRequestContext.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.server.authorizer; + +import java.net.InetAddress; +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.security.auth.KafkaPrincipal; +import org.apache.kafka.common.security.auth.SecurityProtocol; + +/** + * Request context interface that provides data from request header as well as connection + * and authentication information to plugins. + */ +@InterfaceStability.Evolving +public interface AuthorizableRequestContext { + + /** + * Returns name of listener on which request was received. + */ + String listenerName(); + + /** + * Returns the security protocol for the listener on which request was received. + */ + SecurityProtocol securityProtocol(); + + /** + * Returns authenticated principal for the connection on which request was received. + */ + KafkaPrincipal principal(); + + /** + * Returns client IP address from which request was sent. + */ + InetAddress clientAddress(); + + /** + * 16-bit API key of the request from the request header. See + * https://kafka.apache.org/protocol#protocol_api_keys for request types. + */ + int requestType(); + + /** + * Returns the request version from the request header. + */ + int requestVersion(); + + /** + * Returns the client id from the request header. + */ + String clientId(); + + /** + * Returns the correlation id from the request header. + */ + int correlationId(); +} diff --git a/clients/src/main/java/org/apache/kafka/server/authorizer/AuthorizationResult.java b/clients/src/main/java/org/apache/kafka/server/authorizer/AuthorizationResult.java new file mode 100644 index 0000000..d4ad15d --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/server/authorizer/AuthorizationResult.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.server.authorizer; + +import org.apache.kafka.common.annotation.InterfaceStability; + +@InterfaceStability.Evolving +public enum AuthorizationResult { + ALLOWED, + DENIED +} diff --git a/clients/src/main/java/org/apache/kafka/server/authorizer/Authorizer.java b/clients/src/main/java/org/apache/kafka/server/authorizer/Authorizer.java new file mode 100644 index 0000000..17348a7 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/server/authorizer/Authorizer.java @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.server.authorizer; + +import org.apache.kafka.common.Configurable; +import org.apache.kafka.common.Endpoint; +import org.apache.kafka.common.acl.AccessControlEntryFilter; +import org.apache.kafka.common.acl.AclBinding; +import org.apache.kafka.common.acl.AclBindingFilter; +import org.apache.kafka.common.acl.AclOperation; +import org.apache.kafka.common.acl.AclPermissionType; +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.resource.PatternType; +import org.apache.kafka.common.resource.ResourcePattern; +import org.apache.kafka.common.resource.ResourcePatternFilter; +import org.apache.kafka.common.resource.ResourceType; +import org.apache.kafka.common.security.auth.KafkaPrincipal; +import org.apache.kafka.common.utils.SecurityUtils; + +import java.io.Closeable; +import java.util.Collections; +import java.util.EnumMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletionStage; + +/** + * + * Pluggable authorizer interface for Kafka brokers. + * + * Startup sequence in brokers: + *
              + *
            1. Broker creates authorizer instance if configured in `authorizer.class.name`.
            2. + *
            3. Broker configures and starts authorizer instance. Authorizer implementation starts loading its metadata.
            4. + *
            5. Broker starts SocketServer to accept connections and process requests.
            6. + *
            7. For each listener, SocketServer waits for authorization metadata to be available in the + * authorizer before accepting connections. The future returned by {@link #start(AuthorizerServerInfo)} + * for each listener must return only when authorizer is ready to authorize requests on the listener.
            8. + *
            9. Broker accepts connections. For each connection, broker performs authentication and then accepts Kafka requests. + * For each request, broker invokes {@link #authorize(AuthorizableRequestContext, List)} to authorize + * actions performed by the request.
            10. + *
            + * + * Authorizer implementation class may optionally implement @{@link org.apache.kafka.common.Reconfigurable} + * to enable dynamic reconfiguration without restarting the broker. + *

            + * Threading model: + *

              + *
            • All authorizer operations including authorization and ACL updates must be thread-safe.
            • + *
            • ACL update methods are asynchronous. Implementations with low update latency may return a + * completed future using {@link java.util.concurrent.CompletableFuture#completedFuture(Object)}. + * This ensures that the request will be handled synchronously by the caller without using a + * purgatory to wait for the result. If ACL updates require remote communication which may block, + * return a future that is completed asynchronously when the remote operation completes. This enables + * the caller to process other requests on the request threads without blocking.
            • + *
            • Any threads or thread pools used for processing remote operations asynchronously can be started during + * {@link #start(AuthorizerServerInfo)}. These threads must be shutdown during {@link Authorizer#close()}.
            • + *
            + *

            + */ +@InterfaceStability.Evolving +public interface Authorizer extends Configurable, Closeable { + + /** + * Starts loading authorization metadata and returns futures that can be used to wait until + * metadata for authorizing requests on each listener is available. Each listener will be + * started only after its metadata is available and authorizer is ready to start authorizing + * requests on that listener. + * + * @param serverInfo Metadata for the broker including broker id and listener endpoints + * @return CompletionStage for each endpoint that completes when authorizer is ready to + * start authorizing requests on that listener. + */ + Map> start(AuthorizerServerInfo serverInfo); + + /** + * Authorizes the specified action. Additional metadata for the action is specified + * in `requestContext`. + *

            + * This is a synchronous API designed for use with locally cached ACLs. Since this method is invoked on the + * request thread while processing each request, implementations of this method should avoid time-consuming + * remote communication that may block request threads. + * + * @param requestContext Request context including request type, security protocol and listener name + * @param actions Actions being authorized including resource and operation for each action + * @return List of authorization results for each action in the same order as the provided actions + */ + List authorize(AuthorizableRequestContext requestContext, List actions); + + /** + * Creates new ACL bindings. + *

            + * This is an asynchronous API that enables the caller to avoid blocking during the update. Implementations of this + * API can return completed futures using {@link java.util.concurrent.CompletableFuture#completedFuture(Object)} + * to process the update synchronously on the request thread. + * + * @param requestContext Request context if the ACL is being created by a broker to handle + * a client request to create ACLs. This may be null if ACLs are created directly in ZooKeeper + * using AclCommand. + * @param aclBindings ACL bindings to create + * + * @return Create result for each ACL binding in the same order as in the input list. Each result + * is returned as a CompletionStage that completes when the result is available. + */ + List> createAcls(AuthorizableRequestContext requestContext, List aclBindings); + + /** + * Deletes all ACL bindings that match the provided filters. + *

            + * This is an asynchronous API that enables the caller to avoid blocking during the update. Implementations of this + * API can return completed futures using {@link java.util.concurrent.CompletableFuture#completedFuture(Object)} + * to process the update synchronously on the request thread. + *

            + * Refer to the authorizer implementation docs for details on concurrent update guarantees. + * + * @param requestContext Request context if the ACL is being deleted by a broker to handle + * a client request to delete ACLs. This may be null if ACLs are deleted directly in ZooKeeper + * using AclCommand. + * @param aclBindingFilters Filters to match ACL bindings that are to be deleted + * + * @return Delete result for each filter in the same order as in the input list. + * Each result indicates which ACL bindings were actually deleted as well as any + * bindings that matched but could not be deleted. Each result is returned as a + * CompletionStage that completes when the result is available. + */ + List> deleteAcls(AuthorizableRequestContext requestContext, List aclBindingFilters); + + /** + * Returns ACL bindings which match the provided filter. + *

            + * This is a synchronous API designed for use with locally cached ACLs. This method is invoked on the request + * thread while processing DescribeAcls requests and should avoid time-consuming remote communication that may + * block request threads. + * + * @return Iterator for ACL bindings, which may be populated lazily. + */ + Iterable acls(AclBindingFilter filter); + + /** + * Check if the caller is authorized to perform the given ACL operation on at least one + * resource of the given type. + * + * Custom authorizer implementations should consider overriding this default implementation because: + * 1. The default implementation iterates all AclBindings multiple times, without any caching + * by principal, host, operation, permission types, and resource types. More efficient + * implementations may be added in custom authorizers that directly access cached entries. + * 2. The default implementation cannot integrate with any audit logging included in the + * authorizer implementation. + * 3. The default implementation does not support any custom authorizer configs or other access + * rules apart from ACLs. + * + * @param requestContext Request context including request resourceType, security protocol and listener name + * @param op The ACL operation to check + * @param resourceType The resource type to check + * @return Return {@link AuthorizationResult#ALLOWED} if the caller is authorized + * to perform the given ACL operation on at least one resource of the + * given type. Return {@link AuthorizationResult#DENIED} otherwise. + */ + default AuthorizationResult authorizeByResourceType(AuthorizableRequestContext requestContext, AclOperation op, ResourceType resourceType) { + SecurityUtils.authorizeByResourceTypeCheckArgs(op, resourceType); + + // Check a hard-coded name to ensure that super users are granted + // access regardless of DENY ACLs. + if (authorize(requestContext, Collections.singletonList(new Action( + op, new ResourcePattern(resourceType, "hardcode", PatternType.LITERAL), + 0, true, false))) + .get(0) == AuthorizationResult.ALLOWED) { + return AuthorizationResult.ALLOWED; + } + + // Filter out all the resource pattern corresponding to the RequestContext, + // AclOperation, and ResourceType + ResourcePatternFilter resourceTypeFilter = new ResourcePatternFilter( + resourceType, null, PatternType.ANY); + AclBindingFilter aclFilter = new AclBindingFilter( + resourceTypeFilter, AccessControlEntryFilter.ANY); + + EnumMap> denyPatterns = + new EnumMap>(PatternType.class) {{ + put(PatternType.LITERAL, new HashSet<>()); + put(PatternType.PREFIXED, new HashSet<>()); + }}; + EnumMap> allowPatterns = + new EnumMap>(PatternType.class) {{ + put(PatternType.LITERAL, new HashSet<>()); + put(PatternType.PREFIXED, new HashSet<>()); + }}; + + boolean hasWildCardAllow = false; + + KafkaPrincipal principal = new KafkaPrincipal( + requestContext.principal().getPrincipalType(), + requestContext.principal().getName()); + String hostAddr = requestContext.clientAddress().getHostAddress(); + + for (AclBinding binding : acls(aclFilter)) { + if (!binding.entry().host().equals(hostAddr) && !binding.entry().host().equals("*")) + continue; + + if (!SecurityUtils.parseKafkaPrincipal(binding.entry().principal()).equals(principal) + && !binding.entry().principal().equals("User:*")) + continue; + + if (binding.entry().operation() != op + && binding.entry().operation() != AclOperation.ALL) + continue; + + if (binding.entry().permissionType() == AclPermissionType.DENY) { + switch (binding.pattern().patternType()) { + case LITERAL: + // If wildcard deny exists, return deny directly + if (binding.pattern().name().equals(ResourcePattern.WILDCARD_RESOURCE)) + return AuthorizationResult.DENIED; + denyPatterns.get(PatternType.LITERAL).add(binding.pattern().name()); + break; + case PREFIXED: + denyPatterns.get(PatternType.PREFIXED).add(binding.pattern().name()); + break; + default: + } + continue; + } + + if (binding.entry().permissionType() != AclPermissionType.ALLOW) + continue; + + switch (binding.pattern().patternType()) { + case LITERAL: + if (binding.pattern().name().equals(ResourcePattern.WILDCARD_RESOURCE)) { + hasWildCardAllow = true; + continue; + } + allowPatterns.get(PatternType.LITERAL).add(binding.pattern().name()); + break; + case PREFIXED: + allowPatterns.get(PatternType.PREFIXED).add(binding.pattern().name()); + break; + default: + } + } + + if (hasWildCardAllow) { + return AuthorizationResult.ALLOWED; + } + + // For any literal allowed, if there's no dominant literal and prefix denied, return allow. + // For any prefix allowed, if there's no dominant prefix denied, return allow. + for (Map.Entry> entry : allowPatterns.entrySet()) { + for (String allowStr : entry.getValue()) { + if (entry.getKey() == PatternType.LITERAL + && denyPatterns.get(PatternType.LITERAL).contains(allowStr)) + continue; + StringBuilder sb = new StringBuilder(); + boolean hasDominatedDeny = false; + for (char ch : allowStr.toCharArray()) { + sb.append(ch); + if (denyPatterns.get(PatternType.PREFIXED).contains(sb.toString())) { + hasDominatedDeny = true; + break; + } + } + if (!hasDominatedDeny) + return AuthorizationResult.ALLOWED; + } + } + + return AuthorizationResult.DENIED; + } + +} diff --git a/clients/src/main/java/org/apache/kafka/server/authorizer/AuthorizerServerInfo.java b/clients/src/main/java/org/apache/kafka/server/authorizer/AuthorizerServerInfo.java new file mode 100644 index 0000000..51e23fb --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/server/authorizer/AuthorizerServerInfo.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.server.authorizer; + +import java.util.Collection; +import org.apache.kafka.common.ClusterResource; +import org.apache.kafka.common.Endpoint; +import org.apache.kafka.common.annotation.InterfaceStability; + +/** + * Runtime broker configuration metadata provided to authorizers during start up. + */ +@InterfaceStability.Evolving +public interface AuthorizerServerInfo { + + /** + * Returns cluster metadata for the broker running this authorizer including cluster id. + */ + ClusterResource clusterResource(); + + /** + * Returns broker id. This may be a generated broker id if `broker.id` was not configured. + */ + int brokerId(); + + /** + * Returns endpoints for all listeners including the advertised host and port to which + * the listener is bound. + */ + Collection endpoints(); + + /** + * Returns the inter-broker endpoint. This is one of the endpoints returned by {@link #endpoints()}. + */ + Endpoint interBrokerEndpoint(); +} diff --git a/clients/src/main/java/org/apache/kafka/server/policy/AlterConfigPolicy.java b/clients/src/main/java/org/apache/kafka/server/policy/AlterConfigPolicy.java new file mode 100644 index 0000000..5710a60 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/server/policy/AlterConfigPolicy.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.policy; + +import org.apache.kafka.common.Configurable; +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.errors.PolicyViolationException; + +import java.util.Map; +import java.util.Objects; + +/** + *

            An interface for enforcing a policy on alter configs requests. + * + *

            Common use cases are requiring that the replication factor, min.insync.replicas and/or retention settings for a + * topic remain within an allowable range. + * + *

            If alter.config.policy.class.name is defined, Kafka will create an instance of the specified class + * using the default constructor and will then pass the broker configs to its configure() method. During + * broker shutdown, the close() method will be invoked so that resources can be released (if necessary). + */ +public interface AlterConfigPolicy extends Configurable, AutoCloseable { + + /** + * Class containing the create request parameters. + */ + class RequestMetadata { + + private final ConfigResource resource; + private final Map configs; + + /** + * Create an instance of this class with the provided parameters. + * + * This constructor is public to make testing of AlterConfigPolicy implementations easier. + */ + public RequestMetadata(ConfigResource resource, Map configs) { + this.resource = resource; + this.configs = configs; + } + + /** + * Return the configs in the request. + */ + public Map configs() { + return configs; + } + + public ConfigResource resource() { + return resource; + } + + @Override + public int hashCode() { + return Objects.hash(resource, configs); + } + + @Override + public boolean equals(Object o) { + if (o == null || o.getClass() != o.getClass()) return false; + RequestMetadata other = (RequestMetadata) o; + return resource.equals(other.resource) && + configs.equals(other.configs); + } + + @Override + public String toString() { + return "AlterConfigPolicy.RequestMetadata(resource=" + resource + + ", configs=" + configs + ")"; + } + } + + /** + * Validate the request parameters and throw a PolicyViolationException with a suitable error + * message if the alter configs request parameters for the provided resource do not satisfy this policy. + * + * Clients will receive the POLICY_VIOLATION error code along with the exception's message. Note that validation + * failure only affects the relevant resource, other resources in the request will still be processed. + * + * @param requestMetadata the alter configs request parameters for the provided resource (topic is the only resource + * type whose configs can be updated currently). + * @throws PolicyViolationException if the request parameters do not satisfy this policy. + */ + void validate(RequestMetadata requestMetadata) throws PolicyViolationException; +} diff --git a/clients/src/main/java/org/apache/kafka/server/policy/CreateTopicPolicy.java b/clients/src/main/java/org/apache/kafka/server/policy/CreateTopicPolicy.java new file mode 100644 index 0000000..dd9bacf --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/server/policy/CreateTopicPolicy.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.policy; + +import org.apache.kafka.common.Configurable; +import org.apache.kafka.common.errors.PolicyViolationException; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +/** + *

            An interface for enforcing a policy on create topics requests. + * + *

            Common use cases are requiring that the replication factor, min.insync.replicas and/or retention settings for a + * topic are within an allowable range. + * + *

            If create.topic.policy.class.name is defined, Kafka will create an instance of the specified class + * using the default constructor and will then pass the broker configs to its configure() method. During + * broker shutdown, the close() method will be invoked so that resources can be released (if necessary). + */ +public interface CreateTopicPolicy extends Configurable, AutoCloseable { + + /** + * Class containing the create request parameters. + */ + class RequestMetadata { + private final String topic; + private final Integer numPartitions; + private final Short replicationFactor; + private final Map> replicasAssignments; + private final Map configs; + + /** + * Create an instance of this class with the provided parameters. + * + * This constructor is public to make testing of CreateTopicPolicy implementations easier. + * + * @param topic the name of the topic to created. + * @param numPartitions the number of partitions to create or null if replicasAssignments is set. + * @param replicationFactor the replication factor for the topic or null if replicaAssignments is set. + * @param replicasAssignments replica assignments or null if numPartitions and replicationFactor is set. The + * assignment is a map from partition id to replica (broker) ids. + * @param configs topic configs for the topic to be created, not including broker defaults. Broker configs are + * passed via the {@code configure()} method of the policy implementation. + */ + public RequestMetadata(String topic, Integer numPartitions, Short replicationFactor, + Map> replicasAssignments, Map configs) { + this.topic = topic; + this.numPartitions = numPartitions; + this.replicationFactor = replicationFactor; + this.replicasAssignments = replicasAssignments == null ? null : Collections.unmodifiableMap(replicasAssignments); + this.configs = Collections.unmodifiableMap(configs); + } + + /** + * Return the name of the topic to create. + */ + public String topic() { + return topic; + } + + /** + * Return the number of partitions to create or null if replicaAssignments is not null. + */ + public Integer numPartitions() { + return numPartitions; + } + + /** + * Return the number of replicas to create or null if replicaAssignments is not null. + */ + public Short replicationFactor() { + return replicationFactor; + } + + /** + * Return a map from partition id to replica (broker) ids or null if numPartitions and replicationFactor are + * set instead. + */ + public Map> replicasAssignments() { + return replicasAssignments; + } + + /** + * Return topic configs in the request, not including broker defaults. Broker configs are passed via + * the {@code configure()} method of the policy implementation. + */ + public Map configs() { + return configs; + } + + @Override + public int hashCode() { + return Objects.hash(topic, numPartitions, replicationFactor, + replicasAssignments, configs); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RequestMetadata other = (RequestMetadata) o; + return topic.equals(other.topic) && + Objects.equals(numPartitions, other.numPartitions) && + Objects.equals(replicationFactor, other.replicationFactor) && + Objects.equals(replicasAssignments, other.replicasAssignments) && + configs.equals(other.configs); + } + + @Override + public String toString() { + return "CreateTopicPolicy.RequestMetadata(topic=" + topic + + ", numPartitions=" + numPartitions + + ", replicationFactor=" + replicationFactor + + ", replicasAssignments=" + replicasAssignments + + ", configs=" + configs + ")"; + } + } + + /** + * Validate the request parameters and throw a PolicyViolationException with a suitable error + * message if the create topics request parameters for the provided topic do not satisfy this policy. + * + * Clients will receive the POLICY_VIOLATION error code along with the exception's message. Note that validation + * failure only affects the relevant topic, other topics in the request will still be processed. + * + * @param requestMetadata the create topics request parameters for the provided topic. + * @throws PolicyViolationException if the request parameters do not satisfy this policy. + */ + void validate(RequestMetadata requestMetadata) throws PolicyViolationException; +} diff --git a/clients/src/main/java/org/apache/kafka/server/quota/ClientQuotaCallback.java b/clients/src/main/java/org/apache/kafka/server/quota/ClientQuotaCallback.java new file mode 100644 index 0000000..210e9f4 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/server/quota/ClientQuotaCallback.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.quota; + +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.Configurable; +import org.apache.kafka.common.security.auth.KafkaPrincipal; + +import java.util.Map; + +/** + * Quota callback interface for brokers that enables customization of client quota computation. + */ +public interface ClientQuotaCallback extends Configurable { + + /** + * Quota callback invoked to determine the quota metric tags to be applied for a request. + * Quota limits are associated with quota metrics and all clients which use the same + * metric tags share the quota limit. + * + * @param quotaType Type of quota requested + * @param principal The user principal of the connection for which quota is requested + * @param clientId The client id associated with the request + * @return quota metric tags that indicate which other clients share this quota + */ + Map quotaMetricTags(ClientQuotaType quotaType, KafkaPrincipal principal, String clientId); + + /** + * Returns the quota limit associated with the provided metric tags. These tags were returned from + * a previous call to {@link #quotaMetricTags(ClientQuotaType, KafkaPrincipal, String)}. This method is + * invoked by quota managers to obtain the current quota limit applied to a metric when the first request + * using these tags is processed. It is also invoked after a quota update or cluster metadata change. + * If the tags are no longer in use after the update, (e.g. this is a {user, client-id} quota metric + * and the quota now in use is a {user} quota), null is returned. + * + * @param quotaType Type of quota requested + * @param metricTags Metric tags for a quota metric of type `quotaType` + * @return the quota limit for the provided metric tags or null if the metric tags are no longer in use + */ + Double quotaLimit(ClientQuotaType quotaType, Map metricTags); + + /** + * Quota configuration update callback that is invoked when quota configuration for an entity is + * updated in ZooKeeper. This is useful to track configured quotas if built-in quota configuration + * tools are used for quota management. + * + * @param quotaType Type of quota being updated + * @param quotaEntity The quota entity for which quota is being updated + * @param newValue The new quota value + */ + void updateQuota(ClientQuotaType quotaType, ClientQuotaEntity quotaEntity, double newValue); + + /** + * Quota configuration removal callback that is invoked when quota configuration for an entity is + * removed in ZooKeeper. This is useful to track configured quotas if built-in quota configuration + * tools are used for quota management. + * + * @param quotaType Type of quota being updated + * @param quotaEntity The quota entity for which quota is being updated + */ + void removeQuota(ClientQuotaType quotaType, ClientQuotaEntity quotaEntity); + + /** + * Returns true if any of the existing quota configs may have been updated since the last call + * to this method for the provided quota type. Quota updates as a result of calls to + * {@link #updateClusterMetadata(Cluster)}, {@link #updateQuota(ClientQuotaType, ClientQuotaEntity, double)} + * and {@link #removeQuota(ClientQuotaType, ClientQuotaEntity)} are automatically processed. + * So callbacks that rely only on built-in quota configuration tools always return false. Quota callbacks + * with external quota configuration or custom reconfigurable quota configs that affect quota limits must + * return true if existing metric configs may need to be updated. This method is invoked on every request + * and hence is expected to be handled by callbacks as a simple flag that is updated when quotas change. + * + * @param quotaType Type of quota + */ + boolean quotaResetRequired(ClientQuotaType quotaType); + + /** + * Metadata update callback that is invoked whenever UpdateMetadata request is received from + * the controller. This is useful if quota computation takes partitions into account. + * Topics that are being deleted will not be included in `cluster`. + * + * @param cluster Cluster metadata including partitions and their leaders if known + * @return true if quotas have changed and metric configs may need to be updated + */ + boolean updateClusterMetadata(Cluster cluster); + + /** + * Closes this instance. + */ + void close(); +} + diff --git a/clients/src/main/java/org/apache/kafka/server/quota/ClientQuotaEntity.java b/clients/src/main/java/org/apache/kafka/server/quota/ClientQuotaEntity.java new file mode 100644 index 0000000..0f0eb62 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/server/quota/ClientQuotaEntity.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.quota; + +import java.util.List; + +/** + * The metadata for an entity for which quota is configured. Quotas may be defined at + * different levels and `configEntities` gives the list of config entities that define + * the level of this quota entity. + */ +public interface ClientQuotaEntity { + + /** + * Entity type of a {@link ConfigEntity} + */ + enum ConfigEntityType { + USER, + CLIENT_ID, + DEFAULT_USER, + DEFAULT_CLIENT_ID + } + + /** + * Interface representing a quota configuration entity. Quota may be + * configured at levels that include one or more configuration entities. + * For example, {user, client-id} quota is represented using two + * instances of ConfigEntity with entity types USER and CLIENT_ID. + */ + interface ConfigEntity { + /** + * Returns the name of this entity. For default quotas, an empty string is returned. + */ + String name(); + + /** + * Returns the type of this entity. + */ + ConfigEntityType entityType(); + } + + /** + * Returns the list of configuration entities that this quota entity is comprised of. + * For {user} or {clientId} quota, this is a single entity and for {user, clientId} + * quota, this is a list of two entities. + */ + List configEntities(); +} diff --git a/clients/src/main/java/org/apache/kafka/server/quota/ClientQuotaType.java b/clients/src/main/java/org/apache/kafka/server/quota/ClientQuotaType.java new file mode 100644 index 0000000..5b0828a --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/server/quota/ClientQuotaType.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.quota; + +/** + * Types of quotas that may be configured on brokers for client requests. + */ +public enum ClientQuotaType { + PRODUCE, + FETCH, + REQUEST, + CONTROLLER_MUTATION +} diff --git a/clients/src/main/resources/META-INF/services/org.apache.kafka.common.config.provider.ConfigProvider b/clients/src/main/resources/META-INF/services/org.apache.kafka.common.config.provider.ConfigProvider new file mode 100644 index 0000000..409080f --- /dev/null +++ b/clients/src/main/resources/META-INF/services/org.apache.kafka.common.config.provider.ConfigProvider @@ -0,0 +1,17 @@ + # Licensed to the Apache Software Foundation (ASF) under one or more + # contributor license agreements. See the NOTICE file distributed with + # this work for additional information regarding copyright ownership. + # The ASF licenses this file to You under the Apache License, Version 2.0 + # (the "License"); you may not use this file except in compliance with + # the License. You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + +org.apache.kafka.common.config.provider.FileConfigProvider +org.apache.kafka.common.config.provider.DirectoryConfigProvider diff --git a/clients/src/main/resources/common/message/AddOffsetsToTxnRequest.json b/clients/src/main/resources/common/message/AddOffsetsToTxnRequest.json new file mode 100644 index 0000000..ade3fc7 --- /dev/null +++ b/clients/src/main/resources/common/message/AddOffsetsToTxnRequest.json @@ -0,0 +1,38 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 25, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "AddOffsetsToTxnRequest", + // Version 1 is the same as version 0. + // + // Version 2 adds the support for new error code PRODUCER_FENCED. + // + // Version 3 enables flexible versions. + "validVersions": "0-3", + "flexibleVersions": "3+", + "fields": [ + { "name": "TransactionalId", "type": "string", "versions": "0+", "entityType": "transactionalId", + "about": "The transactional id corresponding to the transaction."}, + { "name": "ProducerId", "type": "int64", "versions": "0+", "entityType": "producerId", + "about": "Current producer id in use by the transactional id." }, + { "name": "ProducerEpoch", "type": "int16", "versions": "0+", + "about": "Current epoch associated with the producer id." }, + { "name": "GroupId", "type": "string", "versions": "0+", "entityType": "groupId", + "about": "The unique group identifier." } + ] +} diff --git a/clients/src/main/resources/common/message/AddOffsetsToTxnResponse.json b/clients/src/main/resources/common/message/AddOffsetsToTxnResponse.json new file mode 100644 index 0000000..71fa655 --- /dev/null +++ b/clients/src/main/resources/common/message/AddOffsetsToTxnResponse.json @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 25, + "type": "response", + "name": "AddOffsetsToTxnResponse", + // Starting in version 1, on quota violation brokers send out responses before throttling. + // + // Version 2 adds the support for new error code PRODUCER_FENCED. + // + // Version 3 enables flexible versions. + "validVersions": "0-3", + "flexibleVersions": "3+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "Duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The response error code, or 0 if there was no error." } + ] +} diff --git a/clients/src/main/resources/common/message/AddPartitionsToTxnRequest.json b/clients/src/main/resources/common/message/AddPartitionsToTxnRequest.json new file mode 100644 index 0000000..4920da1 --- /dev/null +++ b/clients/src/main/resources/common/message/AddPartitionsToTxnRequest.json @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 24, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "AddPartitionsToTxnRequest", + // Version 1 is the same as version 0. + // + // Version 2 adds the support for new error code PRODUCER_FENCED. + // + // Version 3 enables flexible versions. + "validVersions": "0-3", + "flexibleVersions": "3+", + "fields": [ + { "name": "TransactionalId", "type": "string", "versions": "0+", "entityType": "transactionalId", + "about": "The transactional id corresponding to the transaction."}, + { "name": "ProducerId", "type": "int64", "versions": "0+", "entityType": "producerId", + "about": "Current producer id in use by the transactional id." }, + { "name": "ProducerEpoch", "type": "int16", "versions": "0+", + "about": "Current epoch associated with the producer id." }, + { "name": "Topics", "type": "[]AddPartitionsToTxnTopic", "versions": "0+", + "about": "The partitions to add to the transaction.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "mapKey": true, "entityType": "topicName", + "about": "The name of the topic." }, + { "name": "Partitions", "type": "[]int32", "versions": "0+", + "about": "The partition indexes to add to the transaction" } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/AddPartitionsToTxnResponse.json b/clients/src/main/resources/common/message/AddPartitionsToTxnResponse.json new file mode 100644 index 0000000..4241dc7 --- /dev/null +++ b/clients/src/main/resources/common/message/AddPartitionsToTxnResponse.json @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 24, + "type": "response", + "name": "AddPartitionsToTxnResponse", + // Starting in version 1, on quota violation brokers send out responses before throttling. + // + // Version 2 adds the support for new error code PRODUCER_FENCED. + // + // Version 3 enables flexible versions. + "validVersions": "0-3", + "flexibleVersions": "3+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "Duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "Results", "type": "[]AddPartitionsToTxnTopicResult", "versions": "0+", + "about": "The results for each topic.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "mapKey": true, "entityType": "topicName", + "about": "The topic name." }, + { "name": "Results", "type": "[]AddPartitionsToTxnPartitionResult", "versions": "0+", + "about": "The results for each partition", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", "mapKey": true, + "about": "The partition indexes." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The response error code."} + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/AllocateProducerIdsRequest.json b/clients/src/main/resources/common/message/AllocateProducerIdsRequest.json new file mode 100644 index 0000000..7256c6b --- /dev/null +++ b/clients/src/main/resources/common/message/AllocateProducerIdsRequest.json @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implie +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 67, + "type": "request", + "listeners": ["zkBroker", "controller"], + "name": "AllocateProducerIdsRequest", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "BrokerId", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The ID of the requesting broker" }, + { "name": "BrokerEpoch", "type": "int64", "versions": "0+", "default": "-1", + "about": "The epoch of the requesting broker" } + ] +} diff --git a/clients/src/main/resources/common/message/AllocateProducerIdsResponse.json b/clients/src/main/resources/common/message/AllocateProducerIdsResponse.json new file mode 100644 index 0000000..0d849c0 --- /dev/null +++ b/clients/src/main/resources/common/message/AllocateProducerIdsResponse.json @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 67, + "type": "response", + "name": "AllocateProducerIdsResponse", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The top level response error code" }, + { "name": "ProducerIdStart", "type": "int64", "versions": "0+", "entityType": "producerId", + "about": "The first producer ID in this range, inclusive"}, + { "name": "ProducerIdLen", "type": "int32", "versions": "0+", + "about": "The number of producer IDs in this range"} + ] +} diff --git a/clients/src/main/resources/common/message/AlterClientQuotasRequest.json b/clients/src/main/resources/common/message/AlterClientQuotasRequest.json new file mode 100644 index 0000000..6bfdc92 --- /dev/null +++ b/clients/src/main/resources/common/message/AlterClientQuotasRequest.json @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 49, + "type": "request", + "listeners": ["zkBroker", "broker", "controller"], + "name": "AlterClientQuotasRequest", + "validVersions": "0-1", + // Version 1 enables flexible versions. + "flexibleVersions": "1+", + "fields": [ + { "name": "Entries", "type": "[]EntryData", "versions": "0+", + "about": "The quota configuration entries to alter.", "fields": [ + { "name": "Entity", "type": "[]EntityData", "versions": "0+", + "about": "The quota entity to alter.", "fields": [ + { "name": "EntityType", "type": "string", "versions": "0+", + "about": "The entity type." }, + { "name": "EntityName", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The name of the entity, or null if the default." } + ]}, + { "name": "Ops", "type": "[]OpData", "versions": "0+", + "about": "An individual quota configuration entry to alter.", "fields": [ + { "name": "Key", "type": "string", "versions": "0+", + "about": "The quota configuration key." }, + { "name": "Value", "type": "float64", "versions": "0+", + "about": "The value to set, otherwise ignored if the value is to be removed." }, + { "name": "Remove", "type": "bool", "versions": "0+", + "about": "Whether the quota configuration value should be removed, otherwise set." } + ]} + ]}, + { "name": "ValidateOnly", "type": "bool", "versions": "0+", + "about": "Whether the alteration should be validated, but not performed." } + ] +} diff --git a/clients/src/main/resources/common/message/AlterClientQuotasResponse.json b/clients/src/main/resources/common/message/AlterClientQuotasResponse.json new file mode 100644 index 0000000..326b85d --- /dev/null +++ b/clients/src/main/resources/common/message/AlterClientQuotasResponse.json @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 49, + "type": "response", + "name": "AlterClientQuotasResponse", + // Version 1 enables flexible versions. + "validVersions": "0-1", + "flexibleVersions": "1+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "Entries", "type": "[]EntryData", "versions": "0+", + "about": "The quota configuration entries to alter.", "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or `0` if the quota alteration succeeded." }, + { "name": "ErrorMessage", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The error message, or `null` if the quota alteration succeeded." }, + { "name": "Entity", "type": "[]EntityData", "versions": "0+", + "about": "The quota entity to alter.", "fields": [ + { "name": "EntityType", "type": "string", "versions": "0+", + "about": "The entity type." }, + { "name": "EntityName", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The name of the entity, or null if the default." } + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/AlterConfigsRequest.json b/clients/src/main/resources/common/message/AlterConfigsRequest.json new file mode 100644 index 0000000..31057e3 --- /dev/null +++ b/clients/src/main/resources/common/message/AlterConfigsRequest.json @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 33, + "type": "request", + "listeners": ["zkBroker", "broker", "controller"], + "name": "AlterConfigsRequest", + // Version 1 is the same as version 0. + // Version 2 enables flexible versions. + "validVersions": "0-2", + "flexibleVersions": "2+", + "fields": [ + { "name": "Resources", "type": "[]AlterConfigsResource", "versions": "0+", + "about": "The updates for each resource.", "fields": [ + { "name": "ResourceType", "type": "int8", "versions": "0+", "mapKey": true, + "about": "The resource type." }, + { "name": "ResourceName", "type": "string", "versions": "0+", "mapKey": true, + "about": "The resource name." }, + { "name": "Configs", "type": "[]AlterableConfig", "versions": "0+", + "about": "The configurations.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "mapKey": true, + "about": "The configuration key name." }, + { "name": "Value", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The value to set for the configuration key."} + ]} + ]}, + { "name": "ValidateOnly", "type": "bool", "versions": "0+", + "about": "True if we should validate the request, but not change the configurations."} + ] +} diff --git a/clients/src/main/resources/common/message/AlterConfigsResponse.json b/clients/src/main/resources/common/message/AlterConfigsResponse.json new file mode 100644 index 0000000..fab102b --- /dev/null +++ b/clients/src/main/resources/common/message/AlterConfigsResponse.json @@ -0,0 +1,39 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 33, + "type": "response", + "name": "AlterConfigsResponse", + // Starting in version 1, on quota violation brokers send out responses before throttling. + // Version 2 enables flexible versions. + "validVersions": "0-2", + "flexibleVersions": "2+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "Duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "Responses", "type": "[]AlterConfigsResourceResponse", "versions": "0+", + "about": "The responses for each resource.", "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The resource error code." }, + { "name": "ErrorMessage", "type": "string", "nullableVersions": "0+", "versions": "0+", + "about": "The resource error message, or null if there was no error." }, + { "name": "ResourceType", "type": "int8", "versions": "0+", + "about": "The resource type." }, + { "name": "ResourceName", "type": "string", "versions": "0+", + "about": "The resource name." } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/AlterIsrRequest.json b/clients/src/main/resources/common/message/AlterIsrRequest.json new file mode 100644 index 0000000..70736db --- /dev/null +++ b/clients/src/main/resources/common/message/AlterIsrRequest.json @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implie +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 56, + "type": "request", + "listeners": ["zkBroker", "controller"], + "name": "AlterIsrRequest", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "BrokerId", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The ID of the requesting broker" }, + { "name": "BrokerEpoch", "type": "int64", "versions": "0+", "default": "-1", + "about": "The epoch of the requesting broker" }, + { "name": "Topics", "type": "[]TopicData", "versions": "0+", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The name of the topic to alter ISRs for" }, + { "name": "Partitions", "type": "[]PartitionData", "versions": "0+", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index" }, + { "name": "LeaderEpoch", "type": "int32", "versions": "0+", + "about": "The leader epoch of this partition" }, + { "name": "NewIsr", "type": "[]int32", "versions": "0+", "entityType": "brokerId", + "about": "The ISR for this partition"}, + { "name": "CurrentIsrVersion", "type": "int32", "versions": "0+", + "about": "The expected version of ISR which is being updated"} + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/AlterIsrResponse.json b/clients/src/main/resources/common/message/AlterIsrResponse.json new file mode 100644 index 0000000..3383799 --- /dev/null +++ b/clients/src/main/resources/common/message/AlterIsrResponse.json @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 56, + "type": "response", + "name": "AlterIsrResponse", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The top level response error code" }, + { "name": "Topics", "type": "[]TopicData", "versions": "0+", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The name of the topic" }, + { "name": "Partitions", "type": "[]PartitionData", "versions": "0+", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index" }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The partition level error code" }, + { "name": "LeaderId", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The broker ID of the leader." }, + { "name": "LeaderEpoch", "type": "int32", "versions": "0+", + "about": "The leader epoch." }, + { "name": "Isr", "type": "[]int32", "versions": "0+", "entityType": "brokerId", + "about": "The in-sync replica IDs." }, + { "name": "CurrentIsrVersion", "type": "int32", "versions": "0+", + "about": "The current ISR version." } + ]} + ]} + ] +} \ No newline at end of file diff --git a/clients/src/main/resources/common/message/AlterPartitionReassignmentsRequest.json b/clients/src/main/resources/common/message/AlterPartitionReassignmentsRequest.json new file mode 100644 index 0000000..47043ff --- /dev/null +++ b/clients/src/main/resources/common/message/AlterPartitionReassignmentsRequest.json @@ -0,0 +1,39 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 45, + "type": "request", + "listeners": ["broker", "controller", "zkBroker"], + "name": "AlterPartitionReassignmentsRequest", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "TimeoutMs", "type": "int32", "versions": "0+", "default": "60000", + "about": "The time in ms to wait for the request to complete." }, + { "name": "Topics", "type": "[]ReassignableTopic", "versions": "0+", + "about": "The topics to reassign.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name." }, + { "name": "Partitions", "type": "[]ReassignablePartition", "versions": "0+", + "about": "The partitions to reassign.", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "Replicas", "type": "[]int32", "versions": "0+", "nullableVersions": "0+", "default": "null", "entityType": "brokerId", + "about": "The replicas to place the partitions on, or null to cancel a pending reassignment for this partition." } + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/AlterPartitionReassignmentsResponse.json b/clients/src/main/resources/common/message/AlterPartitionReassignmentsResponse.json new file mode 100644 index 0000000..3fa0888 --- /dev/null +++ b/clients/src/main/resources/common/message/AlterPartitionReassignmentsResponse.json @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 45, + "type": "response", + "name": "AlterPartitionReassignmentsResponse", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The top-level error code, or 0 if there was no error." }, + { "name": "ErrorMessage", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The top-level error message, or null if there was no error." }, + { "name": "Responses", "type": "[]ReassignableTopicResponse", "versions": "0+", + "about": "The responses to topics to reassign.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name" }, + { "name": "Partitions", "type": "[]ReassignablePartitionResponse", "versions": "0+", + "about": "The responses to partitions to reassign", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code for this partition, or 0 if there was no error." }, + { "name": "ErrorMessage", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The error message for this partition, or null if there was no error." } + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/AlterReplicaLogDirsRequest.json b/clients/src/main/resources/common/message/AlterReplicaLogDirsRequest.json new file mode 100644 index 0000000..2306caa --- /dev/null +++ b/clients/src/main/resources/common/message/AlterReplicaLogDirsRequest.json @@ -0,0 +1,39 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 34, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "AlterReplicaLogDirsRequest", + // Version 1 is the same as version 0. + // Version 2 enables flexible versions. + "validVersions": "0-2", + "flexibleVersions": "2+", + "fields": [ + { "name": "Dirs", "type": "[]AlterReplicaLogDir", "versions": "0+", + "about": "The alterations to make for each directory.", "fields": [ + { "name": "Path", "type": "string", "versions": "0+", "mapKey": true, + "about": "The absolute directory path." }, + { "name": "Topics", "type": "[]AlterReplicaLogDirTopic", "versions": "0+", + "about": "The topics to add to the directory.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "mapKey": true, "entityType": "topicName", + "about": "The topic name." }, + { "name": "Partitions", "type": "[]int32", "versions": "0+", + "about": "The partition indexes." } + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/AlterReplicaLogDirsResponse.json b/clients/src/main/resources/common/message/AlterReplicaLogDirsResponse.json new file mode 100644 index 0000000..386e24e --- /dev/null +++ b/clients/src/main/resources/common/message/AlterReplicaLogDirsResponse.json @@ -0,0 +1,40 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 34, + "type": "response", + "name": "AlterReplicaLogDirsResponse", + // Starting in version 1, on quota violation brokers send out responses before throttling. + // Version 2 enables flexible versions. + "validVersions": "0-2", + "flexibleVersions": "2+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "Duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "Results", "type": "[]AlterReplicaLogDirTopicResult", "versions": "0+", + "about": "The results for each topic.", "fields": [ + { "name": "TopicName", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The name of the topic." }, + { "name": "Partitions", "type": "[]AlterReplicaLogDirPartitionResult", "versions": "0+", + "about": "The results for each partition.", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index."}, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if there was no error." } + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/AlterUserScramCredentialsRequest.json b/clients/src/main/resources/common/message/AlterUserScramCredentialsRequest.json new file mode 100644 index 0000000..70e1483 --- /dev/null +++ b/clients/src/main/resources/common/message/AlterUserScramCredentialsRequest.json @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 51, + "type": "request", + "listeners": ["zkBroker"], + "name": "AlterUserScramCredentialsRequest", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "Deletions", "type": "[]ScramCredentialDeletion", "versions": "0+", + "about": "The SCRAM credentials to remove.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", + "about": "The user name." }, + { "name": "Mechanism", "type": "int8", "versions": "0+", + "about": "The SCRAM mechanism." } + ]}, + { "name": "Upsertions", "type": "[]ScramCredentialUpsertion", "versions": "0+", + "about": "The SCRAM credentials to update/insert.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", + "about": "The user name." }, + { "name": "Mechanism", "type": "int8", "versions": "0+", + "about": "The SCRAM mechanism." }, + { "name": "Iterations", "type": "int32", "versions": "0+", + "about": "The number of iterations." }, + { "name": "Salt", "type": "bytes", "versions": "0+", + "about": "A random salt generated by the client." }, + { "name": "SaltedPassword", "type": "bytes", "versions": "0+", + "about": "The salted password." } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/AlterUserScramCredentialsResponse.json b/clients/src/main/resources/common/message/AlterUserScramCredentialsResponse.json new file mode 100644 index 0000000..92b62d5 --- /dev/null +++ b/clients/src/main/resources/common/message/AlterUserScramCredentialsResponse.json @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 51, + "type": "response", + "name": "AlterUserScramCredentialsResponse", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "Results", "type": "[]AlterUserScramCredentialsResult", "versions": "0+", + "about": "The results for deletions and alterations, one per affected user.", "fields": [ + { "name": "User", "type": "string", "versions": "0+", + "about": "The user name." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code." }, + { "name": "ErrorMessage", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The error message, if any." } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/ApiVersionsRequest.json b/clients/src/main/resources/common/message/ApiVersionsRequest.json new file mode 100644 index 0000000..b86edbf --- /dev/null +++ b/clients/src/main/resources/common/message/ApiVersionsRequest.json @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 18, + "type": "request", + "listeners": ["zkBroker", "broker", "controller"], + "name": "ApiVersionsRequest", + // Versions 0 through 2 of ApiVersionsRequest are the same. + // + // Version 3 is the first flexible version and adds ClientSoftwareName and ClientSoftwareVersion. + "validVersions": "0-3", + "flexibleVersions": "3+", + "fields": [ + { "name": "ClientSoftwareName", "type": "string", "versions": "3+", + "ignorable": true, "about": "The name of the client." }, + { "name": "ClientSoftwareVersion", "type": "string", "versions": "3+", + "ignorable": true, "about": "The version of the client." } + ] +} diff --git a/clients/src/main/resources/common/message/ApiVersionsResponse.json b/clients/src/main/resources/common/message/ApiVersionsResponse.json new file mode 100644 index 0000000..06c343b --- /dev/null +++ b/clients/src/main/resources/common/message/ApiVersionsResponse.json @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 18, + "type": "response", + "name": "ApiVersionsResponse", + // Version 1 adds throttle time to the response. + // + // Starting in version 2, on quota violation, brokers send out responses before throttling. + // + // Version 3 is the first flexible version. Tagged fields are only supported in the body but + // not in the header. The length of the header must not change in order to guarantee the + // backward compatibility. + // + // Starting from Apache Kafka 2.4 (KIP-511), ApiKeys field is populated with the supported + // versions of the ApiVersionsRequest when an UNSUPPORTED_VERSION error is returned. + "validVersions": "0-3", + "flexibleVersions": "3+", + "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The top-level error code." }, + { "name": "ApiKeys", "type": "[]ApiVersion", "versions": "0+", + "about": "The APIs supported by the broker.", "fields": [ + { "name": "ApiKey", "type": "int16", "versions": "0+", "mapKey": true, + "about": "The API index." }, + { "name": "MinVersion", "type": "int16", "versions": "0+", + "about": "The minimum supported version, inclusive." }, + { "name": "MaxVersion", "type": "int16", "versions": "0+", + "about": "The maximum supported version, inclusive." } + ]}, + { "name": "ThrottleTimeMs", "type": "int32", "versions": "1+", "ignorable": true, + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "SupportedFeatures", "type": "[]SupportedFeatureKey", "ignorable": true, + "versions": "3+", "tag": 0, "taggedVersions": "3+", + "about": "Features supported by the broker.", + "fields": [ + { "name": "Name", "type": "string", "versions": "3+", "mapKey": true, + "about": "The name of the feature." }, + { "name": "MinVersion", "type": "int16", "versions": "3+", + "about": "The minimum supported version for the feature." }, + { "name": "MaxVersion", "type": "int16", "versions": "3+", + "about": "The maximum supported version for the feature." } + ] + }, + { "name": "FinalizedFeaturesEpoch", "type": "int64", "versions": "3+", + "tag": 1, "taggedVersions": "3+", "default": "-1", "ignorable": true, + "about": "The monotonically increasing epoch for the finalized features information. Valid values are >= 0. A value of -1 is special and represents unknown epoch."}, + { "name": "FinalizedFeatures", "type": "[]FinalizedFeatureKey", "ignorable": true, + "versions": "3+", "tag": 2, "taggedVersions": "3+", + "about": "List of cluster-wide finalized features. The information is valid only if FinalizedFeaturesEpoch >= 0.", + "fields": [ + {"name": "Name", "type": "string", "versions": "3+", "mapKey": true, + "about": "The name of the feature."}, + {"name": "MaxVersionLevel", "type": "int16", "versions": "3+", + "about": "The cluster-wide finalized max version level for the feature."}, + {"name": "MinVersionLevel", "type": "int16", "versions": "3+", + "about": "The cluster-wide finalized min version level for the feature."} + ] + } + ] +} diff --git a/clients/src/main/resources/common/message/BeginQuorumEpochRequest.json b/clients/src/main/resources/common/message/BeginQuorumEpochRequest.json new file mode 100644 index 0000000..d9d6d92 --- /dev/null +++ b/clients/src/main/resources/common/message/BeginQuorumEpochRequest.json @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 53, + "type": "request", + "listeners": ["controller"], + "name": "BeginQuorumEpochRequest", + "validVersions": "0", + "flexibleVersions": "none", + "fields": [ + { "name": "ClusterId", "type": "string", "versions": "0+", + "nullableVersions": "0+", "default": "null"}, + { "name": "Topics", "type": "[]TopicData", + "versions": "0+", "fields": [ + { "name": "TopicName", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name." }, + { "name": "Partitions", "type": "[]PartitionData", + "versions": "0+", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "LeaderId", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The ID of the newly elected leader"}, + { "name": "LeaderEpoch", "type": "int32", "versions": "0+", + "about": "The epoch of the newly elected leader"} + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/BeginQuorumEpochResponse.json b/clients/src/main/resources/common/message/BeginQuorumEpochResponse.json new file mode 100644 index 0000000..4b7d7f5 --- /dev/null +++ b/clients/src/main/resources/common/message/BeginQuorumEpochResponse.json @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 53, + "type": "response", + "name": "BeginQuorumEpochResponse", + "validVersions": "0", + "flexibleVersions": "none", + "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The top level error code."}, + { "name": "Topics", "type": "[]TopicData", + "versions": "0+", "fields": [ + { "name": "TopicName", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name." }, + { "name": "Partitions", "type": "[]PartitionData", + "versions": "0+", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+"}, + { "name": "LeaderId", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The ID of the current leader or -1 if the leader is unknown."}, + { "name": "LeaderEpoch", "type": "int32", "versions": "0+", + "about": "The latest known leader epoch"} + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/BrokerHeartbeatRequest.json b/clients/src/main/resources/common/message/BrokerHeartbeatRequest.json new file mode 100644 index 0000000..2cf2577 --- /dev/null +++ b/clients/src/main/resources/common/message/BrokerHeartbeatRequest.json @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 63, + "type": "request", + "listeners": ["controller"], + "name": "BrokerHeartbeatRequest", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "BrokerId", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The broker ID." }, + { "name": "BrokerEpoch", "type": "int64", "versions": "0+", "default": "-1", + "about": "The broker epoch." }, + { "name": "CurrentMetadataOffset", "type": "int64", "versions": "0+", + "about": "The highest metadata offset which the broker has reached." }, + { "name": "WantFence", "type": "bool", "versions": "0+", + "about": "True if the broker wants to be fenced, false otherwise." }, + { "name": "WantShutDown", "type": "bool", "versions": "0+", + "about": "True if the broker wants to be shut down, false otherwise." } + ] +} diff --git a/clients/src/main/resources/common/message/BrokerHeartbeatResponse.json b/clients/src/main/resources/common/message/BrokerHeartbeatResponse.json new file mode 100644 index 0000000..ce9aaba --- /dev/null +++ b/clients/src/main/resources/common/message/BrokerHeartbeatResponse.json @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 63, + "type": "response", + "name": "BrokerHeartbeatResponse", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "Duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if there was no error." }, + { "name": "IsCaughtUp", "type": "bool", "versions": "0+", "default": "false", + "about": "True if the broker has approximately caught up with the latest metadata." }, + { "name": "IsFenced", "type": "bool", "versions": "0+", "default": "true", + "about": "True if the broker is fenced." }, + { "name": "ShouldShutDown", "type": "bool", "versions": "0+", + "about": "True if the broker should proceed with its shutdown." } + ] +} diff --git a/clients/src/main/resources/common/message/BrokerRegistrationRequest.json b/clients/src/main/resources/common/message/BrokerRegistrationRequest.json new file mode 100644 index 0000000..d96369c --- /dev/null +++ b/clients/src/main/resources/common/message/BrokerRegistrationRequest.json @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey":62, + "type": "request", + "listeners": ["controller"], + "name": "BrokerRegistrationRequest", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "BrokerId", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The broker ID." }, + { "name": "ClusterId", "type": "string", "versions": "0+", + "about": "The cluster id of the broker process." }, + { "name": "IncarnationId", "type": "uuid", "versions": "0+", + "about": "The incarnation id of the broker process." }, + { "name": "Listeners", "type": "[]Listener", + "about": "The listeners of this broker", "versions": "0+", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "mapKey": true, + "about": "The name of the endpoint." }, + { "name": "Host", "type": "string", "versions": "0+", + "about": "The hostname." }, + { "name": "Port", "type": "uint16", "versions": "0+", + "about": "The port." }, + { "name": "SecurityProtocol", "type": "int16", "versions": "0+", + "about": "The security protocol." } + ] + }, + { "name": "Features", "type": "[]Feature", + "about": "The features on this broker", "versions": "0+", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "mapKey": true, + "about": "The feature name." }, + { "name": "MinSupportedVersion", "type": "int16", "versions": "0+", + "about": "The minimum supported feature level." }, + { "name": "MaxSupportedVersion", "type": "int16", "versions": "0+", + "about": "The maximum supported feature level." } + ] + }, + { "name": "Rack", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The rack which this broker is in." } + ] +} diff --git a/clients/src/main/resources/common/message/BrokerRegistrationResponse.json b/clients/src/main/resources/common/message/BrokerRegistrationResponse.json new file mode 100644 index 0000000..1f12123 --- /dev/null +++ b/clients/src/main/resources/common/message/BrokerRegistrationResponse.json @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 62, + "type": "response", + "name": "BrokerRegistrationResponse", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "Duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if there was no error." }, + { "name": "BrokerEpoch", "type": "int64", "versions": "0+", "default": "-1", + "about": "The broker's assigned epoch, or -1 if none was assigned." } + ] +} diff --git a/clients/src/main/resources/common/message/ConsumerProtocolAssignment.json b/clients/src/main/resources/common/message/ConsumerProtocolAssignment.json new file mode 100644 index 0000000..50a9706 --- /dev/null +++ b/clients/src/main/resources/common/message/ConsumerProtocolAssignment.json @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "type": "data", + "name": "ConsumerProtocolAssignment", + // Assignment part of the Consumer Protocol. + // + // The current implementation assumes that future versions will not break compatibility. When + // it encounters a newer version, it parses it using the current format. This basically means + // that new versions cannot remove or reorder any of the existing fields. + "validVersions": "0-1", + "flexibleVersions": "none", + "fields": [ + { "name": "AssignedPartitions", "type": "[]TopicPartition", "versions": "0+", + "fields": [ + { "name": "Topic", "type": "string", "mapKey": true, "versions": "0+", "entityType": "topicName" }, + { "name": "Partitions", "type": "[]int32", "versions": "0+" } + ] + }, + { "name": "UserData", "type": "bytes", "versions": "0+", "nullableVersions": "0+", + "default": "null", "zeroCopy": true } + ] +} diff --git a/clients/src/main/resources/common/message/ConsumerProtocolSubscription.json b/clients/src/main/resources/common/message/ConsumerProtocolSubscription.json new file mode 100644 index 0000000..e33c16f --- /dev/null +++ b/clients/src/main/resources/common/message/ConsumerProtocolSubscription.json @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "type": "data", + "name": "ConsumerProtocolSubscription", + // Subscription part of the Consumer Protocol. + // + // The current implementation assumes that future versions will not break compatibility. When + // it encounters a newer version, it parses it using the current format. This basically means + // that new versions cannot remove or reorder any of the existing fields. + "validVersions": "0-1", + "flexibleVersions": "none", + "fields": [ + { "name": "Topics", "type": "[]string", "versions": "0+" }, + { "name": "UserData", "type": "bytes", "versions": "0+", "nullableVersions": "0+", + "default": "null", "zeroCopy": true }, + { "name": "OwnedPartitions", "type": "[]TopicPartition", "versions": "1+", "ignorable": true, + "fields": [ + { "name": "Topic", "type": "string", "mapKey": true, "versions": "1+", "entityType": "topicName" }, + { "name": "Partitions", "type": "[]int32", "versions": "1+"} + ] + } + ] +} diff --git a/clients/src/main/resources/common/message/ControlledShutdownRequest.json b/clients/src/main/resources/common/message/ControlledShutdownRequest.json new file mode 100644 index 0000000..49561f7 --- /dev/null +++ b/clients/src/main/resources/common/message/ControlledShutdownRequest.json @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 7, + "type": "request", + "listeners": ["zkBroker", "controller"], + "name": "ControlledShutdownRequest", + // Version 0 of ControlledShutdownRequest has a non-standard request header + // which does not include clientId. Version 1 and later use the standard + // request header. + // + // Version 1 is the same as version 0. + // + // Version 2 adds BrokerEpoch. + "validVersions": "0-3", + "flexibleVersions": "3+", + "fields": [ + { "name": "BrokerId", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The id of the broker for which controlled shutdown has been requested." }, + { "name": "BrokerEpoch", "type": "int64", "versions": "2+", "default": "-1", "ignorable": true, + "about": "The broker epoch." } + ] +} diff --git a/clients/src/main/resources/common/message/ControlledShutdownResponse.json b/clients/src/main/resources/common/message/ControlledShutdownResponse.json new file mode 100644 index 0000000..27feb1b --- /dev/null +++ b/clients/src/main/resources/common/message/ControlledShutdownResponse.json @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 7, + "type": "response", + "name": "ControlledShutdownResponse", + // Versions 1 and 2 are the same as version 0. + "validVersions": "0-3", + "flexibleVersions": "3+", + "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The top-level error code." }, + { "name": "RemainingPartitions", "type": "[]RemainingPartition", "versions": "0+", + "about": "The partitions that the broker still leads.", "fields": [ + { "name": "TopicName", "type": "string", "versions": "0+", "mapKey": true, "entityType": "topicName", + "about": "The name of the topic." }, + { "name": "PartitionIndex", "type": "int32", "versions": "0+", "mapKey": true, + "about": "The index of the partition." } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/CreateAclsRequest.json b/clients/src/main/resources/common/message/CreateAclsRequest.json new file mode 100644 index 0000000..5b3bfed --- /dev/null +++ b/clients/src/main/resources/common/message/CreateAclsRequest.json @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 30, + "type": "request", + "listeners": ["zkBroker", "broker", "controller"], + "name": "CreateAclsRequest", + // Version 1 adds resource pattern type. + // Version 2 enables flexible versions. + "validVersions": "0-2", + "flexibleVersions": "2+", + "fields": [ + { "name": "Creations", "type": "[]AclCreation", "versions": "0+", + "about": "The ACLs that we want to create.", "fields": [ + { "name": "ResourceType", "type": "int8", "versions": "0+", + "about": "The type of the resource." }, + { "name": "ResourceName", "type": "string", "versions": "0+", + "about": "The resource name for the ACL." }, + { "name": "ResourcePatternType", "type": "int8", "versions": "1+", "default": "3", + "about": "The pattern type for the ACL." }, + { "name": "Principal", "type": "string", "versions": "0+", + "about": "The principal for the ACL." }, + { "name": "Host", "type": "string", "versions": "0+", + "about": "The host for the ACL." }, + { "name": "Operation", "type": "int8", "versions": "0+", + "about": "The operation type for the ACL (read, write, etc.)." }, + { "name": "PermissionType", "type": "int8", "versions": "0+", + "about": "The permission type for the ACL (allow, deny, etc.)." } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/CreateAclsResponse.json b/clients/src/main/resources/common/message/CreateAclsResponse.json new file mode 100644 index 0000000..7b0de7e --- /dev/null +++ b/clients/src/main/resources/common/message/CreateAclsResponse.json @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 30, + "type": "response", + "name": "CreateAclsResponse", + // Starting in version 1, on quota violation, brokers send out responses before throttling. + // Version 2 enables flexible versions. + "validVersions": "0-2", + "flexibleVersions": "2+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "Results", "type": "[]AclCreationResult", "versions": "0+", + "about": "The results for each ACL creation.", "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The result error, or zero if there was no error." }, + { "name": "ErrorMessage", "type": "string", "nullableVersions": "0+", "versions": "0+", + "about": "The result message, or null if there was no error." } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/CreateDelegationTokenRequest.json b/clients/src/main/resources/common/message/CreateDelegationTokenRequest.json new file mode 100644 index 0000000..d65d490 --- /dev/null +++ b/clients/src/main/resources/common/message/CreateDelegationTokenRequest.json @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 38, + "type": "request", + "listeners": ["zkBroker"], + "name": "CreateDelegationTokenRequest", + // Version 1 is the same as version 0. + // + // Version 2 is the first flexible version. + "validVersions": "0-2", + "flexibleVersions": "2+", + "fields": [ + { "name": "Renewers", "type": "[]CreatableRenewers", "versions": "0+", + "about": "A list of those who are allowed to renew this token before it expires.", "fields": [ + { "name": "PrincipalType", "type": "string", "versions": "0+", + "about": "The type of the Kafka principal." }, + { "name": "PrincipalName", "type": "string", "versions": "0+", + "about": "The name of the Kafka principal." } + ]}, + { "name": "MaxLifetimeMs", "type": "int64", "versions": "0+", + "about": "The maximum lifetime of the token in milliseconds, or -1 to use the server side default." } + ] +} diff --git a/clients/src/main/resources/common/message/CreateDelegationTokenResponse.json b/clients/src/main/resources/common/message/CreateDelegationTokenResponse.json new file mode 100644 index 0000000..74ad905 --- /dev/null +++ b/clients/src/main/resources/common/message/CreateDelegationTokenResponse.json @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 38, + "type": "response", + "name": "CreateDelegationTokenResponse", + // Starting in version 1, on quota violation, brokers send out responses before throttling. + // + // Version 2 is the first flexible version. + "validVersions": "0-2", + "flexibleVersions": "2+", + "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The top-level error, or zero if there was no error."}, + { "name": "PrincipalType", "type": "string", "versions": "0+", + "about": "The principal type of the token owner." }, + { "name": "PrincipalName", "type": "string", "versions": "0+", + "about": "The name of the token owner." }, + { "name": "IssueTimestampMs", "type": "int64", "versions": "0+", + "about": "When this token was generated." }, + { "name": "ExpiryTimestampMs", "type": "int64", "versions": "0+", + "about": "When this token expires." }, + { "name": "MaxTimestampMs", "type": "int64", "versions": "0+", + "about": "The maximum lifetime of this token." }, + { "name": "TokenId", "type": "string", "versions": "0+", + "about": "The token UUID." }, + { "name": "Hmac", "type": "bytes", "versions": "0+", + "about": "HMAC of the delegation token." }, + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." } + ] +} diff --git a/clients/src/main/resources/common/message/CreatePartitionsRequest.json b/clients/src/main/resources/common/message/CreatePartitionsRequest.json new file mode 100644 index 0000000..6e24949 --- /dev/null +++ b/clients/src/main/resources/common/message/CreatePartitionsRequest.json @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 37, + "type": "request", + "listeners": ["zkBroker", "broker", "controller"], + "name": "CreatePartitionsRequest", + // Version 1 is the same as version 0. + // + // Version 2 adds flexible version support + // + // Version 3 is identical to version 2 but may return a THROTTLING_QUOTA_EXCEEDED error + // in the response if the partitions creation is throttled (KIP-599). + "validVersions": "0-3", + "flexibleVersions": "2+", + "fields": [ + { "name": "Topics", "type": "[]CreatePartitionsTopic", "versions": "0+", + "about": "Each topic that we want to create new partitions inside.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "mapKey": true, "entityType": "topicName", + "about": "The topic name." }, + { "name": "Count", "type": "int32", "versions": "0+", + "about": "The new partition count." }, + { "name": "Assignments", "type": "[]CreatePartitionsAssignment", "versions": "0+", "nullableVersions": "0+", + "about": "The new partition assignments.", "fields": [ + { "name": "BrokerIds", "type": "[]int32", "versions": "0+", "entityType": "brokerId", + "about": "The assigned broker IDs." } + ]} + ]}, + { "name": "TimeoutMs", "type": "int32", "versions": "0+", + "about": "The time in ms to wait for the partitions to be created." }, + { "name": "ValidateOnly", "type": "bool", "versions": "0+", + "about": "If true, then validate the request, but don't actually increase the number of partitions." } + ] +} diff --git a/clients/src/main/resources/common/message/CreatePartitionsResponse.json b/clients/src/main/resources/common/message/CreatePartitionsResponse.json new file mode 100644 index 0000000..ef9f1f6 --- /dev/null +++ b/clients/src/main/resources/common/message/CreatePartitionsResponse.json @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 37, + "type": "response", + "name": "CreatePartitionsResponse", + // Starting in version 1, on quota violation, brokers send out responses before throttling. + // + // Version 2 adds flexible version support + // + // Version 3 is identical to version 2 but may return a THROTTLING_QUOTA_EXCEEDED error + // in the response if the partitions creation is throttled (KIP-599). + "validVersions": "0-3", + "flexibleVersions": "2+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "Results", "type": "[]CreatePartitionsTopicResult", "versions": "0+", + "about": "The partition creation results for each topic.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The result error, or zero if there was no error."}, + { "name": "ErrorMessage", "type": "string", "versions": "0+", "nullableVersions": "0+", + "default": "null", "about": "The result message, or null if there was no error."} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/CreateTopicsRequest.json b/clients/src/main/resources/common/message/CreateTopicsRequest.json new file mode 100644 index 0000000..0882de9 --- /dev/null +++ b/clients/src/main/resources/common/message/CreateTopicsRequest.json @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 19, + "type": "request", + "listeners": ["zkBroker", "broker", "controller"], + "name": "CreateTopicsRequest", + // Version 1 adds validateOnly. + // + // Version 4 makes partitions/replicationFactor optional even when assignments are not present (KIP-464) + // + // Version 5 is the first flexible version. + // Version 5 also returns topic configs in the response (KIP-525). + // + // Version 6 is identical to version 5 but may return a THROTTLING_QUOTA_EXCEEDED error + // in the response if the topics creation is throttled (KIP-599). + // + // Version 7 is the same as version 6. + "validVersions": "0-7", + "flexibleVersions": "5+", + "fields": [ + { "name": "Topics", "type": "[]CreatableTopic", "versions": "0+", + "about": "The topics to create.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "mapKey": true, "entityType": "topicName", + "about": "The topic name." }, + { "name": "NumPartitions", "type": "int32", "versions": "0+", + "about": "The number of partitions to create in the topic, or -1 if we are either specifying a manual partition assignment or using the default partitions." }, + { "name": "ReplicationFactor", "type": "int16", "versions": "0+", + "about": "The number of replicas to create for each partition in the topic, or -1 if we are either specifying a manual partition assignment or using the default replication factor." }, + { "name": "Assignments", "type": "[]CreatableReplicaAssignment", "versions": "0+", + "about": "The manual partition assignment, or the empty array if we are using automatic assignment.", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", "mapKey": true, + "about": "The partition index." }, + { "name": "BrokerIds", "type": "[]int32", "versions": "0+", "entityType": "brokerId", + "about": "The brokers to place the partition on." } + ]}, + { "name": "Configs", "type": "[]CreateableTopicConfig", "versions": "0+", + "about": "The custom topic configurations to set.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+" , "mapKey": true, + "about": "The configuration name." }, + { "name": "Value", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The configuration value." } + ]} + ]}, + { "name": "timeoutMs", "type": "int32", "versions": "0+", "default": "60000", + "about": "How long to wait in milliseconds before timing out the request." }, + { "name": "validateOnly", "type": "bool", "versions": "1+", "default": "false", "ignorable": false, + "about": "If true, check that the topics can be created as specified, but don't create anything." } + ] +} diff --git a/clients/src/main/resources/common/message/CreateTopicsResponse.json b/clients/src/main/resources/common/message/CreateTopicsResponse.json new file mode 100644 index 0000000..c1bf882 --- /dev/null +++ b/clients/src/main/resources/common/message/CreateTopicsResponse.json @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 19, + "type": "response", + "name": "CreateTopicsResponse", + // Version 1 adds a per-topic error message string. + // + // Version 2 adds the throttle time. + // + // Starting in version 3, on quota violation, brokers send out responses before throttling. + // + // Version 4 makes partitions/replicationFactor optional even when assignments are not present (KIP-464). + // + // Version 5 is the first flexible version. + // Version 5 also returns topic configs in the response (KIP-525). + // + // Version 6 is identical to version 5 but may return a THROTTLING_QUOTA_EXCEEDED error + // in the response if the topics creation is throttled (KIP-599). + // + // Version 7 returns the topic ID of the newly created topic if creation is sucessful. + "validVersions": "0-7", + "flexibleVersions": "5+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "2+", "ignorable": true, + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "Topics", "type": "[]CreatableTopicResult", "versions": "0+", + "about": "Results for each topic we tried to create.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "mapKey": true, "entityType": "topicName", + "about": "The topic name." }, + { "name": "TopicId", "type": "uuid", "versions": "7+", "ignorable": true, "about": "The unique topic ID"}, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if there was no error." }, + { "name": "ErrorMessage", "type": "string", "versions": "1+", "nullableVersions": "0+", "ignorable": true, + "about": "The error message, or null if there was no error." }, + { "name": "TopicConfigErrorCode", "type": "int16", "versions": "5+", "tag": 0, "taggedVersions": "5+", "ignorable": true, + "about": "Optional topic config error returned if configs are not returned in the response." }, + { "name": "NumPartitions", "type": "int32", "versions": "5+", "default": "-1", "ignorable": true, + "about": "Number of partitions of the topic." }, + { "name": "ReplicationFactor", "type": "int16", "versions": "5+", "default": "-1", "ignorable": true, + "about": "Replication factor of the topic." }, + { "name": "Configs", "type": "[]CreatableTopicConfigs", "versions": "5+", "nullableVersions": "5+", "ignorable": true, + "about": "Configuration of the topic.", "fields": [ + { "name": "Name", "type": "string", "versions": "5+", + "about": "The configuration name." }, + { "name": "Value", "type": "string", "versions": "5+", "nullableVersions": "5+", + "about": "The configuration value." }, + { "name": "ReadOnly", "type": "bool", "versions": "5+", + "about": "True if the configuration is read-only." }, + { "name": "ConfigSource", "type": "int8", "versions": "5+", "default": "-1", "ignorable": true, + "about": "The configuration source." }, + { "name": "IsSensitive", "type": "bool", "versions": "5+", + "about": "True if this configuration is sensitive." } + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/DefaultPrincipalData.json b/clients/src/main/resources/common/message/DefaultPrincipalData.json new file mode 100644 index 0000000..e06295d --- /dev/null +++ b/clients/src/main/resources/common/message/DefaultPrincipalData.json @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "type": "data", + "name": "DefaultPrincipalData", + // The encoding format for default Kafka principal in + // org.apache.kafka.common.security.authenticator.DefaultKafkaPrincipalBuilder. + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + {"name": "Type", "type": "string", "versions": "0+", + "about": "The principal type"}, + {"name": "Name", "type": "string", "versions": "0+", + "about": "The principal name"}, + {"name": "TokenAuthenticated", "type": "bool", "versions": "0+", + "about": "Whether the principal was authenticated by a delegation token on the forwarding broker."} + ] +} diff --git a/clients/src/main/resources/common/message/DeleteAclsRequest.json b/clients/src/main/resources/common/message/DeleteAclsRequest.json new file mode 100644 index 0000000..fd7c152 --- /dev/null +++ b/clients/src/main/resources/common/message/DeleteAclsRequest.json @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 31, + "type": "request", + "listeners": ["zkBroker", "broker", "controller"], + "name": "DeleteAclsRequest", + // Version 1 adds the pattern type. + // Version 2 enables flexible versions. + "validVersions": "0-2", + "flexibleVersions": "2+", + "fields": [ + { "name": "Filters", "type": "[]DeleteAclsFilter", "versions": "0+", + "about": "The filters to use when deleting ACLs.", "fields": [ + { "name": "ResourceTypeFilter", "type": "int8", "versions": "0+", + "about": "The resource type." }, + { "name": "ResourceNameFilter", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The resource name." }, + { "name": "PatternTypeFilter", "type": "int8", "versions": "1+", "default": "3", "ignorable": false, + "about": "The pattern type." }, + { "name": "PrincipalFilter", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The principal filter, or null to accept all principals." }, + { "name": "HostFilter", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The host filter, or null to accept all hosts." }, + { "name": "Operation", "type": "int8", "versions": "0+", + "about": "The ACL operation." }, + { "name": "PermissionType", "type": "int8", "versions": "0+", + "about": "The permission type." } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/DeleteAclsResponse.json b/clients/src/main/resources/common/message/DeleteAclsResponse.json new file mode 100644 index 0000000..08f5702 --- /dev/null +++ b/clients/src/main/resources/common/message/DeleteAclsResponse.json @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 31, + "type": "response", + "name": "DeleteAclsResponse", + // Version 1 adds the resource pattern type. + // Starting in version 1, on quota violation, brokers send out responses before throttling. + // Version 2 enables flexible versions. + "validVersions": "0-2", + "flexibleVersions": "2+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "FilterResults", "type": "[]DeleteAclsFilterResult", "versions": "0+", + "about": "The results for each filter.", "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if the filter succeeded." }, + { "name": "ErrorMessage", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The error message, or null if the filter succeeded." }, + { "name": "MatchingAcls", "type": "[]DeleteAclsMatchingAcl", "versions": "0+", + "about": "The ACLs which matched this filter.", "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The deletion error code, or 0 if the deletion succeeded." }, + { "name": "ErrorMessage", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The deletion error message, or null if the deletion succeeded." }, + { "name": "ResourceType", "type": "int8", "versions": "0+", + "about": "The ACL resource type." }, + { "name": "ResourceName", "type": "string", "versions": "0+", + "about": "The ACL resource name." }, + { "name": "PatternType", "type": "int8", "versions": "1+", "default": "3", "ignorable": false, + "about": "The ACL resource pattern type." }, + { "name": "Principal", "type": "string", "versions": "0+", + "about": "The ACL principal." }, + { "name": "Host", "type": "string", "versions": "0+", + "about": "The ACL host." }, + { "name": "Operation", "type": "int8", "versions": "0+", + "about": "The ACL operation." }, + { "name": "PermissionType", "type": "int8", "versions": "0+", + "about": "The ACL permission type." } + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/DeleteGroupsRequest.json b/clients/src/main/resources/common/message/DeleteGroupsRequest.json new file mode 100644 index 0000000..1ac6a05 --- /dev/null +++ b/clients/src/main/resources/common/message/DeleteGroupsRequest.json @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 42, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "DeleteGroupsRequest", + // Version 1 is the same as version 0. + // + // Version 2 is the first flexible version. + "validVersions": "0-2", + "flexibleVersions": "2+", + "fields": [ + { "name": "GroupsNames", "type": "[]string", "versions": "0+", "entityType": "groupId", + "about": "The group names to delete." } + ] +} diff --git a/clients/src/main/resources/common/message/DeleteGroupsResponse.json b/clients/src/main/resources/common/message/DeleteGroupsResponse.json new file mode 100644 index 0000000..37e06a5 --- /dev/null +++ b/clients/src/main/resources/common/message/DeleteGroupsResponse.json @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 42, + "type": "response", + "name": "DeleteGroupsResponse", + // Starting in version 1, on quota violation, brokers send out responses before throttling. + // + // Version 2 is the first flexible version. + "validVersions": "0-2", + "flexibleVersions": "2+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "Results", "type": "[]DeletableGroupResult", "versions": "0+", + "about": "The deletion results", "fields": [ + { "name": "GroupId", "type": "string", "versions": "0+", "mapKey": true, "entityType": "groupId", + "about": "The group id" }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The deletion error, or 0 if the deletion succeeded." } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/DeleteRecordsRequest.json b/clients/src/main/resources/common/message/DeleteRecordsRequest.json new file mode 100644 index 0000000..06a12d8 --- /dev/null +++ b/clients/src/main/resources/common/message/DeleteRecordsRequest.json @@ -0,0 +1,42 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 21, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "DeleteRecordsRequest", + // Version 1 is the same as version 0. + + // Version 2 is the first flexible version. + "validVersions": "0-2", + "flexibleVersions": "2+", + "fields": [ + { "name": "Topics", "type": "[]DeleteRecordsTopic", "versions": "0+", + "about": "Each topic that we want to delete records from.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name." }, + { "name": "Partitions", "type": "[]DeleteRecordsPartition", "versions": "0+", + "about": "Each partition that we want to delete records from.", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "Offset", "type": "int64", "versions": "0+", + "about": "The deletion offset." } + ]} + ]}, + { "name": "TimeoutMs", "type": "int32", "versions": "0+", + "about": "How long to wait for the deletion to complete, in milliseconds." } + ] +} diff --git a/clients/src/main/resources/common/message/DeleteRecordsResponse.json b/clients/src/main/resources/common/message/DeleteRecordsResponse.json new file mode 100644 index 0000000..bfc0a56 --- /dev/null +++ b/clients/src/main/resources/common/message/DeleteRecordsResponse.json @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 21, + "type": "response", + "name": "DeleteRecordsResponse", + // Starting in version 1, on quota violation, brokers send out responses before throttling. + + // Version 2 is the first flexible version. + "validVersions": "0-2", + "flexibleVersions": "2+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "Topics", "type": "[]DeleteRecordsTopicResult", "versions": "0+", + "about": "Each topic that we wanted to delete records from.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "mapKey": true, "entityType": "topicName", + "about": "The topic name." }, + { "name": "Partitions", "type": "[]DeleteRecordsPartitionResult", "versions": "0+", + "about": "Each partition that we wanted to delete records from.", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", "mapKey": true, + "about": "The partition index." }, + { "name": "LowWatermark", "type": "int64", "versions": "0+", + "about": "The partition low water mark." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The deletion error code, or 0 if the deletion succeeded." } + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/DeleteTopicsRequest.json b/clients/src/main/resources/common/message/DeleteTopicsRequest.json new file mode 100644 index 0000000..19bdc8a --- /dev/null +++ b/clients/src/main/resources/common/message/DeleteTopicsRequest.json @@ -0,0 +1,42 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 20, + "type": "request", + "listeners": ["zkBroker", "broker", "controller"], + "name": "DeleteTopicsRequest", + // Versions 0, 1, 2, and 3 are the same. + // + // Version 4 is the first flexible version. + // + // Version 5 adds ErrorMessage in the response and may return a THROTTLING_QUOTA_EXCEEDED error + // in the response if the topics deletion is throttled (KIP-599). + // + // Version 6 reorganizes topics, adds topic IDs and allows topic names to be null. + "validVersions": "0-6", + "flexibleVersions": "4+", + "fields": [ + { "name": "Topics", "type": "[]DeleteTopicState", "versions": "6+", "about": "The name or topic ID of the topic", + "fields": [ + {"name": "Name", "type": "string", "versions": "6+", "nullableVersions": "6+", "default": "null", "entityType": "topicName", "about": "The topic name"}, + {"name": "TopicId", "type": "uuid", "versions": "6+", "about": "The unique topic ID"} + ]}, + { "name": "TopicNames", "type": "[]string", "versions": "0-5", "entityType": "topicName", "ignorable": true, + "about": "The names of the topics to delete" }, + { "name": "TimeoutMs", "type": "int32", "versions": "0+", + "about": "The length of time in milliseconds to wait for the deletions to complete." } + ] +} diff --git a/clients/src/main/resources/common/message/DeleteTopicsResponse.json b/clients/src/main/resources/common/message/DeleteTopicsResponse.json new file mode 100644 index 0000000..19a8163 --- /dev/null +++ b/clients/src/main/resources/common/message/DeleteTopicsResponse.json @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 20, + "type": "response", + "name": "DeleteTopicsResponse", + // Version 1 adds the throttle time. + // + // Starting in version 2, on quota violation, brokers send out responses before throttling. + // + // Starting in version 3, a TOPIC_DELETION_DISABLED error code may be returned. + // + // Version 4 is the first flexible version. + // + // Version 5 adds ErrorMessage in the response and may return a THROTTLING_QUOTA_EXCEEDED error + // in the response if the topics deletion is throttled (KIP-599). + // + // Version 6 adds topic ID to responses. An UNSUPPORTED_VERSION error code will be returned when attempting to + // delete using topic IDs when IBP < 2.8. UNKNOWN_TOPIC_ID error code will be returned when IBP is at least 2.8, but + // the topic ID was not found. + "validVersions": "0-6", + "flexibleVersions": "4+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "1+", "ignorable": true, + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "Responses", "type": "[]DeletableTopicResult", "versions": "0+", + "about": "The results for each topic we tried to delete.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "nullableVersions": "6+", "mapKey": true, "entityType": "topicName", + "about": "The topic name" }, + {"name": "TopicId", "type": "uuid", "versions": "6+", "ignorable": true, "about": "the unique topic ID"}, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The deletion error, or 0 if the deletion succeeded." }, + { "name": "ErrorMessage", "type": "string", "versions": "5+", "nullableVersions": "5+", "ignorable": true, "default": "null", + "about": "The error message, or null if there was no error." } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/DescribeAclsRequest.json b/clients/src/main/resources/common/message/DescribeAclsRequest.json new file mode 100644 index 0000000..58886da --- /dev/null +++ b/clients/src/main/resources/common/message/DescribeAclsRequest.json @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 29, + "type": "request", + "listeners": ["zkBroker", "broker", "controller"], + "name": "DescribeAclsRequest", + // Version 1 adds resource pattern type. + // Version 2 enables flexible versions. + "validVersions": "0-2", + "flexibleVersions": "2+", + "fields": [ + { "name": "ResourceTypeFilter", "type": "int8", "versions": "0+", + "about": "The resource type." }, + { "name": "ResourceNameFilter", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The resource name, or null to match any resource name." }, + { "name": "PatternTypeFilter", "type": "int8", "versions": "1+", "default": "3", "ignorable": false, + "about": "The resource pattern to match." }, + { "name": "PrincipalFilter", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The principal to match, or null to match any principal." }, + { "name": "HostFilter", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The host to match, or null to match any host." }, + { "name": "Operation", "type": "int8", "versions": "0+", + "about": "The operation to match." }, + { "name": "PermissionType", "type": "int8", "versions": "0+", + "about": "The permission type to match." } + ] +} diff --git a/clients/src/main/resources/common/message/DescribeAclsResponse.json b/clients/src/main/resources/common/message/DescribeAclsResponse.json new file mode 100644 index 0000000..0ae72d6 --- /dev/null +++ b/clients/src/main/resources/common/message/DescribeAclsResponse.json @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 29, + "type": "response", + "name": "DescribeAclsResponse", + // Version 1 adds PatternType. + // Starting in version 1, on quota violation, brokers send out responses before throttling. + // Version 2 enables flexible versions. + "validVersions": "0-2", + "flexibleVersions": "2+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if there was no error." }, + { "name": "ErrorMessage", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The error message, or null if there was no error." }, + { "name": "Resources", "type": "[]DescribeAclsResource", "versions": "0+", + "about": "Each Resource that is referenced in an ACL.", "fields": [ + { "name": "ResourceType", "type": "int8", "versions": "0+", + "about": "The resource type." }, + { "name": "ResourceName", "type": "string", "versions": "0+", + "about": "The resource name." }, + { "name": "PatternType", "type": "int8", "versions": "1+", "default": "3", "ignorable": false, + "about": "The resource pattern type." }, + { "name": "Acls", "type": "[]AclDescription", "versions": "0+", + "about": "The ACLs.", "fields": [ + { "name": "Principal", "type": "string", "versions": "0+", + "about": "The ACL principal." }, + { "name": "Host", "type": "string", "versions": "0+", + "about": "The ACL host." }, + { "name": "Operation", "type": "int8", "versions": "0+", + "about": "The ACL operation." }, + { "name": "PermissionType", "type": "int8", "versions": "0+", + "about": "The ACL permission type." } + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/DescribeClientQuotasRequest.json b/clients/src/main/resources/common/message/DescribeClientQuotasRequest.json new file mode 100644 index 0000000..d14cfc9 --- /dev/null +++ b/clients/src/main/resources/common/message/DescribeClientQuotasRequest.json @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 48, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "DescribeClientQuotasRequest", + // Version 1 enables flexible versions. + "validVersions": "0-1", + "flexibleVersions": "1+", + "fields": [ + { "name": "Components", "type": "[]ComponentData", "versions": "0+", + "about": "Filter components to apply to quota entities.", "fields": [ + { "name": "EntityType", "type": "string", "versions": "0+", + "about": "The entity type that the filter component applies to." }, + { "name": "MatchType", "type": "int8", "versions": "0+", + "about": "How to match the entity {0 = exact name, 1 = default name, 2 = any specified name}." }, + { "name": "Match", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The string to match against, or null if unused for the match type." } + ]}, + { "name": "Strict", "type": "bool", "versions": "0+", + "about": "Whether the match is strict, i.e. should exclude entities with unspecified entity types." } + ] +} diff --git a/clients/src/main/resources/common/message/DescribeClientQuotasResponse.json b/clients/src/main/resources/common/message/DescribeClientQuotasResponse.json new file mode 100644 index 0000000..0dd0c9c --- /dev/null +++ b/clients/src/main/resources/common/message/DescribeClientQuotasResponse.json @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +{ + "apiKey": 48, + "type": "response", + "name": "DescribeClientQuotasResponse", + // Version 1 enables flexible versions. + "validVersions": "0-1", + "flexibleVersions": "1+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or `0` if the quota description succeeded." }, + { "name": "ErrorMessage", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The error message, or `null` if the quota description succeeded." }, + { "name": "Entries", "type": "[]EntryData", "versions": "0+", "nullableVersions": "0+", + "about": "A result entry.", "fields": [ + { "name": "Entity", "type": "[]EntityData", "versions": "0+", + "about": "The quota entity description.", "fields": [ + { "name": "EntityType", "type": "string", "versions": "0+", + "about": "The entity type." }, + { "name": "EntityName", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The entity name, or null if the default." } + ]}, + { "name": "Values", "type": "[]ValueData", "versions": "0+", + "about": "The quota values for the entity.", "fields": [ + { "name": "Key", "type": "string", "versions": "0+", + "about": "The quota configuration key." }, + { "name": "Value", "type": "float64", "versions": "0+", + "about": "The quota configuration value." } + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/DescribeClusterRequest.json b/clients/src/main/resources/common/message/DescribeClusterRequest.json new file mode 100644 index 0000000..192e4d8 --- /dev/null +++ b/clients/src/main/resources/common/message/DescribeClusterRequest.json @@ -0,0 +1,27 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 60, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "DescribeClusterRequest", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "IncludeClusterAuthorizedOperations", "type": "bool", "versions": "0+", + "about": "Whether to include cluster authorized operations." } + ] +} diff --git a/clients/src/main/resources/common/message/DescribeClusterResponse.json b/clients/src/main/resources/common/message/DescribeClusterResponse.json new file mode 100644 index 0000000..084ff54 --- /dev/null +++ b/clients/src/main/resources/common/message/DescribeClusterResponse.json @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 60, + "type": "response", + "name": "DescribeClusterResponse", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The top-level error code, or 0 if there was no error" }, + { "name": "ErrorMessage", "type": "string", "versions": "0+", "nullableVersions": "0+", "default": "null", + "about": "The top-level error message, or null if there was no error." }, + { "name": "ClusterId", "type": "string", "versions": "0+", + "about": "The cluster ID that responding broker belongs to." }, + { "name": "ControllerId", "type": "int32", "versions": "0+", "default": "-1", "entityType": "brokerId", + "about": "The ID of the controller broker." }, + { "name": "Brokers", "type": "[]DescribeClusterBroker", "versions": "0+", + "about": "Each broker in the response.", "fields": [ + { "name": "BrokerId", "type": "int32", "versions": "0+", "mapKey": true, "entityType": "brokerId", + "about": "The broker ID." }, + { "name": "Host", "type": "string", "versions": "0+", + "about": "The broker hostname." }, + { "name": "Port", "type": "int32", "versions": "0+", + "about": "The broker port." }, + { "name": "Rack", "type": "string", "versions": "0+", "nullableVersions": "0+", "default": "null", + "about": "The rack of the broker, or null if it has not been assigned to a rack." } + ]}, + { "name": "ClusterAuthorizedOperations", "type": "int32", "versions": "0+", "default": "-2147483648", + "about": "32-bit bitfield to represent authorized operations for this cluster." } + ] +} diff --git a/clients/src/main/resources/common/message/DescribeConfigsRequest.json b/clients/src/main/resources/common/message/DescribeConfigsRequest.json new file mode 100644 index 0000000..f48b168 --- /dev/null +++ b/clients/src/main/resources/common/message/DescribeConfigsRequest.json @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 32, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "DescribeConfigsRequest", + // Version 1 adds IncludeSynonyms. + // Version 2 is the same as version 1. + // Version 4 enables flexible versions. + "validVersions": "0-4", + "flexibleVersions": "4+", + "fields": [ + { "name": "Resources", "type": "[]DescribeConfigsResource", "versions": "0+", + "about": "The resources whose configurations we want to describe.", "fields": [ + { "name": "ResourceType", "type": "int8", "versions": "0+", + "about": "The resource type." }, + { "name": "ResourceName", "type": "string", "versions": "0+", + "about": "The resource name." }, + { "name": "ConfigurationKeys", "type": "[]string", "versions": "0+", "nullableVersions": "0+", + "about": "The configuration keys to list, or null to list all configuration keys." } + ]}, + { "name": "IncludeSynonyms", "type": "bool", "versions": "1+", "default": "false", "ignorable": false, + "about": "True if we should include all synonyms." }, + { "name": "IncludeDocumentation", "type": "bool", "versions": "3+", "default": "false", "ignorable": false, + "about": "True if we should include configuration documentation." } + ] +} diff --git a/clients/src/main/resources/common/message/DescribeConfigsResponse.json b/clients/src/main/resources/common/message/DescribeConfigsResponse.json new file mode 100644 index 0000000..f2f57ad --- /dev/null +++ b/clients/src/main/resources/common/message/DescribeConfigsResponse.json @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 32, + "type": "response", + "name": "DescribeConfigsResponse", + // Version 1 adds ConfigSource and the synonyms. + // Starting in version 2, on quota violation, brokers send out responses before throttling. + // Version 4 enables flexible versions. + "validVersions": "0-4", + "flexibleVersions": "4+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "Results", "type": "[]DescribeConfigsResult", "versions": "0+", + "about": "The results for each resource.", "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if we were able to successfully describe the configurations." }, + { "name": "ErrorMessage", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The error message, or null if we were able to successfully describe the configurations." }, + { "name": "ResourceType", "type": "int8", "versions": "0+", + "about": "The resource type." }, + { "name": "ResourceName", "type": "string", "versions": "0+", + "about": "The resource name." }, + { "name": "Configs", "type": "[]DescribeConfigsResourceResult", "versions": "0+", + "about": "Each listed configuration.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", + "about": "The configuration name." }, + { "name": "Value", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The configuration value." }, + { "name": "ReadOnly", "type": "bool", "versions": "0+", + "about": "True if the configuration is read-only." }, + { "name": "IsDefault", "type": "bool", "versions": "0", + "about": "True if the configuration is not set." }, + // Note: the v0 default for this field that should be exposed to callers is + // context-dependent. For example, if the resource is a broker, this should default to 4. + // -1 is just a placeholder value. + { "name": "ConfigSource", "type": "int8", "versions": "1+", "default": "-1", "ignorable": true, + "about": "The configuration source." }, + { "name": "IsSensitive", "type": "bool", "versions": "0+", + "about": "True if this configuration is sensitive." }, + { "name": "Synonyms", "type": "[]DescribeConfigsSynonym", "versions": "1+", "ignorable": true, + "about": "The synonyms for this configuration key.", "fields": [ + { "name": "Name", "type": "string", "versions": "1+", + "about": "The synonym name." }, + { "name": "Value", "type": "string", "versions": "1+", "nullableVersions": "0+", + "about": "The synonym value." }, + { "name": "Source", "type": "int8", "versions": "1+", + "about": "The synonym source." } + ]}, + { "name": "ConfigType", "type": "int8", "versions": "3+", "default": "0", "ignorable": true, + "about": "The configuration data type. Type can be one of the following values - BOOLEAN, STRING, INT, SHORT, LONG, DOUBLE, LIST, CLASS, PASSWORD" }, + { "name": "Documentation", "type": "string", "versions": "3+", "nullableVersions": "0+", "ignorable": true, + "about": "The configuration documentation." } + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/DescribeDelegationTokenRequest.json b/clients/src/main/resources/common/message/DescribeDelegationTokenRequest.json new file mode 100644 index 0000000..79c342e --- /dev/null +++ b/clients/src/main/resources/common/message/DescribeDelegationTokenRequest.json @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 41, + "type": "request", + "listeners": ["zkBroker"], + "name": "DescribeDelegationTokenRequest", + // Version 1 is the same as version 0. + // Version 2 adds flexible version support + "validVersions": "0-2", + "flexibleVersions": "2+", + "fields": [ + { "name": "Owners", "type": "[]DescribeDelegationTokenOwner", "versions": "0+", "nullableVersions": "0+", + "about": "Each owner that we want to describe delegation tokens for, or null to describe all tokens.", "fields": [ + { "name": "PrincipalType", "type": "string", "versions": "0+", + "about": "The owner principal type." }, + { "name": "PrincipalName", "type": "string", "versions": "0+", + "about": "The owner principal name." } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/DescribeDelegationTokenResponse.json b/clients/src/main/resources/common/message/DescribeDelegationTokenResponse.json new file mode 100644 index 0000000..09f69ce --- /dev/null +++ b/clients/src/main/resources/common/message/DescribeDelegationTokenResponse.json @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 41, + "type": "response", + "name": "DescribeDelegationTokenResponse", + // Starting in version 1, on quota violation, brokers send out responses before throttling. + // Version 2 adds flexible version support + "validVersions": "0-2", + "flexibleVersions": "2+", + "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if there was no error." }, + { "name": "Tokens", "type": "[]DescribedDelegationToken", "versions": "0+", + "about": "The tokens.", "fields": [ + { "name": "PrincipalType", "type": "string", "versions": "0+", + "about": "The token principal type." }, + { "name": "PrincipalName", "type": "string", "versions": "0+", + "about": "The token principal name." }, + { "name": "IssueTimestamp", "type": "int64", "versions": "0+", + "about": "The token issue timestamp in milliseconds." }, + { "name": "ExpiryTimestamp", "type": "int64", "versions": "0+", + "about": "The token expiry timestamp in milliseconds." }, + { "name": "MaxTimestamp", "type": "int64", "versions": "0+", + "about": "The token maximum timestamp length in milliseconds." }, + { "name": "TokenId", "type": "string", "versions": "0+", + "about": "The token ID." }, + { "name": "Hmac", "type": "bytes", "versions": "0+", + "about": "The token HMAC." }, + { "name": "Renewers", "type": "[]DescribedDelegationTokenRenewer", "versions": "0+", + "about": "Those who are able to renew this token before it expires.", "fields": [ + { "name": "PrincipalType", "type": "string", "versions": "0+", + "about": "The renewer principal type" }, + { "name": "PrincipalName", "type": "string", "versions": "0+", + "about": "The renewer principal name" } + ]} + ]}, + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." } + ] +} diff --git a/clients/src/main/resources/common/message/DescribeGroupsRequest.json b/clients/src/main/resources/common/message/DescribeGroupsRequest.json new file mode 100644 index 0000000..6b10b06 --- /dev/null +++ b/clients/src/main/resources/common/message/DescribeGroupsRequest.json @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 15, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "DescribeGroupsRequest", + // Versions 1 and 2 are the same as version 0. + // + // Starting in version 3, authorized operations can be requested. + // + // Starting in version 4, the response will include group.instance.id info for members. + // + // Version 5 is the first flexible version. + "validVersions": "0-5", + "flexibleVersions": "5+", + "fields": [ + { "name": "Groups", "type": "[]string", "versions": "0+", "entityType": "groupId", + "about": "The names of the groups to describe" }, + { "name": "IncludeAuthorizedOperations", "type": "bool", "versions": "3+", + "about": "Whether to include authorized operations." } + ] +} diff --git a/clients/src/main/resources/common/message/DescribeGroupsResponse.json b/clients/src/main/resources/common/message/DescribeGroupsResponse.json new file mode 100644 index 0000000..f195843 --- /dev/null +++ b/clients/src/main/resources/common/message/DescribeGroupsResponse.json @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 15, + "type": "response", + "name": "DescribeGroupsResponse", + // Version 1 added throttle time. + // + // Starting in version 2, on quota violation, brokers send out responses before throttling. + // + // Starting in version 3, brokers can send authorized operations. + // + // Starting in version 4, the response will optionally include group.instance.id info for members. + // + // Version 5 is the first flexible version. + "validVersions": "0-5", + "flexibleVersions": "5+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "1+", "ignorable": true, + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "Groups", "type": "[]DescribedGroup", "versions": "0+", + "about": "Each described group.", "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The describe error, or 0 if there was no error." }, + { "name": "GroupId", "type": "string", "versions": "0+", "entityType": "groupId", + "about": "The group ID string." }, + { "name": "GroupState", "type": "string", "versions": "0+", + "about": "The group state string, or the empty string." }, + { "name": "ProtocolType", "type": "string", "versions": "0+", + "about": "The group protocol type, or the empty string." }, + // ProtocolData is currently only filled in if the group state is in the Stable state. + { "name": "ProtocolData", "type": "string", "versions": "0+", + "about": "The group protocol data, or the empty string." }, + // N.B. If the group is in the Dead state, the members array will always be empty. + { "name": "Members", "type": "[]DescribedGroupMember", "versions": "0+", + "about": "The group members.", "fields": [ + { "name": "MemberId", "type": "string", "versions": "0+", + "about": "The member ID assigned by the group coordinator." }, + { "name": "GroupInstanceId", "type": "string", "versions": "4+", "ignorable": true, + "nullableVersions": "4+", "default": "null", + "about": "The unique identifier of the consumer instance provided by end user." }, + { "name": "ClientId", "type": "string", "versions": "0+", + "about": "The client ID used in the member's latest join group request." }, + { "name": "ClientHost", "type": "string", "versions": "0+", + "about": "The client host." }, + // This is currently only provided if the group is in the Stable state. + { "name": "MemberMetadata", "type": "bytes", "versions": "0+", + "about": "The metadata corresponding to the current group protocol in use." }, + // This is currently only provided if the group is in the Stable state. + { "name": "MemberAssignment", "type": "bytes", "versions": "0+", + "about": "The current assignment provided by the group leader." } + ]}, + { "name": "AuthorizedOperations", "type": "int32", "versions": "3+", "default": "-2147483648", + "about": "32-bit bitfield to represent authorized operations for this group." } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/DescribeLogDirsRequest.json b/clients/src/main/resources/common/message/DescribeLogDirsRequest.json new file mode 100644 index 0000000..cfb160f --- /dev/null +++ b/clients/src/main/resources/common/message/DescribeLogDirsRequest.json @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 35, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "DescribeLogDirsRequest", + // Version 1 is the same as version 0. + "validVersions": "0-2", + // Version 2 is the first flexible version. + "flexibleVersions": "2+", + "fields": [ + { "name": "Topics", "type": "[]DescribableLogDirTopic", "versions": "0+", "nullableVersions": "0+", + "about": "Each topic that we want to describe log directories for, or null for all topics.", "fields": [ + { "name": "Topic", "type": "string", "versions": "0+", "entityType": "topicName", "mapKey": true, + "about": "The topic name" }, + { "name": "Partitions", "type": "[]int32", "versions": "0+", + "about": "The partition indxes." } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/DescribeLogDirsResponse.json b/clients/src/main/resources/common/message/DescribeLogDirsResponse.json new file mode 100644 index 0000000..4322f1c --- /dev/null +++ b/clients/src/main/resources/common/message/DescribeLogDirsResponse.json @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 35, + "type": "response", + "name": "DescribeLogDirsResponse", + // Starting in version 1, on quota violation, brokers send out responses before throttling. + "validVersions": "0-2", + // Version 2 is the first flexible version. + "flexibleVersions": "2+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "Results", "type": "[]DescribeLogDirsResult", "versions": "0+", + "about": "The log directories.", "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if there was no error." }, + { "name": "LogDir", "type": "string", "versions": "0+", + "about": "The absolute log directory path." }, + { "name": "Topics", "type": "[]DescribeLogDirsTopic", "versions": "0+", + "about": "Each topic.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name." }, + { "name": "Partitions", "type": "[]DescribeLogDirsPartition", "versions": "0+", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "PartitionSize", "type": "int64", "versions": "0+", + "about": "The size of the log segments in this partition in bytes." }, + { "name": "OffsetLag", "type": "int64", "versions": "0+", + "about": "The lag of the log's LEO w.r.t. partition's HW (if it is the current log for the partition) or current replica's LEO (if it is the future log for the partition)" }, + { "name": "IsFutureKey", "type": "bool", "versions": "0+", + "about": "True if this log is created by AlterReplicaLogDirsRequest and will replace the current log of the replica in the future." } + ]} + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/DescribeProducersRequest.json b/clients/src/main/resources/common/message/DescribeProducersRequest.json new file mode 100644 index 0000000..0e3813b --- /dev/null +++ b/clients/src/main/resources/common/message/DescribeProducersRequest.json @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 61, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "DescribeProducersRequest", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "Topics", "type": "[]TopicRequest", "versions": "0+", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name." }, + { "name": "PartitionIndexes", "type": "[]int32", "versions": "0+", + "about": "The indexes of the partitions to list producers for." } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/DescribeProducersResponse.json b/clients/src/main/resources/common/message/DescribeProducersResponse.json new file mode 100644 index 0000000..c456ee4 --- /dev/null +++ b/clients/src/main/resources/common/message/DescribeProducersResponse.json @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 61, + "type": "response", + "name": "DescribeProducersResponse", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "Topics", "type": "[]TopicResponse", "versions": "0+", + "about": "Each topic in the response.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name" }, + { "name": "Partitions", "type": "[]PartitionResponse", "versions": "0+", + "about": "Each partition in the response.", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The partition error code, or 0 if there was no error." }, + { "name": "ErrorMessage", "type": "string", "versions": "0+", "nullableVersions": "0+", "default": "null", + "about": "The partition error message, which may be null if no additional details are available" }, + { "name": "ActiveProducers", "type": "[]ProducerState", "versions": "0+", "fields": [ + { "name": "ProducerId", "type": "int64", "versions": "0+", "entityType": "producerId" }, + { "name": "ProducerEpoch", "type": "int32", "versions": "0+" }, + { "name": "LastSequence", "type": "int32", "versions": "0+", "default": "-1" }, + { "name": "LastTimestamp", "type": "int64", "versions": "0+", "default": "-1" }, + { "name": "CoordinatorEpoch", "type": "int32", "versions": "0+" }, + { "name": "CurrentTxnStartOffset", "type": "int64", "versions": "0+", "default": "-1" } + ]} + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/DescribeQuorumRequest.json b/clients/src/main/resources/common/message/DescribeQuorumRequest.json new file mode 100644 index 0000000..cd4a7f1 --- /dev/null +++ b/clients/src/main/resources/common/message/DescribeQuorumRequest.json @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 55, + "type": "request", + "listeners": ["broker", "controller"], + "name": "DescribeQuorumRequest", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "Topics", "type": "[]TopicData", + "versions": "0+", "fields": [ + { "name": "TopicName", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name." }, + { "name": "Partitions", "type": "[]PartitionData", + "versions": "0+", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index." } + ] + }] + } + ] +} diff --git a/clients/src/main/resources/common/message/DescribeQuorumResponse.json b/clients/src/main/resources/common/message/DescribeQuorumResponse.json new file mode 100644 index 0000000..444fee3 --- /dev/null +++ b/clients/src/main/resources/common/message/DescribeQuorumResponse.json @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 55, + "type": "response", + "name": "DescribeQuorumResponse", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The top level error code."}, + { "name": "Topics", "type": "[]TopicData", + "versions": "0+", "fields": [ + { "name": "TopicName", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name." }, + { "name": "Partitions", "type": "[]PartitionData", + "versions": "0+", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+"}, + { "name": "LeaderId", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The ID of the current leader or -1 if the leader is unknown."}, + { "name": "LeaderEpoch", "type": "int32", "versions": "0+", + "about": "The latest known leader epoch"}, + { "name": "HighWatermark", "type": "int64", "versions": "0+"}, + { "name": "CurrentVoters", "type": "[]ReplicaState", "versions": "0+" }, + { "name": "Observers", "type": "[]ReplicaState", "versions": "0+" } + ]} + ]}], + "commonStructs": [ + { "name": "ReplicaState", "versions": "0+", "fields": [ + { "name": "ReplicaId", "type": "int32", "versions": "0+", "entityType": "brokerId" }, + { "name": "LogEndOffset", "type": "int64", "versions": "0+", + "about": "The last known log end offset of the follower or -1 if it is unknown"} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/DescribeTransactionsRequest.json b/clients/src/main/resources/common/message/DescribeTransactionsRequest.json new file mode 100644 index 0000000..442f11f --- /dev/null +++ b/clients/src/main/resources/common/message/DescribeTransactionsRequest.json @@ -0,0 +1,27 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 65, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "DescribeTransactionsRequest", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "TransactionalIds", "entityType": "transactionalId", "type": "[]string", "versions": "0+", + "about": "Array of transactionalIds to include in describe results. If empty, then no results will be returned." } + ] +} diff --git a/clients/src/main/resources/common/message/DescribeTransactionsResponse.json b/clients/src/main/resources/common/message/DescribeTransactionsResponse.json new file mode 100644 index 0000000..15f52a4 --- /dev/null +++ b/clients/src/main/resources/common/message/DescribeTransactionsResponse.json @@ -0,0 +1,42 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 65, + "type": "response", + "name": "DescribeTransactionsResponse", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "TransactionStates", "type": "[]TransactionState", "versions": "0+", "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+" }, + { "name": "TransactionalId", "type": "string", "versions": "0+", "entityType": "transactionalId" }, + { "name": "TransactionState", "type": "string", "versions": "0+" }, + { "name": "TransactionTimeoutMs", "type": "int32", "versions": "0+" }, + { "name": "TransactionStartTimeMs", "type": "int64", "versions": "0+" }, + { "name": "ProducerId", "type": "int64", "versions": "0+", "entityType": "producerId" }, + { "name": "ProducerEpoch", "type": "int16", "versions": "0+" }, + { "name": "Topics", "type": "[]TopicData", "versions": "0+", + "about": "The set of partitions included in the current transaction (if active). When a transaction is preparing to commit or abort, this will include only partitions which do not have markers.", + "fields": [ + { "name": "Topic", "type": "string", "versions": "0+", "entityType": "topicName", "mapKey": true }, + { "name": "Partitions", "type": "[]int32", "versions": "0+" } + ] + } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/DescribeUserScramCredentialsRequest.json b/clients/src/main/resources/common/message/DescribeUserScramCredentialsRequest.json new file mode 100644 index 0000000..cef8929 --- /dev/null +++ b/clients/src/main/resources/common/message/DescribeUserScramCredentialsRequest.json @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 50, + "type": "request", + "listeners": ["zkBroker"], + "name": "DescribeUserScramCredentialsRequest", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "Users", "type": "[]UserName", "versions": "0+", "nullableVersions": "0+", + "about": "The users to describe, or null/empty to describe all users.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", + "about": "The user name." } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/DescribeUserScramCredentialsResponse.json b/clients/src/main/resources/common/message/DescribeUserScramCredentialsResponse.json new file mode 100644 index 0000000..9e8b035 --- /dev/null +++ b/clients/src/main/resources/common/message/DescribeUserScramCredentialsResponse.json @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 50, + "type": "response", + "name": "DescribeUserScramCredentialsResponse", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The message-level error code, 0 except for user authorization or infrastructure issues." }, + { "name": "ErrorMessage", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The message-level error message, if any." }, + { "name": "Results", "type": "[]DescribeUserScramCredentialsResult", "versions": "0+", + "about": "The results for descriptions, one per user.", "fields": [ + { "name": "User", "type": "string", "versions": "0+", + "about": "The user name." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The user-level error code." }, + { "name": "ErrorMessage", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The user-level error message, if any." }, + { "name": "CredentialInfos", "type": "[]CredentialInfo", "versions": "0+", + "about": "The mechanism and related information associated with the user's SCRAM credentials.", "fields": [ + { "name": "Mechanism", "type": "int8", "versions": "0+", + "about": "The SCRAM mechanism." }, + { "name": "Iterations", "type": "int32", "versions": "0+", + "about": "The number of iterations used in the SCRAM credential." }]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/ElectLeadersRequest.json b/clients/src/main/resources/common/message/ElectLeadersRequest.json new file mode 100644 index 0000000..dd9fa21 --- /dev/null +++ b/clients/src/main/resources/common/message/ElectLeadersRequest.json @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 43, + "type": "request", + "listeners": ["zkBroker", "broker", "controller"], + "name": "ElectLeadersRequest", + // Version 1 implements multiple leader election types, as described by KIP-460. + // + // Version 2 is the first flexible version. + "validVersions": "0-2", + "flexibleVersions": "2+", + "fields": [ + { "name": "ElectionType", "type": "int8", "versions": "1+", + "about": "Type of elections to conduct for the partition. A value of '0' elects the preferred replica. A value of '1' elects the first live replica if there are no in-sync replica." }, + { "name": "TopicPartitions", "type": "[]TopicPartitions", "versions": "0+", "nullableVersions": "0+", + "about": "The topic partitions to elect leaders.", + "fields": [ + { "name": "Topic", "type": "string", "versions": "0+", "entityType": "topicName", "mapKey": true, + "about": "The name of a topic." }, + { "name": "Partitions", "type": "[]int32", "versions": "0+", + "about": "The partitions of this topic whose leader should be elected." } + ] + }, + { "name": "TimeoutMs", "type": "int32", "versions": "0+", "default": "60000", + "about": "The time in ms to wait for the election to complete." } + ] +} diff --git a/clients/src/main/resources/common/message/ElectLeadersResponse.json b/clients/src/main/resources/common/message/ElectLeadersResponse.json new file mode 100644 index 0000000..15468c7 --- /dev/null +++ b/clients/src/main/resources/common/message/ElectLeadersResponse.json @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 43, + "type": "response", + "name": "ElectLeadersResponse", + // Version 1 adds a top-level error code. + // + // Version 2 is the first flexible version. + "validVersions": "0-2", + "flexibleVersions": "2+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "ErrorCode", "type": "int16", "versions": "1+", "ignorable": false, + "about": "The top level response error code." }, + { "name": "ReplicaElectionResults", "type": "[]ReplicaElectionResult", "versions": "0+", + "about": "The election results, or an empty array if the requester did not have permission and the request asks for all partitions.", "fields": [ + { "name": "Topic", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name" }, + { "name": "PartitionResult", "type": "[]PartitionResult", "versions": "0+", + "about": "The results for each partition", "fields": [ + { "name": "PartitionId", "type": "int32", "versions": "0+", + "about": "The partition id" }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The result error, or zero if there was no error."}, + { "name": "ErrorMessage", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The result message, or null if there was no error."} + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/EndQuorumEpochRequest.json b/clients/src/main/resources/common/message/EndQuorumEpochRequest.json new file mode 100644 index 0000000..a6e4076 --- /dev/null +++ b/clients/src/main/resources/common/message/EndQuorumEpochRequest.json @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 54, + "type": "request", + "listeners": ["controller"], + "name": "EndQuorumEpochRequest", + "validVersions": "0", + "flexibleVersions": "none", + "fields": [ + { "name": "ClusterId", "type": "string", "versions": "0+", + "nullableVersions": "0+", "default": "null"}, + { "name": "Topics", "type": "[]TopicData", + "versions": "0+", "fields": [ + { "name": "TopicName", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name." }, + { "name": "Partitions", "type": "[]PartitionData", + "versions": "0+", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "LeaderId", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The current leader ID that is resigning"}, + { "name": "LeaderEpoch", "type": "int32", "versions": "0+", + "about": "The current epoch"}, + { "name": "PreferredSuccessors", "type": "[]int32", "versions": "0+", + "about": "A sorted list of preferred successors to start the election"} + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/EndQuorumEpochResponse.json b/clients/src/main/resources/common/message/EndQuorumEpochResponse.json new file mode 100644 index 0000000..cd23247 --- /dev/null +++ b/clients/src/main/resources/common/message/EndQuorumEpochResponse.json @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 54, + "type": "response", + "name": "EndQuorumEpochResponse", + "validVersions": "0", + "flexibleVersions": "none", + "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The top level error code."}, + { "name": "Topics", "type": "[]TopicData", + "versions": "0+", "fields": [ + { "name": "TopicName", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name." }, + { "name": "Partitions", "type": "[]PartitionData", + "versions": "0+", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+"}, + { "name": "LeaderId", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The ID of the current leader or -1 if the leader is unknown."}, + { "name": "LeaderEpoch", "type": "int32", "versions": "0+", + "about": "The latest known leader epoch"} + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/EndTxnRequest.json b/clients/src/main/resources/common/message/EndTxnRequest.json new file mode 100644 index 0000000..f16ef76 --- /dev/null +++ b/clients/src/main/resources/common/message/EndTxnRequest.json @@ -0,0 +1,38 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 26, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "EndTxnRequest", + // Version 1 is the same as version 0. + // + // Version 2 adds the support for new error code PRODUCER_FENCED. + // + // Version 3 enables flexible versions. + "validVersions": "0-3", + "flexibleVersions": "3+", + "fields": [ + { "name": "TransactionalId", "type": "string", "versions": "0+", "entityType": "transactionalId", + "about": "The ID of the transaction to end." }, + { "name": "ProducerId", "type": "int64", "versions": "0+", "entityType": "producerId", + "about": "The producer ID." }, + { "name": "ProducerEpoch", "type": "int16", "versions": "0+", + "about": "The current epoch associated with the producer." }, + { "name": "Committed", "type": "bool", "versions": "0+", + "about": "True if the transaction was committed, false if it was aborted." } + ] +} diff --git a/clients/src/main/resources/common/message/EndTxnResponse.json b/clients/src/main/resources/common/message/EndTxnResponse.json new file mode 100644 index 0000000..3071953 --- /dev/null +++ b/clients/src/main/resources/common/message/EndTxnResponse.json @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 26, + "type": "response", + "name": "EndTxnResponse", + // Starting in version 1, on quota violation, brokers send out responses before throttling. + // + // Version 2 adds the support for new error code PRODUCER_FENCED. + // + // Version 3 enables flexible versions. + "validVersions": "0-3", + "flexibleVersions": "3+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if there was no error." } + ] +} diff --git a/clients/src/main/resources/common/message/EnvelopeRequest.json b/clients/src/main/resources/common/message/EnvelopeRequest.json new file mode 100644 index 0000000..1f6ff62 --- /dev/null +++ b/clients/src/main/resources/common/message/EnvelopeRequest.json @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 58, + "type": "request", + "listeners": ["controller"], + "name": "EnvelopeRequest", + // Request struct for forwarding. + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "RequestData", "type": "bytes", "versions": "0+", "zeroCopy": true, + "about": "The embedded request header and data."}, + { "name": "RequestPrincipal", "type": "bytes", "versions": "0+", "nullableVersions": "0+", + "about": "Value of the initial client principal when the request is redirected by a broker." }, + { "name": "ClientHostAddress", "type": "bytes", "versions": "0+", + "about": "The original client's address in bytes." } + ] +} diff --git a/clients/src/main/resources/common/message/EnvelopeResponse.json b/clients/src/main/resources/common/message/EnvelopeResponse.json new file mode 100644 index 0000000..9008052 --- /dev/null +++ b/clients/src/main/resources/common/message/EnvelopeResponse.json @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 58, + "type": "response", + "name": "EnvelopeResponse", + // Response struct for forwarding. + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "ResponseData", "type": "bytes", "versions": "0+", "nullableVersions": "0+", + "zeroCopy": true, "default": "null", + "about": "The embedded response header and data."}, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if there was no error." } + ] +} diff --git a/clients/src/main/resources/common/message/ExpireDelegationTokenRequest.json b/clients/src/main/resources/common/message/ExpireDelegationTokenRequest.json new file mode 100644 index 0000000..736f1df --- /dev/null +++ b/clients/src/main/resources/common/message/ExpireDelegationTokenRequest.json @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 40, + "type": "request", + "listeners": ["zkBroker"], + "name": "ExpireDelegationTokenRequest", + // Version 1 is the same as version 0. + // Version 2 adds flexible version support + "validVersions": "0-2", + "flexibleVersions": "2+", + "fields": [ + { "name": "Hmac", "type": "bytes", "versions": "0+", + "about": "The HMAC of the delegation token to be expired." }, + { "name": "ExpiryTimePeriodMs", "type": "int64", "versions": "0+", + "about": "The expiry time period in milliseconds." } + ] +} diff --git a/clients/src/main/resources/common/message/ExpireDelegationTokenResponse.json b/clients/src/main/resources/common/message/ExpireDelegationTokenResponse.json new file mode 100644 index 0000000..f2d4bf4 --- /dev/null +++ b/clients/src/main/resources/common/message/ExpireDelegationTokenResponse.json @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 40, + "type": "response", + "name": "ExpireDelegationTokenResponse", + // Starting in version 1, on quota violation, brokers send out responses before throttling. + // Version 2 adds flexible version support + "validVersions": "0-2", + "flexibleVersions": "2+", + "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if there was no error." }, + { "name": "ExpiryTimestampMs", "type": "int64", "versions": "0+", + "about": "The timestamp in milliseconds at which this token expires." }, + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." } + ] +} diff --git a/clients/src/main/resources/common/message/FetchRequest.json b/clients/src/main/resources/common/message/FetchRequest.json new file mode 100644 index 0000000..df63957 --- /dev/null +++ b/clients/src/main/resources/common/message/FetchRequest.json @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 1, + "type": "request", + "listeners": ["zkBroker", "broker", "controller"], + "name": "FetchRequest", + // + // Version 1 is the same as version 0. + // + // Starting in Version 2, the requestor must be able to handle Kafka Log + // Message format version 1. + // + // Version 3 adds MaxBytes. Starting in version 3, the partition ordering in + // the request is now relevant. Partitions will be processed in the order + // they appear in the request. + // + // Version 4 adds IsolationLevel. Starting in version 4, the reqestor must be + // able to handle Kafka log message format version 2. + // + // Version 5 adds LogStartOffset to indicate the earliest available offset of + // partition data that can be consumed. + // + // Version 6 is the same as version 5. + // + // Version 7 adds incremental fetch request support. + // + // Version 8 is the same as version 7. + // + // Version 9 adds CurrentLeaderEpoch, as described in KIP-320. + // + // Version 10 indicates that we can use the ZStd compression algorithm, as + // described in KIP-110. + // Version 12 adds flexible versions support as well as epoch validation through + // the `LastFetchedEpoch` field + // + // Version 13 replaces topic names with topic IDs (KIP-516). May return UNKNOWN_TOPIC_ID error code. + "validVersions": "0-13", + "flexibleVersions": "12+", + "fields": [ + { "name": "ClusterId", "type": "string", "versions": "12+", "nullableVersions": "12+", "default": "null", + "taggedVersions": "12+", "tag": 0, "ignorable": true, + "about": "The clusterId if known. This is used to validate metadata fetches prior to broker registration." }, + { "name": "ReplicaId", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The broker ID of the follower, of -1 if this request is from a consumer." }, + { "name": "MaxWaitMs", "type": "int32", "versions": "0+", + "about": "The maximum time in milliseconds to wait for the response." }, + { "name": "MinBytes", "type": "int32", "versions": "0+", + "about": "The minimum bytes to accumulate in the response." }, + { "name": "MaxBytes", "type": "int32", "versions": "3+", "default": "0x7fffffff", "ignorable": true, + "about": "The maximum bytes to fetch. See KIP-74 for cases where this limit may not be honored." }, + { "name": "IsolationLevel", "type": "int8", "versions": "4+", "default": "0", "ignorable": true, + "about": "This setting controls the visibility of transactional records. Using READ_UNCOMMITTED (isolation_level = 0) makes all records visible. With READ_COMMITTED (isolation_level = 1), non-transactional and COMMITTED transactional records are visible. To be more concrete, READ_COMMITTED returns all data from offsets smaller than the current LSO (last stable offset), and enables the inclusion of the list of aborted transactions in the result, which allows consumers to discard ABORTED transactional records" }, + { "name": "SessionId", "type": "int32", "versions": "7+", "default": "0", "ignorable": true, + "about": "The fetch session ID." }, + { "name": "SessionEpoch", "type": "int32", "versions": "7+", "default": "-1", "ignorable": true, + "about": "The fetch session epoch, which is used for ordering requests in a session." }, + { "name": "Topics", "type": "[]FetchTopic", "versions": "0+", + "about": "The topics to fetch.", "fields": [ + { "name": "Topic", "type": "string", "versions": "0-12", "entityType": "topicName", "ignorable": true, + "about": "The name of the topic to fetch." }, + { "name": "TopicId", "type": "uuid", "versions": "13+", "ignorable": true, "about": "The unique topic ID"}, + { "name": "Partitions", "type": "[]FetchPartition", "versions": "0+", + "about": "The partitions to fetch.", "fields": [ + { "name": "Partition", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "CurrentLeaderEpoch", "type": "int32", "versions": "9+", "default": "-1", "ignorable": true, + "about": "The current leader epoch of the partition." }, + { "name": "FetchOffset", "type": "int64", "versions": "0+", + "about": "The message offset." }, + { "name": "LastFetchedEpoch", "type": "int32", "versions": "12+", "default": "-1", "ignorable": false, + "about": "The epoch of the last fetched record or -1 if there is none"}, + { "name": "LogStartOffset", "type": "int64", "versions": "5+", "default": "-1", "ignorable": true, + "about": "The earliest available offset of the follower replica. The field is only used when the request is sent by the follower."}, + { "name": "PartitionMaxBytes", "type": "int32", "versions": "0+", + "about": "The maximum bytes to fetch from this partition. See KIP-74 for cases where this limit may not be honored." } + ]} + ]}, + { "name": "ForgottenTopicsData", "type": "[]ForgottenTopic", "versions": "7+", "ignorable": false, + "about": "In an incremental fetch request, the partitions to remove.", "fields": [ + { "name": "Topic", "type": "string", "versions": "7-12", "entityType": "topicName", "ignorable": true, + "about": "The partition name." }, + { "name": "TopicId", "type": "uuid", "versions": "13+", "ignorable": true, "about": "The unique topic ID"}, + { "name": "Partitions", "type": "[]int32", "versions": "7+", + "about": "The partitions indexes to forget." } + ]}, + { "name": "RackId", "type": "string", "versions": "11+", "default": "", "ignorable": true, + "about": "Rack ID of the consumer making this request"} + ] +} diff --git a/clients/src/main/resources/common/message/FetchResponse.json b/clients/src/main/resources/common/message/FetchResponse.json new file mode 100644 index 0000000..9ae28b7 --- /dev/null +++ b/clients/src/main/resources/common/message/FetchResponse.json @@ -0,0 +1,103 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 1, + "type": "response", + "name": "FetchResponse", + // + // Version 1 adds throttle time. + // + // Version 2 and 3 are the same as version 1. + // + // Version 4 adds features for transactional consumption. + // + // Version 5 adds LogStartOffset to indicate the earliest available offset of + // partition data that can be consumed. + // + // Starting in version 6, we may return KAFKA_STORAGE_ERROR as an error code. + // + // Version 7 adds incremental fetch request support. + // + // Starting in version 8, on quota violation, brokers send out responses before throttling. + // + // Version 9 is the same as version 8. + // + // Version 10 indicates that the response data can use the ZStd compression + // algorithm, as described in KIP-110. + // Version 12 adds support for flexible versions, epoch detection through the `TruncationOffset` field, + // and leader discovery through the `CurrentLeader` field + // + // Version 13 replaces the topic name field with topic ID (KIP-516). + "validVersions": "0-13", + "flexibleVersions": "12+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "1+", "ignorable": true, + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "ErrorCode", "type": "int16", "versions": "7+", "ignorable": true, + "about": "The top level response error code." }, + { "name": "SessionId", "type": "int32", "versions": "7+", "default": "0", "ignorable": false, + "about": "The fetch session ID, or 0 if this is not part of a fetch session." }, + { "name": "Responses", "type": "[]FetchableTopicResponse", "versions": "0+", + "about": "The response topics.", "fields": [ + { "name": "Topic", "type": "string", "versions": "0-12", "ignorable": true, "entityType": "topicName", + "about": "The topic name." }, + { "name": "TopicId", "type": "uuid", "versions": "13+", "ignorable": true, "about": "The unique topic ID"}, + { "name": "Partitions", "type": "[]PartitionData", "versions": "0+", + "about": "The topic partitions.", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if there was no fetch error." }, + { "name": "HighWatermark", "type": "int64", "versions": "0+", + "about": "The current high water mark." }, + { "name": "LastStableOffset", "type": "int64", "versions": "4+", "default": "-1", "ignorable": true, + "about": "The last stable offset (or LSO) of the partition. This is the last offset such that the state of all transactional records prior to this offset have been decided (ABORTED or COMMITTED)" }, + { "name": "LogStartOffset", "type": "int64", "versions": "5+", "default": "-1", "ignorable": true, + "about": "The current log start offset." }, + { "name": "DivergingEpoch", "type": "EpochEndOffset", "versions": "12+", "taggedVersions": "12+", "tag": 0, + "about": "In case divergence is detected based on the `LastFetchedEpoch` and `FetchOffset` in the request, this field indicates the largest epoch and its end offset such that subsequent records are known to diverge", + "fields": [ + { "name": "Epoch", "type": "int32", "versions": "12+", "default": "-1" }, + { "name": "EndOffset", "type": "int64", "versions": "12+", "default": "-1" } + ]}, + { "name": "CurrentLeader", "type": "LeaderIdAndEpoch", + "versions": "12+", "taggedVersions": "12+", "tag": 1, "fields": [ + { "name": "LeaderId", "type": "int32", "versions": "12+", "default": "-1", "entityType": "brokerId", + "about": "The ID of the current leader or -1 if the leader is unknown."}, + { "name": "LeaderEpoch", "type": "int32", "versions": "12+", "default": "-1", + "about": "The latest known leader epoch"} + ]}, + { "name": "SnapshotId", "type": "SnapshotId", + "versions": "12+", "taggedVersions": "12+", "tag": 2, + "about": "In the case of fetching an offset less than the LogStartOffset, this is the end offset and epoch that should be used in the FetchSnapshot request.", + "fields": [ + { "name": "EndOffset", "type": "int64", "versions": "0+", "default": "-1" }, + { "name": "Epoch", "type": "int32", "versions": "0+", "default": "-1" } + ]}, + { "name": "AbortedTransactions", "type": "[]AbortedTransaction", "versions": "4+", "nullableVersions": "4+", "ignorable": true, + "about": "The aborted transactions.", "fields": [ + { "name": "ProducerId", "type": "int64", "versions": "4+", "entityType": "producerId", + "about": "The producer id associated with the aborted transaction." }, + { "name": "FirstOffset", "type": "int64", "versions": "4+", + "about": "The first offset in the aborted transaction." } + ]}, + { "name": "PreferredReadReplica", "type": "int32", "versions": "11+", "default": "-1", "ignorable": false, "entityType": "brokerId", + "about": "The preferred read replica for the consumer to use on its next fetch request"}, + { "name": "Records", "type": "records", "versions": "0+", "nullableVersions": "0+", "about": "The record data."} + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/FetchSnapshotRequest.json b/clients/src/main/resources/common/message/FetchSnapshotRequest.json new file mode 100644 index 0000000..358ef2e --- /dev/null +++ b/clients/src/main/resources/common/message/FetchSnapshotRequest.json @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 59, + "type": "request", + "listeners": ["controller"], + "name": "FetchSnapshotRequest", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "ClusterId", "type": "string", "versions": "0+", "nullableVersions": "0+", "default": "null", "taggedVersions": "0+", "tag": 0, + "about": "The clusterId if known, this is used to validate metadata fetches prior to broker registration" }, + { "name": "ReplicaId", "type": "int32", "versions": "0+", "default": "-1", "entityType": "brokerId", + "about": "The broker ID of the follower" }, + { "name": "MaxBytes", "type": "int32", "versions": "0+", "default": "0x7fffffff", + "about": "The maximum bytes to fetch from all of the snapshots" }, + { "name": "Topics", "type": "[]TopicSnapshot", "versions": "0+", + "about": "The topics to fetch", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The name of the topic to fetch" }, + { "name": "Partitions", "type": "[]PartitionSnapshot", "versions": "0+", + "about": "The partitions to fetch", "fields": [ + { "name": "Partition", "type": "int32", "versions": "0+", + "about": "The partition index" }, + { "name": "CurrentLeaderEpoch", "type": "int32", "versions": "0+", + "about": "The current leader epoch of the partition, -1 for unknown leader epoch" }, + { "name": "SnapshotId", "type": "SnapshotId", "versions": "0+", + "about": "The snapshot endOffset and epoch to fetch", + "fields": [ + { "name": "EndOffset", "type": "int64", "versions": "0+" }, + { "name": "Epoch", "type": "int32", "versions": "0+" } + ]}, + { "name": "Position", "type": "int64", "versions": "0+", + "about": "The byte position within the snapshot to start fetching from" } + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/FetchSnapshotResponse.json b/clients/src/main/resources/common/message/FetchSnapshotResponse.json new file mode 100644 index 0000000..887a5e4 --- /dev/null +++ b/clients/src/main/resources/common/message/FetchSnapshotResponse.json @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 59, + "type": "response", + "name": "FetchSnapshotResponse", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", "ignorable": true, + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", "ignorable": false, + "about": "The top level response error code." }, + { "name": "Topics", "type": "[]TopicSnapshot", "versions": "0+", + "about": "The topics to fetch.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The name of the topic to fetch." }, + { "name": "Partitions", "type": "[]PartitionSnapshot", "versions": "0+", + "about": "The partitions to fetch.", "fields": [ + { "name": "Index", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if there was no fetch error." }, + { "name": "SnapshotId", "type": "SnapshotId", "versions": "0+", + "about": "The snapshot endOffset and epoch fetched", + "fields": [ + { "name": "EndOffset", "type": "int64", "versions": "0+" }, + { "name": "Epoch", "type": "int32", "versions": "0+" } + ]}, + { "name": "CurrentLeader", "type": "LeaderIdAndEpoch", + "versions": "0+", "taggedVersions": "0+", "tag": 0, "fields": [ + { "name": "LeaderId", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The ID of the current leader or -1 if the leader is unknown."}, + { "name": "LeaderEpoch", "type": "int32", "versions": "0+", + "about": "The latest known leader epoch"} + ]}, + { "name": "Size", "type": "int64", "versions": "0+", + "about": "The total size of the snapshot." }, + { "name": "Position", "type": "int64", "versions": "0+", + "about": "The starting byte position within the snapshot included in the Bytes field." }, + { "name": "UnalignedRecords", "type": "records", "versions": "0+", + "about": "Snapshot data in records format which may not be aligned on an offset boundary" } + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/FindCoordinatorRequest.json b/clients/src/main/resources/common/message/FindCoordinatorRequest.json new file mode 100644 index 0000000..a475cc2 --- /dev/null +++ b/clients/src/main/resources/common/message/FindCoordinatorRequest.json @@ -0,0 +1,38 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 10, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "FindCoordinatorRequest", + // Version 1 adds KeyType. + // + // Version 2 is the same as version 1. + // + // Version 3 is the first flexible version. + // + // Version 4 adds support for batching via CoordinatorKeys (KIP-699) + "validVersions": "0-4", + "flexibleVersions": "3+", + "fields": [ + { "name": "Key", "type": "string", "versions": "0-3", + "about": "The coordinator key." }, + { "name": "KeyType", "type": "int8", "versions": "1+", "default": "0", "ignorable": false, + "about": "The coordinator key type. (Group, transaction, etc.)" }, + { "name": "CoordinatorKeys", "type": "[]string", "versions": "4+", + "about": "The coordinator keys." } + ] +} diff --git a/clients/src/main/resources/common/message/FindCoordinatorResponse.json b/clients/src/main/resources/common/message/FindCoordinatorResponse.json new file mode 100644 index 0000000..9309c01 --- /dev/null +++ b/clients/src/main/resources/common/message/FindCoordinatorResponse.json @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 10, + "type": "response", + "name": "FindCoordinatorResponse", + // Version 1 adds throttle time and error messages. + // + // Starting in version 2, on quota violation, brokers send out responses before throttling. + // + // Version 3 is the first flexible version. + // + // Version 4 adds support for batching via Coordinators (KIP-699) + "validVersions": "0-4", + "flexibleVersions": "3+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "1+", "ignorable": true, + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "ErrorCode", "type": "int16", "versions": "0-3", + "about": "The error code, or 0 if there was no error." }, + { "name": "ErrorMessage", "type": "string", "versions": "1-3", "nullableVersions": "1-3", "ignorable": true, + "about": "The error message, or null if there was no error." }, + { "name": "NodeId", "type": "int32", "versions": "0-3", "entityType": "brokerId", + "about": "The node id." }, + { "name": "Host", "type": "string", "versions": "0-3", + "about": "The host name." }, + { "name": "Port", "type": "int32", "versions": "0-3", + "about": "The port." }, + { "name": "Coordinators", "type": "[]Coordinator", "versions": "4+", "about": "Each coordinator result in the response", "fields": [ + { "name": "Key", "type": "string", "versions": "4+", "about": "The coordinator key." }, + { "name": "NodeId", "type": "int32", "versions": "4+", "entityType": "brokerId", + "about": "The node id." }, + { "name": "Host", "type": "string", "versions": "4+", "about": "The host name." }, + { "name": "Port", "type": "int32", "versions": "4+", "about": "The port." }, + { "name": "ErrorCode", "type": "int16", "versions": "4+", + "about": "The error code, or 0 if there was no error." }, + { "name": "ErrorMessage", "type": "string", "versions": "4+", "nullableVersions": "4+", "ignorable": true, + "about": "The error message, or null if there was no error." } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/HeartbeatRequest.json b/clients/src/main/resources/common/message/HeartbeatRequest.json new file mode 100644 index 0000000..dcf776d --- /dev/null +++ b/clients/src/main/resources/common/message/HeartbeatRequest.json @@ -0,0 +1,39 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 12, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "HeartbeatRequest", + // Version 1 and version 2 are the same as version 0. + // + // Starting from version 3, we add a new field called groupInstanceId to indicate member identity across restarts. + // + // Version 4 is the first flexible version. + "validVersions": "0-4", + "flexibleVersions": "4+", + "fields": [ + { "name": "GroupId", "type": "string", "versions": "0+", "entityType": "groupId", + "about": "The group id." }, + { "name": "GenerationId", "type": "int32", "versions": "0+", + "about": "The generation of the group." }, + { "name": "MemberId", "type": "string", "versions": "0+", + "about": "The member ID." }, + { "name": "GroupInstanceId", "type": "string", "versions": "3+", + "nullableVersions": "3+", "default": "null", + "about": "The unique identifier of the consumer instance provided by end user." } + ] +} diff --git a/clients/src/main/resources/common/message/HeartbeatResponse.json b/clients/src/main/resources/common/message/HeartbeatResponse.json new file mode 100644 index 0000000..280ba11 --- /dev/null +++ b/clients/src/main/resources/common/message/HeartbeatResponse.json @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 12, + "type": "response", + "name": "HeartbeatResponse", + // Version 1 adds throttle time. + // + // Starting in version 2, on quota violation, brokers send out responses before throttling. + // + // Starting from version 3, heartbeatRequest supports a new field called groupInstanceId to indicate member identity across restarts. + // + // Version 4 is the first flexible version. + "validVersions": "0-4", + "flexibleVersions": "4+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "1+", "ignorable": true, + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if there was no error." } + ] +} diff --git a/clients/src/main/resources/common/message/IncrementalAlterConfigsRequest.json b/clients/src/main/resources/common/message/IncrementalAlterConfigsRequest.json new file mode 100644 index 0000000..d4955c9 --- /dev/null +++ b/clients/src/main/resources/common/message/IncrementalAlterConfigsRequest.json @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 44, + "type": "request", + "listeners": ["zkBroker", "broker", "controller"], + "name": "IncrementalAlterConfigsRequest", + // Version 1 is the first flexible version. + "validVersions": "0-1", + "flexibleVersions": "1+", + "fields": [ + { "name": "Resources", "type": "[]AlterConfigsResource", "versions": "0+", + "about": "The incremental updates for each resource.", "fields": [ + { "name": "ResourceType", "type": "int8", "versions": "0+", "mapKey": true, + "about": "The resource type." }, + { "name": "ResourceName", "type": "string", "versions": "0+", "mapKey": true, + "about": "The resource name." }, + { "name": "Configs", "type": "[]AlterableConfig", "versions": "0+", + "about": "The configurations.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "mapKey": true, + "about": "The configuration key name." }, + { "name": "ConfigOperation", "type": "int8", "versions": "0+", "mapKey": true, + "about": "The type (Set, Delete, Append, Subtract) of operation." }, + { "name": "Value", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The value to set for the configuration key."} + ]} + ]}, + { "name": "ValidateOnly", "type": "bool", "versions": "0+", + "about": "True if we should validate the request, but not change the configurations."} + ] +} diff --git a/clients/src/main/resources/common/message/IncrementalAlterConfigsResponse.json b/clients/src/main/resources/common/message/IncrementalAlterConfigsResponse.json new file mode 100644 index 0000000..d4dad29 --- /dev/null +++ b/clients/src/main/resources/common/message/IncrementalAlterConfigsResponse.json @@ -0,0 +1,38 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 44, + "type": "response", + "name": "IncrementalAlterConfigsResponse", + // Version 1 is the first flexible version. + "validVersions": "0-1", + "flexibleVersions": "1+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "Duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "Responses", "type": "[]AlterConfigsResourceResponse", "versions": "0+", + "about": "The responses for each resource.", "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The resource error code." }, + { "name": "ErrorMessage", "type": "string", "nullableVersions": "0+", "versions": "0+", + "about": "The resource error message, or null if there was no error." }, + { "name": "ResourceType", "type": "int8", "versions": "0+", + "about": "The resource type." }, + { "name": "ResourceName", "type": "string", "versions": "0+", + "about": "The resource name." } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/InitProducerIdRequest.json b/clients/src/main/resources/common/message/InitProducerIdRequest.json new file mode 100644 index 0000000..4e75352 --- /dev/null +++ b/clients/src/main/resources/common/message/InitProducerIdRequest.json @@ -0,0 +1,40 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 22, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "InitProducerIdRequest", + // Version 1 is the same as version 0. + // + // Version 2 is the first flexible version. + // + // Version 3 adds ProducerId and ProducerEpoch, allowing producers to try to resume after an INVALID_PRODUCER_EPOCH error + // + // Version 4 adds the support for new error code PRODUCER_FENCED. + "validVersions": "0-4", + "flexibleVersions": "2+", + "fields": [ + { "name": "TransactionalId", "type": "string", "versions": "0+", "nullableVersions": "0+", "entityType": "transactionalId", + "about": "The transactional id, or null if the producer is not transactional." }, + { "name": "TransactionTimeoutMs", "type": "int32", "versions": "0+", + "about": "The time in ms to wait before aborting idle transactions sent by this producer. This is only relevant if a TransactionalId has been defined." }, + { "name": "ProducerId", "type": "int64", "versions": "3+", "default": "-1", "entityType": "producerId", + "about": "The producer id. This is used to disambiguate requests if a transactional id is reused following its expiration." }, + { "name": "ProducerEpoch", "type": "int16", "versions": "3+", "default": "-1", + "about": "The producer's current epoch. This will be checked against the producer epoch on the broker, and the request will return an error if they do not match." } + ] +} diff --git a/clients/src/main/resources/common/message/InitProducerIdResponse.json b/clients/src/main/resources/common/message/InitProducerIdResponse.json new file mode 100644 index 0000000..f56c2fe --- /dev/null +++ b/clients/src/main/resources/common/message/InitProducerIdResponse.json @@ -0,0 +1,39 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 22, + "type": "response", + "name": "InitProducerIdResponse", + // Starting in version 1, on quota violation, brokers send out responses before throttling. + // + // Version 2 is the first flexible version. + // + // Version 3 is the same as version 2. + // + // Version 4 adds the support for new error code PRODUCER_FENCED. + "validVersions": "0-4", + "flexibleVersions": "2+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", "ignorable": true, + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if there was no error." }, + { "name": "ProducerId", "type": "int64", "versions": "0+", "entityType": "producerId", + "default": -1, "about": "The current producer id." }, + { "name": "ProducerEpoch", "type": "int16", "versions": "0+", + "about": "The current epoch associated with the producer id." } + ] +} diff --git a/clients/src/main/resources/common/message/JoinGroupRequest.json b/clients/src/main/resources/common/message/JoinGroupRequest.json new file mode 100644 index 0000000..d9113b7 --- /dev/null +++ b/clients/src/main/resources/common/message/JoinGroupRequest.json @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 11, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "JoinGroupRequest", + // Version 1 adds RebalanceTimeoutMs. + // + // Version 2 and 3 are the same as version 1. + // + // Starting from version 4, the client needs to issue a second request to join group + // + // Starting from version 5, we add a new field called groupInstanceId to indicate member identity across restarts. + // with assigned id. + // + // Version 6 is the first flexible version. + // + // Version 7 is the same as version 6. + "validVersions": "0-7", + "flexibleVersions": "6+", + "fields": [ + { "name": "GroupId", "type": "string", "versions": "0+", "entityType": "groupId", + "about": "The group identifier." }, + { "name": "SessionTimeoutMs", "type": "int32", "versions": "0+", + "about": "The coordinator considers the consumer dead if it receives no heartbeat after this timeout in milliseconds." }, + // Note: if RebalanceTimeoutMs is not present, SessionTimeoutMs should be + // used instead. The default of -1 here is just intended as a placeholder. + { "name": "RebalanceTimeoutMs", "type": "int32", "versions": "1+", "default": "-1", "ignorable": true, + "about": "The maximum time in milliseconds that the coordinator will wait for each member to rejoin when rebalancing the group." }, + { "name": "MemberId", "type": "string", "versions": "0+", + "about": "The member id assigned by the group coordinator." }, + { "name": "GroupInstanceId", "type": "string", "versions": "5+", + "nullableVersions": "5+", "default": "null", + "about": "The unique identifier of the consumer instance provided by end user." }, + { "name": "ProtocolType", "type": "string", "versions": "0+", + "about": "The unique name the for class of protocols implemented by the group we want to join." }, + { "name": "Protocols", "type": "[]JoinGroupRequestProtocol", "versions": "0+", + "about": "The list of protocols that the member supports.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "mapKey": true, + "about": "The protocol name." }, + { "name": "Metadata", "type": "bytes", "versions": "0+", + "about": "The protocol metadata." } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/JoinGroupResponse.json b/clients/src/main/resources/common/message/JoinGroupResponse.json new file mode 100644 index 0000000..f95ec01 --- /dev/null +++ b/clients/src/main/resources/common/message/JoinGroupResponse.json @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 11, + "type": "response", + "name": "JoinGroupResponse", + // Version 1 is the same as version 0. + // + // Version 2 adds throttle time. + // + // Starting in version 3, on quota violation, brokers send out responses before throttling. + // + // Starting in version 4, the client needs to issue a second request to join group + // with assigned id. + // + // Version 5 is bumped to apply group.instance.id to identify member across restarts. + // + // Version 6 is the first flexible version. + // + // Starting from version 7, the broker sends back the Protocol Type to the client (KIP-559). + "validVersions": "0-7", + "flexibleVersions": "6+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "2+", "ignorable": true, + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if there was no error." }, + { "name": "GenerationId", "type": "int32", "versions": "0+", "default": "-1", + "about": "The generation ID of the group." }, + { "name": "ProtocolType", "type": "string", "versions": "7+", + "nullableVersions": "7+", "default": "null", "ignorable": true, + "about": "The group protocol name." }, + { "name": "ProtocolName", "type": "string", "versions": "0+", "nullableVersions": "7+", + "about": "The group protocol selected by the coordinator." }, + { "name": "Leader", "type": "string", "versions": "0+", + "about": "The leader of the group." }, + { "name": "MemberId", "type": "string", "versions": "0+", + "about": "The member ID assigned by the group coordinator." }, + { "name": "Members", "type": "[]JoinGroupResponseMember", "versions": "0+", "fields": [ + { "name": "MemberId", "type": "string", "versions": "0+", + "about": "The group member ID." }, + { "name": "GroupInstanceId", "type": "string", "versions": "5+", + "nullableVersions": "5+", "default": "null", + "about": "The unique identifier of the consumer instance provided by end user." }, + { "name": "Metadata", "type": "bytes", "versions": "0+", + "about": "The group member metadata." } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/LeaderAndIsrRequest.json b/clients/src/main/resources/common/message/LeaderAndIsrRequest.json new file mode 100644 index 0000000..c38f21e --- /dev/null +++ b/clients/src/main/resources/common/message/LeaderAndIsrRequest.json @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 4, + "type": "request", + "listeners": ["zkBroker"], + "name": "LeaderAndIsrRequest", + // Version 1 adds IsNew. + // + // Version 2 adds broker epoch and reorganizes the partitions by topic. + // + // Version 3 adds AddingReplicas and RemovingReplicas. + // + // Version 4 is the first flexible version. + // + // Version 5 adds Topic ID and Type to the TopicStates, as described in KIP-516. + "validVersions": "0-5", + "flexibleVersions": "4+", + "fields": [ + { "name": "ControllerId", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The current controller ID." }, + { "name": "ControllerEpoch", "type": "int32", "versions": "0+", + "about": "The current controller epoch." }, + { "name": "BrokerEpoch", "type": "int64", "versions": "2+", "ignorable": true, "default": "-1", + "about": "The current broker epoch." }, + { "name": "Type", "type": "int8", "versions": "5+", + "about": "The type that indicates whether all topics are included in the request"}, + { "name": "UngroupedPartitionStates", "type": "[]LeaderAndIsrPartitionState", "versions": "0-1", + "about": "The state of each partition, in a v0 or v1 message." }, + // In v0 or v1 requests, each partition is listed alongside its topic name. + // In v2+ requests, partitions are organized by topic, so that each topic name + // only needs to be listed once. + { "name": "TopicStates", "type": "[]LeaderAndIsrTopicState", "versions": "2+", + "about": "Each topic.", "fields": [ + { "name": "TopicName", "type": "string", "versions": "2+", "entityType": "topicName", + "about": "The topic name." }, + { "name": "TopicId", "type": "uuid", "versions": "5+", "ignorable": true, + "about": "The unique topic ID." }, + { "name": "PartitionStates", "type": "[]LeaderAndIsrPartitionState", "versions": "2+", + "about": "The state of each partition" } + ]}, + { "name": "LiveLeaders", "type": "[]LeaderAndIsrLiveLeader", "versions": "0+", + "about": "The current live leaders.", "fields": [ + { "name": "BrokerId", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The leader's broker ID." }, + { "name": "HostName", "type": "string", "versions": "0+", + "about": "The leader's hostname." }, + { "name": "Port", "type": "int32", "versions": "0+", + "about": "The leader's port." } + ]} + ], + "commonStructs": [ + { "name": "LeaderAndIsrPartitionState", "versions": "0+", "fields": [ + { "name": "TopicName", "type": "string", "versions": "0-1", "entityType": "topicName", "ignorable": true, + "about": "The topic name. This is only present in v0 or v1." }, + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "ControllerEpoch", "type": "int32", "versions": "0+", + "about": "The controller epoch." }, + { "name": "Leader", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The broker ID of the leader." }, + { "name": "LeaderEpoch", "type": "int32", "versions": "0+", + "about": "The leader epoch." }, + { "name": "Isr", "type": "[]int32", "versions": "0+", "entityType": "brokerId", + "about": "The in-sync replica IDs." }, + { "name": "ZkVersion", "type": "int32", "versions": "0+", + "about": "The ZooKeeper version." }, + { "name": "Replicas", "type": "[]int32", "versions": "0+", "entityType": "brokerId", + "about": "The replica IDs." }, + { "name": "AddingReplicas", "type": "[]int32", "versions": "3+", "ignorable": true, "entityType": "brokerId", + "about": "The replica IDs that we are adding this partition to, or null if no replicas are being added." }, + { "name": "RemovingReplicas", "type": "[]int32", "versions": "3+", "ignorable": true, "entityType": "brokerId", + "about": "The replica IDs that we are removing this partition from, or null if no replicas are being removed." }, + { "name": "IsNew", "type": "bool", "versions": "1+", "default": "false", "ignorable": true, + "about": "Whether the replica should have existed on the broker or not." } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/LeaderAndIsrResponse.json b/clients/src/main/resources/common/message/LeaderAndIsrResponse.json new file mode 100644 index 0000000..958448b --- /dev/null +++ b/clients/src/main/resources/common/message/LeaderAndIsrResponse.json @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 4, + "type": "response", + "name": "LeaderAndIsrResponse", + // Version 1 adds KAFKA_STORAGE_ERROR as a valid error code. + // + // Version 2 is the same as version 1. + // + // Version 3 is the same as version 2. + // + // Version 4 is the first flexible version. + // + // Version 5 removes TopicName and replaces it with TopicId and reorganizes + // the partitions by topic, as described by KIP-516. + "validVersions": "0-5", + "flexibleVersions": "4+", + "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if there was no error." }, + { "name": "PartitionErrors", "type": "[]LeaderAndIsrPartitionError", "versions": "0-4", + "about": "Each partition in v0 to v4 message."}, + { "name": "Topics", "type": "[]LeaderAndIsrTopicError", "versions": "5+", + "about": "Each topic", "fields": [ + { "name": "TopicId", "type": "uuid", "versions": "5+", "mapKey": true, + "about": "The unique topic ID" }, + { "name": "PartitionErrors", "type": "[]LeaderAndIsrPartitionError", "versions": "5+", + "about": "Each partition."} + ]} + ], + "commonStructs": [ + { "name": "LeaderAndIsrPartitionError", "versions": "0+", "fields": [ + { "name": "TopicName", "type": "string", "versions": "0-4", "entityType": "topicName", "ignorable": true, + "about": "The topic name."}, + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The partition error code, or 0 if there was no error." } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/LeaderChangeMessage.json b/clients/src/main/resources/common/message/LeaderChangeMessage.json new file mode 100644 index 0000000..fdd7733 --- /dev/null +++ b/clients/src/main/resources/common/message/LeaderChangeMessage.json @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "type": "data", + "name": "LeaderChangeMessage", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + {"name": "Version", "type": "int16", "versions": "0+", + "about": "The version of the leader change message"}, + {"name": "LeaderId", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The ID of the newly elected leader"}, + {"name": "Voters", "type": "[]Voter", "versions": "0+", + "about": "The set of voters in the quorum for this epoch"}, + {"name": "GrantingVoters", "type": "[]Voter", "versions": "0+", + "about": "The voters who voted for the leader at the time of election"} + ], + "commonStructs": [ + { "name": "Voter", "versions": "0+", "fields": [ + {"name": "VoterId", "type": "int32", "versions": "0+"} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/LeaveGroupRequest.json b/clients/src/main/resources/common/message/LeaveGroupRequest.json new file mode 100644 index 0000000..893c945 --- /dev/null +++ b/clients/src/main/resources/common/message/LeaveGroupRequest.json @@ -0,0 +1,42 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 13, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "LeaveGroupRequest", + // Version 1 and 2 are the same as version 0. + // + // Version 3 defines batch processing scheme with group.instance.id + member.id for identity + // + // Version 4 is the first flexible version. + "validVersions": "0-4", + "flexibleVersions": "4+", + "fields": [ + { "name": "GroupId", "type": "string", "versions": "0+", "entityType": "groupId", + "about": "The ID of the group to leave." }, + { "name": "MemberId", "type": "string", "versions": "0-2", + "about": "The member ID to remove from the group." }, + { "name": "Members", "type": "[]MemberIdentity", "versions": "3+", + "about": "List of leaving member identities.", "fields": [ + { "name": "MemberId", "type": "string", "versions": "3+", + "about": "The member ID to remove from the group." }, + { "name": "GroupInstanceId", "type": "string", + "versions": "3+", "nullableVersions": "3+", "default": "null", + "about": "The group instance ID to remove from the group." } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/LeaveGroupResponse.json b/clients/src/main/resources/common/message/LeaveGroupResponse.json new file mode 100644 index 0000000..0ddb4c6 --- /dev/null +++ b/clients/src/main/resources/common/message/LeaveGroupResponse.json @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 13, + "type": "response", + "name": "LeaveGroupResponse", + // Version 1 adds the throttle time. + // + // Starting in version 2, on quota violation, brokers send out responses before throttling. + // + // Starting in version 3, we will make leave group request into batch mode and add group.instance.id. + // + // Version 4 is the first flexible version. + "validVersions": "0-4", + "flexibleVersions": "4+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "1+", "ignorable": true, + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if there was no error." }, + + { "name": "Members", "type": "[]MemberResponse", "versions": "3+", + "about": "List of leaving member responses.", "fields": [ + { "name": "MemberId", "type": "string", "versions": "3+", + "about": "The member ID to remove from the group." }, + { "name": "GroupInstanceId", "type": "string", "versions": "3+", "nullableVersions": "3+", + "about": "The group instance ID to remove from the group." }, + { "name": "ErrorCode", "type": "int16", "versions": "3+", + "about": "The error code, or 0 if there was no error." } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/ListGroupsRequest.json b/clients/src/main/resources/common/message/ListGroupsRequest.json new file mode 100644 index 0000000..3f62e28 --- /dev/null +++ b/clients/src/main/resources/common/message/ListGroupsRequest.json @@ -0,0 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 16, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "ListGroupsRequest", + // Version 1 and 2 are the same as version 0. + // + // Version 3 is the first flexible version. + // + // Version 4 adds the StatesFilter field (KIP-518). + "validVersions": "0-4", + "flexibleVersions": "3+", + "fields": [ + { "name": "StatesFilter", "type": "[]string", "versions": "4+", + "about": "The states of the groups we want to list. If empty all groups are returned with their state." + } + ] +} diff --git a/clients/src/main/resources/common/message/ListGroupsResponse.json b/clients/src/main/resources/common/message/ListGroupsResponse.json new file mode 100644 index 0000000..87561c2 --- /dev/null +++ b/clients/src/main/resources/common/message/ListGroupsResponse.json @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 16, + "type": "response", + "name": "ListGroupsResponse", + // Version 1 adds the throttle time. + // + // Starting in version 2, on quota violation, brokers send out responses before throttling. + // + // Version 3 is the first flexible version. + // + // Version 4 adds the GroupState field (KIP-518). + "validVersions": "0-4", + "flexibleVersions": "3+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "1+", "ignorable": true, + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if there was no error." }, + { "name": "Groups", "type": "[]ListedGroup", "versions": "0+", + "about": "Each group in the response.", "fields": [ + { "name": "GroupId", "type": "string", "versions": "0+", "entityType": "groupId", + "about": "The group ID." }, + { "name": "ProtocolType", "type": "string", "versions": "0+", + "about": "The group protocol type." }, + { "name": "GroupState", "type": "string", "versions": "4+", "ignorable": true, + "about": "The group state name." } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/ListOffsetsRequest.json b/clients/src/main/resources/common/message/ListOffsetsRequest.json new file mode 100644 index 0000000..93c920e --- /dev/null +++ b/clients/src/main/resources/common/message/ListOffsetsRequest.json @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 2, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "ListOffsetsRequest", + // Version 1 removes MaxNumOffsets. From this version forward, only a single + // offset can be returned. + // + // Version 2 adds the isolation level, which is used for transactional reads. + // + // Version 3 is the same as version 2. + // + // Version 4 adds the current leader epoch, which is used for fencing. + // + // Version 5 is the same as version 4. + // + // Version 6 enables flexible versions. + // + // Version 7 enables listing offsets by max timestamp (KIP-734). + "validVersions": "0-7", + "flexibleVersions": "6+", + "fields": [ + { "name": "ReplicaId", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The broker ID of the requestor, or -1 if this request is being made by a normal consumer." }, + { "name": "IsolationLevel", "type": "int8", "versions": "2+", + "about": "This setting controls the visibility of transactional records. Using READ_UNCOMMITTED (isolation_level = 0) makes all records visible. With READ_COMMITTED (isolation_level = 1), non-transactional and COMMITTED transactional records are visible. To be more concrete, READ_COMMITTED returns all data from offsets smaller than the current LSO (last stable offset), and enables the inclusion of the list of aborted transactions in the result, which allows consumers to discard ABORTED transactional records" }, + { "name": "Topics", "type": "[]ListOffsetsTopic", "versions": "0+", + "about": "Each topic in the request.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name." }, + { "name": "Partitions", "type": "[]ListOffsetsPartition", "versions": "0+", + "about": "Each partition in the request.", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "CurrentLeaderEpoch", "type": "int32", "versions": "4+", "default": "-1", "ignorable": true, + "about": "The current leader epoch." }, + { "name": "Timestamp", "type": "int64", "versions": "0+", + "about": "The current timestamp." }, + { "name": "MaxNumOffsets", "type": "int32", "versions": "0", "default": "1", + "about": "The maximum number of offsets to report." } + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/ListOffsetsResponse.json b/clients/src/main/resources/common/message/ListOffsetsResponse.json new file mode 100644 index 0000000..6d6be0f --- /dev/null +++ b/clients/src/main/resources/common/message/ListOffsetsResponse.json @@ -0,0 +1,59 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 2, + "type": "response", + "name": "ListOffsetsResponse", + // Version 1 removes the offsets array in favor of returning a single offset. + // Version 1 also adds the timestamp associated with the returned offset. + // + // Version 2 adds the throttle time. + // + // Starting in version 3, on quota violation, brokers send out responses before throttling. + // + // Version 4 adds the leader epoch, which is used for fencing. + // + // Version 5 adds a new error code, OFFSET_NOT_AVAILABLE. + // + // Version 6 enables flexible versions. + // + // Version 7 is the same as version 6 (KIP-734). + "validVersions": "0-7", + "flexibleVersions": "6+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "2+", "ignorable": true, + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "Topics", "type": "[]ListOffsetsTopicResponse", "versions": "0+", + "about": "Each topic in the response.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name" }, + { "name": "Partitions", "type": "[]ListOffsetsPartitionResponse", "versions": "0+", + "about": "Each partition in the response.", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The partition error code, or 0 if there was no error." }, + { "name": "OldStyleOffsets", "type": "[]int64", "versions": "0", "ignorable": false, + "about": "The result offsets." }, + { "name": "Timestamp", "type": "int64", "versions": "1+", "default": "-1", "ignorable": false, + "about": "The timestamp associated with the returned offset." }, + { "name": "Offset", "type": "int64", "versions": "1+", "default": "-1", "ignorable": false, + "about": "The returned offset." }, + { "name": "LeaderEpoch", "type": "int32", "versions": "4+", "default": "-1" } + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/ListPartitionReassignmentsRequest.json b/clients/src/main/resources/common/message/ListPartitionReassignmentsRequest.json new file mode 100644 index 0000000..6102209 --- /dev/null +++ b/clients/src/main/resources/common/message/ListPartitionReassignmentsRequest.json @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 46, + "type": "request", + "listeners": ["broker", "controller", "zkBroker"], + "name": "ListPartitionReassignmentsRequest", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "TimeoutMs", "type": "int32", "versions": "0+", "default": "60000", + "about": "The time in ms to wait for the request to complete." }, + { "name": "Topics", "type": "[]ListPartitionReassignmentsTopics", "versions": "0+", "nullableVersions": "0+", "default": "null", + "about": "The topics to list partition reassignments for, or null to list everything.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name" }, + { "name": "PartitionIndexes", "type": "[]int32", "versions": "0+", + "about": "The partitions to list partition reassignments for." } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/ListPartitionReassignmentsResponse.json b/clients/src/main/resources/common/message/ListPartitionReassignmentsResponse.json new file mode 100644 index 0000000..753d9bf --- /dev/null +++ b/clients/src/main/resources/common/message/ListPartitionReassignmentsResponse.json @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 46, + "type": "response", + "name": "ListPartitionReassignmentsResponse", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The top-level error code, or 0 if there was no error" }, + { "name": "ErrorMessage", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The top-level error message, or null if there was no error." }, + { "name": "Topics", "type": "[]OngoingTopicReassignment", "versions": "0+", + "about": "The ongoing reassignments for each topic.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name." }, + { "name": "Partitions", "type": "[]OngoingPartitionReassignment", "versions": "0+", + "about": "The ongoing reassignments for each partition.", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The index of the partition." }, + { "name": "Replicas", "type": "[]int32", "versions": "0+", "entityType": "brokerId", + "about": "The current replica set." }, + { "name": "AddingReplicas", "type": "[]int32", "versions": "0+", "entityType": "brokerId", + "about": "The set of replicas we are currently adding." }, + { "name": "RemovingReplicas", "type": "[]int32", "versions": "0+", "entityType": "brokerId", + "about": "The set of replicas we are currently removing." } + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/ListTransactionsRequest.json b/clients/src/main/resources/common/message/ListTransactionsRequest.json new file mode 100644 index 0000000..21f4552 --- /dev/null +++ b/clients/src/main/resources/common/message/ListTransactionsRequest.json @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 66, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "ListTransactionsRequest", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "StateFilters", "type": "[]string", "versions": "0+", + "about": "The transaction states to filter by: if empty, all transactions are returned; if non-empty, then only transactions matching one of the filtered states will be returned" + }, + { "name": "ProducerIdFilters", "type": "[]int64", "versions": "0+", "entityType": "producerId", + "about": "The producerIds to filter by: if empty, all transactions will be returned; if non-empty, only transactions which match one of the filtered producerIds will be returned" + } + ] +} diff --git a/clients/src/main/resources/common/message/ListTransactionsResponse.json b/clients/src/main/resources/common/message/ListTransactionsResponse.json new file mode 100644 index 0000000..2f17873 --- /dev/null +++ b/clients/src/main/resources/common/message/ListTransactionsResponse.json @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 66, + "type": "response", + "name": "ListTransactionsResponse", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+" }, + { "name": "UnknownStateFilters", "type": "[]string", "versions": "0+", + "about": "Set of state filters provided in the request which were unknown to the transaction coordinator" }, + { "name": "TransactionStates", "type": "[]TransactionState", "versions": "0+", "fields": [ + { "name": "TransactionalId", "type": "string", "versions": "0+", "entityType": "transactionalId" }, + { "name": "ProducerId", "type": "int64", "versions": "0+", "entityType": "producerId" }, + { "name": "TransactionState", "type": "string", "versions": "0+", + "about": "The current transaction state of the producer" } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/MetadataRequest.json b/clients/src/main/resources/common/message/MetadataRequest.json new file mode 100644 index 0000000..5da95cf --- /dev/null +++ b/clients/src/main/resources/common/message/MetadataRequest.json @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 3, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "MetadataRequest", + "validVersions": "0-12", + "flexibleVersions": "9+", + "fields": [ + // In version 0, an empty array indicates "request metadata for all topics." In version 1 and + // higher, an empty array indicates "request metadata for no topics," and a null array is used to + // indiate "request metadata for all topics." + // + // Version 2 and 3 are the same as version 1. + // + // Version 4 adds AllowAutoTopicCreation. + // + // Starting in version 8, authorized operations can be requested for cluster and topic resource. + // + // Version 9 is the first flexible version. + // + // Version 10 adds topicId and allows name field to be null. However, this functionality was not implemented on the server. + // Versions 10 and 11 should not use the topicId field or set topic name to null. + // + // Version 11 deprecates IncludeClusterAuthorizedOperations field. This is now exposed + // by the DescribeCluster API (KIP-700). + // Version 12 supports topic Id. + { "name": "Topics", "type": "[]MetadataRequestTopic", "versions": "0+", "nullableVersions": "1+", + "about": "The topics to fetch metadata for.", "fields": [ + { "name": "TopicId", "type": "uuid", "versions": "10+", "ignorable": true, "about": "The topic id." }, + { "name": "Name", "type": "string", "versions": "0+", "entityType": "topicName", "nullableVersions": "10+", + "about": "The topic name." } + ]}, + { "name": "AllowAutoTopicCreation", "type": "bool", "versions": "4+", "default": "true", "ignorable": false, + "about": "If this is true, the broker may auto-create topics that we requested which do not already exist, if it is configured to do so." }, + { "name": "IncludeClusterAuthorizedOperations", "type": "bool", "versions": "8-10", + "about": "Whether to include cluster authorized operations." }, + { "name": "IncludeTopicAuthorizedOperations", "type": "bool", "versions": "8+", + "about": "Whether to include topic authorized operations." } + ] +} diff --git a/clients/src/main/resources/common/message/MetadataResponse.json b/clients/src/main/resources/common/message/MetadataResponse.json new file mode 100644 index 0000000..714b28b --- /dev/null +++ b/clients/src/main/resources/common/message/MetadataResponse.json @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 3, + "type": "response", + "name": "MetadataResponse", + // Version 1 adds fields for the rack of each broker, the controller id, and + // whether or not the topic is internal. + // + // Version 2 adds the cluster ID field. + // + // Version 3 adds the throttle time. + // + // Version 4 is the same as version 3. + // + // Version 5 adds a per-partition offline_replicas field. This field specifies + // the list of replicas that are offline. + // + // Starting in version 6, on quota violation, brokers send out responses before throttling. + // + // Version 7 adds the leader epoch to the partition metadata. + // + // Starting in version 8, brokers can send authorized operations for topic and cluster. + // + // Version 9 is the first flexible version. + // + // Version 10 adds topicId. + // + // Version 11 deprecates ClusterAuthorizedOperations. This is now exposed + // by the DescribeCluster API (KIP-700). + // Version 12 supports topicId. + "validVersions": "0-12", + "flexibleVersions": "9+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "3+", "ignorable": true, + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "Brokers", "type": "[]MetadataResponseBroker", "versions": "0+", + "about": "Each broker in the response.", "fields": [ + { "name": "NodeId", "type": "int32", "versions": "0+", "mapKey": true, "entityType": "brokerId", + "about": "The broker ID." }, + { "name": "Host", "type": "string", "versions": "0+", + "about": "The broker hostname." }, + { "name": "Port", "type": "int32", "versions": "0+", + "about": "The broker port." }, + { "name": "Rack", "type": "string", "versions": "1+", "nullableVersions": "1+", "ignorable": true, "default": "null", + "about": "The rack of the broker, or null if it has not been assigned to a rack." } + ]}, + { "name": "ClusterId", "type": "string", "nullableVersions": "2+", "versions": "2+", "ignorable": true, "default": "null", + "about": "The cluster ID that responding broker belongs to." }, + { "name": "ControllerId", "type": "int32", "versions": "1+", "default": "-1", "ignorable": true, "entityType": "brokerId", + "about": "The ID of the controller broker." }, + { "name": "Topics", "type": "[]MetadataResponseTopic", "versions": "0+", + "about": "Each topic in the response.", "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The topic error, or 0 if there was no error." }, + { "name": "Name", "type": "string", "versions": "0+", "mapKey": true, "entityType": "topicName", "nullableVersions": "12+", + "about": "The topic name." }, + { "name": "TopicId", "type": "uuid", "versions": "10+", "ignorable": true, "about": "The topic id." }, + { "name": "IsInternal", "type": "bool", "versions": "1+", "default": "false", "ignorable": true, + "about": "True if the topic is internal." }, + { "name": "Partitions", "type": "[]MetadataResponsePartition", "versions": "0+", + "about": "Each partition in the topic.", "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The partition error, or 0 if there was no error." }, + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "LeaderId", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The ID of the leader broker." }, + { "name": "LeaderEpoch", "type": "int32", "versions": "7+", "default": "-1", "ignorable": true, + "about": "The leader epoch of this partition." }, + { "name": "ReplicaNodes", "type": "[]int32", "versions": "0+", "entityType": "brokerId", + "about": "The set of all nodes that host this partition." }, + { "name": "IsrNodes", "type": "[]int32", "versions": "0+", "entityType": "brokerId", + "about": "The set of nodes that are in sync with the leader for this partition." }, + { "name": "OfflineReplicas", "type": "[]int32", "versions": "5+", "ignorable": true, "entityType": "brokerId", + "about": "The set of offline replicas of this partition." } + ]}, + { "name": "TopicAuthorizedOperations", "type": "int32", "versions": "8+", "default": "-2147483648", + "about": "32-bit bitfield to represent authorized operations for this topic." } + ]}, + { "name": "ClusterAuthorizedOperations", "type": "int32", "versions": "8-10", "default": "-2147483648", + "about": "32-bit bitfield to represent authorized operations for this cluster." } + ] +} diff --git a/clients/src/main/resources/common/message/OffsetCommitRequest.json b/clients/src/main/resources/common/message/OffsetCommitRequest.json new file mode 100644 index 0000000..cf112e1 --- /dev/null +++ b/clients/src/main/resources/common/message/OffsetCommitRequest.json @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 8, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "OffsetCommitRequest", + // Version 1 adds timestamp and group membership information, as well as the commit timestamp. + // + // Version 2 adds retention time. It removes the commit timestamp added in version 1. + // + // Version 3 and 4 are the same as version 2. + // + // Version 5 removes the retention time, which is now controlled only by a broker configuration. + // + // Version 6 adds the leader epoch for fencing. + // + // version 7 adds a new field called groupInstanceId to indicate member identity across restarts. + // + // Version 8 is the first flexible version. + "validVersions": "0-8", + "flexibleVersions": "8+", + "fields": [ + { "name": "GroupId", "type": "string", "versions": "0+", "entityType": "groupId", + "about": "The unique group identifier." }, + { "name": "GenerationId", "type": "int32", "versions": "1+", "default": "-1", "ignorable": true, + "about": "The generation of the group." }, + { "name": "MemberId", "type": "string", "versions": "1+", "ignorable": true, + "about": "The member ID assigned by the group coordinator." }, + { "name": "GroupInstanceId", "type": "string", "versions": "7+", + "nullableVersions": "7+", "default": "null", + "about": "The unique identifier of the consumer instance provided by end user." }, + { "name": "RetentionTimeMs", "type": "int64", "versions": "2-4", "default": "-1", "ignorable": true, + "about": "The time period in ms to retain the offset." }, + { "name": "Topics", "type": "[]OffsetCommitRequestTopic", "versions": "0+", + "about": "The topics to commit offsets for.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name." }, + { "name": "Partitions", "type": "[]OffsetCommitRequestPartition", "versions": "0+", + "about": "Each partition to commit offsets for.", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "CommittedOffset", "type": "int64", "versions": "0+", + "about": "The message offset to be committed." }, + { "name": "CommittedLeaderEpoch", "type": "int32", "versions": "6+", "default": "-1", "ignorable": true, + "about": "The leader epoch of this partition." }, + // CommitTimestamp has been removed from v2 and later. + { "name": "CommitTimestamp", "type": "int64", "versions": "1", "default": "-1", + "about": "The timestamp of the commit." }, + { "name": "CommittedMetadata", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "Any associated metadata the client wants to keep." } + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/OffsetCommitResponse.json b/clients/src/main/resources/common/message/OffsetCommitResponse.json new file mode 100644 index 0000000..3d54779 --- /dev/null +++ b/clients/src/main/resources/common/message/OffsetCommitResponse.json @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 8, + "type": "response", + "name": "OffsetCommitResponse", + // Versions 1 and 2 are the same as version 0. + // + // Version 3 adds the throttle time to the response. + // + // Starting in version 4, on quota violation, brokers send out responses before throttling. + // + // Versions 5 and 6 are the same as version 4. + // + // Version 7 offsetCommitRequest supports a new field called groupInstanceId to indicate member identity across restarts. + // + // Version 8 is the first flexible version. + "validVersions": "0-8", + "flexibleVersions": "8+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "3+", "ignorable": true, + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "Topics", "type": "[]OffsetCommitResponseTopic", "versions": "0+", + "about": "The responses for each topic.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name." }, + { "name": "Partitions", "type": "[]OffsetCommitResponsePartition", "versions": "0+", + "about": "The responses for each partition in the topic.", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if there was no error." } + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/OffsetDeleteRequest.json b/clients/src/main/resources/common/message/OffsetDeleteRequest.json new file mode 100644 index 0000000..4a9dea6 --- /dev/null +++ b/clients/src/main/resources/common/message/OffsetDeleteRequest.json @@ -0,0 +1,39 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 47, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "OffsetDeleteRequest", + "validVersions": "0", + "flexibleVersions": "none", + "fields": [ + { "name": "GroupId", "type": "string", "versions": "0+", "entityType": "groupId", + "about": "The unique group identifier." }, + { "name": "Topics", "type": "[]OffsetDeleteRequestTopic", "versions": "0+", + "about": "The topics to delete offsets for", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "mapKey": true, "entityType": "topicName", + "about": "The topic name." }, + { "name": "Partitions", "type": "[]OffsetDeleteRequestPartition", "versions": "0+", + "about": "Each partition to delete offsets for.", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index." } + ] + } + ] + } + ] +} diff --git a/clients/src/main/resources/common/message/OffsetDeleteResponse.json b/clients/src/main/resources/common/message/OffsetDeleteResponse.json new file mode 100644 index 0000000..d32b36f --- /dev/null +++ b/clients/src/main/resources/common/message/OffsetDeleteResponse.json @@ -0,0 +1,42 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 47, + "type": "response", + "name": "OffsetDeleteResponse", + "validVersions": "0", + "flexibleVersions": "none", + "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The top-level error code, or 0 if there was no error." }, + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", "ignorable": true, + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "Topics", "type": "[]OffsetDeleteResponseTopic", "versions": "0+", + "about": "The responses for each topic.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "mapKey": true, "entityType": "topicName", + "about": "The topic name." }, + { "name": "Partitions", "type": "[]OffsetDeleteResponsePartition", "versions": "0+", + "about": "The responses for each partition in the topic.", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", "mapKey": true, + "about": "The partition index." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if there was no error." } + ] + } + ] + } + ] +} diff --git a/clients/src/main/resources/common/message/OffsetFetchRequest.json b/clients/src/main/resources/common/message/OffsetFetchRequest.json new file mode 100644 index 0000000..8f3c414 --- /dev/null +++ b/clients/src/main/resources/common/message/OffsetFetchRequest.json @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 9, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "OffsetFetchRequest", + // In version 0, the request read offsets from ZK. + // + // Starting in version 1, the broker supports fetching offsets from the internal __consumer_offsets topic. + // + // Starting in version 2, the request can contain a null topics array to indicate that offsets + // for all topics should be fetched. It also returns a top level error code + // for group or coordinator level errors. + // + // Version 3, 4, and 5 are the same as version 2. + // + // Version 6 is the first flexible version. + // + // Version 7 is adding the require stable flag. + // + // Version 8 is adding support for fetching offsets for multiple groups at a time + "validVersions": "0-8", + "flexibleVersions": "6+", + "fields": [ + { "name": "GroupId", "type": "string", "versions": "0-7", "entityType": "groupId", + "about": "The group to fetch offsets for." }, + { "name": "Topics", "type": "[]OffsetFetchRequestTopic", "versions": "0-7", "nullableVersions": "2-7", + "about": "Each topic we would like to fetch offsets for, or null to fetch offsets for all topics.", "fields": [ + { "name": "Name", "type": "string", "versions": "0-7", "entityType": "topicName", + "about": "The topic name."}, + { "name": "PartitionIndexes", "type": "[]int32", "versions": "0-7", + "about": "The partition indexes we would like to fetch offsets for." } + ]}, + { "name": "Groups", "type": "[]OffsetFetchRequestGroup", "versions": "8+", + "about": "Each group we would like to fetch offsets for", "fields": [ + { "name": "groupId", "type": "string", "versions": "8+", "entityType": "groupId", + "about": "The group ID."}, + { "name": "Topics", "type": "[]OffsetFetchRequestTopics", "versions": "8+", "nullableVersions": "8+", + "about": "Each topic we would like to fetch offsets for, or null to fetch offsets for all topics.", "fields": [ + { "name": "Name", "type": "string", "versions": "8+", "entityType": "topicName", + "about": "The topic name."}, + { "name": "PartitionIndexes", "type": "[]int32", "versions": "8+", + "about": "The partition indexes we would like to fetch offsets for." } + ]} + ]}, + {"name": "RequireStable", "type": "bool", "versions": "7+", "default": "false", + "about": "Whether broker should hold on returning unstable offsets but set a retriable error code for the partitions."} + ] +} diff --git a/clients/src/main/resources/common/message/OffsetFetchResponse.json b/clients/src/main/resources/common/message/OffsetFetchResponse.json new file mode 100644 index 0000000..dfad60e --- /dev/null +++ b/clients/src/main/resources/common/message/OffsetFetchResponse.json @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 9, + "type": "response", + "name": "OffsetFetchResponse", + // Version 1 is the same as version 0. + // + // Version 2 adds a top-level error code. + // + // Version 3 adds the throttle time. + // + // Starting in version 4, on quota violation, brokers send out responses before throttling. + // + // Version 5 adds the leader epoch to the committed offset. + // + // Version 6 is the first flexible version. + // + // Version 7 adds pending offset commit as new error response on partition level. + // + // Version 8 is adding support for fetching offsets for multiple groups + "validVersions": "0-8", + "flexibleVersions": "6+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "3+", "ignorable": true, + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "Topics", "type": "[]OffsetFetchResponseTopic", "versions": "0-7", + "about": "The responses per topic.", "fields": [ + { "name": "Name", "type": "string", "versions": "0-7", "entityType": "topicName", + "about": "The topic name." }, + { "name": "Partitions", "type": "[]OffsetFetchResponsePartition", "versions": "0-7", + "about": "The responses per partition", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0-7", + "about": "The partition index." }, + { "name": "CommittedOffset", "type": "int64", "versions": "0-7", + "about": "The committed message offset." }, + { "name": "CommittedLeaderEpoch", "type": "int32", "versions": "5-7", "default": "-1", + "ignorable": true, "about": "The leader epoch." }, + { "name": "Metadata", "type": "string", "versions": "0-7", "nullableVersions": "0-7", + "about": "The partition metadata." }, + { "name": "ErrorCode", "type": "int16", "versions": "0-7", + "about": "The error code, or 0 if there was no error." } + ]} + ]}, + { "name": "ErrorCode", "type": "int16", "versions": "2-7", "default": "0", "ignorable": true, + "about": "The top-level error code, or 0 if there was no error." }, + {"name": "Groups", "type": "[]OffsetFetchResponseGroup", "versions": "8+", + "about": "The responses per group id.", "fields": [ + { "name": "groupId", "type": "string", "versions": "8+", "entityType": "groupId", + "about": "The group ID." }, + { "name": "Topics", "type": "[]OffsetFetchResponseTopics", "versions": "8+", + "about": "The responses per topic.", "fields": [ + { "name": "Name", "type": "string", "versions": "8+", "entityType": "topicName", + "about": "The topic name." }, + { "name": "Partitions", "type": "[]OffsetFetchResponsePartitions", "versions": "8+", + "about": "The responses per partition", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "8+", + "about": "The partition index." }, + { "name": "CommittedOffset", "type": "int64", "versions": "8+", + "about": "The committed message offset." }, + { "name": "CommittedLeaderEpoch", "type": "int32", "versions": "8+", "default": "-1", + "ignorable": true, "about": "The leader epoch." }, + { "name": "Metadata", "type": "string", "versions": "8+", "nullableVersions": "8+", + "about": "The partition metadata." }, + { "name": "ErrorCode", "type": "int16", "versions": "8+", + "about": "The partition-level error code, or 0 if there was no error." } + ]} + ]}, + { "name": "ErrorCode", "type": "int16", "versions": "8+", "default": "0", + "about": "The group-level error code, or 0 if there was no error." } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/OffsetForLeaderEpochRequest.json b/clients/src/main/resources/common/message/OffsetForLeaderEpochRequest.json new file mode 100644 index 0000000..6645ad2 --- /dev/null +++ b/clients/src/main/resources/common/message/OffsetForLeaderEpochRequest.json @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 23, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "OffsetForLeaderEpochRequest", + // Version 1 is the same as version 0. + // + // Version 2 adds the current leader epoch to support fencing. + // + // Version 3 adds ReplicaId (the default is -2 which conventionally represents a + // "debug" consumer which is allowed to see offsets beyond the high watermark). + // Followers will use this replicaId when using an older version of the protocol. + // + // Version 4 enables flexible versions. + "validVersions": "0-4", + "flexibleVersions": "4+", + "fields": [ + { "name": "ReplicaId", "type": "int32", "versions": "3+", "default": -2, "ignorable": true, "entityType": "brokerId", + "about": "The broker ID of the follower, of -1 if this request is from a consumer." }, + { "name": "Topics", "type": "[]OffsetForLeaderTopic", "versions": "0+", + "about": "Each topic to get offsets for.", "fields": [ + { "name": "Topic", "type": "string", "versions": "0+", "entityType": "topicName", + "mapKey": true, "about": "The topic name." }, + { "name": "Partitions", "type": "[]OffsetForLeaderPartition", "versions": "0+", + "about": "Each partition to get offsets for.", "fields": [ + { "name": "Partition", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "CurrentLeaderEpoch", "type": "int32", "versions": "2+", "default": "-1", "ignorable": true, + "about": "An epoch used to fence consumers/replicas with old metadata. If the epoch provided by the client is larger than the current epoch known to the broker, then the UNKNOWN_LEADER_EPOCH error code will be returned. If the provided epoch is smaller, then the FENCED_LEADER_EPOCH error code will be returned." }, + { "name": "LeaderEpoch", "type": "int32", "versions": "0+", + "about": "The epoch to look up an offset for." } + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/OffsetForLeaderEpochResponse.json b/clients/src/main/resources/common/message/OffsetForLeaderEpochResponse.json new file mode 100644 index 0000000..2b0810e --- /dev/null +++ b/clients/src/main/resources/common/message/OffsetForLeaderEpochResponse.json @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 23, + "type": "response", + "name": "OffsetForLeaderEpochResponse", + // Version 1 added the leader epoch to the response. + // + // Version 2 added the throttle time. + // + // Version 3 is the same as version 2. + // + // Version 4 enables flexible versions. + "validVersions": "0-4", + "flexibleVersions": "4+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "2+", "ignorable": true, + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "Topics", "type": "[]OffsetForLeaderTopicResult", "versions": "0+", + "about": "Each topic we fetched offsets for.", "fields": [ + { "name": "Topic", "type": "string", "versions": "0+", "entityType": "topicName", + "mapKey": true, "about": "The topic name." }, + { "name": "Partitions", "type": "[]EpochEndOffset", "versions": "0+", + "about": "Each partition in the topic we fetched offsets for.", "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code 0, or if there was no error." }, + { "name": "Partition", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "LeaderEpoch", "type": "int32", "versions": "1+", "default": "-1", "ignorable": true, + "about": "The leader epoch of the partition." }, + { "name": "EndOffset", "type": "int64", "versions": "0+", "default": "-1", + "about": "The end offset of the epoch." } + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/ProduceRequest.json b/clients/src/main/resources/common/message/ProduceRequest.json new file mode 100644 index 0000000..90900af --- /dev/null +++ b/clients/src/main/resources/common/message/ProduceRequest.json @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 0, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "ProduceRequest", + // Version 1 and 2 are the same as version 0. + // + // Version 3 adds the transactional ID, which is used for authorization when attempting to write + // transactional data. Version 3 also adds support for Kafka Message Format v2. + // + // Version 4 is the same as version 3, but the requestor must be prepared to handle a + // KAFKA_STORAGE_ERROR. + // + // Version 5 and 6 are the same as version 3. + // + // Starting in version 7, records can be produced using ZStandard compression. See KIP-110. + // + // Starting in Version 8, response has RecordErrors and ErrorMEssage. See KIP-467. + // + // Version 9 enables flexible versions. + "validVersions": "0-9", + "flexibleVersions": "9+", + "fields": [ + { "name": "TransactionalId", "type": "string", "versions": "3+", "nullableVersions": "3+", "default": "null", "entityType": "transactionalId", + "about": "The transactional ID, or null if the producer is not transactional." }, + { "name": "Acks", "type": "int16", "versions": "0+", + "about": "The number of acknowledgments the producer requires the leader to have received before considering a request complete. Allowed values: 0 for no acknowledgments, 1 for only the leader and -1 for the full ISR." }, + { "name": "TimeoutMs", "type": "int32", "versions": "0+", + "about": "The timeout to await a response in milliseconds." }, + { "name": "TopicData", "type": "[]TopicProduceData", "versions": "0+", + "about": "Each topic to produce to.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "entityType": "topicName", "mapKey": true, + "about": "The topic name." }, + { "name": "PartitionData", "type": "[]PartitionProduceData", "versions": "0+", + "about": "Each partition to produce to.", "fields": [ + { "name": "Index", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "Records", "type": "records", "versions": "0+", "nullableVersions": "0+", + "about": "The record data to be produced." } + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/ProduceResponse.json b/clients/src/main/resources/common/message/ProduceResponse.json new file mode 100644 index 0000000..0c47f6d --- /dev/null +++ b/clients/src/main/resources/common/message/ProduceResponse.json @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 0, + "type": "response", + "name": "ProduceResponse", + // Version 1 added the throttle time. + // + // Version 2 added the log append time. + // + // Version 3 is the same as version 2. + // + // Version 4 added KAFKA_STORAGE_ERROR as a possible error code. + // + // Version 5 added LogStartOffset to filter out spurious + // OutOfOrderSequenceExceptions on the client. + // + // Version 8 added RecordErrors and ErrorMessage to include information about + // records that cause the whole batch to be dropped. See KIP-467 for details. + // + // Version 9 enables flexible versions. + "validVersions": "0-9", + "flexibleVersions": "9+", + "fields": [ + { "name": "Responses", "type": "[]TopicProduceResponse", "versions": "0+", + "about": "Each produce response", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "entityType": "topicName", "mapKey": true, + "about": "The topic name" }, + { "name": "PartitionResponses", "type": "[]PartitionProduceResponse", "versions": "0+", + "about": "Each partition that we produced to within the topic.", "fields": [ + { "name": "Index", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if there was no error." }, + { "name": "BaseOffset", "type": "int64", "versions": "0+", + "about": "The base offset." }, + { "name": "LogAppendTimeMs", "type": "int64", "versions": "2+", "default": "-1", "ignorable": true, + "about": "The timestamp returned by broker after appending the messages. If CreateTime is used for the topic, the timestamp will be -1. If LogAppendTime is used for the topic, the timestamp will be the broker local time when the messages are appended." }, + { "name": "LogStartOffset", "type": "int64", "versions": "5+", "default": "-1", "ignorable": true, + "about": "The log start offset." }, + { "name": "RecordErrors", "type": "[]BatchIndexAndErrorMessage", "versions": "8+", "ignorable": true, + "about": "The batch indices of records that caused the batch to be dropped", "fields": [ + { "name": "BatchIndex", "type": "int32", "versions": "8+", + "about": "The batch index of the record that cause the batch to be dropped" }, + { "name": "BatchIndexErrorMessage", "type": "string", "default": "null", "versions": "8+", "nullableVersions": "8+", + "about": "The error message of the record that caused the batch to be dropped"} + ]}, + { "name": "ErrorMessage", "type": "string", "default": "null", "versions": "8+", "nullableVersions": "8+", "ignorable": true, + "about": "The global error message summarizing the common root cause of the records that caused the batch to be dropped"} + ]} + ]}, + { "name": "ThrottleTimeMs", "type": "int32", "versions": "1+", "ignorable": true, "default": "0", + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." } + ] +} diff --git a/clients/src/main/resources/common/message/README.md b/clients/src/main/resources/common/message/README.md new file mode 100644 index 0000000..6185b87 --- /dev/null +++ b/clients/src/main/resources/common/message/README.md @@ -0,0 +1,267 @@ +Apache Kafka Message Definitions +================================ + +Introduction +------------ +The JSON files in this directory define the Apache Kafka message protocol. +This protocol describes what information clients and servers send to each +other, and how it is serialized. Note that this version of JSON supports +comments. Comments begin with a double forward slash. + +When Kafka is compiled, these specification files are translated into Java code +to read and write messages. Any change to these JSON files will trigger a +recompilation of this generated code. + +These specification files replace an older system where hand-written +serialization code was used. Over time, we will migrate all messages to using +automatically generated serialization and deserialization code. + +Requests and Responses +---------------------- +The Kafka protocol features requests and responses. Requests are sent to a +server in order to get a response. Each request is uniquely identified by a +16-bit integer called the "api key". The API key of the response will always +match that of the request. + +Each message has a unique 16-bit version number. The schema might be different +for each version of the message. Sometimes, the version is incremented even +though the schema has not changed. This may indicate that the server should +behave differently in some way. The version of a response must always match +the version of the corresponding request. + +Each request or response has a top-level field named "validVersions." This +specifies the versions of the protocol that our code understands. For example, +specifying "0-2" indicates that we understand versions 0, 1, and 2. You must +always specify the highest message version which is supported. + +The only old message versions that are no longer supported are version 0 of +MetadataRequest and MetadataResponse. In general, since we adopted KIP-97, +dropping support for old message versions is no longer allowed without a KIP. +Therefore, please be careful not to increase the lower end of the version +support interval for any message. + +MessageData Objects +------------------- +Using the JSON files in this directory, we generate Java code for MessageData +objects. These objects store request and response data for kafka. MessageData +objects do not contain a version number. Instead, a single MessageData object +represents every possible version of a Message. This makes working with +messages more convenient, because the same code path can be used for every +version of a message. + +Fields +------ +Each message contains an array of fields. Fields specify the data that should +be sent with the message. In general, fields have a name, a type, and version +information associated with them. + +The order that fields appear in a message is important. Fields which come +first in the message definition will be sent first over the network. Changing +the order of the fields in a message is an incompatible change. + +In each new message version, we may add or subtract fields. For example, if we +are creating a new version 3 of a message, we can add a new field with the +version spec "3+". This specifies that the field only appears in version 3 and +later. If a field is being removed, we should change its version from "0+" to +"0-2" to indicate that it will not appear in version 3 and later. + +Field Types +----------- +There are several primitive field types available. + +* "boolean": either true or false. + +* "int8": an 8-bit integer. + +* "int16": a 16-bit integer. + +* "uint16": a 16-bit unsigned integer. + +* "int32": a 32-bit integer. + +* "uint32": a 32-bit unsigned integer. + +* "int64": a 64-bit integer. + +* "float64": is a double-precision floating point number (IEEE 754). + +* "string": a UTF-8 string. + +* "uuid": a type 4 immutable universally unique identifier. + +* "bytes": binary data. + +* "records": recordset such as memory recordset. + +In addition to these primitive field types, there is also an array type. Array +types start with a "[]" and end with the name of the element type. For +example, []Foo declares an array of "Foo" objects. Array fields have their own +array of fields, which specifies what is in the contained objects. + +For information about how fields are serialized, see the [Kafka Protocol +Guide](https://kafka.apache.org/protocol.html). + +Nullable Fields +--------------- +Booleans, ints, and floats can never be null. However, fields that are strings, +bytes, uuid, records, or arrays may optionally be "nullable". When a field is +"nullable", that simply means that we are prepared to serialize and deserialize +null entries for that field. + +If you want to declare a field as nullable, you set "nullableVersions" for that +field. Nullability is implemented as a version range in order to accommodate a +very common pattern in Kafka where a field that was originally not nullable +becomes nullable in a later version. + +If a field is declared as non-nullable, and it is present in the message +version you are using, you should set it to a non-null value before serializing +the message. Otherwise, you will get a runtime error. + +Tagged Fields +------------- +Tagged fields are an extension to the Kafka protocol which allows optional data +to be attached to messages. Tagged fields can appear at the root level of +messages, or within any structure in the message. + +Unlike mandatory fields, tagged fields can be added to message versions that +already exists. Older servers will ignore new tagged fields which they do not +understand. + +In order to make a field tagged, set a "tag" for the field, and also set up +tagged versions for the field. The taggedVersions you specify should be +open-ended-- that is, they should specify a start version, but not an end +version. + +You can remove support for a tagged field from a specific version of a message, +but you can't reuse a tag once it has been used for something else. Once tags +have been used for something, they can't be used for anything else, without +breaking compatibility. + +Note that tagged fields can only be added to "flexible" message versions. + +Flexible Versions +----------------- +Kafka serialization has been improved over time to be more flexible and +efficient. Message versions that contain these improvements are referred to as +"flexible versions." + +In flexible versions, variable-length fields such as strings, arrays, and bytes +fields are serialized in a more efficient way that saves space. The new +serialization types start with compact. For example COMPACT_STRING is a more +efficient form of STRING. + +Serializing Messages +-------------------- +The Message#write method writes out a message to a buffer. The fields that are +written out will depend on the version number that you supply to write(). When +you write out a message using an older version, fields that are too old to be +present in the schema will be omitted. + +When working with older message versions, please verify that the older message +schema includes all the data that needs to be sent. For example, it is probably +OK to skip sending a timeout field. However, a field which radically alters the +meaning of the request, such as a "validateOnly" boolean, should not be ignored. + +It's often useful to know how much space a message will take up before writing +it out to a buffer. You can find this out by calling the Message#size method. + +Deserializing Messages +---------------------- +Message objects may be deserialized using the Message#read method. This method +overwrites all the data currently in the message object with new data. + +Any fields in the message object that are not present in the version that you +are deserializing will be reset to default values. Unless a custom default has +been set: + +* Integer fields default to 0. + +* Floats default to 0. + +* Booleans default to false. + +* Strings default to the empty string. + +* Bytes fields default to the empty byte array. + +* Uuid fields default to zero uuid. + +* Records fields default to null. + +* Array fields default to empty. + +You can specify "null" as a default value for a string field by specifying the +literal string "null". Note that you can only specify null as a default if all +versions of the field are nullable. + +Custom Default Values +--------------------- +You may set a custom default for fields that are integers, booleans, floats, or +strings. Just add a "default" entry in the JSON object. The custom default +overrides the normal default for the type. So for example, you could make a +boolean field default to true rather than false, and so forth. + +Note that the default must be valid for the field type. So the default for an +int16 field must be an integer that fits in 16 bits, and so forth. You may +specify hex or octal values, as long as they are prefixed with 0x or 0. It is +currently not possible to specify a custom default for bytes or array fields. + +Custom defaults are useful when an older message version lacked some +information. For example, if an older request lacked a timeout field, you may +want to specify that the server should assume that the timeout for such a +request is 5000 ms (or some other arbitrary value). + +Ignorable Fields +---------------- +When we write messages using an older or newer format, not all fields may be +present. The message receiver will fill in the default value for the field +during deserialization. Therefore, if the source field was set to a non-default +value, that information will be lost. + +In some cases, this information loss is acceptable. For example, if a timeout +field does not get preserved, this is not a problem. However, in other cases, +the field is really quite important and should not be discarded. One example is +a "verify only" boolean which changes the whole meaning of the request. + +By default, we assume that information loss is not acceptable. The message +serialization code will throw an exception if the ignored field is not set to +the default value. If information loss for a field is OK, please set +"ignorable" to true for the field to disable this behavior. When ignorable is +set to true, the field may be silently omitted during serialization. + +Hash Sets +--------- +One very common pattern in Kafka is to load array elements from a message into +a Map or Set for easier access. The message protocol makes this easier with +the "mapKey" concept. + +If some of the elements of an array are annotated with "mapKey": true, the +entire array will be treated as a linked hash set rather than a list. Elements +in this set will be accessible in O(1) time with an automatically generated +"find" function. The order of elements in the set will still be preserved, +however. New entries that are added to the set always show up as last in the +ordering. + +Incompatible Changes +-------------------- +It's very important to avoid making incompatible changes to the message +protocol. Here are some examples of incompatible changes: + +#### Making changes to a protocol version which has already been released. +Protocol versions that have been released must be regarded as done. If there +were mistakes, they should be corrected in a new version rather than changing +the existing version. + +#### Re-ordering existing fields. +It is OK to add new fields before or after existing fields. However, existing +fields should not be re-ordered with respect to each other. + +#### Changing the default of an existing field. +You must never change the default of a field which already exists. Otherwise, +new clients and old servers will not agree on the default, and so forth. + +#### Changing the type of an existing field. +One exception is that an array of primitives may be changed to an array of +structures containing the same data, as long as the conversion is done +correctly. The Kafka protocol does not do any "boxing" of structures, so an +array of structs that contain a single int32 is the same as an array of int32s. diff --git a/clients/src/main/resources/common/message/RenewDelegationTokenRequest.json b/clients/src/main/resources/common/message/RenewDelegationTokenRequest.json new file mode 100644 index 0000000..7240ac3 --- /dev/null +++ b/clients/src/main/resources/common/message/RenewDelegationTokenRequest.json @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 39, + "type": "request", + "listeners": ["zkBroker"], + "name": "RenewDelegationTokenRequest", + // Version 1 is the same as version 0. + // Version 2 adds flexible version support + "validVersions": "0-2", + "flexibleVersions": "2+", + "fields": [ + { "name": "Hmac", "type": "bytes", "versions": "0+", + "about": "The HMAC of the delegation token to be renewed." }, + { "name": "RenewPeriodMs", "type": "int64", "versions": "0+", + "about": "The renewal time period in milliseconds." } + ] +} diff --git a/clients/src/main/resources/common/message/RenewDelegationTokenResponse.json b/clients/src/main/resources/common/message/RenewDelegationTokenResponse.json new file mode 100644 index 0000000..c429dad --- /dev/null +++ b/clients/src/main/resources/common/message/RenewDelegationTokenResponse.json @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 39, + "type": "response", + "name": "RenewDelegationTokenResponse", + // Starting in version 1, on quota violation, brokers send out responses before throttling. + // Version 2 adds flexible version support + "validVersions": "0-2", + "flexibleVersions": "2+", + "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if there was no error." }, + { "name": "ExpiryTimestampMs", "type": "int64", "versions": "0+", + "about": "The timestamp in milliseconds at which this token expires." }, + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." } + ] +} diff --git a/clients/src/main/resources/common/message/RequestHeader.json b/clients/src/main/resources/common/message/RequestHeader.json new file mode 100644 index 0000000..fbf4e2c --- /dev/null +++ b/clients/src/main/resources/common/message/RequestHeader.json @@ -0,0 +1,42 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "type": "header", + "name": "RequestHeader", + // Version 0 of the RequestHeader is only used by v0 of ControlledShutdownRequest. + // + // Version 1 is the first version with ClientId. + // + // Version 2 is the first flexible version. + "validVersions": "0-2", + "flexibleVersions": "2+", + "fields": [ + { "name": "RequestApiKey", "type": "int16", "versions": "0+", + "about": "The API key of this request." }, + { "name": "RequestApiVersion", "type": "int16", "versions": "0+", + "about": "The API version of this request." }, + { "name": "CorrelationId", "type": "int32", "versions": "0+", + "about": "The correlation ID of this request." }, + + // The ClientId string must be serialized with the old-style two-byte length prefix. + // The reason is that older brokers must be able to read the request header for any + // ApiVersionsRequest, even if it is from a newer version. + // Since the client is sending the ApiVersionsRequest in order to discover what + // versions are supported, the client does not know the best version to use. + { "name": "ClientId", "type": "string", "versions": "1+", "nullableVersions": "1+", "ignorable": true, + "flexibleVersions": "none", "about": "The client ID string." } + ] +} diff --git a/clients/src/main/resources/common/message/ResponseHeader.json b/clients/src/main/resources/common/message/ResponseHeader.json new file mode 100644 index 0000000..7736736 --- /dev/null +++ b/clients/src/main/resources/common/message/ResponseHeader.json @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "type": "header", + "name": "ResponseHeader", + // Version 1 is the first flexible version. + "validVersions": "0-1", + "flexibleVersions": "1+", + "fields": [ + { "name": "CorrelationId", "type": "int32", "versions": "0+", + "about": "The correlation ID of this response." } + ] +} diff --git a/clients/src/main/resources/common/message/SaslAuthenticateRequest.json b/clients/src/main/resources/common/message/SaslAuthenticateRequest.json new file mode 100644 index 0000000..3f5558b --- /dev/null +++ b/clients/src/main/resources/common/message/SaslAuthenticateRequest.json @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 36, + "type": "request", + "listeners": ["zkBroker", "broker", "controller"], + "name": "SaslAuthenticateRequest", + // Version 1 is the same as version 0. + // Version 2 adds flexible version support + "validVersions": "0-2", + "flexibleVersions": "2+", + "fields": [ + { "name": "AuthBytes", "type": "bytes", "versions": "0+", + "about": "The SASL authentication bytes from the client, as defined by the SASL mechanism." } + ] +} diff --git a/clients/src/main/resources/common/message/SaslAuthenticateResponse.json b/clients/src/main/resources/common/message/SaslAuthenticateResponse.json new file mode 100644 index 0000000..0e26a51 --- /dev/null +++ b/clients/src/main/resources/common/message/SaslAuthenticateResponse.json @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 36, + "type": "response", + "name": "SaslAuthenticateResponse", + // Version 1 adds the session lifetime. + // Version 2 adds flexible version support + "validVersions": "0-2", + "flexibleVersions": "2+", + "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if there was no error." }, + { "name": "ErrorMessage", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The error message, or null if there was no error." }, + { "name": "AuthBytes", "type": "bytes", "versions": "0+", + "about": "The SASL authentication bytes from the server, as defined by the SASL mechanism." }, + { "name": "SessionLifetimeMs", "type": "int64", "versions": "1+", "default": "0", "ignorable": true, + "about": "The SASL authentication bytes from the server, as defined by the SASL mechanism." } + ] +} diff --git a/clients/src/main/resources/common/message/SaslHandshakeRequest.json b/clients/src/main/resources/common/message/SaslHandshakeRequest.json new file mode 100644 index 0000000..a370a80 --- /dev/null +++ b/clients/src/main/resources/common/message/SaslHandshakeRequest.json @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 17, + "type": "request", + "listeners": ["zkBroker", "broker", "controller"], + "name": "SaslHandshakeRequest", + // Version 1 supports SASL_AUTHENTICATE. + // NOTE: Version cannot be easily bumped due to incorrect + // client negotiation for clients <= 2.4. + // See https://issues.apache.org/jira/browse/KAFKA-9577 + "validVersions": "0-1", + "flexibleVersions": "none", + "fields": [ + { "name": "Mechanism", "type": "string", "versions": "0+", + "about": "The SASL mechanism chosen by the client." } + ] +} diff --git a/clients/src/main/resources/common/message/SaslHandshakeResponse.json b/clients/src/main/resources/common/message/SaslHandshakeResponse.json new file mode 100644 index 0000000..a1567c6 --- /dev/null +++ b/clients/src/main/resources/common/message/SaslHandshakeResponse.json @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 17, + "type": "response", + "name": "SaslHandshakeResponse", + // Version 1 is the same as version 0. + // NOTE: Version cannot be easily bumped due to incorrect + // client negotiation for clients <= 2.4. + // See https://issues.apache.org/jira/browse/KAFKA-9577 + "validVersions": "0-1", + "flexibleVersions": "none", + "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if there was no error." }, + { "name": "Mechanisms", "type": "[]string", "versions": "0+", + "about": "The mechanisms enabled in the server." } + ] +} diff --git a/clients/src/main/resources/common/message/SnapshotFooterRecord.json b/clients/src/main/resources/common/message/SnapshotFooterRecord.json new file mode 100644 index 0000000..0d776b3 --- /dev/null +++ b/clients/src/main/resources/common/message/SnapshotFooterRecord.json @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "type": "data", + "name": "SnapshotFooterRecord", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + {"name": "Version", "type": "int16", "versions": "0+", + "about": "The version of the snapshot footer record"} + ] +} diff --git a/clients/src/main/resources/common/message/SnapshotHeaderRecord.json b/clients/src/main/resources/common/message/SnapshotHeaderRecord.json new file mode 100644 index 0000000..0a03b9c --- /dev/null +++ b/clients/src/main/resources/common/message/SnapshotHeaderRecord.json @@ -0,0 +1,27 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "type": "data", + "name": "SnapshotHeaderRecord", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + {"name": "Version", "type": "int16", "versions": "0+", + "about": "The version of the snapshot header record"}, + {"name": "LastContainedLogTimestamp", "type": "int64", "versions": "0+", + "about": "The append time of the last record from the log contained in this snapshot"} + ] +} diff --git a/clients/src/main/resources/common/message/StopReplicaRequest.json b/clients/src/main/resources/common/message/StopReplicaRequest.json new file mode 100644 index 0000000..b10154f --- /dev/null +++ b/clients/src/main/resources/common/message/StopReplicaRequest.json @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 5, + "type": "request", + "listeners": ["zkBroker"], + "name": "StopReplicaRequest", + // Version 1 adds the broker epoch and reorganizes the partitions to be stored + // per topic. + // + // Version 2 is the first flexible version. + // + // Version 3 adds the leader epoch per partition (KIP-570). + "validVersions": "0-3", + "flexibleVersions": "2+", + "fields": [ + { "name": "ControllerId", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The controller id." }, + { "name": "ControllerEpoch", "type": "int32", "versions": "0+", + "about": "The controller epoch." }, + { "name": "BrokerEpoch", "type": "int64", "versions": "1+", "default": "-1", "ignorable": true, + "about": "The broker epoch." }, + { "name": "DeletePartitions", "type": "bool", "versions": "0-2", + "about": "Whether these partitions should be deleted." }, + { "name": "UngroupedPartitions", "type": "[]StopReplicaPartitionV0", "versions": "0", + "about": "The partitions to stop.", "fields": [ + { "name": "TopicName", "type": "string", "versions": "0", "entityType": "topicName", + "about": "The topic name." }, + { "name": "PartitionIndex", "type": "int32", "versions": "0", + "about": "The partition index." } + ]}, + { "name": "Topics", "type": "[]StopReplicaTopicV1", "versions": "1-2", + "about": "The topics to stop.", "fields": [ + { "name": "Name", "type": "string", "versions": "1-2", "entityType": "topicName", + "about": "The topic name." }, + { "name": "PartitionIndexes", "type": "[]int32", "versions": "1-2", + "about": "The partition indexes." } + ]}, + { "name": "TopicStates", "type": "[]StopReplicaTopicState", "versions": "3+", + "about": "Each topic.", "fields": [ + { "name": "TopicName", "type": "string", "versions": "3+", "entityType": "topicName", + "about": "The topic name." }, + { "name": "PartitionStates", "type": "[]StopReplicaPartitionState", "versions": "3+", + "about": "The state of each partition", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "3+", + "about": "The partition index." }, + { "name": "LeaderEpoch", "type": "int32", "versions": "3+", "default": "-1", + "about": "The leader epoch." }, + { "name": "DeletePartition", "type": "bool", "versions": "3+", + "about": "Whether this partition should be deleted." } + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/StopReplicaResponse.json b/clients/src/main/resources/common/message/StopReplicaResponse.json new file mode 100644 index 0000000..64b355e --- /dev/null +++ b/clients/src/main/resources/common/message/StopReplicaResponse.json @@ -0,0 +1,40 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 5, + "type": "response", + "name": "StopReplicaResponse", + // Version 1 is the same as version 0. + // + // Version 2 is the first flexible version. + // + // Version 3 returns FENCED_LEADER_EPOCH if the epoch of the leader is stale (KIP-570). + "validVersions": "0-3", + "flexibleVersions": "2+", + "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The top-level error code, or 0 if there was no top-level error." }, + { "name": "PartitionErrors", "type": "[]StopReplicaPartitionError", "versions": "0+", + "about": "The responses for each partition.", "fields": [ + { "name": "TopicName", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name." }, + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The partition error code, or 0 if there was no partition error." } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/SyncGroupRequest.json b/clients/src/main/resources/common/message/SyncGroupRequest.json new file mode 100644 index 0000000..5525844 --- /dev/null +++ b/clients/src/main/resources/common/message/SyncGroupRequest.json @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 14, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "SyncGroupRequest", + // Versions 1 and 2 are the same as version 0. + // + // Starting from version 3, we add a new field called groupInstanceId to indicate member identity across restarts. + // + // Version 4 is the first flexible version. + // + // Starting from version 5, the client sends the Protocol Type and the Protocol Name + // to the broker (KIP-559). The broker will reject the request if they are inconsistent + // with the Type and Name known by the broker. + "validVersions": "0-5", + "flexibleVersions": "4+", + "fields": [ + { "name": "GroupId", "type": "string", "versions": "0+", "entityType": "groupId", + "about": "The unique group identifier." }, + { "name": "GenerationId", "type": "int32", "versions": "0+", + "about": "The generation of the group." }, + { "name": "MemberId", "type": "string", "versions": "0+", + "about": "The member ID assigned by the group." }, + { "name": "GroupInstanceId", "type": "string", "versions": "3+", + "nullableVersions": "3+", "default": "null", + "about": "The unique identifier of the consumer instance provided by end user." }, + { "name": "ProtocolType", "type": "string", "versions": "5+", + "nullableVersions": "5+", "default": "null", "ignorable": true, + "about": "The group protocol type." }, + { "name": "ProtocolName", "type": "string", "versions": "5+", + "nullableVersions": "5+", "default": "null", "ignorable": true, + "about": "The group protocol name." }, + { "name": "Assignments", "type": "[]SyncGroupRequestAssignment", "versions": "0+", + "about": "Each assignment.", "fields": [ + { "name": "MemberId", "type": "string", "versions": "0+", + "about": "The ID of the member to assign." }, + { "name": "Assignment", "type": "bytes", "versions": "0+", + "about": "The member assignment." } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/SyncGroupResponse.json b/clients/src/main/resources/common/message/SyncGroupResponse.json new file mode 100644 index 0000000..4aa17e0 --- /dev/null +++ b/clients/src/main/resources/common/message/SyncGroupResponse.json @@ -0,0 +1,46 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 14, + "type": "response", + "name": "SyncGroupResponse", + // Version 1 adds throttle time. + // + // Starting in version 2, on quota violation, brokers send out responses before throttling. + // + // Starting from version 3, syncGroupRequest supports a new field called groupInstanceId to indicate member identity across restarts. + // + // Version 4 is the first flexible version. + // + // Starting from version 5, the broker sends back the Protocol Type and the Protocol Name + // to the client (KIP-559). + "validVersions": "0-5", + "flexibleVersions": "4+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "1+", "ignorable": true, + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if there was no error." }, + { "name": "ProtocolType", "type": "string", "versions": "5+", + "nullableVersions": "5+", "default": "null", "ignorable": true, + "about": "The group protocol type." }, + { "name": "ProtocolName", "type": "string", "versions": "5+", + "nullableVersions": "5+", "default": "null", "ignorable": true, + "about": "The group protocol name." }, + { "name": "Assignment", "type": "bytes", "versions": "0+", + "about": "The member assignment." } + ] +} diff --git a/clients/src/main/resources/common/message/TxnOffsetCommitRequest.json b/clients/src/main/resources/common/message/TxnOffsetCommitRequest.json new file mode 100644 index 0000000..a832ef7 --- /dev/null +++ b/clients/src/main/resources/common/message/TxnOffsetCommitRequest.json @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 28, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "TxnOffsetCommitRequest", + // Version 1 is the same as version 0. + // + // Version 2 adds the committed leader epoch. + // + // Version 3 adds the member.id, group.instance.id and generation.id. + "validVersions": "0-3", + "flexibleVersions": "3+", + "fields": [ + { "name": "TransactionalId", "type": "string", "versions": "0+", "entityType": "transactionalId", + "about": "The ID of the transaction." }, + { "name": "GroupId", "type": "string", "versions": "0+", "entityType": "groupId", + "about": "The ID of the group." }, + { "name": "ProducerId", "type": "int64", "versions": "0+", "entityType": "producerId", + "about": "The current producer ID in use by the transactional ID." }, + { "name": "ProducerEpoch", "type": "int16", "versions": "0+", + "about": "The current epoch associated with the producer ID." }, + { "name": "GenerationId", "type": "int32", "versions": "3+", "default": "-1", + "about": "The generation of the consumer." }, + { "name": "MemberId", "type": "string", "versions": "3+", "default": "", + "about": "The member ID assigned by the group coordinator." }, + { "name": "GroupInstanceId", "type": "string", "versions": "3+", + "nullableVersions": "3+", "default": "null", + "about": "The unique identifier of the consumer instance provided by end user." }, + { "name": "Topics", "type" : "[]TxnOffsetCommitRequestTopic", "versions": "0+", + "about": "Each topic that we want to commit offsets for.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name." }, + { "name": "Partitions", "type": "[]TxnOffsetCommitRequestPartition", "versions": "0+", + "about": "The partitions inside the topic that we want to committ offsets for.", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The index of the partition within the topic." }, + { "name": "CommittedOffset", "type": "int64", "versions": "0+", + "about": "The message offset to be committed." }, + { "name": "CommittedLeaderEpoch", "type": "int32", "versions": "2+", "default": "-1", "ignorable": true, + "about": "The leader epoch of the last consumed record." }, + { "name": "CommittedMetadata", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "Any associated metadata the client wants to keep." } + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/TxnOffsetCommitResponse.json b/clients/src/main/resources/common/message/TxnOffsetCommitResponse.json new file mode 100644 index 0000000..96b03a0 --- /dev/null +++ b/clients/src/main/resources/common/message/TxnOffsetCommitResponse.json @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 28, + "type": "response", + "name": "TxnOffsetCommitResponse", + // Starting in version 1, on quota violation, brokers send out responses before throttling. + // + // Version 2 is the same as version 1. + // + // Version 3 adds illegal generation, fenced instance id, and unknown member id errors. + "validVersions": "0-3", + "flexibleVersions": "3+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "Topics", "type": "[]TxnOffsetCommitResponseTopic", "versions": "0+", + "about": "The responses for each topic.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name." }, + { "name": "Partitions", "type": "[]TxnOffsetCommitResponsePartition", "versions": "0+", + "about": "The responses for each partition in the topic.", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if there was no error." } + ]} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/UnregisterBrokerRequest.json b/clients/src/main/resources/common/message/UnregisterBrokerRequest.json new file mode 100644 index 0000000..4fb8d8d --- /dev/null +++ b/clients/src/main/resources/common/message/UnregisterBrokerRequest.json @@ -0,0 +1,27 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 64, + "type": "request", + "listeners": ["controller"], + "name": "UnregisterBrokerRequest", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "BrokerId", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The broker ID to unregister." } + ] +} diff --git a/clients/src/main/resources/common/message/UnregisterBrokerResponse.json b/clients/src/main/resources/common/message/UnregisterBrokerResponse.json new file mode 100644 index 0000000..3a11c1a --- /dev/null +++ b/clients/src/main/resources/common/message/UnregisterBrokerResponse.json @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 64, + "type": "response", + "name": "UnregisterBrokerResponse", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "Duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if there was no error." }, + { "name": "ErrorMessage", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The top-level error message, or `null` if there was no top-level error." } + ] +} diff --git a/clients/src/main/resources/common/message/UpdateFeaturesRequest.json b/clients/src/main/resources/common/message/UpdateFeaturesRequest.json new file mode 100644 index 0000000..2b31813 --- /dev/null +++ b/clients/src/main/resources/common/message/UpdateFeaturesRequest.json @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 57, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "UpdateFeaturesRequest", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "timeoutMs", "type": "int32", "versions": "0+", "default": "60000", + "about": "How long to wait in milliseconds before timing out the request." }, + { "name": "FeatureUpdates", "type": "[]FeatureUpdateKey", "versions": "0+", + "about": "The list of updates to finalized features.", "fields": [ + {"name": "Feature", "type": "string", "versions": "0+", "mapKey": true, + "about": "The name of the finalized feature to be updated."}, + {"name": "MaxVersionLevel", "type": "int16", "versions": "0+", + "about": "The new maximum version level for the finalized feature. A value >= 1 is valid. A value < 1, is special, and can be used to request the deletion of the finalized feature."}, + {"name": "AllowDowngrade", "type": "bool", "versions": "0+", + "about": "When set to true, the finalized feature version level is allowed to be downgraded/deleted. The downgrade request will fail if the new maximum version level is a value that's not lower than the existing maximum finalized version level."} + ]} + ] +} diff --git a/clients/src/main/resources/common/message/UpdateFeaturesResponse.json b/clients/src/main/resources/common/message/UpdateFeaturesResponse.json new file mode 100644 index 0000000..63e84ff --- /dev/null +++ b/clients/src/main/resources/common/message/UpdateFeaturesResponse.json @@ -0,0 +1,39 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 57, + "type": "response", + "name": "UpdateFeaturesResponse", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", + "about": "The duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The top-level error code, or `0` if there was no top-level error." }, + { "name": "ErrorMessage", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The top-level error message, or `null` if there was no top-level error." }, + { "name": "Results", "type": "[]UpdatableFeatureResult", "versions": "0+", + "about": "Results for each feature update.", "fields": [ + { "name": "Feature", "type": "string", "versions": "0+", "mapKey": true, + "about": "The name of the finalized feature."}, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The feature update error code or `0` if the feature update succeeded." }, + { "name": "ErrorMessage", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The feature update error, or `null` if the feature update succeeded." } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/UpdateMetadataRequest.json b/clients/src/main/resources/common/message/UpdateMetadataRequest.json new file mode 100644 index 0000000..5f397a9 --- /dev/null +++ b/clients/src/main/resources/common/message/UpdateMetadataRequest.json @@ -0,0 +1,96 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 6, + "type": "request", + "listeners": ["zkBroker"], + "name": "UpdateMetadataRequest", + // Version 1 allows specifying multiple endpoints for each broker. + // + // Version 2 adds the rack. + // + // Version 3 adds the listener name. + // + // Version 4 adds the offline replica list. + // + // Version 5 adds the broker epoch field and normalizes partitions by topic. + // Version 7 adds topicId + "validVersions": "0-7", + "flexibleVersions": "6+", + "fields": [ + { "name": "ControllerId", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The controller id." }, + { "name": "ControllerEpoch", "type": "int32", "versions": "0+", + "about": "The controller epoch." }, + { "name": "BrokerEpoch", "type": "int64", "versions": "5+", "ignorable": true, "default": "-1", + "about": "The broker epoch." }, + { "name": "UngroupedPartitionStates", "type": "[]UpdateMetadataPartitionState", "versions": "0-4", + "about": "In older versions of this RPC, each partition that we would like to update." }, + { "name": "TopicStates", "type": "[]UpdateMetadataTopicState", "versions": "5+", + "about": "In newer versions of this RPC, each topic that we would like to update.", "fields": [ + { "name": "TopicName", "type": "string", "versions": "5+", "entityType": "topicName", + "about": "The topic name." }, + { "name": "TopicId", "type": "uuid", "versions": "7+", "ignorable": true, "about": "The topic id."}, + { "name": "PartitionStates", "type": "[]UpdateMetadataPartitionState", "versions": "5+", + "about": "The partition that we would like to update." } + ]}, + { "name": "LiveBrokers", "type": "[]UpdateMetadataBroker", "versions": "0+", "fields": [ + { "name": "Id", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The broker id." }, + // Version 0 of the protocol only allowed specifying a single host and + // port per broker, rather than an array of endpoints. + { "name": "V0Host", "type": "string", "versions": "0", "ignorable": true, + "about": "The broker hostname." }, + { "name": "V0Port", "type": "int32", "versions": "0", "ignorable": true, + "about": "The broker port." }, + { "name": "Endpoints", "type": "[]UpdateMetadataEndpoint", "versions": "1+", "ignorable": true, + "about": "The broker endpoints.", "fields": [ + { "name": "Port", "type": "int32", "versions": "1+", + "about": "The port of this endpoint" }, + { "name": "Host", "type": "string", "versions": "1+", + "about": "The hostname of this endpoint" }, + { "name": "Listener", "type": "string", "versions": "3+", "ignorable": true, + "about": "The listener name." }, + { "name": "SecurityProtocol", "type": "int16", "versions": "1+", + "about": "The security protocol type." } + ]}, + { "name": "Rack", "type": "string", "versions": "2+", "nullableVersions": "0+", "ignorable": true, + "about": "The rack which this broker belongs to." } + ]} + ], + "commonStructs": [ + { "name": "UpdateMetadataPartitionState", "versions": "0+", "fields": [ + { "name": "TopicName", "type": "string", "versions": "0-4", "entityType": "topicName", "ignorable": true, + "about": "In older versions of this RPC, the topic name." }, + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "ControllerEpoch", "type": "int32", "versions": "0+", + "about": "The controller epoch." }, + { "name": "Leader", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The ID of the broker which is the current partition leader." }, + { "name": "LeaderEpoch", "type": "int32", "versions": "0+", + "about": "The leader epoch of this partition." }, + { "name": "Isr", "type": "[]int32", "versions": "0+", "entityType": "brokerId", + "about": "The brokers which are in the ISR for this partition." }, + { "name": "ZkVersion", "type": "int32", "versions": "0+", + "about": "The Zookeeper version." }, + { "name": "Replicas", "type": "[]int32", "versions": "0+", "entityType": "brokerId", + "about": "All the replicas of this partition." }, + { "name": "OfflineReplicas", "type": "[]int32", "versions": "4+", "entityType": "brokerId", "ignorable": true, + "about": "The replicas of this partition which are offline." } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/UpdateMetadataResponse.json b/clients/src/main/resources/common/message/UpdateMetadataResponse.json new file mode 100644 index 0000000..6220322 --- /dev/null +++ b/clients/src/main/resources/common/message/UpdateMetadataResponse.json @@ -0,0 +1,27 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 6, + "type": "response", + "name": "UpdateMetadataResponse", + // Versions 1, 2, 3, 4, and 5 are the same as version 0 + "validVersions": "0-7", + "flexibleVersions": "6+", + "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if there was no error." } + ] +} diff --git a/clients/src/main/resources/common/message/VoteRequest.json b/clients/src/main/resources/common/message/VoteRequest.json new file mode 100644 index 0000000..35583a7 --- /dev/null +++ b/clients/src/main/resources/common/message/VoteRequest.json @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 52, + "type": "request", + "listeners": ["controller"], + "name": "VoteRequest", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "ClusterId", "type": "string", "versions": "0+", + "nullableVersions": "0+", "default": "null"}, + { "name": "Topics", "type": "[]TopicData", + "versions": "0+", "fields": [ + { "name": "TopicName", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name." }, + { "name": "Partitions", "type": "[]PartitionData", + "versions": "0+", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "CandidateEpoch", "type": "int32", "versions": "0+", + "about": "The bumped epoch of the candidate sending the request"}, + { "name": "CandidateId", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The ID of the voter sending the request"}, + { "name": "LastOffsetEpoch", "type": "int32", "versions": "0+", + "about": "The epoch of the last record written to the metadata log"}, + { "name": "LastOffset", "type": "int64", "versions": "0+", + "about": "The offset of the last record written to the metadata log"} + ] + } + ] + } + ] +} diff --git a/clients/src/main/resources/common/message/VoteResponse.json b/clients/src/main/resources/common/message/VoteResponse.json new file mode 100644 index 0000000..b92d007 --- /dev/null +++ b/clients/src/main/resources/common/message/VoteResponse.json @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 52, + "type": "response", + "name": "VoteResponse", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The top level error code."}, + { "name": "Topics", "type": "[]TopicData", + "versions": "0+", "fields": [ + { "name": "TopicName", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name." }, + { "name": "Partitions", "type": "[]PartitionData", + "versions": "0+", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+"}, + { "name": "LeaderId", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The ID of the current leader or -1 if the leader is unknown."}, + { "name": "LeaderEpoch", "type": "int32", "versions": "0+", + "about": "The latest known leader epoch"}, + { "name": "VoteGranted", "type": "bool", "versions": "0+", + "about": "True if the vote was granted and false otherwise"} + ] + } + ] + } + ] +} diff --git a/clients/src/main/resources/common/message/WriteTxnMarkersRequest.json b/clients/src/main/resources/common/message/WriteTxnMarkersRequest.json new file mode 100644 index 0000000..9e29fb3 --- /dev/null +++ b/clients/src/main/resources/common/message/WriteTxnMarkersRequest.json @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 27, + "type": "request", + "listeners": ["zkBroker", "broker"], + "name": "WriteTxnMarkersRequest", + // Version 1 enables flexible versions. + "validVersions": "0-1", + "flexibleVersions": "1+", + "fields": [ + { "name": "Markers", "type": "[]WritableTxnMarker", "versions": "0+", + "about": "The transaction markers to be written.", "fields": [ + { "name": "ProducerId", "type": "int64", "versions": "0+", "entityType": "producerId", + "about": "The current producer ID."}, + { "name": "ProducerEpoch", "type": "int16", "versions": "0+", + "about": "The current epoch associated with the producer ID." }, + { "name": "TransactionResult", "type": "bool", "versions": "0+", + "about": "The result of the transaction to write to the partitions (false = ABORT, true = COMMIT)." }, + { "name": "Topics", "type": "[]WritableTxnMarkerTopic", "versions": "0+", + "about": "Each topic that we want to write transaction marker(s) for.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name." }, + { "name": "PartitionIndexes", "type": "[]int32", "versions": "0+", + "about": "The indexes of the partitions to write transaction markers for." } + ]}, + { "name": "CoordinatorEpoch", "type": "int32", "versions": "0+", + "about": "Epoch associated with the transaction state partition hosted by this transaction coordinator" } + ]} + ] +} diff --git a/clients/src/main/resources/common/message/WriteTxnMarkersResponse.json b/clients/src/main/resources/common/message/WriteTxnMarkersResponse.json new file mode 100644 index 0000000..59b8d66 --- /dev/null +++ b/clients/src/main/resources/common/message/WriteTxnMarkersResponse.json @@ -0,0 +1,42 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 27, + "type": "response", + "name": "WriteTxnMarkersResponse", + "validVersions": "0-1", + // Version 1 enables flexible versions. + "flexibleVersions": "1+", + "fields": [ + { "name": "Markers", "type": "[]WritableTxnMarkerResult", "versions": "0+", + "about": "The results for writing makers.", "fields": [ + { "name": "ProducerId", "type": "int64", "versions": "0+", "entityType": "producerId", + "about": "The current producer ID in use by the transactional ID." }, + { "name": "Topics", "type": "[]WritableTxnMarkerTopicResult", "versions": "0+", + "about": "The results by topic.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name." }, + { "name": "Partitions", "type": "[]WritableTxnMarkerPartitionResult", "versions": "0+", + "about": "The results by partition.", "fields": [ + { "name": "PartitionIndex", "type": "int32", "versions": "0+", + "about": "The partition index." }, + { "name": "ErrorCode", "type": "int16", "versions": "0+", + "about": "The error code, or 0 if there was no error." } + ]} + ]} + ]} + ] +} diff --git a/clients/src/test/java/org/apache/kafka/clients/AddressChangeHostResolver.java b/clients/src/test/java/org/apache/kafka/clients/AddressChangeHostResolver.java new file mode 100644 index 0000000..28f9c88 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/AddressChangeHostResolver.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients; + +import java.net.InetAddress; + +class AddressChangeHostResolver implements HostResolver { + private boolean useNewAddresses; + private InetAddress[] initialAddresses; + private InetAddress[] newAddresses; + private int resolutionCount = 0; + + public AddressChangeHostResolver(InetAddress[] initialAddresses, InetAddress[] newAddresses) { + this.initialAddresses = initialAddresses; + this.newAddresses = newAddresses; + } + + @Override + public InetAddress[] resolve(String host) { + ++resolutionCount; + return useNewAddresses ? newAddresses : initialAddresses; + } + + public void changeAddresses() { + useNewAddresses = true; + } + + public boolean useNewAddresses() { + return useNewAddresses; + } + + public int resolutionCount() { + return resolutionCount; + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/ApiVersionsTest.java b/clients/src/test/java/org/apache/kafka/clients/ApiVersionsTest.java new file mode 100644 index 0000000..206e95e --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/ApiVersionsTest.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients; + +import org.apache.kafka.common.message.ApiVersionsResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.record.RecordBatch; +import org.junit.jupiter.api.Test; + +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class ApiVersionsTest { + + @Test + public void testMaxUsableProduceMagic() { + ApiVersions apiVersions = new ApiVersions(); + assertEquals(RecordBatch.CURRENT_MAGIC_VALUE, apiVersions.maxUsableProduceMagic()); + + apiVersions.update("0", NodeApiVersions.create()); + assertEquals(RecordBatch.CURRENT_MAGIC_VALUE, apiVersions.maxUsableProduceMagic()); + + apiVersions.update("1", NodeApiVersions.create(ApiKeys.PRODUCE.id, (short) 0, (short) 2)); + assertEquals(RecordBatch.MAGIC_VALUE_V1, apiVersions.maxUsableProduceMagic()); + + apiVersions.remove("1"); + assertEquals(RecordBatch.CURRENT_MAGIC_VALUE, apiVersions.maxUsableProduceMagic()); + } + + @Test + public void testMaxUsableProduceMagicWithRaftController() { + ApiVersions apiVersions = new ApiVersions(); + assertEquals(RecordBatch.CURRENT_MAGIC_VALUE, apiVersions.maxUsableProduceMagic()); + + // something that doesn't support PRODUCE, which is the case with Raft-based controllers + apiVersions.update("2", new NodeApiVersions(Collections.singleton( + new ApiVersionsResponseData.ApiVersion() + .setApiKey(ApiKeys.FETCH.id) + .setMinVersion((short) 0) + .setMaxVersion((short) 2)))); + assertEquals(RecordBatch.CURRENT_MAGIC_VALUE, apiVersions.maxUsableProduceMagic()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/ClientUtilsTest.java b/clients/src/test/java/org/apache/kafka/clients/ClientUtilsTest.java new file mode 100644 index 0000000..7ef55eb --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/ClientUtilsTest.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.UnknownHostException; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.kafka.common.config.ConfigException; +import org.junit.jupiter.api.Test; + +import static java.util.Arrays.asList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ClientUtilsTest { + + private final HostResolver hostResolver = new DefaultHostResolver(); + + @Test + public void testParseAndValidateAddresses() { + checkWithoutLookup("127.0.0.1:8000"); + checkWithoutLookup("localhost:8080"); + checkWithoutLookup("[::1]:8000"); + checkWithoutLookup("[2001:db8:85a3:8d3:1319:8a2e:370:7348]:1234", "localhost:10000"); + List validatedAddresses = checkWithoutLookup("localhost:10000"); + assertEquals(1, validatedAddresses.size()); + InetSocketAddress onlyAddress = validatedAddresses.get(0); + assertEquals("localhost", onlyAddress.getHostName()); + assertEquals(10000, onlyAddress.getPort()); + } + + @Test + public void testParseAndValidateAddressesWithReverseLookup() { + checkWithoutLookup("127.0.0.1:8000"); + checkWithoutLookup("localhost:8080"); + checkWithoutLookup("[::1]:8000"); + checkWithoutLookup("[2001:db8:85a3:8d3:1319:8a2e:370:7348]:1234", "localhost:10000"); + + // With lookup of example.com, either one or two addresses are expected depending on + // whether ipv4 and ipv6 are enabled + List validatedAddresses = checkWithLookup(asList("example.com:10000")); + assertTrue(validatedAddresses.size() >= 1, "Unexpected addresses " + validatedAddresses); + List validatedHostNames = validatedAddresses.stream().map(InetSocketAddress::getHostName) + .collect(Collectors.toList()); + List expectedHostNames = asList("93.184.216.34", "2606:2800:220:1:248:1893:25c8:1946"); + assertTrue(expectedHostNames.containsAll(validatedHostNames), "Unexpected addresses " + validatedHostNames); + validatedAddresses.forEach(address -> assertEquals(10000, address.getPort())); + } + + @Test + public void testInvalidConfig() { + assertThrows(IllegalArgumentException.class, + () -> ClientUtils.parseAndValidateAddresses(Collections.singletonList("localhost:10000"), "random.value")); + } + + @Test + public void testNoPort() { + assertThrows(ConfigException.class, () -> checkWithoutLookup("127.0.0.1")); + } + + @Test + public void testOnlyBadHostname() { + assertThrows(ConfigException.class, () -> checkWithoutLookup("some.invalid.hostname.foo.bar.local:9999")); + } + + @Test + public void testFilterPreferredAddresses() throws UnknownHostException { + InetAddress ipv4 = InetAddress.getByName("192.0.0.1"); + InetAddress ipv6 = InetAddress.getByName("::1"); + + InetAddress[] ipv4First = new InetAddress[]{ipv4, ipv6, ipv4}; + List result = ClientUtils.filterPreferredAddresses(ipv4First); + assertTrue(result.contains(ipv4)); + assertFalse(result.contains(ipv6)); + assertEquals(2, result.size()); + + InetAddress[] ipv6First = new InetAddress[]{ipv6, ipv4, ipv4}; + result = ClientUtils.filterPreferredAddresses(ipv6First); + assertTrue(result.contains(ipv6)); + assertFalse(result.contains(ipv4)); + assertEquals(1, result.size()); + } + + @Test + public void testResolveUnknownHostException() { + assertThrows(UnknownHostException.class, + () -> ClientUtils.resolve("some.invalid.hostname.foo.bar.local", hostResolver)); + } + + @Test + public void testResolveDnsLookup() throws UnknownHostException { + InetAddress[] addresses = new InetAddress[] { + InetAddress.getByName("198.51.100.0"), InetAddress.getByName("198.51.100.5") + }; + HostResolver hostResolver = new AddressChangeHostResolver(addresses, addresses); + assertEquals(asList(addresses), ClientUtils.resolve("kafka.apache.org", hostResolver)); + } + + private List checkWithoutLookup(String... url) { + return ClientUtils.parseAndValidateAddresses(asList(url), ClientDnsLookup.USE_ALL_DNS_IPS); + } + + private List checkWithLookup(List url) { + return ClientUtils.parseAndValidateAddresses(url, ClientDnsLookup.RESOLVE_CANONICAL_BOOTSTRAP_SERVERS_ONLY); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/ClusterConnectionStatesTest.java b/clients/src/test/java/org/apache/kafka/clients/ClusterConnectionStatesTest.java new file mode 100644 index 0000000..72cc123 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/ClusterConnectionStatesTest.java @@ -0,0 +1,429 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +import java.net.InetAddress; +import java.net.UnknownHostException; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class ClusterConnectionStatesTest { + + private static ArrayList initialAddresses; + private static ArrayList newAddresses; + + static { + try { + initialAddresses = new ArrayList<>(Arrays.asList( + InetAddress.getByName("10.200.20.100"), + InetAddress.getByName("10.200.20.101"), + InetAddress.getByName("10.200.20.102") + )); + newAddresses = new ArrayList<>(Arrays.asList( + InetAddress.getByName("10.200.20.103"), + InetAddress.getByName("10.200.20.104"), + InetAddress.getByName("10.200.20.105") + )); + } catch (UnknownHostException e) { + fail("Attempted to create an invalid InetAddress, this should not happen"); + } + } + + private final MockTime time = new MockTime(); + private final long reconnectBackoffMs = 10 * 1000; + private final long reconnectBackoffMax = 60 * 1000; + private final long connectionSetupTimeoutMs = 10 * 1000; + private final long connectionSetupTimeoutMaxMs = 127 * 1000; + private final int reconnectBackoffExpBase = ClusterConnectionStates.RECONNECT_BACKOFF_EXP_BASE; + private final double reconnectBackoffJitter = ClusterConnectionStates.RECONNECT_BACKOFF_JITTER; + private final int connectionSetupTimeoutExpBase = ClusterConnectionStates.CONNECTION_SETUP_TIMEOUT_EXP_BASE; + private final double connectionSetupTimeoutJitter = ClusterConnectionStates.CONNECTION_SETUP_TIMEOUT_JITTER; + private final String nodeId1 = "1001"; + private final String nodeId2 = "2002"; + private final String nodeId3 = "3003"; + private final String hostTwoIps = "multiple.ip.address"; + private ClusterConnectionStates connectionStates; + + // For testing nodes with a single IP address, use localhost and default DNS resolution + private DefaultHostResolver singleIPHostResolver = new DefaultHostResolver(); + + // For testing nodes with multiple IP addresses, mock DNS resolution to get consistent results + private AddressChangeHostResolver multipleIPHostResolver = new AddressChangeHostResolver( + initialAddresses.toArray(new InetAddress[0]), newAddresses.toArray(new InetAddress[0])); + + @BeforeEach + public void setup() { + this.connectionStates = new ClusterConnectionStates(reconnectBackoffMs, reconnectBackoffMax, + connectionSetupTimeoutMs, connectionSetupTimeoutMaxMs, new LogContext(), this.singleIPHostResolver); + } + + @Test + public void testClusterConnectionStateChanges() { + assertTrue(connectionStates.canConnect(nodeId1, time.milliseconds())); + assertEquals(0, connectionStates.connectionDelay(nodeId1, time.milliseconds())); + + // Start connecting to Node and check state + connectionStates.connecting(nodeId1, time.milliseconds(), "localhost"); + assertEquals(ConnectionState.CONNECTING, connectionStates.connectionState(nodeId1)); + assertTrue(connectionStates.isConnecting(nodeId1)); + assertFalse(connectionStates.isReady(nodeId1, time.milliseconds())); + assertFalse(connectionStates.isBlackedOut(nodeId1, time.milliseconds())); + assertFalse(connectionStates.hasReadyNodes(time.milliseconds())); + long connectionDelay = connectionStates.connectionDelay(nodeId1, time.milliseconds()); + double connectionDelayDelta = connectionSetupTimeoutMs * connectionSetupTimeoutJitter; + assertEquals(connectionSetupTimeoutMs, connectionDelay, connectionDelayDelta); + + time.sleep(100); + + // Successful connection + connectionStates.ready(nodeId1); + assertEquals(ConnectionState.READY, connectionStates.connectionState(nodeId1)); + assertTrue(connectionStates.isReady(nodeId1, time.milliseconds())); + assertTrue(connectionStates.hasReadyNodes(time.milliseconds())); + assertFalse(connectionStates.isConnecting(nodeId1)); + assertFalse(connectionStates.isBlackedOut(nodeId1, time.milliseconds())); + assertEquals(Long.MAX_VALUE, connectionStates.connectionDelay(nodeId1, time.milliseconds())); + + time.sleep(15000); + + // Disconnected from broker + connectionStates.disconnected(nodeId1, time.milliseconds()); + assertEquals(ConnectionState.DISCONNECTED, connectionStates.connectionState(nodeId1)); + assertTrue(connectionStates.isDisconnected(nodeId1)); + assertTrue(connectionStates.isBlackedOut(nodeId1, time.milliseconds())); + assertFalse(connectionStates.isConnecting(nodeId1)); + assertFalse(connectionStates.hasReadyNodes(time.milliseconds())); + assertFalse(connectionStates.canConnect(nodeId1, time.milliseconds())); + + // After disconnecting we expect a backoff value equal to the reconnect.backoff.ms setting (plus minus 20% jitter) + double backoffTolerance = reconnectBackoffMs * reconnectBackoffJitter; + long currentBackoff = connectionStates.connectionDelay(nodeId1, time.milliseconds()); + assertEquals(reconnectBackoffMs, currentBackoff, backoffTolerance); + + time.sleep(currentBackoff + 1); + // after waiting for the current backoff value we should be allowed to connect again + assertTrue(connectionStates.canConnect(nodeId1, time.milliseconds())); + } + + @Test + public void testMultipleNodeConnectionStates() { + // Check initial state, allowed to connect to all nodes, but no nodes shown as ready + assertTrue(connectionStates.canConnect(nodeId1, time.milliseconds())); + assertTrue(connectionStates.canConnect(nodeId2, time.milliseconds())); + assertFalse(connectionStates.hasReadyNodes(time.milliseconds())); + + // Start connecting one node and check that the pool only shows ready nodes after + // successful connect + connectionStates.connecting(nodeId2, time.milliseconds(), "localhost"); + assertFalse(connectionStates.hasReadyNodes(time.milliseconds())); + time.sleep(1000); + connectionStates.ready(nodeId2); + assertTrue(connectionStates.hasReadyNodes(time.milliseconds())); + + // Connect second node and check that both are shown as ready, pool should immediately + // show ready nodes, since node2 is already connected + connectionStates.connecting(nodeId1, time.milliseconds(), "localhost"); + assertTrue(connectionStates.hasReadyNodes(time.milliseconds())); + time.sleep(1000); + connectionStates.ready(nodeId1); + assertTrue(connectionStates.hasReadyNodes(time.milliseconds())); + + time.sleep(12000); + + // disconnect nodes and check proper state of pool throughout + connectionStates.disconnected(nodeId2, time.milliseconds()); + assertTrue(connectionStates.hasReadyNodes(time.milliseconds())); + assertTrue(connectionStates.isBlackedOut(nodeId2, time.milliseconds())); + assertFalse(connectionStates.isBlackedOut(nodeId1, time.milliseconds())); + time.sleep(connectionStates.connectionDelay(nodeId2, time.milliseconds())); + // by the time node1 disconnects node2 should have been unblocked again + connectionStates.disconnected(nodeId1, time.milliseconds() + 1); + assertTrue(connectionStates.isBlackedOut(nodeId1, time.milliseconds())); + assertFalse(connectionStates.isBlackedOut(nodeId2, time.milliseconds())); + assertFalse(connectionStates.hasReadyNodes(time.milliseconds())); + } + + @Test + public void testAuthorizationFailed() { + // Try connecting + connectionStates.connecting(nodeId1, time.milliseconds(), "localhost"); + + time.sleep(100); + + connectionStates.authenticationFailed(nodeId1, time.milliseconds(), new AuthenticationException("No path to CA for certificate!")); + time.sleep(1000); + assertEquals(connectionStates.connectionState(nodeId1), ConnectionState.AUTHENTICATION_FAILED); + assertTrue(connectionStates.authenticationException(nodeId1) instanceof AuthenticationException); + assertFalse(connectionStates.hasReadyNodes(time.milliseconds())); + assertFalse(connectionStates.canConnect(nodeId1, time.milliseconds())); + + time.sleep(connectionStates.connectionDelay(nodeId1, time.milliseconds()) + 1); + + assertTrue(connectionStates.canConnect(nodeId1, time.milliseconds())); + connectionStates.ready(nodeId1); + assertNull(connectionStates.authenticationException(nodeId1)); + } + + @Test + public void testRemoveNode() { + connectionStates.connecting(nodeId1, time.milliseconds(), "localhost"); + time.sleep(1000); + connectionStates.ready(nodeId1); + time.sleep(10000); + + connectionStates.disconnected(nodeId1, time.milliseconds()); + // Node is disconnected and blocked, removing it from the list should reset all blocks + connectionStates.remove(nodeId1); + assertTrue(connectionStates.canConnect(nodeId1, time.milliseconds())); + assertFalse(connectionStates.isBlackedOut(nodeId1, time.milliseconds())); + assertEquals(connectionStates.connectionDelay(nodeId1, time.milliseconds()), 0L); + } + + @Test + public void testMaxReconnectBackoff() { + long effectiveMaxReconnectBackoff = Math.round(reconnectBackoffMax * (1 + reconnectBackoffJitter)); + connectionStates.connecting(nodeId1, time.milliseconds(), "localhost"); + time.sleep(1000); + connectionStates.disconnected(nodeId1, time.milliseconds()); + + // Do 100 reconnect attempts and check that MaxReconnectBackoff (plus jitter) is not exceeded + for (int i = 0; i < 100; i++) { + long reconnectBackoff = connectionStates.connectionDelay(nodeId1, time.milliseconds()); + assertTrue(reconnectBackoff <= effectiveMaxReconnectBackoff); + assertFalse(connectionStates.canConnect(nodeId1, time.milliseconds())); + time.sleep(reconnectBackoff + 1); + assertTrue(connectionStates.canConnect(nodeId1, time.milliseconds())); + connectionStates.connecting(nodeId1, time.milliseconds(), "localhost"); + time.sleep(10); + connectionStates.disconnected(nodeId1, time.milliseconds()); + } + } + + @Test + public void testExponentialReconnectBackoff() { + double reconnectBackoffMaxExp = Math.log(reconnectBackoffMax / (double) Math.max(reconnectBackoffMs, 1)) + / Math.log(reconnectBackoffExpBase); + + // Run through 10 disconnects and check that reconnect backoff value is within expected range for every attempt + for (int i = 0; i < 10; i++) { + connectionStates.connecting(nodeId1, time.milliseconds(), "localhost"); + connectionStates.disconnected(nodeId1, time.milliseconds()); + // Calculate expected backoff value without jitter + long expectedBackoff = Math.round(Math.pow(reconnectBackoffExpBase, Math.min(i, reconnectBackoffMaxExp)) + * reconnectBackoffMs); + long currentBackoff = connectionStates.connectionDelay(nodeId1, time.milliseconds()); + assertEquals(expectedBackoff, currentBackoff, reconnectBackoffJitter * expectedBackoff); + time.sleep(connectionStates.connectionDelay(nodeId1, time.milliseconds()) + 1); + } + } + + @Test + public void testThrottled() { + connectionStates.connecting(nodeId1, time.milliseconds(), "localhost"); + time.sleep(1000); + connectionStates.ready(nodeId1); + time.sleep(10000); + + // Initially not throttled. + assertEquals(0, connectionStates.throttleDelayMs(nodeId1, time.milliseconds())); + + // Throttle for 100ms from now. + connectionStates.throttle(nodeId1, time.milliseconds() + 100); + assertEquals(100, connectionStates.throttleDelayMs(nodeId1, time.milliseconds())); + + // Still throttled after 50ms. The remaining delay is 50ms. The poll delay should be same as throttling delay. + time.sleep(50); + assertEquals(50, connectionStates.throttleDelayMs(nodeId1, time.milliseconds())); + assertEquals(50, connectionStates.pollDelayMs(nodeId1, time.milliseconds())); + + // Not throttled anymore when the deadline is reached. The poll delay should be same as connection delay. + time.sleep(50); + assertEquals(0, connectionStates.throttleDelayMs(nodeId1, time.milliseconds())); + assertEquals(connectionStates.connectionDelay(nodeId1, time.milliseconds()), + connectionStates.pollDelayMs(nodeId1, time.milliseconds())); + } + + @Test + public void testSingleIP() throws UnknownHostException { + assertEquals(1, ClientUtils.resolve("localhost", singleIPHostResolver).size()); + + connectionStates.connecting(nodeId1, time.milliseconds(), "localhost"); + InetAddress currAddress = connectionStates.currentAddress(nodeId1); + connectionStates.connecting(nodeId1, time.milliseconds(), "localhost"); + assertSame(currAddress, connectionStates.currentAddress(nodeId1)); + } + + @Test + public void testMultipleIPs() throws UnknownHostException { + setupMultipleIPs(); + + assertTrue(ClientUtils.resolve(hostTwoIps, multipleIPHostResolver).size() > 1); + + connectionStates.connecting(nodeId1, time.milliseconds(), hostTwoIps); + InetAddress addr1 = connectionStates.currentAddress(nodeId1); + connectionStates.connecting(nodeId1, time.milliseconds(), hostTwoIps); + InetAddress addr2 = connectionStates.currentAddress(nodeId1); + assertNotSame(addr1, addr2); + connectionStates.connecting(nodeId1, time.milliseconds(), hostTwoIps); + InetAddress addr3 = connectionStates.currentAddress(nodeId1); + assertNotSame(addr1, addr3); + } + + @Test + public void testHostResolveChange() throws UnknownHostException { + setupMultipleIPs(); + + assertTrue(ClientUtils.resolve(hostTwoIps, multipleIPHostResolver).size() > 1); + + connectionStates.connecting(nodeId1, time.milliseconds(), hostTwoIps); + InetAddress addr1 = connectionStates.currentAddress(nodeId1); + + multipleIPHostResolver.changeAddresses(); + connectionStates.connecting(nodeId1, time.milliseconds(), "localhost"); + InetAddress addr2 = connectionStates.currentAddress(nodeId1); + + assertNotSame(addr1, addr2); + } + + @Test + public void testNodeWithNewHostname() throws UnknownHostException { + setupMultipleIPs(); + + connectionStates.connecting(nodeId1, time.milliseconds(), "localhost"); + InetAddress addr1 = connectionStates.currentAddress(nodeId1); + + this.multipleIPHostResolver.changeAddresses(); + connectionStates.connecting(nodeId1, time.milliseconds(), hostTwoIps); + InetAddress addr2 = connectionStates.currentAddress(nodeId1); + + assertNotSame(addr1, addr2); + } + + @Test + public void testIsPreparingConnection() { + assertFalse(connectionStates.isPreparingConnection(nodeId1)); + connectionStates.connecting(nodeId1, time.milliseconds(), "localhost"); + assertTrue(connectionStates.isPreparingConnection(nodeId1)); + connectionStates.checkingApiVersions(nodeId1); + assertTrue(connectionStates.isPreparingConnection(nodeId1)); + connectionStates.disconnected(nodeId1, time.milliseconds()); + assertFalse(connectionStates.isPreparingConnection(nodeId1)); + } + + @Test + public void testExponentialConnectionSetupTimeout() { + assertTrue(connectionStates.canConnect(nodeId1, time.milliseconds())); + + // Check the exponential timeout growth + for (int n = 0; n <= Math.log((double) connectionSetupTimeoutMaxMs / connectionSetupTimeoutMs) / Math.log(connectionSetupTimeoutExpBase); n++) { + connectionStates.connecting(nodeId1, time.milliseconds(), "localhost"); + assertTrue(connectionStates.connectingNodes().contains(nodeId1)); + assertEquals(connectionSetupTimeoutMs * Math.pow(connectionSetupTimeoutExpBase, n), + connectionStates.connectionSetupTimeoutMs(nodeId1), + connectionSetupTimeoutMs * Math.pow(connectionSetupTimeoutExpBase, n) * connectionSetupTimeoutJitter); + connectionStates.disconnected(nodeId1, time.milliseconds()); + assertFalse(connectionStates.connectingNodes().contains(nodeId1)); + } + + // Check the timeout value upper bound + connectionStates.connecting(nodeId1, time.milliseconds(), "localhost"); + assertEquals(connectionSetupTimeoutMaxMs, + connectionStates.connectionSetupTimeoutMs(nodeId1), + connectionSetupTimeoutMaxMs * connectionSetupTimeoutJitter); + assertTrue(connectionStates.connectingNodes().contains(nodeId1)); + + // Should reset the timeout value to the init value + connectionStates.ready(nodeId1); + assertEquals(connectionSetupTimeoutMs, + connectionStates.connectionSetupTimeoutMs(nodeId1), + connectionSetupTimeoutMs * connectionSetupTimeoutJitter); + assertFalse(connectionStates.connectingNodes().contains(nodeId1)); + connectionStates.disconnected(nodeId1, time.milliseconds()); + + // Check if the connection state transition from ready to disconnected + // won't increase the timeout value + connectionStates.connecting(nodeId1, time.milliseconds(), "localhost"); + assertEquals(connectionSetupTimeoutMs, + connectionStates.connectionSetupTimeoutMs(nodeId1), + connectionSetupTimeoutMs * connectionSetupTimeoutJitter); + assertTrue(connectionStates.connectingNodes().contains(nodeId1)); + } + + @Test + public void testTimedOutConnections() { + // Initiate two connections + connectionStates.connecting(nodeId1, time.milliseconds(), "localhost"); + connectionStates.connecting(nodeId2, time.milliseconds(), "localhost"); + + // Expect no timed out connections + assertEquals(0, connectionStates.nodesWithConnectionSetupTimeout(time.milliseconds()).size()); + + // Advance time by half of the connection setup timeout + time.sleep(connectionSetupTimeoutMs / 2); + + // Initiate a third connection + connectionStates.connecting(nodeId3, time.milliseconds(), "localhost"); + + // Advance time beyond the connection setup timeout (+ max jitter) for the first two connections + time.sleep((long) (connectionSetupTimeoutMs / 2 + connectionSetupTimeoutMs * connectionSetupTimeoutJitter)); + + // Expect two timed out connections. + List timedOutConnections = connectionStates.nodesWithConnectionSetupTimeout(time.milliseconds()); + assertEquals(2, timedOutConnections.size()); + assertTrue(timedOutConnections.contains(nodeId1)); + assertTrue(timedOutConnections.contains(nodeId2)); + + // Disconnect the first two connections + connectionStates.disconnected(nodeId1, time.milliseconds()); + connectionStates.disconnected(nodeId2, time.milliseconds()); + + // Advance time beyond the connection setup timeout (+ max jitter) for the third connections + time.sleep((long) (connectionSetupTimeoutMs / 2 + connectionSetupTimeoutMs * connectionSetupTimeoutJitter)); + + // Expect one timed out connection + timedOutConnections = connectionStates.nodesWithConnectionSetupTimeout(time.milliseconds()); + assertEquals(1, timedOutConnections.size()); + assertTrue(timedOutConnections.contains(nodeId3)); + + // Disconnect the third connection + connectionStates.disconnected(nodeId3, time.milliseconds()); + + // Expect no timed out connections + assertEquals(0, connectionStates.nodesWithConnectionSetupTimeout(time.milliseconds()).size()); + } + + private void setupMultipleIPs() { + this.connectionStates = new ClusterConnectionStates(reconnectBackoffMs, reconnectBackoffMax, + connectionSetupTimeoutMs, connectionSetupTimeoutMaxMs, new LogContext(), this.multipleIPHostResolver); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/CommonClientConfigsTest.java b/clients/src/test/java/org/apache/kafka/clients/CommonClientConfigsTest.java new file mode 100644 index 0000000..007e149 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/CommonClientConfigsTest.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients; + +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.ConfigDef; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.apache.kafka.common.config.ConfigDef.Range.atLeast; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class CommonClientConfigsTest { + private static class TestConfig extends AbstractConfig { + private static final ConfigDef CONFIG; + static { + CONFIG = new ConfigDef() + .define(CommonClientConfigs.RECONNECT_BACKOFF_MS_CONFIG, + ConfigDef.Type.LONG, + 50L, + atLeast(0L), + ConfigDef.Importance.LOW, + "") + .define(CommonClientConfigs.RECONNECT_BACKOFF_MAX_MS_CONFIG, + ConfigDef.Type.LONG, + 1000L, + atLeast(0L), + ConfigDef.Importance.LOW, + ""); + } + + @Override + protected Map postProcessParsedConfig(final Map parsedValues) { + return CommonClientConfigs.postProcessReconnectBackoffConfigs(this, parsedValues); + } + + public TestConfig(Map props) { + super(CONFIG, props); + } + } + + @Test + public void testExponentialBackoffDefaults() { + TestConfig defaultConf = new TestConfig(Collections.emptyMap()); + assertEquals(Long.valueOf(50L), + defaultConf.getLong(CommonClientConfigs.RECONNECT_BACKOFF_MS_CONFIG)); + assertEquals(Long.valueOf(1000L), + defaultConf.getLong(CommonClientConfigs.RECONNECT_BACKOFF_MAX_MS_CONFIG)); + + TestConfig bothSetConfig = new TestConfig(new HashMap() {{ + put(CommonClientConfigs.RECONNECT_BACKOFF_MS_CONFIG, "123"); + put(CommonClientConfigs.RECONNECT_BACKOFF_MAX_MS_CONFIG, "12345"); + }}); + assertEquals(Long.valueOf(123L), + bothSetConfig.getLong(CommonClientConfigs.RECONNECT_BACKOFF_MS_CONFIG)); + assertEquals(Long.valueOf(12345L), + bothSetConfig.getLong(CommonClientConfigs.RECONNECT_BACKOFF_MAX_MS_CONFIG)); + + TestConfig reconnectBackoffSetConf = new TestConfig(new HashMap() {{ + put(CommonClientConfigs.RECONNECT_BACKOFF_MS_CONFIG, "123"); + }}); + assertEquals(Long.valueOf(123L), + reconnectBackoffSetConf.getLong(CommonClientConfigs.RECONNECT_BACKOFF_MS_CONFIG)); + assertEquals(Long.valueOf(123L), + reconnectBackoffSetConf.getLong(CommonClientConfigs.RECONNECT_BACKOFF_MAX_MS_CONFIG)); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/FetchSessionHandlerTest.java b/clients/src/test/java/org/apache/kafka/clients/FetchSessionHandlerTest.java new file mode 100644 index 0000000..4bf53d9 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/FetchSessionHandlerTest.java @@ -0,0 +1,777 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.message.FetchResponseData; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.FetchMetadata; +import org.apache.kafka.common.requests.FetchRequest; +import org.apache.kafka.common.requests.FetchResponse; +import org.apache.kafka.common.utils.LogContext; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Stream; +import java.util.TreeSet; + +import static org.apache.kafka.common.requests.FetchMetadata.INITIAL_EPOCH; +import static org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +/** + * A unit test for FetchSessionHandler. + */ +@Timeout(120) +public class FetchSessionHandlerTest { + private static final LogContext LOG_CONTEXT = new LogContext("[FetchSessionHandler]="); + + /** + * Create a set of TopicPartitions. We use a TreeSet, in order to get a deterministic + * ordering for test purposes. + */ + private static Set toSet(TopicPartition... arr) { + TreeSet set = new TreeSet<>(new Comparator() { + @Override + public int compare(TopicPartition o1, TopicPartition o2) { + return o1.toString().compareTo(o2.toString()); + } + }); + set.addAll(Arrays.asList(arr)); + return set; + } + + @Test + public void testFindMissing() { + TopicPartition foo0 = new TopicPartition("foo", 0); + TopicPartition foo1 = new TopicPartition("foo", 1); + TopicPartition bar0 = new TopicPartition("bar", 0); + TopicPartition bar1 = new TopicPartition("bar", 1); + TopicPartition baz0 = new TopicPartition("baz", 0); + TopicPartition baz1 = new TopicPartition("baz", 1); + assertEquals(toSet(), FetchSessionHandler.findMissing(toSet(foo0), toSet(foo0))); + assertEquals(toSet(foo0), FetchSessionHandler.findMissing(toSet(foo0), toSet(foo1))); + assertEquals(toSet(foo0, foo1), + FetchSessionHandler.findMissing(toSet(foo0, foo1), toSet(baz0))); + assertEquals(toSet(bar1, foo0, foo1), + FetchSessionHandler.findMissing(toSet(foo0, foo1, bar0, bar1), + toSet(bar0, baz0, baz1))); + assertEquals(toSet(), + FetchSessionHandler.findMissing(toSet(foo0, foo1, bar0, bar1, baz1), + toSet(foo0, foo1, bar0, bar1, baz0, baz1))); + } + + private static final class ReqEntry { + final TopicPartition part; + final FetchRequest.PartitionData data; + + ReqEntry(String topic, Uuid topicId, int partition, long fetchOffset, long logStartOffset, int maxBytes) { + this.part = new TopicPartition(topic, partition); + this.data = new FetchRequest.PartitionData(topicId, fetchOffset, logStartOffset, maxBytes, Optional.empty()); + } + } + + private static LinkedHashMap reqMap(ReqEntry... entries) { + LinkedHashMap map = new LinkedHashMap<>(); + for (ReqEntry entry : entries) { + map.put(entry.part, entry.data); + } + return map; + } + + private static void assertMapEquals(Map expected, + Map actual) { + Iterator> expectedIter = + expected.entrySet().iterator(); + Iterator> actualIter = + actual.entrySet().iterator(); + int i = 1; + while (expectedIter.hasNext()) { + Map.Entry expectedEntry = expectedIter.next(); + if (!actualIter.hasNext()) { + fail("Element " + i + " not found."); + } + Map.Entry actuaLEntry = actualIter.next(); + assertEquals(expectedEntry.getKey(), actuaLEntry.getKey(), "Element " + i + + " had a different TopicPartition than expected."); + assertEquals(expectedEntry.getValue(), actuaLEntry.getValue(), "Element " + i + + " had different PartitionData than expected."); + i++; + } + if (actualIter.hasNext()) { + fail("Unexpected element " + i + " found."); + } + } + + @SafeVarargs + private static void assertMapsEqual(Map expected, + Map... actuals) { + for (Map actual : actuals) { + assertMapEquals(expected, actual); + } + } + + private static void assertListEquals(List expected, List actual) { + for (TopicIdPartition expectedPart : expected) { + if (!actual.contains(expectedPart)) { + fail("Failed to find expected partition " + expectedPart); + } + } + for (TopicIdPartition actualPart : actual) { + if (!expected.contains(actualPart)) { + fail("Found unexpected partition " + actualPart); + } + } + } + + private static final class RespEntry { + final TopicIdPartition part; + final FetchResponseData.PartitionData data; + + RespEntry(String topic, int partition, Uuid topicId, long highWatermark, long lastStableOffset) { + this.part = new TopicIdPartition(topicId, new TopicPartition(topic, partition)); + + this.data = new FetchResponseData.PartitionData() + .setPartitionIndex(partition) + .setHighWatermark(highWatermark) + .setLastStableOffset(lastStableOffset) + .setLogStartOffset(0); + } + + RespEntry(String topic, int partition, Uuid topicId, Errors error) { + this.part = new TopicIdPartition(topicId, new TopicPartition(topic, partition)); + + this.data = new FetchResponseData.PartitionData() + .setPartitionIndex(partition) + .setErrorCode(error.code()) + .setHighWatermark(FetchResponse.INVALID_HIGH_WATERMARK); + } + } + + private static LinkedHashMap respMap(RespEntry... entries) { + LinkedHashMap map = new LinkedHashMap<>(); + for (RespEntry entry : entries) { + map.put(entry.part, entry.data); + } + return map; + } + + /** + * Test the handling of SESSIONLESS responses. + * Pre-KIP-227 brokers always supply this kind of response. + */ + @Test + public void testSessionless() { + Map topicIds = new HashMap<>(); + Map topicNames = new HashMap<>(); + // We want to test both on older versions that do not use topic IDs and on newer versions that do. + List versions = Arrays.asList((short) 12, ApiKeys.FETCH.latestVersion()); + versions.forEach(version -> { + FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1); + FetchSessionHandler.Builder builder = handler.newBuilder(); + addTopicId(topicIds, topicNames, "foo", version); + Uuid fooId = topicIds.getOrDefault("foo", Uuid.ZERO_UUID); + builder.add(new TopicPartition("foo", 0), + new FetchRequest.PartitionData(fooId, 0, 100, 200, Optional.empty())); + builder.add(new TopicPartition("foo", 1), + new FetchRequest.PartitionData(fooId, 10, 110, 210, Optional.empty())); + FetchSessionHandler.FetchRequestData data = builder.build(); + assertMapsEqual(reqMap(new ReqEntry("foo", fooId, 0, 0, 100, 200), + new ReqEntry("foo", fooId, 1, 10, 110, 210)), + data.toSend(), data.sessionPartitions()); + assertEquals(INVALID_SESSION_ID, data.metadata().sessionId()); + assertEquals(INITIAL_EPOCH, data.metadata().epoch()); + + FetchResponse resp = FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, + respMap(new RespEntry("foo", 0, fooId, 0, 0), + new RespEntry("foo", 1, fooId, 0, 0))); + handler.handleResponse(resp, version); + + FetchSessionHandler.Builder builder2 = handler.newBuilder(); + builder2.add(new TopicPartition("foo", 0), + new FetchRequest.PartitionData(fooId, 0, 100, 200, Optional.empty())); + FetchSessionHandler.FetchRequestData data2 = builder2.build(); + assertEquals(INVALID_SESSION_ID, data2.metadata().sessionId()); + assertEquals(INITIAL_EPOCH, data2.metadata().epoch()); + assertMapsEqual(reqMap(new ReqEntry("foo", fooId, 0, 0, 100, 200)), + data2.toSend(), data2.sessionPartitions()); + }); + } + + /** + * Test handling an incremental fetch session. + */ + @Test + public void testIncrementals() { + Map topicIds = new HashMap<>(); + Map topicNames = new HashMap<>(); + // We want to test both on older versions that do not use topic IDs and on newer versions that do. + List versions = Arrays.asList((short) 12, ApiKeys.FETCH.latestVersion()); + versions.forEach(version -> { + FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1); + FetchSessionHandler.Builder builder = handler.newBuilder(); + addTopicId(topicIds, topicNames, "foo", version); + Uuid fooId = topicIds.getOrDefault("foo", Uuid.ZERO_UUID); + TopicPartition foo0 = new TopicPartition("foo", 0); + TopicPartition foo1 = new TopicPartition("foo", 1); + builder.add(foo0, new FetchRequest.PartitionData(fooId, 0, 100, 200, Optional.empty())); + builder.add(foo1, new FetchRequest.PartitionData(fooId, 10, 110, 210, Optional.empty())); + FetchSessionHandler.FetchRequestData data = builder.build(); + assertMapsEqual(reqMap(new ReqEntry("foo", fooId, 0, 0, 100, 200), + new ReqEntry("foo", fooId, 1, 10, 110, 210)), + data.toSend(), data.sessionPartitions()); + assertEquals(INVALID_SESSION_ID, data.metadata().sessionId()); + assertEquals(INITIAL_EPOCH, data.metadata().epoch()); + + FetchResponse resp = FetchResponse.of(Errors.NONE, 0, 123, + respMap(new RespEntry("foo", 0, fooId, 10, 20), + new RespEntry("foo", 1, fooId, 10, 20))); + handler.handleResponse(resp, version); + + // Test an incremental fetch request which adds one partition and modifies another. + FetchSessionHandler.Builder builder2 = handler.newBuilder(); + addTopicId(topicIds, topicNames, "bar", version); + Uuid barId = topicIds.getOrDefault("bar", Uuid.ZERO_UUID); + TopicPartition bar0 = new TopicPartition("bar", 0); + builder2.add(foo0, new FetchRequest.PartitionData(fooId, 0, 100, 200, Optional.empty())); + builder2.add(foo1, new FetchRequest.PartitionData(fooId, 10, 120, 210, Optional.empty())); + builder2.add(bar0, new FetchRequest.PartitionData(barId, 20, 200, 200, Optional.empty())); + FetchSessionHandler.FetchRequestData data2 = builder2.build(); + assertFalse(data2.metadata().isFull()); + assertMapEquals(reqMap(new ReqEntry("foo", fooId, 0, 0, 100, 200), + new ReqEntry("foo", fooId, 1, 10, 120, 210), + new ReqEntry("bar", barId, 0, 20, 200, 200)), + data2.sessionPartitions()); + assertMapEquals(reqMap(new ReqEntry("bar", barId, 0, 20, 200, 200), + new ReqEntry("foo", fooId, 1, 10, 120, 210)), + data2.toSend()); + + FetchResponse resp2 = FetchResponse.of(Errors.NONE, 0, 123, + respMap(new RespEntry("foo", 1, fooId, 20, 20))); + handler.handleResponse(resp2, version); + + // Skip building a new request. Test that handling an invalid fetch session epoch response results + // in a request which closes the session. + FetchResponse resp3 = FetchResponse.of(Errors.INVALID_FETCH_SESSION_EPOCH, 0, INVALID_SESSION_ID, + respMap()); + handler.handleResponse(resp3, version); + + FetchSessionHandler.Builder builder4 = handler.newBuilder(); + builder4.add(foo0, new FetchRequest.PartitionData(fooId, 0, 100, 200, Optional.empty())); + builder4.add(foo1, new FetchRequest.PartitionData(fooId, 10, 120, 210, Optional.empty())); + builder4.add(bar0, new FetchRequest.PartitionData(barId, 20, 200, 200, Optional.empty())); + FetchSessionHandler.FetchRequestData data4 = builder4.build(); + assertTrue(data4.metadata().isFull()); + assertEquals(data2.metadata().sessionId(), data4.metadata().sessionId()); + assertEquals(INITIAL_EPOCH, data4.metadata().epoch()); + assertMapsEqual(reqMap(new ReqEntry("foo", fooId, 0, 0, 100, 200), + new ReqEntry("foo", fooId, 1, 10, 120, 210), + new ReqEntry("bar", barId, 0, 20, 200, 200)), + data4.sessionPartitions(), data4.toSend()); + }); + } + + /** + * Test that calling FetchSessionHandler#Builder#build twice fails. + */ + @Test + public void testDoubleBuild() { + FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1); + FetchSessionHandler.Builder builder = handler.newBuilder(); + builder.add(new TopicPartition("foo", 0), + new FetchRequest.PartitionData(Uuid.randomUuid(), 0, 100, 200, Optional.empty())); + builder.build(); + try { + builder.build(); + fail("Expected calling build twice to fail."); + } catch (Throwable t) { + // expected + } + } + + @Test + public void testIncrementalPartitionRemoval() { + Map topicIds = new HashMap<>(); + Map topicNames = new HashMap<>(); + // We want to test both on older versions that do not use topic IDs and on newer versions that do. + List versions = Arrays.asList((short) 12, ApiKeys.FETCH.latestVersion()); + versions.forEach(version -> { + FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1); + FetchSessionHandler.Builder builder = handler.newBuilder(); + addTopicId(topicIds, topicNames, "foo", version); + addTopicId(topicIds, topicNames, "bar", version); + Uuid fooId = topicIds.getOrDefault("foo", Uuid.ZERO_UUID); + Uuid barId = topicIds.getOrDefault("bar", Uuid.ZERO_UUID); + TopicPartition foo0 = new TopicPartition("foo", 0); + TopicPartition foo1 = new TopicPartition("foo", 1); + TopicPartition bar0 = new TopicPartition("bar", 0); + builder.add(foo0, new FetchRequest.PartitionData(fooId, 0, 100, 200, Optional.empty())); + builder.add(foo1, new FetchRequest.PartitionData(fooId, 10, 110, 210, Optional.empty())); + builder.add(bar0, new FetchRequest.PartitionData(barId, 20, 120, 220, Optional.empty())); + FetchSessionHandler.FetchRequestData data = builder.build(); + assertMapsEqual(reqMap(new ReqEntry("foo", fooId, 0, 0, 100, 200), + new ReqEntry("foo", fooId, 1, 10, 110, 210), + new ReqEntry("bar", barId, 0, 20, 120, 220)), + data.toSend(), data.sessionPartitions()); + assertTrue(data.metadata().isFull()); + + FetchResponse resp = FetchResponse.of(Errors.NONE, 0, 123, + respMap(new RespEntry("foo", 0, fooId, 10, 20), + new RespEntry("foo", 1, fooId, 10, 20), + new RespEntry("bar", 0, barId, 10, 20))); + handler.handleResponse(resp, version); + + // Test an incremental fetch request which removes two partitions. + FetchSessionHandler.Builder builder2 = handler.newBuilder(); + builder2.add(foo1, new FetchRequest.PartitionData(fooId, 10, 110, 210, Optional.empty())); + FetchSessionHandler.FetchRequestData data2 = builder2.build(); + assertFalse(data2.metadata().isFull()); + assertEquals(123, data2.metadata().sessionId()); + assertEquals(1, data2.metadata().epoch()); + assertMapEquals(reqMap(new ReqEntry("foo", fooId, 1, 10, 110, 210)), + data2.sessionPartitions()); + assertMapEquals(reqMap(), data2.toSend()); + ArrayList expectedToForget2 = new ArrayList<>(); + expectedToForget2.add(new TopicIdPartition(fooId, foo0)); + expectedToForget2.add(new TopicIdPartition(barId, bar0)); + assertListEquals(expectedToForget2, data2.toForget()); + + // A FETCH_SESSION_ID_NOT_FOUND response triggers us to close the session. + // The next request is a session establishing FULL request. + FetchResponse resp2 = FetchResponse.of(Errors.FETCH_SESSION_ID_NOT_FOUND, 0, INVALID_SESSION_ID, + respMap()); + handler.handleResponse(resp2, version); + + FetchSessionHandler.Builder builder3 = handler.newBuilder(); + builder3.add(foo0, new FetchRequest.PartitionData(fooId, 0, 100, 200, Optional.empty())); + FetchSessionHandler.FetchRequestData data3 = builder3.build(); + assertTrue(data3.metadata().isFull()); + assertEquals(INVALID_SESSION_ID, data3.metadata().sessionId()); + assertEquals(INITIAL_EPOCH, data3.metadata().epoch()); + assertMapsEqual(reqMap(new ReqEntry("foo", fooId, 0, 0, 100, 200)), + data3.sessionPartitions(), data3.toSend()); + }); + } + + @Test + public void testTopicIdUsageGrantedOnIdUpgrade() { + // We want to test adding a topic ID to an existing partition and a new partition in the incremental request. + // 0 is the existing partition and 1 is the new one. + List partitions = Arrays.asList(0, 1); + partitions.forEach(partition -> { + String testType = partition == 0 ? "updating a partition" : "adding a new partition"; + FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1); + FetchSessionHandler.Builder builder = handler.newBuilder(); + builder.add(new TopicPartition("foo", 0), + new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0, 100, 200, Optional.empty())); + FetchSessionHandler.FetchRequestData data = builder.build(); + assertMapsEqual(reqMap(new ReqEntry("foo", Uuid.ZERO_UUID, 0, 0, 100, 200)), + data.toSend(), data.sessionPartitions()); + assertTrue(data.metadata().isFull()); + assertFalse(data.canUseTopicIds()); + + FetchResponse resp = FetchResponse.of(Errors.NONE, 0, 123, + respMap(new RespEntry("foo", 0, Uuid.ZERO_UUID, 10, 20))); + handler.handleResponse(resp, (short) 12); + + // Try to add a topic ID to an already existing topic partition (0) or a new partition (1) in the session. + Uuid topicId = Uuid.randomUuid(); + FetchSessionHandler.Builder builder2 = handler.newBuilder(); + builder2.add(new TopicPartition("foo", partition), + new FetchRequest.PartitionData(topicId, 10, 110, 210, Optional.empty())); + FetchSessionHandler.FetchRequestData data2 = builder2.build(); + // Should have the same session ID, and next epoch and can only use topic IDs if the partition was updated. + boolean updated = partition == 0; + // The receiving broker will handle closing the session. + assertEquals(123, data2.metadata().sessionId(), "Did not use same session when " + testType); + assertEquals(1, data2.metadata().epoch(), "Did not have correct epoch when " + testType); + assertEquals(updated, data2.canUseTopicIds()); + }); + } + + @Test + public void testIdUsageRevokedOnIdDowngrade() { + // We want to test removing topic ID from an existing partition and adding a new partition without an ID in the incremental request. + // 0 is the existing partition and 1 is the new one. + List partitions = Arrays.asList(0, 1); + partitions.forEach(partition -> { + String testType = partition == 0 ? "updating a partition" : "adding a new partition"; + Uuid fooId = Uuid.randomUuid(); + FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1); + FetchSessionHandler.Builder builder = handler.newBuilder(); + builder.add(new TopicPartition("foo", 0), + new FetchRequest.PartitionData(fooId, 0, 100, 200, Optional.empty())); + FetchSessionHandler.FetchRequestData data = builder.build(); + assertMapsEqual(reqMap(new ReqEntry("foo", fooId, 0, 0, 100, 200)), + data.toSend(), data.sessionPartitions()); + assertTrue(data.metadata().isFull()); + assertTrue(data.canUseTopicIds()); + + FetchResponse resp = FetchResponse.of(Errors.NONE, 0, 123, + respMap(new RespEntry("foo", 0, fooId, 10, 20))); + handler.handleResponse(resp, ApiKeys.FETCH.latestVersion()); + + // Try to remove a topic ID from an existing topic partition (0) or add a new topic partition (1) without an ID. + FetchSessionHandler.Builder builder2 = handler.newBuilder(); + builder2.add(new TopicPartition("foo", partition), + new FetchRequest.PartitionData(Uuid.ZERO_UUID, 10, 110, 210, Optional.empty())); + FetchSessionHandler.FetchRequestData data2 = builder2.build(); + // Should have the same session ID, and next epoch and can no longer use topic IDs. + // The receiving broker will handle closing the session. + assertEquals(123, data2.metadata().sessionId(), "Did not use same session when " + testType); + assertEquals(1, data2.metadata().epoch(), "Did not have correct epoch when " + testType); + assertFalse(data2.canUseTopicIds()); + }); + } + + private static Stream idUsageCombinations() { + return Stream.of( + Arguments.of(true, true), + Arguments.of(true, false), + Arguments.of(false, true), + Arguments.of(false, false) + ); + } + + @ParameterizedTest + @MethodSource("idUsageCombinations") + public void testTopicIdReplaced(boolean startsWithTopicIds, boolean endsWithTopicIds) { + TopicPartition tp = new TopicPartition("foo", 0); + FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1); + FetchSessionHandler.Builder builder = handler.newBuilder(); + Uuid topicId1 = startsWithTopicIds ? Uuid.randomUuid() : Uuid.ZERO_UUID; + builder.add(tp, new FetchRequest.PartitionData(topicId1, 0, 100, 200, Optional.empty())); + FetchSessionHandler.FetchRequestData data = builder.build(); + assertMapsEqual(reqMap(new ReqEntry("foo", topicId1, 0, 0, 100, 200)), + data.toSend(), data.sessionPartitions()); + assertTrue(data.metadata().isFull()); + assertEquals(startsWithTopicIds, data.canUseTopicIds()); + + FetchResponse resp = FetchResponse.of(Errors.NONE, 0, 123, respMap(new RespEntry("foo", 0, topicId1, 10, 20))); + short version = startsWithTopicIds ? ApiKeys.FETCH.latestVersion() : 12; + handler.handleResponse(resp, version); + + // Try to add a new topic ID. + FetchSessionHandler.Builder builder2 = handler.newBuilder(); + Uuid topicId2 = endsWithTopicIds ? Uuid.randomUuid() : Uuid.ZERO_UUID; + // Use the same data besides the topic ID. + FetchRequest.PartitionData partitionData = new FetchRequest.PartitionData(topicId2, 0, 100, 200, Optional.empty()); + builder2.add(tp, partitionData); + FetchSessionHandler.FetchRequestData data2 = builder2.build(); + + if (startsWithTopicIds && endsWithTopicIds) { + // If we started with an ID, both a only a new ID will count towards replaced. + // The old topic ID partition should be in toReplace, and the new one should be in toSend. + assertEquals(Collections.singletonList(new TopicIdPartition(topicId1, tp)), data2.toReplace()); + assertMapsEqual(reqMap(new ReqEntry("foo", topicId2, 0, 0, 100, 200)), + data2.toSend(), data2.sessionPartitions()); + + // sessionTopicNames should contain only the second topic ID. + assertEquals(Collections.singletonMap(topicId2, tp.topic()), handler.sessionTopicNames()); + + } else if (startsWithTopicIds || endsWithTopicIds) { + // If we downgraded to not using topic IDs we will want to send this data. + // However, we will not mark the partition as one replaced. In this scenario, we should see the session close due to + // changing request types. + // We will have the new topic ID in the session partition map + assertEquals(Collections.emptyList(), data2.toReplace()); + assertMapsEqual(reqMap(new ReqEntry("foo", topicId2, 0, 0, 100, 200)), + data2.toSend(), data2.sessionPartitions()); + // The topicNames map will have the new topic ID if it is valid. + // The old topic ID should be removed as the map will be empty if the request doesn't use topic IDs. + if (endsWithTopicIds) { + assertEquals(Collections.singletonMap(topicId2, tp.topic()), handler.sessionTopicNames()); + } else { + assertEquals(Collections.emptyMap(), handler.sessionTopicNames()); + } + + } else { + // Otherwise, we have no partition in toReplace and since the partition and topic ID was not updated, there is no data to send. + assertEquals(Collections.emptyList(), data2.toReplace()); + assertEquals(Collections.emptyMap(), data2.toSend()); + assertMapsEqual(reqMap(new ReqEntry("foo", topicId2, 0, 0, 100, 200)), data2.sessionPartitions()); + // There is also nothing in the sessionTopicNames map, as there are no topic IDs used. + assertEquals(Collections.emptyMap(), handler.sessionTopicNames()); + } + + // Should have the same session ID, and next epoch and can use topic IDs if it ended with topic IDs. + assertEquals(123, data2.metadata().sessionId(), "Did not use same session"); + assertEquals(1, data2.metadata().epoch(), "Did not have correct epoch"); + assertEquals(endsWithTopicIds, data2.canUseTopicIds()); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testSessionEpochWhenMixedUsageOfTopicIDs(boolean startsWithTopicIds) { + Uuid fooId = startsWithTopicIds ? Uuid.randomUuid() : Uuid.ZERO_UUID; + Uuid barId = startsWithTopicIds ? Uuid.ZERO_UUID : Uuid.randomUuid(); + short responseVersion = startsWithTopicIds ? ApiKeys.FETCH.latestVersion() : 12; + + TopicPartition tp0 = new TopicPartition("foo", 0); + TopicPartition tp1 = new TopicPartition("bar", 1); + + FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1); + FetchSessionHandler.Builder builder = handler.newBuilder(); + builder.add(tp0, new FetchRequest.PartitionData(fooId, 0, 100, 200, Optional.empty())); + FetchSessionHandler.FetchRequestData data = builder.build(); + assertMapsEqual(reqMap(new ReqEntry("foo", fooId, 0, 0, 100, 200)), + data.toSend(), data.sessionPartitions()); + assertTrue(data.metadata().isFull()); + assertEquals(startsWithTopicIds, data.canUseTopicIds()); + + FetchResponse resp = FetchResponse.of(Errors.NONE, 0, 123, + respMap(new RespEntry("foo", 0, fooId, 10, 20))); + handler.handleResponse(resp, responseVersion); + + // Re-add the first partition. Then add a partition with opposite ID usage. + FetchSessionHandler.Builder builder2 = handler.newBuilder(); + builder2.add(tp0, new FetchRequest.PartitionData(fooId, 10, 110, 210, Optional.empty())); + builder2.add(tp1, new FetchRequest.PartitionData(barId, 0, 100, 200, Optional.empty())); + FetchSessionHandler.FetchRequestData data2 = builder2.build(); + // Should have the same session ID, and the next epoch and can not use topic IDs. + // The receiving broker will handle closing the session. + assertEquals(123, data2.metadata().sessionId(), "Did not use same session"); + assertEquals(1, data2.metadata().epoch(), "Did not have final epoch"); + assertFalse(data2.canUseTopicIds()); + } + + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testIdUsageWithAllForgottenPartitions(boolean useTopicIds) { + // We want to test when all topics are removed from the session + TopicPartition foo0 = new TopicPartition("foo", 0); + Uuid topicId = useTopicIds ? Uuid.randomUuid() : Uuid.ZERO_UUID; + short responseVersion = useTopicIds ? ApiKeys.FETCH.latestVersion() : 12; + FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1); + + // Add topic foo to the session + FetchSessionHandler.Builder builder = handler.newBuilder(); + builder.add(foo0, new FetchRequest.PartitionData(topicId, 0, 100, 200, Optional.empty())); + FetchSessionHandler.FetchRequestData data = builder.build(); + assertMapsEqual(reqMap(new ReqEntry("foo", topicId, 0, 0, 100, 200)), + data.toSend(), data.sessionPartitions()); + assertTrue(data.metadata().isFull()); + assertEquals(useTopicIds, data.canUseTopicIds()); + + FetchResponse resp = FetchResponse.of(Errors.NONE, 0, 123, + respMap(new RespEntry("foo", 0, topicId, 10, 20))); + handler.handleResponse(resp, responseVersion); + + // Remove the topic from the session + FetchSessionHandler.Builder builder2 = handler.newBuilder(); + FetchSessionHandler.FetchRequestData data2 = builder2.build(); + assertEquals(Collections.singletonList(new TopicIdPartition(topicId, foo0)), data2.toForget()); + // Should have the same session ID, next epoch, and same ID usage. + assertEquals(123, data2.metadata().sessionId(), "Did not use same session when useTopicIds was " + useTopicIds); + assertEquals(1, data2.metadata().epoch(), "Did not have correct epoch when useTopicIds was " + useTopicIds); + assertEquals(useTopicIds, data2.canUseTopicIds()); + } + + @Test + public void testOkToAddNewIdAfterTopicRemovedFromSession() { + Uuid topicId = Uuid.randomUuid(); + FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1); + FetchSessionHandler.Builder builder = handler.newBuilder(); + builder.add(new TopicPartition("foo", 0), + new FetchRequest.PartitionData(topicId, 0, 100, 200, Optional.empty())); + FetchSessionHandler.FetchRequestData data = builder.build(); + assertMapsEqual(reqMap(new ReqEntry("foo", topicId, 0, 0, 100, 200)), + data.toSend(), data.sessionPartitions()); + assertTrue(data.metadata().isFull()); + assertTrue(data.canUseTopicIds()); + + FetchResponse resp = FetchResponse.of(Errors.NONE, 0, 123, + respMap(new RespEntry("foo", 0, topicId, 10, 20))); + handler.handleResponse(resp, ApiKeys.FETCH.latestVersion()); + + // Remove the partition from the session. Return a session ID as though the session is still open. + FetchSessionHandler.Builder builder2 = handler.newBuilder(); + FetchSessionHandler.FetchRequestData data2 = builder2.build(); + assertMapsEqual(new LinkedHashMap<>(), + data2.toSend(), data2.sessionPartitions()); + FetchResponse resp2 = FetchResponse.of(Errors.NONE, 0, 123, + new LinkedHashMap<>()); + handler.handleResponse(resp2, ApiKeys.FETCH.latestVersion()); + + // After the topic is removed, add a recreated topic with a new ID. + FetchSessionHandler.Builder builder3 = handler.newBuilder(); + builder3.add(new TopicPartition("foo", 0), + new FetchRequest.PartitionData(Uuid.randomUuid(), 0, 100, 200, Optional.empty())); + FetchSessionHandler.FetchRequestData data3 = builder3.build(); + // Should have the same session ID and epoch 2. + assertEquals(123, data3.metadata().sessionId(), "Did not use same session"); + assertEquals(2, data3.metadata().epoch(), "Did not have the correct session epoch"); + assertTrue(data.canUseTopicIds()); + } + + @Test + public void testVerifyFullFetchResponsePartitions() { + Map topicIds = new HashMap<>(); + Map topicNames = new HashMap<>(); + // We want to test both on older versions that do not use topic IDs and on newer versions that do. + List versions = Arrays.asList((short) 12, ApiKeys.FETCH.latestVersion()); + versions.forEach(version -> { + FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1); + addTopicId(topicIds, topicNames, "foo", version); + addTopicId(topicIds, topicNames, "bar", version); + Uuid fooId = topicIds.getOrDefault("foo", Uuid.ZERO_UUID); + Uuid barId = topicIds.getOrDefault("bar", Uuid.ZERO_UUID); + TopicPartition foo0 = new TopicPartition("foo", 0); + TopicPartition foo1 = new TopicPartition("foo", 1); + TopicPartition bar0 = new TopicPartition("bar", 0); + FetchResponse resp1 = FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, + respMap(new RespEntry("foo", 0, fooId, 10, 20), + new RespEntry("foo", 1, fooId, 10, 20), + new RespEntry("bar", 0, barId, 10, 20))); + String issue = handler.verifyFullFetchResponsePartitions(resp1.responseData(topicNames, version).keySet(), + resp1.topicIds(), version); + assertTrue(issue.contains("extraPartitions=")); + assertFalse(issue.contains("omittedPartitions=")); + FetchSessionHandler.Builder builder = handler.newBuilder(); + builder.add(foo0, new FetchRequest.PartitionData(fooId, 0, 100, 200, Optional.empty())); + builder.add(foo1, new FetchRequest.PartitionData(fooId, 10, 110, 210, Optional.empty())); + builder.add(bar0, new FetchRequest.PartitionData(barId, 20, 120, 220, Optional.empty())); + builder.build(); + FetchResponse resp2 = FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, + respMap(new RespEntry("foo", 0, fooId, 10, 20), + new RespEntry("foo", 1, fooId, 10, 20), + new RespEntry("bar", 0, barId, 10, 20))); + String issue2 = handler.verifyFullFetchResponsePartitions(resp2.responseData(topicNames, version).keySet(), + resp2.topicIds(), version); + assertNull(issue2); + FetchResponse resp3 = FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, + respMap(new RespEntry("foo", 0, fooId, 10, 20), + new RespEntry("foo", 1, fooId, 10, 20))); + String issue3 = handler.verifyFullFetchResponsePartitions(resp3.responseData(topicNames, version).keySet(), + resp3.topicIds(), version); + assertFalse(issue3.contains("extraPartitions=")); + assertTrue(issue3.contains("omittedPartitions=")); + }); + } + + @Test + public void testVerifyFullFetchResponsePartitionsWithTopicIds() { + Map topicIds = new HashMap<>(); + Map topicNames = new HashMap<>(); + FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1); + addTopicId(topicIds, topicNames, "foo", ApiKeys.FETCH.latestVersion()); + addTopicId(topicIds, topicNames, "bar", ApiKeys.FETCH.latestVersion()); + addTopicId(topicIds, topicNames, "extra2", ApiKeys.FETCH.latestVersion()); + FetchResponse resp1 = FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, + respMap(new RespEntry("foo", 0, topicIds.get("foo"), 10, 20), + new RespEntry("extra2", 1, topicIds.get("extra2"), 10, 20), + new RespEntry("bar", 0, topicIds.get("bar"), 10, 20))); + String issue = handler.verifyFullFetchResponsePartitions(resp1.responseData(topicNames, ApiKeys.FETCH.latestVersion()).keySet(), + resp1.topicIds(), ApiKeys.FETCH.latestVersion()); + assertTrue(issue.contains("extraPartitions=")); + assertFalse(issue.contains("omittedPartitions=")); + FetchSessionHandler.Builder builder = handler.newBuilder(); + builder.add(new TopicPartition("foo", 0), + new FetchRequest.PartitionData(topicIds.get("foo"), 0, 100, 200, Optional.empty())); + builder.add(new TopicPartition("bar", 0), + new FetchRequest.PartitionData(topicIds.get("bar"), 20, 120, 220, Optional.empty())); + builder.build(); + FetchResponse resp2 = FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, + respMap(new RespEntry("foo", 0, topicIds.get("foo"), 10, 20), + new RespEntry("extra2", 1, topicIds.get("extra2"), 10, 20), + new RespEntry("bar", 0, topicIds.get("bar"), 10, 20))); + String issue2 = handler.verifyFullFetchResponsePartitions(resp2.responseData(topicNames, ApiKeys.FETCH.latestVersion()).keySet(), + resp2.topicIds(), ApiKeys.FETCH.latestVersion()); + assertTrue(issue2.contains("extraPartitions=")); + assertFalse(issue2.contains("omittedPartitions=")); + FetchResponse resp3 = FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, + respMap(new RespEntry("foo", 0, topicIds.get("foo"), 10, 20), + new RespEntry("bar", 0, topicIds.get("bar"), 10, 20))); + String issue3 = handler.verifyFullFetchResponsePartitions(resp3.responseData(topicNames, ApiKeys.FETCH.latestVersion()).keySet(), + resp3.topicIds(), ApiKeys.FETCH.latestVersion()); + assertNull(issue3); + } + + @Test + public void testTopLevelErrorResetsMetadata() { + Map topicIds = new HashMap<>(); + Map topicNames = new HashMap<>(); + FetchSessionHandler handler = new FetchSessionHandler(LOG_CONTEXT, 1); + FetchSessionHandler.Builder builder = handler.newBuilder(); + addTopicId(topicIds, topicNames, "foo", ApiKeys.FETCH.latestVersion()); + Uuid fooId = topicIds.getOrDefault("foo", Uuid.ZERO_UUID); + builder.add(new TopicPartition("foo", 0), + new FetchRequest.PartitionData(fooId, 0, 100, 200, Optional.empty())); + builder.add(new TopicPartition("foo", 1), + new FetchRequest.PartitionData(fooId, 10, 110, 210, Optional.empty())); + FetchSessionHandler.FetchRequestData data = builder.build(); + assertEquals(INVALID_SESSION_ID, data.metadata().sessionId()); + assertEquals(INITIAL_EPOCH, data.metadata().epoch()); + + FetchResponse resp = FetchResponse.of(Errors.NONE, 0, 123, + respMap(new RespEntry("foo", 0, topicIds.get("foo"), 10, 20), + new RespEntry("foo", 1, topicIds.get("foo"), 10, 20))); + handler.handleResponse(resp, ApiKeys.FETCH.latestVersion()); + + // Test an incremental fetch request which adds an ID unknown to the broker. + FetchSessionHandler.Builder builder2 = handler.newBuilder(); + addTopicId(topicIds, topicNames, "unknown", ApiKeys.FETCH.latestVersion()); + builder2.add(new TopicPartition("unknown", 0), + new FetchRequest.PartitionData(topicIds.getOrDefault("unknown", Uuid.ZERO_UUID), 0, 100, 200, Optional.empty())); + FetchSessionHandler.FetchRequestData data2 = builder2.build(); + assertFalse(data2.metadata().isFull()); + assertEquals(123, data2.metadata().sessionId()); + assertEquals(FetchMetadata.nextEpoch(INITIAL_EPOCH), data2.metadata().epoch()); + + // Return and handle a response with a top level error + FetchResponse resp2 = FetchResponse.of(Errors.UNKNOWN_TOPIC_ID, 0, 123, + respMap(new RespEntry("unknown", 0, Uuid.randomUuid(), Errors.UNKNOWN_TOPIC_ID))); + assertFalse(handler.handleResponse(resp2, ApiKeys.FETCH.latestVersion())); + + // Ensure we start with a new epoch. This will close the session in the next request. + FetchSessionHandler.Builder builder3 = handler.newBuilder(); + FetchSessionHandler.FetchRequestData data3 = builder3.build(); + assertEquals(123, data3.metadata().sessionId()); + assertEquals(INITIAL_EPOCH, data3.metadata().epoch()); + } + + private void addTopicId(Map topicIds, Map topicNames, String name, short version) { + if (version >= 13) { + Uuid id = Uuid.randomUuid(); + topicIds.put(name, id); + topicNames.put(id, name); + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/InFlightRequestsTest.java b/clients/src/test/java/org/apache/kafka/clients/InFlightRequestsTest.java new file mode 100644 index 0000000..0902266 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/InFlightRequestsTest.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients; + +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.requests.RequestHeader; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class InFlightRequestsTest { + + private InFlightRequests inFlightRequests; + private int correlationId; + private String dest = "dest"; + + @BeforeEach + public void setup() { + inFlightRequests = new InFlightRequests(12); + correlationId = 0; + } + + @Test + public void testCompleteLastSent() { + int correlationId1 = addRequest(dest); + int correlationId2 = addRequest(dest); + assertEquals(2, inFlightRequests.count()); + + assertEquals(correlationId2, inFlightRequests.completeLastSent(dest).header.correlationId()); + assertEquals(1, inFlightRequests.count()); + + assertEquals(correlationId1, inFlightRequests.completeLastSent(dest).header.correlationId()); + assertEquals(0, inFlightRequests.count()); + } + + @Test + public void testClearAll() { + int correlationId1 = addRequest(dest); + int correlationId2 = addRequest(dest); + + List clearedRequests = TestUtils.toList(this.inFlightRequests.clearAll(dest)); + assertEquals(0, inFlightRequests.count()); + assertEquals(2, clearedRequests.size()); + assertEquals(correlationId1, clearedRequests.get(0).header.correlationId()); + assertEquals(correlationId2, clearedRequests.get(1).header.correlationId()); + } + + @Test + public void testTimedOutNodes() { + Time time = new MockTime(); + + addRequest("A", time.milliseconds(), 50); + addRequest("B", time.milliseconds(), 200); + addRequest("B", time.milliseconds(), 100); + + time.sleep(50); + assertEquals(Collections.emptyList(), inFlightRequests.nodesWithTimedOutRequests(time.milliseconds())); + + time.sleep(25); + assertEquals(Collections.singletonList("A"), inFlightRequests.nodesWithTimedOutRequests(time.milliseconds())); + + time.sleep(50); + assertEquals(Arrays.asList("A", "B"), inFlightRequests.nodesWithTimedOutRequests(time.milliseconds())); + } + + @Test + public void testCompleteNext() { + int correlationId1 = addRequest(dest); + int correlationId2 = addRequest(dest); + assertEquals(2, inFlightRequests.count()); + + assertEquals(correlationId1, inFlightRequests.completeNext(dest).header.correlationId()); + assertEquals(1, inFlightRequests.count()); + + assertEquals(correlationId2, inFlightRequests.completeNext(dest).header.correlationId()); + assertEquals(0, inFlightRequests.count()); + } + + @Test + public void testCompleteNextThrowsIfNoInflights() { + assertThrows(IllegalStateException.class, () -> inFlightRequests.completeNext(dest)); + } + + @Test + public void testCompleteLastSentThrowsIfNoInFlights() { + assertThrows(IllegalStateException.class, () -> inFlightRequests.completeLastSent(dest)); + } + + private int addRequest(String destination) { + return addRequest(destination, 0, 10000); + } + + private int addRequest(String destination, long sendTimeMs, int requestTimeoutMs) { + int correlationId = this.correlationId; + this.correlationId += 1; + + RequestHeader requestHeader = new RequestHeader(ApiKeys.METADATA, (short) 0, "clientId", correlationId); + NetworkClient.InFlightRequest ifr = new NetworkClient.InFlightRequest(requestHeader, requestTimeoutMs, 0, + destination, null, false, false, null, null, sendTimeMs); + inFlightRequests.add(ifr); + return correlationId; + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/MetadataCacheTest.java b/clients/src/test/java/org/apache/kafka/clients/MetadataCacheTest.java new file mode 100644 index 0000000..387643c --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/MetadataCacheTest.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients; + +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.MetadataResponse; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class MetadataCacheTest { + + @Test + public void testMissingLeaderEndpoint() { + // Although the broker attempts to ensure leader information is available, the + // client metadata cache may retain partition metadata across multiple responses. + // For example, separate responses may contain conflicting leader epochs for + // separate partitions and the client will always retain the highest. + + TopicPartition topicPartition = new TopicPartition("topic", 0); + + MetadataResponse.PartitionMetadata partitionMetadata = new MetadataResponse.PartitionMetadata( + Errors.NONE, + topicPartition, + Optional.of(5), + Optional.of(10), + Arrays.asList(5, 6, 7), + Arrays.asList(5, 6, 7), + Collections.emptyList()); + + Map nodesById = new HashMap<>(); + nodesById.put(6, new Node(6, "localhost", 2077)); + nodesById.put(7, new Node(7, "localhost", 2078)); + nodesById.put(8, new Node(8, "localhost", 2079)); + + MetadataCache cache = new MetadataCache("clusterId", + nodesById, + Collections.singleton(partitionMetadata), + Collections.emptySet(), + Collections.emptySet(), + Collections.emptySet(), + null, + Collections.emptyMap()); + + Cluster cluster = cache.cluster(); + assertNull(cluster.leaderFor(topicPartition)); + + PartitionInfo partitionInfo = cluster.partition(topicPartition); + Map replicas = Arrays.stream(partitionInfo.replicas()) + .collect(Collectors.toMap(Node::id, Function.identity())); + assertNull(partitionInfo.leader()); + assertEquals(3, replicas.size()); + assertTrue(replicas.get(5).isEmpty()); + assertEquals(nodesById.get(6), replicas.get(6)); + assertEquals(nodesById.get(7), replicas.get(7)); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/MetadataTest.java b/clients/src/test/java/org/apache/kafka/clients/MetadataTest.java new file mode 100644 index 0000000..a4383d8 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/MetadataTest.java @@ -0,0 +1,1021 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients; + +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.errors.InvalidTopicException; +import org.apache.kafka.common.errors.TopicAuthorizationException; +import org.apache.kafka.common.internals.ClusterResourceListeners; +import org.apache.kafka.common.internals.Topic; +import org.apache.kafka.common.message.MetadataResponseData; +import org.apache.kafka.common.message.MetadataResponseData.MetadataResponseBrokerCollection; +import org.apache.kafka.common.message.MetadataResponseData.MetadataResponsePartition; +import org.apache.kafka.common.message.MetadataResponseData.MetadataResponseTopic; +import org.apache.kafka.common.message.MetadataResponseData.MetadataResponseTopicCollection; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.protocol.MessageUtil; +import org.apache.kafka.common.requests.MetadataRequest; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.requests.RequestTestUtils; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.test.MockClusterResourceListener; +import org.junit.jupiter.api.Test; + +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; + +import static org.apache.kafka.test.TestUtils.assertOptional; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class MetadataTest { + + private long refreshBackoffMs = 100; + private long metadataExpireMs = 1000; + private Metadata metadata = new Metadata(refreshBackoffMs, metadataExpireMs, new LogContext(), + new ClusterResourceListeners()); + + private static MetadataResponse emptyMetadataResponse() { + return RequestTestUtils.metadataResponse( + Collections.emptyList(), + null, + -1, + Collections.emptyList()); + } + + @Test + public void testMetadataUpdateAfterClose() { + metadata.close(); + assertThrows(IllegalStateException.class, () -> metadata.updateWithCurrentRequestVersion(emptyMetadataResponse(), false, 1000)); + } + + private static void checkTimeToNextUpdate(long refreshBackoffMs, long metadataExpireMs) { + long now = 10000; + + // Metadata timeToNextUpdate is implicitly relying on the premise that the currentTimeMillis is always + // larger than the metadataExpireMs or refreshBackoffMs. + // It won't be a problem practically since all usages of Metadata calls first update() immediately after + // it's construction. + if (metadataExpireMs > now || refreshBackoffMs > now) { + throw new IllegalArgumentException( + "metadataExpireMs and refreshBackoffMs must be smaller than 'now'"); + } + + long largerOfBackoffAndExpire = Math.max(refreshBackoffMs, metadataExpireMs); + Metadata metadata = new Metadata(refreshBackoffMs, metadataExpireMs, new LogContext(), + new ClusterResourceListeners()); + + assertEquals(0, metadata.timeToNextUpdate(now)); + + // lastSuccessfulRefreshMs updated to now. + metadata.updateWithCurrentRequestVersion(emptyMetadataResponse(), false, now); + + // The last update was successful so the remaining time to expire the current metadata should be returned. + assertEquals(largerOfBackoffAndExpire, metadata.timeToNextUpdate(now)); + + // Metadata update requested explicitly + metadata.requestUpdate(); + // Update requested so metadataExpireMs should no longer take effect. + assertEquals(refreshBackoffMs, metadata.timeToNextUpdate(now)); + + // Reset needUpdate to false. + metadata.updateWithCurrentRequestVersion(emptyMetadataResponse(), false, now); + assertEquals(largerOfBackoffAndExpire, metadata.timeToNextUpdate(now)); + + // Both metadataExpireMs and refreshBackoffMs elapsed. + now += largerOfBackoffAndExpire; + assertEquals(0, metadata.timeToNextUpdate(now)); + assertEquals(0, metadata.timeToNextUpdate(now + 1)); + } + + @Test + public void testUpdateMetadataAllowedImmediatelyAfterBootstrap() { + MockTime time = new MockTime(); + + Metadata metadata = new Metadata(refreshBackoffMs, metadataExpireMs, new LogContext(), + new ClusterResourceListeners()); + metadata.bootstrap(Collections.singletonList(new InetSocketAddress("localhost", 9002))); + + assertEquals(0, metadata.timeToAllowUpdate(time.milliseconds())); + assertEquals(0, metadata.timeToNextUpdate(time.milliseconds())); + } + + @Test + public void testTimeToNextUpdate() { + checkTimeToNextUpdate(100, 1000); + checkTimeToNextUpdate(1000, 100); + checkTimeToNextUpdate(0, 0); + checkTimeToNextUpdate(0, 100); + checkTimeToNextUpdate(100, 0); + } + + @Test + public void testTimeToNextUpdateRetryBackoff() { + long now = 10000; + + // lastRefreshMs updated to now. + metadata.failedUpdate(now); + + // Backing off. Remaining time until next try should be returned. + assertEquals(refreshBackoffMs, metadata.timeToNextUpdate(now)); + + // Even though metadata update requested explicitly, still respects backoff. + metadata.requestUpdate(); + assertEquals(refreshBackoffMs, metadata.timeToNextUpdate(now)); + + // refreshBackoffMs elapsed. + now += refreshBackoffMs; + // It should return 0 to let next try. + assertEquals(0, metadata.timeToNextUpdate(now)); + assertEquals(0, metadata.timeToNextUpdate(now + 1)); + } + + /** + * Prior to Kafka version 2.4 (which coincides with Metadata version 9), the broker does not propagate leader epoch + * information accurately while a reassignment is in progress, so we cannot rely on it. This is explained in more + * detail in MetadataResponse's constructor. + */ + @Test + public void testIgnoreLeaderEpochInOlderMetadataResponse() { + TopicPartition tp = new TopicPartition("topic", 0); + + MetadataResponsePartition partitionMetadata = new MetadataResponsePartition() + .setPartitionIndex(tp.partition()) + .setLeaderId(5) + .setLeaderEpoch(10) + .setReplicaNodes(Arrays.asList(1, 2, 3)) + .setIsrNodes(Arrays.asList(1, 2, 3)) + .setOfflineReplicas(Collections.emptyList()) + .setErrorCode(Errors.NONE.code()); + + MetadataResponseTopic topicMetadata = new MetadataResponseTopic() + .setName(tp.topic()) + .setErrorCode(Errors.NONE.code()) + .setPartitions(Collections.singletonList(partitionMetadata)) + .setIsInternal(false); + + MetadataResponseTopicCollection topics = new MetadataResponseTopicCollection(); + topics.add(topicMetadata); + + MetadataResponseData data = new MetadataResponseData() + .setClusterId("clusterId") + .setControllerId(0) + .setTopics(topics) + .setBrokers(new MetadataResponseBrokerCollection()); + + for (short version = ApiKeys.METADATA.oldestVersion(); version < 9; version++) { + ByteBuffer buffer = MessageUtil.toByteBuffer(data, version); + MetadataResponse response = MetadataResponse.parse(buffer, version); + assertFalse(response.hasReliableLeaderEpochs()); + metadata.updateWithCurrentRequestVersion(response, false, 100); + assertTrue(metadata.partitionMetadataIfCurrent(tp).isPresent()); + MetadataResponse.PartitionMetadata responseMetadata = this.metadata.partitionMetadataIfCurrent(tp).get(); + assertEquals(Optional.empty(), responseMetadata.leaderEpoch); + } + + for (short version = 9; version <= ApiKeys.METADATA.latestVersion(); version++) { + ByteBuffer buffer = MessageUtil.toByteBuffer(data, version); + MetadataResponse response = MetadataResponse.parse(buffer, version); + assertTrue(response.hasReliableLeaderEpochs()); + metadata.updateWithCurrentRequestVersion(response, false, 100); + assertTrue(metadata.partitionMetadataIfCurrent(tp).isPresent()); + MetadataResponse.PartitionMetadata responseMetadata = metadata.partitionMetadataIfCurrent(tp).get(); + assertEquals(Optional.of(10), responseMetadata.leaderEpoch); + } + } + + @Test + public void testStaleMetadata() { + TopicPartition tp = new TopicPartition("topic", 0); + + MetadataResponsePartition partitionMetadata = new MetadataResponsePartition() + .setPartitionIndex(tp.partition()) + .setLeaderId(1) + .setLeaderEpoch(10) + .setReplicaNodes(Arrays.asList(1, 2, 3)) + .setIsrNodes(Arrays.asList(1, 2, 3)) + .setOfflineReplicas(Collections.emptyList()) + .setErrorCode(Errors.NONE.code()); + + MetadataResponseTopic topicMetadata = new MetadataResponseTopic() + .setName(tp.topic()) + .setErrorCode(Errors.NONE.code()) + .setPartitions(Collections.singletonList(partitionMetadata)) + .setIsInternal(false); + + MetadataResponseTopicCollection topics = new MetadataResponseTopicCollection(); + topics.add(topicMetadata); + + MetadataResponseData data = new MetadataResponseData() + .setClusterId("clusterId") + .setControllerId(0) + .setTopics(topics) + .setBrokers(new MetadataResponseBrokerCollection()); + + metadata.updateWithCurrentRequestVersion(new MetadataResponse(data, ApiKeys.METADATA.latestVersion()), false, 100); + + // Older epoch with changed ISR should be ignored + partitionMetadata + .setPartitionIndex(tp.partition()) + .setLeaderId(1) + .setLeaderEpoch(9) + .setReplicaNodes(Arrays.asList(1, 2, 3)) + .setIsrNodes(Arrays.asList(1, 2)) + .setOfflineReplicas(Collections.emptyList()) + .setErrorCode(Errors.NONE.code()); + + metadata.updateWithCurrentRequestVersion(new MetadataResponse(data, ApiKeys.METADATA.latestVersion()), false, 101); + assertEquals(Optional.of(10), metadata.lastSeenLeaderEpoch(tp)); + + assertTrue(metadata.partitionMetadataIfCurrent(tp).isPresent()); + MetadataResponse.PartitionMetadata responseMetadata = this.metadata.partitionMetadataIfCurrent(tp).get(); + + assertEquals(Arrays.asList(1, 2, 3), responseMetadata.inSyncReplicaIds); + assertEquals(Optional.of(10), responseMetadata.leaderEpoch); + } + + @Test + public void testFailedUpdate() { + long time = 100; + metadata.updateWithCurrentRequestVersion(emptyMetadataResponse(), false, time); + + assertEquals(100, metadata.timeToNextUpdate(1000)); + metadata.failedUpdate(1100); + + assertEquals(100, metadata.timeToNextUpdate(1100)); + assertEquals(100, metadata.lastSuccessfulUpdate()); + + metadata.updateWithCurrentRequestVersion(emptyMetadataResponse(), false, time); + assertEquals(100, metadata.timeToNextUpdate(1000)); + } + + @Test + public void testClusterListenerGetsNotifiedOfUpdate() { + MockClusterResourceListener mockClusterListener = new MockClusterResourceListener(); + ClusterResourceListeners listeners = new ClusterResourceListeners(); + listeners.maybeAdd(mockClusterListener); + metadata = new Metadata(refreshBackoffMs, metadataExpireMs, new LogContext(), listeners); + + String hostName = "www.example.com"; + metadata.bootstrap(Collections.singletonList(new InetSocketAddress(hostName, 9002))); + assertFalse(MockClusterResourceListener.IS_ON_UPDATE_CALLED.get(), + "ClusterResourceListener should not called when metadata is updated with bootstrap Cluster"); + + Map partitionCounts = new HashMap<>(); + partitionCounts.put("topic", 1); + partitionCounts.put("topic1", 1); + MetadataResponse metadataResponse = RequestTestUtils.metadataUpdateWith("dummy", 1, partitionCounts); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 100); + + assertEquals("dummy", mockClusterListener.clusterResource().clusterId(), + "MockClusterResourceListener did not get cluster metadata correctly"); + assertTrue(MockClusterResourceListener.IS_ON_UPDATE_CALLED.get(), + "MockClusterResourceListener should be called when metadata is updated with non-bootstrap Cluster"); + } + + @Test + public void testRequestUpdate() { + assertFalse(metadata.updateRequested()); + + int[] epochs = {42, 42, 41, 41, 42, 43, 43, 42, 41, 44}; + boolean[] updateResult = {true, false, false, false, false, true, false, false, false, true}; + TopicPartition tp = new TopicPartition("topic", 0); + + MetadataResponse metadataResponse = RequestTestUtils.metadataUpdateWith("dummy", 1, + Collections.emptyMap(), Collections.singletonMap("topic", 1), _tp -> 0); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 10L); + + for (int i = 0; i < epochs.length; i++) { + metadata.updateLastSeenEpochIfNewer(tp, epochs[i]); + if (updateResult[i]) { + assertTrue(metadata.updateRequested(), "Expected metadata update to be requested [" + i + "]"); + } else { + assertFalse(metadata.updateRequested(), "Did not expect metadata update to be requested [" + i + "]"); + } + metadata.updateWithCurrentRequestVersion(emptyMetadataResponse(), false, 0L); + assertFalse(metadata.updateRequested()); + } + } + + @Test + public void testUpdateLastEpoch() { + TopicPartition tp = new TopicPartition("topic-1", 0); + + MetadataResponse metadataResponse = emptyMetadataResponse(); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 0L); + + // if we have no leader epoch, this call shouldn't do anything + assertFalse(metadata.updateLastSeenEpochIfNewer(tp, 0)); + assertFalse(metadata.updateLastSeenEpochIfNewer(tp, 1)); + assertFalse(metadata.updateLastSeenEpochIfNewer(tp, 2)); + assertFalse(metadata.lastSeenLeaderEpoch(tp).isPresent()); + + // Metadata with newer epoch is handled + metadataResponse = RequestTestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), Collections.singletonMap("topic-1", 1), _tp -> 10); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 1L); + assertOptional(metadata.lastSeenLeaderEpoch(tp), leaderAndEpoch -> assertEquals(leaderAndEpoch.intValue(), 10)); + + // Don't update to an older one + assertFalse(metadata.updateLastSeenEpochIfNewer(tp, 1)); + assertOptional(metadata.lastSeenLeaderEpoch(tp), leaderAndEpoch -> assertEquals(leaderAndEpoch.intValue(), 10)); + + // Don't cause update if it's the same one + assertFalse(metadata.updateLastSeenEpochIfNewer(tp, 10)); + assertOptional(metadata.lastSeenLeaderEpoch(tp), leaderAndEpoch -> assertEquals(leaderAndEpoch.intValue(), 10)); + + // Update if we see newer epoch + assertTrue(metadata.updateLastSeenEpochIfNewer(tp, 12)); + assertOptional(metadata.lastSeenLeaderEpoch(tp), leaderAndEpoch -> assertEquals(leaderAndEpoch.intValue(), 12)); + + metadataResponse = RequestTestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), Collections.singletonMap("topic-1", 1), _tp -> 12); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 2L); + assertOptional(metadata.lastSeenLeaderEpoch(tp), leaderAndEpoch -> assertEquals(leaderAndEpoch.intValue(), 12)); + + // Don't overwrite metadata with older epoch + metadataResponse = RequestTestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), Collections.singletonMap("topic-1", 1), _tp -> 11); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 3L); + assertOptional(metadata.lastSeenLeaderEpoch(tp), leaderAndEpoch -> assertEquals(leaderAndEpoch.intValue(), 12)); + } + + @Test + public void testEpochUpdateAfterTopicDeletion() { + TopicPartition tp = new TopicPartition("topic-1", 0); + + MetadataResponse metadataResponse = emptyMetadataResponse(); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 0L); + + // Start with a Topic topic-1 with a random topic ID + Map topicIds = Collections.singletonMap("topic-1", Uuid.randomUuid()); + metadataResponse = RequestTestUtils.metadataUpdateWithIds("dummy", 1, Collections.emptyMap(), Collections.singletonMap("topic-1", 1), _tp -> 10, topicIds); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 1L); + assertEquals(Optional.of(10), metadata.lastSeenLeaderEpoch(tp)); + + // Topic topic-1 is now deleted so Response contains an Error. LeaderEpoch should still maintain Old value + metadataResponse = RequestTestUtils.metadataUpdateWith("dummy", 1, Collections.singletonMap("topic-1", Errors.UNKNOWN_TOPIC_OR_PARTITION), Collections.emptyMap()); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 1L); + assertEquals(Optional.of(10), metadata.lastSeenLeaderEpoch(tp)); + + // Create topic-1 again but this time with a different topic ID. LeaderEpoch should be updated to new even if lower. + Map newTopicIds = Collections.singletonMap("topic-1", Uuid.randomUuid()); + metadataResponse = RequestTestUtils.metadataUpdateWithIds("dummy", 1, Collections.emptyMap(), Collections.singletonMap("topic-1", 1), _tp -> 5, newTopicIds); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 1L); + assertEquals(Optional.of(5), metadata.lastSeenLeaderEpoch(tp)); + } + + @Test + public void testEpochUpdateOnChangedTopicIds() { + TopicPartition tp = new TopicPartition("topic-1", 0); + Map topicIds = Collections.singletonMap("topic-1", Uuid.randomUuid()); + + MetadataResponse metadataResponse = emptyMetadataResponse(); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 0L); + + // Start with a topic with no topic ID + metadataResponse = RequestTestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), Collections.singletonMap("topic-1", 1), _tp -> 100); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 1L); + assertEquals(Optional.of(100), metadata.lastSeenLeaderEpoch(tp)); + + // If the older topic ID is null, we should go with the new topic ID as the leader epoch + metadataResponse = RequestTestUtils.metadataUpdateWithIds("dummy", 1, Collections.emptyMap(), Collections.singletonMap("topic-1", 1), _tp -> 10, topicIds); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 2L); + assertEquals(Optional.of(10), metadata.lastSeenLeaderEpoch(tp)); + + // Don't cause update if it's the same one + metadataResponse = RequestTestUtils.metadataUpdateWithIds("dummy", 1, Collections.emptyMap(), Collections.singletonMap("topic-1", 1), _tp -> 10, topicIds); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 3L); + assertEquals(Optional.of(10), metadata.lastSeenLeaderEpoch(tp)); + + // Update if we see newer epoch + metadataResponse = RequestTestUtils.metadataUpdateWithIds("dummy", 1, Collections.emptyMap(), Collections.singletonMap("topic-1", 1), _tp -> 12, topicIds); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 4L); + assertEquals(Optional.of(12), metadata.lastSeenLeaderEpoch(tp)); + + // We should also update if we see a new topicId even if the epoch is lower + Map newTopicIds = Collections.singletonMap("topic-1", Uuid.randomUuid()); + metadataResponse = RequestTestUtils.metadataUpdateWithIds("dummy", 1, Collections.emptyMap(), Collections.singletonMap("topic-1", 1), _tp -> 3, newTopicIds); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 5L); + assertEquals(Optional.of(3), metadata.lastSeenLeaderEpoch(tp)); + + // Finally, update when the topic ID is new and the epoch is higher + Map newTopicIds2 = Collections.singletonMap("topic-1", Uuid.randomUuid()); + metadataResponse = RequestTestUtils.metadataUpdateWithIds("dummy", 1, Collections.emptyMap(), Collections.singletonMap("topic-1", 1), _tp -> 20, newTopicIds2); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 6L); + assertEquals(Optional.of(20), metadata.lastSeenLeaderEpoch(tp)); + } + + @Test + public void testRejectOldMetadata() { + Map partitionCounts = new HashMap<>(); + partitionCounts.put("topic-1", 1); + TopicPartition tp = new TopicPartition("topic-1", 0); + + metadata.updateWithCurrentRequestVersion(emptyMetadataResponse(), false, 0L); + + // First epoch seen, accept it + { + MetadataResponse metadataResponse = RequestTestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), partitionCounts, _tp -> 100); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 10L); + assertNotNull(metadata.fetch().partition(tp)); + assertTrue(metadata.lastSeenLeaderEpoch(tp).isPresent()); + assertEquals(metadata.lastSeenLeaderEpoch(tp).get().longValue(), 100); + } + + // Fake an empty ISR, but with an older epoch, should reject it + { + MetadataResponse metadataResponse = RequestTestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), partitionCounts, _tp -> 99, + (error, partition, leader, leaderEpoch, replicas, isr, offlineReplicas) -> + new MetadataResponse.PartitionMetadata(error, partition, leader, + leaderEpoch, replicas, Collections.emptyList(), offlineReplicas), ApiKeys.METADATA.latestVersion(), Collections.emptyMap()); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 20L); + assertEquals(metadata.fetch().partition(tp).inSyncReplicas().length, 1); + assertEquals(metadata.lastSeenLeaderEpoch(tp).get().longValue(), 100); + } + + // Fake an empty ISR, with same epoch, accept it + { + MetadataResponse metadataResponse = RequestTestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), partitionCounts, _tp -> 100, + (error, partition, leader, leaderEpoch, replicas, isr, offlineReplicas) -> + new MetadataResponse.PartitionMetadata(error, partition, leader, + leaderEpoch, replicas, Collections.emptyList(), offlineReplicas), ApiKeys.METADATA.latestVersion(), Collections.emptyMap()); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 20L); + assertEquals(metadata.fetch().partition(tp).inSyncReplicas().length, 0); + assertEquals(metadata.lastSeenLeaderEpoch(tp).get().longValue(), 100); + } + + // Empty metadata response, should not keep old partition but should keep the last-seen epoch + { + MetadataResponse metadataResponse = RequestTestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), Collections.emptyMap()); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 20L); + assertNull(metadata.fetch().partition(tp)); + assertEquals(metadata.lastSeenLeaderEpoch(tp).get().longValue(), 100); + } + + // Back in the metadata, with old epoch, should not get added + { + MetadataResponse metadataResponse = RequestTestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), partitionCounts, _tp -> 99); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 10L); + assertNull(metadata.fetch().partition(tp)); + assertEquals(metadata.lastSeenLeaderEpoch(tp).get().longValue(), 100); + } + } + + @Test + public void testOutOfBandEpochUpdate() { + Map partitionCounts = new HashMap<>(); + partitionCounts.put("topic-1", 5); + TopicPartition tp = new TopicPartition("topic-1", 0); + + metadata.updateWithCurrentRequestVersion(emptyMetadataResponse(), false, 0L); + + assertFalse(metadata.updateLastSeenEpochIfNewer(tp, 99)); + + // Update epoch to 100 + MetadataResponse metadataResponse = RequestTestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), partitionCounts, _tp -> 100); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 10L); + assertNotNull(metadata.fetch().partition(tp)); + assertTrue(metadata.lastSeenLeaderEpoch(tp).isPresent()); + assertEquals(metadata.lastSeenLeaderEpoch(tp).get().longValue(), 100); + + // Simulate a leader epoch from another response, like a fetch response or list offsets + assertTrue(metadata.updateLastSeenEpochIfNewer(tp, 101)); + + // Cache of partition stays, but current partition info is not available since it's stale + assertNotNull(metadata.fetch().partition(tp)); + assertEquals(Objects.requireNonNull(metadata.fetch().partitionCountForTopic("topic-1")).longValue(), 5); + assertFalse(metadata.partitionMetadataIfCurrent(tp).isPresent()); + assertEquals(metadata.lastSeenLeaderEpoch(tp).get().longValue(), 101); + + // Metadata with older epoch is rejected, metadata state is unchanged + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 20L); + assertNotNull(metadata.fetch().partition(tp)); + assertEquals(Objects.requireNonNull(metadata.fetch().partitionCountForTopic("topic-1")).longValue(), 5); + assertFalse(metadata.partitionMetadataIfCurrent(tp).isPresent()); + assertEquals(metadata.lastSeenLeaderEpoch(tp).get().longValue(), 101); + + // Metadata with equal or newer epoch is accepted + metadataResponse = RequestTestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), partitionCounts, _tp -> 101); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 30L); + assertNotNull(metadata.fetch().partition(tp)); + assertEquals(Objects.requireNonNull(metadata.fetch().partitionCountForTopic("topic-1")).longValue(), 5); + assertTrue(metadata.partitionMetadataIfCurrent(tp).isPresent()); + assertEquals(metadata.lastSeenLeaderEpoch(tp).get().longValue(), 101); + } + + @Test + public void testNoEpoch() { + metadata.updateWithCurrentRequestVersion(emptyMetadataResponse(), false, 0L); + MetadataResponse metadataResponse = RequestTestUtils.metadataUpdateWith("dummy", 1, Collections.emptyMap(), Collections.singletonMap("topic-1", 1)); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 10L); + + TopicPartition tp = new TopicPartition("topic-1", 0); + + // no epoch + assertFalse(metadata.lastSeenLeaderEpoch(tp).isPresent()); + + // still works + assertTrue(metadata.partitionMetadataIfCurrent(tp).isPresent()); + assertEquals(0, metadata.partitionMetadataIfCurrent(tp).get().partition()); + assertEquals(Optional.of(0), metadata.partitionMetadataIfCurrent(tp).get().leaderId); + + // Since epoch was null, this shouldn't update it + metadata.updateLastSeenEpochIfNewer(tp, 10); + assertTrue(metadata.partitionMetadataIfCurrent(tp).isPresent()); + assertFalse(metadata.partitionMetadataIfCurrent(tp).get().leaderEpoch.isPresent()); + } + + @Test + public void testClusterCopy() { + Map counts = new HashMap<>(); + Map errors = new HashMap<>(); + counts.put("topic1", 2); + counts.put("topic2", 3); + counts.put(Topic.GROUP_METADATA_TOPIC_NAME, 3); + errors.put("topic3", Errors.INVALID_TOPIC_EXCEPTION); + errors.put("topic4", Errors.TOPIC_AUTHORIZATION_FAILED); + + MetadataResponse metadataResponse = RequestTestUtils.metadataUpdateWith("dummy", 4, errors, counts); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 0L); + + Cluster cluster = metadata.fetch(); + assertEquals(cluster.clusterResource().clusterId(), "dummy"); + assertEquals(cluster.nodes().size(), 4); + + // topic counts + assertEquals(cluster.invalidTopics(), Collections.singleton("topic3")); + assertEquals(cluster.unauthorizedTopics(), Collections.singleton("topic4")); + assertEquals(cluster.topics().size(), 3); + assertEquals(cluster.internalTopics(), Collections.singleton(Topic.GROUP_METADATA_TOPIC_NAME)); + + // partition counts + assertEquals(cluster.partitionsForTopic("topic1").size(), 2); + assertEquals(cluster.partitionsForTopic("topic2").size(), 3); + + // Sentinel instances + InetSocketAddress address = InetSocketAddress.createUnresolved("localhost", 0); + Cluster fromMetadata = MetadataCache.bootstrap(Collections.singletonList(address)).cluster(); + Cluster fromCluster = Cluster.bootstrap(Collections.singletonList(address)); + assertEquals(fromMetadata, fromCluster); + + Cluster fromMetadataEmpty = MetadataCache.empty().cluster(); + Cluster fromClusterEmpty = Cluster.empty(); + assertEquals(fromMetadataEmpty, fromClusterEmpty); + } + + @Test + public void testRequestVersion() { + Time time = new MockTime(); + + metadata.requestUpdate(); + Metadata.MetadataRequestAndVersion versionAndBuilder = metadata.newMetadataRequestAndVersion(time.milliseconds()); + metadata.update(versionAndBuilder.requestVersion, + RequestTestUtils.metadataUpdateWith(1, Collections.singletonMap("topic", 1)), false, time.milliseconds()); + assertFalse(metadata.updateRequested()); + + // bump the request version for new topics added to the metadata + metadata.requestUpdateForNewTopics(); + + // simulating a bump while a metadata request is in flight + versionAndBuilder = metadata.newMetadataRequestAndVersion(time.milliseconds()); + metadata.requestUpdateForNewTopics(); + metadata.update(versionAndBuilder.requestVersion, + RequestTestUtils.metadataUpdateWith(1, Collections.singletonMap("topic", 1)), true, time.milliseconds()); + + // metadata update is still needed + assertTrue(metadata.updateRequested()); + + // the next update will resolve it + versionAndBuilder = metadata.newMetadataRequestAndVersion(time.milliseconds()); + metadata.update(versionAndBuilder.requestVersion, + RequestTestUtils.metadataUpdateWith(1, Collections.singletonMap("topic", 1)), true, time.milliseconds()); + assertFalse(metadata.updateRequested()); + } + + @Test + public void testPartialMetadataUpdate() { + Time time = new MockTime(); + + metadata = new Metadata(refreshBackoffMs, metadataExpireMs, new LogContext(), new ClusterResourceListeners()) { + @Override + protected MetadataRequest.Builder newMetadataRequestBuilderForNewTopics() { + return newMetadataRequestBuilder(); + } + }; + + assertFalse(metadata.updateRequested()); + + // Request a metadata update. This must force a full metadata update request. + metadata.requestUpdate(); + Metadata.MetadataRequestAndVersion versionAndBuilder = metadata.newMetadataRequestAndVersion(time.milliseconds()); + assertFalse(versionAndBuilder.isPartialUpdate); + metadata.update(versionAndBuilder.requestVersion, + RequestTestUtils.metadataUpdateWith(1, Collections.singletonMap("topic", 1)), false, time.milliseconds()); + assertFalse(metadata.updateRequested()); + + // Request a metadata update for a new topic. This should perform a partial metadata update. + metadata.requestUpdateForNewTopics(); + versionAndBuilder = metadata.newMetadataRequestAndVersion(time.milliseconds()); + assertTrue(versionAndBuilder.isPartialUpdate); + metadata.update(versionAndBuilder.requestVersion, + RequestTestUtils.metadataUpdateWith(1, Collections.singletonMap("topic", 1)), true, time.milliseconds()); + assertFalse(metadata.updateRequested()); + + // Request both types of metadata updates. This should always perform a full update. + metadata.requestUpdate(); + metadata.requestUpdateForNewTopics(); + versionAndBuilder = metadata.newMetadataRequestAndVersion(time.milliseconds()); + assertFalse(versionAndBuilder.isPartialUpdate); + metadata.update(versionAndBuilder.requestVersion, + RequestTestUtils.metadataUpdateWith(1, Collections.singletonMap("topic", 1)), false, time.milliseconds()); + assertFalse(metadata.updateRequested()); + + // Request only a partial metadata update, but elapse enough time such that a full refresh is needed. + metadata.requestUpdateForNewTopics(); + final long refreshTimeMs = time.milliseconds() + metadata.metadataExpireMs(); + versionAndBuilder = metadata.newMetadataRequestAndVersion(refreshTimeMs); + assertFalse(versionAndBuilder.isPartialUpdate); + metadata.update(versionAndBuilder.requestVersion, + RequestTestUtils.metadataUpdateWith(1, Collections.singletonMap("topic", 1)), true, refreshTimeMs); + assertFalse(metadata.updateRequested()); + + // Request two partial metadata updates that are overlapping. + metadata.requestUpdateForNewTopics(); + versionAndBuilder = metadata.newMetadataRequestAndVersion(time.milliseconds()); + assertTrue(versionAndBuilder.isPartialUpdate); + metadata.requestUpdateForNewTopics(); + Metadata.MetadataRequestAndVersion overlappingVersionAndBuilder = metadata.newMetadataRequestAndVersion(time.milliseconds()); + assertTrue(overlappingVersionAndBuilder.isPartialUpdate); + assertTrue(metadata.updateRequested()); + metadata.update(versionAndBuilder.requestVersion, + RequestTestUtils.metadataUpdateWith(1, Collections.singletonMap("topic-1", 1)), true, time.milliseconds()); + assertTrue(metadata.updateRequested()); + metadata.update(overlappingVersionAndBuilder.requestVersion, + RequestTestUtils.metadataUpdateWith(1, Collections.singletonMap("topic-2", 1)), true, time.milliseconds()); + assertFalse(metadata.updateRequested()); + } + + @Test + public void testInvalidTopicError() { + Time time = new MockTime(); + + String invalidTopic = "topic dfsa"; + MetadataResponse invalidTopicResponse = RequestTestUtils.metadataUpdateWith("clusterId", 1, + Collections.singletonMap(invalidTopic, Errors.INVALID_TOPIC_EXCEPTION), Collections.emptyMap()); + metadata.updateWithCurrentRequestVersion(invalidTopicResponse, false, time.milliseconds()); + + InvalidTopicException e = assertThrows(InvalidTopicException.class, () -> metadata.maybeThrowAnyException()); + + assertEquals(Collections.singleton(invalidTopic), e.invalidTopics()); + // We clear the exception once it has been raised to the user + metadata.maybeThrowAnyException(); + + // Reset the invalid topic error + metadata.updateWithCurrentRequestVersion(invalidTopicResponse, false, time.milliseconds()); + + // If we get a good update, the error should clear even if we haven't had a chance to raise it to the user + metadata.updateWithCurrentRequestVersion(emptyMetadataResponse(), false, time.milliseconds()); + metadata.maybeThrowAnyException(); + } + + @Test + public void testTopicAuthorizationError() { + Time time = new MockTime(); + + String invalidTopic = "foo"; + MetadataResponse unauthorizedTopicResponse = RequestTestUtils.metadataUpdateWith("clusterId", 1, + Collections.singletonMap(invalidTopic, Errors.TOPIC_AUTHORIZATION_FAILED), Collections.emptyMap()); + metadata.updateWithCurrentRequestVersion(unauthorizedTopicResponse, false, time.milliseconds()); + + TopicAuthorizationException e = assertThrows(TopicAuthorizationException.class, () -> metadata.maybeThrowAnyException()); + assertEquals(Collections.singleton(invalidTopic), e.unauthorizedTopics()); + // We clear the exception once it has been raised to the user + metadata.maybeThrowAnyException(); + + // Reset the unauthorized topic error + metadata.updateWithCurrentRequestVersion(unauthorizedTopicResponse, false, time.milliseconds()); + + // If we get a good update, the error should clear even if we haven't had a chance to raise it to the user + metadata.updateWithCurrentRequestVersion(emptyMetadataResponse(), false, time.milliseconds()); + metadata.maybeThrowAnyException(); + } + + @Test + public void testMetadataTopicErrors() { + Time time = new MockTime(); + + Map topicErrors = new HashMap<>(3); + topicErrors.put("invalidTopic", Errors.INVALID_TOPIC_EXCEPTION); + topicErrors.put("sensitiveTopic1", Errors.TOPIC_AUTHORIZATION_FAILED); + topicErrors.put("sensitiveTopic2", Errors.TOPIC_AUTHORIZATION_FAILED); + MetadataResponse metadataResponse = RequestTestUtils.metadataUpdateWith("clusterId", 1, topicErrors, Collections.emptyMap()); + + metadata.updateWithCurrentRequestVersion(metadataResponse, false, time.milliseconds()); + TopicAuthorizationException e1 = assertThrows(TopicAuthorizationException.class, + () -> metadata.maybeThrowExceptionForTopic("sensitiveTopic1")); + assertEquals(Collections.singleton("sensitiveTopic1"), e1.unauthorizedTopics()); + // We clear the exception once it has been raised to the user + metadata.maybeThrowAnyException(); + + metadata.updateWithCurrentRequestVersion(metadataResponse, false, time.milliseconds()); + TopicAuthorizationException e2 = assertThrows(TopicAuthorizationException.class, + () -> metadata.maybeThrowExceptionForTopic("sensitiveTopic2")); + assertEquals(Collections.singleton("sensitiveTopic2"), e2.unauthorizedTopics()); + metadata.maybeThrowAnyException(); + + metadata.updateWithCurrentRequestVersion(metadataResponse, false, time.milliseconds()); + InvalidTopicException e3 = assertThrows(InvalidTopicException.class, + () -> metadata.maybeThrowExceptionForTopic("invalidTopic")); + assertEquals(Collections.singleton("invalidTopic"), e3.invalidTopics()); + metadata.maybeThrowAnyException(); + + // Other topics should not throw exception, but they should clear existing exception + metadata.updateWithCurrentRequestVersion(metadataResponse, false, time.milliseconds()); + metadata.maybeThrowExceptionForTopic("anotherTopic"); + metadata.maybeThrowAnyException(); + } + + @Test + public void testNodeIfOffline() { + Map partitionCounts = new HashMap<>(); + partitionCounts.put("topic-1", 1); + Node node0 = new Node(0, "localhost", 9092); + Node node1 = new Node(1, "localhost", 9093); + + MetadataResponse metadataResponse = RequestTestUtils.metadataUpdateWith("dummy", 2, Collections.emptyMap(), partitionCounts, _tp -> 99, + (error, partition, leader, leaderEpoch, replicas, isr, offlineReplicas) -> + new MetadataResponse.PartitionMetadata(error, partition, Optional.of(node0.id()), leaderEpoch, + Collections.singletonList(node0.id()), Collections.emptyList(), + Collections.singletonList(node1.id())), ApiKeys.METADATA.latestVersion(), Collections.emptyMap()); + metadata.updateWithCurrentRequestVersion(emptyMetadataResponse(), false, 0L); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 10L); + + TopicPartition tp = new TopicPartition("topic-1", 0); + + assertOptional(metadata.fetch().nodeIfOnline(tp, 0), node -> assertEquals(node.id(), 0)); + assertFalse(metadata.fetch().nodeIfOnline(tp, 1).isPresent()); + assertEquals(metadata.fetch().nodeById(0).id(), 0); + assertEquals(metadata.fetch().nodeById(1).id(), 1); + } + + @Test + public void testLeaderMetadataInconsistentWithBrokerMetadata() { + // Tests a reordering scenario which can lead to inconsistent leader state. + // A partition initially has one broker offline. That broker comes online and + // is elected leader. The client sees these two events in the opposite order. + + TopicPartition tp = new TopicPartition("topic", 0); + + Node node0 = new Node(0, "localhost", 9092); + Node node1 = new Node(1, "localhost", 9093); + Node node2 = new Node(2, "localhost", 9094); + + // The first metadata received by broker (epoch=10) + MetadataResponsePartition firstPartitionMetadata = new MetadataResponsePartition() + .setPartitionIndex(tp.partition()) + .setErrorCode(Errors.NONE.code()) + .setLeaderEpoch(10) + .setLeaderId(0) + .setReplicaNodes(Arrays.asList(0, 1, 2)) + .setIsrNodes(Arrays.asList(0, 1, 2)) + .setOfflineReplicas(Collections.emptyList()); + + // The second metadata received has stale metadata (epoch=8) + MetadataResponsePartition secondPartitionMetadata = new MetadataResponsePartition() + .setPartitionIndex(tp.partition()) + .setErrorCode(Errors.NONE.code()) + .setLeaderEpoch(8) + .setLeaderId(1) + .setReplicaNodes(Arrays.asList(0, 1, 2)) + .setIsrNodes(Arrays.asList(1, 2)) + .setOfflineReplicas(Collections.singletonList(0)); + + metadata.updateWithCurrentRequestVersion(new MetadataResponse(new MetadataResponseData() + .setTopics(buildTopicCollection(tp.topic(), firstPartitionMetadata)) + .setBrokers(buildBrokerCollection(Arrays.asList(node0, node1, node2))), + ApiKeys.METADATA.latestVersion()), + false, 10L); + + metadata.updateWithCurrentRequestVersion(new MetadataResponse(new MetadataResponseData() + .setTopics(buildTopicCollection(tp.topic(), secondPartitionMetadata)) + .setBrokers(buildBrokerCollection(Arrays.asList(node1, node2))), + ApiKeys.METADATA.latestVersion()), + false, 20L); + + assertNull(metadata.fetch().leaderFor(tp)); + assertEquals(Optional.of(10), metadata.lastSeenLeaderEpoch(tp)); + assertFalse(metadata.currentLeader(tp).leader.isPresent()); + } + + private MetadataResponseTopicCollection buildTopicCollection(String topic, MetadataResponsePartition partitionMetadata) { + MetadataResponseTopic topicMetadata = new MetadataResponseTopic() + .setErrorCode(Errors.NONE.code()) + .setName(topic) + .setIsInternal(false); + + topicMetadata.setPartitions(Collections.singletonList(partitionMetadata)); + + MetadataResponseTopicCollection topics = new MetadataResponseTopicCollection(); + topics.add(topicMetadata); + return topics; + } + + private MetadataResponseBrokerCollection buildBrokerCollection(List nodes) { + MetadataResponseBrokerCollection brokers = new MetadataResponseBrokerCollection(); + for (Node node : nodes) { + MetadataResponseData.MetadataResponseBroker broker = new MetadataResponseData.MetadataResponseBroker() + .setNodeId(node.id()) + .setHost(node.host()) + .setPort(node.port()) + .setRack(node.rack()); + brokers.add(broker); + } + return brokers; + } + + @Test + public void testMetadataMerge() { + Time time = new MockTime(); + Map topicIds = new HashMap<>(); + + final AtomicReference> retainTopics = new AtomicReference<>(new HashSet<>()); + metadata = new Metadata(refreshBackoffMs, metadataExpireMs, new LogContext(), new ClusterResourceListeners()) { + @Override + protected boolean retainTopic(String topic, boolean isInternal, long nowMs) { + return retainTopics.get().contains(topic); + } + }; + + // Initialize a metadata instance with two topic variants "old" and "keep". Both will be retained. + String oldClusterId = "oldClusterId"; + int oldNodes = 2; + Map oldTopicErrors = new HashMap<>(); + oldTopicErrors.put("oldInvalidTopic", Errors.INVALID_TOPIC_EXCEPTION); + oldTopicErrors.put("keepInvalidTopic", Errors.INVALID_TOPIC_EXCEPTION); + oldTopicErrors.put("oldUnauthorizedTopic", Errors.TOPIC_AUTHORIZATION_FAILED); + oldTopicErrors.put("keepUnauthorizedTopic", Errors.TOPIC_AUTHORIZATION_FAILED); + Map oldTopicPartitionCounts = new HashMap<>(); + oldTopicPartitionCounts.put("oldValidTopic", 2); + oldTopicPartitionCounts.put("keepValidTopic", 3); + + retainTopics.set(Utils.mkSet( + "oldInvalidTopic", + "keepInvalidTopic", + "oldUnauthorizedTopic", + "keepUnauthorizedTopic", + "oldValidTopic", + "keepValidTopic")); + + topicIds.put("oldValidTopic", Uuid.randomUuid()); + topicIds.put("keepValidTopic", Uuid.randomUuid()); + MetadataResponse metadataResponse = + RequestTestUtils.metadataUpdateWithIds(oldClusterId, oldNodes, oldTopicErrors, oldTopicPartitionCounts, _tp -> 100, topicIds); + metadata.updateWithCurrentRequestVersion(metadataResponse, true, time.milliseconds()); + Map metadataTopicIds1 = metadata.topicIds(); + retainTopics.get().forEach(topic -> assertEquals(metadataTopicIds1.get(topic), topicIds.get(topic))); + + // Update the metadata to add a new topic variant, "new", which will be retained with "keep". Note this + // means that all of the "old" topics should be dropped. + Cluster cluster = metadata.fetch(); + assertEquals(cluster.clusterResource().clusterId(), oldClusterId); + assertEquals(cluster.nodes().size(), oldNodes); + assertEquals(cluster.invalidTopics(), new HashSet<>(Arrays.asList("oldInvalidTopic", "keepInvalidTopic"))); + assertEquals(cluster.unauthorizedTopics(), new HashSet<>(Arrays.asList("oldUnauthorizedTopic", "keepUnauthorizedTopic"))); + assertEquals(cluster.topics(), new HashSet<>(Arrays.asList("oldValidTopic", "keepValidTopic"))); + assertEquals(cluster.partitionsForTopic("oldValidTopic").size(), 2); + assertEquals(cluster.partitionsForTopic("keepValidTopic").size(), 3); + assertEquals(new HashSet<>(cluster.topicIds()), new HashSet<>(topicIds.values())); + + String newClusterId = "newClusterId"; + int newNodes = oldNodes + 1; + Map newTopicErrors = new HashMap<>(); + newTopicErrors.put("newInvalidTopic", Errors.INVALID_TOPIC_EXCEPTION); + newTopicErrors.put("newUnauthorizedTopic", Errors.TOPIC_AUTHORIZATION_FAILED); + Map newTopicPartitionCounts = new HashMap<>(); + newTopicPartitionCounts.put("keepValidTopic", 2); + newTopicPartitionCounts.put("newValidTopic", 4); + + retainTopics.set(Utils.mkSet( + "keepInvalidTopic", + "newInvalidTopic", + "keepUnauthorizedTopic", + "newUnauthorizedTopic", + "keepValidTopic", + "newValidTopic")); + + topicIds.put("newValidTopic", Uuid.randomUuid()); + metadataResponse = RequestTestUtils.metadataUpdateWithIds(newClusterId, newNodes, newTopicErrors, newTopicPartitionCounts, _tp -> 200, topicIds); + metadata.updateWithCurrentRequestVersion(metadataResponse, true, time.milliseconds()); + topicIds.remove("oldValidTopic"); + Map metadataTopicIds2 = metadata.topicIds(); + retainTopics.get().forEach(topic -> assertEquals(metadataTopicIds2.get(topic), topicIds.get(topic))); + assertNull(metadataTopicIds2.get("oldValidTopic")); + + cluster = metadata.fetch(); + assertEquals(cluster.clusterResource().clusterId(), newClusterId); + assertEquals(cluster.nodes().size(), newNodes); + assertEquals(cluster.invalidTopics(), new HashSet<>(Arrays.asList("keepInvalidTopic", "newInvalidTopic"))); + assertEquals(cluster.unauthorizedTopics(), new HashSet<>(Arrays.asList("keepUnauthorizedTopic", "newUnauthorizedTopic"))); + assertEquals(cluster.topics(), new HashSet<>(Arrays.asList("keepValidTopic", "newValidTopic"))); + assertEquals(cluster.partitionsForTopic("keepValidTopic").size(), 2); + assertEquals(cluster.partitionsForTopic("newValidTopic").size(), 4); + assertEquals(new HashSet<>(cluster.topicIds()), new HashSet<>(topicIds.values())); + + // Perform another metadata update, but this time all topic metadata should be cleared. + retainTopics.set(Collections.emptySet()); + + metadataResponse = RequestTestUtils.metadataUpdateWithIds(newClusterId, newNodes, newTopicErrors, newTopicPartitionCounts, _tp -> 300, topicIds); + metadata.updateWithCurrentRequestVersion(metadataResponse, true, time.milliseconds()); + Map metadataTopicIds3 = metadata.topicIds(); + topicIds.forEach((topicName, topicId) -> assertNull(metadataTopicIds3.get(topicName))); + + cluster = metadata.fetch(); + assertEquals(cluster.clusterResource().clusterId(), newClusterId); + assertEquals(cluster.nodes().size(), newNodes); + assertEquals(cluster.invalidTopics(), Collections.emptySet()); + assertEquals(cluster.unauthorizedTopics(), Collections.emptySet()); + assertEquals(cluster.topics(), Collections.emptySet()); + assertTrue(cluster.topicIds().isEmpty()); + } + + @Test + public void testMetadataMergeOnIdDowngrade() { + Time time = new MockTime(); + Map topicIds = new HashMap<>(); + + final AtomicReference> retainTopics = new AtomicReference<>(new HashSet<>()); + metadata = new Metadata(refreshBackoffMs, metadataExpireMs, new LogContext(), new ClusterResourceListeners()) { + @Override + protected boolean retainTopic(String topic, boolean isInternal, long nowMs) { + return retainTopics.get().contains(topic); + } + }; + + // Initialize a metadata instance with two topics. Both will be retained. + String clusterId = "clusterId"; + int nodes = 2; + Map topicPartitionCounts = new HashMap<>(); + topicPartitionCounts.put("validTopic1", 2); + topicPartitionCounts.put("validTopic2", 3); + + retainTopics.set(Utils.mkSet( + "validTopic1", + "validTopic2")); + + topicIds.put("validTopic1", Uuid.randomUuid()); + topicIds.put("validTopic2", Uuid.randomUuid()); + MetadataResponse metadataResponse = + RequestTestUtils.metadataUpdateWithIds(clusterId, nodes, Collections.emptyMap(), topicPartitionCounts, _tp -> 100, topicIds); + metadata.updateWithCurrentRequestVersion(metadataResponse, true, time.milliseconds()); + Map metadataTopicIds1 = metadata.topicIds(); + retainTopics.get().forEach(topic -> assertEquals(metadataTopicIds1.get(topic), topicIds.get(topic))); + + // Try removing the topic ID from keepValidTopic (simulating receiving a request from a controller with an older IBP) + topicIds.remove("validTopic1"); + metadataResponse = RequestTestUtils.metadataUpdateWithIds(clusterId, nodes, Collections.emptyMap(), topicPartitionCounts, _tp -> 200, topicIds); + metadata.updateWithCurrentRequestVersion(metadataResponse, true, time.milliseconds()); + Map metadataTopicIds2 = metadata.topicIds(); + retainTopics.get().forEach(topic -> assertEquals(metadataTopicIds2.get(topic), topicIds.get(topic))); + + Cluster cluster = metadata.fetch(); + // We still have the topic, but it just doesn't have an ID. + assertEquals(Utils.mkSet("validTopic1", "validTopic2"), cluster.topics()); + assertEquals(2, cluster.partitionsForTopic("validTopic1").size()); + assertEquals(new HashSet<>(topicIds.values()), new HashSet<>(cluster.topicIds())); + assertEquals(Uuid.ZERO_UUID, cluster.topicId("validTopic1")); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/MockClient.java b/clients/src/test/java/org/apache/kafka/clients/MockClient.java new file mode 100644 index 0000000..28d363b --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/MockClient.java @@ -0,0 +1,807 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients; + +import org.apache.kafka.common.Node; +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.errors.InterruptException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.AbstractRequest; +import org.apache.kafka.common.requests.AbstractResponse; +import org.apache.kafka.common.requests.MetadataRequest; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.test.TestCondition; +import org.apache.kafka.test.TestUtils; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +/** + * A mock network client for use testing code + */ +public class MockClient implements KafkaClient { + public static final RequestMatcher ALWAYS_TRUE = body -> true; + + private static class FutureResponse { + private final Node node; + private final RequestMatcher requestMatcher; + private final AbstractResponse responseBody; + private final boolean disconnected; + private final boolean isUnsupportedRequest; + + public FutureResponse(Node node, + RequestMatcher requestMatcher, + AbstractResponse responseBody, + boolean disconnected, + boolean isUnsupportedRequest) { + this.node = node; + this.requestMatcher = requestMatcher; + this.responseBody = responseBody; + this.disconnected = disconnected; + this.isUnsupportedRequest = isUnsupportedRequest; + } + + } + + private int correlation; + private Runnable wakeupHook; + private final Time time; + private final MockMetadataUpdater metadataUpdater; + private final Map connections = new HashMap<>(); + private final Map pendingAuthenticationErrors = new HashMap<>(); + private final Map authenticationErrors = new HashMap<>(); + // Use concurrent queue for requests so that requests may be queried from a different thread + private final Queue requests = new ConcurrentLinkedDeque<>(); + // Use concurrent queue for responses so that responses may be updated during poll() from a different thread. + private final Queue responses = new ConcurrentLinkedDeque<>(); + private final Queue futureResponses = new ConcurrentLinkedDeque<>(); + private final Queue metadataUpdates = new ConcurrentLinkedDeque<>(); + private volatile NodeApiVersions nodeApiVersions = NodeApiVersions.create(); + private volatile int numBlockingWakeups = 0; + private volatile boolean active = true; + private volatile CompletableFuture disconnectFuture; + private volatile Consumer readyCallback; + + public MockClient(Time time) { + this(time, new NoOpMetadataUpdater()); + } + + public MockClient(Time time, Metadata metadata) { + this(time, new DefaultMockMetadataUpdater(metadata)); + } + + public MockClient(Time time, MockMetadataUpdater metadataUpdater) { + this.time = time; + this.metadataUpdater = metadataUpdater; + } + + public boolean isConnected(String idString) { + return connectionState(idString).state == ConnectionState.State.CONNECTED; + } + + private ConnectionState connectionState(String idString) { + ConnectionState connectionState = connections.get(idString); + if (connectionState == null) { + connectionState = new ConnectionState(); + connections.put(idString, connectionState); + } + return connectionState; + } + + @Override + public boolean isReady(Node node, long now) { + return connectionState(node.idString()).isReady(now); + } + + @Override + public boolean ready(Node node, long now) { + if (readyCallback != null) { + readyCallback.accept(node); + } + return connectionState(node.idString()).ready(now); + } + + @Override + public long connectionDelay(Node node, long now) { + return connectionState(node.idString()).connectionDelay(now); + } + + @Override + public long pollDelayMs(Node node, long now) { + return connectionDelay(node, now); + } + + public void backoff(Node node, long durationMs) { + connectionState(node.idString()).backoff(time.milliseconds() + durationMs); + } + + public void setUnreachable(Node node, long durationMs) { + disconnect(node.idString()); + connectionState(node.idString()).setUnreachable(time.milliseconds() + durationMs); + } + + public void throttle(Node node, long durationMs) { + connectionState(node.idString()).throttle(time.milliseconds() + durationMs); + } + + public void delayReady(Node node, long durationMs) { + connectionState(node.idString()).setReadyDelayed(time.milliseconds() + durationMs); + } + + public void authenticationFailed(Node node, long backoffMs) { + pendingAuthenticationErrors.remove(node); + authenticationErrors.put(node, (AuthenticationException) Errors.SASL_AUTHENTICATION_FAILED.exception()); + disconnect(node.idString()); + backoff(node, backoffMs); + } + + public void createPendingAuthenticationError(Node node, long backoffMs) { + pendingAuthenticationErrors.put(node, backoffMs); + } + + @Override + public boolean connectionFailed(Node node) { + return connectionState(node.idString()).isBackingOff(time.milliseconds()); + } + + @Override + public AuthenticationException authenticationException(Node node) { + return authenticationErrors.get(node); + } + + public void setReadyCallback(Consumer onReadyCall) { + this.readyCallback = onReadyCall; + } + + public void setDisconnectFuture(CompletableFuture disconnectFuture) { + this.disconnectFuture = disconnectFuture; + } + + @Override + public void disconnect(String node) { + long now = time.milliseconds(); + Iterator iter = requests.iterator(); + while (iter.hasNext()) { + ClientRequest request = iter.next(); + if (request.destination().equals(node)) { + short version = request.requestBuilder().latestAllowedVersion(); + responses.add(new ClientResponse(request.makeHeader(version), request.callback(), request.destination(), + request.createdTimeMs(), now, true, null, null, null)); + iter.remove(); + } + } + CompletableFuture curDisconnectFuture = disconnectFuture; + if (curDisconnectFuture != null) { + curDisconnectFuture.complete(node); + } + connectionState(node).disconnect(); + } + + @Override + public void send(ClientRequest request, long now) { + if (!connectionState(request.destination()).isReady(now)) + throw new IllegalStateException("Cannot send " + request + " since the destination is not ready"); + + // Check if the request is directed to a node with a pending authentication error. + for (Iterator> authErrorIter = + pendingAuthenticationErrors.entrySet().iterator(); authErrorIter.hasNext(); ) { + Map.Entry entry = authErrorIter.next(); + Node node = entry.getKey(); + long backoffMs = entry.getValue(); + if (node.idString().equals(request.destination())) { + authErrorIter.remove(); + // Set up a disconnected ClientResponse and create an authentication error + // for the affected node. + authenticationFailed(node, backoffMs); + AbstractRequest.Builder builder = request.requestBuilder(); + short version = nodeApiVersions.latestUsableVersion(request.apiKey(), builder.oldestAllowedVersion(), + builder.latestAllowedVersion()); + ClientResponse resp = new ClientResponse(request.makeHeader(version), request.callback(), request.destination(), + request.createdTimeMs(), time.milliseconds(), true, null, + new AuthenticationException("Authentication failed"), null); + responses.add(resp); + return; + } + } + Iterator iterator = futureResponses.iterator(); + while (iterator.hasNext()) { + FutureResponse futureResp = iterator.next(); + if (futureResp.node != null && !request.destination().equals(futureResp.node.idString())) + continue; + + AbstractRequest.Builder builder = request.requestBuilder(); + + try { + short version = nodeApiVersions.latestUsableVersion(request.apiKey(), builder.oldestAllowedVersion(), + builder.latestAllowedVersion()); + + UnsupportedVersionException unsupportedVersionException = null; + if (futureResp.isUnsupportedRequest) { + unsupportedVersionException = new UnsupportedVersionException( + "Api " + request.apiKey() + " with version " + version); + } else { + AbstractRequest abstractRequest = request.requestBuilder().build(version); + if (!futureResp.requestMatcher.matches(abstractRequest)) + throw new IllegalStateException("Request matcher did not match next-in-line request " + + abstractRequest + " with prepared response " + futureResp.responseBody); + } + + ClientResponse resp = new ClientResponse(request.makeHeader(version), request.callback(), request.destination(), + request.createdTimeMs(), time.milliseconds(), futureResp.disconnected, + unsupportedVersionException, null, futureResp.responseBody); + responses.add(resp); + } catch (UnsupportedVersionException unsupportedVersionException) { + ClientResponse resp = new ClientResponse(request.makeHeader(builder.latestAllowedVersion()), request.callback(), request.destination(), + request.createdTimeMs(), time.milliseconds(), false, unsupportedVersionException, null, null); + responses.add(resp); + } + iterator.remove(); + return; + } + + this.requests.add(request); + } + + /** + * Simulate a blocking poll in order to test wakeup behavior. + * + * @param numBlockingWakeups The number of polls which will block until woken up + */ + public synchronized void enableBlockingUntilWakeup(int numBlockingWakeups) { + this.numBlockingWakeups = numBlockingWakeups; + } + + @Override + public synchronized void wakeup() { + if (numBlockingWakeups > 0) { + numBlockingWakeups--; + notify(); + } + if (wakeupHook != null) { + wakeupHook.run(); + } + } + + private synchronized void maybeAwaitWakeup() { + try { + int remainingBlockingWakeups = numBlockingWakeups; + if (remainingBlockingWakeups <= 0) + return; + + TestUtils.waitForCondition(() -> { + if (numBlockingWakeups == remainingBlockingWakeups) + MockClient.this.wait(500); + return numBlockingWakeups < remainingBlockingWakeups; + }, 5000, "Failed to receive expected wakeup"); + } catch (InterruptedException e) { + throw new InterruptException(e); + } + } + + @Override + public List poll(long timeoutMs, long now) { + maybeAwaitWakeup(); + checkTimeoutOfPendingRequests(now); + + // We skip metadata updates if all nodes are currently blacked out + if (metadataUpdater.isUpdateNeeded() && leastLoadedNode(now) != null) { + MetadataUpdate metadataUpdate = metadataUpdates.poll(); + if (metadataUpdate != null) { + metadataUpdater.update(time, metadataUpdate); + } else { + metadataUpdater.updateWithCurrentMetadata(time); + } + } + + List copy = new ArrayList<>(); + ClientResponse response; + while ((response = this.responses.poll()) != null) { + response.onComplete(); + copy.add(response); + } + + return copy; + } + + private long elapsedTimeMs(long currentTimeMs, long startTimeMs) { + return Math.max(0, currentTimeMs - startTimeMs); + } + + private void checkTimeoutOfPendingRequests(long nowMs) { + ClientRequest request = requests.peek(); + while (request != null && elapsedTimeMs(nowMs, request.createdTimeMs()) >= request.requestTimeoutMs()) { + disconnect(request.destination()); + requests.poll(); + request = requests.peek(); + } + } + + public Queue requests() { + return this.requests; + } + + public Queue responses() { + return this.responses; + } + + public Queue futureResponses() { + return this.futureResponses; + } + + public void respond(AbstractResponse response) { + respond(response, false); + } + + public void respond(RequestMatcher matcher, AbstractResponse response) { + ClientRequest nextRequest = requests.peek(); + if (nextRequest == null) + throw new IllegalStateException("No current requests queued"); + + AbstractRequest request = nextRequest.requestBuilder().build(); + if (!matcher.matches(request)) + throw new IllegalStateException("Request matcher did not match next-in-line request " + request); + + respond(response); + } + + // Utility method to enable out of order responses + public void respondToRequest(ClientRequest clientRequest, AbstractResponse response) { + requests.remove(clientRequest); + short version = clientRequest.requestBuilder().latestAllowedVersion(); + responses.add(new ClientResponse(clientRequest.makeHeader(version), clientRequest.callback(), clientRequest.destination(), + clientRequest.createdTimeMs(), time.milliseconds(), false, null, null, response)); + } + + + public void respond(AbstractResponse response, boolean disconnected) { + if (requests.isEmpty()) + throw new IllegalStateException("No requests pending for inbound response " + response); + ClientRequest request = requests.poll(); + short version = request.requestBuilder().latestAllowedVersion(); + responses.add(new ClientResponse(request.makeHeader(version), request.callback(), request.destination(), + request.createdTimeMs(), time.milliseconds(), disconnected, null, null, response)); + } + + public void respondFrom(AbstractResponse response, Node node) { + respondFrom(response, node, false); + } + + public void respondFrom(AbstractResponse response, Node node, boolean disconnected) { + Iterator iterator = requests.iterator(); + while (iterator.hasNext()) { + ClientRequest request = iterator.next(); + if (request.destination().equals(node.idString())) { + iterator.remove(); + short version = request.requestBuilder().latestAllowedVersion(); + responses.add(new ClientResponse(request.makeHeader(version), request.callback(), request.destination(), + request.createdTimeMs(), time.milliseconds(), disconnected, null, null, response)); + return; + } + } + throw new IllegalArgumentException("No requests available to node " + node); + } + + public void prepareResponse(AbstractResponse response) { + prepareResponse(ALWAYS_TRUE, response, false); + } + + public void prepareResponseFrom(AbstractResponse response, Node node) { + prepareResponseFrom(ALWAYS_TRUE, response, node, false, false); + } + + /** + * Prepare a response for a request matching the provided matcher. If the matcher does not + * match, {@link KafkaClient#send(ClientRequest, long)} will throw IllegalStateException + * @param matcher The matcher to apply + * @param response The response body + */ + public void prepareResponse(RequestMatcher matcher, AbstractResponse response) { + prepareResponse(matcher, response, false); + } + + public void prepareResponseFrom(RequestMatcher matcher, AbstractResponse response, Node node) { + prepareResponseFrom(matcher, response, node, false, false); + } + + public void prepareResponseFrom(RequestMatcher matcher, AbstractResponse response, Node node, boolean disconnected) { + prepareResponseFrom(matcher, response, node, disconnected, false); + } + + public void prepareResponse(AbstractResponse response, boolean disconnected) { + prepareResponse(ALWAYS_TRUE, response, disconnected); + } + + public void prepareResponseFrom(AbstractResponse response, Node node, boolean disconnected) { + prepareResponseFrom(ALWAYS_TRUE, response, node, disconnected, false); + } + + /** + * Prepare a response for a request matching the provided matcher. If the matcher does not + * match, {@link KafkaClient#send(ClientRequest, long)} will throw IllegalStateException. + * @param matcher The request matcher to apply + * @param response The response body + * @param disconnected Whether the request was disconnected + */ + public void prepareResponse(RequestMatcher matcher, AbstractResponse response, boolean disconnected) { + prepareResponseFrom(matcher, response, null, disconnected, false); + } + + /** + * Raise an unsupported version error on the next request if it matches the given matcher. + * If the matcher does not match, {@link KafkaClient#send(ClientRequest, long)} will throw IllegalStateException. + * @param matcher The request matcher to apply + */ + public void prepareUnsupportedVersionResponse(RequestMatcher matcher) { + prepareResponseFrom(matcher, null, null, false, true); + } + + private void prepareResponseFrom(RequestMatcher matcher, + AbstractResponse response, + Node node, + boolean disconnected, + boolean isUnsupportedVersion) { + futureResponses.add(new FutureResponse(node, matcher, response, disconnected, isUnsupportedVersion)); + } + + public void waitForRequests(final int minRequests, long maxWaitMs) throws InterruptedException { + TestUtils.waitForCondition(new TestCondition() { + @Override + public boolean conditionMet() { + return requests.size() >= minRequests; + } + }, maxWaitMs, "Expected requests have not been sent"); + } + + public void reset() { + connections.clear(); + requests.clear(); + responses.clear(); + futureResponses.clear(); + metadataUpdates.clear(); + authenticationErrors.clear(); + } + + public boolean hasPendingMetadataUpdates() { + return !metadataUpdates.isEmpty(); + } + + public int numAwaitingResponses() { + return futureResponses.size(); + } + + public void prepareMetadataUpdate(MetadataResponse updateResponse) { + prepareMetadataUpdate(updateResponse, false); + } + + public void prepareMetadataUpdate(MetadataResponse updateResponse, + boolean expectMatchMetadataTopics) { + metadataUpdates.add(new MetadataUpdate(updateResponse, expectMatchMetadataTopics)); + } + + public void updateMetadata(MetadataResponse updateResponse) { + metadataUpdater.update(time, new MetadataUpdate(updateResponse, false)); + } + + @Override + public int inFlightRequestCount() { + return requests.size(); + } + + @Override + public boolean hasInFlightRequests() { + return !requests.isEmpty(); + } + + public boolean hasPendingResponses() { + return !responses.isEmpty() || !futureResponses.isEmpty(); + } + + @Override + public int inFlightRequestCount(String node) { + int result = 0; + for (ClientRequest req : requests) { + if (req.destination().equals(node)) + ++result; + } + return result; + } + + @Override + public boolean hasInFlightRequests(String node) { + return inFlightRequestCount(node) > 0; + } + + @Override + public boolean hasReadyNodes(long now) { + return connections.values().stream().anyMatch(cxn -> cxn.isReady(now)); + } + + @Override + public ClientRequest newClientRequest(String nodeId, AbstractRequest.Builder requestBuilder, long createdTimeMs, + boolean expectResponse) { + return newClientRequest(nodeId, requestBuilder, createdTimeMs, expectResponse, 5000, null); + } + + @Override + public ClientRequest newClientRequest(String nodeId, + AbstractRequest.Builder requestBuilder, + long createdTimeMs, + boolean expectResponse, + int requestTimeoutMs, + RequestCompletionHandler callback) { + return new ClientRequest(nodeId, requestBuilder, correlation++, "mockClientId", createdTimeMs, + expectResponse, requestTimeoutMs, callback); + } + + @Override + public void initiateClose() { + close(); + } + + @Override + public boolean active() { + return active; + } + + @Override + public void close() { + active = false; + metadataUpdater.close(); + } + + @Override + public void close(String node) { + connections.remove(node); + } + + @Override + public Node leastLoadedNode(long now) { + // Consistent with NetworkClient, we do not return nodes awaiting reconnect backoff + for (Node node : metadataUpdater.fetchNodes()) { + if (!connectionState(node.idString()).isBackingOff(now)) + return node; + } + return null; + } + + public void setWakeupHook(Runnable wakeupHook) { + this.wakeupHook = wakeupHook; + } + + /** + * The RequestMatcher provides a way to match a particular request to a response prepared + * through {@link #prepareResponse(RequestMatcher, AbstractResponse)}. Basically this allows testers + * to inspect the request body for the type of the request or for specific fields that should be set, + * and to fail the test if it doesn't match. + */ + @FunctionalInterface + public interface RequestMatcher { + boolean matches(AbstractRequest body); + } + + public void setNodeApiVersions(NodeApiVersions nodeApiVersions) { + this.nodeApiVersions = nodeApiVersions; + } + + public static class MetadataUpdate { + final MetadataResponse updateResponse; + final boolean expectMatchRefreshTopics; + + MetadataUpdate(MetadataResponse updateResponse, boolean expectMatchRefreshTopics) { + this.updateResponse = updateResponse; + this.expectMatchRefreshTopics = expectMatchRefreshTopics; + } + + private Set topics() { + return updateResponse.topicMetadata().stream() + .map(MetadataResponse.TopicMetadata::topic) + .collect(Collectors.toSet()); + } + } + + /** + * This is a dumbed down version of {@link MetadataUpdater} which is used to facilitate + * metadata tracking primarily in order to serve {@link KafkaClient#leastLoadedNode(long)} + * and bookkeeping through {@link Metadata}. The extensibility allows AdminClient, which does + * not rely on {@link Metadata} to do its own thing. + */ + public interface MockMetadataUpdater { + List fetchNodes(); + + boolean isUpdateNeeded(); + + void update(Time time, MetadataUpdate update); + + default void updateWithCurrentMetadata(Time time) {} + + default void close() {} + } + + private static class NoOpMetadataUpdater implements MockMetadataUpdater { + @Override + public List fetchNodes() { + return Collections.emptyList(); + } + + @Override + public boolean isUpdateNeeded() { + return false; + } + + @Override + public void update(Time time, MetadataUpdate update) { + throw new UnsupportedOperationException(); + } + } + + private static class DefaultMockMetadataUpdater implements MockMetadataUpdater { + private final Metadata metadata; + private MetadataUpdate lastUpdate; + + public DefaultMockMetadataUpdater(Metadata metadata) { + this.metadata = metadata; + } + + @Override + public List fetchNodes() { + return metadata.fetch().nodes(); + } + + @Override + public boolean isUpdateNeeded() { + return metadata.updateRequested(); + } + + @Override + public void updateWithCurrentMetadata(Time time) { + if (lastUpdate == null) + throw new IllegalStateException("No previous metadata update to use"); + update(time, lastUpdate); + } + + private void maybeCheckExpectedTopics(MetadataUpdate update, MetadataRequest.Builder builder) { + if (update.expectMatchRefreshTopics) { + if (builder.isAllTopics()) + throw new IllegalStateException("The metadata topics does not match expectation. " + + "Expected topics: " + update.topics() + + ", asked topics: ALL"); + + Set requestedTopics = new HashSet<>(builder.topics()); + if (!requestedTopics.equals(update.topics())) { + throw new IllegalStateException("The metadata topics does not match expectation. " + + "Expected topics: " + update.topics() + + ", asked topics: " + requestedTopics); + } + } + } + + @Override + public void update(Time time, MetadataUpdate update) { + MetadataRequest.Builder builder = metadata.newMetadataRequestBuilder(); + maybeCheckExpectedTopics(update, builder); + metadata.updateWithCurrentRequestVersion(update.updateResponse, false, time.milliseconds()); + this.lastUpdate = update; + } + + @Override + public void close() { + metadata.close(); + } + } + + private static class ConnectionState { + enum State { CONNECTING, CONNECTED, DISCONNECTED } + + private long throttledUntilMs = 0L; + private long readyDelayedUntilMs = 0L; + private long backingOffUntilMs = 0L; + private long unreachableUntilMs = 0L; + private State state = State.DISCONNECTED; + + void backoff(long untilMs) { + backingOffUntilMs = untilMs; + } + + void throttle(long untilMs) { + throttledUntilMs = untilMs; + } + + void setUnreachable(long untilMs) { + unreachableUntilMs = untilMs; + } + + void setReadyDelayed(long untilMs) { + readyDelayedUntilMs = untilMs; + } + + boolean isReady(long now) { + return state == State.CONNECTED && notThrottled(now); + } + + boolean isReadyDelayed(long now) { + return now < readyDelayedUntilMs; + } + + boolean notThrottled(long now) { + return now > throttledUntilMs; + } + + boolean isBackingOff(long now) { + return now < backingOffUntilMs; + } + + boolean isUnreachable(long now) { + return now < unreachableUntilMs; + } + + void disconnect() { + state = State.DISCONNECTED; + } + + long connectionDelay(long now) { + if (state != State.DISCONNECTED) + return Long.MAX_VALUE; + + if (backingOffUntilMs > now) + return backingOffUntilMs - now; + + return 0; + } + + boolean ready(long now) { + switch (state) { + case CONNECTED: + return notThrottled(now); + + case CONNECTING: + if (isReadyDelayed(now)) + return false; + state = State.CONNECTED; + return ready(now); + + case DISCONNECTED: + if (isBackingOff(now)) { + return false; + } else if (isUnreachable(now)) { + backingOffUntilMs = now + 100; + return false; + } + + state = State.CONNECTING; + return ready(now); + + default: + throw new IllegalArgumentException("Invalid state: " + state); + } + } + + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/NetworkClientTest.java b/clients/src/test/java/org/apache/kafka/clients/NetworkClientTest.java new file mode 100644 index 0000000..4fbfd42 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/NetworkClientTest.java @@ -0,0 +1,1157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients; + +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.internals.ClusterResourceListeners; +import org.apache.kafka.common.message.ApiMessageType; +import org.apache.kafka.common.message.ApiVersionsResponseData; +import org.apache.kafka.common.message.ApiVersionsResponseData.ApiVersion; +import org.apache.kafka.common.message.ApiVersionsResponseData.ApiVersionCollection; +import org.apache.kafka.common.message.ProduceRequestData; +import org.apache.kafka.common.message.ProduceResponseData; +import org.apache.kafka.common.network.NetworkReceive; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.AbstractResponse; +import org.apache.kafka.common.requests.ApiVersionsResponse; +import org.apache.kafka.common.requests.MetadataRequest; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.requests.ProduceRequest; +import org.apache.kafka.common.requests.ProduceResponse; +import org.apache.kafka.common.requests.RequestHeader; +import org.apache.kafka.common.requests.RequestTestUtils; +import org.apache.kafka.common.security.authenticator.SaslClientAuthenticator; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.test.DelayedReceive; +import org.apache.kafka.test.MockSelector; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.apache.kafka.common.protocol.ApiKeys.PRODUCE; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class NetworkClientTest { + + protected final int defaultRequestTimeoutMs = 1000; + protected final MockTime time = new MockTime(); + protected final MockSelector selector = new MockSelector(time); + protected final Node node = TestUtils.singletonCluster().nodes().iterator().next(); + protected final long reconnectBackoffMsTest = 10 * 1000; + protected final long reconnectBackoffMaxMsTest = 10 * 10000; + protected final long connectionSetupTimeoutMsTest = 5 * 1000; + protected final long connectionSetupTimeoutMaxMsTest = 127 * 1000; + private final TestMetadataUpdater metadataUpdater = new TestMetadataUpdater(Collections.singletonList(node)); + private final NetworkClient client = createNetworkClient(reconnectBackoffMaxMsTest); + private final NetworkClient clientWithNoExponentialBackoff = createNetworkClient(reconnectBackoffMsTest); + private final NetworkClient clientWithStaticNodes = createNetworkClientWithStaticNodes(); + private final NetworkClient clientWithNoVersionDiscovery = createNetworkClientWithNoVersionDiscovery(); + + private static ArrayList initialAddresses; + private static ArrayList newAddresses; + + static { + try { + initialAddresses = new ArrayList<>(Arrays.asList( + InetAddress.getByName("10.200.20.100"), + InetAddress.getByName("10.200.20.101"), + InetAddress.getByName("10.200.20.102") + )); + newAddresses = new ArrayList<>(Arrays.asList( + InetAddress.getByName("10.200.20.103"), + InetAddress.getByName("10.200.20.104"), + InetAddress.getByName("10.200.20.105") + )); + } catch (UnknownHostException e) { + fail("Attempted to create an invalid InetAddress, this should not happen"); + } + } + + private NetworkClient createNetworkClient(long reconnectBackoffMaxMs) { + return new NetworkClient(selector, metadataUpdater, "mock", Integer.MAX_VALUE, + reconnectBackoffMsTest, reconnectBackoffMaxMs, 64 * 1024, 64 * 1024, + defaultRequestTimeoutMs, connectionSetupTimeoutMsTest, connectionSetupTimeoutMaxMsTest, time, true, new ApiVersions(), new LogContext()); + } + + private NetworkClient createNetworkClientWithMultipleNodes(long reconnectBackoffMaxMs, long connectionSetupTimeoutMsTest, int nodeNumber) { + List nodes = TestUtils.clusterWith(nodeNumber).nodes(); + TestMetadataUpdater metadataUpdater = new TestMetadataUpdater(nodes); + return new NetworkClient(selector, metadataUpdater, "mock", Integer.MAX_VALUE, + reconnectBackoffMsTest, reconnectBackoffMaxMs, 64 * 1024, 64 * 1024, + defaultRequestTimeoutMs, connectionSetupTimeoutMsTest, connectionSetupTimeoutMaxMsTest, time, true, new ApiVersions(), new LogContext()); + } + + private NetworkClient createNetworkClientWithStaticNodes() { + return new NetworkClient(selector, metadataUpdater, + "mock-static", Integer.MAX_VALUE, 0, 0, 64 * 1024, 64 * 1024, defaultRequestTimeoutMs, + connectionSetupTimeoutMsTest, connectionSetupTimeoutMaxMsTest, time, true, new ApiVersions(), new LogContext()); + } + + private NetworkClient createNetworkClientWithNoVersionDiscovery(Metadata metadata) { + return new NetworkClient(selector, metadata, "mock", Integer.MAX_VALUE, + reconnectBackoffMsTest, 0, 64 * 1024, 64 * 1024, + defaultRequestTimeoutMs, connectionSetupTimeoutMsTest, connectionSetupTimeoutMaxMsTest, time, false, new ApiVersions(), new LogContext()); + } + + private NetworkClient createNetworkClientWithNoVersionDiscovery() { + return new NetworkClient(selector, metadataUpdater, "mock", Integer.MAX_VALUE, + reconnectBackoffMsTest, reconnectBackoffMaxMsTest, + 64 * 1024, 64 * 1024, defaultRequestTimeoutMs, + connectionSetupTimeoutMsTest, connectionSetupTimeoutMaxMsTest, time, false, new ApiVersions(), new LogContext()); + } + + @BeforeEach + public void setup() { + selector.reset(); + } + + @Test + public void testSendToUnreadyNode() { + MetadataRequest.Builder builder = new MetadataRequest.Builder(Collections.singletonList("test"), true); + long now = time.milliseconds(); + ClientRequest request = client.newClientRequest("5", builder, now, false); + assertThrows(IllegalStateException.class, () -> client.send(request, now)); + } + + @Test + public void testSimpleRequestResponse() { + checkSimpleRequestResponse(client); + } + + @Test + public void testSimpleRequestResponseWithStaticNodes() { + checkSimpleRequestResponse(clientWithStaticNodes); + } + + @Test + public void testSimpleRequestResponseWithNoBrokerDiscovery() { + checkSimpleRequestResponse(clientWithNoVersionDiscovery); + } + + @Test + public void testDnsLookupFailure() { + /* Fail cleanly when the node has a bad hostname */ + assertFalse(client.ready(new Node(1234, "badhost", 1234), time.milliseconds())); + } + + @Test + public void testClose() { + client.ready(node, time.milliseconds()); + awaitReady(client, node); + client.poll(1, time.milliseconds()); + assertTrue(client.isReady(node, time.milliseconds()), "The client should be ready"); + + ProduceRequest.Builder builder = ProduceRequest.forCurrentMagic(new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection()) + .setAcks((short) 1) + .setTimeoutMs(1000)); + ClientRequest request = client.newClientRequest(node.idString(), builder, time.milliseconds(), true); + client.send(request, time.milliseconds()); + assertEquals(1, client.inFlightRequestCount(node.idString()), + "There should be 1 in-flight request after send"); + assertTrue(client.hasInFlightRequests(node.idString())); + assertTrue(client.hasInFlightRequests()); + + client.close(node.idString()); + assertEquals(0, client.inFlightRequestCount(node.idString()), "There should be no in-flight request after close"); + assertFalse(client.hasInFlightRequests(node.idString())); + assertFalse(client.hasInFlightRequests()); + assertFalse(client.isReady(node, 0), "Connection should not be ready after close"); + } + + @Test + public void testUnsupportedVersionDuringInternalMetadataRequest() { + List topics = Collections.singletonList("topic_1"); + + // disabling auto topic creation for versions less than 4 is not supported + MetadataRequest.Builder builder = new MetadataRequest.Builder(topics, false, (short) 3); + client.sendInternalMetadataRequest(builder, node.idString(), time.milliseconds()); + assertEquals(UnsupportedVersionException.class, metadataUpdater.getAndClearFailure().getClass()); + } + + private void checkSimpleRequestResponse(NetworkClient networkClient) { + awaitReady(networkClient, node); // has to be before creating any request, as it may send ApiVersionsRequest and its response is mocked with correlation id 0 + short requestVersion = PRODUCE.latestVersion(); + ProduceRequest.Builder builder = new ProduceRequest.Builder( + requestVersion, + requestVersion, + new ProduceRequestData() + .setAcks((short) 1) + .setTimeoutMs(1000)); + TestCallbackHandler handler = new TestCallbackHandler(); + ClientRequest request = networkClient.newClientRequest(node.idString(), builder, time.milliseconds(), + true, defaultRequestTimeoutMs, handler); + networkClient.send(request, time.milliseconds()); + networkClient.poll(1, time.milliseconds()); + assertEquals(1, networkClient.inFlightRequestCount()); + ProduceResponse produceResponse = new ProduceResponse(new ProduceResponseData()); + ByteBuffer buffer = RequestTestUtils.serializeResponseWithHeader(produceResponse, requestVersion, request.correlationId()); + selector.completeReceive(new NetworkReceive(node.idString(), buffer)); + List responses = networkClient.poll(1, time.milliseconds()); + assertEquals(1, responses.size()); + assertTrue(handler.executed, "The handler should have executed."); + assertTrue(handler.response.hasResponse(), "Should have a response body."); + assertEquals(request.correlationId(), handler.response.requestHeader().correlationId(), + "Should be correlated to the original request"); + } + + private void delayedApiVersionsResponse(int correlationId, short version, ApiVersionsResponse response) { + ByteBuffer buffer = RequestTestUtils.serializeResponseWithHeader(response, version, correlationId); + selector.delayedReceive(new DelayedReceive(node.idString(), new NetworkReceive(node.idString(), buffer))); + } + + private void setExpectedApiVersionsResponse(ApiVersionsResponse response) { + short apiVersionsResponseVersion = response.apiVersion(ApiKeys.API_VERSIONS.id).maxVersion(); + delayedApiVersionsResponse(0, apiVersionsResponseVersion, response); + } + + private void awaitReady(NetworkClient client, Node node) { + if (client.discoverBrokerVersions()) { + setExpectedApiVersionsResponse(ApiVersionsResponse.defaultApiVersionsResponse( + ApiMessageType.ListenerType.ZK_BROKER)); + } + while (!client.ready(node, time.milliseconds())) + client.poll(1, time.milliseconds()); + selector.clear(); + } + + @Test + public void testInvalidApiVersionsRequest() { + // initiate the connection + client.ready(node, time.milliseconds()); + + // handle the connection, send the ApiVersionsRequest + client.poll(0, time.milliseconds()); + + // check that the ApiVersionsRequest has been initiated + assertTrue(client.hasInFlightRequests(node.idString())); + + // prepare response + delayedApiVersionsResponse(0, ApiKeys.API_VERSIONS.latestVersion(), + new ApiVersionsResponse( + new ApiVersionsResponseData() + .setErrorCode(Errors.INVALID_REQUEST.code()) + .setThrottleTimeMs(0) + )); + + // handle completed receives + client.poll(0, time.milliseconds()); + + // the ApiVersionsRequest is gone + assertFalse(client.hasInFlightRequests(node.idString())); + + // various assertions + assertFalse(client.isReady(node, time.milliseconds())); + } + + @Test + public void testApiVersionsRequest() { + // initiate the connection + client.ready(node, time.milliseconds()); + + // handle the connection, send the ApiVersionsRequest + client.poll(0, time.milliseconds()); + + // check that the ApiVersionsRequest has been initiated + assertTrue(client.hasInFlightRequests(node.idString())); + + // prepare response + delayedApiVersionsResponse(0, ApiKeys.API_VERSIONS.latestVersion(), defaultApiVersionsResponse()); + + // handle completed receives + client.poll(0, time.milliseconds()); + + // the ApiVersionsRequest is gone + assertFalse(client.hasInFlightRequests(node.idString())); + + // various assertions + assertTrue(client.isReady(node, time.milliseconds())); + } + + @Test + public void testUnsupportedApiVersionsRequestWithVersionProvidedByTheBroker() { + // initiate the connection + client.ready(node, time.milliseconds()); + + // handle the connection, initiate first ApiVersionsRequest + client.poll(0, time.milliseconds()); + + // ApiVersionsRequest is in flight but not sent yet + assertTrue(client.hasInFlightRequests(node.idString())); + + // completes initiated sends + client.poll(0, time.milliseconds()); + assertEquals(1, selector.completedSends().size()); + + ByteBuffer buffer = selector.completedSendBuffers().get(0).buffer(); + RequestHeader header = parseHeader(buffer); + assertEquals(ApiKeys.API_VERSIONS, header.apiKey()); + assertEquals(3, header.apiVersion()); + + // prepare response + ApiVersionCollection apiKeys = new ApiVersionCollection(); + apiKeys.add(new ApiVersion() + .setApiKey(ApiKeys.API_VERSIONS.id) + .setMinVersion((short) 0) + .setMaxVersion((short) 2)); + delayedApiVersionsResponse(0, (short) 0, + new ApiVersionsResponse( + new ApiVersionsResponseData() + .setErrorCode(Errors.UNSUPPORTED_VERSION.code()) + .setApiKeys(apiKeys) + )); + + // handle ApiVersionResponse, initiate second ApiVersionRequest + client.poll(0, time.milliseconds()); + + // ApiVersionsRequest is in flight but not sent yet + assertTrue(client.hasInFlightRequests(node.idString())); + + // ApiVersionsResponse has been received + assertEquals(1, selector.completedReceives().size()); + + // clean up the buffers + selector.completedSends().clear(); + selector.completedSendBuffers().clear(); + selector.completedReceives().clear(); + + // completes initiated sends + client.poll(0, time.milliseconds()); + + // ApiVersionsRequest has been sent + assertEquals(1, selector.completedSends().size()); + + buffer = selector.completedSendBuffers().get(0).buffer(); + header = parseHeader(buffer); + assertEquals(ApiKeys.API_VERSIONS, header.apiKey()); + assertEquals(2, header.apiVersion()); + + // prepare response + delayedApiVersionsResponse(1, (short) 0, defaultApiVersionsResponse()); + + // handle completed receives + client.poll(0, time.milliseconds()); + + // the ApiVersionsRequest is gone + assertFalse(client.hasInFlightRequests(node.idString())); + assertEquals(1, selector.completedReceives().size()); + + // the client is ready + assertTrue(client.isReady(node, time.milliseconds())); + } + + @Test + public void testUnsupportedApiVersionsRequestWithoutVersionProvidedByTheBroker() { + // initiate the connection + client.ready(node, time.milliseconds()); + + // handle the connection, initiate first ApiVersionsRequest + client.poll(0, time.milliseconds()); + + // ApiVersionsRequest is in flight but not sent yet + assertTrue(client.hasInFlightRequests(node.idString())); + + // completes initiated sends + client.poll(0, time.milliseconds()); + assertEquals(1, selector.completedSends().size()); + + ByteBuffer buffer = selector.completedSendBuffers().get(0).buffer(); + RequestHeader header = parseHeader(buffer); + assertEquals(ApiKeys.API_VERSIONS, header.apiKey()); + assertEquals(3, header.apiVersion()); + + // prepare response + delayedApiVersionsResponse(0, (short) 0, + new ApiVersionsResponse( + new ApiVersionsResponseData() + .setErrorCode(Errors.UNSUPPORTED_VERSION.code()) + )); + + // handle ApiVersionResponse, initiate second ApiVersionRequest + client.poll(0, time.milliseconds()); + + // ApiVersionsRequest is in flight but not sent yet + assertTrue(client.hasInFlightRequests(node.idString())); + + // ApiVersionsResponse has been received + assertEquals(1, selector.completedReceives().size()); + + // clean up the buffers + selector.completedSends().clear(); + selector.completedSendBuffers().clear(); + selector.completedReceives().clear(); + + // completes initiated sends + client.poll(0, time.milliseconds()); + + // ApiVersionsRequest has been sent + assertEquals(1, selector.completedSends().size()); + + buffer = selector.completedSendBuffers().get(0).buffer(); + header = parseHeader(buffer); + assertEquals(ApiKeys.API_VERSIONS, header.apiKey()); + assertEquals(0, header.apiVersion()); + + // prepare response + delayedApiVersionsResponse(1, (short) 0, defaultApiVersionsResponse()); + + // handle completed receives + client.poll(0, time.milliseconds()); + + // the ApiVersionsRequest is gone + assertFalse(client.hasInFlightRequests(node.idString())); + assertEquals(1, selector.completedReceives().size()); + + // the client is ready + assertTrue(client.isReady(node, time.milliseconds())); + } + + @Test + public void testRequestTimeout() { + awaitReady(client, node); // has to be before creating any request, as it may send ApiVersionsRequest and its response is mocked with correlation id 0 + ProduceRequest.Builder builder = ProduceRequest.forCurrentMagic(new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection()) + .setAcks((short) 1) + .setTimeoutMs(1000)); + TestCallbackHandler handler = new TestCallbackHandler(); + int requestTimeoutMs = defaultRequestTimeoutMs + 5000; + ClientRequest request = client.newClientRequest(node.idString(), builder, time.milliseconds(), true, + requestTimeoutMs, handler); + assertEquals(requestTimeoutMs, request.requestTimeoutMs()); + testRequestTimeout(request); + } + + @Test + public void testDefaultRequestTimeout() { + awaitReady(client, node); // has to be before creating any request, as it may send ApiVersionsRequest and its response is mocked with correlation id 0 + ProduceRequest.Builder builder = ProduceRequest.forCurrentMagic(new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection()) + .setAcks((short) 1) + .setTimeoutMs(1000)); + ClientRequest request = client.newClientRequest(node.idString(), builder, time.milliseconds(), true); + assertEquals(defaultRequestTimeoutMs, request.requestTimeoutMs()); + testRequestTimeout(request); + } + + private void testRequestTimeout(ClientRequest request) { + client.send(request, time.milliseconds()); + + time.sleep(request.requestTimeoutMs() + 1); + List responses = client.poll(0, time.milliseconds()); + + assertEquals(1, responses.size()); + ClientResponse clientResponse = responses.get(0); + assertEquals(node.idString(), clientResponse.destination()); + assertTrue(clientResponse.wasDisconnected(), "Expected response to fail due to disconnection"); + } + + @Test + public void testConnectionSetupTimeout() { + // Use two nodes to ensure that the logic iterate over a set of more than one + // element. ConcurrentModificationException is not triggered otherwise. + final Cluster cluster = TestUtils.clusterWith(2); + final Node node0 = cluster.nodeById(0); + final Node node1 = cluster.nodeById(1); + + client.ready(node0, time.milliseconds()); + selector.serverConnectionBlocked(node0.idString()); + + client.ready(node1, time.milliseconds()); + selector.serverConnectionBlocked(node1.idString()); + + client.poll(0, time.milliseconds()); + assertFalse(client.connectionFailed(node), + "The connections should not fail before the socket connection setup timeout elapsed"); + + time.sleep((long) (connectionSetupTimeoutMsTest * 1.2) + 1); + client.poll(0, time.milliseconds()); + assertTrue(client.connectionFailed(node), + "Expected the connections to fail due to the socket connection setup timeout"); + } + + @Test + public void testConnectionThrottling() { + // Instrument the test to return a response with a 100ms throttle delay. + awaitReady(client, node); + short requestVersion = PRODUCE.latestVersion(); + ProduceRequest.Builder builder = new ProduceRequest.Builder( + requestVersion, + requestVersion, + new ProduceRequestData() + .setAcks((short) 1) + .setTimeoutMs(1000)); + TestCallbackHandler handler = new TestCallbackHandler(); + ClientRequest request = client.newClientRequest(node.idString(), builder, time.milliseconds(), true, + defaultRequestTimeoutMs, handler); + client.send(request, time.milliseconds()); + client.poll(1, time.milliseconds()); + int throttleTime = 100; + ProduceResponse produceResponse = new ProduceResponse(new ProduceResponseData().setThrottleTimeMs(throttleTime)); + ByteBuffer buffer = RequestTestUtils.serializeResponseWithHeader(produceResponse, requestVersion, request.correlationId()); + selector.completeReceive(new NetworkReceive(node.idString(), buffer)); + client.poll(1, time.milliseconds()); + + // The connection is not ready due to throttling. + assertFalse(client.ready(node, time.milliseconds())); + assertEquals(100, client.throttleDelayMs(node, time.milliseconds())); + + // After 50ms, the connection is not ready yet. + time.sleep(50); + assertFalse(client.ready(node, time.milliseconds())); + assertEquals(50, client.throttleDelayMs(node, time.milliseconds())); + + // After another 50ms, the throttling is done and the connection becomes ready again. + time.sleep(50); + assertTrue(client.ready(node, time.milliseconds())); + assertEquals(0, client.throttleDelayMs(node, time.milliseconds())); + } + + // Creates expected ApiVersionsResponse from the specified node, where the max protocol version for the specified + // key is set to the specified version. + private ApiVersionsResponse createExpectedApiVersionsResponse(ApiKeys key, short maxVersion) { + ApiVersionCollection versionList = new ApiVersionCollection(); + for (ApiKeys apiKey : ApiKeys.values()) { + if (apiKey == key) { + versionList.add(new ApiVersion() + .setApiKey(apiKey.id) + .setMinVersion((short) 0) + .setMaxVersion(maxVersion)); + } else versionList.add(ApiVersionsResponse.toApiVersion(apiKey)); + } + return new ApiVersionsResponse(new ApiVersionsResponseData() + .setErrorCode(Errors.NONE.code()) + .setThrottleTimeMs(0) + .setApiKeys(versionList)); + } + + @Test + public void testThrottlingNotEnabledForConnectionToOlderBroker() { + // Instrument the test so that the max protocol version for PRODUCE returned from the node is 5 and thus + // client-side throttling is not enabled. Also, return a response with a 100ms throttle delay. + setExpectedApiVersionsResponse(createExpectedApiVersionsResponse(PRODUCE, (short) 5)); + while (!client.ready(node, time.milliseconds())) + client.poll(1, time.milliseconds()); + selector.clear(); + + int correlationId = sendEmptyProduceRequest(); + client.poll(1, time.milliseconds()); + + sendThrottledProduceResponse(correlationId, 100, (short) 5); + client.poll(1, time.milliseconds()); + + // Since client-side throttling is disabled, the connection is ready even though the response indicated a + // throttle delay. + assertTrue(client.ready(node, time.milliseconds())); + assertEquals(0, client.throttleDelayMs(node, time.milliseconds())); + } + + private int sendEmptyProduceRequest() { + return sendEmptyProduceRequest(node.idString()); + } + + private int sendEmptyProduceRequest(String nodeId) { + ProduceRequest.Builder builder = ProduceRequest.forCurrentMagic(new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection()) + .setAcks((short) 1) + .setTimeoutMs(1000)); + TestCallbackHandler handler = new TestCallbackHandler(); + ClientRequest request = client.newClientRequest(nodeId, builder, time.milliseconds(), true, + defaultRequestTimeoutMs, handler); + client.send(request, time.milliseconds()); + return request.correlationId(); + } + + private void sendResponse(AbstractResponse response, short version, int correlationId) { + ByteBuffer buffer = RequestTestUtils.serializeResponseWithHeader(response, version, correlationId); + selector.completeReceive(new NetworkReceive(node.idString(), buffer)); + } + + private void sendThrottledProduceResponse(int correlationId, int throttleMs, short version) { + ProduceResponse response = new ProduceResponse(new ProduceResponseData().setThrottleTimeMs(throttleMs)); + sendResponse(response, version, correlationId); + } + + @Test + public void testLeastLoadedNode() { + client.ready(node, time.milliseconds()); + assertFalse(client.isReady(node, time.milliseconds())); + assertEquals(node, client.leastLoadedNode(time.milliseconds())); + + awaitReady(client, node); + client.poll(1, time.milliseconds()); + assertTrue(client.isReady(node, time.milliseconds()), "The client should be ready"); + + // leastloadednode should be our single node + Node leastNode = client.leastLoadedNode(time.milliseconds()); + assertEquals(leastNode.id(), node.id(), "There should be one leastloadednode"); + + // sleep for longer than reconnect backoff + time.sleep(reconnectBackoffMsTest); + + // CLOSE node + selector.serverDisconnect(node.idString()); + + client.poll(1, time.milliseconds()); + assertFalse(client.ready(node, time.milliseconds()), "After we forced the disconnection the client is no longer ready."); + leastNode = client.leastLoadedNode(time.milliseconds()); + assertNull(leastNode, "There should be NO leastloadednode"); + } + + @Test + public void testLeastLoadedNodeProvideDisconnectedNodesPrioritizedByLastConnectionTimestamp() { + int nodeNumber = 3; + NetworkClient client = createNetworkClientWithMultipleNodes(0, connectionSetupTimeoutMsTest, nodeNumber); + + Set providedNodeIds = new HashSet<>(); + for (int i = 0; i < nodeNumber * 10; i++) { + Node node = client.leastLoadedNode(time.milliseconds()); + assertNotNull(node, "Should provide a node"); + providedNodeIds.add(node); + client.ready(node, time.milliseconds()); + client.disconnect(node.idString()); + time.sleep(connectionSetupTimeoutMsTest + 1); + client.poll(0, time.milliseconds()); + // Define a round as nodeNumber of nodes have been provided + // In each round every node should be provided exactly once + if ((i + 1) % nodeNumber == 0) { + assertEquals(nodeNumber, providedNodeIds.size(), "All the nodes should be provided"); + providedNodeIds.clear(); + } + } + } + + @Test + public void testAuthenticationFailureWithInFlightMetadataRequest() { + int refreshBackoffMs = 50; + + MetadataResponse metadataResponse = RequestTestUtils.metadataUpdateWith(2, Collections.emptyMap()); + Metadata metadata = new Metadata(refreshBackoffMs, 5000, new LogContext(), new ClusterResourceListeners()); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, time.milliseconds()); + + Cluster cluster = metadata.fetch(); + Node node1 = cluster.nodes().get(0); + Node node2 = cluster.nodes().get(1); + + NetworkClient client = createNetworkClientWithNoVersionDiscovery(metadata); + + awaitReady(client, node1); + + metadata.requestUpdate(); + time.sleep(refreshBackoffMs); + + client.poll(0, time.milliseconds()); + + Optional nodeWithPendingMetadataOpt = cluster.nodes().stream() + .filter(node -> client.hasInFlightRequests(node.idString())) + .findFirst(); + assertEquals(Optional.of(node1), nodeWithPendingMetadataOpt); + + assertFalse(client.ready(node2, time.milliseconds())); + selector.serverAuthenticationFailed(node2.idString()); + client.poll(0, time.milliseconds()); + assertNotNull(client.authenticationException(node2)); + + ByteBuffer requestBuffer = selector.completedSendBuffers().get(0).buffer(); + RequestHeader header = parseHeader(requestBuffer); + assertEquals(ApiKeys.METADATA, header.apiKey()); + + ByteBuffer responseBuffer = RequestTestUtils.serializeResponseWithHeader(metadataResponse, header.apiVersion(), header.correlationId()); + selector.delayedReceive(new DelayedReceive(node1.idString(), new NetworkReceive(node1.idString(), responseBuffer))); + + int initialUpdateVersion = metadata.updateVersion(); + client.poll(0, time.milliseconds()); + assertEquals(initialUpdateVersion + 1, metadata.updateVersion()); + } + + @Test + public void testLeastLoadedNodeConsidersThrottledConnections() { + client.ready(node, time.milliseconds()); + awaitReady(client, node); + client.poll(1, time.milliseconds()); + assertTrue(client.isReady(node, time.milliseconds()), "The client should be ready"); + + int correlationId = sendEmptyProduceRequest(); + client.poll(1, time.milliseconds()); + + sendThrottledProduceResponse(correlationId, 100, PRODUCE.latestVersion()); + client.poll(1, time.milliseconds()); + + // leastloadednode should return null since the node is throttled + assertNull(client.leastLoadedNode(time.milliseconds())); + } + + @Test + public void testConnectionDelayWithNoExponentialBackoff() { + long now = time.milliseconds(); + long delay = clientWithNoExponentialBackoff.connectionDelay(node, now); + + assertEquals(0, delay); + } + + @Test + public void testConnectionDelayConnectedWithNoExponentialBackoff() { + awaitReady(clientWithNoExponentialBackoff, node); + + long now = time.milliseconds(); + long delay = clientWithNoExponentialBackoff.connectionDelay(node, now); + + assertEquals(Long.MAX_VALUE, delay); + } + + @Test + public void testConnectionDelayDisconnectedWithNoExponentialBackoff() { + awaitReady(clientWithNoExponentialBackoff, node); + + selector.serverDisconnect(node.idString()); + clientWithNoExponentialBackoff.poll(defaultRequestTimeoutMs, time.milliseconds()); + long delay = clientWithNoExponentialBackoff.connectionDelay(node, time.milliseconds()); + + assertEquals(reconnectBackoffMsTest, delay); + + // Sleep until there is no connection delay + time.sleep(delay); + assertEquals(0, clientWithNoExponentialBackoff.connectionDelay(node, time.milliseconds())); + + // Start connecting and disconnect before the connection is established + client.ready(node, time.milliseconds()); + selector.serverDisconnect(node.idString()); + client.poll(defaultRequestTimeoutMs, time.milliseconds()); + + // Second attempt should have the same behaviour as exponential backoff is disabled + assertEquals(reconnectBackoffMsTest, delay); + } + + @Test + public void testConnectionDelay() { + long now = time.milliseconds(); + long delay = client.connectionDelay(node, now); + + assertEquals(0, delay); + } + + @Test + public void testConnectionDelayConnected() { + awaitReady(client, node); + + long now = time.milliseconds(); + long delay = client.connectionDelay(node, now); + + assertEquals(Long.MAX_VALUE, delay); + } + + @Test + public void testConnectionDelayDisconnected() { + awaitReady(client, node); + + // First disconnection + selector.serverDisconnect(node.idString()); + client.poll(defaultRequestTimeoutMs, time.milliseconds()); + long delay = client.connectionDelay(node, time.milliseconds()); + long expectedDelay = reconnectBackoffMsTest; + double jitter = 0.3; + assertEquals(expectedDelay, delay, expectedDelay * jitter); + + // Sleep until there is no connection delay + time.sleep(delay); + assertEquals(0, client.connectionDelay(node, time.milliseconds())); + + // Start connecting and disconnect before the connection is established + client.ready(node, time.milliseconds()); + selector.serverDisconnect(node.idString()); + client.poll(defaultRequestTimeoutMs, time.milliseconds()); + + // Second attempt should take twice as long with twice the jitter + expectedDelay = Math.round(delay * 2); + delay = client.connectionDelay(node, time.milliseconds()); + jitter = 0.6; + assertEquals(expectedDelay, delay, expectedDelay * jitter); + } + + @Test + public void testDisconnectDuringUserMetadataRequest() { + // this test ensures that the default metadata updater does not intercept a user-initiated + // metadata request when the remote node disconnects with the request in-flight. + awaitReady(client, node); + + MetadataRequest.Builder builder = new MetadataRequest.Builder(Collections.emptyList(), true); + long now = time.milliseconds(); + ClientRequest request = client.newClientRequest(node.idString(), builder, now, true); + client.send(request, now); + client.poll(defaultRequestTimeoutMs, now); + assertEquals(1, client.inFlightRequestCount(node.idString())); + assertTrue(client.hasInFlightRequests(node.idString())); + assertTrue(client.hasInFlightRequests()); + + selector.close(node.idString()); + List responses = client.poll(defaultRequestTimeoutMs, time.milliseconds()); + assertEquals(1, responses.size()); + assertTrue(responses.iterator().next().wasDisconnected()); + } + + @Test + public void testServerDisconnectAfterInternalApiVersionRequest() throws Exception { + awaitInFlightApiVersionRequest(); + selector.serverDisconnect(node.idString()); + + // The failed ApiVersion request should not be forwarded to upper layers + List responses = client.poll(0, time.milliseconds()); + assertFalse(client.hasInFlightRequests(node.idString())); + assertTrue(responses.isEmpty()); + } + + @Test + public void testClientDisconnectAfterInternalApiVersionRequest() throws Exception { + awaitInFlightApiVersionRequest(); + client.disconnect(node.idString()); + assertFalse(client.hasInFlightRequests(node.idString())); + + // The failed ApiVersion request should not be forwarded to upper layers + List responses = client.poll(0, time.milliseconds()); + assertTrue(responses.isEmpty()); + } + + @Test + public void testDisconnectWithMultipleInFlights() { + NetworkClient client = this.clientWithNoVersionDiscovery; + awaitReady(client, node); + assertTrue(client.isReady(node, time.milliseconds()), + "Expected NetworkClient to be ready to send to node " + node.idString()); + + MetadataRequest.Builder builder = new MetadataRequest.Builder(Collections.emptyList(), true); + long now = time.milliseconds(); + + final List callbackResponses = new ArrayList<>(); + RequestCompletionHandler callback = callbackResponses::add; + + ClientRequest request1 = client.newClientRequest(node.idString(), builder, now, true, defaultRequestTimeoutMs, callback); + client.send(request1, now); + client.poll(0, now); + + ClientRequest request2 = client.newClientRequest(node.idString(), builder, now, true, defaultRequestTimeoutMs, callback); + client.send(request2, now); + client.poll(0, now); + + assertNotEquals(request1.correlationId(), request2.correlationId()); + + assertEquals(2, client.inFlightRequestCount()); + assertEquals(2, client.inFlightRequestCount(node.idString())); + + client.disconnect(node.idString()); + + List responses = client.poll(0, time.milliseconds()); + assertEquals(2, responses.size()); + assertEquals(responses, callbackResponses); + assertEquals(0, client.inFlightRequestCount()); + assertEquals(0, client.inFlightRequestCount(node.idString())); + + // Ensure that the responses are returned in the order they were sent + ClientResponse response1 = responses.get(0); + assertTrue(response1.wasDisconnected()); + assertEquals(request1.correlationId(), response1.requestHeader().correlationId()); + + ClientResponse response2 = responses.get(1); + assertTrue(response2.wasDisconnected()); + assertEquals(request2.correlationId(), response2.requestHeader().correlationId()); + } + + @Test + public void testCallDisconnect() throws Exception { + awaitReady(client, node); + assertTrue(client.isReady(node, time.milliseconds()), + "Expected NetworkClient to be ready to send to node " + node.idString()); + assertFalse(client.connectionFailed(node), + "Did not expect connection to node " + node.idString() + " to be failed"); + client.disconnect(node.idString()); + assertFalse(client.isReady(node, time.milliseconds()), + "Expected node " + node.idString() + " to be disconnected."); + assertTrue(client.connectionFailed(node), + "Expected connection to node " + node.idString() + " to be failed after disconnect"); + assertFalse(client.canConnect(node, time.milliseconds())); + + // ensure disconnect does not reset backoff period if already disconnected + time.sleep(reconnectBackoffMaxMsTest); + assertTrue(client.canConnect(node, time.milliseconds())); + client.disconnect(node.idString()); + assertTrue(client.canConnect(node, time.milliseconds())); + } + + @Test + public void testCorrelationId() { + int count = 100; + Set ids = IntStream.range(0, count) + .mapToObj(i -> client.nextCorrelationId()) + .collect(Collectors.toSet()); + assertEquals(count, ids.size()); + ids.forEach(id -> assertTrue(id < SaslClientAuthenticator.MIN_RESERVED_CORRELATION_ID)); + } + + @Test + public void testReconnectAfterAddressChange() { + AddressChangeHostResolver mockHostResolver = new AddressChangeHostResolver( + initialAddresses.toArray(new InetAddress[0]), newAddresses.toArray(new InetAddress[0])); + AtomicInteger initialAddressConns = new AtomicInteger(); + AtomicInteger newAddressConns = new AtomicInteger(); + MockSelector selector = new MockSelector(this.time, inetSocketAddress -> { + InetAddress inetAddress = inetSocketAddress.getAddress(); + if (initialAddresses.contains(inetAddress)) { + initialAddressConns.incrementAndGet(); + } else if (newAddresses.contains(inetAddress)) { + newAddressConns.incrementAndGet(); + } + return (mockHostResolver.useNewAddresses() && newAddresses.contains(inetAddress)) || + (!mockHostResolver.useNewAddresses() && initialAddresses.contains(inetAddress)); + }); + NetworkClient client = new NetworkClient(metadataUpdater, null, selector, "mock", Integer.MAX_VALUE, + reconnectBackoffMsTest, reconnectBackoffMaxMsTest, 64 * 1024, 64 * 1024, + defaultRequestTimeoutMs, connectionSetupTimeoutMsTest, connectionSetupTimeoutMaxMsTest, + time, false, new ApiVersions(), null, new LogContext(), mockHostResolver); + + // Connect to one the initial addresses, then change the addresses and disconnect + client.ready(node, time.milliseconds()); + time.sleep(connectionSetupTimeoutMaxMsTest); + client.poll(0, time.milliseconds()); + assertTrue(client.isReady(node, time.milliseconds())); + + mockHostResolver.changeAddresses(); + selector.serverDisconnect(node.idString()); + client.poll(0, time.milliseconds()); + assertFalse(client.isReady(node, time.milliseconds())); + + time.sleep(reconnectBackoffMaxMsTest); + client.ready(node, time.milliseconds()); + time.sleep(connectionSetupTimeoutMaxMsTest); + client.poll(0, time.milliseconds()); + assertTrue(client.isReady(node, time.milliseconds())); + + // We should have tried to connect to one initial address and one new address, and resolved DNS twice + assertEquals(1, initialAddressConns.get()); + assertEquals(1, newAddressConns.get()); + assertEquals(2, mockHostResolver.resolutionCount()); + } + + @Test + public void testFailedConnectionToFirstAddress() { + AddressChangeHostResolver mockHostResolver = new AddressChangeHostResolver( + initialAddresses.toArray(new InetAddress[0]), newAddresses.toArray(new InetAddress[0])); + AtomicInteger initialAddressConns = new AtomicInteger(); + AtomicInteger newAddressConns = new AtomicInteger(); + MockSelector selector = new MockSelector(this.time, inetSocketAddress -> { + InetAddress inetAddress = inetSocketAddress.getAddress(); + if (initialAddresses.contains(inetAddress)) { + initialAddressConns.incrementAndGet(); + } else if (newAddresses.contains(inetAddress)) { + newAddressConns.incrementAndGet(); + } + // Refuse first connection attempt + return initialAddressConns.get() > 1; + }); + NetworkClient client = new NetworkClient(metadataUpdater, null, selector, "mock", Integer.MAX_VALUE, + reconnectBackoffMsTest, reconnectBackoffMaxMsTest, 64 * 1024, 64 * 1024, + defaultRequestTimeoutMs, connectionSetupTimeoutMsTest, connectionSetupTimeoutMaxMsTest, + time, false, new ApiVersions(), null, new LogContext(), mockHostResolver); + + // First connection attempt should fail + client.ready(node, time.milliseconds()); + time.sleep(connectionSetupTimeoutMaxMsTest); + client.poll(0, time.milliseconds()); + assertFalse(client.isReady(node, time.milliseconds())); + + // Second connection attempt should succeed + time.sleep(reconnectBackoffMaxMsTest); + client.ready(node, time.milliseconds()); + time.sleep(connectionSetupTimeoutMaxMsTest); + client.poll(0, time.milliseconds()); + assertTrue(client.isReady(node, time.milliseconds())); + + // We should have tried to connect to two of the initial addresses, none of the new address, and should + // only have resolved DNS once + assertEquals(2, initialAddressConns.get()); + assertEquals(0, newAddressConns.get()); + assertEquals(1, mockHostResolver.resolutionCount()); + } + + @Test + public void testFailedConnectionToFirstAddressAfterReconnect() { + AddressChangeHostResolver mockHostResolver = new AddressChangeHostResolver( + initialAddresses.toArray(new InetAddress[0]), newAddresses.toArray(new InetAddress[0])); + AtomicInteger initialAddressConns = new AtomicInteger(); + AtomicInteger newAddressConns = new AtomicInteger(); + MockSelector selector = new MockSelector(this.time, inetSocketAddress -> { + InetAddress inetAddress = inetSocketAddress.getAddress(); + if (initialAddresses.contains(inetAddress)) { + initialAddressConns.incrementAndGet(); + } else if (newAddresses.contains(inetAddress)) { + newAddressConns.incrementAndGet(); + } + // Refuse first connection attempt to the new addresses + return initialAddresses.contains(inetAddress) || newAddressConns.get() > 1; + }); + NetworkClient client = new NetworkClient(metadataUpdater, null, selector, "mock", Integer.MAX_VALUE, + reconnectBackoffMsTest, reconnectBackoffMaxMsTest, 64 * 1024, 64 * 1024, + defaultRequestTimeoutMs, connectionSetupTimeoutMsTest, connectionSetupTimeoutMaxMsTest, + time, false, new ApiVersions(), null, new LogContext(), mockHostResolver); + + // Connect to one the initial addresses, then change the addresses and disconnect + client.ready(node, time.milliseconds()); + time.sleep(connectionSetupTimeoutMaxMsTest); + client.poll(0, time.milliseconds()); + assertTrue(client.isReady(node, time.milliseconds())); + + mockHostResolver.changeAddresses(); + selector.serverDisconnect(node.idString()); + client.poll(0, time.milliseconds()); + assertFalse(client.isReady(node, time.milliseconds())); + + // First connection attempt to new addresses should fail + time.sleep(reconnectBackoffMaxMsTest); + client.ready(node, time.milliseconds()); + time.sleep(connectionSetupTimeoutMaxMsTest); + client.poll(0, time.milliseconds()); + assertFalse(client.isReady(node, time.milliseconds())); + + // Second connection attempt to new addresses should succeed + time.sleep(reconnectBackoffMaxMsTest); + client.ready(node, time.milliseconds()); + time.sleep(connectionSetupTimeoutMaxMsTest); + client.poll(0, time.milliseconds()); + assertTrue(client.isReady(node, time.milliseconds())); + + // We should have tried to connect to one of the initial addresses and two of the new addresses (the first one + // failed), and resolved DNS twice, once for each set of addresses + assertEquals(1, initialAddressConns.get()); + assertEquals(2, newAddressConns.get()); + assertEquals(2, mockHostResolver.resolutionCount()); + } + + @Test + public void testCloseConnectingNode() { + Cluster cluster = TestUtils.clusterWith(2); + Node node0 = cluster.nodeById(0); + Node node1 = cluster.nodeById(1); + client.ready(node0, time.milliseconds()); + selector.serverConnectionBlocked(node0.idString()); + client.poll(1, time.milliseconds()); + client.close(node0.idString()); + + // Poll without any connections should return without exceptions + client.poll(0, time.milliseconds()); + assertFalse(NetworkClientUtils.isReady(client, node0, time.milliseconds())); + assertFalse(NetworkClientUtils.isReady(client, node1, time.milliseconds())); + + // Connection to new node should work + client.ready(node1, time.milliseconds()); + ByteBuffer buffer = RequestTestUtils.serializeResponseWithHeader(defaultApiVersionsResponse(), ApiKeys.API_VERSIONS.latestVersion(), 0); + selector.delayedReceive(new DelayedReceive(node1.idString(), new NetworkReceive(node1.idString(), buffer))); + while (!client.ready(node1, time.milliseconds())) + client.poll(1, time.milliseconds()); + assertTrue(client.isReady(node1, time.milliseconds())); + selector.clear(); + + // New connection to node closed earlier should work + client.ready(node0, time.milliseconds()); + buffer = RequestTestUtils.serializeResponseWithHeader(defaultApiVersionsResponse(), ApiKeys.API_VERSIONS.latestVersion(), 1); + selector.delayedReceive(new DelayedReceive(node0.idString(), new NetworkReceive(node0.idString(), buffer))); + while (!client.ready(node0, time.milliseconds())) + client.poll(1, time.milliseconds()); + assertTrue(client.isReady(node0, time.milliseconds())); + } + + private RequestHeader parseHeader(ByteBuffer buffer) { + buffer.getInt(); // skip size + return RequestHeader.parse(buffer.slice()); + } + + private void awaitInFlightApiVersionRequest() throws Exception { + client.ready(node, time.milliseconds()); + TestUtils.waitForCondition(() -> { + client.poll(0, time.milliseconds()); + return client.hasInFlightRequests(node.idString()); + }, 1000, ""); + assertFalse(client.isReady(node, time.milliseconds())); + } + + private ApiVersionsResponse defaultApiVersionsResponse() { + return ApiVersionsResponse.defaultApiVersionsResponse(ApiMessageType.ListenerType.ZK_BROKER); + } + + private static class TestCallbackHandler implements RequestCompletionHandler { + public boolean executed = false; + public ClientResponse response; + + public void onComplete(ClientResponse response) { + this.executed = true; + this.response = response; + } + } + + // ManualMetadataUpdater with ability to keep track of failures + private static class TestMetadataUpdater extends ManualMetadataUpdater { + KafkaException failure; + + public TestMetadataUpdater(List nodes) { + super(nodes); + } + + @Override + public void handleServerDisconnect(long now, String destinationId, Optional maybeAuthException) { + maybeAuthException.ifPresent(exception -> { + failure = exception; + }); + super.handleServerDisconnect(now, destinationId, maybeAuthException); + } + + @Override + public void handleFailedRequest(long now, Optional maybeFatalException) { + maybeFatalException.ifPresent(exception -> { + failure = exception; + }); + } + + public KafkaException getAndClearFailure() { + KafkaException failure = this.failure; + this.failure = null; + return failure; + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/NodeApiVersionsTest.java b/clients/src/test/java/org/apache/kafka/clients/NodeApiVersionsTest.java new file mode 100644 index 0000000..b04d83b --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/NodeApiVersionsTest.java @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients; + +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.ApiMessageType; +import org.apache.kafka.common.message.ApiVersionsResponseData.ApiVersion; +import org.apache.kafka.common.message.ApiVersionsResponseData.ApiVersionCollection; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.requests.ApiVersionsResponse; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; + +import java.util.ArrayList; +import java.util.LinkedList; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class NodeApiVersionsTest { + + @Test + public void testUnsupportedVersionsToString() { + NodeApiVersions versions = new NodeApiVersions(new ApiVersionCollection()); + StringBuilder bld = new StringBuilder(); + String prefix = "("; + for (ApiKeys apiKey : ApiKeys.zkBrokerApis()) { + bld.append(prefix).append(apiKey.name). + append("(").append(apiKey.id).append("): UNSUPPORTED"); + prefix = ", "; + } + bld.append(")"); + assertEquals(bld.toString(), versions.toString()); + } + + @Test + public void testUnknownApiVersionsToString() { + NodeApiVersions versions = NodeApiVersions.create((short) 337, (short) 0, (short) 1); + assertTrue(versions.toString().endsWith("UNKNOWN(337): 0 to 1)")); + } + + @Test + public void testVersionsToString() { + List versionList = new ArrayList<>(); + for (ApiKeys apiKey : ApiKeys.values()) { + if (apiKey == ApiKeys.DELETE_TOPICS) { + versionList.add(new ApiVersion() + .setApiKey(apiKey.id) + .setMinVersion((short) 10000) + .setMaxVersion((short) 10001)); + } else versionList.add(ApiVersionsResponse.toApiVersion(apiKey)); + } + NodeApiVersions versions = new NodeApiVersions(versionList); + StringBuilder bld = new StringBuilder(); + String prefix = "("; + for (ApiKeys apiKey : ApiKeys.values()) { + bld.append(prefix); + if (apiKey == ApiKeys.DELETE_TOPICS) { + bld.append("DeleteTopics(20): 10000 to 10001 [unusable: node too new]"); + } else { + bld.append(apiKey.name).append("("). + append(apiKey.id).append("): "); + if (apiKey.oldestVersion() == + apiKey.latestVersion()) { + bld.append(apiKey.oldestVersion()); + } else { + bld.append(apiKey.oldestVersion()). + append(" to "). + append(apiKey.latestVersion()); + } + bld.append(" [usable: ").append(apiKey.latestVersion()). + append("]"); + } + prefix = ", "; + } + bld.append(")"); + assertEquals(bld.toString(), versions.toString()); + } + + @Test + public void testLatestUsableVersion() { + NodeApiVersions apiVersions = NodeApiVersions.create(ApiKeys.PRODUCE.id, (short) 1, (short) 3); + assertEquals(3, apiVersions.latestUsableVersion(ApiKeys.PRODUCE)); + assertEquals(1, apiVersions.latestUsableVersion(ApiKeys.PRODUCE, (short) 0, (short) 1)); + assertEquals(1, apiVersions.latestUsableVersion(ApiKeys.PRODUCE, (short) 1, (short) 1)); + assertEquals(2, apiVersions.latestUsableVersion(ApiKeys.PRODUCE, (short) 1, (short) 2)); + assertEquals(3, apiVersions.latestUsableVersion(ApiKeys.PRODUCE, (short) 1, (short) 3)); + assertEquals(2, apiVersions.latestUsableVersion(ApiKeys.PRODUCE, (short) 2, (short) 2)); + assertEquals(3, apiVersions.latestUsableVersion(ApiKeys.PRODUCE, (short) 2, (short) 3)); + assertEquals(3, apiVersions.latestUsableVersion(ApiKeys.PRODUCE, (short) 3, (short) 3)); + assertEquals(3, apiVersions.latestUsableVersion(ApiKeys.PRODUCE, (short) 3, (short) 4)); + } + + @Test + public void testLatestUsableVersionOutOfRangeLow() { + NodeApiVersions apiVersions = NodeApiVersions.create(ApiKeys.PRODUCE.id, (short) 1, (short) 2); + assertThrows(UnsupportedVersionException.class, + () -> apiVersions.latestUsableVersion(ApiKeys.PRODUCE, (short) 3, (short) 4)); + } + + @Test + public void testLatestUsableVersionOutOfRangeHigh() { + NodeApiVersions apiVersions = NodeApiVersions.create(ApiKeys.PRODUCE.id, (short) 2, (short) 3); + assertThrows(UnsupportedVersionException.class, + () -> apiVersions.latestUsableVersion(ApiKeys.PRODUCE, (short) 0, (short) 1)); + } + + @Test + public void testUsableVersionCalculationNoKnownVersions() { + NodeApiVersions versions = new NodeApiVersions(new ApiVersionCollection()); + assertThrows(UnsupportedVersionException.class, + () -> versions.latestUsableVersion(ApiKeys.FETCH)); + } + + @Test + public void testLatestUsableVersionOutOfRange() { + NodeApiVersions apiVersions = NodeApiVersions.create(ApiKeys.PRODUCE.id, (short) 300, (short) 300); + assertThrows(UnsupportedVersionException.class, + () -> apiVersions.latestUsableVersion(ApiKeys.PRODUCE)); + } + + @ParameterizedTest + @EnumSource(ApiMessageType.ListenerType.class) + public void testUsableVersionLatestVersions(ApiMessageType.ListenerType scope) { + ApiVersionsResponse defaultResponse = ApiVersionsResponse.defaultApiVersionsResponse(scope); + List versionList = new LinkedList<>(defaultResponse.data().apiKeys()); + // Add an API key that we don't know about. + versionList.add(new ApiVersion() + .setApiKey((short) 100) + .setMinVersion((short) 0) + .setMaxVersion((short) 1)); + NodeApiVersions versions = new NodeApiVersions(versionList); + for (ApiKeys apiKey: ApiKeys.apisForListener(scope)) { + assertEquals(apiKey.latestVersion(), versions.latestUsableVersion(apiKey)); + } + } + + @ParameterizedTest + @EnumSource(ApiMessageType.ListenerType.class) + public void testConstructionFromApiVersionsResponse(ApiMessageType.ListenerType scope) { + ApiVersionsResponse apiVersionsResponse = ApiVersionsResponse.defaultApiVersionsResponse(scope); + NodeApiVersions versions = new NodeApiVersions(apiVersionsResponse.data().apiKeys()); + + for (ApiVersion apiVersionKey : apiVersionsResponse.data().apiKeys()) { + ApiVersion apiVersion = versions.apiVersion(ApiKeys.forId(apiVersionKey.apiKey())); + assertEquals(apiVersionKey.apiKey(), apiVersion.apiKey()); + assertEquals(apiVersionKey.minVersion(), apiVersion.minVersion()); + assertEquals(apiVersionKey.maxVersion(), apiVersion.maxVersion()); + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/AdminClientTestUtils.java b/clients/src/test/java/org/apache/kafka/clients/admin/AdminClientTestUtils.java new file mode 100644 index 0000000..587434a --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/AdminClientTestUtils.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import org.apache.kafka.clients.HostResolver; +import org.apache.kafka.clients.admin.CreateTopicsResult.TopicMetadataAndConfig; +import org.apache.kafka.clients.admin.internals.MetadataOperationContext; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.internals.KafkaFutureImpl; + +public class AdminClientTestUtils { + + /** + * Helper to create a ListPartitionReassignmentsResult instance for a given Throwable. + * ListPartitionReassignmentsResult's constructor is only accessible from within the + * admin package. + */ + public static ListPartitionReassignmentsResult listPartitionReassignmentsResult(Throwable t) { + KafkaFutureImpl> future = new KafkaFutureImpl<>(); + future.completeExceptionally(t); + return new ListPartitionReassignmentsResult(future); + } + + /** + * Helper to create a CreateTopicsResult instance for a given Throwable. + * CreateTopicsResult's constructor is only accessible from within the + * admin package. + */ + public static CreateTopicsResult createTopicsResult(String topic, Throwable t) { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + future.completeExceptionally(t); + return new CreateTopicsResult(Collections.singletonMap(topic, future)); + } + + /** + * Helper to create a DeleteTopicsResult instance for a given Throwable. + * DeleteTopicsResult's constructor is only accessible from within the + * admin package. + */ + public static DeleteTopicsResult deleteTopicsResult(String topic, Throwable t) { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + future.completeExceptionally(t); + return DeleteTopicsResult.ofTopicNames(Collections.singletonMap(topic, future)); + } + + /** + * Helper to create a ListTopicsResult instance for a given topic. + * ListTopicsResult's constructor is only accessible from within the + * admin package. + */ + public static ListTopicsResult listTopicsResult(String topic) { + KafkaFutureImpl> future = new KafkaFutureImpl<>(); + future.complete(Collections.singletonMap(topic, new TopicListing(topic, Uuid.ZERO_UUID, false))); + return new ListTopicsResult(future); + } + + /** + * Helper to create a CreatePartitionsResult instance for a given Throwable. + * CreatePartitionsResult's constructor is only accessible from within the + * admin package. + */ + public static CreatePartitionsResult createPartitionsResult(String topic, Throwable t) { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + future.completeExceptionally(t); + return new CreatePartitionsResult(Collections.singletonMap(topic, future)); + } + + /** + * Helper to create a DescribeTopicsResult instance for a given topic. + * DescribeTopicsResult's constructor is only accessible from within the + * admin package. + */ + public static DescribeTopicsResult describeTopicsResult(String topic, TopicDescription description) { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + future.complete(description); + return DescribeTopicsResult.ofTopicNames(Collections.singletonMap(topic, future)); + } + + public static DescribeTopicsResult describeTopicsResult(Map topicDescriptions) { + return DescribeTopicsResult.ofTopicNames(topicDescriptions.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> KafkaFuture.completedFuture(e.getValue())))); + } + + public static ListConsumerGroupOffsetsResult listConsumerGroupOffsetsResult(Map offsets) { + return new ListConsumerGroupOffsetsResult(KafkaFuture.completedFuture(offsets)); + } + + /** + * Used for benchmark. KafkaAdminClient.getListOffsetsCalls is only accessible + * from within the admin package. + */ + public static List getListOffsetsCalls(KafkaAdminClient adminClient, + MetadataOperationContext context, + Map topicPartitionOffsets, + Map> futures) { + return adminClient.getListOffsetsCalls(context, topicPartitionOffsets, futures); + } + + /** + * Helper to create a KafkaAdminClient with a custom HostResolver accessible to tests outside this package. + */ + public static Admin create(Map conf, HostResolver hostResolver) { + return KafkaAdminClient.createInternal(new AdminClientConfig(conf, true), null, hostResolver); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/AdminClientUnitTestEnv.java b/clients/src/test/java/org/apache/kafka/clients/admin/AdminClientUnitTestEnv.java new file mode 100644 index 0000000..744ec12 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/AdminClientUnitTestEnv.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.clients.MockClient; +import org.apache.kafka.clients.admin.internals.AdminMetadataManager; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; + +import java.time.Duration; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Simple utility for setting up a mock {@link KafkaAdminClient} that uses a {@link MockClient} for a supplied + * {@link Cluster}. Create a {@link Cluster} manually or use {@link org.apache.kafka.test.TestUtils} methods to + * easily create a simple cluster. + *

            + * To use in a test, create an instance and prepare its {@link #kafkaClient() MockClient} with the expected responses + * for the {@link Admin}. Then, use the {@link #adminClient() AdminClient} in the test, which will then use the MockClient + * and receive the responses you provided. + * + * Since {@link #kafkaClient() MockClient} is not thread-safe, + * users should be wary of calling its methods after the {@link #adminClient() AdminClient} is instantiated. + * + *

            + * When finished, be sure to {@link #close() close} the environment object. + */ +public class AdminClientUnitTestEnv implements AutoCloseable { + private final Time time; + private final Cluster cluster; + private final MockClient mockClient; + private final KafkaAdminClient adminClient; + + public AdminClientUnitTestEnv(Cluster cluster, String... vals) { + this(Time.SYSTEM, cluster, vals); + } + + public AdminClientUnitTestEnv(Time time, Cluster cluster, String... vals) { + this(time, cluster, clientConfigs(vals)); + } + + public AdminClientUnitTestEnv(Time time, Cluster cluster) { + this(time, cluster, clientConfigs()); + } + + public AdminClientUnitTestEnv(Time time, Cluster cluster, Map config) { + this(time, cluster, config, Collections.emptyMap()); + } + + public AdminClientUnitTestEnv(Time time, Cluster cluster, Map config, Map unreachableNodes) { + this.time = time; + this.cluster = cluster; + AdminClientConfig adminClientConfig = new AdminClientConfig(config); + + AdminMetadataManager metadataManager = new AdminMetadataManager(new LogContext(), + adminClientConfig.getLong(AdminClientConfig.RETRY_BACKOFF_MS_CONFIG), + adminClientConfig.getLong(AdminClientConfig.METADATA_MAX_AGE_CONFIG)); + this.mockClient = new MockClient(time, new MockClient.MockMetadataUpdater() { + @Override + public List fetchNodes() { + return cluster.nodes(); + } + + @Override + public boolean isUpdateNeeded() { + return false; + } + + @Override + public void update(Time time, MockClient.MetadataUpdate update) { + throw new UnsupportedOperationException(); + } + }); + + metadataManager.update(cluster, time.milliseconds()); + unreachableNodes.forEach(mockClient::setUnreachable); + this.adminClient = KafkaAdminClient.createInternal(adminClientConfig, metadataManager, mockClient, time); + } + + public Time time() { + return time; + } + + public Cluster cluster() { + return cluster; + } + + public Admin adminClient() { + return adminClient; + } + + public MockClient kafkaClient() { + return mockClient; + } + + @Override + public void close() { + // tell the admin client to close now + this.adminClient.close(Duration.ZERO); + // block for up to a minute until the internal threads shut down. + this.adminClient.close(Duration.ofMinutes(1)); + } + + static Map clientConfigs(String... overrides) { + Map map = new HashMap<>(); + map.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:8121"); + map.put(AdminClientConfig.REQUEST_TIMEOUT_MS_CONFIG, "1000"); + if (overrides.length % 2 != 0) { + throw new IllegalStateException(); + } + for (int i = 0; i < overrides.length; i += 2) { + map.put(overrides[i], overrides[i + 1]); + } + return map; + } + + public static String kafkaAdminClientNetworkThreadPrefix() { + return KafkaAdminClient.NETWORK_THREAD_PREFIX; + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/ConfigTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/ConfigTest.java new file mode 100644 index 0000000..59d1150 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/ConfigTest.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.clients.admin.ConfigEntry.ConfigType; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static java.util.Arrays.asList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ConfigTest { + private static final ConfigEntry E1 = new ConfigEntry("a", "b"); + private static final ConfigEntry E2 = new ConfigEntry("c", "d"); + private Config config; + + @BeforeEach + public void setUp() { + config = new Config(asList(E1, E2)); + } + + @Test + public void shouldGetEntry() { + assertEquals(E1, config.get("a")); + assertEquals(E2, config.get("c")); + } + + @Test + public void shouldReturnNullOnGetUnknownEntry() { + assertNull(config.get("unknown")); + } + + @Test + public void shouldGetAllEntries() { + assertEquals(2, config.entries().size()); + assertTrue(config.entries().contains(E1)); + assertTrue(config.entries().contains(E2)); + } + + @Test + public void shouldImplementEqualsProperly() { + assertEquals(config, config); + assertEquals(config, new Config(config.entries())); + assertNotEquals(new Config(asList(E1)), config); + assertNotEquals(config, "this"); + } + + @Test + public void shouldImplementHashCodeProperly() { + assertEquals(config.hashCode(), config.hashCode()); + assertEquals(config.hashCode(), new Config(config.entries()).hashCode()); + assertNotEquals(new Config(asList(E1)).hashCode(), config.hashCode()); + } + + @Test + public void shouldImplementToStringProperly() { + assertTrue(config.toString().contains(E1.toString())); + assertTrue(config.toString().contains(E2.toString())); + } + + public static ConfigEntry newConfigEntry(String name, String value, ConfigEntry.ConfigSource source, boolean isSensitive, + boolean isReadOnly, List synonyms) { + return new ConfigEntry(name, value, source, isSensitive, isReadOnly, synonyms, ConfigType.UNKNOWN, null); + } + + @Test + public void testHashCodeAndEqualsWithNull() { + ConfigEntry ce0 = new ConfigEntry("abc", null, null, false, false, null, null, null); + ConfigEntry ce1 = new ConfigEntry("abc", null, null, false, false, null, null, null); + assertEquals(ce0, ce1); + assertEquals(ce0.hashCode(), ce1.hashCode()); + } + + @Test + public void testEquals() { + ConfigEntry ce0 = new ConfigEntry("abc", null, ConfigEntry.ConfigSource.DEFAULT_CONFIG, false, false, null, null, null); + ConfigEntry ce1 = new ConfigEntry("abc", null, ConfigEntry.ConfigSource.DYNAMIC_BROKER_CONFIG, false, false, null, null, null); + assertNotEquals(ce0, ce1); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/DeleteConsumerGroupOffsetsResultTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/DeleteConsumerGroupOffsetsResultTest.java new file mode 100644 index 0000000..82a344c --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/DeleteConsumerGroupOffsetsResultTest.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.GroupAuthorizationException; +import org.apache.kafka.common.errors.UnknownTopicOrPartitionException; +import org.apache.kafka.common.internals.KafkaFutureImpl; + +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ExecutionException; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class DeleteConsumerGroupOffsetsResultTest { + + private final String topic = "topic"; + private final TopicPartition tpZero = new TopicPartition(topic, 0); + private final TopicPartition tpOne = new TopicPartition(topic, 1); + private Set partitions; + private Map errorsMap; + + private KafkaFutureImpl> partitionFutures; + + @BeforeEach + public void setUp() { + partitionFutures = new KafkaFutureImpl<>(); + partitions = new HashSet<>(); + partitions.add(tpZero); + partitions.add(tpOne); + + errorsMap = new HashMap<>(); + errorsMap.put(tpZero, Errors.NONE); + errorsMap.put(tpOne, Errors.UNKNOWN_TOPIC_OR_PARTITION); + } + + @Test + public void testTopLevelErrorConstructor() throws InterruptedException { + partitionFutures.completeExceptionally(Errors.GROUP_AUTHORIZATION_FAILED.exception()); + DeleteConsumerGroupOffsetsResult topLevelErrorResult = + new DeleteConsumerGroupOffsetsResult(partitionFutures, partitions); + TestUtils.assertFutureError(topLevelErrorResult.all(), GroupAuthorizationException.class); + } + + @Test + public void testPartitionLevelErrorConstructor() throws ExecutionException, InterruptedException { + createAndVerifyPartitionLevelErrror(); + } + + @Test + public void testPartitionMissingInResponseErrorConstructor() throws InterruptedException, ExecutionException { + errorsMap.remove(tpOne); + partitionFutures.complete(errorsMap); + assertFalse(partitionFutures.isCompletedExceptionally()); + DeleteConsumerGroupOffsetsResult missingPartitionResult = + new DeleteConsumerGroupOffsetsResult(partitionFutures, partitions); + + TestUtils.assertFutureError(missingPartitionResult.all(), IllegalArgumentException.class); + assertNull(missingPartitionResult.partitionResult(tpZero).get()); + TestUtils.assertFutureError(missingPartitionResult.partitionResult(tpOne), IllegalArgumentException.class); + } + + @Test + public void testPartitionMissingInRequestErrorConstructor() throws InterruptedException, ExecutionException { + DeleteConsumerGroupOffsetsResult partitionLevelErrorResult = createAndVerifyPartitionLevelErrror(); + assertThrows(IllegalArgumentException.class, () -> partitionLevelErrorResult.partitionResult(new TopicPartition("invalid-topic", 0))); + } + + @Test + public void testNoErrorConstructor() throws ExecutionException, InterruptedException { + Map errorsMap = new HashMap<>(); + errorsMap.put(tpZero, Errors.NONE); + errorsMap.put(tpOne, Errors.NONE); + DeleteConsumerGroupOffsetsResult noErrorResult = + new DeleteConsumerGroupOffsetsResult(partitionFutures, partitions); + partitionFutures.complete(errorsMap); + + assertNull(noErrorResult.all().get()); + assertNull(noErrorResult.partitionResult(tpZero).get()); + assertNull(noErrorResult.partitionResult(tpOne).get()); + } + + private DeleteConsumerGroupOffsetsResult createAndVerifyPartitionLevelErrror() throws InterruptedException, ExecutionException { + partitionFutures.complete(errorsMap); + assertFalse(partitionFutures.isCompletedExceptionally()); + DeleteConsumerGroupOffsetsResult partitionLevelErrorResult = + new DeleteConsumerGroupOffsetsResult(partitionFutures, partitions); + + TestUtils.assertFutureError(partitionLevelErrorResult.all(), UnknownTopicOrPartitionException.class); + assertNull(partitionLevelErrorResult.partitionResult(tpZero).get()); + TestUtils.assertFutureError(partitionLevelErrorResult.partitionResult(tpOne), UnknownTopicOrPartitionException.class); + return partitionLevelErrorResult; + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/DeleteTopicsResultTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/DeleteTopicsResultTest.java new file mode 100644 index 0000000..3c2add4 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/DeleteTopicsResultTest.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Collections; +import java.util.Map; + +public class DeleteTopicsResultTest { + + @Test + public void testDeleteTopicsResultWithNames() { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + future.complete(null); + Map> topicNames = Collections.singletonMap("foo", future); + + DeleteTopicsResult topicNameFutures = DeleteTopicsResult.ofTopicNames(topicNames); + + assertEquals(topicNames, topicNameFutures.topicNameValues()); + assertNull(topicNameFutures.topicIdValues()); + assertTrue(topicNameFutures.all().isDone()); + } + + @Test + public void testDeleteTopicsResultWithIds() { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + future.complete(null); + Map> topicIds = Collections.singletonMap(Uuid.randomUuid(), future); + + DeleteTopicsResult topicIdFutures = DeleteTopicsResult.ofTopicIds(topicIds); + + assertEquals(topicIds, topicIdFutures.topicIdValues()); + assertNull(topicIdFutures.topicNameValues()); + assertTrue(topicIdFutures.all().isDone()); + } + + @Test + public void testInvalidConfigurations() { + assertThrows(IllegalArgumentException.class, () -> new DeleteTopicsResult(null, null)); + assertThrows(IllegalArgumentException.class, () -> new DeleteTopicsResult(Collections.emptyMap(), Collections.emptyMap())); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/DescribeUserScramCredentialsResultTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/DescribeUserScramCredentialsResultTest.java new file mode 100644 index 0000000..9b5e98a --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/DescribeUserScramCredentialsResultTest.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.message.DescribeUserScramCredentialsResponseData; +import org.apache.kafka.common.protocol.Errors; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +public class DescribeUserScramCredentialsResultTest { + @Test + public void testTopLevelError() { + KafkaFutureImpl dataFuture = new KafkaFutureImpl<>(); + dataFuture.completeExceptionally(new RuntimeException()); + DescribeUserScramCredentialsResult results = new DescribeUserScramCredentialsResult(dataFuture); + try { + results.all().get(); + fail("expected all() to fail when there is a top-level error"); + } catch (Exception expected) { + // ignore, expected + } + try { + results.users().get(); + fail("expected users() to fail when there is a top-level error"); + } catch (Exception expected) { + // ignore, expected + } + try { + results.description("whatever").get(); + fail("expected description() to fail when there is a top-level error"); + } catch (Exception expected) { + // ignore, expected + } + } + + @Test + public void testUserLevelErrors() throws Exception { + String goodUser = "goodUser"; + String unknownUser = "unknownUser"; + String failedUser = "failedUser"; + KafkaFutureImpl dataFuture = new KafkaFutureImpl<>(); + ScramMechanism scramSha256 = ScramMechanism.SCRAM_SHA_256; + int iterations = 4096; + dataFuture.complete(new DescribeUserScramCredentialsResponseData().setErrorCode(Errors.NONE.code()).setResults(Arrays.asList( + new DescribeUserScramCredentialsResponseData.DescribeUserScramCredentialsResult().setUser(goodUser).setCredentialInfos( + Arrays.asList(new DescribeUserScramCredentialsResponseData.CredentialInfo().setMechanism(scramSha256.type()).setIterations(iterations))), + new DescribeUserScramCredentialsResponseData.DescribeUserScramCredentialsResult().setUser(unknownUser).setErrorCode(Errors.RESOURCE_NOT_FOUND.code()), + new DescribeUserScramCredentialsResponseData.DescribeUserScramCredentialsResult().setUser(failedUser).setErrorCode(Errors.DUPLICATE_RESOURCE.code())))); + DescribeUserScramCredentialsResult results = new DescribeUserScramCredentialsResult(dataFuture); + try { + results.all().get(); + fail("expected all() to fail when there is a user-level error"); + } catch (Exception expected) { + // ignore, expected + } + assertEquals(Arrays.asList(goodUser, failedUser), results.users().get(), "Expected 2 users with credentials"); + UserScramCredentialsDescription goodUserDescription = results.description(goodUser).get(); + assertEquals(new UserScramCredentialsDescription(goodUser, Arrays.asList(new ScramCredentialInfo(scramSha256, iterations))), goodUserDescription); + try { + results.description(failedUser).get(); + fail("expected description(failedUser) to fail when there is a user-level error"); + } catch (Exception expected) { + // ignore, expected + } + try { + results.description(unknownUser).get(); + fail("expected description(unknownUser) to fail when there is no such user"); + } catch (Exception expected) { + // ignore, expected + } + } + + @Test + public void testSuccessfulDescription() throws Exception { + String goodUser = "goodUser"; + String unknownUser = "unknownUser"; + KafkaFutureImpl dataFuture = new KafkaFutureImpl<>(); + ScramMechanism scramSha256 = ScramMechanism.SCRAM_SHA_256; + int iterations = 4096; + dataFuture.complete(new DescribeUserScramCredentialsResponseData().setErrorCode(Errors.NONE.code()).setResults(Arrays.asList( + new DescribeUserScramCredentialsResponseData.DescribeUserScramCredentialsResult().setUser(goodUser).setCredentialInfos( + Arrays.asList(new DescribeUserScramCredentialsResponseData.CredentialInfo().setMechanism(scramSha256.type()).setIterations(iterations)))))); + DescribeUserScramCredentialsResult results = new DescribeUserScramCredentialsResult(dataFuture); + assertEquals(Arrays.asList(goodUser), results.users().get(), "Expected 1 user with credentials"); + Map allResults = results.all().get(); + assertEquals(1, allResults.size()); + UserScramCredentialsDescription goodUserDescriptionViaAll = allResults.get(goodUser); + assertEquals(new UserScramCredentialsDescription(goodUser, Arrays.asList(new ScramCredentialInfo(scramSha256, iterations))), goodUserDescriptionViaAll); + assertEquals(goodUserDescriptionViaAll, results.description(goodUser).get(), "Expected same thing via all() and description()"); + try { + results.description(unknownUser).get(); + fail("expected description(unknownUser) to fail when there is no such user even when all() succeeds"); + } catch (Exception expected) { + // ignore, expected + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/KafkaAdminClientTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/KafkaAdminClientTest.java new file mode 100644 index 0000000..b648b2d --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/KafkaAdminClientTest.java @@ -0,0 +1,6289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.clients.ClientDnsLookup; +import org.apache.kafka.clients.ClientUtils; +import org.apache.kafka.clients.MockClient; +import org.apache.kafka.clients.NodeApiVersions; +import org.apache.kafka.clients.admin.DeleteAclsResult.FilterResults; +import org.apache.kafka.clients.admin.ListOffsetsResult.ListOffsetsResultInfo; +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.consumer.internals.ConsumerProtocol; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.ConsumerGroupState; +import org.apache.kafka.common.ElectionType; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicCollection; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.TopicPartitionReplica; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.acl.AccessControlEntry; +import org.apache.kafka.common.acl.AccessControlEntryFilter; +import org.apache.kafka.common.acl.AclBinding; +import org.apache.kafka.common.acl.AclBindingFilter; +import org.apache.kafka.common.acl.AclOperation; +import org.apache.kafka.common.acl.AclPermissionType; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.errors.ApiException; +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.errors.ClusterAuthorizationException; +import org.apache.kafka.common.errors.FencedInstanceIdException; +import org.apache.kafka.common.errors.GroupAuthorizationException; +import org.apache.kafka.common.errors.GroupSubscribedToTopicException; +import org.apache.kafka.common.errors.InvalidRequestException; +import org.apache.kafka.common.errors.InvalidTopicException; +import org.apache.kafka.common.errors.LeaderNotAvailableException; +import org.apache.kafka.common.errors.LogDirNotFoundException; +import org.apache.kafka.common.errors.NotLeaderOrFollowerException; +import org.apache.kafka.common.errors.OffsetOutOfRangeException; +import org.apache.kafka.common.errors.SaslAuthenticationException; +import org.apache.kafka.common.errors.SecurityDisabledException; +import org.apache.kafka.common.errors.ThrottlingQuotaExceededException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.TopicAuthorizationException; +import org.apache.kafka.common.errors.TopicDeletionDisabledException; +import org.apache.kafka.common.errors.TopicExistsException; +import org.apache.kafka.common.errors.UnknownMemberIdException; +import org.apache.kafka.common.errors.UnknownServerException; +import org.apache.kafka.common.errors.UnknownTopicIdException; +import org.apache.kafka.common.errors.UnknownTopicOrPartitionException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.feature.Features; +import org.apache.kafka.common.message.AlterPartitionReassignmentsResponseData; +import org.apache.kafka.common.message.AlterReplicaLogDirsResponseData; +import org.apache.kafka.common.message.AlterReplicaLogDirsResponseData.AlterReplicaLogDirPartitionResult; +import org.apache.kafka.common.message.AlterReplicaLogDirsResponseData.AlterReplicaLogDirTopicResult; +import org.apache.kafka.common.message.AlterUserScramCredentialsResponseData; +import org.apache.kafka.common.message.ApiMessageType; +import org.apache.kafka.common.message.ApiVersionsResponseData; +import org.apache.kafka.common.message.ApiVersionsResponseData.ApiVersion; +import org.apache.kafka.common.message.CreateAclsResponseData; +import org.apache.kafka.common.message.CreatePartitionsResponseData; +import org.apache.kafka.common.message.CreatePartitionsResponseData.CreatePartitionsTopicResult; +import org.apache.kafka.common.message.CreateTopicsResponseData; +import org.apache.kafka.common.message.CreateTopicsResponseData.CreatableTopicResult; +import org.apache.kafka.common.message.CreateTopicsResponseData.CreatableTopicResultCollection; +import org.apache.kafka.common.message.DeleteAclsResponseData; +import org.apache.kafka.common.message.DeleteGroupsResponseData; +import org.apache.kafka.common.message.DeleteGroupsResponseData.DeletableGroupResult; +import org.apache.kafka.common.message.DeleteGroupsResponseData.DeletableGroupResultCollection; +import org.apache.kafka.common.message.DeleteRecordsResponseData; +import org.apache.kafka.common.message.DeleteTopicsResponseData; +import org.apache.kafka.common.message.DeleteTopicsResponseData.DeletableTopicResult; +import org.apache.kafka.common.message.DeleteTopicsResponseData.DeletableTopicResultCollection; +import org.apache.kafka.common.message.DescribeAclsResponseData; +import org.apache.kafka.common.message.DescribeClusterResponseData; +import org.apache.kafka.common.message.DescribeClusterResponseData.DescribeClusterBroker; +import org.apache.kafka.common.message.DescribeConfigsResponseData; +import org.apache.kafka.common.message.DescribeGroupsResponseData; +import org.apache.kafka.common.message.DescribeGroupsResponseData.DescribedGroupMember; +import org.apache.kafka.common.message.DescribeLogDirsResponseData; +import org.apache.kafka.common.message.DescribeLogDirsResponseData.DescribeLogDirsTopic; +import org.apache.kafka.common.message.DescribeProducersResponseData; +import org.apache.kafka.common.message.DescribeTransactionsResponseData; +import org.apache.kafka.common.message.DescribeUserScramCredentialsResponseData; +import org.apache.kafka.common.message.DescribeUserScramCredentialsResponseData.CredentialInfo; +import org.apache.kafka.common.message.ElectLeadersResponseData.PartitionResult; +import org.apache.kafka.common.message.ElectLeadersResponseData.ReplicaElectionResult; +import org.apache.kafka.common.message.FindCoordinatorResponseData; +import org.apache.kafka.common.message.IncrementalAlterConfigsResponseData; +import org.apache.kafka.common.message.IncrementalAlterConfigsResponseData.AlterConfigsResourceResponse; +import org.apache.kafka.common.message.LeaveGroupRequestData.MemberIdentity; +import org.apache.kafka.common.message.LeaveGroupResponseData; +import org.apache.kafka.common.message.LeaveGroupResponseData.MemberResponse; +import org.apache.kafka.common.message.ListGroupsResponseData; +import org.apache.kafka.common.message.ListOffsetsResponseData; +import org.apache.kafka.common.message.ListOffsetsResponseData.ListOffsetsTopicResponse; +import org.apache.kafka.common.message.ListPartitionReassignmentsResponseData; +import org.apache.kafka.common.message.ListTransactionsResponseData; +import org.apache.kafka.common.message.MetadataResponseData; +import org.apache.kafka.common.message.MetadataResponseData.MetadataResponsePartition; +import org.apache.kafka.common.message.MetadataResponseData.MetadataResponseTopic; +import org.apache.kafka.common.message.OffsetDeleteResponseData; +import org.apache.kafka.common.message.OffsetDeleteResponseData.OffsetDeleteResponsePartition; +import org.apache.kafka.common.message.OffsetDeleteResponseData.OffsetDeleteResponsePartitionCollection; +import org.apache.kafka.common.message.OffsetDeleteResponseData.OffsetDeleteResponseTopic; +import org.apache.kafka.common.message.OffsetDeleteResponseData.OffsetDeleteResponseTopicCollection; +import org.apache.kafka.common.message.UnregisterBrokerResponseData; +import org.apache.kafka.common.message.WriteTxnMarkersResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.quota.ClientQuotaAlteration; +import org.apache.kafka.common.quota.ClientQuotaEntity; +import org.apache.kafka.common.quota.ClientQuotaFilter; +import org.apache.kafka.common.quota.ClientQuotaFilterComponent; +import org.apache.kafka.common.record.RecordVersion; +import org.apache.kafka.common.requests.AlterClientQuotasResponse; +import org.apache.kafka.common.requests.AlterPartitionReassignmentsResponse; +import org.apache.kafka.common.requests.AlterReplicaLogDirsResponse; +import org.apache.kafka.common.requests.AlterUserScramCredentialsResponse; +import org.apache.kafka.common.requests.ApiError; +import org.apache.kafka.common.requests.ApiVersionsRequest; +import org.apache.kafka.common.requests.ApiVersionsResponse; +import org.apache.kafka.common.requests.CreateAclsResponse; +import org.apache.kafka.common.requests.CreatePartitionsRequest; +import org.apache.kafka.common.requests.CreatePartitionsResponse; +import org.apache.kafka.common.requests.CreateTopicsRequest; +import org.apache.kafka.common.requests.CreateTopicsResponse; +import org.apache.kafka.common.requests.DeleteAclsResponse; +import org.apache.kafka.common.requests.DeleteGroupsResponse; +import org.apache.kafka.common.requests.DeleteRecordsResponse; +import org.apache.kafka.common.requests.DeleteTopicsRequest; +import org.apache.kafka.common.requests.DeleteTopicsResponse; +import org.apache.kafka.common.requests.DescribeAclsResponse; +import org.apache.kafka.common.requests.DescribeClientQuotasResponse; +import org.apache.kafka.common.requests.DescribeClusterRequest; +import org.apache.kafka.common.requests.DescribeClusterResponse; +import org.apache.kafka.common.requests.DescribeConfigsResponse; +import org.apache.kafka.common.requests.DescribeGroupsResponse; +import org.apache.kafka.common.requests.DescribeLogDirsResponse; +import org.apache.kafka.common.requests.DescribeProducersRequest; +import org.apache.kafka.common.requests.DescribeProducersResponse; +import org.apache.kafka.common.requests.DescribeTransactionsRequest; +import org.apache.kafka.common.requests.DescribeTransactionsResponse; +import org.apache.kafka.common.requests.DescribeUserScramCredentialsResponse; +import org.apache.kafka.common.requests.ElectLeadersResponse; +import org.apache.kafka.common.requests.FindCoordinatorRequest; +import org.apache.kafka.common.requests.FindCoordinatorResponse; +import org.apache.kafka.common.requests.IncrementalAlterConfigsResponse; +import org.apache.kafka.common.requests.JoinGroupRequest; +import org.apache.kafka.common.requests.LeaveGroupResponse; +import org.apache.kafka.common.requests.ListGroupsRequest; +import org.apache.kafka.common.requests.ListGroupsResponse; +import org.apache.kafka.common.requests.ListOffsetsRequest; +import org.apache.kafka.common.requests.ListOffsetsResponse; +import org.apache.kafka.common.requests.ListPartitionReassignmentsResponse; +import org.apache.kafka.common.requests.ListTransactionsRequest; +import org.apache.kafka.common.requests.ListTransactionsResponse; +import org.apache.kafka.common.requests.MetadataRequest; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.requests.OffsetCommitResponse; +import org.apache.kafka.common.requests.OffsetDeleteResponse; +import org.apache.kafka.common.requests.OffsetFetchResponse; +import org.apache.kafka.common.requests.RequestTestUtils; +import org.apache.kafka.common.requests.UnregisterBrokerResponse; +import org.apache.kafka.common.requests.UpdateFeaturesRequest; +import org.apache.kafka.common.requests.UpdateFeaturesResponse; +import org.apache.kafka.common.requests.WriteTxnMarkersRequest; +import org.apache.kafka.common.requests.WriteTxnMarkersResponse; +import org.apache.kafka.common.resource.PatternType; +import org.apache.kafka.common.resource.ResourcePattern; +import org.apache.kafka.common.resource.ResourcePatternFilter; +import org.apache.kafka.common.resource.ResourceType; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.OptionalLong; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; +import static java.util.Collections.emptySet; +import static java.util.Collections.singleton; +import static java.util.Collections.singletonList; +import static org.apache.kafka.common.message.AlterPartitionReassignmentsResponseData.ReassignablePartitionResponse; +import static org.apache.kafka.common.message.AlterPartitionReassignmentsResponseData.ReassignableTopicResponse; +import static org.apache.kafka.common.message.ListPartitionReassignmentsResponseData.OngoingPartitionReassignment; +import static org.apache.kafka.common.message.ListPartitionReassignmentsResponseData.OngoingTopicReassignment; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +/** + * A unit test for KafkaAdminClient. + * + * See AdminClientIntegrationTest for an integration test. + */ +@Timeout(120) +public class KafkaAdminClientTest { + private static final Logger log = LoggerFactory.getLogger(KafkaAdminClientTest.class); + private static final String GROUP_ID = "group-0"; + + @Test + public void testDefaultApiTimeoutAndRequestTimeoutConflicts() { + final AdminClientConfig config = newConfMap(AdminClientConfig.DEFAULT_API_TIMEOUT_MS_CONFIG, "500"); + KafkaException exception = assertThrows(KafkaException.class, + () -> KafkaAdminClient.createInternal(config, null)); + assertTrue(exception.getCause() instanceof ConfigException); + } + + @Test + public void testGetOrCreateListValue() { + Map> map = new HashMap<>(); + List fooList = KafkaAdminClient.getOrCreateListValue(map, "foo"); + assertNotNull(fooList); + fooList.add("a"); + fooList.add("b"); + List fooList2 = KafkaAdminClient.getOrCreateListValue(map, "foo"); + assertEquals(fooList, fooList2); + assertTrue(fooList2.contains("a")); + assertTrue(fooList2.contains("b")); + List barList = KafkaAdminClient.getOrCreateListValue(map, "bar"); + assertNotNull(barList); + assertTrue(barList.isEmpty()); + } + + @Test + public void testCalcTimeoutMsRemainingAsInt() { + assertEquals(0, KafkaAdminClient.calcTimeoutMsRemainingAsInt(1000, 1000)); + assertEquals(100, KafkaAdminClient.calcTimeoutMsRemainingAsInt(1000, 1100)); + assertEquals(Integer.MAX_VALUE, KafkaAdminClient.calcTimeoutMsRemainingAsInt(0, Long.MAX_VALUE)); + assertEquals(Integer.MIN_VALUE, KafkaAdminClient.calcTimeoutMsRemainingAsInt(Long.MAX_VALUE, 0)); + } + + @Test + public void testPrettyPrintException() { + assertEquals("Null exception.", KafkaAdminClient.prettyPrintException(null)); + assertEquals("TimeoutException", KafkaAdminClient.prettyPrintException(new TimeoutException())); + assertEquals("TimeoutException: The foobar timed out.", + KafkaAdminClient.prettyPrintException(new TimeoutException("The foobar timed out."))); + } + + private static Map newStrMap(String... vals) { + Map map = new HashMap<>(); + map.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:8121"); + map.put(AdminClientConfig.REQUEST_TIMEOUT_MS_CONFIG, "1000"); + if (vals.length % 2 != 0) { + throw new IllegalStateException(); + } + for (int i = 0; i < vals.length; i += 2) { + map.put(vals[i], vals[i + 1]); + } + return map; + } + + private static AdminClientConfig newConfMap(String... vals) { + return new AdminClientConfig(newStrMap(vals)); + } + + @Test + public void testGenerateClientId() { + Set ids = new HashSet<>(); + for (int i = 0; i < 10; i++) { + String id = KafkaAdminClient.generateClientId(newConfMap(AdminClientConfig.CLIENT_ID_CONFIG, "")); + assertFalse(ids.contains(id), "Got duplicate id " + id); + ids.add(id); + } + assertEquals("myCustomId", + KafkaAdminClient.generateClientId(newConfMap(AdminClientConfig.CLIENT_ID_CONFIG, "myCustomId"))); + } + + private static Cluster mockCluster(int numNodes, int controllerIndex) { + HashMap nodes = new HashMap<>(); + for (int i = 0; i < numNodes; i++) + nodes.put(i, new Node(i, "localhost", 8121 + i)); + return new Cluster("mockClusterId", nodes.values(), + Collections.emptySet(), Collections.emptySet(), + Collections.emptySet(), nodes.get(controllerIndex)); + } + + private static Cluster mockBootstrapCluster() { + return Cluster.bootstrap(ClientUtils.parseAndValidateAddresses( + singletonList("localhost:8121"), ClientDnsLookup.USE_ALL_DNS_IPS)); + } + + private static AdminClientUnitTestEnv mockClientEnv(String... configVals) { + return new AdminClientUnitTestEnv(mockCluster(3, 0), configVals); + } + + private static AdminClientUnitTestEnv mockClientEnv(Time time, String... configVals) { + return new AdminClientUnitTestEnv(time, mockCluster(3, 0), configVals); + } + + @Test + public void testCloseAdminClient() { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + } + } + + /** + * Test if admin client can be closed in the callback invoked when + * an api call completes. If calling {@link Admin#close()} in callback, AdminClient thread hangs + */ + @Test @Timeout(10) + public void testCloseAdminClientInCallback() throws InterruptedException { + MockTime time = new MockTime(); + AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(time, mockCluster(3, 0)); + + final ListTopicsResult result = env.adminClient().listTopics(new ListTopicsOptions().timeoutMs(1000)); + final KafkaFuture> kafkaFuture = result.listings(); + final Semaphore callbackCalled = new Semaphore(0); + kafkaFuture.whenComplete((topicListings, throwable) -> { + env.close(); + callbackCalled.release(); + }); + + time.sleep(2000); // Advance time to timeout and complete listTopics request + callbackCalled.acquire(); + } + + private static OffsetDeleteResponse prepareOffsetDeleteResponse(Errors error) { + return new OffsetDeleteResponse( + new OffsetDeleteResponseData() + .setErrorCode(error.code()) + .setTopics(new OffsetDeleteResponseTopicCollection()) + ); + } + + private static OffsetDeleteResponse prepareOffsetDeleteResponse(String topic, int partition, Errors error) { + return new OffsetDeleteResponse( + new OffsetDeleteResponseData() + .setErrorCode(Errors.NONE.code()) + .setTopics(new OffsetDeleteResponseTopicCollection(Stream.of( + new OffsetDeleteResponseTopic() + .setName(topic) + .setPartitions(new OffsetDeleteResponsePartitionCollection(Collections.singletonList( + new OffsetDeleteResponsePartition() + .setPartitionIndex(partition) + .setErrorCode(error.code()) + ).iterator())) + ).collect(Collectors.toList()).iterator())) + ); + } + + private static OffsetCommitResponse prepareOffsetCommitResponse(TopicPartition tp, Errors error) { + Map responseData = new HashMap<>(); + responseData.put(tp, error); + return new OffsetCommitResponse(0, responseData); + } + + private static CreateTopicsResponse prepareCreateTopicsResponse(String topicName, Errors error) { + CreateTopicsResponseData data = new CreateTopicsResponseData(); + data.topics().add(new CreatableTopicResult() + .setName(topicName) + .setErrorCode(error.code())); + return new CreateTopicsResponse(data); + } + + public static CreateTopicsResponse prepareCreateTopicsResponse(int throttleTimeMs, CreatableTopicResult... topics) { + CreateTopicsResponseData data = new CreateTopicsResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setTopics(new CreatableTopicResultCollection(Arrays.stream(topics).iterator())); + return new CreateTopicsResponse(data); + } + + public static CreatableTopicResult creatableTopicResult(String name, Errors error) { + return new CreatableTopicResult() + .setName(name) + .setErrorCode(error.code()); + } + + public static DeleteTopicsResponse prepareDeleteTopicsResponse(int throttleTimeMs, DeletableTopicResult... topics) { + DeleteTopicsResponseData data = new DeleteTopicsResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setResponses(new DeletableTopicResultCollection(Arrays.stream(topics).iterator())); + return new DeleteTopicsResponse(data); + } + + public static DeletableTopicResult deletableTopicResult(String topicName, Errors error) { + return new DeletableTopicResult() + .setName(topicName) + .setErrorCode(error.code()); + } + + public static DeletableTopicResult deletableTopicResultWithId(Uuid topicId, Errors error) { + return new DeletableTopicResult() + .setTopicId(topicId) + .setErrorCode(error.code()); + } + + public static CreatePartitionsResponse prepareCreatePartitionsResponse(int throttleTimeMs, CreatePartitionsTopicResult... topics) { + CreatePartitionsResponseData data = new CreatePartitionsResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setResults(Arrays.asList(topics)); + return new CreatePartitionsResponse(data); + } + + public static CreatePartitionsTopicResult createPartitionsTopicResult(String name, Errors error) { + return createPartitionsTopicResult(name, error, null); + } + + public static CreatePartitionsTopicResult createPartitionsTopicResult(String name, Errors error, String errorMessage) { + return new CreatePartitionsTopicResult() + .setName(name) + .setErrorCode(error.code()) + .setErrorMessage(errorMessage); + } + + private static DeleteTopicsResponse prepareDeleteTopicsResponse(String topicName, Errors error) { + DeleteTopicsResponseData data = new DeleteTopicsResponseData(); + data.responses().add(new DeletableTopicResult() + .setName(topicName) + .setErrorCode(error.code())); + return new DeleteTopicsResponse(data); + } + + private static DeleteTopicsResponse prepareDeleteTopicsResponseWithTopicId(Uuid id, Errors error) { + DeleteTopicsResponseData data = new DeleteTopicsResponseData(); + data.responses().add(new DeletableTopicResult() + .setTopicId(id) + .setErrorCode(error.code())); + return new DeleteTopicsResponse(data); + } + + private static FindCoordinatorResponse prepareFindCoordinatorResponse(Errors error, Node node) { + return prepareFindCoordinatorResponse(error, GROUP_ID, node); + } + + private static FindCoordinatorResponse prepareFindCoordinatorResponse(Errors error, String key, Node node) { + return FindCoordinatorResponse.prepareResponse(error, key, node); + } + + private static FindCoordinatorResponse prepareOldFindCoordinatorResponse(Errors error, Node node) { + return FindCoordinatorResponse.prepareOldResponse(error, node); + } + + private static MetadataResponse prepareMetadataResponse(Cluster cluster, Errors error) { + return prepareMetadataResponse(cluster, error, error); + } + + private static MetadataResponse prepareMetadataResponse(Cluster cluster, Errors topicError, Errors partitionError) { + List metadata = new ArrayList<>(); + for (String topic : cluster.topics()) { + List pms = new ArrayList<>(); + for (PartitionInfo pInfo : cluster.availablePartitionsForTopic(topic)) { + MetadataResponsePartition pm = new MetadataResponsePartition() + .setErrorCode(partitionError.code()) + .setPartitionIndex(pInfo.partition()) + .setLeaderId(pInfo.leader().id()) + .setLeaderEpoch(234) + .setReplicaNodes(Arrays.stream(pInfo.replicas()).map(Node::id).collect(Collectors.toList())) + .setIsrNodes(Arrays.stream(pInfo.inSyncReplicas()).map(Node::id).collect(Collectors.toList())) + .setOfflineReplicas(Arrays.stream(pInfo.offlineReplicas()).map(Node::id).collect(Collectors.toList())); + pms.add(pm); + } + MetadataResponseTopic tm = new MetadataResponseTopic() + .setErrorCode(topicError.code()) + .setName(topic) + .setIsInternal(false) + .setPartitions(pms); + metadata.add(tm); + } + return MetadataResponse.prepareResponse(true, + 0, + cluster.nodes(), + cluster.clusterResource().clusterId(), + cluster.controller().id(), + metadata, + MetadataResponse.AUTHORIZED_OPERATIONS_OMITTED); + } + + private static DescribeGroupsResponseData prepareDescribeGroupsResponseData(String groupId, + List groupInstances, + List topicPartitions) { + final ByteBuffer memberAssignment = ConsumerProtocol.serializeAssignment(new ConsumerPartitionAssignor.Assignment(topicPartitions)); + List describedGroupMembers = groupInstances.stream().map(groupInstance -> DescribeGroupsResponse.groupMember(JoinGroupRequest.UNKNOWN_MEMBER_ID, + groupInstance, "clientId0", "clientHost", new byte[memberAssignment.remaining()], null)).collect(Collectors.toList()); + DescribeGroupsResponseData data = new DescribeGroupsResponseData(); + data.groups().add(DescribeGroupsResponse.groupMetadata( + groupId, + Errors.NONE, + "", + ConsumerProtocol.PROTOCOL_TYPE, + "", + describedGroupMembers, + Collections.emptySet())); + return data; + } + + private static FeatureMetadata defaultFeatureMetadata() { + return new FeatureMetadata( + Utils.mkMap(Utils.mkEntry("test_feature_1", new FinalizedVersionRange((short) 2, (short) 3))), + Optional.of(1L), + Utils.mkMap(Utils.mkEntry("test_feature_1", new SupportedVersionRange((short) 1, (short) 5)))); + } + + private static Features convertSupportedFeaturesMap(Map features) { + final Map featuresMap = new HashMap<>(); + for (final Map.Entry entry : features.entrySet()) { + final SupportedVersionRange versionRange = entry.getValue(); + featuresMap.put( + entry.getKey(), + new org.apache.kafka.common.feature.SupportedVersionRange(versionRange.minVersion(), + versionRange.maxVersion())); + } + + return Features.supportedFeatures(featuresMap); + } + + private static Features convertFinalizedFeaturesMap(Map features) { + final Map featuresMap = new HashMap<>(); + for (final Map.Entry entry : features.entrySet()) { + final FinalizedVersionRange versionRange = entry.getValue(); + featuresMap.put( + entry.getKey(), + new org.apache.kafka.common.feature.FinalizedVersionRange( + versionRange.minVersionLevel(), versionRange.maxVersionLevel())); + } + + return Features.finalizedFeatures(featuresMap); + } + + private static ApiVersionsResponse prepareApiVersionsResponseForDescribeFeatures(Errors error) { + if (error == Errors.NONE) { + return ApiVersionsResponse.createApiVersionsResponse( + 0, + ApiVersionsResponse.filterApis(RecordVersion.current(), ApiMessageType.ListenerType.ZK_BROKER), + convertSupportedFeaturesMap(defaultFeatureMetadata().supportedFeatures()), + convertFinalizedFeaturesMap(defaultFeatureMetadata().finalizedFeatures()), + defaultFeatureMetadata().finalizedFeaturesEpoch().get() + ); + } + return new ApiVersionsResponse( + new ApiVersionsResponseData() + .setThrottleTimeMs(0) + .setErrorCode(error.code())); + } + + /** + * Test that the client properly times out when we don't receive any metadata. + */ + @Test + public void testTimeoutWithoutMetadata() throws Exception { + try (final AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(Time.SYSTEM, mockBootstrapCluster(), + newStrMap(AdminClientConfig.REQUEST_TIMEOUT_MS_CONFIG, "10"))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + env.kafkaClient().prepareResponse(prepareCreateTopicsResponse("myTopic", Errors.NONE)); + KafkaFuture future = env.adminClient().createTopics( + singleton(new NewTopic("myTopic", Collections.singletonMap(0, asList(0, 1, 2)))), + new CreateTopicsOptions().timeoutMs(1000)).all(); + TestUtils.assertFutureError(future, TimeoutException.class); + } + } + + @Test + public void testConnectionFailureOnMetadataUpdate() throws Exception { + // This tests the scenario in which we successfully connect to the bootstrap server, but + // the server disconnects before sending the full response + + Cluster cluster = mockBootstrapCluster(); + try (final AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(Time.SYSTEM, cluster)) { + Cluster discoveredCluster = mockCluster(3, 0); + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + env.kafkaClient().prepareResponse(request -> request instanceof MetadataRequest, null, true); + env.kafkaClient().prepareResponse(request -> request instanceof MetadataRequest, + RequestTestUtils.metadataResponse(discoveredCluster.nodes(), discoveredCluster.clusterResource().clusterId(), + 1, Collections.emptyList())); + env.kafkaClient().prepareResponse(body -> body instanceof CreateTopicsRequest, + prepareCreateTopicsResponse("myTopic", Errors.NONE)); + + KafkaFuture future = env.adminClient().createTopics( + singleton(new NewTopic("myTopic", Collections.singletonMap(0, asList(0, 1, 2)))), + new CreateTopicsOptions().timeoutMs(10000)).all(); + + future.get(); + } + } + + @Test + public void testUnreachableBootstrapServer() throws Exception { + // This tests the scenario in which the bootstrap server is unreachable for a short while, + // which prevents AdminClient from being able to send the initial metadata request + + Cluster cluster = Cluster.bootstrap(singletonList(new InetSocketAddress("localhost", 8121))); + Map unreachableNodes = Collections.singletonMap(cluster.nodes().get(0), 200L); + try (final AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(Time.SYSTEM, cluster, + AdminClientUnitTestEnv.clientConfigs(), unreachableNodes)) { + Cluster discoveredCluster = mockCluster(3, 0); + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + env.kafkaClient().prepareResponse(body -> body instanceof MetadataRequest, + RequestTestUtils.metadataResponse(discoveredCluster.nodes(), discoveredCluster.clusterResource().clusterId(), + 1, Collections.emptyList())); + env.kafkaClient().prepareResponse(body -> body instanceof CreateTopicsRequest, + prepareCreateTopicsResponse("myTopic", Errors.NONE)); + + KafkaFuture future = env.adminClient().createTopics( + singleton(new NewTopic("myTopic", Collections.singletonMap(0, asList(0, 1, 2)))), + new CreateTopicsOptions().timeoutMs(10000)).all(); + + future.get(); + } + } + + /** + * Test that we propagate exceptions encountered when fetching metadata. + */ + @Test + public void testPropagatedMetadataFetchException() throws Exception { + try (final AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(Time.SYSTEM, + mockCluster(3, 0), + newStrMap(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:8121", + AdminClientConfig.REQUEST_TIMEOUT_MS_CONFIG, "10"))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + env.kafkaClient().createPendingAuthenticationError(env.cluster().nodeById(0), + TimeUnit.DAYS.toMillis(1)); + env.kafkaClient().prepareResponse(prepareCreateTopicsResponse("myTopic", Errors.NONE)); + KafkaFuture future = env.adminClient().createTopics( + singleton(new NewTopic("myTopic", Collections.singletonMap(0, asList(0, 1, 2)))), + new CreateTopicsOptions().timeoutMs(1000)).all(); + TestUtils.assertFutureError(future, SaslAuthenticationException.class); + } + } + + @Test + public void testCreateTopics() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + env.kafkaClient().prepareResponse( + expectCreateTopicsRequestWithTopics("myTopic"), + prepareCreateTopicsResponse("myTopic", Errors.NONE)); + KafkaFuture future = env.adminClient().createTopics( + singleton(new NewTopic("myTopic", Collections.singletonMap(0, asList(0, 1, 2)))), + new CreateTopicsOptions().timeoutMs(10000)).all(); + future.get(); + } + } + + @Test + public void testCreateTopicsPartialResponse() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + env.kafkaClient().prepareResponse( + expectCreateTopicsRequestWithTopics("myTopic", "myTopic2"), + prepareCreateTopicsResponse("myTopic", Errors.NONE)); + CreateTopicsResult topicsResult = env.adminClient().createTopics( + asList(new NewTopic("myTopic", Collections.singletonMap(0, asList(0, 1, 2))), + new NewTopic("myTopic2", Collections.singletonMap(0, asList(0, 1, 2)))), + new CreateTopicsOptions().timeoutMs(10000)); + topicsResult.values().get("myTopic").get(); + TestUtils.assertFutureThrows(topicsResult.values().get("myTopic2"), ApiException.class); + } + } + + @Test + public void testCreateTopicsRetryBackoff() throws Exception { + MockTime time = new MockTime(); + int retryBackoff = 100; + + try (final AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(time, + mockCluster(3, 0), + newStrMap(AdminClientConfig.RETRY_BACKOFF_MS_CONFIG, "" + retryBackoff))) { + MockClient mockClient = env.kafkaClient(); + + mockClient.setNodeApiVersions(NodeApiVersions.create()); + + AtomicLong firstAttemptTime = new AtomicLong(0); + AtomicLong secondAttemptTime = new AtomicLong(0); + + mockClient.prepareResponse(body -> { + firstAttemptTime.set(time.milliseconds()); + return body instanceof CreateTopicsRequest; + }, null, true); + + mockClient.prepareResponse(body -> { + secondAttemptTime.set(time.milliseconds()); + return body instanceof CreateTopicsRequest; + }, prepareCreateTopicsResponse("myTopic", Errors.NONE)); + + KafkaFuture future = env.adminClient().createTopics( + singleton(new NewTopic("myTopic", Collections.singletonMap(0, asList(0, 1, 2)))), + new CreateTopicsOptions().timeoutMs(10000)).all(); + + // Wait until the first attempt has failed, then advance the time + TestUtils.waitForCondition(() -> mockClient.numAwaitingResponses() == 1, + "Failed awaiting CreateTopics first request failure"); + + // Wait until the retry call added to the queue in AdminClient + TestUtils.waitForCondition(() -> ((KafkaAdminClient) env.adminClient()).numPendingCalls() == 1, + "Failed to add retry CreateTopics call"); + + time.sleep(retryBackoff); + + future.get(); + + long actualRetryBackoff = secondAttemptTime.get() - firstAttemptTime.get(); + assertEquals(retryBackoff, actualRetryBackoff, "CreateTopics retry did not await expected backoff"); + } + } + + @Test + public void testCreateTopicsHandleNotControllerException() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + env.kafkaClient().prepareResponseFrom( + prepareCreateTopicsResponse("myTopic", Errors.NOT_CONTROLLER), + env.cluster().nodeById(0)); + env.kafkaClient().prepareResponse(RequestTestUtils.metadataResponse(env.cluster().nodes(), + env.cluster().clusterResource().clusterId(), + 1, + Collections.emptyList())); + env.kafkaClient().prepareResponseFrom( + prepareCreateTopicsResponse("myTopic", Errors.NONE), + env.cluster().nodeById(1)); + KafkaFuture future = env.adminClient().createTopics( + singleton(new NewTopic("myTopic", Collections.singletonMap(0, asList(0, 1, 2)))), + new CreateTopicsOptions().timeoutMs(10000)).all(); + future.get(); + } + } + + @Test + public void testCreateTopicsRetryThrottlingExceptionWhenEnabled() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse( + expectCreateTopicsRequestWithTopics("topic1", "topic2", "topic3"), + prepareCreateTopicsResponse(1000, + creatableTopicResult("topic1", Errors.NONE), + creatableTopicResult("topic2", Errors.THROTTLING_QUOTA_EXCEEDED), + creatableTopicResult("topic3", Errors.TOPIC_ALREADY_EXISTS))); + + env.kafkaClient().prepareResponse( + expectCreateTopicsRequestWithTopics("topic2"), + prepareCreateTopicsResponse(1000, + creatableTopicResult("topic2", Errors.THROTTLING_QUOTA_EXCEEDED))); + + env.kafkaClient().prepareResponse( + expectCreateTopicsRequestWithTopics("topic2"), + prepareCreateTopicsResponse(0, + creatableTopicResult("topic2", Errors.NONE))); + + CreateTopicsResult result = env.adminClient().createTopics( + asList( + new NewTopic("topic1", 1, (short) 1), + new NewTopic("topic2", 1, (short) 1), + new NewTopic("topic3", 1, (short) 1)), + new CreateTopicsOptions().retryOnQuotaViolation(true)); + + assertNull(result.values().get("topic1").get()); + assertNull(result.values().get("topic2").get()); + TestUtils.assertFutureThrows(result.values().get("topic3"), TopicExistsException.class); + } + } + + @Test + public void testCreateTopicsRetryThrottlingExceptionWhenEnabledUntilRequestTimeOut() throws Exception { + long defaultApiTimeout = 60000; + MockTime time = new MockTime(); + + try (AdminClientUnitTestEnv env = mockClientEnv(time, + AdminClientConfig.DEFAULT_API_TIMEOUT_MS_CONFIG, String.valueOf(defaultApiTimeout))) { + + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse( + expectCreateTopicsRequestWithTopics("topic1", "topic2", "topic3"), + prepareCreateTopicsResponse(1000, + creatableTopicResult("topic1", Errors.NONE), + creatableTopicResult("topic2", Errors.THROTTLING_QUOTA_EXCEEDED), + creatableTopicResult("topic3", Errors.TOPIC_ALREADY_EXISTS))); + + env.kafkaClient().prepareResponse( + expectCreateTopicsRequestWithTopics("topic2"), + prepareCreateTopicsResponse(1000, + creatableTopicResult("topic2", Errors.THROTTLING_QUOTA_EXCEEDED))); + + CreateTopicsResult result = env.adminClient().createTopics( + asList( + new NewTopic("topic1", 1, (short) 1), + new NewTopic("topic2", 1, (short) 1), + new NewTopic("topic3", 1, (short) 1)), + new CreateTopicsOptions().retryOnQuotaViolation(true)); + + // Wait until the prepared attempts have consumed + TestUtils.waitForCondition(() -> env.kafkaClient().numAwaitingResponses() == 0, + "Failed awaiting CreateTopics requests"); + + // Wait until the next request is sent out + TestUtils.waitForCondition(() -> env.kafkaClient().inFlightRequestCount() == 1, + "Failed awaiting next CreateTopics request"); + + // Advance time past the default api timeout to time out the inflight request + time.sleep(defaultApiTimeout + 1); + + assertNull(result.values().get("topic1").get()); + ThrottlingQuotaExceededException e = TestUtils.assertFutureThrows(result.values().get("topic2"), + ThrottlingQuotaExceededException.class); + assertEquals(0, e.throttleTimeMs()); + TestUtils.assertFutureThrows(result.values().get("topic3"), TopicExistsException.class); + } + } + + @Test + public void testCreateTopicsDontRetryThrottlingExceptionWhenDisabled() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse( + expectCreateTopicsRequestWithTopics("topic1", "topic2", "topic3"), + prepareCreateTopicsResponse(1000, + creatableTopicResult("topic1", Errors.NONE), + creatableTopicResult("topic2", Errors.THROTTLING_QUOTA_EXCEEDED), + creatableTopicResult("topic3", Errors.TOPIC_ALREADY_EXISTS))); + + CreateTopicsResult result = env.adminClient().createTopics( + asList( + new NewTopic("topic1", 1, (short) 1), + new NewTopic("topic2", 1, (short) 1), + new NewTopic("topic3", 1, (short) 1)), + new CreateTopicsOptions().retryOnQuotaViolation(false)); + + assertNull(result.values().get("topic1").get()); + ThrottlingQuotaExceededException e = TestUtils.assertFutureThrows(result.values().get("topic2"), + ThrottlingQuotaExceededException.class); + assertEquals(1000, e.throttleTimeMs()); + TestUtils.assertFutureThrows(result.values().get("topic3"), TopicExistsException.class); + } + } + + private MockClient.RequestMatcher expectCreateTopicsRequestWithTopics(final String... topics) { + return body -> { + if (body instanceof CreateTopicsRequest) { + CreateTopicsRequest request = (CreateTopicsRequest) body; + for (String topic : topics) { + if (request.data().topics().find(topic) == null) + return false; + } + return topics.length == request.data().topics().size(); + } + return false; + }; + } + + @Test + public void testDeleteTopics() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse( + expectDeleteTopicsRequestWithTopics("myTopic"), + prepareDeleteTopicsResponse("myTopic", Errors.NONE)); + KafkaFuture future = env.adminClient().deleteTopics(singletonList("myTopic"), + new DeleteTopicsOptions()).all(); + assertNull(future.get()); + + env.kafkaClient().prepareResponse( + expectDeleteTopicsRequestWithTopics("myTopic"), + prepareDeleteTopicsResponse("myTopic", Errors.TOPIC_DELETION_DISABLED)); + future = env.adminClient().deleteTopics(singletonList("myTopic"), + new DeleteTopicsOptions()).all(); + TestUtils.assertFutureError(future, TopicDeletionDisabledException.class); + + env.kafkaClient().prepareResponse( + expectDeleteTopicsRequestWithTopics("myTopic"), + prepareDeleteTopicsResponse("myTopic", Errors.UNKNOWN_TOPIC_OR_PARTITION)); + future = env.adminClient().deleteTopics(singletonList("myTopic"), + new DeleteTopicsOptions()).all(); + TestUtils.assertFutureError(future, UnknownTopicOrPartitionException.class); + + // With topic IDs + Uuid topicId = Uuid.randomUuid(); + + env.kafkaClient().prepareResponse( + expectDeleteTopicsRequestWithTopicIds(topicId), + prepareDeleteTopicsResponseWithTopicId(topicId, Errors.NONE)); + future = env.adminClient().deleteTopics(TopicCollection.ofTopicIds(singletonList(topicId)), + new DeleteTopicsOptions()).all(); + assertNull(future.get()); + + env.kafkaClient().prepareResponse( + expectDeleteTopicsRequestWithTopicIds(topicId), + prepareDeleteTopicsResponseWithTopicId(topicId, Errors.TOPIC_DELETION_DISABLED)); + future = env.adminClient().deleteTopics(TopicCollection.ofTopicIds(singletonList(topicId)), + new DeleteTopicsOptions()).all(); + TestUtils.assertFutureError(future, TopicDeletionDisabledException.class); + + env.kafkaClient().prepareResponse( + expectDeleteTopicsRequestWithTopicIds(topicId), + prepareDeleteTopicsResponseWithTopicId(topicId, Errors.UNKNOWN_TOPIC_ID)); + future = env.adminClient().deleteTopics(TopicCollection.ofTopicIds(singletonList(topicId)), + new DeleteTopicsOptions()).all(); + TestUtils.assertFutureError(future, UnknownTopicIdException.class); + } + } + + + @Test + public void testDeleteTopicsPartialResponse() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse( + expectDeleteTopicsRequestWithTopics("myTopic", "myOtherTopic"), + prepareDeleteTopicsResponse(1000, + deletableTopicResult("myTopic", Errors.NONE))); + + DeleteTopicsResult result = env.adminClient().deleteTopics( + asList("myTopic", "myOtherTopic"), new DeleteTopicsOptions()); + + result.topicNameValues().get("myTopic").get(); + TestUtils.assertFutureThrows(result.topicNameValues().get("myOtherTopic"), ApiException.class); + + // With topic IDs + Uuid topicId1 = Uuid.randomUuid(); + Uuid topicId2 = Uuid.randomUuid(); + env.kafkaClient().prepareResponse( + expectDeleteTopicsRequestWithTopicIds(topicId1, topicId2), + prepareDeleteTopicsResponse(1000, + deletableTopicResultWithId(topicId1, Errors.NONE))); + + DeleteTopicsResult resultIds = env.adminClient().deleteTopics( + TopicCollection.ofTopicIds(asList(topicId1, topicId2)), new DeleteTopicsOptions()); + + resultIds.topicIdValues().get(topicId1).get(); + TestUtils.assertFutureThrows(resultIds.topicIdValues().get(topicId2), ApiException.class); + } + } + + @Test + public void testDeleteTopicsRetryThrottlingExceptionWhenEnabled() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse( + expectDeleteTopicsRequestWithTopics("topic1", "topic2", "topic3"), + prepareDeleteTopicsResponse(1000, + deletableTopicResult("topic1", Errors.NONE), + deletableTopicResult("topic2", Errors.THROTTLING_QUOTA_EXCEEDED), + deletableTopicResult("topic3", Errors.TOPIC_ALREADY_EXISTS))); + + env.kafkaClient().prepareResponse( + expectDeleteTopicsRequestWithTopics("topic2"), + prepareDeleteTopicsResponse(1000, + deletableTopicResult("topic2", Errors.THROTTLING_QUOTA_EXCEEDED))); + + env.kafkaClient().prepareResponse( + expectDeleteTopicsRequestWithTopics("topic2"), + prepareDeleteTopicsResponse(0, + deletableTopicResult("topic2", Errors.NONE))); + + DeleteTopicsResult result = env.adminClient().deleteTopics( + asList("topic1", "topic2", "topic3"), + new DeleteTopicsOptions().retryOnQuotaViolation(true)); + + assertNull(result.topicNameValues().get("topic1").get()); + assertNull(result.topicNameValues().get("topic2").get()); + TestUtils.assertFutureThrows(result.topicNameValues().get("topic3"), TopicExistsException.class); + + // With topic IDs + Uuid topicId1 = Uuid.randomUuid(); + Uuid topicId2 = Uuid.randomUuid(); + Uuid topicId3 = Uuid.randomUuid(); + + env.kafkaClient().prepareResponse( + expectDeleteTopicsRequestWithTopicIds(topicId1, topicId2, topicId3), + prepareDeleteTopicsResponse(1000, + deletableTopicResultWithId(topicId1, Errors.NONE), + deletableTopicResultWithId(topicId2, Errors.THROTTLING_QUOTA_EXCEEDED), + deletableTopicResultWithId(topicId3, Errors.UNKNOWN_TOPIC_ID))); + + env.kafkaClient().prepareResponse( + expectDeleteTopicsRequestWithTopicIds(topicId2), + prepareDeleteTopicsResponse(1000, + deletableTopicResultWithId(topicId2, Errors.THROTTLING_QUOTA_EXCEEDED))); + + env.kafkaClient().prepareResponse( + expectDeleteTopicsRequestWithTopicIds(topicId2), + prepareDeleteTopicsResponse(0, + deletableTopicResultWithId(topicId2, Errors.NONE))); + + DeleteTopicsResult resultIds = env.adminClient().deleteTopics( + TopicCollection.ofTopicIds(asList(topicId1, topicId2, topicId3)), + new DeleteTopicsOptions().retryOnQuotaViolation(true)); + + assertNull(resultIds.topicIdValues().get(topicId1).get()); + assertNull(resultIds.topicIdValues().get(topicId2).get()); + TestUtils.assertFutureThrows(resultIds.topicIdValues().get(topicId3), UnknownTopicIdException.class); + } + } + + @Test + public void testDeleteTopicsRetryThrottlingExceptionWhenEnabledUntilRequestTimeOut() throws Exception { + long defaultApiTimeout = 60000; + MockTime time = new MockTime(); + + try (AdminClientUnitTestEnv env = mockClientEnv(time, + AdminClientConfig.DEFAULT_API_TIMEOUT_MS_CONFIG, String.valueOf(defaultApiTimeout))) { + + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse( + expectDeleteTopicsRequestWithTopics("topic1", "topic2", "topic3"), + prepareDeleteTopicsResponse(1000, + deletableTopicResult("topic1", Errors.NONE), + deletableTopicResult("topic2", Errors.THROTTLING_QUOTA_EXCEEDED), + deletableTopicResult("topic3", Errors.TOPIC_ALREADY_EXISTS))); + + env.kafkaClient().prepareResponse( + expectDeleteTopicsRequestWithTopics("topic2"), + prepareDeleteTopicsResponse(1000, + deletableTopicResult("topic2", Errors.THROTTLING_QUOTA_EXCEEDED))); + + DeleteTopicsResult result = env.adminClient().deleteTopics( + asList("topic1", "topic2", "topic3"), + new DeleteTopicsOptions().retryOnQuotaViolation(true)); + + // Wait until the prepared attempts have consumed + TestUtils.waitForCondition(() -> env.kafkaClient().numAwaitingResponses() == 0, + "Failed awaiting DeleteTopics requests"); + + // Wait until the next request is sent out + TestUtils.waitForCondition(() -> env.kafkaClient().inFlightRequestCount() == 1, + "Failed awaiting next DeleteTopics request"); + + // Advance time past the default api timeout to time out the inflight request + time.sleep(defaultApiTimeout + 1); + + assertNull(result.topicNameValues().get("topic1").get()); + ThrottlingQuotaExceededException e = TestUtils.assertFutureThrows(result.topicNameValues().get("topic2"), + ThrottlingQuotaExceededException.class); + assertEquals(0, e.throttleTimeMs()); + TestUtils.assertFutureThrows(result.topicNameValues().get("topic3"), TopicExistsException.class); + + // With topic IDs + Uuid topicId1 = Uuid.randomUuid(); + Uuid topicId2 = Uuid.randomUuid(); + Uuid topicId3 = Uuid.randomUuid(); + env.kafkaClient().prepareResponse( + expectDeleteTopicsRequestWithTopicIds(topicId1, topicId2, topicId3), + prepareDeleteTopicsResponse(1000, + deletableTopicResultWithId(topicId1, Errors.NONE), + deletableTopicResultWithId(topicId2, Errors.THROTTLING_QUOTA_EXCEEDED), + deletableTopicResultWithId(topicId3, Errors.UNKNOWN_TOPIC_ID))); + + env.kafkaClient().prepareResponse( + expectDeleteTopicsRequestWithTopicIds(topicId2), + prepareDeleteTopicsResponse(1000, + deletableTopicResultWithId(topicId2, Errors.THROTTLING_QUOTA_EXCEEDED))); + + DeleteTopicsResult resultIds = env.adminClient().deleteTopics( + TopicCollection.ofTopicIds(asList(topicId1, topicId2, topicId3)), + new DeleteTopicsOptions().retryOnQuotaViolation(true)); + + // Wait until the prepared attempts have consumed + TestUtils.waitForCondition(() -> env.kafkaClient().numAwaitingResponses() == 0, + "Failed awaiting DeleteTopics requests"); + + // Wait until the next request is sent out + TestUtils.waitForCondition(() -> env.kafkaClient().inFlightRequestCount() == 1, + "Failed awaiting next DeleteTopics request"); + + // Advance time past the default api timeout to time out the inflight request + time.sleep(defaultApiTimeout + 1); + + assertNull(resultIds.topicIdValues().get(topicId1).get()); + e = TestUtils.assertFutureThrows(resultIds.topicIdValues().get(topicId2), + ThrottlingQuotaExceededException.class); + assertEquals(0, e.throttleTimeMs()); + TestUtils.assertFutureThrows(resultIds.topicIdValues().get(topicId3), UnknownTopicIdException.class); + } + } + + @Test + public void testDeleteTopicsDontRetryThrottlingExceptionWhenDisabled() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse( + expectDeleteTopicsRequestWithTopics("topic1", "topic2", "topic3"), + prepareDeleteTopicsResponse(1000, + deletableTopicResult("topic1", Errors.NONE), + deletableTopicResult("topic2", Errors.THROTTLING_QUOTA_EXCEEDED), + deletableTopicResult("topic3", Errors.TOPIC_ALREADY_EXISTS))); + + DeleteTopicsResult result = env.adminClient().deleteTopics( + asList("topic1", "topic2", "topic3"), + new DeleteTopicsOptions().retryOnQuotaViolation(false)); + + assertNull(result.topicNameValues().get("topic1").get()); + ThrottlingQuotaExceededException e = TestUtils.assertFutureThrows(result.topicNameValues().get("topic2"), + ThrottlingQuotaExceededException.class); + assertEquals(1000, e.throttleTimeMs()); + TestUtils.assertFutureError(result.topicNameValues().get("topic3"), TopicExistsException.class); + + // With topic IDs + Uuid topicId1 = Uuid.randomUuid(); + Uuid topicId2 = Uuid.randomUuid(); + Uuid topicId3 = Uuid.randomUuid(); + env.kafkaClient().prepareResponse( + expectDeleteTopicsRequestWithTopicIds(topicId1, topicId2, topicId3), + prepareDeleteTopicsResponse(1000, + deletableTopicResultWithId(topicId1, Errors.NONE), + deletableTopicResultWithId(topicId2, Errors.THROTTLING_QUOTA_EXCEEDED), + deletableTopicResultWithId(topicId3, Errors.UNKNOWN_TOPIC_ID))); + + DeleteTopicsResult resultIds = env.adminClient().deleteTopics( + TopicCollection.ofTopicIds(asList(topicId1, topicId2, topicId3)), + new DeleteTopicsOptions().retryOnQuotaViolation(false)); + + assertNull(resultIds.topicIdValues().get(topicId1).get()); + e = TestUtils.assertFutureThrows(resultIds.topicIdValues().get(topicId2), + ThrottlingQuotaExceededException.class); + assertEquals(1000, e.throttleTimeMs()); + TestUtils.assertFutureError(resultIds.topicIdValues().get(topicId3), UnknownTopicIdException.class); + } + } + + private MockClient.RequestMatcher expectDeleteTopicsRequestWithTopics(final String... topics) { + return body -> { + if (body instanceof DeleteTopicsRequest) { + DeleteTopicsRequest request = (DeleteTopicsRequest) body; + return request.topicNames().equals(Arrays.asList(topics)); + } + return false; + }; + } + + private MockClient.RequestMatcher expectDeleteTopicsRequestWithTopicIds(final Uuid... topicIds) { + return body -> { + if (body instanceof DeleteTopicsRequest) { + DeleteTopicsRequest request = (DeleteTopicsRequest) body; + return request.topicIds().equals(Arrays.asList(topicIds)); + } + return false; + }; + } + + @Test + public void testInvalidTopicNames() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + List sillyTopicNames = asList("", null); + Map> deleteFutures = env.adminClient().deleteTopics(sillyTopicNames).topicNameValues(); + for (String sillyTopicName : sillyTopicNames) { + TestUtils.assertFutureError(deleteFutures.get(sillyTopicName), InvalidTopicException.class); + } + assertEquals(0, env.kafkaClient().inFlightRequestCount()); + + Map> describeFutures = + env.adminClient().describeTopics(sillyTopicNames).topicNameValues(); + for (String sillyTopicName : sillyTopicNames) { + TestUtils.assertFutureError(describeFutures.get(sillyTopicName), InvalidTopicException.class); + } + assertEquals(0, env.kafkaClient().inFlightRequestCount()); + + List newTopics = new ArrayList<>(); + for (String sillyTopicName : sillyTopicNames) { + newTopics.add(new NewTopic(sillyTopicName, 1, (short) 1)); + } + + Map> createFutures = env.adminClient().createTopics(newTopics).values(); + for (String sillyTopicName : sillyTopicNames) { + TestUtils.assertFutureError(createFutures .get(sillyTopicName), InvalidTopicException.class); + } + assertEquals(0, env.kafkaClient().inFlightRequestCount()); + } + } + + @Test + public void testMetadataRetries() throws Exception { + // We should continue retrying on metadata update failures in spite of retry configuration + + String topic = "topic"; + Cluster bootstrapCluster = Cluster.bootstrap(singletonList(new InetSocketAddress("localhost", 9999))); + Cluster initializedCluster = mockCluster(3, 0); + + try (final AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(Time.SYSTEM, bootstrapCluster, + newStrMap(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999", + AdminClientConfig.DEFAULT_API_TIMEOUT_MS_CONFIG, "10000000", + AdminClientConfig.RETRIES_CONFIG, "0"))) { + + // The first request fails with a disconnect + env.kafkaClient().prepareResponse(null, true); + + // The next one succeeds and gives us the controller id + env.kafkaClient().prepareResponse(RequestTestUtils.metadataResponse(initializedCluster.nodes(), + initializedCluster.clusterResource().clusterId(), + initializedCluster.controller().id(), + Collections.emptyList())); + + // Then we respond to the DescribeTopic request + Node leader = initializedCluster.nodes().get(0); + MetadataResponse.PartitionMetadata partitionMetadata = new MetadataResponse.PartitionMetadata( + Errors.NONE, new TopicPartition(topic, 0), Optional.of(leader.id()), Optional.of(10), + singletonList(leader.id()), singletonList(leader.id()), singletonList(leader.id())); + env.kafkaClient().prepareResponse(RequestTestUtils.metadataResponse(initializedCluster.nodes(), + initializedCluster.clusterResource().clusterId(), 1, + singletonList(new MetadataResponse.TopicMetadata(Errors.NONE, topic, Uuid.ZERO_UUID, false, + singletonList(partitionMetadata), MetadataResponse.AUTHORIZED_OPERATIONS_OMITTED)))); + + DescribeTopicsResult result = env.adminClient().describeTopics(singleton(topic)); + Map topicDescriptions = result.allTopicNames().get(); + assertEquals(leader, topicDescriptions.get(topic).partitions().get(0).leader()); + assertNull(topicDescriptions.get(topic).authorizedOperations()); + } + } + + @Test + public void testAdminClientApisAuthenticationFailure() throws Exception { + Cluster cluster = mockBootstrapCluster(); + try (final AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(Time.SYSTEM, cluster, + newStrMap(AdminClientConfig.REQUEST_TIMEOUT_MS_CONFIG, "1000"))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + env.kafkaClient().createPendingAuthenticationError(cluster.nodes().get(0), + TimeUnit.DAYS.toMillis(1)); + callAdminClientApisAndExpectAnAuthenticationError(env); + callClientQuotasApisAndExpectAnAuthenticationError(env); + } + } + + private void callAdminClientApisAndExpectAnAuthenticationError(AdminClientUnitTestEnv env) throws InterruptedException { + ExecutionException e = assertThrows(ExecutionException.class, () -> env.adminClient().createTopics( + singleton(new NewTopic("myTopic", Collections.singletonMap(0, asList(0, 1, 2)))), + new CreateTopicsOptions().timeoutMs(10000)).all().get()); + assertTrue(e.getCause() instanceof AuthenticationException, + "Expected an authentication error, but got " + Utils.stackTrace(e)); + + Map counts = new HashMap<>(); + counts.put("my_topic", NewPartitions.increaseTo(3)); + counts.put("other_topic", NewPartitions.increaseTo(3, asList(asList(2), asList(3)))); + e = assertThrows(ExecutionException.class, () -> env.adminClient().createPartitions(counts).all().get()); + assertTrue(e.getCause() instanceof AuthenticationException, + "Expected an authentication error, but got " + Utils.stackTrace(e)); + + e = assertThrows(ExecutionException.class, () -> env.adminClient().createAcls(asList(ACL1, ACL2)).all().get()); + assertTrue(e.getCause() instanceof AuthenticationException, + "Expected an authentication error, but got " + Utils.stackTrace(e)); + + e = assertThrows(ExecutionException.class, () -> env.adminClient().describeAcls(FILTER1).values().get()); + assertTrue(e.getCause() instanceof AuthenticationException, + "Expected an authentication error, but got " + Utils.stackTrace(e)); + + e = assertThrows(ExecutionException.class, () -> env.adminClient().deleteAcls(asList(FILTER1, FILTER2)).all().get()); + assertTrue(e.getCause() instanceof AuthenticationException, + "Expected an authentication error, but got " + Utils.stackTrace(e)); + + e = assertThrows(ExecutionException.class, () -> env.adminClient().describeConfigs( + singleton(new ConfigResource(ConfigResource.Type.BROKER, "0"))).all().get()); + assertTrue(e.getCause() instanceof AuthenticationException, + "Expected an authentication error, but got " + Utils.stackTrace(e)); + } + + private void callClientQuotasApisAndExpectAnAuthenticationError(AdminClientUnitTestEnv env) throws InterruptedException { + ExecutionException e = assertThrows(ExecutionException.class, + () -> env.adminClient().describeClientQuotas(ClientQuotaFilter.all()).entities().get()); + assertTrue(e.getCause() instanceof AuthenticationException, + "Expected an authentication error, but got " + Utils.stackTrace(e)); + + ClientQuotaEntity entity = new ClientQuotaEntity(Collections.singletonMap(ClientQuotaEntity.USER, "user")); + ClientQuotaAlteration alteration = new ClientQuotaAlteration(entity, asList(new ClientQuotaAlteration.Op("consumer_byte_rate", 1000.0))); + e = assertThrows(ExecutionException.class, + () -> env.adminClient().alterClientQuotas(asList(alteration)).all().get()); + + assertTrue(e.getCause() instanceof AuthenticationException, + "Expected an authentication error, but got " + Utils.stackTrace(e)); + } + + private static final AclBinding ACL1 = new AclBinding(new ResourcePattern(ResourceType.TOPIC, "mytopic3", PatternType.LITERAL), + new AccessControlEntry("User:ANONYMOUS", "*", AclOperation.DESCRIBE, AclPermissionType.ALLOW)); + private static final AclBinding ACL2 = new AclBinding(new ResourcePattern(ResourceType.TOPIC, "mytopic4", PatternType.LITERAL), + new AccessControlEntry("User:ANONYMOUS", "*", AclOperation.DESCRIBE, AclPermissionType.DENY)); + private static final AclBindingFilter FILTER1 = new AclBindingFilter(new ResourcePatternFilter(ResourceType.ANY, null, PatternType.LITERAL), + new AccessControlEntryFilter("User:ANONYMOUS", null, AclOperation.ANY, AclPermissionType.ANY)); + private static final AclBindingFilter FILTER2 = new AclBindingFilter(new ResourcePatternFilter(ResourceType.ANY, null, PatternType.LITERAL), + new AccessControlEntryFilter("User:bob", null, AclOperation.ANY, AclPermissionType.ANY)); + private static final AclBindingFilter UNKNOWN_FILTER = new AclBindingFilter( + new ResourcePatternFilter(ResourceType.UNKNOWN, null, PatternType.LITERAL), + new AccessControlEntryFilter("User:bob", null, AclOperation.ANY, AclPermissionType.ANY)); + + @Test + public void testDescribeAcls() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + // Test a call where we get back ACL1 and ACL2. + env.kafkaClient().prepareResponse(new DescribeAclsResponse(new DescribeAclsResponseData() + .setResources(DescribeAclsResponse.aclsResources(asList(ACL1, ACL2))), ApiKeys.DESCRIBE_ACLS.latestVersion())); + assertCollectionIs(env.adminClient().describeAcls(FILTER1).values().get(), ACL1, ACL2); + + // Test a call where we get back no results. + env.kafkaClient().prepareResponse(new DescribeAclsResponse(new DescribeAclsResponseData(), + ApiKeys.DESCRIBE_ACLS.latestVersion())); + assertTrue(env.adminClient().describeAcls(FILTER2).values().get().isEmpty()); + + // Test a call where we get back an error. + env.kafkaClient().prepareResponse(new DescribeAclsResponse(new DescribeAclsResponseData() + .setErrorCode(Errors.SECURITY_DISABLED.code()) + .setErrorMessage("Security is disabled"), ApiKeys.DESCRIBE_ACLS.latestVersion())); + TestUtils.assertFutureError(env.adminClient().describeAcls(FILTER2).values(), SecurityDisabledException.class); + + // Test a call where we supply an invalid filter. + TestUtils.assertFutureError(env.adminClient().describeAcls(UNKNOWN_FILTER).values(), + InvalidRequestException.class); + } + } + + @Test + public void testCreateAcls() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + // Test a call where we successfully create two ACLs. + env.kafkaClient().prepareResponse(new CreateAclsResponse(new CreateAclsResponseData().setResults(asList( + new CreateAclsResponseData.AclCreationResult(), + new CreateAclsResponseData.AclCreationResult())))); + CreateAclsResult results = env.adminClient().createAcls(asList(ACL1, ACL2)); + assertCollectionIs(results.values().keySet(), ACL1, ACL2); + for (KafkaFuture future : results.values().values()) + future.get(); + results.all().get(); + + // Test a call where we fail to create one ACL. + env.kafkaClient().prepareResponse(new CreateAclsResponse(new CreateAclsResponseData().setResults(asList( + new CreateAclsResponseData.AclCreationResult() + .setErrorCode(Errors.SECURITY_DISABLED.code()) + .setErrorMessage("Security is disabled"), + new CreateAclsResponseData.AclCreationResult())))); + results = env.adminClient().createAcls(asList(ACL1, ACL2)); + assertCollectionIs(results.values().keySet(), ACL1, ACL2); + TestUtils.assertFutureError(results.values().get(ACL1), SecurityDisabledException.class); + results.values().get(ACL2).get(); + TestUtils.assertFutureError(results.all(), SecurityDisabledException.class); + } + } + + @Test + public void testDeleteAcls() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + // Test a call where one filter has an error. + env.kafkaClient().prepareResponse(new DeleteAclsResponse(new DeleteAclsResponseData() + .setThrottleTimeMs(0) + .setFilterResults(asList( + new DeleteAclsResponseData.DeleteAclsFilterResult() + .setMatchingAcls(asList( + DeleteAclsResponse.matchingAcl(ACL1, ApiError.NONE), + DeleteAclsResponse.matchingAcl(ACL2, ApiError.NONE))), + new DeleteAclsResponseData.DeleteAclsFilterResult() + .setErrorCode(Errors.SECURITY_DISABLED.code()) + .setErrorMessage("No security"))), + ApiKeys.DELETE_ACLS.latestVersion())); + DeleteAclsResult results = env.adminClient().deleteAcls(asList(FILTER1, FILTER2)); + Map> filterResults = results.values(); + FilterResults filter1Results = filterResults.get(FILTER1).get(); + assertNull(filter1Results.values().get(0).exception()); + assertEquals(ACL1, filter1Results.values().get(0).binding()); + assertNull(filter1Results.values().get(1).exception()); + assertEquals(ACL2, filter1Results.values().get(1).binding()); + TestUtils.assertFutureError(filterResults.get(FILTER2), SecurityDisabledException.class); + TestUtils.assertFutureError(results.all(), SecurityDisabledException.class); + + // Test a call where one deletion result has an error. + env.kafkaClient().prepareResponse(new DeleteAclsResponse(new DeleteAclsResponseData() + .setThrottleTimeMs(0) + .setFilterResults(asList( + new DeleteAclsResponseData.DeleteAclsFilterResult() + .setMatchingAcls(asList( + DeleteAclsResponse.matchingAcl(ACL1, ApiError.NONE), + new DeleteAclsResponseData.DeleteAclsMatchingAcl() + .setErrorCode(Errors.SECURITY_DISABLED.code()) + .setErrorMessage("No security") + .setPermissionType(AclPermissionType.ALLOW.code()) + .setOperation(AclOperation.ALTER.code()) + .setResourceType(ResourceType.CLUSTER.code()) + .setPatternType(FILTER2.patternFilter().patternType().code()))), + new DeleteAclsResponseData.DeleteAclsFilterResult())), + ApiKeys.DELETE_ACLS.latestVersion())); + results = env.adminClient().deleteAcls(asList(FILTER1, FILTER2)); + assertTrue(results.values().get(FILTER2).get().values().isEmpty()); + TestUtils.assertFutureError(results.all(), SecurityDisabledException.class); + + // Test a call where there are no errors. + env.kafkaClient().prepareResponse(new DeleteAclsResponse(new DeleteAclsResponseData() + .setThrottleTimeMs(0) + .setFilterResults(asList( + new DeleteAclsResponseData.DeleteAclsFilterResult() + .setMatchingAcls(asList(DeleteAclsResponse.matchingAcl(ACL1, ApiError.NONE))), + new DeleteAclsResponseData.DeleteAclsFilterResult() + .setMatchingAcls(asList(DeleteAclsResponse.matchingAcl(ACL2, ApiError.NONE))))), + ApiKeys.DELETE_ACLS.latestVersion())); + results = env.adminClient().deleteAcls(asList(FILTER1, FILTER2)); + Collection deleted = results.all().get(); + assertCollectionIs(deleted, ACL1, ACL2); + } + } + + @Test + public void testElectLeaders() throws Exception { + TopicPartition topic1 = new TopicPartition("topic", 0); + TopicPartition topic2 = new TopicPartition("topic", 2); + try (AdminClientUnitTestEnv env = mockClientEnv()) { + for (ElectionType electionType : ElectionType.values()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + // Test a call where one partition has an error. + ApiError value = ApiError.fromThrowable(new ClusterAuthorizationException(null)); + List electionResults = new ArrayList<>(); + ReplicaElectionResult electionResult = new ReplicaElectionResult(); + electionResult.setTopic(topic1.topic()); + // Add partition 1 result + PartitionResult partition1Result = new PartitionResult(); + partition1Result.setPartitionId(topic1.partition()); + partition1Result.setErrorCode(value.error().code()); + partition1Result.setErrorMessage(value.message()); + electionResult.partitionResult().add(partition1Result); + + // Add partition 2 result + PartitionResult partition2Result = new PartitionResult(); + partition2Result.setPartitionId(topic2.partition()); + partition2Result.setErrorCode(value.error().code()); + partition2Result.setErrorMessage(value.message()); + electionResult.partitionResult().add(partition2Result); + + electionResults.add(electionResult); + + env.kafkaClient().prepareResponse(new ElectLeadersResponse(0, Errors.NONE.code(), + electionResults, ApiKeys.ELECT_LEADERS.latestVersion())); + ElectLeadersResult results = env.adminClient().electLeaders( + electionType, + new HashSet<>(asList(topic1, topic2))); + assertEquals(results.partitions().get().get(topic2).get().getClass(), ClusterAuthorizationException.class); + + // Test a call where there are no errors. By mutating the internal of election results + partition1Result.setErrorCode(ApiError.NONE.error().code()); + partition1Result.setErrorMessage(ApiError.NONE.message()); + + partition2Result.setErrorCode(ApiError.NONE.error().code()); + partition2Result.setErrorMessage(ApiError.NONE.message()); + + env.kafkaClient().prepareResponse(new ElectLeadersResponse(0, Errors.NONE.code(), electionResults, + ApiKeys.ELECT_LEADERS.latestVersion())); + results = env.adminClient().electLeaders(electionType, new HashSet<>(asList(topic1, topic2))); + assertFalse(results.partitions().get().get(topic1).isPresent()); + assertFalse(results.partitions().get().get(topic2).isPresent()); + + // Now try a timeout + results = env.adminClient().electLeaders( + electionType, + new HashSet<>(asList(topic1, topic2)), + new ElectLeadersOptions().timeoutMs(100)); + TestUtils.assertFutureError(results.partitions(), TimeoutException.class); + } + } + } + + @Test + public void testDescribeBrokerConfigs() throws Exception { + ConfigResource broker0Resource = new ConfigResource(ConfigResource.Type.BROKER, "0"); + ConfigResource broker1Resource = new ConfigResource(ConfigResource.Type.BROKER, "1"); + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + env.kafkaClient().prepareResponseFrom(new DescribeConfigsResponse( + new DescribeConfigsResponseData().setResults(asList(new DescribeConfigsResponseData.DescribeConfigsResult() + .setResourceName(broker0Resource.name()).setResourceType(broker0Resource.type().id()).setErrorCode(Errors.NONE.code()) + .setConfigs(emptyList())))), env.cluster().nodeById(0)); + env.kafkaClient().prepareResponseFrom(new DescribeConfigsResponse( + new DescribeConfigsResponseData().setResults(asList(new DescribeConfigsResponseData.DescribeConfigsResult() + .setResourceName(broker1Resource.name()).setResourceType(broker1Resource.type().id()).setErrorCode(Errors.NONE.code()) + .setConfigs(emptyList())))), env.cluster().nodeById(1)); + Map> result = env.adminClient().describeConfigs(asList( + broker0Resource, + broker1Resource)).values(); + assertEquals(new HashSet<>(asList(broker0Resource, broker1Resource)), result.keySet()); + result.get(broker0Resource).get(); + result.get(broker1Resource).get(); + } + } + + @Test + public void testDescribeBrokerAndLogConfigs() throws Exception { + ConfigResource brokerResource = new ConfigResource(ConfigResource.Type.BROKER, "0"); + ConfigResource brokerLoggerResource = new ConfigResource(ConfigResource.Type.BROKER_LOGGER, "0"); + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + env.kafkaClient().prepareResponseFrom(new DescribeConfigsResponse( + new DescribeConfigsResponseData().setResults(asList(new DescribeConfigsResponseData.DescribeConfigsResult() + .setResourceName(brokerResource.name()).setResourceType(brokerResource.type().id()).setErrorCode(Errors.NONE.code()) + .setConfigs(emptyList()), + new DescribeConfigsResponseData.DescribeConfigsResult() + .setResourceName(brokerLoggerResource.name()).setResourceType(brokerLoggerResource.type().id()).setErrorCode(Errors.NONE.code()) + .setConfigs(emptyList())))), env.cluster().nodeById(0)); + Map> result = env.adminClient().describeConfigs(asList( + brokerResource, + brokerLoggerResource)).values(); + assertEquals(new HashSet<>(asList(brokerResource, brokerLoggerResource)), result.keySet()); + result.get(brokerResource).get(); + result.get(brokerLoggerResource).get(); + } + } + + @Test + public void testDescribeConfigsPartialResponse() throws Exception { + ConfigResource topic = new ConfigResource(ConfigResource.Type.TOPIC, "topic"); + ConfigResource topic2 = new ConfigResource(ConfigResource.Type.TOPIC, "topic2"); + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + env.kafkaClient().prepareResponse(new DescribeConfigsResponse( + new DescribeConfigsResponseData().setResults(asList(new DescribeConfigsResponseData.DescribeConfigsResult() + .setResourceName(topic.name()).setResourceType(topic.type().id()).setErrorCode(Errors.NONE.code()) + .setConfigs(emptyList()))))); + Map> result = env.adminClient().describeConfigs(asList( + topic, + topic2)).values(); + assertEquals(new HashSet<>(asList(topic, topic2)), result.keySet()); + result.get(topic); + TestUtils.assertFutureThrows(result.get(topic2), ApiException.class); + } + } + + @Test + public void testDescribeConfigsUnrequested() throws Exception { + ConfigResource topic = new ConfigResource(ConfigResource.Type.TOPIC, "topic"); + ConfigResource unrequested = new ConfigResource(ConfigResource.Type.TOPIC, "unrequested"); + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + env.kafkaClient().prepareResponse(new DescribeConfigsResponse( + new DescribeConfigsResponseData().setResults(asList(new DescribeConfigsResponseData.DescribeConfigsResult() + .setResourceName(topic.name()).setResourceType(topic.type().id()).setErrorCode(Errors.NONE.code()) + .setConfigs(emptyList()), + new DescribeConfigsResponseData.DescribeConfigsResult() + .setResourceName(unrequested.name()).setResourceType(unrequested.type().id()).setErrorCode(Errors.NONE.code()) + .setConfigs(emptyList()))))); + Map> result = env.adminClient().describeConfigs(asList( + topic)).values(); + assertEquals(new HashSet<>(asList(topic)), result.keySet()); + assertNotNull(result.get(topic).get()); + assertNull(result.get(unrequested)); + } + } + + private static DescribeLogDirsResponse prepareDescribeLogDirsResponse(Errors error, String logDir, TopicPartition tp, long partitionSize, long offsetLag) { + return prepareDescribeLogDirsResponse(error, logDir, + prepareDescribeLogDirsTopics(partitionSize, offsetLag, tp.topic(), tp.partition(), false)); + } + + private static List prepareDescribeLogDirsTopics( + long partitionSize, long offsetLag, String topic, int partition, boolean isFuture) { + return singletonList(new DescribeLogDirsTopic() + .setName(topic) + .setPartitions(singletonList(new DescribeLogDirsResponseData.DescribeLogDirsPartition() + .setPartitionIndex(partition) + .setPartitionSize(partitionSize) + .setIsFutureKey(isFuture) + .setOffsetLag(offsetLag)))); + } + + private static DescribeLogDirsResponse prepareDescribeLogDirsResponse(Errors error, String logDir, + List topics) { + return new DescribeLogDirsResponse( + new DescribeLogDirsResponseData().setResults(singletonList(new DescribeLogDirsResponseData.DescribeLogDirsResult() + .setErrorCode(error.code()) + .setLogDir(logDir) + .setTopics(topics) + ))); + } + + @Test + public void testDescribeLogDirs() throws ExecutionException, InterruptedException { + Set brokers = singleton(0); + String logDir = "/var/data/kafka"; + TopicPartition tp = new TopicPartition("topic", 12); + long partitionSize = 1234567890; + long offsetLag = 24; + + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + env.kafkaClient().prepareResponseFrom( + prepareDescribeLogDirsResponse(Errors.NONE, logDir, tp, partitionSize, offsetLag), + env.cluster().nodeById(0)); + + DescribeLogDirsResult result = env.adminClient().describeLogDirs(brokers); + + Map>> descriptions = result.descriptions(); + assertEquals(brokers, descriptions.keySet()); + assertNotNull(descriptions.get(0)); + assertDescriptionContains(descriptions.get(0).get(), logDir, tp, partitionSize, offsetLag); + + Map> allDescriptions = result.allDescriptions().get(); + assertEquals(brokers, allDescriptions.keySet()); + assertDescriptionContains(allDescriptions.get(0), logDir, tp, partitionSize, offsetLag); + } + } + + private static void assertDescriptionContains(Map descriptionsMap, String logDir, + TopicPartition tp, long partitionSize, long offsetLag) { + assertNotNull(descriptionsMap); + assertEquals(singleton(logDir), descriptionsMap.keySet()); + assertNull(descriptionsMap.get(logDir).error()); + Map descriptionsReplicaInfos = descriptionsMap.get(logDir).replicaInfos(); + assertEquals(singleton(tp), descriptionsReplicaInfos.keySet()); + assertEquals(partitionSize, descriptionsReplicaInfos.get(tp).size()); + assertEquals(offsetLag, descriptionsReplicaInfos.get(tp).offsetLag()); + assertFalse(descriptionsReplicaInfos.get(tp).isFuture()); + } + + @SuppressWarnings("deprecation") + @Test + public void testDescribeLogDirsDeprecated() throws ExecutionException, InterruptedException { + Set brokers = singleton(0); + TopicPartition tp = new TopicPartition("topic", 12); + String logDir = "/var/data/kafka"; + Errors error = Errors.NONE; + int offsetLag = 24; + long partitionSize = 1234567890; + + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + env.kafkaClient().prepareResponseFrom( + prepareDescribeLogDirsResponse(error, logDir, tp, partitionSize, offsetLag), + env.cluster().nodeById(0)); + + DescribeLogDirsResult result = env.adminClient().describeLogDirs(brokers); + + Map>> deprecatedValues = result.values(); + assertEquals(brokers, deprecatedValues.keySet()); + assertNotNull(deprecatedValues.get(0)); + assertDescriptionContains(deprecatedValues.get(0).get(), logDir, tp, error, offsetLag, partitionSize); + + Map> deprecatedAll = result.all().get(); + assertEquals(brokers, deprecatedAll.keySet()); + assertDescriptionContains(deprecatedAll.get(0), logDir, tp, error, offsetLag, partitionSize); + } + } + + @SuppressWarnings("deprecation") + private static void assertDescriptionContains(Map descriptionsMap, + String logDir, TopicPartition tp, Errors error, + int offsetLag, long partitionSize) { + assertNotNull(descriptionsMap); + assertEquals(singleton(logDir), descriptionsMap.keySet()); + assertEquals(error, descriptionsMap.get(logDir).error); + Map allReplicaInfos = + descriptionsMap.get(logDir).replicaInfos; + assertEquals(singleton(tp), allReplicaInfos.keySet()); + assertEquals(partitionSize, allReplicaInfos.get(tp).size); + assertEquals(offsetLag, allReplicaInfos.get(tp).offsetLag); + assertFalse(allReplicaInfos.get(tp).isFuture); + } + + @Test + public void testDescribeLogDirsOfflineDir() throws ExecutionException, InterruptedException { + Set brokers = singleton(0); + String logDir = "/var/data/kafka"; + Errors error = Errors.KAFKA_STORAGE_ERROR; + + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + env.kafkaClient().prepareResponseFrom( + prepareDescribeLogDirsResponse(error, logDir, emptyList()), + env.cluster().nodeById(0)); + + DescribeLogDirsResult result = env.adminClient().describeLogDirs(brokers); + + Map>> descriptions = result.descriptions(); + assertEquals(brokers, descriptions.keySet()); + assertNotNull(descriptions.get(0)); + Map descriptionsMap = descriptions.get(0).get(); + assertEquals(singleton(logDir), descriptionsMap.keySet()); + assertEquals(error.exception().getClass(), descriptionsMap.get(logDir).error().getClass()); + assertEquals(emptySet(), descriptionsMap.get(logDir).replicaInfos().keySet()); + + Map> allDescriptions = result.allDescriptions().get(); + assertEquals(brokers, allDescriptions.keySet()); + Map allMap = allDescriptions.get(0); + assertNotNull(allMap); + assertEquals(singleton(logDir), allMap.keySet()); + assertEquals(error.exception().getClass(), allMap.get(logDir).error().getClass()); + assertEquals(emptySet(), allMap.get(logDir).replicaInfos().keySet()); + } + } + + @SuppressWarnings("deprecation") + @Test + public void testDescribeLogDirsOfflineDirDeprecated() throws ExecutionException, InterruptedException { + Set brokers = singleton(0); + String logDir = "/var/data/kafka"; + Errors error = Errors.KAFKA_STORAGE_ERROR; + + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + env.kafkaClient().prepareResponseFrom( + prepareDescribeLogDirsResponse(error, logDir, emptyList()), + env.cluster().nodeById(0)); + + DescribeLogDirsResult result = env.adminClient().describeLogDirs(brokers); + + Map>> deprecatedValues = result.values(); + assertEquals(brokers, deprecatedValues.keySet()); + assertNotNull(deprecatedValues.get(0)); + Map valuesMap = deprecatedValues.get(0).get(); + assertEquals(singleton(logDir), valuesMap.keySet()); + assertEquals(error, valuesMap.get(logDir).error); + assertEquals(emptySet(), valuesMap.get(logDir).replicaInfos.keySet()); + + Map> deprecatedAll = result.all().get(); + assertEquals(brokers, deprecatedAll.keySet()); + Map allMap = deprecatedAll.get(0); + assertNotNull(allMap); + assertEquals(singleton(logDir), allMap.keySet()); + assertEquals(error, allMap.get(logDir).error); + assertEquals(emptySet(), allMap.get(logDir).replicaInfos.keySet()); + } + } + + @Test + public void testDescribeReplicaLogDirs() throws ExecutionException, InterruptedException { + TopicPartitionReplica tpr1 = new TopicPartitionReplica("topic", 12, 1); + TopicPartitionReplica tpr2 = new TopicPartitionReplica("topic", 12, 2); + + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + String broker1log0 = "/var/data/kafka0"; + String broker1log1 = "/var/data/kafka1"; + String broker2log0 = "/var/data/kafka2"; + int broker1Log0OffsetLag = 24; + int broker1Log0PartitionSize = 987654321; + int broker1Log1PartitionSize = 123456789; + int broker1Log1OffsetLag = 4321; + env.kafkaClient().prepareResponseFrom( + new DescribeLogDirsResponse( + new DescribeLogDirsResponseData().setResults(asList( + prepareDescribeLogDirsResult(tpr1, broker1log0, broker1Log0PartitionSize, broker1Log0OffsetLag, false), + prepareDescribeLogDirsResult(tpr1, broker1log1, broker1Log1PartitionSize, broker1Log1OffsetLag, true)))), + env.cluster().nodeById(tpr1.brokerId())); + env.kafkaClient().prepareResponseFrom( + prepareDescribeLogDirsResponse(Errors.KAFKA_STORAGE_ERROR, broker2log0), + env.cluster().nodeById(tpr2.brokerId())); + + DescribeReplicaLogDirsResult result = env.adminClient().describeReplicaLogDirs(asList(tpr1, tpr2)); + + Map> values = result.values(); + assertEquals(TestUtils.toSet(asList(tpr1, tpr2)), values.keySet()); + + assertNotNull(values.get(tpr1)); + assertEquals(broker1log0, values.get(tpr1).get().getCurrentReplicaLogDir()); + assertEquals(broker1Log0OffsetLag, values.get(tpr1).get().getCurrentReplicaOffsetLag()); + assertEquals(broker1log1, values.get(tpr1).get().getFutureReplicaLogDir()); + assertEquals(broker1Log1OffsetLag, values.get(tpr1).get().getFutureReplicaOffsetLag()); + + assertNotNull(values.get(tpr2)); + assertNull(values.get(tpr2).get().getCurrentReplicaLogDir()); + assertEquals(-1, values.get(tpr2).get().getCurrentReplicaOffsetLag()); + assertNull(values.get(tpr2).get().getFutureReplicaLogDir()); + assertEquals(-1, values.get(tpr2).get().getFutureReplicaOffsetLag()); + } + } + + private static DescribeLogDirsResponseData.DescribeLogDirsResult prepareDescribeLogDirsResult(TopicPartitionReplica tpr, String logDir, int partitionSize, int offsetLag, boolean isFuture) { + return new DescribeLogDirsResponseData.DescribeLogDirsResult() + .setErrorCode(Errors.NONE.code()) + .setLogDir(logDir) + .setTopics(prepareDescribeLogDirsTopics(partitionSize, offsetLag, tpr.topic(), tpr.partition(), isFuture)); + } + + @Test + public void testDescribeReplicaLogDirsUnexpected() throws ExecutionException, InterruptedException { + TopicPartitionReplica expected = new TopicPartitionReplica("topic", 12, 1); + TopicPartitionReplica unexpected = new TopicPartitionReplica("topic", 12, 2); + + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + String broker1log0 = "/var/data/kafka0"; + String broker1log1 = "/var/data/kafka1"; + int broker1Log0PartitionSize = 987654321; + int broker1Log0OffsetLag = 24; + int broker1Log1PartitionSize = 123456789; + int broker1Log1OffsetLag = 4321; + env.kafkaClient().prepareResponseFrom( + new DescribeLogDirsResponse( + new DescribeLogDirsResponseData().setResults(asList( + prepareDescribeLogDirsResult(expected, broker1log0, broker1Log0PartitionSize, broker1Log0OffsetLag, false), + prepareDescribeLogDirsResult(unexpected, broker1log1, broker1Log1PartitionSize, broker1Log1OffsetLag, true)))), + env.cluster().nodeById(expected.brokerId())); + + DescribeReplicaLogDirsResult result = env.adminClient().describeReplicaLogDirs(asList(expected)); + + Map> values = result.values(); + assertEquals(TestUtils.toSet(asList(expected)), values.keySet()); + + assertNotNull(values.get(expected)); + assertEquals(broker1log0, values.get(expected).get().getCurrentReplicaLogDir()); + assertEquals(broker1Log0OffsetLag, values.get(expected).get().getCurrentReplicaOffsetLag()); + assertEquals(broker1log1, values.get(expected).get().getFutureReplicaLogDir()); + assertEquals(broker1Log1OffsetLag, values.get(expected).get().getFutureReplicaOffsetLag()); + } + } + + @Test + public void testCreatePartitions() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + // Test a call where one filter has an error. + env.kafkaClient().prepareResponse( + expectCreatePartitionsRequestWithTopics("my_topic", "other_topic"), + prepareCreatePartitionsResponse(1000, + createPartitionsTopicResult("my_topic", Errors.NONE), + createPartitionsTopicResult("other_topic", Errors.INVALID_TOPIC_EXCEPTION, + "some detailed reason"))); + + + Map counts = new HashMap<>(); + counts.put("my_topic", NewPartitions.increaseTo(3)); + counts.put("other_topic", NewPartitions.increaseTo(3, asList(asList(2), asList(3)))); + + CreatePartitionsResult results = env.adminClient().createPartitions(counts); + Map> values = results.values(); + KafkaFuture myTopicResult = values.get("my_topic"); + myTopicResult.get(); + KafkaFuture otherTopicResult = values.get("other_topic"); + try { + otherTopicResult.get(); + fail("get() should throw ExecutionException"); + } catch (ExecutionException e0) { + assertTrue(e0.getCause() instanceof InvalidTopicException); + InvalidTopicException e = (InvalidTopicException) e0.getCause(); + assertEquals("some detailed reason", e.getMessage()); + } + } + } + + @Test + public void testCreatePartitionsRetryThrottlingExceptionWhenEnabled() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse( + expectCreatePartitionsRequestWithTopics("topic1", "topic2", "topic3"), + prepareCreatePartitionsResponse(1000, + createPartitionsTopicResult("topic1", Errors.NONE), + createPartitionsTopicResult("topic2", Errors.THROTTLING_QUOTA_EXCEEDED), + createPartitionsTopicResult("topic3", Errors.TOPIC_ALREADY_EXISTS))); + + env.kafkaClient().prepareResponse( + expectCreatePartitionsRequestWithTopics("topic2"), + prepareCreatePartitionsResponse(1000, + createPartitionsTopicResult("topic2", Errors.THROTTLING_QUOTA_EXCEEDED))); + + env.kafkaClient().prepareResponse( + expectCreatePartitionsRequestWithTopics("topic2"), + prepareCreatePartitionsResponse(0, + createPartitionsTopicResult("topic2", Errors.NONE))); + + Map counts = new HashMap<>(); + counts.put("topic1", NewPartitions.increaseTo(1)); + counts.put("topic2", NewPartitions.increaseTo(2)); + counts.put("topic3", NewPartitions.increaseTo(3)); + + CreatePartitionsResult result = env.adminClient().createPartitions( + counts, new CreatePartitionsOptions().retryOnQuotaViolation(true)); + + assertNull(result.values().get("topic1").get()); + assertNull(result.values().get("topic2").get()); + TestUtils.assertFutureThrows(result.values().get("topic3"), TopicExistsException.class); + } + } + + @Test + public void testCreatePartitionsRetryThrottlingExceptionWhenEnabledUntilRequestTimeOut() throws Exception { + long defaultApiTimeout = 60000; + MockTime time = new MockTime(); + + try (AdminClientUnitTestEnv env = mockClientEnv(time, + AdminClientConfig.DEFAULT_API_TIMEOUT_MS_CONFIG, String.valueOf(defaultApiTimeout))) { + + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse( + expectCreatePartitionsRequestWithTopics("topic1", "topic2", "topic3"), + prepareCreatePartitionsResponse(1000, + createPartitionsTopicResult("topic1", Errors.NONE), + createPartitionsTopicResult("topic2", Errors.THROTTLING_QUOTA_EXCEEDED), + createPartitionsTopicResult("topic3", Errors.TOPIC_ALREADY_EXISTS))); + + env.kafkaClient().prepareResponse( + expectCreatePartitionsRequestWithTopics("topic2"), + prepareCreatePartitionsResponse(1000, + createPartitionsTopicResult("topic2", Errors.THROTTLING_QUOTA_EXCEEDED))); + + Map counts = new HashMap<>(); + counts.put("topic1", NewPartitions.increaseTo(1)); + counts.put("topic2", NewPartitions.increaseTo(2)); + counts.put("topic3", NewPartitions.increaseTo(3)); + + CreatePartitionsResult result = env.adminClient().createPartitions( + counts, new CreatePartitionsOptions().retryOnQuotaViolation(true)); + + // Wait until the prepared attempts have consumed + TestUtils.waitForCondition(() -> env.kafkaClient().numAwaitingResponses() == 0, + "Failed awaiting CreatePartitions requests"); + + // Wait until the next request is sent out + TestUtils.waitForCondition(() -> env.kafkaClient().inFlightRequestCount() == 1, + "Failed awaiting next CreatePartitions request"); + + // Advance time past the default api timeout to time out the inflight request + time.sleep(defaultApiTimeout + 1); + + assertNull(result.values().get("topic1").get()); + ThrottlingQuotaExceededException e = TestUtils.assertFutureThrows(result.values().get("topic2"), + ThrottlingQuotaExceededException.class); + assertEquals(0, e.throttleTimeMs()); + TestUtils.assertFutureThrows(result.values().get("topic3"), TopicExistsException.class); + } + } + + @Test + public void testCreatePartitionsDontRetryThrottlingExceptionWhenDisabled() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse( + expectCreatePartitionsRequestWithTopics("topic1", "topic2", "topic3"), + prepareCreatePartitionsResponse(1000, + createPartitionsTopicResult("topic1", Errors.NONE), + createPartitionsTopicResult("topic2", Errors.THROTTLING_QUOTA_EXCEEDED), + createPartitionsTopicResult("topic3", Errors.TOPIC_ALREADY_EXISTS))); + + Map counts = new HashMap<>(); + counts.put("topic1", NewPartitions.increaseTo(1)); + counts.put("topic2", NewPartitions.increaseTo(2)); + counts.put("topic3", NewPartitions.increaseTo(3)); + + CreatePartitionsResult result = env.adminClient().createPartitions( + counts, new CreatePartitionsOptions().retryOnQuotaViolation(false)); + + assertNull(result.values().get("topic1").get()); + ThrottlingQuotaExceededException e = TestUtils.assertFutureThrows(result.values().get("topic2"), + ThrottlingQuotaExceededException.class); + assertEquals(1000, e.throttleTimeMs()); + TestUtils.assertFutureThrows(result.values().get("topic3"), TopicExistsException.class); + } + } + + private MockClient.RequestMatcher expectCreatePartitionsRequestWithTopics(final String... topics) { + return body -> { + if (body instanceof CreatePartitionsRequest) { + CreatePartitionsRequest request = (CreatePartitionsRequest) body; + for (String topic : topics) { + if (request.data().topics().find(topic) == null) + return false; + } + return topics.length == request.data().topics().size(); + } + return false; + }; + } + + @Test + public void testDeleteRecordsTopicAuthorizationError() { + String topic = "foo"; + TopicPartition partition = new TopicPartition(topic, 0); + + try (AdminClientUnitTestEnv env = mockClientEnv()) { + List topics = new ArrayList<>(); + topics.add(new MetadataResponse.TopicMetadata(Errors.TOPIC_AUTHORIZATION_FAILED, topic, false, + Collections.emptyList())); + + env.kafkaClient().prepareResponse(RequestTestUtils.metadataResponse(env.cluster().nodes(), + env.cluster().clusterResource().clusterId(), env.cluster().controller().id(), topics)); + + Map recordsToDelete = new HashMap<>(); + recordsToDelete.put(partition, RecordsToDelete.beforeOffset(10L)); + DeleteRecordsResult results = env.adminClient().deleteRecords(recordsToDelete); + + TestUtils.assertFutureThrows(results.lowWatermarks().get(partition), TopicAuthorizationException.class); + } + } + + @Test + public void testDeleteRecordsMultipleSends() throws Exception { + String topic = "foo"; + TopicPartition tp0 = new TopicPartition(topic, 0); + TopicPartition tp1 = new TopicPartition(topic, 1); + + MockTime time = new MockTime(); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(time, mockCluster(3, 0))) { + List nodes = env.cluster().nodes(); + + List partitionMetadata = new ArrayList<>(); + partitionMetadata.add(new MetadataResponse.PartitionMetadata(Errors.NONE, tp0, + Optional.of(nodes.get(0).id()), Optional.of(5), singletonList(nodes.get(0).id()), + singletonList(nodes.get(0).id()), Collections.emptyList())); + partitionMetadata.add(new MetadataResponse.PartitionMetadata(Errors.NONE, tp1, + Optional.of(nodes.get(1).id()), Optional.of(5), singletonList(nodes.get(1).id()), + singletonList(nodes.get(1).id()), Collections.emptyList())); + + List topicMetadata = new ArrayList<>(); + topicMetadata.add(new MetadataResponse.TopicMetadata(Errors.NONE, topic, false, partitionMetadata)); + + env.kafkaClient().prepareResponse(RequestTestUtils.metadataResponse(env.cluster().nodes(), + env.cluster().clusterResource().clusterId(), env.cluster().controller().id(), topicMetadata)); + + env.kafkaClient().prepareResponseFrom(new DeleteRecordsResponse(new DeleteRecordsResponseData().setTopics( + new DeleteRecordsResponseData.DeleteRecordsTopicResultCollection(singletonList(new DeleteRecordsResponseData.DeleteRecordsTopicResult() + .setName(tp0.topic()) + .setPartitions(new DeleteRecordsResponseData.DeleteRecordsPartitionResultCollection(singletonList(new DeleteRecordsResponseData.DeleteRecordsPartitionResult() + .setPartitionIndex(tp0.partition()) + .setErrorCode(Errors.NONE.code()) + .setLowWatermark(3)).iterator()))).iterator()))), nodes.get(0)); + + env.kafkaClient().disconnect(nodes.get(1).idString()); + env.kafkaClient().createPendingAuthenticationError(nodes.get(1), 100); + + Map recordsToDelete = new HashMap<>(); + recordsToDelete.put(tp0, RecordsToDelete.beforeOffset(10L)); + recordsToDelete.put(tp1, RecordsToDelete.beforeOffset(10L)); + DeleteRecordsResult results = env.adminClient().deleteRecords(recordsToDelete); + + assertEquals(3L, results.lowWatermarks().get(tp0).get().lowWatermark()); + TestUtils.assertFutureThrows(results.lowWatermarks().get(tp1), AuthenticationException.class); + } + } + + @Test + public void testDeleteRecords() throws Exception { + HashMap nodes = new HashMap<>(); + nodes.put(0, new Node(0, "localhost", 8121)); + List partitionInfos = new ArrayList<>(); + partitionInfos.add(new PartitionInfo("my_topic", 0, nodes.get(0), new Node[] {nodes.get(0)}, new Node[] {nodes.get(0)})); + partitionInfos.add(new PartitionInfo("my_topic", 1, nodes.get(0), new Node[] {nodes.get(0)}, new Node[] {nodes.get(0)})); + partitionInfos.add(new PartitionInfo("my_topic", 2, null, new Node[] {nodes.get(0)}, new Node[] {nodes.get(0)})); + partitionInfos.add(new PartitionInfo("my_topic", 3, nodes.get(0), new Node[] {nodes.get(0)}, new Node[] {nodes.get(0)})); + partitionInfos.add(new PartitionInfo("my_topic", 4, nodes.get(0), new Node[] {nodes.get(0)}, new Node[] {nodes.get(0)})); + Cluster cluster = new Cluster("mockClusterId", nodes.values(), + partitionInfos, Collections.emptySet(), + Collections.emptySet(), nodes.get(0)); + + TopicPartition myTopicPartition0 = new TopicPartition("my_topic", 0); + TopicPartition myTopicPartition1 = new TopicPartition("my_topic", 1); + TopicPartition myTopicPartition2 = new TopicPartition("my_topic", 2); + TopicPartition myTopicPartition3 = new TopicPartition("my_topic", 3); + TopicPartition myTopicPartition4 = new TopicPartition("my_topic", 4); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(cluster)) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + DeleteRecordsResponseData m = new DeleteRecordsResponseData(); + m.topics().add(new DeleteRecordsResponseData.DeleteRecordsTopicResult().setName(myTopicPartition0.topic()) + .setPartitions(new DeleteRecordsResponseData.DeleteRecordsPartitionResultCollection(asList( + new DeleteRecordsResponseData.DeleteRecordsPartitionResult() + .setPartitionIndex(myTopicPartition0.partition()) + .setLowWatermark(3) + .setErrorCode(Errors.NONE.code()), + new DeleteRecordsResponseData.DeleteRecordsPartitionResult() + .setPartitionIndex(myTopicPartition1.partition()) + .setLowWatermark(DeleteRecordsResponse.INVALID_LOW_WATERMARK) + .setErrorCode(Errors.OFFSET_OUT_OF_RANGE.code()), + new DeleteRecordsResponseData.DeleteRecordsPartitionResult() + .setPartitionIndex(myTopicPartition3.partition()) + .setLowWatermark(DeleteRecordsResponse.INVALID_LOW_WATERMARK) + .setErrorCode(Errors.NOT_LEADER_OR_FOLLOWER.code()), + new DeleteRecordsResponseData.DeleteRecordsPartitionResult() + .setPartitionIndex(myTopicPartition4.partition()) + .setLowWatermark(DeleteRecordsResponse.INVALID_LOW_WATERMARK) + .setErrorCode(Errors.UNKNOWN_TOPIC_OR_PARTITION.code()) + ).iterator()))); + + List t = new ArrayList<>(); + List p = new ArrayList<>(); + p.add(new MetadataResponse.PartitionMetadata(Errors.NONE, myTopicPartition0, + Optional.of(nodes.get(0).id()), Optional.of(5), singletonList(nodes.get(0).id()), + singletonList(nodes.get(0).id()), Collections.emptyList())); + p.add(new MetadataResponse.PartitionMetadata(Errors.NONE, myTopicPartition1, + Optional.of(nodes.get(0).id()), Optional.of(5), singletonList(nodes.get(0).id()), + singletonList(nodes.get(0).id()), Collections.emptyList())); + p.add(new MetadataResponse.PartitionMetadata(Errors.LEADER_NOT_AVAILABLE, myTopicPartition2, + Optional.empty(), Optional.empty(), singletonList(nodes.get(0).id()), + singletonList(nodes.get(0).id()), Collections.emptyList())); + p.add(new MetadataResponse.PartitionMetadata(Errors.NONE, myTopicPartition3, + Optional.of(nodes.get(0).id()), Optional.of(5), singletonList(nodes.get(0).id()), + singletonList(nodes.get(0).id()), Collections.emptyList())); + p.add(new MetadataResponse.PartitionMetadata(Errors.NONE, myTopicPartition4, + Optional.of(nodes.get(0).id()), Optional.of(5), singletonList(nodes.get(0).id()), + singletonList(nodes.get(0).id()), Collections.emptyList())); + + t.add(new MetadataResponse.TopicMetadata(Errors.NONE, "my_topic", false, p)); + + env.kafkaClient().prepareResponse(RequestTestUtils.metadataResponse(cluster.nodes(), cluster.clusterResource().clusterId(), cluster.controller().id(), t)); + env.kafkaClient().prepareResponse(new DeleteRecordsResponse(m)); + + Map recordsToDelete = new HashMap<>(); + recordsToDelete.put(myTopicPartition0, RecordsToDelete.beforeOffset(3L)); + recordsToDelete.put(myTopicPartition1, RecordsToDelete.beforeOffset(10L)); + recordsToDelete.put(myTopicPartition2, RecordsToDelete.beforeOffset(10L)); + recordsToDelete.put(myTopicPartition3, RecordsToDelete.beforeOffset(10L)); + recordsToDelete.put(myTopicPartition4, RecordsToDelete.beforeOffset(10L)); + + DeleteRecordsResult results = env.adminClient().deleteRecords(recordsToDelete); + + // success on records deletion for partition 0 + Map> values = results.lowWatermarks(); + KafkaFuture myTopicPartition0Result = values.get(myTopicPartition0); + long lowWatermark = myTopicPartition0Result.get().lowWatermark(); + assertEquals(lowWatermark, 3); + + // "offset out of range" failure on records deletion for partition 1 + KafkaFuture myTopicPartition1Result = values.get(myTopicPartition1); + try { + myTopicPartition1Result.get(); + fail("get() should throw ExecutionException"); + } catch (ExecutionException e0) { + assertTrue(e0.getCause() instanceof OffsetOutOfRangeException); + } + + // "leader not available" failure on metadata request for partition 2 + KafkaFuture myTopicPartition2Result = values.get(myTopicPartition2); + try { + myTopicPartition2Result.get(); + fail("get() should throw ExecutionException"); + } catch (ExecutionException e1) { + assertTrue(e1.getCause() instanceof LeaderNotAvailableException); + } + + // "not leader for partition" failure on records deletion for partition 3 + KafkaFuture myTopicPartition3Result = values.get(myTopicPartition3); + try { + myTopicPartition3Result.get(); + fail("get() should throw ExecutionException"); + } catch (ExecutionException e1) { + assertTrue(e1.getCause() instanceof NotLeaderOrFollowerException); + } + + // "unknown topic or partition" failure on records deletion for partition 4 + KafkaFuture myTopicPartition4Result = values.get(myTopicPartition4); + try { + myTopicPartition4Result.get(); + fail("get() should throw ExecutionException"); + } catch (ExecutionException e1) { + assertTrue(e1.getCause() instanceof UnknownTopicOrPartitionException); + } + } + } + + @Test + public void testDescribeCluster() throws Exception { + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(mockCluster(4, 0), + AdminClientConfig.RETRIES_CONFIG, "2")) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + // Prepare the describe cluster response used for the first describe cluster + env.kafkaClient().prepareResponse( + prepareDescribeClusterResponse(0, + env.cluster().nodes(), + env.cluster().clusterResource().clusterId(), + 2, + MetadataResponse.AUTHORIZED_OPERATIONS_OMITTED)); + + // Prepare the describe cluster response used for the second describe cluster + env.kafkaClient().prepareResponse( + prepareDescribeClusterResponse(0, + env.cluster().nodes(), + env.cluster().clusterResource().clusterId(), + 3, + 1 << AclOperation.DESCRIBE.code() | 1 << AclOperation.ALTER.code())); + + // Test DescribeCluster with the authorized operations omitted. + final DescribeClusterResult result = env.adminClient().describeCluster(); + assertEquals(env.cluster().clusterResource().clusterId(), result.clusterId().get()); + assertEquals(new HashSet<>(env.cluster().nodes()), new HashSet<>(result.nodes().get())); + assertEquals(2, result.controller().get().id()); + assertNull(result.authorizedOperations().get()); + + // Test DescribeCluster with the authorized operations included. + final DescribeClusterResult result2 = env.adminClient().describeCluster(); + assertEquals(env.cluster().clusterResource().clusterId(), result2.clusterId().get()); + assertEquals(new HashSet<>(env.cluster().nodes()), new HashSet<>(result2.nodes().get())); + assertEquals(3, result2.controller().get().id()); + assertEquals(new HashSet<>(Arrays.asList(AclOperation.DESCRIBE, AclOperation.ALTER)), + result2.authorizedOperations().get()); + } + } + + @Test + public void testDescribeClusterHandleError() { + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(mockCluster(4, 0), + AdminClientConfig.RETRIES_CONFIG, "2")) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + // Prepare the describe cluster response used for the first describe cluster + String errorMessage = "my error"; + env.kafkaClient().prepareResponse( + new DescribeClusterResponse(new DescribeClusterResponseData() + .setErrorCode(Errors.INVALID_REQUEST.code()) + .setErrorMessage(errorMessage))); + + final DescribeClusterResult result = env.adminClient().describeCluster(); + TestUtils.assertFutureThrows(result.clusterId(), + InvalidRequestException.class, errorMessage); + TestUtils.assertFutureThrows(result.controller(), + InvalidRequestException.class, errorMessage); + TestUtils.assertFutureThrows(result.nodes(), + InvalidRequestException.class, errorMessage); + TestUtils.assertFutureThrows(result.authorizedOperations(), + InvalidRequestException.class, errorMessage); + } + } + + private static DescribeClusterResponse prepareDescribeClusterResponse( + int throttleTimeMs, + Collection brokers, + String clusterId, + int controllerId, + int clusterAuthorizedOperations + ) { + DescribeClusterResponseData data = new DescribeClusterResponseData() + .setErrorCode(Errors.NONE.code()) + .setThrottleTimeMs(throttleTimeMs) + .setControllerId(controllerId) + .setClusterId(clusterId) + .setClusterAuthorizedOperations(clusterAuthorizedOperations); + + brokers.forEach(broker -> + data.brokers().add(new DescribeClusterBroker() + .setHost(broker.host()) + .setPort(broker.port()) + .setBrokerId(broker.id()) + .setRack(broker.rack()))); + + return new DescribeClusterResponse(data); + } + + @Test + public void testDescribeClusterFailBack() throws Exception { + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(mockCluster(4, 0), + AdminClientConfig.RETRIES_CONFIG, "2")) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + // Reject the describe cluster request with an unsupported exception + env.kafkaClient().prepareUnsupportedVersionResponse( + request -> request instanceof DescribeClusterRequest); + + // Prepare the metadata response used for the first describe cluster + env.kafkaClient().prepareResponse( + RequestTestUtils.metadataResponse( + 0, + env.cluster().nodes(), + env.cluster().clusterResource().clusterId(), + 2, + Collections.emptyList(), + MetadataResponse.AUTHORIZED_OPERATIONS_OMITTED, + ApiKeys.METADATA.latestVersion())); + + final DescribeClusterResult result = env.adminClient().describeCluster(); + assertEquals(env.cluster().clusterResource().clusterId(), result.clusterId().get()); + assertEquals(new HashSet<>(env.cluster().nodes()), new HashSet<>(result.nodes().get())); + assertEquals(2, result.controller().get().id()); + assertNull(result.authorizedOperations().get()); + } + } + + @Test + public void testListConsumerGroups() throws Exception { + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(mockCluster(4, 0), + AdminClientConfig.RETRIES_CONFIG, "2")) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + // Empty metadata response should be retried + env.kafkaClient().prepareResponse( + RequestTestUtils.metadataResponse( + Collections.emptyList(), + env.cluster().clusterResource().clusterId(), + -1, + Collections.emptyList())); + + env.kafkaClient().prepareResponse( + RequestTestUtils.metadataResponse( + env.cluster().nodes(), + env.cluster().clusterResource().clusterId(), + env.cluster().controller().id(), + Collections.emptyList())); + + env.kafkaClient().prepareResponseFrom( + new ListGroupsResponse( + new ListGroupsResponseData() + .setErrorCode(Errors.NONE.code()) + .setGroups(Arrays.asList( + new ListGroupsResponseData.ListedGroup() + .setGroupId("group-1") + .setProtocolType(ConsumerProtocol.PROTOCOL_TYPE) + .setGroupState("Stable"), + new ListGroupsResponseData.ListedGroup() + .setGroupId("group-connect-1") + .setProtocolType("connector") + .setGroupState("Stable") + ))), + env.cluster().nodeById(0)); + + // handle retriable errors + env.kafkaClient().prepareResponseFrom( + new ListGroupsResponse( + new ListGroupsResponseData() + .setErrorCode(Errors.COORDINATOR_NOT_AVAILABLE.code()) + .setGroups(Collections.emptyList()) + ), + env.cluster().nodeById(1)); + env.kafkaClient().prepareResponseFrom( + new ListGroupsResponse( + new ListGroupsResponseData() + .setErrorCode(Errors.COORDINATOR_LOAD_IN_PROGRESS.code()) + .setGroups(Collections.emptyList()) + ), + env.cluster().nodeById(1)); + env.kafkaClient().prepareResponseFrom( + new ListGroupsResponse( + new ListGroupsResponseData() + .setErrorCode(Errors.NONE.code()) + .setGroups(Arrays.asList( + new ListGroupsResponseData.ListedGroup() + .setGroupId("group-2") + .setProtocolType(ConsumerProtocol.PROTOCOL_TYPE) + .setGroupState("Stable"), + new ListGroupsResponseData.ListedGroup() + .setGroupId("group-connect-2") + .setProtocolType("connector") + .setGroupState("Stable") + ))), + env.cluster().nodeById(1)); + + env.kafkaClient().prepareResponseFrom( + new ListGroupsResponse( + new ListGroupsResponseData() + .setErrorCode(Errors.NONE.code()) + .setGroups(Arrays.asList( + new ListGroupsResponseData.ListedGroup() + .setGroupId("group-3") + .setProtocolType(ConsumerProtocol.PROTOCOL_TYPE) + .setGroupState("Stable"), + new ListGroupsResponseData.ListedGroup() + .setGroupId("group-connect-3") + .setProtocolType("connector") + .setGroupState("Stable") + ))), + env.cluster().nodeById(2)); + + // fatal error + env.kafkaClient().prepareResponseFrom( + new ListGroupsResponse( + new ListGroupsResponseData() + .setErrorCode(Errors.UNKNOWN_SERVER_ERROR.code()) + .setGroups(Collections.emptyList())), + env.cluster().nodeById(3)); + + final ListConsumerGroupsResult result = env.adminClient().listConsumerGroups(); + TestUtils.assertFutureError(result.all(), UnknownServerException.class); + + Collection listings = result.valid().get(); + assertEquals(3, listings.size()); + + Set groupIds = new HashSet<>(); + for (ConsumerGroupListing listing : listings) { + groupIds.add(listing.groupId()); + assertTrue(listing.state().isPresent()); + } + + assertEquals(Utils.mkSet("group-1", "group-2", "group-3"), groupIds); + assertEquals(1, result.errors().get().size()); + } + } + + @Test + public void testListConsumerGroupsMetadataFailure() throws Exception { + final Cluster cluster = mockCluster(3, 0); + final Time time = new MockTime(); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(time, cluster, + AdminClientConfig.RETRIES_CONFIG, "0")) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + // Empty metadata causes the request to fail since we have no list of brokers + // to send the ListGroups requests to + env.kafkaClient().prepareResponse( + RequestTestUtils.metadataResponse( + Collections.emptyList(), + env.cluster().clusterResource().clusterId(), + -1, + Collections.emptyList())); + + final ListConsumerGroupsResult result = env.adminClient().listConsumerGroups(); + TestUtils.assertFutureError(result.all(), KafkaException.class); + } + } + + @Test + public void testListConsumerGroupsWithStates() throws Exception { + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(mockCluster(1, 0))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse(prepareMetadataResponse(env.cluster(), Errors.NONE)); + + env.kafkaClient().prepareResponseFrom( + new ListGroupsResponse(new ListGroupsResponseData() + .setErrorCode(Errors.NONE.code()) + .setGroups(Arrays.asList( + new ListGroupsResponseData.ListedGroup() + .setGroupId("group-1") + .setProtocolType(ConsumerProtocol.PROTOCOL_TYPE) + .setGroupState("Stable"), + new ListGroupsResponseData.ListedGroup() + .setGroupId("group-2") + .setGroupState("Empty")))), + env.cluster().nodeById(0)); + + final ListConsumerGroupsOptions options = new ListConsumerGroupsOptions(); + final ListConsumerGroupsResult result = env.adminClient().listConsumerGroups(options); + Collection listings = result.valid().get(); + + assertEquals(2, listings.size()); + List expected = new ArrayList<>(); + expected.add(new ConsumerGroupListing("group-2", true, Optional.of(ConsumerGroupState.EMPTY))); + expected.add(new ConsumerGroupListing("group-1", false, Optional.of(ConsumerGroupState.STABLE))); + assertEquals(expected, listings); + assertEquals(0, result.errors().get().size()); + } + } + + @Test + public void testListConsumerGroupsWithStatesOlderBrokerVersion() throws Exception { + ApiVersion listGroupV3 = new ApiVersion() + .setApiKey(ApiKeys.LIST_GROUPS.id) + .setMinVersion((short) 0) + .setMaxVersion((short) 3); + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(mockCluster(1, 0))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create(Collections.singletonList(listGroupV3))); + + env.kafkaClient().prepareResponse(prepareMetadataResponse(env.cluster(), Errors.NONE)); + + // Check we can list groups with older broker if we don't specify states + env.kafkaClient().prepareResponseFrom( + new ListGroupsResponse(new ListGroupsResponseData() + .setErrorCode(Errors.NONE.code()) + .setGroups(Collections.singletonList( + new ListGroupsResponseData.ListedGroup() + .setGroupId("group-1") + .setProtocolType(ConsumerProtocol.PROTOCOL_TYPE)))), + env.cluster().nodeById(0)); + ListConsumerGroupsOptions options = new ListConsumerGroupsOptions(); + ListConsumerGroupsResult result = env.adminClient().listConsumerGroups(options); + Collection listing = result.all().get(); + assertEquals(1, listing.size()); + List expected = Collections.singletonList(new ConsumerGroupListing("group-1", false, Optional.empty())); + assertEquals(expected, listing); + + // But we cannot set a state filter with older broker + env.kafkaClient().prepareResponse(prepareMetadataResponse(env.cluster(), Errors.NONE)); + env.kafkaClient().prepareUnsupportedVersionResponse( + body -> body instanceof ListGroupsRequest); + + options = new ListConsumerGroupsOptions().inStates(singleton(ConsumerGroupState.STABLE)); + result = env.adminClient().listConsumerGroups(options); + TestUtils.assertFutureThrows(result.all(), UnsupportedVersionException.class); + } + } + + @Test + public void testOffsetCommitNumRetries() throws Exception { + final Cluster cluster = mockCluster(3, 0); + final Time time = new MockTime(); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(time, cluster, + AdminClientConfig.RETRIES_CONFIG, "0")) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + final TopicPartition tp1 = new TopicPartition("foo", 0); + + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + env.kafkaClient().prepareResponse(prepareOffsetCommitResponse(tp1, Errors.NOT_COORDINATOR)); + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + Map offsets = new HashMap<>(); + offsets.put(tp1, new OffsetAndMetadata(123L)); + final AlterConsumerGroupOffsetsResult result = env.adminClient().alterConsumerGroupOffsets(GROUP_ID, offsets); + + TestUtils.assertFutureError(result.all(), TimeoutException.class); + } + } + + @Test + public void testOffsetCommitWithMultipleErrors() throws Exception { + final Cluster cluster = mockCluster(3, 0); + final Time time = new MockTime(); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(time, cluster, + AdminClientConfig.RETRIES_CONFIG, "0")) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + final TopicPartition foo0 = new TopicPartition("foo", 0); + final TopicPartition foo1 = new TopicPartition("foo", 1); + + env.kafkaClient().prepareResponse( + prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + Map responseData = new HashMap<>(); + responseData.put(foo0, Errors.NONE); + responseData.put(foo1, Errors.UNKNOWN_TOPIC_OR_PARTITION); + env.kafkaClient().prepareResponse(new OffsetCommitResponse(0, responseData)); + + Map offsets = new HashMap<>(); + offsets.put(foo0, new OffsetAndMetadata(123L)); + offsets.put(foo1, new OffsetAndMetadata(456L)); + final AlterConsumerGroupOffsetsResult result = env.adminClient() + .alterConsumerGroupOffsets(GROUP_ID, offsets); + + assertNull(result.partitionResult(foo0).get()); + TestUtils.assertFutureError(result.partitionResult(foo1), UnknownTopicOrPartitionException.class); + + TestUtils.assertFutureError(result.all(), UnknownTopicOrPartitionException.class); + } + } + + @Test + public void testOffsetCommitRetryBackoff() throws Exception { + MockTime time = new MockTime(); + int retryBackoff = 100; + + try (final AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(time, + mockCluster(3, 0), + newStrMap(AdminClientConfig.RETRY_BACKOFF_MS_CONFIG, "" + retryBackoff))) { + MockClient mockClient = env.kafkaClient(); + + mockClient.setNodeApiVersions(NodeApiVersions.create()); + + AtomicLong firstAttemptTime = new AtomicLong(0); + AtomicLong secondAttemptTime = new AtomicLong(0); + + final TopicPartition tp1 = new TopicPartition("foo", 0); + + mockClient.prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + mockClient.prepareResponse(body -> { + firstAttemptTime.set(time.milliseconds()); + return true; + }, prepareOffsetCommitResponse(tp1, Errors.NOT_COORDINATOR)); + + + mockClient.prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + mockClient.prepareResponse(body -> { + secondAttemptTime.set(time.milliseconds()); + return true; + }, prepareOffsetCommitResponse(tp1, Errors.NONE)); + + + Map offsets = new HashMap<>(); + offsets.put(tp1, new OffsetAndMetadata(123L)); + final KafkaFuture future = env.adminClient().alterConsumerGroupOffsets(GROUP_ID, offsets).all(); + + TestUtils.waitForCondition(() -> mockClient.numAwaitingResponses() == 1, "Failed awaiting CommitOffsets first request failure"); + TestUtils.waitForCondition(() -> ((KafkaAdminClient) env.adminClient()).numPendingCalls() == 1, "Failed to add retry CommitOffsets call on first failure"); + time.sleep(retryBackoff); + + future.get(); + + long actualRetryBackoff = secondAttemptTime.get() - firstAttemptTime.get(); + assertEquals(retryBackoff, actualRetryBackoff, "CommitOffsets retry did not await expected backoff"); + } + } + + @Test + public void testDescribeConsumerGroupNumRetries() throws Exception { + final Cluster cluster = mockCluster(3, 0); + final Time time = new MockTime(); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(time, cluster, + AdminClientConfig.RETRIES_CONFIG, "0")) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + DescribeGroupsResponseData data = new DescribeGroupsResponseData(); + + data.groups().add(DescribeGroupsResponse.groupMetadata( + GROUP_ID, + Errors.NOT_COORDINATOR, + "", + "", + "", + Collections.emptyList(), + Collections.emptySet())); + env.kafkaClient().prepareResponse(new DescribeGroupsResponse(data)); + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + final DescribeConsumerGroupsResult result = env.adminClient().describeConsumerGroups(singletonList(GROUP_ID)); + + TestUtils.assertFutureError(result.all(), TimeoutException.class); + } + } + + @Test + public void testDescribeConsumerGroupRetryBackoff() throws Exception { + MockTime time = new MockTime(); + int retryBackoff = 100; + + try (final AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(time, + mockCluster(3, 0), + newStrMap(AdminClientConfig.RETRY_BACKOFF_MS_CONFIG, "" + retryBackoff))) { + MockClient mockClient = env.kafkaClient(); + + mockClient.setNodeApiVersions(NodeApiVersions.create()); + + AtomicLong firstAttemptTime = new AtomicLong(0); + AtomicLong secondAttemptTime = new AtomicLong(0); + + mockClient.prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + DescribeGroupsResponseData data = new DescribeGroupsResponseData(); + data.groups().add(DescribeGroupsResponse.groupMetadata( + GROUP_ID, + Errors.NOT_COORDINATOR, + "", + "", + "", + Collections.emptyList(), + Collections.emptySet())); + + mockClient.prepareResponse(body -> { + firstAttemptTime.set(time.milliseconds()); + return true; + }, new DescribeGroupsResponse(data)); + + mockClient.prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + data = new DescribeGroupsResponseData(); + data.groups().add(DescribeGroupsResponse.groupMetadata( + GROUP_ID, + Errors.NONE, + "", + ConsumerProtocol.PROTOCOL_TYPE, + "", + Collections.emptyList(), + Collections.emptySet())); + + mockClient.prepareResponse(body -> { + secondAttemptTime.set(time.milliseconds()); + return true; + }, new DescribeGroupsResponse(data)); + + final KafkaFuture> future = + env.adminClient().describeConsumerGroups(singletonList(GROUP_ID)).all(); + + TestUtils.waitForCondition(() -> mockClient.numAwaitingResponses() == 1, "Failed awaiting DescribeConsumerGroup first request failure"); + TestUtils.waitForCondition(() -> ((KafkaAdminClient) env.adminClient()).numPendingCalls() == 1, "Failed to add retry DescribeConsumerGroup call on first failure"); + time.sleep(retryBackoff); + + future.get(); + + long actualRetryBackoff = secondAttemptTime.get() - firstAttemptTime.get(); + assertEquals(retryBackoff, actualRetryBackoff, "DescribeConsumerGroup retry did not await expected backoff!"); + } + } + + + @Test + public void testDescribeConsumerGroups() throws Exception { + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(mockCluster(1, 0))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + // Retriable FindCoordinatorResponse errors should be retried + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.COORDINATOR_NOT_AVAILABLE, Node.noNode())); + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.COORDINATOR_LOAD_IN_PROGRESS, Node.noNode())); + + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + DescribeGroupsResponseData data = new DescribeGroupsResponseData(); + + //Retriable errors should be retried + data.groups().add(DescribeGroupsResponse.groupMetadata( + GROUP_ID, + Errors.COORDINATOR_LOAD_IN_PROGRESS, + "", + "", + "", + Collections.emptyList(), + Collections.emptySet())); + env.kafkaClient().prepareResponse(new DescribeGroupsResponse(data)); + + /* + * We need to return two responses here, one with NOT_COORDINATOR error when calling describe consumer group + * api using coordinator that has moved. This will retry whole operation. So we need to again respond with a + * FindCoordinatorResponse. + * + * And the same reason for COORDINATOR_NOT_AVAILABLE error response + */ + data = new DescribeGroupsResponseData(); + data.groups().add(DescribeGroupsResponse.groupMetadata( + GROUP_ID, + Errors.NOT_COORDINATOR, + "", + "", + "", + Collections.emptyList(), + Collections.emptySet())); + env.kafkaClient().prepareResponse(new DescribeGroupsResponse(data)); + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + data = new DescribeGroupsResponseData(); + data.groups().add(DescribeGroupsResponse.groupMetadata( + GROUP_ID, + Errors.COORDINATOR_NOT_AVAILABLE, + "", + "", + "", + Collections.emptyList(), + Collections.emptySet())); + env.kafkaClient().prepareResponse(new DescribeGroupsResponse(data)); + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + data = new DescribeGroupsResponseData(); + TopicPartition myTopicPartition0 = new TopicPartition("my_topic", 0); + TopicPartition myTopicPartition1 = new TopicPartition("my_topic", 1); + TopicPartition myTopicPartition2 = new TopicPartition("my_topic", 2); + + final List topicPartitions = new ArrayList<>(); + topicPartitions.add(0, myTopicPartition0); + topicPartitions.add(1, myTopicPartition1); + topicPartitions.add(2, myTopicPartition2); + + final ByteBuffer memberAssignment = ConsumerProtocol.serializeAssignment(new ConsumerPartitionAssignor.Assignment(topicPartitions)); + byte[] memberAssignmentBytes = new byte[memberAssignment.remaining()]; + memberAssignment.get(memberAssignmentBytes); + + DescribedGroupMember memberOne = DescribeGroupsResponse.groupMember("0", "instance1", "clientId0", "clientHost", memberAssignmentBytes, null); + DescribedGroupMember memberTwo = DescribeGroupsResponse.groupMember("1", "instance2", "clientId1", "clientHost", memberAssignmentBytes, null); + + List expectedMemberDescriptions = new ArrayList<>(); + expectedMemberDescriptions.add(convertToMemberDescriptions(memberOne, + new MemberAssignment(new HashSet<>(topicPartitions)))); + expectedMemberDescriptions.add(convertToMemberDescriptions(memberTwo, + new MemberAssignment(new HashSet<>(topicPartitions)))); + data.groups().add(DescribeGroupsResponse.groupMetadata( + GROUP_ID, + Errors.NONE, + "", + ConsumerProtocol.PROTOCOL_TYPE, + "", + asList(memberOne, memberTwo), + Collections.emptySet())); + + env.kafkaClient().prepareResponse(new DescribeGroupsResponse(data)); + + final DescribeConsumerGroupsResult result = env.adminClient().describeConsumerGroups(singletonList(GROUP_ID)); + final ConsumerGroupDescription groupDescription = result.describedGroups().get(GROUP_ID).get(); + + assertEquals(1, result.describedGroups().size()); + assertEquals(GROUP_ID, groupDescription.groupId()); + assertEquals(2, groupDescription.members().size()); + assertEquals(expectedMemberDescriptions, groupDescription.members()); + } + } + + @Test + public void testDescribeMultipleConsumerGroups() { + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(mockCluster(1, 0))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + TopicPartition myTopicPartition0 = new TopicPartition("my_topic", 0); + TopicPartition myTopicPartition1 = new TopicPartition("my_topic", 1); + TopicPartition myTopicPartition2 = new TopicPartition("my_topic", 2); + + final List topicPartitions = new ArrayList<>(); + topicPartitions.add(0, myTopicPartition0); + topicPartitions.add(1, myTopicPartition1); + topicPartitions.add(2, myTopicPartition2); + + final ByteBuffer memberAssignment = ConsumerProtocol.serializeAssignment(new ConsumerPartitionAssignor.Assignment(topicPartitions)); + byte[] memberAssignmentBytes = new byte[memberAssignment.remaining()]; + memberAssignment.get(memberAssignmentBytes); + + DescribeGroupsResponseData group0Data = new DescribeGroupsResponseData(); + group0Data.groups().add(DescribeGroupsResponse.groupMetadata( + GROUP_ID, + Errors.NONE, + "", + ConsumerProtocol.PROTOCOL_TYPE, + "", + asList( + DescribeGroupsResponse.groupMember("0", null, "clientId0", "clientHost", memberAssignmentBytes, null), + DescribeGroupsResponse.groupMember("1", null, "clientId1", "clientHost", memberAssignmentBytes, null) + ), + Collections.emptySet())); + + DescribeGroupsResponseData groupConnectData = new DescribeGroupsResponseData(); + group0Data.groups().add(DescribeGroupsResponse.groupMetadata( + "group-connect-0", + Errors.NONE, + "", + "connect", + "", + asList( + DescribeGroupsResponse.groupMember("0", null, "clientId0", "clientHost", memberAssignmentBytes, null), + DescribeGroupsResponse.groupMember("1", null, "clientId1", "clientHost", memberAssignmentBytes, null) + ), + Collections.emptySet())); + + env.kafkaClient().prepareResponse(new DescribeGroupsResponse(group0Data)); + env.kafkaClient().prepareResponse(new DescribeGroupsResponse(groupConnectData)); + + Collection groups = new HashSet<>(); + groups.add(GROUP_ID); + groups.add("group-connect-0"); + final DescribeConsumerGroupsResult result = env.adminClient().describeConsumerGroups(groups); + assertEquals(2, result.describedGroups().size()); + assertEquals(groups, result.describedGroups().keySet()); + } + } + + @Test + public void testDescribeConsumerGroupsWithAuthorizedOperationsOmitted() throws Exception { + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(mockCluster(1, 0))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse( + prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + DescribeGroupsResponseData data = new DescribeGroupsResponseData(); + data.groups().add(DescribeGroupsResponse.groupMetadata( + GROUP_ID, + Errors.NONE, + "", + ConsumerProtocol.PROTOCOL_TYPE, + "", + Collections.emptyList(), + MetadataResponse.AUTHORIZED_OPERATIONS_OMITTED)); + + env.kafkaClient().prepareResponse(new DescribeGroupsResponse(data)); + + final DescribeConsumerGroupsResult result = env.adminClient().describeConsumerGroups(singletonList(GROUP_ID)); + final ConsumerGroupDescription groupDescription = result.describedGroups().get(GROUP_ID).get(); + + assertNull(groupDescription.authorizedOperations()); + } + } + + @Test + public void testDescribeNonConsumerGroups() throws Exception { + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(mockCluster(1, 0))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + DescribeGroupsResponseData data = new DescribeGroupsResponseData(); + + data.groups().add(DescribeGroupsResponse.groupMetadata( + GROUP_ID, + Errors.NONE, + "", + "non-consumer", + "", + asList(), + Collections.emptySet())); + + env.kafkaClient().prepareResponse(new DescribeGroupsResponse(data)); + + final DescribeConsumerGroupsResult result = env.adminClient().describeConsumerGroups(singletonList(GROUP_ID)); + + TestUtils.assertFutureError(result.describedGroups().get(GROUP_ID), IllegalArgumentException.class); + } + } + + @Test + public void testListConsumerGroupOffsetsNumRetries() throws Exception { + final Cluster cluster = mockCluster(3, 0); + final Time time = new MockTime(); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(time, cluster, + AdminClientConfig.RETRIES_CONFIG, "0")) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + env.kafkaClient().prepareResponse(new OffsetFetchResponse(Errors.NOT_COORDINATOR, Collections.emptyMap())); + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + final ListConsumerGroupOffsetsResult result = env.adminClient().listConsumerGroupOffsets(GROUP_ID); + + + TestUtils.assertFutureError(result.partitionsToOffsetAndMetadata(), TimeoutException.class); + } + } + + @Test + public void testListConsumerGroupOffsetsRetryBackoff() throws Exception { + MockTime time = new MockTime(); + int retryBackoff = 100; + + try (final AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(time, + mockCluster(3, 0), + newStrMap(AdminClientConfig.RETRY_BACKOFF_MS_CONFIG, "" + retryBackoff))) { + MockClient mockClient = env.kafkaClient(); + + mockClient.setNodeApiVersions(NodeApiVersions.create()); + + AtomicLong firstAttemptTime = new AtomicLong(0); + AtomicLong secondAttemptTime = new AtomicLong(0); + + mockClient.prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + mockClient.prepareResponse(body -> { + firstAttemptTime.set(time.milliseconds()); + return true; + }, new OffsetFetchResponse(Errors.NOT_COORDINATOR, Collections.emptyMap())); + + mockClient.prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + mockClient.prepareResponse(body -> { + secondAttemptTime.set(time.milliseconds()); + return true; + }, new OffsetFetchResponse(Errors.NONE, Collections.emptyMap())); + + final KafkaFuture> future = env.adminClient().listConsumerGroupOffsets("group-0").partitionsToOffsetAndMetadata(); + + TestUtils.waitForCondition(() -> mockClient.numAwaitingResponses() == 1, "Failed awaiting ListConsumerGroupOffsets first request failure"); + TestUtils.waitForCondition(() -> ((KafkaAdminClient) env.adminClient()).numPendingCalls() == 1, "Failed to add retry ListConsumerGroupOffsets call on first failure"); + time.sleep(retryBackoff); + + future.get(); + + long actualRetryBackoff = secondAttemptTime.get() - firstAttemptTime.get(); + assertEquals(retryBackoff, actualRetryBackoff, "ListConsumerGroupOffsets retry did not await expected backoff!"); + } + } + + @Test + public void testListConsumerGroupOffsetsRetriableErrors() throws Exception { + // Retriable errors should be retried + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(mockCluster(1, 0))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse( + prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + env.kafkaClient().prepareResponse( + new OffsetFetchResponse(Errors.COORDINATOR_LOAD_IN_PROGRESS, Collections.emptyMap())); + /* + * We need to return two responses here, one for NOT_COORDINATOR call when calling list consumer offsets + * api using coordinator that has moved. This will retry whole operation. So we need to again respond with a + * FindCoordinatorResponse. + * + * And the same reason for the following COORDINATOR_NOT_AVAILABLE error response + */ + env.kafkaClient().prepareResponse( + new OffsetFetchResponse(Errors.NOT_COORDINATOR, Collections.emptyMap())); + + env.kafkaClient().prepareResponse( + prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + env.kafkaClient().prepareResponse( + new OffsetFetchResponse(Errors.COORDINATOR_NOT_AVAILABLE, Collections.emptyMap())); + + env.kafkaClient().prepareResponse( + prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + env.kafkaClient().prepareResponse( + new OffsetFetchResponse(Errors.NONE, Collections.emptyMap())); + + final ListConsumerGroupOffsetsResult errorResult1 = env.adminClient().listConsumerGroupOffsets(GROUP_ID); + + assertEquals(Collections.emptyMap(), errorResult1.partitionsToOffsetAndMetadata().get()); + } + } + + @Test + public void testListConsumerGroupOffsetsNonRetriableErrors() throws Exception { + // Non-retriable errors throw an exception + final List nonRetriableErrors = Arrays.asList( + Errors.GROUP_AUTHORIZATION_FAILED, Errors.INVALID_GROUP_ID, Errors.GROUP_ID_NOT_FOUND); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(mockCluster(1, 0))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + for (Errors error : nonRetriableErrors) { + env.kafkaClient().prepareResponse( + prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + env.kafkaClient().prepareResponse( + new OffsetFetchResponse(error, Collections.emptyMap())); + + ListConsumerGroupOffsetsResult errorResult = env.adminClient().listConsumerGroupOffsets(GROUP_ID); + + TestUtils.assertFutureError(errorResult.partitionsToOffsetAndMetadata(), error.exception().getClass()); + } + } + } + + @Test + public void testListConsumerGroupOffsets() throws Exception { + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(mockCluster(1, 0))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + // Retriable FindCoordinatorResponse errors should be retried + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.COORDINATOR_NOT_AVAILABLE, Node.noNode())); + + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + // Retriable errors should be retried + env.kafkaClient().prepareResponse(new OffsetFetchResponse(Errors.COORDINATOR_LOAD_IN_PROGRESS, Collections.emptyMap())); + + /* + * We need to return two responses here, one for NOT_COORDINATOR error when calling list consumer group offsets + * api using coordinator that has moved. This will retry whole operation. So we need to again respond with a + * FindCoordinatorResponse. + * + * And the same reason for the following COORDINATOR_NOT_AVAILABLE error response + */ + env.kafkaClient().prepareResponse(new OffsetFetchResponse(Errors.NOT_COORDINATOR, Collections.emptyMap())); + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + env.kafkaClient().prepareResponse(new OffsetFetchResponse(Errors.COORDINATOR_NOT_AVAILABLE, Collections.emptyMap())); + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + TopicPartition myTopicPartition0 = new TopicPartition("my_topic", 0); + TopicPartition myTopicPartition1 = new TopicPartition("my_topic", 1); + TopicPartition myTopicPartition2 = new TopicPartition("my_topic", 2); + TopicPartition myTopicPartition3 = new TopicPartition("my_topic", 3); + + final Map responseData = new HashMap<>(); + responseData.put(myTopicPartition0, new OffsetFetchResponse.PartitionData(10, + Optional.empty(), "", Errors.NONE)); + responseData.put(myTopicPartition1, new OffsetFetchResponse.PartitionData(0, + Optional.empty(), "", Errors.NONE)); + responseData.put(myTopicPartition2, new OffsetFetchResponse.PartitionData(20, + Optional.empty(), "", Errors.NONE)); + responseData.put(myTopicPartition3, new OffsetFetchResponse.PartitionData(OffsetFetchResponse.INVALID_OFFSET, + Optional.empty(), "", Errors.NONE)); + env.kafkaClient().prepareResponse(new OffsetFetchResponse(Errors.NONE, responseData)); + + final ListConsumerGroupOffsetsResult result = env.adminClient().listConsumerGroupOffsets(GROUP_ID); + final Map partitionToOffsetAndMetadata = result.partitionsToOffsetAndMetadata().get(); + + assertEquals(4, partitionToOffsetAndMetadata.size()); + assertEquals(10, partitionToOffsetAndMetadata.get(myTopicPartition0).offset()); + assertEquals(0, partitionToOffsetAndMetadata.get(myTopicPartition1).offset()); + assertEquals(20, partitionToOffsetAndMetadata.get(myTopicPartition2).offset()); + assertTrue(partitionToOffsetAndMetadata.containsKey(myTopicPartition3)); + assertNull(partitionToOffsetAndMetadata.get(myTopicPartition3)); + } + } + + @Test + public void testDeleteConsumerGroupsNumRetries() throws Exception { + final Cluster cluster = mockCluster(3, 0); + final Time time = new MockTime(); + final List groupIds = singletonList("groupId"); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(time, cluster, + AdminClientConfig.RETRIES_CONFIG, "0")) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + final DeletableGroupResultCollection validResponse = new DeletableGroupResultCollection(); + validResponse.add(new DeletableGroupResult() + .setGroupId("groupId") + .setErrorCode(Errors.NOT_COORDINATOR.code())); + env.kafkaClient().prepareResponse(new DeleteGroupsResponse( + new DeleteGroupsResponseData() + .setResults(validResponse) + )); + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + final DeleteConsumerGroupsResult result = env.adminClient().deleteConsumerGroups(groupIds); + + TestUtils.assertFutureError(result.all(), TimeoutException.class); + } + } + + @Test + public void testDeleteConsumerGroupsRetryBackoff() throws Exception { + MockTime time = new MockTime(); + int retryBackoff = 100; + final List groupIds = singletonList(GROUP_ID); + + try (final AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(time, + mockCluster(3, 0), + newStrMap(AdminClientConfig.RETRY_BACKOFF_MS_CONFIG, "" + retryBackoff))) { + MockClient mockClient = env.kafkaClient(); + + mockClient.setNodeApiVersions(NodeApiVersions.create()); + + AtomicLong firstAttemptTime = new AtomicLong(0); + AtomicLong secondAttemptTime = new AtomicLong(0); + + mockClient.prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + DeletableGroupResultCollection validResponse = new DeletableGroupResultCollection(); + validResponse.add(new DeletableGroupResult() + .setGroupId(GROUP_ID) + .setErrorCode(Errors.NOT_COORDINATOR.code())); + + + mockClient.prepareResponse(body -> { + firstAttemptTime.set(time.milliseconds()); + return true; + }, new DeleteGroupsResponse(new DeleteGroupsResponseData().setResults(validResponse))); + + mockClient.prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + validResponse = new DeletableGroupResultCollection(); + validResponse.add(new DeletableGroupResult() + .setGroupId(GROUP_ID) + .setErrorCode(Errors.NONE.code())); + + mockClient.prepareResponse(body -> { + secondAttemptTime.set(time.milliseconds()); + return true; + }, new DeleteGroupsResponse(new DeleteGroupsResponseData().setResults(validResponse))); + + final KafkaFuture future = env.adminClient().deleteConsumerGroups(groupIds).all(); + + TestUtils.waitForCondition(() -> mockClient.numAwaitingResponses() == 1, "Failed awaiting DeleteConsumerGroups first request failure"); + TestUtils.waitForCondition(() -> ((KafkaAdminClient) env.adminClient()).numPendingCalls() == 1, "Failed to add retry DeleteConsumerGroups call on first failure"); + time.sleep(retryBackoff); + + future.get(); + + long actualRetryBackoff = secondAttemptTime.get() - firstAttemptTime.get(); + assertEquals(retryBackoff, actualRetryBackoff, "DeleteConsumerGroups retry did not await expected backoff!"); + } + } + + @Test + public void testDeleteConsumerGroupsWithOlderBroker() throws Exception { + final List groupIds = singletonList("groupId"); + ApiVersion findCoordinatorV3 = new ApiVersion() + .setApiKey(ApiKeys.FIND_COORDINATOR.id) + .setMinVersion((short) 0) + .setMaxVersion((short) 3); + ApiVersion describeGroups = new ApiVersion() + .setApiKey(ApiKeys.DESCRIBE_GROUPS.id) + .setMinVersion((short) 0) + .setMaxVersion(ApiKeys.DELETE_GROUPS.latestVersion()); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(mockCluster(1, 0))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create(Arrays.asList(findCoordinatorV3, describeGroups))); + + // Retriable FindCoordinatorResponse errors should be retried + env.kafkaClient().prepareResponse(prepareOldFindCoordinatorResponse(Errors.COORDINATOR_NOT_AVAILABLE, Node.noNode())); + env.kafkaClient().prepareResponse(prepareOldFindCoordinatorResponse(Errors.COORDINATOR_LOAD_IN_PROGRESS, Node.noNode())); + + env.kafkaClient().prepareResponse(prepareOldFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + final DeletableGroupResultCollection validResponse = new DeletableGroupResultCollection(); + validResponse.add(new DeletableGroupResult() + .setGroupId("groupId") + .setErrorCode(Errors.NONE.code())); + env.kafkaClient().prepareResponse(new DeleteGroupsResponse( + new DeleteGroupsResponseData() + .setResults(validResponse) + )); + + final DeleteConsumerGroupsResult result = env.adminClient().deleteConsumerGroups(groupIds); + + final KafkaFuture results = result.deletedGroups().get("groupId"); + assertNull(results.get()); + + // should throw error for non-retriable errors + env.kafkaClient().prepareResponse( + prepareOldFindCoordinatorResponse(Errors.GROUP_AUTHORIZATION_FAILED, Node.noNode())); + + DeleteConsumerGroupsResult errorResult = env.adminClient().deleteConsumerGroups(groupIds); + TestUtils.assertFutureError(errorResult.deletedGroups().get("groupId"), GroupAuthorizationException.class); + + // Retriable errors should be retried + env.kafkaClient().prepareResponse( + prepareOldFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + final DeletableGroupResultCollection errorResponse = new DeletableGroupResultCollection(); + errorResponse.add(new DeletableGroupResult() + .setGroupId("groupId") + .setErrorCode(Errors.COORDINATOR_LOAD_IN_PROGRESS.code()) + ); + env.kafkaClient().prepareResponse(new DeleteGroupsResponse( + new DeleteGroupsResponseData() + .setResults(errorResponse))); + + /* + * We need to return two responses here, one for NOT_COORDINATOR call when calling delete a consumer group + * api using coordinator that has moved. This will retry whole operation. So we need to again respond with a + * FindCoordinatorResponse. + * + * And the same reason for the following COORDINATOR_NOT_AVAILABLE error response + */ + + DeletableGroupResultCollection coordinatorMoved = new DeletableGroupResultCollection(); + coordinatorMoved.add(new DeletableGroupResult() + .setGroupId("groupId") + .setErrorCode(Errors.NOT_COORDINATOR.code()) + ); + + env.kafkaClient().prepareResponse(new DeleteGroupsResponse( + new DeleteGroupsResponseData() + .setResults(coordinatorMoved))); + env.kafkaClient().prepareResponse(prepareOldFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + coordinatorMoved = new DeletableGroupResultCollection(); + coordinatorMoved.add(new DeletableGroupResult() + .setGroupId("groupId") + .setErrorCode(Errors.COORDINATOR_NOT_AVAILABLE.code()) + ); + + env.kafkaClient().prepareResponse(new DeleteGroupsResponse( + new DeleteGroupsResponseData() + .setResults(coordinatorMoved))); + env.kafkaClient().prepareResponse(prepareOldFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + env.kafkaClient().prepareResponse(new DeleteGroupsResponse( + new DeleteGroupsResponseData() + .setResults(validResponse))); + + errorResult = env.adminClient().deleteConsumerGroups(groupIds); + + final KafkaFuture errorResults = errorResult.deletedGroups().get("groupId"); + assertNull(errorResults.get()); + } + } + + @Test + public void testDeleteMultipleConsumerGroupsWithOlderBroker() throws Exception { + final List groupIds = asList("group1", "group2"); + ApiVersion findCoordinatorV3 = new ApiVersion() + .setApiKey(ApiKeys.FIND_COORDINATOR.id) + .setMinVersion((short) 0) + .setMaxVersion((short) 3); + ApiVersion describeGroups = new ApiVersion() + .setApiKey(ApiKeys.DESCRIBE_GROUPS.id) + .setMinVersion((short) 0) + .setMaxVersion(ApiKeys.DELETE_GROUPS.latestVersion()); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(mockCluster(1, 0))) { + env.kafkaClient().setNodeApiVersions( + NodeApiVersions.create(Arrays.asList(findCoordinatorV3, describeGroups))); + + // Dummy response for MockClient to handle the UnsupportedVersionException correctly to switch from batched to un-batched + env.kafkaClient().prepareResponse(null); + // Retriable FindCoordinatorResponse errors should be retried + for (int i = 0; i < groupIds.size(); i++) { + env.kafkaClient().prepareResponse( + prepareOldFindCoordinatorResponse(Errors.COORDINATOR_NOT_AVAILABLE, Node.noNode())); + } + for (int i = 0; i < groupIds.size(); i++) { + env.kafkaClient().prepareResponse( + prepareOldFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + } + + final DeletableGroupResultCollection validResponse = new DeletableGroupResultCollection(); + validResponse.add(new DeletableGroupResult() + .setGroupId("group1") + .setErrorCode(Errors.NONE.code())); + validResponse.add(new DeletableGroupResult() + .setGroupId("group2") + .setErrorCode(Errors.NONE.code())); + env.kafkaClient().prepareResponse(new DeleteGroupsResponse( + new DeleteGroupsResponseData() + .setResults(validResponse) + )); + + final DeleteConsumerGroupsResult result = env.adminClient() + .deleteConsumerGroups(groupIds); + + final KafkaFuture results = result.deletedGroups().get("group1"); + assertNull(results.get(5, TimeUnit.SECONDS)); + } + } + + @Test + public void testDeleteConsumerGroupOffsetsNumRetries() throws Exception { + final Cluster cluster = mockCluster(3, 0); + final Time time = new MockTime(); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(time, cluster, + AdminClientConfig.RETRIES_CONFIG, "0")) { + final TopicPartition tp1 = new TopicPartition("foo", 0); + + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + env.kafkaClient().prepareResponse(prepareOffsetDeleteResponse(Errors.NOT_COORDINATOR)); + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + final DeleteConsumerGroupOffsetsResult result = env.adminClient() + .deleteConsumerGroupOffsets(GROUP_ID, Stream.of(tp1).collect(Collectors.toSet())); + + TestUtils.assertFutureError(result.all(), TimeoutException.class); + } + } + + @Test + public void testDeleteConsumerGroupOffsetsRetryBackoff() throws Exception { + MockTime time = new MockTime(); + int retryBackoff = 100; + + try (final AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(time, + mockCluster(3, 0), + newStrMap(AdminClientConfig.RETRY_BACKOFF_MS_CONFIG, "" + retryBackoff))) { + MockClient mockClient = env.kafkaClient(); + + mockClient.setNodeApiVersions(NodeApiVersions.create()); + + AtomicLong firstAttemptTime = new AtomicLong(0); + AtomicLong secondAttemptTime = new AtomicLong(0); + + final TopicPartition tp1 = new TopicPartition("foo", 0); + + mockClient.prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + mockClient.prepareResponse(body -> { + firstAttemptTime.set(time.milliseconds()); + return true; + }, prepareOffsetDeleteResponse(Errors.NOT_COORDINATOR)); + + + mockClient.prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + mockClient.prepareResponse(body -> { + secondAttemptTime.set(time.milliseconds()); + return true; + }, prepareOffsetDeleteResponse("foo", 0, Errors.NONE)); + + final KafkaFuture future = env.adminClient().deleteConsumerGroupOffsets(GROUP_ID, Stream.of(tp1).collect(Collectors.toSet())).all(); + + TestUtils.waitForCondition(() -> mockClient.numAwaitingResponses() == 1, "Failed awaiting DeleteConsumerGroupOffsets first request failure"); + TestUtils.waitForCondition(() -> ((KafkaAdminClient) env.adminClient()).numPendingCalls() == 1, "Failed to add retry DeleteConsumerGroupOffsets call on first failure"); + time.sleep(retryBackoff); + + future.get(); + + long actualRetryBackoff = secondAttemptTime.get() - firstAttemptTime.get(); + assertEquals(retryBackoff, actualRetryBackoff, "DeleteConsumerGroupOffsets retry did not await expected backoff!"); + } + } + + @Test + public void testDeleteConsumerGroupOffsets() throws Exception { + // Happy path + + final TopicPartition tp1 = new TopicPartition("foo", 0); + final TopicPartition tp2 = new TopicPartition("bar", 0); + final TopicPartition tp3 = new TopicPartition("foobar", 0); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(mockCluster(1, 0))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse( + prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + env.kafkaClient().prepareResponse(new OffsetDeleteResponse( + new OffsetDeleteResponseData() + .setTopics(new OffsetDeleteResponseTopicCollection(Stream.of( + new OffsetDeleteResponseTopic() + .setName("foo") + .setPartitions(new OffsetDeleteResponsePartitionCollection(Collections.singletonList( + new OffsetDeleteResponsePartition() + .setPartitionIndex(0) + .setErrorCode(Errors.NONE.code()) + ).iterator())), + new OffsetDeleteResponseTopic() + .setName("bar") + .setPartitions(new OffsetDeleteResponsePartitionCollection(Collections.singletonList( + new OffsetDeleteResponsePartition() + .setPartitionIndex(0) + .setErrorCode(Errors.GROUP_SUBSCRIBED_TO_TOPIC.code()) + ).iterator())) + ).collect(Collectors.toList()).iterator())) + ) + ); + + final DeleteConsumerGroupOffsetsResult errorResult = env.adminClient().deleteConsumerGroupOffsets( + GROUP_ID, Stream.of(tp1, tp2).collect(Collectors.toSet())); + + assertNull(errorResult.partitionResult(tp1).get()); + TestUtils.assertFutureError(errorResult.all(), GroupSubscribedToTopicException.class); + TestUtils.assertFutureError(errorResult.partitionResult(tp2), GroupSubscribedToTopicException.class); + assertThrows(IllegalArgumentException.class, () -> errorResult.partitionResult(tp3)); + } + } + + @Test + public void testDeleteConsumerGroupOffsetsRetriableErrors() throws Exception { + // Retriable errors should be retried + + final TopicPartition tp1 = new TopicPartition("foo", 0); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(mockCluster(1, 0))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse( + prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + env.kafkaClient().prepareResponse( + prepareOffsetDeleteResponse(Errors.COORDINATOR_LOAD_IN_PROGRESS)); + + /* + * We need to return two responses here, one for NOT_COORDINATOR call when calling delete a consumer group + * api using coordinator that has moved. This will retry whole operation. So we need to again respond with a + * FindCoordinatorResponse. + * + * And the same reason for the following COORDINATOR_NOT_AVAILABLE error response + */ + env.kafkaClient().prepareResponse( + prepareOffsetDeleteResponse(Errors.NOT_COORDINATOR)); + + env.kafkaClient().prepareResponse( + prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + env.kafkaClient().prepareResponse( + prepareOffsetDeleteResponse(Errors.COORDINATOR_NOT_AVAILABLE)); + + env.kafkaClient().prepareResponse( + prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + env.kafkaClient().prepareResponse( + prepareOffsetDeleteResponse("foo", 0, Errors.NONE)); + + final DeleteConsumerGroupOffsetsResult errorResult1 = env.adminClient() + .deleteConsumerGroupOffsets(GROUP_ID, Stream.of(tp1).collect(Collectors.toSet())); + + assertNull(errorResult1.all().get()); + assertNull(errorResult1.partitionResult(tp1).get()); + } + } + + @Test + public void testDeleteConsumerGroupOffsetsNonRetriableErrors() throws Exception { + // Non-retriable errors throw an exception + + final TopicPartition tp1 = new TopicPartition("foo", 0); + final List nonRetriableErrors = Arrays.asList( + Errors.GROUP_AUTHORIZATION_FAILED, Errors.INVALID_GROUP_ID, Errors.GROUP_ID_NOT_FOUND); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(mockCluster(1, 0))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + for (Errors error : nonRetriableErrors) { + env.kafkaClient().prepareResponse( + prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + env.kafkaClient().prepareResponse( + prepareOffsetDeleteResponse(error)); + + DeleteConsumerGroupOffsetsResult errorResult = env.adminClient() + .deleteConsumerGroupOffsets(GROUP_ID, Stream.of(tp1).collect(Collectors.toSet())); + + TestUtils.assertFutureError(errorResult.all(), error.exception().getClass()); + TestUtils.assertFutureError(errorResult.partitionResult(tp1), error.exception().getClass()); + } + } + } + + @Test + public void testDeleteConsumerGroupOffsetsFindCoordinatorRetriableErrors() throws Exception { + // Retriable FindCoordinatorResponse errors should be retried + + final TopicPartition tp1 = new TopicPartition("foo", 0); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(mockCluster(1, 0))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse( + prepareFindCoordinatorResponse(Errors.COORDINATOR_NOT_AVAILABLE, Node.noNode())); + env.kafkaClient().prepareResponse( + prepareFindCoordinatorResponse(Errors.COORDINATOR_LOAD_IN_PROGRESS, Node.noNode())); + + env.kafkaClient().prepareResponse( + prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + env.kafkaClient().prepareResponse( + prepareOffsetDeleteResponse("foo", 0, Errors.NONE)); + + final DeleteConsumerGroupOffsetsResult result = env.adminClient() + .deleteConsumerGroupOffsets(GROUP_ID, Stream.of(tp1).collect(Collectors.toSet())); + + assertNull(result.all().get()); + assertNull(result.partitionResult(tp1).get()); + } + } + + @Test + public void testDeleteConsumerGroupOffsetsFindCoordinatorNonRetriableErrors() throws Exception { + // Non-retriable FindCoordinatorResponse errors throw an exception + + final TopicPartition tp1 = new TopicPartition("foo", 0); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(mockCluster(1, 0))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse( + prepareFindCoordinatorResponse(Errors.GROUP_AUTHORIZATION_FAILED, Node.noNode())); + + final DeleteConsumerGroupOffsetsResult errorResult = env.adminClient() + .deleteConsumerGroupOffsets(GROUP_ID, Stream.of(tp1).collect(Collectors.toSet())); + + TestUtils.assertFutureError(errorResult.all(), GroupAuthorizationException.class); + TestUtils.assertFutureError(errorResult.partitionResult(tp1), GroupAuthorizationException.class); + } + } + + @Test + public void testIncrementalAlterConfigs() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + //test error scenarios + IncrementalAlterConfigsResponseData responseData = new IncrementalAlterConfigsResponseData(); + responseData.responses().add(new AlterConfigsResourceResponse() + .setResourceName("") + .setResourceType(ConfigResource.Type.BROKER.id()) + .setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code()) + .setErrorMessage("authorization error")); + + responseData.responses().add(new AlterConfigsResourceResponse() + .setResourceName("topic1") + .setResourceType(ConfigResource.Type.TOPIC.id()) + .setErrorCode(Errors.INVALID_REQUEST.code()) + .setErrorMessage("Config value append is not allowed for config")); + + env.kafkaClient().prepareResponse(new IncrementalAlterConfigsResponse(responseData)); + + ConfigResource brokerResource = new ConfigResource(ConfigResource.Type.BROKER, ""); + ConfigResource topicResource = new ConfigResource(ConfigResource.Type.TOPIC, "topic1"); + + AlterConfigOp alterConfigOp1 = new AlterConfigOp( + new ConfigEntry("log.segment.bytes", "1073741"), + AlterConfigOp.OpType.SET); + + AlterConfigOp alterConfigOp2 = new AlterConfigOp( + new ConfigEntry("compression.type", "gzip"), + AlterConfigOp.OpType.APPEND); + + final Map> configs = new HashMap<>(); + configs.put(brokerResource, singletonList(alterConfigOp1)); + configs.put(topicResource, singletonList(alterConfigOp2)); + + AlterConfigsResult result = env.adminClient().incrementalAlterConfigs(configs); + TestUtils.assertFutureError(result.values().get(brokerResource), ClusterAuthorizationException.class); + TestUtils.assertFutureError(result.values().get(topicResource), InvalidRequestException.class); + + // Test a call where there are no errors. + responseData = new IncrementalAlterConfigsResponseData(); + responseData.responses().add(new AlterConfigsResourceResponse() + .setResourceName("") + .setResourceType(ConfigResource.Type.BROKER.id()) + .setErrorCode(Errors.NONE.code()) + .setErrorMessage(ApiError.NONE.message())); + + env.kafkaClient().prepareResponse(new IncrementalAlterConfigsResponse(responseData)); + env.adminClient().incrementalAlterConfigs(Collections.singletonMap(brokerResource, singletonList(alterConfigOp1))).all().get(); + } + } + + @Test + public void testRemoveMembersFromGroupNumRetries() throws Exception { + final Cluster cluster = mockCluster(3, 0); + final Time time = new MockTime(); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(time, cluster, + AdminClientConfig.RETRIES_CONFIG, "0")) { + + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + env.kafkaClient().prepareResponse(new LeaveGroupResponse(new LeaveGroupResponseData().setErrorCode(Errors.NOT_COORDINATOR.code()))); + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + Collection membersToRemove = Arrays.asList(new MemberToRemove("instance-1"), new MemberToRemove("instance-2")); + + final RemoveMembersFromConsumerGroupResult result = env.adminClient().removeMembersFromConsumerGroup( + GROUP_ID, new RemoveMembersFromConsumerGroupOptions(membersToRemove)); + + TestUtils.assertFutureError(result.all(), TimeoutException.class); + } + } + + @Test + public void testRemoveMembersFromGroupRetryBackoff() throws Exception { + MockTime time = new MockTime(); + int retryBackoff = 100; + + try (final AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(time, + mockCluster(3, 0), + newStrMap(AdminClientConfig.RETRY_BACKOFF_MS_CONFIG, "" + retryBackoff))) { + MockClient mockClient = env.kafkaClient(); + + mockClient.setNodeApiVersions(NodeApiVersions.create()); + + AtomicLong firstAttemptTime = new AtomicLong(0); + AtomicLong secondAttemptTime = new AtomicLong(0); + + mockClient.prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + env.kafkaClient().prepareResponse(body -> { + firstAttemptTime.set(time.milliseconds()); + return true; + }, new LeaveGroupResponse(new LeaveGroupResponseData().setErrorCode(Errors.NOT_COORDINATOR.code()))); + + mockClient.prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + MemberResponse responseOne = new MemberResponse() + .setGroupInstanceId("instance-1") + .setErrorCode(Errors.NONE.code()); + env.kafkaClient().prepareResponse(body -> { + secondAttemptTime.set(time.milliseconds()); + return true; + }, new LeaveGroupResponse(new LeaveGroupResponseData() + .setErrorCode(Errors.NONE.code()) + .setMembers(Collections.singletonList(responseOne)))); + + Collection membersToRemove = singletonList(new MemberToRemove("instance-1")); + + final KafkaFuture future = env.adminClient().removeMembersFromConsumerGroup( + GROUP_ID, new RemoveMembersFromConsumerGroupOptions(membersToRemove)).all(); + + + TestUtils.waitForCondition(() -> mockClient.numAwaitingResponses() == 1, "Failed awaiting RemoveMembersFromGroup first request failure"); + TestUtils.waitForCondition(() -> ((KafkaAdminClient) env.adminClient()).numPendingCalls() == 1, "Failed to add retry RemoveMembersFromGroup call on first failure"); + time.sleep(retryBackoff); + + future.get(); + + long actualRetryBackoff = secondAttemptTime.get() - firstAttemptTime.get(); + assertEquals(retryBackoff, actualRetryBackoff, "RemoveMembersFromGroup retry did not await expected backoff!"); + } + } + + @Test + public void testRemoveMembersFromGroupRetriableErrors() throws Exception { + // Retriable errors should be retried + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(mockCluster(1, 0))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse( + prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + env.kafkaClient().prepareResponse( + new LeaveGroupResponse(new LeaveGroupResponseData() + .setErrorCode(Errors.COORDINATOR_LOAD_IN_PROGRESS.code()))); + + /* + * We need to return two responses here, one for NOT_COORDINATOR call when calling remove member + * api using coordinator that has moved. This will retry whole operation. So we need to again respond with a + * FindCoordinatorResponse. + * + * And the same reason for the following COORDINATOR_NOT_AVAILABLE error response + */ + env.kafkaClient().prepareResponse( + new LeaveGroupResponse(new LeaveGroupResponseData() + .setErrorCode(Errors.NOT_COORDINATOR.code()))); + + env.kafkaClient().prepareResponse( + prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + env.kafkaClient().prepareResponse( + new LeaveGroupResponse(new LeaveGroupResponseData() + .setErrorCode(Errors.COORDINATOR_NOT_AVAILABLE.code()))); + + env.kafkaClient().prepareResponse( + prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + MemberResponse memberResponse = new MemberResponse() + .setGroupInstanceId("instance-1") + .setErrorCode(Errors.NONE.code()); + env.kafkaClient().prepareResponse( + new LeaveGroupResponse(new LeaveGroupResponseData() + .setErrorCode(Errors.NONE.code()) + .setMembers(Collections.singletonList(memberResponse)))); + + MemberToRemove memberToRemove = new MemberToRemove("instance-1"); + Collection membersToRemove = singletonList(memberToRemove); + + final RemoveMembersFromConsumerGroupResult result = env.adminClient().removeMembersFromConsumerGroup( + GROUP_ID, new RemoveMembersFromConsumerGroupOptions(membersToRemove)); + + assertNull(result.all().get()); + assertNull(result.memberResult(memberToRemove).get()); + } + } + + @Test + public void testRemoveMembersFromGroupNonRetriableErrors() throws Exception { + // Non-retriable errors throw an exception + + final List nonRetriableErrors = Arrays.asList( + Errors.GROUP_AUTHORIZATION_FAILED, Errors.INVALID_GROUP_ID, Errors.GROUP_ID_NOT_FOUND); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(mockCluster(1, 0))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + for (Errors error : nonRetriableErrors) { + env.kafkaClient().prepareResponse( + prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + env.kafkaClient().prepareResponse( + new LeaveGroupResponse(new LeaveGroupResponseData() + .setErrorCode(error.code()))); + + MemberToRemove memberToRemove = new MemberToRemove("instance-1"); + Collection membersToRemove = singletonList(memberToRemove); + + final RemoveMembersFromConsumerGroupResult result = env.adminClient().removeMembersFromConsumerGroup( + GROUP_ID, new RemoveMembersFromConsumerGroupOptions(membersToRemove)); + + TestUtils.assertFutureError(result.all(), error.exception().getClass()); + TestUtils.assertFutureError(result.memberResult(memberToRemove), error.exception().getClass()); + } + } + } + + @Test + public void testRemoveMembersFromGroup() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + final String instanceOne = "instance-1"; + final String instanceTwo = "instance-2"; + + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + // Retriable FindCoordinatorResponse errors should be retried + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.COORDINATOR_LOAD_IN_PROGRESS, Node.noNode())); + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + // Retriable errors should be retried + env.kafkaClient().prepareResponse(new LeaveGroupResponse(new LeaveGroupResponseData() + .setErrorCode(Errors.COORDINATOR_LOAD_IN_PROGRESS.code()))); + + // Inject a top-level non-retriable error + env.kafkaClient().prepareResponse(new LeaveGroupResponse(new LeaveGroupResponseData() + .setErrorCode(Errors.UNKNOWN_SERVER_ERROR.code()))); + + Collection membersToRemove = Arrays.asList(new MemberToRemove(instanceOne), + new MemberToRemove(instanceTwo)); + final RemoveMembersFromConsumerGroupResult unknownErrorResult = env.adminClient().removeMembersFromConsumerGroup( + GROUP_ID, + new RemoveMembersFromConsumerGroupOptions(membersToRemove) + ); + + MemberToRemove memberOne = new MemberToRemove(instanceOne); + MemberToRemove memberTwo = new MemberToRemove(instanceTwo); + + TestUtils.assertFutureError(unknownErrorResult.memberResult(memberOne), UnknownServerException.class); + TestUtils.assertFutureError(unknownErrorResult.memberResult(memberTwo), UnknownServerException.class); + + MemberResponse responseOne = new MemberResponse() + .setGroupInstanceId(instanceOne) + .setErrorCode(Errors.UNKNOWN_MEMBER_ID.code()); + + MemberResponse responseTwo = new MemberResponse() + .setGroupInstanceId(instanceTwo) + .setErrorCode(Errors.NONE.code()); + + // Inject one member level error. + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + env.kafkaClient().prepareResponse(new LeaveGroupResponse(new LeaveGroupResponseData() + .setErrorCode(Errors.NONE.code()) + .setMembers(Arrays.asList(responseOne, responseTwo)))); + + final RemoveMembersFromConsumerGroupResult memberLevelErrorResult = env.adminClient().removeMembersFromConsumerGroup( + GROUP_ID, + new RemoveMembersFromConsumerGroupOptions(membersToRemove) + ); + + TestUtils.assertFutureError(memberLevelErrorResult.all(), UnknownMemberIdException.class); + TestUtils.assertFutureError(memberLevelErrorResult.memberResult(memberOne), UnknownMemberIdException.class); + assertNull(memberLevelErrorResult.memberResult(memberTwo).get()); + + // Return with missing member. + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + env.kafkaClient().prepareResponse(new LeaveGroupResponse(new LeaveGroupResponseData() + .setErrorCode(Errors.NONE.code()) + .setMembers(Collections.singletonList(responseTwo)))); + + final RemoveMembersFromConsumerGroupResult missingMemberResult = env.adminClient().removeMembersFromConsumerGroup( + GROUP_ID, + new RemoveMembersFromConsumerGroupOptions(membersToRemove) + ); + + TestUtils.assertFutureError(missingMemberResult.all(), IllegalArgumentException.class); + // The memberOne was not included in the response. + TestUtils.assertFutureError(missingMemberResult.memberResult(memberOne), IllegalArgumentException.class); + assertNull(missingMemberResult.memberResult(memberTwo).get()); + + + // Return with success. + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + env.kafkaClient().prepareResponse(new LeaveGroupResponse( + new LeaveGroupResponseData().setErrorCode(Errors.NONE.code()).setMembers( + Arrays.asList(responseTwo, + new MemberResponse().setGroupInstanceId(instanceOne).setErrorCode(Errors.NONE.code()) + )) + )); + + final RemoveMembersFromConsumerGroupResult noErrorResult = env.adminClient().removeMembersFromConsumerGroup( + GROUP_ID, + new RemoveMembersFromConsumerGroupOptions(membersToRemove) + ); + assertNull(noErrorResult.all().get()); + assertNull(noErrorResult.memberResult(memberOne).get()); + assertNull(noErrorResult.memberResult(memberTwo).get()); + + // Test the "removeAll" scenario + final List topicPartitions = Arrays.asList(1, 2, 3).stream().map(partition -> new TopicPartition("my_topic", partition)) + .collect(Collectors.toList()); + // construct the DescribeGroupsResponse + DescribeGroupsResponseData data = prepareDescribeGroupsResponseData(GROUP_ID, Arrays.asList(instanceOne, instanceTwo), topicPartitions); + + // Return with partial failure for "removeAll" scenario + // 1 prepare response for AdminClient.describeConsumerGroups + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + env.kafkaClient().prepareResponse(new DescribeGroupsResponse(data)); + + // 2 KafkaAdminClient encounter partial failure when trying to delete all members + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + env.kafkaClient().prepareResponse(new LeaveGroupResponse( + new LeaveGroupResponseData().setErrorCode(Errors.NONE.code()).setMembers( + Arrays.asList(responseOne, responseTwo)) + )); + final RemoveMembersFromConsumerGroupResult partialFailureResults = env.adminClient().removeMembersFromConsumerGroup( + GROUP_ID, + new RemoveMembersFromConsumerGroupOptions() + ); + ExecutionException exception = assertThrows(ExecutionException.class, () -> partialFailureResults.all().get()); + assertTrue(exception.getCause() instanceof KafkaException); + assertTrue(exception.getCause().getCause() instanceof UnknownMemberIdException); + + // Return with success for "removeAll" scenario + // 1 prepare response for AdminClient.describeConsumerGroups + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + env.kafkaClient().prepareResponse(new DescribeGroupsResponse(data)); + + // 2. KafkaAdminClient should delete all members correctly + env.kafkaClient().prepareResponse(prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + env.kafkaClient().prepareResponse(new LeaveGroupResponse( + new LeaveGroupResponseData().setErrorCode(Errors.NONE.code()).setMembers( + Arrays.asList(responseTwo, + new MemberResponse().setGroupInstanceId(instanceOne).setErrorCode(Errors.NONE.code()) + )) + )); + final RemoveMembersFromConsumerGroupResult successResult = env.adminClient().removeMembersFromConsumerGroup( + GROUP_ID, + new RemoveMembersFromConsumerGroupOptions() + ); + assertNull(successResult.all().get()); + } + } + + @Test + public void testAlterPartitionReassignments() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + TopicPartition tp1 = new TopicPartition("A", 0); + TopicPartition tp2 = new TopicPartition("B", 0); + Map> reassignments = new HashMap<>(); + reassignments.put(tp1, Optional.empty()); + reassignments.put(tp2, Optional.of(new NewPartitionReassignment(Arrays.asList(1, 2, 3)))); + + // 1. server returns less responses than number of partitions we sent + AlterPartitionReassignmentsResponseData responseData1 = new AlterPartitionReassignmentsResponseData(); + ReassignablePartitionResponse normalPartitionResponse = new ReassignablePartitionResponse().setPartitionIndex(0); + responseData1.setResponses(Collections.singletonList( + new ReassignableTopicResponse() + .setName("A") + .setPartitions(Collections.singletonList(normalPartitionResponse)))); + env.kafkaClient().prepareResponse(new AlterPartitionReassignmentsResponse(responseData1)); + AlterPartitionReassignmentsResult result1 = env.adminClient().alterPartitionReassignments(reassignments); + Future future1 = result1.all(); + Future future2 = result1.values().get(tp1); + TestUtils.assertFutureError(future1, UnknownServerException.class); + TestUtils.assertFutureError(future2, UnknownServerException.class); + + // 2. NOT_CONTROLLER error handling + AlterPartitionReassignmentsResponseData controllerErrResponseData = + new AlterPartitionReassignmentsResponseData() + .setErrorCode(Errors.NOT_CONTROLLER.code()) + .setErrorMessage(Errors.NOT_CONTROLLER.message()) + .setResponses(Arrays.asList( + new ReassignableTopicResponse() + .setName("A") + .setPartitions(Collections.singletonList(normalPartitionResponse)), + new ReassignableTopicResponse() + .setName("B") + .setPartitions(Collections.singletonList(normalPartitionResponse))) + ); + MetadataResponse controllerNodeResponse = RequestTestUtils.metadataResponse(env.cluster().nodes(), + env.cluster().clusterResource().clusterId(), 1, Collections.emptyList()); + AlterPartitionReassignmentsResponseData normalResponse = + new AlterPartitionReassignmentsResponseData() + .setResponses(Arrays.asList( + new ReassignableTopicResponse() + .setName("A") + .setPartitions(Collections.singletonList(normalPartitionResponse)), + new ReassignableTopicResponse() + .setName("B") + .setPartitions(Collections.singletonList(normalPartitionResponse))) + ); + env.kafkaClient().prepareResponse(new AlterPartitionReassignmentsResponse(controllerErrResponseData)); + env.kafkaClient().prepareResponse(controllerNodeResponse); + env.kafkaClient().prepareResponse(new AlterPartitionReassignmentsResponse(normalResponse)); + AlterPartitionReassignmentsResult controllerErrResult = env.adminClient().alterPartitionReassignments(reassignments); + controllerErrResult.all().get(); + controllerErrResult.values().get(tp1).get(); + controllerErrResult.values().get(tp2).get(); + + // 3. partition-level error + AlterPartitionReassignmentsResponseData partitionLevelErrData = + new AlterPartitionReassignmentsResponseData() + .setResponses(Arrays.asList( + new ReassignableTopicResponse() + .setName("A") + .setPartitions(Collections.singletonList(new ReassignablePartitionResponse() + .setPartitionIndex(0).setErrorMessage(Errors.INVALID_REPLICA_ASSIGNMENT.message()) + .setErrorCode(Errors.INVALID_REPLICA_ASSIGNMENT.code()) + )), + new ReassignableTopicResponse() + .setName("B") + .setPartitions(Collections.singletonList(normalPartitionResponse))) + ); + env.kafkaClient().prepareResponse(new AlterPartitionReassignmentsResponse(partitionLevelErrData)); + AlterPartitionReassignmentsResult partitionLevelErrResult = env.adminClient().alterPartitionReassignments(reassignments); + TestUtils.assertFutureError(partitionLevelErrResult.values().get(tp1), Errors.INVALID_REPLICA_ASSIGNMENT.exception().getClass()); + partitionLevelErrResult.values().get(tp2).get(); + + // 4. top-level error + String errorMessage = "this is custom error message"; + AlterPartitionReassignmentsResponseData topLevelErrResponseData = + new AlterPartitionReassignmentsResponseData() + .setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code()) + .setErrorMessage(errorMessage) + .setResponses(Arrays.asList( + new ReassignableTopicResponse() + .setName("A") + .setPartitions(Collections.singletonList(normalPartitionResponse)), + new ReassignableTopicResponse() + .setName("B") + .setPartitions(Collections.singletonList(normalPartitionResponse))) + ); + env.kafkaClient().prepareResponse(new AlterPartitionReassignmentsResponse(topLevelErrResponseData)); + AlterPartitionReassignmentsResult topLevelErrResult = env.adminClient().alterPartitionReassignments(reassignments); + assertEquals(errorMessage, TestUtils.assertFutureThrows(topLevelErrResult.all(), Errors.CLUSTER_AUTHORIZATION_FAILED.exception().getClass()).getMessage()); + assertEquals(errorMessage, TestUtils.assertFutureThrows(topLevelErrResult.values().get(tp1), Errors.CLUSTER_AUTHORIZATION_FAILED.exception().getClass()).getMessage()); + assertEquals(errorMessage, TestUtils.assertFutureThrows(topLevelErrResult.values().get(tp2), Errors.CLUSTER_AUTHORIZATION_FAILED.exception().getClass()).getMessage()); + + // 5. unrepresentable topic name error + TopicPartition invalidTopicTP = new TopicPartition("", 0); + TopicPartition invalidPartitionTP = new TopicPartition("ABC", -1); + Map> invalidTopicReassignments = new HashMap<>(); + invalidTopicReassignments.put(invalidPartitionTP, Optional.of(new NewPartitionReassignment(Arrays.asList(1, 2, 3)))); + invalidTopicReassignments.put(invalidTopicTP, Optional.of(new NewPartitionReassignment(Arrays.asList(1, 2, 3)))); + invalidTopicReassignments.put(tp1, Optional.of(new NewPartitionReassignment(Arrays.asList(1, 2, 3)))); + + AlterPartitionReassignmentsResponseData singlePartResponseData = + new AlterPartitionReassignmentsResponseData() + .setResponses(Collections.singletonList( + new ReassignableTopicResponse() + .setName("A") + .setPartitions(Collections.singletonList(normalPartitionResponse))) + ); + env.kafkaClient().prepareResponse(new AlterPartitionReassignmentsResponse(singlePartResponseData)); + AlterPartitionReassignmentsResult unrepresentableTopicResult = env.adminClient().alterPartitionReassignments(invalidTopicReassignments); + TestUtils.assertFutureError(unrepresentableTopicResult.values().get(invalidTopicTP), InvalidTopicException.class); + TestUtils.assertFutureError(unrepresentableTopicResult.values().get(invalidPartitionTP), InvalidTopicException.class); + unrepresentableTopicResult.values().get(tp1).get(); + + // Test success scenario + AlterPartitionReassignmentsResponseData noErrResponseData = + new AlterPartitionReassignmentsResponseData() + .setErrorCode(Errors.NONE.code()) + .setErrorMessage(Errors.NONE.message()) + .setResponses(Arrays.asList( + new ReassignableTopicResponse() + .setName("A") + .setPartitions(Collections.singletonList(normalPartitionResponse)), + new ReassignableTopicResponse() + .setName("B") + .setPartitions(Collections.singletonList(normalPartitionResponse))) + ); + env.kafkaClient().prepareResponse(new AlterPartitionReassignmentsResponse(noErrResponseData)); + AlterPartitionReassignmentsResult noErrResult = env.adminClient().alterPartitionReassignments(reassignments); + noErrResult.all().get(); + noErrResult.values().get(tp1).get(); + noErrResult.values().get(tp2).get(); + } + } + + @Test + public void testListPartitionReassignments() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + TopicPartition tp1 = new TopicPartition("A", 0); + OngoingPartitionReassignment tp1PartitionReassignment = new OngoingPartitionReassignment() + .setPartitionIndex(0) + .setRemovingReplicas(Arrays.asList(1, 2, 3)) + .setAddingReplicas(Arrays.asList(4, 5, 6)) + .setReplicas(Arrays.asList(1, 2, 3, 4, 5, 6)); + OngoingTopicReassignment tp1Reassignment = new OngoingTopicReassignment().setName("A") + .setPartitions(Collections.singletonList(tp1PartitionReassignment)); + + TopicPartition tp2 = new TopicPartition("B", 0); + OngoingPartitionReassignment tp2PartitionReassignment = new OngoingPartitionReassignment() + .setPartitionIndex(0) + .setRemovingReplicas(Arrays.asList(1, 2, 3)) + .setAddingReplicas(Arrays.asList(4, 5, 6)) + .setReplicas(Arrays.asList(1, 2, 3, 4, 5, 6)); + OngoingTopicReassignment tp2Reassignment = new OngoingTopicReassignment().setName("B") + .setPartitions(Collections.singletonList(tp2PartitionReassignment)); + + // 1. NOT_CONTROLLER error handling + ListPartitionReassignmentsResponseData notControllerData = new ListPartitionReassignmentsResponseData() + .setErrorCode(Errors.NOT_CONTROLLER.code()) + .setErrorMessage(Errors.NOT_CONTROLLER.message()); + MetadataResponse controllerNodeResponse = RequestTestUtils.metadataResponse(env.cluster().nodes(), + env.cluster().clusterResource().clusterId(), 1, Collections.emptyList()); + ListPartitionReassignmentsResponseData reassignmentsData = new ListPartitionReassignmentsResponseData() + .setTopics(Arrays.asList(tp1Reassignment, tp2Reassignment)); + env.kafkaClient().prepareResponse(new ListPartitionReassignmentsResponse(notControllerData)); + env.kafkaClient().prepareResponse(controllerNodeResponse); + env.kafkaClient().prepareResponse(new ListPartitionReassignmentsResponse(reassignmentsData)); + + ListPartitionReassignmentsResult noControllerResult = env.adminClient().listPartitionReassignments(); + noControllerResult.reassignments().get(); // no error + + // 2. UNKNOWN_TOPIC_OR_EXCEPTION_ERROR + ListPartitionReassignmentsResponseData unknownTpData = new ListPartitionReassignmentsResponseData() + .setErrorCode(Errors.UNKNOWN_TOPIC_OR_PARTITION.code()) + .setErrorMessage(Errors.UNKNOWN_TOPIC_OR_PARTITION.message()); + env.kafkaClient().prepareResponse(new ListPartitionReassignmentsResponse(unknownTpData)); + + ListPartitionReassignmentsResult unknownTpResult = env.adminClient().listPartitionReassignments(new HashSet<>(Arrays.asList(tp1, tp2))); + TestUtils.assertFutureError(unknownTpResult.reassignments(), UnknownTopicOrPartitionException.class); + + // 3. Success + ListPartitionReassignmentsResponseData responseData = new ListPartitionReassignmentsResponseData() + .setTopics(Arrays.asList(tp1Reassignment, tp2Reassignment)); + env.kafkaClient().prepareResponse(new ListPartitionReassignmentsResponse(responseData)); + ListPartitionReassignmentsResult responseResult = env.adminClient().listPartitionReassignments(); + + Map reassignments = responseResult.reassignments().get(); + + PartitionReassignment tp1Result = reassignments.get(tp1); + assertEquals(tp1PartitionReassignment.addingReplicas(), tp1Result.addingReplicas()); + assertEquals(tp1PartitionReassignment.removingReplicas(), tp1Result.removingReplicas()); + assertEquals(tp1PartitionReassignment.replicas(), tp1Result.replicas()); + assertEquals(tp1PartitionReassignment.replicas(), tp1Result.replicas()); + PartitionReassignment tp2Result = reassignments.get(tp2); + assertEquals(tp2PartitionReassignment.addingReplicas(), tp2Result.addingReplicas()); + assertEquals(tp2PartitionReassignment.removingReplicas(), tp2Result.removingReplicas()); + assertEquals(tp2PartitionReassignment.replicas(), tp2Result.replicas()); + assertEquals(tp2PartitionReassignment.replicas(), tp2Result.replicas()); + } + } + + @Test + public void testAlterConsumerGroupOffsets() throws Exception { + // Happy path + + final TopicPartition tp1 = new TopicPartition("foo", 0); + final TopicPartition tp2 = new TopicPartition("bar", 0); + final TopicPartition tp3 = new TopicPartition("foobar", 0); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(mockCluster(1, 0))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse( + prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + Map responseData = new HashMap<>(); + responseData.put(tp1, Errors.NONE); + responseData.put(tp2, Errors.NONE); + env.kafkaClient().prepareResponse(new OffsetCommitResponse(0, responseData)); + + Map offsets = new HashMap<>(); + offsets.put(tp1, new OffsetAndMetadata(123L)); + offsets.put(tp2, new OffsetAndMetadata(456L)); + final AlterConsumerGroupOffsetsResult result = env.adminClient().alterConsumerGroupOffsets( + GROUP_ID, offsets); + + assertNull(result.all().get()); + assertNull(result.partitionResult(tp1).get()); + assertNull(result.partitionResult(tp2).get()); + TestUtils.assertFutureError(result.partitionResult(tp3), IllegalArgumentException.class); + } + } + + @Test + public void testAlterConsumerGroupOffsetsRetriableErrors() throws Exception { + // Retriable errors should be retried + + final TopicPartition tp1 = new TopicPartition("foo", 0); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(mockCluster(1, 0))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse( + prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + env.kafkaClient().prepareResponse( + prepareOffsetCommitResponse(tp1, Errors.COORDINATOR_NOT_AVAILABLE)); + + env.kafkaClient().prepareResponse( + prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + env.kafkaClient().prepareResponse( + prepareOffsetCommitResponse(tp1, Errors.COORDINATOR_LOAD_IN_PROGRESS)); + + env.kafkaClient().prepareResponse( + prepareOffsetCommitResponse(tp1, Errors.NOT_COORDINATOR)); + + env.kafkaClient().prepareResponse( + prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + env.kafkaClient().prepareResponse( + prepareOffsetCommitResponse(tp1, Errors.REBALANCE_IN_PROGRESS)); + + env.kafkaClient().prepareResponse( + prepareOffsetCommitResponse(tp1, Errors.NONE)); + + Map offsets = new HashMap<>(); + offsets.put(tp1, new OffsetAndMetadata(123L)); + final AlterConsumerGroupOffsetsResult result1 = env.adminClient() + .alterConsumerGroupOffsets(GROUP_ID, offsets); + + assertNull(result1.all().get()); + assertNull(result1.partitionResult(tp1).get()); + } + } + + @Test + public void testAlterConsumerGroupOffsetsNonRetriableErrors() throws Exception { + // Non-retriable errors throw an exception + + final TopicPartition tp1 = new TopicPartition("foo", 0); + final List nonRetriableErrors = Arrays.asList( + Errors.GROUP_AUTHORIZATION_FAILED, Errors.INVALID_GROUP_ID, Errors.GROUP_ID_NOT_FOUND); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(mockCluster(1, 0))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + for (Errors error : nonRetriableErrors) { + env.kafkaClient().prepareResponse( + prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + env.kafkaClient().prepareResponse(prepareOffsetCommitResponse(tp1, error)); + + Map offsets = new HashMap<>(); + offsets.put(tp1, new OffsetAndMetadata(123L)); + AlterConsumerGroupOffsetsResult errorResult = env.adminClient() + .alterConsumerGroupOffsets(GROUP_ID, offsets); + + TestUtils.assertFutureError(errorResult.all(), error.exception().getClass()); + TestUtils.assertFutureError(errorResult.partitionResult(tp1), error.exception().getClass()); + } + } + } + + @Test + public void testAlterConsumerGroupOffsetsFindCoordinatorRetriableErrors() throws Exception { + // Retriable FindCoordinatorResponse errors should be retried + + final TopicPartition tp1 = new TopicPartition("foo", 0); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(mockCluster(1, 0))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse( + prepareFindCoordinatorResponse(Errors.COORDINATOR_NOT_AVAILABLE, Node.noNode())); + env.kafkaClient().prepareResponse( + prepareFindCoordinatorResponse(Errors.COORDINATOR_LOAD_IN_PROGRESS, Node.noNode())); + + env.kafkaClient().prepareResponse( + prepareFindCoordinatorResponse(Errors.NONE, env.cluster().controller())); + + env.kafkaClient().prepareResponse( + prepareOffsetCommitResponse(tp1, Errors.NONE)); + + Map offsets = new HashMap<>(); + offsets.put(tp1, new OffsetAndMetadata(123L)); + final AlterConsumerGroupOffsetsResult result = env.adminClient() + .alterConsumerGroupOffsets(GROUP_ID, offsets); + + assertNull(result.all().get()); + assertNull(result.partitionResult(tp1).get()); + } + } + + @Test + public void testAlterConsumerGroupOffsetsFindCoordinatorNonRetriableErrors() throws Exception { + // Non-retriable FindCoordinatorResponse errors throw an exception + + final TopicPartition tp1 = new TopicPartition("foo", 0); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(mockCluster(1, 0))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse( + prepareFindCoordinatorResponse(Errors.GROUP_AUTHORIZATION_FAILED, Node.noNode())); + + Map offsets = new HashMap<>(); + offsets.put(tp1, new OffsetAndMetadata(123L)); + final AlterConsumerGroupOffsetsResult errorResult = env.adminClient() + .alterConsumerGroupOffsets(GROUP_ID, offsets); + + TestUtils.assertFutureError(errorResult.all(), GroupAuthorizationException.class); + TestUtils.assertFutureError(errorResult.partitionResult(tp1), GroupAuthorizationException.class); + } + } + + @Test + public void testListOffsets() throws Exception { + // Happy path + + Node node0 = new Node(0, "localhost", 8120); + List pInfos = new ArrayList<>(); + pInfos.add(new PartitionInfo("foo", 0, node0, new Node[]{node0}, new Node[]{node0})); + pInfos.add(new PartitionInfo("bar", 0, node0, new Node[]{node0}, new Node[]{node0})); + pInfos.add(new PartitionInfo("baz", 0, node0, new Node[]{node0}, new Node[]{node0})); + pInfos.add(new PartitionInfo("qux", 0, node0, new Node[]{node0}, new Node[]{node0})); + final Cluster cluster = + new Cluster( + "mockClusterId", + Arrays.asList(node0), + pInfos, + Collections.emptySet(), + Collections.emptySet(), + node0); + + final TopicPartition tp0 = new TopicPartition("foo", 0); + final TopicPartition tp1 = new TopicPartition("bar", 0); + final TopicPartition tp2 = new TopicPartition("baz", 0); + final TopicPartition tp3 = new TopicPartition("qux", 0); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(cluster)) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse(prepareMetadataResponse(cluster, Errors.NONE)); + + ListOffsetsTopicResponse t0 = ListOffsetsResponse.singletonListOffsetsTopicResponse(tp0, Errors.NONE, -1L, 123L, 321); + ListOffsetsTopicResponse t1 = ListOffsetsResponse.singletonListOffsetsTopicResponse(tp1, Errors.NONE, -1L, 234L, 432); + ListOffsetsTopicResponse t2 = ListOffsetsResponse.singletonListOffsetsTopicResponse(tp2, Errors.NONE, 123456789L, 345L, 543); + ListOffsetsTopicResponse t3 = ListOffsetsResponse.singletonListOffsetsTopicResponse(tp3, Errors.NONE, 234567890L, 456L, 654); + ListOffsetsResponseData responseData = new ListOffsetsResponseData() + .setThrottleTimeMs(0) + .setTopics(Arrays.asList(t0, t1, t2, t3)); + env.kafkaClient().prepareResponse(new ListOffsetsResponse(responseData)); + + Map partitions = new HashMap<>(); + partitions.put(tp0, OffsetSpec.latest()); + partitions.put(tp1, OffsetSpec.earliest()); + partitions.put(tp2, OffsetSpec.forTimestamp(System.currentTimeMillis())); + partitions.put(tp3, OffsetSpec.maxTimestamp()); + ListOffsetsResult result = env.adminClient().listOffsets(partitions); + + Map offsets = result.all().get(); + assertFalse(offsets.isEmpty()); + assertEquals(123L, offsets.get(tp0).offset()); + assertEquals(321, offsets.get(tp0).leaderEpoch().get().intValue()); + assertEquals(-1L, offsets.get(tp0).timestamp()); + assertEquals(234L, offsets.get(tp1).offset()); + assertEquals(432, offsets.get(tp1).leaderEpoch().get().intValue()); + assertEquals(-1L, offsets.get(tp1).timestamp()); + assertEquals(345L, offsets.get(tp2).offset()); + assertEquals(543, offsets.get(tp2).leaderEpoch().get().intValue()); + assertEquals(123456789L, offsets.get(tp2).timestamp()); + assertEquals(456L, offsets.get(tp3).offset()); + assertEquals(654, offsets.get(tp3).leaderEpoch().get().intValue()); + assertEquals(234567890L, offsets.get(tp3).timestamp()); + assertEquals(offsets.get(tp0), result.partitionResult(tp0).get()); + assertEquals(offsets.get(tp1), result.partitionResult(tp1).get()); + assertEquals(offsets.get(tp2), result.partitionResult(tp2).get()); + assertEquals(offsets.get(tp3), result.partitionResult(tp3).get()); + try { + result.partitionResult(new TopicPartition("unknown", 0)).get(); + fail("should have thrown IllegalArgumentException"); + } catch (IllegalArgumentException expected) { } + } + } + + @Test + public void testListOffsetsRetriableErrorOnMetadata() throws Exception { + Node node = new Node(0, "localhost", 8120); + List nodes = Collections.singletonList(node); + final Cluster cluster = new Cluster( + "mockClusterId", + nodes, + Collections.singleton(new PartitionInfo("foo", 0, node, new Node[]{node}, new Node[]{node})), + Collections.emptySet(), + Collections.emptySet(), + node); + final TopicPartition tp0 = new TopicPartition("foo", 0); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(cluster)) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + env.kafkaClient().prepareResponse(prepareMetadataResponse(cluster, Errors.UNKNOWN_TOPIC_OR_PARTITION, Errors.NONE)); + // metadata refresh because of UNKNOWN_TOPIC_OR_PARTITION + env.kafkaClient().prepareResponse(prepareMetadataResponse(cluster, Errors.NONE)); + // listoffsets response from broker 0 + ListOffsetsResponseData responseData = new ListOffsetsResponseData() + .setThrottleTimeMs(0) + .setTopics(Collections.singletonList(ListOffsetsResponse.singletonListOffsetsTopicResponse(tp0, Errors.NONE, -1L, 123L, 321))); + env.kafkaClient().prepareResponseFrom(new ListOffsetsResponse(responseData), node); + + ListOffsetsResult result = env.adminClient().listOffsets(Collections.singletonMap(tp0, OffsetSpec.latest())); + + Map offsets = result.all().get(3, TimeUnit.SECONDS); + assertEquals(1, offsets.size()); + assertEquals(123L, offsets.get(tp0).offset()); + assertEquals(321, offsets.get(tp0).leaderEpoch().get().intValue()); + assertEquals(-1L, offsets.get(tp0).timestamp()); + } + } + + @Test + public void testListOffsetsRetriableErrors() throws Exception { + + Node node0 = new Node(0, "localhost", 8120); + Node node1 = new Node(1, "localhost", 8121); + List nodes = Arrays.asList(node0, node1); + List pInfos = new ArrayList<>(); + pInfos.add(new PartitionInfo("foo", 0, node0, new Node[]{node0, node1}, new Node[]{node0, node1})); + pInfos.add(new PartitionInfo("foo", 1, node0, new Node[]{node0, node1}, new Node[]{node0, node1})); + pInfos.add(new PartitionInfo("bar", 0, node1, new Node[]{node1, node0}, new Node[]{node1, node0})); + final Cluster cluster = + new Cluster( + "mockClusterId", + nodes, + pInfos, + Collections.emptySet(), + Collections.emptySet(), + node0); + + final TopicPartition tp0 = new TopicPartition("foo", 0); + final TopicPartition tp1 = new TopicPartition("foo", 1); + final TopicPartition tp2 = new TopicPartition("bar", 0); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(cluster)) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + env.kafkaClient().prepareResponse(prepareMetadataResponse(cluster, Errors.NONE)); + // listoffsets response from broker 0 + ListOffsetsTopicResponse t0 = ListOffsetsResponse.singletonListOffsetsTopicResponse(tp0, Errors.LEADER_NOT_AVAILABLE, -1L, 123L, 321); + ListOffsetsTopicResponse t1 = ListOffsetsResponse.singletonListOffsetsTopicResponse(tp1, Errors.NONE, -1L, 987L, 789); + ListOffsetsResponseData responseData = new ListOffsetsResponseData() + .setThrottleTimeMs(0) + .setTopics(Arrays.asList(t0, t1)); + env.kafkaClient().prepareResponseFrom(new ListOffsetsResponse(responseData), node0); + // listoffsets response from broker 1 + ListOffsetsTopicResponse t2 = ListOffsetsResponse.singletonListOffsetsTopicResponse(tp2, Errors.NONE, -1L, 456L, 654); + responseData = new ListOffsetsResponseData() + .setThrottleTimeMs(0) + .setTopics(Arrays.asList(t2)); + env.kafkaClient().prepareResponseFrom(new ListOffsetsResponse(responseData), node1); + + // metadata refresh because of LEADER_NOT_AVAILABLE + env.kafkaClient().prepareResponse(prepareMetadataResponse(cluster, Errors.NONE)); + // listoffsets response from broker 0 + t0 = ListOffsetsResponse.singletonListOffsetsTopicResponse(tp0, Errors.NONE, -1L, 345L, 543); + responseData = new ListOffsetsResponseData() + .setThrottleTimeMs(0) + .setTopics(Arrays.asList(t0)); + env.kafkaClient().prepareResponseFrom(new ListOffsetsResponse(responseData), node0); + + Map partitions = new HashMap<>(); + partitions.put(tp0, OffsetSpec.latest()); + partitions.put(tp1, OffsetSpec.latest()); + partitions.put(tp2, OffsetSpec.latest()); + ListOffsetsResult result = env.adminClient().listOffsets(partitions); + + Map offsets = result.all().get(); + assertFalse(offsets.isEmpty()); + assertEquals(345L, offsets.get(tp0).offset()); + assertEquals(543, offsets.get(tp0).leaderEpoch().get().intValue()); + assertEquals(-1L, offsets.get(tp0).timestamp()); + assertEquals(987L, offsets.get(tp1).offset()); + assertEquals(789, offsets.get(tp1).leaderEpoch().get().intValue()); + assertEquals(-1L, offsets.get(tp1).timestamp()); + assertEquals(456L, offsets.get(tp2).offset()); + assertEquals(654, offsets.get(tp2).leaderEpoch().get().intValue()); + assertEquals(-1L, offsets.get(tp2).timestamp()); + } + } + + @Test + public void testListOffsetsNonRetriableErrors() throws Exception { + + Node node0 = new Node(0, "localhost", 8120); + Node node1 = new Node(1, "localhost", 8121); + List nodes = Arrays.asList(node0, node1); + List pInfos = new ArrayList<>(); + pInfos.add(new PartitionInfo("foo", 0, node0, new Node[]{node0, node1}, new Node[]{node0, node1})); + final Cluster cluster = + new Cluster( + "mockClusterId", + nodes, + pInfos, + Collections.emptySet(), + Collections.emptySet(), + node0); + + final TopicPartition tp0 = new TopicPartition("foo", 0); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(cluster)) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse(prepareMetadataResponse(cluster, Errors.NONE)); + + ListOffsetsTopicResponse t0 = ListOffsetsResponse.singletonListOffsetsTopicResponse(tp0, Errors.TOPIC_AUTHORIZATION_FAILED, -1L, -1L, -1); + ListOffsetsResponseData responseData = new ListOffsetsResponseData() + .setThrottleTimeMs(0) + .setTopics(Arrays.asList(t0)); + env.kafkaClient().prepareResponse(new ListOffsetsResponse(responseData)); + + Map partitions = new HashMap<>(); + partitions.put(tp0, OffsetSpec.latest()); + ListOffsetsResult result = env.adminClient().listOffsets(partitions); + + TestUtils.assertFutureError(result.all(), TopicAuthorizationException.class); + } + } + + @Test + public void testListOffsetsMaxTimestampUnsupportedSingleOffsetSpec() { + Node node = new Node(0, "localhost", 8120); + List nodes = Collections.singletonList(node); + final Cluster cluster = new Cluster( + "mockClusterId", + nodes, + Collections.singleton(new PartitionInfo("foo", 0, node, new Node[]{node}, new Node[]{node})), + Collections.emptySet(), + Collections.emptySet(), + node); + final TopicPartition tp0 = new TopicPartition("foo", 0); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(cluster, AdminClientConfig.RETRIES_CONFIG, "2")) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create( + ApiKeys.LIST_OFFSETS.id, (short) 0, (short) 6)); + env.kafkaClient().prepareResponse(prepareMetadataResponse(cluster, Errors.NONE)); + + // listoffsets response from broker 0 + env.kafkaClient().prepareUnsupportedVersionResponse( + request -> request instanceof ListOffsetsRequest); + + ListOffsetsResult result = env.adminClient().listOffsets(Collections.singletonMap(tp0, OffsetSpec.maxTimestamp())); + + TestUtils.assertFutureThrows(result.all(), UnsupportedVersionException.class); + } + } + + @Test + public void testListOffsetsMaxTimestampUnsupportedMultipleOffsetSpec() throws Exception { + Node node = new Node(0, "localhost", 8120); + List nodes = Collections.singletonList(node); + List pInfos = new ArrayList<>(); + pInfos.add(new PartitionInfo("foo", 0, node, new Node[]{node}, new Node[]{node})); + pInfos.add(new PartitionInfo("foo", 1, node, new Node[]{node}, new Node[]{node})); + final Cluster cluster = new Cluster( + "mockClusterId", + nodes, + pInfos, + Collections.emptySet(), + Collections.emptySet(), + node); + final TopicPartition tp0 = new TopicPartition("foo", 0); + final TopicPartition tp1 = new TopicPartition("foo", 1); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(cluster, + AdminClientConfig.RETRIES_CONFIG, "2")) { + + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create( + ApiKeys.LIST_OFFSETS.id, (short) 0, (short) 6)); + env.kafkaClient().prepareResponse(prepareMetadataResponse(cluster, Errors.NONE)); + + // listoffsets response from broker 0 + env.kafkaClient().prepareUnsupportedVersionResponse( + request -> request instanceof ListOffsetsRequest); + + ListOffsetsTopicResponse topicResponse = ListOffsetsResponse.singletonListOffsetsTopicResponse(tp1, Errors.NONE, -1L, 345L, 543); + ListOffsetsResponseData responseData = new ListOffsetsResponseData() + .setThrottleTimeMs(0) + .setTopics(Arrays.asList(topicResponse)); + env.kafkaClient().prepareResponseFrom( + // ensure that no max timestamp requests are retried + request -> request instanceof ListOffsetsRequest && ((ListOffsetsRequest) request).topics().stream() + .flatMap(t -> t.partitions().stream()) + .noneMatch(p -> p.timestamp() == ListOffsetsRequest.MAX_TIMESTAMP), + new ListOffsetsResponse(responseData), node); + + ListOffsetsResult result = env.adminClient().listOffsets(new HashMap() {{ + put(tp0, OffsetSpec.maxTimestamp()); + put(tp1, OffsetSpec.latest()); + }}); + + TestUtils.assertFutureThrows(result.partitionResult(tp0), UnsupportedVersionException.class); + + ListOffsetsResultInfo tp1Offset = result.partitionResult(tp1).get(); + assertEquals(345L, tp1Offset.offset()); + assertEquals(543, tp1Offset.leaderEpoch().get().intValue()); + assertEquals(-1L, tp1Offset.timestamp()); + } + } + + @Test + public void testListOffsetsUnsupportedNonMaxTimestamp() { + Node node = new Node(0, "localhost", 8120); + List nodes = Collections.singletonList(node); + List pInfos = new ArrayList<>(); + pInfos.add(new PartitionInfo("foo", 0, node, new Node[]{node}, new Node[]{node})); + final Cluster cluster = new Cluster( + "mockClusterId", + nodes, + pInfos, + Collections.emptySet(), + Collections.emptySet(), + node); + final TopicPartition tp0 = new TopicPartition("foo", 0); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(cluster, + AdminClientConfig.RETRIES_CONFIG, "2")) { + + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create( + ApiKeys.LIST_OFFSETS.id, (short) 0, (short) 0)); + env.kafkaClient().prepareResponse(prepareMetadataResponse(cluster, Errors.NONE)); + + // listoffsets response from broker 0 + env.kafkaClient().prepareUnsupportedVersionResponse( + request -> request instanceof ListOffsetsRequest); + + ListOffsetsResult result = env.adminClient().listOffsets( + Collections.singletonMap(tp0, OffsetSpec.latest())); + + TestUtils.assertFutureThrows(result.partitionResult(tp0), UnsupportedVersionException.class); + } + } + + @Test + public void testListOffsetsNonMaxTimestampDowngradedImmediately() throws Exception { + Node node = new Node(0, "localhost", 8120); + List nodes = Collections.singletonList(node); + List pInfos = new ArrayList<>(); + pInfos.add(new PartitionInfo("foo", 0, node, new Node[]{node}, new Node[]{node})); + final Cluster cluster = new Cluster( + "mockClusterId", + nodes, + pInfos, + Collections.emptySet(), + Collections.emptySet(), + node); + final TopicPartition tp0 = new TopicPartition("foo", 0); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(cluster, + AdminClientConfig.RETRIES_CONFIG, "2")) { + + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create( + ApiKeys.LIST_OFFSETS.id, (short) 0, (short) 6)); + + env.kafkaClient().prepareResponse(prepareMetadataResponse(cluster, Errors.NONE)); + + ListOffsetsTopicResponse t0 = ListOffsetsResponse.singletonListOffsetsTopicResponse(tp0, Errors.NONE, -1L, 123L, 321); + ListOffsetsResponseData responseData = new ListOffsetsResponseData() + .setThrottleTimeMs(0) + .setTopics(Arrays.asList(t0)); + + // listoffsets response from broker 0 + env.kafkaClient().prepareResponse( + request -> request instanceof ListOffsetsRequest, + new ListOffsetsResponse(responseData)); + + ListOffsetsResult result = env.adminClient().listOffsets( + Collections.singletonMap(tp0, OffsetSpec.latest())); + + ListOffsetsResultInfo tp0Offset = result.partitionResult(tp0).get(); + assertEquals(123L, tp0Offset.offset()); + assertEquals(321, tp0Offset.leaderEpoch().get().intValue()); + assertEquals(-1L, tp0Offset.timestamp()); + } + } + + private Map makeTestFeatureUpdates() { + return Utils.mkMap( + Utils.mkEntry("test_feature_1", new FeatureUpdate((short) 2, false)), + Utils.mkEntry("test_feature_2", new FeatureUpdate((short) 3, true))); + } + + private Map makeTestFeatureUpdateErrors(final Map updates, final Errors error) { + final Map errors = new HashMap<>(); + for (Map.Entry entry : updates.entrySet()) { + errors.put(entry.getKey(), new ApiError(error)); + } + return errors; + } + + private void testUpdateFeatures(Map featureUpdates, + ApiError topLevelError, + Map featureUpdateErrors) throws Exception { + try (final AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().prepareResponse( + body -> body instanceof UpdateFeaturesRequest, + UpdateFeaturesResponse.createWithErrors(topLevelError, featureUpdateErrors, 0)); + final Map> futures = env.adminClient().updateFeatures( + featureUpdates, + new UpdateFeaturesOptions().timeoutMs(10000)).values(); + for (final Map.Entry> entry : futures.entrySet()) { + final KafkaFuture future = entry.getValue(); + final ApiError error = featureUpdateErrors.get(entry.getKey()); + if (topLevelError.error() == Errors.NONE) { + assertNotNull(error); + if (error.error() == Errors.NONE) { + future.get(); + } else { + final ExecutionException e = assertThrows(ExecutionException.class, future::get); + assertEquals(e.getCause().getClass(), error.exception().getClass()); + } + } else { + final ExecutionException e = assertThrows(ExecutionException.class, future::get); + assertEquals(e.getCause().getClass(), topLevelError.exception().getClass()); + } + } + } + } + + @Test + public void testUpdateFeaturesDuringSuccess() throws Exception { + final Map updates = makeTestFeatureUpdates(); + testUpdateFeatures(updates, ApiError.NONE, makeTestFeatureUpdateErrors(updates, Errors.NONE)); + } + + @Test + public void testUpdateFeaturesTopLevelError() throws Exception { + final Map updates = makeTestFeatureUpdates(); + testUpdateFeatures(updates, new ApiError(Errors.INVALID_REQUEST), new HashMap<>()); + } + + @Test + public void testUpdateFeaturesInvalidRequestError() throws Exception { + final Map updates = makeTestFeatureUpdates(); + testUpdateFeatures(updates, ApiError.NONE, makeTestFeatureUpdateErrors(updates, Errors.INVALID_REQUEST)); + } + + @Test + public void testUpdateFeaturesUpdateFailedError() throws Exception { + final Map updates = makeTestFeatureUpdates(); + testUpdateFeatures(updates, ApiError.NONE, makeTestFeatureUpdateErrors(updates, Errors.FEATURE_UPDATE_FAILED)); + } + + @Test + public void testUpdateFeaturesPartialSuccess() throws Exception { + final Map errors = makeTestFeatureUpdateErrors(makeTestFeatureUpdates(), Errors.NONE); + errors.put("test_feature_2", new ApiError(Errors.INVALID_REQUEST)); + testUpdateFeatures(makeTestFeatureUpdates(), ApiError.NONE, errors); + } + + @Test + public void testUpdateFeaturesHandleNotControllerException() throws Exception { + try (final AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().prepareResponseFrom( + request -> request instanceof UpdateFeaturesRequest, + UpdateFeaturesResponse.createWithErrors( + new ApiError(Errors.NOT_CONTROLLER), + Utils.mkMap(), + 0), + env.cluster().nodeById(0)); + final int controllerId = 1; + env.kafkaClient().prepareResponse(RequestTestUtils.metadataResponse(env.cluster().nodes(), + env.cluster().clusterResource().clusterId(), + controllerId, + Collections.emptyList())); + env.kafkaClient().prepareResponseFrom( + request -> request instanceof UpdateFeaturesRequest, + UpdateFeaturesResponse.createWithErrors( + ApiError.NONE, + Utils.mkMap(Utils.mkEntry("test_feature_1", ApiError.NONE), + Utils.mkEntry("test_feature_2", ApiError.NONE)), + 0), + env.cluster().nodeById(controllerId)); + final KafkaFuture future = env.adminClient().updateFeatures( + Utils.mkMap( + Utils.mkEntry("test_feature_1", new FeatureUpdate((short) 2, false)), + Utils.mkEntry("test_feature_2", new FeatureUpdate((short) 3, true))), + new UpdateFeaturesOptions().timeoutMs(10000) + ).all(); + future.get(); + } + } + + @Test + public void testUpdateFeaturesShouldFailRequestForEmptyUpdates() { + try (final AdminClientUnitTestEnv env = mockClientEnv()) { + assertThrows( + IllegalArgumentException.class, + () -> env.adminClient().updateFeatures( + new HashMap<>(), new UpdateFeaturesOptions())); + } + } + + @Test + public void testUpdateFeaturesShouldFailRequestForInvalidFeatureName() { + try (final AdminClientUnitTestEnv env = mockClientEnv()) { + assertThrows( + IllegalArgumentException.class, + () -> env.adminClient().updateFeatures( + Utils.mkMap(Utils.mkEntry("feature", new FeatureUpdate((short) 2, false)), + Utils.mkEntry("", new FeatureUpdate((short) 2, false))), + new UpdateFeaturesOptions())); + } + } + + @Test + public void testUpdateFeaturesShouldFailRequestInClientWhenDowngradeFlagIsNotSetDuringDeletion() { + assertThrows( + IllegalArgumentException.class, + () -> new FeatureUpdate((short) 0, false)); + } + + @Test + public void testDescribeFeaturesSuccess() throws Exception { + try (final AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().prepareResponse( + body -> body instanceof ApiVersionsRequest, + prepareApiVersionsResponseForDescribeFeatures(Errors.NONE)); + final KafkaFuture future = env.adminClient().describeFeatures( + new DescribeFeaturesOptions().timeoutMs(10000)).featureMetadata(); + final FeatureMetadata metadata = future.get(); + assertEquals(defaultFeatureMetadata(), metadata); + } + } + + @Test + public void testDescribeFeaturesFailure() { + try (final AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().prepareResponse( + body -> body instanceof ApiVersionsRequest, + prepareApiVersionsResponseForDescribeFeatures(Errors.INVALID_REQUEST)); + final DescribeFeaturesOptions options = new DescribeFeaturesOptions(); + options.timeoutMs(10000); + final KafkaFuture future = env.adminClient().describeFeatures(options).featureMetadata(); + final ExecutionException e = assertThrows(ExecutionException.class, future::get); + assertEquals(e.getCause().getClass(), Errors.INVALID_REQUEST.exception().getClass()); + } + } + + @Test + public void testListOffsetsMetadataRetriableErrors() throws Exception { + + Node node0 = new Node(0, "localhost", 8120); + Node node1 = new Node(1, "localhost", 8121); + List nodes = Arrays.asList(node0, node1); + List pInfos = new ArrayList<>(); + pInfos.add(new PartitionInfo("foo", 0, node0, new Node[]{node0}, new Node[]{node0})); + pInfos.add(new PartitionInfo("foo", 1, node1, new Node[]{node1}, new Node[]{node1})); + final Cluster cluster = + new Cluster( + "mockClusterId", + nodes, + pInfos, + Collections.emptySet(), + Collections.emptySet(), + node0); + + final TopicPartition tp0 = new TopicPartition("foo", 0); + final TopicPartition tp1 = new TopicPartition("foo", 1); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(cluster)) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse(prepareMetadataResponse(cluster, Errors.LEADER_NOT_AVAILABLE)); + env.kafkaClient().prepareResponse(prepareMetadataResponse(cluster, Errors.UNKNOWN_TOPIC_OR_PARTITION)); + env.kafkaClient().prepareResponse(prepareMetadataResponse(cluster, Errors.NONE)); + + // listoffsets response from broker 0 + ListOffsetsTopicResponse t0 = ListOffsetsResponse.singletonListOffsetsTopicResponse(tp0, Errors.NONE, -1L, 345L, 543); + ListOffsetsResponseData responseData = new ListOffsetsResponseData() + .setThrottleTimeMs(0) + .setTopics(Arrays.asList(t0)); + env.kafkaClient().prepareResponseFrom(new ListOffsetsResponse(responseData), node0); + // listoffsets response from broker 1 + ListOffsetsTopicResponse t1 = ListOffsetsResponse.singletonListOffsetsTopicResponse(tp1, Errors.NONE, -1L, 789L, 987); + responseData = new ListOffsetsResponseData() + .setThrottleTimeMs(0) + .setTopics(Arrays.asList(t1)); + env.kafkaClient().prepareResponseFrom(new ListOffsetsResponse(responseData), node1); + + Map partitions = new HashMap<>(); + partitions.put(tp0, OffsetSpec.latest()); + partitions.put(tp1, OffsetSpec.latest()); + ListOffsetsResult result = env.adminClient().listOffsets(partitions); + + Map offsets = result.all().get(); + assertFalse(offsets.isEmpty()); + assertEquals(345L, offsets.get(tp0).offset()); + assertEquals(543, offsets.get(tp0).leaderEpoch().get().intValue()); + assertEquals(-1L, offsets.get(tp0).timestamp()); + assertEquals(789L, offsets.get(tp1).offset()); + assertEquals(987, offsets.get(tp1).leaderEpoch().get().intValue()); + assertEquals(-1L, offsets.get(tp1).timestamp()); + } + } + + @Test + public void testListOffsetsWithMultiplePartitionsLeaderChange() throws Exception { + Node node0 = new Node(0, "localhost", 8120); + Node node1 = new Node(1, "localhost", 8121); + Node node2 = new Node(2, "localhost", 8122); + List nodes = Arrays.asList(node0, node1, node2); + + final PartitionInfo oldPInfo1 = new PartitionInfo("foo", 0, node0, + new Node[]{node0, node1, node2}, new Node[]{node0, node1, node2}); + final PartitionInfo oldPnfo2 = new PartitionInfo("foo", 1, node0, + new Node[]{node0, node1, node2}, new Node[]{node0, node1, node2}); + List oldPInfos = Arrays.asList(oldPInfo1, oldPnfo2); + + final Cluster oldCluster = new Cluster("mockClusterId", nodes, oldPInfos, + Collections.emptySet(), Collections.emptySet(), node0); + final TopicPartition tp0 = new TopicPartition("foo", 0); + final TopicPartition tp1 = new TopicPartition("foo", 1); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(oldCluster)) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse(prepareMetadataResponse(oldCluster, Errors.NONE)); + + ListOffsetsTopicResponse t0 = ListOffsetsResponse.singletonListOffsetsTopicResponse(tp0, Errors.NOT_LEADER_OR_FOLLOWER, -1L, 345L, 543); + ListOffsetsTopicResponse t1 = ListOffsetsResponse.singletonListOffsetsTopicResponse(tp1, Errors.LEADER_NOT_AVAILABLE, -2L, 123L, 456); + ListOffsetsResponseData responseData = new ListOffsetsResponseData() + .setThrottleTimeMs(0) + .setTopics(Arrays.asList(t0, t1)); + env.kafkaClient().prepareResponseFrom(new ListOffsetsResponse(responseData), node0); + + final PartitionInfo newPInfo1 = new PartitionInfo("foo", 0, node1, + new Node[]{node0, node1, node2}, new Node[]{node0, node1, node2}); + final PartitionInfo newPInfo2 = new PartitionInfo("foo", 1, node2, + new Node[]{node0, node1, node2}, new Node[]{node0, node1, node2}); + List newPInfos = Arrays.asList(newPInfo1, newPInfo2); + + final Cluster newCluster = new Cluster("mockClusterId", nodes, newPInfos, + Collections.emptySet(), Collections.emptySet(), node0); + + env.kafkaClient().prepareResponse(prepareMetadataResponse(newCluster, Errors.NONE)); + + t0 = ListOffsetsResponse.singletonListOffsetsTopicResponse(tp0, Errors.NONE, -1L, 345L, 543); + responseData = new ListOffsetsResponseData() + .setThrottleTimeMs(0) + .setTopics(Arrays.asList(t0)); + env.kafkaClient().prepareResponseFrom(new ListOffsetsResponse(responseData), node1); + + t1 = ListOffsetsResponse.singletonListOffsetsTopicResponse(tp1, Errors.NONE, -2L, 123L, 456); + responseData = new ListOffsetsResponseData() + .setThrottleTimeMs(0) + .setTopics(Arrays.asList(t1)); + env.kafkaClient().prepareResponseFrom(new ListOffsetsResponse(responseData), node2); + + Map partitions = new HashMap<>(); + partitions.put(tp0, OffsetSpec.latest()); + partitions.put(tp1, OffsetSpec.latest()); + ListOffsetsResult result = env.adminClient().listOffsets(partitions); + Map offsets = result.all().get(); + + assertFalse(offsets.isEmpty()); + assertEquals(345L, offsets.get(tp0).offset()); + assertEquals(543, offsets.get(tp0).leaderEpoch().get().intValue()); + assertEquals(-1L, offsets.get(tp0).timestamp()); + assertEquals(123L, offsets.get(tp1).offset()); + assertEquals(456, offsets.get(tp1).leaderEpoch().get().intValue()); + assertEquals(-2L, offsets.get(tp1).timestamp()); + } + } + + @Test + public void testListOffsetsWithLeaderChange() throws Exception { + Node node0 = new Node(0, "localhost", 8120); + Node node1 = new Node(1, "localhost", 8121); + Node node2 = new Node(2, "localhost", 8122); + List nodes = Arrays.asList(node0, node1, node2); + + final PartitionInfo oldPartitionInfo = new PartitionInfo("foo", 0, node0, + new Node[]{node0, node1, node2}, new Node[]{node0, node1, node2}); + final Cluster oldCluster = new Cluster("mockClusterId", nodes, singletonList(oldPartitionInfo), + Collections.emptySet(), Collections.emptySet(), node0); + final TopicPartition tp0 = new TopicPartition("foo", 0); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(oldCluster)) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse(prepareMetadataResponse(oldCluster, Errors.NONE)); + + ListOffsetsTopicResponse t0 = ListOffsetsResponse.singletonListOffsetsTopicResponse(tp0, Errors.NOT_LEADER_OR_FOLLOWER, -1L, 345L, 543); + ListOffsetsResponseData responseData = new ListOffsetsResponseData() + .setThrottleTimeMs(0) + .setTopics(Arrays.asList(t0)); + env.kafkaClient().prepareResponseFrom(new ListOffsetsResponse(responseData), node0); + + // updating leader from node0 to node1 and metadata refresh because of NOT_LEADER_OR_FOLLOWER + final PartitionInfo newPartitionInfo = new PartitionInfo("foo", 0, node1, + new Node[]{node0, node1, node2}, new Node[]{node0, node1, node2}); + final Cluster newCluster = new Cluster("mockClusterId", nodes, singletonList(newPartitionInfo), + Collections.emptySet(), Collections.emptySet(), node0); + + env.kafkaClient().prepareResponse(prepareMetadataResponse(newCluster, Errors.NONE)); + + t0 = ListOffsetsResponse.singletonListOffsetsTopicResponse(tp0, Errors.NONE, -2L, 123L, 456); + responseData = new ListOffsetsResponseData() + .setThrottleTimeMs(0) + .setTopics(Arrays.asList(t0)); + env.kafkaClient().prepareResponseFrom(new ListOffsetsResponse(responseData), node1); + + Map partitions = new HashMap<>(); + partitions.put(tp0, OffsetSpec.latest()); + ListOffsetsResult result = env.adminClient().listOffsets(partitions); + Map offsets = result.all().get(); + + assertFalse(offsets.isEmpty()); + assertEquals(123L, offsets.get(tp0).offset()); + assertEquals(456, offsets.get(tp0).leaderEpoch().get().intValue()); + assertEquals(-2L, offsets.get(tp0).timestamp()); + } + } + + @Test + public void testListOffsetsMetadataNonRetriableErrors() throws Exception { + + Node node0 = new Node(0, "localhost", 8120); + Node node1 = new Node(1, "localhost", 8121); + List nodes = Arrays.asList(node0, node1); + List pInfos = new ArrayList<>(); + pInfos.add(new PartitionInfo("foo", 0, node0, new Node[]{node0, node1}, new Node[]{node0, node1})); + final Cluster cluster = + new Cluster( + "mockClusterId", + nodes, + pInfos, + Collections.emptySet(), + Collections.emptySet(), + node0); + + final TopicPartition tp1 = new TopicPartition("foo", 0); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(cluster)) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse(prepareMetadataResponse(cluster, Errors.TOPIC_AUTHORIZATION_FAILED)); + + Map partitions = new HashMap<>(); + partitions.put(tp1, OffsetSpec.latest()); + ListOffsetsResult result = env.adminClient().listOffsets(partitions); + + TestUtils.assertFutureError(result.all(), TopicAuthorizationException.class); + } + } + + @Test + public void testListOffsetsPartialResponse() throws Exception { + Node node0 = new Node(0, "localhost", 8120); + Node node1 = new Node(1, "localhost", 8121); + List nodes = Arrays.asList(node0, node1); + List pInfos = new ArrayList<>(); + pInfos.add(new PartitionInfo("foo", 0, node0, new Node[]{node0, node1}, new Node[]{node0, node1})); + pInfos.add(new PartitionInfo("foo", 1, node0, new Node[]{node0, node1}, new Node[]{node0, node1})); + final Cluster cluster = + new Cluster( + "mockClusterId", + nodes, + pInfos, + Collections.emptySet(), + Collections.emptySet(), + node0); + + final TopicPartition tp0 = new TopicPartition("foo", 0); + final TopicPartition tp1 = new TopicPartition("foo", 1); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(cluster)) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + env.kafkaClient().prepareResponse(prepareMetadataResponse(cluster, Errors.NONE)); + + ListOffsetsTopicResponse t0 = ListOffsetsResponse.singletonListOffsetsTopicResponse(tp0, Errors.NONE, -2L, 123L, 456); + ListOffsetsResponseData data = new ListOffsetsResponseData() + .setThrottleTimeMs(0) + .setTopics(Arrays.asList(t0)); + env.kafkaClient().prepareResponseFrom(new ListOffsetsResponse(data), node0); + + Map partitions = new HashMap<>(); + partitions.put(tp0, OffsetSpec.latest()); + partitions.put(tp1, OffsetSpec.latest()); + ListOffsetsResult result = env.adminClient().listOffsets(partitions); + assertNotNull(result.partitionResult(tp0).get()); + TestUtils.assertFutureThrows(result.partitionResult(tp1), ApiException.class); + TestUtils.assertFutureThrows(result.all(), ApiException.class); + } + } + + @Test + public void testGetSubLevelError() { + List memberIdentities = Arrays.asList( + new MemberIdentity().setGroupInstanceId("instance-0"), + new MemberIdentity().setGroupInstanceId("instance-1")); + Map errorsMap = new HashMap<>(); + errorsMap.put(memberIdentities.get(0), Errors.NONE); + errorsMap.put(memberIdentities.get(1), Errors.FENCED_INSTANCE_ID); + assertEquals(IllegalArgumentException.class, KafkaAdminClient.getSubLevelError(errorsMap, + new MemberIdentity().setGroupInstanceId("non-exist-id"), "For unit test").getClass()); + assertNull(KafkaAdminClient.getSubLevelError(errorsMap, memberIdentities.get(0), "For unit test")); + assertEquals(FencedInstanceIdException.class, KafkaAdminClient.getSubLevelError( + errorsMap, memberIdentities.get(1), "For unit test").getClass()); + } + + @Test + public void testSuccessfulRetryAfterRequestTimeout() throws Exception { + HashMap nodes = new HashMap<>(); + MockTime time = new MockTime(); + Node node0 = new Node(0, "localhost", 8121); + nodes.put(0, node0); + Cluster cluster = new Cluster("mockClusterId", nodes.values(), + Arrays.asList(new PartitionInfo("foo", 0, node0, new Node[]{node0}, new Node[]{node0})), + Collections.emptySet(), Collections.emptySet(), + Collections.emptySet(), nodes.get(0)); + + final int requestTimeoutMs = 1000; + final int retryBackoffMs = 100; + final int apiTimeoutMs = 3000; + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(time, cluster, + AdminClientConfig.RETRY_BACKOFF_MS_CONFIG, String.valueOf(retryBackoffMs), + AdminClientConfig.REQUEST_TIMEOUT_MS_CONFIG, String.valueOf(requestTimeoutMs))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + final ListTopicsResult result = env.adminClient() + .listTopics(new ListTopicsOptions().timeoutMs(apiTimeoutMs)); + + // Wait until the first attempt has been sent, then advance the time + TestUtils.waitForCondition(() -> env.kafkaClient().hasInFlightRequests(), + "Timed out waiting for Metadata request to be sent"); + time.sleep(requestTimeoutMs + 1); + + // Wait for the request to be timed out before backing off + TestUtils.waitForCondition(() -> !env.kafkaClient().hasInFlightRequests(), + "Timed out waiting for inFlightRequests to be timed out"); + time.sleep(retryBackoffMs); + + // Since api timeout bound is not hit, AdminClient should retry + TestUtils.waitForCondition(() -> env.kafkaClient().hasInFlightRequests(), + "Failed to retry Metadata request"); + env.kafkaClient().respond(prepareMetadataResponse(cluster, Errors.NONE)); + + assertEquals(1, result.listings().get().size()); + assertEquals("foo", result.listings().get().iterator().next().name()); + } + } + + @Test + public void testDefaultApiTimeout() throws Exception { + testApiTimeout(1500, 3000, OptionalInt.empty()); + } + + @Test + public void testDefaultApiTimeoutOverride() throws Exception { + testApiTimeout(1500, 10000, OptionalInt.of(3000)); + } + + private void testApiTimeout(int requestTimeoutMs, + int defaultApiTimeoutMs, + OptionalInt overrideApiTimeoutMs) throws Exception { + HashMap nodes = new HashMap<>(); + MockTime time = new MockTime(); + Node node0 = new Node(0, "localhost", 8121); + nodes.put(0, node0); + Cluster cluster = new Cluster("mockClusterId", nodes.values(), + Arrays.asList(new PartitionInfo("foo", 0, node0, new Node[]{node0}, new Node[]{node0})), + Collections.emptySet(), Collections.emptySet(), + Collections.emptySet(), nodes.get(0)); + + final int retryBackoffMs = 100; + final int effectiveTimeoutMs = overrideApiTimeoutMs.orElse(defaultApiTimeoutMs); + assertEquals(2 * requestTimeoutMs, effectiveTimeoutMs, + "This test expects the effective timeout to be twice the request timeout"); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(time, cluster, + AdminClientConfig.RETRY_BACKOFF_MS_CONFIG, String.valueOf(retryBackoffMs), + AdminClientConfig.REQUEST_TIMEOUT_MS_CONFIG, String.valueOf(requestTimeoutMs), + AdminClientConfig.DEFAULT_API_TIMEOUT_MS_CONFIG, String.valueOf(defaultApiTimeoutMs))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + ListTopicsOptions options = new ListTopicsOptions(); + overrideApiTimeoutMs.ifPresent(options::timeoutMs); + + final ListTopicsResult result = env.adminClient().listTopics(options); + + // Wait until the first attempt has been sent, then advance the time + TestUtils.waitForCondition(() -> env.kafkaClient().hasInFlightRequests(), + "Timed out waiting for Metadata request to be sent"); + time.sleep(requestTimeoutMs + 1); + + // Wait for the request to be timed out before backing off + TestUtils.waitForCondition(() -> !env.kafkaClient().hasInFlightRequests(), + "Timed out waiting for inFlightRequests to be timed out"); + + // Since api timeout bound is not hit, AdminClient should retry + TestUtils.waitForCondition(() -> { + boolean hasInflightRequests = env.kafkaClient().hasInFlightRequests(); + if (!hasInflightRequests) + time.sleep(retryBackoffMs); + return hasInflightRequests; + }, "Timed out waiting for Metadata request to be sent"); + time.sleep(requestTimeoutMs + 1); + + TestUtils.assertFutureThrows(result.future, TimeoutException.class); + } + } + + @Test + public void testRequestTimeoutExceedingDefaultApiTimeout() throws Exception { + HashMap nodes = new HashMap<>(); + MockTime time = new MockTime(); + Node node0 = new Node(0, "localhost", 8121); + nodes.put(0, node0); + Cluster cluster = new Cluster("mockClusterId", nodes.values(), + Arrays.asList(new PartitionInfo("foo", 0, node0, new Node[]{node0}, new Node[]{node0})), + Collections.emptySet(), Collections.emptySet(), + Collections.emptySet(), nodes.get(0)); + + // This test assumes the default api timeout value of 60000. When the request timeout + // is set to something larger, we should adjust the api timeout accordingly for compatibility. + + final int retryBackoffMs = 100; + final int requestTimeoutMs = 120000; + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(time, cluster, + AdminClientConfig.RETRY_BACKOFF_MS_CONFIG, String.valueOf(retryBackoffMs), + AdminClientConfig.REQUEST_TIMEOUT_MS_CONFIG, String.valueOf(requestTimeoutMs))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + ListTopicsOptions options = new ListTopicsOptions(); + + final ListTopicsResult result = env.adminClient().listTopics(options); + + // Wait until the first attempt has been sent, then advance the time by the default api timeout + TestUtils.waitForCondition(() -> env.kafkaClient().hasInFlightRequests(), + "Timed out waiting for Metadata request to be sent"); + time.sleep(60001); + + // The in-flight request should not be cancelled + assertTrue(env.kafkaClient().hasInFlightRequests()); + + // Now sleep the remaining time for the request timeout to expire + time.sleep(60000); + TestUtils.assertFutureThrows(result.future, TimeoutException.class); + } + } + + private ClientQuotaEntity newClientQuotaEntity(String... args) { + assertTrue(args.length % 2 == 0); + + Map entityMap = new HashMap<>(args.length / 2); + for (int index = 0; index < args.length; index += 2) { + entityMap.put(args[index], args[index + 1]); + } + return new ClientQuotaEntity(entityMap); + } + + @Test + public void testDescribeClientQuotas() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + final String value = "value"; + + Map> responseData = new HashMap<>(); + ClientQuotaEntity entity1 = newClientQuotaEntity(ClientQuotaEntity.USER, "user-1", ClientQuotaEntity.CLIENT_ID, value); + ClientQuotaEntity entity2 = newClientQuotaEntity(ClientQuotaEntity.USER, "user-2", ClientQuotaEntity.CLIENT_ID, value); + responseData.put(entity1, Collections.singletonMap("consumer_byte_rate", 10000.0)); + responseData.put(entity2, Collections.singletonMap("producer_byte_rate", 20000.0)); + + env.kafkaClient().prepareResponse(DescribeClientQuotasResponse.fromQuotaEntities(responseData, 0)); + + ClientQuotaFilter filter = ClientQuotaFilter.contains(asList(ClientQuotaFilterComponent.ofEntity(ClientQuotaEntity.USER, value))); + + DescribeClientQuotasResult result = env.adminClient().describeClientQuotas(filter); + Map> resultData = result.entities().get(); + assertEquals(resultData.size(), 2); + assertTrue(resultData.containsKey(entity1)); + Map config1 = resultData.get(entity1); + assertEquals(config1.size(), 1); + assertEquals(config1.get("consumer_byte_rate"), 10000.0, 1e-6); + assertTrue(resultData.containsKey(entity2)); + Map config2 = resultData.get(entity2); + assertEquals(config2.size(), 1); + assertEquals(config2.get("producer_byte_rate"), 20000.0, 1e-6); + } + } + + @Test + public void testEqualsOfClientQuotaFilterComponent() { + assertEquals(ClientQuotaFilterComponent.ofDefaultEntity(ClientQuotaEntity.USER), + ClientQuotaFilterComponent.ofDefaultEntity(ClientQuotaEntity.USER)); + + assertEquals(ClientQuotaFilterComponent.ofEntityType(ClientQuotaEntity.USER), + ClientQuotaFilterComponent.ofEntityType(ClientQuotaEntity.USER)); + + // match = null is different from match = Empty + assertNotEquals(ClientQuotaFilterComponent.ofDefaultEntity(ClientQuotaEntity.USER), + ClientQuotaFilterComponent.ofEntityType(ClientQuotaEntity.USER)); + + assertEquals(ClientQuotaFilterComponent.ofEntity(ClientQuotaEntity.USER, "user"), + ClientQuotaFilterComponent.ofEntity(ClientQuotaEntity.USER, "user")); + + assertNotEquals(ClientQuotaFilterComponent.ofEntity(ClientQuotaEntity.USER, "user"), + ClientQuotaFilterComponent.ofDefaultEntity(ClientQuotaEntity.USER)); + + assertNotEquals(ClientQuotaFilterComponent.ofEntity(ClientQuotaEntity.USER, "user"), + ClientQuotaFilterComponent.ofEntityType(ClientQuotaEntity.USER)); + } + + @Test + public void testAlterClientQuotas() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + ClientQuotaEntity goodEntity = newClientQuotaEntity(ClientQuotaEntity.USER, "user-1"); + ClientQuotaEntity unauthorizedEntity = newClientQuotaEntity(ClientQuotaEntity.USER, "user-0"); + ClientQuotaEntity invalidEntity = newClientQuotaEntity("", "user-0"); + + Map responseData = new HashMap<>(2); + responseData.put(goodEntity, new ApiError(Errors.CLUSTER_AUTHORIZATION_FAILED, "Authorization failed")); + responseData.put(unauthorizedEntity, new ApiError(Errors.CLUSTER_AUTHORIZATION_FAILED, "Authorization failed")); + responseData.put(invalidEntity, new ApiError(Errors.INVALID_REQUEST, "Invalid quota entity")); + + env.kafkaClient().prepareResponse(AlterClientQuotasResponse.fromQuotaEntities(responseData, 0)); + + List entries = new ArrayList<>(3); + entries.add(new ClientQuotaAlteration(goodEntity, singleton(new ClientQuotaAlteration.Op("consumer_byte_rate", 10000.0)))); + entries.add(new ClientQuotaAlteration(unauthorizedEntity, singleton(new ClientQuotaAlteration.Op("producer_byte_rate", 10000.0)))); + entries.add(new ClientQuotaAlteration(invalidEntity, singleton(new ClientQuotaAlteration.Op("producer_byte_rate", 100.0)))); + + AlterClientQuotasResult result = env.adminClient().alterClientQuotas(entries); + result.values().get(goodEntity); + TestUtils.assertFutureError(result.values().get(unauthorizedEntity), ClusterAuthorizationException.class); + TestUtils.assertFutureError(result.values().get(invalidEntity), InvalidRequestException.class); + } + } + + @Test + public void testAlterReplicaLogDirsSuccess() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + createAlterLogDirsResponse(env, env.cluster().nodeById(0), Errors.NONE, 0); + createAlterLogDirsResponse(env, env.cluster().nodeById(1), Errors.NONE, 0); + + TopicPartitionReplica tpr0 = new TopicPartitionReplica("topic", 0, 0); + TopicPartitionReplica tpr1 = new TopicPartitionReplica("topic", 0, 1); + + Map logDirs = new HashMap<>(); + logDirs.put(tpr0, "/data0"); + logDirs.put(tpr1, "/data1"); + AlterReplicaLogDirsResult result = env.adminClient().alterReplicaLogDirs(logDirs); + assertNull(result.values().get(tpr0).get()); + assertNull(result.values().get(tpr1).get()); + } + } + + @Test + public void testAlterReplicaLogDirsLogDirNotFound() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + createAlterLogDirsResponse(env, env.cluster().nodeById(0), Errors.NONE, 0); + createAlterLogDirsResponse(env, env.cluster().nodeById(1), Errors.LOG_DIR_NOT_FOUND, 0); + + TopicPartitionReplica tpr0 = new TopicPartitionReplica("topic", 0, 0); + TopicPartitionReplica tpr1 = new TopicPartitionReplica("topic", 0, 1); + + Map logDirs = new HashMap<>(); + logDirs.put(tpr0, "/data0"); + logDirs.put(tpr1, "/data1"); + AlterReplicaLogDirsResult result = env.adminClient().alterReplicaLogDirs(logDirs); + assertNull(result.values().get(tpr0).get()); + TestUtils.assertFutureError(result.values().get(tpr1), LogDirNotFoundException.class); + } + } + + @Test + public void testAlterReplicaLogDirsUnrequested() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + createAlterLogDirsResponse(env, env.cluster().nodeById(0), Errors.NONE, 1, 2); + + TopicPartitionReplica tpr1 = new TopicPartitionReplica("topic", 1, 0); + + Map logDirs = new HashMap<>(); + logDirs.put(tpr1, "/data1"); + AlterReplicaLogDirsResult result = env.adminClient().alterReplicaLogDirs(logDirs); + assertNull(result.values().get(tpr1).get()); + } + } + + @Test + public void testAlterReplicaLogDirsPartialResponse() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + createAlterLogDirsResponse(env, env.cluster().nodeById(0), Errors.NONE, 1); + + TopicPartitionReplica tpr1 = new TopicPartitionReplica("topic", 1, 0); + TopicPartitionReplica tpr2 = new TopicPartitionReplica("topic", 2, 0); + + Map logDirs = new HashMap<>(); + logDirs.put(tpr1, "/data1"); + logDirs.put(tpr2, "/data1"); + AlterReplicaLogDirsResult result = env.adminClient().alterReplicaLogDirs(logDirs); + assertNull(result.values().get(tpr1).get()); + TestUtils.assertFutureThrows(result.values().get(tpr2), ApiException.class); + } + } + + @Test + public void testAlterReplicaLogDirsPartialFailure() throws Exception { + long defaultApiTimeout = 60000; + MockTime time = new MockTime(); + + try (AdminClientUnitTestEnv env = mockClientEnv(time, AdminClientConfig.RETRIES_CONFIG, "0")) { + + // Provide only one prepared response from node 1 + env.kafkaClient().prepareResponseFrom( + prepareAlterLogDirsResponse(Errors.NONE, "topic", 2), + env.cluster().nodeById(1)); + + TopicPartitionReplica tpr1 = new TopicPartitionReplica("topic", 1, 0); + TopicPartitionReplica tpr2 = new TopicPartitionReplica("topic", 2, 1); + + Map logDirs = new HashMap<>(); + logDirs.put(tpr1, "/data1"); + logDirs.put(tpr2, "/data1"); + + AlterReplicaLogDirsResult result = env.adminClient().alterReplicaLogDirs(logDirs); + + // Wait until the prepared attempt has been consumed + TestUtils.waitForCondition(() -> env.kafkaClient().numAwaitingResponses() == 0, + "Failed awaiting requests"); + + // Wait until the request is sent out + TestUtils.waitForCondition(() -> env.kafkaClient().inFlightRequestCount() == 1, + "Failed awaiting request"); + + // Advance time past the default api timeout to time out the inflight request + time.sleep(defaultApiTimeout + 1); + + TestUtils.assertFutureThrows(result.values().get(tpr1), ApiException.class); + assertNull(result.values().get(tpr2).get()); + } + } + + @Test + public void testDescribeUserScramCredentials() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + final String user0Name = "user0"; + final ScramMechanism user0ScramMechanism0 = ScramMechanism.SCRAM_SHA_256; + final int user0Iterations0 = 4096; + final ScramMechanism user0ScramMechanism1 = ScramMechanism.SCRAM_SHA_512; + final int user0Iterations1 = 8192; + + final CredentialInfo user0CredentialInfo0 = new CredentialInfo(); + user0CredentialInfo0.setMechanism(user0ScramMechanism0.type()); + user0CredentialInfo0.setIterations(user0Iterations0); + final CredentialInfo user0CredentialInfo1 = new CredentialInfo(); + user0CredentialInfo1.setMechanism(user0ScramMechanism1.type()); + user0CredentialInfo1.setIterations(user0Iterations1); + + final String user1Name = "user1"; + final ScramMechanism user1ScramMechanism = ScramMechanism.SCRAM_SHA_256; + final int user1Iterations = 4096; + + final CredentialInfo user1CredentialInfo = new CredentialInfo(); + user1CredentialInfo.setMechanism(user1ScramMechanism.type()); + user1CredentialInfo.setIterations(user1Iterations); + + final DescribeUserScramCredentialsResponseData responseData = new DescribeUserScramCredentialsResponseData(); + responseData.setResults(Arrays.asList( + new DescribeUserScramCredentialsResponseData.DescribeUserScramCredentialsResult() + .setUser(user0Name) + .setCredentialInfos(Arrays.asList(user0CredentialInfo0, user0CredentialInfo1)), + new DescribeUserScramCredentialsResponseData.DescribeUserScramCredentialsResult() + .setUser(user1Name) + .setCredentialInfos(singletonList(user1CredentialInfo)))); + final DescribeUserScramCredentialsResponse response = new DescribeUserScramCredentialsResponse(responseData); + + final Set usersRequestedSet = new HashSet<>(); + usersRequestedSet.add(user0Name); + usersRequestedSet.add(user1Name); + + for (final List users : asList(null, new ArrayList(), asList(user0Name, null, user1Name))) { + env.kafkaClient().prepareResponse(response); + + final DescribeUserScramCredentialsResult result = env.adminClient().describeUserScramCredentials(users); + final Map descriptionResults = result.all().get(); + final KafkaFuture user0DescriptionFuture = result.description(user0Name); + final KafkaFuture user1DescriptionFuture = result.description(user1Name); + + final Set usersDescribedFromUsersSet = new HashSet<>(result.users().get()); + assertEquals(usersRequestedSet, usersDescribedFromUsersSet); + + final Set usersDescribedFromMapKeySet = descriptionResults.keySet(); + assertEquals(usersRequestedSet, usersDescribedFromMapKeySet); + + final UserScramCredentialsDescription userScramCredentialsDescription0 = descriptionResults.get(user0Name); + assertEquals(user0Name, userScramCredentialsDescription0.name()); + assertEquals(2, userScramCredentialsDescription0.credentialInfos().size()); + assertEquals(user0ScramMechanism0, userScramCredentialsDescription0.credentialInfos().get(0).mechanism()); + assertEquals(user0Iterations0, userScramCredentialsDescription0.credentialInfos().get(0).iterations()); + assertEquals(user0ScramMechanism1, userScramCredentialsDescription0.credentialInfos().get(1).mechanism()); + assertEquals(user0Iterations1, userScramCredentialsDescription0.credentialInfos().get(1).iterations()); + assertEquals(userScramCredentialsDescription0, user0DescriptionFuture.get()); + + final UserScramCredentialsDescription userScramCredentialsDescription1 = descriptionResults.get(user1Name); + assertEquals(user1Name, userScramCredentialsDescription1.name()); + assertEquals(1, userScramCredentialsDescription1.credentialInfos().size()); + assertEquals(user1ScramMechanism, userScramCredentialsDescription1.credentialInfos().get(0).mechanism()); + assertEquals(user1Iterations, userScramCredentialsDescription1.credentialInfos().get(0).iterations()); + assertEquals(userScramCredentialsDescription1, user1DescriptionFuture.get()); + } + } + } + + @Test + public void testAlterUserScramCredentialsUnknownMechanism() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + final String user0Name = "user0"; + ScramMechanism user0ScramMechanism0 = ScramMechanism.UNKNOWN; + + final String user1Name = "user1"; + ScramMechanism user1ScramMechanism0 = ScramMechanism.UNKNOWN; + + final String user2Name = "user2"; + ScramMechanism user2ScramMechanism0 = ScramMechanism.SCRAM_SHA_256; + + AlterUserScramCredentialsResponseData responseData = new AlterUserScramCredentialsResponseData(); + responseData.setResults(Arrays.asList( + new AlterUserScramCredentialsResponseData.AlterUserScramCredentialsResult().setUser(user2Name))); + + env.kafkaClient().prepareResponse(new AlterUserScramCredentialsResponse(responseData)); + + AlterUserScramCredentialsResult result = env.adminClient().alterUserScramCredentials(Arrays.asList( + new UserScramCredentialDeletion(user0Name, user0ScramMechanism0), + new UserScramCredentialUpsertion(user1Name, new ScramCredentialInfo(user1ScramMechanism0, 8192), "password"), + new UserScramCredentialUpsertion(user2Name, new ScramCredentialInfo(user2ScramMechanism0, 4096), "password"))); + Map> resultData = result.values(); + assertEquals(3, resultData.size()); + Arrays.asList(user0Name, user1Name).stream().forEach(u -> { + assertTrue(resultData.containsKey(u)); + try { + resultData.get(u).get(); + fail("Expected request for user " + u + " to complete exceptionally, but it did not"); + } catch (Exception expected) { + // ignore + } + }); + assertTrue(resultData.containsKey(user2Name)); + try { + resultData.get(user2Name).get(); + } catch (Exception e) { + fail("Expected request for user " + user2Name + " to NOT complete excdptionally, but it did"); + } + try { + result.all().get(); + fail("Expected 'result.all().get()' to throw an exception since at least one user failed, but it did not"); + } catch (final Exception expected) { + // ignore, expected + } + } + } + + @Test + public void testAlterUserScramCredentials() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + + final String user0Name = "user0"; + ScramMechanism user0ScramMechanism0 = ScramMechanism.SCRAM_SHA_256; + ScramMechanism user0ScramMechanism1 = ScramMechanism.SCRAM_SHA_512; + final String user1Name = "user1"; + ScramMechanism user1ScramMechanism0 = ScramMechanism.SCRAM_SHA_256; + final String user2Name = "user2"; + ScramMechanism user2ScramMechanism0 = ScramMechanism.SCRAM_SHA_512; + AlterUserScramCredentialsResponseData responseData = new AlterUserScramCredentialsResponseData(); + responseData.setResults(Arrays.asList(user0Name, user1Name, user2Name).stream().map(u -> + new AlterUserScramCredentialsResponseData.AlterUserScramCredentialsResult() + .setUser(u).setErrorCode(Errors.NONE.code())).collect(Collectors.toList())); + + env.kafkaClient().prepareResponse(new AlterUserScramCredentialsResponse(responseData)); + + AlterUserScramCredentialsResult result = env.adminClient().alterUserScramCredentials(Arrays.asList( + new UserScramCredentialDeletion(user0Name, user0ScramMechanism0), + new UserScramCredentialUpsertion(user0Name, new ScramCredentialInfo(user0ScramMechanism1, 8192), "password"), + new UserScramCredentialUpsertion(user1Name, new ScramCredentialInfo(user1ScramMechanism0, 8192), "password"), + new UserScramCredentialDeletion(user2Name, user2ScramMechanism0))); + Map> resultData = result.values(); + assertEquals(3, resultData.size()); + Arrays.asList(user0Name, user1Name, user2Name).stream().forEach(u -> { + assertTrue(resultData.containsKey(u)); + assertFalse(resultData.get(u).isCompletedExceptionally()); + }); + } + } + + private void createAlterLogDirsResponse(AdminClientUnitTestEnv env, Node node, Errors error, int... partitions) { + env.kafkaClient().prepareResponseFrom( + prepareAlterLogDirsResponse(error, "topic", partitions), node); + } + + private AlterReplicaLogDirsResponse prepareAlterLogDirsResponse(Errors error, String topic, int... partitions) { + return new AlterReplicaLogDirsResponse( + new AlterReplicaLogDirsResponseData().setResults(singletonList( + new AlterReplicaLogDirTopicResult() + .setTopicName(topic) + .setPartitions(Arrays.stream(partitions).boxed().map(partitionId -> + new AlterReplicaLogDirPartitionResult() + .setPartitionIndex(partitionId) + .setErrorCode(error.code())).collect(Collectors.toList()))))); + } + + @Test + public void testDescribeLogDirsPartialFailure() throws Exception { + long defaultApiTimeout = 60000; + MockTime time = new MockTime(); + + try (AdminClientUnitTestEnv env = mockClientEnv(time, AdminClientConfig.RETRIES_CONFIG, "0")) { + + env.kafkaClient().prepareResponseFrom( + prepareDescribeLogDirsResponse(Errors.NONE, "/data"), + env.cluster().nodeById(1)); + + DescribeLogDirsResult result = env.adminClient().describeLogDirs(Arrays.asList(0, 1)); + + // Wait until the prepared attempt has been consumed + TestUtils.waitForCondition(() -> env.kafkaClient().numAwaitingResponses() == 0, + "Failed awaiting requests"); + + // Wait until the request is sent out + TestUtils.waitForCondition(() -> env.kafkaClient().inFlightRequestCount() == 1, + "Failed awaiting request"); + + // Advance time past the default api timeout to time out the inflight request + time.sleep(defaultApiTimeout + 1); + + TestUtils.assertFutureThrows(result.descriptions().get(0), ApiException.class); + assertNotNull(result.descriptions().get(1).get()); + } + } + + @Test + public void testUnregisterBrokerSuccess() throws InterruptedException, ExecutionException { + int nodeId = 1; + try (final AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions( + NodeApiVersions.create(ApiKeys.UNREGISTER_BROKER.id, (short) 0, (short) 0)); + env.kafkaClient().prepareResponse(prepareUnregisterBrokerResponse(Errors.NONE, 0)); + UnregisterBrokerResult result = env.adminClient().unregisterBroker(nodeId); + // Validate response + assertNotNull(result.all()); + result.all().get(); + } + } + + @Test + public void testUnregisterBrokerFailure() { + int nodeId = 1; + try (final AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions( + NodeApiVersions.create(ApiKeys.UNREGISTER_BROKER.id, (short) 0, (short) 0)); + env.kafkaClient().prepareResponse(prepareUnregisterBrokerResponse(Errors.UNKNOWN_SERVER_ERROR, 0)); + UnregisterBrokerResult result = env.adminClient().unregisterBroker(nodeId); + // Validate response + assertNotNull(result.all()); + TestUtils.assertFutureThrows(result.all(), Errors.UNKNOWN_SERVER_ERROR.exception().getClass()); + } + } + + @Test + public void testUnregisterBrokerTimeoutAndSuccessRetry() throws ExecutionException, InterruptedException { + int nodeId = 1; + try (final AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions( + NodeApiVersions.create(ApiKeys.UNREGISTER_BROKER.id, (short) 0, (short) 0)); + env.kafkaClient().prepareResponse(prepareUnregisterBrokerResponse(Errors.REQUEST_TIMED_OUT, 0)); + env.kafkaClient().prepareResponse(prepareUnregisterBrokerResponse(Errors.NONE, 0)); + + UnregisterBrokerResult result = env.adminClient().unregisterBroker(nodeId); + + // Validate response + assertNotNull(result.all()); + result.all().get(); + } + } + + @Test + public void testUnregisterBrokerTimeoutAndFailureRetry() { + int nodeId = 1; + try (final AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions( + NodeApiVersions.create(ApiKeys.UNREGISTER_BROKER.id, (short) 0, (short) 0)); + env.kafkaClient().prepareResponse(prepareUnregisterBrokerResponse(Errors.REQUEST_TIMED_OUT, 0)); + env.kafkaClient().prepareResponse(prepareUnregisterBrokerResponse(Errors.UNKNOWN_SERVER_ERROR, 0)); + + UnregisterBrokerResult result = env.adminClient().unregisterBroker(nodeId); + + // Validate response + assertNotNull(result.all()); + TestUtils.assertFutureThrows(result.all(), Errors.UNKNOWN_SERVER_ERROR.exception().getClass()); + } + } + + @Test + public void testUnregisterBrokerTimeoutMaxRetry() { + int nodeId = 1; + try (final AdminClientUnitTestEnv env = mockClientEnv(Time.SYSTEM, AdminClientConfig.RETRIES_CONFIG, "1")) { + env.kafkaClient().setNodeApiVersions( + NodeApiVersions.create(ApiKeys.UNREGISTER_BROKER.id, (short) 0, (short) 0)); + env.kafkaClient().prepareResponse(prepareUnregisterBrokerResponse(Errors.REQUEST_TIMED_OUT, 0)); + env.kafkaClient().prepareResponse(prepareUnregisterBrokerResponse(Errors.REQUEST_TIMED_OUT, 0)); + + UnregisterBrokerResult result = env.adminClient().unregisterBroker(nodeId); + + // Validate response + assertNotNull(result.all()); + TestUtils.assertFutureThrows(result.all(), Errors.REQUEST_TIMED_OUT.exception().getClass()); + } + } + + @Test + public void testUnregisterBrokerTimeoutMaxWait() { + int nodeId = 1; + try (final AdminClientUnitTestEnv env = mockClientEnv()) { + env.kafkaClient().setNodeApiVersions( + NodeApiVersions.create(ApiKeys.UNREGISTER_BROKER.id, (short) 0, (short) 0)); + + UnregisterBrokerOptions options = new UnregisterBrokerOptions(); + options.timeoutMs = 10; + UnregisterBrokerResult result = env.adminClient().unregisterBroker(nodeId, options); + + // Validate response + assertNotNull(result.all()); + TestUtils.assertFutureThrows(result.all(), Errors.REQUEST_TIMED_OUT.exception().getClass()); + } + } + + @Test + public void testDescribeProducers() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + TopicPartition topicPartition = new TopicPartition("foo", 0); + + Node leader = env.cluster().nodes().iterator().next(); + expectMetadataRequest(env, topicPartition, leader); + + List expected = Arrays.asList( + new ProducerState(12345L, 15, 30, env.time().milliseconds(), + OptionalInt.of(99), OptionalLong.empty()), + new ProducerState(12345L, 15, 30, env.time().milliseconds(), + OptionalInt.empty(), OptionalLong.of(23423L)) + ); + + DescribeProducersResponse response = buildDescribeProducersResponse( + topicPartition, + expected + ); + + env.kafkaClient().prepareResponseFrom( + request -> request instanceof DescribeProducersRequest, + response, + leader + ); + + DescribeProducersResult result = env.adminClient().describeProducers(singleton(topicPartition)); + KafkaFuture partitionFuture = + result.partitionResult(topicPartition); + assertEquals(new HashSet<>(expected), new HashSet<>(partitionFuture.get().activeProducers())); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testDescribeProducersTimeout(boolean timeoutInMetadataLookup) throws Exception { + MockTime time = new MockTime(); + try (AdminClientUnitTestEnv env = mockClientEnv(time)) { + TopicPartition topicPartition = new TopicPartition("foo", 0); + int requestTimeoutMs = 15000; + + if (!timeoutInMetadataLookup) { + Node leader = env.cluster().nodes().iterator().next(); + expectMetadataRequest(env, topicPartition, leader); + } + + DescribeProducersOptions options = new DescribeProducersOptions().timeoutMs(requestTimeoutMs); + DescribeProducersResult result = env.adminClient().describeProducers( + singleton(topicPartition), options); + assertFalse(result.all().isDone()); + + time.sleep(requestTimeoutMs); + TestUtils.waitForCondition(() -> result.all().isDone(), + "Future failed to timeout after expiration of timeout"); + + assertTrue(result.all().isCompletedExceptionally()); + TestUtils.assertFutureThrows(result.all(), TimeoutException.class); + assertFalse(env.kafkaClient().hasInFlightRequests()); + } + } + + @Test + public void testDescribeProducersRetryAfterDisconnect() throws Exception { + MockTime time = new MockTime(); + int retryBackoffMs = 100; + Cluster cluster = mockCluster(3, 0); + Map configOverride = newStrMap(AdminClientConfig.RETRY_BACKOFF_MS_CONFIG, "" + retryBackoffMs); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(time, cluster, configOverride)) { + TopicPartition topicPartition = new TopicPartition("foo", 0); + Iterator nodeIterator = env.cluster().nodes().iterator(); + + Node initialLeader = nodeIterator.next(); + expectMetadataRequest(env, topicPartition, initialLeader); + + List expected = Arrays.asList( + new ProducerState(12345L, 15, 30, env.time().milliseconds(), + OptionalInt.of(99), OptionalLong.empty()), + new ProducerState(12345L, 15, 30, env.time().milliseconds(), + OptionalInt.empty(), OptionalLong.of(23423L)) + ); + + DescribeProducersResponse response = buildDescribeProducersResponse( + topicPartition, + expected + ); + + env.kafkaClient().prepareResponseFrom( + request -> { + // We need a sleep here because the client will attempt to + // backoff after the disconnect + env.time().sleep(retryBackoffMs); + return request instanceof DescribeProducersRequest; + }, + response, + initialLeader, + true + ); + + Node retryLeader = nodeIterator.next(); + expectMetadataRequest(env, topicPartition, retryLeader); + + env.kafkaClient().prepareResponseFrom( + request -> request instanceof DescribeProducersRequest, + response, + retryLeader + ); + + DescribeProducersResult result = env.adminClient().describeProducers(singleton(topicPartition)); + KafkaFuture partitionFuture = + result.partitionResult(topicPartition); + assertEquals(new HashSet<>(expected), new HashSet<>(partitionFuture.get().activeProducers())); + } + } + + @Test + public void testDescribeTransactions() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + String transactionalId = "foo"; + Node coordinator = env.cluster().nodes().iterator().next(); + TransactionDescription expected = new TransactionDescription( + coordinator.id(), TransactionState.COMPLETE_COMMIT, 12345L, + 15, 10000L, OptionalLong.empty(), emptySet()); + + env.kafkaClient().prepareResponse( + request -> request instanceof FindCoordinatorRequest, + prepareFindCoordinatorResponse(Errors.NONE, transactionalId, coordinator) + ); + + env.kafkaClient().prepareResponseFrom( + request -> request instanceof DescribeTransactionsRequest, + new DescribeTransactionsResponse(new DescribeTransactionsResponseData().setTransactionStates( + singletonList(new DescribeTransactionsResponseData.TransactionState() + .setErrorCode(Errors.NONE.code()) + .setProducerEpoch((short) expected.producerEpoch()) + .setProducerId(expected.producerId()) + .setTransactionalId(transactionalId) + .setTransactionTimeoutMs(10000) + .setTransactionStartTimeMs(-1) + .setTransactionState(expected.state().toString()) + ) + )), + coordinator + ); + + DescribeTransactionsResult result = env.adminClient().describeTransactions(singleton(transactionalId)); + KafkaFuture future = result.description(transactionalId); + assertEquals(expected, future.get()); + } + } + + @Test + public void testRetryDescribeTransactionsAfterNotCoordinatorError() throws Exception { + MockTime time = new MockTime(); + int retryBackoffMs = 100; + Cluster cluster = mockCluster(3, 0); + Map configOverride = newStrMap(AdminClientConfig.RETRY_BACKOFF_MS_CONFIG, "" + retryBackoffMs); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(time, cluster, configOverride)) { + String transactionalId = "foo"; + + Iterator nodeIterator = env.cluster().nodes().iterator(); + Node coordinator1 = nodeIterator.next(); + Node coordinator2 = nodeIterator.next(); + + env.kafkaClient().prepareResponse( + request -> request instanceof FindCoordinatorRequest, + new FindCoordinatorResponse(new FindCoordinatorResponseData() + .setCoordinators(Arrays.asList(new FindCoordinatorResponseData.Coordinator() + .setKey(transactionalId) + .setErrorCode(Errors.NONE.code()) + .setNodeId(coordinator1.id()) + .setHost(coordinator1.host()) + .setPort(coordinator1.port())))) + ); + + env.kafkaClient().prepareResponseFrom( + request -> { + if (!(request instanceof DescribeTransactionsRequest)) { + return false; + } else { + // Backoff needed here for the retry of FindCoordinator + time.sleep(retryBackoffMs); + return true; + } + }, + new DescribeTransactionsResponse(new DescribeTransactionsResponseData().setTransactionStates( + singletonList(new DescribeTransactionsResponseData.TransactionState() + .setErrorCode(Errors.NOT_COORDINATOR.code()) + .setTransactionalId(transactionalId) + ) + )), + coordinator1 + ); + + env.kafkaClient().prepareResponse( + request -> request instanceof FindCoordinatorRequest, + new FindCoordinatorResponse(new FindCoordinatorResponseData() + .setCoordinators(Arrays.asList(new FindCoordinatorResponseData.Coordinator() + .setKey(transactionalId) + .setErrorCode(Errors.NONE.code()) + .setNodeId(coordinator2.id()) + .setHost(coordinator2.host()) + .setPort(coordinator2.port())))) + ); + + TransactionDescription expected = new TransactionDescription( + coordinator2.id(), TransactionState.COMPLETE_COMMIT, 12345L, + 15, 10000L, OptionalLong.empty(), emptySet()); + + env.kafkaClient().prepareResponseFrom( + request -> request instanceof DescribeTransactionsRequest, + new DescribeTransactionsResponse(new DescribeTransactionsResponseData().setTransactionStates( + singletonList(new DescribeTransactionsResponseData.TransactionState() + .setErrorCode(Errors.NONE.code()) + .setProducerEpoch((short) expected.producerEpoch()) + .setProducerId(expected.producerId()) + .setTransactionalId(transactionalId) + .setTransactionTimeoutMs(10000) + .setTransactionStartTimeMs(-1) + .setTransactionState(expected.state().toString()) + ) + )), + coordinator2 + ); + + DescribeTransactionsResult result = env.adminClient().describeTransactions(singleton(transactionalId)); + KafkaFuture future = result.description(transactionalId); + assertEquals(expected, future.get()); + } + } + + @Test + public void testAbortTransaction() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + TopicPartition topicPartition = new TopicPartition("foo", 13); + AbortTransactionSpec abortSpec = new AbortTransactionSpec( + topicPartition, 12345L, (short) 15, 200); + Node leader = env.cluster().nodes().iterator().next(); + + expectMetadataRequest(env, topicPartition, leader); + + env.kafkaClient().prepareResponseFrom( + request -> request instanceof WriteTxnMarkersRequest, + writeTxnMarkersResponse(abortSpec, Errors.NONE), + leader + ); + + AbortTransactionResult result = env.adminClient().abortTransaction(abortSpec); + assertNull(result.all().get()); + } + } + + @Test + public void testAbortTransactionFindLeaderAfterDisconnect() throws Exception { + MockTime time = new MockTime(); + int retryBackoffMs = 100; + Cluster cluster = mockCluster(3, 0); + Map configOverride = newStrMap(AdminClientConfig.RETRY_BACKOFF_MS_CONFIG, "" + retryBackoffMs); + + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(time, cluster, configOverride)) { + TopicPartition topicPartition = new TopicPartition("foo", 13); + AbortTransactionSpec abortSpec = new AbortTransactionSpec( + topicPartition, 12345L, (short) 15, 200); + Iterator nodeIterator = env.cluster().nodes().iterator(); + Node firstLeader = nodeIterator.next(); + + expectMetadataRequest(env, topicPartition, firstLeader); + + WriteTxnMarkersResponse response = writeTxnMarkersResponse(abortSpec, Errors.NONE); + env.kafkaClient().prepareResponseFrom( + request -> { + // We need a sleep here because the client will attempt to + // backoff after the disconnect + time.sleep(retryBackoffMs); + return request instanceof WriteTxnMarkersRequest; + }, + response, + firstLeader, + true + ); + + Node retryLeader = nodeIterator.next(); + expectMetadataRequest(env, topicPartition, retryLeader); + + env.kafkaClient().prepareResponseFrom( + request -> request instanceof WriteTxnMarkersRequest, + response, + retryLeader + ); + + AbortTransactionResult result = env.adminClient().abortTransaction(abortSpec); + assertNull(result.all().get()); + } + } + + @Test + public void testListTransactions() throws Exception { + try (AdminClientUnitTestEnv env = mockClientEnv()) { + MetadataResponseData.MetadataResponseBrokerCollection brokers = + new MetadataResponseData.MetadataResponseBrokerCollection(); + + env.cluster().nodes().forEach(node -> { + brokers.add(new MetadataResponseData.MetadataResponseBroker() + .setHost(node.host()) + .setNodeId(node.id()) + .setPort(node.port()) + .setRack(node.rack()) + ); + }); + + env.kafkaClient().prepareResponse( + request -> request instanceof MetadataRequest, + new MetadataResponse(new MetadataResponseData().setBrokers(brokers), + MetadataResponseData.HIGHEST_SUPPORTED_VERSION) + ); + + List expected = Arrays.asList( + new TransactionListing("foo", 12345L, TransactionState.ONGOING), + new TransactionListing("bar", 98765L, TransactionState.PREPARE_ABORT), + new TransactionListing("baz", 13579L, TransactionState.COMPLETE_COMMIT) + ); + assertEquals(Utils.mkSet(0, 1, 2), env.cluster().nodes().stream().map(Node::id) + .collect(Collectors.toSet())); + + env.cluster().nodes().forEach(node -> { + ListTransactionsResponseData response = new ListTransactionsResponseData() + .setErrorCode(Errors.NONE.code()); + + TransactionListing listing = expected.get(node.id()); + response.transactionStates().add(new ListTransactionsResponseData.TransactionState() + .setTransactionalId(listing.transactionalId()) + .setProducerId(listing.producerId()) + .setTransactionState(listing.state().toString()) + ); + + env.kafkaClient().prepareResponseFrom( + request -> request instanceof ListTransactionsRequest, + new ListTransactionsResponse(response), + node + ); + }); + + ListTransactionsResult result = env.adminClient().listTransactions(); + assertEquals(new HashSet<>(expected), new HashSet<>(result.all().get())); + } + } + + private WriteTxnMarkersResponse writeTxnMarkersResponse( + AbortTransactionSpec abortSpec, + Errors error + ) { + WriteTxnMarkersResponseData.WritableTxnMarkerPartitionResult partitionResponse = + new WriteTxnMarkersResponseData.WritableTxnMarkerPartitionResult() + .setPartitionIndex(abortSpec.topicPartition().partition()) + .setErrorCode(error.code()); + + WriteTxnMarkersResponseData.WritableTxnMarkerTopicResult topicResponse = + new WriteTxnMarkersResponseData.WritableTxnMarkerTopicResult() + .setName(abortSpec.topicPartition().topic()); + topicResponse.partitions().add(partitionResponse); + + WriteTxnMarkersResponseData.WritableTxnMarkerResult markerResponse = + new WriteTxnMarkersResponseData.WritableTxnMarkerResult() + .setProducerId(abortSpec.producerId()); + markerResponse.topics().add(topicResponse); + + WriteTxnMarkersResponseData response = new WriteTxnMarkersResponseData(); + response.markers().add(markerResponse); + + return new WriteTxnMarkersResponse(response); + } + + private DescribeProducersResponse buildDescribeProducersResponse( + TopicPartition topicPartition, + List producerStates + ) { + DescribeProducersResponseData response = new DescribeProducersResponseData(); + + DescribeProducersResponseData.TopicResponse topicResponse = + new DescribeProducersResponseData.TopicResponse() + .setName(topicPartition.topic()); + response.topics().add(topicResponse); + + DescribeProducersResponseData.PartitionResponse partitionResponse = + new DescribeProducersResponseData.PartitionResponse() + .setPartitionIndex(topicPartition.partition()) + .setErrorCode(Errors.NONE.code()); + topicResponse.partitions().add(partitionResponse); + + partitionResponse.setActiveProducers(producerStates.stream().map(producerState -> + new DescribeProducersResponseData.ProducerState() + .setProducerId(producerState.producerId()) + .setProducerEpoch(producerState.producerEpoch()) + .setCoordinatorEpoch(producerState.coordinatorEpoch().orElse(-1)) + .setLastSequence(producerState.lastSequence()) + .setLastTimestamp(producerState.lastTimestamp()) + .setCurrentTxnStartOffset(producerState.currentTransactionStartOffset().orElse(-1L)) + ).collect(Collectors.toList())); + + return new DescribeProducersResponse(response); + } + + private void expectMetadataRequest( + AdminClientUnitTestEnv env, + TopicPartition topicPartition, + Node leader + ) { + MetadataResponseData.MetadataResponseTopicCollection responseTopics = + new MetadataResponseData.MetadataResponseTopicCollection(); + + MetadataResponseTopic responseTopic = new MetadataResponseTopic() + .setName(topicPartition.topic()) + .setErrorCode(Errors.NONE.code()); + responseTopics.add(responseTopic); + + MetadataResponsePartition responsePartition = new MetadataResponsePartition() + .setErrorCode(Errors.NONE.code()) + .setPartitionIndex(topicPartition.partition()) + .setLeaderId(leader.id()) + .setReplicaNodes(singletonList(leader.id())) + .setIsrNodes(singletonList(leader.id())); + responseTopic.partitions().add(responsePartition); + + env.kafkaClient().prepareResponse( + request -> { + if (!(request instanceof MetadataRequest)) { + return false; + } + MetadataRequest metadataRequest = (MetadataRequest) request; + return metadataRequest.topics().equals(singletonList(topicPartition.topic())); + }, + new MetadataResponse(new MetadataResponseData().setTopics(responseTopics), + MetadataResponseData.HIGHEST_SUPPORTED_VERSION) + ); + } + + /** + * Test that if the client can obtain a node assignment, but can't send to the given + * node, it will disconnect and try a different node. + */ + @Test + public void testClientSideTimeoutAfterFailureToSend() throws Exception { + Cluster cluster = mockCluster(3, 0); + CompletableFuture disconnectFuture = new CompletableFuture<>(); + MockTime time = new MockTime(); + try (final AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(time, cluster, + newStrMap(AdminClientConfig.REQUEST_TIMEOUT_MS_CONFIG, "1", + AdminClientConfig.DEFAULT_API_TIMEOUT_MS_CONFIG, "100000", + AdminClientConfig.RETRY_BACKOFF_MS_CONFIG, "1"))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + for (Node node : cluster.nodes()) { + env.kafkaClient().delayReady(node, 100); + } + + // We use a countdown latch to ensure that we get to the first + // call to `ready` before we increment the time below to trigger + // the disconnect. + CountDownLatch readyLatch = new CountDownLatch(2); + + env.kafkaClient().setDisconnectFuture(disconnectFuture); + env.kafkaClient().setReadyCallback(node -> readyLatch.countDown()); + env.kafkaClient().prepareResponse(prepareMetadataResponse(cluster, Errors.NONE)); + + final ListTopicsResult result = env.adminClient().listTopics(); + + readyLatch.await(TestUtils.DEFAULT_MAX_WAIT_MS, TimeUnit.MILLISECONDS); + log.debug("Advancing clock by 25 ms to trigger client-side disconnect."); + time.sleep(25); + disconnectFuture.get(); + + log.debug("Enabling nodes to send requests again."); + for (Node node : cluster.nodes()) { + env.kafkaClient().delayReady(node, 0); + } + time.sleep(5); + log.info("Waiting for result."); + assertEquals(0, result.listings().get().size()); + } + } + + /** + * Test that if the client can send to a node, but doesn't receive a response, it will + * disconnect and try a different node. + */ + @Test + public void testClientSideTimeoutAfterFailureToReceiveResponse() throws Exception { + Cluster cluster = mockCluster(3, 0); + CompletableFuture disconnectFuture = new CompletableFuture<>(); + MockTime time = new MockTime(); + try (final AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(time, cluster, + newStrMap(AdminClientConfig.REQUEST_TIMEOUT_MS_CONFIG, "1", + AdminClientConfig.DEFAULT_API_TIMEOUT_MS_CONFIG, "100000", + AdminClientConfig.RETRY_BACKOFF_MS_CONFIG, "0"))) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + env.kafkaClient().setDisconnectFuture(disconnectFuture); + final ListTopicsResult result = env.adminClient().listTopics(); + TestUtils.waitForCondition(() -> { + time.sleep(1); + return disconnectFuture.isDone(); + }, 5000, 1, () -> "Timed out waiting for expected disconnect"); + assertFalse(disconnectFuture.isCompletedExceptionally()); + assertFalse(result.future.isDone()); + TestUtils.waitForCondition(env.kafkaClient()::hasInFlightRequests, + "Timed out waiting for retry"); + env.kafkaClient().respond(prepareMetadataResponse(cluster, Errors.NONE)); + assertEquals(0, result.listings().get().size()); + } + } + + private UnregisterBrokerResponse prepareUnregisterBrokerResponse(Errors error, int throttleTimeMs) { + return new UnregisterBrokerResponse(new UnregisterBrokerResponseData() + .setErrorCode(error.code()) + .setErrorMessage(error.message()) + .setThrottleTimeMs(throttleTimeMs)); + } + + private DescribeLogDirsResponse prepareDescribeLogDirsResponse(Errors error, String logDir) { + return new DescribeLogDirsResponse(new DescribeLogDirsResponseData() + .setResults(Collections.singletonList( + new DescribeLogDirsResponseData.DescribeLogDirsResult() + .setErrorCode(error.code()) + .setLogDir(logDir)))); + } + + private static MemberDescription convertToMemberDescriptions(DescribedGroupMember member, + MemberAssignment assignment) { + return new MemberDescription(member.memberId(), + Optional.ofNullable(member.groupInstanceId()), + member.clientId(), + member.clientHost(), + assignment); + } + + @SafeVarargs + private static void assertCollectionIs(Collection collection, T... elements) { + for (T element : elements) { + assertTrue(collection.contains(element), "Did not find " + element); + } + assertEquals(elements.length, collection.size(), "There are unexpected extra elements in the collection."); + } + + public static KafkaAdminClient createInternal(AdminClientConfig config, KafkaAdminClient.TimeoutProcessorFactory timeoutProcessorFactory) { + return KafkaAdminClient.createInternal(config, timeoutProcessorFactory); + } + + public static class FailureInjectingTimeoutProcessorFactory extends KafkaAdminClient.TimeoutProcessorFactory { + + private int numTries = 0; + + private int failuresInjected = 0; + + @Override + public KafkaAdminClient.TimeoutProcessor create(long now) { + return new FailureInjectingTimeoutProcessor(now); + } + + synchronized boolean shouldInjectFailure() { + numTries++; + if (numTries == 1) { + failuresInjected++; + return true; + } + return false; + } + + public synchronized int failuresInjected() { + return failuresInjected; + } + + public final class FailureInjectingTimeoutProcessor extends KafkaAdminClient.TimeoutProcessor { + public FailureInjectingTimeoutProcessor(long now) { + super(now); + } + + boolean callHasExpired(KafkaAdminClient.Call call) { + if ((!call.isInternal()) && shouldInjectFailure()) { + log.debug("Injecting timeout for {}.", call); + return true; + } else { + boolean ret = super.callHasExpired(call); + log.debug("callHasExpired({}) = {}", call, ret); + return ret; + } + } + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/ListTransactionsResultTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/ListTransactionsResultTest.java new file mode 100644 index 0000000..769c50b --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/ListTransactionsResultTest.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.api.Test; + +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static org.apache.kafka.test.TestUtils.assertFutureThrows; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class ListTransactionsResultTest { + private final KafkaFutureImpl>>> future = + new KafkaFutureImpl<>(); + private final ListTransactionsResult result = new ListTransactionsResult(future); + + @Test + public void testAllFuturesFailIfLookupFails() { + future.completeExceptionally(new KafkaException()); + assertFutureThrows(result.all(), KafkaException.class); + assertFutureThrows(result.allByBrokerId(), KafkaException.class); + assertFutureThrows(result.byBrokerId(), KafkaException.class); + } + + @Test + public void testAllFuturesSucceed() throws Exception { + KafkaFutureImpl> future1 = new KafkaFutureImpl<>(); + KafkaFutureImpl> future2 = new KafkaFutureImpl<>(); + + Map>> brokerFutures = new HashMap<>(); + brokerFutures.put(1, future1); + brokerFutures.put(2, future2); + + future.complete(brokerFutures); + + List broker1Listings = asList( + new TransactionListing("foo", 12345L, TransactionState.ONGOING), + new TransactionListing("bar", 98765L, TransactionState.PREPARE_ABORT) + ); + future1.complete(broker1Listings); + + List broker2Listings = singletonList( + new TransactionListing("baz", 13579L, TransactionState.COMPLETE_COMMIT) + ); + future2.complete(broker2Listings); + + Map>> resultBrokerFutures = + result.byBrokerId().get(); + + assertEquals(Utils.mkSet(1, 2), resultBrokerFutures.keySet()); + assertEquals(broker1Listings, resultBrokerFutures.get(1).get()); + assertEquals(broker2Listings, resultBrokerFutures.get(2).get()); + assertEquals(broker1Listings, result.allByBrokerId().get().get(1)); + assertEquals(broker2Listings, result.allByBrokerId().get().get(2)); + + Set allExpected = new HashSet<>(); + allExpected.addAll(broker1Listings); + allExpected.addAll(broker2Listings); + + assertEquals(allExpected, new HashSet<>(result.all().get())); + } + + @Test + public void testPartialFailure() throws Exception { + KafkaFutureImpl> future1 = new KafkaFutureImpl<>(); + KafkaFutureImpl> future2 = new KafkaFutureImpl<>(); + + Map>> brokerFutures = new HashMap<>(); + brokerFutures.put(1, future1); + brokerFutures.put(2, future2); + + future.complete(brokerFutures); + + List broker1Listings = asList( + new TransactionListing("foo", 12345L, TransactionState.ONGOING), + new TransactionListing("bar", 98765L, TransactionState.PREPARE_ABORT) + ); + future1.complete(broker1Listings); + future2.completeExceptionally(new KafkaException()); + + Map>> resultBrokerFutures = + result.byBrokerId().get(); + + // Ensure that the future for broker 1 completes successfully + assertEquals(Utils.mkSet(1, 2), resultBrokerFutures.keySet()); + assertEquals(broker1Listings, resultBrokerFutures.get(1).get()); + + // Everything else should fail + assertFutureThrows(result.all(), KafkaException.class); + assertFutureThrows(result.allByBrokerId(), KafkaException.class); + assertFutureThrows(resultBrokerFutures.get(2), KafkaException.class); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/MemberDescriptionTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/MemberDescriptionTest.java new file mode 100644 index 0000000..f0140f7 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/MemberDescriptionTest.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.TopicPartition; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; + +public class MemberDescriptionTest { + + private static final String MEMBER_ID = "member_id"; + private static final Optional INSTANCE_ID = Optional.of("instanceId"); + private static final String CLIENT_ID = "client_id"; + private static final String HOST = "host"; + private static final MemberAssignment ASSIGNMENT; + private static final MemberDescription STATIC_MEMBER_DESCRIPTION; + + static { + ASSIGNMENT = new MemberAssignment(Collections.singleton(new TopicPartition("topic", 1))); + STATIC_MEMBER_DESCRIPTION = new MemberDescription(MEMBER_ID, + INSTANCE_ID, + CLIENT_ID, + HOST, + ASSIGNMENT); + } + + @Test + public void testEqualsWithoutGroupInstanceId() { + MemberDescription dynamicMemberDescription = new MemberDescription(MEMBER_ID, + CLIENT_ID, + HOST, + ASSIGNMENT); + + MemberDescription identityDescription = new MemberDescription(MEMBER_ID, + CLIENT_ID, + HOST, + ASSIGNMENT); + + assertNotEquals(STATIC_MEMBER_DESCRIPTION, dynamicMemberDescription); + assertNotEquals(STATIC_MEMBER_DESCRIPTION.hashCode(), dynamicMemberDescription.hashCode()); + + // Check self equality. + assertEquals(dynamicMemberDescription, dynamicMemberDescription); + assertEquals(dynamicMemberDescription, identityDescription); + assertEquals(dynamicMemberDescription.hashCode(), identityDescription.hashCode()); + } + + @Test + public void testEqualsWithGroupInstanceId() { + // Check self equality. + assertEquals(STATIC_MEMBER_DESCRIPTION, STATIC_MEMBER_DESCRIPTION); + + MemberDescription identityDescription = new MemberDescription(MEMBER_ID, + INSTANCE_ID, + CLIENT_ID, + HOST, + ASSIGNMENT); + + assertEquals(STATIC_MEMBER_DESCRIPTION, identityDescription); + assertEquals(STATIC_MEMBER_DESCRIPTION.hashCode(), identityDescription.hashCode()); + } + + @Test + public void testNonEqual() { + MemberDescription newMemberDescription = new MemberDescription("new_member", + INSTANCE_ID, + CLIENT_ID, + HOST, + ASSIGNMENT); + + assertNotEquals(STATIC_MEMBER_DESCRIPTION, newMemberDescription); + assertNotEquals(STATIC_MEMBER_DESCRIPTION.hashCode(), newMemberDescription.hashCode()); + + MemberDescription newInstanceDescription = new MemberDescription(MEMBER_ID, + Optional.of("new_instance"), + CLIENT_ID, + HOST, + ASSIGNMENT); + + assertNotEquals(STATIC_MEMBER_DESCRIPTION, newInstanceDescription); + assertNotEquals(STATIC_MEMBER_DESCRIPTION.hashCode(), newInstanceDescription.hashCode()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/MockAdminClient.java b/clients/src/test/java/org/apache/kafka/clients/admin/MockAdminClient.java new file mode 100644 index 0000000..473edae --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/MockAdminClient.java @@ -0,0 +1,1064 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.clients.admin.DescribeReplicaLogDirsResult.ReplicaLogDirInfo; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.ElectionType; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicCollection; +import org.apache.kafka.common.TopicCollection.TopicIdCollection; +import org.apache.kafka.common.TopicCollection.TopicNameCollection; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.TopicPartitionInfo; +import org.apache.kafka.common.TopicPartitionReplica; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.acl.AclBinding; +import org.apache.kafka.common.acl.AclBindingFilter; +import org.apache.kafka.common.acl.AclOperation; +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.errors.InvalidReplicationFactorException; +import org.apache.kafka.common.errors.InvalidRequestException; +import org.apache.kafka.common.errors.KafkaStorageException; +import org.apache.kafka.common.errors.ReplicaNotAvailableException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.TopicExistsException; +import org.apache.kafka.common.errors.UnknownTopicIdException; +import org.apache.kafka.common.errors.UnknownTopicOrPartitionException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.quota.ClientQuotaAlteration; +import org.apache.kafka.common.quota.ClientQuotaFilter; +import org.apache.kafka.common.requests.DescribeLogDirsResponse; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +public class MockAdminClient extends AdminClient { + public static final String DEFAULT_CLUSTER_ID = "I4ZmrWqfT2e-upky_4fdPA"; + + public static final List DEFAULT_LOG_DIRS = + Collections.singletonList("/tmp/kafka-logs"); + + private final List brokers; + private final Map allTopics = new HashMap<>(); + private final Map topicIds = new HashMap<>(); + private final Map topicNames = new HashMap<>(); + private final Map reassignments = + new HashMap<>(); + private final Map replicaMoves = + new HashMap<>(); + private final Map beginningOffsets; + private final Map endOffsets; + private final boolean usingRaftController; + private final String clusterId; + private final List> brokerLogDirs; + private final List> brokerConfigs; + + private Node controller; + private int timeoutNextRequests = 0; + private final int defaultPartitions; + private final int defaultReplicationFactor; + + private Map mockMetrics = new HashMap<>(); + + public static Builder create() { + return new Builder(); + } + + public static class Builder { + private String clusterId = DEFAULT_CLUSTER_ID; + private List brokers = new ArrayList<>(); + private Node controller = null; + private List> brokerLogDirs = new ArrayList<>(); + private Short defaultPartitions; + private boolean usingRaftController = false; + private Integer defaultReplicationFactor; + + public Builder() { + numBrokers(1); + } + + public Builder clusterId(String clusterId) { + this.clusterId = clusterId; + return this; + } + + public Builder brokers(List brokers) { + numBrokers(brokers.size()); + this.brokers = brokers; + return this; + } + + public Builder numBrokers(int numBrokers) { + if (brokers.size() >= numBrokers) { + brokers = brokers.subList(0, numBrokers); + brokerLogDirs = brokerLogDirs.subList(0, numBrokers); + } else { + for (int id = brokers.size(); id < numBrokers; id++) { + brokers.add(new Node(id, "localhost", 1000 + id)); + brokerLogDirs.add(DEFAULT_LOG_DIRS); + } + } + return this; + } + + public Builder controller(int index) { + this.controller = brokers.get(index); + return this; + } + + public Builder brokerLogDirs(List> brokerLogDirs) { + this.brokerLogDirs = brokerLogDirs; + return this; + } + + public Builder defaultReplicationFactor(int defaultReplicationFactor) { + this.defaultReplicationFactor = defaultReplicationFactor; + return this; + } + + public Builder usingRaftController(boolean usingRaftController) { + this.usingRaftController = usingRaftController; + return this; + } + + public Builder defaultPartitions(short numPartitions) { + this.defaultPartitions = numPartitions; + return this; + } + + public MockAdminClient build() { + return new MockAdminClient(brokers, + controller == null ? brokers.get(0) : controller, + clusterId, + defaultPartitions != null ? defaultPartitions.shortValue() : 1, + defaultReplicationFactor != null ? defaultReplicationFactor.shortValue() : Math.min(brokers.size(), 3), + brokerLogDirs, + usingRaftController); + } + } + + public MockAdminClient() { + this(Collections.singletonList(Node.noNode()), Node.noNode()); + } + + public MockAdminClient(List brokers, Node controller) { + this(brokers, controller, DEFAULT_CLUSTER_ID, 1, brokers.size(), + Collections.nCopies(brokers.size(), DEFAULT_LOG_DIRS), false); + } + + private MockAdminClient(List brokers, + Node controller, + String clusterId, + int defaultPartitions, + int defaultReplicationFactor, + List> brokerLogDirs, + boolean usingRaftController) { + this.brokers = brokers; + controller(controller); + this.clusterId = clusterId; + this.defaultPartitions = defaultPartitions; + this.defaultReplicationFactor = defaultReplicationFactor; + this.brokerLogDirs = brokerLogDirs; + this.brokerConfigs = new ArrayList<>(); + for (int i = 0; i < brokers.size(); i++) { + this.brokerConfigs.add(new HashMap<>()); + } + this.beginningOffsets = new HashMap<>(); + this.endOffsets = new HashMap<>(); + this.usingRaftController = usingRaftController; + } + + synchronized public void controller(Node controller) { + if (!brokers.contains(controller)) + throw new IllegalArgumentException("The controller node must be in the list of brokers"); + this.controller = controller; + } + + public void addTopic(boolean internal, + String name, + List partitions, + Map configs) { + addTopic(internal, name, partitions, configs, true); + } + + synchronized public void addTopic(boolean internal, + String name, + List partitions, + Map configs, + boolean usesTopicId) { + if (allTopics.containsKey(name)) { + throw new IllegalArgumentException(String.format("Topic %s was already added.", name)); + } + for (TopicPartitionInfo partition : partitions) { + if (!brokers.contains(partition.leader())) { + throw new IllegalArgumentException("Leader broker unknown"); + } + if (!brokers.containsAll(partition.replicas())) { + throw new IllegalArgumentException("Unknown brokers in replica list"); + } + if (!brokers.containsAll(partition.isr())) { + throw new IllegalArgumentException("Unknown brokers in isr list"); + } + } + ArrayList logDirs = new ArrayList<>(); + for (TopicPartitionInfo partition : partitions) { + if (partition.leader() != null) { + logDirs.add(brokerLogDirs.get(partition.leader().id()).get(0)); + } + } + Uuid topicId; + if (usesTopicId) { + topicId = Uuid.randomUuid(); + topicIds.put(name, topicId); + topicNames.put(topicId, name); + } else { + topicId = Uuid.ZERO_UUID; + } + allTopics.put(name, new TopicMetadata(topicId, internal, partitions, logDirs, configs)); + } + + synchronized public void markTopicForDeletion(final String name) { + if (!allTopics.containsKey(name)) { + throw new IllegalArgumentException(String.format("Topic %s did not exist.", name)); + } + + allTopics.get(name).markedForDeletion = true; + } + + synchronized public void timeoutNextRequest(int numberOfRequest) { + timeoutNextRequests = numberOfRequest; + } + + @Override + synchronized public DescribeClusterResult describeCluster(DescribeClusterOptions options) { + KafkaFutureImpl> nodesFuture = new KafkaFutureImpl<>(); + KafkaFutureImpl controllerFuture = new KafkaFutureImpl<>(); + KafkaFutureImpl brokerIdFuture = new KafkaFutureImpl<>(); + KafkaFutureImpl> authorizedOperationsFuture = new KafkaFutureImpl<>(); + + if (timeoutNextRequests > 0) { + nodesFuture.completeExceptionally(new TimeoutException()); + controllerFuture.completeExceptionally(new TimeoutException()); + brokerIdFuture.completeExceptionally(new TimeoutException()); + authorizedOperationsFuture.completeExceptionally(new TimeoutException()); + --timeoutNextRequests; + } else { + nodesFuture.complete(brokers); + controllerFuture.complete(controller); + brokerIdFuture.complete(clusterId); + authorizedOperationsFuture.complete(Collections.emptySet()); + } + + return new DescribeClusterResult(nodesFuture, controllerFuture, brokerIdFuture, authorizedOperationsFuture); + } + + @Override + synchronized public CreateTopicsResult createTopics(Collection newTopics, CreateTopicsOptions options) { + Map> createTopicResult = new HashMap<>(); + + if (timeoutNextRequests > 0) { + for (final NewTopic newTopic : newTopics) { + String topicName = newTopic.name(); + + KafkaFutureImpl future = new KafkaFutureImpl<>(); + future.completeExceptionally(new TimeoutException()); + createTopicResult.put(topicName, future); + } + + --timeoutNextRequests; + return new CreateTopicsResult(createTopicResult); + } + + for (final NewTopic newTopic : newTopics) { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + + String topicName = newTopic.name(); + if (allTopics.containsKey(topicName)) { + future.completeExceptionally(new TopicExistsException(String.format("Topic %s exists already.", topicName))); + createTopicResult.put(topicName, future); + continue; + } + int replicationFactor = newTopic.replicationFactor(); + if (replicationFactor == -1) { + replicationFactor = defaultReplicationFactor; + } + if (replicationFactor > brokers.size()) { + future.completeExceptionally(new InvalidReplicationFactorException( + String.format("Replication factor: %d is larger than brokers: %d", newTopic.replicationFactor(), brokers.size()))); + createTopicResult.put(topicName, future); + continue; + } + + List replicas = new ArrayList<>(replicationFactor); + for (int i = 0; i < replicationFactor; ++i) { + replicas.add(brokers.get(i)); + } + + int numberOfPartitions = newTopic.numPartitions(); + if (numberOfPartitions == -1) { + numberOfPartitions = defaultPartitions; + } + List partitions = new ArrayList<>(numberOfPartitions); + // Partitions start off on the first log directory of each broker, for now. + List logDirs = new ArrayList<>(numberOfPartitions); + for (int i = 0; i < numberOfPartitions; i++) { + partitions.add(new TopicPartitionInfo(i, brokers.get(0), replicas, Collections.emptyList())); + logDirs.add(brokerLogDirs.get(partitions.get(i).leader().id()).get(0)); + } + Uuid topicId = Uuid.randomUuid(); + topicIds.put(topicName, topicId); + topicNames.put(topicId, topicName); + allTopics.put(topicName, new TopicMetadata(topicId, false, partitions, logDirs, newTopic.configs())); + future.complete(null); + createTopicResult.put(topicName, future); + } + + return new CreateTopicsResult(createTopicResult); + } + + @Override + synchronized public ListTopicsResult listTopics(ListTopicsOptions options) { + Map topicListings = new HashMap<>(); + + if (timeoutNextRequests > 0) { + KafkaFutureImpl> future = new KafkaFutureImpl<>(); + future.completeExceptionally(new TimeoutException()); + + --timeoutNextRequests; + return new ListTopicsResult(future); + } + + for (Map.Entry topicDescription : allTopics.entrySet()) { + String topicName = topicDescription.getKey(); + if (topicDescription.getValue().fetchesRemainingUntilVisible > 0) { + topicDescription.getValue().fetchesRemainingUntilVisible--; + } else { + topicListings.put(topicName, new TopicListing(topicName, topicDescription.getValue().topicId, topicDescription.getValue().isInternalTopic)); + } + } + + KafkaFutureImpl> future = new KafkaFutureImpl<>(); + future.complete(topicListings); + return new ListTopicsResult(future); + } + + @Override + synchronized public DescribeTopicsResult describeTopics(TopicCollection topics, DescribeTopicsOptions options) { + if (topics instanceof TopicIdCollection) + return DescribeTopicsResult.ofTopicIds(new HashMap<>(handleDescribeTopicsUsingIds(((TopicIdCollection) topics).topicIds(), options))); + else if (topics instanceof TopicNameCollection) + return DescribeTopicsResult.ofTopicNames(new HashMap<>(handleDescribeTopicsByNames(((TopicNameCollection) topics).topicNames(), options))); + else + throw new IllegalArgumentException("The TopicCollection provided did not match any supported classes for describeTopics."); + } + + private Map> handleDescribeTopicsByNames(Collection topicNames, DescribeTopicsOptions options) { + Map> topicDescriptions = new HashMap<>(); + + if (timeoutNextRequests > 0) { + for (String requestedTopic : topicNames) { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + future.completeExceptionally(new TimeoutException()); + topicDescriptions.put(requestedTopic, future); + } + + --timeoutNextRequests; + return topicDescriptions; + } + + for (String requestedTopic : topicNames) { + for (Map.Entry topicDescription : allTopics.entrySet()) { + String topicName = topicDescription.getKey(); + Uuid topicId = topicIds.getOrDefault(topicName, Uuid.ZERO_UUID); + if (topicName.equals(requestedTopic) && !topicDescription.getValue().markedForDeletion) { + if (topicDescription.getValue().fetchesRemainingUntilVisible > 0) { + topicDescription.getValue().fetchesRemainingUntilVisible--; + } else { + TopicMetadata topicMetadata = topicDescription.getValue(); + KafkaFutureImpl future = new KafkaFutureImpl<>(); + future.complete(new TopicDescription(topicName, topicMetadata.isInternalTopic, topicMetadata.partitions, Collections.emptySet(), topicId)); + topicDescriptions.put(topicName, future); + break; + } + } + } + if (!topicDescriptions.containsKey(requestedTopic)) { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + future.completeExceptionally(new UnknownTopicOrPartitionException("Topic " + requestedTopic + " not found.")); + topicDescriptions.put(requestedTopic, future); + } + } + + return topicDescriptions; + } + + synchronized public Map> handleDescribeTopicsUsingIds(Collection topicIds, DescribeTopicsOptions options) { + + Map> topicDescriptions = new HashMap<>(); + + if (timeoutNextRequests > 0) { + for (Uuid requestedTopicId : topicIds) { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + future.completeExceptionally(new TimeoutException()); + topicDescriptions.put(requestedTopicId, future); + } + + --timeoutNextRequests; + return topicDescriptions; + } + + for (Uuid requestedTopicId : topicIds) { + for (Map.Entry topicDescription : allTopics.entrySet()) { + String topicName = topicDescription.getKey(); + Uuid topicId = this.topicIds.get(topicName); + + if (topicId != null && topicId.equals(requestedTopicId) && !topicDescription.getValue().markedForDeletion) { + if (topicDescription.getValue().fetchesRemainingUntilVisible > 0) { + topicDescription.getValue().fetchesRemainingUntilVisible--; + } else { + TopicMetadata topicMetadata = topicDescription.getValue(); + KafkaFutureImpl future = new KafkaFutureImpl<>(); + future.complete(new TopicDescription(topicName, topicMetadata.isInternalTopic, topicMetadata.partitions, Collections.emptySet(), topicId)); + topicDescriptions.put(requestedTopicId, future); + break; + } + } + } + if (!topicDescriptions.containsKey(requestedTopicId)) { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + future.completeExceptionally(new UnknownTopicIdException("Topic id" + requestedTopicId + " not found.")); + topicDescriptions.put(requestedTopicId, future); + } + } + + return topicDescriptions; + } + + @Override + synchronized public DeleteTopicsResult deleteTopics(TopicCollection topics, DeleteTopicsOptions options) { + DeleteTopicsResult result; + if (topics instanceof TopicIdCollection) + result = DeleteTopicsResult.ofTopicIds(new HashMap<>(handleDeleteTopicsUsingIds(((TopicIdCollection) topics).topicIds(), options))); + else if (topics instanceof TopicNameCollection) + result = DeleteTopicsResult.ofTopicNames(new HashMap<>(handleDeleteTopicsUsingNames(((TopicNameCollection) topics).topicNames(), options))); + else + throw new IllegalArgumentException("The TopicCollection provided did not match any supported classes for deleteTopics."); + return result; + } + + private Map> handleDeleteTopicsUsingNames(Collection topicNameCollection, DeleteTopicsOptions options) { + Map> deleteTopicsResult = new HashMap<>(); + Collection topicNames = new ArrayList<>(topicNameCollection); + + if (timeoutNextRequests > 0) { + for (final String topicName : topicNames) { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + future.completeExceptionally(new TimeoutException()); + deleteTopicsResult.put(topicName, future); + } + + --timeoutNextRequests; + return deleteTopicsResult; + } + + for (final String topicName : topicNames) { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + + if (allTopics.remove(topicName) == null) { + future.completeExceptionally(new UnknownTopicOrPartitionException(String.format("Topic %s does not exist.", topicName))); + } else { + topicNames.remove(topicIds.remove(topicName)); + future.complete(null); + } + deleteTopicsResult.put(topicName, future); + } + return deleteTopicsResult; + } + + private Map> handleDeleteTopicsUsingIds(Collection topicIdCollection, DeleteTopicsOptions options) { + Map> deleteTopicsResult = new HashMap<>(); + Collection topicIds = new ArrayList<>(topicIdCollection); + + if (timeoutNextRequests > 0) { + for (final Uuid topicId : topicIds) { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + future.completeExceptionally(new TimeoutException()); + deleteTopicsResult.put(topicId, future); + } + + --timeoutNextRequests; + return deleteTopicsResult; + } + + for (final Uuid topicId : topicIds) { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + + String name = topicNames.remove(topicId); + if (name == null || allTopics.remove(name) == null) { + future.completeExceptionally(new UnknownTopicOrPartitionException(String.format("Topic %s does not exist.", topicId))); + } else { + topicIds.remove(name); + future.complete(null); + } + deleteTopicsResult.put(topicId, future); + } + return deleteTopicsResult; + } + + @Override + synchronized public CreatePartitionsResult createPartitions(Map newPartitions, CreatePartitionsOptions options) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + synchronized public DeleteRecordsResult deleteRecords(Map recordsToDelete, DeleteRecordsOptions options) { + Map> deletedRecordsResult = new HashMap<>(); + if (recordsToDelete.isEmpty()) { + return new DeleteRecordsResult(deletedRecordsResult); + } else { + throw new UnsupportedOperationException("Not implemented yet"); + } + } + + @Override + synchronized public CreateDelegationTokenResult createDelegationToken(CreateDelegationTokenOptions options) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + synchronized public RenewDelegationTokenResult renewDelegationToken(byte[] hmac, RenewDelegationTokenOptions options) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + synchronized public ExpireDelegationTokenResult expireDelegationToken(byte[] hmac, ExpireDelegationTokenOptions options) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + synchronized public DescribeDelegationTokenResult describeDelegationToken(DescribeDelegationTokenOptions options) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + synchronized public DescribeConsumerGroupsResult describeConsumerGroups(Collection groupIds, DescribeConsumerGroupsOptions options) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + synchronized public ListConsumerGroupsResult listConsumerGroups(ListConsumerGroupsOptions options) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + synchronized public ListConsumerGroupOffsetsResult listConsumerGroupOffsets(String groupId, ListConsumerGroupOffsetsOptions options) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + synchronized public DeleteConsumerGroupsResult deleteConsumerGroups(Collection groupIds, DeleteConsumerGroupsOptions options) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + synchronized public DeleteConsumerGroupOffsetsResult deleteConsumerGroupOffsets(String groupId, Set partitions, DeleteConsumerGroupOffsetsOptions options) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + synchronized public ElectLeadersResult electLeaders( + ElectionType electionType, + Set partitions, + ElectLeadersOptions options) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + synchronized public RemoveMembersFromConsumerGroupResult removeMembersFromConsumerGroup(String groupId, RemoveMembersFromConsumerGroupOptions options) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + synchronized public CreateAclsResult createAcls(Collection acls, CreateAclsOptions options) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + synchronized public DescribeAclsResult describeAcls(AclBindingFilter filter, DescribeAclsOptions options) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + synchronized public DeleteAclsResult deleteAcls(Collection filters, DeleteAclsOptions options) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + synchronized public DescribeConfigsResult describeConfigs(Collection resources, DescribeConfigsOptions options) { + + if (timeoutNextRequests > 0) { + Map> configs = new HashMap<>(); + for (ConfigResource requestedResource : resources) { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + future.completeExceptionally(new TimeoutException()); + configs.put(requestedResource, future); + } + + --timeoutNextRequests; + return new DescribeConfigsResult(configs); + } + + Map> results = new HashMap<>(); + for (ConfigResource resource : resources) { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + results.put(resource, future); + try { + future.complete(getResourceDescription(resource)); + } catch (Throwable e) { + future.completeExceptionally(e); + } + } + return new DescribeConfigsResult(results); + } + + synchronized private Config getResourceDescription(ConfigResource resource) { + switch (resource.type()) { + case BROKER: { + int brokerId = Integer.parseInt(resource.name()); + if (brokerId >= brokerConfigs.size()) { + throw new InvalidRequestException("Broker " + resource.name() + + " not found."); + } + return toConfigObject(brokerConfigs.get(brokerId)); + } + case TOPIC: { + TopicMetadata topicMetadata = allTopics.get(resource.name()); + if (topicMetadata != null && !topicMetadata.markedForDeletion) { + if (topicMetadata.fetchesRemainingUntilVisible > 0) + topicMetadata.fetchesRemainingUntilVisible = Math.max(0, topicMetadata.fetchesRemainingUntilVisible - 1); + else return toConfigObject(topicMetadata.configs); + + } + throw new UnknownTopicOrPartitionException("Resource " + resource + " not found."); + } + default: + throw new UnsupportedOperationException("Not implemented yet"); + } + } + + private static Config toConfigObject(Map map) { + List configEntries = new ArrayList<>(); + for (Map.Entry entry : map.entrySet()) { + configEntries.add(new ConfigEntry(entry.getKey(), entry.getValue())); + } + return new Config(configEntries); + } + + @Override + @Deprecated + synchronized public AlterConfigsResult alterConfigs(Map configs, AlterConfigsOptions options) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + synchronized public AlterConfigsResult incrementalAlterConfigs( + Map> configs, + AlterConfigsOptions options) { + Map> futures = new HashMap<>(); + for (Map.Entry> entry : + configs.entrySet()) { + ConfigResource resource = entry.getKey(); + KafkaFutureImpl future = new KafkaFutureImpl<>(); + futures.put(resource, future); + Throwable throwable = + handleIncrementalResourceAlteration(resource, entry.getValue()); + if (throwable == null) { + future.complete(null); + } else { + future.completeExceptionally(throwable); + } + } + return new AlterConfigsResult(futures); + } + + synchronized private Throwable handleIncrementalResourceAlteration( + ConfigResource resource, Collection ops) { + switch (resource.type()) { + case BROKER: { + int brokerId; + try { + brokerId = Integer.valueOf(resource.name()); + } catch (NumberFormatException e) { + return e; + } + if (brokerId >= brokerConfigs.size()) { + return new InvalidRequestException("no such broker as " + brokerId); + } + HashMap newMap = new HashMap<>(brokerConfigs.get(brokerId)); + for (AlterConfigOp op : ops) { + switch (op.opType()) { + case SET: + newMap.put(op.configEntry().name(), op.configEntry().value()); + break; + case DELETE: + newMap.remove(op.configEntry().name()); + break; + default: + return new InvalidRequestException( + "Unsupported op type " + op.opType()); + } + } + brokerConfigs.set(brokerId, newMap); + return null; + } + case TOPIC: { + TopicMetadata topicMetadata = allTopics.get(resource.name()); + if (topicMetadata == null) { + return new UnknownTopicOrPartitionException("No such topic as " + + resource.name()); + } + HashMap newMap = new HashMap<>(topicMetadata.configs); + for (AlterConfigOp op : ops) { + switch (op.opType()) { + case SET: + newMap.put(op.configEntry().name(), op.configEntry().value()); + break; + case DELETE: + newMap.remove(op.configEntry().name()); + break; + default: + return new InvalidRequestException( + "Unsupported op type " + op.opType()); + } + } + topicMetadata.configs = newMap; + return null; + } + default: + return new UnsupportedOperationException(); + } + } + + @Override + synchronized public AlterReplicaLogDirsResult alterReplicaLogDirs( + Map replicaAssignment, + AlterReplicaLogDirsOptions options) { + Map> results = new HashMap<>(); + for (Map.Entry entry : replicaAssignment.entrySet()) { + TopicPartitionReplica replica = entry.getKey(); + String newLogDir = entry.getValue(); + KafkaFutureImpl future = new KafkaFutureImpl<>(); + results.put(replica, future); + List dirs = brokerLogDirs.get(replica.brokerId()); + if (dirs == null) { + future.completeExceptionally( + new ReplicaNotAvailableException("Can't find " + replica)); + } else if (!dirs.contains(newLogDir)) { + future.completeExceptionally( + new KafkaStorageException("Log directory " + newLogDir + " is offline")); + } else { + TopicMetadata metadata = allTopics.get(replica.topic()); + if (metadata == null || metadata.partitions.size() <= replica.partition()) { + future.completeExceptionally( + new ReplicaNotAvailableException("Can't find " + replica)); + } else { + String currentLogDir = metadata.partitionLogDirs.get(replica.partition()); + replicaMoves.put(replica, + new ReplicaLogDirInfo(currentLogDir, 0, newLogDir, 0)); + future.complete(null); + } + } + } + return new AlterReplicaLogDirsResult(results); + } + + @Override + synchronized public DescribeLogDirsResult describeLogDirs(Collection brokers, + DescribeLogDirsOptions options) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + synchronized public DescribeReplicaLogDirsResult describeReplicaLogDirs( + Collection replicas, DescribeReplicaLogDirsOptions options) { + Map> results = new HashMap<>(); + for (TopicPartitionReplica replica : replicas) { + TopicMetadata topicMetadata = allTopics.get(replica.topic()); + if (topicMetadata != null) { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + results.put(replica, future); + String currentLogDir = currentLogDir(replica); + if (currentLogDir == null) { + future.complete(new ReplicaLogDirInfo(null, + DescribeLogDirsResponse.INVALID_OFFSET_LAG, + null, + DescribeLogDirsResponse.INVALID_OFFSET_LAG)); + } else { + ReplicaLogDirInfo info = replicaMoves.get(replica); + if (info == null) { + future.complete(new ReplicaLogDirInfo(currentLogDir, 0, null, 0)); + } else { + future.complete(info); + } + } + } + } + return new DescribeReplicaLogDirsResult(results); + } + + private synchronized String currentLogDir(TopicPartitionReplica replica) { + TopicMetadata topicMetadata = allTopics.get(replica.topic()); + if (topicMetadata == null) { + return null; + } + if (topicMetadata.partitionLogDirs.size() <= replica.partition()) { + return null; + } + return topicMetadata.partitionLogDirs.get(replica.partition()); + } + + @Override + synchronized public AlterPartitionReassignmentsResult alterPartitionReassignments( + Map> newReassignments, + AlterPartitionReassignmentsOptions options) { + Map> futures = new HashMap<>(); + for (Map.Entry> entry : + newReassignments.entrySet()) { + TopicPartition partition = entry.getKey(); + Optional newReassignment = entry.getValue(); + KafkaFutureImpl future = new KafkaFutureImpl(); + futures.put(partition, future); + TopicMetadata topicMetadata = allTopics.get(partition.topic()); + if (partition.partition() < 0 || + topicMetadata == null || + topicMetadata.partitions.size() <= partition.partition()) { + future.completeExceptionally(new UnknownTopicOrPartitionException()); + } else if (newReassignment.isPresent()) { + reassignments.put(partition, newReassignment.get()); + future.complete(null); + } else { + reassignments.remove(partition); + future.complete(null); + } + } + return new AlterPartitionReassignmentsResult(futures); + } + + @Override + synchronized public ListPartitionReassignmentsResult listPartitionReassignments( + Optional> partitions, + ListPartitionReassignmentsOptions options) { + Map map = new HashMap<>(); + for (TopicPartition partition : partitions.isPresent() ? + partitions.get() : reassignments.keySet()) { + PartitionReassignment reassignment = findPartitionReassignment(partition); + if (reassignment != null) { + map.put(partition, reassignment); + } + } + return new ListPartitionReassignmentsResult(KafkaFutureImpl.completedFuture(map)); + } + + synchronized private PartitionReassignment findPartitionReassignment(TopicPartition partition) { + NewPartitionReassignment reassignment = reassignments.get(partition); + if (reassignment == null) { + return null; + } + TopicMetadata metadata = allTopics.get(partition.topic()); + if (metadata == null) { + throw new RuntimeException("Internal MockAdminClient logic error: found " + + "reassignment for " + partition + ", but no TopicMetadata"); + } + TopicPartitionInfo info = metadata.partitions.get(partition.partition()); + if (info == null) { + throw new RuntimeException("Internal MockAdminClient logic error: found " + + "reassignment for " + partition + ", but no TopicPartitionInfo"); + } + List replicas = new ArrayList<>(); + List removingReplicas = new ArrayList<>(); + List addingReplicas = new ArrayList<>(reassignment.targetReplicas()); + for (Node node : info.replicas()) { + replicas.add(node.id()); + if (!reassignment.targetReplicas().contains(node.id())) { + removingReplicas.add(node.id()); + } + addingReplicas.remove(Integer.valueOf(node.id())); + } + return new PartitionReassignment(replicas, addingReplicas, removingReplicas); + } + + @Override + synchronized public AlterConsumerGroupOffsetsResult alterConsumerGroupOffsets(String groupId, Map offsets, AlterConsumerGroupOffsetsOptions options) { + throw new UnsupportedOperationException("Not implement yet"); + } + + @Override + synchronized public ListOffsetsResult listOffsets(Map topicPartitionOffsets, ListOffsetsOptions options) { + Map> futures = new HashMap<>(); + + for (Map.Entry entry : topicPartitionOffsets.entrySet()) { + TopicPartition tp = entry.getKey(); + OffsetSpec spec = entry.getValue(); + KafkaFutureImpl future = new KafkaFutureImpl<>(); + + if (spec instanceof OffsetSpec.TimestampSpec) + throw new UnsupportedOperationException("Not implement yet"); + else if (spec instanceof OffsetSpec.EarliestSpec) + future.complete(new ListOffsetsResult.ListOffsetsResultInfo(beginningOffsets.get(tp), -1, Optional.empty())); + else + future.complete(new ListOffsetsResult.ListOffsetsResultInfo(endOffsets.get(tp), -1, Optional.empty())); + + futures.put(tp, future); + } + + return new ListOffsetsResult(futures); + } + + @Override + public DescribeClientQuotasResult describeClientQuotas(ClientQuotaFilter filter, DescribeClientQuotasOptions options) { + throw new UnsupportedOperationException("Not implement yet"); + } + + @Override + public AlterClientQuotasResult alterClientQuotas(Collection entries, AlterClientQuotasOptions options) { + throw new UnsupportedOperationException("Not implement yet"); + } + + @Override + public DescribeUserScramCredentialsResult describeUserScramCredentials(List users, DescribeUserScramCredentialsOptions options) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + public AlterUserScramCredentialsResult alterUserScramCredentials(List alterations, AlterUserScramCredentialsOptions options) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + public DescribeFeaturesResult describeFeatures(DescribeFeaturesOptions options) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + public UpdateFeaturesResult updateFeatures(Map featureUpdates, UpdateFeaturesOptions options) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + public UnregisterBrokerResult unregisterBroker(int brokerId, UnregisterBrokerOptions options) { + if (usingRaftController) { + return new UnregisterBrokerResult(KafkaFuture.completedFuture(null)); + } else { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + future.completeExceptionally(new UnsupportedVersionException("")); + return new UnregisterBrokerResult(future); + } + } + + @Override + public DescribeProducersResult describeProducers(Collection partitions, DescribeProducersOptions options) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + public DescribeTransactionsResult describeTransactions(Collection transactionalIds, DescribeTransactionsOptions options) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + public AbortTransactionResult abortTransaction(AbortTransactionSpec spec, AbortTransactionOptions options) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + public ListTransactionsResult listTransactions(ListTransactionsOptions options) { + throw new UnsupportedOperationException("Not implemented yet"); + } + + @Override + synchronized public void close(Duration timeout) {} + + public synchronized void updateBeginningOffsets(Map newOffsets) { + beginningOffsets.putAll(newOffsets); + } + + public synchronized void updateEndOffsets(final Map newOffsets) { + endOffsets.putAll(newOffsets); + } + + private final static class TopicMetadata { + final Uuid topicId; + final boolean isInternalTopic; + final List partitions; + final List partitionLogDirs; + Map configs; + int fetchesRemainingUntilVisible; + + public boolean markedForDeletion; + + TopicMetadata(Uuid topicId, + boolean isInternalTopic, + List partitions, + List partitionLogDirs, + Map configs) { + this.topicId = topicId; + this.isInternalTopic = isInternalTopic; + this.partitions = partitions; + this.partitionLogDirs = partitionLogDirs; + this.configs = configs != null ? configs : Collections.emptyMap(); + this.markedForDeletion = false; + this.fetchesRemainingUntilVisible = 0; + } + } + + synchronized public void setMockMetrics(MetricName name, Metric metric) { + mockMetrics.put(name, metric); + } + + @Override + synchronized public Map metrics() { + return mockMetrics; + } + + synchronized public void setFetchesRemainingUntilVisible(String topicName, int fetchesRemainingUntilVisible) { + TopicMetadata metadata = allTopics.get(topicName); + if (metadata == null) { + throw new RuntimeException("No such topic as " + topicName); + } + metadata.fetchesRemainingUntilVisible = fetchesRemainingUntilVisible; + } + + synchronized public List brokers() { + return new ArrayList<>(brokers); + } + + synchronized public Node broker(int index) { + return brokers.get(index); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/RemoveMembersFromConsumerGroupOptionsTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/RemoveMembersFromConsumerGroupOptionsTest.java new file mode 100644 index 0000000..9d127d0 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/RemoveMembersFromConsumerGroupOptionsTest.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.junit.jupiter.api.Test; + +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class RemoveMembersFromConsumerGroupOptionsTest { + + @Test + public void testConstructor() { + RemoveMembersFromConsumerGroupOptions options = new RemoveMembersFromConsumerGroupOptions( + Collections.singleton(new MemberToRemove("instance-1"))); + + assertEquals(Collections.singleton( + new MemberToRemove("instance-1")), options.members()); + + // Construct will fail if illegal empty members provided + assertThrows(IllegalArgumentException.class, () -> new RemoveMembersFromConsumerGroupOptions(Collections.emptyList())); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/RemoveMembersFromConsumerGroupResultTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/RemoveMembersFromConsumerGroupResultTest.java new file mode 100644 index 0000000..b0ea46e --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/RemoveMembersFromConsumerGroupResultTest.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.errors.FencedInstanceIdException; +import org.apache.kafka.common.errors.GroupAuthorizationException; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.message.LeaveGroupRequestData.MemberIdentity; + +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ExecutionException; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class RemoveMembersFromConsumerGroupResultTest { + + private final MemberToRemove instanceOne = new MemberToRemove("instance-1"); + private final MemberToRemove instanceTwo = new MemberToRemove("instance-2"); + private Set membersToRemove; + private Map errorsMap; + + private KafkaFutureImpl> memberFutures; + + @BeforeEach + public void setUp() { + memberFutures = new KafkaFutureImpl<>(); + membersToRemove = new HashSet<>(); + membersToRemove.add(instanceOne); + membersToRemove.add(instanceTwo); + + errorsMap = new HashMap<>(); + errorsMap.put(instanceOne.toMemberIdentity(), Errors.NONE); + errorsMap.put(instanceTwo.toMemberIdentity(), Errors.FENCED_INSTANCE_ID); + } + + @Test + public void testTopLevelErrorConstructor() throws InterruptedException { + memberFutures.completeExceptionally(Errors.GROUP_AUTHORIZATION_FAILED.exception()); + RemoveMembersFromConsumerGroupResult topLevelErrorResult = + new RemoveMembersFromConsumerGroupResult(memberFutures, membersToRemove); + TestUtils.assertFutureError(topLevelErrorResult.all(), GroupAuthorizationException.class); + } + + @Test + public void testMemberLevelErrorConstructor() throws InterruptedException, ExecutionException { + createAndVerifyMemberLevelError(); + } + + @Test + public void testMemberMissingErrorInRequestConstructor() throws InterruptedException, ExecutionException { + errorsMap.remove(instanceTwo.toMemberIdentity()); + memberFutures.complete(errorsMap); + assertFalse(memberFutures.isCompletedExceptionally()); + RemoveMembersFromConsumerGroupResult missingMemberResult = + new RemoveMembersFromConsumerGroupResult(memberFutures, membersToRemove); + + TestUtils.assertFutureError(missingMemberResult.all(), IllegalArgumentException.class); + assertNull(missingMemberResult.memberResult(instanceOne).get()); + TestUtils.assertFutureError(missingMemberResult.memberResult(instanceTwo), IllegalArgumentException.class); + } + + @Test + public void testMemberLevelErrorInResponseConstructor() throws InterruptedException, ExecutionException { + RemoveMembersFromConsumerGroupResult memberLevelErrorResult = createAndVerifyMemberLevelError(); + assertThrows(IllegalArgumentException.class, () -> memberLevelErrorResult.memberResult( + new MemberToRemove("invalid-instance-id")) + ); + } + + @Test + public void testNoErrorConstructor() throws ExecutionException, InterruptedException { + Map errorsMap = new HashMap<>(); + errorsMap.put(instanceOne.toMemberIdentity(), Errors.NONE); + errorsMap.put(instanceTwo.toMemberIdentity(), Errors.NONE); + RemoveMembersFromConsumerGroupResult noErrorResult = + new RemoveMembersFromConsumerGroupResult(memberFutures, membersToRemove); + memberFutures.complete(errorsMap); + + assertNull(noErrorResult.all().get()); + assertNull(noErrorResult.memberResult(instanceOne).get()); + assertNull(noErrorResult.memberResult(instanceTwo).get()); + } + + private RemoveMembersFromConsumerGroupResult createAndVerifyMemberLevelError() throws InterruptedException, ExecutionException { + memberFutures.complete(errorsMap); + assertFalse(memberFutures.isCompletedExceptionally()); + RemoveMembersFromConsumerGroupResult memberLevelErrorResult = + new RemoveMembersFromConsumerGroupResult(memberFutures, membersToRemove); + + TestUtils.assertFutureError(memberLevelErrorResult.all(), FencedInstanceIdException.class); + assertNull(memberLevelErrorResult.memberResult(instanceOne).get()); + TestUtils.assertFutureError(memberLevelErrorResult.memberResult(instanceTwo), FencedInstanceIdException.class); + return memberLevelErrorResult; + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/ScramMechanismTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/ScramMechanismTest.java new file mode 100644 index 0000000..03ee051 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/ScramMechanismTest.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +class ScramMechanismTest { + + @Test + public void testFromMechanismName() { + assertEquals(ScramMechanism.UNKNOWN, ScramMechanism.fromMechanismName("UNKNOWN")); + assertEquals(ScramMechanism.SCRAM_SHA_256, ScramMechanism.fromMechanismName("SCRAM-SHA-256")); + assertEquals(ScramMechanism.SCRAM_SHA_512, ScramMechanism.fromMechanismName("SCRAM-SHA-512")); + assertEquals(ScramMechanism.UNKNOWN, ScramMechanism.fromMechanismName("some string")); + assertEquals(ScramMechanism.UNKNOWN, ScramMechanism.fromMechanismName("scram-sha-256")); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/TopicCollectionTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/TopicCollectionTest.java new file mode 100644 index 0000000..ce37338 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/TopicCollectionTest.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin; + +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.TopicCollection; +import org.apache.kafka.common.TopicCollection.TopicIdCollection; +import org.apache.kafka.common.TopicCollection.TopicNameCollection; +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Arrays; +import java.util.List; + +public class TopicCollectionTest { + + @Test + public void testTopicCollection() { + + List topicIds = Arrays.asList(Uuid.randomUuid(), Uuid.randomUuid(), Uuid.randomUuid()); + List topicNames = Arrays.asList("foo", "bar"); + + TopicCollection idCollection = TopicCollection.ofTopicIds(topicIds); + TopicCollection nameCollection = TopicCollection.ofTopicNames(topicNames); + + assertTrue(((TopicIdCollection) idCollection).topicIds().containsAll(topicIds)); + assertTrue(((TopicNameCollection) nameCollection).topicNames().containsAll(topicNames)); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/internals/AbortTransactionHandlerTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/internals/AbortTransactionHandlerTest.java new file mode 100644 index 0000000..78b33a7 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/internals/AbortTransactionHandlerTest.java @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import org.apache.kafka.clients.admin.AbortTransactionSpec; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.ClusterAuthorizationException; +import org.apache.kafka.common.errors.InvalidProducerEpochException; +import org.apache.kafka.common.errors.TransactionCoordinatorFencedException; +import org.apache.kafka.common.errors.UnknownServerException; +import org.apache.kafka.common.message.WriteTxnMarkersRequestData; +import org.apache.kafka.common.message.WriteTxnMarkersResponseData; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.WriteTxnMarkersRequest; +import org.apache.kafka.common.requests.WriteTxnMarkersResponse; +import org.apache.kafka.common.utils.LogContext; +import org.junit.jupiter.api.Test; + +import static java.util.Collections.emptyList; +import static java.util.Collections.emptySet; +import static java.util.Collections.singleton; +import static java.util.Collections.singletonList; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class AbortTransactionHandlerTest { + private final LogContext logContext = new LogContext(); + private final TopicPartition topicPartition = new TopicPartition("foo", 5); + private final AbortTransactionSpec abortSpec = new AbortTransactionSpec( + topicPartition, 12345L, (short) 15, 4321); + private final Node node = new Node(1, "host", 1234); + + @Test + public void testInvalidBuildRequestCall() { + AbortTransactionHandler handler = new AbortTransactionHandler(abortSpec, logContext); + assertThrows(IllegalArgumentException.class, () -> handler.buildRequest(1, + emptySet())); + assertThrows(IllegalArgumentException.class, () -> handler.buildRequest(1, + mkSet(new TopicPartition("foo", 1)))); + assertThrows(IllegalArgumentException.class, () -> handler.buildRequest(1, + mkSet(topicPartition, new TopicPartition("foo", 1)))); + } + + @Test + public void testValidBuildRequestCall() { + AbortTransactionHandler handler = new AbortTransactionHandler(abortSpec, logContext); + WriteTxnMarkersRequest.Builder request = handler.buildRequest(1, singleton(topicPartition)); + assertEquals(1, request.data.markers().size()); + + WriteTxnMarkersRequestData.WritableTxnMarker markerRequest = request.data.markers().get(0); + assertEquals(abortSpec.producerId(), markerRequest.producerId()); + assertEquals(abortSpec.producerEpoch(), markerRequest.producerEpoch()); + assertEquals(abortSpec.coordinatorEpoch(), markerRequest.coordinatorEpoch()); + assertEquals(1, markerRequest.topics().size()); + + WriteTxnMarkersRequestData.WritableTxnMarkerTopic topicRequest = markerRequest.topics().get(0); + assertEquals(abortSpec.topicPartition().topic(), topicRequest.name()); + assertEquals(singletonList(abortSpec.topicPartition().partition()), topicRequest.partitionIndexes()); + } + + @Test + public void testInvalidHandleResponseCall() { + AbortTransactionHandler handler = new AbortTransactionHandler(abortSpec, logContext); + WriteTxnMarkersResponseData response = new WriteTxnMarkersResponseData(); + assertThrows(IllegalArgumentException.class, () -> handler.handleResponse(node, + emptySet(), new WriteTxnMarkersResponse(response))); + assertThrows(IllegalArgumentException.class, () -> handler.handleResponse(node, + mkSet(new TopicPartition("foo", 1)), new WriteTxnMarkersResponse(response))); + assertThrows(IllegalArgumentException.class, () -> handler.handleResponse(node, + mkSet(topicPartition, new TopicPartition("foo", 1)), new WriteTxnMarkersResponse(response))); + } + + @Test + public void testInvalidResponse() { + AbortTransactionHandler handler = new AbortTransactionHandler(abortSpec, logContext); + + WriteTxnMarkersResponseData response = new WriteTxnMarkersResponseData(); + assertFailed(KafkaException.class, topicPartition, handler.handleResponse(node, singleton(topicPartition), + new WriteTxnMarkersResponse(response))); + + WriteTxnMarkersResponseData.WritableTxnMarkerResult markerResponse = + new WriteTxnMarkersResponseData.WritableTxnMarkerResult(); + response.markers().add(markerResponse); + assertFailed(KafkaException.class, topicPartition, handler.handleResponse(node, singleton(topicPartition), + new WriteTxnMarkersResponse(response))); + + markerResponse.setProducerId(abortSpec.producerId()); + assertFailed(KafkaException.class, topicPartition, handler.handleResponse(node, singleton(topicPartition), + new WriteTxnMarkersResponse(response))); + + WriteTxnMarkersResponseData.WritableTxnMarkerTopicResult topicResponse = + new WriteTxnMarkersResponseData.WritableTxnMarkerTopicResult(); + markerResponse.topics().add(topicResponse); + assertFailed(KafkaException.class, topicPartition, handler.handleResponse(node, singleton(topicPartition), + new WriteTxnMarkersResponse(response))); + + topicResponse.setName(abortSpec.topicPartition().topic()); + assertFailed(KafkaException.class, topicPartition, handler.handleResponse(node, singleton(topicPartition), + new WriteTxnMarkersResponse(response))); + + WriteTxnMarkersResponseData.WritableTxnMarkerPartitionResult partitionResponse = + new WriteTxnMarkersResponseData.WritableTxnMarkerPartitionResult(); + topicResponse.partitions().add(partitionResponse); + assertFailed(KafkaException.class, topicPartition, handler.handleResponse(node, singleton(topicPartition), + new WriteTxnMarkersResponse(response))); + + partitionResponse.setPartitionIndex(abortSpec.topicPartition().partition()); + topicResponse.setName(abortSpec.topicPartition().topic() + "random"); + assertFailed(KafkaException.class, topicPartition, handler.handleResponse(node, singleton(topicPartition), + new WriteTxnMarkersResponse(response))); + + topicResponse.setName(abortSpec.topicPartition().topic()); + markerResponse.setProducerId(abortSpec.producerId() + 1); + assertFailed(KafkaException.class, topicPartition, handler.handleResponse(node, singleton(topicPartition), + new WriteTxnMarkersResponse(response))); + } + + @Test + public void testSuccessfulResponse() { + assertCompleted(abortSpec.topicPartition(), handleWithError(abortSpec, Errors.NONE)); + } + + @Test + public void testRetriableErrors() { + assertUnmapped(abortSpec.topicPartition(), handleWithError(abortSpec, Errors.NOT_LEADER_OR_FOLLOWER)); + assertUnmapped(abortSpec.topicPartition(), handleWithError(abortSpec, Errors.UNKNOWN_TOPIC_OR_PARTITION)); + assertUnmapped(abortSpec.topicPartition(), handleWithError(abortSpec, Errors.REPLICA_NOT_AVAILABLE)); + assertUnmapped(abortSpec.topicPartition(), handleWithError(abortSpec, Errors.BROKER_NOT_AVAILABLE)); + } + + @Test + public void testFatalErrors() { + assertFailed(ClusterAuthorizationException.class, abortSpec.topicPartition(), + handleWithError(abortSpec, Errors.CLUSTER_AUTHORIZATION_FAILED)); + assertFailed(InvalidProducerEpochException.class, abortSpec.topicPartition(), + handleWithError(abortSpec, Errors.INVALID_PRODUCER_EPOCH)); + assertFailed(TransactionCoordinatorFencedException.class, abortSpec.topicPartition(), + handleWithError(abortSpec, Errors.TRANSACTION_COORDINATOR_FENCED)); + assertFailed(UnknownServerException.class, abortSpec.topicPartition(), + handleWithError(abortSpec, Errors.UNKNOWN_SERVER_ERROR)); + } + + private AdminApiHandler.ApiResult handleWithError( + AbortTransactionSpec abortSpec, + Errors error + ) { + AbortTransactionHandler handler = new AbortTransactionHandler(abortSpec, logContext); + + WriteTxnMarkersResponseData.WritableTxnMarkerPartitionResult partitionResponse = + new WriteTxnMarkersResponseData.WritableTxnMarkerPartitionResult() + .setPartitionIndex(abortSpec.topicPartition().partition()) + .setErrorCode(error.code()); + + WriteTxnMarkersResponseData.WritableTxnMarkerTopicResult topicResponse = + new WriteTxnMarkersResponseData.WritableTxnMarkerTopicResult() + .setName(abortSpec.topicPartition().topic()); + topicResponse.partitions().add(partitionResponse); + + WriteTxnMarkersResponseData.WritableTxnMarkerResult markerResponse = + new WriteTxnMarkersResponseData.WritableTxnMarkerResult() + .setProducerId(abortSpec.producerId()); + markerResponse.topics().add(topicResponse); + + WriteTxnMarkersResponseData response = new WriteTxnMarkersResponseData(); + response.markers().add(markerResponse); + + return handler.handleResponse(node, singleton(abortSpec.topicPartition()), + new WriteTxnMarkersResponse(response)); + } + + private void assertUnmapped( + TopicPartition topicPartition, + AdminApiHandler.ApiResult result + ) { + assertEquals(emptySet(), result.completedKeys.keySet()); + assertEquals(emptySet(), result.failedKeys.keySet()); + assertEquals(singletonList(topicPartition), result.unmappedKeys); + } + + private void assertCompleted( + TopicPartition topicPartition, + AdminApiHandler.ApiResult result + ) { + assertEquals(emptySet(), result.failedKeys.keySet()); + assertEquals(emptyList(), result.unmappedKeys); + assertEquals(singleton(topicPartition), result.completedKeys.keySet()); + assertNull(result.completedKeys.get(topicPartition)); + } + + private void assertFailed( + Class expectedExceptionType, + TopicPartition topicPartition, + AdminApiHandler.ApiResult result + ) { + assertEquals(emptySet(), result.completedKeys.keySet()); + assertEquals(emptyList(), result.unmappedKeys); + assertEquals(singleton(topicPartition), result.failedKeys.keySet()); + assertTrue(expectedExceptionType.isInstance(result.failedKeys.get(topicPartition))); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/internals/AdminApiDriverTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/internals/AdminApiDriverTest.java new file mode 100644 index 0000000..6ff393f --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/internals/AdminApiDriverTest.java @@ -0,0 +1,833 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import org.apache.kafka.clients.admin.internals.AdminApiDriver.RequestSpec; +import org.apache.kafka.clients.admin.internals.AdminApiHandler.ApiResult; +import org.apache.kafka.clients.admin.internals.AdminApiLookupStrategy.LookupResult; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.errors.DisconnectException; +import org.apache.kafka.common.errors.UnknownServerException; +import org.apache.kafka.common.message.MetadataResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.requests.AbstractRequest; +import org.apache.kafka.common.requests.AbstractResponse; +import org.apache.kafka.common.requests.FindCoordinatorRequest.NoBatchedFindCoordinatorsException; +import org.apache.kafka.common.requests.MetadataRequest; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.stream.Collectors; + +import static java.util.Collections.emptyMap; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +class AdminApiDriverTest { + private static final int API_TIMEOUT_MS = 30000; + private static final int RETRY_BACKOFF_MS = 100; + + @Test + public void testCoalescedLookup() { + TestContext ctx = TestContext.dynamicMapped(map( + "foo", "c1", + "bar", "c1" + )); + + Map, LookupResult> lookupRequests = map( + mkSet("foo", "bar"), mapped("foo", 1, "bar", 2) + ); + + ctx.poll(lookupRequests, emptyMap()); + + Map, ApiResult> fulfillmentResults = map( + mkSet("foo"), completed("foo", 15L), + mkSet("bar"), completed("bar", 30L) + ); + + ctx.poll(emptyMap(), fulfillmentResults); + + ctx.poll(emptyMap(), emptyMap()); + } + + @Test + public void testCoalescedFulfillment() { + TestContext ctx = TestContext.dynamicMapped(map( + "foo", "c1", + "bar", "c2" + )); + + Map, LookupResult> lookupRequests = map( + mkSet("foo"), mapped("foo", 1), + mkSet("bar"), mapped("bar", 1) + ); + + ctx.poll(lookupRequests, emptyMap()); + + Map, ApiResult> fulfillmentResults = map( + mkSet("foo", "bar"), completed("foo", 15L, "bar", 30L) + ); + + ctx.poll(emptyMap(), fulfillmentResults); + + ctx.poll(emptyMap(), emptyMap()); + } + + @Test + public void testKeyLookupFailure() { + TestContext ctx = TestContext.dynamicMapped(map( + "foo", "c1", + "bar", "c2" + )); + + Map, LookupResult> lookupRequests = map( + mkSet("foo"), failedLookup("foo", new UnknownServerException()), + mkSet("bar"), mapped("bar", 1) + ); + + ctx.poll(lookupRequests, emptyMap()); + + Map, ApiResult> fulfillmentResults = map( + mkSet("bar"), completed("bar", 30L) + ); + + ctx.poll(emptyMap(), fulfillmentResults); + + ctx.poll(emptyMap(), emptyMap()); + } + + @Test + public void testKeyLookupRetry() { + TestContext ctx = TestContext.dynamicMapped(map( + "foo", "c1", + "bar", "c2" + )); + + Map, LookupResult> lookupRequests = map( + mkSet("foo"), emptyLookup(), + mkSet("bar"), mapped("bar", 1) + ); + + ctx.poll(lookupRequests, emptyMap()); + + Map, LookupResult> fooRetry = map( + mkSet("foo"), mapped("foo", 1) + ); + + Map, ApiResult> barFulfillment = map( + mkSet("bar"), completed("bar", 30L) + ); + + ctx.poll(fooRetry, barFulfillment); + + Map, ApiResult> fooFulfillment = map( + mkSet("foo"), completed("foo", 15L) + ); + + ctx.poll(emptyMap(), fooFulfillment); + + ctx.poll(emptyMap(), emptyMap()); + } + + @Test + public void testStaticMapping() { + TestContext ctx = TestContext.staticMapped(map( + "foo", 0, + "bar", 1, + "baz", 1 + )); + + Map, ApiResult> fulfillmentResults = map( + mkSet("foo"), completed("foo", 15L), + mkSet("bar", "baz"), completed("bar", 30L, "baz", 45L) + ); + + ctx.poll(emptyMap(), fulfillmentResults); + + ctx.poll(emptyMap(), emptyMap()); + } + + @Test + public void testFulfillmentFailure() { + TestContext ctx = TestContext.staticMapped(map( + "foo", 0, + "bar", 1, + "baz", 1 + )); + + Map, ApiResult> fulfillmentResults = map( + mkSet("foo"), failed("foo", new UnknownServerException()), + mkSet("bar", "baz"), completed("bar", 30L, "baz", 45L) + ); + + ctx.poll(emptyMap(), fulfillmentResults); + + ctx.poll(emptyMap(), emptyMap()); + } + + @Test + public void testFulfillmentRetry() { + TestContext ctx = TestContext.staticMapped(map( + "foo", 0, + "bar", 1, + "baz", 1 + )); + + Map, ApiResult> fulfillmentResults = map( + mkSet("foo"), completed("foo", 15L), + mkSet("bar", "baz"), completed("bar", 30L) + ); + + ctx.poll(emptyMap(), fulfillmentResults); + + Map, ApiResult> bazRetry = map( + mkSet("baz"), completed("baz", 45L) + ); + + ctx.poll(emptyMap(), bazRetry); + + ctx.poll(emptyMap(), emptyMap()); + } + + @Test + public void testFulfillmentUnmapping() { + TestContext ctx = TestContext.dynamicMapped(map( + "foo", "c1", + "bar", "c2" + )); + + Map, LookupResult> lookupRequests = map( + mkSet("foo"), mapped("foo", 0), + mkSet("bar"), mapped("bar", 1) + ); + + ctx.poll(lookupRequests, emptyMap()); + + Map, ApiResult> fulfillmentResults = map( + mkSet("foo"), completed("foo", 15L), + mkSet("bar"), unmapped("bar") + ); + + ctx.poll(emptyMap(), fulfillmentResults); + + Map, LookupResult> barLookupRetry = map( + mkSet("bar"), mapped("bar", 1) + ); + + ctx.poll(barLookupRetry, emptyMap()); + + Map, ApiResult> barFulfillRetry = map( + mkSet("bar"), completed("bar", 30L) + ); + + ctx.poll(emptyMap(), barFulfillRetry); + + ctx.poll(emptyMap(), emptyMap()); + } + + @Test + public void testRecoalescedLookup() { + TestContext ctx = TestContext.dynamicMapped(map( + "foo", "c1", + "bar", "c1" + )); + + Map, LookupResult> lookupRequests = map( + mkSet("foo", "bar"), mapped("foo", 1, "bar", 2) + ); + + ctx.poll(lookupRequests, emptyMap()); + + Map, ApiResult> fulfillment = map( + mkSet("foo"), unmapped("foo"), + mkSet("bar"), unmapped("bar") + ); + + ctx.poll(emptyMap(), fulfillment); + + Map, LookupResult> retryLookupRequests = map( + mkSet("foo", "bar"), mapped("foo", 3, "bar", 3) + ); + + ctx.poll(retryLookupRequests, emptyMap()); + + Map, ApiResult> retryFulfillment = map( + mkSet("foo", "bar"), completed("foo", 15L, "bar", 30L) + ); + + ctx.poll(emptyMap(), retryFulfillment); + + ctx.poll(emptyMap(), emptyMap()); + } + + @Test + public void testRetryLookupAfterDisconnect() { + TestContext ctx = TestContext.dynamicMapped(map( + "foo", "c1" + )); + + int initialLeaderId = 1; + + Map, LookupResult> initialLookup = map( + mkSet("foo"), mapped("foo", initialLeaderId) + ); + + ctx.poll(initialLookup, emptyMap()); + assertMappedKey(ctx, "foo", initialLeaderId); + + ctx.handler.expectRequest(mkSet("foo"), completed("foo", 15L)); + + List> requestSpecs = ctx.driver.poll(); + assertEquals(1, requestSpecs.size()); + + RequestSpec requestSpec = requestSpecs.get(0); + assertEquals(OptionalInt.of(initialLeaderId), requestSpec.scope.destinationBrokerId()); + + ctx.driver.onFailure(ctx.time.milliseconds(), requestSpec, new DisconnectException()); + assertUnmappedKey(ctx, "foo"); + + int retryLeaderId = 2; + + ctx.lookupStrategy().expectLookup(mkSet("foo"), mapped("foo", retryLeaderId)); + List> retryLookupSpecs = ctx.driver.poll(); + assertEquals(1, retryLookupSpecs.size()); + + RequestSpec retryLookupSpec = retryLookupSpecs.get(0); + assertEquals(ctx.time.milliseconds(), retryLookupSpec.nextAllowedTryMs); + assertEquals(1, retryLookupSpec.tries); + } + + @Test + public void testRetryLookupAndDisableBatchAfterNoBatchedFindCoordinatorsException() { + MockTime time = new MockTime(); + LogContext lc = new LogContext(); + Set groupIds = new HashSet<>(Arrays.asList("g1", "g2")); + DeleteConsumerGroupsHandler handler = new DeleteConsumerGroupsHandler(lc); + AdminApiFuture future = AdminApiFuture.forKeys( + groupIds.stream().map(g -> CoordinatorKey.byGroupId(g)).collect(Collectors.toSet())); + + AdminApiDriver driver = new AdminApiDriver<>( + handler, + future, + time.milliseconds() + API_TIMEOUT_MS, + RETRY_BACKOFF_MS, + new LogContext() + ); + + assertTrue(((CoordinatorStrategy) handler.lookupStrategy()).batch); + List> requestSpecs = driver.poll(); + // Expect CoordinatorStrategy to try resolving all coordinators in a single request + assertEquals(1, requestSpecs.size()); + + RequestSpec requestSpec = requestSpecs.get(0); + driver.onFailure(time.milliseconds(), requestSpec, new NoBatchedFindCoordinatorsException("message")); + assertFalse(((CoordinatorStrategy) handler.lookupStrategy()).batch); + + // Batching is now disabled, so we now have a request per groupId + List> retryLookupSpecs = driver.poll(); + assertEquals(groupIds.size(), retryLookupSpecs.size()); + // These new requests are treated a new requests and not retries + for (RequestSpec retryLookupSpec : retryLookupSpecs) { + assertEquals(0, retryLookupSpec.nextAllowedTryMs); + assertEquals(0, retryLookupSpec.tries); + } + } + + @Test + public void testCoalescedStaticAndDynamicFulfillment() { + Map dynamicMapping = map( + "foo", "c1" + ); + + Map staticMapping = map( + "bar", 1 + ); + + TestContext ctx = new TestContext( + staticMapping, + dynamicMapping + ); + + // Initially we expect a lookup for the dynamic key and a + // fulfillment request for the static key + LookupResult lookupResult = mapped("foo", 1); + ctx.lookupStrategy().expectLookup( + mkSet("foo"), lookupResult + ); + ctx.handler.expectRequest( + mkSet("bar"), completed("bar", 10L) + ); + + List> requestSpecs = ctx.driver.poll(); + assertEquals(2, requestSpecs.size()); + + RequestSpec lookupSpec = requestSpecs.get(0); + assertEquals(mkSet("foo"), lookupSpec.keys); + ctx.assertLookupResponse(lookupSpec, lookupResult); + + // Receive a disconnect from the fulfillment request so that + // we have an opportunity to coalesce the keys. + RequestSpec fulfillmentSpec = requestSpecs.get(1); + assertEquals(mkSet("bar"), fulfillmentSpec.keys); + ctx.driver.onFailure(ctx.time.milliseconds(), fulfillmentSpec, new DisconnectException()); + + // Now we should get two fulfillment requests. One of them will + // the coalesced dynamic and static keys for broker 1. The other + // should contain the single dynamic key for broker 0. + ctx.handler.reset(); + ctx.handler.expectRequest( + mkSet("foo", "bar"), completed("foo", 15L, "bar", 30L) + ); + + List> coalescedSpecs = ctx.driver.poll(); + assertEquals(1, coalescedSpecs.size()); + RequestSpec coalescedSpec = coalescedSpecs.get(0); + assertEquals(mkSet("foo", "bar"), coalescedSpec.keys); + + // Disconnect in order to ensure that only the dynamic key is unmapped. + // Then complete the remaining requests. + ctx.driver.onFailure(ctx.time.milliseconds(), coalescedSpec, new DisconnectException()); + + Map, LookupResult> fooLookupRetry = map( + mkSet("foo"), mapped("foo", 3) + ); + Map, ApiResult> barFulfillmentRetry = map( + mkSet("bar"), completed("bar", 30L) + ); + ctx.poll(fooLookupRetry, barFulfillmentRetry); + + Map, ApiResult> fooFulfillmentRetry = map( + mkSet("foo"), completed("foo", 15L) + ); + ctx.poll(emptyMap(), fooFulfillmentRetry); + ctx.poll(emptyMap(), emptyMap()); + } + + @Test + public void testLookupRetryBookkeeping() { + TestContext ctx = TestContext.dynamicMapped(map( + "foo", "c1" + )); + + LookupResult emptyLookup = emptyLookup(); + ctx.lookupStrategy().expectLookup(mkSet("foo"), emptyLookup); + + List> requestSpecs = ctx.driver.poll(); + assertEquals(1, requestSpecs.size()); + + RequestSpec requestSpec = requestSpecs.get(0); + assertEquals(0, requestSpec.tries); + assertEquals(0L, requestSpec.nextAllowedTryMs); + ctx.assertLookupResponse(requestSpec, emptyLookup); + + List> retrySpecs = ctx.driver.poll(); + assertEquals(1, retrySpecs.size()); + + RequestSpec retrySpec = retrySpecs.get(0); + assertEquals(1, retrySpec.tries); + assertEquals(ctx.time.milliseconds(), retrySpec.nextAllowedTryMs); + } + + @Test + public void testFulfillmentRetryBookkeeping() { + TestContext ctx = TestContext.staticMapped(map("foo", 0)); + + ApiResult emptyFulfillment = emptyFulfillment(); + ctx.handler.expectRequest(mkSet("foo"), emptyFulfillment); + + List> requestSpecs = ctx.driver.poll(); + assertEquals(1, requestSpecs.size()); + + RequestSpec requestSpec = requestSpecs.get(0); + assertEquals(0, requestSpec.tries); + assertEquals(0L, requestSpec.nextAllowedTryMs); + ctx.assertResponse(requestSpec, emptyFulfillment, Node.noNode()); + + List> retrySpecs = ctx.driver.poll(); + assertEquals(1, retrySpecs.size()); + + RequestSpec retrySpec = retrySpecs.get(0); + assertEquals(1, retrySpec.tries); + assertEquals(ctx.time.milliseconds() + RETRY_BACKOFF_MS, retrySpec.nextAllowedTryMs); + } + + private static void assertMappedKey( + TestContext context, + String key, + Integer expectedBrokerId + ) { + OptionalInt brokerIdOpt = context.driver.keyToBrokerId(key); + assertEquals(OptionalInt.of(expectedBrokerId), brokerIdOpt); + } + + private static void assertUnmappedKey( + TestContext context, + String key + ) { + OptionalInt brokerIdOpt = context.driver.keyToBrokerId(key); + assertEquals(OptionalInt.empty(), brokerIdOpt); + KafkaFuture future = context.future.all().get(key); + assertFalse(future.isDone()); + } + + private static void assertFailedKey( + TestContext context, + String key, + Throwable expectedException + ) { + KafkaFuture future = context.future.all().get(key); + assertTrue(future.isCompletedExceptionally()); + Throwable exception = assertThrows(ExecutionException.class, future::get); + assertEquals(expectedException, exception.getCause()); + } + + private static void assertCompletedKey( + TestContext context, + String key, + Long expected + ) { + KafkaFuture future = context.future.all().get(key); + assertTrue(future.isDone()); + try { + assertEquals(expected, future.get()); + } catch (Throwable t) { + throw new RuntimeException(t); + } + } + + private static class MockRequestScope implements ApiRequestScope { + private final OptionalInt destinationBrokerId; + private final String id; + + private MockRequestScope( + OptionalInt destinationBrokerId, + String id + ) { + this.destinationBrokerId = destinationBrokerId; + this.id = id; + } + + @Override + public OptionalInt destinationBrokerId() { + return destinationBrokerId; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MockRequestScope that = (MockRequestScope) o; + return Objects.equals(destinationBrokerId, that.destinationBrokerId) && + Objects.equals(id, that.id); + } + + @Override + public int hashCode() { + return Objects.hash(destinationBrokerId, id); + } + } + + private static class TestContext { + private final MockTime time = new MockTime(); + private final MockAdminApiHandler handler; + private final AdminApiDriver driver; + private final AdminApiFuture.SimpleAdminApiFuture future; + + public TestContext( + Map staticKeys, + Map dynamicKeys + ) { + Map lookupScopes = new HashMap<>(); + staticKeys.forEach((key, brokerId) -> { + MockRequestScope scope = new MockRequestScope(OptionalInt.of(brokerId), null); + lookupScopes.put(key, scope); + }); + + dynamicKeys.forEach((key, context) -> { + MockRequestScope scope = new MockRequestScope(OptionalInt.empty(), context); + lookupScopes.put(key, scope); + }); + + MockLookupStrategy lookupStrategy = new MockLookupStrategy<>(lookupScopes); + this.handler = new MockAdminApiHandler<>(lookupStrategy); + this.future = AdminApiFuture.forKeys(lookupStrategy.lookupScopes.keySet()); + + this.driver = new AdminApiDriver<>( + handler, + future, + time.milliseconds() + API_TIMEOUT_MS, + RETRY_BACKOFF_MS, + new LogContext() + ); + + staticKeys.forEach((key, brokerId) -> { + assertMappedKey(this, key, brokerId); + }); + + dynamicKeys.keySet().forEach(key -> { + assertUnmappedKey(this, key); + }); + } + + public static TestContext staticMapped(Map staticKeys) { + return new TestContext(staticKeys, Collections.emptyMap()); + } + + public static TestContext dynamicMapped(Map dynamicKeys) { + return new TestContext(Collections.emptyMap(), dynamicKeys); + } + + private void assertLookupResponse( + RequestSpec requestSpec, + LookupResult result + ) { + requestSpec.keys.forEach(key -> { + assertUnmappedKey(this, key); + }); + + // The response is just a placeholder. The result is all we are interested in + MetadataResponse response = new MetadataResponse(new MetadataResponseData(), + ApiKeys.METADATA.latestVersion()); + driver.onResponse(time.milliseconds(), requestSpec, response, Node.noNode()); + + result.mappedKeys.forEach((key, brokerId) -> { + assertMappedKey(this, key, brokerId); + }); + + result.failedKeys.forEach((key, exception) -> { + assertFailedKey(this, key, exception); + }); + } + + private void assertResponse( + RequestSpec requestSpec, + ApiResult result, + Node node + ) { + int brokerId = requestSpec.scope.destinationBrokerId().orElseThrow(() -> + new AssertionError("Fulfillment requests must specify a target brokerId")); + + requestSpec.keys.forEach(key -> { + assertMappedKey(this, key, brokerId); + }); + + // The response is just a placeholder. The result is all we are interested in + MetadataResponse response = new MetadataResponse(new MetadataResponseData(), + ApiKeys.METADATA.latestVersion()); + + driver.onResponse(time.milliseconds(), requestSpec, response, node); + + result.unmappedKeys.forEach(key -> { + assertUnmappedKey(this, key); + }); + + result.failedKeys.forEach((key, exception) -> { + assertFailedKey(this, key, exception); + }); + + result.completedKeys.forEach((key, value) -> { + assertCompletedKey(this, key, value); + }); + } + + private MockLookupStrategy lookupStrategy() { + return handler.lookupStrategy; + } + + public void poll( + Map, LookupResult> expectedLookups, + Map, ApiResult> expectedRequests + ) { + if (!expectedLookups.isEmpty()) { + MockLookupStrategy lookupStrategy = lookupStrategy(); + lookupStrategy.reset(); + expectedLookups.forEach(lookupStrategy::expectLookup); + } + + handler.reset(); + expectedRequests.forEach(handler::expectRequest); + + List> requestSpecs = driver.poll(); + assertEquals(expectedLookups.size() + expectedRequests.size(), requestSpecs.size(), + "Driver generated an unexpected number of requests"); + + for (RequestSpec requestSpec : requestSpecs) { + Set keys = requestSpec.keys; + if (expectedLookups.containsKey(keys)) { + LookupResult result = expectedLookups.get(keys); + assertLookupResponse(requestSpec, result); + } else if (expectedRequests.containsKey(keys)) { + ApiResult result = expectedRequests.get(keys); + assertResponse(requestSpec, result, Node.noNode()); + } else { + fail("Unexpected request for keys " + keys); + } + } + } + } + + private static class MockLookupStrategy implements AdminApiLookupStrategy { + private final Map, LookupResult> expectedLookups = new HashMap<>(); + private final Map lookupScopes; + + private MockLookupStrategy(Map lookupScopes) { + this.lookupScopes = lookupScopes; + } + + @Override + public ApiRequestScope lookupScope(K key) { + return lookupScopes.get(key); + } + + public void expectLookup(Set keys, LookupResult result) { + expectedLookups.put(keys, result); + } + + @Override + public AbstractRequest.Builder buildRequest(Set keys) { + // The request is just a placeholder in these tests + assertTrue(expectedLookups.containsKey(keys), "Unexpected lookup request for keys " + keys); + return new MetadataRequest.Builder(Collections.emptyList(), false); + } + + @Override + public LookupResult handleResponse(Set keys, AbstractResponse response) { + return Optional.ofNullable(expectedLookups.get(keys)).orElseThrow(() -> + new AssertionError("Unexpected fulfillment request for keys " + keys) + ); + } + + public void reset() { + expectedLookups.clear(); + } + } + + private static class MockAdminApiHandler implements AdminApiHandler { + private final Map, ApiResult> expectedRequests = new HashMap<>(); + private final MockLookupStrategy lookupStrategy; + + private MockAdminApiHandler(MockLookupStrategy lookupStrategy) { + this.lookupStrategy = lookupStrategy; + } + + @Override + public String apiName() { + return "mock-api"; + } + + @Override + public AdminApiLookupStrategy lookupStrategy() { + return lookupStrategy; + } + + public void expectRequest(Set keys, ApiResult result) { + expectedRequests.put(keys, result); + } + + @Override + public AbstractRequest.Builder buildRequest(int brokerId, Set keys) { + // The request is just a placeholder in these tests + assertTrue(expectedRequests.containsKey(keys), "Unexpected fulfillment request for keys " + keys); + return new MetadataRequest.Builder(Collections.emptyList(), false); + } + + @Override + public ApiResult handleResponse(Node broker, Set keys, AbstractResponse response) { + return Optional.ofNullable(expectedRequests.get(keys)).orElseThrow(() -> + new AssertionError("Unexpected fulfillment request for keys " + keys) + ); + } + + public void reset() { + expectedRequests.clear(); + } + } + + private static Map map(K key, V value) { + return Collections.singletonMap(key, value); + } + + private static Map map(K k1, V v1, K k2, V v2) { + HashMap map = new HashMap<>(2); + map.put(k1, v1); + map.put(k2, v2); + return map; + } + + private static Map map(K k1, V v1, K k2, V v2, K k3, V v3) { + HashMap map = new HashMap<>(3); + map.put(k1, v1); + map.put(k2, v2); + map.put(k3, v3); + return map; + } + + private static ApiResult completed(String key, Long value) { + return new ApiResult<>(map(key, value), emptyMap(), Collections.emptyList()); + } + + private static ApiResult failed(String key, Throwable exception) { + return new ApiResult<>(emptyMap(), map(key, exception), Collections.emptyList()); + } + + private static ApiResult unmapped(String... keys) { + return new ApiResult<>(emptyMap(), emptyMap(), Arrays.asList(keys)); + } + + private static ApiResult completed(String k1, Long v1, String k2, Long v2) { + return new ApiResult<>(map(k1, v1, k2, v2), emptyMap(), Collections.emptyList()); + } + + private static ApiResult emptyFulfillment() { + return new ApiResult<>(emptyMap(), emptyMap(), Collections.emptyList()); + } + + private static LookupResult failedLookup(String key, Throwable exception) { + return new LookupResult<>(map(key, exception), emptyMap()); + } + + private static LookupResult emptyLookup() { + return new LookupResult<>(emptyMap(), emptyMap()); + } + + private static LookupResult mapped(String key, Integer brokerId) { + return new LookupResult<>(emptyMap(), map(key, brokerId)); + } + + private static LookupResult mapped(String k1, Integer broker1, String k2, Integer broker2) { + return new LookupResult<>(emptyMap(), map(k1, broker1, k2, broker2)); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/internals/AdminMetadataManagerTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/internals/AdminMetadataManagerTest.java new file mode 100644 index 0000000..05587aa --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/internals/AdminMetadataManagerTest.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.admin.internals; + +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.junit.jupiter.api.Test; + +import java.net.InetSocketAddress; +import java.util.Collections; +import java.util.HashMap; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class AdminMetadataManagerTest { + private final MockTime time = new MockTime(); + private final LogContext logContext = new LogContext(); + private final long refreshBackoffMs = 100; + private final long metadataExpireMs = 60000; + private final AdminMetadataManager mgr = new AdminMetadataManager( + logContext, refreshBackoffMs, metadataExpireMs); + + @Test + public void testMetadataReady() { + // Metadata is not ready on initialization + assertFalse(mgr.isReady()); + assertEquals(0, mgr.metadataFetchDelayMs(time.milliseconds())); + + // Metadata is not ready when bootstrap servers are set + mgr.update(Cluster.bootstrap(Collections.singletonList(new InetSocketAddress("localhost", 9999))), + time.milliseconds()); + assertFalse(mgr.isReady()); + assertEquals(0, mgr.metadataFetchDelayMs(time.milliseconds())); + + mgr.update(mockCluster(), time.milliseconds()); + assertTrue(mgr.isReady()); + assertEquals(metadataExpireMs, mgr.metadataFetchDelayMs(time.milliseconds())); + + time.sleep(metadataExpireMs); + assertEquals(0, mgr.metadataFetchDelayMs(time.milliseconds())); + } + + @Test + public void testMetadataRefreshBackoff() { + mgr.transitionToUpdatePending(time.milliseconds()); + assertEquals(Long.MAX_VALUE, mgr.metadataFetchDelayMs(time.milliseconds())); + + mgr.updateFailed(new RuntimeException()); + assertEquals(refreshBackoffMs, mgr.metadataFetchDelayMs(time.milliseconds())); + + // Even if we explicitly request an update, the backoff should be respected + mgr.requestUpdate(); + assertEquals(refreshBackoffMs, mgr.metadataFetchDelayMs(time.milliseconds())); + + time.sleep(refreshBackoffMs); + assertEquals(0, mgr.metadataFetchDelayMs(time.milliseconds())); + } + + @Test + public void testAuthenticationFailure() { + mgr.transitionToUpdatePending(time.milliseconds()); + mgr.updateFailed(new AuthenticationException("Authentication failed")); + assertEquals(refreshBackoffMs, mgr.metadataFetchDelayMs(time.milliseconds())); + assertThrows(AuthenticationException.class, mgr::isReady); + mgr.update(mockCluster(), time.milliseconds()); + assertTrue(mgr.isReady()); + } + + private static Cluster mockCluster() { + HashMap nodes = new HashMap<>(); + nodes.put(0, new Node(0, "localhost", 8121)); + nodes.put(1, new Node(1, "localhost", 8122)); + nodes.put(2, new Node(2, "localhost", 8123)); + return new Cluster("mockClusterId", nodes.values(), + Collections.emptySet(), Collections.emptySet(), + Collections.emptySet(), nodes.get(0)); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/internals/AllBrokersStrategyIntegrationTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/internals/AllBrokersStrategyIntegrationTest.java new file mode 100644 index 0000000..2b98905 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/internals/AllBrokersStrategyIntegrationTest.java @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import org.apache.kafka.common.Node; +import org.apache.kafka.common.errors.DisconnectException; +import org.apache.kafka.common.errors.UnknownServerException; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.message.MetadataRequestData; +import org.apache.kafka.common.message.MetadataResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.requests.AbstractRequest; +import org.apache.kafka.common.requests.AbstractResponse; +import org.apache.kafka.common.requests.MetadataRequest; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.OptionalInt; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class AllBrokersStrategyIntegrationTest { + private static final long TIMEOUT_MS = 5000; + private static final long RETRY_BACKOFF_MS = 100; + + private final LogContext logContext = new LogContext(); + private final MockTime time = new MockTime(); + + private AdminApiDriver buildDriver( + AllBrokersStrategy.AllBrokersFuture result + ) { + return new AdminApiDriver<>( + new MockApiHandler(), + result, + time.milliseconds() + TIMEOUT_MS, + RETRY_BACKOFF_MS, + logContext + ); + } + + @Test + public void testFatalLookupError() { + AllBrokersStrategy.AllBrokersFuture result = new AllBrokersStrategy.AllBrokersFuture<>(); + AdminApiDriver driver = buildDriver(result); + + List> requestSpecs = driver.poll(); + assertEquals(1, requestSpecs.size()); + + AdminApiDriver.RequestSpec spec = requestSpecs.get(0); + assertEquals(AllBrokersStrategy.LOOKUP_KEYS, spec.keys); + + driver.onFailure(time.milliseconds(), spec, new UnknownServerException()); + assertTrue(result.all().isDone()); + TestUtils.assertFutureThrows(result.all(), UnknownServerException.class); + assertEquals(Collections.emptyList(), driver.poll()); + } + + @Test + public void testRetryLookupAfterDisconnect() { + AllBrokersStrategy.AllBrokersFuture result = new AllBrokersStrategy.AllBrokersFuture<>(); + AdminApiDriver driver = buildDriver(result); + + List> requestSpecs = driver.poll(); + assertEquals(1, requestSpecs.size()); + + AdminApiDriver.RequestSpec spec = requestSpecs.get(0); + assertEquals(AllBrokersStrategy.LOOKUP_KEYS, spec.keys); + + driver.onFailure(time.milliseconds(), spec, new DisconnectException()); + List> retrySpecs = driver.poll(); + assertEquals(1, retrySpecs.size()); + + AdminApiDriver.RequestSpec retrySpec = retrySpecs.get(0); + assertEquals(AllBrokersStrategy.LOOKUP_KEYS, retrySpec.keys); + assertEquals(time.milliseconds(), retrySpec.nextAllowedTryMs); + assertEquals(Collections.emptyList(), driver.poll()); + } + + @Test + public void testMultiBrokerCompletion() throws Exception { + AllBrokersStrategy.AllBrokersFuture result = new AllBrokersStrategy.AllBrokersFuture<>(); + AdminApiDriver driver = buildDriver(result); + + List> lookupSpecs = driver.poll(); + assertEquals(1, lookupSpecs.size()); + AdminApiDriver.RequestSpec lookupSpec = lookupSpecs.get(0); + + Set brokerIds = Utils.mkSet(1, 2); + driver.onResponse(time.milliseconds(), lookupSpec, responseWithBrokers(brokerIds), Node.noNode()); + assertTrue(result.all().isDone()); + + Map> brokerFutures = result.all().get(); + + List> requestSpecs = driver.poll(); + assertEquals(2, requestSpecs.size()); + + AdminApiDriver.RequestSpec requestSpec1 = requestSpecs.get(0); + assertTrue(requestSpec1.scope.destinationBrokerId().isPresent()); + int brokerId1 = requestSpec1.scope.destinationBrokerId().getAsInt(); + assertTrue(brokerIds.contains(brokerId1)); + + driver.onResponse(time.milliseconds(), requestSpec1, null, Node.noNode()); + KafkaFutureImpl future1 = brokerFutures.get(brokerId1); + assertTrue(future1.isDone()); + + AdminApiDriver.RequestSpec requestSpec2 = requestSpecs.get(1); + assertTrue(requestSpec2.scope.destinationBrokerId().isPresent()); + int brokerId2 = requestSpec2.scope.destinationBrokerId().getAsInt(); + assertNotEquals(brokerId1, brokerId2); + assertTrue(brokerIds.contains(brokerId2)); + + driver.onResponse(time.milliseconds(), requestSpec2, null, Node.noNode()); + KafkaFutureImpl future2 = brokerFutures.get(brokerId2); + assertTrue(future2.isDone()); + assertEquals(Collections.emptyList(), driver.poll()); + } + + @Test + public void testRetryFulfillmentAfterDisconnect() throws Exception { + AllBrokersStrategy.AllBrokersFuture result = new AllBrokersStrategy.AllBrokersFuture<>(); + AdminApiDriver driver = buildDriver(result); + + List> lookupSpecs = driver.poll(); + assertEquals(1, lookupSpecs.size()); + AdminApiDriver.RequestSpec lookupSpec = lookupSpecs.get(0); + + int brokerId = 1; + driver.onResponse(time.milliseconds(), lookupSpec, responseWithBrokers(Collections.singleton(brokerId)), Node.noNode()); + assertTrue(result.all().isDone()); + + Map> brokerFutures = result.all().get(); + KafkaFutureImpl future = brokerFutures.get(brokerId); + assertFalse(future.isDone()); + + List> requestSpecs = driver.poll(); + assertEquals(1, requestSpecs.size()); + AdminApiDriver.RequestSpec requestSpec = requestSpecs.get(0); + + driver.onFailure(time.milliseconds(), requestSpec, new DisconnectException()); + assertFalse(future.isDone()); + List> retrySpecs = driver.poll(); + assertEquals(1, retrySpecs.size()); + + AdminApiDriver.RequestSpec retrySpec = retrySpecs.get(0); + assertEquals(time.milliseconds() + RETRY_BACKOFF_MS, retrySpec.nextAllowedTryMs); + assertEquals(OptionalInt.of(brokerId), retrySpec.scope.destinationBrokerId()); + + driver.onResponse(time.milliseconds(), retrySpec, null, new Node(brokerId, "host", 1234)); + assertTrue(future.isDone()); + assertEquals(brokerId, future.get()); + assertEquals(Collections.emptyList(), driver.poll()); + } + + @Test + public void testFatalFulfillmentError() throws Exception { + AllBrokersStrategy.AllBrokersFuture result = new AllBrokersStrategy.AllBrokersFuture<>(); + AdminApiDriver driver = buildDriver(result); + + List> lookupSpecs = driver.poll(); + assertEquals(1, lookupSpecs.size()); + AdminApiDriver.RequestSpec lookupSpec = lookupSpecs.get(0); + + int brokerId = 1; + driver.onResponse(time.milliseconds(), lookupSpec, responseWithBrokers(Collections.singleton(brokerId)), Node.noNode()); + assertTrue(result.all().isDone()); + + Map> brokerFutures = result.all().get(); + KafkaFutureImpl future = brokerFutures.get(brokerId); + assertFalse(future.isDone()); + + List> requestSpecs = driver.poll(); + assertEquals(1, requestSpecs.size()); + AdminApiDriver.RequestSpec requestSpec = requestSpecs.get(0); + + driver.onFailure(time.milliseconds(), requestSpec, new UnknownServerException()); + assertTrue(future.isDone()); + TestUtils.assertFutureThrows(future, UnknownServerException.class); + assertEquals(Collections.emptyList(), driver.poll()); + } + + private MetadataResponse responseWithBrokers(Set brokerIds) { + MetadataResponseData response = new MetadataResponseData(); + for (Integer brokerId : brokerIds) { + response.brokers().add(new MetadataResponseData.MetadataResponseBroker() + .setNodeId(brokerId) + .setHost("host" + brokerId) + .setPort(9092) + ); + } + return new MetadataResponse(response, ApiKeys.METADATA.latestVersion()); + } + + private class MockApiHandler implements AdminApiHandler { + private final AllBrokersStrategy allBrokersStrategy = new AllBrokersStrategy(logContext); + + @Override + public String apiName() { + return "mock-api"; + } + + @Override + public AbstractRequest.Builder buildRequest( + int brokerId, + Set keys + ) { + return new MetadataRequest.Builder(new MetadataRequestData()); + } + + @Override + public ApiResult handleResponse( + Node broker, + Set keys, + AbstractResponse response + ) { + return ApiResult.completed(keys.iterator().next(), broker.id()); + } + + @Override + public AdminApiLookupStrategy lookupStrategy() { + return allBrokersStrategy; + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/internals/AllBrokersStrategyTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/internals/AllBrokersStrategyTest.java new file mode 100644 index 0000000..8e4b961 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/internals/AllBrokersStrategyTest.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import org.apache.kafka.common.message.MetadataResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.requests.MetadataRequest; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.utils.LogContext; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.HashSet; +import java.util.OptionalInt; +import java.util.Set; + +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class AllBrokersStrategyTest { + private final LogContext logContext = new LogContext(); + + @Test + public void testBuildRequest() { + AllBrokersStrategy strategy = new AllBrokersStrategy(logContext); + MetadataRequest.Builder builder = strategy.buildRequest(AllBrokersStrategy.LOOKUP_KEYS); + assertEquals(Collections.emptyList(), builder.topics()); + } + + @Test + public void testBuildRequestWithInvalidLookupKeys() { + AllBrokersStrategy strategy = new AllBrokersStrategy(logContext); + AllBrokersStrategy.BrokerKey key1 = new AllBrokersStrategy.BrokerKey(OptionalInt.empty()); + AllBrokersStrategy.BrokerKey key2 = new AllBrokersStrategy.BrokerKey(OptionalInt.of(1)); + assertThrows(IllegalArgumentException.class, () -> strategy.buildRequest(mkSet(key1))); + assertThrows(IllegalArgumentException.class, () -> strategy.buildRequest(mkSet(key2))); + assertThrows(IllegalArgumentException.class, () -> strategy.buildRequest(mkSet(key1, key2))); + + Set keys = new HashSet<>(AllBrokersStrategy.LOOKUP_KEYS); + keys.add(key2); + assertThrows(IllegalArgumentException.class, () -> strategy.buildRequest(keys)); + } + + @Test + public void testHandleResponse() { + AllBrokersStrategy strategy = new AllBrokersStrategy(logContext); + + MetadataResponseData response = new MetadataResponseData(); + response.brokers().add(new MetadataResponseData.MetadataResponseBroker() + .setNodeId(1) + .setHost("host1") + .setPort(9092) + ); + response.brokers().add(new MetadataResponseData.MetadataResponseBroker() + .setNodeId(2) + .setHost("host2") + .setPort(9092) + ); + + AdminApiLookupStrategy.LookupResult lookupResult = strategy.handleResponse( + AllBrokersStrategy.LOOKUP_KEYS, + new MetadataResponse(response, ApiKeys.METADATA.latestVersion()) + ); + + assertEquals(Collections.emptyMap(), lookupResult.failedKeys); + + Set expectedMappedKeys = mkSet( + new AllBrokersStrategy.BrokerKey(OptionalInt.of(1)), + new AllBrokersStrategy.BrokerKey(OptionalInt.of(2)) + ); + + assertEquals(expectedMappedKeys, lookupResult.mappedKeys.keySet()); + lookupResult.mappedKeys.forEach((brokerKey, brokerId) -> { + assertEquals(OptionalInt.of(brokerId), brokerKey.brokerId); + }); + } + + @Test + public void testHandleResponseWithNoBrokers() { + AllBrokersStrategy strategy = new AllBrokersStrategy(logContext); + + MetadataResponseData response = new MetadataResponseData(); + + AdminApiLookupStrategy.LookupResult lookupResult = strategy.handleResponse( + AllBrokersStrategy.LOOKUP_KEYS, + new MetadataResponse(response, ApiKeys.METADATA.latestVersion()) + ); + + assertEquals(Collections.emptyMap(), lookupResult.failedKeys); + assertEquals(Collections.emptyMap(), lookupResult.mappedKeys); + } + + @Test + public void testHandleResponseWithInvalidLookupKeys() { + AllBrokersStrategy strategy = new AllBrokersStrategy(logContext); + AllBrokersStrategy.BrokerKey key1 = new AllBrokersStrategy.BrokerKey(OptionalInt.empty()); + AllBrokersStrategy.BrokerKey key2 = new AllBrokersStrategy.BrokerKey(OptionalInt.of(1)); + MetadataResponse response = new MetadataResponse(new MetadataResponseData(), ApiKeys.METADATA.latestVersion()); + + assertThrows(IllegalArgumentException.class, () -> strategy.handleResponse(mkSet(key1), response)); + assertThrows(IllegalArgumentException.class, () -> strategy.handleResponse(mkSet(key2), response)); + assertThrows(IllegalArgumentException.class, () -> strategy.handleResponse(mkSet(key1, key2), response)); + + Set keys = new HashSet<>(AllBrokersStrategy.LOOKUP_KEYS); + keys.add(key2); + assertThrows(IllegalArgumentException.class, () -> strategy.handleResponse(keys, response)); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/internals/AlterConsumerGroupOffsetsHandlerTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/internals/AlterConsumerGroupOffsetsHandlerTest.java new file mode 100644 index 0000000..c0ea2ba --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/internals/AlterConsumerGroupOffsetsHandlerTest.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import static java.util.Collections.emptyMap; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static java.util.Collections.singleton; +import static java.util.Collections.singletonList; +import static java.util.Collections.emptyList; +import static java.util.Collections.emptySet; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.apache.kafka.clients.admin.internals.AdminApiHandler.ApiResult; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.OffsetCommitRequest; +import org.apache.kafka.common.requests.OffsetCommitResponse; +import org.apache.kafka.common.utils.LogContext; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class AlterConsumerGroupOffsetsHandlerTest { + + private final LogContext logContext = new LogContext(); + private final String groupId = "group-id"; + private final TopicPartition t0p0 = new TopicPartition("t0", 0); + private final TopicPartition t0p1 = new TopicPartition("t0", 1); + private final TopicPartition t1p0 = new TopicPartition("t1", 0); + private final TopicPartition t1p1 = new TopicPartition("t1", 1); + private final Map partitions = new HashMap<>(); + private final long offset = 1L; + private final Node node = new Node(1, "host", 1234); + + @BeforeEach + public void setUp() { + partitions.put(t0p0, new OffsetAndMetadata(offset)); + partitions.put(t0p1, new OffsetAndMetadata(offset)); + partitions.put(t1p0, new OffsetAndMetadata(offset)); + partitions.put(t1p1, new OffsetAndMetadata(offset)); + } + + @Test + public void testBuildRequest() { + AlterConsumerGroupOffsetsHandler handler = new AlterConsumerGroupOffsetsHandler(groupId, partitions, logContext); + OffsetCommitRequest request = handler.buildRequest(-1, singleton(CoordinatorKey.byGroupId(groupId))).build(); + assertEquals(groupId, request.data().groupId()); + assertEquals(2, request.data().topics().size()); + assertEquals(2, request.data().topics().get(0).partitions().size()); + assertEquals(offset, request.data().topics().get(0).partitions().get(0).committedOffset()); + } + + @Test + public void testHandleSuccessfulResponse() { + AlterConsumerGroupOffsetsHandler handler = new AlterConsumerGroupOffsetsHandler(groupId, partitions, logContext); + Map responseData = Collections.singletonMap(t0p0, Errors.NONE); + OffsetCommitResponse response = new OffsetCommitResponse(0, responseData); + ApiResult> result = handler.handleResponse(node, singleton(CoordinatorKey.byGroupId(groupId)), response); + assertCompleted(result, responseData); + } + + @Test + public void testHandleRetriableResponse() { + assertUnmappedKey(partitionErrors(Errors.NOT_COORDINATOR)); + assertUnmappedKey(partitionErrors(Errors.COORDINATOR_NOT_AVAILABLE)); + assertRetriableError(partitionErrors(Errors.COORDINATOR_LOAD_IN_PROGRESS)); + assertRetriableError(partitionErrors(Errors.REBALANCE_IN_PROGRESS)); + } + + @Test + public void testHandleErrorResponse() { + assertFatalError(partitionErrors(Errors.TOPIC_AUTHORIZATION_FAILED)); + assertFatalError(partitionErrors(Errors.GROUP_AUTHORIZATION_FAILED)); + assertFatalError(partitionErrors(Errors.INVALID_GROUP_ID)); + assertFatalError(partitionErrors(Errors.UNKNOWN_TOPIC_OR_PARTITION)); + assertFatalError(partitionErrors(Errors.OFFSET_METADATA_TOO_LARGE)); + assertFatalError(partitionErrors(Errors.ILLEGAL_GENERATION)); + assertFatalError(partitionErrors(Errors.UNKNOWN_MEMBER_ID)); + assertFatalError(partitionErrors(Errors.INVALID_COMMIT_OFFSET_SIZE)); + assertFatalError(partitionErrors(Errors.UNKNOWN_SERVER_ERROR)); + } + + @Test + public void testHandleMultipleErrorsResponse() { + Map partitionErrors = new HashMap<>(); + partitionErrors.put(t0p0, Errors.UNKNOWN_TOPIC_OR_PARTITION); + partitionErrors.put(t0p1, Errors.INVALID_COMMIT_OFFSET_SIZE); + partitionErrors.put(t1p0, Errors.TOPIC_AUTHORIZATION_FAILED); + partitionErrors.put(t1p1, Errors.OFFSET_METADATA_TOO_LARGE); + assertFatalError(partitionErrors); + } + + private AdminApiHandler.ApiResult> handleResponse( + CoordinatorKey groupKey, + Map partitions, + Map partitionResults + ) { + AlterConsumerGroupOffsetsHandler handler = + new AlterConsumerGroupOffsetsHandler(groupKey.idValue, partitions, logContext); + OffsetCommitResponse response = new OffsetCommitResponse(0, partitionResults); + return handler.handleResponse(node, singleton(groupKey), response); + } + + private Map partitionErrors( + Errors error + ) { + Map partitionErrors = new HashMap<>(); + partitions.keySet().forEach(partition -> + partitionErrors.put(partition, error) + ); + return partitionErrors; + } + + private void assertFatalError( + Map partitionResults + ) { + CoordinatorKey groupKey = CoordinatorKey.byGroupId(groupId); + AdminApiHandler.ApiResult> result = handleResponse( + groupKey, + partitions, + partitionResults + ); + + assertEquals(singleton(groupKey), result.completedKeys.keySet()); + assertEquals(partitionResults, result.completedKeys.get(groupKey)); + assertEquals(emptyList(), result.unmappedKeys); + assertEquals(emptyMap(), result.failedKeys); + } + + private void assertRetriableError( + Map partitionResults + ) { + CoordinatorKey groupKey = CoordinatorKey.byGroupId(groupId); + AdminApiHandler.ApiResult> result = handleResponse( + groupKey, + partitions, + partitionResults + ); + + assertEquals(emptySet(), result.completedKeys.keySet()); + assertEquals(emptyList(), result.unmappedKeys); + assertEquals(emptyMap(), result.failedKeys); + } + + private void assertUnmappedKey( + Map partitionResults + ) { + CoordinatorKey groupKey = CoordinatorKey.byGroupId(groupId); + AdminApiHandler.ApiResult> result = handleResponse( + groupKey, + partitions, + partitionResults + ); + + assertEquals(emptySet(), result.completedKeys.keySet()); + assertEquals(emptySet(), result.failedKeys.keySet()); + assertEquals(singletonList(CoordinatorKey.byGroupId(groupId)), result.unmappedKeys); + } + + private void assertCompleted( + AdminApiHandler.ApiResult> result, + Map expected + ) { + CoordinatorKey key = CoordinatorKey.byGroupId(groupId); + assertEquals(emptySet(), result.failedKeys.keySet()); + assertEquals(emptyList(), result.unmappedKeys); + assertEquals(singleton(key), result.completedKeys.keySet()); + assertEquals(expected, result.completedKeys.get(key)); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/internals/CoordinatorStrategyTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/internals/CoordinatorStrategyTest.java new file mode 100644 index 0000000..dd83c6b --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/internals/CoordinatorStrategyTest.java @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import org.apache.kafka.common.errors.GroupAuthorizationException; +import org.apache.kafka.common.message.FindCoordinatorResponseData; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.FindCoordinatorRequest; +import org.apache.kafka.common.requests.FindCoordinatorRequest.CoordinatorType; +import org.apache.kafka.common.requests.FindCoordinatorResponse; +import org.apache.kafka.common.utils.LogContext; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.singleton; +import static java.util.Collections.singletonMap; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class CoordinatorStrategyTest { + + @Test + public void testBuildOldLookupRequest() { + CoordinatorStrategy strategy = new CoordinatorStrategy(CoordinatorType.GROUP, new LogContext()); + strategy.disableBatch(); + FindCoordinatorRequest.Builder request = strategy.buildRequest(singleton( + CoordinatorKey.byGroupId("foo"))); + assertEquals("foo", request.data().key()); + assertEquals(CoordinatorType.GROUP, CoordinatorType.forId(request.data().keyType())); + } + + @Test + public void testBuildLookupRequest() { + CoordinatorStrategy strategy = new CoordinatorStrategy(CoordinatorType.GROUP, new LogContext()); + FindCoordinatorRequest.Builder request = strategy.buildRequest(new HashSet<>(Arrays.asList( + CoordinatorKey.byGroupId("foo"), + CoordinatorKey.byGroupId("bar")))); + assertEquals("", request.data().key()); + assertEquals(2, request.data().coordinatorKeys().size()); + assertEquals(CoordinatorType.GROUP, CoordinatorType.forId(request.data().keyType())); + } + + @Test + public void testBuildLookupRequestNonRepresentable() { + CoordinatorStrategy strategy = new CoordinatorStrategy(CoordinatorType.GROUP, new LogContext()); + FindCoordinatorRequest.Builder request = strategy.buildRequest(new HashSet<>(Arrays.asList( + CoordinatorKey.byGroupId("foo"), + null))); + assertEquals("", request.data().key()); + assertEquals(1, request.data().coordinatorKeys().size()); + } + + @Test + public void testBuildOldLookupRequestRequiresOneKey() { + CoordinatorStrategy strategy = new CoordinatorStrategy(CoordinatorType.GROUP, new LogContext()); + strategy.disableBatch(); + assertThrows(IllegalArgumentException.class, () -> strategy.buildRequest(Collections.emptySet())); + + CoordinatorKey group1 = CoordinatorKey.byGroupId("foo"); + CoordinatorKey group2 = CoordinatorKey.byGroupId("bar"); + assertThrows(IllegalArgumentException.class, () -> strategy.buildRequest(mkSet(group1, group2))); + } + + @Test + public void testBuildOldLookupRequestRequiresAtLeastOneKey() { + CoordinatorStrategy strategy = new CoordinatorStrategy(CoordinatorType.GROUP, new LogContext()); + strategy.disableBatch(); + + assertThrows(IllegalArgumentException.class, () -> strategy.buildRequest( + new HashSet<>(Arrays.asList(CoordinatorKey.byTransactionalId("txnid"))))); + } + + @Test + public void testBuildLookupRequestRequiresAtLeastOneKey() { + CoordinatorStrategy strategy = new CoordinatorStrategy(CoordinatorType.GROUP, new LogContext()); + + assertThrows(IllegalArgumentException.class, () -> strategy.buildRequest(Collections.emptySet())); + } + + @Test + public void testBuildLookupRequestRequiresKeySameType() { + CoordinatorStrategy strategy = new CoordinatorStrategy(CoordinatorType.GROUP, new LogContext()); + + assertThrows(IllegalArgumentException.class, () -> strategy.buildRequest( + new HashSet<>(Arrays.asList( + CoordinatorKey.byGroupId("group"), + CoordinatorKey.byTransactionalId("txnid"))))); + } + + @Test + public void testHandleOldResponseRequiresOneKey() { + FindCoordinatorResponseData responseData = new FindCoordinatorResponseData().setErrorCode(Errors.NONE.code()); + FindCoordinatorResponse response = new FindCoordinatorResponse(responseData); + + CoordinatorStrategy strategy = new CoordinatorStrategy(CoordinatorType.GROUP, new LogContext()); + strategy.disableBatch(); + assertThrows(IllegalArgumentException.class, () -> + strategy.handleResponse(Collections.emptySet(), response)); + + CoordinatorKey group1 = CoordinatorKey.byGroupId("foo"); + CoordinatorKey group2 = CoordinatorKey.byGroupId("bar"); + assertThrows(IllegalArgumentException.class, () -> + strategy.handleResponse(mkSet(group1, group2), response)); + } + + @Test + public void testSuccessfulOldCoordinatorLookup() { + CoordinatorKey group = CoordinatorKey.byGroupId("foo"); + + FindCoordinatorResponseData responseData = new FindCoordinatorResponseData() + .setErrorCode(Errors.NONE.code()) + .setHost("localhost") + .setPort(9092) + .setNodeId(1); + + AdminApiLookupStrategy.LookupResult result = runOldLookup(group, responseData); + assertEquals(singletonMap(group, 1), result.mappedKeys); + assertEquals(emptyMap(), result.failedKeys); + } + + @Test + public void testSuccessfulCoordinatorLookup() { + CoordinatorKey group1 = CoordinatorKey.byGroupId("foo"); + CoordinatorKey group2 = CoordinatorKey.byGroupId("bar"); + + FindCoordinatorResponseData responseData = new FindCoordinatorResponseData() + .setCoordinators(Arrays.asList( + new FindCoordinatorResponseData.Coordinator() + .setKey("foo") + .setErrorCode(Errors.NONE.code()) + .setHost("localhost") + .setPort(9092) + .setNodeId(1), + new FindCoordinatorResponseData.Coordinator() + .setKey("bar") + .setErrorCode(Errors.NONE.code()) + .setHost("localhost") + .setPort(9092) + .setNodeId(2))); + + AdminApiLookupStrategy.LookupResult result = runLookup(new HashSet<>(Arrays.asList(group1, group2)), responseData); + Map expectedResult = new HashMap<>(); + expectedResult.put(group1, 1); + expectedResult.put(group2, 2); + assertEquals(expectedResult, result.mappedKeys); + assertEquals(emptyMap(), result.failedKeys); + } + + @Test + public void testRetriableOldCoordinatorLookup() { + testRetriableOldCoordinatorLookup(Errors.COORDINATOR_LOAD_IN_PROGRESS); + testRetriableOldCoordinatorLookup(Errors.COORDINATOR_NOT_AVAILABLE); + } + + private void testRetriableOldCoordinatorLookup(Errors error) { + CoordinatorKey group = CoordinatorKey.byGroupId("foo"); + FindCoordinatorResponseData responseData = new FindCoordinatorResponseData().setErrorCode(error.code()); + AdminApiLookupStrategy.LookupResult result = runOldLookup(group, responseData); + + assertEquals(emptyMap(), result.failedKeys); + assertEquals(emptyMap(), result.mappedKeys); + } + + @Test + public void testRetriableCoordinatorLookup() { + testRetriableCoordinatorLookup(Errors.COORDINATOR_LOAD_IN_PROGRESS); + testRetriableCoordinatorLookup(Errors.COORDINATOR_NOT_AVAILABLE); + } + + private void testRetriableCoordinatorLookup(Errors error) { + CoordinatorKey group1 = CoordinatorKey.byGroupId("foo"); + CoordinatorKey group2 = CoordinatorKey.byGroupId("bar"); + FindCoordinatorResponseData responseData = new FindCoordinatorResponseData() + .setCoordinators(Arrays.asList( + new FindCoordinatorResponseData.Coordinator() + .setKey("foo") + .setErrorCode(error.code()), + new FindCoordinatorResponseData.Coordinator() + .setKey("bar") + .setErrorCode(Errors.NONE.code()) + .setHost("localhost") + .setPort(9092) + .setNodeId(2))); + AdminApiLookupStrategy.LookupResult result = runLookup(new HashSet<>(Arrays.asList(group1, group2)), responseData); + + assertEquals(emptyMap(), result.failedKeys); + assertEquals(singletonMap(group2, 2), result.mappedKeys); + } + + @Test + public void testFatalErrorOldLookupResponses() { + CoordinatorKey group = CoordinatorKey.byTransactionalId("foo"); + assertFatalOldLookup(group, Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED); + assertFatalOldLookup(group, Errors.UNKNOWN_SERVER_ERROR); + + Throwable throwable = assertFatalOldLookup(group, Errors.GROUP_AUTHORIZATION_FAILED); + assertTrue(throwable instanceof GroupAuthorizationException); + GroupAuthorizationException exception = (GroupAuthorizationException) throwable; + assertEquals("foo", exception.groupId()); + } + + public Throwable assertFatalOldLookup( + CoordinatorKey key, + Errors error + ) { + FindCoordinatorResponseData responseData = new FindCoordinatorResponseData().setErrorCode(error.code()); + AdminApiLookupStrategy.LookupResult result = runOldLookup(key, responseData); + + assertEquals(emptyMap(), result.mappedKeys); + assertEquals(singleton(key), result.failedKeys.keySet()); + + Throwable throwable = result.failedKeys.get(key); + assertTrue(error.exception().getClass().isInstance(throwable)); + return throwable; + } + + @Test + public void testFatalErrorLookupResponses() { + CoordinatorKey group = CoordinatorKey.byTransactionalId("foo"); + assertFatalLookup(group, Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED); + assertFatalLookup(group, Errors.UNKNOWN_SERVER_ERROR); + + Throwable throwable = assertFatalLookup(group, Errors.GROUP_AUTHORIZATION_FAILED); + assertTrue(throwable instanceof GroupAuthorizationException); + GroupAuthorizationException exception = (GroupAuthorizationException) throwable; + assertEquals("foo", exception.groupId()); + } + + public Throwable assertFatalLookup( + CoordinatorKey key, + Errors error + ) { + FindCoordinatorResponseData responseData = new FindCoordinatorResponseData() + .setCoordinators(Collections.singletonList( + new FindCoordinatorResponseData.Coordinator() + .setKey(key.idValue) + .setErrorCode(error.code()))); + AdminApiLookupStrategy.LookupResult result = runLookup(singleton(key), responseData); + + assertEquals(emptyMap(), result.mappedKeys); + assertEquals(singleton(key), result.failedKeys.keySet()); + + Throwable throwable = result.failedKeys.get(key); + assertTrue(error.exception().getClass().isInstance(throwable)); + return throwable; + } + + private AdminApiLookupStrategy.LookupResult runOldLookup( + CoordinatorKey key, + FindCoordinatorResponseData responseData + ) { + CoordinatorStrategy strategy = new CoordinatorStrategy(key.type, new LogContext()); + strategy.disableBatch(); + FindCoordinatorResponse response = new FindCoordinatorResponse(responseData); + return strategy.handleResponse(singleton(key), response); + } + + private AdminApiLookupStrategy.LookupResult runLookup( + Set keys, + FindCoordinatorResponseData responseData + ) { + CoordinatorStrategy strategy = new CoordinatorStrategy(keys.iterator().next().type, new LogContext()); + strategy.buildRequest(keys); + FindCoordinatorResponse response = new FindCoordinatorResponse(responseData); + return strategy.handleResponse(keys, response); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/internals/DeleteConsumerGroupOffsetsHandlerTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/internals/DeleteConsumerGroupOffsetsHandlerTest.java new file mode 100644 index 0000000..b4aea93 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/internals/DeleteConsumerGroupOffsetsHandlerTest.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import static java.util.Collections.emptyList; +import static java.util.Collections.emptySet; +import static java.util.Collections.singleton; +import static java.util.Collections.singletonList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.GroupAuthorizationException; +import org.apache.kafka.common.errors.GroupIdNotFoundException; +import org.apache.kafka.common.errors.GroupNotEmptyException; +import org.apache.kafka.common.errors.InvalidGroupIdException; +import org.apache.kafka.common.message.OffsetDeleteResponseData; +import org.apache.kafka.common.message.OffsetDeleteResponseData.OffsetDeleteResponsePartition; +import org.apache.kafka.common.message.OffsetDeleteResponseData.OffsetDeleteResponsePartitionCollection; +import org.apache.kafka.common.message.OffsetDeleteResponseData.OffsetDeleteResponseTopic; +import org.apache.kafka.common.message.OffsetDeleteResponseData.OffsetDeleteResponseTopicCollection; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.OffsetDeleteRequest; +import org.apache.kafka.common.requests.OffsetDeleteResponse; +import org.apache.kafka.common.utils.LogContext; +import org.junit.jupiter.api.Test; + +public class DeleteConsumerGroupOffsetsHandlerTest { + + private final LogContext logContext = new LogContext(); + private final String groupId = "group-id"; + private final TopicPartition t0p0 = new TopicPartition("t0", 0); + private final TopicPartition t0p1 = new TopicPartition("t0", 1); + private final TopicPartition t1p0 = new TopicPartition("t1", 0); + private final Set tps = new HashSet<>(Arrays.asList(t0p0, t0p1, t1p0)); + + @Test + public void testBuildRequest() { + DeleteConsumerGroupOffsetsHandler handler = new DeleteConsumerGroupOffsetsHandler(groupId, tps, logContext); + OffsetDeleteRequest request = handler.buildRequest(1, singleton(CoordinatorKey.byGroupId(groupId))).build(); + assertEquals(groupId, request.data().groupId()); + assertEquals(2, request.data().topics().size()); + assertEquals(2, request.data().topics().find("t0").partitions().size()); + assertEquals(1, request.data().topics().find("t1").partitions().size()); + } + + @Test + public void testSuccessfulHandleResponse() { + Map responseData = Collections.singletonMap(t0p0, Errors.NONE); + assertCompleted(handleWithGroupError(Errors.NONE), responseData); + } + + @Test + public void testUnmappedHandleResponse() { + assertUnmapped(handleWithGroupError(Errors.NOT_COORDINATOR)); + assertUnmapped(handleWithGroupError(Errors.COORDINATOR_NOT_AVAILABLE)); + } + + @Test + public void testRetriableHandleResponse() { + assertRetriable(handleWithGroupError(Errors.COORDINATOR_LOAD_IN_PROGRESS)); + } + + @Test + public void testFailedHandleResponseWithGroupError() { + assertGroupFailed(GroupAuthorizationException.class, handleWithGroupError(Errors.GROUP_AUTHORIZATION_FAILED)); + assertGroupFailed(GroupIdNotFoundException.class, handleWithGroupError(Errors.GROUP_ID_NOT_FOUND)); + assertGroupFailed(InvalidGroupIdException.class, handleWithGroupError(Errors.INVALID_GROUP_ID)); + assertGroupFailed(GroupNotEmptyException.class, handleWithGroupError(Errors.NON_EMPTY_GROUP)); + } + + @Test + public void testFailedHandleResponseWithPartitionError() { + assertPartitionFailed(Collections.singletonMap(t0p0, Errors.GROUP_SUBSCRIBED_TO_TOPIC), + handleWithPartitionError(Errors.GROUP_SUBSCRIBED_TO_TOPIC)); + assertPartitionFailed(Collections.singletonMap(t0p0, Errors.TOPIC_AUTHORIZATION_FAILED), + handleWithPartitionError(Errors.TOPIC_AUTHORIZATION_FAILED)); + assertPartitionFailed(Collections.singletonMap(t0p0, Errors.UNKNOWN_TOPIC_OR_PARTITION), + handleWithPartitionError(Errors.UNKNOWN_TOPIC_OR_PARTITION)); + } + + private OffsetDeleteResponse buildGroupErrorResponse(Errors error) { + OffsetDeleteResponse response = new OffsetDeleteResponse( + new OffsetDeleteResponseData() + .setErrorCode(error.code())); + if (error == Errors.NONE) { + response.data() + .setThrottleTimeMs(0) + .setTopics(new OffsetDeleteResponseTopicCollection(singletonList( + new OffsetDeleteResponseTopic() + .setName(t0p0.topic()) + .setPartitions(new OffsetDeleteResponsePartitionCollection(singletonList( + new OffsetDeleteResponsePartition() + .setPartitionIndex(t0p0.partition()) + .setErrorCode(error.code()) + ).iterator())) + ).iterator())); + } + return response; + } + + private OffsetDeleteResponse buildPartitionErrorResponse(Errors error) { + OffsetDeleteResponse response = new OffsetDeleteResponse( + new OffsetDeleteResponseData() + .setThrottleTimeMs(0) + .setTopics(new OffsetDeleteResponseTopicCollection(singletonList( + new OffsetDeleteResponseTopic() + .setName(t0p0.topic()) + .setPartitions(new OffsetDeleteResponsePartitionCollection(singletonList( + new OffsetDeleteResponsePartition() + .setPartitionIndex(t0p0.partition()) + .setErrorCode(error.code()) + ).iterator())) + ).iterator())) + ); + return response; + } + + private AdminApiHandler.ApiResult> handleWithGroupError( + Errors error + ) { + DeleteConsumerGroupOffsetsHandler handler = new DeleteConsumerGroupOffsetsHandler(groupId, tps, logContext); + OffsetDeleteResponse response = buildGroupErrorResponse(error); + return handler.handleResponse(new Node(1, "host", 1234), singleton(CoordinatorKey.byGroupId(groupId)), response); + } + + private AdminApiHandler.ApiResult> handleWithPartitionError( + Errors error + ) { + DeleteConsumerGroupOffsetsHandler handler = new DeleteConsumerGroupOffsetsHandler(groupId, tps, logContext); + OffsetDeleteResponse response = buildPartitionErrorResponse(error); + return handler.handleResponse(new Node(1, "host", 1234), singleton(CoordinatorKey.byGroupId(groupId)), response); + } + + private void assertUnmapped( + AdminApiHandler.ApiResult> result + ) { + assertEquals(emptySet(), result.completedKeys.keySet()); + assertEquals(emptySet(), result.failedKeys.keySet()); + assertEquals(singletonList(CoordinatorKey.byGroupId(groupId)), result.unmappedKeys); + } + + private void assertRetriable( + AdminApiHandler.ApiResult> result + ) { + assertEquals(emptySet(), result.completedKeys.keySet()); + assertEquals(emptySet(), result.failedKeys.keySet()); + assertEquals(emptyList(), result.unmappedKeys); + } + + private void assertCompleted( + AdminApiHandler.ApiResult> result, + Map expected + ) { + CoordinatorKey key = CoordinatorKey.byGroupId(groupId); + assertEquals(emptySet(), result.failedKeys.keySet()); + assertEquals(emptyList(), result.unmappedKeys); + assertEquals(singleton(key), result.completedKeys.keySet()); + assertEquals(expected, result.completedKeys.get(key)); + } + + private void assertGroupFailed( + Class expectedExceptionType, + AdminApiHandler.ApiResult> result + ) { + CoordinatorKey key = CoordinatorKey.byGroupId(groupId); + assertEquals(emptySet(), result.completedKeys.keySet()); + assertEquals(emptyList(), result.unmappedKeys); + assertEquals(singleton(key), result.failedKeys.keySet()); + assertTrue(expectedExceptionType.isInstance(result.failedKeys.get(key))); + } + + private void assertPartitionFailed( + Map expectedResult, + AdminApiHandler.ApiResult> result + ) { + CoordinatorKey key = CoordinatorKey.byGroupId(groupId); + assertEquals(singleton(key), result.completedKeys.keySet()); + + // verify the completed value is expected result + Collection> completeCollection = result.completedKeys.values(); + assertEquals(1, completeCollection.size()); + assertEquals(expectedResult, result.completedKeys.get(key)); + + assertEquals(emptyList(), result.unmappedKeys); + assertEquals(emptySet(), result.failedKeys.keySet()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/internals/DeleteConsumerGroupsHandlerTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/internals/DeleteConsumerGroupsHandlerTest.java new file mode 100644 index 0000000..8d3a237 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/internals/DeleteConsumerGroupsHandlerTest.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import static java.util.Collections.emptyList; +import static java.util.Collections.emptySet; +import static java.util.Collections.singleton; +import static java.util.Collections.singletonList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.apache.kafka.common.Node; +import org.apache.kafka.common.errors.GroupAuthorizationException; +import org.apache.kafka.common.errors.GroupIdNotFoundException; +import org.apache.kafka.common.errors.GroupNotEmptyException; +import org.apache.kafka.common.errors.InvalidGroupIdException; +import org.apache.kafka.common.message.DeleteGroupsResponseData; +import org.apache.kafka.common.message.DeleteGroupsResponseData.DeletableGroupResult; +import org.apache.kafka.common.message.DeleteGroupsResponseData.DeletableGroupResultCollection; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.DeleteGroupsRequest; +import org.apache.kafka.common.requests.DeleteGroupsResponse; +import org.apache.kafka.common.utils.LogContext; +import org.junit.jupiter.api.Test; + +public class DeleteConsumerGroupsHandlerTest { + + private final LogContext logContext = new LogContext(); + private final String groupId1 = "group-id1"; + + @Test + public void testBuildRequest() { + DeleteConsumerGroupsHandler handler = new DeleteConsumerGroupsHandler(logContext); + DeleteGroupsRequest request = handler.buildRequest(1, singleton(CoordinatorKey.byGroupId(groupId1))).build(); + assertEquals(1, request.data().groupsNames().size()); + assertEquals(groupId1, request.data().groupsNames().get(0)); + } + + @Test + public void testSuccessfulHandleResponse() { + assertCompleted(handleWithError(Errors.NONE)); + } + + @Test + public void testUnmappedHandleResponse() { + assertUnmapped(handleWithError(Errors.NOT_COORDINATOR)); + assertUnmapped(handleWithError(Errors.COORDINATOR_NOT_AVAILABLE)); + } + + @Test + public void testRetriableHandleResponse() { + assertRetriable(handleWithError(Errors.COORDINATOR_LOAD_IN_PROGRESS)); + } + + @Test + public void testFailedHandleResponse() { + assertFailed(GroupAuthorizationException.class, handleWithError(Errors.GROUP_AUTHORIZATION_FAILED)); + assertFailed(GroupIdNotFoundException.class, handleWithError(Errors.GROUP_ID_NOT_FOUND)); + assertFailed(InvalidGroupIdException.class, handleWithError(Errors.INVALID_GROUP_ID)); + assertFailed(GroupNotEmptyException.class, handleWithError(Errors.NON_EMPTY_GROUP)); + } + + private DeleteGroupsResponse buildResponse(Errors error) { + DeleteGroupsResponse response = new DeleteGroupsResponse( + new DeleteGroupsResponseData() + .setResults(new DeletableGroupResultCollection(singletonList( + new DeletableGroupResult() + .setErrorCode(error.code()) + .setGroupId(groupId1)).iterator()))); + return response; + } + + private AdminApiHandler.ApiResult handleWithError( + Errors error + ) { + DeleteConsumerGroupsHandler handler = new DeleteConsumerGroupsHandler(logContext); + DeleteGroupsResponse response = buildResponse(error); + return handler.handleResponse(new Node(1, "host", 1234), singleton(CoordinatorKey.byGroupId(groupId1)), response); + } + + private void assertUnmapped( + AdminApiHandler.ApiResult result + ) { + assertEquals(emptySet(), result.completedKeys.keySet()); + assertEquals(emptySet(), result.failedKeys.keySet()); + assertEquals(singletonList(CoordinatorKey.byGroupId(groupId1)), result.unmappedKeys); + } + + private void assertRetriable( + AdminApiHandler.ApiResult result + ) { + assertEquals(emptySet(), result.completedKeys.keySet()); + assertEquals(emptySet(), result.failedKeys.keySet()); + assertEquals(emptyList(), result.unmappedKeys); + } + + private void assertCompleted( + AdminApiHandler.ApiResult result + ) { + CoordinatorKey key = CoordinatorKey.byGroupId(groupId1); + assertEquals(emptySet(), result.failedKeys.keySet()); + assertEquals(emptyList(), result.unmappedKeys); + assertEquals(singleton(key), result.completedKeys.keySet()); + } + + private void assertFailed( + Class expectedExceptionType, + AdminApiHandler.ApiResult result + ) { + CoordinatorKey key = CoordinatorKey.byGroupId(groupId1); + assertEquals(emptySet(), result.completedKeys.keySet()); + assertEquals(emptyList(), result.unmappedKeys); + assertEquals(singleton(key), result.failedKeys.keySet()); + assertTrue(expectedExceptionType.isInstance(result.failedKeys.get(key))); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/internals/DescribeConsumerGroupsHandlerTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/internals/DescribeConsumerGroupsHandlerTest.java new file mode 100644 index 0000000..aef207a --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/internals/DescribeConsumerGroupsHandlerTest.java @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import static java.util.Collections.emptyList; +import static java.util.Collections.emptySet; +import static java.util.Collections.singleton; +import static java.util.Collections.singletonList; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.kafka.clients.admin.ConsumerGroupDescription; +import org.apache.kafka.clients.admin.MemberAssignment; +import org.apache.kafka.clients.admin.MemberDescription; +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Assignment; +import org.apache.kafka.clients.consumer.internals.ConsumerProtocol; +import org.apache.kafka.common.ConsumerGroupState; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.GroupAuthorizationException; +import org.apache.kafka.common.errors.GroupIdNotFoundException; +import org.apache.kafka.common.errors.InvalidGroupIdException; +import org.apache.kafka.common.message.DescribeGroupsResponseData; +import org.apache.kafka.common.message.DescribeGroupsResponseData.DescribedGroup; +import org.apache.kafka.common.message.DescribeGroupsResponseData.DescribedGroupMember; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.DescribeGroupsRequest; +import org.apache.kafka.common.requests.DescribeGroupsResponse; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.api.Test; + +public class DescribeConsumerGroupsHandlerTest { + + private final LogContext logContext = new LogContext(); + private final String groupId1 = "group-id1"; + private final String groupId2 = "group-id2"; + private final Set groupIds = new HashSet<>(Arrays.asList(groupId1, groupId2)); + private final Set keys = groupIds.stream() + .map(CoordinatorKey::byGroupId) + .collect(Collectors.toSet()); + private final Node coordinator = new Node(1, "host", 1234); + private final Set tps = new HashSet<>(Arrays.asList( + new TopicPartition("foo", 0), new TopicPartition("bar", 1))); + + @Test + public void testBuildRequest() { + DescribeConsumerGroupsHandler handler = new DescribeConsumerGroupsHandler(false, logContext); + DescribeGroupsRequest request = handler.buildRequest(1, keys).build(); + assertEquals(2, request.data().groups().size()); + assertFalse(request.data().includeAuthorizedOperations()); + + handler = new DescribeConsumerGroupsHandler(true, logContext); + request = handler.buildRequest(1, keys).build(); + assertEquals(2, request.data().groups().size()); + assertTrue(request.data().includeAuthorizedOperations()); + } + + @Test + public void testInvalidBuildRequest() { + DescribeConsumerGroupsHandler handler = new DescribeConsumerGroupsHandler(false, logContext); + assertThrows(IllegalArgumentException.class, () -> handler.buildRequest(1, singleton(CoordinatorKey.byTransactionalId("tId")))); + } + + @Test + public void testSuccessfulHandleResponse() { + Collection members = singletonList(new MemberDescription( + "memberId", + "clientId", + "host", + new MemberAssignment(tps))); + ConsumerGroupDescription expected = new ConsumerGroupDescription( + groupId1, + true, + members, + "assignor", + ConsumerGroupState.STABLE, + coordinator); + assertCompleted(handleWithError(Errors.NONE, ""), expected); + } + + @Test + public void testUnmappedHandleResponse() { + assertUnmapped(handleWithError(Errors.COORDINATOR_NOT_AVAILABLE, "")); + assertUnmapped(handleWithError(Errors.NOT_COORDINATOR, "")); + } + + @Test + public void testRetriableHandleResponse() { + assertRetriable(handleWithError(Errors.COORDINATOR_LOAD_IN_PROGRESS, "")); + } + + @Test + public void testFailedHandleResponse() { + assertFailed(GroupAuthorizationException.class, handleWithError(Errors.GROUP_AUTHORIZATION_FAILED, "")); + assertFailed(GroupIdNotFoundException.class, handleWithError(Errors.GROUP_ID_NOT_FOUND, "")); + assertFailed(InvalidGroupIdException.class, handleWithError(Errors.INVALID_GROUP_ID, "")); + assertFailed(IllegalArgumentException.class, handleWithError(Errors.NONE, "custom-protocol")); + } + + private DescribeGroupsResponse buildResponse(Errors error, String protocolType) { + DescribeGroupsResponse response = new DescribeGroupsResponse( + new DescribeGroupsResponseData() + .setGroups(singletonList( + new DescribedGroup() + .setErrorCode(error.code()) + .setGroupId(groupId1) + .setGroupState(ConsumerGroupState.STABLE.toString()) + .setProtocolType(protocolType) + .setProtocolData("assignor") + .setAuthorizedOperations(Utils.to32BitField(emptySet())) + .setMembers(singletonList( + new DescribedGroupMember() + .setClientHost("host") + .setClientId("clientId") + .setMemberId("memberId") + .setMemberAssignment(ConsumerProtocol.serializeAssignment( + new Assignment(new ArrayList<>(tps))).array()) + ))))); + return response; + } + + private AdminApiHandler.ApiResult handleWithError( + Errors error, + String protocolType + ) { + DescribeConsumerGroupsHandler handler = new DescribeConsumerGroupsHandler(true, logContext); + DescribeGroupsResponse response = buildResponse(error, protocolType); + return handler.handleResponse(coordinator, singleton(CoordinatorKey.byGroupId(groupId1)), response); + } + + private void assertUnmapped( + AdminApiHandler.ApiResult result + ) { + assertEquals(emptySet(), result.completedKeys.keySet()); + assertEquals(emptySet(), result.failedKeys.keySet()); + assertEquals(singletonList(CoordinatorKey.byGroupId(groupId1)), result.unmappedKeys); + } + + private void assertRetriable( + AdminApiHandler.ApiResult result + ) { + assertEquals(emptySet(), result.completedKeys.keySet()); + assertEquals(emptySet(), result.failedKeys.keySet()); + assertEquals(emptyList(), result.unmappedKeys); + } + + private void assertCompleted( + AdminApiHandler.ApiResult result, + ConsumerGroupDescription expected + ) { + CoordinatorKey key = CoordinatorKey.byGroupId(groupId1); + assertEquals(emptySet(), result.failedKeys.keySet()); + assertEquals(emptyList(), result.unmappedKeys); + assertEquals(singleton(key), result.completedKeys.keySet()); + assertEquals(expected, result.completedKeys.get(CoordinatorKey.byGroupId(groupId1))); + } + + private void assertFailed( + Class expectedExceptionType, + AdminApiHandler.ApiResult result + ) { + CoordinatorKey key = CoordinatorKey.byGroupId(groupId1); + assertEquals(emptySet(), result.completedKeys.keySet()); + assertEquals(emptyList(), result.unmappedKeys); + assertEquals(singleton(key), result.failedKeys.keySet()); + assertTrue(expectedExceptionType.isInstance(result.failedKeys.get(key))); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/internals/DescribeProducersHandlerTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/internals/DescribeProducersHandlerTest.java new file mode 100644 index 0000000..8daed06 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/internals/DescribeProducersHandlerTest.java @@ -0,0 +1,330 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import org.apache.kafka.clients.admin.DescribeProducersOptions; +import org.apache.kafka.clients.admin.DescribeProducersResult.PartitionProducerState; +import org.apache.kafka.clients.admin.internals.AdminApiHandler.ApiResult; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.InvalidTopicException; +import org.apache.kafka.common.errors.NotLeaderOrFollowerException; +import org.apache.kafka.common.errors.TopicAuthorizationException; +import org.apache.kafka.common.errors.UnknownServerException; +import org.apache.kafka.common.message.DescribeProducersRequestData; +import org.apache.kafka.common.message.DescribeProducersResponseData; +import org.apache.kafka.common.message.DescribeProducersResponseData.PartitionResponse; +import org.apache.kafka.common.message.DescribeProducersResponseData.ProducerState; +import org.apache.kafka.common.message.DescribeProducersResponseData.TopicResponse; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.DescribeProducersRequest; +import org.apache.kafka.common.requests.DescribeProducersResponse; +import org.apache.kafka.common.utils.CollectionUtils; +import org.apache.kafka.common.utils.LogContext; +import org.junit.jupiter.api.Test; + +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.OptionalInt; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; +import static java.util.Collections.emptyMap; +import static java.util.Collections.singletonList; +import static java.util.Collections.singletonMap; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class DescribeProducersHandlerTest { + private DescribeProducersHandler newHandler( + DescribeProducersOptions options + ) { + return new DescribeProducersHandler( + options, + new LogContext() + ); + } + + + @Test + public void testBrokerIdSetInOptions() { + int brokerId = 3; + Set topicPartitions = mkSet( + new TopicPartition("foo", 5), + new TopicPartition("bar", 3), + new TopicPartition("foo", 4) + ); + + DescribeProducersHandler handler = newHandler( + new DescribeProducersOptions().brokerId(brokerId) + ); + + topicPartitions.forEach(topicPartition -> { + ApiRequestScope scope = handler.lookupStrategy().lookupScope(topicPartition); + assertEquals(OptionalInt.of(brokerId), scope.destinationBrokerId(), + "Unexpected brokerId for " + topicPartition); + }); + } + + @Test + public void testBrokerIdNotSetInOptions() { + Set topicPartitions = mkSet( + new TopicPartition("foo", 5), + new TopicPartition("bar", 3), + new TopicPartition("foo", 4) + ); + + DescribeProducersHandler handler = newHandler( + new DescribeProducersOptions() + ); + + topicPartitions.forEach(topicPartition -> { + ApiRequestScope scope = handler.lookupStrategy().lookupScope(topicPartition); + assertEquals(OptionalInt.empty(), scope.destinationBrokerId(), + "Unexpected brokerId for " + topicPartition); + }); + } + + @Test + public void testBuildRequest() { + Set topicPartitions = mkSet( + new TopicPartition("foo", 5), + new TopicPartition("bar", 3), + new TopicPartition("foo", 4) + ); + + DescribeProducersHandler handler = newHandler( + new DescribeProducersOptions() + ); + + int brokerId = 3; + DescribeProducersRequest.Builder request = handler.buildRequest(brokerId, topicPartitions); + + List topics = request.data.topics(); + + assertEquals(mkSet("foo", "bar"), topics.stream() + .map(DescribeProducersRequestData.TopicRequest::name) + .collect(Collectors.toSet())); + + topics.forEach(topic -> { + Set expectedTopicPartitions = "foo".equals(topic.name()) ? + mkSet(4, 5) : mkSet(3); + assertEquals(expectedTopicPartitions, new HashSet<>(topic.partitionIndexes())); + }); + } + + @Test + public void testAuthorizationFailure() { + TopicPartition topicPartition = new TopicPartition("foo", 5); + Throwable exception = assertFatalError(topicPartition, Errors.TOPIC_AUTHORIZATION_FAILED); + assertTrue(exception instanceof TopicAuthorizationException); + TopicAuthorizationException authException = (TopicAuthorizationException) exception; + assertEquals(mkSet("foo"), authException.unauthorizedTopics()); + } + + @Test + public void testInvalidTopic() { + TopicPartition topicPartition = new TopicPartition("foo", 5); + Throwable exception = assertFatalError(topicPartition, Errors.INVALID_TOPIC_EXCEPTION); + assertTrue(exception instanceof InvalidTopicException); + InvalidTopicException invalidTopicException = (InvalidTopicException) exception; + assertEquals(mkSet("foo"), invalidTopicException.invalidTopics()); + } + + @Test + public void testUnexpectedError() { + TopicPartition topicPartition = new TopicPartition("foo", 5); + Throwable exception = assertFatalError(topicPartition, Errors.UNKNOWN_SERVER_ERROR); + assertTrue(exception instanceof UnknownServerException); + } + + @Test + public void testRetriableErrors() { + TopicPartition topicPartition = new TopicPartition("foo", 5); + assertRetriableError(topicPartition, Errors.UNKNOWN_TOPIC_OR_PARTITION); + } + + @Test + public void testUnmappedAfterNotLeaderError() { + TopicPartition topicPartition = new TopicPartition("foo", 5); + ApiResult result = + handleResponseWithError(new DescribeProducersOptions(), topicPartition, Errors.NOT_LEADER_OR_FOLLOWER); + assertEquals(emptyMap(), result.failedKeys); + assertEquals(emptyMap(), result.completedKeys); + assertEquals(singletonList(topicPartition), result.unmappedKeys); + } + + @Test + public void testFatalNotLeaderErrorIfStaticMapped() { + TopicPartition topicPartition = new TopicPartition("foo", 5); + DescribeProducersOptions options = new DescribeProducersOptions().brokerId(1); + + ApiResult result = + handleResponseWithError(options, topicPartition, Errors.NOT_LEADER_OR_FOLLOWER); + assertEquals(emptyMap(), result.completedKeys); + assertEquals(emptyList(), result.unmappedKeys); + assertEquals(mkSet(topicPartition), result.failedKeys.keySet()); + Throwable exception = result.failedKeys.get(topicPartition); + assertTrue(exception instanceof NotLeaderOrFollowerException); + } + + @Test + public void testCompletedResult() { + TopicPartition topicPartition = new TopicPartition("foo", 5); + DescribeProducersOptions options = new DescribeProducersOptions().brokerId(1); + DescribeProducersHandler handler = newHandler(options); + + PartitionResponse partitionResponse = sampleProducerState(topicPartition); + DescribeProducersResponse response = describeProducersResponse( + singletonMap(topicPartition, partitionResponse) + ); + Node node = new Node(3, "host", 1); + + ApiResult result = + handler.handleResponse(node, mkSet(topicPartition), response); + + assertEquals(mkSet(topicPartition), result.completedKeys.keySet()); + assertEquals(emptyMap(), result.failedKeys); + assertEquals(emptyList(), result.unmappedKeys); + + PartitionProducerState producerState = result.completedKeys.get(topicPartition); + assertMatchingProducers(partitionResponse, producerState); + } + + private void assertRetriableError( + TopicPartition topicPartition, + Errors error + ) { + ApiResult result = + handleResponseWithError(new DescribeProducersOptions(), topicPartition, error); + assertEquals(emptyMap(), result.failedKeys); + assertEquals(emptyMap(), result.completedKeys); + assertEquals(emptyList(), result.unmappedKeys); + } + + private Throwable assertFatalError( + TopicPartition topicPartition, + Errors error + ) { + ApiResult result = handleResponseWithError( + new DescribeProducersOptions(), topicPartition, error); + assertEquals(emptyMap(), result.completedKeys); + assertEquals(emptyList(), result.unmappedKeys); + assertEquals(mkSet(topicPartition), result.failedKeys.keySet()); + return result.failedKeys.get(topicPartition); + } + + private ApiResult handleResponseWithError( + DescribeProducersOptions options, + TopicPartition topicPartition, + Errors error + ) { + DescribeProducersHandler handler = newHandler(options); + DescribeProducersResponse response = buildResponseWithError(topicPartition, error); + Node node = new Node(options.brokerId().orElse(3), "host", 1); + return handler.handleResponse(node, mkSet(topicPartition), response); + } + + private DescribeProducersResponse buildResponseWithError( + TopicPartition topicPartition, + Errors error + ) { + PartitionResponse partitionResponse = new PartitionResponse() + .setPartitionIndex(topicPartition.partition()) + .setErrorCode(error.code()); + return describeProducersResponse(singletonMap(topicPartition, partitionResponse)); + } + + private PartitionResponse sampleProducerState(TopicPartition topicPartition) { + PartitionResponse partitionResponse = new PartitionResponse() + .setPartitionIndex(topicPartition.partition()) + .setErrorCode(Errors.NONE.code()); + + partitionResponse.setActiveProducers(asList( + new ProducerState() + .setProducerId(12345L) + .setProducerEpoch(15) + .setLastSequence(75) + .setLastTimestamp(System.currentTimeMillis()) + .setCurrentTxnStartOffset(-1L), + new ProducerState() + .setProducerId(98765L) + .setProducerEpoch(30) + .setLastSequence(150) + .setLastTimestamp(System.currentTimeMillis() - 5000) + .setCurrentTxnStartOffset(5000) + )); + + return partitionResponse; + } + + private void assertMatchingProducers( + PartitionResponse expected, + PartitionProducerState actual + ) { + List expectedProducers = expected.activeProducers(); + List actualProducers = actual.activeProducers(); + + assertEquals(expectedProducers.size(), actualProducers.size()); + + Map expectedByProducerId = expectedProducers.stream().collect(Collectors.toMap( + ProducerState::producerId, + Function.identity() + )); + + for (org.apache.kafka.clients.admin.ProducerState actualProducerState : actualProducers) { + ProducerState expectedProducerState = expectedByProducerId.get(actualProducerState.producerId()); + assertNotNull(expectedProducerState); + assertEquals(expectedProducerState.producerEpoch(), actualProducerState.producerEpoch()); + assertEquals(expectedProducerState.lastSequence(), actualProducerState.lastSequence()); + assertEquals(expectedProducerState.lastTimestamp(), actualProducerState.lastTimestamp()); + assertEquals(expectedProducerState.currentTxnStartOffset(), + actualProducerState.currentTransactionStartOffset().orElse(-1L)); + } + } + + private DescribeProducersResponse describeProducersResponse( + Map partitionResponses + ) { + DescribeProducersResponseData response = new DescribeProducersResponseData(); + Map> partitionResponsesByTopic = + CollectionUtils.groupPartitionDataByTopic(partitionResponses); + + for (Map.Entry> topicEntry : partitionResponsesByTopic.entrySet()) { + String topic = topicEntry.getKey(); + Map topicPartitionResponses = topicEntry.getValue(); + + TopicResponse topicResponse = new TopicResponse().setName(topic); + response.topics().add(topicResponse); + + for (Map.Entry partitionEntry : topicPartitionResponses.entrySet()) { + Integer partitionId = partitionEntry.getKey(); + PartitionResponse partitionResponse = partitionEntry.getValue(); + topicResponse.partitions().add(partitionResponse.setPartitionIndex(partitionId)); + } + } + + return new DescribeProducersResponse(response); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/internals/DescribeTransactionsHandlerTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/internals/DescribeTransactionsHandlerTest.java new file mode 100644 index 0000000..04eac89 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/internals/DescribeTransactionsHandlerTest.java @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import org.apache.kafka.clients.admin.TransactionDescription; +import org.apache.kafka.clients.admin.internals.AdminApiHandler.ApiResult; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.DescribeTransactionsResponseData; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.DescribeTransactionsRequest; +import org.apache.kafka.common.requests.DescribeTransactionsResponse; +import org.apache.kafka.common.utils.LogContext; +import org.junit.jupiter.api.Test; + +import java.util.HashSet; +import java.util.Set; +import java.util.stream.Collectors; + +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; +import static java.util.Collections.emptyMap; +import static java.util.Collections.singletonList; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class DescribeTransactionsHandlerTest { + private final LogContext logContext = new LogContext(); + private final Node node = new Node(1, "host", 1234); + + @Test + public void testBuildRequest() { + String transactionalId1 = "foo"; + String transactionalId2 = "bar"; + String transactionalId3 = "baz"; + + Set transactionalIds = mkSet(transactionalId1, transactionalId2, transactionalId3); + DescribeTransactionsHandler handler = new DescribeTransactionsHandler(logContext); + + assertLookup(handler, transactionalIds); + assertLookup(handler, mkSet(transactionalId1)); + assertLookup(handler, mkSet(transactionalId2, transactionalId3)); + } + + @Test + public void testHandleSuccessfulResponse() { + String transactionalId1 = "foo"; + String transactionalId2 = "bar"; + + Set transactionalIds = mkSet(transactionalId1, transactionalId2); + DescribeTransactionsHandler handler = new DescribeTransactionsHandler(logContext); + + DescribeTransactionsResponseData.TransactionState transactionState1 = + sampleTransactionState1(transactionalId1); + DescribeTransactionsResponseData.TransactionState transactionState2 = + sampleTransactionState2(transactionalId2); + + Set keys = coordinatorKeys(transactionalIds); + DescribeTransactionsResponse response = new DescribeTransactionsResponse(new DescribeTransactionsResponseData() + .setTransactionStates(asList(transactionState1, transactionState2))); + + ApiResult result = handler.handleResponse( + node, keys, response); + + assertEquals(keys, result.completedKeys.keySet()); + assertMatchingTransactionState(node.id(), transactionState1, + result.completedKeys.get(CoordinatorKey.byTransactionalId(transactionalId1))); + assertMatchingTransactionState(node.id(), transactionState2, + result.completedKeys.get(CoordinatorKey.byTransactionalId(transactionalId2))); + } + + @Test + public void testHandleErrorResponse() { + String transactionalId = "foo"; + DescribeTransactionsHandler handler = new DescribeTransactionsHandler(logContext); + assertFatalError(handler, transactionalId, Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED); + assertFatalError(handler, transactionalId, Errors.TRANSACTIONAL_ID_NOT_FOUND); + assertFatalError(handler, transactionalId, Errors.UNKNOWN_SERVER_ERROR); + assertRetriableError(handler, transactionalId, Errors.COORDINATOR_LOAD_IN_PROGRESS); + assertUnmappedKey(handler, transactionalId, Errors.NOT_COORDINATOR); + assertUnmappedKey(handler, transactionalId, Errors.COORDINATOR_NOT_AVAILABLE); + } + + private void assertFatalError( + DescribeTransactionsHandler handler, + String transactionalId, + Errors error + ) { + CoordinatorKey key = CoordinatorKey.byTransactionalId(transactionalId); + ApiResult result = handleResponseError(handler, transactionalId, error); + assertEquals(emptyList(), result.unmappedKeys); + assertEquals(mkSet(key), result.failedKeys.keySet()); + + Throwable throwable = result.failedKeys.get(key); + assertTrue(error.exception().getClass().isInstance(throwable)); + } + + private void assertRetriableError( + DescribeTransactionsHandler handler, + String transactionalId, + Errors error + ) { + ApiResult result = handleResponseError(handler, transactionalId, error); + assertEquals(emptyList(), result.unmappedKeys); + assertEquals(emptyMap(), result.failedKeys); + } + + private void assertUnmappedKey( + DescribeTransactionsHandler handler, + String transactionalId, + Errors error + ) { + CoordinatorKey key = CoordinatorKey.byTransactionalId(transactionalId); + ApiResult result = handleResponseError(handler, transactionalId, error); + assertEquals(emptyMap(), result.failedKeys); + assertEquals(singletonList(key), result.unmappedKeys); + } + + private ApiResult handleResponseError( + DescribeTransactionsHandler handler, + String transactionalId, + Errors error + ) { + CoordinatorKey key = CoordinatorKey.byTransactionalId(transactionalId); + Set keys = mkSet(key); + + DescribeTransactionsResponseData.TransactionState transactionState = new DescribeTransactionsResponseData.TransactionState() + .setErrorCode(error.code()) + .setTransactionalId(transactionalId); + + DescribeTransactionsResponse response = new DescribeTransactionsResponse(new DescribeTransactionsResponseData() + .setTransactionStates(singletonList(transactionState))); + + ApiResult result = handler.handleResponse(node, keys, response); + assertEquals(emptyMap(), result.completedKeys); + return result; + } + + private void assertLookup( + DescribeTransactionsHandler handler, + Set transactionalIds + ) { + Set keys = coordinatorKeys(transactionalIds); + DescribeTransactionsRequest.Builder request = handler.buildRequest(1, keys); + assertEquals(transactionalIds, new HashSet<>(request.data.transactionalIds())); + } + + private static Set coordinatorKeys(Set transactionalIds) { + return transactionalIds.stream() + .map(CoordinatorKey::byTransactionalId) + .collect(Collectors.toSet()); + } + + private DescribeTransactionsResponseData.TransactionState sampleTransactionState1( + String transactionalId + ) { + return new DescribeTransactionsResponseData.TransactionState() + .setErrorCode(Errors.NONE.code()) + .setTransactionState("Ongoing") + .setTransactionalId(transactionalId) + .setProducerId(12345L) + .setProducerEpoch((short) 15) + .setTransactionStartTimeMs(1599151791L) + .setTransactionTimeoutMs(10000) + .setTopics(new DescribeTransactionsResponseData.TopicDataCollection(asList( + new DescribeTransactionsResponseData.TopicData() + .setTopic("foo") + .setPartitions(asList(1, 3, 5)), + new DescribeTransactionsResponseData.TopicData() + .setTopic("bar") + .setPartitions(asList(1, 3, 5)) + ).iterator())); + } + + private DescribeTransactionsResponseData.TransactionState sampleTransactionState2( + String transactionalId + ) { + return new DescribeTransactionsResponseData.TransactionState() + .setErrorCode(Errors.NONE.code()) + .setTransactionState("Empty") + .setTransactionalId(transactionalId) + .setProducerId(98765L) + .setProducerEpoch((short) 30) + .setTransactionStartTimeMs(-1); + } + + private void assertMatchingTransactionState( + int expectedCoordinatorId, + DescribeTransactionsResponseData.TransactionState expected, + TransactionDescription actual + ) { + assertEquals(expectedCoordinatorId, actual.coordinatorId()); + assertEquals(expected.producerId(), actual.producerId()); + assertEquals(expected.producerEpoch(), actual.producerEpoch()); + assertEquals(expected.transactionTimeoutMs(), actual.transactionTimeoutMs()); + assertEquals(expected.transactionStartTimeMs(), actual.transactionStartTimeMs().orElse(-1)); + assertEquals(collectTransactionPartitions(expected), actual.topicPartitions()); + } + + private Set collectTransactionPartitions( + DescribeTransactionsResponseData.TransactionState transactionState + ) { + Set topicPartitions = new HashSet<>(); + for (DescribeTransactionsResponseData.TopicData topicData : transactionState.topics()) { + for (Integer partitionId : topicData.partitions()) { + topicPartitions.add(new TopicPartition(topicData.topic(), partitionId)); + } + } + return topicPartitions; + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/internals/ListConsumerGroupOffsetsHandlerTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/internals/ListConsumerGroupOffsetsHandlerTest.java new file mode 100644 index 0000000..9c9bb1e --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/internals/ListConsumerGroupOffsetsHandlerTest.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import static java.util.Collections.emptyList; +import static java.util.Collections.emptySet; +import static java.util.Collections.singleton; +import static java.util.Collections.singletonList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.GroupAuthorizationException; +import org.apache.kafka.common.errors.GroupIdNotFoundException; +import org.apache.kafka.common.errors.InvalidGroupIdException; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.OffsetFetchRequest; +import org.apache.kafka.common.requests.OffsetFetchResponse; +import org.apache.kafka.common.requests.OffsetFetchResponse.PartitionData; +import org.apache.kafka.common.utils.LogContext; +import org.junit.jupiter.api.Test; + +public class ListConsumerGroupOffsetsHandlerTest { + + private final LogContext logContext = new LogContext(); + private final String groupId = "group-id"; + private final TopicPartition t0p0 = new TopicPartition("t0", 0); + private final TopicPartition t0p1 = new TopicPartition("t0", 1); + private final TopicPartition t1p0 = new TopicPartition("t1", 0); + private final TopicPartition t1p1 = new TopicPartition("t1", 1); + private final List tps = Arrays.asList(t0p0, t0p1, t1p0, t1p1); + + @Test + public void testBuildRequest() { + ListConsumerGroupOffsetsHandler handler = new ListConsumerGroupOffsetsHandler(groupId, tps, logContext); + OffsetFetchRequest request = handler.buildRequest(1, singleton(CoordinatorKey.byGroupId(groupId))).build(); + assertEquals(groupId, request.data().groups().get(0).groupId()); + assertEquals(2, request.data().groups().get(0).topics().size()); + assertEquals(2, request.data().groups().get(0).topics().get(0).partitionIndexes().size()); + assertEquals(2, request.data().groups().get(0).topics().get(1).partitionIndexes().size()); + } + + @Test + public void testSuccessfulHandleResponse() { + Map expected = new HashMap<>(); + assertCompleted(handleWithError(Errors.NONE), expected); + } + + + @Test + public void testSuccessfulHandleResponseWithOnePartitionError() { + Map expectedResult = Collections.singletonMap(t0p0, new OffsetAndMetadata(10L)); + + // expected that there's only 1 partition result returned because the other partition is skipped with error + assertCompleted(handleWithPartitionError(Errors.UNKNOWN_TOPIC_OR_PARTITION), expectedResult); + assertCompleted(handleWithPartitionError(Errors.TOPIC_AUTHORIZATION_FAILED), expectedResult); + assertCompleted(handleWithPartitionError(Errors.UNSTABLE_OFFSET_COMMIT), expectedResult); + } + + @Test + public void testUnmappedHandleResponse() { + assertUnmapped(handleWithError(Errors.COORDINATOR_NOT_AVAILABLE)); + assertUnmapped(handleWithError(Errors.NOT_COORDINATOR)); + } + + @Test + public void testRetriableHandleResponse() { + assertRetriable(handleWithError(Errors.COORDINATOR_LOAD_IN_PROGRESS)); + } + + @Test + public void testFailedHandleResponse() { + assertFailed(GroupAuthorizationException.class, handleWithError(Errors.GROUP_AUTHORIZATION_FAILED)); + assertFailed(GroupIdNotFoundException.class, handleWithError(Errors.GROUP_ID_NOT_FOUND)); + assertFailed(InvalidGroupIdException.class, handleWithError(Errors.INVALID_GROUP_ID)); + } + + private OffsetFetchResponse buildResponse(Errors error) { + Map responseData = new HashMap<>(); + OffsetFetchResponse response = new OffsetFetchResponse(error, responseData); + return response; + } + + private OffsetFetchResponse buildResponseWithPartitionError(Errors error) { + + Map responseData = new HashMap<>(); + responseData.put(t0p0, new OffsetFetchResponse.PartitionData(10, Optional.empty(), "", Errors.NONE)); + responseData.put(t0p1, new OffsetFetchResponse.PartitionData(10, Optional.empty(), "", error)); + + OffsetFetchResponse response = new OffsetFetchResponse(Errors.NONE, responseData); + return response; + } + + private AdminApiHandler.ApiResult> handleWithPartitionError( + Errors error + ) { + ListConsumerGroupOffsetsHandler handler = new ListConsumerGroupOffsetsHandler(groupId, tps, logContext); + OffsetFetchResponse response = buildResponseWithPartitionError(error); + return handler.handleResponse(new Node(1, "host", 1234), singleton(CoordinatorKey.byGroupId(groupId)), response); + } + + private AdminApiHandler.ApiResult> handleWithError( + Errors error + ) { + ListConsumerGroupOffsetsHandler handler = new ListConsumerGroupOffsetsHandler(groupId, tps, logContext); + OffsetFetchResponse response = buildResponse(error); + return handler.handleResponse(new Node(1, "host", 1234), singleton(CoordinatorKey.byGroupId(groupId)), response); + } + + private void assertUnmapped( + AdminApiHandler.ApiResult> result + ) { + assertEquals(emptySet(), result.completedKeys.keySet()); + assertEquals(emptySet(), result.failedKeys.keySet()); + assertEquals(singletonList(CoordinatorKey.byGroupId(groupId)), result.unmappedKeys); + } + + private void assertRetriable( + AdminApiHandler.ApiResult> result + ) { + assertEquals(emptySet(), result.completedKeys.keySet()); + assertEquals(emptySet(), result.failedKeys.keySet()); + assertEquals(emptyList(), result.unmappedKeys); + } + + private void assertCompleted( + AdminApiHandler.ApiResult> result, + Map expected + ) { + CoordinatorKey key = CoordinatorKey.byGroupId(groupId); + assertEquals(emptySet(), result.failedKeys.keySet()); + assertEquals(emptyList(), result.unmappedKeys); + assertEquals(singleton(key), result.completedKeys.keySet()); + assertEquals(expected, result.completedKeys.get(CoordinatorKey.byGroupId(groupId))); + } + + private void assertFailed( + Class expectedExceptionType, + AdminApiHandler.ApiResult> result + ) { + CoordinatorKey key = CoordinatorKey.byGroupId(groupId); + assertEquals(emptySet(), result.completedKeys.keySet()); + assertEquals(emptyList(), result.unmappedKeys); + assertEquals(singleton(key), result.failedKeys.keySet()); + assertTrue(expectedExceptionType.isInstance(result.failedKeys.get(key))); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/internals/ListTransactionsHandlerTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/internals/ListTransactionsHandlerTest.java new file mode 100644 index 0000000..a8923d1 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/internals/ListTransactionsHandlerTest.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import org.apache.kafka.clients.admin.ListTransactionsOptions; +import org.apache.kafka.clients.admin.TransactionListing; +import org.apache.kafka.clients.admin.TransactionState; +import org.apache.kafka.clients.admin.internals.AdminApiHandler.ApiResult; +import org.apache.kafka.clients.admin.internals.AllBrokersStrategy.BrokerKey; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.message.ListTransactionsResponseData; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.ListTransactionsRequest; +import org.apache.kafka.common.requests.ListTransactionsResponse; +import org.apache.kafka.common.utils.LogContext; +import org.junit.jupiter.api.Test; + +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.OptionalInt; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static java.util.Arrays.asList; +import static java.util.Collections.singleton; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +public class ListTransactionsHandlerTest { + private final LogContext logContext = new LogContext(); + private final Node node = new Node(1, "host", 1234); + + @Test + public void testBuildRequestWithoutFilters() { + int brokerId = 1; + BrokerKey brokerKey = new BrokerKey(OptionalInt.of(brokerId)); + ListTransactionsOptions options = new ListTransactionsOptions(); + ListTransactionsHandler handler = new ListTransactionsHandler(options, logContext); + ListTransactionsRequest request = handler.buildRequest(brokerId, singleton(brokerKey)).build(); + assertEquals(Collections.emptyList(), request.data().producerIdFilters()); + assertEquals(Collections.emptyList(), request.data().stateFilters()); + } + + @Test + public void testBuildRequestWithFilteredProducerId() { + int brokerId = 1; + BrokerKey brokerKey = new BrokerKey(OptionalInt.of(brokerId)); + long filteredProducerId = 23423L; + ListTransactionsOptions options = new ListTransactionsOptions() + .filterProducerIds(singleton(filteredProducerId)); + ListTransactionsHandler handler = new ListTransactionsHandler(options, logContext); + ListTransactionsRequest request = handler.buildRequest(brokerId, singleton(brokerKey)).build(); + assertEquals(Collections.singletonList(filteredProducerId), request.data().producerIdFilters()); + assertEquals(Collections.emptyList(), request.data().stateFilters()); + } + + @Test + public void testBuildRequestWithFilteredState() { + int brokerId = 1; + BrokerKey brokerKey = new BrokerKey(OptionalInt.of(brokerId)); + TransactionState filteredState = TransactionState.ONGOING; + ListTransactionsOptions options = new ListTransactionsOptions() + .filterStates(singleton(filteredState)); + ListTransactionsHandler handler = new ListTransactionsHandler(options, logContext); + ListTransactionsRequest request = handler.buildRequest(brokerId, singleton(brokerKey)).build(); + assertEquals(Collections.singletonList(filteredState.toString()), request.data().stateFilters()); + assertEquals(Collections.emptyList(), request.data().producerIdFilters()); + } + + @Test + public void testHandleSuccessfulResponse() { + int brokerId = 1; + BrokerKey brokerKey = new BrokerKey(OptionalInt.of(brokerId)); + ListTransactionsOptions options = new ListTransactionsOptions(); + ListTransactionsHandler handler = new ListTransactionsHandler(options, logContext); + ListTransactionsResponse response = sampleListTransactionsResponse1(); + ApiResult> result = handler.handleResponse( + node, singleton(brokerKey), response); + assertEquals(singleton(brokerKey), result.completedKeys.keySet()); + assertExpectedTransactions(response.data().transactionStates(), result.completedKeys.get(brokerKey)); + } + + @Test + public void testCoordinatorLoadingErrorIsRetriable() { + int brokerId = 1; + ApiResult> result = + handleResponseWithError(brokerId, Errors.COORDINATOR_LOAD_IN_PROGRESS); + assertEquals(Collections.emptyMap(), result.completedKeys); + assertEquals(Collections.emptyMap(), result.failedKeys); + assertEquals(Collections.emptyList(), result.unmappedKeys); + } + + @Test + public void testHandleResponseWithFatalErrors() { + assertFatalError(Errors.COORDINATOR_NOT_AVAILABLE); + assertFatalError(Errors.UNKNOWN_SERVER_ERROR); + } + + private void assertFatalError( + Errors error + ) { + int brokerId = 1; + BrokerKey brokerKey = new BrokerKey(OptionalInt.of(brokerId)); + ApiResult> result = handleResponseWithError(brokerId, error); + assertEquals(Collections.emptyMap(), result.completedKeys); + assertEquals(Collections.emptyList(), result.unmappedKeys); + assertEquals(Collections.singleton(brokerKey), result.failedKeys.keySet()); + Throwable throwable = result.failedKeys.get(brokerKey); + assertEquals(error, Errors.forException(throwable)); + } + + private ApiResult> handleResponseWithError( + int brokerId, + Errors error + ) { + BrokerKey brokerKey = new BrokerKey(OptionalInt.of(brokerId)); + ListTransactionsOptions options = new ListTransactionsOptions(); + ListTransactionsHandler handler = new ListTransactionsHandler(options, logContext); + + ListTransactionsResponse response = new ListTransactionsResponse( + new ListTransactionsResponseData().setErrorCode(error.code()) + ); + return handler.handleResponse(node, singleton(brokerKey), response); + } + + private ListTransactionsResponse sampleListTransactionsResponse1() { + return new ListTransactionsResponse( + new ListTransactionsResponseData() + .setErrorCode(Errors.NONE.code()) + .setTransactionStates(asList( + new ListTransactionsResponseData.TransactionState() + .setTransactionalId("foo") + .setProducerId(12345L) + .setTransactionState("Ongoing"), + new ListTransactionsResponseData.TransactionState() + .setTransactionalId("bar") + .setProducerId(98765L) + .setTransactionState("PrepareAbort") + )) + ); + } + + private void assertExpectedTransactions( + List expected, + Collection actual + ) { + assertEquals(expected.size(), actual.size()); + + Map expectedMap = expected.stream().collect(Collectors.toMap( + ListTransactionsResponseData.TransactionState::transactionalId, + Function.identity() + )); + + for (TransactionListing actualListing : actual) { + ListTransactionsResponseData.TransactionState expectedState = + expectedMap.get(actualListing.transactionalId()); + assertNotNull(expectedState); + assertExpectedTransactionState(expectedState, actualListing); + } + } + + private void assertExpectedTransactionState( + ListTransactionsResponseData.TransactionState expected, + TransactionListing actual + ) { + assertEquals(expected.transactionalId(), actual.transactionalId()); + assertEquals(expected.producerId(), actual.producerId()); + assertEquals(expected.transactionState(), actual.state().toString()); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/internals/PartitionLeaderStrategyTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/internals/PartitionLeaderStrategyTest.java new file mode 100644 index 0000000..e600df4 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/internals/PartitionLeaderStrategyTest.java @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import org.apache.kafka.clients.admin.internals.AdminApiLookupStrategy.LookupResult; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.InvalidTopicException; +import org.apache.kafka.common.errors.TopicAuthorizationException; +import org.apache.kafka.common.errors.UnknownServerException; +import org.apache.kafka.common.message.MetadataResponseData; +import org.apache.kafka.common.message.MetadataResponseData.MetadataResponsePartition; +import org.apache.kafka.common.message.MetadataResponseData.MetadataResponseTopic; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.MetadataRequest; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.utils.LogContext; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.singletonMap; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class PartitionLeaderStrategyTest { + + private PartitionLeaderStrategy newStrategy() { + return new PartitionLeaderStrategy(new LogContext()); + } + + @Test + public void testBuildLookupRequest() { + Set topicPartitions = mkSet( + new TopicPartition("foo", 0), + new TopicPartition("bar", 0), + new TopicPartition("foo", 1), + new TopicPartition("baz", 0) + ); + + PartitionLeaderStrategy strategy = newStrategy(); + + MetadataRequest allRequest = strategy.buildRequest(topicPartitions).build(); + assertEquals(mkSet("foo", "bar", "baz"), new HashSet<>(allRequest.topics())); + assertFalse(allRequest.allowAutoTopicCreation()); + + MetadataRequest partialRequest = strategy.buildRequest( + topicPartitions.stream().filter(tp -> tp.topic().equals("foo")).collect(Collectors.toSet()) + ).build(); + assertEquals(mkSet("foo"), new HashSet<>(partialRequest.topics())); + assertFalse(partialRequest.allowAutoTopicCreation()); + } + + @Test + public void testTopicAuthorizationFailure() { + TopicPartition topicPartition = new TopicPartition("foo", 0); + Throwable exception = assertFatalTopicError(topicPartition, Errors.TOPIC_AUTHORIZATION_FAILED); + assertTrue(exception instanceof TopicAuthorizationException); + TopicAuthorizationException authException = (TopicAuthorizationException) exception; + assertEquals(mkSet("foo"), authException.unauthorizedTopics()); + } + + @Test + public void testInvalidTopicError() { + TopicPartition topicPartition = new TopicPartition("foo", 0); + Throwable exception = assertFatalTopicError(topicPartition, Errors.INVALID_TOPIC_EXCEPTION); + assertTrue(exception instanceof InvalidTopicException); + InvalidTopicException invalidTopicException = (InvalidTopicException) exception; + assertEquals(mkSet("foo"), invalidTopicException.invalidTopics()); + } + + @Test + public void testUnexpectedTopicErrror() { + TopicPartition topicPartition = new TopicPartition("foo", 0); + Throwable exception = assertFatalTopicError(topicPartition, Errors.UNKNOWN_SERVER_ERROR); + assertTrue(exception instanceof UnknownServerException); + } + + @Test + public void testRetriableTopicErrors() { + TopicPartition topicPartition = new TopicPartition("foo", 0); + assertRetriableTopicError(topicPartition, Errors.UNKNOWN_TOPIC_OR_PARTITION); + assertRetriableTopicError(topicPartition, Errors.LEADER_NOT_AVAILABLE); + assertRetriableTopicError(topicPartition, Errors.BROKER_NOT_AVAILABLE); + } + + @Test + public void testRetriablePartitionErrors() { + TopicPartition topicPartition = new TopicPartition("foo", 0); + assertRetriablePartitionError(topicPartition, Errors.NOT_LEADER_OR_FOLLOWER); + assertRetriablePartitionError(topicPartition, Errors.REPLICA_NOT_AVAILABLE); + assertRetriablePartitionError(topicPartition, Errors.LEADER_NOT_AVAILABLE); + assertRetriablePartitionError(topicPartition, Errors.BROKER_NOT_AVAILABLE); + assertRetriablePartitionError(topicPartition, Errors.KAFKA_STORAGE_ERROR); + } + + @Test + public void testUnexpectedPartitionError() { + TopicPartition topicPartition = new TopicPartition("foo", 0); + Throwable exception = assertFatalPartitionError(topicPartition, Errors.UNKNOWN_SERVER_ERROR); + assertTrue(exception instanceof UnknownServerException); + } + + @Test + public void testPartitionSuccessfullyMapped() { + TopicPartition topicPartition1 = new TopicPartition("foo", 0); + TopicPartition topicPartition2 = new TopicPartition("bar", 1); + + Map responsePartitions = new HashMap<>(2); + responsePartitions.put(topicPartition1, partitionResponseDataWithLeader( + topicPartition1, 5, Arrays.asList(5, 6, 7))); + responsePartitions.put(topicPartition2, partitionResponseDataWithLeader( + topicPartition2, 1, Arrays.asList(2, 1, 3))); + + LookupResult result = handleLookupResponse( + mkSet(topicPartition1, topicPartition2), + responseWithPartitionData(responsePartitions) + ); + + assertEquals(emptyMap(), result.failedKeys); + assertEquals(mkSet(topicPartition1, topicPartition2), result.mappedKeys.keySet()); + assertEquals(5, result.mappedKeys.get(topicPartition1)); + assertEquals(1, result.mappedKeys.get(topicPartition2)); + } + + @Test + public void testIgnoreUnrequestedPartitions() { + TopicPartition requestedTopicPartition = new TopicPartition("foo", 0); + TopicPartition unrequestedTopicPartition = new TopicPartition("foo", 1); + + Map responsePartitions = new HashMap<>(2); + responsePartitions.put(requestedTopicPartition, partitionResponseDataWithLeader( + requestedTopicPartition, 5, Arrays.asList(5, 6, 7))); + responsePartitions.put(unrequestedTopicPartition, partitionResponseDataWithError( + unrequestedTopicPartition, Errors.UNKNOWN_SERVER_ERROR)); + + LookupResult result = handleLookupResponse( + mkSet(requestedTopicPartition), + responseWithPartitionData(responsePartitions) + ); + + assertEquals(emptyMap(), result.failedKeys); + assertEquals(mkSet(requestedTopicPartition), result.mappedKeys.keySet()); + assertEquals(5, result.mappedKeys.get(requestedTopicPartition)); + } + + @Test + public void testRetryIfLeaderUnknown() { + TopicPartition topicPartition = new TopicPartition("foo", 0); + + Map responsePartitions = singletonMap( + topicPartition, + partitionResponseDataWithLeader(topicPartition, -1, Arrays.asList(5, 6, 7)) + ); + + LookupResult result = handleLookupResponse( + mkSet(topicPartition), + responseWithPartitionData(responsePartitions) + ); + + assertEquals(emptyMap(), result.failedKeys); + assertEquals(emptyMap(), result.mappedKeys); + } + + private void assertRetriableTopicError( + TopicPartition topicPartition, + Errors error + ) { + assertRetriableError( + topicPartition, + responseWithTopicError(topicPartition.topic(), error) + ); + } + + private void assertRetriablePartitionError( + TopicPartition topicPartition, + Errors error + ) { + MetadataResponse response = responseWithPartitionData(singletonMap( + topicPartition, + partitionResponseDataWithError(topicPartition, error) + )); + assertRetriableError(topicPartition, response); + } + + private Throwable assertFatalTopicError( + TopicPartition topicPartition, + Errors error + ) { + return assertFatalError( + topicPartition, + responseWithTopicError(topicPartition.topic(), error) + ); + } + + private Throwable assertFatalPartitionError( + TopicPartition topicPartition, + Errors error + ) { + MetadataResponse response = responseWithPartitionData(singletonMap( + topicPartition, + partitionResponseDataWithError(topicPartition, error) + )); + return assertFatalError(topicPartition, response); + } + + private void assertRetriableError( + TopicPartition topicPartition, + MetadataResponse response + ) { + LookupResult result = handleLookupResponse(mkSet(topicPartition), response); + assertEquals(emptyMap(), result.failedKeys); + assertEquals(emptyMap(), result.mappedKeys); + } + + private Throwable assertFatalError( + TopicPartition topicPartition, + MetadataResponse response + ) { + LookupResult result = handleLookupResponse(mkSet(topicPartition), response); + assertEquals(mkSet(topicPartition), result.failedKeys.keySet()); + return result.failedKeys.get(topicPartition); + } + + private LookupResult handleLookupResponse( + Set topicPartitions, + MetadataResponse response + ) { + PartitionLeaderStrategy strategy = newStrategy(); + return strategy.handleResponse(topicPartitions, response); + } + + private MetadataResponse responseWithTopicError(String topic, Errors error) { + MetadataResponseTopic responseTopic = new MetadataResponseTopic() + .setName(topic) + .setErrorCode(error.code()); + MetadataResponseData responseData = new MetadataResponseData(); + responseData.topics().add(responseTopic); + return new MetadataResponse(responseData, ApiKeys.METADATA.latestVersion()); + } + + private MetadataResponsePartition partitionResponseDataWithError(TopicPartition topicPartition, Errors error) { + return new MetadataResponsePartition() + .setPartitionIndex(topicPartition.partition()) + .setErrorCode(error.code()); + } + + private MetadataResponsePartition partitionResponseDataWithLeader( + TopicPartition topicPartition, + Integer leaderId, + List replicas + ) { + return new MetadataResponsePartition() + .setPartitionIndex(topicPartition.partition()) + .setErrorCode(Errors.NONE.code()) + .setLeaderId(leaderId) + .setReplicaNodes(replicas) + .setIsrNodes(replicas); + } + + private MetadataResponse responseWithPartitionData( + Map responsePartitions + ) { + MetadataResponseData responseData = new MetadataResponseData(); + for (Map.Entry entry : responsePartitions.entrySet()) { + TopicPartition topicPartition = entry.getKey(); + MetadataResponseTopic responseTopic = responseData.topics().find(topicPartition.topic()); + if (responseTopic == null) { + responseTopic = new MetadataResponseTopic() + .setName(topicPartition.topic()) + .setErrorCode(Errors.NONE.code()); + responseData.topics().add(responseTopic); + } + responseTopic.partitions().add(entry.getValue()); + } + return new MetadataResponse(responseData, ApiKeys.METADATA.latestVersion()); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/internals/RemoveMembersFromConsumerGroupHandlerTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/internals/RemoveMembersFromConsumerGroupHandlerTest.java new file mode 100644 index 0000000..6f5dfda --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/admin/internals/RemoveMembersFromConsumerGroupHandlerTest.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.admin.internals; + +import static java.util.Collections.emptyList; +import static java.util.Collections.emptySet; +import static java.util.Collections.singleton; +import static java.util.Collections.singletonList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.apache.kafka.common.Node; +import org.apache.kafka.common.errors.GroupAuthorizationException; +import org.apache.kafka.common.errors.UnknownServerException; +import org.apache.kafka.common.message.LeaveGroupRequestData.MemberIdentity; +import org.apache.kafka.common.message.LeaveGroupResponseData; +import org.apache.kafka.common.message.LeaveGroupResponseData.MemberResponse; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.LeaveGroupRequest; +import org.apache.kafka.common.requests.LeaveGroupResponse; +import org.apache.kafka.common.utils.LogContext; +import org.junit.jupiter.api.Test; + +public class RemoveMembersFromConsumerGroupHandlerTest { + + private final LogContext logContext = new LogContext(); + private final String groupId = "group-id"; + private final MemberIdentity m1 = new MemberIdentity() + .setMemberId("m1") + .setGroupInstanceId("m1-gii"); + private final MemberIdentity m2 = new MemberIdentity() + .setMemberId("m2") + .setGroupInstanceId("m2-gii"); + private final List members = Arrays.asList(m1, m2); + + @Test + public void testBuildRequest() { + RemoveMembersFromConsumerGroupHandler handler = new RemoveMembersFromConsumerGroupHandler(groupId, members, logContext); + LeaveGroupRequest request = handler.buildRequest(1, singleton(CoordinatorKey.byGroupId(groupId))).build(); + assertEquals(groupId, request.data().groupId()); + assertEquals(2, request.data().members().size()); + } + + @Test + public void testSuccessfulHandleResponse() { + Map responseData = Collections.singletonMap(m1, Errors.NONE); + assertCompleted(handleWithGroupError(Errors.NONE), responseData); + } + + @Test + public void testUnmappedHandleResponse() { + assertUnmapped(handleWithGroupError(Errors.COORDINATOR_NOT_AVAILABLE)); + assertUnmapped(handleWithGroupError(Errors.NOT_COORDINATOR)); + } + + @Test + public void testRetriableHandleResponse() { + assertRetriable(handleWithGroupError(Errors.COORDINATOR_LOAD_IN_PROGRESS)); + } + + @Test + public void testFailedHandleResponse() { + assertFailed(GroupAuthorizationException.class, handleWithGroupError(Errors.GROUP_AUTHORIZATION_FAILED)); + assertFailed(UnknownServerException.class, handleWithGroupError(Errors.UNKNOWN_SERVER_ERROR)); + } + + @Test + public void testFailedHandleResponseInMemberLevel() { + assertMemberFailed(Errors.FENCED_INSTANCE_ID, handleWithMemberError(Errors.FENCED_INSTANCE_ID)); + assertMemberFailed(Errors.UNKNOWN_MEMBER_ID, handleWithMemberError(Errors.UNKNOWN_MEMBER_ID)); + } + + private LeaveGroupResponse buildResponse(Errors error) { + LeaveGroupResponse response = new LeaveGroupResponse( + new LeaveGroupResponseData() + .setErrorCode(error.code()) + .setMembers(singletonList( + new MemberResponse() + .setErrorCode(Errors.NONE.code()) + .setMemberId("m1") + .setGroupInstanceId("m1-gii")))); + return response; + } + + private LeaveGroupResponse buildResponseWithMemberError(Errors error) { + LeaveGroupResponse response = new LeaveGroupResponse( + new LeaveGroupResponseData() + .setErrorCode(Errors.NONE.code()) + .setMembers(singletonList( + new MemberResponse() + .setErrorCode(error.code()) + .setMemberId("m1") + .setGroupInstanceId("m1-gii")))); + return response; + } + + private AdminApiHandler.ApiResult> handleWithGroupError( + Errors error + ) { + RemoveMembersFromConsumerGroupHandler handler = new RemoveMembersFromConsumerGroupHandler(groupId, members, logContext); + LeaveGroupResponse response = buildResponse(error); + return handler.handleResponse(new Node(1, "host", 1234), singleton(CoordinatorKey.byGroupId(groupId)), response); + } + + private AdminApiHandler.ApiResult> handleWithMemberError( + Errors error + ) { + RemoveMembersFromConsumerGroupHandler handler = new RemoveMembersFromConsumerGroupHandler(groupId, members, logContext); + LeaveGroupResponse response = buildResponseWithMemberError(error); + return handler.handleResponse(new Node(1, "host", 1234), singleton(CoordinatorKey.byGroupId(groupId)), response); + } + + private void assertUnmapped( + AdminApiHandler.ApiResult> result + ) { + assertEquals(emptySet(), result.completedKeys.keySet()); + assertEquals(emptySet(), result.failedKeys.keySet()); + assertEquals(singletonList(CoordinatorKey.byGroupId(groupId)), result.unmappedKeys); + } + + private void assertRetriable( + AdminApiHandler.ApiResult> result + ) { + assertEquals(emptySet(), result.completedKeys.keySet()); + assertEquals(emptySet(), result.failedKeys.keySet()); + assertEquals(emptyList(), result.unmappedKeys); + } + + private void assertCompleted( + AdminApiHandler.ApiResult> result, + Map expected + ) { + CoordinatorKey key = CoordinatorKey.byGroupId(groupId); + assertEquals(emptySet(), result.failedKeys.keySet()); + assertEquals(emptyList(), result.unmappedKeys); + assertEquals(singleton(key), result.completedKeys.keySet()); + assertEquals(expected, result.completedKeys.get(key)); + } + + private void assertFailed( + Class expectedExceptionType, + AdminApiHandler.ApiResult> result + ) { + CoordinatorKey key = CoordinatorKey.byGroupId(groupId); + assertEquals(emptySet(), result.completedKeys.keySet()); + assertEquals(emptyList(), result.unmappedKeys); + assertEquals(singleton(key), result.failedKeys.keySet()); + assertTrue(expectedExceptionType.isInstance(result.failedKeys.get(key))); + } + + private void assertMemberFailed( + Errors expectedError, + AdminApiHandler.ApiResult> result + ) { + Map expectedResponseData = Collections.singletonMap(m1, expectedError); + CoordinatorKey key = CoordinatorKey.byGroupId(groupId); + assertEquals(emptySet(), result.failedKeys.keySet()); + assertEquals(emptyList(), result.unmappedKeys); + assertEquals(singleton(key), result.completedKeys.keySet()); + assertEquals(expectedResponseData, result.completedKeys.get(key)); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/ConsumerConfigTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/ConsumerConfigTest.java new file mode 100644 index 0000000..dc1eeac --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/ConsumerConfigTest.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import org.apache.kafka.common.errors.InvalidConfigurationException; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +public class ConsumerConfigTest { + + private final Deserializer keyDeserializer = new ByteArrayDeserializer(); + private final Deserializer valueDeserializer = new StringDeserializer(); + private final String keyDeserializerClassName = keyDeserializer.getClass().getName(); + private final String valueDeserializerClassName = valueDeserializer.getClass().getName(); + private final Object keyDeserializerClass = keyDeserializer.getClass(); + private final Object valueDeserializerClass = valueDeserializer.getClass(); + private final Properties properties = new Properties(); + + @BeforeEach + public void setUp() { + properties.setProperty(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, keyDeserializerClassName); + properties.setProperty(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, valueDeserializerClassName); + } + + @Test + public void testOverrideClientId() { + properties.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "test-group"); + ConsumerConfig config = new ConsumerConfig(properties); + assertFalse(config.getString(ConsumerConfig.CLIENT_ID_CONFIG).isEmpty()); + } + + @Test + public void testOverrideEnableAutoCommit() { + ConsumerConfig config = new ConsumerConfig(properties); + boolean overrideEnableAutoCommit = config.maybeOverrideEnableAutoCommit(); + assertFalse(overrideEnableAutoCommit); + + properties.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "true"); + config = new ConsumerConfig(properties); + try { + config.maybeOverrideEnableAutoCommit(); + fail("Should have thrown an exception"); + } catch (InvalidConfigurationException e) { + // expected + } + } + + @Test + public void testAppendDeserializerToConfig() { + Map configs = new HashMap<>(); + configs.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, keyDeserializerClass); + configs.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, valueDeserializerClass); + Map newConfigs = ConsumerConfig.appendDeserializerToConfig(configs, null, null); + assertEquals(newConfigs.get(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG), keyDeserializerClass); + assertEquals(newConfigs.get(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG), valueDeserializerClass); + + configs.clear(); + configs.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, valueDeserializerClass); + newConfigs = ConsumerConfig.appendDeserializerToConfig(configs, keyDeserializer, null); + assertEquals(newConfigs.get(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG), keyDeserializerClass); + assertEquals(newConfigs.get(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG), valueDeserializerClass); + + configs.clear(); + configs.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, keyDeserializerClass); + newConfigs = ConsumerConfig.appendDeserializerToConfig(configs, null, valueDeserializer); + assertEquals(newConfigs.get(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG), keyDeserializerClass); + assertEquals(newConfigs.get(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG), valueDeserializerClass); + + configs.clear(); + newConfigs = ConsumerConfig.appendDeserializerToConfig(configs, keyDeserializer, valueDeserializer); + assertEquals(newConfigs.get(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG), keyDeserializerClass); + assertEquals(newConfigs.get(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG), valueDeserializerClass); + } + + @Test + public void ensureDefaultThrowOnUnsupportedStableFlagToFalse() { + assertFalse(new ConsumerConfig(properties).getBoolean(ConsumerConfig.THROW_ON_FETCH_STABLE_OFFSET_UNSUPPORTED)); + } + + @Test + public void testDefaultPartitionAssignor() { + assertEquals(Arrays.asList(RangeAssignor.class, CooperativeStickyAssignor.class), + new ConsumerConfig(properties).getList(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG)); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/ConsumerGroupMetadataTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/ConsumerGroupMetadataTest.java new file mode 100644 index 0000000..b32b49c --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/ConsumerGroupMetadataTest.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import org.apache.kafka.common.requests.JoinGroupRequest; +import org.junit.jupiter.api.Test; + +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ConsumerGroupMetadataTest { + + private String groupId = "group"; + + @Test + public void testAssignmentConstructor() { + String memberId = "member"; + int generationId = 2; + String groupInstanceId = "instance"; + + ConsumerGroupMetadata groupMetadata = new ConsumerGroupMetadata(groupId, + generationId, memberId, Optional.of(groupInstanceId)); + + assertEquals(groupId, groupMetadata.groupId()); + assertEquals(generationId, groupMetadata.generationId()); + assertEquals(memberId, groupMetadata.memberId()); + assertTrue(groupMetadata.groupInstanceId().isPresent()); + assertEquals(groupInstanceId, groupMetadata.groupInstanceId().get()); + } + + @Test + public void testGroupIdConstructor() { + ConsumerGroupMetadata groupMetadata = new ConsumerGroupMetadata(groupId); + + assertEquals(groupId, groupMetadata.groupId()); + assertEquals(JoinGroupRequest.UNKNOWN_GENERATION_ID, groupMetadata.generationId()); + assertEquals(JoinGroupRequest.UNKNOWN_MEMBER_ID, groupMetadata.memberId()); + assertFalse(groupMetadata.groupInstanceId().isPresent()); + } + + @Test + public void testInvalidGroupId() { + String memberId = "member"; + int generationId = 2; + + assertThrows(NullPointerException.class, () -> new ConsumerGroupMetadata( + null, generationId, memberId, Optional.empty()) + ); + } + + @Test + public void testInvalidMemberId() { + int generationId = 2; + + assertThrows(NullPointerException.class, () -> new ConsumerGroupMetadata( + groupId, generationId, null, Optional.empty()) + ); + } + + @Test + public void testInvalidInstanceId() { + String memberId = "member"; + int generationId = 2; + + assertThrows(NullPointerException.class, () -> new ConsumerGroupMetadata( + groupId, generationId, memberId, null) + ); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/ConsumerPartitionAssignorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/ConsumerPartitionAssignorTest.java new file mode 100644 index 0000000..1298f8c --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/ConsumerPartitionAssignorTest.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + + +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Properties; +import java.util.Set; + +import static org.apache.kafka.clients.consumer.ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG; +import static org.apache.kafka.clients.consumer.ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG; +import static org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.getAssignorInstances; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ConsumerPartitionAssignorTest { + + @Test + public void shouldInstantiateAssignor() { + List assignors = getAssignorInstances( + Collections.singletonList(StickyAssignor.class.getName()), + Collections.emptyMap() + ); + assertTrue(assignors.get(0) instanceof StickyAssignor); + } + + @Test + public void shouldInstantiateListOfAssignors() { + List assignors = getAssignorInstances( + Arrays.asList(StickyAssignor.class.getName(), CooperativeStickyAssignor.class.getName()), + Collections.emptyMap() + ); + assertTrue(assignors.get(0) instanceof StickyAssignor); + assertTrue(assignors.get(1) instanceof CooperativeStickyAssignor); + } + + @Test + public void shouldThrowKafkaExceptionOnNonAssignor() { + assertThrows(KafkaException.class, () -> getAssignorInstances( + Collections.singletonList(String.class.getName()), + Collections.emptyMap()) + ); + } + + @Test + public void shouldThrowKafkaExceptionOnAssignorNotFound() { + assertThrows(KafkaException.class, () -> getAssignorInstances( + Collections.singletonList("Non-existent assignor"), + Collections.emptyMap()) + ); + } + + @Test + public void shouldInstantiateFromClassType() { + List classTypes = + initConsumerConfigWithClassTypes(Collections.singletonList(StickyAssignor.class)) + .getList(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG); + List assignors = getAssignorInstances(classTypes, Collections.emptyMap()); + assertTrue(assignors.get(0) instanceof StickyAssignor); + } + + @Test + public void shouldInstantiateFromListOfClassTypes() { + List classTypes = initConsumerConfigWithClassTypes( + Arrays.asList(StickyAssignor.class, CooperativeStickyAssignor.class) + ).getList(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG); + + List assignors = getAssignorInstances(classTypes, Collections.emptyMap()); + + assertTrue(assignors.get(0) instanceof StickyAssignor); + assertTrue(assignors.get(1) instanceof CooperativeStickyAssignor); + } + + @Test + public void shouldThrowKafkaExceptionOnListWithNonAssignorClassType() { + List classTypes = + initConsumerConfigWithClassTypes(Arrays.asList(StickyAssignor.class, String.class)) + .getList(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG); + + assertThrows(KafkaException.class, () -> getAssignorInstances(classTypes, Collections.emptyMap())); + } + + @Test + public void shouldThrowKafkaExceptionOnAssignorsWithSameName() { + assertThrows(KafkaException.class, () -> getAssignorInstances( + Arrays.asList(RangeAssignor.class.getName(), TestConsumerPartitionAssignor.class.getName()), + Collections.emptyMap() + )); + } + + public static class TestConsumerPartitionAssignor implements ConsumerPartitionAssignor { + + @Override + public ByteBuffer subscriptionUserData(Set topics) { + return ConsumerPartitionAssignor.super.subscriptionUserData(topics); + } + + @Override + public GroupAssignment assign(Cluster metadata, GroupSubscription groupSubscription) { + return null; + } + + @Override + public void onAssignment(Assignment assignment, ConsumerGroupMetadata metadata) { + ConsumerPartitionAssignor.super.onAssignment(assignment, metadata); + } + + @Override + public List supportedProtocols() { + return ConsumerPartitionAssignor.super.supportedProtocols(); + } + + @Override + public short version() { + return ConsumerPartitionAssignor.super.version(); + } + + @Override + public String name() { + // use the RangeAssignor's name to cause naming conflict + return new RangeAssignor().name(); + } + } + + private ConsumerConfig initConsumerConfigWithClassTypes(List classTypes) { + Properties props = new Properties(); + props.put(KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); + props.put(VALUE_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); + props.put(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG, classTypes); + return new ConsumerConfig(props); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/ConsumerRecordTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/ConsumerRecordTest.java new file mode 100644 index 0000000..848f75c --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/ConsumerRecordTest.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.record.TimestampType; +import org.junit.jupiter.api.Test; + +import java.nio.charset.StandardCharsets; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class ConsumerRecordTest { + + @Test + public void testShortConstructor() { + String topic = "topic"; + int partition = 0; + long offset = 23; + String key = "key"; + String value = "value"; + + ConsumerRecord record = new ConsumerRecord<>(topic, partition, offset, key, value); + assertEquals(topic, record.topic()); + assertEquals(partition, record.partition()); + assertEquals(offset, record.offset()); + assertEquals(key, record.key()); + assertEquals(value, record.value()); + assertEquals(TimestampType.NO_TIMESTAMP_TYPE, record.timestampType()); + assertEquals(ConsumerRecord.NO_TIMESTAMP, record.timestamp()); + assertEquals(ConsumerRecord.NULL_SIZE, record.serializedKeySize()); + assertEquals(ConsumerRecord.NULL_SIZE, record.serializedValueSize()); + assertEquals(Optional.empty(), record.leaderEpoch()); + assertEquals(new RecordHeaders(), record.headers()); + } + + @Test + @Deprecated + public void testConstructorsWithChecksum() { + String topic = "topic"; + int partition = 0; + long offset = 23; + long timestamp = 23434217432432L; + TimestampType timestampType = TimestampType.CREATE_TIME; + String key = "key"; + String value = "value"; + long checksum = 50L; + int serializedKeySize = 100; + int serializedValueSize = 1142; + + ConsumerRecord record = new ConsumerRecord<>(topic, partition, offset, timestamp, timestampType, + checksum, serializedKeySize, serializedValueSize, key, value); + assertEquals(topic, record.topic()); + assertEquals(partition, record.partition()); + assertEquals(offset, record.offset()); + assertEquals(key, record.key()); + assertEquals(value, record.value()); + assertEquals(timestampType, record.timestampType()); + assertEquals(timestamp, record.timestamp()); + assertEquals(serializedKeySize, record.serializedKeySize()); + assertEquals(serializedValueSize, record.serializedValueSize()); + assertEquals(Optional.empty(), record.leaderEpoch()); + assertEquals(new RecordHeaders(), record.headers()); + + RecordHeaders headers = new RecordHeaders(); + headers.add(new RecordHeader("header key", "header value".getBytes(StandardCharsets.UTF_8))); + record = new ConsumerRecord<>(topic, partition, offset, timestamp, timestampType, + checksum, serializedKeySize, serializedValueSize, key, value, headers); + assertEquals(topic, record.topic()); + assertEquals(partition, record.partition()); + assertEquals(offset, record.offset()); + assertEquals(key, record.key()); + assertEquals(value, record.value()); + assertEquals(timestampType, record.timestampType()); + assertEquals(timestamp, record.timestamp()); + assertEquals(serializedKeySize, record.serializedKeySize()); + assertEquals(serializedValueSize, record.serializedValueSize()); + assertEquals(Optional.empty(), record.leaderEpoch()); + assertEquals(headers, record.headers()); + + Optional leaderEpoch = Optional.of(10); + record = new ConsumerRecord<>(topic, partition, offset, timestamp, timestampType, + checksum, serializedKeySize, serializedValueSize, key, value, headers, leaderEpoch); + assertEquals(topic, record.topic()); + assertEquals(partition, record.partition()); + assertEquals(offset, record.offset()); + assertEquals(key, record.key()); + assertEquals(value, record.value()); + assertEquals(timestampType, record.timestampType()); + assertEquals(timestamp, record.timestamp()); + assertEquals(serializedKeySize, record.serializedKeySize()); + assertEquals(serializedValueSize, record.serializedValueSize()); + assertEquals(leaderEpoch, record.leaderEpoch()); + assertEquals(headers, record.headers()); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/ConsumerRecordsTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/ConsumerRecordsTest.java new file mode 100644 index 0000000..d414450 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/ConsumerRecordsTest.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.record.TimestampType; +import org.junit.jupiter.api.Test; + +public class ConsumerRecordsTest { + + @Test + public void iterator() throws Exception { + + Map>> records = new LinkedHashMap<>(); + + String topic = "topic"; + records.put(new TopicPartition(topic, 0), new ArrayList>()); + ConsumerRecord record1 = new ConsumerRecord<>(topic, 1, 0, 0L, TimestampType.CREATE_TIME, + 0, 0, 1, "value1", new RecordHeaders(), Optional.empty()); + ConsumerRecord record2 = new ConsumerRecord<>(topic, 1, 1, 0L, TimestampType.CREATE_TIME, + 0, 0, 2, "value2", new RecordHeaders(), Optional.empty()); + records.put(new TopicPartition(topic, 1), Arrays.asList(record1, record2)); + records.put(new TopicPartition(topic, 2), new ArrayList<>()); + + ConsumerRecords consumerRecords = new ConsumerRecords<>(records); + Iterator> iter = consumerRecords.iterator(); + + int c = 0; + for (; iter.hasNext(); c++) { + ConsumerRecord record = iter.next(); + assertEquals(1, record.partition()); + assertEquals(topic, record.topic()); + assertEquals(c, record.offset()); + } + assertEquals(2, c); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/CooperativeStickyAssignorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/CooperativeStickyAssignorTest.java new file mode 100644 index 0000000..f94aa23 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/CooperativeStickyAssignorTest.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Subscription; +import org.apache.kafka.clients.consumer.internals.AbstractStickyAssignor; +import org.apache.kafka.clients.consumer.internals.AbstractStickyAssignorTest; +import org.apache.kafka.common.TopicPartition; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import org.junit.jupiter.api.Test; + +import static org.apache.kafka.clients.consumer.internals.AbstractStickyAssignor.DEFAULT_GENERATION; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static java.util.Collections.emptyList; + +public class CooperativeStickyAssignorTest extends AbstractStickyAssignorTest { + + @Override + public AbstractStickyAssignor createAssignor() { + return new CooperativeStickyAssignor(); + } + + @Override + public Subscription buildSubscription(List topics, List partitions) { + return new Subscription(topics, assignor.subscriptionUserData(new HashSet<>(topics)), partitions); + } + + @Override + public Subscription buildSubscriptionWithGeneration(List topics, List partitions, int generation) { + assignor.onAssignment(null, new ConsumerGroupMetadata("dummy-group-id", generation, "dummy-member-id", Optional.empty())); + return new Subscription(topics, assignor.subscriptionUserData(new HashSet<>(topics)), partitions); + } + + @Test + public void testEncodeAndDecodeGeneration() { + Subscription subscription = new Subscription(topics(topic), assignor.subscriptionUserData(new HashSet<>(topics(topic)))); + + Optional encodedGeneration = ((CooperativeStickyAssignor) assignor).memberData(subscription).generation; + assertTrue(encodedGeneration.isPresent()); + assertEquals(encodedGeneration.get(), DEFAULT_GENERATION); + + int generation = 10; + assignor.onAssignment(null, new ConsumerGroupMetadata("dummy-group-id", generation, "dummy-member-id", Optional.empty())); + + subscription = new Subscription(topics(topic), assignor.subscriptionUserData(new HashSet<>(topics(topic)))); + encodedGeneration = ((CooperativeStickyAssignor) assignor).memberData(subscription).generation; + + assertTrue(encodedGeneration.isPresent()); + assertEquals(encodedGeneration.get(), generation); + } + + @Test + public void testDecodeGeneration() { + Subscription subscription = new Subscription(topics(topic)); + assertFalse(((CooperativeStickyAssignor) assignor).memberData(subscription).generation.isPresent()); + } + + @Test + public void testAllConsumersHaveOwnedPartitionInvalidatedWhenClaimedByMultipleConsumersInSameGenerationWithEqualPartitionsPerConsumer() { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic, 3); + + subscriptions.put(consumer1, buildSubscription(topics(topic), partitions(tp(topic, 0), tp(topic, 1)))); + subscriptions.put(consumer2, buildSubscription(topics(topic), partitions(tp(topic, 0), tp(topic, 2)))); + subscriptions.put(consumer3, buildSubscription(topics(topic), emptyList())); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertEquals(partitions(tp(topic, 1)), assignment.get(consumer1)); + assertEquals(partitions(tp(topic, 2)), assignment.get(consumer2)); + // In the cooperative assignor, topic-0 has to be considered "owned" and so it cant be assigned until both have "revoked" it + assertTrue(assignment.get(consumer3).isEmpty()); + + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + } + + @Test + public void testAllConsumersHaveOwnedPartitionInvalidatedWhenClaimedByMultipleConsumersInSameGenerationWithUnequalPartitionsPerConsumer() { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic, 4); + + subscriptions.put(consumer1, buildSubscription(topics(topic), partitions(tp(topic, 0), tp(topic, 1)))); + subscriptions.put(consumer2, buildSubscription(topics(topic), partitions(tp(topic, 0), tp(topic, 2)))); + subscriptions.put(consumer3, buildSubscription(topics(topic), emptyList())); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertEquals(partitions(tp(topic, 1), tp(topic, 3)), assignment.get(consumer1)); + assertEquals(partitions(tp(topic, 2)), assignment.get(consumer2)); + // In the cooperative assignor, topic-0 has to be considered "owned" and so it cant be assigned until both have "revoked" it + assertTrue(assignment.get(consumer3).isEmpty()); + + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + } + + /** + * The cooperative assignor must do some additional work and verification of some assignments relative to the eager + * assignor, since it may or may not need to trigger a second follow-up rebalance. + *

            + * In addition to the validity requirements described in + * {@link org.apache.kafka.clients.consumer.internals.AbstractStickyAssignorTest#verifyValidityAndBalance(Map, Map, Map)}, + * we must verify that no partition is being revoked and reassigned during the same rebalance. This means the initial + * assignment may be unbalanced, so if we do detect partitions being revoked we should trigger a second "rebalance" + * to get the final assignment and then verify that it is both valid and balanced. + */ + @Override + public void verifyValidityAndBalance(Map subscriptions, + Map> assignments, + Map partitionsPerTopic) { + int rebalances = 0; + // partitions are being revoked, we must go through another assignment to get the final state + while (verifyCooperativeValidity(subscriptions, assignments)) { + + // update the subscriptions with the now owned partitions + for (Map.Entry> entry : assignments.entrySet()) { + String consumer = entry.getKey(); + Subscription oldSubscription = subscriptions.get(consumer); + subscriptions.put(consumer, buildSubscription(oldSubscription.topics(), entry.getValue())); + } + + assignments.clear(); + assignments.putAll(assignor.assign(partitionsPerTopic, subscriptions)); + ++rebalances; + + assertTrue(rebalances <= 4); + } + + // Check the validity and balance of the final assignment + super.verifyValidityAndBalance(subscriptions, assignments, partitionsPerTopic); + } + + // Returns true if partitions are being revoked, indicating a second rebalance will be triggered + private boolean verifyCooperativeValidity(Map subscriptions, Map> assignments) { + Set allAddedPartitions = new HashSet<>(); + Set allRevokedPartitions = new HashSet<>(); + for (Map.Entry> entry : assignments.entrySet()) { + List ownedPartitions = subscriptions.get(entry.getKey()).ownedPartitions(); + List assignedPartitions = entry.getValue(); + + Set revokedPartitions = new HashSet<>(ownedPartitions); + revokedPartitions.removeAll(assignedPartitions); + + Set addedPartitions = new HashSet<>(assignedPartitions); + addedPartitions.removeAll(ownedPartitions); + + allAddedPartitions.addAll(addedPartitions); + allRevokedPartitions.addAll(revokedPartitions); + } + + Set intersection = new HashSet<>(allAddedPartitions); + intersection.retainAll(allRevokedPartitions); + assertTrue(intersection.isEmpty(), + "Error: Some partitions were assigned to a new consumer during the same rebalance they are being " + + "revoked from their previous owner. Partitions: " + intersection); + + return !allRevokedPartitions.isEmpty(); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java new file mode 100644 index 0000000..2872983 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/KafkaConsumerTest.java @@ -0,0 +1,2986 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import org.apache.kafka.clients.ApiVersions; +import org.apache.kafka.clients.ClientRequest; +import org.apache.kafka.clients.GroupRebalanceConfig; +import org.apache.kafka.clients.KafkaClient; +import org.apache.kafka.clients.MockClient; +import org.apache.kafka.clients.NodeApiVersions; +import org.apache.kafka.clients.consumer.internals.ConsumerCoordinator; +import org.apache.kafka.clients.consumer.internals.ConsumerInterceptors; +import org.apache.kafka.clients.consumer.internals.ConsumerMetadata; +import org.apache.kafka.clients.consumer.internals.ConsumerMetrics; +import org.apache.kafka.clients.consumer.internals.ConsumerNetworkClient; +import org.apache.kafka.clients.consumer.internals.ConsumerProtocol; +import org.apache.kafka.clients.consumer.internals.Fetcher; +import org.apache.kafka.clients.consumer.internals.MockRebalanceListener; +import org.apache.kafka.clients.consumer.internals.SubscriptionState; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.IsolationLevel; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.config.SslConfigs; +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.errors.InterruptException; +import org.apache.kafka.common.errors.InvalidConfigurationException; +import org.apache.kafka.common.errors.InvalidGroupIdException; +import org.apache.kafka.common.errors.InvalidTopicException; +import org.apache.kafka.common.errors.RecordDeserializationException; +import org.apache.kafka.common.errors.SerializationException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.errors.WakeupException; +import org.apache.kafka.common.internals.ClusterResourceListeners; +import org.apache.kafka.common.message.FetchResponseData; +import org.apache.kafka.common.message.HeartbeatResponseData; +import org.apache.kafka.common.message.JoinGroupRequestData; +import org.apache.kafka.common.message.JoinGroupResponseData; +import org.apache.kafka.common.message.LeaveGroupResponseData; +import org.apache.kafka.common.message.ListOffsetsRequestData.ListOffsetsPartition; +import org.apache.kafka.common.message.ListOffsetsResponseData; +import org.apache.kafka.common.message.ListOffsetsResponseData.ListOffsetsPartitionResponse; +import org.apache.kafka.common.message.ListOffsetsResponseData.ListOffsetsTopicResponse; +import org.apache.kafka.common.message.SyncGroupResponseData; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.stats.Avg; +import org.apache.kafka.common.network.Selectable; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.MemoryRecordsBuilder; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.common.requests.AbstractRequest; +import org.apache.kafka.common.requests.AbstractResponse; +import org.apache.kafka.common.requests.FetchRequest; +import org.apache.kafka.common.requests.FetchResponse; +import org.apache.kafka.common.requests.FindCoordinatorResponse; +import org.apache.kafka.common.requests.HeartbeatResponse; +import org.apache.kafka.common.requests.JoinGroupRequest; +import org.apache.kafka.common.requests.JoinGroupResponse; +import org.apache.kafka.common.requests.LeaveGroupResponse; +import org.apache.kafka.common.requests.ListOffsetsRequest; +import org.apache.kafka.common.requests.ListOffsetsResponse; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.requests.OffsetCommitRequest; +import org.apache.kafka.common.requests.OffsetCommitResponse; +import org.apache.kafka.common.requests.OffsetFetchResponse; +import org.apache.kafka.common.requests.RequestTestUtils; +import org.apache.kafka.common.requests.SyncGroupResponse; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.test.MockConsumerInterceptor; +import org.apache.kafka.test.MockMetricsReporter; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; + +import javax.management.MBeanServer; +import javax.management.ObjectName; +import java.lang.management.ManagementFactory; +import java.nio.ByteBuffer; +import java.time.Duration; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.ConcurrentModificationException; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalLong; +import java.util.Properties; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static java.util.Collections.singleton; +import static java.util.Collections.singletonList; +import static java.util.Collections.singletonMap; +import static org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class KafkaConsumerTest { + private final String topic = "test"; + private final Uuid topicId = Uuid.randomUuid(); + private final TopicPartition tp0 = new TopicPartition(topic, 0); + private final TopicPartition tp1 = new TopicPartition(topic, 1); + + private final String topic2 = "test2"; + private final Uuid topicId2 = Uuid.randomUuid(); + private final TopicPartition t2p0 = new TopicPartition(topic2, 0); + + private final String topic3 = "test3"; + private final Uuid topicId3 = Uuid.randomUuid(); + private final TopicPartition t3p0 = new TopicPartition(topic3, 0); + + private final int sessionTimeoutMs = 10000; + private final int heartbeatIntervalMs = 1000; + + // Set auto commit interval lower than heartbeat so we don't need to deal with + // a concurrent heartbeat request + private final int autoCommitIntervalMs = 500; + + private final String groupId = "mock-group"; + private final String memberId = "memberId"; + private final String leaderId = "leaderId"; + private final Optional groupInstanceId = Optional.of("mock-instance"); + private Map topicIds = Stream.of( + new AbstractMap.SimpleEntry<>(topic, topicId), + new AbstractMap.SimpleEntry<>(topic2, topicId2), + new AbstractMap.SimpleEntry<>(topic3, topicId3)) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + private Map topicNames = Stream.of( + new AbstractMap.SimpleEntry<>(topicId, topic), + new AbstractMap.SimpleEntry<>(topicId2, topic2), + new AbstractMap.SimpleEntry<>(topicId3, topic3)) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + private final String partitionRevoked = "Hit partition revoke "; + private final String partitionAssigned = "Hit partition assign "; + private final String partitionLost = "Hit partition lost "; + + private final Collection singleTopicPartition = Collections.singleton(new TopicPartition(topic, 0)); + + @Test + public void testMetricsReporterAutoGeneratedClientId() { + Properties props = new Properties(); + props.setProperty(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + props.setProperty(ConsumerConfig.METRIC_REPORTER_CLASSES_CONFIG, MockMetricsReporter.class.getName()); + KafkaConsumer consumer = new KafkaConsumer<>( + props, new StringDeserializer(), new StringDeserializer()); + + MockMetricsReporter mockMetricsReporter = (MockMetricsReporter) consumer.metrics.reporters().get(0); + + assertEquals(consumer.getClientId(), mockMetricsReporter.clientId); + consumer.close(); + } + + @Test + public void testPollReturnsRecords() { + KafkaConsumer consumer = setUpConsumerWithRecordsToPoll(tp0, 5); + + ConsumerRecords records = consumer.poll(Duration.ZERO); + + assertEquals(records.count(), 5); + assertEquals(records.partitions(), Collections.singleton(tp0)); + assertEquals(records.records(tp0).size(), 5); + + consumer.close(Duration.ofMillis(0)); + } + + @Test + public void testSecondPollWithDeserializationErrorThrowsRecordDeserializationException() { + int invalidRecordNumber = 4; + int invalidRecordOffset = 3; + StringDeserializer deserializer = mockErrorDeserializer(invalidRecordNumber); + + KafkaConsumer consumer = setUpConsumerWithRecordsToPoll(tp0, 5, deserializer); + ConsumerRecords records = consumer.poll(Duration.ZERO); + + assertEquals(invalidRecordNumber - 1, records.count()); + assertEquals(Collections.singleton(tp0), records.partitions()); + assertEquals(invalidRecordNumber - 1, records.records(tp0).size()); + long lastOffset = records.records(tp0).get(records.records(tp0).size() - 1).offset(); + assertEquals(invalidRecordNumber - 2, lastOffset); + + RecordDeserializationException rde = assertThrows(RecordDeserializationException.class, () -> consumer.poll(Duration.ZERO)); + assertEquals(invalidRecordOffset, rde.offset()); + assertEquals(tp0, rde.topicPartition()); + assertEquals(rde.offset(), consumer.position(tp0)); + consumer.close(Duration.ofMillis(0)); + } + + /* + Create a mock deserializer which throws a SerializationException on the Nth record's value deserialization + */ + private StringDeserializer mockErrorDeserializer(int recordNumber) { + int recordIndex = recordNumber - 1; + return new StringDeserializer() { + int i = 0; + @Override + public String deserialize(String topic, byte[] data) { + if (i == recordIndex) { + throw new SerializationException(); + } else { + i++; + return super.deserialize(topic, data); + } + } + }; + } + + private KafkaConsumer setUpConsumerWithRecordsToPoll(TopicPartition tp, int recordCount) { + return setUpConsumerWithRecordsToPoll(tp, recordCount, new StringDeserializer()); + } + + private KafkaConsumer setUpConsumerWithRecordsToPoll(TopicPartition tp, int recordCount, Deserializer deserializer) { + Time time = new MockTime(); + Cluster cluster = TestUtils.singletonCluster(tp.topic(), 1); + Node node = cluster.nodes().get(0); + + ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + initMetadata(client, Collections.singletonMap(topic, 1)); + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, + true, groupId, groupInstanceId, Optional.of(deserializer), false); + consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer)); + prepareRebalance(client, node, assignor, singletonList(tp), null); + consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE)); + client.prepareResponseFrom(fetchResponse(tp, 0, recordCount), node); + return consumer; + } + + @Test + public void testConstructorClose() { + Properties props = new Properties(); + props.setProperty(ConsumerConfig.CLIENT_ID_CONFIG, "testConstructorClose"); + props.setProperty(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "invalid-23-8409-adsfsdj"); + props.setProperty(ConsumerConfig.METRIC_REPORTER_CLASSES_CONFIG, MockMetricsReporter.class.getName()); + + final int oldInitCount = MockMetricsReporter.INIT_COUNT.get(); + final int oldCloseCount = MockMetricsReporter.CLOSE_COUNT.get(); + try { + new KafkaConsumer<>(props, new ByteArrayDeserializer(), new ByteArrayDeserializer()); + fail("should have caught an exception and returned"); + } catch (KafkaException e) { + assertEquals(oldInitCount + 1, MockMetricsReporter.INIT_COUNT.get()); + assertEquals(oldCloseCount + 1, MockMetricsReporter.CLOSE_COUNT.get()); + assertEquals("Failed to construct kafka consumer", e.getMessage()); + } + } + + @Test + public void testOsDefaultSocketBufferSizes() { + Map config = new HashMap<>(); + config.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + config.put(ConsumerConfig.SEND_BUFFER_CONFIG, Selectable.USE_DEFAULT_BUFFER_SIZE); + config.put(ConsumerConfig.RECEIVE_BUFFER_CONFIG, Selectable.USE_DEFAULT_BUFFER_SIZE); + KafkaConsumer consumer = new KafkaConsumer<>( + config, new ByteArrayDeserializer(), new ByteArrayDeserializer()); + consumer.close(); + } + + @Test + public void testInvalidSocketSendBufferSize() { + Map config = new HashMap<>(); + config.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + config.put(ConsumerConfig.SEND_BUFFER_CONFIG, -2); + assertThrows(KafkaException.class, + () -> new KafkaConsumer<>(config, new ByteArrayDeserializer(), new ByteArrayDeserializer())); + } + + @Test + public void testInvalidSocketReceiveBufferSize() { + Map config = new HashMap<>(); + config.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + config.put(ConsumerConfig.RECEIVE_BUFFER_CONFIG, -2); + assertThrows(KafkaException.class, + () -> new KafkaConsumer<>(config, new ByteArrayDeserializer(), new ByteArrayDeserializer())); + } + + @Test + public void shouldIgnoreGroupInstanceIdForEmptyGroupId() { + Map config = new HashMap<>(); + config.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + config.put(ConsumerConfig.GROUP_INSTANCE_ID_CONFIG, "instance_id"); + KafkaConsumer consumer = new KafkaConsumer<>( + config, new ByteArrayDeserializer(), new ByteArrayDeserializer()); + consumer.close(); + } + + @Test + public void testSubscription() { + KafkaConsumer consumer = newConsumer(groupId); + + consumer.subscribe(singletonList(topic)); + assertEquals(singleton(topic), consumer.subscription()); + assertTrue(consumer.assignment().isEmpty()); + + consumer.subscribe(Collections.emptyList()); + assertTrue(consumer.subscription().isEmpty()); + assertTrue(consumer.assignment().isEmpty()); + + consumer.assign(singletonList(tp0)); + assertTrue(consumer.subscription().isEmpty()); + assertEquals(singleton(tp0), consumer.assignment()); + + consumer.unsubscribe(); + assertTrue(consumer.subscription().isEmpty()); + assertTrue(consumer.assignment().isEmpty()); + + consumer.close(); + } + + @Test + public void testSubscriptionOnNullTopicCollection() { + try (KafkaConsumer consumer = newConsumer(groupId)) { + assertThrows(IllegalArgumentException.class, () -> consumer.subscribe((List) null)); + } + } + + @Test + public void testSubscriptionOnNullTopic() { + try (KafkaConsumer consumer = newConsumer(groupId)) { + assertThrows(IllegalArgumentException.class, () -> consumer.subscribe(singletonList(null))); + } + } + + @Test + public void testSubscriptionOnEmptyTopic() { + try (KafkaConsumer consumer = newConsumer(groupId)) { + String emptyTopic = " "; + assertThrows(IllegalArgumentException.class, () -> consumer.subscribe(singletonList(emptyTopic))); + } + } + + @Test + public void testSubscriptionOnNullPattern() { + try (KafkaConsumer consumer = newConsumer(groupId)) { + assertThrows(IllegalArgumentException.class, + () -> consumer.subscribe((Pattern) null)); + } + } + + @Test + public void testSubscriptionOnEmptyPattern() { + try (KafkaConsumer consumer = newConsumer(groupId)) { + assertThrows(IllegalArgumentException.class, + () -> consumer.subscribe(Pattern.compile(""))); + } + } + + @Test + public void testSubscriptionWithEmptyPartitionAssignment() { + Properties props = new Properties(); + props.setProperty(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + props.setProperty(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG, ""); + props.setProperty(ConsumerConfig.GROUP_ID_CONFIG, groupId); + try (KafkaConsumer consumer = newConsumer(props)) { + assertThrows(IllegalStateException.class, + () -> consumer.subscribe(singletonList(topic))); + } + } + + @Test + public void testSeekNegative() { + try (KafkaConsumer consumer = newConsumer((String) null)) { + consumer.assign(singleton(new TopicPartition("nonExistTopic", 0))); + assertThrows(IllegalArgumentException.class, + () -> consumer.seek(new TopicPartition("nonExistTopic", 0), -1)); + } + } + + @Test + public void testAssignOnNullTopicPartition() { + try (KafkaConsumer consumer = newConsumer((String) null)) { + assertThrows(IllegalArgumentException.class, () -> consumer.assign(null)); + } + } + + @Test + public void testAssignOnEmptyTopicPartition() { + try (KafkaConsumer consumer = newConsumer(groupId)) { + consumer.assign(Collections.emptyList()); + assertTrue(consumer.subscription().isEmpty()); + assertTrue(consumer.assignment().isEmpty()); + } + } + + @Test + public void testAssignOnNullTopicInPartition() { + try (KafkaConsumer consumer = newConsumer((String) null)) { + assertThrows(IllegalArgumentException.class, () -> consumer.assign(singleton(new TopicPartition(null, 0)))); + } + } + + @Test + public void testAssignOnEmptyTopicInPartition() { + try (KafkaConsumer consumer = newConsumer((String) null)) { + assertThrows(IllegalArgumentException.class, () -> consumer.assign(singleton(new TopicPartition(" ", 0)))); + } + } + + @Test + public void testInterceptorConstructorClose() { + try { + Properties props = new Properties(); + // test with client ID assigned by KafkaConsumer + props.setProperty(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + props.setProperty(ConsumerConfig.INTERCEPTOR_CLASSES_CONFIG, MockConsumerInterceptor.class.getName()); + + KafkaConsumer consumer = new KafkaConsumer<>( + props, new StringDeserializer(), new StringDeserializer()); + assertEquals(1, MockConsumerInterceptor.INIT_COUNT.get()); + assertEquals(0, MockConsumerInterceptor.CLOSE_COUNT.get()); + + consumer.close(); + assertEquals(1, MockConsumerInterceptor.INIT_COUNT.get()); + assertEquals(1, MockConsumerInterceptor.CLOSE_COUNT.get()); + // Cluster metadata will only be updated on calling poll. + assertNull(MockConsumerInterceptor.CLUSTER_META.get()); + + } finally { + // cleanup since we are using mutable static variables in MockConsumerInterceptor + MockConsumerInterceptor.resetCounters(); + } + } + + @Test + public void testPause() { + KafkaConsumer consumer = newConsumer(groupId); + + consumer.assign(singletonList(tp0)); + assertEquals(singleton(tp0), consumer.assignment()); + assertTrue(consumer.paused().isEmpty()); + + consumer.pause(singleton(tp0)); + assertEquals(singleton(tp0), consumer.paused()); + + consumer.resume(singleton(tp0)); + assertTrue(consumer.paused().isEmpty()); + + consumer.unsubscribe(); + assertTrue(consumer.paused().isEmpty()); + + consumer.close(); + } + + @Test + public void testConsumerJmxPrefix() throws Exception { + Map config = new HashMap<>(); + config.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + config.put(ConsumerConfig.SEND_BUFFER_CONFIG, Selectable.USE_DEFAULT_BUFFER_SIZE); + config.put(ConsumerConfig.RECEIVE_BUFFER_CONFIG, Selectable.USE_DEFAULT_BUFFER_SIZE); + config.put("client.id", "client-1"); + KafkaConsumer consumer = new KafkaConsumer<>( + config, new ByteArrayDeserializer(), new ByteArrayDeserializer()); + MBeanServer server = ManagementFactory.getPlatformMBeanServer(); + MetricName testMetricName = consumer.metrics.metricName("test-metric", + "grp1", "test metric"); + consumer.metrics.addMetric(testMetricName, new Avg()); + assertNotNull(server.getObjectInstance(new ObjectName("kafka.consumer:type=grp1,client-id=client-1"))); + consumer.close(); + } + + private KafkaConsumer newConsumer(String groupId) { + return newConsumer(groupId, Optional.empty()); + } + + private KafkaConsumer newConsumer(String groupId, Optional enableAutoCommit) { + Properties props = new Properties(); + props.setProperty(ConsumerConfig.CLIENT_ID_CONFIG, "my.consumer"); + props.setProperty(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + props.setProperty(ConsumerConfig.METRIC_REPORTER_CLASSES_CONFIG, MockMetricsReporter.class.getName()); + if (groupId != null) + props.setProperty(ConsumerConfig.GROUP_ID_CONFIG, groupId); + enableAutoCommit.ifPresent( + autoCommit -> props.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, autoCommit.toString())); + return newConsumer(props); + } + + private KafkaConsumer newConsumer(Properties props) { + return new KafkaConsumer<>(props, new ByteArrayDeserializer(), new ByteArrayDeserializer()); + } + + @Test + public void verifyHeartbeatSent() throws Exception { + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + + initMetadata(client, Collections.singletonMap(topic, 1)); + Node node = metadata.fetch().nodes().get(0); + + ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId); + + consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer)); + Node coordinator = prepareRebalance(client, node, assignor, singletonList(tp0), null); + + // initial fetch + client.prepareResponseFrom(fetchResponse(tp0, 0, 0), node); + consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE)); + + assertEquals(singleton(tp0), consumer.assignment()); + + AtomicBoolean heartbeatReceived = prepareHeartbeatResponse(client, coordinator, Errors.NONE); + + // heartbeat interval is 2 seconds + time.sleep(heartbeatIntervalMs); + Thread.sleep(heartbeatIntervalMs); + + consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE)); + + assertTrue(heartbeatReceived.get()); + consumer.close(Duration.ofMillis(0)); + } + + @Test + public void verifyHeartbeatSentWhenFetchedDataReady() throws Exception { + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + + initMetadata(client, Collections.singletonMap(topic, 1)); + Node node = metadata.fetch().nodes().get(0); + ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId); + consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer)); + Node coordinator = prepareRebalance(client, node, assignor, singletonList(tp0), null); + + consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE)); + consumer.poll(Duration.ZERO); + + // respond to the outstanding fetch so that we have data available on the next poll + client.respondFrom(fetchResponse(tp0, 0, 5), node); + client.poll(0, time.milliseconds()); + + client.prepareResponseFrom(fetchResponse(tp0, 5, 0), node); + AtomicBoolean heartbeatReceived = prepareHeartbeatResponse(client, coordinator, Errors.NONE); + + time.sleep(heartbeatIntervalMs); + Thread.sleep(heartbeatIntervalMs); + + consumer.poll(Duration.ZERO); + + assertTrue(heartbeatReceived.get()); + consumer.close(Duration.ofMillis(0)); + } + + @Test + public void verifyPollTimesOutDuringMetadataUpdate() { + final Time time = new MockTime(); + final SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + final ConsumerMetadata metadata = createMetadata(subscription); + final MockClient client = new MockClient(time, metadata); + + initMetadata(client, Collections.singletonMap(topic, 1)); + Node node = metadata.fetch().nodes().get(0); + + final ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + + final KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId); + consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer)); + // Since we would enable the heartbeat thread after received join-response which could + // send the sync-group on behalf of the consumer if it is enqueued, we may still complete + // the rebalance and send out the fetch; in order to avoid it we do not prepare sync response here. + client.prepareResponseFrom(FindCoordinatorResponse.prepareResponse(Errors.NONE, groupId, node), node); + Node coordinator = new Node(Integer.MAX_VALUE - node.id(), node.host(), node.port()); + client.prepareResponseFrom(joinGroupFollowerResponse(assignor, 1, memberId, leaderId, Errors.NONE), coordinator); + + consumer.poll(Duration.ZERO); + + final Queue requests = client.requests(); + assertEquals(0, requests.stream().filter(request -> request.apiKey().equals(ApiKeys.FETCH)).count()); + } + + @SuppressWarnings("deprecation") + @Test + public void verifyDeprecatedPollDoesNotTimeOutDuringMetadataUpdate() { + final Time time = new MockTime(); + final SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + final ConsumerMetadata metadata = createMetadata(subscription); + final MockClient client = new MockClient(time, metadata); + + initMetadata(client, Collections.singletonMap(topic, 1)); + Node node = metadata.fetch().nodes().get(0); + + final ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + + final KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId); + consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer)); + prepareRebalance(client, node, assignor, singletonList(tp0), null); + + consumer.poll(0L); + + // The underlying client SHOULD get a fetch request + final Queue requests = client.requests(); + assertEquals(1, requests.size()); + final Class aClass = requests.peek().requestBuilder().getClass(); + assertEquals(FetchRequest.Builder.class, aClass); + } + + @Test + public void verifyNoCoordinatorLookupForManualAssignmentWithSeek() { + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + + initMetadata(client, Collections.singletonMap(topic, 1)); + ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId); + consumer.assign(singleton(tp0)); + consumer.seekToBeginning(singleton(tp0)); + + // there shouldn't be any need to lookup the coordinator or fetch committed offsets. + // we just lookup the starting position and send the record fetch. + client.prepareResponse(listOffsetsResponse(Collections.singletonMap(tp0, 50L))); + client.prepareResponse(fetchResponse(tp0, 50L, 5)); + + ConsumerRecords records = consumer.poll(Duration.ofMillis(1)); + assertEquals(5, records.count()); + assertEquals(55L, consumer.position(tp0)); + consumer.close(Duration.ofMillis(0)); + } + + @Test + public void testFetchProgressWithMissingPartitionPosition() { + // Verifies that we can make progress on one partition while we are awaiting + // a reset on another partition. + + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + initMetadata(client, Collections.singletonMap(topic, 2)); + + KafkaConsumer consumer = newConsumerNoAutoCommit(time, client, subscription, metadata); + consumer.assign(Arrays.asList(tp0, tp1)); + consumer.seekToEnd(singleton(tp0)); + consumer.seekToBeginning(singleton(tp1)); + + client.prepareResponse(body -> { + ListOffsetsRequest request = (ListOffsetsRequest) body; + List partitions = request.topics().stream().flatMap(t -> { + if (t.name().equals(topic)) + return Stream.of(t.partitions()); + else + return Stream.empty(); + }).flatMap(List::stream).collect(Collectors.toList()); + ListOffsetsPartition expectedTp0 = new ListOffsetsPartition() + .setPartitionIndex(tp0.partition()) + .setTimestamp(ListOffsetsRequest.LATEST_TIMESTAMP); + ListOffsetsPartition expectedTp1 = new ListOffsetsPartition() + .setPartitionIndex(tp1.partition()) + .setTimestamp(ListOffsetsRequest.EARLIEST_TIMESTAMP); + return partitions.contains(expectedTp0) && partitions.contains(expectedTp1); + }, listOffsetsResponse(Collections.singletonMap(tp0, 50L), Collections.singletonMap(tp1, Errors.NOT_LEADER_OR_FOLLOWER))); + client.prepareResponse( + body -> { + FetchRequest request = (FetchRequest) body; + Map fetchData = request.fetchData(topicNames); + TopicIdPartition tidp0 = new TopicIdPartition(topicIds.get(tp0.topic()), tp0); + return fetchData.keySet().equals(singleton(tidp0)) && + fetchData.get(tidp0).fetchOffset == 50L; + + }, fetchResponse(tp0, 50L, 5)); + + ConsumerRecords records = consumer.poll(Duration.ofMillis(1)); + assertEquals(5, records.count()); + assertEquals(singleton(tp0), records.partitions()); + } + + private void initMetadata(MockClient mockClient, Map partitionCounts) { + Map metadataIds = new HashMap<>(); + for (String name : partitionCounts.keySet()) { + metadataIds.put(name, topicIds.get(name)); + } + MetadataResponse initialMetadata = RequestTestUtils.metadataUpdateWithIds(1, partitionCounts, metadataIds); + + mockClient.updateMetadata(initialMetadata); + } + + @Test + public void testMissingOffsetNoResetPolicy() { + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.NONE); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + + initMetadata(client, Collections.singletonMap(topic, 1)); + Node node = metadata.fetch().nodes().get(0); + + ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, + true, groupId, groupInstanceId, false); + consumer.assign(singletonList(tp0)); + + client.prepareResponseFrom(FindCoordinatorResponse.prepareResponse(Errors.NONE, groupId, node), node); + Node coordinator = new Node(Integer.MAX_VALUE - node.id(), node.host(), node.port()); + + // lookup committed offset and find nothing + client.prepareResponseFrom(offsetResponse(Collections.singletonMap(tp0, -1L), Errors.NONE), coordinator); + assertThrows(NoOffsetForPartitionException.class, () -> consumer.poll(Duration.ZERO)); + } + + @Test + public void testResetToCommittedOffset() { + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.NONE); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + + initMetadata(client, Collections.singletonMap(topic, 1)); + Node node = metadata.fetch().nodes().get(0); + + ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, + true, groupId, groupInstanceId, false); + consumer.assign(singletonList(tp0)); + + client.prepareResponseFrom(FindCoordinatorResponse.prepareResponse(Errors.NONE, groupId, node), node); + Node coordinator = new Node(Integer.MAX_VALUE - node.id(), node.host(), node.port()); + + client.prepareResponseFrom(offsetResponse(Collections.singletonMap(tp0, 539L), Errors.NONE), coordinator); + consumer.poll(Duration.ZERO); + + assertEquals(539L, consumer.position(tp0)); + } + + @Test + public void testResetUsingAutoResetPolicy() { + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.LATEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + + initMetadata(client, Collections.singletonMap(topic, 1)); + Node node = metadata.fetch().nodes().get(0); + + ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, + true, groupId, groupInstanceId, false); + consumer.assign(singletonList(tp0)); + + client.prepareResponseFrom(FindCoordinatorResponse.prepareResponse(Errors.NONE, groupId, node), node); + Node coordinator = new Node(Integer.MAX_VALUE - node.id(), node.host(), node.port()); + + client.prepareResponseFrom(offsetResponse(Collections.singletonMap(tp0, -1L), Errors.NONE), coordinator); + client.prepareResponse(listOffsetsResponse(Collections.singletonMap(tp0, 50L))); + + consumer.poll(Duration.ZERO); + + assertEquals(50L, consumer.position(tp0)); + } + + @Test + public void testOffsetIsValidAfterSeek() { + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.LATEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + + initMetadata(client, Collections.singletonMap(topic, 1)); + + ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, + true, groupId, Optional.empty(), false); + consumer.assign(singletonList(tp0)); + consumer.seek(tp0, 20L); + consumer.poll(Duration.ZERO); + assertEquals(subscription.validPosition(tp0).offset, 20L); + } + + @Test + public void testCommitsFetchedDuringAssign() { + long offset1 = 10000; + long offset2 = 20000; + + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + + initMetadata(client, Collections.singletonMap(topic, 2)); + Node node = metadata.fetch().nodes().get(0); + + ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId); + consumer.assign(singletonList(tp0)); + + // lookup coordinator + client.prepareResponseFrom(FindCoordinatorResponse.prepareResponse(Errors.NONE, groupId, node), node); + Node coordinator = new Node(Integer.MAX_VALUE - node.id(), node.host(), node.port()); + + // fetch offset for one topic + client.prepareResponseFrom(offsetResponse(Collections.singletonMap(tp0, offset1), Errors.NONE), coordinator); + assertEquals(offset1, consumer.committed(Collections.singleton(tp0)).get(tp0).offset()); + + consumer.assign(Arrays.asList(tp0, tp1)); + + // fetch offset for two topics + Map offsets = new HashMap<>(); + offsets.put(tp0, offset1); + client.prepareResponseFrom(offsetResponse(offsets, Errors.NONE), coordinator); + assertEquals(offset1, consumer.committed(Collections.singleton(tp0)).get(tp0).offset()); + + offsets.remove(tp0); + offsets.put(tp1, offset2); + client.prepareResponseFrom(offsetResponse(offsets, Errors.NONE), coordinator); + assertEquals(offset2, consumer.committed(Collections.singleton(tp1)).get(tp1).offset()); + consumer.close(Duration.ofMillis(0)); + } + + @Test + public void testFetchStableOffsetThrowInCommitted() { + assertThrows(UnsupportedVersionException.class, () -> setupThrowableConsumer().committed(Collections.singleton(tp0))); + } + + @Test + public void testFetchStableOffsetThrowInPoll() { + assertThrows(UnsupportedVersionException.class, () -> setupThrowableConsumer().poll(Duration.ZERO)); + } + + @Test + public void testFetchStableOffsetThrowInPosition() { + assertThrows(UnsupportedVersionException.class, () -> setupThrowableConsumer().position(tp0)); + } + + private KafkaConsumer setupThrowableConsumer() { + long offset1 = 10000; + + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + + initMetadata(client, Collections.singletonMap(topic, 2)); + client.setNodeApiVersions(NodeApiVersions.create(ApiKeys.OFFSET_FETCH.id, (short) 0, (short) 6)); + + Node node = metadata.fetch().nodes().get(0); + + ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + + KafkaConsumer consumer = newConsumer( + time, client, subscription, metadata, assignor, true, groupId, groupInstanceId, true); + consumer.assign(singletonList(tp0)); + + client.prepareResponseFrom(FindCoordinatorResponse.prepareResponse(Errors.NONE, groupId, node), node); + Node coordinator = new Node(Integer.MAX_VALUE - node.id(), node.host(), node.port()); + + client.prepareResponseFrom(offsetResponse( + Collections.singletonMap(tp0, offset1), Errors.NONE), coordinator); + return consumer; + } + + @Test + public void testNoCommittedOffsets() { + long offset1 = 10000; + + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + + initMetadata(client, Collections.singletonMap(topic, 2)); + Node node = metadata.fetch().nodes().get(0); + + ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId); + consumer.assign(Arrays.asList(tp0, tp1)); + + // lookup coordinator + client.prepareResponseFrom(FindCoordinatorResponse.prepareResponse(Errors.NONE, groupId, node), node); + Node coordinator = new Node(Integer.MAX_VALUE - node.id(), node.host(), node.port()); + + // fetch offset for one topic + client.prepareResponseFrom(offsetResponse(Utils.mkMap(Utils.mkEntry(tp0, offset1), Utils.mkEntry(tp1, -1L)), Errors.NONE), coordinator); + final Map committed = consumer.committed(Utils.mkSet(tp0, tp1)); + assertEquals(2, committed.size()); + assertEquals(offset1, committed.get(tp0).offset()); + assertNull(committed.get(tp1)); + + consumer.close(Duration.ofMillis(0)); + } + + @Test + public void testAutoCommitSentBeforePositionUpdate() { + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + + initMetadata(client, Collections.singletonMap(topic, 1)); + Node node = metadata.fetch().nodes().get(0); + + ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId); + consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer)); + Node coordinator = prepareRebalance(client, node, assignor, singletonList(tp0), null); + + consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE)); + consumer.poll(Duration.ZERO); + + // respond to the outstanding fetch so that we have data available on the next poll + client.respondFrom(fetchResponse(tp0, 0, 5), node); + client.poll(0, time.milliseconds()); + + time.sleep(autoCommitIntervalMs); + + client.prepareResponseFrom(fetchResponse(tp0, 5, 0), node); + + // no data has been returned to the user yet, so the committed offset should be 0 + AtomicBoolean commitReceived = prepareOffsetCommitResponse(client, coordinator, tp0, 0); + + consumer.poll(Duration.ZERO); + + assertTrue(commitReceived.get()); + consumer.close(Duration.ofMillis(0)); + } + + @Test + public void testRegexSubscription() { + String unmatchedTopic = "unmatched"; + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + + Map partitionCounts = new HashMap<>(); + partitionCounts.put(topic, 1); + partitionCounts.put(unmatchedTopic, 1); + topicIds.put(unmatchedTopic, Uuid.randomUuid()); + initMetadata(client, partitionCounts); + Node node = metadata.fetch().nodes().get(0); + + ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId); + prepareRebalance(client, node, singleton(topic), assignor, singletonList(tp0), null); + + consumer.subscribe(Pattern.compile(topic), getConsumerRebalanceListener(consumer)); + + client.prepareMetadataUpdate(RequestTestUtils.metadataUpdateWithIds(1, partitionCounts, topicIds)); + + consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE)); + + assertEquals(singleton(topic), consumer.subscription()); + assertEquals(singleton(tp0), consumer.assignment()); + consumer.close(Duration.ofMillis(0)); + } + + @Test + public void testChangingRegexSubscription() { + ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + + String otherTopic = "other"; + TopicPartition otherTopicPartition = new TopicPartition(otherTopic, 0); + + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + + Map partitionCounts = new HashMap<>(); + partitionCounts.put(topic, 1); + partitionCounts.put(otherTopic, 1); + topicIds.put(otherTopic, Uuid.randomUuid()); + initMetadata(client, partitionCounts); + Node node = metadata.fetch().nodes().get(0); + + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, false, groupInstanceId); + + Node coordinator = prepareRebalance(client, node, singleton(topic), assignor, singletonList(tp0), null); + consumer.subscribe(Pattern.compile(topic), getConsumerRebalanceListener(consumer)); + + consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE)); + consumer.poll(Duration.ZERO); + + assertEquals(singleton(topic), consumer.subscription()); + + consumer.subscribe(Pattern.compile(otherTopic), getConsumerRebalanceListener(consumer)); + + client.prepareMetadataUpdate(RequestTestUtils.metadataUpdateWithIds(1, partitionCounts, topicIds)); + prepareRebalance(client, node, singleton(otherTopic), assignor, singletonList(otherTopicPartition), coordinator); + consumer.poll(Duration.ZERO); + + assertEquals(singleton(otherTopic), consumer.subscription()); + consumer.close(Duration.ofMillis(0)); + } + + @Test + public void testWakeupWithFetchDataAvailable() throws Exception { + final Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + + initMetadata(client, Collections.singletonMap(topic, 1)); + Node node = metadata.fetch().nodes().get(0); + + ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId); + consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer)); + prepareRebalance(client, node, assignor, singletonList(tp0), null); + + consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE)); + consumer.poll(Duration.ZERO); + + // respond to the outstanding fetch so that we have data available on the next poll + client.respondFrom(fetchResponse(tp0, 0, 5), node); + client.poll(0, time.milliseconds()); + + consumer.wakeup(); + + assertThrows(WakeupException.class, () -> consumer.poll(Duration.ZERO)); + + // make sure the position hasn't been updated + assertEquals(0, consumer.position(tp0)); + + // the next poll should return the completed fetch + ConsumerRecords records = consumer.poll(Duration.ZERO); + assertEquals(5, records.count()); + // Increment time asynchronously to clear timeouts in closing the consumer + final ScheduledExecutorService exec = Executors.newSingleThreadScheduledExecutor(); + exec.scheduleAtFixedRate(() -> time.sleep(sessionTimeoutMs), 0L, 10L, TimeUnit.MILLISECONDS); + consumer.close(); + exec.shutdownNow(); + exec.awaitTermination(5L, TimeUnit.SECONDS); + } + + @Test + public void testPollThrowsInterruptExceptionIfInterrupted() { + final Time time = new MockTime(); + final SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + final ConsumerMetadata metadata = createMetadata(subscription); + final MockClient client = new MockClient(time, metadata); + + initMetadata(client, Collections.singletonMap(topic, 1)); + Node node = metadata.fetch().nodes().get(0); + + final ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, false, groupInstanceId); + consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer)); + prepareRebalance(client, node, assignor, singletonList(tp0), null); + + consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE)); + consumer.poll(Duration.ZERO); + + // interrupt the thread and call poll + try { + Thread.currentThread().interrupt(); + assertThrows(InterruptException.class, () -> consumer.poll(Duration.ZERO)); + } finally { + // clear interrupted state again since this thread may be reused by JUnit + Thread.interrupted(); + consumer.close(Duration.ofMillis(0)); + } + } + + @Test + public void fetchResponseWithUnexpectedPartitionIsIgnored() { + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + + initMetadata(client, Collections.singletonMap(topic, 1)); + Node node = metadata.fetch().nodes().get(0); + + ConsumerPartitionAssignor assignor = new RangeAssignor(); + + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId); + consumer.subscribe(singletonList(topic), getConsumerRebalanceListener(consumer)); + + prepareRebalance(client, node, assignor, singletonList(tp0), null); + + Map fetches1 = new HashMap<>(); + fetches1.put(tp0, new FetchInfo(0, 1)); + fetches1.put(t2p0, new FetchInfo(0, 10)); // not assigned and not fetched + client.prepareResponseFrom(fetchResponse(fetches1), node); + + consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE)); + + ConsumerRecords records = consumer.poll(Duration.ZERO); + assertEquals(0, records.count()); + consumer.close(Duration.ofMillis(0)); + } + + /** + * Verify that when a consumer changes its topic subscription its assigned partitions + * do not immediately change, and the latest consumed offsets of its to-be-revoked + * partitions are properly committed (when auto-commit is enabled). + * Upon unsubscribing from subscribed topics the consumer subscription and assignment + * are both updated right away but its consumed offsets are not auto committed. + */ + @Test + public void testSubscriptionChangesWithAutoCommitEnabled() { + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + + Map tpCounts = new HashMap<>(); + tpCounts.put(topic, 1); + tpCounts.put(topic2, 1); + tpCounts.put(topic3, 1); + initMetadata(client, tpCounts); + Node node = metadata.fetch().nodes().get(0); + + ConsumerPartitionAssignor assignor = new RangeAssignor(); + + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId); + + // initial subscription + consumer.subscribe(Arrays.asList(topic, topic2), getConsumerRebalanceListener(consumer)); + + // verify that subscription has changed but assignment is still unchanged + assertEquals(2, consumer.subscription().size()); + assertTrue(consumer.subscription().contains(topic) && consumer.subscription().contains(topic2)); + assertTrue(consumer.assignment().isEmpty()); + + // mock rebalance responses + Node coordinator = prepareRebalance(client, node, assignor, Arrays.asList(tp0, t2p0), null); + + consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE)); + consumer.poll(Duration.ZERO); + + // verify that subscription is still the same, and now assignment has caught up + assertEquals(2, consumer.subscription().size()); + assertTrue(consumer.subscription().contains(topic) && consumer.subscription().contains(topic2)); + assertEquals(2, consumer.assignment().size()); + assertTrue(consumer.assignment().contains(tp0) && consumer.assignment().contains(t2p0)); + + // mock a response to the outstanding fetch so that we have data available on the next poll + Map fetches1 = new HashMap<>(); + fetches1.put(tp0, new FetchInfo(0, 1)); + fetches1.put(t2p0, new FetchInfo(0, 10)); + client.respondFrom(fetchResponse(fetches1), node); + client.poll(0, time.milliseconds()); + + ConsumerRecords records = consumer.poll(Duration.ofMillis(1)); + + // clear out the prefetch so it doesn't interfere with the rest of the test + fetches1.put(tp0, new FetchInfo(1, 0)); + fetches1.put(t2p0, new FetchInfo(10, 0)); + client.respondFrom(fetchResponse(fetches1), node); + client.poll(0, time.milliseconds()); + + // verify that the fetch occurred as expected + assertEquals(11, records.count()); + assertEquals(1L, consumer.position(tp0)); + assertEquals(10L, consumer.position(t2p0)); + + // subscription change + consumer.subscribe(Arrays.asList(topic, topic3), getConsumerRebalanceListener(consumer)); + + // verify that subscription has changed but assignment is still unchanged + assertEquals(2, consumer.subscription().size()); + assertTrue(consumer.subscription().contains(topic) && consumer.subscription().contains(topic3)); + assertEquals(2, consumer.assignment().size()); + assertTrue(consumer.assignment().contains(tp0) && consumer.assignment().contains(t2p0)); + + // mock the offset commit response for to be revoked partitions + Map partitionOffsets1 = new HashMap<>(); + partitionOffsets1.put(tp0, 1L); + partitionOffsets1.put(t2p0, 10L); + AtomicBoolean commitReceived = prepareOffsetCommitResponse(client, coordinator, partitionOffsets1); + + // mock rebalance responses + prepareRebalance(client, node, assignor, Arrays.asList(tp0, t3p0), coordinator); + + // mock a response to the next fetch from the new assignment + Map fetches2 = new HashMap<>(); + fetches2.put(tp0, new FetchInfo(1, 1)); + fetches2.put(t3p0, new FetchInfo(0, 100)); + client.prepareResponse(fetchResponse(fetches2)); + + records = consumer.poll(Duration.ofMillis(1)); + + // verify that the fetch occurred as expected + assertEquals(101, records.count()); + assertEquals(2L, consumer.position(tp0)); + assertEquals(100L, consumer.position(t3p0)); + + // verify that the offset commits occurred as expected + assertTrue(commitReceived.get()); + + // verify that subscription is still the same, and now assignment has caught up + assertEquals(2, consumer.subscription().size()); + assertTrue(consumer.subscription().contains(topic) && consumer.subscription().contains(topic3)); + assertEquals(2, consumer.assignment().size()); + assertTrue(consumer.assignment().contains(tp0) && consumer.assignment().contains(t3p0)); + + consumer.unsubscribe(); + + // verify that subscription and assignment are both cleared + assertTrue(consumer.subscription().isEmpty()); + assertTrue(consumer.assignment().isEmpty()); + + client.requests().clear(); + consumer.close(); + } + + /** + * Verify that when a consumer changes its topic subscription its assigned partitions + * do not immediately change, and the consumed offsets of its to-be-revoked partitions + * are not committed (when auto-commit is disabled). + * Upon unsubscribing from subscribed topics, the assigned partitions immediately + * change but if auto-commit is disabled the consumer offsets are not committed. + */ + @Test + public void testSubscriptionChangesWithAutoCommitDisabled() { + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + + Map tpCounts = new HashMap<>(); + tpCounts.put(topic, 1); + tpCounts.put(topic2, 1); + initMetadata(client, tpCounts); + Node node = metadata.fetch().nodes().get(0); + + ConsumerPartitionAssignor assignor = new RangeAssignor(); + + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, false, groupInstanceId); + + initializeSubscriptionWithSingleTopic(consumer, getConsumerRebalanceListener(consumer)); + + // mock rebalance responses + prepareRebalance(client, node, assignor, singletonList(tp0), null); + + consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE)); + consumer.poll(Duration.ZERO); + + // verify that subscription is still the same, and now assignment has caught up + assertEquals(singleton(topic), consumer.subscription()); + assertEquals(singleton(tp0), consumer.assignment()); + + consumer.poll(Duration.ZERO); + + // subscription change + consumer.subscribe(singleton(topic2), getConsumerRebalanceListener(consumer)); + + // verify that subscription has changed but assignment is still unchanged + assertEquals(singleton(topic2), consumer.subscription()); + assertEquals(singleton(tp0), consumer.assignment()); + + // the auto commit is disabled, so no offset commit request should be sent + for (ClientRequest req: client.requests()) + assertNotSame(ApiKeys.OFFSET_COMMIT, req.requestBuilder().apiKey()); + + // subscription change + consumer.unsubscribe(); + + // verify that subscription and assignment are both updated + assertEquals(Collections.emptySet(), consumer.subscription()); + assertEquals(Collections.emptySet(), consumer.assignment()); + + // the auto commit is disabled, so no offset commit request should be sent + for (ClientRequest req: client.requests()) + assertNotSame(ApiKeys.OFFSET_COMMIT, req.requestBuilder().apiKey()); + + client.requests().clear(); + consumer.close(); + } + + @Test + public void testUnsubscribeShouldTriggerPartitionsRevokedWithValidGeneration() { + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + + initMetadata(client, Collections.singletonMap(topic, 1)); + Node node = metadata.fetch().nodes().get(0); + + CooperativeStickyAssignor assignor = new CooperativeStickyAssignor(); + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, false, groupInstanceId); + + initializeSubscriptionWithSingleTopic(consumer, getExceptionConsumerRebalanceListener()); + + prepareRebalance(client, node, assignor, singletonList(tp0), null); + + RuntimeException assignmentException = assertThrows(RuntimeException.class, + () -> consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE))); + assertEquals(partitionAssigned + singleTopicPartition, assignmentException.getCause().getMessage()); + + RuntimeException unsubscribeException = assertThrows(RuntimeException.class, consumer::unsubscribe); + assertEquals(partitionRevoked + singleTopicPartition, unsubscribeException.getCause().getMessage()); + } + + @Test + public void testUnsubscribeShouldTriggerPartitionsLostWithNoGeneration() throws Exception { + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + + initMetadata(client, Collections.singletonMap(topic, 1)); + Node node = metadata.fetch().nodes().get(0); + + CooperativeStickyAssignor assignor = new CooperativeStickyAssignor(); + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, false, groupInstanceId); + + initializeSubscriptionWithSingleTopic(consumer, getExceptionConsumerRebalanceListener()); + Node coordinator = prepareRebalance(client, node, assignor, singletonList(tp0), null); + + RuntimeException assignException = assertThrows(RuntimeException.class, + () -> consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE))); + assertEquals(partitionAssigned + singleTopicPartition, assignException.getCause().getMessage()); + + AtomicBoolean heartbeatReceived = prepareHeartbeatResponse(client, coordinator, Errors.UNKNOWN_MEMBER_ID); + + time.sleep(heartbeatIntervalMs); + TestUtils.waitForCondition(heartbeatReceived::get, "Heartbeat response did not occur within timeout."); + + RuntimeException unsubscribeException = assertThrows(RuntimeException.class, consumer::unsubscribe); + assertEquals(partitionLost + singleTopicPartition, unsubscribeException.getCause().getMessage()); + } + + private void initializeSubscriptionWithSingleTopic(KafkaConsumer consumer, + ConsumerRebalanceListener consumerRebalanceListener) { + consumer.subscribe(singleton(topic), consumerRebalanceListener); + // verify that subscription has changed but assignment is still unchanged + assertEquals(singleton(topic), consumer.subscription()); + assertEquals(Collections.emptySet(), consumer.assignment()); + } + + @Test + public void testManualAssignmentChangeWithAutoCommitEnabled() { + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + + Map tpCounts = new HashMap<>(); + tpCounts.put(topic, 1); + tpCounts.put(topic2, 1); + initMetadata(client, tpCounts); + Node node = metadata.fetch().nodes().get(0); + + ConsumerPartitionAssignor assignor = new RangeAssignor(); + + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId); + + // lookup coordinator + client.prepareResponseFrom(FindCoordinatorResponse.prepareResponse(Errors.NONE, groupId, node), node); + Node coordinator = new Node(Integer.MAX_VALUE - node.id(), node.host(), node.port()); + + // manual assignment + consumer.assign(singleton(tp0)); + consumer.seekToBeginning(singleton(tp0)); + + // fetch offset for one topic + client.prepareResponseFrom(offsetResponse(Collections.singletonMap(tp0, 0L), Errors.NONE), coordinator); + assertEquals(0, consumer.committed(Collections.singleton(tp0)).get(tp0).offset()); + + // verify that assignment immediately changes + assertEquals(consumer.assignment(), singleton(tp0)); + + // there shouldn't be any need to lookup the coordinator or fetch committed offsets. + // we just lookup the starting position and send the record fetch. + client.prepareResponse(listOffsetsResponse(Collections.singletonMap(tp0, 10L))); + client.prepareResponse(fetchResponse(tp0, 10L, 1)); + + ConsumerRecords records = consumer.poll(Duration.ofMillis(1)); + + assertEquals(1, records.count()); + assertEquals(11L, consumer.position(tp0)); + + // mock the offset commit response for to be revoked partitions + AtomicBoolean commitReceived = prepareOffsetCommitResponse(client, coordinator, tp0, 11); + + // new manual assignment + consumer.assign(singleton(t2p0)); + + // verify that assignment immediately changes + assertEquals(consumer.assignment(), singleton(t2p0)); + // verify that the offset commits occurred as expected + assertTrue(commitReceived.get()); + + client.requests().clear(); + consumer.close(); + } + + @Test + public void testManualAssignmentChangeWithAutoCommitDisabled() { + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + + Map tpCounts = new HashMap<>(); + tpCounts.put(topic, 1); + tpCounts.put(topic2, 1); + initMetadata(client, tpCounts); + Node node = metadata.fetch().nodes().get(0); + + ConsumerPartitionAssignor assignor = new RangeAssignor(); + + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, false, groupInstanceId); + + // lookup coordinator + client.prepareResponseFrom(FindCoordinatorResponse.prepareResponse(Errors.NONE, groupId, node), node); + Node coordinator = new Node(Integer.MAX_VALUE - node.id(), node.host(), node.port()); + + // manual assignment + consumer.assign(singleton(tp0)); + consumer.seekToBeginning(singleton(tp0)); + + // fetch offset for one topic + client.prepareResponseFrom( + offsetResponse(Collections.singletonMap(tp0, 0L), Errors.NONE), + coordinator); + assertEquals(0, consumer.committed(Collections.singleton(tp0)).get(tp0).offset()); + + // verify that assignment immediately changes + assertEquals(consumer.assignment(), singleton(tp0)); + + // there shouldn't be any need to lookup the coordinator or fetch committed offsets. + // we just lookup the starting position and send the record fetch. + client.prepareResponse(listOffsetsResponse(Collections.singletonMap(tp0, 10L))); + client.prepareResponse(fetchResponse(tp0, 10L, 1)); + + ConsumerRecords records = consumer.poll(Duration.ofMillis(1)); + assertEquals(1, records.count()); + assertEquals(11L, consumer.position(tp0)); + + // new manual assignment + consumer.assign(singleton(t2p0)); + + // verify that assignment immediately changes + assertEquals(consumer.assignment(), singleton(t2p0)); + + // the auto commit is disabled, so no offset commit request should be sent + for (ClientRequest req : client.requests()) + assertNotSame(req.requestBuilder().apiKey(), ApiKeys.OFFSET_COMMIT); + + client.requests().clear(); + consumer.close(); + } + + @Test + public void testOffsetOfPausedPartitions() { + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + + initMetadata(client, Collections.singletonMap(topic, 2)); + Node node = metadata.fetch().nodes().get(0); + + ConsumerPartitionAssignor assignor = new RangeAssignor(); + + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId); + + // lookup coordinator + client.prepareResponseFrom(FindCoordinatorResponse.prepareResponse(Errors.NONE, groupId, node), node); + Node coordinator = new Node(Integer.MAX_VALUE - node.id(), node.host(), node.port()); + + // manual assignment + Set partitions = Utils.mkSet(tp0, tp1); + consumer.assign(partitions); + // verify consumer's assignment + assertEquals(partitions, consumer.assignment()); + + consumer.pause(partitions); + consumer.seekToEnd(partitions); + + // fetch and verify committed offset of two partitions + Map offsets = new HashMap<>(); + offsets.put(tp0, 0L); + offsets.put(tp1, 0L); + + client.prepareResponseFrom(offsetResponse(offsets, Errors.NONE), coordinator); + assertEquals(0, consumer.committed(Collections.singleton(tp0)).get(tp0).offset()); + + offsets.remove(tp0); + offsets.put(tp1, 0L); + client.prepareResponseFrom(offsetResponse(offsets, Errors.NONE), coordinator); + assertEquals(0, consumer.committed(Collections.singleton(tp1)).get(tp1).offset()); + + // fetch and verify consumer's position in the two partitions + final Map offsetResponse = new HashMap<>(); + offsetResponse.put(tp0, 3L); + offsetResponse.put(tp1, 3L); + client.prepareResponse(listOffsetsResponse(offsetResponse)); + assertEquals(3L, consumer.position(tp0)); + assertEquals(3L, consumer.position(tp1)); + + client.requests().clear(); + consumer.unsubscribe(); + consumer.close(); + } + + @Test + public void testPollWithNoSubscription() { + try (KafkaConsumer consumer = newConsumer((String) null)) { + assertThrows(IllegalStateException.class, () -> consumer.poll(Duration.ZERO)); + } + } + + @Test + public void testPollWithEmptySubscription() { + try (KafkaConsumer consumer = newConsumer(groupId)) { + consumer.subscribe(Collections.emptyList()); + assertThrows(IllegalStateException.class, () -> consumer.poll(Duration.ZERO)); + } + } + + @Test + public void testPollWithEmptyUserAssignment() { + try (KafkaConsumer consumer = newConsumer(groupId)) { + consumer.assign(Collections.emptySet()); + assertThrows(IllegalStateException.class, () -> consumer.poll(Duration.ZERO)); + } + } + + @Test + public void testGracefulClose() throws Exception { + Map response = new HashMap<>(); + response.put(tp0, Errors.NONE); + OffsetCommitResponse commitResponse = offsetCommitResponse(response); + LeaveGroupResponse leaveGroupResponse = new LeaveGroupResponse(new LeaveGroupResponseData().setErrorCode(Errors.NONE.code())); + consumerCloseTest(5000, Arrays.asList(commitResponse, leaveGroupResponse), 0, false); + } + + @Test + public void testCloseTimeout() throws Exception { + consumerCloseTest(5000, Collections.emptyList(), 5000, false); + } + + @Test + public void testLeaveGroupTimeout() throws Exception { + Map response = new HashMap<>(); + response.put(tp0, Errors.NONE); + OffsetCommitResponse commitResponse = offsetCommitResponse(response); + consumerCloseTest(5000, singletonList(commitResponse), 5000, false); + } + + @Test + public void testCloseNoWait() throws Exception { + consumerCloseTest(0, Collections.emptyList(), 0, false); + } + + @Test + public void testCloseInterrupt() throws Exception { + consumerCloseTest(Long.MAX_VALUE, Collections.emptyList(), 0, true); + } + + @Test + public void testCloseShouldBeIdempotent() { + KafkaConsumer consumer = newConsumer((String) null); + consumer.close(); + consumer.close(); + consumer.close(); + } + + @Test + public void testOperationsBySubscribingConsumerWithDefaultGroupId() { + try { + newConsumer(null, Optional.of(Boolean.TRUE)); + fail("Expected an InvalidConfigurationException"); + } catch (KafkaException e) { + assertEquals(InvalidConfigurationException.class, e.getCause().getClass()); + } + + try { + newConsumer((String) null).subscribe(Collections.singleton(topic)); + fail("Expected an InvalidGroupIdException"); + } catch (InvalidGroupIdException e) { + // OK, expected + } + + try { + newConsumer((String) null).committed(Collections.singleton(tp0)).get(tp0); + fail("Expected an InvalidGroupIdException"); + } catch (InvalidGroupIdException e) { + // OK, expected + } + + try { + newConsumer((String) null).commitAsync(); + fail("Expected an InvalidGroupIdException"); + } catch (InvalidGroupIdException e) { + // OK, expected + } + + try { + newConsumer((String) null).commitSync(); + fail("Expected an InvalidGroupIdException"); + } catch (InvalidGroupIdException e) { + // OK, expected + } + } + + @Test + public void testOperationsByAssigningConsumerWithDefaultGroupId() { + KafkaConsumer consumer = newConsumer((String) null); + consumer.assign(singleton(tp0)); + + try { + consumer.committed(Collections.singleton(tp0)).get(tp0); + fail("Expected an InvalidGroupIdException"); + } catch (InvalidGroupIdException e) { + // OK, expected + } + + try { + consumer.commitAsync(); + fail("Expected an InvalidGroupIdException"); + } catch (InvalidGroupIdException e) { + // OK, expected + } + + try { + consumer.commitSync(); + fail("Expected an InvalidGroupIdException"); + } catch (InvalidGroupIdException e) { + // OK, expected + } + } + + @Test + public void testMetricConfigRecordingLevel() { + Properties props = new Properties(); + props.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000"); + try (KafkaConsumer consumer = new KafkaConsumer<>(props, new ByteArrayDeserializer(), new ByteArrayDeserializer())) { + assertEquals(Sensor.RecordingLevel.INFO, consumer.metrics.config().recordLevel()); + } + + props.put(ConsumerConfig.METRICS_RECORDING_LEVEL_CONFIG, "DEBUG"); + try (KafkaConsumer consumer = new KafkaConsumer<>(props, new ByteArrayDeserializer(), new ByteArrayDeserializer())) { + assertEquals(Sensor.RecordingLevel.DEBUG, consumer.metrics.config().recordLevel()); + } + } + + @Test + public void testShouldAttemptToRejoinGroupAfterSyncGroupFailed() throws Exception { + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + + initMetadata(client, Collections.singletonMap(topic, 1)); + Node node = metadata.fetch().nodes().get(0); + + ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, false, groupInstanceId); + consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer)); + client.prepareResponseFrom(FindCoordinatorResponse.prepareResponse(Errors.NONE, groupId, node), node); + Node coordinator = new Node(Integer.MAX_VALUE - node.id(), node.host(), node.port()); + + + client.prepareResponseFrom(joinGroupFollowerResponse(assignor, 1, memberId, leaderId, Errors.NONE), coordinator); + client.prepareResponseFrom(syncGroupResponse(singletonList(tp0), Errors.NONE), coordinator); + + client.prepareResponseFrom(fetchResponse(tp0, 0, 1), node); + client.prepareResponseFrom(fetchResponse(tp0, 1, 0), node); + + consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE)); + consumer.poll(Duration.ZERO); + + // heartbeat fails due to rebalance in progress + client.prepareResponseFrom(body -> true, new HeartbeatResponse( + new HeartbeatResponseData().setErrorCode(Errors.REBALANCE_IN_PROGRESS.code())), coordinator); + + // join group + final ByteBuffer byteBuffer = ConsumerProtocol.serializeSubscription(new ConsumerPartitionAssignor.Subscription(singletonList(topic))); + + // This member becomes the leader + final JoinGroupResponse leaderResponse = new JoinGroupResponse( + new JoinGroupResponseData() + .setErrorCode(Errors.NONE.code()) + .setGenerationId(1).setProtocolName(assignor.name()) + .setLeader(memberId).setMemberId(memberId) + .setMembers(Collections.singletonList( + new JoinGroupResponseData.JoinGroupResponseMember() + .setMemberId(memberId) + .setMetadata(byteBuffer.array()) + ) + ) + ); + + client.prepareResponseFrom(leaderResponse, coordinator); + + // sync group fails due to disconnect + client.prepareResponseFrom(syncGroupResponse(singletonList(tp0), Errors.NONE), coordinator, true); + + // should try and find the new coordinator + client.prepareResponseFrom(FindCoordinatorResponse.prepareResponse(Errors.NONE, groupId, node), node); + + // rejoin group + client.prepareResponseFrom(joinGroupFollowerResponse(assignor, 1, memberId, leaderId, Errors.NONE), coordinator); + client.prepareResponseFrom(syncGroupResponse(singletonList(tp0), Errors.NONE), coordinator); + + client.prepareResponseFrom(body -> body instanceof FetchRequest + && ((FetchRequest) body).fetchData(topicNames).containsKey(new TopicIdPartition(topicId, tp0)), fetchResponse(tp0, 1, 1), node); + time.sleep(heartbeatIntervalMs); + Thread.sleep(heartbeatIntervalMs); + consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE)); + final ConsumerRecords records = consumer.poll(Duration.ZERO); + assertFalse(records.isEmpty()); + consumer.close(Duration.ofMillis(0)); + } + + private void consumerCloseTest(final long closeTimeoutMs, + List responses, + long waitMs, + boolean interrupt) throws Exception { + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + + initMetadata(client, Collections.singletonMap(topic, 1)); + Node node = metadata.fetch().nodes().get(0); + + ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + + final KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, false, Optional.empty()); + consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer)); + Node coordinator = prepareRebalance(client, node, assignor, singletonList(tp0), null); + + client.prepareMetadataUpdate(RequestTestUtils.metadataUpdateWithIds(1, Collections.singletonMap(topic, 1), topicIds)); + + consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE)); + + // Poll with responses + client.prepareResponseFrom(fetchResponse(tp0, 0, 1), node); + client.prepareResponseFrom(fetchResponse(tp0, 1, 0), node); + consumer.poll(Duration.ZERO); + + // Initiate close() after a commit request on another thread. + // Kafka consumer is single-threaded, but the implementation allows calls on a + // different thread as long as the calls are not executed concurrently. So this is safe. + ExecutorService executor = Executors.newSingleThreadExecutor(); + final AtomicReference closeException = new AtomicReference<>(); + try { + Future future = executor.submit(() -> { + consumer.commitAsync(); + try { + consumer.close(Duration.ofMillis(closeTimeoutMs)); + } catch (Exception e) { + closeException.set(e); + } + }); + + // Close task should not complete until commit succeeds or close times out + // if close timeout is not zero. + try { + future.get(100, TimeUnit.MILLISECONDS); + if (closeTimeoutMs != 0) + fail("Close completed without waiting for commit or leave response"); + } catch (TimeoutException e) { + // Expected exception + } + + // Ensure close has started and queued at least one more request after commitAsync + client.waitForRequests(2, 1000); + + // In graceful mode, commit response results in close() completing immediately without a timeout + // In non-graceful mode, close() times out without an exception even though commit response is pending + for (int i = 0; i < responses.size(); i++) { + client.waitForRequests(1, 1000); + client.respondFrom(responses.get(i), coordinator); + if (i != responses.size() - 1) { + try { + future.get(100, TimeUnit.MILLISECONDS); + fail("Close completed without waiting for response"); + } catch (TimeoutException e) { + // Expected exception + } + } + } + + if (waitMs > 0) + time.sleep(waitMs); + if (interrupt) { + assertTrue(future.cancel(true), "Close terminated prematurely"); + + TestUtils.waitForCondition( + () -> closeException.get() != null, "InterruptException did not occur within timeout."); + + assertTrue(closeException.get() instanceof InterruptException, "Expected exception not thrown " + closeException); + } else { + future.get(500, TimeUnit.MILLISECONDS); // Should succeed without TimeoutException or ExecutionException + assertNull(closeException.get(), "Unexpected exception during close"); + } + } finally { + executor.shutdownNow(); + } + } + + @Test + public void testPartitionsForNonExistingTopic() { + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + + initMetadata(client, Collections.singletonMap(topic, 1)); + Cluster cluster = metadata.fetch(); + + MetadataResponse updateResponse = RequestTestUtils.metadataResponse(cluster.nodes(), + cluster.clusterResource().clusterId(), + cluster.controller().id(), + Collections.emptyList()); + client.prepareResponse(updateResponse); + + ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId); + assertEquals(Collections.emptyList(), consumer.partitionsFor("non-exist-topic")); + } + + @Test + public void testPartitionsForAuthenticationFailure() { + final KafkaConsumer consumer = consumerWithPendingAuthenticationError(); + assertThrows(AuthenticationException.class, () -> consumer.partitionsFor("some other topic")); + } + + @Test + public void testBeginningOffsetsAuthenticationFailure() { + final KafkaConsumer consumer = consumerWithPendingAuthenticationError(); + assertThrows(AuthenticationException.class, () -> consumer.beginningOffsets(Collections.singleton(tp0))); + } + + @Test + public void testEndOffsetsAuthenticationFailure() { + final KafkaConsumer consumer = consumerWithPendingAuthenticationError(); + assertThrows(AuthenticationException.class, () -> consumer.endOffsets(Collections.singleton(tp0))); + } + + @Test + public void testPollAuthenticationFailure() { + final KafkaConsumer consumer = consumerWithPendingAuthenticationError(); + consumer.subscribe(singleton(topic)); + assertThrows(AuthenticationException.class, () -> consumer.poll(Duration.ZERO)); + } + + @Test + public void testOffsetsForTimesAuthenticationFailure() { + final KafkaConsumer consumer = consumerWithPendingAuthenticationError(); + assertThrows(AuthenticationException.class, () -> consumer.offsetsForTimes(singletonMap(tp0, 0L))); + } + + @Test + public void testCommitSyncAuthenticationFailure() { + final KafkaConsumer consumer = consumerWithPendingAuthenticationError(); + Map offsets = new HashMap<>(); + offsets.put(tp0, new OffsetAndMetadata(10L)); + assertThrows(AuthenticationException.class, () -> consumer.commitSync(offsets)); + } + + @Test + public void testCommittedAuthenticationFailure() { + final KafkaConsumer consumer = consumerWithPendingAuthenticationError(); + assertThrows(AuthenticationException.class, () -> consumer.committed(Collections.singleton(tp0)).get(tp0)); + } + + @Test + public void testMeasureCommitSyncDurationOnFailure() { + final KafkaConsumer consumer + = consumerWithPendingError(new MockTime(Duration.ofSeconds(1).toMillis())); + + try { + consumer.commitSync(Collections.singletonMap(tp0, new OffsetAndMetadata(10L))); + } catch (final RuntimeException e) { + } + + final Metric metric = consumer.metrics() + .get(consumer.metrics.metricName("commit-sync-time-ns-total", "consumer-metrics")); + assertTrue((Double) metric.metricValue() >= Duration.ofMillis(999).toNanos()); + } + + @Test + public void testMeasureCommitSyncDuration() { + Time time = new MockTime(Duration.ofSeconds(1).toMillis()); + SubscriptionState subscription = new SubscriptionState(new LogContext(), + OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + initMetadata(client, Collections.singletonMap(topic, 2)); + Node node = metadata.fetch().nodes().get(0); + ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, + assignor, true, groupInstanceId); + consumer.assign(singletonList(tp0)); + + client.prepareResponseFrom( + FindCoordinatorResponse.prepareResponse(Errors.NONE, groupId, node), node); + Node coordinator = new Node(Integer.MAX_VALUE - node.id(), node.host(), node.port()); + client.prepareResponseFrom( + offsetCommitResponse(Collections.singletonMap(tp0, Errors.NONE)), + coordinator + ); + + consumer.commitSync(Collections.singletonMap(tp0, new OffsetAndMetadata(10L))); + + final Metric metric = consumer.metrics() + .get(consumer.metrics.metricName("commit-sync-time-ns-total", "consumer-metrics")); + assertTrue((Double) metric.metricValue() >= Duration.ofMillis(999).toNanos()); + } + + @Test + public void testMeasureCommittedDurationOnFailure() { + final KafkaConsumer consumer + = consumerWithPendingError(new MockTime(Duration.ofSeconds(1).toMillis())); + + try { + consumer.committed(Collections.singleton(tp0)); + } catch (final RuntimeException e) { + } + + final Metric metric = consumer.metrics() + .get(consumer.metrics.metricName("committed-time-ns-total", "consumer-metrics")); + assertTrue((Double) metric.metricValue() >= Duration.ofMillis(999).toNanos()); + } + + @Test + public void testMeasureCommittedDuration() { + long offset1 = 10000; + Time time = new MockTime(Duration.ofSeconds(1).toMillis()); + SubscriptionState subscription = new SubscriptionState(new LogContext(), + OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + initMetadata(client, Collections.singletonMap(topic, 2)); + Node node = metadata.fetch().nodes().get(0); + ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, + assignor, true, groupInstanceId); + consumer.assign(singletonList(tp0)); + + // lookup coordinator + client.prepareResponseFrom( + FindCoordinatorResponse.prepareResponse(Errors.NONE, groupId, node), node); + Node coordinator = new Node(Integer.MAX_VALUE - node.id(), node.host(), node.port()); + + // fetch offset for one topic + client.prepareResponseFrom( + offsetResponse(Collections.singletonMap(tp0, offset1), Errors.NONE), coordinator); + + consumer.committed(Collections.singleton(tp0)).get(tp0).offset(); + + final Metric metric = consumer.metrics() + .get(consumer.metrics.metricName("committed-time-ns-total", "consumer-metrics")); + assertTrue((Double) metric.metricValue() >= Duration.ofMillis(999).toNanos()); + } + + @Test + public void testRebalanceException() { + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + + initMetadata(client, Collections.singletonMap(topic, 1)); + Node node = metadata.fetch().nodes().get(0); + + ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId); + + consumer.subscribe(singleton(topic), getExceptionConsumerRebalanceListener()); + Node coordinator = new Node(Integer.MAX_VALUE - node.id(), node.host(), node.port()); + + client.prepareResponseFrom(FindCoordinatorResponse.prepareResponse(Errors.NONE, groupId, node), node); + client.prepareResponseFrom(joinGroupFollowerResponse(assignor, 1, memberId, leaderId, Errors.NONE), coordinator); + client.prepareResponseFrom(syncGroupResponse(singletonList(tp0), Errors.NONE), coordinator); + + // assign throws + try { + consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE)); + fail("Should throw exception"); + } catch (Throwable e) { + assertEquals(partitionAssigned + singleTopicPartition, e.getCause().getMessage()); + } + + // the assignment is still updated regardless of the exception + assertEquals(singleton(tp0), subscription.assignedPartitions()); + + // close's revoke throws + try { + consumer.close(Duration.ofMillis(0)); + fail("Should throw exception"); + } catch (Throwable e) { + assertEquals(partitionRevoked + singleTopicPartition, e.getCause().getCause().getMessage()); + } + + consumer.close(Duration.ofMillis(0)); + + // the assignment is still updated regardless of the exception + assertTrue(subscription.assignedPartitions().isEmpty()); + } + + @Test + public void testReturnRecordsDuringRebalance() throws InterruptedException { + Time time = new MockTime(1L); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + ConsumerPartitionAssignor assignor = new CooperativeStickyAssignor(); + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId); + + initMetadata(client, Utils.mkMap(Utils.mkEntry(topic, 1), Utils.mkEntry(topic2, 1), Utils.mkEntry(topic3, 1))); + + consumer.subscribe(Arrays.asList(topic, topic2), getConsumerRebalanceListener(consumer)); + + Node node = metadata.fetch().nodes().get(0); + Node coordinator = prepareRebalance(client, node, assignor, Arrays.asList(tp0, t2p0), null); + + // a poll with non-zero milliseconds would complete three round-trips (discover, join, sync) + TestUtils.waitForCondition(() -> { + consumer.poll(Duration.ofMillis(100L)); + return consumer.assignment().equals(Utils.mkSet(tp0, t2p0)); + }, "Does not complete rebalance in time"); + + assertEquals(Utils.mkSet(topic, topic2), consumer.subscription()); + assertEquals(Utils.mkSet(tp0, t2p0), consumer.assignment()); + + // prepare a response of the outstanding fetch so that we have data available on the next poll + Map fetches1 = new HashMap<>(); + fetches1.put(tp0, new FetchInfo(0, 1)); + fetches1.put(t2p0, new FetchInfo(0, 10)); + client.respondFrom(fetchResponse(fetches1), node); + + ConsumerRecords records = consumer.poll(Duration.ZERO); + + // verify that the fetch occurred as expected + assertEquals(11, records.count()); + assertEquals(1L, consumer.position(tp0)); + assertEquals(10L, consumer.position(t2p0)); + + // prepare the next response of the prefetch + fetches1.clear(); + fetches1.put(tp0, new FetchInfo(1, 1)); + fetches1.put(t2p0, new FetchInfo(10, 20)); + client.respondFrom(fetchResponse(fetches1), node); + + // subscription change + consumer.subscribe(Arrays.asList(topic, topic3), getConsumerRebalanceListener(consumer)); + + // verify that subscription has changed but assignment is still unchanged + assertEquals(Utils.mkSet(topic, topic3), consumer.subscription()); + assertEquals(Utils.mkSet(tp0, t2p0), consumer.assignment()); + + // mock the offset commit response for to be revoked partitions + Map partitionOffsets1 = new HashMap<>(); + partitionOffsets1.put(t2p0, 10L); + AtomicBoolean commitReceived = prepareOffsetCommitResponse(client, coordinator, partitionOffsets1); + + // poll once which would not complete the rebalance + records = consumer.poll(Duration.ZERO); + + // clear out the prefetch so it doesn't interfere with the rest of the test + fetches1.clear(); + fetches1.put(tp0, new FetchInfo(2, 1)); + client.respondFrom(fetchResponse(fetches1), node); + + // verify that the fetch still occurred as expected + assertEquals(Utils.mkSet(topic, topic3), consumer.subscription()); + assertEquals(Collections.singleton(tp0), consumer.assignment()); + assertEquals(1, records.count()); + assertEquals(2L, consumer.position(tp0)); + + // verify that the offset commits occurred as expected + assertTrue(commitReceived.get()); + + // mock rebalance responses + client.respondFrom(joinGroupFollowerResponse(assignor, 2, "memberId", "leaderId", Errors.NONE), coordinator); + + // we need to poll 1) for getting the join response, and then send the sync request; + // 2) for getting the sync response + records = consumer.poll(Duration.ZERO); + + // should not finish the response yet + assertEquals(Utils.mkSet(topic, topic3), consumer.subscription()); + assertEquals(Collections.singleton(tp0), consumer.assignment()); + assertEquals(1, records.count()); + assertEquals(3L, consumer.position(tp0)); + + fetches1.clear(); + fetches1.put(tp0, new FetchInfo(3, 1)); + client.respondFrom(fetchResponse(fetches1), node); + + // now complete the rebalance + client.respondFrom(syncGroupResponse(Arrays.asList(tp0, t3p0), Errors.NONE), coordinator); + + AtomicInteger count = new AtomicInteger(0); + TestUtils.waitForCondition(() -> { + ConsumerRecords recs = consumer.poll(Duration.ofMillis(100L)); + return consumer.assignment().equals(Utils.mkSet(tp0, t3p0)) && count.addAndGet(recs.count()) == 1; + + }, "Does not complete rebalance in time"); + + // should have t3 but not sent yet the t3 records + assertEquals(Utils.mkSet(topic, topic3), consumer.subscription()); + assertEquals(Utils.mkSet(tp0, t3p0), consumer.assignment()); + assertEquals(4L, consumer.position(tp0)); + assertEquals(0L, consumer.position(t3p0)); + + fetches1.clear(); + fetches1.put(tp0, new FetchInfo(4, 1)); + fetches1.put(t3p0, new FetchInfo(0, 100)); + client.respondFrom(fetchResponse(fetches1), node); + + count.set(0); + TestUtils.waitForCondition(() -> { + ConsumerRecords recs = consumer.poll(Duration.ofMillis(100L)); + return count.addAndGet(recs.count()) == 101; + + }, "Does not complete rebalance in time"); + + assertEquals(5L, consumer.position(tp0)); + assertEquals(100L, consumer.position(t3p0)); + + client.requests().clear(); + consumer.unsubscribe(); + consumer.close(); + } + + @Test + public void testGetGroupMetadata() { + final Time time = new MockTime(); + final SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + final ConsumerMetadata metadata = createMetadata(subscription); + final MockClient client = new MockClient(time, metadata); + + initMetadata(client, Collections.singletonMap(topic, 1)); + final Node node = metadata.fetch().nodes().get(0); + + final ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + + final KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId); + + final ConsumerGroupMetadata groupMetadataOnStart = consumer.groupMetadata(); + assertEquals(groupId, groupMetadataOnStart.groupId()); + assertEquals(JoinGroupRequest.UNKNOWN_MEMBER_ID, groupMetadataOnStart.memberId()); + assertEquals(JoinGroupRequest.UNKNOWN_GENERATION_ID, groupMetadataOnStart.generationId()); + assertEquals(groupInstanceId, groupMetadataOnStart.groupInstanceId()); + + consumer.subscribe(singleton(topic), getConsumerRebalanceListener(consumer)); + prepareRebalance(client, node, assignor, singletonList(tp0), null); + + // initial fetch + client.prepareResponseFrom(fetchResponse(tp0, 0, 0), node); + consumer.updateAssignmentMetadataIfNeeded(time.timer(Long.MAX_VALUE)); + + final ConsumerGroupMetadata groupMetadataAfterPoll = consumer.groupMetadata(); + assertEquals(groupId, groupMetadataAfterPoll.groupId()); + assertEquals(memberId, groupMetadataAfterPoll.memberId()); + assertEquals(1, groupMetadataAfterPoll.generationId()); + assertEquals(groupInstanceId, groupMetadataAfterPoll.groupInstanceId()); + } + + @Test + public void testInvalidGroupMetadata() throws InterruptedException { + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + initMetadata(client, Collections.singletonMap(topic, 1)); + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, + new RoundRobinAssignor(), true, groupInstanceId); + consumer.subscribe(singletonList(topic)); + // concurrent access is illegal + client.enableBlockingUntilWakeup(1); + ExecutorService service = Executors.newSingleThreadExecutor(); + service.execute(() -> consumer.poll(Duration.ofSeconds(5))); + try { + TimeUnit.SECONDS.sleep(1); + assertThrows(ConcurrentModificationException.class, consumer::groupMetadata); + client.wakeup(); + consumer.wakeup(); + } finally { + service.shutdown(); + assertTrue(service.awaitTermination(10, TimeUnit.SECONDS)); + } + + // accessing closed consumer is illegal + consumer.close(Duration.ofSeconds(5)); + assertThrows(IllegalStateException.class, consumer::groupMetadata); + } + + @Test + public void testCurrentLag() { + final Time time = new MockTime(); + final SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + final ConsumerMetadata metadata = createMetadata(subscription); + final MockClient client = new MockClient(time, metadata); + + initMetadata(client, singletonMap(topic, 1)); + final ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + + final KafkaConsumer consumer = + newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId); + + // throws for unassigned partition + assertThrows(IllegalStateException.class, () -> consumer.currentLag(tp0)); + + consumer.assign(singleton(tp0)); + + // poll once to update with the current metadata + consumer.poll(Duration.ofMillis(0)); + client.respond(FindCoordinatorResponse.prepareResponse(Errors.NONE, groupId, metadata.fetch().nodes().get(0))); + + // no error for no current position + assertEquals(OptionalLong.empty(), consumer.currentLag(tp0)); + assertEquals(0, client.inFlightRequestCount()); + + // poll once again, which should send the list-offset request + consumer.seek(tp0, 50L); + consumer.poll(Duration.ofMillis(0)); + // requests: list-offset, fetch + assertEquals(2, client.inFlightRequestCount()); + + // no error for no end offset (so unknown lag) + assertEquals(OptionalLong.empty(), consumer.currentLag(tp0)); + + // poll once again, which should return the list-offset response + // and hence next call would return correct lag result + client.respond(listOffsetsResponse(singletonMap(tp0, 90L))); + consumer.poll(Duration.ofMillis(0)); + + assertEquals(OptionalLong.of(40L), consumer.currentLag(tp0)); + // requests: fetch + assertEquals(1, client.inFlightRequestCount()); + + // one successful fetch should update the log end offset and the position + final FetchInfo fetchInfo = new FetchInfo(1L, 99L, 50L, 5); + client.respond(fetchResponse(singletonMap(tp0, fetchInfo))); + + final ConsumerRecords records = consumer.poll(Duration.ofMillis(1)); + assertEquals(5, records.count()); + assertEquals(55L, consumer.position(tp0)); + + // correct lag result + assertEquals(OptionalLong.of(45L), consumer.currentLag(tp0)); + + consumer.close(Duration.ZERO); + } + + @Test + public void testListOffsetShouldUpateSubscriptions() { + final Time time = new MockTime(); + final SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + final ConsumerMetadata metadata = createMetadata(subscription); + final MockClient client = new MockClient(time, metadata); + + initMetadata(client, singletonMap(topic, 1)); + final ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + + final KafkaConsumer consumer = + newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId); + + consumer.assign(singleton(tp0)); + + // poll once to update with the current metadata + consumer.poll(Duration.ofMillis(0)); + client.respond(FindCoordinatorResponse.prepareResponse(Errors.NONE, groupId, metadata.fetch().nodes().get(0))); + + consumer.seek(tp0, 50L); + client.prepareResponse(listOffsetsResponse(singletonMap(tp0, 90L))); + + assertEquals(singletonMap(tp0, 90L), consumer.endOffsets(Collections.singleton(tp0))); + // correct lag result should be returned as well + assertEquals(OptionalLong.of(40L), consumer.currentLag(tp0)); + + consumer.close(Duration.ZERO); + } + + private KafkaConsumer consumerWithPendingAuthenticationError(final Time time) { + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + + initMetadata(client, singletonMap(topic, 1)); + Node node = metadata.fetch().nodes().get(0); + + ConsumerPartitionAssignor assignor = new RangeAssignor(); + + client.createPendingAuthenticationError(node, 0); + return newConsumer(time, client, subscription, metadata, assignor, false, groupInstanceId); + } + + private KafkaConsumer consumerWithPendingAuthenticationError() { + return consumerWithPendingAuthenticationError(new MockTime()); + } + + private KafkaConsumer consumerWithPendingError(final Time time) { + return consumerWithPendingAuthenticationError(time); + } + + private ConsumerRebalanceListener getConsumerRebalanceListener(final KafkaConsumer consumer) { + return new ConsumerRebalanceListener() { + @Override + public void onPartitionsRevoked(Collection partitions) { + } + + @Override + public void onPartitionsAssigned(Collection partitions) { + // set initial position so we don't need a lookup + for (TopicPartition partition : partitions) + consumer.seek(partition, 0); + } + }; + } + + private ConsumerRebalanceListener getExceptionConsumerRebalanceListener() { + return new ConsumerRebalanceListener() { + @Override + public void onPartitionsRevoked(Collection partitions) { + throw new RuntimeException(partitionRevoked + partitions); + } + + @Override + public void onPartitionsAssigned(Collection partitions) { + throw new RuntimeException(partitionAssigned + partitions); + } + + @Override + public void onPartitionsLost(Collection partitions) { + throw new RuntimeException(partitionLost + partitions); + } + }; + } + + private ConsumerMetadata createMetadata(SubscriptionState subscription) { + return new ConsumerMetadata(0, Long.MAX_VALUE, false, false, + subscription, new LogContext(), new ClusterResourceListeners()); + } + + private Node prepareRebalance(MockClient client, Node node, final Set subscribedTopics, ConsumerPartitionAssignor assignor, List partitions, Node coordinator) { + if (coordinator == null) { + // lookup coordinator + client.prepareResponseFrom(FindCoordinatorResponse.prepareResponse(Errors.NONE, groupId, node), node); + coordinator = new Node(Integer.MAX_VALUE - node.id(), node.host(), node.port()); + } + + // join group + client.prepareResponseFrom(body -> { + JoinGroupRequest joinGroupRequest = (JoinGroupRequest) body; + Iterator protocolIterator = + joinGroupRequest.data().protocols().iterator(); + assertTrue(protocolIterator.hasNext()); + + ByteBuffer protocolMetadata = ByteBuffer.wrap(protocolIterator.next().metadata()); + ConsumerPartitionAssignor.Subscription subscription = ConsumerProtocol.deserializeSubscription(protocolMetadata); + return subscribedTopics.equals(new HashSet<>(subscription.topics())); + }, joinGroupFollowerResponse(assignor, 1, memberId, leaderId, Errors.NONE), coordinator); + + // sync group + client.prepareResponseFrom(syncGroupResponse(partitions, Errors.NONE), coordinator); + + return coordinator; + } + + private Node prepareRebalance(MockClient client, Node node, ConsumerPartitionAssignor assignor, List partitions, Node coordinator) { + if (coordinator == null) { + // lookup coordinator + client.prepareResponseFrom(FindCoordinatorResponse.prepareResponse(Errors.NONE, groupId, node), node); + coordinator = new Node(Integer.MAX_VALUE - node.id(), node.host(), node.port()); + } + + // join group + client.prepareResponseFrom(joinGroupFollowerResponse(assignor, 1, memberId, leaderId, Errors.NONE), coordinator); + + // sync group + client.prepareResponseFrom(syncGroupResponse(partitions, Errors.NONE), coordinator); + + return coordinator; + } + + private AtomicBoolean prepareHeartbeatResponse(MockClient client, Node coordinator, Errors error) { + final AtomicBoolean heartbeatReceived = new AtomicBoolean(false); + client.prepareResponseFrom(body -> { + heartbeatReceived.set(true); + return true; + }, new HeartbeatResponse(new HeartbeatResponseData().setErrorCode(error.code())), coordinator); + return heartbeatReceived; + } + + private AtomicBoolean prepareOffsetCommitResponse(MockClient client, Node coordinator, final Map partitionOffsets) { + final AtomicBoolean commitReceived = new AtomicBoolean(true); + Map response = new HashMap<>(); + for (TopicPartition partition : partitionOffsets.keySet()) + response.put(partition, Errors.NONE); + + client.prepareResponseFrom(body -> { + OffsetCommitRequest commitRequest = (OffsetCommitRequest) body; + Map commitErrors = commitRequest.offsets(); + + for (Map.Entry partitionOffset : partitionOffsets.entrySet()) { + // verify that the expected offset has been committed + if (!commitErrors.get(partitionOffset.getKey()).equals(partitionOffset.getValue())) { + commitReceived.set(false); + return false; + } + } + return true; + }, offsetCommitResponse(response), coordinator); + return commitReceived; + } + + private AtomicBoolean prepareOffsetCommitResponse(MockClient client, Node coordinator, final TopicPartition partition, final long offset) { + return prepareOffsetCommitResponse(client, coordinator, Collections.singletonMap(partition, offset)); + } + + private OffsetCommitResponse offsetCommitResponse(Map responseData) { + return new OffsetCommitResponse(responseData); + } + + private JoinGroupResponse joinGroupFollowerResponse(ConsumerPartitionAssignor assignor, int generationId, String memberId, String leaderId, Errors error) { + return new JoinGroupResponse( + new JoinGroupResponseData() + .setErrorCode(error.code()) + .setGenerationId(generationId) + .setProtocolName(assignor.name()) + .setLeader(leaderId) + .setMemberId(memberId) + .setMembers(Collections.emptyList()) + ); + } + + private SyncGroupResponse syncGroupResponse(List partitions, Errors error) { + ByteBuffer buf = ConsumerProtocol.serializeAssignment(new ConsumerPartitionAssignor.Assignment(partitions)); + return new SyncGroupResponse( + new SyncGroupResponseData() + .setErrorCode(error.code()) + .setAssignment(Utils.toArray(buf)) + ); + } + + private OffsetFetchResponse offsetResponse(Map offsets, Errors error) { + Map partitionData = new HashMap<>(); + for (Map.Entry entry : offsets.entrySet()) { + partitionData.put(entry.getKey(), new OffsetFetchResponse.PartitionData(entry.getValue(), + Optional.empty(), "", error)); + } + return new OffsetFetchResponse(Errors.NONE, partitionData); + } + + private ListOffsetsResponse listOffsetsResponse(Map offsets) { + return listOffsetsResponse(offsets, Collections.emptyMap()); + } + + private ListOffsetsResponse listOffsetsResponse(Map partitionOffsets, + Map partitionErrors) { + Map responses = new HashMap<>(); + for (Map.Entry partitionOffset : partitionOffsets.entrySet()) { + TopicPartition tp = partitionOffset.getKey(); + ListOffsetsTopicResponse topic = responses.computeIfAbsent(tp.topic(), k -> new ListOffsetsTopicResponse().setName(tp.topic())); + topic.partitions().add(new ListOffsetsPartitionResponse() + .setPartitionIndex(tp.partition()) + .setErrorCode(Errors.NONE.code()) + .setTimestamp(ListOffsetsResponse.UNKNOWN_TIMESTAMP) + .setOffset(partitionOffset.getValue())); + } + + for (Map.Entry partitionError : partitionErrors.entrySet()) { + TopicPartition tp = partitionError.getKey(); + ListOffsetsTopicResponse topic = responses.computeIfAbsent(tp.topic(), k -> new ListOffsetsTopicResponse().setName(tp.topic())); + topic.partitions().add(new ListOffsetsPartitionResponse() + .setPartitionIndex(tp.partition()) + .setErrorCode(partitionError.getValue().code()) + .setTimestamp(ListOffsetsResponse.UNKNOWN_TIMESTAMP) + .setOffset(ListOffsetsResponse.UNKNOWN_OFFSET)); + } + ListOffsetsResponseData data = new ListOffsetsResponseData() + .setTopics(new ArrayList<>(responses.values())); + return new ListOffsetsResponse(data); + } + + private FetchResponse fetchResponse(Map fetches) { + LinkedHashMap tpResponses = new LinkedHashMap<>(); + for (Map.Entry fetchEntry : fetches.entrySet()) { + TopicPartition partition = fetchEntry.getKey(); + long fetchOffset = fetchEntry.getValue().offset; + int fetchCount = fetchEntry.getValue().count; + final long highWatermark = fetchEntry.getValue().logLastOffset + 1; + final long logStartOffset = fetchEntry.getValue().logFirstOffset; + final MemoryRecords records; + if (fetchCount == 0) { + records = MemoryRecords.EMPTY; + } else { + MemoryRecordsBuilder builder = MemoryRecords.builder(ByteBuffer.allocate(1024), CompressionType.NONE, + TimestampType.CREATE_TIME, fetchOffset); + for (int i = 0; i < fetchCount; i++) + builder.append(0L, ("key-" + i).getBytes(), ("value-" + i).getBytes()); + records = builder.build(); + } + tpResponses.put(new TopicIdPartition(topicIds.get(partition.topic()), partition), + new FetchResponseData.PartitionData() + .setPartitionIndex(partition.partition()) + .setHighWatermark(highWatermark) + .setLogStartOffset(logStartOffset) + .setRecords(records)); + } + return FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, tpResponses); + } + + private FetchResponse fetchResponse(TopicPartition partition, long fetchOffset, int count) { + FetchInfo fetchInfo = new FetchInfo(fetchOffset, count); + return fetchResponse(Collections.singletonMap(partition, fetchInfo)); + } + + private KafkaConsumer newConsumer(Time time, + KafkaClient client, + SubscriptionState subscription, + ConsumerMetadata metadata, + ConsumerPartitionAssignor assignor, + boolean autoCommitEnabled, + Optional groupInstanceId) { + return newConsumer(time, client, subscription, metadata, assignor, autoCommitEnabled, groupId, groupInstanceId, false); + } + + private KafkaConsumer newConsumerNoAutoCommit(Time time, + KafkaClient client, + SubscriptionState subscription, + ConsumerMetadata metadata) { + return newConsumer(time, client, subscription, metadata, new RangeAssignor(), false, groupId, groupInstanceId, false); + } + + private KafkaConsumer newConsumer(Time time, + KafkaClient client, + SubscriptionState subscription, + ConsumerMetadata metadata, + ConsumerPartitionAssignor assignor, + boolean autoCommitEnabled, + String groupId, + Optional groupInstanceId, + boolean throwOnStableOffsetNotSupported) { + return newConsumer(time, client, subscription, metadata, assignor, autoCommitEnabled, groupId, groupInstanceId, + Optional.of(new StringDeserializer()), throwOnStableOffsetNotSupported); + } + + private KafkaConsumer newConsumer(Time time, + KafkaClient client, + SubscriptionState subscription, + ConsumerMetadata metadata, + ConsumerPartitionAssignor assignor, + boolean autoCommitEnabled, + String groupId, + Optional groupInstanceId, + Optional> valueDeserializer, + boolean throwOnStableOffsetNotSupported) { + String clientId = "mock-consumer"; + String metricGroupPrefix = "consumer"; + long retryBackoffMs = 100; + int requestTimeoutMs = 30000; + int defaultApiTimeoutMs = 30000; + int minBytes = 1; + int maxBytes = Integer.MAX_VALUE; + int maxWaitMs = 500; + int fetchSize = 1024 * 1024; + int maxPollRecords = Integer.MAX_VALUE; + boolean checkCrcs = true; + int rebalanceTimeoutMs = 60000; + + Deserializer keyDeserializer = new StringDeserializer(); + Deserializer deserializer = valueDeserializer.orElse(new StringDeserializer()); + + List assignors = singletonList(assignor); + ConsumerInterceptors interceptors = new ConsumerInterceptors<>(Collections.emptyList()); + + Metrics metrics = new Metrics(time); + ConsumerMetrics metricsRegistry = new ConsumerMetrics(metricGroupPrefix); + + LogContext loggerFactory = new LogContext(); + ConsumerNetworkClient consumerClient = new ConsumerNetworkClient(loggerFactory, client, metadata, time, + retryBackoffMs, requestTimeoutMs, heartbeatIntervalMs); + + GroupRebalanceConfig rebalanceConfig = new GroupRebalanceConfig(sessionTimeoutMs, + rebalanceTimeoutMs, + heartbeatIntervalMs, + groupId, + groupInstanceId, + retryBackoffMs, + true); + ConsumerCoordinator consumerCoordinator = new ConsumerCoordinator(rebalanceConfig, + loggerFactory, + consumerClient, + assignors, + metadata, + subscription, + metrics, + metricGroupPrefix, + time, + autoCommitEnabled, + autoCommitIntervalMs, + interceptors, + throwOnStableOffsetNotSupported); + Fetcher fetcher = new Fetcher<>( + loggerFactory, + consumerClient, + minBytes, + maxBytes, + maxWaitMs, + fetchSize, + maxPollRecords, + checkCrcs, + "", + keyDeserializer, + deserializer, + metadata, + subscription, + metrics, + metricsRegistry.fetcherMetrics, + time, + retryBackoffMs, + requestTimeoutMs, + IsolationLevel.READ_UNCOMMITTED, + new ApiVersions()); + + return new KafkaConsumer<>( + loggerFactory, + clientId, + consumerCoordinator, + keyDeserializer, + deserializer, + fetcher, + interceptors, + time, + consumerClient, + metrics, + subscription, + metadata, + retryBackoffMs, + requestTimeoutMs, + defaultApiTimeoutMs, + assignors, + groupId); + } + + private static class FetchInfo { + long logFirstOffset; + long logLastOffset; + long offset; + int count; + + FetchInfo(long offset, int count) { + this(0L, offset + count, offset, count); + } + + FetchInfo(long logFirstOffset, long logLastOffset, long offset, int count) { + this.logFirstOffset = logFirstOffset; + this.logLastOffset = logLastOffset; + this.offset = offset; + this.count = count; + } + } + + @Test + public void testSubscriptionOnInvalidTopic() { + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + + initMetadata(client, Collections.singletonMap(topic, 1)); + Cluster cluster = metadata.fetch(); + + ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + + String invalidTopicName = "topic abc"; // Invalid topic name due to space + + List topicMetadata = new ArrayList<>(); + topicMetadata.add(new MetadataResponse.TopicMetadata(Errors.INVALID_TOPIC_EXCEPTION, + invalidTopicName, false, Collections.emptyList())); + MetadataResponse updateResponse = RequestTestUtils.metadataResponse(cluster.nodes(), + cluster.clusterResource().clusterId(), + cluster.controller().id(), + topicMetadata); + client.prepareMetadataUpdate(updateResponse); + + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId); + consumer.subscribe(singleton(invalidTopicName), getConsumerRebalanceListener(consumer)); + + assertThrows(InvalidTopicException.class, () -> consumer.poll(Duration.ZERO)); + } + + @Test + public void testPollTimeMetrics() { + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + initMetadata(client, Collections.singletonMap(topic, 1)); + + ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId); + consumer.subscribe(singletonList(topic)); + // MetricName objects to check + Metrics metrics = consumer.metrics; + MetricName lastPollSecondsAgoName = metrics.metricName("last-poll-seconds-ago", "consumer-metrics"); + MetricName timeBetweenPollAvgName = metrics.metricName("time-between-poll-avg", "consumer-metrics"); + MetricName timeBetweenPollMaxName = metrics.metricName("time-between-poll-max", "consumer-metrics"); + // Test default values + assertEquals(-1.0d, consumer.metrics().get(lastPollSecondsAgoName).metricValue()); + assertEquals(Double.NaN, consumer.metrics().get(timeBetweenPollAvgName).metricValue()); + assertEquals(Double.NaN, consumer.metrics().get(timeBetweenPollMaxName).metricValue()); + // Call first poll + consumer.poll(Duration.ZERO); + assertEquals(0.0d, consumer.metrics().get(lastPollSecondsAgoName).metricValue()); + assertEquals(0.0d, consumer.metrics().get(timeBetweenPollAvgName).metricValue()); + assertEquals(0.0d, consumer.metrics().get(timeBetweenPollMaxName).metricValue()); + // Advance time by 5,000 (total time = 5,000) + time.sleep(5 * 1000L); + assertEquals(5.0d, consumer.metrics().get(lastPollSecondsAgoName).metricValue()); + // Call second poll + consumer.poll(Duration.ZERO); + assertEquals(2.5 * 1000d, consumer.metrics().get(timeBetweenPollAvgName).metricValue()); + assertEquals(5 * 1000d, consumer.metrics().get(timeBetweenPollMaxName).metricValue()); + // Advance time by 10,000 (total time = 15,000) + time.sleep(10 * 1000L); + assertEquals(10.0d, consumer.metrics().get(lastPollSecondsAgoName).metricValue()); + // Call third poll + consumer.poll(Duration.ZERO); + assertEquals(5 * 1000d, consumer.metrics().get(timeBetweenPollAvgName).metricValue()); + assertEquals(10 * 1000d, consumer.metrics().get(timeBetweenPollMaxName).metricValue()); + // Advance time by 5,000 (total time = 20,000) + time.sleep(5 * 1000L); + assertEquals(5.0d, consumer.metrics().get(lastPollSecondsAgoName).metricValue()); + // Call fourth poll + consumer.poll(Duration.ZERO); + assertEquals(5 * 1000d, consumer.metrics().get(timeBetweenPollAvgName).metricValue()); + assertEquals(10 * 1000d, consumer.metrics().get(timeBetweenPollMaxName).metricValue()); + } + + @Test + public void testPollIdleRatio() { + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + initMetadata(client, Collections.singletonMap(topic, 1)); + + ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId); + // MetricName object to check + Metrics metrics = consumer.metrics; + MetricName pollIdleRatio = metrics.metricName("poll-idle-ratio-avg", "consumer-metrics"); + // Test default value + assertEquals(Double.NaN, consumer.metrics().get(pollIdleRatio).metricValue()); + + // 1st poll + // Spend 50ms in poll so value = 1.0 + consumer.kafkaConsumerMetrics.recordPollStart(time.milliseconds()); + time.sleep(50); + consumer.kafkaConsumerMetrics.recordPollEnd(time.milliseconds()); + + assertEquals(1.0d, consumer.metrics().get(pollIdleRatio).metricValue()); + + // 2nd poll + // Spend 50m outside poll and 0ms in poll so value = 0.0 + time.sleep(50); + consumer.kafkaConsumerMetrics.recordPollStart(time.milliseconds()); + consumer.kafkaConsumerMetrics.recordPollEnd(time.milliseconds()); + + // Avg of first two data points + assertEquals((1.0d + 0.0d) / 2, consumer.metrics().get(pollIdleRatio).metricValue()); + + // 3rd poll + // Spend 25ms outside poll and 25ms in poll so value = 0.5 + time.sleep(25); + consumer.kafkaConsumerMetrics.recordPollStart(time.milliseconds()); + time.sleep(25); + consumer.kafkaConsumerMetrics.recordPollEnd(time.milliseconds()); + + // Avg of three data points + assertEquals((1.0d + 0.0d + 0.5d) / 3, consumer.metrics().get(pollIdleRatio).metricValue()); + } + + private static boolean consumerMetricPresent(KafkaConsumer consumer, String name) { + MetricName metricName = new MetricName(name, "consumer-metrics", "", Collections.emptyMap()); + return consumer.metrics.metrics().containsKey(metricName); + } + + @Test + public void testClosingConsumerUnregistersConsumerMetrics() { + Time time = new MockTime(); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + initMetadata(client, Collections.singletonMap(topic, 1)); + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, + new RoundRobinAssignor(), true, groupInstanceId); + consumer.subscribe(singletonList(topic)); + assertTrue(consumerMetricPresent(consumer, "last-poll-seconds-ago")); + assertTrue(consumerMetricPresent(consumer, "time-between-poll-avg")); + assertTrue(consumerMetricPresent(consumer, "time-between-poll-max")); + consumer.close(); + assertFalse(consumerMetricPresent(consumer, "last-poll-seconds-ago")); + assertFalse(consumerMetricPresent(consumer, "time-between-poll-avg")); + assertFalse(consumerMetricPresent(consumer, "time-between-poll-max")); + } + + @Test + public void testEnforceRebalanceWithManualAssignment() { + try (KafkaConsumer consumer = newConsumer((String) null)) { + consumer.assign(singleton(new TopicPartition("topic", 0))); + assertThrows(IllegalStateException.class, consumer::enforceRebalance); + } + } + + @Test + public void testEnforceRebalanceTriggersRebalanceOnNextPoll() { + Time time = new MockTime(1L); + SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + ConsumerMetadata metadata = createMetadata(subscription); + MockClient client = new MockClient(time, metadata); + ConsumerPartitionAssignor assignor = new RoundRobinAssignor(); + KafkaConsumer consumer = newConsumer(time, client, subscription, metadata, assignor, true, groupInstanceId); + MockRebalanceListener countingRebalanceListener = new MockRebalanceListener(); + initMetadata(client, Utils.mkMap(Utils.mkEntry(topic, 1), Utils.mkEntry(topic2, 1), Utils.mkEntry(topic3, 1))); + + consumer.subscribe(Arrays.asList(topic, topic2), countingRebalanceListener); + Node node = metadata.fetch().nodes().get(0); + prepareRebalance(client, node, assignor, Arrays.asList(tp0, t2p0), null); + + // a first rebalance to get the assignment, we need two poll calls since we need two round trips to finish join / sync-group + consumer.poll(Duration.ZERO); + consumer.poll(Duration.ZERO); + + // onPartitionsRevoked is not invoked when first joining the group + assertEquals(countingRebalanceListener.revokedCount, 0); + assertEquals(countingRebalanceListener.assignedCount, 1); + + consumer.enforceRebalance(); + + // the next poll should trigger a rebalance + consumer.poll(Duration.ZERO); + + assertEquals(countingRebalanceListener.revokedCount, 1); + } + + @Test + public void configurableObjectsShouldSeeGeneratedClientId() { + Properties props = new Properties(); + props.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + props.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, DeserializerForClientId.class.getName()); + props.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, DeserializerForClientId.class.getName()); + props.put(ConsumerConfig.INTERCEPTOR_CLASSES_CONFIG, ConsumerInterceptorForClientId.class.getName()); + + KafkaConsumer consumer = new KafkaConsumer<>(props); + assertNotNull(consumer.getClientId()); + assertNotEquals(0, consumer.getClientId().length()); + assertEquals(3, CLIENT_IDS.size()); + CLIENT_IDS.forEach(id -> assertEquals(id, consumer.getClientId())); + consumer.close(); + } + + @Test + public void testUnusedConfigs() { + Map props = new HashMap<>(); + props.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + props.put(SslConfigs.SSL_PROTOCOL_CONFIG, "TLS"); + ConsumerConfig config = new ConsumerConfig(ConsumerConfig.appendDeserializerToConfig(props, new StringDeserializer(), new StringDeserializer())); + + assertTrue(config.unused().contains(SslConfigs.SSL_PROTOCOL_CONFIG)); + + try (KafkaConsumer consumer = new KafkaConsumer<>(config, null, null)) { + assertTrue(config.unused().contains(SslConfigs.SSL_PROTOCOL_CONFIG)); + } + } + + @Test + public void testAssignorNameConflict() { + Map configs = new HashMap<>(); + configs.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + configs.put(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG, + Arrays.asList(RangeAssignor.class.getName(), ConsumerPartitionAssignorTest.TestConsumerPartitionAssignor.class.getName())); + + assertThrows(KafkaException.class, + () -> new KafkaConsumer<>(configs, new StringDeserializer(), new StringDeserializer())); + } + + private static final List CLIENT_IDS = new ArrayList<>(); + public static class DeserializerForClientId implements Deserializer { + @Override + public void configure(Map configs, boolean isKey) { + CLIENT_IDS.add(configs.get(ConsumerConfig.CLIENT_ID_CONFIG).toString()); + } + + @Override + public byte[] deserialize(String topic, byte[] data) { + return data; + } + } + + public static class ConsumerInterceptorForClientId implements ConsumerInterceptor { + + @Override + public ConsumerRecords onConsume(ConsumerRecords records) { + return records; + } + + @Override + public void onCommit(Map offsets) { + + } + + @Override + public void close() { + + } + + @Override + public void configure(Map configs) { + CLIENT_IDS.add(configs.get(ConsumerConfig.CLIENT_ID_CONFIG).toString()); + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/MockConsumerTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/MockConsumerTest.java new file mode 100644 index 0000000..be2ea0e --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/MockConsumerTest.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.record.TimestampType; +import org.junit.jupiter.api.Test; + +import java.time.Duration; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class MockConsumerTest { + + private final MockConsumer consumer = new MockConsumer<>(OffsetResetStrategy.EARLIEST); + + @Test + public void testSimpleMock() { + consumer.subscribe(Collections.singleton("test")); + assertEquals(0, consumer.poll(Duration.ZERO).count()); + consumer.rebalance(Arrays.asList(new TopicPartition("test", 0), new TopicPartition("test", 1))); + // Mock consumers need to seek manually since they cannot automatically reset offsets + HashMap beginningOffsets = new HashMap<>(); + beginningOffsets.put(new TopicPartition("test", 0), 0L); + beginningOffsets.put(new TopicPartition("test", 1), 0L); + consumer.updateBeginningOffsets(beginningOffsets); + consumer.seek(new TopicPartition("test", 0), 0); + ConsumerRecord rec1 = new ConsumerRecord<>("test", 0, 0, 0L, TimestampType.CREATE_TIME, + 0, 0, "key1", "value1", new RecordHeaders(), Optional.empty()); + ConsumerRecord rec2 = new ConsumerRecord<>("test", 0, 1, 0L, TimestampType.CREATE_TIME, + 0, 0, "key2", "value2", new RecordHeaders(), Optional.empty()); + consumer.addRecord(rec1); + consumer.addRecord(rec2); + ConsumerRecords recs = consumer.poll(Duration.ofMillis(1)); + Iterator> iter = recs.iterator(); + assertEquals(rec1, iter.next()); + assertEquals(rec2, iter.next()); + assertFalse(iter.hasNext()); + final TopicPartition tp = new TopicPartition("test", 0); + assertEquals(2L, consumer.position(tp)); + consumer.commitSync(); + assertEquals(2L, consumer.committed(Collections.singleton(tp)).get(tp).offset()); + } + + @SuppressWarnings("deprecation") + @Test + public void testSimpleMockDeprecated() { + consumer.subscribe(Collections.singleton("test")); + assertEquals(0, consumer.poll(1000).count()); + consumer.rebalance(Arrays.asList(new TopicPartition("test", 0), new TopicPartition("test", 1))); + // Mock consumers need to seek manually since they cannot automatically reset offsets + HashMap beginningOffsets = new HashMap<>(); + beginningOffsets.put(new TopicPartition("test", 0), 0L); + beginningOffsets.put(new TopicPartition("test", 1), 0L); + consumer.updateBeginningOffsets(beginningOffsets); + consumer.seek(new TopicPartition("test", 0), 0); + ConsumerRecord rec1 = new ConsumerRecord<>("test", 0, 0, 0L, TimestampType.CREATE_TIME, + 0, 0, "key1", "value1", new RecordHeaders(), Optional.empty()); + ConsumerRecord rec2 = new ConsumerRecord<>("test", 0, 1, 0L, TimestampType.CREATE_TIME, + 0, 0, "key2", "value2", new RecordHeaders(), Optional.empty()); + consumer.addRecord(rec1); + consumer.addRecord(rec2); + ConsumerRecords recs = consumer.poll(1); + Iterator> iter = recs.iterator(); + assertEquals(rec1, iter.next()); + assertEquals(rec2, iter.next()); + assertFalse(iter.hasNext()); + final TopicPartition tp = new TopicPartition("test", 0); + assertEquals(2L, consumer.position(tp)); + consumer.commitSync(); + assertEquals(2L, consumer.committed(Collections.singleton(tp)).get(tp).offset()); + assertEquals(new ConsumerGroupMetadata("dummy.group.id", 1, "1", Optional.empty()), + consumer.groupMetadata()); + } + + @Test + public void testConsumerRecordsIsEmptyWhenReturningNoRecords() { + TopicPartition partition = new TopicPartition("test", 0); + consumer.assign(Collections.singleton(partition)); + consumer.addRecord(new ConsumerRecord<>("test", 0, 0, null, null)); + consumer.updateEndOffsets(Collections.singletonMap(partition, 1L)); + consumer.seekToEnd(Collections.singleton(partition)); + ConsumerRecords records = consumer.poll(Duration.ofMillis(1)); + assertEquals(0, records.count()); + assertTrue(records.isEmpty()); + } + + @Test + public void shouldNotClearRecordsForPausedPartitions() { + TopicPartition partition0 = new TopicPartition("test", 0); + Collection testPartitionList = Collections.singletonList(partition0); + consumer.assign(testPartitionList); + consumer.addRecord(new ConsumerRecord<>("test", 0, 0, null, null)); + consumer.updateBeginningOffsets(Collections.singletonMap(partition0, 0L)); + consumer.seekToBeginning(testPartitionList); + + consumer.pause(testPartitionList); + consumer.poll(Duration.ofMillis(1)); + consumer.resume(testPartitionList); + ConsumerRecords recordsSecondPoll = consumer.poll(Duration.ofMillis(1)); + assertEquals(1, recordsSecondPoll.count()); + } + + @Test + public void endOffsetsShouldBeIdempotent() { + TopicPartition partition = new TopicPartition("test", 0); + consumer.updateEndOffsets(Collections.singletonMap(partition, 10L)); + // consumer.endOffsets should NOT change the value of end offsets + assertEquals(10L, (long) consumer.endOffsets(Collections.singleton(partition)).get(partition)); + assertEquals(10L, (long) consumer.endOffsets(Collections.singleton(partition)).get(partition)); + assertEquals(10L, (long) consumer.endOffsets(Collections.singleton(partition)).get(partition)); + consumer.updateEndOffsets(Collections.singletonMap(partition, 11L)); + // consumer.endOffsets should NOT change the value of end offsets + assertEquals(11L, (long) consumer.endOffsets(Collections.singleton(partition)).get(partition)); + assertEquals(11L, (long) consumer.endOffsets(Collections.singleton(partition)).get(partition)); + assertEquals(11L, (long) consumer.endOffsets(Collections.singleton(partition)).get(partition)); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/OffsetAndMetadataTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/OffsetAndMetadataTest.java new file mode 100644 index 0000000..5e4ef68 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/OffsetAndMetadataTest.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import org.apache.kafka.common.utils.Serializer; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +/** + * This test case ensures OffsetAndMetadata class is serializable and is serialization compatible. + * Note: this ensures that the current code can deserialize data serialized with older versions of the code, but not the reverse. + * That is, older code won't necessarily be able to deserialize data serialized with newer code. + */ +public class OffsetAndMetadataTest { + + @Test + public void testInvalidNegativeOffset() { + assertThrows(IllegalArgumentException.class, () -> new OffsetAndMetadata(-239L, Optional.of(15), "")); + } + + @Test + public void testSerializationRoundtrip() throws IOException, ClassNotFoundException { + checkSerde(new OffsetAndMetadata(239L, Optional.of(15), "blah")); + checkSerde(new OffsetAndMetadata(239L, "blah")); + checkSerde(new OffsetAndMetadata(239L)); + } + + private void checkSerde(OffsetAndMetadata offsetAndMetadata) throws IOException, ClassNotFoundException { + byte[] bytes = Serializer.serialize(offsetAndMetadata); + OffsetAndMetadata deserialized = (OffsetAndMetadata) Serializer.deserialize(bytes); + assertEquals(offsetAndMetadata, deserialized); + } + + @Test + public void testDeserializationCompatibilityBeforeLeaderEpoch() throws IOException, ClassNotFoundException { + String fileName = "serializedData/offsetAndMetadataBeforeLeaderEpoch"; + Object deserializedObject = Serializer.deserialize(fileName); + assertEquals(new OffsetAndMetadata(10, "test commit metadata"), deserializedObject); + } + + @Test + public void testDeserializationCompatibilityWithLeaderEpoch() throws IOException, ClassNotFoundException { + String fileName = "serializedData/offsetAndMetadataWithLeaderEpoch"; + Object deserializedObject = Serializer.deserialize(fileName); + assertEquals(new OffsetAndMetadata(10, Optional.of(235), "test commit metadata"), deserializedObject); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/RangeAssignorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/RangeAssignorTest.java new file mode 100644 index 0000000..e067e6f --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/RangeAssignorTest.java @@ -0,0 +1,341 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Subscription; +import org.apache.kafka.clients.consumer.internals.AbstractPartitionAssignor; +import org.apache.kafka.clients.consumer.internals.AbstractPartitionAssignor.MemberInfo; +import org.apache.kafka.common.TopicPartition; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class RangeAssignorTest { + + private RangeAssignor assignor = new RangeAssignor(); + + // For plural tests + private String topic1 = "topic1"; + private String topic2 = "topic2"; + private final String consumer1 = "consumer1"; + private final String instance1 = "instance1"; + private final String consumer2 = "consumer2"; + private final String instance2 = "instance2"; + private final String consumer3 = "consumer3"; + private final String instance3 = "instance3"; + + private List staticMemberInfos; + + @BeforeEach + public void setUp() { + staticMemberInfos = new ArrayList<>(); + staticMemberInfos.add(new MemberInfo(consumer1, Optional.of(instance1))); + staticMemberInfos.add(new MemberInfo(consumer2, Optional.of(instance2))); + staticMemberInfos.add(new MemberInfo(consumer3, Optional.of(instance3))); + } + + @Test + public void testOneConsumerNoTopic() { + Map partitionsPerTopic = new HashMap<>(); + + Map> assignment = assignor.assign(partitionsPerTopic, + Collections.singletonMap(consumer1, new Subscription(Collections.emptyList()))); + + assertEquals(Collections.singleton(consumer1), assignment.keySet()); + assertTrue(assignment.get(consumer1).isEmpty()); + } + + @Test + public void testOneConsumerNonexistentTopic() { + Map partitionsPerTopic = new HashMap<>(); + Map> assignment = assignor.assign(partitionsPerTopic, + Collections.singletonMap(consumer1, new Subscription(topics(topic1)))); + assertEquals(Collections.singleton(consumer1), assignment.keySet()); + assertTrue(assignment.get(consumer1).isEmpty()); + } + + @Test + public void testOneConsumerOneTopic() { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic1, 3); + + Map> assignment = assignor.assign(partitionsPerTopic, + Collections.singletonMap(consumer1, new Subscription(topics(topic1)))); + + assertEquals(Collections.singleton(consumer1), assignment.keySet()); + assertAssignment(partitions(tp(topic1, 0), tp(topic1, 1), tp(topic1, 2)), assignment.get(consumer1)); + } + + @Test + public void testOnlyAssignsPartitionsFromSubscribedTopics() { + String otherTopic = "other"; + + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic1, 3); + partitionsPerTopic.put(otherTopic, 3); + + Map> assignment = assignor.assign(partitionsPerTopic, + Collections.singletonMap(consumer1, new Subscription(topics(topic1)))); + assertEquals(Collections.singleton(consumer1), assignment.keySet()); + assertAssignment(partitions(tp(topic1, 0), tp(topic1, 1), tp(topic1, 2)), assignment.get(consumer1)); + } + + @Test + public void testOneConsumerMultipleTopics() { + Map partitionsPerTopic = setupPartitionsPerTopicWithTwoTopics(1, 2); + + Map> assignment = assignor.assign(partitionsPerTopic, + Collections.singletonMap(consumer1, new Subscription(topics(topic1, topic2)))); + + assertEquals(Collections.singleton(consumer1), assignment.keySet()); + assertAssignment(partitions(tp(topic1, 0), tp(topic2, 0), tp(topic2, 1)), assignment.get(consumer1)); + } + + @Test + public void testTwoConsumersOneTopicOnePartition() { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic1, 1); + + Map consumers = new HashMap<>(); + consumers.put(consumer1, new Subscription(topics(topic1))); + consumers.put(consumer2, new Subscription(topics(topic1))); + + Map> assignment = assignor.assign(partitionsPerTopic, consumers); + assertAssignment(partitions(tp(topic1, 0)), assignment.get(consumer1)); + assertAssignment(Collections.emptyList(), assignment.get(consumer2)); + } + + + @Test + public void testTwoConsumersOneTopicTwoPartitions() { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic1, 2); + + Map consumers = new HashMap<>(); + consumers.put(consumer1, new Subscription(topics(topic1))); + consumers.put(consumer2, new Subscription(topics(topic1))); + + Map> assignment = assignor.assign(partitionsPerTopic, consumers); + assertAssignment(partitions(tp(topic1, 0)), assignment.get(consumer1)); + assertAssignment(partitions(tp(topic1, 1)), assignment.get(consumer2)); + } + + @Test + public void testMultipleConsumersMixedTopics() { + Map partitionsPerTopic = setupPartitionsPerTopicWithTwoTopics(3, 2); + + Map consumers = new HashMap<>(); + consumers.put(consumer1, new Subscription(topics(topic1))); + consumers.put(consumer2, new Subscription(topics(topic1, topic2))); + consumers.put(consumer3, new Subscription(topics(topic1))); + + Map> assignment = assignor.assign(partitionsPerTopic, consumers); + assertAssignment(partitions(tp(topic1, 0)), assignment.get(consumer1)); + assertAssignment(partitions(tp(topic1, 1), tp(topic2, 0), tp(topic2, 1)), assignment.get(consumer2)); + assertAssignment(partitions(tp(topic1, 2)), assignment.get(consumer3)); + } + + @Test + public void testTwoConsumersTwoTopicsSixPartitions() { + String topic1 = "topic1"; + String topic2 = "topic2"; + String consumer1 = "consumer1"; + String consumer2 = "consumer2"; + + Map partitionsPerTopic = setupPartitionsPerTopicWithTwoTopics(3, 3); + + Map consumers = new HashMap<>(); + consumers.put(consumer1, new Subscription(topics(topic1, topic2))); + consumers.put(consumer2, new Subscription(topics(topic1, topic2))); + + Map> assignment = assignor.assign(partitionsPerTopic, consumers); + assertAssignment(partitions(tp(topic1, 0), tp(topic1, 1), tp(topic2, 0), tp(topic2, 1)), assignment.get(consumer1)); + assertAssignment(partitions(tp(topic1, 2), tp(topic2, 2)), assignment.get(consumer2)); + } + + @Test + public void testTwoStaticConsumersTwoTopicsSixPartitions() { + // although consumer high has a higher rank than consumer low, the comparison happens on + // instance id level. + String consumerIdLow = "consumer-b"; + String consumerIdHigh = "consumer-a"; + + Map partitionsPerTopic = setupPartitionsPerTopicWithTwoTopics(3, 3); + + Map consumers = new HashMap<>(); + Subscription consumerLowSubscription = new Subscription(topics(topic1, topic2), + null, + Collections.emptyList()); + consumerLowSubscription.setGroupInstanceId(Optional.of(instance1)); + consumers.put(consumerIdLow, consumerLowSubscription); + Subscription consumerHighSubscription = new Subscription(topics(topic1, topic2), + null, + Collections.emptyList()); + consumerHighSubscription.setGroupInstanceId(Optional.of(instance2)); + consumers.put(consumerIdHigh, consumerHighSubscription); + Map> assignment = assignor.assign(partitionsPerTopic, consumers); + assertAssignment(partitions(tp(topic1, 0), tp(topic1, 1), tp(topic2, 0), tp(topic2, 1)), assignment.get(consumerIdLow)); + assertAssignment(partitions(tp(topic1, 2), tp(topic2, 2)), assignment.get(consumerIdHigh)); + } + + @Test + public void testOneStaticConsumerAndOneDynamicConsumerTwoTopicsSixPartitions() { + // although consumer high has a higher rank than low, consumer low will win the comparison + // because it has instance id while consumer 2 doesn't. + String consumerIdLow = "consumer-b"; + String consumerIdHigh = "consumer-a"; + + Map partitionsPerTopic = setupPartitionsPerTopicWithTwoTopics(3, 3); + + Map consumers = new HashMap<>(); + + Subscription consumerLowSubscription = new Subscription(topics(topic1, topic2), + null, + Collections.emptyList()); + consumerLowSubscription.setGroupInstanceId(Optional.of(instance1)); + consumers.put(consumerIdLow, consumerLowSubscription); + consumers.put(consumerIdHigh, new Subscription(topics(topic1, topic2))); + + Map> assignment = assignor.assign(partitionsPerTopic, consumers); + assertAssignment(partitions(tp(topic1, 0), tp(topic1, 1), tp(topic2, 0), tp(topic2, 1)), assignment.get(consumerIdLow)); + assertAssignment(partitions(tp(topic1, 2), tp(topic2, 2)), assignment.get(consumerIdHigh)); + } + + @Test + public void testStaticMemberRangeAssignmentPersistent() { + Map partitionsPerTopic = setupPartitionsPerTopicWithTwoTopics(5, 4); + + Map consumers = new HashMap<>(); + for (MemberInfo m : staticMemberInfos) { + Subscription subscription = new Subscription(topics(topic1, topic2), + null, + Collections.emptyList()); + subscription.setGroupInstanceId(m.groupInstanceId); + consumers.put(m.memberId, subscription); + } + // Consumer 4 is a dynamic member. + String consumer4 = "consumer4"; + consumers.put(consumer4, new Subscription(topics(topic1, topic2))); + + Map> expectedAssignment = new HashMap<>(); + // Have 3 static members instance1, instance2, instance3 to be persistent + // across generations. Their assignment shall be the same. + expectedAssignment.put(consumer1, partitions(tp(topic1, 0), tp(topic1, 1), tp(topic2, 0))); + expectedAssignment.put(consumer2, partitions(tp(topic1, 2), tp(topic2, 1))); + expectedAssignment.put(consumer3, partitions(tp(topic1, 3), tp(topic2, 2))); + expectedAssignment.put(consumer4, partitions(tp(topic1, 4), tp(topic2, 3))); + + Map> assignment = assignor.assign(partitionsPerTopic, consumers); + assertEquals(expectedAssignment, assignment); + + // Replace dynamic member 4 with a new dynamic member 5. + consumers.remove(consumer4); + String consumer5 = "consumer5"; + consumers.put(consumer5, new Subscription(topics(topic1, topic2))); + + expectedAssignment.remove(consumer4); + expectedAssignment.put(consumer5, partitions(tp(topic1, 4), tp(topic2, 3))); + assignment = assignor.assign(partitionsPerTopic, consumers); + assertEquals(expectedAssignment, assignment); + } + + @Test + public void testStaticMemberRangeAssignmentPersistentAfterMemberIdChanges() { + Map partitionsPerTopic = setupPartitionsPerTopicWithTwoTopics(5, 5); + + Map consumers = new HashMap<>(); + for (MemberInfo m : staticMemberInfos) { + Subscription subscription = new Subscription(topics(topic1, topic2), + null, + Collections.emptyList()); + subscription.setGroupInstanceId(m.groupInstanceId); + consumers.put(m.memberId, subscription); + } + Map> expectedInstanceAssignment = new HashMap<>(); + expectedInstanceAssignment.put(instance1, + partitions(tp(topic1, 0), tp(topic1, 1), tp(topic2, 0), tp(topic2, 1))); + expectedInstanceAssignment.put(instance2, + partitions(tp(topic1, 2), tp(topic1, 3), tp(topic2, 2), tp(topic2, 3))); + expectedInstanceAssignment.put(instance3, + partitions(tp(topic1, 4), tp(topic2, 4))); + + Map> staticAssignment = + checkStaticAssignment(assignor, partitionsPerTopic, consumers); + assertEquals(expectedInstanceAssignment, staticAssignment); + + // Now switch the member.id fields for each member info, the assignment should + // stay the same as last time. + String consumer4 = "consumer4"; + String consumer5 = "consumer5"; + consumers.put(consumer4, consumers.get(consumer3)); + consumers.remove(consumer3); + consumers.put(consumer5, consumers.get(consumer2)); + consumers.remove(consumer2); + + Map> newStaticAssignment = + checkStaticAssignment(assignor, partitionsPerTopic, consumers); + assertEquals(staticAssignment, newStaticAssignment); + } + + static Map> checkStaticAssignment(AbstractPartitionAssignor assignor, + Map partitionsPerTopic, + Map consumers) { + Map> assignmentByMemberId = assignor.assign(partitionsPerTopic, consumers); + Map> assignmentByInstanceId = new HashMap<>(); + for (Map.Entry entry : consumers.entrySet()) { + String memberId = entry.getKey(); + Optional instanceId = entry.getValue().groupInstanceId(); + instanceId.ifPresent(id -> assignmentByInstanceId.put(id, assignmentByMemberId.get(memberId))); + } + return assignmentByInstanceId; + } + + private void assertAssignment(List expected, List actual) { + // order doesn't matter for assignment, so convert to a set + assertEquals(new HashSet<>(expected), new HashSet<>(actual)); + } + + private Map setupPartitionsPerTopicWithTwoTopics(int numberOfPartitions1, int numberOfPartitions2) { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic1, numberOfPartitions1); + partitionsPerTopic.put(topic2, numberOfPartitions2); + return partitionsPerTopic; + } + + private static List topics(String... topics) { + return Arrays.asList(topics); + } + + private static List partitions(TopicPartition... partitions) { + return Arrays.asList(partitions); + } + + private static TopicPartition tp(String topic, int partition) { + return new TopicPartition(topic, partition); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/RoundRobinAssignorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/RoundRobinAssignorTest.java new file mode 100644 index 0000000..19cd68c --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/RoundRobinAssignorTest.java @@ -0,0 +1,338 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Subscription; +import org.apache.kafka.clients.consumer.internals.AbstractPartitionAssignor; +import org.apache.kafka.clients.consumer.internals.AbstractPartitionAssignor.MemberInfo; +import org.apache.kafka.common.TopicPartition; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.apache.kafka.clients.consumer.RangeAssignorTest.checkStaticAssignment; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class RoundRobinAssignorTest { + + private RoundRobinAssignor assignor = new RoundRobinAssignor(); + private String topic = "topic"; + private String consumerId = "consumer"; + + private String topic1 = "topic1"; + private String topic2 = "topic2"; + + @Test + public void testOneConsumerNoTopic() { + Map partitionsPerTopic = new HashMap<>(); + + Map> assignment = assignor.assign(partitionsPerTopic, + Collections.singletonMap(consumerId, new Subscription(Collections.emptyList()))); + assertEquals(Collections.singleton(consumerId), assignment.keySet()); + assertTrue(assignment.get(consumerId).isEmpty()); + } + + @Test + public void testOneConsumerNonexistentTopic() { + Map partitionsPerTopic = new HashMap<>(); + Map> assignment = assignor.assign(partitionsPerTopic, + Collections.singletonMap(consumerId, new Subscription(topics(topic)))); + + assertEquals(Collections.singleton(consumerId), assignment.keySet()); + assertTrue(assignment.get(consumerId).isEmpty()); + } + + @Test + public void testOneConsumerOneTopic() { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic, 3); + + Map> assignment = assignor.assign(partitionsPerTopic, + Collections.singletonMap(consumerId, new Subscription(topics(topic)))); + assertEquals(partitions(tp(topic, 0), tp(topic, 1), tp(topic, 2)), assignment.get(consumerId)); + } + + @Test + public void testOnlyAssignsPartitionsFromSubscribedTopics() { + String otherTopic = "other"; + + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic, 3); + partitionsPerTopic.put(otherTopic, 3); + + Map> assignment = assignor.assign(partitionsPerTopic, + Collections.singletonMap(consumerId, new Subscription(topics(topic)))); + assertEquals(partitions(tp(topic, 0), tp(topic, 1), tp(topic, 2)), assignment.get(consumerId)); + } + + @Test + public void testOneConsumerMultipleTopics() { + Map partitionsPerTopic = setupPartitionsPerTopicWithTwoTopics(1, 2); + + Map> assignment = assignor.assign(partitionsPerTopic, + Collections.singletonMap(consumerId, new Subscription(topics(topic1, topic2)))); + assertEquals(partitions(tp(topic1, 0), tp(topic2, 0), tp(topic2, 1)), assignment.get(consumerId)); + } + + @Test + public void testTwoConsumersOneTopicOnePartition() { + String consumer1 = "consumer1"; + String consumer2 = "consumer2"; + + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic, 1); + + Map consumers = new HashMap<>(); + consumers.put(consumer1, new Subscription(topics(topic))); + consumers.put(consumer2, new Subscription(topics(topic))); + + Map> assignment = assignor.assign(partitionsPerTopic, consumers); + assertEquals(partitions(tp(topic, 0)), assignment.get(consumer1)); + assertEquals(Collections.emptyList(), assignment.get(consumer2)); + } + + @Test + public void testTwoConsumersOneTopicTwoPartitions() { + String consumer1 = "consumer1"; + String consumer2 = "consumer2"; + + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic, 2); + + Map consumers = new HashMap<>(); + consumers.put(consumer1, new Subscription(topics(topic))); + consumers.put(consumer2, new Subscription(topics(topic))); + + Map> assignment = assignor.assign(partitionsPerTopic, consumers); + assertEquals(partitions(tp(topic, 0)), assignment.get(consumer1)); + assertEquals(partitions(tp(topic, 1)), assignment.get(consumer2)); + } + + @Test + public void testMultipleConsumersMixedTopics() { + String topic1 = "topic1"; + String topic2 = "topic2"; + String consumer1 = "consumer1"; + String consumer2 = "consumer2"; + String consumer3 = "consumer3"; + + Map partitionsPerTopic = setupPartitionsPerTopicWithTwoTopics(3, 2); + + Map consumers = new HashMap<>(); + consumers.put(consumer1, new Subscription(topics(topic1))); + consumers.put(consumer2, new Subscription(topics(topic1, topic2))); + consumers.put(consumer3, new Subscription(topics(topic1))); + + Map> assignment = assignor.assign(partitionsPerTopic, consumers); + assertEquals(partitions(tp(topic1, 0)), assignment.get(consumer1)); + assertEquals(partitions(tp(topic1, 1), tp(topic2, 0), tp(topic2, 1)), assignment.get(consumer2)); + assertEquals(partitions(tp(topic1, 2)), assignment.get(consumer3)); + } + + @Test + public void testTwoDynamicConsumersTwoTopicsSixPartitions() { + String topic1 = "topic1"; + String topic2 = "topic2"; + String consumer1 = "consumer1"; + String consumer2 = "consumer2"; + + Map partitionsPerTopic = setupPartitionsPerTopicWithTwoTopics(3, 3); + + Map consumers = new HashMap<>(); + consumers.put(consumer1, new Subscription(topics(topic1, topic2))); + consumers.put(consumer2, new Subscription(topics(topic1, topic2))); + + Map> assignment = assignor.assign(partitionsPerTopic, consumers); + assertEquals(partitions(tp(topic1, 0), tp(topic1, 2), tp(topic2, 1)), assignment.get(consumer1)); + assertEquals(partitions(tp(topic1, 1), tp(topic2, 0), tp(topic2, 2)), assignment.get(consumer2)); + } + + @Test + public void testTwoStaticConsumersTwoTopicsSixPartitions() { + // although consumer 2 has a higher rank than 1, the comparison happens on + // instance id level. + String topic1 = "topic1"; + String topic2 = "topic2"; + String consumer1 = "consumer-b"; + String instance1 = "instance1"; + String consumer2 = "consumer-a"; + String instance2 = "instance2"; + + Map partitionsPerTopic = setupPartitionsPerTopicWithTwoTopics(3, 3); + + Map consumers = new HashMap<>(); + Subscription consumer1Subscription = new Subscription(topics(topic1, topic2), null); + consumer1Subscription.setGroupInstanceId(Optional.of(instance1)); + consumers.put(consumer1, consumer1Subscription); + Subscription consumer2Subscription = new Subscription(topics(topic1, topic2), null); + consumer2Subscription.setGroupInstanceId(Optional.of(instance2)); + consumers.put(consumer2, consumer2Subscription); + Map> assignment = assignor.assign(partitionsPerTopic, consumers); + assertEquals(partitions(tp(topic1, 0), tp(topic1, 2), tp(topic2, 1)), assignment.get(consumer1)); + assertEquals(partitions(tp(topic1, 1), tp(topic2, 0), tp(topic2, 2)), assignment.get(consumer2)); + } + + @Test + public void testOneStaticConsumerAndOneDynamicConsumerTwoTopicsSixPartitions() { + // although consumer 2 has a higher rank than 1, consumer 1 will win the comparison + // because it has instance id while consumer 2 doesn't. + String consumer1 = "consumer-b"; + String instance1 = "instance1"; + String consumer2 = "consumer-a"; + + Map partitionsPerTopic = setupPartitionsPerTopicWithTwoTopics(3, 3); + + Map consumers = new HashMap<>(); + + Subscription consumer1Subscription = new Subscription(topics(topic1, topic2), null); + consumer1Subscription.setGroupInstanceId(Optional.of(instance1)); + consumers.put(consumer1, consumer1Subscription); + consumers.put(consumer2, new Subscription(topics(topic1, topic2))); + + Map> assignment = assignor.assign(partitionsPerTopic, consumers); + assertEquals(partitions(tp(topic1, 0), tp(topic1, 2), tp(topic2, 1)), assignment.get(consumer1)); + assertEquals(partitions(tp(topic1, 1), tp(topic2, 0), tp(topic2, 2)), assignment.get(consumer2)); + } + + @Test + public void testStaticMemberRoundRobinAssignmentPersistent() { + // Have 3 static members instance1, instance2, instance3 to be persistent + // across generations. Their assignment shall be the same. + String consumer1 = "consumer1"; + String instance1 = "instance1"; + String consumer2 = "consumer2"; + String instance2 = "instance2"; + String consumer3 = "consumer3"; + String instance3 = "instance3"; + + List staticMemberInfos = new ArrayList<>(); + staticMemberInfos.add(new MemberInfo(consumer1, Optional.of(instance1))); + staticMemberInfos.add(new MemberInfo(consumer2, Optional.of(instance2))); + staticMemberInfos.add(new MemberInfo(consumer3, Optional.of(instance3))); + + // Consumer 4 is a dynamic member. + String consumer4 = "consumer4"; + + Map partitionsPerTopic = setupPartitionsPerTopicWithTwoTopics(3, 3); + + Map consumers = new HashMap<>(); + for (MemberInfo m : staticMemberInfos) { + Subscription subscription = new Subscription(topics(topic1, topic2), null); + subscription.setGroupInstanceId(m.groupInstanceId); + consumers.put(m.memberId, subscription); + } + consumers.put(consumer4, new Subscription(topics(topic1, topic2))); + + Map> expectedAssignment = new HashMap<>(); + expectedAssignment.put(consumer1, partitions(tp(topic1, 0), tp(topic2, 1))); + expectedAssignment.put(consumer2, partitions(tp(topic1, 1), tp(topic2, 2))); + expectedAssignment.put(consumer3, partitions(tp(topic1, 2))); + expectedAssignment.put(consumer4, partitions(tp(topic2, 0))); + + Map> assignment = assignor.assign(partitionsPerTopic, consumers); + assertEquals(expectedAssignment, assignment); + + // Replace dynamic member 4 with a new dynamic member 5. + consumers.remove(consumer4); + String consumer5 = "consumer5"; + consumers.put(consumer5, new Subscription(topics(topic1, topic2))); + + expectedAssignment.remove(consumer4); + expectedAssignment.put(consumer5, partitions(tp(topic2, 0))); + assignment = assignor.assign(partitionsPerTopic, consumers); + assertEquals(expectedAssignment, assignment); + } + + @Test + public void testStaticMemberRoundRobinAssignmentPersistentAfterMemberIdChanges() { + String consumer1 = "consumer1"; + String instance1 = "instance1"; + String consumer2 = "consumer2"; + String instance2 = "instance2"; + String consumer3 = "consumer3"; + String instance3 = "instance3"; + Map memberIdToInstanceId = new HashMap<>(); + memberIdToInstanceId.put(consumer1, instance1); + memberIdToInstanceId.put(consumer2, instance2); + memberIdToInstanceId.put(consumer3, instance3); + + Map partitionsPerTopic = setupPartitionsPerTopicWithTwoTopics(5, 5); + + Map> expectedInstanceAssignment = new HashMap<>(); + expectedInstanceAssignment.put(instance1, + partitions(tp(topic1, 0), tp(topic1, 3), tp(topic2, 1), tp(topic2, 4))); + expectedInstanceAssignment.put(instance2, + partitions(tp(topic1, 1), tp(topic1, 4), tp(topic2, 2))); + expectedInstanceAssignment.put(instance3, + partitions(tp(topic1, 2), tp(topic2, 0), tp(topic2, 3))); + + List staticMemberInfos = new ArrayList<>(); + for (Map.Entry entry : memberIdToInstanceId.entrySet()) { + staticMemberInfos.add(new AbstractPartitionAssignor.MemberInfo(entry.getKey(), Optional.of(entry.getValue()))); + } + Map consumers = new HashMap<>(); + for (MemberInfo m : staticMemberInfos) { + Subscription subscription = new Subscription(topics(topic1, topic2), null); + subscription.setGroupInstanceId(m.groupInstanceId); + consumers.put(m.memberId, subscription); + } + + Map> staticAssignment = + checkStaticAssignment(assignor, partitionsPerTopic, consumers); + assertEquals(expectedInstanceAssignment, staticAssignment); + + memberIdToInstanceId.clear(); + + // Now switch the member.id fields for each member info, the assignment should + // stay the same as last time. + String consumer4 = "consumer4"; + String consumer5 = "consumer5"; + consumers.put(consumer4, consumers.get(consumer3)); + consumers.remove(consumer3); + consumers.put(consumer5, consumers.get(consumer2)); + consumers.remove(consumer2); + Map> newStaticAssignment = + checkStaticAssignment(assignor, partitionsPerTopic, consumers); + assertEquals(staticAssignment, newStaticAssignment); + } + + private static List topics(String... topics) { + return Arrays.asList(topics); + } + + private static List partitions(TopicPartition... partitions) { + return Arrays.asList(partitions); + } + + private static TopicPartition tp(String topic, int partition) { + return new TopicPartition(topic, partition); + } + + private Map setupPartitionsPerTopicWithTwoTopics(int numberOfPartitions1, int numberOfPartitions2) { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic1, numberOfPartitions1); + partitionsPerTopic.put(topic2, numberOfPartitions2); + return partitionsPerTopic; + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/StickyAssignorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/StickyAssignorTest.java new file mode 100644 index 0000000..bb03de2 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/StickyAssignorTest.java @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer; + +import static org.apache.kafka.clients.consumer.StickyAssignor.serializeTopicPartitionAssignment; +import static org.apache.kafka.clients.consumer.internals.AbstractStickyAssignor.DEFAULT_GENERATION; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Subscription; +import org.apache.kafka.clients.consumer.internals.AbstractStickyAssignor; +import org.apache.kafka.clients.consumer.internals.AbstractStickyAssignor.MemberData; +import org.apache.kafka.clients.consumer.internals.AbstractStickyAssignorTest; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.protocol.types.Struct; +import org.apache.kafka.common.utils.CollectionUtils; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import static java.util.Collections.emptyList; + +public class StickyAssignorTest extends AbstractStickyAssignorTest { + + @Override + public AbstractStickyAssignor createAssignor() { + return new StickyAssignor(); + } + + @Override + public Subscription buildSubscription(List topics, List partitions) { + return new Subscription(topics, + serializeTopicPartitionAssignment(new MemberData(partitions, Optional.of(DEFAULT_GENERATION)))); + } + + @Override + public Subscription buildSubscriptionWithGeneration(List topics, List partitions, int generation) { + return new Subscription(topics, + serializeTopicPartitionAssignment(new MemberData(partitions, Optional.of(generation)))); + } + + @Test + public void testAllConsumersHaveOwnedPartitionInvalidatedWhenClaimedByMultipleConsumersInSameGenerationWithEqualPartitionsPerConsumer() { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic, 3); + + subscriptions.put(consumer1, buildSubscription(topics(topic), partitions(tp(topic, 0), tp(topic, 1)))); + subscriptions.put(consumer2, buildSubscription(topics(topic), partitions(tp(topic, 0), tp(topic, 2)))); + subscriptions.put(consumer3, buildSubscription(topics(topic), emptyList())); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertEquals(partitions(tp(topic, 1)), assignment.get(consumer1)); + assertEquals(partitions(tp(topic, 2)), assignment.get(consumer2)); + assertEquals(partitions(tp(topic, 0)), assignment.get(consumer3)); + + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + } + + @Test + public void testAllConsumersHaveOwnedPartitionInvalidatedWhenClaimedByMultipleConsumersInSameGenerationWithUnequalPartitionsPerConsumer() { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic, 4); + + subscriptions.put(consumer1, buildSubscription(topics(topic), partitions(tp(topic, 0), tp(topic, 1)))); + subscriptions.put(consumer2, buildSubscription(topics(topic), partitions(tp(topic, 0), tp(topic, 2)))); + subscriptions.put(consumer3, buildSubscription(topics(topic), emptyList())); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertEquals(partitions(tp(topic, 1), tp(topic, 3)), assignment.get(consumer1)); + assertEquals(partitions(tp(topic, 2)), assignment.get(consumer2)); + assertEquals(partitions(tp(topic, 0)), assignment.get(consumer3)); + + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + } + + @ParameterizedTest(name = "testAssignmentWithMultipleGenerations1 with isAllSubscriptionsEqual: {0}") + @ValueSource(booleans = {true, false}) + public void testAssignmentWithMultipleGenerations1(boolean isAllSubscriptionsEqual) { + List allTopics = topics(topic, topic2); + List consumer2SubscribedTopics = isAllSubscriptionsEqual ? allTopics : topics(topic); + + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic, 6); + partitionsPerTopic.put(topic2, 6); + subscriptions.put(consumer1, new Subscription(allTopics)); + subscriptions.put(consumer2, new Subscription(consumer2SubscribedTopics)); + subscriptions.put(consumer3, new Subscription(allTopics)); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + List r1partitions1 = assignment.get(consumer1); + List r1partitions2 = assignment.get(consumer2); + List r1partitions3 = assignment.get(consumer3); + assertTrue(r1partitions1.size() == 4 && r1partitions2.size() == 4 && r1partitions3.size() == 4); + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + + subscriptions.put(consumer1, buildSubscription(allTopics, r1partitions1)); + subscriptions.put(consumer2, buildSubscription(consumer2SubscribedTopics, r1partitions2)); + subscriptions.remove(consumer3); + + assignment = assignor.assign(partitionsPerTopic, subscriptions); + List r2partitions1 = assignment.get(consumer1); + List r2partitions2 = assignment.get(consumer2); + assertTrue(r2partitions1.size() == 6 && r2partitions2.size() == 6); + if (isAllSubscriptionsEqual) { + // only true in all subscription equal case + assertTrue(r2partitions1.containsAll(r1partitions1)); + } + assertTrue(r2partitions2.containsAll(r1partitions2)); + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + assertFalse(Collections.disjoint(r2partitions2, r1partitions3)); + + subscriptions.remove(consumer1); + subscriptions.put(consumer2, buildSubscriptionWithGeneration(consumer2SubscribedTopics, r2partitions2, 2)); + subscriptions.put(consumer3, buildSubscriptionWithGeneration(allTopics, r1partitions3, 1)); + + assignment = assignor.assign(partitionsPerTopic, subscriptions); + List r3partitions2 = assignment.get(consumer2); + List r3partitions3 = assignment.get(consumer3); + assertTrue(r3partitions2.size() == 6 && r3partitions3.size() == 6); + assertTrue(Collections.disjoint(r3partitions2, r3partitions3)); + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + } + + @ParameterizedTest(name = "testAssignmentWithMultipleGenerations2 with isAllSubscriptionsEqual: {0}") + @ValueSource(booleans = {true, false}) + public void testAssignmentWithMultipleGenerations2(boolean isAllSubscriptionsEqual) { + List allTopics = topics(topic, topic2, topic3); + List consumer1SubscribedTopics = isAllSubscriptionsEqual ? allTopics : topics(topic); + List consumer3SubscribedTopics = isAllSubscriptionsEqual ? allTopics : topics(topic, topic2); + + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic, 4); + partitionsPerTopic.put(topic2, 4); + partitionsPerTopic.put(topic3, 4); + subscriptions.put(consumer1, new Subscription(consumer1SubscribedTopics)); + subscriptions.put(consumer2, new Subscription(allTopics)); + subscriptions.put(consumer3, new Subscription(consumer3SubscribedTopics)); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + List r1partitions1 = assignment.get(consumer1); + List r1partitions2 = assignment.get(consumer2); + List r1partitions3 = assignment.get(consumer3); + assertTrue(r1partitions1.size() == 4 && r1partitions2.size() == 4 && r1partitions3.size() == 4); + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + + subscriptions.remove(consumer1); + subscriptions.put(consumer2, buildSubscriptionWithGeneration(allTopics, r1partitions2, 1)); + subscriptions.remove(consumer3); + + assignment = assignor.assign(partitionsPerTopic, subscriptions); + List r2partitions2 = assignment.get(consumer2); + assertEquals(12, r2partitions2.size()); + assertTrue(r2partitions2.containsAll(r1partitions2)); + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + + subscriptions.put(consumer1, buildSubscriptionWithGeneration(consumer1SubscribedTopics, r1partitions1, 1)); + subscriptions.put(consumer2, buildSubscriptionWithGeneration(allTopics, r2partitions2, 2)); + subscriptions.put(consumer3, buildSubscriptionWithGeneration(consumer3SubscribedTopics, r1partitions3, 1)); + + assignment = assignor.assign(partitionsPerTopic, subscriptions); + List r3partitions1 = assignment.get(consumer1); + List r3partitions2 = assignment.get(consumer2); + List r3partitions3 = assignment.get(consumer3); + assertTrue(r3partitions1.size() == 4 && r3partitions2.size() == 4 && r3partitions3.size() == 4); + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + } + + @ParameterizedTest(name = "testAssignmentWithConflictingPreviousGenerations with isAllSubscriptionsEqual: {0}") + @ValueSource(booleans = {true, false}) + public void testAssignmentWithConflictingPreviousGenerations(boolean isAllSubscriptionsEqual) { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic, 4); + partitionsPerTopic.put(topic2, 4); + partitionsPerTopic.put(topic3, 4); + + List allTopics = topics(topic, topic2, topic3); + List consumer1SubscribedTopics = isAllSubscriptionsEqual ? allTopics : topics(topic); + List consumer2SubscribedTopics = isAllSubscriptionsEqual ? allTopics : topics(topic, topic2); + + subscriptions.put(consumer1, new Subscription(consumer1SubscribedTopics)); + subscriptions.put(consumer2, new Subscription(consumer2SubscribedTopics)); + subscriptions.put(consumer3, new Subscription(allTopics)); + + TopicPartition tp0 = new TopicPartition(topic, 0); + TopicPartition tp1 = new TopicPartition(topic, 1); + TopicPartition tp2 = new TopicPartition(topic, 2); + TopicPartition tp3 = new TopicPartition(topic, 3); + TopicPartition t2p0 = new TopicPartition(topic2, 0); + TopicPartition t2p1 = new TopicPartition(topic2, 1); + TopicPartition t2p2 = new TopicPartition(topic2, 2); + TopicPartition t2p3 = new TopicPartition(topic2, 3); + TopicPartition t3p0 = new TopicPartition(topic3, 0); + TopicPartition t3p1 = new TopicPartition(topic3, 1); + TopicPartition t3p2 = new TopicPartition(topic3, 2); + TopicPartition t3p3 = new TopicPartition(topic3, 3); + + List c1partitions0 = isAllSubscriptionsEqual ? partitions(tp0, tp1, tp2, t2p2, t2p3, t3p0) : + partitions(tp0, tp1, tp2, tp3); + List c2partitions0 = partitions(tp0, tp1, t2p0, t2p1, t2p2, t2p3); + List c3partitions0 = partitions(tp2, tp3, t3p0, t3p1, t3p2, t3p3); + subscriptions.put(consumer1, buildSubscriptionWithGeneration(consumer1SubscribedTopics, c1partitions0, 1)); + subscriptions.put(consumer2, buildSubscriptionWithGeneration(consumer2SubscribedTopics, c2partitions0, 2)); + subscriptions.put(consumer3, buildSubscriptionWithGeneration(allTopics, c3partitions0, 2)); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + List c1partitions = assignment.get(consumer1); + List c2partitions = assignment.get(consumer2); + List c3partitions = assignment.get(consumer3); + + assertTrue(c1partitions.size() == 4 && c2partitions.size() == 4 && c3partitions.size() == 4); + assertTrue(c2partitions0.containsAll(c2partitions)); + assertTrue(c3partitions0.containsAll(c3partitions)); + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + } + + @Test + public void testSchemaBackwardCompatibility() { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic, 3); + subscriptions.put(consumer1, new Subscription(topics(topic))); + subscriptions.put(consumer2, new Subscription(topics(topic))); + subscriptions.put(consumer3, new Subscription(topics(topic))); + + TopicPartition tp0 = new TopicPartition(topic, 0); + TopicPartition tp1 = new TopicPartition(topic, 1); + TopicPartition tp2 = new TopicPartition(topic, 2); + + List c1partitions0 = partitions(tp0, tp2); + List c2partitions0 = partitions(tp1); + subscriptions.put(consumer1, buildSubscriptionWithGeneration(topics(topic), c1partitions0, 1)); + subscriptions.put(consumer2, buildSubscriptionWithOldSchema(topics(topic), c2partitions0)); + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + List c1partitions = assignment.get(consumer1); + List c2partitions = assignment.get(consumer2); + List c3partitions = assignment.get(consumer3); + + assertTrue(c1partitions.size() == 1 && c2partitions.size() == 1 && c3partitions.size() == 1); + assertTrue(c1partitions0.containsAll(c1partitions)); + assertTrue(c2partitions0.containsAll(c2partitions)); + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + } + + private static Subscription buildSubscriptionWithOldSchema(List topics, List partitions) { + Struct struct = new Struct(StickyAssignor.STICKY_ASSIGNOR_USER_DATA_V0); + List topicAssignments = new ArrayList<>(); + for (Map.Entry> topicEntry : CollectionUtils.groupPartitionsByTopic(partitions).entrySet()) { + Struct topicAssignment = new Struct(StickyAssignor.TOPIC_ASSIGNMENT); + topicAssignment.set(StickyAssignor.TOPIC_KEY_NAME, topicEntry.getKey()); + topicAssignment.set(StickyAssignor.PARTITIONS_KEY_NAME, topicEntry.getValue().toArray()); + topicAssignments.add(topicAssignment); + } + struct.set(StickyAssignor.TOPIC_PARTITIONS_KEY_NAME, topicAssignments.toArray()); + ByteBuffer buffer = ByteBuffer.allocate(StickyAssignor.STICKY_ASSIGNOR_USER_DATA_V0.sizeOf(struct)); + StickyAssignor.STICKY_ASSIGNOR_USER_DATA_V0.write(buffer, struct); + buffer.flip(); + + return new Subscription(topics, buffer); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinatorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinatorTest.java new file mode 100644 index 0000000..384ba91 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractCoordinatorTest.java @@ -0,0 +1,1589 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.clients.GroupRebalanceConfig; +import org.apache.kafka.clients.MockClient; +import org.apache.kafka.clients.consumer.OffsetResetStrategy; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.errors.DisconnectException; +import org.apache.kafka.common.errors.FencedInstanceIdException; +import org.apache.kafka.common.errors.InconsistentGroupProtocolException; +import org.apache.kafka.common.errors.UnknownMemberIdException; +import org.apache.kafka.common.errors.WakeupException; +import org.apache.kafka.common.internals.ClusterResourceListeners; +import org.apache.kafka.common.message.HeartbeatResponseData; +import org.apache.kafka.common.message.JoinGroupRequestData; +import org.apache.kafka.common.message.JoinGroupResponseData; +import org.apache.kafka.common.message.LeaveGroupResponseData; +import org.apache.kafka.common.message.LeaveGroupResponseData.MemberResponse; +import org.apache.kafka.common.message.SyncGroupRequestData; +import org.apache.kafka.common.message.SyncGroupResponseData; +import org.apache.kafka.common.metrics.KafkaMetric; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.AbstractRequest; +import org.apache.kafka.common.requests.FindCoordinatorResponse; +import org.apache.kafka.common.requests.HeartbeatRequest; +import org.apache.kafka.common.requests.HeartbeatResponse; +import org.apache.kafka.common.requests.JoinGroupRequest; +import org.apache.kafka.common.requests.JoinGroupResponse; +import org.apache.kafka.common.requests.LeaveGroupRequest; +import org.apache.kafka.common.requests.LeaveGroupResponse; +import org.apache.kafka.common.requests.RequestTestUtils; +import org.apache.kafka.common.requests.SyncGroupRequest; +import org.apache.kafka.common.requests.SyncGroupResponse; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Timer; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import static java.util.Collections.emptyMap; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class AbstractCoordinatorTest { + private static final ByteBuffer EMPTY_DATA = ByteBuffer.wrap(new byte[0]); + private static final int REBALANCE_TIMEOUT_MS = 60000; + private static final int SESSION_TIMEOUT_MS = 10000; + private static final int HEARTBEAT_INTERVAL_MS = 3000; + private static final int RETRY_BACKOFF_MS = 100; + private static final int REQUEST_TIMEOUT_MS = 40000; + private static final String GROUP_ID = "dummy-group"; + private static final String METRIC_GROUP_PREFIX = "consumer"; + private static final String PROTOCOL_TYPE = "dummy"; + private static final String PROTOCOL_NAME = "dummy-subprotocol"; + + private Node node; + private Metrics metrics; + private MockTime mockTime; + private Node coordinatorNode; + private MockClient mockClient; + private DummyCoordinator coordinator; + private ConsumerNetworkClient consumerClient; + + private final String memberId = "memberId"; + private final String leaderId = "leaderId"; + private final int defaultGeneration = -1; + + @AfterEach + public void closeCoordinator() { + Utils.closeQuietly(coordinator, "close coordinator"); + Utils.closeQuietly(consumerClient, "close consumer client"); + } + + private void setupCoordinator() { + setupCoordinator(RETRY_BACKOFF_MS, REBALANCE_TIMEOUT_MS, + Optional.empty()); + } + + private void setupCoordinator(int retryBackoffMs) { + setupCoordinator(retryBackoffMs, REBALANCE_TIMEOUT_MS, + Optional.empty()); + } + + private void setupCoordinator(int retryBackoffMs, int rebalanceTimeoutMs, Optional groupInstanceId) { + LogContext logContext = new LogContext(); + this.mockTime = new MockTime(); + ConsumerMetadata metadata = new ConsumerMetadata(retryBackoffMs, 60 * 60 * 1000L, + false, false, new SubscriptionState(logContext, OffsetResetStrategy.EARLIEST), + logContext, new ClusterResourceListeners()); + + this.mockClient = new MockClient(mockTime, metadata); + this.consumerClient = new ConsumerNetworkClient(logContext, + mockClient, + metadata, + mockTime, + retryBackoffMs, + REQUEST_TIMEOUT_MS, + HEARTBEAT_INTERVAL_MS); + metrics = new Metrics(mockTime); + + mockClient.updateMetadata(RequestTestUtils.metadataUpdateWith(1, emptyMap())); + this.node = metadata.fetch().nodes().get(0); + this.coordinatorNode = new Node(Integer.MAX_VALUE - node.id(), node.host(), node.port()); + + GroupRebalanceConfig rebalanceConfig = new GroupRebalanceConfig(SESSION_TIMEOUT_MS, + rebalanceTimeoutMs, + HEARTBEAT_INTERVAL_MS, + GROUP_ID, + groupInstanceId, + retryBackoffMs, + !groupInstanceId.isPresent()); + this.coordinator = new DummyCoordinator(rebalanceConfig, + consumerClient, + metrics, + mockTime); + } + + private void joinGroup() { + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + + final int generation = 1; + + mockClient.prepareResponse(joinGroupFollowerResponse(generation, memberId, JoinGroupRequest.UNKNOWN_MEMBER_ID, Errors.NONE)); + mockClient.prepareResponse(syncGroupResponse(Errors.NONE)); + + coordinator.ensureActiveGroup(); + } + + @Test + public void testMetrics() { + setupCoordinator(); + + assertNotNull(getMetric("heartbeat-response-time-max")); + assertNotNull(getMetric("heartbeat-rate")); + assertNotNull(getMetric("heartbeat-total")); + assertNotNull(getMetric("last-heartbeat-seconds-ago")); + assertNotNull(getMetric("join-time-avg")); + assertNotNull(getMetric("join-time-max")); + assertNotNull(getMetric("join-rate")); + assertNotNull(getMetric("join-total")); + assertNotNull(getMetric("sync-time-avg")); + assertNotNull(getMetric("sync-time-max")); + assertNotNull(getMetric("sync-rate")); + assertNotNull(getMetric("sync-total")); + assertNotNull(getMetric("rebalance-latency-avg")); + assertNotNull(getMetric("rebalance-latency-max")); + assertNotNull(getMetric("rebalance-latency-total")); + assertNotNull(getMetric("rebalance-rate-per-hour")); + assertNotNull(getMetric("rebalance-total")); + assertNotNull(getMetric("last-rebalance-seconds-ago")); + assertNotNull(getMetric("failed-rebalance-rate-per-hour")); + assertNotNull(getMetric("failed-rebalance-total")); + + metrics.sensor("heartbeat-latency").record(1.0d); + metrics.sensor("heartbeat-latency").record(6.0d); + metrics.sensor("heartbeat-latency").record(2.0d); + + assertEquals(6.0d, getMetric("heartbeat-response-time-max").metricValue()); + assertEquals(0.1d, getMetric("heartbeat-rate").metricValue()); + assertEquals(3.0d, getMetric("heartbeat-total").metricValue()); + + assertEquals(-1.0d, getMetric("last-heartbeat-seconds-ago").metricValue()); + coordinator.heartbeat().sentHeartbeat(mockTime.milliseconds()); + assertEquals(0.0d, getMetric("last-heartbeat-seconds-ago").metricValue()); + mockTime.sleep(10 * 1000L); + assertEquals(10.0d, getMetric("last-heartbeat-seconds-ago").metricValue()); + + metrics.sensor("join-latency").record(1.0d); + metrics.sensor("join-latency").record(6.0d); + metrics.sensor("join-latency").record(2.0d); + + assertEquals(3.0d, getMetric("join-time-avg").metricValue()); + assertEquals(6.0d, getMetric("join-time-max").metricValue()); + assertEquals(0.1d, getMetric("join-rate").metricValue()); + assertEquals(3.0d, getMetric("join-total").metricValue()); + + metrics.sensor("sync-latency").record(1.0d); + metrics.sensor("sync-latency").record(6.0d); + metrics.sensor("sync-latency").record(2.0d); + + assertEquals(3.0d, getMetric("sync-time-avg").metricValue()); + assertEquals(6.0d, getMetric("sync-time-max").metricValue()); + assertEquals(0.1d, getMetric("sync-rate").metricValue()); + assertEquals(3.0d, getMetric("sync-total").metricValue()); + + metrics.sensor("rebalance-latency").record(1.0d); + metrics.sensor("rebalance-latency").record(6.0d); + metrics.sensor("rebalance-latency").record(2.0d); + + assertEquals(3.0d, getMetric("rebalance-latency-avg").metricValue()); + assertEquals(6.0d, getMetric("rebalance-latency-max").metricValue()); + assertEquals(9.0d, getMetric("rebalance-latency-total").metricValue()); + assertEquals(360.0d, getMetric("rebalance-rate-per-hour").metricValue()); + assertEquals(3.0d, getMetric("rebalance-total").metricValue()); + + metrics.sensor("failed-rebalance").record(1.0d); + metrics.sensor("failed-rebalance").record(6.0d); + metrics.sensor("failed-rebalance").record(2.0d); + + assertEquals(360.0d, getMetric("failed-rebalance-rate-per-hour").metricValue()); + assertEquals(3.0d, getMetric("failed-rebalance-total").metricValue()); + + assertEquals(-1.0d, getMetric("last-rebalance-seconds-ago").metricValue()); + coordinator.setLastRebalanceTime(mockTime.milliseconds()); + assertEquals(0.0d, getMetric("last-rebalance-seconds-ago").metricValue()); + mockTime.sleep(10 * 1000L); + assertEquals(10.0d, getMetric("last-rebalance-seconds-ago").metricValue()); + } + + private KafkaMetric getMetric(final String name) { + return metrics.metrics().get(metrics.metricName(name, "consumer-coordinator-metrics")); + } + + @Test + public void testCoordinatorDiscoveryBackoff() { + setupCoordinator(); + + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + + // cut out the coordinator for 10 milliseconds to simulate a disconnect. + // after backing off, we should be able to connect. + mockClient.backoff(coordinatorNode, 10L); + + long initialTime = mockTime.milliseconds(); + coordinator.ensureCoordinatorReady(mockTime.timer(Long.MAX_VALUE)); + long endTime = mockTime.milliseconds(); + + assertTrue(endTime - initialTime >= RETRY_BACKOFF_MS); + } + + @Test + public void testTimeoutAndRetryJoinGroupIfNeeded() throws Exception { + setupCoordinator(); + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(mockTime.timer(0)); + + ExecutorService executor = Executors.newFixedThreadPool(1); + try { + Timer firstAttemptTimer = mockTime.timer(REQUEST_TIMEOUT_MS); + Future firstAttempt = executor.submit(() -> coordinator.joinGroupIfNeeded(firstAttemptTimer)); + + mockTime.sleep(REQUEST_TIMEOUT_MS); + assertFalse(firstAttempt.get()); + assertTrue(consumerClient.hasPendingRequests(coordinatorNode)); + + mockClient.respond(joinGroupFollowerResponse(1, memberId, leaderId, Errors.NONE)); + mockClient.prepareResponse(syncGroupResponse(Errors.NONE)); + + Timer secondAttemptTimer = mockTime.timer(REQUEST_TIMEOUT_MS); + Future secondAttempt = executor.submit(() -> coordinator.joinGroupIfNeeded(secondAttemptTimer)); + + assertTrue(secondAttempt.get()); + } finally { + executor.shutdownNow(); + executor.awaitTermination(1000, TimeUnit.MILLISECONDS); + } + } + + @Test + public void testGroupMaxSizeExceptionIsFatal() { + setupCoordinator(); + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(mockTime.timer(0)); + + mockClient.prepareResponse(joinGroupFollowerResponse(defaultGeneration, memberId, JoinGroupRequest.UNKNOWN_MEMBER_ID, Errors.GROUP_MAX_SIZE_REACHED)); + + RequestFuture future = coordinator.sendJoinGroupRequest(); + assertTrue(consumerClient.poll(future, mockTime.timer(REQUEST_TIMEOUT_MS))); + assertTrue(future.exception().getClass().isInstance(Errors.GROUP_MAX_SIZE_REACHED.exception())); + assertFalse(future.isRetriable()); + } + + @Test + public void testJoinGroupRequestTimeout() { + setupCoordinator(RETRY_BACKOFF_MS, REBALANCE_TIMEOUT_MS, + Optional.empty()); + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(mockTime.timer(0)); + + RequestFuture future = coordinator.sendJoinGroupRequest(); + + mockTime.sleep(REQUEST_TIMEOUT_MS + 1); + assertFalse(consumerClient.poll(future, mockTime.timer(0))); + + mockTime.sleep(REBALANCE_TIMEOUT_MS - REQUEST_TIMEOUT_MS + AbstractCoordinator.JOIN_GROUP_TIMEOUT_LAPSE); + assertTrue(consumerClient.poll(future, mockTime.timer(0))); + assertTrue(future.exception() instanceof DisconnectException); + } + + @Test + public void testJoinGroupRequestTimeoutLowerBoundedByDefaultRequestTimeout() { + int rebalanceTimeoutMs = REQUEST_TIMEOUT_MS - 10000; + setupCoordinator(RETRY_BACKOFF_MS, rebalanceTimeoutMs, Optional.empty()); + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(mockTime.timer(0)); + + RequestFuture future = coordinator.sendJoinGroupRequest(); + + long expectedRequestDeadline = mockTime.milliseconds() + REQUEST_TIMEOUT_MS; + mockTime.sleep(rebalanceTimeoutMs + AbstractCoordinator.JOIN_GROUP_TIMEOUT_LAPSE + 1); + assertFalse(consumerClient.poll(future, mockTime.timer(0))); + + mockTime.sleep(expectedRequestDeadline - mockTime.milliseconds() + 1); + assertTrue(consumerClient.poll(future, mockTime.timer(0))); + assertTrue(future.exception() instanceof DisconnectException); + } + + @Test + public void testJoinGroupRequestMaxTimeout() { + // Ensure we can handle the maximum allowed rebalance timeout + + setupCoordinator(RETRY_BACKOFF_MS, Integer.MAX_VALUE, + Optional.empty()); + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(mockTime.timer(0)); + + RequestFuture future = coordinator.sendJoinGroupRequest(); + assertFalse(consumerClient.poll(future, mockTime.timer(0))); + + mockTime.sleep(Integer.MAX_VALUE + 1L); + assertTrue(consumerClient.poll(future, mockTime.timer(0))); + } + + @Test + public void testJoinGroupRequestWithMemberIdRequired() { + setupCoordinator(); + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(mockTime.timer(0)); + + mockClient.prepareResponse(joinGroupFollowerResponse(defaultGeneration, memberId, JoinGroupRequest.UNKNOWN_MEMBER_ID, Errors.MEMBER_ID_REQUIRED)); + + mockClient.prepareResponse(body -> { + if (!(body instanceof JoinGroupRequest)) { + return false; + } + JoinGroupRequest joinGroupRequest = (JoinGroupRequest) body; + return joinGroupRequest.data().memberId().equals(memberId); + }, joinGroupResponse(Errors.UNKNOWN_MEMBER_ID)); + + RequestFuture future = coordinator.sendJoinGroupRequest(); + assertTrue(consumerClient.poll(future, mockTime.timer(REQUEST_TIMEOUT_MS))); + assertEquals(Errors.MEMBER_ID_REQUIRED.message(), future.exception().getMessage()); + assertTrue(coordinator.rejoinNeededOrPending()); + assertTrue(coordinator.hasValidMemberId()); + assertTrue(coordinator.hasMatchingGenerationId(defaultGeneration)); + future = coordinator.sendJoinGroupRequest(); + assertTrue(consumerClient.poll(future, mockTime.timer(REBALANCE_TIMEOUT_MS))); + } + + @Test + public void testJoinGroupRequestWithFencedInstanceIdException() { + setupCoordinator(); + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(mockTime.timer(0)); + + mockClient.prepareResponse(joinGroupFollowerResponse(defaultGeneration, memberId, JoinGroupRequest.UNKNOWN_MEMBER_ID, Errors.FENCED_INSTANCE_ID)); + + RequestFuture future = coordinator.sendJoinGroupRequest(); + assertTrue(consumerClient.poll(future, mockTime.timer(REQUEST_TIMEOUT_MS))); + assertEquals(Errors.FENCED_INSTANCE_ID.message(), future.exception().getMessage()); + // Make sure the exception is fatal. + assertFalse(future.isRetriable()); + } + + @Test + public void testJoinGroupProtocolTypeAndName() { + final String wrongProtocolType = "wrong-type"; + final String wrongProtocolName = "wrong-name"; + + // No Protocol Type in both JoinGroup and SyncGroup responses + assertTrue(joinGroupWithProtocolTypeAndName(null, null, null)); + + // Protocol Type in both JoinGroup and SyncGroup responses + assertTrue(joinGroupWithProtocolTypeAndName(PROTOCOL_TYPE, PROTOCOL_TYPE, PROTOCOL_NAME)); + + // Wrong protocol type in the JoinGroupResponse + assertThrows(InconsistentGroupProtocolException.class, + () -> joinGroupWithProtocolTypeAndName("wrong", null, null)); + + // Correct protocol type in the JoinGroupResponse + // Wrong protocol type in the SyncGroupResponse + // Correct protocol name in the SyncGroupResponse + assertThrows(InconsistentGroupProtocolException.class, + () -> joinGroupWithProtocolTypeAndName(PROTOCOL_TYPE, wrongProtocolType, PROTOCOL_NAME)); + + // Correct protocol type in the JoinGroupResponse + // Correct protocol type in the SyncGroupResponse + // Wrong protocol name in the SyncGroupResponse + assertThrows(InconsistentGroupProtocolException.class, + () -> joinGroupWithProtocolTypeAndName(PROTOCOL_TYPE, PROTOCOL_TYPE, wrongProtocolName)); + } + + @Test + public void testRetainMemberIdAfterJoinGroupDisconnect() { + setupCoordinator(); + + String memberId = "memberId"; + int generation = 5; + + // Rebalance once to initialize the generation and memberId + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + expectJoinGroup("", generation, memberId); + expectSyncGroup(generation, memberId); + ensureActiveGroup(generation, memberId); + + // Force a rebalance + coordinator.requestRejoin("Manual test trigger"); + assertTrue(coordinator.rejoinNeededOrPending()); + + // Disconnect during the JoinGroup and ensure that the retry preserves the memberId + int rejoinedGeneration = 10; + expectDisconnectInJoinGroup(memberId); + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + expectJoinGroup(memberId, rejoinedGeneration, memberId); + expectSyncGroup(rejoinedGeneration, memberId); + ensureActiveGroup(rejoinedGeneration, memberId); + } + + @Test + public void testRetainMemberIdAfterSyncGroupDisconnect() { + setupCoordinator(); + + String memberId = "memberId"; + int generation = 5; + + // Rebalance once to initialize the generation and memberId + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + expectJoinGroup("", generation, memberId); + expectSyncGroup(generation, memberId); + ensureActiveGroup(generation, memberId); + + // Force a rebalance + coordinator.requestRejoin("Manual test trigger"); + assertTrue(coordinator.rejoinNeededOrPending()); + + // Disconnect during the SyncGroup and ensure that the retry preserves the memberId + int rejoinedGeneration = 10; + expectJoinGroup(memberId, rejoinedGeneration, memberId); + expectDisconnectInSyncGroup(rejoinedGeneration, memberId); + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + + // Note that the consumer always starts from JoinGroup after a failed rebalance + expectJoinGroup(memberId, rejoinedGeneration, memberId); + expectSyncGroup(rejoinedGeneration, memberId); + ensureActiveGroup(rejoinedGeneration, memberId); + } + + private void ensureActiveGroup( + int generation, + String memberId + ) { + coordinator.ensureActiveGroup(); + assertEquals(generation, coordinator.generation().generationId); + assertEquals(memberId, coordinator.generation().memberId); + assertFalse(coordinator.rejoinNeededOrPending()); + } + + private void expectSyncGroup( + int expectedGeneration, + String expectedMemberId + ) { + mockClient.prepareResponse(body -> { + if (!(body instanceof SyncGroupRequest)) { + return false; + } + SyncGroupRequestData syncGroupRequest = ((SyncGroupRequest) body).data(); + return syncGroupRequest.generationId() == expectedGeneration + && syncGroupRequest.memberId().equals(expectedMemberId) + && syncGroupRequest.protocolType().equals(PROTOCOL_TYPE) + && syncGroupRequest.protocolName().equals(PROTOCOL_NAME); + }, syncGroupResponse(Errors.NONE, PROTOCOL_TYPE, PROTOCOL_NAME)); + } + + private void expectDisconnectInSyncGroup( + int expectedGeneration, + String expectedMemberId + ) { + mockClient.prepareResponse(body -> { + if (!(body instanceof SyncGroupRequest)) { + return false; + } + SyncGroupRequestData syncGroupRequest = ((SyncGroupRequest) body).data(); + return syncGroupRequest.generationId() == expectedGeneration + && syncGroupRequest.memberId().equals(expectedMemberId) + && syncGroupRequest.protocolType().equals(PROTOCOL_TYPE) + && syncGroupRequest.protocolName().equals(PROTOCOL_NAME); + }, null, true); + } + + private void expectDisconnectInJoinGroup( + String expectedMemberId + ) { + mockClient.prepareResponse(body -> { + if (!(body instanceof JoinGroupRequest)) { + return false; + } + JoinGroupRequestData joinGroupRequest = ((JoinGroupRequest) body).data(); + return joinGroupRequest.memberId().equals(expectedMemberId) + && joinGroupRequest.protocolType().equals(PROTOCOL_TYPE); + }, null, true); + } + + private void expectJoinGroup( + String expectedMemberId, + int responseGeneration, + String responseMemberId + ) { + JoinGroupResponse response = joinGroupFollowerResponse( + responseGeneration, + responseMemberId, + "leaderId", + Errors.NONE, + PROTOCOL_TYPE + ); + + mockClient.prepareResponse(body -> { + if (!(body instanceof JoinGroupRequest)) { + return false; + } + JoinGroupRequestData joinGroupRequest = ((JoinGroupRequest) body).data(); + return joinGroupRequest.memberId().equals(expectedMemberId) + && joinGroupRequest.protocolType().equals(PROTOCOL_TYPE); + }, response); + } + + @Test + public void testNoGenerationWillNotTriggerProtocolNameCheck() { + final String wrongProtocolName = "wrong-name"; + + setupCoordinator(); + mockClient.reset(); + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(mockTime.timer(0)); + + mockClient.prepareResponse(body -> { + if (!(body instanceof JoinGroupRequest)) { + return false; + } + JoinGroupRequest joinGroupRequest = (JoinGroupRequest) body; + return joinGroupRequest.data().protocolType().equals(PROTOCOL_TYPE); + }, joinGroupFollowerResponse(defaultGeneration, memberId, + "memberid", Errors.NONE, PROTOCOL_TYPE)); + + mockClient.prepareResponse(body -> { + if (!(body instanceof SyncGroupRequest)) { + return false; + } + coordinator.resetGenerationOnLeaveGroup(); + + SyncGroupRequest syncGroupRequest = (SyncGroupRequest) body; + return syncGroupRequest.data().protocolType().equals(PROTOCOL_TYPE) + && syncGroupRequest.data().protocolName().equals(PROTOCOL_NAME); + }, syncGroupResponse(Errors.NONE, PROTOCOL_TYPE, wrongProtocolName)); + + // let the retry to complete successfully to break out of the while loop + mockClient.prepareResponse(body -> { + if (!(body instanceof JoinGroupRequest)) { + return false; + } + JoinGroupRequest joinGroupRequest = (JoinGroupRequest) body; + return joinGroupRequest.data().protocolType().equals(PROTOCOL_TYPE); + }, joinGroupFollowerResponse(1, memberId, + "memberid", Errors.NONE, PROTOCOL_TYPE)); + + mockClient.prepareResponse(body -> { + if (!(body instanceof SyncGroupRequest)) { + return false; + } + + SyncGroupRequest syncGroupRequest = (SyncGroupRequest) body; + return syncGroupRequest.data().protocolType().equals(PROTOCOL_TYPE) + && syncGroupRequest.data().protocolName().equals(PROTOCOL_NAME); + }, syncGroupResponse(Errors.NONE, PROTOCOL_TYPE, PROTOCOL_NAME)); + + // No exception shall be thrown as the generation is reset. + coordinator.joinGroupIfNeeded(mockTime.timer(100L)); + } + + private boolean joinGroupWithProtocolTypeAndName(String joinGroupResponseProtocolType, + String syncGroupResponseProtocolType, + String syncGroupResponseProtocolName) { + setupCoordinator(); + mockClient.reset(); + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(mockTime.timer(0)); + + mockClient.prepareResponse(body -> { + if (!(body instanceof JoinGroupRequest)) { + return false; + } + JoinGroupRequest joinGroupRequest = (JoinGroupRequest) body; + return joinGroupRequest.data().protocolType().equals(PROTOCOL_TYPE); + }, joinGroupFollowerResponse(defaultGeneration, memberId, + "memberid", Errors.NONE, joinGroupResponseProtocolType)); + + mockClient.prepareResponse(body -> { + if (!(body instanceof SyncGroupRequest)) { + return false; + } + SyncGroupRequest syncGroupRequest = (SyncGroupRequest) body; + return syncGroupRequest.data().protocolType().equals(PROTOCOL_TYPE) + && syncGroupRequest.data().protocolName().equals(PROTOCOL_NAME); + }, syncGroupResponse(Errors.NONE, syncGroupResponseProtocolType, syncGroupResponseProtocolName)); + + return coordinator.joinGroupIfNeeded(mockTime.timer(5000L)); + } + + @Test + public void testSyncGroupRequestWithFencedInstanceIdException() { + setupCoordinator(); + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + + final int generation = -1; + + mockClient.prepareResponse(joinGroupFollowerResponse(generation, memberId, JoinGroupRequest.UNKNOWN_MEMBER_ID, Errors.NONE)); + mockClient.prepareResponse(syncGroupResponse(Errors.FENCED_INSTANCE_ID)); + + assertThrows(FencedInstanceIdException.class, () -> coordinator.ensureActiveGroup()); + } + + @Test + public void testJoinGroupUnknownMemberResponseWithOldGeneration() throws InterruptedException { + setupCoordinator(); + joinGroup(); + + final AbstractCoordinator.Generation currGen = coordinator.generation(); + + RequestFuture future = coordinator.sendJoinGroupRequest(); + + TestUtils.waitForCondition(() -> !mockClient.requests().isEmpty(), 2000, + "The join-group request was not sent"); + + // change the generation after the join-group request + final AbstractCoordinator.Generation newGen = new AbstractCoordinator.Generation( + currGen.generationId, + currGen.memberId + "-new", + currGen.protocolName); + coordinator.setNewGeneration(newGen); + + mockClient.respond(joinGroupFollowerResponse(currGen.generationId + 1, memberId, JoinGroupRequest.UNKNOWN_MEMBER_ID, Errors.UNKNOWN_MEMBER_ID)); + + assertTrue(consumerClient.poll(future, mockTime.timer(REQUEST_TIMEOUT_MS))); + assertTrue(future.exception().getClass().isInstance(Errors.UNKNOWN_MEMBER_ID.exception())); + + // the generation should not be reset + assertEquals(newGen, coordinator.generation()); + } + + @Test + public void testSyncGroupUnknownMemberResponseWithOldGeneration() throws InterruptedException { + setupCoordinator(); + joinGroup(); + + final AbstractCoordinator.Generation currGen = coordinator.generation(); + + coordinator.setNewState(AbstractCoordinator.MemberState.PREPARING_REBALANCE); + RequestFuture future = coordinator.sendJoinGroupRequest(); + + TestUtils.waitForCondition(() -> { + consumerClient.poll(mockTime.timer(REQUEST_TIMEOUT_MS)); + return !mockClient.requests().isEmpty(); + }, 2000, + "The join-group request was not sent"); + + mockClient.respond(joinGroupFollowerResponse(currGen.generationId, memberId, JoinGroupRequest.UNKNOWN_MEMBER_ID, Errors.NONE)); + assertTrue(mockClient.requests().isEmpty()); + + TestUtils.waitForCondition(() -> { + consumerClient.poll(mockTime.timer(REQUEST_TIMEOUT_MS)); + return !mockClient.requests().isEmpty(); + }, 2000, + "The sync-group request was not sent"); + + // change the generation after the sync-group request + final AbstractCoordinator.Generation newGen = new AbstractCoordinator.Generation( + currGen.generationId, + currGen.memberId + "-new", + currGen.protocolName); + coordinator.setNewGeneration(newGen); + + mockClient.respond(syncGroupResponse(Errors.UNKNOWN_MEMBER_ID)); + assertTrue(consumerClient.poll(future, mockTime.timer(REQUEST_TIMEOUT_MS))); + assertTrue(future.exception().getClass().isInstance(Errors.UNKNOWN_MEMBER_ID.exception())); + + // the generation should not be reset + assertEquals(newGen, coordinator.generation()); + } + + @Test + public void testSyncGroupIllegalGenerationResponseWithOldGeneration() throws InterruptedException { + setupCoordinator(); + joinGroup(); + + final AbstractCoordinator.Generation currGen = coordinator.generation(); + + coordinator.setNewState(AbstractCoordinator.MemberState.PREPARING_REBALANCE); + RequestFuture future = coordinator.sendJoinGroupRequest(); + + TestUtils.waitForCondition(() -> { + consumerClient.poll(mockTime.timer(REQUEST_TIMEOUT_MS)); + return !mockClient.requests().isEmpty(); + }, 2000, + "The join-group request was not sent"); + + mockClient.respond(joinGroupFollowerResponse(currGen.generationId, memberId, JoinGroupRequest.UNKNOWN_MEMBER_ID, Errors.NONE)); + assertTrue(mockClient.requests().isEmpty()); + + TestUtils.waitForCondition(() -> { + consumerClient.poll(mockTime.timer(REQUEST_TIMEOUT_MS)); + return !mockClient.requests().isEmpty(); + }, 2000, + "The sync-group request was not sent"); + + // change the generation after the sync-group request + final AbstractCoordinator.Generation newGen = new AbstractCoordinator.Generation( + currGen.generationId, + currGen.memberId + "-new", + currGen.protocolName); + coordinator.setNewGeneration(newGen); + + mockClient.respond(syncGroupResponse(Errors.ILLEGAL_GENERATION)); + assertTrue(consumerClient.poll(future, mockTime.timer(REQUEST_TIMEOUT_MS))); + assertTrue(future.exception().getClass().isInstance(Errors.ILLEGAL_GENERATION.exception())); + + // the generation should not be reset + assertEquals(newGen, coordinator.generation()); + } + + @Test + public void testHeartbeatSentWhenCompletingRebalance() throws Exception { + setupCoordinator(); + joinGroup(); + + final AbstractCoordinator.Generation currGen = coordinator.generation(); + + coordinator.setNewState(AbstractCoordinator.MemberState.COMPLETING_REBALANCE); + + // the heartbeat should be sent out during a rebalance + mockTime.sleep(HEARTBEAT_INTERVAL_MS); + TestUtils.waitForCondition(() -> !mockClient.requests().isEmpty(), 2000, + "The heartbeat request was not sent"); + assertTrue(coordinator.heartbeat().hasInflight()); + + mockClient.respond(heartbeatResponse(Errors.REBALANCE_IN_PROGRESS)); + assertEquals(currGen, coordinator.generation()); + } + + @Test + public void testHeartbeatIllegalGenerationResponseWithOldGeneration() throws InterruptedException { + setupCoordinator(); + joinGroup(); + + final AbstractCoordinator.Generation currGen = coordinator.generation(); + + // let the heartbeat thread send out a request + mockTime.sleep(HEARTBEAT_INTERVAL_MS); + + TestUtils.waitForCondition(() -> !mockClient.requests().isEmpty(), 2000, + "The heartbeat request was not sent"); + assertTrue(coordinator.heartbeat().hasInflight()); + + // change the generation + final AbstractCoordinator.Generation newGen = new AbstractCoordinator.Generation( + currGen.generationId + 1, + currGen.memberId, + currGen.protocolName); + coordinator.setNewGeneration(newGen); + + mockClient.respond(heartbeatResponse(Errors.ILLEGAL_GENERATION)); + + // the heartbeat error code should be ignored + TestUtils.waitForCondition(() -> { + coordinator.pollHeartbeat(mockTime.milliseconds()); + return !coordinator.heartbeat().hasInflight(); + }, 2000, + "The heartbeat response was not received"); + + // the generation should not be reset + assertEquals(newGen, coordinator.generation()); + } + + @Test + public void testHeartbeatUnknownMemberResponseWithOldGeneration() throws InterruptedException { + setupCoordinator(); + joinGroup(); + + final AbstractCoordinator.Generation currGen = coordinator.generation(); + + // let the heartbeat request to send out a request + mockTime.sleep(HEARTBEAT_INTERVAL_MS); + + TestUtils.waitForCondition(() -> !mockClient.requests().isEmpty(), 2000, + "The heartbeat request was not sent"); + assertTrue(coordinator.heartbeat().hasInflight()); + + // change the generation + final AbstractCoordinator.Generation newGen = new AbstractCoordinator.Generation( + currGen.generationId, + currGen.memberId + "-new", + currGen.protocolName); + coordinator.setNewGeneration(newGen); + + mockClient.respond(heartbeatResponse(Errors.UNKNOWN_MEMBER_ID)); + + // the heartbeat error code should be ignored + TestUtils.waitForCondition(() -> { + coordinator.pollHeartbeat(mockTime.milliseconds()); + return !coordinator.heartbeat().hasInflight(); + }, 2000, + "The heartbeat response was not received"); + + // the generation should not be reset + assertEquals(newGen, coordinator.generation()); + } + + @Test + public void testHeartbeatRebalanceInProgressResponseDuringRebalancing() throws InterruptedException { + setupCoordinator(); + joinGroup(); + + final AbstractCoordinator.Generation currGen = coordinator.generation(); + + // let the heartbeat request to send out a request + mockTime.sleep(HEARTBEAT_INTERVAL_MS); + + TestUtils.waitForCondition(() -> !mockClient.requests().isEmpty(), 2000, + "The heartbeat request was not sent"); + + assertTrue(coordinator.heartbeat().hasInflight()); + + mockClient.respond(heartbeatResponse(Errors.REBALANCE_IN_PROGRESS)); + + coordinator.requestRejoin("test"); + + TestUtils.waitForCondition(() -> { + coordinator.ensureActiveGroup(new MockTime(1L).timer(100L)); + return !coordinator.heartbeat().hasInflight(); + }, + 2000, + "The heartbeat response was not received"); + + // the generation would not be reset while the rebalance is in progress + assertEquals(currGen, coordinator.generation()); + + mockClient.respond(joinGroupFollowerResponse(currGen.generationId, memberId, JoinGroupRequest.UNKNOWN_MEMBER_ID, Errors.NONE)); + mockClient.prepareResponse(syncGroupResponse(Errors.NONE)); + + coordinator.ensureActiveGroup(); + assertEquals(currGen, coordinator.generation()); + } + + @Test + public void testHeartbeatInstanceFencedResponseWithOldGeneration() throws InterruptedException { + setupCoordinator(); + joinGroup(); + + final AbstractCoordinator.Generation currGen = coordinator.generation(); + + // let the heartbeat request to send out a request + mockTime.sleep(HEARTBEAT_INTERVAL_MS); + + TestUtils.waitForCondition(() -> !mockClient.requests().isEmpty(), 2000, + "The heartbeat request was not sent"); + assertTrue(coordinator.heartbeat().hasInflight()); + + // change the generation + final AbstractCoordinator.Generation newGen = new AbstractCoordinator.Generation( + currGen.generationId, + currGen.memberId + "-new", + currGen.protocolName); + coordinator.setNewGeneration(newGen); + + mockClient.respond(heartbeatResponse(Errors.FENCED_INSTANCE_ID)); + + // the heartbeat error code should be ignored + TestUtils.waitForCondition(() -> { + coordinator.pollHeartbeat(mockTime.milliseconds()); + return !coordinator.heartbeat().hasInflight(); + }, 2000, + "The heartbeat response was not received"); + + // the generation should not be reset + assertEquals(newGen, coordinator.generation()); + } + + @Test + public void testHeartbeatRequestWithFencedInstanceIdException() throws InterruptedException { + setupCoordinator(); + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + + final int generation = -1; + + mockClient.prepareResponse(joinGroupFollowerResponse(generation, memberId, JoinGroupRequest.UNKNOWN_MEMBER_ID, Errors.NONE)); + mockClient.prepareResponse(syncGroupResponse(Errors.NONE)); + mockClient.prepareResponse(heartbeatResponse(Errors.FENCED_INSTANCE_ID)); + + try { + coordinator.ensureActiveGroup(); + mockTime.sleep(HEARTBEAT_INTERVAL_MS); + long startMs = System.currentTimeMillis(); + while (System.currentTimeMillis() - startMs < 1000) { + Thread.sleep(10); + coordinator.pollHeartbeat(mockTime.milliseconds()); + } + fail("Expected pollHeartbeat to raise fenced instance id exception in 1 second"); + } catch (RuntimeException exception) { + assertTrue(exception instanceof FencedInstanceIdException); + } + } + + @Test + public void testJoinGroupRequestWithGroupInstanceIdNotFound() { + setupCoordinator(); + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(mockTime.timer(0)); + + mockClient.prepareResponse(joinGroupFollowerResponse(defaultGeneration, memberId, JoinGroupRequest.UNKNOWN_MEMBER_ID, Errors.UNKNOWN_MEMBER_ID)); + + RequestFuture future = coordinator.sendJoinGroupRequest(); + + assertTrue(consumerClient.poll(future, mockTime.timer(REQUEST_TIMEOUT_MS))); + assertEquals(Errors.UNKNOWN_MEMBER_ID.message(), future.exception().getMessage()); + assertTrue(coordinator.rejoinNeededOrPending()); + assertTrue(coordinator.hasUnknownGeneration()); + } + + @Test + public void testJoinGroupRequestWithRebalanceInProgress() { + setupCoordinator(); + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(mockTime.timer(0)); + + mockClient.prepareResponse( + joinGroupFollowerResponse(defaultGeneration, memberId, JoinGroupRequest.UNKNOWN_MEMBER_ID, Errors.REBALANCE_IN_PROGRESS)); + + RequestFuture future = coordinator.sendJoinGroupRequest(); + + assertTrue(consumerClient.poll(future, mockTime.timer(REQUEST_TIMEOUT_MS))); + assertTrue(future.exception().getClass().isInstance(Errors.REBALANCE_IN_PROGRESS.exception())); + assertEquals(Errors.REBALANCE_IN_PROGRESS.message(), future.exception().getMessage()); + assertTrue(coordinator.rejoinNeededOrPending()); + + // make sure we'll retry on next poll + assertEquals(0, coordinator.onJoinPrepareInvokes); + assertEquals(0, coordinator.onJoinCompleteInvokes); + + mockClient.prepareResponse(joinGroupFollowerResponse(defaultGeneration, memberId, JoinGroupRequest.UNKNOWN_MEMBER_ID, Errors.NONE)); + mockClient.prepareResponse(syncGroupResponse(Errors.NONE)); + + coordinator.ensureActiveGroup(); + // make sure both onJoinPrepare and onJoinComplete got called + assertEquals(1, coordinator.onJoinPrepareInvokes); + assertEquals(1, coordinator.onJoinCompleteInvokes); + } + + @Test + public void testLeaveGroupSentWithGroupInstanceIdUnSet() { + checkLeaveGroupRequestSent(Optional.empty()); + checkLeaveGroupRequestSent(Optional.of("groupInstanceId")); + } + + private void checkLeaveGroupRequestSent(Optional groupInstanceId) { + setupCoordinator(RETRY_BACKOFF_MS, Integer.MAX_VALUE, groupInstanceId); + + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + mockClient.prepareResponse(joinGroupFollowerResponse(1, memberId, leaderId, Errors.NONE)); + mockClient.prepareResponse(syncGroupResponse(Errors.NONE)); + + final RuntimeException e = new RuntimeException(); + + // raise the error when the coordinator tries to send leave group request. + mockClient.prepareResponse(body -> { + if (body instanceof LeaveGroupRequest) + throw e; + return false; + }, heartbeatResponse(Errors.UNKNOWN_SERVER_ERROR)); + + try { + coordinator.ensureActiveGroup(); + coordinator.close(); + if (coordinator.isDynamicMember()) { + fail("Expected leavegroup to raise an error."); + } + } catch (RuntimeException exception) { + if (coordinator.isDynamicMember()) { + assertEquals(exception, e); + } else { + fail("Coordinator with group.instance.id set shouldn't send leave group request."); + } + } + } + + @Test + public void testHandleNormalLeaveGroupResponse() { + MemberResponse memberResponse = new MemberResponse() + .setMemberId(memberId) + .setErrorCode(Errors.NONE.code()); + LeaveGroupResponse response = + leaveGroupResponse(Collections.singletonList(memberResponse)); + RequestFuture leaveGroupFuture = setupLeaveGroup(response); + assertNotNull(leaveGroupFuture); + assertTrue(leaveGroupFuture.succeeded()); + } + + @Test + public void testHandleMultipleMembersLeaveGroupResponse() { + MemberResponse memberResponse = new MemberResponse() + .setMemberId(memberId) + .setErrorCode(Errors.NONE.code()); + LeaveGroupResponse response = + leaveGroupResponse(Arrays.asList(memberResponse, memberResponse)); + RequestFuture leaveGroupFuture = setupLeaveGroup(response); + assertNotNull(leaveGroupFuture); + assertTrue(leaveGroupFuture.exception() instanceof IllegalStateException); + } + + @Test + public void testHandleLeaveGroupResponseWithEmptyMemberResponse() { + LeaveGroupResponse response = + leaveGroupResponse(Collections.emptyList()); + RequestFuture leaveGroupFuture = setupLeaveGroup(response); + assertNotNull(leaveGroupFuture); + assertTrue(leaveGroupFuture.succeeded()); + } + + @Test + public void testHandleLeaveGroupResponseWithException() { + MemberResponse memberResponse = new MemberResponse() + .setMemberId(memberId) + .setErrorCode(Errors.UNKNOWN_MEMBER_ID.code()); + LeaveGroupResponse response = + leaveGroupResponse(Collections.singletonList(memberResponse)); + RequestFuture leaveGroupFuture = setupLeaveGroup(response); + assertNotNull(leaveGroupFuture); + assertTrue(leaveGroupFuture.exception() instanceof UnknownMemberIdException); + } + + private RequestFuture setupLeaveGroup(LeaveGroupResponse leaveGroupResponse) { + setupCoordinator(RETRY_BACKOFF_MS, Integer.MAX_VALUE, Optional.empty()); + + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + mockClient.prepareResponse(joinGroupFollowerResponse(1, memberId, leaderId, Errors.NONE)); + mockClient.prepareResponse(syncGroupResponse(Errors.NONE)); + mockClient.prepareResponse(leaveGroupResponse); + + coordinator.ensureActiveGroup(); + return coordinator.maybeLeaveGroup("test maybe leave group"); + } + + @Test + public void testUncaughtExceptionInHeartbeatThread() throws Exception { + setupCoordinator(); + + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + mockClient.prepareResponse(joinGroupFollowerResponse(1, memberId, leaderId, Errors.NONE)); + mockClient.prepareResponse(syncGroupResponse(Errors.NONE)); + + final RuntimeException e = new RuntimeException(); + + // raise the error when the background thread tries to send a heartbeat + mockClient.prepareResponse(body -> { + if (body instanceof HeartbeatRequest) + throw e; + return false; + }, heartbeatResponse(Errors.UNKNOWN_SERVER_ERROR)); + + try { + coordinator.ensureActiveGroup(); + mockTime.sleep(HEARTBEAT_INTERVAL_MS); + long startMs = System.currentTimeMillis(); + while (System.currentTimeMillis() - startMs < 1000) { + Thread.sleep(10); + coordinator.pollHeartbeat(mockTime.milliseconds()); + } + fail("Expected pollHeartbeat to raise an error in 1 second"); + } catch (RuntimeException exception) { + assertEquals(exception, e); + } + } + + @Test + public void testPollHeartbeatAwakesHeartbeatThread() throws Exception { + final int longRetryBackoffMs = 10000; + setupCoordinator(longRetryBackoffMs); + + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + mockClient.prepareResponse(joinGroupFollowerResponse(1, memberId, leaderId, Errors.NONE)); + mockClient.prepareResponse(syncGroupResponse(Errors.NONE)); + + coordinator.ensureActiveGroup(); + + final CountDownLatch heartbeatDone = new CountDownLatch(1); + mockClient.prepareResponse(body -> { + heartbeatDone.countDown(); + return body instanceof HeartbeatRequest; + }, heartbeatResponse(Errors.NONE)); + + mockTime.sleep(HEARTBEAT_INTERVAL_MS); + coordinator.pollHeartbeat(mockTime.milliseconds()); + + if (!heartbeatDone.await(1, TimeUnit.SECONDS)) { + fail("Should have received a heartbeat request after calling pollHeartbeat"); + } + } + + @Test + public void testLookupCoordinator() { + setupCoordinator(); + + mockClient.backoff(node, 50); + RequestFuture noBrokersAvailableFuture = coordinator.lookupCoordinator(); + assertTrue(noBrokersAvailableFuture.failed(), "Failed future expected"); + mockTime.sleep(50); + + RequestFuture future = coordinator.lookupCoordinator(); + assertFalse(future.isDone(), "Request not sent"); + assertSame(future, coordinator.lookupCoordinator(), "New request sent while one is in progress"); + + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(mockTime.timer(Long.MAX_VALUE)); + assertNotSame(future, coordinator.lookupCoordinator(), "New request not sent after previous completed"); + } + + @Test + public void testWakeupAfterJoinGroupSent() throws Exception { + setupCoordinator(); + + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + mockClient.prepareResponse(new MockClient.RequestMatcher() { + private int invocations = 0; + @Override + public boolean matches(AbstractRequest body) { + invocations++; + boolean isJoinGroupRequest = body instanceof JoinGroupRequest; + if (isJoinGroupRequest && invocations == 1) + // simulate wakeup before the request returns + throw new WakeupException(); + return isJoinGroupRequest; + } + }, joinGroupFollowerResponse(1, memberId, leaderId, Errors.NONE)); + mockClient.prepareResponse(syncGroupResponse(Errors.NONE)); + AtomicBoolean heartbeatReceived = prepareFirstHeartbeat(); + + try { + coordinator.ensureActiveGroup(); + fail("Should have woken up from ensureActiveGroup()"); + } catch (WakeupException ignored) { + } + + assertEquals(1, coordinator.onJoinPrepareInvokes); + assertEquals(0, coordinator.onJoinCompleteInvokes); + assertFalse(heartbeatReceived.get()); + + coordinator.ensureActiveGroup(); + + assertEquals(1, coordinator.onJoinPrepareInvokes); + assertEquals(1, coordinator.onJoinCompleteInvokes); + + awaitFirstHeartbeat(heartbeatReceived); + } + + @Test + public void testWakeupAfterJoinGroupSentExternalCompletion() throws Exception { + setupCoordinator(); + + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + mockClient.prepareResponse(new MockClient.RequestMatcher() { + private int invocations = 0; + @Override + public boolean matches(AbstractRequest body) { + invocations++; + boolean isJoinGroupRequest = body instanceof JoinGroupRequest; + if (isJoinGroupRequest && invocations == 1) + // simulate wakeup before the request returns + throw new WakeupException(); + return isJoinGroupRequest; + } + }, joinGroupFollowerResponse(1, memberId, leaderId, Errors.NONE)); + mockClient.prepareResponse(syncGroupResponse(Errors.NONE)); + AtomicBoolean heartbeatReceived = prepareFirstHeartbeat(); + + try { + coordinator.ensureActiveGroup(); + fail("Should have woken up from ensureActiveGroup()"); + } catch (WakeupException ignored) { + } + + assertEquals(1, coordinator.onJoinPrepareInvokes); + assertEquals(0, coordinator.onJoinCompleteInvokes); + assertFalse(heartbeatReceived.get()); + + // the join group completes in this poll() + consumerClient.poll(mockTime.timer(0)); + coordinator.ensureActiveGroup(); + + assertEquals(1, coordinator.onJoinPrepareInvokes); + assertEquals(1, coordinator.onJoinCompleteInvokes); + + awaitFirstHeartbeat(heartbeatReceived); + } + + @Test + public void testWakeupAfterJoinGroupReceived() throws Exception { + setupCoordinator(); + + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + mockClient.prepareResponse(body -> { + boolean isJoinGroupRequest = body instanceof JoinGroupRequest; + if (isJoinGroupRequest) + // wakeup after the request returns + consumerClient.wakeup(); + return isJoinGroupRequest; + }, joinGroupFollowerResponse(1, memberId, leaderId, Errors.NONE)); + mockClient.prepareResponse(syncGroupResponse(Errors.NONE)); + AtomicBoolean heartbeatReceived = prepareFirstHeartbeat(); + + try { + coordinator.ensureActiveGroup(); + fail("Should have woken up from ensureActiveGroup()"); + } catch (WakeupException ignored) { + } + + assertEquals(1, coordinator.onJoinPrepareInvokes); + assertEquals(0, coordinator.onJoinCompleteInvokes); + assertFalse(heartbeatReceived.get()); + + coordinator.ensureActiveGroup(); + + assertEquals(1, coordinator.onJoinPrepareInvokes); + assertEquals(1, coordinator.onJoinCompleteInvokes); + + awaitFirstHeartbeat(heartbeatReceived); + } + + @Test + public void testWakeupAfterJoinGroupReceivedExternalCompletion() throws Exception { + setupCoordinator(); + + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + mockClient.prepareResponse(body -> { + boolean isJoinGroupRequest = body instanceof JoinGroupRequest; + if (isJoinGroupRequest) + // wakeup after the request returns + consumerClient.wakeup(); + return isJoinGroupRequest; + }, joinGroupFollowerResponse(1, memberId, leaderId, Errors.NONE)); + mockClient.prepareResponse(syncGroupResponse(Errors.NONE)); + AtomicBoolean heartbeatReceived = prepareFirstHeartbeat(); + + try { + coordinator.ensureActiveGroup(); + fail("Should have woken up from ensureActiveGroup()"); + } catch (WakeupException e) { + } + + assertEquals(1, coordinator.onJoinPrepareInvokes); + assertEquals(0, coordinator.onJoinCompleteInvokes); + assertFalse(heartbeatReceived.get()); + + // the join group completes in this poll() + consumerClient.poll(mockTime.timer(0)); + coordinator.ensureActiveGroup(); + + assertEquals(1, coordinator.onJoinPrepareInvokes); + assertEquals(1, coordinator.onJoinCompleteInvokes); + + awaitFirstHeartbeat(heartbeatReceived); + } + + @Test + public void testWakeupAfterSyncGroupSentExternalCompletion() throws Exception { + setupCoordinator(); + + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + mockClient.prepareResponse(joinGroupFollowerResponse(1, memberId, leaderId, Errors.NONE)); + mockClient.prepareResponse(new MockClient.RequestMatcher() { + private int invocations = 0; + @Override + public boolean matches(AbstractRequest body) { + invocations++; + boolean isSyncGroupRequest = body instanceof SyncGroupRequest; + if (isSyncGroupRequest && invocations == 1) + // wakeup after the request returns + consumerClient.wakeup(); + return isSyncGroupRequest; + } + }, syncGroupResponse(Errors.NONE)); + AtomicBoolean heartbeatReceived = prepareFirstHeartbeat(); + + try { + coordinator.ensureActiveGroup(); + fail("Should have woken up from ensureActiveGroup()"); + } catch (WakeupException e) { + } + + assertEquals(1, coordinator.onJoinPrepareInvokes); + assertEquals(0, coordinator.onJoinCompleteInvokes); + assertFalse(heartbeatReceived.get()); + + // the join group completes in this poll() + consumerClient.poll(mockTime.timer(0)); + coordinator.ensureActiveGroup(); + + assertEquals(1, coordinator.onJoinPrepareInvokes); + assertEquals(1, coordinator.onJoinCompleteInvokes); + + awaitFirstHeartbeat(heartbeatReceived); + } + + @Test + public void testWakeupAfterSyncGroupReceived() throws Exception { + setupCoordinator(); + + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + mockClient.prepareResponse(joinGroupFollowerResponse(1, memberId, leaderId, Errors.NONE)); + mockClient.prepareResponse(body -> { + boolean isSyncGroupRequest = body instanceof SyncGroupRequest; + if (isSyncGroupRequest) + // wakeup after the request returns + consumerClient.wakeup(); + return isSyncGroupRequest; + }, syncGroupResponse(Errors.NONE)); + AtomicBoolean heartbeatReceived = prepareFirstHeartbeat(); + + try { + coordinator.ensureActiveGroup(); + fail("Should have woken up from ensureActiveGroup()"); + } catch (WakeupException ignored) { + } + + assertEquals(1, coordinator.onJoinPrepareInvokes); + assertEquals(0, coordinator.onJoinCompleteInvokes); + assertFalse(heartbeatReceived.get()); + + coordinator.ensureActiveGroup(); + + assertEquals(1, coordinator.onJoinPrepareInvokes); + assertEquals(1, coordinator.onJoinCompleteInvokes); + + awaitFirstHeartbeat(heartbeatReceived); + } + + @Test + public void testWakeupAfterSyncGroupReceivedExternalCompletion() throws Exception { + setupCoordinator(); + + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + mockClient.prepareResponse(joinGroupFollowerResponse(1, memberId, leaderId, Errors.NONE)); + mockClient.prepareResponse(body -> { + boolean isSyncGroupRequest = body instanceof SyncGroupRequest; + if (isSyncGroupRequest) + // wakeup after the request returns + consumerClient.wakeup(); + return isSyncGroupRequest; + }, syncGroupResponse(Errors.NONE)); + AtomicBoolean heartbeatReceived = prepareFirstHeartbeat(); + + try { + coordinator.ensureActiveGroup(); + fail("Should have woken up from ensureActiveGroup()"); + } catch (WakeupException e) { + } + + assertEquals(1, coordinator.onJoinPrepareInvokes); + assertEquals(0, coordinator.onJoinCompleteInvokes); + assertFalse(heartbeatReceived.get()); + + coordinator.ensureActiveGroup(); + + assertEquals(1, coordinator.onJoinPrepareInvokes); + assertEquals(1, coordinator.onJoinCompleteInvokes); + + awaitFirstHeartbeat(heartbeatReceived); + } + + @Test + public void testWakeupInOnJoinComplete() throws Exception { + setupCoordinator(); + + coordinator.wakeupOnJoinComplete = true; + mockClient.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + mockClient.prepareResponse(joinGroupFollowerResponse(1, memberId, leaderId, Errors.NONE)); + mockClient.prepareResponse(syncGroupResponse(Errors.NONE)); + AtomicBoolean heartbeatReceived = prepareFirstHeartbeat(); + + try { + coordinator.ensureActiveGroup(); + fail("Should have woken up from ensureActiveGroup()"); + } catch (WakeupException ignored) { + } + + assertEquals(1, coordinator.onJoinPrepareInvokes); + assertEquals(0, coordinator.onJoinCompleteInvokes); + assertFalse(heartbeatReceived.get()); + + // the join group completes in this poll() + coordinator.wakeupOnJoinComplete = false; + consumerClient.poll(mockTime.timer(0)); + coordinator.ensureActiveGroup(); + + assertEquals(1, coordinator.onJoinPrepareInvokes); + assertEquals(1, coordinator.onJoinCompleteInvokes); + + awaitFirstHeartbeat(heartbeatReceived); + } + + @Test + public void testAuthenticationErrorInEnsureCoordinatorReady() { + setupCoordinator(); + + mockClient.createPendingAuthenticationError(node, 300); + + try { + coordinator.ensureCoordinatorReady(mockTime.timer(Long.MAX_VALUE)); + fail("Expected an authentication error."); + } catch (AuthenticationException e) { + // OK + } + } + + private AtomicBoolean prepareFirstHeartbeat() { + final AtomicBoolean heartbeatReceived = new AtomicBoolean(false); + mockClient.prepareResponse(body -> { + boolean isHeartbeatRequest = body instanceof HeartbeatRequest; + if (isHeartbeatRequest) + heartbeatReceived.set(true); + return isHeartbeatRequest; + }, heartbeatResponse(Errors.UNKNOWN_SERVER_ERROR)); + return heartbeatReceived; + } + + private void awaitFirstHeartbeat(final AtomicBoolean heartbeatReceived) throws Exception { + mockTime.sleep(HEARTBEAT_INTERVAL_MS); + TestUtils.waitForCondition(heartbeatReceived::get, + 3000, "Should have received a heartbeat request after joining the group"); + } + + private FindCoordinatorResponse groupCoordinatorResponse(Node node, Errors error) { + return FindCoordinatorResponse.prepareResponse(error, GROUP_ID, node); + } + + private HeartbeatResponse heartbeatResponse(Errors error) { + return new HeartbeatResponse(new HeartbeatResponseData().setErrorCode(error.code())); + } + + private JoinGroupResponse joinGroupFollowerResponse(int generationId, + String memberId, + String leaderId, + Errors error) { + return joinGroupFollowerResponse(generationId, memberId, leaderId, error, null); + } + + private JoinGroupResponse joinGroupFollowerResponse(int generationId, + String memberId, + String leaderId, + Errors error, + String protocolType) { + return new JoinGroupResponse( + new JoinGroupResponseData() + .setErrorCode(error.code()) + .setGenerationId(generationId) + .setProtocolType(protocolType) + .setProtocolName(PROTOCOL_NAME) + .setMemberId(memberId) + .setLeader(leaderId) + .setMembers(Collections.emptyList()) + ); + } + + private JoinGroupResponse joinGroupResponse(Errors error) { + return joinGroupFollowerResponse(JoinGroupRequest.UNKNOWN_GENERATION_ID, + JoinGroupRequest.UNKNOWN_MEMBER_ID, JoinGroupRequest.UNKNOWN_MEMBER_ID, error); + } + + private SyncGroupResponse syncGroupResponse(Errors error) { + return syncGroupResponse(error, null, null); + } + + private SyncGroupResponse syncGroupResponse(Errors error, + String protocolType, + String protocolName) { + return new SyncGroupResponse( + new SyncGroupResponseData() + .setErrorCode(error.code()) + .setProtocolType(protocolType) + .setProtocolName(protocolName) + .setAssignment(new byte[0]) + ); + } + + private LeaveGroupResponse leaveGroupResponse(List members) { + return new LeaveGroupResponse(new LeaveGroupResponseData() + .setErrorCode(Errors.NONE.code()) + .setMembers(members)); + } + + public static class DummyCoordinator extends AbstractCoordinator { + + private int onJoinPrepareInvokes = 0; + private int onJoinCompleteInvokes = 0; + private boolean wakeupOnJoinComplete = false; + + DummyCoordinator(GroupRebalanceConfig rebalanceConfig, + ConsumerNetworkClient client, + Metrics metrics, + Time time) { + super(rebalanceConfig, new LogContext(), client, metrics, METRIC_GROUP_PREFIX, time); + } + + @Override + protected String protocolType() { + return PROTOCOL_TYPE; + } + + @Override + protected JoinGroupRequestData.JoinGroupRequestProtocolCollection metadata() { + return new JoinGroupRequestData.JoinGroupRequestProtocolCollection( + Collections.singleton(new JoinGroupRequestData.JoinGroupRequestProtocol() + .setName(PROTOCOL_NAME) + .setMetadata(EMPTY_DATA.array())).iterator() + ); + } + + @Override + protected Map performAssignment(String leaderId, + String protocol, + List allMemberMetadata) { + Map assignment = new HashMap<>(); + for (JoinGroupResponseData.JoinGroupResponseMember member : allMemberMetadata) { + assignment.put(member.memberId(), EMPTY_DATA); + } + return assignment; + } + + @Override + protected void onJoinPrepare(int generation, String memberId) { + onJoinPrepareInvokes++; + } + + @Override + protected void onJoinComplete(int generation, String memberId, String protocol, ByteBuffer memberAssignment) { + if (wakeupOnJoinComplete) + throw new WakeupException(); + onJoinCompleteInvokes++; + } + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractPartitionAssignorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractPartitionAssignorTest.java new file mode 100644 index 0000000..9d0423d --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractPartitionAssignorTest.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.clients.consumer.internals.AbstractPartitionAssignor.MemberInfo; +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class AbstractPartitionAssignorTest { + + @Test + public void testMemberInfoSortingWithoutGroupInstanceId() { + MemberInfo m1 = new MemberInfo("a", Optional.empty()); + MemberInfo m2 = new MemberInfo("b", Optional.empty()); + MemberInfo m3 = new MemberInfo("c", Optional.empty()); + + List memberInfoList = Arrays.asList(m1, m2, m3); + assertEquals(memberInfoList, Utils.sorted(memberInfoList)); + } + + @Test + public void testMemberInfoSortingWithAllGroupInstanceId() { + MemberInfo m1 = new MemberInfo("a", Optional.of("y")); + MemberInfo m2 = new MemberInfo("b", Optional.of("z")); + MemberInfo m3 = new MemberInfo("c", Optional.of("x")); + + List memberInfoList = Arrays.asList(m1, m2, m3); + assertEquals(Arrays.asList(m3, m1, m2), Utils.sorted(memberInfoList)); + } + + @Test + public void testMemberInfoSortingSomeGroupInstanceId() { + MemberInfo m1 = new MemberInfo("a", Optional.empty()); + MemberInfo m2 = new MemberInfo("b", Optional.of("y")); + MemberInfo m3 = new MemberInfo("c", Optional.of("x")); + + List memberInfoList = Arrays.asList(m1, m2, m3); + assertEquals(Arrays.asList(m3, m2, m1), Utils.sorted(memberInfoList)); + } + + @Test + public void testMergeSortManyMemberInfo() { + Random rand = new Random(); + int bound = 2; + List memberInfoList = new ArrayList<>(); + List staticMemberList = new ArrayList<>(); + List dynamicMemberList = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + // Need to make sure all the ids are defined as 3-digits otherwise + // the comparison result will break. + String id = Integer.toString(i + 100); + Optional groupInstanceId = rand.nextInt(bound) < bound / 2 ? + Optional.of(id) : Optional.empty(); + MemberInfo m = new MemberInfo(id, groupInstanceId); + memberInfoList.add(m); + if (m.groupInstanceId.isPresent()) { + staticMemberList.add(m); + } else { + dynamicMemberList.add(m); + } + } + staticMemberList.addAll(dynamicMemberList); + Collections.shuffle(memberInfoList); + assertEquals(staticMemberList, Utils.sorted(memberInfoList)); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractStickyAssignorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractStickyAssignorTest.java new file mode 100644 index 0000000..5eb4351 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/AbstractStickyAssignorTest.java @@ -0,0 +1,1039 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; + +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Subscription; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.CollectionUtils; +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import static java.util.Collections.emptyList; +import static org.apache.kafka.clients.consumer.internals.AbstractStickyAssignor.DEFAULT_GENERATION; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public abstract class AbstractStickyAssignorTest { + protected AbstractStickyAssignor assignor; + protected String consumerId = "consumer"; + protected String consumer1 = "consumer1"; + protected String consumer2 = "consumer2"; + protected String consumer3 = "consumer3"; + protected String consumer4 = "consumer4"; + protected Map subscriptions; + protected String topic = "topic"; + protected String topic1 = "topic1"; + protected String topic2 = "topic2"; + protected String topic3 = "topic3"; + + protected abstract AbstractStickyAssignor createAssignor(); + + protected abstract Subscription buildSubscription(List topics, List partitions); + + protected abstract Subscription buildSubscriptionWithGeneration(List topics, List partitions, int generation); + + @BeforeEach + public void setUp() { + assignor = createAssignor(); + + if (subscriptions != null) { + subscriptions.clear(); + } else { + subscriptions = new HashMap<>(); + } + } + + @Test + public void testOneConsumerNoTopic() { + Map partitionsPerTopic = new HashMap<>(); + subscriptions = Collections.singletonMap(consumerId, new Subscription(Collections.emptyList())); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertEquals(Collections.singleton(consumerId), assignment.keySet()); + assertTrue(assignment.get(consumerId).isEmpty()); + assertTrue(assignor.partitionsTransferringOwnership.isEmpty()); + + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + } + + @Test + public void testOneConsumerNonexistentTopic() { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic, 0); + subscriptions = Collections.singletonMap(consumerId, new Subscription(topics(topic))); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + + assertEquals(Collections.singleton(consumerId), assignment.keySet()); + assertTrue(assignment.get(consumerId).isEmpty()); + assertTrue(assignor.partitionsTransferringOwnership.isEmpty()); + + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + } + + @Test + public void testOneConsumerOneTopic() { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic, 3); + subscriptions = Collections.singletonMap(consumerId, new Subscription(topics(topic))); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertEquals(partitions(tp(topic, 0), tp(topic, 1), tp(topic, 2)), assignment.get(consumerId)); + assertTrue(assignor.partitionsTransferringOwnership.isEmpty()); + + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + } + + @Test + public void testOnlyAssignsPartitionsFromSubscribedTopics() { + String otherTopic = "other"; + + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic, 2); + subscriptions = mkMap( + mkEntry(consumerId, buildSubscription( + topics(topic), + Arrays.asList(tp(topic, 0), tp(topic, 1), tp(otherTopic, 0), tp(otherTopic, 1))) + ) + ); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertEquals(partitions(tp(topic, 0), tp(topic, 1)), assignment.get(consumerId)); + assertTrue(assignor.partitionsTransferringOwnership.isEmpty()); + + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + } + + @Test + public void testOneConsumerMultipleTopics() { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic1, 1); + partitionsPerTopic.put(topic2, 2); + subscriptions = Collections.singletonMap(consumerId, new Subscription(topics(topic1, topic2))); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertEquals(partitions(tp(topic1, 0), tp(topic2, 0), tp(topic2, 1)), assignment.get(consumerId)); + assertTrue(assignor.partitionsTransferringOwnership.isEmpty()); + + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + } + + @Test + public void testTwoConsumersOneTopicOnePartition() { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic, 1); + + subscriptions.put(consumer1, new Subscription(topics(topic))); + subscriptions.put(consumer2, new Subscription(topics(topic))); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertTrue(assignor.partitionsTransferringOwnership.isEmpty()); + + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + } + + @Test + public void testTwoConsumersOneTopicTwoPartitions() { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic, 2); + + subscriptions.put(consumer1, new Subscription(topics(topic))); + subscriptions.put(consumer2, new Subscription(topics(topic))); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertEquals(partitions(tp(topic, 0)), assignment.get(consumer1)); + assertEquals(partitions(tp(topic, 1)), assignment.get(consumer2)); + assertTrue(assignor.partitionsTransferringOwnership.isEmpty()); + + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + } + + @Test + public void testMultipleConsumersMixedTopicSubscriptions() { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic1, 3); + partitionsPerTopic.put(topic2, 2); + + subscriptions.put(consumer1, new Subscription(topics(topic1))); + subscriptions.put(consumer2, new Subscription(topics(topic1, topic2))); + subscriptions.put(consumer3, new Subscription(topics(topic1))); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertEquals(partitions(tp(topic1, 0), tp(topic1, 2)), assignment.get(consumer1)); + assertEquals(partitions(tp(topic2, 0), tp(topic2, 1)), assignment.get(consumer2)); + assertEquals(partitions(tp(topic1, 1)), assignment.get(consumer3)); + assertNull(assignor.partitionsTransferringOwnership); + + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + } + + @Test + public void testTwoConsumersTwoTopicsSixPartitions() { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic1, 3); + partitionsPerTopic.put(topic2, 3); + + subscriptions.put(consumer1, new Subscription(topics(topic1, topic2))); + subscriptions.put(consumer2, new Subscription(topics(topic1, topic2))); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertEquals(partitions(tp(topic1, 0), tp(topic1, 2), tp(topic2, 1)), assignment.get(consumer1)); + assertEquals(partitions(tp(topic1, 1), tp(topic2, 0), tp(topic2, 2)), assignment.get(consumer2)); + assertTrue(assignor.partitionsTransferringOwnership.isEmpty()); + + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + } + + /** + * This unit test is testing consumer owned minQuota partitions, and expected to have maxQuota partitions situation + */ + @Test + public void testConsumerOwningMinQuotaExpectedMaxQuota() { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic1, 2); + partitionsPerTopic.put(topic2, 3); + + List subscribedTopics = topics(topic1, topic2); + + subscriptions.put(consumer1, + buildSubscription(subscribedTopics, partitions(tp(topic1, 0), tp(topic2, 1)))); + subscriptions.put(consumer2, + buildSubscription(subscribedTopics, partitions(tp(topic1, 1), tp(topic2, 2)))); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertEquals(partitions(tp(topic1, 0), tp(topic2, 1), tp(topic2, 0)), assignment.get(consumer1)); + assertEquals(partitions(tp(topic1, 1), tp(topic2, 2)), assignment.get(consumer2)); + assertTrue(assignor.partitionsTransferringOwnership.isEmpty()); + + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + } + + /** + * This unit test is testing consumers owned maxQuota partitions are more than numExpectedMaxCapacityMembers situation + */ + @Test + public void testMaxQuotaConsumerMoreThanNumExpectedMaxCapacityMembers() { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic1, 2); + partitionsPerTopic.put(topic2, 2); + + List subscribedTopics = topics(topic1, topic2); + + subscriptions.put(consumer1, + buildSubscription(subscribedTopics, partitions(tp(topic1, 0), tp(topic2, 0)))); + subscriptions.put(consumer2, + buildSubscription(subscribedTopics, partitions(tp(topic1, 1), tp(topic2, 1)))); + subscriptions.put(consumer3, buildSubscription(subscribedTopics, Collections.emptyList())); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertEquals(Collections.singletonMap(tp(topic2, 0), consumer3), assignor.partitionsTransferringOwnership); + + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertEquals(partitions(tp(topic1, 0)), assignment.get(consumer1)); + assertEquals(partitions(tp(topic1, 1), tp(topic2, 1)), assignment.get(consumer2)); + assertEquals(partitions(tp(topic2, 0)), assignment.get(consumer3)); + + assertTrue(isFullyBalanced(assignment)); + } + + /** + * This unit test is testing all consumers owned less than minQuota partitions situation + */ + @Test + public void testAllConsumersAreUnderMinQuota() { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic1, 2); + partitionsPerTopic.put(topic2, 3); + + List subscribedTopics = topics(topic1, topic2); + + subscriptions.put(consumer1, + buildSubscription(subscribedTopics, partitions(tp(topic1, 0)))); + subscriptions.put(consumer2, + buildSubscription(subscribedTopics, partitions(tp(topic1, 1)))); + subscriptions.put(consumer3, buildSubscription(subscribedTopics, Collections.emptyList())); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertTrue(assignor.partitionsTransferringOwnership.isEmpty()); + + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertEquals(partitions(tp(topic1, 0), tp(topic2, 1)), assignment.get(consumer1)); + assertEquals(partitions(tp(topic1, 1), tp(topic2, 2)), assignment.get(consumer2)); + assertEquals(partitions(tp(topic2, 0)), assignment.get(consumer3)); + + assertTrue(isFullyBalanced(assignment)); + } + + @Test + public void testAddRemoveConsumerOneTopic() { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic, 3); + subscriptions.put(consumer1, new Subscription(topics(topic))); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertEquals(partitions(tp(topic, 0), tp(topic, 1), tp(topic, 2)), assignment.get(consumer1)); + + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + + subscriptions.put(consumer1, buildSubscription(topics(topic), assignment.get(consumer1))); + subscriptions.put(consumer2, buildSubscription(topics(topic), Collections.emptyList())); + assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertEquals(Collections.singletonMap(tp(topic, 2), consumer2), assignor.partitionsTransferringOwnership); + + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertEquals(partitions(tp(topic, 0), tp(topic, 1)), assignment.get(consumer1)); + assertEquals(partitions(tp(topic, 2)), assignment.get(consumer2)); + assertTrue(isFullyBalanced(assignment)); + + subscriptions.remove(consumer1); + subscriptions.put(consumer2, buildSubscription(topics(topic), assignment.get(consumer2))); + assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertEquals(new HashSet<>(partitions(tp(topic, 2), tp(topic, 1), tp(topic, 0))), + new HashSet<>(assignment.get(consumer2))); + assertTrue(assignor.partitionsTransferringOwnership.isEmpty()); + + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + } + + @Test + public void testAddRemoveTwoConsumersTwoTopics() { + List allTopics = topics(topic1, topic2); + + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic1, 3); + partitionsPerTopic.put(topic2, 4); + subscriptions.put(consumer1, new Subscription(allTopics)); + subscriptions.put(consumer2, new Subscription(allTopics)); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertEquals(partitions(tp(topic1, 0), tp(topic1, 2), tp(topic2, 1), tp(topic2, 3)), assignment.get(consumer1)); + assertEquals(partitions(tp(topic1, 1), tp(topic2, 0), tp(topic2, 2)), assignment.get(consumer2)); + assertTrue(assignor.partitionsTransferringOwnership.isEmpty()); + + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + + // add 2 consumers + subscriptions.put(consumer1, buildSubscription(allTopics, assignment.get(consumer1))); + subscriptions.put(consumer2, buildSubscription(allTopics, assignment.get(consumer2))); + subscriptions.put(consumer3, buildSubscription(allTopics, Collections.emptyList())); + subscriptions.put(consumer4, buildSubscription(allTopics, Collections.emptyList())); + assignment = assignor.assign(partitionsPerTopic, subscriptions); + + Map expectedPartitionsTransferringOwnership = new HashMap<>(); + expectedPartitionsTransferringOwnership.put(tp(topic2, 1), consumer3); + expectedPartitionsTransferringOwnership.put(tp(topic2, 3), consumer3); + expectedPartitionsTransferringOwnership.put(tp(topic2, 2), consumer4); + assertEquals(expectedPartitionsTransferringOwnership, assignor.partitionsTransferringOwnership); + + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertEquals(partitions(tp(topic1, 0), tp(topic1, 2)), assignment.get(consumer1)); + assertEquals(partitions(tp(topic1, 1), tp(topic2, 0)), assignment.get(consumer2)); + assertEquals(partitions(tp(topic2, 1), tp(topic2, 3)), assignment.get(consumer3)); + assertEquals(partitions(tp(topic2, 2)), assignment.get(consumer4)); + assertTrue(isFullyBalanced(assignment)); + + // remove 2 consumers + subscriptions.remove(consumer1); + subscriptions.remove(consumer2); + subscriptions.put(consumer3, buildSubscription(allTopics, assignment.get(consumer3))); + subscriptions.put(consumer4, buildSubscription(allTopics, assignment.get(consumer4))); + assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertEquals(partitions(tp(topic2, 1), tp(topic2, 3), tp(topic1, 0), tp(topic2, 0)), assignment.get(consumer3)); + assertEquals(partitions(tp(topic2, 2), tp(topic1, 1), tp(topic1, 2)), assignment.get(consumer4)); + + assertTrue(assignor.partitionsTransferringOwnership.isEmpty()); + + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + } + + /** + * This unit test performs sticky assignment for a scenario that round robin assignor handles poorly. + * Topics (partitions per topic): topic1 (2), topic2 (1), topic3 (2), topic4 (1), topic5 (2) + * Subscriptions: + * - consumer1: topic1, topic2, topic3, topic4, topic5 + * - consumer2: topic1, topic3, topic5 + * - consumer3: topic1, topic3, topic5 + * - consumer4: topic1, topic2, topic3, topic4, topic5 + * Round Robin Assignment Result: + * - consumer1: topic1-0, topic3-0, topic5-0 + * - consumer2: topic1-1, topic3-1, topic5-1 + * - consumer3: + * - consumer4: topic2-0, topic4-0 + * Sticky Assignment Result: + * - consumer1: topic2-0, topic3-0 + * - consumer2: topic1-0, topic3-1 + * - consumer3: topic1-1, topic5-0 + * - consumer4: topic4-0, topic5-1 + */ + @Test + public void testPoorRoundRobinAssignmentScenario() { + Map partitionsPerTopic = new HashMap<>(); + for (int i = 1; i <= 5; i++) + partitionsPerTopic.put(String.format("topic%d", i), (i % 2) + 1); + + subscriptions.put("consumer1", + new Subscription(topics("topic1", "topic2", "topic3", "topic4", "topic5"))); + subscriptions.put("consumer2", + new Subscription(topics("topic1", "topic3", "topic5"))); + subscriptions.put("consumer3", + new Subscription(topics("topic1", "topic3", "topic5"))); + subscriptions.put("consumer4", + new Subscription(topics("topic1", "topic2", "topic3", "topic4", "topic5"))); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + } + + @Test + public void testAddRemoveTopicTwoConsumers() { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic, 3); + subscriptions.put(consumer1, new Subscription(topics(topic))); + subscriptions.put(consumer2, new Subscription(topics(topic))); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertTrue(assignor.partitionsTransferringOwnership.isEmpty()); + // verify balance + assertTrue(isFullyBalanced(assignment)); + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + // verify stickiness + List consumer1Assignment1 = assignment.get(consumer1); + List consumer2Assignment1 = assignment.get(consumer2); + assertTrue((consumer1Assignment1.size() == 1 && consumer2Assignment1.size() == 2) || + (consumer1Assignment1.size() == 2 && consumer2Assignment1.size() == 1)); + + partitionsPerTopic.put(topic2, 3); + subscriptions.put(consumer1, buildSubscription(topics(topic, topic2), assignment.get(consumer1))); + subscriptions.put(consumer2, buildSubscription(topics(topic, topic2), assignment.get(consumer2))); + + assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertTrue(assignor.partitionsTransferringOwnership.isEmpty()); + // verify balance + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + // verify stickiness + List consumer1assignment = assignment.get(consumer1); + List consumer2assignment = assignment.get(consumer2); + assertTrue(consumer1assignment.size() == 3 && consumer2assignment.size() == 3); + assertTrue(consumer1assignment.containsAll(consumer1Assignment1)); + assertTrue(consumer2assignment.containsAll(consumer2Assignment1)); + + partitionsPerTopic.remove(topic); + subscriptions.put(consumer1, buildSubscription(topics(topic2), assignment.get(consumer1))); + subscriptions.put(consumer2, buildSubscription(topics(topic2), assignment.get(consumer2))); + + assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertTrue(assignor.partitionsTransferringOwnership.isEmpty()); + // verify balance + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + // verify stickiness + List consumer1Assignment3 = assignment.get(consumer1); + List consumer2Assignment3 = assignment.get(consumer2); + assertTrue((consumer1Assignment3.size() == 1 && consumer2Assignment3.size() == 2) || + (consumer1Assignment3.size() == 2 && consumer2Assignment3.size() == 1)); + assertTrue(consumer1assignment.containsAll(consumer1Assignment3)); + assertTrue(consumer2assignment.containsAll(consumer2Assignment3)); + } + + @Test + public void testReassignmentAfterOneConsumerLeaves() { + Map partitionsPerTopic = new HashMap<>(); + for (int i = 1; i < 20; i++) + partitionsPerTopic.put(getTopicName(i, 20), i); + + for (int i = 1; i < 20; i++) { + List topics = new ArrayList<>(); + for (int j = 1; j <= i; j++) + topics.add(getTopicName(j, 20)); + subscriptions.put(getConsumerName(i, 20), new Subscription(topics)); + } + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + + for (int i = 1; i < 20; i++) { + String consumer = getConsumerName(i, 20); + subscriptions.put(consumer, + buildSubscription(subscriptions.get(consumer).topics(), assignment.get(consumer))); + } + subscriptions.remove("consumer10"); + + assignment = assignor.assign(partitionsPerTopic, subscriptions); + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(assignor.isSticky()); + } + + + @Test + public void testReassignmentAfterOneConsumerAdded() { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put("topic", 20); + + for (int i = 1; i < 10; i++) + subscriptions.put(getConsumerName(i, 10), + new Subscription(topics("topic"))); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertTrue(assignor.partitionsTransferringOwnership.isEmpty()); + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + + // add a new consumer + subscriptions.put(getConsumerName(10, 10), new Subscription(topics("topic"))); + + assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertTrue(assignor.partitionsTransferringOwnership.isEmpty()); + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + } + + @Test + public void testSameSubscriptions() { + Map partitionsPerTopic = new HashMap<>(); + for (int i = 1; i < 15; i++) + partitionsPerTopic.put(getTopicName(i, 15), i); + + for (int i = 1; i < 9; i++) { + List topics = new ArrayList<>(); + for (int j = 1; j <= partitionsPerTopic.size(); j++) + topics.add(getTopicName(j, 15)); + subscriptions.put(getConsumerName(i, 9), new Subscription(topics)); + } + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertTrue(assignor.partitionsTransferringOwnership.isEmpty()); + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + + for (int i = 1; i < 9; i++) { + String consumer = getConsumerName(i, 9); + subscriptions.put(consumer, + buildSubscription(subscriptions.get(consumer).topics(), assignment.get(consumer))); + } + subscriptions.remove(getConsumerName(5, 9)); + + assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertTrue(assignor.partitionsTransferringOwnership.isEmpty()); + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + } + + @Timeout(30) + @Test + public void testLargeAssignmentAndGroupWithUniformSubscription() { + // 1 million partitions! + int topicCount = 500; + int partitionCount = 2_000; + int consumerCount = 2_000; + + List topics = new ArrayList<>(); + Map partitionsPerTopic = new HashMap<>(); + for (int i = 0; i < topicCount; i++) { + String topicName = getTopicName(i, topicCount); + topics.add(topicName); + partitionsPerTopic.put(topicName, partitionCount); + } + + for (int i = 0; i < consumerCount; i++) { + subscriptions.put(getConsumerName(i, consumerCount), new Subscription(topics)); + } + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + + for (int i = 1; i < consumerCount; i++) { + String consumer = getConsumerName(i, consumerCount); + subscriptions.put(consumer, buildSubscription(topics, assignment.get(consumer))); + } + + assignor.assign(partitionsPerTopic, subscriptions); + } + + @Timeout(60) + @Test + public void testLargeAssignmentAndGroupWithNonEqualSubscription() { + // 1 million partitions! + int topicCount = 500; + int partitionCount = 2_000; + int consumerCount = 2_000; + + List topics = new ArrayList<>(); + Map partitionsPerTopic = new HashMap<>(); + for (int i = 0; i < topicCount; i++) { + String topicName = getTopicName(i, topicCount); + topics.add(topicName); + partitionsPerTopic.put(topicName, partitionCount); + } + for (int i = 0; i < consumerCount; i++) { + if (i == consumerCount - 1) { + subscriptions.put(getConsumerName(i, consumerCount), new Subscription(topics.subList(0, 1))); + } else { + subscriptions.put(getConsumerName(i, consumerCount), new Subscription(topics)); + } + } + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + + for (int i = 1; i < consumerCount; i++) { + String consumer = getConsumerName(i, consumerCount); + if (i == consumerCount - 1) { + subscriptions.put(consumer, buildSubscription(topics.subList(0, 1), assignment.get(consumer))); + } else { + subscriptions.put(consumer, buildSubscription(topics, assignment.get(consumer))); + } + } + + assignor.assign(partitionsPerTopic, subscriptions); + } + + @Test + public void testLargeAssignmentWithMultipleConsumersLeavingAndRandomSubscription() { + Random rand = new Random(); + int topicCount = 40; + int consumerCount = 200; + + Map partitionsPerTopic = new HashMap<>(); + for (int i = 0; i < topicCount; i++) + partitionsPerTopic.put(getTopicName(i, topicCount), rand.nextInt(10) + 1); + + for (int i = 0; i < consumerCount; i++) { + List topics = new ArrayList<>(); + for (int j = 0; j < rand.nextInt(20); j++) + topics.add(getTopicName(rand.nextInt(topicCount), topicCount)); + subscriptions.put(getConsumerName(i, consumerCount), new Subscription(topics)); + } + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + + for (int i = 1; i < consumerCount; i++) { + String consumer = getConsumerName(i, consumerCount); + subscriptions.put(consumer, + buildSubscription(subscriptions.get(consumer).topics(), assignment.get(consumer))); + } + for (int i = 0; i < 50; ++i) { + String c = getConsumerName(rand.nextInt(consumerCount), consumerCount); + subscriptions.remove(c); + } + + assignment = assignor.assign(partitionsPerTopic, subscriptions); + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(assignor.isSticky()); + } + + @Test + public void testNewSubscription() { + Map partitionsPerTopic = new HashMap<>(); + for (int i = 1; i < 5; i++) + partitionsPerTopic.put(getTopicName(i, 5), 1); + + for (int i = 0; i < 3; i++) { + List topics = new ArrayList<>(); + for (int j = i; j <= 3 * i - 2; j++) + topics.add(getTopicName(j, 5)); + subscriptions.put(getConsumerName(i, 3), new Subscription(topics)); + } + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + + subscriptions.get(getConsumerName(0, 3)).topics().add(getTopicName(1, 5)); + + assignment = assignor.assign(partitionsPerTopic, subscriptions); + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(assignor.isSticky()); + } + + @Test + public void testMoveExistingAssignments() { + String topic4 = "topic4"; + String topic5 = "topic5"; + String topic6 = "topic6"; + + Map partitionsPerTopic = new HashMap<>(); + for (int i = 1; i <= 6; i++) + partitionsPerTopic.put(String.format("topic%d", i), 1); + + subscriptions.put(consumer1, + buildSubscription(topics(topic1, topic2), + partitions(tp(topic1, 0)))); + subscriptions.put(consumer2, + buildSubscription(topics(topic1, topic2, topic3, topic4), + partitions(tp(topic2, 0), tp(topic3, 0)))); + subscriptions.put(consumer3, + buildSubscription(topics(topic2, topic3, topic4, topic5, topic6), + partitions(tp(topic4, 0), tp(topic5, 0), tp(topic6, 0)))); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertNull(assignor.partitionsTransferringOwnership); + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + } + + @Test + public void testStickiness() { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic1, 3); + + subscriptions.put(consumer1, new Subscription(topics(topic1))); + subscriptions.put(consumer2, new Subscription(topics(topic1))); + subscriptions.put(consumer3, new Subscription(topics(topic1))); + subscriptions.put(consumer4, new Subscription(topics(topic1))); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertTrue(assignor.partitionsTransferringOwnership.isEmpty()); + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + Map partitionsAssigned = new HashMap<>(); + + Set>> assignments = assignment.entrySet(); + for (Map.Entry> entry: assignments) { + String consumer = entry.getKey(); + List topicPartitions = entry.getValue(); + int size = topicPartitions.size(); + assertTrue(size <= 1, "Consumer " + consumer + " is assigned more topic partitions than expected."); + if (size == 1) + partitionsAssigned.put(consumer, topicPartitions.get(0)); + } + + // removing the potential group leader + subscriptions.remove(consumer1); + subscriptions.put(consumer2, + buildSubscription(topics(topic1), assignment.get(consumer2))); + subscriptions.put(consumer3, + buildSubscription(topics(topic1), assignment.get(consumer3))); + subscriptions.put(consumer4, + buildSubscription(topics(topic1), assignment.get(consumer4))); + + + assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertTrue(assignor.partitionsTransferringOwnership.isEmpty()); + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + + assignments = assignment.entrySet(); + for (Map.Entry> entry: assignments) { + String consumer = entry.getKey(); + List topicPartitions = entry.getValue(); + assertEquals(1, topicPartitions.size(), "Consumer " + consumer + " is assigned more topic partitions than expected."); + assertTrue((!partitionsAssigned.containsKey(consumer)) || (assignment.get(consumer).contains(partitionsAssigned.get(consumer))), + "Stickiness was not honored for consumer " + consumer); + } + } + + @Test + public void testAssignmentUpdatedForDeletedTopic() { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic1, 1); + partitionsPerTopic.put(topic3, 100); + subscriptions = Collections.singletonMap(consumerId, new Subscription(topics(topic1, topic2, topic3))); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertTrue(assignor.partitionsTransferringOwnership.isEmpty()); + assertEquals(assignment.values().stream().mapToInt(List::size).sum(), 1 + 100); + assertEquals(Collections.singleton(consumerId), assignment.keySet()); + assertTrue(isFullyBalanced(assignment)); + } + + @Test + public void testNoExceptionThrownWhenOnlySubscribedTopicDeleted() { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic, 3); + subscriptions.put(consumerId, new Subscription(topics(topic))); + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertTrue(assignor.partitionsTransferringOwnership.isEmpty()); + subscriptions.put(consumerId, buildSubscription(topics(topic), assignment.get(consumerId))); + + assignment = assignor.assign(Collections.emptyMap(), subscriptions); + assertTrue(assignor.partitionsTransferringOwnership.isEmpty()); + assertEquals(assignment.size(), 1); + assertTrue(assignment.get(consumerId).isEmpty()); + } + + @Test + public void testReassignmentWithRandomSubscriptionsAndChanges() { + final int minNumConsumers = 20; + final int maxNumConsumers = 40; + final int minNumTopics = 10; + final int maxNumTopics = 20; + + for (int round = 1; round <= 100; ++round) { + int numTopics = minNumTopics + new Random().nextInt(maxNumTopics - minNumTopics); + + ArrayList topics = new ArrayList<>(); + + Map partitionsPerTopic = new HashMap<>(); + for (int i = 0; i < numTopics; ++i) { + topics.add(getTopicName(i, maxNumTopics)); + partitionsPerTopic.put(getTopicName(i, maxNumTopics), i + 1); + } + + int numConsumers = minNumConsumers + new Random().nextInt(maxNumConsumers - minNumConsumers); + + for (int i = 0; i < numConsumers; ++i) { + List sub = Utils.sorted(getRandomSublist(topics)); + subscriptions.put(getConsumerName(i, maxNumConsumers), new Subscription(sub)); + } + + assignor = createAssignor(); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + + subscriptions.clear(); + for (int i = 0; i < numConsumers; ++i) { + List sub = Utils.sorted(getRandomSublist(topics)); + String consumer = getConsumerName(i, maxNumConsumers); + subscriptions.put(consumer, buildSubscription(sub, assignment.get(consumer))); + } + + assignment = assignor.assign(partitionsPerTopic, subscriptions); + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(assignor.isSticky()); + } + } + + @Test + public void testAllConsumersReachExpectedQuotaAndAreConsideredFilled() { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic, 4); + + subscriptions.put(consumer1, buildSubscription(topics(topic), partitions(tp(topic, 0), tp(topic, 1)))); + subscriptions.put(consumer2, buildSubscription(topics(topic), partitions(tp(topic, 2)))); + subscriptions.put(consumer3, buildSubscription(topics(topic), Collections.emptyList())); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertEquals(partitions(tp(topic, 0), tp(topic, 1)), assignment.get(consumer1)); + assertEquals(partitions(tp(topic, 2)), assignment.get(consumer2)); + assertEquals(partitions(tp(topic, 3)), assignment.get(consumer3)); + assertTrue(assignor.partitionsTransferringOwnership.isEmpty()); + + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + } + + @Test + public void testOwnedPartitionsAreInvalidatedForConsumerWithStaleGeneration() { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic, 3); + partitionsPerTopic.put(topic2, 3); + + int currentGeneration = 10; + + subscriptions.put(consumer1, buildSubscriptionWithGeneration(topics(topic, topic2), partitions(tp(topic, 0), tp(topic, 2), tp(topic2, 1)), currentGeneration)); + subscriptions.put(consumer2, buildSubscriptionWithGeneration(topics(topic, topic2), partitions(tp(topic, 0), tp(topic, 2), tp(topic2, 1)), currentGeneration - 1)); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertEquals(new HashSet<>(partitions(tp(topic, 0), tp(topic, 2), tp(topic2, 1))), new HashSet<>(assignment.get(consumer1))); + assertEquals(new HashSet<>(partitions(tp(topic, 1), tp(topic2, 0), tp(topic2, 2))), new HashSet<>(assignment.get(consumer2))); + assertTrue(assignor.partitionsTransferringOwnership.isEmpty()); + + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + } + + @Test + public void testOwnedPartitionsAreInvalidatedForConsumerWithNoGeneration() { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic, 3); + partitionsPerTopic.put(topic2, 3); + + int currentGeneration = 10; + + subscriptions.put(consumer1, buildSubscriptionWithGeneration(topics(topic, topic2), partitions(tp(topic, 0), tp(topic, 2), tp(topic2, 1)), currentGeneration)); + subscriptions.put(consumer2, buildSubscriptionWithGeneration(topics(topic, topic2), partitions(tp(topic, 0), tp(topic, 2), tp(topic2, 1)), DEFAULT_GENERATION)); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + assertEquals(new HashSet<>(partitions(tp(topic, 0), tp(topic, 2), tp(topic2, 1))), new HashSet<>(assignment.get(consumer1))); + assertEquals(new HashSet<>(partitions(tp(topic, 1), tp(topic2, 0), tp(topic2, 2))), new HashSet<>(assignment.get(consumer2))); + assertTrue(assignor.partitionsTransferringOwnership.isEmpty()); + + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertTrue(isFullyBalanced(assignment)); + } + + @Test + public void testPartitionsTransferringOwnershipIncludeThePartitionClaimedByMultipleConsumersInSameGeneration() { + Map partitionsPerTopic = new HashMap<>(); + partitionsPerTopic.put(topic, 3); + + // partition topic-0 is owned by multiple consumer + subscriptions.put(consumer1, buildSubscription(topics(topic), partitions(tp(topic, 0), tp(topic, 1)))); + subscriptions.put(consumer2, buildSubscription(topics(topic), partitions(tp(topic, 0), tp(topic, 2)))); + subscriptions.put(consumer3, buildSubscription(topics(topic), emptyList())); + + Map> assignment = assignor.assign(partitionsPerTopic, subscriptions); + // we should include the partitions claimed by multiple consumers in partitionsTransferringOwnership + assertEquals(Collections.singletonMap(tp(topic, 0), consumer3), assignor.partitionsTransferringOwnership); + + verifyValidityAndBalance(subscriptions, assignment, partitionsPerTopic); + assertEquals(partitions(tp(topic, 1)), assignment.get(consumer1)); + assertEquals(partitions(tp(topic, 2)), assignment.get(consumer2)); + assertEquals(partitions(tp(topic, 0)), assignment.get(consumer3)); + assertTrue(isFullyBalanced(assignment)); + } + + private String getTopicName(int i, int maxNum) { + return getCanonicalName("t", i, maxNum); + } + + private String getConsumerName(int i, int maxNum) { + return getCanonicalName("c", i, maxNum); + } + + private String getCanonicalName(String str, int i, int maxNum) { + return str + pad(i, Integer.toString(maxNum).length()); + } + + private String pad(int num, int digits) { + StringBuilder sb = new StringBuilder(); + int iDigits = Integer.toString(num).length(); + + for (int i = 1; i <= digits - iDigits; ++i) + sb.append("0"); + + sb.append(num); + return sb.toString(); + } + + protected static List topics(String... topics) { + return Arrays.asList(topics); + } + + protected static List partitions(TopicPartition... partitions) { + return Arrays.asList(partitions); + } + + protected static TopicPartition tp(String topic, int partition) { + return new TopicPartition(topic, partition); + } + + protected static boolean isFullyBalanced(Map> assignment) { + int min = Integer.MAX_VALUE; + int max = Integer.MIN_VALUE; + for (List topicPartitions: assignment.values()) { + int size = topicPartitions.size(); + if (size < min) + min = size; + if (size > max) + max = size; + } + return max - min <= 1; + } + + protected static List getRandomSublist(ArrayList list) { + List selectedItems = new ArrayList<>(list); + int len = list.size(); + Random random = new Random(); + int howManyToRemove = random.nextInt(len); + + for (int i = 1; i <= howManyToRemove; ++i) + selectedItems.remove(random.nextInt(selectedItems.size())); + + return selectedItems; + } + + /** + * Verifies that the given assignment is valid with respect to the given subscriptions + * Validity requirements: + * - each consumer is subscribed to topics of all partitions assigned to it, and + * - each partition is assigned to no more than one consumer + * Balance requirements: + * - the assignment is fully balanced (the numbers of topic partitions assigned to consumers differ by at most one), or + * - there is no topic partition that can be moved from one consumer to another with 2+ fewer topic partitions + * + * @param subscriptions: topic subscriptions of each consumer + * @param assignments: given assignment for balance check + * @param partitionsPerTopic: number of partitions per topic + */ + protected void verifyValidityAndBalance(Map subscriptions, + Map> assignments, + Map partitionsPerTopic) { + int size = subscriptions.size(); + assert size == assignments.size(); + + List consumers = Utils.sorted(assignments.keySet()); + + for (int i = 0; i < size; ++i) { + String consumer = consumers.get(i); + List partitions = assignments.get(consumer); + for (TopicPartition partition: partitions) + assertTrue(subscriptions.get(consumer).topics().contains(partition.topic()), + "Error: Partition " + partition + "is assigned to c" + i + ", but it is not subscribed to Topic t" + + partition.topic() + "\nSubscriptions: " + subscriptions + "\nAssignments: " + assignments); + + if (i == size - 1) + continue; + + for (int j = i + 1; j < size; ++j) { + String otherConsumer = consumers.get(j); + List otherPartitions = assignments.get(otherConsumer); + + Set intersection = new HashSet<>(partitions); + intersection.retainAll(otherPartitions); + assertTrue(intersection.isEmpty(), + "Error: Consumers c" + i + " and c" + j + " have common partitions assigned to them: " + intersection + + "\nSubscriptions: " + subscriptions + "\nAssignments: " + assignments); + + int len = partitions.size(); + int otherLen = otherPartitions.size(); + + if (Math.abs(len - otherLen) <= 1) + continue; + + Map> map = CollectionUtils.groupPartitionsByTopic(partitions); + Map> otherMap = CollectionUtils.groupPartitionsByTopic(otherPartitions); + + int moreLoaded = len > otherLen ? i : j; + int lessLoaded = len > otherLen ? j : i; + + // If there's any overlap in the subscribed topics, we should have been able to balance partitions + for (String topic: map.keySet()) { + assertFalse(otherMap.containsKey(topic), + "Error: Some partitions can be moved from c" + moreLoaded + " to c" + lessLoaded + " to achieve a better balance" + + "\nc" + i + " has " + len + " partitions, and c" + j + " has " + otherLen + " partitions." + + "\nSubscriptions: " + subscriptions + + "\nAssignments: " + assignments); + } + } + } + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java new file mode 100644 index 0000000..96aaf8b --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerCoordinatorTest.java @@ -0,0 +1,3426 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.clients.ClientResponse; +import org.apache.kafka.clients.GroupRebalanceConfig; +import org.apache.kafka.clients.MockClient; +import org.apache.kafka.clients.NodeApiVersions; +import org.apache.kafka.clients.consumer.CommitFailedException; +import org.apache.kafka.clients.consumer.ConsumerGroupMetadata; +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.consumer.OffsetCommitCallback; +import org.apache.kafka.clients.consumer.OffsetResetStrategy; +import org.apache.kafka.clients.consumer.RetriableCommitFailedException; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.ApiException; +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.errors.DisconnectException; +import org.apache.kafka.common.errors.FencedInstanceIdException; +import org.apache.kafka.common.errors.GroupAuthorizationException; +import org.apache.kafka.common.errors.OffsetMetadataTooLarge; +import org.apache.kafka.common.errors.RebalanceInProgressException; +import org.apache.kafka.common.errors.TopicAuthorizationException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.errors.WakeupException; +import org.apache.kafka.common.internals.ClusterResourceListeners; +import org.apache.kafka.common.internals.Topic; +import org.apache.kafka.common.message.HeartbeatResponseData; +import org.apache.kafka.common.message.JoinGroupRequestData; +import org.apache.kafka.common.message.JoinGroupResponseData; +import org.apache.kafka.common.message.LeaveGroupRequestData.MemberIdentity; +import org.apache.kafka.common.message.LeaveGroupResponseData; +import org.apache.kafka.common.message.OffsetCommitRequestData; +import org.apache.kafka.common.message.OffsetCommitResponseData; +import org.apache.kafka.common.message.SyncGroupResponseData; +import org.apache.kafka.common.metrics.KafkaMetric; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.protocol.types.Field; +import org.apache.kafka.common.protocol.types.Schema; +import org.apache.kafka.common.protocol.types.Struct; +import org.apache.kafka.common.protocol.types.Type; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.requests.FindCoordinatorResponse; +import org.apache.kafka.common.requests.HeartbeatResponse; +import org.apache.kafka.common.requests.JoinGroupRequest; +import org.apache.kafka.common.requests.JoinGroupResponse; +import org.apache.kafka.common.requests.LeaveGroupRequest; +import org.apache.kafka.common.requests.LeaveGroupResponse; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.requests.OffsetCommitRequest; +import org.apache.kafka.common.requests.OffsetCommitResponse; +import org.apache.kafka.common.requests.OffsetFetchResponse; +import org.apache.kafka.common.requests.RequestTestUtils; +import org.apache.kafka.common.requests.SyncGroupRequest; +import org.apache.kafka.common.requests.SyncGroupResponse; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.SystemTime; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.regex.Pattern; + +import static java.util.Collections.emptyList; +import static java.util.Collections.emptySet; +import static java.util.Collections.singleton; +import static java.util.Collections.singletonList; +import static java.util.Collections.singletonMap; +import static org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.RebalanceProtocol.COOPERATIVE; +import static org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.RebalanceProtocol.EAGER; +import static org.apache.kafka.clients.consumer.CooperativeStickyAssignor.COOPERATIVE_STICKY_ASSIGNOR_NAME; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.test.TestUtils.toSet; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public abstract class ConsumerCoordinatorTest { + private final String topic1 = "test1"; + private final String topic2 = "test2"; + private final TopicPartition t1p = new TopicPartition(topic1, 0); + private final TopicPartition t2p = new TopicPartition(topic2, 0); + private final String groupId = "test-group"; + private final Optional groupInstanceId = Optional.of("test-instance"); + private final int rebalanceTimeoutMs = 60000; + private final int sessionTimeoutMs = 10000; + private final int heartbeatIntervalMs = 5000; + private final long retryBackoffMs = 100; + private final int autoCommitIntervalMs = 2000; + private final int requestTimeoutMs = 30000; + private final MockTime time = new MockTime(); + private GroupRebalanceConfig rebalanceConfig; + + private final ConsumerPartitionAssignor.RebalanceProtocol protocol; + private final MockPartitionAssignor partitionAssignor; + private final ThrowOnAssignmentAssignor throwOnAssignmentAssignor; + private final ThrowOnAssignmentAssignor throwFatalErrorOnAssignmentAssignor; + private final List assignors; + private final Map assignorMap; + private final String consumerId = "consumer"; + private final String consumerId2 = "consumer2"; + + private MockClient client; + private MetadataResponse metadataResponse = RequestTestUtils.metadataUpdateWith(1, new HashMap() { + { + put(topic1, 1); + put(topic2, 1); + } + }); + private Node node = metadataResponse.brokers().iterator().next(); + private SubscriptionState subscriptions; + private ConsumerMetadata metadata; + private Metrics metrics; + private ConsumerNetworkClient consumerClient; + private MockRebalanceListener rebalanceListener; + private MockCommitCallback mockOffsetCommitCallback; + private ConsumerCoordinator coordinator; + + public ConsumerCoordinatorTest(final ConsumerPartitionAssignor.RebalanceProtocol protocol) { + this.protocol = protocol; + + this.partitionAssignor = new MockPartitionAssignor(Collections.singletonList(protocol)); + this.throwOnAssignmentAssignor = new ThrowOnAssignmentAssignor(Collections.singletonList(protocol), + new KafkaException("Kaboom for assignment!"), + "throw-on-assignment-assignor"); + this.throwFatalErrorOnAssignmentAssignor = new ThrowOnAssignmentAssignor(Collections.singletonList(protocol), + new IllegalStateException("Illegal state for assignment!"), + "throw-fatal-error-on-assignment-assignor"); + this.assignors = Arrays.asList(partitionAssignor, throwOnAssignmentAssignor, throwFatalErrorOnAssignmentAssignor); + this.assignorMap = mkMap(mkEntry(partitionAssignor.name(), partitionAssignor), + mkEntry(throwOnAssignmentAssignor.name(), throwOnAssignmentAssignor), + mkEntry(throwFatalErrorOnAssignmentAssignor.name(), throwFatalErrorOnAssignmentAssignor)); + } + + @BeforeEach + public void setup() { + LogContext logContext = new LogContext(); + this.subscriptions = new SubscriptionState(logContext, OffsetResetStrategy.EARLIEST); + this.metadata = new ConsumerMetadata(0, Long.MAX_VALUE, false, + false, subscriptions, logContext, new ClusterResourceListeners()); + this.client = new MockClient(time, metadata); + this.client.updateMetadata(metadataResponse); + this.consumerClient = new ConsumerNetworkClient(logContext, client, metadata, time, 100, + requestTimeoutMs, Integer.MAX_VALUE); + this.metrics = new Metrics(time); + this.rebalanceListener = new MockRebalanceListener(); + this.mockOffsetCommitCallback = new MockCommitCallback(); + this.partitionAssignor.clear(); + this.rebalanceConfig = buildRebalanceConfig(Optional.empty()); + this.coordinator = buildCoordinator(rebalanceConfig, + metrics, + assignors, + false, + subscriptions); + } + + private GroupRebalanceConfig buildRebalanceConfig(Optional groupInstanceId) { + return new GroupRebalanceConfig(sessionTimeoutMs, + rebalanceTimeoutMs, + heartbeatIntervalMs, + groupId, + groupInstanceId, + retryBackoffMs, + !groupInstanceId.isPresent()); + } + + @AfterEach + public void teardown() { + this.metrics.close(); + this.coordinator.close(time.timer(0)); + } + + @Test + public void testMetrics() { + assertNotNull(getMetric("commit-latency-avg")); + assertNotNull(getMetric("commit-latency-max")); + assertNotNull(getMetric("commit-rate")); + assertNotNull(getMetric("commit-total")); + assertNotNull(getMetric("partition-revoked-latency-avg")); + assertNotNull(getMetric("partition-revoked-latency-max")); + assertNotNull(getMetric("partition-assigned-latency-avg")); + assertNotNull(getMetric("partition-assigned-latency-max")); + assertNotNull(getMetric("partition-lost-latency-avg")); + assertNotNull(getMetric("partition-lost-latency-max")); + assertNotNull(getMetric("assigned-partitions")); + + metrics.sensor("commit-latency").record(1.0d); + metrics.sensor("commit-latency").record(6.0d); + metrics.sensor("commit-latency").record(2.0d); + + assertEquals(3.0d, getMetric("commit-latency-avg").metricValue()); + assertEquals(6.0d, getMetric("commit-latency-max").metricValue()); + assertEquals(0.1d, getMetric("commit-rate").metricValue()); + assertEquals(3.0d, getMetric("commit-total").metricValue()); + + metrics.sensor("partition-revoked-latency").record(1.0d); + metrics.sensor("partition-revoked-latency").record(2.0d); + metrics.sensor("partition-assigned-latency").record(1.0d); + metrics.sensor("partition-assigned-latency").record(2.0d); + metrics.sensor("partition-lost-latency").record(1.0d); + metrics.sensor("partition-lost-latency").record(2.0d); + + assertEquals(1.5d, getMetric("partition-revoked-latency-avg").metricValue()); + assertEquals(2.0d, getMetric("partition-revoked-latency-max").metricValue()); + assertEquals(1.5d, getMetric("partition-assigned-latency-avg").metricValue()); + assertEquals(2.0d, getMetric("partition-assigned-latency-max").metricValue()); + assertEquals(1.5d, getMetric("partition-lost-latency-avg").metricValue()); + assertEquals(2.0d, getMetric("partition-lost-latency-max").metricValue()); + + assertEquals(0.0d, getMetric("assigned-partitions").metricValue()); + subscriptions.assignFromUser(Collections.singleton(t1p)); + assertEquals(1.0d, getMetric("assigned-partitions").metricValue()); + subscriptions.assignFromUser(Utils.mkSet(t1p, t2p)); + assertEquals(2.0d, getMetric("assigned-partitions").metricValue()); + } + + private KafkaMetric getMetric(final String name) { + return metrics.metrics().get(metrics.metricName(name, consumerId + groupId + "-coordinator-metrics")); + } + + @SuppressWarnings("unchecked") + @Test + public void testPerformAssignmentShouldUpdateGroupSubscriptionAfterAssignmentIfNeeded() { + SubscriptionState mockSubscriptionState = Mockito.mock(SubscriptionState.class); + + // the consumer only subscribed to "topic1" + Map> memberSubscriptions = singletonMap(consumerId, singletonList(topic1)); + + List metadata = new ArrayList<>(); + for (Map.Entry> subscriptionEntry : memberSubscriptions.entrySet()) { + ConsumerPartitionAssignor.Subscription subscription = new ConsumerPartitionAssignor.Subscription(subscriptionEntry.getValue()); + ByteBuffer buf = ConsumerProtocol.serializeSubscription(subscription); + metadata.add(new JoinGroupResponseData.JoinGroupResponseMember() + .setMemberId(subscriptionEntry.getKey()) + .setMetadata(buf.array())); + } + + // normal case: the assignment result will have partitions for only the subscribed topic: "topic1" + partitionAssignor.prepare(Collections.singletonMap(consumerId, singletonList(t1p))); + + try (ConsumerCoordinator coordinator = buildCoordinator(rebalanceConfig, new Metrics(), assignors, false, mockSubscriptionState)) { + coordinator.performAssignment("1", partitionAssignor.name(), metadata); + + ArgumentCaptor> topicsCaptor = ArgumentCaptor.forClass(Collection.class); + // groupSubscribe should be only called 1 time, which is before assignment, + // because the assigned topics are the same as the subscribed topics + Mockito.verify(mockSubscriptionState, Mockito.times(1)).groupSubscribe(topicsCaptor.capture()); + + List> capturedTopics = topicsCaptor.getAllValues(); + + // expected the final group subscribed topics to be updated to "topic1" + Set expectedTopicsGotCalled = new HashSet<>(Arrays.asList(topic1)); + assertEquals(expectedTopicsGotCalled, capturedTopics.get(0)); + } + + Mockito.clearInvocations(mockSubscriptionState); + + // unsubscribed topic partition assigned case: the assignment result will have partitions for (1) subscribed topic: "topic1" + // and (2) the additional unsubscribed topic: "topic2". We should add "topic2" into group subscription list + partitionAssignor.prepare(Collections.singletonMap(consumerId, Arrays.asList(t1p, t2p))); + + try (ConsumerCoordinator coordinator = buildCoordinator(rebalanceConfig, new Metrics(), assignors, false, mockSubscriptionState)) { + coordinator.performAssignment("1", partitionAssignor.name(), metadata); + + ArgumentCaptor> topicsCaptor = ArgumentCaptor.forClass(Collection.class); + // groupSubscribe should be called 2 times, once before assignment, once after assignment + // (because the assigned topics are not the same as the subscribed topics) + Mockito.verify(mockSubscriptionState, Mockito.times(2)).groupSubscribe(topicsCaptor.capture()); + + List> capturedTopics = topicsCaptor.getAllValues(); + + // expected the final group subscribed topics to be updated to "topic1" and "topic2" + Set expectedTopicsGotCalled = new HashSet<>(Arrays.asList(topic1, topic2)); + assertEquals(expectedTopicsGotCalled, capturedTopics.get(1)); + } + } + + public ByteBuffer subscriptionUserData(int generation) { + final String generationKeyName = "generation"; + final Schema cooperativeStickyAssignorUserDataV0 = new Schema( + new Field(generationKeyName, Type.INT32)); + Struct struct = new Struct(cooperativeStickyAssignorUserDataV0); + + struct.set(generationKeyName, generation); + ByteBuffer buffer = ByteBuffer.allocate(cooperativeStickyAssignorUserDataV0.sizeOf(struct)); + cooperativeStickyAssignorUserDataV0.write(buffer, struct); + buffer.flip(); + return buffer; + } + + private List validateCooperativeAssignmentTestSetup() { + // consumer1 and consumer2 subscribed to "topic1" with 2 partitions: t1p, t2p + Map> memberSubscriptions = new HashMap<>(); + List subscribedTopics = singletonList(topic1); + memberSubscriptions.put(consumerId, subscribedTopics); + memberSubscriptions.put(consumerId2, subscribedTopics); + + // the ownedPartition for consumer1 is t1p, t2p + ConsumerPartitionAssignor.Subscription subscriptionConsumer1 = new ConsumerPartitionAssignor.Subscription( + subscribedTopics, subscriptionUserData(1), Arrays.asList(t1p, t2p)); + + // the ownedPartition for consumer2 is empty + ConsumerPartitionAssignor.Subscription subscriptionConsumer2 = new ConsumerPartitionAssignor.Subscription( + subscribedTopics, subscriptionUserData(1), emptyList()); + + List metadata = new ArrayList<>(); + for (Map.Entry> subscriptionEntry : memberSubscriptions.entrySet()) { + ByteBuffer buf = null; + if (subscriptionEntry.getKey().equals(consumerId)) { + buf = ConsumerProtocol.serializeSubscription(subscriptionConsumer1); + } else { + buf = ConsumerProtocol.serializeSubscription(subscriptionConsumer2); + } + + metadata.add(new JoinGroupResponseData.JoinGroupResponseMember() + .setMemberId(subscriptionEntry.getKey()) + .setMetadata(buf.array())); + } + + return metadata; + } + + @Test + public void testPerformAssignmentShouldValidateCooperativeAssignment() { + SubscriptionState mockSubscriptionState = Mockito.mock(SubscriptionState.class); + List metadata = validateCooperativeAssignmentTestSetup(); + + // simulate the custom cooperative assignor didn't revoke the partition first before assign to other consumer + Map> assignment = new HashMap<>(); + assignment.put(consumerId, Arrays.asList(t1p)); + assignment.put(consumerId2, Arrays.asList(t2p)); + partitionAssignor.prepare(assignment); + + try (ConsumerCoordinator coordinator = buildCoordinator(rebalanceConfig, new Metrics(), assignors, false, mockSubscriptionState)) { + if (protocol == COOPERATIVE) { + // in cooperative protocol, we should throw exception when validating cooperative assignment + Exception e = assertThrows(IllegalStateException.class, + () -> coordinator.performAssignment("1", partitionAssignor.name(), metadata)); + assertTrue(e.getMessage().contains("Assignor supporting the COOPERATIVE protocol violates its requirements")); + } else { + // in eager protocol, we should not validate assignment + coordinator.performAssignment("1", partitionAssignor.name(), metadata); + } + } + } + + @Test + public void testPerformAssignmentShouldSkipValidateCooperativeAssignmentForBuiltInCooperativeStickyAssignor() { + SubscriptionState mockSubscriptionState = Mockito.mock(SubscriptionState.class); + List metadata = validateCooperativeAssignmentTestSetup(); + + List assignorsWithCooperativeStickyAssignor = new ArrayList<>(assignors); + // create a mockPartitionAssignor with the same name as cooperative sticky assignor + MockPartitionAssignor mockCooperativeStickyAssignor = new MockPartitionAssignor(Collections.singletonList(protocol)) { + @Override + public String name() { + return COOPERATIVE_STICKY_ASSIGNOR_NAME; + } + }; + assignorsWithCooperativeStickyAssignor.add(mockCooperativeStickyAssignor); + + // simulate the cooperative sticky assignor do the assignment with out-of-date ownedPartition + Map> assignment = new HashMap<>(); + assignment.put(consumerId, Arrays.asList(t1p)); + assignment.put(consumerId2, Arrays.asList(t2p)); + mockCooperativeStickyAssignor.prepare(assignment); + + try (ConsumerCoordinator coordinator = buildCoordinator(rebalanceConfig, new Metrics(), assignorsWithCooperativeStickyAssignor, false, mockSubscriptionState)) { + // should not validate assignment for built-in cooperative sticky assignor + coordinator.performAssignment("1", mockCooperativeStickyAssignor.name(), metadata); + } + } + + @Test + public void testSelectRebalanceProtcol() { + List assignors = new ArrayList<>(); + assignors.add(new MockPartitionAssignor(Collections.singletonList(ConsumerPartitionAssignor.RebalanceProtocol.EAGER))); + assignors.add(new MockPartitionAssignor(Collections.singletonList(COOPERATIVE))); + + // no commonly supported protocols + assertThrows(IllegalArgumentException.class, () -> buildCoordinator(rebalanceConfig, new Metrics(), assignors, false, subscriptions)); + + assignors.clear(); + assignors.add(new MockPartitionAssignor(Arrays.asList(ConsumerPartitionAssignor.RebalanceProtocol.EAGER, COOPERATIVE))); + assignors.add(new MockPartitionAssignor(Arrays.asList(ConsumerPartitionAssignor.RebalanceProtocol.EAGER, COOPERATIVE))); + + // select higher indexed (more advanced) protocols + try (ConsumerCoordinator coordinator = buildCoordinator(rebalanceConfig, new Metrics(), assignors, false, subscriptions)) { + assertEquals(COOPERATIVE, coordinator.getProtocol()); + } + } + + @Test + public void testNormalHeartbeat() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // normal heartbeat + time.sleep(sessionTimeoutMs); + RequestFuture future = coordinator.sendHeartbeatRequest(); // should send out the heartbeat + assertEquals(1, consumerClient.pendingRequestCount()); + assertFalse(future.isDone()); + + client.prepareResponse(heartbeatResponse(Errors.NONE)); + consumerClient.poll(time.timer(0)); + + assertTrue(future.isDone()); + assertTrue(future.succeeded()); + } + + @Test + public void testGroupDescribeUnauthorized() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.GROUP_AUTHORIZATION_FAILED)); + assertThrows(GroupAuthorizationException.class, () -> coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE))); + } + + @Test + public void testGroupReadUnauthorized() { + subscriptions.subscribe(singleton(topic1), rebalanceListener); + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + client.prepareResponse(joinGroupLeaderResponse(0, "memberId", Collections.emptyMap(), + Errors.GROUP_AUTHORIZATION_FAILED)); + assertThrows(GroupAuthorizationException.class, () -> coordinator.poll(time.timer(Long.MAX_VALUE))); + } + + @Test + public void testCoordinatorNotAvailable() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // COORDINATOR_NOT_AVAILABLE will mark coordinator as unknown + time.sleep(sessionTimeoutMs); + RequestFuture future = coordinator.sendHeartbeatRequest(); // should send out the heartbeat + assertEquals(1, consumerClient.pendingRequestCount()); + assertFalse(future.isDone()); + + client.prepareResponse(heartbeatResponse(Errors.COORDINATOR_NOT_AVAILABLE)); + time.sleep(sessionTimeoutMs); + consumerClient.poll(time.timer(0)); + + assertTrue(future.isDone()); + assertTrue(future.failed()); + assertEquals(Errors.COORDINATOR_NOT_AVAILABLE.exception(), future.exception()); + assertTrue(coordinator.coordinatorUnknown()); + } + + @Test + public void testManyInFlightAsyncCommitsWithCoordinatorDisconnect() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + int numRequests = 1000; + TopicPartition tp = new TopicPartition("foo", 0); + final AtomicInteger responses = new AtomicInteger(0); + + for (int i = 0; i < numRequests; i++) { + Map offsets = singletonMap(tp, new OffsetAndMetadata(i)); + coordinator.commitOffsetsAsync(offsets, (offsets1, exception) -> { + responses.incrementAndGet(); + Throwable cause = exception.getCause(); + assertTrue(cause instanceof DisconnectException, + "Unexpected exception cause type: " + (cause == null ? null : cause.getClass())); + }); + } + + coordinator.markCoordinatorUnknown("test cause"); + consumerClient.pollNoWakeup(); + coordinator.invokeCompletedOffsetCommitCallbacks(); + assertEquals(numRequests, responses.get()); + } + + @Test + public void testCoordinatorUnknownInUnsentCallbacksAfterCoordinatorDead() { + // When the coordinator is marked dead, all unsent or in-flight requests are cancelled + // with a disconnect error. This test case ensures that the corresponding callbacks see + // the coordinator as unknown which prevents additional retries to the same coordinator. + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + final AtomicBoolean asyncCallbackInvoked = new AtomicBoolean(false); + + OffsetCommitRequestData offsetCommitRequestData = new OffsetCommitRequestData() + .setGroupId(groupId) + .setTopics(Collections.singletonList(new + OffsetCommitRequestData.OffsetCommitRequestTopic() + .setName("foo") + .setPartitions(Collections.singletonList( + new OffsetCommitRequestData.OffsetCommitRequestPartition() + .setPartitionIndex(0) + .setCommittedLeaderEpoch(RecordBatch.NO_PARTITION_LEADER_EPOCH) + .setCommittedMetadata("") + .setCommittedOffset(13L) + .setCommitTimestamp(0) + )) + ) + ); + + consumerClient.send(coordinator.checkAndGetCoordinator(), new OffsetCommitRequest.Builder(offsetCommitRequestData)) + .compose(new RequestFutureAdapter() { + @Override + public void onSuccess(ClientResponse value, RequestFuture future) {} + + @Override + public void onFailure(RuntimeException e, RequestFuture future) { + assertTrue(e instanceof DisconnectException, "Unexpected exception type: " + e.getClass()); + assertTrue(coordinator.coordinatorUnknown()); + asyncCallbackInvoked.set(true); + } + }); + + coordinator.markCoordinatorUnknown("test cause"); + consumerClient.pollNoWakeup(); + assertTrue(asyncCallbackInvoked.get()); + } + + @Test + public void testNotCoordinator() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // not_coordinator will mark coordinator as unknown + time.sleep(sessionTimeoutMs); + RequestFuture future = coordinator.sendHeartbeatRequest(); // should send out the heartbeat + assertEquals(1, consumerClient.pendingRequestCount()); + assertFalse(future.isDone()); + + client.prepareResponse(heartbeatResponse(Errors.NOT_COORDINATOR)); + time.sleep(sessionTimeoutMs); + consumerClient.poll(time.timer(0)); + + assertTrue(future.isDone()); + assertTrue(future.failed()); + assertEquals(Errors.NOT_COORDINATOR.exception(), future.exception()); + assertTrue(coordinator.coordinatorUnknown()); + } + + @Test + public void testIllegalGeneration() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // illegal_generation will cause re-partition + subscriptions.subscribe(singleton(topic1), rebalanceListener); + subscriptions.assignFromSubscribed(Collections.singletonList(t1p)); + + time.sleep(sessionTimeoutMs); + RequestFuture future = coordinator.sendHeartbeatRequest(); // should send out the heartbeat + assertEquals(1, consumerClient.pendingRequestCount()); + assertFalse(future.isDone()); + + client.prepareResponse(heartbeatResponse(Errors.ILLEGAL_GENERATION)); + time.sleep(sessionTimeoutMs); + consumerClient.poll(time.timer(0)); + + assertTrue(future.isDone()); + assertTrue(future.failed()); + assertEquals(Errors.ILLEGAL_GENERATION.exception(), future.exception()); + assertTrue(coordinator.rejoinNeededOrPending()); + + coordinator.poll(time.timer(0)); + + assertEquals(1, rebalanceListener.lostCount); + assertEquals(Collections.singleton(t1p), rebalanceListener.lost); + } + + @Test + public void testUnsubscribeWithValidGeneration() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + subscriptions.subscribe(singleton(topic1), rebalanceListener); + ByteBuffer buffer = ConsumerProtocol.serializeAssignment( + new ConsumerPartitionAssignor.Assignment(Collections.singletonList(t1p), ByteBuffer.wrap(new byte[0]))); + coordinator.onJoinComplete(1, "memberId", partitionAssignor.name(), buffer); + + coordinator.onLeavePrepare(); + assertEquals(1, rebalanceListener.lostCount); + assertEquals(0, rebalanceListener.revokedCount); + } + + @Test + public void testRevokeExceptionThrownFirstNonBlockingSubCallbacks() { + MockRebalanceListener throwOnRevokeListener = new MockRebalanceListener() { + @Override + public void onPartitionsRevoked(Collection partitions) { + super.onPartitionsRevoked(partitions); + throw new KafkaException("Kaboom on revoke!"); + } + }; + + if (protocol == COOPERATIVE) { + verifyOnCallbackExceptions(throwOnRevokeListener, + throwOnAssignmentAssignor.name(), "Kaboom on revoke!", null); + } else { + // Eager protocol doesn't revoke partitions. + verifyOnCallbackExceptions(throwOnRevokeListener, + throwOnAssignmentAssignor.name(), "Kaboom for assignment!", null); + } + } + + @Test + public void testOnAssignmentExceptionThrownFirstNonBlockingSubCallbacks() { + MockRebalanceListener throwOnAssignListener = new MockRebalanceListener() { + @Override + public void onPartitionsAssigned(Collection partitions) { + super.onPartitionsAssigned(partitions); + throw new KafkaException("Kaboom on partition assign!"); + } + }; + + verifyOnCallbackExceptions(throwOnAssignListener, + throwOnAssignmentAssignor.name(), "Kaboom for assignment!", null); + } + + @Test + public void testOnPartitionsAssignExceptionThrownWhenNoPreviousThrownCallbacks() { + MockRebalanceListener throwOnAssignListener = new MockRebalanceListener() { + @Override + public void onPartitionsAssigned(Collection partitions) { + super.onPartitionsAssigned(partitions); + throw new KafkaException("Kaboom on partition assign!"); + } + }; + + verifyOnCallbackExceptions(throwOnAssignListener, + partitionAssignor.name(), "Kaboom on partition assign!", null); + } + + @Test + public void testOnRevokeExceptionShouldBeRenderedIfNotKafkaException() { + MockRebalanceListener throwOnRevokeListener = new MockRebalanceListener() { + @Override + public void onPartitionsRevoked(Collection partitions) { + super.onPartitionsRevoked(partitions); + throw new IllegalStateException("Illegal state on partition revoke!"); + } + }; + + if (protocol == COOPERATIVE) { + verifyOnCallbackExceptions(throwOnRevokeListener, + throwOnAssignmentAssignor.name(), + "User rebalance callback throws an error", "Illegal state on partition revoke!"); + } else { + // Eager protocol doesn't revoke partitions. + verifyOnCallbackExceptions(throwOnRevokeListener, + throwOnAssignmentAssignor.name(), "Kaboom for assignment!", null); + } + } + + @Test + public void testOnAssignmentExceptionShouldBeRenderedIfNotKafkaException() { + MockRebalanceListener throwOnAssignListener = new MockRebalanceListener() { + @Override + public void onPartitionsAssigned(Collection partitions) { + super.onPartitionsAssigned(partitions); + throw new KafkaException("Kaboom on partition assign!"); + } + }; + verifyOnCallbackExceptions(throwOnAssignListener, + throwFatalErrorOnAssignmentAssignor.name(), + "User rebalance callback throws an error", "Illegal state for assignment!"); + } + + @Test + public void testOnPartitionsAssignExceptionShouldBeRenderedIfNotKafkaException() { + MockRebalanceListener throwOnAssignListener = new MockRebalanceListener() { + @Override + public void onPartitionsAssigned(Collection partitions) { + super.onPartitionsAssigned(partitions); + throw new IllegalStateException("Illegal state on partition assign!"); + } + }; + + verifyOnCallbackExceptions(throwOnAssignListener, + partitionAssignor.name(), "User rebalance callback throws an error", + "Illegal state on partition assign!"); + } + + private void verifyOnCallbackExceptions(final MockRebalanceListener rebalanceListener, + final String assignorName, + final String exceptionMessage, + final String causeMessage) { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + subscriptions.subscribe(singleton(topic1), rebalanceListener); + ByteBuffer buffer = ConsumerProtocol.serializeAssignment( + new ConsumerPartitionAssignor.Assignment(Collections.singletonList(t1p), ByteBuffer.wrap(new byte[0]))); + subscriptions.assignFromSubscribed(singleton(t2p)); + + if (exceptionMessage != null) { + final Exception exception = assertThrows(KafkaException.class, + () -> coordinator.onJoinComplete(1, "memberId", assignorName, buffer)); + assertEquals(exceptionMessage, exception.getMessage()); + if (causeMessage != null) { + assertEquals(causeMessage, exception.getCause().getMessage()); + } + } + + // Eager doesn't trigger on partition revoke. + assertEquals(protocol == COOPERATIVE ? 1 : 0, rebalanceListener.revokedCount); + assertEquals(0, rebalanceListener.lostCount); + assertEquals(1, rebalanceListener.assignedCount); + assertTrue(assignorMap.containsKey(assignorName), "Unknown assignor name: " + assignorName); + assertEquals(1, assignorMap.get(assignorName).numAssignment()); + } + + @Test + public void testUnsubscribeWithInvalidGeneration() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + subscriptions.subscribe(singleton(topic1), rebalanceListener); + subscriptions.assignFromSubscribed(Collections.singletonList(t1p)); + + coordinator.onLeavePrepare(); + assertEquals(1, rebalanceListener.lostCount); + assertEquals(0, rebalanceListener.revokedCount); + } + + @Test + public void testUnknownMemberId() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // illegal_generation will cause re-partition + subscriptions.subscribe(singleton(topic1), rebalanceListener); + subscriptions.assignFromSubscribed(Collections.singletonList(t1p)); + + time.sleep(sessionTimeoutMs); + RequestFuture future = coordinator.sendHeartbeatRequest(); // should send out the heartbeat + assertEquals(1, consumerClient.pendingRequestCount()); + assertFalse(future.isDone()); + + client.prepareResponse(heartbeatResponse(Errors.UNKNOWN_MEMBER_ID)); + time.sleep(sessionTimeoutMs); + consumerClient.poll(time.timer(0)); + + assertTrue(future.isDone()); + assertTrue(future.failed()); + assertEquals(Errors.UNKNOWN_MEMBER_ID.exception(), future.exception()); + assertTrue(coordinator.rejoinNeededOrPending()); + + coordinator.poll(time.timer(0)); + + assertEquals(1, rebalanceListener.lostCount); + assertEquals(Collections.singleton(t1p), rebalanceListener.lost); + } + + @Test + public void testCoordinatorDisconnect() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // coordinator disconnect will mark coordinator as unknown + time.sleep(sessionTimeoutMs); + RequestFuture future = coordinator.sendHeartbeatRequest(); // should send out the heartbeat + assertEquals(1, consumerClient.pendingRequestCount()); + assertFalse(future.isDone()); + + client.prepareResponse(heartbeatResponse(Errors.NONE), true); // return disconnected + time.sleep(sessionTimeoutMs); + consumerClient.poll(time.timer(0)); + + assertTrue(future.isDone()); + assertTrue(future.failed()); + assertTrue(future.exception() instanceof DisconnectException); + assertTrue(coordinator.coordinatorUnknown()); + } + + @Test + public void testJoinGroupInvalidGroupId() { + final String consumerId = "leader"; + + subscriptions.subscribe(singleton(topic1), rebalanceListener); + + // ensure metadata is up-to-date for leader + client.updateMetadata(metadataResponse); + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + client.prepareResponse(joinGroupLeaderResponse(0, consumerId, Collections.emptyMap(), + Errors.INVALID_GROUP_ID)); + assertThrows(ApiException.class, () -> coordinator.poll(time.timer(Long.MAX_VALUE))); + } + + @Test + public void testNormalJoinGroupLeader() { + final String consumerId = "leader"; + final Set subscription = singleton(topic1); + final List owned = Collections.emptyList(); + final List assigned = Arrays.asList(t1p); + + subscriptions.subscribe(singleton(topic1), rebalanceListener); + + // ensure metadata is up-to-date for leader + client.updateMetadata(metadataResponse); + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // normal join group + Map> memberSubscriptions = singletonMap(consumerId, singletonList(topic1)); + partitionAssignor.prepare(singletonMap(consumerId, assigned)); + + client.prepareResponse(joinGroupLeaderResponse(1, consumerId, memberSubscriptions, Errors.NONE)); + client.prepareResponse(body -> { + SyncGroupRequest sync = (SyncGroupRequest) body; + return sync.data().memberId().equals(consumerId) && + sync.data().generationId() == 1 && + sync.groupAssignments().containsKey(consumerId); + }, syncGroupResponse(assigned, Errors.NONE)); + coordinator.poll(time.timer(Long.MAX_VALUE)); + + assertFalse(coordinator.rejoinNeededOrPending()); + assertEquals(toSet(assigned), subscriptions.assignedPartitions()); + assertEquals(subscription, subscriptions.metadataTopics()); + assertEquals(0, rebalanceListener.revokedCount); + assertNull(rebalanceListener.revoked); + assertEquals(1, rebalanceListener.assignedCount); + assertEquals(getAdded(owned, assigned), rebalanceListener.assigned); + } + + @Test + public void testOutdatedCoordinatorAssignment() { + final String consumerId = "outdated_assignment"; + final List owned = Collections.emptyList(); + final List oldSubscription = singletonList(topic2); + final List oldAssignment = Arrays.asList(t2p); + final List newSubscription = singletonList(topic1); + final List newAssignment = Arrays.asList(t1p); + + subscriptions.subscribe(toSet(oldSubscription), rebalanceListener); + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // Test coordinator returning unsubscribed partitions + partitionAssignor.prepare(singletonMap(consumerId, newAssignment)); + + // First incorrect assignment for subscription + client.prepareResponse( + joinGroupLeaderResponse( + 1, consumerId, singletonMap(consumerId, oldSubscription), Errors.NONE)); + client.prepareResponse(body -> { + SyncGroupRequest sync = (SyncGroupRequest) body; + return sync.data().memberId().equals(consumerId) && + sync.data().generationId() == 1 && + sync.groupAssignments().containsKey(consumerId); + }, syncGroupResponse(oldAssignment, Errors.NONE)); + + // Second correct assignment for subscription + client.prepareResponse( + joinGroupLeaderResponse( + 1, consumerId, singletonMap(consumerId, newSubscription), Errors.NONE)); + client.prepareResponse(body -> { + SyncGroupRequest sync = (SyncGroupRequest) body; + return sync.data().memberId().equals(consumerId) && + sync.data().generationId() == 1 && + sync.groupAssignments().containsKey(consumerId); + }, syncGroupResponse(newAssignment, Errors.NONE)); + + // Poll once so that the join group future gets created and complete + coordinator.poll(time.timer(0)); + + // Before the sync group response gets completed change the subscription + subscriptions.subscribe(toSet(newSubscription), rebalanceListener); + coordinator.poll(time.timer(0)); + + coordinator.poll(time.timer(Long.MAX_VALUE)); + + final Collection assigned = getAdded(owned, newAssignment); + + assertFalse(coordinator.rejoinNeededOrPending()); + assertEquals(toSet(newAssignment), subscriptions.assignedPartitions()); + assertEquals(toSet(newSubscription), subscriptions.metadataTopics()); + assertEquals(protocol == EAGER ? 1 : 0, rebalanceListener.revokedCount); + assertEquals(1, rebalanceListener.assignedCount); + assertEquals(assigned, rebalanceListener.assigned); + } + + @Test + public void testMetadataTopicsDuringSubscriptionChange() { + final String consumerId = "subscription_change"; + final List oldSubscription = singletonList(topic1); + final List oldAssignment = Collections.singletonList(t1p); + final List newSubscription = singletonList(topic2); + final List newAssignment = Collections.singletonList(t2p); + + subscriptions.subscribe(toSet(oldSubscription), rebalanceListener); + assertEquals(toSet(oldSubscription), subscriptions.metadataTopics()); + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + prepareJoinAndSyncResponse(consumerId, 1, oldSubscription, oldAssignment); + + coordinator.poll(time.timer(0)); + assertEquals(toSet(oldSubscription), subscriptions.metadataTopics()); + + subscriptions.subscribe(toSet(newSubscription), rebalanceListener); + assertEquals(Utils.mkSet(topic1, topic2), subscriptions.metadataTopics()); + + prepareJoinAndSyncResponse(consumerId, 2, newSubscription, newAssignment); + coordinator.poll(time.timer(Long.MAX_VALUE)); + assertFalse(coordinator.rejoinNeededOrPending()); + assertEquals(toSet(newAssignment), subscriptions.assignedPartitions()); + assertEquals(toSet(newSubscription), subscriptions.metadataTopics()); + } + + @Test + public void testPatternJoinGroupLeader() { + final String consumerId = "leader"; + final List assigned = Arrays.asList(t1p, t2p); + final List owned = Collections.emptyList(); + + subscriptions.subscribe(Pattern.compile("test.*"), rebalanceListener); + + // partially update the metadata with one topic first, + // let the leader to refresh metadata during assignment + client.updateMetadata(RequestTestUtils.metadataUpdateWith(1, singletonMap(topic1, 1))); + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // normal join group + Map> memberSubscriptions = singletonMap(consumerId, singletonList(topic1)); + partitionAssignor.prepare(singletonMap(consumerId, assigned)); + + client.prepareResponse(joinGroupLeaderResponse(1, consumerId, memberSubscriptions, Errors.NONE)); + client.prepareResponse(body -> { + SyncGroupRequest sync = (SyncGroupRequest) body; + return sync.data().memberId().equals(consumerId) && + sync.data().generationId() == 1 && + sync.groupAssignments().containsKey(consumerId); + }, syncGroupResponse(assigned, Errors.NONE)); + // expect client to force updating the metadata, if yes gives it both topics + client.prepareMetadataUpdate(metadataResponse); + + coordinator.poll(time.timer(Long.MAX_VALUE)); + + assertFalse(coordinator.rejoinNeededOrPending()); + assertEquals(2, subscriptions.numAssignedPartitions()); + assertEquals(2, subscriptions.metadataTopics().size()); + assertEquals(2, subscriptions.subscription().size()); + // callback not triggered at all since there's nothing to be revoked + assertEquals(0, rebalanceListener.revokedCount); + assertNull(rebalanceListener.revoked); + assertEquals(1, rebalanceListener.assignedCount); + assertEquals(getAdded(owned, assigned), rebalanceListener.assigned); + } + + @Test + public void testMetadataRefreshDuringRebalance() { + final String consumerId = "leader"; + final List owned = Collections.emptyList(); + final List oldAssigned = singletonList(t1p); + subscriptions.subscribe(Pattern.compile(".*"), rebalanceListener); + client.updateMetadata(RequestTestUtils.metadataUpdateWith(1, singletonMap(topic1, 1))); + coordinator.maybeUpdateSubscriptionMetadata(); + + assertEquals(singleton(topic1), subscriptions.subscription()); + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + Map> initialSubscription = singletonMap(consumerId, singletonList(topic1)); + partitionAssignor.prepare(singletonMap(consumerId, oldAssigned)); + + // the metadata will be updated in flight with a new topic added + final List updatedSubscription = Arrays.asList(topic1, topic2); + + client.prepareResponse(joinGroupLeaderResponse(1, consumerId, initialSubscription, Errors.NONE)); + client.prepareResponse(body -> { + final Map updatedPartitions = new HashMap<>(); + for (String topic : updatedSubscription) + updatedPartitions.put(topic, 1); + client.updateMetadata(RequestTestUtils.metadataUpdateWith(1, updatedPartitions)); + return true; + }, syncGroupResponse(oldAssigned, Errors.NONE)); + coordinator.poll(time.timer(Long.MAX_VALUE)); + + // rejoin will only be set in the next poll call + assertFalse(coordinator.rejoinNeededOrPending()); + assertEquals(singleton(topic1), subscriptions.subscription()); + assertEquals(toSet(oldAssigned), subscriptions.assignedPartitions()); + // nothing to be revoked and hence no callback triggered + assertEquals(0, rebalanceListener.revokedCount); + assertNull(rebalanceListener.revoked); + assertEquals(1, rebalanceListener.assignedCount); + assertEquals(getAdded(owned, oldAssigned), rebalanceListener.assigned); + + List newAssigned = Arrays.asList(t1p, t2p); + + final Map> updatedSubscriptions = singletonMap(consumerId, Arrays.asList(topic1, topic2)); + partitionAssignor.prepare(singletonMap(consumerId, newAssigned)); + + // we expect to see a second rebalance with the new-found topics + client.prepareResponse(body -> { + JoinGroupRequest join = (JoinGroupRequest) body; + Iterator protocolIterator = + join.data().protocols().iterator(); + assertTrue(protocolIterator.hasNext()); + JoinGroupRequestData.JoinGroupRequestProtocol protocolMetadata = protocolIterator.next(); + + ByteBuffer metadata = ByteBuffer.wrap(protocolMetadata.metadata()); + ConsumerPartitionAssignor.Subscription subscription = ConsumerProtocol.deserializeSubscription(metadata); + metadata.rewind(); + return subscription.topics().containsAll(updatedSubscription); + }, joinGroupLeaderResponse(2, consumerId, updatedSubscriptions, Errors.NONE)); + // update the metadata again back to topic1 + client.prepareResponse(body -> { + client.updateMetadata(RequestTestUtils.metadataUpdateWith(1, singletonMap(topic1, 1))); + return true; + }, syncGroupResponse(newAssigned, Errors.NONE)); + + coordinator.poll(time.timer(Long.MAX_VALUE)); + + Collection revoked = getRevoked(oldAssigned, newAssigned); + int revokedCount = revoked.isEmpty() ? 0 : 1; + + assertFalse(coordinator.rejoinNeededOrPending()); + assertEquals(toSet(updatedSubscription), subscriptions.subscription()); + assertEquals(toSet(newAssigned), subscriptions.assignedPartitions()); + assertEquals(revokedCount, rebalanceListener.revokedCount); + assertEquals(revoked.isEmpty() ? null : revoked, rebalanceListener.revoked); + assertEquals(2, rebalanceListener.assignedCount); + assertEquals(getAdded(oldAssigned, newAssigned), rebalanceListener.assigned); + + // we expect to see a third rebalance with the new-found topics + partitionAssignor.prepare(singletonMap(consumerId, oldAssigned)); + + client.prepareResponse(body -> { + JoinGroupRequest join = (JoinGroupRequest) body; + Iterator protocolIterator = + join.data().protocols().iterator(); + assertTrue(protocolIterator.hasNext()); + JoinGroupRequestData.JoinGroupRequestProtocol protocolMetadata = protocolIterator.next(); + + ByteBuffer metadata = ByteBuffer.wrap(protocolMetadata.metadata()); + ConsumerPartitionAssignor.Subscription subscription = ConsumerProtocol.deserializeSubscription(metadata); + metadata.rewind(); + return subscription.topics().contains(topic1); + }, joinGroupLeaderResponse(3, consumerId, initialSubscription, Errors.NONE)); + client.prepareResponse(syncGroupResponse(oldAssigned, Errors.NONE)); + + coordinator.poll(time.timer(Long.MAX_VALUE)); + + revoked = getRevoked(newAssigned, oldAssigned); + assertFalse(revoked.isEmpty()); + revokedCount += 1; + Collection added = getAdded(newAssigned, oldAssigned); + + assertFalse(coordinator.rejoinNeededOrPending()); + assertEquals(singleton(topic1), subscriptions.subscription()); + assertEquals(toSet(oldAssigned), subscriptions.assignedPartitions()); + assertEquals(revokedCount, rebalanceListener.revokedCount); + assertEquals(revoked.isEmpty() ? null : revoked, rebalanceListener.revoked); + assertEquals(3, rebalanceListener.assignedCount); + assertEquals(added, rebalanceListener.assigned); + assertEquals(0, rebalanceListener.lostCount); + } + + @Test + public void testForceMetadataRefreshForPatternSubscriptionDuringRebalance() { + // Set up a non-leader consumer with pattern subscription and a cluster containing one topic matching the + // pattern. + subscriptions.subscribe(Pattern.compile(".*"), rebalanceListener); + client.updateMetadata(RequestTestUtils.metadataUpdateWith(1, singletonMap(topic1, 1))); + coordinator.maybeUpdateSubscriptionMetadata(); + assertEquals(singleton(topic1), subscriptions.subscription()); + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // Instrument the test so that metadata will contain two topics after next refresh. + client.prepareMetadataUpdate(metadataResponse); + + client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE)); + client.prepareResponse(body -> { + SyncGroupRequest sync = (SyncGroupRequest) body; + return sync.data().memberId().equals(consumerId) && + sync.data().generationId() == 1 && + sync.groupAssignments().isEmpty(); + }, syncGroupResponse(singletonList(t1p), Errors.NONE)); + + partitionAssignor.prepare(singletonMap(consumerId, singletonList(t1p))); + + // This will trigger rebalance. + coordinator.poll(time.timer(Long.MAX_VALUE)); + + // Make sure that the metadata was refreshed during the rebalance and thus subscriptions now contain two topics. + final Set updatedSubscriptionSet = new HashSet<>(Arrays.asList(topic1, topic2)); + assertEquals(updatedSubscriptionSet, subscriptions.subscription()); + + // Refresh the metadata again. Since there have been no changes since the last refresh, it won't trigger + // rebalance again. + metadata.requestUpdate(); + consumerClient.poll(time.timer(Long.MAX_VALUE)); + assertFalse(coordinator.rejoinNeededOrPending()); + } + + /** + * Verifies that the consumer re-joins after a metadata change. If JoinGroup fails + * and metadata reverts to its original value, the consumer should still retry JoinGroup. + */ + @Test + public void testRebalanceWithMetadataChange() { + final String consumerId = "leader"; + final List topics = Arrays.asList(topic1, topic2); + final List partitions = Arrays.asList(t1p, t2p); + subscriptions.subscribe(toSet(topics), rebalanceListener); + client.updateMetadata(RequestTestUtils.metadataUpdateWith(1, + Utils.mkMap(Utils.mkEntry(topic1, 1), Utils.mkEntry(topic2, 1)))); + coordinator.maybeUpdateSubscriptionMetadata(); + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + Map> initialSubscription = singletonMap(consumerId, topics); + partitionAssignor.prepare(singletonMap(consumerId, partitions)); + + client.prepareResponse(joinGroupLeaderResponse(1, consumerId, initialSubscription, Errors.NONE)); + client.prepareResponse(syncGroupResponse(partitions, Errors.NONE)); + coordinator.poll(time.timer(Long.MAX_VALUE)); + + // rejoin will only be set in the next poll call + assertFalse(coordinator.rejoinNeededOrPending()); + assertEquals(toSet(topics), subscriptions.subscription()); + assertEquals(toSet(partitions), subscriptions.assignedPartitions()); + assertEquals(0, rebalanceListener.revokedCount); + assertNull(rebalanceListener.revoked); + assertEquals(1, rebalanceListener.assignedCount); + + // Change metadata to trigger rebalance. + client.updateMetadata(RequestTestUtils.metadataUpdateWith(1, singletonMap(topic1, 1))); + coordinator.poll(time.timer(0)); + + // Revert metadata to original value. Fail pending JoinGroup. Another + // JoinGroup should be sent, which will be completed successfully. + client.updateMetadata(RequestTestUtils.metadataUpdateWith(1, + Utils.mkMap(Utils.mkEntry(topic1, 1), Utils.mkEntry(topic2, 1)))); + client.respond(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NOT_COORDINATOR)); + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.poll(time.timer(0)); + assertTrue(coordinator.rejoinNeededOrPending()); + + client.respond(request -> { + if (!(request instanceof JoinGroupRequest)) { + return false; + } else { + JoinGroupRequest joinRequest = (JoinGroupRequest) request; + return consumerId.equals(joinRequest.data().memberId()); + } + }, joinGroupLeaderResponse(2, consumerId, initialSubscription, Errors.NONE)); + client.prepareResponse(syncGroupResponse(partitions, Errors.NONE)); + coordinator.poll(time.timer(Long.MAX_VALUE)); + + assertFalse(coordinator.rejoinNeededOrPending()); + Collection revoked = getRevoked(partitions, partitions); + assertEquals(revoked.isEmpty() ? 0 : 1, rebalanceListener.revokedCount); + assertEquals(revoked.isEmpty() ? null : revoked, rebalanceListener.revoked); + // No partitions have been lost since the rebalance failure was not fatal + assertEquals(0, rebalanceListener.lostCount); + assertNull(rebalanceListener.lost); + + Collection added = getAdded(partitions, partitions); + assertEquals(2, rebalanceListener.assignedCount); + assertEquals(added.isEmpty() ? Collections.emptySet() : toSet(partitions), rebalanceListener.assigned); + assertEquals(toSet(partitions), subscriptions.assignedPartitions()); + } + + @Test + public void testWakeupDuringJoin() { + final String consumerId = "leader"; + final List owned = Collections.emptyList(); + final List assigned = singletonList(t1p); + + subscriptions.subscribe(singleton(topic1), rebalanceListener); + + // ensure metadata is up-to-date for leader + client.updateMetadata(metadataResponse); + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + Map> memberSubscriptions = singletonMap(consumerId, singletonList(topic1)); + partitionAssignor.prepare(singletonMap(consumerId, assigned)); + + // prepare only the first half of the join and then trigger the wakeup + client.prepareResponse(joinGroupLeaderResponse(1, consumerId, memberSubscriptions, Errors.NONE)); + consumerClient.wakeup(); + + try { + coordinator.poll(time.timer(Long.MAX_VALUE)); + } catch (WakeupException e) { + // ignore + } + + // now complete the second half + client.prepareResponse(syncGroupResponse(assigned, Errors.NONE)); + coordinator.poll(time.timer(Long.MAX_VALUE)); + + assertFalse(coordinator.rejoinNeededOrPending()); + assertEquals(toSet(assigned), subscriptions.assignedPartitions()); + assertEquals(0, rebalanceListener.revokedCount); + assertNull(rebalanceListener.revoked); + assertEquals(1, rebalanceListener.assignedCount); + assertEquals(getAdded(owned, assigned), rebalanceListener.assigned); + } + + @Test + public void testNormalJoinGroupFollower() { + final Set subscription = singleton(topic1); + final List owned = Collections.emptyList(); + final List assigned = singletonList(t1p); + + subscriptions.subscribe(subscription, rebalanceListener); + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // normal join group + client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE)); + client.prepareResponse(body -> { + SyncGroupRequest sync = (SyncGroupRequest) body; + return sync.data().memberId().equals(consumerId) && + sync.data().generationId() == 1 && + sync.groupAssignments().isEmpty(); + }, syncGroupResponse(assigned, Errors.NONE)); + + coordinator.joinGroupIfNeeded(time.timer(Long.MAX_VALUE)); + + assertFalse(coordinator.rejoinNeededOrPending()); + assertEquals(toSet(assigned), subscriptions.assignedPartitions()); + assertEquals(subscription, subscriptions.metadataTopics()); + assertEquals(0, rebalanceListener.revokedCount); + assertNull(rebalanceListener.revoked); + assertEquals(1, rebalanceListener.assignedCount); + assertEquals(getAdded(owned, assigned), rebalanceListener.assigned); + } + + @Test + public void testUpdateLastHeartbeatPollWhenCoordinatorUnknown() throws Exception { + // If we are part of an active group and we cannot find the coordinator, we should nevertheless + // continue to update the last poll time so that we do not expire the consumer + subscriptions.subscribe(singleton(topic1), rebalanceListener); + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // Join the group, but signal a coordinator change after the first heartbeat + client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE)); + client.prepareResponse(syncGroupResponse(singletonList(t1p), Errors.NONE)); + client.prepareResponse(heartbeatResponse(Errors.NOT_COORDINATOR)); + + coordinator.poll(time.timer(Long.MAX_VALUE)); + time.sleep(heartbeatIntervalMs); + + // Await the first heartbeat which forces us to find a new coordinator + TestUtils.waitForCondition(() -> !client.hasPendingResponses(), + "Failed to observe expected heartbeat from background thread"); + + assertTrue(coordinator.coordinatorUnknown()); + assertFalse(coordinator.poll(time.timer(0))); + assertEquals(time.milliseconds(), coordinator.heartbeat().lastPollTime()); + + time.sleep(rebalanceTimeoutMs - 1); + assertFalse(coordinator.heartbeat().pollTimeoutExpired(time.milliseconds())); + } + + @Test + public void testPatternJoinGroupFollower() { + final Set subscription = Utils.mkSet(topic1, topic2); + final List owned = Collections.emptyList(); + final List assigned = Arrays.asList(t1p, t2p); + + subscriptions.subscribe(Pattern.compile("test.*"), rebalanceListener); + + // partially update the metadata with one topic first, + // let the leader to refresh metadata during assignment + client.updateMetadata(RequestTestUtils.metadataUpdateWith(1, singletonMap(topic1, 1))); + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // normal join group + client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE)); + client.prepareResponse(body -> { + SyncGroupRequest sync = (SyncGroupRequest) body; + return sync.data().memberId().equals(consumerId) && + sync.data().generationId() == 1 && + sync.groupAssignments().isEmpty(); + }, syncGroupResponse(assigned, Errors.NONE)); + // expect client to force updating the metadata, if yes gives it both topics + client.prepareMetadataUpdate(metadataResponse); + + coordinator.joinGroupIfNeeded(time.timer(Long.MAX_VALUE)); + + assertFalse(coordinator.rejoinNeededOrPending()); + assertEquals(assigned.size(), subscriptions.numAssignedPartitions()); + assertEquals(subscription, subscriptions.subscription()); + assertEquals(0, rebalanceListener.revokedCount); + assertNull(rebalanceListener.revoked); + assertEquals(1, rebalanceListener.assignedCount); + assertEquals(getAdded(owned, assigned), rebalanceListener.assigned); + } + + @Test + public void testLeaveGroupOnClose() { + + subscriptions.subscribe(singleton(topic1), rebalanceListener); + joinAsFollowerAndReceiveAssignment(coordinator, singletonList(t1p)); + + final AtomicBoolean received = new AtomicBoolean(false); + client.prepareResponse(body -> { + received.set(true); + LeaveGroupRequest leaveRequest = (LeaveGroupRequest) body; + return validateLeaveGroup(groupId, consumerId, leaveRequest); + }, new LeaveGroupResponse( + new LeaveGroupResponseData().setErrorCode(Errors.NONE.code()))); + coordinator.close(time.timer(0)); + assertTrue(received.get()); + } + + @Test + public void testMaybeLeaveGroup() { + subscriptions.subscribe(singleton(topic1), rebalanceListener); + joinAsFollowerAndReceiveAssignment(coordinator, singletonList(t1p)); + + final AtomicBoolean received = new AtomicBoolean(false); + client.prepareResponse(body -> { + received.set(true); + LeaveGroupRequest leaveRequest = (LeaveGroupRequest) body; + return validateLeaveGroup(groupId, consumerId, leaveRequest); + }, new LeaveGroupResponse(new LeaveGroupResponseData().setErrorCode(Errors.NONE.code()))); + coordinator.maybeLeaveGroup("test maybe leave group"); + assertTrue(received.get()); + + AbstractCoordinator.Generation generation = coordinator.generationIfStable(); + assertNull(generation); + } + + private boolean validateLeaveGroup(String groupId, + String consumerId, + LeaveGroupRequest leaveRequest) { + List members = leaveRequest.data().members(); + return leaveRequest.data().groupId().equals(groupId) && + members.size() == 1 && + members.get(0).memberId().equals(consumerId); + } + + /** + * This test checks if a consumer that has a valid member ID but an invalid generation + * ({@link org.apache.kafka.clients.consumer.internals.AbstractCoordinator.Generation#NO_GENERATION}) + * can still execute a leave group request. Such a situation may arise when a consumer has initiated a JoinGroup + * request without a memberId, but is shutdown or restarted before it has a chance to initiate and complete the + * second request. + */ + @Test + public void testPendingMemberShouldLeaveGroup() { + final String consumerId = "consumer-id"; + subscriptions.subscribe(singleton(topic1), rebalanceListener); + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // here we return a DEFAULT_GENERATION_ID, but valid member id and leader id. + client.prepareResponse(joinGroupFollowerResponse(-1, consumerId, "leader-id", Errors.MEMBER_ID_REQUIRED)); + + // execute join group + coordinator.joinGroupIfNeeded(time.timer(0)); + + final AtomicBoolean received = new AtomicBoolean(false); + client.prepareResponse(body -> { + received.set(true); + LeaveGroupRequest leaveRequest = (LeaveGroupRequest) body; + return validateLeaveGroup(groupId, consumerId, leaveRequest); + }, new LeaveGroupResponse(new LeaveGroupResponseData().setErrorCode(Errors.NONE.code()))); + + coordinator.maybeLeaveGroup("pending member leaves"); + assertTrue(received.get()); + } + + @Test + public void testUnexpectedErrorOnSyncGroup() { + subscriptions.subscribe(singleton(topic1), rebalanceListener); + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // join initially, but let coordinator rebalance on sync + client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE)); + client.prepareResponse(syncGroupResponse(Collections.emptyList(), Errors.UNKNOWN_SERVER_ERROR)); + assertThrows(KafkaException.class, () -> coordinator.joinGroupIfNeeded(time.timer(Long.MAX_VALUE))); + } + + @Test + public void testUnknownMemberIdOnSyncGroup() { + subscriptions.subscribe(singleton(topic1), rebalanceListener); + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // join initially, but let coordinator returns unknown member id + client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE)); + client.prepareResponse(syncGroupResponse(Collections.emptyList(), Errors.UNKNOWN_MEMBER_ID)); + + // now we should see a new join with the empty UNKNOWN_MEMBER_ID + client.prepareResponse(body -> { + JoinGroupRequest joinRequest = (JoinGroupRequest) body; + return joinRequest.data().memberId().equals(JoinGroupRequest.UNKNOWN_MEMBER_ID); + }, joinGroupFollowerResponse(2, consumerId, "leader", Errors.NONE)); + client.prepareResponse(syncGroupResponse(singletonList(t1p), Errors.NONE)); + + coordinator.joinGroupIfNeeded(time.timer(Long.MAX_VALUE)); + + assertFalse(coordinator.rejoinNeededOrPending()); + assertEquals(singleton(t1p), subscriptions.assignedPartitions()); + } + + @Test + public void testRebalanceInProgressOnSyncGroup() { + subscriptions.subscribe(singleton(topic1), rebalanceListener); + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // join initially, but let coordinator rebalance on sync + client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE)); + client.prepareResponse(syncGroupResponse(Collections.emptyList(), Errors.REBALANCE_IN_PROGRESS)); + + // then let the full join/sync finish successfully + client.prepareResponse(joinGroupFollowerResponse(2, consumerId, "leader", Errors.NONE)); + client.prepareResponse(syncGroupResponse(singletonList(t1p), Errors.NONE)); + + coordinator.joinGroupIfNeeded(time.timer(Long.MAX_VALUE)); + + assertFalse(coordinator.rejoinNeededOrPending()); + assertEquals(singleton(t1p), subscriptions.assignedPartitions()); + } + + @Test + public void testIllegalGenerationOnSyncGroup() { + subscriptions.subscribe(singleton(topic1), rebalanceListener); + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // join initially, but let coordinator rebalance on sync + client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE)); + client.prepareResponse(syncGroupResponse(Collections.emptyList(), Errors.ILLEGAL_GENERATION)); + + // then let the full join/sync finish successfully + client.prepareResponse(body -> { + JoinGroupRequest joinRequest = (JoinGroupRequest) body; + return joinRequest.data().memberId().equals(JoinGroupRequest.UNKNOWN_MEMBER_ID); + }, joinGroupFollowerResponse(2, consumerId, "leader", Errors.NONE)); + client.prepareResponse(syncGroupResponse(singletonList(t1p), Errors.NONE)); + + coordinator.joinGroupIfNeeded(time.timer(Long.MAX_VALUE)); + + assertFalse(coordinator.rejoinNeededOrPending()); + assertEquals(singleton(t1p), subscriptions.assignedPartitions()); + } + + @Test + public void testMetadataChangeTriggersRebalance() { + + // ensure metadata is up-to-date for leader + subscriptions.subscribe(singleton(topic1), rebalanceListener); + client.updateMetadata(metadataResponse); + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + Map> memberSubscriptions = singletonMap(consumerId, singletonList(topic1)); + partitionAssignor.prepare(singletonMap(consumerId, singletonList(t1p))); + + // the leader is responsible for picking up metadata changes and forcing a group rebalance + client.prepareResponse(joinGroupLeaderResponse(1, consumerId, memberSubscriptions, Errors.NONE)); + client.prepareResponse(syncGroupResponse(singletonList(t1p), Errors.NONE)); + + coordinator.poll(time.timer(Long.MAX_VALUE)); + + assertFalse(coordinator.rejoinNeededOrPending()); + + // a new partition is added to the topic + metadata.updateWithCurrentRequestVersion(RequestTestUtils.metadataUpdateWith(1, singletonMap(topic1, 2)), false, time.milliseconds()); + coordinator.maybeUpdateSubscriptionMetadata(); + + // we should detect the change and ask for reassignment + assertTrue(coordinator.rejoinNeededOrPending()); + } + + @Test + public void testUpdateMetadataDuringRebalance() { + final String topic1 = "topic1"; + final String topic2 = "topic2"; + TopicPartition tp1 = new TopicPartition(topic1, 0); + TopicPartition tp2 = new TopicPartition(topic2, 0); + final String consumerId = "leader"; + + List topics = Arrays.asList(topic1, topic2); + + subscriptions.subscribe(new HashSet<>(topics), rebalanceListener); + + // we only have metadata for one topic initially + client.updateMetadata(RequestTestUtils.metadataUpdateWith(1, singletonMap(topic1, 1))); + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // prepare initial rebalance + Map> memberSubscriptions = singletonMap(consumerId, topics); + partitionAssignor.prepare(singletonMap(consumerId, Arrays.asList(tp1))); + + client.prepareResponse(joinGroupLeaderResponse(1, consumerId, memberSubscriptions, Errors.NONE)); + client.prepareResponse(body -> { + SyncGroupRequest sync = (SyncGroupRequest) body; + if (sync.data().memberId().equals(consumerId) && + sync.data().generationId() == 1 && + sync.groupAssignments().containsKey(consumerId)) { + // trigger the metadata update including both topics after the sync group request has been sent + Map topicPartitionCounts = new HashMap<>(); + topicPartitionCounts.put(topic1, 1); + topicPartitionCounts.put(topic2, 1); + client.updateMetadata(RequestTestUtils.metadataUpdateWith(1, topicPartitionCounts)); + return true; + } + return false; + }, syncGroupResponse(Collections.singletonList(tp1), Errors.NONE)); + coordinator.poll(time.timer(Long.MAX_VALUE)); + + // the metadata update should trigger a second rebalance + client.prepareResponse(joinGroupLeaderResponse(2, consumerId, memberSubscriptions, Errors.NONE)); + client.prepareResponse(syncGroupResponse(Arrays.asList(tp1, tp2), Errors.NONE)); + + coordinator.poll(time.timer(Long.MAX_VALUE)); + + assertFalse(coordinator.rejoinNeededOrPending()); + assertEquals(new HashSet<>(Arrays.asList(tp1, tp2)), subscriptions.assignedPartitions()); + } + + /** + * Verifies that subscription change updates SubscriptionState correctly even after JoinGroup failures + * that don't re-invoke onJoinPrepare. + */ + @Test + public void testSubscriptionChangeWithAuthorizationFailure() { + // Subscribe to two topics of which only one is authorized and verify that metadata failure is propagated. + subscriptions.subscribe(Utils.mkSet(topic1, topic2), rebalanceListener); + client.prepareMetadataUpdate(RequestTestUtils.metadataUpdateWith("kafka-cluster", 1, + Collections.singletonMap(topic2, Errors.TOPIC_AUTHORIZATION_FAILED), singletonMap(topic1, 1))); + assertThrows(TopicAuthorizationException.class, () -> coordinator.poll(time.timer(Long.MAX_VALUE))); + + client.respond(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // Fail the first JoinGroup request + client.prepareResponse(joinGroupLeaderResponse(0, consumerId, Collections.emptyMap(), + Errors.GROUP_AUTHORIZATION_FAILED)); + assertThrows(GroupAuthorizationException.class, () -> coordinator.poll(time.timer(Long.MAX_VALUE))); + + // Change subscription to include only the authorized topic. Complete rebalance and check that + // references to topic2 have been removed from SubscriptionState. + subscriptions.subscribe(Utils.mkSet(topic1), rebalanceListener); + assertEquals(Collections.singleton(topic1), subscriptions.metadataTopics()); + client.prepareMetadataUpdate(RequestTestUtils.metadataUpdateWith("kafka-cluster", 1, + Collections.emptyMap(), singletonMap(topic1, 1))); + + Map> memberSubscriptions = singletonMap(consumerId, singletonList(topic1)); + partitionAssignor.prepare(singletonMap(consumerId, singletonList(t1p))); + client.prepareResponse(joinGroupLeaderResponse(1, consumerId, memberSubscriptions, Errors.NONE)); + client.prepareResponse(syncGroupResponse(singletonList(t1p), Errors.NONE)); + coordinator.poll(time.timer(Long.MAX_VALUE)); + + assertEquals(singleton(topic1), subscriptions.subscription()); + assertEquals(singleton(topic1), subscriptions.metadataTopics()); + } + + @Test + public void testWakeupFromAssignmentCallback() { + final String topic = "topic1"; + TopicPartition partition = new TopicPartition(topic, 0); + final String consumerId = "follower"; + Set topics = Collections.singleton(topic); + MockRebalanceListener rebalanceListener = new MockRebalanceListener() { + @Override + public void onPartitionsAssigned(Collection partitions) { + boolean raiseWakeup = this.assignedCount == 0; + super.onPartitionsAssigned(partitions); + + if (raiseWakeup) + throw new WakeupException(); + } + }; + + subscriptions.subscribe(topics, rebalanceListener); + + // we only have metadata for one topic initially + client.updateMetadata(RequestTestUtils.metadataUpdateWith(1, singletonMap(topic1, 1))); + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // prepare initial rebalance + partitionAssignor.prepare(singletonMap(consumerId, Collections.singletonList(partition))); + + client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE)); + client.prepareResponse(syncGroupResponse(Collections.singletonList(partition), Errors.NONE)); + + // The first call to poll should raise the exception from the rebalance listener + try { + coordinator.poll(time.timer(Long.MAX_VALUE)); + fail("Expected exception thrown from assignment callback"); + } catch (WakeupException e) { + } + + // The second call should retry the assignment callback and succeed + coordinator.poll(time.timer(Long.MAX_VALUE)); + + assertFalse(coordinator.rejoinNeededOrPending()); + assertEquals(0, rebalanceListener.revokedCount); + assertEquals(2, rebalanceListener.assignedCount); + } + + @Test + public void testRebalanceAfterTopicUnavailableWithSubscribe() { + unavailableTopicTest(false, Collections.emptySet()); + } + + @Test + public void testRebalanceAfterTopicUnavailableWithPatternSubscribe() { + unavailableTopicTest(true, Collections.emptySet()); + } + + @Test + public void testRebalanceAfterNotMatchingTopicUnavailableWithPatternSubscribe() { + unavailableTopicTest(true, Collections.singleton("notmatching")); + } + + private void unavailableTopicTest(boolean patternSubscribe, Set unavailableTopicsInLastMetadata) { + if (patternSubscribe) + subscriptions.subscribe(Pattern.compile("test.*"), rebalanceListener); + else + subscriptions.subscribe(singleton(topic1), rebalanceListener); + + client.prepareMetadataUpdate(RequestTestUtils.metadataUpdateWith("kafka-cluster", 1, + Collections.singletonMap(topic1, Errors.UNKNOWN_TOPIC_OR_PARTITION), Collections.emptyMap())); + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + Map> memberSubscriptions = singletonMap(consumerId, singletonList(topic1)); + partitionAssignor.prepare(Collections.emptyMap()); + + client.prepareResponse(joinGroupLeaderResponse(1, consumerId, memberSubscriptions, Errors.NONE)); + client.prepareResponse(syncGroupResponse(Collections.emptyList(), Errors.NONE)); + coordinator.poll(time.timer(Long.MAX_VALUE)); + assertFalse(coordinator.rejoinNeededOrPending()); + // callback not triggered since there's nothing to be assigned + assertEquals(Collections.emptySet(), rebalanceListener.assigned); + assertTrue(metadata.updateRequested(), "Metadata refresh not requested for unavailable partitions"); + + Map topicErrors = new HashMap<>(); + for (String topic : unavailableTopicsInLastMetadata) + topicErrors.put(topic, Errors.UNKNOWN_TOPIC_OR_PARTITION); + + client.prepareMetadataUpdate(RequestTestUtils.metadataUpdateWith("kafka-cluster", 1, + topicErrors, singletonMap(topic1, 1))); + + consumerClient.poll(time.timer(0)); + client.prepareResponse(joinGroupLeaderResponse(2, consumerId, memberSubscriptions, Errors.NONE)); + client.prepareResponse(syncGroupResponse(singletonList(t1p), Errors.NONE)); + coordinator.poll(time.timer(Long.MAX_VALUE)); + + assertFalse(metadata.updateRequested(), "Metadata refresh requested unnecessarily"); + assertFalse(coordinator.rejoinNeededOrPending()); + assertEquals(singleton(t1p), rebalanceListener.assigned); + } + + @Test + public void testExcludeInternalTopicsConfigOption() { + testInternalTopicInclusion(false); + } + + @Test + public void testIncludeInternalTopicsConfigOption() { + testInternalTopicInclusion(true); + } + + private void testInternalTopicInclusion(boolean includeInternalTopics) { + metadata = new ConsumerMetadata(0, Long.MAX_VALUE, includeInternalTopics, + false, subscriptions, new LogContext(), new ClusterResourceListeners()); + client = new MockClient(time, metadata); + try (ConsumerCoordinator coordinator = buildCoordinator(rebalanceConfig, new Metrics(), assignors, false, subscriptions)) { + subscriptions.subscribe(Pattern.compile(".*"), rebalanceListener); + Node node = new Node(0, "localhost", 9999); + MetadataResponse.PartitionMetadata partitionMetadata = + new MetadataResponse.PartitionMetadata(Errors.NONE, new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, 0), + Optional.of(node.id()), Optional.empty(), singletonList(node.id()), singletonList(node.id()), + singletonList(node.id())); + MetadataResponse.TopicMetadata topicMetadata = new MetadataResponse.TopicMetadata(Errors.NONE, + Topic.GROUP_METADATA_TOPIC_NAME, true, singletonList(partitionMetadata)); + + client.updateMetadata(RequestTestUtils.metadataResponse(singletonList(node), "clusterId", node.id(), + singletonList(topicMetadata))); + coordinator.maybeUpdateSubscriptionMetadata(); + + assertEquals(includeInternalTopics, subscriptions.subscription().contains(Topic.GROUP_METADATA_TOPIC_NAME)); + } + } + + @Test + public void testRejoinGroup() { + String otherTopic = "otherTopic"; + final List owned = Collections.emptyList(); + final List assigned = Arrays.asList(t1p); + + subscriptions.subscribe(singleton(topic1), rebalanceListener); + + // join the group once + joinAsFollowerAndReceiveAssignment(coordinator, assigned); + + assertEquals(0, rebalanceListener.revokedCount); + assertNull(rebalanceListener.revoked); + assertEquals(1, rebalanceListener.assignedCount); + assertEquals(getAdded(owned, assigned), rebalanceListener.assigned); + + // and join the group again + rebalanceListener.revoked = null; + rebalanceListener.assigned = null; + subscriptions.subscribe(new HashSet<>(Arrays.asList(topic1, otherTopic)), rebalanceListener); + client.prepareResponse(joinGroupFollowerResponse(2, consumerId, "leader", Errors.NONE)); + client.prepareResponse(syncGroupResponse(assigned, Errors.NONE)); + coordinator.joinGroupIfNeeded(time.timer(Long.MAX_VALUE)); + + Collection revoked = getRevoked(assigned, assigned); + Collection added = getAdded(assigned, assigned); + assertEquals(revoked.isEmpty() ? 0 : 1, rebalanceListener.revokedCount); + assertEquals(revoked.isEmpty() ? null : revoked, rebalanceListener.revoked); + assertEquals(2, rebalanceListener.assignedCount); + assertEquals(added, rebalanceListener.assigned); + } + + @Test + public void testDisconnectInJoin() { + subscriptions.subscribe(singleton(topic1), rebalanceListener); + final List owned = Collections.emptyList(); + final List assigned = Arrays.asList(t1p); + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // disconnected from original coordinator will cause re-discover and join again + client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE), true); + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE)); + client.prepareResponse(syncGroupResponse(assigned, Errors.NONE)); + coordinator.joinGroupIfNeeded(time.timer(Long.MAX_VALUE)); + + assertFalse(coordinator.rejoinNeededOrPending()); + assertEquals(toSet(assigned), subscriptions.assignedPartitions()); + // nothing to be revoked hence callback not triggered + assertEquals(0, rebalanceListener.revokedCount); + assertNull(rebalanceListener.revoked); + assertEquals(1, rebalanceListener.assignedCount); + assertEquals(getAdded(owned, assigned), rebalanceListener.assigned); + } + + @Test + public void testInvalidSessionTimeout() { + subscriptions.subscribe(singleton(topic1), rebalanceListener); + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // coordinator doesn't like the session timeout + client.prepareResponse(joinGroupFollowerResponse(0, consumerId, "", Errors.INVALID_SESSION_TIMEOUT)); + assertThrows(ApiException.class, () -> coordinator.joinGroupIfNeeded(time.timer(Long.MAX_VALUE))); + } + + @Test + public void testCommitOffsetOnly() { + subscriptions.assignFromUser(singleton(t1p)); + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE); + + AtomicBoolean success = new AtomicBoolean(false); + coordinator.commitOffsetsAsync(singletonMap(t1p, new OffsetAndMetadata(100L)), callback(success)); + coordinator.invokeCompletedOffsetCommitCallbacks(); + assertTrue(success.get()); + } + + @Test + public void testCoordinatorDisconnectAfterNotCoordinatorError() { + testInFlightRequestsFailedAfterCoordinatorMarkedDead(Errors.NOT_COORDINATOR); + } + + @Test + public void testCoordinatorDisconnectAfterCoordinatorNotAvailableError() { + testInFlightRequestsFailedAfterCoordinatorMarkedDead(Errors.COORDINATOR_NOT_AVAILABLE); + } + + private void testInFlightRequestsFailedAfterCoordinatorMarkedDead(Errors error) { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // Send two async commits and fail the first one with an error. + // This should cause a coordinator disconnect which will cancel the second request. + + MockCommitCallback firstCommitCallback = new MockCommitCallback(); + MockCommitCallback secondCommitCallback = new MockCommitCallback(); + coordinator.commitOffsetsAsync(singletonMap(t1p, new OffsetAndMetadata(100L)), firstCommitCallback); + coordinator.commitOffsetsAsync(singletonMap(t1p, new OffsetAndMetadata(100L)), secondCommitCallback); + + respondToOffsetCommitRequest(singletonMap(t1p, 100L), error); + consumerClient.pollNoWakeup(); + consumerClient.pollNoWakeup(); // second poll since coordinator disconnect is async + coordinator.invokeCompletedOffsetCommitCallbacks(); + + assertTrue(coordinator.coordinatorUnknown()); + assertTrue(firstCommitCallback.exception instanceof RetriableCommitFailedException); + assertTrue(secondCommitCallback.exception instanceof RetriableCommitFailedException); + } + + @Test + public void testAutoCommitDynamicAssignment() { + try (ConsumerCoordinator coordinator = buildCoordinator(rebalanceConfig, new Metrics(), assignors, true, subscriptions) + ) { + subscriptions.subscribe(singleton(topic1), rebalanceListener); + joinAsFollowerAndReceiveAssignment(coordinator, singletonList(t1p)); + subscriptions.seek(t1p, 100); + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE); + time.sleep(autoCommitIntervalMs); + coordinator.poll(time.timer(Long.MAX_VALUE)); + assertFalse(client.hasPendingResponses()); + } + } + + @Test + public void testAutoCommitRetryBackoff() { + try (ConsumerCoordinator coordinator = buildCoordinator(rebalanceConfig, new Metrics(), assignors, true, subscriptions)) { + subscriptions.subscribe(singleton(topic1), rebalanceListener); + joinAsFollowerAndReceiveAssignment(coordinator, singletonList(t1p)); + + subscriptions.seek(t1p, 100); + time.sleep(autoCommitIntervalMs); + + // Send an offset commit, but let it fail with a retriable error + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NOT_COORDINATOR); + coordinator.poll(time.timer(Long.MAX_VALUE)); + assertTrue(coordinator.coordinatorUnknown()); + + // After the disconnect, we should rediscover the coordinator + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.poll(time.timer(Long.MAX_VALUE)); + + subscriptions.seek(t1p, 200); + + // Until the retry backoff has expired, we should not retry the offset commit + time.sleep(retryBackoffMs / 2); + coordinator.poll(time.timer(Long.MAX_VALUE)); + assertEquals(0, client.inFlightRequestCount()); + + // Once the backoff expires, we should retry + time.sleep(retryBackoffMs / 2); + coordinator.poll(time.timer(Long.MAX_VALUE)); + assertEquals(1, client.inFlightRequestCount()); + respondToOffsetCommitRequest(singletonMap(t1p, 200L), Errors.NONE); + } + } + + @Test + public void testAutoCommitAwaitsInterval() { + try (ConsumerCoordinator coordinator = buildCoordinator(rebalanceConfig, new Metrics(), assignors, true, subscriptions)) { + subscriptions.subscribe(singleton(topic1), rebalanceListener); + joinAsFollowerAndReceiveAssignment(coordinator, singletonList(t1p)); + + subscriptions.seek(t1p, 100); + time.sleep(autoCommitIntervalMs); + + // Send the offset commit request, but do not respond + coordinator.poll(time.timer(Long.MAX_VALUE)); + assertEquals(1, client.inFlightRequestCount()); + + time.sleep(autoCommitIntervalMs / 2); + + // Ensure that no additional offset commit is sent + coordinator.poll(time.timer(Long.MAX_VALUE)); + assertEquals(1, client.inFlightRequestCount()); + + respondToOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE); + coordinator.poll(time.timer(Long.MAX_VALUE)); + assertEquals(0, client.inFlightRequestCount()); + + subscriptions.seek(t1p, 200); + + // If we poll again before the auto-commit interval, there should be no new sends + coordinator.poll(time.timer(Long.MAX_VALUE)); + assertEquals(0, client.inFlightRequestCount()); + + // After the remainder of the interval passes, we send a new offset commit + time.sleep(autoCommitIntervalMs / 2); + coordinator.poll(time.timer(Long.MAX_VALUE)); + assertEquals(1, client.inFlightRequestCount()); + respondToOffsetCommitRequest(singletonMap(t1p, 200L), Errors.NONE); + } + } + + @Test + public void testAutoCommitDynamicAssignmentRebalance() { + try (ConsumerCoordinator coordinator = buildCoordinator(rebalanceConfig, new Metrics(), assignors, true, subscriptions)) { + subscriptions.subscribe(singleton(topic1), rebalanceListener); + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // haven't joined, so should not cause a commit + time.sleep(autoCommitIntervalMs); + consumerClient.poll(time.timer(0)); + + client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE)); + client.prepareResponse(syncGroupResponse(singletonList(t1p), Errors.NONE)); + coordinator.joinGroupIfNeeded(time.timer(Long.MAX_VALUE)); + + subscriptions.seek(t1p, 100); + + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE); + time.sleep(autoCommitIntervalMs); + coordinator.poll(time.timer(Long.MAX_VALUE)); + assertFalse(client.hasPendingResponses()); + } + } + + @Test + public void testAutoCommitManualAssignment() { + try (ConsumerCoordinator coordinator = buildCoordinator(rebalanceConfig, new Metrics(), assignors, true, subscriptions)) { + subscriptions.assignFromUser(singleton(t1p)); + subscriptions.seek(t1p, 100); + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE); + time.sleep(autoCommitIntervalMs); + coordinator.poll(time.timer(Long.MAX_VALUE)); + assertFalse(client.hasPendingResponses()); + } + } + + @Test + public void testAutoCommitManualAssignmentCoordinatorUnknown() { + try (ConsumerCoordinator coordinator = buildCoordinator(rebalanceConfig, new Metrics(), assignors, true, subscriptions)) { + subscriptions.assignFromUser(singleton(t1p)); + subscriptions.seek(t1p, 100); + + // no commit initially since coordinator is unknown + consumerClient.poll(time.timer(0)); + time.sleep(autoCommitIntervalMs); + consumerClient.poll(time.timer(0)); + + // now find the coordinator + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // sleep only for the retry backoff + time.sleep(retryBackoffMs); + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE); + coordinator.poll(time.timer(Long.MAX_VALUE)); + assertFalse(client.hasPendingResponses()); + } + } + + @Test + public void testCommitOffsetMetadata() { + subscriptions.assignFromUser(singleton(t1p)); + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE); + + AtomicBoolean success = new AtomicBoolean(false); + + Map offsets = singletonMap(t1p, new OffsetAndMetadata(100L, "hello")); + coordinator.commitOffsetsAsync(offsets, callback(offsets, success)); + coordinator.invokeCompletedOffsetCommitCallbacks(); + assertTrue(success.get()); + } + + @Test + public void testCommitOffsetAsyncWithDefaultCallback() { + int invokedBeforeTest = mockOffsetCommitCallback.invoked; + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE); + coordinator.commitOffsetsAsync(singletonMap(t1p, new OffsetAndMetadata(100L)), mockOffsetCommitCallback); + coordinator.invokeCompletedOffsetCommitCallbacks(); + assertEquals(invokedBeforeTest + 1, mockOffsetCommitCallback.invoked); + assertNull(mockOffsetCommitCallback.exception); + } + + @Test + public void testCommitAfterLeaveGroup() { + // enable auto-assignment + subscriptions.subscribe(singleton(topic1), rebalanceListener); + + joinAsFollowerAndReceiveAssignment(coordinator, singletonList(t1p)); + + // now switch to manual assignment + client.prepareResponse(new LeaveGroupResponse(new LeaveGroupResponseData() + .setErrorCode(Errors.NONE.code()))); + subscriptions.unsubscribe(); + coordinator.maybeLeaveGroup("test commit after leave"); + subscriptions.assignFromUser(singleton(t1p)); + + // the client should not reuse generation/memberId from auto-subscribed generation + client.prepareResponse(body -> { + OffsetCommitRequest commitRequest = (OffsetCommitRequest) body; + return commitRequest.data().memberId().equals(OffsetCommitRequest.DEFAULT_MEMBER_ID) && + commitRequest.data().generationId() == OffsetCommitRequest.DEFAULT_GENERATION_ID; + }, offsetCommitResponse(singletonMap(t1p, Errors.NONE))); + + AtomicBoolean success = new AtomicBoolean(false); + coordinator.commitOffsetsAsync(singletonMap(t1p, new OffsetAndMetadata(100L)), callback(success)); + coordinator.invokeCompletedOffsetCommitCallbacks(); + assertTrue(success.get()); + } + + @Test + public void testCommitOffsetAsyncFailedWithDefaultCallback() { + int invokedBeforeTest = mockOffsetCommitCallback.invoked; + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.COORDINATOR_NOT_AVAILABLE); + coordinator.commitOffsetsAsync(singletonMap(t1p, new OffsetAndMetadata(100L)), mockOffsetCommitCallback); + coordinator.invokeCompletedOffsetCommitCallbacks(); + assertEquals(invokedBeforeTest + 1, mockOffsetCommitCallback.invoked); + assertTrue(mockOffsetCommitCallback.exception instanceof RetriableCommitFailedException); + } + + @Test + public void testCommitOffsetAsyncCoordinatorNotAvailable() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // async commit with coordinator not available + MockCommitCallback cb = new MockCommitCallback(); + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.COORDINATOR_NOT_AVAILABLE); + coordinator.commitOffsetsAsync(singletonMap(t1p, new OffsetAndMetadata(100L)), cb); + coordinator.invokeCompletedOffsetCommitCallbacks(); + + assertTrue(coordinator.coordinatorUnknown()); + assertEquals(1, cb.invoked); + assertTrue(cb.exception instanceof RetriableCommitFailedException); + } + + @Test + public void testCommitOffsetAsyncNotCoordinator() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // async commit with not coordinator + MockCommitCallback cb = new MockCommitCallback(); + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.COORDINATOR_NOT_AVAILABLE); + coordinator.commitOffsetsAsync(singletonMap(t1p, new OffsetAndMetadata(100L)), cb); + coordinator.invokeCompletedOffsetCommitCallbacks(); + + assertTrue(coordinator.coordinatorUnknown()); + assertEquals(1, cb.invoked); + assertTrue(cb.exception instanceof RetriableCommitFailedException); + } + + @Test + public void testCommitOffsetAsyncDisconnected() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // async commit with coordinator disconnected + MockCommitCallback cb = new MockCommitCallback(); + prepareOffsetCommitRequestDisconnect(singletonMap(t1p, 100L)); + coordinator.commitOffsetsAsync(singletonMap(t1p, new OffsetAndMetadata(100L)), cb); + coordinator.invokeCompletedOffsetCommitCallbacks(); + + assertTrue(coordinator.coordinatorUnknown()); + assertEquals(1, cb.invoked); + assertTrue(cb.exception instanceof RetriableCommitFailedException); + } + + @Test + public void testCommitOffsetSyncNotCoordinator() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // sync commit with coordinator disconnected (should connect, get metadata, and then submit the commit request) + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NOT_COORDINATOR); + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE); + coordinator.commitOffsetsSync(singletonMap(t1p, new OffsetAndMetadata(100L)), time.timer(Long.MAX_VALUE)); + } + + @Test + public void testCommitOffsetSyncCoordinatorNotAvailable() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // sync commit with coordinator disconnected (should connect, get metadata, and then submit the commit request) + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.COORDINATOR_NOT_AVAILABLE); + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE); + coordinator.commitOffsetsSync(singletonMap(t1p, new OffsetAndMetadata(100L)), time.timer(Long.MAX_VALUE)); + } + + @Test + public void testCommitOffsetSyncCoordinatorDisconnected() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // sync commit with coordinator disconnected (should connect, get metadata, and then submit the commit request) + prepareOffsetCommitRequestDisconnect(singletonMap(t1p, 100L)); + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE); + coordinator.commitOffsetsSync(singletonMap(t1p, new OffsetAndMetadata(100L)), time.timer(Long.MAX_VALUE)); + } + + @Test + public void testAsyncCommitCallbacksInvokedPriorToSyncCommitCompletion() throws Exception { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + final List committedOffsets = Collections.synchronizedList(new ArrayList<>()); + final OffsetAndMetadata firstOffset = new OffsetAndMetadata(0L); + final OffsetAndMetadata secondOffset = new OffsetAndMetadata(1L); + + coordinator.commitOffsetsAsync(singletonMap(t1p, firstOffset), new OffsetCommitCallback() { + @Override + public void onComplete(Map offsets, Exception exception) { + committedOffsets.add(firstOffset); + } + }); + + // Do a synchronous commit in the background so that we can send both responses at the same time + Thread thread = new Thread() { + @Override + public void run() { + coordinator.commitOffsetsSync(singletonMap(t1p, secondOffset), time.timer(10000)); + committedOffsets.add(secondOffset); + } + }; + + thread.start(); + + client.waitForRequests(2, 5000); + respondToOffsetCommitRequest(singletonMap(t1p, firstOffset.offset()), Errors.NONE); + respondToOffsetCommitRequest(singletonMap(t1p, secondOffset.offset()), Errors.NONE); + + thread.join(); + + assertEquals(Arrays.asList(firstOffset, secondOffset), committedOffsets); + } + + @Test + public void testRetryCommitUnknownTopicOrPartition() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + client.prepareResponse(offsetCommitResponse(singletonMap(t1p, Errors.UNKNOWN_TOPIC_OR_PARTITION))); + client.prepareResponse(offsetCommitResponse(singletonMap(t1p, Errors.NONE))); + + assertTrue(coordinator.commitOffsetsSync(singletonMap(t1p, + new OffsetAndMetadata(100L, "metadata")), time.timer(10000))); + } + + @Test + public void testCommitOffsetMetadataTooLarge() { + // since offset metadata is provided by the user, we have to propagate the exception so they can handle it + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.OFFSET_METADATA_TOO_LARGE); + assertThrows(OffsetMetadataTooLarge.class, () -> coordinator.commitOffsetsSync(singletonMap(t1p, + new OffsetAndMetadata(100L, "metadata")), time.timer(Long.MAX_VALUE))); + } + + @Test + public void testCommitOffsetIllegalGeneration() { + // we cannot retry if a rebalance occurs before the commit completed + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.ILLEGAL_GENERATION); + assertThrows(CommitFailedException.class, () -> coordinator.commitOffsetsSync(singletonMap(t1p, + new OffsetAndMetadata(100L, "metadata")), time.timer(Long.MAX_VALUE))); + } + + @Test + public void testCommitOffsetUnknownMemberId() { + // we cannot retry if a rebalance occurs before the commit completed + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.UNKNOWN_MEMBER_ID); + assertThrows(CommitFailedException.class, () -> coordinator.commitOffsetsSync(singletonMap(t1p, + new OffsetAndMetadata(100L, "metadata")), time.timer(Long.MAX_VALUE))); + } + + @Test + public void testCommitOffsetIllegalGenerationWithNewGeneration() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + final AbstractCoordinator.Generation currGen = new AbstractCoordinator.Generation( + 1, + "memberId", + null); + coordinator.setNewGeneration(currGen); + + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.ILLEGAL_GENERATION); + RequestFuture future = coordinator.sendOffsetCommitRequest(singletonMap(t1p, + new OffsetAndMetadata(100L, "metadata"))); + + // change the generation + final AbstractCoordinator.Generation newGen = new AbstractCoordinator.Generation( + 2, + "memberId-new", + null); + coordinator.setNewGeneration(newGen); + coordinator.setNewState(AbstractCoordinator.MemberState.PREPARING_REBALANCE); + + assertTrue(consumerClient.poll(future, time.timer(30000))); + assertTrue(future.exception().getClass().isInstance(Errors.REBALANCE_IN_PROGRESS.exception())); + + // the generation should not be reset + assertEquals(newGen, coordinator.generation()); + } + + @Test + public void testCommitOffsetIllegalGenerationWithResetGenearion() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + final AbstractCoordinator.Generation currGen = new AbstractCoordinator.Generation( + 1, + "memberId", + null); + coordinator.setNewGeneration(currGen); + + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.ILLEGAL_GENERATION); + RequestFuture future = coordinator.sendOffsetCommitRequest(singletonMap(t1p, + new OffsetAndMetadata(100L, "metadata"))); + + // reset the generation + coordinator.setNewGeneration(AbstractCoordinator.Generation.NO_GENERATION); + + assertTrue(consumerClient.poll(future, time.timer(30000))); + assertTrue(future.exception().getClass().isInstance(new CommitFailedException())); + + // the generation should not be reset + assertEquals(AbstractCoordinator.Generation.NO_GENERATION, coordinator.generation()); + } + + @Test + public void testCommitOffsetUnknownMemberWithNewGenearion() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + final AbstractCoordinator.Generation currGen = new AbstractCoordinator.Generation( + 1, + "memberId", + null); + coordinator.setNewGeneration(currGen); + + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.UNKNOWN_MEMBER_ID); + RequestFuture future = coordinator.sendOffsetCommitRequest(singletonMap(t1p, + new OffsetAndMetadata(100L, "metadata"))); + + // change the generation + final AbstractCoordinator.Generation newGen = new AbstractCoordinator.Generation( + 2, + "memberId-new", + null); + coordinator.setNewGeneration(newGen); + coordinator.setNewState(AbstractCoordinator.MemberState.PREPARING_REBALANCE); + + assertTrue(consumerClient.poll(future, time.timer(30000))); + assertTrue(future.exception().getClass().isInstance(Errors.REBALANCE_IN_PROGRESS.exception())); + + // the generation should not be reset + assertEquals(newGen, coordinator.generation()); + } + + @Test + public void testCommitOffsetUnknownMemberWithResetGenearion() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + final AbstractCoordinator.Generation currGen = new AbstractCoordinator.Generation( + 1, + "memberId", + null); + coordinator.setNewGeneration(currGen); + + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.UNKNOWN_MEMBER_ID); + RequestFuture future = coordinator.sendOffsetCommitRequest(singletonMap(t1p, + new OffsetAndMetadata(100L, "metadata"))); + + // reset the generation + coordinator.setNewGeneration(AbstractCoordinator.Generation.NO_GENERATION); + + assertTrue(consumerClient.poll(future, time.timer(30000))); + assertTrue(future.exception().getClass().isInstance(new CommitFailedException())); + + // the generation should not be reset + assertEquals(AbstractCoordinator.Generation.NO_GENERATION, coordinator.generation()); + } + + @Test + public void testCommitOffsetFencedInstanceWithRebalancingGenearion() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + final AbstractCoordinator.Generation currGen = new AbstractCoordinator.Generation( + 1, + "memberId", + null); + coordinator.setNewGeneration(currGen); + coordinator.setNewState(AbstractCoordinator.MemberState.PREPARING_REBALANCE); + + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.FENCED_INSTANCE_ID); + RequestFuture future = coordinator.sendOffsetCommitRequest(singletonMap(t1p, + new OffsetAndMetadata(100L, "metadata"))); + + // change the generation + final AbstractCoordinator.Generation newGen = new AbstractCoordinator.Generation( + 2, + "memberId-new", + null); + coordinator.setNewGeneration(newGen); + + assertTrue(consumerClient.poll(future, time.timer(30000))); + assertTrue(future.exception().getClass().isInstance(Errors.REBALANCE_IN_PROGRESS.exception())); + + // the generation should not be reset + assertEquals(newGen, coordinator.generation()); + } + + @Test + public void testCommitOffsetFencedInstanceWithNewGenearion() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + final AbstractCoordinator.Generation currGen = new AbstractCoordinator.Generation( + 1, + "memberId", + null); + coordinator.setNewGeneration(currGen); + + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.FENCED_INSTANCE_ID); + RequestFuture future = coordinator.sendOffsetCommitRequest(singletonMap(t1p, + new OffsetAndMetadata(100L, "metadata"))); + + // change the generation + final AbstractCoordinator.Generation newGen = new AbstractCoordinator.Generation( + 2, + "memberId-new", + null); + coordinator.setNewGeneration(newGen); + + assertTrue(consumerClient.poll(future, time.timer(30000))); + assertTrue(future.exception().getClass().isInstance(new CommitFailedException())); + + // the generation should not be reset + assertEquals(newGen, coordinator.generation()); + } + + @Test + public void testCommitOffsetRebalanceInProgress() { + // we cannot retry if a rebalance occurs before the commit completed + final String consumerId = "leader"; + + subscriptions.subscribe(singleton(topic1), rebalanceListener); + + // ensure metadata is up-to-date for leader + client.updateMetadata(metadataResponse); + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // normal join group + Map> memberSubscriptions = singletonMap(consumerId, singletonList(topic1)); + partitionAssignor.prepare(singletonMap(consumerId, singletonList(t1p))); + + coordinator.ensureActiveGroup(time.timer(0L)); + + assertTrue(coordinator.rejoinNeededOrPending()); + assertNull(coordinator.generationIfStable()); + + // when the state is REBALANCING, we would not even send out the request but fail immediately + assertThrows(RebalanceInProgressException.class, () -> coordinator.commitOffsetsSync(singletonMap(t1p, + new OffsetAndMetadata(100L, "metadata")), time.timer(Long.MAX_VALUE))); + + final Node coordinatorNode = new Node(Integer.MAX_VALUE - node.id(), node.host(), node.port()); + client.respondFrom(joinGroupLeaderResponse(1, consumerId, memberSubscriptions, Errors.NONE), coordinatorNode); + + client.prepareResponse(body -> { + SyncGroupRequest sync = (SyncGroupRequest) body; + return sync.data().memberId().equals(consumerId) && + sync.data().generationId() == 1 && + sync.groupAssignments().containsKey(consumerId); + }, syncGroupResponse(singletonList(t1p), Errors.NONE)); + coordinator.poll(time.timer(Long.MAX_VALUE)); + + AbstractCoordinator.Generation expectedGeneration = new AbstractCoordinator.Generation(1, consumerId, partitionAssignor.name()); + assertFalse(coordinator.rejoinNeededOrPending()); + assertEquals(expectedGeneration, coordinator.generationIfStable()); + + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.REBALANCE_IN_PROGRESS); + assertThrows(RebalanceInProgressException.class, () -> coordinator.commitOffsetsSync(singletonMap(t1p, + new OffsetAndMetadata(100L, "metadata")), time.timer(Long.MAX_VALUE))); + + assertTrue(coordinator.rejoinNeededOrPending()); + assertEquals(expectedGeneration, coordinator.generationIfStable()); + } + + @Test + public void testCommitOffsetSyncCallbackWithNonRetriableException() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // sync commit with invalid partitions should throw if we have no callback + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.UNKNOWN_SERVER_ERROR); + assertThrows(KafkaException.class, () -> coordinator.commitOffsetsSync(singletonMap(t1p, + new OffsetAndMetadata(100L)), time.timer(Long.MAX_VALUE))); + } + + @Test + public void testCommitOffsetSyncWithoutFutureGetsCompleted() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + assertFalse(coordinator.commitOffsetsSync(singletonMap(t1p, new OffsetAndMetadata(100L)), time.timer(0))); + } + + @Test + public void testRefreshOffset() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + subscriptions.assignFromUser(singleton(t1p)); + client.prepareResponse(offsetFetchResponse(t1p, Errors.NONE, "", 100L)); + coordinator.refreshCommittedOffsetsIfNeeded(time.timer(Long.MAX_VALUE)); + + assertEquals(Collections.emptySet(), subscriptions.initializingPartitions()); + assertTrue(subscriptions.hasAllFetchPositions()); + assertEquals(100L, subscriptions.position(t1p).offset); + } + + @Test + public void testRefreshOffsetWithValidation() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + subscriptions.assignFromUser(singleton(t1p)); + + // Initial leader epoch of 4 + MetadataResponse metadataResponse = RequestTestUtils.metadataUpdateWith("kafka-cluster", 1, + Collections.emptyMap(), singletonMap(topic1, 1), tp -> 4); + client.updateMetadata(metadataResponse); + + // Load offsets from previous epoch + client.prepareResponse(offsetFetchResponse(t1p, Errors.NONE, "", 100L, Optional.of(3))); + coordinator.refreshCommittedOffsetsIfNeeded(time.timer(Long.MAX_VALUE)); + + // Offset gets loaded, but requires validation + assertEquals(Collections.emptySet(), subscriptions.initializingPartitions()); + assertFalse(subscriptions.hasAllFetchPositions()); + assertTrue(subscriptions.awaitingValidation(t1p)); + assertEquals(subscriptions.position(t1p).offset, 100L); + assertNull(subscriptions.validPosition(t1p)); + } + + @Test + public void testFetchCommittedOffsets() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + long offset = 500L; + String metadata = "blahblah"; + Optional leaderEpoch = Optional.of(15); + OffsetFetchResponse.PartitionData data = new OffsetFetchResponse.PartitionData(offset, leaderEpoch, + metadata, Errors.NONE); + + client.prepareResponse(new OffsetFetchResponse(Errors.NONE, singletonMap(t1p, data))); + Map fetchedOffsets = coordinator.fetchCommittedOffsets(singleton(t1p), + time.timer(Long.MAX_VALUE)); + + assertNotNull(fetchedOffsets); + assertEquals(new OffsetAndMetadata(offset, leaderEpoch, metadata), fetchedOffsets.get(t1p)); + } + + @Test + public void testTopicAuthorizationFailedInOffsetFetch() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + OffsetFetchResponse.PartitionData data = new OffsetFetchResponse.PartitionData(-1, Optional.empty(), + "", Errors.TOPIC_AUTHORIZATION_FAILED); + + client.prepareResponse(new OffsetFetchResponse(Errors.NONE, singletonMap(t1p, data))); + TopicAuthorizationException exception = assertThrows(TopicAuthorizationException.class, () -> + coordinator.fetchCommittedOffsets(singleton(t1p), time.timer(Long.MAX_VALUE))); + + assertEquals(singleton(topic1), exception.unauthorizedTopics()); + } + + @Test + public void testRefreshOffsetLoadInProgress() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + subscriptions.assignFromUser(singleton(t1p)); + client.prepareResponse(offsetFetchResponse(Errors.COORDINATOR_LOAD_IN_PROGRESS)); + client.prepareResponse(offsetFetchResponse(t1p, Errors.NONE, "", 100L)); + coordinator.refreshCommittedOffsetsIfNeeded(time.timer(Long.MAX_VALUE)); + + assertEquals(Collections.emptySet(), subscriptions.initializingPartitions()); + assertTrue(subscriptions.hasAllFetchPositions()); + assertEquals(100L, subscriptions.position(t1p).offset); + } + + @Test + public void testRefreshOffsetsGroupNotAuthorized() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + subscriptions.assignFromUser(singleton(t1p)); + client.prepareResponse(offsetFetchResponse(Errors.GROUP_AUTHORIZATION_FAILED)); + try { + coordinator.refreshCommittedOffsetsIfNeeded(time.timer(Long.MAX_VALUE)); + fail("Expected group authorization error"); + } catch (GroupAuthorizationException e) { + assertEquals(groupId, e.groupId()); + } + } + + @Test + public void testRefreshOffsetWithPendingTransactions() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + subscriptions.assignFromUser(singleton(t1p)); + client.prepareResponse(offsetFetchResponse(t1p, Errors.UNSTABLE_OFFSET_COMMIT, "", -1L)); + client.prepareResponse(offsetFetchResponse(t1p, Errors.NONE, "", 100L)); + assertEquals(Collections.singleton(t1p), subscriptions.initializingPartitions()); + coordinator.refreshCommittedOffsetsIfNeeded(time.timer(0L)); + assertEquals(Collections.singleton(t1p), subscriptions.initializingPartitions()); + coordinator.refreshCommittedOffsetsIfNeeded(time.timer(0L)); + + assertEquals(Collections.emptySet(), subscriptions.initializingPartitions()); + assertTrue(subscriptions.hasAllFetchPositions()); + assertEquals(100L, subscriptions.position(t1p).offset); + } + + @Test + public void testRefreshOffsetUnknownTopicOrPartition() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + subscriptions.assignFromUser(singleton(t1p)); + client.prepareResponse(offsetFetchResponse(t1p, Errors.UNKNOWN_TOPIC_OR_PARTITION, "", 100L)); + assertThrows(KafkaException.class, () -> coordinator.refreshCommittedOffsetsIfNeeded(time.timer(Long.MAX_VALUE))); + } + + @Test + public void testRefreshOffsetNotCoordinatorForConsumer() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + subscriptions.assignFromUser(singleton(t1p)); + client.prepareResponse(offsetFetchResponse(Errors.NOT_COORDINATOR)); + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + client.prepareResponse(offsetFetchResponse(t1p, Errors.NONE, "", 100L)); + coordinator.refreshCommittedOffsetsIfNeeded(time.timer(Long.MAX_VALUE)); + + assertEquals(Collections.emptySet(), subscriptions.initializingPartitions()); + assertTrue(subscriptions.hasAllFetchPositions()); + assertEquals(100L, subscriptions.position(t1p).offset); + } + + @Test + public void testRefreshOffsetWithNoFetchableOffsets() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + subscriptions.assignFromUser(singleton(t1p)); + client.prepareResponse(offsetFetchResponse(t1p, Errors.NONE, "", -1L)); + coordinator.refreshCommittedOffsetsIfNeeded(time.timer(Long.MAX_VALUE)); + + assertEquals(Collections.singleton(t1p), subscriptions.initializingPartitions()); + assertEquals(Collections.emptySet(), subscriptions.partitionsNeedingReset(time.milliseconds())); + assertFalse(subscriptions.hasAllFetchPositions()); + assertNull(subscriptions.position(t1p)); + } + + @Test + public void testNoCoordinatorDiscoveryIfPositionsKnown() { + assertTrue(coordinator.coordinatorUnknown()); + + subscriptions.assignFromUser(singleton(t1p)); + subscriptions.seek(t1p, 500L); + coordinator.refreshCommittedOffsetsIfNeeded(time.timer(Long.MAX_VALUE)); + + assertEquals(Collections.emptySet(), subscriptions.initializingPartitions()); + assertTrue(subscriptions.hasAllFetchPositions()); + assertEquals(500L, subscriptions.position(t1p).offset); + assertTrue(coordinator.coordinatorUnknown()); + } + + @Test + public void testNoCoordinatorDiscoveryIfPartitionAwaitingReset() { + assertTrue(coordinator.coordinatorUnknown()); + + subscriptions.assignFromUser(singleton(t1p)); + subscriptions.requestOffsetReset(t1p, OffsetResetStrategy.EARLIEST); + coordinator.refreshCommittedOffsetsIfNeeded(time.timer(Long.MAX_VALUE)); + + assertEquals(Collections.emptySet(), subscriptions.initializingPartitions()); + assertFalse(subscriptions.hasAllFetchPositions()); + assertEquals(Collections.singleton(t1p), subscriptions.partitionsNeedingReset(time.milliseconds())); + assertEquals(OffsetResetStrategy.EARLIEST, subscriptions.resetStrategy(t1p)); + assertTrue(coordinator.coordinatorUnknown()); + } + + @Test + public void testAuthenticationFailureInEnsureActiveGroup() { + client.createPendingAuthenticationError(node, 300); + + try { + coordinator.ensureActiveGroup(); + fail("Expected an authentication error."); + } catch (AuthenticationException e) { + // OK + } + } + + @Test + public void testThreadSafeAssignedPartitionsMetric() throws Exception { + // Get the assigned-partitions metric + final Metric metric = metrics.metric(new MetricName("assigned-partitions", consumerId + groupId + "-coordinator-metrics", + "", Collections.emptyMap())); + + // Start polling the metric in the background + final AtomicBoolean doStop = new AtomicBoolean(); + final AtomicReference exceptionHolder = new AtomicReference<>(); + final AtomicInteger observedSize = new AtomicInteger(); + + Thread poller = new Thread() { + @Override + public void run() { + // Poll as fast as possible to reproduce ConcurrentModificationException + while (!doStop.get()) { + try { + int size = ((Double) metric.metricValue()).intValue(); + observedSize.set(size); + } catch (Exception e) { + exceptionHolder.set(e); + return; + } + } + } + }; + poller.start(); + + // Assign two partitions to trigger a metric change that can lead to ConcurrentModificationException + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // Change the assignment several times to increase likelihood of concurrent updates + Set partitions = new HashSet<>(); + int totalPartitions = 10; + for (int partition = 0; partition < totalPartitions; partition++) { + partitions.add(new TopicPartition(topic1, partition)); + subscriptions.assignFromUser(partitions); + } + + // Wait for the metric poller to observe the final assignment change or raise an error + TestUtils.waitForCondition( + () -> observedSize.get() == totalPartitions || + exceptionHolder.get() != null, "Failed to observe expected assignment change"); + + doStop.set(true); + poller.join(); + + assertNull(exceptionHolder.get(), "Failed fetching the metric at least once"); + } + + @Test + public void testCloseDynamicAssignment() { + try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, Optional.empty())) { + gracefulCloseTest(coordinator, true); + } + } + + @Test + public void testCloseManualAssignment() { + try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(false, true, Optional.empty())) { + gracefulCloseTest(coordinator, false); + } + } + + @Test + public void testCloseCoordinatorNotKnownManualAssignment() throws Exception { + try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(false, true, Optional.empty())) { + makeCoordinatorUnknown(coordinator, Errors.NOT_COORDINATOR); + time.sleep(autoCommitIntervalMs); + closeVerifyTimeout(coordinator, 1000, 1000, 1000); + } + } + + @Test + public void testCloseCoordinatorNotKnownNoCommits() throws Exception { + try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.empty())) { + makeCoordinatorUnknown(coordinator, Errors.NOT_COORDINATOR); + closeVerifyTimeout(coordinator, 1000, 0, 0); + } + } + + @Test + public void testCloseCoordinatorNotKnownWithCommits() throws Exception { + try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, Optional.empty())) { + makeCoordinatorUnknown(coordinator, Errors.NOT_COORDINATOR); + time.sleep(autoCommitIntervalMs); + closeVerifyTimeout(coordinator, 1000, 1000, 1000); + } + } + + @Test + public void testCloseCoordinatorUnavailableNoCommits() throws Exception { + try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.empty())) { + makeCoordinatorUnknown(coordinator, Errors.COORDINATOR_NOT_AVAILABLE); + closeVerifyTimeout(coordinator, 1000, 0, 0); + } + } + + @Test + public void testCloseTimeoutCoordinatorUnavailableForCommit() throws Exception { + try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId)) { + makeCoordinatorUnknown(coordinator, Errors.COORDINATOR_NOT_AVAILABLE); + time.sleep(autoCommitIntervalMs); + closeVerifyTimeout(coordinator, 1000, 1000, 1000); + } + } + + @Test + public void testCloseMaxWaitCoordinatorUnavailableForCommit() throws Exception { + try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId)) { + makeCoordinatorUnknown(coordinator, Errors.COORDINATOR_NOT_AVAILABLE); + time.sleep(autoCommitIntervalMs); + closeVerifyTimeout(coordinator, Long.MAX_VALUE, requestTimeoutMs, requestTimeoutMs); + } + } + + @Test + public void testCloseNoResponseForCommit() throws Exception { + try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId)) { + time.sleep(autoCommitIntervalMs); + closeVerifyTimeout(coordinator, Long.MAX_VALUE, requestTimeoutMs, requestTimeoutMs); + } + } + + @Test + public void testCloseNoResponseForLeaveGroup() throws Exception { + try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.empty())) { + closeVerifyTimeout(coordinator, Long.MAX_VALUE, requestTimeoutMs, requestTimeoutMs); + } + } + + @Test + public void testCloseNoWait() throws Exception { + try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId)) { + time.sleep(autoCommitIntervalMs); + closeVerifyTimeout(coordinator, 0, 0, 0); + } + } + + @Test + public void testHeartbeatThreadClose() throws Exception { + try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId)) { + coordinator.ensureActiveGroup(); + time.sleep(heartbeatIntervalMs + 100); + Thread.yield(); // Give heartbeat thread a chance to attempt heartbeat + closeVerifyTimeout(coordinator, Long.MAX_VALUE, requestTimeoutMs, requestTimeoutMs); + Thread[] threads = new Thread[Thread.activeCount()]; + int threadCount = Thread.enumerate(threads); + for (int i = 0; i < threadCount; i++) { + assertFalse(threads[i].getName().contains(groupId), "Heartbeat thread active after close"); + } + } + } + + @Test + public void testAutoCommitAfterCoordinatorBackToService() { + try (ConsumerCoordinator coordinator = buildCoordinator(rebalanceConfig, new Metrics(), assignors, true, subscriptions)) { + subscriptions.assignFromUser(Collections.singleton(t1p)); + subscriptions.seek(t1p, 100L); + + coordinator.markCoordinatorUnknown("test cause"); + assertTrue(coordinator.coordinatorUnknown()); + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.NONE); + + // async commit offset should find coordinator + time.sleep(autoCommitIntervalMs); // sleep for a while to ensure auto commit does happen + coordinator.maybeAutoCommitOffsetsAsync(time.milliseconds()); + assertFalse(coordinator.coordinatorUnknown()); + assertEquals(100L, subscriptions.position(t1p).offset); + } + } + + @Test + public void testCommitOffsetRequestSyncWithFencedInstanceIdException() { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // sync commit with invalid partitions should throw if we have no callback + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.FENCED_INSTANCE_ID); + assertThrows(FencedInstanceIdException.class, () -> coordinator.commitOffsetsSync(singletonMap(t1p, + new OffsetAndMetadata(100L)), time.timer(Long.MAX_VALUE))); + } + + @Test + public void testCommitOffsetRequestAsyncWithFencedInstanceIdException() { + assertThrows(FencedInstanceIdException.class, this::receiveFencedInstanceIdException); + } + + @Test + public void testCommitOffsetRequestAsyncAlwaysReceiveFencedException() { + // Once we get fenced exception once, we should always hit fencing case. + assertThrows(FencedInstanceIdException.class, this::receiveFencedInstanceIdException); + assertThrows(FencedInstanceIdException.class, () -> + coordinator.commitOffsetsAsync(singletonMap(t1p, new OffsetAndMetadata(100L)), new MockCommitCallback())); + assertThrows(FencedInstanceIdException.class, () -> + coordinator.commitOffsetsSync(singletonMap(t1p, new OffsetAndMetadata(100L)), time.timer(Long.MAX_VALUE))); + } + + @Test + public void testGetGroupMetadata() { + final ConsumerGroupMetadata groupMetadata = coordinator.groupMetadata(); + assertNotNull(groupMetadata); + assertEquals(groupId, groupMetadata.groupId()); + assertEquals(JoinGroupRequest.UNKNOWN_GENERATION_ID, groupMetadata.generationId()); + assertEquals(JoinGroupRequest.UNKNOWN_MEMBER_ID, groupMetadata.memberId()); + assertFalse(groupMetadata.groupInstanceId().isPresent()); + + try (final ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, true, groupInstanceId)) { + coordinator.ensureActiveGroup(); + + final ConsumerGroupMetadata joinedGroupMetadata = coordinator.groupMetadata(); + assertNotNull(joinedGroupMetadata); + assertEquals(groupId, joinedGroupMetadata.groupId()); + assertEquals(1, joinedGroupMetadata.generationId()); + assertEquals(consumerId, joinedGroupMetadata.memberId()); + assertEquals(groupInstanceId, joinedGroupMetadata.groupInstanceId()); + } + } + + @Test + public void shouldUpdateConsumerGroupMetadataBeforeCallbacks() { + final MockRebalanceListener rebalanceListener = new MockRebalanceListener() { + @Override + public void onPartitionsRevoked(Collection partitions) { + assertEquals(2, coordinator.groupMetadata().generationId()); + } + }; + + subscriptions.subscribe(singleton(topic1), rebalanceListener); + { + ByteBuffer buffer = ConsumerProtocol.serializeAssignment( + new ConsumerPartitionAssignor.Assignment(Collections.singletonList(t1p), ByteBuffer.wrap(new byte[0]))); + coordinator.onJoinComplete(1, "memberId", partitionAssignor.name(), buffer); + } + + ByteBuffer buffer = ConsumerProtocol.serializeAssignment( + new ConsumerPartitionAssignor.Assignment(Collections.emptyList(), ByteBuffer.wrap(new byte[0]))); + coordinator.onJoinComplete(2, "memberId", partitionAssignor.name(), buffer); + } + + @Test + public void testPrepareJoinAndRejoinAfterFailedRebalance() { + final List partitions = singletonList(t1p); + try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.of("group-id"))) { + coordinator.ensureActiveGroup(); + + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.REBALANCE_IN_PROGRESS); + + assertThrows(RebalanceInProgressException.class, () -> coordinator.commitOffsetsSync( + singletonMap(t1p, new OffsetAndMetadata(100L)), + time.timer(Long.MAX_VALUE))); + + assertFalse(client.hasPendingResponses()); + assertFalse(client.hasInFlightRequests()); + + int generationId = 42; + String memberId = "consumer-42"; + + client.prepareResponse(joinGroupFollowerResponse(generationId, memberId, "leader", Errors.NONE)); + + MockTime time = new MockTime(1); + + // onJoinPrepare will be executed and onJoinComplete will not. + boolean res = coordinator.joinGroupIfNeeded(time.timer(2)); + + assertFalse(res); + assertFalse(client.hasPendingResponses()); + // SynGroupRequest not responded. + assertEquals(1, client.inFlightRequestCount()); + assertEquals(generationId, coordinator.generation().generationId); + assertEquals(memberId, coordinator.generation().memberId); + + // Imitating heartbeat thread that clears generation data. + coordinator.maybeLeaveGroup("Clear generation data."); + + assertEquals(AbstractCoordinator.Generation.NO_GENERATION, coordinator.generation()); + + client.respond(syncGroupResponse(partitions, Errors.NONE)); + + // Join future should succeed but generation already cleared so result of join is false. + res = coordinator.joinGroupIfNeeded(time.timer(1)); + + assertFalse(res); + + // should have retried sending a join group request already + assertFalse(client.hasPendingResponses()); + assertEquals(1, client.inFlightRequestCount()); + + System.out.println(client.requests()); + + // Retry join should then succeed + client.respond(joinGroupFollowerResponse(generationId, memberId, "leader", Errors.NONE)); + client.prepareResponse(syncGroupResponse(partitions, Errors.NONE)); + + res = coordinator.joinGroupIfNeeded(time.timer(3000)); + + assertTrue(res); + assertFalse(client.hasPendingResponses()); + assertFalse(client.hasInFlightRequests()); + } + Collection lost = getLost(partitions); + assertEquals(lost.isEmpty() ? null : lost, rebalanceListener.lost); + assertEquals(lost.size(), rebalanceListener.lostCount); + } + + @Test + public void shouldLoseAllOwnedPartitionsBeforeRejoiningAfterDroppingOutOfTheGroup() { + final List partitions = singletonList(t1p); + try (ConsumerCoordinator coordinator = prepareCoordinatorForCloseTest(true, false, Optional.of("group-id"))) { + final SystemTime realTime = new SystemTime(); + coordinator.ensureActiveGroup(); + + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.REBALANCE_IN_PROGRESS); + + assertThrows(RebalanceInProgressException.class, () -> coordinator.commitOffsetsSync( + singletonMap(t1p, new OffsetAndMetadata(100L)), + time.timer(Long.MAX_VALUE))); + + int generationId = 42; + String memberId = "consumer-42"; + + client.prepareResponse(joinGroupFollowerResponse(generationId, memberId, "leader", Errors.NONE)); + client.prepareResponse(syncGroupResponse(Collections.emptyList(), Errors.UNKNOWN_MEMBER_ID)); + + boolean res = coordinator.joinGroupIfNeeded(realTime.timer(1000)); + + assertFalse(res); + assertEquals(AbstractCoordinator.Generation.NO_GENERATION, coordinator.generation()); + assertEquals("", coordinator.generation().memberId); + + res = coordinator.joinGroupIfNeeded(realTime.timer(1000)); + assertFalse(res); + } + Collection lost = getLost(partitions); + assertEquals(lost.isEmpty() ? 0 : 1, rebalanceListener.lostCount); + assertEquals(lost.isEmpty() ? null : lost, rebalanceListener.lost); + } + + + @Test + public void testThrowOnUnsupportedStableFlag() { + supportStableFlag((short) 6, true); + } + + @Test + public void testNoThrowWhenStableFlagIsSupported() { + supportStableFlag((short) 7, false); + } + + private void supportStableFlag(final short upperVersion, final boolean expectThrows) { + ConsumerCoordinator coordinator = new ConsumerCoordinator( + rebalanceConfig, + new LogContext(), + consumerClient, + assignors, + metadata, + subscriptions, + new Metrics(time), + consumerId + groupId, + time, + false, + autoCommitIntervalMs, + null, + true); + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + client.setNodeApiVersions(NodeApiVersions.create(ApiKeys.OFFSET_FETCH.id, (short) 0, upperVersion)); + + long offset = 500L; + String metadata = "blahblah"; + Optional leaderEpoch = Optional.of(15); + OffsetFetchResponse.PartitionData data = new OffsetFetchResponse.PartitionData(offset, leaderEpoch, + metadata, Errors.NONE); + + client.prepareResponse(new OffsetFetchResponse(Errors.NONE, singletonMap(t1p, data))); + if (expectThrows) { + assertThrows(UnsupportedVersionException.class, + () -> coordinator.fetchCommittedOffsets(singleton(t1p), time.timer(Long.MAX_VALUE))); + } else { + Map fetchedOffsets = coordinator.fetchCommittedOffsets(singleton(t1p), + time.timer(Long.MAX_VALUE)); + + assertNotNull(fetchedOffsets); + assertEquals(new OffsetAndMetadata(offset, leaderEpoch, metadata), fetchedOffsets.get(t1p)); + } + } + + private void receiveFencedInstanceIdException() { + subscriptions.assignFromUser(singleton(t1p)); + + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + prepareOffsetCommitRequest(singletonMap(t1p, 100L), Errors.FENCED_INSTANCE_ID); + + coordinator.commitOffsetsAsync(singletonMap(t1p, new OffsetAndMetadata(100L)), new MockCommitCallback()); + coordinator.invokeCompletedOffsetCommitCallbacks(); + } + + private ConsumerCoordinator prepareCoordinatorForCloseTest(final boolean useGroupManagement, + final boolean autoCommit, + final Optional groupInstanceId) { + rebalanceConfig = buildRebalanceConfig(groupInstanceId); + ConsumerCoordinator coordinator = buildCoordinator(rebalanceConfig, + new Metrics(), + assignors, + autoCommit, + subscriptions); + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + if (useGroupManagement) { + subscriptions.subscribe(singleton(topic1), rebalanceListener); + client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE)); + client.prepareResponse(syncGroupResponse(singletonList(t1p), Errors.NONE)); + coordinator.joinGroupIfNeeded(time.timer(Long.MAX_VALUE)); + } else { + subscriptions.assignFromUser(singleton(t1p)); + } + + subscriptions.seek(t1p, 100); + coordinator.poll(time.timer(Long.MAX_VALUE)); + + return coordinator; + } + + private void makeCoordinatorUnknown(ConsumerCoordinator coordinator, Errors error) { + time.sleep(sessionTimeoutMs); + coordinator.sendHeartbeatRequest(); + client.prepareResponse(heartbeatResponse(error)); + time.sleep(sessionTimeoutMs); + consumerClient.poll(time.timer(0)); + assertTrue(coordinator.coordinatorUnknown()); + } + + private void closeVerifyTimeout(final ConsumerCoordinator coordinator, + final long closeTimeoutMs, + final long expectedMinTimeMs, + final long expectedMaxTimeMs) throws Exception { + ExecutorService executor = Executors.newSingleThreadExecutor(); + try { + boolean coordinatorUnknown = coordinator.coordinatorUnknown(); + // Run close on a different thread. Coordinator is locked by this thread, so it is + // not safe to use the coordinator from the main thread until the task completes. + Future future = executor.submit( + () -> coordinator.close(time.timer(Math.min(closeTimeoutMs, requestTimeoutMs)))); + // Wait for close to start. If coordinator is known, wait for close to queue + // at least one request. Otherwise, sleep for a short time. + if (!coordinatorUnknown) + client.waitForRequests(1, 1000); + else + Thread.sleep(200); + if (expectedMinTimeMs > 0) { + time.sleep(expectedMinTimeMs - 1); + try { + future.get(500, TimeUnit.MILLISECONDS); + fail("Close completed ungracefully without waiting for timeout"); + } catch (TimeoutException e) { + // Expected timeout + } + } + if (expectedMaxTimeMs >= 0) + time.sleep(expectedMaxTimeMs - expectedMinTimeMs + 2); + future.get(2000, TimeUnit.MILLISECONDS); + } finally { + executor.shutdownNow(); + } + } + + private void gracefulCloseTest(ConsumerCoordinator coordinator, boolean shouldLeaveGroup) { + final AtomicBoolean commitRequested = new AtomicBoolean(); + final AtomicBoolean leaveGroupRequested = new AtomicBoolean(); + client.prepareResponse(body -> { + commitRequested.set(true); + OffsetCommitRequest commitRequest = (OffsetCommitRequest) body; + return commitRequest.data().groupId().equals(groupId); + }, new OffsetCommitResponse(new OffsetCommitResponseData())); + client.prepareResponse(body -> { + leaveGroupRequested.set(true); + LeaveGroupRequest leaveRequest = (LeaveGroupRequest) body; + return leaveRequest.data().groupId().equals(groupId); + }, new LeaveGroupResponse(new LeaveGroupResponseData() + .setErrorCode(Errors.NONE.code()))); + + coordinator.close(); + assertTrue(commitRequested.get(), "Commit not requested"); + assertEquals(shouldLeaveGroup, leaveGroupRequested.get(), "leaveGroupRequested should be " + shouldLeaveGroup); + + if (shouldLeaveGroup) { + assertEquals(1, rebalanceListener.revokedCount); + assertEquals(singleton(t1p), rebalanceListener.revoked); + } + } + + private ConsumerCoordinator buildCoordinator(final GroupRebalanceConfig rebalanceConfig, + final Metrics metrics, + final List assignors, + final boolean autoCommitEnabled, + final SubscriptionState subscriptionState) { + return new ConsumerCoordinator( + rebalanceConfig, + new LogContext(), + consumerClient, + assignors, + metadata, + subscriptionState, + metrics, + consumerId + groupId, + time, + autoCommitEnabled, + autoCommitIntervalMs, + null, + false); + } + + private Collection getRevoked(final List owned, + final List assigned) { + switch (protocol) { + case EAGER: + return toSet(owned); + case COOPERATIVE: + final List revoked = new ArrayList<>(owned); + revoked.removeAll(assigned); + return toSet(revoked); + default: + throw new IllegalStateException("This should not happen"); + } + } + + private Collection getLost(final List owned) { + switch (protocol) { + case EAGER: + return emptySet(); + case COOPERATIVE: + return toSet(owned); + default: + throw new IllegalStateException("This should not happen"); + } + } + + private Collection getAdded(final List owned, + final List assigned) { + switch (protocol) { + case EAGER: + return toSet(assigned); + case COOPERATIVE: + final List added = new ArrayList<>(assigned); + added.removeAll(owned); + return toSet(added); + default: + throw new IllegalStateException("This should not happen"); + } + } + + private FindCoordinatorResponse groupCoordinatorResponse(Node node, Errors error) { + return FindCoordinatorResponse.prepareResponse(error, groupId, node); + } + + private HeartbeatResponse heartbeatResponse(Errors error) { + return new HeartbeatResponse(new HeartbeatResponseData().setErrorCode(error.code())); + } + + private JoinGroupResponse joinGroupLeaderResponse(int generationId, + String memberId, + Map> subscriptions, + Errors error) { + List metadata = new ArrayList<>(); + for (Map.Entry> subscriptionEntry : subscriptions.entrySet()) { + ConsumerPartitionAssignor.Subscription subscription = new ConsumerPartitionAssignor.Subscription(subscriptionEntry.getValue()); + ByteBuffer buf = ConsumerProtocol.serializeSubscription(subscription); + metadata.add(new JoinGroupResponseData.JoinGroupResponseMember() + .setMemberId(subscriptionEntry.getKey()) + .setMetadata(buf.array())); + } + + return new JoinGroupResponse( + new JoinGroupResponseData() + .setErrorCode(error.code()) + .setGenerationId(generationId) + .setProtocolName(partitionAssignor.name()) + .setLeader(memberId) + .setMemberId(memberId) + .setMembers(metadata) + ); + } + + private JoinGroupResponse joinGroupFollowerResponse(int generationId, String memberId, String leaderId, Errors error) { + return new JoinGroupResponse( + new JoinGroupResponseData() + .setErrorCode(error.code()) + .setGenerationId(generationId) + .setProtocolName(partitionAssignor.name()) + .setLeader(leaderId) + .setMemberId(memberId) + .setMembers(Collections.emptyList()) + ); + } + + private SyncGroupResponse syncGroupResponse(List partitions, Errors error) { + ByteBuffer buf = ConsumerProtocol.serializeAssignment(new ConsumerPartitionAssignor.Assignment(partitions)); + return new SyncGroupResponse( + new SyncGroupResponseData() + .setErrorCode(error.code()) + .setAssignment(Utils.toArray(buf)) + ); + } + + private OffsetCommitResponse offsetCommitResponse(Map responseData) { + return new OffsetCommitResponse(responseData); + } + + private OffsetFetchResponse offsetFetchResponse(Errors topLevelError) { + return new OffsetFetchResponse(topLevelError, Collections.emptyMap()); + } + + private OffsetFetchResponse offsetFetchResponse(TopicPartition tp, Errors partitionLevelError, String metadata, long offset) { + return offsetFetchResponse(tp, partitionLevelError, metadata, offset, Optional.empty()); + } + + private OffsetFetchResponse offsetFetchResponse(TopicPartition tp, Errors partitionLevelError, String metadata, long offset, Optional epoch) { + OffsetFetchResponse.PartitionData data = new OffsetFetchResponse.PartitionData(offset, + epoch, metadata, partitionLevelError); + return new OffsetFetchResponse(Errors.NONE, singletonMap(tp, data)); + } + + private OffsetCommitCallback callback(final AtomicBoolean success) { + return (offsets, exception) -> { + if (exception == null) + success.set(true); + }; + } + + private void joinAsFollowerAndReceiveAssignment(ConsumerCoordinator coordinator, + List assignment) { + client.prepareResponse(groupCoordinatorResponse(node, Errors.NONE)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + client.prepareResponse(joinGroupFollowerResponse(1, consumerId, "leader", Errors.NONE)); + client.prepareResponse(syncGroupResponse(assignment, Errors.NONE)); + coordinator.joinGroupIfNeeded(time.timer(Long.MAX_VALUE)); + } + + private void prepareOffsetCommitRequest(Map expectedOffsets, Errors error) { + prepareOffsetCommitRequest(expectedOffsets, error, false); + } + + private void prepareOffsetCommitRequestDisconnect(Map expectedOffsets) { + prepareOffsetCommitRequest(expectedOffsets, Errors.NONE, true); + } + + private void prepareOffsetCommitRequest(final Map expectedOffsets, + Errors error, + boolean disconnected) { + Map errors = partitionErrors(expectedOffsets.keySet(), error); + client.prepareResponse(offsetCommitRequestMatcher(expectedOffsets), offsetCommitResponse(errors), disconnected); + } + + private void prepareJoinAndSyncResponse(String consumerId, int generation, List subscription, List assignment) { + partitionAssignor.prepare(singletonMap(consumerId, assignment)); + client.prepareResponse( + joinGroupLeaderResponse( + generation, consumerId, singletonMap(consumerId, subscription), Errors.NONE)); + client.prepareResponse(body -> { + SyncGroupRequest sync = (SyncGroupRequest) body; + return sync.data().memberId().equals(consumerId) && + sync.data().generationId() == generation && + sync.groupAssignments().containsKey(consumerId); + }, syncGroupResponse(assignment, Errors.NONE)); + } + + private Map partitionErrors(Collection partitions, Errors error) { + final Map errors = new HashMap<>(); + for (TopicPartition partition : partitions) { + errors.put(partition, error); + } + return errors; + } + + private void respondToOffsetCommitRequest(final Map expectedOffsets, Errors error) { + Map errors = partitionErrors(expectedOffsets.keySet(), error); + client.respond(offsetCommitRequestMatcher(expectedOffsets), offsetCommitResponse(errors)); + } + + private MockClient.RequestMatcher offsetCommitRequestMatcher(final Map expectedOffsets) { + return body -> { + OffsetCommitRequest req = (OffsetCommitRequest) body; + Map offsets = req.offsets(); + if (offsets.size() != expectedOffsets.size()) + return false; + + for (Map.Entry expectedOffset : expectedOffsets.entrySet()) { + if (!offsets.containsKey(expectedOffset.getKey())) { + return false; + } else { + Long actualOffset = offsets.get(expectedOffset.getKey()); + if (!actualOffset.equals(expectedOffset.getValue())) { + return false; + } + } + } + return true; + }; + } + + private OffsetCommitCallback callback(final Map expectedOffsets, + final AtomicBoolean success) { + return (offsets, exception) -> { + if (expectedOffsets.equals(offsets) && exception == null) + success.set(true); + }; + } + + private static class MockCommitCallback implements OffsetCommitCallback { + public int invoked = 0; + public Exception exception = null; + + @Override + public void onComplete(Map offsets, Exception exception) { + invoked++; + this.exception = exception; + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerInterceptorsTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerInterceptorsTest.java new file mode 100644 index 0000000..19ac256 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerInterceptorsTest.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerInterceptor; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.record.TimestampType; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.HashMap; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ConsumerInterceptorsTest { + private final int filterPartition1 = 5; + private final int filterPartition2 = 6; + private final String topic = "test"; + private final int partition = 1; + private final TopicPartition tp = new TopicPartition(topic, partition); + private final TopicPartition filterTopicPart1 = new TopicPartition("test5", filterPartition1); + private final TopicPartition filterTopicPart2 = new TopicPartition("test6", filterPartition2); + private final ConsumerRecord consumerRecord = new ConsumerRecord<>(topic, partition, 0, 0L, + TimestampType.CREATE_TIME, 0, 0, 1, 1, new RecordHeaders(), Optional.empty()); + private int onCommitCount = 0; + private int onConsumeCount = 0; + + /** + * Test consumer interceptor that filters records in onConsume() intercept + */ + private class FilterConsumerInterceptor implements ConsumerInterceptor { + private int filterPartition; + private boolean throwExceptionOnConsume = false; + private boolean throwExceptionOnCommit = false; + + FilterConsumerInterceptor(int filterPartition) { + this.filterPartition = filterPartition; + } + + @Override + public void configure(Map configs) { + } + + @Override + public ConsumerRecords onConsume(ConsumerRecords records) { + onConsumeCount++; + if (throwExceptionOnConsume) + throw new KafkaException("Injected exception in FilterConsumerInterceptor.onConsume."); + + // filters out topic/partitions with partition == FILTER_PARTITION + Map>> recordMap = new HashMap<>(); + for (TopicPartition tp : records.partitions()) { + if (tp.partition() != filterPartition) + recordMap.put(tp, records.records(tp)); + } + return new ConsumerRecords(recordMap); + } + + @Override + public void onCommit(Map offsets) { + onCommitCount++; + if (throwExceptionOnCommit) + throw new KafkaException("Injected exception in FilterConsumerInterceptor.onCommit."); + } + + @Override + public void close() { + } + + // if 'on' is true, onConsume will always throw an exception + public void injectOnConsumeError(boolean on) { + throwExceptionOnConsume = on; + } + + // if 'on' is true, onConsume will always throw an exception + public void injectOnCommitError(boolean on) { + throwExceptionOnCommit = on; + } + } + + @Test + public void testOnConsumeChain() { + List> interceptorList = new ArrayList<>(); + // we are testing two different interceptors by configuring the same interceptor differently, which is not + // how it would be done in KafkaConsumer, but ok for testing interceptor callbacks + FilterConsumerInterceptor interceptor1 = new FilterConsumerInterceptor<>(filterPartition1); + FilterConsumerInterceptor interceptor2 = new FilterConsumerInterceptor<>(filterPartition2); + interceptorList.add(interceptor1); + interceptorList.add(interceptor2); + ConsumerInterceptors interceptors = new ConsumerInterceptors<>(interceptorList); + + // verify that onConsumer modifies ConsumerRecords + Map>> records = new HashMap<>(); + List> list1 = new ArrayList<>(); + list1.add(consumerRecord); + List> list2 = new ArrayList<>(); + list2.add(new ConsumerRecord<>(filterTopicPart1.topic(), filterTopicPart1.partition(), 0, 0L, + TimestampType.CREATE_TIME, 0, 0, 1, 1, new RecordHeaders(), Optional.empty())); + List> list3 = new ArrayList<>(); + list3.add(new ConsumerRecord<>(filterTopicPart2.topic(), filterTopicPart2.partition(), 0, 0L, TimestampType.CREATE_TIME, + 0, 0, 1, 1, new RecordHeaders(), Optional.empty())); + records.put(tp, list1); + records.put(filterTopicPart1, list2); + records.put(filterTopicPart2, list3); + ConsumerRecords consumerRecords = new ConsumerRecords<>(records); + ConsumerRecords interceptedRecords = interceptors.onConsume(consumerRecords); + assertEquals(1, interceptedRecords.count()); + assertTrue(interceptedRecords.partitions().contains(tp)); + assertFalse(interceptedRecords.partitions().contains(filterTopicPart1)); + assertFalse(interceptedRecords.partitions().contains(filterTopicPart2)); + assertEquals(2, onConsumeCount); + + // verify that even if one of the intermediate interceptors throws an exception, all interceptors' onConsume are called + interceptor1.injectOnConsumeError(true); + ConsumerRecords partInterceptedRecs = interceptors.onConsume(consumerRecords); + assertEquals(2, partInterceptedRecs.count()); + assertTrue(partInterceptedRecs.partitions().contains(filterTopicPart1)); // since interceptor1 threw exception + assertFalse(partInterceptedRecs.partitions().contains(filterTopicPart2)); // interceptor2 should still be called + assertEquals(4, onConsumeCount); + + // if all interceptors throw an exception, records should be unmodified + interceptor2.injectOnConsumeError(true); + ConsumerRecords noneInterceptedRecs = interceptors.onConsume(consumerRecords); + assertEquals(noneInterceptedRecs, consumerRecords); + assertEquals(3, noneInterceptedRecs.count()); + assertEquals(6, onConsumeCount); + + interceptors.close(); + } + + @Test + public void testOnCommitChain() { + List> interceptorList = new ArrayList<>(); + // we are testing two different interceptors by configuring the same interceptor differently, which is not + // how it would be done in KafkaConsumer, but ok for testing interceptor callbacks + FilterConsumerInterceptor interceptor1 = new FilterConsumerInterceptor<>(filterPartition1); + FilterConsumerInterceptor interceptor2 = new FilterConsumerInterceptor<>(filterPartition2); + interceptorList.add(interceptor1); + interceptorList.add(interceptor2); + ConsumerInterceptors interceptors = new ConsumerInterceptors<>(interceptorList); + + // verify that onCommit is called for all interceptors in the chain + Map offsets = new HashMap<>(); + offsets.put(tp, new OffsetAndMetadata(0)); + interceptors.onCommit(offsets); + assertEquals(2, onCommitCount); + + // verify that even if one of the interceptors throws an exception, all interceptors' onCommit are called + interceptor1.injectOnCommitError(true); + interceptors.onCommit(offsets); + assertEquals(4, onCommitCount); + + interceptors.close(); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerMetadataTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerMetadataTest.java new file mode 100644 index 0000000..02ab81c --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerMetadataTest.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.clients.consumer.OffsetResetStrategy; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.internals.ClusterResourceListeners; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.MetadataRequest; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.requests.RequestTestUtils; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.regex.Pattern; + +import static java.util.Collections.singleton; +import static java.util.Collections.singletonList; +import static java.util.Collections.singletonMap; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ConsumerMetadataTest { + + private final Node node = new Node(1, "localhost", 9092); + private final SubscriptionState subscription = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + private final Time time = new MockTime(); + + @Test + public void testPatternSubscriptionNoInternalTopics() { + testPatternSubscription(false); + } + + @Test + public void testPatternSubscriptionIncludeInternalTopics() { + testPatternSubscription(true); + } + + private void testPatternSubscription(boolean includeInternalTopics) { + subscription.subscribe(Pattern.compile("__.*"), new NoOpConsumerRebalanceListener()); + ConsumerMetadata metadata = newConsumerMetadata(includeInternalTopics); + + MetadataRequest.Builder builder = metadata.newMetadataRequestBuilder(); + assertTrue(builder.isAllTopics()); + + List topics = new ArrayList<>(); + topics.add(topicMetadata("__consumer_offsets", true)); + topics.add(topicMetadata("__matching_topic", false)); + topics.add(topicMetadata("non_matching_topic", false)); + + MetadataResponse response = RequestTestUtils.metadataResponse(singletonList(node), + "clusterId", node.id(), topics); + metadata.updateWithCurrentRequestVersion(response, false, time.milliseconds()); + + if (includeInternalTopics) + assertEquals(Utils.mkSet("__matching_topic", "__consumer_offsets"), metadata.fetch().topics()); + else + assertEquals(Collections.singleton("__matching_topic"), metadata.fetch().topics()); + } + + @Test + public void testUserAssignment() { + subscription.assignFromUser(Utils.mkSet( + new TopicPartition("foo", 0), + new TopicPartition("bar", 0), + new TopicPartition("__consumer_offsets", 0))); + testBasicSubscription(Utils.mkSet("foo", "bar"), Utils.mkSet("__consumer_offsets")); + + subscription.assignFromUser(Utils.mkSet( + new TopicPartition("baz", 0), + new TopicPartition("__consumer_offsets", 0))); + testBasicSubscription(Utils.mkSet("baz"), Utils.mkSet("__consumer_offsets")); + } + + @Test + public void testNormalSubscription() { + subscription.subscribe(Utils.mkSet("foo", "bar", "__consumer_offsets"), new NoOpConsumerRebalanceListener()); + subscription.groupSubscribe(Utils.mkSet("baz", "foo", "bar", "__consumer_offsets")); + testBasicSubscription(Utils.mkSet("foo", "bar", "baz"), Utils.mkSet("__consumer_offsets")); + + subscription.resetGroupSubscription(); + testBasicSubscription(Utils.mkSet("foo", "bar"), Utils.mkSet("__consumer_offsets")); + } + + @Test + public void testTransientTopics() { + Map topicIds = new HashMap<>(); + topicIds.put("foo", Uuid.randomUuid()); + subscription.subscribe(singleton("foo"), new NoOpConsumerRebalanceListener()); + ConsumerMetadata metadata = newConsumerMetadata(false); + metadata.updateWithCurrentRequestVersion(RequestTestUtils.metadataUpdateWithIds(1, singletonMap("foo", 1), topicIds), false, time.milliseconds()); + assertEquals(topicIds.get("foo"), metadata.topicIds().get("foo")); + assertFalse(metadata.updateRequested()); + + metadata.addTransientTopics(singleton("foo")); + assertFalse(metadata.updateRequested()); + + metadata.addTransientTopics(singleton("bar")); + assertTrue(metadata.updateRequested()); + + Map topicPartitionCounts = new HashMap<>(); + topicPartitionCounts.put("foo", 1); + topicPartitionCounts.put("bar", 1); + topicIds.put("bar", Uuid.randomUuid()); + metadata.updateWithCurrentRequestVersion(RequestTestUtils.metadataUpdateWithIds(1, topicPartitionCounts, topicIds), false, time.milliseconds()); + Map metadataTopicIds = metadata.topicIds(); + topicIds.forEach((topicName, topicId) -> assertEquals(topicId, metadataTopicIds.get(topicName))); + assertFalse(metadata.updateRequested()); + + assertEquals(Utils.mkSet("foo", "bar"), new HashSet<>(metadata.fetch().topics())); + + metadata.clearTransientTopics(); + topicIds.remove("bar"); + metadata.updateWithCurrentRequestVersion(RequestTestUtils.metadataUpdateWithIds(1, topicPartitionCounts, topicIds), false, time.milliseconds()); + assertEquals(singleton("foo"), new HashSet<>(metadata.fetch().topics())); + assertEquals(topicIds.get("foo"), metadata.topicIds().get("foo")); + assertEquals(topicIds.get("bar"), null); + } + + private void testBasicSubscription(Set expectedTopics, Set expectedInternalTopics) { + Set allTopics = new HashSet<>(); + allTopics.addAll(expectedTopics); + allTopics.addAll(expectedInternalTopics); + + ConsumerMetadata metadata = newConsumerMetadata(false); + + MetadataRequest.Builder builder = metadata.newMetadataRequestBuilder(); + assertEquals(allTopics, new HashSet<>(builder.topics())); + + List topics = new ArrayList<>(); + for (String expectedTopic : expectedTopics) + topics.add(topicMetadata(expectedTopic, false)); + for (String expectedInternalTopic : expectedInternalTopics) + topics.add(topicMetadata(expectedInternalTopic, true)); + + MetadataResponse response = RequestTestUtils.metadataResponse(singletonList(node), + "clusterId", node.id(), topics); + metadata.updateWithCurrentRequestVersion(response, false, time.milliseconds()); + + assertEquals(allTopics, metadata.fetch().topics()); + } + + private MetadataResponse.TopicMetadata topicMetadata(String topic, boolean isInternal) { + MetadataResponse.PartitionMetadata partitionMetadata = new MetadataResponse.PartitionMetadata(Errors.NONE, + new TopicPartition(topic, 0), Optional.of(node.id()), Optional.of(5), + singletonList(node.id()), singletonList(node.id()), singletonList(node.id())); + return new MetadataResponse.TopicMetadata(Errors.NONE, topic, isInternal, singletonList(partitionMetadata)); + } + + private ConsumerMetadata newConsumerMetadata(boolean includeInternalTopics) { + long refreshBackoffMs = 50; + long expireMs = 50000; + return new ConsumerMetadata(refreshBackoffMs, expireMs, includeInternalTopics, false, + subscription, new LogContext(), new ClusterResourceListeners()); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClientTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClientTest.java new file mode 100644 index 0000000..df74f7e --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerNetworkClientTest.java @@ -0,0 +1,429 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.clients.ClientResponse; +import org.apache.kafka.clients.Metadata; +import org.apache.kafka.clients.MockClient; +import org.apache.kafka.clients.NetworkClient; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.errors.DisconnectException; +import org.apache.kafka.common.errors.InvalidTopicException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.TopicAuthorizationException; +import org.apache.kafka.common.errors.WakeupException; +import org.apache.kafka.common.internals.ClusterResourceListeners; +import org.apache.kafka.common.message.HeartbeatRequestData; +import org.apache.kafka.common.message.HeartbeatResponseData; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.HeartbeatRequest; +import org.apache.kafka.common.requests.HeartbeatResponse; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.requests.RequestTestUtils; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; + +import java.time.Duration; +import java.util.Collections; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class ConsumerNetworkClientTest { + + private String topicName = "test"; + private MockTime time = new MockTime(1); + private Cluster cluster = TestUtils.singletonCluster(topicName, 1); + private Node node = cluster.nodes().get(0); + private Metadata metadata = new Metadata(100, 50000, new LogContext(), + new ClusterResourceListeners()); + private MockClient client = new MockClient(time, metadata); + private ConsumerNetworkClient consumerClient = new ConsumerNetworkClient(new LogContext(), + client, metadata, time, 100, 1000, Integer.MAX_VALUE); + + @Test + public void send() { + client.prepareResponse(heartbeatResponse(Errors.NONE)); + RequestFuture future = consumerClient.send(node, heartbeat()); + assertEquals(1, consumerClient.pendingRequestCount()); + assertEquals(1, consumerClient.pendingRequestCount(node)); + assertFalse(future.isDone()); + + consumerClient.poll(future); + assertTrue(future.isDone()); + assertTrue(future.succeeded()); + + ClientResponse clientResponse = future.value(); + HeartbeatResponse response = (HeartbeatResponse) clientResponse.responseBody(); + assertEquals(Errors.NONE, response.error()); + } + + @Test + public void sendWithinBackoffPeriodAfterAuthenticationFailure() { + client.authenticationFailed(node, 300); + client.prepareResponse(heartbeatResponse(Errors.NONE)); + final RequestFuture future = consumerClient.send(node, heartbeat()); + consumerClient.poll(future); + assertTrue(future.failed()); + assertTrue(future.exception() instanceof AuthenticationException, "Expected only an authentication error."); + + time.sleep(30); // wait less than the backoff period + assertTrue(client.connectionFailed(node)); + + final RequestFuture future2 = consumerClient.send(node, heartbeat()); + consumerClient.poll(future2); + assertTrue(future2.failed()); + assertTrue(future2.exception() instanceof AuthenticationException, "Expected only an authentication error."); + } + + @Test + public void multiSend() { + client.prepareResponse(heartbeatResponse(Errors.NONE)); + client.prepareResponse(heartbeatResponse(Errors.NONE)); + RequestFuture future1 = consumerClient.send(node, heartbeat()); + RequestFuture future2 = consumerClient.send(node, heartbeat()); + assertEquals(2, consumerClient.pendingRequestCount()); + assertEquals(2, consumerClient.pendingRequestCount(node)); + + consumerClient.awaitPendingRequests(node, time.timer(Long.MAX_VALUE)); + assertTrue(future1.succeeded()); + assertTrue(future2.succeeded()); + } + + @Test + public void testDisconnectWithUnsentRequests() { + RequestFuture future = consumerClient.send(node, heartbeat()); + assertTrue(consumerClient.hasPendingRequests(node)); + assertFalse(client.hasInFlightRequests(node.idString())); + consumerClient.disconnectAsync(node); + consumerClient.pollNoWakeup(); + assertTrue(future.failed()); + assertTrue(future.exception() instanceof DisconnectException); + } + + @Test + public void testDisconnectWithInFlightRequests() { + RequestFuture future = consumerClient.send(node, heartbeat()); + consumerClient.pollNoWakeup(); + assertTrue(consumerClient.hasPendingRequests(node)); + assertTrue(client.hasInFlightRequests(node.idString())); + consumerClient.disconnectAsync(node); + consumerClient.pollNoWakeup(); + assertTrue(future.failed()); + assertTrue(future.exception() instanceof DisconnectException); + } + + @Test + public void testTimeoutUnsentRequest() { + // Delay connection to the node so that the request remains unsent + client.delayReady(node, 1000); + + RequestFuture future = consumerClient.send(node, heartbeat(), 500); + consumerClient.pollNoWakeup(); + + // Ensure the request is pending, but hasn't been sent + assertTrue(consumerClient.hasPendingRequests()); + assertFalse(client.hasInFlightRequests()); + + time.sleep(501); + consumerClient.pollNoWakeup(); + + assertFalse(consumerClient.hasPendingRequests()); + assertTrue(future.failed()); + assertTrue(future.exception() instanceof TimeoutException); + } + + @Test + public void doNotBlockIfPollConditionIsSatisfied() { + NetworkClient mockNetworkClient = mock(NetworkClient.class); + ConsumerNetworkClient consumerClient = new ConsumerNetworkClient(new LogContext(), + mockNetworkClient, metadata, time, 100, 1000, Integer.MAX_VALUE); + + // expect poll, but with no timeout + consumerClient.poll(time.timer(Long.MAX_VALUE), () -> false); + verify(mockNetworkClient).poll(eq(0L), anyLong()); + } + + @Test + public void blockWhenPollConditionNotSatisfied() { + long timeout = 4000L; + + NetworkClient mockNetworkClient = mock(NetworkClient.class); + ConsumerNetworkClient consumerClient = new ConsumerNetworkClient(new LogContext(), + mockNetworkClient, metadata, time, 100, 1000, Integer.MAX_VALUE); + + when(mockNetworkClient.inFlightRequestCount()).thenReturn(1); + consumerClient.poll(time.timer(timeout), () -> true); + verify(mockNetworkClient).poll(eq(timeout), anyLong()); + } + + @Test + public void blockOnlyForRetryBackoffIfNoInflightRequests() { + long retryBackoffMs = 100L; + + NetworkClient mockNetworkClient = mock(NetworkClient.class); + ConsumerNetworkClient consumerClient = new ConsumerNetworkClient(new LogContext(), + mockNetworkClient, metadata, time, retryBackoffMs, 1000, Integer.MAX_VALUE); + + when(mockNetworkClient.inFlightRequestCount()).thenReturn(0); + + consumerClient.poll(time.timer(Long.MAX_VALUE), () -> true); + + verify(mockNetworkClient).poll(eq(retryBackoffMs), anyLong()); + } + + @Test + public void wakeup() { + RequestFuture future = consumerClient.send(node, heartbeat()); + consumerClient.wakeup(); + try { + consumerClient.poll(time.timer(0)); + fail(); + } catch (WakeupException e) { + } + + client.respond(heartbeatResponse(Errors.NONE)); + consumerClient.poll(future); + assertTrue(future.isDone()); + } + + @Test + public void testDisconnectWakesUpPoll() throws Exception { + final RequestFuture future = consumerClient.send(node, heartbeat()); + + client.enableBlockingUntilWakeup(1); + Thread t = new Thread() { + @Override + public void run() { + consumerClient.poll(future); + } + }; + t.start(); + + consumerClient.disconnectAsync(node); + t.join(); + assertTrue(future.failed()); + assertTrue(future.exception() instanceof DisconnectException); + } + + @Test + public void testAuthenticationExceptionPropagatedFromMetadata() { + metadata.fatalError(new AuthenticationException("Authentication failed")); + try { + consumerClient.poll(time.timer(Duration.ZERO)); + fail("Expected authentication error thrown"); + } catch (AuthenticationException e) { + // After the exception is raised, it should have been cleared + metadata.maybeThrowAnyException(); + } + } + + @Test + public void testInvalidTopicExceptionPropagatedFromMetadata() { + MetadataResponse metadataResponse = RequestTestUtils.metadataUpdateWith("clusterId", 1, + Collections.singletonMap("topic", Errors.INVALID_TOPIC_EXCEPTION), Collections.emptyMap()); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, time.milliseconds()); + assertThrows(InvalidTopicException.class, () -> consumerClient.poll(time.timer(Duration.ZERO))); + } + + @Test + public void testTopicAuthorizationExceptionPropagatedFromMetadata() { + MetadataResponse metadataResponse = RequestTestUtils.metadataUpdateWith("clusterId", 1, + Collections.singletonMap("topic", Errors.TOPIC_AUTHORIZATION_FAILED), Collections.emptyMap()); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, time.milliseconds()); + assertThrows(TopicAuthorizationException.class, () -> consumerClient.poll(time.timer(Duration.ZERO))); + } + + @Test + public void testMetadataFailurePropagated() { + KafkaException metadataException = new KafkaException(); + metadata.fatalError(metadataException); + try { + consumerClient.poll(time.timer(Duration.ZERO)); + fail("Expected poll to throw exception"); + } catch (Exception e) { + assertEquals(metadataException, e); + } + } + + @Test + public void testFutureCompletionOutsidePoll() throws Exception { + // Tests the scenario in which the request that is being awaited in one thread + // is received and completed in another thread. + + final RequestFuture future = consumerClient.send(node, heartbeat()); + consumerClient.pollNoWakeup(); // dequeue and send the request + + client.enableBlockingUntilWakeup(2); + Thread t1 = new Thread() { + @Override + public void run() { + consumerClient.pollNoWakeup(); + } + }; + t1.start(); + + // Sleep a little so that t1 is blocking in poll + Thread.sleep(50); + + Thread t2 = new Thread() { + @Override + public void run() { + consumerClient.poll(future); + } + }; + t2.start(); + + // Sleep a little so that t2 is awaiting the network client lock + Thread.sleep(50); + + // Simulate a network response and return from the poll in t1 + client.respond(heartbeatResponse(Errors.NONE)); + client.wakeup(); + + // Both threads should complete since t1 should wakeup t2 + t1.join(); + t2.join(); + assertTrue(future.succeeded()); + } + + @Test + public void testAwaitForMetadataUpdateWithTimeout() { + assertFalse(consumerClient.awaitMetadataUpdate(time.timer(10L))); + } + + @Test + public void sendExpiry() { + int requestTimeoutMs = 10; + final AtomicBoolean isReady = new AtomicBoolean(); + final AtomicBoolean disconnected = new AtomicBoolean(); + client = new MockClient(time, metadata) { + @Override + public boolean ready(Node node, long now) { + if (isReady.get()) + return super.ready(node, now); + else + return false; + } + @Override + public boolean connectionFailed(Node node) { + return disconnected.get(); + } + }; + // Queue first send, sleep long enough for this to expire and then queue second send + consumerClient = new ConsumerNetworkClient(new LogContext(), client, metadata, time, 100, requestTimeoutMs, Integer.MAX_VALUE); + RequestFuture future1 = consumerClient.send(node, heartbeat()); + assertEquals(1, consumerClient.pendingRequestCount()); + assertEquals(1, consumerClient.pendingRequestCount(node)); + assertFalse(future1.isDone()); + + time.sleep(requestTimeoutMs + 1); + RequestFuture future2 = consumerClient.send(node, heartbeat()); + assertEquals(2, consumerClient.pendingRequestCount()); + assertEquals(2, consumerClient.pendingRequestCount(node)); + assertFalse(future2.isDone()); + + // First send should have expired and second send still pending + consumerClient.poll(time.timer(0)); + assertTrue(future1.isDone()); + assertFalse(future1.succeeded()); + assertEquals(1, consumerClient.pendingRequestCount()); + assertEquals(1, consumerClient.pendingRequestCount(node)); + assertFalse(future2.isDone()); + + // Enable send, the un-expired send should succeed on poll + isReady.set(true); + client.prepareResponse(heartbeatResponse(Errors.NONE)); + consumerClient.poll(future2); + ClientResponse clientResponse = future2.value(); + HeartbeatResponse response = (HeartbeatResponse) clientResponse.responseBody(); + assertEquals(Errors.NONE, response.error()); + + // Disable ready flag to delay send and queue another send. Disconnection should remove pending send + isReady.set(false); + RequestFuture future3 = consumerClient.send(node, heartbeat()); + assertEquals(1, consumerClient.pendingRequestCount()); + assertEquals(1, consumerClient.pendingRequestCount(node)); + disconnected.set(true); + consumerClient.poll(time.timer(0)); + assertTrue(future3.isDone()); + assertFalse(future3.succeeded()); + assertEquals(0, consumerClient.pendingRequestCount()); + assertEquals(0, consumerClient.pendingRequestCount(node)); + } + + @Test + public void testTrySend() { + final AtomicBoolean isReady = new AtomicBoolean(); + final AtomicInteger checkCount = new AtomicInteger(); + client = new MockClient(time, metadata) { + @Override + public boolean ready(Node node, long now) { + checkCount.incrementAndGet(); + if (isReady.get()) + return super.ready(node, now); + else + return false; + } + }; + consumerClient = new ConsumerNetworkClient(new LogContext(), client, metadata, time, 100, 10, Integer.MAX_VALUE); + consumerClient.send(node, heartbeat()); + consumerClient.send(node, heartbeat()); + assertEquals(2, consumerClient.pendingRequestCount(node)); + assertEquals(0, client.inFlightRequestCount(node.idString())); + + consumerClient.trySend(time.milliseconds()); + // only check one time when the node doesn't ready + assertEquals(1, checkCount.getAndSet(0)); + assertEquals(2, consumerClient.pendingRequestCount(node)); + assertEquals(0, client.inFlightRequestCount(node.idString())); + + isReady.set(true); + consumerClient.trySend(time.milliseconds()); + // check node ready or not for every request + assertEquals(2, checkCount.getAndSet(0)); + assertEquals(2, consumerClient.pendingRequestCount(node)); + assertEquals(2, client.inFlightRequestCount(node.idString())); + } + + private HeartbeatRequest.Builder heartbeat() { + return new HeartbeatRequest.Builder(new HeartbeatRequestData() + .setGroupId("group") + .setGenerationId(1) + .setMemberId("memberId")); + } + + private HeartbeatResponse heartbeatResponse(Errors error) { + return new HeartbeatResponse(new HeartbeatResponseData().setErrorCode(error.code())); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerProtocolTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerProtocolTest.java new file mode 100644 index 0000000..a2d5120 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ConsumerProtocolTest.java @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Assignment; +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Subscription; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.ConsumerProtocolAssignment; +import org.apache.kafka.common.message.ConsumerProtocolSubscription; +import org.apache.kafka.common.protocol.types.ArrayOf; +import org.apache.kafka.common.protocol.types.Field; +import org.apache.kafka.common.protocol.types.Schema; +import org.apache.kafka.common.protocol.types.Struct; +import org.apache.kafka.common.protocol.types.Type; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import static org.apache.kafka.test.TestUtils.toSet; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ConsumerProtocolTest { + + private final TopicPartition tp1 = new TopicPartition("foo", 1); + private final TopicPartition tp2 = new TopicPartition("bar", 2); + private final Optional groupInstanceId = Optional.of("instance.id"); + + @Test + public void serializeDeserializeSubscriptionAllVersions() { + List ownedPartitions = Arrays.asList( + new TopicPartition("foo", 0), + new TopicPartition("bar", 0)); + Subscription subscription = new Subscription(Arrays.asList("foo", "bar"), + ByteBuffer.wrap("hello".getBytes()), ownedPartitions); + + for (short version = ConsumerProtocolSubscription.LOWEST_SUPPORTED_VERSION; version <= ConsumerProtocolSubscription.HIGHEST_SUPPORTED_VERSION; version++) { + ByteBuffer buffer = ConsumerProtocol.serializeSubscription(subscription, version); + Subscription parsedSubscription = ConsumerProtocol.deserializeSubscription(buffer); + + assertEquals(toSet(subscription.topics()), toSet(parsedSubscription.topics())); + assertEquals(subscription.userData(), parsedSubscription.userData()); + assertFalse(parsedSubscription.groupInstanceId().isPresent()); + + if (version >= 1) { + assertEquals(toSet(subscription.ownedPartitions()), toSet(parsedSubscription.ownedPartitions())); + } else { + assertEquals(Collections.emptyList(), parsedSubscription.ownedPartitions()); + } + } + } + + @Test + public void serializeDeserializeMetadata() { + Subscription subscription = new Subscription(Arrays.asList("foo", "bar"), ByteBuffer.wrap(new byte[0])); + ByteBuffer buffer = ConsumerProtocol.serializeSubscription(subscription); + Subscription parsedSubscription = ConsumerProtocol.deserializeSubscription(buffer); + assertEquals(toSet(subscription.topics()), toSet(parsedSubscription.topics())); + assertEquals(0, parsedSubscription.userData().limit()); + assertFalse(parsedSubscription.groupInstanceId().isPresent()); + } + + @Test + public void serializeDeserializeMetadataAndGroupInstanceId() { + Subscription subscription = new Subscription(Arrays.asList("foo", "bar"), ByteBuffer.wrap(new byte[0])); + ByteBuffer buffer = ConsumerProtocol.serializeSubscription(subscription); + + Subscription parsedSubscription = ConsumerProtocol.deserializeSubscription(buffer); + parsedSubscription.setGroupInstanceId(groupInstanceId); + assertEquals(toSet(subscription.topics()), toSet(parsedSubscription.topics())); + assertEquals(0, parsedSubscription.userData().limit()); + assertEquals(groupInstanceId, parsedSubscription.groupInstanceId()); + } + + @Test + public void serializeDeserializeNullSubscriptionUserData() { + Subscription subscription = new Subscription(Arrays.asList("foo", "bar"), null); + ByteBuffer buffer = ConsumerProtocol.serializeSubscription(subscription); + Subscription parsedSubscription = ConsumerProtocol.deserializeSubscription(buffer); + assertEquals(toSet(subscription.topics()), toSet(parsedSubscription.topics())); + assertNull(parsedSubscription.userData()); + } + + @Test + public void serializeSubscriptionShouldOrderTopics() { + assertEquals( + ConsumerProtocol.serializeSubscription( + new Subscription(Arrays.asList("foo", "bar"), null, Arrays.asList(tp1, tp2)) + ), + ConsumerProtocol.serializeSubscription( + new Subscription(Arrays.asList("bar", "foo"), null, Arrays.asList(tp1, tp2)) + ) + ); + } + + @Test + public void serializeSubscriptionShouldOrderOwnedPartitions() { + assertEquals( + ConsumerProtocol.serializeSubscription( + new Subscription(Arrays.asList("foo", "bar"), null, Arrays.asList(tp1, tp2)) + ), + ConsumerProtocol.serializeSubscription( + new Subscription(Arrays.asList("foo", "bar"), null, Arrays.asList(tp2, tp1)) + ) + ); + } + + @Test + public void deserializeOldSubscriptionVersion() { + Subscription subscription = new Subscription(Arrays.asList("foo", "bar"), null); + ByteBuffer buffer = ConsumerProtocol.serializeSubscription(subscription, (short) 0); + Subscription parsedSubscription = ConsumerProtocol.deserializeSubscription(buffer); + assertEquals(toSet(parsedSubscription.topics()), toSet(parsedSubscription.topics())); + assertNull(parsedSubscription.userData()); + assertTrue(parsedSubscription.ownedPartitions().isEmpty()); + } + + @Test + public void deserializeNewSubscriptionWithOldVersion() { + Subscription subscription = new Subscription(Arrays.asList("foo", "bar"), null, Collections.singletonList(tp2)); + ByteBuffer buffer = ConsumerProtocol.serializeSubscription(subscription); + // ignore the version assuming it is the old byte code, as it will blindly deserialize as V0 + ConsumerProtocol.deserializeVersion(buffer); + Subscription parsedSubscription = ConsumerProtocol.deserializeSubscription(buffer, (short) 0); + assertEquals(toSet(subscription.topics()), toSet(parsedSubscription.topics())); + assertNull(parsedSubscription.userData()); + assertTrue(parsedSubscription.ownedPartitions().isEmpty()); + assertFalse(parsedSubscription.groupInstanceId().isPresent()); + } + + @Test + public void deserializeFutureSubscriptionVersion() { + // verify that a new version which adds a field is still parseable + short version = 100; + + Schema subscriptionSchemaV100 = new Schema( + new Field("topics", new ArrayOf(Type.STRING)), + new Field("user_data", Type.NULLABLE_BYTES), + new Field("owned_partitions", new ArrayOf( + ConsumerProtocolSubscription.TopicPartition.SCHEMA_1)), + new Field("foo", Type.STRING)); + + Struct subscriptionV100 = new Struct(subscriptionSchemaV100); + subscriptionV100.set("topics", new Object[]{"topic"}); + subscriptionV100.set("user_data", ByteBuffer.wrap(new byte[0])); + subscriptionV100.set("owned_partitions", new Object[]{new Struct( + ConsumerProtocolSubscription.TopicPartition.SCHEMA_1) + .set("topic", tp2.topic()) + .set("partitions", new Object[]{tp2.partition()})}); + subscriptionV100.set("foo", "bar"); + + Struct headerV100 = new Struct(new Schema(new Field("version", Type.INT16))); + headerV100.set("version", version); + + ByteBuffer buffer = ByteBuffer.allocate(subscriptionV100.sizeOf() + headerV100.sizeOf()); + headerV100.writeTo(buffer); + subscriptionV100.writeTo(buffer); + + buffer.flip(); + + Subscription subscription = ConsumerProtocol.deserializeSubscription(buffer); + subscription.setGroupInstanceId(groupInstanceId); + assertEquals(Collections.singleton("topic"), toSet(subscription.topics())); + assertEquals(Collections.singleton(tp2), toSet(subscription.ownedPartitions())); + assertEquals(groupInstanceId, subscription.groupInstanceId()); + } + + @Test + public void serializeDeserializeAssignmentAllVersions() { + List partitions = Arrays.asList(tp1, tp2); + Assignment assignment = new Assignment(partitions, ByteBuffer.wrap("hello".getBytes())); + + for (short version = ConsumerProtocolAssignment.LOWEST_SUPPORTED_VERSION; version <= ConsumerProtocolAssignment.HIGHEST_SUPPORTED_VERSION; version++) { + ByteBuffer buffer = ConsumerProtocol.serializeAssignment(assignment, version); + Assignment parsedAssignment = ConsumerProtocol.deserializeAssignment(buffer); + assertEquals(toSet(partitions), toSet(parsedAssignment.partitions())); + assertEquals(assignment.userData(), parsedAssignment.userData()); + } + } + + @Test + public void serializeDeserializeAssignment() { + List partitions = Arrays.asList(tp1, tp2); + ByteBuffer buffer = ConsumerProtocol.serializeAssignment(new Assignment(partitions, ByteBuffer.wrap(new byte[0]))); + Assignment parsedAssignment = ConsumerProtocol.deserializeAssignment(buffer); + assertEquals(toSet(partitions), toSet(parsedAssignment.partitions())); + assertEquals(0, parsedAssignment.userData().limit()); + } + + @Test + public void deserializeNullAssignmentUserData() { + List partitions = Arrays.asList(tp1, tp2); + ByteBuffer buffer = ConsumerProtocol.serializeAssignment(new Assignment(partitions, null)); + Assignment parsedAssignment = ConsumerProtocol.deserializeAssignment(buffer); + assertEquals(toSet(partitions), toSet(parsedAssignment.partitions())); + assertNull(parsedAssignment.userData()); + } + + @Test + public void deserializeFutureAssignmentVersion() { + // verify that a new version which adds a field is still parseable + short version = 100; + + Schema assignmentSchemaV100 = new Schema( + new Field("assigned_partitions", new ArrayOf( + ConsumerProtocolAssignment.TopicPartition.SCHEMA_0)), + new Field("user_data", Type.BYTES), + new Field("foo", Type.STRING)); + + Struct assignmentV100 = new Struct(assignmentSchemaV100); + assignmentV100.set("assigned_partitions", + new Object[]{new Struct(ConsumerProtocolAssignment.TopicPartition.SCHEMA_0) + .set("topic", tp1.topic()) + .set("partitions", new Object[]{tp1.partition()})}); + assignmentV100.set("user_data", ByteBuffer.wrap(new byte[0])); + assignmentV100.set("foo", "bar"); + + Struct headerV100 = new Struct(new Schema(new Field("version", Type.INT16))); + headerV100.set("version", version); + + ByteBuffer buffer = ByteBuffer.allocate(assignmentV100.sizeOf() + headerV100.sizeOf()); + headerV100.writeTo(buffer); + assignmentV100.writeTo(buffer); + + buffer.flip(); + + Assignment assignment = ConsumerProtocol.deserializeAssignment(buffer); + assertEquals(toSet(Collections.singletonList(tp1)), toSet(assignment.partitions())); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/CooperativeConsumerCoordinatorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/CooperativeConsumerCoordinatorTest.java new file mode 100644 index 0000000..60e530b --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/CooperativeConsumerCoordinatorTest.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor; + +public class CooperativeConsumerCoordinatorTest extends ConsumerCoordinatorTest { + public CooperativeConsumerCoordinatorTest() { + super(ConsumerPartitionAssignor.RebalanceProtocol.COOPERATIVE); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/EagerConsumerCoordinatorTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/EagerConsumerCoordinatorTest.java new file mode 100644 index 0000000..5bcfc8c --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/EagerConsumerCoordinatorTest.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor; + +public class EagerConsumerCoordinatorTest extends ConsumerCoordinatorTest { + public EagerConsumerCoordinatorTest() { + super(ConsumerPartitionAssignor.RebalanceProtocol.EAGER); + } +} + diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java new file mode 100644 index 0000000..b3dee9e --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java @@ -0,0 +1,5130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.clients.ApiVersions; +import org.apache.kafka.clients.ClientDnsLookup; +import org.apache.kafka.clients.ClientRequest; +import org.apache.kafka.clients.ClientUtils; +import org.apache.kafka.clients.FetchSessionHandler; +import org.apache.kafka.clients.Metadata; +import org.apache.kafka.clients.MockClient; +import org.apache.kafka.clients.NetworkClient; +import org.apache.kafka.clients.NodeApiVersions; +import org.apache.kafka.clients.consumer.ConsumerRebalanceListener; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.LogTruncationException; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.consumer.OffsetAndTimestamp; +import org.apache.kafka.clients.consumer.OffsetOutOfRangeException; +import org.apache.kafka.clients.consumer.OffsetResetStrategy; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.IsolationLevel; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.MetricNameTemplate; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.errors.InvalidTopicException; +import org.apache.kafka.common.errors.RecordTooLargeException; +import org.apache.kafka.common.errors.SerializationException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.TopicAuthorizationException; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.internals.ClusterResourceListeners; +import org.apache.kafka.common.message.FetchResponseData; +import org.apache.kafka.common.message.ApiMessageType; +import org.apache.kafka.common.message.ListOffsetsRequestData.ListOffsetsPartition; +import org.apache.kafka.common.message.ListOffsetsRequestData.ListOffsetsTopic; +import org.apache.kafka.common.message.ListOffsetsResponseData; +import org.apache.kafka.common.message.ListOffsetsResponseData.ListOffsetsPartitionResponse; +import org.apache.kafka.common.message.ListOffsetsResponseData.ListOffsetsTopicResponse; +import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData; +import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderPartition; +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData; +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset; +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.OffsetForLeaderTopicResult; +import org.apache.kafka.common.metrics.KafkaMetric; +import org.apache.kafka.common.metrics.MetricConfig; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.network.NetworkReceive; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.FetchRequest.PartitionData; +import org.apache.kafka.common.utils.BufferSupplier; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.ControlRecordType; +import org.apache.kafka.common.record.DefaultRecordBatch; +import org.apache.kafka.common.record.EndTransactionMarker; +import org.apache.kafka.common.record.LegacyRecord; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.MemoryRecordsBuilder; +import org.apache.kafka.common.record.Record; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.record.Records; +import org.apache.kafka.common.record.SimpleRecord; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.common.requests.AbstractRequest; +import org.apache.kafka.common.requests.ApiVersionsResponse; +import org.apache.kafka.common.requests.FetchRequest; +import org.apache.kafka.common.requests.FetchResponse; +import org.apache.kafka.common.requests.ListOffsetsRequest; +import org.apache.kafka.common.requests.ListOffsetsResponse; +import org.apache.kafka.common.requests.MetadataRequest; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.requests.OffsetsForLeaderEpochRequest; +import org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse; +import org.apache.kafka.common.requests.RequestTestUtils; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.apache.kafka.common.serialization.BytesDeserializer; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.utils.ByteBufferOutputStream; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.test.DelayedReceive; +import org.apache.kafka.test.MockSelector; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.DataOutputStream; +import java.lang.reflect.Field; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; +import static java.util.Collections.emptyMap; +import static java.util.Collections.singleton; +import static java.util.Collections.singletonList; +import static java.util.Collections.singletonMap; +import static org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID; +import static org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.UNDEFINED_EPOCH; +import static org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.UNDEFINED_EPOCH_OFFSET; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.test.TestUtils.assertOptional; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class FetcherTest { + private static final double EPSILON = 0.0001; + + private ConsumerRebalanceListener listener = new NoOpConsumerRebalanceListener(); + private String topicName = "test"; + private String groupId = "test-group"; + private Uuid topicId = Uuid.randomUuid(); + private Map topicIds = new HashMap() { + { + put(topicName, topicId); + } + }; + private Map topicNames = singletonMap(topicId, topicName); + private final String metricGroup = "consumer" + groupId + "-fetch-manager-metrics"; + private TopicPartition tp0 = new TopicPartition(topicName, 0); + private TopicPartition tp1 = new TopicPartition(topicName, 1); + private TopicPartition tp2 = new TopicPartition(topicName, 2); + private TopicPartition tp3 = new TopicPartition(topicName, 3); + private TopicIdPartition tidp0 = new TopicIdPartition(topicId, tp0); + private TopicIdPartition tidp1 = new TopicIdPartition(topicId, tp1); + private TopicIdPartition tidp2 = new TopicIdPartition(topicId, tp2); + private TopicIdPartition tidp3 = new TopicIdPartition(topicId, tp3); + private int validLeaderEpoch = 0; + private MetadataResponse initialUpdateResponse = + RequestTestUtils.metadataUpdateWithIds(1, singletonMap(topicName, 4), topicIds); + + private int minBytes = 1; + private int maxBytes = Integer.MAX_VALUE; + private int maxWaitMs = 0; + private int fetchSize = 1000; + private long retryBackoffMs = 100; + private long requestTimeoutMs = 30000; + private MockTime time = new MockTime(1); + private SubscriptionState subscriptions; + private ConsumerMetadata metadata; + private FetcherMetricsRegistry metricsRegistry; + private MockClient client; + private Metrics metrics; + private ApiVersions apiVersions = new ApiVersions(); + private ConsumerNetworkClient consumerClient; + private Fetcher fetcher; + + private MemoryRecords records; + private MemoryRecords nextRecords; + private MemoryRecords emptyRecords; + private MemoryRecords partialRecords; + private ExecutorService executorService; + + @BeforeEach + public void setup() { + records = buildRecords(1L, 3, 1); + nextRecords = buildRecords(4L, 2, 4); + emptyRecords = buildRecords(0L, 0, 0); + partialRecords = buildRecords(4L, 1, 0); + partialRecords.buffer().putInt(Records.SIZE_OFFSET, 10000); + } + + private void assignFromUser(Set partitions) { + subscriptions.assignFromUser(partitions); + client.updateMetadata(initialUpdateResponse); + + // A dummy metadata update to ensure valid leader epoch. + metadata.updateWithCurrentRequestVersion(RequestTestUtils.metadataUpdateWithIds("dummy", 1, + Collections.emptyMap(), singletonMap(topicName, 4), + tp -> validLeaderEpoch, topicIds), false, 0L); + } + + private void assignFromUserNoId(Set partitions) { + subscriptions.assignFromUser(partitions); + client.updateMetadata(RequestTestUtils.metadataUpdateWithIds(1, singletonMap("noId", 1), Collections.emptyMap())); + + // A dummy metadata update to ensure valid leader epoch. + metadata.update(9, RequestTestUtils.metadataUpdateWithIds("dummy", 1, + Collections.emptyMap(), singletonMap("noId", 1), + tp -> validLeaderEpoch, topicIds), false, 0L); + } + + @AfterEach + public void teardown() throws Exception { + if (metrics != null) + this.metrics.close(); + if (fetcher != null) + this.fetcher.close(); + if (executorService != null) { + executorService.shutdownNow(); + assertTrue(executorService.awaitTermination(5, TimeUnit.SECONDS)); + } + } + + @Test + public void testFetchNormal() { + buildFetcher(); + + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + // normal fetch + assertEquals(1, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + + client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0)); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + + Map>> partitionRecords = fetchedRecords(); + assertTrue(partitionRecords.containsKey(tp0)); + + List> records = partitionRecords.get(tp0); + assertEquals(3, records.size()); + assertEquals(4L, subscriptions.position(tp0).offset); // this is the next fetching position + long offset = 1; + for (ConsumerRecord record : records) { + assertEquals(offset, record.offset()); + offset += 1; + } + } + + @Test + public void testFetchWithNoTopicId() { + // Should work and default to using old request type. + buildFetcher(); + + TopicIdPartition noId = new TopicIdPartition(Uuid.ZERO_UUID, new TopicPartition("noId", 0)); + assignFromUserNoId(singleton(noId.topicPartition())); + subscriptions.seek(noId.topicPartition(), 0); + + // Fetch should use request version 12 + assertEquals(1, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + + client.prepareResponse( + fetchRequestMatcher((short) 12, noId, 0, Optional.of(validLeaderEpoch)), + fullFetchResponse(noId, this.records, Errors.NONE, 100L, 0) + ); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + + Map>> partitionRecords = fetchedRecords(); + assertTrue(partitionRecords.containsKey(noId.topicPartition())); + + List> records = partitionRecords.get(noId.topicPartition()); + assertEquals(3, records.size()); + assertEquals(4L, subscriptions.position(noId.topicPartition()).offset); // this is the next fetching position + long offset = 1; + for (ConsumerRecord record : records) { + assertEquals(offset, record.offset()); + offset += 1; + } + } + + @Test + public void testFetchWithTopicId() { + buildFetcher(); + + TopicIdPartition tp = new TopicIdPartition(topicId, new TopicPartition(topicName, 0)); + assignFromUser(singleton(tp.topicPartition())); + subscriptions.seek(tp.topicPartition(), 0); + + // Fetch should use latest version + assertEquals(1, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + + client.prepareResponse( + fetchRequestMatcher(ApiKeys.FETCH.latestVersion(), tp, 0, Optional.of(validLeaderEpoch)), + fullFetchResponse(tp, this.records, Errors.NONE, 100L, 0) + ); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + + Map>> partitionRecords = fetchedRecords(); + assertTrue(partitionRecords.containsKey(tp.topicPartition())); + + List> records = partitionRecords.get(tp.topicPartition()); + assertEquals(3, records.size()); + assertEquals(4L, subscriptions.position(tp.topicPartition()).offset); // this is the next fetching position + long offset = 1; + for (ConsumerRecord record : records) { + assertEquals(offset, record.offset()); + offset += 1; + } + } + + @Test + public void testFetchForgetTopicIdWhenUnassigned() { + buildFetcher(); + + TopicIdPartition foo = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 0)); + TopicIdPartition bar = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("bar", 0)); + + // Assign foo and bar. + subscriptions.assignFromUser(singleton(foo.topicPartition())); + client.updateMetadata(RequestTestUtils.metadataUpdateWithIds(1, singleton(foo), tp -> validLeaderEpoch)); + subscriptions.seek(foo.topicPartition(), 0); + + // Fetch should use latest version. + assertEquals(1, fetcher.sendFetches()); + + client.prepareResponse( + fetchRequestMatcher(ApiKeys.FETCH.latestVersion(), + singletonMap(foo, new PartitionData( + foo.topicId(), + 0, + FetchRequest.INVALID_LOG_START_OFFSET, + fetchSize, + Optional.of(validLeaderEpoch)) + ), + emptyList() + ), + fullFetchResponse(1, foo, this.records, Errors.NONE, 100L, 0) + ); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + fetchedRecords(); + + // Assign bar and unassign foo. + subscriptions.assignFromUser(singleton(bar.topicPartition())); + client.updateMetadata(RequestTestUtils.metadataUpdateWithIds(1, singleton(bar), tp -> validLeaderEpoch)); + subscriptions.seek(bar.topicPartition(), 0); + + // Fetch should use latest version. + assertEquals(1, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + + client.prepareResponse( + fetchRequestMatcher(ApiKeys.FETCH.latestVersion(), + singletonMap(bar, new PartitionData( + bar.topicId(), + 0, + FetchRequest.INVALID_LOG_START_OFFSET, + fetchSize, + Optional.of(validLeaderEpoch)) + ), + singletonList(foo) + ), + fullFetchResponse(1, bar, this.records, Errors.NONE, 100L, 0) + ); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + fetchedRecords(); + } + + @Test + public void testFetchForgetTopicIdWhenReplaced() { + buildFetcher(); + + TopicIdPartition fooWithOldTopicId = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 0)); + TopicIdPartition fooWithNewTopicId = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 0)); + + // Assign foo with old topic id. + subscriptions.assignFromUser(singleton(fooWithOldTopicId.topicPartition())); + client.updateMetadata(RequestTestUtils.metadataUpdateWithIds(1, singleton(fooWithOldTopicId), tp -> validLeaderEpoch)); + subscriptions.seek(fooWithOldTopicId.topicPartition(), 0); + + // Fetch should use latest version. + assertEquals(1, fetcher.sendFetches()); + + client.prepareResponse( + fetchRequestMatcher(ApiKeys.FETCH.latestVersion(), + singletonMap(fooWithOldTopicId, new PartitionData( + fooWithOldTopicId.topicId(), + 0, + FetchRequest.INVALID_LOG_START_OFFSET, + fetchSize, + Optional.of(validLeaderEpoch)) + ), + emptyList() + ), + fullFetchResponse(1, fooWithOldTopicId, this.records, Errors.NONE, 100L, 0) + ); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + fetchedRecords(); + + // Replace foo with old topic id with foo with new topic id. + subscriptions.assignFromUser(singleton(fooWithNewTopicId.topicPartition())); + client.updateMetadata(RequestTestUtils.metadataUpdateWithIds(1, singleton(fooWithNewTopicId), tp -> validLeaderEpoch)); + subscriptions.seek(fooWithNewTopicId.topicPartition(), 0); + + // Fetch should use latest version. + assertEquals(1, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + + // foo with old topic id should be removed from the session. + client.prepareResponse( + fetchRequestMatcher(ApiKeys.FETCH.latestVersion(), + singletonMap(fooWithNewTopicId, new PartitionData( + fooWithNewTopicId.topicId(), + 0, + FetchRequest.INVALID_LOG_START_OFFSET, + fetchSize, + Optional.of(validLeaderEpoch)) + ), + singletonList(fooWithOldTopicId) + ), + fullFetchResponse(1, fooWithNewTopicId, this.records, Errors.NONE, 100L, 0) + ); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + fetchedRecords(); + } + + @Test + public void testFetchTopicIdUpgradeDowngrade() { + buildFetcher(); + + TopicIdPartition fooWithoutId = new TopicIdPartition(Uuid.ZERO_UUID, new TopicPartition("foo", 0)); + TopicIdPartition fooWithId = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 0)); + + // Assign foo without a topic id. + subscriptions.assignFromUser(singleton(fooWithoutId.topicPartition())); + client.updateMetadata(RequestTestUtils.metadataUpdateWithIds(1, singleton(fooWithoutId), tp -> validLeaderEpoch)); + subscriptions.seek(fooWithoutId.topicPartition(), 0); + + // Fetch should use version 12. + assertEquals(1, fetcher.sendFetches()); + + client.prepareResponse( + fetchRequestMatcher((short) 12, + singletonMap(fooWithoutId, new PartitionData( + fooWithoutId.topicId(), + 0, + FetchRequest.INVALID_LOG_START_OFFSET, + fetchSize, + Optional.of(validLeaderEpoch)) + ), + emptyList() + ), + fullFetchResponse(1, fooWithoutId, this.records, Errors.NONE, 100L, 0) + ); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + fetchedRecords(); + + // Upgrade. + subscriptions.assignFromUser(singleton(fooWithId.topicPartition())); + client.updateMetadata(RequestTestUtils.metadataUpdateWithIds(1, singleton(fooWithId), tp -> validLeaderEpoch)); + subscriptions.seek(fooWithId.topicPartition(), 0); + + // Fetch should use latest version. + assertEquals(1, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + + // foo with old topic id should be removed from the session. + client.prepareResponse( + fetchRequestMatcher(ApiKeys.FETCH.latestVersion(), + singletonMap(fooWithId, new PartitionData( + fooWithId.topicId(), + 0, + FetchRequest.INVALID_LOG_START_OFFSET, + fetchSize, + Optional.of(validLeaderEpoch)) + ), + emptyList() + ), + fullFetchResponse(1, fooWithId, this.records, Errors.NONE, 100L, 0) + ); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + fetchedRecords(); + + // Downgrade. + subscriptions.assignFromUser(singleton(fooWithoutId.topicPartition())); + client.updateMetadata(RequestTestUtils.metadataUpdateWithIds(1, singleton(fooWithoutId), tp -> validLeaderEpoch)); + subscriptions.seek(fooWithoutId.topicPartition(), 0); + + // Fetch should use version 12. + assertEquals(1, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + + // foo with old topic id should be removed from the session. + client.prepareResponse( + fetchRequestMatcher((short) 12, + singletonMap(fooWithoutId, new PartitionData( + fooWithoutId.topicId(), + 0, + FetchRequest.INVALID_LOG_START_OFFSET, + fetchSize, + Optional.of(validLeaderEpoch)) + ), + emptyList() + ), + fullFetchResponse(1, fooWithoutId, this.records, Errors.NONE, 100L, 0) + ); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + fetchedRecords(); + } + + private MockClient.RequestMatcher fetchRequestMatcher( + short expectedVersion, + TopicIdPartition tp, + long expectedFetchOffset, + Optional expectedCurrentLeaderEpoch + ) { + return fetchRequestMatcher( + expectedVersion, + singletonMap(tp, new PartitionData( + tp.topicId(), + expectedFetchOffset, + FetchRequest.INVALID_LOG_START_OFFSET, + fetchSize, + expectedCurrentLeaderEpoch + )), + emptyList() + ); + } + + private MockClient.RequestMatcher fetchRequestMatcher( + short expectedVersion, + Map fetch, + List forgotten + ) { + return body -> { + if (body instanceof FetchRequest) { + FetchRequest fetchRequest = (FetchRequest) body; + assertEquals(expectedVersion, fetchRequest.version()); + assertEquals(fetch, fetchRequest.fetchData(topicNames(new ArrayList<>(fetch.keySet())))); + assertEquals(forgotten, fetchRequest.forgottenTopics(topicNames(forgotten))); + return true; + } else { + fail("Should have seen FetchRequest"); + return false; + } + }; + } + + private Map topicNames(List partitions) { + Map topicNames = new HashMap<>(); + partitions.forEach(partition -> topicNames.putIfAbsent(partition.topicId(), partition.topic())); + return topicNames; + } + + @Test + public void testMissingLeaderEpochInRecords() { + buildFetcher(); + + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + ByteBuffer buffer = ByteBuffer.allocate(1024); + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V0, + CompressionType.NONE, TimestampType.CREATE_TIME, 0L, System.currentTimeMillis(), + RecordBatch.NO_PARTITION_LEADER_EPOCH); + builder.append(0L, "key".getBytes(), "1".getBytes()); + builder.append(0L, "key".getBytes(), "2".getBytes()); + MemoryRecords records = builder.build(); + + assertEquals(1, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + + client.prepareResponse(fullFetchResponse(tidp0, records, Errors.NONE, 100L, 0)); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + + Map>> partitionRecords = fetchedRecords(); + assertTrue(partitionRecords.containsKey(tp0)); + assertEquals(2, partitionRecords.get(tp0).size()); + + for (ConsumerRecord record : partitionRecords.get(tp0)) { + assertEquals(Optional.empty(), record.leaderEpoch()); + } + } + + @Test + public void testLeaderEpochInConsumerRecord() { + buildFetcher(); + + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + Integer partitionLeaderEpoch = 1; + + ByteBuffer buffer = ByteBuffer.allocate(1024); + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, + CompressionType.NONE, TimestampType.CREATE_TIME, 0L, System.currentTimeMillis(), + partitionLeaderEpoch); + builder.append(0L, "key".getBytes(), partitionLeaderEpoch.toString().getBytes()); + builder.append(0L, "key".getBytes(), partitionLeaderEpoch.toString().getBytes()); + builder.close(); + + partitionLeaderEpoch += 7; + + builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE, + TimestampType.CREATE_TIME, 2L, System.currentTimeMillis(), partitionLeaderEpoch); + builder.append(0L, "key".getBytes(), partitionLeaderEpoch.toString().getBytes()); + builder.close(); + + partitionLeaderEpoch += 5; + builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE, + TimestampType.CREATE_TIME, 3L, System.currentTimeMillis(), partitionLeaderEpoch); + builder.append(0L, "key".getBytes(), partitionLeaderEpoch.toString().getBytes()); + builder.append(0L, "key".getBytes(), partitionLeaderEpoch.toString().getBytes()); + builder.append(0L, "key".getBytes(), partitionLeaderEpoch.toString().getBytes()); + builder.close(); + + buffer.flip(); + MemoryRecords records = MemoryRecords.readableRecords(buffer); + + assertEquals(1, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + + client.prepareResponse(fullFetchResponse(tidp0, records, Errors.NONE, 100L, 0)); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + + Map>> partitionRecords = fetchedRecords(); + assertTrue(partitionRecords.containsKey(tp0)); + assertEquals(6, partitionRecords.get(tp0).size()); + + for (ConsumerRecord record : partitionRecords.get(tp0)) { + int expectedLeaderEpoch = Integer.parseInt(Utils.utf8(record.value())); + assertEquals(Optional.of(expectedLeaderEpoch), record.leaderEpoch()); + } + } + + @Test + public void testClearBufferedDataForTopicPartitions() { + buildFetcher(); + + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + // normal fetch + assertEquals(1, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + + client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0)); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + Set newAssignedTopicPartitions = new HashSet<>(); + newAssignedTopicPartitions.add(tp1); + + fetcher.clearBufferedDataForUnassignedPartitions(newAssignedTopicPartitions); + assertFalse(fetcher.hasCompletedFetches()); + } + + @Test + public void testFetchSkipsBlackedOutNodes() { + buildFetcher(); + + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + Node node = initialUpdateResponse.brokers().iterator().next(); + + client.backoff(node, 500); + assertEquals(0, fetcher.sendFetches()); + + time.sleep(500); + assertEquals(1, fetcher.sendFetches()); + } + + @Test + public void testFetcherIgnoresControlRecords() { + buildFetcher(); + + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + // normal fetch + assertEquals(1, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + + long producerId = 1; + short producerEpoch = 0; + int baseSequence = 0; + int partitionLeaderEpoch = 0; + + ByteBuffer buffer = ByteBuffer.allocate(1024); + MemoryRecordsBuilder builder = MemoryRecords.idempotentBuilder(buffer, CompressionType.NONE, 0L, producerId, + producerEpoch, baseSequence); + builder.append(0L, "key".getBytes(), null); + builder.close(); + + MemoryRecords.writeEndTransactionalMarker(buffer, 1L, time.milliseconds(), partitionLeaderEpoch, producerId, producerEpoch, + new EndTransactionMarker(ControlRecordType.ABORT, 0)); + + buffer.flip(); + + client.prepareResponse(fullFetchResponse(tidp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0)); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + + Map>> partitionRecords = fetchedRecords(); + assertTrue(partitionRecords.containsKey(tp0)); + + List> records = partitionRecords.get(tp0); + assertEquals(1, records.size()); + assertEquals(2L, subscriptions.position(tp0).offset); + + ConsumerRecord record = records.get(0); + assertArrayEquals("key".getBytes(), record.key()); + } + + @Test + public void testFetchError() { + buildFetcher(); + + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + assertEquals(1, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + + client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NOT_LEADER_OR_FOLLOWER, 100L, 0)); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + + Map>> partitionRecords = fetchedRecords(); + assertFalse(partitionRecords.containsKey(tp0)); + } + + private MockClient.RequestMatcher matchesOffset(final TopicIdPartition tp, final long offset) { + return body -> { + FetchRequest fetch = (FetchRequest) body; + Map fetchData = fetch.fetchData(topicNames); + return fetchData.containsKey(tp) && + fetchData.get(tp).fetchOffset == offset; + }; + } + + @Test + public void testFetchedRecordsRaisesOnSerializationErrors() { + // raise an exception from somewhere in the middle of the fetch response + // so that we can verify that our position does not advance after raising + ByteArrayDeserializer deserializer = new ByteArrayDeserializer() { + int i = 0; + @Override + public byte[] deserialize(String topic, byte[] data) { + if (i++ % 2 == 1) { + // Should be blocked on the value deserialization of the first record. + assertEquals("value-1", new String(data, StandardCharsets.UTF_8)); + throw new SerializationException(); + } + return data; + } + }; + + buildFetcher(deserializer, deserializer); + + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 1); + + client.prepareResponse(matchesOffset(tidp0, 1), fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0)); + + assertEquals(1, fetcher.sendFetches()); + consumerClient.poll(time.timer(0)); + // The fetcher should block on Deserialization error + for (int i = 0; i < 2; i++) { + try { + fetcher.fetchedRecords(); + fail("fetchedRecords should have raised"); + } catch (SerializationException e) { + // the position should not advance since no data has been returned + assertEquals(1, subscriptions.position(tp0).offset); + } + } + } + + @Test + public void testParseCorruptedRecord() throws Exception { + buildFetcher(); + assignFromUser(singleton(tp0)); + + ByteBuffer buffer = ByteBuffer.allocate(1024); + DataOutputStream out = new DataOutputStream(new ByteBufferOutputStream(buffer)); + + byte magic = RecordBatch.MAGIC_VALUE_V1; + byte[] key = "foo".getBytes(); + byte[] value = "baz".getBytes(); + long offset = 0; + long timestamp = 500L; + + int size = LegacyRecord.recordSize(magic, key.length, value.length); + byte attributes = LegacyRecord.computeAttributes(magic, CompressionType.NONE, TimestampType.CREATE_TIME); + long crc = LegacyRecord.computeChecksum(magic, attributes, timestamp, key, value); + + // write one valid record + out.writeLong(offset); + out.writeInt(size); + LegacyRecord.write(out, magic, crc, LegacyRecord.computeAttributes(magic, CompressionType.NONE, TimestampType.CREATE_TIME), timestamp, key, value); + + // and one invalid record (note the crc) + out.writeLong(offset + 1); + out.writeInt(size); + LegacyRecord.write(out, magic, crc + 1, LegacyRecord.computeAttributes(magic, CompressionType.NONE, TimestampType.CREATE_TIME), timestamp, key, value); + + // write one valid record + out.writeLong(offset + 2); + out.writeInt(size); + LegacyRecord.write(out, magic, crc, LegacyRecord.computeAttributes(magic, CompressionType.NONE, TimestampType.CREATE_TIME), timestamp, key, value); + + // Write a record whose size field is invalid. + out.writeLong(offset + 3); + out.writeInt(1); + + // write one valid record + out.writeLong(offset + 4); + out.writeInt(size); + LegacyRecord.write(out, magic, crc, LegacyRecord.computeAttributes(magic, CompressionType.NONE, TimestampType.CREATE_TIME), timestamp, key, value); + + buffer.flip(); + + subscriptions.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(0, Optional.empty(), metadata.currentLeader(tp0))); + + // normal fetch + assertEquals(1, fetcher.sendFetches()); + client.prepareResponse(fullFetchResponse(tidp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0)); + consumerClient.poll(time.timer(0)); + + // the first fetchedRecords() should return the first valid message + assertEquals(1, fetcher.fetchedRecords().get(tp0).size()); + assertEquals(1, subscriptions.position(tp0).offset); + + ensureBlockOnRecord(1L); + seekAndConsumeRecord(buffer, 2L); + ensureBlockOnRecord(3L); + try { + // For a record that cannot be retrieved from the iterator, we cannot seek over it within the batch. + seekAndConsumeRecord(buffer, 4L); + fail("Should have thrown exception when fail to retrieve a record from iterator."); + } catch (KafkaException ke) { + // let it go + } + ensureBlockOnRecord(4L); + } + + private void ensureBlockOnRecord(long blockedOffset) { + // the fetchedRecords() should always throw exception due to the invalid message at the starting offset. + for (int i = 0; i < 2; i++) { + try { + fetcher.fetchedRecords(); + fail("fetchedRecords should have raised KafkaException"); + } catch (KafkaException e) { + assertEquals(blockedOffset, subscriptions.position(tp0).offset); + } + } + } + + private void seekAndConsumeRecord(ByteBuffer responseBuffer, long toOffset) { + // Seek to skip the bad record and fetch again. + subscriptions.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(toOffset, Optional.empty(), metadata.currentLeader(tp0))); + // Should not throw exception after the seek. + fetcher.fetchedRecords(); + assertEquals(1, fetcher.sendFetches()); + client.prepareResponse(fullFetchResponse(tidp0, MemoryRecords.readableRecords(responseBuffer), Errors.NONE, 100L, 0)); + consumerClient.poll(time.timer(0)); + + Map>> recordsByPartition = fetchedRecords(); + List> records = recordsByPartition.get(tp0); + assertEquals(1, records.size()); + assertEquals(toOffset, records.get(0).offset()); + assertEquals(toOffset + 1, subscriptions.position(tp0).offset); + } + + @Test + public void testInvalidDefaultRecordBatch() { + buildFetcher(); + + ByteBuffer buffer = ByteBuffer.allocate(1024); + ByteBufferOutputStream out = new ByteBufferOutputStream(buffer); + + MemoryRecordsBuilder builder = new MemoryRecordsBuilder(out, + DefaultRecordBatch.CURRENT_MAGIC_VALUE, + CompressionType.NONE, + TimestampType.CREATE_TIME, + 0L, 10L, 0L, (short) 0, 0, false, false, 0, 1024); + builder.append(10L, "key".getBytes(), "value".getBytes()); + builder.close(); + buffer.flip(); + + // Garble the CRC + buffer.position(17); + buffer.put("beef".getBytes()); + buffer.position(0); + + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + // normal fetch + assertEquals(1, fetcher.sendFetches()); + client.prepareResponse(fullFetchResponse(tidp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0)); + consumerClient.poll(time.timer(0)); + + // the fetchedRecords() should always throw exception due to the bad batch. + for (int i = 0; i < 2; i++) { + try { + fetcher.fetchedRecords(); + fail("fetchedRecords should have raised KafkaException"); + } catch (KafkaException e) { + assertEquals(0, subscriptions.position(tp0).offset); + } + } + } + + @Test + public void testParseInvalidRecordBatch() { + buildFetcher(); + MemoryRecords records = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V2, 0L, + CompressionType.NONE, TimestampType.CREATE_TIME, + new SimpleRecord(1L, "a".getBytes(), "1".getBytes()), + new SimpleRecord(2L, "b".getBytes(), "2".getBytes()), + new SimpleRecord(3L, "c".getBytes(), "3".getBytes())); + ByteBuffer buffer = records.buffer(); + + // flip some bits to fail the crc + buffer.putInt(32, buffer.get(32) ^ 87238423); + + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + // normal fetch + assertEquals(1, fetcher.sendFetches()); + client.prepareResponse(fullFetchResponse(tidp0, MemoryRecords.readableRecords(buffer), Errors.NONE, 100L, 0)); + consumerClient.poll(time.timer(0)); + try { + fetcher.fetchedRecords(); + fail("fetchedRecords should have raised"); + } catch (KafkaException e) { + // the position should not advance since no data has been returned + assertEquals(0, subscriptions.position(tp0).offset); + } + } + + @Test + public void testHeaders() { + buildFetcher(); + + MemoryRecordsBuilder builder = MemoryRecords.builder(ByteBuffer.allocate(1024), CompressionType.NONE, TimestampType.CREATE_TIME, 1L); + builder.append(0L, "key".getBytes(), "value-1".getBytes()); + + Header[] headersArray = new Header[1]; + headersArray[0] = new RecordHeader("headerKey", "headerValue".getBytes(StandardCharsets.UTF_8)); + builder.append(0L, "key".getBytes(), "value-2".getBytes(), headersArray); + + Header[] headersArray2 = new Header[2]; + headersArray2[0] = new RecordHeader("headerKey", "headerValue".getBytes(StandardCharsets.UTF_8)); + headersArray2[1] = new RecordHeader("headerKey", "headerValue2".getBytes(StandardCharsets.UTF_8)); + builder.append(0L, "key".getBytes(), "value-3".getBytes(), headersArray2); + + MemoryRecords memoryRecords = builder.build(); + + List> records; + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 1); + + client.prepareResponse(matchesOffset(tidp0, 1), fullFetchResponse(tidp0, memoryRecords, Errors.NONE, 100L, 0)); + + assertEquals(1, fetcher.sendFetches()); + consumerClient.poll(time.timer(0)); + Map>> recordsByPartition = fetchedRecords(); + records = recordsByPartition.get(tp0); + + assertEquals(3, records.size()); + + Iterator> recordIterator = records.iterator(); + + ConsumerRecord record = recordIterator.next(); + assertNull(record.headers().lastHeader("headerKey")); + + record = recordIterator.next(); + assertEquals("headerValue", new String(record.headers().lastHeader("headerKey").value(), StandardCharsets.UTF_8)); + assertEquals("headerKey", record.headers().lastHeader("headerKey").key()); + + record = recordIterator.next(); + assertEquals("headerValue2", new String(record.headers().lastHeader("headerKey").value(), StandardCharsets.UTF_8)); + assertEquals("headerKey", record.headers().lastHeader("headerKey").key()); + } + + @Test + public void testFetchMaxPollRecords() { + buildFetcher(2); + + List> records; + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 1); + + client.prepareResponse(matchesOffset(tidp0, 1), fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0)); + client.prepareResponse(matchesOffset(tidp0, 4), fullFetchResponse(tidp0, this.nextRecords, Errors.NONE, 100L, 0)); + + assertEquals(1, fetcher.sendFetches()); + consumerClient.poll(time.timer(0)); + Map>> recordsByPartition = fetchedRecords(); + records = recordsByPartition.get(tp0); + assertEquals(2, records.size()); + assertEquals(3L, subscriptions.position(tp0).offset); + assertEquals(1, records.get(0).offset()); + assertEquals(2, records.get(1).offset()); + + assertEquals(0, fetcher.sendFetches()); + consumerClient.poll(time.timer(0)); + recordsByPartition = fetchedRecords(); + records = recordsByPartition.get(tp0); + assertEquals(1, records.size()); + assertEquals(4L, subscriptions.position(tp0).offset); + assertEquals(3, records.get(0).offset()); + + assertTrue(fetcher.sendFetches() > 0); + consumerClient.poll(time.timer(0)); + recordsByPartition = fetchedRecords(); + records = recordsByPartition.get(tp0); + assertEquals(2, records.size()); + assertEquals(6L, subscriptions.position(tp0).offset); + assertEquals(4, records.get(0).offset()); + assertEquals(5, records.get(1).offset()); + } + + /** + * Test the scenario where a partition with fetched but not consumed records (i.e. max.poll.records is + * less than the number of fetched records) is unassigned and a different partition is assigned. This is a + * pattern used by Streams state restoration and KAFKA-5097 would have been caught by this test. + */ + @Test + public void testFetchAfterPartitionWithFetchedRecordsIsUnassigned() { + buildFetcher(2); + + List> records; + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 1); + + // Returns 3 records while `max.poll.records` is configured to 2 + client.prepareResponse(matchesOffset(tidp0, 1), fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0)); + + assertEquals(1, fetcher.sendFetches()); + consumerClient.poll(time.timer(0)); + Map>> recordsByPartition = fetchedRecords(); + records = recordsByPartition.get(tp0); + assertEquals(2, records.size()); + assertEquals(3L, subscriptions.position(tp0).offset); + assertEquals(1, records.get(0).offset()); + assertEquals(2, records.get(1).offset()); + + assignFromUser(singleton(tp1)); + client.prepareResponse(matchesOffset(tidp1, 4), fullFetchResponse(tidp1, this.nextRecords, Errors.NONE, 100L, 0)); + subscriptions.seek(tp1, 4); + + assertEquals(1, fetcher.sendFetches()); + consumerClient.poll(time.timer(0)); + Map>> fetchedRecords = fetchedRecords(); + assertNull(fetchedRecords.get(tp0)); + records = fetchedRecords.get(tp1); + assertEquals(2, records.size()); + assertEquals(6L, subscriptions.position(tp1).offset); + assertEquals(4, records.get(0).offset()); + assertEquals(5, records.get(1).offset()); + } + + @Test + public void testFetchNonContinuousRecords() { + // if we are fetching from a compacted topic, there may be gaps in the returned records + // this test verifies the fetcher updates the current fetched/consumed positions correctly for this case + buildFetcher(); + + MemoryRecordsBuilder builder = MemoryRecords.builder(ByteBuffer.allocate(1024), CompressionType.NONE, + TimestampType.CREATE_TIME, 0L); + builder.appendWithOffset(15L, 0L, "key".getBytes(), "value-1".getBytes()); + builder.appendWithOffset(20L, 0L, "key".getBytes(), "value-2".getBytes()); + builder.appendWithOffset(30L, 0L, "key".getBytes(), "value-3".getBytes()); + MemoryRecords records = builder.build(); + + List> consumerRecords; + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + // normal fetch + assertEquals(1, fetcher.sendFetches()); + client.prepareResponse(fullFetchResponse(tidp0, records, Errors.NONE, 100L, 0)); + consumerClient.poll(time.timer(0)); + Map>> recordsByPartition = fetchedRecords(); + consumerRecords = recordsByPartition.get(tp0); + assertEquals(3, consumerRecords.size()); + assertEquals(31L, subscriptions.position(tp0).offset); // this is the next fetching position + + assertEquals(15L, consumerRecords.get(0).offset()); + assertEquals(20L, consumerRecords.get(1).offset()); + assertEquals(30L, consumerRecords.get(2).offset()); + } + + /** + * Test the case where the client makes a pre-v3 FetchRequest, but the server replies with only a partial + * request. This happens when a single message is larger than the per-partition limit. + */ + @Test + public void testFetchRequestWhenRecordTooLarge() { + try { + buildFetcher(); + + client.setNodeApiVersions(NodeApiVersions.create(ApiKeys.FETCH.id, (short) 2, (short) 2)); + makeFetchRequestWithIncompleteRecord(); + try { + fetcher.fetchedRecords(); + fail("RecordTooLargeException should have been raised"); + } catch (RecordTooLargeException e) { + assertTrue(e.getMessage().startsWith("There are some messages at [Partition=Offset]: ")); + // the position should not advance since no data has been returned + assertEquals(0, subscriptions.position(tp0).offset); + } + } finally { + client.setNodeApiVersions(NodeApiVersions.create()); + } + } + + /** + * Test the case where the client makes a post KIP-74 FetchRequest, but the server replies with only a + * partial request. For v3 and later FetchRequests, the implementation of KIP-74 changed the behavior + * so that at least one message is always returned. Therefore, this case should not happen, and it indicates + * that an internal error has taken place. + */ + @Test + public void testFetchRequestInternalError() { + buildFetcher(); + makeFetchRequestWithIncompleteRecord(); + try { + fetcher.fetchedRecords(); + fail("RecordTooLargeException should have been raised"); + } catch (KafkaException e) { + assertTrue(e.getMessage().startsWith("Failed to make progress reading messages")); + // the position should not advance since no data has been returned + assertEquals(0, subscriptions.position(tp0).offset); + } + } + + private void makeFetchRequestWithIncompleteRecord() { + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + assertEquals(1, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + MemoryRecords partialRecord = MemoryRecords.readableRecords( + ByteBuffer.wrap(new byte[]{0, 0, 0, 0, 0, 0, 0, 0})); + client.prepareResponse(fullFetchResponse(tidp0, partialRecord, Errors.NONE, 100L, 0)); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + } + + @Test + public void testUnauthorizedTopic() { + buildFetcher(); + + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + // resize the limit of the buffer to pretend it is only fetch-size large + assertEquals(1, fetcher.sendFetches()); + client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.TOPIC_AUTHORIZATION_FAILED, 100L, 0)); + consumerClient.poll(time.timer(0)); + try { + fetcher.fetchedRecords(); + fail("fetchedRecords should have thrown"); + } catch (TopicAuthorizationException e) { + assertEquals(singleton(topicName), e.unauthorizedTopics()); + } + } + + @Test + public void testFetchDuringEagerRebalance() { + buildFetcher(); + + subscriptions.subscribe(singleton(topicName), listener); + subscriptions.assignFromSubscribed(singleton(tp0)); + subscriptions.seek(tp0, 0); + + client.updateMetadata(RequestTestUtils.metadataUpdateWithIds( + 1, singletonMap(topicName, 4), tp -> validLeaderEpoch, topicIds)); + + assertEquals(1, fetcher.sendFetches()); + + // Now the eager rebalance happens and fetch positions are cleared + subscriptions.assignFromSubscribed(Collections.emptyList()); + + subscriptions.assignFromSubscribed(singleton(tp0)); + client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0)); + consumerClient.poll(time.timer(0)); + + // The active fetch should be ignored since its position is no longer valid + assertTrue(fetcher.fetchedRecords().isEmpty()); + } + + @Test + public void testFetchDuringCooperativeRebalance() { + buildFetcher(); + + subscriptions.subscribe(singleton(topicName), listener); + subscriptions.assignFromSubscribed(singleton(tp0)); + subscriptions.seek(tp0, 0); + + client.updateMetadata(RequestTestUtils.metadataUpdateWithIds( + 1, singletonMap(topicName, 4), tp -> validLeaderEpoch, topicIds)); + + assertEquals(1, fetcher.sendFetches()); + + // Now the cooperative rebalance happens and fetch positions are NOT cleared for unrevoked partitions + subscriptions.assignFromSubscribed(singleton(tp0)); + + client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0)); + consumerClient.poll(time.timer(0)); + + Map>> fetchedRecords = fetchedRecords(); + + // The active fetch should NOT be ignored since the position for tp0 is still valid + assertEquals(1, fetchedRecords.size()); + assertEquals(3, fetchedRecords.get(tp0).size()); + } + + @Test + public void testInFlightFetchOnPausedPartition() { + buildFetcher(); + + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + assertEquals(1, fetcher.sendFetches()); + subscriptions.pause(tp0); + + client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0)); + consumerClient.poll(time.timer(0)); + assertNull(fetcher.fetchedRecords().get(tp0)); + } + + @Test + public void testFetchOnPausedPartition() { + buildFetcher(); + + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + subscriptions.pause(tp0); + assertFalse(fetcher.sendFetches() > 0); + assertTrue(client.requests().isEmpty()); + } + + @Test + public void testFetchOnCompletedFetchesForPausedAndResumedPartitions() { + buildFetcher(); + + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + assertEquals(1, fetcher.sendFetches()); + + subscriptions.pause(tp0); + + client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0)); + consumerClient.poll(time.timer(0)); + + Map>> fetchedRecords = fetchedRecords(); + assertEquals(0, fetchedRecords.size(), "Should not return any records when partition is paused"); + assertTrue(fetcher.hasCompletedFetches(), "Should still contain completed fetches"); + assertFalse(fetcher.hasAvailableFetches(), "Should not have any available (non-paused) completed fetches"); + assertNull(fetchedRecords.get(tp0)); + assertEquals(0, fetcher.sendFetches()); + + subscriptions.resume(tp0); + + assertTrue(fetcher.hasAvailableFetches(), "Should have available (non-paused) completed fetches"); + + consumerClient.poll(time.timer(0)); + fetchedRecords = fetchedRecords(); + assertEquals(1, fetchedRecords.size(), "Should return records when partition is resumed"); + assertNotNull(fetchedRecords.get(tp0)); + assertEquals(3, fetchedRecords.get(tp0).size()); + + consumerClient.poll(time.timer(0)); + fetchedRecords = fetchedRecords(); + assertEquals(0, fetchedRecords.size(), "Should not return records after previously paused partitions are fetched"); + assertFalse(fetcher.hasCompletedFetches(), "Should no longer contain completed fetches"); + } + + @Test + public void testFetchOnCompletedFetchesForSomePausedPartitions() { + buildFetcher(); + + Map>> fetchedRecords; + + assignFromUser(mkSet(tp0, tp1)); + + // seek to tp0 and tp1 in two polls to generate 2 complete requests and responses + + // #1 seek, request, poll, response + subscriptions.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(1, Optional.empty(), metadata.currentLeader(tp0))); + assertEquals(1, fetcher.sendFetches()); + client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0)); + consumerClient.poll(time.timer(0)); + + // #2 seek, request, poll, response + subscriptions.seekUnvalidated(tp1, new SubscriptionState.FetchPosition(1, Optional.empty(), metadata.currentLeader(tp1))); + assertEquals(1, fetcher.sendFetches()); + client.prepareResponse(fullFetchResponse(tidp1, this.nextRecords, Errors.NONE, 100L, 0)); + + subscriptions.pause(tp0); + consumerClient.poll(time.timer(0)); + + fetchedRecords = fetchedRecords(); + assertEquals(1, fetchedRecords.size(), "Should return completed fetch for unpaused partitions"); + assertTrue(fetcher.hasCompletedFetches(), "Should still contain completed fetches"); + assertNotNull(fetchedRecords.get(tp1)); + assertNull(fetchedRecords.get(tp0)); + + fetchedRecords = fetchedRecords(); + assertEquals(0, fetchedRecords.size(), "Should return no records for remaining paused partition"); + assertTrue(fetcher.hasCompletedFetches(), "Should still contain completed fetches"); + } + + @Test + public void testFetchOnCompletedFetchesForAllPausedPartitions() { + buildFetcher(); + + Map>> fetchedRecords; + + assignFromUser(mkSet(tp0, tp1)); + + // seek to tp0 and tp1 in two polls to generate 2 complete requests and responses + + // #1 seek, request, poll, response + subscriptions.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(1, Optional.empty(), metadata.currentLeader(tp0))); + assertEquals(1, fetcher.sendFetches()); + client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0)); + consumerClient.poll(time.timer(0)); + + // #2 seek, request, poll, response + subscriptions.seekUnvalidated(tp1, new SubscriptionState.FetchPosition(1, Optional.empty(), metadata.currentLeader(tp1))); + assertEquals(1, fetcher.sendFetches()); + client.prepareResponse(fullFetchResponse(tidp1, this.nextRecords, Errors.NONE, 100L, 0)); + + subscriptions.pause(tp0); + subscriptions.pause(tp1); + + consumerClient.poll(time.timer(0)); + + fetchedRecords = fetchedRecords(); + assertEquals(0, fetchedRecords.size(), "Should return no records for all paused partitions"); + assertTrue(fetcher.hasCompletedFetches(), "Should still contain completed fetches"); + assertFalse(fetcher.hasAvailableFetches(), "Should not have any available (non-paused) completed fetches"); + } + + @Test + public void testPartialFetchWithPausedPartitions() { + // this test sends creates a completed fetch with 3 records and a max poll of 2 records to assert + // that a fetch that must be returned over at least 2 polls can be cached successfully when its partition is + // paused, then returned successfully after its been resumed again later + buildFetcher(2); + + Map>> fetchedRecords; + + assignFromUser(mkSet(tp0, tp1)); + + subscriptions.seek(tp0, 1); + assertEquals(1, fetcher.sendFetches()); + client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0)); + consumerClient.poll(time.timer(0)); + + fetchedRecords = fetchedRecords(); + + assertEquals(2, fetchedRecords.get(tp0).size(), "Should return 2 records from fetch with 3 records"); + assertFalse(fetcher.hasCompletedFetches(), "Should have no completed fetches"); + + subscriptions.pause(tp0); + consumerClient.poll(time.timer(0)); + + fetchedRecords = fetchedRecords(); + + assertEquals(0, fetchedRecords.size(), "Should return no records for paused partitions"); + assertTrue(fetcher.hasCompletedFetches(), "Should have 1 entry in completed fetches"); + assertFalse(fetcher.hasAvailableFetches(), "Should not have any available (non-paused) completed fetches"); + + subscriptions.resume(tp0); + + consumerClient.poll(time.timer(0)); + + fetchedRecords = fetchedRecords(); + + assertEquals(1, fetchedRecords.get(tp0).size(), "Should return last remaining record"); + assertFalse(fetcher.hasCompletedFetches(), "Should have no completed fetches"); + } + + @Test + public void testFetchDiscardedAfterPausedPartitionResumedAndSeekedToNewOffset() { + buildFetcher(); + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + assertEquals(1, fetcher.sendFetches()); + subscriptions.pause(tp0); + client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0)); + + subscriptions.seek(tp0, 3); + subscriptions.resume(tp0); + consumerClient.poll(time.timer(0)); + + assertTrue(fetcher.hasCompletedFetches(), "Should have 1 entry in completed fetches"); + Map>> fetchedRecords = fetchedRecords(); + assertEquals(0, fetchedRecords.size(), "Should not return any records because we seeked to a new offset"); + assertNull(fetchedRecords.get(tp0)); + assertFalse(fetcher.hasCompletedFetches(), "Should have no completed fetches"); + } + + @Test + public void testFetchNotLeaderOrFollower() { + buildFetcher(); + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + assertEquals(1, fetcher.sendFetches()); + client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NOT_LEADER_OR_FOLLOWER, 100L, 0)); + consumerClient.poll(time.timer(0)); + assertEquals(0, fetcher.fetchedRecords().size()); + assertEquals(0L, metadata.timeToNextUpdate(time.milliseconds())); + } + + @Test + public void testFetchUnknownTopicOrPartition() { + buildFetcher(); + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + assertEquals(1, fetcher.sendFetches()); + client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.UNKNOWN_TOPIC_OR_PARTITION, 100L, 0)); + consumerClient.poll(time.timer(0)); + assertEquals(0, fetcher.fetchedRecords().size()); + assertEquals(0L, metadata.timeToNextUpdate(time.milliseconds())); + } + + @Test + public void testFetchUnknownTopicId() { + buildFetcher(); + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + assertEquals(1, fetcher.sendFetches()); + client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.UNKNOWN_TOPIC_ID, -1L, 0)); + consumerClient.poll(time.timer(0)); + assertEquals(0, fetcher.fetchedRecords().size()); + assertEquals(0L, metadata.timeToNextUpdate(time.milliseconds())); + } + + @Test + public void testFetchSessionIdError() { + buildFetcher(); + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + assertEquals(1, fetcher.sendFetches()); + client.prepareResponse(fetchResponseWithTopLevelError(tidp0, Errors.FETCH_SESSION_TOPIC_ID_ERROR, 0)); + consumerClient.poll(time.timer(0)); + assertEquals(0, fetcher.fetchedRecords().size()); + assertEquals(0L, metadata.timeToNextUpdate(time.milliseconds())); + } + + @Test + public void testFetchInconsistentTopicId() { + buildFetcher(); + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + assertEquals(1, fetcher.sendFetches()); + client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.INCONSISTENT_TOPIC_ID, -1L, 0)); + consumerClient.poll(time.timer(0)); + assertEquals(0, fetcher.fetchedRecords().size()); + assertEquals(0L, metadata.timeToNextUpdate(time.milliseconds())); + } + + @Test + public void testFetchFencedLeaderEpoch() { + buildFetcher(); + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + assertEquals(1, fetcher.sendFetches()); + client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.FENCED_LEADER_EPOCH, 100L, 0)); + consumerClient.poll(time.timer(0)); + + assertEquals(0, fetcher.fetchedRecords().size(), "Should not return any records"); + assertEquals(0L, metadata.timeToNextUpdate(time.milliseconds()), "Should have requested metadata update"); + } + + @Test + public void testFetchUnknownLeaderEpoch() { + buildFetcher(); + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + assertEquals(1, fetcher.sendFetches()); + client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.UNKNOWN_LEADER_EPOCH, 100L, 0)); + consumerClient.poll(time.timer(0)); + + assertEquals(0, fetcher.fetchedRecords().size(), "Should not return any records"); + assertNotEquals(0L, metadata.timeToNextUpdate(time.milliseconds()), "Should not have requested metadata update"); + } + + @Test + public void testEpochSetInFetchRequest() { + buildFetcher(); + subscriptions.assignFromUser(singleton(tp0)); + MetadataResponse metadataResponse = RequestTestUtils.metadataUpdateWithIds("dummy", 1, + Collections.emptyMap(), Collections.singletonMap(topicName, 4), tp -> 99, topicIds); + client.updateMetadata(metadataResponse); + + subscriptions.seek(tp0, 10); + assertEquals(1, fetcher.sendFetches()); + + // Check for epoch in outgoing request + MockClient.RequestMatcher matcher = body -> { + if (body instanceof FetchRequest) { + FetchRequest fetchRequest = (FetchRequest) body; + fetchRequest.fetchData(topicNames).values().forEach(partitionData -> { + assertTrue(partitionData.currentLeaderEpoch.isPresent(), "Expected Fetcher to set leader epoch in request"); + assertEquals(99, partitionData.currentLeaderEpoch.get().longValue(), "Expected leader epoch to match epoch from metadata update"); + }); + return true; + } else { + fail("Should have seen FetchRequest"); + return false; + } + }; + client.prepareResponse(matcher, fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0)); + consumerClient.pollNoWakeup(); + } + + @Test + public void testFetchOffsetOutOfRange() { + buildFetcher(); + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + assertEquals(1, fetcher.sendFetches()); + client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0)); + consumerClient.poll(time.timer(0)); + assertEquals(0, fetcher.fetchedRecords().size()); + assertTrue(subscriptions.isOffsetResetNeeded(tp0)); + assertNull(subscriptions.validPosition(tp0)); + assertNull(subscriptions.position(tp0)); + } + + @Test + public void testStaleOutOfRangeError() { + // verify that an out of range error which arrives after a seek + // does not cause us to reset our position or throw an exception + buildFetcher(); + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + assertEquals(1, fetcher.sendFetches()); + client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0)); + subscriptions.seek(tp0, 1); + consumerClient.poll(time.timer(0)); + assertEquals(0, fetcher.fetchedRecords().size()); + assertFalse(subscriptions.isOffsetResetNeeded(tp0)); + assertEquals(1, subscriptions.position(tp0).offset); + } + + @Test + public void testFetchedRecordsAfterSeek() { + buildFetcher(OffsetResetStrategy.NONE, new ByteArrayDeserializer(), + new ByteArrayDeserializer(), 2, IsolationLevel.READ_UNCOMMITTED); + + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + assertTrue(fetcher.sendFetches() > 0); + client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0)); + consumerClient.poll(time.timer(0)); + assertFalse(subscriptions.isOffsetResetNeeded(tp0)); + subscriptions.seek(tp0, 2); + assertEquals(0, fetcher.fetchedRecords().size()); + } + + @Test + public void testFetchOffsetOutOfRangeException() { + buildFetcher(OffsetResetStrategy.NONE, new ByteArrayDeserializer(), + new ByteArrayDeserializer(), 2, IsolationLevel.READ_UNCOMMITTED); + + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + fetcher.sendFetches(); + client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, 0)); + consumerClient.poll(time.timer(0)); + + assertFalse(subscriptions.isOffsetResetNeeded(tp0)); + for (int i = 0; i < 2; i++) { + OffsetOutOfRangeException e = assertThrows(OffsetOutOfRangeException.class, () -> + fetcher.fetchedRecords()); + assertEquals(singleton(tp0), e.offsetOutOfRangePartitions().keySet()); + assertEquals(0L, e.offsetOutOfRangePartitions().get(tp0).longValue()); + } + } + + @Test + public void testFetchPositionAfterException() { + // verify the advancement in the next fetch offset equals to the number of fetched records when + // some fetched partitions cause Exception. This ensures that consumer won't lose record upon exception + buildFetcher(OffsetResetStrategy.NONE, new ByteArrayDeserializer(), + new ByteArrayDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_UNCOMMITTED); + assignFromUser(mkSet(tp0, tp1)); + subscriptions.seek(tp0, 1); + subscriptions.seek(tp1, 1); + + assertEquals(1, fetcher.sendFetches()); + + Map partitions = new LinkedHashMap<>(); + partitions.put(tidp1, new FetchResponseData.PartitionData() + .setPartitionIndex(tp1.partition()) + .setHighWatermark(100) + .setRecords(records)); + partitions.put(tidp0, new FetchResponseData.PartitionData() + .setPartitionIndex(tp0.partition()) + .setErrorCode(Errors.OFFSET_OUT_OF_RANGE.code()) + .setHighWatermark(100)); + client.prepareResponse(FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, new LinkedHashMap<>(partitions))); + consumerClient.poll(time.timer(0)); + + List> allFetchedRecords = new ArrayList<>(); + fetchRecordsInto(allFetchedRecords); + + assertEquals(1, subscriptions.position(tp0).offset); + assertEquals(4, subscriptions.position(tp1).offset); + assertEquals(3, allFetchedRecords.size()); + + OffsetOutOfRangeException e = assertThrows(OffsetOutOfRangeException.class, () -> + fetchRecordsInto(allFetchedRecords)); + + assertEquals(singleton(tp0), e.offsetOutOfRangePartitions().keySet()); + assertEquals(1L, e.offsetOutOfRangePartitions().get(tp0).longValue()); + + assertEquals(1, subscriptions.position(tp0).offset); + assertEquals(4, subscriptions.position(tp1).offset); + assertEquals(3, allFetchedRecords.size()); + } + + private void fetchRecordsInto(List> allFetchedRecords) { + Map>> fetchedRecords = fetchedRecords(); + fetchedRecords.values().forEach(allFetchedRecords::addAll); + } + + @Test + public void testCompletedFetchRemoval() { + // Ensure the removal of completed fetches that cause an Exception if and only if they contain empty records. + buildFetcher(OffsetResetStrategy.NONE, new ByteArrayDeserializer(), + new ByteArrayDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_UNCOMMITTED); + assignFromUser(mkSet(tp0, tp1, tp2, tp3)); + + subscriptions.seek(tp0, 1); + subscriptions.seek(tp1, 1); + subscriptions.seek(tp2, 1); + subscriptions.seek(tp3, 1); + + assertEquals(1, fetcher.sendFetches()); + + Map partitions = new LinkedHashMap<>(); + partitions.put(tidp1, new FetchResponseData.PartitionData() + .setPartitionIndex(tp1.partition()) + .setHighWatermark(100) + .setRecords(records)); + partitions.put(tidp0, new FetchResponseData.PartitionData() + .setPartitionIndex(tp0.partition()) + .setErrorCode(Errors.OFFSET_OUT_OF_RANGE.code()) + .setHighWatermark(100)); + partitions.put(tidp2, new FetchResponseData.PartitionData() + .setPartitionIndex(tp2.partition()) + .setHighWatermark(100) + .setLastStableOffset(4) + .setLogStartOffset(0) + .setRecords(nextRecords)); + partitions.put(tidp3, new FetchResponseData.PartitionData() + .setPartitionIndex(tp3.partition()) + .setHighWatermark(100) + .setLastStableOffset(4) + .setLogStartOffset(0) + .setRecords(partialRecords)); + client.prepareResponse(FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, new LinkedHashMap<>(partitions))); + consumerClient.poll(time.timer(0)); + + List> fetchedRecords = new ArrayList<>(); + Map>> recordsByPartition = fetchedRecords(); + for (List> records : recordsByPartition.values()) + fetchedRecords.addAll(records); + + assertEquals(fetchedRecords.size(), subscriptions.position(tp1).offset - 1); + assertEquals(4, subscriptions.position(tp1).offset); + assertEquals(3, fetchedRecords.size()); + + List oorExceptions = new ArrayList<>(); + try { + recordsByPartition = fetchedRecords(); + for (List> records : recordsByPartition.values()) + fetchedRecords.addAll(records); + } catch (OffsetOutOfRangeException oor) { + oorExceptions.add(oor); + } + + // Should have received one OffsetOutOfRangeException for partition tp1 + assertEquals(1, oorExceptions.size()); + OffsetOutOfRangeException oor = oorExceptions.get(0); + assertTrue(oor.offsetOutOfRangePartitions().containsKey(tp0)); + assertEquals(oor.offsetOutOfRangePartitions().size(), 1); + + recordsByPartition = fetchedRecords(); + for (List> records : recordsByPartition.values()) + fetchedRecords.addAll(records); + + // Should not have received an Exception for tp2. + assertEquals(6, subscriptions.position(tp2).offset); + assertEquals(5, fetchedRecords.size()); + + int numExceptionsExpected = 3; + List kafkaExceptions = new ArrayList<>(); + for (int i = 1; i <= numExceptionsExpected; i++) { + try { + recordsByPartition = fetchedRecords(); + for (List> records : recordsByPartition.values()) + fetchedRecords.addAll(records); + } catch (KafkaException e) { + kafkaExceptions.add(e); + } + } + // Should have received as much as numExceptionsExpected Kafka exceptions for tp3. + assertEquals(numExceptionsExpected, kafkaExceptions.size()); + } + + @Test + public void testSeekBeforeException() { + buildFetcher(OffsetResetStrategy.NONE, new ByteArrayDeserializer(), + new ByteArrayDeserializer(), 2, IsolationLevel.READ_UNCOMMITTED); + + assignFromUser(mkSet(tp0)); + subscriptions.seek(tp0, 1); + assertEquals(1, fetcher.sendFetches()); + Map partitions = new HashMap<>(); + partitions.put(tidp0, new FetchResponseData.PartitionData() + .setPartitionIndex(tp0.partition()) + .setHighWatermark(100) + .setRecords(records)); + client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0)); + consumerClient.poll(time.timer(0)); + + assertEquals(2, fetcher.fetchedRecords().get(tp0).size()); + + subscriptions.assignFromUser(mkSet(tp0, tp1)); + subscriptions.seekUnvalidated(tp1, new SubscriptionState.FetchPosition(1, Optional.empty(), metadata.currentLeader(tp1))); + + assertEquals(1, fetcher.sendFetches()); + partitions = new HashMap<>(); + partitions.put(tidp1, new FetchResponseData.PartitionData() + .setPartitionIndex(tp1.partition()) + .setErrorCode(Errors.OFFSET_OUT_OF_RANGE.code()) + .setHighWatermark(100)); + client.prepareResponse(FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, new LinkedHashMap<>(partitions))); + consumerClient.poll(time.timer(0)); + assertEquals(1, fetcher.fetchedRecords().get(tp0).size()); + + subscriptions.seek(tp1, 10); + // Should not throw OffsetOutOfRangeException after the seek + assertEquals(0, fetcher.fetchedRecords().size()); + } + + @Test + public void testFetchDisconnected() { + buildFetcher(); + + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + assertEquals(1, fetcher.sendFetches()); + client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0), true); + consumerClient.poll(time.timer(0)); + assertEquals(0, fetcher.fetchedRecords().size()); + + // disconnects should have no affect on subscription state + assertFalse(subscriptions.isOffsetResetNeeded(tp0)); + assertTrue(subscriptions.isFetchable(tp0)); + assertEquals(0, subscriptions.position(tp0).offset); + } + + @Test + public void testUpdateFetchPositionNoOpWithPositionSet() { + buildFetcher(); + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 5L); + + fetcher.resetOffsetsIfNeeded(); + assertFalse(client.hasInFlightRequests()); + assertTrue(subscriptions.isFetchable(tp0)); + assertEquals(5, subscriptions.position(tp0).offset); + } + + @Test + public void testUpdateFetchPositionResetToDefaultOffset() { + buildFetcher(); + assignFromUser(singleton(tp0)); + subscriptions.requestOffsetReset(tp0); + + client.prepareResponse(listOffsetRequestMatcher(ListOffsetsRequest.EARLIEST_TIMESTAMP, + validLeaderEpoch), listOffsetResponse(Errors.NONE, 1L, 5L)); + fetcher.resetOffsetsIfNeeded(); + consumerClient.pollNoWakeup(); + assertFalse(subscriptions.isOffsetResetNeeded(tp0)); + assertTrue(subscriptions.isFetchable(tp0)); + assertEquals(5, subscriptions.position(tp0).offset); + } + + @Test + public void testUpdateFetchPositionResetToLatestOffset() { + buildFetcher(); + assignFromUser(singleton(tp0)); + subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST); + + client.updateMetadata(initialUpdateResponse); + + client.prepareResponse(listOffsetRequestMatcher(ListOffsetsRequest.LATEST_TIMESTAMP), + listOffsetResponse(Errors.NONE, 1L, 5L)); + fetcher.resetOffsetsIfNeeded(); + consumerClient.pollNoWakeup(); + assertFalse(subscriptions.isOffsetResetNeeded(tp0)); + assertTrue(subscriptions.isFetchable(tp0)); + assertEquals(5, subscriptions.position(tp0).offset); + } + + /** + * Make sure the client behaves appropriately when receiving an exception for unavailable offsets + */ + @Test + public void testFetchOffsetErrors() { + buildFetcher(); + assignFromUser(singleton(tp0)); + subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST); + + // Fail with OFFSET_NOT_AVAILABLE + client.prepareResponse(listOffsetRequestMatcher(ListOffsetsRequest.LATEST_TIMESTAMP, + validLeaderEpoch), listOffsetResponse(Errors.OFFSET_NOT_AVAILABLE, 1L, 5L), false); + fetcher.resetOffsetsIfNeeded(); + consumerClient.pollNoWakeup(); + assertFalse(subscriptions.hasValidPosition(tp0)); + assertTrue(subscriptions.isOffsetResetNeeded(tp0)); + assertFalse(subscriptions.isFetchable(tp0)); + + // Fail with LEADER_NOT_AVAILABLE + time.sleep(retryBackoffMs); + client.prepareResponse(listOffsetRequestMatcher(ListOffsetsRequest.LATEST_TIMESTAMP, + validLeaderEpoch), listOffsetResponse(Errors.LEADER_NOT_AVAILABLE, 1L, 5L), false); + fetcher.resetOffsetsIfNeeded(); + consumerClient.pollNoWakeup(); + assertFalse(subscriptions.hasValidPosition(tp0)); + assertTrue(subscriptions.isOffsetResetNeeded(tp0)); + assertFalse(subscriptions.isFetchable(tp0)); + + // Back to normal + time.sleep(retryBackoffMs); + client.prepareResponse(listOffsetRequestMatcher(ListOffsetsRequest.LATEST_TIMESTAMP), + listOffsetResponse(Errors.NONE, 1L, 5L), false); + fetcher.resetOffsetsIfNeeded(); + consumerClient.pollNoWakeup(); + assertTrue(subscriptions.hasValidPosition(tp0)); + assertFalse(subscriptions.isOffsetResetNeeded(tp0)); + assertTrue(subscriptions.isFetchable(tp0)); + assertEquals(subscriptions.position(tp0).offset, 5L); + } + + @Test + public void testListOffsetSendsReadUncommitted() { + testListOffsetsSendsIsolationLevel(IsolationLevel.READ_UNCOMMITTED); + } + + @Test + public void testListOffsetSendsReadCommitted() { + testListOffsetsSendsIsolationLevel(IsolationLevel.READ_COMMITTED); + } + + private void testListOffsetsSendsIsolationLevel(IsolationLevel isolationLevel) { + buildFetcher(OffsetResetStrategy.EARLIEST, new ByteArrayDeserializer(), new ByteArrayDeserializer(), + Integer.MAX_VALUE, isolationLevel); + + assignFromUser(singleton(tp0)); + subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST); + + client.prepareResponse(body -> { + ListOffsetsRequest request = (ListOffsetsRequest) body; + return request.isolationLevel() == isolationLevel; + }, listOffsetResponse(Errors.NONE, 1L, 5L)); + fetcher.resetOffsetsIfNeeded(); + consumerClient.pollNoWakeup(); + + assertFalse(subscriptions.isOffsetResetNeeded(tp0)); + assertTrue(subscriptions.isFetchable(tp0)); + assertEquals(5, subscriptions.position(tp0).offset); + } + + @Test + public void testResetOffsetsSkipsBlackedOutConnections() { + buildFetcher(); + assignFromUser(singleton(tp0)); + subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.EARLIEST); + + // Check that we skip sending the ListOffset request when the node is blacked out + client.updateMetadata(initialUpdateResponse); + Node node = initialUpdateResponse.brokers().iterator().next(); + client.backoff(node, 500); + fetcher.resetOffsetsIfNeeded(); + assertEquals(0, consumerClient.pendingRequestCount()); + consumerClient.pollNoWakeup(); + assertTrue(subscriptions.isOffsetResetNeeded(tp0)); + assertEquals(OffsetResetStrategy.EARLIEST, subscriptions.resetStrategy(tp0)); + + time.sleep(500); + client.prepareResponse(listOffsetRequestMatcher(ListOffsetsRequest.EARLIEST_TIMESTAMP), + listOffsetResponse(Errors.NONE, 1L, 5L)); + fetcher.resetOffsetsIfNeeded(); + consumerClient.pollNoWakeup(); + + assertFalse(subscriptions.isOffsetResetNeeded(tp0)); + assertTrue(subscriptions.isFetchable(tp0)); + assertEquals(5, subscriptions.position(tp0).offset); + } + + @Test + public void testUpdateFetchPositionResetToEarliestOffset() { + buildFetcher(); + assignFromUser(singleton(tp0)); + subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.EARLIEST); + + client.prepareResponse(listOffsetRequestMatcher(ListOffsetsRequest.EARLIEST_TIMESTAMP, + validLeaderEpoch), listOffsetResponse(Errors.NONE, 1L, 5L)); + fetcher.resetOffsetsIfNeeded(); + consumerClient.pollNoWakeup(); + + assertFalse(subscriptions.isOffsetResetNeeded(tp0)); + assertTrue(subscriptions.isFetchable(tp0)); + assertEquals(5, subscriptions.position(tp0).offset); + } + + @Test + public void testResetOffsetsMetadataRefresh() { + buildFetcher(); + assignFromUser(singleton(tp0)); + subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST); + + // First fetch fails with stale metadata + client.prepareResponse(listOffsetRequestMatcher(ListOffsetsRequest.LATEST_TIMESTAMP, + validLeaderEpoch), listOffsetResponse(Errors.NOT_LEADER_OR_FOLLOWER, 1L, 5L), false); + fetcher.resetOffsetsIfNeeded(); + consumerClient.pollNoWakeup(); + assertFalse(subscriptions.hasValidPosition(tp0)); + + // Expect a metadata refresh + client.prepareMetadataUpdate(initialUpdateResponse); + consumerClient.pollNoWakeup(); + assertFalse(client.hasPendingMetadataUpdates()); + + // Next fetch succeeds + time.sleep(retryBackoffMs); + client.prepareResponse(listOffsetRequestMatcher(ListOffsetsRequest.LATEST_TIMESTAMP), + listOffsetResponse(Errors.NONE, 1L, 5L)); + fetcher.resetOffsetsIfNeeded(); + consumerClient.pollNoWakeup(); + + assertFalse(subscriptions.isOffsetResetNeeded(tp0)); + assertTrue(subscriptions.isFetchable(tp0)); + assertEquals(5, subscriptions.position(tp0).offset); + } + + @Test + public void testListOffsetNoUpdateMissingEpoch() { + buildFetcher(); + + // Set up metadata with no leader epoch + subscriptions.assignFromUser(singleton(tp0)); + MetadataResponse metadataWithNoLeaderEpochs = RequestTestUtils.metadataUpdateWithIds( + "kafka-cluster", 1, Collections.emptyMap(), singletonMap(topicName, 4), tp -> null, topicIds); + client.updateMetadata(metadataWithNoLeaderEpochs); + + // Return a ListOffsets response with leaderEpoch=1, we should ignore it + subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST); + client.prepareResponse(listOffsetRequestMatcher(ListOffsetsRequest.LATEST_TIMESTAMP), + listOffsetResponse(tp0, Errors.NONE, 1L, 5L, 1)); + fetcher.resetOffsetsIfNeeded(); + consumerClient.pollNoWakeup(); + + // Reset should be satisfied and no metadata update requested + assertFalse(subscriptions.isOffsetResetNeeded(tp0)); + assertFalse(metadata.updateRequested()); + assertFalse(metadata.lastSeenLeaderEpoch(tp0).isPresent()); + } + + @Test + public void testListOffsetUpdateEpoch() { + buildFetcher(); + + // Set up metadata with leaderEpoch=1 + subscriptions.assignFromUser(singleton(tp0)); + MetadataResponse metadataWithLeaderEpochs = RequestTestUtils.metadataUpdateWithIds( + "kafka-cluster", 1, Collections.emptyMap(), singletonMap(topicName, 4), tp -> 1, topicIds); + client.updateMetadata(metadataWithLeaderEpochs); + + // Reset offsets to trigger ListOffsets call + subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST); + + // Now we see a ListOffsets with leaderEpoch=2 epoch, we trigger a metadata update + client.prepareResponse(listOffsetRequestMatcher(ListOffsetsRequest.LATEST_TIMESTAMP, 1), + listOffsetResponse(tp0, Errors.NONE, 1L, 5L, 2)); + fetcher.resetOffsetsIfNeeded(); + consumerClient.pollNoWakeup(); + + assertFalse(subscriptions.isOffsetResetNeeded(tp0)); + assertTrue(metadata.updateRequested()); + assertOptional(metadata.lastSeenLeaderEpoch(tp0), epoch -> assertEquals((long) epoch, 2)); + } + + @Test + public void testUpdateFetchPositionDisconnect() { + buildFetcher(); + assignFromUser(singleton(tp0)); + subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST); + + // First request gets a disconnect + client.prepareResponse(listOffsetRequestMatcher(ListOffsetsRequest.LATEST_TIMESTAMP, + validLeaderEpoch), listOffsetResponse(Errors.NONE, 1L, 5L), true); + fetcher.resetOffsetsIfNeeded(); + consumerClient.pollNoWakeup(); + assertFalse(subscriptions.hasValidPosition(tp0)); + + // Expect a metadata refresh + client.prepareMetadataUpdate(initialUpdateResponse); + consumerClient.pollNoWakeup(); + assertFalse(client.hasPendingMetadataUpdates()); + + // No retry until the backoff passes + fetcher.resetOffsetsIfNeeded(); + consumerClient.pollNoWakeup(); + assertFalse(client.hasInFlightRequests()); + assertFalse(subscriptions.hasValidPosition(tp0)); + + // Next one succeeds + time.sleep(retryBackoffMs); + client.prepareResponse(listOffsetRequestMatcher(ListOffsetsRequest.LATEST_TIMESTAMP), + listOffsetResponse(Errors.NONE, 1L, 5L)); + fetcher.resetOffsetsIfNeeded(); + consumerClient.pollNoWakeup(); + + assertFalse(subscriptions.isOffsetResetNeeded(tp0)); + assertTrue(subscriptions.isFetchable(tp0)); + assertEquals(5, subscriptions.position(tp0).offset); + } + + @Test + public void testAssignmentChangeWithInFlightReset() { + buildFetcher(); + assignFromUser(singleton(tp0)); + subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST); + + // Send the ListOffsets request to reset the position + fetcher.resetOffsetsIfNeeded(); + consumerClient.pollNoWakeup(); + assertFalse(subscriptions.hasValidPosition(tp0)); + assertTrue(client.hasInFlightRequests()); + + // Now we have an assignment change + assignFromUser(singleton(tp1)); + + // The response returns and is discarded + client.respond(listOffsetResponse(Errors.NONE, 1L, 5L)); + consumerClient.pollNoWakeup(); + + assertFalse(client.hasPendingResponses()); + assertFalse(client.hasInFlightRequests()); + assertFalse(subscriptions.isAssigned(tp0)); + } + + @Test + public void testSeekWithInFlightReset() { + buildFetcher(); + assignFromUser(singleton(tp0)); + subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST); + + // Send the ListOffsets request to reset the position + fetcher.resetOffsetsIfNeeded(); + consumerClient.pollNoWakeup(); + assertFalse(subscriptions.hasValidPosition(tp0)); + assertTrue(client.hasInFlightRequests()); + + // Now we get a seek from the user + subscriptions.seek(tp0, 237); + + // The response returns and is discarded + client.respond(listOffsetResponse(Errors.NONE, 1L, 5L)); + consumerClient.pollNoWakeup(); + + assertFalse(client.hasPendingResponses()); + assertFalse(client.hasInFlightRequests()); + assertEquals(237L, subscriptions.position(tp0).offset); + } + + private boolean listOffsetMatchesExpectedReset( + TopicPartition tp, + OffsetResetStrategy strategy, + AbstractRequest request + ) { + assertTrue(request instanceof ListOffsetsRequest); + + ListOffsetsRequest req = (ListOffsetsRequest) request; + assertEquals(singleton(tp.topic()), req.data().topics().stream() + .map(ListOffsetsTopic::name).collect(Collectors.toSet())); + + ListOffsetsTopic listTopic = req.data().topics().get(0); + assertEquals(singleton(tp.partition()), listTopic.partitions().stream() + .map(ListOffsetsPartition::partitionIndex).collect(Collectors.toSet())); + + ListOffsetsPartition listPartition = listTopic.partitions().get(0); + if (strategy == OffsetResetStrategy.EARLIEST) { + assertEquals(ListOffsetsRequest.EARLIEST_TIMESTAMP, listPartition.timestamp()); + } else if (strategy == OffsetResetStrategy.LATEST) { + assertEquals(ListOffsetsRequest.LATEST_TIMESTAMP, listPartition.timestamp()); + } + return true; + } + + @Test + public void testEarlierOffsetResetArrivesLate() { + buildFetcher(); + assignFromUser(singleton(tp0)); + + subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.EARLIEST); + fetcher.resetOffsetsIfNeeded(); + + client.prepareResponse(req -> { + if (listOffsetMatchesExpectedReset(tp0, OffsetResetStrategy.EARLIEST, req)) { + // Before the response is handled, we get a request to reset to the latest offset + subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST); + return true; + } else { + return false; + } + }, listOffsetResponse(Errors.NONE, 1L, 0L)); + consumerClient.pollNoWakeup(); + + // The list offset result should be ignored + assertTrue(subscriptions.isOffsetResetNeeded(tp0)); + assertEquals(OffsetResetStrategy.LATEST, subscriptions.resetStrategy(tp0)); + + fetcher.resetOffsetsIfNeeded(); + client.prepareResponse( + req -> listOffsetMatchesExpectedReset(tp0, OffsetResetStrategy.LATEST, req), + listOffsetResponse(Errors.NONE, 1L, 10L) + ); + consumerClient.pollNoWakeup(); + + assertFalse(subscriptions.isOffsetResetNeeded(tp0)); + assertEquals(10, subscriptions.position(tp0).offset); + } + + @Test + public void testChangeResetWithInFlightReset() { + buildFetcher(); + assignFromUser(singleton(tp0)); + subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST); + + // Send the ListOffsets request to reset the position + fetcher.resetOffsetsIfNeeded(); + consumerClient.pollNoWakeup(); + assertFalse(subscriptions.hasValidPosition(tp0)); + assertTrue(client.hasInFlightRequests()); + + // Now we get a seek from the user + subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.EARLIEST); + + // The response returns and is discarded + client.respond(listOffsetResponse(Errors.NONE, 1L, 5L)); + consumerClient.pollNoWakeup(); + + assertFalse(client.hasPendingResponses()); + assertFalse(client.hasInFlightRequests()); + assertTrue(subscriptions.isOffsetResetNeeded(tp0)); + assertEquals(OffsetResetStrategy.EARLIEST, subscriptions.resetStrategy(tp0)); + } + + @Test + public void testIdempotentResetWithInFlightReset() { + buildFetcher(); + assignFromUser(singleton(tp0)); + subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST); + + // Send the ListOffsets request to reset the position + fetcher.resetOffsetsIfNeeded(); + consumerClient.pollNoWakeup(); + assertFalse(subscriptions.hasValidPosition(tp0)); + assertTrue(client.hasInFlightRequests()); + + // Now we get a seek from the user + subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST); + + client.respond(listOffsetResponse(Errors.NONE, 1L, 5L)); + consumerClient.pollNoWakeup(); + + assertFalse(client.hasInFlightRequests()); + assertFalse(subscriptions.isOffsetResetNeeded(tp0)); + assertEquals(5L, subscriptions.position(tp0).offset); + } + + @Test + public void testRestOffsetsAuthorizationFailure() { + buildFetcher(); + assignFromUser(singleton(tp0)); + subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST); + + // First request gets a disconnect + client.prepareResponse(listOffsetRequestMatcher(ListOffsetsRequest.LATEST_TIMESTAMP, + validLeaderEpoch), listOffsetResponse(Errors.TOPIC_AUTHORIZATION_FAILED, -1, -1), false); + fetcher.resetOffsetsIfNeeded(); + consumerClient.pollNoWakeup(); + assertFalse(subscriptions.hasValidPosition(tp0)); + + try { + fetcher.resetOffsetsIfNeeded(); + fail("Expected authorization error to be raised"); + } catch (TopicAuthorizationException e) { + assertEquals(singleton(tp0.topic()), e.unauthorizedTopics()); + } + + // The exception should clear after being raised, but no retry until the backoff + fetcher.resetOffsetsIfNeeded(); + consumerClient.pollNoWakeup(); + assertFalse(client.hasInFlightRequests()); + assertFalse(subscriptions.hasValidPosition(tp0)); + + // Next one succeeds + time.sleep(retryBackoffMs); + client.prepareResponse(listOffsetRequestMatcher(ListOffsetsRequest.LATEST_TIMESTAMP), + listOffsetResponse(Errors.NONE, 1L, 5L)); + fetcher.resetOffsetsIfNeeded(); + consumerClient.pollNoWakeup(); + + assertFalse(subscriptions.isOffsetResetNeeded(tp0)); + assertTrue(subscriptions.isFetchable(tp0)); + assertEquals(5, subscriptions.position(tp0).offset); + } + + @Test + public void testUpdateFetchPositionOfPausedPartitionsRequiringOffsetReset() { + buildFetcher(); + assignFromUser(singleton(tp0)); + subscriptions.pause(tp0); // paused partition does not have a valid position + subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST); + + client.prepareResponse(listOffsetRequestMatcher(ListOffsetsRequest.LATEST_TIMESTAMP, + validLeaderEpoch), listOffsetResponse(Errors.NONE, 1L, 10L)); + fetcher.resetOffsetsIfNeeded(); + consumerClient.pollNoWakeup(); + + assertFalse(subscriptions.isOffsetResetNeeded(tp0)); + assertFalse(subscriptions.isFetchable(tp0)); // because tp is paused + assertTrue(subscriptions.hasValidPosition(tp0)); + assertEquals(10, subscriptions.position(tp0).offset); + } + + @Test + public void testUpdateFetchPositionOfPausedPartitionsWithoutAValidPosition() { + buildFetcher(); + assignFromUser(singleton(tp0)); + subscriptions.requestOffsetReset(tp0); + subscriptions.pause(tp0); // paused partition does not have a valid position + + fetcher.resetOffsetsIfNeeded(); + consumerClient.pollNoWakeup(); + + assertTrue(subscriptions.isOffsetResetNeeded(tp0)); + assertFalse(subscriptions.isFetchable(tp0)); // because tp is paused + assertFalse(subscriptions.hasValidPosition(tp0)); + } + + @Test + public void testUpdateFetchPositionOfPausedPartitionsWithAValidPosition() { + buildFetcher(); + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 10); + subscriptions.pause(tp0); // paused partition already has a valid position + + fetcher.resetOffsetsIfNeeded(); + + assertFalse(subscriptions.isOffsetResetNeeded(tp0)); + assertFalse(subscriptions.isFetchable(tp0)); // because tp is paused + assertTrue(subscriptions.hasValidPosition(tp0)); + assertEquals(10, subscriptions.position(tp0).offset); + } + + @Test + public void testGetAllTopics() { + // sending response before request, as getTopicMetadata is a blocking call + buildFetcher(); + assignFromUser(singleton(tp0)); + client.prepareResponse(newMetadataResponse(topicName, Errors.NONE)); + + Map> allTopics = fetcher.getAllTopicMetadata(time.timer(5000L)); + + assertEquals(initialUpdateResponse.topicMetadata().size(), allTopics.size()); + } + + @Test + public void testGetAllTopicsDisconnect() { + // first try gets a disconnect, next succeeds + buildFetcher(); + assignFromUser(singleton(tp0)); + client.prepareResponse(null, true); + client.prepareResponse(newMetadataResponse(topicName, Errors.NONE)); + Map> allTopics = fetcher.getAllTopicMetadata(time.timer(5000L)); + assertEquals(initialUpdateResponse.topicMetadata().size(), allTopics.size()); + } + + @Test + public void testGetAllTopicsTimeout() { + // since no response is prepared, the request should timeout + buildFetcher(); + assignFromUser(singleton(tp0)); + assertThrows(TimeoutException.class, () -> fetcher.getAllTopicMetadata(time.timer(50L))); + } + + @Test + public void testGetAllTopicsUnauthorized() { + buildFetcher(); + assignFromUser(singleton(tp0)); + client.prepareResponse(newMetadataResponse(topicName, Errors.TOPIC_AUTHORIZATION_FAILED)); + try { + fetcher.getAllTopicMetadata(time.timer(10L)); + fail(); + } catch (TopicAuthorizationException e) { + assertEquals(singleton(topicName), e.unauthorizedTopics()); + } + } + + @Test + public void testGetTopicMetadataInvalidTopic() { + buildFetcher(); + assignFromUser(singleton(tp0)); + client.prepareResponse(newMetadataResponse(topicName, Errors.INVALID_TOPIC_EXCEPTION)); + assertThrows(InvalidTopicException.class, () -> fetcher.getTopicMetadata( + new MetadataRequest.Builder(Collections.singletonList(topicName), true), time.timer(5000L))); + } + + @Test + public void testGetTopicMetadataUnknownTopic() { + buildFetcher(); + assignFromUser(singleton(tp0)); + client.prepareResponse(newMetadataResponse(topicName, Errors.UNKNOWN_TOPIC_OR_PARTITION)); + + Map> topicMetadata = fetcher.getTopicMetadata( + new MetadataRequest.Builder(Collections.singletonList(topicName), true), time.timer(5000L)); + assertNull(topicMetadata.get(topicName)); + } + + @Test + public void testGetTopicMetadataLeaderNotAvailable() { + buildFetcher(); + assignFromUser(singleton(tp0)); + client.prepareResponse(newMetadataResponse(topicName, Errors.LEADER_NOT_AVAILABLE)); + client.prepareResponse(newMetadataResponse(topicName, Errors.NONE)); + + Map> topicMetadata = fetcher.getTopicMetadata( + new MetadataRequest.Builder(Collections.singletonList(topicName), true), time.timer(5000L)); + assertTrue(topicMetadata.containsKey(topicName)); + } + + @Test + public void testGetTopicMetadataOfflinePartitions() { + buildFetcher(); + assignFromUser(singleton(tp0)); + MetadataResponse originalResponse = newMetadataResponse(topicName, Errors.NONE); //baseline ok response + + //create a response based on the above one with all partitions being leaderless + List altTopics = new ArrayList<>(); + for (MetadataResponse.TopicMetadata item : originalResponse.topicMetadata()) { + List partitions = item.partitionMetadata(); + List altPartitions = new ArrayList<>(); + for (MetadataResponse.PartitionMetadata p : partitions) { + altPartitions.add(new MetadataResponse.PartitionMetadata( + p.error, + p.topicPartition, + Optional.empty(), //no leader + Optional.empty(), + p.replicaIds, + p.inSyncReplicaIds, + p.offlineReplicaIds + )); + } + MetadataResponse.TopicMetadata alteredTopic = new MetadataResponse.TopicMetadata( + item.error(), + item.topic(), + item.isInternal(), + altPartitions + ); + altTopics.add(alteredTopic); + } + Node controller = originalResponse.controller(); + MetadataResponse altered = RequestTestUtils.metadataResponse( + originalResponse.brokers(), + originalResponse.clusterId(), + controller != null ? controller.id() : MetadataResponse.NO_CONTROLLER_ID, + altTopics); + + client.prepareResponse(altered); + + Map> topicMetadata = + fetcher.getTopicMetadata(new MetadataRequest.Builder(Collections.singletonList(topicName), false), + time.timer(5000L)); + + assertNotNull(topicMetadata); + assertNotNull(topicMetadata.get(topicName)); + //noinspection ConstantConditions + assertEquals(metadata.fetch().partitionCountForTopic(topicName).longValue(), topicMetadata.get(topicName).size()); + } + + /* + * Send multiple requests. Verify that the client side quota metrics have the right values + */ + @Test + public void testQuotaMetrics() { + buildFetcher(); + + MockSelector selector = new MockSelector(time); + Sensor throttleTimeSensor = Fetcher.throttleTimeSensor(metrics, metricsRegistry); + Cluster cluster = TestUtils.singletonCluster("test", 1); + Node node = cluster.nodes().get(0); + NetworkClient client = new NetworkClient(selector, metadata, "mock", Integer.MAX_VALUE, + 1000, 1000, 64 * 1024, 64 * 1024, 1000, 10 * 1000, 127 * 1000, + time, true, new ApiVersions(), throttleTimeSensor, new LogContext()); + + ApiVersionsResponse apiVersionsResponse = ApiVersionsResponse.defaultApiVersionsResponse( + 400, ApiMessageType.ListenerType.ZK_BROKER); + ByteBuffer buffer = RequestTestUtils.serializeResponseWithHeader(apiVersionsResponse, ApiKeys.API_VERSIONS.latestVersion(), 0); + + selector.delayedReceive(new DelayedReceive(node.idString(), new NetworkReceive(node.idString(), buffer))); + while (!client.ready(node, time.milliseconds())) { + client.poll(1, time.milliseconds()); + // If a throttled response is received, advance the time to ensure progress. + time.sleep(client.throttleDelayMs(node, time.milliseconds())); + } + selector.clear(); + + for (int i = 1; i <= 3; i++) { + int throttleTimeMs = 100 * i; + FetchRequest.Builder builder = FetchRequest.Builder.forConsumer(ApiKeys.FETCH.latestVersion(), 100, 100, new LinkedHashMap<>()); + builder.rackId(""); + ClientRequest request = client.newClientRequest(node.idString(), builder, time.milliseconds(), true); + client.send(request, time.milliseconds()); + client.poll(1, time.milliseconds()); + FetchResponse response = fullFetchResponse(tidp0, nextRecords, Errors.NONE, i, throttleTimeMs); + buffer = RequestTestUtils.serializeResponseWithHeader(response, ApiKeys.FETCH.latestVersion(), request.correlationId()); + selector.completeReceive(new NetworkReceive(node.idString(), buffer)); + client.poll(1, time.milliseconds()); + // If a throttled response is received, advance the time to ensure progress. + time.sleep(client.throttleDelayMs(node, time.milliseconds())); + selector.clear(); + } + Map allMetrics = metrics.metrics(); + KafkaMetric avgMetric = allMetrics.get(metrics.metricInstance(metricsRegistry.fetchThrottleTimeAvg)); + KafkaMetric maxMetric = allMetrics.get(metrics.metricInstance(metricsRegistry.fetchThrottleTimeMax)); + // Throttle times are ApiVersions=400, Fetch=(100, 200, 300) + assertEquals(250, (Double) avgMetric.metricValue(), EPSILON); + assertEquals(400, (Double) maxMetric.metricValue(), EPSILON); + client.close(); + } + + /* + * Send multiple requests. Verify that the client side quota metrics have the right values + */ + @Test + public void testFetcherMetrics() { + buildFetcher(); + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + MetricName maxLagMetric = metrics.metricInstance(metricsRegistry.recordsLagMax); + Map tags = new HashMap<>(); + tags.put("topic", tp0.topic()); + tags.put("partition", String.valueOf(tp0.partition())); + MetricName partitionLagMetric = metrics.metricName("records-lag", metricGroup, tags); + + Map allMetrics = metrics.metrics(); + KafkaMetric recordsFetchLagMax = allMetrics.get(maxLagMetric); + + // recordsFetchLagMax should be initialized to NaN + assertEquals(Double.NaN, (Double) recordsFetchLagMax.metricValue(), EPSILON); + + // recordsFetchLagMax should be hw - fetchOffset after receiving an empty FetchResponse + fetchRecords(tidp0, MemoryRecords.EMPTY, Errors.NONE, 100L, 0); + assertEquals(100, (Double) recordsFetchLagMax.metricValue(), EPSILON); + + KafkaMetric partitionLag = allMetrics.get(partitionLagMetric); + assertEquals(100, (Double) partitionLag.metricValue(), EPSILON); + + // recordsFetchLagMax should be hw - offset of the last message after receiving a non-empty FetchResponse + MemoryRecordsBuilder builder = MemoryRecords.builder(ByteBuffer.allocate(1024), CompressionType.NONE, + TimestampType.CREATE_TIME, 0L); + for (int v = 0; v < 3; v++) + builder.appendWithOffset(v, RecordBatch.NO_TIMESTAMP, "key".getBytes(), ("value-" + v).getBytes()); + fetchRecords(tidp0, builder.build(), Errors.NONE, 200L, 0); + assertEquals(197, (Double) recordsFetchLagMax.metricValue(), EPSILON); + assertEquals(197, (Double) partitionLag.metricValue(), EPSILON); + + // verify de-registration of partition lag + subscriptions.unsubscribe(); + fetcher.sendFetches(); + assertFalse(allMetrics.containsKey(partitionLagMetric)); + } + + @Test + public void testFetcherLeadMetric() { + buildFetcher(); + + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + MetricName minLeadMetric = metrics.metricInstance(metricsRegistry.recordsLeadMin); + Map tags = new HashMap<>(2); + tags.put("topic", tp0.topic()); + tags.put("partition", String.valueOf(tp0.partition())); + MetricName partitionLeadMetric = metrics.metricName("records-lead", metricGroup, "", tags); + + Map allMetrics = metrics.metrics(); + KafkaMetric recordsFetchLeadMin = allMetrics.get(minLeadMetric); + + // recordsFetchLeadMin should be initialized to NaN + assertEquals(Double.NaN, (Double) recordsFetchLeadMin.metricValue(), EPSILON); + + // recordsFetchLeadMin should be position - logStartOffset after receiving an empty FetchResponse + fetchRecords(tidp0, MemoryRecords.EMPTY, Errors.NONE, 100L, -1L, 0L, 0); + assertEquals(0L, (Double) recordsFetchLeadMin.metricValue(), EPSILON); + + KafkaMetric partitionLead = allMetrics.get(partitionLeadMetric); + assertEquals(0L, (Double) partitionLead.metricValue(), EPSILON); + + // recordsFetchLeadMin should be position - logStartOffset after receiving a non-empty FetchResponse + MemoryRecordsBuilder builder = MemoryRecords.builder(ByteBuffer.allocate(1024), CompressionType.NONE, + TimestampType.CREATE_TIME, 0L); + for (int v = 0; v < 3; v++) { + builder.appendWithOffset(v, RecordBatch.NO_TIMESTAMP, "key".getBytes(), ("value-" + v).getBytes()); + } + fetchRecords(tidp0, builder.build(), Errors.NONE, 200L, -1L, 0L, 0); + assertEquals(0L, (Double) recordsFetchLeadMin.metricValue(), EPSILON); + assertEquals(3L, (Double) partitionLead.metricValue(), EPSILON); + + // verify de-registration of partition lag + subscriptions.unsubscribe(); + fetcher.sendFetches(); + assertFalse(allMetrics.containsKey(partitionLeadMetric)); + } + + @Test + public void testReadCommittedLagMetric() { + buildFetcher(OffsetResetStrategy.EARLIEST, new ByteArrayDeserializer(), + new ByteArrayDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_COMMITTED); + + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + MetricName maxLagMetric = metrics.metricInstance(metricsRegistry.recordsLagMax); + + Map tags = new HashMap<>(); + tags.put("topic", tp0.topic()); + tags.put("partition", String.valueOf(tp0.partition())); + MetricName partitionLagMetric = metrics.metricName("records-lag", metricGroup, tags); + + Map allMetrics = metrics.metrics(); + KafkaMetric recordsFetchLagMax = allMetrics.get(maxLagMetric); + + // recordsFetchLagMax should be initialized to NaN + assertEquals(Double.NaN, (Double) recordsFetchLagMax.metricValue(), EPSILON); + + // recordsFetchLagMax should be lso - fetchOffset after receiving an empty FetchResponse + fetchRecords(tidp0, MemoryRecords.EMPTY, Errors.NONE, 100L, 50L, 0); + assertEquals(50, (Double) recordsFetchLagMax.metricValue(), EPSILON); + + KafkaMetric partitionLag = allMetrics.get(partitionLagMetric); + assertEquals(50, (Double) partitionLag.metricValue(), EPSILON); + + // recordsFetchLagMax should be lso - offset of the last message after receiving a non-empty FetchResponse + MemoryRecordsBuilder builder = MemoryRecords.builder(ByteBuffer.allocate(1024), CompressionType.NONE, + TimestampType.CREATE_TIME, 0L); + for (int v = 0; v < 3; v++) + builder.appendWithOffset(v, RecordBatch.NO_TIMESTAMP, "key".getBytes(), ("value-" + v).getBytes()); + fetchRecords(tidp0, builder.build(), Errors.NONE, 200L, 150L, 0); + assertEquals(147, (Double) recordsFetchLagMax.metricValue(), EPSILON); + assertEquals(147, (Double) partitionLag.metricValue(), EPSILON); + + // verify de-registration of partition lag + subscriptions.unsubscribe(); + fetcher.sendFetches(); + assertFalse(allMetrics.containsKey(partitionLagMetric)); + } + + @Test + public void testFetchResponseMetrics() { + buildFetcher(); + + String topic1 = "foo"; + String topic2 = "bar"; + TopicPartition tp1 = new TopicPartition(topic1, 0); + TopicPartition tp2 = new TopicPartition(topic2, 0); + + subscriptions.assignFromUser(mkSet(tp1, tp2)); + + Map partitionCounts = new HashMap<>(); + partitionCounts.put(topic1, 1); + partitionCounts.put(topic2, 1); + topicIds.put(topic1, Uuid.randomUuid()); + topicIds.put(topic2, Uuid.randomUuid()); + TopicIdPartition tidp1 = new TopicIdPartition(topicIds.get(topic1), tp1); + TopicIdPartition tidp2 = new TopicIdPartition(topicIds.get(topic2), tp2); + client.updateMetadata(RequestTestUtils.metadataUpdateWithIds(1, partitionCounts, tp -> validLeaderEpoch, topicIds)); + + int expectedBytes = 0; + LinkedHashMap fetchPartitionData = new LinkedHashMap<>(); + + for (TopicIdPartition tp : mkSet(tidp1, tidp2)) { + subscriptions.seek(tp.topicPartition(), 0); + + MemoryRecordsBuilder builder = MemoryRecords.builder(ByteBuffer.allocate(1024), CompressionType.NONE, + TimestampType.CREATE_TIME, 0L); + for (int v = 0; v < 3; v++) + builder.appendWithOffset(v, RecordBatch.NO_TIMESTAMP, "key".getBytes(), ("value-" + v).getBytes()); + MemoryRecords records = builder.build(); + for (Record record : records.records()) + expectedBytes += record.sizeInBytes(); + + fetchPartitionData.put(tp, new FetchResponseData.PartitionData() + .setPartitionIndex(tp.topicPartition().partition()) + .setHighWatermark(15) + .setLogStartOffset(0) + .setRecords(records)); + } + + assertEquals(1, fetcher.sendFetches()); + client.prepareResponse(FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, fetchPartitionData)); + consumerClient.poll(time.timer(0)); + + Map>> fetchedRecords = fetchedRecords(); + assertEquals(3, fetchedRecords.get(tp1).size()); + assertEquals(3, fetchedRecords.get(tp2).size()); + + Map allMetrics = metrics.metrics(); + KafkaMetric fetchSizeAverage = allMetrics.get(metrics.metricInstance(metricsRegistry.fetchSizeAvg)); + KafkaMetric recordsCountAverage = allMetrics.get(metrics.metricInstance(metricsRegistry.recordsPerRequestAvg)); + assertEquals(expectedBytes, (Double) fetchSizeAverage.metricValue(), EPSILON); + assertEquals(6, (Double) recordsCountAverage.metricValue(), EPSILON); + } + + @Test + public void testFetchResponseMetricsPartialResponse() { + buildFetcher(); + + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 1); + + Map allMetrics = metrics.metrics(); + KafkaMetric fetchSizeAverage = allMetrics.get(metrics.metricInstance(metricsRegistry.fetchSizeAvg)); + KafkaMetric recordsCountAverage = allMetrics.get(metrics.metricInstance(metricsRegistry.recordsPerRequestAvg)); + + MemoryRecordsBuilder builder = MemoryRecords.builder(ByteBuffer.allocate(1024), CompressionType.NONE, + TimestampType.CREATE_TIME, 0L); + for (int v = 0; v < 3; v++) + builder.appendWithOffset(v, RecordBatch.NO_TIMESTAMP, "key".getBytes(), ("value-" + v).getBytes()); + MemoryRecords records = builder.build(); + + int expectedBytes = 0; + for (Record record : records.records()) { + if (record.offset() >= 1) + expectedBytes += record.sizeInBytes(); + } + + fetchRecords(tidp0, records, Errors.NONE, 100L, 0); + assertEquals(expectedBytes, (Double) fetchSizeAverage.metricValue(), EPSILON); + assertEquals(2, (Double) recordsCountAverage.metricValue(), EPSILON); + } + + @Test + public void testFetchResponseMetricsWithOnePartitionError() { + buildFetcher(); + assignFromUser(mkSet(tp0, tp1)); + subscriptions.seek(tp0, 0); + subscriptions.seek(tp1, 0); + + Map allMetrics = metrics.metrics(); + KafkaMetric fetchSizeAverage = allMetrics.get(metrics.metricInstance(metricsRegistry.fetchSizeAvg)); + KafkaMetric recordsCountAverage = allMetrics.get(metrics.metricInstance(metricsRegistry.recordsPerRequestAvg)); + + MemoryRecordsBuilder builder = MemoryRecords.builder(ByteBuffer.allocate(1024), CompressionType.NONE, + TimestampType.CREATE_TIME, 0L); + for (int v = 0; v < 3; v++) + builder.appendWithOffset(v, RecordBatch.NO_TIMESTAMP, "key".getBytes(), ("value-" + v).getBytes()); + MemoryRecords records = builder.build(); + + Map partitions = new HashMap<>(); + partitions.put(tidp0, new FetchResponseData.PartitionData() + .setPartitionIndex(tp0.partition()) + .setHighWatermark(100) + .setLogStartOffset(0) + .setRecords(records)); + partitions.put(tidp1, new FetchResponseData.PartitionData() + .setPartitionIndex(tp1.partition()) + .setErrorCode(Errors.OFFSET_OUT_OF_RANGE.code()) + .setHighWatermark(100) + .setLogStartOffset(0)); + + assertEquals(1, fetcher.sendFetches()); + client.prepareResponse(FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, new LinkedHashMap<>(partitions))); + consumerClient.poll(time.timer(0)); + fetcher.fetchedRecords(); + + int expectedBytes = 0; + for (Record record : records.records()) + expectedBytes += record.sizeInBytes(); + + assertEquals(expectedBytes, (Double) fetchSizeAverage.metricValue(), EPSILON); + assertEquals(3, (Double) recordsCountAverage.metricValue(), EPSILON); + } + + @Test + public void testFetchResponseMetricsWithOnePartitionAtTheWrongOffset() { + buildFetcher(); + + assignFromUser(mkSet(tp0, tp1)); + subscriptions.seek(tp0, 0); + subscriptions.seek(tp1, 0); + + Map allMetrics = metrics.metrics(); + KafkaMetric fetchSizeAverage = allMetrics.get(metrics.metricInstance(metricsRegistry.fetchSizeAvg)); + KafkaMetric recordsCountAverage = allMetrics.get(metrics.metricInstance(metricsRegistry.recordsPerRequestAvg)); + + // send the fetch and then seek to a new offset + assertEquals(1, fetcher.sendFetches()); + subscriptions.seek(tp1, 5); + + MemoryRecordsBuilder builder = MemoryRecords.builder(ByteBuffer.allocate(1024), CompressionType.NONE, + TimestampType.CREATE_TIME, 0L); + for (int v = 0; v < 3; v++) + builder.appendWithOffset(v, RecordBatch.NO_TIMESTAMP, "key".getBytes(), ("value-" + v).getBytes()); + MemoryRecords records = builder.build(); + + Map partitions = new HashMap<>(); + partitions.put(tidp0, new FetchResponseData.PartitionData() + .setPartitionIndex(tp0.partition()) + .setHighWatermark(100) + .setLogStartOffset(0) + .setRecords(records)); + partitions.put(tidp1, new FetchResponseData.PartitionData() + .setPartitionIndex(tp1.partition()) + .setHighWatermark(100) + .setLogStartOffset(0) + .setRecords(MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("val".getBytes())))); + + client.prepareResponse(FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, new LinkedHashMap<>(partitions))); + consumerClient.poll(time.timer(0)); + fetcher.fetchedRecords(); + + // we should have ignored the record at the wrong offset + int expectedBytes = 0; + for (Record record : records.records()) + expectedBytes += record.sizeInBytes(); + + assertEquals(expectedBytes, (Double) fetchSizeAverage.metricValue(), EPSILON); + assertEquals(3, (Double) recordsCountAverage.metricValue(), EPSILON); + } + + @Test + public void testFetcherMetricsTemplates() { + Map clientTags = Collections.singletonMap("client-id", "clientA"); + buildFetcher(new MetricConfig().tags(clientTags), OffsetResetStrategy.EARLIEST, new ByteArrayDeserializer(), + new ByteArrayDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_UNCOMMITTED); + + // Fetch from topic to generate topic metrics + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + assertEquals(1, fetcher.sendFetches()); + client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, 0)); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + Map>> partitionRecords = fetchedRecords(); + assertTrue(partitionRecords.containsKey(tp0)); + + // Create throttle metrics + Fetcher.throttleTimeSensor(metrics, metricsRegistry); + + // Verify that all metrics except metrics-count have registered templates + Set allMetrics = new HashSet<>(); + for (MetricName n : metrics.metrics().keySet()) { + String name = n.name().replaceAll(tp0.toString(), "{topic}-{partition}"); + if (!n.group().equals("kafka-metrics-count")) + allMetrics.add(new MetricNameTemplate(name, n.group(), "", n.tags().keySet())); + } + TestUtils.checkEquals(allMetrics, new HashSet<>(metricsRegistry.getAllTemplates()), "metrics", "templates"); + } + + private Map>> fetchRecords( + TopicIdPartition tp, MemoryRecords records, Errors error, long hw, int throttleTime) { + return fetchRecords(tp, records, error, hw, FetchResponse.INVALID_LAST_STABLE_OFFSET, throttleTime); + } + + private Map>> fetchRecords( + TopicIdPartition tp, MemoryRecords records, Errors error, long hw, long lastStableOffset, int throttleTime) { + assertEquals(1, fetcher.sendFetches()); + client.prepareResponse(fullFetchResponse(tp, records, error, hw, lastStableOffset, throttleTime)); + consumerClient.poll(time.timer(0)); + return fetchedRecords(); + } + + private Map>> fetchRecords( + TopicIdPartition tp, MemoryRecords records, Errors error, long hw, long lastStableOffset, long logStartOffset, int throttleTime) { + assertEquals(1, fetcher.sendFetches()); + client.prepareResponse(fetchResponse(tp, records, error, hw, lastStableOffset, logStartOffset, throttleTime)); + consumerClient.poll(time.timer(0)); + return fetchedRecords(); + } + + @Test + public void testGetOffsetsForTimesTimeout() { + buildFetcher(); + assertThrows(TimeoutException.class, () -> fetcher.offsetsForTimes( + Collections.singletonMap(new TopicPartition(topicName, 2), 1000L), time.timer(100L))); + } + + @Test + public void testGetOffsetsForTimes() { + buildFetcher(); + + // Empty map + assertTrue(fetcher.offsetsForTimes(new HashMap<>(), time.timer(100L)).isEmpty()); + // Unknown Offset + testGetOffsetsForTimesWithUnknownOffset(); + // Error code none with unknown offset + testGetOffsetsForTimesWithError(Errors.NONE, Errors.NONE, -1L, 100L, null, 100L); + // Error code none with known offset + testGetOffsetsForTimesWithError(Errors.NONE, Errors.NONE, 10L, 100L, 10L, 100L); + // Test both of partition has error. + testGetOffsetsForTimesWithError(Errors.NOT_LEADER_OR_FOLLOWER, Errors.INVALID_REQUEST, 10L, 100L, 10L, 100L); + // Test the second partition has error. + testGetOffsetsForTimesWithError(Errors.NONE, Errors.NOT_LEADER_OR_FOLLOWER, 10L, 100L, 10L, 100L); + // Test different errors. + testGetOffsetsForTimesWithError(Errors.NOT_LEADER_OR_FOLLOWER, Errors.NONE, 10L, 100L, 10L, 100L); + testGetOffsetsForTimesWithError(Errors.UNKNOWN_TOPIC_OR_PARTITION, Errors.NONE, 10L, 100L, 10L, 100L); + testGetOffsetsForTimesWithError(Errors.UNSUPPORTED_FOR_MESSAGE_FORMAT, Errors.NONE, 10L, 100L, null, 100L); + testGetOffsetsForTimesWithError(Errors.BROKER_NOT_AVAILABLE, Errors.NONE, 10L, 100L, 10L, 100L); + } + + @Test + public void testGetOffsetsFencedLeaderEpoch() { + buildFetcher(); + subscriptions.assignFromUser(singleton(tp0)); + client.updateMetadata(initialUpdateResponse); + + subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST); + + client.prepareResponse(listOffsetResponse(Errors.FENCED_LEADER_EPOCH, 1L, 5L)); + fetcher.resetOffsetsIfNeeded(); + consumerClient.pollNoWakeup(); + + assertTrue(subscriptions.isOffsetResetNeeded(tp0)); + assertFalse(subscriptions.isFetchable(tp0)); + assertFalse(subscriptions.hasValidPosition(tp0)); + assertEquals(0L, metadata.timeToNextUpdate(time.milliseconds())); + } + + @Test + public void testGetOffsetByTimeWithPartitionsRetryCouldTriggerMetadataUpdate() { + List retriableErrors = Arrays.asList(Errors.NOT_LEADER_OR_FOLLOWER, + Errors.REPLICA_NOT_AVAILABLE, Errors.KAFKA_STORAGE_ERROR, Errors.OFFSET_NOT_AVAILABLE, + Errors.LEADER_NOT_AVAILABLE, Errors.FENCED_LEADER_EPOCH, Errors.UNKNOWN_LEADER_EPOCH); + + final int newLeaderEpoch = 3; + MetadataResponse updatedMetadata = RequestTestUtils.metadataUpdateWithIds("dummy", 3, + singletonMap(topicName, Errors.NONE), singletonMap(topicName, 4), tp -> newLeaderEpoch, topicIds); + + Node originalLeader = initialUpdateResponse.buildCluster().leaderFor(tp1); + Node newLeader = updatedMetadata.buildCluster().leaderFor(tp1); + assertNotEquals(originalLeader, newLeader); + + for (Errors retriableError : retriableErrors) { + buildFetcher(); + + subscriptions.assignFromUser(mkSet(tp0, tp1)); + client.updateMetadata(initialUpdateResponse); + + final long fetchTimestamp = 10L; + ListOffsetsPartitionResponse tp0NoError = new ListOffsetsPartitionResponse() + .setPartitionIndex(tp0.partition()) + .setErrorCode(Errors.NONE.code()) + .setTimestamp(fetchTimestamp) + .setOffset(4L); + List topics = Collections.singletonList( + new ListOffsetsTopicResponse() + .setName(tp0.topic()) + .setPartitions(Arrays.asList( + tp0NoError, + new ListOffsetsPartitionResponse() + .setPartitionIndex(tp1.partition()) + .setErrorCode(retriableError.code()) + .setTimestamp(ListOffsetsRequest.LATEST_TIMESTAMP) + .setOffset(-1L)))); + ListOffsetsResponseData data = new ListOffsetsResponseData() + .setThrottleTimeMs(0) + .setTopics(topics); + + client.prepareResponseFrom(body -> { + boolean isListOffsetRequest = body instanceof ListOffsetsRequest; + if (isListOffsetRequest) { + ListOffsetsRequest request = (ListOffsetsRequest) body; + List expectedTopics = Collections.singletonList( + new ListOffsetsTopic() + .setName(tp0.topic()) + .setPartitions(Arrays.asList( + new ListOffsetsPartition() + .setPartitionIndex(tp1.partition()) + .setTimestamp(fetchTimestamp) + .setCurrentLeaderEpoch(ListOffsetsResponse.UNKNOWN_EPOCH), + new ListOffsetsPartition() + .setPartitionIndex(tp0.partition()) + .setTimestamp(fetchTimestamp) + .setCurrentLeaderEpoch(ListOffsetsResponse.UNKNOWN_EPOCH)))); + return request.topics().equals(expectedTopics); + } else { + return false; + } + }, new ListOffsetsResponse(data), originalLeader); + + client.prepareMetadataUpdate(updatedMetadata); + + // If the metadata wasn't updated before retrying, the fetcher would consult the original leader and hit a NOT_LEADER exception. + // We will count the answered future response in the end to verify if this is the case. + List topicsWithFatalError = Collections.singletonList( + new ListOffsetsTopicResponse() + .setName(tp0.topic()) + .setPartitions(Arrays.asList( + tp0NoError, + new ListOffsetsPartitionResponse() + .setPartitionIndex(tp1.partition()) + .setErrorCode(Errors.NOT_LEADER_OR_FOLLOWER.code()) + .setTimestamp(ListOffsetsRequest.LATEST_TIMESTAMP) + .setOffset(-1L)))); + ListOffsetsResponseData dataWithFatalError = new ListOffsetsResponseData() + .setThrottleTimeMs(0) + .setTopics(topicsWithFatalError); + client.prepareResponseFrom(new ListOffsetsResponse(dataWithFatalError), originalLeader); + + // The request to new leader must only contain one partition tp1 with error. + client.prepareResponseFrom(body -> { + boolean isListOffsetRequest = body instanceof ListOffsetsRequest; + if (isListOffsetRequest) { + ListOffsetsRequest request = (ListOffsetsRequest) body; + + ListOffsetsTopic requestTopic = request.topics().get(0); + ListOffsetsPartition expectedPartition = new ListOffsetsPartition() + .setPartitionIndex(tp1.partition()) + .setTimestamp(fetchTimestamp) + .setCurrentLeaderEpoch(newLeaderEpoch); + return expectedPartition.equals(requestTopic.partitions().get(0)); + } else { + return false; + } + }, listOffsetResponse(tp1, Errors.NONE, fetchTimestamp, 5L), newLeader); + + Map offsetAndTimestampMap = + fetcher.offsetsForTimes( + Utils.mkMap(Utils.mkEntry(tp0, fetchTimestamp), + Utils.mkEntry(tp1, fetchTimestamp)), time.timer(Integer.MAX_VALUE)); + + assertEquals(Utils.mkMap( + Utils.mkEntry(tp0, new OffsetAndTimestamp(4L, fetchTimestamp)), + Utils.mkEntry(tp1, new OffsetAndTimestamp(5L, fetchTimestamp))), offsetAndTimestampMap); + + // The NOT_LEADER exception future should not be cleared as we already refreshed the metadata before + // first retry, thus never hitting. + assertEquals(1, client.numAwaitingResponses()); + + fetcher.close(); + } + } + + @Test + public void testGetOffsetsUnknownLeaderEpoch() { + buildFetcher(); + subscriptions.assignFromUser(singleton(tp0)); + subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST); + + client.prepareResponse(listOffsetResponse(Errors.UNKNOWN_LEADER_EPOCH, 1L, 5L)); + fetcher.resetOffsetsIfNeeded(); + consumerClient.pollNoWakeup(); + + assertTrue(subscriptions.isOffsetResetNeeded(tp0)); + assertFalse(subscriptions.isFetchable(tp0)); + assertFalse(subscriptions.hasValidPosition(tp0)); + assertEquals(0L, metadata.timeToNextUpdate(time.milliseconds())); + } + + @Test + public void testGetOffsetsIncludesLeaderEpoch() { + buildFetcher(); + subscriptions.assignFromUser(singleton(tp0)); + + client.updateMetadata(initialUpdateResponse); + + // Metadata update with leader epochs + MetadataResponse metadataResponse = RequestTestUtils.metadataUpdateWithIds("dummy", 1, + Collections.emptyMap(), Collections.singletonMap(topicName, 4), tp -> 99, topicIds); + client.updateMetadata(metadataResponse); + + // Request latest offset + subscriptions.requestOffsetReset(tp0); + fetcher.resetOffsetsIfNeeded(); + + // Check for epoch in outgoing request + MockClient.RequestMatcher matcher = body -> { + if (body instanceof ListOffsetsRequest) { + ListOffsetsRequest offsetRequest = (ListOffsetsRequest) body; + int epoch = offsetRequest.topics().get(0).partitions().get(0).currentLeaderEpoch(); + assertTrue(epoch != ListOffsetsResponse.UNKNOWN_EPOCH, "Expected Fetcher to set leader epoch in request"); + assertEquals(epoch, 99, "Expected leader epoch to match epoch from metadata update"); + return true; + } else { + fail("Should have seen ListOffsetRequest"); + return false; + } + }; + + client.prepareResponse(matcher, listOffsetResponse(Errors.NONE, 1L, 5L)); + consumerClient.pollNoWakeup(); + } + + @Test + public void testGetOffsetsForTimesWhenSomeTopicPartitionLeadersNotKnownInitially() { + buildFetcher(); + + subscriptions.assignFromUser(mkSet(tp0, tp1)); + final String anotherTopic = "another-topic"; + final TopicPartition t2p0 = new TopicPartition(anotherTopic, 0); + + client.reset(); + + // Metadata initially has one topic + MetadataResponse initialMetadata = RequestTestUtils.metadataUpdateWithIds(3, singletonMap(topicName, 2), topicIds); + client.updateMetadata(initialMetadata); + + // The first metadata refresh should contain one topic + client.prepareMetadataUpdate(initialMetadata); + client.prepareResponseFrom(listOffsetResponse(tp0, Errors.NONE, 1000L, 11L), + metadata.fetch().leaderFor(tp0)); + client.prepareResponseFrom(listOffsetResponse(tp1, Errors.NONE, 1000L, 32L), + metadata.fetch().leaderFor(tp1)); + + // Second metadata refresh should contain two topics + Map partitionNumByTopic = new HashMap<>(); + partitionNumByTopic.put(topicName, 2); + partitionNumByTopic.put(anotherTopic, 1); + topicIds.put("another-topic", Uuid.randomUuid()); + MetadataResponse updatedMetadata = RequestTestUtils.metadataUpdateWithIds(3, partitionNumByTopic, topicIds); + client.prepareMetadataUpdate(updatedMetadata); + client.prepareResponseFrom(listOffsetResponse(t2p0, Errors.NONE, 1000L, 54L), + metadata.fetch().leaderFor(t2p0)); + + Map timestampToSearch = new HashMap<>(); + timestampToSearch.put(tp0, ListOffsetsRequest.LATEST_TIMESTAMP); + timestampToSearch.put(tp1, ListOffsetsRequest.LATEST_TIMESTAMP); + timestampToSearch.put(t2p0, ListOffsetsRequest.LATEST_TIMESTAMP); + Map offsetAndTimestampMap = + fetcher.offsetsForTimes(timestampToSearch, time.timer(Long.MAX_VALUE)); + + assertNotNull(offsetAndTimestampMap.get(tp0), "Expect Fetcher.offsetsForTimes() to return non-null result for " + tp0); + assertNotNull(offsetAndTimestampMap.get(tp1), "Expect Fetcher.offsetsForTimes() to return non-null result for " + tp1); + assertNotNull(offsetAndTimestampMap.get(t2p0), "Expect Fetcher.offsetsForTimes() to return non-null result for " + t2p0); + assertEquals(11L, offsetAndTimestampMap.get(tp0).offset()); + assertEquals(32L, offsetAndTimestampMap.get(tp1).offset()); + assertEquals(54L, offsetAndTimestampMap.get(t2p0).offset()); + } + + @Test + public void testGetOffsetsForTimesWhenSomeTopicPartitionLeadersDisconnectException() { + buildFetcher(); + final String anotherTopic = "another-topic"; + final TopicPartition t2p0 = new TopicPartition(anotherTopic, 0); + subscriptions.assignFromUser(mkSet(tp0, t2p0)); + + client.reset(); + + MetadataResponse initialMetadata = RequestTestUtils.metadataUpdateWithIds(1, singletonMap(topicName, 1), topicIds); + client.updateMetadata(initialMetadata); + + Map partitionNumByTopic = new HashMap<>(); + partitionNumByTopic.put(topicName, 1); + partitionNumByTopic.put(anotherTopic, 1); + topicIds.put("another-topic", Uuid.randomUuid()); + MetadataResponse updatedMetadata = RequestTestUtils.metadataUpdateWithIds(1, partitionNumByTopic, topicIds); + client.prepareMetadataUpdate(updatedMetadata); + + client.prepareResponse(listOffsetRequestMatcher(ListOffsetsRequest.LATEST_TIMESTAMP), + listOffsetResponse(tp0, Errors.NONE, 1000L, 11L), true); + client.prepareResponseFrom(listOffsetResponse(tp0, Errors.NONE, 1000L, 11L), metadata.fetch().leaderFor(tp0)); + + Map timestampToSearch = new HashMap<>(); + timestampToSearch.put(tp0, ListOffsetsRequest.LATEST_TIMESTAMP); + Map offsetAndTimestampMap = fetcher.offsetsForTimes(timestampToSearch, time.timer(Long.MAX_VALUE)); + + assertNotNull(offsetAndTimestampMap.get(tp0), "Expect Fetcher.offsetsForTimes() to return non-null result for " + tp0); + assertEquals(11L, offsetAndTimestampMap.get(tp0).offset()); + assertNotNull(metadata.fetch().partitionCountForTopic(anotherTopic)); + } + + @Test + public void testListOffsetsWithZeroTimeout() { + buildFetcher(); + + Map offsetsToSearch = new HashMap<>(); + offsetsToSearch.put(tp0, ListOffsetsRequest.EARLIEST_TIMESTAMP); + offsetsToSearch.put(tp1, ListOffsetsRequest.EARLIEST_TIMESTAMP); + + Map offsetsToExpect = new HashMap<>(); + offsetsToExpect.put(tp0, null); + offsetsToExpect.put(tp1, null); + + assertEquals(offsetsToExpect, fetcher.offsetsForTimes(offsetsToSearch, time.timer(0))); + } + + @Test + public void testBatchedListOffsetsMetadataErrors() { + buildFetcher(); + + ListOffsetsResponseData data = new ListOffsetsResponseData() + .setThrottleTimeMs(0) + .setTopics(Collections.singletonList(new ListOffsetsTopicResponse() + .setName(tp0.topic()) + .setPartitions(Arrays.asList( + new ListOffsetsPartitionResponse() + .setPartitionIndex(tp0.partition()) + .setErrorCode(Errors.NOT_LEADER_OR_FOLLOWER.code()) + .setTimestamp(ListOffsetsResponse.UNKNOWN_TIMESTAMP) + .setOffset(ListOffsetsResponse.UNKNOWN_OFFSET), + new ListOffsetsPartitionResponse() + .setPartitionIndex(tp1.partition()) + .setErrorCode(Errors.UNKNOWN_TOPIC_OR_PARTITION.code()) + .setTimestamp(ListOffsetsResponse.UNKNOWN_TIMESTAMP) + .setOffset(ListOffsetsResponse.UNKNOWN_OFFSET))))); + client.prepareResponse(new ListOffsetsResponse(data)); + + Map offsetsToSearch = new HashMap<>(); + offsetsToSearch.put(tp0, ListOffsetsRequest.EARLIEST_TIMESTAMP); + offsetsToSearch.put(tp1, ListOffsetsRequest.EARLIEST_TIMESTAMP); + + assertThrows(TimeoutException.class, () -> fetcher.offsetsForTimes(offsetsToSearch, time.timer(1))); + } + + @Test + public void testSkippingAbortedTransactions() { + buildFetcher(OffsetResetStrategy.EARLIEST, new ByteArrayDeserializer(), + new ByteArrayDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_COMMITTED); + ByteBuffer buffer = ByteBuffer.allocate(1024); + int currentOffset = 0; + + currentOffset += appendTransactionalRecords(buffer, 1L, currentOffset, + new SimpleRecord(time.milliseconds(), "key".getBytes(), "value".getBytes()), + new SimpleRecord(time.milliseconds(), "key".getBytes(), "value".getBytes())); + + abortTransaction(buffer, 1L, currentOffset); + + buffer.flip(); + + List abortedTransactions = Collections.singletonList( + new FetchResponseData.AbortedTransaction().setProducerId(1).setFirstOffset(0)); + MemoryRecords records = MemoryRecords.readableRecords(buffer); + assignFromUser(singleton(tp0)); + + subscriptions.seek(tp0, 0); + + // normal fetch + assertEquals(1, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + + client.prepareResponse(fullFetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0)); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + + Map>> fetchedRecords = fetchedRecords(); + assertFalse(fetchedRecords.containsKey(tp0)); + } + + @Test + public void testReturnCommittedTransactions() { + buildFetcher(OffsetResetStrategy.EARLIEST, new ByteArrayDeserializer(), + new ByteArrayDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_COMMITTED); + ByteBuffer buffer = ByteBuffer.allocate(1024); + int currentOffset = 0; + + currentOffset += appendTransactionalRecords(buffer, 1L, currentOffset, + new SimpleRecord(time.milliseconds(), "key".getBytes(), "value".getBytes()), + new SimpleRecord(time.milliseconds(), "key".getBytes(), "value".getBytes())); + + commitTransaction(buffer, 1L, currentOffset); + buffer.flip(); + + MemoryRecords records = MemoryRecords.readableRecords(buffer); + assignFromUser(singleton(tp0)); + + subscriptions.seek(tp0, 0); + + // normal fetch + assertEquals(1, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + client.prepareResponse(body -> { + FetchRequest request = (FetchRequest) body; + assertEquals(IsolationLevel.READ_COMMITTED, request.isolationLevel()); + return true; + }, fullFetchResponseWithAbortedTransactions(records, Collections.emptyList(), Errors.NONE, 100L, 100L, 0)); + + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + + Map>> fetchedRecords = fetchedRecords(); + assertTrue(fetchedRecords.containsKey(tp0)); + assertEquals(fetchedRecords.get(tp0).size(), 2); + } + + @Test + public void testReadCommittedWithCommittedAndAbortedTransactions() { + buildFetcher(OffsetResetStrategy.EARLIEST, new ByteArrayDeserializer(), + new ByteArrayDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_COMMITTED); + ByteBuffer buffer = ByteBuffer.allocate(1024); + + List abortedTransactions = new ArrayList<>(); + + long pid1 = 1L; + long pid2 = 2L; + + // Appends for producer 1 (eventually committed) + appendTransactionalRecords(buffer, pid1, 0L, + new SimpleRecord("commit1-1".getBytes(), "value".getBytes()), + new SimpleRecord("commit1-2".getBytes(), "value".getBytes())); + + // Appends for producer 2 (eventually aborted) + appendTransactionalRecords(buffer, pid2, 2L, + new SimpleRecord("abort2-1".getBytes(), "value".getBytes())); + + // commit producer 1 + commitTransaction(buffer, pid1, 3L); + + // append more for producer 2 (eventually aborted) + appendTransactionalRecords(buffer, pid2, 4L, + new SimpleRecord("abort2-2".getBytes(), "value".getBytes())); + + // abort producer 2 + abortTransaction(buffer, pid2, 5L); + abortedTransactions.add(new FetchResponseData.AbortedTransaction().setProducerId(pid2).setFirstOffset(2L)); + + // New transaction for producer 1 (eventually aborted) + appendTransactionalRecords(buffer, pid1, 6L, + new SimpleRecord("abort1-1".getBytes(), "value".getBytes())); + + // New transaction for producer 2 (eventually committed) + appendTransactionalRecords(buffer, pid2, 7L, + new SimpleRecord("commit2-1".getBytes(), "value".getBytes())); + + // Add messages for producer 1 (eventually aborted) + appendTransactionalRecords(buffer, pid1, 8L, + new SimpleRecord("abort1-2".getBytes(), "value".getBytes())); + + // abort producer 1 + abortTransaction(buffer, pid1, 9L); + abortedTransactions.add(new FetchResponseData.AbortedTransaction().setProducerId(1).setFirstOffset(6)); + + // commit producer 2 + commitTransaction(buffer, pid2, 10L); + + buffer.flip(); + + MemoryRecords records = MemoryRecords.readableRecords(buffer); + assignFromUser(singleton(tp0)); + + subscriptions.seek(tp0, 0); + + // normal fetch + assertEquals(1, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + + client.prepareResponse(fullFetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0)); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + + Map>> fetchedRecords = fetchedRecords(); + assertTrue(fetchedRecords.containsKey(tp0)); + // There are only 3 committed records + List> fetchedConsumerRecords = fetchedRecords.get(tp0); + Set fetchedKeys = new HashSet<>(); + for (ConsumerRecord consumerRecord : fetchedConsumerRecords) { + fetchedKeys.add(new String(consumerRecord.key(), StandardCharsets.UTF_8)); + } + assertEquals(mkSet("commit1-1", "commit1-2", "commit2-1"), fetchedKeys); + } + + @Test + public void testMultipleAbortMarkers() { + buildFetcher(OffsetResetStrategy.EARLIEST, new ByteArrayDeserializer(), + new ByteArrayDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_COMMITTED); + ByteBuffer buffer = ByteBuffer.allocate(1024); + int currentOffset = 0; + + currentOffset += appendTransactionalRecords(buffer, 1L, currentOffset, + new SimpleRecord(time.milliseconds(), "abort1-1".getBytes(), "value".getBytes()), + new SimpleRecord(time.milliseconds(), "abort1-2".getBytes(), "value".getBytes())); + + currentOffset += abortTransaction(buffer, 1L, currentOffset); + // Duplicate abort -- should be ignored. + currentOffset += abortTransaction(buffer, 1L, currentOffset); + // Now commit a transaction. + currentOffset += appendTransactionalRecords(buffer, 1L, currentOffset, + new SimpleRecord(time.milliseconds(), "commit1-1".getBytes(), "value".getBytes()), + new SimpleRecord(time.milliseconds(), "commit1-2".getBytes(), "value".getBytes())); + commitTransaction(buffer, 1L, currentOffset); + buffer.flip(); + + List abortedTransactions = Collections.singletonList( + new FetchResponseData.AbortedTransaction().setProducerId(1).setFirstOffset(0) + ); + MemoryRecords records = MemoryRecords.readableRecords(buffer); + assignFromUser(singleton(tp0)); + + subscriptions.seek(tp0, 0); + + // normal fetch + assertEquals(1, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + + client.prepareResponse(fullFetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0)); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + + Map>> fetchedRecords = fetchedRecords(); + assertTrue(fetchedRecords.containsKey(tp0)); + assertEquals(fetchedRecords.get(tp0).size(), 2); + List> fetchedConsumerRecords = fetchedRecords.get(tp0); + Set committedKeys = new HashSet<>(Arrays.asList("commit1-1", "commit1-2")); + Set actuallyCommittedKeys = new HashSet<>(); + for (ConsumerRecord consumerRecord : fetchedConsumerRecords) { + actuallyCommittedKeys.add(new String(consumerRecord.key(), StandardCharsets.UTF_8)); + } + assertEquals(actuallyCommittedKeys, committedKeys); + } + + @Test + public void testReadCommittedAbortMarkerWithNoData() { + buildFetcher(OffsetResetStrategy.EARLIEST, new StringDeserializer(), + new StringDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_COMMITTED); + ByteBuffer buffer = ByteBuffer.allocate(1024); + + long producerId = 1L; + + abortTransaction(buffer, producerId, 5L); + + appendTransactionalRecords(buffer, producerId, 6L, + new SimpleRecord("6".getBytes(), null), + new SimpleRecord("7".getBytes(), null), + new SimpleRecord("8".getBytes(), null)); + + commitTransaction(buffer, producerId, 9L); + + buffer.flip(); + + // send the fetch + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + assertEquals(1, fetcher.sendFetches()); + + // prepare the response. the aborted transactions begin at offsets which are no longer in the log + List abortedTransactions = Collections.singletonList( + new FetchResponseData.AbortedTransaction().setProducerId(producerId).setFirstOffset(0L)); + + client.prepareResponse(fullFetchResponseWithAbortedTransactions(MemoryRecords.readableRecords(buffer), + abortedTransactions, Errors.NONE, 100L, 100L, 0)); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + + Map>> allFetchedRecords = fetchedRecords(); + assertTrue(allFetchedRecords.containsKey(tp0)); + List> fetchedRecords = allFetchedRecords.get(tp0); + assertEquals(3, fetchedRecords.size()); + assertEquals(Arrays.asList(6L, 7L, 8L), collectRecordOffsets(fetchedRecords)); + } + + @Test + public void testUpdatePositionWithLastRecordMissingFromBatch() { + buildFetcher(); + + MemoryRecords records = MemoryRecords.withRecords(CompressionType.NONE, + new SimpleRecord("0".getBytes(), "v".getBytes()), + new SimpleRecord("1".getBytes(), "v".getBytes()), + new SimpleRecord("2".getBytes(), "v".getBytes()), + new SimpleRecord(null, "value".getBytes())); + + // Remove the last record to simulate compaction + MemoryRecords.FilterResult result = records.filterTo(tp0, new MemoryRecords.RecordFilter(0, 0) { + @Override + protected BatchRetentionResult checkBatchRetention(RecordBatch batch) { + return new BatchRetentionResult(BatchRetention.DELETE_EMPTY, false); + } + + @Override + protected boolean shouldRetainRecord(RecordBatch recordBatch, Record record) { + return record.key() != null; + } + }, ByteBuffer.allocate(1024), Integer.MAX_VALUE, BufferSupplier.NO_CACHING); + result.outputBuffer().flip(); + MemoryRecords compactedRecords = MemoryRecords.readableRecords(result.outputBuffer()); + + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + assertEquals(1, fetcher.sendFetches()); + client.prepareResponse(fullFetchResponse(tidp0, compactedRecords, Errors.NONE, 100L, 0)); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + + Map>> allFetchedRecords = fetchedRecords(); + assertTrue(allFetchedRecords.containsKey(tp0)); + List> fetchedRecords = allFetchedRecords.get(tp0); + assertEquals(3, fetchedRecords.size()); + + for (int i = 0; i < 3; i++) { + assertEquals(Integer.toString(i), new String(fetchedRecords.get(i).key())); + } + + // The next offset should point to the next batch + assertEquals(4L, subscriptions.position(tp0).offset); + } + + @Test + public void testUpdatePositionOnEmptyBatch() { + buildFetcher(); + + long producerId = 1; + short producerEpoch = 0; + int sequence = 1; + long baseOffset = 37; + long lastOffset = 54; + int partitionLeaderEpoch = 7; + ByteBuffer buffer = ByteBuffer.allocate(DefaultRecordBatch.RECORD_BATCH_OVERHEAD); + DefaultRecordBatch.writeEmptyHeader(buffer, RecordBatch.CURRENT_MAGIC_VALUE, producerId, producerEpoch, + sequence, baseOffset, lastOffset, partitionLeaderEpoch, TimestampType.CREATE_TIME, + System.currentTimeMillis(), false, false); + buffer.flip(); + MemoryRecords recordsWithEmptyBatch = MemoryRecords.readableRecords(buffer); + + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + assertEquals(1, fetcher.sendFetches()); + client.prepareResponse(fullFetchResponse(tidp0, recordsWithEmptyBatch, Errors.NONE, 100L, 0)); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + + Map>> allFetchedRecords = fetchedRecords(); + assertTrue(allFetchedRecords.isEmpty()); + + // The next offset should point to the next batch + assertEquals(lastOffset + 1, subscriptions.position(tp0).offset); + } + + @Test + public void testReadCommittedWithCompactedTopic() { + buildFetcher(OffsetResetStrategy.EARLIEST, new StringDeserializer(), + new StringDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_COMMITTED); + ByteBuffer buffer = ByteBuffer.allocate(1024); + + long pid1 = 1L; + long pid2 = 2L; + long pid3 = 3L; + + appendTransactionalRecords(buffer, pid3, 3L, + new SimpleRecord("3".getBytes(), "value".getBytes()), + new SimpleRecord("4".getBytes(), "value".getBytes())); + + appendTransactionalRecords(buffer, pid2, 15L, + new SimpleRecord("15".getBytes(), "value".getBytes()), + new SimpleRecord("16".getBytes(), "value".getBytes()), + new SimpleRecord("17".getBytes(), "value".getBytes())); + + appendTransactionalRecords(buffer, pid1, 22L, + new SimpleRecord("22".getBytes(), "value".getBytes()), + new SimpleRecord("23".getBytes(), "value".getBytes())); + + abortTransaction(buffer, pid2, 28L); + + appendTransactionalRecords(buffer, pid3, 30L, + new SimpleRecord("30".getBytes(), "value".getBytes()), + new SimpleRecord("31".getBytes(), "value".getBytes()), + new SimpleRecord("32".getBytes(), "value".getBytes())); + + commitTransaction(buffer, pid3, 35L); + + appendTransactionalRecords(buffer, pid1, 39L, + new SimpleRecord("39".getBytes(), "value".getBytes()), + new SimpleRecord("40".getBytes(), "value".getBytes())); + + // transaction from pid1 is aborted, but the marker is not included in the fetch + + buffer.flip(); + + // send the fetch + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + assertEquals(1, fetcher.sendFetches()); + + // prepare the response. the aborted transactions begin at offsets which are no longer in the log + List abortedTransactions = Arrays.asList( + new FetchResponseData.AbortedTransaction().setProducerId(pid2).setFirstOffset(6), + new FetchResponseData.AbortedTransaction().setProducerId(pid1).setFirstOffset(0) + ); + + client.prepareResponse(fullFetchResponseWithAbortedTransactions(MemoryRecords.readableRecords(buffer), + abortedTransactions, Errors.NONE, 100L, 100L, 0)); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + + Map>> allFetchedRecords = fetchedRecords(); + assertTrue(allFetchedRecords.containsKey(tp0)); + List> fetchedRecords = allFetchedRecords.get(tp0); + assertEquals(5, fetchedRecords.size()); + assertEquals(Arrays.asList(3L, 4L, 30L, 31L, 32L), collectRecordOffsets(fetchedRecords)); + } + + @Test + public void testReturnAbortedTransactionsinUncommittedMode() { + buildFetcher(OffsetResetStrategy.EARLIEST, new ByteArrayDeserializer(), + new ByteArrayDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_UNCOMMITTED); + ByteBuffer buffer = ByteBuffer.allocate(1024); + int currentOffset = 0; + + currentOffset += appendTransactionalRecords(buffer, 1L, currentOffset, + new SimpleRecord(time.milliseconds(), "key".getBytes(), "value".getBytes()), + new SimpleRecord(time.milliseconds(), "key".getBytes(), "value".getBytes())); + + abortTransaction(buffer, 1L, currentOffset); + + buffer.flip(); + + List abortedTransactions = Collections.singletonList( + new FetchResponseData.AbortedTransaction().setProducerId(1).setFirstOffset(0)); + MemoryRecords records = MemoryRecords.readableRecords(buffer); + assignFromUser(singleton(tp0)); + + subscriptions.seek(tp0, 0); + + // normal fetch + assertEquals(1, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + + client.prepareResponse(fullFetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0)); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + + Map>> fetchedRecords = fetchedRecords(); + assertTrue(fetchedRecords.containsKey(tp0)); + } + + @Test + public void testConsumerPositionUpdatedWhenSkippingAbortedTransactions() { + buildFetcher(OffsetResetStrategy.EARLIEST, new ByteArrayDeserializer(), + new ByteArrayDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_COMMITTED); + ByteBuffer buffer = ByteBuffer.allocate(1024); + long currentOffset = 0; + + currentOffset += appendTransactionalRecords(buffer, 1L, currentOffset, + new SimpleRecord(time.milliseconds(), "abort1-1".getBytes(), "value".getBytes()), + new SimpleRecord(time.milliseconds(), "abort1-2".getBytes(), "value".getBytes())); + + currentOffset += abortTransaction(buffer, 1L, currentOffset); + buffer.flip(); + + List abortedTransactions = Collections.singletonList( + new FetchResponseData.AbortedTransaction().setProducerId(1).setFirstOffset(0)); + MemoryRecords records = MemoryRecords.readableRecords(buffer); + assignFromUser(singleton(tp0)); + + subscriptions.seek(tp0, 0); + + // normal fetch + assertEquals(1, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + + client.prepareResponse(fullFetchResponseWithAbortedTransactions(records, abortedTransactions, Errors.NONE, 100L, 100L, 0)); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + + Map>> fetchedRecords = fetchedRecords(); + + // Ensure that we don't return any of the aborted records, but yet advance the consumer position. + assertFalse(fetchedRecords.containsKey(tp0)); + assertEquals(currentOffset, subscriptions.position(tp0).offset); + } + + @Test + public void testConsumingViaIncrementalFetchRequests() { + buildFetcher(2); + + List> records; + assignFromUser(new HashSet<>(Arrays.asList(tp0, tp1))); + subscriptions.seekValidated(tp0, new SubscriptionState.FetchPosition(0, Optional.empty(), metadata.currentLeader(tp0))); + subscriptions.seekValidated(tp1, new SubscriptionState.FetchPosition(1, Optional.empty(), metadata.currentLeader(tp1))); + + // Fetch some records and establish an incremental fetch session. + LinkedHashMap partitions1 = new LinkedHashMap<>(); + partitions1.put(tidp0, new FetchResponseData.PartitionData() + .setPartitionIndex(tp0.partition()) + .setHighWatermark(2) + .setLastStableOffset(2) + .setLogStartOffset(0) + .setRecords(this.records)); + partitions1.put(tidp1, new FetchResponseData.PartitionData() + .setPartitionIndex(tp1.partition()) + .setHighWatermark(100) + .setLogStartOffset(0) + .setRecords(emptyRecords)); + FetchResponse resp1 = FetchResponse.of(Errors.NONE, 0, 123, partitions1); + client.prepareResponse(resp1); + assertEquals(1, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + Map>> fetchedRecords = fetchedRecords(); + assertFalse(fetchedRecords.containsKey(tp1)); + records = fetchedRecords.get(tp0); + assertEquals(2, records.size()); + assertEquals(3L, subscriptions.position(tp0).offset); + assertEquals(1L, subscriptions.position(tp1).offset); + assertEquals(1, records.get(0).offset()); + assertEquals(2, records.get(1).offset()); + + // There is still a buffered record. + assertEquals(0, fetcher.sendFetches()); + fetchedRecords = fetchedRecords(); + assertFalse(fetchedRecords.containsKey(tp1)); + records = fetchedRecords.get(tp0); + assertEquals(1, records.size()); + assertEquals(3, records.get(0).offset()); + assertEquals(4L, subscriptions.position(tp0).offset); + + // The second response contains no new records. + LinkedHashMap partitions2 = new LinkedHashMap<>(); + FetchResponse resp2 = FetchResponse.of(Errors.NONE, 0, 123, partitions2); + client.prepareResponse(resp2); + assertEquals(1, fetcher.sendFetches()); + consumerClient.poll(time.timer(0)); + fetchedRecords = fetchedRecords(); + assertTrue(fetchedRecords.isEmpty()); + assertEquals(4L, subscriptions.position(tp0).offset); + assertEquals(1L, subscriptions.position(tp1).offset); + + // The third response contains some new records for tp0. + LinkedHashMap partitions3 = new LinkedHashMap<>(); + partitions3.put(tidp0, new FetchResponseData.PartitionData() + .setPartitionIndex(tp0.partition()) + .setHighWatermark(100) + .setLastStableOffset(4) + .setLogStartOffset(0) + .setRecords(this.nextRecords)); + FetchResponse resp3 = FetchResponse.of(Errors.NONE, 0, 123, partitions3); + client.prepareResponse(resp3); + assertEquals(1, fetcher.sendFetches()); + consumerClient.poll(time.timer(0)); + fetchedRecords = fetchedRecords(); + assertFalse(fetchedRecords.containsKey(tp1)); + records = fetchedRecords.get(tp0); + assertEquals(2, records.size()); + assertEquals(6L, subscriptions.position(tp0).offset); + assertEquals(1L, subscriptions.position(tp1).offset); + assertEquals(4, records.get(0).offset()); + assertEquals(5, records.get(1).offset()); + } + + @Test + public void testFetcherConcurrency() throws Exception { + int numPartitions = 20; + Set topicPartitions = new HashSet<>(); + for (int i = 0; i < numPartitions; i++) + topicPartitions.add(new TopicPartition(topicName, i)); + + LogContext logContext = new LogContext(); + buildDependencies(new MetricConfig(), Long.MAX_VALUE, new SubscriptionState(logContext, OffsetResetStrategy.EARLIEST), logContext); + + fetcher = new Fetcher( + new LogContext(), + consumerClient, + minBytes, + maxBytes, + maxWaitMs, + fetchSize, + 2 * numPartitions, + true, + "", + new ByteArrayDeserializer(), + new ByteArrayDeserializer(), + metadata, + subscriptions, + metrics, + metricsRegistry, + time, + retryBackoffMs, + requestTimeoutMs, + IsolationLevel.READ_UNCOMMITTED, + apiVersions) { + @Override + protected FetchSessionHandler sessionHandler(int id) { + final FetchSessionHandler handler = super.sessionHandler(id); + if (handler == null) + return null; + else { + return new FetchSessionHandler(new LogContext(), id) { + @Override + public Builder newBuilder() { + verifySessionPartitions(); + return handler.newBuilder(); + } + + @Override + public boolean handleResponse(FetchResponse response, short version) { + verifySessionPartitions(); + return handler.handleResponse(response, version); + } + + @Override + public void handleError(Throwable t) { + verifySessionPartitions(); + handler.handleError(t); + } + + // Verify that session partitions can be traversed safely. + private void verifySessionPartitions() { + try { + Field field = FetchSessionHandler.class.getDeclaredField("sessionPartitions"); + field.setAccessible(true); + LinkedHashMap sessionPartitions = + (LinkedHashMap) field.get(handler); + for (Map.Entry entry : sessionPartitions.entrySet()) { + // If `sessionPartitions` are modified on another thread, Thread.yield will increase the + // possibility of ConcurrentModificationException if appropriate synchronization is not used. + Thread.yield(); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + }; + } + } + }; + + MetadataResponse initialMetadataResponse = RequestTestUtils.metadataUpdateWithIds(1, + singletonMap(topicName, numPartitions), tp -> validLeaderEpoch, topicIds); + client.updateMetadata(initialMetadataResponse); + fetchSize = 10000; + + assignFromUser(topicPartitions); + topicPartitions.forEach(tp -> subscriptions.seek(tp, 0L)); + + AtomicInteger fetchesRemaining = new AtomicInteger(1000); + executorService = Executors.newSingleThreadExecutor(); + Future future = executorService.submit(() -> { + while (fetchesRemaining.get() > 0) { + synchronized (consumerClient) { + if (!client.requests().isEmpty()) { + ClientRequest request = client.requests().peek(); + FetchRequest fetchRequest = (FetchRequest) request.requestBuilder().build(); + LinkedHashMap responseMap = new LinkedHashMap<>(); + for (Map.Entry entry : fetchRequest.fetchData(topicNames).entrySet()) { + TopicIdPartition tp = entry.getKey(); + long offset = entry.getValue().fetchOffset; + responseMap.put(tp, new FetchResponseData.PartitionData() + .setPartitionIndex(tp.topicPartition().partition()) + .setHighWatermark(offset + 2) + .setLastStableOffset(offset + 2) + .setLogStartOffset(0) + .setRecords(buildRecords(offset, 2, offset))); + } + client.respondToRequest(request, FetchResponse.of(Errors.NONE, 0, 123, responseMap)); + consumerClient.poll(time.timer(0)); + } + } + } + return fetchesRemaining.get(); + }); + Map nextFetchOffsets = topicPartitions.stream() + .collect(Collectors.toMap(Function.identity(), t -> 0L)); + while (fetchesRemaining.get() > 0 && !future.isDone()) { + if (fetcher.sendFetches() == 1) { + synchronized (consumerClient) { + consumerClient.poll(time.timer(0)); + } + } + if (fetcher.hasCompletedFetches()) { + Map>> fetchedRecords = fetchedRecords(); + if (!fetchedRecords.isEmpty()) { + fetchesRemaining.decrementAndGet(); + fetchedRecords.forEach((tp, records) -> { + assertEquals(2, records.size()); + long nextOffset = nextFetchOffsets.get(tp); + assertEquals(nextOffset, records.get(0).offset()); + assertEquals(nextOffset + 1, records.get(1).offset()); + nextFetchOffsets.put(tp, nextOffset + 2); + }); + } + } + } + assertEquals(0, future.get()); + } + + @Test + public void testFetcherSessionEpochUpdate() throws Exception { + buildFetcher(2); + + MetadataResponse initialMetadataResponse = RequestTestUtils.metadataUpdateWithIds(1, singletonMap(topicName, 1), topicIds); + client.updateMetadata(initialMetadataResponse); + assignFromUser(Collections.singleton(tp0)); + subscriptions.seek(tp0, 0L); + + AtomicInteger fetchesRemaining = new AtomicInteger(1000); + executorService = Executors.newSingleThreadExecutor(); + Future future = executorService.submit(() -> { + long nextOffset = 0; + long nextEpoch = 0; + while (fetchesRemaining.get() > 0) { + synchronized (consumerClient) { + if (!client.requests().isEmpty()) { + ClientRequest request = client.requests().peek(); + FetchRequest fetchRequest = (FetchRequest) request.requestBuilder().build(); + int epoch = fetchRequest.metadata().epoch(); + assertTrue(epoch == 0 || epoch == nextEpoch, + String.format("Unexpected epoch expected %d got %d", nextEpoch, epoch)); + nextEpoch++; + LinkedHashMap responseMap = new LinkedHashMap<>(); + responseMap.put(tidp0, new FetchResponseData.PartitionData() + .setPartitionIndex(tp0.partition()) + .setHighWatermark(nextOffset + 2) + .setLastStableOffset(nextOffset + 2) + .setLogStartOffset(0) + .setRecords(buildRecords(nextOffset, 2, nextOffset))); + nextOffset += 2; + client.respondToRequest(request, FetchResponse.of(Errors.NONE, 0, 123, responseMap)); + consumerClient.poll(time.timer(0)); + } + } + } + return fetchesRemaining.get(); + }); + long nextFetchOffset = 0; + while (fetchesRemaining.get() > 0 && !future.isDone()) { + if (fetcher.sendFetches() == 1) { + synchronized (consumerClient) { + consumerClient.poll(time.timer(0)); + } + } + if (fetcher.hasCompletedFetches()) { + Map>> fetchedRecords = fetchedRecords(); + if (!fetchedRecords.isEmpty()) { + fetchesRemaining.decrementAndGet(); + List> records = fetchedRecords.get(tp0); + assertEquals(2, records.size()); + assertEquals(nextFetchOffset, records.get(0).offset()); + assertEquals(nextFetchOffset + 1, records.get(1).offset()); + nextFetchOffset += 2; + } + assertTrue(fetchedRecords().isEmpty()); + } + } + assertEquals(0, future.get()); + } + + @Test + public void testEmptyControlBatch() { + buildFetcher(OffsetResetStrategy.EARLIEST, new ByteArrayDeserializer(), + new ByteArrayDeserializer(), Integer.MAX_VALUE, IsolationLevel.READ_COMMITTED); + ByteBuffer buffer = ByteBuffer.allocate(1024); + int currentOffset = 1; + + // Empty control batch should not cause an exception + DefaultRecordBatch.writeEmptyHeader(buffer, RecordBatch.MAGIC_VALUE_V2, 1L, + (short) 0, -1, 0, 0, + RecordBatch.NO_PARTITION_LEADER_EPOCH, TimestampType.CREATE_TIME, time.milliseconds(), + true, true); + + currentOffset += appendTransactionalRecords(buffer, 1L, currentOffset, + new SimpleRecord(time.milliseconds(), "key".getBytes(), "value".getBytes()), + new SimpleRecord(time.milliseconds(), "key".getBytes(), "value".getBytes())); + + commitTransaction(buffer, 1L, currentOffset); + buffer.flip(); + + MemoryRecords records = MemoryRecords.readableRecords(buffer); + assignFromUser(singleton(tp0)); + + subscriptions.seek(tp0, 0); + + // normal fetch + assertEquals(1, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + client.prepareResponse(body -> { + FetchRequest request = (FetchRequest) body; + assertEquals(IsolationLevel.READ_COMMITTED, request.isolationLevel()); + return true; + }, fullFetchResponseWithAbortedTransactions(records, Collections.emptyList(), Errors.NONE, 100L, 100L, 0)); + + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + + Map>> fetchedRecords = fetchedRecords(); + assertTrue(fetchedRecords.containsKey(tp0)); + assertEquals(fetchedRecords.get(tp0).size(), 2); + } + + private MemoryRecords buildRecords(long baseOffset, int count, long firstMessageId) { + MemoryRecordsBuilder builder = MemoryRecords.builder(ByteBuffer.allocate(1024), CompressionType.NONE, TimestampType.CREATE_TIME, baseOffset); + for (int i = 0; i < count; i++) + builder.append(0L, "key".getBytes(), ("value-" + (firstMessageId + i)).getBytes()); + return builder.build(); + } + + private int appendTransactionalRecords(ByteBuffer buffer, long pid, long baseOffset, int baseSequence, SimpleRecord... records) { + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE, + TimestampType.CREATE_TIME, baseOffset, time.milliseconds(), pid, (short) 0, baseSequence, true, + RecordBatch.NO_PARTITION_LEADER_EPOCH); + + for (SimpleRecord record : records) { + builder.append(record); + } + builder.build(); + return records.length; + } + + private int appendTransactionalRecords(ByteBuffer buffer, long pid, long baseOffset, SimpleRecord... records) { + return appendTransactionalRecords(buffer, pid, baseOffset, (int) baseOffset, records); + } + + private void commitTransaction(ByteBuffer buffer, long producerId, long baseOffset) { + short producerEpoch = 0; + int partitionLeaderEpoch = 0; + MemoryRecords.writeEndTransactionalMarker(buffer, baseOffset, time.milliseconds(), partitionLeaderEpoch, producerId, producerEpoch, + new EndTransactionMarker(ControlRecordType.COMMIT, 0)); + } + + private int abortTransaction(ByteBuffer buffer, long producerId, long baseOffset) { + short producerEpoch = 0; + int partitionLeaderEpoch = 0; + MemoryRecords.writeEndTransactionalMarker(buffer, baseOffset, time.milliseconds(), partitionLeaderEpoch, producerId, producerEpoch, + new EndTransactionMarker(ControlRecordType.ABORT, 0)); + return 1; + } + + private void testGetOffsetsForTimesWithError(Errors errorForP0, + Errors errorForP1, + long offsetForP0, + long offsetForP1, + Long expectedOffsetForP0, + Long expectedOffsetForP1) { + client.reset(); + String topicName2 = "topic2"; + TopicPartition t2p0 = new TopicPartition(topicName2, 0); + // Expect a metadata refresh. + metadata.bootstrap(ClientUtils.parseAndValidateAddresses(Collections.singletonList("1.1.1.1:1111"), + ClientDnsLookup.USE_ALL_DNS_IPS)); + + Map partitionNumByTopic = new HashMap<>(); + partitionNumByTopic.put(topicName, 2); + partitionNumByTopic.put(topicName2, 1); + MetadataResponse updateMetadataResponse = RequestTestUtils.metadataUpdateWithIds(2, partitionNumByTopic, topicIds); + Cluster updatedCluster = updateMetadataResponse.buildCluster(); + + // The metadata refresh should contain all the topics. + client.prepareMetadataUpdate(updateMetadataResponse, true); + + // First try should fail due to metadata error. + client.prepareResponseFrom(listOffsetResponse(t2p0, errorForP0, offsetForP0, offsetForP0), + updatedCluster.leaderFor(t2p0)); + client.prepareResponseFrom(listOffsetResponse(tp1, errorForP1, offsetForP1, offsetForP1), + updatedCluster.leaderFor(tp1)); + // Second try should succeed. + client.prepareResponseFrom(listOffsetResponse(t2p0, Errors.NONE, offsetForP0, offsetForP0), + updatedCluster.leaderFor(t2p0)); + client.prepareResponseFrom(listOffsetResponse(tp1, Errors.NONE, offsetForP1, offsetForP1), + updatedCluster.leaderFor(tp1)); + + Map timestampToSearch = new HashMap<>(); + timestampToSearch.put(t2p0, 0L); + timestampToSearch.put(tp1, 0L); + Map offsetAndTimestampMap = + fetcher.offsetsForTimes(timestampToSearch, time.timer(Long.MAX_VALUE)); + + if (expectedOffsetForP0 == null) + assertNull(offsetAndTimestampMap.get(t2p0)); + else { + assertEquals(expectedOffsetForP0.longValue(), offsetAndTimestampMap.get(t2p0).timestamp()); + assertEquals(expectedOffsetForP0.longValue(), offsetAndTimestampMap.get(t2p0).offset()); + } + + if (expectedOffsetForP1 == null) + assertNull(offsetAndTimestampMap.get(tp1)); + else { + assertEquals(expectedOffsetForP1.longValue(), offsetAndTimestampMap.get(tp1).timestamp()); + assertEquals(expectedOffsetForP1.longValue(), offsetAndTimestampMap.get(tp1).offset()); + } + } + + private void testGetOffsetsForTimesWithUnknownOffset() { + client.reset(); + // Ensure metadata has both partitions. + MetadataResponse initialMetadataUpdate = RequestTestUtils.metadataUpdateWithIds(1, singletonMap(topicName, 1), topicIds); + client.updateMetadata(initialMetadataUpdate); + + ListOffsetsResponseData data = new ListOffsetsResponseData() + .setThrottleTimeMs(0) + .setTopics(Collections.singletonList(new ListOffsetsTopicResponse() + .setName(tp0.topic()) + .setPartitions(Collections.singletonList(new ListOffsetsPartitionResponse() + .setPartitionIndex(tp0.partition()) + .setErrorCode(Errors.NONE.code()) + .setTimestamp(ListOffsetsResponse.UNKNOWN_TIMESTAMP) + .setOffset(ListOffsetsResponse.UNKNOWN_OFFSET))))); + + client.prepareResponseFrom(new ListOffsetsResponse(data), + metadata.fetch().leaderFor(tp0)); + + Map timestampToSearch = new HashMap<>(); + timestampToSearch.put(tp0, 0L); + Map offsetAndTimestampMap = + fetcher.offsetsForTimes(timestampToSearch, time.timer(Long.MAX_VALUE)); + + assertTrue(offsetAndTimestampMap.containsKey(tp0)); + assertNull(offsetAndTimestampMap.get(tp0)); + } + + @Test + public void testGetOffsetsForTimesWithUnknownOffsetV0() { + buildFetcher(); + // Empty map + assertTrue(fetcher.offsetsForTimes(new HashMap<>(), time.timer(100L)).isEmpty()); + // Unknown Offset + client.reset(); + // Ensure metadata has both partition. + MetadataResponse initialMetadataUpdate = RequestTestUtils.metadataUpdateWithIds(1, singletonMap(topicName, 1), topicIds); + client.updateMetadata(initialMetadataUpdate); + // Force LIST_OFFSETS version 0 + Node node = metadata.fetch().nodes().get(0); + apiVersions.update(node.idString(), NodeApiVersions.create( + ApiKeys.LIST_OFFSETS.id, (short) 0, (short) 0)); + + ListOffsetsResponseData data = new ListOffsetsResponseData() + .setThrottleTimeMs(0) + .setTopics(Collections.singletonList(new ListOffsetsTopicResponse() + .setName(tp0.topic()) + .setPartitions(Collections.singletonList(new ListOffsetsPartitionResponse() + .setPartitionIndex(tp0.partition()) + .setErrorCode(Errors.NONE.code()) + .setTimestamp(ListOffsetsResponse.UNKNOWN_TIMESTAMP) + .setOldStyleOffsets(Collections.emptyList()))))); + + client.prepareResponseFrom(new ListOffsetsResponse(data), + metadata.fetch().leaderFor(tp0)); + + Map timestampToSearch = new HashMap<>(); + timestampToSearch.put(tp0, 0L); + Map offsetAndTimestampMap = + fetcher.offsetsForTimes(timestampToSearch, time.timer(Long.MAX_VALUE)); + + assertTrue(offsetAndTimestampMap.containsKey(tp0)); + assertNull(offsetAndTimestampMap.get(tp0)); + } + + @Test + public void testSubscriptionPositionUpdatedWithEpoch() { + // Create some records that include a leader epoch (1) + MemoryRecordsBuilder builder = MemoryRecords.builder( + ByteBuffer.allocate(1024), + RecordBatch.CURRENT_MAGIC_VALUE, + CompressionType.NONE, + TimestampType.CREATE_TIME, + 0L, + RecordBatch.NO_TIMESTAMP, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_EPOCH, + RecordBatch.NO_SEQUENCE, + false, + 1 + ); + builder.appendWithOffset(0L, 0L, "key".getBytes(), "value-1".getBytes()); + builder.appendWithOffset(1L, 0L, "key".getBytes(), "value-2".getBytes()); + builder.appendWithOffset(2L, 0L, "key".getBytes(), "value-3".getBytes()); + MemoryRecords records = builder.build(); + + buildFetcher(); + assignFromUser(singleton(tp0)); + + // Initialize the epoch=1 + Map partitionCounts = new HashMap<>(); + partitionCounts.put(tp0.topic(), 4); + MetadataResponse metadataResponse = RequestTestUtils.metadataUpdateWithIds("dummy", 1, Collections.emptyMap(), partitionCounts, tp -> 1, topicIds); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 0L); + + // Seek + subscriptions.seek(tp0, 0); + + // Do a normal fetch + assertEquals(1, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + + client.prepareResponse(fullFetchResponse(tidp0, records, Errors.NONE, 100L, 0)); + consumerClient.pollNoWakeup(); + assertTrue(fetcher.hasCompletedFetches()); + + Map>> partitionRecords = fetchedRecords(); + assertTrue(partitionRecords.containsKey(tp0)); + + assertEquals(subscriptions.position(tp0).offset, 3L); + assertOptional(subscriptions.position(tp0).offsetEpoch, value -> assertEquals(value.intValue(), 1)); + } + + @Test + public void testOffsetValidationRequestGrouping() { + buildFetcher(); + assignFromUser(mkSet(tp0, tp1, tp2, tp3)); + + metadata.updateWithCurrentRequestVersion(RequestTestUtils.metadataUpdateWithIds("dummy", 3, + Collections.emptyMap(), singletonMap(topicName, 4), + tp -> 5, topicIds), false, 0L); + + for (TopicPartition tp : subscriptions.assignedPartitions()) { + Metadata.LeaderAndEpoch leaderAndEpoch = new Metadata.LeaderAndEpoch( + metadata.currentLeader(tp).leader, Optional.of(4)); + subscriptions.seekUnvalidated(tp, + new SubscriptionState.FetchPosition(0, Optional.of(4), leaderAndEpoch)); + } + + Set allRequestedPartitions = new HashSet<>(); + + for (Node node : metadata.fetch().nodes()) { + apiVersions.update(node.idString(), NodeApiVersions.create()); + + Set expectedPartitions = subscriptions.assignedPartitions().stream() + .filter(tp -> + metadata.currentLeader(tp).leader.equals(Optional.of(node))) + .collect(Collectors.toSet()); + + assertTrue(expectedPartitions.stream().noneMatch(allRequestedPartitions::contains)); + assertTrue(expectedPartitions.size() > 0); + allRequestedPartitions.addAll(expectedPartitions); + + OffsetForLeaderEpochResponseData data = new OffsetForLeaderEpochResponseData(); + expectedPartitions.forEach(tp -> { + OffsetForLeaderTopicResult topic = data.topics().find(tp.topic()); + if (topic == null) { + topic = new OffsetForLeaderTopicResult().setTopic(tp.topic()); + data.topics().add(topic); + } + topic.partitions().add(new EpochEndOffset() + .setPartition(tp.partition()) + .setErrorCode(Errors.NONE.code()) + .setLeaderEpoch(4) + .setEndOffset(0)); + }); + + OffsetsForLeaderEpochResponse response = new OffsetsForLeaderEpochResponse(data); + client.prepareResponseFrom(body -> { + OffsetsForLeaderEpochRequest request = (OffsetsForLeaderEpochRequest) body; + return expectedPartitions.equals(offsetForLeaderPartitionMap(request.data()).keySet()); + }, response, node); + } + + assertEquals(subscriptions.assignedPartitions(), allRequestedPartitions); + + fetcher.validateOffsetsIfNeeded(); + consumerClient.pollNoWakeup(); + + assertTrue(subscriptions.assignedPartitions() + .stream().noneMatch(subscriptions::awaitingValidation)); + } + + @Test + public void testOffsetValidationAwaitsNodeApiVersion() { + buildFetcher(); + assignFromUser(singleton(tp0)); + + Map partitionCounts = new HashMap<>(); + partitionCounts.put(tp0.topic(), 4); + + final int epochOne = 1; + + metadata.updateWithCurrentRequestVersion(RequestTestUtils.metadataUpdateWithIds("dummy", 1, + Collections.emptyMap(), partitionCounts, tp -> epochOne, topicIds), false, 0L); + + Node node = metadata.fetch().nodes().get(0); + assertFalse(client.isConnected(node.idString())); + + // Seek with a position and leader+epoch + Metadata.LeaderAndEpoch leaderAndEpoch = new Metadata.LeaderAndEpoch( + metadata.currentLeader(tp0).leader, Optional.of(epochOne)); + subscriptions.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(20L, Optional.of(epochOne), leaderAndEpoch)); + assertFalse(client.isConnected(node.idString())); + assertTrue(subscriptions.awaitingValidation(tp0)); + + // No version information is initially available, but the node is now connected + fetcher.validateOffsetsIfNeeded(); + assertTrue(subscriptions.awaitingValidation(tp0)); + assertTrue(client.isConnected(node.idString())); + apiVersions.update(node.idString(), NodeApiVersions.create()); + + // On the next call, the OffsetForLeaderEpoch request is sent and validation completes + client.prepareResponseFrom( + prepareOffsetsForLeaderEpochResponse(tp0, Errors.NONE, epochOne, 30L), + node); + + fetcher.validateOffsetsIfNeeded(); + consumerClient.pollNoWakeup(); + + assertFalse(subscriptions.awaitingValidation(tp0)); + assertEquals(20L, subscriptions.position(tp0).offset); + } + + @Test + public void testOffsetValidationSkippedForOldBroker() { + // Old brokers may require CLUSTER permission to use the OffsetForLeaderEpoch API, + // so we should skip offset validation and not send the request. + + buildFetcher(); + assignFromUser(singleton(tp0)); + + Map partitionCounts = new HashMap<>(); + partitionCounts.put(tp0.topic(), 4); + + final int epochOne = 1; + final int epochTwo = 2; + + // Start with metadata, epoch=1 + metadata.updateWithCurrentRequestVersion(RequestTestUtils.metadataUpdateWithIds("dummy", 1, + Collections.emptyMap(), partitionCounts, tp -> epochOne, topicIds), false, 0L); + + // Offset validation requires OffsetForLeaderEpoch request v3 or higher + Node node = metadata.fetch().nodes().get(0); + apiVersions.update(node.idString(), NodeApiVersions.create( + ApiKeys.OFFSET_FOR_LEADER_EPOCH.id, (short) 0, (short) 2)); + + { + // Seek with a position and leader+epoch + Metadata.LeaderAndEpoch leaderAndEpoch = new Metadata.LeaderAndEpoch( + metadata.currentLeader(tp0).leader, Optional.of(epochOne)); + subscriptions.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(0, Optional.of(epochOne), leaderAndEpoch)); + + // Update metadata to epoch=2, enter validation + metadata.updateWithCurrentRequestVersion(RequestTestUtils.metadataUpdateWithIds("dummy", 1, + Collections.emptyMap(), partitionCounts, tp -> epochTwo, topicIds), false, 0L); + fetcher.validateOffsetsIfNeeded(); + + // Offset validation is skipped + assertFalse(subscriptions.awaitingValidation(tp0)); + } + + { + // Seek with a position and leader+epoch + Metadata.LeaderAndEpoch leaderAndEpoch = new Metadata.LeaderAndEpoch( + metadata.currentLeader(tp0).leader, Optional.of(epochOne)); + subscriptions.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(0, Optional.of(epochOne), leaderAndEpoch)); + + // Update metadata to epoch=2, enter validation + metadata.updateWithCurrentRequestVersion(RequestTestUtils.metadataUpdateWithIds("dummy", 1, + Collections.emptyMap(), partitionCounts, tp -> epochTwo, topicIds), false, 0L); + + // Subscription should not stay in AWAITING_VALIDATION in prepareFetchRequest + assertEquals(1, fetcher.sendFetches()); + assertFalse(subscriptions.awaitingValidation(tp0)); + } + } + + @Test + public void testOffsetValidationSkippedForOldResponse() { + // Old responses may provide unreliable leader epoch, + // so we should skip offset validation and not send the request. + buildFetcher(); + assignFromUser(singleton(tp0)); + + Map partitionCounts = new HashMap<>(); + partitionCounts.put(tp0.topic(), 4); + + final int epochOne = 1; + + metadata.updateWithCurrentRequestVersion(RequestTestUtils.metadataUpdateWithIds("dummy", 1, + Collections.emptyMap(), partitionCounts, tp -> epochOne, topicIds), false, 0L); + + Node node = metadata.fetch().nodes().get(0); + assertFalse(client.isConnected(node.idString())); + + // Seek with a position and leader+epoch + Metadata.LeaderAndEpoch leaderAndEpoch = new Metadata.LeaderAndEpoch( + metadata.currentLeader(tp0).leader, Optional.of(epochOne)); + subscriptions.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(20L, Optional.of(epochOne), leaderAndEpoch)); + assertFalse(client.isConnected(node.idString())); + assertTrue(subscriptions.awaitingValidation(tp0)); + + // Inject an older version of the metadata response + final short responseVersion = 8; + metadata.updateWithCurrentRequestVersion(RequestTestUtils.metadataUpdateWith("dummy", 1, + Collections.emptyMap(), partitionCounts, tp -> null, MetadataResponse.PartitionMetadata::new, responseVersion, topicIds), false, 0L); + fetcher.validateOffsetsIfNeeded(); + // Offset validation is skipped + assertFalse(subscriptions.awaitingValidation(tp0)); + } + + @Test + public void testOffsetValidationResetOffsetForUndefinedEpochWithDefinedResetPolicy() { + testOffsetValidationWithGivenEpochOffset( + UNDEFINED_EPOCH, 0L, OffsetResetStrategy.EARLIEST); + } + + @Test + public void testOffsetValidationResetOffsetForUndefinedOffsetWithDefinedResetPolicy() { + testOffsetValidationWithGivenEpochOffset( + 2, UNDEFINED_EPOCH_OFFSET, OffsetResetStrategy.EARLIEST); + } + + @Test + public void testOffsetValidationResetOffsetForUndefinedEpochWithUndefinedResetPolicy() { + testOffsetValidationWithGivenEpochOffset( + UNDEFINED_EPOCH, 0L, OffsetResetStrategy.NONE); + } + + @Test + public void testOffsetValidationResetOffsetForUndefinedOffsetWithUndefinedResetPolicy() { + testOffsetValidationWithGivenEpochOffset( + 2, UNDEFINED_EPOCH_OFFSET, OffsetResetStrategy.NONE); + } + + @Test + public void testOffsetValidationTriggerLogTruncationForBadOffsetWithUndefinedResetPolicy() { + testOffsetValidationWithGivenEpochOffset( + 1, 1L, OffsetResetStrategy.NONE); + } + + private void testOffsetValidationWithGivenEpochOffset(int leaderEpoch, + long endOffset, + OffsetResetStrategy offsetResetStrategy) { + buildFetcher(offsetResetStrategy); + assignFromUser(singleton(tp0)); + + Map partitionCounts = new HashMap<>(); + partitionCounts.put(tp0.topic(), 4); + + final int epochOne = 1; + final long initialOffset = 5; + + metadata.updateWithCurrentRequestVersion(RequestTestUtils.metadataUpdateWithIds("dummy", 1, + Collections.emptyMap(), partitionCounts, tp -> epochOne, topicIds), false, 0L); + + // Offset validation requires OffsetForLeaderEpoch request v3 or higher + Node node = metadata.fetch().nodes().get(0); + apiVersions.update(node.idString(), NodeApiVersions.create()); + + Metadata.LeaderAndEpoch leaderAndEpoch = new Metadata.LeaderAndEpoch(metadata.currentLeader(tp0).leader, Optional.of(epochOne)); + subscriptions.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(initialOffset, Optional.of(epochOne), leaderAndEpoch)); + + fetcher.validateOffsetsIfNeeded(); + + consumerClient.poll(time.timer(Duration.ZERO)); + assertTrue(subscriptions.awaitingValidation(tp0)); + assertTrue(client.hasInFlightRequests()); + + client.respond( + offsetsForLeaderEpochRequestMatcher(tp0, epochOne, epochOne), + prepareOffsetsForLeaderEpochResponse(tp0, Errors.NONE, leaderEpoch, endOffset)); + consumerClient.poll(time.timer(Duration.ZERO)); + + if (offsetResetStrategy == OffsetResetStrategy.NONE) { + LogTruncationException thrown = + assertThrows(LogTruncationException.class, () -> fetcher.validateOffsetsIfNeeded()); + assertEquals(singletonMap(tp0, initialOffset), thrown.offsetOutOfRangePartitions()); + + if (endOffset == UNDEFINED_EPOCH_OFFSET || leaderEpoch == UNDEFINED_EPOCH) { + assertEquals(Collections.emptyMap(), thrown.divergentOffsets()); + } else { + OffsetAndMetadata expectedDivergentOffset = new OffsetAndMetadata( + endOffset, Optional.of(leaderEpoch), ""); + assertEquals(singletonMap(tp0, expectedDivergentOffset), thrown.divergentOffsets()); + } + assertTrue(subscriptions.awaitingValidation(tp0)); + } else { + fetcher.validateOffsetsIfNeeded(); + assertFalse(subscriptions.awaitingValidation(tp0)); + } + } + + @Test + public void testOffsetValidationHandlesSeekWithInflightOffsetForLeaderRequest() { + buildFetcher(); + assignFromUser(singleton(tp0)); + + Map partitionCounts = new HashMap<>(); + partitionCounts.put(tp0.topic(), 4); + + final int epochOne = 1; + + metadata.updateWithCurrentRequestVersion(RequestTestUtils.metadataUpdateWithIds("dummy", 1, + Collections.emptyMap(), partitionCounts, tp -> epochOne, topicIds), false, 0L); + + // Offset validation requires OffsetForLeaderEpoch request v3 or higher + Node node = metadata.fetch().nodes().get(0); + apiVersions.update(node.idString(), NodeApiVersions.create()); + + Metadata.LeaderAndEpoch leaderAndEpoch = new Metadata.LeaderAndEpoch(metadata.currentLeader(tp0).leader, Optional.of(epochOne)); + subscriptions.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(0, Optional.of(epochOne), leaderAndEpoch)); + + fetcher.validateOffsetsIfNeeded(); + consumerClient.poll(time.timer(Duration.ZERO)); + assertTrue(subscriptions.awaitingValidation(tp0)); + assertTrue(client.hasInFlightRequests()); + + // While the OffsetForLeaderEpoch request is in-flight, we seek to a different offset. + subscriptions.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(5, Optional.of(epochOne), leaderAndEpoch)); + assertTrue(subscriptions.awaitingValidation(tp0)); + + client.respond( + offsetsForLeaderEpochRequestMatcher(tp0, epochOne, epochOne), + prepareOffsetsForLeaderEpochResponse(tp0, Errors.NONE, 0, 0L)); + consumerClient.poll(time.timer(Duration.ZERO)); + + // The response should be ignored since we were validating a different position. + assertTrue(subscriptions.awaitingValidation(tp0)); + } + + @Test + public void testOffsetValidationFencing() { + buildFetcher(); + assignFromUser(singleton(tp0)); + + Map partitionCounts = new HashMap<>(); + partitionCounts.put(tp0.topic(), 4); + + final int epochOne = 1; + final int epochTwo = 2; + final int epochThree = 3; + + // Start with metadata, epoch=1 + metadata.updateWithCurrentRequestVersion(RequestTestUtils.metadataUpdateWithIds("dummy", 1, + Collections.emptyMap(), partitionCounts, tp -> epochOne, topicIds), false, 0L); + + // Offset validation requires OffsetForLeaderEpoch request v3 or higher + Node node = metadata.fetch().nodes().get(0); + apiVersions.update(node.idString(), NodeApiVersions.create()); + + // Seek with a position and leader+epoch + Metadata.LeaderAndEpoch leaderAndEpoch = new Metadata.LeaderAndEpoch(metadata.currentLeader(tp0).leader, Optional.of(epochOne)); + subscriptions.seekValidated(tp0, new SubscriptionState.FetchPosition(0, Optional.of(epochOne), leaderAndEpoch)); + + // Update metadata to epoch=2, enter validation + metadata.updateWithCurrentRequestVersion(RequestTestUtils.metadataUpdateWithIds("dummy", 1, + Collections.emptyMap(), partitionCounts, tp -> epochTwo, topicIds), false, 0L); + fetcher.validateOffsetsIfNeeded(); + assertTrue(subscriptions.awaitingValidation(tp0)); + + // Update the position to epoch=3, as we would from a fetch + subscriptions.completeValidation(tp0); + SubscriptionState.FetchPosition nextPosition = new SubscriptionState.FetchPosition( + 10, + Optional.of(epochTwo), + new Metadata.LeaderAndEpoch(leaderAndEpoch.leader, Optional.of(epochTwo))); + subscriptions.position(tp0, nextPosition); + subscriptions.maybeValidatePositionForCurrentLeader(apiVersions, tp0, new Metadata.LeaderAndEpoch(leaderAndEpoch.leader, Optional.of(epochThree))); + + // Prepare offset list response from async validation with epoch=2 + client.prepareResponse(prepareOffsetsForLeaderEpochResponse(tp0, Errors.NONE, epochTwo, 10L)); + consumerClient.pollNoWakeup(); + assertTrue(subscriptions.awaitingValidation(tp0), "Expected validation to fail since leader epoch changed"); + + // Next round of validation, should succeed in validating the position + fetcher.validateOffsetsIfNeeded(); + client.prepareResponse(prepareOffsetsForLeaderEpochResponse(tp0, Errors.NONE, epochThree, 10L)); + consumerClient.pollNoWakeup(); + assertFalse(subscriptions.awaitingValidation(tp0), "Expected validation to succeed with latest epoch"); + } + + @Test + public void testSkipValidationForOlderApiVersion() { + buildFetcher(); + assignFromUser(singleton(tp0)); + + Map partitionCounts = new HashMap<>(); + partitionCounts.put(tp0.topic(), 4); + + apiVersions.update("0", NodeApiVersions.create(ApiKeys.OFFSET_FOR_LEADER_EPOCH.id, (short) 0, (short) 2)); + + // Start with metadata, epoch=1 + metadata.updateWithCurrentRequestVersion(RequestTestUtils.metadataUpdateWithIds("dummy", 1, + Collections.emptyMap(), partitionCounts, tp -> 1, topicIds), false, 0L); + + // Request offset reset + subscriptions.requestOffsetReset(tp0, OffsetResetStrategy.LATEST); + + // Since we have no position due to reset, no fetch is sent + assertEquals(0, fetcher.sendFetches()); + + // Still no position, ensure offset validation logic did not transition us to FETCHING state + assertEquals(0, fetcher.sendFetches()); + + // Complete reset and now we can fetch + fetcher.resetOffsetIfNeeded(tp0, OffsetResetStrategy.LATEST, + new Fetcher.ListOffsetData(100, 1L, Optional.empty())); + assertEquals(1, fetcher.sendFetches()); + } + + @Test + public void testTruncationDetected() { + // Create some records that include a leader epoch (1) + MemoryRecordsBuilder builder = MemoryRecords.builder( + ByteBuffer.allocate(1024), + RecordBatch.CURRENT_MAGIC_VALUE, + CompressionType.NONE, + TimestampType.CREATE_TIME, + 0L, + RecordBatch.NO_TIMESTAMP, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_EPOCH, + RecordBatch.NO_SEQUENCE, + false, + 1 // record epoch is earlier than the leader epoch on the client + ); + builder.appendWithOffset(0L, 0L, "key".getBytes(), "value-1".getBytes()); + builder.appendWithOffset(1L, 0L, "key".getBytes(), "value-2".getBytes()); + builder.appendWithOffset(2L, 0L, "key".getBytes(), "value-3".getBytes()); + MemoryRecords records = builder.build(); + + buildFetcher(); + assignFromUser(singleton(tp0)); + + // Initialize the epoch=2 + Map partitionCounts = new HashMap<>(); + partitionCounts.put(tp0.topic(), 4); + MetadataResponse metadataResponse = RequestTestUtils.metadataUpdateWithIds("dummy", 1, Collections.emptyMap(), + partitionCounts, tp -> 2, topicIds); + metadata.updateWithCurrentRequestVersion(metadataResponse, false, 0L); + + // Offset validation requires OffsetForLeaderEpoch request v3 or higher + Node node = metadata.fetch().nodes().get(0); + apiVersions.update(node.idString(), NodeApiVersions.create()); + + // Seek + Metadata.LeaderAndEpoch leaderAndEpoch = new Metadata.LeaderAndEpoch(metadata.currentLeader(tp0).leader, Optional.of(1)); + subscriptions.seekValidated(tp0, new SubscriptionState.FetchPosition(0, Optional.of(1), leaderAndEpoch)); + + // Check for truncation, this should cause tp0 to go into validation + fetcher.validateOffsetsIfNeeded(); + + // No fetches sent since we entered validation + assertEquals(0, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + assertTrue(subscriptions.awaitingValidation(tp0)); + + // Prepare OffsetForEpoch response then check that we update the subscription position correctly. + client.prepareResponse(prepareOffsetsForLeaderEpochResponse(tp0, Errors.NONE, 1, 10L)); + consumerClient.pollNoWakeup(); + + assertFalse(subscriptions.awaitingValidation(tp0)); + + // Fetch again, now it works + assertEquals(1, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + + client.prepareResponse(fullFetchResponse(tidp0, records, Errors.NONE, 100L, 0)); + consumerClient.pollNoWakeup(); + assertTrue(fetcher.hasCompletedFetches()); + + Map>> partitionRecords = fetchedRecords(); + assertTrue(partitionRecords.containsKey(tp0)); + + assertEquals(subscriptions.position(tp0).offset, 3L); + assertOptional(subscriptions.position(tp0).offsetEpoch, value -> assertEquals(value.intValue(), 1)); + } + + @Test + public void testPreferredReadReplica() { + buildFetcher(new MetricConfig(), OffsetResetStrategy.EARLIEST, new BytesDeserializer(), new BytesDeserializer(), + Integer.MAX_VALUE, IsolationLevel.READ_COMMITTED, Duration.ofMinutes(5).toMillis()); + + subscriptions.assignFromUser(singleton(tp0)); + client.updateMetadata(RequestTestUtils.metadataUpdateWithIds(2, singletonMap(topicName, 4), tp -> validLeaderEpoch, topicIds)); + subscriptions.seek(tp0, 0); + + // Node preferred replica before first fetch response + Node selected = fetcher.selectReadReplica(tp0, Node.noNode(), time.milliseconds()); + assertEquals(selected.id(), -1); + + assertEquals(1, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + + // Set preferred read replica to node=1 + client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, + FetchResponse.INVALID_LAST_STABLE_OFFSET, 0, Optional.of(1))); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + + Map>> partitionRecords = fetchedRecords(); + assertTrue(partitionRecords.containsKey(tp0)); + + // verify + selected = fetcher.selectReadReplica(tp0, Node.noNode(), time.milliseconds()); + assertEquals(selected.id(), 1); + + + assertEquals(1, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + + // Set preferred read replica to node=2, which isn't in our metadata, should revert to leader + client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, + FetchResponse.INVALID_LAST_STABLE_OFFSET, 0, Optional.of(2))); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + fetchedRecords(); + selected = fetcher.selectReadReplica(tp0, Node.noNode(), time.milliseconds()); + assertEquals(selected.id(), -1); + } + + @Test + public void testPreferredReadReplicaOffsetError() { + buildFetcher(new MetricConfig(), OffsetResetStrategy.EARLIEST, new BytesDeserializer(), new BytesDeserializer(), + Integer.MAX_VALUE, IsolationLevel.READ_COMMITTED, Duration.ofMinutes(5).toMillis()); + + subscriptions.assignFromUser(singleton(tp0)); + client.updateMetadata(RequestTestUtils.metadataUpdateWithIds(2, singletonMap(topicName, 4), tp -> validLeaderEpoch, topicIds)); + + subscriptions.seek(tp0, 0); + + assertEquals(1, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + + client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.NONE, 100L, + FetchResponse.INVALID_LAST_STABLE_OFFSET, 0, Optional.of(1))); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + + fetchedRecords(); + + Node selected = fetcher.selectReadReplica(tp0, Node.noNode(), time.milliseconds()); + assertEquals(selected.id(), 1); + + // Return an error, should unset the preferred read replica + assertEquals(1, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + + client.prepareResponse(fullFetchResponse(tidp0, this.records, Errors.OFFSET_OUT_OF_RANGE, 100L, + FetchResponse.INVALID_LAST_STABLE_OFFSET, 0, Optional.empty())); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + + fetchedRecords(); + + selected = fetcher.selectReadReplica(tp0, Node.noNode(), time.milliseconds()); + assertEquals(selected.id(), -1); + } + + @Test + public void testFetchCompletedBeforeHandlerAdded() { + buildFetcher(); + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + fetcher.sendFetches(); + client.prepareResponse(fullFetchResponse(tidp0, buildRecords(1L, 1, 1), Errors.NONE, 100L, 0)); + consumerClient.poll(time.timer(0)); + fetchedRecords(); + + Metadata.LeaderAndEpoch leaderAndEpoch = subscriptions.position(tp0).currentLeader; + assertTrue(leaderAndEpoch.leader.isPresent()); + Node readReplica = fetcher.selectReadReplica(tp0, leaderAndEpoch.leader.get(), time.milliseconds()); + + AtomicBoolean wokenUp = new AtomicBoolean(false); + client.setWakeupHook(() -> { + if (!wokenUp.getAndSet(true)) { + consumerClient.disconnectAsync(readReplica); + consumerClient.poll(time.timer(0)); + } + }); + + assertEquals(1, fetcher.sendFetches()); + + consumerClient.disconnectAsync(readReplica); + consumerClient.poll(time.timer(0)); + + assertEquals(1, fetcher.sendFetches()); + } + + @Test + public void testCorruptMessageError() { + buildFetcher(); + assignFromUser(singleton(tp0)); + subscriptions.seek(tp0, 0); + + assertEquals(1, fetcher.sendFetches()); + assertFalse(fetcher.hasCompletedFetches()); + + // Prepare a response with the CORRUPT_MESSAGE error. + client.prepareResponse(fullFetchResponse( + tidp0, + buildRecords(1L, 1, 1), + Errors.CORRUPT_MESSAGE, + 100L, 0)); + consumerClient.poll(time.timer(0)); + assertTrue(fetcher.hasCompletedFetches()); + + // Trigger the exception. + assertThrows(KafkaException.class, this::fetchedRecords); + } + + @Test + public void testBeginningOffsets() { + buildFetcher(); + assignFromUser(singleton(tp0)); + client.prepareResponse(listOffsetResponse(tp0, Errors.NONE, ListOffsetsRequest.EARLIEST_TIMESTAMP, 2L)); + assertEquals(singletonMap(tp0, 2L), fetcher.beginningOffsets(singleton(tp0), time.timer(5000L))); + } + + @Test + public void testBeginningOffsetsDuplicateTopicPartition() { + buildFetcher(); + assignFromUser(singleton(tp0)); + client.prepareResponse(listOffsetResponse(tp0, Errors.NONE, ListOffsetsRequest.EARLIEST_TIMESTAMP, 2L)); + assertEquals(singletonMap(tp0, 2L), fetcher.beginningOffsets(asList(tp0, tp0), time.timer(5000L))); + } + + @Test + public void testBeginningOffsetsMultipleTopicPartitions() { + buildFetcher(); + Map expectedOffsets = new HashMap<>(); + expectedOffsets.put(tp0, 2L); + expectedOffsets.put(tp1, 4L); + expectedOffsets.put(tp2, 6L); + assignFromUser(expectedOffsets.keySet()); + client.prepareResponse(listOffsetResponse(expectedOffsets, Errors.NONE, ListOffsetsRequest.EARLIEST_TIMESTAMP, ListOffsetsResponse.UNKNOWN_EPOCH)); + assertEquals(expectedOffsets, fetcher.beginningOffsets(asList(tp0, tp1, tp2), time.timer(5000L))); + } + + @Test + public void testBeginningOffsetsEmpty() { + buildFetcher(); + assertEquals(emptyMap(), fetcher.beginningOffsets(emptyList(), time.timer(5000L))); + } + + @Test + public void testEndOffsets() { + buildFetcher(); + assignFromUser(singleton(tp0)); + client.prepareResponse(listOffsetResponse(tp0, Errors.NONE, ListOffsetsRequest.LATEST_TIMESTAMP, 5L)); + assertEquals(singletonMap(tp0, 5L), fetcher.endOffsets(singleton(tp0), time.timer(5000L))); + } + + @Test + public void testEndOffsetsDuplicateTopicPartition() { + buildFetcher(); + assignFromUser(singleton(tp0)); + client.prepareResponse(listOffsetResponse(tp0, Errors.NONE, ListOffsetsRequest.LATEST_TIMESTAMP, 5L)); + assertEquals(singletonMap(tp0, 5L), fetcher.endOffsets(asList(tp0, tp0), time.timer(5000L))); + } + + @Test + public void testEndOffsetsMultipleTopicPartitions() { + buildFetcher(); + Map expectedOffsets = new HashMap<>(); + expectedOffsets.put(tp0, 5L); + expectedOffsets.put(tp1, 7L); + expectedOffsets.put(tp2, 9L); + assignFromUser(expectedOffsets.keySet()); + client.prepareResponse(listOffsetResponse(expectedOffsets, Errors.NONE, ListOffsetsRequest.LATEST_TIMESTAMP, ListOffsetsResponse.UNKNOWN_EPOCH)); + assertEquals(expectedOffsets, fetcher.endOffsets(asList(tp0, tp1, tp2), time.timer(5000L))); + } + + @Test + public void testEndOffsetsEmpty() { + buildFetcher(); + assertEquals(emptyMap(), fetcher.endOffsets(emptyList(), time.timer(5000L))); + } + + private MockClient.RequestMatcher offsetsForLeaderEpochRequestMatcher( + TopicPartition topicPartition, + int currentLeaderEpoch, + int leaderEpoch + ) { + return request -> { + OffsetsForLeaderEpochRequest epochRequest = (OffsetsForLeaderEpochRequest) request; + OffsetForLeaderPartition partition = offsetForLeaderPartitionMap(epochRequest.data()) + .get(topicPartition); + return partition != null + && partition.currentLeaderEpoch() == currentLeaderEpoch + && partition.leaderEpoch() == leaderEpoch; + }; + } + + private OffsetsForLeaderEpochResponse prepareOffsetsForLeaderEpochResponse( + TopicPartition topicPartition, + Errors error, + int leaderEpoch, + long endOffset + ) { + OffsetForLeaderEpochResponseData data = new OffsetForLeaderEpochResponseData(); + data.topics().add(new OffsetForLeaderTopicResult() + .setTopic(topicPartition.topic()) + .setPartitions(Collections.singletonList(new EpochEndOffset() + .setPartition(topicPartition.partition()) + .setErrorCode(error.code()) + .setLeaderEpoch(leaderEpoch) + .setEndOffset(endOffset)))); + return new OffsetsForLeaderEpochResponse(data); + } + + private Map offsetForLeaderPartitionMap( + OffsetForLeaderEpochRequestData data + ) { + Map result = new HashMap<>(); + data.topics().forEach(topic -> + topic.partitions().forEach(partition -> + result.put(new TopicPartition(topic.topic(), partition.partition()), partition))); + return result; + } + + private MockClient.RequestMatcher listOffsetRequestMatcher(final long timestamp) { + return listOffsetRequestMatcher(timestamp, ListOffsetsResponse.UNKNOWN_EPOCH); + } + + private MockClient.RequestMatcher listOffsetRequestMatcher(final long timestamp, final int leaderEpoch) { + // matches any list offset request with the provided timestamp + return body -> { + ListOffsetsRequest req = (ListOffsetsRequest) body; + ListOffsetsTopic topic = req.topics().get(0); + ListOffsetsPartition partition = topic.partitions().get(0); + return tp0.topic().equals(topic.name()) + && tp0.partition() == partition.partitionIndex() + && timestamp == partition.timestamp() + && leaderEpoch == partition.currentLeaderEpoch(); + }; + } + + private ListOffsetsResponse listOffsetResponse(Errors error, long timestamp, long offset) { + return listOffsetResponse(tp0, error, timestamp, offset); + } + + private ListOffsetsResponse listOffsetResponse(TopicPartition tp, Errors error, long timestamp, long offset) { + return listOffsetResponse(tp, error, timestamp, offset, ListOffsetsResponse.UNKNOWN_EPOCH); + } + + private ListOffsetsResponse listOffsetResponse(TopicPartition tp, Errors error, long timestamp, long offset, int leaderEpoch) { + Map offsets = new HashMap<>(); + offsets.put(tp, offset); + return listOffsetResponse(offsets, error, timestamp, leaderEpoch); + } + + private ListOffsetsResponse listOffsetResponse(Map offsets, Errors error, long timestamp, int leaderEpoch) { + Map> responses = new HashMap<>(); + for (Map.Entry entry : offsets.entrySet()) { + TopicPartition tp = entry.getKey(); + responses.putIfAbsent(tp.topic(), new ArrayList<>()); + responses.get(tp.topic()).add(new ListOffsetsPartitionResponse() + .setPartitionIndex(tp.partition()) + .setErrorCode(error.code()) + .setOffset(entry.getValue()) + .setTimestamp(timestamp) + .setLeaderEpoch(leaderEpoch)); + } + List topics = new ArrayList<>(); + for (Map.Entry> response : responses.entrySet()) { + topics.add(new ListOffsetsTopicResponse() + .setName(response.getKey()) + .setPartitions(response.getValue())); + } + ListOffsetsResponseData data = new ListOffsetsResponseData().setTopics(topics); + return new ListOffsetsResponse(data); + } + + private FetchResponse fetchResponseWithTopLevelError(TopicIdPartition tp, Errors error, int throttleTime) { + Map partitions = Collections.singletonMap(tp, + new FetchResponseData.PartitionData() + .setPartitionIndex(tp.topicPartition().partition()) + .setErrorCode(error.code()) + .setHighWatermark(FetchResponse.INVALID_HIGH_WATERMARK)); + return FetchResponse.of(error, throttleTime, INVALID_SESSION_ID, new LinkedHashMap<>(partitions)); + } + + private FetchResponse fullFetchResponseWithAbortedTransactions(MemoryRecords records, + List abortedTransactions, + Errors error, + long lastStableOffset, + long hw, + int throttleTime) { + Map partitions = Collections.singletonMap(tidp0, + new FetchResponseData.PartitionData() + .setPartitionIndex(tp0.partition()) + .setErrorCode(error.code()) + .setHighWatermark(hw) + .setLastStableOffset(lastStableOffset) + .setLogStartOffset(0) + .setAbortedTransactions(abortedTransactions) + .setRecords(records)); + return FetchResponse.of(Errors.NONE, throttleTime, INVALID_SESSION_ID, new LinkedHashMap<>(partitions)); + } + + private FetchResponse fullFetchResponse(int sessionId, TopicIdPartition tp, MemoryRecords records, Errors error, long hw, int throttleTime) { + return fullFetchResponse(sessionId, tp, records, error, hw, FetchResponse.INVALID_LAST_STABLE_OFFSET, throttleTime); + } + + private FetchResponse fullFetchResponse(TopicIdPartition tp, MemoryRecords records, Errors error, long hw, int throttleTime) { + return fullFetchResponse(tp, records, error, hw, FetchResponse.INVALID_LAST_STABLE_OFFSET, throttleTime); + } + + private FetchResponse fullFetchResponse(TopicIdPartition tp, MemoryRecords records, Errors error, long hw, + long lastStableOffset, int throttleTime) { + return fullFetchResponse(INVALID_SESSION_ID, tp, records, error, hw, lastStableOffset, throttleTime); + } + + private FetchResponse fullFetchResponse(int sessionId, TopicIdPartition tp, MemoryRecords records, Errors error, long hw, + long lastStableOffset, int throttleTime) { + Map partitions = Collections.singletonMap(tp, + new FetchResponseData.PartitionData() + .setPartitionIndex(tp.topicPartition().partition()) + .setErrorCode(error.code()) + .setHighWatermark(hw) + .setLastStableOffset(lastStableOffset) + .setLogStartOffset(0) + .setRecords(records)); + return FetchResponse.of(Errors.NONE, throttleTime, sessionId, new LinkedHashMap<>(partitions)); + } + + private FetchResponse fullFetchResponse(TopicIdPartition tp, MemoryRecords records, Errors error, long hw, + long lastStableOffset, int throttleTime, Optional preferredReplicaId) { + Map partitions = Collections.singletonMap(tp, + new FetchResponseData.PartitionData() + .setPartitionIndex(tp.topicPartition().partition()) + .setErrorCode(error.code()) + .setHighWatermark(hw) + .setLastStableOffset(lastStableOffset) + .setLogStartOffset(0) + .setRecords(records) + .setPreferredReadReplica(preferredReplicaId.orElse(FetchResponse.INVALID_PREFERRED_REPLICA_ID))); + return FetchResponse.of(Errors.NONE, throttleTime, INVALID_SESSION_ID, new LinkedHashMap<>(partitions)); + } + + private FetchResponse fetchResponse(TopicIdPartition tp, MemoryRecords records, Errors error, long hw, + long lastStableOffset, long logStartOffset, int throttleTime) { + Map partitions = Collections.singletonMap(tp, + new FetchResponseData.PartitionData() + .setPartitionIndex(tp.topicPartition().partition()) + .setErrorCode(error.code()) + .setHighWatermark(hw) + .setLastStableOffset(lastStableOffset) + .setLogStartOffset(logStartOffset) + .setRecords(records)); + return FetchResponse.of(Errors.NONE, throttleTime, INVALID_SESSION_ID, new LinkedHashMap<>(partitions)); + } + + private MetadataResponse newMetadataResponse(String topic, Errors error) { + List partitionsMetadata = new ArrayList<>(); + if (error == Errors.NONE) { + Optional foundMetadata = initialUpdateResponse.topicMetadata() + .stream() + .filter(topicMetadata -> topicMetadata.topic().equals(topic)) + .findFirst(); + foundMetadata.ifPresent(topicMetadata -> { + partitionsMetadata.addAll(topicMetadata.partitionMetadata()); + }); + } + + MetadataResponse.TopicMetadata topicMetadata = new MetadataResponse.TopicMetadata(error, topic, false, + partitionsMetadata); + List brokers = new ArrayList<>(initialUpdateResponse.brokers()); + return RequestTestUtils.metadataResponse(brokers, initialUpdateResponse.clusterId(), + initialUpdateResponse.controller().id(), Collections.singletonList(topicMetadata)); + } + + @SuppressWarnings("unchecked") + private Map>> fetchedRecords() { + return (Map) fetcher.fetchedRecords(); + } + + private void buildFetcher(int maxPollRecords) { + buildFetcher(OffsetResetStrategy.EARLIEST, new ByteArrayDeserializer(), new ByteArrayDeserializer(), + maxPollRecords, IsolationLevel.READ_UNCOMMITTED); + } + + private void buildFetcher() { + buildFetcher(Integer.MAX_VALUE); + } + + private void buildFetcher(Deserializer keyDeserializer, + Deserializer valueDeserializer) { + buildFetcher(OffsetResetStrategy.EARLIEST, keyDeserializer, valueDeserializer, + Integer.MAX_VALUE, IsolationLevel.READ_UNCOMMITTED); + } + + private void buildFetcher(OffsetResetStrategy offsetResetStrategy) { + buildFetcher(new MetricConfig(), offsetResetStrategy, + new ByteArrayDeserializer(), new ByteArrayDeserializer(), + Integer.MAX_VALUE, IsolationLevel.READ_UNCOMMITTED); + } + + private void buildFetcher(OffsetResetStrategy offsetResetStrategy, + Deserializer keyDeserializer, + Deserializer valueDeserializer, + int maxPollRecords, + IsolationLevel isolationLevel) { + buildFetcher(new MetricConfig(), offsetResetStrategy, keyDeserializer, valueDeserializer, + maxPollRecords, isolationLevel); + } + + private void buildFetcher(MetricConfig metricConfig, + OffsetResetStrategy offsetResetStrategy, + Deserializer keyDeserializer, + Deserializer valueDeserializer, + int maxPollRecords, + IsolationLevel isolationLevel) { + buildFetcher(metricConfig, offsetResetStrategy, keyDeserializer, valueDeserializer, maxPollRecords, isolationLevel, Long.MAX_VALUE); + } + + private void buildFetcher(MetricConfig metricConfig, + OffsetResetStrategy offsetResetStrategy, + Deserializer keyDeserializer, + Deserializer valueDeserializer, + int maxPollRecords, + IsolationLevel isolationLevel, + long metadataExpireMs) { + LogContext logContext = new LogContext(); + SubscriptionState subscriptionState = new SubscriptionState(logContext, offsetResetStrategy); + buildFetcher(metricConfig, keyDeserializer, valueDeserializer, maxPollRecords, isolationLevel, metadataExpireMs, + subscriptionState, logContext); + } + + private void buildFetcher(SubscriptionState subscriptionState, LogContext logContext) { + buildFetcher(new MetricConfig(), new ByteArrayDeserializer(), new ByteArrayDeserializer(), Integer.MAX_VALUE, + IsolationLevel.READ_UNCOMMITTED, Long.MAX_VALUE, subscriptionState, logContext); + } + + private void buildFetcher(MetricConfig metricConfig, + Deserializer keyDeserializer, + Deserializer valueDeserializer, + int maxPollRecords, + IsolationLevel isolationLevel, + long metadataExpireMs, + SubscriptionState subscriptionState, + LogContext logContext) { + buildDependencies(metricConfig, metadataExpireMs, subscriptionState, logContext); + fetcher = new Fetcher<>( + new LogContext(), + consumerClient, + minBytes, + maxBytes, + maxWaitMs, + fetchSize, + maxPollRecords, + true, // check crc + "", + keyDeserializer, + valueDeserializer, + metadata, + subscriptions, + metrics, + metricsRegistry, + time, + retryBackoffMs, + requestTimeoutMs, + isolationLevel, + apiVersions); + } + + private void buildDependencies(MetricConfig metricConfig, + long metadataExpireMs, + SubscriptionState subscriptionState, + LogContext logContext) { + time = new MockTime(1); + subscriptions = subscriptionState; + metadata = new ConsumerMetadata(0, metadataExpireMs, false, false, + subscriptions, logContext, new ClusterResourceListeners()); + client = new MockClient(time, metadata); + metrics = new Metrics(metricConfig, time); + consumerClient = new ConsumerNetworkClient(logContext, client, metadata, time, + 100, 1000, Integer.MAX_VALUE); + metricsRegistry = new FetcherMetricsRegistry(metricConfig.tags().keySet(), "consumer" + groupId); + } + + private List collectRecordOffsets(List> records) { + return records.stream().map(ConsumerRecord::offset).collect(Collectors.toList()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/HeartbeatTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/HeartbeatTest.java new file mode 100644 index 0000000..186546f --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/HeartbeatTest.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.clients.GroupRebalanceConfig; +import org.apache.kafka.common.utils.MockTime; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class HeartbeatTest { + private int sessionTimeoutMs = 300; + private int heartbeatIntervalMs = 100; + private int maxPollIntervalMs = 900; + private long retryBackoffMs = 10L; + private MockTime time = new MockTime(); + + private Heartbeat heartbeat; + + @BeforeEach + public void setUp() { + GroupRebalanceConfig rebalanceConfig = new GroupRebalanceConfig(sessionTimeoutMs, + maxPollIntervalMs, + heartbeatIntervalMs, + "group_id", + Optional.empty(), + retryBackoffMs, + true); + heartbeat = new Heartbeat(rebalanceConfig, time); + } + + @Test + public void testShouldHeartbeat() { + heartbeat.sentHeartbeat(time.milliseconds()); + time.sleep((long) ((float) heartbeatIntervalMs * 1.1)); + assertTrue(heartbeat.shouldHeartbeat(time.milliseconds())); + } + + @Test + public void testShouldNotHeartbeat() { + heartbeat.sentHeartbeat(time.milliseconds()); + time.sleep(heartbeatIntervalMs / 2); + assertFalse(heartbeat.shouldHeartbeat(time.milliseconds())); + } + + @Test + public void testTimeToNextHeartbeat() { + heartbeat.sentHeartbeat(time.milliseconds()); + assertEquals(heartbeatIntervalMs, heartbeat.timeToNextHeartbeat(time.milliseconds())); + + time.sleep(heartbeatIntervalMs); + assertEquals(0, heartbeat.timeToNextHeartbeat(time.milliseconds())); + + time.sleep(heartbeatIntervalMs); + assertEquals(0, heartbeat.timeToNextHeartbeat(time.milliseconds())); + } + + @Test + public void testSessionTimeoutExpired() { + heartbeat.sentHeartbeat(time.milliseconds()); + time.sleep(sessionTimeoutMs + 5); + assertTrue(heartbeat.sessionTimeoutExpired(time.milliseconds())); + } + + @Test + public void testResetSession() { + heartbeat.sentHeartbeat(time.milliseconds()); + time.sleep(sessionTimeoutMs + 5); + heartbeat.resetSessionTimeout(); + assertFalse(heartbeat.sessionTimeoutExpired(time.milliseconds())); + + // Resetting the session timeout should not reset the poll timeout + time.sleep(maxPollIntervalMs + 1); + heartbeat.resetSessionTimeout(); + assertTrue(heartbeat.pollTimeoutExpired(time.milliseconds())); + } + + @Test + public void testResetTimeouts() { + time.sleep(maxPollIntervalMs); + assertTrue(heartbeat.sessionTimeoutExpired(time.milliseconds())); + assertEquals(0, heartbeat.timeToNextHeartbeat(time.milliseconds())); + assertTrue(heartbeat.pollTimeoutExpired(time.milliseconds())); + + heartbeat.resetTimeouts(); + assertFalse(heartbeat.sessionTimeoutExpired(time.milliseconds())); + assertEquals(heartbeatIntervalMs, heartbeat.timeToNextHeartbeat(time.milliseconds())); + assertFalse(heartbeat.pollTimeoutExpired(time.milliseconds())); + } + + @Test + public void testPollTimeout() { + assertFalse(heartbeat.pollTimeoutExpired(time.milliseconds())); + time.sleep(maxPollIntervalMs / 2); + + assertFalse(heartbeat.pollTimeoutExpired(time.milliseconds())); + time.sleep(maxPollIntervalMs / 2 + 1); + + assertTrue(heartbeat.pollTimeoutExpired(time.milliseconds())); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/KafkaConsumerMetricsTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/KafkaConsumerMetricsTest.java new file mode 100644 index 0000000..087f90b --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/KafkaConsumerMetricsTest.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.common.metrics.Metrics; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +class KafkaConsumerMetricsTest { + private static final long METRIC_VALUE = 123L; + private static final String CONSUMER_GROUP_PREFIX = "consumer"; + private static final String CONSUMER_METRIC_GROUP = "consumer-metrics"; + private static final String COMMIT_SYNC_TIME_TOTAL = "commit-sync-time-ns-total"; + private static final String COMMITTED_TIME_TOTAL = "committed-time-ns-total"; + + private final Metrics metrics = new Metrics(); + private final KafkaConsumerMetrics consumerMetrics + = new KafkaConsumerMetrics(metrics, CONSUMER_GROUP_PREFIX); + + @Test + public void shouldRecordCommitSyncTime() { + // When: + consumerMetrics.recordCommitSync(METRIC_VALUE); + + // Then: + assertMetricValue(COMMIT_SYNC_TIME_TOTAL); + } + + @Test + public void shouldRecordCommittedTime() { + // When: + consumerMetrics.recordCommitted(METRIC_VALUE); + + // Then: + assertMetricValue(COMMITTED_TIME_TOTAL); + } + + @Test + public void shouldRemoveMetricsOnClose() { + // When: + consumerMetrics.close(); + + // Then: + assertMetricRemoved(COMMIT_SYNC_TIME_TOTAL); + assertMetricRemoved(COMMITTED_TIME_TOTAL); + } + + private void assertMetricRemoved(final String name) { + assertNull(metrics.metric(metrics.metricName(name, CONSUMER_METRIC_GROUP))); + } + + private void assertMetricValue(final String name) { + assertEquals( + metrics.metric(metrics.metricName(name, CONSUMER_METRIC_GROUP)).metricValue(), + (double) METRIC_VALUE + ); + } +} \ No newline at end of file diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/MockPartitionAssignor.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/MockPartitionAssignor.java new file mode 100644 index 0000000..ef95f2f --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/MockPartitionAssignor.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.clients.consumer.ConsumerGroupMetadata; +import org.apache.kafka.common.TopicPartition; + +import java.util.List; +import java.util.Map; + +public class MockPartitionAssignor extends AbstractPartitionAssignor { + + private final List supportedProtocols; + + private int numAssignment; + + private Map> result = null; + + MockPartitionAssignor(final List supportedProtocols) { + this.supportedProtocols = supportedProtocols; + numAssignment = 0; + } + + @Override + public Map> assign(Map partitionsPerTopic, + Map subscriptions) { + if (result == null) + throw new IllegalStateException("Call to assign with no result prepared"); + return result; + } + + @Override + public String name() { + return "consumer-mock-assignor"; + } + + @Override + public List supportedProtocols() { + return supportedProtocols; + } + + public void clear() { + this.result = null; + } + + public void prepare(Map> result) { + this.result = result; + } + + @Override + public void onAssignment(Assignment assignment, ConsumerGroupMetadata metadata) { + numAssignment += 1; + } + + int numAssignment() { + return numAssignment; + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/MockRebalanceListener.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/MockRebalanceListener.java new file mode 100644 index 0000000..be80254 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/MockRebalanceListener.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import java.util.Collection; +import org.apache.kafka.clients.consumer.ConsumerRebalanceListener; +import org.apache.kafka.common.TopicPartition; + +public class MockRebalanceListener implements ConsumerRebalanceListener { + public Collection lost; + public Collection revoked; + public Collection assigned; + public int lostCount = 0; + public int revokedCount = 0; + public int assignedCount = 0; + + @Override + public void onPartitionsAssigned(Collection partitions) { + this.assigned = partitions; + assignedCount++; + } + + @Override + public void onPartitionsRevoked(Collection partitions) { + this.revoked = partitions; + revokedCount++; + } + + @Override + public void onPartitionsLost(Collection partitions) { + this.lost = partitions; + lostCount++; + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/OffsetForLeaderEpochClientTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/OffsetForLeaderEpochClientTest.java new file mode 100644 index 0000000..432bd44 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/OffsetForLeaderEpochClientTest.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.clients.Metadata; +import org.apache.kafka.clients.MockClient; +import org.apache.kafka.clients.consumer.OffsetResetStrategy; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.TopicAuthorizationException; +import org.apache.kafka.common.internals.ClusterResourceListeners; +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData; +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset; +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.OffsetForLeaderTopicResult; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class OffsetForLeaderEpochClientTest { + + private ConsumerNetworkClient consumerClient; + private SubscriptionState subscriptions; + private Metadata metadata; + private MockClient client; + private Time time; + + private TopicPartition tp0 = new TopicPartition("topic", 0); + + @Test + public void testEmptyResponse() { + OffsetsForLeaderEpochClient offsetClient = newOffsetClient(); + RequestFuture future = + offsetClient.sendAsyncRequest(Node.noNode(), Collections.emptyMap()); + + OffsetsForLeaderEpochResponse resp = new OffsetsForLeaderEpochResponse( + new OffsetForLeaderEpochResponseData()); + client.prepareResponse(resp); + consumerClient.pollNoWakeup(); + + OffsetsForLeaderEpochClient.OffsetForEpochResult result = future.value(); + assertTrue(result.partitionsToRetry().isEmpty()); + assertTrue(result.endOffsets().isEmpty()); + } + + @Test + public void testUnexpectedEmptyResponse() { + Map positionMap = new HashMap<>(); + positionMap.put(tp0, new SubscriptionState.FetchPosition(0, Optional.of(1), + new Metadata.LeaderAndEpoch(Optional.empty(), Optional.of(1)))); + + OffsetsForLeaderEpochClient offsetClient = newOffsetClient(); + RequestFuture future = + offsetClient.sendAsyncRequest(Node.noNode(), positionMap); + + OffsetsForLeaderEpochResponse resp = new OffsetsForLeaderEpochResponse( + new OffsetForLeaderEpochResponseData()); + client.prepareResponse(resp); + consumerClient.pollNoWakeup(); + + OffsetsForLeaderEpochClient.OffsetForEpochResult result = future.value(); + assertFalse(result.partitionsToRetry().isEmpty()); + assertTrue(result.endOffsets().isEmpty()); + } + + @Test + public void testOkResponse() { + Map positionMap = new HashMap<>(); + positionMap.put(tp0, new SubscriptionState.FetchPosition(0, Optional.of(1), + new Metadata.LeaderAndEpoch(Optional.empty(), Optional.of(1)))); + + OffsetsForLeaderEpochClient offsetClient = newOffsetClient(); + RequestFuture future = + offsetClient.sendAsyncRequest(Node.noNode(), positionMap); + + client.prepareResponse(prepareOffsetForLeaderEpochResponse( + tp0, Errors.NONE, 1, 10L)); + consumerClient.pollNoWakeup(); + + OffsetsForLeaderEpochClient.OffsetForEpochResult result = future.value(); + assertTrue(result.partitionsToRetry().isEmpty()); + assertTrue(result.endOffsets().containsKey(tp0)); + assertEquals(result.endOffsets().get(tp0).errorCode(), Errors.NONE.code()); + assertEquals(result.endOffsets().get(tp0).leaderEpoch(), 1); + assertEquals(result.endOffsets().get(tp0).endOffset(), 10L); + } + + @Test + public void testUnauthorizedTopic() { + Map positionMap = new HashMap<>(); + positionMap.put(tp0, new SubscriptionState.FetchPosition(0, Optional.of(1), + new Metadata.LeaderAndEpoch(Optional.empty(), Optional.of(1)))); + + OffsetsForLeaderEpochClient offsetClient = newOffsetClient(); + RequestFuture future = + offsetClient.sendAsyncRequest(Node.noNode(), positionMap); + + client.prepareResponse(prepareOffsetForLeaderEpochResponse( + tp0, Errors.TOPIC_AUTHORIZATION_FAILED, -1, -1)); + consumerClient.pollNoWakeup(); + + assertTrue(future.failed()); + assertEquals(future.exception().getClass(), TopicAuthorizationException.class); + assertTrue(((TopicAuthorizationException) future.exception()).unauthorizedTopics().contains(tp0.topic())); + } + + @Test + public void testRetriableError() { + Map positionMap = new HashMap<>(); + positionMap.put(tp0, new SubscriptionState.FetchPosition(0, Optional.of(1), + new Metadata.LeaderAndEpoch(Optional.empty(), Optional.of(1)))); + + OffsetsForLeaderEpochClient offsetClient = newOffsetClient(); + RequestFuture future = + offsetClient.sendAsyncRequest(Node.noNode(), positionMap); + + client.prepareResponse(prepareOffsetForLeaderEpochResponse( + tp0, Errors.LEADER_NOT_AVAILABLE, -1, -1)); + consumerClient.pollNoWakeup(); + + assertFalse(future.failed()); + OffsetsForLeaderEpochClient.OffsetForEpochResult result = future.value(); + assertTrue(result.partitionsToRetry().contains(tp0)); + assertFalse(result.endOffsets().containsKey(tp0)); + } + + private OffsetsForLeaderEpochClient newOffsetClient() { + buildDependencies(OffsetResetStrategy.EARLIEST); + return new OffsetsForLeaderEpochClient(consumerClient, new LogContext()); + } + + private void buildDependencies(OffsetResetStrategy offsetResetStrategy) { + LogContext logContext = new LogContext(); + time = new MockTime(1); + subscriptions = new SubscriptionState(logContext, offsetResetStrategy); + metadata = new ConsumerMetadata(0, Long.MAX_VALUE, false, false, + subscriptions, logContext, new ClusterResourceListeners()); + client = new MockClient(time, metadata); + consumerClient = new ConsumerNetworkClient(logContext, client, metadata, time, + 100, 1000, Integer.MAX_VALUE); + } + + private static OffsetsForLeaderEpochResponse prepareOffsetForLeaderEpochResponse( + TopicPartition tp, Errors error, int leaderEpoch, long endOffset) { + OffsetForLeaderEpochResponseData data = new OffsetForLeaderEpochResponseData(); + OffsetForLeaderTopicResult topic = new OffsetForLeaderTopicResult() + .setTopic(tp.topic()); + data.topics().add(topic); + topic.partitions().add(new EpochEndOffset() + .setPartition(tp.partition()) + .setErrorCode(error.code()) + .setLeaderEpoch(leaderEpoch) + .setEndOffset(endOffset)); + return new OffsetsForLeaderEpochResponse(data); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/RequestFutureTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/RequestFutureTest.java new file mode 100644 index 0000000..e218f81 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/RequestFutureTest.java @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.junit.jupiter.api.Test; + +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class RequestFutureTest { + + @Test + public void testBasicCompletion() { + RequestFuture future = new RequestFuture<>(); + String value = "foo"; + future.complete(value); + assertTrue(future.isDone()); + assertEquals(value, future.value()); + } + + @Test + public void testBasicFailure() { + RequestFuture future = new RequestFuture<>(); + RuntimeException exception = new RuntimeException(); + future.raise(exception); + assertTrue(future.isDone()); + assertEquals(exception, future.exception()); + } + + @Test + public void testVoidFuture() { + RequestFuture future = new RequestFuture<>(); + future.complete(null); + assertTrue(future.isDone()); + assertNull(future.value()); + } + + @Test + public void testRuntimeExceptionInComplete() { + RequestFuture future = new RequestFuture<>(); + assertThrows(IllegalArgumentException.class, () -> future.complete(new RuntimeException())); + } + + @Test + public void invokeCompleteAfterAlreadyComplete() { + RequestFuture future = new RequestFuture<>(); + future.complete(null); + assertThrows(IllegalStateException.class, () -> future.complete(null)); + } + + @Test + public void invokeCompleteAfterAlreadyFailed() { + RequestFuture future = new RequestFuture<>(); + future.raise(new RuntimeException()); + assertThrows(IllegalStateException.class, () -> future.complete(null)); + } + + @Test + public void invokeRaiseAfterAlreadyFailed() { + RequestFuture future = new RequestFuture<>(); + future.raise(new RuntimeException()); + assertThrows(IllegalStateException.class, () -> future.raise(new RuntimeException())); + } + + @Test + public void invokeRaiseAfterAlreadyCompleted() { + RequestFuture future = new RequestFuture<>(); + future.complete(null); + assertThrows(IllegalStateException.class, () -> future.raise(new RuntimeException())); + } + + @Test + public void invokeExceptionAfterSuccess() { + RequestFuture future = new RequestFuture<>(); + future.complete(null); + assertThrows(IllegalStateException.class, future::exception); + } + + @Test + public void invokeValueAfterFailure() { + RequestFuture future = new RequestFuture<>(); + future.raise(new RuntimeException()); + assertThrows(IllegalStateException.class, future::value); + } + + @Test + public void listenerInvokedIfAddedBeforeFutureCompletion() { + RequestFuture future = new RequestFuture<>(); + + MockRequestFutureListener listener = new MockRequestFutureListener<>(); + future.addListener(listener); + + future.complete(null); + + assertOnSuccessInvoked(listener); + } + + @Test + public void listenerInvokedIfAddedBeforeFutureFailure() { + RequestFuture future = new RequestFuture<>(); + + MockRequestFutureListener listener = new MockRequestFutureListener<>(); + future.addListener(listener); + + future.raise(new RuntimeException()); + + assertOnFailureInvoked(listener); + } + + @Test + public void listenerInvokedIfAddedAfterFutureCompletion() { + RequestFuture future = new RequestFuture<>(); + future.complete(null); + + MockRequestFutureListener listener = new MockRequestFutureListener<>(); + future.addListener(listener); + + assertOnSuccessInvoked(listener); + } + + @Test + public void listenerInvokedIfAddedAfterFutureFailure() { + RequestFuture future = new RequestFuture<>(); + future.raise(new RuntimeException()); + + MockRequestFutureListener listener = new MockRequestFutureListener<>(); + future.addListener(listener); + + assertOnFailureInvoked(listener); + } + + @Test + public void listenersInvokedIfAddedBeforeAndAfterFailure() { + RequestFuture future = new RequestFuture<>(); + + MockRequestFutureListener beforeListener = new MockRequestFutureListener<>(); + future.addListener(beforeListener); + + future.raise(new RuntimeException()); + + MockRequestFutureListener afterListener = new MockRequestFutureListener<>(); + future.addListener(afterListener); + + assertOnFailureInvoked(beforeListener); + assertOnFailureInvoked(afterListener); + } + + @Test + public void listenersInvokedIfAddedBeforeAndAfterCompletion() { + RequestFuture future = new RequestFuture<>(); + + MockRequestFutureListener beforeListener = new MockRequestFutureListener<>(); + future.addListener(beforeListener); + + future.complete(null); + + MockRequestFutureListener afterListener = new MockRequestFutureListener<>(); + future.addListener(afterListener); + + assertOnSuccessInvoked(beforeListener); + assertOnSuccessInvoked(afterListener); + } + + @Test + public void testComposeSuccessCase() { + RequestFuture future = new RequestFuture<>(); + RequestFuture composed = future.compose(new RequestFutureAdapter() { + @Override + public void onSuccess(String value, RequestFuture future) { + future.complete(value.length()); + } + }); + + future.complete("hello"); + + assertTrue(composed.isDone()); + assertTrue(composed.succeeded()); + assertEquals(5, (int) composed.value()); + } + + @Test + public void testComposeFailureCase() { + RequestFuture future = new RequestFuture<>(); + RequestFuture composed = future.compose(new RequestFutureAdapter() { + @Override + public void onSuccess(String value, RequestFuture future) { + future.complete(value.length()); + } + }); + + RuntimeException e = new RuntimeException(); + future.raise(e); + + assertTrue(composed.isDone()); + assertTrue(composed.failed()); + assertEquals(e, composed.exception()); + } + + private static void assertOnSuccessInvoked(MockRequestFutureListener listener) { + assertEquals(1, listener.numOnSuccessCalls.get()); + assertEquals(0, listener.numOnFailureCalls.get()); + } + + private static void assertOnFailureInvoked(MockRequestFutureListener listener) { + assertEquals(0, listener.numOnSuccessCalls.get()); + assertEquals(1, listener.numOnFailureCalls.get()); + } + + private static class MockRequestFutureListener implements RequestFutureListener { + private final AtomicInteger numOnSuccessCalls = new AtomicInteger(0); + private final AtomicInteger numOnFailureCalls = new AtomicInteger(0); + + @Override + public void onSuccess(T value) { + numOnSuccessCalls.incrementAndGet(); + } + + @Override + public void onFailure(RuntimeException e) { + numOnFailureCalls.incrementAndGet(); + } + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/SubscriptionStateTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/SubscriptionStateTest.java new file mode 100644 index 0000000..d19234f --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/SubscriptionStateTest.java @@ -0,0 +1,812 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.clients.ApiVersions; +import org.apache.kafka.clients.Metadata; +import org.apache.kafka.clients.NodeApiVersions; +import org.apache.kafka.clients.consumer.ConsumerRebalanceListener; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.consumer.OffsetResetStrategy; +import org.apache.kafka.clients.consumer.internals.SubscriptionState.LogTruncation; +import org.apache.kafka.common.IsolationLevel; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.Optional; +import java.util.Set; +import java.util.regex.Pattern; + +import static java.util.Collections.singleton; +import static org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.UNDEFINED_EPOCH; +import static org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.UNDEFINED_EPOCH_OFFSET; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class SubscriptionStateTest { + + private SubscriptionState state = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + private final String topic = "test"; + private final String topic1 = "test1"; + private final TopicPartition tp0 = new TopicPartition(topic, 0); + private final TopicPartition tp1 = new TopicPartition(topic, 1); + private final TopicPartition t1p0 = new TopicPartition(topic1, 0); + private final MockRebalanceListener rebalanceListener = new MockRebalanceListener(); + private final Metadata.LeaderAndEpoch leaderAndEpoch = Metadata.LeaderAndEpoch.noLeaderOrEpoch(); + + @Test + public void partitionAssignment() { + state.assignFromUser(singleton(tp0)); + assertEquals(singleton(tp0), state.assignedPartitions()); + assertEquals(1, state.numAssignedPartitions()); + assertFalse(state.hasAllFetchPositions()); + state.seek(tp0, 1); + assertTrue(state.isFetchable(tp0)); + assertEquals(1L, state.position(tp0).offset); + state.assignFromUser(Collections.emptySet()); + assertTrue(state.assignedPartitions().isEmpty()); + assertEquals(0, state.numAssignedPartitions()); + assertFalse(state.isAssigned(tp0)); + assertFalse(state.isFetchable(tp0)); + } + + @Test + public void partitionAssignmentChangeOnTopicSubscription() { + state.assignFromUser(new HashSet<>(Arrays.asList(tp0, tp1))); + // assigned partitions should immediately change + assertEquals(2, state.assignedPartitions().size()); + assertEquals(2, state.numAssignedPartitions()); + assertTrue(state.assignedPartitions().contains(tp0)); + assertTrue(state.assignedPartitions().contains(tp1)); + + state.unsubscribe(); + // assigned partitions should immediately change + assertTrue(state.assignedPartitions().isEmpty()); + assertEquals(0, state.numAssignedPartitions()); + + state.subscribe(singleton(topic1), rebalanceListener); + // assigned partitions should remain unchanged + assertTrue(state.assignedPartitions().isEmpty()); + assertEquals(0, state.numAssignedPartitions()); + + assertTrue(state.checkAssignmentMatchedSubscription(singleton(t1p0))); + state.assignFromSubscribed(singleton(t1p0)); + // assigned partitions should immediately change + assertEquals(singleton(t1p0), state.assignedPartitions()); + assertEquals(1, state.numAssignedPartitions()); + + state.subscribe(singleton(topic), rebalanceListener); + // assigned partitions should remain unchanged + assertEquals(singleton(t1p0), state.assignedPartitions()); + assertEquals(1, state.numAssignedPartitions()); + + state.unsubscribe(); + // assigned partitions should immediately change + assertTrue(state.assignedPartitions().isEmpty()); + assertEquals(0, state.numAssignedPartitions()); + } + + @Test + public void testGroupSubscribe() { + state.subscribe(singleton(topic1), rebalanceListener); + assertEquals(singleton(topic1), state.metadataTopics()); + + assertFalse(state.groupSubscribe(singleton(topic1))); + assertEquals(singleton(topic1), state.metadataTopics()); + + assertTrue(state.groupSubscribe(Utils.mkSet(topic, topic1))); + assertEquals(Utils.mkSet(topic, topic1), state.metadataTopics()); + + // `groupSubscribe` does not accumulate + assertFalse(state.groupSubscribe(singleton(topic1))); + assertEquals(singleton(topic1), state.metadataTopics()); + + state.subscribe(singleton("anotherTopic"), rebalanceListener); + assertEquals(Utils.mkSet(topic1, "anotherTopic"), state.metadataTopics()); + + assertFalse(state.groupSubscribe(singleton("anotherTopic"))); + assertEquals(singleton("anotherTopic"), state.metadataTopics()); + } + + @Test + public void partitionAssignmentChangeOnPatternSubscription() { + state.subscribe(Pattern.compile(".*"), rebalanceListener); + // assigned partitions should remain unchanged + assertTrue(state.assignedPartitions().isEmpty()); + assertEquals(0, state.numAssignedPartitions()); + + state.subscribeFromPattern(Collections.singleton(topic)); + // assigned partitions should remain unchanged + assertTrue(state.assignedPartitions().isEmpty()); + assertEquals(0, state.numAssignedPartitions()); + + assertTrue(state.checkAssignmentMatchedSubscription(singleton(tp1))); + state.assignFromSubscribed(singleton(tp1)); + + // assigned partitions should immediately change + assertEquals(singleton(tp1), state.assignedPartitions()); + assertEquals(1, state.numAssignedPartitions()); + assertEquals(singleton(topic), state.subscription()); + + assertTrue(state.checkAssignmentMatchedSubscription(singleton(t1p0))); + state.assignFromSubscribed(singleton(t1p0)); + + // assigned partitions should immediately change + assertEquals(singleton(t1p0), state.assignedPartitions()); + assertEquals(1, state.numAssignedPartitions()); + assertEquals(singleton(topic), state.subscription()); + + state.subscribe(Pattern.compile(".*t"), rebalanceListener); + // assigned partitions should remain unchanged + assertEquals(singleton(t1p0), state.assignedPartitions()); + assertEquals(1, state.numAssignedPartitions()); + + state.subscribeFromPattern(singleton(topic)); + // assigned partitions should remain unchanged + assertEquals(singleton(t1p0), state.assignedPartitions()); + assertEquals(1, state.numAssignedPartitions()); + + assertTrue(state.checkAssignmentMatchedSubscription(singleton(tp0))); + state.assignFromSubscribed(singleton(tp0)); + + // assigned partitions should immediately change + assertEquals(singleton(tp0), state.assignedPartitions()); + assertEquals(1, state.numAssignedPartitions()); + assertEquals(singleton(topic), state.subscription()); + + state.unsubscribe(); + // assigned partitions should immediately change + assertTrue(state.assignedPartitions().isEmpty()); + assertEquals(0, state.numAssignedPartitions()); + } + + @Test + public void verifyAssignmentId() { + assertEquals(0, state.assignmentId()); + Set userAssignment = Utils.mkSet(tp0, tp1); + state.assignFromUser(userAssignment); + assertEquals(1, state.assignmentId()); + assertEquals(userAssignment, state.assignedPartitions()); + + state.unsubscribe(); + assertEquals(2, state.assignmentId()); + assertEquals(Collections.emptySet(), state.assignedPartitions()); + + Set autoAssignment = Utils.mkSet(t1p0); + state.subscribe(singleton(topic1), rebalanceListener); + assertTrue(state.checkAssignmentMatchedSubscription(autoAssignment)); + state.assignFromSubscribed(autoAssignment); + assertEquals(3, state.assignmentId()); + assertEquals(autoAssignment, state.assignedPartitions()); + } + + @Test + public void partitionReset() { + state.assignFromUser(singleton(tp0)); + state.seek(tp0, 5); + assertEquals(5L, state.position(tp0).offset); + state.requestOffsetReset(tp0); + assertFalse(state.isFetchable(tp0)); + assertTrue(state.isOffsetResetNeeded(tp0)); + assertNull(state.position(tp0)); + + // seek should clear the reset and make the partition fetchable + state.seek(tp0, 0); + assertTrue(state.isFetchable(tp0)); + assertFalse(state.isOffsetResetNeeded(tp0)); + } + + @Test + public void topicSubscription() { + state.subscribe(singleton(topic), rebalanceListener); + assertEquals(1, state.subscription().size()); + assertTrue(state.assignedPartitions().isEmpty()); + assertEquals(0, state.numAssignedPartitions()); + assertTrue(state.hasAutoAssignedPartitions()); + assertTrue(state.checkAssignmentMatchedSubscription(singleton(tp0))); + state.assignFromSubscribed(singleton(tp0)); + + state.seek(tp0, 1); + assertEquals(1L, state.position(tp0).offset); + assertTrue(state.checkAssignmentMatchedSubscription(singleton(tp1))); + state.assignFromSubscribed(singleton(tp1)); + + assertTrue(state.isAssigned(tp1)); + assertFalse(state.isAssigned(tp0)); + assertFalse(state.isFetchable(tp1)); + assertEquals(singleton(tp1), state.assignedPartitions()); + assertEquals(1, state.numAssignedPartitions()); + } + + @Test + public void partitionPause() { + state.assignFromUser(singleton(tp0)); + state.seek(tp0, 100); + assertTrue(state.isFetchable(tp0)); + state.pause(tp0); + assertFalse(state.isFetchable(tp0)); + state.resume(tp0); + assertTrue(state.isFetchable(tp0)); + } + + @Test + public void invalidPositionUpdate() { + state.subscribe(singleton(topic), rebalanceListener); + assertTrue(state.checkAssignmentMatchedSubscription(singleton(tp0))); + state.assignFromSubscribed(singleton(tp0)); + + assertThrows(IllegalStateException.class, () -> state.position(tp0, + new SubscriptionState.FetchPosition(0, Optional.empty(), leaderAndEpoch))); + } + + @Test + public void cantAssignPartitionForUnsubscribedTopics() { + state.subscribe(singleton(topic), rebalanceListener); + assertFalse(state.checkAssignmentMatchedSubscription(Collections.singletonList(t1p0))); + } + + @Test + public void cantAssignPartitionForUnmatchedPattern() { + state.subscribe(Pattern.compile(".*t"), rebalanceListener); + state.subscribeFromPattern(Collections.singleton(topic)); + assertFalse(state.checkAssignmentMatchedSubscription(Collections.singletonList(t1p0))); + } + + @Test + public void cantChangePositionForNonAssignedPartition() { + assertThrows(IllegalStateException.class, () -> state.position(tp0, + new SubscriptionState.FetchPosition(1, Optional.empty(), leaderAndEpoch))); + } + + @Test + public void cantSubscribeTopicAndPattern() { + state.subscribe(singleton(topic), rebalanceListener); + assertThrows(IllegalStateException.class, () -> state.subscribe(Pattern.compile(".*"), rebalanceListener)); + } + + @Test + public void cantSubscribePartitionAndPattern() { + state.assignFromUser(singleton(tp0)); + assertThrows(IllegalStateException.class, () -> state.subscribe(Pattern.compile(".*"), rebalanceListener)); + } + + @Test + public void cantSubscribePatternAndTopic() { + state.subscribe(Pattern.compile(".*"), rebalanceListener); + assertThrows(IllegalStateException.class, () -> state.subscribe(singleton(topic), rebalanceListener)); + } + + @Test + public void cantSubscribePatternAndPartition() { + state.subscribe(Pattern.compile(".*"), rebalanceListener); + assertThrows(IllegalStateException.class, () -> state.assignFromUser(singleton(tp0))); + } + + @Test + public void patternSubscription() { + state.subscribe(Pattern.compile(".*"), rebalanceListener); + state.subscribeFromPattern(new HashSet<>(Arrays.asList(topic, topic1))); + assertEquals(2, state.subscription().size(), "Expected subscribed topics count is incorrect"); + } + + @Test + public void unsubscribeUserAssignment() { + state.assignFromUser(new HashSet<>(Arrays.asList(tp0, tp1))); + state.unsubscribe(); + state.subscribe(singleton(topic), rebalanceListener); + assertEquals(singleton(topic), state.subscription()); + } + + @Test + public void unsubscribeUserSubscribe() { + state.subscribe(singleton(topic), rebalanceListener); + state.unsubscribe(); + state.assignFromUser(singleton(tp0)); + assertEquals(singleton(tp0), state.assignedPartitions()); + assertEquals(1, state.numAssignedPartitions()); + } + + @Test + public void unsubscription() { + state.subscribe(Pattern.compile(".*"), rebalanceListener); + state.subscribeFromPattern(new HashSet<>(Arrays.asList(topic, topic1))); + assertTrue(state.checkAssignmentMatchedSubscription(singleton(tp1))); + state.assignFromSubscribed(singleton(tp1)); + + assertEquals(singleton(tp1), state.assignedPartitions()); + assertEquals(1, state.numAssignedPartitions()); + + state.unsubscribe(); + assertEquals(0, state.subscription().size()); + assertTrue(state.assignedPartitions().isEmpty()); + assertEquals(0, state.numAssignedPartitions()); + + state.assignFromUser(singleton(tp0)); + assertEquals(singleton(tp0), state.assignedPartitions()); + assertEquals(1, state.numAssignedPartitions()); + + state.unsubscribe(); + assertEquals(0, state.subscription().size()); + assertTrue(state.assignedPartitions().isEmpty()); + assertEquals(0, state.numAssignedPartitions()); + } + + @Test + public void testPreferredReadReplicaLease() { + state.assignFromUser(Collections.singleton(tp0)); + + // Default state + assertFalse(state.preferredReadReplica(tp0, 0L).isPresent()); + + // Set the preferred replica with lease + state.updatePreferredReadReplica(tp0, 42, () -> 10L); + TestUtils.assertOptional(state.preferredReadReplica(tp0, 9L), value -> assertEquals(value.intValue(), 42)); + TestUtils.assertOptional(state.preferredReadReplica(tp0, 10L), value -> assertEquals(value.intValue(), 42)); + assertFalse(state.preferredReadReplica(tp0, 11L).isPresent()); + + // Unset the preferred replica + state.clearPreferredReadReplica(tp0); + assertFalse(state.preferredReadReplica(tp0, 9L).isPresent()); + assertFalse(state.preferredReadReplica(tp0, 11L).isPresent()); + + // Set to new preferred replica with lease + state.updatePreferredReadReplica(tp0, 43, () -> 20L); + TestUtils.assertOptional(state.preferredReadReplica(tp0, 11L), value -> assertEquals(value.intValue(), 43)); + TestUtils.assertOptional(state.preferredReadReplica(tp0, 20L), value -> assertEquals(value.intValue(), 43)); + assertFalse(state.preferredReadReplica(tp0, 21L).isPresent()); + + // Set to new preferred replica without clearing first + state.updatePreferredReadReplica(tp0, 44, () -> 30L); + TestUtils.assertOptional(state.preferredReadReplica(tp0, 30L), value -> assertEquals(value.intValue(), 44)); + assertFalse(state.preferredReadReplica(tp0, 31L).isPresent()); + } + + @Test + public void testSeekUnvalidatedWithNoOffsetEpoch() { + Node broker1 = new Node(1, "localhost", 9092); + state.assignFromUser(Collections.singleton(tp0)); + + // Seek with no offset epoch requires no validation no matter what the current leader is + state.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(0L, Optional.empty(), + new Metadata.LeaderAndEpoch(Optional.of(broker1), Optional.of(5)))); + assertTrue(state.hasValidPosition(tp0)); + assertFalse(state.awaitingValidation(tp0)); + ApiVersions apiVersions = new ApiVersions(); + apiVersions.update(broker1.idString(), NodeApiVersions.create()); + + assertFalse(state.maybeValidatePositionForCurrentLeader(apiVersions, tp0, new Metadata.LeaderAndEpoch( + Optional.of(broker1), Optional.empty()))); + assertTrue(state.hasValidPosition(tp0)); + assertFalse(state.awaitingValidation(tp0)); + + assertFalse(state.maybeValidatePositionForCurrentLeader(apiVersions, tp0, new Metadata.LeaderAndEpoch( + Optional.of(broker1), Optional.of(10)))); + assertTrue(state.hasValidPosition(tp0)); + assertFalse(state.awaitingValidation(tp0)); + } + + @Test + public void testSeekUnvalidatedWithNoEpochClearsAwaitingValidation() { + Node broker1 = new Node(1, "localhost", 9092); + state.assignFromUser(Collections.singleton(tp0)); + + // Seek with no offset epoch requires no validation no matter what the current leader is + state.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(0L, Optional.of(2), + new Metadata.LeaderAndEpoch(Optional.of(broker1), Optional.of(5)))); + assertFalse(state.hasValidPosition(tp0)); + assertTrue(state.awaitingValidation(tp0)); + + state.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(0L, Optional.empty(), + new Metadata.LeaderAndEpoch(Optional.of(broker1), Optional.of(5)))); + assertTrue(state.hasValidPosition(tp0)); + assertFalse(state.awaitingValidation(tp0)); + } + + @Test + public void testSeekUnvalidatedWithOffsetEpoch() { + Node broker1 = new Node(1, "localhost", 9092); + ApiVersions apiVersions = new ApiVersions(); + apiVersions.update(broker1.idString(), NodeApiVersions.create()); + + state.assignFromUser(Collections.singleton(tp0)); + + state.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(0L, Optional.of(2), + new Metadata.LeaderAndEpoch(Optional.of(broker1), Optional.of(5)))); + assertFalse(state.hasValidPosition(tp0)); + assertTrue(state.awaitingValidation(tp0)); + + // Update using the current leader and epoch + assertTrue(state.maybeValidatePositionForCurrentLeader(apiVersions, tp0, new Metadata.LeaderAndEpoch( + Optional.of(broker1), Optional.of(5)))); + assertFalse(state.hasValidPosition(tp0)); + assertTrue(state.awaitingValidation(tp0)); + + // Update with a newer leader and epoch + assertTrue(state.maybeValidatePositionForCurrentLeader(apiVersions, tp0, new Metadata.LeaderAndEpoch( + Optional.of(broker1), Optional.of(15)))); + assertFalse(state.hasValidPosition(tp0)); + assertTrue(state.awaitingValidation(tp0)); + + // If the updated leader has no epoch information, then skip validation and begin fetching + assertFalse(state.maybeValidatePositionForCurrentLeader(apiVersions, tp0, new Metadata.LeaderAndEpoch( + Optional.of(broker1), Optional.empty()))); + assertTrue(state.hasValidPosition(tp0)); + assertFalse(state.awaitingValidation(tp0)); + } + + @Test + public void testSeekValidatedShouldClearAwaitingValidation() { + Node broker1 = new Node(1, "localhost", 9092); + state.assignFromUser(Collections.singleton(tp0)); + + state.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(10L, Optional.of(5), + new Metadata.LeaderAndEpoch(Optional.of(broker1), Optional.of(10)))); + assertFalse(state.hasValidPosition(tp0)); + assertTrue(state.awaitingValidation(tp0)); + assertEquals(10L, state.position(tp0).offset); + + state.seekValidated(tp0, new SubscriptionState.FetchPosition(8L, Optional.of(4), + new Metadata.LeaderAndEpoch(Optional.of(broker1), Optional.of(10)))); + assertTrue(state.hasValidPosition(tp0)); + assertFalse(state.awaitingValidation(tp0)); + assertEquals(8L, state.position(tp0).offset); + } + + @Test + public void testCompleteValidationShouldClearAwaitingValidation() { + Node broker1 = new Node(1, "localhost", 9092); + state.assignFromUser(Collections.singleton(tp0)); + + state.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(10L, Optional.of(5), + new Metadata.LeaderAndEpoch(Optional.of(broker1), Optional.of(10)))); + assertFalse(state.hasValidPosition(tp0)); + assertTrue(state.awaitingValidation(tp0)); + assertEquals(10L, state.position(tp0).offset); + + state.completeValidation(tp0); + assertTrue(state.hasValidPosition(tp0)); + assertFalse(state.awaitingValidation(tp0)); + assertEquals(10L, state.position(tp0).offset); + } + + @Test + public void testOffsetResetWhileAwaitingValidation() { + Node broker1 = new Node(1, "localhost", 9092); + state.assignFromUser(Collections.singleton(tp0)); + + state.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(10L, Optional.of(5), + new Metadata.LeaderAndEpoch(Optional.of(broker1), Optional.of(10)))); + assertTrue(state.awaitingValidation(tp0)); + + state.requestOffsetReset(tp0, OffsetResetStrategy.EARLIEST); + assertFalse(state.awaitingValidation(tp0)); + assertTrue(state.isOffsetResetNeeded(tp0)); + } + + @Test + public void testMaybeCompleteValidation() { + Node broker1 = new Node(1, "localhost", 9092); + state.assignFromUser(Collections.singleton(tp0)); + + int currentEpoch = 10; + long initialOffset = 10L; + int initialOffsetEpoch = 5; + + SubscriptionState.FetchPosition initialPosition = new SubscriptionState.FetchPosition(initialOffset, + Optional.of(initialOffsetEpoch), new Metadata.LeaderAndEpoch(Optional.of(broker1), Optional.of(currentEpoch))); + state.seekUnvalidated(tp0, initialPosition); + assertTrue(state.awaitingValidation(tp0)); + + Optional truncationOpt = state.maybeCompleteValidation(tp0, initialPosition, + new EpochEndOffset() + .setLeaderEpoch(initialOffsetEpoch) + .setEndOffset(initialOffset + 5)); + assertEquals(Optional.empty(), truncationOpt); + assertFalse(state.awaitingValidation(tp0)); + assertEquals(initialPosition, state.position(tp0)); + } + + @Test + public void testMaybeValidatePositionForCurrentLeader() { + NodeApiVersions oldApis = NodeApiVersions.create(ApiKeys.OFFSET_FOR_LEADER_EPOCH.id, (short) 0, (short) 2); + ApiVersions apiVersions = new ApiVersions(); + apiVersions.update("1", oldApis); + + Node broker1 = new Node(1, "localhost", 9092); + state.assignFromUser(Collections.singleton(tp0)); + + state.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(10L, Optional.of(5), + new Metadata.LeaderAndEpoch(Optional.of(broker1), Optional.of(10)))); + + // if API is too old to be usable, we just skip validation + assertFalse(state.maybeValidatePositionForCurrentLeader(apiVersions, tp0, new Metadata.LeaderAndEpoch( + Optional.of(broker1), Optional.of(10)))); + assertTrue(state.hasValidPosition(tp0)); + + // New API + apiVersions.update("1", NodeApiVersions.create()); + state.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(10L, Optional.of(5), + new Metadata.LeaderAndEpoch(Optional.of(broker1), Optional.of(10)))); + + // API is too old to be usable, we just skip validation + assertTrue(state.maybeValidatePositionForCurrentLeader(apiVersions, tp0, new Metadata.LeaderAndEpoch( + Optional.of(broker1), Optional.of(10)))); + assertFalse(state.hasValidPosition(tp0)); + } + + @Test + public void testMaybeCompleteValidationAfterPositionChange() { + Node broker1 = new Node(1, "localhost", 9092); + state.assignFromUser(Collections.singleton(tp0)); + + int currentEpoch = 10; + long initialOffset = 10L; + int initialOffsetEpoch = 5; + long updateOffset = 20L; + int updateOffsetEpoch = 8; + + SubscriptionState.FetchPosition initialPosition = new SubscriptionState.FetchPosition(initialOffset, + Optional.of(initialOffsetEpoch), new Metadata.LeaderAndEpoch(Optional.of(broker1), Optional.of(currentEpoch))); + state.seekUnvalidated(tp0, initialPosition); + assertTrue(state.awaitingValidation(tp0)); + + SubscriptionState.FetchPosition updatePosition = new SubscriptionState.FetchPosition(updateOffset, + Optional.of(updateOffsetEpoch), new Metadata.LeaderAndEpoch(Optional.of(broker1), Optional.of(currentEpoch))); + state.seekUnvalidated(tp0, updatePosition); + + Optional truncationOpt = state.maybeCompleteValidation(tp0, initialPosition, + new EpochEndOffset() + .setLeaderEpoch(initialOffsetEpoch) + .setEndOffset(initialOffset + 5)); + assertEquals(Optional.empty(), truncationOpt); + assertTrue(state.awaitingValidation(tp0)); + assertEquals(updatePosition, state.position(tp0)); + } + + @Test + public void testMaybeCompleteValidationAfterOffsetReset() { + Node broker1 = new Node(1, "localhost", 9092); + state.assignFromUser(Collections.singleton(tp0)); + + int currentEpoch = 10; + long initialOffset = 10L; + int initialOffsetEpoch = 5; + + SubscriptionState.FetchPosition initialPosition = new SubscriptionState.FetchPosition(initialOffset, + Optional.of(initialOffsetEpoch), new Metadata.LeaderAndEpoch(Optional.of(broker1), Optional.of(currentEpoch))); + state.seekUnvalidated(tp0, initialPosition); + assertTrue(state.awaitingValidation(tp0)); + + state.requestOffsetReset(tp0); + + Optional truncationOpt = state.maybeCompleteValidation(tp0, initialPosition, + new EpochEndOffset() + .setLeaderEpoch(initialOffsetEpoch) + .setEndOffset(initialOffset + 5)); + assertEquals(Optional.empty(), truncationOpt); + assertFalse(state.awaitingValidation(tp0)); + assertTrue(state.isOffsetResetNeeded(tp0)); + assertNull(state.position(tp0)); + } + + @Test + public void testTruncationDetectionWithResetPolicy() { + Node broker1 = new Node(1, "localhost", 9092); + state.assignFromUser(Collections.singleton(tp0)); + + int currentEpoch = 10; + long initialOffset = 10L; + int initialOffsetEpoch = 5; + long divergentOffset = 5L; + int divergentOffsetEpoch = 7; + + SubscriptionState.FetchPosition initialPosition = new SubscriptionState.FetchPosition(initialOffset, + Optional.of(initialOffsetEpoch), new Metadata.LeaderAndEpoch(Optional.of(broker1), Optional.of(currentEpoch))); + state.seekUnvalidated(tp0, initialPosition); + assertTrue(state.awaitingValidation(tp0)); + + Optional truncationOpt = state.maybeCompleteValidation(tp0, initialPosition, + new EpochEndOffset() + .setLeaderEpoch(divergentOffsetEpoch) + .setEndOffset(divergentOffset)); + assertEquals(Optional.empty(), truncationOpt); + assertFalse(state.awaitingValidation(tp0)); + + SubscriptionState.FetchPosition updatedPosition = new SubscriptionState.FetchPosition(divergentOffset, + Optional.of(divergentOffsetEpoch), new Metadata.LeaderAndEpoch(Optional.of(broker1), Optional.of(currentEpoch))); + assertEquals(updatedPosition, state.position(tp0)); + } + + @Test + public void testTruncationDetectionWithoutResetPolicy() { + Node broker1 = new Node(1, "localhost", 9092); + state = new SubscriptionState(new LogContext(), OffsetResetStrategy.NONE); + state.assignFromUser(Collections.singleton(tp0)); + + int currentEpoch = 10; + long initialOffset = 10L; + int initialOffsetEpoch = 5; + long divergentOffset = 5L; + int divergentOffsetEpoch = 7; + + SubscriptionState.FetchPosition initialPosition = new SubscriptionState.FetchPosition(initialOffset, + Optional.of(initialOffsetEpoch), new Metadata.LeaderAndEpoch(Optional.of(broker1), Optional.of(currentEpoch))); + state.seekUnvalidated(tp0, initialPosition); + assertTrue(state.awaitingValidation(tp0)); + + Optional truncationOpt = state.maybeCompleteValidation(tp0, initialPosition, + new EpochEndOffset() + .setLeaderEpoch(divergentOffsetEpoch) + .setEndOffset(divergentOffset)); + assertTrue(truncationOpt.isPresent()); + LogTruncation truncation = truncationOpt.get(); + + assertEquals(Optional.of(new OffsetAndMetadata(divergentOffset, Optional.of(divergentOffsetEpoch), "")), + truncation.divergentOffsetOpt); + assertEquals(initialPosition, truncation.fetchPosition); + assertTrue(state.awaitingValidation(tp0)); + } + + @Test + public void testTruncationDetectionUnknownDivergentOffsetWithResetPolicy() { + Node broker1 = new Node(1, "localhost", 9092); + state = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + state.assignFromUser(Collections.singleton(tp0)); + + int currentEpoch = 10; + long initialOffset = 10L; + int initialOffsetEpoch = 5; + + SubscriptionState.FetchPosition initialPosition = new SubscriptionState.FetchPosition(initialOffset, + Optional.of(initialOffsetEpoch), new Metadata.LeaderAndEpoch(Optional.of(broker1), Optional.of(currentEpoch))); + state.seekUnvalidated(tp0, initialPosition); + assertTrue(state.awaitingValidation(tp0)); + + Optional truncationOpt = state.maybeCompleteValidation(tp0, initialPosition, + new EpochEndOffset() + .setLeaderEpoch(UNDEFINED_EPOCH) + .setEndOffset(UNDEFINED_EPOCH_OFFSET)); + assertEquals(Optional.empty(), truncationOpt); + assertFalse(state.awaitingValidation(tp0)); + assertTrue(state.isOffsetResetNeeded(tp0)); + assertEquals(OffsetResetStrategy.EARLIEST, state.resetStrategy(tp0)); + } + + @Test + public void testTruncationDetectionUnknownDivergentOffsetWithoutResetPolicy() { + Node broker1 = new Node(1, "localhost", 9092); + state = new SubscriptionState(new LogContext(), OffsetResetStrategy.NONE); + state.assignFromUser(Collections.singleton(tp0)); + + int currentEpoch = 10; + long initialOffset = 10L; + int initialOffsetEpoch = 5; + + SubscriptionState.FetchPosition initialPosition = new SubscriptionState.FetchPosition(initialOffset, + Optional.of(initialOffsetEpoch), new Metadata.LeaderAndEpoch(Optional.of(broker1), Optional.of(currentEpoch))); + state.seekUnvalidated(tp0, initialPosition); + assertTrue(state.awaitingValidation(tp0)); + + Optional truncationOpt = state.maybeCompleteValidation(tp0, initialPosition, + new EpochEndOffset() + .setLeaderEpoch(UNDEFINED_EPOCH) + .setEndOffset(UNDEFINED_EPOCH_OFFSET)); + assertTrue(truncationOpt.isPresent()); + LogTruncation truncation = truncationOpt.get(); + + assertEquals(Optional.empty(), truncation.divergentOffsetOpt); + assertEquals(initialPosition, truncation.fetchPosition); + assertTrue(state.awaitingValidation(tp0)); + } + + private static class MockRebalanceListener implements ConsumerRebalanceListener { + Collection revoked; + public Collection assigned; + int revokedCount = 0; + int assignedCount = 0; + + @Override + public void onPartitionsAssigned(Collection partitions) { + this.assigned = partitions; + assignedCount++; + } + + @Override + public void onPartitionsRevoked(Collection partitions) { + this.revoked = partitions; + revokedCount++; + } + + } + + @Test + public void resetOffsetNoValidation() { + // Check that offset reset works when we can't validate offsets (older brokers) + + Node broker1 = new Node(1, "localhost", 9092); + state.assignFromUser(Collections.singleton(tp0)); + + // Reset offsets + state.requestOffsetReset(tp0, OffsetResetStrategy.EARLIEST); + + // Attempt to validate with older API version, should do nothing + ApiVersions oldApis = new ApiVersions(); + oldApis.update("1", NodeApiVersions.create(ApiKeys.OFFSET_FOR_LEADER_EPOCH.id, (short) 0, (short) 2)); + assertFalse(state.maybeValidatePositionForCurrentLeader(oldApis, tp0, new Metadata.LeaderAndEpoch( + Optional.of(broker1), Optional.empty()))); + assertFalse(state.hasValidPosition(tp0)); + assertFalse(state.awaitingValidation(tp0)); + assertTrue(state.isOffsetResetNeeded(tp0)); + + // Complete the reset via unvalidated seek + state.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(10L)); + assertTrue(state.hasValidPosition(tp0)); + assertFalse(state.awaitingValidation(tp0)); + assertFalse(state.isOffsetResetNeeded(tp0)); + + // Next call to validate offsets does nothing + assertFalse(state.maybeValidatePositionForCurrentLeader(oldApis, tp0, new Metadata.LeaderAndEpoch( + Optional.of(broker1), Optional.empty()))); + assertTrue(state.hasValidPosition(tp0)); + assertFalse(state.awaitingValidation(tp0)); + assertFalse(state.isOffsetResetNeeded(tp0)); + + // Reset again, and complete it with a seek that would normally require validation + state.requestOffsetReset(tp0, OffsetResetStrategy.EARLIEST); + state.seekUnvalidated(tp0, new SubscriptionState.FetchPosition(10L, Optional.of(10), new Metadata.LeaderAndEpoch( + Optional.of(broker1), Optional.of(2)))); + // We are now in AWAIT_VALIDATION + assertFalse(state.hasValidPosition(tp0)); + assertTrue(state.awaitingValidation(tp0)); + assertFalse(state.isOffsetResetNeeded(tp0)); + + // Now ensure next call to validate clears the validation state + assertFalse(state.maybeValidatePositionForCurrentLeader(oldApis, tp0, new Metadata.LeaderAndEpoch( + Optional.of(broker1), Optional.of(2)))); + assertTrue(state.hasValidPosition(tp0)); + assertFalse(state.awaitingValidation(tp0)); + assertFalse(state.isOffsetResetNeeded(tp0)); + } + + @Test + public void nullPositionLagOnNoPosition() { + state.assignFromUser(Collections.singleton(tp0)); + + assertNull(state.partitionLag(tp0, IsolationLevel.READ_UNCOMMITTED)); + assertNull(state.partitionLag(tp0, IsolationLevel.READ_COMMITTED)); + + state.updateHighWatermark(tp0, 1L); + state.updateLastStableOffset(tp0, 1L); + + assertNull(state.partitionLag(tp0, IsolationLevel.READ_UNCOMMITTED)); + assertNull(state.partitionLag(tp0, IsolationLevel.READ_COMMITTED)); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ThrowOnAssignmentAssignor.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ThrowOnAssignmentAssignor.java new file mode 100644 index 0000000..782ca26 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/ThrowOnAssignmentAssignor.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.consumer.internals; + +import org.apache.kafka.clients.consumer.ConsumerGroupMetadata; + +import java.util.List; + +/** + * A mock assignor which throws for {@link org.apache.kafka.clients.consumer.ConsumerPartitionAssignor#onAssignment}. + */ +public class ThrowOnAssignmentAssignor extends MockPartitionAssignor { + + private final RuntimeException bookeepedException; + private final String name; + + ThrowOnAssignmentAssignor(final List supportedProtocols, + final RuntimeException bookeepedException, + final String name) { + super(supportedProtocols); + this.bookeepedException = bookeepedException; + this.name = name; + } + + @Override + public void onAssignment(Assignment assignment, ConsumerGroupMetadata metadata) { + super.onAssignment(assignment, metadata); + throw bookeepedException; + } + + @Override + public String name() { + return name; + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java new file mode 100644 index 0000000..1e45a58 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/producer/KafkaProducerTest.java @@ -0,0 +1,1561 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.clients.KafkaClient; +import org.apache.kafka.clients.MockClient; +import org.apache.kafka.clients.NodeApiVersions; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerGroupMetadata; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.producer.internals.ProducerInterceptors; +import org.apache.kafka.clients.producer.internals.ProducerMetadata; +import org.apache.kafka.clients.producer.internals.Sender; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.config.SslConfigs; +import org.apache.kafka.common.errors.InterruptException; +import org.apache.kafka.common.errors.InvalidTopicException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.internals.ClusterResourceListeners; +import org.apache.kafka.common.message.AddOffsetsToTxnResponseData; +import org.apache.kafka.common.message.EndTxnResponseData; +import org.apache.kafka.common.message.InitProducerIdResponseData; +import org.apache.kafka.common.message.TxnOffsetCommitRequestData; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.stats.Avg; +import org.apache.kafka.common.network.Selectable; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.AddOffsetsToTxnResponse; +import org.apache.kafka.common.requests.EndTxnResponse; +import org.apache.kafka.common.requests.FindCoordinatorRequest; +import org.apache.kafka.common.requests.FindCoordinatorResponse; +import org.apache.kafka.common.requests.InitProducerIdResponse; +import org.apache.kafka.common.requests.JoinGroupRequest; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.requests.RequestTestUtils; +import org.apache.kafka.common.requests.TxnOffsetCommitRequest; +import org.apache.kafka.common.requests.TxnOffsetCommitResponse; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.test.MockMetricsReporter; +import org.apache.kafka.test.MockPartitioner; +import org.apache.kafka.test.MockProducerInterceptor; +import org.apache.kafka.test.MockSerializer; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; + +import javax.management.MBeanServer; +import javax.management.ObjectName; +import java.lang.management.ManagementFactory; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Properties; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Exchanger; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Stream; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.singletonMap; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.notNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class KafkaProducerTest { + private final String topic = "topic"; + private final Node host1 = new Node(0, "host1", 1000); + private final Collection nodes = Collections.singletonList(host1); + private final Cluster emptyCluster = new Cluster( + null, + nodes, + Collections.emptySet(), + Collections.emptySet(), + Collections.emptySet()); + private final Cluster onePartitionCluster = new Cluster( + "dummy", + nodes, + Collections.singletonList(new PartitionInfo(topic, 0, null, null, null)), + Collections.emptySet(), + Collections.emptySet()); + private final Cluster threePartitionCluster = new Cluster( + "dummy", + nodes, + Arrays.asList( + new PartitionInfo(topic, 0, null, null, null), + new PartitionInfo(topic, 1, null, null, null), + new PartitionInfo(topic, 2, null, null, null)), + Collections.emptySet(), + Collections.emptySet()); + private static final int DEFAULT_METADATA_IDLE_MS = 5 * 60 * 1000; + + + private static KafkaProducer kafkaProducer(Map configs, + Serializer keySerializer, + Serializer valueSerializer, + ProducerMetadata metadata, + KafkaClient kafkaClient, + ProducerInterceptors interceptors, + Time time) { + return new KafkaProducer<>(new ProducerConfig(ProducerConfig.appendSerializerToConfig(configs, keySerializer, valueSerializer)), + keySerializer, valueSerializer, metadata, kafkaClient, interceptors, time); + } + + @Test + public void testOverwriteAcksAndRetriesForIdempotentProducers() { + Properties props = new Properties(); + props.setProperty(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + props.setProperty(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "transactionalId"); + props.setProperty(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class.getName()); + props.setProperty(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, StringSerializer.class.getName()); + + ProducerConfig config = new ProducerConfig(props); + assertTrue(config.getBoolean(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG)); + assertTrue(Stream.of("-1", "all").anyMatch(each -> each.equalsIgnoreCase(config.getString(ProducerConfig.ACKS_CONFIG)))); + assertEquals((int) config.getInt(ProducerConfig.RETRIES_CONFIG), Integer.MAX_VALUE); + assertTrue(config.getString(ProducerConfig.CLIENT_ID_CONFIG).equalsIgnoreCase("producer-" + + config.getString(ProducerConfig.TRANSACTIONAL_ID_CONFIG))); + } + + @Test + public void testAcksAndIdempotenceForIdempotentProducers() { + Properties baseProps = new Properties() {{ + setProperty( + ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + setProperty( + ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class.getName()); + setProperty( + ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, StringSerializer.class.getName()); + }}; + + Properties validProps = new Properties() {{ + putAll(baseProps); + setProperty(ProducerConfig.ACKS_CONFIG, "0"); + setProperty(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG, "false"); + }}; + ProducerConfig config = new ProducerConfig(validProps); + assertFalse( + config.getBoolean(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG), + "idempotence should be overwritten"); + assertEquals( + "0", + config.getString(ProducerConfig.ACKS_CONFIG), + "acks should be overwritten"); + + Properties validProps2 = new Properties() {{ + putAll(baseProps); + setProperty(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "transactionalId"); + }}; + config = new ProducerConfig(validProps2); + assertTrue( + config.getBoolean(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG), + "idempotence should be set with the default value"); + assertEquals( + "-1", + config.getString(ProducerConfig.ACKS_CONFIG), + "acks should be set with the default value"); + + Properties validProps3 = new Properties() {{ + putAll(baseProps); + setProperty(ProducerConfig.ACKS_CONFIG, "all"); + setProperty(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG, "false"); + }}; + config = new ProducerConfig(validProps3); + assertFalse(config.getBoolean(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG), + "idempotence should be overwritten"); + assertEquals( + "-1", + config.getString(ProducerConfig.ACKS_CONFIG), + "acks should be overwritten"); + + Properties invalidProps = new Properties() {{ + putAll(baseProps); + setProperty(ProducerConfig.ACKS_CONFIG, "0"); + setProperty(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG, "false"); + setProperty(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "transactionalId"); + }}; + assertThrows( + ConfigException.class, + () -> new ProducerConfig(invalidProps), + "Cannot set a transactional.id without also enabling idempotence"); + + Properties invalidProps2 = new Properties() {{ + putAll(baseProps); + setProperty(ProducerConfig.ACKS_CONFIG, "1"); + setProperty(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG, "true"); + }}; + assertThrows( + ConfigException.class, + () -> new ProducerConfig(invalidProps2), + "Must set acks to all in order to use the idempotent producer"); + + Properties invalidProps3 = new Properties() {{ + putAll(baseProps); + setProperty(ProducerConfig.ACKS_CONFIG, "0"); + setProperty(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG, "true"); + }}; + assertThrows( + ConfigException.class, + () -> new ProducerConfig(invalidProps3), + "Must set acks to all in order to use the idempotent producer"); + } + + @Test + public void testMetricsReporterAutoGeneratedClientId() { + Properties props = new Properties(); + props.setProperty(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + props.setProperty(ProducerConfig.METRIC_REPORTER_CLASSES_CONFIG, MockMetricsReporter.class.getName()); + KafkaProducer producer = new KafkaProducer<>( + props, new StringSerializer(), new StringSerializer()); + + MockMetricsReporter mockMetricsReporter = (MockMetricsReporter) producer.metrics.reporters().get(0); + + assertEquals(producer.getClientId(), mockMetricsReporter.clientId); + producer.close(); + } + + @Test + public void testConstructorWithSerializers() { + Properties producerProps = new Properties(); + producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000"); + new KafkaProducer<>(producerProps, new ByteArraySerializer(), new ByteArraySerializer()).close(); + } + + @Test + public void testNoSerializerProvided() { + Properties producerProps = new Properties(); + producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000"); + assertThrows(ConfigException.class, () -> new KafkaProducer(producerProps)); + } + + @Test + public void testConstructorFailureCloseResource() { + Properties props = new Properties(); + props.setProperty(ProducerConfig.CLIENT_ID_CONFIG, "testConstructorClose"); + props.setProperty(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "some.invalid.hostname.foo.bar.local:9999"); + props.setProperty(ProducerConfig.METRIC_REPORTER_CLASSES_CONFIG, MockMetricsReporter.class.getName()); + + final int oldInitCount = MockMetricsReporter.INIT_COUNT.get(); + final int oldCloseCount = MockMetricsReporter.CLOSE_COUNT.get(); + try (KafkaProducer ignored = new KafkaProducer<>(props, new ByteArraySerializer(), new ByteArraySerializer())) { + fail("should have caught an exception and returned"); + } catch (KafkaException e) { + assertEquals(oldInitCount + 1, MockMetricsReporter.INIT_COUNT.get()); + assertEquals(oldCloseCount + 1, MockMetricsReporter.CLOSE_COUNT.get()); + assertEquals("Failed to construct kafka producer", e.getMessage()); + } + } + + @Test + public void testConstructorWithNotStringKey() { + Properties props = new Properties(); + props.setProperty(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + props.put(1, "not string key"); + try (KafkaProducer ff = new KafkaProducer<>(props, new StringSerializer(), new StringSerializer())) { + fail("Constructor should throw exception"); + } catch (ConfigException e) { + assertTrue(e.getMessage().contains("not string key"), "Unexpected exception message: " + e.getMessage()); + } + } + + @Test + public void testSerializerClose() { + Map configs = new HashMap<>(); + configs.put(ProducerConfig.CLIENT_ID_CONFIG, "testConstructorClose"); + configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + configs.put(ProducerConfig.METRIC_REPORTER_CLASSES_CONFIG, MockMetricsReporter.class.getName()); + configs.put(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, CommonClientConfigs.DEFAULT_SECURITY_PROTOCOL); + final int oldInitCount = MockSerializer.INIT_COUNT.get(); + final int oldCloseCount = MockSerializer.CLOSE_COUNT.get(); + + KafkaProducer producer = new KafkaProducer<>( + configs, new MockSerializer(), new MockSerializer()); + assertEquals(oldInitCount + 2, MockSerializer.INIT_COUNT.get()); + assertEquals(oldCloseCount, MockSerializer.CLOSE_COUNT.get()); + + producer.close(); + assertEquals(oldInitCount + 2, MockSerializer.INIT_COUNT.get()); + assertEquals(oldCloseCount + 2, MockSerializer.CLOSE_COUNT.get()); + } + + @Test + public void testInterceptorConstructClose() { + try { + Properties props = new Properties(); + // test with client ID assigned by KafkaProducer + props.setProperty(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + props.setProperty(ProducerConfig.INTERCEPTOR_CLASSES_CONFIG, MockProducerInterceptor.class.getName()); + props.setProperty(MockProducerInterceptor.APPEND_STRING_PROP, "something"); + + KafkaProducer producer = new KafkaProducer<>( + props, new StringSerializer(), new StringSerializer()); + assertEquals(1, MockProducerInterceptor.INIT_COUNT.get()); + assertEquals(0, MockProducerInterceptor.CLOSE_COUNT.get()); + + // Cluster metadata will only be updated on calling onSend. + assertNull(MockProducerInterceptor.CLUSTER_META.get()); + + producer.close(); + assertEquals(1, MockProducerInterceptor.INIT_COUNT.get()); + assertEquals(1, MockProducerInterceptor.CLOSE_COUNT.get()); + } finally { + // cleanup since we are using mutable static variables in MockProducerInterceptor + MockProducerInterceptor.resetCounters(); + } + } + + @Test + public void testPartitionerClose() { + try { + Properties props = new Properties(); + props.setProperty(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + MockPartitioner.resetCounters(); + props.setProperty(ProducerConfig.PARTITIONER_CLASS_CONFIG, MockPartitioner.class.getName()); + + KafkaProducer producer = new KafkaProducer<>( + props, new StringSerializer(), new StringSerializer()); + assertEquals(1, MockPartitioner.INIT_COUNT.get()); + assertEquals(0, MockPartitioner.CLOSE_COUNT.get()); + + producer.close(); + assertEquals(1, MockPartitioner.INIT_COUNT.get()); + assertEquals(1, MockPartitioner.CLOSE_COUNT.get()); + } finally { + // cleanup since we are using mutable static variables in MockPartitioner + MockPartitioner.resetCounters(); + } + } + + @Test + public void shouldCloseProperlyAndThrowIfInterrupted() throws Exception { + Map configs = new HashMap<>(); + configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + configs.put(ProducerConfig.PARTITIONER_CLASS_CONFIG, MockPartitioner.class.getName()); + configs.put(ProducerConfig.BATCH_SIZE_CONFIG, "1"); + + Time time = new MockTime(); + MetadataResponse initialUpdateResponse = RequestTestUtils.metadataUpdateWith(1, singletonMap("topic", 1)); + ProducerMetadata metadata = newMetadata(0, Long.MAX_VALUE); + MockClient client = new MockClient(time, metadata); + client.updateMetadata(initialUpdateResponse); + + final Producer producer = kafkaProducer(configs, new StringSerializer(), + new StringSerializer(), metadata, client, null, time); + + ExecutorService executor = Executors.newSingleThreadExecutor(); + final AtomicReference closeException = new AtomicReference<>(); + try { + Future future = executor.submit(() -> { + producer.send(new ProducerRecord<>("topic", "key", "value")); + try { + producer.close(); + fail("Close should block and throw."); + } catch (Exception e) { + closeException.set(e); + } + }); + + // Close producer should not complete until send succeeds + try { + future.get(100, TimeUnit.MILLISECONDS); + fail("Close completed without waiting for send"); + } catch (java.util.concurrent.TimeoutException expected) { /* ignore */ } + + // Ensure send has started + client.waitForRequests(1, 1000); + + assertTrue(future.cancel(true), "Close terminated prematurely"); + + TestUtils.waitForCondition(() -> closeException.get() != null, + "InterruptException did not occur within timeout."); + + assertTrue(closeException.get() instanceof InterruptException, "Expected exception not thrown " + closeException); + } finally { + executor.shutdownNow(); + } + + } + + @Test + public void testOsDefaultSocketBufferSizes() { + Map config = new HashMap<>(); + config.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + config.put(ProducerConfig.SEND_BUFFER_CONFIG, Selectable.USE_DEFAULT_BUFFER_SIZE); + config.put(ProducerConfig.RECEIVE_BUFFER_CONFIG, Selectable.USE_DEFAULT_BUFFER_SIZE); + new KafkaProducer<>(config, new ByteArraySerializer(), new ByteArraySerializer()).close(); + } + + @Test + public void testInvalidSocketSendBufferSize() { + Map config = new HashMap<>(); + config.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + config.put(ProducerConfig.SEND_BUFFER_CONFIG, -2); + assertThrows(KafkaException.class, () -> new KafkaProducer<>(config, new ByteArraySerializer(), new ByteArraySerializer())); + } + + @Test + public void testInvalidSocketReceiveBufferSize() { + Map config = new HashMap<>(); + config.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + config.put(ProducerConfig.RECEIVE_BUFFER_CONFIG, -2); + assertThrows(KafkaException.class, () -> new KafkaProducer<>(config, new ByteArraySerializer(), new ByteArraySerializer())); + } + + private static KafkaProducer producerWithOverrideNewSender(Map configs, + ProducerMetadata metadata) { + return producerWithOverrideNewSender(configs, metadata, Time.SYSTEM); + } + + private static KafkaProducer producerWithOverrideNewSender(Map configs, + ProducerMetadata metadata, + Time timer) { + return new KafkaProducer( + new ProducerConfig(ProducerConfig.appendSerializerToConfig(configs, new StringSerializer(), new StringSerializer())), + new StringSerializer(), new StringSerializer(), metadata, new MockClient(Time.SYSTEM, metadata), null, timer) { + @Override + Sender newSender(LogContext logContext, KafkaClient kafkaClient, ProducerMetadata metadata) { + // give Sender its own Metadata instance so that we can isolate Metadata calls from KafkaProducer + return super.newSender(logContext, kafkaClient, newMetadata(0, 100_000)); + } + }; + } + + @Test + public void testMetadataFetch() throws InterruptedException { + Map configs = new HashMap<>(); + configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + ProducerMetadata metadata = mock(ProducerMetadata.class); + + // Return empty cluster 4 times and cluster from then on + when(metadata.fetch()).thenReturn(emptyCluster, emptyCluster, emptyCluster, emptyCluster, onePartitionCluster); + + KafkaProducer producer = producerWithOverrideNewSender(configs, metadata); + ProducerRecord record = new ProducerRecord<>(topic, "value"); + producer.send(record); + + // One request update for each empty cluster returned + verify(metadata, times(4)).requestUpdateForTopic(topic); + verify(metadata, times(4)).awaitUpdate(anyInt(), anyLong()); + verify(metadata, times(5)).fetch(); + + // Should not request update for subsequent `send` + producer.send(record, null); + verify(metadata, times(4)).requestUpdateForTopic(topic); + verify(metadata, times(4)).awaitUpdate(anyInt(), anyLong()); + verify(metadata, times(6)).fetch(); + + // Should not request update for subsequent `partitionsFor` + producer.partitionsFor(topic); + verify(metadata, times(4)).requestUpdateForTopic(topic); + verify(metadata, times(4)).awaitUpdate(anyInt(), anyLong()); + verify(metadata, times(7)).fetch(); + + producer.close(Duration.ofMillis(0)); + } + + @Test + public void testMetadataExpiry() throws InterruptedException { + Map configs = new HashMap<>(); + configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + ProducerMetadata metadata = mock(ProducerMetadata.class); + + Cluster emptyCluster = new Cluster( + "dummy", + Collections.singletonList(host1), + Collections.emptySet(), + Collections.emptySet(), + Collections.emptySet()); + when(metadata.fetch()).thenReturn(onePartitionCluster, emptyCluster, onePartitionCluster); + + KafkaProducer producer = producerWithOverrideNewSender(configs, metadata); + ProducerRecord record = new ProducerRecord<>(topic, "value"); + producer.send(record); + + // Verify the topic's metadata isn't requested since it's already present. + verify(metadata, times(0)).requestUpdateForTopic(topic); + verify(metadata, times(0)).awaitUpdate(anyInt(), anyLong()); + verify(metadata, times(1)).fetch(); + + // The metadata has been expired. Verify the producer requests the topic's metadata. + producer.send(record, null); + verify(metadata, times(1)).requestUpdateForTopic(topic); + verify(metadata, times(1)).awaitUpdate(anyInt(), anyLong()); + verify(metadata, times(3)).fetch(); + + producer.close(Duration.ofMillis(0)); + } + + @Test + public void testMetadataTimeoutWithMissingTopic() throws Exception { + Map configs = new HashMap<>(); + configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + configs.put(ProducerConfig.MAX_BLOCK_MS_CONFIG, 60000); + + // Create a record with a partition higher than the initial (outdated) partition range + ProducerRecord record = new ProducerRecord<>(topic, 2, null, "value"); + ProducerMetadata metadata = mock(ProducerMetadata.class); + + MockTime mockTime = new MockTime(); + AtomicInteger invocationCount = new AtomicInteger(0); + when(metadata.fetch()).then(invocation -> { + invocationCount.incrementAndGet(); + if (invocationCount.get() == 5) { + mockTime.setCurrentTimeMs(mockTime.milliseconds() + 70000); + } + + return emptyCluster; + }); + + KafkaProducer producer = producerWithOverrideNewSender(configs, metadata, mockTime); + + // Four request updates where the topic isn't present, at which point the timeout expires and a + // TimeoutException is thrown + Future future = producer.send(record); + verify(metadata, times(4)).requestUpdateForTopic(topic); + verify(metadata, times(4)).awaitUpdate(anyInt(), anyLong()); + verify(metadata, times(5)).fetch(); + try { + future.get(); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof TimeoutException); + } finally { + producer.close(Duration.ofMillis(0)); + } + } + + @Test + public void testMetadataWithPartitionOutOfRange() throws Exception { + Map configs = new HashMap<>(); + configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + configs.put(ProducerConfig.MAX_BLOCK_MS_CONFIG, 60000); + + // Create a record with a partition higher than the initial (outdated) partition range + ProducerRecord record = new ProducerRecord<>(topic, 2, null, "value"); + ProducerMetadata metadata = mock(ProducerMetadata.class); + + MockTime mockTime = new MockTime(); + + when(metadata.fetch()).thenReturn(onePartitionCluster, onePartitionCluster, threePartitionCluster); + + KafkaProducer producer = producerWithOverrideNewSender(configs, metadata, mockTime); + // One request update if metadata is available but outdated for the given record + producer.send(record); + verify(metadata, times(2)).requestUpdateForTopic(topic); + verify(metadata, times(2)).awaitUpdate(anyInt(), anyLong()); + verify(metadata, times(3)).fetch(); + + producer.close(Duration.ofMillis(0)); + } + + @Test + public void testMetadataTimeoutWithPartitionOutOfRange() throws Exception { + Map configs = new HashMap<>(); + configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + configs.put(ProducerConfig.MAX_BLOCK_MS_CONFIG, 60000); + + // Create a record with a partition higher than the initial (outdated) partition range + ProducerRecord record = new ProducerRecord<>(topic, 2, null, "value"); + ProducerMetadata metadata = mock(ProducerMetadata.class); + + MockTime mockTime = new MockTime(); + AtomicInteger invocationCount = new AtomicInteger(0); + when(metadata.fetch()).then(invocation -> { + invocationCount.incrementAndGet(); + if (invocationCount.get() == 5) { + mockTime.setCurrentTimeMs(mockTime.milliseconds() + 70000); + } + + return onePartitionCluster; + }); + + KafkaProducer producer = producerWithOverrideNewSender(configs, metadata, mockTime); + + // Four request updates where the requested partition is out of range, at which point the timeout expires + // and a TimeoutException is thrown + Future future = producer.send(record); + verify(metadata, times(4)).requestUpdateForTopic(topic); + verify(metadata, times(4)).awaitUpdate(anyInt(), anyLong()); + verify(metadata, times(5)).fetch(); + try { + future.get(); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof TimeoutException); + } finally { + producer.close(Duration.ofMillis(0)); + } + } + + @Test + public void testTopicRefreshInMetadata() throws InterruptedException { + Map configs = new HashMap<>(); + configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + configs.put(ProducerConfig.MAX_BLOCK_MS_CONFIG, "600000"); + long refreshBackoffMs = 500L; + long metadataExpireMs = 60000L; + long metadataIdleMs = 60000L; + final Time time = new MockTime(); + final ProducerMetadata metadata = new ProducerMetadata(refreshBackoffMs, metadataExpireMs, metadataIdleMs, + new LogContext(), new ClusterResourceListeners(), time); + final String topic = "topic"; + try (KafkaProducer producer = kafkaProducer(configs, + new StringSerializer(), new StringSerializer(), metadata, new MockClient(time, metadata), null, time)) { + + AtomicBoolean running = new AtomicBoolean(true); + Thread t = new Thread(() -> { + long startTimeMs = System.currentTimeMillis(); + while (running.get()) { + while (!metadata.updateRequested() && System.currentTimeMillis() - startTimeMs < 100) + Thread.yield(); + MetadataResponse updateResponse = RequestTestUtils.metadataUpdateWith("kafka-cluster", 1, + singletonMap(topic, Errors.UNKNOWN_TOPIC_OR_PARTITION), emptyMap()); + metadata.updateWithCurrentRequestVersion(updateResponse, false, time.milliseconds()); + time.sleep(60 * 1000L); + } + }); + t.start(); + assertThrows(TimeoutException.class, () -> producer.partitionsFor(topic)); + running.set(false); + t.join(); + } + } + + @Test + public void testTopicExpiryInMetadata() throws InterruptedException { + Map configs = new HashMap<>(); + configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + configs.put(ProducerConfig.MAX_BLOCK_MS_CONFIG, "30000"); + long refreshBackoffMs = 500L; + long metadataExpireMs = 60000L; + long metadataIdleMs = 60000L; + final Time time = new MockTime(); + final ProducerMetadata metadata = new ProducerMetadata(refreshBackoffMs, metadataExpireMs, metadataIdleMs, + new LogContext(), new ClusterResourceListeners(), time); + final String topic = "topic"; + try (KafkaProducer producer = kafkaProducer(configs, new StringSerializer(), + new StringSerializer(), metadata, new MockClient(time, metadata), null, time)) { + + Exchanger exchanger = new Exchanger<>(); + + Thread t = new Thread(() -> { + try { + exchanger.exchange(null); // 1 + while (!metadata.updateRequested()) + Thread.sleep(100); + MetadataResponse updateResponse = RequestTestUtils.metadataUpdateWith(1, singletonMap(topic, 1)); + metadata.updateWithCurrentRequestVersion(updateResponse, false, time.milliseconds()); + exchanger.exchange(null); // 2 + time.sleep(120 * 1000L); + + // Update the metadata again, but it should be expired at this point. + updateResponse = RequestTestUtils.metadataUpdateWith(1, singletonMap(topic, 1)); + metadata.updateWithCurrentRequestVersion(updateResponse, false, time.milliseconds()); + exchanger.exchange(null); // 3 + while (!metadata.updateRequested()) + Thread.sleep(100); + time.sleep(30 * 1000L); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + t.start(); + exchanger.exchange(null); // 1 + assertNotNull(producer.partitionsFor(topic)); + exchanger.exchange(null); // 2 + exchanger.exchange(null); // 3 + assertThrows(TimeoutException.class, () -> producer.partitionsFor(topic)); + t.join(); + } + } + + @SuppressWarnings("unchecked") + @Test + public void testHeaders() { + doTestHeaders(Serializer.class); + } + + private > void doTestHeaders(Class serializerClassToMock) { + Map configs = new HashMap<>(); + configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + Serializer keySerializer = mock(serializerClassToMock); + Serializer valueSerializer = mock(serializerClassToMock); + + long nowMs = Time.SYSTEM.milliseconds(); + String topic = "topic"; + ProducerMetadata metadata = newMetadata(0, 90000); + metadata.add(topic, nowMs); + + MetadataResponse initialUpdateResponse = RequestTestUtils.metadataUpdateWith(1, singletonMap(topic, 1)); + metadata.updateWithCurrentRequestVersion(initialUpdateResponse, false, nowMs); + + KafkaProducer producer = kafkaProducer(configs, keySerializer, valueSerializer, metadata, + null, null, Time.SYSTEM); + + when(keySerializer.serialize(any(), any(), any())).then(invocation -> + invocation.getArgument(2).getBytes()); + when(valueSerializer.serialize(any(), any(), any())).then(invocation -> + invocation.getArgument(2).getBytes()); + + String value = "value"; + String key = "key"; + ProducerRecord record = new ProducerRecord<>(topic, key, value); + + //ensure headers can be mutated pre send. + record.headers().add(new RecordHeader("test", "header2".getBytes())); + producer.send(record, null); + + //ensure headers are closed and cannot be mutated post send + assertThrows(IllegalStateException.class, () -> record.headers().add(new RecordHeader("test", "test".getBytes()))); + + //ensure existing headers are not changed, and last header for key is still original value + assertArrayEquals(record.headers().lastHeader("test").value(), "header2".getBytes()); + + verify(valueSerializer).serialize(topic, record.headers(), value); + verify(keySerializer).serialize(topic, record.headers(), key); + + producer.close(Duration.ofMillis(0)); + } + + @Test + public void closeShouldBeIdempotent() { + Properties producerProps = new Properties(); + producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000"); + Producer producer = new KafkaProducer<>(producerProps, new ByteArraySerializer(), new ByteArraySerializer()); + producer.close(); + producer.close(); + } + + @Test + public void closeWithNegativeTimestampShouldThrow() { + Properties producerProps = new Properties(); + producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000"); + try (Producer producer = new KafkaProducer<>(producerProps, new ByteArraySerializer(), new ByteArraySerializer())) { + assertThrows(IllegalArgumentException.class, () -> producer.close(Duration.ofMillis(-100))); + } + } + + @Test + public void testFlushCompleteSendOfInflightBatches() { + Map configs = new HashMap<>(); + configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000"); + + Time time = new MockTime(1); + MetadataResponse initialUpdateResponse = RequestTestUtils.metadataUpdateWith(1, singletonMap("topic", 1)); + ProducerMetadata metadata = newMetadata(0, Long.MAX_VALUE); + + MockClient client = new MockClient(time, metadata); + client.updateMetadata(initialUpdateResponse); + + try (Producer producer = kafkaProducer(configs, new StringSerializer(), + new StringSerializer(), metadata, client, null, time)) { + ArrayList> futureResponses = new ArrayList<>(); + for (int i = 0; i < 50; i++) { + Future response = producer.send(new ProducerRecord<>("topic", "value" + i)); + futureResponses.add(response); + } + futureResponses.forEach(res -> assertFalse(res.isDone())); + producer.flush(); + futureResponses.forEach(res -> assertTrue(res.isDone())); + } + } + + private static Double getMetricValue(final KafkaProducer producer, final String name) { + Metrics metrics = producer.metrics; + Metric metric = metrics.metric(metrics.metricName(name, "producer-metrics")); + return (Double) metric.metricValue(); + } + + @Test + public void testFlushMeasureLatency() { + Map configs = new HashMap<>(); + configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000"); + + Time time = new MockTime(1); + MetadataResponse initialUpdateResponse = RequestTestUtils.metadataUpdateWith(1, singletonMap("topic", 1)); + ProducerMetadata metadata = newMetadata(0, Long.MAX_VALUE); + + MockClient client = new MockClient(time, metadata); + client.updateMetadata(initialUpdateResponse); + + try (KafkaProducer producer = kafkaProducer( + configs, + new StringSerializer(), + new StringSerializer(), + metadata, + client, + null, + time + )) { + producer.flush(); + double first = getMetricValue(producer, "flush-time-ns-total"); + assertTrue(first > 0); + producer.flush(); + assertTrue(getMetricValue(producer, "flush-time-ns-total") > first); + } + } + + @Test + public void testMetricConfigRecordingLevel() { + Properties props = new Properties(); + props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000"); + try (KafkaProducer producer = new KafkaProducer<>(props, new ByteArraySerializer(), new ByteArraySerializer())) { + assertEquals(Sensor.RecordingLevel.INFO, producer.metrics.config().recordLevel()); + } + + props.put(ProducerConfig.METRICS_RECORDING_LEVEL_CONFIG, "DEBUG"); + try (KafkaProducer producer = new KafkaProducer<>(props, new ByteArraySerializer(), new ByteArraySerializer())) { + assertEquals(Sensor.RecordingLevel.DEBUG, producer.metrics.config().recordLevel()); + } + } + + @Test + public void testInterceptorPartitionSetOnTooLargeRecord() { + Map configs = new HashMap<>(); + configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + configs.put(ProducerConfig.MAX_REQUEST_SIZE_CONFIG, "1"); + String topic = "topic"; + ProducerRecord record = new ProducerRecord<>(topic, "value"); + + long nowMs = Time.SYSTEM.milliseconds(); + ProducerMetadata metadata = newMetadata(0, 90000); + metadata.add(topic, nowMs); + MetadataResponse initialUpdateResponse = RequestTestUtils.metadataUpdateWith(1, singletonMap(topic, 1)); + metadata.updateWithCurrentRequestVersion(initialUpdateResponse, false, nowMs); + + @SuppressWarnings("unchecked") // it is safe to suppress, since this is a mock class + ProducerInterceptors interceptors = mock(ProducerInterceptors.class); + KafkaProducer producer = kafkaProducer(configs, new StringSerializer(), + new StringSerializer(), metadata, null, interceptors, Time.SYSTEM); + + when(interceptors.onSend(any())).then(invocation -> invocation.getArgument(0)); + + producer.send(record); + + verify(interceptors).onSend(record); + verify(interceptors).onSendError(eq(record), notNull(), notNull()); + + producer.close(Duration.ofMillis(0)); + } + + @Test + public void testPartitionsForWithNullTopic() { + Properties props = new Properties(); + props.setProperty(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000"); + try (KafkaProducer producer = new KafkaProducer<>(props, new ByteArraySerializer(), new ByteArraySerializer())) { + assertThrows(NullPointerException.class, () -> producer.partitionsFor(null)); + } + } + + @Test + public void testInitTransactionTimeout() { + Map configs = new HashMap<>(); + configs.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "bad-transaction"); + configs.put(ProducerConfig.MAX_BLOCK_MS_CONFIG, 500); + configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000"); + + Time time = new MockTime(1); + MetadataResponse initialUpdateResponse = RequestTestUtils.metadataUpdateWith(1, singletonMap("topic", 1)); + ProducerMetadata metadata = newMetadata(0, Long.MAX_VALUE); + metadata.updateWithCurrentRequestVersion(initialUpdateResponse, false, time.milliseconds()); + + MockClient client = new MockClient(time, metadata); + + try (Producer producer = kafkaProducer(configs, new StringSerializer(), + new StringSerializer(), metadata, client, null, time)) { + client.prepareResponse( + request -> request instanceof FindCoordinatorRequest && + ((FindCoordinatorRequest) request).data().keyType() == FindCoordinatorRequest.CoordinatorType.TRANSACTION.id(), + FindCoordinatorResponse.prepareResponse(Errors.NONE, "bad-transaction", host1)); + + assertThrows(TimeoutException.class, producer::initTransactions); + + client.prepareResponse( + request -> request instanceof FindCoordinatorRequest && + ((FindCoordinatorRequest) request).data().keyType() == FindCoordinatorRequest.CoordinatorType.TRANSACTION.id(), + FindCoordinatorResponse.prepareResponse(Errors.NONE, "bad-transaction", host1)); + + client.prepareResponse(initProducerIdResponse(1L, (short) 5, Errors.NONE)); + + // retry initialization should work + producer.initTransactions(); + } + } + + @Test + public void testInitTransactionWhileThrottled() { + Map configs = new HashMap<>(); + configs.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "some.id"); + configs.put(ProducerConfig.MAX_BLOCK_MS_CONFIG, 10000); + configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000"); + + Time time = new MockTime(1); + MetadataResponse initialUpdateResponse = RequestTestUtils.metadataUpdateWith(1, singletonMap("topic", 1)); + ProducerMetadata metadata = newMetadata(0, Long.MAX_VALUE); + + MockClient client = new MockClient(time, metadata); + client.updateMetadata(initialUpdateResponse); + + Node node = metadata.fetch().nodes().get(0); + client.throttle(node, 5000); + + client.prepareResponse(FindCoordinatorResponse.prepareResponse(Errors.NONE, "some.id", host1)); + client.prepareResponse(initProducerIdResponse(1L, (short) 5, Errors.NONE)); + + try (Producer producer = kafkaProducer(configs, new StringSerializer(), + new StringSerializer(), metadata, client, null, time)) { + producer.initTransactions(); + } + } + + @Test + public void testAbortTransaction() { + Map configs = new HashMap<>(); + configs.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "some.id"); + configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000"); + + Time time = new MockTime(1); + MetadataResponse initialUpdateResponse = RequestTestUtils.metadataUpdateWith(1, singletonMap("topic", 1)); + ProducerMetadata metadata = newMetadata(0, Long.MAX_VALUE); + + MockClient client = new MockClient(time, metadata); + client.updateMetadata(initialUpdateResponse); + + client.prepareResponse(FindCoordinatorResponse.prepareResponse(Errors.NONE, "some.id", host1)); + client.prepareResponse(initProducerIdResponse(1L, (short) 5, Errors.NONE)); + client.prepareResponse(endTxnResponse(Errors.NONE)); + + try (Producer producer = kafkaProducer(configs, new StringSerializer(), + new StringSerializer(), metadata, client, null, time)) { + producer.initTransactions(); + producer.beginTransaction(); + producer.abortTransaction(); + } + } + + @Test + public void testMeasureAbortTransactionDuration() { + Map configs = new HashMap<>(); + configs.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "some.id"); + configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000"); + Time time = new MockTime(1); + MetadataResponse initialUpdateResponse = RequestTestUtils.metadataUpdateWith(1, singletonMap("topic", 1)); + ProducerMetadata metadata = newMetadata(0, Long.MAX_VALUE); + MockClient client = new MockClient(time, metadata); + client.updateMetadata(initialUpdateResponse); + client.prepareResponse(FindCoordinatorResponse.prepareResponse(Errors.NONE, "some.id", host1)); + client.prepareResponse(initProducerIdResponse(1L, (short) 5, Errors.NONE)); + + try (KafkaProducer producer = kafkaProducer(configs, new StringSerializer(), + new StringSerializer(), metadata, client, null, time)) { + producer.initTransactions(); + + client.prepareResponse(endTxnResponse(Errors.NONE)); + producer.beginTransaction(); + producer.abortTransaction(); + double first = getMetricValue(producer, "txn-abort-time-ns-total"); + assertTrue(first > 0); + + client.prepareResponse(endTxnResponse(Errors.NONE)); + producer.beginTransaction(); + producer.abortTransaction(); + assertTrue(getMetricValue(producer, "txn-abort-time-ns-total") > first); + } + } + + @Test + public void testSendTxnOffsetsWithGroupId() { + Map configs = new HashMap<>(); + configs.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "some.id"); + configs.put(ProducerConfig.MAX_BLOCK_MS_CONFIG, 10000); + configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000"); + + Time time = new MockTime(1); + MetadataResponse initialUpdateResponse = RequestTestUtils.metadataUpdateWith(1, singletonMap("topic", 1)); + ProducerMetadata metadata = newMetadata(0, Long.MAX_VALUE); + + MockClient client = new MockClient(time, metadata); + client.updateMetadata(initialUpdateResponse); + + Node node = metadata.fetch().nodes().get(0); + client.throttle(node, 5000); + + client.prepareResponse(FindCoordinatorResponse.prepareResponse(Errors.NONE, "some.id", host1)); + client.prepareResponse(initProducerIdResponse(1L, (short) 5, Errors.NONE)); + client.prepareResponse(addOffsetsToTxnResponse(Errors.NONE)); + client.prepareResponse(FindCoordinatorResponse.prepareResponse(Errors.NONE, "some.id", host1)); + String groupId = "group"; + client.prepareResponse(request -> + ((TxnOffsetCommitRequest) request).data().groupId().equals(groupId), + txnOffsetsCommitResponse(Collections.singletonMap( + new TopicPartition("topic", 0), Errors.NONE))); + client.prepareResponse(endTxnResponse(Errors.NONE)); + + try (Producer producer = kafkaProducer(configs, new StringSerializer(), + new StringSerializer(), metadata, client, null, time)) { + producer.initTransactions(); + producer.beginTransaction(); + producer.sendOffsetsToTransaction(Collections.emptyMap(), new ConsumerGroupMetadata(groupId)); + producer.commitTransaction(); + } + } + + private void assertDurationAtLeast(KafkaProducer producer, String name, double floor) { + getAndAssertDurationAtLeast(producer, name, floor); + } + + private double getAndAssertDurationAtLeast(KafkaProducer producer, String name, double floor) { + double value = getMetricValue(producer, name); + assertTrue(value >= floor); + return value; + } + + @Test + public void testMeasureTransactionDurations() { + Map configs = new HashMap<>(); + configs.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "some.id"); + configs.put(ProducerConfig.MAX_BLOCK_MS_CONFIG, 10000); + configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000"); + Duration tick = Duration.ofSeconds(1); + Time time = new MockTime(tick.toMillis()); + MetadataResponse initialUpdateResponse = RequestTestUtils.metadataUpdateWith(1, singletonMap("topic", 1)); + ProducerMetadata metadata = newMetadata(0, Long.MAX_VALUE); + + MockClient client = new MockClient(time, metadata); + client.updateMetadata(initialUpdateResponse); + client.prepareResponse(FindCoordinatorResponse.prepareResponse(Errors.NONE, "some.id", host1)); + client.prepareResponse(initProducerIdResponse(1L, (short) 5, Errors.NONE)); + + try (KafkaProducer producer = kafkaProducer(configs, new StringSerializer(), + new StringSerializer(), metadata, client, null, time)) { + producer.initTransactions(); + assertDurationAtLeast(producer, "txn-init-time-ns-total", tick.toNanos()); + + client.prepareResponse(addOffsetsToTxnResponse(Errors.NONE)); + client.prepareResponse(FindCoordinatorResponse.prepareResponse(Errors.NONE, "some.id", host1)); + client.prepareResponse(txnOffsetsCommitResponse(Collections.singletonMap( + new TopicPartition("topic", 0), Errors.NONE))); + client.prepareResponse(endTxnResponse(Errors.NONE)); + producer.beginTransaction(); + double beginFirst = getAndAssertDurationAtLeast(producer, "txn-begin-time-ns-total", tick.toNanos()); + producer.sendOffsetsToTransaction(Collections.singletonMap( + new TopicPartition("topic", 0), + new OffsetAndMetadata(5L)), + new ConsumerGroupMetadata("group")); + double sendOffFirst = getAndAssertDurationAtLeast(producer, "txn-send-offsets-time-ns-total", tick.toNanos()); + producer.commitTransaction(); + double commitFirst = getAndAssertDurationAtLeast(producer, "txn-commit-time-ns-total", tick.toNanos()); + + client.prepareResponse(addOffsetsToTxnResponse(Errors.NONE)); + client.prepareResponse(txnOffsetsCommitResponse(Collections.singletonMap( + new TopicPartition("topic", 0), Errors.NONE))); + client.prepareResponse(endTxnResponse(Errors.NONE)); + producer.beginTransaction(); + assertDurationAtLeast(producer, "txn-begin-time-ns-total", beginFirst + tick.toNanos()); + producer.sendOffsetsToTransaction(Collections.singletonMap( + new TopicPartition("topic", 0), + new OffsetAndMetadata(10L)), + new ConsumerGroupMetadata("group")); + assertDurationAtLeast(producer, "txn-send-offsets-time-ns-total", sendOffFirst + tick.toNanos()); + producer.commitTransaction(); + assertDurationAtLeast(producer, "txn-commit-time-ns-total", commitFirst + tick.toNanos()); + } + } + + @Test + public void testSendTxnOffsetsWithGroupMetadata() { + final short maxVersion = (short) 3; + Map configs = new HashMap<>(); + configs.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "some.id"); + configs.put(ProducerConfig.MAX_BLOCK_MS_CONFIG, 10000); + configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000"); + + Time time = new MockTime(1); + MetadataResponse initialUpdateResponse = RequestTestUtils.metadataUpdateWith(1, singletonMap("topic", 1)); + ProducerMetadata metadata = newMetadata(0, Long.MAX_VALUE); + + MockClient client = new MockClient(time, metadata); + client.updateMetadata(initialUpdateResponse); + client.setNodeApiVersions(NodeApiVersions.create(ApiKeys.TXN_OFFSET_COMMIT.id, (short) 0, maxVersion)); + + Node node = metadata.fetch().nodes().get(0); + client.throttle(node, 5000); + + client.prepareResponse(FindCoordinatorResponse.prepareResponse(Errors.NONE, "some.id", host1)); + client.prepareResponse(initProducerIdResponse(1L, (short) 5, Errors.NONE)); + client.prepareResponse(addOffsetsToTxnResponse(Errors.NONE)); + client.prepareResponse(FindCoordinatorResponse.prepareResponse(Errors.NONE, "some.id", host1)); + String groupId = "group"; + String memberId = "member"; + int generationId = 5; + String groupInstanceId = "instance"; + client.prepareResponse(request -> { + TxnOffsetCommitRequestData data = ((TxnOffsetCommitRequest) request).data(); + return data.groupId().equals(groupId) && + data.memberId().equals(memberId) && + data.generationId() == generationId && + data.groupInstanceId().equals(groupInstanceId); + }, txnOffsetsCommitResponse(Collections.singletonMap( + new TopicPartition("topic", 0), Errors.NONE))); + client.prepareResponse(endTxnResponse(Errors.NONE)); + + try (Producer producer = kafkaProducer(configs, new StringSerializer(), + new StringSerializer(), metadata, client, null, time)) { + producer.initTransactions(); + producer.beginTransaction(); + ConsumerGroupMetadata groupMetadata = new ConsumerGroupMetadata(groupId, + generationId, memberId, Optional.of(groupInstanceId)); + + producer.sendOffsetsToTransaction(Collections.emptyMap(), groupMetadata); + producer.commitTransaction(); + } + } + + @Test + public void testNullGroupMetadataInSendOffsets() { + verifyInvalidGroupMetadata(null); + } + + @Test + public void testInvalidGenerationIdAndMemberIdCombinedInSendOffsets() { + verifyInvalidGroupMetadata(new ConsumerGroupMetadata("group", 2, JoinGroupRequest.UNKNOWN_MEMBER_ID, Optional.empty())); + } + + private void verifyInvalidGroupMetadata(ConsumerGroupMetadata groupMetadata) { + Map configs = new HashMap<>(); + configs.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "some.id"); + configs.put(ProducerConfig.MAX_BLOCK_MS_CONFIG, 10000); + configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000"); + + Time time = new MockTime(1); + MetadataResponse initialUpdateResponse = RequestTestUtils.metadataUpdateWith(1, singletonMap("topic", 1)); + ProducerMetadata metadata = newMetadata(0, Long.MAX_VALUE); + + MockClient client = new MockClient(time, metadata); + client.updateMetadata(initialUpdateResponse); + + Node node = metadata.fetch().nodes().get(0); + client.throttle(node, 5000); + + client.prepareResponse(FindCoordinatorResponse.prepareResponse(Errors.NONE, "some.id", host1)); + client.prepareResponse(initProducerIdResponse(1L, (short) 5, Errors.NONE)); + + try (Producer producer = kafkaProducer(configs, new StringSerializer(), + new StringSerializer(), metadata, client, null, time)) { + producer.initTransactions(); + producer.beginTransaction(); + assertThrows(IllegalArgumentException.class, + () -> producer.sendOffsetsToTransaction(Collections.emptyMap(), groupMetadata)); + } + } + + private InitProducerIdResponse initProducerIdResponse(long producerId, short producerEpoch, Errors error) { + InitProducerIdResponseData responseData = new InitProducerIdResponseData() + .setErrorCode(error.code()) + .setProducerEpoch(producerEpoch) + .setProducerId(producerId) + .setThrottleTimeMs(0); + return new InitProducerIdResponse(responseData); + } + + private AddOffsetsToTxnResponse addOffsetsToTxnResponse(Errors error) { + return new AddOffsetsToTxnResponse(new AddOffsetsToTxnResponseData() + .setErrorCode(error.code()) + .setThrottleTimeMs(10)); + } + + private TxnOffsetCommitResponse txnOffsetsCommitResponse(Map errorMap) { + return new TxnOffsetCommitResponse(10, errorMap); + } + + private EndTxnResponse endTxnResponse(Errors error) { + return new EndTxnResponse(new EndTxnResponseData() + .setErrorCode(error.code()) + .setThrottleTimeMs(0)); + } + + @Test + public void testOnlyCanExecuteCloseAfterInitTransactionsTimeout() { + Map configs = new HashMap<>(); + configs.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "bad-transaction"); + configs.put(ProducerConfig.MAX_BLOCK_MS_CONFIG, 5); + configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000"); + + Time time = new MockTime(); + MetadataResponse initialUpdateResponse = RequestTestUtils.metadataUpdateWith(1, singletonMap("topic", 1)); + ProducerMetadata metadata = newMetadata(0, Long.MAX_VALUE); + metadata.updateWithCurrentRequestVersion(initialUpdateResponse, false, time.milliseconds()); + + MockClient client = new MockClient(time, metadata); + + Producer producer = kafkaProducer(configs, new StringSerializer(), new StringSerializer(), + metadata, client, null, time); + assertThrows(TimeoutException.class, producer::initTransactions); + // other transactional operations should not be allowed if we catch the error after initTransactions failed + try { + assertThrows(KafkaException.class, producer::beginTransaction); + } finally { + producer.close(Duration.ofMillis(0)); + } + } + + @Test + public void testSendToInvalidTopic() throws Exception { + Map configs = new HashMap<>(); + configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000"); + configs.put(ProducerConfig.MAX_BLOCK_MS_CONFIG, "15000"); + + Time time = new MockTime(); + MetadataResponse initialUpdateResponse = RequestTestUtils.metadataUpdateWith(1, emptyMap()); + ProducerMetadata metadata = newMetadata(0, Long.MAX_VALUE); + metadata.updateWithCurrentRequestVersion(initialUpdateResponse, false, time.milliseconds()); + + MockClient client = new MockClient(time, metadata); + + Producer producer = kafkaProducer(configs, new StringSerializer(), new StringSerializer(), + metadata, client, null, time); + + String invalidTopicName = "topic abc"; // Invalid topic name due to space + ProducerRecord record = new ProducerRecord<>(invalidTopicName, "HelloKafka"); + + List topicMetadata = new ArrayList<>(); + topicMetadata.add(new MetadataResponse.TopicMetadata(Errors.INVALID_TOPIC_EXCEPTION, + invalidTopicName, false, Collections.emptyList())); + MetadataResponse updateResponse = RequestTestUtils.metadataResponse( + new ArrayList<>(initialUpdateResponse.brokers()), + initialUpdateResponse.clusterId(), + initialUpdateResponse.controller().id(), + topicMetadata); + client.prepareMetadataUpdate(updateResponse); + + Future future = producer.send(record); + + assertEquals(Collections.singleton(invalidTopicName), + metadata.fetch().invalidTopics(), "Cluster has incorrect invalid topic list."); + TestUtils.assertFutureError(future, InvalidTopicException.class); + + producer.close(Duration.ofMillis(0)); + } + + @Test + public void testCloseWhenWaitingForMetadataUpdate() throws InterruptedException { + Map configs = new HashMap<>(); + configs.put(ProducerConfig.MAX_BLOCK_MS_CONFIG, Long.MAX_VALUE); + configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000"); + + // Simulate a case where metadata for a particular topic is not available. This will cause KafkaProducer#send to + // block in Metadata#awaitUpdate for the configured max.block.ms. When close() is invoked, KafkaProducer#send should + // return with a KafkaException. + String topicName = "test"; + Time time = Time.SYSTEM; + MetadataResponse initialUpdateResponse = RequestTestUtils.metadataUpdateWith(1, emptyMap()); + ProducerMetadata metadata = new ProducerMetadata(0, Long.MAX_VALUE, Long.MAX_VALUE, + new LogContext(), new ClusterResourceListeners(), time); + metadata.updateWithCurrentRequestVersion(initialUpdateResponse, false, time.milliseconds()); + MockClient client = new MockClient(time, metadata); + + Producer producer = kafkaProducer(configs, new StringSerializer(), new StringSerializer(), + metadata, client, null, time); + + ExecutorService executor = Executors.newSingleThreadExecutor(); + final AtomicReference sendException = new AtomicReference<>(); + + try { + executor.submit(() -> { + try { + // Metadata for topic "test" will not be available which will cause us to block indefinitely until + // KafkaProducer#close is invoked. + producer.send(new ProducerRecord<>(topicName, "key", "value")); + fail(); + } catch (Exception e) { + sendException.set(e); + } + }); + + // Wait until metadata update for the topic has been requested + TestUtils.waitForCondition(() -> metadata.containsTopic(topicName), + "Timeout when waiting for topic to be added to metadata"); + producer.close(Duration.ofMillis(0)); + TestUtils.waitForCondition(() -> sendException.get() != null, "No producer exception within timeout"); + assertEquals(KafkaException.class, sendException.get().getClass()); + } finally { + executor.shutdownNow(); + } + } + + @Test + public void testTransactionalMethodThrowsWhenSenderClosed() { + Map configs = new HashMap<>(); + configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000"); + configs.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "this-is-a-transactional-id"); + + Time time = new MockTime(); + MetadataResponse initialUpdateResponse = RequestTestUtils.metadataUpdateWith(1, emptyMap()); + ProducerMetadata metadata = newMetadata(0, Long.MAX_VALUE); + metadata.updateWithCurrentRequestVersion(initialUpdateResponse, false, time.milliseconds()); + + MockClient client = new MockClient(time, metadata); + + Producer producer = kafkaProducer(configs, new StringSerializer(), new StringSerializer(), + metadata, client, null, time); + producer.close(); + assertThrows(IllegalStateException.class, producer::initTransactions); + } + + @Test + public void testCloseIsForcedOnPendingFindCoordinator() throws InterruptedException { + Map configs = new HashMap<>(); + configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000"); + configs.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "this-is-a-transactional-id"); + + Time time = new MockTime(); + MetadataResponse initialUpdateResponse = RequestTestUtils.metadataUpdateWith(1, singletonMap("testTopic", 1)); + ProducerMetadata metadata = newMetadata(0, Long.MAX_VALUE); + metadata.updateWithCurrentRequestVersion(initialUpdateResponse, false, time.milliseconds()); + + MockClient client = new MockClient(time, metadata); + + Producer producer = kafkaProducer(configs, new StringSerializer(), new StringSerializer(), + metadata, client, null, time); + + ExecutorService executorService = Executors.newSingleThreadExecutor(); + CountDownLatch assertionDoneLatch = new CountDownLatch(1); + executorService.submit(() -> { + assertThrows(KafkaException.class, producer::initTransactions); + assertionDoneLatch.countDown(); + }); + + client.waitForRequests(1, 2000); + producer.close(Duration.ofMillis(1000)); + assertionDoneLatch.await(5000, TimeUnit.MILLISECONDS); + } + + @Test + public void testCloseIsForcedOnPendingInitProducerId() throws InterruptedException { + Map configs = new HashMap<>(); + configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000"); + configs.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "this-is-a-transactional-id"); + + Time time = new MockTime(); + MetadataResponse initialUpdateResponse = RequestTestUtils.metadataUpdateWith(1, singletonMap("testTopic", 1)); + ProducerMetadata metadata = newMetadata(0, Long.MAX_VALUE); + metadata.updateWithCurrentRequestVersion(initialUpdateResponse, false, time.milliseconds()); + + MockClient client = new MockClient(time, metadata); + + Producer producer = kafkaProducer(configs, new StringSerializer(), new StringSerializer(), + metadata, client, null, time); + + ExecutorService executorService = Executors.newSingleThreadExecutor(); + CountDownLatch assertionDoneLatch = new CountDownLatch(1); + client.prepareResponse(FindCoordinatorResponse.prepareResponse(Errors.NONE, "this-is-a-transactional-id", host1)); + executorService.submit(() -> { + assertThrows(KafkaException.class, producer::initTransactions); + assertionDoneLatch.countDown(); + }); + + client.waitForRequests(1, 2000); + producer.close(Duration.ofMillis(1000)); + assertionDoneLatch.await(5000, TimeUnit.MILLISECONDS); + } + + @Test + public void testCloseIsForcedOnPendingAddOffsetRequest() throws InterruptedException { + Map configs = new HashMap<>(); + configs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9000"); + configs.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "this-is-a-transactional-id"); + + Time time = new MockTime(); + MetadataResponse initialUpdateResponse = RequestTestUtils.metadataUpdateWith(1, singletonMap("testTopic", 1)); + ProducerMetadata metadata = newMetadata(0, Long.MAX_VALUE); + metadata.updateWithCurrentRequestVersion(initialUpdateResponse, false, time.milliseconds()); + + MockClient client = new MockClient(time, metadata); + + Producer producer = kafkaProducer(configs, new StringSerializer(), new StringSerializer(), + metadata, client, null, time); + + ExecutorService executorService = Executors.newSingleThreadExecutor(); + CountDownLatch assertionDoneLatch = new CountDownLatch(1); + client.prepareResponse(FindCoordinatorResponse.prepareResponse(Errors.NONE, "this-is-a-transactional-id", host1)); + executorService.submit(() -> { + assertThrows(KafkaException.class, producer::initTransactions); + assertionDoneLatch.countDown(); + }); + + client.waitForRequests(1, 2000); + producer.close(Duration.ofMillis(1000)); + assertionDoneLatch.await(5000, TimeUnit.MILLISECONDS); + } + + @Test + public void testProducerJmxPrefix() throws Exception { + Map props = new HashMap<>(); + props.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + props.put("client.id", "client-1"); + + KafkaProducer producer = new KafkaProducer<>( + props, new StringSerializer(), new StringSerializer()); + + MBeanServer server = ManagementFactory.getPlatformMBeanServer(); + MetricName testMetricName = producer.metrics.metricName("test-metric", + "grp1", "test metric"); + producer.metrics.addMetric(testMetricName, new Avg()); + assertNotNull(server.getObjectInstance(new ObjectName("kafka.producer:type=grp1,client-id=client-1"))); + producer.close(); + } + + private static ProducerMetadata newMetadata(long refreshBackoffMs, long expirationMs) { + return new ProducerMetadata(refreshBackoffMs, expirationMs, DEFAULT_METADATA_IDLE_MS, + new LogContext(), new ClusterResourceListeners(), Time.SYSTEM); + } + + @Test + public void configurableObjectsShouldSeeGeneratedClientId() { + Properties props = new Properties(); + props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + props.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, SerializerForClientId.class.getName()); + props.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, SerializerForClientId.class.getName()); + props.put(ProducerConfig.PARTITIONER_CLASS_CONFIG, PartitionerForClientId.class.getName()); + props.put(ProducerConfig.INTERCEPTOR_CLASSES_CONFIG, ProducerInterceptorForClientId.class.getName()); + + KafkaProducer producer = new KafkaProducer<>(props); + assertNotNull(producer.getClientId()); + assertNotEquals(0, producer.getClientId().length()); + assertEquals(4, CLIENT_IDS.size()); + CLIENT_IDS.forEach(id -> assertEquals(id, producer.getClientId())); + producer.close(); + } + + @Test + public void testUnusedConfigs() { + Map props = new HashMap<>(); + props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9999"); + props.put(SslConfigs.SSL_PROTOCOL_CONFIG, "TLS"); + ProducerConfig config = new ProducerConfig(ProducerConfig.appendSerializerToConfig(props, + new StringSerializer(), new StringSerializer())); + + assertTrue(config.unused().contains(SslConfigs.SSL_PROTOCOL_CONFIG)); + + try (KafkaProducer producer = new KafkaProducer<>(config, null, null, + null, null, null, Time.SYSTEM)) { + assertTrue(config.unused().contains(SslConfigs.SSL_PROTOCOL_CONFIG)); + } + } + + @Test + public void testNullTopicName() { + // send a record with null topic should fail + assertThrows(IllegalArgumentException.class, () -> new ProducerRecord<>(null, 1, + "key".getBytes(StandardCharsets.UTF_8), "value".getBytes(StandardCharsets.UTF_8))); + } + + private static final List CLIENT_IDS = new ArrayList<>(); + + public static class SerializerForClientId implements Serializer { + @Override + public void configure(Map configs, boolean isKey) { + CLIENT_IDS.add(configs.get(ProducerConfig.CLIENT_ID_CONFIG).toString()); + } + + @Override + public byte[] serialize(String topic, byte[] data) { + return data; + } + } + + public static class PartitionerForClientId implements Partitioner { + + @Override + public int partition(String topic, Object key, byte[] keyBytes, Object value, byte[] valueBytes, Cluster cluster) { + return 0; + } + + @Override + public void close() { + + } + + @Override + public void configure(Map configs) { + CLIENT_IDS.add(configs.get(ProducerConfig.CLIENT_ID_CONFIG).toString()); + } + } + + public static class ProducerInterceptorForClientId implements ProducerInterceptor { + + @Override + public ProducerRecord onSend(ProducerRecord record) { + return record; + } + + @Override + public void onAcknowledgement(RecordMetadata metadata, Exception exception) { + } + + @Override + public void close() { + } + + @Override + public void configure(Map configs) { + CLIENT_IDS.add(configs.get(ProducerConfig.CLIENT_ID_CONFIG).toString()); + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/MockProducerTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/MockProducerTest.java new file mode 100644 index 0000000..ca14ab0 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/producer/MockProducerTest.java @@ -0,0 +1,771 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer; + +import org.apache.kafka.clients.consumer.ConsumerGroupMetadata; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.producer.internals.DefaultPartitioner; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.ProducerFencedException; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.test.MockSerializer; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; + +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class MockProducerTest { + + private final String topic = "topic"; + private MockProducer producer; + private final ProducerRecord record1 = new ProducerRecord<>(topic, "key1".getBytes(), "value1".getBytes()); + private final ProducerRecord record2 = new ProducerRecord<>(topic, "key2".getBytes(), "value2".getBytes()); + private final String groupId = "group"; + + private void buildMockProducer(boolean autoComplete) { + this.producer = new MockProducer<>(autoComplete, new MockSerializer(), new MockSerializer()); + } + + @AfterEach + public void cleanup() { + if (this.producer != null && !this.producer.closed()) + this.producer.close(); + } + + @Test + public void testAutoCompleteMock() throws Exception { + buildMockProducer(true); + Future metadata = producer.send(record1); + assertTrue(metadata.isDone(), "Send should be immediately complete"); + assertFalse(isError(metadata), "Send should be successful"); + assertEquals(0L, metadata.get().offset(), "Offset should be 0"); + assertEquals(topic, metadata.get().topic()); + assertEquals(singletonList(record1), producer.history(), "We should have the record in our history"); + producer.clear(); + assertEquals(0, producer.history().size(), "Clear should erase our history"); + } + + @Test + public void testPartitioner() throws Exception { + PartitionInfo partitionInfo0 = new PartitionInfo(topic, 0, null, null, null); + PartitionInfo partitionInfo1 = new PartitionInfo(topic, 1, null, null, null); + Cluster cluster = new Cluster(null, new ArrayList<>(0), asList(partitionInfo0, partitionInfo1), + Collections.emptySet(), Collections.emptySet()); + MockProducer producer = new MockProducer<>(cluster, true, new DefaultPartitioner(), new StringSerializer(), new StringSerializer()); + ProducerRecord record = new ProducerRecord<>(topic, "key", "value"); + Future metadata = producer.send(record); + assertEquals(1, metadata.get().partition(), "Partition should be correct"); + producer.clear(); + assertEquals(0, producer.history().size(), "Clear should erase our history"); + producer.close(); + } + + @Test + public void testManualCompletion() throws Exception { + buildMockProducer(false); + Future md1 = producer.send(record1); + assertFalse(md1.isDone(), "Send shouldn't have completed"); + Future md2 = producer.send(record2); + assertFalse(md2.isDone(), "Send shouldn't have completed"); + assertTrue(producer.completeNext(), "Complete the first request"); + assertFalse(isError(md1), "Requst should be successful"); + assertFalse(md2.isDone(), "Second request still incomplete"); + IllegalArgumentException e = new IllegalArgumentException("blah"); + assertTrue(producer.errorNext(e), "Complete the second request with an error"); + try { + md2.get(); + fail("Expected error to be thrown"); + } catch (ExecutionException err) { + assertEquals(e, err.getCause()); + } + assertFalse(producer.completeNext(), "No more requests to complete"); + + Future md3 = producer.send(record1); + Future md4 = producer.send(record2); + assertTrue(!md3.isDone() && !md4.isDone(), "Requests should not be completed."); + producer.flush(); + assertTrue(md3.isDone() && md4.isDone(), "Requests should be completed."); + } + + @Test + public void shouldInitTransactions() { + buildMockProducer(true); + producer.initTransactions(); + assertTrue(producer.transactionInitialized()); + } + + @Test + public void shouldThrowOnInitTransactionIfProducerAlreadyInitializedForTransactions() { + buildMockProducer(true); + producer.initTransactions(); + assertThrows(IllegalStateException.class, producer::initTransactions); + } + + @Test + public void shouldThrowOnBeginTransactionIfTransactionsNotInitialized() { + buildMockProducer(true); + assertThrows(IllegalStateException.class, producer::beginTransaction); + } + + @Test + public void shouldBeginTransactions() { + buildMockProducer(true); + producer.initTransactions(); + producer.beginTransaction(); + assertTrue(producer.transactionInFlight()); + } + + @Test + public void shouldThrowOnBeginTransactionsIfTransactionInflight() { + buildMockProducer(true); + producer.initTransactions(); + producer.beginTransaction(); + assertThrows(IllegalStateException.class, () -> producer.beginTransaction()); + } + + @Test + public void shouldThrowOnSendOffsetsToTransactionIfTransactionsNotInitialized() { + buildMockProducer(true); + assertThrows(IllegalStateException.class, () -> producer.sendOffsetsToTransaction(null, new ConsumerGroupMetadata(groupId))); + } + + @Test + public void shouldThrowOnSendOffsetsToTransactionTransactionIfNoTransactionGotStarted() { + buildMockProducer(true); + producer.initTransactions(); + assertThrows(IllegalStateException.class, () -> producer.sendOffsetsToTransaction(null, new ConsumerGroupMetadata(groupId))); + } + + @Test + public void shouldThrowOnCommitIfTransactionsNotInitialized() { + buildMockProducer(true); + assertThrows(IllegalStateException.class, producer::commitTransaction); + } + + @Test + public void shouldThrowOnCommitTransactionIfNoTransactionGotStarted() { + buildMockProducer(true); + producer.initTransactions(); + assertThrows(IllegalStateException.class, producer::commitTransaction); + } + + @Test + public void shouldCommitEmptyTransaction() { + buildMockProducer(true); + producer.initTransactions(); + producer.beginTransaction(); + producer.commitTransaction(); + assertFalse(producer.transactionInFlight()); + assertTrue(producer.transactionCommitted()); + assertFalse(producer.transactionAborted()); + } + + @Test + public void shouldCountCommittedTransaction() { + buildMockProducer(true); + producer.initTransactions(); + producer.beginTransaction(); + + assertEquals(0L, producer.commitCount()); + producer.commitTransaction(); + assertEquals(1L, producer.commitCount()); + } + + @Test + public void shouldNotCountAbortedTransaction() { + buildMockProducer(true); + producer.initTransactions(); + + producer.beginTransaction(); + producer.abortTransaction(); + + producer.beginTransaction(); + producer.commitTransaction(); + assertEquals(1L, producer.commitCount()); + } + + @Test + public void shouldThrowOnAbortIfTransactionsNotInitialized() { + buildMockProducer(true); + assertThrows(IllegalStateException.class, () -> producer.abortTransaction()); + } + + @Test + public void shouldThrowOnAbortTransactionIfNoTransactionGotStarted() { + buildMockProducer(true); + producer.initTransactions(); + assertThrows(IllegalStateException.class, producer::abortTransaction); + } + + @Test + public void shouldAbortEmptyTransaction() { + buildMockProducer(true); + producer.initTransactions(); + producer.beginTransaction(); + producer.abortTransaction(); + assertFalse(producer.transactionInFlight()); + assertTrue(producer.transactionAborted()); + assertFalse(producer.transactionCommitted()); + } + + @Test + public void shouldThrowFenceProducerIfTransactionsNotInitialized() { + buildMockProducer(true); + assertThrows(IllegalStateException.class, () -> producer.fenceProducer()); + } + + @Test + public void shouldThrowOnBeginTransactionsIfProducerGotFenced() { + buildMockProducer(true); + producer.initTransactions(); + producer.fenceProducer(); + assertThrows(ProducerFencedException.class, producer::beginTransaction); + } + + @Test + public void shouldThrowOnSendIfProducerGotFenced() { + buildMockProducer(true); + producer.initTransactions(); + producer.fenceProducer(); + Throwable e = assertThrows(KafkaException.class, () -> producer.send(null)); + assertTrue(e.getCause() instanceof ProducerFencedException, "The root cause of the exception should be ProducerFenced"); + } + + @Test + public void shouldThrowOnSendOffsetsToTransactionByGroupIdIfProducerGotFenced() { + buildMockProducer(true); + producer.initTransactions(); + producer.fenceProducer(); + assertThrows(ProducerFencedException.class, () -> producer.sendOffsetsToTransaction(null, new ConsumerGroupMetadata(groupId))); + } + + @Test + public void shouldThrowOnSendOffsetsToTransactionByGroupMetadataIfProducerGotFenced() { + buildMockProducer(true); + producer.initTransactions(); + producer.fenceProducer(); + assertThrows(ProducerFencedException.class, () -> producer.sendOffsetsToTransaction(null, new ConsumerGroupMetadata(groupId))); + } + + @Test + public void shouldThrowOnCommitTransactionIfProducerGotFenced() { + buildMockProducer(true); + producer.initTransactions(); + producer.fenceProducer(); + assertThrows(ProducerFencedException.class, producer::commitTransaction); + } + + @Test + public void shouldThrowOnAbortTransactionIfProducerGotFenced() { + buildMockProducer(true); + producer.initTransactions(); + producer.fenceProducer(); + assertThrows(ProducerFencedException.class, producer::abortTransaction); + } + + @Test + public void shouldPublishMessagesOnlyAfterCommitIfTransactionsAreEnabled() { + buildMockProducer(true); + producer.initTransactions(); + producer.beginTransaction(); + + producer.send(record1); + producer.send(record2); + + assertTrue(producer.history().isEmpty()); + + producer.commitTransaction(); + + List> expectedResult = new ArrayList<>(); + expectedResult.add(record1); + expectedResult.add(record2); + + assertEquals(expectedResult, producer.history()); + } + + @Test + public void shouldFlushOnCommitForNonAutoCompleteIfTransactionsAreEnabled() { + buildMockProducer(false); + producer.initTransactions(); + producer.beginTransaction(); + + Future md1 = producer.send(record1); + Future md2 = producer.send(record2); + + assertFalse(md1.isDone()); + assertFalse(md2.isDone()); + + producer.commitTransaction(); + + assertTrue(md1.isDone()); + assertTrue(md2.isDone()); + } + + @Test + public void shouldDropMessagesOnAbortIfTransactionsAreEnabled() { + buildMockProducer(true); + producer.initTransactions(); + + producer.beginTransaction(); + producer.send(record1); + producer.send(record2); + producer.abortTransaction(); + assertTrue(producer.history().isEmpty()); + + producer.beginTransaction(); + producer.commitTransaction(); + assertTrue(producer.history().isEmpty()); + } + + @Test + public void shouldThrowOnAbortForNonAutoCompleteIfTransactionsAreEnabled() { + buildMockProducer(false); + producer.initTransactions(); + producer.beginTransaction(); + + Future md1 = producer.send(record1); + assertFalse(md1.isDone()); + + producer.abortTransaction(); + assertTrue(md1.isDone()); + } + + @Test + public void shouldPreserveCommittedMessagesOnAbortIfTransactionsAreEnabled() { + buildMockProducer(true); + producer.initTransactions(); + + producer.beginTransaction(); + producer.send(record1); + producer.send(record2); + producer.commitTransaction(); + + producer.beginTransaction(); + producer.abortTransaction(); + + List> expectedResult = new ArrayList<>(); + expectedResult.add(record1); + expectedResult.add(record2); + + assertEquals(expectedResult, producer.history()); + } + + @Test + public void shouldPublishConsumerGroupOffsetsOnlyAfterCommitIfTransactionsAreEnabled() { + buildMockProducer(true); + producer.initTransactions(); + producer.beginTransaction(); + + String group1 = "g1"; + Map group1Commit = new HashMap() { + { + put(new TopicPartition(topic, 0), new OffsetAndMetadata(42L, null)); + put(new TopicPartition(topic, 1), new OffsetAndMetadata(73L, null)); + } + }; + String group2 = "g2"; + Map group2Commit = new HashMap() { + { + put(new TopicPartition(topic, 0), new OffsetAndMetadata(101L, null)); + put(new TopicPartition(topic, 1), new OffsetAndMetadata(21L, null)); + } + }; + producer.sendOffsetsToTransaction(group1Commit, new ConsumerGroupMetadata(group1)); + producer.sendOffsetsToTransaction(group2Commit, new ConsumerGroupMetadata(group2)); + + assertTrue(producer.consumerGroupOffsetsHistory().isEmpty()); + + Map> expectedResult = new HashMap<>(); + expectedResult.put(group1, group1Commit); + expectedResult.put(group2, group2Commit); + + producer.commitTransaction(); + assertEquals(Collections.singletonList(expectedResult), producer.consumerGroupOffsetsHistory()); + } + + @Deprecated + @Test + public void shouldThrowOnNullConsumerGroupIdWhenSendOffsetsToTransaction() { + buildMockProducer(true); + producer.initTransactions(); + producer.beginTransaction(); + assertThrows(NullPointerException.class, () -> producer.sendOffsetsToTransaction(Collections.emptyMap(), (String) null)); + } + + @Test + public void shouldThrowOnNullConsumerGroupMetadataWhenSendOffsetsToTransaction() { + buildMockProducer(true); + producer.initTransactions(); + producer.beginTransaction(); + assertThrows(NullPointerException.class, () -> producer.sendOffsetsToTransaction(Collections.emptyMap(), new ConsumerGroupMetadata(null))); + } + + @Deprecated + @Test + public void shouldIgnoreEmptyOffsetsWhenSendOffsetsToTransactionByGroupId() { + buildMockProducer(true); + producer.initTransactions(); + producer.beginTransaction(); + producer.sendOffsetsToTransaction(Collections.emptyMap(), "groupId"); + assertFalse(producer.sentOffsets()); + } + + @Test + public void shouldIgnoreEmptyOffsetsWhenSendOffsetsToTransactionByGroupMetadata() { + buildMockProducer(true); + producer.initTransactions(); + producer.beginTransaction(); + producer.sendOffsetsToTransaction(Collections.emptyMap(), new ConsumerGroupMetadata("groupId")); + assertFalse(producer.sentOffsets()); + } + + @Deprecated + @Test + public void shouldAddOffsetsWhenSendOffsetsToTransactionByGroupId() { + buildMockProducer(true); + producer.initTransactions(); + producer.beginTransaction(); + + assertFalse(producer.sentOffsets()); + + Map groupCommit = new HashMap() { + { + put(new TopicPartition(topic, 0), new OffsetAndMetadata(42L, null)); + } + }; + producer.sendOffsetsToTransaction(groupCommit, "groupId"); + assertTrue(producer.sentOffsets()); + } + + @Test + public void shouldAddOffsetsWhenSendOffsetsToTransactionByGroupMetadata() { + buildMockProducer(true); + producer.initTransactions(); + producer.beginTransaction(); + + assertFalse(producer.sentOffsets()); + + Map groupCommit = new HashMap() { + { + put(new TopicPartition(topic, 0), new OffsetAndMetadata(42L, null)); + } + }; + producer.sendOffsetsToTransaction(groupCommit, new ConsumerGroupMetadata("groupId")); + assertTrue(producer.sentOffsets()); + } + + @Test + public void shouldResetSentOffsetsFlagOnlyWhenBeginningNewTransaction() { + buildMockProducer(true); + producer.initTransactions(); + producer.beginTransaction(); + + assertFalse(producer.sentOffsets()); + + Map groupCommit = new HashMap() { + { + put(new TopicPartition(topic, 0), new OffsetAndMetadata(42L, null)); + } + }; + producer.sendOffsetsToTransaction(groupCommit, new ConsumerGroupMetadata("groupId")); + producer.commitTransaction(); // commit should not reset "sentOffsets" flag + assertTrue(producer.sentOffsets()); + + producer.beginTransaction(); + assertFalse(producer.sentOffsets()); + + producer.sendOffsetsToTransaction(groupCommit, new ConsumerGroupMetadata("groupId")); + producer.commitTransaction(); // commit should not reset "sentOffsets" flag + assertTrue(producer.sentOffsets()); + + producer.beginTransaction(); + assertFalse(producer.sentOffsets()); + } + + @Test + public void shouldPublishLatestAndCumulativeConsumerGroupOffsetsOnlyAfterCommitIfTransactionsAreEnabled() { + buildMockProducer(true); + producer.initTransactions(); + producer.beginTransaction(); + + String group = "g"; + Map groupCommit1 = new HashMap() { + { + put(new TopicPartition(topic, 0), new OffsetAndMetadata(42L, null)); + put(new TopicPartition(topic, 1), new OffsetAndMetadata(73L, null)); + } + }; + Map groupCommit2 = new HashMap() { + { + put(new TopicPartition(topic, 1), new OffsetAndMetadata(101L, null)); + put(new TopicPartition(topic, 2), new OffsetAndMetadata(21L, null)); + } + }; + producer.sendOffsetsToTransaction(groupCommit1, new ConsumerGroupMetadata(group)); + producer.sendOffsetsToTransaction(groupCommit2, new ConsumerGroupMetadata(group)); + + assertTrue(producer.consumerGroupOffsetsHistory().isEmpty()); + + Map> expectedResult = new HashMap<>(); + expectedResult.put(group, new HashMap() { + { + put(new TopicPartition(topic, 0), new OffsetAndMetadata(42L, null)); + put(new TopicPartition(topic, 1), new OffsetAndMetadata(101L, null)); + put(new TopicPartition(topic, 2), new OffsetAndMetadata(21L, null)); + } + }); + + producer.commitTransaction(); + assertEquals(Collections.singletonList(expectedResult), producer.consumerGroupOffsetsHistory()); + } + + @Test + public void shouldDropConsumerGroupOffsetsOnAbortIfTransactionsAreEnabled() { + buildMockProducer(true); + producer.initTransactions(); + producer.beginTransaction(); + + String group = "g"; + Map groupCommit = new HashMap() { + { + put(new TopicPartition(topic, 0), new OffsetAndMetadata(42L, null)); + put(new TopicPartition(topic, 1), new OffsetAndMetadata(73L, null)); + } + }; + producer.sendOffsetsToTransaction(groupCommit, new ConsumerGroupMetadata(group)); + producer.abortTransaction(); + + producer.beginTransaction(); + producer.commitTransaction(); + assertTrue(producer.consumerGroupOffsetsHistory().isEmpty()); + + producer.beginTransaction(); + producer.sendOffsetsToTransaction(groupCommit, new ConsumerGroupMetadata(group)); + producer.abortTransaction(); + + producer.beginTransaction(); + producer.commitTransaction(); + assertTrue(producer.consumerGroupOffsetsHistory().isEmpty()); + } + + @Test + public void shouldPreserveOffsetsFromCommitByGroupIdOnAbortIfTransactionsAreEnabled() { + buildMockProducer(true); + producer.initTransactions(); + producer.beginTransaction(); + + String group = "g"; + Map groupCommit = new HashMap() { + { + put(new TopicPartition(topic, 0), new OffsetAndMetadata(42L, null)); + put(new TopicPartition(topic, 1), new OffsetAndMetadata(73L, null)); + } + }; + producer.sendOffsetsToTransaction(groupCommit, new ConsumerGroupMetadata(group)); + producer.commitTransaction(); + + producer.beginTransaction(); + producer.abortTransaction(); + + Map> expectedResult = new HashMap<>(); + expectedResult.put(group, groupCommit); + + assertEquals(Collections.singletonList(expectedResult), producer.consumerGroupOffsetsHistory()); + } + + @Test + public void shouldPreserveOffsetsFromCommitByGroupMetadataOnAbortIfTransactionsAreEnabled() { + buildMockProducer(true); + producer.initTransactions(); + producer.beginTransaction(); + + String group = "g"; + Map groupCommit = new HashMap() { + { + put(new TopicPartition(topic, 0), new OffsetAndMetadata(42L, null)); + put(new TopicPartition(topic, 1), new OffsetAndMetadata(73L, null)); + } + }; + producer.sendOffsetsToTransaction(groupCommit, new ConsumerGroupMetadata(group)); + producer.commitTransaction(); + + producer.beginTransaction(); + + String group2 = "g2"; + Map groupCommit2 = new HashMap() { + { + put(new TopicPartition(topic, 2), new OffsetAndMetadata(53L, null)); + put(new TopicPartition(topic, 3), new OffsetAndMetadata(84L, null)); + } + }; + producer.sendOffsetsToTransaction(groupCommit2, new ConsumerGroupMetadata(group2)); + producer.abortTransaction(); + + Map> expectedResult = new HashMap<>(); + expectedResult.put(group, groupCommit); + + assertEquals(Collections.singletonList(expectedResult), producer.consumerGroupOffsetsHistory()); + } + + @Test + public void shouldThrowOnInitTransactionIfProducerIsClosed() { + buildMockProducer(true); + producer.close(); + assertThrows(IllegalStateException.class, producer::initTransactions); + } + + @Test + public void shouldThrowOnSendIfProducerIsClosed() { + buildMockProducer(true); + producer.close(); + assertThrows(IllegalStateException.class, () -> producer.send(null)); + } + + @Test + public void shouldThrowOnBeginTransactionIfProducerIsClosed() { + buildMockProducer(true); + producer.close(); + assertThrows(IllegalStateException.class, producer::beginTransaction); + } + + @Test + public void shouldThrowSendOffsetsToTransactionByGroupIdIfProducerIsClosed() { + buildMockProducer(true); + producer.close(); + assertThrows(IllegalStateException.class, () -> producer.sendOffsetsToTransaction(null, new ConsumerGroupMetadata(groupId))); + } + + @Test + public void shouldThrowSendOffsetsToTransactionByGroupMetadataIfProducerIsClosed() { + buildMockProducer(true); + producer.close(); + assertThrows(IllegalStateException.class, () -> producer.sendOffsetsToTransaction(null, new ConsumerGroupMetadata(groupId))); + } + + @Test + public void shouldThrowOnCommitTransactionIfProducerIsClosed() { + buildMockProducer(true); + producer.close(); + assertThrows(IllegalStateException.class, producer::commitTransaction); + } + + @Test + public void shouldThrowOnAbortTransactionIfProducerIsClosed() { + buildMockProducer(true); + producer.close(); + assertThrows(IllegalStateException.class, producer::abortTransaction); + } + + @Test + public void shouldThrowOnFenceProducerIfProducerIsClosed() { + buildMockProducer(true); + producer.close(); + assertThrows(IllegalStateException.class, producer::fenceProducer); + } + + @Test + public void shouldThrowOnFlushProducerIfProducerIsClosed() { + buildMockProducer(true); + producer.close(); + assertThrows(IllegalStateException.class, producer::flush); + } + + @Test + @SuppressWarnings("unchecked") + public void shouldThrowClassCastException() { + try (MockProducer customProducer = new MockProducer<>(true, new IntegerSerializer(), new StringSerializer())) { + assertThrows(ClassCastException.class, () -> customProducer.send(new ProducerRecord(topic, "key1", "value1"))); + } + } + + @Test + public void shouldBeFlushedIfNoBufferedRecords() { + buildMockProducer(true); + assertTrue(producer.flushed()); + } + + @Test + public void shouldBeFlushedWithAutoCompleteIfBufferedRecords() { + buildMockProducer(true); + producer.send(record1); + assertTrue(producer.flushed()); + } + + @Test + public void shouldNotBeFlushedWithNoAutoCompleteIfBufferedRecords() { + buildMockProducer(false); + producer.send(record1); + assertFalse(producer.flushed()); + } + + @Test + public void shouldNotBeFlushedAfterFlush() { + buildMockProducer(false); + producer.send(record1); + producer.flush(); + assertTrue(producer.flushed()); + } + + @Test + public void testMetadataOnException() throws InterruptedException { + buildMockProducer(false); + Future metadata = producer.send(record2, (md, exception) -> { + assertNotNull(md); + assertEquals(md.offset(), -1L, "Invalid offset"); + assertEquals(md.timestamp(), RecordBatch.NO_TIMESTAMP, "Invalid timestamp"); + assertEquals(md.serializedKeySize(), -1L, "Invalid Serialized Key size"); + assertEquals(md.serializedValueSize(), -1L, "Invalid Serialized value size"); + }); + IllegalArgumentException e = new IllegalArgumentException("dummy exception"); + assertTrue(producer.errorNext(e), "Complete the second request with an error"); + try { + metadata.get(); + fail("Something went wrong, expected an error"); + } catch (ExecutionException err) { + assertEquals(e, err.getCause()); + } + } + + private boolean isError(Future future) { + try { + future.get(); + return false; + } catch (Exception e) { + return true; + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/ProducerConfigTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/ProducerConfigTest.java new file mode 100644 index 0000000..a2f318b --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/producer/ProducerConfigTest.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer; + +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class ProducerConfigTest { + + private final Serializer keySerializer = new ByteArraySerializer(); + private final Serializer valueSerializer = new StringSerializer(); + private final Object keySerializerClass = keySerializer.getClass(); + private final Object valueSerializerClass = valueSerializer.getClass(); + + @Test + public void testAppendSerializerToConfig() { + Map configs = new HashMap<>(); + configs.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, keySerializerClass); + configs.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, valueSerializerClass); + Map newConfigs = ProducerConfig.appendSerializerToConfig(configs, null, null); + assertEquals(newConfigs.get(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG), keySerializerClass); + assertEquals(newConfigs.get(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG), valueSerializerClass); + + configs.clear(); + configs.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, valueSerializerClass); + newConfigs = ProducerConfig.appendSerializerToConfig(configs, keySerializer, null); + assertEquals(newConfigs.get(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG), keySerializerClass); + assertEquals(newConfigs.get(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG), valueSerializerClass); + + configs.clear(); + configs.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, keySerializerClass); + newConfigs = ProducerConfig.appendSerializerToConfig(configs, null, valueSerializer); + assertEquals(newConfigs.get(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG), keySerializerClass); + assertEquals(newConfigs.get(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG), valueSerializerClass); + + configs.clear(); + newConfigs = ProducerConfig.appendSerializerToConfig(configs, keySerializer, valueSerializer); + assertEquals(newConfigs.get(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG), keySerializerClass); + assertEquals(newConfigs.get(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG), valueSerializerClass); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/ProducerRecordTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/ProducerRecordTest.java new file mode 100644 index 0000000..8cb18d6 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/producer/ProducerRecordTest.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.fail; + +public class ProducerRecordTest { + + @Test + public void testEqualsAndHashCode() { + ProducerRecord producerRecord = new ProducerRecord<>("test", 1, "key", 1); + assertEquals(producerRecord, producerRecord); + assertEquals(producerRecord.hashCode(), producerRecord.hashCode()); + + ProducerRecord equalRecord = new ProducerRecord<>("test", 1, "key", 1); + assertEquals(producerRecord, equalRecord); + assertEquals(producerRecord.hashCode(), equalRecord.hashCode()); + + ProducerRecord topicMisMatch = new ProducerRecord<>("test-1", 1, "key", 1); + assertFalse(producerRecord.equals(topicMisMatch)); + + ProducerRecord partitionMismatch = new ProducerRecord<>("test", 2, "key", 1); + assertFalse(producerRecord.equals(partitionMismatch)); + + ProducerRecord keyMisMatch = new ProducerRecord<>("test", 1, "key-1", 1); + assertFalse(producerRecord.equals(keyMisMatch)); + + ProducerRecord valueMisMatch = new ProducerRecord<>("test", 1, "key", 2); + assertFalse(producerRecord.equals(valueMisMatch)); + + ProducerRecord nullFieldsRecord = new ProducerRecord<>("topic", null, null, null, null, null); + assertEquals(nullFieldsRecord, nullFieldsRecord); + assertEquals(nullFieldsRecord.hashCode(), nullFieldsRecord.hashCode()); + } + + @Test + public void testInvalidRecords() { + try { + new ProducerRecord<>(null, 0, "key", 1); + fail("Expected IllegalArgumentException to be raised because topic is null"); + } catch (IllegalArgumentException e) { + //expected + } + + try { + new ProducerRecord<>("test", 0, -1L, "key", 1); + fail("Expected IllegalArgumentException to be raised because of negative timestamp"); + } catch (IllegalArgumentException e) { + //expected + } + + try { + new ProducerRecord<>("test", -1, "key", 1); + fail("Expected IllegalArgumentException to be raised because of negative partition"); + } catch (IllegalArgumentException e) { + //expected + } + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/RecordMetadataTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/RecordMetadataTest.java new file mode 100644 index 0000000..b6207fb --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/producer/RecordMetadataTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer; + +import org.apache.kafka.common.TopicPartition; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; + +public class RecordMetadataTest { + + @Test + public void testConstructionWithMissingBatchIndex() { + TopicPartition tp = new TopicPartition("foo", 0); + long timestamp = 2340234L; + int keySize = 3; + int valueSize = 5; + + RecordMetadata metadata = new RecordMetadata(tp, -1L, -1, timestamp, keySize, valueSize); + assertEquals(tp.topic(), metadata.topic()); + assertEquals(tp.partition(), metadata.partition()); + assertEquals(timestamp, metadata.timestamp()); + assertFalse(metadata.hasOffset()); + assertEquals(-1L, metadata.offset()); + assertEquals(keySize, metadata.serializedKeySize()); + assertEquals(valueSize, metadata.serializedValueSize()); + } + + @Test + public void testConstructionWithBatchIndexOffset() { + TopicPartition tp = new TopicPartition("foo", 0); + long timestamp = 2340234L; + int keySize = 3; + int valueSize = 5; + long baseOffset = 15L; + int batchIndex = 3; + + RecordMetadata metadata = new RecordMetadata(tp, baseOffset, batchIndex, timestamp, keySize, valueSize); + assertEquals(tp.topic(), metadata.topic()); + assertEquals(tp.partition(), metadata.partition()); + assertEquals(timestamp, metadata.timestamp()); + assertEquals(baseOffset + batchIndex, metadata.offset()); + assertEquals(keySize, metadata.serializedKeySize()); + assertEquals(valueSize, metadata.serializedValueSize()); + } + + @Test + @Deprecated + public void testConstructionWithChecksum() { + TopicPartition tp = new TopicPartition("foo", 0); + long timestamp = 2340234L; + long baseOffset = 15L; + long batchIndex = 3L; + int keySize = 3; + int valueSize = 5; + + RecordMetadata metadata = new RecordMetadata(tp, baseOffset, batchIndex, timestamp, null, keySize, valueSize); + assertEquals(tp.topic(), metadata.topic()); + assertEquals(tp.partition(), metadata.partition()); + assertEquals(timestamp, metadata.timestamp()); + assertEquals(baseOffset + batchIndex, metadata.offset()); + assertEquals(keySize, metadata.serializedKeySize()); + assertEquals(valueSize, metadata.serializedValueSize()); + + long checksum = 133424L; + metadata = new RecordMetadata(tp, baseOffset, batchIndex, timestamp, checksum, keySize, valueSize); + assertEquals(tp.topic(), metadata.topic()); + assertEquals(tp.partition(), metadata.partition()); + assertEquals(timestamp, metadata.timestamp()); + assertEquals(baseOffset + batchIndex, metadata.offset()); + assertEquals(keySize, metadata.serializedKeySize()); + assertEquals(valueSize, metadata.serializedValueSize()); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/RecordSendTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/RecordSendTest.java new file mode 100644 index 0000000..86632ee --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/producer/RecordSendTest.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + + +import org.apache.kafka.clients.producer.internals.FutureRecordMetadata; +import org.apache.kafka.clients.producer.internals.ProduceRequestResult; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.CorruptRecordException; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.utils.Time; +import org.junit.jupiter.api.Test; + +public class RecordSendTest { + + private final TopicPartition topicPartition = new TopicPartition("test", 0); + private final long baseOffset = 45; + private final int relOffset = 5; + + /** + * Test that waiting on a request that never completes times out + */ + @Test + public void testTimeout() throws Exception { + ProduceRequestResult request = new ProduceRequestResult(topicPartition); + FutureRecordMetadata future = new FutureRecordMetadata(request, relOffset, + RecordBatch.NO_TIMESTAMP, 0, 0, Time.SYSTEM); + assertFalse(future.isDone(), "Request is not completed"); + try { + future.get(5, TimeUnit.MILLISECONDS); + fail("Should have thrown exception."); + } catch (TimeoutException e) { /* this is good */ + } + + request.set(baseOffset, RecordBatch.NO_TIMESTAMP, null); + request.done(); + assertTrue(future.isDone()); + assertEquals(baseOffset + relOffset, future.get().offset()); + } + + /** + * Test that an asynchronous request will eventually throw the right exception + */ + @Test + public void testError() throws Exception { + FutureRecordMetadata future = new FutureRecordMetadata(asyncRequest(baseOffset, new CorruptRecordException(), 50L), + relOffset, RecordBatch.NO_TIMESTAMP, 0, 0, Time.SYSTEM); + assertThrows(ExecutionException.class, future::get); + } + + /** + * Test that an asynchronous request will eventually return the right offset + */ + @Test + public void testBlocking() throws Exception { + FutureRecordMetadata future = new FutureRecordMetadata(asyncRequest(baseOffset, null, 50L), + relOffset, RecordBatch.NO_TIMESTAMP, 0, 0, Time.SYSTEM); + assertEquals(baseOffset + relOffset, future.get().offset()); + } + + /* create a new request result that will be completed after the given timeout */ + public ProduceRequestResult asyncRequest(final long baseOffset, final RuntimeException error, final long timeout) { + final ProduceRequestResult request = new ProduceRequestResult(topicPartition); + Thread thread = new Thread() { + public void run() { + try { + sleep(timeout); + if (error == null) { + request.set(baseOffset, RecordBatch.NO_TIMESTAMP, null); + } else { + request.set(-1L, RecordBatch.NO_TIMESTAMP, index -> error); + } + + request.done(); + } catch (InterruptedException e) { } + } + }; + thread.start(); + return request; + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/RoundRobinPartitionerTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/RoundRobinPartitionerTest.java new file mode 100644 index 0000000..dfb98e9 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/producer/RoundRobinPartitionerTest.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer; + +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import static java.util.Arrays.asList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class RoundRobinPartitionerTest { + private final static Node[] NODES = new Node[] { + new Node(0, "localhost", 99), + new Node(1, "localhost", 100), + new Node(2, "localhost", 101) + }; + + @Test + public void testRoundRobinWithUnavailablePartitions() { + // Intentionally make the partition list not in partition order to test the edge + // cases. + List partitions = asList( + new PartitionInfo("test", 1, null, NODES, NODES), + new PartitionInfo("test", 2, NODES[1], NODES, NODES), + new PartitionInfo("test", 0, NODES[0], NODES, NODES)); + // When there are some unavailable partitions, we want to make sure that (1) we + // always pick an available partition, + // and (2) the available partitions are selected in a round robin way. + int countForPart0 = 0; + int countForPart2 = 0; + Partitioner partitioner = new RoundRobinPartitioner(); + Cluster cluster = new Cluster("clusterId", asList(NODES[0], NODES[1], NODES[2]), partitions, + Collections.emptySet(), Collections.emptySet()); + for (int i = 1; i <= 100; i++) { + int part = partitioner.partition("test", null, null, null, null, cluster); + assertTrue(part == 0 || part == 2, "We should never choose a leader-less node in round robin"); + if (part == 0) + countForPart0++; + else + countForPart2++; + } + assertEquals(countForPart0, countForPart2, "The distribution between two available partitions should be even"); + } + + @Test + public void testRoundRobinWithKeyBytes() throws InterruptedException { + final String topicA = "topicA"; + final String topicB = "topicB"; + + List allPartitions = asList(new PartitionInfo(topicA, 0, NODES[0], NODES, NODES), + new PartitionInfo(topicA, 1, NODES[1], NODES, NODES), new PartitionInfo(topicA, 2, NODES[2], NODES, NODES), + new PartitionInfo(topicB, 0, NODES[0], NODES, NODES)); + Cluster testCluster = new Cluster("clusterId", asList(NODES[0], NODES[1], NODES[2]), allPartitions, + Collections.emptySet(), Collections.emptySet()); + + final Map partitionCount = new HashMap<>(); + + final byte[] keyBytes = "key".getBytes(); + Partitioner partitioner = new RoundRobinPartitioner(); + for (int i = 0; i < 30; ++i) { + int partition = partitioner.partition(topicA, null, keyBytes, null, null, testCluster); + Integer count = partitionCount.get(partition); + if (null == count) + count = 0; + partitionCount.put(partition, count + 1); + + if (i % 5 == 0) { + partitioner.partition(topicB, null, keyBytes, null, null, testCluster); + } + } + + assertEquals(10, partitionCount.get(0).intValue()); + assertEquals(10, partitionCount.get(1).intValue()); + assertEquals(10, partitionCount.get(2).intValue()); + } + + @Test + public void testRoundRobinWithNullKeyBytes() throws InterruptedException { + final String topicA = "topicA"; + final String topicB = "topicB"; + + List allPartitions = asList(new PartitionInfo(topicA, 0, NODES[0], NODES, NODES), + new PartitionInfo(topicA, 1, NODES[1], NODES, NODES), new PartitionInfo(topicA, 2, NODES[2], NODES, NODES), + new PartitionInfo(topicB, 0, NODES[0], NODES, NODES)); + Cluster testCluster = new Cluster("clusterId", asList(NODES[0], NODES[1], NODES[2]), allPartitions, + Collections.emptySet(), Collections.emptySet()); + + final Map partitionCount = new HashMap<>(); + + Partitioner partitioner = new RoundRobinPartitioner(); + for (int i = 0; i < 30; ++i) { + int partition = partitioner.partition(topicA, null, null, null, null, testCluster); + Integer count = partitionCount.get(partition); + if (null == count) + count = 0; + partitionCount.put(partition, count + 1); + + if (i % 5 == 0) { + partitioner.partition(topicB, null, null, null, null, testCluster); + } + } + + assertEquals(10, partitionCount.get(0).intValue()); + assertEquals(10, partitionCount.get(1).intValue()); + assertEquals(10, partitionCount.get(2).intValue()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/UniformStickyPartitionerTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/UniformStickyPartitionerTest.java new file mode 100644 index 0000000..0014bf8 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/producer/UniformStickyPartitionerTest.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer; + +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import static java.util.Arrays.asList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class UniformStickyPartitionerTest { + private final static Node[] NODES = new Node[] { + new Node(0, "localhost", 99), + new Node(1, "localhost", 100), + new Node(2, "localhost", 101) + }; + + private final static String TOPIC_A = "TOPIC_A"; + private final static String TOPIC_B = "TOPIC_B"; + + @Test + public void testRoundRobinWithUnavailablePartitions() { + // Intentionally make the partition list not in partition order to test the edge + // cases. + List partitions = asList( + new PartitionInfo("test", 1, null, NODES, NODES), + new PartitionInfo("test", 2, NODES[1], NODES, NODES), + new PartitionInfo("test", 0, NODES[0], NODES, NODES)); + // When there are some unavailable partitions, we want to make sure that (1) we + // always pick an available partition, + // and (2) the available partitions are selected in a sticky way. + int countForPart0 = 0; + int countForPart2 = 0; + int part = 0; + Partitioner partitioner = new UniformStickyPartitioner(); + Cluster cluster = new Cluster("clusterId", asList(NODES[0], NODES[1], NODES[2]), partitions, + Collections.emptySet(), Collections.emptySet()); + for (int i = 0; i < 50; i++) { + part = partitioner.partition("test", null, null, null, null, cluster); + assertTrue(part == 0 || part == 2, "We should never choose a leader-less node in round robin"); + if (part == 0) + countForPart0++; + else + countForPart2++; + } + // Simulates switching the sticky partition on a new batch. + partitioner.onNewBatch("test", cluster, part); + for (int i = 1; i <= 50; i++) { + part = partitioner.partition("test", null, null, null, null, cluster); + assertTrue(part == 0 || part == 2, "We should never choose a leader-less node in round robin"); + if (part == 0) + countForPart0++; + else + countForPart2++; + } + assertEquals(countForPart0, countForPart2, "The distribution between two available partitions should be even"); + } + + @Test + public void testRoundRobinWithKeyBytes() throws InterruptedException { + List allPartitions = asList(new PartitionInfo(TOPIC_A, 0, NODES[0], NODES, NODES), + new PartitionInfo(TOPIC_A, 1, NODES[1], NODES, NODES), new PartitionInfo(TOPIC_A, 2, NODES[1], NODES, NODES), + new PartitionInfo(TOPIC_B, 0, NODES[0], NODES, NODES)); + Cluster testCluster = new Cluster("clusterId", asList(NODES[0], NODES[1], NODES[2]), allPartitions, + Collections.emptySet(), Collections.emptySet()); + + final Map partitionCount = new HashMap<>(); + + final byte[] keyBytes = "key".getBytes(); + int partition = 0; + Partitioner partitioner = new UniformStickyPartitioner(); + for (int i = 0; i < 30; ++i) { + partition = partitioner.partition(TOPIC_A, null, keyBytes, null, null, testCluster); + Integer count = partitionCount.get(partition); + if (null == count) + count = 0; + partitionCount.put(partition, count + 1); + + if (i % 5 == 0) { + partitioner.partition(TOPIC_B, null, keyBytes, null, null, testCluster); + } + } + // Simulate a batch filling up and switching the sticky partition. + partitioner.onNewBatch(TOPIC_A, testCluster, partition); + partitioner.onNewBatch(TOPIC_B, testCluster, 0); + + // Save old partition to ensure that the wrong partition does not trigger a new batch. + int oldPart = partition; + + for (int i = 0; i < 30; ++i) { + partition = partitioner.partition(TOPIC_A, null, keyBytes, null, null, testCluster); + Integer count = partitionCount.get(partition); + if (null == count) + count = 0; + partitionCount.put(partition, count + 1); + + if (i % 5 == 0) { + partitioner.partition(TOPIC_B, null, keyBytes, null, null, testCluster); + } + } + + int newPart = partition; + + // Attempt to switch the partition with the wrong previous partition. Sticky partition should not change. + partitioner.onNewBatch(TOPIC_A, testCluster, oldPart); + + for (int i = 0; i < 30; ++i) { + partition = partitioner.partition(TOPIC_A, null, keyBytes, null, null, testCluster); + Integer count = partitionCount.get(partition); + if (null == count) + count = 0; + partitionCount.put(partition, count + 1); + + if (i % 5 == 0) { + partitioner.partition(TOPIC_B, null, keyBytes, null, null, testCluster); + } + } + + assertEquals(30, partitionCount.get(oldPart).intValue()); + assertEquals(60, partitionCount.get(newPart).intValue()); + } + + @Test + public void testRoundRobinWithNullKeyBytes() throws InterruptedException { + List allPartitions = asList(new PartitionInfo(TOPIC_A, 0, NODES[0], NODES, NODES), + new PartitionInfo(TOPIC_A, 1, NODES[1], NODES, NODES), new PartitionInfo(TOPIC_A, 2, NODES[1], NODES, NODES), + new PartitionInfo(TOPIC_B, 0, NODES[0], NODES, NODES)); + Cluster testCluster = new Cluster("clusterId", asList(NODES[0], NODES[1], NODES[2]), allPartitions, + Collections.emptySet(), Collections.emptySet()); + + final Map partitionCount = new HashMap<>(); + + int partition = 0; + Partitioner partitioner = new UniformStickyPartitioner(); + for (int i = 0; i < 30; ++i) { + partition = partitioner.partition(TOPIC_A, null, null, null, null, testCluster); + Integer count = partitionCount.get(partition); + if (null == count) + count = 0; + partitionCount.put(partition, count + 1); + + if (i % 5 == 0) { + partitioner.partition(TOPIC_B, null, null, null, null, testCluster); + } + } + // Simulate a batch filling up and switching the sticky partition. + partitioner.onNewBatch(TOPIC_A, testCluster, partition); + partitioner.onNewBatch(TOPIC_B, testCluster, 0); + + // Save old partition to ensure that the wrong partition does not trigger a new batch. + int oldPart = partition; + + for (int i = 0; i < 30; ++i) { + partition = partitioner.partition(TOPIC_A, null, null, null, null, testCluster); + Integer count = partitionCount.get(partition); + if (null == count) + count = 0; + partitionCount.put(partition, count + 1); + + if (i % 5 == 0) { + partitioner.partition(TOPIC_B, null, null, null, null, testCluster); + } + } + + int newPart = partition; + + // Attempt to switch the partition with the wrong previous partition. Sticky partition should not change. + partitioner.onNewBatch(TOPIC_A, testCluster, oldPart); + + for (int i = 0; i < 30; ++i) { + partition = partitioner.partition(TOPIC_A, null, null, null, null, testCluster); + Integer count = partitionCount.get(partition); + if (null == count) + count = 0; + partitionCount.put(partition, count + 1); + + if (i % 5 == 0) { + partitioner.partition(TOPIC_B, null, null, null, null, testCluster); + } + } + + assertEquals(30, partitionCount.get(oldPart).intValue()); + assertEquals(60, partitionCount.get(newPart).intValue()); + } +} + diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/BufferPoolTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/BufferPoolTest.java new file mode 100644 index 0000000..eaf35f9 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/BufferPoolTest.java @@ -0,0 +1,432 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer.internals; + +import org.apache.kafka.clients.producer.BufferExhaustedException; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Deque; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executors; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.Condition; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; + +public class BufferPoolTest { + private final MockTime time = new MockTime(); + private final Metrics metrics = new Metrics(time); + private final long maxBlockTimeMs = 10; + private final String metricGroup = "TestMetrics"; + + @AfterEach + public void teardown() { + this.metrics.close(); + } + + /** + * Test the simple non-blocking allocation paths + */ + @Test + public void testSimple() throws Exception { + long totalMemory = 64 * 1024; + int size = 1024; + BufferPool pool = new BufferPool(totalMemory, size, metrics, time, metricGroup); + ByteBuffer buffer = pool.allocate(size, maxBlockTimeMs); + assertEquals(size, buffer.limit(), "Buffer size should equal requested size."); + assertEquals(totalMemory - size, pool.unallocatedMemory(), "Unallocated memory should have shrunk"); + assertEquals(totalMemory - size, pool.availableMemory(), "Available memory should have shrunk"); + buffer.putInt(1); + buffer.flip(); + pool.deallocate(buffer); + assertEquals(totalMemory, pool.availableMemory(), "All memory should be available"); + assertEquals(totalMemory - size, pool.unallocatedMemory(), "But now some is on the free list"); + buffer = pool.allocate(size, maxBlockTimeMs); + assertEquals(0, buffer.position(), "Recycled buffer should be cleared."); + assertEquals(buffer.capacity(), buffer.limit(), "Recycled buffer should be cleared."); + pool.deallocate(buffer); + assertEquals(totalMemory, pool.availableMemory(), "All memory should be available"); + assertEquals(totalMemory - size, pool.unallocatedMemory(), "Still a single buffer on the free list"); + buffer = pool.allocate(2 * size, maxBlockTimeMs); + pool.deallocate(buffer); + assertEquals(totalMemory, pool.availableMemory(), "All memory should be available"); + assertEquals(totalMemory - size, pool.unallocatedMemory(), "Non-standard size didn't go to the free list."); + } + + /** + * Test that we cannot try to allocate more memory then we have in the whole pool + */ + @Test + public void testCantAllocateMoreMemoryThanWeHave() throws Exception { + BufferPool pool = new BufferPool(1024, 512, metrics, time, metricGroup); + ByteBuffer buffer = pool.allocate(1024, maxBlockTimeMs); + assertEquals(1024, buffer.limit()); + pool.deallocate(buffer); + assertThrows(IllegalArgumentException.class, () -> pool.allocate(1025, maxBlockTimeMs)); + } + + /** + * Test that delayed allocation blocks + */ + @Test + public void testDelayedAllocation() throws Exception { + BufferPool pool = new BufferPool(5 * 1024, 1024, metrics, time, metricGroup); + ByteBuffer buffer = pool.allocate(1024, maxBlockTimeMs); + CountDownLatch doDealloc = asyncDeallocate(pool, buffer); + CountDownLatch allocation = asyncAllocate(pool, 5 * 1024); + assertEquals(1L, allocation.getCount(), "Allocation shouldn't have happened yet, waiting on memory."); + doDealloc.countDown(); // return the memory + assertTrue(allocation.await(1, TimeUnit.SECONDS), "Allocation should succeed soon after de-allocation"); + } + + private CountDownLatch asyncDeallocate(final BufferPool pool, final ByteBuffer buffer) { + final CountDownLatch latch = new CountDownLatch(1); + Thread thread = new Thread(() -> { + try { + latch.await(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + pool.deallocate(buffer); + }); + thread.start(); + return latch; + } + + private void delayedDeallocate(final BufferPool pool, final ByteBuffer buffer, final long delayMs) { + Thread thread = new Thread(() -> { + Time.SYSTEM.sleep(delayMs); + pool.deallocate(buffer); + }); + thread.start(); + } + + private CountDownLatch asyncAllocate(final BufferPool pool, final int size) { + final CountDownLatch completed = new CountDownLatch(1); + Thread thread = new Thread(() -> { + try { + pool.allocate(size, maxBlockTimeMs); + } catch (InterruptedException e) { + e.printStackTrace(); + } finally { + completed.countDown(); + } + }); + thread.start(); + return completed; + } + + /** + * Test if BufferExhausted exception is thrown when there is not enough memory to allocate and the elapsed + * time is greater than the max specified block time. + */ + @Test + public void testBufferExhaustedExceptionIsThrown() throws Exception { + BufferPool pool = new BufferPool(2, 1, metrics, time, metricGroup); + pool.allocate(1, maxBlockTimeMs); + assertThrows(BufferExhaustedException.class, () -> pool.allocate(2, maxBlockTimeMs)); + } + + /** + * Verify that a failed allocation attempt due to not enough memory finishes soon after the maxBlockTimeMs. + */ + @Test + public void testBlockTimeout() throws Exception { + BufferPool pool = new BufferPool(10, 1, metrics, Time.SYSTEM, metricGroup); + ByteBuffer buffer1 = pool.allocate(1, maxBlockTimeMs); + ByteBuffer buffer2 = pool.allocate(1, maxBlockTimeMs); + ByteBuffer buffer3 = pool.allocate(1, maxBlockTimeMs); + // The first two buffers will be de-allocated within maxBlockTimeMs since the most recent allocation + delayedDeallocate(pool, buffer1, maxBlockTimeMs / 2); + delayedDeallocate(pool, buffer2, maxBlockTimeMs); + // The third buffer will be de-allocated after maxBlockTimeMs since the most recent allocation + delayedDeallocate(pool, buffer3, maxBlockTimeMs / 2 * 5); + + long beginTimeMs = Time.SYSTEM.milliseconds(); + try { + pool.allocate(10, maxBlockTimeMs); + fail("The buffer allocated more memory than its maximum value 10"); + } catch (BufferExhaustedException e) { + // this is good + } + // Thread scheduling sometimes means that deallocation varies by this point + assertTrue(pool.availableMemory() >= 7 && pool.availableMemory() <= 10, "available memory " + pool.availableMemory()); + long durationMs = Time.SYSTEM.milliseconds() - beginTimeMs; + assertTrue(durationMs >= maxBlockTimeMs, "BufferExhaustedException should not throw before maxBlockTimeMs"); + assertTrue(durationMs < maxBlockTimeMs + 1000, "BufferExhaustedException should throw soon after maxBlockTimeMs"); + } + + /** + * Test if the waiter that is waiting on availability of more memory is cleaned up when a timeout occurs + */ + @Test + public void testCleanupMemoryAvailabilityWaiterOnBlockTimeout() throws Exception { + BufferPool pool = new BufferPool(2, 1, metrics, time, metricGroup); + pool.allocate(1, maxBlockTimeMs); + try { + pool.allocate(2, maxBlockTimeMs); + fail("The buffer allocated more memory than its maximum value 2"); + } catch (BufferExhaustedException e) { + // this is good + } + assertEquals(0, pool.queued()); + assertEquals(1, pool.availableMemory()); + } + + /** + * Test if the waiter that is waiting on availability of more memory is cleaned up when an interruption occurs + */ + @Test + public void testCleanupMemoryAvailabilityWaiterOnInterruption() throws Exception { + BufferPool pool = new BufferPool(2, 1, metrics, time, metricGroup); + long blockTime = 5000; + pool.allocate(1, maxBlockTimeMs); + Thread t1 = new Thread(new BufferPoolAllocator(pool, blockTime)); + Thread t2 = new Thread(new BufferPoolAllocator(pool, blockTime)); + // start thread t1 which will try to allocate more memory on to the Buffer pool + t1.start(); + // sleep for 500ms. Condition variable c1 associated with pool.allocate() by thread t1 will be inserted in the waiters queue. + Thread.sleep(500); + Deque waiters = pool.waiters(); + // get the condition object associated with pool.allocate() by thread t1 + Condition c1 = waiters.getFirst(); + // start thread t2 which will try to allocate more memory on to the Buffer pool + t2.start(); + // sleep for 500ms. Condition variable c2 associated with pool.allocate() by thread t2 will be inserted in the waiters queue. The waiters queue will have 2 entries c1 and c2. + Thread.sleep(500); + t1.interrupt(); + // sleep for 500ms. + Thread.sleep(500); + // get the condition object associated with allocate() by thread t2 + Condition c2 = waiters.getLast(); + t2.interrupt(); + assertNotEquals(c1, c2); + t1.join(); + t2.join(); + // both the allocate() called by threads t1 and t2 should have been interrupted and the waiters queue should be empty + assertEquals(pool.queued(), 0); + } + + @Test + public void testCleanupMemoryAvailabilityOnMetricsException() throws Exception { + BufferPool bufferPool = spy(new BufferPool(2, 1, new Metrics(), time, metricGroup)); + doThrow(new OutOfMemoryError()).when(bufferPool).recordWaitTime(anyLong()); + + bufferPool.allocate(1, 0); + try { + bufferPool.allocate(2, 1000); + fail("Expected oom."); + } catch (OutOfMemoryError expected) { + } + assertEquals(1, bufferPool.availableMemory()); + assertEquals(0, bufferPool.queued()); + assertEquals(1, bufferPool.unallocatedMemory()); + //This shouldn't timeout + bufferPool.allocate(1, 0); + + verify(bufferPool).recordWaitTime(anyLong()); + } + + private static class BufferPoolAllocator implements Runnable { + BufferPool pool; + long maxBlockTimeMs; + + BufferPoolAllocator(BufferPool pool, long maxBlockTimeMs) { + this.pool = pool; + this.maxBlockTimeMs = maxBlockTimeMs; + } + + @Override + public void run() { + try { + pool.allocate(2, maxBlockTimeMs); + fail("The buffer allocated more memory than its maximum value 2"); + } catch (BufferExhaustedException e) { + // this is good + } catch (InterruptedException e) { + // this can be neglected + } + } + } + + /** + * This test creates lots of threads that hammer on the pool + */ + @Test + public void testStressfulSituation() throws Exception { + int numThreads = 10; + final int iterations = 50000; + final int poolableSize = 1024; + final long totalMemory = numThreads / 2 * poolableSize; + final BufferPool pool = new BufferPool(totalMemory, poolableSize, metrics, time, metricGroup); + List threads = new ArrayList(); + for (int i = 0; i < numThreads; i++) + threads.add(new StressTestThread(pool, iterations)); + for (StressTestThread thread : threads) + thread.start(); + for (StressTestThread thread : threads) + thread.join(); + for (StressTestThread thread : threads) + assertTrue(thread.success.get(), "Thread should have completed all iterations successfully."); + assertEquals(totalMemory, pool.availableMemory()); + } + + @Test + public void testLargeAvailableMemory() throws Exception { + long memory = 20_000_000_000L; + int poolableSize = 2_000_000_000; + final AtomicInteger freeSize = new AtomicInteger(0); + BufferPool pool = new BufferPool(memory, poolableSize, metrics, time, metricGroup) { + @Override + protected ByteBuffer allocateByteBuffer(int size) { + // Ignore size to avoid OOM due to large buffers + return ByteBuffer.allocate(0); + } + + @Override + protected int freeSize() { + return freeSize.get(); + } + }; + pool.allocate(poolableSize, 0); + assertEquals(18_000_000_000L, pool.availableMemory()); + pool.allocate(poolableSize, 0); + assertEquals(16_000_000_000L, pool.availableMemory()); + + // Emulate `deallocate` by increasing `freeSize` + freeSize.incrementAndGet(); + assertEquals(18_000_000_000L, pool.availableMemory()); + freeSize.incrementAndGet(); + assertEquals(20_000_000_000L, pool.availableMemory()); + } + + @Test + public void outOfMemoryOnAllocation() { + BufferPool bufferPool = new BufferPool(1024, 1024, metrics, time, metricGroup) { + @Override + protected ByteBuffer allocateByteBuffer(int size) { + throw new OutOfMemoryError(); + } + }; + + try { + bufferPool.allocateByteBuffer(1024); + // should not reach here + fail("Should have thrown OutOfMemoryError"); + } catch (OutOfMemoryError ignored) { + + } + + assertEquals(bufferPool.availableMemory(), 1024); + } + + public static class StressTestThread extends Thread { + private final int iterations; + private final BufferPool pool; + private final long maxBlockTimeMs = 20_000; + public final AtomicBoolean success = new AtomicBoolean(false); + + public StressTestThread(BufferPool pool, int iterations) { + this.iterations = iterations; + this.pool = pool; + } + + @Override + public void run() { + try { + for (int i = 0; i < iterations; i++) { + int size; + if (TestUtils.RANDOM.nextBoolean()) + // allocate poolable size + size = pool.poolableSize(); + else + // allocate a random size + size = TestUtils.RANDOM.nextInt((int) pool.totalMemory()); + ByteBuffer buffer = pool.allocate(size, maxBlockTimeMs); + pool.deallocate(buffer); + } + success.set(true); + } catch (Exception e) { + e.printStackTrace(); + } + } + } + + @Test + public void testCloseAllocations() throws Exception { + BufferPool pool = new BufferPool(10, 1, metrics, Time.SYSTEM, metricGroup); + ByteBuffer buffer = pool.allocate(1, maxBlockTimeMs); + + // Close the buffer pool. This should prevent any further allocations. + pool.close(); + + assertThrows(KafkaException.class, () -> pool.allocate(1, maxBlockTimeMs)); + + // Ensure deallocation still works. + pool.deallocate(buffer); + } + + @Test + public void testCloseNotifyWaiters() throws Exception { + final int numWorkers = 2; + + BufferPool pool = new BufferPool(1, 1, metrics, Time.SYSTEM, metricGroup); + ByteBuffer buffer = pool.allocate(1, Long.MAX_VALUE); + + ExecutorService executor = Executors.newFixedThreadPool(numWorkers); + Callable work = new Callable() { + public Void call() throws Exception { + assertThrows(KafkaException.class, () -> pool.allocate(1, Long.MAX_VALUE)); + return null; + } + }; + for (int i = 0; i < numWorkers; ++i) { + executor.submit(work); + } + + TestUtils.waitForCondition(() -> pool.queued() == numWorkers, "Awaiting " + numWorkers + " workers to be blocked on allocation"); + + // Close the buffer pool. This should notify all waiters. + pool.close(); + + TestUtils.waitForCondition(() -> pool.queued() == 0, "Awaiting " + numWorkers + " workers to be interrupted from allocation"); + + pool.deallocate(buffer); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/DefaultPartitionerTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/DefaultPartitionerTest.java new file mode 100644 index 0000000..a55e5d2 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/DefaultPartitionerTest.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer.internals; + +import org.apache.kafka.clients.producer.Partitioner; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.List; +import static java.util.Arrays.asList; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class DefaultPartitionerTest { + private final static byte[] KEY_BYTES = "key".getBytes(); + private final static Node[] NODES = new Node[] { + new Node(0, "localhost", 99), + new Node(1, "localhost", 100), + new Node(12, "localhost", 101) + }; + private final static String TOPIC = "test"; + // Intentionally make the partition list not in partition order to test the edge cases. + private final static List PARTITIONS = asList(new PartitionInfo(TOPIC, 1, null, NODES, NODES), + new PartitionInfo(TOPIC, 2, NODES[1], NODES, NODES), + new PartitionInfo(TOPIC, 0, NODES[0], NODES, NODES)); + + @Test + public void testKeyPartitionIsStable() { + final Partitioner partitioner = new DefaultPartitioner(); + final Cluster cluster = new Cluster("clusterId", asList(NODES), PARTITIONS, + Collections.emptySet(), Collections.emptySet()); + int partition = partitioner.partition("test", null, KEY_BYTES, null, null, cluster); + assertEquals(partition, partitioner.partition("test", null, KEY_BYTES, null, null, cluster), "Same key should yield same partition"); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/FutureRecordMetadataTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/FutureRecordMetadataTest.java new file mode 100644 index 0000000..1fd7a43 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/FutureRecordMetadataTest.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer.internals; + +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.utils.MockTime; +import org.junit.jupiter.api.Test; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class FutureRecordMetadataTest { + + private final MockTime time = new MockTime(); + + @Test + public void testFutureGetWithSeconds() throws ExecutionException, InterruptedException, TimeoutException { + ProduceRequestResult produceRequestResult = mockProduceRequestResult(); + FutureRecordMetadata future = futureRecordMetadata(produceRequestResult); + + ProduceRequestResult chainedProduceRequestResult = mockProduceRequestResult(); + future.chain(futureRecordMetadata(chainedProduceRequestResult)); + + future.get(1L, TimeUnit.SECONDS); + + verify(produceRequestResult).await(1L, TimeUnit.SECONDS); + verify(chainedProduceRequestResult).await(1000L, TimeUnit.MILLISECONDS); + } + + @Test + public void testFutureGetWithMilliSeconds() throws ExecutionException, InterruptedException, TimeoutException { + ProduceRequestResult produceRequestResult = mockProduceRequestResult(); + FutureRecordMetadata future = futureRecordMetadata(produceRequestResult); + + ProduceRequestResult chainedProduceRequestResult = mockProduceRequestResult(); + future.chain(futureRecordMetadata(chainedProduceRequestResult)); + + future.get(1000L, TimeUnit.MILLISECONDS); + + verify(produceRequestResult).await(1000L, TimeUnit.MILLISECONDS); + verify(chainedProduceRequestResult).await(1000L, TimeUnit.MILLISECONDS); + } + + private FutureRecordMetadata futureRecordMetadata(ProduceRequestResult produceRequestResult) { + return new FutureRecordMetadata( + produceRequestResult, + 0, + RecordBatch.NO_TIMESTAMP, + 0, + 0, + time + ); + } + + private ProduceRequestResult mockProduceRequestResult() throws InterruptedException { + ProduceRequestResult mockProduceRequestResult = mock(ProduceRequestResult.class); + when(mockProduceRequestResult.await(anyLong(), any(TimeUnit.class))).thenReturn(true); + return mockProduceRequestResult; + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/KafkaProducerMetricsTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/KafkaProducerMetricsTest.java new file mode 100644 index 0000000..e068861 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/KafkaProducerMetricsTest.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.clients.producer.internals; + +import org.apache.kafka.common.metrics.Metrics; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +class KafkaProducerMetricsTest { + private static final long METRIC_VALUE = 123L; + private static final String FLUSH_TIME_TOTAL = "flush-time-ns-total"; + private static final String TXN_INIT_TIME_TOTAL = "txn-init-time-ns-total"; + private static final String TXN_BEGIN_TIME_TOTAL = "txn-begin-time-ns-total"; + private static final String TXN_COMMIT_TIME_TOTAL = "txn-commit-time-ns-total"; + private static final String TXN_ABORT_TIME_TOTAL = "txn-abort-time-ns-total"; + private static final String TXN_SEND_OFFSETS_TIME_TOTAL = "txn-send-offsets-time-ns-total"; + + private final Metrics metrics = new Metrics(); + private final KafkaProducerMetrics producerMetrics = new KafkaProducerMetrics(metrics); + + @Test + public void shouldRecordFlushTime() { + // When: + producerMetrics.recordFlush(METRIC_VALUE); + + // Then: + assertMetricValue(FLUSH_TIME_TOTAL); + } + + @Test + public void shouldRecordInitTime() { + // When: + producerMetrics.recordInit(METRIC_VALUE); + + // Then: + assertMetricValue(TXN_INIT_TIME_TOTAL); + } + + @Test + public void shouldRecordTxBeginTime() { + // When: + producerMetrics.recordBeginTxn(METRIC_VALUE); + + // Then: + assertMetricValue(TXN_BEGIN_TIME_TOTAL); + } + + @Test + public void shouldRecordTxCommitTime() { + // When: + producerMetrics.recordCommitTxn(METRIC_VALUE); + + // Then: + assertMetricValue(TXN_COMMIT_TIME_TOTAL); + } + + @Test + public void shouldRecordTxAbortTime() { + // When: + producerMetrics.recordAbortTxn(METRIC_VALUE); + + // Then: + assertMetricValue(TXN_ABORT_TIME_TOTAL); + } + + @Test + public void shouldRecordSendOffsetsTime() { + // When: + producerMetrics.recordSendOffsets(METRIC_VALUE); + + // Then: + assertMetricValue(TXN_SEND_OFFSETS_TIME_TOTAL); + } + + @Test + public void shouldRemoveMetricsOnClose() { + // When: + producerMetrics.close(); + + // Then: + assertMetricRemoved(FLUSH_TIME_TOTAL); + assertMetricRemoved(TXN_INIT_TIME_TOTAL); + assertMetricRemoved(TXN_BEGIN_TIME_TOTAL); + assertMetricRemoved(TXN_COMMIT_TIME_TOTAL); + assertMetricRemoved(TXN_ABORT_TIME_TOTAL); + assertMetricRemoved(TXN_SEND_OFFSETS_TIME_TOTAL); + } + + private void assertMetricRemoved(final String name) { + assertNull(metrics.metric(metrics.metricName(name, KafkaProducerMetrics.GROUP))); + } + + private void assertMetricValue(final String name) { + assertEquals( + metrics.metric(metrics.metricName(name, KafkaProducerMetrics.GROUP)).metricValue(), + (double) METRIC_VALUE + ); + } +} \ No newline at end of file diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/ProducerBatchTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/ProducerBatchTest.java new file mode 100644 index 0000000..6af8289 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/ProducerBatchTest.java @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer.internals; + +import org.apache.kafka.clients.producer.Callback; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.MemoryRecordsBuilder; +import org.apache.kafka.common.record.Record; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Deque; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.function.Function; + +import static org.apache.kafka.common.record.RecordBatch.MAGIC_VALUE_V0; +import static org.apache.kafka.common.record.RecordBatch.MAGIC_VALUE_V1; +import static org.apache.kafka.common.record.RecordBatch.MAGIC_VALUE_V2; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class ProducerBatchTest { + + private final long now = 1488748346917L; + + private final MemoryRecordsBuilder memoryRecordsBuilder = MemoryRecords.builder(ByteBuffer.allocate(512), + CompressionType.NONE, TimestampType.CREATE_TIME, 128); + + @Test + public void testBatchAbort() throws Exception { + ProducerBatch batch = new ProducerBatch(new TopicPartition("topic", 1), memoryRecordsBuilder, now); + MockCallback callback = new MockCallback(); + FutureRecordMetadata future = batch.tryAppend(now, null, new byte[10], Record.EMPTY_HEADERS, callback, now); + + KafkaException exception = new KafkaException(); + batch.abort(exception); + assertTrue(future.isDone()); + assertEquals(1, callback.invocations); + assertEquals(exception, callback.exception); + assertNull(callback.metadata); + + // subsequent completion should be ignored + assertFalse(batch.complete(500L, 2342342341L)); + assertFalse(batch.completeExceptionally(new KafkaException(), index -> new KafkaException())); + assertEquals(1, callback.invocations); + + assertTrue(future.isDone()); + try { + future.get(); + fail("Future should have thrown"); + } catch (ExecutionException e) { + assertEquals(exception, e.getCause()); + } + } + + @Test + public void testBatchCannotAbortTwice() throws Exception { + ProducerBatch batch = new ProducerBatch(new TopicPartition("topic", 1), memoryRecordsBuilder, now); + MockCallback callback = new MockCallback(); + FutureRecordMetadata future = batch.tryAppend(now, null, new byte[10], Record.EMPTY_HEADERS, callback, now); + KafkaException exception = new KafkaException(); + batch.abort(exception); + assertEquals(1, callback.invocations); + assertEquals(exception, callback.exception); + assertNull(callback.metadata); + + try { + batch.abort(new KafkaException()); + fail("Expected exception from abort"); + } catch (IllegalStateException e) { + // expected + } + + assertEquals(1, callback.invocations); + assertTrue(future.isDone()); + try { + future.get(); + fail("Future should have thrown"); + } catch (ExecutionException e) { + assertEquals(exception, e.getCause()); + } + } + + @Test + public void testBatchCannotCompleteTwice() throws Exception { + ProducerBatch batch = new ProducerBatch(new TopicPartition("topic", 1), memoryRecordsBuilder, now); + MockCallback callback = new MockCallback(); + FutureRecordMetadata future = batch.tryAppend(now, null, new byte[10], Record.EMPTY_HEADERS, callback, now); + batch.complete(500L, 10L); + assertEquals(1, callback.invocations); + assertNull(callback.exception); + assertNotNull(callback.metadata); + assertThrows(IllegalStateException.class, () -> batch.complete(1000L, 20L)); + RecordMetadata recordMetadata = future.get(); + assertEquals(500L, recordMetadata.offset()); + assertEquals(10L, recordMetadata.timestamp()); + } + + @Test + public void testSplitPreservesHeaders() { + for (CompressionType compressionType : CompressionType.values()) { + MemoryRecordsBuilder builder = MemoryRecords.builder( + ByteBuffer.allocate(1024), + MAGIC_VALUE_V2, + compressionType, + TimestampType.CREATE_TIME, + 0L); + ProducerBatch batch = new ProducerBatch(new TopicPartition("topic", 1), builder, now); + Header header = new RecordHeader("header-key", "header-value".getBytes()); + + while (true) { + FutureRecordMetadata future = batch.tryAppend( + now, "hi".getBytes(), "there".getBytes(), + new Header[]{header}, null, now); + if (future == null) { + break; + } + } + Deque batches = batch.split(200); + assertTrue(batches.size() >= 2, "This batch should be split to multiple small batches."); + + for (ProducerBatch splitProducerBatch : batches) { + for (RecordBatch splitBatch : splitProducerBatch.records().batches()) { + for (Record record : splitBatch) { + assertTrue(record.headers().length == 1, "Header size should be 1."); + assertTrue(record.headers()[0].key().equals("header-key"), "Header key should be 'header-key'."); + assertTrue(new String(record.headers()[0].value()).equals("header-value"), "Header value should be 'header-value'."); + } + } + } + } + } + + @Test + public void testSplitPreservesMagicAndCompressionType() { + for (byte magic : Arrays.asList(MAGIC_VALUE_V0, MAGIC_VALUE_V1, MAGIC_VALUE_V2)) { + for (CompressionType compressionType : CompressionType.values()) { + if (compressionType == CompressionType.NONE && magic < MAGIC_VALUE_V2) + continue; + + if (compressionType == CompressionType.ZSTD && magic < MAGIC_VALUE_V2) + continue; + + MemoryRecordsBuilder builder = MemoryRecords.builder(ByteBuffer.allocate(1024), magic, + compressionType, TimestampType.CREATE_TIME, 0L); + + ProducerBatch batch = new ProducerBatch(new TopicPartition("topic", 1), builder, now); + while (true) { + FutureRecordMetadata future = batch.tryAppend(now, "hi".getBytes(), "there".getBytes(), + Record.EMPTY_HEADERS, null, now); + if (future == null) + break; + } + + Deque batches = batch.split(512); + assertTrue(batches.size() >= 2); + + for (ProducerBatch splitProducerBatch : batches) { + assertEquals(magic, splitProducerBatch.magic()); + assertTrue(splitProducerBatch.isSplitBatch()); + + for (RecordBatch splitBatch : splitProducerBatch.records().batches()) { + assertEquals(magic, splitBatch.magic()); + assertEquals(0L, splitBatch.baseOffset()); + assertEquals(compressionType, splitBatch.compressionType()); + } + } + } + } + } + + /** + * A {@link ProducerBatch} configured using a timestamp preceding its create time is interpreted correctly + * as not expired by {@link ProducerBatch#hasReachedDeliveryTimeout(long, long)}. + */ + @Test + public void testBatchExpiration() { + long deliveryTimeoutMs = 10240; + ProducerBatch batch = new ProducerBatch(new TopicPartition("topic", 1), memoryRecordsBuilder, now); + // Set `now` to 2ms before the create time. + assertFalse(batch.hasReachedDeliveryTimeout(deliveryTimeoutMs, now - 2)); + // Set `now` to deliveryTimeoutMs. + assertTrue(batch.hasReachedDeliveryTimeout(deliveryTimeoutMs, now + deliveryTimeoutMs)); + } + + /** + * A {@link ProducerBatch} configured using a timestamp preceding its create time is interpreted correctly + * * as not expired by {@link ProducerBatch#hasReachedDeliveryTimeout(long, long)}. + */ + @Test + public void testBatchExpirationAfterReenqueue() { + ProducerBatch batch = new ProducerBatch(new TopicPartition("topic", 1), memoryRecordsBuilder, now); + // Set batch.retry = true + batch.reenqueued(now); + // Set `now` to 2ms before the create time. + assertFalse(batch.hasReachedDeliveryTimeout(10240, now - 2L)); + } + + @Test + public void testShouldNotAttemptAppendOnceRecordsBuilderIsClosedForAppends() { + ProducerBatch batch = new ProducerBatch(new TopicPartition("topic", 1), memoryRecordsBuilder, now); + FutureRecordMetadata result0 = batch.tryAppend(now, null, new byte[10], Record.EMPTY_HEADERS, null, now); + assertNotNull(result0); + assertTrue(memoryRecordsBuilder.hasRoomFor(now, null, new byte[10], Record.EMPTY_HEADERS)); + memoryRecordsBuilder.closeForRecordAppends(); + assertFalse(memoryRecordsBuilder.hasRoomFor(now, null, new byte[10], Record.EMPTY_HEADERS)); + assertNull(batch.tryAppend(now + 1, null, new byte[10], Record.EMPTY_HEADERS, null, now + 1)); + } + + @Test + public void testCompleteExceptionallyWithRecordErrors() { + int recordCount = 5; + RuntimeException topLevelException = new RuntimeException(); + + Map recordExceptionMap = new HashMap<>(); + recordExceptionMap.put(0, new RuntimeException()); + recordExceptionMap.put(3, new RuntimeException()); + + Function recordExceptions = batchIndex -> + recordExceptionMap.getOrDefault(batchIndex, topLevelException); + + testCompleteExceptionally(recordCount, topLevelException, recordExceptions); + } + + @Test + public void testCompleteExceptionallyWithNullRecordErrors() { + int recordCount = 5; + RuntimeException topLevelException = new RuntimeException(); + assertThrows(NullPointerException.class, () -> + testCompleteExceptionally(recordCount, topLevelException, null)); + } + + private void testCompleteExceptionally( + int recordCount, + RuntimeException topLevelException, + Function recordExceptions + ) { + ProducerBatch batch = new ProducerBatch( + new TopicPartition("topic", 1), + memoryRecordsBuilder, + now + ); + + List futures = new ArrayList<>(recordCount); + for (int i = 0; i < recordCount; i++) { + futures.add(batch.tryAppend(now, null, new byte[10], Record.EMPTY_HEADERS, null, now)); + } + assertEquals(recordCount, batch.recordCount); + + batch.completeExceptionally(topLevelException, recordExceptions); + assertTrue(batch.isDone()); + + for (int i = 0; i < futures.size(); i++) { + FutureRecordMetadata future = futures.get(i); + RuntimeException caughtException = TestUtils.assertFutureThrows(future, RuntimeException.class); + RuntimeException expectedException = recordExceptions.apply(i); + assertEquals(expectedException, caughtException); + } + } + + private static class MockCallback implements Callback { + private int invocations = 0; + private RecordMetadata metadata; + private Exception exception; + + @Override + public void onCompletion(RecordMetadata metadata, Exception exception) { + invocations++; + this.metadata = metadata; + this.exception = exception; + } + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/ProducerInterceptorsTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/ProducerInterceptorsTest.java new file mode 100644 index 0000000..cd15a3e --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/ProducerInterceptorsTest.java @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer.internals; + + +import org.apache.kafka.clients.producer.ProducerInterceptor; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.TopicPartition; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class ProducerInterceptorsTest { + private final TopicPartition tp = new TopicPartition("test", 0); + private final ProducerRecord producerRecord = new ProducerRecord<>("test", 0, 1, "value"); + private int onAckCount = 0; + private int onErrorAckCount = 0; + private int onErrorAckWithTopicSetCount = 0; + private int onErrorAckWithTopicPartitionSetCount = 0; + private int onSendCount = 0; + + private class AppendProducerInterceptor implements ProducerInterceptor { + private String appendStr = ""; + private boolean throwExceptionOnSend = false; + private boolean throwExceptionOnAck = false; + + public AppendProducerInterceptor(String appendStr) { + this.appendStr = appendStr; + } + + @Override + public void configure(Map configs) { + } + + @Override + public ProducerRecord onSend(ProducerRecord record) { + onSendCount++; + if (throwExceptionOnSend) + throw new KafkaException("Injected exception in AppendProducerInterceptor.onSend"); + + return new ProducerRecord<>( + record.topic(), record.partition(), record.key(), record.value().concat(appendStr)); + } + + @Override + public void onAcknowledgement(RecordMetadata metadata, Exception exception) { + onAckCount++; + if (exception != null) { + onErrorAckCount++; + // the length check is just to call topic() method and let it throw an exception + // if RecordMetadata.TopicPartition is null + if (metadata != null && metadata.topic().length() >= 0) { + onErrorAckWithTopicSetCount++; + if (metadata.partition() >= 0) + onErrorAckWithTopicPartitionSetCount++; + } + } + if (throwExceptionOnAck) + throw new KafkaException("Injected exception in AppendProducerInterceptor.onAcknowledgement"); + } + + @Override + public void close() { + } + + // if 'on' is true, onSend will always throw an exception + public void injectOnSendError(boolean on) { + throwExceptionOnSend = on; + } + + // if 'on' is true, onAcknowledgement will always throw an exception + public void injectOnAcknowledgementError(boolean on) { + throwExceptionOnAck = on; + } + } + + @Test + public void testOnSendChain() { + List> interceptorList = new ArrayList<>(); + // we are testing two different interceptors by configuring the same interceptor differently, which is not + // how it would be done in KafkaProducer, but ok for testing interceptor callbacks + AppendProducerInterceptor interceptor1 = new AppendProducerInterceptor("One"); + AppendProducerInterceptor interceptor2 = new AppendProducerInterceptor("Two"); + interceptorList.add(interceptor1); + interceptorList.add(interceptor2); + ProducerInterceptors interceptors = new ProducerInterceptors<>(interceptorList); + + // verify that onSend() mutates the record as expected + ProducerRecord interceptedRecord = interceptors.onSend(producerRecord); + assertEquals(2, onSendCount); + assertEquals(producerRecord.topic(), interceptedRecord.topic()); + assertEquals(producerRecord.partition(), interceptedRecord.partition()); + assertEquals(producerRecord.key(), interceptedRecord.key()); + assertEquals(interceptedRecord.value(), producerRecord.value().concat("One").concat("Two")); + + // onSend() mutates the same record the same way + ProducerRecord anotherRecord = interceptors.onSend(producerRecord); + assertEquals(4, onSendCount); + assertEquals(interceptedRecord, anotherRecord); + + // verify that if one of the interceptors throws an exception, other interceptors' callbacks are still called + interceptor1.injectOnSendError(true); + ProducerRecord partInterceptRecord = interceptors.onSend(producerRecord); + assertEquals(6, onSendCount); + assertEquals(partInterceptRecord.value(), producerRecord.value().concat("Two")); + + // verify the record remains valid if all onSend throws an exception + interceptor2.injectOnSendError(true); + ProducerRecord noInterceptRecord = interceptors.onSend(producerRecord); + assertEquals(producerRecord, noInterceptRecord); + + interceptors.close(); + } + + @Test + public void testOnAcknowledgementChain() { + List> interceptorList = new ArrayList<>(); + // we are testing two different interceptors by configuring the same interceptor differently, which is not + // how it would be done in KafkaProducer, but ok for testing interceptor callbacks + AppendProducerInterceptor interceptor1 = new AppendProducerInterceptor("One"); + AppendProducerInterceptor interceptor2 = new AppendProducerInterceptor("Two"); + interceptorList.add(interceptor1); + interceptorList.add(interceptor2); + ProducerInterceptors interceptors = new ProducerInterceptors<>(interceptorList); + + // verify onAck is called on all interceptors + RecordMetadata meta = new RecordMetadata(tp, 0, 0, 0, 0, 0); + interceptors.onAcknowledgement(meta, null); + assertEquals(2, onAckCount); + + // verify that onAcknowledgement exceptions do not propagate + interceptor1.injectOnAcknowledgementError(true); + interceptors.onAcknowledgement(meta, null); + assertEquals(4, onAckCount); + + interceptor2.injectOnAcknowledgementError(true); + interceptors.onAcknowledgement(meta, null); + assertEquals(6, onAckCount); + + interceptors.close(); + } + + @Test + public void testOnAcknowledgementWithErrorChain() { + List> interceptorList = new ArrayList<>(); + AppendProducerInterceptor interceptor1 = new AppendProducerInterceptor("One"); + interceptorList.add(interceptor1); + ProducerInterceptors interceptors = new ProducerInterceptors<>(interceptorList); + + // verify that metadata contains both topic and partition + interceptors.onSendError(producerRecord, + new TopicPartition(producerRecord.topic(), producerRecord.partition()), + new KafkaException("Test")); + assertEquals(1, onErrorAckCount); + assertEquals(1, onErrorAckWithTopicPartitionSetCount); + + // verify that metadata contains both topic and partition (because record already contains partition) + interceptors.onSendError(producerRecord, null, new KafkaException("Test")); + assertEquals(2, onErrorAckCount); + assertEquals(2, onErrorAckWithTopicPartitionSetCount); + + // if producer record does not contain partition, interceptor should get partition == -1 + ProducerRecord record2 = new ProducerRecord<>("test2", null, 1, "value"); + interceptors.onSendError(record2, null, new KafkaException("Test")); + assertEquals(3, onErrorAckCount); + assertEquals(3, onErrorAckWithTopicSetCount); + assertEquals(2, onErrorAckWithTopicPartitionSetCount); + + // if producer record does not contain partition, but topic/partition is passed to + // onSendError, then interceptor should get valid partition + int reassignedPartition = producerRecord.partition() + 1; + interceptors.onSendError(record2, + new TopicPartition(record2.topic(), reassignedPartition), + new KafkaException("Test")); + assertEquals(4, onErrorAckCount); + assertEquals(4, onErrorAckWithTopicSetCount); + assertEquals(3, onErrorAckWithTopicPartitionSetCount); + + // if both record and topic/partition are null, interceptor should not receive metadata + interceptors.onSendError(null, null, new KafkaException("Test")); + assertEquals(5, onErrorAckCount); + assertEquals(4, onErrorAckWithTopicSetCount); + assertEquals(3, onErrorAckWithTopicPartitionSetCount); + + interceptors.close(); + } +} + diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/ProducerMetadataTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/ProducerMetadataTest.java new file mode 100644 index 0000000..3b024a0 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/ProducerMetadataTest.java @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer.internals; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.internals.ClusterResourceListeners; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.requests.RequestTestUtils; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class ProducerMetadataTest { + private static final long METADATA_IDLE_MS = 60 * 1000; + private long refreshBackoffMs = 100; + private long metadataExpireMs = 1000; + private ProducerMetadata metadata = new ProducerMetadata(refreshBackoffMs, metadataExpireMs, METADATA_IDLE_MS, + new LogContext(), new ClusterResourceListeners(), Time.SYSTEM); + private AtomicReference backgroundError = new AtomicReference<>(); + + @AfterEach + public void tearDown() { + assertNull(backgroundError.get(), "Exception in background thread : " + backgroundError.get()); + } + + @Test + public void testMetadata() throws Exception { + long time = Time.SYSTEM.milliseconds(); + String topic = "my-topic"; + metadata.add(topic, time); + + metadata.updateWithCurrentRequestVersion(responseWithTopics(Collections.emptySet()), false, time); + assertTrue(metadata.timeToNextUpdate(time) > 0, "No update needed."); + metadata.requestUpdate(); + assertTrue(metadata.timeToNextUpdate(time) > 0, "Still no updated needed due to backoff"); + time += refreshBackoffMs; + assertEquals(0, metadata.timeToNextUpdate(time), "Update needed now that backoff time expired"); + Thread t1 = asyncFetch(topic, 500); + Thread t2 = asyncFetch(topic, 500); + assertTrue(t1.isAlive(), "Awaiting update"); + assertTrue(t2.isAlive(), "Awaiting update"); + // Perform metadata update when an update is requested on the async fetch thread + // This simulates the metadata update sequence in KafkaProducer + while (t1.isAlive() || t2.isAlive()) { + if (metadata.timeToNextUpdate(time) == 0) { + metadata.updateWithCurrentRequestVersion(responseWithCurrentTopics(), false, time); + time += refreshBackoffMs; + } + Thread.sleep(1); + } + t1.join(); + t2.join(); + assertTrue(metadata.timeToNextUpdate(time) > 0, "No update needed."); + time += metadataExpireMs; + assertEquals(0, metadata.timeToNextUpdate(time), "Update needed due to stale metadata."); + } + + @Test + public void testMetadataAwaitAfterClose() throws InterruptedException { + long time = 0; + metadata.updateWithCurrentRequestVersion(responseWithCurrentTopics(), false, time); + assertTrue(metadata.timeToNextUpdate(time) > 0, "No update needed."); + metadata.requestUpdate(); + assertTrue(metadata.timeToNextUpdate(time) > 0, "Still no updated needed due to backoff"); + time += refreshBackoffMs; + assertEquals(0, metadata.timeToNextUpdate(time), "Update needed now that backoff time expired"); + String topic = "my-topic"; + metadata.close(); + Thread t1 = asyncFetch(topic, 500); + t1.join(); + assertEquals(KafkaException.class, backgroundError.get().getClass()); + assertTrue(backgroundError.get().toString().contains("Requested metadata update after close")); + clearBackgroundError(); + } + + /** + * Tests that {@link org.apache.kafka.clients.producer.internals.ProducerMetadata#awaitUpdate(int, long)} doesn't + * wait forever with a max timeout value of 0 + * + * @throws Exception + * @see KAFKA-1836 + */ + @Test + public void testMetadataUpdateWaitTime() throws Exception { + long time = 0; + metadata.updateWithCurrentRequestVersion(responseWithCurrentTopics(), false, time); + assertTrue(metadata.timeToNextUpdate(time) > 0, "No update needed."); + // first try with a max wait time of 0 and ensure that this returns back without waiting forever + try { + metadata.awaitUpdate(metadata.requestUpdate(), 0); + fail("Wait on metadata update was expected to timeout, but it didn't"); + } catch (TimeoutException te) { + // expected + } + // now try with a higher timeout value once + final long twoSecondWait = 2000; + try { + metadata.awaitUpdate(metadata.requestUpdate(), twoSecondWait); + fail("Wait on metadata update was expected to timeout, but it didn't"); + } catch (TimeoutException te) { + // expected + } + } + + @Test + public void testTimeToNextUpdateOverwriteBackoff() { + long now = 10000; + + // New topic added to fetch set and update requested. It should allow immediate update. + metadata.updateWithCurrentRequestVersion(responseWithCurrentTopics(), false, now); + metadata.add("new-topic", now); + assertEquals(0, metadata.timeToNextUpdate(now)); + + // Even though add is called, immediate update isn't necessary if the new topic set isn't + // containing a new topic, + metadata.updateWithCurrentRequestVersion(responseWithCurrentTopics(), false, now); + metadata.add("new-topic", now); + assertEquals(metadataExpireMs, metadata.timeToNextUpdate(now)); + + // If the new set of topics containing a new topic then it should allow immediate update. + metadata.add("another-new-topic", now); + assertEquals(0, metadata.timeToNextUpdate(now)); + } + + @Test + public void testTopicExpiry() { + // Test that topic is expired if not used within the expiry interval + long time = 0; + final String topic1 = "topic1"; + metadata.add(topic1, time); + metadata.updateWithCurrentRequestVersion(responseWithCurrentTopics(), false, time); + assertTrue(metadata.containsTopic(topic1)); + + time += METADATA_IDLE_MS; + metadata.updateWithCurrentRequestVersion(responseWithCurrentTopics(), false, time); + assertFalse(metadata.containsTopic(topic1), "Unused topic not expired"); + + // Test that topic is not expired if used within the expiry interval + final String topic2 = "topic2"; + metadata.add(topic2, time); + metadata.updateWithCurrentRequestVersion(responseWithCurrentTopics(), false, time); + for (int i = 0; i < 3; i++) { + time += METADATA_IDLE_MS / 2; + metadata.updateWithCurrentRequestVersion(responseWithCurrentTopics(), false, time); + assertTrue(metadata.containsTopic(topic2), "Topic expired even though in use"); + metadata.add(topic2, time); + } + + // Add a new topic, but update its metadata after the expiry would have occurred. + // The topic should still be retained. + final String topic3 = "topic3"; + metadata.add(topic3, time); + time += METADATA_IDLE_MS * 2; + metadata.updateWithCurrentRequestVersion(responseWithCurrentTopics(), false, time); + assertTrue(metadata.containsTopic(topic3), "Topic expired while awaiting metadata"); + } + + @Test + public void testMetadataWaitAbortedOnFatalException() { + metadata.fatalError(new AuthenticationException("Fatal exception from test")); + assertThrows(AuthenticationException.class, () -> metadata.awaitUpdate(0, 1000)); + } + + @Test + public void testMetadataPartialUpdate() { + long now = 10000; + + // Add a new topic and fetch its metadata in a partial update. + final String topic1 = "topic-one"; + metadata.add(topic1, now); + assertTrue(metadata.updateRequested()); + assertEquals(0, metadata.timeToNextUpdate(now)); + assertEquals(metadata.topics(), Collections.singleton(topic1)); + assertEquals(metadata.newTopics(), Collections.singleton(topic1)); + + // Perform the partial update. Verify the topic is no longer considered "new". + now += 1000; + metadata.updateWithCurrentRequestVersion(responseWithTopics(Collections.singleton(topic1)), true, now); + assertFalse(metadata.updateRequested()); + assertEquals(metadata.topics(), Collections.singleton(topic1)); + assertEquals(metadata.newTopics(), Collections.emptySet()); + + // Add the topic again. It should not be considered "new". + metadata.add(topic1, now); + assertFalse(metadata.updateRequested()); + assertTrue(metadata.timeToNextUpdate(now) > 0); + assertEquals(metadata.topics(), Collections.singleton(topic1)); + assertEquals(metadata.newTopics(), Collections.emptySet()); + + // Add two new topics. However, we'll only apply a partial update for one of them. + now += 1000; + final String topic2 = "topic-two"; + metadata.add(topic2, now); + + now += 1000; + final String topic3 = "topic-three"; + metadata.add(topic3, now); + + assertTrue(metadata.updateRequested()); + assertEquals(0, metadata.timeToNextUpdate(now)); + assertEquals(metadata.topics(), new HashSet<>(Arrays.asList(topic1, topic2, topic3))); + assertEquals(metadata.newTopics(), new HashSet<>(Arrays.asList(topic2, topic3))); + + // Perform the partial update for a subset of the new topics. + now += 1000; + assertTrue(metadata.updateRequested()); + metadata.updateWithCurrentRequestVersion(responseWithTopics(Collections.singleton(topic2)), true, now); + assertEquals(metadata.topics(), new HashSet<>(Arrays.asList(topic1, topic2, topic3))); + assertEquals(metadata.newTopics(), Collections.singleton(topic3)); + } + + @Test + public void testRequestUpdateForTopic() { + long now = 10000; + + final String topic1 = "topic-1"; + final String topic2 = "topic-2"; + + // Add the topics to the metadata. + metadata.add(topic1, now); + metadata.add(topic2, now); + assertTrue(metadata.updateRequested()); + + // Request an update for topic1. Since the topic is considered new, it should not trigger + // the metadata to require a full update. + metadata.requestUpdateForTopic(topic1); + assertTrue(metadata.updateRequested()); + + // Perform the partial update. Verify no additional (full) updates are requested. + now += 1000; + metadata.updateWithCurrentRequestVersion(responseWithTopics(Collections.singleton(topic1)), true, now); + assertFalse(metadata.updateRequested()); + + // Request an update for topic1 again. Such a request may occur when the leader + // changes, which may affect many topics, and should therefore request a full update. + metadata.requestUpdateForTopic(topic1); + assertTrue(metadata.updateRequested()); + + // Perform a partial update for the topic. This should not clear the full update. + now += 1000; + metadata.updateWithCurrentRequestVersion(responseWithTopics(Collections.singleton(topic1)), true, now); + assertTrue(metadata.updateRequested()); + + // Perform the full update. This should clear the update request. + now += 1000; + metadata.updateWithCurrentRequestVersion(responseWithTopics(new HashSet<>(Arrays.asList(topic1, topic2))), false, now); + assertFalse(metadata.updateRequested()); + } + + private MetadataResponse responseWithCurrentTopics() { + return responseWithTopics(metadata.topics()); + } + + private MetadataResponse responseWithTopics(Set topics) { + Map partitionCounts = new HashMap<>(); + for (String topic : topics) + partitionCounts.put(topic, 1); + return RequestTestUtils.metadataUpdateWith(1, partitionCounts); + } + + private void clearBackgroundError() { + backgroundError.set(null); + } + + private Thread asyncFetch(final String topic, final long maxWaitMs) { + Thread thread = new Thread(() -> { + try { + while (metadata.fetch().partitionsForTopic(topic).isEmpty()) + metadata.awaitUpdate(metadata.requestUpdate(), maxWaitMs); + } catch (Exception e) { + backgroundError.set(e); + } + }); + thread.start(); + return thread; + } + +} diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/ProducerTestUtils.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/ProducerTestUtils.java new file mode 100644 index 0000000..a841033 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/ProducerTestUtils.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer.internals; + +import java.util.function.Supplier; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ProducerTestUtils { + private static final int MAX_TRIES = 10; + + static void runUntil( + Sender sender, + Supplier condition + ) { + runUntil(sender, condition, MAX_TRIES); + } + + static void runUntil( + Sender sender, + Supplier condition, + int maxTries + ) { + int tries = 0; + while (!condition.get() && tries < maxTries) { + tries++; + sender.runOnce(); + } + assertTrue(condition.get(), "Condition not satisfied after " + maxTries + " tries"); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/RecordAccumulatorTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/RecordAccumulatorTest.java new file mode 100644 index 0000000..06ed1ce --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/RecordAccumulatorTest.java @@ -0,0 +1,1168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer.internals; + +import org.apache.kafka.clients.ApiVersions; +import org.apache.kafka.clients.NodeApiVersions; +import org.apache.kafka.clients.producer.Callback; +import org.apache.kafka.clients.producer.Partitioner; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.record.CompressionRatioEstimator; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.DefaultRecord; +import org.apache.kafka.common.record.DefaultRecordBatch; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.MemoryRecordsBuilder; +import org.apache.kafka.common.record.MutableRecordBatch; +import org.apache.kafka.common.record.Record; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.ProducerIdAndEpoch; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Deque; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicInteger; + +import static java.util.Arrays.asList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class RecordAccumulatorTest { + + private String topic = "test"; + private int partition1 = 0; + private int partition2 = 1; + private int partition3 = 2; + private Node node1 = new Node(0, "localhost", 1111); + private Node node2 = new Node(1, "localhost", 1112); + private TopicPartition tp1 = new TopicPartition(topic, partition1); + private TopicPartition tp2 = new TopicPartition(topic, partition2); + private TopicPartition tp3 = new TopicPartition(topic, partition3); + private PartitionInfo part1 = new PartitionInfo(topic, partition1, node1, null, null); + private PartitionInfo part2 = new PartitionInfo(topic, partition2, node1, null, null); + private PartitionInfo part3 = new PartitionInfo(topic, partition3, node2, null, null); + private MockTime time = new MockTime(); + private byte[] key = "key".getBytes(); + private byte[] value = "value".getBytes(); + private int msgSize = DefaultRecord.sizeInBytes(0, 0, key.length, value.length, Record.EMPTY_HEADERS); + private Cluster cluster = new Cluster(null, Arrays.asList(node1, node2), Arrays.asList(part1, part2, part3), + Collections.emptySet(), Collections.emptySet()); + private Metrics metrics = new Metrics(time); + private final long maxBlockTimeMs = 1000; + private final LogContext logContext = new LogContext(); + + @AfterEach + public void teardown() { + this.metrics.close(); + } + + @Test + public void testFull() throws Exception { + long now = time.milliseconds(); + + // test case assumes that the records do not fill the batch completely + int batchSize = 1025; + + RecordAccumulator accum = createTestRecordAccumulator( + batchSize + DefaultRecordBatch.RECORD_BATCH_OVERHEAD, 10L * batchSize, CompressionType.NONE, 10); + int appends = expectedNumAppends(batchSize); + for (int i = 0; i < appends; i++) { + // append to the first batch + accum.append(tp1, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + Deque partitionBatches = accum.batches().get(tp1); + assertEquals(1, partitionBatches.size()); + + ProducerBatch batch = partitionBatches.peekFirst(); + assertTrue(batch.isWritable()); + assertEquals(0, accum.ready(cluster, now).readyNodes.size(), "No partitions should be ready."); + } + + // this append doesn't fit in the first batch, so a new batch is created and the first batch is closed + + accum.append(tp1, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + Deque partitionBatches = accum.batches().get(tp1); + assertEquals(2, partitionBatches.size()); + Iterator partitionBatchesIterator = partitionBatches.iterator(); + assertTrue(partitionBatchesIterator.next().isWritable()); + assertEquals(Collections.singleton(node1), accum.ready(cluster, time.milliseconds()).readyNodes, "Our partition's leader should be ready"); + + List batches = accum.drain(cluster, Collections.singleton(node1), Integer.MAX_VALUE, 0).get(node1.id()); + assertEquals(1, batches.size()); + ProducerBatch batch = batches.get(0); + + Iterator iter = batch.records().records().iterator(); + for (int i = 0; i < appends; i++) { + Record record = iter.next(); + assertEquals(ByteBuffer.wrap(key), record.key(), "Keys should match"); + assertEquals(ByteBuffer.wrap(value), record.value(), "Values should match"); + } + assertFalse(iter.hasNext(), "No more records"); + } + + @Test + public void testAppendLargeCompressed() throws Exception { + testAppendLarge(CompressionType.GZIP); + } + + @Test + public void testAppendLargeNonCompressed() throws Exception { + testAppendLarge(CompressionType.NONE); + } + + private void testAppendLarge(CompressionType compressionType) throws Exception { + int batchSize = 512; + byte[] value = new byte[2 * batchSize]; + RecordAccumulator accum = createTestRecordAccumulator( + batchSize + DefaultRecordBatch.RECORD_BATCH_OVERHEAD, 10 * 1024, compressionType, 0); + accum.append(tp1, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + assertEquals(Collections.singleton(node1), accum.ready(cluster, time.milliseconds()).readyNodes, "Our partition's leader should be ready"); + + Deque batches = accum.batches().get(tp1); + assertEquals(1, batches.size()); + ProducerBatch producerBatch = batches.peek(); + List recordBatches = TestUtils.toList(producerBatch.records().batches()); + assertEquals(1, recordBatches.size()); + MutableRecordBatch recordBatch = recordBatches.get(0); + assertEquals(0L, recordBatch.baseOffset()); + List records = TestUtils.toList(recordBatch); + assertEquals(1, records.size()); + Record record = records.get(0); + assertEquals(0L, record.offset()); + assertEquals(ByteBuffer.wrap(key), record.key()); + assertEquals(ByteBuffer.wrap(value), record.value()); + assertEquals(0L, record.timestamp()); + } + + @Test + public void testAppendLargeOldMessageFormatCompressed() throws Exception { + testAppendLargeOldMessageFormat(CompressionType.GZIP); + } + + @Test + public void testAppendLargeOldMessageFormatNonCompressed() throws Exception { + testAppendLargeOldMessageFormat(CompressionType.NONE); + } + + private void testAppendLargeOldMessageFormat(CompressionType compressionType) throws Exception { + int batchSize = 512; + byte[] value = new byte[2 * batchSize]; + + ApiVersions apiVersions = new ApiVersions(); + apiVersions.update(node1.idString(), NodeApiVersions.create(ApiKeys.PRODUCE.id, (short) 0, (short) 2)); + + RecordAccumulator accum = createTestRecordAccumulator( + batchSize + DefaultRecordBatch.RECORD_BATCH_OVERHEAD, 10 * 1024, compressionType, 0); + accum.append(tp1, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + assertEquals(Collections.singleton(node1), accum.ready(cluster, time.milliseconds()).readyNodes, "Our partition's leader should be ready"); + + Deque batches = accum.batches().get(tp1); + assertEquals(1, batches.size()); + ProducerBatch producerBatch = batches.peek(); + List recordBatches = TestUtils.toList(producerBatch.records().batches()); + assertEquals(1, recordBatches.size()); + MutableRecordBatch recordBatch = recordBatches.get(0); + assertEquals(0L, recordBatch.baseOffset()); + List records = TestUtils.toList(recordBatch); + assertEquals(1, records.size()); + Record record = records.get(0); + assertEquals(0L, record.offset()); + assertEquals(ByteBuffer.wrap(key), record.key()); + assertEquals(ByteBuffer.wrap(value), record.value()); + assertEquals(0L, record.timestamp()); + } + + @Test + public void testLinger() throws Exception { + int lingerMs = 10; + RecordAccumulator accum = createTestRecordAccumulator( + 1024 + DefaultRecordBatch.RECORD_BATCH_OVERHEAD, 10 * 1024, CompressionType.NONE, lingerMs); + accum.append(tp1, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + assertEquals(0, accum.ready(cluster, time.milliseconds()).readyNodes.size(), "No partitions should be ready"); + time.sleep(10); + assertEquals(Collections.singleton(node1), accum.ready(cluster, time.milliseconds()).readyNodes, "Our partition's leader should be ready"); + List batches = accum.drain(cluster, Collections.singleton(node1), Integer.MAX_VALUE, 0).get(node1.id()); + assertEquals(1, batches.size()); + ProducerBatch batch = batches.get(0); + + Iterator iter = batch.records().records().iterator(); + Record record = iter.next(); + assertEquals(ByteBuffer.wrap(key), record.key(), "Keys should match"); + assertEquals(ByteBuffer.wrap(value), record.value(), "Values should match"); + assertFalse(iter.hasNext(), "No more records"); + } + + @Test + public void testPartialDrain() throws Exception { + RecordAccumulator accum = createTestRecordAccumulator( + 1024 + DefaultRecordBatch.RECORD_BATCH_OVERHEAD, 10 * 1024, CompressionType.NONE, 10); + int appends = 1024 / msgSize + 1; + List partitions = asList(tp1, tp2); + for (TopicPartition tp : partitions) { + for (int i = 0; i < appends; i++) + accum.append(tp, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + } + assertEquals(Collections.singleton(node1), accum.ready(cluster, time.milliseconds()).readyNodes, "Partition's leader should be ready"); + + List batches = accum.drain(cluster, Collections.singleton(node1), 1024, 0).get(node1.id()); + assertEquals(1, batches.size(), "But due to size bound only one partition should have been retrieved"); + } + + @SuppressWarnings("unused") + @Test + public void testStressfulSituation() throws Exception { + final int numThreads = 5; + final int msgs = 10000; + final int numParts = 2; + final RecordAccumulator accum = createTestRecordAccumulator( + 1024 + DefaultRecordBatch.RECORD_BATCH_OVERHEAD, 10 * 1024, CompressionType.NONE, 0); + List threads = new ArrayList<>(); + for (int i = 0; i < numThreads; i++) { + threads.add(new Thread() { + public void run() { + for (int i = 0; i < msgs; i++) { + try { + accum.append(new TopicPartition(topic, i % numParts), 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + } catch (Exception e) { + e.printStackTrace(); + } + } + } + }); + } + for (Thread t : threads) + t.start(); + int read = 0; + long now = time.milliseconds(); + while (read < numThreads * msgs) { + Set nodes = accum.ready(cluster, now).readyNodes; + List batches = accum.drain(cluster, nodes, 5 * 1024, 0).get(node1.id()); + if (batches != null) { + for (ProducerBatch batch : batches) { + for (Record record : batch.records().records()) + read++; + accum.deallocate(batch); + } + } + } + + for (Thread t : threads) + t.join(); + } + + + @Test + public void testNextReadyCheckDelay() throws Exception { + // Next check time will use lingerMs since this test won't trigger any retries/backoff + int lingerMs = 10; + + // test case assumes that the records do not fill the batch completely + int batchSize = 1025; + + RecordAccumulator accum = createTestRecordAccumulator(batchSize + DefaultRecordBatch.RECORD_BATCH_OVERHEAD, + 10 * batchSize, CompressionType.NONE, lingerMs); + // Just short of going over the limit so we trigger linger time + int appends = expectedNumAppends(batchSize); + + // Partition on node1 only + for (int i = 0; i < appends; i++) + accum.append(tp1, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + RecordAccumulator.ReadyCheckResult result = accum.ready(cluster, time.milliseconds()); + assertEquals(0, result.readyNodes.size(), "No nodes should be ready."); + assertEquals(lingerMs, result.nextReadyCheckDelayMs, "Next check time should be the linger time"); + + time.sleep(lingerMs / 2); + + // Add partition on node2 only + for (int i = 0; i < appends; i++) + accum.append(tp3, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + result = accum.ready(cluster, time.milliseconds()); + assertEquals(0, result.readyNodes.size(), "No nodes should be ready."); + assertEquals(lingerMs / 2, result.nextReadyCheckDelayMs, "Next check time should be defined by node1, half remaining linger time"); + + // Add data for another partition on node1, enough to make data sendable immediately + for (int i = 0; i < appends + 1; i++) + accum.append(tp2, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + result = accum.ready(cluster, time.milliseconds()); + assertEquals(Collections.singleton(node1), result.readyNodes, "Node1 should be ready"); + // Note this can actually be < linger time because it may use delays from partitions that aren't sendable + // but have leaders with other sendable data. + assertTrue(result.nextReadyCheckDelayMs <= lingerMs, "Next check time should be defined by node2, at most linger time"); + } + + @Test + public void testRetryBackoff() throws Exception { + int lingerMs = Integer.MAX_VALUE / 16; + long retryBackoffMs = Integer.MAX_VALUE / 8; + int deliveryTimeoutMs = Integer.MAX_VALUE; + long totalSize = 10 * 1024; + int batchSize = 1024 + DefaultRecordBatch.RECORD_BATCH_OVERHEAD; + String metricGrpName = "producer-metrics"; + + final RecordAccumulator accum = new RecordAccumulator(logContext, batchSize, + CompressionType.NONE, lingerMs, retryBackoffMs, deliveryTimeoutMs, metrics, metricGrpName, time, new ApiVersions(), null, + new BufferPool(totalSize, batchSize, metrics, time, metricGrpName)); + + long now = time.milliseconds(); + accum.append(tp1, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + RecordAccumulator.ReadyCheckResult result = accum.ready(cluster, now + lingerMs + 1); + assertEquals(Collections.singleton(node1), result.readyNodes, "Node1 should be ready"); + Map> batches = accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE, now + lingerMs + 1); + assertEquals(1, batches.size(), "Node1 should be the only ready node."); + assertEquals(1, batches.get(0).size(), "Partition 0 should only have one batch drained."); + + // Reenqueue the batch + now = time.milliseconds(); + accum.reenqueue(batches.get(0).get(0), now); + + // Put message for partition 1 into accumulator + accum.append(tp2, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + result = accum.ready(cluster, now + lingerMs + 1); + assertEquals(Collections.singleton(node1), result.readyNodes, "Node1 should be ready"); + + // tp1 should backoff while tp2 should not + batches = accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE, now + lingerMs + 1); + assertEquals(1, batches.size(), "Node1 should be the only ready node."); + assertEquals(1, batches.get(0).size(), "Node1 should only have one batch drained."); + assertEquals(tp2, batches.get(0).get(0).topicPartition, "Node1 should only have one batch for partition 1."); + + // Partition 0 can be drained after retry backoff + result = accum.ready(cluster, now + retryBackoffMs + 1); + assertEquals(Collections.singleton(node1), result.readyNodes, "Node1 should be ready"); + batches = accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE, now + retryBackoffMs + 1); + assertEquals(1, batches.size(), "Node1 should be the only ready node."); + assertEquals(1, batches.get(0).size(), "Node1 should only have one batch drained."); + assertEquals(tp1, batches.get(0).get(0).topicPartition, "Node1 should only have one batch for partition 0."); + } + + @Test + public void testFlush() throws Exception { + int lingerMs = Integer.MAX_VALUE; + final RecordAccumulator accum = createTestRecordAccumulator( + 4 * 1024 + DefaultRecordBatch.RECORD_BATCH_OVERHEAD, 64 * 1024, CompressionType.NONE, lingerMs); + + for (int i = 0; i < 100; i++) { + accum.append(new TopicPartition(topic, i % 3), 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + assertTrue(accum.hasIncomplete()); + } + RecordAccumulator.ReadyCheckResult result = accum.ready(cluster, time.milliseconds()); + assertEquals(0, result.readyNodes.size(), "No nodes should be ready."); + + accum.beginFlush(); + result = accum.ready(cluster, time.milliseconds()); + + // drain and deallocate all batches + Map> results = accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE, time.milliseconds()); + assertTrue(accum.hasIncomplete()); + + for (List batches: results.values()) + for (ProducerBatch batch: batches) + accum.deallocate(batch); + + // should be complete with no unsent records. + accum.awaitFlushCompletion(); + assertFalse(accum.hasUndrained()); + assertFalse(accum.hasIncomplete()); + } + + + private void delayedInterrupt(final Thread thread, final long delayMs) { + Thread t = new Thread() { + public void run() { + Time.SYSTEM.sleep(delayMs); + thread.interrupt(); + } + }; + t.start(); + } + + @Test + public void testAwaitFlushComplete() throws Exception { + RecordAccumulator accum = createTestRecordAccumulator( + 4 * 1024 + DefaultRecordBatch.RECORD_BATCH_OVERHEAD, 64 * 1024, CompressionType.NONE, Integer.MAX_VALUE); + accum.append(new TopicPartition(topic, 0), 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + + accum.beginFlush(); + assertTrue(accum.flushInProgress()); + delayedInterrupt(Thread.currentThread(), 1000L); + try { + accum.awaitFlushCompletion(); + fail("awaitFlushCompletion should throw InterruptException"); + } catch (InterruptedException e) { + assertFalse(accum.flushInProgress(), "flushInProgress count should be decremented even if thread is interrupted"); + } + } + + @Test + public void testAbortIncompleteBatches() throws Exception { + int lingerMs = Integer.MAX_VALUE; + int numRecords = 100; + + final AtomicInteger numExceptionReceivedInCallback = new AtomicInteger(0); + final RecordAccumulator accum = createTestRecordAccumulator( + 128 + DefaultRecordBatch.RECORD_BATCH_OVERHEAD, 64 * 1024, CompressionType.NONE, lingerMs); + class TestCallback implements Callback { + @Override + public void onCompletion(RecordMetadata metadata, Exception exception) { + assertTrue(exception.getMessage().equals("Producer is closed forcefully.")); + numExceptionReceivedInCallback.incrementAndGet(); + } + } + for (int i = 0; i < numRecords; i++) + accum.append(new TopicPartition(topic, i % 3), 0L, key, value, null, new TestCallback(), maxBlockTimeMs, false, time.milliseconds()); + RecordAccumulator.ReadyCheckResult result = accum.ready(cluster, time.milliseconds()); + assertFalse(result.readyNodes.isEmpty()); + Map> drained = accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE, time.milliseconds()); + assertTrue(accum.hasUndrained()); + assertTrue(accum.hasIncomplete()); + + int numDrainedRecords = 0; + for (Map.Entry> drainedEntry : drained.entrySet()) { + for (ProducerBatch batch : drainedEntry.getValue()) { + assertTrue(batch.isClosed()); + assertFalse(batch.produceFuture.completed()); + numDrainedRecords += batch.recordCount; + } + } + + assertTrue(numDrainedRecords > 0 && numDrainedRecords < numRecords); + accum.abortIncompleteBatches(); + assertEquals(numRecords, numExceptionReceivedInCallback.get()); + assertFalse(accum.hasUndrained()); + assertFalse(accum.hasIncomplete()); + } + + @Test + public void testAbortUnsentBatches() throws Exception { + int lingerMs = Integer.MAX_VALUE; + int numRecords = 100; + + final AtomicInteger numExceptionReceivedInCallback = new AtomicInteger(0); + final RecordAccumulator accum = createTestRecordAccumulator( + 128 + DefaultRecordBatch.RECORD_BATCH_OVERHEAD, 64 * 1024, CompressionType.NONE, lingerMs); + final KafkaException cause = new KafkaException(); + + class TestCallback implements Callback { + @Override + public void onCompletion(RecordMetadata metadata, Exception exception) { + assertEquals(cause, exception); + numExceptionReceivedInCallback.incrementAndGet(); + } + } + for (int i = 0; i < numRecords; i++) + accum.append(new TopicPartition(topic, i % 3), 0L, key, value, null, new TestCallback(), maxBlockTimeMs, false, time.milliseconds()); + RecordAccumulator.ReadyCheckResult result = accum.ready(cluster, time.milliseconds()); + assertFalse(result.readyNodes.isEmpty()); + Map> drained = accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE, + time.milliseconds()); + assertTrue(accum.hasUndrained()); + assertTrue(accum.hasIncomplete()); + + accum.abortUndrainedBatches(cause); + int numDrainedRecords = 0; + for (Map.Entry> drainedEntry : drained.entrySet()) { + for (ProducerBatch batch : drainedEntry.getValue()) { + assertTrue(batch.isClosed()); + assertFalse(batch.produceFuture.completed()); + numDrainedRecords += batch.recordCount; + } + } + + assertTrue(numDrainedRecords > 0); + assertTrue(numExceptionReceivedInCallback.get() > 0); + assertEquals(numRecords, numExceptionReceivedInCallback.get() + numDrainedRecords); + assertFalse(accum.hasUndrained()); + assertTrue(accum.hasIncomplete()); + } + + private void doExpireBatchSingle(int deliveryTimeoutMs) throws InterruptedException { + int lingerMs = 300; + List muteStates = Arrays.asList(false, true); + Set readyNodes = null; + List expiredBatches = new ArrayList<>(); + // test case assumes that the records do not fill the batch completely + int batchSize = 1025; + RecordAccumulator accum = createTestRecordAccumulator(deliveryTimeoutMs, + batchSize + DefaultRecordBatch.RECORD_BATCH_OVERHEAD, 10 * batchSize, CompressionType.NONE, lingerMs); + + // Make the batches ready due to linger. These batches are not in retry + for (Boolean mute: muteStates) { + if (time.milliseconds() < System.currentTimeMillis()) + time.setCurrentTimeMs(System.currentTimeMillis()); + accum.append(tp1, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + assertEquals(0, accum.ready(cluster, time.milliseconds()).readyNodes.size(), "No partition should be ready."); + + time.sleep(lingerMs); + readyNodes = accum.ready(cluster, time.milliseconds()).readyNodes; + assertEquals(Collections.singleton(node1), readyNodes, "Our partition's leader should be ready"); + + expiredBatches = accum.expiredBatches(time.milliseconds()); + assertEquals(0, expiredBatches.size(), "The batch should not expire when just linger has passed"); + + if (mute) + accum.mutePartition(tp1); + else + accum.unmutePartition(tp1); + + // Advance the clock to expire the batch. + time.sleep(deliveryTimeoutMs - lingerMs); + expiredBatches = accum.expiredBatches(time.milliseconds()); + assertEquals(1, expiredBatches.size(), "The batch may expire when the partition is muted"); + assertEquals(0, accum.ready(cluster, time.milliseconds()).readyNodes.size(), "No partitions should be ready."); + } + } + + @Test + public void testExpiredBatchSingle() throws InterruptedException { + doExpireBatchSingle(3200); + } + + @Test + public void testExpiredBatchSingleMaxValue() throws InterruptedException { + doExpireBatchSingle(Integer.MAX_VALUE); + } + + @Test + public void testExpiredBatches() throws InterruptedException { + long retryBackoffMs = 100L; + int lingerMs = 30; + int requestTimeout = 60; + int deliveryTimeoutMs = 3200; + + // test case assumes that the records do not fill the batch completely + int batchSize = 1025; + + RecordAccumulator accum = createTestRecordAccumulator( + deliveryTimeoutMs, batchSize + DefaultRecordBatch.RECORD_BATCH_OVERHEAD, 10 * batchSize, CompressionType.NONE, lingerMs); + int appends = expectedNumAppends(batchSize); + + // Test batches not in retry + for (int i = 0; i < appends; i++) { + accum.append(tp1, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + assertEquals(0, accum.ready(cluster, time.milliseconds()).readyNodes.size(), "No partitions should be ready."); + } + // Make the batches ready due to batch full + accum.append(tp1, 0L, key, value, Record.EMPTY_HEADERS, null, 0, false, time.milliseconds()); + Set readyNodes = accum.ready(cluster, time.milliseconds()).readyNodes; + assertEquals(Collections.singleton(node1), readyNodes, "Our partition's leader should be ready"); + // Advance the clock to expire the batch. + time.sleep(deliveryTimeoutMs + 1); + accum.mutePartition(tp1); + List expiredBatches = accum.expiredBatches(time.milliseconds()); + assertEquals(2, expiredBatches.size(), "The batches will be muted no matter if the partition is muted or not"); + + accum.unmutePartition(tp1); + expiredBatches = accum.expiredBatches(time.milliseconds()); + assertEquals(0, expiredBatches.size(), "All batches should have been expired earlier"); + assertEquals(0, accum.ready(cluster, time.milliseconds()).readyNodes.size(), "No partitions should be ready."); + + // Advance the clock to make the next batch ready due to linger.ms + time.sleep(lingerMs); + assertEquals(Collections.singleton(node1), readyNodes, "Our partition's leader should be ready"); + time.sleep(requestTimeout + 1); + + accum.mutePartition(tp1); + expiredBatches = accum.expiredBatches(time.milliseconds()); + assertEquals(0, expiredBatches.size(), "The batch should not be expired when metadata is still available and partition is muted"); + + accum.unmutePartition(tp1); + expiredBatches = accum.expiredBatches(time.milliseconds()); + assertEquals(0, expiredBatches.size(), "All batches should have been expired"); + assertEquals(0, accum.ready(cluster, time.milliseconds()).readyNodes.size(), "No partitions should be ready."); + + // Test batches in retry. + // Create a retried batch + accum.append(tp1, 0L, key, value, Record.EMPTY_HEADERS, null, 0, false, time.milliseconds()); + time.sleep(lingerMs); + readyNodes = accum.ready(cluster, time.milliseconds()).readyNodes; + assertEquals(Collections.singleton(node1), readyNodes, "Our partition's leader should be ready"); + Map> drained = accum.drain(cluster, readyNodes, Integer.MAX_VALUE, time.milliseconds()); + assertEquals(drained.get(node1.id()).size(), 1, "There should be only one batch."); + time.sleep(1000L); + accum.reenqueue(drained.get(node1.id()).get(0), time.milliseconds()); + + // test expiration. + time.sleep(requestTimeout + retryBackoffMs); + expiredBatches = accum.expiredBatches(time.milliseconds()); + assertEquals(0, expiredBatches.size(), "The batch should not be expired."); + time.sleep(1L); + + accum.mutePartition(tp1); + expiredBatches = accum.expiredBatches(time.milliseconds()); + assertEquals(0, expiredBatches.size(), "The batch should not be expired when the partition is muted"); + + accum.unmutePartition(tp1); + expiredBatches = accum.expiredBatches(time.milliseconds()); + assertEquals(0, expiredBatches.size(), "All batches should have been expired."); + + // Test that when being throttled muted batches are expired before the throttle time is over. + accum.append(tp1, 0L, key, value, Record.EMPTY_HEADERS, null, 0, false, time.milliseconds()); + time.sleep(lingerMs); + readyNodes = accum.ready(cluster, time.milliseconds()).readyNodes; + assertEquals(Collections.singleton(node1), readyNodes, "Our partition's leader should be ready"); + // Advance the clock to expire the batch. + time.sleep(requestTimeout + 1); + accum.mutePartition(tp1); + expiredBatches = accum.expiredBatches(time.milliseconds()); + assertEquals(0, expiredBatches.size(), "The batch should not be expired when the partition is muted"); + + long throttleTimeMs = 100L; + accum.unmutePartition(tp1); + // The batch shouldn't be expired yet. + expiredBatches = accum.expiredBatches(time.milliseconds()); + assertEquals(0, expiredBatches.size(), "The batch should not be expired when the partition is muted"); + + // Once the throttle time is over, the batch can be expired. + time.sleep(throttleTimeMs); + expiredBatches = accum.expiredBatches(time.milliseconds()); + assertEquals(0, expiredBatches.size(), "All batches should have been expired earlier"); + assertEquals(1, accum.ready(cluster, time.milliseconds()).readyNodes.size(), "No partitions should be ready."); + } + + @Test + public void testMutedPartitions() throws InterruptedException { + long now = time.milliseconds(); + // test case assumes that the records do not fill the batch completely + int batchSize = 1025; + + RecordAccumulator accum = createTestRecordAccumulator( + batchSize + DefaultRecordBatch.RECORD_BATCH_OVERHEAD, 10 * batchSize, CompressionType.NONE, 10); + int appends = expectedNumAppends(batchSize); + for (int i = 0; i < appends; i++) { + accum.append(tp1, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + assertEquals(0, accum.ready(cluster, now).readyNodes.size(), "No partitions should be ready."); + } + time.sleep(2000); + + // Test ready with muted partition + accum.mutePartition(tp1); + RecordAccumulator.ReadyCheckResult result = accum.ready(cluster, time.milliseconds()); + assertEquals(0, result.readyNodes.size(), "No node should be ready"); + + // Test ready without muted partition + accum.unmutePartition(tp1); + result = accum.ready(cluster, time.milliseconds()); + assertTrue(result.readyNodes.size() > 0, "The batch should be ready"); + + // Test drain with muted partition + accum.mutePartition(tp1); + Map> drained = accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE, time.milliseconds()); + assertEquals(0, drained.get(node1.id()).size(), "No batch should have been drained"); + + // Test drain without muted partition. + accum.unmutePartition(tp1); + drained = accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE, time.milliseconds()); + assertTrue(drained.get(node1.id()).size() > 0, "The batch should have been drained."); + } + + @Test + public void testIdempotenceWithOldMagic() { + // Simulate talking to an older broker, ie. one which supports a lower magic. + ApiVersions apiVersions = new ApiVersions(); + int batchSize = 1025; + int deliveryTimeoutMs = 3200; + int lingerMs = 10; + long retryBackoffMs = 100L; + long totalSize = 10 * batchSize; + String metricGrpName = "producer-metrics"; + + apiVersions.update("foobar", NodeApiVersions.create(ApiKeys.PRODUCE.id, (short) 0, (short) 2)); + TransactionManager transactionManager = new TransactionManager(new LogContext(), null, 0, retryBackoffMs, apiVersions); + RecordAccumulator accum = new RecordAccumulator(logContext, batchSize + DefaultRecordBatch.RECORD_BATCH_OVERHEAD, + CompressionType.NONE, lingerMs, retryBackoffMs, deliveryTimeoutMs, metrics, metricGrpName, time, apiVersions, transactionManager, + new BufferPool(totalSize, batchSize, metrics, time, metricGrpName)); + assertThrows(UnsupportedVersionException.class, + () -> accum.append(tp1, 0L, key, value, Record.EMPTY_HEADERS, null, 0, false, time.milliseconds())); + } + + @Test + public void testRecordsDrainedWhenTransactionCompleting() throws Exception { + int batchSize = 1025; + int deliveryTimeoutMs = 3200; + int lingerMs = 10; + long totalSize = 10 * batchSize; + + TransactionManager transactionManager = Mockito.mock(TransactionManager.class); + RecordAccumulator accumulator = createTestRecordAccumulator(transactionManager, deliveryTimeoutMs, + batchSize, totalSize, CompressionType.NONE, lingerMs); + + ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(12345L, (short) 5); + Mockito.when(transactionManager.producerIdAndEpoch()).thenReturn(producerIdAndEpoch); + Mockito.when(transactionManager.isSendToPartitionAllowed(tp1)).thenReturn(true); + Mockito.when(transactionManager.isPartitionAdded(tp1)).thenReturn(true); + Mockito.when(transactionManager.firstInFlightSequence(tp1)).thenReturn(0); + + // Initially, the transaction is still in progress, so we should respect the linger. + Mockito.when(transactionManager.isCompleting()).thenReturn(false); + + accumulator.append(tp1, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, + false, time.milliseconds()); + accumulator.append(tp1, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, + false, time.milliseconds()); + assertTrue(accumulator.hasUndrained()); + + RecordAccumulator.ReadyCheckResult firstResult = accumulator.ready(cluster, time.milliseconds()); + assertEquals(0, firstResult.readyNodes.size()); + Map> firstDrained = accumulator.drain(cluster, firstResult.readyNodes, + Integer.MAX_VALUE, time.milliseconds()); + assertEquals(0, firstDrained.size()); + + // Once the transaction begins completion, then the batch should be drained immediately. + Mockito.when(transactionManager.isCompleting()).thenReturn(true); + + RecordAccumulator.ReadyCheckResult secondResult = accumulator.ready(cluster, time.milliseconds()); + assertEquals(1, secondResult.readyNodes.size()); + Node readyNode = secondResult.readyNodes.iterator().next(); + + Map> secondDrained = accumulator.drain(cluster, secondResult.readyNodes, + Integer.MAX_VALUE, time.milliseconds()); + assertEquals(Collections.singleton(readyNode.id()), secondDrained.keySet()); + List batches = secondDrained.get(readyNode.id()); + assertEquals(1, batches.size()); + } + + @Test + public void testSplitAndReenqueue() throws ExecutionException, InterruptedException { + long now = time.milliseconds(); + RecordAccumulator accum = createTestRecordAccumulator(1024, 10 * 1024, CompressionType.GZIP, 10); + + // Create a big batch + ByteBuffer buffer = ByteBuffer.allocate(4096); + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, CompressionType.NONE, TimestampType.CREATE_TIME, 0L); + ProducerBatch batch = new ProducerBatch(tp1, builder, now, true); + + byte[] value = new byte[1024]; + final AtomicInteger acked = new AtomicInteger(0); + Callback cb = new Callback() { + @Override + public void onCompletion(RecordMetadata metadata, Exception exception) { + acked.incrementAndGet(); + } + }; + // Append two messages so the batch is too big. + Future future1 = batch.tryAppend(now, null, value, Record.EMPTY_HEADERS, cb, now); + Future future2 = batch.tryAppend(now, null, value, Record.EMPTY_HEADERS, cb, now); + assertNotNull(future1); + assertNotNull(future2); + batch.close(); + // Enqueue the batch to the accumulator as if the batch was created by the accumulator. + accum.reenqueue(batch, now); + time.sleep(101L); + // Drain the batch. + RecordAccumulator.ReadyCheckResult result = accum.ready(cluster, time.milliseconds()); + assertTrue(result.readyNodes.size() > 0, "The batch should be ready"); + Map> drained = accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE, time.milliseconds()); + assertEquals(1, drained.size(), "Only node1 should be drained"); + assertEquals(1, drained.get(node1.id()).size(), "Only one batch should be drained"); + // Split and reenqueue the batch. + accum.splitAndReenqueue(drained.get(node1.id()).get(0)); + time.sleep(101L); + + drained = accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE, time.milliseconds()); + assertFalse(drained.isEmpty()); + assertFalse(drained.get(node1.id()).isEmpty()); + drained.get(node1.id()).get(0).complete(acked.get(), 100L); + assertEquals(1, acked.get(), "The first message should have been acked."); + assertTrue(future1.isDone()); + assertEquals(0, future1.get().offset()); + + drained = accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE, time.milliseconds()); + assertFalse(drained.isEmpty()); + assertFalse(drained.get(node1.id()).isEmpty()); + drained.get(node1.id()).get(0).complete(acked.get(), 100L); + assertEquals(2, acked.get(), "Both message should have been acked."); + assertTrue(future2.isDone()); + assertEquals(1, future2.get().offset()); + } + + @Test + public void testSplitBatchOffAccumulator() throws InterruptedException { + long seed = System.currentTimeMillis(); + final int batchSize = 1024; + final int bufferCapacity = 3 * 1024; + + // First set the compression ratio estimation to be good. + CompressionRatioEstimator.setEstimation(tp1.topic(), CompressionType.GZIP, 0.1f); + RecordAccumulator accum = createTestRecordAccumulator(batchSize, bufferCapacity, CompressionType.GZIP, 0); + int numSplitBatches = prepareSplitBatches(accum, seed, 100, 20); + assertTrue(numSplitBatches > 0, "There should be some split batches"); + // Drain all the split batches. + RecordAccumulator.ReadyCheckResult result = accum.ready(cluster, time.milliseconds()); + for (int i = 0; i < numSplitBatches; i++) { + Map> drained = + accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE, time.milliseconds()); + assertFalse(drained.isEmpty()); + assertFalse(drained.get(node1.id()).isEmpty()); + } + assertTrue(accum.ready(cluster, time.milliseconds()).readyNodes.isEmpty(), "All the batches should have been drained."); + assertEquals(bufferCapacity, accum.bufferPoolAvailableMemory(), + "The split batches should be allocated off the accumulator"); + } + + @Test + public void testSplitFrequency() throws InterruptedException { + long seed = System.currentTimeMillis(); + Random random = new Random(); + random.setSeed(seed); + final int batchSize = 1024; + final int numMessages = 1000; + + RecordAccumulator accum = createTestRecordAccumulator(batchSize, 3 * 1024, CompressionType.GZIP, 10); + // Adjust the high and low compression ratio message percentage + for (int goodCompRatioPercentage = 1; goodCompRatioPercentage < 100; goodCompRatioPercentage++) { + int numSplit = 0; + int numBatches = 0; + CompressionRatioEstimator.resetEstimation(topic); + for (int i = 0; i < numMessages; i++) { + int dice = random.nextInt(100); + byte[] value = (dice < goodCompRatioPercentage) ? + bytesWithGoodCompression(random) : bytesWithPoorCompression(random, 100); + accum.append(tp1, 0L, null, value, Record.EMPTY_HEADERS, null, 0, false, time.milliseconds()); + BatchDrainedResult result = completeOrSplitBatches(accum, batchSize); + numSplit += result.numSplit; + numBatches += result.numBatches; + } + time.sleep(10); + BatchDrainedResult result = completeOrSplitBatches(accum, batchSize); + numSplit += result.numSplit; + numBatches += result.numBatches; + assertTrue((double) numSplit / numBatches < 0.1f, String.format("Total num batches = %d, split batches = %d, more than 10%% of the batch splits. " + + "Random seed is " + seed, + numBatches, numSplit)); + } + } + + @Test + public void testSoonToExpireBatchesArePickedUpForExpiry() throws InterruptedException { + int lingerMs = 500; + int batchSize = 1025; + + RecordAccumulator accum = createTestRecordAccumulator( + batchSize + DefaultRecordBatch.RECORD_BATCH_OVERHEAD, 10 * batchSize, CompressionType.NONE, lingerMs); + + accum.append(tp1, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + Set readyNodes = accum.ready(cluster, time.milliseconds()).readyNodes; + Map> drained = accum.drain(cluster, readyNodes, Integer.MAX_VALUE, time.milliseconds()); + assertTrue(drained.isEmpty()); + //assertTrue(accum.soonToExpireInFlightBatches().isEmpty()); + + // advanced clock and send one batch out but it should not be included in soon to expire inflight + // batches because batch's expiry is quite far. + time.sleep(lingerMs + 1); + readyNodes = accum.ready(cluster, time.milliseconds()).readyNodes; + drained = accum.drain(cluster, readyNodes, Integer.MAX_VALUE, time.milliseconds()); + assertEquals(1, drained.size(), "A batch did not drain after linger"); + //assertTrue(accum.soonToExpireInFlightBatches().isEmpty()); + + // Queue another batch and advance clock such that batch expiry time is earlier than request timeout. + accum.append(tp2, 0L, key, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + time.sleep(lingerMs * 4); + + // Now drain and check that accumulator picked up the drained batch because its expiry is soon. + readyNodes = accum.ready(cluster, time.milliseconds()).readyNodes; + drained = accum.drain(cluster, readyNodes, Integer.MAX_VALUE, time.milliseconds()); + assertEquals(1, drained.size(), "A batch did not drain after linger"); + } + + @Test + public void testExpiredBatchesRetry() throws InterruptedException { + int lingerMs = 3000; + int rtt = 1000; + int deliveryTimeoutMs = 3200; + Set readyNodes; + List expiredBatches; + List muteStates = Arrays.asList(false, true); + + // test case assumes that the records do not fill the batch completely + int batchSize = 1025; + RecordAccumulator accum = createTestRecordAccumulator( + batchSize + DefaultRecordBatch.RECORD_BATCH_OVERHEAD, 10 * batchSize, CompressionType.NONE, lingerMs); + + // Test batches in retry. + for (Boolean mute : muteStates) { + accum.append(tp1, 0L, key, value, Record.EMPTY_HEADERS, null, 0, false, time.milliseconds()); + time.sleep(lingerMs); + readyNodes = accum.ready(cluster, time.milliseconds()).readyNodes; + assertEquals(Collections.singleton(node1), readyNodes, "Our partition's leader should be ready"); + Map> drained = accum.drain(cluster, readyNodes, Integer.MAX_VALUE, time.milliseconds()); + assertEquals(1, drained.get(node1.id()).size(), "There should be only one batch."); + time.sleep(rtt); + accum.reenqueue(drained.get(node1.id()).get(0), time.milliseconds()); + + if (mute) + accum.mutePartition(tp1); + else + accum.unmutePartition(tp1); + + // test expiration + time.sleep(deliveryTimeoutMs - rtt); + accum.drain(cluster, Collections.singleton(node1), Integer.MAX_VALUE, time.milliseconds()); + expiredBatches = accum.expiredBatches(time.milliseconds()); + assertEquals(mute ? 1 : 0, expiredBatches.size(), "RecordAccumulator has expired batches if the partition is not muted"); + } + } + + @Test + public void testStickyBatches() throws Exception { + long now = time.milliseconds(); + + // Test case assumes that the records do not fill the batch completely + int batchSize = 1025; + + Partitioner partitioner = new DefaultPartitioner(); + RecordAccumulator accum = createTestRecordAccumulator(3200, + batchSize + DefaultRecordBatch.RECORD_BATCH_OVERHEAD, 10L * batchSize, CompressionType.NONE, 10); + int expectedAppends = expectedNumAppendsNoKey(batchSize); + + // Create first batch + int partition = partitioner.partition(topic, null, null, "value", value, cluster); + TopicPartition tp = new TopicPartition(topic, partition); + accum.append(tp, 0L, null, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + int appends = 1; + + boolean switchPartition = false; + while (!switchPartition) { + // Append to the first batch + partition = partitioner.partition(topic, null, null, "value", value, cluster); + tp = new TopicPartition(topic, partition); + RecordAccumulator.RecordAppendResult result = accum.append(tp, 0L, null, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, true, time.milliseconds()); + Deque partitionBatches1 = accum.batches().get(tp1); + Deque partitionBatches2 = accum.batches().get(tp2); + Deque partitionBatches3 = accum.batches().get(tp3); + int numBatches = (partitionBatches1 == null ? 0 : partitionBatches1.size()) + (partitionBatches2 == null ? 0 : partitionBatches2.size()) + (partitionBatches3 == null ? 0 : partitionBatches3.size()); + // Only one batch is created because the partition is sticky. + assertEquals(1, numBatches); + + switchPartition = result.abortForNewBatch; + // We only appended if we do not retry. + if (!switchPartition) { + appends++; + assertEquals(0, accum.ready(cluster, now).readyNodes.size(), "No partitions should be ready."); + } + } + + // Batch should be full. + assertEquals(1, accum.ready(cluster, time.milliseconds()).readyNodes.size()); + assertEquals(appends, expectedAppends); + switchPartition = false; + + // KafkaProducer would call this method in this case, make second batch + partitioner.onNewBatch(topic, cluster, partition); + partition = partitioner.partition(topic, null, null, "value", value, cluster); + tp = new TopicPartition(topic, partition); + accum.append(tp, 0L, null, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + appends++; + + // These appends all go into the second batch + while (!switchPartition) { + partition = partitioner.partition(topic, null, null, "value", value, cluster); + tp = new TopicPartition(topic, partition); + RecordAccumulator.RecordAppendResult result = accum.append(tp, 0L, null, value, Record.EMPTY_HEADERS, null, maxBlockTimeMs, true, time.milliseconds()); + Deque partitionBatches1 = accum.batches().get(tp1); + Deque partitionBatches2 = accum.batches().get(tp2); + Deque partitionBatches3 = accum.batches().get(tp3); + int numBatches = (partitionBatches1 == null ? 0 : partitionBatches1.size()) + (partitionBatches2 == null ? 0 : partitionBatches2.size()) + (partitionBatches3 == null ? 0 : partitionBatches3.size()); + // Only two batches because the new partition is also sticky. + assertEquals(2, numBatches); + + switchPartition = result.abortForNewBatch; + // We only appended if we do not retry. + if (!switchPartition) { + appends++; + } + } + + // There should be two full batches now. + assertEquals(appends, 2 * expectedAppends); + } + + private int prepareSplitBatches(RecordAccumulator accum, long seed, int recordSize, int numRecords) + throws InterruptedException { + Random random = new Random(); + random.setSeed(seed); + + // First set the compression ratio estimation to be good. + CompressionRatioEstimator.setEstimation(tp1.topic(), CompressionType.GZIP, 0.1f); + // Append 20 records of 100 bytes size with poor compression ratio should make the batch too big. + for (int i = 0; i < numRecords; i++) { + accum.append(tp1, 0L, null, bytesWithPoorCompression(random, recordSize), Record.EMPTY_HEADERS, null, 0, false, time.milliseconds()); + } + + RecordAccumulator.ReadyCheckResult result = accum.ready(cluster, time.milliseconds()); + assertFalse(result.readyNodes.isEmpty()); + Map> batches = accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE, time.milliseconds()); + assertEquals(1, batches.size()); + assertEquals(1, batches.values().iterator().next().size()); + ProducerBatch batch = batches.values().iterator().next().get(0); + int numSplitBatches = accum.splitAndReenqueue(batch); + accum.deallocate(batch); + + return numSplitBatches; + } + + private BatchDrainedResult completeOrSplitBatches(RecordAccumulator accum, int batchSize) { + int numSplit = 0; + int numBatches = 0; + boolean batchDrained; + do { + batchDrained = false; + RecordAccumulator.ReadyCheckResult result = accum.ready(cluster, time.milliseconds()); + Map> batches = accum.drain(cluster, result.readyNodes, Integer.MAX_VALUE, time.milliseconds()); + for (List batchList : batches.values()) { + for (ProducerBatch batch : batchList) { + batchDrained = true; + numBatches++; + if (batch.estimatedSizeInBytes() > batchSize + DefaultRecordBatch.RECORD_BATCH_OVERHEAD) { + accum.splitAndReenqueue(batch); + // release the resource of the original big batch. + numSplit++; + } else { + batch.complete(0L, 0L); + } + accum.deallocate(batch); + } + } + } while (batchDrained); + return new BatchDrainedResult(numSplit, numBatches); + } + + /** + * Generates the compression ratio at about 0.6 + */ + private byte[] bytesWithGoodCompression(Random random) { + byte[] value = new byte[100]; + ByteBuffer buffer = ByteBuffer.wrap(value); + while (buffer.remaining() > 0) + buffer.putInt(random.nextInt(1000)); + return value; + } + + /** + * Generates the compression ratio at about 0.9 + */ + private byte[] bytesWithPoorCompression(Random random, int size) { + byte[] value = new byte[size]; + random.nextBytes(value); + return value; + } + + private class BatchDrainedResult { + final int numSplit; + final int numBatches; + BatchDrainedResult(int numSplit, int numBatches) { + this.numBatches = numBatches; + this.numSplit = numSplit; + } + } + + /** + * Return the offset delta. + */ + private int expectedNumAppends(int batchSize) { + int size = 0; + int offsetDelta = 0; + while (true) { + int recordSize = DefaultRecord.sizeInBytes(offsetDelta, 0, key.length, value.length, + Record.EMPTY_HEADERS); + if (size + recordSize > batchSize) + return offsetDelta; + offsetDelta += 1; + size += recordSize; + } + } + + /** + * Return the offset delta when there is no key. + */ + private int expectedNumAppendsNoKey(int batchSize) { + int size = 0; + int offsetDelta = 0; + while (true) { + int recordSize = DefaultRecord.sizeInBytes(offsetDelta, 0, 0, value.length, + Record.EMPTY_HEADERS); + if (size + recordSize > batchSize) + return offsetDelta; + offsetDelta += 1; + size += recordSize; + } + } + + private RecordAccumulator createTestRecordAccumulator(int batchSize, long totalSize, CompressionType type, int lingerMs) { + int deliveryTimeoutMs = 3200; + return createTestRecordAccumulator(deliveryTimeoutMs, batchSize, totalSize, type, lingerMs); + } + + private RecordAccumulator createTestRecordAccumulator(int deliveryTimeoutMs, int batchSize, long totalSize, CompressionType type, int lingerMs) { + return createTestRecordAccumulator(null, deliveryTimeoutMs, batchSize, totalSize, type, lingerMs); + } + + /** + * Return a test RecordAccumulator instance + */ + private RecordAccumulator createTestRecordAccumulator( + TransactionManager txnManager, + int deliveryTimeoutMs, + int batchSize, + long totalSize, + CompressionType type, + int lingerMs + ) { + long retryBackoffMs = 100L; + String metricGrpName = "producer-metrics"; + + return new RecordAccumulator( + logContext, + batchSize, + type, + lingerMs, + retryBackoffMs, + deliveryTimeoutMs, + metrics, + metricGrpName, + time, + new ApiVersions(), + txnManager, + new BufferPool(totalSize, batchSize, metrics, time, metricGrpName)); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java new file mode 100644 index 0000000..34e4f18 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java @@ -0,0 +1,3264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer.internals; + +import org.apache.kafka.clients.ApiVersions; +import org.apache.kafka.clients.ClientRequest; +import org.apache.kafka.clients.ClientResponse; +import org.apache.kafka.clients.MockClient; +import org.apache.kafka.clients.NetworkClient; +import org.apache.kafka.clients.NodeApiVersions; +import org.apache.kafka.clients.producer.Callback; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.InvalidRecordException; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.MetricNameTemplate; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.ClusterAuthorizationException; +import org.apache.kafka.common.errors.InvalidRequestException; +import org.apache.kafka.common.errors.NetworkException; +import org.apache.kafka.common.errors.RecordTooLargeException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.TopicAuthorizationException; +import org.apache.kafka.common.errors.TransactionAbortedException; +import org.apache.kafka.common.errors.UnsupportedForMessageFormatException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.internals.ClusterResourceListeners; +import org.apache.kafka.common.message.ApiMessageType; +import org.apache.kafka.common.message.EndTxnResponseData; +import org.apache.kafka.common.message.InitProducerIdResponseData; +import org.apache.kafka.common.message.ProduceRequestData; +import org.apache.kafka.common.message.ProduceResponseData; +import org.apache.kafka.common.message.ProduceResponseData.BatchIndexAndErrorMessage; +import org.apache.kafka.common.metrics.KafkaMetric; +import org.apache.kafka.common.metrics.MetricConfig; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.network.NetworkReceive; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.CompressionRatioEstimator; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.MutableRecordBatch; +import org.apache.kafka.common.record.Record; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.requests.AbstractRequest; +import org.apache.kafka.common.requests.AddPartitionsToTxnResponse; +import org.apache.kafka.common.requests.ApiVersionsResponse; +import org.apache.kafka.common.requests.EndTxnRequest; +import org.apache.kafka.common.requests.EndTxnResponse; +import org.apache.kafka.common.requests.FindCoordinatorRequest.CoordinatorType; +import org.apache.kafka.common.requests.FindCoordinatorResponse; +import org.apache.kafka.common.requests.InitProducerIdRequest; +import org.apache.kafka.common.requests.InitProducerIdResponse; +import org.apache.kafka.common.requests.MetadataRequest; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.requests.ProduceRequest; +import org.apache.kafka.common.requests.ProduceResponse; +import org.apache.kafka.common.requests.RequestTestUtils; +import org.apache.kafka.common.requests.TransactionResult; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.ProducerIdAndEpoch; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.test.DelayedReceive; +import org.apache.kafka.test.MockSelector; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.mockito.InOrder; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.Deque; +import java.util.HashMap; +import java.util.HashSet; +import java.util.IdentityHashMap; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.OptionalInt; +import java.util.OptionalLong; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static org.apache.kafka.clients.producer.internals.ProducerTestUtils.runUntil; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.AdditionalMatchers.geq; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class SenderTest { + private static final int MAX_REQUEST_SIZE = 1024 * 1024; + private static final short ACKS_ALL = -1; + private static final String CLIENT_ID = "clientId"; + private static final double EPS = 0.0001; + private static final int MAX_BLOCK_TIMEOUT = 1000; + private static final int REQUEST_TIMEOUT = 5000; + private static final long RETRY_BACKOFF_MS = 50; + private static final int DELIVERY_TIMEOUT_MS = 1500; + private static final long TOPIC_IDLE_MS = 60 * 1000; + + private TopicPartition tp0 = new TopicPartition("test", 0); + private TopicPartition tp1 = new TopicPartition("test", 1); + private MockTime time = new MockTime(); + private int batchSize = 16 * 1024; + private ProducerMetadata metadata = new ProducerMetadata(0, Long.MAX_VALUE, TOPIC_IDLE_MS, + new LogContext(), new ClusterResourceListeners(), time); + private MockClient client = new MockClient(time, metadata); + private ApiVersions apiVersions = new ApiVersions(); + private Metrics metrics = null; + private RecordAccumulator accumulator = null; + private Sender sender = null; + private SenderMetricsRegistry senderMetricsRegistry = null; + private final LogContext logContext = new LogContext(); + + @BeforeEach + public void setup() { + setupWithTransactionState(null); + } + + @AfterEach + public void tearDown() { + this.metrics.close(); + } + + private static Map partitionRecords(ProduceRequest request) { + Map partitionRecords = new HashMap<>(); + request.data().topicData().forEach(tpData -> tpData.partitionData().forEach(p -> { + TopicPartition tp = new TopicPartition(tpData.name(), p.index()); + partitionRecords.put(tp, (MemoryRecords) p.records()); + })); + return Collections.unmodifiableMap(partitionRecords); + } + + @Test + public void testSimple() throws Exception { + long offset = 0; + Future future = appendToAccumulator(tp0, 0L, "key", "value"); + sender.runOnce(); // connect + sender.runOnce(); // send produce request + assertEquals(1, client.inFlightRequestCount(), "We should have a single produce request in flight."); + assertEquals(1, sender.inFlightBatches(tp0).size()); + assertTrue(client.hasInFlightRequests()); + client.respond(produceResponse(tp0, offset, Errors.NONE, 0)); + sender.runOnce(); + assertEquals(0, client.inFlightRequestCount(), "All requests completed."); + assertEquals(0, sender.inFlightBatches(tp0).size()); + assertFalse(client.hasInFlightRequests()); + sender.runOnce(); + assertTrue(future.isDone(), "Request should be completed"); + assertEquals(offset, future.get().offset()); + } + + @Test + public void testMessageFormatDownConversion() throws Exception { + // this test case verifies the behavior when the version of the produce request supported by the + // broker changes after the record set is created + + long offset = 0; + + // start off support produce request v3 + apiVersions.update("0", NodeApiVersions.create()); + + Future future = appendToAccumulator(tp0, 0L, "key", "value"); + + // now the partition leader supports only v2 + apiVersions.update("0", NodeApiVersions.create(ApiKeys.PRODUCE.id, (short) 0, (short) 2)); + + client.prepareResponse(body -> { + ProduceRequest request = (ProduceRequest) body; + if (request.version() != 2) + return false; + + MemoryRecords records = partitionRecords(request).get(tp0); + return records != null && + records.sizeInBytes() > 0 && + records.hasMatchingMagic(RecordBatch.MAGIC_VALUE_V1); + }, produceResponse(tp0, offset, Errors.NONE, 0)); + + sender.runOnce(); // connect + sender.runOnce(); // send produce request + + assertTrue(future.isDone(), "Request should be completed"); + assertEquals(offset, future.get().offset()); + } + + @SuppressWarnings("deprecation") + @Test + public void testDownConversionForMismatchedMagicValues() throws Exception { + // it can happen that we construct a record set with mismatching magic values (perhaps + // because the partition leader changed after the record set was initially constructed) + // in this case, we down-convert record sets with newer magic values to match the oldest + // created record set + + long offset = 0; + + // start off support produce request v3 + apiVersions.update("0", NodeApiVersions.create()); + + Future future1 = appendToAccumulator(tp0, 0L, "key", "value"); + + // now the partition leader supports only v2 + apiVersions.update("0", NodeApiVersions.create(ApiKeys.PRODUCE.id, (short) 0, (short) 2)); + + Future future2 = appendToAccumulator(tp1, 0L, "key", "value"); + + // start off support produce request v3 + apiVersions.update("0", NodeApiVersions.create()); + + ProduceResponse.PartitionResponse resp = new ProduceResponse.PartitionResponse(Errors.NONE, offset, RecordBatch.NO_TIMESTAMP, 100); + Map partResp = new HashMap<>(); + partResp.put(tp0, resp); + partResp.put(tp1, resp); + ProduceResponse produceResponse = new ProduceResponse(partResp, 0); + + client.prepareResponse(body -> { + ProduceRequest request = (ProduceRequest) body; + if (request.version() != 2) + return false; + + Map recordsMap = partitionRecords(request); + if (recordsMap.size() != 2) + return false; + + for (MemoryRecords records : recordsMap.values()) { + if (records == null || records.sizeInBytes() == 0 || !records.hasMatchingMagic(RecordBatch.MAGIC_VALUE_V1)) + return false; + } + return true; + }, produceResponse); + + sender.runOnce(); // connect + sender.runOnce(); // send produce request + + assertTrue(future1.isDone(), "Request should be completed"); + assertTrue(future2.isDone(), "Request should be completed"); + } + + /* + * Send multiple requests. Verify that the client side quota metrics have the right values + */ + @SuppressWarnings("deprecation") + @Test + public void testQuotaMetrics() { + MockSelector selector = new MockSelector(time); + Sensor throttleTimeSensor = Sender.throttleTimeSensor(this.senderMetricsRegistry); + Cluster cluster = TestUtils.singletonCluster("test", 1); + Node node = cluster.nodes().get(0); + NetworkClient client = new NetworkClient(selector, metadata, "mock", Integer.MAX_VALUE, + 1000, 1000, 64 * 1024, 64 * 1024, 1000, 10 * 1000, 127 * 1000, + time, true, new ApiVersions(), throttleTimeSensor, logContext); + + ApiVersionsResponse apiVersionsResponse = ApiVersionsResponse.defaultApiVersionsResponse( + 400, ApiMessageType.ListenerType.ZK_BROKER); + ByteBuffer buffer = RequestTestUtils.serializeResponseWithHeader(apiVersionsResponse, ApiKeys.API_VERSIONS.latestVersion(), 0); + + selector.delayedReceive(new DelayedReceive(node.idString(), new NetworkReceive(node.idString(), buffer))); + while (!client.ready(node, time.milliseconds())) { + client.poll(1, time.milliseconds()); + // If a throttled response is received, advance the time to ensure progress. + time.sleep(client.throttleDelayMs(node, time.milliseconds())); + } + selector.clear(); + + for (int i = 1; i <= 3; i++) { + int throttleTimeMs = 100 * i; + ProduceRequest.Builder builder = ProduceRequest.forCurrentMagic(new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection()) + .setAcks((short) 1) + .setTimeoutMs(1000)); + ClientRequest request = client.newClientRequest(node.idString(), builder, time.milliseconds(), true); + client.send(request, time.milliseconds()); + client.poll(1, time.milliseconds()); + ProduceResponse response = produceResponse(tp0, i, Errors.NONE, throttleTimeMs); + buffer = RequestTestUtils.serializeResponseWithHeader(response, ApiKeys.PRODUCE.latestVersion(), request.correlationId()); + selector.completeReceive(new NetworkReceive(node.idString(), buffer)); + client.poll(1, time.milliseconds()); + // If a throttled response is received, advance the time to ensure progress. + time.sleep(client.throttleDelayMs(node, time.milliseconds())); + selector.clear(); + } + Map allMetrics = metrics.metrics(); + KafkaMetric avgMetric = allMetrics.get(this.senderMetricsRegistry.produceThrottleTimeAvg); + KafkaMetric maxMetric = allMetrics.get(this.senderMetricsRegistry.produceThrottleTimeMax); + // Throttle times are ApiVersions=400, Produce=(100, 200, 300) + assertEquals(250, (Double) avgMetric.metricValue(), EPS); + assertEquals(400, (Double) maxMetric.metricValue(), EPS); + client.close(); + } + + @Test + public void testSenderMetricsTemplates() throws Exception { + metrics.close(); + Map clientTags = Collections.singletonMap("client-id", "clientA"); + metrics = new Metrics(new MetricConfig().tags(clientTags)); + SenderMetricsRegistry metricsRegistry = new SenderMetricsRegistry(metrics); + Sender sender = new Sender(logContext, client, metadata, this.accumulator, false, MAX_REQUEST_SIZE, ACKS_ALL, + 1, metricsRegistry, time, REQUEST_TIMEOUT, RETRY_BACKOFF_MS, null, apiVersions); + + // Append a message so that topic metrics are created + appendToAccumulator(tp0, 0L, "key", "value"); + sender.runOnce(); // connect + sender.runOnce(); // send produce request + client.respond(produceResponse(tp0, 0, Errors.NONE, 0)); + sender.runOnce(); + // Create throttle time metrics + Sender.throttleTimeSensor(metricsRegistry); + + // Verify that all metrics except metrics-count have registered templates + Set allMetrics = new HashSet<>(); + for (MetricName n : metrics.metrics().keySet()) { + if (!n.group().equals("kafka-metrics-count")) + allMetrics.add(new MetricNameTemplate(n.name(), n.group(), "", n.tags().keySet())); + } + TestUtils.checkEquals(allMetrics, new HashSet<>(metricsRegistry.allTemplates()), "metrics", "templates"); + } + + @Test + public void testRetries() throws Exception { + // create a sender with retries = 1 + int maxRetries = 1; + Metrics m = new Metrics(); + SenderMetricsRegistry senderMetrics = new SenderMetricsRegistry(m); + try { + Sender sender = new Sender(logContext, client, metadata, this.accumulator, false, MAX_REQUEST_SIZE, ACKS_ALL, + maxRetries, senderMetrics, time, REQUEST_TIMEOUT, RETRY_BACKOFF_MS, null, apiVersions); + // do a successful retry + Future future = appendToAccumulator(tp0, 0L, "key", "value"); + sender.runOnce(); // connect + sender.runOnce(); // send produce request + String id = client.requests().peek().destination(); + Node node = new Node(Integer.parseInt(id), "localhost", 0); + assertEquals(1, client.inFlightRequestCount()); + assertTrue(client.hasInFlightRequests()); + assertEquals(1, sender.inFlightBatches(tp0).size()); + assertTrue(client.isReady(node, time.milliseconds()), "Client ready status should be true"); + client.disconnect(id); + assertEquals(0, client.inFlightRequestCount()); + assertFalse(client.hasInFlightRequests()); + assertFalse(client.isReady(node, time.milliseconds()), "Client ready status should be false"); + // the batch is in accumulator.inFlightBatches until it expires + assertEquals(1, sender.inFlightBatches(tp0).size()); + sender.runOnce(); // receive error + sender.runOnce(); // reconnect + sender.runOnce(); // resend + assertEquals(1, client.inFlightRequestCount()); + assertTrue(client.hasInFlightRequests()); + assertEquals(1, sender.inFlightBatches(tp0).size()); + long offset = 0; + client.respond(produceResponse(tp0, offset, Errors.NONE, 0)); + sender.runOnce(); + assertTrue(future.isDone(), "Request should have retried and completed"); + assertEquals(offset, future.get().offset()); + assertEquals(0, sender.inFlightBatches(tp0).size()); + + // do an unsuccessful retry + future = appendToAccumulator(tp0, 0L, "key", "value"); + sender.runOnce(); // send produce request + assertEquals(1, sender.inFlightBatches(tp0).size()); + for (int i = 0; i < maxRetries + 1; i++) { + client.disconnect(client.requests().peek().destination()); + sender.runOnce(); // receive error + assertEquals(0, sender.inFlightBatches(tp0).size()); + sender.runOnce(); // reconnect + sender.runOnce(); // resend + assertEquals(i > 0 ? 0 : 1, sender.inFlightBatches(tp0).size()); + } + sender.runOnce(); + assertFutureFailure(future, NetworkException.class); + assertEquals(0, sender.inFlightBatches(tp0).size()); + } finally { + m.close(); + } + } + + @Test + public void testSendInOrder() throws Exception { + int maxRetries = 1; + Metrics m = new Metrics(); + SenderMetricsRegistry senderMetrics = new SenderMetricsRegistry(m); + + try { + Sender sender = new Sender(logContext, client, metadata, this.accumulator, true, MAX_REQUEST_SIZE, ACKS_ALL, maxRetries, + senderMetrics, time, REQUEST_TIMEOUT, RETRY_BACKOFF_MS, null, apiVersions); + // Create a two broker cluster, with partition 0 on broker 0 and partition 1 on broker 1 + MetadataResponse metadataUpdate1 = RequestTestUtils.metadataUpdateWith(2, Collections.singletonMap("test", 2)); + client.prepareMetadataUpdate(metadataUpdate1); + + // Send the first message. + TopicPartition tp2 = new TopicPartition("test", 1); + appendToAccumulator(tp2, 0L, "key1", "value1"); + sender.runOnce(); // connect + sender.runOnce(); // send produce request + String id = client.requests().peek().destination(); + assertEquals(ApiKeys.PRODUCE, client.requests().peek().requestBuilder().apiKey()); + Node node = new Node(Integer.parseInt(id), "localhost", 0); + assertEquals(1, client.inFlightRequestCount()); + assertTrue(client.hasInFlightRequests()); + assertTrue(client.isReady(node, time.milliseconds()), "Client ready status should be true"); + assertEquals(1, sender.inFlightBatches(tp2).size()); + + time.sleep(900); + // Now send another message to tp2 + appendToAccumulator(tp2, 0L, "key2", "value2"); + + // Update metadata before sender receives response from broker 0. Now partition 2 moves to broker 0 + MetadataResponse metadataUpdate2 = RequestTestUtils.metadataUpdateWith(1, Collections.singletonMap("test", 2)); + client.prepareMetadataUpdate(metadataUpdate2); + // Sender should not send the second message to node 0. + assertEquals(1, sender.inFlightBatches(tp2).size()); + sender.runOnce(); // receive the response for the previous send, and send the new batch + assertEquals(1, client.inFlightRequestCount()); + assertTrue(client.hasInFlightRequests()); + assertEquals(1, sender.inFlightBatches(tp2).size()); + } finally { + m.close(); + } + } + + @Test + public void testAppendInExpiryCallback() throws InterruptedException { + int messagesPerBatch = 10; + final AtomicInteger expiryCallbackCount = new AtomicInteger(0); + final AtomicReference unexpectedException = new AtomicReference<>(); + final byte[] key = "key".getBytes(); + final byte[] value = "value".getBytes(); + final long maxBlockTimeMs = 1000; + Callback callback = (metadata, exception) -> { + if (exception instanceof TimeoutException) { + expiryCallbackCount.incrementAndGet(); + try { + accumulator.append(tp1, 0L, key, value, + Record.EMPTY_HEADERS, null, maxBlockTimeMs, false, time.milliseconds()); + } catch (InterruptedException e) { + throw new RuntimeException("Unexpected interruption", e); + } + } else if (exception != null) + unexpectedException.compareAndSet(null, exception); + }; + + final long nowMs = time.milliseconds(); + for (int i = 0; i < messagesPerBatch; i++) + accumulator.append(tp1, 0L, key, value, null, callback, maxBlockTimeMs, false, nowMs); + + // Advance the clock to expire the first batch. + time.sleep(10000); + + Node clusterNode = metadata.fetch().nodes().get(0); + Map> drainedBatches = + accumulator.drain(metadata.fetch(), Collections.singleton(clusterNode), Integer.MAX_VALUE, time.milliseconds()); + sender.addToInflightBatches(drainedBatches); + + // Disconnect the target node for the pending produce request. This will ensure that sender will try to + // expire the batch. + client.disconnect(clusterNode.idString()); + client.backoff(clusterNode, 100); + + sender.runOnce(); // We should try to flush the batch, but we expire it instead without sending anything. + assertEquals(messagesPerBatch, expiryCallbackCount.get(), "Callbacks not invoked for expiry"); + assertNull(unexpectedException.get(), "Unexpected exception"); + // Make sure that the reconds were appended back to the batch. + assertTrue(accumulator.batches().containsKey(tp1)); + assertEquals(1, accumulator.batches().get(tp1).size()); + assertEquals(messagesPerBatch, accumulator.batches().get(tp1).peekFirst().recordCount); + } + + /** + * Tests that topics are added to the metadata list when messages are available to send + * and expired if not used during a metadata refresh interval. + */ + @Test + public void testMetadataTopicExpiry() throws Exception { + long offset = 0; + client.updateMetadata(RequestTestUtils.metadataUpdateWith(1, Collections.singletonMap("test", 2))); + + Future future = appendToAccumulator(tp0); + sender.runOnce(); + assertTrue(metadata.containsTopic(tp0.topic()), "Topic not added to metadata"); + client.updateMetadata(RequestTestUtils.metadataUpdateWith(1, Collections.singletonMap("test", 2))); + sender.runOnce(); // send produce request + client.respond(produceResponse(tp0, offset, Errors.NONE, 0)); + sender.runOnce(); + assertEquals(0, client.inFlightRequestCount(), "Request completed."); + assertFalse(client.hasInFlightRequests()); + assertEquals(0, sender.inFlightBatches(tp0).size()); + sender.runOnce(); + assertTrue(future.isDone(), "Request should be completed"); + + assertTrue(metadata.containsTopic(tp0.topic()), "Topic not retained in metadata list"); + time.sleep(TOPIC_IDLE_MS); + client.updateMetadata(RequestTestUtils.metadataUpdateWith(1, Collections.singletonMap("test", 2))); + assertFalse(metadata.containsTopic(tp0.topic()), "Unused topic has not been expired"); + future = appendToAccumulator(tp0); + sender.runOnce(); + assertTrue(metadata.containsTopic(tp0.topic()), "Topic not added to metadata"); + client.updateMetadata(RequestTestUtils.metadataUpdateWith(1, Collections.singletonMap("test", 2))); + sender.runOnce(); // send produce request + client.respond(produceResponse(tp0, offset + 1, Errors.NONE, 0)); + sender.runOnce(); + assertEquals(0, client.inFlightRequestCount(), "Request completed."); + assertFalse(client.hasInFlightRequests()); + assertEquals(0, sender.inFlightBatches(tp0).size()); + sender.runOnce(); + assertTrue(future.isDone(), "Request should be completed"); + } + + @Test + public void testInitProducerIdRequest() { + final long producerId = 343434L; + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager); + prepareAndReceiveInitProducerId(producerId, Errors.NONE); + assertTrue(transactionManager.hasProducerId()); + assertEquals(producerId, transactionManager.producerIdAndEpoch().producerId); + assertEquals((short) 0, transactionManager.producerIdAndEpoch().epoch); + } + + /** + * Verifies that InitProducerId of transactional producer succeeds even if metadata requests + * are pending with only one bootstrap node available and maxInFlight=1, where multiple + * polls are necessary to send requests. + */ + @Test + public void testInitProducerIdWithMaxInFlightOne() throws Exception { + final long producerId = 123456L; + createMockClientWithMaxFlightOneMetadataPending(); + + // Initialize transaction manager. InitProducerId will be queued up until metadata response + // is processed and FindCoordinator can be sent to `leastLoadedNode`. + TransactionManager transactionManager = new TransactionManager(new LogContext(), "testInitProducerIdWithPendingMetadataRequest", + 60000, 100L, new ApiVersions()); + setupWithTransactionState(transactionManager, false, null, false); + ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(producerId, (short) 0); + transactionManager.initializeTransactions(); + sender.runOnce(); + + // Process metadata response, prepare FindCoordinator and InitProducerId responses. + // Verify producerId after the sender is run to process responses. + MetadataResponse metadataUpdate = RequestTestUtils.metadataUpdateWith(1, Collections.emptyMap()); + client.respond(metadataUpdate); + prepareFindCoordinatorResponse(Errors.NONE, "testInitProducerIdWithPendingMetadataRequest"); + prepareInitProducerResponse(Errors.NONE, producerIdAndEpoch.producerId, producerIdAndEpoch.epoch); + waitForProducerId(transactionManager, producerIdAndEpoch); + } + + /** + * Verifies that InitProducerId of idempotent producer succeeds even if metadata requests + * are pending with only one bootstrap node available and maxInFlight=1, where multiple + * polls are necessary to send requests. + */ + @Test + public void testIdempotentInitProducerIdWithMaxInFlightOne() throws Exception { + final long producerId = 123456L; + createMockClientWithMaxFlightOneMetadataPending(); + + // Initialize transaction manager. InitProducerId will be queued up until metadata response + // is processed. + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager, false, null, false); + ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(producerId, (short) 0); + + // Process metadata and InitProducerId responses. + // Verify producerId after the sender is run to process responses. + MetadataResponse metadataUpdate = RequestTestUtils.metadataUpdateWith(1, Collections.emptyMap()); + client.respond(metadataUpdate); + sender.runOnce(); + sender.runOnce(); + client.respond(initProducerIdResponse(producerIdAndEpoch.producerId, producerIdAndEpoch.epoch, Errors.NONE)); + waitForProducerId(transactionManager, producerIdAndEpoch); + } + + /** + * Tests the code path where the target node to send FindCoordinator or InitProducerId + * is not ready. + */ + @Test + public void testNodeNotReady() { + final long producerId = 123456L; + time = new MockTime(10); + client = new MockClient(time, metadata); + + TransactionManager transactionManager = new TransactionManager(new LogContext(), "testNodeNotReady", + 60000, 100L, new ApiVersions()); + setupWithTransactionState(transactionManager, false, null, true); + ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(producerId, (short) 0); + transactionManager.initializeTransactions(); + sender.runOnce(); + + Node node = metadata.fetch().nodes().get(0); + client.delayReady(node, REQUEST_TIMEOUT + 20); + prepareFindCoordinatorResponse(Errors.NONE, "testNodeNotReady"); + sender.runOnce(); + sender.runOnce(); + assertNotNull(transactionManager.coordinator(CoordinatorType.TRANSACTION), "Coordinator not found"); + + client.throttle(node, REQUEST_TIMEOUT + 20); + prepareFindCoordinatorResponse(Errors.NONE, "Coordinator not found"); + prepareInitProducerResponse(Errors.NONE, producerIdAndEpoch.producerId, producerIdAndEpoch.epoch); + waitForProducerId(transactionManager, producerIdAndEpoch); + } + + @Test + public void testClusterAuthorizationExceptionInInitProducerIdRequest() throws Exception { + final long producerId = 343434L; + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager); + prepareAndReceiveInitProducerId(producerId, Errors.CLUSTER_AUTHORIZATION_FAILED); + assertFalse(transactionManager.hasProducerId()); + assertTrue(transactionManager.hasError()); + assertTrue(transactionManager.lastError() instanceof ClusterAuthorizationException); + + // cluster authorization is a fatal error for the producer + assertSendFailure(ClusterAuthorizationException.class); + } + + @Test + public void testCanRetryWithoutIdempotence() throws Exception { + // do a successful retry + Future future = appendToAccumulator(tp0, 0L, "key", "value"); + sender.runOnce(); // connect + sender.runOnce(); // send produce request + String id = client.requests().peek().destination(); + Node node = new Node(Integer.parseInt(id), "localhost", 0); + assertEquals(1, client.inFlightRequestCount()); + assertTrue(client.hasInFlightRequests()); + assertEquals(1, sender.inFlightBatches(tp0).size()); + assertTrue(client.isReady(node, time.milliseconds()), "Client ready status should be true"); + assertFalse(future.isDone()); + + client.respond(body -> { + ProduceRequest request = (ProduceRequest) body; + assertFalse(RequestTestUtils.hasIdempotentRecords(request)); + return true; + }, produceResponse(tp0, -1L, Errors.TOPIC_AUTHORIZATION_FAILED, 0)); + sender.runOnce(); + assertTrue(future.isDone()); + try { + future.get(); + } catch (Exception e) { + assertTrue(e.getCause() instanceof TopicAuthorizationException); + } + } + + @Test + public void testIdempotenceWithMultipleInflights() throws Exception { + final long producerId = 343434L; + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager); + prepareAndReceiveInitProducerId(producerId, Errors.NONE); + assertTrue(transactionManager.hasProducerId()); + + assertEquals(0, transactionManager.sequenceNumber(tp0).longValue()); + + // Send first ProduceRequest + Future request1 = appendToAccumulator(tp0); + sender.runOnce(); + String nodeId = client.requests().peek().destination(); + Node node = new Node(Integer.valueOf(nodeId), "localhost", 0); + assertEquals(1, client.inFlightRequestCount()); + assertEquals(1, transactionManager.sequenceNumber(tp0).longValue()); + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + + // Send second ProduceRequest + Future request2 = appendToAccumulator(tp0); + sender.runOnce(); + assertEquals(2, client.inFlightRequestCount()); + assertEquals(2, transactionManager.sequenceNumber(tp0).longValue()); + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + assertFalse(request1.isDone()); + assertFalse(request2.isDone()); + assertTrue(client.isReady(node, time.milliseconds())); + + sendIdempotentProducerResponse(0, tp0, Errors.NONE, 0L); + + sender.runOnce(); // receive response 0 + + assertEquals(1, client.inFlightRequestCount()); + assertEquals(OptionalInt.of(0), transactionManager.lastAckedSequence(tp0)); + assertTrue(request1.isDone()); + assertEquals(0, request1.get().offset()); + assertFalse(request2.isDone()); + + sendIdempotentProducerResponse(1, tp0, Errors.NONE, 1L); + sender.runOnce(); // receive response 1 + assertEquals(OptionalInt.of(1), transactionManager.lastAckedSequence(tp0)); + assertFalse(client.hasInFlightRequests()); + assertEquals(0, sender.inFlightBatches(tp0).size()); + assertTrue(request2.isDone()); + assertEquals(1, request2.get().offset()); + } + + + @Test + public void testIdempotenceWithMultipleInflightsRetriedInOrder() throws Exception { + // Send multiple in flight requests, retry them all one at a time, in the correct order. + final long producerId = 343434L; + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager); + prepareAndReceiveInitProducerId(producerId, Errors.NONE); + assertTrue(transactionManager.hasProducerId()); + + assertEquals(0, transactionManager.sequenceNumber(tp0).longValue()); + + // Send first ProduceRequest + Future request1 = appendToAccumulator(tp0); + sender.runOnce(); + String nodeId = client.requests().peek().destination(); + Node node = new Node(Integer.valueOf(nodeId), "localhost", 0); + assertEquals(1, client.inFlightRequestCount()); + assertEquals(1, transactionManager.sequenceNumber(tp0).longValue()); + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + + // Send second ProduceRequest + Future request2 = appendToAccumulator(tp0); + sender.runOnce(); + + // Send third ProduceRequest + Future request3 = appendToAccumulator(tp0); + sender.runOnce(); + + assertEquals(3, client.inFlightRequestCount()); + assertEquals(3, transactionManager.sequenceNumber(tp0).longValue()); + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + assertFalse(request1.isDone()); + assertFalse(request2.isDone()); + assertFalse(request3.isDone()); + assertTrue(client.isReady(node, time.milliseconds())); + + sendIdempotentProducerResponse(0, tp0, Errors.LEADER_NOT_AVAILABLE, -1L); + sender.runOnce(); // receive response 0 + + // Queue the fourth request, it shouldn't be sent until the first 3 complete. + Future request4 = appendToAccumulator(tp0); + + assertEquals(2, client.inFlightRequestCount()); + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + + sendIdempotentProducerResponse(1, tp0, Errors.OUT_OF_ORDER_SEQUENCE_NUMBER, -1L); + sender.runOnce(); // re send request 1, receive response 2 + + sendIdempotentProducerResponse(2, tp0, Errors.OUT_OF_ORDER_SEQUENCE_NUMBER, -1L); + sender.runOnce(); // receive response 3 + + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + assertEquals(1, client.inFlightRequestCount()); + + sender.runOnce(); // Do nothing, we are reduced to one in flight request during retries. + + assertEquals(3, transactionManager.sequenceNumber(tp0).longValue()); // the batch for request 4 shouldn't have been drained, and hence the sequence should not have been incremented. + assertEquals(1, client.inFlightRequestCount()); + + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + + sendIdempotentProducerResponse(0, tp0, Errors.NONE, 0L); + sender.runOnce(); // receive response 1 + assertEquals(OptionalInt.of(0), transactionManager.lastAckedSequence(tp0)); + assertTrue(request1.isDone()); + assertEquals(0, request1.get().offset()); + assertFalse(client.hasInFlightRequests()); + assertEquals(0, sender.inFlightBatches(tp0).size()); + + sender.runOnce(); // send request 2; + assertEquals(1, client.inFlightRequestCount()); + assertEquals(1, sender.inFlightBatches(tp0).size()); + + sendIdempotentProducerResponse(1, tp0, Errors.NONE, 1L); + sender.runOnce(); // receive response 2 + assertEquals(OptionalInt.of(1), transactionManager.lastAckedSequence(tp0)); + assertTrue(request2.isDone()); + assertEquals(1, request2.get().offset()); + + assertFalse(client.hasInFlightRequests()); + assertEquals(0, sender.inFlightBatches(tp0).size()); + + sender.runOnce(); // send request 3 + assertEquals(1, client.inFlightRequestCount()); + assertEquals(1, sender.inFlightBatches(tp0).size()); + + sendIdempotentProducerResponse(2, tp0, Errors.NONE, 2L); + sender.runOnce(); // receive response 3, send request 4 since we are out of 'retry' mode. + assertEquals(OptionalInt.of(2), transactionManager.lastAckedSequence(tp0)); + assertTrue(request3.isDone()); + assertEquals(2, request3.get().offset()); + assertEquals(1, client.inFlightRequestCount()); + assertEquals(1, sender.inFlightBatches(tp0).size()); + + sendIdempotentProducerResponse(3, tp0, Errors.NONE, 3L); + sender.runOnce(); // receive response 4 + assertEquals(OptionalInt.of(3), transactionManager.lastAckedSequence(tp0)); + assertTrue(request4.isDone()); + assertEquals(3, request4.get().offset()); + } + + @Test + public void testIdempotenceWithMultipleInflightsWhereFirstFailsFatallyAndSequenceOfFutureBatchesIsAdjusted() throws Exception { + final long producerId = 343434L; + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager); + prepareAndReceiveInitProducerId(producerId, Errors.NONE); + assertTrue(transactionManager.hasProducerId()); + + assertEquals(0, transactionManager.sequenceNumber(tp0).longValue()); + + // Send first ProduceRequest + Future request1 = appendToAccumulator(tp0); + sender.runOnce(); + String nodeId = client.requests().peek().destination(); + Node node = new Node(Integer.valueOf(nodeId), "localhost", 0); + assertEquals(1, client.inFlightRequestCount()); + assertEquals(1, transactionManager.sequenceNumber(tp0).longValue()); + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + + // Send second ProduceRequest + Future request2 = appendToAccumulator(tp0); + sender.runOnce(); + assertEquals(2, client.inFlightRequestCount()); + assertEquals(2, transactionManager.sequenceNumber(tp0).longValue()); + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + assertFalse(request1.isDone()); + assertFalse(request2.isDone()); + assertTrue(client.isReady(node, time.milliseconds())); + + sendIdempotentProducerResponse(0, tp0, Errors.MESSAGE_TOO_LARGE, -1L); + + sender.runOnce(); // receive response 0, should adjust sequences of future batches. + assertFutureFailure(request1, RecordTooLargeException.class); + + assertEquals(1, client.inFlightRequestCount()); + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + + sendIdempotentProducerResponse(1, tp0, Errors.OUT_OF_ORDER_SEQUENCE_NUMBER, -1L); + + sender.runOnce(); // receive response 1 + + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + assertEquals(0, client.inFlightRequestCount()); + + sender.runOnce(); // resend request 1 + + assertEquals(1, client.inFlightRequestCount()); + + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + + sendIdempotentProducerResponse(0, tp0, Errors.NONE, 0L); + sender.runOnce(); // receive response 1 + assertEquals(OptionalInt.of(0), transactionManager.lastAckedSequence(tp0)); + assertEquals(0, client.inFlightRequestCount()); + + assertTrue(request1.isDone()); + assertEquals(0, request2.get().offset()); + } + + @Test + public void testEpochBumpOnOutOfOrderSequenceForNextBatch() throws Exception { + final long producerId = 343434L; + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager); + prepareAndReceiveInitProducerId(producerId, Errors.NONE); + assertTrue(transactionManager.hasProducerId()); + + assertEquals(0, transactionManager.sequenceNumber(tp0).longValue()); + + // Send first ProduceRequest with multiple messages. + Future request1 = appendToAccumulator(tp0); + appendToAccumulator(tp0); + sender.runOnce(); + String nodeId = client.requests().peek().destination(); + Node node = new Node(Integer.valueOf(nodeId), "localhost", 0); + assertEquals(1, client.inFlightRequestCount()); + + // make sure the next sequence number accounts for multi-message batches. + assertEquals(2, transactionManager.sequenceNumber(tp0).longValue()); + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + sendIdempotentProducerResponse(0, tp0, Errors.NONE, 0); + + sender.runOnce(); + + // Send second ProduceRequest + Future request2 = appendToAccumulator(tp0); + sender.runOnce(); + assertEquals(1, client.inFlightRequestCount()); + assertEquals(3, transactionManager.sequenceNumber(tp0).longValue()); + assertEquals(OptionalInt.of(1), transactionManager.lastAckedSequence(tp0)); + assertTrue(request1.isDone()); + assertEquals(0, request1.get().offset()); + assertFalse(request2.isDone()); + assertTrue(client.isReady(node, time.milliseconds())); + + // This OutOfOrderSequence triggers an epoch bump since it is returned for the batch succeeding the last acknowledged batch. + sendIdempotentProducerResponse(2, tp0, Errors.OUT_OF_ORDER_SEQUENCE_NUMBER, -1L); + + sender.runOnce(); + sender.runOnce(); + + // epoch should be bumped and sequence numbers reset + assertEquals(1, transactionManager.producerIdAndEpoch().epoch); + assertEquals(1, transactionManager.sequenceNumber(tp0).intValue()); + assertEquals(0, transactionManager.firstInFlightSequence(tp0)); + } + + @Test + public void testEpochBumpOnOutOfOrderSequenceForNextBatchWhenThereIsNoBatchInFlight() throws Exception { + // Verify that partitions without in-flight batches when the producer epoch + // is bumped get their sequence number reset correctly. + final long producerId = 343434L; + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager); + + // Init producer id/epoch + prepareAndReceiveInitProducerId(producerId, Errors.NONE); + assertEquals(producerId, transactionManager.producerIdAndEpoch().producerId); + assertEquals(0, transactionManager.producerIdAndEpoch().epoch); + + // Partition 0 - Send first batch + appendToAccumulator(tp0); + sender.runOnce(); + + // Partition 0 - State is lazily initialized + assertPartitionState(transactionManager, tp0, producerId, (short) 0, 1, OptionalInt.empty()); + + // Partition 0 - Successful response + sendIdempotentProducerResponse(0, 0, tp0, Errors.NONE, 0, -1); + sender.runOnce(); + + // Partition 0 - Last ack is updated + assertPartitionState(transactionManager, tp0, producerId, (short) 0, 1, OptionalInt.of(0)); + + // Partition 1 - Send first batch + appendToAccumulator(tp1); + sender.runOnce(); + + // Partition 1 - State is lazily initialized + assertPartitionState(transactionManager, tp1, producerId, (short) 0, 1, OptionalInt.empty()); + + // Partition 1 - Successful response + sendIdempotentProducerResponse(0, 0, tp1, Errors.NONE, 0, -1); + sender.runOnce(); + + // Partition 1 - Last ack is updated + assertPartitionState(transactionManager, tp1, producerId, (short) 0, 1, OptionalInt.of(0)); + + // Partition 0 - Send second batch + appendToAccumulator(tp0); + sender.runOnce(); + + // Partition 0 - Sequence is incremented + assertPartitionState(transactionManager, tp0, producerId, (short) 0, 2, OptionalInt.of(0)); + + // Partition 0 - Failed response with OUT_OF_ORDER_SEQUENCE_NUMBER + sendIdempotentProducerResponse(0, 1, tp0, Errors.OUT_OF_ORDER_SEQUENCE_NUMBER, -1, -1); + sender.runOnce(); // Receive + sender.runOnce(); // Bump epoch & Retry + + // Producer epoch is bumped + assertEquals(1, transactionManager.producerIdAndEpoch().epoch); + + // Partition 0 - State is reset to current producer epoch + assertPartitionState(transactionManager, tp0, producerId, (short) 1, 1, OptionalInt.empty()); + + // Partition 1 - State is not changed + assertPartitionState(transactionManager, tp1, producerId, (short) 0, 1, OptionalInt.of(0)); + assertTrue(transactionManager.hasStaleProducerIdAndEpoch(tp1)); + + // Partition 0 - Successful Response + sendIdempotentProducerResponse(1, 0, tp0, Errors.NONE, 1, -1); + sender.runOnce(); + + // Partition 0 - Last ack is updated + assertPartitionState(transactionManager, tp0, producerId, (short) 1, 1, OptionalInt.of(0)); + + // Partition 1 - Send second batch + appendToAccumulator(tp1); + sender.runOnce(); + + // Partition 1 - Epoch is bumped, sequence is reset and incremented + assertPartitionState(transactionManager, tp1, producerId, (short) 1, 1, OptionalInt.empty()); + assertFalse(transactionManager.hasStaleProducerIdAndEpoch(tp1)); + + // Partition 1 - Successful Response + sendIdempotentProducerResponse(1, 0, tp1, Errors.NONE, 1, -1); + sender.runOnce(); + + // Partition 1 - Last ack is updated + assertPartitionState(transactionManager, tp1, producerId, (short) 1, 1, OptionalInt.of(0)); + } + + @Test + public void testEpochBumpOnOutOfOrderSequenceForNextBatchWhenBatchInFlightFails() throws Exception { + // When a batch failed after the producer epoch is bumped, the sequence number of + // that partition must be reset for any subsequent batches sent. + final long producerId = 343434L; + TransactionManager transactionManager = createTransactionManager(); + + // Retries once + setupWithTransactionState(transactionManager, false, null, true, 1, 0); + + // Init producer id/epoch + prepareAndReceiveInitProducerId(producerId, Errors.NONE); + assertEquals(producerId, transactionManager.producerIdAndEpoch().producerId); + assertEquals(0, transactionManager.producerIdAndEpoch().epoch); + + // Partition 0 - Send first batch + appendToAccumulator(tp0); + sender.runOnce(); + + // Partition 0 - State is lazily initialized + assertPartitionState(transactionManager, tp0, producerId, (short) 0, 1, OptionalInt.empty()); + + // Partition 0 - Successful response + sendIdempotentProducerResponse(0, 0, tp0, Errors.NONE, 0, -1); + sender.runOnce(); + + // Partition 0 - Last ack is updated + assertPartitionState(transactionManager, tp0, producerId, (short) 0, 1, OptionalInt.of(0)); + + // Partition 1 - Send first batch + appendToAccumulator(tp1); + sender.runOnce(); + + // Partition 1 - State is lazily initialized + assertPartitionState(transactionManager, tp1, producerId, (short) 0, 1, OptionalInt.empty()); + + // Partition 1 - Successful response + sendIdempotentProducerResponse(0, 0, tp1, Errors.NONE, 0, -1); + sender.runOnce(); + + // Partition 1 - Last ack is updated + assertPartitionState(transactionManager, tp1, producerId, (short) 0, 1, OptionalInt.of(0)); + + // Partition 0 - Send second batch + appendToAccumulator(tp0); + sender.runOnce(); + + // Partition 0 - Sequence is incremented + assertPartitionState(transactionManager, tp0, producerId, (short) 0, 2, OptionalInt.of(0)); + + // Partition 1 - Send second batch + appendToAccumulator(tp1); + sender.runOnce(); + + // Partition 1 - Sequence is incremented + assertPartitionState(transactionManager, tp1, producerId, (short) 0, 2, OptionalInt.of(0)); + + // Partition 0 - Failed response with OUT_OF_ORDER_SEQUENCE_NUMBER + sendIdempotentProducerResponse(0, 1, tp0, Errors.OUT_OF_ORDER_SEQUENCE_NUMBER, -1, -1); + sender.runOnce(); // Receive + sender.runOnce(); // Bump epoch & Retry + + // Producer epoch is bumped + assertEquals(1, transactionManager.producerIdAndEpoch().epoch); + + // Partition 0 - State is reset to current producer epoch + assertPartitionState(transactionManager, tp0, producerId, (short) 1, 1, OptionalInt.empty()); + + // Partition 1 - State is not changed. The epoch will be lazily bumped when all in-flight + // batches are completed + assertPartitionState(transactionManager, tp1, producerId, (short) 0, 2, OptionalInt.of(0)); + assertTrue(transactionManager.hasStaleProducerIdAndEpoch(tp1)); + + // Partition 1 - Failed response with NOT_LEADER_OR_FOLLOWER + sendIdempotentProducerResponse(0, 1, tp1, Errors.NOT_LEADER_OR_FOLLOWER, -1, -1); + sender.runOnce(); // Receive & Retry + + // Partition 1 - State is not changed. + assertPartitionState(transactionManager, tp1, producerId, (short) 0, 2, OptionalInt.of(0)); + assertTrue(transactionManager.hasStaleProducerIdAndEpoch(tp1)); + + // Partition 0 - Successful Response + sendIdempotentProducerResponse(1, 0, tp0, Errors.NONE, 1, -1); + sender.runOnce(); + + // Partition 0 - Last ack is updated + assertPartitionState(transactionManager, tp0, producerId, (short) 1, 1, OptionalInt.of(0)); + + // Partition 1 - Failed response with NOT_LEADER_OR_FOLLOWER + sendIdempotentProducerResponse(0, 1, tp1, Errors.NOT_LEADER_OR_FOLLOWER, -1, -1); + sender.runOnce(); // Receive & Fail the batch (retries exhausted) + + // Partition 1 - State is not changed. It will be lazily updated when the next batch is sent. + assertPartitionState(transactionManager, tp1, producerId, (short) 0, 2, OptionalInt.of(0)); + assertTrue(transactionManager.hasStaleProducerIdAndEpoch(tp1)); + + // Partition 1 - Send third batch + appendToAccumulator(tp1); + sender.runOnce(); + + // Partition 1 - Epoch is bumped, sequence is reset + assertPartitionState(transactionManager, tp1, producerId, (short) 1, 1, OptionalInt.empty()); + assertFalse(transactionManager.hasStaleProducerIdAndEpoch(tp1)); + + // Partition 1 - Successful Response + sendIdempotentProducerResponse(1, 0, tp1, Errors.NONE, 0, -1); + sender.runOnce(); + + // Partition 1 - Last ack is updated + assertPartitionState(transactionManager, tp1, producerId, (short) 1, 1, OptionalInt.of(0)); + + // Partition 0 - Send third batch + appendToAccumulator(tp0); + sender.runOnce(); + + // Partition 0 - Sequence is incremented + assertPartitionState(transactionManager, tp0, producerId, (short) 1, 2, OptionalInt.of(0)); + + // Partition 0 - Successful Response + sendIdempotentProducerResponse(1, 1, tp0, Errors.NONE, 0, -1); + sender.runOnce(); + + // Partition 0 - Last ack is updated + assertPartitionState(transactionManager, tp0, producerId, (short) 1, 2, OptionalInt.of(1)); + } + + private void assertPartitionState( + TransactionManager transactionManager, + TopicPartition tp, + long expectedProducerId, + short expectedProducerEpoch, + long expectedSequenceValue, + OptionalInt expectedLastAckedSequence + ) { + assertEquals(expectedProducerId, transactionManager.producerIdAndEpoch(tp).producerId, "Producer Id:"); + assertEquals(expectedProducerEpoch, transactionManager.producerIdAndEpoch(tp).epoch, "Producer Epoch:"); + assertEquals(expectedSequenceValue, transactionManager.sequenceNumber(tp).longValue(), "Seq Number:"); + assertEquals(expectedLastAckedSequence, transactionManager.lastAckedSequence(tp), "Last Acked Seq Number:"); + } + + @Test + public void testCorrectHandlingOfOutOfOrderResponses() throws Exception { + final long producerId = 343434L; + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager); + prepareAndReceiveInitProducerId(producerId, Errors.NONE); + assertTrue(transactionManager.hasProducerId()); + assertEquals(0, transactionManager.sequenceNumber(tp0).longValue()); + + // Send first ProduceRequest + Future request1 = appendToAccumulator(tp0); + sender.runOnce(); + String nodeId = client.requests().peek().destination(); + Node node = new Node(Integer.valueOf(nodeId), "localhost", 0); + assertEquals(1, client.inFlightRequestCount()); + assertEquals(1, transactionManager.sequenceNumber(tp0).longValue()); + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + + // Send second ProduceRequest + Future request2 = appendToAccumulator(tp0); + sender.runOnce(); + assertEquals(2, client.inFlightRequestCount()); + assertEquals(2, transactionManager.sequenceNumber(tp0).longValue()); + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + assertFalse(request1.isDone()); + assertFalse(request2.isDone()); + assertTrue(client.isReady(node, time.milliseconds())); + + ClientRequest firstClientRequest = client.requests().peek(); + ClientRequest secondClientRequest = (ClientRequest) client.requests().toArray()[1]; + + client.respondToRequest(secondClientRequest, produceResponse(tp0, -1, Errors.OUT_OF_ORDER_SEQUENCE_NUMBER, -1)); + + sender.runOnce(); // receive response 1 + Deque queuedBatches = accumulator.batches().get(tp0); + + // Make sure that we are queueing the second batch first. + assertEquals(1, queuedBatches.size()); + assertEquals(1, queuedBatches.peekFirst().baseSequence()); + assertEquals(1, client.inFlightRequestCount()); + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + + client.respondToRequest(firstClientRequest, produceResponse(tp0, -1, Errors.NOT_LEADER_OR_FOLLOWER, -1)); + + sender.runOnce(); // receive response 0 + + // Make sure we requeued both batches in the correct order. + assertEquals(2, queuedBatches.size()); + assertEquals(0, queuedBatches.peekFirst().baseSequence()); + assertEquals(1, queuedBatches.peekLast().baseSequence()); + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + assertEquals(0, client.inFlightRequestCount()); + assertFalse(request1.isDone()); + assertFalse(request2.isDone()); + + sender.runOnce(); // send request 0 + assertEquals(1, client.inFlightRequestCount()); + sender.runOnce(); // don't do anything, only one inflight allowed once we are retrying. + + assertEquals(1, client.inFlightRequestCount()); + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + + // Make sure that the requests are sent in order, even though the previous responses were not in order. + sendIdempotentProducerResponse(0, tp0, Errors.NONE, 0L); + sender.runOnce(); // receive response 0 + assertEquals(OptionalInt.of(0), transactionManager.lastAckedSequence(tp0)); + assertEquals(0, client.inFlightRequestCount()); + assertTrue(request1.isDone()); + assertEquals(0, request1.get().offset()); + + sender.runOnce(); // send request 1 + assertEquals(1, client.inFlightRequestCount()); + sendIdempotentProducerResponse(1, tp0, Errors.NONE, 1L); + sender.runOnce(); // receive response 1 + + assertFalse(client.hasInFlightRequests()); + assertEquals(OptionalInt.of(1), transactionManager.lastAckedSequence(tp0)); + assertTrue(request2.isDone()); + assertEquals(1, request2.get().offset()); + } + + @Test + public void testCorrectHandlingOfOutOfOrderResponsesWhenSecondSucceeds() throws Exception { + final long producerId = 343434L; + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager); + prepareAndReceiveInitProducerId(producerId, Errors.NONE); + assertTrue(transactionManager.hasProducerId()); + + assertEquals(0, transactionManager.sequenceNumber(tp0).longValue()); + + // Send first ProduceRequest + Future request1 = appendToAccumulator(tp0); + sender.runOnce(); + String nodeId = client.requests().peek().destination(); + Node node = new Node(Integer.valueOf(nodeId), "localhost", 0); + assertEquals(1, client.inFlightRequestCount()); + + // Send second ProduceRequest + Future request2 = appendToAccumulator(tp0); + sender.runOnce(); + assertEquals(2, client.inFlightRequestCount()); + assertFalse(request1.isDone()); + assertFalse(request2.isDone()); + assertTrue(client.isReady(node, time.milliseconds())); + + ClientRequest firstClientRequest = client.requests().peek(); + ClientRequest secondClientRequest = (ClientRequest) client.requests().toArray()[1]; + + client.respondToRequest(secondClientRequest, produceResponse(tp0, 1, Errors.NONE, 1)); + + sender.runOnce(); // receive response 1 + assertTrue(request2.isDone()); + assertEquals(1, request2.get().offset()); + assertFalse(request1.isDone()); + Deque queuedBatches = accumulator.batches().get(tp0); + + assertEquals(0, queuedBatches.size()); + assertEquals(1, client.inFlightRequestCount()); + assertEquals(OptionalInt.of(1), transactionManager.lastAckedSequence(tp0)); + + client.respondToRequest(firstClientRequest, produceResponse(tp0, -1, Errors.REQUEST_TIMED_OUT, -1)); + + sender.runOnce(); // receive response 0 + + // Make sure we requeued both batches in the correct order. + assertEquals(1, queuedBatches.size()); + assertEquals(0, queuedBatches.peekFirst().baseSequence()); + assertEquals(OptionalInt.of(1), transactionManager.lastAckedSequence(tp0)); + assertEquals(0, client.inFlightRequestCount()); + + sender.runOnce(); // resend request 0 + assertEquals(1, client.inFlightRequestCount()); + + assertEquals(1, client.inFlightRequestCount()); + assertEquals(OptionalInt.of(1), transactionManager.lastAckedSequence(tp0)); + + // Make sure we handle the out of order successful responses correctly. + sendIdempotentProducerResponse(0, tp0, Errors.NONE, 0L); + sender.runOnce(); // receive response 0 + assertEquals(0, queuedBatches.size()); + assertEquals(OptionalInt.of(1), transactionManager.lastAckedSequence(tp0)); + assertEquals(0, client.inFlightRequestCount()); + + assertFalse(client.hasInFlightRequests()); + assertTrue(request1.isDone()); + assertEquals(0, request1.get().offset()); + } + + @Test + public void testExpiryOfUnsentBatchesShouldNotCauseUnresolvedSequences() throws Exception { + final long producerId = 343434L; + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager); + prepareAndReceiveInitProducerId(producerId, Errors.NONE); + assertTrue(transactionManager.hasProducerId()); + + assertEquals(0, transactionManager.sequenceNumber(tp0).longValue()); + + // Send first ProduceRequest + Future request1 = appendToAccumulator(tp0, 0L, "key", "value"); + Node node = metadata.fetch().nodes().get(0); + time.sleep(10000L); + client.disconnect(node.idString()); + client.backoff(node, 10); + + sender.runOnce(); + + assertFutureFailure(request1, TimeoutException.class); + assertFalse(transactionManager.hasUnresolvedSequence(tp0)); + } + + @Test + public void testExpiryOfFirstBatchShouldNotCauseUnresolvedSequencesIfFutureBatchesSucceed() throws Exception { + final long producerId = 343434L; + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager, false, null); + prepareAndReceiveInitProducerId(producerId, Errors.NONE); + assertTrue(transactionManager.hasProducerId()); + assertEquals(0, transactionManager.sequenceNumber(tp0).longValue()); + + // Send first ProduceRequest + Future request1 = appendToAccumulator(tp0); + sender.runOnce(); // send request + // We separate the two appends by 1 second so that the two batches + // don't expire at the same time. + time.sleep(1000L); + + Future request2 = appendToAccumulator(tp0); + sender.runOnce(); // send request + assertEquals(2, client.inFlightRequestCount()); + assertEquals(2, sender.inFlightBatches(tp0).size()); + + sendIdempotentProducerResponse(0, tp0, Errors.REQUEST_TIMED_OUT, -1); + sender.runOnce(); // receive first response + assertEquals(1, sender.inFlightBatches(tp0).size()); + + Node node = metadata.fetch().nodes().get(0); + // We add 600 millis to expire the first batch but not the second. + // Note deliveryTimeoutMs is 1500. + time.sleep(600L); + client.disconnect(node.idString()); + client.backoff(node, 10); + + sender.runOnce(); // now expire the first batch. + assertFutureFailure(request1, TimeoutException.class); + assertTrue(transactionManager.hasUnresolvedSequence(tp0)); + assertEquals(0, sender.inFlightBatches(tp0).size()); + + // let's enqueue another batch, which should not be dequeued until the unresolved state is clear. + Future request3 = appendToAccumulator(tp0); + time.sleep(20); + assertFalse(request2.isDone()); + + sender.runOnce(); // send second request + sendIdempotentProducerResponse(1, tp0, Errors.NONE, 1); + assertEquals(1, sender.inFlightBatches(tp0).size()); + + sender.runOnce(); // receive second response, the third request shouldn't be sent since we are in an unresolved state. + assertTrue(request2.isDone()); + assertEquals(1, request2.get().offset()); + assertEquals(0, sender.inFlightBatches(tp0).size()); + + Deque batches = accumulator.batches().get(tp0); + assertEquals(1, batches.size()); + assertFalse(batches.peekFirst().hasSequence()); + assertFalse(client.hasInFlightRequests()); + assertEquals(2L, transactionManager.sequenceNumber(tp0).longValue()); + assertTrue(transactionManager.hasUnresolvedSequence(tp0)); + + sender.runOnce(); // clear the unresolved state, send the pending request. + assertFalse(transactionManager.hasUnresolvedSequence(tp0)); + assertTrue(transactionManager.hasProducerId()); + assertEquals(0, batches.size()); + assertEquals(1, client.inFlightRequestCount()); + assertFalse(request3.isDone()); + assertEquals(1, sender.inFlightBatches(tp0).size()); + } + + @Test + public void testExpiryOfFirstBatchShouldCauseEpochBumpIfFutureBatchesFail() throws Exception { + final long producerId = 343434L; + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager); + prepareAndReceiveInitProducerId(producerId, Errors.NONE); + assertTrue(transactionManager.hasProducerId()); + assertEquals(0, transactionManager.sequenceNumber(tp0).longValue()); + + // Send first ProduceRequest + Future request1 = appendToAccumulator(tp0); + sender.runOnce(); // send request + + time.sleep(1000L); + Future request2 = appendToAccumulator(tp0); + sender.runOnce(); // send request + + assertEquals(2, client.inFlightRequestCount()); + + sendIdempotentProducerResponse(0, tp0, Errors.NOT_LEADER_OR_FOLLOWER, -1); + sender.runOnce(); // receive first response + + Node node = metadata.fetch().nodes().get(0); + time.sleep(1000L); + client.disconnect(node.idString()); + client.backoff(node, 10); + + sender.runOnce(); // now expire the first batch. + assertFutureFailure(request1, TimeoutException.class); + assertTrue(transactionManager.hasUnresolvedSequence(tp0)); + // let's enqueue another batch, which should not be dequeued until the unresolved state is clear. + appendToAccumulator(tp0); + + time.sleep(20); + assertFalse(request2.isDone()); + sender.runOnce(); // send second request + sendIdempotentProducerResponse(1, tp0, Errors.OUT_OF_ORDER_SEQUENCE_NUMBER, 1); + sender.runOnce(); // receive second response, the third request shouldn't be sent since we are in an unresolved state. + + Deque batches = accumulator.batches().get(tp0); + + // The epoch should be bumped and the second request should be requeued + assertEquals(2, batches.size()); + + sender.runOnce(); + assertEquals((short) 1, transactionManager.producerIdAndEpoch().epoch); + assertEquals(1, transactionManager.sequenceNumber(tp0).longValue()); + assertFalse(transactionManager.hasUnresolvedSequence(tp0)); + } + + @Test + public void testUnresolvedSequencesAreNotFatal() throws Exception { + ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0); + apiVersions.update("0", NodeApiVersions.create(ApiKeys.INIT_PRODUCER_ID.id, (short) 0, (short) 3)); + TransactionManager txnManager = new TransactionManager(logContext, "testUnresolvedSeq", 60000, 100, apiVersions); + + setupWithTransactionState(txnManager); + doInitTransactions(txnManager, producerIdAndEpoch); + + txnManager.beginTransaction(); + txnManager.failIfNotReadyForSend(); + txnManager.maybeAddPartitionToTransaction(tp0); + client.prepareResponse(new AddPartitionsToTxnResponse(0, Collections.singletonMap(tp0, Errors.NONE))); + sender.runOnce(); + + // Send first ProduceRequest + Future request1 = appendToAccumulator(tp0); + sender.runOnce(); // send request + + time.sleep(1000L); + appendToAccumulator(tp0); + sender.runOnce(); // send request + + assertEquals(2, client.inFlightRequestCount()); + + sendIdempotentProducerResponse(0, tp0, Errors.NOT_LEADER_OR_FOLLOWER, -1); + sender.runOnce(); // receive first response + + Node node = metadata.fetch().nodes().get(0); + time.sleep(1000L); + client.disconnect(node.idString()); + client.backoff(node, 10); + + sender.runOnce(); // now expire the first batch. + assertFutureFailure(request1, TimeoutException.class); + assertTrue(txnManager.hasUnresolvedSequence(tp0)); + + // Loop once and confirm that the transaction manager does not enter a fatal error state + sender.runOnce(); + assertTrue(txnManager.hasAbortableError()); + } + + @Test + public void testExpiryOfAllSentBatchesShouldCauseUnresolvedSequences() throws Exception { + final long producerId = 343434L; + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager); + prepareAndReceiveInitProducerId(producerId, Errors.NONE); + assertTrue(transactionManager.hasProducerId()); + + assertEquals(0, transactionManager.sequenceNumber(tp0).longValue()); + + // Send first ProduceRequest + Future request1 = appendToAccumulator(tp0, 0L, "key", "value"); + sender.runOnce(); // send request + sendIdempotentProducerResponse(0, tp0, Errors.NOT_LEADER_OR_FOLLOWER, -1); + + sender.runOnce(); // receive response + assertEquals(1L, transactionManager.sequenceNumber(tp0).longValue()); + + Node node = metadata.fetch().nodes().get(0); + time.sleep(15000L); + client.disconnect(node.idString()); + client.backoff(node, 10); + + sender.runOnce(); // now expire the batch. + + assertFutureFailure(request1, TimeoutException.class); + assertTrue(transactionManager.hasUnresolvedSequence(tp0)); + assertFalse(client.hasInFlightRequests()); + Deque batches = accumulator.batches().get(tp0); + assertEquals(0, batches.size()); + assertEquals(producerId, transactionManager.producerIdAndEpoch().producerId); + + // In the next run loop, we bump the epoch and clear the unresolved sequences + sender.runOnce(); + assertEquals(1, transactionManager.producerIdAndEpoch().epoch); + assertFalse(transactionManager.hasUnresolvedSequence(tp0)); + } + + @Test + public void testResetOfProducerStateShouldAllowQueuedBatchesToDrain() throws Exception { + final long producerId = 343434L; + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager); + prepareAndReceiveInitProducerId(producerId, Short.MAX_VALUE, Errors.NONE); + assertTrue(transactionManager.hasProducerId()); + + int maxRetries = 10; + Metrics m = new Metrics(); + SenderMetricsRegistry senderMetrics = new SenderMetricsRegistry(m); + + Sender sender = new Sender(logContext, client, metadata, this.accumulator, true, MAX_REQUEST_SIZE, ACKS_ALL, maxRetries, + senderMetrics, time, REQUEST_TIMEOUT, RETRY_BACKOFF_MS, transactionManager, apiVersions); + + appendToAccumulator(tp0); // failed response + Future successfulResponse = appendToAccumulator(tp1); + sender.runOnce(); // connect and send. + + assertEquals(1, client.inFlightRequestCount()); + + Map responses = new LinkedHashMap<>(); + responses.put(tp1, new OffsetAndError(-1, Errors.NOT_LEADER_OR_FOLLOWER)); + responses.put(tp0, new OffsetAndError(-1, Errors.OUT_OF_ORDER_SEQUENCE_NUMBER)); + client.respond(produceResponse(responses)); + + sender.runOnce(); // trigger epoch bump + prepareAndReceiveInitProducerId(producerId + 1, Errors.NONE); // also send request to tp1 + sender.runOnce(); // reset producer ID because epoch is maxed out + assertEquals(producerId + 1, transactionManager.producerIdAndEpoch().producerId); + + assertFalse(successfulResponse.isDone()); + client.respond(produceResponse(tp1, 10, Errors.NONE, -1)); + sender.runOnce(); + + assertTrue(successfulResponse.isDone()); + assertEquals(10, successfulResponse.get().offset()); + + // The epoch and the sequence are updated when the next batch is sent. + assertEquals(1, transactionManager.sequenceNumber(tp1).longValue()); + } + + @Test + public void testCloseWithProducerIdReset() throws Exception { + final long producerId = 343434L; + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager); + prepareAndReceiveInitProducerId(producerId, Short.MAX_VALUE, Errors.NONE); + assertTrue(transactionManager.hasProducerId()); + + Metrics m = new Metrics(); + SenderMetricsRegistry senderMetrics = new SenderMetricsRegistry(m); + + Sender sender = new Sender(logContext, client, metadata, this.accumulator, true, MAX_REQUEST_SIZE, ACKS_ALL, 10, + senderMetrics, time, REQUEST_TIMEOUT, RETRY_BACKOFF_MS, transactionManager, apiVersions); + + appendToAccumulator(tp0); // failed response + appendToAccumulator(tp1); // success response + sender.runOnce(); // connect and send. + + assertEquals(1, client.inFlightRequestCount()); + + Map responses = new LinkedHashMap<>(); + responses.put(tp1, new OffsetAndError(-1, Errors.NOT_LEADER_OR_FOLLOWER)); + responses.put(tp0, new OffsetAndError(-1, Errors.OUT_OF_ORDER_SEQUENCE_NUMBER)); + client.respond(produceResponse(responses)); + sender.initiateClose(); // initiate close + sender.runOnce(); // out of order sequence error triggers producer ID reset because epoch is maxed out + + TestUtils.waitForCondition(() -> { + prepareInitProducerResponse(Errors.NONE, producerId + 1, (short) 1); + sender.runOnce(); + return !accumulator.hasUndrained(); + }, 5000, "Failed to drain batches"); + } + + @Test + public void testForceCloseWithProducerIdReset() throws Exception { + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager); + prepareAndReceiveInitProducerId(1L, Short.MAX_VALUE, Errors.NONE); + assertTrue(transactionManager.hasProducerId()); + + Metrics m = new Metrics(); + SenderMetricsRegistry senderMetrics = new SenderMetricsRegistry(m); + + Sender sender = new Sender(logContext, client, metadata, this.accumulator, true, MAX_REQUEST_SIZE, ACKS_ALL, 10, + senderMetrics, time, REQUEST_TIMEOUT, RETRY_BACKOFF_MS, transactionManager, apiVersions); + + Future failedResponse = appendToAccumulator(tp0); + Future successfulResponse = appendToAccumulator(tp1); + sender.runOnce(); // connect and send. + + assertEquals(1, client.inFlightRequestCount()); + + Map responses = new LinkedHashMap<>(); + responses.put(tp1, new OffsetAndError(-1, Errors.NOT_LEADER_OR_FOLLOWER)); + responses.put(tp0, new OffsetAndError(-1, Errors.OUT_OF_ORDER_SEQUENCE_NUMBER)); + client.respond(produceResponse(responses)); + sender.runOnce(); // out of order sequence error triggers producer ID reset because epoch is maxed out + sender.forceClose(); // initiate force close + sender.runOnce(); // this should not block + sender.run(); // run main loop to test forceClose flag + assertFalse(accumulator.hasUndrained(), "Pending batches are not aborted."); + assertTrue(successfulResponse.isDone()); + } + + @Test + public void testBatchesDrainedWithOldProducerIdShouldSucceedOnSubsequentRetry() throws Exception { + final long producerId = 343434L; + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager); + prepareAndReceiveInitProducerId(producerId, Errors.NONE); + assertTrue(transactionManager.hasProducerId()); + + int maxRetries = 10; + Metrics m = new Metrics(); + SenderMetricsRegistry senderMetrics = new SenderMetricsRegistry(m); + + Sender sender = new Sender(logContext, client, metadata, this.accumulator, true, MAX_REQUEST_SIZE, ACKS_ALL, maxRetries, + senderMetrics, time, REQUEST_TIMEOUT, RETRY_BACKOFF_MS, transactionManager, apiVersions); + + Future outOfOrderResponse = appendToAccumulator(tp0); + Future successfulResponse = appendToAccumulator(tp1); + sender.runOnce(); // connect. + sender.runOnce(); // send. + + assertEquals(1, client.inFlightRequestCount()); + + Map responses = new LinkedHashMap<>(); + responses.put(tp1, new OffsetAndError(-1, Errors.NOT_LEADER_OR_FOLLOWER)); + responses.put(tp0, new OffsetAndError(-1, Errors.OUT_OF_ORDER_SEQUENCE_NUMBER)); + client.respond(produceResponse(responses)); + sender.runOnce(); + assertFalse(outOfOrderResponse.isDone()); + + sender.runOnce(); // bump epoch send request to tp1 with the old producerId + assertEquals(1, transactionManager.producerIdAndEpoch().epoch); + + assertFalse(successfulResponse.isDone()); + // The response comes back with a retriable error. + client.respond(produceResponse(tp1, 0, Errors.NOT_LEADER_OR_FOLLOWER, -1)); + sender.runOnce(); + + // The response + assertFalse(successfulResponse.isDone()); + sender.runOnce(); // retry one more time + client.respond(produceResponse(tp1, 0, Errors.NONE, -1)); + sender.runOnce(); + assertTrue(successfulResponse.isDone()); + // epoch of partition is bumped and sequence is reset when the next batch is sent + assertEquals(1, transactionManager.sequenceNumber(tp1).intValue()); + } + + @Test + public void testCorrectHandlingOfDuplicateSequenceError() throws Exception { + final long producerId = 343434L; + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager); + prepareAndReceiveInitProducerId(producerId, Errors.NONE); + assertTrue(transactionManager.hasProducerId()); + + assertEquals(0, transactionManager.sequenceNumber(tp0).longValue()); + + // Send first ProduceRequest + Future request1 = appendToAccumulator(tp0); + sender.runOnce(); + String nodeId = client.requests().peek().destination(); + Node node = new Node(Integer.valueOf(nodeId), "localhost", 0); + assertEquals(1, client.inFlightRequestCount()); + assertEquals(1, transactionManager.sequenceNumber(tp0).longValue()); + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + + // Send second ProduceRequest + Future request2 = appendToAccumulator(tp0); + sender.runOnce(); + assertEquals(2, client.inFlightRequestCount()); + assertEquals(2, transactionManager.sequenceNumber(tp0).longValue()); + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + assertFalse(request1.isDone()); + assertFalse(request2.isDone()); + assertTrue(client.isReady(node, time.milliseconds())); + + ClientRequest firstClientRequest = client.requests().peek(); + ClientRequest secondClientRequest = (ClientRequest) client.requests().toArray()[1]; + + client.respondToRequest(secondClientRequest, produceResponse(tp0, 1000, Errors.NONE, 0)); + + sender.runOnce(); // receive response 1 + + assertEquals(OptionalLong.of(1000), transactionManager.lastAckedOffset(tp0)); + assertEquals(OptionalInt.of(1), transactionManager.lastAckedSequence(tp0)); + + client.respondToRequest(firstClientRequest, produceResponse(tp0, ProduceResponse.INVALID_OFFSET, Errors.DUPLICATE_SEQUENCE_NUMBER, 0)); + + sender.runOnce(); // receive response 0 + + // Make sure that the last ack'd sequence doesn't change. + assertEquals(OptionalInt.of(1), transactionManager.lastAckedSequence(tp0)); + assertEquals(OptionalLong.of(1000), transactionManager.lastAckedOffset(tp0)); + assertFalse(client.hasInFlightRequests()); + + RecordMetadata unknownMetadata = request1.get(); + assertFalse(unknownMetadata.hasOffset()); + assertEquals(-1L, unknownMetadata.offset()); + } + + @Test + public void testTransactionalUnknownProducerHandlingWhenRetentionLimitReached() throws Exception { + final long producerId = 343434L; + TransactionManager transactionManager = new TransactionManager(logContext, "testUnresolvedSeq", 60000, 100, apiVersions); + + setupWithTransactionState(transactionManager); + doInitTransactions(transactionManager, new ProducerIdAndEpoch(producerId, (short) 0)); + assertTrue(transactionManager.hasProducerId()); + + transactionManager.maybeAddPartitionToTransaction(tp0); + client.prepareResponse(new AddPartitionsToTxnResponse(0, Collections.singletonMap(tp0, Errors.NONE))); + sender.runOnce(); // Receive AddPartitions response + + assertEquals(0, transactionManager.sequenceNumber(tp0).longValue()); + + // Send first ProduceRequest + Future request1 = appendToAccumulator(tp0); + sender.runOnce(); + + assertEquals(1, client.inFlightRequestCount()); + assertEquals(1, transactionManager.sequenceNumber(tp0).longValue()); + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + + sendIdempotentProducerResponse(0, tp0, Errors.NONE, 1000L, 10L); + + sender.runOnce(); // receive the response. + + assertTrue(request1.isDone()); + assertEquals(1000L, request1.get().offset()); + assertEquals(OptionalInt.of(0), transactionManager.lastAckedSequence(tp0)); + assertEquals(OptionalLong.of(1000L), transactionManager.lastAckedOffset(tp0)); + + // Send second ProduceRequest, a single batch with 2 records. + appendToAccumulator(tp0); + Future request2 = appendToAccumulator(tp0); + sender.runOnce(); + assertEquals(3, transactionManager.sequenceNumber(tp0).longValue()); + assertEquals(OptionalInt.of(0), transactionManager.lastAckedSequence(tp0)); + + assertFalse(request2.isDone()); + + sendIdempotentProducerResponse(1, tp0, Errors.UNKNOWN_PRODUCER_ID, -1L, 1010L); + sender.runOnce(); // receive response 0, should be retried since the logStartOffset > lastAckedOffset. + + // We should have reset the sequence number state of the partition because the state was lost on the broker. + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + assertEquals(2, transactionManager.sequenceNumber(tp0).longValue()); + assertFalse(request2.isDone()); + assertFalse(client.hasInFlightRequests()); + + sender.runOnce(); // should retry request 1 + + // resend the request. Note that the expected sequence is 0, since we have lost producer state on the broker. + sendIdempotentProducerResponse(0, tp0, Errors.NONE, 1011L, 1010L); + sender.runOnce(); // receive response 1 + assertEquals(OptionalInt.of(1), transactionManager.lastAckedSequence(tp0)); + assertEquals(2, transactionManager.sequenceNumber(tp0).longValue()); + assertFalse(client.hasInFlightRequests()); + assertTrue(request2.isDone()); + assertEquals(1012L, request2.get().offset()); + assertEquals(OptionalLong.of(1012L), transactionManager.lastAckedOffset(tp0)); + } + + @Test + public void testIdempotentUnknownProducerHandlingWhenRetentionLimitReached() throws Exception { + final long producerId = 343434L; + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager); + prepareAndReceiveInitProducerId(producerId, Errors.NONE); + assertTrue(transactionManager.hasProducerId()); + + assertEquals(0, transactionManager.sequenceNumber(tp0).longValue()); + + // Send first ProduceRequest + Future request1 = appendToAccumulator(tp0); + sender.runOnce(); + + assertEquals(1, client.inFlightRequestCount()); + assertEquals(1, transactionManager.sequenceNumber(tp0).longValue()); + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + + sendIdempotentProducerResponse(0, tp0, Errors.NONE, 1000L, 10L); + + sender.runOnce(); // receive the response. + + assertTrue(request1.isDone()); + assertEquals(1000L, request1.get().offset()); + assertEquals(OptionalInt.of(0), transactionManager.lastAckedSequence(tp0)); + assertEquals(OptionalLong.of(1000L), transactionManager.lastAckedOffset(tp0)); + + // Send second ProduceRequest, a single batch with 2 records. + appendToAccumulator(tp0); + Future request2 = appendToAccumulator(tp0); + sender.runOnce(); + assertEquals(3, transactionManager.sequenceNumber(tp0).longValue()); + assertEquals(OptionalInt.of(0), transactionManager.lastAckedSequence(tp0)); + + assertFalse(request2.isDone()); + + sendIdempotentProducerResponse(1, tp0, Errors.UNKNOWN_PRODUCER_ID, -1L, 1010L); + sender.runOnce(); // receive response 0, should be retried since the logStartOffset > lastAckedOffset. + sender.runOnce(); // bump epoch and retry request + + // We should have reset the sequence number state of the partition because the state was lost on the broker. + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + assertEquals(2, transactionManager.sequenceNumber(tp0).longValue()); + assertFalse(request2.isDone()); + assertTrue(client.hasInFlightRequests()); + assertEquals((short) 1, transactionManager.producerIdAndEpoch().epoch); + + // resend the request. Note that the expected sequence is 0, since we have lost producer state on the broker. + sendIdempotentProducerResponse(0, tp0, Errors.NONE, 1011L, 1010L); + sender.runOnce(); // receive response 1 + assertEquals(OptionalInt.of(1), transactionManager.lastAckedSequence(tp0)); + assertEquals(2, transactionManager.sequenceNumber(tp0).longValue()); + assertFalse(client.hasInFlightRequests()); + assertTrue(request2.isDone()); + assertEquals(1012L, request2.get().offset()); + assertEquals(OptionalLong.of(1012L), transactionManager.lastAckedOffset(tp0)); + } + + @Test + public void testUnknownProducerErrorShouldBeRetriedWhenLogStartOffsetIsUnknown() throws Exception { + final long producerId = 343434L; + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager); + prepareAndReceiveInitProducerId(producerId, Errors.NONE); + assertTrue(transactionManager.hasProducerId()); + + assertEquals(0, transactionManager.sequenceNumber(tp0).longValue()); + + // Send first ProduceRequest + Future request1 = appendToAccumulator(tp0); + sender.runOnce(); + + assertEquals(1, client.inFlightRequestCount()); + assertEquals(1, transactionManager.sequenceNumber(tp0).longValue()); + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + + sendIdempotentProducerResponse(0, tp0, Errors.NONE, 1000L, 10L); + + sender.runOnce(); // receive the response. + + assertTrue(request1.isDone()); + assertEquals(1000L, request1.get().offset()); + assertEquals(OptionalInt.of(0), transactionManager.lastAckedSequence(tp0)); + assertEquals(OptionalLong.of(1000L), transactionManager.lastAckedOffset(tp0)); + + // Send second ProduceRequest + Future request2 = appendToAccumulator(tp0); + sender.runOnce(); + assertEquals(2, transactionManager.sequenceNumber(tp0).longValue()); + assertEquals(OptionalInt.of(0), transactionManager.lastAckedSequence(tp0)); + + assertFalse(request2.isDone()); + + sendIdempotentProducerResponse(1, tp0, Errors.UNKNOWN_PRODUCER_ID, -1L, -1L); + sender.runOnce(); // receive response 0, should be retried without resetting the sequence numbers since the log start offset is unknown. + + // We should have reset the sequence number state of the partition because the state was lost on the broker. + assertEquals(OptionalInt.of(0), transactionManager.lastAckedSequence(tp0)); + assertEquals(2, transactionManager.sequenceNumber(tp0).longValue()); + assertFalse(request2.isDone()); + assertFalse(client.hasInFlightRequests()); + + sender.runOnce(); // should retry request 1 + + // resend the request. Note that the expected sequence is 1, since we never got the logStartOffset in the previous + // response and hence we didn't reset the sequence numbers. + sendIdempotentProducerResponse(1, tp0, Errors.NONE, 1011L, 1010L); + sender.runOnce(); // receive response 1 + assertEquals(OptionalInt.of(1), transactionManager.lastAckedSequence(tp0)); + assertEquals(2, transactionManager.sequenceNumber(tp0).longValue()); + assertFalse(client.hasInFlightRequests()); + assertTrue(request2.isDone()); + assertEquals(1011L, request2.get().offset()); + assertEquals(OptionalLong.of(1011L), transactionManager.lastAckedOffset(tp0)); + } + + @Test + public void testUnknownProducerErrorShouldBeRetriedForFutureBatchesWhenFirstFails() throws Exception { + final long producerId = 343434L; + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager); + prepareAndReceiveInitProducerId(producerId, Errors.NONE); + assertTrue(transactionManager.hasProducerId()); + + assertEquals(0, transactionManager.sequenceNumber(tp0).longValue()); + + // Send first ProduceRequest + Future request1 = appendToAccumulator(tp0); + sender.runOnce(); + + assertEquals(1, client.inFlightRequestCount()); + assertEquals(1, transactionManager.sequenceNumber(tp0).longValue()); + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + + sendIdempotentProducerResponse(0, tp0, Errors.NONE, 1000L, 10L); + + sender.runOnce(); // receive the response. + + assertTrue(request1.isDone()); + assertEquals(1000L, request1.get().offset()); + assertEquals(OptionalInt.of(0), transactionManager.lastAckedSequence(tp0)); + assertEquals(OptionalLong.of(1000L), transactionManager.lastAckedOffset(tp0)); + + // Send second ProduceRequest + Future request2 = appendToAccumulator(tp0); + sender.runOnce(); + assertEquals(2, transactionManager.sequenceNumber(tp0).longValue()); + assertEquals(OptionalInt.of(0), transactionManager.lastAckedSequence(tp0)); + + // Send the third ProduceRequest, in parallel with the second. It should be retried even though the + // lastAckedOffset > logStartOffset when its UnknownProducerResponse comes back. + Future request3 = appendToAccumulator(tp0); + sender.runOnce(); + assertEquals(3, transactionManager.sequenceNumber(tp0).longValue()); + assertEquals(OptionalInt.of(0), transactionManager.lastAckedSequence(tp0)); + + assertFalse(request2.isDone()); + assertFalse(request3.isDone()); + assertEquals(2, client.inFlightRequestCount()); + + sendIdempotentProducerResponse(1, tp0, Errors.UNKNOWN_PRODUCER_ID, -1L, 1010L); + sender.runOnce(); // receive response 2, should reset the sequence numbers and be retried. + sender.runOnce(); // bump epoch and retry request 2 + + // We should have reset the sequence number state of the partition because the state was lost on the broker. + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + assertEquals(2, transactionManager.sequenceNumber(tp0).longValue()); + assertFalse(request2.isDone()); + assertFalse(request3.isDone()); + assertEquals(2, client.inFlightRequestCount()); + assertEquals((short) 1, transactionManager.producerIdAndEpoch().epoch); + + // receive the original response 3. note the expected sequence is still the originally assigned sequence. + sendIdempotentProducerResponse(2, tp0, Errors.UNKNOWN_PRODUCER_ID, -1, 1010L); + sender.runOnce(); // receive response 3 + + assertEquals(1, client.inFlightRequestCount()); + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + assertEquals(2, transactionManager.sequenceNumber(tp0).longValue()); + + sendIdempotentProducerResponse(0, tp0, Errors.NONE, 1011L, 1010L); + sender.runOnce(); // receive response 2, don't send request 3 since we can have at most 1 in flight when retrying + + assertTrue(request2.isDone()); + assertFalse(request3.isDone()); + assertFalse(client.hasInFlightRequests()); + assertEquals(OptionalInt.of(0), transactionManager.lastAckedSequence(tp0)); + assertEquals(1011L, request2.get().offset()); + assertEquals(OptionalLong.of(1011L), transactionManager.lastAckedOffset(tp0)); + + sender.runOnce(); // resend request 3. + assertEquals(1, client.inFlightRequestCount()); + + sendIdempotentProducerResponse(1, tp0, Errors.NONE, 1012L, 1010L); + sender.runOnce(); // receive response 3. + + assertFalse(client.hasInFlightRequests()); + assertTrue(request3.isDone()); + assertEquals(1012L, request3.get().offset()); + assertEquals(OptionalLong.of(1012L), transactionManager.lastAckedOffset(tp0)); + } + + @Test + public void testShouldRaiseOutOfOrderSequenceExceptionToUserIfLogWasNotTruncated() throws Exception { + final long producerId = 343434L; + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager); + prepareAndReceiveInitProducerId(producerId, Errors.NONE); + assertTrue(transactionManager.hasProducerId()); + + assertEquals(0, transactionManager.sequenceNumber(tp0).longValue()); + + // Send first ProduceRequest + Future request1 = appendToAccumulator(tp0); + sender.runOnce(); + + assertEquals(1, client.inFlightRequestCount()); + assertEquals(1, transactionManager.sequenceNumber(tp0).longValue()); + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + + sendIdempotentProducerResponse(0, tp0, Errors.NONE, 1000L, 10L); + + sender.runOnce(); // receive the response. + + assertTrue(request1.isDone()); + assertEquals(1000L, request1.get().offset()); + assertEquals(OptionalInt.of(0), transactionManager.lastAckedSequence(tp0)); + assertEquals(OptionalLong.of(1000L), transactionManager.lastAckedOffset(tp0)); + + // Send second ProduceRequest, + Future request2 = appendToAccumulator(tp0); + sender.runOnce(); + assertEquals(2, transactionManager.sequenceNumber(tp0).longValue()); + assertEquals(OptionalInt.of(0), transactionManager.lastAckedSequence(tp0)); + + assertFalse(request2.isDone()); + + sendIdempotentProducerResponse(1, tp0, Errors.UNKNOWN_PRODUCER_ID, -1L, 10L); + sender.runOnce(); // receive response 0, should request an epoch bump + sender.runOnce(); // bump epoch + assertEquals(1, transactionManager.producerIdAndEpoch().epoch); + assertEquals(OptionalInt.empty(), transactionManager.lastAckedSequence(tp0)); + assertFalse(request2.isDone()); + } + + void sendIdempotentProducerResponse(int expectedSequence, TopicPartition tp, Errors responseError, long responseOffset) { + sendIdempotentProducerResponse(expectedSequence, tp, responseError, responseOffset, -1L); + } + + void sendIdempotentProducerResponse(int expectedSequence, TopicPartition tp, Errors responseError, long responseOffset, long logStartOffset) { + sendIdempotentProducerResponse(-1, expectedSequence, tp, responseError, responseOffset, logStartOffset); + } + + void sendIdempotentProducerResponse( + int expectedEpoch, + int expectedSequence, + TopicPartition tp, + Errors responseError, + long responseOffset, + long logStartOffset + ) { + client.respond(body -> { + ProduceRequest produceRequest = (ProduceRequest) body; + assertTrue(RequestTestUtils.hasIdempotentRecords(produceRequest)); + MemoryRecords records = partitionRecords(produceRequest).get(tp); + Iterator batchIterator = records.batches().iterator(); + RecordBatch firstBatch = batchIterator.next(); + assertFalse(batchIterator.hasNext()); + if (expectedEpoch > -1) + assertEquals((short) expectedEpoch, firstBatch.producerEpoch()); + assertEquals(expectedSequence, firstBatch.baseSequence()); + return true; + }, produceResponse(tp, responseOffset, responseError, 0, logStartOffset, null)); + } + + @Test + public void testClusterAuthorizationExceptionInProduceRequest() throws Exception { + final long producerId = 343434L; + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager); + + prepareAndReceiveInitProducerId(producerId, Errors.NONE); + assertTrue(transactionManager.hasProducerId()); + + // cluster authorization is a fatal error for the producer + Future future = appendToAccumulator(tp0); + client.prepareResponse( + body -> body instanceof ProduceRequest && RequestTestUtils.hasIdempotentRecords((ProduceRequest) body), + produceResponse(tp0, -1, Errors.CLUSTER_AUTHORIZATION_FAILED, 0)); + + sender.runOnce(); + assertFutureFailure(future, ClusterAuthorizationException.class); + + // cluster authorization errors are fatal, so we should continue seeing it on future sends + assertTrue(transactionManager.hasFatalError()); + assertSendFailure(ClusterAuthorizationException.class); + } + + @Test + public void testCancelInFlightRequestAfterFatalError() throws Exception { + final long producerId = 343434L; + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager); + + prepareAndReceiveInitProducerId(producerId, Errors.NONE); + assertTrue(transactionManager.hasProducerId()); + + // cluster authorization is a fatal error for the producer + Future future1 = appendToAccumulator(tp0); + sender.runOnce(); + + Future future2 = appendToAccumulator(tp1); + sender.runOnce(); + + client.respond( + body -> body instanceof ProduceRequest && RequestTestUtils.hasIdempotentRecords((ProduceRequest) body), + produceResponse(tp0, -1, Errors.CLUSTER_AUTHORIZATION_FAILED, 0)); + + sender.runOnce(); + assertTrue(transactionManager.hasFatalError()); + assertFutureFailure(future1, ClusterAuthorizationException.class); + + sender.runOnce(); + assertFutureFailure(future2, ClusterAuthorizationException.class); + + // Should be fine if the second response eventually returns + client.respond( + body -> body instanceof ProduceRequest && RequestTestUtils.hasIdempotentRecords((ProduceRequest) body), + produceResponse(tp1, 0, Errors.NONE, 0)); + sender.runOnce(); + } + + @Test + public void testUnsupportedForMessageFormatInProduceRequest() throws Exception { + final long producerId = 343434L; + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager); + + prepareAndReceiveInitProducerId(producerId, Errors.NONE); + assertTrue(transactionManager.hasProducerId()); + + Future future = appendToAccumulator(tp0); + client.prepareResponse( + body -> body instanceof ProduceRequest && RequestTestUtils.hasIdempotentRecords((ProduceRequest) body), + produceResponse(tp0, -1, Errors.UNSUPPORTED_FOR_MESSAGE_FORMAT, 0)); + + sender.runOnce(); + assertFutureFailure(future, UnsupportedForMessageFormatException.class); + + // unsupported for message format is not a fatal error + assertFalse(transactionManager.hasError()); + } + + @Test + public void testUnsupportedVersionInProduceRequest() throws Exception { + final long producerId = 343434L; + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager); + + prepareAndReceiveInitProducerId(producerId, Errors.NONE); + assertTrue(transactionManager.hasProducerId()); + + Future future = appendToAccumulator(tp0); + client.prepareUnsupportedVersionResponse( + body -> body instanceof ProduceRequest && RequestTestUtils.hasIdempotentRecords((ProduceRequest) body)); + + sender.runOnce(); + assertFutureFailure(future, UnsupportedVersionException.class); + + // unsupported version errors are fatal, so we should continue seeing it on future sends + assertTrue(transactionManager.hasFatalError()); + assertSendFailure(UnsupportedVersionException.class); + } + + @Test + public void testSequenceNumberIncrement() throws InterruptedException { + final long producerId = 343434L; + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager); + prepareAndReceiveInitProducerId(producerId, Errors.NONE); + assertTrue(transactionManager.hasProducerId()); + + int maxRetries = 10; + Metrics m = new Metrics(); + SenderMetricsRegistry senderMetrics = new SenderMetricsRegistry(m); + + Sender sender = new Sender(logContext, client, metadata, this.accumulator, true, MAX_REQUEST_SIZE, ACKS_ALL, maxRetries, + senderMetrics, time, REQUEST_TIMEOUT, RETRY_BACKOFF_MS, transactionManager, apiVersions); + + Future responseFuture = appendToAccumulator(tp0); + client.prepareResponse(body -> { + if (body instanceof ProduceRequest) { + ProduceRequest request = (ProduceRequest) body; + MemoryRecords records = partitionRecords(request).get(tp0); + Iterator batchIterator = records.batches().iterator(); + assertTrue(batchIterator.hasNext()); + RecordBatch batch = batchIterator.next(); + assertFalse(batchIterator.hasNext()); + assertEquals(0, batch.baseSequence()); + assertEquals(producerId, batch.producerId()); + assertEquals(0, batch.producerEpoch()); + return true; + } + return false; + }, produceResponse(tp0, 0, Errors.NONE, 0)); + + sender.runOnce(); // connect. + sender.runOnce(); // send. + + sender.runOnce(); // receive response + assertTrue(responseFuture.isDone()); + assertEquals(OptionalInt.of(0), transactionManager.lastAckedSequence(tp0)); + assertEquals(1L, (long) transactionManager.sequenceNumber(tp0)); + } + + @Test + public void testRetryWhenProducerIdChanges() throws InterruptedException { + final long producerId = 343434L; + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager); + prepareAndReceiveInitProducerId(producerId, Short.MAX_VALUE, Errors.NONE); + assertTrue(transactionManager.hasProducerId()); + + int maxRetries = 10; + Metrics m = new Metrics(); + SenderMetricsRegistry senderMetrics = new SenderMetricsRegistry(m); + Sender sender = new Sender(logContext, client, metadata, this.accumulator, true, MAX_REQUEST_SIZE, ACKS_ALL, maxRetries, + senderMetrics, time, REQUEST_TIMEOUT, RETRY_BACKOFF_MS, transactionManager, apiVersions); + + Future responseFuture = appendToAccumulator(tp0); + sender.runOnce(); // connect. + sender.runOnce(); // send. + String id = client.requests().peek().destination(); + Node node = new Node(Integer.valueOf(id), "localhost", 0); + assertEquals(1, client.inFlightRequestCount()); + assertTrue(client.isReady(node, time.milliseconds()), "Client ready status should be true"); + client.disconnect(id); + assertEquals(0, client.inFlightRequestCount()); + assertFalse(client.isReady(node, time.milliseconds()), "Client ready status should be false"); + sender.runOnce(); // receive error + sender.runOnce(); // reset producer ID because epoch is maxed out + + prepareAndReceiveInitProducerId(producerId + 1, Errors.NONE); + sender.runOnce(); // nothing to do, since the pid has changed. We should check the metrics for errors. + assertEquals(1, client.inFlightRequestCount(), "Expected requests to be retried after pid change"); + + assertFalse(responseFuture.isDone()); + assertEquals(1, (long) transactionManager.sequenceNumber(tp0)); + } + + @Test + public void testBumpEpochWhenOutOfOrderSequenceReceived() throws InterruptedException { + final long producerId = 343434L; + TransactionManager transactionManager = createTransactionManager(); + setupWithTransactionState(transactionManager); + prepareAndReceiveInitProducerId(producerId, Errors.NONE); + assertTrue(transactionManager.hasProducerId()); + + int maxRetries = 10; + Metrics m = new Metrics(); + SenderMetricsRegistry senderMetrics = new SenderMetricsRegistry(m); + + Sender sender = new Sender(logContext, client, metadata, this.accumulator, true, MAX_REQUEST_SIZE, ACKS_ALL, maxRetries, + senderMetrics, time, REQUEST_TIMEOUT, RETRY_BACKOFF_MS, transactionManager, apiVersions); + + Future responseFuture = appendToAccumulator(tp0); + sender.runOnce(); // connect. + sender.runOnce(); // send. + + assertEquals(1, client.inFlightRequestCount()); + assertEquals(1, sender.inFlightBatches(tp0).size()); + + client.respond(produceResponse(tp0, 0, Errors.OUT_OF_ORDER_SEQUENCE_NUMBER, 0)); + + sender.runOnce(); // receive the out of order sequence error + sender.runOnce(); // bump the epoch + assertFalse(responseFuture.isDone()); + assertEquals(1, sender.inFlightBatches(tp0).size()); + assertEquals(1, transactionManager.producerIdAndEpoch().epoch); + } + + @Test + public void testIdempotentSplitBatchAndSend() throws Exception { + TopicPartition tp = new TopicPartition("testSplitBatchAndSend", 1); + TransactionManager txnManager = createTransactionManager(); + ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0); + setupWithTransactionState(txnManager); + prepareAndReceiveInitProducerId(123456L, Errors.NONE); + assertTrue(txnManager.hasProducerId()); + testSplitBatchAndSend(txnManager, producerIdAndEpoch, tp); + } + + @Test + public void testTransactionalSplitBatchAndSend() throws Exception { + ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0); + TopicPartition tp = new TopicPartition("testSplitBatchAndSend", 1); + TransactionManager txnManager = new TransactionManager(logContext, "testSplitBatchAndSend", 60000, 100, apiVersions); + + setupWithTransactionState(txnManager); + doInitTransactions(txnManager, producerIdAndEpoch); + + txnManager.beginTransaction(); + txnManager.failIfNotReadyForSend(); + txnManager.maybeAddPartitionToTransaction(tp); + client.prepareResponse(new AddPartitionsToTxnResponse(0, Collections.singletonMap(tp, Errors.NONE))); + sender.runOnce(); + + testSplitBatchAndSend(txnManager, producerIdAndEpoch, tp); + } + + @SuppressWarnings("deprecation") + private void testSplitBatchAndSend(TransactionManager txnManager, + ProducerIdAndEpoch producerIdAndEpoch, + TopicPartition tp) throws Exception { + int maxRetries = 1; + String topic = tp.topic(); + int deliveryTimeoutMs = 3000; + long totalSize = 1024 * 1024; + String metricGrpName = "producer-metrics"; + // Set a good compression ratio. + CompressionRatioEstimator.setEstimation(topic, CompressionType.GZIP, 0.2f); + try (Metrics m = new Metrics()) { + accumulator = new RecordAccumulator(logContext, batchSize, CompressionType.GZIP, + 0, 0L, deliveryTimeoutMs, m, metricGrpName, time, new ApiVersions(), txnManager, + new BufferPool(totalSize, batchSize, metrics, time, "producer-internal-metrics")); + SenderMetricsRegistry senderMetrics = new SenderMetricsRegistry(m); + Sender sender = new Sender(logContext, client, metadata, this.accumulator, true, MAX_REQUEST_SIZE, ACKS_ALL, maxRetries, + senderMetrics, time, REQUEST_TIMEOUT, 1000L, txnManager, new ApiVersions()); + // Create a two broker cluster, with partition 0 on broker 0 and partition 1 on broker 1 + MetadataResponse metadataUpdate1 = RequestTestUtils.metadataUpdateWith(2, Collections.singletonMap(topic, 2)); + client.prepareMetadataUpdate(metadataUpdate1); + // Send the first message. + long nowMs = time.milliseconds(); + Future f1 = + accumulator.append(tp, 0L, "key1".getBytes(), new byte[batchSize / 2], null, null, MAX_BLOCK_TIMEOUT, false, nowMs).future; + Future f2 = + accumulator.append(tp, 0L, "key2".getBytes(), new byte[batchSize / 2], null, null, MAX_BLOCK_TIMEOUT, false, nowMs).future; + sender.runOnce(); // connect + sender.runOnce(); // send produce request + + assertEquals(2, txnManager.sequenceNumber(tp).longValue(), "The next sequence should be 2"); + String id = client.requests().peek().destination(); + assertEquals(ApiKeys.PRODUCE, client.requests().peek().requestBuilder().apiKey()); + Node node = new Node(Integer.valueOf(id), "localhost", 0); + assertEquals(1, client.inFlightRequestCount()); + assertTrue(client.isReady(node, time.milliseconds()), "Client ready status should be true"); + + Map responseMap = new HashMap<>(); + responseMap.put(tp, new ProduceResponse.PartitionResponse(Errors.MESSAGE_TOO_LARGE)); + client.respond(new ProduceResponse(responseMap)); + sender.runOnce(); // split and reenqueue + assertEquals(2, txnManager.sequenceNumber(tp).longValue(), "The next sequence should be 2"); + // The compression ratio should have been improved once. + assertEquals(CompressionType.GZIP.rate - CompressionRatioEstimator.COMPRESSION_RATIO_IMPROVING_STEP, + CompressionRatioEstimator.estimation(topic, CompressionType.GZIP), 0.01); + sender.runOnce(); // send the first produce request + assertEquals(2, txnManager.sequenceNumber(tp).longValue(), "The next sequence number should be 2"); + assertFalse(f1.isDone(), "The future shouldn't have been done."); + assertFalse(f2.isDone(), "The future shouldn't have been done."); + id = client.requests().peek().destination(); + assertEquals(ApiKeys.PRODUCE, client.requests().peek().requestBuilder().apiKey()); + node = new Node(Integer.valueOf(id), "localhost", 0); + assertEquals(1, client.inFlightRequestCount()); + assertTrue(client.isReady(node, time.milliseconds()), "Client ready status should be true"); + + responseMap.put(tp, new ProduceResponse.PartitionResponse(Errors.NONE, 0L, 0L, 0L)); + client.respond(produceRequestMatcher(tp, producerIdAndEpoch, 0, txnManager.isTransactional()), + new ProduceResponse(responseMap)); + + sender.runOnce(); // receive + assertTrue(f1.isDone(), "The future should have been done."); + assertEquals(2, txnManager.sequenceNumber(tp).longValue(), "The next sequence number should still be 2"); + assertEquals(OptionalInt.of(0), txnManager.lastAckedSequence(tp), "The last ack'd sequence number should be 0"); + assertFalse(f2.isDone(), "The future shouldn't have been done."); + assertEquals(0L, f1.get().offset(), "Offset of the first message should be 0"); + sender.runOnce(); // send the seconcd produce request + id = client.requests().peek().destination(); + assertEquals(ApiKeys.PRODUCE, client.requests().peek().requestBuilder().apiKey()); + node = new Node(Integer.valueOf(id), "localhost", 0); + assertEquals(1, client.inFlightRequestCount()); + assertTrue(client.isReady(node, time.milliseconds()), "Client ready status should be true"); + + responseMap.put(tp, new ProduceResponse.PartitionResponse(Errors.NONE, 1L, 0L, 0L)); + client.respond(produceRequestMatcher(tp, producerIdAndEpoch, 1, txnManager.isTransactional()), + new ProduceResponse(responseMap)); + + sender.runOnce(); // receive + assertTrue(f2.isDone(), "The future should have been done."); + assertEquals(2, txnManager.sequenceNumber(tp).longValue(), "The next sequence number should be 2"); + assertEquals(OptionalInt.of(1), txnManager.lastAckedSequence(tp), "The last ack'd sequence number should be 1"); + assertEquals(1L, f2.get().offset(), "Offset of the first message should be 1"); + assertTrue(accumulator.batches().get(tp).isEmpty(), "There should be no batch in the accumulator"); + assertTrue((Double) (m.metrics().get(senderMetrics.batchSplitRate).metricValue()) > 0, "There should be a split"); + } + } + + @Test + public void testNoDoubleDeallocation() throws Exception { + long totalSize = 1024 * 1024; + String metricGrpName = "producer-custom-metrics"; + MatchingBufferPool pool = new MatchingBufferPool(totalSize, batchSize, metrics, time, metricGrpName); + setupWithTransactionState(null, false, pool); + + // Send first ProduceRequest + Future request1 = appendToAccumulator(tp0); + sender.runOnce(); // send request + assertEquals(1, client.inFlightRequestCount()); + assertEquals(1, sender.inFlightBatches(tp0).size()); + + time.sleep(REQUEST_TIMEOUT); + assertFalse(pool.allMatch()); + + sender.runOnce(); // expire the batch + assertTrue(request1.isDone()); + assertTrue(pool.allMatch(), "The batch should have been de-allocated"); + assertTrue(pool.allMatch()); + + sender.runOnce(); + assertTrue(pool.allMatch(), "The batch should have been de-allocated"); + assertEquals(0, client.inFlightRequestCount()); + assertEquals(0, sender.inFlightBatches(tp0).size()); + } + + @SuppressWarnings("deprecation") + @Test + public void testInflightBatchesExpireOnDeliveryTimeout() throws InterruptedException { + long deliveryTimeoutMs = 1500L; + setupWithTransactionState(null, true, null); + + // Send first ProduceRequest + Future request = appendToAccumulator(tp0); + sender.runOnce(); // send request + assertEquals(1, client.inFlightRequestCount()); + assertEquals(1, sender.inFlightBatches(tp0).size(), "Expect one in-flight batch in accumulator"); + + Map responseMap = new HashMap<>(); + responseMap.put(tp0, new ProduceResponse.PartitionResponse(Errors.NONE, 0L, 0L, 0L)); + client.respond(new ProduceResponse(responseMap)); + + time.sleep(deliveryTimeoutMs); + sender.runOnce(); // receive first response + assertEquals(0, sender.inFlightBatches(tp0).size(), "Expect zero in-flight batch in accumulator"); + try { + request.get(); + fail("The expired batch should throw a TimeoutException"); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof TimeoutException); + } + } + + @Test + public void testRecordErrorPropagatedToApplication() throws InterruptedException { + int recordCount = 5; + + setup(); + + Map futures = new HashMap<>(recordCount); + for (int i = 0; i < recordCount; i++) { + futures.put(i, appendToAccumulator(tp0)); + } + + sender.runOnce(); // send request + assertEquals(1, client.inFlightRequestCount()); + assertEquals(1, sender.inFlightBatches(tp0).size()); + + OffsetAndError offsetAndError = new OffsetAndError(-1L, Errors.INVALID_RECORD, Arrays.asList( + new BatchIndexAndErrorMessage().setBatchIndex(0).setBatchIndexErrorMessage("0"), + new BatchIndexAndErrorMessage().setBatchIndex(2).setBatchIndexErrorMessage("2"), + new BatchIndexAndErrorMessage().setBatchIndex(3) + )); + + client.respond(produceResponse(Collections.singletonMap(tp0, offsetAndError))); + sender.runOnce(); + + for (Map.Entry futureEntry : futures.entrySet()) { + FutureRecordMetadata future = futureEntry.getValue(); + assertTrue(future.isDone()); + + KafkaException exception = TestUtils.assertFutureThrows(future, KafkaException.class); + Integer index = futureEntry.getKey(); + if (index == 0 || index == 2) { + assertTrue(exception instanceof InvalidRecordException); + assertEquals(index.toString(), exception.getMessage()); + } else if (index == 3) { + assertTrue(exception instanceof InvalidRecordException); + assertEquals(Errors.INVALID_RECORD.message(), exception.getMessage()); + } else { + assertEquals(KafkaException.class, exception.getClass()); + } + } + } + + @Test + public void testWhenFirstBatchExpireNoSendSecondBatchIfGuaranteeOrder() throws InterruptedException { + long deliveryTimeoutMs = 1500L; + setupWithTransactionState(null, true, null); + + // Send first ProduceRequest + appendToAccumulator(tp0); + sender.runOnce(); // send request + assertEquals(1, client.inFlightRequestCount()); + assertEquals(1, sender.inFlightBatches(tp0).size()); + + time.sleep(deliveryTimeoutMs / 2); + + // Send second ProduceRequest + appendToAccumulator(tp0); + sender.runOnce(); // must not send request because the partition is muted + assertEquals(1, client.inFlightRequestCount()); + assertEquals(1, sender.inFlightBatches(tp0).size()); + + time.sleep(deliveryTimeoutMs / 2); // expire the first batch only + + client.respond(produceResponse(tp0, 0L, Errors.NONE, 0, 0L, null)); + sender.runOnce(); // receive response (offset=0) + assertEquals(0, client.inFlightRequestCount()); + assertEquals(0, sender.inFlightBatches(tp0).size()); + + sender.runOnce(); // Drain the second request only this time + assertEquals(1, client.inFlightRequestCount()); + assertEquals(1, sender.inFlightBatches(tp0).size()); + } + + @Test + public void testExpiredBatchDoesNotRetry() throws Exception { + long deliverTimeoutMs = 1500L; + setupWithTransactionState(null, false, null); + + // Send first ProduceRequest + Future request1 = appendToAccumulator(tp0); + sender.runOnce(); // send request + assertEquals(1, client.inFlightRequestCount()); + time.sleep(deliverTimeoutMs); + + client.respond(produceResponse(tp0, -1, Errors.NOT_LEADER_OR_FOLLOWER, -1)); // return a retriable error + + sender.runOnce(); // expire the batch + assertTrue(request1.isDone()); + assertEquals(0, client.inFlightRequestCount()); + assertEquals(0, sender.inFlightBatches(tp0).size()); + + sender.runOnce(); // receive first response and do not reenqueue. + assertEquals(0, client.inFlightRequestCount()); + assertEquals(0, sender.inFlightBatches(tp0).size()); + + sender.runOnce(); // run again and must not send anything. + assertEquals(0, client.inFlightRequestCount()); + assertEquals(0, sender.inFlightBatches(tp0).size()); + } + + @Test + public void testExpiredBatchDoesNotSplitOnMessageTooLargeError() throws Exception { + long deliverTimeoutMs = 1500L; + // create a producer batch with more than one record so it is eligible for splitting + Future request1 = appendToAccumulator(tp0); + Future request2 = appendToAccumulator(tp0); + + // send request + sender.runOnce(); + assertEquals(1, client.inFlightRequestCount()); + // return a MESSAGE_TOO_LARGE error + client.respond(produceResponse(tp0, -1, Errors.MESSAGE_TOO_LARGE, -1)); + + time.sleep(deliverTimeoutMs); + // expire the batch and process the response + sender.runOnce(); + assertTrue(request1.isDone()); + assertTrue(request2.isDone()); + assertEquals(0, client.inFlightRequestCount()); + assertEquals(0, sender.inFlightBatches(tp0).size()); + + // run again and must not split big batch and resend anything. + sender.runOnce(); + assertEquals(0, client.inFlightRequestCount()); + assertEquals(0, sender.inFlightBatches(tp0).size()); + } + + @Test + public void testResetNextBatchExpiry() throws Exception { + client = spy(new MockClient(time, metadata)); + + setupWithTransactionState(null); + + appendToAccumulator(tp0, 0L, "key", "value"); + + sender.runOnce(); + sender.runOnce(); + time.setCurrentTimeMs(time.milliseconds() + accumulator.getDeliveryTimeoutMs() + 1); + sender.runOnce(); + + InOrder inOrder = inOrder(client); + inOrder.verify(client, atLeastOnce()).ready(any(), anyLong()); + inOrder.verify(client, atLeastOnce()).newClientRequest(anyString(), any(), anyLong(), anyBoolean(), anyInt(), any()); + inOrder.verify(client, atLeastOnce()).send(any(), anyLong()); + inOrder.verify(client).poll(eq(0L), anyLong()); + inOrder.verify(client).poll(eq(accumulator.getDeliveryTimeoutMs()), anyLong()); + inOrder.verify(client).poll(geq(1L), anyLong()); + + } + + @SuppressWarnings("deprecation") + @Test + public void testExpiredBatchesInMultiplePartitions() throws Exception { + long deliveryTimeoutMs = 1500L; + setupWithTransactionState(null, true, null); + + // Send multiple ProduceRequest across multiple partitions. + Future request1 = appendToAccumulator(tp0, time.milliseconds(), "k1", "v1"); + Future request2 = appendToAccumulator(tp1, time.milliseconds(), "k2", "v2"); + + // Send request. + sender.runOnce(); + assertEquals(1, client.inFlightRequestCount()); + assertEquals(1, sender.inFlightBatches(tp0).size(), "Expect one in-flight batch in accumulator"); + + Map responseMap = new HashMap<>(); + responseMap.put(tp0, new ProduceResponse.PartitionResponse(Errors.NONE, 0L, 0L, 0L)); + client.respond(new ProduceResponse(responseMap)); + + // Successfully expire both batches. + time.sleep(deliveryTimeoutMs); + sender.runOnce(); + assertEquals(0, sender.inFlightBatches(tp0).size(), "Expect zero in-flight batch in accumulator"); + + ExecutionException e = assertThrows(ExecutionException.class, request1::get); + assertTrue(e.getCause() instanceof TimeoutException); + + e = assertThrows(ExecutionException.class, request2::get); + assertTrue(e.getCause() instanceof TimeoutException); + } + + @Test + public void testTransactionalRequestsSentOnShutdown() { + // create a sender with retries = 1 + int maxRetries = 1; + Metrics m = new Metrics(); + SenderMetricsRegistry senderMetrics = new SenderMetricsRegistry(m); + try { + TransactionManager txnManager = new TransactionManager(logContext, "testTransactionalRequestsSentOnShutdown", 6000, 100, apiVersions); + Sender sender = new Sender(logContext, client, metadata, this.accumulator, false, MAX_REQUEST_SIZE, ACKS_ALL, + maxRetries, senderMetrics, time, REQUEST_TIMEOUT, RETRY_BACKOFF_MS, txnManager, apiVersions); + + ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0); + TopicPartition tp = new TopicPartition("testTransactionalRequestsSentOnShutdown", 1); + + setupWithTransactionState(txnManager); + doInitTransactions(txnManager, producerIdAndEpoch); + + txnManager.beginTransaction(); + txnManager.failIfNotReadyForSend(); + txnManager.maybeAddPartitionToTransaction(tp); + client.prepareResponse(new AddPartitionsToTxnResponse(0, Collections.singletonMap(tp, Errors.NONE))); + sender.runOnce(); + sender.initiateClose(); + txnManager.beginCommit(); + AssertEndTxnRequestMatcher endTxnMatcher = new AssertEndTxnRequestMatcher(TransactionResult.COMMIT); + client.prepareResponse(endTxnMatcher, new EndTxnResponse(new EndTxnResponseData() + .setErrorCode(Errors.NONE.code()) + .setThrottleTimeMs(0))); + sender.run(); + assertTrue(endTxnMatcher.matched, "Response didn't match in test"); + } finally { + m.close(); + } + } + + @Test + public void testRecordsFlushedImmediatelyOnTransactionCompletion() throws Exception { + try (Metrics m = new Metrics()) { + int lingerMs = 50; + SenderMetricsRegistry senderMetrics = new SenderMetricsRegistry(m); + + TransactionManager txnManager = new TransactionManager(logContext, "txnId", 6000, 100, apiVersions); + setupWithTransactionState(txnManager, lingerMs); + + Sender sender = new Sender(logContext, client, metadata, this.accumulator, false, MAX_REQUEST_SIZE, ACKS_ALL, + 1, senderMetrics, time, REQUEST_TIMEOUT, RETRY_BACKOFF_MS, txnManager, apiVersions); + + // Begin a transaction and successfully add one partition to it. + ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0); + doInitTransactions(txnManager, producerIdAndEpoch); + txnManager.beginTransaction(); + addPartitionToTxn(sender, txnManager, tp0); + + // Send a couple records and assert that they are not sent immediately (due to linger). + appendToAccumulator(tp0); + appendToAccumulator(tp0); + sender.runOnce(); + assertFalse(client.hasInFlightRequests()); + + // Now begin the commit and assert that the Produce request is sent immediately + // without waiting for the linger. + txnManager.beginCommit(); + runUntil(sender, client::hasInFlightRequests); + + // Respond to the produce request and wait for the EndTxn request to be sent. + respondToProduce(tp0, Errors.NONE, 1L); + runUntil(sender, txnManager::hasInFlightRequest); + + // Respond to the expected EndTxn request. + respondToEndTxn(Errors.NONE); + runUntil(sender, txnManager::isReady); + + // Finally, we want to assert that the linger time is still effective + // when the new transaction begins. + txnManager.beginTransaction(); + addPartitionToTxn(sender, txnManager, tp0); + + appendToAccumulator(tp0); + appendToAccumulator(tp0); + time.sleep(lingerMs - 1); + sender.runOnce(); + assertFalse(client.hasInFlightRequests()); + assertTrue(accumulator.hasUndrained()); + + time.sleep(1); + runUntil(sender, client::hasInFlightRequests); + assertFalse(accumulator.hasUndrained()); + } + } + + @Test + public void testAwaitPendingRecordsBeforeCommittingTransaction() throws Exception { + try (Metrics m = new Metrics()) { + SenderMetricsRegistry senderMetrics = new SenderMetricsRegistry(m); + + TransactionManager txnManager = new TransactionManager(logContext, "txnId", 6000, 100, apiVersions); + setupWithTransactionState(txnManager); + + Sender sender = new Sender(logContext, client, metadata, this.accumulator, false, MAX_REQUEST_SIZE, ACKS_ALL, + 1, senderMetrics, time, REQUEST_TIMEOUT, RETRY_BACKOFF_MS, txnManager, apiVersions); + + // Begin a transaction and successfully add one partition to it. + ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0); + doInitTransactions(txnManager, producerIdAndEpoch); + txnManager.beginTransaction(); + addPartitionToTxn(sender, txnManager, tp0); + + // Send one Produce request. + appendToAccumulator(tp0); + runUntil(sender, () -> client.requests().size() == 1); + assertFalse(accumulator.hasUndrained()); + assertTrue(client.hasInFlightRequests()); + assertTrue(txnManager.hasInflightBatches(tp0)); + + // Enqueue another record and then commit the transaction. We expect the unsent record to + // get sent before the transaction can be completed. + appendToAccumulator(tp0); + txnManager.beginCommit(); + runUntil(sender, () -> client.requests().size() == 2); + + assertTrue(txnManager.isCompleting()); + assertFalse(txnManager.hasInFlightRequest()); + assertTrue(txnManager.hasInflightBatches(tp0)); + + // Now respond to the pending Produce requests. + respondToProduce(tp0, Errors.NONE, 0L); + respondToProduce(tp0, Errors.NONE, 1L); + runUntil(sender, txnManager::hasInFlightRequest); + + // Finally, respond to the expected EndTxn request. + respondToEndTxn(Errors.NONE); + runUntil(sender, txnManager::isReady); + } + } + + private void addPartitionToTxn(Sender sender, TransactionManager txnManager, TopicPartition tp) { + txnManager.maybeAddPartitionToTransaction(tp); + client.prepareResponse(new AddPartitionsToTxnResponse(0, Collections.singletonMap(tp, Errors.NONE))); + runUntil(sender, () -> txnManager.isPartitionAdded(tp)); + assertFalse(txnManager.hasInFlightRequest()); + } + + private void respondToProduce(TopicPartition tp, Errors error, long offset) { + client.respond( + request -> request instanceof ProduceRequest, + produceResponse(tp, offset, error, 0) + ); + + } + + private void respondToEndTxn(Errors error) { + client.respond( + request -> request instanceof EndTxnRequest, + new EndTxnResponse(new EndTxnResponseData() + .setErrorCode(error.code()) + .setThrottleTimeMs(0)) + ); + } + + @Test + public void testIncompleteTransactionAbortOnShutdown() { + // create a sender with retries = 1 + int maxRetries = 1; + Metrics m = new Metrics(); + SenderMetricsRegistry senderMetrics = new SenderMetricsRegistry(m); + try { + TransactionManager txnManager = new TransactionManager(logContext, "testIncompleteTransactionAbortOnShutdown", 6000, 100, apiVersions); + Sender sender = new Sender(logContext, client, metadata, this.accumulator, false, MAX_REQUEST_SIZE, ACKS_ALL, + maxRetries, senderMetrics, time, REQUEST_TIMEOUT, RETRY_BACKOFF_MS, txnManager, apiVersions); + + ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0); + TopicPartition tp = new TopicPartition("testIncompleteTransactionAbortOnShutdown", 1); + + setupWithTransactionState(txnManager); + doInitTransactions(txnManager, producerIdAndEpoch); + + txnManager.beginTransaction(); + txnManager.failIfNotReadyForSend(); + txnManager.maybeAddPartitionToTransaction(tp); + client.prepareResponse(new AddPartitionsToTxnResponse(0, Collections.singletonMap(tp, Errors.NONE))); + sender.runOnce(); + sender.initiateClose(); + AssertEndTxnRequestMatcher endTxnMatcher = new AssertEndTxnRequestMatcher(TransactionResult.ABORT); + client.prepareResponse(endTxnMatcher, new EndTxnResponse(new EndTxnResponseData() + .setErrorCode(Errors.NONE.code()) + .setThrottleTimeMs(0))); + sender.run(); + assertTrue(endTxnMatcher.matched, "Response didn't match in test"); + } finally { + m.close(); + } + } + + @Timeout(10L) + @Test + public void testForceShutdownWithIncompleteTransaction() { + // create a sender with retries = 1 + int maxRetries = 1; + Metrics m = new Metrics(); + SenderMetricsRegistry senderMetrics = new SenderMetricsRegistry(m); + try { + TransactionManager txnManager = new TransactionManager(logContext, "testForceShutdownWithIncompleteTransaction", 6000, 100, apiVersions); + Sender sender = new Sender(logContext, client, metadata, this.accumulator, false, MAX_REQUEST_SIZE, ACKS_ALL, + maxRetries, senderMetrics, time, REQUEST_TIMEOUT, RETRY_BACKOFF_MS, txnManager, apiVersions); + + ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0); + TopicPartition tp = new TopicPartition("testForceShutdownWithIncompleteTransaction", 1); + + setupWithTransactionState(txnManager); + doInitTransactions(txnManager, producerIdAndEpoch); + + txnManager.beginTransaction(); + txnManager.failIfNotReadyForSend(); + txnManager.maybeAddPartitionToTransaction(tp); + client.prepareResponse(new AddPartitionsToTxnResponse(0, Collections.singletonMap(tp, Errors.NONE))); + sender.runOnce(); + + // Try to commit the transaction but it won't happen as we'll forcefully close the sender + TransactionalRequestResult commitResult = txnManager.beginCommit(); + + sender.forceClose(); + sender.run(); + assertThrows(KafkaException.class, commitResult::await, + "The test expected to throw a KafkaException for forcefully closing the sender"); + } finally { + m.close(); + } + } + + @Test + public void testTransactionAbortedExceptionOnAbortWithoutError() throws InterruptedException, ExecutionException { + ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0); + TransactionManager txnManager = new TransactionManager(logContext, "testTransactionAbortedExceptionOnAbortWithoutError", 60000, 100, apiVersions); + + setupWithTransactionState(txnManager, false, null); + doInitTransactions(txnManager, producerIdAndEpoch); + // Begin the transaction + txnManager.beginTransaction(); + txnManager.maybeAddPartitionToTransaction(tp0); + client.prepareResponse(new AddPartitionsToTxnResponse(0, Collections.singletonMap(tp0, Errors.NONE))); + // Run it once so that the partition is added to the transaction. + sender.runOnce(); + // Append a record to the accumulator. + FutureRecordMetadata metadata = appendToAccumulator(tp0, time.milliseconds(), "key", "value"); + // Now abort the transaction manually. + txnManager.beginAbort(); + // Try to send. + // This should abort the existing transaction and + // drain all the unsent batches with a TransactionAbortedException. + sender.runOnce(); + // Now attempt to fetch the result for the record. + TestUtils.assertFutureThrows(metadata, TransactionAbortedException.class); + } + + @Test + public void testDoNotPollWhenNoRequestSent() { + client = spy(new MockClient(time, metadata)); + + TransactionManager txnManager = new TransactionManager(logContext, "testDoNotPollWhenNoRequestSent", 6000, 100, apiVersions); + ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0); + setupWithTransactionState(txnManager); + doInitTransactions(txnManager, producerIdAndEpoch); + + // doInitTransactions calls sender.doOnce three times, only two requests are sent, so we should only poll twice + verify(client, times(2)).poll(eq(RETRY_BACKOFF_MS), anyLong()); + } + + @Test + public void testTooLargeBatchesAreSafelyRemoved() throws InterruptedException { + ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(123456L, (short) 0); + TransactionManager txnManager = new TransactionManager(logContext, "testSplitBatchAndSend", 60000, 100, apiVersions); + + setupWithTransactionState(txnManager, false, null); + doInitTransactions(txnManager, producerIdAndEpoch); + + txnManager.beginTransaction(); + txnManager.maybeAddPartitionToTransaction(tp0); + client.prepareResponse(new AddPartitionsToTxnResponse(0, Collections.singletonMap(tp0, Errors.NONE))); + sender.runOnce(); + + // create a producer batch with more than one record so it is eligible for splitting + appendToAccumulator(tp0, time.milliseconds(), "key1", "value1"); + appendToAccumulator(tp0, time.milliseconds(), "key2", "value2"); + + // send request + sender.runOnce(); + assertEquals(1, sender.inFlightBatches(tp0).size()); + // return a MESSAGE_TOO_LARGE error + client.respond(produceResponse(tp0, -1, Errors.MESSAGE_TOO_LARGE, -1)); + sender.runOnce(); + + // process retried response + sender.runOnce(); + client.respond(produceResponse(tp0, 0, Errors.NONE, 0)); + sender.runOnce(); + + // In-flight batches should be empty. Sleep past the expiration time of the batch and run once, no error should be thrown + assertEquals(0, sender.inFlightBatches(tp0).size()); + time.sleep(2000); + sender.runOnce(); + } + + @Test + public void testDefaultErrorMessage() throws Exception { + verifyErrorMessage(produceResponse(tp0, 0L, Errors.INVALID_REQUEST, 0), Errors.INVALID_REQUEST.message()); + } + + @Test + public void testCustomErrorMessage() throws Exception { + String errorMessage = "testCustomErrorMessage"; + verifyErrorMessage(produceResponse(tp0, 0L, Errors.INVALID_REQUEST, 0, -1, errorMessage), errorMessage); + } + + private void verifyErrorMessage(ProduceResponse response, String expectedMessage) throws Exception { + Future future = appendToAccumulator(tp0, 0L, "key", "value"); + sender.runOnce(); // connect + sender.runOnce(); // send produce request + client.respond(response); + sender.runOnce(); + sender.runOnce(); + ExecutionException e1 = assertThrows(ExecutionException.class, () -> future.get(5, TimeUnit.SECONDS)); + assertEquals(InvalidRequestException.class, e1.getCause().getClass()); + assertEquals(expectedMessage, e1.getCause().getMessage()); + } + + class AssertEndTxnRequestMatcher implements MockClient.RequestMatcher { + + private TransactionResult requiredResult; + private boolean matched = false; + + AssertEndTxnRequestMatcher(TransactionResult requiredResult) { + this.requiredResult = requiredResult; + } + + @Override + public boolean matches(AbstractRequest body) { + if (body instanceof EndTxnRequest) { + assertSame(requiredResult, ((EndTxnRequest) body).result()); + matched = true; + return true; + } else { + return false; + } + } + } + + private class MatchingBufferPool extends BufferPool { + IdentityHashMap allocatedBuffers; + + MatchingBufferPool(long totalSize, int batchSize, Metrics metrics, Time time, String metricGrpName) { + super(totalSize, batchSize, metrics, time, metricGrpName); + allocatedBuffers = new IdentityHashMap<>(); + } + + @Override + public ByteBuffer allocate(int size, long maxTimeToBlockMs) throws InterruptedException { + ByteBuffer buffer = super.allocate(size, maxTimeToBlockMs); + allocatedBuffers.put(buffer, Boolean.TRUE); + return buffer; + } + + @Override + public void deallocate(ByteBuffer buffer, int size) { + if (!allocatedBuffers.containsKey(buffer)) { + throw new IllegalStateException("Deallocating a buffer that is not allocated"); + } + allocatedBuffers.remove(buffer); + super.deallocate(buffer, size); + } + + public boolean allMatch() { + return allocatedBuffers.isEmpty(); + } + } + + private MockClient.RequestMatcher produceRequestMatcher(final TopicPartition tp, + final ProducerIdAndEpoch producerIdAndEpoch, + final int sequence, + final boolean isTransactional) { + return body -> { + if (!(body instanceof ProduceRequest)) + return false; + + ProduceRequest request = (ProduceRequest) body; + Map recordsMap = partitionRecords(request); + MemoryRecords records = recordsMap.get(tp); + if (records == null) + return false; + + List batches = TestUtils.toList(records.batches()); + if (batches.size() != 1) + return false; + + MutableRecordBatch batch = batches.get(0); + return batch.baseOffset() == 0L && + batch.baseSequence() == sequence && + batch.producerId() == producerIdAndEpoch.producerId && + batch.producerEpoch() == producerIdAndEpoch.epoch && + batch.isTransactional() == isTransactional; + }; + } + + private static class OffsetAndError { + final long offset; + final Errors error; + final List recordErrors; + + OffsetAndError( + long offset, + Errors error, + List recordErrors + ) { + this.offset = offset; + this.error = error; + this.recordErrors = recordErrors; + } + + OffsetAndError(long offset, Errors error) { + this(offset, error, Collections.emptyList()); + } + + } + + private FutureRecordMetadata appendToAccumulator(TopicPartition tp) throws InterruptedException { + return appendToAccumulator(tp, time.milliseconds(), "key", "value"); + } + + private FutureRecordMetadata appendToAccumulator(TopicPartition tp, long timestamp, String key, String value) throws InterruptedException { + return accumulator.append(tp, timestamp, key.getBytes(), value.getBytes(), Record.EMPTY_HEADERS, + null, MAX_BLOCK_TIMEOUT, false, time.milliseconds()).future; + } + + @SuppressWarnings("deprecation") + private ProduceResponse produceResponse(TopicPartition tp, long offset, Errors error, int throttleTimeMs, long logStartOffset, String errorMessage) { + ProduceResponse.PartitionResponse resp = new ProduceResponse.PartitionResponse(error, offset, + RecordBatch.NO_TIMESTAMP, logStartOffset, Collections.emptyList(), errorMessage); + Map partResp = Collections.singletonMap(tp, resp); + return new ProduceResponse(partResp, throttleTimeMs); + } + + private ProduceResponse produceResponse(Map responses) { + ProduceResponseData data = new ProduceResponseData(); + + for (Map.Entry entry : responses.entrySet()) { + TopicPartition topicPartition = entry.getKey(); + ProduceResponseData.TopicProduceResponse topicData = data.responses().find(topicPartition.topic()); + if (topicData == null) { + topicData = new ProduceResponseData.TopicProduceResponse().setName(topicPartition.topic()); + data.responses().add(topicData); + } + + OffsetAndError offsetAndError = entry.getValue(); + ProduceResponseData.PartitionProduceResponse partitionData = + new ProduceResponseData.PartitionProduceResponse() + .setIndex(topicPartition.partition()) + .setBaseOffset(offsetAndError.offset) + .setErrorCode(offsetAndError.error.code()) + .setRecordErrors(offsetAndError.recordErrors); + + topicData.partitionResponses().add(partitionData); + } + + return new ProduceResponse(data); + } + private ProduceResponse produceResponse(TopicPartition tp, long offset, Errors error, int throttleTimeMs) { + return produceResponse(tp, offset, error, throttleTimeMs, -1L, null); + } + + private TransactionManager createTransactionManager() { + return new TransactionManager(new LogContext(), null, 0, 100L, new ApiVersions()); + } + + private void setupWithTransactionState(TransactionManager transactionManager) { + setupWithTransactionState(transactionManager, false, null, true, Integer.MAX_VALUE, 0); + } + + private void setupWithTransactionState(TransactionManager transactionManager, int lingerMs) { + setupWithTransactionState(transactionManager, false, null, true, Integer.MAX_VALUE, lingerMs); + } + + private void setupWithTransactionState(TransactionManager transactionManager, boolean guaranteeOrder, BufferPool customPool) { + setupWithTransactionState(transactionManager, guaranteeOrder, customPool, true, Integer.MAX_VALUE, 0); + } + + private void setupWithTransactionState( + TransactionManager transactionManager, + boolean guaranteeOrder, + BufferPool customPool, + boolean updateMetadata + ) { + setupWithTransactionState(transactionManager, guaranteeOrder, customPool, updateMetadata, Integer.MAX_VALUE, 0); + } + + private void setupWithTransactionState( + TransactionManager transactionManager, + boolean guaranteeOrder, + BufferPool customPool, + boolean updateMetadata, + int retries, + int lingerMs + ) { + long totalSize = 1024 * 1024; + String metricGrpName = "producer-metrics"; + MetricConfig metricConfig = new MetricConfig().tags(Collections.singletonMap("client-id", CLIENT_ID)); + this.metrics = new Metrics(metricConfig, time); + BufferPool pool = (customPool == null) ? new BufferPool(totalSize, batchSize, metrics, time, metricGrpName) : customPool; + + this.accumulator = new RecordAccumulator(logContext, batchSize, CompressionType.NONE, lingerMs, 0L, + DELIVERY_TIMEOUT_MS, metrics, metricGrpName, time, apiVersions, transactionManager, pool); + this.senderMetricsRegistry = new SenderMetricsRegistry(this.metrics); + this.sender = new Sender(logContext, this.client, this.metadata, this.accumulator, guaranteeOrder, MAX_REQUEST_SIZE, ACKS_ALL, + retries, this.senderMetricsRegistry, this.time, REQUEST_TIMEOUT, RETRY_BACKOFF_MS, transactionManager, apiVersions); + + metadata.add("test", time.milliseconds()); + if (updateMetadata) + this.client.updateMetadata(RequestTestUtils.metadataUpdateWith(1, Collections.singletonMap("test", 2))); + } + + private void assertSendFailure(Class expectedError) throws Exception { + Future future = appendToAccumulator(tp0); + sender.runOnce(); + assertTrue(future.isDone()); + try { + future.get(); + fail("Future should have raised " + expectedError.getSimpleName()); + } catch (ExecutionException e) { + assertTrue(expectedError.isAssignableFrom(e.getCause().getClass())); + } + } + + private void prepareAndReceiveInitProducerId(long producerId, Errors error) { + prepareAndReceiveInitProducerId(producerId, (short) 0, error); + } + + private void prepareAndReceiveInitProducerId(long producerId, short producerEpoch, Errors error) { + if (error != Errors.NONE) + producerEpoch = RecordBatch.NO_PRODUCER_EPOCH; + + client.prepareResponse( + body -> body instanceof InitProducerIdRequest && + ((InitProducerIdRequest) body).data().transactionalId() == null, + initProducerIdResponse(producerId, producerEpoch, error)); + sender.runOnce(); + } + + private InitProducerIdResponse initProducerIdResponse(long producerId, short producerEpoch, Errors error) { + InitProducerIdResponseData responseData = new InitProducerIdResponseData() + .setErrorCode(error.code()) + .setProducerEpoch(producerEpoch) + .setProducerId(producerId) + .setThrottleTimeMs(0); + return new InitProducerIdResponse(responseData); + } + + private void doInitTransactions(TransactionManager transactionManager, ProducerIdAndEpoch producerIdAndEpoch) { + transactionManager.initializeTransactions(); + prepareFindCoordinatorResponse(Errors.NONE, transactionManager.transactionalId()); + sender.runOnce(); + sender.runOnce(); + + prepareInitProducerResponse(Errors.NONE, producerIdAndEpoch.producerId, producerIdAndEpoch.epoch); + sender.runOnce(); + assertTrue(transactionManager.hasProducerId()); + } + + private void prepareFindCoordinatorResponse(Errors error, String txnid) { + Node node = metadata.fetch().nodes().get(0); + client.prepareResponse(FindCoordinatorResponse.prepareResponse(error, txnid, node)); + } + + private void prepareInitProducerResponse(Errors error, long producerId, short producerEpoch) { + client.prepareResponse(initProducerIdResponse(producerId, producerEpoch, error)); + } + + private void assertFutureFailure(Future future, Class expectedExceptionType) + throws InterruptedException { + assertTrue(future.isDone()); + try { + future.get(); + fail("Future should have raised " + expectedExceptionType.getName()); + } catch (ExecutionException e) { + Class causeType = e.getCause().getClass(); + assertTrue(expectedExceptionType.isAssignableFrom(causeType), "Unexpected cause " + causeType.getName()); + } + } + + private void createMockClientWithMaxFlightOneMetadataPending() { + client = new MockClient(time, metadata) { + volatile boolean canSendMore = true; + @Override + public Node leastLoadedNode(long now) { + for (Node node : metadata.fetch().nodes()) { + if (isReady(node, now) && canSendMore) + return node; + } + return null; + } + + @Override + public List poll(long timeoutMs, long now) { + canSendMore = inFlightRequestCount() < 1; + return super.poll(timeoutMs, now); + } + }; + + // Send metadata request and wait until request is sent. `leastLoadedNode` will be null once + // request is in progress since no more requests can be sent to the node. Node will be ready + // on the next poll() after response is processed later on in tests which use this method. + MetadataRequest.Builder builder = new MetadataRequest.Builder(Collections.emptyList(), false); + Node node = metadata.fetch().nodes().get(0); + ClientRequest request = client.newClientRequest(node.idString(), builder, time.milliseconds(), true); + while (!client.ready(node, time.milliseconds())) + client.poll(0, time.milliseconds()); + client.send(request, time.milliseconds()); + while (client.leastLoadedNode(time.milliseconds()) != null) + client.poll(0, time.milliseconds()); + } + + private void waitForProducerId(TransactionManager transactionManager, ProducerIdAndEpoch producerIdAndEpoch) { + for (int i = 0; i < 5 && !transactionManager.hasProducerId(); i++) + sender.runOnce(); + + assertTrue(transactionManager.hasProducerId()); + assertEquals(producerIdAndEpoch, transactionManager.producerIdAndEpoch()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/StickyPartitionCacheTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/StickyPartitionCacheTest.java new file mode 100644 index 0000000..b1af9a3 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/StickyPartitionCacheTest.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer.internals; + +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.List; + +import static java.util.Arrays.asList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; + +public class StickyPartitionCacheTest { + private final static Node[] NODES = new Node[] { + new Node(0, "localhost", 99), + new Node(1, "localhost", 100), + new Node(2, "localhost", 101), + new Node(11, "localhost", 102) + }; + final static String TOPIC_A = "topicA"; + final static String TOPIC_B = "topicB"; + final static String TOPIC_C = "topicC"; + + @Test + public void testStickyPartitionCache() { + List allPartitions = asList(new PartitionInfo(TOPIC_A, 0, NODES[0], NODES, NODES), + new PartitionInfo(TOPIC_A, 1, NODES[1], NODES, NODES), + new PartitionInfo(TOPIC_A, 2, NODES[2], NODES, NODES), + new PartitionInfo(TOPIC_B, 0, NODES[0], NODES, NODES) + ); + Cluster testCluster = new Cluster("clusterId", asList(NODES), allPartitions, + Collections.emptySet(), Collections.emptySet()); + StickyPartitionCache stickyPartitionCache = new StickyPartitionCache(); + + int partA = stickyPartitionCache.partition(TOPIC_A, testCluster); + assertEquals(partA, stickyPartitionCache.partition(TOPIC_A, testCluster)); + + int partB = stickyPartitionCache.partition(TOPIC_B, testCluster); + assertEquals(partB, stickyPartitionCache.partition(TOPIC_B, testCluster)); + + int changedPartA = stickyPartitionCache.nextPartition(TOPIC_A, testCluster, partA); + assertEquals(changedPartA, stickyPartitionCache.partition(TOPIC_A, testCluster)); + assertNotEquals(partA, changedPartA); + int changedPartA2 = stickyPartitionCache.partition(TOPIC_A, testCluster); + assertEquals(changedPartA2, changedPartA); + + // We do not want to change partitions because the previous partition does not match the current sticky one. + int changedPartA3 = stickyPartitionCache.nextPartition(TOPIC_A, testCluster, partA); + assertEquals(changedPartA3, changedPartA2); + + // Check that the we can still use the partitioner when there is only one partition + int changedPartB = stickyPartitionCache.nextPartition(TOPIC_B, testCluster, partB); + assertEquals(changedPartB, stickyPartitionCache.partition(TOPIC_B, testCluster)); + } + + @Test + public void unavailablePartitionsTest() { + // Partition 1 in topic A and partition 0 in topic B are unavailable partitions. + List allPartitions = asList(new PartitionInfo(TOPIC_A, 0, NODES[0], NODES, NODES), + new PartitionInfo(TOPIC_A, 1, null, NODES, NODES), + new PartitionInfo(TOPIC_A, 2, NODES[2], NODES, NODES), + new PartitionInfo(TOPIC_B, 0, null, NODES, NODES), + new PartitionInfo(TOPIC_B, 1, NODES[0], NODES, NODES), + new PartitionInfo(TOPIC_C, 0, null, NODES, NODES) + ); + + Cluster testCluster = new Cluster("clusterId", asList(NODES[0], NODES[1], NODES[2]), allPartitions, + Collections.emptySet(), Collections.emptySet()); + StickyPartitionCache stickyPartitionCache = new StickyPartitionCache(); + + // Assure we never choose partition 1 because it is unavailable. + int partA = stickyPartitionCache.partition(TOPIC_A, testCluster); + assertNotEquals(1, partA); + for (int aPartitions = 0; aPartitions < 100; aPartitions++) { + partA = stickyPartitionCache.nextPartition(TOPIC_A, testCluster, partA); + assertNotEquals(1, stickyPartitionCache.partition(TOPIC_A, testCluster)); + } + + // Assure we always choose partition 1 for topic B. + int partB = stickyPartitionCache.partition(TOPIC_B, testCluster); + assertEquals(1, partB); + for (int bPartitions = 0; bPartitions < 100; bPartitions++) { + partB = stickyPartitionCache.nextPartition(TOPIC_B, testCluster, partB); + assertEquals(1, stickyPartitionCache.partition(TOPIC_B, testCluster)); + } + + // Assure that we still choose the partition when there are no partitions available. + int partC = stickyPartitionCache.partition(TOPIC_C, testCluster); + assertEquals(0, partC); + partC = stickyPartitionCache.nextPartition(TOPIC_C, testCluster, partC); + assertEquals(0, partC); + } +} diff --git a/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java b/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java new file mode 100644 index 0000000..6c1e2fd --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java @@ -0,0 +1,3584 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.clients.producer.internals; + +import org.apache.kafka.clients.ApiVersions; +import org.apache.kafka.clients.MockClient; +import org.apache.kafka.clients.NodeApiVersions; +import org.apache.kafka.clients.consumer.CommitFailedException; +import org.apache.kafka.clients.consumer.ConsumerGroupMetadata; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.FencedInstanceIdException; +import org.apache.kafka.common.errors.GroupAuthorizationException; +import org.apache.kafka.common.errors.OutOfOrderSequenceException; +import org.apache.kafka.common.errors.ProducerFencedException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.TopicAuthorizationException; +import org.apache.kafka.common.errors.TransactionalIdAuthorizationException; +import org.apache.kafka.common.errors.UnsupportedForMessageFormatException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.internals.ClusterResourceListeners; +import org.apache.kafka.common.message.AddOffsetsToTxnResponseData; +import org.apache.kafka.common.message.ApiVersionsResponseData.ApiVersion; +import org.apache.kafka.common.message.EndTxnResponseData; +import org.apache.kafka.common.message.InitProducerIdResponseData; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.MemoryRecordsBuilder; +import org.apache.kafka.common.record.MutableRecordBatch; +import org.apache.kafka.common.record.Record; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.common.requests.AddOffsetsToTxnRequest; +import org.apache.kafka.common.requests.AddOffsetsToTxnResponse; +import org.apache.kafka.common.requests.AddPartitionsToTxnRequest; +import org.apache.kafka.common.requests.AddPartitionsToTxnResponse; +import org.apache.kafka.common.requests.EndTxnRequest; +import org.apache.kafka.common.requests.EndTxnResponse; +import org.apache.kafka.common.requests.FindCoordinatorRequest; +import org.apache.kafka.common.requests.FindCoordinatorRequest.CoordinatorType; +import org.apache.kafka.common.requests.FindCoordinatorResponse; +import org.apache.kafka.common.requests.InitProducerIdRequest; +import org.apache.kafka.common.requests.InitProducerIdResponse; +import org.apache.kafka.common.requests.JoinGroupRequest; +import org.apache.kafka.common.requests.ProduceRequest; +import org.apache.kafka.common.requests.ProduceResponse; +import org.apache.kafka.common.requests.RequestTestUtils; +import org.apache.kafka.common.requests.TransactionResult; +import org.apache.kafka.common.requests.TxnOffsetCommitRequest; +import org.apache.kafka.common.requests.TxnOffsetCommitResponse; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.ProducerIdAndEpoch; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +import static java.util.Collections.singleton; +import static java.util.Collections.singletonList; +import static java.util.Collections.singletonMap; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class TransactionManagerTest { + private static final int MAX_REQUEST_SIZE = 1024 * 1024; + private static final short ACKS_ALL = -1; + private static final int MAX_RETRIES = Integer.MAX_VALUE; + private static final int MAX_BLOCK_TIMEOUT = 1000; + private static final int REQUEST_TIMEOUT = 1000; + private static final long DEFAULT_RETRY_BACKOFF_MS = 100L; + + private final String transactionalId = "foobar"; + private final int transactionTimeoutMs = 1121; + + private final String topic = "test"; + private final TopicPartition tp0 = new TopicPartition(topic, 0); + private final TopicPartition tp1 = new TopicPartition(topic, 1); + private final long producerId = 13131L; + private final short epoch = 1; + private final String consumerGroupId = "myConsumerGroup"; + private final String memberId = "member"; + private final int generationId = 5; + private final String groupInstanceId = "instance"; + + private final LogContext logContext = new LogContext(); + private final MockTime time = new MockTime(); + private final ProducerMetadata metadata = new ProducerMetadata(0, Long.MAX_VALUE, Long.MAX_VALUE, + logContext, new ClusterResourceListeners(), time); + private final MockClient client = new MockClient(time, metadata); + private final ApiVersions apiVersions = new ApiVersions(); + + private RecordAccumulator accumulator = null; + private Sender sender = null; + private TransactionManager transactionManager = null; + private Node brokerNode = null; + + @BeforeEach + public void setup() { + this.metadata.add("test", time.milliseconds()); + this.client.updateMetadata(RequestTestUtils.metadataUpdateWith(1, singletonMap("test", 2))); + this.brokerNode = new Node(0, "localhost", 2211); + + initializeTransactionManager(Optional.of(transactionalId)); + } + + private void initializeTransactionManager(Optional transactionalId) { + Metrics metrics = new Metrics(time); + + apiVersions.update("0", new NodeApiVersions(Arrays.asList( + new ApiVersion() + .setApiKey(ApiKeys.INIT_PRODUCER_ID.id) + .setMinVersion((short) 0) + .setMaxVersion((short) 3), + new ApiVersion() + .setApiKey(ApiKeys.PRODUCE.id) + .setMinVersion((short) 0) + .setMaxVersion((short) 7)))); + this.transactionManager = new TransactionManager(logContext, transactionalId.orElse(null), + transactionTimeoutMs, DEFAULT_RETRY_BACKOFF_MS, apiVersions); + + int batchSize = 16 * 1024; + int deliveryTimeoutMs = 3000; + long totalSize = 1024 * 1024; + String metricGrpName = "producer-metrics"; + + this.brokerNode = new Node(0, "localhost", 2211); + this.accumulator = new RecordAccumulator(logContext, batchSize, CompressionType.NONE, 0, 0L, + deliveryTimeoutMs, metrics, metricGrpName, time, apiVersions, transactionManager, + new BufferPool(totalSize, batchSize, metrics, time, metricGrpName)); + + this.sender = new Sender(logContext, this.client, this.metadata, this.accumulator, true, + MAX_REQUEST_SIZE, ACKS_ALL, MAX_RETRIES, new SenderMetricsRegistry(metrics), this.time, REQUEST_TIMEOUT, + 50, transactionManager, apiVersions); + } + + @Test + public void testSenderShutdownWithPendingTransactions() throws Exception { + doInitTransactions(); + transactionManager.beginTransaction(); + + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + FutureRecordMetadata sendFuture = appendToAccumulator(tp0); + + prepareAddPartitionsToTxn(tp0, Errors.NONE); + prepareProduceResponse(Errors.NONE, producerId, epoch); + runUntil(() -> !client.hasPendingResponses()); + + sender.initiateClose(); + sender.runOnce(); + + TransactionalRequestResult result = transactionManager.beginCommit(); + prepareEndTxnResponse(Errors.NONE, TransactionResult.COMMIT, producerId, epoch); + runUntil(result::isCompleted); + runUntil(sendFuture::isDone); + } + + @Test + public void testEndTxnNotSentIfIncompleteBatches() { + doInitTransactions(); + transactionManager.beginTransaction(); + + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + prepareAddPartitionsToTxn(tp0, Errors.NONE); + runUntil(() -> transactionManager.isPartitionAdded(tp0)); + + transactionManager.beginCommit(); + assertNull(transactionManager.nextRequest(true)); + assertTrue(transactionManager.nextRequest(false).isEndTxn()); + } + + @Test + public void testFailIfNotReadyForSendNoProducerId() { + assertThrows(IllegalStateException.class, () -> transactionManager.failIfNotReadyForSend()); + } + + @Test + public void testFailIfNotReadyForSendIdempotentProducer() { + initializeTransactionManager(Optional.empty()); + transactionManager.failIfNotReadyForSend(); + } + + @Test + public void testFailIfNotReadyForSendIdempotentProducerFatalError() { + initializeTransactionManager(Optional.empty()); + transactionManager.transitionToFatalError(new KafkaException()); + assertThrows(KafkaException.class, () -> transactionManager.failIfNotReadyForSend()); + } + + @Test + public void testFailIfNotReadyForSendNoOngoingTransaction() { + doInitTransactions(); + assertThrows(IllegalStateException.class, () -> transactionManager.failIfNotReadyForSend()); + } + + @Test + public void testFailIfNotReadyForSendAfterAbortableError() { + doInitTransactions(); + transactionManager.beginTransaction(); + transactionManager.transitionToAbortableError(new KafkaException()); + assertThrows(KafkaException.class, transactionManager::failIfNotReadyForSend); + } + + @Test + public void testFailIfNotReadyForSendAfterFatalError() { + doInitTransactions(); + transactionManager.transitionToFatalError(new KafkaException()); + assertThrows(KafkaException.class, transactionManager::failIfNotReadyForSend); + } + + @Test + public void testHasOngoingTransactionSuccessfulAbort() { + TopicPartition partition = new TopicPartition("foo", 0); + + assertFalse(transactionManager.hasOngoingTransaction()); + doInitTransactions(); + assertFalse(transactionManager.hasOngoingTransaction()); + + transactionManager.beginTransaction(); + assertTrue(transactionManager.hasOngoingTransaction()); + + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(partition); + runUntil(transactionManager::hasOngoingTransaction); + + prepareAddPartitionsToTxn(partition, Errors.NONE); + runUntil(() -> transactionManager.isPartitionAdded(partition)); + + transactionManager.beginAbort(); + assertTrue(transactionManager.hasOngoingTransaction()); + + prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, producerId, epoch); + runUntil(() -> !transactionManager.hasOngoingTransaction()); + } + + @Test + public void testHasOngoingTransactionSuccessfulCommit() { + TopicPartition partition = new TopicPartition("foo", 0); + + assertFalse(transactionManager.hasOngoingTransaction()); + doInitTransactions(); + assertFalse(transactionManager.hasOngoingTransaction()); + + transactionManager.beginTransaction(); + assertTrue(transactionManager.hasOngoingTransaction()); + + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(partition); + assertTrue(transactionManager.hasOngoingTransaction()); + + prepareAddPartitionsToTxn(partition, Errors.NONE); + runUntil(() -> transactionManager.isPartitionAdded(partition)); + + transactionManager.beginCommit(); + assertTrue(transactionManager.hasOngoingTransaction()); + + prepareEndTxnResponse(Errors.NONE, TransactionResult.COMMIT, producerId, epoch); + runUntil(() -> !transactionManager.hasOngoingTransaction()); + } + + @Test + public void testHasOngoingTransactionAbortableError() { + TopicPartition partition = new TopicPartition("foo", 0); + + assertFalse(transactionManager.hasOngoingTransaction()); + doInitTransactions(); + assertFalse(transactionManager.hasOngoingTransaction()); + + transactionManager.beginTransaction(); + assertTrue(transactionManager.hasOngoingTransaction()); + + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(partition); + assertTrue(transactionManager.hasOngoingTransaction()); + + prepareAddPartitionsToTxn(partition, Errors.NONE); + runUntil(() -> transactionManager.isPartitionAdded(partition)); + + transactionManager.transitionToAbortableError(new KafkaException()); + assertTrue(transactionManager.hasOngoingTransaction()); + + transactionManager.beginAbort(); + assertTrue(transactionManager.hasOngoingTransaction()); + + prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, producerId, epoch); + runUntil(() -> !transactionManager.hasOngoingTransaction()); + } + + @Test + public void testHasOngoingTransactionFatalError() { + TopicPartition partition = new TopicPartition("foo", 0); + + assertFalse(transactionManager.hasOngoingTransaction()); + doInitTransactions(); + assertFalse(transactionManager.hasOngoingTransaction()); + + transactionManager.beginTransaction(); + assertTrue(transactionManager.hasOngoingTransaction()); + + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(partition); + assertTrue(transactionManager.hasOngoingTransaction()); + + prepareAddPartitionsToTxn(partition, Errors.NONE); + runUntil(() -> transactionManager.isPartitionAdded(partition)); + + transactionManager.transitionToFatalError(new KafkaException()); + assertFalse(transactionManager.hasOngoingTransaction()); + } + + @Test + public void testMaybeAddPartitionToTransaction() { + TopicPartition partition = new TopicPartition("foo", 0); + doInitTransactions(); + transactionManager.beginTransaction(); + + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(partition); + assertTrue(transactionManager.hasPartitionsToAdd()); + assertFalse(transactionManager.isPartitionAdded(partition)); + assertTrue(transactionManager.isPartitionPendingAdd(partition)); + + prepareAddPartitionsToTxn(partition, Errors.NONE); + assertTrue(transactionManager.hasPartitionsToAdd()); + + runUntil(() -> transactionManager.isPartitionAdded(partition)); + assertFalse(transactionManager.hasPartitionsToAdd()); + assertFalse(transactionManager.isPartitionPendingAdd(partition)); + + // adding the partition again should not have any effect + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(partition); + assertFalse(transactionManager.hasPartitionsToAdd()); + assertTrue(transactionManager.isPartitionAdded(partition)); + assertFalse(transactionManager.isPartitionPendingAdd(partition)); + } + + @Test + public void testAddPartitionToTransactionOverridesRetryBackoffForConcurrentTransactions() { + TopicPartition partition = new TopicPartition("foo", 0); + doInitTransactions(); + transactionManager.beginTransaction(); + + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(partition); + assertTrue(transactionManager.hasPartitionsToAdd()); + assertFalse(transactionManager.isPartitionAdded(partition)); + assertTrue(transactionManager.isPartitionPendingAdd(partition)); + + prepareAddPartitionsToTxn(partition, Errors.CONCURRENT_TRANSACTIONS); + runUntil(() -> !client.hasPendingResponses()); + + TransactionManager.TxnRequestHandler handler = transactionManager.nextRequest(false); + assertNotNull(handler); + assertEquals(20, handler.retryBackoffMs()); + } + + @Test + public void testAddPartitionToTransactionRetainsRetryBackoffForRegularRetriableError() { + TopicPartition partition = new TopicPartition("foo", 0); + doInitTransactions(); + transactionManager.beginTransaction(); + + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(partition); + assertTrue(transactionManager.hasPartitionsToAdd()); + assertFalse(transactionManager.isPartitionAdded(partition)); + assertTrue(transactionManager.isPartitionPendingAdd(partition)); + + prepareAddPartitionsToTxn(partition, Errors.COORDINATOR_NOT_AVAILABLE); + runUntil(() -> !client.hasPendingResponses()); + + TransactionManager.TxnRequestHandler handler = transactionManager.nextRequest(false); + assertNotNull(handler); + assertEquals(DEFAULT_RETRY_BACKOFF_MS, handler.retryBackoffMs()); + } + + @Test + public void testAddPartitionToTransactionRetainsRetryBackoffWhenPartitionsAlreadyAdded() { + TopicPartition partition = new TopicPartition("foo", 0); + doInitTransactions(); + transactionManager.beginTransaction(); + + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(partition); + assertTrue(transactionManager.hasPartitionsToAdd()); + assertFalse(transactionManager.isPartitionAdded(partition)); + assertTrue(transactionManager.isPartitionPendingAdd(partition)); + + prepareAddPartitionsToTxn(partition, Errors.NONE); + runUntil(() -> transactionManager.isPartitionAdded(partition)); + + TopicPartition otherPartition = new TopicPartition("foo", 1); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(otherPartition); + + prepareAddPartitionsToTxn(otherPartition, Errors.CONCURRENT_TRANSACTIONS); + TransactionManager.TxnRequestHandler handler = transactionManager.nextRequest(false); + assertNotNull(handler); + assertEquals(DEFAULT_RETRY_BACKOFF_MS, handler.retryBackoffMs()); + } + + @Test + public void testNotReadyForSendBeforeInitTransactions() { + assertThrows(IllegalStateException.class, () -> transactionManager.failIfNotReadyForSend()); + } + + @Test + public void testNotReadyForSendBeforeBeginTransaction() { + doInitTransactions(); + assertThrows(IllegalStateException.class, () -> transactionManager.failIfNotReadyForSend()); + } + + @Test + public void testNotReadyForSendAfterAbortableError() { + doInitTransactions(); + transactionManager.beginTransaction(); + transactionManager.transitionToAbortableError(new KafkaException()); + assertThrows(KafkaException.class, () -> transactionManager.failIfNotReadyForSend()); + } + + @Test + public void testNotReadyForSendAfterFatalError() { + doInitTransactions(); + transactionManager.transitionToFatalError(new KafkaException()); + assertThrows(KafkaException.class, () -> transactionManager.failIfNotReadyForSend()); + } + + @Test + public void testIsSendToPartitionAllowedWithPendingPartitionAfterAbortableError() { + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + transactionManager.transitionToAbortableError(new KafkaException()); + + assertFalse(transactionManager.isSendToPartitionAllowed(tp0)); + assertTrue(transactionManager.hasAbortableError()); + } + + @Test + public void testIsSendToPartitionAllowedWithInFlightPartitionAddAfterAbortableError() { + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + // Send the AddPartitionsToTxn request and leave it in-flight + runUntil(transactionManager::hasInFlightRequest); + transactionManager.transitionToAbortableError(new KafkaException()); + + assertFalse(transactionManager.isSendToPartitionAllowed(tp0)); + assertTrue(transactionManager.hasAbortableError()); + } + + @Test + public void testIsSendToPartitionAllowedWithPendingPartitionAfterFatalError() { + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + transactionManager.transitionToFatalError(new KafkaException()); + + assertFalse(transactionManager.isSendToPartitionAllowed(tp0)); + assertTrue(transactionManager.hasFatalError()); + } + + @Test + public void testIsSendToPartitionAllowedWithInFlightPartitionAddAfterFatalError() { + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + // Send the AddPartitionsToTxn request and leave it in-flight + runUntil(transactionManager::hasInFlightRequest); + transactionManager.transitionToFatalError(new KafkaException()); + + assertFalse(transactionManager.isSendToPartitionAllowed(tp0)); + assertTrue(transactionManager.hasFatalError()); + } + + @Test + public void testIsSendToPartitionAllowedWithAddedPartitionAfterAbortableError() { + doInitTransactions(); + + transactionManager.beginTransaction(); + + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, producerId); + + runUntil(() -> !transactionManager.hasPartitionsToAdd()); + transactionManager.transitionToAbortableError(new KafkaException()); + + assertTrue(transactionManager.isSendToPartitionAllowed(tp0)); + assertTrue(transactionManager.hasAbortableError()); + } + + @Test + public void testIsSendToPartitionAllowedWithAddedPartitionAfterFatalError() { + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, producerId); + + runUntil(() -> !transactionManager.hasPartitionsToAdd()); + transactionManager.transitionToFatalError(new KafkaException()); + + assertFalse(transactionManager.isSendToPartitionAllowed(tp0)); + assertTrue(transactionManager.hasFatalError()); + } + + @Test + public void testIsSendToPartitionAllowedWithPartitionNotAdded() { + doInitTransactions(); + transactionManager.beginTransaction(); + assertFalse(transactionManager.isSendToPartitionAllowed(tp0)); + } + + @Test + public void testDefaultSequenceNumber() { + initializeTransactionManager(Optional.empty()); + assertEquals((int) transactionManager.sequenceNumber(tp0), 0); + transactionManager.incrementSequenceNumber(tp0, 3); + assertEquals((int) transactionManager.sequenceNumber(tp0), 3); + } + + @Test + public void testBumpEpochAndResetSequenceNumbersAfterUnknownProducerId() { + initializeTransactionManager(Optional.empty()); + initializeIdempotentProducerId(producerId, epoch); + + ProducerBatch b1 = writeIdempotentBatchWithValue(transactionManager, tp0, "1"); + ProducerBatch b2 = writeIdempotentBatchWithValue(transactionManager, tp0, "2"); + ProducerBatch b3 = writeIdempotentBatchWithValue(transactionManager, tp0, "3"); + ProducerBatch b4 = writeIdempotentBatchWithValue(transactionManager, tp0, "4"); + ProducerBatch b5 = writeIdempotentBatchWithValue(transactionManager, tp0, "5"); + assertEquals(5, transactionManager.sequenceNumber(tp0).intValue()); + + // First batch succeeds + long b1AppendTime = time.milliseconds(); + ProduceResponse.PartitionResponse b1Response = new ProduceResponse.PartitionResponse( + Errors.NONE, 500L, b1AppendTime, 0L); + b1.complete(500L, b1AppendTime); + transactionManager.handleCompletedBatch(b1, b1Response); + + // We get an UNKNOWN_PRODUCER_ID, so bump the epoch and set sequence numbers back to 0 + ProduceResponse.PartitionResponse b2Response = new ProduceResponse.PartitionResponse( + Errors.UNKNOWN_PRODUCER_ID, -1, -1, 500L); + assertTrue(transactionManager.canRetry(b2Response, b2)); + + // Run sender loop to trigger epoch bump + runUntil(() -> transactionManager.producerIdAndEpoch().epoch == 2); + assertEquals(2, b2.producerEpoch()); + assertEquals(0, b2.baseSequence()); + assertEquals(1, b3.baseSequence()); + assertEquals(2, b4.baseSequence()); + assertEquals(3, b5.baseSequence()); + } + + @Test + public void testBatchFailureAfterProducerReset() { + // This tests a scenario where the producerId is reset while pending requests are still inflight. + // The partition(s) that triggered the reset will have their sequence number reset, while any others will not + final short epoch = Short.MAX_VALUE; + + initializeTransactionManager(Optional.empty()); + initializeIdempotentProducerId(producerId, epoch); + + ProducerBatch tp0b1 = writeIdempotentBatchWithValue(transactionManager, tp0, "1"); + ProducerBatch tp1b1 = writeIdempotentBatchWithValue(transactionManager, tp1, "1"); + + ProduceResponse.PartitionResponse tp0b1Response = new ProduceResponse.PartitionResponse( + Errors.NONE, -1, -1, 400L); + transactionManager.handleCompletedBatch(tp0b1, tp0b1Response); + + ProduceResponse.PartitionResponse tp1b1Response = new ProduceResponse.PartitionResponse( + Errors.NONE, -1, -1, 400L); + transactionManager.handleCompletedBatch(tp1b1, tp1b1Response); + + ProducerBatch tp0b2 = writeIdempotentBatchWithValue(transactionManager, tp0, "2"); + ProducerBatch tp1b2 = writeIdempotentBatchWithValue(transactionManager, tp1, "2"); + assertEquals(2, transactionManager.sequenceNumber(tp0).intValue()); + assertEquals(2, transactionManager.sequenceNumber(tp1).intValue()); + + ProduceResponse.PartitionResponse b1Response = new ProduceResponse.PartitionResponse( + Errors.UNKNOWN_PRODUCER_ID, -1, -1, 400L); + assertTrue(transactionManager.canRetry(b1Response, tp0b1)); + + ProduceResponse.PartitionResponse b2Response = new ProduceResponse.PartitionResponse( + Errors.NONE, -1, -1, 400L); + transactionManager.handleCompletedBatch(tp1b1, b2Response); + + transactionManager.bumpIdempotentEpochAndResetIdIfNeeded(); + + assertEquals(1, transactionManager.sequenceNumber(tp0).intValue()); + assertEquals(tp0b2, transactionManager.nextBatchBySequence(tp0)); + assertEquals(2, transactionManager.sequenceNumber(tp1).intValue()); + assertEquals(tp1b2, transactionManager.nextBatchBySequence(tp1)); + } + + @Test + public void testBatchCompletedAfterProducerReset() { + final short epoch = Short.MAX_VALUE; + + initializeTransactionManager(Optional.empty()); + initializeIdempotentProducerId(producerId, epoch); + + ProducerBatch b1 = writeIdempotentBatchWithValue(transactionManager, tp0, "1"); + writeIdempotentBatchWithValue(transactionManager, tp1, "1"); + + ProducerBatch b2 = writeIdempotentBatchWithValue(transactionManager, tp0, "2"); + assertEquals(2, transactionManager.sequenceNumber(tp0).intValue()); + + // The producerId might be reset due to a failure on another partition + transactionManager.requestEpochBumpForPartition(tp1); + transactionManager.bumpIdempotentEpochAndResetIdIfNeeded(); + initializeIdempotentProducerId(producerId + 1, (short) 0); + + // We continue to track the state of tp0 until in-flight requests complete + ProduceResponse.PartitionResponse b1Response = new ProduceResponse.PartitionResponse( + Errors.NONE, 500L, time.milliseconds(), 0L); + transactionManager.handleCompletedBatch(b1, b1Response); + + assertEquals(2, transactionManager.sequenceNumber(tp0).intValue()); + assertEquals(0, transactionManager.lastAckedSequence(tp0).getAsInt()); + assertEquals(b2, transactionManager.nextBatchBySequence(tp0)); + assertEquals(epoch, transactionManager.nextBatchBySequence(tp0).producerEpoch()); + + ProduceResponse.PartitionResponse b2Response = new ProduceResponse.PartitionResponse( + Errors.NONE, 500L, time.milliseconds(), 0L); + transactionManager.handleCompletedBatch(b2, b2Response); + + transactionManager.maybeUpdateProducerIdAndEpoch(tp0); + assertEquals(0, transactionManager.sequenceNumber(tp0).intValue()); + assertFalse(transactionManager.lastAckedSequence(tp0).isPresent()); + assertNull(transactionManager.nextBatchBySequence(tp0)); + } + + private ProducerBatch writeIdempotentBatchWithValue(TransactionManager manager, + TopicPartition tp, + String value) { + manager.maybeUpdateProducerIdAndEpoch(tp); + int seq = manager.sequenceNumber(tp); + manager.incrementSequenceNumber(tp, 1); + ProducerBatch batch = batchWithValue(tp, value); + batch.setProducerState(manager.producerIdAndEpoch(), seq, false); + manager.addInFlightBatch(batch); + batch.close(); + return batch; + } + + private ProducerBatch batchWithValue(TopicPartition tp, String value) { + MemoryRecordsBuilder builder = MemoryRecords.builder(ByteBuffer.allocate(64), + CompressionType.NONE, TimestampType.CREATE_TIME, 0L); + long currentTimeMs = time.milliseconds(); + ProducerBatch batch = new ProducerBatch(tp, builder, currentTimeMs); + batch.tryAppend(currentTimeMs, new byte[0], value.getBytes(), new Header[0], null, currentTimeMs); + return batch; + } + + @Test + public void testSequenceNumberOverflow() { + initializeTransactionManager(Optional.empty()); + assertEquals((int) transactionManager.sequenceNumber(tp0), 0); + transactionManager.incrementSequenceNumber(tp0, Integer.MAX_VALUE); + assertEquals((int) transactionManager.sequenceNumber(tp0), Integer.MAX_VALUE); + transactionManager.incrementSequenceNumber(tp0, 100); + assertEquals((int) transactionManager.sequenceNumber(tp0), 99); + transactionManager.incrementSequenceNumber(tp0, Integer.MAX_VALUE); + assertEquals((int) transactionManager.sequenceNumber(tp0), 98); + } + + @Test + public void testProducerIdReset() { + initializeTransactionManager(Optional.empty()); + initializeIdempotentProducerId(15L, Short.MAX_VALUE); + assertEquals((int) transactionManager.sequenceNumber(tp0), 0); + assertEquals((int) transactionManager.sequenceNumber(tp1), 0); + transactionManager.incrementSequenceNumber(tp0, 3); + assertEquals((int) transactionManager.sequenceNumber(tp0), 3); + transactionManager.incrementSequenceNumber(tp1, 3); + assertEquals((int) transactionManager.sequenceNumber(tp1), 3); + + transactionManager.requestEpochBumpForPartition(tp0); + transactionManager.bumpIdempotentEpochAndResetIdIfNeeded(); + assertEquals((int) transactionManager.sequenceNumber(tp0), 0); + assertEquals((int) transactionManager.sequenceNumber(tp1), 3); + } + + @Test + public void testBasicTransaction() throws InterruptedException { + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + Future responseFuture = appendToAccumulator(tp0); + + assertFalse(responseFuture.isDone()); + prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, producerId); + + prepareProduceResponse(Errors.NONE, producerId, epoch); + assertFalse(transactionManager.transactionContainsPartition(tp0)); + assertFalse(transactionManager.isSendToPartitionAllowed(tp0)); + runUntil(() -> transactionManager.transactionContainsPartition(tp0)); + assertTrue(transactionManager.isSendToPartitionAllowed(tp0)); + assertFalse(responseFuture.isDone()); + runUntil(responseFuture::isDone); + + Map offsets = new HashMap<>(); + offsets.put(tp1, new OffsetAndMetadata(1)); + + TransactionalRequestResult addOffsetsResult = transactionManager.sendOffsetsToTransaction( + offsets, new ConsumerGroupMetadata(consumerGroupId)); + + assertFalse(transactionManager.hasPendingOffsetCommits()); + + prepareAddOffsetsToTxnResponse(Errors.NONE, consumerGroupId, producerId, epoch); + + runUntil(transactionManager::hasPendingOffsetCommits); + assertFalse(addOffsetsResult.isCompleted()); // the result doesn't complete until TxnOffsetCommit returns + + Map txnOffsetCommitResponse = new HashMap<>(); + txnOffsetCommitResponse.put(tp1, Errors.NONE); + + prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.GROUP, consumerGroupId); + prepareTxnOffsetCommitResponse(consumerGroupId, producerId, epoch, txnOffsetCommitResponse); + + assertNull(transactionManager.coordinator(CoordinatorType.GROUP)); + runUntil(() -> transactionManager.coordinator(CoordinatorType.GROUP) != null); + assertTrue(transactionManager.hasPendingOffsetCommits()); + + runUntil(() -> !transactionManager.hasPendingOffsetCommits()); + assertTrue(addOffsetsResult.isCompleted()); // We should only be done after both RPCs complete. + + transactionManager.beginCommit(); + prepareEndTxnResponse(Errors.NONE, TransactionResult.COMMIT, producerId, epoch); + runUntil(() -> !transactionManager.hasOngoingTransaction()); + assertFalse(transactionManager.isCompleting()); + assertFalse(transactionManager.transactionContainsPartition(tp0)); + } + + @Test + public void testDisconnectAndRetry() { + // This is called from the initTransactions method in the producer as the first order of business. + // It finds the coordinator and then gets a PID. + transactionManager.initializeTransactions(); + prepareFindCoordinatorResponse(Errors.NONE, true, CoordinatorType.TRANSACTION, transactionalId); + runUntil(() -> transactionManager.coordinator(CoordinatorType.TRANSACTION) == null); + + prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.TRANSACTION, transactionalId); + runUntil(() -> transactionManager.coordinator(CoordinatorType.TRANSACTION) != null); + assertEquals(brokerNode, transactionManager.coordinator(CoordinatorType.TRANSACTION)); + } + + @Test + public void testInitializeTransactionsTwiceRaisesError() { + doInitTransactions(producerId, epoch); + assertTrue(transactionManager.hasProducerId()); + assertThrows(KafkaException.class, () -> transactionManager.initializeTransactions()); + } + + @Test + public void testUnsupportedFindCoordinator() { + transactionManager.initializeTransactions(); + client.prepareUnsupportedVersionResponse(body -> { + FindCoordinatorRequest findCoordinatorRequest = (FindCoordinatorRequest) body; + assertEquals(CoordinatorType.forId(findCoordinatorRequest.data().keyType()), CoordinatorType.TRANSACTION); + assertEquals(findCoordinatorRequest.data().key(), transactionalId); + return true; + }); + + runUntil(transactionManager::hasFatalError); + assertTrue(transactionManager.hasFatalError()); + assertTrue(transactionManager.lastError() instanceof UnsupportedVersionException); + } + + @Test + public void testUnsupportedInitTransactions() { + transactionManager.initializeTransactions(); + prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.TRANSACTION, transactionalId); + runUntil(() -> transactionManager.coordinator(CoordinatorType.TRANSACTION) != null); + assertFalse(transactionManager.hasError()); + + client.prepareUnsupportedVersionResponse(body -> { + InitProducerIdRequest initProducerIdRequest = (InitProducerIdRequest) body; + assertEquals(initProducerIdRequest.data().transactionalId(), transactionalId); + assertEquals(initProducerIdRequest.data().transactionTimeoutMs(), transactionTimeoutMs); + return true; + }); + + runUntil(transactionManager::hasFatalError); + assertTrue(transactionManager.hasFatalError()); + assertTrue(transactionManager.lastError() instanceof UnsupportedVersionException); + } + + @Test + public void testUnsupportedForMessageFormatInTxnOffsetCommit() { + final TopicPartition tp = new TopicPartition("foo", 0); + + doInitTransactions(); + + transactionManager.beginTransaction(); + TransactionalRequestResult sendOffsetsResult = transactionManager.sendOffsetsToTransaction( + singletonMap(tp, new OffsetAndMetadata(39L)), new ConsumerGroupMetadata(consumerGroupId)); + + prepareAddOffsetsToTxnResponse(Errors.NONE, consumerGroupId, producerId, epoch); + prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.GROUP, consumerGroupId); + prepareTxnOffsetCommitResponse(consumerGroupId, producerId, epoch, singletonMap(tp, Errors.UNSUPPORTED_FOR_MESSAGE_FORMAT)); + runUntil(transactionManager::hasError); + + assertTrue(transactionManager.lastError() instanceof UnsupportedForMessageFormatException); + assertTrue(sendOffsetsResult.isCompleted()); + assertFalse(sendOffsetsResult.isSuccessful()); + assertTrue(sendOffsetsResult.error() instanceof UnsupportedForMessageFormatException); + assertFatalError(UnsupportedForMessageFormatException.class); + } + + @Test + public void testFencedInstanceIdInTxnOffsetCommitByGroupMetadata() { + final TopicPartition tp = new TopicPartition("foo", 0); + final String fencedMemberId = "fenced_member"; + + doInitTransactions(); + + transactionManager.beginTransaction(); + + TransactionalRequestResult sendOffsetsResult = transactionManager.sendOffsetsToTransaction( + singletonMap(tp, new OffsetAndMetadata(39L)), + new ConsumerGroupMetadata(consumerGroupId, 5, fencedMemberId, Optional.of(groupInstanceId))); + + prepareAddOffsetsToTxnResponse(Errors.NONE, consumerGroupId, producerId, epoch); + prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.GROUP, consumerGroupId); + runUntil(() -> transactionManager.coordinator(CoordinatorType.GROUP) != null); + + client.prepareResponse(request -> { + TxnOffsetCommitRequest txnOffsetCommitRequest = (TxnOffsetCommitRequest) request; + assertEquals(consumerGroupId, txnOffsetCommitRequest.data().groupId()); + assertEquals(producerId, txnOffsetCommitRequest.data().producerId()); + assertEquals(epoch, txnOffsetCommitRequest.data().producerEpoch()); + return txnOffsetCommitRequest.data().groupInstanceId().equals(groupInstanceId) + && !txnOffsetCommitRequest.data().memberId().equals(memberId); + }, new TxnOffsetCommitResponse(0, singletonMap(tp, Errors.FENCED_INSTANCE_ID))); + + runUntil(transactionManager::hasError); + assertTrue(transactionManager.lastError() instanceof FencedInstanceIdException); + assertTrue(sendOffsetsResult.isCompleted()); + assertFalse(sendOffsetsResult.isSuccessful()); + assertTrue(sendOffsetsResult.error() instanceof FencedInstanceIdException); + assertAbortableError(FencedInstanceIdException.class); + } + + @Test + public void testUnknownMemberIdInTxnOffsetCommitByGroupMetadata() { + final TopicPartition tp = new TopicPartition("foo", 0); + final String unknownMemberId = "unknownMember"; + + doInitTransactions(); + + transactionManager.beginTransaction(); + + TransactionalRequestResult sendOffsetsResult = transactionManager.sendOffsetsToTransaction( + singletonMap(tp, new OffsetAndMetadata(39L)), + new ConsumerGroupMetadata(consumerGroupId, 5, unknownMemberId, Optional.empty())); + + prepareAddOffsetsToTxnResponse(Errors.NONE, consumerGroupId, producerId, epoch); + prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.GROUP, consumerGroupId); + runUntil(() -> transactionManager.coordinator(CoordinatorType.GROUP) != null); + + client.prepareResponse(request -> { + TxnOffsetCommitRequest txnOffsetCommitRequest = (TxnOffsetCommitRequest) request; + assertEquals(consumerGroupId, txnOffsetCommitRequest.data().groupId()); + assertEquals(producerId, txnOffsetCommitRequest.data().producerId()); + assertEquals(epoch, txnOffsetCommitRequest.data().producerEpoch()); + return !txnOffsetCommitRequest.data().memberId().equals(memberId); + }, new TxnOffsetCommitResponse(0, singletonMap(tp, Errors.UNKNOWN_MEMBER_ID))); + + runUntil(transactionManager::hasError); + assertTrue(transactionManager.lastError() instanceof CommitFailedException); + assertTrue(sendOffsetsResult.isCompleted()); + assertFalse(sendOffsetsResult.isSuccessful()); + assertTrue(sendOffsetsResult.error() instanceof CommitFailedException); + assertAbortableError(CommitFailedException.class); + } + + @Test + public void testIllegalGenerationInTxnOffsetCommitByGroupMetadata() { + final TopicPartition tp = new TopicPartition("foo", 0); + final int illegalGenerationId = 1; + + doInitTransactions(); + + transactionManager.beginTransaction(); + + TransactionalRequestResult sendOffsetsResult = transactionManager.sendOffsetsToTransaction( + singletonMap(tp, new OffsetAndMetadata(39L)), + new ConsumerGroupMetadata(consumerGroupId, illegalGenerationId, JoinGroupRequest.UNKNOWN_MEMBER_ID, + Optional.empty())); + + prepareAddOffsetsToTxnResponse(Errors.NONE, consumerGroupId, producerId, epoch); + prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.GROUP, consumerGroupId); + runUntil(() -> transactionManager.coordinator(CoordinatorType.GROUP) != null); + + prepareTxnOffsetCommitResponse(consumerGroupId, producerId, epoch, singletonMap(tp, Errors.ILLEGAL_GENERATION)); + client.prepareResponse(request -> { + TxnOffsetCommitRequest txnOffsetCommitRequest = (TxnOffsetCommitRequest) request; + assertEquals(consumerGroupId, txnOffsetCommitRequest.data().groupId()); + assertEquals(producerId, txnOffsetCommitRequest.data().producerId()); + assertEquals(epoch, txnOffsetCommitRequest.data().producerEpoch()); + return txnOffsetCommitRequest.data().generationId() != generationId; + }, new TxnOffsetCommitResponse(0, singletonMap(tp, Errors.ILLEGAL_GENERATION))); + + runUntil(transactionManager::hasError); + assertTrue(transactionManager.lastError() instanceof CommitFailedException); + assertTrue(sendOffsetsResult.isCompleted()); + assertFalse(sendOffsetsResult.isSuccessful()); + assertTrue(sendOffsetsResult.error() instanceof CommitFailedException); + assertAbortableError(CommitFailedException.class); + } + + @Test + public void testLookupCoordinatorOnDisconnectAfterSend() { + // This is called from the initTransactions method in the producer as the first order of business. + // It finds the coordinator and then gets a PID. + TransactionalRequestResult initPidResult = transactionManager.initializeTransactions(); + prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.TRANSACTION, transactionalId); + runUntil(() -> transactionManager.coordinator(CoordinatorType.TRANSACTION) != null); + assertEquals(brokerNode, transactionManager.coordinator(CoordinatorType.TRANSACTION)); + + prepareInitPidResponse(Errors.NONE, true, producerId, epoch); + // send pid to coordinator, should get disconnected before receiving the response, and resend the + // FindCoordinator and InitPid requests. + runUntil(() -> transactionManager.coordinator(CoordinatorType.TRANSACTION) == null); + + assertNull(transactionManager.coordinator(CoordinatorType.TRANSACTION)); + assertFalse(initPidResult.isCompleted()); + assertFalse(transactionManager.hasProducerId()); + + prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.TRANSACTION, transactionalId); + runUntil(() -> transactionManager.coordinator(CoordinatorType.TRANSACTION) != null); + + assertEquals(brokerNode, transactionManager.coordinator(CoordinatorType.TRANSACTION)); + assertFalse(initPidResult.isCompleted()); + prepareInitPidResponse(Errors.NONE, false, producerId, epoch); + runUntil(initPidResult::isCompleted); + + assertTrue(initPidResult.isCompleted()); // The future should only return after the second round of retries succeed. + assertTrue(transactionManager.hasProducerId()); + assertEquals(producerId, transactionManager.producerIdAndEpoch().producerId); + assertEquals(epoch, transactionManager.producerIdAndEpoch().epoch); + } + + @Test + public void testLookupCoordinatorOnDisconnectBeforeSend() { + // This is called from the initTransactions method in the producer as the first order of business. + // It finds the coordinator and then gets a PID. + TransactionalRequestResult initPidResult = transactionManager.initializeTransactions(); + prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.TRANSACTION, transactionalId); + runUntil(() -> transactionManager.coordinator(CoordinatorType.TRANSACTION) != null); + assertEquals(brokerNode, transactionManager.coordinator(CoordinatorType.TRANSACTION)); + + client.disconnect(brokerNode.idString()); + client.backoff(brokerNode, 100); + // send pid to coordinator. Should get disconnected before the send and resend the FindCoordinator + // and InitPid requests. + runUntil(() -> transactionManager.coordinator(CoordinatorType.TRANSACTION) == null); + time.sleep(110); // waiting for the backoff period for the node to expire. + + assertFalse(initPidResult.isCompleted()); + assertFalse(transactionManager.hasProducerId()); + + prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.TRANSACTION, transactionalId); + runUntil(() -> transactionManager.coordinator(CoordinatorType.TRANSACTION) != null); + assertEquals(brokerNode, transactionManager.coordinator(CoordinatorType.TRANSACTION)); + assertFalse(initPidResult.isCompleted()); + prepareInitPidResponse(Errors.NONE, false, producerId, epoch); + + runUntil(initPidResult::isCompleted); + assertTrue(transactionManager.hasProducerId()); + assertEquals(producerId, transactionManager.producerIdAndEpoch().producerId); + assertEquals(epoch, transactionManager.producerIdAndEpoch().epoch); + } + + @Test + public void testLookupCoordinatorOnNotCoordinatorError() { + // This is called from the initTransactions method in the producer as the first order of business. + // It finds the coordinator and then gets a PID. + TransactionalRequestResult initPidResult = transactionManager.initializeTransactions(); + prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.TRANSACTION, transactionalId); + runUntil(() -> transactionManager.coordinator(CoordinatorType.TRANSACTION) != null); + assertEquals(brokerNode, transactionManager.coordinator(CoordinatorType.TRANSACTION)); + + prepareInitPidResponse(Errors.NOT_COORDINATOR, false, producerId, epoch); + runUntil(() -> transactionManager.coordinator(CoordinatorType.TRANSACTION) == null); + + assertFalse(initPidResult.isCompleted()); + assertFalse(transactionManager.hasProducerId()); + + prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.TRANSACTION, transactionalId); + runUntil(() -> transactionManager.coordinator(CoordinatorType.TRANSACTION) != null); + assertEquals(brokerNode, transactionManager.coordinator(CoordinatorType.TRANSACTION)); + assertFalse(initPidResult.isCompleted()); + prepareInitPidResponse(Errors.NONE, false, producerId, epoch); + + runUntil(initPidResult::isCompleted); + assertTrue(transactionManager.hasProducerId()); + assertEquals(producerId, transactionManager.producerIdAndEpoch().producerId); + assertEquals(epoch, transactionManager.producerIdAndEpoch().epoch); + } + + @Test + public void testTransactionalIdAuthorizationFailureInFindCoordinator() { + TransactionalRequestResult initPidResult = transactionManager.initializeTransactions(); + prepareFindCoordinatorResponse(Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED, false, + CoordinatorType.TRANSACTION, transactionalId); + + runUntil(transactionManager::hasError); + + assertTrue(transactionManager.hasFatalError()); + assertTrue(transactionManager.lastError() instanceof TransactionalIdAuthorizationException); + assertFalse(initPidResult.isSuccessful()); + assertTrue(initPidResult.error() instanceof TransactionalIdAuthorizationException); + assertFatalError(TransactionalIdAuthorizationException.class); + } + + @Test + public void testTransactionalIdAuthorizationFailureInInitProducerId() { + TransactionalRequestResult initPidResult = transactionManager.initializeTransactions(); + prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.TRANSACTION, transactionalId); + runUntil(() -> transactionManager.coordinator(CoordinatorType.TRANSACTION) != null); + assertEquals(brokerNode, transactionManager.coordinator(CoordinatorType.TRANSACTION)); + + prepareInitPidResponse(Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED, false, producerId, RecordBatch.NO_PRODUCER_EPOCH); + runUntil(transactionManager::hasError); + assertTrue(initPidResult.isCompleted()); + assertFalse(initPidResult.isSuccessful()); + assertTrue(initPidResult.error() instanceof TransactionalIdAuthorizationException); + + assertFatalError(TransactionalIdAuthorizationException.class); + } + + @Test + public void testGroupAuthorizationFailureInFindCoordinator() { + doInitTransactions(); + + transactionManager.beginTransaction(); + TransactionalRequestResult sendOffsetsResult = transactionManager.sendOffsetsToTransaction( + singletonMap(new TopicPartition("foo", 0), new OffsetAndMetadata(39L)), new ConsumerGroupMetadata(consumerGroupId)); + + prepareAddOffsetsToTxnResponse(Errors.NONE, consumerGroupId, producerId, epoch); + runUntil(() -> !transactionManager.hasPartitionsToAdd()); + + prepareFindCoordinatorResponse(Errors.GROUP_AUTHORIZATION_FAILED, false, CoordinatorType.GROUP, consumerGroupId); + runUntil(transactionManager::hasError); + assertTrue(transactionManager.lastError() instanceof GroupAuthorizationException); + + runUntil(sendOffsetsResult::isCompleted); + assertFalse(sendOffsetsResult.isSuccessful()); + assertTrue(sendOffsetsResult.error() instanceof GroupAuthorizationException); + + GroupAuthorizationException exception = (GroupAuthorizationException) sendOffsetsResult.error(); + assertEquals(consumerGroupId, exception.groupId()); + + assertAbortableError(GroupAuthorizationException.class); + } + + @Test + public void testGroupAuthorizationFailureInTxnOffsetCommit() { + final TopicPartition tp1 = new TopicPartition("foo", 0); + + doInitTransactions(); + + transactionManager.beginTransaction(); + TransactionalRequestResult sendOffsetsResult = transactionManager.sendOffsetsToTransaction( + singletonMap(tp1, new OffsetAndMetadata(39L)), new ConsumerGroupMetadata(consumerGroupId)); + + prepareAddOffsetsToTxnResponse(Errors.NONE, consumerGroupId, producerId, epoch); + runUntil(() -> !transactionManager.hasPartitionsToAdd()); + + prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.GROUP, consumerGroupId); + prepareTxnOffsetCommitResponse(consumerGroupId, producerId, epoch, singletonMap(tp1, Errors.GROUP_AUTHORIZATION_FAILED)); + + runUntil(transactionManager::hasError); + assertTrue(transactionManager.lastError() instanceof GroupAuthorizationException); + assertTrue(sendOffsetsResult.isCompleted()); + assertFalse(sendOffsetsResult.isSuccessful()); + assertTrue(sendOffsetsResult.error() instanceof GroupAuthorizationException); + assertFalse(transactionManager.hasPendingOffsetCommits()); + + GroupAuthorizationException exception = (GroupAuthorizationException) sendOffsetsResult.error(); + assertEquals(consumerGroupId, exception.groupId()); + + assertAbortableError(GroupAuthorizationException.class); + } + + @Test + public void testTransactionalIdAuthorizationFailureInAddOffsetsToTxn() { + final TopicPartition tp = new TopicPartition("foo", 0); + + doInitTransactions(); + + transactionManager.beginTransaction(); + TransactionalRequestResult sendOffsetsResult = transactionManager.sendOffsetsToTransaction( + singletonMap(tp, new OffsetAndMetadata(39L)), new ConsumerGroupMetadata(consumerGroupId)); + + prepareAddOffsetsToTxnResponse(Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED, consumerGroupId, producerId, epoch); + runUntil(transactionManager::hasError); + assertTrue(transactionManager.lastError() instanceof TransactionalIdAuthorizationException); + assertTrue(sendOffsetsResult.isCompleted()); + assertFalse(sendOffsetsResult.isSuccessful()); + assertTrue(sendOffsetsResult.error() instanceof TransactionalIdAuthorizationException); + + assertFatalError(TransactionalIdAuthorizationException.class); + } + + @Test + public void testTransactionalIdAuthorizationFailureInTxnOffsetCommit() { + final TopicPartition tp = new TopicPartition("foo", 0); + + doInitTransactions(); + + transactionManager.beginTransaction(); + TransactionalRequestResult sendOffsetsResult = transactionManager.sendOffsetsToTransaction( + singletonMap(tp, new OffsetAndMetadata(39L)), new ConsumerGroupMetadata(consumerGroupId)); + + prepareAddOffsetsToTxnResponse(Errors.NONE, consumerGroupId, producerId, epoch); + runUntil(() -> !transactionManager.hasPartitionsToAdd()); + + prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.GROUP, consumerGroupId); + prepareTxnOffsetCommitResponse(consumerGroupId, producerId, epoch, singletonMap(tp, Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED)); + runUntil(transactionManager::hasError); + + assertTrue(transactionManager.lastError() instanceof TransactionalIdAuthorizationException); + assertTrue(sendOffsetsResult.isCompleted()); + assertFalse(sendOffsetsResult.isSuccessful()); + assertTrue(sendOffsetsResult.error() instanceof TransactionalIdAuthorizationException); + + assertFatalError(TransactionalIdAuthorizationException.class); + } + + @Test + public void testTopicAuthorizationFailureInAddPartitions() throws InterruptedException { + final TopicPartition tp0 = new TopicPartition("foo", 0); + final TopicPartition tp1 = new TopicPartition("bar", 0); + + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp1); + + FutureRecordMetadata firstPartitionAppend = appendToAccumulator(tp0); + FutureRecordMetadata secondPartitionAppend = appendToAccumulator(tp1); + + Map errors = new HashMap<>(); + errors.put(tp0, Errors.TOPIC_AUTHORIZATION_FAILED); + errors.put(tp1, Errors.OPERATION_NOT_ATTEMPTED); + + prepareAddPartitionsToTxn(errors); + runUntil(transactionManager::hasError); + + assertTrue(transactionManager.lastError() instanceof TopicAuthorizationException); + assertFalse(transactionManager.isPartitionPendingAdd(tp0)); + assertFalse(transactionManager.isPartitionPendingAdd(tp1)); + assertFalse(transactionManager.isPartitionAdded(tp0)); + assertFalse(transactionManager.isPartitionAdded(tp1)); + assertFalse(transactionManager.hasPartitionsToAdd()); + + TopicAuthorizationException exception = (TopicAuthorizationException) transactionManager.lastError(); + assertEquals(singleton(tp0.topic()), exception.unauthorizedTopics()); + assertAbortableError(TopicAuthorizationException.class); + sender.runOnce(); + + TestUtils.assertFutureThrows(firstPartitionAppend, KafkaException.class); + TestUtils.assertFutureThrows(secondPartitionAppend, KafkaException.class); + } + + @Test + public void testCommitWithTopicAuthorizationFailureInAddPartitionsInFlight() throws InterruptedException { + final TopicPartition tp0 = new TopicPartition("foo", 0); + final TopicPartition tp1 = new TopicPartition("bar", 0); + + doInitTransactions(); + + // Begin a transaction, send two records, and begin commit + transactionManager.beginTransaction(); + transactionManager.maybeAddPartitionToTransaction(tp0); + transactionManager.maybeAddPartitionToTransaction(tp1); + FutureRecordMetadata firstPartitionAppend = appendToAccumulator(tp0); + FutureRecordMetadata secondPartitionAppend = appendToAccumulator(tp1); + TransactionalRequestResult commitResult = transactionManager.beginCommit(); + + // We send the AddPartitionsToTxn request in the first sender call + sender.runOnce(); + assertFalse(transactionManager.hasError()); + assertFalse(commitResult.isCompleted()); + assertFalse(firstPartitionAppend.isDone()); + + // The AddPartitionsToTxn response returns in the next call with the error + Map errors = new HashMap<>(); + errors.put(tp0, Errors.TOPIC_AUTHORIZATION_FAILED); + errors.put(tp1, Errors.OPERATION_NOT_ATTEMPTED); + client.respond(body -> { + AddPartitionsToTxnRequest request = (AddPartitionsToTxnRequest) body; + assertEquals(new HashSet<>(request.partitions()), new HashSet<>(errors.keySet())); + return true; + }, new AddPartitionsToTxnResponse(0, errors)); + + sender.runOnce(); + assertTrue(transactionManager.hasError()); + assertFalse(commitResult.isCompleted()); + assertFalse(firstPartitionAppend.isDone()); + assertFalse(secondPartitionAppend.isDone()); + + // The next call aborts the records, which have not yet been sent. It should + // not block because there are no requests pending and we still need to cancel + // the pending transaction commit. + sender.runOnce(); + assertTrue(commitResult.isCompleted()); + TestUtils.assertFutureThrows(firstPartitionAppend, KafkaException.class); + TestUtils.assertFutureThrows(secondPartitionAppend, KafkaException.class); + assertTrue(commitResult.error() instanceof TopicAuthorizationException); + } + + @Test + public void testRecoveryFromAbortableErrorTransactionNotStarted() throws Exception { + final TopicPartition unauthorizedPartition = new TopicPartition("foo", 0); + + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(unauthorizedPartition); + + Future responseFuture = appendToAccumulator(unauthorizedPartition); + + prepareAddPartitionsToTxn(singletonMap(unauthorizedPartition, Errors.TOPIC_AUTHORIZATION_FAILED)); + runUntil(() -> !client.hasPendingResponses()); + + assertTrue(transactionManager.hasAbortableError()); + transactionManager.beginAbort(); + runUntil(responseFuture::isDone); + assertProduceFutureFailed(responseFuture); + + // No partitions added, so no need to prepare EndTxn response + runUntil(transactionManager::isReady); + assertFalse(transactionManager.hasPartitionsToAdd()); + assertFalse(accumulator.hasIncomplete()); + + // ensure we can now start a new transaction + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + responseFuture = appendToAccumulator(tp0); + + prepareAddPartitionsToTxn(singletonMap(tp0, Errors.NONE)); + runUntil(() -> transactionManager.isPartitionAdded(tp0)); + assertFalse(transactionManager.hasPartitionsToAdd()); + + transactionManager.beginCommit(); + prepareProduceResponse(Errors.NONE, producerId, epoch); + runUntil(responseFuture::isDone); + assertNotNull(responseFuture.get()); + + prepareEndTxnResponse(Errors.NONE, TransactionResult.COMMIT, producerId, epoch); + runUntil(transactionManager::isReady); + } + + @Test + public void testRecoveryFromAbortableErrorTransactionStarted() throws Exception { + final TopicPartition unauthorizedPartition = new TopicPartition("foo", 0); + + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + prepareAddPartitionsToTxn(tp0, Errors.NONE); + + Future authorizedTopicProduceFuture = appendToAccumulator(unauthorizedPartition); + runUntil(() -> transactionManager.isPartitionAdded(tp0)); + + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(unauthorizedPartition); + Future unauthorizedTopicProduceFuture = appendToAccumulator(unauthorizedPartition); + prepareAddPartitionsToTxn(singletonMap(unauthorizedPartition, Errors.TOPIC_AUTHORIZATION_FAILED)); + runUntil(transactionManager::hasAbortableError); + assertTrue(transactionManager.isPartitionAdded(tp0)); + assertFalse(transactionManager.isPartitionAdded(unauthorizedPartition)); + assertFalse(authorizedTopicProduceFuture.isDone()); + assertFalse(unauthorizedTopicProduceFuture.isDone()); + + prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, producerId, epoch); + transactionManager.beginAbort(); + runUntil(transactionManager::isReady); + // neither produce request has been sent, so they should both be failed immediately + assertProduceFutureFailed(authorizedTopicProduceFuture); + assertProduceFutureFailed(unauthorizedTopicProduceFuture); + assertFalse(transactionManager.hasPartitionsToAdd()); + assertFalse(accumulator.hasIncomplete()); + + // ensure we can now start a new transaction + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + FutureRecordMetadata nextTransactionFuture = appendToAccumulator(tp0); + + prepareAddPartitionsToTxn(singletonMap(tp0, Errors.NONE)); + runUntil(() -> transactionManager.isPartitionAdded(tp0)); + assertFalse(transactionManager.hasPartitionsToAdd()); + + transactionManager.beginCommit(); + prepareProduceResponse(Errors.NONE, producerId, epoch); + runUntil(nextTransactionFuture::isDone); + assertNotNull(nextTransactionFuture.get()); + + prepareEndTxnResponse(Errors.NONE, TransactionResult.COMMIT, producerId, epoch); + runUntil(transactionManager::isReady); + } + + @Test + public void testRecoveryFromAbortableErrorProduceRequestInRetry() throws Exception { + final TopicPartition unauthorizedPartition = new TopicPartition("foo", 0); + + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + prepareAddPartitionsToTxn(tp0, Errors.NONE); + + Future authorizedTopicProduceFuture = appendToAccumulator(tp0); + runUntil(() -> transactionManager.isPartitionAdded(tp0)); + + accumulator.beginFlush(); + prepareProduceResponse(Errors.REQUEST_TIMED_OUT, producerId, epoch); + runUntil(() -> !client.hasPendingResponses()); + assertFalse(authorizedTopicProduceFuture.isDone()); + assertTrue(accumulator.hasIncomplete()); + + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(unauthorizedPartition); + Future unauthorizedTopicProduceFuture = appendToAccumulator(unauthorizedPartition); + prepareAddPartitionsToTxn(singletonMap(unauthorizedPartition, Errors.TOPIC_AUTHORIZATION_FAILED)); + runUntil(transactionManager::hasAbortableError); + assertTrue(transactionManager.isPartitionAdded(tp0)); + assertFalse(transactionManager.isPartitionAdded(unauthorizedPartition)); + assertFalse(authorizedTopicProduceFuture.isDone()); + + prepareProduceResponse(Errors.NONE, producerId, epoch); + runUntil(authorizedTopicProduceFuture::isDone); + + assertProduceFutureFailed(unauthorizedTopicProduceFuture); + assertNotNull(authorizedTopicProduceFuture.get()); + assertTrue(authorizedTopicProduceFuture.isDone()); + + prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, producerId, epoch); + transactionManager.beginAbort(); + runUntil(transactionManager::isReady); + // neither produce request has been sent, so they should both be failed immediately + assertTrue(transactionManager.isReady()); + assertFalse(transactionManager.hasPartitionsToAdd()); + assertFalse(accumulator.hasIncomplete()); + + // ensure we can now start a new transaction + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + FutureRecordMetadata nextTransactionFuture = appendToAccumulator(tp0); + + prepareAddPartitionsToTxn(singletonMap(tp0, Errors.NONE)); + runUntil(() -> transactionManager.isPartitionAdded(tp0)); + assertFalse(transactionManager.hasPartitionsToAdd()); + + transactionManager.beginCommit(); + prepareProduceResponse(Errors.NONE, producerId, epoch); + runUntil(nextTransactionFuture::isDone); + assertNotNull(nextTransactionFuture.get()); + + prepareEndTxnResponse(Errors.NONE, TransactionResult.COMMIT, producerId, epoch); + runUntil(transactionManager::isReady); + } + + @Test + public void testTransactionalIdAuthorizationFailureInAddPartitions() { + final TopicPartition tp = new TopicPartition("foo", 0); + + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp); + + prepareAddPartitionsToTxn(tp, Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED); + runUntil(transactionManager::hasError); + assertTrue(transactionManager.lastError() instanceof TransactionalIdAuthorizationException); + + assertFatalError(TransactionalIdAuthorizationException.class); + } + + @Test + public void testFlushPendingPartitionsOnCommit() throws InterruptedException { + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + Future responseFuture = appendToAccumulator(tp0); + + assertFalse(responseFuture.isDone()); + + TransactionalRequestResult commitResult = transactionManager.beginCommit(); + + // we have an append, an add partitions request, and now also an endtxn. + // The order should be: + // 1. Add Partitions + // 2. Produce + // 3. EndTxn. + assertFalse(transactionManager.transactionContainsPartition(tp0)); + prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, producerId); + + runUntil(() -> transactionManager.transactionContainsPartition(tp0)); + assertFalse(responseFuture.isDone()); + assertFalse(commitResult.isCompleted()); + + prepareProduceResponse(Errors.NONE, producerId, epoch); + runUntil(responseFuture::isDone); + + prepareEndTxnResponse(Errors.NONE, TransactionResult.COMMIT, producerId, epoch); + assertFalse(commitResult.isCompleted()); + assertTrue(transactionManager.hasOngoingTransaction()); + assertTrue(transactionManager.isCompleting()); + + runUntil(commitResult::isCompleted); + assertFalse(transactionManager.hasOngoingTransaction()); + } + + @Test + public void testMultipleAddPartitionsPerForOneProduce() throws InterruptedException { + doInitTransactions(); + + transactionManager.beginTransaction(); + // User does one producer.send + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + Future responseFuture = appendToAccumulator(tp0); + + assertFalse(responseFuture.isDone()); + prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, producerId); + + assertFalse(transactionManager.transactionContainsPartition(tp0)); + + // Sender flushes one add partitions. The produce goes next. + runUntil(() -> transactionManager.transactionContainsPartition(tp0)); + + // In the mean time, the user does a second produce to a different partition + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp1); + Future secondResponseFuture = appendToAccumulator(tp0); + + prepareAddPartitionsToTxnResponse(Errors.NONE, tp1, epoch, producerId); + prepareProduceResponse(Errors.NONE, producerId, epoch); + + assertFalse(transactionManager.transactionContainsPartition(tp1)); + + assertFalse(responseFuture.isDone()); + assertFalse(secondResponseFuture.isDone()); + + // The second add partitions should go out here. + runUntil(() -> transactionManager.transactionContainsPartition(tp1)); + + assertFalse(responseFuture.isDone()); + assertFalse(secondResponseFuture.isDone()); + + // Finally we get to the produce. + runUntil(responseFuture::isDone); + assertTrue(secondResponseFuture.isDone()); + } + + @Test + public void testProducerFencedExceptionInInitProducerId() { + verifyProducerFencedForInitProducerId(Errors.PRODUCER_FENCED); + } + + @Test + public void testInvalidProducerEpochConvertToProducerFencedInInitProducerId() { + verifyProducerFencedForInitProducerId(Errors.INVALID_PRODUCER_EPOCH); + } + + private void verifyProducerFencedForInitProducerId(Errors error) { + TransactionalRequestResult result = transactionManager.initializeTransactions(); + prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.TRANSACTION, transactionalId); + runUntil(() -> transactionManager.coordinator(CoordinatorType.TRANSACTION) != null); + assertEquals(brokerNode, transactionManager.coordinator(CoordinatorType.TRANSACTION)); + + prepareInitPidResponse(error, false, producerId, epoch); + + runUntil(transactionManager::hasError); + + assertEquals(ProducerFencedException.class, result.error().getClass()); + + assertThrows(ProducerFencedException.class, () -> transactionManager.beginTransaction()); + assertThrows(ProducerFencedException.class, () -> transactionManager.beginCommit()); + assertThrows(ProducerFencedException.class, () -> transactionManager.beginAbort()); + assertThrows(ProducerFencedException.class, () -> transactionManager.sendOffsetsToTransaction( + Collections.emptyMap(), new ConsumerGroupMetadata("dummyId"))); + } + + @Test + public void testProducerFencedInAddPartitionToTxn() throws InterruptedException { + verifyProducerFencedForAddPartitionsToTxn(Errors.PRODUCER_FENCED); + } + + @Test + public void testInvalidProducerEpochConvertToProducerFencedInAddPartitionToTxn() throws InterruptedException { + verifyProducerFencedForAddPartitionsToTxn(Errors.INVALID_PRODUCER_EPOCH); + } + + private void verifyProducerFencedForAddPartitionsToTxn(Errors error) throws InterruptedException { + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + Future responseFuture = appendToAccumulator(tp0); + + assertFalse(responseFuture.isDone()); + prepareAddPartitionsToTxnResponse(error, tp0, epoch, producerId); + + verifyProducerFenced(responseFuture); + } + + @Test + public void testProducerFencedInAddOffSetsToTxn() throws InterruptedException { + verifyProducerFencedForAddOffsetsToTxn(Errors.INVALID_PRODUCER_EPOCH); + } + + @Test + public void testInvalidProducerEpochConvertToProducerFencedInAddOffSetsToTxn() throws InterruptedException { + verifyProducerFencedForAddOffsetsToTxn(Errors.INVALID_PRODUCER_EPOCH); + } + + private void verifyProducerFencedForAddOffsetsToTxn(Errors error) throws InterruptedException { + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.sendOffsetsToTransaction(Collections.emptyMap(), new ConsumerGroupMetadata(consumerGroupId)); + + Future responseFuture = appendToAccumulator(tp0); + + assertFalse(responseFuture.isDone()); + prepareAddOffsetsToTxnResponse(error, consumerGroupId, producerId, epoch); + + verifyProducerFenced(responseFuture); + } + + private void verifyProducerFenced(Future responseFuture) throws InterruptedException { + runUntil(responseFuture::isDone); + assertTrue(transactionManager.hasError()); + + try { + // make sure the produce was expired. + responseFuture.get(); + fail("Expected to get a ExecutionException from the response"); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof ProducerFencedException); + } + + // make sure the exception was thrown directly from the follow-up calls. + assertThrows(ProducerFencedException.class, () -> transactionManager.beginTransaction()); + assertThrows(ProducerFencedException.class, () -> transactionManager.beginCommit()); + assertThrows(ProducerFencedException.class, () -> transactionManager.beginAbort()); + assertThrows(ProducerFencedException.class, () -> transactionManager.sendOffsetsToTransaction( + Collections.emptyMap(), new ConsumerGroupMetadata("dummyId"))); + } + + @Test + public void testInvalidProducerEpochConvertToProducerFencedInEndTxn() throws InterruptedException { + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + TransactionalRequestResult commitResult = transactionManager.beginCommit(); + + Future responseFuture = appendToAccumulator(tp0); + + assertFalse(responseFuture.isDone()); + prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, producerId); + prepareProduceResponse(Errors.NONE, producerId, epoch); + prepareEndTxnResponse(Errors.INVALID_PRODUCER_EPOCH, TransactionResult.COMMIT, producerId, epoch); + + runUntil(commitResult::isCompleted); + runUntil(responseFuture::isDone); + + // make sure the exception was thrown directly from the follow-up calls. + assertThrows(KafkaException.class, () -> transactionManager.beginTransaction()); + assertThrows(KafkaException.class, () -> transactionManager.beginCommit()); + assertThrows(KafkaException.class, () -> transactionManager.beginAbort()); + assertThrows(KafkaException.class, () -> transactionManager.sendOffsetsToTransaction( + Collections.emptyMap(), new ConsumerGroupMetadata("dummyId"))); + } + + @Test + public void testInvalidProducerEpochFromProduce() throws InterruptedException { + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + Future responseFuture = appendToAccumulator(tp0); + + assertFalse(responseFuture.isDone()); + prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, producerId); + prepareProduceResponse(Errors.INVALID_PRODUCER_EPOCH, producerId, epoch); + prepareProduceResponse(Errors.NONE, producerId, epoch); + + sender.runOnce(); + + runUntil(responseFuture::isDone); + assertTrue(transactionManager.hasError()); + + transactionManager.beginAbort(); + + TransactionManager.TxnRequestHandler handler = transactionManager.nextRequest(false); + + // First we will get an EndTxn for abort. + assertNotNull(handler); + assertTrue(handler.requestBuilder() instanceof EndTxnRequest.Builder); + + handler = transactionManager.nextRequest(false); + + // Second we will see an InitPid for handling InvalidProducerEpoch. + assertNotNull(handler); + assertTrue(handler.requestBuilder() instanceof InitProducerIdRequest.Builder); + } + + @Test + public void testDisallowCommitOnProduceFailure() throws InterruptedException { + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + Future responseFuture = appendToAccumulator(tp0); + + TransactionalRequestResult commitResult = transactionManager.beginCommit(); + assertFalse(responseFuture.isDone()); + prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, producerId); + prepareProduceResponse(Errors.OUT_OF_ORDER_SEQUENCE_NUMBER, producerId, epoch); + + runUntil(commitResult::isCompleted); // commit should be cancelled with exception without being sent. + + try { + commitResult.await(); + fail(); // the get() must throw an exception. + } catch (KafkaException e) { + // Expected + } + + try { + responseFuture.get(); + fail("Expected produce future to raise an exception"); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof OutOfOrderSequenceException); + } + + // Commit is not allowed, so let's abort and try again. + TransactionalRequestResult abortResult = transactionManager.beginAbort(); + prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, producerId, epoch); + prepareInitPidResponse(Errors.NONE, false, producerId, (short) (epoch + 1)); + runUntil(abortResult::isCompleted); + assertTrue(abortResult.isSuccessful()); + assertTrue(transactionManager.isReady()); // make sure we are ready for a transaction now. + } + + @Test + public void testAllowAbortOnProduceFailure() throws InterruptedException { + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + Future responseFuture = appendToAccumulator(tp0); + + assertFalse(responseFuture.isDone()); + prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, producerId); + prepareProduceResponse(Errors.OUT_OF_ORDER_SEQUENCE_NUMBER, producerId, epoch); + + // Because this is a failure that triggers an epoch bump, the abort will trigger an InitProducerId call + runUntil(transactionManager::hasAbortableError); + TransactionalRequestResult abortResult = transactionManager.beginAbort(); + prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, producerId, epoch); + prepareInitPidResponse(Errors.NONE, false, producerId, (short) (epoch + 1)); + runUntil(abortResult::isCompleted); + assertTrue(abortResult.isSuccessful()); + assertTrue(transactionManager.isReady()); // make sure we are ready for a transaction now. + } + + @Test + public void testAbortableErrorWhileAbortInProgress() throws InterruptedException { + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + Future responseFuture = appendToAccumulator(tp0); + + assertFalse(responseFuture.isDone()); + prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, producerId); + runUntil(() -> !accumulator.hasUndrained()); + + TransactionalRequestResult abortResult = transactionManager.beginAbort(); + assertTrue(transactionManager.isAborting()); + assertFalse(transactionManager.hasError()); + + sendProduceResponse(Errors.OUT_OF_ORDER_SEQUENCE_NUMBER, producerId, epoch); + prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, producerId, epoch); + runUntil(responseFuture::isDone); + + // we do not transition to ABORTABLE_ERROR since we were already aborting + assertTrue(transactionManager.isAborting()); + assertFalse(transactionManager.hasError()); + + runUntil(abortResult::isCompleted); + assertTrue(abortResult.isSuccessful()); + assertTrue(transactionManager.isReady()); // make sure we are ready for a transaction now. + } + + @Test + public void testCommitTransactionWithUnsentProduceRequest() throws Exception { + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + Future responseFuture = appendToAccumulator(tp0); + + prepareAddPartitionsToTxn(tp0, Errors.NONE); + runUntil(() -> !client.hasPendingResponses()); + assertTrue(accumulator.hasUndrained()); + + // committing the transaction should cause the unsent batch to be flushed + transactionManager.beginCommit(); + runUntil(() -> !accumulator.hasUndrained()); + assertTrue(accumulator.hasIncomplete()); + assertFalse(transactionManager.hasInFlightRequest()); + assertFalse(responseFuture.isDone()); + + // until the produce future returns, we will not send EndTxn + AtomicInteger numRuns = new AtomicInteger(0); + runUntil(() -> numRuns.incrementAndGet() >= 4); + assertFalse(accumulator.hasUndrained()); + assertTrue(accumulator.hasIncomplete()); + assertFalse(transactionManager.hasInFlightRequest()); + assertFalse(responseFuture.isDone()); + + // now the produce response returns + sendProduceResponse(Errors.NONE, producerId, epoch); + runUntil(responseFuture::isDone); + assertFalse(accumulator.hasUndrained()); + assertFalse(accumulator.hasIncomplete()); + assertFalse(transactionManager.hasInFlightRequest()); + + // now we send EndTxn + runUntil(transactionManager::hasInFlightRequest); + sendEndTxnResponse(Errors.NONE, TransactionResult.COMMIT, producerId, epoch); + + runUntil(transactionManager::isReady); + assertFalse(transactionManager.hasInFlightRequest()); + } + + @Test + public void testCommitTransactionWithInFlightProduceRequest() throws Exception { + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + Future responseFuture = appendToAccumulator(tp0); + + prepareAddPartitionsToTxn(tp0, Errors.NONE); + runUntil(() -> !transactionManager.hasPartitionsToAdd()); + assertTrue(accumulator.hasUndrained()); + + accumulator.beginFlush(); + runUntil(() -> !accumulator.hasUndrained()); + assertFalse(accumulator.hasUndrained()); + assertTrue(accumulator.hasIncomplete()); + assertFalse(transactionManager.hasInFlightRequest()); + + // now we begin the commit with the produce request still pending + transactionManager.beginCommit(); + AtomicInteger numRuns = new AtomicInteger(0); + runUntil(() -> numRuns.incrementAndGet() >= 4); + assertFalse(accumulator.hasUndrained()); + assertTrue(accumulator.hasIncomplete()); + assertFalse(transactionManager.hasInFlightRequest()); + assertFalse(responseFuture.isDone()); + + // now the produce response returns + sendProduceResponse(Errors.NONE, producerId, epoch); + runUntil(responseFuture::isDone); + assertFalse(accumulator.hasUndrained()); + assertFalse(accumulator.hasIncomplete()); + assertFalse(transactionManager.hasInFlightRequest()); + + // now we send EndTxn + runUntil(transactionManager::hasInFlightRequest); + sendEndTxnResponse(Errors.NONE, TransactionResult.COMMIT, producerId, epoch); + runUntil(transactionManager::isReady); + assertFalse(transactionManager.hasInFlightRequest()); + } + + @Test + public void testFindCoordinatorAllowedInAbortableErrorState() throws InterruptedException { + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + Future responseFuture = appendToAccumulator(tp0); + + assertFalse(responseFuture.isDone()); + runUntil(transactionManager::hasInFlightRequest); + + transactionManager.transitionToAbortableError(new KafkaException()); + sendAddPartitionsToTxnResponse(Errors.NOT_COORDINATOR, tp0, epoch, producerId); + runUntil(() -> transactionManager.coordinator(CoordinatorType.TRANSACTION) == null); + + prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.TRANSACTION, transactionalId); + runUntil(() -> transactionManager.coordinator(CoordinatorType.TRANSACTION) != null); + assertEquals(brokerNode, transactionManager.coordinator(CoordinatorType.TRANSACTION)); + assertTrue(transactionManager.hasAbortableError()); + } + + @Test + public void testCancelUnsentAddPartitionsAndProduceOnAbort() throws InterruptedException { + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + Future responseFuture = appendToAccumulator(tp0); + + assertFalse(responseFuture.isDone()); + + TransactionalRequestResult abortResult = transactionManager.beginAbort(); + // note since no partitions were added to the transaction, no EndTxn will be sent + + runUntil(abortResult::isCompleted); + assertTrue(abortResult.isSuccessful()); + assertTrue(transactionManager.isReady()); // make sure we are ready for a transaction now. + + try { + responseFuture.get(); + fail("Expected produce future to raise an exception"); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof KafkaException); + } + } + + @Test + public void testAbortResendsAddPartitionErrorIfRetried() throws InterruptedException { + doInitTransactions(producerId, epoch); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + prepareAddPartitionsToTxnResponse(Errors.UNKNOWN_TOPIC_OR_PARTITION, tp0, epoch, producerId); + + Future responseFuture = appendToAccumulator(tp0); + + runUntil(() -> !client.hasPendingResponses()); + assertFalse(responseFuture.isDone()); + + TransactionalRequestResult abortResult = transactionManager.beginAbort(); + + // we should resend the AddPartitions + prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, producerId); + prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, producerId, epoch); + + runUntil(abortResult::isCompleted); + assertTrue(abortResult.isSuccessful()); + assertTrue(transactionManager.isReady()); // make sure we are ready for a transaction now. + + try { + responseFuture.get(); + fail("Expected produce future to raise an exception"); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof KafkaException); + } + } + + @Test + public void testAbortResendsProduceRequestIfRetried() throws Exception { + doInitTransactions(producerId, epoch); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, producerId); + prepareProduceResponse(Errors.REQUEST_TIMED_OUT, producerId, epoch); + + Future responseFuture = appendToAccumulator(tp0); + + runUntil(() -> !client.hasPendingResponses()); + assertFalse(responseFuture.isDone()); + + TransactionalRequestResult abortResult = transactionManager.beginAbort(); + + // we should resend the ProduceRequest before aborting + prepareProduceResponse(Errors.NONE, producerId, epoch); + prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, producerId, epoch); + + runUntil(abortResult::isCompleted); + assertTrue(abortResult.isSuccessful()); + assertTrue(transactionManager.isReady()); // make sure we are ready for a transaction now. + + RecordMetadata recordMetadata = responseFuture.get(); + assertEquals(tp0.topic(), recordMetadata.topic()); + } + + @Test + public void testHandlingOfUnknownTopicPartitionErrorOnAddPartitions() throws InterruptedException { + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + Future responseFuture = appendToAccumulator(tp0); + + assertFalse(responseFuture.isDone()); + prepareAddPartitionsToTxnResponse(Errors.UNKNOWN_TOPIC_OR_PARTITION, tp0, epoch, producerId); + + runUntil(() -> !client.hasPendingResponses()); + assertFalse(transactionManager.transactionContainsPartition(tp0)); // The partition should not yet be added. + + prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, producerId); + prepareProduceResponse(Errors.NONE, producerId, epoch); + runUntil(() -> transactionManager.transactionContainsPartition(tp0)); + runUntil(responseFuture::isDone); + } + + @Test + public void testHandlingOfUnknownTopicPartitionErrorOnTxnOffsetCommit() { + testRetriableErrorInTxnOffsetCommit(Errors.UNKNOWN_TOPIC_OR_PARTITION); + } + + @Test + public void testHandlingOfCoordinatorLoadingErrorOnTxnOffsetCommit() { + testRetriableErrorInTxnOffsetCommit(Errors.COORDINATOR_LOAD_IN_PROGRESS); + } + + private void testRetriableErrorInTxnOffsetCommit(Errors error) { + doInitTransactions(); + + transactionManager.beginTransaction(); + + Map offsets = new HashMap<>(); + offsets.put(tp0, new OffsetAndMetadata(1)); + offsets.put(tp1, new OffsetAndMetadata(1)); + + TransactionalRequestResult addOffsetsResult = transactionManager.sendOffsetsToTransaction( + offsets, new ConsumerGroupMetadata(consumerGroupId)); + prepareAddOffsetsToTxnResponse(Errors.NONE, consumerGroupId, producerId, epoch); + runUntil(() -> !client.hasPendingResponses()); + assertFalse(addOffsetsResult.isCompleted()); // The request should complete only after the TxnOffsetCommit completes. + + Map txnOffsetCommitResponse = new HashMap<>(); + txnOffsetCommitResponse.put(tp0, Errors.NONE); + txnOffsetCommitResponse.put(tp1, error); + + prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.GROUP, consumerGroupId); + prepareTxnOffsetCommitResponse(consumerGroupId, producerId, epoch, txnOffsetCommitResponse); + + assertNull(transactionManager.coordinator(CoordinatorType.GROUP)); + runUntil(() -> transactionManager.coordinator(CoordinatorType.GROUP) != null); + assertTrue(transactionManager.hasPendingOffsetCommits()); + + runUntil(transactionManager::hasPendingOffsetCommits); // The TxnOffsetCommit failed. + assertFalse(addOffsetsResult.isCompleted()); // We should only be done after both RPCs complete successfully. + + txnOffsetCommitResponse.put(tp1, Errors.NONE); + prepareTxnOffsetCommitResponse(consumerGroupId, producerId, epoch, txnOffsetCommitResponse); + runUntil(addOffsetsResult::isCompleted); + assertTrue(addOffsetsResult.isSuccessful()); + } + + @Test + public void testHandlingOfProducerFencedErrorOnTxnOffsetCommit() { + testFatalErrorInTxnOffsetCommit(Errors.PRODUCER_FENCED); + } + + @Test + public void testHandlingOfTransactionalIdAuthorizationFailedErrorOnTxnOffsetCommit() { + testFatalErrorInTxnOffsetCommit(Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED); + } + + @Test + public void testHandlingOfInvalidProducerEpochErrorOnTxnOffsetCommit() { + testFatalErrorInTxnOffsetCommit(Errors.INVALID_PRODUCER_EPOCH); + } + + @Test + public void testHandlingOfUnsupportedForMessageFormatErrorOnTxnOffsetCommit() { + testFatalErrorInTxnOffsetCommit(Errors.UNSUPPORTED_FOR_MESSAGE_FORMAT); + } + + private void testFatalErrorInTxnOffsetCommit(final Errors error) { + doInitTransactions(); + + transactionManager.beginTransaction(); + + Map offsets = new HashMap<>(); + offsets.put(tp0, new OffsetAndMetadata(1)); + offsets.put(tp1, new OffsetAndMetadata(1)); + + TransactionalRequestResult addOffsetsResult = transactionManager.sendOffsetsToTransaction( + offsets, new ConsumerGroupMetadata(consumerGroupId)); + prepareAddOffsetsToTxnResponse(Errors.NONE, consumerGroupId, producerId, epoch); + runUntil(() -> !client.hasPendingResponses()); + assertFalse(addOffsetsResult.isCompleted()); // The request should complete only after the TxnOffsetCommit completes. + + Map txnOffsetCommitResponse = new HashMap<>(); + txnOffsetCommitResponse.put(tp0, Errors.NONE); + txnOffsetCommitResponse.put(tp1, error); + + prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.GROUP, consumerGroupId); + prepareTxnOffsetCommitResponse(consumerGroupId, producerId, epoch, txnOffsetCommitResponse); + + runUntil(addOffsetsResult::isCompleted); + assertFalse(addOffsetsResult.isSuccessful()); + assertEquals(error.exception().getClass(), addOffsetsResult.error().getClass()); + } + + @Test + public void shouldNotAddPartitionsToTransactionWhenTopicAuthorizationFailed() throws Exception { + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + Future responseFuture = appendToAccumulator(tp0); + assertFalse(responseFuture.isDone()); + prepareAddPartitionsToTxn(tp0, Errors.TOPIC_AUTHORIZATION_FAILED); + runUntil(transactionManager::hasError); + assertFalse(transactionManager.transactionContainsPartition(tp0)); + } + + @Test + public void shouldNotSendAbortTxnRequestWhenOnlyAddPartitionsRequestFailed() { + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + prepareAddPartitionsToTxnResponse(Errors.TOPIC_AUTHORIZATION_FAILED, tp0, epoch, producerId); + runUntil(() -> !client.hasPendingResponses()); + + TransactionalRequestResult abortResult = transactionManager.beginAbort(); + assertFalse(abortResult.isCompleted()); + + runUntil(abortResult::isCompleted); + assertTrue(abortResult.isSuccessful()); + } + + @Test + public void shouldNotSendAbortTxnRequestWhenOnlyAddOffsetsRequestFailed() { + doInitTransactions(); + + transactionManager.beginTransaction(); + Map offsets = new HashMap<>(); + offsets.put(tp1, new OffsetAndMetadata(1)); + + transactionManager.sendOffsetsToTransaction(offsets, new ConsumerGroupMetadata(consumerGroupId)); + + TransactionalRequestResult abortResult = transactionManager.beginAbort(); + + prepareAddOffsetsToTxnResponse(Errors.GROUP_AUTHORIZATION_FAILED, consumerGroupId, producerId, epoch); + runUntil(abortResult::isCompleted); + assertTrue(transactionManager.isReady()); + assertTrue(abortResult.isCompleted()); + assertTrue(abortResult.isSuccessful()); + } + + @Test + public void shouldFailAbortIfAddOffsetsFailsWithFatalError() { + doInitTransactions(); + + transactionManager.beginTransaction(); + Map offsets = new HashMap<>(); + offsets.put(tp1, new OffsetAndMetadata(1)); + + transactionManager.sendOffsetsToTransaction(offsets, new ConsumerGroupMetadata(consumerGroupId)); + + TransactionalRequestResult abortResult = transactionManager.beginAbort(); + + prepareAddOffsetsToTxnResponse(Errors.UNKNOWN_SERVER_ERROR, consumerGroupId, producerId, epoch); + + runUntil(abortResult::isCompleted); + assertFalse(abortResult.isSuccessful()); + assertTrue(transactionManager.hasFatalError()); + } + + @Test + public void testSendOffsetsWithGroupMetadata() { + Map txnOffsetCommitResponse = new HashMap<>(); + txnOffsetCommitResponse.put(tp0, Errors.NONE); + txnOffsetCommitResponse.put(tp1, Errors.COORDINATOR_LOAD_IN_PROGRESS); + + TransactionalRequestResult addOffsetsResult = prepareGroupMetadataCommit( + () -> prepareTxnOffsetCommitResponse(consumerGroupId, producerId, + epoch, groupInstanceId, memberId, generationId, txnOffsetCommitResponse)); + + sender.runOnce(); // Send TxnOffsetCommitRequest request. + + assertTrue(transactionManager.hasPendingOffsetCommits()); // The TxnOffsetCommit failed. + assertFalse(addOffsetsResult.isCompleted()); // We should only be done after both RPCs complete successfully. + + txnOffsetCommitResponse.put(tp1, Errors.NONE); + prepareTxnOffsetCommitResponse(consumerGroupId, producerId, epoch, groupInstanceId, memberId, generationId, txnOffsetCommitResponse); + sender.runOnce(); // Send TxnOffsetCommitRequest again. + + assertTrue(addOffsetsResult.isCompleted()); + assertTrue(addOffsetsResult.isSuccessful()); + } + + @Test + public void testSendOffsetWithGroupMetadataFailAsAutoDowngradeTxnCommitNotEnabled() { + client.setNodeApiVersions(NodeApiVersions.create(ApiKeys.TXN_OFFSET_COMMIT.id, (short) 0, (short) 2)); + + Map txnOffsetCommitResponse = new HashMap<>(); + txnOffsetCommitResponse.put(tp0, Errors.NONE); + txnOffsetCommitResponse.put(tp1, Errors.COORDINATOR_LOAD_IN_PROGRESS); + + TransactionalRequestResult addOffsetsResult = prepareGroupMetadataCommit( + () -> prepareTxnOffsetCommitResponse(consumerGroupId, producerId, epoch, txnOffsetCommitResponse)); + + sender.runOnce(); + + assertTrue(addOffsetsResult.isCompleted()); + assertFalse(addOffsetsResult.isSuccessful()); + assertTrue(addOffsetsResult.error() instanceof UnsupportedVersionException); + assertFatalError(UnsupportedVersionException.class); + } + + private TransactionalRequestResult prepareGroupMetadataCommit(Runnable prepareTxnCommitResponse) { + doInitTransactions(); + + transactionManager.beginTransaction(); + Map offsets = new HashMap<>(); + offsets.put(tp0, new OffsetAndMetadata(1)); + offsets.put(tp1, new OffsetAndMetadata(1)); + + TransactionalRequestResult addOffsetsResult = transactionManager.sendOffsetsToTransaction( + offsets, new ConsumerGroupMetadata(consumerGroupId, generationId, memberId, Optional.of(groupInstanceId))); + prepareAddOffsetsToTxnResponse(Errors.NONE, consumerGroupId, producerId, epoch); + + sender.runOnce(); // send AddOffsetsToTxnResult + + assertFalse(addOffsetsResult.isCompleted()); // The request should complete only after the TxnOffsetCommit completes + + prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.GROUP, consumerGroupId); +// prepareTxnOffsetCommitResponse(consumerGroupId, producerId, epoch, groupInstanceId, memberId, generationId, txnOffsetCommitResponse); + prepareTxnCommitResponse.run(); + + assertNull(transactionManager.coordinator(CoordinatorType.GROUP)); + sender.runOnce(); // try to send TxnOffsetCommitRequest, but find we don't have a group coordinator + sender.runOnce(); // send find coordinator for group request + assertNotNull(transactionManager.coordinator(CoordinatorType.GROUP)); + assertTrue(transactionManager.hasPendingOffsetCommits()); + return addOffsetsResult; + } + + @Test + public void testNoDrainWhenPartitionsPending() throws InterruptedException { + doInitTransactions(); + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + appendToAccumulator(tp0); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp1); + appendToAccumulator(tp1); + + assertFalse(transactionManager.isSendToPartitionAllowed(tp0)); + assertFalse(transactionManager.isSendToPartitionAllowed(tp1)); + + Node node1 = new Node(0, "localhost", 1111); + Node node2 = new Node(1, "localhost", 1112); + PartitionInfo part1 = new PartitionInfo(topic, 0, node1, null, null); + PartitionInfo part2 = new PartitionInfo(topic, 1, node2, null, null); + + Cluster cluster = new Cluster(null, Arrays.asList(node1, node2), Arrays.asList(part1, part2), + Collections.emptySet(), Collections.emptySet()); + Set nodes = new HashSet<>(); + nodes.add(node1); + nodes.add(node2); + Map> drainedBatches = accumulator.drain(cluster, nodes, Integer.MAX_VALUE, + time.milliseconds()); + + // We shouldn't drain batches which haven't been added to the transaction yet. + assertTrue(drainedBatches.containsKey(node1.id())); + assertTrue(drainedBatches.get(node1.id()).isEmpty()); + assertTrue(drainedBatches.containsKey(node2.id())); + assertTrue(drainedBatches.get(node2.id()).isEmpty()); + assertFalse(transactionManager.hasError()); + } + + @Test + public void testAllowDrainInAbortableErrorState() throws InterruptedException { + doInitTransactions(); + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp1); + prepareAddPartitionsToTxn(tp1, Errors.NONE); + runUntil(() -> transactionManager.transactionContainsPartition(tp1)); + + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + prepareAddPartitionsToTxn(tp0, Errors.TOPIC_AUTHORIZATION_FAILED); + runUntil(transactionManager::hasAbortableError); + assertTrue(transactionManager.isSendToPartitionAllowed(tp1)); + + // Try to drain a message destined for tp1, it should get drained. + Node node1 = new Node(1, "localhost", 1112); + PartitionInfo part1 = new PartitionInfo(topic, 1, node1, null, null); + Cluster cluster = new Cluster(null, Collections.singletonList(node1), Collections.singletonList(part1), + Collections.emptySet(), Collections.emptySet()); + appendToAccumulator(tp1); + Map> drainedBatches = accumulator.drain(cluster, Collections.singleton(node1), + Integer.MAX_VALUE, + time.milliseconds()); + + // We should drain the appended record since we are in abortable state and the partition has already been + // added to the transaction. + assertTrue(drainedBatches.containsKey(node1.id())); + assertEquals(1, drainedBatches.get(node1.id()).size()); + assertTrue(transactionManager.hasAbortableError()); + } + + @Test + public void testRaiseErrorWhenNoPartitionsPendingOnDrain() throws InterruptedException { + doInitTransactions(); + transactionManager.beginTransaction(); + // Don't execute transactionManager.maybeAddPartitionToTransaction(tp0). This should result in an error on drain. + appendToAccumulator(tp0); + Node node1 = new Node(0, "localhost", 1111); + PartitionInfo part1 = new PartitionInfo(topic, 0, node1, null, null); + + Cluster cluster = new Cluster(null, Collections.singletonList(node1), Collections.singletonList(part1), + Collections.emptySet(), Collections.emptySet()); + Set nodes = new HashSet<>(); + nodes.add(node1); + Map> drainedBatches = accumulator.drain(cluster, nodes, Integer.MAX_VALUE, + time.milliseconds()); + + // We shouldn't drain batches which haven't been added to the transaction yet. + assertTrue(drainedBatches.containsKey(node1.id())); + assertTrue(drainedBatches.get(node1.id()).isEmpty()); + } + + @Test + public void resendFailedProduceRequestAfterAbortableError() throws Exception { + doInitTransactions(); + transactionManager.beginTransaction(); + + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + Future responseFuture = appendToAccumulator(tp0); + + prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, producerId); + prepareProduceResponse(Errors.NOT_LEADER_OR_FOLLOWER, producerId, epoch); + runUntil(() -> !client.hasPendingResponses()); + + assertFalse(responseFuture.isDone()); + + transactionManager.transitionToAbortableError(new KafkaException()); + prepareProduceResponse(Errors.NONE, producerId, epoch); + runUntil(responseFuture::isDone); + assertNotNull(responseFuture.get()); // should throw the exception which caused the transaction to be aborted. + } + + @Test + public void testTransitionToAbortableErrorOnBatchExpiry() throws InterruptedException { + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + Future responseFuture = appendToAccumulator(tp0); + + assertFalse(responseFuture.isDone()); + + prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, producerId); + + assertFalse(transactionManager.transactionContainsPartition(tp0)); + assertFalse(transactionManager.isSendToPartitionAllowed(tp0)); + // Check that only addPartitions was sent. + runUntil(() -> transactionManager.transactionContainsPartition(tp0)); + assertTrue(transactionManager.isSendToPartitionAllowed(tp0)); + assertFalse(responseFuture.isDone()); + + // Sleep 10 seconds to make sure that the batches in the queue would be expired if they can't be drained. + time.sleep(10000); + // Disconnect the target node for the pending produce request. This will ensure that sender will try to + // expire the batch. + Node clusterNode = metadata.fetch().nodes().get(0); + client.disconnect(clusterNode.idString()); + client.backoff(clusterNode, 100); + + runUntil(responseFuture::isDone); + + try { + // make sure the produce was expired. + responseFuture.get(); + fail("Expected to get a TimeoutException since the queued ProducerBatch should have been expired"); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof TimeoutException); + } + assertTrue(transactionManager.hasAbortableError()); + } + + @Test + public void testTransitionToAbortableErrorOnMultipleBatchExpiry() throws InterruptedException { + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp1); + + Future firstBatchResponse = appendToAccumulator(tp0); + Future secondBatchResponse = appendToAccumulator(tp1); + + assertFalse(firstBatchResponse.isDone()); + assertFalse(secondBatchResponse.isDone()); + + Map partitionErrors = new HashMap<>(); + partitionErrors.put(tp0, Errors.NONE); + partitionErrors.put(tp1, Errors.NONE); + prepareAddPartitionsToTxn(partitionErrors); + + assertFalse(transactionManager.transactionContainsPartition(tp0)); + assertFalse(transactionManager.isSendToPartitionAllowed(tp0)); + // Check that only addPartitions was sent. + runUntil(() -> transactionManager.transactionContainsPartition(tp0)); + assertTrue(transactionManager.transactionContainsPartition(tp1)); + assertTrue(transactionManager.isSendToPartitionAllowed(tp1)); + assertTrue(transactionManager.isSendToPartitionAllowed(tp1)); + assertFalse(firstBatchResponse.isDone()); + assertFalse(secondBatchResponse.isDone()); + + // Sleep 10 seconds to make sure that the batches in the queue would be expired if they can't be drained. + time.sleep(10000); + // Disconnect the target node for the pending produce request. This will ensure that sender will try to + // expire the batch. + Node clusterNode = metadata.fetch().nodes().get(0); + client.disconnect(clusterNode.idString()); + client.backoff(clusterNode, 100); + + runUntil(firstBatchResponse::isDone); + runUntil(secondBatchResponse::isDone); + + try { + // make sure the produce was expired. + firstBatchResponse.get(); + fail("Expected to get a TimeoutException since the queued ProducerBatch should have been expired"); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof TimeoutException); + } + + try { + // make sure the produce was expired. + secondBatchResponse.get(); + fail("Expected to get a TimeoutException since the queued ProducerBatch should have been expired"); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof TimeoutException); + } + assertTrue(transactionManager.hasAbortableError()); + } + + @Test + public void testDropCommitOnBatchExpiry() throws InterruptedException { + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + Future responseFuture = appendToAccumulator(tp0); + + assertFalse(responseFuture.isDone()); + + prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, producerId); + + assertFalse(transactionManager.transactionContainsPartition(tp0)); + assertFalse(transactionManager.isSendToPartitionAllowed(tp0)); + // Check that only addPartitions was sent. + runUntil(() -> transactionManager.transactionContainsPartition(tp0)); + assertTrue(transactionManager.isSendToPartitionAllowed(tp0)); + assertFalse(responseFuture.isDone()); + + TransactionalRequestResult commitResult = transactionManager.beginCommit(); + + // Sleep 10 seconds to make sure that the batches in the queue would be expired if they can't be drained. + time.sleep(10000); + // Disconnect the target node for the pending produce request. This will ensure that sender will try to + // expire the batch. + Node clusterNode = metadata.fetch().nodes().get(0); + client.disconnect(clusterNode.idString()); + + runUntil(responseFuture::isDone); // We should try to flush the produce, but expire it instead without sending anything. + + try { + // make sure the produce was expired. + responseFuture.get(); + fail("Expected to get a TimeoutException since the queued ProducerBatch should have been expired"); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof TimeoutException); + } + runUntil(commitResult::isCompleted); // the commit shouldn't be completed without being sent since the produce request failed. + assertFalse(commitResult.isSuccessful()); // the commit shouldn't succeed since the produce request failed. + + assertTrue(transactionManager.hasAbortableError()); + assertTrue(transactionManager.hasOngoingTransaction()); + assertFalse(transactionManager.isCompleting()); + assertTrue(transactionManager.transactionContainsPartition(tp0)); + + TransactionalRequestResult abortResult = transactionManager.beginAbort(); + + prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, producerId, epoch); + prepareInitPidResponse(Errors.NONE, false, producerId, (short) (epoch + 1)); + runUntil(abortResult::isCompleted); + assertTrue(abortResult.isSuccessful()); + assertFalse(transactionManager.hasOngoingTransaction()); + assertFalse(transactionManager.transactionContainsPartition(tp0)); + } + + @Test + public void testTransitionToFatalErrorWhenRetriedBatchIsExpired() throws InterruptedException { + apiVersions.update("0", new NodeApiVersions(Arrays.asList( + new ApiVersion() + .setApiKey(ApiKeys.INIT_PRODUCER_ID.id) + .setMinVersion((short) 0) + .setMaxVersion((short) 1), + new ApiVersion() + .setApiKey(ApiKeys.PRODUCE.id) + .setMinVersion((short) 0) + .setMaxVersion((short) 7)))); + + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + Future responseFuture = appendToAccumulator(tp0); + + assertFalse(responseFuture.isDone()); + + prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, producerId); + + assertFalse(transactionManager.transactionContainsPartition(tp0)); + assertFalse(transactionManager.isSendToPartitionAllowed(tp0)); + // Check that only addPartitions was sent. + runUntil(() -> transactionManager.transactionContainsPartition(tp0)); + assertTrue(transactionManager.isSendToPartitionAllowed(tp0)); + + prepareProduceResponse(Errors.NOT_LEADER_OR_FOLLOWER, producerId, epoch); + runUntil(() -> !client.hasPendingResponses()); + assertFalse(responseFuture.isDone()); + + TransactionalRequestResult commitResult = transactionManager.beginCommit(); + + // Sleep 10 seconds to make sure that the batches in the queue would be expired if they can't be drained. + time.sleep(10000); + // Disconnect the target node for the pending produce request. This will ensure that sender will try to + // expire the batch. + Node clusterNode = metadata.fetch().nodes().get(0); + client.disconnect(clusterNode.idString()); + client.backoff(clusterNode, 100); + + runUntil(responseFuture::isDone); // We should try to flush the produce, but expire it instead without sending anything. + + try { + // make sure the produce was expired. + responseFuture.get(); + fail("Expected to get a TimeoutException since the queued ProducerBatch should have been expired"); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof TimeoutException); + } + runUntil(commitResult::isCompleted); + assertFalse(commitResult.isSuccessful()); // the commit should have been dropped. + + assertTrue(transactionManager.hasFatalError()); + assertFalse(transactionManager.hasOngoingTransaction()); + } + + @Test + public void testBumpEpochAfterTimeoutWithoutPendingInflightRequests() { + initializeTransactionManager(Optional.empty()); + long producerId = 15L; + short epoch = 5; + ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(producerId, epoch); + initializeIdempotentProducerId(producerId, epoch); + + // Nothing to resolve, so no reset is needed + transactionManager.bumpIdempotentEpochAndResetIdIfNeeded(); + assertEquals(producerIdAndEpoch, transactionManager.producerIdAndEpoch()); + + TopicPartition tp0 = new TopicPartition("foo", 0); + assertEquals(Integer.valueOf(0), transactionManager.sequenceNumber(tp0)); + + ProducerBatch b1 = writeIdempotentBatchWithValue(transactionManager, tp0, "1"); + assertEquals(Integer.valueOf(1), transactionManager.sequenceNumber(tp0)); + transactionManager.handleCompletedBatch(b1, new ProduceResponse.PartitionResponse( + Errors.NONE, 500L, time.milliseconds(), 0L)); + assertEquals(OptionalInt.of(0), transactionManager.lastAckedSequence(tp0)); + + // Marking sequence numbers unresolved without inflight requests is basically a no-op. + transactionManager.markSequenceUnresolved(b1); + transactionManager.maybeResolveSequences(); + assertEquals(producerIdAndEpoch, transactionManager.producerIdAndEpoch()); + assertFalse(transactionManager.hasUnresolvedSequences()); + + // We have a new batch which fails with a timeout + ProducerBatch b2 = writeIdempotentBatchWithValue(transactionManager, tp0, "2"); + assertEquals(Integer.valueOf(2), transactionManager.sequenceNumber(tp0)); + transactionManager.markSequenceUnresolved(b2); + transactionManager.handleFailedBatch(b2, new TimeoutException(), false); + assertTrue(transactionManager.hasUnresolvedSequences()); + + // We only had one inflight batch, so we should be able to clear the unresolved status + // and bump the epoch + transactionManager.maybeResolveSequences(); + assertFalse(transactionManager.hasUnresolvedSequences()); + + // Run sender loop to trigger epoch bump + runUntil(() -> transactionManager.producerIdAndEpoch().epoch == 6); + } + + @Test + public void testNoProducerIdResetAfterLastInFlightBatchSucceeds() { + initializeTransactionManager(Optional.empty()); + long producerId = 15L; + short epoch = 5; + ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(producerId, epoch); + initializeIdempotentProducerId(producerId, epoch); + + TopicPartition tp0 = new TopicPartition("foo", 0); + ProducerBatch b1 = writeIdempotentBatchWithValue(transactionManager, tp0, "1"); + ProducerBatch b2 = writeIdempotentBatchWithValue(transactionManager, tp0, "2"); + ProducerBatch b3 = writeIdempotentBatchWithValue(transactionManager, tp0, "3"); + assertEquals(3, transactionManager.sequenceNumber(tp0).intValue()); + + // The first batch fails with a timeout + transactionManager.markSequenceUnresolved(b1); + transactionManager.handleFailedBatch(b1, new TimeoutException(), false); + assertTrue(transactionManager.hasUnresolvedSequences()); + + // The reset should not occur until sequence numbers have been resolved + transactionManager.bumpIdempotentEpochAndResetIdIfNeeded(); + assertEquals(producerIdAndEpoch, transactionManager.producerIdAndEpoch()); + assertTrue(transactionManager.hasUnresolvedSequences()); + + // The second batch fails as well with a timeout + transactionManager.handleFailedBatch(b2, new TimeoutException(), false); + transactionManager.bumpIdempotentEpochAndResetIdIfNeeded(); + assertEquals(producerIdAndEpoch, transactionManager.producerIdAndEpoch()); + assertTrue(transactionManager.hasUnresolvedSequences()); + + // The third batch succeeds, which should resolve the sequence number without + // requiring a producerId reset. + transactionManager.handleCompletedBatch(b3, new ProduceResponse.PartitionResponse( + Errors.NONE, 500L, time.milliseconds(), 0L)); + transactionManager.maybeResolveSequences(); + assertEquals(producerIdAndEpoch, transactionManager.producerIdAndEpoch()); + assertFalse(transactionManager.hasUnresolvedSequences()); + assertEquals(3, transactionManager.sequenceNumber(tp0).intValue()); + } + + @Test + public void testEpochBumpAfterLastInflightBatchFails() { + initializeTransactionManager(Optional.empty()); + ProducerIdAndEpoch producerIdAndEpoch = new ProducerIdAndEpoch(producerId, epoch); + initializeIdempotentProducerId(producerId, epoch); + + TopicPartition tp0 = new TopicPartition("foo", 0); + ProducerBatch b1 = writeIdempotentBatchWithValue(transactionManager, tp0, "1"); + ProducerBatch b2 = writeIdempotentBatchWithValue(transactionManager, tp0, "2"); + ProducerBatch b3 = writeIdempotentBatchWithValue(transactionManager, tp0, "3"); + assertEquals(Integer.valueOf(3), transactionManager.sequenceNumber(tp0)); + + // The first batch fails with a timeout + transactionManager.markSequenceUnresolved(b1); + transactionManager.handleFailedBatch(b1, new TimeoutException(), false); + assertTrue(transactionManager.hasUnresolvedSequences()); + + // The second batch succeeds, but sequence numbers are still not resolved + transactionManager.handleCompletedBatch(b2, new ProduceResponse.PartitionResponse( + Errors.NONE, 500L, time.milliseconds(), 0L)); + transactionManager.bumpIdempotentEpochAndResetIdIfNeeded(); + assertEquals(producerIdAndEpoch, transactionManager.producerIdAndEpoch()); + assertTrue(transactionManager.hasUnresolvedSequences()); + + // When the last inflight batch fails, we have to bump the epoch + transactionManager.handleFailedBatch(b3, new TimeoutException(), false); + + // Run sender loop to trigger epoch bump + runUntil(() -> transactionManager.producerIdAndEpoch().epoch == 2); + assertFalse(transactionManager.hasUnresolvedSequences()); + assertEquals(0, transactionManager.sequenceNumber(tp0).intValue()); + } + + @Test + public void testNoFailedBatchHandlingWhenTxnManagerIsInFatalError() { + initializeTransactionManager(Optional.empty()); + long producerId = 15L; + short epoch = 5; + initializeIdempotentProducerId(producerId, epoch); + + TopicPartition tp0 = new TopicPartition("foo", 0); + ProducerBatch b1 = writeIdempotentBatchWithValue(transactionManager, tp0, "1"); + // Handling b1 should bump the epoch after OutOfOrderSequenceException + transactionManager.handleFailedBatch(b1, new OutOfOrderSequenceException("out of sequence"), false); + transactionManager.bumpIdempotentEpochAndResetIdIfNeeded(); + ProducerIdAndEpoch idAndEpochAfterFirstBatch = new ProducerIdAndEpoch(producerId, (short) (epoch + 1)); + assertEquals(idAndEpochAfterFirstBatch, transactionManager.producerIdAndEpoch()); + + transactionManager.transitionToFatalError(new KafkaException()); + + // The second batch should not bump the epoch as txn manager is already in fatal error state + ProducerBatch b2 = writeIdempotentBatchWithValue(transactionManager, tp0, "2"); + transactionManager.handleFailedBatch(b2, new TimeoutException(), true); + transactionManager.bumpIdempotentEpochAndResetIdIfNeeded(); + assertEquals(idAndEpochAfterFirstBatch, transactionManager.producerIdAndEpoch()); + } + + @Test + public void testAbortTransactionAndReuseSequenceNumberOnError() throws InterruptedException { + apiVersions.update("0", new NodeApiVersions(Arrays.asList( + new ApiVersion() + .setApiKey(ApiKeys.INIT_PRODUCER_ID.id) + .setMinVersion((short) 0) + .setMaxVersion((short) 1), + new ApiVersion() + .setApiKey(ApiKeys.PRODUCE.id) + .setMinVersion((short) 0) + .setMaxVersion((short) 7)))); + + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + Future responseFuture0 = appendToAccumulator(tp0); + prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, producerId); + prepareProduceResponse(Errors.NONE, producerId, epoch); + runUntil(() -> transactionManager.isPartitionAdded(tp0)); // Send AddPartitionsRequest + runUntil(responseFuture0::isDone); + + Future responseFuture1 = appendToAccumulator(tp0); + prepareProduceResponse(Errors.NONE, producerId, epoch); + runUntil(responseFuture1::isDone); + + Future responseFuture2 = appendToAccumulator(tp0); + prepareProduceResponse(Errors.TOPIC_AUTHORIZATION_FAILED, producerId, epoch); + runUntil(responseFuture2::isDone); // Receive abortable error + + assertTrue(transactionManager.hasAbortableError()); + + TransactionalRequestResult abortResult = transactionManager.beginAbort(); + prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, producerId, epoch); + runUntil(abortResult::isCompleted); + assertTrue(abortResult.isSuccessful()); + assertTrue(transactionManager.isReady()); // make sure we are ready for a transaction now. + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, producerId); + runUntil(() -> transactionManager.isPartitionAdded(tp0)); // Send AddPartitionsRequest + + assertEquals(2, transactionManager.sequenceNumber(tp0).intValue()); + } + + @Test + public void testAbortTransactionAndResetSequenceNumberOnUnknownProducerId() throws InterruptedException { + // Set the InitProducerId version such that bumping the epoch number is not supported. This will test the case + // where the sequence number is reset on an UnknownProducerId error, allowing subsequent transactions to + // append to the log successfully + apiVersions.update("0", new NodeApiVersions(Arrays.asList( + new ApiVersion() + .setApiKey(ApiKeys.INIT_PRODUCER_ID.id) + .setMinVersion((short) 0) + .setMaxVersion((short) 1), + new ApiVersion() + .setApiKey(ApiKeys.PRODUCE.id) + .setMinVersion((short) 0) + .setMaxVersion((short) 7)))); + + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + + transactionManager.maybeAddPartitionToTransaction(tp1); + Future successPartitionResponseFuture = appendToAccumulator(tp1); + prepareAddPartitionsToTxnResponse(Errors.NONE, tp1, epoch, producerId); + prepareProduceResponse(Errors.NONE, producerId, epoch, tp1); + runUntil(successPartitionResponseFuture::isDone); + assertTrue(transactionManager.isPartitionAdded(tp1)); + + transactionManager.maybeAddPartitionToTransaction(tp0); + Future responseFuture0 = appendToAccumulator(tp0); + prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, producerId); + prepareProduceResponse(Errors.NONE, producerId, epoch); + runUntil(responseFuture0::isDone); + assertTrue(transactionManager.isPartitionAdded(tp0)); + + Future responseFuture1 = appendToAccumulator(tp0); + prepareProduceResponse(Errors.NONE, producerId, epoch); + runUntil(responseFuture1::isDone); + + Future responseFuture2 = appendToAccumulator(tp0); + client.prepareResponse(produceRequestMatcher(producerId, epoch, tp0), + produceResponse(tp0, 0, Errors.UNKNOWN_PRODUCER_ID, 0, 0)); + runUntil(responseFuture2::isDone); + + assertTrue(transactionManager.hasAbortableError()); + + TransactionalRequestResult abortResult = transactionManager.beginAbort(); + prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, producerId, epoch); + runUntil(abortResult::isCompleted); + assertTrue(abortResult.isSuccessful()); + assertTrue(transactionManager.isReady()); // make sure we are ready for a transaction now. + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, producerId); + runUntil(() -> transactionManager.isPartitionAdded(tp0)); + + assertEquals(0, transactionManager.sequenceNumber(tp0).intValue()); + assertEquals(1, transactionManager.sequenceNumber(tp1).intValue()); + } + + @Test + public void testBumpTransactionalEpochOnAbortableError() throws InterruptedException { + final short initialEpoch = 1; + final short bumpedEpoch = initialEpoch + 1; + + doInitTransactions(producerId, initialEpoch); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, initialEpoch, producerId); + runUntil(() -> transactionManager.isPartitionAdded(tp0)); + + Future responseFuture0 = appendToAccumulator(tp0); + prepareProduceResponse(Errors.NONE, producerId, initialEpoch); + runUntil(responseFuture0::isDone); + + Future responseFuture1 = appendToAccumulator(tp0); + prepareProduceResponse(Errors.NONE, producerId, initialEpoch); + runUntil(responseFuture1::isDone); + + Future responseFuture2 = appendToAccumulator(tp0); + prepareProduceResponse(Errors.TOPIC_AUTHORIZATION_FAILED, producerId, initialEpoch); + runUntil(responseFuture2::isDone); + + assertTrue(transactionManager.hasAbortableError()); + TransactionalRequestResult abortResult = transactionManager.beginAbort(); + + prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, producerId, initialEpoch); + prepareInitPidResponse(Errors.NONE, false, producerId, bumpedEpoch); + runUntil(() -> transactionManager.producerIdAndEpoch().epoch == bumpedEpoch); + + assertTrue(abortResult.isCompleted()); + assertTrue(abortResult.isSuccessful()); + assertTrue(transactionManager.isReady()); // make sure we are ready for a transaction now. + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, bumpedEpoch, producerId); + runUntil(() -> transactionManager.isPartitionAdded(tp0)); + + assertEquals(0, transactionManager.sequenceNumber(tp0).intValue()); + } + + @Test + public void testBumpTransactionalEpochOnUnknownProducerIdError() throws InterruptedException { + final short initialEpoch = 1; + final short bumpedEpoch = 2; + + doInitTransactions(producerId, initialEpoch); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, initialEpoch, producerId); + runUntil(() -> transactionManager.isPartitionAdded(tp0)); + + Future responseFuture0 = appendToAccumulator(tp0); + prepareProduceResponse(Errors.NONE, producerId, initialEpoch); + runUntil(responseFuture0::isDone); + + Future responseFuture1 = appendToAccumulator(tp0); + prepareProduceResponse(Errors.NONE, producerId, initialEpoch); + runUntil(responseFuture1::isDone); + + Future responseFuture2 = appendToAccumulator(tp0); + client.prepareResponse(produceRequestMatcher(producerId, initialEpoch, tp0), + produceResponse(tp0, 0, Errors.UNKNOWN_PRODUCER_ID, 0, 0)); + runUntil(responseFuture2::isDone); + + assertTrue(transactionManager.hasAbortableError()); + TransactionalRequestResult abortResult = transactionManager.beginAbort(); + + prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, producerId, initialEpoch); + prepareInitPidResponse(Errors.NONE, false, producerId, bumpedEpoch); + runUntil(() -> transactionManager.producerIdAndEpoch().epoch == bumpedEpoch); + + assertTrue(abortResult.isCompleted()); + assertTrue(abortResult.isSuccessful()); + assertTrue(transactionManager.isReady()); // make sure we are ready for a transaction now. + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, bumpedEpoch, producerId); + runUntil(() -> transactionManager.isPartitionAdded(tp0)); + + assertEquals(0, transactionManager.sequenceNumber(tp0).intValue()); + } + + @Test + public void testBumpTransactionalEpochOnTimeout() throws InterruptedException { + final short initialEpoch = 1; + final short bumpedEpoch = 2; + + doInitTransactions(producerId, initialEpoch); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, initialEpoch, producerId); + runUntil(() -> transactionManager.isPartitionAdded(tp0)); + + Future responseFuture0 = appendToAccumulator(tp0); + prepareProduceResponse(Errors.NONE, producerId, initialEpoch); + runUntil(responseFuture0::isDone); + + Future responseFuture1 = appendToAccumulator(tp0); + prepareProduceResponse(Errors.NONE, producerId, initialEpoch); + runUntil(responseFuture1::isDone); + + Future responseFuture2 = appendToAccumulator(tp0); + runUntil(client::hasInFlightRequests); // Send Produce Request + + // Sleep 10 seconds to make sure that the batches in the queue would be expired if they can't be drained. + time.sleep(10000); + // Disconnect the target node for the pending produce request. This will ensure that sender will try to + // expire the batch. + Node clusterNode = metadata.fetch().nodes().get(0); + client.disconnect(clusterNode.idString()); + client.backoff(clusterNode, 100); + + runUntil(responseFuture2::isDone); // We should try to flush the produce, but expire it instead without sending anything. + + assertTrue(transactionManager.hasAbortableError()); + TransactionalRequestResult abortResult = transactionManager.beginAbort(); + + sender.runOnce(); // handle the abort + time.sleep(110); // Sleep to make sure the node backoff period has passed + + prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.TRANSACTION, transactionalId); + prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, producerId, initialEpoch); + prepareInitPidResponse(Errors.NONE, false, producerId, bumpedEpoch); + runUntil(() -> transactionManager.producerIdAndEpoch().epoch == bumpedEpoch); + + assertTrue(abortResult.isCompleted()); + assertTrue(abortResult.isSuccessful()); + assertTrue(transactionManager.isReady()); // make sure we are ready for a transaction now. + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, bumpedEpoch, producerId); + runUntil(() -> transactionManager.isPartitionAdded(tp0)); + + assertEquals(0, transactionManager.sequenceNumber(tp0).intValue()); + } + + @Test + public void testBumpTransactionalEpochOnRecoverableAddPartitionRequestError() { + final short initialEpoch = 1; + final short bumpedEpoch = 2; + + doInitTransactions(producerId, initialEpoch); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + prepareAddPartitionsToTxnResponse(Errors.INVALID_PRODUCER_ID_MAPPING, tp0, initialEpoch, producerId); + runUntil(transactionManager::hasAbortableError); + TransactionalRequestResult abortResult = transactionManager.beginAbort(); + + prepareInitPidResponse(Errors.NONE, false, producerId, bumpedEpoch); + runUntil(abortResult::isCompleted); + assertEquals(bumpedEpoch, transactionManager.producerIdAndEpoch().epoch); + assertTrue(abortResult.isSuccessful()); + assertTrue(transactionManager.isReady()); // make sure we are ready for a transaction now. + } + + @Test + public void testBumpTransactionalEpochOnRecoverableAddOffsetsRequestError() throws InterruptedException { + final short initialEpoch = 1; + final short bumpedEpoch = 2; + + doInitTransactions(producerId, initialEpoch); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + Future responseFuture = appendToAccumulator(tp0); + + assertFalse(responseFuture.isDone()); + prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, initialEpoch, producerId); + prepareProduceResponse(Errors.NONE, producerId, initialEpoch); + runUntil(responseFuture::isDone); + + Map offsets = new HashMap<>(); + offsets.put(tp0, new OffsetAndMetadata(1)); + transactionManager.sendOffsetsToTransaction(offsets, new ConsumerGroupMetadata(consumerGroupId)); + assertFalse(transactionManager.hasPendingOffsetCommits()); + prepareAddOffsetsToTxnResponse(Errors.INVALID_PRODUCER_ID_MAPPING, consumerGroupId, producerId, initialEpoch); + runUntil(transactionManager::hasAbortableError); // Send AddOffsetsRequest + TransactionalRequestResult abortResult = transactionManager.beginAbort(); + + prepareEndTxnResponse(Errors.NONE, TransactionResult.ABORT, producerId, initialEpoch); + prepareInitPidResponse(Errors.NONE, false, producerId, bumpedEpoch); + runUntil(abortResult::isCompleted); + assertEquals(bumpedEpoch, transactionManager.producerIdAndEpoch().epoch); + assertTrue(abortResult.isSuccessful()); + assertTrue(transactionManager.isReady()); // make sure we are ready for a transaction now. + } + + @Test + public void testHealthyPartitionRetriesDuringEpochBump() throws InterruptedException { + // Use a custom Sender to allow multiple inflight requests + initializeTransactionManager(Optional.empty()); + Sender sender = new Sender(logContext, this.client, this.metadata, this.accumulator, false, + MAX_REQUEST_SIZE, ACKS_ALL, MAX_RETRIES, new SenderMetricsRegistry(new Metrics(time)), this.time, + REQUEST_TIMEOUT, 50, transactionManager, apiVersions); + initializeIdempotentProducerId(producerId, epoch); + + ProducerBatch tp0b1 = writeIdempotentBatchWithValue(transactionManager, tp0, "1"); + ProducerBatch tp0b2 = writeIdempotentBatchWithValue(transactionManager, tp0, "2"); + writeIdempotentBatchWithValue(transactionManager, tp0, "3"); + ProducerBatch tp1b1 = writeIdempotentBatchWithValue(transactionManager, tp1, "4"); + ProducerBatch tp1b2 = writeIdempotentBatchWithValue(transactionManager, tp1, "5"); + assertEquals(3, transactionManager.sequenceNumber(tp0).intValue()); + assertEquals(2, transactionManager.sequenceNumber(tp1).intValue()); + + // First batch of each partition succeeds + long b1AppendTime = time.milliseconds(); + ProduceResponse.PartitionResponse t0b1Response = new ProduceResponse.PartitionResponse( + Errors.NONE, 500L, b1AppendTime, 0L); + tp0b1.complete(500L, b1AppendTime); + transactionManager.handleCompletedBatch(tp0b1, t0b1Response); + + ProduceResponse.PartitionResponse t1b1Response = new ProduceResponse.PartitionResponse( + Errors.NONE, 500L, b1AppendTime, 0L); + tp1b1.complete(500L, b1AppendTime); + transactionManager.handleCompletedBatch(tp1b1, t1b1Response); + + // We bump the epoch and set sequence numbers back to 0 + ProduceResponse.PartitionResponse t0b2Response = new ProduceResponse.PartitionResponse( + Errors.UNKNOWN_PRODUCER_ID, -1, -1, 500L); + assertTrue(transactionManager.canRetry(t0b2Response, tp0b2)); + + // Run sender loop to trigger epoch bump + runUntil(() -> transactionManager.producerIdAndEpoch().epoch == 2); + + // tp0 batches should have had sequence and epoch rewritten, but tp1 batches should not + assertEquals(tp0b2, transactionManager.nextBatchBySequence(tp0)); + assertEquals(0, transactionManager.firstInFlightSequence(tp0)); + assertEquals(0, tp0b2.baseSequence()); + assertTrue(tp0b2.sequenceHasBeenReset()); + assertEquals(2, tp0b2.producerEpoch()); + + assertEquals(tp1b2, transactionManager.nextBatchBySequence(tp1)); + assertEquals(1, transactionManager.firstInFlightSequence(tp1)); + assertEquals(1, tp1b2.baseSequence()); + assertFalse(tp1b2.sequenceHasBeenReset()); + assertEquals(1, tp1b2.producerEpoch()); + + // New tp1 batches should not be drained from the accumulator while tp1 has in-flight requests using the old epoch + appendToAccumulator(tp1); + sender.runOnce(); + assertEquals(1, accumulator.batches().get(tp1).size()); + + // Partition failover occurs and tp1 returns a NOT_LEADER_OR_FOLLOWER error + // Despite having the old epoch, the batch should retry + ProduceResponse.PartitionResponse t1b2Response = new ProduceResponse.PartitionResponse( + Errors.NOT_LEADER_OR_FOLLOWER, -1, -1, 600L); + assertTrue(transactionManager.canRetry(t1b2Response, tp1b2)); + accumulator.reenqueue(tp1b2, time.milliseconds()); + + // The batch with the old epoch should be successfully drained, leaving the new one in the queue + sender.runOnce(); + assertEquals(1, accumulator.batches().get(tp1).size()); + assertNotEquals(tp1b2, accumulator.batches().get(tp1).peek()); + assertEquals(epoch, tp1b2.producerEpoch()); + + // After successfully retrying, there should be no in-flight batches for tp1 and the sequence should be 0 + t1b2Response = new ProduceResponse.PartitionResponse( + Errors.NONE, 500L, b1AppendTime, 0L); + tp1b2.complete(500L, b1AppendTime); + transactionManager.handleCompletedBatch(tp1b2, t1b2Response); + + transactionManager.maybeUpdateProducerIdAndEpoch(tp1); + assertFalse(transactionManager.hasInflightBatches(tp1)); + assertEquals(0, transactionManager.sequenceNumber(tp1).intValue()); + + // The last batch should now be drained and sent + runUntil(() -> transactionManager.hasInflightBatches(tp1)); + assertTrue(accumulator.batches().get(tp1).isEmpty()); + ProducerBatch tp1b3 = transactionManager.nextBatchBySequence(tp1); + assertEquals(epoch + 1, tp1b3.producerEpoch()); + + ProduceResponse.PartitionResponse t1b3Response = new ProduceResponse.PartitionResponse( + Errors.NONE, 500L, b1AppendTime, 0L); + tp1b3.complete(500L, b1AppendTime); + transactionManager.handleCompletedBatch(tp1b3, t1b3Response); + + transactionManager.maybeUpdateProducerIdAndEpoch(tp1); + assertFalse(transactionManager.hasInflightBatches(tp1)); + assertEquals(1, transactionManager.sequenceNumber(tp1).intValue()); + } + + @Test + public void testRetryAbortTransaction() throws InterruptedException { + verifyCommitOrAbortTransactionRetriable(TransactionResult.ABORT, TransactionResult.ABORT); + } + + @Test + public void testRetryCommitTransaction() throws InterruptedException { + verifyCommitOrAbortTransactionRetriable(TransactionResult.COMMIT, TransactionResult.COMMIT); + } + + @Test + public void testRetryAbortTransactionAfterCommitTimeout() { + assertThrows(KafkaException.class, () -> verifyCommitOrAbortTransactionRetriable(TransactionResult.COMMIT, TransactionResult.ABORT)); + } + + @Test + public void testRetryCommitTransactionAfterAbortTimeout() { + assertThrows(KafkaException.class, () -> verifyCommitOrAbortTransactionRetriable(TransactionResult.ABORT, TransactionResult.COMMIT)); + } + + @Test + public void testCanBumpEpochDuringCoordinatorDisconnect() { + doInitTransactions(0, (short) 0); + runUntil(() -> transactionManager.coordinator(CoordinatorType.TRANSACTION) != null); + assertTrue(transactionManager.canBumpEpoch()); + + apiVersions.remove(transactionManager.coordinator(CoordinatorType.TRANSACTION).idString()); + assertTrue(transactionManager.canBumpEpoch()); + } + + @Test + public void testFailedInflightBatchAfterEpochBump() throws InterruptedException { + // Use a custom Sender to allow multiple inflight requests + initializeTransactionManager(Optional.empty()); + Sender sender = new Sender(logContext, this.client, this.metadata, this.accumulator, false, + MAX_REQUEST_SIZE, ACKS_ALL, MAX_RETRIES, new SenderMetricsRegistry(new Metrics(time)), this.time, + REQUEST_TIMEOUT, 50, transactionManager, apiVersions); + initializeIdempotentProducerId(producerId, epoch); + + ProducerBatch tp0b1 = writeIdempotentBatchWithValue(transactionManager, tp0, "1"); + ProducerBatch tp0b2 = writeIdempotentBatchWithValue(transactionManager, tp0, "2"); + writeIdempotentBatchWithValue(transactionManager, tp0, "3"); + ProducerBatch tp1b1 = writeIdempotentBatchWithValue(transactionManager, tp1, "4"); + ProducerBatch tp1b2 = writeIdempotentBatchWithValue(transactionManager, tp1, "5"); + assertEquals(3, transactionManager.sequenceNumber(tp0).intValue()); + assertEquals(2, transactionManager.sequenceNumber(tp1).intValue()); + + // First batch of each partition succeeds + long b1AppendTime = time.milliseconds(); + ProduceResponse.PartitionResponse t0b1Response = new ProduceResponse.PartitionResponse( + Errors.NONE, 500L, b1AppendTime, 0L); + tp0b1.complete(500L, b1AppendTime); + transactionManager.handleCompletedBatch(tp0b1, t0b1Response); + + ProduceResponse.PartitionResponse t1b1Response = new ProduceResponse.PartitionResponse( + Errors.NONE, 500L, b1AppendTime, 0L); + tp1b1.complete(500L, b1AppendTime); + transactionManager.handleCompletedBatch(tp1b1, t1b1Response); + + // We bump the epoch and set sequence numbers back to 0 + ProduceResponse.PartitionResponse t0b2Response = new ProduceResponse.PartitionResponse( + Errors.UNKNOWN_PRODUCER_ID, -1, -1, 500L); + assertTrue(transactionManager.canRetry(t0b2Response, tp0b2)); + + // Run sender loop to trigger epoch bump + runUntil(() -> transactionManager.producerIdAndEpoch().epoch == 2); + + // tp0 batches should have had sequence and epoch rewritten, but tp1 batches should not + assertEquals(tp0b2, transactionManager.nextBatchBySequence(tp0)); + assertEquals(0, transactionManager.firstInFlightSequence(tp0)); + assertEquals(0, tp0b2.baseSequence()); + assertTrue(tp0b2.sequenceHasBeenReset()); + assertEquals(2, tp0b2.producerEpoch()); + + assertEquals(tp1b2, transactionManager.nextBatchBySequence(tp1)); + assertEquals(1, transactionManager.firstInFlightSequence(tp1)); + assertEquals(1, tp1b2.baseSequence()); + assertFalse(tp1b2.sequenceHasBeenReset()); + assertEquals(1, tp1b2.producerEpoch()); + + // New tp1 batches should not be drained from the accumulator while tp1 has in-flight requests using the old epoch + appendToAccumulator(tp1); + sender.runOnce(); + assertEquals(1, accumulator.batches().get(tp1).size()); + + // Partition failover occurs and tp1 returns a NOT_LEADER_OR_FOLLOWER error + // Despite having the old epoch, the batch should retry + ProduceResponse.PartitionResponse t1b2Response = new ProduceResponse.PartitionResponse( + Errors.NOT_LEADER_OR_FOLLOWER, -1, -1, 600L); + assertTrue(transactionManager.canRetry(t1b2Response, tp1b2)); + accumulator.reenqueue(tp1b2, time.milliseconds()); + + // The batch with the old epoch should be successfully drained, leaving the new one in the queue + sender.runOnce(); + assertEquals(1, accumulator.batches().get(tp1).size()); + assertNotEquals(tp1b2, accumulator.batches().get(tp1).peek()); + assertEquals(epoch, tp1b2.producerEpoch()); + + // After successfully retrying, there should be no in-flight batches for tp1 and the sequence should be 0 + t1b2Response = new ProduceResponse.PartitionResponse( + Errors.NONE, 500L, b1AppendTime, 0L); + tp1b2.complete(500L, b1AppendTime); + transactionManager.handleCompletedBatch(tp1b2, t1b2Response); + + transactionManager.maybeUpdateProducerIdAndEpoch(tp1); + assertFalse(transactionManager.hasInflightBatches(tp1)); + assertEquals(0, transactionManager.sequenceNumber(tp1).intValue()); + + // The last batch should now be drained and sent + runUntil(() -> transactionManager.hasInflightBatches(tp1)); + assertTrue(accumulator.batches().get(tp1).isEmpty()); + ProducerBatch tp1b3 = transactionManager.nextBatchBySequence(tp1); + assertEquals(epoch + 1, tp1b3.producerEpoch()); + + ProduceResponse.PartitionResponse t1b3Response = new ProduceResponse.PartitionResponse( + Errors.NONE, 500L, b1AppendTime, 0L); + tp1b3.complete(500L, b1AppendTime); + transactionManager.handleCompletedBatch(tp1b3, t1b3Response); + + assertFalse(transactionManager.hasInflightBatches(tp1)); + assertEquals(1, transactionManager.sequenceNumber(tp1).intValue()); + } + + private FutureRecordMetadata appendToAccumulator(TopicPartition tp) throws InterruptedException { + final long nowMs = time.milliseconds(); + return accumulator.append(tp, nowMs, "key".getBytes(), "value".getBytes(), Record.EMPTY_HEADERS, + null, MAX_BLOCK_TIMEOUT, false, nowMs).future; + } + + private void verifyCommitOrAbortTransactionRetriable(TransactionResult firstTransactionResult, + TransactionResult retryTransactionResult) throws InterruptedException { + doInitTransactions(); + + transactionManager.beginTransaction(); + transactionManager.failIfNotReadyForSend(); + transactionManager.maybeAddPartitionToTransaction(tp0); + + appendToAccumulator(tp0); + + prepareAddPartitionsToTxnResponse(Errors.NONE, tp0, epoch, producerId); + prepareProduceResponse(Errors.NONE, producerId, epoch); + runUntil(() -> !client.hasPendingResponses()); + + TransactionalRequestResult result = firstTransactionResult == TransactionResult.COMMIT ? + transactionManager.beginCommit() : transactionManager.beginAbort(); + prepareEndTxnResponse(Errors.NONE, firstTransactionResult, producerId, epoch, true); + runUntil(() -> !client.hasPendingResponses()); + assertFalse(result.isCompleted()); + + try { + result.await(MAX_BLOCK_TIMEOUT, TimeUnit.MILLISECONDS); + fail("Should have raised TimeoutException"); + } catch (TimeoutException ignored) { + } + + prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.TRANSACTION, transactionalId); + runUntil(() -> !client.hasPendingResponses()); + TransactionalRequestResult retryResult = retryTransactionResult == TransactionResult.COMMIT ? + transactionManager.beginCommit() : transactionManager.beginAbort(); + assertEquals(retryResult, result); // check if cached result is reused. + + prepareEndTxnResponse(Errors.NONE, retryTransactionResult, producerId, epoch, false); + runUntil(retryResult::isCompleted); + assertFalse(transactionManager.hasOngoingTransaction()); + } + + private void prepareAddPartitionsToTxn(final Map errors) { + client.prepareResponse(body -> { + AddPartitionsToTxnRequest request = (AddPartitionsToTxnRequest) body; + assertEquals(new HashSet<>(request.partitions()), new HashSet<>(errors.keySet())); + return true; + }, new AddPartitionsToTxnResponse(0, errors)); + } + + private void prepareAddPartitionsToTxn(final TopicPartition tp, final Errors error) { + prepareAddPartitionsToTxn(Collections.singletonMap(tp, error)); + } + + private void prepareFindCoordinatorResponse(Errors error, boolean shouldDisconnect, + final CoordinatorType coordinatorType, + final String coordinatorKey) { + client.prepareResponse(body -> { + FindCoordinatorRequest findCoordinatorRequest = (FindCoordinatorRequest) body; + assertEquals(coordinatorType, CoordinatorType.forId(findCoordinatorRequest.data().keyType())); + String key = findCoordinatorRequest.data().coordinatorKeys().isEmpty() + ? findCoordinatorRequest.data().key() + : findCoordinatorRequest.data().coordinatorKeys().get(0); + assertEquals(coordinatorKey, key); + return true; + }, FindCoordinatorResponse.prepareResponse(error, coordinatorKey, brokerNode), shouldDisconnect); + } + + private void prepareInitPidResponse(Errors error, boolean shouldDisconnect, long producerId, short producerEpoch) { + InitProducerIdResponseData responseData = new InitProducerIdResponseData() + .setErrorCode(error.code()) + .setProducerEpoch(producerEpoch) + .setProducerId(producerId) + .setThrottleTimeMs(0); + client.prepareResponse(body -> { + InitProducerIdRequest initProducerIdRequest = (InitProducerIdRequest) body; + assertEquals(transactionalId, initProducerIdRequest.data().transactionalId()); + assertEquals(transactionTimeoutMs, initProducerIdRequest.data().transactionTimeoutMs()); + return true; + }, new InitProducerIdResponse(responseData), shouldDisconnect); + } + + private void sendProduceResponse(Errors error, final long producerId, final short producerEpoch) { + sendProduceResponse(error, producerId, producerEpoch, tp0); + } + + private void sendProduceResponse(Errors error, final long producerId, final short producerEpoch, TopicPartition tp) { + client.respond(produceRequestMatcher(producerId, producerEpoch, tp), produceResponse(tp, 0, error, 0)); + } + + private void prepareProduceResponse(Errors error, final long producerId, final short producerEpoch) { + prepareProduceResponse(error, producerId, producerEpoch, tp0); + } + + private void prepareProduceResponse(Errors error, final long producerId, final short producerEpoch, TopicPartition tp) { + client.prepareResponse(produceRequestMatcher(producerId, producerEpoch, tp), produceResponse(tp, 0, error, 0)); + } + + private MockClient.RequestMatcher produceRequestMatcher(final long producerId, final short epoch, TopicPartition tp) { + return body -> { + ProduceRequest produceRequest = (ProduceRequest) body; + MemoryRecords records = produceRequest.data().topicData() + .stream() + .filter(t -> t.name().equals(tp.topic())) + .findAny() + .get() + .partitionData() + .stream() + .filter(p -> p.index() == tp.partition()) + .map(p -> (MemoryRecords) p.records()) + .findAny().get(); + assertNotNull(records); + Iterator batchIterator = records.batches().iterator(); + assertTrue(batchIterator.hasNext()); + MutableRecordBatch batch = batchIterator.next(); + assertFalse(batchIterator.hasNext()); + assertTrue(batch.isTransactional()); + assertEquals(producerId, batch.producerId()); + assertEquals(epoch, batch.producerEpoch()); + assertEquals(transactionalId, produceRequest.transactionalId()); + return true; + }; + } + + private void prepareAddPartitionsToTxnResponse(Errors error, final TopicPartition topicPartition, + final short epoch, final long producerId) { + client.prepareResponse(addPartitionsRequestMatcher(topicPartition, epoch, producerId), + new AddPartitionsToTxnResponse(0, singletonMap(topicPartition, error))); + } + + private void sendAddPartitionsToTxnResponse(Errors error, final TopicPartition topicPartition, + final short epoch, final long producerId) { + client.respond(addPartitionsRequestMatcher(topicPartition, epoch, producerId), + new AddPartitionsToTxnResponse(0, singletonMap(topicPartition, error))); + } + + private MockClient.RequestMatcher addPartitionsRequestMatcher(final TopicPartition topicPartition, + final short epoch, final long producerId) { + return body -> { + AddPartitionsToTxnRequest addPartitionsToTxnRequest = (AddPartitionsToTxnRequest) body; + assertEquals(producerId, addPartitionsToTxnRequest.data().producerId()); + assertEquals(epoch, addPartitionsToTxnRequest.data().producerEpoch()); + assertEquals(singletonList(topicPartition), addPartitionsToTxnRequest.partitions()); + assertEquals(transactionalId, addPartitionsToTxnRequest.data().transactionalId()); + return true; + }; + } + + private void prepareEndTxnResponse(Errors error, final TransactionResult result, final long producerId, final short epoch) { + this.prepareEndTxnResponse(error, result, producerId, epoch, false); + } + + private void prepareEndTxnResponse(Errors error, + final TransactionResult result, + final long producerId, + final short epoch, + final boolean shouldDisconnect) { + client.prepareResponse(endTxnMatcher(result, producerId, epoch), + new EndTxnResponse(new EndTxnResponseData() + .setErrorCode(error.code()) + .setThrottleTimeMs(0)), shouldDisconnect); + } + + private void sendEndTxnResponse(Errors error, final TransactionResult result, final long producerId, final short epoch) { + client.respond(endTxnMatcher(result, producerId, epoch), new EndTxnResponse( + new EndTxnResponseData() + .setErrorCode(error.code()) + .setThrottleTimeMs(0) + )); + } + + private MockClient.RequestMatcher endTxnMatcher(final TransactionResult result, final long producerId, final short epoch) { + return body -> { + EndTxnRequest endTxnRequest = (EndTxnRequest) body; + assertEquals(transactionalId, endTxnRequest.data().transactionalId()); + assertEquals(producerId, endTxnRequest.data().producerId()); + assertEquals(epoch, endTxnRequest.data().producerEpoch()); + assertEquals(result, endTxnRequest.result()); + return true; + }; + } + + private void prepareAddOffsetsToTxnResponse(final Errors error, + final String consumerGroupId, + final long producerId, + final short producerEpoch) { + client.prepareResponse(body -> { + AddOffsetsToTxnRequest addOffsetsToTxnRequest = (AddOffsetsToTxnRequest) body; + assertEquals(consumerGroupId, addOffsetsToTxnRequest.data().groupId()); + assertEquals(transactionalId, addOffsetsToTxnRequest.data().transactionalId()); + assertEquals(producerId, addOffsetsToTxnRequest.data().producerId()); + assertEquals(producerEpoch, addOffsetsToTxnRequest.data().producerEpoch()); + return true; + }, new AddOffsetsToTxnResponse( + new AddOffsetsToTxnResponseData() + .setErrorCode(error.code())) + ); + } + + private void prepareTxnOffsetCommitResponse(final String consumerGroupId, + final long producerId, + final short producerEpoch, + Map txnOffsetCommitResponse) { + client.prepareResponse(request -> { + TxnOffsetCommitRequest txnOffsetCommitRequest = (TxnOffsetCommitRequest) request; + assertEquals(consumerGroupId, txnOffsetCommitRequest.data().groupId()); + assertEquals(producerId, txnOffsetCommitRequest.data().producerId()); + assertEquals(producerEpoch, txnOffsetCommitRequest.data().producerEpoch()); + return true; + }, new TxnOffsetCommitResponse(0, txnOffsetCommitResponse)); + } + + private void prepareTxnOffsetCommitResponse(final String consumerGroupId, + final long producerId, + final short producerEpoch, + final String groupInstanceId, + final String memberId, + final int generationId, + Map txnOffsetCommitResponse) { + client.prepareResponse(request -> { + TxnOffsetCommitRequest txnOffsetCommitRequest = (TxnOffsetCommitRequest) request; + assertEquals(consumerGroupId, txnOffsetCommitRequest.data().groupId()); + assertEquals(producerId, txnOffsetCommitRequest.data().producerId()); + assertEquals(producerEpoch, txnOffsetCommitRequest.data().producerEpoch()); + assertEquals(groupInstanceId, txnOffsetCommitRequest.data().groupInstanceId()); + assertEquals(memberId, txnOffsetCommitRequest.data().memberId()); + assertEquals(generationId, txnOffsetCommitRequest.data().generationId()); + return true; + }, new TxnOffsetCommitResponse(0, txnOffsetCommitResponse)); + } + + private ProduceResponse produceResponse(TopicPartition tp, long offset, Errors error, int throttleTimeMs) { + return produceResponse(tp, offset, error, throttleTimeMs, 10); + } + + @SuppressWarnings("deprecation") + private ProduceResponse produceResponse(TopicPartition tp, long offset, Errors error, int throttleTimeMs, int logStartOffset) { + ProduceResponse.PartitionResponse resp = new ProduceResponse.PartitionResponse(error, offset, RecordBatch.NO_TIMESTAMP, logStartOffset); + Map partResp = singletonMap(tp, resp); + return new ProduceResponse(partResp, throttleTimeMs); + } + + private void initializeIdempotentProducerId(long producerId, short epoch) { + InitProducerIdResponseData responseData = new InitProducerIdResponseData() + .setErrorCode(Errors.NONE.code()) + .setProducerEpoch(epoch) + .setProducerId(producerId) + .setThrottleTimeMs(0); + client.prepareResponse(body -> { + InitProducerIdRequest initProducerIdRequest = (InitProducerIdRequest) body; + assertNull(initProducerIdRequest.data().transactionalId()); + return true; + }, new InitProducerIdResponse(responseData), false); + + runUntil(transactionManager::hasProducerId); + } + + private void doInitTransactions() { + doInitTransactions(producerId, epoch); + } + + private void doInitTransactions(long producerId, short epoch) { + transactionManager.initializeTransactions(); + prepareFindCoordinatorResponse(Errors.NONE, false, CoordinatorType.TRANSACTION, transactionalId); + runUntil(() -> transactionManager.coordinator(CoordinatorType.TRANSACTION) != null); + assertEquals(brokerNode, transactionManager.coordinator(CoordinatorType.TRANSACTION)); + + prepareInitPidResponse(Errors.NONE, false, producerId, epoch); + runUntil(transactionManager::hasProducerId); + } + + private void assertAbortableError(Class cause) { + try { + transactionManager.beginCommit(); + fail("Should have raised " + cause.getSimpleName()); + } catch (KafkaException e) { + assertTrue(cause.isAssignableFrom(e.getCause().getClass())); + assertTrue(transactionManager.hasError()); + } + + assertTrue(transactionManager.hasError()); + transactionManager.beginAbort(); + assertFalse(transactionManager.hasError()); + } + + private void assertFatalError(Class cause) { + assertTrue(transactionManager.hasError()); + + try { + transactionManager.beginAbort(); + fail("Should have raised " + cause.getSimpleName()); + } catch (KafkaException e) { + assertTrue(cause.isAssignableFrom(e.getCause().getClass())); + assertTrue(transactionManager.hasError()); + } + + // Transaction abort cannot clear fatal error state + try { + transactionManager.beginAbort(); + fail("Should have raised " + cause.getSimpleName()); + } catch (KafkaException e) { + assertTrue(cause.isAssignableFrom(e.getCause().getClass())); + assertTrue(transactionManager.hasError()); + } + } + + private void assertProduceFutureFailed(Future future) throws InterruptedException { + assertTrue(future.isDone()); + + try { + future.get(); + fail("Expected produce future to throw"); + } catch (ExecutionException e) { + // expected + } + } + + private void runUntil(Supplier condition) { + ProducerTestUtils.runUntil(sender, condition); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/ClusterTest.java b/clients/src/test/java/org/apache/kafka/common/ClusterTest.java new file mode 100644 index 0000000..4c6db86 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/ClusterTest.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common; + +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.api.Test; + +import java.net.InetSocketAddress; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static java.util.Arrays.asList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class ClusterTest { + + private final static Node[] NODES = new Node[] { + new Node(0, "localhost", 99), + new Node(1, "localhost", 100), + new Node(2, "localhost", 101), + new Node(11, "localhost", 102) + }; + + private final static String TOPIC_A = "topicA"; + private final static String TOPIC_B = "topicB"; + private final static String TOPIC_C = "topicC"; + private final static String TOPIC_D = "topicD"; + private final static String TOPIC_E = "topicE"; + + @Test + public void testBootstrap() { + String ipAddress = "140.211.11.105"; + String hostName = "www.example.com"; + Cluster cluster = Cluster.bootstrap(Arrays.asList( + new InetSocketAddress(ipAddress, 9002), + new InetSocketAddress(hostName, 9002) + )); + Set expectedHosts = Utils.mkSet(ipAddress, hostName); + Set actualHosts = new HashSet<>(); + for (Node node : cluster.nodes()) + actualHosts.add(node.host()); + assertEquals(expectedHosts, actualHosts); + } + + @Test + public void testReturnUnmodifiableCollections() { + List allPartitions = asList(new PartitionInfo(TOPIC_A, 0, NODES[0], NODES, NODES), + new PartitionInfo(TOPIC_A, 1, null, NODES, NODES), + new PartitionInfo(TOPIC_A, 2, NODES[2], NODES, NODES), + new PartitionInfo(TOPIC_B, 0, null, NODES, NODES), + new PartitionInfo(TOPIC_B, 1, NODES[0], NODES, NODES), + new PartitionInfo(TOPIC_C, 0, null, NODES, NODES), + new PartitionInfo(TOPIC_D, 0, NODES[1], NODES, NODES), + new PartitionInfo(TOPIC_E, 0, NODES[0], NODES, NODES) + ); + Set unauthorizedTopics = Utils.mkSet(TOPIC_C); + Set invalidTopics = Utils.mkSet(TOPIC_D); + Set internalTopics = Utils.mkSet(TOPIC_E); + Cluster cluster = new Cluster("clusterId", asList(NODES), allPartitions, unauthorizedTopics, + invalidTopics, internalTopics, NODES[1]); + + assertThrows(UnsupportedOperationException.class, () -> cluster.invalidTopics().add("foo")); + assertThrows(UnsupportedOperationException.class, () -> cluster.internalTopics().add("foo")); + assertThrows(UnsupportedOperationException.class, () -> cluster.unauthorizedTopics().add("foo")); + assertThrows(UnsupportedOperationException.class, () -> cluster.topics().add("foo")); + assertThrows(UnsupportedOperationException.class, () -> cluster.nodes().add(NODES[3])); + assertThrows(UnsupportedOperationException.class, () -> cluster.partitionsForTopic(TOPIC_A).add( + new PartitionInfo(TOPIC_A, 3, NODES[0], NODES, NODES))); + assertThrows(UnsupportedOperationException.class, () -> cluster.availablePartitionsForTopic(TOPIC_B).add( + new PartitionInfo(TOPIC_B, 2, NODES[0], NODES, NODES))); + assertThrows(UnsupportedOperationException.class, () -> cluster.partitionsForNode(NODES[1].id()).add( + new PartitionInfo(TOPIC_B, 2, NODES[1], NODES, NODES))); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/KafkaFutureTest.java b/clients/src/test/java/org/apache/kafka/common/KafkaFutureTest.java new file mode 100644 index 0000000..0218ce1 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/KafkaFutureTest.java @@ -0,0 +1,648 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common; + +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.utils.Java; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CancellationException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.function.Supplier; + +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * A unit test for KafkaFuture. + */ +@Timeout(120) +public class KafkaFutureTest { + + /** Asserts that the given future is done, didn't fail and wasn't cancelled. */ + private void assertIsSuccessful(KafkaFuture future) { + assertTrue(future.isDone()); + assertFalse(future.isCompletedExceptionally()); + assertFalse(future.isCancelled()); + } + + /** Asserts that the given future is done, failed and wasn't cancelled. */ + private void assertIsFailed(KafkaFuture future) { + assertTrue(future.isDone()); + assertFalse(future.isCancelled()); + assertTrue(future.isCompletedExceptionally()); + } + + /** Asserts that the given future is done, didn't fail and was cancelled. */ + private void assertIsCancelled(KafkaFuture future) { + assertTrue(future.isDone()); + assertTrue(future.isCancelled()); + assertTrue(future.isCompletedExceptionally()); + assertThrows(CancellationException.class, () -> future.getNow(null)); + assertThrows(CancellationException.class, () -> future.get(0, TimeUnit.MILLISECONDS)); + } + + private void awaitAndAssertResult(KafkaFuture future, + T expectedResult, + T alternativeValue) { + assertNotEquals(expectedResult, alternativeValue); + try { + assertEquals(expectedResult, future.get(5, TimeUnit.MINUTES)); + } catch (Exception e) { + throw new AssertionError("Unexpected exception", e); + } + try { + assertEquals(expectedResult, future.get()); + } catch (Exception e) { + throw new AssertionError("Unexpected exception", e); + } + try { + assertEquals(expectedResult, future.getNow(alternativeValue)); + } catch (Exception e) { + throw new AssertionError("Unexpected exception", e); + } + } + + private Throwable awaitAndAssertFailure(KafkaFuture future, + Class expectedException, + String expectedMessage) { + ExecutionException executionException = assertThrows(ExecutionException.class, () -> future.get(5, TimeUnit.MINUTES)); + assertEquals(expectedException, executionException.getCause().getClass()); + assertEquals(expectedMessage, executionException.getCause().getMessage()); + + executionException = assertThrows(ExecutionException.class, () -> future.get()); + assertEquals(expectedException, executionException.getCause().getClass()); + assertEquals(expectedMessage, executionException.getCause().getMessage()); + + executionException = assertThrows(ExecutionException.class, () -> future.getNow(null)); + assertEquals(expectedException, executionException.getCause().getClass()); + assertEquals(expectedMessage, executionException.getCause().getMessage()); + return executionException.getCause(); + } + + private void awaitAndAssertCancelled(KafkaFuture future, String expectedMessage) { + CancellationException cancellationException = assertThrows(CancellationException.class, () -> future.get(5, TimeUnit.MINUTES)); + assertEquals(expectedMessage, cancellationException.getMessage()); + assertEquals(CancellationException.class, cancellationException.getClass()); + + cancellationException = assertThrows(CancellationException.class, () -> future.get()); + assertEquals(expectedMessage, cancellationException.getMessage()); + assertEquals(CancellationException.class, cancellationException.getClass()); + + cancellationException = assertThrows(CancellationException.class, () -> future.getNow(null)); + assertEquals(expectedMessage, cancellationException.getMessage()); + assertEquals(CancellationException.class, cancellationException.getClass()); + } + + @Test + public void testCompleteFutures() throws Exception { + KafkaFutureImpl future123 = new KafkaFutureImpl<>(); + assertTrue(future123.complete(123)); + assertFalse(future123.complete(456)); + assertFalse(future123.cancel(true)); + assertEquals(Integer.valueOf(123), future123.get()); + assertIsSuccessful(future123); + + KafkaFuture future456 = KafkaFuture.completedFuture(456); + assertFalse(future456.complete(789)); + assertFalse(future456.cancel(true)); + assertEquals(Integer.valueOf(456), future456.get()); + assertIsSuccessful(future456); + } + + @Test + public void testCompleteFuturesExceptionally() throws Exception { + KafkaFutureImpl futureFail = new KafkaFutureImpl<>(); + assertTrue(futureFail.completeExceptionally(new RuntimeException("We require more vespene gas"))); + assertIsFailed(futureFail); + assertFalse(futureFail.completeExceptionally(new RuntimeException("We require more minerals"))); + assertFalse(futureFail.cancel(true)); + + ExecutionException executionException = assertThrows(ExecutionException.class, () -> futureFail.get()); + assertEquals(RuntimeException.class, executionException.getCause().getClass()); + assertEquals("We require more vespene gas", executionException.getCause().getMessage()); + + KafkaFutureImpl tricky1 = new KafkaFutureImpl<>(); + assertTrue(tricky1.completeExceptionally(new CompletionException(new CancellationException()))); + assertIsFailed(tricky1); + awaitAndAssertFailure(tricky1, CompletionException.class, "java.util.concurrent.CancellationException"); + } + + @Test + public void testCompleteFuturesViaCancellation() { + KafkaFutureImpl viaCancel = new KafkaFutureImpl<>(); + assertTrue(viaCancel.cancel(true)); + assertIsCancelled(viaCancel); + awaitAndAssertCancelled(viaCancel, null); + + KafkaFutureImpl viaCancellationException = new KafkaFutureImpl<>(); + assertTrue(viaCancellationException.completeExceptionally(new CancellationException("We require more vespene gas"))); + assertIsCancelled(viaCancellationException); + awaitAndAssertCancelled(viaCancellationException, "We require more vespene gas"); + } + + @Test + public void testToString() { + KafkaFutureImpl success = new KafkaFutureImpl<>(); + assertEquals("KafkaFuture{value=null,exception=null,done=false}", success.toString()); + success.complete(12); + assertEquals("KafkaFuture{value=12,exception=null,done=true}", success.toString()); + + KafkaFutureImpl failure = new KafkaFutureImpl<>(); + failure.completeExceptionally(new RuntimeException("foo")); + assertEquals("KafkaFuture{value=null,exception=java.lang.RuntimeException: foo,done=true}", failure.toString()); + + KafkaFutureImpl tricky1 = new KafkaFutureImpl<>(); + tricky1.completeExceptionally(new CompletionException(new CancellationException())); + assertEquals("KafkaFuture{value=null,exception=java.util.concurrent.CompletionException: java.util.concurrent.CancellationException,done=true}", tricky1.toString()); + + KafkaFutureImpl cancelled = new KafkaFutureImpl<>(); + cancelled.cancel(true); + assertEquals("KafkaFuture{value=null,exception=java.util.concurrent.CancellationException,done=true}", cancelled.toString()); + } + + @Test + public void testCompletingFutures() throws Exception { + final KafkaFutureImpl future = new KafkaFutureImpl<>(); + CompleterThread myThread = new CompleterThread<>(future, "You must construct additional pylons."); + assertIsNotCompleted(future); + assertEquals("I am ready", future.getNow("I am ready")); + myThread.start(); + awaitAndAssertResult(future, "You must construct additional pylons.", "I am ready"); + assertIsSuccessful(future); + myThread.join(); + assertNull(myThread.testException); + } + + @Test + public void testCompletingFuturesExceptionally() throws Exception { + final KafkaFutureImpl future = new KafkaFutureImpl<>(); + CompleterThread myThread = new CompleterThread<>(future, null, + new RuntimeException("Ultimate efficiency achieved.")); + assertIsNotCompleted(future); + assertEquals("I am ready", future.getNow("I am ready")); + myThread.start(); + awaitAndAssertFailure(future, RuntimeException.class, "Ultimate efficiency achieved."); + assertIsFailed(future); + myThread.join(); + assertNull(myThread.testException); + } + + @Test + public void testCompletingFuturesViaCancellation() throws Exception { + final KafkaFutureImpl future = new KafkaFutureImpl<>(); + CompleterThread myThread = new CompleterThread<>(future, null, + new CancellationException("Ultimate efficiency achieved.")); + assertIsNotCompleted(future); + assertEquals("I am ready", future.getNow("I am ready")); + myThread.start(); + awaitAndAssertCancelled(future, "Ultimate efficiency achieved."); + assertIsCancelled(future); + myThread.join(); + assertNull(myThread.testException); + } + + private void assertIsNotCompleted(KafkaFutureImpl future) { + assertFalse(future.isDone()); + assertFalse(future.isCompletedExceptionally()); + assertFalse(future.isCancelled()); + } + + @Test + public void testThenApplyOnSucceededFuture() throws Exception { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + KafkaFuture doubledFuture = future.thenApply(integer -> 2 * integer); + assertFalse(doubledFuture.isDone()); + KafkaFuture tripledFuture = future.thenApply(integer -> 3 * integer); + assertFalse(tripledFuture.isDone()); + future.complete(21); + assertEquals(Integer.valueOf(21), future.getNow(-1)); + assertEquals(Integer.valueOf(42), doubledFuture.getNow(-1)); + assertEquals(Integer.valueOf(63), tripledFuture.getNow(-1)); + KafkaFuture quadrupledFuture = future.thenApply(integer -> 4 * integer); + assertEquals(Integer.valueOf(84), quadrupledFuture.getNow(-1)); + } + + @Test + public void testThenApplyOnFailedFuture() { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + KafkaFuture dependantFuture = future.thenApply(integer -> 2 * integer); + future.completeExceptionally(new RuntimeException("We require more vespene gas")); + assertIsFailed(future); + assertIsFailed(dependantFuture); + awaitAndAssertFailure(future, RuntimeException.class, "We require more vespene gas"); + awaitAndAssertFailure(dependantFuture, RuntimeException.class, "We require more vespene gas"); + } + + @Test + public void testThenApplyOnFailedFutureTricky() { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + KafkaFuture dependantFuture = future.thenApply(integer -> 2 * integer); + future.completeExceptionally(new CompletionException(new RuntimeException("We require more vespene gas"))); + assertIsFailed(future); + assertIsFailed(dependantFuture); + awaitAndAssertFailure(future, CompletionException.class, "java.lang.RuntimeException: We require more vespene gas"); + awaitAndAssertFailure(dependantFuture, CompletionException.class, "java.lang.RuntimeException: We require more vespene gas"); + } + + @Test + public void testThenApplyOnFailedFutureTricky2() { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + KafkaFuture dependantFuture = future.thenApply(integer -> 2 * integer); + future.completeExceptionally(new CompletionException(new CancellationException())); + assertIsFailed(future); + assertIsFailed(dependantFuture); + awaitAndAssertFailure(future, CompletionException.class, "java.util.concurrent.CancellationException"); + awaitAndAssertFailure(dependantFuture, CompletionException.class, "java.util.concurrent.CancellationException"); + } + + @Test + public void testThenApplyOnSucceededFutureAndFunctionThrows() { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + KafkaFuture dependantFuture = future.thenApply(integer -> { + throw new RuntimeException("We require more vespene gas"); + }); + future.complete(21); + assertIsSuccessful(future); + assertIsFailed(dependantFuture); + awaitAndAssertResult(future, 21, null); + awaitAndAssertFailure(dependantFuture, RuntimeException.class, "We require more vespene gas"); + } + + @Test + public void testThenApplyOnSucceededFutureAndFunctionThrowsCompletionException() { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + KafkaFuture dependantFuture = future.thenApply(integer -> { + throw new CompletionException(new RuntimeException("We require more vespene gas")); + }); + future.complete(21); + assertIsSuccessful(future); + assertIsFailed(dependantFuture); + awaitAndAssertResult(future, 21, null); + Throwable cause = awaitAndAssertFailure(dependantFuture, CompletionException.class, "java.lang.RuntimeException: We require more vespene gas"); + assertTrue(cause.getCause() instanceof RuntimeException); + assertEquals(cause.getCause().getMessage(), "We require more vespene gas"); + } + + @Test + public void testThenApplyOnFailedFutureFunctionNotCalled() { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + boolean[] ran = {false}; + KafkaFuture dependantFuture = future.thenApply(integer -> { + // Because the top level future failed, this should never be called. + ran[0] = true; + return null; + }); + future.completeExceptionally(new RuntimeException("We require more minerals")); + assertIsFailed(future); + assertIsFailed(dependantFuture); + awaitAndAssertFailure(future, RuntimeException.class, "We require more minerals"); + awaitAndAssertFailure(dependantFuture, RuntimeException.class, "We require more minerals"); + assertFalse(ran[0]); + } + + @Test + public void testThenApplyOnCancelledFuture() { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + KafkaFuture dependantFuture = future.thenApply(integer -> 2 * integer); + future.cancel(true); + assertIsCancelled(future); + assertIsCancelled(dependantFuture); + awaitAndAssertCancelled(future, null); + awaitAndAssertCancelled(dependantFuture, null); + } + + @Test + public void testWhenCompleteOnSucceededFuture() throws Throwable { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + Throwable[] err = new Throwable[1]; + boolean[] ran = {false}; + KafkaFuture dependantFuture = future.whenComplete((integer, ex) -> { + ran[0] = true; + try { + assertEquals(Integer.valueOf(21), integer); + if (ex != null) { + throw ex; + } + } catch (Throwable e) { + err[0] = e; + } + }); + assertFalse(dependantFuture.isDone()); + assertTrue(future.complete(21)); + assertTrue(ran[0]); + if (err[0] != null) { + throw err[0]; + } + } + + @Test + public void testWhenCompleteOnFailedFuture() { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + Throwable[] err = new Throwable[1]; + boolean[] ran = {false}; + KafkaFuture dependantFuture = future.whenComplete((integer, ex) -> { + ran[0] = true; + err[0] = ex; + if (integer != null) { + err[0] = new AssertionError(); + } + }); + assertFalse(dependantFuture.isDone()); + RuntimeException ex = new RuntimeException("We require more vespene gas"); + assertTrue(future.completeExceptionally(ex)); + assertTrue(ran[0]); + assertEquals(err[0], ex); + } + + @Test + public void testWhenCompleteOnSucceededFutureAndConsumerThrows() { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + boolean[] ran = {false}; + KafkaFuture dependantFuture = future.whenComplete((integer, ex) -> { + ran[0] = true; + throw new RuntimeException("We require more minerals"); + }); + assertFalse(dependantFuture.isDone()); + assertTrue(future.complete(21)); + assertIsSuccessful(future); + assertTrue(ran[0]); + assertIsFailed(dependantFuture); + awaitAndAssertFailure(dependantFuture, RuntimeException.class, "We require more minerals"); + } + + @Test + public void testWhenCompleteOnFailedFutureAndConsumerThrows() { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + boolean[] ran = {false}; + KafkaFuture dependantFuture = future.whenComplete((integer, ex) -> { + ran[0] = true; + throw new RuntimeException("We require more minerals"); + }); + assertFalse(dependantFuture.isDone()); + assertTrue(future.completeExceptionally(new RuntimeException("We require more vespene gas"))); + assertIsFailed(future); + assertTrue(ran[0]); + assertIsFailed(dependantFuture); + awaitAndAssertFailure(dependantFuture, RuntimeException.class, "We require more vespene gas"); + } + + @Test + public void testWhenCompleteOnCancelledFuture() { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + Throwable[] err = new Throwable[1]; + boolean[] ran = {false}; + KafkaFuture dependantFuture = future.whenComplete((integer, ex) -> { + ran[0] = true; + err[0] = ex; + if (integer != null) { + err[0] = new AssertionError(); + } + }); + assertFalse(dependantFuture.isDone()); + assertTrue(future.cancel(true)); + assertTrue(ran[0]); + assertTrue(err[0] instanceof CancellationException); + } + + private static class CompleterThread extends Thread { + + private final KafkaFutureImpl future; + private final T value; + private final Throwable exception; + Throwable testException = null; + + CompleterThread(KafkaFutureImpl future, T value) { + this.future = future; + this.value = value; + this.exception = null; + } + + CompleterThread(KafkaFutureImpl future, T value, Exception exception) { + this.future = future; + this.value = value; + this.exception = exception; + } + + @Override + public void run() { + try { + try { + Thread.sleep(0, 200); + } catch (InterruptedException e) { + } + if (exception == null) { + future.complete(value); + } else { + future.completeExceptionally(exception); + } + } catch (Throwable testException) { + this.testException = testException; + } + } + } + + private static class WaiterThread extends Thread { + + private final KafkaFutureImpl future; + private final T expected; + Throwable testException = null; + + WaiterThread(KafkaFutureImpl future, T expected) { + this.future = future; + this.expected = expected; + } + + @Override + public void run() { + try { + T value = future.get(); + assertEquals(expected, value); + } catch (Throwable testException) { + this.testException = testException; + } + } + } + + @Test + public void testAllOfFutures() throws Exception { + final int numThreads = 5; + final List> futures = new ArrayList<>(); + for (int i = 0; i < numThreads; i++) { + futures.add(new KafkaFutureImpl<>()); + } + KafkaFuture allFuture = KafkaFuture.allOf(futures.toArray(new KafkaFuture[0])); + final List> completerThreads = new ArrayList<>(); + final List> waiterThreads = new ArrayList<>(); + for (int i = 0; i < numThreads; i++) { + completerThreads.add(new CompleterThread<>(futures.get(i), i)); + waiterThreads.add(new WaiterThread<>(futures.get(i), i)); + } + assertFalse(allFuture.isDone()); + for (int i = 0; i < numThreads; i++) { + waiterThreads.get(i).start(); + } + for (int i = 0; i < numThreads - 1; i++) { + completerThreads.get(i).start(); + } + assertFalse(allFuture.isDone()); + completerThreads.get(numThreads - 1).start(); + allFuture.get(); + assertIsSuccessful(allFuture); + for (int i = 0; i < numThreads; i++) { + assertEquals(Integer.valueOf(i), futures.get(i).get()); + } + for (int i = 0; i < numThreads; i++) { + completerThreads.get(i).join(); + waiterThreads.get(i).join(); + assertNull(completerThreads.get(i).testException); + assertNull(waiterThreads.get(i).testException); + } + } + + @Test + public void testAllOfFuturesWithFailure() throws Exception { + final int numThreads = 5; + final List> futures = new ArrayList<>(); + for (int i = 0; i < numThreads; i++) { + futures.add(new KafkaFutureImpl<>()); + } + KafkaFuture allFuture = KafkaFuture.allOf(futures.toArray(new KafkaFuture[0])); + final List> completerThreads = new ArrayList<>(); + final List> waiterThreads = new ArrayList<>(); + int lastIndex = numThreads - 1; + for (int i = 0; i < lastIndex; i++) { + completerThreads.add(new CompleterThread<>(futures.get(i), i)); + waiterThreads.add(new WaiterThread<>(futures.get(i), i)); + } + completerThreads.add(new CompleterThread<>(futures.get(lastIndex), null, new RuntimeException("Last one failed"))); + waiterThreads.add(new WaiterThread<>(futures.get(lastIndex), lastIndex)); + assertFalse(allFuture.isDone()); + for (int i = 0; i < numThreads; i++) { + waiterThreads.get(i).start(); + } + for (int i = 0; i < lastIndex; i++) { + completerThreads.get(i).start(); + } + assertFalse(allFuture.isDone()); + completerThreads.get(lastIndex).start(); + awaitAndAssertFailure(allFuture, RuntimeException.class, "Last one failed"); + assertIsFailed(allFuture); + for (int i = 0; i < lastIndex; i++) { + assertEquals(Integer.valueOf(i), futures.get(i).get()); + } + assertIsFailed(futures.get(lastIndex)); + for (int i = 0; i < numThreads; i++) { + completerThreads.get(i).join(); + waiterThreads.get(i).join(); + assertNull(completerThreads.get(i).testException); + if (i == lastIndex) { + assertEquals(ExecutionException.class, waiterThreads.get(i).testException.getClass()); + assertEquals(RuntimeException.class, waiterThreads.get(i).testException.getCause().getClass()); + assertEquals("Last one failed", waiterThreads.get(i).testException.getCause().getMessage()); + } else { + assertNull(waiterThreads.get(i).testException); + } + } + } + + @Test + public void testAllOfFuturesHandlesZeroFutures() throws Exception { + KafkaFuture allFuture = KafkaFuture.allOf(); + assertTrue(allFuture.isDone()); + assertFalse(allFuture.isCancelled()); + assertFalse(allFuture.isCompletedExceptionally()); + allFuture.get(); + } + + @Test + public void testFutureTimeoutWithZeroWait() { + final KafkaFutureImpl future = new KafkaFutureImpl<>(); + assertThrows(TimeoutException.class, () -> future.get(0, TimeUnit.MILLISECONDS)); + } + + @Test + @SuppressWarnings("unchecked") + public void testLeakCompletableFuture() throws NoSuchMethodException, InvocationTargetException, IllegalAccessException { + final KafkaFutureImpl kfut = new KafkaFutureImpl<>(); + CompletableFuture comfut = kfut.toCompletionStage().toCompletableFuture(); + assertThrows(UnsupportedOperationException.class, () -> comfut.complete("")); + assertThrows(UnsupportedOperationException.class, () -> comfut.completeExceptionally(new RuntimeException())); + // Annoyingly CompletableFuture added some more methods in Java 9, but the tests need to run on Java 8 + // so test reflectively + if (Java.IS_JAVA9_COMPATIBLE) { + Method completeOnTimeout = CompletableFuture.class.getDeclaredMethod("completeOnTimeout", Object.class, Long.TYPE, TimeUnit.class); + assertThrows(UnsupportedOperationException.class, () -> { + try { + completeOnTimeout.invoke(comfut, "", 1L, TimeUnit.MILLISECONDS); + } catch (InvocationTargetException e) { + throw e.getCause(); + } + }); + + Method completeAsync = CompletableFuture.class.getDeclaredMethod("completeAsync", Supplier.class); + assertThrows(UnsupportedOperationException.class, () -> { + try { + completeAsync.invoke(comfut, (Supplier) () -> ""); + } catch (InvocationTargetException e) { + throw e.getCause(); + } + }); + + Method obtrudeValue = CompletableFuture.class.getDeclaredMethod("obtrudeValue", Object.class); + assertThrows(UnsupportedOperationException.class, () -> { + try { + obtrudeValue.invoke(comfut, ""); + } catch (InvocationTargetException e) { + throw e.getCause(); + } + }); + + Method obtrudeException = CompletableFuture.class.getDeclaredMethod("obtrudeException", Throwable.class); + assertThrows(UnsupportedOperationException.class, () -> { + try { + obtrudeException.invoke(comfut, new RuntimeException()); + } catch (InvocationTargetException e) { + throw e.getCause(); + } + }); + + // Check the CF from a minimal CompletionStage doesn't cause completion of the original KafkaFuture + Method minimal = CompletableFuture.class.getDeclaredMethod("minimalCompletionStage"); + CompletionStage cs = (CompletionStage) minimal.invoke(comfut); + cs.toCompletableFuture().complete(""); + + assertFalse(kfut.isDone()); + assertFalse(comfut.isDone()); + } + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/PartitionInfoTest.java b/clients/src/test/java/org/apache/kafka/common/PartitionInfoTest.java new file mode 100644 index 0000000..ed970d6 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/PartitionInfoTest.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class PartitionInfoTest { + + @Test + public void testToString() { + String topic = "sample"; + int partition = 0; + Node leader = new Node(0, "localhost", 9092); + Node r1 = new Node(1, "localhost", 9093); + Node r2 = new Node(2, "localhost", 9094); + Node[] replicas = new Node[] {leader, r1, r2}; + Node[] inSyncReplicas = new Node[] {leader, r1}; + Node[] offlineReplicas = new Node[] {r2}; + PartitionInfo partitionInfo = new PartitionInfo(topic, partition, leader, replicas, inSyncReplicas, offlineReplicas); + + String expected = String.format("Partition(topic = %s, partition = %d, leader = %s, replicas = %s, isr = %s, offlineReplicas = %s)", + topic, partition, leader.idString(), "[0,1,2]", "[0,1]", "[2]"); + assertEquals(expected, partitionInfo.toString()); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/TopicIdPartitionTest.java b/clients/src/test/java/org/apache/kafka/common/TopicIdPartitionTest.java new file mode 100644 index 0000000..d39481c --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/TopicIdPartitionTest.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; + +import java.util.Objects; +import org.junit.jupiter.api.Test; + +class TopicIdPartitionTest { + + private final Uuid topicId0 = new Uuid(-4883993789924556279L, -5960309683534398572L); + private final String topicName0 = "a_topic_name"; + private final int partition1 = 1; + private final TopicPartition topicPartition0 = new TopicPartition(topicName0, partition1); + private final TopicIdPartition topicIdPartition0 = new TopicIdPartition(topicId0, topicPartition0); + private final TopicIdPartition topicIdPartition1 = new TopicIdPartition(topicId0, + partition1, topicName0); + + private final TopicIdPartition topicIdPartitionWithNullTopic0 = new TopicIdPartition(topicId0, + partition1, null); + private final TopicIdPartition topicIdPartitionWithNullTopic1 = new TopicIdPartition(topicId0, + new TopicPartition(null, partition1)); + + private final Uuid topicId1 = new Uuid(7759286116672424028L, -5081215629859775948L); + private final String topicName1 = "another_topic_name"; + private final TopicIdPartition topicIdPartition2 = new TopicIdPartition(topicId1, + partition1, topicName1); + private final TopicIdPartition topicIdPartitionWithNullTopic2 = new TopicIdPartition(topicId1, + new TopicPartition(null, partition1)); + + @Test + public void testEquals() { + assertEquals(topicIdPartition0, topicIdPartition1); + assertEquals(topicIdPartition1, topicIdPartition0); + assertEquals(topicIdPartitionWithNullTopic0, topicIdPartitionWithNullTopic1); + + assertNotEquals(topicIdPartition0, topicIdPartition2); + assertNotEquals(topicIdPartition2, topicIdPartition0); + assertNotEquals(topicIdPartition0, topicIdPartitionWithNullTopic0); + assertNotEquals(topicIdPartitionWithNullTopic0, topicIdPartitionWithNullTopic2); + } + + @Test + public void testHashCode() { + assertEquals(Objects.hash(topicIdPartition0.topicId(), topicIdPartition0.topicPartition()), + topicIdPartition0.hashCode()); + assertEquals(topicIdPartition0.hashCode(), topicIdPartition1.hashCode()); + + assertEquals(Objects.hash(topicIdPartitionWithNullTopic0.topicId(), + new TopicPartition(null, partition1)), topicIdPartitionWithNullTopic0.hashCode()); + assertEquals(topicIdPartitionWithNullTopic0.hashCode(), topicIdPartitionWithNullTopic1.hashCode()); + + assertNotEquals(topicIdPartition0.hashCode(), topicIdPartition2.hashCode()); + assertNotEquals(topicIdPartition0.hashCode(), topicIdPartitionWithNullTopic0.hashCode()); + assertNotEquals(topicIdPartitionWithNullTopic0.hashCode(), topicIdPartitionWithNullTopic2.hashCode()); + } + + @Test + public void testToString() { + assertEquals("vDiRhkpVQgmtSLnsAZx7lA:a_topic_name-1", topicIdPartition0.toString()); + assertEquals("vDiRhkpVQgmtSLnsAZx7lA:null-1", topicIdPartitionWithNullTopic0.toString()); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/TopicPartitionTest.java b/clients/src/test/java/org/apache/kafka/common/TopicPartitionTest.java new file mode 100644 index 0000000..fae7ce7 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/TopicPartitionTest.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common; + +import org.apache.kafka.common.utils.Serializer; +import org.junit.jupiter.api.Test; + +import java.io.IOException; + +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * This test ensures TopicPartition class is serializable and is serialization compatible. + * Note: this ensures that the current code can deserialize data serialized with older versions of the code, but not the reverse. + * That is, older code won't necessarily be able to deserialize data serialized with newer code. + */ +public class TopicPartitionTest { + private String topicName = "mytopic"; + private String fileName = "serializedData/topicPartitionSerializedfile"; + private int partNum = 5; + + private void checkValues(TopicPartition deSerTP) { + //assert deserialized values are same as original + assertEquals(partNum, deSerTP.partition(), "partition number should be " + partNum + " but got " + deSerTP.partition()); + assertEquals(topicName, deSerTP.topic(), "topic should be " + topicName + " but got " + deSerTP.topic()); + } + + @Test + public void testSerializationRoundtrip() throws IOException, ClassNotFoundException { + //assert TopicPartition is serializable and deserialization renders the clone of original properly + TopicPartition origTp = new TopicPartition(topicName, partNum); + byte[] byteArray = Serializer.serialize(origTp); + + //deserialize the byteArray and check if the values are same as original + Object deserializedObject = Serializer.deserialize(byteArray); + assertTrue(deserializedObject instanceof TopicPartition); + checkValues((TopicPartition) deserializedObject); + } + + @Test + public void testTopiPartitionSerializationCompatibility() throws IOException, ClassNotFoundException { + // assert serialized TopicPartition object in file (serializedData/topicPartitionSerializedfile) is + // deserializable into TopicPartition and is compatible + Object deserializedObject = Serializer.deserialize(fileName); + assertTrue(deserializedObject instanceof TopicPartition); + checkValues((TopicPartition) deserializedObject); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/UuidTest.java b/clients/src/test/java/org/apache/kafka/common/UuidTest.java new file mode 100644 index 0000000..232b992 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/UuidTest.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common; + +import org.junit.jupiter.api.Test; + +import java.util.Base64; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class UuidTest { + + @Test + public void testSignificantBits() { + Uuid id = new Uuid(34L, 98L); + + assertEquals(id.getMostSignificantBits(), 34L); + assertEquals(id.getLeastSignificantBits(), 98L); + } + + @Test + public void testUuidEquality() { + Uuid id1 = new Uuid(12L, 13L); + Uuid id2 = new Uuid(12L, 13L); + Uuid id3 = new Uuid(24L, 38L); + + assertEquals(Uuid.ZERO_UUID, Uuid.ZERO_UUID); + assertEquals(id1, id2); + assertNotEquals(id1, id3); + + assertEquals(Uuid.ZERO_UUID.hashCode(), Uuid.ZERO_UUID.hashCode()); + assertEquals(id1.hashCode(), id2.hashCode()); + assertNotEquals(id1.hashCode(), id3.hashCode()); + } + + @Test + public void testHashCode() { + Uuid id1 = new Uuid(16L, 7L); + Uuid id2 = new Uuid(1043L, 20075L); + Uuid id3 = new Uuid(104312423523523L, 200732425676585L); + + assertEquals(23, id1.hashCode()); + assertEquals(19064, id2.hashCode()); + assertEquals(-2011255899, id3.hashCode()); + } + + @Test + public void testStringConversion() { + Uuid id = Uuid.randomUuid(); + String idString = id.toString(); + + assertEquals(Uuid.fromString(idString), id); + + String zeroIdString = Uuid.ZERO_UUID.toString(); + + assertEquals(Uuid.fromString(zeroIdString), Uuid.ZERO_UUID); + } + + @Test + public void testRandomUuid() { + Uuid randomID = Uuid.randomUuid(); + + assertNotEquals(randomID, Uuid.ZERO_UUID); + assertNotEquals(randomID, Uuid.METADATA_TOPIC_ID); + } + + @Test + public void testCompareUuids() { + Uuid id00 = new Uuid(0L, 0L); + Uuid id01 = new Uuid(0L, 1L); + Uuid id10 = new Uuid(1L, 0L); + assertEquals(0, id00.compareTo(id00)); + assertEquals(0, id01.compareTo(id01)); + assertEquals(0, id10.compareTo(id10)); + assertEquals(-1, id00.compareTo(id01)); + assertEquals(-1, id00.compareTo(id10)); + assertEquals(1, id01.compareTo(id00)); + assertEquals(1, id10.compareTo(id00)); + assertEquals(-1, id01.compareTo(id10)); + assertEquals(1, id10.compareTo(id01)); + } + + @Test + public void testFromStringWithInvalidInput() { + String oversizeString = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[32]); + assertThrows(IllegalArgumentException.class, () -> Uuid.fromString(oversizeString)); + + String undersizeString = Base64.getUrlEncoder().withoutPadding().encodeToString(new byte[4]); + assertThrows(IllegalArgumentException.class, () -> Uuid.fromString(undersizeString)); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/acl/AclBindingTest.java b/clients/src/test/java/org/apache/kafka/common/acl/AclBindingTest.java new file mode 100644 index 0000000..5330395 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/acl/AclBindingTest.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.acl; + +import org.apache.kafka.common.resource.PatternType; +import org.apache.kafka.common.resource.ResourcePattern; +import org.apache.kafka.common.resource.ResourcePatternFilter; +import org.apache.kafka.common.resource.ResourceType; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class AclBindingTest { + private static final AclBinding ACL1 = new AclBinding( + new ResourcePattern(ResourceType.TOPIC, "mytopic", PatternType.LITERAL), + new AccessControlEntry("User:ANONYMOUS", "", AclOperation.ALL, AclPermissionType.ALLOW)); + + private static final AclBinding ACL2 = new AclBinding( + new ResourcePattern(ResourceType.TOPIC, "mytopic", PatternType.LITERAL), + new AccessControlEntry("User:*", "", AclOperation.READ, AclPermissionType.ALLOW)); + + private static final AclBinding ACL3 = new AclBinding( + new ResourcePattern(ResourceType.TOPIC, "mytopic2", PatternType.LITERAL), + new AccessControlEntry("User:ANONYMOUS", "127.0.0.1", AclOperation.READ, AclPermissionType.DENY)); + + private static final AclBinding UNKNOWN_ACL = new AclBinding( + new ResourcePattern(ResourceType.TOPIC, "mytopic2", PatternType.LITERAL), + new AccessControlEntry("User:ANONYMOUS", "127.0.0.1", AclOperation.UNKNOWN, AclPermissionType.DENY)); + + private static final AclBindingFilter ANY_ANONYMOUS = new AclBindingFilter( + ResourcePatternFilter.ANY, + new AccessControlEntryFilter("User:ANONYMOUS", null, AclOperation.ANY, AclPermissionType.ANY)); + + private static final AclBindingFilter ANY_DENY = new AclBindingFilter( + ResourcePatternFilter.ANY, + new AccessControlEntryFilter(null, null, AclOperation.ANY, AclPermissionType.DENY)); + + private static final AclBindingFilter ANY_MYTOPIC = new AclBindingFilter( + new ResourcePatternFilter(ResourceType.TOPIC, "mytopic", PatternType.LITERAL), + new AccessControlEntryFilter(null, null, AclOperation.ANY, AclPermissionType.ANY)); + + @Test + public void testMatching() { + assertEquals(ACL1, ACL1); + final AclBinding acl1Copy = new AclBinding( + new ResourcePattern(ResourceType.TOPIC, "mytopic", PatternType.LITERAL), + new AccessControlEntry("User:ANONYMOUS", "", AclOperation.ALL, AclPermissionType.ALLOW)); + assertEquals(ACL1, acl1Copy); + assertEquals(acl1Copy, ACL1); + assertEquals(ACL2, ACL2); + assertNotEquals(ACL1, ACL2); + assertNotEquals(ACL2, ACL1); + assertTrue(AclBindingFilter.ANY.matches(ACL1)); + assertNotEquals(AclBindingFilter.ANY, ACL1); + assertTrue(AclBindingFilter.ANY.matches(ACL2)); + assertNotEquals(AclBindingFilter.ANY, ACL2); + assertTrue(AclBindingFilter.ANY.matches(ACL3)); + assertNotEquals(AclBindingFilter.ANY, ACL3); + assertEquals(AclBindingFilter.ANY, AclBindingFilter.ANY); + assertTrue(ANY_ANONYMOUS.matches(ACL1)); + assertNotEquals(ANY_ANONYMOUS, ACL1); + assertFalse(ANY_ANONYMOUS.matches(ACL2)); + assertNotEquals(ANY_ANONYMOUS, ACL2); + assertTrue(ANY_ANONYMOUS.matches(ACL3)); + assertNotEquals(ANY_ANONYMOUS, ACL3); + assertFalse(ANY_DENY.matches(ACL1)); + assertFalse(ANY_DENY.matches(ACL2)); + assertTrue(ANY_DENY.matches(ACL3)); + assertTrue(ANY_MYTOPIC.matches(ACL1)); + assertTrue(ANY_MYTOPIC.matches(ACL2)); + assertFalse(ANY_MYTOPIC.matches(ACL3)); + assertTrue(ANY_ANONYMOUS.matches(UNKNOWN_ACL)); + assertTrue(ANY_DENY.matches(UNKNOWN_ACL)); + assertEquals(UNKNOWN_ACL, UNKNOWN_ACL); + assertFalse(ANY_MYTOPIC.matches(UNKNOWN_ACL)); + } + + @Test + public void testUnknowns() { + assertFalse(ACL1.isUnknown()); + assertFalse(ACL2.isUnknown()); + assertFalse(ACL3.isUnknown()); + assertFalse(ANY_ANONYMOUS.isUnknown()); + assertFalse(ANY_DENY.isUnknown()); + assertFalse(ANY_MYTOPIC.isUnknown()); + assertTrue(UNKNOWN_ACL.isUnknown()); + } + + @Test + public void testMatchesAtMostOne() { + assertNull(ACL1.toFilter().findIndefiniteField()); + assertNull(ACL2.toFilter().findIndefiniteField()); + assertNull(ACL3.toFilter().findIndefiniteField()); + assertFalse(ANY_ANONYMOUS.matchesAtMostOne()); + assertFalse(ANY_DENY.matchesAtMostOne()); + assertFalse(ANY_MYTOPIC.matchesAtMostOne()); + } + + @Test + public void shouldNotThrowOnUnknownPatternType() { + new AclBinding(new ResourcePattern(ResourceType.TOPIC, "foo", PatternType.UNKNOWN), ACL1.entry()); + } + + @Test + public void shouldNotThrowOnUnknownResourceType() { + new AclBinding(new ResourcePattern(ResourceType.UNKNOWN, "foo", PatternType.LITERAL), ACL1.entry()); + } + + @Test + public void shouldThrowOnMatchPatternType() { + assertThrows(IllegalArgumentException.class, + () -> new AclBinding(new ResourcePattern(ResourceType.TOPIC, "foo", PatternType.MATCH), ACL1.entry())); + } + + @Test + public void shouldThrowOnAnyPatternType() { + assertThrows(IllegalArgumentException.class, + () -> new AclBinding(new ResourcePattern(ResourceType.TOPIC, "foo", PatternType.ANY), ACL1.entry())); + } + + @Test + public void shouldThrowOnAnyResourceType() { + assertThrows(IllegalArgumentException.class, + () -> new AclBinding(new ResourcePattern(ResourceType.ANY, "foo", PatternType.LITERAL), ACL1.entry())); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/acl/AclOperationTest.java b/clients/src/test/java/org/apache/kafka/common/acl/AclOperationTest.java new file mode 100644 index 0000000..c807e2b --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/acl/AclOperationTest.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.acl; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class AclOperationTest { + private static class AclOperationTestInfo { + private final AclOperation operation; + private final int code; + private final String name; + private final boolean unknown; + + AclOperationTestInfo(AclOperation operation, int code, String name, boolean unknown) { + this.operation = operation; + this.code = code; + this.name = name; + this.unknown = unknown; + } + } + + private static final AclOperationTestInfo[] INFOS = { + new AclOperationTestInfo(AclOperation.UNKNOWN, 0, "unknown", true), + new AclOperationTestInfo(AclOperation.ANY, 1, "any", false), + new AclOperationTestInfo(AclOperation.ALL, 2, "all", false), + new AclOperationTestInfo(AclOperation.READ, 3, "read", false), + new AclOperationTestInfo(AclOperation.WRITE, 4, "write", false), + new AclOperationTestInfo(AclOperation.CREATE, 5, "create", false), + new AclOperationTestInfo(AclOperation.DELETE, 6, "delete", false), + new AclOperationTestInfo(AclOperation.ALTER, 7, "alter", false), + new AclOperationTestInfo(AclOperation.DESCRIBE, 8, "describe", false), + new AclOperationTestInfo(AclOperation.CLUSTER_ACTION, 9, "cluster_action", false), + new AclOperationTestInfo(AclOperation.DESCRIBE_CONFIGS, 10, "describe_configs", false), + new AclOperationTestInfo(AclOperation.ALTER_CONFIGS, 11, "alter_configs", false), + new AclOperationTestInfo(AclOperation.IDEMPOTENT_WRITE, 12, "idempotent_write", false) + }; + + @Test + public void testIsUnknown() throws Exception { + for (AclOperationTestInfo info : INFOS) { + assertEquals(info.unknown, info.operation.isUnknown(), + info.operation + " was supposed to have unknown == " + info.unknown); + } + } + + @Test + public void testCode() throws Exception { + assertEquals(AclOperation.values().length, INFOS.length); + for (AclOperationTestInfo info : INFOS) { + assertEquals(info.code, info.operation.code(), info.operation + " was supposed to have code == " + info.code); + assertEquals(info.operation, AclOperation.fromCode((byte) info.code), + "AclOperation.fromCode(" + info.code + ") was supposed to be " + info.operation); + } + assertEquals(AclOperation.UNKNOWN, AclOperation.fromCode((byte) 120)); + } + + @Test + public void testName() throws Exception { + for (AclOperationTestInfo info : INFOS) { + assertEquals(info.operation, AclOperation.fromString(info.name), + "AclOperation.fromString(" + info.name + ") was supposed to be " + info.operation); + } + assertEquals(AclOperation.UNKNOWN, AclOperation.fromString("something")); + } + + @Test + public void testExhaustive() { + assertEquals(INFOS.length, AclOperation.values().length); + for (int i = 0; i < INFOS.length; i++) { + assertEquals(INFOS[i].operation, AclOperation.values()[i]); + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/acl/AclPermissionTypeTest.java b/clients/src/test/java/org/apache/kafka/common/acl/AclPermissionTypeTest.java new file mode 100644 index 0000000..8da6d2a --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/acl/AclPermissionTypeTest.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.acl; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class AclPermissionTypeTest { + private static class AclPermissionTypeTestInfo { + private final AclPermissionType ty; + private final int code; + private final String name; + private final boolean unknown; + + AclPermissionTypeTestInfo(AclPermissionType ty, int code, String name, boolean unknown) { + this.ty = ty; + this.code = code; + this.name = name; + this.unknown = unknown; + } + } + + private static final AclPermissionTypeTestInfo[] INFOS = { + new AclPermissionTypeTestInfo(AclPermissionType.UNKNOWN, 0, "unknown", true), + new AclPermissionTypeTestInfo(AclPermissionType.ANY, 1, "any", false), + new AclPermissionTypeTestInfo(AclPermissionType.DENY, 2, "deny", false), + new AclPermissionTypeTestInfo(AclPermissionType.ALLOW, 3, "allow", false) + }; + + @Test + public void testIsUnknown() throws Exception { + for (AclPermissionTypeTestInfo info : INFOS) { + assertEquals(info.unknown, info.ty.isUnknown(), info.ty + " was supposed to have unknown == " + info.unknown); + } + } + + @Test + public void testCode() throws Exception { + assertEquals(AclPermissionType.values().length, INFOS.length); + for (AclPermissionTypeTestInfo info : INFOS) { + assertEquals(info.code, info.ty.code(), info.ty + " was supposed to have code == " + info.code); + assertEquals(info.ty, AclPermissionType.fromCode((byte) info.code), + "AclPermissionType.fromCode(" + info.code + ") was supposed to be " + info.ty); + } + assertEquals(AclPermissionType.UNKNOWN, AclPermissionType.fromCode((byte) 120)); + } + + @Test + public void testName() throws Exception { + for (AclPermissionTypeTestInfo info : INFOS) { + assertEquals(info.ty, AclPermissionType.fromString(info.name), + "AclPermissionType.fromString(" + info.name + ") was supposed to be " + info.ty); + } + assertEquals(AclPermissionType.UNKNOWN, AclPermissionType.fromString("something")); + } + + @Test + public void testExhaustive() throws Exception { + assertEquals(INFOS.length, AclPermissionType.values().length); + for (int i = 0; i < INFOS.length; i++) { + assertEquals(INFOS[i].ty, AclPermissionType.values()[i]); + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/acl/ResourcePatternFilterTest.java b/clients/src/test/java/org/apache/kafka/common/acl/ResourcePatternFilterTest.java new file mode 100644 index 0000000..98b004e --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/acl/ResourcePatternFilterTest.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.acl; + +import org.apache.kafka.common.resource.PatternType; +import org.apache.kafka.common.resource.ResourcePattern; +import org.apache.kafka.common.resource.ResourcePatternFilter; +import org.junit.jupiter.api.Test; + +import static org.apache.kafka.common.resource.PatternType.LITERAL; +import static org.apache.kafka.common.resource.PatternType.PREFIXED; +import static org.apache.kafka.common.resource.ResourceType.ANY; +import static org.apache.kafka.common.resource.ResourceType.GROUP; +import static org.apache.kafka.common.resource.ResourceType.TOPIC; +import static org.apache.kafka.common.resource.ResourceType.UNKNOWN; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ResourcePatternFilterTest { + @Test + public void shouldBeUnknownIfResourceTypeUnknown() { + assertTrue(new ResourcePatternFilter(UNKNOWN, null, PatternType.LITERAL).isUnknown()); + } + + @Test + public void shouldBeUnknownIfPatternTypeUnknown() { + assertTrue(new ResourcePatternFilter(GROUP, null, PatternType.UNKNOWN).isUnknown()); + } + + @Test + public void shouldNotMatchIfDifferentResourceType() { + assertFalse(new ResourcePatternFilter(TOPIC, "Name", LITERAL) + .matches(new ResourcePattern(GROUP, "Name", LITERAL))); + } + + @Test + public void shouldNotMatchIfDifferentName() { + assertFalse(new ResourcePatternFilter(TOPIC, "Different", PREFIXED) + .matches(new ResourcePattern(TOPIC, "Name", PREFIXED))); + } + + @Test + public void shouldNotMatchIfDifferentNameCase() { + assertFalse(new ResourcePatternFilter(TOPIC, "NAME", LITERAL) + .matches(new ResourcePattern(TOPIC, "Name", LITERAL))); + } + + @Test + public void shouldNotMatchIfDifferentPatternType() { + assertFalse(new ResourcePatternFilter(TOPIC, "Name", LITERAL) + .matches(new ResourcePattern(TOPIC, "Name", PREFIXED))); + } + + @Test + public void shouldMatchWhereResourceTypeIsAny() { + assertTrue(new ResourcePatternFilter(ANY, "Name", PREFIXED) + .matches(new ResourcePattern(TOPIC, "Name", PREFIXED))); + } + + @Test + public void shouldMatchWhereResourceNameIsAny() { + assertTrue(new ResourcePatternFilter(TOPIC, null, PREFIXED) + .matches(new ResourcePattern(TOPIC, "Name", PREFIXED))); + } + + @Test + public void shouldMatchWherePatternTypeIsAny() { + assertTrue(new ResourcePatternFilter(TOPIC, null, PatternType.ANY) + .matches(new ResourcePattern(TOPIC, "Name", PREFIXED))); + } + + @Test + public void shouldMatchWherePatternTypeIsMatch() { + assertTrue(new ResourcePatternFilter(TOPIC, null, PatternType.MATCH) + .matches(new ResourcePattern(TOPIC, "Name", PREFIXED))); + } + + @Test + public void shouldMatchLiteralIfExactMatch() { + assertTrue(new ResourcePatternFilter(TOPIC, "Name", LITERAL) + .matches(new ResourcePattern(TOPIC, "Name", LITERAL))); + } + + @Test + public void shouldMatchLiteralIfNameMatchesAndFilterIsOnPatternTypeAny() { + assertTrue(new ResourcePatternFilter(TOPIC, "Name", PatternType.ANY) + .matches(new ResourcePattern(TOPIC, "Name", LITERAL))); + } + + @Test + public void shouldMatchLiteralIfNameMatchesAndFilterIsOnPatternTypeMatch() { + assertTrue(new ResourcePatternFilter(TOPIC, "Name", PatternType.MATCH) + .matches(new ResourcePattern(TOPIC, "Name", LITERAL))); + } + + @Test + public void shouldNotMatchLiteralIfNamePrefixed() { + assertFalse(new ResourcePatternFilter(TOPIC, "Name-something", PatternType.MATCH) + .matches(new ResourcePattern(TOPIC, "Name", LITERAL))); + } + + @Test + public void shouldMatchLiteralWildcardIfExactMatch() { + assertTrue(new ResourcePatternFilter(TOPIC, "*", LITERAL) + .matches(new ResourcePattern(TOPIC, "*", LITERAL))); + } + + @Test + public void shouldNotMatchLiteralWildcardAgainstOtherName() { + assertFalse(new ResourcePatternFilter(TOPIC, "Name", LITERAL) + .matches(new ResourcePattern(TOPIC, "*", LITERAL))); + } + + @Test + public void shouldNotMatchLiteralWildcardTheWayAround() { + assertFalse(new ResourcePatternFilter(TOPIC, "*", LITERAL) + .matches(new ResourcePattern(TOPIC, "Name", LITERAL))); + } + + @Test + public void shouldNotMatchLiteralWildcardIfFilterHasPatternTypeOfAny() { + assertFalse(new ResourcePatternFilter(TOPIC, "Name", PatternType.ANY) + .matches(new ResourcePattern(TOPIC, "*", LITERAL))); + } + + @Test + public void shouldMatchLiteralWildcardIfFilterHasPatternTypeOfMatch() { + assertTrue(new ResourcePatternFilter(TOPIC, "Name", PatternType.MATCH) + .matches(new ResourcePattern(TOPIC, "*", LITERAL))); + } + + @Test + public void shouldMatchPrefixedIfExactMatch() { + assertTrue(new ResourcePatternFilter(TOPIC, "Name", PREFIXED) + .matches(new ResourcePattern(TOPIC, "Name", PREFIXED))); + } + + @Test + public void shouldNotMatchIfBothPrefixedAndFilterIsPrefixOfResource() { + assertFalse(new ResourcePatternFilter(TOPIC, "Name", PREFIXED) + .matches(new ResourcePattern(TOPIC, "Name-something", PREFIXED))); + } + + @Test + public void shouldNotMatchIfBothPrefixedAndResourceIsPrefixOfFilter() { + assertFalse(new ResourcePatternFilter(TOPIC, "Name-something", PREFIXED) + .matches(new ResourcePattern(TOPIC, "Name", PREFIXED))); + } + + @Test + public void shouldNotMatchPrefixedIfNamePrefixedAnyFilterTypeIsAny() { + assertFalse(new ResourcePatternFilter(TOPIC, "Name-something", PatternType.ANY) + .matches(new ResourcePattern(TOPIC, "Name", PREFIXED))); + } + + @Test + public void shouldMatchPrefixedIfNamePrefixedAnyFilterTypeIsMatch() { + assertTrue(new ResourcePatternFilter(TOPIC, "Name-something", PatternType.MATCH) + .matches(new ResourcePattern(TOPIC, "Name", PREFIXED))); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/acl/ResourcePatternTest.java b/clients/src/test/java/org/apache/kafka/common/acl/ResourcePatternTest.java new file mode 100644 index 0000000..9ffa7b7 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/acl/ResourcePatternTest.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.acl; + +import org.apache.kafka.common.resource.PatternType; +import org.apache.kafka.common.resource.ResourcePattern; +import org.apache.kafka.common.resource.ResourceType; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class ResourcePatternTest { + + @Test + public void shouldThrowIfResourceTypeIsAny() { + assertThrows(IllegalArgumentException.class, + () -> new ResourcePattern(ResourceType.ANY, "name", PatternType.LITERAL)); + } + + @Test + public void shouldThrowIfPatternTypeIsMatch() { + assertThrows(IllegalArgumentException.class, () -> new ResourcePattern(ResourceType.TOPIC, "name", PatternType.MATCH)); + } + + @Test + public void shouldThrowIfPatternTypeIsAny() { + assertThrows(IllegalArgumentException.class, () -> new ResourcePattern(ResourceType.TOPIC, "name", PatternType.ANY)); + } + + @Test + public void shouldThrowIfResourceNameIsNull() { + assertThrows(NullPointerException.class, () -> new ResourcePattern(ResourceType.TOPIC, null, PatternType.ANY)); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/cache/LRUCacheTest.java b/clients/src/test/java/org/apache/kafka/common/cache/LRUCacheTest.java new file mode 100644 index 0000000..d7cd2ec --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/cache/LRUCacheTest.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.cache; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +public class LRUCacheTest { + + @Test + public void testPutGet() { + Cache cache = new LRUCache<>(4); + + cache.put("a", "b"); + cache.put("c", "d"); + cache.put("e", "f"); + cache.put("g", "h"); + + assertEquals(4, cache.size()); + + assertEquals("b", cache.get("a")); + assertEquals("d", cache.get("c")); + assertEquals("f", cache.get("e")); + assertEquals("h", cache.get("g")); + } + + @Test + public void testRemove() { + Cache cache = new LRUCache<>(4); + + cache.put("a", "b"); + cache.put("c", "d"); + cache.put("e", "f"); + assertEquals(3, cache.size()); + + assertEquals(true, cache.remove("a")); + assertEquals(2, cache.size()); + assertNull(cache.get("a")); + assertEquals("d", cache.get("c")); + assertEquals("f", cache.get("e")); + + assertEquals(false, cache.remove("key-does-not-exist")); + + assertEquals(true, cache.remove("c")); + assertEquals(1, cache.size()); + assertNull(cache.get("c")); + assertEquals("f", cache.get("e")); + + assertEquals(true, cache.remove("e")); + assertEquals(0, cache.size()); + assertNull(cache.get("e")); + } + + @Test + public void testEviction() { + Cache cache = new LRUCache<>(2); + + cache.put("a", "b"); + cache.put("c", "d"); + assertEquals(2, cache.size()); + + cache.put("e", "f"); + assertEquals(2, cache.size()); + assertNull(cache.get("a")); + assertEquals("d", cache.get("c")); + assertEquals("f", cache.get("e")); + + // Validate correct access order eviction + cache.get("c"); + cache.put("g", "h"); + assertEquals(2, cache.size()); + assertNull(cache.get("e")); + assertEquals("d", cache.get("c")); + assertEquals("h", cache.get("g")); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/compress/KafkaLZ4Test.java b/clients/src/test/java/org/apache/kafka/common/compress/KafkaLZ4Test.java new file mode 100644 index 0000000..a03c830 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/compress/KafkaLZ4Test.java @@ -0,0 +1,375 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.compress; + +import net.jpountz.xxhash.XXHashFactory; + +import org.apache.kafka.common.utils.BufferSupplier; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.ArgumentsProvider; +import org.junit.jupiter.params.provider.ArgumentsSource; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Random; +import java.util.stream.Stream; + +import static org.apache.kafka.common.compress.KafkaLZ4BlockOutputStream.LZ4_FRAME_INCOMPRESSIBLE_MASK; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class KafkaLZ4Test { + + private final static Random RANDOM = new Random(0); + + private static class Payload { + String name; + byte[] payload; + + Payload(String name, byte[] payload) { + this.name = name; + this.payload = payload; + } + + @Override + public String toString() { + return "Payload{" + + "size=" + payload.length + + ", name='" + name + '\'' + + '}'; + } + } + + private static class Args { + final boolean useBrokenFlagDescriptorChecksum; + final boolean ignoreFlagDescriptorChecksum; + final byte[] payload; + final boolean close; + final boolean blockChecksum; + + Args(boolean useBrokenFlagDescriptorChecksum, boolean ignoreFlagDescriptorChecksum, + boolean blockChecksum, boolean close, Payload payload) { + this.useBrokenFlagDescriptorChecksum = useBrokenFlagDescriptorChecksum; + this.ignoreFlagDescriptorChecksum = ignoreFlagDescriptorChecksum; + this.blockChecksum = blockChecksum; + this.close = close; + this.payload = payload.payload; + } + + @Override + public String toString() { + return "useBrokenFlagDescriptorChecksum=" + useBrokenFlagDescriptorChecksum + + ", ignoreFlagDescriptorChecksum=" + ignoreFlagDescriptorChecksum + + ", blockChecksum=" + blockChecksum + + ", close=" + close + + ", payload=" + Arrays.toString(payload); + } + } + + private static class Lz4ArgumentsProvider implements ArgumentsProvider { + + @Override + public Stream provideArguments(ExtensionContext context) throws Exception { + List payloads = new ArrayList<>(); + + payloads.add(new Payload("empty", new byte[0])); + payloads.add(new Payload("onebyte", new byte[]{1})); + + for (int size : Arrays.asList(1000, 1 << 16, (1 << 10) * 96)) { + byte[] random = new byte[size]; + RANDOM.nextBytes(random); + payloads.add(new Payload("random", random)); + + byte[] ones = new byte[size]; + Arrays.fill(ones, (byte) 1); + payloads.add(new Payload("ones", ones)); + } + + List arguments = new ArrayList<>(); + for (Payload payload : payloads) + for (boolean broken : Arrays.asList(false, true)) + for (boolean ignore : Arrays.asList(false, true)) + for (boolean blockChecksum : Arrays.asList(false, true)) + for (boolean close : Arrays.asList(false, true)) + arguments.add(Arguments.of(new Args(broken, ignore, blockChecksum, close, payload))); + + return arguments.stream(); + } + } + + @ParameterizedTest + @ArgumentsSource(Lz4ArgumentsProvider.class) + public void testHeaderPrematureEnd(Args args) { + ByteBuffer buffer = ByteBuffer.allocate(2); + IOException e = assertThrows(IOException.class, () -> makeInputStream(buffer, args.ignoreFlagDescriptorChecksum)); + assertEquals(KafkaLZ4BlockInputStream.PREMATURE_EOS, e.getMessage()); + } + + private KafkaLZ4BlockInputStream makeInputStream(ByteBuffer buffer, boolean ignoreFlagDescriptorChecksum) throws IOException { + return new KafkaLZ4BlockInputStream(buffer, BufferSupplier.create(), ignoreFlagDescriptorChecksum); + } + + @ParameterizedTest + @ArgumentsSource(Lz4ArgumentsProvider.class) + public void testNotSupported(Args args) throws Exception { + byte[] compressed = compressedBytes(args); + compressed[0] = 0x00; + ByteBuffer buffer = ByteBuffer.wrap(compressed); + IOException e = assertThrows(IOException.class, () -> makeInputStream(buffer, args.ignoreFlagDescriptorChecksum)); + assertEquals(KafkaLZ4BlockInputStream.NOT_SUPPORTED, e.getMessage()); + } + + @ParameterizedTest + @ArgumentsSource(Lz4ArgumentsProvider.class) + public void testBadFrameChecksum(Args args) throws Exception { + byte[] compressed = compressedBytes(args); + compressed[6] = (byte) 0xFF; + ByteBuffer buffer = ByteBuffer.wrap(compressed); + + if (args.ignoreFlagDescriptorChecksum) { + makeInputStream(buffer, args.ignoreFlagDescriptorChecksum); + } else { + IOException e = assertThrows(IOException.class, () -> makeInputStream(buffer, args.ignoreFlagDescriptorChecksum)); + assertEquals(KafkaLZ4BlockInputStream.DESCRIPTOR_HASH_MISMATCH, e.getMessage()); + } + } + + @ParameterizedTest + @ArgumentsSource(Lz4ArgumentsProvider.class) + public void testBadBlockSize(Args args) throws Exception { + if (!args.close || (args.useBrokenFlagDescriptorChecksum && !args.ignoreFlagDescriptorChecksum)) + return; + + byte[] compressed = compressedBytes(args); + ByteBuffer buffer = ByteBuffer.wrap(compressed).order(ByteOrder.LITTLE_ENDIAN); + + int blockSize = buffer.getInt(7); + blockSize = (blockSize & LZ4_FRAME_INCOMPRESSIBLE_MASK) | (1 << 24 & ~LZ4_FRAME_INCOMPRESSIBLE_MASK); + buffer.putInt(7, blockSize); + + IOException e = assertThrows(IOException.class, () -> testDecompression(buffer, args)); + assertTrue(e.getMessage().contains("exceeded max")); + } + + + + @ParameterizedTest + @ArgumentsSource(Lz4ArgumentsProvider.class) + public void testCompression(Args args) throws Exception { + byte[] compressed = compressedBytes(args); + + // Check magic bytes stored as little-endian + int offset = 0; + assertEquals(0x04, compressed[offset++]); + assertEquals(0x22, compressed[offset++]); + assertEquals(0x4D, compressed[offset++]); + assertEquals(0x18, compressed[offset++]); + + // Check flg descriptor + byte flg = compressed[offset++]; + + // 2-bit version must be 01 + int version = (flg >>> 6) & 3; + assertEquals(1, version); + + // Reserved bits should always be 0 + int reserved = flg & 3; + assertEquals(0, reserved); + + // Check block descriptor + byte bd = compressed[offset++]; + + // Block max-size + int blockMaxSize = (bd >>> 4) & 7; + // Only supported values are 4 (64KB), 5 (256KB), 6 (1MB), 7 (4MB) + assertTrue(blockMaxSize >= 4); + assertTrue(blockMaxSize <= 7); + + // Multiple reserved bit ranges in block descriptor + reserved = bd & 15; + assertEquals(0, reserved); + reserved = (bd >>> 7) & 1; + assertEquals(0, reserved); + + // If flg descriptor sets content size flag + // there are 8 additional bytes before checksum + boolean contentSize = ((flg >>> 3) & 1) != 0; + if (contentSize) + offset += 8; + + // Checksum applies to frame descriptor: flg, bd, and optional contentsize + // so initial offset should be 4 (for magic bytes) + int off = 4; + int len = offset - 4; + + // Initial implementation of checksum incorrectly applied to full header + // including magic bytes + if (args.useBrokenFlagDescriptorChecksum) { + off = 0; + len = offset; + } + + int hash = XXHashFactory.fastestInstance().hash32().hash(compressed, off, len, 0); + + byte hc = compressed[offset++]; + assertEquals((byte) ((hash >> 8) & 0xFF), hc); + + // Check EndMark, data block with size `0` expressed as a 32-bits value + if (args.close) { + offset = compressed.length - 4; + assertEquals(0, compressed[offset++]); + assertEquals(0, compressed[offset++]); + assertEquals(0, compressed[offset++]); + assertEquals(0, compressed[offset++]); + } + } + + @ParameterizedTest + @ArgumentsSource(Lz4ArgumentsProvider.class) + public void testArrayBackedBuffer(Args args) throws IOException { + byte[] compressed = compressedBytes(args); + testDecompression(ByteBuffer.wrap(compressed), args); + } + + @ParameterizedTest + @ArgumentsSource(Lz4ArgumentsProvider.class) + public void testArrayBackedBufferSlice(Args args) throws IOException { + byte[] compressed = compressedBytes(args); + + int sliceOffset = 12; + + ByteBuffer buffer = ByteBuffer.allocate(compressed.length + sliceOffset + 123); + buffer.position(sliceOffset); + buffer.put(compressed).flip(); + buffer.position(sliceOffset); + + ByteBuffer slice = buffer.slice(); + testDecompression(slice, args); + + int offset = 42; + buffer = ByteBuffer.allocate(compressed.length + sliceOffset + offset); + buffer.position(sliceOffset + offset); + buffer.put(compressed).flip(); + buffer.position(sliceOffset); + + slice = buffer.slice(); + slice.position(offset); + testDecompression(slice, args); + } + + @ParameterizedTest + @ArgumentsSource(Lz4ArgumentsProvider.class) + public void testDirectBuffer(Args args) throws IOException { + byte[] compressed = compressedBytes(args); + ByteBuffer buffer; + + buffer = ByteBuffer.allocateDirect(compressed.length); + buffer.put(compressed).flip(); + testDecompression(buffer, args); + + int offset = 42; + buffer = ByteBuffer.allocateDirect(compressed.length + offset + 123); + buffer.position(offset); + buffer.put(compressed).flip(); + buffer.position(offset); + testDecompression(buffer, args); + } + + @ParameterizedTest + @ArgumentsSource(Lz4ArgumentsProvider.class) + public void testSkip(Args args) throws Exception { + if (!args.close || (args.useBrokenFlagDescriptorChecksum && !args.ignoreFlagDescriptorChecksum)) return; + + final KafkaLZ4BlockInputStream in = makeInputStream(ByteBuffer.wrap(compressedBytes(args)), + args.ignoreFlagDescriptorChecksum); + + int n = 100; + int remaining = args.payload.length; + long skipped = in.skip(n); + assertEquals(Math.min(n, remaining), skipped); + + n = 10000; + remaining -= skipped; + skipped = in.skip(n); + assertEquals(Math.min(n, remaining), skipped); + } + + private void testDecompression(ByteBuffer buffer, Args args) throws IOException { + IOException error = null; + try { + KafkaLZ4BlockInputStream decompressed = makeInputStream(buffer, args.ignoreFlagDescriptorChecksum); + + byte[] testPayload = new byte[args.payload.length]; + + byte[] tmp = new byte[1024]; + int n, pos = 0, i = 0; + while ((n = decompressed.read(tmp, i, tmp.length - i)) != -1) { + i += n; + if (i == tmp.length) { + System.arraycopy(tmp, 0, testPayload, pos, i); + pos += i; + i = 0; + } + } + System.arraycopy(tmp, 0, testPayload, pos, i); + pos += i; + + assertEquals(-1, decompressed.read(tmp, 0, tmp.length)); + assertEquals(args.payload.length, pos); + assertArrayEquals(args.payload, testPayload); + } catch (IOException e) { + if (!args.ignoreFlagDescriptorChecksum && args.useBrokenFlagDescriptorChecksum) { + assertEquals(KafkaLZ4BlockInputStream.DESCRIPTOR_HASH_MISMATCH, e.getMessage()); + error = e; + } else if (!args.close) { + assertEquals(KafkaLZ4BlockInputStream.PREMATURE_EOS, e.getMessage()); + error = e; + } else { + throw e; + } + } + if (!args.ignoreFlagDescriptorChecksum && args.useBrokenFlagDescriptorChecksum) assertNotNull(error); + if (!args.close) assertNotNull(error); + } + + private byte[] compressedBytes(Args args) throws IOException { + ByteArrayOutputStream output = new ByteArrayOutputStream(); + KafkaLZ4BlockOutputStream lz4 = new KafkaLZ4BlockOutputStream( + output, + KafkaLZ4BlockOutputStream.BLOCKSIZE_64KB, + args.blockChecksum, + args.useBrokenFlagDescriptorChecksum + ); + lz4.write(args.payload, 0, args.payload.length); + if (args.close) { + lz4.close(); + } else { + lz4.flush(); + } + return output.toByteArray(); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/config/AbstractConfigTest.java b/clients/src/test/java/org/apache/kafka/common/config/AbstractConfigTest.java new file mode 100644 index 0000000..2d7247a --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/config/AbstractConfigTest.java @@ -0,0 +1,611 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.config; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.config.ConfigDef.Importance; +import org.apache.kafka.common.config.ConfigDef.Type; +import org.apache.kafka.common.config.types.Password; +import org.apache.kafka.common.metrics.FakeMetricsReporter; +import org.apache.kafka.common.metrics.JmxReporter; +import org.apache.kafka.common.metrics.MetricsReporter; +import org.apache.kafka.common.security.TestSecurityConfig; +import org.apache.kafka.common.config.provider.MockVaultConfigProvider; +import org.apache.kafka.common.config.provider.MockFileConfigProvider; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.UUID; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class AbstractConfigTest { + + @Test + public void testConfiguredInstances() { + testValidInputs(""); + testValidInputs("org.apache.kafka.common.metrics.FakeMetricsReporter"); + testValidInputs("org.apache.kafka.common.metrics.FakeMetricsReporter, org.apache.kafka.common.metrics.FakeMetricsReporter"); + testInvalidInputs(","); + testInvalidInputs("org.apache.kafka.clients.producer.unknown-metrics-reporter"); + testInvalidInputs("test1,test2"); + testInvalidInputs("org.apache.kafka.common.metrics.FakeMetricsReporter,"); + } + + @Test + public void testEmptyList() { + AbstractConfig conf; + ConfigDef configDef = new ConfigDef().define("a", Type.LIST, "", new ConfigDef.NonNullValidator(), Importance.HIGH, "doc"); + + conf = new AbstractConfig(configDef, Collections.emptyMap()); + assertEquals(Collections.emptyList(), conf.getList("a")); + + conf = new AbstractConfig(configDef, Collections.singletonMap("a", "")); + assertEquals(Collections.emptyList(), conf.getList("a")); + + conf = new AbstractConfig(configDef, Collections.singletonMap("a", "b,c,d")); + assertEquals(Arrays.asList("b", "c", "d"), conf.getList("a")); + } + + @Test + public void testOriginalsWithPrefix() { + Properties props = new Properties(); + props.put("foo.bar", "abc"); + props.put("setting", "def"); + TestConfig config = new TestConfig(props); + Map originalsWithPrefix = config.originalsWithPrefix("foo."); + + assertTrue(config.unused().contains("foo.bar")); + originalsWithPrefix.get("bar"); + assertFalse(config.unused().contains("foo.bar")); + + Map expected = new HashMap<>(); + expected.put("bar", "abc"); + assertEquals(expected, originalsWithPrefix); + } + + @Test + public void testValuesWithPrefixOverride() { + String prefix = "prefix."; + Properties props = new Properties(); + props.put("sasl.mechanism", "PLAIN"); + props.put("prefix.sasl.mechanism", "GSSAPI"); + props.put("prefix.sasl.kerberos.kinit.cmd", "/usr/bin/kinit2"); + props.put("prefix.ssl.truststore.location", "my location"); + props.put("sasl.kerberos.service.name", "service name"); + props.put("ssl.keymanager.algorithm", "algorithm"); + TestSecurityConfig config = new TestSecurityConfig(props); + Map valuesWithPrefixOverride = config.valuesWithPrefixOverride(prefix); + + // prefix overrides global + assertTrue(config.unused().contains("prefix.sasl.mechanism")); + assertTrue(config.unused().contains("sasl.mechanism")); + assertEquals("GSSAPI", valuesWithPrefixOverride.get("sasl.mechanism")); + assertFalse(config.unused().contains("sasl.mechanism")); + assertFalse(config.unused().contains("prefix.sasl.mechanism")); + + // prefix overrides default + assertTrue(config.unused().contains("prefix.sasl.kerberos.kinit.cmd")); + assertFalse(config.unused().contains("sasl.kerberos.kinit.cmd")); + assertEquals("/usr/bin/kinit2", valuesWithPrefixOverride.get("sasl.kerberos.kinit.cmd")); + assertFalse(config.unused().contains("sasl.kerberos.kinit.cmd")); + assertFalse(config.unused().contains("prefix.sasl.kerberos.kinit.cmd")); + + // prefix override with no default + assertTrue(config.unused().contains("prefix.ssl.truststore.location")); + assertFalse(config.unused().contains("ssl.truststore.location")); + assertEquals("my location", valuesWithPrefixOverride.get("ssl.truststore.location")); + assertFalse(config.unused().contains("ssl.truststore.location")); + assertFalse(config.unused().contains("prefix.ssl.truststore.location")); + + // global overrides default + assertTrue(config.unused().contains("ssl.keymanager.algorithm")); + assertEquals("algorithm", valuesWithPrefixOverride.get("ssl.keymanager.algorithm")); + assertFalse(config.unused().contains("ssl.keymanager.algorithm")); + + // global with no default + assertTrue(config.unused().contains("sasl.kerberos.service.name")); + assertEquals("service name", valuesWithPrefixOverride.get("sasl.kerberos.service.name")); + assertFalse(config.unused().contains("sasl.kerberos.service.name")); + + // unset with default + assertFalse(config.unused().contains("sasl.kerberos.min.time.before.relogin")); + assertEquals(SaslConfigs.DEFAULT_KERBEROS_MIN_TIME_BEFORE_RELOGIN, + valuesWithPrefixOverride.get("sasl.kerberos.min.time.before.relogin")); + assertFalse(config.unused().contains("sasl.kerberos.min.time.before.relogin")); + + // unset with no default + assertFalse(config.unused().contains("ssl.key.password")); + assertNull(valuesWithPrefixOverride.get("ssl.key.password")); + assertFalse(config.unused().contains("ssl.key.password")); + } + + @Test + public void testValuesWithSecondaryPrefix() { + String prefix = "listener.name.listener1."; + Password saslJaasConfig1 = new Password("test.myLoginModule1 required;"); + Password saslJaasConfig2 = new Password("test.myLoginModule2 required;"); + Password saslJaasConfig3 = new Password("test.myLoginModule3 required;"); + Properties props = new Properties(); + props.put("listener.name.listener1.test-mechanism.sasl.jaas.config", saslJaasConfig1.value()); + props.put("test-mechanism.sasl.jaas.config", saslJaasConfig2.value()); + props.put("sasl.jaas.config", saslJaasConfig3.value()); + props.put("listener.name.listener1.gssapi.sasl.kerberos.kinit.cmd", "/usr/bin/kinit2"); + props.put("listener.name.listener1.gssapi.sasl.kerberos.service.name", "testkafka"); + props.put("listener.name.listener1.gssapi.sasl.kerberos.min.time.before.relogin", "60000"); + props.put("ssl.provider", "TEST"); + TestSecurityConfig config = new TestSecurityConfig(props); + Map valuesWithPrefixOverride = config.valuesWithPrefixOverride(prefix); + + // prefix with mechanism overrides global + assertTrue(config.unused().contains("listener.name.listener1.test-mechanism.sasl.jaas.config")); + assertTrue(config.unused().contains("test-mechanism.sasl.jaas.config")); + assertEquals(saslJaasConfig1, valuesWithPrefixOverride.get("test-mechanism.sasl.jaas.config")); + assertEquals(saslJaasConfig3, valuesWithPrefixOverride.get("sasl.jaas.config")); + assertFalse(config.unused().contains("listener.name.listener1.test-mechanism.sasl.jaas.config")); + assertFalse(config.unused().contains("test-mechanism.sasl.jaas.config")); + assertFalse(config.unused().contains("sasl.jaas.config")); + + // prefix with mechanism overrides default + assertFalse(config.unused().contains("sasl.kerberos.kinit.cmd")); + assertTrue(config.unused().contains("listener.name.listener1.gssapi.sasl.kerberos.kinit.cmd")); + assertFalse(config.unused().contains("gssapi.sasl.kerberos.kinit.cmd")); + assertFalse(config.unused().contains("sasl.kerberos.kinit.cmd")); + assertEquals("/usr/bin/kinit2", valuesWithPrefixOverride.get("gssapi.sasl.kerberos.kinit.cmd")); + assertFalse(config.unused().contains("listener.name.listener1.sasl.kerberos.kinit.cmd")); + + // prefix override for mechanism with no default + assertFalse(config.unused().contains("sasl.kerberos.service.name")); + assertTrue(config.unused().contains("listener.name.listener1.gssapi.sasl.kerberos.service.name")); + assertFalse(config.unused().contains("gssapi.sasl.kerberos.service.name")); + assertFalse(config.unused().contains("sasl.kerberos.service.name")); + assertEquals("testkafka", valuesWithPrefixOverride.get("gssapi.sasl.kerberos.service.name")); + assertFalse(config.unused().contains("listener.name.listener1.gssapi.sasl.kerberos.service.name")); + + // unset with no default + assertTrue(config.unused().contains("ssl.provider")); + assertNull(valuesWithPrefixOverride.get("gssapi.ssl.provider")); + assertTrue(config.unused().contains("ssl.provider")); + } + + @Test + public void testValuesWithPrefixAllOrNothing() { + String prefix1 = "prefix1."; + String prefix2 = "prefix2."; + Properties props = new Properties(); + props.put("sasl.mechanism", "PLAIN"); + props.put("prefix1.sasl.mechanism", "GSSAPI"); + props.put("prefix1.sasl.kerberos.kinit.cmd", "/usr/bin/kinit2"); + props.put("prefix1.ssl.truststore.location", "my location"); + props.put("sasl.kerberos.service.name", "service name"); + props.put("ssl.keymanager.algorithm", "algorithm"); + TestSecurityConfig config = new TestSecurityConfig(props); + Map valuesWithPrefixAllOrNothing1 = config.valuesWithPrefixAllOrNothing(prefix1); + + // All prefixed values are there + assertEquals("GSSAPI", valuesWithPrefixAllOrNothing1.get("sasl.mechanism")); + assertEquals("/usr/bin/kinit2", valuesWithPrefixAllOrNothing1.get("sasl.kerberos.kinit.cmd")); + assertEquals("my location", valuesWithPrefixAllOrNothing1.get("ssl.truststore.location")); + + // Non-prefixed values are missing + assertFalse(valuesWithPrefixAllOrNothing1.containsKey("sasl.kerberos.service.name")); + assertFalse(valuesWithPrefixAllOrNothing1.containsKey("ssl.keymanager.algorithm")); + + Map valuesWithPrefixAllOrNothing2 = config.valuesWithPrefixAllOrNothing(prefix2); + assertTrue(valuesWithPrefixAllOrNothing2.containsKey("sasl.kerberos.service.name")); + assertTrue(valuesWithPrefixAllOrNothing2.containsKey("ssl.keymanager.algorithm")); + } + + @Test + public void testUnusedConfigs() { + Properties props = new Properties(); + String configValue = "org.apache.kafka.common.config.AbstractConfigTest$ConfiguredFakeMetricsReporter"; + props.put(TestConfig.METRIC_REPORTER_CLASSES_CONFIG, configValue); + props.put(ConfiguredFakeMetricsReporter.EXTRA_CONFIG, "my_value"); + TestConfig config = new TestConfig(props); + + assertTrue(config.unused().contains(ConfiguredFakeMetricsReporter.EXTRA_CONFIG), + ConfiguredFakeMetricsReporter.EXTRA_CONFIG + " should be marked unused before getConfiguredInstances is called"); + + config.getConfiguredInstances(TestConfig.METRIC_REPORTER_CLASSES_CONFIG, MetricsReporter.class); + assertFalse(config.unused().contains(ConfiguredFakeMetricsReporter.EXTRA_CONFIG), + ConfiguredFakeMetricsReporter.EXTRA_CONFIG + " should be marked as used"); + } + + private void testValidInputs(String configValue) { + Properties props = new Properties(); + props.put(TestConfig.METRIC_REPORTER_CLASSES_CONFIG, configValue); + TestConfig config = new TestConfig(props); + try { + config.getConfiguredInstances(TestConfig.METRIC_REPORTER_CLASSES_CONFIG, MetricsReporter.class); + } catch (ConfigException e) { + fail("No exceptions are expected here, valid props are :" + props); + } + } + + private void testInvalidInputs(String configValue) { + Properties props = new Properties(); + props.put(TestConfig.METRIC_REPORTER_CLASSES_CONFIG, configValue); + TestConfig config = new TestConfig(props); + try { + config.getConfiguredInstances(TestConfig.METRIC_REPORTER_CLASSES_CONFIG, MetricsReporter.class); + fail("Expected a config exception due to invalid props :" + props); + } catch (KafkaException e) { + // this is good + } + } + + @Test + public void testClassConfigs() { + class RestrictedClassLoader extends ClassLoader { + public RestrictedClassLoader() { + super(null); + } + @Override + protected Class findClass(String name) throws ClassNotFoundException { + if (name.equals(ClassTestConfig.DEFAULT_CLASS.getName()) || name.equals(ClassTestConfig.RESTRICTED_CLASS.getName())) + throw new ClassNotFoundException(); + else + return ClassTestConfig.class.getClassLoader().loadClass(name); + } + } + + ClassLoader restrictedClassLoader = new RestrictedClassLoader(); + ClassLoader defaultClassLoader = AbstractConfig.class.getClassLoader(); + + ClassLoader originClassLoader = Thread.currentThread().getContextClassLoader(); + try { + // Test default classloading where all classes are visible to thread context classloader + Thread.currentThread().setContextClassLoader(defaultClassLoader); + ClassTestConfig testConfig = new ClassTestConfig(); + testConfig.checkInstances(ClassTestConfig.DEFAULT_CLASS, ClassTestConfig.DEFAULT_CLASS); + + // Test default classloading where default classes are not visible to thread context classloader + // Static classloading is used for default classes, so instance creation should succeed. + Thread.currentThread().setContextClassLoader(restrictedClassLoader); + testConfig = new ClassTestConfig(); + testConfig.checkInstances(ClassTestConfig.DEFAULT_CLASS, ClassTestConfig.DEFAULT_CLASS); + + // Test class overrides with names or classes where all classes are visible to thread context classloader + Thread.currentThread().setContextClassLoader(defaultClassLoader); + ClassTestConfig.testOverrides(); + + // Test class overrides with names or classes where all classes are visible to Kafka classloader, context classloader is null + Thread.currentThread().setContextClassLoader(null); + ClassTestConfig.testOverrides(); + + // Test class overrides where some classes are not visible to thread context classloader + Thread.currentThread().setContextClassLoader(restrictedClassLoader); + // Properties specified as classes should succeed + testConfig = new ClassTestConfig(ClassTestConfig.RESTRICTED_CLASS, Collections.singletonList(ClassTestConfig.RESTRICTED_CLASS)); + testConfig.checkInstances(ClassTestConfig.RESTRICTED_CLASS, ClassTestConfig.RESTRICTED_CLASS); + testConfig = new ClassTestConfig(ClassTestConfig.RESTRICTED_CLASS, Arrays.asList(ClassTestConfig.VISIBLE_CLASS, ClassTestConfig.RESTRICTED_CLASS)); + testConfig.checkInstances(ClassTestConfig.RESTRICTED_CLASS, ClassTestConfig.VISIBLE_CLASS, ClassTestConfig.RESTRICTED_CLASS); + + // Properties specified as classNames should fail to load classes + assertThrows(ConfigException.class, () -> new ClassTestConfig(ClassTestConfig.RESTRICTED_CLASS.getName(), null), + "Config created with class property that cannot be loaded"); + + ClassTestConfig config = new ClassTestConfig(null, Arrays.asList(ClassTestConfig.VISIBLE_CLASS.getName(), ClassTestConfig.RESTRICTED_CLASS.getName())); + assertThrows(KafkaException.class, () -> config.getConfiguredInstances("list.prop", MetricsReporter.class), + "Should have failed to load class"); + + ClassTestConfig config2 = new ClassTestConfig(null, ClassTestConfig.VISIBLE_CLASS.getName() + "," + ClassTestConfig.RESTRICTED_CLASS.getName()); + assertThrows(KafkaException.class, () -> config2.getConfiguredInstances("list.prop", MetricsReporter.class), + "Should have failed to load class"); + } finally { + Thread.currentThread().setContextClassLoader(originClassLoader); + } + } + + @SuppressWarnings("unchecked") + public Map convertPropertiesToMap(Map props) { + for (Map.Entry entry : props.entrySet()) { + if (!(entry.getKey() instanceof String)) + throw new ConfigException(entry.getKey().toString(), entry.getValue(), + "Key must be a string."); + } + return (Map) props; + } + + @Test + public void testOriginalWithOverrides() { + Properties props = new Properties(); + props.put("config.providers", "file"); + TestIndirectConfigResolution config = new TestIndirectConfigResolution(props); + assertEquals(config.originals().get("config.providers"), "file"); + assertEquals(config.originals(Collections.singletonMap("config.providers", "file2")).get("config.providers"), "file2"); + } + + @Test + public void testOriginalsWithConfigProvidersProps() { + Properties props = new Properties(); + + // Test Case: Valid Test Case for ConfigProviders as part of config.properties + props.put("config.providers", "file"); + props.put("config.providers.file.class", MockFileConfigProvider.class.getName()); + String id = UUID.randomUUID().toString(); + props.put("config.providers.file.param.testId", id); + props.put("prefix.ssl.truststore.location.number", 5); + props.put("sasl.kerberos.service.name", "service name"); + props.put("sasl.kerberos.key", "${file:/usr/kerberos:key}"); + props.put("sasl.kerberos.password", "${file:/usr/kerberos:password}"); + TestIndirectConfigResolution config = new TestIndirectConfigResolution(props); + assertEquals("testKey", config.originals().get("sasl.kerberos.key")); + assertEquals("randomPassword", config.originals().get("sasl.kerberos.password")); + assertEquals(5, config.originals().get("prefix.ssl.truststore.location.number")); + assertEquals("service name", config.originals().get("sasl.kerberos.service.name")); + MockFileConfigProvider.assertClosed(id); + } + + @Test + public void testConfigProvidersPropsAsParam() { + // Test Case: Valid Test Case for ConfigProviders as a separate variable + Properties providers = new Properties(); + providers.put("config.providers", "file"); + providers.put("config.providers.file.class", MockFileConfigProvider.class.getName()); + String id = UUID.randomUUID().toString(); + providers.put("config.providers.file.param.testId", id); + Properties props = new Properties(); + props.put("sasl.kerberos.key", "${file:/usr/kerberos:key}"); + props.put("sasl.kerberos.password", "${file:/usr/kerberos:password}"); + TestIndirectConfigResolution config = new TestIndirectConfigResolution(props, convertPropertiesToMap(providers)); + assertEquals("testKey", config.originals().get("sasl.kerberos.key")); + assertEquals("randomPassword", config.originals().get("sasl.kerberos.password")); + MockFileConfigProvider.assertClosed(id); + } + + @Test + public void testImmutableOriginalsWithConfigProvidersProps() { + // Test Case: Valid Test Case for ConfigProviders as a separate variable + Properties providers = new Properties(); + providers.put("config.providers", "file"); + providers.put("config.providers.file.class", MockFileConfigProvider.class.getName()); + String id = UUID.randomUUID().toString(); + providers.put("config.providers.file.param.testId", id); + Properties props = new Properties(); + props.put("sasl.kerberos.key", "${file:/usr/kerberos:key}"); + Map immutableMap = Collections.unmodifiableMap(props); + Map provMap = convertPropertiesToMap(providers); + TestIndirectConfigResolution config = new TestIndirectConfigResolution(immutableMap, provMap); + assertEquals("testKey", config.originals().get("sasl.kerberos.key")); + MockFileConfigProvider.assertClosed(id); + } + + @Test + public void testAutoConfigResolutionWithMultipleConfigProviders() { + // Test Case: Valid Test Case With Multiple ConfigProviders as a separate variable + Properties providers = new Properties(); + providers.put("config.providers", "file,vault"); + providers.put("config.providers.file.class", MockFileConfigProvider.class.getName()); + String id = UUID.randomUUID().toString(); + providers.put("config.providers.file.param.testId", id); + providers.put("config.providers.vault.class", MockVaultConfigProvider.class.getName()); + Properties props = new Properties(); + props.put("sasl.kerberos.key", "${file:/usr/kerberos:key}"); + props.put("sasl.kerberos.password", "${file:/usr/kerberos:password}"); + props.put("sasl.truststore.key", "${vault:/usr/truststore:truststoreKey}"); + props.put("sasl.truststore.password", "${vault:/usr/truststore:truststorePassword}"); + TestIndirectConfigResolution config = new TestIndirectConfigResolution(props, convertPropertiesToMap(providers)); + assertEquals("testKey", config.originals().get("sasl.kerberos.key")); + assertEquals("randomPassword", config.originals().get("sasl.kerberos.password")); + assertEquals("testTruststoreKey", config.originals().get("sasl.truststore.key")); + assertEquals("randomtruststorePassword", config.originals().get("sasl.truststore.password")); + MockFileConfigProvider.assertClosed(id); + } + + @Test + public void testAutoConfigResolutionWithInvalidConfigProviderClass() { + // Test Case: Invalid class for Config Provider + Properties props = new Properties(); + props.put("config.providers", "file"); + props.put("config.providers.file.class", + "org.apache.kafka.common.config.provider.InvalidConfigProvider"); + props.put("testKey", "${test:/foo/bar/testpath:testKey}"); + try { + new TestIndirectConfigResolution(props); + fail("Expected a config exception due to invalid props :" + props); + } catch (KafkaException e) { + // this is good + } + } + + @Test + public void testAutoConfigResolutionWithMissingConfigProvider() { + // Test Case: Config Provider for a variable missing in config file. + Properties props = new Properties(); + props.put("testKey", "${test:/foo/bar/testpath:testKey}"); + TestIndirectConfigResolution config = new TestIndirectConfigResolution(props); + assertEquals("${test:/foo/bar/testpath:testKey}", config.originals().get("testKey")); + } + + @Test + public void testAutoConfigResolutionWithMissingConfigKey() { + // Test Case: Config Provider fails to resolve the config (key not present) + Properties props = new Properties(); + props.put("config.providers", "test"); + props.put("config.providers.test.class", MockFileConfigProvider.class.getName()); + String id = UUID.randomUUID().toString(); + props.put("config.providers.test.param.testId", id); + props.put("random", "${test:/foo/bar/testpath:random}"); + TestIndirectConfigResolution config = new TestIndirectConfigResolution(props); + assertEquals("${test:/foo/bar/testpath:random}", config.originals().get("random")); + MockFileConfigProvider.assertClosed(id); + } + + @Test + public void testAutoConfigResolutionWithDuplicateConfigProvider() { + // Test Case: If ConfigProvider is provided in both originals and provider. Only the ones in provider should be used. + Properties providers = new Properties(); + providers.put("config.providers", "test"); + providers.put("config.providers.test.class", MockVaultConfigProvider.class.getName()); + + Properties props = new Properties(); + props.put("sasl.kerberos.key", "${file:/usr/kerberos:key}"); + props.put("config.providers", "file"); + props.put("config.providers.file.class", MockVaultConfigProvider.class.getName()); + + TestIndirectConfigResolution config = new TestIndirectConfigResolution(props, convertPropertiesToMap(providers)); + assertEquals("${file:/usr/kerberos:key}", config.originals().get("sasl.kerberos.key")); + } + + @Test + public void testConfigProviderConfigurationWithConfigParams() { + // Test Case: Valid Test Case With Multiple ConfigProviders as a separate variable + Properties providers = new Properties(); + providers.put("config.providers", "vault"); + providers.put("config.providers.vault.class", MockVaultConfigProvider.class.getName()); + providers.put("config.providers.vault.param.key", "randomKey"); + providers.put("config.providers.vault.param.location", "/usr/vault"); + Properties props = new Properties(); + props.put("sasl.truststore.key", "${vault:/usr/truststore:truststoreKey}"); + props.put("sasl.truststore.password", "${vault:/usr/truststore:truststorePassword}"); + props.put("sasl.truststore.location", "${vault:/usr/truststore:truststoreLocation}"); + TestIndirectConfigResolution config = new TestIndirectConfigResolution(props, convertPropertiesToMap(providers)); + assertEquals("/usr/vault", config.originals().get("sasl.truststore.location")); + } + + @Test + public void testDocumentationOf() { + Properties props = new Properties(); + TestIndirectConfigResolution config = new TestIndirectConfigResolution(props); + + assertEquals( + TestIndirectConfigResolution.INDIRECT_CONFIGS_DOC, + config.documentationOf(TestIndirectConfigResolution.INDIRECT_CONFIGS) + ); + } + + @Test + public void testDocumentationOfExpectNull() { + Properties props = new Properties(); + TestIndirectConfigResolution config = new TestIndirectConfigResolution(props); + + assertNull(config.documentationOf("xyz")); + } + + private static class TestIndirectConfigResolution extends AbstractConfig { + + private static final ConfigDef CONFIG; + + public static final String INDIRECT_CONFIGS = "indirect.variables"; + private static final String INDIRECT_CONFIGS_DOC = "Variables whose values can be obtained from ConfigProviders"; + + static { + CONFIG = new ConfigDef().define(INDIRECT_CONFIGS, + Type.LIST, + "", + Importance.LOW, + INDIRECT_CONFIGS_DOC); + } + + public TestIndirectConfigResolution(Map props) { + super(CONFIG, props, true); + } + + public TestIndirectConfigResolution(Map props, Map providers) { + super(CONFIG, props, providers, true); + } + } + + private static class ClassTestConfig extends AbstractConfig { + static final Class DEFAULT_CLASS = FakeMetricsReporter.class; + static final Class VISIBLE_CLASS = JmxReporter.class; + static final Class RESTRICTED_CLASS = ConfiguredFakeMetricsReporter.class; + + private static final ConfigDef CONFIG; + static { + CONFIG = new ConfigDef().define("class.prop", Type.CLASS, DEFAULT_CLASS, Importance.HIGH, "docs") + .define("list.prop", Type.LIST, Collections.singletonList(DEFAULT_CLASS), Importance.HIGH, "docs"); + } + + public ClassTestConfig() { + super(CONFIG, new Properties()); + } + + public ClassTestConfig(Object classPropOverride, Object listPropOverride) { + super(CONFIG, overrideProps(classPropOverride, listPropOverride)); + } + + void checkInstances(Class expectedClassPropClass, Class... expectedListPropClasses) { + assertEquals(expectedClassPropClass, getConfiguredInstance("class.prop", MetricsReporter.class).getClass()); + List list = getConfiguredInstances("list.prop", MetricsReporter.class); + for (int i = 0; i < list.size(); i++) + assertEquals(expectedListPropClasses[i], list.get(i).getClass()); + } + + static void testOverrides() { + ClassTestConfig testConfig1 = new ClassTestConfig(RESTRICTED_CLASS, Arrays.asList(VISIBLE_CLASS, RESTRICTED_CLASS)); + testConfig1.checkInstances(RESTRICTED_CLASS, VISIBLE_CLASS, RESTRICTED_CLASS); + + ClassTestConfig testConfig2 = new ClassTestConfig(RESTRICTED_CLASS.getName(), Arrays.asList(VISIBLE_CLASS.getName(), RESTRICTED_CLASS.getName())); + testConfig2.checkInstances(RESTRICTED_CLASS, VISIBLE_CLASS, RESTRICTED_CLASS); + + ClassTestConfig testConfig3 = new ClassTestConfig(RESTRICTED_CLASS.getName(), VISIBLE_CLASS.getName() + "," + RESTRICTED_CLASS.getName()); + testConfig3.checkInstances(RESTRICTED_CLASS, VISIBLE_CLASS, RESTRICTED_CLASS); + } + + private static Map overrideProps(Object classProp, Object listProp) { + Map props = new HashMap<>(); + if (classProp != null) + props.put("class.prop", classProp); + if (listProp != null) + props.put("list.prop", listProp); + return props; + } + } + + private static class TestConfig extends AbstractConfig { + + private static final ConfigDef CONFIG; + + public static final String METRIC_REPORTER_CLASSES_CONFIG = "metric.reporters"; + private static final String METRIC_REPORTER_CLASSES_DOC = "A list of classes to use as metrics reporters."; + + static { + CONFIG = new ConfigDef().define(METRIC_REPORTER_CLASSES_CONFIG, + Type.LIST, + "", + Importance.LOW, + METRIC_REPORTER_CLASSES_DOC); + } + + public TestConfig(Map props) { + super(CONFIG, props); + } + } + + public static class ConfiguredFakeMetricsReporter extends FakeMetricsReporter { + public static final String EXTRA_CONFIG = "metric.extra_config"; + @Override + public void configure(Map configs) { + // Calling get() should have the side effect of marking that config as used. + // this is required by testUnusedConfigs + configs.get(EXTRA_CONFIG); + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/config/ConfigDefTest.java b/clients/src/test/java/org/apache/kafka/common/config/ConfigDefTest.java new file mode 100644 index 0000000..893f68b --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/config/ConfigDefTest.java @@ -0,0 +1,725 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.config; + +import org.apache.kafka.common.config.ConfigDef.CaseInsensitiveValidString; +import org.apache.kafka.common.config.ConfigDef.Importance; +import org.apache.kafka.common.config.ConfigDef.Range; +import org.apache.kafka.common.config.ConfigDef.Type; +import org.apache.kafka.common.config.ConfigDef.ValidString; +import org.apache.kafka.common.config.ConfigDef.Validator; +import org.apache.kafka.common.config.ConfigDef.Width; +import org.apache.kafka.common.config.types.Password; +import org.junit.jupiter.api.Test; + +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Set; + +import static java.util.Arrays.asList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class ConfigDefTest { + + @Test + public void testBasicTypes() { + ConfigDef def = new ConfigDef().define("a", Type.INT, 5, Range.between(0, 14), Importance.HIGH, "docs") + .define("b", Type.LONG, Importance.HIGH, "docs") + .define("c", Type.STRING, "hello", Importance.HIGH, "docs") + .define("d", Type.LIST, Importance.HIGH, "docs") + .define("e", Type.DOUBLE, Importance.HIGH, "docs") + .define("f", Type.CLASS, Importance.HIGH, "docs") + .define("g", Type.BOOLEAN, Importance.HIGH, "docs") + .define("h", Type.BOOLEAN, Importance.HIGH, "docs") + .define("i", Type.BOOLEAN, Importance.HIGH, "docs") + .define("j", Type.PASSWORD, Importance.HIGH, "docs"); + + Properties props = new Properties(); + props.put("a", "1 "); + props.put("b", 2); + props.put("d", " a , b, c"); + props.put("e", 42.5d); + props.put("f", String.class.getName()); + props.put("g", "true"); + props.put("h", "FalSE"); + props.put("i", "TRUE"); + props.put("j", "password"); + + Map vals = def.parse(props); + assertEquals(1, vals.get("a")); + assertEquals(2L, vals.get("b")); + assertEquals("hello", vals.get("c")); + assertEquals(asList("a", "b", "c"), vals.get("d")); + assertEquals(42.5d, vals.get("e")); + assertEquals(String.class, vals.get("f")); + assertEquals(true, vals.get("g")); + assertEquals(false, vals.get("h")); + assertEquals(true, vals.get("i")); + assertEquals(new Password("password"), vals.get("j")); + assertEquals(Password.HIDDEN, vals.get("j").toString()); + } + + @Test + public void testInvalidDefault() { + assertThrows(ConfigException.class, () -> new ConfigDef().define("a", Type.INT, "hello", Importance.HIGH, "docs")); + } + + @Test + public void testNullDefault() { + ConfigDef def = new ConfigDef().define("a", Type.INT, null, null, null, "docs"); + Map vals = def.parse(new Properties()); + + assertNull(vals.get("a")); + } + + @Test + public void testMissingRequired() { + assertThrows(ConfigException.class, () -> new ConfigDef().define("a", Type.INT, Importance.HIGH, "docs").parse(new HashMap())); + } + + @Test + public void testParsingEmptyDefaultValueForStringFieldShouldSucceed() { + new ConfigDef().define("a", Type.STRING, "", ConfigDef.Importance.HIGH, "docs") + .parse(new HashMap()); + } + + @Test + public void testDefinedTwice() { + assertThrows(ConfigException.class, () -> new ConfigDef().define("a", Type.STRING, + Importance.HIGH, "docs").define("a", Type.INT, Importance.HIGH, "docs")); + } + + @Test + public void testBadInputs() { + testBadInputs(Type.INT, "hello", "42.5", 42.5, Long.MAX_VALUE, Long.toString(Long.MAX_VALUE), new Object()); + testBadInputs(Type.LONG, "hello", "42.5", Long.toString(Long.MAX_VALUE) + "00", new Object()); + testBadInputs(Type.DOUBLE, "hello", new Object()); + testBadInputs(Type.STRING, new Object()); + testBadInputs(Type.LIST, 53, new Object()); + testBadInputs(Type.BOOLEAN, "hello", "truee", "fals"); + testBadInputs(Type.CLASS, "ClassDoesNotExist"); + } + + private void testBadInputs(Type type, Object... values) { + for (Object value : values) { + Map m = new HashMap(); + m.put("name", value); + ConfigDef def = new ConfigDef().define("name", type, Importance.HIGH, "docs"); + try { + def.parse(m); + fail("Expected a config exception on bad input for value " + value); + } catch (ConfigException e) { + // this is good + } + } + } + + @Test + public void testInvalidDefaultRange() { + assertThrows(ConfigException.class, () -> new ConfigDef().define("name", Type.INT, -1, + Range.between(0, 10), Importance.HIGH, "docs")); + } + + @Test + public void testInvalidDefaultString() { + assertThrows(ConfigException.class, () -> new ConfigDef().define("name", Type.STRING, "bad", + ValidString.in("valid", "values"), Importance.HIGH, "docs")); + } + + @Test + public void testNestedClass() { + // getName(), not getSimpleName() or getCanonicalName(), is the version that should be able to locate the class + Map props = Collections.singletonMap("name", NestedClass.class.getName()); + new ConfigDef().define("name", Type.CLASS, Importance.HIGH, "docs").parse(props); + } + + @Test + public void testValidators() { + testValidators(Type.INT, Range.between(0, 10), 5, new Object[]{1, 5, 9}, new Object[]{-1, 11, null}); + testValidators(Type.STRING, ValidString.in("good", "values", "default"), "default", + new Object[]{"good", "values", "default"}, new Object[]{"bad", "inputs", "DEFAULT", null}); + testValidators(Type.STRING, CaseInsensitiveValidString.in("good", "values", "default"), "default", + new Object[]{"gOOd", "VALUES", "default"}, new Object[]{"Bad", "iNPUts", null}); + testValidators(Type.LIST, ConfigDef.ValidList.in("1", "2", "3"), "1", new Object[]{"1", "2", "3"}, new Object[]{"4", "5", "6"}); + testValidators(Type.STRING, new ConfigDef.NonNullValidator(), "a", new Object[]{"abb"}, new Object[] {null}); + testValidators(Type.STRING, ConfigDef.CompositeValidator.of(new ConfigDef.NonNullValidator(), ValidString.in("a", "b")), "a", new Object[]{"a", "b"}, new Object[] {null, -1, "c"}); + testValidators(Type.STRING, new ConfigDef.NonEmptyStringWithoutControlChars(), "defaultname", + new Object[]{"test", "name", "test/test", "test\u1234", "\u1324name\\", "/+%>&):??<&()?-", "+1", "\uD83D\uDE01", "\uF3B1", " test \n\r", "\n hello \t"}, + new Object[]{"nontrailing\nnotallowed", "as\u0001cii control char", "tes\rt", "test\btest", "1\t2", ""}); + } + + @Test + public void testSslPasswords() { + ConfigDef def = new ConfigDef(); + SslConfigs.addClientSslSupport(def); + + Properties props = new Properties(); + props.put(SslConfigs.SSL_KEY_PASSWORD_CONFIG, "key_password"); + props.put(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG, "keystore_password"); + props.put(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG, "truststore_password"); + + Map vals = def.parse(props); + assertEquals(new Password("key_password"), vals.get(SslConfigs.SSL_KEY_PASSWORD_CONFIG)); + assertEquals(Password.HIDDEN, vals.get(SslConfigs.SSL_KEY_PASSWORD_CONFIG).toString()); + assertEquals(new Password("keystore_password"), vals.get(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG)); + assertEquals(Password.HIDDEN, vals.get(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG).toString()); + assertEquals(new Password("truststore_password"), vals.get(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG)); + assertEquals(Password.HIDDEN, vals.get(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG).toString()); + } + + @Test + public void testNullDefaultWithValidator() { + final String key = "enum_test"; + + ConfigDef def = new ConfigDef(); + def.define(key, Type.STRING, ConfigDef.NO_DEFAULT_VALUE, + ValidString.in("ONE", "TWO", "THREE"), Importance.HIGH, "docs"); + + Properties props = new Properties(); + props.put(key, "ONE"); + Map vals = def.parse(props); + assertEquals("ONE", vals.get(key)); + } + + @Test + public void testGroupInference() { + List expected1 = Arrays.asList("group1", "group2"); + ConfigDef def1 = new ConfigDef() + .define("a", Type.INT, Importance.HIGH, "docs", "group1", 1, Width.SHORT, "a") + .define("b", Type.INT, Importance.HIGH, "docs", "group2", 1, Width.SHORT, "b") + .define("c", Type.INT, Importance.HIGH, "docs", "group1", 2, Width.SHORT, "c"); + + assertEquals(expected1, def1.groups()); + + List expected2 = Arrays.asList("group2", "group1"); + ConfigDef def2 = new ConfigDef() + .define("a", Type.INT, Importance.HIGH, "docs", "group2", 1, Width.SHORT, "a") + .define("b", Type.INT, Importance.HIGH, "docs", "group2", 2, Width.SHORT, "b") + .define("c", Type.INT, Importance.HIGH, "docs", "group1", 2, Width.SHORT, "c"); + + assertEquals(expected2, def2.groups()); + } + + @Test + public void testParseForValidate() { + Map expectedParsed = new HashMap<>(); + expectedParsed.put("a", 1); + expectedParsed.put("b", null); + expectedParsed.put("c", null); + expectedParsed.put("d", 10); + + Map expected = new HashMap<>(); + String errorMessageB = "Missing required configuration \"b\" which has no default value."; + String errorMessageC = "Missing required configuration \"c\" which has no default value."; + ConfigValue configA = new ConfigValue("a", 1, Collections.emptyList(), Collections.emptyList()); + ConfigValue configB = new ConfigValue("b", null, Collections.emptyList(), Arrays.asList(errorMessageB, errorMessageB)); + ConfigValue configC = new ConfigValue("c", null, Collections.emptyList(), Arrays.asList(errorMessageC)); + ConfigValue configD = new ConfigValue("d", 10, Collections.emptyList(), Collections.emptyList()); + expected.put("a", configA); + expected.put("b", configB); + expected.put("c", configC); + expected.put("d", configD); + + ConfigDef def = new ConfigDef() + .define("a", Type.INT, Importance.HIGH, "docs", "group", 1, Width.SHORT, "a", Arrays.asList("b", "c"), new IntegerRecommender(false)) + .define("b", Type.INT, Importance.HIGH, "docs", "group", 2, Width.SHORT, "b", new IntegerRecommender(true)) + .define("c", Type.INT, Importance.HIGH, "docs", "group", 3, Width.SHORT, "c", new IntegerRecommender(true)) + .define("d", Type.INT, Importance.HIGH, "docs", "group", 4, Width.SHORT, "d", Arrays.asList("b"), new IntegerRecommender(false)); + + Map props = new HashMap<>(); + props.put("a", "1"); + props.put("d", "10"); + + Map configValues = new HashMap<>(); + + for (String name : def.configKeys().keySet()) { + configValues.put(name, new ConfigValue(name)); + } + + Map parsed = def.parseForValidate(props, configValues); + + assertEquals(expectedParsed, parsed); + assertEquals(expected, configValues); + } + + @Test + public void testValidate() { + Map expected = new HashMap<>(); + String errorMessageB = "Missing required configuration \"b\" which has no default value."; + String errorMessageC = "Missing required configuration \"c\" which has no default value."; + + ConfigValue configA = new ConfigValue("a", 1, Arrays.asList(1, 2, 3), Collections.emptyList()); + ConfigValue configB = new ConfigValue("b", null, Arrays.asList(4, 5), Arrays.asList(errorMessageB, errorMessageB)); + ConfigValue configC = new ConfigValue("c", null, Arrays.asList(4, 5), Arrays.asList(errorMessageC)); + ConfigValue configD = new ConfigValue("d", 10, Arrays.asList(1, 2, 3), Collections.emptyList()); + + expected.put("a", configA); + expected.put("b", configB); + expected.put("c", configC); + expected.put("d", configD); + + ConfigDef def = new ConfigDef() + .define("a", Type.INT, Importance.HIGH, "docs", "group", 1, Width.SHORT, "a", Arrays.asList("b", "c"), new IntegerRecommender(false)) + .define("b", Type.INT, Importance.HIGH, "docs", "group", 2, Width.SHORT, "b", new IntegerRecommender(true)) + .define("c", Type.INT, Importance.HIGH, "docs", "group", 3, Width.SHORT, "c", new IntegerRecommender(true)) + .define("d", Type.INT, Importance.HIGH, "docs", "group", 4, Width.SHORT, "d", Arrays.asList("b"), new IntegerRecommender(false)); + + Map props = new HashMap<>(); + props.put("a", "1"); + props.put("d", "10"); + + List configs = def.validate(props); + for (ConfigValue config : configs) { + String name = config.name(); + ConfigValue expectedConfig = expected.get(name); + assertEquals(expectedConfig, config); + } + } + + @Test + public void testValidateMissingConfigKey() { + Map expected = new HashMap<>(); + String errorMessageB = "Missing required configuration \"b\" which has no default value."; + String errorMessageC = "Missing required configuration \"c\" which has no default value."; + String errorMessageD = "d is referred in the dependents, but not defined."; + + ConfigValue configA = new ConfigValue("a", 1, Arrays.asList(1, 2, 3), Collections.emptyList()); + ConfigValue configB = new ConfigValue("b", null, Arrays.asList(4, 5), Arrays.asList(errorMessageB)); + ConfigValue configC = new ConfigValue("c", null, Arrays.asList(4, 5), Arrays.asList(errorMessageC)); + ConfigValue configD = new ConfigValue("d", null, Collections.emptyList(), Arrays.asList(errorMessageD)); + configD.visible(false); + + expected.put("a", configA); + expected.put("b", configB); + expected.put("c", configC); + expected.put("d", configD); + + ConfigDef def = new ConfigDef() + .define("a", Type.INT, Importance.HIGH, "docs", "group", 1, Width.SHORT, "a", Arrays.asList("b", "c", "d"), new IntegerRecommender(false)) + .define("b", Type.INT, Importance.HIGH, "docs", "group", 2, Width.SHORT, "b", new IntegerRecommender(true)) + .define("c", Type.INT, Importance.HIGH, "docs", "group", 3, Width.SHORT, "c", new IntegerRecommender(true)); + + Map props = new HashMap<>(); + props.put("a", "1"); + + List configs = def.validate(props); + for (ConfigValue config: configs) { + String name = config.name(); + ConfigValue expectedConfig = expected.get(name); + assertEquals(expectedConfig, config); + } + } + + @Test + public void testValidateCannotParse() { + Map expected = new HashMap<>(); + String errorMessageB = "Invalid value non_integer for configuration a: Not a number of type INT"; + ConfigValue configA = new ConfigValue("a", null, Collections.emptyList(), Arrays.asList(errorMessageB)); + expected.put("a", configA); + + ConfigDef def = new ConfigDef().define("a", Type.INT, Importance.HIGH, "docs"); + Map props = new HashMap<>(); + props.put("a", "non_integer"); + + List configs = def.validate(props); + for (ConfigValue config: configs) { + String name = config.name(); + ConfigValue expectedConfig = expected.get(name); + assertEquals(expectedConfig, config); + } + } + + @Test + public void testCanAddInternalConfig() throws Exception { + final String configName = "internal.config"; + final ConfigDef configDef = new ConfigDef().defineInternal(configName, Type.STRING, "", Importance.LOW); + final HashMap properties = new HashMap<>(); + properties.put(configName, "value"); + final List results = configDef.validate(properties); + final ConfigValue configValue = results.get(0); + assertEquals("value", configValue.value()); + assertEquals(configName, configValue.name()); + } + + @Test + public void testInternalConfigDoesntShowUpInDocs() { + final String name = "my.config"; + final ConfigDef configDef = new ConfigDef().defineInternal(name, Type.STRING, "", Importance.LOW); + configDef.defineInternal("my.other.config", Type.STRING, "", null, Importance.LOW, null); + assertFalse(configDef.toHtmlTable().contains("my.config")); + assertFalse(configDef.toEnrichedRst().contains("my.config")); + assertFalse(configDef.toRst().contains("my.config")); + assertFalse(configDef.toHtmlTable().contains("my.other.config")); + assertFalse(configDef.toEnrichedRst().contains("my.other.config")); + assertFalse(configDef.toRst().contains("my.other.config")); + } + + @Test + public void testDynamicUpdateModeInDocs() throws Exception { + final ConfigDef configDef = new ConfigDef() + .define("my.broker.config", Type.LONG, Importance.HIGH, "docs") + .define("my.cluster.config", Type.LONG, Importance.HIGH, "docs") + .define("my.readonly.config", Type.LONG, Importance.HIGH, "docs"); + final Map updateModes = new HashMap<>(); + updateModes.put("my.broker.config", "per-broker"); + updateModes.put("my.cluster.config", "cluster-wide"); + final String html = configDef.toHtmlTable(updateModes); + Set configsInHtml = new HashSet<>(); + for (String line : html.split("\n")) { + if (line.contains("my.broker.config")) { + assertTrue(line.contains("per-broker")); + configsInHtml.add("my.broker.config"); + } else if (line.contains("my.cluster.config")) { + assertTrue(line.contains("cluster-wide")); + configsInHtml.add("my.cluster.config"); + } else if (line.contains("my.readonly.config")) { + assertTrue(line.contains("read-only")); + configsInHtml.add("my.readonly.config"); + } + } + assertEquals(configDef.names(), configsInHtml); + } + + @Test + public void testNames() { + final ConfigDef configDef = new ConfigDef() + .define("a", Type.STRING, Importance.LOW, "docs") + .define("b", Type.STRING, Importance.LOW, "docs"); + Set names = configDef.names(); + assertEquals(new HashSet<>(Arrays.asList("a", "b")), names); + // should be unmodifiable + try { + names.add("new"); + fail(); + } catch (UnsupportedOperationException e) { + // expected + } + } + + @Test + public void testMissingDependentConfigs() { + // Should not be possible to parse a config if a dependent config has not been defined + final ConfigDef configDef = new ConfigDef() + .define("parent", Type.STRING, Importance.HIGH, "parent docs", "group", 1, Width.LONG, "Parent", Collections.singletonList("child")); + assertThrows(ConfigException.class, () -> configDef.parse(Collections.emptyMap())); + } + + @Test + public void testBaseConfigDefDependents() { + // Creating a ConfigDef based on another should compute the correct number of configs with no parent, even + // if the base ConfigDef has already computed its parentless configs + final ConfigDef baseConfigDef = new ConfigDef().define("a", Type.STRING, Importance.LOW, "docs"); + assertEquals(new HashSet<>(Arrays.asList("a")), baseConfigDef.getConfigsWithNoParent()); + + final ConfigDef configDef = new ConfigDef(baseConfigDef) + .define("parent", Type.STRING, Importance.HIGH, "parent docs", "group", 1, Width.LONG, "Parent", Collections.singletonList("child")) + .define("child", Type.STRING, Importance.HIGH, "docs"); + + assertEquals(new HashSet<>(Arrays.asList("a", "parent")), configDef.getConfigsWithNoParent()); + } + + + private static class IntegerRecommender implements ConfigDef.Recommender { + + private boolean hasParent; + + public IntegerRecommender(boolean hasParent) { + this.hasParent = hasParent; + } + + @Override + public List validValues(String name, Map parsedConfig) { + List values = new LinkedList<>(); + if (!hasParent) { + values.addAll(Arrays.asList(1, 2, 3)); + } else { + values.addAll(Arrays.asList(4, 5)); + } + return values; + } + + @Override + public boolean visible(String name, Map parsedConfig) { + return true; + } + } + + private void testValidators(Type type, Validator validator, Object defaultVal, Object[] okValues, Object[] badValues) { + ConfigDef def = new ConfigDef().define("name", type, defaultVal, validator, Importance.HIGH, "docs"); + + for (Object value : okValues) { + Map m = new HashMap(); + m.put("name", value); + def.parse(m); + } + + for (Object value : badValues) { + Map m = new HashMap(); + m.put("name", value); + try { + def.parse(m); + fail("Expected a config exception due to invalid value " + value); + } catch (ConfigException e) { + // this is good + } + } + } + + @Test + public void toRst() { + final ConfigDef def = new ConfigDef() + .define("opt1", Type.STRING, "a", ValidString.in("a", "b", "c"), Importance.HIGH, "docs1") + .define("opt2", Type.INT, Importance.MEDIUM, "docs2") + .define("opt3", Type.LIST, Arrays.asList("a", "b"), Importance.LOW, "docs3") + .define("opt4", Type.BOOLEAN, false, Importance.LOW, null); + + final String expectedRst = "" + + "``opt2``\n" + + " docs2\n" + + "\n" + + " * Type: int\n" + + " * Importance: medium\n" + + "\n" + + "``opt1``\n" + + " docs1\n" + + "\n" + + " * Type: string\n" + + " * Default: a\n" + + " * Valid Values: [a, b, c]\n" + + " * Importance: high\n" + + "\n" + + "``opt3``\n" + + " docs3\n" + + "\n" + + " * Type: list\n" + + " * Default: a,b\n" + + " * Importance: low\n" + + "\n" + + "``opt4``\n" + + "\n" + + " * Type: boolean\n" + + " * Default: false\n" + + " * Importance: low\n" + + "\n"; + + assertEquals(expectedRst, def.toRst()); + } + + @Test + public void toEnrichedRst() { + final ConfigDef def = new ConfigDef() + .define("opt1.of.group1", Type.STRING, "a", ValidString.in("a", "b", "c"), Importance.HIGH, "Doc doc.", + "Group One", 0, Width.NONE, "..", Collections.emptyList()) + .define("opt2.of.group1", Type.INT, ConfigDef.NO_DEFAULT_VALUE, Importance.MEDIUM, "Doc doc doc.", + "Group One", 1, Width.NONE, "..", Arrays.asList("some.option1", "some.option2")) + .define("opt2.of.group2", Type.BOOLEAN, false, Importance.HIGH, "Doc doc doc doc.", + "Group Two", 1, Width.NONE, "..", Collections.emptyList()) + .define("opt1.of.group2", Type.BOOLEAN, false, Importance.HIGH, "Doc doc doc doc doc.", + "Group Two", 0, Width.NONE, "..", Collections.singletonList("some.option")) + .define("poor.opt", Type.STRING, "foo", Importance.HIGH, "Doc doc doc doc."); + + final String expectedRst = "" + + "``poor.opt``\n" + + " Doc doc doc doc.\n" + + "\n" + + " * Type: string\n" + + " * Default: foo\n" + + " * Importance: high\n" + + "\n" + + "Group One\n" + + "^^^^^^^^^\n" + + "\n" + + "``opt1.of.group1``\n" + + " Doc doc.\n" + + "\n" + + " * Type: string\n" + + " * Default: a\n" + + " * Valid Values: [a, b, c]\n" + + " * Importance: high\n" + + "\n" + + "``opt2.of.group1``\n" + + " Doc doc doc.\n" + + "\n" + + " * Type: int\n" + + " * Importance: medium\n" + + " * Dependents: ``some.option1``, ``some.option2``\n" + + "\n" + + "Group Two\n" + + "^^^^^^^^^\n" + + "\n" + + "``opt1.of.group2``\n" + + " Doc doc doc doc doc.\n" + + "\n" + + " * Type: boolean\n" + + " * Default: false\n" + + " * Importance: high\n" + + " * Dependents: ``some.option``\n" + + "\n" + + "``opt2.of.group2``\n" + + " Doc doc doc doc.\n" + + "\n" + + " * Type: boolean\n" + + " * Default: false\n" + + " * Importance: high\n" + + "\n"; + + assertEquals(expectedRst, def.toEnrichedRst()); + } + + @Test + public void testConvertValueToStringBoolean() { + assertEquals("true", ConfigDef.convertToString(true, Type.BOOLEAN)); + assertNull(ConfigDef.convertToString(null, Type.BOOLEAN)); + } + + @Test + public void testConvertValueToStringShort() { + assertEquals("32767", ConfigDef.convertToString(Short.MAX_VALUE, Type.SHORT)); + assertNull(ConfigDef.convertToString(null, Type.SHORT)); + } + + @Test + public void testConvertValueToStringInt() { + assertEquals("2147483647", ConfigDef.convertToString(Integer.MAX_VALUE, Type.INT)); + assertNull(ConfigDef.convertToString(null, Type.INT)); + } + + @Test + public void testConvertValueToStringLong() { + assertEquals("9223372036854775807", ConfigDef.convertToString(Long.MAX_VALUE, Type.LONG)); + assertNull(ConfigDef.convertToString(null, Type.LONG)); + } + + @Test + public void testConvertValueToStringDouble() { + assertEquals("3.125", ConfigDef.convertToString(3.125, Type.DOUBLE)); + assertNull(ConfigDef.convertToString(null, Type.DOUBLE)); + } + + @Test + public void testConvertValueToStringString() { + assertEquals("foobar", ConfigDef.convertToString("foobar", Type.STRING)); + assertNull(ConfigDef.convertToString(null, Type.STRING)); + } + + @Test + public void testConvertValueToStringPassword() { + assertEquals(Password.HIDDEN, ConfigDef.convertToString(new Password("foobar"), Type.PASSWORD)); + assertEquals("foobar", ConfigDef.convertToString("foobar", Type.PASSWORD)); + assertNull(ConfigDef.convertToString(null, Type.PASSWORD)); + } + + @Test + public void testConvertValueToStringList() { + assertEquals("a,bc,d", ConfigDef.convertToString(Arrays.asList("a", "bc", "d"), Type.LIST)); + assertNull(ConfigDef.convertToString(null, Type.LIST)); + } + + @Test + public void testConvertValueToStringClass() throws ClassNotFoundException { + String actual = ConfigDef.convertToString(ConfigDefTest.class, Type.CLASS); + assertEquals("org.apache.kafka.common.config.ConfigDefTest", actual); + // Additionally validate that we can look up this class by this name + assertEquals(ConfigDefTest.class, Class.forName(actual)); + assertNull(ConfigDef.convertToString(null, Type.CLASS)); + } + + @Test + public void testConvertValueToStringNestedClass() throws ClassNotFoundException { + String actual = ConfigDef.convertToString(NestedClass.class, Type.CLASS); + assertEquals("org.apache.kafka.common.config.ConfigDefTest$NestedClass", actual); + // Additionally validate that we can look up this class by this name + assertEquals(NestedClass.class, Class.forName(actual)); + } + + @Test + public void testClassWithAlias() { + final String alias = "PluginAlias"; + ClassLoader originalClassLoader = Thread.currentThread().getContextClassLoader(); + try { + // Could try to use the Plugins class from Connect here, but this should simulate enough + // of the aliasing logic to suffice for this test. + Thread.currentThread().setContextClassLoader(new ClassLoader(originalClassLoader) { + @Override + public Class loadClass(String name, boolean resolve) throws ClassNotFoundException { + if (alias.equals(name)) { + return NestedClass.class; + } else { + return super.loadClass(name, resolve); + } + } + }); + ConfigDef.parseType("Test config", alias, Type.CLASS); + } finally { + Thread.currentThread().setContextClassLoader(originalClassLoader); + } + } + + private class NestedClass { + } + + @Test + public void testNiceMemoryUnits() { + assertEquals("", ConfigDef.niceMemoryUnits(0L)); + assertEquals("", ConfigDef.niceMemoryUnits(1023)); + assertEquals(" (1 kibibyte)", ConfigDef.niceMemoryUnits(1024)); + assertEquals("", ConfigDef.niceMemoryUnits(1025)); + assertEquals(" (2 kibibytes)", ConfigDef.niceMemoryUnits(2 * 1024)); + assertEquals(" (1 mebibyte)", ConfigDef.niceMemoryUnits(1024 * 1024)); + assertEquals(" (2 mebibytes)", ConfigDef.niceMemoryUnits(2 * 1024 * 1024)); + assertEquals(" (1 gibibyte)", ConfigDef.niceMemoryUnits(1024 * 1024 * 1024)); + assertEquals(" (2 gibibytes)", ConfigDef.niceMemoryUnits(2L * 1024 * 1024 * 1024)); + assertEquals(" (1 tebibyte)", ConfigDef.niceMemoryUnits(1024L * 1024 * 1024 * 1024)); + assertEquals(" (2 tebibytes)", ConfigDef.niceMemoryUnits(2L * 1024 * 1024 * 1024 * 1024)); + assertEquals(" (1024 tebibytes)", ConfigDef.niceMemoryUnits(1024L * 1024 * 1024 * 1024 * 1024)); + assertEquals(" (2048 tebibytes)", ConfigDef.niceMemoryUnits(2L * 1024 * 1024 * 1024 * 1024 * 1024)); + } + + @Test + public void testNiceTimeUnits() { + assertEquals("", ConfigDef.niceTimeUnits(0)); + assertEquals("", ConfigDef.niceTimeUnits(Duration.ofSeconds(1).toMillis() - 1)); + assertEquals(" (1 second)", ConfigDef.niceTimeUnits(Duration.ofSeconds(1).toMillis())); + assertEquals("", ConfigDef.niceTimeUnits(Duration.ofSeconds(1).toMillis() + 1)); + assertEquals(" (2 seconds)", ConfigDef.niceTimeUnits(Duration.ofSeconds(2).toMillis())); + + assertEquals(" (1 minute)", ConfigDef.niceTimeUnits(Duration.ofMinutes(1).toMillis())); + assertEquals(" (2 minutes)", ConfigDef.niceTimeUnits(Duration.ofMinutes(2).toMillis())); + + assertEquals(" (1 hour)", ConfigDef.niceTimeUnits(Duration.ofHours(1).toMillis())); + assertEquals(" (2 hours)", ConfigDef.niceTimeUnits(Duration.ofHours(2).toMillis())); + + assertEquals(" (1 day)", ConfigDef.niceTimeUnits(Duration.ofDays(1).toMillis())); + assertEquals(" (2 days)", ConfigDef.niceTimeUnits(Duration.ofDays(2).toMillis())); + + assertEquals(" (7 days)", ConfigDef.niceTimeUnits(Duration.ofDays(7).toMillis())); + assertEquals(" (365 days)", ConfigDef.niceTimeUnits(Duration.ofDays(365).toMillis())); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/config/ConfigResourceTest.java b/clients/src/test/java/org/apache/kafka/common/config/ConfigResourceTest.java new file mode 100644 index 0000000..9247db4 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/config/ConfigResourceTest.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.config; + +import org.junit.jupiter.api.Test; + +import java.util.Arrays; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class ConfigResourceTest { + @Test + public void shouldGetTypeFromId() { + assertEquals(ConfigResource.Type.TOPIC, ConfigResource.Type.forId((byte) 2)); + assertEquals(ConfigResource.Type.BROKER, ConfigResource.Type.forId((byte) 4)); + } + + @Test + public void shouldReturnUnknownForUnknownCode() { + assertEquals(ConfigResource.Type.UNKNOWN, ConfigResource.Type.forId((byte) -1)); + assertEquals(ConfigResource.Type.UNKNOWN, ConfigResource.Type.forId((byte) 0)); + assertEquals(ConfigResource.Type.UNKNOWN, ConfigResource.Type.forId((byte) 1)); + } + + @Test + public void shouldRoundTripEveryType() { + Arrays.stream(ConfigResource.Type.values()).forEach(type -> + assertEquals(type, ConfigResource.Type.forId(type.id()), type.toString())); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/config/ConfigTransformerTest.java b/clients/src/test/java/org/apache/kafka/common/config/ConfigTransformerTest.java new file mode 100644 index 0000000..93296d9 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/config/ConfigTransformerTest.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.config; + +import org.apache.kafka.common.config.provider.ConfigProvider; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ConfigTransformerTest { + + public static final String MY_KEY = "myKey"; + public static final String TEST_INDIRECTION = "testIndirection"; + public static final String TEST_KEY = "testKey"; + public static final String TEST_KEY_WITH_TTL = "testKeyWithTTL"; + public static final String TEST_PATH = "testPath"; + public static final String TEST_RESULT = "testResult"; + public static final String TEST_RESULT_WITH_TTL = "testResultWithTTL"; + public static final String TEST_RESULT_NO_PATH = "testResultNoPath"; + + private ConfigTransformer configTransformer; + + @BeforeEach + public void setup() { + configTransformer = new ConfigTransformer(Collections.singletonMap("test", new TestConfigProvider())); + } + + @Test + public void testReplaceVariable() throws Exception { + ConfigTransformerResult result = configTransformer.transform(Collections.singletonMap(MY_KEY, "${test:testPath:testKey}")); + Map data = result.data(); + Map ttls = result.ttls(); + assertEquals(TEST_RESULT, data.get(MY_KEY)); + assertTrue(ttls.isEmpty()); + } + + @Test + public void testReplaceVariableWithTTL() throws Exception { + ConfigTransformerResult result = configTransformer.transform(Collections.singletonMap(MY_KEY, "${test:testPath:testKeyWithTTL}")); + Map data = result.data(); + Map ttls = result.ttls(); + assertEquals(TEST_RESULT_WITH_TTL, data.get(MY_KEY)); + assertEquals(1L, ttls.get(TEST_PATH).longValue()); + } + + @Test + public void testReplaceMultipleVariablesInValue() throws Exception { + ConfigTransformerResult result = configTransformer.transform(Collections.singletonMap(MY_KEY, "hello, ${test:testPath:testKey}; goodbye, ${test:testPath:testKeyWithTTL}!!!")); + Map data = result.data(); + assertEquals("hello, testResult; goodbye, testResultWithTTL!!!", data.get(MY_KEY)); + } + + @Test + public void testNoReplacement() throws Exception { + ConfigTransformerResult result = configTransformer.transform(Collections.singletonMap(MY_KEY, "${test:testPath:missingKey}")); + Map data = result.data(); + assertEquals("${test:testPath:missingKey}", data.get(MY_KEY)); + } + + @Test + public void testSingleLevelOfIndirection() throws Exception { + ConfigTransformerResult result = configTransformer.transform(Collections.singletonMap(MY_KEY, "${test:testPath:testIndirection}")); + Map data = result.data(); + assertEquals("${test:testPath:testResult}", data.get(MY_KEY)); + } + + @Test + public void testReplaceVariableNoPath() throws Exception { + ConfigTransformerResult result = configTransformer.transform(Collections.singletonMap(MY_KEY, "${test:testKey}")); + Map data = result.data(); + Map ttls = result.ttls(); + assertEquals(TEST_RESULT_NO_PATH, data.get(MY_KEY)); + assertTrue(ttls.isEmpty()); + } + + @Test + public void testReplaceMultipleVariablesWithoutPathInValue() throws Exception { + ConfigTransformerResult result = configTransformer.transform(Collections.singletonMap(MY_KEY, "first ${test:testKey}; second ${test:testKey}")); + Map data = result.data(); + assertEquals("first testResultNoPath; second testResultNoPath", data.get(MY_KEY)); + } + + @Test + public void testNullConfigValue() throws Exception { + ConfigTransformerResult result = configTransformer.transform(Collections.singletonMap(MY_KEY, null)); + Map data = result.data(); + Map ttls = result.ttls(); + assertNull(data.get(MY_KEY)); + assertTrue(ttls.isEmpty()); + } + + public static class TestConfigProvider implements ConfigProvider { + + public void configure(Map configs) { + } + + public ConfigData get(String path) { + return null; + } + + public ConfigData get(String path, Set keys) { + Map data = new HashMap<>(); + Long ttl = null; + if (TEST_PATH.equals(path)) { + if (keys.contains(TEST_KEY)) { + data.put(TEST_KEY, TEST_RESULT); + } + if (keys.contains(TEST_KEY_WITH_TTL)) { + data.put(TEST_KEY_WITH_TTL, TEST_RESULT_WITH_TTL); + ttl = 1L; + } + if (keys.contains(TEST_INDIRECTION)) { + data.put(TEST_INDIRECTION, "${test:testPath:testResult}"); + } + } else { + if (keys.contains(TEST_KEY)) { + data.put(TEST_KEY, TEST_RESULT_NO_PATH); + } + } + return new ConfigData(data, ttl); + } + + public void close() { + } + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/config/SaslConfigsTest.java b/clients/src/test/java/org/apache/kafka/common/config/SaslConfigsTest.java new file mode 100644 index 0000000..7e8b63c --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/config/SaslConfigsTest.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.config; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +public class SaslConfigsTest { + @Test + public void testSaslLoginRefreshDefaults() { + Map vals = new ConfigDef().withClientSaslSupport().parse(Collections.emptyMap()); + assertEquals(SaslConfigs.DEFAULT_LOGIN_REFRESH_WINDOW_FACTOR, + vals.get(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_FACTOR)); + assertEquals(SaslConfigs.DEFAULT_LOGIN_REFRESH_WINDOW_JITTER, + vals.get(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_JITTER)); + assertEquals(SaslConfigs.DEFAULT_LOGIN_REFRESH_MIN_PERIOD_SECONDS, + vals.get(SaslConfigs.SASL_LOGIN_REFRESH_MIN_PERIOD_SECONDS)); + assertEquals(SaslConfigs.DEFAULT_LOGIN_REFRESH_BUFFER_SECONDS, + vals.get(SaslConfigs.SASL_LOGIN_REFRESH_BUFFER_SECONDS)); + } + + @Test + public void testSaslLoginRefreshMinValuesAreValid() { + Map props = new HashMap<>(); + props.put(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_FACTOR, "0.5"); + props.put(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_JITTER, "0.0"); + props.put(SaslConfigs.SASL_LOGIN_REFRESH_MIN_PERIOD_SECONDS, "0"); + props.put(SaslConfigs.SASL_LOGIN_REFRESH_BUFFER_SECONDS, "0"); + Map vals = new ConfigDef().withClientSaslSupport().parse(props); + assertEquals(Double.valueOf("0.5"), vals.get(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_FACTOR)); + assertEquals(Double.valueOf("0.0"), vals.get(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_JITTER)); + assertEquals(Short.valueOf("0"), vals.get(SaslConfigs.SASL_LOGIN_REFRESH_MIN_PERIOD_SECONDS)); + assertEquals(Short.valueOf("0"), vals.get(SaslConfigs.SASL_LOGIN_REFRESH_BUFFER_SECONDS)); + } + + @Test + public void testSaslLoginRefreshMaxValuesAreValid() { + Map props = new HashMap<>(); + props.put(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_FACTOR, "1.0"); + props.put(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_JITTER, "0.25"); + props.put(SaslConfigs.SASL_LOGIN_REFRESH_MIN_PERIOD_SECONDS, "900"); + props.put(SaslConfigs.SASL_LOGIN_REFRESH_BUFFER_SECONDS, "3600"); + Map vals = new ConfigDef().withClientSaslSupport().parse(props); + assertEquals(Double.valueOf("1.0"), vals.get(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_FACTOR)); + assertEquals(Double.valueOf("0.25"), vals.get(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_JITTER)); + assertEquals(Short.valueOf("900"), vals.get(SaslConfigs.SASL_LOGIN_REFRESH_MIN_PERIOD_SECONDS)); + assertEquals(Short.valueOf("3600"), vals.get(SaslConfigs.SASL_LOGIN_REFRESH_BUFFER_SECONDS)); + } + + @Test + public void testSaslLoginRefreshWindowFactorMinValueIsReallyMinimum() { + Map props = new HashMap<>(); + props.put(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_FACTOR, "0.499999"); + assertThrows(ConfigException.class, () -> new ConfigDef().withClientSaslSupport().parse(props)); + } + + @Test + public void testSaslLoginRefreshWindowFactorMaxValueIsReallyMaximum() { + Map props = new HashMap<>(); + props.put(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_FACTOR, "1.0001"); + assertThrows(ConfigException.class, () -> new ConfigDef().withClientSaslSupport().parse(props)); + } + + @Test + public void testSaslLoginRefreshWindowJitterMinValueIsReallyMinimum() { + Map props = new HashMap<>(); + props.put(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_JITTER, "-0.000001"); + assertThrows(ConfigException.class, () -> new ConfigDef().withClientSaslSupport().parse(props)); + } + + @Test + public void testSaslLoginRefreshWindowJitterMaxValueIsReallyMaximum() { + Map props = new HashMap<>(); + props.put(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_JITTER, "0.251"); + assertThrows(ConfigException.class, () -> new ConfigDef().withClientSaslSupport().parse(props)); + } + + @Test + public void testSaslLoginRefreshMinPeriodSecondsMinValueIsReallyMinimum() { + Map props = new HashMap<>(); + props.put(SaslConfigs.SASL_LOGIN_REFRESH_MIN_PERIOD_SECONDS, "-1"); + assertThrows(ConfigException.class, () -> new ConfigDef().withClientSaslSupport().parse(props)); + } + + @Test + public void testSaslLoginRefreshMinPeriodSecondsMaxValueIsReallyMaximum() { + Map props = new HashMap<>(); + props.put(SaslConfigs.SASL_LOGIN_REFRESH_MIN_PERIOD_SECONDS, "901"); + assertThrows(ConfigException.class, () -> new ConfigDef().withClientSaslSupport().parse(props)); + } + + @Test + public void testSaslLoginRefreshBufferSecondsMinValueIsReallyMinimum() { + Map props = new HashMap<>(); + props.put(SaslConfigs.SASL_LOGIN_REFRESH_BUFFER_SECONDS, "-1"); + assertThrows(ConfigException.class, () -> new ConfigDef().withClientSaslSupport().parse(props)); + } + + @Test + public void testSaslLoginRefreshBufferSecondsMaxValueIsReallyMaximum() { + Map props = new HashMap<>(); + props.put(SaslConfigs.SASL_LOGIN_REFRESH_BUFFER_SECONDS, "3601"); + assertThrows(ConfigException.class, () -> new ConfigDef().withClientSaslSupport().parse(props)); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/config/provider/DirectoryConfigProviderTest.java b/clients/src/test/java/org/apache/kafka/common/config/provider/DirectoryConfigProviderTest.java new file mode 100644 index 0000000..7cf5422 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/config/provider/DirectoryConfigProviderTest.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.config.provider; + +import org.apache.kafka.common.config.ConfigData; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.util.Collections; +import java.util.Locale; +import java.util.ServiceLoader; +import java.util.Set; +import java.util.stream.StreamSupport; + +import static java.util.Arrays.asList; +import static org.apache.kafka.test.TestUtils.toSet; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class DirectoryConfigProviderTest { + + private DirectoryConfigProvider provider; + private File parent; + private File dir; + private File bar; + private File foo; + private File subdir; + private File subdirFile; + private File siblingDir; + private File siblingDirFile; + private File siblingFile; + + private static File writeFile(File file) throws IOException { + Files.write(file.toPath(), file.getName().toUpperCase(Locale.ENGLISH).getBytes(StandardCharsets.UTF_8)); + return file; + } + + @BeforeEach + public void setup() throws IOException { + provider = new DirectoryConfigProvider(); + provider.configure(Collections.emptyMap()); + parent = TestUtils.tempDirectory(); + dir = new File(parent, "dir"); + dir.mkdir(); + foo = writeFile(new File(dir, "foo")); + bar = writeFile(new File(dir, "bar")); + subdir = new File(dir, "subdir"); + subdir.mkdir(); + subdirFile = writeFile(new File(subdir, "subdirFile")); + siblingDir = new File(parent, "siblingdir"); + siblingDir.mkdir(); + siblingDirFile = writeFile(new File(siblingDir, "siblingdirFile")); + siblingFile = writeFile(new File(parent, "siblingFile")); + } + + @AfterEach + public void close() throws IOException { + provider.close(); + Utils.delete(parent); + } + + @Test + public void testGetAllKeysAtPath() throws IOException { + ConfigData configData = provider.get(dir.getAbsolutePath()); + assertEquals(toSet(asList(foo.getName(), bar.getName())), configData.data().keySet()); + assertEquals("FOO", configData.data().get(foo.getName())); + assertEquals("BAR", configData.data().get(bar.getName())); + assertNull(configData.ttl()); + } + + @Test + public void testGetSetOfKeysAtPath() { + Set keys = toSet(asList(foo.getName(), "baz")); + ConfigData configData = provider.get(dir.getAbsolutePath(), keys); + assertEquals(Collections.singleton(foo.getName()), configData.data().keySet()); + assertEquals("FOO", configData.data().get(foo.getName())); + assertNull(configData.ttl()); + } + + @Test + public void testNoSubdirs() { + // Only regular files directly in the path directory are allowed, not in subdirs + Set keys = toSet(asList(subdir.getName(), String.join(File.separator, subdir.getName(), subdirFile.getName()))); + ConfigData configData = provider.get(dir.getAbsolutePath(), keys); + assertTrue(configData.data().isEmpty()); + assertNull(configData.ttl()); + } + + @Test + public void testNoTraversal() { + // Check we can't escape outside the path directory + Set keys = toSet(asList( + String.join(File.separator, "..", siblingFile.getName()), + String.join(File.separator, "..", siblingDir.getName()), + String.join(File.separator, "..", siblingDir.getName(), siblingDirFile.getName()))); + ConfigData configData = provider.get(dir.getAbsolutePath(), keys); + assertTrue(configData.data().isEmpty()); + assertNull(configData.ttl()); + } + + @Test + public void testEmptyPath() { + ConfigData configData = provider.get(""); + assertTrue(configData.data().isEmpty()); + assertNull(configData.ttl()); + } + + @Test + public void testEmptyPathWithKey() { + ConfigData configData = provider.get("", Collections.singleton("foo")); + assertTrue(configData.data().isEmpty()); + assertNull(configData.ttl()); + } + + @Test + public void testNullPath() { + ConfigData configData = provider.get(null); + assertTrue(configData.data().isEmpty()); + assertNull(configData.ttl()); + } + + @Test + public void testNullPathWithKey() { + ConfigData configData = provider.get(null, Collections.singleton("foo")); + assertTrue(configData.data().isEmpty()); + assertNull(configData.ttl()); + } + + @Test + public void testServiceLoaderDiscovery() { + ServiceLoader serviceLoader = ServiceLoader.load(ConfigProvider.class); + assertTrue(StreamSupport.stream(serviceLoader.spliterator(), false).anyMatch(configProvider -> configProvider instanceof DirectoryConfigProvider)); + } +} + diff --git a/clients/src/test/java/org/apache/kafka/common/config/provider/FileConfigProviderTest.java b/clients/src/test/java/org/apache/kafka/common/config/provider/FileConfigProviderTest.java new file mode 100644 index 0000000..431f382 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/config/provider/FileConfigProviderTest.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.config.provider; + +import org.apache.kafka.common.config.ConfigData; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.io.Reader; +import java.io.StringReader; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.ServiceLoader; +import java.util.stream.StreamSupport; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class FileConfigProviderTest { + + private FileConfigProvider configProvider; + + @BeforeEach + public void setup() { + configProvider = new TestFileConfigProvider(); + } + + @Test + public void testGetAllKeysAtPath() throws Exception { + ConfigData configData = configProvider.get("dummy"); + Map result = new HashMap<>(); + result.put("testKey", "testResult"); + result.put("testKey2", "testResult2"); + assertEquals(result, configData.data()); + assertNull(configData.ttl()); + } + + @Test + public void testGetOneKeyAtPath() throws Exception { + ConfigData configData = configProvider.get("dummy", Collections.singleton("testKey")); + Map result = new HashMap<>(); + result.put("testKey", "testResult"); + assertEquals(result, configData.data()); + assertNull(configData.ttl()); + } + + @Test + public void testEmptyPath() throws Exception { + ConfigData configData = configProvider.get("", Collections.singleton("testKey")); + assertTrue(configData.data().isEmpty()); + assertNull(configData.ttl()); + } + + @Test + public void testEmptyPathWithKey() throws Exception { + ConfigData configData = configProvider.get(""); + assertTrue(configData.data().isEmpty()); + assertNull(configData.ttl()); + } + + @Test + public void testNullPath() throws Exception { + ConfigData configData = configProvider.get(null); + assertTrue(configData.data().isEmpty()); + assertNull(configData.ttl()); + } + + @Test + public void testNullPathWithKey() throws Exception { + ConfigData configData = configProvider.get(null, Collections.singleton("testKey")); + assertTrue(configData.data().isEmpty()); + assertNull(configData.ttl()); + } + + @Test + public void testServiceLoaderDiscovery() { + ServiceLoader serviceLoader = ServiceLoader.load(ConfigProvider.class); + assertTrue(StreamSupport.stream(serviceLoader.spliterator(), false).anyMatch(configProvider -> configProvider instanceof FileConfigProvider)); + } + + public static class TestFileConfigProvider extends FileConfigProvider { + + @Override + protected Reader reader(String path) throws IOException { + return new StringReader("testKey=testResult\ntestKey2=testResult2"); + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/config/provider/MockFileConfigProvider.java b/clients/src/test/java/org/apache/kafka/common/config/provider/MockFileConfigProvider.java new file mode 100644 index 0000000..50f60e2 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/config/provider/MockFileConfigProvider.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.config.provider; + +import java.io.IOException; +import java.io.Reader; +import java.io.StringReader; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class MockFileConfigProvider extends FileConfigProvider { + + private static final Map INSTANCES = Collections.synchronizedMap(new HashMap<>()); + private String id; + private boolean closed = false; + + public void configure(Map configs) { + Object id = configs.get("testId"); + if (id == null) { + throw new RuntimeException(getClass().getName() + " missing 'testId' config"); + } + if (this.id != null) { + throw new RuntimeException(getClass().getName() + " instance was configured twice"); + } + this.id = id.toString(); + INSTANCES.put(id.toString(), this); + } + + @Override + protected Reader reader(String path) throws IOException { + return new StringReader("key=testKey\npassword=randomPassword"); + } + + @Override + public synchronized void close() { + closed = true; + } + + public static void assertClosed(String id) { + MockFileConfigProvider instance = INSTANCES.remove(id); + assertNotNull(instance); + synchronized (instance) { + assertTrue(instance.closed); + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/config/provider/MockVaultConfigProvider.java b/clients/src/test/java/org/apache/kafka/common/config/provider/MockVaultConfigProvider.java new file mode 100644 index 0000000..c741798 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/config/provider/MockVaultConfigProvider.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.config.provider; + +import java.io.IOException; +import java.io.Reader; +import java.io.StringReader; +import java.util.Map; + +public class MockVaultConfigProvider extends FileConfigProvider { + + Map vaultConfigs; + private boolean configured = false; + private static final String LOCATION = "location"; + + @Override + protected Reader reader(String path) throws IOException { + String vaultLocation = (String) vaultConfigs.get(LOCATION); + return new StringReader("truststoreKey=testTruststoreKey\ntruststorePassword=randomtruststorePassword\n" + "truststoreLocation=" + vaultLocation + "\n"); + } + + @Override + public void configure(Map configs) { + this.vaultConfigs = configs; + configured = true; + } + + public boolean configured() { + return configured; + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/feature/FeaturesTest.java b/clients/src/test/java/org/apache/kafka/common/feature/FeaturesTest.java new file mode 100644 index 0000000..88b3471 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/feature/FeaturesTest.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.feature; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class FeaturesTest { + + @Test + public void testEmptyFeatures() { + Map> emptyMap = new HashMap<>(); + + Features emptyFinalizedFeatures = Features.emptyFinalizedFeatures(); + assertTrue(emptyFinalizedFeatures.features().isEmpty()); + assertTrue(emptyFinalizedFeatures.toMap().isEmpty()); + assertEquals(emptyFinalizedFeatures, Features.fromFinalizedFeaturesMap(emptyMap)); + + Features emptySupportedFeatures = Features.emptySupportedFeatures(); + assertTrue(emptySupportedFeatures.features().isEmpty()); + assertTrue(emptySupportedFeatures.toMap().isEmpty()); + assertEquals(emptySupportedFeatures, Features.fromSupportedFeaturesMap(emptyMap)); + } + + @Test + public void testNullFeatures() { + assertThrows( + NullPointerException.class, + () -> Features.finalizedFeatures(null)); + assertThrows( + NullPointerException.class, + () -> Features.supportedFeatures(null)); + } + + @Test + public void testGetAllFeaturesAPI() { + SupportedVersionRange v1 = new SupportedVersionRange((short) 1, (short) 2); + SupportedVersionRange v2 = new SupportedVersionRange((short) 3, (short) 4); + Map allFeatures = + mkMap(mkEntry("feature_1", v1), mkEntry("feature_2", v2)); + Features features = Features.supportedFeatures(allFeatures); + assertEquals(allFeatures, features.features()); + } + + @Test + public void testGetAPI() { + SupportedVersionRange v1 = new SupportedVersionRange((short) 1, (short) 2); + SupportedVersionRange v2 = new SupportedVersionRange((short) 3, (short) 4); + Map allFeatures = mkMap(mkEntry("feature_1", v1), mkEntry("feature_2", v2)); + Features features = Features.supportedFeatures(allFeatures); + assertEquals(v1, features.get("feature_1")); + assertEquals(v2, features.get("feature_2")); + assertNull(features.get("nonexistent_feature")); + } + + @Test + public void testFromFeaturesMapToFeaturesMap() { + SupportedVersionRange v1 = new SupportedVersionRange((short) 1, (short) 2); + SupportedVersionRange v2 = new SupportedVersionRange((short) 3, (short) 4); + Map allFeatures = mkMap(mkEntry("feature_1", v1), mkEntry("feature_2", v2)); + + Features features = Features.supportedFeatures(allFeatures); + + Map> expected = mkMap( + mkEntry("feature_1", mkMap(mkEntry("min_version", (short) 1), mkEntry("max_version", (short) 2))), + mkEntry("feature_2", mkMap(mkEntry("min_version", (short) 3), mkEntry("max_version", (short) 4)))); + assertEquals(expected, features.toMap()); + assertEquals(features, Features.fromSupportedFeaturesMap(expected)); + } + + @Test + public void testFromToFinalizedFeaturesMap() { + FinalizedVersionRange v1 = new FinalizedVersionRange((short) 1, (short) 2); + FinalizedVersionRange v2 = new FinalizedVersionRange((short) 3, (short) 4); + Map allFeatures = mkMap(mkEntry("feature_1", v1), mkEntry("feature_2", v2)); + + Features features = Features.finalizedFeatures(allFeatures); + + Map> expected = mkMap( + mkEntry("feature_1", mkMap(mkEntry("min_version_level", (short) 1), mkEntry("max_version_level", (short) 2))), + mkEntry("feature_2", mkMap(mkEntry("min_version_level", (short) 3), mkEntry("max_version_level", (short) 4)))); + assertEquals(expected, features.toMap()); + assertEquals(features, Features.fromFinalizedFeaturesMap(expected)); + } + + @Test + public void testToStringFinalizedFeatures() { + FinalizedVersionRange v1 = new FinalizedVersionRange((short) 1, (short) 2); + FinalizedVersionRange v2 = new FinalizedVersionRange((short) 3, (short) 4); + Map allFeatures = mkMap(mkEntry("feature_1", v1), mkEntry("feature_2", v2)); + + Features features = Features.finalizedFeatures(allFeatures); + + assertEquals( + "Features{(feature_1 -> FinalizedVersionRange[min_version_level:1, max_version_level:2]), (feature_2 -> FinalizedVersionRange[min_version_level:3, max_version_level:4])}", + features.toString()); + } + + @Test + public void testToStringSupportedFeatures() { + SupportedVersionRange v1 = new SupportedVersionRange((short) 1, (short) 2); + SupportedVersionRange v2 = new SupportedVersionRange((short) 3, (short) 4); + Map allFeatures + = mkMap(mkEntry("feature_1", v1), mkEntry("feature_2", v2)); + + Features features = Features.supportedFeatures(allFeatures); + + assertEquals( + "Features{(feature_1 -> SupportedVersionRange[min_version:1, max_version:2]), (feature_2 -> SupportedVersionRange[min_version:3, max_version:4])}", + features.toString()); + } + + @Test + public void testSuppportedFeaturesFromMapFailureWithInvalidMissingMaxVersion() { + // This is invalid because 'max_version' key is missing. + Map> invalidFeatures = mkMap( + mkEntry("feature_1", mkMap(mkEntry("min_version", (short) 1)))); + assertThrows( + IllegalArgumentException.class, + () -> Features.fromSupportedFeaturesMap(invalidFeatures)); + } + + @Test + public void testFinalizedFeaturesFromMapFailureWithInvalidMissingMaxVersionLevel() { + // This is invalid because 'max_version_level' key is missing. + Map> invalidFeatures = mkMap( + mkEntry("feature_1", mkMap(mkEntry("min_version_level", (short) 1)))); + assertThrows( + IllegalArgumentException.class, + () -> Features.fromFinalizedFeaturesMap(invalidFeatures)); + } + + @Test + public void testEquals() { + SupportedVersionRange v1 = new SupportedVersionRange((short) 1, (short) 2); + Map allFeatures = mkMap(mkEntry("feature_1", v1)); + Features features = Features.supportedFeatures(allFeatures); + Features featuresClone = Features.supportedFeatures(allFeatures); + assertTrue(features.equals(featuresClone)); + + SupportedVersionRange v2 = new SupportedVersionRange((short) 1, (short) 3); + Map allFeaturesDifferent = mkMap(mkEntry("feature_1", v2)); + Features featuresDifferent = Features.supportedFeatures(allFeaturesDifferent); + assertFalse(features.equals(featuresDifferent)); + + assertFalse(features.equals(null)); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/feature/FinalizedVersionRangeTest.java b/clients/src/test/java/org/apache/kafka/common/feature/FinalizedVersionRangeTest.java new file mode 100644 index 0000000..989c4bd --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/feature/FinalizedVersionRangeTest.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.feature; + +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Unit tests for the FinalizedVersionRange class. + * + * Most of the unit tests required for BaseVersionRange are part of the SupportedVersionRangeTest + * suite. This suite only tests behavior very specific to FinalizedVersionRange. + */ +public class FinalizedVersionRangeTest { + + @Test + public void testFromToMap() { + FinalizedVersionRange versionRange = new FinalizedVersionRange((short) 1, (short) 2); + assertEquals(1, versionRange.min()); + assertEquals(2, versionRange.max()); + + Map versionRangeMap = versionRange.toMap(); + assertEquals( + mkMap( + mkEntry("min_version_level", versionRange.min()), + mkEntry("max_version_level", versionRange.max())), + versionRangeMap); + + FinalizedVersionRange newVersionRange = FinalizedVersionRange.fromMap(versionRangeMap); + assertEquals(1, newVersionRange.min()); + assertEquals(2, newVersionRange.max()); + assertEquals(versionRange, newVersionRange); + } + + @Test + public void testToString() { + assertEquals("FinalizedVersionRange[min_version_level:1, max_version_level:1]", new FinalizedVersionRange((short) 1, (short) 1).toString()); + assertEquals("FinalizedVersionRange[min_version_level:1, max_version_level:2]", new FinalizedVersionRange((short) 1, (short) 2).toString()); + } + + @Test + public void testIsCompatibleWith() { + assertFalse(new FinalizedVersionRange((short) 1, (short) 1).isIncompatibleWith(new SupportedVersionRange((short) 1, (short) 1))); + assertFalse(new FinalizedVersionRange((short) 2, (short) 3).isIncompatibleWith(new SupportedVersionRange((short) 1, (short) 4))); + assertFalse(new FinalizedVersionRange((short) 1, (short) 4).isIncompatibleWith(new SupportedVersionRange((short) 1, (short) 4))); + + assertTrue(new FinalizedVersionRange((short) 1, (short) 4).isIncompatibleWith(new SupportedVersionRange((short) 2, (short) 3))); + assertTrue(new FinalizedVersionRange((short) 1, (short) 4).isIncompatibleWith(new SupportedVersionRange((short) 2, (short) 4))); + assertTrue(new FinalizedVersionRange((short) 2, (short) 4).isIncompatibleWith(new SupportedVersionRange((short) 2, (short) 3))); + } + + @Test + public void testMinMax() { + FinalizedVersionRange versionRange = new FinalizedVersionRange((short) 1, (short) 2); + assertEquals(1, versionRange.min()); + assertEquals(2, versionRange.max()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/feature/SupportedVersionRangeTest.java b/clients/src/test/java/org/apache/kafka/common/feature/SupportedVersionRangeTest.java new file mode 100644 index 0000000..acf452d --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/feature/SupportedVersionRangeTest.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.feature; + +import java.util.Map; + +import org.junit.jupiter.api.Test; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Unit tests for the SupportedVersionRange class. + * Along the way, this suite also includes extensive tests for the base class BaseVersionRange. + */ +public class SupportedVersionRangeTest { + @Test + public void testFailDueToInvalidParams() { + // min and max can't be < 1. + assertThrows( + IllegalArgumentException.class, + () -> new SupportedVersionRange((short) 0, (short) 0)); + // min can't be < 1. + assertThrows( + IllegalArgumentException.class, + () -> new SupportedVersionRange((short) 0, (short) 1)); + // max can't be < 1. + assertThrows( + IllegalArgumentException.class, + () -> new SupportedVersionRange((short) 1, (short) 0)); + // min can't be > max. + assertThrows( + IllegalArgumentException.class, + () -> new SupportedVersionRange((short) 2, (short) 1)); + } + + @Test + public void testFromToMap() { + SupportedVersionRange versionRange = new SupportedVersionRange((short) 1, (short) 2); + assertEquals(1, versionRange.min()); + assertEquals(2, versionRange.max()); + + Map versionRangeMap = versionRange.toMap(); + assertEquals( + mkMap(mkEntry("min_version", versionRange.min()), mkEntry("max_version", versionRange.max())), + versionRangeMap); + + SupportedVersionRange newVersionRange = SupportedVersionRange.fromMap(versionRangeMap); + assertEquals(1, newVersionRange.min()); + assertEquals(2, newVersionRange.max()); + assertEquals(versionRange, newVersionRange); + } + + @Test + public void testFromMapFailure() { + // min_version can't be < 1. + Map invalidWithBadMinVersion = + mkMap(mkEntry("min_version", (short) 0), mkEntry("max_version", (short) 1)); + assertThrows( + IllegalArgumentException.class, + () -> SupportedVersionRange.fromMap(invalidWithBadMinVersion)); + + // max_version can't be < 1. + Map invalidWithBadMaxVersion = + mkMap(mkEntry("min_version", (short) 1), mkEntry("max_version", (short) 0)); + assertThrows( + IllegalArgumentException.class, + () -> SupportedVersionRange.fromMap(invalidWithBadMaxVersion)); + + // min_version and max_version can't be < 1. + Map invalidWithBadMinMaxVersion = + mkMap(mkEntry("min_version", (short) 0), mkEntry("max_version", (short) 0)); + assertThrows( + IllegalArgumentException.class, + () -> SupportedVersionRange.fromMap(invalidWithBadMinMaxVersion)); + + // min_version can't be > max_version. + Map invalidWithLowerMaxVersion = + mkMap(mkEntry("min_version", (short) 2), mkEntry("max_version", (short) 1)); + assertThrows( + IllegalArgumentException.class, + () -> SupportedVersionRange.fromMap(invalidWithLowerMaxVersion)); + + // min_version key missing. + Map invalidWithMinKeyMissing = + mkMap(mkEntry("max_version", (short) 1)); + assertThrows( + IllegalArgumentException.class, + () -> SupportedVersionRange.fromMap(invalidWithMinKeyMissing)); + + // max_version key missing. + Map invalidWithMaxKeyMissing = + mkMap(mkEntry("min_version", (short) 1)); + assertThrows( + IllegalArgumentException.class, + () -> SupportedVersionRange.fromMap(invalidWithMaxKeyMissing)); + } + + @Test + public void testToString() { + assertEquals( + "SupportedVersionRange[min_version:1, max_version:1]", + new SupportedVersionRange((short) 1, (short) 1).toString()); + assertEquals( + "SupportedVersionRange[min_version:1, max_version:2]", + new SupportedVersionRange((short) 1, (short) 2).toString()); + } + + @Test + public void testEquals() { + SupportedVersionRange tested = new SupportedVersionRange((short) 1, (short) 1); + assertTrue(tested.equals(tested)); + assertFalse(tested.equals(new SupportedVersionRange((short) 1, (short) 2))); + assertFalse(tested.equals(null)); + } + + @Test + public void testMinMax() { + SupportedVersionRange versionRange = new SupportedVersionRange((short) 1, (short) 2); + assertEquals(1, versionRange.min()); + assertEquals(2, versionRange.max()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/header/internals/RecordHeadersTest.java b/clients/src/test/java/org/apache/kafka/common/header/internals/RecordHeadersTest.java new file mode 100644 index 0000000..f4813fd --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/header/internals/RecordHeadersTest.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.header.internals; + +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.Headers; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Iterator; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class RecordHeadersTest { + + @Test + public void testAdd() { + Headers headers = new RecordHeaders(); + headers.add(new RecordHeader("key", "value".getBytes())); + + Header header = headers.iterator().next(); + assertHeader("key", "value", header); + + headers.add(new RecordHeader("key2", "value2".getBytes())); + + assertHeader("key2", "value2", headers.lastHeader("key2")); + assertEquals(2, getCount(headers)); + } + + @Test + public void testRemove() { + Headers headers = new RecordHeaders(); + headers.add(new RecordHeader("key", "value".getBytes())); + + assertTrue(headers.iterator().hasNext()); + + headers.remove("key"); + + assertFalse(headers.iterator().hasNext()); + } + + @Test + public void testAddRemoveInterleaved() { + Headers headers = new RecordHeaders(); + headers.add(new RecordHeader("key", "value".getBytes())); + headers.add(new RecordHeader("key2", "value2".getBytes())); + + assertTrue(headers.iterator().hasNext()); + + headers.remove("key"); + + assertEquals(1, getCount(headers)); + + headers.add(new RecordHeader("key3", "value3".getBytes())); + + assertNull(headers.lastHeader("key")); + + assertHeader("key2", "value2", headers.lastHeader("key2")); + + assertHeader("key3", "value3", headers.lastHeader("key3")); + + assertEquals(2, getCount(headers)); + + headers.remove("key2"); + + assertNull(headers.lastHeader("key")); + + assertNull(headers.lastHeader("key2")); + + assertHeader("key3", "value3", headers.lastHeader("key3")); + + assertEquals(1, getCount(headers)); + + headers.add(new RecordHeader("key3", "value4".getBytes())); + + assertHeader("key3", "value4", headers.lastHeader("key3")); + + assertEquals(2, getCount(headers)); + + headers.add(new RecordHeader("key", "valueNew".getBytes())); + + assertEquals(3, getCount(headers)); + + + assertHeader("key", "valueNew", headers.lastHeader("key")); + + headers.remove("key3"); + + assertEquals(1, getCount(headers)); + + assertNull(headers.lastHeader("key2")); + + headers.remove("key"); + + assertFalse(headers.iterator().hasNext()); + } + + @Test + public void testLastHeader() { + Headers headers = new RecordHeaders(); + headers.add(new RecordHeader("key", "value".getBytes())); + headers.add(new RecordHeader("key", "value2".getBytes())); + headers.add(new RecordHeader("key", "value3".getBytes())); + + assertHeader("key", "value3", headers.lastHeader("key")); + assertEquals(3, getCount(headers)); + + } + + @Test + public void testReadOnly() throws IOException { + RecordHeaders headers = new RecordHeaders(); + headers.add(new RecordHeader("key", "value".getBytes())); + Iterator
            headerIteratorBeforeClose = headers.iterator(); + headers.setReadOnly(); + try { + headers.add(new RecordHeader("key", "value".getBytes())); + fail("IllegalStateException expected as headers are closed"); + } catch (IllegalStateException ise) { + //expected + } + + try { + headers.remove("key"); + fail("IllegalStateException expected as headers are closed"); + } catch (IllegalStateException ise) { + //expected + } + + try { + Iterator
            headerIterator = headers.iterator(); + headerIterator.next(); + headerIterator.remove(); + fail("IllegalStateException expected as headers are closed"); + } catch (IllegalStateException ise) { + //expected + } + + try { + headerIteratorBeforeClose.next(); + headerIteratorBeforeClose.remove(); + fail("IllegalStateException expected as headers are closed"); + } catch (IllegalStateException ise) { + //expected + } + } + + @Test + public void testHeaders() throws IOException { + RecordHeaders headers = new RecordHeaders(); + headers.add(new RecordHeader("key", "value".getBytes())); + headers.add(new RecordHeader("key1", "key1value".getBytes())); + headers.add(new RecordHeader("key", "value2".getBytes())); + headers.add(new RecordHeader("key2", "key2value".getBytes())); + + + Iterator
            keyHeaders = headers.headers("key").iterator(); + assertHeader("key", "value", keyHeaders.next()); + assertHeader("key", "value2", keyHeaders.next()); + assertFalse(keyHeaders.hasNext()); + + keyHeaders = headers.headers("key1").iterator(); + assertHeader("key1", "key1value", keyHeaders.next()); + assertFalse(keyHeaders.hasNext()); + + keyHeaders = headers.headers("key2").iterator(); + assertHeader("key2", "key2value", keyHeaders.next()); + assertFalse(keyHeaders.hasNext()); + + } + + @Test + public void testNew() throws IOException { + RecordHeaders headers = new RecordHeaders(); + headers.add(new RecordHeader("key", "value".getBytes())); + headers.setReadOnly(); + + RecordHeaders newHeaders = new RecordHeaders(headers); + newHeaders.add(new RecordHeader("key", "value2".getBytes())); + + //Ensure existing headers are not modified + assertHeader("key", "value", headers.lastHeader("key")); + assertEquals(1, getCount(headers)); + + //Ensure new headers are modified + assertHeader("key", "value2", newHeaders.lastHeader("key")); + assertEquals(2, getCount(newHeaders)); + } + + @Test + public void shouldThrowNpeWhenAddingNullHeader() { + final RecordHeaders recordHeaders = new RecordHeaders(); + assertThrows(NullPointerException.class, () -> recordHeaders.add(null)); + } + + @Test + public void shouldThrowNpeWhenAddingCollectionWithNullHeader() { + assertThrows(NullPointerException.class, () -> new RecordHeaders(new Header[1])); + } + + private int getCount(Headers headers) { + int count = 0; + Iterator
            headerIterator = headers.iterator(); + while (headerIterator.hasNext()) { + headerIterator.next(); + count++; + } + return count; + } + + static void assertHeader(String key, String value, Header actual) { + assertEquals(key, actual.key()); + assertTrue(Arrays.equals(value.getBytes(), actual.value())); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/internals/PartitionStatesTest.java b/clients/src/test/java/org/apache/kafka/common/internals/PartitionStatesTest.java new file mode 100644 index 0000000..52f0175 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/internals/PartitionStatesTest.java @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.internals; + +import org.apache.kafka.common.TopicPartition; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class PartitionStatesTest { + + @Test + public void testSet() { + PartitionStates states = new PartitionStates<>(); + LinkedHashMap map = createMap(); + states.set(map); + LinkedHashMap expected = new LinkedHashMap<>(); + expected.put(new TopicPartition("foo", 2), "foo 2"); + expected.put(new TopicPartition("foo", 0), "foo 0"); + expected.put(new TopicPartition("blah", 2), "blah 2"); + expected.put(new TopicPartition("blah", 1), "blah 1"); + expected.put(new TopicPartition("baz", 2), "baz 2"); + expected.put(new TopicPartition("baz", 3), "baz 3"); + checkState(states, expected); + + states.set(new LinkedHashMap<>()); + checkState(states, new LinkedHashMap<>()); + } + + private LinkedHashMap createMap() { + LinkedHashMap map = new LinkedHashMap<>(); + map.put(new TopicPartition("foo", 2), "foo 2"); + map.put(new TopicPartition("blah", 2), "blah 2"); + map.put(new TopicPartition("blah", 1), "blah 1"); + map.put(new TopicPartition("baz", 2), "baz 2"); + map.put(new TopicPartition("foo", 0), "foo 0"); + map.put(new TopicPartition("baz", 3), "baz 3"); + return map; + } + + private void checkState(PartitionStates states, LinkedHashMap expected) { + assertEquals(expected.keySet(), states.partitionSet()); + assertEquals(expected.size(), states.size()); + assertEquals(expected, states.partitionStateMap()); + } + + @Test + public void testMoveToEnd() { + PartitionStates states = new PartitionStates<>(); + LinkedHashMap map = createMap(); + states.set(map); + + states.moveToEnd(new TopicPartition("baz", 2)); + LinkedHashMap expected = new LinkedHashMap<>(); + expected.put(new TopicPartition("foo", 2), "foo 2"); + expected.put(new TopicPartition("foo", 0), "foo 0"); + expected.put(new TopicPartition("blah", 2), "blah 2"); + expected.put(new TopicPartition("blah", 1), "blah 1"); + expected.put(new TopicPartition("baz", 3), "baz 3"); + expected.put(new TopicPartition("baz", 2), "baz 2"); + checkState(states, expected); + + states.moveToEnd(new TopicPartition("foo", 2)); + expected = new LinkedHashMap<>(); + expected.put(new TopicPartition("foo", 0), "foo 0"); + expected.put(new TopicPartition("blah", 2), "blah 2"); + expected.put(new TopicPartition("blah", 1), "blah 1"); + expected.put(new TopicPartition("baz", 3), "baz 3"); + expected.put(new TopicPartition("baz", 2), "baz 2"); + expected.put(new TopicPartition("foo", 2), "foo 2"); + checkState(states, expected); + + // no-op + states.moveToEnd(new TopicPartition("foo", 2)); + checkState(states, expected); + + // partition doesn't exist + states.moveToEnd(new TopicPartition("baz", 5)); + checkState(states, expected); + + // topic doesn't exist + states.moveToEnd(new TopicPartition("aaa", 2)); + checkState(states, expected); + } + + @Test + public void testUpdateAndMoveToEnd() { + PartitionStates states = new PartitionStates<>(); + LinkedHashMap map = createMap(); + states.set(map); + + states.updateAndMoveToEnd(new TopicPartition("foo", 0), "foo 0 updated"); + LinkedHashMap expected = new LinkedHashMap<>(); + expected.put(new TopicPartition("foo", 2), "foo 2"); + expected.put(new TopicPartition("blah", 2), "blah 2"); + expected.put(new TopicPartition("blah", 1), "blah 1"); + expected.put(new TopicPartition("baz", 2), "baz 2"); + expected.put(new TopicPartition("baz", 3), "baz 3"); + expected.put(new TopicPartition("foo", 0), "foo 0 updated"); + checkState(states, expected); + + states.updateAndMoveToEnd(new TopicPartition("baz", 2), "baz 2 updated"); + expected = new LinkedHashMap<>(); + expected.put(new TopicPartition("foo", 2), "foo 2"); + expected.put(new TopicPartition("blah", 2), "blah 2"); + expected.put(new TopicPartition("blah", 1), "blah 1"); + expected.put(new TopicPartition("baz", 3), "baz 3"); + expected.put(new TopicPartition("foo", 0), "foo 0 updated"); + expected.put(new TopicPartition("baz", 2), "baz 2 updated"); + checkState(states, expected); + + // partition doesn't exist + states.updateAndMoveToEnd(new TopicPartition("baz", 5), "baz 5 new"); + expected = new LinkedHashMap<>(); + expected.put(new TopicPartition("foo", 2), "foo 2"); + expected.put(new TopicPartition("blah", 2), "blah 2"); + expected.put(new TopicPartition("blah", 1), "blah 1"); + expected.put(new TopicPartition("baz", 3), "baz 3"); + expected.put(new TopicPartition("foo", 0), "foo 0 updated"); + expected.put(new TopicPartition("baz", 2), "baz 2 updated"); + expected.put(new TopicPartition("baz", 5), "baz 5 new"); + checkState(states, expected); + + // topic doesn't exist + states.updateAndMoveToEnd(new TopicPartition("aaa", 2), "aaa 2 new"); + expected = new LinkedHashMap<>(); + expected.put(new TopicPartition("foo", 2), "foo 2"); + expected.put(new TopicPartition("blah", 2), "blah 2"); + expected.put(new TopicPartition("blah", 1), "blah 1"); + expected.put(new TopicPartition("baz", 3), "baz 3"); + expected.put(new TopicPartition("foo", 0), "foo 0 updated"); + expected.put(new TopicPartition("baz", 2), "baz 2 updated"); + expected.put(new TopicPartition("baz", 5), "baz 5 new"); + expected.put(new TopicPartition("aaa", 2), "aaa 2 new"); + checkState(states, expected); + } + + @Test + public void testPartitionValues() { + PartitionStates states = new PartitionStates<>(); + LinkedHashMap map = createMap(); + states.set(map); + List expected = new ArrayList<>(); + expected.add("foo 2"); + expected.add("foo 0"); + expected.add("blah 2"); + expected.add("blah 1"); + expected.add("baz 2"); + expected.add("baz 3"); + assertEquals(expected, states.partitionStateValues()); + } + + @Test + public void testClear() { + PartitionStates states = new PartitionStates<>(); + LinkedHashMap map = createMap(); + states.set(map); + states.clear(); + checkState(states, new LinkedHashMap()); + } + + @Test + public void testRemove() { + PartitionStates states = new PartitionStates<>(); + LinkedHashMap map = createMap(); + states.set(map); + + states.remove(new TopicPartition("foo", 2)); + LinkedHashMap expected = new LinkedHashMap<>(); + expected.put(new TopicPartition("foo", 0), "foo 0"); + expected.put(new TopicPartition("blah", 2), "blah 2"); + expected.put(new TopicPartition("blah", 1), "blah 1"); + expected.put(new TopicPartition("baz", 2), "baz 2"); + expected.put(new TopicPartition("baz", 3), "baz 3"); + checkState(states, expected); + + states.remove(new TopicPartition("blah", 1)); + expected = new LinkedHashMap<>(); + expected.put(new TopicPartition("foo", 0), "foo 0"); + expected.put(new TopicPartition("blah", 2), "blah 2"); + expected.put(new TopicPartition("baz", 2), "baz 2"); + expected.put(new TopicPartition("baz", 3), "baz 3"); + checkState(states, expected); + + states.remove(new TopicPartition("baz", 3)); + expected = new LinkedHashMap<>(); + expected.put(new TopicPartition("foo", 0), "foo 0"); + expected.put(new TopicPartition("blah", 2), "blah 2"); + expected.put(new TopicPartition("baz", 2), "baz 2"); + checkState(states, expected); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/internals/TopicTest.java b/clients/src/test/java/org/apache/kafka/common/internals/TopicTest.java new file mode 100644 index 0000000..9bf237f --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/internals/TopicTest.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.internals; + +import org.apache.kafka.common.errors.InvalidTopicException; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class TopicTest { + + @Test + public void shouldAcceptValidTopicNames() { + String maxLengthString = TestUtils.randomString(249); + String[] validTopicNames = {"valid", "TOPIC", "nAmEs", "ar6", "VaL1d", "_0-9_.", "...", maxLengthString}; + + for (String topicName : validTopicNames) { + Topic.validate(topicName); + } + } + + @Test + public void shouldThrowOnInvalidTopicNames() { + char[] longString = new char[250]; + Arrays.fill(longString, 'a'); + String[] invalidTopicNames = {"", "foo bar", "..", "foo:bar", "foo=bar", ".", new String(longString)}; + + for (String topicName : invalidTopicNames) { + try { + Topic.validate(topicName); + fail("No exception was thrown for topic with invalid name: " + topicName); + } catch (InvalidTopicException e) { + // Good + } + } + } + + @Test + public void shouldRecognizeInvalidCharactersInTopicNames() { + char[] invalidChars = {'/', '\\', ',', '\u0000', ':', '"', '\'', ';', '*', '?', ' ', '\t', '\r', '\n', '='}; + + for (char c : invalidChars) { + String topicName = "Is " + c + "illegal"; + assertFalse(Topic.containsValidPattern(topicName)); + } + } + + @Test + public void testTopicHasCollisionChars() { + List falseTopics = Arrays.asList("start", "end", "middle", "many"); + List trueTopics = Arrays.asList( + ".start", "end.", "mid.dle", ".ma.ny.", + "_start", "end_", "mid_dle", "_ma_ny." + ); + + for (String topic : falseTopics) + assertFalse(Topic.hasCollisionChars(topic)); + + for (String topic : trueTopics) + assertTrue(Topic.hasCollisionChars(topic)); + } + + @Test + public void testTopicHasCollision() { + List periodFirstMiddleLastNone = Arrays.asList(".topic", "to.pic", "topic.", "topic"); + List underscoreFirstMiddleLastNone = Arrays.asList("_topic", "to_pic", "topic_", "topic"); + + // Self + for (String topic : periodFirstMiddleLastNone) + assertTrue(Topic.hasCollision(topic, topic)); + + for (String topic : underscoreFirstMiddleLastNone) + assertTrue(Topic.hasCollision(topic, topic)); + + // Same Position + for (int i = 0; i < periodFirstMiddleLastNone.size(); ++i) + assertTrue(Topic.hasCollision(periodFirstMiddleLastNone.get(i), underscoreFirstMiddleLastNone.get(i))); + + // Different Position + Collections.reverse(underscoreFirstMiddleLastNone); + for (int i = 0; i < periodFirstMiddleLastNone.size(); ++i) + assertFalse(Topic.hasCollision(periodFirstMiddleLastNone.get(i), underscoreFirstMiddleLastNone.get(i))); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/memory/GarbageCollectedMemoryPoolTest.java b/clients/src/test/java/org/apache/kafka/common/memory/GarbageCollectedMemoryPoolTest.java new file mode 100644 index 0000000..0bf90bf --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/memory/GarbageCollectedMemoryPoolTest.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.memory; + +import java.nio.ByteBuffer; +import java.util.concurrent.TimeUnit; + +import org.apache.kafka.common.utils.Utils; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + + +public class GarbageCollectedMemoryPoolTest { + + private GarbageCollectedMemoryPool pool; + + @AfterEach + public void releasePool() { + if (pool != null) pool.close(); + } + + @Test + public void testZeroSize() { + assertThrows(IllegalArgumentException.class, + () -> new GarbageCollectedMemoryPool(0, 7, true, null)); + } + + @Test + public void testNegativeSize() { + assertThrows(IllegalArgumentException.class, + () -> new GarbageCollectedMemoryPool(-1, 7, false, null)); + } + + @Test + public void testZeroMaxAllocation() { + assertThrows(IllegalArgumentException.class, + () -> new GarbageCollectedMemoryPool(100, 0, true, null)); + } + + @Test + public void testNegativeMaxAllocation() { + assertThrows(IllegalArgumentException.class, + () -> new GarbageCollectedMemoryPool(100, -1, false, null)); + } + + @Test + public void testMaxAllocationLargerThanSize() { + assertThrows(IllegalArgumentException.class, + () -> new GarbageCollectedMemoryPool(100, 101, true, null)); + } + + @Test + public void testAllocationOverMaxAllocation() { + pool = new GarbageCollectedMemoryPool(1000, 10, false, null); + assertThrows(IllegalArgumentException.class, () -> pool.tryAllocate(11)); + } + + @Test + public void testAllocationZero() { + pool = new GarbageCollectedMemoryPool(1000, 10, true, null); + assertThrows(IllegalArgumentException.class, () -> pool.tryAllocate(0)); + } + + @Test + public void testAllocationNegative() { + pool = new GarbageCollectedMemoryPool(1000, 10, false, null); + assertThrows(IllegalArgumentException.class, () -> pool.tryAllocate(-1)); + } + + @Test + public void testReleaseNull() { + pool = new GarbageCollectedMemoryPool(1000, 10, true, null); + assertThrows(IllegalArgumentException.class, () -> pool.release(null)); + } + + @Test + public void testReleaseForeignBuffer() { + pool = new GarbageCollectedMemoryPool(1000, 10, true, null); + ByteBuffer fellOffATruck = ByteBuffer.allocate(1); + assertThrows(IllegalArgumentException.class, () -> pool.release(fellOffATruck)); + } + + @Test + public void testDoubleFree() { + pool = new GarbageCollectedMemoryPool(1000, 10, false, null); + ByteBuffer buffer = pool.tryAllocate(5); + assertNotNull(buffer); + pool.release(buffer); + assertThrows(IllegalArgumentException.class, () -> pool.release(buffer)); + } + + @Test + public void testAllocationBound() { + pool = new GarbageCollectedMemoryPool(21, 10, false, null); + ByteBuffer buf1 = pool.tryAllocate(10); + assertNotNull(buf1); + assertEquals(10, buf1.capacity()); + ByteBuffer buf2 = pool.tryAllocate(10); + assertNotNull(buf2); + assertEquals(10, buf2.capacity()); + ByteBuffer buf3 = pool.tryAllocate(10); + assertNotNull(buf3); + assertEquals(10, buf3.capacity()); + //no more allocations + assertNull(pool.tryAllocate(1)); + //release a buffer + pool.release(buf3); + //now we can have more + ByteBuffer buf4 = pool.tryAllocate(10); + assertNotNull(buf4); + assertEquals(10, buf4.capacity()); + //no more allocations + assertNull(pool.tryAllocate(1)); + } + + @Test + public void testBuffersGarbageCollected() throws Exception { + Runtime runtime = Runtime.getRuntime(); + long maxHeap = runtime.maxMemory(); //in bytes + long maxPool = maxHeap / 2; + long maxSingleAllocation = maxPool / 10; + assertTrue(maxSingleAllocation < Integer.MAX_VALUE / 2); //test JVM running with too much memory for this test logic (?) + pool = new GarbageCollectedMemoryPool(maxPool, (int) maxSingleAllocation, false, null); + + //we will allocate 30 buffers from this pool, which is sized such that at-most + //11 should coexist and 30 do not fit in the JVM memory, proving that: + // 1. buffers were reclaimed and + // 2. the pool registered the reclamation. + + int timeoutSeconds = 30; + long giveUp = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(timeoutSeconds); + boolean success = false; + + int buffersAllocated = 0; + while (System.currentTimeMillis() < giveUp) { + ByteBuffer buffer = pool.tryAllocate((int) maxSingleAllocation); + if (buffer == null) { + System.gc(); + Thread.sleep(10); + continue; + } + buffersAllocated++; + if (buffersAllocated >= 30) { + success = true; + break; + } + } + + assertTrue(success, "failed to allocate 30 buffers in " + timeoutSeconds + " seconds." + + " buffers allocated: " + buffersAllocated + " heap " + Utils.formatBytes(maxHeap) + + " pool " + Utils.formatBytes(maxPool) + " single allocation " + + Utils.formatBytes(maxSingleAllocation)); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/message/ApiMessageTypeTest.java b/clients/src/test/java/org/apache/kafka/common/message/ApiMessageTypeTest.java new file mode 100644 index 0000000..7dc6147 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/message/ApiMessageTypeTest.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.message; + +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.protocol.types.Schema; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.HashSet; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.fail; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@Timeout(120) +public class ApiMessageTypeTest { + + @Test + public void testFromApiKey() { + for (ApiMessageType type : ApiMessageType.values()) { + ApiMessageType type2 = ApiMessageType.fromApiKey(type.apiKey()); + assertEquals(type2, type); + } + } + + @Test + public void testInvalidFromApiKey() { + try { + ApiMessageType.fromApiKey((short) -1); + fail("expected to get an UnsupportedVersionException"); + } catch (UnsupportedVersionException uve) { + // expected + } + } + + @Test + public void testUniqueness() { + Set ids = new HashSet<>(); + Set requestNames = new HashSet<>(); + Set responseNames = new HashSet<>(); + for (ApiMessageType type : ApiMessageType.values()) { + assertFalse(ids.contains(type.apiKey()), + "found two ApiMessageType objects with id " + type.apiKey()); + ids.add(type.apiKey()); + String requestName = type.newRequest().getClass().getSimpleName(); + assertFalse(requestNames.contains(requestName), + "found two ApiMessageType objects with requestName " + requestName); + requestNames.add(requestName); + String responseName = type.newResponse().getClass().getSimpleName(); + assertFalse(responseNames.contains(responseName), + "found two ApiMessageType objects with responseName " + responseName); + responseNames.add(responseName); + } + assertEquals(ApiMessageType.values().length, ids.size()); + assertEquals(ApiMessageType.values().length, requestNames.size()); + assertEquals(ApiMessageType.values().length, responseNames.size()); + } + + @Test + public void testHeaderVersion() { + assertEquals((short) 1, ApiMessageType.PRODUCE.requestHeaderVersion((short) 0)); + assertEquals((short) 0, ApiMessageType.PRODUCE.responseHeaderVersion((short) 0)); + + assertEquals((short) 1, ApiMessageType.PRODUCE.requestHeaderVersion((short) 1)); + assertEquals((short) 0, ApiMessageType.PRODUCE.responseHeaderVersion((short) 1)); + + assertEquals((short) 0, ApiMessageType.CONTROLLED_SHUTDOWN.requestHeaderVersion((short) 0)); + assertEquals((short) 0, ApiMessageType.CONTROLLED_SHUTDOWN.responseHeaderVersion((short) 0)); + + assertEquals((short) 1, ApiMessageType.CONTROLLED_SHUTDOWN.requestHeaderVersion((short) 1)); + assertEquals((short) 0, ApiMessageType.CONTROLLED_SHUTDOWN.responseHeaderVersion((short) 1)); + + assertEquals((short) 1, ApiMessageType.CREATE_TOPICS.requestHeaderVersion((short) 4)); + assertEquals((short) 0, ApiMessageType.CREATE_TOPICS.responseHeaderVersion((short) 4)); + + assertEquals((short) 2, ApiMessageType.CREATE_TOPICS.requestHeaderVersion((short) 5)); + assertEquals((short) 1, ApiMessageType.CREATE_TOPICS.responseHeaderVersion((short) 5)); + } + + /** + * Kafka currently supports direct upgrades from 0.8 to the latest version. As such, it has to support all apis + * starting from version 0 and we must have schemas from the oldest version to the latest. + */ + @Test + public void testAllVersionsHaveSchemas() { + for (ApiMessageType type : ApiMessageType.values()) { + assertEquals(0, type.lowestSupportedVersion()); + + assertEquals(type.requestSchemas().length, type.responseSchemas().length); + for (Schema schema : type.requestSchemas()) + assertNotNull(schema); + for (Schema schema : type.responseSchemas()) + assertNotNull(schema); + + assertEquals(type.highestSupportedVersion() + 1, type.requestSchemas().length); + } + } + + @Test + public void testApiIdsArePositive() { + for (ApiMessageType type : ApiMessageType.values()) + assertTrue(type.apiKey() >= 0); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/message/MessageTest.java b/clients/src/test/java/org/apache/kafka/common/message/MessageTest.java new file mode 100644 index 0000000..3fcd007 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/message/MessageTest.java @@ -0,0 +1,1234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.message; + +import com.fasterxml.jackson.databind.JsonNode; + +import org.apache.kafka.common.IsolationLevel; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.AddPartitionsToTxnRequestData.AddPartitionsToTxnTopic; +import org.apache.kafka.common.message.AddPartitionsToTxnRequestData.AddPartitionsToTxnTopicCollection; +import org.apache.kafka.common.message.DescribeClusterResponseData.DescribeClusterBroker; +import org.apache.kafka.common.message.DescribeClusterResponseData.DescribeClusterBrokerCollection; +import org.apache.kafka.common.message.DescribeGroupsResponseData.DescribedGroup; +import org.apache.kafka.common.message.DescribeGroupsResponseData.DescribedGroupMember; +import org.apache.kafka.common.message.JoinGroupResponseData.JoinGroupResponseMember; +import org.apache.kafka.common.message.LeaveGroupResponseData.MemberResponse; +import org.apache.kafka.common.message.ListOffsetsRequestData.ListOffsetsPartition; +import org.apache.kafka.common.message.ListOffsetsRequestData.ListOffsetsTopic; +import org.apache.kafka.common.message.ListOffsetsResponseData.ListOffsetsPartitionResponse; +import org.apache.kafka.common.message.ListOffsetsResponseData.ListOffsetsTopicResponse; +import org.apache.kafka.common.message.OffsetCommitRequestData.OffsetCommitRequestPartition; +import org.apache.kafka.common.message.OffsetCommitRequestData.OffsetCommitRequestTopic; +import org.apache.kafka.common.message.OffsetCommitResponseData.OffsetCommitResponsePartition; +import org.apache.kafka.common.message.OffsetCommitResponseData.OffsetCommitResponseTopic; +import org.apache.kafka.common.message.OffsetFetchRequestData.OffsetFetchRequestGroup; +import org.apache.kafka.common.message.OffsetFetchRequestData.OffsetFetchRequestTopic; +import org.apache.kafka.common.message.OffsetFetchRequestData.OffsetFetchRequestTopics; +import org.apache.kafka.common.message.OffsetFetchResponseData.OffsetFetchResponseGroup; +import org.apache.kafka.common.message.OffsetFetchResponseData.OffsetFetchResponsePartition; +import org.apache.kafka.common.message.OffsetFetchResponseData.OffsetFetchResponsePartitions; +import org.apache.kafka.common.message.OffsetFetchResponseData.OffsetFetchResponseTopic; +import org.apache.kafka.common.message.OffsetFetchResponseData.OffsetFetchResponseTopics; +import org.apache.kafka.common.message.TxnOffsetCommitRequestData.TxnOffsetCommitRequestPartition; +import org.apache.kafka.common.message.TxnOffsetCommitRequestData.TxnOffsetCommitRequestTopic; +import org.apache.kafka.common.message.TxnOffsetCommitResponseData.TxnOffsetCommitResponsePartition; +import org.apache.kafka.common.message.TxnOffsetCommitResponseData.TxnOffsetCommitResponseTopic; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.protocol.Message; +import org.apache.kafka.common.protocol.MessageUtil; +import org.apache.kafka.common.protocol.ObjectSerializationCache; +import org.apache.kafka.common.protocol.types.RawTaggedField; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.lang.reflect.Method; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.function.Supplier; + +import static java.util.Collections.singletonList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +@Timeout(120) +public final class MessageTest { + + private final String memberId = "memberId"; + private final String instanceId = "instanceId"; + private final List listOfVersionsNonBatchOffsetFetch = Arrays.asList(0, 1, 2, 3, 4, 5, 6, 7); + + @Test + public void testAddOffsetsToTxnVersions() throws Exception { + testAllMessageRoundTrips(new AddOffsetsToTxnRequestData(). + setTransactionalId("foobar"). + setProducerId(0xbadcafebadcafeL). + setProducerEpoch((short) 123). + setGroupId("baaz")); + testAllMessageRoundTrips(new AddOffsetsToTxnResponseData(). + setThrottleTimeMs(42). + setErrorCode((short) 0)); + } + + @Test + public void testAddPartitionsToTxnVersions() throws Exception { + testAllMessageRoundTrips(new AddPartitionsToTxnRequestData(). + setTransactionalId("blah"). + setProducerId(0xbadcafebadcafeL). + setProducerEpoch((short) 30000). + setTopics(new AddPartitionsToTxnTopicCollection(singletonList( + new AddPartitionsToTxnTopic(). + setName("Topic"). + setPartitions(singletonList(1))).iterator()))); + } + + @Test + public void testCreateTopicsVersions() throws Exception { + testAllMessageRoundTrips(new CreateTopicsRequestData(). + setTimeoutMs(1000).setTopics(new CreateTopicsRequestData.CreatableTopicCollection())); + } + + @Test + public void testDescribeAclsRequest() throws Exception { + testAllMessageRoundTrips(new DescribeAclsRequestData(). + setResourceTypeFilter((byte) 42). + setResourceNameFilter(null). + setPatternTypeFilter((byte) 3). + setPrincipalFilter("abc"). + setHostFilter(null). + setOperation((byte) 0). + setPermissionType((byte) 0)); + } + + @Test + public void testMetadataVersions() throws Exception { + testAllMessageRoundTrips(new MetadataRequestData().setTopics( + Arrays.asList(new MetadataRequestData.MetadataRequestTopic().setName("foo"), + new MetadataRequestData.MetadataRequestTopic().setName("bar") + ))); + testAllMessageRoundTripsFromVersion((short) 1, new MetadataRequestData(). + setTopics(null). + setAllowAutoTopicCreation(true). + setIncludeClusterAuthorizedOperations(false). + setIncludeTopicAuthorizedOperations(false)); + testAllMessageRoundTripsFromVersion((short) 4, new MetadataRequestData(). + setTopics(null). + setAllowAutoTopicCreation(false). + setIncludeClusterAuthorizedOperations(false). + setIncludeTopicAuthorizedOperations(false)); + } + + @Test + public void testHeartbeatVersions() throws Exception { + Supplier newRequest = () -> new HeartbeatRequestData() + .setGroupId("groupId") + .setMemberId(memberId) + .setGenerationId(15); + testAllMessageRoundTrips(newRequest.get()); + testAllMessageRoundTrips(newRequest.get().setGroupInstanceId(null)); + testAllMessageRoundTripsFromVersion((short) 3, newRequest.get().setGroupInstanceId("instanceId")); + } + + @Test + public void testJoinGroupRequestVersions() throws Exception { + Supplier newRequest = () -> new JoinGroupRequestData() + .setGroupId("groupId") + .setMemberId(memberId) + .setProtocolType("consumer") + .setProtocols(new JoinGroupRequestData.JoinGroupRequestProtocolCollection()) + .setSessionTimeoutMs(10000); + testAllMessageRoundTrips(newRequest.get()); + testAllMessageRoundTripsFromVersion((short) 1, newRequest.get().setRebalanceTimeoutMs(20000)); + testAllMessageRoundTrips(newRequest.get().setGroupInstanceId(null)); + testAllMessageRoundTripsFromVersion((short) 5, newRequest.get().setGroupInstanceId("instanceId")); + } + + @Test + public void testListOffsetsRequestVersions() throws Exception { + List v = Collections.singletonList(new ListOffsetsTopic() + .setName("topic") + .setPartitions(Collections.singletonList(new ListOffsetsPartition() + .setPartitionIndex(0) + .setTimestamp(123L)))); + Supplier newRequest = () -> new ListOffsetsRequestData() + .setTopics(v) + .setReplicaId(0); + testAllMessageRoundTrips(newRequest.get()); + testAllMessageRoundTripsFromVersion((short) 2, newRequest.get().setIsolationLevel(IsolationLevel.READ_COMMITTED.id())); + } + + @Test + public void testListOffsetsResponseVersions() throws Exception { + ListOffsetsPartitionResponse partition = new ListOffsetsPartitionResponse() + .setErrorCode(Errors.NONE.code()) + .setPartitionIndex(0) + .setOldStyleOffsets(Collections.singletonList(321L)); + List topics = Collections.singletonList(new ListOffsetsTopicResponse() + .setName("topic") + .setPartitions(Collections.singletonList(partition))); + Supplier response = () -> new ListOffsetsResponseData() + .setTopics(topics); + for (short version : ApiKeys.LIST_OFFSETS.allVersions()) { + ListOffsetsResponseData responseData = response.get(); + if (version > 0) { + responseData.topics().get(0).partitions().get(0) + .setOldStyleOffsets(Collections.emptyList()) + .setOffset(456L) + .setTimestamp(123L); + } + if (version > 1) { + responseData.setThrottleTimeMs(1000); + } + if (version > 3) { + partition.setLeaderEpoch(1); + } + testEquivalentMessageRoundTrip(version, responseData); + } + } + + @Test + public void testJoinGroupResponseVersions() throws Exception { + Supplier newResponse = () -> new JoinGroupResponseData() + .setMemberId(memberId) + .setLeader(memberId) + .setGenerationId(1) + .setMembers(Collections.singletonList( + new JoinGroupResponseMember() + .setMemberId(memberId) + )); + testAllMessageRoundTrips(newResponse.get()); + testAllMessageRoundTripsFromVersion((short) 2, newResponse.get().setThrottleTimeMs(1000)); + testAllMessageRoundTrips(newResponse.get().members().get(0).setGroupInstanceId(null)); + testAllMessageRoundTripsFromVersion((short) 5, newResponse.get().members().get(0).setGroupInstanceId("instanceId")); + } + + @Test + public void testLeaveGroupResponseVersions() throws Exception { + Supplier newResponse = () -> new LeaveGroupResponseData() + .setErrorCode(Errors.NOT_COORDINATOR.code()); + + testAllMessageRoundTrips(newResponse.get()); + testAllMessageRoundTripsFromVersion((short) 1, newResponse.get().setThrottleTimeMs(1000)); + + testAllMessageRoundTripsFromVersion((short) 3, newResponse.get().setMembers( + Collections.singletonList(new MemberResponse() + .setMemberId(memberId) + .setGroupInstanceId(instanceId)) + )); + } + + @Test + public void testSyncGroupDefaultGroupInstanceId() throws Exception { + Supplier request = () -> new SyncGroupRequestData() + .setGroupId("groupId") + .setMemberId(memberId) + .setGenerationId(15) + .setAssignments(new ArrayList<>()); + testAllMessageRoundTrips(request.get()); + testAllMessageRoundTrips(request.get().setGroupInstanceId(null)); + testAllMessageRoundTripsFromVersion((short) 3, request.get().setGroupInstanceId(instanceId)); + } + + @Test + public void testOffsetCommitDefaultGroupInstanceId() throws Exception { + testAllMessageRoundTrips(new OffsetCommitRequestData() + .setTopics(new ArrayList<>()) + .setGroupId("groupId")); + + Supplier request = () -> new OffsetCommitRequestData() + .setGroupId("groupId") + .setMemberId(memberId) + .setTopics(new ArrayList<>()) + .setGenerationId(15); + testAllMessageRoundTripsFromVersion((short) 1, request.get()); + testAllMessageRoundTripsFromVersion((short) 1, request.get().setGroupInstanceId(null)); + testAllMessageRoundTripsFromVersion((short) 7, request.get().setGroupInstanceId(instanceId)); + } + + @Test + public void testDescribeGroupsRequestVersions() throws Exception { + testAllMessageRoundTrips(new DescribeGroupsRequestData() + .setGroups(Collections.singletonList("group")) + .setIncludeAuthorizedOperations(false)); + } + + @Test + public void testDescribeGroupsResponseVersions() throws Exception { + DescribedGroupMember baseMember = new DescribedGroupMember() + .setMemberId(memberId); + + DescribedGroup baseGroup = new DescribedGroup() + .setGroupId("group") + .setGroupState("Stable").setErrorCode(Errors.NONE.code()) + .setMembers(Collections.singletonList(baseMember)) + .setProtocolType("consumer"); + DescribeGroupsResponseData baseResponse = new DescribeGroupsResponseData() + .setGroups(Collections.singletonList(baseGroup)); + testAllMessageRoundTrips(baseResponse); + + testAllMessageRoundTripsFromVersion((short) 1, baseResponse.setThrottleTimeMs(10)); + + baseGroup.setAuthorizedOperations(1); + testAllMessageRoundTripsFromVersion((short) 3, baseResponse); + + baseMember.setGroupInstanceId(instanceId); + testAllMessageRoundTripsFromVersion((short) 4, baseResponse); + } + + @Test + public void testDescribeClusterRequestVersions() throws Exception { + testAllMessageRoundTrips(new DescribeClusterRequestData() + .setIncludeClusterAuthorizedOperations(true)); + } + + @Test + public void testDescribeClusterResponseVersions() throws Exception { + DescribeClusterResponseData data = new DescribeClusterResponseData() + .setBrokers(new DescribeClusterBrokerCollection( + Collections.singletonList(new DescribeClusterBroker() + .setBrokerId(1) + .setHost("localhost") + .setPort(9092) + .setRack("rack1")).iterator())) + .setClusterId("clusterId") + .setControllerId(1) + .setClusterAuthorizedOperations(10); + + testAllMessageRoundTrips(data); + } + + @Test + public void testGroupInstanceIdIgnorableInDescribeGroupsResponse() throws Exception { + DescribeGroupsResponseData responseWithGroupInstanceId = + new DescribeGroupsResponseData() + .setGroups(Collections.singletonList( + new DescribedGroup() + .setGroupId("group") + .setGroupState("Stable") + .setErrorCode(Errors.NONE.code()) + .setMembers(Collections.singletonList( + new DescribedGroupMember() + .setMemberId(memberId) + .setGroupInstanceId(instanceId))) + .setProtocolType("consumer") + )); + + DescribeGroupsResponseData expectedResponse = responseWithGroupInstanceId.duplicate(); + // Unset GroupInstanceId + expectedResponse.groups().get(0).members().get(0).setGroupInstanceId(null); + + testAllMessageRoundTripsBeforeVersion((short) 4, responseWithGroupInstanceId, expectedResponse); + } + + @Test + public void testThrottleTimeIgnorableInDescribeGroupsResponse() throws Exception { + DescribeGroupsResponseData responseWithGroupInstanceId = + new DescribeGroupsResponseData() + .setGroups(Collections.singletonList( + new DescribedGroup() + .setGroupId("group") + .setGroupState("Stable") + .setErrorCode(Errors.NONE.code()) + .setMembers(Collections.singletonList( + new DescribedGroupMember() + .setMemberId(memberId))) + .setProtocolType("consumer") + )) + .setThrottleTimeMs(10); + + DescribeGroupsResponseData expectedResponse = responseWithGroupInstanceId.duplicate(); + // Unset throttle time + expectedResponse.setThrottleTimeMs(0); + + testAllMessageRoundTripsBeforeVersion((short) 1, responseWithGroupInstanceId, expectedResponse); + } + + @Test + public void testOffsetForLeaderEpochVersions() throws Exception { + // Version 2 adds optional current leader epoch + OffsetForLeaderEpochRequestData.OffsetForLeaderPartition partitionDataNoCurrentEpoch = + new OffsetForLeaderEpochRequestData.OffsetForLeaderPartition() + .setPartition(0) + .setLeaderEpoch(3); + OffsetForLeaderEpochRequestData.OffsetForLeaderPartition partitionDataWithCurrentEpoch = + new OffsetForLeaderEpochRequestData.OffsetForLeaderPartition() + .setPartition(0) + .setLeaderEpoch(3) + .setCurrentLeaderEpoch(5); + OffsetForLeaderEpochRequestData data = new OffsetForLeaderEpochRequestData(); + data.topics().add(new OffsetForLeaderEpochRequestData.OffsetForLeaderTopic() + .setTopic("foo") + .setPartitions(singletonList(partitionDataNoCurrentEpoch))); + + testAllMessageRoundTrips(data); + testAllMessageRoundTripsBeforeVersion((short) 2, partitionDataWithCurrentEpoch, partitionDataNoCurrentEpoch); + testAllMessageRoundTripsFromVersion((short) 2, partitionDataWithCurrentEpoch); + + // Version 3 adds the optional replica Id field + testAllMessageRoundTripsFromVersion((short) 3, new OffsetForLeaderEpochRequestData().setReplicaId(5)); + testAllMessageRoundTripsBeforeVersion((short) 3, + new OffsetForLeaderEpochRequestData().setReplicaId(5), + new OffsetForLeaderEpochRequestData()); + testAllMessageRoundTripsBeforeVersion((short) 3, + new OffsetForLeaderEpochRequestData().setReplicaId(5), + new OffsetForLeaderEpochRequestData().setReplicaId(-2)); + } + + @Test + public void testLeaderAndIsrVersions() throws Exception { + // Version 3 adds two new fields - AddingReplicas and RemovingReplicas + LeaderAndIsrRequestData.LeaderAndIsrTopicState partitionStateNoAddingRemovingReplicas = + new LeaderAndIsrRequestData.LeaderAndIsrTopicState() + .setTopicName("topic") + .setPartitionStates(Collections.singletonList( + new LeaderAndIsrRequestData.LeaderAndIsrPartitionState() + .setPartitionIndex(0) + .setReplicas(Collections.singletonList(0)) + )); + LeaderAndIsrRequestData.LeaderAndIsrTopicState partitionStateWithAddingRemovingReplicas = + new LeaderAndIsrRequestData.LeaderAndIsrTopicState() + .setTopicName("topic") + .setPartitionStates(Collections.singletonList( + new LeaderAndIsrRequestData.LeaderAndIsrPartitionState() + .setPartitionIndex(0) + .setReplicas(Collections.singletonList(0)) + .setAddingReplicas(Collections.singletonList(1)) + .setRemovingReplicas(Collections.singletonList(1)) + )); + testAllMessageRoundTripsBetweenVersions( + (short) 2, + (short) 3, + new LeaderAndIsrRequestData().setTopicStates(Collections.singletonList(partitionStateWithAddingRemovingReplicas)), + new LeaderAndIsrRequestData().setTopicStates(Collections.singletonList(partitionStateNoAddingRemovingReplicas))); + testAllMessageRoundTripsFromVersion((short) 3, new LeaderAndIsrRequestData().setTopicStates(Collections.singletonList(partitionStateWithAddingRemovingReplicas))); + } + + @Test + public void testOffsetCommitRequestVersions() throws Exception { + String groupId = "groupId"; + String topicName = "topic"; + String metadata = "metadata"; + int partition = 2; + int offset = 100; + + testAllMessageRoundTrips(new OffsetCommitRequestData() + .setGroupId(groupId) + .setTopics(Collections.singletonList( + new OffsetCommitRequestTopic() + .setName(topicName) + .setPartitions(Collections.singletonList( + new OffsetCommitRequestPartition() + .setPartitionIndex(partition) + .setCommittedMetadata(metadata) + .setCommittedOffset(offset) + ))))); + + Supplier request = + () -> new OffsetCommitRequestData() + .setGroupId(groupId) + .setMemberId("memberId") + .setGroupInstanceId("instanceId") + .setTopics(Collections.singletonList( + new OffsetCommitRequestTopic() + .setName(topicName) + .setPartitions(Collections.singletonList( + new OffsetCommitRequestPartition() + .setPartitionIndex(partition) + .setCommittedLeaderEpoch(10) + .setCommittedMetadata(metadata) + .setCommittedOffset(offset) + .setCommitTimestamp(20) + )))) + .setRetentionTimeMs(20); + + for (short version : ApiKeys.OFFSET_COMMIT.allVersions()) { + OffsetCommitRequestData requestData = request.get(); + if (version < 1) { + requestData.setMemberId(""); + requestData.setGenerationId(-1); + } + + if (version != 1) { + requestData.topics().get(0).partitions().get(0).setCommitTimestamp(-1); + } + + if (version < 2 || version > 4) { + requestData.setRetentionTimeMs(-1); + } + + if (version < 6) { + requestData.topics().get(0).partitions().get(0).setCommittedLeaderEpoch(-1); + } + + if (version < 7) { + requestData.setGroupInstanceId(null); + } + + if (version == 1) { + testEquivalentMessageRoundTrip(version, requestData); + } else if (version >= 2 && version <= 4) { + testAllMessageRoundTripsBetweenVersions(version, (short) 5, requestData, requestData); + } else { + testAllMessageRoundTripsFromVersion(version, requestData); + } + } + } + + @Test + public void testOffsetCommitResponseVersions() throws Exception { + Supplier response = + () -> new OffsetCommitResponseData() + .setTopics( + singletonList( + new OffsetCommitResponseTopic() + .setName("topic") + .setPartitions(singletonList( + new OffsetCommitResponsePartition() + .setPartitionIndex(1) + .setErrorCode(Errors.UNKNOWN_MEMBER_ID.code()) + )) + ) + ) + .setThrottleTimeMs(20); + + for (short version : ApiKeys.OFFSET_COMMIT.allVersions()) { + OffsetCommitResponseData responseData = response.get(); + if (version < 3) { + responseData.setThrottleTimeMs(0); + } + testAllMessageRoundTripsFromVersion(version, responseData); + } + } + + @Test + public void testTxnOffsetCommitRequestVersions() throws Exception { + String groupId = "groupId"; + String topicName = "topic"; + String metadata = "metadata"; + String txnId = "transactionalId"; + int producerId = 25; + short producerEpoch = 10; + String instanceId = "instance"; + String memberId = "member"; + int generationId = 1; + + int partition = 2; + int offset = 100; + + testAllMessageRoundTrips(new TxnOffsetCommitRequestData() + .setGroupId(groupId) + .setTransactionalId(txnId) + .setProducerId(producerId) + .setProducerEpoch(producerEpoch) + .setTopics(Collections.singletonList( + new TxnOffsetCommitRequestTopic() + .setName(topicName) + .setPartitions(Collections.singletonList( + new TxnOffsetCommitRequestPartition() + .setPartitionIndex(partition) + .setCommittedMetadata(metadata) + .setCommittedOffset(offset) + ))))); + + Supplier request = + () -> new TxnOffsetCommitRequestData() + .setGroupId(groupId) + .setTransactionalId(txnId) + .setProducerId(producerId) + .setProducerEpoch(producerEpoch) + .setGroupInstanceId(instanceId) + .setMemberId(memberId) + .setGenerationId(generationId) + .setTopics(Collections.singletonList( + new TxnOffsetCommitRequestTopic() + .setName(topicName) + .setPartitions(Collections.singletonList( + new TxnOffsetCommitRequestPartition() + .setPartitionIndex(partition) + .setCommittedLeaderEpoch(10) + .setCommittedMetadata(metadata) + .setCommittedOffset(offset) + )))); + + for (short version : ApiKeys.TXN_OFFSET_COMMIT.allVersions()) { + TxnOffsetCommitRequestData requestData = request.get(); + if (version < 2) { + requestData.topics().get(0).partitions().get(0).setCommittedLeaderEpoch(-1); + } + + if (version < 3) { + final short finalVersion = version; + assertThrows(UnsupportedVersionException.class, () -> testEquivalentMessageRoundTrip(finalVersion, requestData)); + requestData.setGroupInstanceId(null); + assertThrows(UnsupportedVersionException.class, () -> testEquivalentMessageRoundTrip(finalVersion, requestData)); + requestData.setMemberId(""); + assertThrows(UnsupportedVersionException.class, () -> testEquivalentMessageRoundTrip(finalVersion, requestData)); + requestData.setGenerationId(-1); + } + + testAllMessageRoundTripsFromVersion(version, requestData); + } + } + + @Test + public void testTxnOffsetCommitResponseVersions() throws Exception { + testAllMessageRoundTrips( + new TxnOffsetCommitResponseData() + .setTopics( + singletonList( + new TxnOffsetCommitResponseTopic() + .setName("topic") + .setPartitions(singletonList( + new TxnOffsetCommitResponsePartition() + .setPartitionIndex(1) + .setErrorCode(Errors.UNKNOWN_MEMBER_ID.code()) + )) + ) + ) + .setThrottleTimeMs(20)); + } + + @Test + public void testOffsetFetchV0ToV7() throws Exception { + String groupId = "groupId"; + String topicName = "topic"; + + List topics = Collections.singletonList( + new OffsetFetchRequestTopic() + .setName(topicName) + .setPartitionIndexes(Collections.singletonList(5))); + testAllMessageRoundTripsOffsetFetchV0ToV7(new OffsetFetchRequestData() + .setTopics(new ArrayList<>()) + .setGroupId(groupId)); + + testAllMessageRoundTripsOffsetFetchV0ToV7(new OffsetFetchRequestData() + .setGroupId(groupId) + .setTopics(topics)); + + OffsetFetchRequestData allPartitionData = new OffsetFetchRequestData() + .setGroupId(groupId) + .setTopics(null); + + OffsetFetchRequestData requireStableData = new OffsetFetchRequestData() + .setGroupId(groupId) + .setTopics(topics) + .setRequireStable(true); + + for (int version : listOfVersionsNonBatchOffsetFetch) { + final short finalVersion = (short) version; + if (version < 2) { + assertThrows(NullPointerException.class, () -> testAllMessageRoundTripsOffsetFetchFromVersionV0ToV7(finalVersion, allPartitionData)); + } else { + testAllMessageRoundTripsOffsetFetchFromVersionV0ToV7((short) version, allPartitionData); + } + + if (version < 7) { + assertThrows(UnsupportedVersionException.class, () -> testAllMessageRoundTripsOffsetFetchFromVersionV0ToV7(finalVersion, requireStableData)); + } else { + testAllMessageRoundTripsOffsetFetchFromVersionV0ToV7(finalVersion, requireStableData); + } + } + + Supplier response = + () -> new OffsetFetchResponseData() + .setTopics(Collections.singletonList( + new OffsetFetchResponseTopic() + .setName(topicName) + .setPartitions(Collections.singletonList( + new OffsetFetchResponsePartition() + .setPartitionIndex(5) + .setMetadata(null) + .setCommittedOffset(100) + .setCommittedLeaderEpoch(3) + .setErrorCode(Errors.UNKNOWN_TOPIC_OR_PARTITION.code()))))) + .setErrorCode(Errors.NOT_COORDINATOR.code()) + .setThrottleTimeMs(10); + for (int version : listOfVersionsNonBatchOffsetFetch) { + OffsetFetchResponseData responseData = response.get(); + if (version <= 1) { + responseData.setErrorCode(Errors.NONE.code()); + } + + if (version <= 2) { + responseData.setThrottleTimeMs(0); + } + + if (version <= 4) { + responseData.topics().get(0).partitions().get(0).setCommittedLeaderEpoch(-1); + } + + testAllMessageRoundTripsOffsetFetchFromVersionV0ToV7((short) version, responseData); + } + } + + private void testAllMessageRoundTripsOffsetFetchV0ToV7(Message message) throws Exception { + testDuplication(message); + testAllMessageRoundTripsOffsetFetchFromVersionV0ToV7(message.lowestSupportedVersion(), message); + } + + private void testAllMessageRoundTripsOffsetFetchFromVersionV0ToV7(short fromVersion, + Message message) throws Exception { + for (short version = fromVersion; version <= 7; version++) { + testEquivalentMessageRoundTrip(version, message); + } + } + + @Test + public void testOffsetFetchV8AndAboveSingleGroup() throws Exception { + String groupId = "groupId"; + String topicName = "topic"; + + List topic = Collections.singletonList( + new OffsetFetchRequestTopics() + .setName(topicName) + .setPartitionIndexes(Collections.singletonList(5))); + + OffsetFetchRequestData allPartitionData = new OffsetFetchRequestData() + .setGroups(Collections.singletonList( + new OffsetFetchRequestGroup() + .setGroupId(groupId) + .setTopics(null))); + + OffsetFetchRequestData specifiedPartitionData = new OffsetFetchRequestData() + .setGroups(Collections.singletonList( + new OffsetFetchRequestGroup() + .setGroupId(groupId) + .setTopics(topic))) + .setRequireStable(true); + + testAllMessageRoundTripsOffsetFetchV8AndAbove(allPartitionData); + testAllMessageRoundTripsOffsetFetchV8AndAbove(specifiedPartitionData); + + for (short version : ApiKeys.OFFSET_FETCH.allVersions()) { + if (version >= 8) { + testAllMessageRoundTripsOffsetFetchFromVersionV8AndAbove(version, specifiedPartitionData); + testAllMessageRoundTripsOffsetFetchFromVersionV8AndAbove(version, allPartitionData); + } + } + + Supplier response = + () -> new OffsetFetchResponseData() + .setGroups(Collections.singletonList( + new OffsetFetchResponseGroup() + .setGroupId(groupId) + .setTopics(Collections.singletonList( + new OffsetFetchResponseTopics() + .setPartitions(Collections.singletonList( + new OffsetFetchResponsePartitions() + .setPartitionIndex(5) + .setMetadata(null) + .setCommittedOffset(100) + .setCommittedLeaderEpoch(3) + .setErrorCode(Errors.UNKNOWN_TOPIC_OR_PARTITION.code()))))) + .setErrorCode(Errors.NOT_COORDINATOR.code()))) + .setThrottleTimeMs(10); + for (short version : ApiKeys.OFFSET_FETCH.allVersions()) { + if (version >= 8) { + OffsetFetchResponseData responseData = response.get(); + testAllMessageRoundTripsOffsetFetchFromVersionV8AndAbove(version, responseData); + } + } + } + + @Test + public void testOffsetFetchV8AndAbove() throws Exception { + String groupOne = "group1"; + String groupTwo = "group2"; + String groupThree = "group3"; + String groupFour = "group4"; + String groupFive = "group5"; + String topic1 = "topic1"; + String topic2 = "topic2"; + String topic3 = "topic3"; + + OffsetFetchRequestTopics topicOne = new OffsetFetchRequestTopics() + .setName(topic1) + .setPartitionIndexes(Collections.singletonList(5)); + OffsetFetchRequestTopics topicTwo = new OffsetFetchRequestTopics() + .setName(topic2) + .setPartitionIndexes(Collections.singletonList(10)); + OffsetFetchRequestTopics topicThree = new OffsetFetchRequestTopics() + .setName(topic3) + .setPartitionIndexes(Collections.singletonList(15)); + + List groupOneTopics = singletonList(topicOne); + OffsetFetchRequestGroup group1 = + new OffsetFetchRequestGroup() + .setGroupId(groupOne) + .setTopics(groupOneTopics); + + List groupTwoTopics = Arrays.asList(topicOne, topicTwo); + OffsetFetchRequestGroup group2 = + new OffsetFetchRequestGroup() + .setGroupId(groupTwo) + .setTopics(groupTwoTopics); + + List groupThreeTopics = Arrays.asList(topicOne, topicTwo, topicThree); + OffsetFetchRequestGroup group3 = + new OffsetFetchRequestGroup() + .setGroupId(groupThree) + .setTopics(groupThreeTopics); + + OffsetFetchRequestGroup group4 = + new OffsetFetchRequestGroup() + .setGroupId(groupFour) + .setTopics(null); + + OffsetFetchRequestGroup group5 = + new OffsetFetchRequestGroup() + .setGroupId(groupFive) + .setTopics(null); + + OffsetFetchRequestData requestData = new OffsetFetchRequestData() + .setGroups(Arrays.asList(group1, group2, group3, group4, group5)) + .setRequireStable(true); + + testAllMessageRoundTripsOffsetFetchV8AndAbove(requestData); + + testAllMessageRoundTripsOffsetFetchV8AndAbove(requestData.setRequireStable(false)); + + + for (short version : ApiKeys.OFFSET_FETCH.allVersions()) { + if (version >= 8) { + testAllMessageRoundTripsOffsetFetchFromVersionV8AndAbove(version, requestData); + } + } + + OffsetFetchResponseTopics responseTopic1 = + new OffsetFetchResponseTopics() + .setName(topic1) + .setPartitions(Collections.singletonList( + new OffsetFetchResponsePartitions() + .setPartitionIndex(5) + .setMetadata(null) + .setCommittedOffset(100) + .setCommittedLeaderEpoch(3) + .setErrorCode(Errors.UNKNOWN_TOPIC_OR_PARTITION.code()))); + OffsetFetchResponseTopics responseTopic2 = + new OffsetFetchResponseTopics() + .setName(topic2) + .setPartitions(Collections.singletonList( + new OffsetFetchResponsePartitions() + .setPartitionIndex(10) + .setMetadata("foo") + .setCommittedOffset(200) + .setCommittedLeaderEpoch(2) + .setErrorCode(Errors.TOPIC_AUTHORIZATION_FAILED.code()))); + OffsetFetchResponseTopics responseTopic3 = + new OffsetFetchResponseTopics() + .setName(topic3) + .setPartitions(Collections.singletonList( + new OffsetFetchResponsePartitions() + .setPartitionIndex(15) + .setMetadata("bar") + .setCommittedOffset(300) + .setCommittedLeaderEpoch(1) + .setErrorCode(Errors.GROUP_AUTHORIZATION_FAILED.code()))); + + OffsetFetchResponseGroup responseGroup1 = + new OffsetFetchResponseGroup() + .setGroupId(groupOne) + .setTopics(Collections.singletonList(responseTopic1)) + .setErrorCode(Errors.NOT_COORDINATOR.code()); + OffsetFetchResponseGroup responseGroup2 = + new OffsetFetchResponseGroup() + .setGroupId(groupTwo) + .setTopics(Arrays.asList(responseTopic1, responseTopic2)) + .setErrorCode(Errors.COORDINATOR_LOAD_IN_PROGRESS.code()); + OffsetFetchResponseGroup responseGroup3 = + new OffsetFetchResponseGroup() + .setGroupId(groupThree) + .setTopics(Arrays.asList(responseTopic1, responseTopic2, responseTopic3)) + .setErrorCode(Errors.NONE.code()); + OffsetFetchResponseGroup responseGroup4 = + new OffsetFetchResponseGroup() + .setGroupId(groupFour) + .setTopics(Arrays.asList(responseTopic1, responseTopic2, responseTopic3)) + .setErrorCode(Errors.NONE.code()); + OffsetFetchResponseGroup responseGroup5 = + new OffsetFetchResponseGroup() + .setGroupId(groupFive) + .setTopics(Arrays.asList(responseTopic1, responseTopic2, responseTopic3)) + .setErrorCode(Errors.NONE.code()); + + Supplier response = + () -> new OffsetFetchResponseData() + .setGroups(Arrays.asList(responseGroup1, responseGroup2, responseGroup3, + responseGroup4, responseGroup5)) + .setThrottleTimeMs(10); + for (short version : ApiKeys.OFFSET_FETCH.allVersions()) { + if (version >= 8) { + OffsetFetchResponseData responseData = response.get(); + testAllMessageRoundTripsOffsetFetchFromVersionV8AndAbove(version, responseData); + } + } + } + + private void testAllMessageRoundTripsOffsetFetchV8AndAbove(Message message) throws Exception { + testDuplication(message); + testAllMessageRoundTripsOffsetFetchFromVersionV8AndAbove((short) 8, message); + } + + private void testAllMessageRoundTripsOffsetFetchFromVersionV8AndAbove(short fromVersion, Message message) throws Exception { + for (short version = fromVersion; version <= message.highestSupportedVersion(); version++) { + testEquivalentMessageRoundTrip(version, message); + } + } + + @Test + public void testProduceResponseVersions() throws Exception { + String topicName = "topic"; + int partitionIndex = 0; + short errorCode = Errors.INVALID_TOPIC_EXCEPTION.code(); + long baseOffset = 12L; + int throttleTimeMs = 1234; + long logAppendTimeMs = 1234L; + long logStartOffset = 1234L; + int batchIndex = 0; + String batchIndexErrorMessage = "error message"; + String errorMessage = "global error message"; + + testAllMessageRoundTrips(new ProduceResponseData() + .setResponses(new ProduceResponseData.TopicProduceResponseCollection(singletonList( + new ProduceResponseData.TopicProduceResponse() + .setName(topicName) + .setPartitionResponses(singletonList( + new ProduceResponseData.PartitionProduceResponse() + .setIndex(partitionIndex) + .setErrorCode(errorCode) + .setBaseOffset(baseOffset)))).iterator()))); + + Supplier response = () -> new ProduceResponseData() + .setResponses(new ProduceResponseData.TopicProduceResponseCollection(singletonList( + new ProduceResponseData.TopicProduceResponse() + .setName(topicName) + .setPartitionResponses(singletonList( + new ProduceResponseData.PartitionProduceResponse() + .setIndex(partitionIndex) + .setErrorCode(errorCode) + .setBaseOffset(baseOffset) + .setLogAppendTimeMs(logAppendTimeMs) + .setLogStartOffset(logStartOffset) + .setRecordErrors(singletonList( + new ProduceResponseData.BatchIndexAndErrorMessage() + .setBatchIndex(batchIndex) + .setBatchIndexErrorMessage(batchIndexErrorMessage))) + .setErrorMessage(errorMessage)))).iterator())) + .setThrottleTimeMs(throttleTimeMs); + + for (short version : ApiKeys.PRODUCE.allVersions()) { + ProduceResponseData responseData = response.get(); + + if (version < 8) { + responseData.responses().iterator().next().partitionResponses().get(0).setRecordErrors(Collections.emptyList()); + responseData.responses().iterator().next().partitionResponses().get(0).setErrorMessage(null); + } + + if (version < 5) { + responseData.responses().iterator().next().partitionResponses().get(0).setLogStartOffset(-1); + } + + if (version < 2) { + responseData.responses().iterator().next().partitionResponses().get(0).setLogAppendTimeMs(-1); + } + + if (version < 1) { + responseData.setThrottleTimeMs(0); + } + + if (version >= 3 && version <= 4) { + testAllMessageRoundTripsBetweenVersions(version, (short) 5, responseData, responseData); + } else if (version >= 6 && version <= 7) { + testAllMessageRoundTripsBetweenVersions(version, (short) 8, responseData, responseData); + } else { + testEquivalentMessageRoundTrip(version, responseData); + } + } + } + + @Test + public void defaultValueShouldBeWritable() { + for (short version = SimpleExampleMessageData.LOWEST_SUPPORTED_VERSION; version <= SimpleExampleMessageData.HIGHEST_SUPPORTED_VERSION; ++version) { + MessageUtil.toByteBuffer(new SimpleExampleMessageData(), version); + } + } + + @Test + public void testSimpleMessage() throws Exception { + final SimpleExampleMessageData message = new SimpleExampleMessageData(); + message.setMyStruct(new SimpleExampleMessageData.MyStruct().setStructId(25).setArrayInStruct( + Collections.singletonList(new SimpleExampleMessageData.StructArray().setArrayFieldId(20)) + )); + message.setMyTaggedStruct(new SimpleExampleMessageData.TaggedStruct().setStructId("abc")); + + message.setProcessId(Uuid.randomUuid()); + message.setMyNullableString("notNull"); + message.setMyInt16((short) 3); + message.setMyString("test string"); + SimpleExampleMessageData duplicate = message.duplicate(); + assertEquals(duplicate, message); + assertEquals(message, duplicate); + duplicate.setMyTaggedIntArray(Collections.singletonList(123)); + assertNotEquals(duplicate, message); + assertNotEquals(message, duplicate); + + testAllMessageRoundTripsFromVersion((short) 2, message); + } + + private void testAllMessageRoundTrips(Message message) throws Exception { + testDuplication(message); + testAllMessageRoundTripsFromVersion(message.lowestSupportedVersion(), message); + } + + private void testDuplication(Message message) { + Message duplicate = message.duplicate(); + assertEquals(duplicate, message); + assertEquals(message, duplicate); + assertEquals(duplicate.hashCode(), message.hashCode()); + assertEquals(message.hashCode(), duplicate.hashCode()); + } + + private void testAllMessageRoundTripsBeforeVersion(short beforeVersion, Message message, Message expected) throws Exception { + testAllMessageRoundTripsBetweenVersions((short) 0, beforeVersion, message, expected); + } + + /** + * @param startVersion - the version we want to start at, inclusive + * @param endVersion - the version we want to end at, exclusive + */ + private void testAllMessageRoundTripsBetweenVersions(short startVersion, short endVersion, Message message, Message expected) throws Exception { + for (short version = startVersion; version < endVersion; version++) { + testMessageRoundTrip(version, message, expected); + } + } + + private void testAllMessageRoundTripsFromVersion(short fromVersion, Message message) throws Exception { + for (short version = fromVersion; version <= message.highestSupportedVersion(); version++) { + testEquivalentMessageRoundTrip(version, message); + } + } + + private void testMessageRoundTrip(short version, Message message, Message expected) throws Exception { + testByteBufferRoundTrip(version, message, expected); + } + + private void testEquivalentMessageRoundTrip(short version, Message message) throws Exception { + testByteBufferRoundTrip(version, message, message); + testJsonRoundTrip(version, message, message); + } + + private void testByteBufferRoundTrip(short version, Message message, Message expected) throws Exception { + ObjectSerializationCache cache = new ObjectSerializationCache(); + int size = message.size(cache, version); + ByteBuffer buf = ByteBuffer.allocate(size); + ByteBufferAccessor byteBufferAccessor = new ByteBufferAccessor(buf); + message.write(byteBufferAccessor, cache, version); + assertEquals(size, buf.position(), "The result of the size function does not match the number of bytes " + + "written for version " + version); + Message message2 = message.getClass().getConstructor().newInstance(); + buf.flip(); + message2.read(byteBufferAccessor, version); + assertEquals(size, buf.position(), "The result of the size function does not match the number of bytes " + + "read back in for version " + version); + assertEquals(expected, message2, "The message object created after a round trip did not match for " + + "version " + version); + assertEquals(expected.hashCode(), message2.hashCode()); + assertEquals(expected.toString(), message2.toString()); + } + + private void testJsonRoundTrip(short version, Message message, Message expected) throws Exception { + String jsonConverter = jsonConverterTypeName(message.getClass().getTypeName()); + Class converter = Class.forName(jsonConverter); + Method writeMethod = converter.getMethod("write", message.getClass(), short.class); + JsonNode jsonNode = (JsonNode) writeMethod.invoke(null, message, version); + Method readMethod = converter.getMethod("read", JsonNode.class, short.class); + Message message2 = (Message) readMethod.invoke(null, jsonNode, version); + assertEquals(expected, message2); + assertEquals(expected.hashCode(), message2.hashCode()); + assertEquals(expected.toString(), message2.toString()); + } + + private static String jsonConverterTypeName(String source) { + int outerClassIndex = source.lastIndexOf('$'); + if (outerClassIndex == -1) { + return source + "JsonConverter"; + } else { + return source.substring(0, outerClassIndex) + "JsonConverter$" + + source.substring(outerClassIndex + 1) + "JsonConverter"; + } + } + + /** + * Verify that the JSON files support the same message versions as the + * schemas accessible through the ApiKey class. + */ + @Test + public void testMessageVersions() { + for (ApiKeys apiKey : ApiKeys.values()) { + Message message = null; + try { + message = ApiMessageType.fromApiKey(apiKey.id).newRequest(); + } catch (UnsupportedVersionException e) { + fail("No request message spec found for API " + apiKey); + } + assertTrue(apiKey.latestVersion() <= message.highestSupportedVersion(), + "Request message spec for " + apiKey + " only " + "supports versions up to " + + message.highestSupportedVersion()); + try { + message = ApiMessageType.fromApiKey(apiKey.id).newResponse(); + } catch (UnsupportedVersionException e) { + fail("No response message spec found for API " + apiKey); + } + assertTrue(apiKey.latestVersion() <= message.highestSupportedVersion(), + "Response message spec for " + apiKey + " only " + "supports versions up to " + + message.highestSupportedVersion()); + } + } + + @Test + public void testDefaultValues() { + verifyWriteRaisesUve((short) 0, "validateOnly", + new CreateTopicsRequestData().setValidateOnly(true)); + verifyWriteSucceeds((short) 0, + new CreateTopicsRequestData().setValidateOnly(false)); + verifyWriteSucceeds((short) 0, + new OffsetCommitRequestData().setRetentionTimeMs(123)); + verifyWriteRaisesUve((short) 5, "forgotten", + new FetchRequestData().setForgottenTopicsData(singletonList( + new FetchRequestData.ForgottenTopic().setTopic("foo")))); + } + + @Test + public void testNonIgnorableFieldWithDefaultNull() { + // Test non-ignorable string field `groupInstanceId` with default null + verifyWriteRaisesUve((short) 0, "groupInstanceId", new HeartbeatRequestData() + .setGroupId("groupId") + .setGenerationId(15) + .setMemberId(memberId) + .setGroupInstanceId(instanceId)); + verifyWriteSucceeds((short) 0, new HeartbeatRequestData() + .setGroupId("groupId") + .setGenerationId(15) + .setMemberId(memberId) + .setGroupInstanceId(null)); + verifyWriteSucceeds((short) 0, new HeartbeatRequestData() + .setGroupId("groupId") + .setGenerationId(15) + .setMemberId(memberId)); + } + + @Test + public void testWriteNullForNonNullableFieldRaisesException() { + CreateTopicsRequestData createTopics = new CreateTopicsRequestData().setTopics(null); + for (short version : ApiKeys.CREATE_TOPICS.allVersions()) { + verifyWriteRaisesNpe(version, createTopics); + } + MetadataRequestData metadata = new MetadataRequestData().setTopics(null); + verifyWriteRaisesNpe((short) 0, metadata); + } + + @Test + public void testUnknownTaggedFields() { + CreateTopicsRequestData createTopics = new CreateTopicsRequestData(); + verifyWriteSucceeds((short) 6, createTopics); + RawTaggedField field1000 = new RawTaggedField(1000, new byte[] {0x1, 0x2, 0x3}); + createTopics.unknownTaggedFields().add(field1000); + verifyWriteRaisesUve((short) 0, "Tagged fields were set", createTopics); + verifyWriteSucceeds((short) 6, createTopics); + } + + @Test + public void testLongTaggedString() throws Exception { + char[] chars = new char[1024]; + Arrays.fill(chars, 'a'); + String longString = new String(chars); + SimpleExampleMessageData message = new SimpleExampleMessageData() + .setMyString(longString); + ObjectSerializationCache cache = new ObjectSerializationCache(); + short version = 1; + int size = message.size(cache, version); + ByteBuffer buf = ByteBuffer.allocate(size); + ByteBufferAccessor byteBufferAccessor = new ByteBufferAccessor(buf); + message.write(byteBufferAccessor, cache, version); + assertEquals(size, buf.position()); + } + + private void verifyWriteRaisesNpe(short version, Message message) { + ObjectSerializationCache cache = new ObjectSerializationCache(); + assertThrows(NullPointerException.class, () -> { + int size = message.size(cache, version); + ByteBuffer buf = ByteBuffer.allocate(size); + ByteBufferAccessor byteBufferAccessor = new ByteBufferAccessor(buf); + message.write(byteBufferAccessor, cache, version); + }); + } + + private void verifyWriteRaisesUve(short version, + String problemText, + Message message) { + ObjectSerializationCache cache = new ObjectSerializationCache(); + UnsupportedVersionException e = + assertThrows(UnsupportedVersionException.class, () -> { + int size = message.size(cache, version); + ByteBuffer buf = ByteBuffer.allocate(size); + ByteBufferAccessor byteBufferAccessor = new ByteBufferAccessor(buf); + message.write(byteBufferAccessor, cache, version); + }); + assertTrue(e.getMessage().contains(problemText), "Expected to get an error message about " + problemText + + ", but got: " + e.getMessage()); + } + + private void verifyWriteSucceeds(short version, Message message) { + ObjectSerializationCache cache = new ObjectSerializationCache(); + int size = message.size(cache, version); + ByteBuffer buf = ByteBuffer.allocate(size * 2); + ByteBufferAccessor byteBufferAccessor = new ByteBufferAccessor(buf); + message.write(byteBufferAccessor, cache, version); + assertEquals(size, buf.position(), "Expected the serialized size to be " + size + ", but it was " + buf.position()); + } + + @Test + public void testCompareWithUnknownTaggedFields() { + CreateTopicsRequestData createTopics = new CreateTopicsRequestData(); + createTopics.setTimeoutMs(123); + CreateTopicsRequestData createTopics2 = new CreateTopicsRequestData(); + createTopics2.setTimeoutMs(123); + assertEquals(createTopics, createTopics2); + assertEquals(createTopics2, createTopics); + // Call the accessor, which will create a new empty list. + createTopics.unknownTaggedFields(); + // Verify that the equalities still hold after the new empty list has been created. + assertEquals(createTopics, createTopics2); + assertEquals(createTopics2, createTopics); + createTopics.unknownTaggedFields().add(new RawTaggedField(0, new byte[] {0})); + assertNotEquals(createTopics, createTopics2); + assertNotEquals(createTopics2, createTopics); + createTopics2.unknownTaggedFields().add(new RawTaggedField(0, new byte[] {0})); + assertEquals(createTopics, createTopics2); + assertEquals(createTopics2, createTopics); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/message/RecordsSerdeTest.java b/clients/src/test/java/org/apache/kafka/common/message/RecordsSerdeTest.java new file mode 100644 index 0000000..739bed9 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/message/RecordsSerdeTest.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.message; + +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.MessageUtil; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.SimpleRecord; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +public class RecordsSerdeTest { + + @Test + public void testSerdeRecords() throws Exception { + MemoryRecords records = MemoryRecords.withRecords(CompressionType.NONE, + new SimpleRecord("foo".getBytes()), + new SimpleRecord("bar".getBytes())); + + SimpleRecordsMessageData message = new SimpleRecordsMessageData() + .setTopic("foo") + .setRecordSet(records); + + testAllRoundTrips(message); + } + + @Test + public void testSerdeNullRecords() throws Exception { + SimpleRecordsMessageData message = new SimpleRecordsMessageData() + .setTopic("foo"); + assertNull(message.recordSet()); + + testAllRoundTrips(message); + } + + @Test + public void testSerdeEmptyRecords() throws Exception { + SimpleRecordsMessageData message = new SimpleRecordsMessageData() + .setTopic("foo") + .setRecordSet(MemoryRecords.EMPTY); + testAllRoundTrips(message); + } + + private void testAllRoundTrips(SimpleRecordsMessageData message) throws Exception { + for (short version = SimpleRecordsMessageData.LOWEST_SUPPORTED_VERSION; + version <= SimpleRecordsMessageData.HIGHEST_SUPPORTED_VERSION; + version++) { + testRoundTrip(message, version); + } + } + + private void testRoundTrip(SimpleRecordsMessageData message, short version) { + ByteBuffer buf = MessageUtil.toByteBuffer(message, version); + SimpleRecordsMessageData message2 = deserialize(buf.duplicate(), version); + assertEquals(message, message2); + assertEquals(message.hashCode(), message2.hashCode()); + } + + private SimpleRecordsMessageData deserialize(ByteBuffer buffer, short version) { + ByteBufferAccessor readable = new ByteBufferAccessor(buffer); + return new SimpleRecordsMessageData(readable, version); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/message/SimpleExampleMessageTest.java b/clients/src/test/java/org/apache/kafka/common/message/SimpleExampleMessageTest.java new file mode 100644 index 0000000..1cdafcd --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/message/SimpleExampleMessageTest.java @@ -0,0 +1,360 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.message; + +import com.fasterxml.jackson.databind.JsonNode; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.MessageUtil; +import org.apache.kafka.common.protocol.ObjectSerializationCache; +import org.apache.kafka.common.utils.ByteUtils; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.function.Consumer; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class SimpleExampleMessageTest { + + @Test + public void shouldStoreField() { + final Uuid uuid = Uuid.randomUuid(); + final ByteBuffer buf = ByteBuffer.wrap(new byte[] {1, 2, 3}); + + final SimpleExampleMessageData out = new SimpleExampleMessageData(); + out.setProcessId(uuid); + out.setZeroCopyByteBuffer(buf); + + assertEquals(uuid, out.processId()); + assertEquals(buf, out.zeroCopyByteBuffer()); + + out.setNullableZeroCopyByteBuffer(null); + assertNull(out.nullableZeroCopyByteBuffer()); + out.setNullableZeroCopyByteBuffer(buf); + assertEquals(buf, out.nullableZeroCopyByteBuffer()); + } + + @Test + public void shouldThrowIfCannotWriteNonIgnorableField() { + // processId is not supported in v0 and is not marked as ignorable + + final SimpleExampleMessageData out = new SimpleExampleMessageData().setProcessId(Uuid.randomUuid()); + assertThrows(UnsupportedVersionException.class, () -> + out.write(new ByteBufferAccessor(ByteBuffer.allocate(64)), new ObjectSerializationCache(), (short) 0)); + } + + @Test + public void shouldDefaultField() { + final SimpleExampleMessageData out = new SimpleExampleMessageData(); + assertEquals(Uuid.fromString("AAAAAAAAAAAAAAAAAAAAAA"), out.processId()); + assertEquals(ByteUtils.EMPTY_BUF, out.zeroCopyByteBuffer()); + assertEquals(ByteUtils.EMPTY_BUF, out.nullableZeroCopyByteBuffer()); + } + + @Test + public void shouldRoundTripFieldThroughBuffer() { + final Uuid uuid = Uuid.randomUuid(); + final ByteBuffer buf = ByteBuffer.wrap(new byte[] {1, 2, 3}); + final SimpleExampleMessageData out = new SimpleExampleMessageData(); + out.setProcessId(uuid); + out.setZeroCopyByteBuffer(buf); + + final ByteBuffer buffer = MessageUtil.toByteBuffer(out, (short) 1); + + final SimpleExampleMessageData in = new SimpleExampleMessageData(); + in.read(new ByteBufferAccessor(buffer), (short) 1); + + buf.rewind(); + + assertEquals(uuid, in.processId()); + assertEquals(buf, in.zeroCopyByteBuffer()); + assertEquals(ByteUtils.EMPTY_BUF, in.nullableZeroCopyByteBuffer()); + } + + @Test + public void shouldRoundTripFieldThroughBufferWithNullable() { + final Uuid uuid = Uuid.randomUuid(); + final ByteBuffer buf1 = ByteBuffer.wrap(new byte[] {1, 2, 3}); + final ByteBuffer buf2 = ByteBuffer.wrap(new byte[] {4, 5, 6}); + final SimpleExampleMessageData out = new SimpleExampleMessageData(); + out.setProcessId(uuid); + out.setZeroCopyByteBuffer(buf1); + out.setNullableZeroCopyByteBuffer(buf2); + + final ByteBuffer buffer = MessageUtil.toByteBuffer(out, (short) 1); + + final SimpleExampleMessageData in = new SimpleExampleMessageData(); + in.read(new ByteBufferAccessor(buffer), (short) 1); + + buf1.rewind(); + buf2.rewind(); + + assertEquals(uuid, in.processId()); + assertEquals(buf1, in.zeroCopyByteBuffer()); + assertEquals(buf2, in.nullableZeroCopyByteBuffer()); + } + + @Test + public void shouldImplementEqualsAndHashCode() { + final Uuid uuid = Uuid.randomUuid(); + final ByteBuffer buf = ByteBuffer.wrap(new byte[] {1, 2, 3}); + final SimpleExampleMessageData a = new SimpleExampleMessageData(); + a.setProcessId(uuid); + a.setZeroCopyByteBuffer(buf); + + final SimpleExampleMessageData b = new SimpleExampleMessageData(); + b.setProcessId(uuid); + b.setZeroCopyByteBuffer(buf); + + assertEquals(a, b); + assertEquals(a.hashCode(), b.hashCode()); + // just tagging this on here + assertEquals(a.toString(), b.toString()); + + a.setNullableZeroCopyByteBuffer(buf); + b.setNullableZeroCopyByteBuffer(buf); + + assertEquals(a, b); + assertEquals(a.hashCode(), b.hashCode()); + assertEquals(a.toString(), b.toString()); + + a.setNullableZeroCopyByteBuffer(null); + b.setNullableZeroCopyByteBuffer(null); + + assertEquals(a, b); + assertEquals(a.hashCode(), b.hashCode()); + assertEquals(a.toString(), b.toString()); + } + + @Test + public void testMyTaggedIntArray() { + // Verify that the tagged int array reads as empty when not set. + testRoundTrip(new SimpleExampleMessageData(), + message -> assertEquals(Collections.emptyList(), message.myTaggedIntArray())); + + // Verify that we can set a tagged array of ints. + testRoundTrip(new SimpleExampleMessageData(). + setMyTaggedIntArray(Arrays.asList(1, 2, 3)), + message -> assertEquals(Arrays.asList(1, 2, 3), message.myTaggedIntArray())); + } + + @Test + public void testMyNullableString() { + // Verify that the tagged field reads as null when not set. + testRoundTrip(new SimpleExampleMessageData(), message -> assertNull(message.myNullableString())); + + // Verify that we can set and retrieve a string for the tagged field. + testRoundTrip(new SimpleExampleMessageData().setMyNullableString("foobar"), + message -> assertEquals("foobar", message.myNullableString())); + } + + @Test + public void testMyInt16() { + // Verify that the tagged field reads as 123 when not set. + testRoundTrip(new SimpleExampleMessageData(), + message -> assertEquals((short) 123, message.myInt16())); + + testRoundTrip(new SimpleExampleMessageData().setMyInt16((short) 456), + message -> assertEquals((short) 456, message.myInt16())); + } + + @Test + public void testMyUint16() { + // Verify that the uint16 field reads as 33000 when not set. + testRoundTrip(new SimpleExampleMessageData(), + message -> assertEquals(33000, message.myUint16())); + + testRoundTrip(new SimpleExampleMessageData().setMyUint16(123), + message -> assertEquals(123, message.myUint16())); + testRoundTrip(new SimpleExampleMessageData().setMyUint16(60000), + message -> assertEquals(60000, message.myUint16())); + } + + @Test + public void testMyString() { + // Verify that the tagged field reads as empty when not set. + testRoundTrip(new SimpleExampleMessageData(), + message -> assertEquals("", message.myString())); + + testRoundTrip(new SimpleExampleMessageData().setMyString("abc"), + message -> assertEquals("abc", message.myString())); + } + + @Test + public void testMyBytes() { + assertThrows(RuntimeException.class, + () -> new SimpleExampleMessageData().setMyUint16(-1)); + assertThrows(RuntimeException.class, + () -> new SimpleExampleMessageData().setMyUint16(65536)); + + // Verify that the tagged field reads as empty when not set. + testRoundTrip(new SimpleExampleMessageData(), + message -> assertArrayEquals(new byte[0], message.myBytes())); + + testRoundTrip(new SimpleExampleMessageData(). + setMyBytes(new byte[] {0x43, 0x66}), + message -> assertArrayEquals(new byte[] {0x43, 0x66}, + message.myBytes())); + + testRoundTrip(new SimpleExampleMessageData().setMyBytes(null), message -> assertNull(message.myBytes())); + } + + @Test + public void testTaggedUuid() { + testRoundTrip(new SimpleExampleMessageData(), + message -> assertEquals( + Uuid.fromString("H3KKO4NTRPaCWtEmm3vW7A"), + message.taggedUuid())); + + Uuid randomUuid = Uuid.randomUuid(); + testRoundTrip(new SimpleExampleMessageData(). + setTaggedUuid(randomUuid), + message -> assertEquals( + randomUuid, + message.taggedUuid())); + } + + @Test + public void testTaggedLong() { + testRoundTrip(new SimpleExampleMessageData(), + message -> assertEquals(0xcafcacafcacafcaL, + message.taggedLong())); + + testRoundTrip(new SimpleExampleMessageData(). + setMyString("blah"). + setMyTaggedIntArray(Collections.singletonList(4)). + setTaggedLong(0x123443211234432L), + message -> assertEquals(0x123443211234432L, + message.taggedLong())); + } + + @Test + public void testMyStruct() { + // Verify that we can set and retrieve a nullable struct object. + SimpleExampleMessageData.MyStruct myStruct = + new SimpleExampleMessageData.MyStruct().setStructId(10).setArrayInStruct( + Collections.singletonList(new SimpleExampleMessageData.StructArray().setArrayFieldId(20)) + ); + testRoundTrip(new SimpleExampleMessageData().setMyStruct(myStruct), + message -> assertEquals(myStruct, message.myStruct()), (short) 2); + } + + @Test + public void testMyStructUnsupportedVersion() { + SimpleExampleMessageData.MyStruct myStruct = + new SimpleExampleMessageData.MyStruct().setStructId(10); + // Check serialization throws exception for unsupported version + assertThrows(UnsupportedVersionException.class, + () -> testRoundTrip(new SimpleExampleMessageData().setMyStruct(myStruct), (short) 1)); + } + + /** + * Check following cases: + * 1. Tagged struct can be serialized/deserialized for version it is supported + * 2. Tagged struct doesn't matter for versions it is not declared. + */ + @Test + public void testMyTaggedStruct() { + // Verify that we can set and retrieve a nullable struct object. + SimpleExampleMessageData.TaggedStruct myStruct = + new SimpleExampleMessageData.TaggedStruct().setStructId("abc"); + testRoundTrip(new SimpleExampleMessageData().setMyTaggedStruct(myStruct), + message -> assertEquals(myStruct, message.myTaggedStruct()), (short) 2); + + // Not setting field works for both version 1 and version 2 protocol + testRoundTrip(new SimpleExampleMessageData().setMyString("abc"), + message -> assertEquals("abc", message.myString()), (short) 1); + testRoundTrip(new SimpleExampleMessageData().setMyString("abc"), + message -> assertEquals("abc", message.myString()), (short) 2); + } + + @Test + public void testCommonStruct() { + SimpleExampleMessageData message = new SimpleExampleMessageData(); + message.setMyCommonStruct(new SimpleExampleMessageData.TestCommonStruct() + .setFoo(1) + .setBar(2)); + message.setMyOtherCommonStruct(new SimpleExampleMessageData.TestCommonStruct() + .setFoo(3) + .setBar(4)); + testRoundTrip(message, (short) 2); + } + + private SimpleExampleMessageData deserialize(ByteBuffer buf, short version) { + SimpleExampleMessageData message = new SimpleExampleMessageData(); + message.read(new ByteBufferAccessor(buf.duplicate()), version); + return message; + } + + private void testRoundTrip(SimpleExampleMessageData message, short version) { + testRoundTrip(message, m -> { }, version); + } + + private void testRoundTrip(SimpleExampleMessageData message, + Consumer validator) { + testRoundTrip(message, validator, (short) 1); + } + + private void testRoundTrip(SimpleExampleMessageData message, + Consumer validator, + short version) { + validator.accept(message); + ByteBuffer buf = MessageUtil.toByteBuffer(message, version); + + SimpleExampleMessageData message2 = deserialize(buf.duplicate(), version); + validator.accept(message2); + assertEquals(message, message2); + assertEquals(message.hashCode(), message2.hashCode()); + + // Check JSON serialization + JsonNode serializedJson = SimpleExampleMessageDataJsonConverter.write(message, version); + SimpleExampleMessageData messageFromJson = SimpleExampleMessageDataJsonConverter.read(serializedJson, version); + validator.accept(messageFromJson); + assertEquals(message, messageFromJson); + assertEquals(message.hashCode(), messageFromJson.hashCode()); + } + + @Test + public void testToString() { + SimpleExampleMessageData message = new SimpleExampleMessageData(); + message.setMyUint16(65535); + message.setTaggedUuid(Uuid.fromString("x7D3Ck_ZRA22-dzIvu_pnQ")); + message.setMyFloat64(1.0); + assertEquals("SimpleExampleMessageData(processId=AAAAAAAAAAAAAAAAAAAAAA, " + + "myTaggedIntArray=[], " + + "myNullableString=null, " + + "myInt16=123, myFloat64=1.0, " + + "myString='', " + + "myBytes=[], " + + "taggedUuid=x7D3Ck_ZRA22-dzIvu_pnQ, " + + "taggedLong=914172222550880202, " + + "zeroCopyByteBuffer=java.nio.HeapByteBuffer[pos=0 lim=0 cap=0], " + + "nullableZeroCopyByteBuffer=java.nio.HeapByteBuffer[pos=0 lim=0 cap=0], " + + "myStruct=MyStruct(structId=0, arrayInStruct=[]), " + + "myTaggedStruct=TaggedStruct(structId=''), " + + "myCommonStruct=TestCommonStruct(foo=123, bar=123), " + + "myOtherCommonStruct=TestCommonStruct(foo=123, bar=123), " + + "myUint16=65535)", message.toString()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/metrics/FakeMetricsReporter.java b/clients/src/test/java/org/apache/kafka/common/metrics/FakeMetricsReporter.java new file mode 100644 index 0000000..99dfc30 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/metrics/FakeMetricsReporter.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics; + +import java.util.List; +import java.util.Map; + +public class FakeMetricsReporter implements MetricsReporter { + + @Override + public void configure(Map configs) {} + + @Override + public void init(List metrics) {} + + @Override + public void metricChange(KafkaMetric metric) {} + + @Override + public void metricRemoval(KafkaMetric metric) {} + + @Override + public void close() {} + +} diff --git a/clients/src/test/java/org/apache/kafka/common/metrics/JmxReporterTest.java b/clients/src/test/java/org/apache/kafka/common/metrics/JmxReporterTest.java new file mode 100644 index 0000000..a6b2e7f --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/metrics/JmxReporterTest.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics; + +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.stats.Avg; +import org.apache.kafka.common.metrics.stats.CumulativeSum; +import org.apache.kafka.common.utils.Time; +import org.junit.jupiter.api.Test; + +import javax.management.MBeanServer; +import javax.management.ObjectName; +import java.lang.management.ManagementFactory; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class JmxReporterTest { + + @Test + public void testJmxRegistration() throws Exception { + Metrics metrics = new Metrics(); + MBeanServer server = ManagementFactory.getPlatformMBeanServer(); + try { + JmxReporter reporter = new JmxReporter(); + metrics.addReporter(reporter); + + assertFalse(server.isRegistered(new ObjectName(":type=grp1"))); + + Sensor sensor = metrics.sensor("kafka.requests"); + sensor.add(metrics.metricName("pack.bean1.avg", "grp1"), new Avg()); + sensor.add(metrics.metricName("pack.bean2.total", "grp2"), new CumulativeSum()); + + assertTrue(server.isRegistered(new ObjectName(":type=grp1"))); + assertEquals(Double.NaN, server.getAttribute(new ObjectName(":type=grp1"), "pack.bean1.avg")); + assertTrue(server.isRegistered(new ObjectName(":type=grp2"))); + assertEquals(0.0, server.getAttribute(new ObjectName(":type=grp2"), "pack.bean2.total")); + + MetricName metricName = metrics.metricName("pack.bean1.avg", "grp1"); + String mBeanName = JmxReporter.getMBeanName("", metricName); + assertTrue(reporter.containsMbean(mBeanName)); + metrics.removeMetric(metricName); + assertFalse(reporter.containsMbean(mBeanName)); + + assertFalse(server.isRegistered(new ObjectName(":type=grp1"))); + assertTrue(server.isRegistered(new ObjectName(":type=grp2"))); + assertEquals(0.0, server.getAttribute(new ObjectName(":type=grp2"), "pack.bean2.total")); + + metricName = metrics.metricName("pack.bean2.total", "grp2"); + metrics.removeMetric(metricName); + assertFalse(reporter.containsMbean(mBeanName)); + + assertFalse(server.isRegistered(new ObjectName(":type=grp1"))); + assertFalse(server.isRegistered(new ObjectName(":type=grp2"))); + } finally { + metrics.close(); + } + } + + @Test + public void testJmxRegistrationSanitization() throws Exception { + Metrics metrics = new Metrics(); + MBeanServer server = ManagementFactory.getPlatformMBeanServer(); + try { + metrics.addReporter(new JmxReporter()); + + Sensor sensor = metrics.sensor("kafka.requests"); + sensor.add(metrics.metricName("name", "group", "desc", "id", "foo*"), new CumulativeSum()); + sensor.add(metrics.metricName("name", "group", "desc", "id", "foo+"), new CumulativeSum()); + sensor.add(metrics.metricName("name", "group", "desc", "id", "foo?"), new CumulativeSum()); + sensor.add(metrics.metricName("name", "group", "desc", "id", "foo:"), new CumulativeSum()); + sensor.add(metrics.metricName("name", "group", "desc", "id", "foo%"), new CumulativeSum()); + + assertTrue(server.isRegistered(new ObjectName(":type=group,id=\"foo\\*\""))); + assertEquals(0.0, server.getAttribute(new ObjectName(":type=group,id=\"foo\\*\""), "name")); + assertTrue(server.isRegistered(new ObjectName(":type=group,id=\"foo+\""))); + assertEquals(0.0, server.getAttribute(new ObjectName(":type=group,id=\"foo+\""), "name")); + assertTrue(server.isRegistered(new ObjectName(":type=group,id=\"foo\\?\""))); + assertEquals(0.0, server.getAttribute(new ObjectName(":type=group,id=\"foo\\?\""), "name")); + assertTrue(server.isRegistered(new ObjectName(":type=group,id=\"foo:\""))); + assertEquals(0.0, server.getAttribute(new ObjectName(":type=group,id=\"foo:\""), "name")); + assertTrue(server.isRegistered(new ObjectName(":type=group,id=foo%"))); + assertEquals(0.0, server.getAttribute(new ObjectName(":type=group,id=foo%"), "name")); + + metrics.removeMetric(metrics.metricName("name", "group", "desc", "id", "foo*")); + metrics.removeMetric(metrics.metricName("name", "group", "desc", "id", "foo+")); + metrics.removeMetric(metrics.metricName("name", "group", "desc", "id", "foo?")); + metrics.removeMetric(metrics.metricName("name", "group", "desc", "id", "foo:")); + metrics.removeMetric(metrics.metricName("name", "group", "desc", "id", "foo%")); + + assertFalse(server.isRegistered(new ObjectName(":type=group,id=\"foo\\*\""))); + assertFalse(server.isRegistered(new ObjectName(":type=group,id=foo+"))); + assertFalse(server.isRegistered(new ObjectName(":type=group,id=\"foo\\?\""))); + assertFalse(server.isRegistered(new ObjectName(":type=group,id=\"foo:\""))); + assertFalse(server.isRegistered(new ObjectName(":type=group,id=foo%"))); + } finally { + metrics.close(); + } + } + + @Test + public void testPredicateAndDynamicReload() throws Exception { + Metrics metrics = new Metrics(); + MBeanServer server = ManagementFactory.getPlatformMBeanServer(); + + Map configs = new HashMap<>(); + + configs.put(JmxReporter.EXCLUDE_CONFIG, + JmxReporter.getMBeanName("", metrics.metricName("pack.bean2.total", "grp2"))); + + try { + JmxReporter reporter = new JmxReporter(); + reporter.configure(configs); + metrics.addReporter(reporter); + + Sensor sensor = metrics.sensor("kafka.requests"); + sensor.add(metrics.metricName("pack.bean2.avg", "grp1"), new Avg()); + sensor.add(metrics.metricName("pack.bean2.total", "grp2"), new CumulativeSum()); + sensor.record(); + + assertTrue(server.isRegistered(new ObjectName(":type=grp1"))); + assertEquals(1.0, server.getAttribute(new ObjectName(":type=grp1"), "pack.bean2.avg")); + assertFalse(server.isRegistered(new ObjectName(":type=grp2"))); + + sensor.record(); + + configs.put(JmxReporter.EXCLUDE_CONFIG, + JmxReporter.getMBeanName("", metrics.metricName("pack.bean2.avg", "grp1"))); + + reporter.reconfigure(configs); + + assertFalse(server.isRegistered(new ObjectName(":type=grp1"))); + assertTrue(server.isRegistered(new ObjectName(":type=grp2"))); + assertEquals(2.0, server.getAttribute(new ObjectName(":type=grp2"), "pack.bean2.total")); + + metrics.removeMetric(metrics.metricName("pack.bean2.total", "grp2")); + assertFalse(server.isRegistered(new ObjectName(":type=grp2"))); + } finally { + metrics.close(); + } + } + + @Test + public void testJmxPrefix() throws Exception { + JmxReporter reporter = new JmxReporter(); + MetricsContext metricsContext = new KafkaMetricsContext("kafka.server"); + MetricConfig metricConfig = new MetricConfig(); + Metrics metrics = new Metrics(metricConfig, new ArrayList<>(Arrays.asList(reporter)), Time.SYSTEM, metricsContext); + + MBeanServer server = ManagementFactory.getPlatformMBeanServer(); + try { + Sensor sensor = metrics.sensor("kafka.requests"); + sensor.add(metrics.metricName("pack.bean1.avg", "grp1"), new Avg()); + assertEquals("kafka.server", server.getObjectInstance(new ObjectName("kafka.server:type=grp1")).getObjectName().getDomain()); + } finally { + metrics.close(); + } + } + + @Test + public void testDeprecatedJmxPrefixWithDefaultMetrics() throws Exception { + @SuppressWarnings("deprecation") + JmxReporter reporter = new JmxReporter("my-prefix"); + + // for backwards compatibility, ensure prefix does not get overridden by the default empty namespace in metricscontext + MetricConfig metricConfig = new MetricConfig(); + Metrics metrics = new Metrics(metricConfig, new ArrayList<>(Arrays.asList(reporter)), Time.SYSTEM); + + MBeanServer server = ManagementFactory.getPlatformMBeanServer(); + try { + Sensor sensor = metrics.sensor("my-sensor"); + sensor.add(metrics.metricName("pack.bean1.avg", "grp1"), new Avg()); + assertEquals("my-prefix", server.getObjectInstance(new ObjectName("my-prefix:type=grp1")).getObjectName().getDomain()); + } finally { + metrics.close(); + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/metrics/KafkaMbeanTest.java b/clients/src/test/java/org/apache/kafka/common/metrics/KafkaMbeanTest.java new file mode 100644 index 0000000..5df66db --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/metrics/KafkaMbeanTest.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics; + +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.stats.WindowedCount; +import org.apache.kafka.common.metrics.stats.WindowedSum; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import javax.management.Attribute; +import javax.management.AttributeList; +import javax.management.AttributeNotFoundException; +import javax.management.MBeanServer; +import javax.management.ObjectName; +import javax.management.RuntimeMBeanException; +import java.lang.management.ManagementFactory; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.fail; + +public class KafkaMbeanTest { + + private final MBeanServer mBeanServer = ManagementFactory.getPlatformMBeanServer(); + private Sensor sensor; + private MetricName countMetricName; + private MetricName sumMetricName; + private Metrics metrics; + + @BeforeEach + public void setup() throws Exception { + metrics = new Metrics(); + metrics.addReporter(new JmxReporter()); + sensor = metrics.sensor("kafka.requests"); + countMetricName = metrics.metricName("pack.bean1.count", "grp1"); + sensor.add(countMetricName, new WindowedCount()); + sumMetricName = metrics.metricName("pack.bean1.sum", "grp1"); + sensor.add(sumMetricName, new WindowedSum()); + } + + @AfterEach + public void tearDown() { + metrics.close(); + } + + @Test + public void testGetAttribute() throws Exception { + sensor.record(2.5); + Object counterAttribute = getAttribute(countMetricName); + assertEquals(1.0, counterAttribute); + Object sumAttribute = getAttribute(sumMetricName); + assertEquals(2.5, sumAttribute); + } + + @Test + public void testGetAttributeUnknown() throws Exception { + sensor.record(2.5); + try { + getAttribute(sumMetricName, "name"); + fail("Should have gotten attribute not found"); + } catch (AttributeNotFoundException e) { + // Expected + } + } + + @Test + public void testGetAttributes() throws Exception { + sensor.record(3.5); + sensor.record(4.0); + AttributeList attributeList = getAttributes(countMetricName, countMetricName.name(), sumMetricName.name()); + List attributes = attributeList.asList(); + assertEquals(2, attributes.size()); + for (Attribute attribute : attributes) { + if (countMetricName.name().equals(attribute.getName())) + assertEquals(2.0, attribute.getValue()); + else if (sumMetricName.name().equals(attribute.getName())) + assertEquals(7.5, attribute.getValue()); + else + fail("Unexpected attribute returned: " + attribute.getName()); + } + } + + @Test + public void testGetAttributesWithUnknown() throws Exception { + sensor.record(3.5); + sensor.record(4.0); + AttributeList attributeList = getAttributes(countMetricName, countMetricName.name(), + sumMetricName.name(), "name"); + List attributes = attributeList.asList(); + assertEquals(2, attributes.size()); + for (Attribute attribute : attributes) { + if (countMetricName.name().equals(attribute.getName())) + assertEquals(2.0, attribute.getValue()); + else if (sumMetricName.name().equals(attribute.getName())) + assertEquals(7.5, attribute.getValue()); + else + fail("Unexpected attribute returned: " + attribute.getName()); + } + } + + @Test + public void testInvoke() throws Exception { + RuntimeMBeanException e = assertThrows(RuntimeMBeanException.class, + () -> mBeanServer.invoke(objectName(countMetricName), "something", null, null)); + assertEquals(UnsupportedOperationException.class, e.getCause().getClass()); + } + + @Test + public void testSetAttribute() throws Exception { + RuntimeMBeanException e = assertThrows(RuntimeMBeanException.class, + () -> mBeanServer.setAttribute(objectName(countMetricName), new Attribute("anything", 1))); + assertEquals(UnsupportedOperationException.class, e.getCause().getClass()); + } + + @Test + public void testSetAttributes() throws Exception { + RuntimeMBeanException e = assertThrows(RuntimeMBeanException.class, + () -> mBeanServer.setAttributes(objectName(countMetricName), new AttributeList(1))); + assertEquals(UnsupportedOperationException.class, e.getCause().getClass()); + } + + private ObjectName objectName(MetricName metricName) throws Exception { + return new ObjectName(JmxReporter.getMBeanName("", metricName)); + } + + private Object getAttribute(MetricName metricName, String attribute) throws Exception { + return mBeanServer.getAttribute(objectName(metricName), attribute); + } + + private Object getAttribute(MetricName metricName) throws Exception { + return getAttribute(metricName, metricName.name()); + } + + private AttributeList getAttributes(MetricName metricName, String... attributes) throws Exception { + return mBeanServer.getAttributes(objectName(metricName), attributes); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/metrics/KafkaMetricsContextTest.java b/clients/src/test/java/org/apache/kafka/common/metrics/KafkaMetricsContextTest.java new file mode 100644 index 0000000..8a5a8a8 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/metrics/KafkaMetricsContextTest.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class KafkaMetricsContextTest { + + private static final String SAMPLE_NAMESPACE = "sample-ns"; + + private static final String LABEL_A_KEY = "label-a"; + private static final String LABEL_A_VALUE = "label-a-value"; + + private String namespace; + private Map labels; + private KafkaMetricsContext context; + + @BeforeEach + public void beforeEach() { + namespace = SAMPLE_NAMESPACE; + labels = new HashMap<>(); + labels.put(LABEL_A_KEY, LABEL_A_VALUE); + } + + @Test + public void testCreationWithValidNamespaceAndNoLabels() { + labels.clear(); + context = new KafkaMetricsContext(namespace, labels); + + assertEquals(1, context.contextLabels().size()); + assertEquals(namespace, context.contextLabels().get(MetricsContext.NAMESPACE)); + } + + @Test + public void testCreationWithValidNamespaceAndLabels() { + context = new KafkaMetricsContext(namespace, labels); + + assertEquals(2, context.contextLabels().size()); + assertEquals(namespace, context.contextLabels().get(MetricsContext.NAMESPACE)); + assertEquals(LABEL_A_VALUE, context.contextLabels().get(LABEL_A_KEY)); + } + + @Test + public void testCreationWithValidNamespaceAndNullLabelValues() { + labels.put(LABEL_A_KEY, null); + context = new KafkaMetricsContext(namespace, labels); + + assertEquals(2, context.contextLabels().size()); + assertEquals(namespace, context.contextLabels().get(MetricsContext.NAMESPACE)); + assertNull(context.contextLabels().get(LABEL_A_KEY)); + } + + @Test + public void testCreationWithNullNamespaceAndLabels() { + context = new KafkaMetricsContext(null, labels); + + assertEquals(2, context.contextLabels().size()); + assertNull(context.contextLabels().get(MetricsContext.NAMESPACE)); + assertEquals(LABEL_A_VALUE, context.contextLabels().get(LABEL_A_KEY)); + } + + @Test + public void testKafkaMetricsContextLabelsAreImmutable() { + context = new KafkaMetricsContext(namespace, labels); + assertThrows(UnsupportedOperationException.class, () -> context.contextLabels().clear()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/metrics/MetricsTest.java b/clients/src/test/java/org/apache/kafka/common/metrics/MetricsTest.java new file mode 100644 index 0000000..23d13e3 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/metrics/MetricsTest.java @@ -0,0 +1,955 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics; + +import static java.util.Collections.emptyList; +import static java.util.Collections.singletonList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Deque; +import java.util.List; +import java.util.HashMap; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; + +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.stats.Avg; +import org.apache.kafka.common.metrics.stats.CumulativeSum; +import org.apache.kafka.common.metrics.stats.Max; +import org.apache.kafka.common.metrics.stats.Meter; +import org.apache.kafka.common.metrics.stats.Min; +import org.apache.kafka.common.metrics.stats.Percentile; +import org.apache.kafka.common.metrics.stats.Percentiles; +import org.apache.kafka.common.metrics.stats.Percentiles.BucketSizing; +import org.apache.kafka.common.metrics.stats.Rate; +import org.apache.kafka.common.metrics.stats.WindowedCount; +import org.apache.kafka.common.metrics.stats.WindowedSum; +import org.apache.kafka.common.metrics.stats.SimpleRate; +import org.apache.kafka.common.metrics.stats.Value; +import org.apache.kafka.common.utils.MockTime; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class MetricsTest { + private static final Logger log = LoggerFactory.getLogger(MetricsTest.class); + + private static final double EPS = 0.000001; + private MockTime time = new MockTime(); + private MetricConfig config = new MetricConfig(); + private Metrics metrics; + private ExecutorService executorService; + + @BeforeEach + public void setup() { + this.metrics = new Metrics(config, Arrays.asList(new JmxReporter()), time, true); + } + + @AfterEach + public void tearDown() throws Exception { + if (executorService != null) { + executorService.shutdownNow(); + executorService.awaitTermination(5, TimeUnit.SECONDS); + } + this.metrics.close(); + } + + @Test + public void testMetricName() { + MetricName n1 = metrics.metricName("name", "group", "description", "key1", "value1", "key2", "value2"); + Map tags = new HashMap(); + tags.put("key1", "value1"); + tags.put("key2", "value2"); + MetricName n2 = metrics.metricName("name", "group", "description", tags); + assertEquals(n1, n2, "metric names created in two different ways should be equal"); + + try { + metrics.metricName("name", "group", "description", "key1"); + fail("Creating MetricName with an odd number of keyValue should fail"); + } catch (IllegalArgumentException e) { + // this is expected + } + } + + @Test + public void testSimpleStats() throws Exception { + verifyStats(m -> (double) m.metricValue()); + } + + private void verifyStats(Function metricValueFunc) { + ConstantMeasurable measurable = new ConstantMeasurable(); + + metrics.addMetric(metrics.metricName("direct.measurable", "grp1", "The fraction of time an appender waits for space allocation."), measurable); + Sensor s = metrics.sensor("test.sensor"); + s.add(metrics.metricName("test.avg", "grp1"), new Avg()); + s.add(metrics.metricName("test.max", "grp1"), new Max()); + s.add(metrics.metricName("test.min", "grp1"), new Min()); + s.add(new Meter(TimeUnit.SECONDS, metrics.metricName("test.rate", "grp1"), + metrics.metricName("test.total", "grp1"))); + s.add(new Meter(TimeUnit.SECONDS, new WindowedCount(), metrics.metricName("test.occurences", "grp1"), + metrics.metricName("test.occurences.total", "grp1"))); + s.add(metrics.metricName("test.count", "grp1"), new WindowedCount()); + s.add(new Percentiles(100, -100, 100, BucketSizing.CONSTANT, + new Percentile(metrics.metricName("test.median", "grp1"), 50.0), + new Percentile(metrics.metricName("test.perc99_9", "grp1"), 99.9))); + + Sensor s2 = metrics.sensor("test.sensor2"); + s2.add(metrics.metricName("s2.total", "grp1"), new CumulativeSum()); + s2.record(5.0); + + int sum = 0; + int count = 10; + for (int i = 0; i < count; i++) { + s.record(i); + sum += i; + } + // prior to any time passing + double elapsedSecs = (config.timeWindowMs() * (config.samples() - 1)) / 1000.0; + assertEquals(count / elapsedSecs, metricValueFunc.apply(metrics.metrics().get(metrics.metricName("test.occurences", "grp1"))), EPS, + String.format("Occurrences(0...%d) = %f", count, count / elapsedSecs)); + + // pretend 2 seconds passed... + long sleepTimeMs = 2; + time.sleep(sleepTimeMs * 1000); + elapsedSecs += sleepTimeMs; + + assertEquals(5.0, metricValueFunc.apply(metrics.metric(metrics.metricName("s2.total", "grp1"))), EPS, + "s2 reflects the constant value"); + assertEquals(4.5, metricValueFunc.apply(metrics.metric(metrics.metricName("test.avg", "grp1"))), EPS, + "Avg(0...9) = 4.5"); + assertEquals(count - 1, metricValueFunc.apply(metrics.metric(metrics.metricName("test.max", "grp1"))), EPS, + "Max(0...9) = 9"); + assertEquals(0.0, metricValueFunc.apply(metrics.metric(metrics.metricName("test.min", "grp1"))), EPS, + "Min(0...9) = 0"); + assertEquals(sum / elapsedSecs, metricValueFunc.apply(metrics.metric(metrics.metricName("test.rate", "grp1"))), EPS, + "Rate(0...9) = 1.40625"); + assertEquals(count / elapsedSecs, metricValueFunc.apply(metrics.metric(metrics.metricName("test.occurences", "grp1"))), EPS, + String.format("Occurrences(0...%d) = %f", count, count / elapsedSecs)); + assertEquals(count, metricValueFunc.apply(metrics.metric(metrics.metricName("test.count", "grp1"))), EPS, + "Count(0...9) = 10"); + } + + @Test + public void testHierarchicalSensors() { + Sensor parent1 = metrics.sensor("test.parent1"); + parent1.add(metrics.metricName("test.parent1.count", "grp1"), new WindowedCount()); + Sensor parent2 = metrics.sensor("test.parent2"); + parent2.add(metrics.metricName("test.parent2.count", "grp1"), new WindowedCount()); + Sensor child1 = metrics.sensor("test.child1", parent1, parent2); + child1.add(metrics.metricName("test.child1.count", "grp1"), new WindowedCount()); + Sensor child2 = metrics.sensor("test.child2", parent1); + child2.add(metrics.metricName("test.child2.count", "grp1"), new WindowedCount()); + Sensor grandchild = metrics.sensor("test.grandchild", child1); + grandchild.add(metrics.metricName("test.grandchild.count", "grp1"), new WindowedCount()); + + /* increment each sensor one time */ + parent1.record(); + parent2.record(); + child1.record(); + child2.record(); + grandchild.record(); + + double p1 = (double) parent1.metrics().get(0).metricValue(); + double p2 = (double) parent2.metrics().get(0).metricValue(); + double c1 = (double) child1.metrics().get(0).metricValue(); + double c2 = (double) child2.metrics().get(0).metricValue(); + double gc = (double) grandchild.metrics().get(0).metricValue(); + + /* each metric should have a count equal to one + its children's count */ + assertEquals(1.0, gc, EPS); + assertEquals(1.0 + gc, c1, EPS); + assertEquals(1.0, c2, EPS); + assertEquals(1.0 + c1, p2, EPS); + assertEquals(1.0 + c1 + c2, p1, EPS); + assertEquals(Arrays.asList(child1, child2), metrics.childrenSensors().get(parent1)); + assertEquals(Arrays.asList(child1), metrics.childrenSensors().get(parent2)); + assertNull(metrics.childrenSensors().get(grandchild)); + } + + @Test + public void testBadSensorHierarchy() { + Sensor p = metrics.sensor("parent"); + Sensor c1 = metrics.sensor("child1", p); + Sensor c2 = metrics.sensor("child2", p); + assertThrows(IllegalArgumentException.class, () -> metrics.sensor("gc", c1, c2)); + } + + @Test + public void testRemoveChildSensor() { + final Metrics metrics = new Metrics(); + + final Sensor parent = metrics.sensor("parent"); + final Sensor child = metrics.sensor("child", parent); + + assertEquals(singletonList(child), metrics.childrenSensors().get(parent)); + + metrics.removeSensor("child"); + + assertEquals(emptyList(), metrics.childrenSensors().get(parent)); + } + + @Test + public void testRemoveSensor() { + int size = metrics.metrics().size(); + Sensor parent1 = metrics.sensor("test.parent1"); + parent1.add(metrics.metricName("test.parent1.count", "grp1"), new WindowedCount()); + Sensor parent2 = metrics.sensor("test.parent2"); + parent2.add(metrics.metricName("test.parent2.count", "grp1"), new WindowedCount()); + Sensor child1 = metrics.sensor("test.child1", parent1, parent2); + child1.add(metrics.metricName("test.child1.count", "grp1"), new WindowedCount()); + Sensor child2 = metrics.sensor("test.child2", parent2); + child2.add(metrics.metricName("test.child2.count", "grp1"), new WindowedCount()); + Sensor grandChild1 = metrics.sensor("test.gchild2", child2); + grandChild1.add(metrics.metricName("test.gchild2.count", "grp1"), new WindowedCount()); + + Sensor sensor = metrics.getSensor("test.parent1"); + assertNotNull(sensor); + metrics.removeSensor("test.parent1"); + assertNull(metrics.getSensor("test.parent1")); + assertNull(metrics.metrics().get(metrics.metricName("test.parent1.count", "grp1"))); + assertNull(metrics.getSensor("test.child1")); + assertNull(metrics.childrenSensors().get(sensor)); + assertNull(metrics.metrics().get(metrics.metricName("test.child1.count", "grp1"))); + + sensor = metrics.getSensor("test.gchild2"); + assertNotNull(sensor); + metrics.removeSensor("test.gchild2"); + assertNull(metrics.getSensor("test.gchild2")); + assertNull(metrics.childrenSensors().get(sensor)); + assertNull(metrics.metrics().get(metrics.metricName("test.gchild2.count", "grp1"))); + + sensor = metrics.getSensor("test.child2"); + assertNotNull(sensor); + metrics.removeSensor("test.child2"); + assertNull(metrics.getSensor("test.child2")); + assertNull(metrics.childrenSensors().get(sensor)); + assertNull(metrics.metrics().get(metrics.metricName("test.child2.count", "grp1"))); + + sensor = metrics.getSensor("test.parent2"); + assertNotNull(sensor); + metrics.removeSensor("test.parent2"); + assertNull(metrics.getSensor("test.parent2")); + assertNull(metrics.childrenSensors().get(sensor)); + assertNull(metrics.metrics().get(metrics.metricName("test.parent2.count", "grp1"))); + + assertEquals(size, metrics.metrics().size()); + } + + @Test + public void testRemoveInactiveMetrics() { + Sensor s1 = metrics.sensor("test.s1", null, 1); + s1.add(metrics.metricName("test.s1.count", "grp1"), new WindowedCount()); + + Sensor s2 = metrics.sensor("test.s2", null, 3); + s2.add(metrics.metricName("test.s2.count", "grp1"), new WindowedCount()); + + Metrics.ExpireSensorTask purger = metrics.new ExpireSensorTask(); + purger.run(); + assertNotNull(metrics.getSensor("test.s1"), "Sensor test.s1 must be present"); + assertNotNull( + metrics.metrics().get(metrics.metricName("test.s1.count", "grp1")), "MetricName test.s1.count must be present"); + assertNotNull(metrics.getSensor("test.s2"), "Sensor test.s2 must be present"); + assertNotNull( + metrics.metrics().get(metrics.metricName("test.s2.count", "grp1")), "MetricName test.s2.count must be present"); + + time.sleep(1001); + purger.run(); + assertNull(metrics.getSensor("test.s1"), "Sensor test.s1 should have been purged"); + assertNull( + metrics.metrics().get(metrics.metricName("test.s1.count", "grp1")), "MetricName test.s1.count should have been purged"); + assertNotNull(metrics.getSensor("test.s2"), "Sensor test.s2 must be present"); + assertNotNull( + metrics.metrics().get(metrics.metricName("test.s2.count", "grp1")), "MetricName test.s2.count must be present"); + + // record a value in sensor s2. This should reset the clock for that sensor. + // It should not get purged at the 3 second mark after creation + s2.record(); + time.sleep(2000); + purger.run(); + assertNotNull(metrics.getSensor("test.s2"), "Sensor test.s2 must be present"); + assertNotNull( + metrics.metrics().get(metrics.metricName("test.s2.count", "grp1")), "MetricName test.s2.count must be present"); + + // After another 1 second sleep, the metric should be purged + time.sleep(1000); + purger.run(); + assertNull(metrics.getSensor("test.s1"), "Sensor test.s2 should have been purged"); + assertNull( + metrics.metrics().get(metrics.metricName("test.s1.count", "grp1")), "MetricName test.s2.count should have been purged"); + + // After purging, it should be possible to recreate a metric + s1 = metrics.sensor("test.s1", null, 1); + s1.add(metrics.metricName("test.s1.count", "grp1"), new WindowedCount()); + assertNotNull(metrics.getSensor("test.s1"), "Sensor test.s1 must be present"); + assertNotNull( + metrics.metrics().get(metrics.metricName("test.s1.count", "grp1")), "MetricName test.s1.count must be present"); + } + + @Test + public void testRemoveMetric() { + int size = metrics.metrics().size(); + metrics.addMetric(metrics.metricName("test1", "grp1"), new WindowedCount()); + metrics.addMetric(metrics.metricName("test2", "grp1"), new WindowedCount()); + + assertNotNull(metrics.removeMetric(metrics.metricName("test1", "grp1"))); + assertNull(metrics.metrics().get(metrics.metricName("test1", "grp1"))); + assertNotNull(metrics.metrics().get(metrics.metricName("test2", "grp1"))); + + assertNotNull(metrics.removeMetric(metrics.metricName("test2", "grp1"))); + assertNull(metrics.metrics().get(metrics.metricName("test2", "grp1"))); + + assertEquals(size, metrics.metrics().size()); + } + + @Test + public void testEventWindowing() { + WindowedCount count = new WindowedCount(); + MetricConfig config = new MetricConfig().eventWindow(1).samples(2); + count.record(config, 1.0, time.milliseconds()); + count.record(config, 1.0, time.milliseconds()); + assertEquals(2.0, count.measure(config, time.milliseconds()), EPS); + count.record(config, 1.0, time.milliseconds()); // first event times out + assertEquals(2.0, count.measure(config, time.milliseconds()), EPS); + } + + @Test + public void testTimeWindowing() { + WindowedCount count = new WindowedCount(); + MetricConfig config = new MetricConfig().timeWindow(1, TimeUnit.MILLISECONDS).samples(2); + count.record(config, 1.0, time.milliseconds()); + time.sleep(1); + count.record(config, 1.0, time.milliseconds()); + assertEquals(2.0, count.measure(config, time.milliseconds()), EPS); + time.sleep(1); + count.record(config, 1.0, time.milliseconds()); // oldest event times out + assertEquals(2.0, count.measure(config, time.milliseconds()), EPS); + } + + @Test + public void testOldDataHasNoEffect() { + Max max = new Max(); + long windowMs = 100; + int samples = 2; + MetricConfig config = new MetricConfig().timeWindow(windowMs, TimeUnit.MILLISECONDS).samples(samples); + max.record(config, 50, time.milliseconds()); + time.sleep(samples * windowMs); + assertEquals(Double.NaN, max.measure(config, time.milliseconds()), EPS); + } + + /** + * Some implementations of SampledStat make sense to return NaN + * when there are no values set rather than the initial value + */ + @Test + public void testSampledStatReturnsNaNWhenNoValuesExist() { + // This is tested by having a SampledStat with expired Stats, + // because their values get reset to the initial values. + Max max = new Max(); + Min min = new Min(); + Avg avg = new Avg(); + long windowMs = 100; + int samples = 2; + MetricConfig config = new MetricConfig().timeWindow(windowMs, TimeUnit.MILLISECONDS).samples(samples); + max.record(config, 50, time.milliseconds()); + min.record(config, 50, time.milliseconds()); + avg.record(config, 50, time.milliseconds()); + + time.sleep(samples * windowMs); + + assertEquals(Double.NaN, max.measure(config, time.milliseconds()), EPS); + assertEquals(Double.NaN, min.measure(config, time.milliseconds()), EPS); + assertEquals(Double.NaN, avg.measure(config, time.milliseconds()), EPS); + } + + /** + * Some implementations of SampledStat make sense to return the initial value + * when there are no values set + */ + @Test + public void testSampledStatReturnsInitialValueWhenNoValuesExist() { + WindowedCount count = new WindowedCount(); + WindowedSum sampledTotal = new WindowedSum(); + long windowMs = 100; + int samples = 2; + MetricConfig config = new MetricConfig().timeWindow(windowMs, TimeUnit.MILLISECONDS).samples(samples); + + count.record(config, 50, time.milliseconds()); + sampledTotal.record(config, 50, time.milliseconds()); + + time.sleep(samples * windowMs); + + assertEquals(0, count.measure(config, time.milliseconds()), EPS); + assertEquals(0.0, sampledTotal.measure(config, time.milliseconds()), EPS); + } + + @Test + public void testDuplicateMetricName() { + metrics.sensor("test").add(metrics.metricName("test", "grp1"), new Avg()); + assertThrows(IllegalArgumentException.class, () -> + metrics.sensor("test2").add(metrics.metricName("test", "grp1"), new CumulativeSum())); + } + + @Test + public void testQuotas() { + Sensor sensor = metrics.sensor("test"); + sensor.add(metrics.metricName("test1.total", "grp1"), new CumulativeSum(), new MetricConfig().quota(Quota.upperBound(5.0))); + sensor.add(metrics.metricName("test2.total", "grp1"), new CumulativeSum(), new MetricConfig().quota(Quota.lowerBound(0.0))); + sensor.record(5.0); + try { + sensor.record(1.0); + fail("Should have gotten a quota violation."); + } catch (QuotaViolationException e) { + // this is good + } + assertEquals(6.0, (Double) metrics.metrics().get(metrics.metricName("test1.total", "grp1")).metricValue(), EPS); + sensor.record(-6.0); + try { + sensor.record(-1.0); + fail("Should have gotten a quota violation."); + } catch (QuotaViolationException e) { + // this is good + } + } + + @Test + public void testQuotasEquality() { + final Quota quota1 = Quota.upperBound(10.5); + final Quota quota2 = Quota.lowerBound(10.5); + + assertFalse(quota1.equals(quota2), "Quota with different upper values shouldn't be equal"); + + final Quota quota3 = Quota.lowerBound(10.5); + + assertTrue(quota2.equals(quota3), "Quota with same upper and bound values should be equal"); + } + + @Test + public void testPercentiles() { + int buckets = 100; + Percentiles percs = new Percentiles(4 * buckets, + 0.0, + 100.0, + BucketSizing.CONSTANT, + new Percentile(metrics.metricName("test.p25", "grp1"), 25), + new Percentile(metrics.metricName("test.p50", "grp1"), 50), + new Percentile(metrics.metricName("test.p75", "grp1"), 75)); + MetricConfig config = new MetricConfig().eventWindow(50).samples(2); + Sensor sensor = metrics.sensor("test", config); + sensor.add(percs); + Metric p25 = this.metrics.metrics().get(metrics.metricName("test.p25", "grp1")); + Metric p50 = this.metrics.metrics().get(metrics.metricName("test.p50", "grp1")); + Metric p75 = this.metrics.metrics().get(metrics.metricName("test.p75", "grp1")); + + // record two windows worth of sequential values + for (int i = 0; i < buckets; i++) + sensor.record(i); + + assertEquals(25, (Double) p25.metricValue(), 1.0); + assertEquals(50, (Double) p50.metricValue(), 1.0); + assertEquals(75, (Double) p75.metricValue(), 1.0); + + for (int i = 0; i < buckets; i++) + sensor.record(0.0); + + assertEquals(0.0, (Double) p25.metricValue(), 1.0); + assertEquals(0.0, (Double) p50.metricValue(), 1.0); + assertEquals(0.0, (Double) p75.metricValue(), 1.0); + + // record two more windows worth of sequential values + for (int i = 0; i < buckets; i++) + sensor.record(i); + + assertEquals(25, (Double) p25.metricValue(), 1.0); + assertEquals(50, (Double) p50.metricValue(), 1.0); + assertEquals(75, (Double) p75.metricValue(), 1.0); + } + + @Test + public void shouldPinSmallerValuesToMin() { + final double min = 0.0d; + final double max = 100d; + Percentiles percs = new Percentiles(1000, + min, + max, + BucketSizing.LINEAR, + new Percentile(metrics.metricName("test.p50", "grp1"), 50)); + MetricConfig config = new MetricConfig().eventWindow(50).samples(2); + Sensor sensor = metrics.sensor("test", config); + sensor.add(percs); + Metric p50 = this.metrics.metrics().get(metrics.metricName("test.p50", "grp1")); + + sensor.record(min - 100); + sensor.record(min - 100); + assertEquals(min, (double) p50.metricValue(), 0d); + } + + @Test + public void shouldPinLargerValuesToMax() { + final double min = 0.0d; + final double max = 100d; + Percentiles percs = new Percentiles(1000, + min, + max, + BucketSizing.LINEAR, + new Percentile(metrics.metricName("test.p50", "grp1"), 50)); + MetricConfig config = new MetricConfig().eventWindow(50).samples(2); + Sensor sensor = metrics.sensor("test", config); + sensor.add(percs); + Metric p50 = this.metrics.metrics().get(metrics.metricName("test.p50", "grp1")); + + sensor.record(max + 100); + sensor.record(max + 100); + assertEquals(max, (double) p50.metricValue(), 0d); + } + + @Test + public void testPercentilesWithRandomNumbersAndLinearBucketing() { + long seed = new Random().nextLong(); + int sizeInBytes = 100 * 1000; // 100kB + long maximumValue = 1000 * 24 * 60 * 60 * 1000L; // if values are ms, max is 1000 days + + try { + Random prng = new Random(seed); + int numberOfValues = 5000 + prng.nextInt(10_000); // range is [5000, 15000] + + Percentiles percs = new Percentiles(sizeInBytes, + maximumValue, + BucketSizing.LINEAR, + new Percentile(metrics.metricName("test.p90", "grp1"), 90), + new Percentile(metrics.metricName("test.p99", "grp1"), 99)); + MetricConfig config = new MetricConfig().eventWindow(50).samples(2); + Sensor sensor = metrics.sensor("test", config); + sensor.add(percs); + Metric p90 = this.metrics.metrics().get(metrics.metricName("test.p90", "grp1")); + Metric p99 = this.metrics.metrics().get(metrics.metricName("test.p99", "grp1")); + + final List values = new ArrayList<>(numberOfValues); + // record two windows worth of sequential values + for (int i = 0; i < numberOfValues; ++i) { + long value = (Math.abs(prng.nextLong()) - 1) % maximumValue; + values.add(value); + sensor.record(value); + } + + Collections.sort(values); + + int p90Index = (int) Math.ceil(((double) (90 * numberOfValues)) / 100); + int p99Index = (int) Math.ceil(((double) (99 * numberOfValues)) / 100); + + double expectedP90 = values.get(p90Index - 1); + double expectedP99 = values.get(p99Index - 1); + + assertEquals(expectedP90, (Double) p90.metricValue(), expectedP90 / 5); + assertEquals(expectedP99, (Double) p99.metricValue(), expectedP99 / 5); + } catch (AssertionError e) { + throw new AssertionError("Assertion failed in randomized test. Reproduce with seed = " + seed + " .", e); + } + } + + @Test + public void testRateWindowing() throws Exception { + // Use the default time window. Set 3 samples + MetricConfig cfg = new MetricConfig().samples(3); + Sensor s = metrics.sensor("test.sensor", cfg); + MetricName rateMetricName = metrics.metricName("test.rate", "grp1"); + MetricName totalMetricName = metrics.metricName("test.total", "grp1"); + MetricName countRateMetricName = metrics.metricName("test.count.rate", "grp1"); + MetricName countTotalMetricName = metrics.metricName("test.count.total", "grp1"); + s.add(new Meter(TimeUnit.SECONDS, rateMetricName, totalMetricName)); + s.add(new Meter(TimeUnit.SECONDS, new WindowedCount(), countRateMetricName, countTotalMetricName)); + KafkaMetric totalMetric = metrics.metrics().get(totalMetricName); + KafkaMetric countTotalMetric = metrics.metrics().get(countTotalMetricName); + + int sum = 0; + int count = cfg.samples() - 1; + // Advance 1 window after every record + for (int i = 0; i < count; i++) { + s.record(100); + sum += 100; + time.sleep(cfg.timeWindowMs()); + assertEquals(sum, (Double) totalMetric.metricValue(), EPS); + } + + // Sleep for half the window. + time.sleep(cfg.timeWindowMs() / 2); + + // prior to any time passing + double elapsedSecs = (cfg.timeWindowMs() * (cfg.samples() - 1) + cfg.timeWindowMs() / 2) / 1000.0; + + KafkaMetric rateMetric = metrics.metrics().get(rateMetricName); + KafkaMetric countRateMetric = metrics.metrics().get(countRateMetricName); + assertEquals(sum / elapsedSecs, (Double) rateMetric.metricValue(), EPS, "Rate(0...2) = 2.666"); + assertEquals(count / elapsedSecs, (Double) countRateMetric.metricValue(), EPS, "Count rate(0...2) = 0.02666"); + assertEquals(elapsedSecs, + ((Rate) rateMetric.measurable()).windowSize(cfg, time.milliseconds()) / 1000, EPS, "Elapsed Time = 75 seconds"); + assertEquals(sum, (Double) totalMetric.metricValue(), EPS); + assertEquals(count, (Double) countTotalMetric.metricValue(), EPS); + + // Verify that rates are expired, but total is cumulative + time.sleep(cfg.timeWindowMs() * cfg.samples()); + assertEquals(0, (Double) rateMetric.metricValue(), EPS); + assertEquals(0, (Double) countRateMetric.metricValue(), EPS); + assertEquals(sum, (Double) totalMetric.metricValue(), EPS); + assertEquals(count, (Double) countTotalMetric.metricValue(), EPS); + } + + public static class ConstantMeasurable implements Measurable { + public double value = 0.0; + + @Override + public double measure(MetricConfig config, long now) { + return value; + } + + } + + @Test + public void testSimpleRate() { + SimpleRate rate = new SimpleRate(); + + //Given + MetricConfig config = new MetricConfig().timeWindow(1, TimeUnit.SECONDS).samples(10); + + //In the first window the rate is a fraction of the whole (1s) window + //So when we record 1000 at t0, the rate should be 1000 until the window completes, or more data is recorded. + record(rate, config, 1000); + assertEquals(1000, measure(rate, config), 0); + time.sleep(100); + assertEquals(1000, measure(rate, config), 0); // 1000B / 0.1s + time.sleep(100); + assertEquals(1000, measure(rate, config), 0); // 1000B / 0.2s + time.sleep(200); + assertEquals(1000, measure(rate, config), 0); // 1000B / 0.4s + + //In the second (and subsequent) window(s), the rate will be in proportion to the elapsed time + //So the rate will degrade over time, as the time between measurement and the initial recording grows. + time.sleep(600); + assertEquals(1000, measure(rate, config), 0); // 1000B / 1.0s + time.sleep(200); + assertEquals(1000 / 1.2, measure(rate, config), 0); // 1000B / 1.2s + time.sleep(200); + assertEquals(1000 / 1.4, measure(rate, config), 0); // 1000B / 1.4s + + //Adding another value, inside the same window should double the rate + record(rate, config, 1000); + assertEquals(2000 / 1.4, measure(rate, config), 0); // 2000B / 1.4s + + //Going over the next window, should not change behaviour + time.sleep(1100); + assertEquals(2000 / 2.5, measure(rate, config), 0); // 2000B / 2.5s + record(rate, config, 1000); + assertEquals(3000 / 2.5, measure(rate, config), 0); // 3000B / 2.5s + + //Sleeping for another 6.5 windows also should be the same + time.sleep(6500); + assertEquals(3000 / 9, measure(rate, config), 1); // 3000B / 9s + record(rate, config, 1000); + assertEquals(4000 / 9, measure(rate, config), 1); // 4000B / 9s + + //Going over the 10 window boundary should cause the first window's values (1000) will be purged. + //So the rate is calculated based on the oldest reading, which is inside the second window, at 1.4s + time.sleep(1500); + assertEquals((4000 - 1000) / (10.5 - 1.4), measure(rate, config), 1); + record(rate, config, 1000); + assertEquals((5000 - 1000) / (10.5 - 1.4), measure(rate, config), 1); + } + + private void record(Rate rate, MetricConfig config, int value) { + rate.record(config, value, time.milliseconds()); + } + + private Double measure(Measurable rate, MetricConfig config) { + return rate.measure(config, time.milliseconds()); + } + + @Test + public void testMetricInstances() { + MetricName n1 = metrics.metricInstance(SampleMetrics.METRIC1, "key1", "value1", "key2", "value2"); + Map tags = new HashMap(); + tags.put("key1", "value1"); + tags.put("key2", "value2"); + MetricName n2 = metrics.metricInstance(SampleMetrics.METRIC2, tags); + assertEquals(n1, n2, "metric names created in two different ways should be equal"); + + try { + metrics.metricInstance(SampleMetrics.METRIC1, "key1"); + fail("Creating MetricName with an odd number of keyValue should fail"); + } catch (IllegalArgumentException e) { + // this is expected + } + + Map parentTagsWithValues = new HashMap<>(); + parentTagsWithValues.put("parent-tag", "parent-tag-value"); + + Map childTagsWithValues = new HashMap<>(); + childTagsWithValues.put("child-tag", "child-tag-value"); + + try (Metrics inherited = new Metrics(new MetricConfig().tags(parentTagsWithValues), Arrays.asList(new JmxReporter()), time, true)) { + MetricName inheritedMetric = inherited.metricInstance(SampleMetrics.METRIC_WITH_INHERITED_TAGS, childTagsWithValues); + + Map filledOutTags = inheritedMetric.tags(); + assertEquals(filledOutTags.get("parent-tag"), "parent-tag-value", "parent-tag should be set properly"); + assertEquals(filledOutTags.get("child-tag"), "child-tag-value", "child-tag should be set properly"); + + try { + inherited.metricInstance(SampleMetrics.METRIC_WITH_INHERITED_TAGS, parentTagsWithValues); + fail("Creating MetricName should fail if the child metrics are not defined at runtime"); + } catch (IllegalArgumentException e) { + // this is expected + } + + try { + + Map runtimeTags = new HashMap<>(); + runtimeTags.put("child-tag", "child-tag-value"); + runtimeTags.put("tag-not-in-template", "unexpected-value"); + + inherited.metricInstance(SampleMetrics.METRIC_WITH_INHERITED_TAGS, runtimeTags); + fail("Creating MetricName should fail if there is a tag at runtime that is not in the template"); + } catch (IllegalArgumentException e) { + // this is expected + } + } + } + + /** + * Verifies that concurrent sensor add, remove, updates and read don't result + * in errors or deadlock. + */ + @Test + public void testConcurrentReadUpdate() throws Exception { + final Random random = new Random(); + final Deque sensors = new ConcurrentLinkedDeque<>(); + metrics = new Metrics(new MockTime(10)); + SensorCreator sensorCreator = new SensorCreator(metrics); + + final AtomicBoolean alive = new AtomicBoolean(true); + executorService = Executors.newSingleThreadExecutor(); + executorService.submit(new ConcurrentMetricOperation(alive, "record", + () -> sensors.forEach(sensor -> sensor.record(random.nextInt(10000))))); + + for (int i = 0; i < 10000; i++) { + if (sensors.size() > 5) { + Sensor sensor = random.nextBoolean() ? sensors.removeFirst() : sensors.removeLast(); + metrics.removeSensor(sensor.name()); + } + StatType statType = StatType.forId(random.nextInt(StatType.values().length)); + sensors.add(sensorCreator.createSensor(statType, i)); + for (Sensor sensor : sensors) { + for (KafkaMetric metric : sensor.metrics()) { + assertNotNull(metric.metricValue(), "Invalid metric value"); + } + } + } + alive.set(false); + } + + /** + * Verifies that concurrent sensor add, remove, updates and read with a metrics reporter + * that synchronizes on every reporter method doesn't result in errors or deadlock. + */ + @Test + public void testConcurrentReadUpdateReport() throws Exception { + + class LockingReporter implements MetricsReporter { + Map activeMetrics = new HashMap<>(); + @Override + public synchronized void init(List metrics) { + } + + @Override + public synchronized void metricChange(KafkaMetric metric) { + activeMetrics.put(metric.metricName(), metric); + } + + @Override + public synchronized void metricRemoval(KafkaMetric metric) { + activeMetrics.remove(metric.metricName(), metric); + } + + @Override + public synchronized void close() { + } + + @Override + public void configure(Map configs) { + } + + synchronized void processMetrics() { + for (KafkaMetric metric : activeMetrics.values()) { + assertNotNull(metric.metricValue(), "Invalid metric value"); + } + } + } + + final LockingReporter reporter = new LockingReporter(); + this.metrics.close(); + this.metrics = new Metrics(config, Arrays.asList(reporter), new MockTime(10), true); + final Deque sensors = new ConcurrentLinkedDeque<>(); + SensorCreator sensorCreator = new SensorCreator(metrics); + + final Random random = new Random(); + final AtomicBoolean alive = new AtomicBoolean(true); + executorService = Executors.newFixedThreadPool(3); + + Future writeFuture = executorService.submit(new ConcurrentMetricOperation(alive, "record", + () -> sensors.forEach(sensor -> sensor.record(random.nextInt(10000))))); + Future readFuture = executorService.submit(new ConcurrentMetricOperation(alive, "read", + () -> sensors.forEach(sensor -> sensor.metrics().forEach(metric -> + assertNotNull(metric.metricValue(), "Invalid metric value"))))); + Future reportFuture = executorService.submit(new ConcurrentMetricOperation(alive, "report", + reporter::processMetrics)); + + for (int i = 0; i < 10000; i++) { + if (sensors.size() > 10) { + Sensor sensor = random.nextBoolean() ? sensors.removeFirst() : sensors.removeLast(); + metrics.removeSensor(sensor.name()); + } + StatType statType = StatType.forId(random.nextInt(StatType.values().length)); + sensors.add(sensorCreator.createSensor(statType, i)); + } + assertFalse(readFuture.isDone(), "Read failed"); + assertFalse(writeFuture.isDone(), "Write failed"); + assertFalse(reportFuture.isDone(), "Report failed"); + + alive.set(false); + } + + private class ConcurrentMetricOperation implements Runnable { + private final AtomicBoolean alive; + private final String opName; + private final Runnable op; + ConcurrentMetricOperation(AtomicBoolean alive, String opName, Runnable op) { + this.alive = alive; + this.opName = opName; + this.op = op; + } + @Override + public void run() { + try { + while (alive.get()) { + op.run(); + } + } catch (Throwable t) { + log.error("Metric {} failed with exception", opName, t); + } + } + } + + enum StatType { + AVG(0), + TOTAL(1), + COUNT(2), + MAX(3), + MIN(4), + RATE(5), + SIMPLE_RATE(6), + SUM(7), + VALUE(8), + PERCENTILES(9), + METER(10); + + int id; + StatType(int id) { + this.id = id; + } + + static StatType forId(int id) { + for (StatType statType : StatType.values()) { + if (statType.id == id) + return statType; + } + return null; + } + } + + private static class SensorCreator { + + private final Metrics metrics; + + SensorCreator(Metrics metrics) { + this.metrics = metrics; + } + + private Sensor createSensor(StatType statType, int index) { + Sensor sensor = metrics.sensor("kafka.requests." + index); + Map tags = Collections.singletonMap("tag", "tag" + index); + switch (statType) { + case AVG: + sensor.add(metrics.metricName("test.metric.avg", "avg", tags), new Avg()); + break; + case TOTAL: + sensor.add(metrics.metricName("test.metric.total", "total", tags), new CumulativeSum()); + break; + case COUNT: + sensor.add(metrics.metricName("test.metric.count", "count", tags), new WindowedCount()); + break; + case MAX: + sensor.add(metrics.metricName("test.metric.max", "max", tags), new Max()); + break; + case MIN: + sensor.add(metrics.metricName("test.metric.min", "min", tags), new Min()); + break; + case RATE: + sensor.add(metrics.metricName("test.metric.rate", "rate", tags), new Rate()); + break; + case SIMPLE_RATE: + sensor.add(metrics.metricName("test.metric.simpleRate", "simpleRate", tags), new SimpleRate()); + break; + case SUM: + sensor.add(metrics.metricName("test.metric.sum", "sum", tags), new WindowedSum()); + break; + case VALUE: + sensor.add(metrics.metricName("test.metric.value", "value", tags), new Value()); + break; + case PERCENTILES: + sensor.add(metrics.metricName("test.metric.percentiles", "percentiles", tags), + new Percentiles(100, -100, 100, Percentiles.BucketSizing.CONSTANT, + new Percentile(metrics.metricName("test.median", "percentiles"), 50.0), + new Percentile(metrics.metricName("test.perc99_9", "percentiles"), 99.9))); + break; + case METER: + sensor.add(new Meter(metrics.metricName("test.metric.meter.rate", "meter", tags), + metrics.metricName("test.metric.meter.total", "meter", tags))); + break; + default: + throw new IllegalStateException("Invalid stat type " + statType); + } + return sensor; + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/metrics/SampleMetrics.java b/clients/src/test/java/org/apache/kafka/common/metrics/SampleMetrics.java new file mode 100644 index 0000000..1e3e817 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/metrics/SampleMetrics.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics; + +import org.apache.kafka.common.MetricNameTemplate; + +/** + * A registry of predefined Metrics for the MetricsTest.java class. + */ +public class SampleMetrics { + + public static final MetricNameTemplate METRIC1 = new MetricNameTemplate("name", "group", "The first metric used in testMetricName()", "key1", "key2"); + public static final MetricNameTemplate METRIC2 = new MetricNameTemplate("name", "group", "The second metric used in testMetricName()", "key1", "key2"); + + public static final MetricNameTemplate METRIC_WITH_INHERITED_TAGS = new MetricNameTemplate("inherited.tags", "group", "inherited.tags in testMetricName", "parent-tag", "child-tag"); +} + diff --git a/clients/src/test/java/org/apache/kafka/common/metrics/SensorTest.java b/clients/src/test/java/org/apache/kafka/common/metrics/SensorTest.java new file mode 100644 index 0000000..77176e1 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/metrics/SensorTest.java @@ -0,0 +1,373 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics; + +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.stats.Avg; +import org.apache.kafka.common.metrics.stats.CumulativeCount; +import org.apache.kafka.common.metrics.stats.Meter; +import org.apache.kafka.common.metrics.stats.Rate; +import org.apache.kafka.common.metrics.stats.TokenBucket; +import org.apache.kafka.common.metrics.stats.WindowedSum; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.SystemTime; +import org.apache.kafka.common.utils.Time; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Callable; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import org.mockito.Mockito; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class SensorTest { + + private static final MetricConfig INFO_CONFIG = new MetricConfig().recordLevel(Sensor.RecordingLevel.INFO); + private static final MetricConfig DEBUG_CONFIG = new MetricConfig().recordLevel(Sensor.RecordingLevel.DEBUG); + private static final MetricConfig TRACE_CONFIG = new MetricConfig().recordLevel(Sensor.RecordingLevel.TRACE); + + @Test + public void testRecordLevelEnum() { + Sensor.RecordingLevel configLevel = Sensor.RecordingLevel.INFO; + assertTrue(Sensor.RecordingLevel.INFO.shouldRecord(configLevel.id)); + assertFalse(Sensor.RecordingLevel.DEBUG.shouldRecord(configLevel.id)); + assertFalse(Sensor.RecordingLevel.TRACE.shouldRecord(configLevel.id)); + + configLevel = Sensor.RecordingLevel.DEBUG; + assertTrue(Sensor.RecordingLevel.INFO.shouldRecord(configLevel.id)); + assertTrue(Sensor.RecordingLevel.DEBUG.shouldRecord(configLevel.id)); + assertFalse(Sensor.RecordingLevel.TRACE.shouldRecord(configLevel.id)); + + configLevel = Sensor.RecordingLevel.TRACE; + assertTrue(Sensor.RecordingLevel.INFO.shouldRecord(configLevel.id)); + assertTrue(Sensor.RecordingLevel.DEBUG.shouldRecord(configLevel.id)); + assertTrue(Sensor.RecordingLevel.TRACE.shouldRecord(configLevel.id)); + + assertEquals(Sensor.RecordingLevel.valueOf(Sensor.RecordingLevel.DEBUG.toString()), + Sensor.RecordingLevel.DEBUG); + assertEquals(Sensor.RecordingLevel.valueOf(Sensor.RecordingLevel.INFO.toString()), + Sensor.RecordingLevel.INFO); + assertEquals(Sensor.RecordingLevel.valueOf(Sensor.RecordingLevel.TRACE.toString()), + Sensor.RecordingLevel.TRACE); + } + + @Test + public void testShouldRecordForInfoLevelSensor() { + Sensor infoSensor = new Sensor(null, "infoSensor", null, INFO_CONFIG, new SystemTime(), + 0, Sensor.RecordingLevel.INFO); + assertTrue(infoSensor.shouldRecord()); + + infoSensor = new Sensor(null, "infoSensor", null, DEBUG_CONFIG, new SystemTime(), + 0, Sensor.RecordingLevel.INFO); + assertTrue(infoSensor.shouldRecord()); + + infoSensor = new Sensor(null, "infoSensor", null, TRACE_CONFIG, new SystemTime(), + 0, Sensor.RecordingLevel.INFO); + assertTrue(infoSensor.shouldRecord()); + } + + @Test + public void testShouldRecordForDebugLevelSensor() { + Sensor debugSensor = new Sensor(null, "debugSensor", null, INFO_CONFIG, new SystemTime(), + 0, Sensor.RecordingLevel.DEBUG); + assertFalse(debugSensor.shouldRecord()); + + debugSensor = new Sensor(null, "debugSensor", null, DEBUG_CONFIG, new SystemTime(), + 0, Sensor.RecordingLevel.DEBUG); + assertTrue(debugSensor.shouldRecord()); + + debugSensor = new Sensor(null, "debugSensor", null, TRACE_CONFIG, new SystemTime(), + 0, Sensor.RecordingLevel.DEBUG); + assertTrue(debugSensor.shouldRecord()); + } + + @Test + public void testShouldRecordForTraceLevelSensor() { + Sensor traceSensor = new Sensor(null, "traceSensor", null, INFO_CONFIG, new SystemTime(), + 0, Sensor.RecordingLevel.TRACE); + assertFalse(traceSensor.shouldRecord()); + + traceSensor = new Sensor(null, "traceSensor", null, DEBUG_CONFIG, new SystemTime(), + 0, Sensor.RecordingLevel.TRACE); + assertFalse(traceSensor.shouldRecord()); + + traceSensor = new Sensor(null, "traceSensor", null, TRACE_CONFIG, new SystemTime(), + 0, Sensor.RecordingLevel.TRACE); + assertTrue(traceSensor.shouldRecord()); + } + + @Test + public void testExpiredSensor() { + MetricConfig config = new MetricConfig(); + Time mockTime = new MockTime(); + try (Metrics metrics = new Metrics(config, Arrays.asList(new JmxReporter()), mockTime, true)) { + long inactiveSensorExpirationTimeSeconds = 60L; + Sensor sensor = new Sensor(metrics, "sensor", null, config, mockTime, + inactiveSensorExpirationTimeSeconds, Sensor.RecordingLevel.INFO); + + assertTrue(sensor.add(metrics.metricName("test1", "grp1"), new Avg())); + + Map emptyTags = Collections.emptyMap(); + MetricName rateMetricName = new MetricName("rate", "test", "", emptyTags); + MetricName totalMetricName = new MetricName("total", "test", "", emptyTags); + Meter meter = new Meter(rateMetricName, totalMetricName); + assertTrue(sensor.add(meter)); + + mockTime.sleep(TimeUnit.SECONDS.toMillis(inactiveSensorExpirationTimeSeconds + 1)); + assertFalse(sensor.add(metrics.metricName("test3", "grp1"), new Avg())); + assertFalse(sensor.add(meter)); + } + } + + @Test + public void testIdempotentAdd() { + final Metrics metrics = new Metrics(); + final Sensor sensor = metrics.sensor("sensor"); + + assertTrue(sensor.add(metrics.metricName("test-metric", "test-group"), new Avg())); + + // adding the same metric to the same sensor is a no-op + assertTrue(sensor.add(metrics.metricName("test-metric", "test-group"), new Avg())); + + + // but adding the same metric to a DIFFERENT sensor is an error + final Sensor anotherSensor = metrics.sensor("another-sensor"); + try { + anotherSensor.add(metrics.metricName("test-metric", "test-group"), new Avg()); + fail("should have thrown"); + } catch (final IllegalArgumentException ignored) { + // pass + } + + // note that adding a different metric with the same name is also a no-op + assertTrue(sensor.add(metrics.metricName("test-metric", "test-group"), new WindowedSum())); + + // so after all this, we still just have the original metric registered + assertEquals(1, sensor.metrics().size()); + assertEquals(org.apache.kafka.common.metrics.stats.Avg.class, sensor.metrics().get(0).measurable().getClass()); + } + + /** + * The Sensor#checkQuotas should be thread-safe since the method may be used by many ReplicaFetcherThreads. + */ + @Test + public void testCheckQuotasInMultiThreads() throws InterruptedException, ExecutionException { + final Metrics metrics = new Metrics(new MetricConfig().quota(Quota.upperBound(Double.MAX_VALUE)) + // decreasing the value of time window make SampledStat always record the given value + .timeWindow(1, TimeUnit.MILLISECONDS) + // increasing the value of samples make SampledStat store more samples + .samples(100)); + final Sensor sensor = metrics.sensor("sensor"); + + assertTrue(sensor.add(metrics.metricName("test-metric", "test-group"), new Rate())); + final int threadCount = 10; + final CountDownLatch latch = new CountDownLatch(1); + ExecutorService service = Executors.newFixedThreadPool(threadCount); + List> workers = new ArrayList<>(threadCount); + boolean needShutdown = true; + try { + for (int i = 0; i != threadCount; ++i) { + final int index = i; + workers.add(service.submit(new Callable() { + @Override + public Throwable call() { + try { + assertTrue(latch.await(5, TimeUnit.SECONDS)); + for (int j = 0; j != 20; ++j) { + sensor.record(j * index, System.currentTimeMillis() + j, false); + sensor.checkQuotas(); + } + return null; + } catch (Throwable e) { + return e; + } + } + })); + } + latch.countDown(); + service.shutdown(); + assertTrue(service.awaitTermination(10, TimeUnit.SECONDS)); + needShutdown = false; + for (Future callable : workers) { + assertTrue(callable.isDone(), "If this failure happen frequently, we can try to increase the wait time"); + assertNull(callable.get(), "Sensor#checkQuotas SHOULD be thread-safe!"); + } + } finally { + if (needShutdown) { + service.shutdownNow(); + } + } + } + + @Test + public void shouldReturnPresenceOfMetrics() { + final Metrics metrics = new Metrics(); + final Sensor sensor = metrics.sensor("sensor"); + + assertFalse(sensor.hasMetrics()); + + sensor.add( + new MetricName("name1", "group1", "description1", Collections.emptyMap()), + new WindowedSum() + ); + + assertTrue(sensor.hasMetrics()); + + sensor.add( + new MetricName("name2", "group2", "description2", Collections.emptyMap()), + new CumulativeCount() + ); + + assertTrue(sensor.hasMetrics()); + } + + @Test + public void testStrictQuotaEnforcementWithRate() { + final Time time = new MockTime(0, System.currentTimeMillis(), 0); + final Metrics metrics = new Metrics(time); + final Sensor sensor = metrics.sensor("sensor", new MetricConfig() + .quota(Quota.upperBound(2)) + .timeWindow(1, TimeUnit.SECONDS) + .samples(11)); + final MetricName metricName = metrics.metricName("rate", "test-group"); + assertTrue(sensor.add(metricName, new Rate())); + final KafkaMetric rateMetric = metrics.metric(metricName); + + // Recording a first value at T+0 to bring the avg rate to 3 which is already + // above the quota. + strictRecord(sensor, 30, time.milliseconds()); + assertEquals(3, rateMetric.measurableValue(time.milliseconds()), 0.1); + + // Theoretically, we should wait 5s to bring back the avg rate to the define quota: + // ((30 / 10) - 2) / 2 * 10 = 5s + time.sleep(5000); + + // But, recording a second value is rejected because the avg rate is still equal + // to 3 after 5s. + assertEquals(3, rateMetric.measurableValue(time.milliseconds()), 0.1); + assertThrows(QuotaViolationException.class, () -> strictRecord(sensor, 30, time.milliseconds())); + + metrics.close(); + } + + @Test + public void testStrictQuotaEnforcementWithTokenBucket() { + final Time time = new MockTime(0, System.currentTimeMillis(), 0); + final Metrics metrics = new Metrics(time); + final Sensor sensor = metrics.sensor("sensor", new MetricConfig() + .quota(Quota.upperBound(2)) + .timeWindow(1, TimeUnit.SECONDS) + .samples(10)); + final MetricName metricName = metrics.metricName("credits", "test-group"); + assertTrue(sensor.add(metricName, new TokenBucket())); + final KafkaMetric tkMetric = metrics.metric(metricName); + + // Recording a first value at T+0 to bring the remaining credits below zero + strictRecord(sensor, 30, time.milliseconds()); + assertEquals(-10, tkMetric.measurableValue(time.milliseconds()), 0.1); + + // Theoretically, we should wait 5s to bring back the avg rate to the define quota: + // 10 / 2 = 5s + time.sleep(5000); + + // Unlike the default rate based on a windowed sum, it works as expected. + assertEquals(0, tkMetric.measurableValue(time.milliseconds()), 0.1); + strictRecord(sensor, 30, time.milliseconds()); + assertEquals(-30, tkMetric.measurableValue(time.milliseconds()), 0.1); + + metrics.close(); + } + + private void strictRecord(Sensor sensor, double value, long timeMs) { + synchronized (sensor) { + sensor.checkQuotas(timeMs); + sensor.record(value, timeMs, false); + } + } + + @Test + public void testRecordAndCheckQuotaUseMetricConfigOfEachStat() { + final Time time = new MockTime(0, System.currentTimeMillis(), 0); + final Metrics metrics = new Metrics(time); + final Sensor sensor = metrics.sensor("sensor"); + + final MeasurableStat stat1 = Mockito.mock(MeasurableStat.class); + final MetricName stat1Name = metrics.metricName("stat1", "test-group"); + final MetricConfig stat1Config = new MetricConfig().quota(Quota.upperBound(5)); + sensor.add(stat1Name, stat1, stat1Config); + + final MeasurableStat stat2 = Mockito.mock(MeasurableStat.class); + final MetricName stat2Name = metrics.metricName("stat2", "test-group"); + final MetricConfig stat2Config = new MetricConfig().quota(Quota.upperBound(10)); + sensor.add(stat2Name, stat2, stat2Config); + + sensor.record(10, 1); + Mockito.verify(stat1).record(stat1Config, 10, 1); + Mockito.verify(stat2).record(stat2Config, 10, 1); + + sensor.checkQuotas(2); + Mockito.verify(stat1).measure(stat1Config, 2); + Mockito.verify(stat2).measure(stat2Config, 2); + + metrics.close(); + } + + @Test + public void testUpdatingMetricConfigIsReflectedInTheSensor() { + final Time time = new MockTime(0, System.currentTimeMillis(), 0); + final Metrics metrics = new Metrics(time); + final Sensor sensor = metrics.sensor("sensor"); + + final MeasurableStat stat = Mockito.mock(MeasurableStat.class); + final MetricName statName = metrics.metricName("stat", "test-group"); + final MetricConfig statConfig = new MetricConfig().quota(Quota.upperBound(5)); + sensor.add(statName, stat, statConfig); + + sensor.record(10, 1); + Mockito.verify(stat).record(statConfig, 10, 1); + + sensor.checkQuotas(2); + Mockito.verify(stat).measure(statConfig, 2); + + // Update the config of the KafkaMetric + final MetricConfig newConfig = new MetricConfig().quota(Quota.upperBound(10)); + metrics.metric(statName).config(newConfig); + + sensor.record(10, 3); + Mockito.verify(stat).record(newConfig, 10, 3); + + sensor.checkQuotas(4); + Mockito.verify(stat).measure(newConfig, 4); + + metrics.close(); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/metrics/TokenBucketTest.java b/clients/src/test/java/org/apache/kafka/common/metrics/TokenBucketTest.java new file mode 100644 index 0000000..f0e58a4 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/metrics/TokenBucketTest.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.concurrent.TimeUnit; +import org.apache.kafka.common.metrics.stats.TokenBucket; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class TokenBucketTest { + Time time; + + @BeforeEach + public void setup() { + time = new MockTime(0, System.currentTimeMillis(), System.nanoTime()); + } + + @Test + public void testRecord() { + // Rate = 5 unit / sec + // Burst = 2 * 10 = 20 units + MetricConfig config = new MetricConfig() + .quota(Quota.upperBound(5)) + .timeWindow(2, TimeUnit.SECONDS) + .samples(10); + + TokenBucket tk = new TokenBucket(); + + // Expect 100 credits at T + assertEquals(100, tk.measure(config, time.milliseconds()), 0.1); + + // Record 60 at T, expect 13 credits + tk.record(config, 60, time.milliseconds()); + assertEquals(40, tk.measure(config, time.milliseconds()), 0.1); + + // Advance by 2s, record 5, expect 45 credits + time.sleep(2000); + tk.record(config, 5, time.milliseconds()); + assertEquals(45, tk.measure(config, time.milliseconds()), 0.1); + + // Advance by 2s, record 60, expect -5 credits + time.sleep(2000); + tk.record(config, 60, time.milliseconds()); + assertEquals(-5, tk.measure(config, time.milliseconds()), 0.1); + } + + @Test + public void testUnrecord() { + // Rate = 5 unit / sec + // Burst = 2 * 10 = 20 units + MetricConfig config = new MetricConfig() + .quota(Quota.upperBound(5)) + .timeWindow(2, TimeUnit.SECONDS) + .samples(10); + + TokenBucket tk = new TokenBucket(); + + // Expect 100 credits at T + assertEquals(100, tk.measure(config, time.milliseconds()), 0.1); + + // Record -60 at T, expect 100 credits + tk.record(config, -60, time.milliseconds()); + assertEquals(100, tk.measure(config, time.milliseconds()), 0.1); + + // Advance by 2s, record 60, expect 40 credits + time.sleep(2000); + tk.record(config, 60, time.milliseconds()); + assertEquals(40, tk.measure(config, time.milliseconds()), 0.1); + + // Advance by 2s, record -60, expect 100 credits + time.sleep(2000); + tk.record(config, -60, time.milliseconds()); + assertEquals(100, tk.measure(config, time.milliseconds()), 0.1); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/metrics/internals/IntGaugeSuiteTest.java b/clients/src/test/java/org/apache/kafka/common/metrics/internals/IntGaugeSuiteTest.java new file mode 100644 index 0000000..0e6a913 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/metrics/internals/IntGaugeSuiteTest.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.metrics.internals; + +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.MetricConfig; +import org.apache.kafka.common.metrics.Metrics; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collections; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; + +public class IntGaugeSuiteTest { + private static final Logger log = LoggerFactory.getLogger(IntGaugeSuiteTest.class); + + private static IntGaugeSuite createIntGaugeSuite() { + MetricConfig config = new MetricConfig(); + Metrics metrics = new Metrics(config); + IntGaugeSuite suite = new IntGaugeSuite<>(log, + "mySuite", + metrics, + name -> new MetricName(name, "group", "myMetric", Collections.emptyMap()), + 3); + return suite; + } + + @Test + public void testCreateAndClose() { + IntGaugeSuite suite = createIntGaugeSuite(); + assertEquals(3, suite.maxEntries()); + suite.close(); + suite.close(); + suite.metrics().close(); + } + + @Test + public void testCreateMetrics() { + IntGaugeSuite suite = createIntGaugeSuite(); + suite.increment("foo"); + Map values = suite.values(); + assertEquals(Integer.valueOf(1), values.get("foo")); + assertEquals(1, values.size()); + suite.increment("foo"); + suite.increment("bar"); + suite.increment("baz"); + suite.increment("quux"); + values = suite.values(); + assertEquals(Integer.valueOf(2), values.get("foo")); + assertEquals(Integer.valueOf(1), values.get("bar")); + assertEquals(Integer.valueOf(1), values.get("baz")); + assertEquals(3, values.size()); + assertFalse(values.containsKey("quux")); + suite.close(); + suite.metrics().close(); + } + + @Test + public void testCreateAndRemoveMetrics() { + IntGaugeSuite suite = createIntGaugeSuite(); + suite.increment("foo"); + suite.decrement("foo"); + suite.increment("foo"); + suite.increment("foo"); + suite.increment("bar"); + suite.decrement("bar"); + suite.increment("baz"); + suite.increment("quux"); + Map values = suite.values(); + assertEquals(Integer.valueOf(2), values.get("foo")); + assertFalse(values.containsKey("bar")); + assertEquals(Integer.valueOf(1), values.get("baz")); + assertEquals(Integer.valueOf(1), values.get("quux")); + assertEquals(3, values.size()); + suite.close(); + suite.metrics().close(); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/metrics/internals/MetricsUtilsTest.java b/clients/src/test/java/org/apache/kafka/common/metrics/internals/MetricsUtilsTest.java new file mode 100644 index 0000000..06268c5 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/metrics/internals/MetricsUtilsTest.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics.internals; + +import org.junit.jupiter.api.Test; + +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class MetricsUtilsTest { + + @Test + public void testCreatingTags() { + Map tags = MetricsUtils.getTags("k1", "v1", "k2", "v2"); + assertEquals("v1", tags.get("k1")); + assertEquals("v2", tags.get("k2")); + assertEquals(2, tags.size()); + } + + @Test + public void testCreatingTagsWithOddNumberOfTags() { + assertThrows(IllegalArgumentException.class, () -> MetricsUtils.getTags("k1", "v1", "k2", "v2", "extra")); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/metrics/stats/FrequenciesTest.java b/clients/src/test/java/org/apache/kafka/common/metrics/stats/FrequenciesTest.java new file mode 100644 index 0000000..b44306b --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/metrics/stats/FrequenciesTest.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics.stats; + +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.CompoundStat.NamedMeasurable; +import org.apache.kafka.common.metrics.JmxReporter; +import org.apache.kafka.common.metrics.MetricConfig; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class FrequenciesTest { + + private static final double DELTA = 0.0001d; + private MetricConfig config; + private Time time; + private Metrics metrics; + + @BeforeEach + public void setup() { + config = new MetricConfig().eventWindow(50).samples(2); + time = new MockTime(); + metrics = new Metrics(config, Arrays.asList(new JmxReporter()), time, true); + } + + @AfterEach + public void tearDown() { + metrics.close(); + } + + @Test + public void testFrequencyCenterValueAboveMax() { + assertThrows(IllegalArgumentException.class, + () -> new Frequencies(4, 1.0, 4.0, freq("1", 1.0), freq("2", 20.0))); + } + + @Test + public void testFrequencyCenterValueBelowMin() { + assertThrows(IllegalArgumentException.class, + () -> new Frequencies(4, 1.0, 4.0, freq("1", 1.0), freq("2", -20.0))); + } + + @Test + public void testMoreFrequencyParametersThanBuckets() { + assertThrows(IllegalArgumentException.class, + () -> new Frequencies(1, 1.0, 4.0, freq("1", 1.0), freq("2", -20.0))); + } + + @Test + public void testBooleanFrequencies() { + MetricName metricTrue = name("true"); + MetricName metricFalse = name("false"); + Frequencies frequencies = Frequencies.forBooleanValues(metricFalse, metricTrue); + final NamedMeasurable falseMetric = frequencies.stats().get(0); + final NamedMeasurable trueMetric = frequencies.stats().get(1); + + // Record 2 windows worth of values + for (int i = 0; i != 25; ++i) { + frequencies.record(config, 0.0, time.milliseconds()); + } + for (int i = 0; i != 75; ++i) { + frequencies.record(config, 1.0, time.milliseconds()); + } + assertEquals(0.25, falseMetric.stat().measure(config, time.milliseconds()), DELTA); + assertEquals(0.75, trueMetric.stat().measure(config, time.milliseconds()), DELTA); + + // Record 2 more windows worth of values + for (int i = 0; i != 40; ++i) { + frequencies.record(config, 0.0, time.milliseconds()); + } + for (int i = 0; i != 60; ++i) { + frequencies.record(config, 1.0, time.milliseconds()); + } + assertEquals(0.40, falseMetric.stat().measure(config, time.milliseconds()), DELTA); + assertEquals(0.60, trueMetric.stat().measure(config, time.milliseconds()), DELTA); + } + + @Test + public void testUseWithMetrics() { + MetricName name1 = name("1"); + MetricName name2 = name("2"); + MetricName name3 = name("3"); + MetricName name4 = name("4"); + Frequencies frequencies = new Frequencies(4, 1.0, 4.0, + new Frequency(name1, 1.0), + new Frequency(name2, 2.0), + new Frequency(name3, 3.0), + new Frequency(name4, 4.0)); + Sensor sensor = metrics.sensor("test", config); + sensor.add(frequencies); + Metric metric1 = this.metrics.metrics().get(name1); + Metric metric2 = this.metrics.metrics().get(name2); + Metric metric3 = this.metrics.metrics().get(name3); + Metric metric4 = this.metrics.metrics().get(name4); + + // Record 2 windows worth of values + for (int i = 0; i != 100; ++i) { + frequencies.record(config, i % 4 + 1, time.milliseconds()); + } + assertEquals(0.25, (Double) metric1.metricValue(), DELTA); + assertEquals(0.25, (Double) metric2.metricValue(), DELTA); + assertEquals(0.25, (Double) metric3.metricValue(), DELTA); + assertEquals(0.25, (Double) metric4.metricValue(), DELTA); + + // Record 2 windows worth of values + for (int i = 0; i != 100; ++i) { + frequencies.record(config, i % 2 + 1, time.milliseconds()); + } + assertEquals(0.50, (Double) metric1.metricValue(), DELTA); + assertEquals(0.50, (Double) metric2.metricValue(), DELTA); + assertEquals(0.00, (Double) metric3.metricValue(), DELTA); + assertEquals(0.00, (Double) metric4.metricValue(), DELTA); + + // Record 1 window worth of values to overlap with the last window + // that is half 1.0 and half 2.0 + for (int i = 0; i != 50; ++i) { + frequencies.record(config, 4.0, time.milliseconds()); + } + assertEquals(0.25, (Double) metric1.metricValue(), DELTA); + assertEquals(0.25, (Double) metric2.metricValue(), DELTA); + assertEquals(0.00, (Double) metric3.metricValue(), DELTA); + assertEquals(0.50, (Double) metric4.metricValue(), DELTA); + } + + protected MetricName name(String metricName) { + return new MetricName(metricName, "group-id", "desc", Collections.emptyMap()); + } + + protected Frequency freq(String name, double value) { + return new Frequency(name(name), value); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/metrics/stats/HistogramTest.java b/clients/src/test/java/org/apache/kafka/common/metrics/stats/HistogramTest.java new file mode 100644 index 0000000..fe2760d --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/metrics/stats/HistogramTest.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics.stats; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.Arrays; +import java.util.Random; + + +import org.apache.kafka.common.metrics.stats.Histogram.BinScheme; +import org.apache.kafka.common.metrics.stats.Histogram.ConstantBinScheme; +import org.apache.kafka.common.metrics.stats.Histogram.LinearBinScheme; +import org.junit.jupiter.api.Test; + +public class HistogramTest { + + private static final double EPS = 0.0000001d; + + @Test + public void testHistogram() { + BinScheme scheme = new ConstantBinScheme(10, -5, 5); + Histogram hist = new Histogram(scheme); + for (int i = -5; i < 5; i++) + hist.record(i); + for (int i = 0; i < 10; i++) + assertEquals(scheme.fromBin(i), hist.value(i / 10.0 + EPS), EPS); + } + + @Test + public void testConstantBinScheme() { + ConstantBinScheme scheme = new ConstantBinScheme(5, -5, 5); + assertEquals(0, scheme.toBin(-5.01), "A value below the lower bound should map to the first bin"); + assertEquals(4, scheme.toBin(5.01), "A value above the upper bound should map to the last bin"); + assertEquals(0, scheme.toBin(-5.0001), "Check boundary of bucket 0"); + assertEquals(0, scheme.toBin(-5.0000), "Check boundary of bucket 0"); + assertEquals(0, scheme.toBin(-4.99999), "Check boundary of bucket 0"); + assertEquals(0, scheme.toBin(-3.00001), "Check boundary of bucket 0"); + assertEquals(1, scheme.toBin(-3), "Check boundary of bucket 1"); + assertEquals(1, scheme.toBin(-1.00001), "Check boundary of bucket 1"); + assertEquals(2, scheme.toBin(-1), "Check boundary of bucket 2"); + assertEquals(2, scheme.toBin(0.99999), "Check boundary of bucket 2"); + assertEquals(3, scheme.toBin(1), "Check boundary of bucket 3"); + assertEquals(3, scheme.toBin(2.99999), "Check boundary of bucket 3"); + assertEquals(4, scheme.toBin(3), "Check boundary of bucket 4"); + assertEquals(4, scheme.toBin(4.9999), "Check boundary of bucket 4"); + assertEquals(4, scheme.toBin(5.000), "Check boundary of bucket 4"); + assertEquals(4, scheme.toBin(5.001), "Check boundary of bucket 4"); + assertEquals(Float.NEGATIVE_INFINITY, scheme.fromBin(-1), 0.001d); + assertEquals(Float.POSITIVE_INFINITY, scheme.fromBin(5), 0.001d); + assertEquals(-5.0, scheme.fromBin(0), 0.001d); + assertEquals(-3.0, scheme.fromBin(1), 0.001d); + assertEquals(-1.0, scheme.fromBin(2), 0.001d); + assertEquals(1.0, scheme.fromBin(3), 0.001d); + assertEquals(3.0, scheme.fromBin(4), 0.001d); + checkBinningConsistency(scheme); + } + + @Test + public void testConstantBinSchemeWithPositiveRange() { + ConstantBinScheme scheme = new ConstantBinScheme(5, 0, 5); + assertEquals(0, scheme.toBin(-1.0), "A value below the lower bound should map to the first bin"); + assertEquals(4, scheme.toBin(5.01), "A value above the upper bound should map to the last bin"); + assertEquals(0, scheme.toBin(-0.0001), "Check boundary of bucket 0"); + assertEquals(0, scheme.toBin(0.0000), "Check boundary of bucket 0"); + assertEquals(0, scheme.toBin(0.0001), "Check boundary of bucket 0"); + assertEquals(0, scheme.toBin(0.9999), "Check boundary of bucket 0"); + assertEquals(1, scheme.toBin(1.0000), "Check boundary of bucket 1"); + assertEquals(1, scheme.toBin(1.0001), "Check boundary of bucket 1"); + assertEquals(1, scheme.toBin(1.9999), "Check boundary of bucket 1"); + assertEquals(2, scheme.toBin(2.0000), "Check boundary of bucket 2"); + assertEquals(2, scheme.toBin(2.0001), "Check boundary of bucket 2"); + assertEquals(2, scheme.toBin(2.9999), "Check boundary of bucket 2"); + assertEquals(3, scheme.toBin(3.0000), "Check boundary of bucket 3"); + assertEquals(3, scheme.toBin(3.0001), "Check boundary of bucket 3"); + assertEquals(3, scheme.toBin(3.9999), "Check boundary of bucket 3"); + assertEquals(4, scheme.toBin(4.0000), "Check boundary of bucket 4"); + assertEquals(4, scheme.toBin(4.9999), "Check boundary of bucket 4"); + assertEquals(4, scheme.toBin(5.0000), "Check boundary of bucket 4"); + assertEquals(4, scheme.toBin(5.0001), "Check boundary of bucket 4"); + assertEquals(Float.NEGATIVE_INFINITY, scheme.fromBin(-1), 0.001d); + assertEquals(Float.POSITIVE_INFINITY, scheme.fromBin(5), 0.001d); + assertEquals(0.0, scheme.fromBin(0), 0.001d); + assertEquals(1.0, scheme.fromBin(1), 0.001d); + assertEquals(2.0, scheme.fromBin(2), 0.001d); + assertEquals(3.0, scheme.fromBin(3), 0.001d); + assertEquals(4.0, scheme.fromBin(4), 0.001d); + checkBinningConsistency(scheme); + } + + @Test + public void testLinearBinScheme() { + LinearBinScheme scheme = new LinearBinScheme(10, 10); + assertEquals(Float.NEGATIVE_INFINITY, scheme.fromBin(-1), 0.001d); + assertEquals(Float.POSITIVE_INFINITY, scheme.fromBin(11), 0.001d); + assertEquals(0.0, scheme.fromBin(0), 0.001d); + assertEquals(0.2222, scheme.fromBin(1), 0.001d); + assertEquals(0.6666, scheme.fromBin(2), 0.001d); + assertEquals(1.3333, scheme.fromBin(3), 0.001d); + assertEquals(2.2222, scheme.fromBin(4), 0.001d); + assertEquals(3.3333, scheme.fromBin(5), 0.001d); + assertEquals(4.6667, scheme.fromBin(6), 0.001d); + assertEquals(6.2222, scheme.fromBin(7), 0.001d); + assertEquals(8.0000, scheme.fromBin(8), 0.001d); + assertEquals(10.000, scheme.fromBin(9), 0.001d); + assertEquals(0, scheme.toBin(0.0000)); + assertEquals(0, scheme.toBin(0.2221)); + assertEquals(1, scheme.toBin(0.2223)); + assertEquals(2, scheme.toBin(0.6667)); + assertEquals(3, scheme.toBin(1.3334)); + assertEquals(4, scheme.toBin(2.2223)); + assertEquals(5, scheme.toBin(3.3334)); + assertEquals(6, scheme.toBin(4.6667)); + assertEquals(7, scheme.toBin(6.2223)); + assertEquals(8, scheme.toBin(8.0000)); + assertEquals(9, scheme.toBin(10.000)); + assertEquals(9, scheme.toBin(10.001)); + assertEquals(Float.POSITIVE_INFINITY, scheme.fromBin(10), 0.001d); + checkBinningConsistency(scheme); + } + + private void checkBinningConsistency(BinScheme scheme) { + for (int bin = 0; bin < scheme.bins(); bin++) { + double fromBin = scheme.fromBin(bin); + int binAgain = scheme.toBin(fromBin + EPS); + assertEquals(bin, binAgain, "unbinning and rebinning the bin " + bin + + " gave a different result (" + + fromBin + + " was placed in bin " + + binAgain + + " )"); + } + } + + public static void main(String[] args) { + Random random = new Random(); + System.out.println("[-100, 100]:"); + for (BinScheme scheme : Arrays.asList(new ConstantBinScheme(1000, -100, 100), + new ConstantBinScheme(100, -100, 100), + new ConstantBinScheme(10, -100, 100))) { + Histogram h = new Histogram(scheme); + for (int i = 0; i < 10000; i++) + h.record(200.0 * random.nextDouble() - 100.0); + for (double quantile = 0.0; quantile < 1.0; quantile += 0.05) + System.out.printf("%5.2f: %.1f, ", quantile, h.value(quantile)); + System.out.println(); + } + + System.out.println("[0, 1000]"); + for (BinScheme scheme : Arrays.asList(new LinearBinScheme(1000, 1000), + new LinearBinScheme(100, 1000), + new LinearBinScheme(10, 1000))) { + Histogram h = new Histogram(scheme); + for (int i = 0; i < 10000; i++) + h.record(1000.0 * random.nextDouble()); + for (double quantile = 0.0; quantile < 1.0; quantile += 0.05) + System.out.printf("%5.2f: %.1f, ", quantile, h.value(quantile)); + System.out.println(); + } + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/metrics/stats/MeterTest.java b/clients/src/test/java/org/apache/kafka/common/metrics/stats/MeterTest.java new file mode 100644 index 0000000..8d33e61 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/metrics/stats/MeterTest.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics.stats; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.CompoundStat.NamedMeasurable; +import org.apache.kafka.common.metrics.MetricConfig; +import org.junit.jupiter.api.Test; + +public class MeterTest { + + private static final double EPS = 0.0000001d; + + @Test + public void testMeter() { + Map emptyTags = Collections.emptyMap(); + MetricName rateMetricName = new MetricName("rate", "test", "", emptyTags); + MetricName totalMetricName = new MetricName("total", "test", "", emptyTags); + Meter meter = new Meter(rateMetricName, totalMetricName); + List stats = meter.stats(); + assertEquals(2, stats.size()); + NamedMeasurable total = stats.get(0); + NamedMeasurable rate = stats.get(1); + assertEquals(rateMetricName, rate.name()); + assertEquals(totalMetricName, total.name()); + Rate rateStat = (Rate) rate.stat(); + CumulativeSum totalStat = (CumulativeSum) total.stat(); + + MetricConfig config = new MetricConfig(); + double nextValue = 0.0; + double expectedTotal = 0.0; + long now = 0; + double intervalMs = 100; + double delta = 5.0; + + // Record values in multiple windows and verify that rates are reported + // for time windows and that the total is cumulative. + for (int i = 1; i <= 100; i++) { + for (; now < i * 1000; now += intervalMs, nextValue += delta) { + expectedTotal += nextValue; + meter.record(config, nextValue, now); + } + assertEquals(expectedTotal, totalStat.measure(config, now), EPS); + long windowSizeMs = rateStat.windowSize(config, now); + long windowStartMs = Math.max(now - windowSizeMs, 0); + double sampledTotal = 0.0; + double prevValue = nextValue - delta; + for (long timeMs = now - 100; timeMs >= windowStartMs; timeMs -= intervalMs, prevValue -= delta) + sampledTotal += prevValue; + assertEquals(sampledTotal * 1000 / windowSizeMs, rateStat.measure(config, now), EPS); + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/network/CertStores.java b/clients/src/test/java/org/apache/kafka/common/network/CertStores.java new file mode 100644 index 0000000..3230da2 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/network/CertStores.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import java.util.ArrayList; +import java.util.List; +import org.apache.kafka.common.config.SslConfigs; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.test.TestSslUtils; + +import java.io.File; +import java.net.InetAddress; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import org.apache.kafka.test.TestSslUtils.SslConfigsBuilder; + +public class CertStores { + + public static final Set KEYSTORE_PROPS = Utils.mkSet( + SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG, + SslConfigs.SSL_KEYSTORE_TYPE_CONFIG, + SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG, + SslConfigs.SSL_KEY_PASSWORD_CONFIG, + SslConfigs.SSL_KEYSTORE_KEY_CONFIG, + SslConfigs.SSL_KEYSTORE_CERTIFICATE_CHAIN_CONFIG); + + public static final Set TRUSTSTORE_PROPS = Utils.mkSet( + SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG, + SslConfigs.SSL_TRUSTSTORE_TYPE_CONFIG, + SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG, + SslConfigs.SSL_TRUSTSTORE_CERTIFICATES_CONFIG); + + private final Map sslConfig; + + public CertStores(boolean server, String hostName) throws Exception { + this(server, hostName, new TestSslUtils.CertificateBuilder()); + } + + public CertStores(boolean server, String commonName, String sanHostName) throws Exception { + this(server, commonName, new TestSslUtils.CertificateBuilder().sanDnsNames(sanHostName)); + } + + public CertStores(boolean server, String commonName, InetAddress hostAddress) throws Exception { + this(server, commonName, new TestSslUtils.CertificateBuilder().sanIpAddress(hostAddress)); + } + + private CertStores(boolean server, String commonName, TestSslUtils.CertificateBuilder certBuilder) throws Exception { + this(server, commonName, "RSA", certBuilder, false); + } + + private CertStores(boolean server, String commonName, String keyAlgorithm, TestSslUtils.CertificateBuilder certBuilder, boolean usePem) throws Exception { + String name = server ? "server" : "client"; + Mode mode = server ? Mode.SERVER : Mode.CLIENT; + File truststoreFile = usePem ? null : File.createTempFile(name + "TS", ".jks"); + sslConfig = new SslConfigsBuilder(mode) + .useClientCert(!server) + .certAlias(name) + .cn(commonName) + .createNewTrustStore(truststoreFile) + .certBuilder(certBuilder) + .algorithm(keyAlgorithm) + .usePem(usePem) + .build(); + } + + + public Map getTrustingConfig(CertStores truststoreConfig) { + Map config = new HashMap<>(sslConfig); + for (String propName : TRUSTSTORE_PROPS) { + config.put(propName, truststoreConfig.sslConfig.get(propName)); + } + return config; + } + + public Map getUntrustingConfig() { + return sslConfig; + } + + public Map keyStoreProps() { + Map props = new HashMap<>(); + for (String propName : KEYSTORE_PROPS) { + props.put(propName, sslConfig.get(propName)); + } + return props; + } + + public Map trustStoreProps() { + Map props = new HashMap<>(); + for (String propName : TRUSTSTORE_PROPS) { + props.put(propName, sslConfig.get(propName)); + } + return props; + } + + public static class Builder { + private final boolean isServer; + private String cn; + private List sanDns; + private InetAddress sanIp; + private String keyAlgorithm; + private boolean usePem; + + public Builder(boolean isServer) { + this.isServer = isServer; + this.sanDns = new ArrayList<>(); + this.keyAlgorithm = "RSA"; + } + + public Builder cn(String cn) { + this.cn = cn; + return this; + } + + public Builder addHostName(String hostname) { + this.sanDns.add(hostname); + return this; + } + + public Builder hostAddress(InetAddress hostAddress) { + this.sanIp = hostAddress; + return this; + } + + public Builder keyAlgorithm(String keyAlgorithm) { + this.keyAlgorithm = keyAlgorithm; + return this; + } + + public Builder usePem(boolean usePem) { + this.usePem = usePem; + return this; + } + + public CertStores build() throws Exception { + TestSslUtils.CertificateBuilder certBuilder = new TestSslUtils.CertificateBuilder() + .sanDnsNames(sanDns.toArray(new String[0])); + if (sanIp != null) + certBuilder = certBuilder.sanIpAddress(sanIp); + return new CertStores(isServer, cn, keyAlgorithm, certBuilder, usePem); + } + } +} \ No newline at end of file diff --git a/clients/src/test/java/org/apache/kafka/common/network/ChannelBuildersTest.java b/clients/src/test/java/org/apache/kafka/common/network/ChannelBuildersTest.java new file mode 100644 index 0000000..f1d367b --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/network/ChannelBuildersTest.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import org.apache.kafka.common.Configurable; +import org.apache.kafka.common.config.internals.BrokerSecurityConfigs; +import org.apache.kafka.common.security.TestSecurityConfig; +import org.apache.kafka.common.security.auth.AuthenticationContext; +import org.apache.kafka.common.security.auth.KafkaPrincipal; +import org.apache.kafka.common.security.auth.KafkaPrincipalBuilder; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ChannelBuildersTest { + + @Test + public void testCreateConfigurableKafkaPrincipalBuilder() { + Map configs = new HashMap<>(); + configs.put(BrokerSecurityConfigs.PRINCIPAL_BUILDER_CLASS_CONFIG, ConfigurableKafkaPrincipalBuilder.class); + KafkaPrincipalBuilder builder = ChannelBuilders.createPrincipalBuilder(configs, null, null); + assertTrue(builder instanceof ConfigurableKafkaPrincipalBuilder); + assertTrue(((ConfigurableKafkaPrincipalBuilder) builder).configured); + } + + @Test + public void testChannelBuilderConfigs() { + Properties props = new Properties(); + props.put("listener.name.listener1.gssapi.sasl.kerberos.service.name", "testkafka"); + props.put("listener.name.listener1.sasl.kerberos.service.name", "testkafkaglobal"); + props.put("plain.sasl.server.callback.handler.class", "callback"); + props.put("listener.name.listener1.gssapi.config1.key", "custom.config1"); + props.put("custom.config2.key", "custom.config2"); + TestSecurityConfig securityConfig = new TestSecurityConfig(props); + + // test configs with listener prefix + Map configs = ChannelBuilders.channelBuilderConfigs(securityConfig, new ListenerName("listener1")); + + assertNull(configs.get("listener.name.listener1.gssapi.sasl.kerberos.service.name")); + assertFalse(securityConfig.unused().contains("listener.name.listener1.gssapi.sasl.kerberos.service.name")); + + assertEquals(configs.get("gssapi.sasl.kerberos.service.name"), "testkafka"); + assertFalse(securityConfig.unused().contains("gssapi.sasl.kerberos.service.name")); + + assertEquals(configs.get("sasl.kerberos.service.name"), "testkafkaglobal"); + assertFalse(securityConfig.unused().contains("sasl.kerberos.service.name")); + + assertNull(configs.get("listener.name.listener1.sasl.kerberos.service.name")); + assertFalse(securityConfig.unused().contains("listener.name.listener1.sasl.kerberos.service.name")); + + assertNull(configs.get("plain.sasl.server.callback.handler.class")); + assertFalse(securityConfig.unused().contains("plain.sasl.server.callback.handler.class")); + + assertEquals(configs.get("listener.name.listener1.gssapi.config1.key"), "custom.config1"); + assertFalse(securityConfig.unused().contains("listener.name.listener1.gssapi.config1.key")); + + assertEquals(configs.get("custom.config2.key"), "custom.config2"); + assertFalse(securityConfig.unused().contains("custom.config2.key")); + + // test configs without listener prefix + securityConfig = new TestSecurityConfig(props); + configs = ChannelBuilders.channelBuilderConfigs(securityConfig, null); + + assertEquals(configs.get("listener.name.listener1.gssapi.sasl.kerberos.service.name"), "testkafka"); + assertFalse(securityConfig.unused().contains("listener.name.listener1.gssapi.sasl.kerberos.service.name")); + + assertNull(configs.get("gssapi.sasl.kerberos.service.name")); + assertFalse(securityConfig.unused().contains("gssapi.sasl.kerberos.service.name")); + + assertEquals(configs.get("listener.name.listener1.sasl.kerberos.service.name"), "testkafkaglobal"); + assertFalse(securityConfig.unused().contains("listener.name.listener1.sasl.kerberos.service.name")); + + assertNull(configs.get("sasl.kerberos.service.name")); + assertFalse(securityConfig.unused().contains("sasl.kerberos.service.name")); + + assertEquals(configs.get("plain.sasl.server.callback.handler.class"), "callback"); + assertFalse(securityConfig.unused().contains("plain.sasl.server.callback.handler.class")); + + assertEquals(configs.get("listener.name.listener1.gssapi.config1.key"), "custom.config1"); + assertFalse(securityConfig.unused().contains("listener.name.listener1.gssapi.config1.key")); + + assertEquals(configs.get("custom.config2.key"), "custom.config2"); + assertFalse(securityConfig.unused().contains("custom.config2.key")); + } + + public static class ConfigurableKafkaPrincipalBuilder implements KafkaPrincipalBuilder, Configurable { + private boolean configured = false; + + @Override + public void configure(Map configs) { + configured = true; + } + + @Override + public KafkaPrincipal build(AuthenticationContext context) { + return null; + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/network/EchoServer.java b/clients/src/test/java/org/apache/kafka/common/network/EchoServer.java new file mode 100644 index 0000000..d0cc059 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/network/EchoServer.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.common.security.ssl.DefaultSslEngineFactory; +import org.apache.kafka.common.security.ssl.SslFactory; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSocket; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.net.ServerSocket; +import java.net.Socket; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Map; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; + + +/** + * A simple server that takes size delimited byte arrays and just echos them back to the sender. + */ +class EchoServer extends Thread { + public final int port; + private final ServerSocket serverSocket; + private final List threads; + private final List sockets; + private volatile boolean closing = false; + private final SslFactory sslFactory; + private final AtomicBoolean renegotiate = new AtomicBoolean(); + + public EchoServer(SecurityProtocol securityProtocol, Map configs) throws Exception { + switch (securityProtocol) { + case SSL: + this.sslFactory = new SslFactory(Mode.SERVER); + this.sslFactory.configure(configs); + SSLContext sslContext = ((DefaultSslEngineFactory) this.sslFactory.sslEngineFactory()).sslContext(); + this.serverSocket = sslContext.getServerSocketFactory().createServerSocket(0); + break; + case PLAINTEXT: + this.serverSocket = new ServerSocket(0); + this.sslFactory = null; + break; + default: + throw new IllegalArgumentException("Unsupported securityProtocol " + securityProtocol); + } + this.port = this.serverSocket.getLocalPort(); + this.threads = Collections.synchronizedList(new ArrayList()); + this.sockets = Collections.synchronizedList(new ArrayList()); + } + + public void renegotiate() { + renegotiate.set(true); + } + + @Override + public void run() { + try { + while (!closing) { + final Socket socket = serverSocket.accept(); + synchronized (sockets) { + if (closing) { + break; + } + sockets.add(socket); + Thread thread = new Thread() { + @Override + public void run() { + try { + DataInputStream input = new DataInputStream(socket.getInputStream()); + DataOutputStream output = new DataOutputStream(socket.getOutputStream()); + while (socket.isConnected() && !socket.isClosed()) { + int size = input.readInt(); + if (renegotiate.get()) { + renegotiate.set(false); + ((SSLSocket) socket).startHandshake(); + } + byte[] bytes = new byte[size]; + input.readFully(bytes); + output.writeInt(size); + output.write(bytes); + output.flush(); + } + } catch (IOException e) { + // ignore + } finally { + try { + socket.close(); + } catch (IOException e) { + // ignore + } + } + } + }; + thread.start(); + threads.add(thread); + } + } + } catch (IOException e) { + // ignore + } + } + + public void closeConnections() throws IOException { + synchronized (sockets) { + for (Socket socket : sockets) + socket.close(); + } + } + + public void close() throws IOException, InterruptedException { + closing = true; + this.serverSocket.close(); + closeConnections(); + for (Thread t : threads) + t.join(); + join(); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/network/KafkaChannelTest.java b/clients/src/test/java/org/apache/kafka/common/network/KafkaChannelTest.java new file mode 100644 index 0000000..f83ea7d --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/network/KafkaChannelTest.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import org.apache.kafka.common.memory.MemoryPool; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class KafkaChannelTest { + + @Test + public void testSending() throws IOException { + Authenticator authenticator = Mockito.mock(Authenticator.class); + TransportLayer transport = Mockito.mock(TransportLayer.class); + MemoryPool pool = Mockito.mock(MemoryPool.class); + ChannelMetadataRegistry metadataRegistry = Mockito.mock(ChannelMetadataRegistry.class); + + KafkaChannel channel = new KafkaChannel("0", transport, () -> authenticator, + 1024, pool, metadataRegistry); + ByteBufferSend send = ByteBufferSend.sizePrefixed(ByteBuffer.wrap(TestUtils.randomBytes(128))); + NetworkSend networkSend = new NetworkSend("0", send); + + channel.setSend(networkSend); + assertTrue(channel.hasSend()); + assertThrows(IllegalStateException.class, () -> channel.setSend(networkSend)); + + Mockito.when(transport.write(Mockito.any(ByteBuffer[].class))).thenReturn(4L); + assertEquals(4L, channel.write()); + assertEquals(128, send.remaining()); + assertNull(channel.maybeCompleteSend()); + + Mockito.when(transport.write(Mockito.any(ByteBuffer[].class))).thenReturn(64L); + assertEquals(64, channel.write()); + assertEquals(64, send.remaining()); + assertNull(channel.maybeCompleteSend()); + + Mockito.when(transport.write(Mockito.any(ByteBuffer[].class))).thenReturn(64L); + assertEquals(64, channel.write()); + assertEquals(0, send.remaining()); + assertEquals(networkSend, channel.maybeCompleteSend()); + } + + @Test + public void testReceiving() throws IOException { + Authenticator authenticator = Mockito.mock(Authenticator.class); + TransportLayer transport = Mockito.mock(TransportLayer.class); + MemoryPool pool = Mockito.mock(MemoryPool.class); + ChannelMetadataRegistry metadataRegistry = Mockito.mock(ChannelMetadataRegistry.class); + + ArgumentCaptor sizeCaptor = ArgumentCaptor.forClass(Integer.class); + Mockito.when(pool.tryAllocate(sizeCaptor.capture())).thenAnswer(invocation -> { + return ByteBuffer.allocate(sizeCaptor.getValue()); + }); + + KafkaChannel channel = new KafkaChannel("0", transport, () -> authenticator, + 1024, pool, metadataRegistry); + + ArgumentCaptor bufferCaptor = ArgumentCaptor.forClass(ByteBuffer.class); + Mockito.when(transport.read(bufferCaptor.capture())).thenAnswer(invocation -> { + bufferCaptor.getValue().putInt(128); + return 4; + }).thenReturn(0); + assertEquals(4, channel.read()); + assertEquals(4, channel.currentReceive().bytesRead()); + assertNull(channel.maybeCompleteReceive()); + + Mockito.reset(transport); + Mockito.when(transport.read(bufferCaptor.capture())).thenAnswer(invocation -> { + bufferCaptor.getValue().put(TestUtils.randomBytes(64)); + return 64; + }); + assertEquals(64, channel.read()); + assertEquals(68, channel.currentReceive().bytesRead()); + assertNull(channel.maybeCompleteReceive()); + + Mockito.reset(transport); + Mockito.when(transport.read(bufferCaptor.capture())).thenAnswer(invocation -> { + bufferCaptor.getValue().put(TestUtils.randomBytes(64)); + return 64; + }); + assertEquals(64, channel.read()); + assertEquals(132, channel.currentReceive().bytesRead()); + assertNotNull(channel.maybeCompleteReceive()); + assertNull(channel.currentReceive()); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/network/NetworkReceiveTest.java b/clients/src/test/java/org/apache/kafka/common/network/NetworkReceiveTest.java new file mode 100644 index 0000000..ec18c26 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/network/NetworkReceiveTest.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ScatteringByteChannel; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class NetworkReceiveTest { + + @Test + public void testBytesRead() throws IOException { + NetworkReceive receive = new NetworkReceive(128, "0"); + assertEquals(0, receive.bytesRead()); + + ScatteringByteChannel channel = Mockito.mock(ScatteringByteChannel.class); + + ArgumentCaptor bufferCaptor = ArgumentCaptor.forClass(ByteBuffer.class); + Mockito.when(channel.read(bufferCaptor.capture())).thenAnswer(invocation -> { + bufferCaptor.getValue().putInt(128); + return 4; + }).thenReturn(0); + + assertEquals(4, receive.readFrom(channel)); + assertEquals(4, receive.bytesRead()); + assertFalse(receive.complete()); + + Mockito.reset(channel); + Mockito.when(channel.read(bufferCaptor.capture())).thenAnswer(invocation -> { + bufferCaptor.getValue().put(TestUtils.randomBytes(64)); + return 64; + }); + + assertEquals(64, receive.readFrom(channel)); + assertEquals(68, receive.bytesRead()); + assertFalse(receive.complete()); + + Mockito.reset(channel); + Mockito.when(channel.read(bufferCaptor.capture())).thenAnswer(invocation -> { + bufferCaptor.getValue().put(TestUtils.randomBytes(64)); + return 64; + }); + + assertEquals(64, receive.readFrom(channel)); + assertEquals(132, receive.bytesRead()); + assertTrue(receive.complete()); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/network/NetworkTestUtils.java b/clients/src/test/java/org/apache/kafka/common/network/NetworkTestUtils.java new file mode 100644 index 0000000..4f53845 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/network/NetworkTestUtils.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.security.authenticator.CredentialCache; +import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.test.TestUtils; + +/** + * Common utility functions used by transport layer and authenticator tests. + */ +public class NetworkTestUtils { + public static NioEchoServer createEchoServer(ListenerName listenerName, SecurityProtocol securityProtocol, + AbstractConfig serverConfig, CredentialCache credentialCache, Time time) throws Exception { + return createEchoServer(listenerName, securityProtocol, serverConfig, credentialCache, 100, time); + } + + public static NioEchoServer createEchoServer(ListenerName listenerName, SecurityProtocol securityProtocol, + AbstractConfig serverConfig, CredentialCache credentialCache, + int failedAuthenticationDelayMs, Time time) throws Exception { + NioEchoServer server = new NioEchoServer(listenerName, securityProtocol, serverConfig, "localhost", + null, credentialCache, failedAuthenticationDelayMs, time); + server.start(); + return server; + } + + public static NioEchoServer createEchoServer(ListenerName listenerName, SecurityProtocol securityProtocol, + AbstractConfig serverConfig, CredentialCache credentialCache, + int failedAuthenticationDelayMs, Time time, DelegationTokenCache tokenCache) throws Exception { + NioEchoServer server = new NioEchoServer(listenerName, securityProtocol, serverConfig, "localhost", + null, credentialCache, failedAuthenticationDelayMs, time, tokenCache); + server.start(); + return server; + } + + public static Selector createSelector(ChannelBuilder channelBuilder, Time time) { + return new Selector(5000, new Metrics(), time, "MetricGroup", channelBuilder, new LogContext()); + } + + public static void checkClientConnection(Selector selector, String node, int minMessageSize, int messageCount) throws Exception { + waitForChannelReady(selector, node); + String prefix = TestUtils.randomString(minMessageSize); + int requests = 0; + int responses = 0; + selector.send(new NetworkSend(node, ByteBufferSend.sizePrefixed(ByteBuffer.wrap((prefix + "-0").getBytes(StandardCharsets.UTF_8))))); + requests++; + while (responses < messageCount) { + selector.poll(0L); + assertEquals(0, selector.disconnected().size(), "No disconnects should have occurred ." + selector.disconnected()); + + for (NetworkReceive receive : selector.completedReceives()) { + assertEquals(prefix + "-" + responses, new String(Utils.toArray(receive.payload()), StandardCharsets.UTF_8)); + responses++; + } + + for (int i = 0; i < selector.completedSends().size() && requests < messageCount && selector.isChannelReady(node); i++, requests++) { + selector.send(new NetworkSend(node, ByteBufferSend.sizePrefixed(ByteBuffer.wrap((prefix + "-" + requests).getBytes())))); + } + } + } + + public static void waitForChannelReady(Selector selector, String node) throws IOException { + // wait for handshake to finish + int secondsLeft = 30; + while (!selector.isChannelReady(node) && secondsLeft-- > 0) { + selector.poll(1000L); + } + assertTrue(selector.isChannelReady(node)); + } + + public static ChannelState waitForChannelClose(Selector selector, String node, ChannelState.State channelState) + throws IOException { + boolean closed = false; + for (int i = 0; i < 300; i++) { + selector.poll(100L); + if (selector.channel(node) == null && selector.closingChannel(node) == null) { + closed = true; + break; + } + } + assertTrue(closed, "Channel was not closed by timeout"); + ChannelState finalState = selector.disconnected().get(node); + assertEquals(channelState, finalState.state()); + return finalState; + } + + public static void completeDelayedChannelClose(Selector selector, long currentTimeNanos) { + selector.completeDelayedChannelClose(currentTimeNanos); + } + + public static Map delayedClosingChannels(Selector selector) { + return selector.delayedClosingChannels(); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java b/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java new file mode 100644 index 0000000..53b46e7 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java @@ -0,0 +1,384 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.message.ApiMessageType; +import org.apache.kafka.common.metrics.KafkaMetric; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.requests.ApiVersionsResponse; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.common.security.authenticator.CredentialCache; +import org.apache.kafka.common.security.scram.ScramCredential; +import org.apache.kafka.common.security.scram.internals.ScramMechanism; +import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.test.TestUtils; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.channels.SelectionKey; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.nio.channels.WritableByteChannel; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.EnumSet; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * Non-blocking EchoServer implementation that uses ChannelBuilder to create channels + * with the configured security protocol. + * + */ +public class NioEchoServer extends Thread { + public enum MetricType { + TOTAL, RATE, AVG, MAX; + + private final String metricNameSuffix = "-" + name().toLowerCase(Locale.ROOT); + + public String metricNameSuffix() { + return metricNameSuffix; + } + } + + private static final double EPS = 0.0001; + + private final int port; + private final ServerSocketChannel serverSocketChannel; + private final List newChannels; + private final List socketChannels; + private final AcceptorThread acceptorThread; + private final Selector selector; + private volatile TransferableChannel outputChannel; + private final CredentialCache credentialCache; + private final Metrics metrics; + private volatile int numSent = 0; + private volatile boolean closeKafkaChannels; + private final DelegationTokenCache tokenCache; + private final Time time; + + public NioEchoServer(ListenerName listenerName, SecurityProtocol securityProtocol, AbstractConfig config, + String serverHost, ChannelBuilder channelBuilder, CredentialCache credentialCache, Time time) throws Exception { + this(listenerName, securityProtocol, config, serverHost, channelBuilder, credentialCache, 100, time); + } + + public NioEchoServer(ListenerName listenerName, SecurityProtocol securityProtocol, AbstractConfig config, + String serverHost, ChannelBuilder channelBuilder, CredentialCache credentialCache, + int failedAuthenticationDelayMs, Time time) throws Exception { + this(listenerName, securityProtocol, config, serverHost, channelBuilder, credentialCache, failedAuthenticationDelayMs, time, + new DelegationTokenCache(ScramMechanism.mechanismNames())); + } + + public NioEchoServer(ListenerName listenerName, SecurityProtocol securityProtocol, AbstractConfig config, + String serverHost, ChannelBuilder channelBuilder, CredentialCache credentialCache, + int failedAuthenticationDelayMs, Time time, DelegationTokenCache tokenCache) throws Exception { + super("echoserver"); + setDaemon(true); + serverSocketChannel = ServerSocketChannel.open(); + serverSocketChannel.configureBlocking(false); + serverSocketChannel.socket().bind(new InetSocketAddress(serverHost, 0)); + this.port = serverSocketChannel.socket().getLocalPort(); + this.socketChannels = Collections.synchronizedList(new ArrayList()); + this.newChannels = Collections.synchronizedList(new ArrayList()); + this.credentialCache = credentialCache; + this.tokenCache = tokenCache; + if (securityProtocol == SecurityProtocol.SASL_PLAINTEXT || securityProtocol == SecurityProtocol.SASL_SSL) { + for (String mechanism : ScramMechanism.mechanismNames()) { + if (credentialCache.cache(mechanism, ScramCredential.class) == null) + credentialCache.createCache(mechanism, ScramCredential.class); + } + } + LogContext logContext = new LogContext(); + if (channelBuilder == null) + channelBuilder = ChannelBuilders.serverChannelBuilder(listenerName, false, + securityProtocol, config, credentialCache, tokenCache, time, logContext, + () -> ApiVersionsResponse.defaultApiVersionsResponse(ApiMessageType.ListenerType.ZK_BROKER)); + this.metrics = new Metrics(); + this.selector = new Selector(10000, failedAuthenticationDelayMs, metrics, time, + "MetricGroup", channelBuilder, logContext); + acceptorThread = new AcceptorThread(); + this.time = time; + } + + public int port() { + return port; + } + + public CredentialCache credentialCache() { + return credentialCache; + } + + public DelegationTokenCache tokenCache() { + return tokenCache; + } + + public double metricValue(String name) { + for (Map.Entry entry : metrics.metrics().entrySet()) { + if (entry.getKey().name().equals(name)) + return (double) entry.getValue().metricValue(); + } + throw new IllegalStateException("Metric not found, " + name + ", found=" + metrics.metrics().keySet()); + } + + public void verifyAuthenticationMetrics(int successfulAuthentications, final int failedAuthentications) + throws InterruptedException { + waitForMetrics("successful-authentication", successfulAuthentications, + EnumSet.of(MetricType.TOTAL, MetricType.RATE)); + waitForMetrics("failed-authentication", failedAuthentications, EnumSet.of(MetricType.TOTAL, MetricType.RATE)); + } + + public void verifyReauthenticationMetrics(int successfulReauthentications, final int failedReauthentications) + throws InterruptedException { + waitForMetrics("successful-reauthentication", successfulReauthentications, + EnumSet.of(MetricType.TOTAL, MetricType.RATE)); + waitForMetrics("failed-reauthentication", failedReauthentications, + EnumSet.of(MetricType.TOTAL, MetricType.RATE)); + waitForMetrics("successful-authentication-no-reauth", 0, EnumSet.of(MetricType.TOTAL)); + if (!(time instanceof MockTime)) { + waitForMetrics("reauthentication-latency", Math.signum(successfulReauthentications), + EnumSet.of(MetricType.MAX, MetricType.AVG)); + } + } + + public void verifyAuthenticationNoReauthMetric(int successfulAuthenticationNoReauths) throws InterruptedException { + waitForMetrics("successful-authentication-no-reauth", successfulAuthenticationNoReauths, + EnumSet.of(MetricType.TOTAL)); + } + + public void waitForMetric(String name, final double expectedValue) throws InterruptedException { + waitForMetrics(name, expectedValue, EnumSet.of(MetricType.TOTAL, MetricType.RATE)); + } + + public void waitForMetrics(String namePrefix, final double expectedValue, Set metricTypes) + throws InterruptedException { + long maxAggregateWaitMs = 15000; + long startMs = time.milliseconds(); + for (MetricType metricType : metricTypes) { + long currentElapsedMs = time.milliseconds() - startMs; + long thisMaxWaitMs = maxAggregateWaitMs - currentElapsedMs; + String metricName = namePrefix + metricType.metricNameSuffix(); + if (expectedValue == 0.0) { + Double expected = expectedValue; + if (metricType == MetricType.MAX || metricType == MetricType.AVG) + expected = Double.NaN; + + assertEquals(expected, metricValue(metricName), EPS, "Metric not updated " + metricName + + " expected:<" + expectedValue + "> but was:<" + metricValue(metricName) + ">"); + } else if (metricType == MetricType.TOTAL) + TestUtils.waitForCondition(() -> Math.abs(metricValue(metricName) - expectedValue) <= EPS, + thisMaxWaitMs, () -> "Metric not updated " + metricName + " expected:<" + expectedValue + + "> but was:<" + metricValue(metricName) + ">"); + else + TestUtils.waitForCondition(() -> metricValue(metricName) > 0.0, thisMaxWaitMs, + () -> "Metric not updated " + metricName + " expected: but was:<" + + metricValue(metricName) + ">"); + } + } + + @Override + public void run() { + try { + acceptorThread.start(); + while (serverSocketChannel.isOpen()) { + selector.poll(100); + synchronized (newChannels) { + for (SocketChannel socketChannel : newChannels) { + String id = id(socketChannel); + selector.register(id, socketChannel); + socketChannels.add(socketChannel); + } + newChannels.clear(); + } + if (closeKafkaChannels) { + for (KafkaChannel channel : selector.channels()) + selector.close(channel.id()); + } + + Collection completedReceives = selector.completedReceives(); + for (NetworkReceive rcv : completedReceives) { + KafkaChannel channel = channel(rcv.source()); + if (!maybeBeginServerReauthentication(channel, rcv, time)) { + String channelId = channel.id(); + selector.mute(channelId); + NetworkSend send = new NetworkSend(rcv.source(), ByteBufferSend.sizePrefixed(rcv.payload())); + if (outputChannel == null) + selector.send(send); + else { + send.writeTo(outputChannel); + selector.unmute(channelId); + } + } + } + for (NetworkSend send : selector.completedSends()) { + selector.unmute(send.destinationId()); + numSent += 1; + } + } + } catch (IOException e) { + // ignore + } + } + + public int numSent() { + return numSent; + } + + private static boolean maybeBeginServerReauthentication(KafkaChannel channel, NetworkReceive networkReceive, Time time) { + try { + if (TestUtils.apiKeyFrom(networkReceive) == ApiKeys.SASL_HANDSHAKE) { + return channel.maybeBeginServerReauthentication(networkReceive, time::nanoseconds); + } + } catch (Exception e) { + // ignore + } + return false; + } + + private String id(SocketChannel channel) { + return channel.socket().getLocalAddress().getHostAddress() + ":" + channel.socket().getLocalPort() + "-" + + channel.socket().getInetAddress().getHostAddress() + ":" + channel.socket().getPort(); + } + + private KafkaChannel channel(String id) { + KafkaChannel channel = selector.channel(id); + return channel == null ? selector.closingChannel(id) : channel; + } + + /** + * Sets the output channel to which messages received on this server are echoed. + * This is useful in tests where the clients sending the messages don't receive + * the responses (eg. testing graceful close). + */ + public void outputChannel(WritableByteChannel channel) { + this.outputChannel = new TransferableChannel() { + + @Override + public boolean hasPendingWrites() { + return false; + } + + @Override + public long transferFrom(FileChannel fileChannel, long position, long count) throws IOException { + return fileChannel.transferTo(position, count, channel); + } + + @Override + public boolean isOpen() { + return channel.isOpen(); + } + + @Override + public void close() throws IOException { + channel.close(); + } + + @Override + public int write(ByteBuffer src) throws IOException { + return channel.write(src); + } + + @Override + public long write(ByteBuffer[] srcs, int offset, int length) throws IOException { + long result = 0; + for (int i = offset; i < offset + length; ++i) + result += write(srcs[i]); + return result; + } + + @Override + public long write(ByteBuffer[] srcs) throws IOException { + return write(srcs, 0, srcs.length); + } + }; + } + + public Selector selector() { + return selector; + } + + public void closeKafkaChannels() { + closeKafkaChannels = true; + selector.wakeup(); + try { + TestUtils.waitForCondition(() -> selector.channels().isEmpty(), "Channels not closed"); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } finally { + closeKafkaChannels = false; + } + } + + public void closeSocketChannels() throws IOException { + for (SocketChannel channel : socketChannels) { + channel.close(); + } + socketChannels.clear(); + } + + public void close() throws IOException, InterruptedException { + this.serverSocketChannel.close(); + closeSocketChannels(); + acceptorThread.interrupt(); + acceptorThread.join(); + interrupt(); + join(); + } + + private class AcceptorThread extends Thread { + public AcceptorThread() { + setName("acceptor"); + } + @Override + public void run() { + try { + java.nio.channels.Selector acceptSelector = java.nio.channels.Selector.open(); + serverSocketChannel.register(acceptSelector, SelectionKey.OP_ACCEPT); + while (serverSocketChannel.isOpen()) { + if (acceptSelector.select(1000) > 0) { + Iterator it = acceptSelector.selectedKeys().iterator(); + while (it.hasNext()) { + SelectionKey key = it.next(); + if (key.isAcceptable()) { + SocketChannel socketChannel = ((ServerSocketChannel) key.channel()).accept(); + socketChannel.configureBlocking(false); + newChannels.add(socketChannel); + selector.wakeup(); + } + it.remove(); + } + } + } + } catch (IOException e) { + // ignore + } + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/network/PlaintextSender.java b/clients/src/test/java/org/apache/kafka/common/network/PlaintextSender.java new file mode 100644 index 0000000..3338d03 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/network/PlaintextSender.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.net.Socket; + +/** + * test helper class that will connect to a given server address, write out the given payload and disconnect + */ +public class PlaintextSender extends Thread { + + public PlaintextSender(final InetSocketAddress serverAddress, final byte[] payload) { + super(new Runnable() { + @Override + public void run() { + try (Socket connection = new Socket(serverAddress.getAddress(), serverAddress.getPort()); + OutputStream os = connection.getOutputStream()) { + os.write(payload); + os.flush(); + } catch (Exception e) { + e.printStackTrace(System.err); + } + } + }); + setDaemon(true); + setName("PlaintextSender - " + payload.length + " bytes @ " + serverAddress); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/network/SaslChannelBuilderTest.java b/clients/src/test/java/org/apache/kafka/common/network/SaslChannelBuilderTest.java new file mode 100644 index 0000000..1697c62 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/network/SaslChannelBuilderTest.java @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.SaslConfigs; +import org.apache.kafka.common.config.SslConfigs; +import org.apache.kafka.common.config.internals.BrokerSecurityConfigs; +import org.apache.kafka.common.message.ApiMessageType; +import org.apache.kafka.common.requests.ApiVersionsResponse; +import org.apache.kafka.common.security.TestSecurityConfig; +import org.apache.kafka.common.security.auth.KafkaPrincipal; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.common.security.JaasContext; +import org.apache.kafka.common.security.authenticator.TestJaasConfig; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; +import org.apache.kafka.common.security.plain.PlainLoginModule; +import org.apache.kafka.common.security.scram.ScramLoginModule; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.ietf.jgss.GSSContext; +import org.ietf.jgss.GSSCredential; +import org.ietf.jgss.GSSManager; +import org.ietf.jgss.GSSName; +import org.ietf.jgss.Oid; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import javax.security.auth.Subject; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.login.LoginException; +import javax.security.auth.spi.LoginModule; +import java.lang.reflect.Field; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Supplier; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + + +public class SaslChannelBuilderTest { + + @AfterEach + public void tearDown() { + System.clearProperty(SaslChannelBuilder.GSS_NATIVE_PROP); + } + + @Test + public void testCloseBeforeConfigureIsIdempotent() { + SaslChannelBuilder builder = createChannelBuilder(SecurityProtocol.SASL_PLAINTEXT, "PLAIN"); + builder.close(); + assertTrue(builder.loginManagers().isEmpty()); + builder.close(); + assertTrue(builder.loginManagers().isEmpty()); + } + + @Test + public void testCloseAfterConfigIsIdempotent() { + SaslChannelBuilder builder = createChannelBuilder(SecurityProtocol.SASL_PLAINTEXT, "PLAIN"); + builder.configure(new HashMap<>()); + assertNotNull(builder.loginManagers().get("PLAIN")); + builder.close(); + assertTrue(builder.loginManagers().isEmpty()); + builder.close(); + assertTrue(builder.loginManagers().isEmpty()); + } + + @Test + public void testLoginManagerReleasedIfConfigureThrowsException() { + SaslChannelBuilder builder = createChannelBuilder(SecurityProtocol.SASL_SSL, "PLAIN"); + try { + // Use invalid config so that an exception is thrown + builder.configure(Collections.singletonMap(SslConfigs.SSL_ENABLED_PROTOCOLS_CONFIG, "1")); + fail("Exception should have been thrown"); + } catch (KafkaException e) { + assertTrue(builder.loginManagers().isEmpty()); + } + builder.close(); + assertTrue(builder.loginManagers().isEmpty()); + } + + @Test + public void testNativeGssapiCredentials() throws Exception { + System.setProperty(SaslChannelBuilder.GSS_NATIVE_PROP, "true"); + + TestJaasConfig jaasConfig = new TestJaasConfig(); + jaasConfig.addEntry("jaasContext", TestGssapiLoginModule.class.getName(), new HashMap<>()); + JaasContext jaasContext = new JaasContext("jaasContext", JaasContext.Type.SERVER, jaasConfig, null); + Map jaasContexts = Collections.singletonMap("GSSAPI", jaasContext); + GSSManager gssManager = Mockito.mock(GSSManager.class); + GSSName gssName = Mockito.mock(GSSName.class); + Mockito.when(gssManager.createName(Mockito.anyString(), Mockito.any())) + .thenAnswer(unused -> gssName); + Oid oid = new Oid("1.2.840.113554.1.2.2"); + Mockito.when(gssManager.createCredential(gssName, GSSContext.INDEFINITE_LIFETIME, oid, GSSCredential.ACCEPT_ONLY)) + .thenAnswer(unused -> Mockito.mock(GSSCredential.class)); + + SaslChannelBuilder channelBuilder1 = createGssapiChannelBuilder(jaasContexts, gssManager); + assertEquals(1, channelBuilder1.subject("GSSAPI").getPrincipals().size()); + assertEquals(1, channelBuilder1.subject("GSSAPI").getPrivateCredentials().size()); + + SaslChannelBuilder channelBuilder2 = createGssapiChannelBuilder(jaasContexts, gssManager); + assertEquals(1, channelBuilder2.subject("GSSAPI").getPrincipals().size()); + assertEquals(1, channelBuilder2.subject("GSSAPI").getPrivateCredentials().size()); + assertSame(channelBuilder1.subject("GSSAPI"), channelBuilder2.subject("GSSAPI")); + + Mockito.verify(gssManager, Mockito.times(1)) + .createCredential(gssName, GSSContext.INDEFINITE_LIFETIME, oid, GSSCredential.ACCEPT_ONLY); + } + + /** + * Verify that unparsed broker configs don't break clients. This is to ensure that clients + * created by brokers are not broken if broker configs are passed to clients. + */ + @Test + public void testClientChannelBuilderWithBrokerConfigs() throws Exception { + Map configs = new HashMap<>(); + CertStores certStores = new CertStores(false, "client", "localhost"); + configs.putAll(certStores.getTrustingConfig(certStores)); + configs.put(SaslConfigs.SASL_KERBEROS_SERVICE_NAME, "kafka"); + configs.putAll(new ConfigDef().withClientSaslSupport().parse(configs)); + for (Field field : BrokerSecurityConfigs.class.getFields()) { + if (field.getName().endsWith("_CONFIG")) + configs.put(field.get(BrokerSecurityConfigs.class).toString(), "somevalue"); + } + + SaslChannelBuilder plainBuilder = createChannelBuilder(SecurityProtocol.SASL_PLAINTEXT, "PLAIN"); + plainBuilder.configure(configs); + + SaslChannelBuilder gssapiBuilder = createChannelBuilder(SecurityProtocol.SASL_PLAINTEXT, "GSSAPI"); + gssapiBuilder.configure(configs); + + SaslChannelBuilder oauthBearerBuilder = createChannelBuilder(SecurityProtocol.SASL_PLAINTEXT, "OAUTHBEARER"); + oauthBearerBuilder.configure(configs); + + SaslChannelBuilder scramBuilder = createChannelBuilder(SecurityProtocol.SASL_PLAINTEXT, "SCRAM-SHA-256"); + scramBuilder.configure(configs); + + SaslChannelBuilder saslSslBuilder = createChannelBuilder(SecurityProtocol.SASL_SSL, "PLAIN"); + saslSslBuilder.configure(configs); + } + + private SaslChannelBuilder createGssapiChannelBuilder(Map jaasContexts, GSSManager gssManager) { + SaslChannelBuilder channelBuilder = new SaslChannelBuilder(Mode.SERVER, jaasContexts, + SecurityProtocol.SASL_PLAINTEXT, new ListenerName("GSSAPI"), false, "GSSAPI", + true, null, null, null, Time.SYSTEM, new LogContext(), defaultApiVersionsSupplier()) { + + @Override + protected GSSManager gssManager() { + return gssManager; + } + }; + Map props = Collections.singletonMap(SaslConfigs.SASL_KERBEROS_SERVICE_NAME, "kafka"); + channelBuilder.configure(new TestSecurityConfig(props).values()); + return channelBuilder; + } + + private Supplier defaultApiVersionsSupplier() { + return () -> ApiVersionsResponse.defaultApiVersionsResponse(ApiMessageType.ListenerType.ZK_BROKER); + } + + private SaslChannelBuilder createChannelBuilder(SecurityProtocol securityProtocol, String saslMechanism) { + Class loginModule = null; + switch (saslMechanism) { + case "PLAIN": + loginModule = PlainLoginModule.class; + break; + case "SCRAM-SHA-256": + loginModule = ScramLoginModule.class; + break; + case "OAUTHBEARER": + loginModule = OAuthBearerLoginModule.class; + break; + case "GSSAPI": + loginModule = TestGssapiLoginModule.class; + break; + default: + throw new IllegalArgumentException("Unsupported SASL mechanism " + saslMechanism); + } + TestJaasConfig jaasConfig = new TestJaasConfig(); + jaasConfig.addEntry("jaasContext", loginModule.getName(), new HashMap<>()); + JaasContext jaasContext = new JaasContext("jaasContext", JaasContext.Type.SERVER, jaasConfig, null); + Map jaasContexts = Collections.singletonMap(saslMechanism, jaasContext); + return new SaslChannelBuilder(Mode.CLIENT, jaasContexts, securityProtocol, new ListenerName(saslMechanism), + false, saslMechanism, true, null, + null, null, Time.SYSTEM, new LogContext(), defaultApiVersionsSupplier()); + } + + public static final class TestGssapiLoginModule implements LoginModule { + private Subject subject; + + @Override + public void initialize(Subject subject, CallbackHandler callbackHandler, Map sharedState, Map options) { + this.subject = subject; + } + + @Override + public boolean login() throws LoginException { + subject.getPrincipals().add(new KafkaPrincipal("User", "kafka@kafka1.example.com")); + return true; + } + + @Override + public boolean commit() throws LoginException { + return true; + } + + @Override + public boolean abort() throws LoginException { + return true; + } + + @Override + public boolean logout() throws LoginException { + return true; + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java b/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java new file mode 100644 index 0000000..f276cd4 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java @@ -0,0 +1,1173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.memory.MemoryPool; +import org.apache.kafka.common.memory.SimpleMemoryPool; +import org.apache.kafka.common.metrics.KafkaMetric; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.test.TestCondition; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.lang.reflect.Field; +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import static java.util.Arrays.asList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + + +/** + * A set of tests for the selector. These use a test harness that runs a simple socket server that echos back responses. + */ +@Timeout(240) +public class SelectorTest { + protected static final int BUFFER_SIZE = 4 * 1024; + private static final String METRIC_GROUP = "MetricGroup"; + + protected EchoServer server; + protected Time time; + protected Selector selector; + protected ChannelBuilder channelBuilder; + protected Metrics metrics; + + @BeforeEach + public void setUp() throws Exception { + Map configs = new HashMap<>(); + this.server = new EchoServer(SecurityProtocol.PLAINTEXT, configs); + this.server.start(); + this.time = new MockTime(); + this.channelBuilder = new PlaintextChannelBuilder(ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT)); + this.channelBuilder.configure(clientConfigs()); + this.metrics = new Metrics(); + this.selector = new Selector(5000, this.metrics, time, METRIC_GROUP, channelBuilder, new LogContext()); + } + + @AfterEach + public void tearDown() throws Exception { + try { + verifySelectorEmpty(); + } finally { + this.selector.close(); + this.server.close(); + this.metrics.close(); + } + } + + public SecurityProtocol securityProtocol() { + return SecurityProtocol.PLAINTEXT; + } + + protected Map clientConfigs() { + return new HashMap<>(); + } + + /** + * Validate that when the server disconnects, a client send ends up with that node in the disconnected list. + */ + @Test + public void testServerDisconnect() throws Exception { + final String node = "0"; + + // connect and do a simple request + blockingConnect(node); + assertEquals("hello", blockingRequest(node, "hello")); + + KafkaChannel channel = selector.channel(node); + + // disconnect + this.server.closeConnections(); + TestUtils.waitForCondition(new TestCondition() { + @Override + public boolean conditionMet() { + try { + selector.poll(1000L); + return selector.disconnected().containsKey(node); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + }, 5000, "Failed to observe disconnected node in disconnected set"); + + assertNull(channel.selectionKey().attachment()); + + // reconnect and do another request + blockingConnect(node); + assertEquals("hello", blockingRequest(node, "hello")); + } + + /** + * Sending a request with one already in flight should result in an exception + */ + @Test + public void testCantSendWithInProgress() throws Exception { + String node = "0"; + blockingConnect(node); + selector.send(createSend(node, "test1")); + try { + selector.send(createSend(node, "test2")); + fail("IllegalStateException not thrown when sending a request with one in flight"); + } catch (IllegalStateException e) { + // Expected exception + } + selector.poll(0); + assertTrue(selector.disconnected().containsKey(node), "Channel not closed"); + assertEquals(ChannelState.FAILED_SEND, selector.disconnected().get(node)); + } + + /** + * Sending a request to a node without an existing connection should result in an exception + */ + @Test + public void testSendWithoutConnecting() { + assertThrows(IllegalStateException.class, () -> selector.send(createSend("0", "test"))); + } + + /** + * Sending a request to a node with a bad hostname should result in an exception during connect + */ + @Test + public void testNoRouteToHost() { + assertThrows(IOException.class, + () -> selector.connect("0", new InetSocketAddress("some.invalid.hostname.foo.bar.local", server.port), BUFFER_SIZE, BUFFER_SIZE)); + } + + /** + * Sending a request to a node not listening on that port should result in disconnection + */ + @Test + public void testConnectionRefused() throws Exception { + String node = "0"; + ServerSocket nonListeningSocket = new ServerSocket(0); + int nonListeningPort = nonListeningSocket.getLocalPort(); + selector.connect(node, new InetSocketAddress("localhost", nonListeningPort), BUFFER_SIZE, BUFFER_SIZE); + while (selector.disconnected().containsKey(node)) { + assertEquals(ChannelState.NOT_CONNECTED, selector.disconnected().get(node)); + selector.poll(1000L); + } + nonListeningSocket.close(); + } + + /** + * Send multiple requests to several connections in parallel. Validate that responses are received in the order that + * requests were sent. + */ + @Test + public void testNormalOperation() throws Exception { + int conns = 5; + int reqs = 500; + + // create connections + InetSocketAddress addr = new InetSocketAddress("localhost", server.port); + for (int i = 0; i < conns; i++) + connect(Integer.toString(i), addr); + // send echo requests and receive responses + Map requests = new HashMap<>(); + Map responses = new HashMap<>(); + int responseCount = 0; + for (int i = 0; i < conns; i++) { + String node = Integer.toString(i); + selector.send(createSend(node, node + "-0")); + } + + // loop until we complete all requests + while (responseCount < conns * reqs) { + // do the i/o + selector.poll(0L); + + assertEquals(0, selector.disconnected().size(), "No disconnects should have occurred."); + + // handle any responses we may have gotten + for (NetworkReceive receive : selector.completedReceives()) { + String[] pieces = asString(receive).split("-"); + assertEquals(2, pieces.length, "Should be in the form 'conn-counter'"); + assertEquals(receive.source(), pieces[0], "Check the source"); + assertEquals(0, receive.payload().position(), "Check that the receive has kindly been rewound"); + if (responses.containsKey(receive.source())) { + assertEquals((int) responses.get(receive.source()), Integer.parseInt(pieces[1]), "Check the request counter"); + responses.put(receive.source(), responses.get(receive.source()) + 1); + } else { + assertEquals(0, Integer.parseInt(pieces[1]), "Check the request counter"); + responses.put(receive.source(), 1); + } + responseCount++; + } + + // prepare new sends for the next round + for (NetworkSend send : selector.completedSends()) { + String dest = send.destinationId(); + if (requests.containsKey(dest)) + requests.put(dest, requests.get(dest) + 1); + else + requests.put(dest, 1); + if (requests.get(dest) < reqs) + selector.send(createSend(dest, dest + "-" + requests.get(dest))); + } + } + if (channelBuilder instanceof PlaintextChannelBuilder) { + assertEquals(0, cipherMetrics(metrics).size()); + } else { + TestUtils.waitForCondition(() -> cipherMetrics(metrics).size() == 1, + "Waiting for cipher metrics to be created."); + assertEquals(Integer.valueOf(5), cipherMetrics(metrics).get(0).metricValue()); + } + } + + static List cipherMetrics(Metrics metrics) { + return metrics.metrics().entrySet().stream(). + filter(e -> e.getKey().description(). + contains("The number of connections with this SSL cipher and protocol.")). + map(e -> e.getValue()). + collect(Collectors.toList()); + } + + /** + * Validate that we can send and receive a message larger than the receive and send buffer size + */ + @Test + public void testSendLargeRequest() throws Exception { + String node = "0"; + blockingConnect(node); + String big = TestUtils.randomString(10 * BUFFER_SIZE); + assertEquals(big, blockingRequest(node, big)); + } + + @Test + public void testPartialSendAndReceiveReflectedInMetrics() throws Exception { + // We use a large payload to attempt to trigger the partial send and receive logic. + int payloadSize = 20 * BUFFER_SIZE; + String payload = TestUtils.randomString(payloadSize); + String nodeId = "0"; + blockingConnect(nodeId); + ByteBufferSend send = ByteBufferSend.sizePrefixed(ByteBuffer.wrap(payload.getBytes())); + NetworkSend networkSend = new NetworkSend(nodeId, send); + + selector.send(networkSend); + KafkaChannel channel = selector.channel(nodeId); + + KafkaMetric outgoingByteTotal = findUntaggedMetricByName("outgoing-byte-total"); + KafkaMetric incomingByteTotal = findUntaggedMetricByName("incoming-byte-total"); + + TestUtils.waitForCondition(() -> { + long bytesSent = send.size() - send.remaining(); + assertEquals(bytesSent, ((Double) outgoingByteTotal.metricValue()).longValue()); + + NetworkReceive currentReceive = channel.currentReceive(); + if (currentReceive != null) { + assertEquals(currentReceive.bytesRead(), ((Double) incomingByteTotal.metricValue()).intValue()); + } + + selector.poll(50); + return !selector.completedReceives().isEmpty(); + }, "Failed to receive expected response"); + + KafkaMetric requestTotal = findUntaggedMetricByName("request-total"); + assertEquals(1, ((Double) requestTotal.metricValue()).intValue()); + + KafkaMetric responseTotal = findUntaggedMetricByName("response-total"); + assertEquals(1, ((Double) responseTotal.metricValue()).intValue()); + } + + @Test + public void testLargeMessageSequence() throws Exception { + int bufferSize = 512 * 1024; + String node = "0"; + int reqs = 50; + InetSocketAddress addr = new InetSocketAddress("localhost", server.port); + connect(node, addr); + String requestPrefix = TestUtils.randomString(bufferSize); + sendAndReceive(node, requestPrefix, 0, reqs); + } + + @Test + public void testEmptyRequest() throws Exception { + String node = "0"; + blockingConnect(node); + assertEquals("", blockingRequest(node, "")); + } + + @Test + public void testClearCompletedSendsAndReceives() throws Exception { + int bufferSize = 1024; + String node = "0"; + InetSocketAddress addr = new InetSocketAddress("localhost", server.port); + connect(node, addr); + String request = TestUtils.randomString(bufferSize); + selector.send(createSend(node, request)); + boolean sent = false; + boolean received = false; + while (!sent || !received) { + selector.poll(1000L); + assertEquals(0, selector.disconnected().size(), "No disconnects should have occurred."); + if (!selector.completedSends().isEmpty()) { + assertEquals(1, selector.completedSends().size()); + selector.clearCompletedSends(); + assertEquals(0, selector.completedSends().size()); + sent = true; + } + + if (!selector.completedReceives().isEmpty()) { + assertEquals(1, selector.completedReceives().size()); + assertEquals(request, asString(selector.completedReceives().iterator().next())); + selector.clearCompletedReceives(); + assertEquals(0, selector.completedReceives().size()); + received = true; + } + } + } + + @Test + public void testExistingConnectionId() throws IOException { + blockingConnect("0"); + assertThrows(IllegalStateException.class, () -> blockingConnect("0")); + } + + @Test + public void testMute() throws Exception { + blockingConnect("0"); + blockingConnect("1"); + + selector.send(createSend("0", "hello")); + selector.send(createSend("1", "hi")); + + selector.mute("1"); + + while (selector.completedReceives().isEmpty()) + selector.poll(5); + assertEquals(1, selector.completedReceives().size(), "We should have only one response"); + assertEquals("0", selector.completedReceives().iterator().next().source(), + "The response should not be from the muted node"); + + selector.unmute("1"); + do { + selector.poll(5); + } while (selector.completedReceives().isEmpty()); + assertEquals(1, selector.completedReceives().size(), "We should have only one response"); + assertEquals("1", selector.completedReceives().iterator().next().source(), "The response should be from the previously muted node"); + } + + @Test + public void testCloseAllChannels() throws Exception { + AtomicInteger closedChannelsCount = new AtomicInteger(0); + ChannelBuilder channelBuilder = new PlaintextChannelBuilder(null) { + private int channelIndex = 0; + @Override + KafkaChannel buildChannel(String id, TransportLayer transportLayer, Supplier authenticatorCreator, + int maxReceiveSize, MemoryPool memoryPool, ChannelMetadataRegistry metadataRegistry) { + return new KafkaChannel(id, transportLayer, authenticatorCreator, maxReceiveSize, memoryPool, metadataRegistry) { + private final int index = channelIndex++; + @Override + public void close() throws IOException { + closedChannelsCount.getAndIncrement(); + if (index == 0) throw new RuntimeException("you should fail"); + else super.close(); + } + }; + } + }; + channelBuilder.configure(clientConfigs()); + Selector selector = new Selector(5000, new Metrics(), new MockTime(), "MetricGroup", channelBuilder, new LogContext()); + selector.connect("0", new InetSocketAddress("localhost", server.port), BUFFER_SIZE, BUFFER_SIZE); + selector.connect("1", new InetSocketAddress("localhost", server.port), BUFFER_SIZE, BUFFER_SIZE); + assertThrows(RuntimeException.class, selector::close); + assertEquals(2, closedChannelsCount.get()); + } + + @Test + public void registerFailure() throws Exception { + ChannelBuilder channelBuilder = new PlaintextChannelBuilder(null) { + @Override + public KafkaChannel buildChannel(String id, SelectionKey key, int maxReceiveSize, + MemoryPool memoryPool, ChannelMetadataRegistry metadataRegistry) throws KafkaException { + throw new RuntimeException("Test exception"); + } + @Override + public void close() { + } + }; + Selector selector = new Selector(5000, new Metrics(), new MockTime(), "MetricGroup", channelBuilder, new LogContext()); + SocketChannel socketChannel = SocketChannel.open(); + socketChannel.configureBlocking(false); + IOException e = assertThrows(IOException.class, () -> selector.register("1", socketChannel)); + assertTrue(e.getCause().getMessage().contains("Test exception"), "Unexpected exception: " + e); + assertFalse(socketChannel.isOpen(), "Socket not closed"); + selector.close(); + } + + @Test + public void testCloseOldestConnection() throws Exception { + String id = "0"; + blockingConnect(id); + + time.sleep(6000); // The max idle time is 5000ms + selector.poll(0); + + assertTrue(selector.disconnected().containsKey(id), "The idle connection should have been closed"); + assertEquals(ChannelState.EXPIRED, selector.disconnected().get(id)); + } + + @Test + public void testIdleExpiryWithoutReadyKeys() throws IOException { + String id = "0"; + selector.connect(id, new InetSocketAddress("localhost", server.port), BUFFER_SIZE, BUFFER_SIZE); + KafkaChannel channel = selector.channel(id); + channel.selectionKey().interestOps(0); + + time.sleep(6000); // The max idle time is 5000ms + selector.poll(0); + assertTrue(selector.disconnected().containsKey(id), "The idle connection should have been closed"); + assertEquals(ChannelState.EXPIRED, selector.disconnected().get(id)); + } + + @Test + public void testImmediatelyConnectedCleaned() throws Exception { + Metrics metrics = new Metrics(); // new metrics object to avoid metric registration conflicts + Selector selector = new ImmediatelyConnectingSelector(5000, metrics, time, "MetricGroup", channelBuilder, new LogContext()); + + try { + testImmediatelyConnectedCleaned(selector, true); + testImmediatelyConnectedCleaned(selector, false); + } finally { + selector.close(); + metrics.close(); + } + } + + private static class ImmediatelyConnectingSelector extends Selector { + public ImmediatelyConnectingSelector(long connectionMaxIdleMS, + Metrics metrics, + Time time, + String metricGrpPrefix, + ChannelBuilder channelBuilder, + LogContext logContext) { + super(connectionMaxIdleMS, metrics, time, metricGrpPrefix, channelBuilder, logContext); + } + + @Override + protected boolean doConnect(SocketChannel channel, InetSocketAddress address) throws IOException { + // Use a blocking connect to trigger the immediately connected path + channel.configureBlocking(true); + boolean connected = super.doConnect(channel, address); + channel.configureBlocking(false); + return connected; + } + } + + private void testImmediatelyConnectedCleaned(Selector selector, boolean closeAfterFirstPoll) throws Exception { + String id = "0"; + selector.connect(id, new InetSocketAddress("localhost", server.port), BUFFER_SIZE, BUFFER_SIZE); + verifyNonEmptyImmediatelyConnectedKeys(selector); + if (closeAfterFirstPoll) { + selector.poll(0); + verifyEmptyImmediatelyConnectedKeys(selector); + } + selector.close(id); + verifySelectorEmpty(selector); + } + + /** + * Verify that if Selector#connect fails and throws an Exception, all related objects + * are cleared immediately before the exception is propagated. + */ + @Test + public void testConnectException() throws Exception { + Metrics metrics = new Metrics(); + AtomicBoolean throwIOException = new AtomicBoolean(); + Selector selector = new ImmediatelyConnectingSelector(5000, metrics, time, "MetricGroup", channelBuilder, new LogContext()) { + @Override + protected SelectionKey registerChannel(String id, SocketChannel socketChannel, int interestedOps) throws IOException { + SelectionKey key = super.registerChannel(id, socketChannel, interestedOps); + key.cancel(); + if (throwIOException.get()) + throw new IOException("Test exception"); + return key; + } + }; + + try { + verifyImmediatelyConnectedException(selector, "0"); + throwIOException.set(true); + verifyImmediatelyConnectedException(selector, "1"); + } finally { + selector.close(); + metrics.close(); + } + } + + private void verifyImmediatelyConnectedException(Selector selector, String id) throws Exception { + try { + selector.connect(id, new InetSocketAddress("localhost", server.port), BUFFER_SIZE, BUFFER_SIZE); + fail("Expected exception not thrown"); + } catch (Exception e) { + verifyEmptyImmediatelyConnectedKeys(selector); + assertNull(selector.channel(id), "Channel not removed"); + ensureEmptySelectorFields(selector); + } + } + + /* + * Verifies that a muted connection is expired on idle timeout even if there are pending + * receives on the socket. + */ + @Test + public void testExpireConnectionWithPendingReceives() throws Exception { + KafkaChannel channel = createConnectionWithPendingReceives(5); + verifyChannelExpiry(channel); + } + + /** + * Verifies that a muted connection closed by peer is expired on idle timeout even if there are pending + * receives on the socket. + */ + @Test + public void testExpireClosedConnectionWithPendingReceives() throws Exception { + KafkaChannel channel = createConnectionWithPendingReceives(5); + server.closeConnections(); + verifyChannelExpiry(channel); + } + + private void verifyChannelExpiry(KafkaChannel channel) throws Exception { + String id = channel.id(); + selector.mute(id); // Mute to allow channel to be expired even if more data is available for read + time.sleep(6000); // The max idle time is 5000ms + selector.poll(0); + assertNull(selector.channel(id), "Channel not expired"); + assertNull(selector.closingChannel(id), "Channel not removed from closingChannels"); + assertEquals(ChannelState.EXPIRED, channel.state()); + assertNull(channel.selectionKey().attachment()); + assertTrue(selector.disconnected().containsKey(id), "Disconnect not notified"); + assertEquals(ChannelState.EXPIRED, selector.disconnected().get(id)); + verifySelectorEmpty(); + } + + /** + * Verifies that sockets with incoming data available are not expired. + * For PLAINTEXT, pending receives are always read from socket without any buffering, so this + * test is only verifying that channels are not expired while there is data to read from socket. + * For SSL, pending receives may also be in SSL netReadBuffer or appReadBuffer. So the test verifies + * that connection is not expired when data is available from buffers or network. + */ + @Test + public void testCloseOldestConnectionWithMultiplePendingReceives() throws Exception { + int expectedReceives = 5; + KafkaChannel channel = createConnectionWithPendingReceives(expectedReceives); + String id = channel.id(); + int completedReceives = 0; + while (selector.disconnected().isEmpty()) { + time.sleep(6000); // The max idle time is 5000ms + selector.poll(completedReceives == expectedReceives ? 0 : 1000); + completedReceives += selector.completedReceives().size(); + if (!selector.completedReceives().isEmpty()) { + assertEquals(1, selector.completedReceives().size()); + assertNotNull(selector.channel(id), "Channel should not have been expired"); + assertTrue(selector.closingChannel(id) != null || selector.channel(id) != null, "Channel not found"); + assertFalse(selector.disconnected().containsKey(id), "Disconnect notified too early"); + } + } + assertEquals(expectedReceives, completedReceives); + assertNull(selector.channel(id), "Channel not removed"); + assertNull(selector.closingChannel(id), "Channel not removed"); + assertTrue(selector.disconnected().containsKey(id), "Disconnect not notified"); + assertTrue(selector.completedReceives().isEmpty(), "Unexpected receive"); + } + + /** + * Tests that graceful close of channel processes remaining data from socket read buffers. + * Since we cannot determine how much data is available in the buffers, this test verifies that + * multiple receives are completed after server shuts down connections, with retries to tolerate + * cases where data may not be available in the socket buffer. + */ + @Test + public void testGracefulClose() throws Exception { + int maxReceiveCountAfterClose = 0; + for (int i = 6; i <= 100 && maxReceiveCountAfterClose < 5; i++) { + int receiveCount = 0; + KafkaChannel channel = createConnectionWithPendingReceives(i); + // Poll until one or more receives complete and then close the server-side connection + TestUtils.waitForCondition(() -> { + selector.poll(1000); + return selector.completedReceives().size() > 0; + }, 5000, "Receive not completed"); + server.closeConnections(); + while (selector.disconnected().isEmpty()) { + selector.poll(1); + receiveCount += selector.completedReceives().size(); + assertTrue(selector.completedReceives().size() <= 1, "Too many completed receives in one poll"); + } + assertEquals(channel.id(), selector.disconnected().keySet().iterator().next()); + maxReceiveCountAfterClose = Math.max(maxReceiveCountAfterClose, receiveCount); + } + assertTrue(maxReceiveCountAfterClose >= 5, "Too few receives after close: " + maxReceiveCountAfterClose); + } + + /** + * Tests that graceful close is not delayed if only part of an incoming receive is + * available in the socket buffer. + */ + @Test + public void testPartialReceiveGracefulClose() throws Exception { + String id = "0"; + blockingConnect(id); + KafkaChannel channel = selector.channel(id); + // Inject a NetworkReceive into Kafka channel with a large size + injectNetworkReceive(channel, 100000); + sendNoReceive(channel, 2); // Send some data that gets received as part of injected receive + selector.poll(1000); // Wait until some data arrives, but not a completed receive + assertEquals(0, selector.completedReceives().size()); + server.closeConnections(); + TestUtils.waitForCondition(() -> { + try { + selector.poll(100); + return !selector.disconnected().isEmpty(); + } catch (IOException e) { + throw new RuntimeException(e); + } + }, 10000, "Channel not disconnected"); + assertEquals(1, selector.disconnected().size()); + assertEquals(channel.id(), selector.disconnected().keySet().iterator().next()); + assertEquals(0, selector.completedReceives().size()); + } + + @Test + public void testMuteOnOOM() throws Exception { + //clean up default selector, replace it with one that uses a finite mem pool + selector.close(); + MemoryPool pool = new SimpleMemoryPool(900, 900, false, null); + selector = new Selector(NetworkReceive.UNLIMITED, 5000, metrics, time, "MetricGroup", + new HashMap(), true, false, channelBuilder, pool, new LogContext()); + + try (ServerSocketChannel ss = ServerSocketChannel.open()) { + ss.bind(new InetSocketAddress(0)); + + InetSocketAddress serverAddress = (InetSocketAddress) ss.getLocalAddress(); + + Thread sender1 = createSender(serverAddress, randomPayload(900)); + Thread sender2 = createSender(serverAddress, randomPayload(900)); + sender1.start(); + sender2.start(); + + //wait until everything has been flushed out to network (assuming payload size is smaller than OS buffer size) + //this is important because we assume both requests' prefixes (1st 4 bytes) have made it. + sender1.join(5000); + sender2.join(5000); + + SocketChannel channelX = ss.accept(); //not defined if its 1 or 2 + channelX.configureBlocking(false); + SocketChannel channelY = ss.accept(); + channelY.configureBlocking(false); + selector.register("clientX", channelX); + selector.register("clientY", channelY); + + Collection completed = Collections.emptyList(); + long deadline = System.currentTimeMillis() + 5000; + while (System.currentTimeMillis() < deadline && completed.isEmpty()) { + selector.poll(1000); + completed = selector.completedReceives(); + } + assertEquals(1, completed.size(), "could not read a single request within timeout"); + NetworkReceive firstReceive = completed.iterator().next(); + assertEquals(0, pool.availableMemory()); + assertTrue(selector.isOutOfMemory()); + + selector.poll(10); + assertTrue(selector.completedReceives().isEmpty()); + assertEquals(0, pool.availableMemory()); + assertTrue(selector.isOutOfMemory()); + + firstReceive.close(); + assertEquals(900, pool.availableMemory()); //memory has been released back to pool + + completed = Collections.emptyList(); + deadline = System.currentTimeMillis() + 5000; + while (System.currentTimeMillis() < deadline && completed.isEmpty()) { + selector.poll(1000); + completed = selector.completedReceives(); + } + assertEquals(1, selector.completedReceives().size(), "could not read a single request within timeout"); + assertEquals(0, pool.availableMemory()); + assertFalse(selector.isOutOfMemory()); + } + } + + private Thread createSender(InetSocketAddress serverAddress, byte[] payload) { + return new PlaintextSender(serverAddress, payload); + } + + protected byte[] randomPayload(int sizeBytes) throws Exception { + Random random = new Random(); + byte[] payload = new byte[sizeBytes + 4]; + random.nextBytes(payload); + ByteArrayOutputStream prefixOs = new ByteArrayOutputStream(); + DataOutputStream prefixDos = new DataOutputStream(prefixOs); + prefixDos.writeInt(sizeBytes); + prefixDos.flush(); + prefixDos.close(); + prefixOs.flush(); + prefixOs.close(); + byte[] prefix = prefixOs.toByteArray(); + System.arraycopy(prefix, 0, payload, 0, prefix.length); + return payload; + } + + /** + * Tests that a connect and disconnect in a single poll invocation results in the channel id being + * in `disconnected`, but not `connected`. + */ + @Test + public void testConnectDisconnectDuringInSinglePoll() throws Exception { + // channel is connected, not ready and it throws an exception during prepare + KafkaChannel kafkaChannel = mock(KafkaChannel.class); + when(kafkaChannel.id()).thenReturn("1"); + when(kafkaChannel.socketDescription()).thenReturn(""); + when(kafkaChannel.state()).thenReturn(ChannelState.NOT_CONNECTED); + when(kafkaChannel.finishConnect()).thenReturn(true); + when(kafkaChannel.isConnected()).thenReturn(true); + when(kafkaChannel.ready()).thenReturn(false); + doThrow(new IOException()).when(kafkaChannel).prepare(); + + SelectionKey selectionKey = mock(SelectionKey.class); + when(kafkaChannel.selectionKey()).thenReturn(selectionKey); + when(selectionKey.channel()).thenReturn(SocketChannel.open()); + when(selectionKey.readyOps()).thenReturn(SelectionKey.OP_CONNECT); + + selectionKey.attach(kafkaChannel); + Set selectionKeys = Utils.mkSet(selectionKey); + selector.pollSelectionKeys(selectionKeys, false, System.nanoTime()); + + assertFalse(selector.connected().contains(kafkaChannel.id())); + assertTrue(selector.disconnected().containsKey(kafkaChannel.id())); + assertNull(selectionKey.attachment()); + + verify(kafkaChannel, atLeastOnce()).ready(); + verify(kafkaChannel).disconnect(); + verify(kafkaChannel).close(); + verify(selectionKey).cancel(); + } + + @Test + public void testOutboundConnectionsCountInConnectionCreationMetric() throws Exception { + // create connections + int expectedConnections = 5; + InetSocketAddress addr = new InetSocketAddress("localhost", server.port); + for (int i = 0; i < expectedConnections; i++) + connect(Integer.toString(i), addr); + + // Poll continuously, as we cannot guarantee that the first call will see all connections + int seenConnections = 0; + for (int i = 0; i < 10; i++) { + selector.poll(100L); + seenConnections += selector.connected().size(); + if (seenConnections == expectedConnections) + break; + } + + assertEquals((double) expectedConnections, getMetric("connection-creation-total").metricValue()); + assertEquals((double) expectedConnections, getMetric("connection-count").metricValue()); + } + + @Test + public void testInboundConnectionsCountInConnectionCreationMetric() throws Exception { + int conns = 5; + + try (ServerSocketChannel ss = ServerSocketChannel.open()) { + ss.bind(new InetSocketAddress(0)); + InetSocketAddress serverAddress = (InetSocketAddress) ss.getLocalAddress(); + + for (int i = 0; i < conns; i++) { + Thread sender = createSender(serverAddress, randomPayload(1)); + sender.start(); + SocketChannel channel = ss.accept(); + channel.configureBlocking(false); + + selector.register(Integer.toString(i), channel); + } + } + + assertEquals((double) conns, getMetric("connection-creation-total").metricValue()); + assertEquals((double) conns, getMetric("connection-count").metricValue()); + } + + @Test + public void testConnectionsByClientMetric() throws Exception { + String node = "0"; + Map unknownNameAndVersion = softwareNameAndVersionTags( + ClientInformation.UNKNOWN_NAME_OR_VERSION, ClientInformation.UNKNOWN_NAME_OR_VERSION); + Map knownNameAndVersion = softwareNameAndVersionTags("A", "B"); + + try (ServerSocketChannel ss = ServerSocketChannel.open()) { + ss.bind(new InetSocketAddress(0)); + InetSocketAddress serverAddress = (InetSocketAddress) ss.getLocalAddress(); + + Thread sender = createSender(serverAddress, randomPayload(1)); + sender.start(); + SocketChannel channel = ss.accept(); + channel.configureBlocking(false); + + // Metric with unknown / unknown should be there + selector.register(node, channel); + assertEquals(1, + getMetric("connections", unknownNameAndVersion).metricValue()); + assertEquals(ClientInformation.EMPTY, + selector.channel(node).channelMetadataRegistry().clientInformation()); + + // Metric with unknown / unknown should not be there, metric with A / B should be there + ClientInformation clientInformation = new ClientInformation("A", "B"); + selector.channel(node).channelMetadataRegistry() + .registerClientInformation(clientInformation); + assertEquals(clientInformation, + selector.channel(node).channelMetadataRegistry().clientInformation()); + assertEquals(0, getMetric("connections", unknownNameAndVersion).metricValue()); + assertEquals(1, getMetric("connections", knownNameAndVersion).metricValue()); + + // Metric with A / B should not be there, + selector.close(node); + assertEquals(0, getMetric("connections", knownNameAndVersion).metricValue()); + } + } + + private Map softwareNameAndVersionTags(String clientSoftwareName, String clientSoftwareVersion) { + Map tags = new HashMap<>(2); + tags.put("clientSoftwareName", clientSoftwareName); + tags.put("clientSoftwareVersion", clientSoftwareVersion); + return tags; + } + + private KafkaMetric getMetric(String name, Map tags) throws Exception { + Optional> metric = metrics.metrics().entrySet().stream() + .filter(entry -> + entry.getKey().name().equals(name) && entry.getKey().tags().equals(tags)) + .findFirst(); + if (!metric.isPresent()) + throw new Exception(String.format("Could not find metric called %s with tags %s", name, tags.toString())); + + return metric.get().getValue(); + } + + @SuppressWarnings("unchecked") + @Test + public void testLowestPriorityChannel() throws Exception { + int conns = 5; + InetSocketAddress addr = new InetSocketAddress("localhost", server.port); + for (int i = 0; i < conns; i++) { + connect(String.valueOf(i), addr); + } + assertNotNull(selector.lowestPriorityChannel()); + for (int i = conns - 1; i >= 0; i--) { + if (i != 2) + assertEquals("", blockingRequest(String.valueOf(i), "")); + time.sleep(10); + } + assertEquals("2", selector.lowestPriorityChannel().id()); + + Field field = Selector.class.getDeclaredField("closingChannels"); + field.setAccessible(true); + Map closingChannels = (Map) field.get(selector); + closingChannels.put("3", selector.channel("3")); + assertEquals("3", selector.lowestPriorityChannel().id()); + closingChannels.remove("3"); + + for (int i = 0; i < conns; i++) { + selector.close(String.valueOf(i)); + } + assertNull(selector.lowestPriorityChannel()); + } + + @Test + public void testMetricsCleanupOnSelectorClose() throws Exception { + Metrics metrics = new Metrics(); + Selector selector = new ImmediatelyConnectingSelector(5000, metrics, time, "MetricGroup", channelBuilder, new LogContext()) { + @Override + public void close(String id) { + throw new RuntimeException(); + } + }; + assertTrue(metrics.metrics().size() > 1); + String id = "0"; + selector.connect(id, new InetSocketAddress("localhost", server.port), BUFFER_SIZE, BUFFER_SIZE); + + // Close the selector and ensure a RuntimeException has been throw + assertThrows(RuntimeException.class, selector::close); + + // We should only have one remaining metric for kafka-metrics-count, which is a global metric + assertEquals(1, metrics.metrics().size()); + } + + @Test + public void testWriteCompletesSendWithNoBytesWritten() throws IOException { + KafkaChannel channel = mock(KafkaChannel.class); + when(channel.id()).thenReturn("1"); + when(channel.write()).thenReturn(0L); + NetworkSend send = new NetworkSend("destination", new ByteBufferSend(ByteBuffer.allocate(0))); + when(channel.maybeCompleteSend()).thenReturn(send); + selector.write(channel); + assertEquals(asList(send), selector.completedSends()); + } + + /** + * Ensure that no errors are thrown if channels are closed while processing multiple completed receives + */ + @Test + public void testChannelCloseWhileProcessingReceives() throws Exception { + int numChannels = 4; + Map channels = TestUtils.fieldValue(selector, Selector.class, "channels"); + Set selectionKeys = new HashSet<>(); + for (int i = 0; i < numChannels; i++) { + String id = String.valueOf(i); + KafkaChannel channel = mock(KafkaChannel.class); + channels.put(id, channel); + when(channel.id()).thenReturn(id); + when(channel.state()).thenReturn(ChannelState.READY); + when(channel.isConnected()).thenReturn(true); + when(channel.ready()).thenReturn(true); + when(channel.read()).thenReturn(1L); + + SelectionKey selectionKey = mock(SelectionKey.class); + when(channel.selectionKey()).thenReturn(selectionKey); + when(selectionKey.isValid()).thenReturn(true); + when(selectionKey.readyOps()).thenReturn(SelectionKey.OP_READ); + selectionKey.attach(channel); + selectionKeys.add(selectionKey); + + NetworkReceive receive = mock(NetworkReceive.class); + when(receive.source()).thenReturn(id); + when(receive.size()).thenReturn(10); + when(receive.bytesRead()).thenReturn(1); + when(receive.payload()).thenReturn(ByteBuffer.allocate(10)); + when(channel.maybeCompleteReceive()).thenReturn(receive); + } + + selector.pollSelectionKeys(selectionKeys, false, System.nanoTime()); + assertEquals(numChannels, selector.completedReceives().size()); + Set closed = new HashSet<>(); + Set notClosed = new HashSet<>(); + for (NetworkReceive receive : selector.completedReceives()) { + KafkaChannel channel = selector.channel(receive.source()); + assertNotNull(channel); + if (closed.size() < 2) { + selector.close(channel.id()); + closed.add(channel); + } else + notClosed.add(channel); + } + assertEquals(notClosed, new HashSet<>(selector.channels())); + closed.forEach(channel -> assertNull(selector.channel(channel.id()))); + + selector.poll(0); + assertEquals(0, selector.completedReceives().size()); + } + + + private String blockingRequest(String node, String s) throws IOException { + selector.send(createSend(node, s)); + selector.poll(1000L); + while (true) { + selector.poll(1000L); + for (NetworkReceive receive : selector.completedReceives()) + if (receive.source().equals(node)) + return asString(receive); + } + } + + protected void connect(String node, InetSocketAddress serverAddr) throws IOException { + selector.connect(node, serverAddr, BUFFER_SIZE, BUFFER_SIZE); + } + + /* connect and wait for the connection to complete */ + private void blockingConnect(String node) throws IOException { + blockingConnect(node, new InetSocketAddress("localhost", server.port)); + } + + protected void blockingConnect(String node, InetSocketAddress serverAddr) throws IOException { + selector.connect(node, serverAddr, BUFFER_SIZE, BUFFER_SIZE); + while (!selector.connected().contains(node)) + selector.poll(10000L); + while (!selector.isChannelReady(node)) + selector.poll(10000L); + } + + protected final NetworkSend createSend(String node, String payload) { + return new NetworkSend(node, ByteBufferSend.sizePrefixed(ByteBuffer.wrap(payload.getBytes()))); + } + + protected String asString(NetworkReceive receive) { + return new String(Utils.toArray(receive.payload())); + } + + private void sendAndReceive(String node, String requestPrefix, int startIndex, int endIndex) throws Exception { + int requests = startIndex; + int responses = startIndex; + selector.send(createSend(node, requestPrefix + "-" + startIndex)); + requests++; + while (responses < endIndex) { + // do the i/o + selector.poll(0L); + assertEquals(0, selector.disconnected().size(), "No disconnects should have occurred."); + // handle requests and responses of the fast node + for (NetworkReceive receive : selector.completedReceives()) { + assertEquals(requestPrefix + "-" + responses, asString(receive)); + responses++; + } + + for (int i = 0; i < selector.completedSends().size() && requests < endIndex; i++, requests++) { + selector.send(createSend(node, requestPrefix + "-" + requests)); + } + } + } + + private void verifyNonEmptyImmediatelyConnectedKeys(Selector selector) throws Exception { + Field field = Selector.class.getDeclaredField("immediatelyConnectedKeys"); + field.setAccessible(true); + Collection immediatelyConnectedKeys = (Collection) field.get(selector); + assertFalse(immediatelyConnectedKeys.isEmpty()); + } + + private void verifyEmptyImmediatelyConnectedKeys(Selector selector) throws Exception { + Field field = Selector.class.getDeclaredField("immediatelyConnectedKeys"); + ensureEmptySelectorField(selector, field); + } + + protected void verifySelectorEmpty() throws Exception { + verifySelectorEmpty(this.selector); + } + + public void verifySelectorEmpty(Selector selector) throws Exception { + for (KafkaChannel channel : selector.channels()) { + selector.close(channel.id()); + assertNull(channel.selectionKey().attachment()); + } + selector.poll(0); + selector.poll(0); // Poll a second time to clear everything + ensureEmptySelectorFields(selector); + } + + private void ensureEmptySelectorFields(Selector selector) throws Exception { + for (Field field : Selector.class.getDeclaredFields()) { + ensureEmptySelectorField(selector, field); + } + } + + private void ensureEmptySelectorField(Selector selector, Field field) throws Exception { + field.setAccessible(true); + Object obj = field.get(selector); + if (obj instanceof Collection) + assertTrue(((Collection) obj).isEmpty(), "Field not empty: " + field + " " + obj); + else if (obj instanceof Map) + assertTrue(((Map) obj).isEmpty(), "Field not empty: " + field + " " + obj); + } + + private KafkaMetric getMetric(String name) throws Exception { + Optional> metric = metrics.metrics().entrySet().stream() + .filter(entry -> entry.getKey().name().equals(name)) + .findFirst(); + if (!metric.isPresent()) + throw new Exception(String.format("Could not find metric called %s", name)); + + return metric.get().getValue(); + } + + private KafkaMetric findUntaggedMetricByName(String name) { + MetricName metricName = new MetricName(name, METRIC_GROUP + "-metrics", "", new HashMap<>()); + KafkaMetric metric = metrics.metrics().get(metricName); + assertNotNull(metric); + return metric; + } + + /** + * Creates a connection, sends the specified number of requests and returns without reading + * any incoming data. Some of the incoming data may be in the socket buffers when this method + * returns, but there is no guarantee that all the data from the server will be available + * immediately. + */ + private KafkaChannel createConnectionWithPendingReceives(int pendingReceives) throws Exception { + String id = "0"; + blockingConnect(id); + KafkaChannel channel = selector.channel(id); + sendNoReceive(channel, pendingReceives); + return channel; + } + + /** + * Sends the specified number of requests and waits for the requests to be sent. The channel + * is muted during polling to ensure that incoming data is not received. + */ + private KafkaChannel sendNoReceive(KafkaChannel channel, int numRequests) throws Exception { + channel.mute(); + for (int i = 0; i < numRequests; i++) { + selector.send(createSend(channel.id(), String.valueOf(i))); + do { + selector.poll(10); + } while (selector.completedSends().isEmpty()); + } + channel.maybeUnmute(); + + return channel; + } + + /** + * Injects a NetworkReceive for channel with size buffer filled in with the provided size + * and a payload buffer allocated with that size, but no data in the payload buffer. + */ + private void injectNetworkReceive(KafkaChannel channel, int size) throws Exception { + NetworkReceive receive = new NetworkReceive(); + TestUtils.setFieldValue(channel, "receive", receive); + ByteBuffer sizeBuffer = TestUtils.fieldValue(receive, NetworkReceive.class, "size"); + sizeBuffer.putInt(size); + TestUtils.setFieldValue(receive, "buffer", ByteBuffer.allocate(size)); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/network/SslSelectorTest.java b/clients/src/test/java/org/apache/kafka/common/network/SslSelectorTest.java new file mode 100644 index 0000000..7f95566 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/network/SslSelectorTest.java @@ -0,0 +1,427 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import java.nio.channels.SelectionKey; +import javax.net.ssl.SSLEngine; + +import org.apache.kafka.common.config.SecurityConfig; +import org.apache.kafka.common.memory.MemoryPool; +import org.apache.kafka.common.memory.SimpleMemoryPool; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.common.security.ssl.SslFactory; +import org.apache.kafka.common.security.ssl.mock.TestKeyManagerFactory; +import org.apache.kafka.common.security.ssl.mock.TestProviderCreator; +import org.apache.kafka.common.security.ssl.mock.TestTrustManagerFactory; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.test.TestSslUtils; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.File; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.security.Security; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * A set of tests for the selector. These use a test harness that runs a simple socket server that echos back responses. + */ +public class SslSelectorTest extends SelectorTest { + + private Map sslClientConfigs; + + @BeforeEach + public void setUp() throws Exception { + File trustStoreFile = File.createTempFile("truststore", ".jks"); + + Map sslServerConfigs = TestSslUtils.createSslConfig(false, true, Mode.SERVER, trustStoreFile, "server"); + this.server = new EchoServer(SecurityProtocol.SSL, sslServerConfigs); + this.server.start(); + this.time = new MockTime(); + sslClientConfigs = TestSslUtils.createSslConfig(false, false, Mode.CLIENT, trustStoreFile, "client"); + LogContext logContext = new LogContext(); + this.channelBuilder = new SslChannelBuilder(Mode.CLIENT, null, false, logContext); + this.channelBuilder.configure(sslClientConfigs); + this.metrics = new Metrics(); + this.selector = new Selector(5000, metrics, time, "MetricGroup", channelBuilder, logContext); + } + + @AfterEach + public void tearDown() throws Exception { + this.selector.close(); + this.server.close(); + this.metrics.close(); + } + + @Override + public SecurityProtocol securityProtocol() { + return SecurityProtocol.PLAINTEXT; + } + + @Override + protected Map clientConfigs() { + return sslClientConfigs; + } + + @Test + public void testConnectionWithCustomKeyManager() throws Exception { + + TestProviderCreator testProviderCreator = new TestProviderCreator(); + + int requestSize = 100 * 1024; + final String node = "0"; + String request = TestUtils.randomString(requestSize); + + Map sslServerConfigs = TestSslUtils.createSslConfig( + TestKeyManagerFactory.ALGORITHM, + TestTrustManagerFactory.ALGORITHM, + TestSslUtils.DEFAULT_TLS_PROTOCOL_FOR_TESTS + ); + sslServerConfigs.put(SecurityConfig.SECURITY_PROVIDERS_CONFIG, testProviderCreator.getClass().getName()); + EchoServer server = new EchoServer(SecurityProtocol.SSL, sslServerConfigs); + server.start(); + Time time = new MockTime(); + File trustStoreFile = new File(TestKeyManagerFactory.TestKeyManager.mockTrustStoreFile); + Map sslClientConfigs = TestSslUtils.createSslConfig(true, true, Mode.CLIENT, trustStoreFile, "client"); + + ChannelBuilder channelBuilder = new TestSslChannelBuilder(Mode.CLIENT); + channelBuilder.configure(sslClientConfigs); + Metrics metrics = new Metrics(); + Selector selector = new Selector(5000, metrics, time, "MetricGroup", channelBuilder, new LogContext()); + + selector.connect(node, new InetSocketAddress("localhost", server.port), BUFFER_SIZE, BUFFER_SIZE); + while (!selector.connected().contains(node)) + selector.poll(10000L); + while (!selector.isChannelReady(node)) + selector.poll(10000L); + + selector.send(createSend(node, request)); + + waitForBytesBuffered(selector, node); + + TestUtils.waitForCondition(() -> cipherMetrics(metrics).size() == 1, + "Waiting for cipher metrics to be created."); + assertEquals(Integer.valueOf(1), cipherMetrics(metrics).get(0).metricValue()); + assertNotNull(selector.channel(node).channelMetadataRegistry().cipherInformation()); + + selector.close(node); + super.verifySelectorEmpty(selector); + + assertEquals(1, cipherMetrics(metrics).size()); + assertEquals(Integer.valueOf(0), cipherMetrics(metrics).get(0).metricValue()); + + Security.removeProvider(testProviderCreator.getProvider().getName()); + selector.close(); + server.close(); + metrics.close(); + } + + @Test + public void testDisconnectWithIntermediateBufferedBytes() throws Exception { + int requestSize = 100 * 1024; + final String node = "0"; + String request = TestUtils.randomString(requestSize); + + this.selector.close(); + + this.channelBuilder = new TestSslChannelBuilder(Mode.CLIENT); + this.channelBuilder.configure(sslClientConfigs); + this.selector = new Selector(5000, metrics, time, "MetricGroup", channelBuilder, new LogContext()); + connect(node, new InetSocketAddress("localhost", server.port)); + selector.send(createSend(node, request)); + + waitForBytesBuffered(selector, node); + + selector.close(node); + verifySelectorEmpty(); + } + + private void waitForBytesBuffered(Selector selector, String node) throws Exception { + TestUtils.waitForCondition(() -> { + try { + selector.poll(0L); + return selector.channel(node).hasBytesBuffered(); + } catch (IOException e) { + throw new RuntimeException(e); + } + }, 2000L, "Failed to reach socket state with bytes buffered"); + } + + @Test + public void testBytesBufferedChannelWithNoIncomingBytes() throws Exception { + verifyNoUnnecessaryPollWithBytesBuffered(key -> + key.interestOps(key.interestOps() & ~SelectionKey.OP_READ)); + } + + @Test + public void testBytesBufferedChannelAfterMute() throws Exception { + verifyNoUnnecessaryPollWithBytesBuffered(key -> ((KafkaChannel) key.attachment()).mute()); + } + + private void verifyNoUnnecessaryPollWithBytesBuffered(Consumer disableRead) + throws Exception { + this.selector.close(); + + String node1 = "1"; + String node2 = "2"; + final AtomicInteger node1Polls = new AtomicInteger(); + + this.channelBuilder = new TestSslChannelBuilder(Mode.CLIENT); + this.channelBuilder.configure(sslClientConfigs); + this.selector = new Selector(5000, metrics, time, "MetricGroup", channelBuilder, new LogContext()) { + @Override + void pollSelectionKeys(Set selectionKeys, boolean isImmediatelyConnected, long currentTimeNanos) { + for (SelectionKey key : selectionKeys) { + KafkaChannel channel = (KafkaChannel) key.attachment(); + if (channel != null && channel.id().equals(node1)) + node1Polls.incrementAndGet(); + } + super.pollSelectionKeys(selectionKeys, isImmediatelyConnected, currentTimeNanos); + } + }; + + // Get node1 into bytes buffered state and then disable read on the socket. + // Truncate the read buffers to ensure that there is buffered data, but not enough to make progress. + int largeRequestSize = 100 * 1024; + connect(node1, new InetSocketAddress("localhost", server.port)); + selector.send(createSend(node1, TestUtils.randomString(largeRequestSize))); + waitForBytesBuffered(selector, node1); + TestSslChannelBuilder.TestSslTransportLayer.transportLayers.get(node1).truncateReadBuffer(); + disableRead.accept(selector.channel(node1).selectionKey()); + + // Clear poll count and count the polls from now on + node1Polls.set(0); + + // Process sends and receives on node2. Test verifies that we don't process node1 + // unnecessarily on each of these polls. + connect(node2, new InetSocketAddress("localhost", server.port)); + int received = 0; + String request = TestUtils.randomString(10); + selector.send(createSend(node2, request)); + while (received < 100) { + received += selector.completedReceives().size(); + if (!selector.completedSends().isEmpty()) { + selector.send(createSend(node2, request)); + } + selector.poll(5); + } + + // Verify that pollSelectionKeys was invoked once to process buffered data + // but not again since there isn't sufficient data to process. + assertEquals(1, node1Polls.get()); + selector.close(node1); + selector.close(node2); + verifySelectorEmpty(); + } + + /** + * Renegotiation is not supported since it is potentially unsafe and it has been removed in TLS 1.3 + */ + @Test + public void testRenegotiationFails() throws Exception { + String node = "0"; + // create connections + InetSocketAddress addr = new InetSocketAddress("localhost", server.port); + selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); + + // send echo requests and receive responses + while (!selector.isChannelReady(node)) { + selector.poll(1000L); + } + selector.send(createSend(node, node + "-" + 0)); + selector.poll(0L); + server.renegotiate(); + selector.send(createSend(node, node + "-" + 1)); + long expiryTime = System.currentTimeMillis() + 2000; + + List disconnected = new ArrayList<>(); + while (!disconnected.contains(node) && System.currentTimeMillis() < expiryTime) { + selector.poll(10); + disconnected.addAll(selector.disconnected().keySet()); + } + assertTrue(disconnected.contains(node), "Renegotiation should cause disconnection"); + + } + + @Override + @Test + public void testMuteOnOOM() throws Exception { + //clean up default selector, replace it with one that uses a finite mem pool + selector.close(); + MemoryPool pool = new SimpleMemoryPool(900, 900, false, null); + //the initial channel builder is for clients, we need a server one + String tlsProtocol = "TLSv1.2"; + File trustStoreFile = File.createTempFile("truststore", ".jks"); + Map sslServerConfigs = new TestSslUtils.SslConfigsBuilder(Mode.SERVER) + .tlsProtocol(tlsProtocol) + .createNewTrustStore(trustStoreFile) + .build(); + channelBuilder = new SslChannelBuilder(Mode.SERVER, null, false, new LogContext()); + channelBuilder.configure(sslServerConfigs); + selector = new Selector(NetworkReceive.UNLIMITED, 5000, metrics, time, "MetricGroup", + new HashMap(), true, false, channelBuilder, pool, new LogContext()); + + try (ServerSocketChannel ss = ServerSocketChannel.open()) { + ss.bind(new InetSocketAddress(0)); + + InetSocketAddress serverAddress = (InetSocketAddress) ss.getLocalAddress(); + + SslSender sender1 = createSender(tlsProtocol, serverAddress, randomPayload(900)); + SslSender sender2 = createSender(tlsProtocol, serverAddress, randomPayload(900)); + sender1.start(); + sender2.start(); + + SocketChannel channelX = ss.accept(); //not defined if its 1 or 2 + channelX.configureBlocking(false); + SocketChannel channelY = ss.accept(); + channelY.configureBlocking(false); + selector.register("clientX", channelX); + selector.register("clientY", channelY); + + boolean handshaked = false; + NetworkReceive firstReceive = null; + long deadline = System.currentTimeMillis() + 5000; + //keep calling poll until: + //1. both senders have completed the handshakes (so server selector has tried reading both payloads) + //2. a single payload is actually read out completely (the other is too big to fit) + while (System.currentTimeMillis() < deadline) { + selector.poll(10); + + Collection completed = selector.completedReceives(); + if (firstReceive == null) { + if (!completed.isEmpty()) { + assertEquals(1, completed.size(), "expecting a single request"); + firstReceive = completed.iterator().next(); + assertTrue(selector.isMadeReadProgressLastPoll()); + assertEquals(0, pool.availableMemory()); + } + } else { + assertTrue(completed.isEmpty(), "only expecting single request"); + } + + handshaked = sender1.waitForHandshake(1) && sender2.waitForHandshake(1); + + if (handshaked && firstReceive != null && selector.isOutOfMemory()) + break; + } + assertTrue(handshaked, "could not initiate connections within timeout"); + + selector.poll(10); + assertTrue(selector.completedReceives().isEmpty()); + assertEquals(0, pool.availableMemory()); + assertNotNull(firstReceive, "First receive not complete"); + assertTrue(selector.isOutOfMemory(), "Selector not out of memory"); + + firstReceive.close(); + assertEquals(900, pool.availableMemory()); //memory has been released back to pool + + Collection completed = Collections.emptyList(); + deadline = System.currentTimeMillis() + 5000; + while (System.currentTimeMillis() < deadline && completed.isEmpty()) { + selector.poll(1000); + completed = selector.completedReceives(); + } + assertEquals(1, completed.size(), "could not read remaining request within timeout"); + assertEquals(0, pool.availableMemory()); + assertFalse(selector.isOutOfMemory()); + } + } + + /** + * Connects and waits for handshake to complete. This is required since SslTransportLayer + * implementation requires the channel to be ready before send is invoked (unlike plaintext + * where send can be invoked straight after connect) + */ + protected void connect(String node, InetSocketAddress serverAddr) throws IOException { + blockingConnect(node, serverAddr); + } + + private SslSender createSender(String tlsProtocol, InetSocketAddress serverAddress, byte[] payload) { + return new SslSender(tlsProtocol, serverAddress, payload); + } + + private static class TestSslChannelBuilder extends SslChannelBuilder { + + public TestSslChannelBuilder(Mode mode) { + super(mode, null, false, new LogContext()); + } + + @Override + protected SslTransportLayer buildTransportLayer(SslFactory sslFactory, String id, SelectionKey key, + ChannelMetadataRegistry metadataRegistry) throws IOException { + SocketChannel socketChannel = (SocketChannel) key.channel(); + SSLEngine sslEngine = sslFactory.createSslEngine(socketChannel.socket()); + TestSslTransportLayer transportLayer = new TestSslTransportLayer(id, key, sslEngine, metadataRegistry); + return transportLayer; + } + + /* + * TestSslTransportLayer will read from socket once every two tries. This increases + * the chance that there will be bytes buffered in the transport layer after read(). + */ + static class TestSslTransportLayer extends SslTransportLayer { + static Map transportLayers = new HashMap<>(); + boolean muteSocket = false; + + public TestSslTransportLayer(String channelId, SelectionKey key, SSLEngine sslEngine, + ChannelMetadataRegistry metadataRegistry) throws IOException { + super(channelId, key, sslEngine, metadataRegistry); + transportLayers.put(channelId, this); + } + + @Override + protected int readFromSocketChannel() throws IOException { + if (muteSocket) { + if ((selectionKey().interestOps() & SelectionKey.OP_READ) != 0) + muteSocket = false; + return 0; + } + muteSocket = true; + return super.readFromSocketChannel(); + } + + // Leave one byte in network read buffer so that some buffered bytes are present, + // but not enough to make progress on a read. + void truncateReadBuffer() throws Exception { + netReadBuffer().position(1); + appReadBuffer().position(0); + muteSocket = true; + } + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/network/SslSender.java b/clients/src/test/java/org/apache/kafka/common/network/SslSender.java new file mode 100644 index 0000000..22196dd --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/network/SslSender.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSocket; +import javax.net.ssl.TrustManager; +import javax.net.ssl.X509TrustManager; +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +public class SslSender extends Thread { + + private final String tlsProtocol; + private final InetSocketAddress serverAddress; + private final byte[] payload; + private final CountDownLatch handshaked = new CountDownLatch(1); + + public SslSender(String tlsProtocol, InetSocketAddress serverAddress, byte[] payload) { + this.tlsProtocol = tlsProtocol; + this.serverAddress = serverAddress; + this.payload = payload; + setDaemon(true); + setName("SslSender - " + payload.length + " bytes @ " + serverAddress); + } + + @Override + public void run() { + try { + SSLContext sc = SSLContext.getInstance(tlsProtocol); + sc.init(null, new TrustManager[]{new NaiveTrustManager()}, new java.security.SecureRandom()); + try (SSLSocket connection = (SSLSocket) sc.getSocketFactory().createSocket(serverAddress.getAddress(), serverAddress.getPort())) { + OutputStream os = connection.getOutputStream(); + connection.startHandshake(); + handshaked.countDown(); + os.write(payload); + os.flush(); + } + } catch (Exception e) { + e.printStackTrace(System.err); + } + } + + public boolean waitForHandshake(long timeoutMillis) throws InterruptedException { + return handshaked.await(timeoutMillis, TimeUnit.MILLISECONDS); + } + + /** + * blindly trust any certificate presented to it + */ + private static class NaiveTrustManager implements X509TrustManager { + @Override + public void checkClientTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException { + //nop + } + + @Override + public void checkServerTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException { + //nop + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return new X509Certificate[0]; + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java b/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java new file mode 100644 index 0000000..17a8d79 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java @@ -0,0 +1,1456 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.config.SslConfigs; +import org.apache.kafka.common.config.internals.BrokerSecurityConfigs; +import org.apache.kafka.common.config.types.Password; +import org.apache.kafka.common.errors.InvalidConfigurationException; +import org.apache.kafka.common.memory.MemoryPool; +import org.apache.kafka.common.message.ApiMessageType; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.requests.ApiVersionsResponse; +import org.apache.kafka.common.security.TestSecurityConfig; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.common.security.ssl.DefaultSslEngineFactory; +import org.apache.kafka.common.security.ssl.SslFactory; +import org.apache.kafka.common.utils.Java; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.test.TestSslUtils; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.ArgumentsProvider; +import org.junit.jupiter.params.provider.ArgumentsSource; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.SelectionKey; +import java.nio.channels.SocketChannel; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; +import java.util.stream.Stream; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLParameters; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +/** + * Tests for the SSL transport layer. These use a test harness that runs a simple socket server that echos back responses. + */ +public class SslTransportLayerTest { + + private static final int BUFFER_SIZE = 4 * 1024; + private static Time time = Time.SYSTEM; + + private static class Args { + private final String tlsProtocol; + private final boolean useInlinePem; + private CertStores serverCertStores; + private CertStores clientCertStores; + private Map sslClientConfigs; + private Map sslServerConfigs; + private Map sslConfigOverrides; + + public Args(String tlsProtocol, boolean useInlinePem) throws Exception { + this.tlsProtocol = tlsProtocol; + this.useInlinePem = useInlinePem; + sslConfigOverrides = new HashMap<>(); + sslConfigOverrides.put(SslConfigs.SSL_PROTOCOL_CONFIG, tlsProtocol); + sslConfigOverrides.put(SslConfigs.SSL_ENABLED_PROTOCOLS_CONFIG, Collections.singletonList(tlsProtocol)); + init(); + } + + Map getTrustingConfig(CertStores certStores, CertStores peerCertStores) { + Map configs = certStores.getTrustingConfig(peerCertStores); + configs.putAll(sslConfigOverrides); + return configs; + } + + private void init() throws Exception { + // Create certificates for use by client and server. Add server cert to client truststore and vice versa. + serverCertStores = certBuilder(true, "server", useInlinePem).addHostName("localhost").build(); + clientCertStores = certBuilder(false, "client", useInlinePem).addHostName("localhost").build(); + sslServerConfigs = getTrustingConfig(serverCertStores, clientCertStores); + sslClientConfigs = getTrustingConfig(clientCertStores, serverCertStores); + sslServerConfigs.put(SslConfigs.SSL_ENGINE_FACTORY_CLASS_CONFIG, DefaultSslEngineFactory.class); + sslClientConfigs.put(SslConfigs.SSL_ENGINE_FACTORY_CLASS_CONFIG, DefaultSslEngineFactory.class); + } + + @Override + public String toString() { + return "tlsProtocol=" + tlsProtocol + + ", useInlinePem=" + useInlinePem; + } + } + + private static class SslTransportLayerArgumentsProvider implements ArgumentsProvider { + + @Override + public Stream provideArguments(ExtensionContext context) throws Exception { + List parameters = new ArrayList<>(); + parameters.add(Arguments.of(new Args("TLSv1.2", false))); + parameters.add(Arguments.of(new Args("TLSv1.2", true))); + if (Java.IS_JAVA11_COMPATIBLE) { + parameters.add(Arguments.of(new Args("TLSv1.3", false))); + } + return parameters.stream(); + } + } + + private NioEchoServer server; + private Selector selector; + + @AfterEach + public void teardown() throws Exception { + if (selector != null) + this.selector.close(); + if (server != null) + this.server.close(); + } + + /** + * Tests that server certificate with SubjectAltName containing the valid hostname + * is accepted by a client that connects using the hostname and validates server endpoint. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testValidEndpointIdentificationSanDns(Args args) throws Exception { + createSelector(args); + String node = "0"; + server = createEchoServer(args, SecurityProtocol.SSL); + args.sslClientConfigs.put(SslConfigs.SSL_ENDPOINT_IDENTIFICATION_ALGORITHM_CONFIG, "HTTPS"); + createSelector(args.sslClientConfigs); + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); + + NetworkTestUtils.checkClientConnection(selector, node, 100, 10); + server.verifyAuthenticationMetrics(1, 0); + } + + /** + * Tests that server certificate with SubjectAltName containing valid IP address + * is accepted by a client that connects using IP address and validates server endpoint. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testValidEndpointIdentificationSanIp(Args args) throws Exception { + String node = "0"; + args.serverCertStores = certBuilder(true, "server", args.useInlinePem).hostAddress(InetAddress.getByName("127.0.0.1")).build(); + args.clientCertStores = certBuilder(false, "client", args.useInlinePem).hostAddress(InetAddress.getByName("127.0.0.1")).build(); + args.sslServerConfigs = args.getTrustingConfig(args.serverCertStores, args.clientCertStores); + args.sslClientConfigs = args.getTrustingConfig(args.clientCertStores, args.serverCertStores); + server = createEchoServer(args, SecurityProtocol.SSL); + args.sslClientConfigs.put(SslConfigs.SSL_ENDPOINT_IDENTIFICATION_ALGORITHM_CONFIG, "HTTPS"); + createSelector(args.sslClientConfigs); + InetSocketAddress addr = new InetSocketAddress("127.0.0.1", server.port()); + selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); + + NetworkTestUtils.checkClientConnection(selector, node, 100, 10); + } + + /** + * Tests that server certificate with CN containing valid hostname + * is accepted by a client that connects using hostname and validates server endpoint. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testValidEndpointIdentificationCN(Args args) throws Exception { + args.serverCertStores = certBuilder(true, "localhost", args.useInlinePem).build(); + args.clientCertStores = certBuilder(false, "localhost", args.useInlinePem).build(); + args.sslServerConfigs = args.getTrustingConfig(args.serverCertStores, args.clientCertStores); + args.sslClientConfigs = args.getTrustingConfig(args.clientCertStores, args.serverCertStores); + args.sslClientConfigs.put(SslConfigs.SSL_ENDPOINT_IDENTIFICATION_ALGORITHM_CONFIG, "HTTPS"); + verifySslConfigs(args); + } + + /** + * Tests that hostname verification is performed on the host name or address + * specified by the client without using reverse DNS lookup. Certificate is + * created with hostname, client connection uses IP address. Endpoint validation + * must fail. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testEndpointIdentificationNoReverseLookup(Args args) throws Exception { + String node = "0"; + server = createEchoServer(args, SecurityProtocol.SSL); + args.sslClientConfigs.put(SslConfigs.SSL_ENDPOINT_IDENTIFICATION_ALGORITHM_CONFIG, "HTTPS"); + createSelector(args.sslClientConfigs); + InetSocketAddress addr = new InetSocketAddress("127.0.0.1", server.port()); + selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); + + NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED); + } + + /** + * According to RFC 2818: + *
            Typically, the server has no external knowledge of what the client's + * identity ought to be and so checks (other than that the client has a + * certificate chain rooted in an appropriate CA) are not possible. If a + * server has such knowledge (typically from some source external to + * HTTP or TLS) it SHOULD check the identity as described above.
            + * + * However, Java SSL engine does not perform any endpoint validation for client IP address. + * Hence it is safe to avoid reverse DNS lookup while creating the SSL engine. This test checks + * that client validation does not fail even if the client certificate has an invalid hostname. + * This test is to ensure that if client endpoint validation is added to Java in future, we can detect + * and update Kafka SSL code to enable validation on the server-side and provide hostname if required. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testClientEndpointNotValidated(Args args) throws Exception { + String node = "0"; + + // Create client certificate with an invalid hostname + args.clientCertStores = certBuilder(false, "non-existent.com", args.useInlinePem).build(); + args.serverCertStores = certBuilder(true, "localhost", args.useInlinePem).build(); + args.sslServerConfigs = args.getTrustingConfig(args.serverCertStores, args.clientCertStores); + args.sslClientConfigs = args.getTrustingConfig(args.clientCertStores, args.serverCertStores); + + // Create a server with endpoint validation enabled on the server SSL engine + SslChannelBuilder serverChannelBuilder = new TestSslChannelBuilder(Mode.SERVER) { + @Override + protected TestSslTransportLayer newTransportLayer(String id, SelectionKey key, SSLEngine sslEngine) throws IOException { + SSLParameters sslParams = sslEngine.getSSLParameters(); + sslParams.setEndpointIdentificationAlgorithm("HTTPS"); + sslEngine.setSSLParameters(sslParams); + return super.newTransportLayer(id, key, sslEngine); + } + }; + serverChannelBuilder.configure(args.sslServerConfigs); + server = new NioEchoServer(ListenerName.forSecurityProtocol(SecurityProtocol.SSL), SecurityProtocol.SSL, + new TestSecurityConfig(args.sslServerConfigs), "localhost", serverChannelBuilder, null, time); + server.start(); + + createSelector(args.sslClientConfigs); + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); + + NetworkTestUtils.checkClientConnection(selector, node, 100, 10); + } + + /** + * Tests that server certificate with invalid host name is not accepted by + * a client that validates server endpoint. Server certificate uses + * wrong hostname as common name to trigger endpoint validation failure. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testInvalidEndpointIdentification(Args args) throws Exception { + args.serverCertStores = certBuilder(true, "server", args.useInlinePem).addHostName("notahost").build(); + args.clientCertStores = certBuilder(false, "client", args.useInlinePem).addHostName("localhost").build(); + args.sslServerConfigs = args.getTrustingConfig(args.serverCertStores, args.clientCertStores); + args.sslClientConfigs = args.getTrustingConfig(args.clientCertStores, args.serverCertStores); + args.sslClientConfigs.put(SslConfigs.SSL_ENDPOINT_IDENTIFICATION_ALGORITHM_CONFIG, "HTTPS"); + verifySslConfigsWithHandshakeFailure(args); + } + + /** + * Tests that server certificate with invalid host name is accepted by + * a client that has disabled endpoint validation + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testEndpointIdentificationDisabled(Args args) throws Exception { + args.serverCertStores = certBuilder(true, "server", args.useInlinePem).addHostName("notahost").build(); + args.clientCertStores = certBuilder(false, "client", args.useInlinePem).addHostName("localhost").build(); + args.sslServerConfigs = args.getTrustingConfig(args.serverCertStores, args.clientCertStores); + args.sslClientConfigs = args.getTrustingConfig(args.clientCertStores, args.serverCertStores); + + server = createEchoServer(args, SecurityProtocol.SSL); + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + + // Disable endpoint validation, connection should succeed + String node = "1"; + args.sslClientConfigs.put(SslConfigs.SSL_ENDPOINT_IDENTIFICATION_ALGORITHM_CONFIG, ""); + createSelector(args.sslClientConfigs); + selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); + NetworkTestUtils.checkClientConnection(selector, node, 100, 10); + + // Disable endpoint validation using null value, connection should succeed + String node2 = "2"; + args.sslClientConfigs.put(SslConfigs.SSL_ENDPOINT_IDENTIFICATION_ALGORITHM_CONFIG, null); + createSelector(args.sslClientConfigs); + selector.connect(node2, addr, BUFFER_SIZE, BUFFER_SIZE); + NetworkTestUtils.checkClientConnection(selector, node2, 100, 10); + + // Connection should fail with endpoint validation enabled + String node3 = "3"; + args.sslClientConfigs.put(SslConfigs.SSL_ENDPOINT_IDENTIFICATION_ALGORITHM_CONFIG, "HTTPS"); + createSelector(args.sslClientConfigs); + selector.connect(node3, addr, BUFFER_SIZE, BUFFER_SIZE); + NetworkTestUtils.waitForChannelClose(selector, node3, ChannelState.State.AUTHENTICATION_FAILED); + selector.close(); + } + + /** + * Tests that server accepts connections from clients with a trusted certificate + * when client authentication is required. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testClientAuthenticationRequiredValidProvided(Args args) throws Exception { + args.sslServerConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "required"); + verifySslConfigs(args); + } + + /** + * Tests that disabling client authentication as a listener override has the desired effect. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testListenerConfigOverride(Args args) throws Exception { + String node = "0"; + ListenerName clientListenerName = new ListenerName("client"); + args.sslServerConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "required"); + args.sslServerConfigs.put(clientListenerName.configPrefix() + BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "none"); + + // `client` listener is not configured at this point, so client auth should be required + server = createEchoServer(args, SecurityProtocol.SSL); + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + + // Connect with client auth should work fine + createSelector(args.sslClientConfigs); + selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); + NetworkTestUtils.checkClientConnection(selector, node, 100, 10); + selector.close(); + + // Remove client auth, so connection should fail + CertStores.KEYSTORE_PROPS.forEach(args.sslClientConfigs::remove); + createSelector(args.sslClientConfigs); + selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); + NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED); + selector.close(); + server.close(); + + // Listener-specific config should be used and client auth should be disabled + server = createEchoServer(args, clientListenerName, SecurityProtocol.SSL); + addr = new InetSocketAddress("localhost", server.port()); + + // Connect without client auth should work fine now + createSelector(args.sslClientConfigs); + selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); + NetworkTestUtils.checkClientConnection(selector, node, 100, 10); + } + + /** + * Tests that server does not accept connections from clients with an untrusted certificate + * when client authentication is required. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testClientAuthenticationRequiredUntrustedProvided(Args args) throws Exception { + args.sslServerConfigs = args.serverCertStores.getUntrustingConfig(); + args.sslServerConfigs.putAll(args.sslConfigOverrides); + args.sslServerConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "required"); + verifySslConfigsWithHandshakeFailure(args); + } + + /** + * Tests that server does not accept connections from clients which don't + * provide a certificate when client authentication is required. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testClientAuthenticationRequiredNotProvided(Args args) throws Exception { + args.sslServerConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "required"); + CertStores.KEYSTORE_PROPS.forEach(args.sslClientConfigs::remove); + verifySslConfigsWithHandshakeFailure(args); + } + + /** + * Tests that server accepts connections from a client configured + * with an untrusted certificate if client authentication is disabled + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testClientAuthenticationDisabledUntrustedProvided(Args args) throws Exception { + args.sslServerConfigs = args.serverCertStores.getUntrustingConfig(); + args.sslServerConfigs.putAll(args.sslConfigOverrides); + args.sslServerConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "none"); + verifySslConfigs(args); + } + + /** + * Tests that server accepts connections from a client that does not provide + * a certificate if client authentication is disabled + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testClientAuthenticationDisabledNotProvided(Args args) throws Exception { + args.sslServerConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "none"); + + CertStores.KEYSTORE_PROPS.forEach(args.sslClientConfigs::remove); + verifySslConfigs(args); + } + + /** + * Tests that server accepts connections from a client configured + * with a valid certificate if client authentication is requested + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testClientAuthenticationRequestedValidProvided(Args args) throws Exception { + args.sslServerConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "requested"); + verifySslConfigs(args); + } + + /** + * Tests that server accepts connections from a client that does not provide + * a certificate if client authentication is requested but not required + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testClientAuthenticationRequestedNotProvided(Args args) throws Exception { + args.sslServerConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "requested"); + + CertStores.KEYSTORE_PROPS.forEach(args.sslClientConfigs::remove); + verifySslConfigs(args); + } + + /** + * Tests key-pair created using DSA. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testDsaKeyPair(Args args) throws Exception { + // DSA algorithms are not supported for TLSv1.3. + assumeTrue(args.tlsProtocol.equals("TLSv1.2")); + args.serverCertStores = certBuilder(true, "server", args.useInlinePem).keyAlgorithm("DSA").build(); + args.clientCertStores = certBuilder(false, "client", args.useInlinePem).keyAlgorithm("DSA").build(); + args.sslServerConfigs = args.getTrustingConfig(args.serverCertStores, args.clientCertStores); + args.sslClientConfigs = args.getTrustingConfig(args.clientCertStores, args.serverCertStores); + args.sslServerConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "required"); + verifySslConfigs(args); + } + + /** + * Tests key-pair created using EC. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testECKeyPair(Args args) throws Exception { + args.serverCertStores = certBuilder(true, "server", args.useInlinePem).keyAlgorithm("EC").build(); + args.clientCertStores = certBuilder(false, "client", args.useInlinePem).keyAlgorithm("EC").build(); + args.sslServerConfigs = args.getTrustingConfig(args.serverCertStores, args.clientCertStores); + args.sslClientConfigs = args.getTrustingConfig(args.clientCertStores, args.serverCertStores); + args.sslServerConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "required"); + verifySslConfigs(args); + } + + /** + * Tests PEM key store and trust store files which don't have store passwords. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testPemFiles(Args args) throws Exception { + TestSslUtils.convertToPem(args.sslServerConfigs, true, true); + TestSslUtils.convertToPem(args.sslClientConfigs, true, true); + args.sslServerConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "required"); + verifySslConfigs(args); + } + + /** + * Test with PEM key store files without key password for client key store. We don't allow this + * with PEM files since unprotected private key on disk is not safe. We do allow with inline + * PEM config since key config can be encrypted or externalized similar to other password configs. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testPemFilesWithoutClientKeyPassword(Args args) throws Exception { + boolean useInlinePem = args.useInlinePem; + TestSslUtils.convertToPem(args.sslServerConfigs, !useInlinePem, true); + TestSslUtils.convertToPem(args.sslClientConfigs, !useInlinePem, false); + args.sslServerConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "required"); + server = createEchoServer(args, SecurityProtocol.SSL); + if (useInlinePem) + verifySslConfigs(args); + else + assertThrows(KafkaException.class, () -> createSelector(args.sslClientConfigs)); + } + + /** + * Test with PEM key store files without key password for server key store.We don't allow this + * with PEM files since unprotected private key on disk is not safe. We do allow with inline + * PEM config since key config can be encrypted or externalized similar to other password configs. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testPemFilesWithoutServerKeyPassword(Args args) throws Exception { + TestSslUtils.convertToPem(args.sslServerConfigs, !args.useInlinePem, false); + TestSslUtils.convertToPem(args.sslClientConfigs, !args.useInlinePem, true); + + if (args.useInlinePem) + verifySslConfigs(args); + else + assertThrows(KafkaException.class, () -> createEchoServer(args, SecurityProtocol.SSL)); + } + + /** + * Tests that an invalid SecureRandom implementation cannot be configured + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testInvalidSecureRandomImplementation(Args args) { + try (SslChannelBuilder channelBuilder = newClientChannelBuilder()) { + args.sslClientConfigs.put(SslConfigs.SSL_SECURE_RANDOM_IMPLEMENTATION_CONFIG, "invalid"); + assertThrows(KafkaException.class, () -> channelBuilder.configure(args.sslClientConfigs)); + } + } + + /** + * Tests that channels cannot be created if truststore cannot be loaded + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testInvalidTruststorePassword(Args args) { + try (SslChannelBuilder channelBuilder = newClientChannelBuilder()) { + args.sslClientConfigs.put(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG, "invalid"); + assertThrows(KafkaException.class, () -> channelBuilder.configure(args.sslClientConfigs)); + } + } + + /** + * Tests that channels cannot be created if keystore cannot be loaded + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testInvalidKeystorePassword(Args args) { + try (SslChannelBuilder channelBuilder = newClientChannelBuilder()) { + args.sslClientConfigs.put(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG, "invalid"); + assertThrows(KafkaException.class, () -> channelBuilder.configure(args.sslClientConfigs)); + } + } + + /** + * Tests that client connections can be created to a server + * if null truststore password is used + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testNullTruststorePassword(Args args) throws Exception { + args.sslClientConfigs.remove(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG); + args.sslServerConfigs.remove(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG); + + verifySslConfigs(args); + } + + /** + * Tests that client connections cannot be created to a server + * if key password is invalid + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testInvalidKeyPassword(Args args) throws Exception { + args.sslServerConfigs.put(SslConfigs.SSL_KEY_PASSWORD_CONFIG, new Password("invalid")); + if (args.useInlinePem) { + // We fail fast for PEM + assertThrows(InvalidConfigurationException.class, () -> createEchoServer(args, SecurityProtocol.SSL)); + return; + } + verifySslConfigsWithHandshakeFailure(args); + } + + /** + * Tests that connection succeeds with the default TLS version. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testTlsDefaults(Args args) throws Exception { + args.sslServerConfigs = args.serverCertStores.getTrustingConfig(args.clientCertStores); + args.sslClientConfigs = args.clientCertStores.getTrustingConfig(args.serverCertStores); + + assertEquals(SslConfigs.DEFAULT_SSL_PROTOCOL, args.sslServerConfigs.get(SslConfigs.SSL_PROTOCOL_CONFIG)); + assertEquals(SslConfigs.DEFAULT_SSL_PROTOCOL, args.sslClientConfigs.get(SslConfigs.SSL_PROTOCOL_CONFIG)); + + server = createEchoServer(args, SecurityProtocol.SSL); + createSelector(args.sslClientConfigs); + + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + selector.connect("0", addr, BUFFER_SIZE, BUFFER_SIZE); + + NetworkTestUtils.checkClientConnection(selector, "0", 10, 100); + server.verifyAuthenticationMetrics(1, 0); + selector.close(); + } + + /** Checks connection failed using the specified {@code tlsVersion}. */ + private void checkAuthenticationFailed(Args args, String node, String tlsVersion) throws IOException { + args.sslClientConfigs.put(SslConfigs.SSL_ENABLED_PROTOCOLS_CONFIG, Arrays.asList(tlsVersion)); + createSelector(args.sslClientConfigs); + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); + + NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED); + + selector.close(); + } + + /** + * Tests that connections cannot be made with unsupported TLS cipher suites + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testUnsupportedCiphers(Args args) throws Exception { + SSLContext context = SSLContext.getInstance(args.tlsProtocol); + context.init(null, null, null); + String[] cipherSuites = context.getDefaultSSLParameters().getCipherSuites(); + args.sslServerConfigs.put(SslConfigs.SSL_CIPHER_SUITES_CONFIG, Arrays.asList(cipherSuites[0])); + server = createEchoServer(args, SecurityProtocol.SSL); + + args.sslClientConfigs.put(SslConfigs.SSL_CIPHER_SUITES_CONFIG, Arrays.asList(cipherSuites[1])); + createSelector(args.sslClientConfigs); + + checkAuthenticationFailed(args, "1", args.tlsProtocol); + server.verifyAuthenticationMetrics(0, 1); + } + + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testServerRequestMetrics(Args args) throws Exception { + String node = "0"; + server = createEchoServer(args, SecurityProtocol.SSL); + createSelector(args.sslClientConfigs, 16384, 16384, 16384); + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + selector.connect(node, addr, 102400, 102400); + NetworkTestUtils.waitForChannelReady(selector, node); + int messageSize = 1024 * 1024; + String message = TestUtils.randomString(messageSize); + selector.send(new NetworkSend(node, ByteBufferSend.sizePrefixed(ByteBuffer.wrap(message.getBytes())))); + while (selector.completedReceives().isEmpty()) { + selector.poll(100L); + } + int totalBytes = messageSize + 4; // including 4-byte size + server.waitForMetric("incoming-byte", totalBytes); + server.waitForMetric("outgoing-byte", totalBytes); + server.waitForMetric("request", 1); + server.waitForMetric("response", 1); + } + + /** + * selector.poll() should be able to fetch more data than netReadBuffer from the socket. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testSelectorPollReadSize(Args args) throws Exception { + String node = "0"; + server = createEchoServer(args, SecurityProtocol.SSL); + createSelector(args.sslClientConfigs, 16384, 16384, 16384); + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + selector.connect(node, addr, 102400, 102400); + NetworkTestUtils.checkClientConnection(selector, node, 81920, 1); + + // Send a message of 80K. This is 5X as large as the socket buffer. It should take at least three selector.poll() + // to read this message from socket if the SslTransportLayer.read() does not read all data from socket buffer. + String message = TestUtils.randomString(81920); + selector.send(new NetworkSend(node, ByteBufferSend.sizePrefixed(ByteBuffer.wrap(message.getBytes())))); + + // Send the message to echo server + TestUtils.waitForCondition(() -> { + try { + selector.poll(100L); + } catch (IOException e) { + return false; + } + return selector.completedSends().size() > 0; + }, "Timed out waiting for message to be sent"); + + // Wait for echo server to send the message back + TestUtils.waitForCondition(() -> + server.numSent() >= 2, "Timed out waiting for echo server to send message"); + + // Read the message from socket with only one poll() + selector.poll(1000L); + + Collection receiveList = selector.completedReceives(); + assertEquals(1, receiveList.size()); + assertEquals(message, new String(Utils.toArray(receiveList.iterator().next().payload()))); + } + + /** + * Tests handling of BUFFER_UNDERFLOW during unwrap when network read buffer is smaller than SSL session packet buffer size. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testNetReadBufferResize(Args args) throws Exception { + String node = "0"; + server = createEchoServer(args, SecurityProtocol.SSL); + createSelector(args.sslClientConfigs, 10, null, null); + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); + + NetworkTestUtils.checkClientConnection(selector, node, 64000, 10); + } + + /** + * Tests handling of BUFFER_OVERFLOW during wrap when network write buffer is smaller than SSL session packet buffer size. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testNetWriteBufferResize(Args args) throws Exception { + String node = "0"; + server = createEchoServer(args, SecurityProtocol.SSL); + createSelector(args.sslClientConfigs, null, 10, null); + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); + + NetworkTestUtils.checkClientConnection(selector, node, 64000, 10); + } + + /** + * Tests handling of BUFFER_OVERFLOW during unwrap when application read buffer is smaller than SSL session application buffer size. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testApplicationBufferResize(Args args) throws Exception { + String node = "0"; + server = createEchoServer(args, SecurityProtocol.SSL); + createSelector(args.sslClientConfigs, null, null, 10); + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); + + NetworkTestUtils.checkClientConnection(selector, node, 64000, 10); + } + + /** + * Tests that time spent on the network thread is accumulated on each channel + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testNetworkThreadTimeRecorded(Args args) throws Exception { + LogContext logContext = new LogContext(); + ChannelBuilder channelBuilder = new SslChannelBuilder(Mode.CLIENT, null, false, logContext); + channelBuilder.configure(args.sslClientConfigs); + try (Selector selector = new Selector(NetworkReceive.UNLIMITED, Selector.NO_IDLE_TIMEOUT_MS, new Metrics(), Time.SYSTEM, + "MetricGroup", new HashMap<>(), false, true, channelBuilder, MemoryPool.NONE, logContext)) { + + String node = "0"; + server = createEchoServer(args, SecurityProtocol.SSL); + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); + + String message = TestUtils.randomString(1024 * 1024); + NetworkTestUtils.waitForChannelReady(selector, node); + final KafkaChannel channel = selector.channel(node); + assertTrue(channel.getAndResetNetworkThreadTimeNanos() > 0, "SSL handshake time not recorded"); + assertEquals(0, channel.getAndResetNetworkThreadTimeNanos(), "Time not reset"); + + selector.mute(node); + selector.send(new NetworkSend(node, ByteBufferSend.sizePrefixed(ByteBuffer.wrap(message.getBytes())))); + while (selector.completedSends().isEmpty()) { + selector.poll(100L); + } + long sendTimeNanos = channel.getAndResetNetworkThreadTimeNanos(); + assertTrue(sendTimeNanos > 0, "Send time not recorded: " + sendTimeNanos); + assertEquals(0, channel.getAndResetNetworkThreadTimeNanos(), "Time not reset"); + assertFalse(channel.hasBytesBuffered(), "Unexpected bytes buffered"); + assertEquals(0, selector.completedReceives().size()); + + selector.unmute(node); + // Wait for echo server to send the message back + TestUtils.waitForCondition(() -> { + try { + selector.poll(100L); + } catch (IOException e) { + return false; + } + return !selector.completedReceives().isEmpty(); + }, "Timed out waiting for a message to receive from echo server"); + + long receiveTimeNanos = channel.getAndResetNetworkThreadTimeNanos(); + assertTrue(receiveTimeNanos > 0, "Receive time not recorded: " + receiveTimeNanos); + } + } + + /** + * Tests that IOExceptions from read during SSL handshake are not treated as authentication failures. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testIOExceptionsDuringHandshakeRead(Args args) throws Exception { + server = createEchoServer(args, SecurityProtocol.SSL); + testIOExceptionsDuringHandshake(args, FailureAction.THROW_IO_EXCEPTION, FailureAction.NO_OP); + } + + /** + * Tests that IOExceptions from write during SSL handshake are not treated as authentication failures. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testIOExceptionsDuringHandshakeWrite(Args args) throws Exception { + server = createEchoServer(args, SecurityProtocol.SSL); + testIOExceptionsDuringHandshake(args, FailureAction.NO_OP, FailureAction.THROW_IO_EXCEPTION); + } + + /** + * Tests that if the remote end closes connection ungracefully during SSL handshake while reading data, + * the disconnection is not treated as an authentication failure. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testUngracefulRemoteCloseDuringHandshakeRead(Args args) throws Exception { + server = createEchoServer(args, SecurityProtocol.SSL); + testIOExceptionsDuringHandshake(args, server::closeSocketChannels, FailureAction.NO_OP); + } + + /** + * Tests that if the remote end closes connection ungracefully during SSL handshake while writing data, + * the disconnection is not treated as an authentication failure. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testUngracefulRemoteCloseDuringHandshakeWrite(Args args) throws Exception { + server = createEchoServer(args, SecurityProtocol.SSL); + testIOExceptionsDuringHandshake(args, FailureAction.NO_OP, server::closeSocketChannels); + } + + /** + * Tests that if the remote end closes the connection during SSL handshake while reading data, + * the disconnection is not treated as an authentication failure. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testGracefulRemoteCloseDuringHandshakeRead(Args args) throws Exception { + server = createEchoServer(args, SecurityProtocol.SSL); + testIOExceptionsDuringHandshake(args, FailureAction.NO_OP, server::closeKafkaChannels); + } + + /** + * Tests that if the remote end closes the connection during SSL handshake while writing data, + * the disconnection is not treated as an authentication failure. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testGracefulRemoteCloseDuringHandshakeWrite(Args args) throws Exception { + server = createEchoServer(args, SecurityProtocol.SSL); + testIOExceptionsDuringHandshake(args, server::closeKafkaChannels, FailureAction.NO_OP); + } + + private void testIOExceptionsDuringHandshake(Args args, + FailureAction readFailureAction, + FailureAction flushFailureAction) throws Exception { + TestSslChannelBuilder channelBuilder = new TestSslChannelBuilder(Mode.CLIENT); + boolean done = false; + for (int i = 1; i <= 100; i++) { + String node = String.valueOf(i); + + channelBuilder.readFailureAction = readFailureAction; + channelBuilder.flushFailureAction = flushFailureAction; + channelBuilder.failureIndex = i; + channelBuilder.configure(args.sslClientConfigs); + this.selector = new Selector(5000, new Metrics(), time, "MetricGroup", channelBuilder, new LogContext()); + + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); + for (int j = 0; j < 30; j++) { + selector.poll(1000L); + KafkaChannel channel = selector.channel(node); + if (channel != null && channel.ready()) { + done = true; + break; + } + if (selector.disconnected().containsKey(node)) { + ChannelState.State state = selector.disconnected().get(node).state(); + assertTrue(state == ChannelState.State.AUTHENTICATE || state == ChannelState.State.READY, + "Unexpected channel state " + state); + break; + } + } + KafkaChannel channel = selector.channel(node); + if (channel != null) + assertTrue(channel.ready(), "Channel not ready or disconnected:" + channel.state().state()); + selector.close(); + } + assertTrue(done, "Too many invocations of read/write during SslTransportLayer.handshake()"); + } + + /** + * Tests that handshake failures are propagated only after writes complete, even when + * there are delays in writes to ensure that clients see an authentication exception + * rather than a connection failure. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testPeerNotifiedOfHandshakeFailure(Args args) throws Exception { + args.sslServerConfigs = args.serverCertStores.getUntrustingConfig(); + args.sslServerConfigs.putAll(args.sslConfigOverrides); + args.sslServerConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "required"); + + // Test without delay and a couple of delay counts to ensure delay applies to handshake failure + for (int i = 0; i < 3; i++) { + String node = "0"; + TestSslChannelBuilder serverChannelBuilder = new TestSslChannelBuilder(Mode.SERVER); + serverChannelBuilder.configure(args.sslServerConfigs); + serverChannelBuilder.flushDelayCount = i; + server = new NioEchoServer(ListenerName.forSecurityProtocol(SecurityProtocol.SSL), + SecurityProtocol.SSL, new TestSecurityConfig(args.sslServerConfigs), + "localhost", serverChannelBuilder, null, time); + server.start(); + createSelector(args.sslClientConfigs); + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); + + NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED); + server.close(); + selector.close(); + serverChannelBuilder.close(); + } + } + + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testCloseSsl(Args args) throws Exception { + testClose(args, SecurityProtocol.SSL, newClientChannelBuilder()); + } + + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testClosePlaintext(Args args) throws Exception { + testClose(args, SecurityProtocol.PLAINTEXT, new PlaintextChannelBuilder(null)); + } + + private SslChannelBuilder newClientChannelBuilder() { + return new SslChannelBuilder(Mode.CLIENT, null, false, new LogContext()); + } + + private void testClose(Args args, SecurityProtocol securityProtocol, ChannelBuilder clientChannelBuilder) throws Exception { + String node = "0"; + server = createEchoServer(args, securityProtocol); + clientChannelBuilder.configure(args.sslClientConfigs); + this.selector = new Selector(5000, new Metrics(), time, "MetricGroup", clientChannelBuilder, new LogContext()); + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); + + NetworkTestUtils.waitForChannelReady(selector, node); + // `waitForChannelReady` waits for client-side channel to be ready. This is sufficient for other tests + // operating on the client-side channel. But here, we are muting the server-side channel below, so we + // need to wait for the server-side channel to be ready as well. + TestUtils.waitForCondition(() -> server.selector().channels().stream().allMatch(KafkaChannel::ready), + "Channel not ready"); + + final ByteArrayOutputStream bytesOut = new ByteArrayOutputStream(); + server.outputChannel(Channels.newChannel(bytesOut)); + server.selector().muteAll(); + byte[] message = TestUtils.randomString(100).getBytes(); + int count = 20; + final int totalSendSize = count * (message.length + 4); + for (int i = 0; i < count; i++) { + selector.send(new NetworkSend(node, ByteBufferSend.sizePrefixed(ByteBuffer.wrap(message)))); + do { + selector.poll(0L); + } while (selector.completedSends().isEmpty()); + } + server.selector().unmuteAll(); + selector.close(node); + TestUtils.waitForCondition(() -> + bytesOut.toByteArray().length == totalSendSize, 5000, "All requests sent were not processed"); + } + + /** + * Verifies that inter-broker listener with validation of truststore against keystore works + * with configs including mutual authentication and hostname verification. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testInterBrokerSslConfigValidation(Args args) throws Exception { + SecurityProtocol securityProtocol = SecurityProtocol.SSL; + args.sslServerConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "required"); + args.sslServerConfigs.put(SslConfigs.SSL_ENDPOINT_IDENTIFICATION_ALGORITHM_CONFIG, "HTTPS"); + args.sslServerConfigs.putAll(args.serverCertStores.keyStoreProps()); + args.sslServerConfigs.putAll(args.serverCertStores.trustStoreProps()); + args.sslClientConfigs.putAll(args.serverCertStores.keyStoreProps()); + args.sslClientConfigs.putAll(args.serverCertStores.trustStoreProps()); + TestSecurityConfig config = new TestSecurityConfig(args.sslServerConfigs); + ListenerName listenerName = ListenerName.forSecurityProtocol(securityProtocol); + ChannelBuilder serverChannelBuilder = ChannelBuilders.serverChannelBuilder(listenerName, + true, securityProtocol, config, null, null, time, new LogContext(), + defaultApiVersionsSupplier()); + server = new NioEchoServer(listenerName, securityProtocol, config, + "localhost", serverChannelBuilder, null, time); + server.start(); + + this.selector = createSelector(args.sslClientConfigs, null, null, null); + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + selector.connect("0", addr, BUFFER_SIZE, BUFFER_SIZE); + NetworkTestUtils.checkClientConnection(selector, "0", 100, 10); + } + + /** + * Verifies that inter-broker listener with validation of truststore against keystore + * fails if certs from keystore are not trusted. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testInterBrokerSslConfigValidationFailure(Args args) { + SecurityProtocol securityProtocol = SecurityProtocol.SSL; + args.sslServerConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "required"); + TestSecurityConfig config = new TestSecurityConfig(args.sslServerConfigs); + ListenerName listenerName = ListenerName.forSecurityProtocol(securityProtocol); + assertThrows(KafkaException.class, () -> ChannelBuilders.serverChannelBuilder( + listenerName, true, securityProtocol, config, + null, null, time, new LogContext(), defaultApiVersionsSupplier())); + } + + /** + * Tests reconfiguration of server keystore. Verifies that existing connections continue + * to work with old keystore and new connections work with new keystore. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testServerKeystoreDynamicUpdate(Args args) throws Exception { + SecurityProtocol securityProtocol = SecurityProtocol.SSL; + TestSecurityConfig config = new TestSecurityConfig(args.sslServerConfigs); + ListenerName listenerName = ListenerName.forSecurityProtocol(securityProtocol); + ChannelBuilder serverChannelBuilder = ChannelBuilders.serverChannelBuilder(listenerName, + false, securityProtocol, config, null, null, time, new LogContext(), + defaultApiVersionsSupplier()); + server = new NioEchoServer(listenerName, securityProtocol, config, + "localhost", serverChannelBuilder, null, time); + server.start(); + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + + // Verify that client with matching truststore can authenticate, send and receive + String oldNode = "0"; + Selector oldClientSelector = createSelector(args.sslClientConfigs); + oldClientSelector.connect(oldNode, addr, BUFFER_SIZE, BUFFER_SIZE); + NetworkTestUtils.checkClientConnection(selector, oldNode, 100, 10); + + CertStores newServerCertStores = certBuilder(true, "server", args.useInlinePem).addHostName("localhost").build(); + Map newKeystoreConfigs = newServerCertStores.keyStoreProps(); + assertTrue(serverChannelBuilder instanceof ListenerReconfigurable, "SslChannelBuilder not reconfigurable"); + ListenerReconfigurable reconfigurableBuilder = (ListenerReconfigurable) serverChannelBuilder; + assertEquals(listenerName, reconfigurableBuilder.listenerName()); + reconfigurableBuilder.validateReconfiguration(newKeystoreConfigs); + reconfigurableBuilder.reconfigure(newKeystoreConfigs); + + // Verify that new client with old truststore fails + oldClientSelector.connect("1", addr, BUFFER_SIZE, BUFFER_SIZE); + NetworkTestUtils.waitForChannelClose(oldClientSelector, "1", ChannelState.State.AUTHENTICATION_FAILED); + + // Verify that new client with new truststore can authenticate, send and receive + args.sslClientConfigs = args.getTrustingConfig(args.clientCertStores, newServerCertStores); + Selector newClientSelector = createSelector(args.sslClientConfigs); + newClientSelector.connect("2", addr, BUFFER_SIZE, BUFFER_SIZE); + NetworkTestUtils.checkClientConnection(newClientSelector, "2", 100, 10); + + // Verify that old client continues to work + NetworkTestUtils.checkClientConnection(oldClientSelector, oldNode, 100, 10); + + CertStores invalidCertStores = certBuilder(true, "server", args.useInlinePem).addHostName("127.0.0.1").build(); + Map invalidConfigs = args.getTrustingConfig(invalidCertStores, args.clientCertStores); + verifyInvalidReconfigure(reconfigurableBuilder, invalidConfigs, "keystore with different SubjectAltName"); + + Map missingStoreConfigs = new HashMap<>(); + missingStoreConfigs.put(SslConfigs.SSL_KEYSTORE_TYPE_CONFIG, "PKCS12"); + missingStoreConfigs.put(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG, "some.keystore.path"); + missingStoreConfigs.put(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG, new Password("some.keystore.password")); + missingStoreConfigs.put(SslConfigs.SSL_KEY_PASSWORD_CONFIG, new Password("some.key.password")); + verifyInvalidReconfigure(reconfigurableBuilder, missingStoreConfigs, "keystore not found"); + + // Verify that new connections continue to work with the server with previously configured keystore after failed reconfiguration + newClientSelector.connect("3", addr, BUFFER_SIZE, BUFFER_SIZE); + NetworkTestUtils.checkClientConnection(newClientSelector, "3", 100, 10); + } + + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testServerKeystoreDynamicUpdateWithNewSubjectAltName(Args args) throws Exception { + SecurityProtocol securityProtocol = SecurityProtocol.SSL; + TestSecurityConfig config = new TestSecurityConfig(args.sslServerConfigs); + ListenerName listenerName = ListenerName.forSecurityProtocol(securityProtocol); + ChannelBuilder serverChannelBuilder = ChannelBuilders.serverChannelBuilder(listenerName, + false, securityProtocol, config, null, null, time, new LogContext(), + defaultApiVersionsSupplier()); + server = new NioEchoServer(listenerName, securityProtocol, config, + "localhost", serverChannelBuilder, null, time); + server.start(); + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + + Selector selector = createSelector(args.sslClientConfigs); + String node1 = "1"; + selector.connect(node1, addr, BUFFER_SIZE, BUFFER_SIZE); + NetworkTestUtils.checkClientConnection(selector, node1, 100, 10); + selector.close(); + + TestSslUtils.CertificateBuilder certBuilder = new TestSslUtils.CertificateBuilder().sanDnsNames("localhost", "*.example.com"); + String truststorePath = (String) args.sslClientConfigs.get(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG); + File truststoreFile = truststorePath != null ? new File(truststorePath) : null; + TestSslUtils.SslConfigsBuilder builder = new TestSslUtils.SslConfigsBuilder(Mode.SERVER) + .useClientCert(false) + .certAlias("server") + .cn("server") + .certBuilder(certBuilder) + .createNewTrustStore(truststoreFile) + .usePem(args.useInlinePem); + Map newConfigs = builder.build(); + Map newKeystoreConfigs = new HashMap<>(); + for (String propName : CertStores.KEYSTORE_PROPS) { + newKeystoreConfigs.put(propName, newConfigs.get(propName)); + } + ListenerReconfigurable reconfigurableBuilder = (ListenerReconfigurable) serverChannelBuilder; + reconfigurableBuilder.validateReconfiguration(newKeystoreConfigs); + reconfigurableBuilder.reconfigure(newKeystoreConfigs); + + for (String propName : CertStores.TRUSTSTORE_PROPS) { + args.sslClientConfigs.put(propName, newConfigs.get(propName)); + } + selector = createSelector(args.sslClientConfigs); + String node2 = "2"; + selector.connect(node2, addr, BUFFER_SIZE, BUFFER_SIZE); + NetworkTestUtils.checkClientConnection(selector, node2, 100, 10); + + TestSslUtils.CertificateBuilder invalidBuilder = new TestSslUtils.CertificateBuilder().sanDnsNames("localhost"); + if (!args.useInlinePem) + builder.useExistingTrustStore(truststoreFile); + Map invalidConfig = builder.certBuilder(invalidBuilder).build(); + Map invalidKeystoreConfigs = new HashMap<>(); + for (String propName : CertStores.KEYSTORE_PROPS) { + invalidKeystoreConfigs.put(propName, invalidConfig.get(propName)); + } + verifyInvalidReconfigure(reconfigurableBuilder, invalidKeystoreConfigs, "keystore without existing SubjectAltName"); + String node3 = "3"; + selector.connect(node3, addr, BUFFER_SIZE, BUFFER_SIZE); + NetworkTestUtils.checkClientConnection(selector, node3, 100, 10); + } + + /** + * Tests reconfiguration of server truststore. Verifies that existing connections continue + * to work with old truststore and new connections work with new truststore. + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testServerTruststoreDynamicUpdate(Args args) throws Exception { + SecurityProtocol securityProtocol = SecurityProtocol.SSL; + args.sslServerConfigs.put(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "required"); + TestSecurityConfig config = new TestSecurityConfig(args.sslServerConfigs); + ListenerName listenerName = ListenerName.forSecurityProtocol(securityProtocol); + ChannelBuilder serverChannelBuilder = ChannelBuilders.serverChannelBuilder(listenerName, + false, securityProtocol, config, null, null, time, new LogContext(), + defaultApiVersionsSupplier()); + server = new NioEchoServer(listenerName, securityProtocol, config, + "localhost", serverChannelBuilder, null, time); + server.start(); + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + + // Verify that client with matching keystore can authenticate, send and receive + String oldNode = "0"; + Selector oldClientSelector = createSelector(args.sslClientConfigs); + oldClientSelector.connect(oldNode, addr, BUFFER_SIZE, BUFFER_SIZE); + NetworkTestUtils.checkClientConnection(selector, oldNode, 100, 10); + + CertStores newClientCertStores = certBuilder(true, "client", args.useInlinePem).addHostName("localhost").build(); + args.sslClientConfigs = args.getTrustingConfig(newClientCertStores, args.serverCertStores); + Map newTruststoreConfigs = newClientCertStores.trustStoreProps(); + assertTrue(serverChannelBuilder instanceof ListenerReconfigurable, "SslChannelBuilder not reconfigurable"); + ListenerReconfigurable reconfigurableBuilder = (ListenerReconfigurable) serverChannelBuilder; + assertEquals(listenerName, reconfigurableBuilder.listenerName()); + reconfigurableBuilder.validateReconfiguration(newTruststoreConfigs); + reconfigurableBuilder.reconfigure(newTruststoreConfigs); + + // Verify that new client with old truststore fails + oldClientSelector.connect("1", addr, BUFFER_SIZE, BUFFER_SIZE); + NetworkTestUtils.waitForChannelClose(oldClientSelector, "1", ChannelState.State.AUTHENTICATION_FAILED); + + // Verify that new client with new truststore can authenticate, send and receive + Selector newClientSelector = createSelector(args.sslClientConfigs); + newClientSelector.connect("2", addr, BUFFER_SIZE, BUFFER_SIZE); + NetworkTestUtils.checkClientConnection(newClientSelector, "2", 100, 10); + + // Verify that old client continues to work + NetworkTestUtils.checkClientConnection(oldClientSelector, oldNode, 100, 10); + + Map invalidConfigs = new HashMap<>(newTruststoreConfigs); + invalidConfigs.put(SslConfigs.SSL_TRUSTSTORE_TYPE_CONFIG, "INVALID_TYPE"); + verifyInvalidReconfigure(reconfigurableBuilder, invalidConfigs, "invalid truststore type"); + + Map missingStoreConfigs = new HashMap<>(); + missingStoreConfigs.put(SslConfigs.SSL_TRUSTSTORE_TYPE_CONFIG, "PKCS12"); + missingStoreConfigs.put(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG, "some.truststore.path"); + missingStoreConfigs.put(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG, new Password("some.truststore.password")); + verifyInvalidReconfigure(reconfigurableBuilder, missingStoreConfigs, "truststore not found"); + + // Verify that new connections continue to work with the server with previously configured keystore after failed reconfiguration + newClientSelector.connect("3", addr, BUFFER_SIZE, BUFFER_SIZE); + NetworkTestUtils.checkClientConnection(newClientSelector, "3", 100, 10); + } + + /** + * Tests if client can plugin customize ssl.engine.factory + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testCustomClientSslEngineFactory(Args args) throws Exception { + args.sslClientConfigs.put(SslConfigs.SSL_ENGINE_FACTORY_CLASS_CONFIG, TestSslUtils.TestSslEngineFactory.class); + verifySslConfigs(args); + } + + /** + * Tests if server can plugin customize ssl.engine.factory + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testCustomServerSslEngineFactory(Args args) throws Exception { + args.sslServerConfigs.put(SslConfigs.SSL_ENGINE_FACTORY_CLASS_CONFIG, TestSslUtils.TestSslEngineFactory.class); + verifySslConfigs(args); + } + + /** + * Tests if client and server both can plugin customize ssl.engine.factory and talk to each other! + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testCustomClientAndServerSslEngineFactory(Args args) throws Exception { + args.sslClientConfigs.put(SslConfigs.SSL_ENGINE_FACTORY_CLASS_CONFIG, TestSslUtils.TestSslEngineFactory.class); + args.sslServerConfigs.put(SslConfigs.SSL_ENGINE_FACTORY_CLASS_CONFIG, TestSslUtils.TestSslEngineFactory.class); + verifySslConfigs(args); + } + + /** + * Tests invalid ssl.engine.factory plugin class + */ + @ParameterizedTest + @ArgumentsSource(SslTransportLayerArgumentsProvider.class) + public void testInvalidSslEngineFactory(Args args) { + args.sslClientConfigs.put(SslConfigs.SSL_ENGINE_FACTORY_CLASS_CONFIG, String.class); + assertThrows(KafkaException.class, () -> createSelector(args.sslClientConfigs)); + } + + private void verifyInvalidReconfigure(ListenerReconfigurable reconfigurable, + Map invalidConfigs, String errorMessage) { + assertThrows(KafkaException.class, () -> reconfigurable.validateReconfiguration(invalidConfigs)); + assertThrows(KafkaException.class, () -> reconfigurable.reconfigure(invalidConfigs)); + } + + private Selector createSelector(Map sslClientConfigs) { + return createSelector(sslClientConfigs, null, null, null); + } + + private Selector createSelector(Map sslClientConfigs, final Integer netReadBufSize, + final Integer netWriteBufSize, final Integer appBufSize) { + TestSslChannelBuilder channelBuilder = new TestSslChannelBuilder(Mode.CLIENT); + channelBuilder.configureBufferSizes(netReadBufSize, netWriteBufSize, appBufSize); + channelBuilder.configure(sslClientConfigs); + this.selector = new Selector(100 * 5000, new Metrics(), time, "MetricGroup", channelBuilder, new LogContext()); + return selector; + } + + private NioEchoServer createEchoServer(Args args, ListenerName listenerName, SecurityProtocol securityProtocol) throws Exception { + return NetworkTestUtils.createEchoServer(listenerName, securityProtocol, new TestSecurityConfig(args.sslServerConfigs), null, time); + } + + private NioEchoServer createEchoServer(Args args, SecurityProtocol securityProtocol) throws Exception { + return createEchoServer(args, ListenerName.forSecurityProtocol(securityProtocol), securityProtocol); + } + + private Selector createSelector(Args args) { + LogContext logContext = new LogContext(); + ChannelBuilder channelBuilder = new SslChannelBuilder(Mode.CLIENT, null, false, logContext); + channelBuilder.configure(args.sslClientConfigs); + selector = new Selector(5000, new Metrics(), time, "MetricGroup", channelBuilder, logContext); + return selector; + } + + private void verifySslConfigs(Args args) throws Exception { + server = createEchoServer(args, SecurityProtocol.SSL); + createSelector(args.sslClientConfigs); + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + String node = "0"; + selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); + NetworkTestUtils.checkClientConnection(selector, node, 100, 10); + } + + private void verifySslConfigsWithHandshakeFailure(Args args) throws Exception { + server = createEchoServer(args, SecurityProtocol.SSL); + createSelector(args.sslClientConfigs); + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + String node = "0"; + selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); + NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED); + server.verifyAuthenticationMetrics(0, 1); + } + + private static CertStores.Builder certBuilder(boolean isServer, String cn, boolean useInlinePem) { + return new CertStores.Builder(isServer) + .cn(cn) + .usePem(useInlinePem); + } + + @FunctionalInterface + private interface FailureAction { + FailureAction NO_OP = () -> { }; + FailureAction THROW_IO_EXCEPTION = () -> { + throw new IOException("Test IO exception"); + }; + void run() throws IOException; + } + + private Supplier defaultApiVersionsSupplier() { + return () -> ApiVersionsResponse.defaultApiVersionsResponse(ApiMessageType.ListenerType.ZK_BROKER); + } + + static class TestSslChannelBuilder extends SslChannelBuilder { + + private Integer netReadBufSizeOverride; + private Integer netWriteBufSizeOverride; + private Integer appBufSizeOverride; + private long failureIndex = Long.MAX_VALUE; + FailureAction readFailureAction = FailureAction.NO_OP; + FailureAction flushFailureAction = FailureAction.NO_OP; + int flushDelayCount = 0; + + public TestSslChannelBuilder(Mode mode) { + super(mode, null, false, new LogContext()); + } + + public void configureBufferSizes(Integer netReadBufSize, Integer netWriteBufSize, Integer appBufSize) { + this.netReadBufSizeOverride = netReadBufSize; + this.netWriteBufSizeOverride = netWriteBufSize; + this.appBufSizeOverride = appBufSize; + } + + @Override + protected SslTransportLayer buildTransportLayer(SslFactory sslFactory, String id, SelectionKey key, + ChannelMetadataRegistry metadataRegistry) throws IOException { + SocketChannel socketChannel = (SocketChannel) key.channel(); + SSLEngine sslEngine = sslFactory.createSslEngine(socketChannel.socket()); + return newTransportLayer(id, key, sslEngine); + } + + protected TestSslTransportLayer newTransportLayer(String id, SelectionKey key, SSLEngine sslEngine) throws IOException { + return new TestSslTransportLayer(id, key, sslEngine); + } + + /** + * SSLTransportLayer with overrides for testing including: + *
              + *
            • Overrides for packet and application buffer size to test buffer resize code path. + * The overridden buffer size starts with a small value and increases in size when the buffer size + * is retrieved to handle overflow/underflow, until the actual session buffer size is reached.
            • + *
            • IOException injection for reads and writes for testing exception handling during handshakes.
            • + *
            • Delayed writes to test handshake failure notifications to peer
            • + *
            + */ + class TestSslTransportLayer extends SslTransportLayer { + + private final ResizeableBufferSize netReadBufSize; + private final ResizeableBufferSize netWriteBufSize; + private final ResizeableBufferSize appBufSize; + private final AtomicLong numReadsRemaining; + private final AtomicLong numFlushesRemaining; + private final AtomicInteger numDelayedFlushesRemaining; + + public TestSslTransportLayer(String channelId, SelectionKey key, SSLEngine sslEngine) { + super(channelId, key, sslEngine, new DefaultChannelMetadataRegistry()); + this.netReadBufSize = new ResizeableBufferSize(netReadBufSizeOverride); + this.netWriteBufSize = new ResizeableBufferSize(netWriteBufSizeOverride); + this.appBufSize = new ResizeableBufferSize(appBufSizeOverride); + numReadsRemaining = new AtomicLong(failureIndex); + numFlushesRemaining = new AtomicLong(failureIndex); + numDelayedFlushesRemaining = new AtomicInteger(flushDelayCount); + } + + @Override + protected int netReadBufferSize() { + ByteBuffer netReadBuffer = netReadBuffer(); + // netReadBufferSize() is invoked in SSLTransportLayer.read() prior to the read + // operation. To avoid the read buffer being expanded too early, increase buffer size + // only when read buffer is full. This ensures that BUFFER_UNDERFLOW is always + // triggered in testNetReadBufferResize(). + boolean updateBufSize = netReadBuffer != null && !netReadBuffer().hasRemaining(); + return netReadBufSize.updateAndGet(super.netReadBufferSize(), updateBufSize); + } + + @Override + protected int netWriteBufferSize() { + return netWriteBufSize.updateAndGet(super.netWriteBufferSize(), true); + } + + @Override + protected int applicationBufferSize() { + return appBufSize.updateAndGet(super.applicationBufferSize(), true); + } + + @Override + protected int readFromSocketChannel() throws IOException { + if (numReadsRemaining.decrementAndGet() == 0 && !ready()) + readFailureAction.run(); + return super.readFromSocketChannel(); + } + + @Override + protected boolean flush(ByteBuffer buf) throws IOException { + if (numFlushesRemaining.decrementAndGet() == 0 && !ready()) + flushFailureAction.run(); + else if (numDelayedFlushesRemaining.getAndDecrement() != 0) + return false; + resetDelayedFlush(); + return super.flush(buf); + } + + @Override + protected void startHandshake() throws IOException { + assertTrue(socketChannel().isConnected(), "SSL handshake initialized too early"); + super.startHandshake(); + } + + private void resetDelayedFlush() { + numDelayedFlushesRemaining.set(flushDelayCount); + } + } + + static class ResizeableBufferSize { + private Integer bufSizeOverride; + ResizeableBufferSize(Integer bufSizeOverride) { + this.bufSizeOverride = bufSizeOverride; + } + int updateAndGet(int actualSize, boolean update) { + int size = actualSize; + if (bufSizeOverride != null) { + if (update) + bufSizeOverride = Math.min(bufSizeOverride * 2, size); + size = bufSizeOverride; + } + return size; + } + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/network/SslTransportTls12Tls13Test.java b/clients/src/test/java/org/apache/kafka/common/network/SslTransportTls12Tls13Test.java new file mode 100644 index 0000000..f0fae56 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/network/SslTransportTls12Tls13Test.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; +import org.apache.kafka.common.config.SslConfigs; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.security.TestSecurityConfig; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledForJreRange; +import org.junit.jupiter.api.condition.JRE; + +public class SslTransportTls12Tls13Test { + private static final int BUFFER_SIZE = 4 * 1024; + private static final Time TIME = Time.SYSTEM; + + private NioEchoServer server; + private Selector selector; + private Map sslClientConfigs; + private Map sslServerConfigs; + + @BeforeEach + public void setup() throws Exception { + // Create certificates for use by client and server. Add server cert to client truststore and vice versa. + CertStores serverCertStores = new CertStores(true, "server", "localhost"); + CertStores clientCertStores = new CertStores(false, "client", "localhost"); + sslServerConfigs = serverCertStores.getTrustingConfig(clientCertStores); + sslClientConfigs = clientCertStores.getTrustingConfig(serverCertStores); + + LogContext logContext = new LogContext(); + ChannelBuilder channelBuilder = new SslChannelBuilder(Mode.CLIENT, null, false, logContext); + channelBuilder.configure(sslClientConfigs); + this.selector = new Selector(5000, new Metrics(), TIME, "MetricGroup", channelBuilder, logContext); + } + + @AfterEach + public void teardown() throws Exception { + if (selector != null) + this.selector.close(); + if (server != null) + this.server.close(); + } + + /** + * Tests that connections fails if TLSv1.3 enabled but cipher suite suitable only for TLSv1.2 used. + */ + @Test + @EnabledForJreRange(min = JRE.JAVA_11) + public void testCiphersSuiteForTls12FailsForTls13() throws Exception { + String cipherSuite = "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384"; + + sslServerConfigs.put(SslConfigs.SSL_ENABLED_PROTOCOLS_CONFIG, Collections.singletonList("TLSv1.3")); + sslServerConfigs.put(SslConfigs.SSL_CIPHER_SUITES_CONFIG, Collections.singletonList(cipherSuite)); + server = NetworkTestUtils.createEchoServer(ListenerName.forSecurityProtocol(SecurityProtocol.SSL), + SecurityProtocol.SSL, new TestSecurityConfig(sslServerConfigs), null, TIME); + + sslClientConfigs.put(SslConfigs.SSL_ENABLED_PROTOCOLS_CONFIG, Collections.singletonList("TLSv1.3")); + sslClientConfigs.put(SslConfigs.SSL_CIPHER_SUITES_CONFIG, Collections.singletonList(cipherSuite)); + + checkAuthentiationFailed(); + } + + /** + * Tests that connections can't be made if server uses TLSv1.2 with custom cipher suite and client uses TLSv1.3. + */ + @Test + @EnabledForJreRange(min = JRE.JAVA_11) + public void testCiphersSuiteFailForServerTls12ClientTls13() throws Exception { + String tls12CipherSuite = "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384"; + String tls13CipherSuite = "TLS_AES_128_GCM_SHA256"; + + sslServerConfigs.put(SslConfigs.SSL_PROTOCOL_CONFIG, "TLSv1.2"); + sslServerConfigs.put(SslConfigs.SSL_ENABLED_PROTOCOLS_CONFIG, Collections.singletonList("TLSv1.2")); + sslServerConfigs.put(SslConfigs.SSL_CIPHER_SUITES_CONFIG, Collections.singletonList(tls12CipherSuite)); + server = NetworkTestUtils.createEchoServer(ListenerName.forSecurityProtocol(SecurityProtocol.SSL), + SecurityProtocol.SSL, new TestSecurityConfig(sslServerConfigs), null, TIME); + + sslClientConfigs.put(SslConfigs.SSL_PROTOCOL_CONFIG, "TLSv1.3"); + sslClientConfigs.put(SslConfigs.SSL_CIPHER_SUITES_CONFIG, Collections.singletonList(tls13CipherSuite)); + + checkAuthentiationFailed(); + } + + /** + * Tests that connections can be made with TLSv1.3 cipher suite. + */ + @Test + @EnabledForJreRange(min = JRE.JAVA_11) + public void testCiphersSuiteForTls13() throws Exception { + String cipherSuite = "TLS_AES_128_GCM_SHA256"; + + sslServerConfigs.put(SslConfigs.SSL_CIPHER_SUITES_CONFIG, Collections.singletonList(cipherSuite)); + server = NetworkTestUtils.createEchoServer(ListenerName.forSecurityProtocol(SecurityProtocol.SSL), + SecurityProtocol.SSL, new TestSecurityConfig(sslServerConfigs), null, TIME); + + sslClientConfigs.put(SslConfigs.SSL_CIPHER_SUITES_CONFIG, Collections.singletonList(cipherSuite)); + checkAuthenticationSucceed(); + } + + /** + * Tests that connections can be made with TLSv1.2 cipher suite. + */ + @Test + public void testCiphersSuiteForTls12() throws Exception { + String cipherSuite = "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384"; + + sslServerConfigs.put(SslConfigs.SSL_ENABLED_PROTOCOLS_CONFIG, Arrays.asList(SslConfigs.DEFAULT_SSL_ENABLED_PROTOCOLS.split(","))); + sslServerConfigs.put(SslConfigs.SSL_CIPHER_SUITES_CONFIG, Collections.singletonList(cipherSuite)); + server = NetworkTestUtils.createEchoServer(ListenerName.forSecurityProtocol(SecurityProtocol.SSL), + SecurityProtocol.SSL, new TestSecurityConfig(sslServerConfigs), null, TIME); + + sslClientConfigs.put(SslConfigs.SSL_ENABLED_PROTOCOLS_CONFIG, Arrays.asList(SslConfigs.DEFAULT_SSL_ENABLED_PROTOCOLS.split(","))); + sslClientConfigs.put(SslConfigs.SSL_CIPHER_SUITES_CONFIG, Collections.singletonList(cipherSuite)); + checkAuthenticationSucceed(); + } + + /** Checks connection failed using the specified {@code tlsVersion}. */ + private void checkAuthentiationFailed() throws IOException, InterruptedException { + sslClientConfigs.put(SslConfigs.SSL_ENABLED_PROTOCOLS_CONFIG, Arrays.asList("TLSv1.3")); + createSelector(sslClientConfigs); + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + selector.connect("0", addr, BUFFER_SIZE, BUFFER_SIZE); + + NetworkTestUtils.waitForChannelClose(selector, "0", ChannelState.State.AUTHENTICATION_FAILED); + server.verifyAuthenticationMetrics(0, 1); + } + + private void checkAuthenticationSucceed() throws IOException, InterruptedException { + createSelector(sslClientConfigs); + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + selector.connect("0", addr, BUFFER_SIZE, BUFFER_SIZE); + NetworkTestUtils.waitForChannelReady(selector, "0"); + server.verifyAuthenticationMetrics(1, 0); + } + + private void createSelector(Map sslClientConfigs) { + SslTransportLayerTest.TestSslChannelBuilder channelBuilder = new SslTransportLayerTest.TestSslChannelBuilder(Mode.CLIENT); + channelBuilder.configureBufferSizes(null, null, null); + channelBuilder.configure(sslClientConfigs); + this.selector = new Selector(100 * 5000, new Metrics(), TIME, "MetricGroup", channelBuilder, new LogContext()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/network/SslVersionsTransportLayerTest.java b/clients/src/test/java/org/apache/kafka/common/network/SslVersionsTransportLayerTest.java new file mode 100644 index 0000000..584d48f --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/network/SslVersionsTransportLayerTest.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.network; + +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +import org.apache.kafka.common.config.SslConfigs; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.security.TestSecurityConfig; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.common.utils.Java; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +/** + * Tests for the SSL transport layer. + * Checks different versions of the protocol usage on the server and client. + */ +public class SslVersionsTransportLayerTest { + private static final int BUFFER_SIZE = 4 * 1024; + private static final Time TIME = Time.SYSTEM; + + public static Stream parameters() { + List parameters = new ArrayList<>(); + + parameters.add(Arguments.of(Collections.singletonList("TLSv1.2"), Collections.singletonList("TLSv1.2"))); + + if (Java.IS_JAVA11_COMPATIBLE) { + parameters.add(Arguments.of(Collections.singletonList("TLSv1.2"), Collections.singletonList("TLSv1.3"))); + parameters.add(Arguments.of(Collections.singletonList("TLSv1.3"), Collections.singletonList("TLSv1.2"))); + parameters.add(Arguments.of(Collections.singletonList("TLSv1.3"), Collections.singletonList("TLSv1.3"))); + parameters.add(Arguments.of(Collections.singletonList("TLSv1.2"), Arrays.asList("TLSv1.2", "TLSv1.3"))); + parameters.add(Arguments.of(Collections.singletonList("TLSv1.2"), Arrays.asList("TLSv1.3", "TLSv1.2"))); + parameters.add(Arguments.of(Collections.singletonList("TLSv1.3"), Arrays.asList("TLSv1.2", "TLSv1.3"))); + parameters.add(Arguments.of(Collections.singletonList("TLSv1.3"), Arrays.asList("TLSv1.3", "TLSv1.2"))); + parameters.add(Arguments.of(Arrays.asList("TLSv1.3", "TLSv1.2"), Collections.singletonList("TLSv1.3"))); + parameters.add(Arguments.of(Arrays.asList("TLSv1.3", "TLSv1.2"), Collections.singletonList("TLSv1.2"))); + parameters.add(Arguments.of(Arrays.asList("TLSv1.3", "TLSv1.2"), Arrays.asList("TLSv1.2", "TLSv1.3"))); + parameters.add(Arguments.of(Arrays.asList("TLSv1.3", "TLSv1.2"), Arrays.asList("TLSv1.3", "TLSv1.2"))); + parameters.add(Arguments.of(Arrays.asList("TLSv1.2", "TLSv1.3"), Collections.singletonList("TLSv1.3"))); + parameters.add(Arguments.of(Arrays.asList("TLSv1.2", "TLSv1.3"), Collections.singletonList("TLSv1.2"))); + parameters.add(Arguments.of(Arrays.asList("TLSv1.2", "TLSv1.3"), Arrays.asList("TLSv1.2", "TLSv1.3"))); + parameters.add(Arguments.of(Arrays.asList("TLSv1.2", "TLSv1.3"), Arrays.asList("TLSv1.3", "TLSv1.2"))); + } + + return parameters.stream(); + } + + /** + * Tests that connection success with the default TLS version. + * Note that debug mode for javax.net.ssl can be enabled via {@code System.setProperty("javax.net.debug", "ssl:handshake");} + */ + @ParameterizedTest(name = "tlsServerProtocol = {0}, tlsClientProtocol = {1}") + @MethodSource("parameters") + public void testTlsDefaults(List serverProtocols, List clientProtocols) throws Exception { + // Create certificates for use by client and server. Add server cert to client truststore and vice versa. + CertStores serverCertStores = new CertStores(true, "server", "localhost"); + CertStores clientCertStores = new CertStores(false, "client", "localhost"); + + Map sslClientConfigs = getTrustingConfig(clientCertStores, serverCertStores, clientProtocols); + Map sslServerConfigs = getTrustingConfig(serverCertStores, clientCertStores, serverProtocols); + + NioEchoServer server = NetworkTestUtils.createEchoServer(ListenerName.forSecurityProtocol(SecurityProtocol.SSL), + SecurityProtocol.SSL, + new TestSecurityConfig(sslServerConfigs), + null, + TIME); + Selector selector = createClientSelector(sslClientConfigs); + + String node = "0"; + selector.connect(node, new InetSocketAddress("localhost", server.port()), BUFFER_SIZE, BUFFER_SIZE); + + if (isCompatible(serverProtocols, clientProtocols)) { + NetworkTestUtils.waitForChannelReady(selector, node); + + int msgSz = 1024 * 1024; + String message = TestUtils.randomString(msgSz); + selector.send(new NetworkSend(node, ByteBufferSend.sizePrefixed(ByteBuffer.wrap(message.getBytes())))); + while (selector.completedReceives().isEmpty()) { + selector.poll(100L); + } + int totalBytes = msgSz + 4; // including 4-byte size + server.waitForMetric("incoming-byte", totalBytes); + server.waitForMetric("outgoing-byte", totalBytes); + server.waitForMetric("request", 1); + server.waitForMetric("response", 1); + } else { + NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED); + server.verifyAuthenticationMetrics(0, 1); + } + } + + /** + *

            + * The explanation of this check in the structure of the ClientHello SSL message. + * Please, take a look at the Guide, + * "Send ClientHello Message" section. + *

            + * > Client version: For TLS 1.3, this has a fixed value, TLSv1.2; TLS 1.3 uses the extension supported_versions and not this field to negotiate protocol version + * ... + * > supported_versions: Lists which versions of TLS the client supports. In particular, if the client + * > requests TLS 1.3, then the client version field has the value TLSv1.2 and this extension + * > contains the value TLSv1.3; if the client requests TLS 1.2, then the client version field has the + * > value TLSv1.2 and this extension either doesn’t exist or contains the value TLSv1.2 but not the value TLSv1.3. + *

            + * + * This mean that TLSv1.3 client can fallback to TLSv1.2 but TLSv1.2 client can't change protocol to TLSv1.3. + * + * @param serverProtocols Server protocols. Expected to be non empty. + * @param clientProtocols Client protocols. Expected to be non empty. + * @return {@code true} if client should be able to connect to the server. + */ + private boolean isCompatible(List serverProtocols, List clientProtocols) { + assertNotNull(serverProtocols); + assertFalse(serverProtocols.isEmpty()); + assertNotNull(clientProtocols); + assertFalse(clientProtocols.isEmpty()); + + return serverProtocols.contains(clientProtocols.get(0)) || + (clientProtocols.get(0).equals("TLSv1.3") && !Collections.disjoint(serverProtocols, clientProtocols)); + } + + private static Map getTrustingConfig(CertStores certStores, CertStores peerCertStores, List tlsProtocols) { + Map configs = certStores.getTrustingConfig(peerCertStores); + configs.putAll(sslConfig(tlsProtocols)); + return configs; + } + + private static Map sslConfig(List tlsProtocols) { + Map sslConfig = new HashMap<>(); + sslConfig.put(SslConfigs.SSL_PROTOCOL_CONFIG, tlsProtocols.get(0)); + sslConfig.put(SslConfigs.SSL_ENABLED_PROTOCOLS_CONFIG, tlsProtocols); + return sslConfig; + } + + private Selector createClientSelector(Map sslClientConfigs) { + SslTransportLayerTest.TestSslChannelBuilder channelBuilder = + new SslTransportLayerTest.TestSslChannelBuilder(Mode.CLIENT); + channelBuilder.configureBufferSizes(null, null, null); + channelBuilder.configure(sslClientConfigs); + return new Selector(100 * 5000, new Metrics(), TIME, "MetricGroup", channelBuilder, new LogContext()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/protocol/ApiKeysTest.java b/clients/src/test/java/org/apache/kafka/common/protocol/ApiKeysTest.java new file mode 100644 index 0000000..3c66b21 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/protocol/ApiKeysTest.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.protocol; + +import org.apache.kafka.common.protocol.types.BoundField; +import org.apache.kafka.common.protocol.types.Schema; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.EnumSet; +import java.util.HashSet; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ApiKeysTest { + + @Test + public void testForIdWithInvalidIdLow() { + assertThrows(IllegalArgumentException.class, () -> ApiKeys.forId(-1)); + } + + @Test + public void testForIdWithInvalidIdHigh() { + assertThrows(IllegalArgumentException.class, () -> ApiKeys.forId(10000)); + } + + @Test + public void testAlterIsrIsClusterAction() { + assertTrue(ApiKeys.ALTER_ISR.clusterAction); + } + + /** + * All valid client responses which may be throttled should have a field named + * 'throttle_time_ms' to return the throttle time to the client. Exclusions are + *

              + *
            • Cluster actions used only for inter-broker are throttled only if unauthorized + *
            • SASL_HANDSHAKE and SASL_AUTHENTICATE are not throttled when used for authentication + * when a connection is established or for re-authentication thereafter; these requests + * return an error response that may be throttled if they are sent otherwise. + *
            + */ + @Test + public void testResponseThrottleTime() { + Set authenticationKeys = EnumSet.of(ApiKeys.SASL_HANDSHAKE, ApiKeys.SASL_AUTHENTICATE); + // Newer protocol apis include throttle time ms even for cluster actions + Set clusterActionsWithThrottleTimeMs = EnumSet.of(ApiKeys.ALTER_ISR, ApiKeys.ALLOCATE_PRODUCER_IDS); + for (ApiKeys apiKey: ApiKeys.zkBrokerApis()) { + Schema responseSchema = apiKey.messageType.responseSchemas()[apiKey.latestVersion()]; + BoundField throttleTimeField = responseSchema.get("throttle_time_ms"); + if ((apiKey.clusterAction && !clusterActionsWithThrottleTimeMs.contains(apiKey)) + || authenticationKeys.contains(apiKey)) + assertNull(throttleTimeField, "Unexpected throttle time field: " + apiKey); + else + assertNotNull(throttleTimeField, "Throttle time field missing: " + apiKey); + } + } + + @Test + public void testApiScope() { + Set apisMissingScope = new HashSet<>(); + for (ApiKeys apiKey : ApiKeys.values()) { + if (apiKey.messageType.listeners().isEmpty()) { + apisMissingScope.add(apiKey); + } + } + assertEquals(Collections.emptySet(), apisMissingScope, + "Found some APIs missing scope definition"); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/protocol/ErrorsTest.java b/clients/src/test/java/org/apache/kafka/common/protocol/ErrorsTest.java new file mode 100644 index 0000000..44523f4 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/protocol/ErrorsTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.protocol; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +import java.util.HashSet; +import java.util.Set; + +import org.apache.kafka.common.errors.ApiException; +import org.apache.kafka.common.errors.TimeoutException; +import org.junit.jupiter.api.Test; + +public class ErrorsTest { + + @Test + public void testUniqueErrorCodes() { + Set codeSet = new HashSet<>(); + for (Errors error : Errors.values()) { + codeSet.add(error.code()); + } + assertEquals(codeSet.size(), Errors.values().length, "Error codes must be unique"); + } + + @Test + public void testUniqueExceptions() { + Set> exceptionSet = new HashSet<>(); + for (Errors error : Errors.values()) { + if (error != Errors.NONE) + exceptionSet.add(error.exception().getClass()); + } + assertEquals(exceptionSet.size(), Errors.values().length - 1, "Exceptions must be unique"); // Ignore NONE + } + + @Test + public void testExceptionsAreNotGeneric() { + for (Errors error : Errors.values()) { + if (error != Errors.NONE) + assertNotEquals(error.exception().getClass(), ApiException.class, "Generic ApiException should not be used"); + } + } + + @Test + public void testNoneException() { + assertNull(Errors.NONE.exception(), "The NONE error should not have an exception"); + } + + @Test + public void testForExceptionInheritance() { + class ExtendedTimeoutException extends TimeoutException { } + + Errors expectedError = Errors.forException(new TimeoutException()); + Errors actualError = Errors.forException(new ExtendedTimeoutException()); + + assertEquals(expectedError, actualError, "forException should match super classes"); + } + + @Test + public void testForExceptionDefault() { + Errors error = Errors.forException(new ApiException()); + assertEquals(Errors.UNKNOWN_SERVER_ERROR, error, "forException should default to unknown"); + } + + @Test + public void testExceptionName() { + String exceptionName = Errors.UNKNOWN_SERVER_ERROR.exceptionName(); + assertEquals("org.apache.kafka.common.errors.UnknownServerException", exceptionName); + exceptionName = Errors.NONE.exceptionName(); + assertNull(exceptionName); + exceptionName = Errors.INVALID_TOPIC_EXCEPTION.exceptionName(); + assertEquals("org.apache.kafka.common.errors.InvalidTopicException", exceptionName); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/protocol/MessageUtilTest.java b/clients/src/test/java/org/apache/kafka/common/protocol/MessageUtilTest.java new file mode 100755 index 0000000..33dcabb --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/protocol/MessageUtilTest.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.protocol; + +import org.apache.kafka.common.protocol.types.RawTaggedField; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@Timeout(120) +public final class MessageUtilTest { + + @Test + public void testDeepToString() { + assertEquals("[1, 2, 3]", + MessageUtil.deepToString(Arrays.asList(1, 2, 3).iterator())); + assertEquals("[foo]", + MessageUtil.deepToString(Arrays.asList("foo").iterator())); + } + + @Test + public void testByteBufferToArray() { + assertArrayEquals(new byte[]{1, 2, 3}, + MessageUtil.byteBufferToArray(ByteBuffer.wrap(new byte[]{1, 2, 3}))); + assertArrayEquals(new byte[]{}, + MessageUtil.byteBufferToArray(ByteBuffer.wrap(new byte[]{}))); + } + + @Test + public void testDuplicate() { + assertNull(MessageUtil.duplicate(null)); + assertArrayEquals(new byte[] {}, + MessageUtil.duplicate(new byte[] {})); + assertArrayEquals(new byte[] {1, 2, 3}, + MessageUtil.duplicate(new byte[] {1, 2, 3})); + } + + @Test + public void testCompareRawTaggedFields() { + assertTrue(MessageUtil.compareRawTaggedFields(null, null)); + assertTrue(MessageUtil.compareRawTaggedFields(null, Collections.emptyList())); + assertTrue(MessageUtil.compareRawTaggedFields(Collections.emptyList(), null)); + assertFalse(MessageUtil.compareRawTaggedFields(Collections.emptyList(), + Collections.singletonList(new RawTaggedField(1, new byte[] {1})))); + assertFalse(MessageUtil.compareRawTaggedFields(null, + Collections.singletonList(new RawTaggedField(1, new byte[] {1})))); + assertFalse(MessageUtil.compareRawTaggedFields( + Collections.singletonList(new RawTaggedField(1, new byte[] {1})), + Collections.emptyList())); + assertTrue(MessageUtil.compareRawTaggedFields( + Arrays.asList(new RawTaggedField(1, new byte[] {1}), + new RawTaggedField(2, new byte[] {})), + Arrays.asList(new RawTaggedField(1, new byte[] {1}), + new RawTaggedField(2, new byte[] {})))); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/protocol/ProtoUtilsTest.java b/clients/src/test/java/org/apache/kafka/common/protocol/ProtoUtilsTest.java new file mode 100644 index 0000000..712c611 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/protocol/ProtoUtilsTest.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.protocol; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ProtoUtilsTest { + @Test + public void testDelayedAllocationSchemaDetection() { + //verifies that schemas known to retain a reference to the underlying byte buffer are correctly detected. + for (ApiKeys key : ApiKeys.values()) { + switch (key) { + case PRODUCE: + case JOIN_GROUP: + case SYNC_GROUP: + case SASL_AUTHENTICATE: + case EXPIRE_DELEGATION_TOKEN: + case RENEW_DELEGATION_TOKEN: + case ALTER_USER_SCRAM_CREDENTIALS: + case ENVELOPE: + assertTrue(key.requiresDelayedAllocation, key + " should require delayed allocation"); + break; + default: + if (key.forwardable) + assertTrue(key.requiresDelayedAllocation, + key + " should require delayed allocation since it is forwardable"); + else + assertFalse(key.requiresDelayedAllocation, key + " should not require delayed allocation"); + break; + } + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/protocol/SendBuilderTest.java b/clients/src/test/java/org/apache/kafka/common/protocol/SendBuilderTest.java new file mode 100644 index 0000000..6d36395 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/protocol/SendBuilderTest.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.protocol; + +import org.apache.kafka.common.network.Send; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.MemoryRecordsBuilder; +import org.apache.kafka.common.record.SimpleRecord; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.common.record.UnalignedMemoryRecords; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class SendBuilderTest { + + @Test + public void testZeroCopyByteBuffer() { + byte[] data = Utils.utf8("foo"); + ByteBuffer zeroCopyBuffer = ByteBuffer.wrap(data); + SendBuilder builder = new SendBuilder(8); + + builder.writeInt(5); + builder.writeByteBuffer(zeroCopyBuffer); + builder.writeInt(15); + Send send = builder.build(); + + // Overwrite the original buffer in order to prove the data was not copied + byte[] overwrittenData = Utils.utf8("bar"); + assertEquals(data.length, overwrittenData.length); + zeroCopyBuffer.rewind(); + zeroCopyBuffer.put(overwrittenData); + zeroCopyBuffer.rewind(); + + ByteBuffer buffer = TestUtils.toBuffer(send); + assertEquals(8 + data.length, buffer.remaining()); + assertEquals(5, buffer.getInt()); + assertEquals("bar", getString(buffer, data.length)); + assertEquals(15, buffer.getInt()); + } + + @Test + public void testWriteByteBufferRespectsPosition() { + byte[] data = Utils.utf8("yolo"); + assertEquals(4, data.length); + + ByteBuffer buffer = ByteBuffer.wrap(data); + SendBuilder builder = new SendBuilder(0); + + buffer.limit(2); + builder.writeByteBuffer(buffer); + assertEquals(0, buffer.position()); + + buffer.position(2); + buffer.limit(4); + builder.writeByteBuffer(buffer); + assertEquals(2, buffer.position()); + + Send send = builder.build(); + ByteBuffer readBuffer = TestUtils.toBuffer(send); + assertEquals("yolo", getString(readBuffer, 4)); + } + + @Test + public void testZeroCopyRecords() { + ByteBuffer buffer = ByteBuffer.allocate(128); + MemoryRecords records = createRecords(buffer, "foo"); + + SendBuilder builder = new SendBuilder(8); + builder.writeInt(5); + builder.writeRecords(records); + builder.writeInt(15); + Send send = builder.build(); + + // Overwrite the original buffer in order to prove the data was not copied + buffer.rewind(); + MemoryRecords overwrittenRecords = createRecords(buffer, "bar"); + + ByteBuffer readBuffer = TestUtils.toBuffer(send); + assertEquals(5, readBuffer.getInt()); + assertEquals(overwrittenRecords, getRecords(readBuffer, records.sizeInBytes())); + assertEquals(15, readBuffer.getInt()); + } + + @Test + public void testZeroCopyUnalignedRecords() { + ByteBuffer buffer = ByteBuffer.allocate(128); + MemoryRecords records = createRecords(buffer, "foo"); + + ByteBuffer buffer1 = records.buffer().duplicate(); + buffer1.limit(buffer1.limit() / 2); + + ByteBuffer buffer2 = records.buffer().duplicate(); + buffer2.position(buffer2.limit() / 2); + + UnalignedMemoryRecords records1 = new UnalignedMemoryRecords(buffer1); + UnalignedMemoryRecords records2 = new UnalignedMemoryRecords(buffer2); + + SendBuilder builder = new SendBuilder(8); + builder.writeInt(5); + builder.writeRecords(records1); + builder.writeRecords(records2); + builder.writeInt(15); + Send send = builder.build(); + + // Overwrite the original buffer in order to prove the data was not copied + buffer.rewind(); + MemoryRecords overwrittenRecords = createRecords(buffer, "bar"); + + ByteBuffer readBuffer = TestUtils.toBuffer(send); + assertEquals(5, readBuffer.getInt()); + assertEquals(overwrittenRecords, getRecords(readBuffer, records.sizeInBytes())); + assertEquals(15, readBuffer.getInt()); + } + + + private String getString(ByteBuffer buffer, int size) { + byte[] readData = new byte[size]; + buffer.get(readData); + return Utils.utf8(readData); + } + + private MemoryRecords getRecords(ByteBuffer buffer, int size) { + int initialPosition = buffer.position(); + int initialLimit = buffer.limit(); + int recordsLimit = initialPosition + size; + + buffer.limit(recordsLimit); + MemoryRecords records = MemoryRecords.readableRecords(buffer.slice()); + + buffer.position(recordsLimit); + buffer.limit(initialLimit); + return records; + } + + private MemoryRecords createRecords(ByteBuffer buffer, String value) { + MemoryRecordsBuilder recordsBuilder = MemoryRecords.builder( + buffer, + CompressionType.NONE, + TimestampType.CREATE_TIME, + 0L + ); + recordsBuilder.append(new SimpleRecord(Utils.utf8(value))); + return recordsBuilder.build(); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/protocol/types/ProtocolSerializationTest.java b/clients/src/test/java/org/apache/kafka/common/protocol/types/ProtocolSerializationTest.java new file mode 100644 index 0000000..811bcf4 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/protocol/types/ProtocolSerializationTest.java @@ -0,0 +1,437 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.protocol.types; + +import org.apache.kafka.common.utils.ByteUtils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.util.Arrays; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class ProtocolSerializationTest { + + private Schema schema; + private Struct struct; + + @BeforeEach + public void setup() { + this.schema = new Schema(new Field("boolean", Type.BOOLEAN), + new Field("int8", Type.INT8), + new Field("int16", Type.INT16), + new Field("int32", Type.INT32), + new Field("int64", Type.INT64), + new Field("varint", Type.VARINT), + new Field("varlong", Type.VARLONG), + new Field("float64", Type.FLOAT64), + new Field("string", Type.STRING), + new Field("compact_string", Type.COMPACT_STRING), + new Field("nullable_string", Type.NULLABLE_STRING), + new Field("compact_nullable_string", Type.COMPACT_NULLABLE_STRING), + new Field("bytes", Type.BYTES), + new Field("compact_bytes", Type.COMPACT_BYTES), + new Field("nullable_bytes", Type.NULLABLE_BYTES), + new Field("compact_nullable_bytes", Type.COMPACT_NULLABLE_BYTES), + new Field("array", new ArrayOf(Type.INT32)), + new Field("compact_array", new CompactArrayOf(Type.INT32)), + new Field("null_array", ArrayOf.nullable(Type.INT32)), + new Field("compact_null_array", CompactArrayOf.nullable(Type.INT32)), + new Field("struct", new Schema(new Field("field", new ArrayOf(Type.INT32))))); + this.struct = new Struct(this.schema).set("boolean", true) + .set("int8", (byte) 1) + .set("int16", (short) 1) + .set("int32", 1) + .set("int64", 1L) + .set("varint", 300) + .set("varlong", 500L) + .set("float64", 0.5D) + .set("string", "1") + .set("compact_string", "1") + .set("nullable_string", null) + .set("compact_nullable_string", null) + .set("bytes", ByteBuffer.wrap("1".getBytes())) + .set("compact_bytes", ByteBuffer.wrap("1".getBytes())) + .set("nullable_bytes", null) + .set("compact_nullable_bytes", null) + .set("array", new Object[] {1}) + .set("compact_array", new Object[] {1}) + .set("null_array", null) + .set("compact_null_array", null); + this.struct.set("struct", this.struct.instance("struct").set("field", new Object[] {1, 2, 3})); + } + + @Test + public void testSimple() { + check(Type.BOOLEAN, false, "BOOLEAN"); + check(Type.BOOLEAN, true, "BOOLEAN"); + check(Type.INT8, (byte) -111, "INT8"); + check(Type.INT16, (short) -11111, "INT16"); + check(Type.INT32, -11111111, "INT32"); + check(Type.INT64, -11111111111L, "INT64"); + check(Type.FLOAT64, 2.5, "FLOAT64"); + check(Type.FLOAT64, -0.5, "FLOAT64"); + check(Type.FLOAT64, 1e300, "FLOAT64"); + check(Type.FLOAT64, 0.0, "FLOAT64"); + check(Type.FLOAT64, -0.0, "FLOAT64"); + check(Type.FLOAT64, Double.MAX_VALUE, "FLOAT64"); + check(Type.FLOAT64, Double.MIN_VALUE, "FLOAT64"); + check(Type.FLOAT64, Double.NaN, "FLOAT64"); + check(Type.FLOAT64, Double.NEGATIVE_INFINITY, "FLOAT64"); + check(Type.FLOAT64, Double.POSITIVE_INFINITY, "FLOAT64"); + check(Type.STRING, "", "STRING"); + check(Type.STRING, "hello", "STRING"); + check(Type.STRING, "A\u00ea\u00f1\u00fcC", "STRING"); + check(Type.COMPACT_STRING, "", "COMPACT_STRING"); + check(Type.COMPACT_STRING, "hello", "COMPACT_STRING"); + check(Type.COMPACT_STRING, "A\u00ea\u00f1\u00fcC", "COMPACT_STRING"); + check(Type.NULLABLE_STRING, null, "NULLABLE_STRING"); + check(Type.NULLABLE_STRING, "", "NULLABLE_STRING"); + check(Type.NULLABLE_STRING, "hello", "NULLABLE_STRING"); + check(Type.COMPACT_NULLABLE_STRING, null, "COMPACT_NULLABLE_STRING"); + check(Type.COMPACT_NULLABLE_STRING, "", "COMPACT_NULLABLE_STRING"); + check(Type.COMPACT_NULLABLE_STRING, "hello", "COMPACT_NULLABLE_STRING"); + check(Type.BYTES, ByteBuffer.allocate(0), "BYTES"); + check(Type.BYTES, ByteBuffer.wrap("abcd".getBytes()), "BYTES"); + check(Type.COMPACT_BYTES, ByteBuffer.allocate(0), "COMPACT_BYTES"); + check(Type.COMPACT_BYTES, ByteBuffer.wrap("abcd".getBytes()), "COMPACT_BYTES"); + check(Type.NULLABLE_BYTES, null, "NULLABLE_BYTES"); + check(Type.NULLABLE_BYTES, ByteBuffer.allocate(0), "NULLABLE_BYTES"); + check(Type.NULLABLE_BYTES, ByteBuffer.wrap("abcd".getBytes()), "NULLABLE_BYTES"); + check(Type.COMPACT_NULLABLE_BYTES, null, "COMPACT_NULLABLE_BYTES"); + check(Type.COMPACT_NULLABLE_BYTES, ByteBuffer.allocate(0), "COMPACT_NULLABLE_BYTES"); + check(Type.COMPACT_NULLABLE_BYTES, ByteBuffer.wrap("abcd".getBytes()), + "COMPACT_NULLABLE_BYTES"); + check(Type.VARINT, Integer.MAX_VALUE, "VARINT"); + check(Type.VARINT, Integer.MIN_VALUE, "VARINT"); + check(Type.VARLONG, Long.MAX_VALUE, "VARLONG"); + check(Type.VARLONG, Long.MIN_VALUE, "VARLONG"); + check(new ArrayOf(Type.INT32), new Object[] {1, 2, 3, 4}, "ARRAY(INT32)"); + check(new ArrayOf(Type.STRING), new Object[] {}, "ARRAY(STRING)"); + check(new ArrayOf(Type.STRING), new Object[] {"hello", "there", "beautiful"}, + "ARRAY(STRING)"); + check(new CompactArrayOf(Type.INT32), new Object[] {1, 2, 3, 4}, + "COMPACT_ARRAY(INT32)"); + check(new CompactArrayOf(Type.COMPACT_STRING), new Object[] {}, + "COMPACT_ARRAY(COMPACT_STRING)"); + check(new CompactArrayOf(Type.COMPACT_STRING), + new Object[] {"hello", "there", "beautiful"}, + "COMPACT_ARRAY(COMPACT_STRING)"); + check(ArrayOf.nullable(Type.STRING), null, "ARRAY(STRING)"); + check(CompactArrayOf.nullable(Type.COMPACT_STRING), null, + "COMPACT_ARRAY(COMPACT_STRING)"); + } + + @Test + public void testNulls() { + for (BoundField f : this.schema.fields()) { + Object o = this.struct.get(f); + try { + this.struct.set(f, null); + this.struct.validate(); + if (!f.def.type.isNullable()) + fail("Should not allow serialization of null value."); + } catch (SchemaException e) { + assertFalse(f.def.type.isNullable(), f.toString() + " should not be nullable"); + } finally { + this.struct.set(f, o); + } + } + } + + @Test + public void testDefault() { + Schema schema = new Schema(new Field("field", Type.INT32, "doc", 42)); + Struct struct = new Struct(schema); + assertEquals(42, struct.get("field"), "Should get the default value"); + struct.validate(); // should be valid even with missing value + } + + @Test + public void testNullableDefault() { + checkNullableDefault(Type.NULLABLE_BYTES, ByteBuffer.allocate(0)); + checkNullableDefault(Type.COMPACT_NULLABLE_BYTES, ByteBuffer.allocate(0)); + checkNullableDefault(Type.NULLABLE_STRING, "default"); + checkNullableDefault(Type.COMPACT_NULLABLE_STRING, "default"); + } + + private void checkNullableDefault(Type type, Object defaultValue) { + // Should use default even if the field allows null values + Schema schema = new Schema(new Field("field", type, "doc", defaultValue)); + Struct struct = new Struct(schema); + assertEquals(defaultValue, struct.get("field"), "Should get the default value"); + struct.validate(); // should be valid even with missing value + } + + @Test + public void testReadArraySizeTooLarge() { + Type type = new ArrayOf(Type.INT8); + int size = 10; + ByteBuffer invalidBuffer = ByteBuffer.allocate(4 + size); + invalidBuffer.putInt(Integer.MAX_VALUE); + for (int i = 0; i < size; i++) + invalidBuffer.put((byte) i); + invalidBuffer.rewind(); + try { + type.read(invalidBuffer); + fail("Array size not validated"); + } catch (SchemaException e) { + // Expected exception + } + } + + @Test + public void testReadCompactArraySizeTooLarge() { + Type type = new CompactArrayOf(Type.INT8); + int size = 10; + ByteBuffer invalidBuffer = ByteBuffer.allocate( + ByteUtils.sizeOfUnsignedVarint(Integer.MAX_VALUE) + size); + ByteUtils.writeUnsignedVarint(Integer.MAX_VALUE, invalidBuffer); + for (int i = 0; i < size; i++) + invalidBuffer.put((byte) i); + invalidBuffer.rewind(); + try { + type.read(invalidBuffer); + fail("Array size not validated"); + } catch (SchemaException e) { + // Expected exception + } + } + + @Test + public void testReadNegativeArraySize() { + Type type = new ArrayOf(Type.INT8); + int size = 10; + ByteBuffer invalidBuffer = ByteBuffer.allocate(4 + size); + invalidBuffer.putInt(-1); + for (int i = 0; i < size; i++) + invalidBuffer.put((byte) i); + invalidBuffer.rewind(); + try { + type.read(invalidBuffer); + fail("Array size not validated"); + } catch (SchemaException e) { + // Expected exception + } + } + + @Test + public void testReadZeroCompactArraySize() { + Type type = new CompactArrayOf(Type.INT8); + int size = 10; + ByteBuffer invalidBuffer = ByteBuffer.allocate( + ByteUtils.sizeOfUnsignedVarint(0) + size); + ByteUtils.writeUnsignedVarint(0, invalidBuffer); + for (int i = 0; i < size; i++) + invalidBuffer.put((byte) i); + invalidBuffer.rewind(); + try { + type.read(invalidBuffer); + fail("Array size not validated"); + } catch (SchemaException e) { + // Expected exception + } + } + + @Test + public void testReadStringSizeTooLarge() { + byte[] stringBytes = "foo".getBytes(); + ByteBuffer invalidBuffer = ByteBuffer.allocate(2 + stringBytes.length); + invalidBuffer.putShort((short) (stringBytes.length * 5)); + invalidBuffer.put(stringBytes); + invalidBuffer.rewind(); + try { + Type.STRING.read(invalidBuffer); + fail("String size not validated"); + } catch (SchemaException e) { + // Expected exception + } + invalidBuffer.rewind(); + try { + Type.NULLABLE_STRING.read(invalidBuffer); + fail("String size not validated"); + } catch (SchemaException e) { + // Expected exception + } + } + + @Test + public void testReadNegativeStringSize() { + byte[] stringBytes = "foo".getBytes(); + ByteBuffer invalidBuffer = ByteBuffer.allocate(2 + stringBytes.length); + invalidBuffer.putShort((short) -1); + invalidBuffer.put(stringBytes); + invalidBuffer.rewind(); + try { + Type.STRING.read(invalidBuffer); + fail("String size not validated"); + } catch (SchemaException e) { + // Expected exception + } + } + + @Test + public void testReadBytesSizeTooLarge() { + byte[] stringBytes = "foo".getBytes(); + ByteBuffer invalidBuffer = ByteBuffer.allocate(4 + stringBytes.length); + invalidBuffer.putInt(stringBytes.length * 5); + invalidBuffer.put(stringBytes); + invalidBuffer.rewind(); + try { + Type.BYTES.read(invalidBuffer); + fail("Bytes size not validated"); + } catch (SchemaException e) { + // Expected exception + } + invalidBuffer.rewind(); + try { + Type.NULLABLE_BYTES.read(invalidBuffer); + fail("Bytes size not validated"); + } catch (SchemaException e) { + // Expected exception + } + } + + @Test + public void testReadNegativeBytesSize() { + byte[] stringBytes = "foo".getBytes(); + ByteBuffer invalidBuffer = ByteBuffer.allocate(4 + stringBytes.length); + invalidBuffer.putInt(-20); + invalidBuffer.put(stringBytes); + invalidBuffer.rewind(); + try { + Type.BYTES.read(invalidBuffer); + fail("Bytes size not validated"); + } catch (SchemaException e) { + // Expected exception + } + } + + @Test + public void testToString() { + String structStr = this.struct.toString(); + assertNotNull(structStr, "Struct string should not be null."); + assertFalse(structStr.isEmpty(), "Struct string should not be empty."); + } + + private Object roundtrip(Type type, Object obj) { + ByteBuffer buffer = ByteBuffer.allocate(type.sizeOf(obj)); + type.write(buffer, obj); + assertFalse(buffer.hasRemaining(), "The buffer should now be full."); + buffer.rewind(); + Object read = type.read(buffer); + assertFalse(buffer.hasRemaining(), "All bytes should have been read."); + return read; + } + + private void check(Type type, Object obj, String expectedTypeName) { + Object result = roundtrip(type, obj); + if (obj instanceof Object[]) { + obj = Arrays.asList((Object[]) obj); + result = Arrays.asList((Object[]) result); + } + assertEquals(expectedTypeName, type.toString()); + assertEquals(obj, result, "The object read back should be the same as what was written."); + } + + @Test + public void testStructEquals() { + Schema schema = new Schema(new Field("field1", Type.NULLABLE_STRING), new Field("field2", Type.NULLABLE_STRING)); + Struct emptyStruct1 = new Struct(schema); + Struct emptyStruct2 = new Struct(schema); + assertEquals(emptyStruct1, emptyStruct2); + + Struct mostlyEmptyStruct = new Struct(schema).set("field1", "foo"); + assertNotEquals(emptyStruct1, mostlyEmptyStruct); + assertNotEquals(mostlyEmptyStruct, emptyStruct1); + } + + @Test + public void testReadIgnoringExtraDataAtTheEnd() { + Schema oldSchema = new Schema(new Field("field1", Type.NULLABLE_STRING), new Field("field2", Type.NULLABLE_STRING)); + Schema newSchema = new Schema(new Field("field1", Type.NULLABLE_STRING)); + String value = "foo bar baz"; + Struct oldFormat = new Struct(oldSchema).set("field1", value).set("field2", "fine to ignore"); + ByteBuffer buffer = ByteBuffer.allocate(oldSchema.sizeOf(oldFormat)); + oldFormat.writeTo(buffer); + buffer.flip(); + Struct newFormat = newSchema.read(buffer); + assertEquals(value, newFormat.get("field1")); + } + + @Test + public void testReadWhenOptionalDataMissingAtTheEndIsTolerated() { + Schema oldSchema = new Schema(new Field("field1", Type.NULLABLE_STRING)); + Schema newSchema = new Schema( + true, + new Field("field1", Type.NULLABLE_STRING), + new Field("field2", Type.NULLABLE_STRING, "", true, "default"), + new Field("field3", Type.NULLABLE_STRING, "", true, null), + new Field("field4", Type.NULLABLE_BYTES, "", true, ByteBuffer.allocate(0)), + new Field("field5", Type.INT64, "doc", true, Long.MAX_VALUE)); + String value = "foo bar baz"; + Struct oldFormat = new Struct(oldSchema).set("field1", value); + ByteBuffer buffer = ByteBuffer.allocate(oldSchema.sizeOf(oldFormat)); + oldFormat.writeTo(buffer); + buffer.flip(); + Struct newFormat = newSchema.read(buffer); + assertEquals(value, newFormat.get("field1")); + assertEquals("default", newFormat.get("field2")); + assertNull(newFormat.get("field3")); + assertEquals(ByteBuffer.allocate(0), newFormat.get("field4")); + assertEquals(Long.MAX_VALUE, newFormat.get("field5")); + } + + @Test + public void testReadWhenOptionalDataMissingAtTheEndIsNotTolerated() { + Schema oldSchema = new Schema(new Field("field1", Type.NULLABLE_STRING)); + Schema newSchema = new Schema( + new Field("field1", Type.NULLABLE_STRING), + new Field("field2", Type.NULLABLE_STRING, "", true, "default")); + String value = "foo bar baz"; + Struct oldFormat = new Struct(oldSchema).set("field1", value); + ByteBuffer buffer = ByteBuffer.allocate(oldSchema.sizeOf(oldFormat)); + oldFormat.writeTo(buffer); + buffer.flip(); + SchemaException e = assertThrows(SchemaException.class, () -> newSchema.read(buffer)); + assertTrue(e.getMessage().contains("Error reading field 'field2':")); + } + + @Test + public void testReadWithMissingNonOptionalExtraDataAtTheEnd() { + Schema oldSchema = new Schema(new Field("field1", Type.NULLABLE_STRING)); + Schema newSchema = new Schema( + true, + new Field("field1", Type.NULLABLE_STRING), + new Field("field2", Type.NULLABLE_STRING)); + String value = "foo bar baz"; + Struct oldFormat = new Struct(oldSchema).set("field1", value); + ByteBuffer buffer = ByteBuffer.allocate(oldSchema.sizeOf(oldFormat)); + oldFormat.writeTo(buffer); + buffer.flip(); + SchemaException e = assertThrows(SchemaException.class, () -> newSchema.read(buffer)); + assertTrue(e.getMessage().contains("Missing value for field 'field2' which has no default value")); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/protocol/types/RawTaggedFieldWriterTest.java b/clients/src/test/java/org/apache/kafka/common/protocol/types/RawTaggedFieldWriterTest.java new file mode 100644 index 0000000..bd7f208 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/protocol/types/RawTaggedFieldWriterTest.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.protocol.types; + +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +@Timeout(120) +public class RawTaggedFieldWriterTest { + + @Test + public void testWritingZeroRawTaggedFields() { + RawTaggedFieldWriter writer = RawTaggedFieldWriter.forFields(null); + assertEquals(0, writer.numFields()); + ByteBufferAccessor accessor = new ByteBufferAccessor(ByteBuffer.allocate(0)); + writer.writeRawTags(accessor, Integer.MAX_VALUE); + } + + @Test + public void testWritingSeveralRawTaggedFields() { + List tags = Arrays.asList( + new RawTaggedField(2, new byte[] {0x1, 0x2, 0x3}), + new RawTaggedField(5, new byte[] {0x4, 0x5}) + ); + RawTaggedFieldWriter writer = RawTaggedFieldWriter.forFields(tags); + assertEquals(2, writer.numFields()); + byte[] arr = new byte[9]; + ByteBufferAccessor accessor = new ByteBufferAccessor(ByteBuffer.wrap(arr)); + writer.writeRawTags(accessor, 1); + assertArrayEquals(new byte[] {0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, arr); + writer.writeRawTags(accessor, 3); + assertArrayEquals(new byte[] {0x2, 0x3, 0x1, 0x2, 0x3, 0x0, 0x0, 0x0, 0x0}, arr); + writer.writeRawTags(accessor, 7); + assertArrayEquals(new byte[] {0x2, 0x3, 0x1, 0x2, 0x3, 0x5, 0x2, 0x4, 0x5}, arr); + writer.writeRawTags(accessor, Integer.MAX_VALUE); + assertArrayEquals(new byte[] {0x2, 0x3, 0x1, 0x2, 0x3, 0x5, 0x2, 0x4, 0x5}, arr); + } + + @Test + public void testInvalidNextDefinedTag() { + List tags = Arrays.asList( + new RawTaggedField(2, new byte[] {0x1, 0x2, 0x3}), + new RawTaggedField(5, new byte[] {0x4, 0x5, 0x6}), + new RawTaggedField(7, new byte[] {0x0}) + ); + RawTaggedFieldWriter writer = RawTaggedFieldWriter.forFields(tags); + assertEquals(3, writer.numFields()); + try { + writer.writeRawTags(new ByteBufferAccessor(ByteBuffer.allocate(1024)), 2); + fail("expected to get RuntimeException"); + } catch (RuntimeException e) { + assertEquals("Attempted to use tag 2 as an undefined tag.", e.getMessage()); + } + } + + @Test + public void testOutOfOrderTags() { + List tags = Arrays.asList( + new RawTaggedField(5, new byte[] {0x4, 0x5, 0x6}), + new RawTaggedField(2, new byte[] {0x1, 0x2, 0x3}), + new RawTaggedField(7, new byte[] {0x0 }) + ); + RawTaggedFieldWriter writer = RawTaggedFieldWriter.forFields(tags); + assertEquals(3, writer.numFields()); + try { + writer.writeRawTags(new ByteBufferAccessor(ByteBuffer.allocate(1024)), 8); + fail("expected to get RuntimeException"); + } catch (RuntimeException e) { + assertEquals("Invalid raw tag field list: tag 2 comes after tag 5, but is " + + "not higher than it.", e.getMessage()); + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/protocol/types/StructTest.java b/clients/src/test/java/org/apache/kafka/common/protocol/types/StructTest.java new file mode 100644 index 0000000..e76022a --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/protocol/types/StructTest.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.protocol.types; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; + +public class StructTest { + private static final Schema FLAT_STRUCT_SCHEMA = new Schema( + new Field.Int8("int8", ""), + new Field.Int16("int16", ""), + new Field.Int32("int32", ""), + new Field.Int64("int64", ""), + new Field.Bool("boolean", ""), + new Field.Float64("float64", ""), + new Field.Str("string", "")); + + private static final Schema ARRAY_SCHEMA = new Schema(new Field.Array("array", new ArrayOf(Type.INT8), "")); + private static final Schema NESTED_CHILD_SCHEMA = new Schema( + new Field.Int8("int8", "")); + private static final Schema NESTED_SCHEMA = new Schema( + new Field.Array("array", ARRAY_SCHEMA, ""), + new Field("nested", NESTED_CHILD_SCHEMA, "")); + + @Test + public void testEquals() { + Struct struct1 = new Struct(FLAT_STRUCT_SCHEMA) + .set("int8", (byte) 12) + .set("int16", (short) 12) + .set("int32", 12) + .set("int64", (long) 12) + .set("boolean", true) + .set("float64", 0.5) + .set("string", "foobar"); + Struct struct2 = new Struct(FLAT_STRUCT_SCHEMA) + .set("int8", (byte) 12) + .set("int16", (short) 12) + .set("int32", 12) + .set("int64", (long) 12) + .set("boolean", true) + .set("float64", 0.5) + .set("string", "foobar"); + Struct struct3 = new Struct(FLAT_STRUCT_SCHEMA) + .set("int8", (byte) 12) + .set("int16", (short) 12) + .set("int32", 12) + .set("int64", (long) 12) + .set("boolean", true) + .set("float64", 0.5) + .set("string", "mismatching string"); + + assertEquals(struct1, struct2); + assertNotEquals(struct1, struct3); + + Object[] array = {(byte) 1, (byte) 2}; + struct1 = new Struct(NESTED_SCHEMA) + .set("array", array) + .set("nested", new Struct(NESTED_CHILD_SCHEMA).set("int8", (byte) 12)); + Object[] array2 = {(byte) 1, (byte) 2}; + struct2 = new Struct(NESTED_SCHEMA) + .set("array", array2) + .set("nested", new Struct(NESTED_CHILD_SCHEMA).set("int8", (byte) 12)); + Object[] array3 = {(byte) 1, (byte) 2, (byte) 3}; + struct3 = new Struct(NESTED_SCHEMA) + .set("array", array3) + .set("nested", new Struct(NESTED_CHILD_SCHEMA).set("int8", (byte) 13)); + + assertEquals(struct1, struct2); + assertNotEquals(struct1, struct3); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/protocol/types/TypeTest.java b/clients/src/test/java/org/apache/kafka/common/protocol/types/TypeTest.java new file mode 100644 index 0000000..3a4b651 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/protocol/types/TypeTest.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.protocol.types; + +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.SimpleRecord; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +public class TypeTest { + + @Test + public void testEmptyRecordsSerde() { + ByteBuffer buffer = ByteBuffer.allocate(4); + Type.RECORDS.write(buffer, MemoryRecords.EMPTY); + buffer.flip(); + assertEquals(4, Type.RECORDS.sizeOf(MemoryRecords.EMPTY)); + assertEquals(4, buffer.limit()); + assertEquals(MemoryRecords.EMPTY, Type.RECORDS.read(buffer)); + } + + @Test + public void testNullRecordsSerde() { + ByteBuffer buffer = ByteBuffer.allocate(4); + Type.RECORDS.write(buffer, null); + buffer.flip(); + assertEquals(4, Type.RECORDS.sizeOf(MemoryRecords.EMPTY)); + assertEquals(4, buffer.limit()); + assertNull(Type.RECORDS.read(buffer)); + } + + @Test + public void testRecordsSerde() { + MemoryRecords records = MemoryRecords.withRecords(CompressionType.NONE, + new SimpleRecord("foo".getBytes()), + new SimpleRecord("bar".getBytes())); + ByteBuffer buffer = ByteBuffer.allocate(Type.RECORDS.sizeOf(records)); + Type.RECORDS.write(buffer, records); + buffer.flip(); + assertEquals(records, Type.RECORDS.read(buffer)); + } + + @Test + public void testEmptyCompactRecordsSerde() { + ByteBuffer buffer = ByteBuffer.allocate(4); + Type.COMPACT_RECORDS.write(buffer, MemoryRecords.EMPTY); + buffer.flip(); + assertEquals(1, Type.COMPACT_RECORDS.sizeOf(MemoryRecords.EMPTY)); + assertEquals(1, buffer.limit()); + assertEquals(MemoryRecords.EMPTY, Type.COMPACT_RECORDS.read(buffer)); + } + + @Test + public void testNullCompactRecordsSerde() { + ByteBuffer buffer = ByteBuffer.allocate(4); + Type.COMPACT_RECORDS.write(buffer, null); + buffer.flip(); + assertEquals(1, Type.COMPACT_RECORDS.sizeOf(MemoryRecords.EMPTY)); + assertEquals(1, buffer.limit()); + assertNull(Type.COMPACT_RECORDS.read(buffer)); + } + + @Test + public void testCompactRecordsSerde() { + MemoryRecords records = MemoryRecords.withRecords(CompressionType.NONE, + new SimpleRecord("foo".getBytes()), + new SimpleRecord("bar".getBytes())); + ByteBuffer buffer = ByteBuffer.allocate(Type.COMPACT_RECORDS.sizeOf(records)); + Type.COMPACT_RECORDS.write(buffer, records); + buffer.flip(); + assertEquals(records, Type.COMPACT_RECORDS.read(buffer)); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/record/AbstractLegacyRecordBatchTest.java b/clients/src/test/java/org/apache/kafka/common/record/AbstractLegacyRecordBatchTest.java new file mode 100644 index 0000000..a6a64fc --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/record/AbstractLegacyRecordBatchTest.java @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.InvalidRecordException; +import org.apache.kafka.common.record.AbstractLegacyRecordBatch.ByteBufferLegacyRecordBatch; +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class AbstractLegacyRecordBatchTest { + + @Test + public void testSetLastOffsetCompressed() { + SimpleRecord[] simpleRecords = new SimpleRecord[] { + new SimpleRecord(1L, "a".getBytes(), "1".getBytes()), + new SimpleRecord(2L, "b".getBytes(), "2".getBytes()), + new SimpleRecord(3L, "c".getBytes(), "3".getBytes()) + }; + + MemoryRecords records = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V1, 0L, + CompressionType.GZIP, TimestampType.CREATE_TIME, simpleRecords); + + long lastOffset = 500L; + long firstOffset = lastOffset - simpleRecords.length + 1; + + ByteBufferLegacyRecordBatch batch = new ByteBufferLegacyRecordBatch(records.buffer()); + batch.setLastOffset(lastOffset); + assertEquals(lastOffset, batch.lastOffset()); + assertEquals(firstOffset, batch.baseOffset()); + assertTrue(batch.isValid()); + + List recordBatches = Utils.toList(records.batches().iterator()); + assertEquals(1, recordBatches.size()); + assertEquals(lastOffset, recordBatches.get(0).lastOffset()); + + long offset = firstOffset; + for (Record record : records.records()) + assertEquals(offset++, record.offset()); + } + + /** + * The wrapper offset should be 0 in v0, but not in v1. However, the latter worked by accident and some versions of + * librdkafka now depend on it. So we support 0 for compatibility reasons, but the recommendation is to set the + * wrapper offset to the relative offset of the last record in the batch. + */ + @Test + public void testIterateCompressedRecordWithWrapperOffsetZero() { + for (byte magic : Arrays.asList(RecordBatch.MAGIC_VALUE_V0, RecordBatch.MAGIC_VALUE_V1)) { + SimpleRecord[] simpleRecords = new SimpleRecord[] { + new SimpleRecord(1L, "a".getBytes(), "1".getBytes()), + new SimpleRecord(2L, "b".getBytes(), "2".getBytes()), + new SimpleRecord(3L, "c".getBytes(), "3".getBytes()) + }; + + MemoryRecords records = MemoryRecords.withRecords(magic, 0L, + CompressionType.GZIP, TimestampType.CREATE_TIME, simpleRecords); + + ByteBufferLegacyRecordBatch batch = new ByteBufferLegacyRecordBatch(records.buffer()); + batch.setLastOffset(0L); + + long offset = 0L; + for (Record record : batch) + assertEquals(offset++, record.offset()); + } + } + + @Test + public void testInvalidWrapperOffsetV1() { + SimpleRecord[] simpleRecords = new SimpleRecord[] { + new SimpleRecord(1L, "a".getBytes(), "1".getBytes()), + new SimpleRecord(2L, "b".getBytes(), "2".getBytes()), + new SimpleRecord(3L, "c".getBytes(), "3".getBytes()) + }; + + MemoryRecords records = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V1, 0L, + CompressionType.GZIP, TimestampType.CREATE_TIME, simpleRecords); + + ByteBufferLegacyRecordBatch batch = new ByteBufferLegacyRecordBatch(records.buffer()); + batch.setLastOffset(1L); + + assertThrows(InvalidRecordException.class, batch::iterator); + } + + @Test + public void testSetNoTimestampTypeNotAllowed() { + MemoryRecords records = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V1, 0L, + CompressionType.GZIP, TimestampType.CREATE_TIME, + new SimpleRecord(1L, "a".getBytes(), "1".getBytes()), + new SimpleRecord(2L, "b".getBytes(), "2".getBytes()), + new SimpleRecord(3L, "c".getBytes(), "3".getBytes())); + ByteBufferLegacyRecordBatch batch = new ByteBufferLegacyRecordBatch(records.buffer()); + assertThrows(IllegalArgumentException.class, () -> batch.setMaxTimestamp(TimestampType.NO_TIMESTAMP_TYPE, RecordBatch.NO_TIMESTAMP)); + } + + @Test + public void testSetLogAppendTimeNotAllowedV0() { + MemoryRecords records = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V0, 0L, + CompressionType.GZIP, TimestampType.CREATE_TIME, + new SimpleRecord(1L, "a".getBytes(), "1".getBytes()), + new SimpleRecord(2L, "b".getBytes(), "2".getBytes()), + new SimpleRecord(3L, "c".getBytes(), "3".getBytes())); + long logAppendTime = 15L; + ByteBufferLegacyRecordBatch batch = new ByteBufferLegacyRecordBatch(records.buffer()); + assertThrows(UnsupportedOperationException.class, () -> batch.setMaxTimestamp(TimestampType.LOG_APPEND_TIME, logAppendTime)); + } + + @Test + public void testSetCreateTimeNotAllowedV0() { + MemoryRecords records = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V0, 0L, + CompressionType.GZIP, TimestampType.CREATE_TIME, + new SimpleRecord(1L, "a".getBytes(), "1".getBytes()), + new SimpleRecord(2L, "b".getBytes(), "2".getBytes()), + new SimpleRecord(3L, "c".getBytes(), "3".getBytes())); + long createTime = 15L; + ByteBufferLegacyRecordBatch batch = new ByteBufferLegacyRecordBatch(records.buffer()); + assertThrows(UnsupportedOperationException.class, () -> batch.setMaxTimestamp(TimestampType.CREATE_TIME, createTime)); + } + + @Test + public void testSetPartitionLeaderEpochNotAllowedV0() { + MemoryRecords records = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V0, 0L, + CompressionType.GZIP, TimestampType.CREATE_TIME, + new SimpleRecord(1L, "a".getBytes(), "1".getBytes()), + new SimpleRecord(2L, "b".getBytes(), "2".getBytes()), + new SimpleRecord(3L, "c".getBytes(), "3".getBytes())); + ByteBufferLegacyRecordBatch batch = new ByteBufferLegacyRecordBatch(records.buffer()); + assertThrows(UnsupportedOperationException.class, () -> batch.setPartitionLeaderEpoch(15)); + } + + @Test + public void testSetPartitionLeaderEpochNotAllowedV1() { + MemoryRecords records = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V1, 0L, + CompressionType.GZIP, TimestampType.CREATE_TIME, + new SimpleRecord(1L, "a".getBytes(), "1".getBytes()), + new SimpleRecord(2L, "b".getBytes(), "2".getBytes()), + new SimpleRecord(3L, "c".getBytes(), "3".getBytes())); + ByteBufferLegacyRecordBatch batch = new ByteBufferLegacyRecordBatch(records.buffer()); + assertThrows(UnsupportedOperationException.class, () -> batch.setPartitionLeaderEpoch(15)); + } + + @Test + public void testSetLogAppendTimeV1() { + MemoryRecords records = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V1, 0L, + CompressionType.GZIP, TimestampType.CREATE_TIME, + new SimpleRecord(1L, "a".getBytes(), "1".getBytes()), + new SimpleRecord(2L, "b".getBytes(), "2".getBytes()), + new SimpleRecord(3L, "c".getBytes(), "3".getBytes())); + + long logAppendTime = 15L; + + ByteBufferLegacyRecordBatch batch = new ByteBufferLegacyRecordBatch(records.buffer()); + batch.setMaxTimestamp(TimestampType.LOG_APPEND_TIME, logAppendTime); + assertEquals(TimestampType.LOG_APPEND_TIME, batch.timestampType()); + assertEquals(logAppendTime, batch.maxTimestamp()); + assertTrue(batch.isValid()); + + List recordBatches = Utils.toList(records.batches().iterator()); + assertEquals(1, recordBatches.size()); + assertEquals(TimestampType.LOG_APPEND_TIME, recordBatches.get(0).timestampType()); + assertEquals(logAppendTime, recordBatches.get(0).maxTimestamp()); + + for (Record record : records.records()) + assertEquals(logAppendTime, record.timestamp()); + } + + @Test + public void testSetCreateTimeV1() { + MemoryRecords records = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V1, 0L, + CompressionType.GZIP, TimestampType.CREATE_TIME, + new SimpleRecord(1L, "a".getBytes(), "1".getBytes()), + new SimpleRecord(2L, "b".getBytes(), "2".getBytes()), + new SimpleRecord(3L, "c".getBytes(), "3".getBytes())); + + long createTime = 15L; + + ByteBufferLegacyRecordBatch batch = new ByteBufferLegacyRecordBatch(records.buffer()); + batch.setMaxTimestamp(TimestampType.CREATE_TIME, createTime); + assertEquals(TimestampType.CREATE_TIME, batch.timestampType()); + assertEquals(createTime, batch.maxTimestamp()); + assertTrue(batch.isValid()); + + List recordBatches = Utils.toList(records.batches().iterator()); + assertEquals(1, recordBatches.size()); + assertEquals(TimestampType.CREATE_TIME, recordBatches.get(0).timestampType()); + assertEquals(createTime, recordBatches.get(0).maxTimestamp()); + + long expectedTimestamp = 1L; + for (Record record : records.records()) + assertEquals(expectedTimestamp++, record.timestamp()); + } + + @Test + public void testZStdCompressionTypeWithV0OrV1() { + SimpleRecord[] simpleRecords = new SimpleRecord[] { + new SimpleRecord(1L, "a".getBytes(), "1".getBytes()), + new SimpleRecord(2L, "b".getBytes(), "2".getBytes()), + new SimpleRecord(3L, "c".getBytes(), "3".getBytes()) + }; + + // Check V0 + try { + MemoryRecords records = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V0, 0L, + CompressionType.ZSTD, TimestampType.CREATE_TIME, simpleRecords); + + ByteBufferLegacyRecordBatch batch = new ByteBufferLegacyRecordBatch(records.buffer()); + batch.setLastOffset(1L); + + batch.iterator(); + fail("Can't reach here"); + } catch (IllegalArgumentException e) { + assertEquals("ZStandard compression is not supported for magic 0", e.getMessage()); + } + + // Check V1 + try { + MemoryRecords records = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V1, 0L, + CompressionType.ZSTD, TimestampType.CREATE_TIME, simpleRecords); + + ByteBufferLegacyRecordBatch batch = new ByteBufferLegacyRecordBatch(records.buffer()); + batch.setLastOffset(1L); + + batch.iterator(); + fail("Can't reach here"); + } catch (IllegalArgumentException e) { + assertEquals("ZStandard compression is not supported for magic 1", e.getMessage()); + } + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/record/BufferSupplierTest.java b/clients/src/test/java/org/apache/kafka/common/record/BufferSupplierTest.java new file mode 100644 index 0000000..e580be5 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/record/BufferSupplierTest.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.record; + +import org.apache.kafka.common.utils.BufferSupplier; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; + +public class BufferSupplierTest { + + @Test + public void testGrowableBuffer() { + BufferSupplier.GrowableBufferSupplier supplier = new BufferSupplier.GrowableBufferSupplier(); + ByteBuffer buffer = supplier.get(1024); + assertEquals(0, buffer.position()); + assertEquals(1024, buffer.capacity()); + supplier.release(buffer); + + ByteBuffer cached = supplier.get(512); + assertEquals(0, cached.position()); + assertSame(buffer, cached); + + ByteBuffer increased = supplier.get(2048); + assertEquals(2048, increased.capacity()); + assertEquals(0, increased.position()); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/record/ByteBufferLogInputStreamTest.java b/clients/src/test/java/org/apache/kafka/common/record/ByteBufferLogInputStreamTest.java new file mode 100644 index 0000000..0840169 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/record/ByteBufferLogInputStreamTest.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.errors.CorruptRecordException; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.util.Iterator; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ByteBufferLogInputStreamTest { + + @Test + public void iteratorIgnoresIncompleteEntries() { + ByteBuffer buffer = ByteBuffer.allocate(1024); + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, CompressionType.NONE, TimestampType.CREATE_TIME, 0L); + builder.append(15L, "a".getBytes(), "1".getBytes()); + builder.append(20L, "b".getBytes(), "2".getBytes()); + builder.close(); + + builder = MemoryRecords.builder(buffer, CompressionType.NONE, TimestampType.CREATE_TIME, 2L); + builder.append(30L, "c".getBytes(), "3".getBytes()); + builder.append(40L, "d".getBytes(), "4".getBytes()); + builder.close(); + + buffer.flip(); + buffer.limit(buffer.limit() - 5); + + MemoryRecords records = MemoryRecords.readableRecords(buffer); + Iterator iterator = records.batches().iterator(); + assertTrue(iterator.hasNext()); + MutableRecordBatch first = iterator.next(); + assertEquals(1L, first.lastOffset()); + + assertFalse(iterator.hasNext()); + } + + @Test + public void iteratorRaisesOnTooSmallRecords() { + ByteBuffer buffer = ByteBuffer.allocate(1024); + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, CompressionType.NONE, TimestampType.CREATE_TIME, 0L); + builder.append(15L, "a".getBytes(), "1".getBytes()); + builder.append(20L, "b".getBytes(), "2".getBytes()); + builder.close(); + + int position = buffer.position(); + + builder = MemoryRecords.builder(buffer, CompressionType.NONE, TimestampType.CREATE_TIME, 2L); + builder.append(30L, "c".getBytes(), "3".getBytes()); + builder.append(40L, "d".getBytes(), "4".getBytes()); + builder.close(); + + buffer.flip(); + buffer.putInt(position + DefaultRecordBatch.LENGTH_OFFSET, 9); + + ByteBufferLogInputStream logInputStream = new ByteBufferLogInputStream(buffer, Integer.MAX_VALUE); + assertNotNull(logInputStream.nextBatch()); + assertThrows(CorruptRecordException.class, logInputStream::nextBatch); + } + + @Test + public void iteratorRaisesOnInvalidMagic() { + ByteBuffer buffer = ByteBuffer.allocate(1024); + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, CompressionType.NONE, TimestampType.CREATE_TIME, 0L); + builder.append(15L, "a".getBytes(), "1".getBytes()); + builder.append(20L, "b".getBytes(), "2".getBytes()); + builder.close(); + + int position = buffer.position(); + + builder = MemoryRecords.builder(buffer, CompressionType.NONE, TimestampType.CREATE_TIME, 2L); + builder.append(30L, "c".getBytes(), "3".getBytes()); + builder.append(40L, "d".getBytes(), "4".getBytes()); + builder.close(); + + buffer.flip(); + buffer.put(position + DefaultRecordBatch.MAGIC_OFFSET, (byte) 37); + + ByteBufferLogInputStream logInputStream = new ByteBufferLogInputStream(buffer, Integer.MAX_VALUE); + assertNotNull(logInputStream.nextBatch()); + assertThrows(CorruptRecordException.class, logInputStream::nextBatch); + } + + @Test + public void iteratorRaisesOnTooLargeRecords() { + ByteBuffer buffer = ByteBuffer.allocate(1024); + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, CompressionType.NONE, TimestampType.CREATE_TIME, 0L); + builder.append(15L, "a".getBytes(), "1".getBytes()); + builder.close(); + + builder = MemoryRecords.builder(buffer, CompressionType.NONE, TimestampType.CREATE_TIME, 2L); + builder.append(30L, "c".getBytes(), "3".getBytes()); + builder.append(40L, "d".getBytes(), "4".getBytes()); + builder.close(); + buffer.flip(); + + ByteBufferLogInputStream logInputStream = new ByteBufferLogInputStream(buffer, 60); + assertNotNull(logInputStream.nextBatch()); + assertThrows(CorruptRecordException.class, logInputStream::nextBatch); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/record/CompressionRatioEstimatorTest.java b/clients/src/test/java/org/apache/kafka/common/record/CompressionRatioEstimatorTest.java new file mode 100644 index 0000000..7ba51db --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/record/CompressionRatioEstimatorTest.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.assertTrue; +import java.util.Arrays; +import java.util.List; + +public class CompressionRatioEstimatorTest { + + @Test + public void testUpdateEstimation() { + class EstimationsObservedRatios { + float currentEstimation; + float observedRatio; + EstimationsObservedRatios(float currentEstimation, float observedRatio) { + this.currentEstimation = currentEstimation; + this.observedRatio = observedRatio; + } + } + + // If currentEstimation is smaller than observedRatio, the updatedCompressionRatio is currentEstimation plus + // COMPRESSION_RATIO_DETERIORATE_STEP 0.05, otherwise currentEstimation minus COMPRESSION_RATIO_IMPROVING_STEP + // 0.005. There are four cases,and updatedCompressionRatio shouldn't smaller than observedRatio in all of cases. + // Refer to non test code for more details. + List estimationsObservedRatios = Arrays.asList( + new EstimationsObservedRatios(0.8f, 0.84f), + new EstimationsObservedRatios(0.6f, 0.7f), + new EstimationsObservedRatios(0.6f, 0.4f), + new EstimationsObservedRatios(0.004f, 0.001f)); + for (EstimationsObservedRatios estimationsObservedRatio : estimationsObservedRatios) { + String topic = "tp"; + CompressionRatioEstimator.setEstimation(topic, CompressionType.ZSTD, estimationsObservedRatio.currentEstimation); + float updatedCompressionRatio = CompressionRatioEstimator.updateEstimation(topic, CompressionType.ZSTD, estimationsObservedRatio.observedRatio); + assertTrue(updatedCompressionRatio >= estimationsObservedRatio.observedRatio); + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/record/CompressionTypeTest.java b/clients/src/test/java/org/apache/kafka/common/record/CompressionTypeTest.java new file mode 100644 index 0000000..16b560d --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/record/CompressionTypeTest.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.compress.KafkaLZ4BlockInputStream; +import org.apache.kafka.common.compress.KafkaLZ4BlockOutputStream; +import org.apache.kafka.common.utils.BufferSupplier; +import org.apache.kafka.common.utils.ByteBufferOutputStream; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class CompressionTypeTest { + + @Test + public void testLZ4FramingMagicV0() { + ByteBuffer buffer = ByteBuffer.allocate(256); + KafkaLZ4BlockOutputStream out = (KafkaLZ4BlockOutputStream) CompressionType.LZ4.wrapForOutput( + new ByteBufferOutputStream(buffer), RecordBatch.MAGIC_VALUE_V0); + assertTrue(out.useBrokenFlagDescriptorChecksum()); + + buffer.rewind(); + + KafkaLZ4BlockInputStream in = (KafkaLZ4BlockInputStream) CompressionType.LZ4.wrapForInput( + buffer, RecordBatch.MAGIC_VALUE_V0, BufferSupplier.NO_CACHING); + assertTrue(in.ignoreFlagDescriptorChecksum()); + } + + @Test + public void testLZ4FramingMagicV1() { + ByteBuffer buffer = ByteBuffer.allocate(256); + KafkaLZ4BlockOutputStream out = (KafkaLZ4BlockOutputStream) CompressionType.LZ4.wrapForOutput( + new ByteBufferOutputStream(buffer), RecordBatch.MAGIC_VALUE_V1); + assertFalse(out.useBrokenFlagDescriptorChecksum()); + + buffer.rewind(); + + KafkaLZ4BlockInputStream in = (KafkaLZ4BlockInputStream) CompressionType.LZ4.wrapForInput( + buffer, RecordBatch.MAGIC_VALUE_V1, BufferSupplier.create()); + assertFalse(in.ignoreFlagDescriptorChecksum()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/record/ControlRecordTypeTest.java b/clients/src/test/java/org/apache/kafka/common/record/ControlRecordTypeTest.java new file mode 100644 index 0000000..3245f31 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/record/ControlRecordTypeTest.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class ControlRecordTypeTest { + + @Test + public void testParseUnknownType() { + ByteBuffer buffer = ByteBuffer.allocate(32); + buffer.putShort(ControlRecordType.CURRENT_CONTROL_RECORD_KEY_VERSION); + buffer.putShort((short) 337); + buffer.flip(); + ControlRecordType type = ControlRecordType.parse(buffer); + assertEquals(ControlRecordType.UNKNOWN, type); + } + + @Test + public void testParseUnknownVersion() { + ByteBuffer buffer = ByteBuffer.allocate(32); + buffer.putShort((short) 5); + buffer.putShort(ControlRecordType.ABORT.type); + buffer.putInt(23432); // some field added in version 5 + buffer.flip(); + ControlRecordType type = ControlRecordType.parse(buffer); + assertEquals(ControlRecordType.ABORT, type); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/record/ControlRecordUtilsTest.java b/clients/src/test/java/org/apache/kafka/common/record/ControlRecordUtilsTest.java new file mode 100644 index 0000000..657dc80 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/record/ControlRecordUtilsTest.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.message.LeaderChangeMessage; +import org.apache.kafka.common.message.LeaderChangeMessage.Voter; + +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.ObjectSerializationCache; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class ControlRecordUtilsTest { + + @Test + public void testInvalidControlRecordType() { + IllegalArgumentException thrown = assertThrows( + IllegalArgumentException.class, () -> testDeserializeRecord(ControlRecordType.COMMIT)); + assertEquals("Expected LEADER_CHANGE control record type(2), but found COMMIT", thrown.getMessage()); + } + + @Test + public void testDeserializeByteData() { + testDeserializeRecord(ControlRecordType.LEADER_CHANGE); + } + + private void testDeserializeRecord(ControlRecordType controlRecordType) { + final int leaderId = 1; + final int voterId = 2; + LeaderChangeMessage data = new LeaderChangeMessage() + .setLeaderId(leaderId) + .setVoters(Collections.singletonList( + new Voter().setVoterId(voterId))); + + ByteBuffer valueBuffer = ByteBuffer.allocate(256); + data.write(new ByteBufferAccessor(valueBuffer), new ObjectSerializationCache(), data.highestSupportedVersion()); + valueBuffer.flip(); + + byte[] keyData = new byte[]{0, 0, 0, (byte) controlRecordType.type}; + + DefaultRecord record = new DefaultRecord( + 256, (byte) 0, 0, 0L, 0, ByteBuffer.wrap(keyData), valueBuffer, null + ); + + LeaderChangeMessage deserializedData = ControlRecordUtils.deserializeLeaderChangeMessage(record); + + assertEquals(leaderId, deserializedData.leaderId()); + assertEquals(Collections.singletonList( + new Voter().setVoterId(voterId)), deserializedData.voters()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordBatchTest.java b/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordBatchTest.java new file mode 100644 index 0000000..0864a2d --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordBatchTest.java @@ -0,0 +1,423 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.InvalidRecordException; +import org.apache.kafka.common.errors.CorruptRecordException; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.utils.BufferSupplier; +import org.apache.kafka.common.utils.CloseableIterator; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; + +import static org.apache.kafka.common.record.DefaultRecordBatch.RECORDS_COUNT_OFFSET; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class DefaultRecordBatchTest { + + @Test + public void testWriteEmptyHeader() { + long producerId = 23423L; + short producerEpoch = 145; + int baseSequence = 983; + long baseOffset = 15L; + long lastOffset = 37; + int partitionLeaderEpoch = 15; + long timestamp = System.currentTimeMillis(); + + for (TimestampType timestampType : Arrays.asList(TimestampType.CREATE_TIME, TimestampType.LOG_APPEND_TIME)) { + for (boolean isTransactional : Arrays.asList(true, false)) { + for (boolean isControlBatch : Arrays.asList(true, false)) { + ByteBuffer buffer = ByteBuffer.allocate(2048); + DefaultRecordBatch.writeEmptyHeader(buffer, RecordBatch.CURRENT_MAGIC_VALUE, producerId, + producerEpoch, baseSequence, baseOffset, lastOffset, partitionLeaderEpoch, timestampType, + timestamp, isTransactional, isControlBatch); + buffer.flip(); + DefaultRecordBatch batch = new DefaultRecordBatch(buffer); + assertEquals(producerId, batch.producerId()); + assertEquals(producerEpoch, batch.producerEpoch()); + assertEquals(baseSequence, batch.baseSequence()); + assertEquals(baseSequence + ((int) (lastOffset - baseOffset)), batch.lastSequence()); + assertEquals(baseOffset, batch.baseOffset()); + assertEquals(lastOffset, batch.lastOffset()); + assertEquals(partitionLeaderEpoch, batch.partitionLeaderEpoch()); + assertEquals(isTransactional, batch.isTransactional()); + assertEquals(timestampType, batch.timestampType()); + assertEquals(timestamp, batch.maxTimestamp()); + assertEquals(RecordBatch.NO_TIMESTAMP, batch.baseTimestamp()); + assertEquals(isControlBatch, batch.isControlBatch()); + } + } + } + } + + @Test + public void buildDefaultRecordBatch() { + ByteBuffer buffer = ByteBuffer.allocate(2048); + + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, CompressionType.NONE, + TimestampType.CREATE_TIME, 1234567L); + builder.appendWithOffset(1234567, 1L, "a".getBytes(), "v".getBytes()); + builder.appendWithOffset(1234568, 2L, "b".getBytes(), "v".getBytes()); + + MemoryRecords records = builder.build(); + for (MutableRecordBatch batch : records.batches()) { + assertTrue(batch.isValid()); + assertEquals(1234567, batch.baseOffset()); + assertEquals(1234568, batch.lastOffset()); + assertEquals(2L, batch.maxTimestamp()); + assertEquals(RecordBatch.NO_PRODUCER_ID, batch.producerId()); + assertEquals(RecordBatch.NO_PRODUCER_EPOCH, batch.producerEpoch()); + assertEquals(RecordBatch.NO_SEQUENCE, batch.baseSequence()); + assertEquals(RecordBatch.NO_SEQUENCE, batch.lastSequence()); + + for (Record record : batch) record.ensureValid(); + } + } + + @Test + public void buildDefaultRecordBatchWithProducerId() { + long pid = 23423L; + short epoch = 145; + int baseSequence = 983; + + ByteBuffer buffer = ByteBuffer.allocate(2048); + + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, CompressionType.NONE, + TimestampType.CREATE_TIME, 1234567L, RecordBatch.NO_TIMESTAMP, pid, epoch, baseSequence); + builder.appendWithOffset(1234567, 1L, "a".getBytes(), "v".getBytes()); + builder.appendWithOffset(1234568, 2L, "b".getBytes(), "v".getBytes()); + + MemoryRecords records = builder.build(); + for (MutableRecordBatch batch : records.batches()) { + assertTrue(batch.isValid()); + assertEquals(1234567, batch.baseOffset()); + assertEquals(1234568, batch.lastOffset()); + assertEquals(2L, batch.maxTimestamp()); + assertEquals(pid, batch.producerId()); + assertEquals(epoch, batch.producerEpoch()); + assertEquals(baseSequence, batch.baseSequence()); + assertEquals(baseSequence + 1, batch.lastSequence()); + + for (Record record : batch) record.ensureValid(); + } + } + + @Test + public void buildDefaultRecordBatchWithSequenceWrapAround() { + long pid = 23423L; + short epoch = 145; + int baseSequence = Integer.MAX_VALUE - 1; + ByteBuffer buffer = ByteBuffer.allocate(2048); + + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, CompressionType.NONE, + TimestampType.CREATE_TIME, 1234567L, RecordBatch.NO_TIMESTAMP, pid, epoch, baseSequence); + builder.appendWithOffset(1234567, 1L, "a".getBytes(), "v".getBytes()); + builder.appendWithOffset(1234568, 2L, "b".getBytes(), "v".getBytes()); + builder.appendWithOffset(1234569, 3L, "c".getBytes(), "v".getBytes()); + + MemoryRecords records = builder.build(); + List batches = TestUtils.toList(records.batches()); + assertEquals(1, batches.size()); + RecordBatch batch = batches.get(0); + + assertEquals(pid, batch.producerId()); + assertEquals(epoch, batch.producerEpoch()); + assertEquals(baseSequence, batch.baseSequence()); + assertEquals(0, batch.lastSequence()); + List allRecords = TestUtils.toList(batch); + assertEquals(3, allRecords.size()); + assertEquals(Integer.MAX_VALUE - 1, allRecords.get(0).sequence()); + assertEquals(Integer.MAX_VALUE, allRecords.get(1).sequence()); + assertEquals(0, allRecords.get(2).sequence()); + } + + @Test + public void testSizeInBytes() { + Header[] headers = new Header[] { + new RecordHeader("foo", "value".getBytes()), + new RecordHeader("bar", (byte[]) null) + }; + + long timestamp = System.currentTimeMillis(); + SimpleRecord[] records = new SimpleRecord[] { + new SimpleRecord(timestamp, "key".getBytes(), "value".getBytes()), + new SimpleRecord(timestamp + 30000, null, "value".getBytes()), + new SimpleRecord(timestamp + 60000, "key".getBytes(), null), + new SimpleRecord(timestamp + 60000, "key".getBytes(), "value".getBytes(), headers) + }; + int actualSize = MemoryRecords.withRecords(CompressionType.NONE, records).sizeInBytes(); + assertEquals(actualSize, DefaultRecordBatch.sizeInBytes(Arrays.asList(records))); + } + + @Test + public void testInvalidRecordSize() { + MemoryRecords records = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V2, 0L, + CompressionType.NONE, TimestampType.CREATE_TIME, + new SimpleRecord(1L, "a".getBytes(), "1".getBytes()), + new SimpleRecord(2L, "b".getBytes(), "2".getBytes()), + new SimpleRecord(3L, "c".getBytes(), "3".getBytes())); + + ByteBuffer buffer = records.buffer(); + buffer.putInt(DefaultRecordBatch.LENGTH_OFFSET, 10); + + DefaultRecordBatch batch = new DefaultRecordBatch(buffer); + assertFalse(batch.isValid()); + assertThrows(CorruptRecordException.class, batch::ensureValid); + } + + @Test + public void testInvalidRecordCountTooManyNonCompressedV2() { + long now = System.currentTimeMillis(); + DefaultRecordBatch batch = recordsWithInvalidRecordCount(RecordBatch.MAGIC_VALUE_V2, now, CompressionType.NONE, 5); + // force iteration through the batch to execute validation + // batch validation is a part of normal workflow for LogValidator.validateMessagesAndAssignOffsets + assertThrows(InvalidRecordException.class, () -> batch.forEach(Record::ensureValid)); + } + + @Test + public void testInvalidRecordCountTooLittleNonCompressedV2() { + long now = System.currentTimeMillis(); + DefaultRecordBatch batch = recordsWithInvalidRecordCount(RecordBatch.MAGIC_VALUE_V2, now, CompressionType.NONE, 2); + // force iteration through the batch to execute validation + // batch validation is a part of normal workflow for LogValidator.validateMessagesAndAssignOffsets + assertThrows(InvalidRecordException.class, () -> batch.forEach(Record::ensureValid)); + } + + @Test + public void testInvalidRecordCountTooManyCompressedV2() { + long now = System.currentTimeMillis(); + DefaultRecordBatch batch = recordsWithInvalidRecordCount(RecordBatch.MAGIC_VALUE_V2, now, CompressionType.GZIP, 5); + // force iteration through the batch to execute validation + // batch validation is a part of normal workflow for LogValidator.validateMessagesAndAssignOffsets + assertThrows(InvalidRecordException.class, () -> batch.forEach(Record::ensureValid)); + } + + @Test + public void testInvalidRecordCountTooLittleCompressedV2() { + long now = System.currentTimeMillis(); + DefaultRecordBatch batch = recordsWithInvalidRecordCount(RecordBatch.MAGIC_VALUE_V2, now, CompressionType.GZIP, 2); + // force iteration through the batch to execute validation + // batch validation is a part of normal workflow for LogValidator.validateMessagesAndAssignOffsets + assertThrows(InvalidRecordException.class, () -> batch.forEach(Record::ensureValid)); + } + + @Test + public void testInvalidCrc() { + MemoryRecords records = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V2, 0L, + CompressionType.NONE, TimestampType.CREATE_TIME, + new SimpleRecord(1L, "a".getBytes(), "1".getBytes()), + new SimpleRecord(2L, "b".getBytes(), "2".getBytes()), + new SimpleRecord(3L, "c".getBytes(), "3".getBytes())); + + ByteBuffer buffer = records.buffer(); + buffer.putInt(DefaultRecordBatch.LAST_OFFSET_DELTA_OFFSET, 23); + + DefaultRecordBatch batch = new DefaultRecordBatch(buffer); + assertFalse(batch.isValid()); + assertThrows(CorruptRecordException.class, batch::ensureValid); + } + + @Test + public void testSetLastOffset() { + SimpleRecord[] simpleRecords = new SimpleRecord[] { + new SimpleRecord(1L, "a".getBytes(), "1".getBytes()), + new SimpleRecord(2L, "b".getBytes(), "2".getBytes()), + new SimpleRecord(3L, "c".getBytes(), "3".getBytes()) + }; + MemoryRecords records = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V2, 0L, + CompressionType.NONE, TimestampType.CREATE_TIME, simpleRecords); + + long lastOffset = 500L; + long firstOffset = lastOffset - simpleRecords.length + 1; + + DefaultRecordBatch batch = new DefaultRecordBatch(records.buffer()); + batch.setLastOffset(lastOffset); + assertEquals(lastOffset, batch.lastOffset()); + assertEquals(firstOffset, batch.baseOffset()); + assertTrue(batch.isValid()); + + List recordBatches = Utils.toList(records.batches().iterator()); + assertEquals(1, recordBatches.size()); + assertEquals(lastOffset, recordBatches.get(0).lastOffset()); + + long offset = firstOffset; + for (Record record : records.records()) + assertEquals(offset++, record.offset()); + } + + @Test + public void testSetPartitionLeaderEpoch() { + MemoryRecords records = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V2, 0L, + CompressionType.NONE, TimestampType.CREATE_TIME, + new SimpleRecord(1L, "a".getBytes(), "1".getBytes()), + new SimpleRecord(2L, "b".getBytes(), "2".getBytes()), + new SimpleRecord(3L, "c".getBytes(), "3".getBytes())); + + int leaderEpoch = 500; + + DefaultRecordBatch batch = new DefaultRecordBatch(records.buffer()); + batch.setPartitionLeaderEpoch(leaderEpoch); + assertEquals(leaderEpoch, batch.partitionLeaderEpoch()); + assertTrue(batch.isValid()); + + List recordBatches = Utils.toList(records.batches().iterator()); + assertEquals(1, recordBatches.size()); + assertEquals(leaderEpoch, recordBatches.get(0).partitionLeaderEpoch()); + } + + @Test + public void testSetLogAppendTime() { + MemoryRecords records = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V2, 0L, + CompressionType.NONE, TimestampType.CREATE_TIME, + new SimpleRecord(1L, "a".getBytes(), "1".getBytes()), + new SimpleRecord(2L, "b".getBytes(), "2".getBytes()), + new SimpleRecord(3L, "c".getBytes(), "3".getBytes())); + + long logAppendTime = 15L; + + DefaultRecordBatch batch = new DefaultRecordBatch(records.buffer()); + batch.setMaxTimestamp(TimestampType.LOG_APPEND_TIME, logAppendTime); + assertEquals(TimestampType.LOG_APPEND_TIME, batch.timestampType()); + assertEquals(logAppendTime, batch.maxTimestamp()); + assertTrue(batch.isValid()); + + List recordBatches = Utils.toList(records.batches().iterator()); + assertEquals(1, recordBatches.size()); + assertEquals(logAppendTime, recordBatches.get(0).maxTimestamp()); + assertEquals(TimestampType.LOG_APPEND_TIME, recordBatches.get(0).timestampType()); + + for (Record record : records.records()) + assertEquals(logAppendTime, record.timestamp()); + } + + @Test + public void testSetNoTimestampTypeNotAllowed() { + MemoryRecords records = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V2, 0L, + CompressionType.NONE, TimestampType.CREATE_TIME, + new SimpleRecord(1L, "a".getBytes(), "1".getBytes()), + new SimpleRecord(2L, "b".getBytes(), "2".getBytes()), + new SimpleRecord(3L, "c".getBytes(), "3".getBytes())); + DefaultRecordBatch batch = new DefaultRecordBatch(records.buffer()); + assertThrows(IllegalArgumentException.class, () -> batch.setMaxTimestamp(TimestampType.NO_TIMESTAMP_TYPE, RecordBatch.NO_TIMESTAMP)); + } + + @Test + public void testReadAndWriteControlBatch() { + long producerId = 1L; + short producerEpoch = 0; + int coordinatorEpoch = 15; + + ByteBuffer buffer = ByteBuffer.allocate(128); + MemoryRecordsBuilder builder = new MemoryRecordsBuilder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, + CompressionType.NONE, TimestampType.CREATE_TIME, 0L, RecordBatch.NO_TIMESTAMP, producerId, + producerEpoch, RecordBatch.NO_SEQUENCE, true, true, RecordBatch.NO_PARTITION_LEADER_EPOCH, + buffer.remaining()); + + EndTransactionMarker marker = new EndTransactionMarker(ControlRecordType.COMMIT, coordinatorEpoch); + builder.appendEndTxnMarker(System.currentTimeMillis(), marker); + MemoryRecords records = builder.build(); + + List batches = TestUtils.toList(records.batches()); + assertEquals(1, batches.size()); + + MutableRecordBatch batch = batches.get(0); + assertTrue(batch.isControlBatch()); + + List logRecords = TestUtils.toList(records.records()); + assertEquals(1, logRecords.size()); + + Record commitRecord = logRecords.get(0); + assertEquals(marker, EndTransactionMarker.deserialize(commitRecord)); + } + + @Test + public void testStreamingIteratorConsistency() { + MemoryRecords records = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V2, 0L, + CompressionType.GZIP, TimestampType.CREATE_TIME, + new SimpleRecord(1L, "a".getBytes(), "1".getBytes()), + new SimpleRecord(2L, "b".getBytes(), "2".getBytes()), + new SimpleRecord(3L, "c".getBytes(), "3".getBytes())); + DefaultRecordBatch batch = new DefaultRecordBatch(records.buffer()); + try (CloseableIterator streamingIterator = batch.streamingIterator(BufferSupplier.create())) { + TestUtils.checkEquals(streamingIterator, batch.iterator()); + } + } + + @Test + public void testSkipKeyValueIteratorCorrectness() { + Header[] headers = {new RecordHeader("k1", "v1".getBytes()), new RecordHeader("k2", "v2".getBytes())}; + + MemoryRecords records = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V2, 0L, + CompressionType.LZ4, TimestampType.CREATE_TIME, + new SimpleRecord(1L, "a".getBytes(), "1".getBytes()), + new SimpleRecord(2L, "b".getBytes(), "2".getBytes()), + new SimpleRecord(3L, "c".getBytes(), "3".getBytes()), + new SimpleRecord(1000L, "abc".getBytes(), "0".getBytes()), + new SimpleRecord(9999L, "abc".getBytes(), "0".getBytes(), headers) + ); + DefaultRecordBatch batch = new DefaultRecordBatch(records.buffer()); + try (CloseableIterator streamingIterator = batch.skipKeyValueIterator(BufferSupplier.NO_CACHING)) { + assertEquals(Arrays.asList( + new PartialDefaultRecord(9, (byte) 0, 0L, 1L, -1, 1, 1), + new PartialDefaultRecord(9, (byte) 0, 1L, 2L, -1, 1, 1), + new PartialDefaultRecord(9, (byte) 0, 2L, 3L, -1, 1, 1), + new PartialDefaultRecord(12, (byte) 0, 3L, 1000L, -1, 3, 1), + new PartialDefaultRecord(25, (byte) 0, 4L, 9999L, -1, 3, 1) + ), + Utils.toList(streamingIterator) + ); + } + } + + @Test + public void testIncrementSequence() { + assertEquals(10, DefaultRecordBatch.incrementSequence(5, 5)); + assertEquals(0, DefaultRecordBatch.incrementSequence(Integer.MAX_VALUE, 1)); + assertEquals(4, DefaultRecordBatch.incrementSequence(Integer.MAX_VALUE - 5, 10)); + } + + @Test + public void testDecrementSequence() { + assertEquals(0, DefaultRecordBatch.decrementSequence(5, 5)); + assertEquals(Integer.MAX_VALUE, DefaultRecordBatch.decrementSequence(0, 1)); + } + + private static DefaultRecordBatch recordsWithInvalidRecordCount(Byte magicValue, long timestamp, + CompressionType codec, int invalidCount) { + ByteBuffer buf = ByteBuffer.allocate(512); + MemoryRecordsBuilder builder = MemoryRecords.builder(buf, magicValue, codec, TimestampType.CREATE_TIME, 0L); + builder.appendWithOffset(0, timestamp, null, "hello".getBytes()); + builder.appendWithOffset(1, timestamp, null, "there".getBytes()); + builder.appendWithOffset(2, timestamp, null, "beautiful".getBytes()); + MemoryRecords records = builder.build(); + ByteBuffer buffer = records.buffer(); + buffer.position(0); + buffer.putInt(RECORDS_COUNT_OFFSET, invalidCount); + buffer.position(0); + return new DefaultRecordBatch(buffer); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordTest.java b/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordTest.java new file mode 100644 index 0000000..49743d2 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/record/DefaultRecordTest.java @@ -0,0 +1,494 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.InvalidRecordException; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.utils.ByteBufferInputStream; +import org.apache.kafka.common.utils.ByteBufferOutputStream; +import org.apache.kafka.common.utils.ByteUtils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class DefaultRecordTest { + + private byte[] skipArray; + + @BeforeEach + public void setUp() { + skipArray = new byte[64]; + } + + @Test + public void testBasicSerde() throws IOException { + Header[] headers = new Header[] { + new RecordHeader("foo", "value".getBytes()), + new RecordHeader("bar", (byte[]) null), + new RecordHeader("\"A\\u00ea\\u00f1\\u00fcC\"", "value".getBytes()) + }; + + SimpleRecord[] records = new SimpleRecord[] { + new SimpleRecord("hi".getBytes(), "there".getBytes()), + new SimpleRecord(null, "there".getBytes()), + new SimpleRecord("hi".getBytes(), null), + new SimpleRecord(null, null), + new SimpleRecord(15L, "hi".getBytes(), "there".getBytes(), headers) + }; + + for (SimpleRecord record : records) { + int baseSequence = 723; + long baseOffset = 37; + int offsetDelta = 10; + long baseTimestamp = System.currentTimeMillis(); + long timestampDelta = 323; + + ByteBufferOutputStream out = new ByteBufferOutputStream(1024); + DefaultRecord.writeTo(new DataOutputStream(out), offsetDelta, timestampDelta, record.key(), record.value(), + record.headers()); + ByteBuffer buffer = out.buffer(); + buffer.flip(); + + DefaultRecord logRecord = DefaultRecord.readFrom(buffer, baseOffset, baseTimestamp, baseSequence, null); + assertNotNull(logRecord); + assertEquals(baseOffset + offsetDelta, logRecord.offset()); + assertEquals(baseSequence + offsetDelta, logRecord.sequence()); + assertEquals(baseTimestamp + timestampDelta, logRecord.timestamp()); + assertEquals(record.key(), logRecord.key()); + assertEquals(record.value(), logRecord.value()); + assertArrayEquals(record.headers(), logRecord.headers()); + assertEquals(DefaultRecord.sizeInBytes(offsetDelta, timestampDelta, record.key(), record.value(), + record.headers()), logRecord.sizeInBytes()); + } + } + + @Test + public void testBasicSerdeInvalidHeaderCountTooHigh() throws IOException { + Header[] headers = new Header[] { + new RecordHeader("foo", "value".getBytes()), + new RecordHeader("bar", null), + new RecordHeader("\"A\\u00ea\\u00f1\\u00fcC\"", "value".getBytes()) + }; + + SimpleRecord record = new SimpleRecord(15L, "hi".getBytes(), "there".getBytes(), headers); + + int baseSequence = 723; + long baseOffset = 37; + int offsetDelta = 10; + long baseTimestamp = System.currentTimeMillis(); + long timestampDelta = 323; + + ByteBufferOutputStream out = new ByteBufferOutputStream(1024); + DefaultRecord.writeTo(new DataOutputStream(out), offsetDelta, timestampDelta, record.key(), record.value(), + record.headers()); + ByteBuffer buffer = out.buffer(); + buffer.flip(); + buffer.put(14, (byte) 8); + assertThrows(InvalidRecordException.class, + () -> DefaultRecord.readFrom(buffer, baseOffset, baseTimestamp, baseSequence, null)); + } + + @Test + public void testBasicSerdeInvalidHeaderCountTooLow() throws IOException { + Header[] headers = new Header[] { + new RecordHeader("foo", "value".getBytes()), + new RecordHeader("bar", null), + new RecordHeader("\"A\\u00ea\\u00f1\\u00fcC\"", "value".getBytes()) + }; + + SimpleRecord record = new SimpleRecord(15L, "hi".getBytes(), "there".getBytes(), headers); + + int baseSequence = 723; + long baseOffset = 37; + int offsetDelta = 10; + long baseTimestamp = System.currentTimeMillis(); + long timestampDelta = 323; + + ByteBufferOutputStream out = new ByteBufferOutputStream(1024); + DefaultRecord.writeTo(new DataOutputStream(out), offsetDelta, timestampDelta, record.key(), record.value(), + record.headers()); + ByteBuffer buffer = out.buffer(); + buffer.flip(); + buffer.put(14, (byte) 4); + + assertThrows(InvalidRecordException.class, + () -> DefaultRecord.readFrom(buffer, baseOffset, baseTimestamp, baseSequence, null)); + } + + @Test + public void testInvalidKeySize() { + byte attributes = 0; + long timestampDelta = 2; + int offsetDelta = 1; + int sizeOfBodyInBytes = 100; + int keySize = 105; // use a key size larger than the full message + + ByteBuffer buf = ByteBuffer.allocate(sizeOfBodyInBytes + ByteUtils.sizeOfVarint(sizeOfBodyInBytes)); + ByteUtils.writeVarint(sizeOfBodyInBytes, buf); + buf.put(attributes); + ByteUtils.writeVarlong(timestampDelta, buf); + ByteUtils.writeVarint(offsetDelta, buf); + ByteUtils.writeVarint(keySize, buf); + buf.position(buf.limit()); + + buf.flip(); + assertThrows(InvalidRecordException.class, + () -> DefaultRecord.readFrom(buf, 0L, 0L, RecordBatch.NO_SEQUENCE, null)); + } + + @Test + public void testInvalidKeySizePartial() { + byte attributes = 0; + long timestampDelta = 2; + int offsetDelta = 1; + int sizeOfBodyInBytes = 100; + int keySize = 105; // use a key size larger than the full message + + ByteBuffer buf = ByteBuffer.allocate(sizeOfBodyInBytes + ByteUtils.sizeOfVarint(sizeOfBodyInBytes)); + ByteUtils.writeVarint(sizeOfBodyInBytes, buf); + buf.put(attributes); + ByteUtils.writeVarlong(timestampDelta, buf); + ByteUtils.writeVarint(offsetDelta, buf); + ByteUtils.writeVarint(keySize, buf); + buf.position(buf.limit()); + + buf.flip(); + DataInputStream inputStream = new DataInputStream(new ByteBufferInputStream(buf)); + assertThrows(InvalidRecordException.class, + () -> DefaultRecord.readPartiallyFrom(inputStream, skipArray, 0L, 0L, RecordBatch.NO_SEQUENCE, null)); + } + + @Test + public void testInvalidValueSize() { + byte attributes = 0; + long timestampDelta = 2; + int offsetDelta = 1; + int sizeOfBodyInBytes = 100; + int valueSize = 105; // use a value size larger than the full message + + ByteBuffer buf = ByteBuffer.allocate(sizeOfBodyInBytes + ByteUtils.sizeOfVarint(sizeOfBodyInBytes)); + ByteUtils.writeVarint(sizeOfBodyInBytes, buf); + buf.put(attributes); + ByteUtils.writeVarlong(timestampDelta, buf); + ByteUtils.writeVarint(offsetDelta, buf); + ByteUtils.writeVarint(-1, buf); // null key + ByteUtils.writeVarint(valueSize, buf); + buf.position(buf.limit()); + + buf.flip(); + assertThrows(InvalidRecordException.class, + () -> DefaultRecord.readFrom(buf, 0L, 0L, RecordBatch.NO_SEQUENCE, null)); + } + + @Test + public void testInvalidValueSizePartial() throws IOException { + byte attributes = 0; + long timestampDelta = 2; + int offsetDelta = 1; + int sizeOfBodyInBytes = 100; + int valueSize = 105; // use a value size larger than the full message + + ByteBuffer buf = ByteBuffer.allocate(sizeOfBodyInBytes + ByteUtils.sizeOfVarint(sizeOfBodyInBytes)); + ByteUtils.writeVarint(sizeOfBodyInBytes, buf); + buf.put(attributes); + ByteUtils.writeVarlong(timestampDelta, buf); + ByteUtils.writeVarint(offsetDelta, buf); + ByteUtils.writeVarint(-1, buf); // null key + ByteUtils.writeVarint(valueSize, buf); + buf.position(buf.limit()); + + buf.flip(); + DataInputStream inputStream = new DataInputStream(new ByteBufferInputStream(buf)); + assertThrows(InvalidRecordException.class, + () -> DefaultRecord.readPartiallyFrom(inputStream, skipArray, 0L, 0L, RecordBatch.NO_SEQUENCE, null)); + } + + @Test + public void testInvalidNumHeaders() { + byte attributes = 0; + long timestampDelta = 2; + int offsetDelta = 1; + int sizeOfBodyInBytes = 100; + + ByteBuffer buf = ByteBuffer.allocate(sizeOfBodyInBytes + ByteUtils.sizeOfVarint(sizeOfBodyInBytes)); + ByteUtils.writeVarint(sizeOfBodyInBytes, buf); + buf.put(attributes); + ByteUtils.writeVarlong(timestampDelta, buf); + ByteUtils.writeVarint(offsetDelta, buf); + ByteUtils.writeVarint(-1, buf); // null key + ByteUtils.writeVarint(-1, buf); // null value + ByteUtils.writeVarint(-1, buf); // -1 num.headers, not allowed + buf.position(buf.limit()); + + buf.flip(); + assertThrows(InvalidRecordException.class, + () -> DefaultRecord.readFrom(buf, 0L, 0L, RecordBatch.NO_SEQUENCE, null)); + } + + @Test + public void testInvalidNumHeadersPartial() { + byte attributes = 0; + long timestampDelta = 2; + int offsetDelta = 1; + int sizeOfBodyInBytes = 100; + + ByteBuffer buf = ByteBuffer.allocate(sizeOfBodyInBytes + ByteUtils.sizeOfVarint(sizeOfBodyInBytes)); + ByteUtils.writeVarint(sizeOfBodyInBytes, buf); + buf.put(attributes); + ByteUtils.writeVarlong(timestampDelta, buf); + ByteUtils.writeVarint(offsetDelta, buf); + ByteUtils.writeVarint(-1, buf); // null key + ByteUtils.writeVarint(-1, buf); // null value + ByteUtils.writeVarint(-1, buf); // -1 num.headers, not allowed + buf.position(buf.limit()); + + buf.flip(); + DataInputStream inputStream = new DataInputStream(new ByteBufferInputStream(buf)); + assertThrows(InvalidRecordException.class, + () -> DefaultRecord.readPartiallyFrom(inputStream, skipArray, 0L, 0L, RecordBatch.NO_SEQUENCE, null)); + } + + @Test + public void testInvalidHeaderKey() { + byte attributes = 0; + long timestampDelta = 2; + int offsetDelta = 1; + int sizeOfBodyInBytes = 100; + + ByteBuffer buf = ByteBuffer.allocate(sizeOfBodyInBytes + ByteUtils.sizeOfVarint(sizeOfBodyInBytes)); + ByteUtils.writeVarint(sizeOfBodyInBytes, buf); + buf.put(attributes); + ByteUtils.writeVarlong(timestampDelta, buf); + ByteUtils.writeVarint(offsetDelta, buf); + ByteUtils.writeVarint(-1, buf); // null key + ByteUtils.writeVarint(-1, buf); // null value + ByteUtils.writeVarint(1, buf); + ByteUtils.writeVarint(105, buf); // header key too long + buf.position(buf.limit()); + + buf.flip(); + assertThrows(InvalidRecordException.class, + () -> DefaultRecord.readFrom(buf, 0L, 0L, RecordBatch.NO_SEQUENCE, null)); + } + + @Test + public void testInvalidHeaderKeyPartial() { + byte attributes = 0; + long timestampDelta = 2; + int offsetDelta = 1; + int sizeOfBodyInBytes = 100; + + ByteBuffer buf = ByteBuffer.allocate(sizeOfBodyInBytes + ByteUtils.sizeOfVarint(sizeOfBodyInBytes)); + ByteUtils.writeVarint(sizeOfBodyInBytes, buf); + buf.put(attributes); + ByteUtils.writeVarlong(timestampDelta, buf); + ByteUtils.writeVarint(offsetDelta, buf); + ByteUtils.writeVarint(-1, buf); // null key + ByteUtils.writeVarint(-1, buf); // null value + ByteUtils.writeVarint(1, buf); + ByteUtils.writeVarint(105, buf); // header key too long + buf.position(buf.limit()); + + buf.flip(); + DataInputStream inputStream = new DataInputStream(new ByteBufferInputStream(buf)); + assertThrows(InvalidRecordException.class, + () -> DefaultRecord.readPartiallyFrom(inputStream, skipArray, 0L, 0L, RecordBatch.NO_SEQUENCE, null)); + } + + @Test + public void testNullHeaderKey() { + byte attributes = 0; + long timestampDelta = 2; + int offsetDelta = 1; + int sizeOfBodyInBytes = 100; + + ByteBuffer buf = ByteBuffer.allocate(sizeOfBodyInBytes + ByteUtils.sizeOfVarint(sizeOfBodyInBytes)); + ByteUtils.writeVarint(sizeOfBodyInBytes, buf); + buf.put(attributes); + ByteUtils.writeVarlong(timestampDelta, buf); + ByteUtils.writeVarint(offsetDelta, buf); + ByteUtils.writeVarint(-1, buf); // null key + ByteUtils.writeVarint(-1, buf); // null value + ByteUtils.writeVarint(1, buf); + ByteUtils.writeVarint(-1, buf); // null header key not allowed + buf.position(buf.limit()); + + buf.flip(); + assertThrows(InvalidRecordException.class, + () -> DefaultRecord.readFrom(buf, 0L, 0L, RecordBatch.NO_SEQUENCE, null)); + } + + @Test + public void testNullHeaderKeyPartial() { + byte attributes = 0; + long timestampDelta = 2; + int offsetDelta = 1; + int sizeOfBodyInBytes = 100; + + ByteBuffer buf = ByteBuffer.allocate(sizeOfBodyInBytes + ByteUtils.sizeOfVarint(sizeOfBodyInBytes)); + ByteUtils.writeVarint(sizeOfBodyInBytes, buf); + buf.put(attributes); + ByteUtils.writeVarlong(timestampDelta, buf); + ByteUtils.writeVarint(offsetDelta, buf); + ByteUtils.writeVarint(-1, buf); // null key + ByteUtils.writeVarint(-1, buf); // null value + ByteUtils.writeVarint(1, buf); + ByteUtils.writeVarint(-1, buf); // null header key not allowed + buf.position(buf.limit()); + + buf.flip(); + DataInputStream inputStream = new DataInputStream(new ByteBufferInputStream(buf)); + assertThrows(InvalidRecordException.class, + () -> DefaultRecord.readPartiallyFrom(inputStream, skipArray, 0L, 0L, RecordBatch.NO_SEQUENCE, null)); + } + + @Test + public void testInvalidHeaderValue() { + byte attributes = 0; + long timestampDelta = 2; + int offsetDelta = 1; + int sizeOfBodyInBytes = 100; + + ByteBuffer buf = ByteBuffer.allocate(sizeOfBodyInBytes + ByteUtils.sizeOfVarint(sizeOfBodyInBytes)); + ByteUtils.writeVarint(sizeOfBodyInBytes, buf); + buf.put(attributes); + ByteUtils.writeVarlong(timestampDelta, buf); + ByteUtils.writeVarint(offsetDelta, buf); + ByteUtils.writeVarint(-1, buf); // null key + ByteUtils.writeVarint(-1, buf); // null value + ByteUtils.writeVarint(1, buf); + ByteUtils.writeVarint(1, buf); + buf.put((byte) 1); + ByteUtils.writeVarint(105, buf); // header value too long + buf.position(buf.limit()); + + buf.flip(); + assertThrows(InvalidRecordException.class, + () -> DefaultRecord.readFrom(buf, 0L, 0L, RecordBatch.NO_SEQUENCE, null)); + } + + @Test + public void testInvalidHeaderValuePartial() { + byte attributes = 0; + long timestampDelta = 2; + int offsetDelta = 1; + int sizeOfBodyInBytes = 100; + + ByteBuffer buf = ByteBuffer.allocate(sizeOfBodyInBytes + ByteUtils.sizeOfVarint(sizeOfBodyInBytes)); + ByteUtils.writeVarint(sizeOfBodyInBytes, buf); + buf.put(attributes); + ByteUtils.writeVarlong(timestampDelta, buf); + ByteUtils.writeVarint(offsetDelta, buf); + ByteUtils.writeVarint(-1, buf); // null key + ByteUtils.writeVarint(-1, buf); // null value + ByteUtils.writeVarint(1, buf); + ByteUtils.writeVarint(1, buf); + buf.put((byte) 1); + ByteUtils.writeVarint(105, buf); // header value too long + buf.position(buf.limit()); + + buf.flip(); + DataInputStream inputStream = new DataInputStream(new ByteBufferInputStream(buf)); + assertThrows(InvalidRecordException.class, + () -> DefaultRecord.readPartiallyFrom(inputStream, skipArray, 0L, 0L, RecordBatch.NO_SEQUENCE, null)); + } + + @Test + public void testUnderflowReadingTimestamp() { + byte attributes = 0; + int sizeOfBodyInBytes = 1; + ByteBuffer buf = ByteBuffer.allocate(sizeOfBodyInBytes + ByteUtils.sizeOfVarint(sizeOfBodyInBytes)); + ByteUtils.writeVarint(sizeOfBodyInBytes, buf); + buf.put(attributes); + + buf.flip(); + assertThrows(InvalidRecordException.class, + () -> DefaultRecord.readFrom(buf, 0L, 0L, RecordBatch.NO_SEQUENCE, null)); + } + + @Test + public void testUnderflowReadingVarlong() { + byte attributes = 0; + int sizeOfBodyInBytes = 2; // one byte for attributes, one byte for partial timestamp + ByteBuffer buf = ByteBuffer.allocate(sizeOfBodyInBytes + ByteUtils.sizeOfVarint(sizeOfBodyInBytes) + 1); + ByteUtils.writeVarint(sizeOfBodyInBytes, buf); + buf.put(attributes); + ByteUtils.writeVarlong(156, buf); // needs 2 bytes to represent + buf.position(buf.limit() - 1); + + buf.flip(); + assertThrows(InvalidRecordException.class, + () -> DefaultRecord.readFrom(buf, 0L, 0L, RecordBatch.NO_SEQUENCE, null)); + } + + @Test + public void testInvalidVarlong() { + byte attributes = 0; + int sizeOfBodyInBytes = 11; // one byte for attributes, 10 bytes for max timestamp + ByteBuffer buf = ByteBuffer.allocate(sizeOfBodyInBytes + ByteUtils.sizeOfVarint(sizeOfBodyInBytes) + 1); + ByteUtils.writeVarint(sizeOfBodyInBytes, buf); + int recordStartPosition = buf.position(); + + buf.put(attributes); + ByteUtils.writeVarlong(Long.MAX_VALUE, buf); // takes 10 bytes + buf.put(recordStartPosition + 10, Byte.MIN_VALUE); // use an invalid final byte + + buf.flip(); + assertThrows(InvalidRecordException.class, + () -> DefaultRecord.readFrom(buf, 0L, 0L, RecordBatch.NO_SEQUENCE, null)); + } + + @Test + public void testSerdeNoSequence() throws IOException { + ByteBuffer key = ByteBuffer.wrap("hi".getBytes()); + ByteBuffer value = ByteBuffer.wrap("there".getBytes()); + long baseOffset = 37; + int offsetDelta = 10; + long baseTimestamp = System.currentTimeMillis(); + long timestampDelta = 323; + + ByteBufferOutputStream out = new ByteBufferOutputStream(1024); + DefaultRecord.writeTo(new DataOutputStream(out), offsetDelta, timestampDelta, key, value, new Header[0]); + ByteBuffer buffer = out.buffer(); + buffer.flip(); + + DefaultRecord record = DefaultRecord.readFrom(buffer, baseOffset, baseTimestamp, RecordBatch.NO_SEQUENCE, null); + assertNotNull(record); + assertEquals(RecordBatch.NO_SEQUENCE, record.sequence()); + } + + @Test + public void testInvalidSizeOfBodyInBytes() { + int sizeOfBodyInBytes = 10; + ByteBuffer buf = ByteBuffer.allocate(5); + ByteUtils.writeVarint(sizeOfBodyInBytes, buf); + + buf.flip(); + assertThrows(InvalidRecordException.class, + () -> DefaultRecord.readFrom(buf, 0L, 0L, RecordBatch.NO_SEQUENCE, null)); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/record/EndTransactionMarkerTest.java b/clients/src/test/java/org/apache/kafka/common/record/EndTransactionMarkerTest.java new file mode 100644 index 0000000..4ac1417 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/record/EndTransactionMarkerTest.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.InvalidRecordException; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class EndTransactionMarkerTest { + + @Test + public void testUnknownControlTypeNotAllowed() { + assertThrows(IllegalArgumentException.class, + () -> new EndTransactionMarker(ControlRecordType.UNKNOWN, 24)); + } + + @Test + public void testCannotDeserializeUnknownControlType() { + assertThrows(IllegalArgumentException.class, + () -> EndTransactionMarker.deserializeValue(ControlRecordType.UNKNOWN, ByteBuffer.wrap(new byte[0]))); + } + + @Test + public void testIllegalNegativeVersion() { + ByteBuffer buffer = ByteBuffer.allocate(2); + buffer.putShort((short) -1); + buffer.flip(); + assertThrows(InvalidRecordException.class, () -> EndTransactionMarker.deserializeValue(ControlRecordType.ABORT, buffer)); + } + + @Test + public void testNotEnoughBytes() { + assertThrows(InvalidRecordException.class, + () -> EndTransactionMarker.deserializeValue(ControlRecordType.COMMIT, ByteBuffer.wrap(new byte[0]))); + } + + @Test + public void testSerde() { + int coordinatorEpoch = 79; + EndTransactionMarker marker = new EndTransactionMarker(ControlRecordType.COMMIT, coordinatorEpoch); + ByteBuffer buffer = marker.serializeValue(); + EndTransactionMarker deserialized = EndTransactionMarker.deserializeValue(ControlRecordType.COMMIT, buffer); + assertEquals(coordinatorEpoch, deserialized.coordinatorEpoch()); + } + + @Test + public void testDeserializeNewerVersion() { + int coordinatorEpoch = 79; + ByteBuffer buffer = ByteBuffer.allocate(8); + buffer.putShort((short) 5); + buffer.putInt(coordinatorEpoch); + buffer.putShort((short) 0); // unexpected data + buffer.flip(); + EndTransactionMarker deserialized = EndTransactionMarker.deserializeValue(ControlRecordType.COMMIT, buffer); + assertEquals(coordinatorEpoch, deserialized.coordinatorEpoch()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/record/FileLogInputStreamTest.java b/clients/src/test/java/org/apache/kafka/common/record/FileLogInputStreamTest.java new file mode 100644 index 0000000..8e204a9 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/record/FileLogInputStreamTest.java @@ -0,0 +1,321 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.record.FileLogInputStream.FileChannelRecordBatch; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.ArgumentsProvider; +import org.junit.jupiter.params.provider.ArgumentsSource; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Stream; + +import static java.util.Arrays.asList; +import static org.apache.kafka.common.record.RecordBatch.MAGIC_VALUE_V0; +import static org.apache.kafka.common.record.RecordBatch.MAGIC_VALUE_V1; +import static org.apache.kafka.common.record.RecordBatch.MAGIC_VALUE_V2; +import static org.apache.kafka.common.record.RecordBatch.NO_TIMESTAMP; +import static org.apache.kafka.common.record.TimestampType.CREATE_TIME; +import static org.apache.kafka.common.record.TimestampType.NO_TIMESTAMP_TYPE; +import static org.apache.kafka.test.TestUtils.tempFile; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class FileLogInputStreamTest { + + private static class Args { + final byte magic; + final CompressionType compression; + + public Args(byte magic, CompressionType compression) { + this.magic = magic; + this.compression = compression; + } + + @Override + public String toString() { + return "magic=" + magic + + ", compression=" + compression; + } + } + + private static class FileLogInputStreamArgumentsProvider implements ArgumentsProvider { + + @Override + public Stream provideArguments(ExtensionContext context) { + List arguments = new ArrayList<>(); + for (byte magic : asList(MAGIC_VALUE_V0, MAGIC_VALUE_V1, MAGIC_VALUE_V2)) + for (CompressionType type: CompressionType.values()) + arguments.add(Arguments.of(new Args(magic, type))); + return arguments.stream(); + } + } + + @ParameterizedTest + @ArgumentsSource(FileLogInputStreamArgumentsProvider.class) + public void testWriteTo(Args args) throws IOException { + CompressionType compression = args.compression; + byte magic = args.magic; + if (compression == CompressionType.ZSTD && magic < MAGIC_VALUE_V2) + return; + + try (FileRecords fileRecords = FileRecords.open(tempFile())) { + fileRecords.append(MemoryRecords.withRecords(magic, compression, new SimpleRecord("foo".getBytes()))); + fileRecords.flush(); + + FileLogInputStream logInputStream = new FileLogInputStream(fileRecords, 0, fileRecords.sizeInBytes()); + + FileChannelRecordBatch batch = logInputStream.nextBatch(); + assertNotNull(batch); + assertEquals(magic, batch.magic()); + + ByteBuffer buffer = ByteBuffer.allocate(128); + batch.writeTo(buffer); + buffer.flip(); + + MemoryRecords memRecords = MemoryRecords.readableRecords(buffer); + List records = Utils.toList(memRecords.records().iterator()); + assertEquals(1, records.size()); + Record record0 = records.get(0); + assertTrue(record0.hasMagic(magic)); + assertEquals("foo", Utils.utf8(record0.value(), record0.valueSize())); + } + } + + @ParameterizedTest + @ArgumentsSource(FileLogInputStreamArgumentsProvider.class) + public void testSimpleBatchIteration(Args args) throws IOException { + CompressionType compression = args.compression; + byte magic = args.magic; + if (compression == CompressionType.ZSTD && magic < MAGIC_VALUE_V2) + return; + + try (FileRecords fileRecords = FileRecords.open(tempFile())) { + SimpleRecord firstBatchRecord = new SimpleRecord(3241324L, "a".getBytes(), "foo".getBytes()); + SimpleRecord secondBatchRecord = new SimpleRecord(234280L, "b".getBytes(), "bar".getBytes()); + + fileRecords.append(MemoryRecords.withRecords(magic, 0L, compression, CREATE_TIME, firstBatchRecord)); + fileRecords.append(MemoryRecords.withRecords(magic, 1L, compression, CREATE_TIME, secondBatchRecord)); + fileRecords.flush(); + + FileLogInputStream logInputStream = new FileLogInputStream(fileRecords, 0, fileRecords.sizeInBytes()); + + FileChannelRecordBatch firstBatch = logInputStream.nextBatch(); + assertGenericRecordBatchData(args, firstBatch, 0L, 3241324L, firstBatchRecord); + assertNoProducerData(firstBatch); + + FileChannelRecordBatch secondBatch = logInputStream.nextBatch(); + assertGenericRecordBatchData(args, secondBatch, 1L, 234280L, secondBatchRecord); + assertNoProducerData(secondBatch); + + assertNull(logInputStream.nextBatch()); + } + } + + @ParameterizedTest + @ArgumentsSource(FileLogInputStreamArgumentsProvider.class) + public void testBatchIterationWithMultipleRecordsPerBatch(Args args) throws IOException { + CompressionType compression = args.compression; + byte magic = args.magic; + if (magic < MAGIC_VALUE_V2 && compression == CompressionType.NONE) + return; + + if (compression == CompressionType.ZSTD && magic < MAGIC_VALUE_V2) + return; + + try (FileRecords fileRecords = FileRecords.open(tempFile())) { + SimpleRecord[] firstBatchRecords = new SimpleRecord[]{ + new SimpleRecord(3241324L, "a".getBytes(), "1".getBytes()), + new SimpleRecord(234280L, "b".getBytes(), "2".getBytes()) + }; + + SimpleRecord[] secondBatchRecords = new SimpleRecord[]{ + new SimpleRecord(238423489L, "c".getBytes(), "3".getBytes()), + new SimpleRecord(897839L, null, "4".getBytes()), + new SimpleRecord(8234020L, "e".getBytes(), null) + }; + + fileRecords.append(MemoryRecords.withRecords(magic, 0L, compression, CREATE_TIME, firstBatchRecords)); + fileRecords.append(MemoryRecords.withRecords(magic, 1L, compression, CREATE_TIME, secondBatchRecords)); + fileRecords.flush(); + + FileLogInputStream logInputStream = new FileLogInputStream(fileRecords, 0, fileRecords.sizeInBytes()); + + FileChannelRecordBatch firstBatch = logInputStream.nextBatch(); + assertNoProducerData(firstBatch); + assertGenericRecordBatchData(args, firstBatch, 0L, 3241324L, firstBatchRecords); + + FileChannelRecordBatch secondBatch = logInputStream.nextBatch(); + assertNoProducerData(secondBatch); + assertGenericRecordBatchData(args, secondBatch, 1L, 238423489L, secondBatchRecords); + + assertNull(logInputStream.nextBatch()); + } + } + + @ParameterizedTest + @ArgumentsSource(FileLogInputStreamArgumentsProvider.class) + public void testBatchIterationV2(Args args) throws IOException { + CompressionType compression = args.compression; + byte magic = args.magic; + if (magic != MAGIC_VALUE_V2) + return; + + try (FileRecords fileRecords = FileRecords.open(tempFile())) { + long producerId = 83843L; + short producerEpoch = 15; + int baseSequence = 234; + int partitionLeaderEpoch = 9832; + + SimpleRecord[] firstBatchRecords = new SimpleRecord[]{ + new SimpleRecord(3241324L, "a".getBytes(), "1".getBytes()), + new SimpleRecord(234280L, "b".getBytes(), "2".getBytes()) + }; + + SimpleRecord[] secondBatchRecords = new SimpleRecord[]{ + new SimpleRecord(238423489L, "c".getBytes(), "3".getBytes()), + new SimpleRecord(897839L, null, "4".getBytes()), + new SimpleRecord(8234020L, "e".getBytes(), null) + }; + + fileRecords.append(MemoryRecords.withIdempotentRecords(magic, 15L, compression, producerId, + producerEpoch, baseSequence, partitionLeaderEpoch, firstBatchRecords)); + fileRecords.append(MemoryRecords.withTransactionalRecords(magic, 27L, compression, producerId, + producerEpoch, baseSequence + firstBatchRecords.length, partitionLeaderEpoch, secondBatchRecords)); + fileRecords.flush(); + + FileLogInputStream logInputStream = new FileLogInputStream(fileRecords, 0, fileRecords.sizeInBytes()); + + FileChannelRecordBatch firstBatch = logInputStream.nextBatch(); + assertProducerData(firstBatch, producerId, producerEpoch, baseSequence, false, firstBatchRecords); + assertGenericRecordBatchData(args, firstBatch, 15L, 3241324L, firstBatchRecords); + assertEquals(partitionLeaderEpoch, firstBatch.partitionLeaderEpoch()); + + FileChannelRecordBatch secondBatch = logInputStream.nextBatch(); + assertProducerData(secondBatch, producerId, producerEpoch, baseSequence + firstBatchRecords.length, + true, secondBatchRecords); + assertGenericRecordBatchData(args, secondBatch, 27L, 238423489L, secondBatchRecords); + assertEquals(partitionLeaderEpoch, secondBatch.partitionLeaderEpoch()); + + assertNull(logInputStream.nextBatch()); + } + } + + @ParameterizedTest + @ArgumentsSource(FileLogInputStreamArgumentsProvider.class) + public void testBatchIterationIncompleteBatch(Args args) throws IOException { + CompressionType compression = args.compression; + byte magic = args.magic; + if (compression == CompressionType.ZSTD && magic < MAGIC_VALUE_V2) + return; + + try (FileRecords fileRecords = FileRecords.open(tempFile())) { + SimpleRecord firstBatchRecord = new SimpleRecord(100L, "foo".getBytes()); + SimpleRecord secondBatchRecord = new SimpleRecord(200L, "bar".getBytes()); + + fileRecords.append(MemoryRecords.withRecords(magic, 0L, compression, CREATE_TIME, firstBatchRecord)); + fileRecords.append(MemoryRecords.withRecords(magic, 1L, compression, CREATE_TIME, secondBatchRecord)); + fileRecords.flush(); + fileRecords.truncateTo(fileRecords.sizeInBytes() - 13); + + FileLogInputStream logInputStream = new FileLogInputStream(fileRecords, 0, fileRecords.sizeInBytes()); + + FileChannelRecordBatch firstBatch = logInputStream.nextBatch(); + assertNoProducerData(firstBatch); + assertGenericRecordBatchData(args, firstBatch, 0L, 100L, firstBatchRecord); + + assertNull(logInputStream.nextBatch()); + } + } + + @Test + public void testNextBatchSelectionWithMaxedParams() throws IOException { + try (FileRecords fileRecords = FileRecords.open(tempFile())) { + FileLogInputStream logInputStream = new FileLogInputStream(fileRecords, Integer.MAX_VALUE, Integer.MAX_VALUE); + assertNull(logInputStream.nextBatch()); + } + } + + @Test + public void testNextBatchSelectionWithZeroedParams() throws IOException { + try (FileRecords fileRecords = FileRecords.open(tempFile())) { + FileLogInputStream logInputStream = new FileLogInputStream(fileRecords, 0, 0); + assertNull(logInputStream.nextBatch()); + } + } + + private void assertProducerData(RecordBatch batch, long producerId, short producerEpoch, int baseSequence, + boolean isTransactional, SimpleRecord... records) { + assertEquals(producerId, batch.producerId()); + assertEquals(producerEpoch, batch.producerEpoch()); + assertEquals(baseSequence, batch.baseSequence()); + assertEquals(baseSequence + records.length - 1, batch.lastSequence()); + assertEquals(isTransactional, batch.isTransactional()); + } + + private void assertNoProducerData(RecordBatch batch) { + assertEquals(RecordBatch.NO_PRODUCER_ID, batch.producerId()); + assertEquals(RecordBatch.NO_PRODUCER_EPOCH, batch.producerEpoch()); + assertEquals(RecordBatch.NO_SEQUENCE, batch.baseSequence()); + assertEquals(RecordBatch.NO_SEQUENCE, batch.lastSequence()); + assertFalse(batch.isTransactional()); + } + + private void assertGenericRecordBatchData(Args args, RecordBatch batch, long baseOffset, long maxTimestamp, + SimpleRecord... records) { + CompressionType compression = args.compression; + byte magic = args.magic; + assertEquals(magic, batch.magic()); + assertEquals(compression, batch.compressionType()); + + if (magic == MAGIC_VALUE_V0) { + assertEquals(NO_TIMESTAMP_TYPE, batch.timestampType()); + } else { + assertEquals(CREATE_TIME, batch.timestampType()); + assertEquals(maxTimestamp, batch.maxTimestamp()); + } + + assertEquals(baseOffset + records.length - 1, batch.lastOffset()); + if (magic >= MAGIC_VALUE_V2) + assertEquals(Integer.valueOf(records.length), batch.countOrNull()); + + assertEquals(baseOffset, batch.baseOffset()); + assertTrue(batch.isValid()); + + List batchRecords = TestUtils.toList(batch); + for (int i = 0; i < records.length; i++) { + assertEquals(baseOffset + i, batchRecords.get(i).offset()); + assertEquals(records[i].key(), batchRecords.get(i).key()); + assertEquals(records[i].value(), batchRecords.get(i).value()); + if (magic == MAGIC_VALUE_V0) + assertEquals(NO_TIMESTAMP, batchRecords.get(i).timestamp()); + else + assertEquals(records[i].timestamp(), batchRecords.get(i).timestamp()); + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/record/FileRecordsTest.java b/clients/src/test/java/org/apache/kafka/common/record/FileRecordsTest.java new file mode 100644 index 0000000..2e9ff33 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/record/FileRecordsTest.java @@ -0,0 +1,719 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.network.TransferableChannel; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; +import java.util.Random; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +import static java.util.Arrays.asList; +import static org.apache.kafka.common.utils.Utils.utf8; +import static org.apache.kafka.test.TestUtils.tempFile; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class FileRecordsTest { + + private byte[][] values = new byte[][] { + "abcd".getBytes(), + "efgh".getBytes(), + "ijkl".getBytes() + }; + private FileRecords fileRecords; + private Time time; + + @BeforeEach + public void setup() throws IOException { + this.fileRecords = createFileRecords(values); + this.time = new MockTime(); + } + + @AfterEach + public void cleanup() throws IOException { + this.fileRecords.close(); + } + + @Test + public void testAppendProtectsFromOverflow() throws Exception { + File fileMock = mock(File.class); + FileChannel fileChannelMock = mock(FileChannel.class); + when(fileChannelMock.size()).thenReturn((long) Integer.MAX_VALUE); + + FileRecords records = new FileRecords(fileMock, fileChannelMock, 0, Integer.MAX_VALUE, false); + assertThrows(IllegalArgumentException.class, () -> append(records, values)); + } + + @Test + public void testOpenOversizeFile() throws Exception { + File fileMock = mock(File.class); + FileChannel fileChannelMock = mock(FileChannel.class); + when(fileChannelMock.size()).thenReturn(Integer.MAX_VALUE + 5L); + + assertThrows(KafkaException.class, () -> new FileRecords(fileMock, fileChannelMock, 0, Integer.MAX_VALUE, false)); + } + + @Test + public void testOutOfRangeSlice() { + assertThrows(IllegalArgumentException.class, + () -> this.fileRecords.slice(fileRecords.sizeInBytes() + 1, 15).sizeInBytes()); + } + + /** + * Test that the cached size variable matches the actual file size as we append messages + */ + @Test + public void testFileSize() throws IOException { + assertEquals(fileRecords.channel().size(), fileRecords.sizeInBytes()); + for (int i = 0; i < 20; i++) { + fileRecords.append(MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("abcd".getBytes()))); + assertEquals(fileRecords.channel().size(), fileRecords.sizeInBytes()); + } + } + + /** + * Test that adding invalid bytes to the end of the log doesn't break iteration + */ + @Test + public void testIterationOverPartialAndTruncation() throws IOException { + testPartialWrite(0, fileRecords); + testPartialWrite(2, fileRecords); + testPartialWrite(4, fileRecords); + testPartialWrite(5, fileRecords); + testPartialWrite(6, fileRecords); + } + + @Test + public void testSliceSizeLimitWithConcurrentWrite() throws Exception { + FileRecords log = FileRecords.open(tempFile()); + ExecutorService executor = Executors.newFixedThreadPool(2); + int maxSizeInBytes = 16384; + + try { + Future readerCompletion = executor.submit(() -> { + while (log.sizeInBytes() < maxSizeInBytes) { + int currentSize = log.sizeInBytes(); + FileRecords slice = log.slice(0, currentSize); + assertEquals(currentSize, slice.sizeInBytes()); + } + return null; + }); + + Future writerCompletion = executor.submit(() -> { + while (log.sizeInBytes() < maxSizeInBytes) { + append(log, values); + } + return null; + }); + + writerCompletion.get(); + readerCompletion.get(); + } finally { + executor.shutdownNow(); + } + } + + private void testPartialWrite(int size, FileRecords fileRecords) throws IOException { + ByteBuffer buffer = ByteBuffer.allocate(size); + for (int i = 0; i < size; i++) + buffer.put((byte) 0); + + buffer.rewind(); + + fileRecords.channel().write(buffer); + + // appending those bytes should not change the contents + Iterator records = fileRecords.records().iterator(); + for (byte[] value : values) { + assertTrue(records.hasNext()); + assertEquals(records.next().value(), ByteBuffer.wrap(value)); + } + } + + /** + * Iterating over the file does file reads but shouldn't change the position of the underlying FileChannel. + */ + @Test + public void testIterationDoesntChangePosition() throws IOException { + long position = fileRecords.channel().position(); + Iterator records = fileRecords.records().iterator(); + for (byte[] value : values) { + assertTrue(records.hasNext()); + assertEquals(records.next().value(), ByteBuffer.wrap(value)); + } + assertEquals(position, fileRecords.channel().position()); + } + + /** + * Test a simple append and read. + */ + @Test + public void testRead() throws IOException { + FileRecords read = fileRecords.slice(0, fileRecords.sizeInBytes()); + assertEquals(fileRecords.sizeInBytes(), read.sizeInBytes()); + TestUtils.checkEquals(fileRecords.batches(), read.batches()); + + List items = batches(read); + RecordBatch first = items.get(0); + + // read from second message until the end + read = fileRecords.slice(first.sizeInBytes(), fileRecords.sizeInBytes() - first.sizeInBytes()); + assertEquals(fileRecords.sizeInBytes() - first.sizeInBytes(), read.sizeInBytes()); + assertEquals(items.subList(1, items.size()), batches(read), "Read starting from the second message"); + + // read from second message and size is past the end of the file + read = fileRecords.slice(first.sizeInBytes(), fileRecords.sizeInBytes()); + assertEquals(fileRecords.sizeInBytes() - first.sizeInBytes(), read.sizeInBytes()); + assertEquals(items.subList(1, items.size()), batches(read), "Read starting from the second message"); + + // read from second message and position + size overflows + read = fileRecords.slice(first.sizeInBytes(), Integer.MAX_VALUE); + assertEquals(fileRecords.sizeInBytes() - first.sizeInBytes(), read.sizeInBytes()); + assertEquals(items.subList(1, items.size()), batches(read), "Read starting from the second message"); + + // read from second message and size is past the end of the file on a view/slice + read = fileRecords.slice(1, fileRecords.sizeInBytes() - 1) + .slice(first.sizeInBytes() - 1, fileRecords.sizeInBytes()); + assertEquals(fileRecords.sizeInBytes() - first.sizeInBytes(), read.sizeInBytes()); + assertEquals(items.subList(1, items.size()), batches(read), "Read starting from the second message"); + + // read from second message and position + size overflows on a view/slice + read = fileRecords.slice(1, fileRecords.sizeInBytes() - 1) + .slice(first.sizeInBytes() - 1, Integer.MAX_VALUE); + assertEquals(fileRecords.sizeInBytes() - first.sizeInBytes(), read.sizeInBytes()); + assertEquals(items.subList(1, items.size()), batches(read), "Read starting from the second message"); + + // read a single message starting from second message + RecordBatch second = items.get(1); + read = fileRecords.slice(first.sizeInBytes(), second.sizeInBytes()); + assertEquals(second.sizeInBytes(), read.sizeInBytes()); + assertEquals(Collections.singletonList(second), batches(read), "Read a single message starting from the second message"); + } + + /** + * Test the MessageSet.searchFor API. + */ + @Test + public void testSearch() throws IOException { + // append a new message with a high offset + SimpleRecord lastMessage = new SimpleRecord("test".getBytes()); + fileRecords.append(MemoryRecords.withRecords(50L, CompressionType.NONE, lastMessage)); + + List batches = batches(fileRecords); + int position = 0; + + int message1Size = batches.get(0).sizeInBytes(); + assertEquals(new FileRecords.LogOffsetPosition(0L, position, message1Size), + fileRecords.searchForOffsetWithSize(0, 0), + "Should be able to find the first message by its offset"); + position += message1Size; + + int message2Size = batches.get(1).sizeInBytes(); + assertEquals(new FileRecords.LogOffsetPosition(1L, position, message2Size), + fileRecords.searchForOffsetWithSize(1, 0), + "Should be able to find second message when starting from 0"); + assertEquals(new FileRecords.LogOffsetPosition(1L, position, message2Size), + fileRecords.searchForOffsetWithSize(1, position), + "Should be able to find second message starting from its offset"); + position += message2Size + batches.get(2).sizeInBytes(); + + int message4Size = batches.get(3).sizeInBytes(); + assertEquals(new FileRecords.LogOffsetPosition(50L, position, message4Size), + fileRecords.searchForOffsetWithSize(3, position), + "Should be able to find fourth message from a non-existent offset"); + assertEquals(new FileRecords.LogOffsetPosition(50L, position, message4Size), + fileRecords.searchForOffsetWithSize(50, position), + "Should be able to find fourth message by correct offset"); + } + + /** + * Test that the message set iterator obeys start and end slicing + */ + @Test + public void testIteratorWithLimits() throws IOException { + RecordBatch batch = batches(fileRecords).get(1); + int start = fileRecords.searchForOffsetWithSize(1, 0).position; + int size = batch.sizeInBytes(); + FileRecords slice = fileRecords.slice(start, size); + assertEquals(Collections.singletonList(batch), batches(slice)); + FileRecords slice2 = fileRecords.slice(start, size - 1); + assertEquals(Collections.emptyList(), batches(slice2)); + } + + /** + * Test the truncateTo method lops off messages and appropriately updates the size + */ + @Test + public void testTruncate() throws IOException { + RecordBatch batch = batches(fileRecords).get(0); + int end = fileRecords.searchForOffsetWithSize(1, 0).position; + fileRecords.truncateTo(end); + assertEquals(Collections.singletonList(batch), batches(fileRecords)); + assertEquals(batch.sizeInBytes(), fileRecords.sizeInBytes()); + } + + /** + * Test that truncateTo only calls truncate on the FileChannel if the size of the + * FileChannel is bigger than the target size. This is important because some JVMs + * change the mtime of the file, even if truncate should do nothing. + */ + @Test + public void testTruncateNotCalledIfSizeIsSameAsTargetSize() throws IOException { + FileChannel channelMock = mock(FileChannel.class); + + when(channelMock.size()).thenReturn(42L); + when(channelMock.position(42L)).thenReturn(null); + + FileRecords fileRecords = new FileRecords(tempFile(), channelMock, 0, Integer.MAX_VALUE, false); + fileRecords.truncateTo(42); + + verify(channelMock, atLeastOnce()).size(); + verify(channelMock, times(0)).truncate(anyLong()); + } + + /** + * Expect a KafkaException if targetSize is bigger than the size of + * the FileRecords. + */ + @Test + public void testTruncateNotCalledIfSizeIsBiggerThanTargetSize() throws IOException { + FileChannel channelMock = mock(FileChannel.class); + + when(channelMock.size()).thenReturn(42L); + + FileRecords fileRecords = new FileRecords(tempFile(), channelMock, 0, Integer.MAX_VALUE, false); + + try { + fileRecords.truncateTo(43); + fail("Should throw KafkaException"); + } catch (KafkaException e) { + // expected + } + + verify(channelMock, atLeastOnce()).size(); + } + + /** + * see #testTruncateNotCalledIfSizeIsSameAsTargetSize + */ + @Test + public void testTruncateIfSizeIsDifferentToTargetSize() throws IOException { + FileChannel channelMock = mock(FileChannel.class); + + when(channelMock.size()).thenReturn(42L); + when(channelMock.truncate(anyLong())).thenReturn(channelMock); + + FileRecords fileRecords = new FileRecords(tempFile(), channelMock, 0, Integer.MAX_VALUE, false); + fileRecords.truncateTo(23); + + verify(channelMock, atLeastOnce()).size(); + verify(channelMock).truncate(23); + } + + /** + * Test the new FileRecords with pre allocate as true + */ + @Test + public void testPreallocateTrue() throws IOException { + File temp = tempFile(); + FileRecords fileRecords = FileRecords.open(temp, false, 1024 * 1024, true); + long position = fileRecords.channel().position(); + int size = fileRecords.sizeInBytes(); + assertEquals(0, position); + assertEquals(0, size); + assertEquals(1024 * 1024, temp.length()); + } + + /** + * Test the new FileRecords with pre allocate as false + */ + @Test + public void testPreallocateFalse() throws IOException { + File temp = tempFile(); + FileRecords set = FileRecords.open(temp, false, 1024 * 1024, false); + long position = set.channel().position(); + int size = set.sizeInBytes(); + assertEquals(0, position); + assertEquals(0, size); + assertEquals(0, temp.length()); + } + + /** + * Test the new FileRecords with pre allocate as true and file has been clearly shut down, the file will be truncate to end of valid data. + */ + @Test + public void testPreallocateClearShutdown() throws IOException { + File temp = tempFile(); + FileRecords fileRecords = FileRecords.open(temp, false, 1024 * 1024, true); + append(fileRecords, values); + + int oldPosition = (int) fileRecords.channel().position(); + int oldSize = fileRecords.sizeInBytes(); + assertEquals(this.fileRecords.sizeInBytes(), oldPosition); + assertEquals(this.fileRecords.sizeInBytes(), oldSize); + fileRecords.close(); + + File tempReopen = new File(temp.getAbsolutePath()); + FileRecords setReopen = FileRecords.open(tempReopen, true, 1024 * 1024, true); + int position = (int) setReopen.channel().position(); + int size = setReopen.sizeInBytes(); + + assertEquals(oldPosition, position); + assertEquals(oldPosition, size); + assertEquals(oldPosition, tempReopen.length()); + } + + @Test + public void testFormatConversionWithPartialMessage() throws IOException { + RecordBatch batch = batches(fileRecords).get(1); + int start = fileRecords.searchForOffsetWithSize(1, 0).position; + int size = batch.sizeInBytes(); + FileRecords slice = fileRecords.slice(start, size - 1); + Records messageV0 = slice.downConvert(RecordBatch.MAGIC_VALUE_V0, 0, time).records(); + assertTrue(batches(messageV0).isEmpty(), "No message should be there"); + assertEquals(size - 1, messageV0.sizeInBytes(), "There should be " + (size - 1) + " bytes"); + + // Lazy down-conversion will not return any messages for a partial input batch + TopicPartition tp = new TopicPartition("topic-1", 0); + LazyDownConversionRecords lazyRecords = new LazyDownConversionRecords(tp, slice, RecordBatch.MAGIC_VALUE_V0, 0, Time.SYSTEM); + Iterator> it = lazyRecords.iterator(16 * 1024L); + assertFalse(it.hasNext(), "No messages should be returned"); + } + + @Test + public void testFormatConversionWithNoMessages() throws IOException { + TopicPartition tp = new TopicPartition("topic-1", 0); + LazyDownConversionRecords lazyRecords = new LazyDownConversionRecords(tp, MemoryRecords.EMPTY, RecordBatch.MAGIC_VALUE_V0, + 0, Time.SYSTEM); + assertEquals(0, lazyRecords.sizeInBytes()); + Iterator> it = lazyRecords.iterator(16 * 1024L); + assertFalse(it.hasNext(), "No messages should be returned"); + } + + @Test + public void testSearchForTimestamp() throws IOException { + for (RecordVersion version : RecordVersion.values()) { + testSearchForTimestamp(version); + } + } + + private void testSearchForTimestamp(RecordVersion version) throws IOException { + File temp = tempFile(); + FileRecords fileRecords = FileRecords.open(temp, false, 1024 * 1024, true); + appendWithOffsetAndTimestamp(fileRecords, version, 10L, 5, 0); + appendWithOffsetAndTimestamp(fileRecords, version, 11L, 6, 1); + + assertFoundTimestamp(new FileRecords.TimestampAndOffset(10L, 5, Optional.of(0)), + fileRecords.searchForTimestamp(9L, 0, 0L), version); + assertFoundTimestamp(new FileRecords.TimestampAndOffset(10L, 5, Optional.of(0)), + fileRecords.searchForTimestamp(10L, 0, 0L), version); + assertFoundTimestamp(new FileRecords.TimestampAndOffset(11L, 6, Optional.of(1)), + fileRecords.searchForTimestamp(11L, 0, 0L), version); + assertNull(fileRecords.searchForTimestamp(12L, 0, 0L)); + } + + private void assertFoundTimestamp(FileRecords.TimestampAndOffset expected, + FileRecords.TimestampAndOffset actual, + RecordVersion version) { + if (version == RecordVersion.V0) { + assertNull(actual, "Expected no match for message format v0"); + } else { + assertNotNull(actual, "Expected to find timestamp for message format " + version); + assertEquals(expected.timestamp, actual.timestamp, "Expected matching timestamps for message format" + version); + assertEquals(expected.offset, actual.offset, "Expected matching offsets for message format " + version); + Optional expectedLeaderEpoch = version.value >= RecordVersion.V2.value ? + expected.leaderEpoch : Optional.empty(); + assertEquals(expectedLeaderEpoch, actual.leaderEpoch, "Non-matching leader epoch for version " + version); + } + } + + private void appendWithOffsetAndTimestamp(FileRecords fileRecords, + RecordVersion recordVersion, + long timestamp, + long offset, + int leaderEpoch) throws IOException { + ByteBuffer buffer = ByteBuffer.allocate(128); + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, recordVersion.value, + CompressionType.NONE, TimestampType.CREATE_TIME, offset, timestamp, leaderEpoch); + builder.append(new SimpleRecord(timestamp, new byte[0], new byte[0])); + fileRecords.append(builder.build()); + } + + @Test + public void testDownconversionAfterMessageFormatDowngrade() throws IOException { + // random bytes + Random random = new Random(); + byte[] bytes = new byte[3000]; + random.nextBytes(bytes); + + // records + CompressionType compressionType = CompressionType.GZIP; + List offsets = asList(0L, 1L); + List magic = asList(RecordBatch.MAGIC_VALUE_V2, RecordBatch.MAGIC_VALUE_V1); // downgrade message format from v2 to v1 + List records = asList( + new SimpleRecord(1L, "k1".getBytes(), bytes), + new SimpleRecord(2L, "k2".getBytes(), bytes)); + byte toMagic = 1; + + // create MemoryRecords + ByteBuffer buffer = ByteBuffer.allocate(8000); + for (int i = 0; i < records.size(); i++) { + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, magic.get(i), compressionType, TimestampType.CREATE_TIME, 0L); + builder.appendWithOffset(offsets.get(i), records.get(i)); + builder.close(); + } + buffer.flip(); + + // create FileRecords, down-convert and verify + try (FileRecords fileRecords = FileRecords.open(tempFile())) { + fileRecords.append(MemoryRecords.readableRecords(buffer)); + fileRecords.flush(); + downConvertAndVerifyRecords(records, offsets, fileRecords, compressionType, toMagic, 0L, time); + } + } + + @Test + public void testConversion() throws IOException { + doTestConversion(CompressionType.NONE, RecordBatch.MAGIC_VALUE_V0); + doTestConversion(CompressionType.GZIP, RecordBatch.MAGIC_VALUE_V0); + doTestConversion(CompressionType.NONE, RecordBatch.MAGIC_VALUE_V1); + doTestConversion(CompressionType.GZIP, RecordBatch.MAGIC_VALUE_V1); + doTestConversion(CompressionType.NONE, RecordBatch.MAGIC_VALUE_V2); + doTestConversion(CompressionType.GZIP, RecordBatch.MAGIC_VALUE_V2); + } + + @Test + public void testBytesLengthOfWriteTo() throws IOException { + + int size = fileRecords.sizeInBytes(); + long firstWritten = size / 3; + + TransferableChannel channel = Mockito.mock(TransferableChannel.class); + + // Firstly we wrote some of the data + fileRecords.writeTo(channel, 0, (int) firstWritten); + verify(channel).transferFrom(any(), anyLong(), eq(firstWritten)); + + // Ensure (length > size - firstWritten) + int secondWrittenLength = size - (int) firstWritten + 1; + fileRecords.writeTo(channel, firstWritten, secondWrittenLength); + // But we still only write (size - firstWritten), which is not fulfilled in the old version + verify(channel).transferFrom(any(), anyLong(), eq(size - firstWritten)); + } + + private void doTestConversion(CompressionType compressionType, byte toMagic) throws IOException { + List offsets = asList(0L, 2L, 3L, 9L, 11L, 15L, 16L, 17L, 22L, 24L); + + Header[] headers = {new RecordHeader("headerKey1", "headerValue1".getBytes()), + new RecordHeader("headerKey2", "headerValue2".getBytes()), + new RecordHeader("headerKey3", "headerValue3".getBytes())}; + + List records = asList( + new SimpleRecord(1L, "k1".getBytes(), "hello".getBytes()), + new SimpleRecord(2L, "k2".getBytes(), "goodbye".getBytes()), + new SimpleRecord(3L, "k3".getBytes(), "hello again".getBytes()), + new SimpleRecord(4L, "k4".getBytes(), "goodbye for now".getBytes()), + new SimpleRecord(5L, "k5".getBytes(), "hello again".getBytes()), + new SimpleRecord(6L, "k6".getBytes(), "I sense indecision".getBytes()), + new SimpleRecord(7L, "k7".getBytes(), "what now".getBytes()), + new SimpleRecord(8L, "k8".getBytes(), "running out".getBytes(), headers), + new SimpleRecord(9L, "k9".getBytes(), "ok, almost done".getBytes()), + new SimpleRecord(10L, "k10".getBytes(), "finally".getBytes(), headers)); + assertEquals(offsets.size(), records.size(), "incorrect test setup"); + + ByteBuffer buffer = ByteBuffer.allocate(1024); + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V0, compressionType, + TimestampType.CREATE_TIME, 0L); + for (int i = 0; i < 3; i++) + builder.appendWithOffset(offsets.get(i), records.get(i)); + builder.close(); + + builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V1, compressionType, TimestampType.CREATE_TIME, + 0L); + for (int i = 3; i < 6; i++) + builder.appendWithOffset(offsets.get(i), records.get(i)); + builder.close(); + + builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, compressionType, TimestampType.CREATE_TIME, 0L); + for (int i = 6; i < 10; i++) + builder.appendWithOffset(offsets.get(i), records.get(i)); + builder.close(); + + buffer.flip(); + + try (FileRecords fileRecords = FileRecords.open(tempFile())) { + fileRecords.append(MemoryRecords.readableRecords(buffer)); + fileRecords.flush(); + downConvertAndVerifyRecords(records, offsets, fileRecords, compressionType, toMagic, 0L, time); + + if (toMagic <= RecordBatch.MAGIC_VALUE_V1 && compressionType == CompressionType.NONE) { + long firstOffset; + if (toMagic == RecordBatch.MAGIC_VALUE_V0) + firstOffset = 11L; // v1 record + else + firstOffset = 17; // v2 record + List filteredOffsets = new ArrayList<>(offsets); + List filteredRecords = new ArrayList<>(records); + int index = filteredOffsets.indexOf(firstOffset) - 1; + filteredRecords.remove(index); + filteredOffsets.remove(index); + downConvertAndVerifyRecords(filteredRecords, filteredOffsets, fileRecords, compressionType, toMagic, firstOffset, time); + } else { + // firstOffset doesn't have any effect in this case + downConvertAndVerifyRecords(records, offsets, fileRecords, compressionType, toMagic, 10L, time); + } + } + } + + private void downConvertAndVerifyRecords(List initialRecords, + List initialOffsets, + FileRecords fileRecords, + CompressionType compressionType, + byte toMagic, + long firstOffset, + Time time) { + long minBatchSize = Long.MAX_VALUE; + long maxBatchSize = Long.MIN_VALUE; + for (RecordBatch batch : fileRecords.batches()) { + minBatchSize = Math.min(minBatchSize, batch.sizeInBytes()); + maxBatchSize = Math.max(maxBatchSize, batch.sizeInBytes()); + } + + // Test the normal down-conversion path + List convertedRecords = new ArrayList<>(); + convertedRecords.add(fileRecords.downConvert(toMagic, firstOffset, time).records()); + verifyConvertedRecords(initialRecords, initialOffsets, convertedRecords, compressionType, toMagic); + convertedRecords.clear(); + + // Test the lazy down-conversion path + List maximumReadSize = asList(16L * 1024L, + (long) fileRecords.sizeInBytes(), + (long) fileRecords.sizeInBytes() - 1, + (long) fileRecords.sizeInBytes() / 4, + maxBatchSize + 1, + 1L); + for (long readSize : maximumReadSize) { + TopicPartition tp = new TopicPartition("topic-1", 0); + LazyDownConversionRecords lazyRecords = new LazyDownConversionRecords(tp, fileRecords, toMagic, firstOffset, Time.SYSTEM); + Iterator> it = lazyRecords.iterator(readSize); + while (it.hasNext()) + convertedRecords.add(it.next().records()); + verifyConvertedRecords(initialRecords, initialOffsets, convertedRecords, compressionType, toMagic); + convertedRecords.clear(); + } + } + + private void verifyConvertedRecords(List initialRecords, + List initialOffsets, + List convertedRecordsList, + CompressionType compressionType, + byte magicByte) { + int i = 0; + + for (Records convertedRecords : convertedRecordsList) { + for (RecordBatch batch : convertedRecords.batches()) { + assertTrue(batch.magic() <= magicByte, "Magic byte should be lower than or equal to " + magicByte); + if (batch.magic() == RecordBatch.MAGIC_VALUE_V0) + assertEquals(TimestampType.NO_TIMESTAMP_TYPE, batch.timestampType()); + else + assertEquals(TimestampType.CREATE_TIME, batch.timestampType()); + assertEquals(compressionType, batch.compressionType(), "Compression type should not be affected by conversion"); + for (Record record : batch) { + assertTrue(record.hasMagic(batch.magic()), "Inner record should have magic " + magicByte); + assertEquals(initialOffsets.get(i).longValue(), record.offset(), "Offset should not change"); + assertEquals(utf8(initialRecords.get(i).key()), utf8(record.key()), "Key should not change"); + assertEquals(utf8(initialRecords.get(i).value()), utf8(record.value()), "Value should not change"); + assertFalse(record.hasTimestampType(TimestampType.LOG_APPEND_TIME)); + if (batch.magic() == RecordBatch.MAGIC_VALUE_V0) { + assertEquals(RecordBatch.NO_TIMESTAMP, record.timestamp()); + assertFalse(record.hasTimestampType(TimestampType.CREATE_TIME)); + assertTrue(record.hasTimestampType(TimestampType.NO_TIMESTAMP_TYPE)); + } else if (batch.magic() == RecordBatch.MAGIC_VALUE_V1) { + assertEquals(initialRecords.get(i).timestamp(), record.timestamp(), "Timestamp should not change"); + assertTrue(record.hasTimestampType(TimestampType.CREATE_TIME)); + assertFalse(record.hasTimestampType(TimestampType.NO_TIMESTAMP_TYPE)); + } else { + assertEquals(initialRecords.get(i).timestamp(), record.timestamp(), "Timestamp should not change"); + assertFalse(record.hasTimestampType(TimestampType.CREATE_TIME)); + assertFalse(record.hasTimestampType(TimestampType.NO_TIMESTAMP_TYPE)); + assertArrayEquals(initialRecords.get(i).headers(), record.headers(), "Headers should not change"); + } + i += 1; + } + } + } + assertEquals(initialOffsets.size(), i); + } + + private static List batches(Records buffer) { + return TestUtils.toList(buffer.batches()); + } + + private FileRecords createFileRecords(byte[][] values) throws IOException { + FileRecords fileRecords = FileRecords.open(tempFile()); + append(fileRecords, values); + return fileRecords; + } + + private void append(FileRecords fileRecords, byte[][] values) throws IOException { + long offset = 0L; + for (byte[] value : values) { + ByteBuffer buffer = ByteBuffer.allocate(128); + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, + CompressionType.NONE, TimestampType.CREATE_TIME, offset); + builder.appendWithOffset(offset++, System.currentTimeMillis(), null, value); + fileRecords.append(builder.build()); + } + fileRecords.flush(); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/record/LazyDownConversionRecordsTest.java b/clients/src/test/java/org/apache/kafka/common/record/LazyDownConversionRecordsTest.java new file mode 100644 index 0000000..ebc2982 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/record/LazyDownConversionRecordsTest.java @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.network.TransferableChannel; +import org.apache.kafka.common.utils.Time; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.file.StandardOpenOption; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import static java.util.Arrays.asList; +import static org.apache.kafka.common.utils.Utils.utf8; +import static org.apache.kafka.test.TestUtils.tempFile; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class LazyDownConversionRecordsTest { + + /** + * Test the lazy down-conversion path in the presence of commit markers. When converting to V0 or V1, these batches + * are dropped. If there happen to be no more batches left to convert, we must get an overflow message batch after + * conversion. + */ + @Test + public void testConversionOfCommitMarker() throws IOException { + MemoryRecords recordsToConvert = MemoryRecords.withEndTransactionMarker(0, Time.SYSTEM.milliseconds(), RecordBatch.NO_PARTITION_LEADER_EPOCH, + 1, (short) 1, new EndTransactionMarker(ControlRecordType.COMMIT, 0)); + MemoryRecords convertedRecords = convertRecords(recordsToConvert, (byte) 1, recordsToConvert.sizeInBytes()); + ByteBuffer buffer = convertedRecords.buffer(); + + // read the offset and the batch length + buffer.getLong(); + int sizeOfConvertedRecords = buffer.getInt(); + + // assert we got an overflow message batch + assertTrue(sizeOfConvertedRecords > buffer.limit()); + assertFalse(convertedRecords.batchIterator().hasNext()); + } + + private static Collection parameters() { + List arguments = new ArrayList<>(); + for (byte toMagic = RecordBatch.MAGIC_VALUE_V0; toMagic <= RecordBatch.CURRENT_MAGIC_VALUE; toMagic++) { + for (boolean overflow : asList(true, false)) { + arguments.add(Arguments.of(CompressionType.NONE, toMagic, overflow)); + arguments.add(Arguments.of(CompressionType.GZIP, toMagic, overflow)); + } + } + return arguments; + } + + /** + * Test the lazy down-conversion path. + * + * If `overflow` is true, the number of bytes we want to convert is much larger + * than the number of bytes we get after conversion. This causes overflow message batch(es) to be appended towards the + * end of the converted output. + */ + @ParameterizedTest(name = "compressionType={0}, toMagic={1}, overflow={2}") + @MethodSource("parameters") + public void testConversion(CompressionType compressionType, byte toMagic, boolean overflow) throws IOException { + doTestConversion(compressionType, toMagic, overflow); + } + + private void doTestConversion(CompressionType compressionType, byte toMagic, boolean testConversionOverflow) throws IOException { + List offsets = asList(0L, 2L, 3L, 9L, 11L, 15L, 16L, 17L, 22L, 24L); + + Header[] headers = {new RecordHeader("headerKey1", "headerValue1".getBytes()), + new RecordHeader("headerKey2", "headerValue2".getBytes()), + new RecordHeader("headerKey3", "headerValue3".getBytes())}; + + List records = asList( + new SimpleRecord(1L, "k1".getBytes(), "hello".getBytes()), + new SimpleRecord(2L, "k2".getBytes(), "goodbye".getBytes()), + new SimpleRecord(3L, "k3".getBytes(), "hello again".getBytes()), + new SimpleRecord(4L, "k4".getBytes(), "goodbye for now".getBytes()), + new SimpleRecord(5L, "k5".getBytes(), "hello again".getBytes()), + new SimpleRecord(6L, "k6".getBytes(), "I sense indecision".getBytes()), + new SimpleRecord(7L, "k7".getBytes(), "what now".getBytes()), + new SimpleRecord(8L, "k8".getBytes(), "running out".getBytes(), headers), + new SimpleRecord(9L, "k9".getBytes(), "ok, almost done".getBytes()), + new SimpleRecord(10L, "k10".getBytes(), "finally".getBytes(), headers)); + assertEquals(offsets.size(), records.size(), "incorrect test setup"); + + ByteBuffer buffer = ByteBuffer.allocate(1024); + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, compressionType, + TimestampType.CREATE_TIME, 0L); + for (int i = 0; i < 3; i++) + builder.appendWithOffset(offsets.get(i), records.get(i)); + builder.close(); + + builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, compressionType, TimestampType.CREATE_TIME, + 0L); + for (int i = 3; i < 6; i++) + builder.appendWithOffset(offsets.get(i), records.get(i)); + builder.close(); + + builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, compressionType, TimestampType.CREATE_TIME, + 0L); + for (int i = 6; i < 10; i++) + builder.appendWithOffset(offsets.get(i), records.get(i)); + builder.close(); + buffer.flip(); + + MemoryRecords recordsToConvert = MemoryRecords.readableRecords(buffer); + int numBytesToConvert = recordsToConvert.sizeInBytes(); + if (testConversionOverflow) + numBytesToConvert *= 2; + + MemoryRecords convertedRecords = convertRecords(recordsToConvert, toMagic, numBytesToConvert); + verifyDownConvertedRecords(records, offsets, convertedRecords, compressionType, toMagic); + } + + private static MemoryRecords convertRecords(MemoryRecords recordsToConvert, byte toMagic, int bytesToConvert) throws IOException { + try (FileRecords inputRecords = FileRecords.open(tempFile())) { + inputRecords.append(recordsToConvert); + inputRecords.flush(); + + LazyDownConversionRecords lazyRecords = new LazyDownConversionRecords(new TopicPartition("test", 1), + inputRecords, toMagic, 0L, Time.SYSTEM); + LazyDownConversionRecordsSend lazySend = lazyRecords.toSend(); + File outputFile = tempFile(); + ByteBuffer convertedRecordsBuffer; + try (TransferableChannel channel = toTransferableChannel(FileChannel.open(outputFile.toPath(), StandardOpenOption.READ, StandardOpenOption.WRITE))) { + int written = 0; + while (written < bytesToConvert) written += lazySend.writeTo(channel, written, bytesToConvert - written); + try (FileRecords convertedRecords = FileRecords.open(outputFile, true, written, false)) { + convertedRecordsBuffer = ByteBuffer.allocate(convertedRecords.sizeInBytes()); + convertedRecords.readInto(convertedRecordsBuffer, 0); + } + } + return MemoryRecords.readableRecords(convertedRecordsBuffer); + } + } + + private static TransferableChannel toTransferableChannel(FileChannel channel) { + return new TransferableChannel() { + + @Override + public boolean hasPendingWrites() { + return false; + } + + @Override + public long transferFrom(FileChannel fileChannel, long position, long count) throws IOException { + return fileChannel.transferTo(position, count, channel); + } + + @Override + public boolean isOpen() { + return channel.isOpen(); + } + + @Override + public void close() throws IOException { + channel.close(); + } + + @Override + public int write(ByteBuffer src) throws IOException { + return channel.write(src); + } + + @Override + public long write(ByteBuffer[] srcs, int offset, int length) throws IOException { + return channel.write(srcs, offset, length); + } + + @Override + public long write(ByteBuffer[] srcs) throws IOException { + return channel.write(srcs); + } + }; + } + + private static void verifyDownConvertedRecords(List initialRecords, + List initialOffsets, + MemoryRecords downConvertedRecords, + CompressionType compressionType, + byte toMagic) { + int i = 0; + for (RecordBatch batch : downConvertedRecords.batches()) { + assertTrue(batch.magic() <= toMagic, "Magic byte should be lower than or equal to " + toMagic); + if (batch.magic() == RecordBatch.MAGIC_VALUE_V0) + assertEquals(TimestampType.NO_TIMESTAMP_TYPE, batch.timestampType()); + else + assertEquals(TimestampType.CREATE_TIME, batch.timestampType()); + assertEquals(compressionType, batch.compressionType(), "Compression type should not be affected by conversion"); + for (Record record : batch) { + assertTrue(record.hasMagic(batch.magic()), "Inner record should have magic " + toMagic); + assertEquals(initialOffsets.get(i).longValue(), record.offset(), "Offset should not change"); + assertEquals(utf8(initialRecords.get(i).key()), utf8(record.key()), "Key should not change"); + assertEquals(utf8(initialRecords.get(i).value()), utf8(record.value()), "Value should not change"); + assertFalse(record.hasTimestampType(TimestampType.LOG_APPEND_TIME)); + if (batch.magic() == RecordBatch.MAGIC_VALUE_V0) { + assertEquals(RecordBatch.NO_TIMESTAMP, record.timestamp()); + assertFalse(record.hasTimestampType(TimestampType.CREATE_TIME)); + assertTrue(record.hasTimestampType(TimestampType.NO_TIMESTAMP_TYPE)); + } else if (batch.magic() == RecordBatch.MAGIC_VALUE_V1) { + assertEquals(initialRecords.get(i).timestamp(), record.timestamp(), "Timestamp should not change"); + assertTrue(record.hasTimestampType(TimestampType.CREATE_TIME)); + assertFalse(record.hasTimestampType(TimestampType.NO_TIMESTAMP_TYPE)); + } else { + assertEquals(initialRecords.get(i).timestamp(), record.timestamp(), "Timestamp should not change"); + assertFalse(record.hasTimestampType(TimestampType.CREATE_TIME)); + assertFalse(record.hasTimestampType(TimestampType.NO_TIMESTAMP_TYPE)); + assertArrayEquals(initialRecords.get(i).headers(), record.headers(), "Headers should not change"); + } + i += 1; + } + } + assertEquals(initialOffsets.size(), i); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/record/LegacyRecordTest.java b/clients/src/test/java/org/apache/kafka/common/record/LegacyRecordTest.java new file mode 100644 index 0000000..ffa49e6 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/record/LegacyRecordTest.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.errors.CorruptRecordException; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.ArgumentsProvider; +import org.junit.jupiter.params.provider.ArgumentsSource; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class LegacyRecordTest { + + private static class Args { + final byte magic; + final long timestamp; + final ByteBuffer key; + final ByteBuffer value; + final CompressionType compression; + final TimestampType timestampType; + final LegacyRecord record; + + public Args(byte magic, long timestamp, byte[] key, byte[] value, CompressionType compression) { + this.magic = magic; + this.timestamp = timestamp; + this.timestampType = TimestampType.CREATE_TIME; + this.key = key == null ? null : ByteBuffer.wrap(key); + this.value = value == null ? null : ByteBuffer.wrap(value); + this.compression = compression; + this.record = LegacyRecord.create(magic, timestamp, key, value, compression, timestampType); + } + + @Override + public String toString() { + return "magic=" + magic + + ", compression=" + compression + + ", timestamp=" + timestamp; + } + } + + private static class LegacyRecordArgumentsProvider implements ArgumentsProvider { + @Override + public Stream provideArguments(ExtensionContext context) { + byte[] payload = new byte[1000]; + Arrays.fill(payload, (byte) 1); + List arguments = new ArrayList<>(); + for (byte magic : Arrays.asList(RecordBatch.MAGIC_VALUE_V0, RecordBatch.MAGIC_VALUE_V1)) + for (long timestamp : Arrays.asList(RecordBatch.NO_TIMESTAMP, 0L, 1L)) + for (byte[] key : Arrays.asList(null, "".getBytes(), "key".getBytes(), payload)) + for (byte[] value : Arrays.asList(null, "".getBytes(), "value".getBytes(), payload)) + for (CompressionType compression : CompressionType.values()) + arguments.add(Arguments.of(new Args(magic, timestamp, key, value, compression))); + return arguments.stream(); + } + } + + @ParameterizedTest + @ArgumentsSource(LegacyRecordArgumentsProvider.class) + public void testFields(Args args) { + LegacyRecord record = args.record; + ByteBuffer key = args.key; + assertEquals(args.compression, record.compressionType()); + assertEquals(key != null, record.hasKey()); + assertEquals(key, record.key()); + if (key != null) + assertEquals(key.limit(), record.keySize()); + assertEquals(args.magic, record.magic()); + assertEquals(args.value, record.value()); + if (args.value != null) + assertEquals(args.value.limit(), record.valueSize()); + if (args.magic > 0) { + assertEquals(args.timestamp, record.timestamp()); + assertEquals(args.timestampType, record.timestampType()); + } else { + assertEquals(RecordBatch.NO_TIMESTAMP, record.timestamp()); + assertEquals(TimestampType.NO_TIMESTAMP_TYPE, record.timestampType()); + } + } + + @ParameterizedTest + @ArgumentsSource(LegacyRecordArgumentsProvider.class) + public void testChecksum(Args args) { + LegacyRecord record = args.record; + assertEquals(record.checksum(), record.computeChecksum()); + + byte attributes = LegacyRecord.computeAttributes(args.magic, args.compression, TimestampType.CREATE_TIME); + assertEquals(record.checksum(), LegacyRecord.computeChecksum( + args.magic, + attributes, + args.timestamp, + args.key == null ? null : args.key.array(), + args.value == null ? null : args.value.array() + )); + assertTrue(record.isValid()); + for (int i = LegacyRecord.CRC_OFFSET + LegacyRecord.CRC_LENGTH; i < record.sizeInBytes(); i++) { + LegacyRecord copy = copyOf(record); + copy.buffer().put(i, (byte) 69); + assertFalse(copy.isValid()); + assertThrows(CorruptRecordException.class, copy::ensureValid); + } + } + + private LegacyRecord copyOf(LegacyRecord record) { + ByteBuffer buffer = ByteBuffer.allocate(record.sizeInBytes()); + record.buffer().put(buffer); + buffer.rewind(); + record.buffer().rewind(); + return new LegacyRecord(buffer); + } + + @ParameterizedTest + @ArgumentsSource(LegacyRecordArgumentsProvider.class) + public void testEquality(Args args) { + assertEquals(args.record, copyOf(args.record)); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsBuilderTest.java b/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsBuilderTest.java new file mode 100644 index 0000000..4f3f03c --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsBuilderTest.java @@ -0,0 +1,834 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.errors.UnsupportedCompressionTypeException; +import org.apache.kafka.common.message.LeaderChangeMessage; +import org.apache.kafka.common.message.LeaderChangeMessage.Voter; +import org.apache.kafka.common.utils.BufferSupplier; +import org.apache.kafka.common.utils.ByteBufferOutputStream; +import org.apache.kafka.common.utils.CloseableIterator; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.ArgumentsProvider; +import org.junit.jupiter.params.provider.ArgumentsSource; +import org.junit.jupiter.params.provider.EnumSource; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.OptionalLong; +import java.util.Random; +import java.util.function.BiFunction; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static java.util.Arrays.asList; +import static org.apache.kafka.common.record.RecordBatch.MAGIC_VALUE_V0; +import static org.apache.kafka.common.record.RecordBatch.MAGIC_VALUE_V1; +import static org.apache.kafka.common.record.RecordBatch.MAGIC_VALUE_V2; +import static org.apache.kafka.common.utils.Utils.utf8; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class MemoryRecordsBuilderTest { + + private static class Args { + final int bufferOffset; + final CompressionType compressionType; + final byte magic; + + public Args(int bufferOffset, CompressionType compressionType, byte magic) { + this.bufferOffset = bufferOffset; + this.compressionType = compressionType; + this.magic = magic; + } + + @Override + public String toString() { + return "magic=" + magic + + ", bufferOffset=" + bufferOffset + + ", compressionType=" + compressionType; + } + } + + private static class MemoryRecordsBuilderArgumentsProvider implements ArgumentsProvider { + @Override + public Stream provideArguments(ExtensionContext context) { + List values = new ArrayList<>(); + for (int bufferOffset : Arrays.asList(0, 15)) + for (CompressionType type: CompressionType.values()) { + List magics = type == CompressionType.ZSTD + ? Collections.singletonList(RecordBatch.MAGIC_VALUE_V2) + : asList(RecordBatch.MAGIC_VALUE_V0, MAGIC_VALUE_V1, RecordBatch.MAGIC_VALUE_V2); + for (byte magic : magics) + values.add(Arguments.of(new Args(bufferOffset, type, magic))); + } + return values.stream(); + } + } + + private final Time time = Time.SYSTEM; + + @Test + public void testUnsupportedCompress() { + BiFunction builderBiFunction = (magic, compressionType) -> + new MemoryRecordsBuilder(ByteBuffer.allocate(128), magic, compressionType, TimestampType.CREATE_TIME, 0L, 0L, + RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, + false, false, RecordBatch.NO_PARTITION_LEADER_EPOCH, 128); + + Arrays.asList(MAGIC_VALUE_V0, MAGIC_VALUE_V1).forEach(magic -> { + Exception e = assertThrows(IllegalArgumentException.class, () -> builderBiFunction.apply(magic, CompressionType.ZSTD)); + assertEquals(e.getMessage(), "ZStandard compression is not supported for magic " + magic); + }); + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsBuilderArgumentsProvider.class) + public void testWriteEmptyRecordSet(Args args) { + byte magic = args.magic; + ByteBuffer buffer = allocateBuffer(128, args); + + MemoryRecords records = new MemoryRecordsBuilder(buffer, magic, + args.compressionType, TimestampType.CREATE_TIME, 0L, 0L, + RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, + false, false, RecordBatch.NO_PARTITION_LEADER_EPOCH, buffer.capacity()).build(); + + assertEquals(0, records.sizeInBytes()); + assertEquals(args.bufferOffset, buffer.position()); + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsBuilderArgumentsProvider.class) + public void testWriteTransactionalRecordSet(Args args) { + ByteBuffer buffer = allocateBuffer(128, args); + long pid = 9809; + short epoch = 15; + int sequence = 2342; + + Supplier supplier = () -> new MemoryRecordsBuilder(buffer, args.magic, args.compressionType, + TimestampType.CREATE_TIME, 0L, 0L, pid, epoch, sequence, true, false, + RecordBatch.NO_PARTITION_LEADER_EPOCH, buffer.capacity()); + + if (args.magic < MAGIC_VALUE_V2) { + assertThrows(IllegalArgumentException.class, supplier::get); + } else { + MemoryRecordsBuilder builder = supplier.get(); + builder.append(System.currentTimeMillis(), "foo".getBytes(), "bar".getBytes()); + MemoryRecords records = builder.build(); + + List batches = Utils.toList(records.batches().iterator()); + assertEquals(1, batches.size()); + assertTrue(batches.get(0).isTransactional()); + } + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsBuilderArgumentsProvider.class) + public void testWriteTransactionalWithInvalidPID(Args args) { + ByteBuffer buffer = allocateBuffer(128, args); + long pid = RecordBatch.NO_PRODUCER_ID; + short epoch = 15; + int sequence = 2342; + + Supplier supplier = () -> new MemoryRecordsBuilder(buffer, args.magic, args.compressionType, TimestampType.CREATE_TIME, + 0L, 0L, pid, epoch, sequence, true, false, RecordBatch.NO_PARTITION_LEADER_EPOCH, buffer.capacity()); + if (args.magic < MAGIC_VALUE_V2) { + assertThrows(IllegalArgumentException.class, supplier::get); + } else { + MemoryRecordsBuilder builder = supplier.get(); + assertThrows(IllegalArgumentException.class, builder::close); + } + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsBuilderArgumentsProvider.class) + public void testWriteIdempotentWithInvalidEpoch(Args args) { + ByteBuffer buffer = allocateBuffer(128, args); + long pid = 9809; + short epoch = RecordBatch.NO_PRODUCER_EPOCH; + int sequence = 2342; + + Supplier supplier = () -> new MemoryRecordsBuilder(buffer, args.magic, args.compressionType, TimestampType.CREATE_TIME, + 0L, 0L, pid, epoch, sequence, true, false, RecordBatch.NO_PARTITION_LEADER_EPOCH, buffer.capacity()); + + if (args.magic < MAGIC_VALUE_V2) { + assertThrows(IllegalArgumentException.class, supplier::get); + } else { + MemoryRecordsBuilder builder = supplier.get(); + assertThrows(IllegalArgumentException.class, builder::close); + } + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsBuilderArgumentsProvider.class) + public void testWriteIdempotentWithInvalidBaseSequence(Args args) { + ByteBuffer buffer = allocateBuffer(128, args); + long pid = 9809; + short epoch = 15; + int sequence = RecordBatch.NO_SEQUENCE; + + Supplier supplier = () -> new MemoryRecordsBuilder(buffer, args.magic, args.compressionType, TimestampType.CREATE_TIME, + 0L, 0L, pid, epoch, sequence, true, false, RecordBatch.NO_PARTITION_LEADER_EPOCH, buffer.capacity()); + + if (args.magic < MAGIC_VALUE_V2) { + assertThrows(IllegalArgumentException.class, supplier::get); + } else { + MemoryRecordsBuilder builder = supplier.get(); + assertThrows(IllegalArgumentException.class, builder::close); + } + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsBuilderArgumentsProvider.class) + public void testWriteEndTxnMarkerNonTransactionalBatch(Args args) { + ByteBuffer buffer = allocateBuffer(128, args); + long pid = 9809; + short epoch = 15; + int sequence = RecordBatch.NO_SEQUENCE; + + Supplier supplier = () -> new MemoryRecordsBuilder(buffer, args.magic, args.compressionType, + TimestampType.CREATE_TIME, 0L, 0L, pid, epoch, sequence, false, true, + RecordBatch.NO_PARTITION_LEADER_EPOCH, buffer.capacity()); + + if (args.magic < MAGIC_VALUE_V2) { + assertThrows(IllegalArgumentException.class, supplier::get); + } else { + MemoryRecordsBuilder builder = supplier.get(); + assertThrows(IllegalArgumentException.class, () -> builder.appendEndTxnMarker(RecordBatch.NO_TIMESTAMP, + new EndTransactionMarker(ControlRecordType.ABORT, 0))); + } + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsBuilderArgumentsProvider.class) + public void testWriteEndTxnMarkerNonControlBatch(Args args) { + ByteBuffer buffer = allocateBuffer(128, args); + long pid = 9809; + short epoch = 15; + int sequence = RecordBatch.NO_SEQUENCE; + + Supplier supplier = () -> new MemoryRecordsBuilder(buffer, args.magic, args.compressionType, TimestampType.CREATE_TIME, + 0L, 0L, pid, epoch, sequence, true, false, RecordBatch.NO_PARTITION_LEADER_EPOCH, buffer.capacity()); + + if (args.magic < MAGIC_VALUE_V2) { + assertThrows(IllegalArgumentException.class, supplier::get); + } else { + MemoryRecordsBuilder builder = supplier.get(); + assertThrows(IllegalArgumentException.class, () -> builder.appendEndTxnMarker(RecordBatch.NO_TIMESTAMP, + new EndTransactionMarker(ControlRecordType.ABORT, 0))); + } + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsBuilderArgumentsProvider.class) + public void testWriteLeaderChangeControlBatchWithoutLeaderEpoch(Args args) { + ByteBuffer buffer = allocateBuffer(128, args); + Supplier supplier = () -> new MemoryRecordsBuilder(buffer, args.magic, args.compressionType, + TimestampType.CREATE_TIME, 0L, 0L, + RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, + false, true, RecordBatch.NO_PARTITION_LEADER_EPOCH, buffer.capacity()); + + if (args.magic < MAGIC_VALUE_V2) { + assertThrows(IllegalArgumentException.class, supplier::get); + } else { + final int leaderId = 1; + MemoryRecordsBuilder builder = supplier.get(); + assertThrows(IllegalArgumentException.class, () -> builder.appendLeaderChangeMessage(RecordBatch.NO_TIMESTAMP, + new LeaderChangeMessage().setLeaderId(leaderId).setVoters(Collections.emptyList()))); + } + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsBuilderArgumentsProvider.class) + public void testWriteLeaderChangeControlBatch(Args args) { + ByteBuffer buffer = allocateBuffer(128, args); + final int leaderId = 1; + final int leaderEpoch = 5; + final List voters = Arrays.asList(2, 3); + + Supplier supplier = () -> new MemoryRecordsBuilder(buffer, args.magic, args.compressionType, + TimestampType.CREATE_TIME, 0L, 0L, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, + RecordBatch.NO_SEQUENCE, false, true, leaderEpoch, buffer.capacity()); + + if (args.magic < MAGIC_VALUE_V2) { + assertThrows(IllegalArgumentException.class, supplier::get); + } else { + MemoryRecordsBuilder builder = supplier.get(); + builder.appendLeaderChangeMessage(RecordBatch.NO_TIMESTAMP, + new LeaderChangeMessage() + .setLeaderId(leaderId) + .setVoters(voters.stream().map( + voterId -> new Voter().setVoterId(voterId)).collect(Collectors.toList()))); + + MemoryRecords built = builder.build(); + List records = TestUtils.toList(built.records()); + assertEquals(1, records.size()); + LeaderChangeMessage leaderChangeMessage = ControlRecordUtils.deserializeLeaderChangeMessage(records.get(0)); + + assertEquals(leaderId, leaderChangeMessage.leaderId()); + assertEquals(voters, leaderChangeMessage.voters().stream().map(Voter::voterId).collect(Collectors.toList())); + } + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsBuilderArgumentsProvider.class) + public void testLegacyCompressionRate(Args args) { + byte magic = args.magic; + ByteBuffer buffer = allocateBuffer(1024, args); + + Supplier supplier = () -> new LegacyRecord[]{ + LegacyRecord.create(magic, 0L, "a".getBytes(), "1".getBytes()), + LegacyRecord.create(magic, 1L, "b".getBytes(), "2".getBytes()), + LegacyRecord.create(magic, 2L, "c".getBytes(), "3".getBytes()), + }; + + if (magic >= MAGIC_VALUE_V2) { + assertThrows(IllegalArgumentException.class, supplier::get); + } else { + LegacyRecord[] records = supplier.get(); + + MemoryRecordsBuilder builder = new MemoryRecordsBuilder(buffer, magic, args.compressionType, + TimestampType.CREATE_TIME, 0L, 0L, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, + false, false, RecordBatch.NO_PARTITION_LEADER_EPOCH, buffer.capacity()); + + int uncompressedSize = 0; + for (LegacyRecord record : records) { + uncompressedSize += record.sizeInBytes() + Records.LOG_OVERHEAD; + builder.append(record); + } + + MemoryRecords built = builder.build(); + if (args.compressionType == CompressionType.NONE) { + assertEquals(1.0, builder.compressionRatio(), 0.00001); + } else { + int recordHeaad = magic == MAGIC_VALUE_V0 ? LegacyRecord.RECORD_OVERHEAD_V0 : LegacyRecord.RECORD_OVERHEAD_V1; + int compressedSize = built.sizeInBytes() - Records.LOG_OVERHEAD - recordHeaad; + double computedCompressionRate = (double) compressedSize / uncompressedSize; + assertEquals(computedCompressionRate, builder.compressionRatio(), 0.00001); + } + } + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsBuilderArgumentsProvider.class) + public void testEstimatedSizeInBytes(Args args) { + ByteBuffer buffer = allocateBuffer(1024, args); + + MemoryRecordsBuilder builder = new MemoryRecordsBuilder(buffer, args.magic, args.compressionType, + TimestampType.CREATE_TIME, 0L, 0L, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, + false, false, RecordBatch.NO_PARTITION_LEADER_EPOCH, buffer.capacity()); + + int previousEstimate = 0; + for (int i = 0; i < 10; i++) { + builder.append(new SimpleRecord(i, ("" + i).getBytes())); + int currentEstimate = builder.estimatedSizeInBytes(); + assertTrue(currentEstimate > previousEstimate); + previousEstimate = currentEstimate; + } + + int bytesWrittenBeforeClose = builder.estimatedSizeInBytes(); + MemoryRecords records = builder.build(); + assertEquals(records.sizeInBytes(), builder.estimatedSizeInBytes()); + if (args.compressionType == CompressionType.NONE) + assertEquals(records.sizeInBytes(), bytesWrittenBeforeClose); + } + + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsBuilderArgumentsProvider.class) + public void buildUsingLogAppendTime(Args args) { + byte magic = args.magic; + ByteBuffer buffer = allocateBuffer(1024, args); + long logAppendTime = System.currentTimeMillis(); + + MemoryRecordsBuilder builder = new MemoryRecordsBuilder(buffer, magic, args.compressionType, + TimestampType.LOG_APPEND_TIME, 0L, logAppendTime, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, + RecordBatch.NO_SEQUENCE, false, false, RecordBatch.NO_PARTITION_LEADER_EPOCH, buffer.capacity()); + builder.append(0L, "a".getBytes(), "1".getBytes()); + builder.append(0L, "b".getBytes(), "2".getBytes()); + builder.append(0L, "c".getBytes(), "3".getBytes()); + MemoryRecords records = builder.build(); + + MemoryRecordsBuilder.RecordsInfo info = builder.info(); + assertEquals(logAppendTime, info.maxTimestamp); + + if (args.compressionType == CompressionType.NONE && magic <= MAGIC_VALUE_V1) + assertEquals(0L, info.shallowOffsetOfMaxTimestamp); + else + assertEquals(2L, info.shallowOffsetOfMaxTimestamp); + + for (RecordBatch batch : records.batches()) { + if (magic == MAGIC_VALUE_V0) { + assertEquals(TimestampType.NO_TIMESTAMP_TYPE, batch.timestampType()); + } else { + assertEquals(TimestampType.LOG_APPEND_TIME, batch.timestampType()); + for (Record record : batch) + assertEquals(logAppendTime, record.timestamp()); + } + } + } + @ParameterizedTest + @ArgumentsSource(MemoryRecordsBuilderArgumentsProvider.class) + public void buildUsingCreateTime(Args args) { + byte magic = args.magic; + ByteBuffer buffer = allocateBuffer(1024, args); + + long logAppendTime = System.currentTimeMillis(); + MemoryRecordsBuilder builder = new MemoryRecordsBuilder(buffer, magic, args.compressionType, + TimestampType.CREATE_TIME, 0L, logAppendTime, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, + false, false, RecordBatch.NO_PARTITION_LEADER_EPOCH, buffer.capacity()); + builder.append(0L, "a".getBytes(), "1".getBytes()); + builder.append(2L, "b".getBytes(), "2".getBytes()); + builder.append(1L, "c".getBytes(), "3".getBytes()); + MemoryRecords records = builder.build(); + + MemoryRecordsBuilder.RecordsInfo info = builder.info(); + if (magic == MAGIC_VALUE_V0) { + assertEquals(-1, info.maxTimestamp); + } else { + assertEquals(2L, info.maxTimestamp); + } + + if (args.compressionType == CompressionType.NONE && magic == MAGIC_VALUE_V1) + assertEquals(1L, info.shallowOffsetOfMaxTimestamp); + else + assertEquals(2L, info.shallowOffsetOfMaxTimestamp); + + int i = 0; + long[] expectedTimestamps = new long[] {0L, 2L, 1L}; + for (RecordBatch batch : records.batches()) { + if (magic == MAGIC_VALUE_V0) { + assertEquals(TimestampType.NO_TIMESTAMP_TYPE, batch.timestampType()); + } else { + assertEquals(TimestampType.CREATE_TIME, batch.timestampType()); + for (Record record : batch) + assertEquals(expectedTimestamps[i++], record.timestamp()); + } + } + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsBuilderArgumentsProvider.class) + public void testAppendedChecksumConsistency(Args args) { + ByteBuffer buffer = ByteBuffer.allocate(512); + MemoryRecordsBuilder builder = new MemoryRecordsBuilder(buffer, args.magic, args.compressionType, + TimestampType.CREATE_TIME, 0L, LegacyRecord.NO_TIMESTAMP, RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, false, false, + RecordBatch.NO_PARTITION_LEADER_EPOCH, buffer.capacity()); + builder.append(1L, "key".getBytes(), "value".getBytes()); + MemoryRecords memoryRecords = builder.build(); + List records = TestUtils.toList(memoryRecords.records()); + assertEquals(1, records.size()); + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsBuilderArgumentsProvider.class) + public void testSmallWriteLimit(Args args) { + // with a small write limit, we always allow at least one record to be added + + byte[] key = "foo".getBytes(); + byte[] value = "bar".getBytes(); + int writeLimit = 0; + ByteBuffer buffer = ByteBuffer.allocate(512); + MemoryRecordsBuilder builder = new MemoryRecordsBuilder(buffer, args.magic, args.compressionType, + TimestampType.CREATE_TIME, 0L, LegacyRecord.NO_TIMESTAMP, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, + RecordBatch.NO_SEQUENCE, false, false, RecordBatch.NO_PARTITION_LEADER_EPOCH, writeLimit); + + assertFalse(builder.isFull()); + assertTrue(builder.hasRoomFor(0L, key, value, Record.EMPTY_HEADERS)); + builder.append(0L, key, value); + + assertTrue(builder.isFull()); + assertFalse(builder.hasRoomFor(0L, key, value, Record.EMPTY_HEADERS)); + + MemoryRecords memRecords = builder.build(); + List records = TestUtils.toList(memRecords.records()); + assertEquals(1, records.size()); + + Record record = records.get(0); + assertEquals(ByteBuffer.wrap(key), record.key()); + assertEquals(ByteBuffer.wrap(value), record.value()); + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsBuilderArgumentsProvider.class) + public void writePastLimit(Args args) { + byte magic = args.magic; + ByteBuffer buffer = allocateBuffer(64, args); + + long logAppendTime = System.currentTimeMillis(); + MemoryRecordsBuilder builder = new MemoryRecordsBuilder(buffer, magic, args.compressionType, + TimestampType.CREATE_TIME, 0L, logAppendTime, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, + false, false, RecordBatch.NO_PARTITION_LEADER_EPOCH, buffer.capacity()); + builder.setEstimatedCompressionRatio(0.5f); + builder.append(0L, "a".getBytes(), "1".getBytes()); + builder.append(1L, "b".getBytes(), "2".getBytes()); + + assertFalse(builder.hasRoomFor(2L, "c".getBytes(), "3".getBytes(), Record.EMPTY_HEADERS)); + builder.append(2L, "c".getBytes(), "3".getBytes()); + MemoryRecords records = builder.build(); + + MemoryRecordsBuilder.RecordsInfo info = builder.info(); + if (magic == MAGIC_VALUE_V0) + assertEquals(-1, info.maxTimestamp); + else + assertEquals(2L, info.maxTimestamp); + + assertEquals(2L, info.shallowOffsetOfMaxTimestamp); + + long i = 0L; + for (RecordBatch batch : records.batches()) { + if (magic == MAGIC_VALUE_V0) { + assertEquals(TimestampType.NO_TIMESTAMP_TYPE, batch.timestampType()); + } else { + assertEquals(TimestampType.CREATE_TIME, batch.timestampType()); + for (Record record : batch) + assertEquals(i++, record.timestamp()); + } + } + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsBuilderArgumentsProvider.class) + public void testAppendAtInvalidOffset(Args args) { + ByteBuffer buffer = allocateBuffer(1024, args); + + long logAppendTime = System.currentTimeMillis(); + MemoryRecordsBuilder builder = new MemoryRecordsBuilder(buffer, args.magic, args.compressionType, + TimestampType.CREATE_TIME, 0L, logAppendTime, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, + false, false, RecordBatch.NO_PARTITION_LEADER_EPOCH, buffer.capacity()); + + builder.appendWithOffset(0L, System.currentTimeMillis(), "a".getBytes(), null); + + // offsets must increase monotonically + assertThrows(IllegalArgumentException.class, () -> builder.appendWithOffset(0L, System.currentTimeMillis(), + "b".getBytes(), null)); + } + + @ParameterizedTest + @EnumSource(CompressionType.class) + public void convertV2ToV1UsingMixedCreateAndLogAppendTime(CompressionType compressionType) { + ByteBuffer buffer = ByteBuffer.allocate(512); + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, + compressionType, TimestampType.LOG_APPEND_TIME, 0L); + builder.append(10L, "1".getBytes(), "a".getBytes()); + builder.close(); + + int sizeExcludingTxnMarkers = buffer.position(); + + MemoryRecords.writeEndTransactionalMarker(buffer, 1L, System.currentTimeMillis(), 0, 15L, (short) 0, + new EndTransactionMarker(ControlRecordType.ABORT, 0)); + + int position = buffer.position(); + + builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, compressionType, + TimestampType.CREATE_TIME, 1L); + builder.append(12L, "2".getBytes(), "b".getBytes()); + builder.append(13L, "3".getBytes(), "c".getBytes()); + builder.close(); + + sizeExcludingTxnMarkers += buffer.position() - position; + + MemoryRecords.writeEndTransactionalMarker(buffer, 14L, System.currentTimeMillis(), 0, 1L, (short) 0, + new EndTransactionMarker(ControlRecordType.COMMIT, 0)); + + buffer.flip(); + + Supplier> convertedRecordsSupplier = () -> + MemoryRecords.readableRecords(buffer).downConvert(MAGIC_VALUE_V1, 0, time); + + if (compressionType != CompressionType.ZSTD) { + ConvertedRecords convertedRecords = convertedRecordsSupplier.get(); + MemoryRecords records = convertedRecords.records(); + + // Transactional markers are skipped when down converting to V1, so exclude them from size + verifyRecordsProcessingStats(compressionType, convertedRecords.recordConversionStats(), + 3, 3, records.sizeInBytes(), sizeExcludingTxnMarkers); + + List batches = Utils.toList(records.batches().iterator()); + if (compressionType != CompressionType.NONE) { + assertEquals(2, batches.size()); + assertEquals(TimestampType.LOG_APPEND_TIME, batches.get(0).timestampType()); + assertEquals(TimestampType.CREATE_TIME, batches.get(1).timestampType()); + } else { + assertEquals(3, batches.size()); + assertEquals(TimestampType.LOG_APPEND_TIME, batches.get(0).timestampType()); + assertEquals(TimestampType.CREATE_TIME, batches.get(1).timestampType()); + assertEquals(TimestampType.CREATE_TIME, batches.get(2).timestampType()); + } + + List logRecords = Utils.toList(records.records().iterator()); + assertEquals(3, logRecords.size()); + assertEquals(ByteBuffer.wrap("1".getBytes()), logRecords.get(0).key()); + assertEquals(ByteBuffer.wrap("2".getBytes()), logRecords.get(1).key()); + assertEquals(ByteBuffer.wrap("3".getBytes()), logRecords.get(2).key()); + } else { + Exception e = assertThrows(UnsupportedCompressionTypeException.class, convertedRecordsSupplier::get); + assertEquals("Down-conversion of zstandard-compressed batches is not supported", e.getMessage()); + } + } + + @ParameterizedTest + @EnumSource(CompressionType.class) + public void convertToV1WithMixedV0AndV2Data(CompressionType compressionType) { + ByteBuffer buffer = ByteBuffer.allocate(512); + + Supplier supplier = () -> MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V0, + compressionType, TimestampType.NO_TIMESTAMP_TYPE, 0L); + + if (compressionType == CompressionType.ZSTD) { + assertThrows(IllegalArgumentException.class, supplier::get); + } else { + MemoryRecordsBuilder builder = supplier.get(); + builder.append(RecordBatch.NO_TIMESTAMP, "1".getBytes(), "a".getBytes()); + builder.close(); + + builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, compressionType, + TimestampType.CREATE_TIME, 1L); + builder.append(11L, "2".getBytes(), "b".getBytes()); + builder.append(12L, "3".getBytes(), "c".getBytes()); + builder.close(); + + buffer.flip(); + + ConvertedRecords convertedRecords = MemoryRecords.readableRecords(buffer) + .downConvert(MAGIC_VALUE_V1, 0, time); + MemoryRecords records = convertedRecords.records(); + verifyRecordsProcessingStats(compressionType, convertedRecords.recordConversionStats(), 3, 2, + records.sizeInBytes(), buffer.limit()); + + List batches = Utils.toList(records.batches().iterator()); + if (compressionType != CompressionType.NONE) { + assertEquals(2, batches.size()); + assertEquals(RecordBatch.MAGIC_VALUE_V0, batches.get(0).magic()); + assertEquals(0, batches.get(0).baseOffset()); + assertEquals(MAGIC_VALUE_V1, batches.get(1).magic()); + assertEquals(1, batches.get(1).baseOffset()); + } else { + assertEquals(3, batches.size()); + assertEquals(RecordBatch.MAGIC_VALUE_V0, batches.get(0).magic()); + assertEquals(0, batches.get(0).baseOffset()); + assertEquals(MAGIC_VALUE_V1, batches.get(1).magic()); + assertEquals(1, batches.get(1).baseOffset()); + assertEquals(MAGIC_VALUE_V1, batches.get(2).magic()); + assertEquals(2, batches.get(2).baseOffset()); + } + + List logRecords = Utils.toList(records.records().iterator()); + assertEquals("1", utf8(logRecords.get(0).key())); + assertEquals("2", utf8(logRecords.get(1).key())); + assertEquals("3", utf8(logRecords.get(2).key())); + + convertedRecords = MemoryRecords.readableRecords(buffer).downConvert(MAGIC_VALUE_V1, 2L, time); + records = convertedRecords.records(); + + batches = Utils.toList(records.batches().iterator()); + logRecords = Utils.toList(records.records().iterator()); + + if (compressionType != CompressionType.NONE) { + assertEquals(2, batches.size()); + assertEquals(RecordBatch.MAGIC_VALUE_V0, batches.get(0).magic()); + assertEquals(0, batches.get(0).baseOffset()); + assertEquals(MAGIC_VALUE_V1, batches.get(1).magic()); + assertEquals(1, batches.get(1).baseOffset()); + assertEquals("1", utf8(logRecords.get(0).key())); + assertEquals("2", utf8(logRecords.get(1).key())); + assertEquals("3", utf8(logRecords.get(2).key())); + verifyRecordsProcessingStats(compressionType, convertedRecords.recordConversionStats(), 3, 2, + records.sizeInBytes(), buffer.limit()); + } else { + assertEquals(2, batches.size()); + assertEquals(RecordBatch.MAGIC_VALUE_V0, batches.get(0).magic()); + assertEquals(0, batches.get(0).baseOffset()); + assertEquals(MAGIC_VALUE_V1, batches.get(1).magic()); + assertEquals(2, batches.get(1).baseOffset()); + assertEquals("1", utf8(logRecords.get(0).key())); + assertEquals("3", utf8(logRecords.get(1).key())); + verifyRecordsProcessingStats(compressionType, convertedRecords.recordConversionStats(), 3, 1, + records.sizeInBytes(), buffer.limit()); + } + } + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsBuilderArgumentsProvider.class) + public void shouldThrowIllegalStateExceptionOnBuildWhenAborted(Args args) { + ByteBuffer buffer = allocateBuffer(128, args); + + MemoryRecordsBuilder builder = new MemoryRecordsBuilder(buffer, args.magic, args.compressionType, + TimestampType.CREATE_TIME, 0L, 0L, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, + RecordBatch.NO_SEQUENCE, false, false, RecordBatch.NO_PARTITION_LEADER_EPOCH, buffer.capacity()); + builder.abort(); + assertThrows(IllegalStateException.class, builder::build); + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsBuilderArgumentsProvider.class) + public void shouldResetBufferToInitialPositionOnAbort(Args args) { + ByteBuffer buffer = allocateBuffer(128, args); + + MemoryRecordsBuilder builder = new MemoryRecordsBuilder(buffer, args.magic, args.compressionType, + TimestampType.CREATE_TIME, 0L, 0L, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, + false, false, RecordBatch.NO_PARTITION_LEADER_EPOCH, buffer.capacity()); + builder.append(0L, "a".getBytes(), "1".getBytes()); + builder.abort(); + assertEquals(args.bufferOffset, builder.buffer().position()); + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsBuilderArgumentsProvider.class) + public void shouldThrowIllegalStateExceptionOnCloseWhenAborted(Args args) { + ByteBuffer buffer = allocateBuffer(128, args); + + MemoryRecordsBuilder builder = new MemoryRecordsBuilder(buffer, args.magic, args.compressionType, + TimestampType.CREATE_TIME, 0L, 0L, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, + false, false, RecordBatch.NO_PARTITION_LEADER_EPOCH, buffer.capacity()); + builder.abort(); + assertThrows(IllegalStateException.class, builder::close, "Should have thrown IllegalStateException"); + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsBuilderArgumentsProvider.class) + public void shouldThrowIllegalStateExceptionOnAppendWhenAborted(Args args) { + ByteBuffer buffer = allocateBuffer(128, args); + + MemoryRecordsBuilder builder = new MemoryRecordsBuilder(buffer, args.magic, args.compressionType, + TimestampType.CREATE_TIME, 0L, 0L, RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, + false, false, RecordBatch.NO_PARTITION_LEADER_EPOCH, buffer.capacity()); + builder.abort(); + assertThrows(IllegalStateException.class, () -> builder.append(0L, "a".getBytes(), "1".getBytes()), "Should have thrown IllegalStateException"); + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsBuilderArgumentsProvider.class) + public void testBuffersDereferencedOnClose(Args args) { + Runtime runtime = Runtime.getRuntime(); + int payloadLen = 1024 * 1024; + ByteBuffer buffer = ByteBuffer.allocate(payloadLen * 2); + byte[] key = new byte[0]; + byte[] value = new byte[payloadLen]; + new Random().nextBytes(value); // Use random payload so that compressed buffer is large + List builders = new ArrayList<>(100); + long startMem = 0; + long memUsed = 0; + int iterations = 0; + while (iterations++ < 100) { + buffer.rewind(); + MemoryRecordsBuilder builder = new MemoryRecordsBuilder(buffer, args.magic, args.compressionType, + TimestampType.CREATE_TIME, 0L, 0L, RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, false, false, + RecordBatch.NO_PARTITION_LEADER_EPOCH, 0); + builder.append(1L, key, value); + builder.build(); + builders.add(builder); + + System.gc(); + memUsed = runtime.totalMemory() - runtime.freeMemory() - startMem; + // Ignore memory usage during initialization + if (iterations == 2) + startMem = memUsed; + else if (iterations > 2 && memUsed < (iterations - 2) * 1024) + break; + } + assertTrue(iterations < 100, "Memory usage too high: " + memUsed); + } + + @ParameterizedTest + @ArgumentsSource(V2MemoryRecordsBuilderArgumentsProvider.class) + public void testRecordTimestampsWithDeleteHorizon(Args args) { + long deleteHorizon = 100; + int payloadLen = 1024 * 1024; + ByteBuffer buffer = ByteBuffer.allocate(payloadLen * 2); + ByteBufferOutputStream byteBufferOutputStream = new ByteBufferOutputStream(buffer); + MemoryRecordsBuilder builder = new MemoryRecordsBuilder(byteBufferOutputStream, args.magic, args.compressionType, + TimestampType.CREATE_TIME, 0L, 0L, RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, false, false, + RecordBatch.NO_PARTITION_LEADER_EPOCH, 0, deleteHorizon); + + builder.append(50L, "0".getBytes(), "0".getBytes()); + builder.append(100L, "1".getBytes(), null); + builder.append(150L, "2".getBytes(), "2".getBytes()); + + MemoryRecords records = builder.build(); + List batches = TestUtils.toList(records.batches()); + assertEquals(OptionalLong.of(deleteHorizon), batches.get(0).deleteHorizonMs()); + + CloseableIterator recordIterator = batches.get(0).streamingIterator(BufferSupplier.create()); + Record record = recordIterator.next(); + assertEquals(50L, record.timestamp()); + record = recordIterator.next(); + assertEquals(100L, record.timestamp()); + record = recordIterator.next(); + assertEquals(150L, record.timestamp()); + recordIterator.close(); + } + + private static class V2MemoryRecordsBuilderArgumentsProvider implements ArgumentsProvider { + @Override + public Stream provideArguments(ExtensionContext context) { + List values = new ArrayList<>(); + for (int bufferOffset : Arrays.asList(0, 15)) + for (CompressionType type: CompressionType.values()) { + values.add(Arguments.of(new Args(bufferOffset, type, MAGIC_VALUE_V2))); + } + return values.stream(); + } + } + + private void verifyRecordsProcessingStats(CompressionType compressionType, RecordConversionStats processingStats, + int numRecords, int numRecordsConverted, long finalBytes, + long preConvertedBytes) { + assertNotNull(processingStats, "Records processing info is null"); + assertEquals(numRecordsConverted, processingStats.numRecordsConverted()); + // Since nanoTime accuracy on build machines may not be sufficient to measure small conversion times, + // only check if the value >= 0. Default is -1, so this checks if time has been recorded. + assertTrue(processingStats.conversionTimeNanos() >= 0, "Processing time not recorded: " + processingStats); + long tempBytes = processingStats.temporaryMemoryBytes(); + if (compressionType == CompressionType.NONE) { + if (numRecordsConverted == 0) + assertEquals(finalBytes, tempBytes); + else if (numRecordsConverted == numRecords) + assertEquals(preConvertedBytes + finalBytes, tempBytes); + else { + assertTrue(tempBytes > finalBytes && tempBytes < finalBytes + preConvertedBytes, + String.format("Unexpected temp bytes %d final %d pre %d", tempBytes, finalBytes, preConvertedBytes)); + } + } else { + long compressedBytes = finalBytes - Records.LOG_OVERHEAD - LegacyRecord.RECORD_OVERHEAD_V0; + assertTrue(tempBytes > compressedBytes, + String.format("Uncompressed size expected temp=%d, compressed=%d", tempBytes, compressedBytes)); + } + } + + private ByteBuffer allocateBuffer(int size, Args args) { + ByteBuffer buffer = ByteBuffer.allocate(size); + buffer.position(args.bufferOffset); + return buffer; + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsTest.java b/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsTest.java new file mode 100644 index 0000000..3f0195b --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/record/MemoryRecordsTest.java @@ -0,0 +1,1091 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.CorruptRecordException; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.message.LeaderChangeMessage; +import org.apache.kafka.common.message.LeaderChangeMessage.Voter; +import org.apache.kafka.common.record.MemoryRecords.RecordFilter; +import org.apache.kafka.common.record.MemoryRecords.RecordFilter.BatchRetention; +import org.apache.kafka.common.utils.BufferSupplier; +import org.apache.kafka.common.utils.CloseableIterator; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.ArgumentsProvider; +import org.junit.jupiter.params.provider.ArgumentsSource; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.OptionalLong; +import java.util.function.BiFunction; +import java.util.function.Supplier; +import java.util.stream.Stream; + +import static java.util.Arrays.asList; +import static org.apache.kafka.common.record.RecordBatch.MAGIC_VALUE_V0; +import static org.apache.kafka.common.record.RecordBatch.MAGIC_VALUE_V1; +import static org.apache.kafka.common.record.RecordBatch.MAGIC_VALUE_V2; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class MemoryRecordsTest { + + private static class Args { + final CompressionType compression; + final byte magic; + final long firstOffset; + final long pid; + final short epoch; + final int firstSequence; + + public Args(byte magic, long firstOffset, CompressionType compression) { + this.magic = magic; + this.compression = compression; + this.firstOffset = firstOffset; + if (magic >= RecordBatch.MAGIC_VALUE_V2) { + pid = 134234L; + epoch = 28; + firstSequence = 777; + } else { + pid = RecordBatch.NO_PRODUCER_ID; + epoch = RecordBatch.NO_PRODUCER_EPOCH; + firstSequence = RecordBatch.NO_SEQUENCE; + } + } + + @Override + public String toString() { + return "magic=" + magic + + ", firstOffset=" + firstOffset + + ", compressionType=" + compression; + } + } + + private static class MemoryRecordsArgumentsProvider implements ArgumentsProvider { + @Override + public Stream provideArguments(ExtensionContext context) { + List arguments = new ArrayList<>(); + for (long firstOffset : asList(0L, 57L)) + for (CompressionType type: CompressionType.values()) { + List magics = type == CompressionType.ZSTD + ? Collections.singletonList(RecordBatch.MAGIC_VALUE_V2) + : asList(RecordBatch.MAGIC_VALUE_V0, RecordBatch.MAGIC_VALUE_V1, RecordBatch.MAGIC_VALUE_V2); + for (byte magic : magics) + arguments.add(Arguments.of(new Args(magic, firstOffset, type))); + } + return arguments.stream(); + } + } + + private static class V2MemoryRecordsArgumentsProvider implements ArgumentsProvider { + @Override + public Stream provideArguments(ExtensionContext context) { + List arguments = new ArrayList<>(); + for (long firstOffset : asList(0L, 57L)) + for (CompressionType type: CompressionType.values()) { + arguments.add(Arguments.of(new Args(RecordBatch.MAGIC_VALUE_V2, firstOffset, type))); + } + return arguments.stream(); + } + } + + private final long logAppendTime = System.currentTimeMillis(); + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsArgumentsProvider.class) + public void testIterator(Args args) { + CompressionType compression = args.compression; + byte magic = args.magic; + long pid = args.pid; + short epoch = args.epoch; + int firstSequence = args.firstSequence; + long firstOffset = args.firstOffset; + ByteBuffer buffer = ByteBuffer.allocate(1024); + + int partitionLeaderEpoch = 998; + MemoryRecordsBuilder builder = new MemoryRecordsBuilder(buffer, magic, compression, + TimestampType.CREATE_TIME, firstOffset, logAppendTime, pid, epoch, firstSequence, false, false, + partitionLeaderEpoch, buffer.limit()); + + SimpleRecord[] records = new SimpleRecord[] { + new SimpleRecord(1L, "a".getBytes(), "1".getBytes()), + new SimpleRecord(2L, "b".getBytes(), "2".getBytes()), + new SimpleRecord(3L, "c".getBytes(), "3".getBytes()), + new SimpleRecord(4L, null, "4".getBytes()), + new SimpleRecord(5L, "d".getBytes(), null), + new SimpleRecord(6L, (byte[]) null, null) + }; + + for (SimpleRecord record : records) + builder.append(record); + + MemoryRecords memoryRecords = builder.build(); + for (int iteration = 0; iteration < 2; iteration++) { + int total = 0; + for (RecordBatch batch : memoryRecords.batches()) { + assertTrue(batch.isValid()); + assertEquals(compression, batch.compressionType()); + assertEquals(firstOffset + total, batch.baseOffset()); + + if (magic >= RecordBatch.MAGIC_VALUE_V2) { + assertEquals(pid, batch.producerId()); + assertEquals(epoch, batch.producerEpoch()); + assertEquals(firstSequence + total, batch.baseSequence()); + assertEquals(partitionLeaderEpoch, batch.partitionLeaderEpoch()); + assertEquals(records.length, batch.countOrNull().intValue()); + assertEquals(TimestampType.CREATE_TIME, batch.timestampType()); + assertEquals(records[records.length - 1].timestamp(), batch.maxTimestamp()); + } else { + assertEquals(RecordBatch.NO_PRODUCER_ID, batch.producerId()); + assertEquals(RecordBatch.NO_PRODUCER_EPOCH, batch.producerEpoch()); + assertEquals(RecordBatch.NO_SEQUENCE, batch.baseSequence()); + assertEquals(RecordBatch.NO_PARTITION_LEADER_EPOCH, batch.partitionLeaderEpoch()); + assertNull(batch.countOrNull()); + if (magic == RecordBatch.MAGIC_VALUE_V0) + assertEquals(TimestampType.NO_TIMESTAMP_TYPE, batch.timestampType()); + else + assertEquals(TimestampType.CREATE_TIME, batch.timestampType()); + } + + int recordCount = 0; + for (Record record : batch) { + record.ensureValid(); + assertTrue(record.hasMagic(batch.magic())); + assertFalse(record.isCompressed()); + assertEquals(firstOffset + total, record.offset()); + assertEquals(records[total].key(), record.key()); + assertEquals(records[total].value(), record.value()); + + if (magic >= RecordBatch.MAGIC_VALUE_V2) + assertEquals(firstSequence + total, record.sequence()); + + assertFalse(record.hasTimestampType(TimestampType.LOG_APPEND_TIME)); + if (magic == RecordBatch.MAGIC_VALUE_V0) { + assertEquals(RecordBatch.NO_TIMESTAMP, record.timestamp()); + assertFalse(record.hasTimestampType(TimestampType.CREATE_TIME)); + assertTrue(record.hasTimestampType(TimestampType.NO_TIMESTAMP_TYPE)); + } else { + assertEquals(records[total].timestamp(), record.timestamp()); + assertFalse(record.hasTimestampType(TimestampType.NO_TIMESTAMP_TYPE)); + if (magic < RecordBatch.MAGIC_VALUE_V2) + assertTrue(record.hasTimestampType(TimestampType.CREATE_TIME)); + else + assertFalse(record.hasTimestampType(TimestampType.CREATE_TIME)); + } + + total++; + recordCount++; + } + + assertEquals(batch.baseOffset() + recordCount - 1, batch.lastOffset()); + } + } + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsArgumentsProvider.class) + public void testHasRoomForMethod(Args args) { + MemoryRecordsBuilder builder = MemoryRecords.builder(ByteBuffer.allocate(1024), args.magic, args.compression, + TimestampType.CREATE_TIME, 0L); + builder.append(0L, "a".getBytes(), "1".getBytes()); + assertTrue(builder.hasRoomFor(1L, "b".getBytes(), "2".getBytes(), Record.EMPTY_HEADERS)); + builder.close(); + assertFalse(builder.hasRoomFor(1L, "b".getBytes(), "2".getBytes(), Record.EMPTY_HEADERS)); + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsArgumentsProvider.class) + public void testHasRoomForMethodWithHeaders(Args args) { + byte magic = args.magic; + MemoryRecordsBuilder builder = MemoryRecords.builder(ByteBuffer.allocate(120), magic, args.compression, + TimestampType.CREATE_TIME, 0L); + builder.append(logAppendTime, "key".getBytes(), "value".getBytes()); + RecordHeaders headers = new RecordHeaders(); + for (int i = 0; i < 10; ++i) headers.add("hello", "world.world".getBytes()); + // Make sure that hasRoomFor accounts for header sizes by letting a record without headers pass, but stopping + // a record with a large number of headers. + assertTrue(builder.hasRoomFor(logAppendTime, "key".getBytes(), "value".getBytes(), Record.EMPTY_HEADERS)); + if (magic < MAGIC_VALUE_V2) assertTrue(builder.hasRoomFor(logAppendTime, "key".getBytes(), "value".getBytes(), headers.toArray())); + else assertFalse(builder.hasRoomFor(logAppendTime, "key".getBytes(), "value".getBytes(), headers.toArray())); + } + + /** + * This test verifies that the checksum returned for various versions matches hardcoded values to catch unintentional + * changes to how the checksum is computed. + */ + @ParameterizedTest + @ArgumentsSource(MemoryRecordsArgumentsProvider.class) + public void testChecksum(Args args) { + CompressionType compression = args.compression; + byte magic = args.magic; + // we get reasonable coverage with uncompressed and one compression type + if (compression != CompressionType.NONE && compression != CompressionType.LZ4) + return; + + SimpleRecord[] records = { + new SimpleRecord(283843L, "key1".getBytes(), "value1".getBytes()), + new SimpleRecord(1234L, "key2".getBytes(), "value2".getBytes()) + }; + RecordBatch batch = MemoryRecords.withRecords(magic, compression, records).batches().iterator().next(); + long expectedChecksum; + if (magic == RecordBatch.MAGIC_VALUE_V0) { + if (compression == CompressionType.NONE) + expectedChecksum = 1978725405L; + else + expectedChecksum = 66944826L; + } else if (magic == RecordBatch.MAGIC_VALUE_V1) { + if (compression == CompressionType.NONE) + expectedChecksum = 109425508L; + else + expectedChecksum = 1407303399L; + } else { + if (compression == CompressionType.NONE) + expectedChecksum = 3851219455L; + else + expectedChecksum = 2745969314L; + } + assertEquals(expectedChecksum, batch.checksum(), "Unexpected checksum for magic " + magic + + " and compression type " + compression); + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsArgumentsProvider.class) + public void testFilterToPreservesPartitionLeaderEpoch(Args args) { + byte magic = args.magic; + int partitionLeaderEpoch = 67; + + ByteBuffer buffer = ByteBuffer.allocate(2048); + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, magic, args.compression, TimestampType.CREATE_TIME, + 0L, RecordBatch.NO_TIMESTAMP, partitionLeaderEpoch); + builder.append(10L, null, "a".getBytes()); + builder.append(11L, "1".getBytes(), "b".getBytes()); + builder.append(12L, null, "c".getBytes()); + + ByteBuffer filtered = ByteBuffer.allocate(2048); + builder.build().filterTo(new TopicPartition("foo", 0), new RetainNonNullKeysFilter(), filtered, + Integer.MAX_VALUE, BufferSupplier.NO_CACHING); + + filtered.flip(); + MemoryRecords filteredRecords = MemoryRecords.readableRecords(filtered); + + List batches = TestUtils.toList(filteredRecords.batches()); + assertEquals(1, batches.size()); + + MutableRecordBatch firstBatch = batches.get(0); + if (magic < MAGIC_VALUE_V2) assertEquals(RecordBatch.NO_PARTITION_LEADER_EPOCH, firstBatch.partitionLeaderEpoch()); + else assertEquals(partitionLeaderEpoch, firstBatch.partitionLeaderEpoch()); + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsArgumentsProvider.class) + public void testFilterToEmptyBatchRetention(Args args) { + byte magic = args.magic; + for (boolean isTransactional : Arrays.asList(true, false)) { + ByteBuffer buffer = ByteBuffer.allocate(2048); + long producerId = 23L; + short producerEpoch = 5; + long baseOffset = 3L; + int baseSequence = 10; + int partitionLeaderEpoch = 293; + int numRecords = 2; + + Supplier supplier = () -> MemoryRecords.builder(buffer, magic, args.compression, TimestampType.CREATE_TIME, + baseOffset, RecordBatch.NO_TIMESTAMP, producerId, producerEpoch, baseSequence, isTransactional, + partitionLeaderEpoch); + + if (isTransactional && magic < RecordBatch.MAGIC_VALUE_V2) assertThrows(IllegalArgumentException.class, supplier::get); + else { + MemoryRecordsBuilder builder = supplier.get(); + builder.append(11L, "2".getBytes(), "b".getBytes()); + builder.append(12L, "3".getBytes(), "c".getBytes()); + if (magic < MAGIC_VALUE_V2) assertThrows(IllegalArgumentException.class, builder::close); + else { + builder.close(); + MemoryRecords records = builder.build(); + ByteBuffer filtered = ByteBuffer.allocate(2048); + MemoryRecords.FilterResult filterResult = records.filterTo(new TopicPartition("foo", 0), + new MemoryRecords.RecordFilter(0, 0) { + @Override + protected BatchRetentionResult checkBatchRetention(RecordBatch batch) { + // retain all batches + return new BatchRetentionResult(BatchRetention.RETAIN_EMPTY, false); + } + + @Override + protected boolean shouldRetainRecord(RecordBatch recordBatch, Record record) { + // delete the records + return false; + } + }, filtered, Integer.MAX_VALUE, BufferSupplier.NO_CACHING); + + // Verify filter result + assertEquals(numRecords, filterResult.messagesRead()); + assertEquals(records.sizeInBytes(), filterResult.bytesRead()); + assertEquals(baseOffset + 1, filterResult.maxOffset()); + assertEquals(0, filterResult.messagesRetained()); + assertEquals(DefaultRecordBatch.RECORD_BATCH_OVERHEAD, filterResult.bytesRetained()); + assertEquals(12, filterResult.maxTimestamp()); + assertEquals(baseOffset + 1, filterResult.shallowOffsetOfMaxTimestamp()); + + // Verify filtered records + filtered.flip(); + MemoryRecords filteredRecords = MemoryRecords.readableRecords(filtered); + + List batches = TestUtils.toList(filteredRecords.batches()); + assertEquals(1, batches.size()); + + MutableRecordBatch batch = batches.get(0); + assertEquals(0, batch.countOrNull().intValue()); + assertEquals(12L, batch.maxTimestamp()); + assertEquals(TimestampType.CREATE_TIME, batch.timestampType()); + assertEquals(baseOffset, batch.baseOffset()); + assertEquals(baseOffset + 1, batch.lastOffset()); + assertEquals(baseSequence, batch.baseSequence()); + assertEquals(baseSequence + 1, batch.lastSequence()); + assertEquals(isTransactional, batch.isTransactional()); + } + } + } + } + + @Test + public void testEmptyBatchRetention() { + ByteBuffer buffer = ByteBuffer.allocate(DefaultRecordBatch.RECORD_BATCH_OVERHEAD); + long producerId = 23L; + short producerEpoch = 5; + long baseOffset = 3L; + int baseSequence = 10; + int partitionLeaderEpoch = 293; + long timestamp = System.currentTimeMillis(); + + DefaultRecordBatch.writeEmptyHeader(buffer, RecordBatch.MAGIC_VALUE_V2, producerId, producerEpoch, + baseSequence, baseOffset, baseOffset, partitionLeaderEpoch, TimestampType.CREATE_TIME, + timestamp, false, false); + buffer.flip(); + + ByteBuffer filtered = ByteBuffer.allocate(2048); + MemoryRecords records = MemoryRecords.readableRecords(buffer); + MemoryRecords.FilterResult filterResult = records.filterTo(new TopicPartition("foo", 0), + new MemoryRecords.RecordFilter(0, 0) { + @Override + protected BatchRetentionResult checkBatchRetention(RecordBatch batch) { + // retain all batches + return new BatchRetentionResult(BatchRetention.RETAIN_EMPTY, false); + } + + @Override + protected boolean shouldRetainRecord(RecordBatch recordBatch, Record record) { + return false; + } + }, filtered, Integer.MAX_VALUE, BufferSupplier.NO_CACHING); + + // Verify filter result + assertEquals(0, filterResult.messagesRead()); + assertEquals(records.sizeInBytes(), filterResult.bytesRead()); + assertEquals(baseOffset, filterResult.maxOffset()); + assertEquals(0, filterResult.messagesRetained()); + assertEquals(DefaultRecordBatch.RECORD_BATCH_OVERHEAD, filterResult.bytesRetained()); + assertEquals(timestamp, filterResult.maxTimestamp()); + assertEquals(baseOffset, filterResult.shallowOffsetOfMaxTimestamp()); + assertTrue(filterResult.outputBuffer().position() > 0); + + // Verify filtered records + filtered.flip(); + MemoryRecords filteredRecords = MemoryRecords.readableRecords(filtered); + assertEquals(DefaultRecordBatch.RECORD_BATCH_OVERHEAD, filteredRecords.sizeInBytes()); + } + + @Test + public void testEmptyBatchDeletion() { + for (final BatchRetention deleteRetention : Arrays.asList(BatchRetention.DELETE, BatchRetention.DELETE_EMPTY)) { + ByteBuffer buffer = ByteBuffer.allocate(DefaultRecordBatch.RECORD_BATCH_OVERHEAD); + long producerId = 23L; + short producerEpoch = 5; + long baseOffset = 3L; + int baseSequence = 10; + int partitionLeaderEpoch = 293; + long timestamp = System.currentTimeMillis(); + + DefaultRecordBatch.writeEmptyHeader(buffer, RecordBatch.MAGIC_VALUE_V2, producerId, producerEpoch, + baseSequence, baseOffset, baseOffset, partitionLeaderEpoch, TimestampType.CREATE_TIME, + timestamp, false, false); + buffer.flip(); + + ByteBuffer filtered = ByteBuffer.allocate(2048); + MemoryRecords records = MemoryRecords.readableRecords(buffer); + MemoryRecords.FilterResult filterResult = records.filterTo(new TopicPartition("foo", 0), + new MemoryRecords.RecordFilter(0, 0) { + @Override + protected BatchRetentionResult checkBatchRetention(RecordBatch batch) { + return new BatchRetentionResult(deleteRetention, false); + } + + @Override + protected boolean shouldRetainRecord(RecordBatch recordBatch, Record record) { + return false; + } + }, filtered, Integer.MAX_VALUE, BufferSupplier.NO_CACHING); + + // Verify filter result + assertEquals(0, filterResult.outputBuffer().position()); + + // Verify filtered records + filtered.flip(); + MemoryRecords filteredRecords = MemoryRecords.readableRecords(filtered); + assertEquals(0, filteredRecords.sizeInBytes()); + } + } + + @Test + public void testBuildEndTxnMarker() { + long producerId = 73; + short producerEpoch = 13; + long initialOffset = 983L; + int coordinatorEpoch = 347; + int partitionLeaderEpoch = 29; + + EndTransactionMarker marker = new EndTransactionMarker(ControlRecordType.COMMIT, coordinatorEpoch); + MemoryRecords records = MemoryRecords.withEndTransactionMarker(initialOffset, System.currentTimeMillis(), + partitionLeaderEpoch, producerId, producerEpoch, marker); + // verify that buffer allocation was precise + assertEquals(records.buffer().remaining(), records.buffer().capacity()); + + List batches = TestUtils.toList(records.batches()); + assertEquals(1, batches.size()); + + RecordBatch batch = batches.get(0); + assertTrue(batch.isControlBatch()); + assertEquals(producerId, batch.producerId()); + assertEquals(producerEpoch, batch.producerEpoch()); + assertEquals(initialOffset, batch.baseOffset()); + assertEquals(partitionLeaderEpoch, batch.partitionLeaderEpoch()); + assertTrue(batch.isValid()); + + List createdRecords = TestUtils.toList(batch); + assertEquals(1, createdRecords.size()); + + Record record = createdRecords.get(0); + record.ensureValid(); + EndTransactionMarker deserializedMarker = EndTransactionMarker.deserialize(record); + assertEquals(ControlRecordType.COMMIT, deserializedMarker.controlType()); + assertEquals(coordinatorEpoch, deserializedMarker.coordinatorEpoch()); + } + + /** + * This test is used to see if the base timestamp of the batch has been successfully + * converted to a delete horizon for the tombstones / transaction markers of the batch. + * It also verifies that the record timestamps remain correct as a delta relative to the delete horizon. + */ + @ParameterizedTest + @ArgumentsSource(V2MemoryRecordsArgumentsProvider.class) + public void testBaseTimestampToDeleteHorizonConversion(Args args) { + int partitionLeaderEpoch = 998; + ByteBuffer buffer = ByteBuffer.allocate(2048); + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, args.magic, args.compression, TimestampType.CREATE_TIME, + 0L, RecordBatch.NO_TIMESTAMP, partitionLeaderEpoch); + builder.append(5L, "0".getBytes(), "0".getBytes()); + builder.append(10L, "1".getBytes(), null); + builder.append(15L, "2".getBytes(), "2".getBytes()); + + ByteBuffer filtered = ByteBuffer.allocate(2048); + final long deleteHorizon = Integer.MAX_VALUE / 2; + final RecordFilter recordFilter = new MemoryRecords.RecordFilter(deleteHorizon - 1, 1) { + @Override + protected boolean shouldRetainRecord(RecordBatch recordBatch, Record record) { + return true; + } + + @Override + protected BatchRetentionResult checkBatchRetention(RecordBatch batch) { + return new BatchRetentionResult(BatchRetention.RETAIN_EMPTY, false); + } + }; + builder.build().filterTo(new TopicPartition("random", 0), recordFilter, filtered, Integer.MAX_VALUE, BufferSupplier.NO_CACHING); + filtered.flip(); + MemoryRecords filteredRecords = MemoryRecords.readableRecords(filtered); + + List batches = TestUtils.toList(filteredRecords.batches()); + assertEquals(1, batches.size()); + assertEquals(OptionalLong.of(deleteHorizon), batches.get(0).deleteHorizonMs()); + + CloseableIterator recordIterator = batches.get(0).streamingIterator(BufferSupplier.create()); + Record record = recordIterator.next(); + assertEquals(5L, record.timestamp()); + record = recordIterator.next(); + assertEquals(10L, record.timestamp()); + record = recordIterator.next(); + assertEquals(15L, record.timestamp()); + recordIterator.close(); + } + + @Test + public void testBuildLeaderChangeMessage() { + final int leaderId = 5; + final int leaderEpoch = 20; + final int voterId = 6; + long initialOffset = 983L; + + LeaderChangeMessage leaderChangeMessage = new LeaderChangeMessage() + .setLeaderId(leaderId) + .setVoters(Collections.singletonList( + new Voter().setVoterId(voterId))); + ByteBuffer buffer = ByteBuffer.allocate(256); + MemoryRecords records = MemoryRecords.withLeaderChangeMessage( + initialOffset, + System.currentTimeMillis(), + leaderEpoch, + buffer, + leaderChangeMessage + ); + + List batches = TestUtils.toList(records.batches()); + assertEquals(1, batches.size()); + + RecordBatch batch = batches.get(0); + assertTrue(batch.isControlBatch()); + assertEquals(initialOffset, batch.baseOffset()); + assertEquals(leaderEpoch, batch.partitionLeaderEpoch()); + assertTrue(batch.isValid()); + + List createdRecords = TestUtils.toList(batch); + assertEquals(1, createdRecords.size()); + + Record record = createdRecords.get(0); + record.ensureValid(); + assertEquals(ControlRecordType.LEADER_CHANGE, ControlRecordType.parse(record.key())); + + LeaderChangeMessage deserializedMessage = ControlRecordUtils.deserializeLeaderChangeMessage(record); + assertEquals(leaderId, deserializedMessage.leaderId()); + assertEquals(1, deserializedMessage.voters().size()); + assertEquals(voterId, deserializedMessage.voters().get(0).voterId()); + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsArgumentsProvider.class) + public void testFilterToBatchDiscard(Args args) { + CompressionType compression = args.compression; + byte magic = args.magic; + + ByteBuffer buffer = ByteBuffer.allocate(2048); + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, magic, compression, TimestampType.CREATE_TIME, 0L); + builder.append(10L, "1".getBytes(), "a".getBytes()); + builder.close(); + + builder = MemoryRecords.builder(buffer, magic, compression, TimestampType.CREATE_TIME, 1L); + builder.append(11L, "2".getBytes(), "b".getBytes()); + builder.append(12L, "3".getBytes(), "c".getBytes()); + builder.close(); + + builder = MemoryRecords.builder(buffer, magic, compression, TimestampType.CREATE_TIME, 3L); + builder.append(13L, "4".getBytes(), "d".getBytes()); + builder.append(20L, "5".getBytes(), "e".getBytes()); + builder.append(15L, "6".getBytes(), "f".getBytes()); + builder.close(); + + builder = MemoryRecords.builder(buffer, magic, compression, TimestampType.CREATE_TIME, 6L); + builder.append(16L, "7".getBytes(), "g".getBytes()); + builder.close(); + + buffer.flip(); + + ByteBuffer filtered = ByteBuffer.allocate(2048); + MemoryRecords.readableRecords(buffer).filterTo(new TopicPartition("foo", 0), new MemoryRecords.RecordFilter(0, 0) { + @Override + protected BatchRetentionResult checkBatchRetention(RecordBatch batch) { + // discard the second and fourth batches + if (batch.lastOffset() == 2L || batch.lastOffset() == 6L) + return new BatchRetentionResult(BatchRetention.DELETE, false); + return new BatchRetentionResult(BatchRetention.DELETE_EMPTY, false); + } + + @Override + protected boolean shouldRetainRecord(RecordBatch recordBatch, Record record) { + return true; + } + }, filtered, Integer.MAX_VALUE, BufferSupplier.NO_CACHING); + + filtered.flip(); + MemoryRecords filteredRecords = MemoryRecords.readableRecords(filtered); + + List batches = TestUtils.toList(filteredRecords.batches()); + if (compression != CompressionType.NONE || magic >= MAGIC_VALUE_V2) { + assertEquals(2, batches.size()); + assertEquals(0, batches.get(0).lastOffset()); + assertEquals(5, batches.get(1).lastOffset()); + } else { + assertEquals(5, batches.size()); + assertEquals(0, batches.get(0).lastOffset()); + assertEquals(1, batches.get(1).lastOffset()); + } + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsArgumentsProvider.class) + public void testFilterToAlreadyCompactedLog(Args args) { + byte magic = args.magic; + CompressionType compression = args.compression; + + ByteBuffer buffer = ByteBuffer.allocate(2048); + + // create a batch with some offset gaps to simulate a compacted batch + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, magic, compression, + TimestampType.CREATE_TIME, 0L); + builder.appendWithOffset(5L, 10L, null, "a".getBytes()); + builder.appendWithOffset(8L, 11L, "1".getBytes(), "b".getBytes()); + builder.appendWithOffset(10L, 12L, null, "c".getBytes()); + + builder.close(); + buffer.flip(); + + ByteBuffer filtered = ByteBuffer.allocate(2048); + MemoryRecords.readableRecords(buffer).filterTo(new TopicPartition("foo", 0), new RetainNonNullKeysFilter(), + filtered, Integer.MAX_VALUE, BufferSupplier.NO_CACHING); + filtered.flip(); + MemoryRecords filteredRecords = MemoryRecords.readableRecords(filtered); + + List batches = TestUtils.toList(filteredRecords.batches()); + assertEquals(1, batches.size()); + + MutableRecordBatch batch = batches.get(0); + List records = TestUtils.toList(batch); + assertEquals(1, records.size()); + assertEquals(8L, records.get(0).offset()); + + + if (magic >= RecordBatch.MAGIC_VALUE_V1) + assertEquals(new SimpleRecord(11L, "1".getBytes(), "b".getBytes()), new SimpleRecord(records.get(0))); + else + assertEquals(new SimpleRecord(RecordBatch.NO_TIMESTAMP, "1".getBytes(), "b".getBytes()), + new SimpleRecord(records.get(0))); + + if (magic >= RecordBatch.MAGIC_VALUE_V2) { + // the new format preserves first and last offsets from the original batch + assertEquals(0L, batch.baseOffset()); + assertEquals(10L, batch.lastOffset()); + } else { + assertEquals(8L, batch.baseOffset()); + assertEquals(8L, batch.lastOffset()); + } + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsArgumentsProvider.class) + public void testFilterToPreservesProducerInfo(Args args) { + byte magic = args.magic; + CompressionType compression = args.compression; + ByteBuffer buffer = ByteBuffer.allocate(2048); + + // non-idempotent, non-transactional + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, magic, compression, TimestampType.CREATE_TIME, 0L); + builder.append(10L, null, "a".getBytes()); + builder.append(11L, "1".getBytes(), "b".getBytes()); + builder.append(12L, null, "c".getBytes()); + + builder.close(); + + // idempotent + long pid1 = 23L; + short epoch1 = 5; + int baseSequence1 = 10; + MemoryRecordsBuilder idempotentBuilder = MemoryRecords.builder(buffer, magic, compression, TimestampType.CREATE_TIME, 3L, + RecordBatch.NO_TIMESTAMP, pid1, epoch1, baseSequence1); + idempotentBuilder.append(13L, null, "d".getBytes()); + idempotentBuilder.append(14L, "4".getBytes(), "e".getBytes()); + idempotentBuilder.append(15L, "5".getBytes(), "f".getBytes()); + if (magic < MAGIC_VALUE_V2) assertThrows(IllegalArgumentException.class, idempotentBuilder::close); + else idempotentBuilder.close(); + + + // transactional + long pid2 = 99384L; + short epoch2 = 234; + int baseSequence2 = 15; + Supplier transactionSupplier = () -> MemoryRecords.builder(buffer, magic, compression, TimestampType.CREATE_TIME, 3L, + RecordBatch.NO_TIMESTAMP, pid2, epoch2, baseSequence2, true, RecordBatch.NO_PARTITION_LEADER_EPOCH); + + if (magic < MAGIC_VALUE_V2) assertThrows(IllegalArgumentException.class, transactionSupplier::get); + else { + builder = transactionSupplier.get(); + builder.append(16L, "6".getBytes(), "g".getBytes()); + builder.append(17L, "7".getBytes(), "h".getBytes()); + builder.append(18L, null, "i".getBytes()); + builder.close(); + + buffer.flip(); + + ByteBuffer filtered = ByteBuffer.allocate(2048); + MemoryRecords.readableRecords(buffer).filterTo(new TopicPartition("foo", 0), new RetainNonNullKeysFilter(), + filtered, Integer.MAX_VALUE, BufferSupplier.NO_CACHING); + + filtered.flip(); + MemoryRecords filteredRecords = MemoryRecords.readableRecords(filtered); + + List batches = TestUtils.toList(filteredRecords.batches()); + assertEquals(3, batches.size()); + + MutableRecordBatch firstBatch = batches.get(0); + assertEquals(1, firstBatch.countOrNull().intValue()); + assertEquals(0L, firstBatch.baseOffset()); + assertEquals(2L, firstBatch.lastOffset()); + assertEquals(RecordBatch.NO_PRODUCER_ID, firstBatch.producerId()); + assertEquals(RecordBatch.NO_PRODUCER_EPOCH, firstBatch.producerEpoch()); + assertEquals(RecordBatch.NO_SEQUENCE, firstBatch.baseSequence()); + assertEquals(RecordBatch.NO_SEQUENCE, firstBatch.lastSequence()); + assertFalse(firstBatch.isTransactional()); + List firstBatchRecords = TestUtils.toList(firstBatch); + assertEquals(1, firstBatchRecords.size()); + assertEquals(RecordBatch.NO_SEQUENCE, firstBatchRecords.get(0).sequence()); + assertEquals(new SimpleRecord(11L, "1".getBytes(), "b".getBytes()), new SimpleRecord(firstBatchRecords.get(0))); + + MutableRecordBatch secondBatch = batches.get(1); + assertEquals(2, secondBatch.countOrNull().intValue()); + assertEquals(3L, secondBatch.baseOffset()); + assertEquals(5L, secondBatch.lastOffset()); + assertEquals(pid1, secondBatch.producerId()); + assertEquals(epoch1, secondBatch.producerEpoch()); + assertEquals(baseSequence1, secondBatch.baseSequence()); + assertEquals(baseSequence1 + 2, secondBatch.lastSequence()); + assertFalse(secondBatch.isTransactional()); + List secondBatchRecords = TestUtils.toList(secondBatch); + assertEquals(2, secondBatchRecords.size()); + assertEquals(baseSequence1 + 1, secondBatchRecords.get(0).sequence()); + assertEquals(new SimpleRecord(14L, "4".getBytes(), "e".getBytes()), new SimpleRecord(secondBatchRecords.get(0))); + assertEquals(baseSequence1 + 2, secondBatchRecords.get(1).sequence()); + assertEquals(new SimpleRecord(15L, "5".getBytes(), "f".getBytes()), new SimpleRecord(secondBatchRecords.get(1))); + + MutableRecordBatch thirdBatch = batches.get(2); + assertEquals(2, thirdBatch.countOrNull().intValue()); + assertEquals(3L, thirdBatch.baseOffset()); + assertEquals(5L, thirdBatch.lastOffset()); + assertEquals(pid2, thirdBatch.producerId()); + assertEquals(epoch2, thirdBatch.producerEpoch()); + assertEquals(baseSequence2, thirdBatch.baseSequence()); + assertEquals(baseSequence2 + 2, thirdBatch.lastSequence()); + assertTrue(thirdBatch.isTransactional()); + List thirdBatchRecords = TestUtils.toList(thirdBatch); + assertEquals(2, thirdBatchRecords.size()); + assertEquals(baseSequence2, thirdBatchRecords.get(0).sequence()); + assertEquals(new SimpleRecord(16L, "6".getBytes(), "g".getBytes()), new SimpleRecord(thirdBatchRecords.get(0))); + assertEquals(baseSequence2 + 1, thirdBatchRecords.get(1).sequence()); + assertEquals(new SimpleRecord(17L, "7".getBytes(), "h".getBytes()), new SimpleRecord(thirdBatchRecords.get(1))); + } + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsArgumentsProvider.class) + public void testFilterToWithUndersizedBuffer(Args args) { + byte magic = args.magic; + CompressionType compression = args.compression; + + ByteBuffer buffer = ByteBuffer.allocate(1024); + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, magic, compression, TimestampType.CREATE_TIME, 0L); + builder.append(10L, null, "a".getBytes()); + builder.close(); + + builder = MemoryRecords.builder(buffer, magic, compression, TimestampType.CREATE_TIME, 1L); + builder.append(11L, "1".getBytes(), new byte[128]); + builder.append(12L, "2".getBytes(), "c".getBytes()); + builder.append(13L, null, "d".getBytes()); + builder.close(); + + builder = MemoryRecords.builder(buffer, magic, compression, TimestampType.CREATE_TIME, 4L); + builder.append(14L, null, "e".getBytes()); + builder.append(15L, "5".getBytes(), "f".getBytes()); + builder.append(16L, "6".getBytes(), "g".getBytes()); + builder.close(); + + builder = MemoryRecords.builder(buffer, magic, compression, TimestampType.CREATE_TIME, 7L); + builder.append(17L, "7".getBytes(), new byte[128]); + builder.close(); + + buffer.flip(); + + ByteBuffer output = ByteBuffer.allocate(64); + + List records = new ArrayList<>(); + while (buffer.hasRemaining()) { + output.rewind(); + + MemoryRecords.FilterResult result = MemoryRecords.readableRecords(buffer) + .filterTo(new TopicPartition("foo", 0), new RetainNonNullKeysFilter(), output, Integer.MAX_VALUE, + BufferSupplier.NO_CACHING); + + buffer.position(buffer.position() + result.bytesRead()); + result.outputBuffer().flip(); + + if (output != result.outputBuffer()) + assertEquals(0, output.position()); + + MemoryRecords filtered = MemoryRecords.readableRecords(result.outputBuffer()); + records.addAll(TestUtils.toList(filtered.records())); + } + + assertEquals(5, records.size()); + for (Record record : records) + assertNotNull(record.key()); + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsArgumentsProvider.class) + public void testFilterTo(Args args) { + byte magic = args.magic; + CompressionType compression = args.compression; + + ByteBuffer buffer = ByteBuffer.allocate(2048); + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, magic, compression, TimestampType.CREATE_TIME, 0L); + builder.append(10L, null, "a".getBytes()); + builder.close(); + + builder = MemoryRecords.builder(buffer, magic, compression, TimestampType.CREATE_TIME, 1L); + builder.append(11L, "1".getBytes(), "b".getBytes()); + builder.append(12L, null, "c".getBytes()); + builder.close(); + + builder = MemoryRecords.builder(buffer, magic, compression, TimestampType.CREATE_TIME, 3L); + builder.append(13L, null, "d".getBytes()); + builder.append(20L, "4".getBytes(), "e".getBytes()); + builder.append(15L, "5".getBytes(), "f".getBytes()); + builder.close(); + + builder = MemoryRecords.builder(buffer, magic, compression, TimestampType.CREATE_TIME, 6L); + builder.append(16L, "6".getBytes(), "g".getBytes()); + builder.close(); + + buffer.flip(); + + ByteBuffer filtered = ByteBuffer.allocate(2048); + MemoryRecords.FilterResult result = MemoryRecords.readableRecords(buffer).filterTo( + new TopicPartition("foo", 0), new RetainNonNullKeysFilter(), filtered, Integer.MAX_VALUE, + BufferSupplier.NO_CACHING); + + filtered.flip(); + + assertEquals(7, result.messagesRead()); + assertEquals(4, result.messagesRetained()); + assertEquals(buffer.limit(), result.bytesRead()); + assertEquals(filtered.limit(), result.bytesRetained()); + if (magic > RecordBatch.MAGIC_VALUE_V0) { + assertEquals(20L, result.maxTimestamp()); + if (compression == CompressionType.NONE && magic < RecordBatch.MAGIC_VALUE_V2) + assertEquals(4L, result.shallowOffsetOfMaxTimestamp()); + else + assertEquals(5L, result.shallowOffsetOfMaxTimestamp()); + } + + MemoryRecords filteredRecords = MemoryRecords.readableRecords(filtered); + + List batches = TestUtils.toList(filteredRecords.batches()); + final List expectedEndOffsets; + final List expectedStartOffsets; + final List expectedMaxTimestamps; + + if (magic < RecordBatch.MAGIC_VALUE_V2 && compression == CompressionType.NONE) { + expectedEndOffsets = asList(1L, 4L, 5L, 6L); + expectedStartOffsets = asList(1L, 4L, 5L, 6L); + expectedMaxTimestamps = asList(11L, 20L, 15L, 16L); + } else if (magic < RecordBatch.MAGIC_VALUE_V2) { + expectedEndOffsets = asList(1L, 5L, 6L); + expectedStartOffsets = asList(1L, 4L, 6L); + expectedMaxTimestamps = asList(11L, 20L, 16L); + } else { + expectedEndOffsets = asList(2L, 5L, 6L); + expectedStartOffsets = asList(1L, 3L, 6L); + expectedMaxTimestamps = asList(11L, 20L, 16L); + } + + assertEquals(expectedEndOffsets.size(), batches.size()); + + for (int i = 0; i < expectedEndOffsets.size(); i++) { + RecordBatch batch = batches.get(i); + assertEquals(expectedStartOffsets.get(i).longValue(), batch.baseOffset()); + assertEquals(expectedEndOffsets.get(i).longValue(), batch.lastOffset()); + assertEquals(magic, batch.magic()); + assertEquals(compression, batch.compressionType()); + if (magic >= RecordBatch.MAGIC_VALUE_V1) { + assertEquals(expectedMaxTimestamps.get(i).longValue(), batch.maxTimestamp()); + assertEquals(TimestampType.CREATE_TIME, batch.timestampType()); + } else { + assertEquals(RecordBatch.NO_TIMESTAMP, batch.maxTimestamp()); + assertEquals(TimestampType.NO_TIMESTAMP_TYPE, batch.timestampType()); + } + } + + List records = TestUtils.toList(filteredRecords.records()); + assertEquals(4, records.size()); + + Record first = records.get(0); + assertEquals(1L, first.offset()); + if (magic > RecordBatch.MAGIC_VALUE_V0) + assertEquals(11L, first.timestamp()); + assertEquals("1", Utils.utf8(first.key(), first.keySize())); + assertEquals("b", Utils.utf8(first.value(), first.valueSize())); + + Record second = records.get(1); + assertEquals(4L, second.offset()); + if (magic > RecordBatch.MAGIC_VALUE_V0) + assertEquals(20L, second.timestamp()); + assertEquals("4", Utils.utf8(second.key(), second.keySize())); + assertEquals("e", Utils.utf8(second.value(), second.valueSize())); + + Record third = records.get(2); + assertEquals(5L, third.offset()); + if (magic > RecordBatch.MAGIC_VALUE_V0) + assertEquals(15L, third.timestamp()); + assertEquals("5", Utils.utf8(third.key(), third.keySize())); + assertEquals("f", Utils.utf8(third.value(), third.valueSize())); + + Record fourth = records.get(3); + assertEquals(6L, fourth.offset()); + if (magic > RecordBatch.MAGIC_VALUE_V0) + assertEquals(16L, fourth.timestamp()); + assertEquals("6", Utils.utf8(fourth.key(), fourth.keySize())); + assertEquals("g", Utils.utf8(fourth.value(), fourth.valueSize())); + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsArgumentsProvider.class) + public void testFilterToPreservesLogAppendTime(Args args) { + byte magic = args.magic; + CompressionType compression = args.compression; + long pid = args.pid; + short epoch = args.epoch; + int firstSequence = args.firstSequence; + long logAppendTime = System.currentTimeMillis(); + + ByteBuffer buffer = ByteBuffer.allocate(2048); + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, magic, compression, + TimestampType.LOG_APPEND_TIME, 0L, logAppendTime, pid, epoch, firstSequence); + builder.append(10L, null, "a".getBytes()); + builder.close(); + + builder = MemoryRecords.builder(buffer, magic, compression, TimestampType.LOG_APPEND_TIME, 1L, logAppendTime, + pid, epoch, firstSequence); + builder.append(11L, "1".getBytes(), "b".getBytes()); + builder.append(12L, null, "c".getBytes()); + builder.close(); + + builder = MemoryRecords.builder(buffer, magic, compression, TimestampType.LOG_APPEND_TIME, 3L, logAppendTime, + pid, epoch, firstSequence); + builder.append(13L, null, "d".getBytes()); + builder.append(14L, "4".getBytes(), "e".getBytes()); + builder.append(15L, "5".getBytes(), "f".getBytes()); + builder.close(); + + buffer.flip(); + + ByteBuffer filtered = ByteBuffer.allocate(2048); + MemoryRecords.readableRecords(buffer).filterTo(new TopicPartition("foo", 0), new RetainNonNullKeysFilter(), + filtered, Integer.MAX_VALUE, BufferSupplier.NO_CACHING); + + filtered.flip(); + MemoryRecords filteredRecords = MemoryRecords.readableRecords(filtered); + + List batches = TestUtils.toList(filteredRecords.batches()); + assertEquals(magic < RecordBatch.MAGIC_VALUE_V2 && compression == CompressionType.NONE ? 3 : 2, batches.size()); + + for (RecordBatch batch : batches) { + assertEquals(compression, batch.compressionType()); + if (magic > RecordBatch.MAGIC_VALUE_V0) { + assertEquals(TimestampType.LOG_APPEND_TIME, batch.timestampType()); + assertEquals(logAppendTime, batch.maxTimestamp()); + } + } + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsArgumentsProvider.class) + public void testNextBatchSize(Args args) { + ByteBuffer buffer = ByteBuffer.allocate(2048); + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, args.magic, args.compression, + TimestampType.LOG_APPEND_TIME, 0L, logAppendTime, args.pid, args.epoch, args.firstSequence); + builder.append(10L, null, "abc".getBytes()); + builder.close(); + + buffer.flip(); + int size = buffer.remaining(); + MemoryRecords records = MemoryRecords.readableRecords(buffer); + assertEquals(size, records.firstBatchSize().intValue()); + assertEquals(0, buffer.position()); + + buffer.limit(1); // size not in buffer + assertNull(records.firstBatchSize()); + buffer.limit(Records.LOG_OVERHEAD); // magic not in buffer + assertNull(records.firstBatchSize()); + buffer.limit(Records.HEADER_SIZE_UP_TO_MAGIC); // payload not in buffer + assertEquals(size, records.firstBatchSize().intValue()); + + buffer.limit(size); + byte magic = buffer.get(Records.MAGIC_OFFSET); + buffer.put(Records.MAGIC_OFFSET, (byte) 10); + assertThrows(CorruptRecordException.class, records::firstBatchSize); + buffer.put(Records.MAGIC_OFFSET, magic); + + buffer.put(Records.SIZE_OFFSET + 3, (byte) 0); + assertThrows(CorruptRecordException.class, records::firstBatchSize); + } + + @ParameterizedTest + @ArgumentsSource(MemoryRecordsArgumentsProvider.class) + public void testWithRecords(Args args) { + CompressionType compression = args.compression; + byte magic = args.magic; + MemoryRecords memoryRecords = MemoryRecords.withRecords(magic, compression, + new SimpleRecord(10L, "key1".getBytes(), "value1".getBytes())); + String key = Utils.utf8(memoryRecords.batches().iterator().next().iterator().next().key()); + assertEquals("key1", key); + } + + @Test + public void testUnsupportedCompress() { + BiFunction builderBiFunction = (magic, compressionType) -> + MemoryRecords.withRecords(magic, compressionType, new SimpleRecord(10L, "key1".getBytes(), "value1".getBytes())); + + Arrays.asList(MAGIC_VALUE_V0, MAGIC_VALUE_V1).forEach(magic -> { + Exception e = assertThrows(IllegalArgumentException.class, () -> builderBiFunction.apply(magic, CompressionType.ZSTD)); + assertEquals(e.getMessage(), "ZStandard compression is not supported for magic " + magic); + }); + } + + private static class RetainNonNullKeysFilter extends MemoryRecords.RecordFilter { + public RetainNonNullKeysFilter() { + super(0, 0); + } + + @Override + protected BatchRetentionResult checkBatchRetention(RecordBatch batch) { + return new BatchRetentionResult(BatchRetention.DELETE_EMPTY, false); + } + + @Override + public boolean shouldRetainRecord(RecordBatch batch, Record record) { + return record.hasKey(); + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/record/MultiRecordsSendTest.java b/clients/src/test/java/org/apache/kafka/common/record/MultiRecordsSendTest.java new file mode 100644 index 0000000..d307ba6 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/record/MultiRecordsSendTest.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.network.ByteBufferSend; +import org.apache.kafka.common.network.Send; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.LinkedList; +import java.util.Queue; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class MultiRecordsSendTest { + + @Test + public void testSendsFreedAfterWriting() throws IOException { + int numChunks = 4; + int chunkSize = 32; + int totalSize = numChunks * chunkSize; + + Queue sends = new LinkedList<>(); + ByteBuffer[] chunks = new ByteBuffer[numChunks]; + + for (int i = 0; i < numChunks; i++) { + ByteBuffer buffer = ByteBuffer.wrap(TestUtils.randomBytes(chunkSize)); + chunks[i] = buffer; + sends.add(new ByteBufferSend(buffer)); + } + + MultiRecordsSend send = new MultiRecordsSend(sends); + assertEquals(totalSize, send.size()); + + for (int i = 0; i < numChunks; i++) { + assertEquals(numChunks - i, send.numResidentSends()); + NonOverflowingByteBufferChannel out = new NonOverflowingByteBufferChannel(chunkSize); + send.writeTo(out); + out.close(); + assertEquals(chunks[i], out.buffer()); + } + + assertEquals(0, send.numResidentSends()); + assertTrue(send.completed()); + } + + private static class NonOverflowingByteBufferChannel extends org.apache.kafka.common.requests.ByteBufferChannel { + + private NonOverflowingByteBufferChannel(long size) { + super(size); + } + + @Override + public long write(ByteBuffer[] srcs) { + // Instead of overflowing, this channel refuses additional writes once the buffer is full, + // which allows us to test the MultiRecordsSend behavior on a per-send basis. + if (!buffer().hasRemaining()) + return 0; + return super.write(srcs); + } + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/record/SimpleLegacyRecordTest.java b/clients/src/test/java/org/apache/kafka/common/record/SimpleLegacyRecordTest.java new file mode 100644 index 0000000..7204709 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/record/SimpleLegacyRecordTest.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + +import org.apache.kafka.common.InvalidRecordException; +import org.apache.kafka.common.errors.CorruptRecordException; +import org.apache.kafka.common.utils.ByteBufferOutputStream; +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.api.Test; + +import java.io.DataOutputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class SimpleLegacyRecordTest { + + @Test + public void testCompressedIterationWithNullValue() throws Exception { + ByteBuffer buffer = ByteBuffer.allocate(128); + DataOutputStream out = new DataOutputStream(new ByteBufferOutputStream(buffer)); + AbstractLegacyRecordBatch.writeHeader(out, 0L, LegacyRecord.RECORD_OVERHEAD_V1); + LegacyRecord.write(out, RecordBatch.MAGIC_VALUE_V1, 1L, (byte[]) null, null, + CompressionType.GZIP, TimestampType.CREATE_TIME); + + buffer.flip(); + MemoryRecords records = MemoryRecords.readableRecords(buffer); + assertThrows(InvalidRecordException.class, () -> records.records().iterator().hasNext()); + } + + @Test + public void testCompressedIterationWithEmptyRecords() throws Exception { + ByteBuffer emptyCompressedValue = ByteBuffer.allocate(64); + OutputStream gzipOutput = CompressionType.GZIP.wrapForOutput(new ByteBufferOutputStream(emptyCompressedValue), + RecordBatch.MAGIC_VALUE_V1); + gzipOutput.close(); + emptyCompressedValue.flip(); + + ByteBuffer buffer = ByteBuffer.allocate(128); + DataOutputStream out = new DataOutputStream(new ByteBufferOutputStream(buffer)); + AbstractLegacyRecordBatch.writeHeader(out, 0L, LegacyRecord.RECORD_OVERHEAD_V1 + emptyCompressedValue.remaining()); + LegacyRecord.write(out, RecordBatch.MAGIC_VALUE_V1, 1L, null, Utils.toArray(emptyCompressedValue), + CompressionType.GZIP, TimestampType.CREATE_TIME); + + buffer.flip(); + + MemoryRecords records = MemoryRecords.readableRecords(buffer); + assertThrows(InvalidRecordException.class, () -> records.records().iterator().hasNext()); + } + + /* This scenario can happen if the record size field is corrupt and we end up allocating a buffer that is too small */ + @Test + public void testIsValidWithTooSmallBuffer() { + ByteBuffer buffer = ByteBuffer.allocate(2); + LegacyRecord record = new LegacyRecord(buffer); + assertFalse(record.isValid()); + assertThrows(CorruptRecordException.class, record::ensureValid); + + } + + @Test + public void testIsValidWithChecksumMismatch() { + ByteBuffer buffer = ByteBuffer.allocate(4); + // set checksum + buffer.putInt(2); + LegacyRecord record = new LegacyRecord(buffer); + assertFalse(record.isValid()); + assertThrows(CorruptRecordException.class, record::ensureValid); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/record/UnalignedFileRecordsTest.java b/clients/src/test/java/org/apache/kafka/common/record/UnalignedFileRecordsTest.java new file mode 100644 index 0000000..9a05a22 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/record/UnalignedFileRecordsTest.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.record; + + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Iterator; + +import static org.apache.kafka.test.TestUtils.tempFile; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class UnalignedFileRecordsTest { + + private byte[][] values = new byte[][] { + "foo".getBytes(), + "bar".getBytes() + }; + private FileRecords fileRecords; + + @BeforeEach + public void setup() throws IOException { + this.fileRecords = createFileRecords(values); + } + + @AfterEach + public void cleanup() throws IOException { + this.fileRecords.close(); + } + + @Test + public void testWriteTo() throws IOException { + + org.apache.kafka.common.requests.ByteBufferChannel channel = new org.apache.kafka.common.requests.ByteBufferChannel(fileRecords.sizeInBytes()); + int size = fileRecords.sizeInBytes(); + + UnalignedFileRecords records1 = fileRecords.sliceUnaligned(0, size / 2); + UnalignedFileRecords records2 = fileRecords.sliceUnaligned(size / 2, size - size / 2); + + records1.writeTo(channel, 0, records1.sizeInBytes()); + records2.writeTo(channel, 0, records2.sizeInBytes()); + + channel.close(); + Iterator records = MemoryRecords.readableRecords(channel.buffer()).records().iterator(); + for (byte[] value : values) { + assertTrue(records.hasNext()); + assertEquals(records.next().value(), ByteBuffer.wrap(value)); + } + } + + private FileRecords createFileRecords(byte[][] values) throws IOException { + FileRecords fileRecords = FileRecords.open(tempFile()); + + for (byte[] value : values) { + fileRecords.append(MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord(value))); + } + + return fileRecords; + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/replica/ReplicaSelectorTest.java b/clients/src/test/java/org/apache/kafka/common/replica/ReplicaSelectorTest.java new file mode 100644 index 0000000..15ddef6 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/replica/ReplicaSelectorTest.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.replica; + +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.security.auth.KafkaPrincipal; +import org.junit.jupiter.api.Test; + +import java.net.InetAddress; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.apache.kafka.test.TestUtils.assertOptional; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class ReplicaSelectorTest { + + @Test + public void testSameRackSelector() { + TopicPartition tp = new TopicPartition("test", 0); + + List replicaViewSet = replicaInfoSet(); + ReplicaView leader = replicaViewSet.get(0); + PartitionView partitionView = partitionInfo(new HashSet<>(replicaViewSet), leader); + + ReplicaSelector selector = new RackAwareReplicaSelector(); + Optional selected = selector.select(tp, metadata("rack-b"), partitionView); + assertOptional(selected, replicaInfo -> { + assertEquals(replicaInfo.endpoint().rack(), "rack-b", "Expect replica to be in rack-b"); + assertEquals(replicaInfo.endpoint().id(), 3, "Expected replica 3 since it is more caught-up"); + }); + + selected = selector.select(tp, metadata("not-a-rack"), partitionView); + assertOptional(selected, replicaInfo -> { + assertEquals(replicaInfo, leader, "Expect leader when we can't find any nodes in given rack"); + }); + + selected = selector.select(tp, metadata("rack-a"), partitionView); + assertOptional(selected, replicaInfo -> { + assertEquals(replicaInfo.endpoint().rack(), "rack-a", "Expect replica to be in rack-a"); + assertEquals(replicaInfo, leader, "Expect the leader since it's in rack-a"); + }); + + + } + + static List replicaInfoSet() { + return Stream.of( + replicaInfo(new Node(0, "host0", 1234, "rack-a"), 4, 0), + replicaInfo(new Node(1, "host1", 1234, "rack-a"), 2, 5), + replicaInfo(new Node(2, "host2", 1234, "rack-b"), 3, 3), + replicaInfo(new Node(3, "host3", 1234, "rack-b"), 4, 2) + + ).collect(Collectors.toList()); + } + + static ReplicaView replicaInfo(Node node, long logOffset, long timeSinceLastCaughtUpMs) { + return new ReplicaView.DefaultReplicaView(node, logOffset, timeSinceLastCaughtUpMs); + } + + static PartitionView partitionInfo(Set replicaViewSet, ReplicaView leader) { + return new PartitionView.DefaultPartitionView(replicaViewSet, leader); + } + + static ClientMetadata metadata(String rack) { + return new ClientMetadata.DefaultClientMetadata(rack, "test-client", + InetAddress.getLoopbackAddress(), KafkaPrincipal.ANONYMOUS, "TEST"); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/AddPartitionsToTxnRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/AddPartitionsToTxnRequestTest.java new file mode 100644 index 0000000..04bde4a --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/AddPartitionsToTxnRequestTest.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.annotation.ApiKeyVersionsSource; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; + +import java.util.ArrayList; + +import java.util.Collections; +import java.util.List; +import org.junit.jupiter.params.ParameterizedTest; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class AddPartitionsToTxnRequestTest { + + private static String transactionalId = "transactionalId"; + private static int producerId = 10; + private static short producerEpoch = 1; + private static int throttleTimeMs = 10; + + @ParameterizedTest + @ApiKeyVersionsSource(apiKey = ApiKeys.ADD_PARTITIONS_TO_TXN) + public void testConstructor(short version) { + List partitions = new ArrayList<>(); + partitions.add(new TopicPartition("topic", 0)); + partitions.add(new TopicPartition("topic", 1)); + + AddPartitionsToTxnRequest.Builder builder = new AddPartitionsToTxnRequest.Builder(transactionalId, producerId, producerEpoch, partitions); + AddPartitionsToTxnRequest request = builder.build(version); + + assertEquals(transactionalId, request.data().transactionalId()); + assertEquals(producerId, request.data().producerId()); + assertEquals(producerEpoch, request.data().producerEpoch()); + assertEquals(partitions, request.partitions()); + + AddPartitionsToTxnResponse response = request.getErrorResponse(throttleTimeMs, Errors.UNKNOWN_TOPIC_OR_PARTITION.exception()); + + assertEquals(Collections.singletonMap(Errors.UNKNOWN_TOPIC_OR_PARTITION, 2), response.errorCounts()); + assertEquals(throttleTimeMs, response.throttleTimeMs()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/AddPartitionsToTxnResponseTest.java b/clients/src/test/java/org/apache/kafka/common/requests/AddPartitionsToTxnResponseTest.java new file mode 100644 index 0000000..5b67bd4 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/AddPartitionsToTxnResponseTest.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.AddPartitionsToTxnResponseData; +import org.apache.kafka.common.message.AddPartitionsToTxnResponseData.AddPartitionsToTxnPartitionResult; +import org.apache.kafka.common.message.AddPartitionsToTxnResponseData.AddPartitionsToTxnTopicResult; +import org.apache.kafka.common.message.AddPartitionsToTxnResponseData.AddPartitionsToTxnTopicResultCollection; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class AddPartitionsToTxnResponseTest { + + protected final int throttleTimeMs = 10; + + protected final String topicOne = "topic1"; + protected final int partitionOne = 1; + protected final Errors errorOne = Errors.COORDINATOR_NOT_AVAILABLE; + protected final Errors errorTwo = Errors.NOT_COORDINATOR; + protected final String topicTwo = "topic2"; + protected final int partitionTwo = 2; + + protected TopicPartition tp1 = new TopicPartition(topicOne, partitionOne); + protected TopicPartition tp2 = new TopicPartition(topicTwo, partitionTwo); + protected Map expectedErrorCounts; + protected Map errorsMap; + + @BeforeEach + public void setUp() { + expectedErrorCounts = new HashMap<>(); + expectedErrorCounts.put(errorOne, 1); + expectedErrorCounts.put(errorTwo, 1); + + errorsMap = new HashMap<>(); + errorsMap.put(tp1, errorOne); + errorsMap.put(tp2, errorTwo); + } + + @Test + public void testConstructorWithErrorResponse() { + AddPartitionsToTxnResponse response = new AddPartitionsToTxnResponse(throttleTimeMs, errorsMap); + + assertEquals(expectedErrorCounts, response.errorCounts()); + assertEquals(throttleTimeMs, response.throttleTimeMs()); + } + + @Test + public void testParse() { + + AddPartitionsToTxnTopicResultCollection topicCollection = new AddPartitionsToTxnTopicResultCollection(); + + AddPartitionsToTxnTopicResult topicResult = new AddPartitionsToTxnTopicResult(); + topicResult.setName(topicOne); + + topicResult.results().add(new AddPartitionsToTxnPartitionResult() + .setErrorCode(errorOne.code()) + .setPartitionIndex(partitionOne)); + + topicResult.results().add(new AddPartitionsToTxnPartitionResult() + .setErrorCode(errorTwo.code()) + .setPartitionIndex(partitionTwo)); + + topicCollection.add(topicResult); + + AddPartitionsToTxnResponseData data = new AddPartitionsToTxnResponseData() + .setResults(topicCollection) + .setThrottleTimeMs(throttleTimeMs); + AddPartitionsToTxnResponse response = new AddPartitionsToTxnResponse(data); + + for (short version : ApiKeys.ADD_PARTITIONS_TO_TXN.allVersions()) { + AddPartitionsToTxnResponse parsedResponse = AddPartitionsToTxnResponse.parse(response.serialize(version), version); + assertEquals(expectedErrorCounts, parsedResponse.errorCounts()); + assertEquals(throttleTimeMs, parsedResponse.throttleTimeMs()); + assertEquals(version >= 1, parsedResponse.shouldClientThrottle(version)); + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/AlterReplicaLogDirsRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/AlterReplicaLogDirsRequestTest.java new file mode 100644 index 0000000..c18926d --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/AlterReplicaLogDirsRequestTest.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import java.util.HashMap; +import java.util.Map; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.LogDirNotFoundException; +import org.apache.kafka.common.message.AlterReplicaLogDirsRequestData; +import org.apache.kafka.common.message.AlterReplicaLogDirsRequestData.AlterReplicaLogDir; +import org.apache.kafka.common.message.AlterReplicaLogDirsRequestData.AlterReplicaLogDirCollection; +import org.apache.kafka.common.message.AlterReplicaLogDirsRequestData.AlterReplicaLogDirTopic; +import org.apache.kafka.common.message.AlterReplicaLogDirsRequestData.AlterReplicaLogDirTopicCollection; +import org.apache.kafka.common.message.AlterReplicaLogDirsResponseData.AlterReplicaLogDirTopicResult; +import org.apache.kafka.common.protocol.Errors; +import org.junit.jupiter.api.Test; + +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class AlterReplicaLogDirsRequestTest { + + @Test + public void testErrorResponse() { + AlterReplicaLogDirsRequestData data = new AlterReplicaLogDirsRequestData() + .setDirs(new AlterReplicaLogDirCollection( + singletonList(new AlterReplicaLogDir() + .setPath("/data0") + .setTopics(new AlterReplicaLogDirTopicCollection( + singletonList(new AlterReplicaLogDirTopic() + .setName("topic") + .setPartitions(asList(0, 1, 2))).iterator()))).iterator())); + AlterReplicaLogDirsResponse errorResponse = new AlterReplicaLogDirsRequest.Builder(data).build() + .getErrorResponse(123, new LogDirNotFoundException("/data0")); + assertEquals(1, errorResponse.data().results().size()); + AlterReplicaLogDirTopicResult topicResponse = errorResponse.data().results().get(0); + assertEquals("topic", topicResponse.topicName()); + assertEquals(3, topicResponse.partitions().size()); + for (int i = 0; i < 3; i++) { + assertEquals(i, topicResponse.partitions().get(i).partitionIndex()); + assertEquals(Errors.LOG_DIR_NOT_FOUND.code(), topicResponse.partitions().get(i).errorCode()); + } + } + + @Test + public void testPartitionDir() { + AlterReplicaLogDirsRequestData data = new AlterReplicaLogDirsRequestData() + .setDirs(new AlterReplicaLogDirCollection( + asList(new AlterReplicaLogDir() + .setPath("/data0") + .setTopics(new AlterReplicaLogDirTopicCollection( + asList(new AlterReplicaLogDirTopic() + .setName("topic") + .setPartitions(asList(0, 1)), + new AlterReplicaLogDirTopic() + .setName("topic2") + .setPartitions(asList(7))).iterator())), + new AlterReplicaLogDir() + .setPath("/data1") + .setTopics(new AlterReplicaLogDirTopicCollection( + asList(new AlterReplicaLogDirTopic() + .setName("topic3") + .setPartitions(asList(12))).iterator()))).iterator())); + AlterReplicaLogDirsRequest request = new AlterReplicaLogDirsRequest.Builder(data).build(); + Map expect = new HashMap<>(); + expect.put(new TopicPartition("topic", 0), "/data0"); + expect.put(new TopicPartition("topic", 1), "/data0"); + expect.put(new TopicPartition("topic2", 7), "/data0"); + expect.put(new TopicPartition("topic3", 12), "/data1"); + assertEquals(expect, request.partitionDirs()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/AlterReplicaLogDirsResponseTest.java b/clients/src/test/java/org/apache/kafka/common/requests/AlterReplicaLogDirsResponseTest.java new file mode 100644 index 0000000..edc441c --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/AlterReplicaLogDirsResponseTest.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import java.util.Map; + +import org.apache.kafka.common.message.AlterReplicaLogDirsResponseData; +import org.apache.kafka.common.message.AlterReplicaLogDirsResponseData.AlterReplicaLogDirPartitionResult; +import org.apache.kafka.common.message.AlterReplicaLogDirsResponseData.AlterReplicaLogDirTopicResult; +import org.apache.kafka.common.protocol.Errors; +import org.junit.jupiter.api.Test; + +import static java.util.Arrays.asList; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class AlterReplicaLogDirsResponseTest { + + @Test + public void testErrorCounts() { + AlterReplicaLogDirsResponseData data = new AlterReplicaLogDirsResponseData() + .setResults(asList( + new AlterReplicaLogDirTopicResult() + .setTopicName("t0") + .setPartitions(asList( + new AlterReplicaLogDirPartitionResult() + .setPartitionIndex(0) + .setErrorCode(Errors.LOG_DIR_NOT_FOUND.code()), + new AlterReplicaLogDirPartitionResult() + .setPartitionIndex(1) + .setErrorCode(Errors.NONE.code()))), + new AlterReplicaLogDirTopicResult() + .setTopicName("t1") + .setPartitions(asList( + new AlterReplicaLogDirPartitionResult() + .setPartitionIndex(0) + .setErrorCode(Errors.LOG_DIR_NOT_FOUND.code()))))); + Map counts = new AlterReplicaLogDirsResponse(data).errorCounts(); + assertEquals(2, counts.size()); + assertEquals(Integer.valueOf(2), counts.get(Errors.LOG_DIR_NOT_FOUND)); + assertEquals(Integer.valueOf(1), counts.get(Errors.NONE)); + + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/ApiErrorTest.java b/clients/src/test/java/org/apache/kafka/common/requests/ApiErrorTest.java new file mode 100644 index 0000000..8b0aa47 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/ApiErrorTest.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + + +import org.apache.kafka.common.errors.NotControllerException; +import org.apache.kafka.common.errors.NotCoordinatorException; +import org.apache.kafka.common.errors.NotEnoughReplicasException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.UnknownServerException; +import org.apache.kafka.common.errors.UnknownTopicOrPartitionException; +import org.apache.kafka.common.protocol.Errors; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ExecutionException; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class ApiErrorTest { + + @ParameterizedTest + @MethodSource("parameters") + public void fromThrowableShouldReturnCorrectError(Throwable t, Errors expectedErrors, String expectedMsg) { + ApiError apiError = ApiError.fromThrowable(t); + assertEquals(apiError.error(), expectedErrors); + assertEquals(apiError.message(), expectedMsg); + } + + private static Collection parameters() { + List arguments = new ArrayList<>(); + + arguments.add(Arguments.of( + new UnknownServerException("Don't leak sensitive information "), Errors.UNKNOWN_SERVER_ERROR, null)); + + arguments.add(Arguments.of( + new NotEnoughReplicasException(), Errors.NOT_ENOUGH_REPLICAS, null)); + + // avoid populating the error message if it's a generic one + arguments.add(Arguments.of( + new UnknownTopicOrPartitionException(Errors.UNKNOWN_TOPIC_OR_PARTITION.message()), Errors.UNKNOWN_TOPIC_OR_PARTITION, null)); + + String notCoordinatorErrorMsg = "Not coordinator"; + arguments.add(Arguments.of( + new NotCoordinatorException(notCoordinatorErrorMsg), Errors.NOT_COORDINATOR, notCoordinatorErrorMsg)); + + String notControllerErrorMsg = "Not controller"; + // test the NotControllerException is wrapped in the CompletionException, should return correct error + arguments.add(Arguments.of( + new CompletionException(new NotControllerException(notControllerErrorMsg)), Errors.NOT_CONTROLLER, notControllerErrorMsg)); + + String requestTimeoutErrorMsg = "request time out"; + // test the TimeoutException is wrapped in the ExecutionException, should return correct error + arguments.add(Arguments.of( + new ExecutionException(new TimeoutException(requestTimeoutErrorMsg)), Errors.REQUEST_TIMED_OUT, requestTimeoutErrorMsg)); + + // test the exception not in the Errors list, should return UNKNOWN_SERVER_ERROR + arguments.add(Arguments.of(new IOException(), Errors.UNKNOWN_SERVER_ERROR, null)); + + return arguments; + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/ApiVersionsResponseTest.java b/clients/src/test/java/org/apache/kafka/common/requests/ApiVersionsResponseTest.java new file mode 100644 index 0000000..2c9b1e8 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/ApiVersionsResponseTest.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.ApiMessageType; +import org.apache.kafka.common.message.ApiVersionsResponseData.ApiVersion; +import org.apache.kafka.common.message.ApiVersionsResponseData.ApiVersionCollection; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.record.RecordVersion; +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; + +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ApiVersionsResponseTest { + + @ParameterizedTest + @EnumSource(ApiMessageType.ListenerType.class) + public void shouldHaveCorrectDefaultApiVersionsResponse(ApiMessageType.ListenerType scope) { + ApiVersionsResponse defaultResponse = ApiVersionsResponse.defaultApiVersionsResponse(scope); + assertEquals(ApiKeys.apisForListener(scope).size(), defaultResponse.data().apiKeys().size(), + "API versions for all API keys must be maintained."); + + for (ApiKeys key : ApiKeys.apisForListener(scope)) { + ApiVersion version = defaultResponse.apiVersion(key.id); + assertNotNull(version, "Could not find ApiVersion for API " + key.name); + assertEquals(version.minVersion(), key.oldestVersion(), "Incorrect min version for Api " + key.name); + assertEquals(version.maxVersion(), key.latestVersion(), "Incorrect max version for Api " + key.name); + + // Check if versions less than min version are indeed set as null, i.e., deprecated. + for (int i = 0; i < version.minVersion(); ++i) { + assertNull(key.messageType.requestSchemas()[i], + "Request version " + i + " for API " + version.apiKey() + " must be null"); + assertNull(key.messageType.responseSchemas()[i], + "Response version " + i + " for API " + version.apiKey() + " must be null"); + } + + // Check if versions between min and max versions are non null, i.e., valid. + for (int i = version.minVersion(); i <= version.maxVersion(); ++i) { + assertNotNull(key.messageType.requestSchemas()[i], + "Request version " + i + " for API " + version.apiKey() + " must not be null"); + assertNotNull(key.messageType.responseSchemas()[i], + "Response version " + i + " for API " + version.apiKey() + " must not be null"); + } + } + + assertTrue(defaultResponse.data().supportedFeatures().isEmpty()); + assertTrue(defaultResponse.data().finalizedFeatures().isEmpty()); + assertEquals(ApiVersionsResponse.UNKNOWN_FINALIZED_FEATURES_EPOCH, defaultResponse.data().finalizedFeaturesEpoch()); + } + + @Test + public void shouldHaveCommonlyAgreedApiVersionResponseWithControllerOnForwardableAPIs() { + final ApiKeys forwardableAPIKey = ApiKeys.CREATE_ACLS; + final ApiKeys nonForwardableAPIKey = ApiKeys.JOIN_GROUP; + final short minVersion = 0; + final short maxVersion = 1; + Map activeControllerApiVersions = Utils.mkMap( + Utils.mkEntry(forwardableAPIKey, new ApiVersion() + .setApiKey(forwardableAPIKey.id) + .setMinVersion(minVersion) + .setMaxVersion(maxVersion)), + Utils.mkEntry(nonForwardableAPIKey, new ApiVersion() + .setApiKey(nonForwardableAPIKey.id) + .setMinVersion(minVersion) + .setMaxVersion(maxVersion)) + ); + + ApiVersionCollection commonResponse = ApiVersionsResponse.intersectForwardableApis( + ApiMessageType.ListenerType.ZK_BROKER, + RecordVersion.current(), + activeControllerApiVersions + ); + + verifyVersions(forwardableAPIKey.id, minVersion, maxVersion, commonResponse); + + verifyVersions(nonForwardableAPIKey.id, ApiKeys.JOIN_GROUP.oldestVersion(), + ApiKeys.JOIN_GROUP.latestVersion(), commonResponse); + } + + @Test + public void testIntersect() { + assertFalse(ApiVersionsResponse.intersect(null, null).isPresent()); + assertThrows(IllegalArgumentException.class, + () -> ApiVersionsResponse.intersect(new ApiVersion().setApiKey((short) 10), new ApiVersion().setApiKey((short) 3))); + + short min = 0; + short max = 10; + ApiVersion thisVersion = new ApiVersion() + .setApiKey(ApiKeys.FETCH.id) + .setMinVersion(min) + .setMaxVersion(Short.MAX_VALUE); + + ApiVersion other = new ApiVersion() + .setApiKey(ApiKeys.FETCH.id) + .setMinVersion(Short.MIN_VALUE) + .setMaxVersion(max); + + ApiVersion expected = new ApiVersion() + .setApiKey(ApiKeys.FETCH.id) + .setMinVersion(min) + .setMaxVersion(max); + + assertFalse(ApiVersionsResponse.intersect(thisVersion, null).isPresent()); + assertFalse(ApiVersionsResponse.intersect(null, other).isPresent()); + + assertEquals(expected, ApiVersionsResponse.intersect(thisVersion, other).get()); + // test for symmetric + assertEquals(expected, ApiVersionsResponse.intersect(other, thisVersion).get()); + } + + private void verifyVersions(short forwardableAPIKey, + short minVersion, + short maxVersion, + ApiVersionCollection commonResponse) { + ApiVersion expectedVersionsForForwardableAPI = + new ApiVersion() + .setApiKey(forwardableAPIKey) + .setMinVersion(minVersion) + .setMaxVersion(maxVersion); + assertEquals(expectedVersionsForForwardableAPI, commonResponse.find(forwardableAPIKey)); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/ByteBufferChannel.java b/clients/src/test/java/org/apache/kafka/common/requests/ByteBufferChannel.java new file mode 100644 index 0000000..7370f50 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/ByteBufferChannel.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.network.TransferableChannel; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; + +public class ByteBufferChannel implements TransferableChannel { + private final ByteBuffer buf; + private boolean closed = false; + + public ByteBufferChannel(long size) { + if (size > Integer.MAX_VALUE) + throw new IllegalArgumentException("size should be not be greater than Integer.MAX_VALUE"); + this.buf = ByteBuffer.allocate((int) size); + } + + @Override + public long write(ByteBuffer[] srcs, int offset, int length) { + if ((offset < 0) || (length < 0) || (offset > srcs.length - length)) + throw new IndexOutOfBoundsException(); + int position = buf.position(); + int count = offset + length; + for (int i = offset; i < count; i++) buf.put(srcs[i].duplicate()); + return buf.position() - position; + } + + @Override + public long write(ByteBuffer[] srcs) { + return write(srcs, 0, srcs.length); + } + + @Override + public int write(ByteBuffer src) { + return (int) write(new ByteBuffer[]{src}); + } + + @Override + public boolean isOpen() { + return !closed; + } + + @Override + public void close() { + buf.flip(); + closed = true; + } + + public ByteBuffer buffer() { + return buf; + } + + @Override + public boolean hasPendingWrites() { + return false; + } + + @Override + public long transferFrom(FileChannel fileChannel, long position, long count) throws IOException { + return fileChannel.transferTo(position, count, this); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/ByteBufferChannelTest.java b/clients/src/test/java/org/apache/kafka/common/requests/ByteBufferChannelTest.java new file mode 100644 index 0000000..58c913a --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/ByteBufferChannelTest.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.util.Arrays; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class ByteBufferChannelTest { + + @Test + public void testWriteBufferArrayWithNonZeroPosition() { + byte[] data = Utils.utf8("hello"); + ByteBuffer buffer = ByteBuffer.allocate(32); + buffer.position(10); + buffer.put(data); + + int limit = buffer.position(); + buffer.position(10); + buffer.limit(limit); + + ByteBufferChannel channel = new ByteBufferChannel(buffer.remaining()); + ByteBuffer[] buffers = new ByteBuffer[] {buffer}; + channel.write(buffers); + channel.close(); + ByteBuffer channelBuffer = channel.buffer(); + assertEquals(data.length, channelBuffer.remaining()); + assertEquals("hello", Utils.utf8(channelBuffer)); + } + + @Test + public void testWriteMultiplesByteBuffers() { + ByteBuffer[] buffers = new ByteBuffer[] { + ByteBuffer.wrap(Utils.utf8("hello")), + ByteBuffer.wrap(Utils.utf8("world")) + }; + int size = Arrays.stream(buffers).mapToInt(ByteBuffer::remaining).sum(); + ByteBuffer buf; + try (ByteBufferChannel channel = new ByteBufferChannel(size)) { + channel.write(buffers, 1, 1); + buf = channel.buffer(); + } + assertEquals("world", Utils.utf8(buf)); + + try (ByteBufferChannel channel = new ByteBufferChannel(size)) { + channel.write(buffers, 0, 1); + buf = channel.buffer(); + } + assertEquals("hello", Utils.utf8(buf)); + + try (ByteBufferChannel channel = new ByteBufferChannel(size)) { + channel.write(buffers, 0, 2); + buf = channel.buffer(); + } + assertEquals("helloworld", Utils.utf8(buf)); + } + + @Test + public void testInvalidArgumentsInWritsMultiplesByteBuffers() { + try (ByteBufferChannel channel = new ByteBufferChannel(10)) { + assertThrows(IndexOutOfBoundsException.class, () -> channel.write(new ByteBuffer[0], 1, 1)); + assertThrows(IndexOutOfBoundsException.class, () -> channel.write(new ByteBuffer[0], -1, 1)); + assertThrows(IndexOutOfBoundsException.class, () -> channel.write(new ByteBuffer[0], 0, -1)); + assertThrows(IndexOutOfBoundsException.class, () -> channel.write(new ByteBuffer[0], 0, 1)); + assertEquals(0, channel.write(new ByteBuffer[0], 0, 0)); + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/ControlledShutdownRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/ControlledShutdownRequestTest.java new file mode 100644 index 0000000..867be71 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/ControlledShutdownRequestTest.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.errors.ClusterAuthorizationException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.ControlledShutdownRequestData; +import org.apache.kafka.common.protocol.Errors; +import org.junit.jupiter.api.Test; + +import static org.apache.kafka.common.protocol.ApiKeys.CONTROLLED_SHUTDOWN; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class ControlledShutdownRequestTest { + + @Test + public void testUnsupportedVersion() { + ControlledShutdownRequest.Builder builder = new ControlledShutdownRequest.Builder( + new ControlledShutdownRequestData().setBrokerId(1), + (short) (CONTROLLED_SHUTDOWN.latestVersion() + 1)); + assertThrows(UnsupportedVersionException.class, builder::build); + } + + @Test + public void testGetErrorResponse() { + for (short version : CONTROLLED_SHUTDOWN.allVersions()) { + ControlledShutdownRequest.Builder builder = new ControlledShutdownRequest.Builder( + new ControlledShutdownRequestData().setBrokerId(1), version); + ControlledShutdownRequest request = builder.build(); + ControlledShutdownResponse response = request.getErrorResponse(0, + new ClusterAuthorizationException("Not authorized")); + assertEquals(Errors.CLUSTER_AUTHORIZATION_FAILED, response.error()); + } + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/CreateAclsRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/CreateAclsRequestTest.java new file mode 100644 index 0000000..97da6d7 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/CreateAclsRequestTest.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.acl.AccessControlEntry; +import org.apache.kafka.common.acl.AclBinding; +import org.apache.kafka.common.acl.AclOperation; +import org.apache.kafka.common.acl.AclPermissionType; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.CreateAclsRequestData; +import org.apache.kafka.common.resource.PatternType; +import org.apache.kafka.common.resource.ResourcePattern; +import org.apache.kafka.common.resource.ResourceType; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class CreateAclsRequestTest { + private static final short V0 = 0; + private static final short V1 = 1; + + private static final AclBinding LITERAL_ACL1 = new AclBinding(new ResourcePattern(ResourceType.TOPIC, "foo", PatternType.LITERAL), + new AccessControlEntry("User:ANONYMOUS", "127.0.0.1", AclOperation.READ, AclPermissionType.DENY)); + + private static final AclBinding LITERAL_ACL2 = new AclBinding(new ResourcePattern(ResourceType.GROUP, "group", PatternType.LITERAL), + new AccessControlEntry("User:*", "127.0.0.1", AclOperation.WRITE, AclPermissionType.ALLOW)); + + private static final AclBinding PREFIXED_ACL1 = new AclBinding(new ResourcePattern(ResourceType.GROUP, "prefix", PatternType.PREFIXED), + new AccessControlEntry("User:*", "127.0.0.1", AclOperation.CREATE, AclPermissionType.ALLOW)); + + private static final AclBinding UNKNOWN_ACL1 = new AclBinding(new ResourcePattern(ResourceType.UNKNOWN, "unknown", PatternType.LITERAL), + new AccessControlEntry("User:*", "127.0.0.1", AclOperation.CREATE, AclPermissionType.ALLOW)); + + @Test + public void shouldThrowOnV0IfNotLiteral() { + assertThrows(UnsupportedVersionException.class, () -> new CreateAclsRequest(data(PREFIXED_ACL1), V0)); + } + + @Test + public void shouldThrowOnIfUnknown() { + assertThrows(IllegalArgumentException.class, () -> new CreateAclsRequest(data(UNKNOWN_ACL1), V0)); + } + + @Test + public void shouldRoundTripV0() { + final CreateAclsRequest original = new CreateAclsRequest(data(LITERAL_ACL1, LITERAL_ACL2), V0); + final ByteBuffer buffer = original.serialize(); + + final CreateAclsRequest result = CreateAclsRequest.parse(buffer, V0); + + assertRequestEquals(original, result); + } + + @Test + public void shouldRoundTripV1() { + final CreateAclsRequest original = new CreateAclsRequest(data(LITERAL_ACL1, PREFIXED_ACL1), V1); + final ByteBuffer buffer = original.serialize(); + + final CreateAclsRequest result = CreateAclsRequest.parse(buffer, V1); + + assertRequestEquals(original, result); + } + + private static void assertRequestEquals(final CreateAclsRequest original, final CreateAclsRequest actual) { + assertEquals(original.aclCreations().size(), actual.aclCreations().size(), "Number of Acls wrong"); + + for (int idx = 0; idx != original.aclCreations().size(); ++idx) { + final AclBinding originalBinding = CreateAclsRequest.aclBinding(original.aclCreations().get(idx)); + final AclBinding actualBinding = CreateAclsRequest.aclBinding(actual.aclCreations().get(idx)); + assertEquals(originalBinding, actualBinding); + } + } + + private static CreateAclsRequestData data(final AclBinding... acls) { + List aclCreations = Arrays.stream(acls) + .map(CreateAclsRequest::aclCreation) + .collect(Collectors.toList()); + return new CreateAclsRequestData().setCreations(aclCreations); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/DeleteAclsRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/DeleteAclsRequestTest.java new file mode 100644 index 0000000..e88a5b9 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/DeleteAclsRequestTest.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.acl.AccessControlEntryFilter; +import org.apache.kafka.common.acl.AclBindingFilter; +import org.apache.kafka.common.acl.AclOperation; +import org.apache.kafka.common.acl.AclPermissionType; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.DeleteAclsRequestData; +import org.apache.kafka.common.resource.PatternType; +import org.apache.kafka.common.resource.ResourcePatternFilter; +import org.apache.kafka.common.resource.ResourceType; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.util.stream.Collectors; + +import static java.util.Arrays.asList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class DeleteAclsRequestTest { + private static final short V0 = 0; + private static final short V1 = 1; + + private static final AclBindingFilter LITERAL_FILTER = new AclBindingFilter(new ResourcePatternFilter(ResourceType.TOPIC, "foo", PatternType.LITERAL), + new AccessControlEntryFilter("User:ANONYMOUS", "127.0.0.1", AclOperation.READ, AclPermissionType.DENY)); + + private static final AclBindingFilter PREFIXED_FILTER = new AclBindingFilter(new ResourcePatternFilter(ResourceType.GROUP, "prefix", PatternType.PREFIXED), + new AccessControlEntryFilter("User:*", "127.0.0.1", AclOperation.CREATE, AclPermissionType.ALLOW)); + + private static final AclBindingFilter ANY_FILTER = new AclBindingFilter(new ResourcePatternFilter(ResourceType.GROUP, "bar", PatternType.ANY), + new AccessControlEntryFilter("User:*", "127.0.0.1", AclOperation.CREATE, AclPermissionType.ALLOW)); + + private static final AclBindingFilter UNKNOWN_FILTER = new AclBindingFilter(new ResourcePatternFilter(ResourceType.UNKNOWN, "prefix", PatternType.PREFIXED), + new AccessControlEntryFilter("User:*", "127.0.0.1", AclOperation.CREATE, AclPermissionType.ALLOW)); + + @Test + public void shouldThrowOnV0IfPrefixed() { + assertThrows(UnsupportedVersionException.class, () -> new DeleteAclsRequest.Builder(requestData(PREFIXED_FILTER)).build(V0)); + } + + @Test + public void shouldThrowOnUnknownElements() { + assertThrows(IllegalArgumentException.class, () -> new DeleteAclsRequest.Builder(requestData(UNKNOWN_FILTER)).build(V1)); + } + + @Test + public void shouldRoundTripLiteralV0() { + final DeleteAclsRequest original = new DeleteAclsRequest.Builder(requestData(LITERAL_FILTER)).build(V0); + final ByteBuffer buffer = original.serialize(); + + final DeleteAclsRequest result = DeleteAclsRequest.parse(buffer, V0); + + assertRequestEquals(original, result); + } + + @Test + public void shouldRoundTripAnyV0AsLiteral() { + final DeleteAclsRequest original = new DeleteAclsRequest.Builder(requestData(ANY_FILTER)).build(V0); + final DeleteAclsRequest expected = new DeleteAclsRequest.Builder(requestData( + new AclBindingFilter(new ResourcePatternFilter( + ANY_FILTER.patternFilter().resourceType(), + ANY_FILTER.patternFilter().name(), + PatternType.LITERAL), + ANY_FILTER.entryFilter())) + ).build(V0); + + final DeleteAclsRequest result = DeleteAclsRequest.parse(original.serialize(), V0); + + assertRequestEquals(expected, result); + } + + @Test + public void shouldRoundTripV1() { + final DeleteAclsRequest original = new DeleteAclsRequest.Builder( + requestData(LITERAL_FILTER, PREFIXED_FILTER, ANY_FILTER) + ).build(V1); + final ByteBuffer buffer = original.serialize(); + + final DeleteAclsRequest result = DeleteAclsRequest.parse(buffer, V1); + + assertRequestEquals(original, result); + } + + private static void assertRequestEquals(final DeleteAclsRequest original, final DeleteAclsRequest actual) { + assertEquals(original.filters().size(), actual.filters().size(), "Number of filters wrong"); + + for (int idx = 0; idx != original.filters().size(); ++idx) { + final AclBindingFilter originalFilter = original.filters().get(idx); + final AclBindingFilter actualFilter = actual.filters().get(idx); + assertEquals(originalFilter, actualFilter); + } + } + + private static DeleteAclsRequestData requestData(AclBindingFilter... acls) { + return new DeleteAclsRequestData().setFilters(asList(acls).stream() + .map(DeleteAclsRequest::deleteAclsFilter) + .collect(Collectors.toList())); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/DeleteAclsResponseTest.java b/clients/src/test/java/org/apache/kafka/common/requests/DeleteAclsResponseTest.java new file mode 100644 index 0000000..3baf3af --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/DeleteAclsResponseTest.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.acl.AclOperation; +import org.apache.kafka.common.acl.AclPermissionType; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.DeleteAclsResponseData; +import org.apache.kafka.common.message.DeleteAclsResponseData.DeleteAclsFilterResult; +import org.apache.kafka.common.message.DeleteAclsResponseData.DeleteAclsMatchingAcl; +import org.apache.kafka.common.resource.PatternType; +import org.apache.kafka.common.resource.ResourceType; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; + +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class DeleteAclsResponseTest { + private static final short V0 = 0; + private static final short V1 = 1; + + private static final DeleteAclsMatchingAcl LITERAL_ACL1 = new DeleteAclsMatchingAcl() + .setResourceType(ResourceType.TOPIC.code()) + .setResourceName("foo") + .setPatternType(PatternType.LITERAL.code()) + .setPrincipal("User:ANONYMOUS") + .setHost("127.0.0.1") + .setOperation(AclOperation.READ.code()) + .setPermissionType(AclPermissionType.DENY.code()); + + private static final DeleteAclsMatchingAcl LITERAL_ACL2 = new DeleteAclsMatchingAcl() + .setResourceType(ResourceType.GROUP.code()) + .setResourceName("group") + .setPatternType(PatternType.LITERAL.code()) + .setPrincipal("User:*") + .setHost("127.0.0.1") + .setOperation(AclOperation.WRITE.code()) + .setPermissionType(AclPermissionType.ALLOW.code()); + + private static final DeleteAclsMatchingAcl PREFIXED_ACL1 = new DeleteAclsMatchingAcl() + .setResourceType(ResourceType.GROUP.code()) + .setResourceName("prefix") + .setPatternType(PatternType.PREFIXED.code()) + .setPrincipal("User:*") + .setHost("127.0.0.1") + .setOperation(AclOperation.CREATE.code()) + .setPermissionType(AclPermissionType.ALLOW.code()); + + private static final DeleteAclsMatchingAcl UNKNOWN_ACL = new DeleteAclsMatchingAcl() + .setResourceType(ResourceType.UNKNOWN.code()) + .setResourceName("group") + .setPatternType(PatternType.LITERAL.code()) + .setPrincipal("User:*") + .setHost("127.0.0.1") + .setOperation(AclOperation.WRITE.code()) + .setPermissionType(AclPermissionType.ALLOW.code()); + + private static final DeleteAclsFilterResult LITERAL_RESPONSE = new DeleteAclsFilterResult().setMatchingAcls(asList( + LITERAL_ACL1, LITERAL_ACL2)); + + private static final DeleteAclsFilterResult PREFIXED_RESPONSE = new DeleteAclsFilterResult().setMatchingAcls(asList( + LITERAL_ACL1, PREFIXED_ACL1)); + + private static final DeleteAclsFilterResult UNKNOWN_RESPONSE = new DeleteAclsFilterResult().setMatchingAcls(asList( + UNKNOWN_ACL)); + + @Test + public void shouldThrowOnV0IfNotLiteral() { + assertThrows(UnsupportedVersionException.class, () -> new DeleteAclsResponse( + new DeleteAclsResponseData() + .setThrottleTimeMs(10) + .setFilterResults(singletonList(PREFIXED_RESPONSE)), + V0)); + } + + @Test + public void shouldThrowOnIfUnknown() { + assertThrows(IllegalArgumentException.class, () -> new DeleteAclsResponse( + new DeleteAclsResponseData() + .setThrottleTimeMs(10) + .setFilterResults(singletonList(UNKNOWN_RESPONSE)), + V1)); + } + + @Test + public void shouldRoundTripV0() { + final DeleteAclsResponse original = new DeleteAclsResponse( + new DeleteAclsResponseData() + .setThrottleTimeMs(10) + .setFilterResults(singletonList(LITERAL_RESPONSE)), + V0); + final ByteBuffer buffer = original.serialize(V0); + + final DeleteAclsResponse result = DeleteAclsResponse.parse(buffer, V0); + assertEquals(original.filterResults(), result.filterResults()); + } + + @Test + public void shouldRoundTripV1() { + final DeleteAclsResponse original = new DeleteAclsResponse( + new DeleteAclsResponseData() + .setThrottleTimeMs(10) + .setFilterResults(asList(LITERAL_RESPONSE, PREFIXED_RESPONSE)), + V1); + final ByteBuffer buffer = original.serialize(V1); + + final DeleteAclsResponse result = DeleteAclsResponse.parse(buffer, V1); + assertEquals(original.filterResults(), result.filterResults()); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/DeleteGroupsResponseTest.java b/clients/src/test/java/org/apache/kafka/common/requests/DeleteGroupsResponseTest.java new file mode 100644 index 0000000..ff352f1 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/DeleteGroupsResponseTest.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.DeleteGroupsResponseData; +import org.apache.kafka.common.message.DeleteGroupsResponseData.DeletableGroupResult; +import org.apache.kafka.common.message.DeleteGroupsResponseData.DeletableGroupResultCollection; +import org.apache.kafka.common.protocol.Errors; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class DeleteGroupsResponseTest { + + private static final String GROUP_ID_1 = "groupId1"; + private static final String GROUP_ID_2 = "groupId2"; + private static final int THROTTLE_TIME_MS = 10; + private static DeleteGroupsResponse deleteGroupsResponse; + + static { + deleteGroupsResponse = new DeleteGroupsResponse( + new DeleteGroupsResponseData() + .setResults( + new DeletableGroupResultCollection(Arrays.asList( + new DeletableGroupResult() + .setGroupId(GROUP_ID_1) + .setErrorCode(Errors.NONE.code()), + new DeletableGroupResult() + .setGroupId(GROUP_ID_2) + .setErrorCode(Errors.GROUP_AUTHORIZATION_FAILED.code())).iterator() + ) + ) + .setThrottleTimeMs(THROTTLE_TIME_MS)); + } + + @Test + public void testGetErrorWithExistingGroupIds() { + assertEquals(Errors.NONE, deleteGroupsResponse.get(GROUP_ID_1)); + assertEquals(Errors.GROUP_AUTHORIZATION_FAILED, deleteGroupsResponse.get(GROUP_ID_2)); + + Map expectedErrors = new HashMap<>(); + expectedErrors.put(GROUP_ID_1, Errors.NONE); + expectedErrors.put(GROUP_ID_2, Errors.GROUP_AUTHORIZATION_FAILED); + assertEquals(expectedErrors, deleteGroupsResponse.errors()); + + Map expectedErrorCounts = new HashMap<>(); + expectedErrorCounts.put(Errors.NONE, 1); + expectedErrorCounts.put(Errors.GROUP_AUTHORIZATION_FAILED, 1); + assertEquals(expectedErrorCounts, deleteGroupsResponse.errorCounts()); + } + + @Test + public void testGetErrorWithInvalidGroupId() { + assertThrows(IllegalArgumentException.class, () -> deleteGroupsResponse.get("invalid-group-id")); + } + + @Test + public void testGetThrottleTimeMs() { + assertEquals(THROTTLE_TIME_MS, deleteGroupsResponse.throttleTimeMs()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/DeleteTopicsRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/DeleteTopicsRequestTest.java new file mode 100644 index 0000000..897797a --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/DeleteTopicsRequestTest.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.DeleteTopicsRequestData; +import org.apache.kafka.common.message.DeleteTopicsRequestData.DeleteTopicState; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import static org.apache.kafka.common.protocol.ApiKeys.DELETE_TOPICS; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class DeleteTopicsRequestTest { + + @Test + public void testTopicNormalization() { + for (short version : DELETE_TOPICS.allVersions()) { + // Check topic names are in the correct place when using topicNames. + String topic1 = "topic1"; + String topic2 = "topic2"; + List topics = Arrays.asList(topic1, topic2); + DeleteTopicsRequest requestWithNames = new DeleteTopicsRequest.Builder( + new DeleteTopicsRequestData().setTopicNames(topics)).build(version); + DeleteTopicsRequest requestWithNamesSerialized = DeleteTopicsRequest.parse(requestWithNames.serialize(), version); + + assertEquals(topics, requestWithNames.topicNames()); + assertEquals(topics, requestWithNamesSerialized.topicNames()); + + if (version < 6) { + assertEquals(topics, requestWithNames.data().topicNames()); + assertEquals(topics, requestWithNamesSerialized.data().topicNames()); + } else { + // topics in TopicNames are moved to new topics field + assertEquals(topics, requestWithNames.data().topics().stream().map(DeleteTopicState::name).collect(Collectors.toList())); + assertEquals(topics, requestWithNamesSerialized.data().topics().stream().map(DeleteTopicState::name).collect(Collectors.toList())); + } + } + } + + @Test + public void testNewTopicsField() { + for (short version : DELETE_TOPICS.allVersions()) { + String topic1 = "topic1"; + String topic2 = "topic2"; + List topics = Arrays.asList(topic1, topic2); + DeleteTopicsRequest requestWithNames = new DeleteTopicsRequest.Builder( + new DeleteTopicsRequestData().setTopics(Arrays.asList( + new DeleteTopicsRequestData.DeleteTopicState().setName(topic1), + new DeleteTopicsRequestData.DeleteTopicState().setName(topic2)))).build(version); + // Ensure we only use new topics field on versions 6+. + if (version >= 6) { + DeleteTopicsRequest requestWithNamesSerialized = DeleteTopicsRequest.parse(requestWithNames.serialize(), version); + + assertEquals(topics, requestWithNames.topicNames()); + assertEquals(topics, requestWithNamesSerialized.topicNames()); + + } else { + // We should fail if version is less than 6. + assertThrows(UnsupportedVersionException.class, () -> requestWithNames.serialize()); + } + } + } + + @Test + public void testTopicIdsField() { + for (short version : DELETE_TOPICS.allVersions()) { + // Check topic IDs are handled correctly. We should only use this field on versions 6+. + Uuid topicId1 = Uuid.randomUuid(); + Uuid topicId2 = Uuid.randomUuid(); + List topicIds = Arrays.asList(topicId1, topicId2); + DeleteTopicsRequest requestWithIds = new DeleteTopicsRequest.Builder( + new DeleteTopicsRequestData().setTopics(Arrays.asList( + new DeleteTopicsRequestData.DeleteTopicState().setTopicId(topicId1), + new DeleteTopicsRequestData.DeleteTopicState().setTopicId(topicId2)))).build(version); + + if (version >= 6) { + DeleteTopicsRequest requestWithIdsSerialized = DeleteTopicsRequest.parse(requestWithIds.serialize(), version); + + assertEquals(topicIds, requestWithIds.topicIds()); + assertEquals(topicIds, requestWithIdsSerialized.topicIds()); + + // All topic names should be replaced with null + requestWithIds.data().topics().forEach(topic -> assertNull(topic.name())); + requestWithIdsSerialized.data().topics().forEach(topic -> assertNull(topic.name())); + } else { + // We should fail if version is less than 6. + assertThrows(UnsupportedVersionException.class, () -> requestWithIds.serialize()); + } + } + } + + @Test + public void testDeleteTopicsRequestNumTopics() { + for (short version : DELETE_TOPICS.allVersions()) { + DeleteTopicsRequest request = new DeleteTopicsRequest.Builder(new DeleteTopicsRequestData() + .setTopicNames(Arrays.asList("topic1", "topic2")) + .setTimeoutMs(1000)).build(version); + DeleteTopicsRequest serializedRequest = DeleteTopicsRequest.parse(request.serialize(), version); + // createDeleteTopicsRequest sets 2 topics + assertEquals(2, request.numberOfTopics()); + assertEquals(2, serializedRequest.numberOfTopics()); + + // Test using IDs + if (version >= 6) { + DeleteTopicsRequest requestWithIds = new DeleteTopicsRequest.Builder( + new DeleteTopicsRequestData().setTopics(Arrays.asList( + new DeleteTopicsRequestData.DeleteTopicState().setTopicId(Uuid.randomUuid()), + new DeleteTopicsRequestData.DeleteTopicState().setTopicId(Uuid.randomUuid())))).build(version); + DeleteTopicsRequest serializedRequestWithIds = DeleteTopicsRequest.parse(requestWithIds.serialize(), version); + assertEquals(2, requestWithIds.numberOfTopics()); + assertEquals(2, serializedRequestWithIds.numberOfTopics()); + } + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/DescribeAclsRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/DescribeAclsRequestTest.java new file mode 100644 index 0000000..00ce57a --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/DescribeAclsRequestTest.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.acl.AccessControlEntryFilter; +import org.apache.kafka.common.acl.AclBindingFilter; +import org.apache.kafka.common.acl.AclOperation; +import org.apache.kafka.common.acl.AclPermissionType; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.resource.PatternType; +import org.apache.kafka.common.resource.ResourcePatternFilter; +import org.apache.kafka.common.resource.ResourceType; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class DescribeAclsRequestTest { + private static final short V0 = 0; + private static final short V1 = 1; + + private static final AclBindingFilter LITERAL_FILTER = new AclBindingFilter(new ResourcePatternFilter(ResourceType.TOPIC, "foo", PatternType.LITERAL), + new AccessControlEntryFilter("User:ANONYMOUS", "127.0.0.1", AclOperation.READ, AclPermissionType.DENY)); + + private static final AclBindingFilter PREFIXED_FILTER = new AclBindingFilter(new ResourcePatternFilter(ResourceType.GROUP, "prefix", PatternType.PREFIXED), + new AccessControlEntryFilter("User:*", "127.0.0.1", AclOperation.CREATE, AclPermissionType.ALLOW)); + + private static final AclBindingFilter ANY_FILTER = new AclBindingFilter(new ResourcePatternFilter(ResourceType.GROUP, "bar", PatternType.ANY), + new AccessControlEntryFilter("User:*", "127.0.0.1", AclOperation.CREATE, AclPermissionType.ALLOW)); + + private static final AclBindingFilter UNKNOWN_FILTER = new AclBindingFilter(new ResourcePatternFilter(ResourceType.UNKNOWN, "foo", PatternType.LITERAL), + new AccessControlEntryFilter("User:ANONYMOUS", "127.0.0.1", AclOperation.READ, AclPermissionType.DENY)); + + @Test + public void shouldThrowOnV0IfPrefixed() { + assertThrows(UnsupportedVersionException.class, () -> new DescribeAclsRequest.Builder(PREFIXED_FILTER).build(V0)); + } + + @Test + public void shouldThrowIfUnknown() { + assertThrows(IllegalArgumentException.class, () -> new DescribeAclsRequest.Builder(UNKNOWN_FILTER).build(V0)); + } + + @Test + public void shouldRoundTripLiteralV0() { + final DescribeAclsRequest original = new DescribeAclsRequest.Builder(LITERAL_FILTER).build(V0); + final DescribeAclsRequest result = DescribeAclsRequest.parse(original.serialize(), V0); + + assertRequestEquals(original, result); + } + + @Test + public void shouldRoundTripAnyV0AsLiteral() { + final DescribeAclsRequest original = new DescribeAclsRequest.Builder(ANY_FILTER).build(V0); + final DescribeAclsRequest expected = new DescribeAclsRequest.Builder( + new AclBindingFilter(new ResourcePatternFilter( + ANY_FILTER.patternFilter().resourceType(), + ANY_FILTER.patternFilter().name(), + PatternType.LITERAL), + ANY_FILTER.entryFilter())).build(V0); + + final DescribeAclsRequest result = DescribeAclsRequest.parse(original.serialize(), V0); + assertRequestEquals(expected, result); + } + + @Test + public void shouldRoundTripLiteralV1() { + final DescribeAclsRequest original = new DescribeAclsRequest.Builder(LITERAL_FILTER).build(V1); + final DescribeAclsRequest result = DescribeAclsRequest.parse(original.serialize(), V1); + assertRequestEquals(original, result); + } + + @Test + public void shouldRoundTripPrefixedV1() { + final DescribeAclsRequest original = new DescribeAclsRequest.Builder(PREFIXED_FILTER).build(V1); + final DescribeAclsRequest result = DescribeAclsRequest.parse(original.serialize(), V1); + assertRequestEquals(original, result); + } + + @Test + public void shouldRoundTripAnyV1() { + final DescribeAclsRequest original = new DescribeAclsRequest.Builder(ANY_FILTER).build(V1); + final DescribeAclsRequest result = DescribeAclsRequest.parse(original.serialize(), V1); + assertRequestEquals(original, result); + } + + private static void assertRequestEquals(final DescribeAclsRequest original, final DescribeAclsRequest actual) { + final AclBindingFilter originalFilter = original.filter(); + final AclBindingFilter acttualFilter = actual.filter(); + assertEquals(originalFilter, acttualFilter); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/DescribeAclsResponseTest.java b/clients/src/test/java/org/apache/kafka/common/requests/DescribeAclsResponseTest.java new file mode 100644 index 0000000..d036a33 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/DescribeAclsResponseTest.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.acl.AccessControlEntry; +import org.apache.kafka.common.acl.AclBinding; +import org.apache.kafka.common.acl.AclOperation; +import org.apache.kafka.common.acl.AclPermissionType; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.DescribeAclsResponseData; +import org.apache.kafka.common.message.DescribeAclsResponseData.AclDescription; +import org.apache.kafka.common.message.DescribeAclsResponseData.DescribeAclsResource; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.resource.PatternType; +import org.apache.kafka.common.resource.ResourcePattern; +import org.apache.kafka.common.resource.ResourceType; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class DescribeAclsResponseTest { + private static final short V0 = 0; + private static final short V1 = 1; + + private static final AclDescription ALLOW_CREATE_ACL = buildAclDescription( + "127.0.0.1", + "User:ANONYMOUS", + AclOperation.CREATE, + AclPermissionType.ALLOW); + + private static final AclDescription DENY_READ_ACL = buildAclDescription( + "127.0.0.1", + "User:ANONYMOUS", + AclOperation.READ, + AclPermissionType.DENY); + + private static final DescribeAclsResource UNKNOWN_ACL = buildResource( + "foo", + ResourceType.UNKNOWN, + PatternType.LITERAL, + Collections.singletonList(DENY_READ_ACL)); + + private static final DescribeAclsResource PREFIXED_ACL1 = buildResource( + "prefix", + ResourceType.GROUP, + PatternType.PREFIXED, + Collections.singletonList(ALLOW_CREATE_ACL)); + + private static final DescribeAclsResource LITERAL_ACL1 = buildResource( + "foo", + ResourceType.TOPIC, + PatternType.LITERAL, + Collections.singletonList(ALLOW_CREATE_ACL)); + + private static final DescribeAclsResource LITERAL_ACL2 = buildResource( + "group", + ResourceType.GROUP, + PatternType.LITERAL, + Collections.singletonList(DENY_READ_ACL)); + + @Test + public void shouldThrowOnV0IfNotLiteral() { + assertThrows(UnsupportedVersionException.class, + () -> buildResponse(10, Errors.NONE, Collections.singletonList(PREFIXED_ACL1)).serialize(V0)); + } + + @Test + public void shouldThrowIfUnknown() { + assertThrows(IllegalArgumentException.class, + () -> buildResponse(10, Errors.NONE, Collections.singletonList(UNKNOWN_ACL)).serialize(V0)); + } + + @Test + public void shouldRoundTripV0() { + List resources = Arrays.asList(LITERAL_ACL1, LITERAL_ACL2); + final DescribeAclsResponse original = buildResponse(10, Errors.NONE, resources); + final ByteBuffer buffer = original.serialize(V0); + + final DescribeAclsResponse result = DescribeAclsResponse.parse(buffer, V0); + assertResponseEquals(original, result); + + final DescribeAclsResponse result2 = buildResponse(10, Errors.NONE, DescribeAclsResponse.aclsResources( + DescribeAclsResponse.aclBindings(resources))); + assertResponseEquals(original, result2); + } + + @Test + public void shouldRoundTripV1() { + List resources = Arrays.asList(LITERAL_ACL1, PREFIXED_ACL1); + final DescribeAclsResponse original = buildResponse(100, Errors.NONE, resources); + final ByteBuffer buffer = original.serialize(V1); + + final DescribeAclsResponse result = DescribeAclsResponse.parse(buffer, V1); + assertResponseEquals(original, result); + + final DescribeAclsResponse result2 = buildResponse(100, Errors.NONE, DescribeAclsResponse.aclsResources( + DescribeAclsResponse.aclBindings(resources))); + assertResponseEquals(original, result2); + } + + @Test + public void testAclBindings() { + final AclBinding original = new AclBinding(new ResourcePattern(ResourceType.TOPIC, "foo", PatternType.LITERAL), + new AccessControlEntry("User:ANONYMOUS", "127.0.0.1", AclOperation.CREATE, AclPermissionType.ALLOW)); + + final List result = DescribeAclsResponse.aclBindings(Collections.singletonList(LITERAL_ACL1)); + assertEquals(1, result.size()); + assertEquals(original, result.get(0)); + } + + private static void assertResponseEquals(final DescribeAclsResponse original, final DescribeAclsResponse actual) { + final Set originalBindings = new HashSet<>(original.acls()); + final Set actualBindings = new HashSet<>(actual.acls()); + + assertEquals(originalBindings, actualBindings); + } + + private static DescribeAclsResponse buildResponse(int throttleTimeMs, Errors error, List resources) { + return new DescribeAclsResponse(new DescribeAclsResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setErrorCode(error.code()) + .setErrorMessage(error.message()) + .setResources(resources)); + } + + private static DescribeAclsResource buildResource(String name, ResourceType type, PatternType patternType, List acls) { + return new DescribeAclsResource() + .setResourceName(name) + .setResourceType(type.code()) + .setPatternType(patternType.code()) + .setAcls(acls); + } + + private static AclDescription buildAclDescription(String host, String principal, AclOperation operation, AclPermissionType permission) { + return new AclDescription() + .setHost(host) + .setPrincipal(principal) + .setOperation(operation.code()) + .setPermissionType(permission.code()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/EndTxnRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/EndTxnRequestTest.java new file mode 100644 index 0000000..f14bf66 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/EndTxnRequestTest.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.EndTxnRequestData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.junit.jupiter.api.Test; + +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class EndTxnRequestTest { + + @Test + public void testConstructor() { + short producerEpoch = 0; + int producerId = 1; + String transactionId = "txn_id"; + int throttleTimeMs = 10; + + EndTxnRequest.Builder builder = new EndTxnRequest.Builder( + new EndTxnRequestData() + .setCommitted(true) + .setProducerEpoch(producerEpoch) + .setProducerId(producerId) + .setTransactionalId(transactionId)); + + for (short version : ApiKeys.END_TXN.allVersions()) { + EndTxnRequest request = builder.build(version); + + EndTxnResponse response = request.getErrorResponse(throttleTimeMs, Errors.NOT_COORDINATOR.exception()); + + assertEquals(Collections.singletonMap(Errors.NOT_COORDINATOR, 1), response.errorCounts()); + + assertEquals(TransactionResult.COMMIT, request.result()); + + assertEquals(throttleTimeMs, response.throttleTimeMs()); + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/EndTxnResponseTest.java b/clients/src/test/java/org/apache/kafka/common/requests/EndTxnResponseTest.java new file mode 100644 index 0000000..39c4bc0 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/EndTxnResponseTest.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.EndTxnResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class EndTxnResponseTest { + + @Test + public void testConstructor() { + int throttleTimeMs = 10; + + EndTxnResponseData data = new EndTxnResponseData() + .setErrorCode(Errors.NOT_COORDINATOR.code()) + .setThrottleTimeMs(throttleTimeMs); + + Map expectedErrorCounts = Collections.singletonMap(Errors.NOT_COORDINATOR, 1); + + for (short version : ApiKeys.END_TXN.allVersions()) { + EndTxnResponse response = new EndTxnResponse(data); + assertEquals(expectedErrorCounts, response.errorCounts()); + assertEquals(throttleTimeMs, response.throttleTimeMs()); + assertEquals(version >= 1, response.shouldClientThrottle(version)); + + response = EndTxnResponse.parse(response.serialize(version), version); + assertEquals(expectedErrorCounts, response.errorCounts()); + assertEquals(throttleTimeMs, response.throttleTimeMs()); + assertEquals(version >= 1, response.shouldClientThrottle(version)); + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/EnvelopeRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/EnvelopeRequestTest.java new file mode 100644 index 0000000..36b8618 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/EnvelopeRequestTest.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.EnvelopeRequestData; +import org.apache.kafka.common.network.Send; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.security.auth.KafkaPrincipal; +import org.apache.kafka.common.security.authenticator.DefaultKafkaPrincipalBuilder; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.net.InetAddress; +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class EnvelopeRequestTest { + + @Test + public void testGetPrincipal() { + KafkaPrincipal kafkaPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "principal", true); + DefaultKafkaPrincipalBuilder kafkaPrincipalBuilder = new DefaultKafkaPrincipalBuilder(null, null); + + EnvelopeRequest.Builder requestBuilder = new EnvelopeRequest.Builder(ByteBuffer.allocate(0), + kafkaPrincipalBuilder.serialize(kafkaPrincipal), "client-address".getBytes()); + EnvelopeRequest request = requestBuilder.build(EnvelopeRequestData.HIGHEST_SUPPORTED_VERSION); + assertEquals(kafkaPrincipal, kafkaPrincipalBuilder.deserialize(request.requestPrincipal())); + } + + @Test + public void testToSend() throws IOException { + for (short version : ApiKeys.ENVELOPE.allVersions()) { + ByteBuffer requestData = ByteBuffer.wrap("foobar".getBytes()); + RequestHeader header = new RequestHeader(ApiKeys.ENVELOPE, version, "clientId", 15); + EnvelopeRequest request = new EnvelopeRequest.Builder( + requestData, + "principal".getBytes(), + InetAddress.getLocalHost().getAddress() + ).build(version); + + Send send = request.toSend(header); + ByteBuffer buffer = TestUtils.toBuffer(send); + assertEquals(send.size() - 4, buffer.getInt()); + assertEquals(header, RequestHeader.parse(buffer)); + + EnvelopeRequestData parsedRequestData = new EnvelopeRequestData(); + parsedRequestData.read(new ByteBufferAccessor(buffer), version); + assertEquals(request.data(), parsedRequestData); + } + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/EnvelopeResponseTest.java b/clients/src/test/java/org/apache/kafka/common/requests/EnvelopeResponseTest.java new file mode 100644 index 0000000..e0fa2fd --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/EnvelopeResponseTest.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.EnvelopeResponseData; +import org.apache.kafka.common.network.Send; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class EnvelopeResponseTest { + + @Test + public void testToSend() { + for (short version : ApiKeys.ENVELOPE.allVersions()) { + ByteBuffer responseData = ByteBuffer.wrap("foobar".getBytes()); + EnvelopeResponse response = new EnvelopeResponse(responseData, Errors.NONE); + short headerVersion = ApiKeys.ENVELOPE.responseHeaderVersion(version); + ResponseHeader header = new ResponseHeader(15, headerVersion); + + Send send = response.toSend(header, version); + ByteBuffer buffer = TestUtils.toBuffer(send); + assertEquals(send.size() - 4, buffer.getInt()); + assertEquals(header, ResponseHeader.parse(buffer, headerVersion)); + + EnvelopeResponseData parsedResponseData = new EnvelopeResponseData(); + parsedResponseData.read(new ByteBufferAccessor(buffer), version); + assertEquals(response.data(), parsedResponseData); + } + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/FetchRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/FetchRequestTest.java new file mode 100644 index 0000000..a567d43 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/FetchRequestTest.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.protocol.ApiKeys; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.Collections; +import java.util.Map; +import java.util.LinkedHashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Optional; +import java.util.stream.Stream; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; + +public class FetchRequestTest { + + private static Stream fetchVersions() { + return ApiKeys.FETCH.allVersions().stream().map(version -> Arguments.of(version)); + } + + @ParameterizedTest + @MethodSource("fetchVersions") + public void testToReplaceWithDifferentVersions(short version) { + boolean fetchRequestUsesTopicIds = version >= 13; + Uuid topicId = Uuid.randomUuid(); + TopicIdPartition tp = new TopicIdPartition(topicId, 0, "topic"); + + Map partitionData = Collections.singletonMap(tp.topicPartition(), + new FetchRequest.PartitionData(topicId, 0, 0, 0, Optional.empty())); + List toReplace = Collections.singletonList(tp); + + FetchRequest fetchRequest = FetchRequest.Builder + .forReplica(version, 0, 1, 1, partitionData) + .removed(Collections.emptyList()) + .replaced(toReplace) + .metadata(FetchMetadata.newIncremental(123)).build(version); + + // If version < 13, we should not see any partitions in forgottenTopics. This is because we can not + // distinguish different topic IDs on versions earlier than 13. + assertEquals(fetchRequestUsesTopicIds, fetchRequest.data().forgottenTopicsData().size() > 0); + fetchRequest.data().forgottenTopicsData().forEach(forgottenTopic -> { + // Since we didn't serialize, we should see the topic name and ID regardless of the version. + assertEquals(tp.topic(), forgottenTopic.topic()); + assertEquals(topicId, forgottenTopic.topicId()); + }); + + assertEquals(1, fetchRequest.data().topics().size()); + fetchRequest.data().topics().forEach(topic -> { + // Since we didn't serialize, we should see the topic name and ID regardless of the version. + assertEquals(tp.topic(), topic.topic()); + assertEquals(topicId, topic.topicId()); + }); + } + + @ParameterizedTest + @MethodSource("fetchVersions") + public void testFetchData(short version) { + TopicPartition topicPartition0 = new TopicPartition("topic", 0); + TopicPartition topicPartition1 = new TopicPartition("unknownIdTopic", 0); + Uuid topicId0 = Uuid.randomUuid(); + Uuid topicId1 = Uuid.randomUuid(); + + // Only include topic IDs for the first topic partition. + Map topicNames = Collections.singletonMap(topicId0, topicPartition0.topic()); + List topicIdPartitions = new LinkedList<>(); + topicIdPartitions.add(new TopicIdPartition(topicId0, topicPartition0)); + topicIdPartitions.add(new TopicIdPartition(topicId1, topicPartition1)); + + // Include one topic with topic IDs in the topic names map and one without. + Map partitionData = new LinkedHashMap<>(); + partitionData.put(topicPartition0, new FetchRequest.PartitionData(topicId0, 0, 0, 0, Optional.empty())); + partitionData.put(topicPartition1, new FetchRequest.PartitionData(topicId1, 0, 0, 0, Optional.empty())); + boolean fetchRequestUsesTopicIds = version >= 13; + + FetchRequest fetchRequest = FetchRequest.parse(FetchRequest.Builder + .forReplica(version, 0, 1, 1, partitionData) + .removed(Collections.emptyList()) + .replaced(Collections.emptyList()) + .metadata(FetchMetadata.newIncremental(123)).build(version).serialize(), version); + + // For versions < 13, we will be provided a topic name and a zero UUID in FetchRequestData. + // Versions 13+ will contain a valid topic ID but an empty topic name. + List expectedData = new LinkedList<>(); + topicIdPartitions.forEach(tidp -> { + String expectedName = fetchRequestUsesTopicIds ? "" : tidp.topic(); + Uuid expectedTopicId = fetchRequestUsesTopicIds ? tidp.topicId() : Uuid.ZERO_UUID; + expectedData.add(new TopicIdPartition(expectedTopicId, tidp.partition(), expectedName)); + }); + + // Build the list of TopicIdPartitions based on the FetchRequestData that was serialized and parsed. + List convertedFetchData = new LinkedList<>(); + fetchRequest.data().topics().forEach(topic -> + topic.partitions().forEach(partition -> + convertedFetchData.add(new TopicIdPartition(topic.topicId(), partition.partition(), topic.topic())) + ) + ); + // The TopicIdPartitions built from the request data should match what we expect. + assertEquals(expectedData, convertedFetchData); + + // For fetch request version 13+ we expect topic names to be filled in for all topics in the topicNames map. + // Otherwise, the topic name should be null. + // For earlier request versions, we expect topic names and zero Uuids. + Map expectedFetchData = new LinkedHashMap<>(); + // Build the expected map based on fetchRequestUsesTopicIds. + expectedData.forEach(tidp -> { + String expectedName = fetchRequestUsesTopicIds ? topicNames.get(tidp.topicId()) : tidp.topic(); + TopicIdPartition tpKey = new TopicIdPartition(tidp.topicId(), new TopicPartition(expectedName, tidp.partition())); + // logStartOffset was not a valid field in versions 4 and earlier. + int logStartOffset = version > 4 ? 0 : -1; + expectedFetchData.put(tpKey, new FetchRequest.PartitionData(tidp.topicId(), 0, logStartOffset, 0, Optional.empty())); + }); + assertEquals(expectedFetchData, fetchRequest.fetchData(topicNames)); + } + + @ParameterizedTest + @MethodSource("fetchVersions") + public void testForgottenTopics(short version) { + // Forgotten topics are not allowed prior to version 7 + if (version >= 7) { + TopicPartition topicPartition0 = new TopicPartition("topic", 0); + TopicPartition topicPartition1 = new TopicPartition("unknownIdTopic", 0); + Uuid topicId0 = Uuid.randomUuid(); + Uuid topicId1 = Uuid.randomUuid(); + // Only include topic IDs for the first topic partition. + Map topicNames = Collections.singletonMap(topicId0, topicPartition0.topic()); + + // Include one topic with topic IDs in the topic names map and one without. + List toForgetTopics = new LinkedList<>(); + toForgetTopics.add(new TopicIdPartition(topicId0, topicPartition0)); + toForgetTopics.add(new TopicIdPartition(topicId1, topicPartition1)); + + boolean fetchRequestUsesTopicIds = version >= 13; + + FetchRequest fetchRequest = FetchRequest.parse(FetchRequest.Builder + .forReplica(version, 0, 1, 1, Collections.emptyMap()) + .removed(toForgetTopics) + .replaced(Collections.emptyList()) + .metadata(FetchMetadata.newIncremental(123)).build(version).serialize(), version); + + // For versions < 13, we will be provided a topic name and a zero Uuid in FetchRequestData. + // Versions 13+ will contain a valid topic ID but an empty topic name. + List expectedForgottenTopicData = new LinkedList<>(); + toForgetTopics.forEach(tidp -> { + String expectedName = fetchRequestUsesTopicIds ? "" : tidp.topic(); + Uuid expectedTopicId = fetchRequestUsesTopicIds ? tidp.topicId() : Uuid.ZERO_UUID; + expectedForgottenTopicData.add(new TopicIdPartition(expectedTopicId, tidp.partition(), expectedName)); + }); + + // Build the list of TopicIdPartitions based on the FetchRequestData that was serialized and parsed. + List convertedForgottenTopicData = new LinkedList<>(); + fetchRequest.data().forgottenTopicsData().forEach(forgottenTopic -> + forgottenTopic.partitions().forEach(partition -> + convertedForgottenTopicData.add(new TopicIdPartition(forgottenTopic.topicId(), partition, forgottenTopic.topic())) + ) + ); + // The TopicIdPartitions built from the request data should match what we expect. + assertEquals(expectedForgottenTopicData, convertedForgottenTopicData); + + // Get the forgottenTopics from the request data. + List forgottenTopics = fetchRequest.forgottenTopics(topicNames); + + // For fetch request version 13+ we expect topic names to be filled in for all topics in the topicNames map. + // Otherwise, the topic name should be null. + // For earlier request versions, we expect topic names and zero Uuids. + // Build the list of expected TopicIdPartitions. These are different from the earlier expected topicIdPartitions + // as empty strings are converted to nulls. + assertEquals(expectedForgottenTopicData.size(), forgottenTopics.size()); + List expectedForgottenTopics = new LinkedList<>(); + expectedForgottenTopicData.forEach(tidp -> { + String expectedName = fetchRequestUsesTopicIds ? topicNames.get(tidp.topicId()) : tidp.topic(); + expectedForgottenTopics.add(new TopicIdPartition(tidp.topicId(), new TopicPartition(expectedName, tidp.partition()))); + }); + assertEquals(expectedForgottenTopics, forgottenTopics); + } + } + + @Test + public void testPartitionDataEquals() { + assertEquals(new FetchRequest.PartitionData(Uuid.ZERO_UUID, 300, 0L, 300, Optional.of(300)), + new FetchRequest.PartitionData(Uuid.ZERO_UUID, 300, 0L, 300, Optional.of(300))); + + assertNotEquals(new FetchRequest.PartitionData(Uuid.randomUuid(), 300, 0L, 300, Optional.of(300)), + new FetchRequest.PartitionData(Uuid.randomUuid(), 300, 0L, 300, Optional.of(300))); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/FindCoordinatorRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/FindCoordinatorRequestTest.java new file mode 100644 index 0000000..9393358 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/FindCoordinatorRequestTest.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.errors.InvalidRequestException; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +class FindCoordinatorRequestTest { + + @Test + public void getInvalidCoordinatorTypeId() { + assertThrows(InvalidRequestException.class, + () -> FindCoordinatorRequest.CoordinatorType.forId((byte) 10)); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/HeartbeatRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/HeartbeatRequestTest.java new file mode 100644 index 0000000..4478c4e --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/HeartbeatRequestTest.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.HeartbeatRequestData; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class HeartbeatRequestTest { + + @Test + public void testRequestVersionCompatibilityFailBuild() { + assertThrows(UnsupportedVersionException.class, () -> new HeartbeatRequest.Builder( + new HeartbeatRequestData() + .setGroupId("groupId") + .setMemberId("consumerId") + .setGroupInstanceId("groupInstanceId") + ).build((short) 2)); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/JoinGroupRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/JoinGroupRequestTest.java new file mode 100644 index 0000000..ebf6ef2 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/JoinGroupRequestTest.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.errors.InvalidConfigurationException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.JoinGroupRequestData; +import org.apache.kafka.common.protocol.MessageUtil; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.util.Arrays; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.fail; + +public class JoinGroupRequestTest { + + @Test + public void shouldAcceptValidGroupInstanceIds() { + String maxLengthString = TestUtils.randomString(249); + String[] validGroupInstanceIds = {"valid", "INSTANCE", "gRoUp", "ar6", "VaL1d", "_0-9_.", "...", maxLengthString}; + + for (String instanceId : validGroupInstanceIds) { + JoinGroupRequest.validateGroupInstanceId(instanceId); + } + } + + @Test + public void shouldThrowOnInvalidGroupInstanceIds() { + char[] longString = new char[250]; + Arrays.fill(longString, 'a'); + String[] invalidGroupInstanceIds = {"", "foo bar", "..", "foo:bar", "foo=bar", ".", new String(longString)}; + + for (String instanceId : invalidGroupInstanceIds) { + try { + JoinGroupRequest.validateGroupInstanceId(instanceId); + fail("No exception was thrown for invalid instance id: " + instanceId); + } catch (InvalidConfigurationException e) { + // Good + } + } + } + @Test + public void testRequestVersionCompatibilityFailBuild() { + assertThrows(UnsupportedVersionException.class, () -> new JoinGroupRequest.Builder( + new JoinGroupRequestData() + .setGroupId("groupId") + .setMemberId("consumerId") + .setGroupInstanceId("groupInstanceId") + .setProtocolType("consumer") + ).build((short) 4)); + } + + @Test + public void testRebalanceTimeoutDefaultsToSessionTimeoutV0() { + int sessionTimeoutMs = 30000; + short version = 0; + + ByteBuffer buffer = MessageUtil.toByteBuffer(new JoinGroupRequestData() + .setGroupId("groupId") + .setMemberId("consumerId") + .setProtocolType("consumer") + .setSessionTimeoutMs(sessionTimeoutMs), version); + + JoinGroupRequest request = JoinGroupRequest.parse(buffer, version); + assertEquals(sessionTimeoutMs, request.data().sessionTimeoutMs()); + assertEquals(sessionTimeoutMs, request.data().rebalanceTimeoutMs()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/LeaderAndIsrRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/LeaderAndIsrRequestTest.java new file mode 100644 index 0000000..de9914c --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/LeaderAndIsrRequestTest.java @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.errors.ClusterAuthorizationException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.LeaderAndIsrRequestData; +import org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrLiveLeader; +import org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrPartitionState; +import org.apache.kafka.common.message.LeaderAndIsrResponseData.LeaderAndIsrPartitionError; +import org.apache.kafka.common.message.LeaderAndIsrResponseData.LeaderAndIsrTopicError; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; + +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; +import static org.apache.kafka.common.protocol.ApiKeys.LEADER_AND_ISR; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class LeaderAndIsrRequestTest { + + @Test + public void testUnsupportedVersion() { + LeaderAndIsrRequest.Builder builder = new LeaderAndIsrRequest.Builder( + (short) (LEADER_AND_ISR.latestVersion() + 1), 0, 0, 0, + Collections.emptyList(), Collections.emptyMap(), Collections.emptySet()); + assertThrows(UnsupportedVersionException.class, builder::build); + } + + @Test + public void testGetErrorResponse() { + Uuid topicId = Uuid.randomUuid(); + String topicName = "topic"; + int partition = 0; + for (short version : LEADER_AND_ISR.allVersions()) { + LeaderAndIsrRequest request = new LeaderAndIsrRequest.Builder(version, 0, 0, 0, + Collections.singletonList(new LeaderAndIsrPartitionState() + .setTopicName(topicName) + .setPartitionIndex(partition)), + Collections.singletonMap(topicName, topicId), + Collections.emptySet() + ).build(version); + + LeaderAndIsrResponse response = request.getErrorResponse(0, + new ClusterAuthorizationException("Not authorized")); + + assertEquals(Errors.CLUSTER_AUTHORIZATION_FAILED, response.error()); + + if (version < 5) { + assertEquals( + Collections.singletonList(new LeaderAndIsrPartitionError() + .setTopicName(topicName) + .setPartitionIndex(partition) + .setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code())), + response.data().partitionErrors()); + assertEquals(0, response.data().topics().size()); + } else { + LeaderAndIsrTopicError topicState = response.topics().find(topicId); + assertEquals(topicId, topicState.topicId()); + assertEquals( + Collections.singletonList(new LeaderAndIsrPartitionError() + .setPartitionIndex(partition) + .setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code())), + topicState.partitionErrors()); + assertEquals(0, response.data().partitionErrors().size()); + } + } + } + + /** + * Verifies the logic we have in LeaderAndIsrRequest to present a unified interface across the various versions + * works correctly. For example, `LeaderAndIsrPartitionState.topicName` is not serialiazed/deserialized in + * recent versions, but we set it manually so that we can always present the ungrouped partition states + * independently of the version. + */ + @Test + public void testVersionLogic() { + for (short version : LEADER_AND_ISR.allVersions()) { + List partitionStates = asList( + new LeaderAndIsrPartitionState() + .setTopicName("topic0") + .setPartitionIndex(0) + .setControllerEpoch(2) + .setLeader(0) + .setLeaderEpoch(10) + .setIsr(asList(0, 1)) + .setZkVersion(10) + .setReplicas(asList(0, 1, 2)) + .setAddingReplicas(asList(3)) + .setRemovingReplicas(asList(2)), + new LeaderAndIsrPartitionState() + .setTopicName("topic0") + .setPartitionIndex(1) + .setControllerEpoch(2) + .setLeader(1) + .setLeaderEpoch(11) + .setIsr(asList(1, 2, 3)) + .setZkVersion(11) + .setReplicas(asList(1, 2, 3)) + .setAddingReplicas(emptyList()) + .setRemovingReplicas(emptyList()), + new LeaderAndIsrPartitionState() + .setTopicName("topic1") + .setPartitionIndex(0) + .setControllerEpoch(2) + .setLeader(2) + .setLeaderEpoch(11) + .setIsr(asList(2, 3, 4)) + .setZkVersion(11) + .setReplicas(asList(2, 3, 4)) + .setAddingReplicas(emptyList()) + .setRemovingReplicas(emptyList()) + ); + + List liveNodes = asList( + new Node(0, "host0", 9090), + new Node(1, "host1", 9091) + ); + + Map topicIds = new HashMap<>(); + topicIds.put("topic0", Uuid.randomUuid()); + topicIds.put("topic1", Uuid.randomUuid()); + + LeaderAndIsrRequest request = new LeaderAndIsrRequest.Builder(version, 1, 2, 3, partitionStates, + topicIds, liveNodes).build(); + + List liveLeaders = liveNodes.stream().map(n -> new LeaderAndIsrLiveLeader() + .setBrokerId(n.id()) + .setHostName(n.host()) + .setPort(n.port())).collect(Collectors.toList()); + assertEquals(new HashSet<>(partitionStates), iterableToSet(request.partitionStates())); + assertEquals(liveLeaders, request.liveLeaders()); + assertEquals(1, request.controllerId()); + assertEquals(2, request.controllerEpoch()); + assertEquals(3, request.brokerEpoch()); + + ByteBuffer byteBuffer = request.serialize(); + LeaderAndIsrRequest deserializedRequest = new LeaderAndIsrRequest(new LeaderAndIsrRequestData( + new ByteBufferAccessor(byteBuffer), version), version); + + // Adding/removing replicas is only supported from version 3, so the deserialized request won't have + // them for earlier versions. + if (version < 3) { + partitionStates.get(0) + .setAddingReplicas(emptyList()) + .setRemovingReplicas(emptyList()); + } + + // Prior to version 2, there were no TopicStates, so a map of Topic Ids from a list of + // TopicStates is an empty map. + if (version < 2) { + topicIds = new HashMap<>(); + } + + // In versions 2-4 there are TopicStates, but no topicIds, so deserialized requests will have + // Zero Uuids in place. + if (version > 1 && version < 5) { + topicIds.put("topic0", Uuid.ZERO_UUID); + topicIds.put("topic1", Uuid.ZERO_UUID); + } + + assertEquals(new HashSet<>(partitionStates), iterableToSet(deserializedRequest.partitionStates())); + assertEquals(topicIds, deserializedRequest.topicIds()); + assertEquals(liveLeaders, deserializedRequest.liveLeaders()); + assertEquals(1, request.controllerId()); + assertEquals(2, request.controllerEpoch()); + assertEquals(3, request.brokerEpoch()); + } + } + + @Test + public void testTopicPartitionGroupingSizeReduction() { + Set tps = TestUtils.generateRandomTopicPartitions(10, 10); + List partitionStates = new ArrayList<>(); + Map topicIds = new HashMap<>(); + for (TopicPartition tp : tps) { + partitionStates.add(new LeaderAndIsrPartitionState() + .setTopicName(tp.topic()) + .setPartitionIndex(tp.partition())); + topicIds.put(tp.topic(), Uuid.randomUuid()); + } + LeaderAndIsrRequest.Builder builder = new LeaderAndIsrRequest.Builder((short) 2, 0, 0, 0, + partitionStates, topicIds, Collections.emptySet()); + + LeaderAndIsrRequest v2 = builder.build((short) 2); + LeaderAndIsrRequest v1 = builder.build((short) 1); + assertTrue(v2.sizeInBytes() < v1.sizeInBytes(), "Expected v2 < v1: v2=" + v2.sizeInBytes() + ", v1=" + v1.sizeInBytes()); + } + + private Set iterableToSet(Iterable iterable) { + return StreamSupport.stream(iterable.spliterator(), false).collect(Collectors.toSet()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/LeaderAndIsrResponseTest.java b/clients/src/test/java/org/apache/kafka/common/requests/LeaderAndIsrResponseTest.java new file mode 100644 index 0000000..9f46304 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/LeaderAndIsrResponseTest.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrPartitionState; +import org.apache.kafka.common.message.LeaderAndIsrResponseData; +import org.apache.kafka.common.message.LeaderAndIsrResponseData.LeaderAndIsrTopicError; +import org.apache.kafka.common.message.LeaderAndIsrResponseData.LeaderAndIsrPartitionError; +import org.apache.kafka.common.message.LeaderAndIsrResponseData.LeaderAndIsrTopicErrorCollection; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static java.util.Arrays.asList; +import static org.apache.kafka.common.protocol.ApiKeys.LEADER_AND_ISR; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class LeaderAndIsrResponseTest { + + @Test + public void testErrorCountsFromGetErrorResponse() { + List partitionStates = new ArrayList<>(); + partitionStates.add(new LeaderAndIsrPartitionState() + .setTopicName("foo") + .setPartitionIndex(0) + .setControllerEpoch(15) + .setLeader(1) + .setLeaderEpoch(10) + .setIsr(Collections.singletonList(10)) + .setZkVersion(20) + .setReplicas(Collections.singletonList(10)) + .setIsNew(false)); + partitionStates.add(new LeaderAndIsrPartitionState() + .setTopicName("foo") + .setPartitionIndex(1) + .setControllerEpoch(15) + .setLeader(1) + .setLeaderEpoch(10) + .setIsr(Collections.singletonList(10)) + .setZkVersion(20) + .setReplicas(Collections.singletonList(10)) + .setIsNew(false)); + Map topicIds = Collections.singletonMap("foo", Uuid.randomUuid()); + + LeaderAndIsrRequest request = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion(), + 15, 20, 0, partitionStates, topicIds, Collections.emptySet()).build(); + LeaderAndIsrResponse response = request.getErrorResponse(0, Errors.CLUSTER_AUTHORIZATION_FAILED.exception()); + assertEquals(Collections.singletonMap(Errors.CLUSTER_AUTHORIZATION_FAILED, 3), response.errorCounts()); + } + + @Test + public void testErrorCountsWithTopLevelError() { + for (short version : LEADER_AND_ISR.allVersions()) { + LeaderAndIsrResponse response; + if (version < 5) { + List partitions = createPartitions("foo", + asList(Errors.NONE, Errors.NOT_LEADER_OR_FOLLOWER)); + response = new LeaderAndIsrResponse(new LeaderAndIsrResponseData() + .setErrorCode(Errors.UNKNOWN_SERVER_ERROR.code()) + .setPartitionErrors(partitions), version); + } else { + Uuid id = Uuid.randomUuid(); + LeaderAndIsrTopicErrorCollection topics = createTopic(id, asList(Errors.NONE, Errors.NOT_LEADER_OR_FOLLOWER)); + response = new LeaderAndIsrResponse(new LeaderAndIsrResponseData() + .setErrorCode(Errors.UNKNOWN_SERVER_ERROR.code()) + .setTopics(topics), version); + } + assertEquals(Collections.singletonMap(Errors.UNKNOWN_SERVER_ERROR, 3), response.errorCounts()); + } + } + + @Test + public void testErrorCountsNoTopLevelError() { + for (short version : LEADER_AND_ISR.allVersions()) { + LeaderAndIsrResponse response; + if (version < 5) { + List partitions = createPartitions("foo", + asList(Errors.NONE, Errors.CLUSTER_AUTHORIZATION_FAILED)); + response = new LeaderAndIsrResponse(new LeaderAndIsrResponseData() + .setErrorCode(Errors.NONE.code()) + .setPartitionErrors(partitions), version); + } else { + Uuid id = Uuid.randomUuid(); + LeaderAndIsrTopicErrorCollection topics = createTopic(id, asList(Errors.NONE, Errors.CLUSTER_AUTHORIZATION_FAILED)); + response = new LeaderAndIsrResponse(new LeaderAndIsrResponseData() + .setErrorCode(Errors.NONE.code()) + .setTopics(topics), version); + } + Map errorCounts = response.errorCounts(); + assertEquals(2, errorCounts.size()); + assertEquals(2, errorCounts.get(Errors.NONE).intValue()); + assertEquals(1, errorCounts.get(Errors.CLUSTER_AUTHORIZATION_FAILED).intValue()); + } + } + + @Test + public void testToString() { + for (short version : LEADER_AND_ISR.allVersions()) { + LeaderAndIsrResponse response; + if (version < 5) { + List partitions = createPartitions("foo", + asList(Errors.NONE, Errors.CLUSTER_AUTHORIZATION_FAILED)); + response = new LeaderAndIsrResponse(new LeaderAndIsrResponseData() + .setErrorCode(Errors.NONE.code()) + .setPartitionErrors(partitions), version); + String responseStr = response.toString(); + assertTrue(responseStr.contains(LeaderAndIsrResponse.class.getSimpleName())); + assertTrue(responseStr.contains(partitions.toString())); + assertTrue(responseStr.contains("errorCode=" + Errors.NONE.code())); + + } else { + Uuid id = Uuid.randomUuid(); + LeaderAndIsrTopicErrorCollection topics = createTopic(id, asList(Errors.NONE, Errors.CLUSTER_AUTHORIZATION_FAILED)); + response = new LeaderAndIsrResponse(new LeaderAndIsrResponseData() + .setErrorCode(Errors.NONE.code()) + .setTopics(topics), version); + String responseStr = response.toString(); + assertTrue(responseStr.contains(LeaderAndIsrResponse.class.getSimpleName())); + assertTrue(responseStr.contains(topics.toString())); + assertTrue(responseStr.contains(id.toString())); + assertTrue(responseStr.contains("errorCode=" + Errors.NONE.code())); + } + } + } + + private List createPartitions(String topicName, List errors) { + List partitions = new ArrayList<>(); + int partitionIndex = 0; + for (Errors error : errors) { + partitions.add(new LeaderAndIsrPartitionError() + .setTopicName(topicName) + .setPartitionIndex(partitionIndex++) + .setErrorCode(error.code())); + } + return partitions; + } + + private LeaderAndIsrTopicErrorCollection createTopic(Uuid id, List errors) { + LeaderAndIsrTopicErrorCollection topics = new LeaderAndIsrTopicErrorCollection(); + LeaderAndIsrTopicError topic = new LeaderAndIsrTopicError(); + topic.setTopicId(id); + List partitions = new ArrayList<>(); + int partitionIndex = 0; + for (Errors error : errors) { + partitions.add(new LeaderAndIsrPartitionError() + .setPartitionIndex(partitionIndex++) + .setErrorCode(error.code())); + } + topic.setPartitionErrors(partitions); + topics.add(topic); + return topics; + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/LeaveGroupRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/LeaveGroupRequestTest.java new file mode 100644 index 0000000..1694ef5 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/LeaveGroupRequestTest.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.LeaveGroupRequestData; +import org.apache.kafka.common.message.LeaveGroupRequestData.MemberIdentity; +import org.apache.kafka.common.message.LeaveGroupResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class LeaveGroupRequestTest { + + private final String groupId = "group_id"; + private final String memberIdOne = "member_1"; + private final String instanceIdOne = "instance_1"; + private final String memberIdTwo = "member_2"; + private final String instanceIdTwo = "instance_2"; + + private final int throttleTimeMs = 10; + + private LeaveGroupRequest.Builder builder; + private List members; + + @BeforeEach + public void setUp() { + members = Arrays.asList(new MemberIdentity() + .setMemberId(memberIdOne) + .setGroupInstanceId(instanceIdOne), + new MemberIdentity() + .setMemberId(memberIdTwo) + .setGroupInstanceId(instanceIdTwo)); + builder = new LeaveGroupRequest.Builder( + groupId, + members + ); + } + + @Test + public void testMultiLeaveConstructor() { + final LeaveGroupRequestData expectedData = new LeaveGroupRequestData() + .setGroupId(groupId) + .setMembers(members); + + for (short version : ApiKeys.LEAVE_GROUP.allVersions()) { + try { + LeaveGroupRequest request = builder.build(version); + if (version <= 2) { + fail("Older version " + version + + " request data should not be created due to non-single members"); + } + assertEquals(expectedData, request.data()); + assertEquals(members, request.members()); + + LeaveGroupResponse expectedResponse = new LeaveGroupResponse( + Collections.emptyList(), + Errors.COORDINATOR_LOAD_IN_PROGRESS, + throttleTimeMs, + version + ); + + assertEquals(expectedResponse, request.getErrorResponse(throttleTimeMs, + Errors.COORDINATOR_LOAD_IN_PROGRESS.exception())); + } catch (UnsupportedVersionException e) { + assertTrue(e.getMessage().contains("leave group request only supports single member instance")); + } + } + + } + + @Test + public void testSingleLeaveConstructor() { + final LeaveGroupRequestData expectedData = new LeaveGroupRequestData() + .setGroupId(groupId) + .setMemberId(memberIdOne); + List singleMember = Collections.singletonList( + new MemberIdentity() + .setMemberId(memberIdOne)); + + builder = new LeaveGroupRequest.Builder(groupId, singleMember); + + for (short version = 0; version <= 2; version++) { + LeaveGroupRequest request = builder.build(version); + assertEquals(expectedData, request.data()); + assertEquals(singleMember, request.members()); + + int expectedThrottleTime = version >= 1 ? throttleTimeMs + : AbstractResponse.DEFAULT_THROTTLE_TIME; + LeaveGroupResponse expectedResponse = new LeaveGroupResponse( + new LeaveGroupResponseData() + .setErrorCode(Errors.NOT_CONTROLLER.code()) + .setThrottleTimeMs(expectedThrottleTime) + ); + + assertEquals(expectedResponse, request.getErrorResponse(throttleTimeMs, + Errors.NOT_CONTROLLER.exception())); + } + } + + @Test + public void testBuildEmptyMembers() { + assertThrows(IllegalArgumentException.class, + () -> new LeaveGroupRequest.Builder(groupId, Collections.emptyList())); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/LeaveGroupResponseTest.java b/clients/src/test/java/org/apache/kafka/common/requests/LeaveGroupResponseTest.java new file mode 100644 index 0000000..d513218 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/LeaveGroupResponseTest.java @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.LeaveGroupResponseData; +import org.apache.kafka.common.message.LeaveGroupResponseData.MemberResponse; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.protocol.MessageUtil; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.kafka.common.requests.AbstractResponse.DEFAULT_THROTTLE_TIME; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class LeaveGroupResponseTest { + + private final String memberIdOne = "member_1"; + private final String instanceIdOne = "instance_1"; + private final String memberIdTwo = "member_2"; + private final String instanceIdTwo = "instance_2"; + + private final int throttleTimeMs = 10; + + private List memberResponses; + + @BeforeEach + public void setUp() { + memberResponses = Arrays.asList(new MemberResponse() + .setMemberId(memberIdOne) + .setGroupInstanceId(instanceIdOne) + .setErrorCode(Errors.UNKNOWN_MEMBER_ID.code()), + new MemberResponse() + .setMemberId(memberIdTwo) + .setGroupInstanceId(instanceIdTwo) + .setErrorCode(Errors.FENCED_INSTANCE_ID.code()) + ); + } + + @Test + public void testConstructorWithMemberResponses() { + Map expectedErrorCounts = new HashMap<>(); + expectedErrorCounts.put(Errors.NONE, 1); // top level + expectedErrorCounts.put(Errors.UNKNOWN_MEMBER_ID, 1); + expectedErrorCounts.put(Errors.FENCED_INSTANCE_ID, 1); + + for (short version : ApiKeys.LEAVE_GROUP.allVersions()) { + LeaveGroupResponse leaveGroupResponse = new LeaveGroupResponse(memberResponses, + Errors.NONE, + throttleTimeMs, + version); + + if (version >= 3) { + assertEquals(expectedErrorCounts, leaveGroupResponse.errorCounts()); + assertEquals(memberResponses, leaveGroupResponse.memberResponses()); + } else { + assertEquals(Collections.singletonMap(Errors.UNKNOWN_MEMBER_ID, 1), + leaveGroupResponse.errorCounts()); + assertEquals(Collections.emptyList(), leaveGroupResponse.memberResponses()); + } + + if (version >= 1) { + assertEquals(throttleTimeMs, leaveGroupResponse.throttleTimeMs()); + } else { + assertEquals(DEFAULT_THROTTLE_TIME, leaveGroupResponse.throttleTimeMs()); + } + + assertEquals(Errors.UNKNOWN_MEMBER_ID, leaveGroupResponse.error()); + } + } + + @Test + public void testShouldThrottle() { + LeaveGroupResponse response = new LeaveGroupResponse(new LeaveGroupResponseData()); + for (short version : ApiKeys.LEAVE_GROUP.allVersions()) { + if (version >= 2) { + assertTrue(response.shouldClientThrottle(version)); + } else { + assertFalse(response.shouldClientThrottle(version)); + } + } + } + + @Test + public void testEqualityWithSerialization() { + LeaveGroupResponseData responseData = new LeaveGroupResponseData() + .setErrorCode(Errors.NONE.code()) + .setThrottleTimeMs(throttleTimeMs); + for (short version : ApiKeys.LEAVE_GROUP.allVersions()) { + LeaveGroupResponse primaryResponse = LeaveGroupResponse.parse( + MessageUtil.toByteBuffer(responseData, version), version); + LeaveGroupResponse secondaryResponse = LeaveGroupResponse.parse( + MessageUtil.toByteBuffer(responseData, version), version); + + assertEquals(primaryResponse, primaryResponse); + assertEquals(primaryResponse, secondaryResponse); + assertEquals(primaryResponse.hashCode(), secondaryResponse.hashCode()); + } + } + + @Test + public void testParse() { + Map expectedErrorCounts = Collections.singletonMap(Errors.NOT_COORDINATOR, 1); + + LeaveGroupResponseData data = new LeaveGroupResponseData() + .setErrorCode(Errors.NOT_COORDINATOR.code()) + .setThrottleTimeMs(throttleTimeMs); + + for (short version : ApiKeys.LEAVE_GROUP.allVersions()) { + ByteBuffer buffer = MessageUtil.toByteBuffer(data, version); + LeaveGroupResponse leaveGroupResponse = LeaveGroupResponse.parse(buffer, version); + assertEquals(expectedErrorCounts, leaveGroupResponse.errorCounts()); + + if (version >= 1) { + assertEquals(throttleTimeMs, leaveGroupResponse.throttleTimeMs()); + } else { + assertEquals(DEFAULT_THROTTLE_TIME, leaveGroupResponse.throttleTimeMs()); + } + + assertEquals(Errors.NOT_COORDINATOR, leaveGroupResponse.error()); + } + } + + @Test + public void testEqualityWithMemberResponses() { + for (short version : ApiKeys.LEAVE_GROUP.allVersions()) { + List localResponses = version > 2 ? memberResponses : memberResponses.subList(0, 1); + LeaveGroupResponse primaryResponse = new LeaveGroupResponse(localResponses, + Errors.NONE, + throttleTimeMs, + version); + + // The order of members should not alter result data. + Collections.reverse(localResponses); + LeaveGroupResponse reversedResponse = new LeaveGroupResponse(localResponses, + Errors.NONE, + throttleTimeMs, + version); + + assertEquals(primaryResponse, primaryResponse); + assertEquals(primaryResponse, reversedResponse); + assertEquals(primaryResponse.hashCode(), reversedResponse.hashCode()); + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/ListOffsetsRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/ListOffsetsRequestTest.java new file mode 100644 index 0000000..83c4b10 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/ListOffsetsRequestTest.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.kafka.common.IsolationLevel; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.ListOffsetsRequestData; +import org.apache.kafka.common.message.ListOffsetsRequestData.ListOffsetsPartition; +import org.apache.kafka.common.message.ListOffsetsRequestData.ListOffsetsTopic; +import org.apache.kafka.common.message.ListOffsetsResponseData; +import org.apache.kafka.common.message.ListOffsetsResponseData.ListOffsetsPartitionResponse; +import org.apache.kafka.common.message.ListOffsetsResponseData.ListOffsetsTopicResponse; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.protocol.MessageUtil; +import org.junit.jupiter.api.Test; + +public class ListOffsetsRequestTest { + + @Test + public void testDuplicatePartitions() { + List topics = Collections.singletonList( + new ListOffsetsTopic() + .setName("topic") + .setPartitions(Arrays.asList( + new ListOffsetsPartition() + .setPartitionIndex(0), + new ListOffsetsPartition() + .setPartitionIndex(0)))); + ListOffsetsRequestData data = new ListOffsetsRequestData() + .setTopics(topics) + .setReplicaId(-1); + ListOffsetsRequest request = ListOffsetsRequest.parse(MessageUtil.toByteBuffer(data, (short) 0), (short) 0); + assertEquals(Collections.singleton(new TopicPartition("topic", 0)), request.duplicatePartitions()); + } + + @Test + public void testGetErrorResponse() { + for (short version = 1; version <= ApiKeys.LIST_OFFSETS.latestVersion(); version++) { + List topics = Arrays.asList( + new ListOffsetsTopic() + .setName("topic") + .setPartitions(Collections.singletonList( + new ListOffsetsPartition() + .setPartitionIndex(0)))); + ListOffsetsRequest request = ListOffsetsRequest.Builder + .forConsumer(true, IsolationLevel.READ_COMMITTED, false) + .setTargetTimes(topics) + .build(version); + ListOffsetsResponse response = (ListOffsetsResponse) request.getErrorResponse(0, Errors.NOT_LEADER_OR_FOLLOWER.exception()); + + List v = Collections.singletonList( + new ListOffsetsTopicResponse() + .setName("topic") + .setPartitions(Collections.singletonList( + new ListOffsetsPartitionResponse() + .setErrorCode(Errors.NOT_LEADER_OR_FOLLOWER.code()) + .setLeaderEpoch(ListOffsetsResponse.UNKNOWN_EPOCH) + .setOffset(ListOffsetsResponse.UNKNOWN_OFFSET) + .setPartitionIndex(0) + .setTimestamp(ListOffsetsResponse.UNKNOWN_TIMESTAMP)))); + ListOffsetsResponseData data = new ListOffsetsResponseData() + .setThrottleTimeMs(0) + .setTopics(v); + ListOffsetsResponse expectedResponse = new ListOffsetsResponse(data); + assertEquals(expectedResponse.data().topics(), response.data().topics()); + assertEquals(expectedResponse.throttleTimeMs(), response.throttleTimeMs()); + } + } + + @Test + public void testGetErrorResponseV0() { + List topics = Arrays.asList( + new ListOffsetsTopic() + .setName("topic") + .setPartitions(Collections.singletonList( + new ListOffsetsPartition() + .setPartitionIndex(0)))); + ListOffsetsRequest request = ListOffsetsRequest.Builder + .forConsumer(true, IsolationLevel.READ_UNCOMMITTED, false) + .setTargetTimes(topics) + .build((short) 0); + ListOffsetsResponse response = (ListOffsetsResponse) request.getErrorResponse(0, Errors.NOT_LEADER_OR_FOLLOWER.exception()); + + List v = Collections.singletonList( + new ListOffsetsTopicResponse() + .setName("topic") + .setPartitions(Collections.singletonList( + new ListOffsetsPartitionResponse() + .setErrorCode(Errors.NOT_LEADER_OR_FOLLOWER.code()) + .setOldStyleOffsets(Collections.emptyList()) + .setPartitionIndex(0)))); + ListOffsetsResponseData data = new ListOffsetsResponseData() + .setThrottleTimeMs(0) + .setTopics(v); + ListOffsetsResponse expectedResponse = new ListOffsetsResponse(data); + assertEquals(expectedResponse.data().topics(), response.data().topics()); + assertEquals(expectedResponse.throttleTimeMs(), response.throttleTimeMs()); + } + + @Test + public void testToListOffsetsTopics() { + ListOffsetsPartition lop0 = new ListOffsetsPartition() + .setPartitionIndex(0) + .setCurrentLeaderEpoch(1) + .setMaxNumOffsets(2) + .setTimestamp(123L); + ListOffsetsPartition lop1 = new ListOffsetsPartition() + .setPartitionIndex(1) + .setCurrentLeaderEpoch(3) + .setMaxNumOffsets(4) + .setTimestamp(567L); + Map timestampsToSearch = new HashMap<>(); + timestampsToSearch.put(new TopicPartition("topic", 0), lop0); + timestampsToSearch.put(new TopicPartition("topic", 1), lop1); + List listOffsetTopics = ListOffsetsRequest.toListOffsetsTopics(timestampsToSearch); + assertEquals(1, listOffsetTopics.size()); + ListOffsetsTopic topic = listOffsetTopics.get(0); + assertEquals("topic", topic.name()); + assertEquals(2, topic.partitions().size()); + assertTrue(topic.partitions().contains(lop0)); + assertTrue(topic.partitions().contains(lop1)); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/MetadataRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/MetadataRequestTest.java new file mode 100644 index 0000000..74c217d --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/MetadataRequestTest.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.MetadataRequestData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class MetadataRequestTest { + + @Test + public void testEmptyMeansAllTopicsV0() { + MetadataRequestData data = new MetadataRequestData(); + MetadataRequest parsedRequest = new MetadataRequest(data, (short) 0); + assertTrue(parsedRequest.isAllTopics()); + assertNull(parsedRequest.topics()); + } + + @Test + public void testEmptyMeansEmptyForVersionsAboveV0() { + for (int i = 1; i < MetadataRequestData.SCHEMAS.length; i++) { + MetadataRequestData data = new MetadataRequestData(); + data.setAllowAutoTopicCreation(true); + MetadataRequest parsedRequest = new MetadataRequest(data, (short) i); + assertFalse(parsedRequest.isAllTopics()); + assertEquals(Collections.emptyList(), parsedRequest.topics()); + } + } + + @Test + public void testMetadataRequestVersion() { + MetadataRequest.Builder builder = new MetadataRequest.Builder(Collections.singletonList("topic"), false); + assertEquals(ApiKeys.METADATA.oldestVersion(), builder.oldestAllowedVersion()); + assertEquals(ApiKeys.METADATA.latestVersion(), builder.latestAllowedVersion()); + + short version = 5; + MetadataRequest.Builder builder2 = new MetadataRequest.Builder(Collections.singletonList("topic"), false, version); + assertEquals(version, builder2.oldestAllowedVersion()); + assertEquals(version, builder2.latestAllowedVersion()); + + short minVersion = 1; + short maxVersion = 6; + MetadataRequest.Builder builder3 = new MetadataRequest.Builder(Collections.singletonList("topic"), false, minVersion, maxVersion); + assertEquals(minVersion, builder3.oldestAllowedVersion()); + assertEquals(maxVersion, builder3.latestAllowedVersion()); + } + + @Test + public void testTopicIdAndNullTopicNameRequests() { + // Construct invalid MetadataRequestTopics. We will build each one separately and ensure the error is thrown. + List topics = Arrays.asList( + new MetadataRequestData.MetadataRequestTopic().setName(null).setTopicId(Uuid.randomUuid()), + new MetadataRequestData.MetadataRequestTopic().setName(null), + new MetadataRequestData.MetadataRequestTopic().setTopicId(Uuid.randomUuid()), + new MetadataRequestData.MetadataRequestTopic().setName("topic").setTopicId(Uuid.randomUuid())); + + // if version is 10 or 11, the invalid topic metadata should return an error + List invalidVersions = Arrays.asList((short) 10, (short) 11); + invalidVersions.forEach(version -> + topics.forEach(topic -> { + MetadataRequestData metadataRequestData = new MetadataRequestData().setTopics(Collections.singletonList(topic)); + MetadataRequest.Builder builder = new MetadataRequest.Builder(metadataRequestData); + assertThrows(UnsupportedVersionException.class, () -> builder.build(version)); + }) + ); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/OffsetCommitRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/OffsetCommitRequestTest.java new file mode 100644 index 0000000..08ae7a3 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/OffsetCommitRequestTest.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.OffsetCommitRequestData; +import org.apache.kafka.common.message.OffsetCommitRequestData.OffsetCommitRequestPartition; +import org.apache.kafka.common.message.OffsetCommitRequestData.OffsetCommitRequestTopic; +import org.apache.kafka.common.message.OffsetCommitResponseData.OffsetCommitResponsePartition; +import org.apache.kafka.common.message.OffsetCommitResponseData.OffsetCommitResponseTopic; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.kafka.common.requests.OffsetCommitRequest.getErrorResponseTopics; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class OffsetCommitRequestTest { + + protected static String groupId = "groupId"; + protected static String memberId = "consumerId"; + protected static String groupInstanceId = "groupInstanceId"; + protected static String topicOne = "topicOne"; + protected static String topicTwo = "topicTwo"; + protected static int partitionOne = 1; + protected static int partitionTwo = 2; + protected static long offset = 100L; + protected static short leaderEpoch = 20; + protected static String metadata = "metadata"; + + protected static int throttleTimeMs = 10; + + private static OffsetCommitRequestData data; + private static List topics; + + @BeforeEach + public void setUp() { + topics = Arrays.asList( + new OffsetCommitRequestTopic() + .setName(topicOne) + .setPartitions(Collections.singletonList( + new OffsetCommitRequestPartition() + .setPartitionIndex(partitionOne) + .setCommittedOffset(offset) + .setCommittedLeaderEpoch(leaderEpoch) + .setCommittedMetadata(metadata) + )), + new OffsetCommitRequestTopic() + .setName(topicTwo) + .setPartitions(Collections.singletonList( + new OffsetCommitRequestPartition() + .setPartitionIndex(partitionTwo) + .setCommittedOffset(offset) + .setCommittedLeaderEpoch(leaderEpoch) + .setCommittedMetadata(metadata) + )) + ); + data = new OffsetCommitRequestData() + .setGroupId(groupId) + .setTopics(topics); + } + + @Test + public void testConstructor() { + Map expectedOffsets = new HashMap<>(); + expectedOffsets.put(new TopicPartition(topicOne, partitionOne), offset); + expectedOffsets.put(new TopicPartition(topicTwo, partitionTwo), offset); + + OffsetCommitRequest.Builder builder = new OffsetCommitRequest.Builder(data); + + for (short version : ApiKeys.TXN_OFFSET_COMMIT.allVersions()) { + OffsetCommitRequest request = builder.build(version); + assertEquals(expectedOffsets, request.offsets()); + + OffsetCommitResponse response = request.getErrorResponse(throttleTimeMs, Errors.NOT_COORDINATOR.exception()); + + assertEquals(Collections.singletonMap(Errors.NOT_COORDINATOR, 2), response.errorCounts()); + assertEquals(throttleTimeMs, response.throttleTimeMs()); + } + } + + @Test + public void testGetErrorResponseTopics() { + List expectedTopics = Arrays.asList( + new OffsetCommitResponseTopic() + .setName(topicOne) + .setPartitions(Collections.singletonList( + new OffsetCommitResponsePartition() + .setErrorCode(Errors.UNKNOWN_MEMBER_ID.code()) + .setPartitionIndex(partitionOne))), + new OffsetCommitResponseTopic() + .setName(topicTwo) + .setPartitions(Collections.singletonList( + new OffsetCommitResponsePartition() + .setErrorCode(Errors.UNKNOWN_MEMBER_ID.code()) + .setPartitionIndex(partitionTwo))) + ); + assertEquals(expectedTopics, getErrorResponseTopics(topics, Errors.UNKNOWN_MEMBER_ID)); + } + + @Test + public void testVersionSupportForGroupInstanceId() { + OffsetCommitRequest.Builder builder = new OffsetCommitRequest.Builder( + new OffsetCommitRequestData() + .setGroupId(groupId) + .setMemberId(memberId) + .setGroupInstanceId(groupInstanceId) + ); + + for (short version : ApiKeys.OFFSET_COMMIT.allVersions()) { + if (version >= 7) { + builder.build(version); + } else { + final short finalVersion = version; + assertThrows(UnsupportedVersionException.class, () -> builder.build(finalVersion)); + } + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/OffsetCommitResponseTest.java b/clients/src/test/java/org/apache/kafka/common/requests/OffsetCommitResponseTest.java new file mode 100644 index 0000000..b9ce03f --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/OffsetCommitResponseTest.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.OffsetCommitResponseData; +import org.apache.kafka.common.message.OffsetCommitResponseData.OffsetCommitResponsePartition; +import org.apache.kafka.common.message.OffsetCommitResponseData.OffsetCommitResponseTopic; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.protocol.MessageUtil; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.apache.kafka.common.requests.AbstractResponse.DEFAULT_THROTTLE_TIME; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class OffsetCommitResponseTest { + + protected final int throttleTimeMs = 10; + + protected final String topicOne = "topic1"; + protected final int partitionOne = 1; + protected final Errors errorOne = Errors.COORDINATOR_NOT_AVAILABLE; + protected final Errors errorTwo = Errors.NOT_COORDINATOR; + protected final String topicTwo = "topic2"; + protected final int partitionTwo = 2; + + protected TopicPartition tp1 = new TopicPartition(topicOne, partitionOne); + protected TopicPartition tp2 = new TopicPartition(topicTwo, partitionTwo); + protected Map expectedErrorCounts; + protected Map errorsMap; + + @BeforeEach + public void setUp() { + expectedErrorCounts = new HashMap<>(); + expectedErrorCounts.put(errorOne, 1); + expectedErrorCounts.put(errorTwo, 1); + + errorsMap = new HashMap<>(); + errorsMap.put(tp1, errorOne); + errorsMap.put(tp2, errorTwo); + } + + @Test + public void testConstructorWithErrorResponse() { + OffsetCommitResponse response = new OffsetCommitResponse(throttleTimeMs, errorsMap); + + assertEquals(expectedErrorCounts, response.errorCounts()); + assertEquals(throttleTimeMs, response.throttleTimeMs()); + } + + @Test + public void testParse() { + OffsetCommitResponseData data = new OffsetCommitResponseData() + .setTopics(Arrays.asList( + new OffsetCommitResponseTopic().setPartitions( + Collections.singletonList(new OffsetCommitResponsePartition() + .setPartitionIndex(partitionOne) + .setErrorCode(errorOne.code()))), + new OffsetCommitResponseTopic().setPartitions( + Collections.singletonList(new OffsetCommitResponsePartition() + .setPartitionIndex(partitionTwo) + .setErrorCode(errorTwo.code()))) + )) + .setThrottleTimeMs(throttleTimeMs); + + for (short version : ApiKeys.OFFSET_COMMIT.allVersions()) { + ByteBuffer buffer = MessageUtil.toByteBuffer(data, version); + OffsetCommitResponse response = OffsetCommitResponse.parse(buffer, version); + assertEquals(expectedErrorCounts, response.errorCounts()); + + if (version >= 3) { + assertEquals(throttleTimeMs, response.throttleTimeMs()); + } else { + assertEquals(DEFAULT_THROTTLE_TIME, response.throttleTimeMs()); + } + + assertEquals(version >= 4, response.shouldClientThrottle(version)); + } + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/OffsetFetchRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/OffsetFetchRequestTest.java new file mode 100644 index 0000000..37076d0 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/OffsetFetchRequestTest.java @@ -0,0 +1,235 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.OffsetFetchRequestData.OffsetFetchRequestTopics; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.OffsetFetchRequest.Builder; +import org.apache.kafka.common.requests.OffsetFetchRequest.NoBatchedOffsetFetchRequestException; +import org.apache.kafka.common.requests.OffsetFetchResponse.PartitionData; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.apache.kafka.common.requests.AbstractResponse.DEFAULT_THROTTLE_TIME; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class OffsetFetchRequestTest { + + private final String topicOne = "topic1"; + private final int partitionOne = 1; + private final String topicTwo = "topic2"; + private final int partitionTwo = 2; + private final String topicThree = "topic3"; + private final String group1 = "group1"; + private final String group2 = "group2"; + private final String group3 = "group3"; + private final String group4 = "group4"; + private final String group5 = "group5"; + private List groups = Arrays.asList(group1, group2, group3, group4, group5); + + private final List listOfVersionsNonBatchOffsetFetch = Arrays.asList(0, 1, 2, 3, 4, 5, 6, 7); + + + private OffsetFetchRequest.Builder builder; + + @Test + public void testConstructor() { + List partitions = Arrays.asList( + new TopicPartition(topicOne, partitionOne), + new TopicPartition(topicTwo, partitionTwo)); + int throttleTimeMs = 10; + + Map expectedData = new HashMap<>(); + for (TopicPartition partition : partitions) { + expectedData.put(partition, new PartitionData( + OffsetFetchResponse.INVALID_OFFSET, + Optional.empty(), + OffsetFetchResponse.NO_METADATA, + Errors.NONE + )); + } + + for (short version : ApiKeys.OFFSET_FETCH.allVersions()) { + if (version < 8) { + builder = new OffsetFetchRequest.Builder( + group1, + false, + partitions, + false); + assertFalse(builder.isAllTopicPartitions()); + OffsetFetchRequest request = builder.build(version); + assertFalse(request.isAllPartitions()); + assertEquals(group1, request.groupId()); + assertEquals(partitions, request.partitions()); + + OffsetFetchResponse response = request.getErrorResponse(throttleTimeMs, Errors.NONE); + assertEquals(Errors.NONE, response.error()); + assertFalse(response.hasError()); + assertEquals(Collections.singletonMap(Errors.NONE, version <= (short) 1 ? 3 : 1), response.errorCounts(), + "Incorrect error count for version " + version); + + if (version <= 1) { + assertEquals(expectedData, response.responseDataV0ToV7()); + } + + if (version >= 3) { + assertEquals(throttleTimeMs, response.throttleTimeMs()); + } else { + assertEquals(DEFAULT_THROTTLE_TIME, response.throttleTimeMs()); + } + } else { + builder = new Builder(Collections.singletonMap(group1, partitions), false, false); + OffsetFetchRequest request = builder.build(version); + Map> groupToPartitionMap = + request.groupIdsToPartitions(); + Map> groupToTopicMap = + request.groupIdsToTopics(); + assertFalse(request.isAllPartitionsForGroup(group1)); + assertTrue(groupToPartitionMap.containsKey(group1) && groupToTopicMap.containsKey( + group1)); + assertEquals(partitions, groupToPartitionMap.get(group1)); + OffsetFetchResponse response = request.getErrorResponse(throttleTimeMs, Errors.NONE); + assertEquals(Errors.NONE, response.groupLevelError(group1)); + assertFalse(response.groupHasError(group1)); + assertEquals(Collections.singletonMap(Errors.NONE, 1), response.errorCounts(), + "Incorrect error count for version " + version); + assertEquals(throttleTimeMs, response.throttleTimeMs()); + } + } + } + + @Test + public void testConstructorWithMultipleGroups() { + List topic1Partitions = Arrays.asList( + new TopicPartition(topicOne, partitionOne), + new TopicPartition(topicOne, partitionTwo)); + List topic2Partitions = Arrays.asList( + new TopicPartition(topicTwo, partitionOne), + new TopicPartition(topicTwo, partitionTwo)); + List topic3Partitions = Arrays.asList( + new TopicPartition(topicThree, partitionOne), + new TopicPartition(topicThree, partitionTwo)); + Map> groupToTp = new HashMap<>(); + groupToTp.put(group1, topic1Partitions); + groupToTp.put(group2, topic2Partitions); + groupToTp.put(group3, topic3Partitions); + groupToTp.put(group4, null); + groupToTp.put(group5, null); + int throttleTimeMs = 10; + + for (short version : ApiKeys.OFFSET_FETCH.allVersions()) { + if (version >= 8) { + builder = new Builder(groupToTp, false, false); + OffsetFetchRequest request = builder.build(version); + Map> groupToPartitionMap = + request.groupIdsToPartitions(); + Map> groupToTopicMap = + request.groupIdsToTopics(); + assertEquals(groupToTp.keySet(), groupToTopicMap.keySet()); + assertEquals(groupToTp.keySet(), groupToPartitionMap.keySet()); + assertFalse(request.isAllPartitionsForGroup(group1)); + assertFalse(request.isAllPartitionsForGroup(group2)); + assertFalse(request.isAllPartitionsForGroup(group3)); + assertTrue(request.isAllPartitionsForGroup(group4)); + assertTrue(request.isAllPartitionsForGroup(group5)); + OffsetFetchResponse response = request.getErrorResponse(throttleTimeMs, Errors.NONE); + for (String group : groups) { + assertEquals(Errors.NONE, response.groupLevelError(group)); + assertFalse(response.groupHasError(group)); + } + assertEquals(Collections.singletonMap(Errors.NONE, 5), response.errorCounts(), + "Incorrect error count for version " + version); + assertEquals(throttleTimeMs, response.throttleTimeMs()); + } + } + } + + @Test + public void testBuildThrowForUnsupportedBatchRequest() { + for (int version : listOfVersionsNonBatchOffsetFetch) { + Map> groupPartitionMap = new HashMap<>(); + groupPartitionMap.put(group1, null); + groupPartitionMap.put(group2, null); + builder = new Builder(groupPartitionMap, true, false); + final short finalVersion = (short) version; + assertThrows(NoBatchedOffsetFetchRequestException.class, () -> builder.build(finalVersion)); + } + } + + @Test + public void testConstructorFailForUnsupportedRequireStable() { + for (short version : ApiKeys.OFFSET_FETCH.allVersions()) { + if (version < 8) { + // The builder needs to be initialized every cycle as the internal data `requireStable` flag is flipped. + builder = new OffsetFetchRequest.Builder(group1, true, null, false); + final short finalVersion = version; + if (version < 2) { + assertThrows(UnsupportedVersionException.class, () -> builder.build(finalVersion)); + } else { + OffsetFetchRequest request = builder.build(finalVersion); + assertEquals(group1, request.groupId()); + assertNull(request.partitions()); + assertTrue(request.isAllPartitions()); + if (version < 7) { + assertFalse(request.requireStable()); + } else { + assertTrue(request.requireStable()); + } + } + } else { + builder = new Builder(Collections.singletonMap(group1, null), true, false); + OffsetFetchRequest request = builder.build(version); + Map> groupToPartitionMap = + request.groupIdsToPartitions(); + Map> groupToTopicMap = + request.groupIdsToTopics(); + assertTrue(groupToPartitionMap.containsKey(group1) && groupToTopicMap.containsKey( + group1)); + assertNull(groupToPartitionMap.get(group1)); + assertTrue(request.isAllPartitionsForGroup(group1)); + assertTrue(request.requireStable()); + } + } + } + + @Test + public void testBuildThrowForUnsupportedRequireStable() { + for (int version : listOfVersionsNonBatchOffsetFetch) { + builder = new OffsetFetchRequest.Builder(group1, true, null, true); + if (version < 7) { + final short finalVersion = (short) version; + assertThrows(UnsupportedVersionException.class, () -> builder.build(finalVersion)); + } else { + OffsetFetchRequest request = builder.build((short) version); + assertTrue(request.requireStable()); + } + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/OffsetFetchResponseTest.java b/clients/src/test/java/org/apache/kafka/common/requests/OffsetFetchResponseTest.java new file mode 100644 index 0000000..c73ea2a --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/OffsetFetchResponseTest.java @@ -0,0 +1,441 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.OffsetFetchResponseData; +import org.apache.kafka.common.message.OffsetFetchResponseData.OffsetFetchResponseGroup; +import org.apache.kafka.common.message.OffsetFetchResponseData.OffsetFetchResponsePartition; +import org.apache.kafka.common.message.OffsetFetchResponseData.OffsetFetchResponsePartitions; +import org.apache.kafka.common.message.OffsetFetchResponseData.OffsetFetchResponseTopic; +import org.apache.kafka.common.message.OffsetFetchResponseData.OffsetFetchResponseTopics; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.requests.OffsetFetchResponse.PartitionData; +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +import static org.apache.kafka.common.requests.AbstractResponse.DEFAULT_THROTTLE_TIME; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class OffsetFetchResponseTest { + private final int throttleTimeMs = 10; + private final int offset = 100; + private final String metadata = "metadata"; + + private final String groupOne = "group1"; + private final String groupTwo = "group2"; + private final String groupThree = "group3"; + private final String topicOne = "topic1"; + private final int partitionOne = 1; + private final Optional leaderEpochOne = Optional.of(1); + private final String topicTwo = "topic2"; + private final int partitionTwo = 2; + private final Optional leaderEpochTwo = Optional.of(2); + private final String topicThree = "topic3"; + private final int partitionThree = 3; + private final Optional leaderEpochThree = Optional.of(3); + + + private Map partitionDataMap; + + @BeforeEach + public void setUp() { + partitionDataMap = new HashMap<>(); + partitionDataMap.put(new TopicPartition(topicOne, partitionOne), new PartitionData( + offset, + leaderEpochOne, + metadata, + Errors.TOPIC_AUTHORIZATION_FAILED + )); + partitionDataMap.put(new TopicPartition(topicTwo, partitionTwo), new PartitionData( + offset, + leaderEpochTwo, + metadata, + Errors.UNKNOWN_TOPIC_OR_PARTITION + )); + } + + @Test + public void testConstructor() { + for (short version : ApiKeys.OFFSET_FETCH.allVersions()) { + if (version < 8) { + OffsetFetchResponse response = new OffsetFetchResponse(throttleTimeMs, Errors.NOT_COORDINATOR, partitionDataMap); + assertEquals(Errors.NOT_COORDINATOR, response.error()); + assertEquals(3, response.errorCounts().size()); + assertEquals(Utils.mkMap(Utils.mkEntry(Errors.NOT_COORDINATOR, 1), + Utils.mkEntry(Errors.TOPIC_AUTHORIZATION_FAILED, 1), + Utils.mkEntry(Errors.UNKNOWN_TOPIC_OR_PARTITION, 1)), + response.errorCounts()); + + assertEquals(throttleTimeMs, response.throttleTimeMs()); + + Map responseData = response.responseDataV0ToV7(); + assertEquals(partitionDataMap, responseData); + responseData.forEach((tp, data) -> assertTrue(data.hasError())); + } else { + OffsetFetchResponse response = new OffsetFetchResponse( + throttleTimeMs, + Collections.singletonMap(groupOne, Errors.NOT_COORDINATOR), + Collections.singletonMap(groupOne, partitionDataMap)); + assertEquals(Errors.NOT_COORDINATOR, response.groupLevelError(groupOne)); + assertEquals(3, response.errorCounts().size()); + assertEquals(Utils.mkMap(Utils.mkEntry(Errors.NOT_COORDINATOR, 1), + Utils.mkEntry(Errors.TOPIC_AUTHORIZATION_FAILED, 1), + Utils.mkEntry(Errors.UNKNOWN_TOPIC_OR_PARTITION, 1)), + response.errorCounts()); + + assertEquals(throttleTimeMs, response.throttleTimeMs()); + + Map responseData = response.partitionDataMap(groupOne); + assertEquals(partitionDataMap, responseData); + responseData.forEach((tp, data) -> assertTrue(data.hasError())); + } + } + } + + @Test + public void testConstructorWithMultipleGroups() { + Map> responseData = new HashMap<>(); + Map errorMap = new HashMap<>(); + Map pd1 = new HashMap<>(); + Map pd2 = new HashMap<>(); + Map pd3 = new HashMap<>(); + pd1.put(new TopicPartition(topicOne, partitionOne), new PartitionData( + offset, + leaderEpochOne, + metadata, + Errors.TOPIC_AUTHORIZATION_FAILED)); + pd2.put(new TopicPartition(topicTwo, partitionTwo), new PartitionData( + offset, + leaderEpochTwo, + metadata, + Errors.UNKNOWN_TOPIC_OR_PARTITION)); + pd3.put(new TopicPartition(topicThree, partitionThree), new PartitionData( + offset, + leaderEpochThree, + metadata, + Errors.NONE)); + responseData.put(groupOne, pd1); + responseData.put(groupTwo, pd2); + responseData.put(groupThree, pd3); + errorMap.put(groupOne, Errors.NOT_COORDINATOR); + errorMap.put(groupTwo, Errors.COORDINATOR_LOAD_IN_PROGRESS); + errorMap.put(groupThree, Errors.NONE); + for (short version : ApiKeys.OFFSET_FETCH.allVersions()) { + if (version >= 8) { + OffsetFetchResponse response = new OffsetFetchResponse( + throttleTimeMs, errorMap, responseData); + + assertEquals(Errors.NOT_COORDINATOR, response.groupLevelError(groupOne)); + assertEquals(Errors.COORDINATOR_LOAD_IN_PROGRESS, response.groupLevelError(groupTwo)); + assertEquals(Errors.NONE, response.groupLevelError(groupThree)); + assertTrue(response.groupHasError(groupOne)); + assertTrue(response.groupHasError(groupTwo)); + assertFalse(response.groupHasError(groupThree)); + assertEquals(5, response.errorCounts().size()); + assertEquals(Utils.mkMap(Utils.mkEntry(Errors.NOT_COORDINATOR, 1), + Utils.mkEntry(Errors.TOPIC_AUTHORIZATION_FAILED, 1), + Utils.mkEntry(Errors.UNKNOWN_TOPIC_OR_PARTITION, 1), + Utils.mkEntry(Errors.COORDINATOR_LOAD_IN_PROGRESS, 1), + Utils.mkEntry(Errors.NONE, 2)), + response.errorCounts()); + + assertEquals(throttleTimeMs, response.throttleTimeMs()); + + Map responseData1 = response.partitionDataMap(groupOne); + assertEquals(pd1, responseData1); + responseData1.forEach((tp, data) -> assertTrue(data.hasError())); + Map responseData2 = response.partitionDataMap(groupTwo); + assertEquals(pd2, responseData2); + responseData2.forEach((tp, data) -> assertTrue(data.hasError())); + Map responseData3 = response.partitionDataMap(groupThree); + assertEquals(pd3, responseData3); + responseData3.forEach((tp, data) -> assertFalse(data.hasError())); + } + } + } + + /** + * Test behavior changes over the versions. Refer to resources.common.messages.OffsetFetchResponse.json + */ + @Test + public void testStructBuild() { + for (short version : ApiKeys.OFFSET_FETCH.allVersions()) { + if (version < 8) { + partitionDataMap.put(new TopicPartition(topicTwo, partitionTwo), new PartitionData( + offset, + leaderEpochTwo, + metadata, + Errors.GROUP_AUTHORIZATION_FAILED + )); + + OffsetFetchResponse latestResponse = new OffsetFetchResponse(throttleTimeMs, Errors.NONE, partitionDataMap); + OffsetFetchResponseData data = new OffsetFetchResponseData( + new ByteBufferAccessor(latestResponse.serialize(version)), version); + + OffsetFetchResponse oldResponse = new OffsetFetchResponse(data, version); + + if (version <= 1) { + assertEquals(Errors.NONE.code(), data.errorCode()); + + // Partition level error populated in older versions. + assertEquals(Errors.GROUP_AUTHORIZATION_FAILED, oldResponse.error()); + assertEquals(Utils.mkMap(Utils.mkEntry(Errors.GROUP_AUTHORIZATION_FAILED, 2), + Utils.mkEntry(Errors.TOPIC_AUTHORIZATION_FAILED, 1)), + oldResponse.errorCounts()); + } else { + assertEquals(Errors.NONE.code(), data.errorCode()); + + assertEquals(Errors.NONE, oldResponse.error()); + assertEquals(Utils.mkMap( + Utils.mkEntry(Errors.NONE, 1), + Utils.mkEntry(Errors.GROUP_AUTHORIZATION_FAILED, 1), + Utils.mkEntry(Errors.TOPIC_AUTHORIZATION_FAILED, 1)), + oldResponse.errorCounts()); + } + + if (version <= 2) { + assertEquals(DEFAULT_THROTTLE_TIME, oldResponse.throttleTimeMs()); + } else { + assertEquals(throttleTimeMs, oldResponse.throttleTimeMs()); + } + + Map expectedDataMap = new HashMap<>(); + for (Map.Entry entry : partitionDataMap.entrySet()) { + PartitionData partitionData = entry.getValue(); + expectedDataMap.put(entry.getKey(), new PartitionData( + partitionData.offset, + version <= 4 ? Optional.empty() : partitionData.leaderEpoch, + partitionData.metadata, + partitionData.error + )); + } + + Map responseData = oldResponse.responseDataV0ToV7(); + assertEquals(expectedDataMap, responseData); + + responseData.forEach((tp, rdata) -> assertTrue(rdata.hasError())); + } else { + partitionDataMap.put(new TopicPartition(topicTwo, partitionTwo), new PartitionData( + offset, + leaderEpochTwo, + metadata, + Errors.GROUP_AUTHORIZATION_FAILED)); + OffsetFetchResponse latestResponse = new OffsetFetchResponse( + throttleTimeMs, + Collections.singletonMap(groupOne, Errors.NONE), + Collections.singletonMap(groupOne, partitionDataMap)); + OffsetFetchResponseData data = new OffsetFetchResponseData( + new ByteBufferAccessor(latestResponse.serialize(version)), version); + OffsetFetchResponse oldResponse = new OffsetFetchResponse(data, version); + assertEquals(Errors.NONE.code(), data.groups().get(0).errorCode()); + + assertEquals(Errors.NONE, oldResponse.groupLevelError(groupOne)); + assertEquals(Utils.mkMap( + Utils.mkEntry(Errors.NONE, 1), + Utils.mkEntry(Errors.GROUP_AUTHORIZATION_FAILED, 1), + Utils.mkEntry(Errors.TOPIC_AUTHORIZATION_FAILED, 1)), + oldResponse.errorCounts()); + assertEquals(throttleTimeMs, oldResponse.throttleTimeMs()); + + Map expectedDataMap = new HashMap<>(); + for (Map.Entry entry : partitionDataMap.entrySet()) { + PartitionData partitionData = entry.getValue(); + expectedDataMap.put(entry.getKey(), new PartitionData( + partitionData.offset, + partitionData.leaderEpoch, + partitionData.metadata, + partitionData.error + )); + } + + Map responseData = oldResponse.partitionDataMap(groupOne); + assertEquals(expectedDataMap, responseData); + + responseData.forEach((tp, rdata) -> assertTrue(rdata.hasError())); + } + } + } + + @Test + public void testShouldThrottle() { + for (short version : ApiKeys.OFFSET_FETCH.allVersions()) { + if (version < 8) { + OffsetFetchResponse response = new OffsetFetchResponse(throttleTimeMs, Errors.NONE, partitionDataMap); + if (version >= 4) { + assertTrue(response.shouldClientThrottle(version)); + } else { + assertFalse(response.shouldClientThrottle(version)); + } + } else { + OffsetFetchResponse response = new OffsetFetchResponse( + throttleTimeMs, + Collections.singletonMap(groupOne, Errors.NOT_COORDINATOR), + Collections.singletonMap(groupOne, partitionDataMap)); + assertTrue(response.shouldClientThrottle(version)); + } + } + } + + @Test + public void testNullableMetadataV0ToV7() { + PartitionData pd = new PartitionData( + offset, + leaderEpochOne, + null, + Errors.UNKNOWN_TOPIC_OR_PARTITION); + // test PartitionData.equals with null metadata + assertEquals(pd, pd); + partitionDataMap.clear(); + partitionDataMap.put(new TopicPartition(topicOne, partitionOne), pd); + + OffsetFetchResponse response = new OffsetFetchResponse(throttleTimeMs, Errors.GROUP_AUTHORIZATION_FAILED, partitionDataMap); + OffsetFetchResponseData expectedData = + new OffsetFetchResponseData() + .setErrorCode(Errors.GROUP_AUTHORIZATION_FAILED.code()) + .setThrottleTimeMs(throttleTimeMs) + .setTopics(Collections.singletonList( + new OffsetFetchResponseTopic() + .setName(topicOne) + .setPartitions(Collections.singletonList( + new OffsetFetchResponsePartition() + .setPartitionIndex(partitionOne) + .setCommittedOffset(offset) + .setCommittedLeaderEpoch(leaderEpochOne.orElse(-1)) + .setErrorCode(Errors.UNKNOWN_TOPIC_OR_PARTITION.code()) + .setMetadata(null)) + )) + ); + assertEquals(expectedData, response.data()); + } + + @Test + public void testNullableMetadataV8AndAbove() { + PartitionData pd = new PartitionData( + offset, + leaderEpochOne, + null, + Errors.UNKNOWN_TOPIC_OR_PARTITION); + // test PartitionData.equals with null metadata + assertEquals(pd, pd); + partitionDataMap.clear(); + partitionDataMap.put(new TopicPartition(topicOne, partitionOne), pd); + + OffsetFetchResponse response = new OffsetFetchResponse( + throttleTimeMs, + Collections.singletonMap(groupOne, Errors.GROUP_AUTHORIZATION_FAILED), + Collections.singletonMap(groupOne, partitionDataMap)); + OffsetFetchResponseData expectedData = + new OffsetFetchResponseData() + .setGroups(Collections.singletonList( + new OffsetFetchResponseGroup() + .setGroupId(groupOne) + .setTopics(Collections.singletonList( + new OffsetFetchResponseTopics() + .setName(topicOne) + .setPartitions(Collections.singletonList( + new OffsetFetchResponsePartitions() + .setPartitionIndex(partitionOne) + .setCommittedOffset(offset) + .setCommittedLeaderEpoch(leaderEpochOne.orElse(-1)) + .setErrorCode(Errors.UNKNOWN_TOPIC_OR_PARTITION.code()) + .setMetadata(null))))) + .setErrorCode(Errors.GROUP_AUTHORIZATION_FAILED.code()))) + .setThrottleTimeMs(throttleTimeMs); + assertEquals(expectedData, response.data()); + } + + @Test + public void testUseDefaultLeaderEpochV0ToV7() { + final Optional emptyLeaderEpoch = Optional.empty(); + partitionDataMap.clear(); + + partitionDataMap.put(new TopicPartition(topicOne, partitionOne), + new PartitionData( + offset, + emptyLeaderEpoch, + metadata, + Errors.UNKNOWN_TOPIC_OR_PARTITION) + ); + + OffsetFetchResponse response = new OffsetFetchResponse(throttleTimeMs, Errors.NOT_COORDINATOR, partitionDataMap); + OffsetFetchResponseData expectedData = + new OffsetFetchResponseData() + .setErrorCode(Errors.NOT_COORDINATOR.code()) + .setThrottleTimeMs(throttleTimeMs) + .setTopics(Collections.singletonList( + new OffsetFetchResponseTopic() + .setName(topicOne) + .setPartitions(Collections.singletonList( + new OffsetFetchResponsePartition() + .setPartitionIndex(partitionOne) + .setCommittedOffset(offset) + .setCommittedLeaderEpoch(RecordBatch.NO_PARTITION_LEADER_EPOCH) + .setErrorCode(Errors.UNKNOWN_TOPIC_OR_PARTITION.code()) + .setMetadata(metadata)) + )) + ); + assertEquals(expectedData, response.data()); + } + + @Test + public void testUseDefaultLeaderEpochV8() { + final Optional emptyLeaderEpoch = Optional.empty(); + partitionDataMap.clear(); + + partitionDataMap.put(new TopicPartition(topicOne, partitionOne), + new PartitionData( + offset, + emptyLeaderEpoch, + metadata, + Errors.UNKNOWN_TOPIC_OR_PARTITION) + ); + OffsetFetchResponse response = new OffsetFetchResponse( + throttleTimeMs, + Collections.singletonMap(groupOne, Errors.NOT_COORDINATOR), + Collections.singletonMap(groupOne, partitionDataMap)); + OffsetFetchResponseData expectedData = + new OffsetFetchResponseData() + .setGroups(Collections.singletonList( + new OffsetFetchResponseGroup() + .setGroupId(groupOne) + .setTopics(Collections.singletonList( + new OffsetFetchResponseTopics() + .setName(topicOne) + .setPartitions(Collections.singletonList( + new OffsetFetchResponsePartitions() + .setPartitionIndex(partitionOne) + .setCommittedOffset(offset) + .setCommittedLeaderEpoch(RecordBatch.NO_PARTITION_LEADER_EPOCH) + .setErrorCode(Errors.UNKNOWN_TOPIC_OR_PARTITION.code()) + .setMetadata(metadata))))) + .setErrorCode(Errors.NOT_COORDINATOR.code()))) + .setThrottleTimeMs(throttleTimeMs); + assertEquals(expectedData, response.data()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/OffsetsForLeaderEpochRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/OffsetsForLeaderEpochRequestTest.java new file mode 100644 index 0000000..e5dacd5 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/OffsetsForLeaderEpochRequestTest.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderTopicCollection; +import org.apache.kafka.common.protocol.ApiKeys; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class OffsetsForLeaderEpochRequestTest { + + @Test + public void testForConsumerRequiresVersion3() { + OffsetsForLeaderEpochRequest.Builder builder = OffsetsForLeaderEpochRequest.Builder.forConsumer(new OffsetForLeaderTopicCollection()); + for (short version = 0; version < 3; version++) { + final short v = version; + assertThrows(UnsupportedVersionException.class, () -> builder.build(v)); + } + + for (short version = 3; version <= ApiKeys.OFFSET_FOR_LEADER_EPOCH.latestVersion(); version++) { + OffsetsForLeaderEpochRequest request = builder.build(version); + assertEquals(OffsetsForLeaderEpochRequest.CONSUMER_REPLICA_ID, request.replicaId()); + } + } + + @Test + public void testDefaultReplicaId() { + for (short version : ApiKeys.OFFSET_FOR_LEADER_EPOCH.allVersions()) { + int replicaId = 1; + OffsetsForLeaderEpochRequest.Builder builder = OffsetsForLeaderEpochRequest.Builder.forFollower( + version, new OffsetForLeaderTopicCollection(), replicaId); + OffsetsForLeaderEpochRequest request = builder.build(); + OffsetsForLeaderEpochRequest parsed = OffsetsForLeaderEpochRequest.parse(request.serialize(), version); + if (version < 3) + assertEquals(OffsetsForLeaderEpochRequest.DEBUGGING_REPLICA_ID, parsed.replicaId()); + else + assertEquals(replicaId, parsed.replicaId()); + } + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/ProduceRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/ProduceRequestTest.java new file mode 100644 index 0000000..fee026e --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/ProduceRequestTest.java @@ -0,0 +1,314 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.InvalidRecordException; +import org.apache.kafka.common.errors.UnsupportedCompressionTypeException; +import org.apache.kafka.common.message.ProduceRequestData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.MemoryRecordsBuilder; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.record.RecordVersion; +import org.apache.kafka.common.record.SimpleRecord; +import org.apache.kafka.common.record.TimestampType; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.stream.IntStream; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ProduceRequestTest { + + private final SimpleRecord simpleRecord = new SimpleRecord(System.currentTimeMillis(), + "key".getBytes(), + "value".getBytes()); + + @Test + public void shouldBeFlaggedAsTransactionalWhenTransactionalRecords() throws Exception { + final MemoryRecords memoryRecords = MemoryRecords.withTransactionalRecords(0, CompressionType.NONE, 1L, + (short) 1, 1, 1, simpleRecord); + + final ProduceRequest request = ProduceRequest.forCurrentMagic(new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection(Collections.singletonList( + new ProduceRequestData.TopicProduceData() + .setName("topic") + .setPartitionData(Collections.singletonList( + new ProduceRequestData.PartitionProduceData() + .setIndex(1) + .setRecords(memoryRecords)))).iterator())) + .setAcks((short) -1) + .setTimeoutMs(10)).build(); + assertTrue(RequestUtils.hasTransactionalRecords(request)); + } + + @Test + public void shouldNotBeFlaggedAsTransactionalWhenNoRecords() throws Exception { + final ProduceRequest request = createNonIdempotentNonTransactionalRecords(); + assertFalse(RequestUtils.hasTransactionalRecords(request)); + } + + @Test + public void shouldNotBeFlaggedAsIdempotentWhenRecordsNotIdempotent() throws Exception { + final ProduceRequest request = createNonIdempotentNonTransactionalRecords(); + assertFalse(RequestUtils.hasTransactionalRecords(request)); + } + + @Test + public void shouldBeFlaggedAsIdempotentWhenIdempotentRecords() throws Exception { + final MemoryRecords memoryRecords = MemoryRecords.withIdempotentRecords(1, CompressionType.NONE, 1L, + (short) 1, 1, 1, simpleRecord); + final ProduceRequest request = ProduceRequest.forCurrentMagic(new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection(Collections.singletonList( + new ProduceRequestData.TopicProduceData() + .setName("topic") + .setPartitionData(Collections.singletonList( + new ProduceRequestData.PartitionProduceData() + .setIndex(1) + .setRecords(memoryRecords)))).iterator())) + .setAcks((short) -1) + .setTimeoutMs(10)).build(); + assertTrue(RequestTestUtils.hasIdempotentRecords(request)); + } + + @Test + public void testBuildWithOldMessageFormat() { + ByteBuffer buffer = ByteBuffer.allocate(256); + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V1, CompressionType.NONE, + TimestampType.CREATE_TIME, 0L); + builder.append(10L, null, "a".getBytes()); + ProduceRequest.Builder requestBuilder = ProduceRequest.forMagic(RecordBatch.MAGIC_VALUE_V1, + new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection(Collections.singletonList( + new ProduceRequestData.TopicProduceData().setName("test").setPartitionData(Collections.singletonList( + new ProduceRequestData.PartitionProduceData().setIndex(9).setRecords(builder.build())))) + .iterator())) + .setAcks((short) 1) + .setTimeoutMs(5000)); + assertEquals(2, requestBuilder.oldestAllowedVersion()); + assertEquals(2, requestBuilder.latestAllowedVersion()); + } + + @Test + public void testBuildWithCurrentMessageFormat() { + ByteBuffer buffer = ByteBuffer.allocate(256); + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, + CompressionType.NONE, TimestampType.CREATE_TIME, 0L); + builder.append(10L, null, "a".getBytes()); + ProduceRequest.Builder requestBuilder = ProduceRequest.forMagic(RecordBatch.CURRENT_MAGIC_VALUE, + new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection(Collections.singletonList( + new ProduceRequestData.TopicProduceData().setName("test").setPartitionData(Collections.singletonList( + new ProduceRequestData.PartitionProduceData().setIndex(9).setRecords(builder.build())))) + .iterator())) + .setAcks((short) 1) + .setTimeoutMs(5000)); + assertEquals(3, requestBuilder.oldestAllowedVersion()); + assertEquals(ApiKeys.PRODUCE.latestVersion(), requestBuilder.latestAllowedVersion()); + } + + @Test + public void testV3AndAboveShouldContainOnlyOneRecordBatch() { + ByteBuffer buffer = ByteBuffer.allocate(256); + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, CompressionType.NONE, TimestampType.CREATE_TIME, 0L); + builder.append(10L, null, "a".getBytes()); + builder.close(); + + builder = MemoryRecords.builder(buffer, CompressionType.NONE, TimestampType.CREATE_TIME, 1L); + builder.append(11L, "1".getBytes(), "b".getBytes()); + builder.append(12L, null, "c".getBytes()); + builder.close(); + + buffer.flip(); + + ProduceRequest.Builder requestBuilder = ProduceRequest.forCurrentMagic(new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection(Collections.singletonList( + new ProduceRequestData.TopicProduceData() + .setName("test") + .setPartitionData(Collections.singletonList( + new ProduceRequestData.PartitionProduceData() + .setIndex(0) + .setRecords(MemoryRecords.readableRecords(buffer))))).iterator())) + .setAcks((short) 1) + .setTimeoutMs(5000)); + assertThrowsForAllVersions(requestBuilder, InvalidRecordException.class); + } + + @Test + public void testV3AndAboveCannotHaveNoRecordBatches() { + ProduceRequest.Builder requestBuilder = ProduceRequest.forCurrentMagic(new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection(Collections.singletonList( + new ProduceRequestData.TopicProduceData() + .setName("test") + .setPartitionData(Collections.singletonList( + new ProduceRequestData.PartitionProduceData() + .setIndex(0) + .setRecords(MemoryRecords.EMPTY)))).iterator())) + .setAcks((short) 1) + .setTimeoutMs(5000)); + assertThrowsForAllVersions(requestBuilder, InvalidRecordException.class); + } + + @Test + public void testV3AndAboveCannotUseMagicV0() { + ByteBuffer buffer = ByteBuffer.allocate(256); + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V0, CompressionType.NONE, + TimestampType.NO_TIMESTAMP_TYPE, 0L); + builder.append(10L, null, "a".getBytes()); + + ProduceRequest.Builder requestBuilder = ProduceRequest.forCurrentMagic(new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection(Collections.singletonList( + new ProduceRequestData.TopicProduceData() + .setName("test") + .setPartitionData(Collections.singletonList( + new ProduceRequestData.PartitionProduceData() + .setIndex(0) + .setRecords(builder.build())))).iterator())) + .setAcks((short) 1) + .setTimeoutMs(5000)); + assertThrowsForAllVersions(requestBuilder, InvalidRecordException.class); + } + + @Test + public void testV3AndAboveCannotUseMagicV1() { + ByteBuffer buffer = ByteBuffer.allocate(256); + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V1, CompressionType.NONE, + TimestampType.CREATE_TIME, 0L); + builder.append(10L, null, "a".getBytes()); + + ProduceRequest.Builder requestBuilder = ProduceRequest.forCurrentMagic(new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection(Collections.singletonList( + new ProduceRequestData.TopicProduceData() + .setName("test") + .setPartitionData(Collections.singletonList(new ProduceRequestData.PartitionProduceData() + .setIndex(0) + .setRecords(builder.build())))) + .iterator())) + .setAcks((short) 1) + .setTimeoutMs(5000)); + assertThrowsForAllVersions(requestBuilder, InvalidRecordException.class); + } + + @Test + public void testV6AndBelowCannotUseZStdCompression() { + ByteBuffer buffer = ByteBuffer.allocate(256); + MemoryRecordsBuilder builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, CompressionType.ZSTD, + TimestampType.CREATE_TIME, 0L); + builder.append(10L, null, "a".getBytes()); + + ProduceRequestData produceData = new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection(Collections.singletonList( + new ProduceRequestData.TopicProduceData() + .setName("test") + .setPartitionData(Collections.singletonList(new ProduceRequestData.PartitionProduceData() + .setIndex(0) + .setRecords(builder.build())))) + .iterator())) + .setAcks((short) 1) + .setTimeoutMs(1000); + // Can't create ProduceRequest instance with version within [3, 7) + for (short version = 3; version < 7; version++) { + + ProduceRequest.Builder requestBuilder = new ProduceRequest.Builder(version, version, produceData); + assertThrowsForAllVersions(requestBuilder, UnsupportedCompressionTypeException.class); + } + + // Works fine with current version (>= 7) + ProduceRequest.forCurrentMagic(produceData); + } + + @Test + public void testMixedTransactionalData() { + final long producerId = 15L; + final short producerEpoch = 5; + final int sequence = 10; + final String transactionalId = "txnlId"; + + final MemoryRecords nonTxnRecords = MemoryRecords.withRecords(CompressionType.NONE, + new SimpleRecord("foo".getBytes())); + final MemoryRecords txnRecords = MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, + producerEpoch, sequence, new SimpleRecord("bar".getBytes())); + + ProduceRequest.Builder builder = ProduceRequest.forMagic(RecordBatch.CURRENT_MAGIC_VALUE, + new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection(Arrays.asList( + new ProduceRequestData.TopicProduceData().setName("foo").setPartitionData(Collections.singletonList( + new ProduceRequestData.PartitionProduceData().setIndex(0).setRecords(txnRecords))), + new ProduceRequestData.TopicProduceData().setName("foo").setPartitionData(Collections.singletonList( + new ProduceRequestData.PartitionProduceData().setIndex(1).setRecords(nonTxnRecords)))) + .iterator())) + .setAcks((short) -1) + .setTimeoutMs(5000)); + final ProduceRequest request = builder.build(); + assertTrue(RequestUtils.hasTransactionalRecords(request)); + assertTrue(RequestTestUtils.hasIdempotentRecords(request)); + } + + @Test + public void testMixedIdempotentData() { + final long producerId = 15L; + final short producerEpoch = 5; + final int sequence = 10; + + final MemoryRecords nonTxnRecords = MemoryRecords.withRecords(CompressionType.NONE, + new SimpleRecord("foo".getBytes())); + final MemoryRecords txnRecords = MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, + producerEpoch, sequence, new SimpleRecord("bar".getBytes())); + + ProduceRequest.Builder builder = ProduceRequest.forMagic(RecordVersion.current().value, + new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection(Arrays.asList( + new ProduceRequestData.TopicProduceData().setName("foo").setPartitionData(Collections.singletonList( + new ProduceRequestData.PartitionProduceData().setIndex(0).setRecords(txnRecords))), + new ProduceRequestData.TopicProduceData().setName("foo").setPartitionData(Collections.singletonList( + new ProduceRequestData.PartitionProduceData().setIndex(1).setRecords(nonTxnRecords)))) + .iterator())) + .setAcks((short) -1) + .setTimeoutMs(5000)); + + final ProduceRequest request = builder.build(); + assertFalse(RequestUtils.hasTransactionalRecords(request)); + assertTrue(RequestTestUtils.hasIdempotentRecords(request)); + } + + private static void assertThrowsForAllVersions(ProduceRequest.Builder builder, + Class expectedType) { + IntStream.range(builder.oldestAllowedVersion(), builder.latestAllowedVersion() + 1) + .forEach(version -> assertThrows(expectedType, () -> builder.build((short) version).serialize())); + } + + private ProduceRequest createNonIdempotentNonTransactionalRecords() { + return ProduceRequest.forCurrentMagic(new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection(Collections.singletonList( + new ProduceRequestData.TopicProduceData() + .setName("topic") + .setPartitionData(Collections.singletonList(new ProduceRequestData.PartitionProduceData() + .setIndex(1) + .setRecords(MemoryRecords.withRecords(CompressionType.NONE, simpleRecord))))) + .iterator())) + .setAcks((short) -1) + .setTimeoutMs(10)).build(); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/ProduceResponseTest.java b/clients/src/test/java/org/apache/kafka/common/requests/ProduceResponseTest.java new file mode 100644 index 0000000..d854eb0 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/ProduceResponseTest.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.ProduceResponseData; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.RecordBatch; + +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.kafka.common.protocol.ApiKeys.PRODUCE; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ProduceResponseTest { + + @SuppressWarnings("deprecation") + @Test + public void produceResponseV5Test() { + Map responseData = new HashMap<>(); + TopicPartition tp0 = new TopicPartition("test", 0); + responseData.put(tp0, new ProduceResponse.PartitionResponse(Errors.NONE, 10000, RecordBatch.NO_TIMESTAMP, 100)); + + ProduceResponse v5Response = new ProduceResponse(responseData, 10); + short version = 5; + + ByteBuffer buffer = RequestTestUtils.serializeResponseWithHeader(v5Response, version, 0); + + ResponseHeader.parse(buffer, ApiKeys.PRODUCE.responseHeaderVersion(version)); // throw away. + ProduceResponse v5FromBytes = (ProduceResponse) AbstractResponse.parseResponse(ApiKeys.PRODUCE, buffer, version); + + assertEquals(1, v5FromBytes.data().responses().size()); + ProduceResponseData.TopicProduceResponse topicProduceResponse = v5FromBytes.data().responses().iterator().next(); + assertEquals(1, topicProduceResponse.partitionResponses().size()); + ProduceResponseData.PartitionProduceResponse partitionProduceResponse = topicProduceResponse.partitionResponses().iterator().next(); + TopicPartition tp = new TopicPartition(topicProduceResponse.name(), partitionProduceResponse.index()); + assertEquals(tp0, tp); + + assertEquals(100, partitionProduceResponse.logStartOffset()); + assertEquals(10000, partitionProduceResponse.baseOffset()); + assertEquals(RecordBatch.NO_TIMESTAMP, partitionProduceResponse.logAppendTimeMs()); + assertEquals(Errors.NONE, Errors.forCode(partitionProduceResponse.errorCode())); + assertNull(partitionProduceResponse.errorMessage()); + assertTrue(partitionProduceResponse.recordErrors().isEmpty()); + } + + @SuppressWarnings("deprecation") + @Test + public void produceResponseVersionTest() { + Map responseData = new HashMap<>(); + responseData.put(new TopicPartition("test", 0), new ProduceResponse.PartitionResponse(Errors.NONE, 10000, RecordBatch.NO_TIMESTAMP, 100)); + ProduceResponse v0Response = new ProduceResponse(responseData); + ProduceResponse v1Response = new ProduceResponse(responseData, 10); + ProduceResponse v2Response = new ProduceResponse(responseData, 10); + assertEquals(0, v0Response.throttleTimeMs(), "Throttle time must be zero"); + assertEquals(10, v1Response.throttleTimeMs(), "Throttle time must be 10"); + assertEquals(10, v2Response.throttleTimeMs(), "Throttle time must be 10"); + + List arrResponse = Arrays.asList(v0Response, v1Response, v2Response); + for (ProduceResponse produceResponse : arrResponse) { + assertEquals(1, produceResponse.data().responses().size()); + ProduceResponseData.TopicProduceResponse topicProduceResponse = produceResponse.data().responses().iterator().next(); + assertEquals(1, topicProduceResponse.partitionResponses().size()); + ProduceResponseData.PartitionProduceResponse partitionProduceResponse = topicProduceResponse.partitionResponses().iterator().next(); + assertEquals(100, partitionProduceResponse.logStartOffset()); + assertEquals(10000, partitionProduceResponse.baseOffset()); + assertEquals(RecordBatch.NO_TIMESTAMP, partitionProduceResponse.logAppendTimeMs()); + assertEquals(Errors.NONE, Errors.forCode(partitionProduceResponse.errorCode())); + assertNull(partitionProduceResponse.errorMessage()); + assertTrue(partitionProduceResponse.recordErrors().isEmpty()); + } + } + + @SuppressWarnings("deprecation") + @Test + public void produceResponseRecordErrorsTest() { + Map responseData = new HashMap<>(); + TopicPartition tp = new TopicPartition("test", 0); + ProduceResponse.PartitionResponse partResponse = new ProduceResponse.PartitionResponse(Errors.NONE, + 10000, RecordBatch.NO_TIMESTAMP, 100, + Collections.singletonList(new ProduceResponse.RecordError(3, "Record error")), + "Produce failed"); + responseData.put(tp, partResponse); + + for (short version : PRODUCE.allVersions()) { + ProduceResponse response = new ProduceResponse(responseData); + + ProduceResponse produceResponse = ProduceResponse.parse(response.serialize(version), version); + ProduceResponseData.TopicProduceResponse topicProduceResponse = produceResponse.data().responses().iterator().next(); + ProduceResponseData.PartitionProduceResponse deserialized = topicProduceResponse.partitionResponses().iterator().next(); + if (version >= 8) { + assertEquals(1, deserialized.recordErrors().size()); + assertEquals(3, deserialized.recordErrors().get(0).batchIndex()); + assertEquals("Record error", deserialized.recordErrors().get(0).batchIndexErrorMessage()); + assertEquals("Produce failed", deserialized.errorMessage()); + } else { + assertEquals(0, deserialized.recordErrors().size()); + assertNull(deserialized.errorMessage()); + } + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/RequestContextTest.java b/clients/src/test/java/org/apache/kafka/common/requests/RequestContextTest.java new file mode 100644 index 0000000..4415ff9 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/RequestContextTest.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.ApiVersionsResponseData; +import org.apache.kafka.common.message.ApiVersionsResponseData.ApiVersionCollection; +import org.apache.kafka.common.message.CreateTopicsResponseData; +import org.apache.kafka.common.network.ClientInformation; +import org.apache.kafka.common.network.ListenerName; +import org.apache.kafka.common.network.Send; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.security.auth.KafkaPrincipal; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.junit.jupiter.api.Test; + +import java.net.InetAddress; +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class RequestContextTest { + + @Test + public void testSerdeUnsupportedApiVersionRequest() throws Exception { + int correlationId = 23423; + + RequestHeader header = new RequestHeader(ApiKeys.API_VERSIONS, Short.MAX_VALUE, "", correlationId); + RequestContext context = new RequestContext(header, "0", InetAddress.getLocalHost(), KafkaPrincipal.ANONYMOUS, + new ListenerName("ssl"), SecurityProtocol.SASL_SSL, ClientInformation.EMPTY, false); + assertEquals(0, context.apiVersion()); + + // Write some garbage to the request buffer. This should be ignored since we will treat + // the unknown version type as v0 which has an empty request body. + ByteBuffer requestBuffer = ByteBuffer.allocate(8); + requestBuffer.putInt(3709234); + requestBuffer.putInt(29034); + requestBuffer.flip(); + + RequestAndSize requestAndSize = context.parseRequest(requestBuffer); + assertTrue(requestAndSize.request instanceof ApiVersionsRequest); + ApiVersionsRequest request = (ApiVersionsRequest) requestAndSize.request; + assertTrue(request.hasUnsupportedRequestVersion()); + + Send send = context.buildResponseSend(new ApiVersionsResponse(new ApiVersionsResponseData() + .setThrottleTimeMs(0) + .setErrorCode(Errors.UNSUPPORTED_VERSION.code()) + .setApiKeys(new ApiVersionCollection()))); + ByteBufferChannel channel = new ByteBufferChannel(256); + send.writeTo(channel); + + ByteBuffer responseBuffer = channel.buffer(); + responseBuffer.flip(); + responseBuffer.getInt(); // strip off the size + + ResponseHeader responseHeader = ResponseHeader.parse(responseBuffer, + ApiKeys.API_VERSIONS.responseHeaderVersion(header.apiVersion())); + assertEquals(correlationId, responseHeader.correlationId()); + + ApiVersionsResponse response = (ApiVersionsResponse) AbstractResponse.parseResponse(ApiKeys.API_VERSIONS, + responseBuffer, (short) 0); + assertEquals(Errors.UNSUPPORTED_VERSION.code(), response.data().errorCode()); + assertTrue(response.data().apiKeys().isEmpty()); + } + + @Test + public void testEnvelopeResponseSerde() throws Exception { + CreateTopicsResponseData.CreatableTopicResultCollection collection = + new CreateTopicsResponseData.CreatableTopicResultCollection(); + collection.add(new CreateTopicsResponseData.CreatableTopicResult() + .setTopicConfigErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code()) + .setNumPartitions(5)); + CreateTopicsResponseData expectedResponse = new CreateTopicsResponseData() + .setThrottleTimeMs(10) + .setTopics(collection); + + int correlationId = 15; + String clientId = "clientId"; + RequestHeader header = new RequestHeader(ApiKeys.CREATE_TOPICS, ApiKeys.CREATE_TOPICS.latestVersion(), + clientId, correlationId); + + RequestContext context = new RequestContext(header, "0", InetAddress.getLocalHost(), + KafkaPrincipal.ANONYMOUS, new ListenerName("ssl"), SecurityProtocol.SASL_SSL, + ClientInformation.EMPTY, true); + + ByteBuffer buffer = context.buildResponseEnvelopePayload(new CreateTopicsResponse(expectedResponse)); + assertEquals(buffer.capacity(), buffer.limit(), "Buffer limit and capacity should be the same"); + CreateTopicsResponse parsedResponse = (CreateTopicsResponse) AbstractResponse.parseResponse(buffer, header); + assertEquals(expectedResponse, parsedResponse.data()); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/RequestHeaderTest.java b/clients/src/test/java/org/apache/kafka/common/requests/RequestHeaderTest.java new file mode 100644 index 0000000..e4f1995 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/RequestHeaderTest.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.RequestHeaderData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.ObjectSerializationCache; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class RequestHeaderTest { + + @Test + public void testSerdeControlledShutdownV0() { + // Verify that version 0 of controlled shutdown does not include the clientId field + short apiVersion = 0; + int correlationId = 2342; + ByteBuffer rawBuffer = ByteBuffer.allocate(32); + rawBuffer.putShort(ApiKeys.CONTROLLED_SHUTDOWN.id); + rawBuffer.putShort(apiVersion); + rawBuffer.putInt(correlationId); + rawBuffer.flip(); + + RequestHeader deserialized = RequestHeader.parse(rawBuffer); + assertEquals(ApiKeys.CONTROLLED_SHUTDOWN, deserialized.apiKey()); + assertEquals(0, deserialized.apiVersion()); + assertEquals(correlationId, deserialized.correlationId()); + assertEquals("", deserialized.clientId()); + assertEquals(0, deserialized.headerVersion()); + + ByteBuffer serializedBuffer = RequestTestUtils.serializeRequestHeader(deserialized); + + assertEquals(ApiKeys.CONTROLLED_SHUTDOWN.id, serializedBuffer.getShort(0)); + assertEquals(0, serializedBuffer.getShort(2)); + assertEquals(correlationId, serializedBuffer.getInt(4)); + assertEquals(8, serializedBuffer.limit()); + } + + @Test + public void testRequestHeaderV1() { + short apiVersion = 1; + RequestHeader header = new RequestHeader(ApiKeys.FIND_COORDINATOR, apiVersion, "", 10); + assertEquals(1, header.headerVersion()); + + ByteBuffer buffer = RequestTestUtils.serializeRequestHeader(header); + assertEquals(10, buffer.remaining()); + RequestHeader deserialized = RequestHeader.parse(buffer); + assertEquals(header, deserialized); + } + + @Test + public void testRequestHeaderV2() { + short apiVersion = 2; + RequestHeader header = new RequestHeader(ApiKeys.CREATE_DELEGATION_TOKEN, apiVersion, "", 10); + assertEquals(2, header.headerVersion()); + + ByteBuffer buffer = RequestTestUtils.serializeRequestHeader(header); + assertEquals(11, buffer.remaining()); + RequestHeader deserialized = RequestHeader.parse(buffer); + assertEquals(header, deserialized); + } + + @Test + public void parseHeaderFromBufferWithNonZeroPosition() { + ByteBuffer buffer = ByteBuffer.allocate(64); + buffer.position(10); + + RequestHeader header = new RequestHeader(ApiKeys.FIND_COORDINATOR, (short) 1, "", 10); + ObjectSerializationCache serializationCache = new ObjectSerializationCache(); + // size must be called before write to avoid an NPE with the current implementation + header.size(serializationCache); + header.write(buffer, serializationCache); + int limit = buffer.position(); + buffer.position(10); + buffer.limit(limit); + + RequestHeader parsed = RequestHeader.parse(buffer); + assertEquals(header, parsed); + } + + @Test + public void parseHeaderWithNullClientId() { + RequestHeaderData headerData = new RequestHeaderData(). + setClientId(null). + setCorrelationId(123). + setRequestApiKey(ApiKeys.FIND_COORDINATOR.id). + setRequestApiVersion((short) 10); + ObjectSerializationCache serializationCache = new ObjectSerializationCache(); + ByteBuffer buffer = ByteBuffer.allocate(headerData.size(serializationCache, (short) 2)); + headerData.write(new ByteBufferAccessor(buffer), serializationCache, (short) 2); + buffer.flip(); + RequestHeader parsed = RequestHeader.parse(buffer); + assertEquals("", parsed.clientId()); + assertEquals(123, parsed.correlationId()); + assertEquals(ApiKeys.FIND_COORDINATOR, parsed.apiKey()); + assertEquals((short) 10, parsed.apiVersion()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java b/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java new file mode 100644 index 0000000..70b1e05 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java @@ -0,0 +1,3086 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.ConsumerGroupState; +import org.apache.kafka.common.ElectionType; +import org.apache.kafka.common.IsolationLevel; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.acl.AccessControlEntry; +import org.apache.kafka.common.acl.AccessControlEntryFilter; +import org.apache.kafka.common.acl.AclBinding; +import org.apache.kafka.common.acl.AclBindingFilter; +import org.apache.kafka.common.acl.AclOperation; +import org.apache.kafka.common.acl.AclPermissionType; +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.errors.NotCoordinatorException; +import org.apache.kafka.common.errors.NotEnoughReplicasException; +import org.apache.kafka.common.errors.SecurityDisabledException; +import org.apache.kafka.common.errors.UnknownServerException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.AddOffsetsToTxnRequestData; +import org.apache.kafka.common.message.AddOffsetsToTxnResponseData; +import org.apache.kafka.common.message.AlterClientQuotasResponseData; +import org.apache.kafka.common.message.AlterConfigsResponseData; +import org.apache.kafka.common.message.AlterPartitionReassignmentsRequestData; +import org.apache.kafka.common.message.AlterPartitionReassignmentsResponseData; +import org.apache.kafka.common.message.AlterReplicaLogDirsRequestData; +import org.apache.kafka.common.message.AlterReplicaLogDirsRequestData.AlterReplicaLogDirTopic; +import org.apache.kafka.common.message.AlterReplicaLogDirsRequestData.AlterReplicaLogDirTopicCollection; +import org.apache.kafka.common.message.AlterReplicaLogDirsResponseData; +import org.apache.kafka.common.message.ApiMessageType; +import org.apache.kafka.common.message.ApiVersionsRequestData; +import org.apache.kafka.common.message.ApiVersionsResponseData; +import org.apache.kafka.common.message.ApiVersionsResponseData.ApiVersion; +import org.apache.kafka.common.message.ApiVersionsResponseData.ApiVersionCollection; +import org.apache.kafka.common.message.BrokerHeartbeatRequestData; +import org.apache.kafka.common.message.BrokerHeartbeatResponseData; +import org.apache.kafka.common.message.BrokerRegistrationRequestData; +import org.apache.kafka.common.message.BrokerRegistrationResponseData; +import org.apache.kafka.common.message.ControlledShutdownRequestData; +import org.apache.kafka.common.message.ControlledShutdownResponseData; +import org.apache.kafka.common.message.ControlledShutdownResponseData.RemainingPartition; +import org.apache.kafka.common.message.ControlledShutdownResponseData.RemainingPartitionCollection; +import org.apache.kafka.common.message.CreateAclsRequestData; +import org.apache.kafka.common.message.CreateAclsResponseData; +import org.apache.kafka.common.message.CreateDelegationTokenRequestData; +import org.apache.kafka.common.message.CreateDelegationTokenRequestData.CreatableRenewers; +import org.apache.kafka.common.message.CreateDelegationTokenResponseData; +import org.apache.kafka.common.message.CreatePartitionsRequestData; +import org.apache.kafka.common.message.CreatePartitionsRequestData.CreatePartitionsAssignment; +import org.apache.kafka.common.message.CreatePartitionsRequestData.CreatePartitionsTopic; +import org.apache.kafka.common.message.CreatePartitionsRequestData.CreatePartitionsTopicCollection; +import org.apache.kafka.common.message.CreatePartitionsResponseData; +import org.apache.kafka.common.message.CreatePartitionsResponseData.CreatePartitionsTopicResult; +import org.apache.kafka.common.message.CreateTopicsRequestData; +import org.apache.kafka.common.message.CreateTopicsRequestData.CreatableReplicaAssignment; +import org.apache.kafka.common.message.CreateTopicsRequestData.CreatableTopic; +import org.apache.kafka.common.message.CreateTopicsRequestData.CreatableTopicCollection; +import org.apache.kafka.common.message.CreateTopicsRequestData.CreateableTopicConfig; +import org.apache.kafka.common.message.CreateTopicsResponseData; +import org.apache.kafka.common.message.CreateTopicsResponseData.CreatableTopicConfigs; +import org.apache.kafka.common.message.CreateTopicsResponseData.CreatableTopicResult; +import org.apache.kafka.common.message.DeleteAclsRequestData; +import org.apache.kafka.common.message.DeleteAclsResponseData; +import org.apache.kafka.common.message.DeleteGroupsRequestData; +import org.apache.kafka.common.message.DeleteGroupsResponseData; +import org.apache.kafka.common.message.DeleteGroupsResponseData.DeletableGroupResult; +import org.apache.kafka.common.message.DeleteGroupsResponseData.DeletableGroupResultCollection; +import org.apache.kafka.common.message.DeleteTopicsRequestData; +import org.apache.kafka.common.message.DeleteTopicsResponseData; +import org.apache.kafka.common.message.DeleteTopicsResponseData.DeletableTopicResult; +import org.apache.kafka.common.message.DescribeAclsResponseData; +import org.apache.kafka.common.message.DescribeAclsResponseData.AclDescription; +import org.apache.kafka.common.message.DescribeAclsResponseData.DescribeAclsResource; +import org.apache.kafka.common.message.DescribeClientQuotasResponseData; +import org.apache.kafka.common.message.DescribeClusterRequestData; +import org.apache.kafka.common.message.DescribeClusterResponseData; +import org.apache.kafka.common.message.DescribeClusterResponseData.DescribeClusterBroker; +import org.apache.kafka.common.message.DescribeClusterResponseData.DescribeClusterBrokerCollection; +import org.apache.kafka.common.message.DescribeConfigsRequestData; +import org.apache.kafka.common.message.DescribeConfigsResponseData; +import org.apache.kafka.common.message.DescribeConfigsResponseData.DescribeConfigsResourceResult; +import org.apache.kafka.common.message.DescribeConfigsResponseData.DescribeConfigsResult; +import org.apache.kafka.common.message.DescribeGroupsRequestData; +import org.apache.kafka.common.message.DescribeGroupsResponseData; +import org.apache.kafka.common.message.DescribeGroupsResponseData.DescribedGroup; +import org.apache.kafka.common.message.DescribeProducersRequestData; +import org.apache.kafka.common.message.DescribeProducersResponseData; +import org.apache.kafka.common.message.DescribeTransactionsRequestData; +import org.apache.kafka.common.message.DescribeTransactionsResponseData; +import org.apache.kafka.common.message.ElectLeadersResponseData.PartitionResult; +import org.apache.kafka.common.message.ElectLeadersResponseData.ReplicaElectionResult; +import org.apache.kafka.common.message.EndTxnRequestData; +import org.apache.kafka.common.message.EndTxnResponseData; +import org.apache.kafka.common.message.ExpireDelegationTokenRequestData; +import org.apache.kafka.common.message.ExpireDelegationTokenResponseData; +import org.apache.kafka.common.message.FetchRequestData; +import org.apache.kafka.common.message.FetchResponseData; +import org.apache.kafka.common.message.FindCoordinatorRequestData; +import org.apache.kafka.common.message.HeartbeatRequestData; +import org.apache.kafka.common.message.HeartbeatResponseData; +import org.apache.kafka.common.message.IncrementalAlterConfigsRequestData; +import org.apache.kafka.common.message.IncrementalAlterConfigsRequestData.AlterConfigsResource; +import org.apache.kafka.common.message.IncrementalAlterConfigsRequestData.AlterableConfig; +import org.apache.kafka.common.message.IncrementalAlterConfigsResponseData; +import org.apache.kafka.common.message.IncrementalAlterConfigsResponseData.AlterConfigsResourceResponse; +import org.apache.kafka.common.message.InitProducerIdRequestData; +import org.apache.kafka.common.message.InitProducerIdResponseData; +import org.apache.kafka.common.message.JoinGroupRequestData; +import org.apache.kafka.common.message.JoinGroupResponseData; +import org.apache.kafka.common.message.JoinGroupResponseData.JoinGroupResponseMember; +import org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrPartitionState; +import org.apache.kafka.common.message.LeaderAndIsrResponseData; +import org.apache.kafka.common.message.LeaderAndIsrResponseData.LeaderAndIsrTopicErrorCollection; +import org.apache.kafka.common.message.LeaveGroupRequestData.MemberIdentity; +import org.apache.kafka.common.message.LeaveGroupResponseData; +import org.apache.kafka.common.message.ListGroupsRequestData; +import org.apache.kafka.common.message.ListGroupsResponseData; +import org.apache.kafka.common.message.ListOffsetsRequestData.ListOffsetsPartition; +import org.apache.kafka.common.message.ListOffsetsRequestData.ListOffsetsTopic; +import org.apache.kafka.common.message.ListOffsetsResponseData; +import org.apache.kafka.common.message.ListOffsetsResponseData.ListOffsetsPartitionResponse; +import org.apache.kafka.common.message.ListOffsetsResponseData.ListOffsetsTopicResponse; +import org.apache.kafka.common.message.ListPartitionReassignmentsRequestData; +import org.apache.kafka.common.message.ListPartitionReassignmentsResponseData; +import org.apache.kafka.common.message.ListTransactionsRequestData; +import org.apache.kafka.common.message.ListTransactionsResponseData; +import org.apache.kafka.common.message.OffsetCommitRequestData; +import org.apache.kafka.common.message.OffsetCommitResponseData; +import org.apache.kafka.common.message.OffsetDeleteRequestData; +import org.apache.kafka.common.message.OffsetDeleteRequestData.OffsetDeleteRequestPartition; +import org.apache.kafka.common.message.OffsetDeleteRequestData.OffsetDeleteRequestTopic; +import org.apache.kafka.common.message.OffsetDeleteRequestData.OffsetDeleteRequestTopicCollection; +import org.apache.kafka.common.message.OffsetDeleteResponseData; +import org.apache.kafka.common.message.OffsetDeleteResponseData.OffsetDeleteResponsePartition; +import org.apache.kafka.common.message.OffsetDeleteResponseData.OffsetDeleteResponsePartitionCollection; +import org.apache.kafka.common.message.OffsetDeleteResponseData.OffsetDeleteResponseTopic; +import org.apache.kafka.common.message.OffsetDeleteResponseData.OffsetDeleteResponseTopicCollection; +import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderPartition; +import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderTopic; +import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderTopicCollection; +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData; +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset; +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.OffsetForLeaderTopicResult; +import org.apache.kafka.common.message.ProduceRequestData; +import org.apache.kafka.common.message.ProduceResponseData; +import org.apache.kafka.common.message.RenewDelegationTokenRequestData; +import org.apache.kafka.common.message.RenewDelegationTokenResponseData; +import org.apache.kafka.common.message.SaslAuthenticateRequestData; +import org.apache.kafka.common.message.SaslAuthenticateResponseData; +import org.apache.kafka.common.message.SaslHandshakeRequestData; +import org.apache.kafka.common.message.SaslHandshakeResponseData; +import org.apache.kafka.common.message.StopReplicaRequestData.StopReplicaPartitionState; +import org.apache.kafka.common.message.StopReplicaRequestData.StopReplicaTopicState; +import org.apache.kafka.common.message.StopReplicaResponseData; +import org.apache.kafka.common.message.SyncGroupRequestData; +import org.apache.kafka.common.message.SyncGroupRequestData.SyncGroupRequestAssignment; +import org.apache.kafka.common.message.SyncGroupResponseData; +import org.apache.kafka.common.message.UnregisterBrokerRequestData; +import org.apache.kafka.common.message.UnregisterBrokerResponseData; +import org.apache.kafka.common.message.UpdateMetadataRequestData.UpdateMetadataBroker; +import org.apache.kafka.common.message.UpdateMetadataRequestData.UpdateMetadataEndpoint; +import org.apache.kafka.common.message.UpdateMetadataRequestData.UpdateMetadataPartitionState; +import org.apache.kafka.common.message.UpdateMetadataResponseData; +import org.apache.kafka.common.network.ListenerName; +import org.apache.kafka.common.network.Send; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.protocol.ObjectSerializationCache; +import org.apache.kafka.common.quota.ClientQuotaAlteration; +import org.apache.kafka.common.quota.ClientQuotaEntity; +import org.apache.kafka.common.quota.ClientQuotaFilter; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.record.SimpleRecord; +import org.apache.kafka.common.requests.CreateTopicsRequest.Builder; +import org.apache.kafka.common.requests.DescribeConfigsResponse.ConfigType; +import org.apache.kafka.common.requests.FindCoordinatorRequest.CoordinatorType; +import org.apache.kafka.common.requests.FindCoordinatorRequest.NoBatchedFindCoordinatorsException; +import org.apache.kafka.common.resource.PatternType; +import org.apache.kafka.common.resource.ResourcePattern; +import org.apache.kafka.common.resource.ResourcePatternFilter; +import org.apache.kafka.common.resource.ResourceType; +import org.apache.kafka.common.security.auth.KafkaPrincipal; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.common.security.token.delegation.DelegationToken; +import org.apache.kafka.common.security.token.delegation.TokenInformation; +import org.apache.kafka.common.utils.SecurityUtils; +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.api.Test; + +import java.nio.BufferUnderflowException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; +import static java.util.Collections.singletonList; +import static org.apache.kafka.common.protocol.ApiKeys.CREATE_PARTITIONS; +import static org.apache.kafka.common.protocol.ApiKeys.CREATE_TOPICS; +import static org.apache.kafka.common.protocol.ApiKeys.DELETE_TOPICS; +import static org.apache.kafka.common.protocol.ApiKeys.DESCRIBE_CONFIGS; +import static org.apache.kafka.common.protocol.ApiKeys.FETCH; +import static org.apache.kafka.common.protocol.ApiKeys.FIND_COORDINATOR; +import static org.apache.kafka.common.protocol.ApiKeys.JOIN_GROUP; +import static org.apache.kafka.common.protocol.ApiKeys.LEADER_AND_ISR; +import static org.apache.kafka.common.protocol.ApiKeys.LIST_GROUPS; +import static org.apache.kafka.common.protocol.ApiKeys.LIST_OFFSETS; +import static org.apache.kafka.common.protocol.ApiKeys.OFFSET_FETCH; +import static org.apache.kafka.common.protocol.ApiKeys.STOP_REPLICA; +import static org.apache.kafka.common.protocol.ApiKeys.SYNC_GROUP; +import static org.apache.kafka.common.requests.FetchMetadata.INVALID_SESSION_ID; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class RequestResponseTest { + + // Exception includes a message that we verify is not included in error responses + private final UnknownServerException unknownServerException = new UnknownServerException("secret"); + + @Test + public void testSerialization() throws Exception { + checkRequest(createControlledShutdownRequest(), true); + checkResponse(createControlledShutdownResponse(), 1, true); + checkErrorResponse(createControlledShutdownRequest(), unknownServerException, true); + checkErrorResponse(createControlledShutdownRequest(0), unknownServerException, true); + checkRequest(createFetchRequest(4), true); + checkResponse(createFetchResponse(true), 4, true); + List toForgetTopics = new ArrayList<>(); + toForgetTopics.add(new TopicIdPartition(Uuid.ZERO_UUID, new TopicPartition("foo", 0))); + toForgetTopics.add(new TopicIdPartition(Uuid.ZERO_UUID, new TopicPartition("foo", 2))); + toForgetTopics.add(new TopicIdPartition(Uuid.ZERO_UUID, new TopicPartition("bar", 0))); + checkRequest(createFetchRequest(7, new FetchMetadata(123, 456), toForgetTopics), true); + checkResponse(createFetchResponse(123), 7, true); + checkResponse(createFetchResponse(Errors.FETCH_SESSION_ID_NOT_FOUND, 123), 7, true); + checkErrorResponse(createFetchRequest(7), unknownServerException, true); + checkRequest(createHeartBeatRequest(), true); + checkErrorResponse(createHeartBeatRequest(), unknownServerException, true); + checkResponse(createHeartBeatResponse(), 0, true); + + for (short version : JOIN_GROUP.allVersions()) { + checkRequest(createJoinGroupRequest(version), true); + checkErrorResponse(createJoinGroupRequest(version), unknownServerException, true); + checkResponse(createJoinGroupResponse(version), version, true); + } + + for (short version : SYNC_GROUP.allVersions()) { + checkRequest(createSyncGroupRequest(version), true); + checkErrorResponse(createSyncGroupRequest(version), unknownServerException, true); + checkResponse(createSyncGroupResponse(version), version, true); + } + + checkRequest(createLeaveGroupRequest(), true); + checkErrorResponse(createLeaveGroupRequest(), unknownServerException, true); + checkResponse(createLeaveGroupResponse(), 0, true); + + for (short version : ApiKeys.LIST_GROUPS.allVersions()) { + checkRequest(createListGroupsRequest(version), false); + checkErrorResponse(createListGroupsRequest(version), unknownServerException, true); + checkResponse(createListGroupsResponse(version), version, true); + } + + checkRequest(createDescribeGroupRequest(), true); + checkErrorResponse(createDescribeGroupRequest(), unknownServerException, true); + checkResponse(createDescribeGroupResponse(), 0, true); + checkRequest(createDeleteGroupsRequest(), true); + checkErrorResponse(createDeleteGroupsRequest(), unknownServerException, true); + checkResponse(createDeleteGroupsResponse(), 0, true); + for (short version : LIST_OFFSETS.allVersions()) { + checkRequest(createListOffsetRequest(version), true); + checkErrorResponse(createListOffsetRequest(version), unknownServerException, true); + checkResponse(createListOffsetResponse(version), version, true); + } + checkRequest(MetadataRequest.Builder.allTopics().build((short) 2), true); + checkRequest(createMetadataRequest(1, Collections.singletonList("topic1")), true); + checkErrorResponse(createMetadataRequest(1, Collections.singletonList("topic1")), unknownServerException, true); + checkResponse(createMetadataResponse(), 2, true); + checkErrorResponse(createMetadataRequest(2, Collections.singletonList("topic1")), unknownServerException, true); + checkResponse(createMetadataResponse(), 3, true); + checkErrorResponse(createMetadataRequest(3, Collections.singletonList("topic1")), unknownServerException, true); + checkResponse(createMetadataResponse(), 4, true); + checkErrorResponse(createMetadataRequest(4, Collections.singletonList("topic1")), unknownServerException, true); + checkRequest(createOffsetFetchRequest(0, false), true); + checkRequest(createOffsetFetchRequest(1, false), true); + checkRequest(createOffsetFetchRequest(2, false), true); + checkRequest(createOffsetFetchRequest(7, true), true); + checkRequest(createOffsetFetchRequest(8, true), true); + checkRequest(createOffsetFetchRequestWithMultipleGroups(8, true), true); + checkRequest(createOffsetFetchRequestWithMultipleGroups(8, false), true); + checkRequest(createOffsetFetchRequestForAllPartition(7, true), true); + checkRequest(createOffsetFetchRequestForAllPartition(8, true), true); + checkErrorResponse(createOffsetFetchRequest(0, false), unknownServerException, true); + checkErrorResponse(createOffsetFetchRequest(1, false), unknownServerException, true); + checkErrorResponse(createOffsetFetchRequest(2, false), unknownServerException, true); + checkErrorResponse(createOffsetFetchRequest(7, true), unknownServerException, true); + checkErrorResponse(createOffsetFetchRequest(8, true), unknownServerException, true); + checkErrorResponse(createOffsetFetchRequestWithMultipleGroups(8, true), unknownServerException, true); + checkErrorResponse(createOffsetFetchRequestForAllPartition(7, true), + new NotCoordinatorException("Not Coordinator"), true); + checkErrorResponse(createOffsetFetchRequestForAllPartition(8, true), + new NotCoordinatorException("Not Coordinator"), true); + checkErrorResponse(createOffsetFetchRequestWithMultipleGroups(8, true), + new NotCoordinatorException("Not Coordinator"), true); + checkResponse(createOffsetFetchResponse(0), 0, true); + checkResponse(createOffsetFetchResponse(7), 7, true); + checkResponse(createOffsetFetchResponse(8), 8, true); + checkRequest(createProduceRequest(2), true); + checkErrorResponse(createProduceRequest(2), unknownServerException, true); + checkRequest(createProduceRequest(3), true); + checkErrorResponse(createProduceRequest(3), unknownServerException, true); + checkResponse(createProduceResponse(), 2, true); + checkResponse(createProduceResponseWithErrorMessage(), 8, true); + + for (short version : STOP_REPLICA.allVersions()) { + checkRequest(createStopReplicaRequest(version, true), true); + checkRequest(createStopReplicaRequest(version, false), true); + checkErrorResponse(createStopReplicaRequest(version, true), unknownServerException, true); + checkErrorResponse(createStopReplicaRequest(version, false), unknownServerException, true); + checkResponse(createStopReplicaResponse(), version, true); + } + + for (short version : LEADER_AND_ISR.allVersions()) { + checkRequest(createLeaderAndIsrRequest(version), true); + checkErrorResponse(createLeaderAndIsrRequest(version), unknownServerException, false); + checkResponse(createLeaderAndIsrResponse(version), version, true); + } + + checkRequest(createSaslHandshakeRequest(), true); + checkErrorResponse(createSaslHandshakeRequest(), unknownServerException, true); + checkResponse(createSaslHandshakeResponse(), 0, true); + checkRequest(createSaslAuthenticateRequest(), true); + checkErrorResponse(createSaslAuthenticateRequest(), unknownServerException, true); + checkResponse(createSaslAuthenticateResponse(), 0, true); + checkResponse(createSaslAuthenticateResponse(), 1, true); + + for (short version : CREATE_TOPICS.allVersions()) { + checkRequest(createCreateTopicRequest(version), true); + checkErrorResponse(createCreateTopicRequest(version), unknownServerException, true); + checkResponse(createCreateTopicResponse(), version, true); + } + + for (short version : DELETE_TOPICS.allVersions()) { + checkRequest(createDeleteTopicsRequest(version), true); + checkErrorResponse(createDeleteTopicsRequest(version), unknownServerException, true); + checkResponse(createDeleteTopicsResponse(), version, true); + } + + for (short version : CREATE_PARTITIONS.allVersions()) { + checkRequest(createCreatePartitionsRequest(version), true); + checkRequest(createCreatePartitionsRequestWithAssignments(version), false); + checkErrorResponse(createCreatePartitionsRequest(version), unknownServerException, true); + checkResponse(createCreatePartitionsResponse(), version, true); + } + + checkRequest(createInitPidRequest(), true); + checkErrorResponse(createInitPidRequest(), unknownServerException, true); + checkResponse(createInitPidResponse(), 0, true); + + checkRequest(createAddPartitionsToTxnRequest(), true); + checkResponse(createAddPartitionsToTxnResponse(), 0, true); + checkErrorResponse(createAddPartitionsToTxnRequest(), unknownServerException, true); + checkRequest(createAddOffsetsToTxnRequest(), true); + checkResponse(createAddOffsetsToTxnResponse(), 0, true); + checkErrorResponse(createAddOffsetsToTxnRequest(), unknownServerException, true); + checkRequest(createEndTxnRequest(), true); + checkResponse(createEndTxnResponse(), 0, true); + checkErrorResponse(createEndTxnRequest(), unknownServerException, true); + checkRequest(createWriteTxnMarkersRequest(), true); + checkResponse(createWriteTxnMarkersResponse(), 0, true); + checkErrorResponse(createWriteTxnMarkersRequest(), unknownServerException, true); + + checkOlderFetchVersions(); + checkResponse(createMetadataResponse(), 0, true); + checkResponse(createMetadataResponse(), 1, true); + checkErrorResponse(createMetadataRequest(1, Collections.singletonList("topic1")), unknownServerException, true); + checkRequest(createOffsetCommitRequest(0), true); + checkErrorResponse(createOffsetCommitRequest(0), unknownServerException, true); + checkRequest(createOffsetCommitRequest(1), true); + checkErrorResponse(createOffsetCommitRequest(1), unknownServerException, true); + checkRequest(createOffsetCommitRequest(2), true); + checkErrorResponse(createOffsetCommitRequest(2), unknownServerException, true); + checkRequest(createOffsetCommitRequest(3), true); + checkErrorResponse(createOffsetCommitRequest(3), unknownServerException, true); + checkRequest(createOffsetCommitRequest(4), true); + checkErrorResponse(createOffsetCommitRequest(4), unknownServerException, true); + checkResponse(createOffsetCommitResponse(), 4, true); + checkRequest(createOffsetCommitRequest(5), true); + checkErrorResponse(createOffsetCommitRequest(5), unknownServerException, true); + checkResponse(createOffsetCommitResponse(), 5, true); + checkRequest(createJoinGroupRequest(0), true); + checkRequest(createUpdateMetadataRequest(0, null), false); + checkErrorResponse(createUpdateMetadataRequest(0, null), unknownServerException, true); + checkRequest(createUpdateMetadataRequest(1, null), false); + checkRequest(createUpdateMetadataRequest(1, "rack1"), false); + checkErrorResponse(createUpdateMetadataRequest(1, null), unknownServerException, true); + checkRequest(createUpdateMetadataRequest(2, "rack1"), false); + checkRequest(createUpdateMetadataRequest(2, null), false); + checkErrorResponse(createUpdateMetadataRequest(2, "rack1"), unknownServerException, true); + checkRequest(createUpdateMetadataRequest(3, "rack1"), false); + checkRequest(createUpdateMetadataRequest(3, null), false); + checkErrorResponse(createUpdateMetadataRequest(3, "rack1"), unknownServerException, true); + checkRequest(createUpdateMetadataRequest(4, "rack1"), false); + checkRequest(createUpdateMetadataRequest(4, null), false); + checkErrorResponse(createUpdateMetadataRequest(4, "rack1"), unknownServerException, true); + checkRequest(createUpdateMetadataRequest(5, "rack1"), false); + checkRequest(createUpdateMetadataRequest(5, null), false); + checkErrorResponse(createUpdateMetadataRequest(5, "rack1"), unknownServerException, true); + checkResponse(createUpdateMetadataResponse(), 0, true); + checkRequest(createListOffsetRequest(0), true); + checkErrorResponse(createListOffsetRequest(0), unknownServerException, true); + checkResponse(createListOffsetResponse(0), 0, true); + checkRequest(createLeaderEpochRequestForReplica(0, 1), true); + checkRequest(createLeaderEpochRequestForConsumer(), true); + checkResponse(createLeaderEpochResponse(), 0, true); + checkErrorResponse(createLeaderEpochRequestForConsumer(), unknownServerException, true); + checkRequest(createAddPartitionsToTxnRequest(), true); + checkErrorResponse(createAddPartitionsToTxnRequest(), unknownServerException, true); + checkResponse(createAddPartitionsToTxnResponse(), 0, true); + checkRequest(createAddOffsetsToTxnRequest(), true); + checkErrorResponse(createAddOffsetsToTxnRequest(), unknownServerException, true); + checkResponse(createAddOffsetsToTxnResponse(), 0, true); + checkRequest(createEndTxnRequest(), true); + checkErrorResponse(createEndTxnRequest(), unknownServerException, true); + checkResponse(createEndTxnResponse(), 0, true); + checkRequest(createWriteTxnMarkersRequest(), true); + checkErrorResponse(createWriteTxnMarkersRequest(), unknownServerException, true); + checkResponse(createWriteTxnMarkersResponse(), 0, true); + checkRequest(createTxnOffsetCommitRequest(0), true); + checkRequest(createTxnOffsetCommitRequest(3), true); + checkRequest(createTxnOffsetCommitRequestWithAutoDowngrade(2), true); + checkErrorResponse(createTxnOffsetCommitRequest(0), unknownServerException, true); + checkErrorResponse(createTxnOffsetCommitRequest(3), unknownServerException, true); + checkErrorResponse(createTxnOffsetCommitRequestWithAutoDowngrade(2), unknownServerException, true); + checkResponse(createTxnOffsetCommitResponse(), 0, true); + checkRequest(createDescribeAclsRequest(), true); + checkErrorResponse(createDescribeAclsRequest(), new SecurityDisabledException("Security is not enabled."), true); + checkResponse(createDescribeAclsResponse(), ApiKeys.DESCRIBE_ACLS.latestVersion(), true); + checkRequest(createCreateAclsRequest(), true); + checkErrorResponse(createCreateAclsRequest(), new SecurityDisabledException("Security is not enabled."), true); + checkResponse(createCreateAclsResponse(), ApiKeys.CREATE_ACLS.latestVersion(), true); + checkRequest(createDeleteAclsRequest(), true); + checkErrorResponse(createDeleteAclsRequest(), new SecurityDisabledException("Security is not enabled."), true); + checkResponse(createDeleteAclsResponse(ApiKeys.DELETE_ACLS.latestVersion()), ApiKeys.DELETE_ACLS.latestVersion(), true); + checkRequest(createAlterConfigsRequest(), false); + checkErrorResponse(createAlterConfigsRequest(), unknownServerException, true); + checkResponse(createAlterConfigsResponse(), 0, false); + checkRequest(createDescribeConfigsRequest(0), true); + checkRequest(createDescribeConfigsRequestWithConfigEntries(0), false); + checkErrorResponse(createDescribeConfigsRequest(0), unknownServerException, true); + checkResponse(createDescribeConfigsResponse((short) 0), 0, false); + checkRequest(createDescribeConfigsRequest(1), true); + checkRequest(createDescribeConfigsRequestWithConfigEntries(1), false); + checkRequest(createDescribeConfigsRequestWithDocumentation(1), false); + checkRequest(createDescribeConfigsRequestWithDocumentation(2), false); + checkRequest(createDescribeConfigsRequestWithDocumentation(3), false); + checkErrorResponse(createDescribeConfigsRequest(1), unknownServerException, true); + checkResponse(createDescribeConfigsResponse((short) 1), 1, false); + checkDescribeConfigsResponseVersions(); + checkRequest(createCreateTokenRequest(), true); + checkErrorResponse(createCreateTokenRequest(), unknownServerException, true); + checkResponse(createCreateTokenResponse(), 0, true); + checkRequest(createDescribeTokenRequest(), true); + checkErrorResponse(createDescribeTokenRequest(), unknownServerException, true); + checkResponse(createDescribeTokenResponse(), 0, true); + checkRequest(createExpireTokenRequest(), true); + checkErrorResponse(createExpireTokenRequest(), unknownServerException, true); + checkResponse(createExpireTokenResponse(), 0, true); + checkRequest(createRenewTokenRequest(), true); + checkErrorResponse(createRenewTokenRequest(), unknownServerException, true); + checkResponse(createRenewTokenResponse(), 0, true); + checkRequest(createElectLeadersRequest(), true); + checkRequest(createElectLeadersRequestNullPartitions(), true); + checkErrorResponse(createElectLeadersRequest(), unknownServerException, true); + checkResponse(createElectLeadersResponse(), 1, true); + checkRequest(createIncrementalAlterConfigsRequest(), true); + checkErrorResponse(createIncrementalAlterConfigsRequest(), unknownServerException, true); + checkResponse(createIncrementalAlterConfigsResponse(), 0, true); + checkRequest(createAlterPartitionReassignmentsRequest(), true); + checkErrorResponse(createAlterPartitionReassignmentsRequest(), unknownServerException, true); + checkResponse(createAlterPartitionReassignmentsResponse(), 0, true); + checkRequest(createListPartitionReassignmentsRequest(), true); + checkErrorResponse(createListPartitionReassignmentsRequest(), unknownServerException, true); + checkResponse(createListPartitionReassignmentsResponse(), 0, true); + checkRequest(createOffsetDeleteRequest(), true); + checkErrorResponse(createOffsetDeleteRequest(), unknownServerException, true); + checkResponse(createOffsetDeleteResponse(), 0, true); + checkRequest(createAlterReplicaLogDirsRequest(), true); + checkErrorResponse(createAlterReplicaLogDirsRequest(), unknownServerException, true); + checkResponse(createAlterReplicaLogDirsResponse(), 0, true); + + checkRequest(createDescribeClientQuotasRequest(), true); + checkErrorResponse(createDescribeClientQuotasRequest(), unknownServerException, true); + checkResponse(createDescribeClientQuotasResponse(), 0, true); + checkRequest(createAlterClientQuotasRequest(), true); + checkErrorResponse(createAlterClientQuotasRequest(), unknownServerException, true); + checkResponse(createAlterClientQuotasResponse(), 0, true); + } + + @Test + public void testApiVersionsSerialization() { + for (short version : ApiKeys.API_VERSIONS.allVersions()) { + checkRequest(createApiVersionRequest(version), true); + checkErrorResponse(createApiVersionRequest(version), unknownServerException, true); + checkErrorResponse(createApiVersionRequest(version), new UnsupportedVersionException("Not Supported"), true); + checkResponse(createApiVersionResponse(), version, true); + checkResponse(ApiVersionsResponse.defaultApiVersionsResponse(ApiMessageType.ListenerType.ZK_BROKER), version, true); + } + } + + @Test + public void testBrokerHeartbeatSerialization() { + for (short version : ApiKeys.BROKER_HEARTBEAT.allVersions()) { + checkRequest(createBrokerHeartbeatRequest(version), true); + checkErrorResponse(createBrokerHeartbeatRequest(version), unknownServerException, true); + checkResponse(createBrokerHeartbeatResponse(), version, true); + } + } + + @Test + public void testBrokerRegistrationSerialization() { + for (short version : ApiKeys.BROKER_REGISTRATION.allVersions()) { + checkRequest(createBrokerRegistrationRequest(version), true); + checkErrorResponse(createBrokerRegistrationRequest(version), unknownServerException, true); + checkResponse(createBrokerRegistrationResponse(), 0, true); + } + } + + @Test + public void testDescribeProducersSerialization() { + for (short version : ApiKeys.DESCRIBE_PRODUCERS.allVersions()) { + checkRequest(createDescribeProducersRequest(version), true); + checkErrorResponse(createDescribeProducersRequest(version), unknownServerException, true); + checkResponse(createDescribeProducersResponse(), version, true); + } + } + + @Test + public void testDescribeTransactionsSerialization() { + for (short version : ApiKeys.DESCRIBE_TRANSACTIONS.allVersions()) { + checkRequest(createDescribeTransactionsRequest(version), true); + checkErrorResponse(createDescribeTransactionsRequest(version), unknownServerException, true); + checkResponse(createDescribeTransactionsResponse(), version, true); + } + } + + @Test + public void testListTransactionsSerialization() { + for (short version : ApiKeys.LIST_TRANSACTIONS.allVersions()) { + checkRequest(createListTransactionsRequest(version), true); + checkErrorResponse(createListTransactionsRequest(version), unknownServerException, true); + checkResponse(createListTransactionsResponse(), version, true); + } + } + + @Test + public void testDescribeClusterSerialization() { + for (short version : ApiKeys.DESCRIBE_CLUSTER.allVersions()) { + checkRequest(createDescribeClusterRequest(version), true); + checkErrorResponse(createDescribeClusterRequest(version), unknownServerException, true); + checkResponse(createDescribeClusterResponse(), version, true); + } + } + + @Test + public void testUnregisterBrokerSerialization() { + for (short version : ApiKeys.UNREGISTER_BROKER.allVersions()) { + checkRequest(createUnregisterBrokerRequest(version), true); + checkErrorResponse(createUnregisterBrokerRequest(version), unknownServerException, true); + checkResponse(createUnregisterBrokerResponse(), version, true); + } + } + + @Test + public void testFindCoordinatorRequestSerialization() { + for (short version : ApiKeys.FIND_COORDINATOR.allVersions()) { + checkRequest(createFindCoordinatorRequest(version), true); + checkRequest(createBatchedFindCoordinatorRequest(Collections.singletonList("group1"), version), true); + if (version < FindCoordinatorRequest.MIN_BATCHED_VERSION) { + assertThrows(NoBatchedFindCoordinatorsException.class, () -> + createBatchedFindCoordinatorRequest(Arrays.asList("group1", "group2"), version)); + } else { + checkRequest(createBatchedFindCoordinatorRequest(Arrays.asList("group1", "group2"), version), true); + } + checkErrorResponse(createFindCoordinatorRequest(version), unknownServerException, true); + checkResponse(createFindCoordinatorResponse(version), version, true); + } + } + + private DescribeClusterRequest createDescribeClusterRequest(short version) { + return new DescribeClusterRequest.Builder( + new DescribeClusterRequestData() + .setIncludeClusterAuthorizedOperations(true)) + .build(version); + } + + private DescribeClusterResponse createDescribeClusterResponse() { + return new DescribeClusterResponse( + new DescribeClusterResponseData() + .setBrokers(new DescribeClusterBrokerCollection( + Collections.singletonList(new DescribeClusterBroker() + .setBrokerId(1) + .setHost("localhost") + .setPort(9092) + .setRack("rack1")).iterator())) + .setClusterId("clusterId") + .setControllerId(1) + .setClusterAuthorizedOperations(10)); + } + + @Test + public void testResponseHeader() { + ResponseHeader header = createResponseHeader((short) 1); + ObjectSerializationCache serializationCache = new ObjectSerializationCache(); + ByteBuffer buffer = ByteBuffer.allocate(header.size(serializationCache)); + header.write(buffer, serializationCache); + buffer.flip(); + ResponseHeader deserialized = ResponseHeader.parse(buffer, header.headerVersion()); + assertEquals(header.correlationId(), deserialized.correlationId()); + } + + private void checkOlderFetchVersions() { + for (short version : FETCH.allVersions()) { + if (version > 7) { + checkErrorResponse(createFetchRequest(version), unknownServerException, true); + } + checkRequest(createFetchRequest(version), true); + checkResponse(createFetchResponse(version >= 4), version, true); + } + } + + private void verifyDescribeConfigsResponse(DescribeConfigsResponse expected, DescribeConfigsResponse actual, + int version) { + for (Map.Entry resource : expected.resultMap().entrySet()) { + List actualEntries = actual.resultMap().get(resource.getKey()).configs(); + List expectedEntries = expected.resultMap().get(resource.getKey()).configs(); + assertEquals(expectedEntries.size(), actualEntries.size()); + for (int i = 0; i < actualEntries.size(); ++i) { + DescribeConfigsResourceResult actualEntry = actualEntries.get(i); + DescribeConfigsResourceResult expectedEntry = expectedEntries.get(i); + assertEquals(expectedEntry.name(), actualEntry.name()); + assertEquals(expectedEntry.value(), actualEntry.value(), + "Non-matching values for " + actualEntry.name() + " in version " + version); + assertEquals(expectedEntry.readOnly(), actualEntry.readOnly(), + "Non-matching readonly for " + actualEntry.name() + " in version " + version); + assertEquals(expectedEntry.isSensitive(), actualEntry.isSensitive(), + "Non-matching isSensitive for " + actualEntry.name() + " in version " + version); + if (version < 3) { + assertEquals(ConfigType.UNKNOWN.id(), actualEntry.configType(), + "Non-matching configType for " + actualEntry.name() + " in version " + version); + } else { + assertEquals(expectedEntry.configType(), actualEntry.configType(), + "Non-matching configType for " + actualEntry.name() + " in version " + version); + } + if (version == 0) { + assertEquals(DescribeConfigsResponse.ConfigSource.STATIC_BROKER_CONFIG.id(), actualEntry.configSource(), + "Non matching configSource for " + actualEntry.name() + " in version " + version); + } else { + assertEquals(expectedEntry.configSource(), actualEntry.configSource(), + "Non-matching configSource for " + actualEntry.name() + " in version " + version); + } + } + } + } + + private void checkDescribeConfigsResponseVersions() { + for (short version : ApiKeys.DESCRIBE_CONFIGS.allVersions()) { + DescribeConfigsResponse response = createDescribeConfigsResponse(version); + DescribeConfigsResponse deserialized0 = (DescribeConfigsResponse) AbstractResponse.parseResponse(ApiKeys.DESCRIBE_CONFIGS, + response.serialize(version), version); + verifyDescribeConfigsResponse(response, deserialized0, version); + } + } + + private void checkErrorResponse(AbstractRequest req, Throwable e, boolean checkEqualityAndHashCode) { + AbstractResponse response = req.getErrorResponse(e); + checkResponse(response, req.version(), checkEqualityAndHashCode); + Errors error = Errors.forException(e); + Map errorCounts = response.errorCounts(); + assertEquals(Collections.singleton(error), errorCounts.keySet(), + "API Key " + req.apiKey().name + " v" + req.version() + " failed errorCounts test"); + assertTrue(errorCounts.get(error) > 0); + if (e instanceof UnknownServerException) { + String responseStr = response.toString(); + assertFalse(responseStr.contains(e.getMessage()), + String.format("Unknown message included in response for %s: %s ", req.apiKey(), responseStr)); + } + } + + private void checkRequest(AbstractRequest req, boolean checkEquality) { + // Check that we can serialize, deserialize and serialize again + // Check for equality of the ByteBuffer only if indicated (it is likely to fail if any of the fields + // in the request is a HashMap with multiple elements since ordering of the elements may vary) + try { + ByteBuffer serializedBytes = req.serialize(); + AbstractRequest deserialized = AbstractRequest.parseRequest(req.apiKey(), req.version(), serializedBytes).request; + ByteBuffer serializedBytes2 = deserialized.serialize(); + serializedBytes.rewind(); + if (checkEquality) + assertEquals(serializedBytes, serializedBytes2, "Request " + req + "failed equality test"); + } catch (Exception e) { + throw new RuntimeException("Failed to deserialize request " + req + " with type " + req.getClass(), e); + } + } + + private void checkResponse(AbstractResponse response, int version, boolean checkEquality) { + // Check that we can serialize, deserialize and serialize again + // Check for equality and hashCode of the Struct only if indicated (it is likely to fail if any of the fields + // in the response is a HashMap with multiple elements since ordering of the elements may vary) + try { + ByteBuffer serializedBytes = response.serialize((short) version); + AbstractResponse deserialized = AbstractResponse.parseResponse(response.apiKey(), serializedBytes, (short) version); + ByteBuffer serializedBytes2 = deserialized.serialize((short) version); + serializedBytes.rewind(); + if (checkEquality) + assertEquals(serializedBytes, serializedBytes2, "Response " + response + "failed equality test"); + } catch (Exception e) { + throw new RuntimeException("Failed to deserialize response " + response + " with type " + response.getClass(), e); + } + } + + @Test + public void cannotUseFindCoordinatorV0ToFindTransactionCoordinator() { + FindCoordinatorRequest.Builder builder = new FindCoordinatorRequest.Builder( + new FindCoordinatorRequestData() + .setKeyType(CoordinatorType.TRANSACTION.id) + .setKey("foobar")); + assertThrows(UnsupportedVersionException.class, () -> builder.build((short) 0)); + } + + @Test + public void testPartitionSize() { + TopicPartition tp0 = new TopicPartition("test", 0); + TopicPartition tp1 = new TopicPartition("test", 1); + MemoryRecords records0 = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V2, + CompressionType.NONE, new SimpleRecord("woot".getBytes())); + MemoryRecords records1 = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V2, + CompressionType.NONE, new SimpleRecord("woot".getBytes()), new SimpleRecord("woot".getBytes())); + ProduceRequest request = ProduceRequest.forMagic(RecordBatch.MAGIC_VALUE_V2, + new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection(Arrays.asList( + new ProduceRequestData.TopicProduceData().setName(tp0.topic()).setPartitionData( + Collections.singletonList(new ProduceRequestData.PartitionProduceData().setIndex(tp0.partition()).setRecords(records0))), + new ProduceRequestData.TopicProduceData().setName(tp1.topic()).setPartitionData( + Collections.singletonList(new ProduceRequestData.PartitionProduceData().setIndex(tp1.partition()).setRecords(records1)))) + .iterator())) + .setAcks((short) 1) + .setTimeoutMs(5000) + .setTransactionalId("transactionalId")) + .build((short) 3); + assertEquals(2, request.partitionSizes().size()); + assertEquals(records0.sizeInBytes(), (int) request.partitionSizes().get(tp0)); + assertEquals(records1.sizeInBytes(), (int) request.partitionSizes().get(tp1)); + } + + @Test + public void produceRequestToStringTest() { + ProduceRequest request = createProduceRequest(ApiKeys.PRODUCE.latestVersion()); + assertEquals(1, request.data().topicData().size()); + assertFalse(request.toString(false).contains("partitionSizes")); + assertTrue(request.toString(false).contains("numPartitions=1")); + assertTrue(request.toString(true).contains("partitionSizes")); + assertFalse(request.toString(true).contains("numPartitions")); + + request.clearPartitionRecords(); + try { + request.data(); + fail("dataOrException should fail after clearPartitionRecords()"); + } catch (IllegalStateException e) { + // OK + } + + // `toString` should behave the same after `clearPartitionRecords` + assertFalse(request.toString(false).contains("partitionSizes")); + assertTrue(request.toString(false).contains("numPartitions=1")); + assertTrue(request.toString(true).contains("partitionSizes")); + assertFalse(request.toString(true).contains("numPartitions")); + } + + @Test + public void produceRequestGetErrorResponseTest() { + ProduceRequest request = createProduceRequest(ApiKeys.PRODUCE.latestVersion()); + + ProduceResponse errorResponse = (ProduceResponse) request.getErrorResponse(new NotEnoughReplicasException()); + ProduceResponseData.TopicProduceResponse topicProduceResponse = errorResponse.data().responses().iterator().next(); + ProduceResponseData.PartitionProduceResponse partitionProduceResponse = topicProduceResponse.partitionResponses().iterator().next(); + + assertEquals(Errors.NOT_ENOUGH_REPLICAS, Errors.forCode(partitionProduceResponse.errorCode())); + assertEquals(ProduceResponse.INVALID_OFFSET, partitionProduceResponse.baseOffset()); + assertEquals(RecordBatch.NO_TIMESTAMP, partitionProduceResponse.logAppendTimeMs()); + + request.clearPartitionRecords(); + + // `getErrorResponse` should behave the same after `clearPartitionRecords` + errorResponse = (ProduceResponse) request.getErrorResponse(new NotEnoughReplicasException()); + topicProduceResponse = errorResponse.data().responses().iterator().next(); + partitionProduceResponse = topicProduceResponse.partitionResponses().iterator().next(); + + assertEquals(Errors.NOT_ENOUGH_REPLICAS, Errors.forCode(partitionProduceResponse.errorCode())); + assertEquals(ProduceResponse.INVALID_OFFSET, partitionProduceResponse.baseOffset()); + assertEquals(RecordBatch.NO_TIMESTAMP, partitionProduceResponse.logAppendTimeMs()); + } + + @Test + public void fetchResponseVersionTest() { + LinkedHashMap responseData = new LinkedHashMap<>(); + Uuid id = Uuid.randomUuid(); + Map topicNames = Collections.singletonMap(id, "test"); + TopicPartition tp = new TopicPartition("test", 0); + + MemoryRecords records = MemoryRecords.readableRecords(ByteBuffer.allocate(10)); + FetchResponseData.PartitionData partitionData = new FetchResponseData.PartitionData() + .setPartitionIndex(0) + .setHighWatermark(1000000) + .setLogStartOffset(-1) + .setRecords(records); + + // Use zero UUID since we are comparing with old request versions + responseData.put(new TopicIdPartition(Uuid.ZERO_UUID, tp), partitionData); + + LinkedHashMap tpResponseData = new LinkedHashMap<>(); + tpResponseData.put(tp, partitionData); + + FetchResponse v0Response = FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, responseData); + FetchResponse v1Response = FetchResponse.of(Errors.NONE, 10, INVALID_SESSION_ID, responseData); + FetchResponse v0Deserialized = FetchResponse.parse(v0Response.serialize((short) 0), (short) 0); + FetchResponse v1Deserialized = FetchResponse.parse(v1Response.serialize((short) 1), (short) 1); + assertEquals(0, v0Deserialized.throttleTimeMs(), "Throttle time must be zero"); + assertEquals(10, v1Deserialized.throttleTimeMs(), "Throttle time must be 10"); + assertEquals(tpResponseData, v0Deserialized.responseData(topicNames, (short) 0), "Response data does not match"); + assertEquals(tpResponseData, v1Deserialized.responseData(topicNames, (short) 1), "Response data does not match"); + + LinkedHashMap idResponseData = new LinkedHashMap<>(); + idResponseData.put(new TopicIdPartition(id, new TopicPartition("test", 0)), + new FetchResponseData.PartitionData() + .setPartitionIndex(0) + .setHighWatermark(1000000) + .setLogStartOffset(-1) + .setRecords(records)); + FetchResponse idTestResponse = FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, idResponseData); + FetchResponse v12Deserialized = FetchResponse.parse(idTestResponse.serialize((short) 12), (short) 12); + FetchResponse newestDeserialized = FetchResponse.parse(idTestResponse.serialize(FETCH.latestVersion()), FETCH.latestVersion()); + assertTrue(v12Deserialized.topicIds().isEmpty()); + assertEquals(1, newestDeserialized.topicIds().size()); + assertTrue(newestDeserialized.topicIds().contains(id)); + } + + @Test + public void testFetchResponseV4() { + LinkedHashMap responseData = new LinkedHashMap<>(); + Map topicNames = new HashMap<>(); + topicNames.put(Uuid.randomUuid(), "bar"); + topicNames.put(Uuid.randomUuid(), "foo"); + MemoryRecords records = MemoryRecords.readableRecords(ByteBuffer.allocate(10)); + + List abortedTransactions = asList( + new FetchResponseData.AbortedTransaction().setProducerId(10).setFirstOffset(100), + new FetchResponseData.AbortedTransaction().setProducerId(15).setFirstOffset(50) + ); + + // Use zero UUID since this is an old request version. + responseData.put(new TopicIdPartition(Uuid.ZERO_UUID, new TopicPartition("bar", 0)), + new FetchResponseData.PartitionData() + .setPartitionIndex(0) + .setHighWatermark(1000000) + .setAbortedTransactions(abortedTransactions) + .setRecords(records)); + responseData.put(new TopicIdPartition(Uuid.ZERO_UUID, new TopicPartition("bar", 1)), + new FetchResponseData.PartitionData() + .setPartitionIndex(1) + .setHighWatermark(900000) + .setLastStableOffset(5) + .setRecords(records)); + responseData.put(new TopicIdPartition(Uuid.ZERO_UUID, new TopicPartition("foo", 0)), + new FetchResponseData.PartitionData() + .setPartitionIndex(0) + .setHighWatermark(70000) + .setLastStableOffset(6) + .setRecords(records)); + + FetchResponse response = FetchResponse.of(Errors.NONE, 10, INVALID_SESSION_ID, responseData); + FetchResponse deserialized = FetchResponse.parse(response.serialize((short) 4), (short) 4); + assertEquals(responseData.entrySet().stream().collect(Collectors.toMap(e -> e.getKey().topicPartition(), Map.Entry::getValue)), + deserialized.responseData(topicNames, (short) 4)); + } + + @Test + public void verifyFetchResponseFullWrites() throws Exception { + verifyFetchResponseFullWrite(FETCH.latestVersion(), createFetchResponse(123)); + verifyFetchResponseFullWrite(FETCH.latestVersion(), + createFetchResponse(Errors.FETCH_SESSION_ID_NOT_FOUND, 123)); + for (short version : FETCH.allVersions()) { + verifyFetchResponseFullWrite(version, createFetchResponse(version >= 4)); + } + } + + private void verifyFetchResponseFullWrite(short apiVersion, FetchResponse fetchResponse) throws Exception { + int correlationId = 15; + + short responseHeaderVersion = FETCH.responseHeaderVersion(apiVersion); + Send send = fetchResponse.toSend(new ResponseHeader(correlationId, responseHeaderVersion), apiVersion); + ByteBufferChannel channel = new ByteBufferChannel(send.size()); + send.writeTo(channel); + channel.close(); + + ByteBuffer buf = channel.buffer(); + + // read the size + int size = buf.getInt(); + assertTrue(size > 0); + + // read the header + ResponseHeader responseHeader = ResponseHeader.parse(channel.buffer(), responseHeaderVersion); + assertEquals(correlationId, responseHeader.correlationId()); + + assertEquals(fetchResponse.serialize(apiVersion), buf); + FetchResponseData deserialized = new FetchResponseData(new ByteBufferAccessor(buf), apiVersion); + ObjectSerializationCache serializationCache = new ObjectSerializationCache(); + assertEquals(size, responseHeader.size(serializationCache) + deserialized.size(serializationCache, apiVersion)); + } + + @Test + public void testControlledShutdownResponse() { + ControlledShutdownResponse response = createControlledShutdownResponse(); + short version = ApiKeys.CONTROLLED_SHUTDOWN.latestVersion(); + ByteBuffer buffer = response.serialize(version); + ControlledShutdownResponse deserialized = ControlledShutdownResponse.parse(buffer, version); + assertEquals(response.error(), deserialized.error()); + assertEquals(response.data().remainingPartitions(), deserialized.data().remainingPartitions()); + } + + @Test + public void testCreateTopicRequestV0FailsIfValidateOnly() { + assertThrows(UnsupportedVersionException.class, + () -> createCreateTopicRequest(0, true)); + } + + @Test + public void testCreateTopicRequestV3FailsIfNoPartitionsOrReplicas() { + final UnsupportedVersionException exception = assertThrows( + UnsupportedVersionException.class, () -> { + CreateTopicsRequestData data = new CreateTopicsRequestData() + .setTimeoutMs(123) + .setValidateOnly(false); + data.topics().add(new CreatableTopic(). + setName("foo"). + setNumPartitions(CreateTopicsRequest.NO_NUM_PARTITIONS). + setReplicationFactor((short) 1)); + data.topics().add(new CreatableTopic(). + setName("bar"). + setNumPartitions(1). + setReplicationFactor(CreateTopicsRequest.NO_REPLICATION_FACTOR)); + + new Builder(data).build((short) 3); + }); + assertTrue(exception.getMessage().contains("supported in CreateTopicRequest version 4+")); + assertTrue(exception.getMessage().contains("[foo, bar]")); + } + + @Test + public void testFetchRequestMaxBytesOldVersions() { + final short version = 1; + FetchRequest fr = createFetchRequest(version); + FetchRequest fr2 = FetchRequest.parse(fr.serialize(), version); + assertEquals(fr2.maxBytes(), fr.maxBytes()); + } + + @Test + public void testFetchRequestIsolationLevel() throws Exception { + FetchRequest request = createFetchRequest(4, IsolationLevel.READ_COMMITTED); + FetchRequest deserialized = (FetchRequest) AbstractRequest.parseRequest(request.apiKey(), request.version(), + request.serialize()).request; + assertEquals(request.isolationLevel(), deserialized.isolationLevel()); + + request = createFetchRequest(4, IsolationLevel.READ_UNCOMMITTED); + deserialized = (FetchRequest) AbstractRequest.parseRequest(request.apiKey(), request.version(), + request.serialize()).request; + assertEquals(request.isolationLevel(), deserialized.isolationLevel()); + } + + @Test + public void testFetchRequestWithMetadata() throws Exception { + FetchRequest request = createFetchRequest(4, IsolationLevel.READ_COMMITTED); + FetchRequest deserialized = (FetchRequest) AbstractRequest.parseRequest(ApiKeys.FETCH, request.version(), + request.serialize()).request; + assertEquals(request.isolationLevel(), deserialized.isolationLevel()); + + request = createFetchRequest(4, IsolationLevel.READ_UNCOMMITTED); + deserialized = (FetchRequest) AbstractRequest.parseRequest(ApiKeys.FETCH, request.version(), + request.serialize()).request; + assertEquals(request.isolationLevel(), deserialized.isolationLevel()); + } + + @Test + public void testFetchRequestCompat() { + Map fetchData = new HashMap<>(); + fetchData.put(new TopicPartition("test", 0), new FetchRequest.PartitionData(Uuid.ZERO_UUID, 100, 2, 100, Optional.of(42))); + FetchRequest req = FetchRequest.Builder + .forConsumer((short) 2, 100, 100, fetchData) + .metadata(new FetchMetadata(10, 20)) + .isolationLevel(IsolationLevel.READ_COMMITTED) + .build((short) 2); + + FetchRequestData data = req.data(); + ObjectSerializationCache cache = new ObjectSerializationCache(); + int size = data.size(cache, (short) 2); + + ByteBufferAccessor writer = new ByteBufferAccessor(ByteBuffer.allocate(size)); + data.write(writer, cache, (short) 2); + } + + @Test + public void testSerializeWithHeader() { + CreatableTopicCollection topicsToCreate = new CreatableTopicCollection(1); + topicsToCreate.add(new CreatableTopic() + .setName("topic") + .setNumPartitions(3) + .setReplicationFactor((short) 2)); + + CreateTopicsRequest createTopicsRequest = new CreateTopicsRequest.Builder( + new CreateTopicsRequestData() + .setTimeoutMs(10) + .setTopics(topicsToCreate) + ).build(); + + short requestVersion = ApiKeys.CREATE_TOPICS.latestVersion(); + RequestHeader requestHeader = new RequestHeader(ApiKeys.CREATE_TOPICS, requestVersion, "client", 2); + ByteBuffer serializedRequest = createTopicsRequest.serializeWithHeader(requestHeader); + + RequestHeader parsedHeader = RequestHeader.parse(serializedRequest); + assertEquals(requestHeader, parsedHeader); + + RequestAndSize parsedRequest = AbstractRequest.parseRequest( + ApiKeys.CREATE_TOPICS, requestVersion, serializedRequest); + + assertEquals(createTopicsRequest.data(), parsedRequest.request.data()); + } + + @Test + public void testSerializeWithInconsistentHeaderApiKey() { + CreateTopicsRequest createTopicsRequest = new CreateTopicsRequest.Builder( + new CreateTopicsRequestData() + ).build(); + short requestVersion = ApiKeys.CREATE_TOPICS.latestVersion(); + RequestHeader requestHeader = new RequestHeader(DELETE_TOPICS, requestVersion, "client", 2); + assertThrows(IllegalArgumentException.class, () -> createTopicsRequest.serializeWithHeader(requestHeader)); + } + + @Test + public void testSerializeWithInconsistentHeaderVersion() { + CreateTopicsRequest createTopicsRequest = new CreateTopicsRequest.Builder( + new CreateTopicsRequestData() + ).build((short) 2); + RequestHeader requestHeader = new RequestHeader(CREATE_TOPICS, (short) 1, "client", 2); + assertThrows(IllegalArgumentException.class, () -> createTopicsRequest.serializeWithHeader(requestHeader)); + } + + @Test + public void testJoinGroupRequestVersion0RebalanceTimeout() { + final short version = 0; + JoinGroupRequest jgr = createJoinGroupRequest(version); + JoinGroupRequest jgr2 = JoinGroupRequest.parse(jgr.serialize(), version); + assertEquals(jgr2.data().rebalanceTimeoutMs(), jgr.data().rebalanceTimeoutMs()); + } + + @Test + public void testOffsetFetchRequestBuilderToStringV0ToV7() { + List stableFlags = Arrays.asList(true, false); + for (Boolean requireStable : stableFlags) { + String allTopicPartitionsString = new OffsetFetchRequest.Builder("someGroup", + requireStable, + null, + false) + .toString(); + + assertTrue(allTopicPartitionsString.contains("groupId='someGroup', topics=null," + + " groups=[], requireStable=" + requireStable)); + String string = new OffsetFetchRequest.Builder("group1", + requireStable, + Collections.singletonList( + new TopicPartition("test11", 1)), + false) + .toString(); + assertTrue(string.contains("test11")); + assertTrue(string.contains("group1")); + assertTrue(string.contains("requireStable=" + requireStable)); + } + } + + @Test + public void testOffsetFetchRequestBuilderToStringV8AndAbove() { + List stableFlags = Arrays.asList(true, false); + for (Boolean requireStable : stableFlags) { + String allTopicPartitionsString = new OffsetFetchRequest.Builder( + Collections.singletonMap("someGroup", null), + requireStable, + false) + .toString(); + assertTrue(allTopicPartitionsString.contains("groups=[OffsetFetchRequestGroup" + + "(groupId='someGroup', topics=null)], requireStable=" + requireStable)); + + String subsetTopicPartitionsString = new OffsetFetchRequest.Builder( + Collections.singletonMap( + "group1", + Collections.singletonList(new TopicPartition("test11", 1))), + requireStable, + false) + .toString(); + assertTrue(subsetTopicPartitionsString.contains("test11")); + assertTrue(subsetTopicPartitionsString.contains("group1")); + assertTrue(subsetTopicPartitionsString.contains("requireStable=" + requireStable)); + } + } + + @Test + public void testApiVersionsRequestBeforeV3Validation() { + for (short version = 0; version < 3; version++) { + ApiVersionsRequest request = new ApiVersionsRequest(new ApiVersionsRequestData(), version); + assertTrue(request.isValid()); + } + } + + @Test + public void testValidApiVersionsRequest() { + ApiVersionsRequest request; + + request = new ApiVersionsRequest.Builder().build(); + assertTrue(request.isValid()); + + request = new ApiVersionsRequest(new ApiVersionsRequestData() + .setClientSoftwareName("apache-kafka.java") + .setClientSoftwareVersion("0.0.0-SNAPSHOT"), + ApiKeys.API_VERSIONS.latestVersion() + ); + assertTrue(request.isValid()); + } + + @Test + public void testListGroupRequestV3FailsWithStates() { + ListGroupsRequestData data = new ListGroupsRequestData() + .setStatesFilter(asList(ConsumerGroupState.STABLE.name())); + assertThrows(UnsupportedVersionException.class, () -> new ListGroupsRequest.Builder(data).build((short) 3)); + } + + @Test + public void testInvalidApiVersionsRequest() { + testInvalidCase("java@apache_kafka", "0.0.0-SNAPSHOT"); + testInvalidCase("apache-kafka-java", "0.0.0@java"); + testInvalidCase("-apache-kafka-java", "0.0.0"); + testInvalidCase("apache-kafka-java.", "0.0.0"); + } + + private void testInvalidCase(String name, String version) { + ApiVersionsRequest request = new ApiVersionsRequest(new ApiVersionsRequestData() + .setClientSoftwareName(name) + .setClientSoftwareVersion(version), + ApiKeys.API_VERSIONS.latestVersion() + ); + assertFalse(request.isValid()); + } + + @Test + public void testApiVersionResponseWithUnsupportedError() { + for (short version : ApiKeys.API_VERSIONS.allVersions()) { + ApiVersionsRequest request = new ApiVersionsRequest.Builder().build(version); + ApiVersionsResponse response = request.getErrorResponse(0, Errors.UNSUPPORTED_VERSION.exception()); + assertEquals(Errors.UNSUPPORTED_VERSION.code(), response.data().errorCode()); + + ApiVersion apiVersion = response.data().apiKeys().find(ApiKeys.API_VERSIONS.id); + assertNotNull(apiVersion); + assertEquals(ApiKeys.API_VERSIONS.id, apiVersion.apiKey()); + assertEquals(ApiKeys.API_VERSIONS.oldestVersion(), apiVersion.minVersion()); + assertEquals(ApiKeys.API_VERSIONS.latestVersion(), apiVersion.maxVersion()); + } + } + + @Test + public void testApiVersionResponseWithNotUnsupportedError() { + for (short version : ApiKeys.API_VERSIONS.allVersions()) { + ApiVersionsRequest request = new ApiVersionsRequest.Builder().build(version); + ApiVersionsResponse response = request.getErrorResponse(0, Errors.INVALID_REQUEST.exception()); + assertEquals(response.data().errorCode(), Errors.INVALID_REQUEST.code()); + assertTrue(response.data().apiKeys().isEmpty()); + } + } + + private ApiVersionsResponse defaultApiVersionsResponse() { + return ApiVersionsResponse.defaultApiVersionsResponse(ApiMessageType.ListenerType.ZK_BROKER); + } + + @Test + public void testApiVersionResponseParsingFallback() { + for (short version : ApiKeys.API_VERSIONS.allVersions()) { + ByteBuffer buffer = defaultApiVersionsResponse().serialize((short) 0); + ApiVersionsResponse response = ApiVersionsResponse.parse(buffer, version); + assertEquals(Errors.NONE.code(), response.data().errorCode()); + } + } + + @Test + public void testApiVersionResponseParsingFallbackException() { + for (final short version : ApiKeys.API_VERSIONS.allVersions()) { + assertThrows(BufferUnderflowException.class, () -> ApiVersionsResponse.parse(ByteBuffer.allocate(0), version)); + } + } + + @Test + public void testApiVersionResponseParsing() { + for (short version : ApiKeys.API_VERSIONS.allVersions()) { + ByteBuffer buffer = defaultApiVersionsResponse().serialize(version); + ApiVersionsResponse response = ApiVersionsResponse.parse(buffer, version); + assertEquals(Errors.NONE.code(), response.data().errorCode()); + } + } + + @Test + public void testInitProducerIdRequestVersions() { + InitProducerIdRequest.Builder bld = new InitProducerIdRequest.Builder( + new InitProducerIdRequestData().setTransactionTimeoutMs(1000). + setTransactionalId("abracadabra"). + setProducerId(123)); + final UnsupportedVersionException exception = assertThrows( + UnsupportedVersionException.class, () -> bld.build((short) 2).serialize()); + assertTrue(exception.getMessage().contains("Attempted to write a non-default producerId at version 2")); + bld.build((short) 3); + } + + @Test + public void testDeletableTopicResultErrorMessageIsNullByDefault() { + DeletableTopicResult result = new DeletableTopicResult() + .setName("topic") + .setErrorCode(Errors.THROTTLING_QUOTA_EXCEEDED.code()); + + assertEquals("topic", result.name()); + assertEquals(Errors.THROTTLING_QUOTA_EXCEEDED.code(), result.errorCode()); + assertNull(result.errorMessage()); + } + + private ResponseHeader createResponseHeader(short headerVersion) { + return new ResponseHeader(10, headerVersion); + } + + private FindCoordinatorRequest createFindCoordinatorRequest(int version) { + return new FindCoordinatorRequest.Builder( + new FindCoordinatorRequestData() + .setKeyType(CoordinatorType.GROUP.id()) + .setKey("test-group")) + .build((short) version); + } + + private FindCoordinatorRequest createBatchedFindCoordinatorRequest(List coordinatorKeys, int version) { + return new FindCoordinatorRequest.Builder( + new FindCoordinatorRequestData() + .setKeyType(CoordinatorType.GROUP.id()) + .setCoordinatorKeys(coordinatorKeys)) + .build((short) version); + } + + private FindCoordinatorResponse createFindCoordinatorResponse(short version) { + Node node = new Node(10, "host1", 2014); + if (version < FindCoordinatorRequest.MIN_BATCHED_VERSION) + return FindCoordinatorResponse.prepareOldResponse(Errors.NONE, node); + else + return FindCoordinatorResponse.prepareResponse(Errors.NONE, "group", node); + } + + private FetchRequest createFetchRequest(int version, FetchMetadata metadata, List toForget) { + LinkedHashMap fetchData = new LinkedHashMap<>(); + fetchData.put(new TopicPartition("test1", 0), + new FetchRequest.PartitionData(Uuid.randomUuid(), 100, -1L, 1000000, Optional.empty())); + fetchData.put(new TopicPartition("test2", 0), + new FetchRequest.PartitionData(Uuid.randomUuid(), 200, -1L, 1000000, Optional.empty())); + return FetchRequest.Builder.forConsumer((short) version, 100, 100000, fetchData). + metadata(metadata).setMaxBytes(1000).removed(toForget).build((short) version); + } + + private FetchRequest createFetchRequest(int version, IsolationLevel isolationLevel) { + LinkedHashMap fetchData = new LinkedHashMap<>(); + fetchData.put(new TopicPartition("test1", 0), + new FetchRequest.PartitionData(Uuid.randomUuid(), 100, -1L, 1000000, Optional.empty())); + fetchData.put(new TopicPartition("test2", 0), + new FetchRequest.PartitionData(Uuid.randomUuid(), 200, -1L, 1000000, Optional.empty())); + return FetchRequest.Builder.forConsumer((short) version, 100, 100000, fetchData). + isolationLevel(isolationLevel).setMaxBytes(1000).build((short) version); + } + + private FetchRequest createFetchRequest(int version) { + LinkedHashMap fetchData = new LinkedHashMap<>(); + fetchData.put(new TopicPartition("test1", 0), + new FetchRequest.PartitionData(Uuid.randomUuid(), 100, -1L, 1000000, Optional.empty())); + fetchData.put(new TopicPartition("test2", 0), + new FetchRequest.PartitionData(Uuid.randomUuid(), 200, -1L, 1000000, Optional.empty())); + return FetchRequest.Builder.forConsumer((short) version, 100, 100000, fetchData).setMaxBytes(1000).build((short) version); + } + + private FetchResponse createFetchResponse(Errors error, int sessionId) { + return FetchResponse.parse( + FetchResponse.of(error, 25, sessionId, new LinkedHashMap<>()).serialize(FETCH.latestVersion()), FETCH.latestVersion()); + } + + private FetchResponse createFetchResponse(int sessionId) { + LinkedHashMap responseData = new LinkedHashMap<>(); + Map topicIds = new HashMap<>(); + topicIds.put("test", Uuid.randomUuid()); + MemoryRecords records = MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("blah".getBytes())); + responseData.put(new TopicIdPartition(topicIds.get("test"), new TopicPartition("test", 0)), new FetchResponseData.PartitionData() + .setPartitionIndex(0) + .setHighWatermark(1000000) + .setLogStartOffset(0) + .setRecords(records)); + List abortedTransactions = Collections.singletonList( + new FetchResponseData.AbortedTransaction().setProducerId(234L).setFirstOffset(999L)); + responseData.put(new TopicIdPartition(topicIds.get("test"), new TopicPartition("test", 1)), new FetchResponseData.PartitionData() + .setPartitionIndex(1) + .setHighWatermark(1000000) + .setLogStartOffset(0) + .setAbortedTransactions(abortedTransactions)); + return FetchResponse.parse(FetchResponse.of(Errors.NONE, 25, sessionId, + responseData).serialize(FETCH.latestVersion()), FETCH.latestVersion()); + } + + private FetchResponse createFetchResponse(boolean includeAborted) { + LinkedHashMap responseData = new LinkedHashMap<>(); + Uuid topicId = Uuid.randomUuid(); + MemoryRecords records = MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("blah".getBytes())); + responseData.put(new TopicIdPartition(topicId, new TopicPartition("test", 0)), new FetchResponseData.PartitionData() + .setPartitionIndex(0) + .setHighWatermark(1000000) + .setLogStartOffset(0) + .setRecords(records)); + + List abortedTransactions = Collections.emptyList(); + if (includeAborted) { + abortedTransactions = Collections.singletonList( + new FetchResponseData.AbortedTransaction().setProducerId(234L).setFirstOffset(999L)); + } + responseData.put(new TopicIdPartition(topicId, new TopicPartition("test", 1)), new FetchResponseData.PartitionData() + .setPartitionIndex(1) + .setHighWatermark(1000000) + .setLogStartOffset(0) + .setAbortedTransactions(abortedTransactions)); + return FetchResponse.parse(FetchResponse.of(Errors.NONE, 25, INVALID_SESSION_ID, + responseData).serialize(FETCH.latestVersion()), FETCH.latestVersion()); + } + + private HeartbeatRequest createHeartBeatRequest() { + return new HeartbeatRequest.Builder(new HeartbeatRequestData() + .setGroupId("group1") + .setGenerationId(1) + .setMemberId("consumer1")).build(); + } + + private HeartbeatResponse createHeartBeatResponse() { + return new HeartbeatResponse(new HeartbeatResponseData().setErrorCode(Errors.NONE.code())); + } + + private JoinGroupRequest createJoinGroupRequest(int version) { + JoinGroupRequestData.JoinGroupRequestProtocolCollection protocols = + new JoinGroupRequestData.JoinGroupRequestProtocolCollection( + Collections.singleton( + new JoinGroupRequestData.JoinGroupRequestProtocol() + .setName("consumer-range") + .setMetadata(new byte[0])).iterator() + ); + + JoinGroupRequestData data = new JoinGroupRequestData() + .setGroupId("group1") + .setSessionTimeoutMs(30000) + .setMemberId("consumer1") + .setProtocolType("consumer") + .setProtocols(protocols); + + // v1 and above contains rebalance timeout + if (version >= 1) + data.setRebalanceTimeoutMs(60000); + + // v5 and above could set group instance id + if (version >= 5) + data.setGroupInstanceId("groupInstanceId"); + + return new JoinGroupRequest.Builder(data).build((short) version); + } + + private JoinGroupResponse createJoinGroupResponse(int version) { + List members = new ArrayList<>(); + + for (int i = 0; i < 2; i++) { + JoinGroupResponseMember member = new JoinGroupResponseData.JoinGroupResponseMember() + .setMemberId("consumer" + i) + .setMetadata(new byte[0]); + + if (version >= 5) + member.setGroupInstanceId("instance" + i); + + members.add(member); + } + + JoinGroupResponseData data = new JoinGroupResponseData() + .setErrorCode(Errors.NONE.code()) + .setGenerationId(1) + .setProtocolType("consumer") // Added in v7 but ignorable + .setProtocolName("range") + .setLeader("leader") + .setMemberId("consumer1") + .setMembers(members); + + // v1 and above could set throttle time + if (version >= 1) + data.setThrottleTimeMs(1000); + + return new JoinGroupResponse(data); + } + + private SyncGroupRequest createSyncGroupRequest(int version) { + List assignments = Collections.singletonList( + new SyncGroupRequestAssignment() + .setMemberId("member") + .setAssignment(new byte[0]) + ); + + SyncGroupRequestData data = new SyncGroupRequestData() + .setGroupId("group1") + .setGenerationId(1) + .setMemberId("member") + .setProtocolType("consumer") // Added in v5 but ignorable + .setProtocolName("range") // Added in v5 but ignorable + .setAssignments(assignments); + + // v3 and above could set group instance id + if (version >= 3) + data.setGroupInstanceId("groupInstanceId"); + + return new SyncGroupRequest.Builder(data).build((short) version); + } + + private SyncGroupResponse createSyncGroupResponse(int version) { + SyncGroupResponseData data = new SyncGroupResponseData() + .setErrorCode(Errors.NONE.code()) + .setProtocolType("consumer") // Added in v5 but ignorable + .setProtocolName("range") // Added in v5 but ignorable + .setAssignment(new byte[0]); + + // v1 and above could set throttle time + if (version >= 1) + data.setThrottleTimeMs(1000); + + return new SyncGroupResponse(data); + } + + private ListGroupsRequest createListGroupsRequest(short version) { + ListGroupsRequestData data = new ListGroupsRequestData(); + if (version >= 4) + data.setStatesFilter(Arrays.asList("Stable")); + return new ListGroupsRequest.Builder(data).build(version); + } + + private ListGroupsResponse createListGroupsResponse(int version) { + ListGroupsResponseData.ListedGroup group = new ListGroupsResponseData.ListedGroup() + .setGroupId("test-group") + .setProtocolType("consumer"); + if (version >= 4) + group.setGroupState("Stable"); + ListGroupsResponseData data = new ListGroupsResponseData() + .setErrorCode(Errors.NONE.code()) + .setGroups(Collections.singletonList(group)); + return new ListGroupsResponse(data); + } + + private DescribeGroupsRequest createDescribeGroupRequest() { + return new DescribeGroupsRequest.Builder( + new DescribeGroupsRequestData(). + setGroups(Collections.singletonList("test-group"))).build(); + } + + private DescribeGroupsResponse createDescribeGroupResponse() { + String clientId = "consumer-1"; + String clientHost = "localhost"; + DescribeGroupsResponseData describeGroupsResponseData = new DescribeGroupsResponseData(); + DescribeGroupsResponseData.DescribedGroupMember member = DescribeGroupsResponse.groupMember("memberId", null, + clientId, clientHost, new byte[0], new byte[0]); + DescribedGroup metadata = DescribeGroupsResponse.groupMetadata("test-group", + Errors.NONE, + "STABLE", + "consumer", + "roundrobin", + Collections.singletonList(member), + DescribeGroupsResponse.AUTHORIZED_OPERATIONS_OMITTED); + describeGroupsResponseData.groups().add(metadata); + return new DescribeGroupsResponse(describeGroupsResponseData); + } + + private LeaveGroupRequest createLeaveGroupRequest() { + return new LeaveGroupRequest.Builder( + "group1", Collections.singletonList(new MemberIdentity() + .setMemberId("consumer1")) + ).build(); + } + + private LeaveGroupResponse createLeaveGroupResponse() { + return new LeaveGroupResponse(new LeaveGroupResponseData().setErrorCode(Errors.NONE.code())); + } + + private DeleteGroupsRequest createDeleteGroupsRequest() { + return new DeleteGroupsRequest.Builder( + new DeleteGroupsRequestData() + .setGroupsNames(Collections.singletonList("test-group")) + ).build(); + } + + private DeleteGroupsResponse createDeleteGroupsResponse() { + DeletableGroupResultCollection result = new DeletableGroupResultCollection(); + result.add(new DeletableGroupResult() + .setGroupId("test-group") + .setErrorCode(Errors.NONE.code())); + return new DeleteGroupsResponse( + new DeleteGroupsResponseData() + .setResults(result) + ); + } + + private ListOffsetsRequest createListOffsetRequest(int version) { + if (version == 0) { + ListOffsetsTopic topic = new ListOffsetsTopic() + .setName("test") + .setPartitions(Arrays.asList(new ListOffsetsPartition() + .setPartitionIndex(0) + .setTimestamp(1000000L) + .setMaxNumOffsets(10) + .setCurrentLeaderEpoch(5))); + return ListOffsetsRequest.Builder + .forConsumer(false, IsolationLevel.READ_UNCOMMITTED, false) + .setTargetTimes(Collections.singletonList(topic)) + .build((short) version); + } else if (version == 1) { + ListOffsetsTopic topic = new ListOffsetsTopic() + .setName("test") + .setPartitions(Arrays.asList(new ListOffsetsPartition() + .setPartitionIndex(0) + .setTimestamp(1000000L) + .setCurrentLeaderEpoch(5))); + return ListOffsetsRequest.Builder + .forConsumer(true, IsolationLevel.READ_UNCOMMITTED, false) + .setTargetTimes(Collections.singletonList(topic)) + .build((short) version); + } else if (version >= 2 && version <= LIST_OFFSETS.latestVersion()) { + ListOffsetsPartition partition = new ListOffsetsPartition() + .setPartitionIndex(0) + .setTimestamp(1000000L) + .setCurrentLeaderEpoch(5); + + ListOffsetsTopic topic = new ListOffsetsTopic() + .setName("test") + .setPartitions(Arrays.asList(partition)); + return ListOffsetsRequest.Builder + .forConsumer(true, IsolationLevel.READ_COMMITTED, false) + .setTargetTimes(Collections.singletonList(topic)) + .build((short) version); + } else { + throw new IllegalArgumentException("Illegal ListOffsetRequest version " + version); + } + } + + private ListOffsetsResponse createListOffsetResponse(int version) { + if (version == 0) { + ListOffsetsResponseData data = new ListOffsetsResponseData() + .setTopics(Collections.singletonList(new ListOffsetsTopicResponse() + .setName("test") + .setPartitions(Collections.singletonList(new ListOffsetsPartitionResponse() + .setPartitionIndex(0) + .setErrorCode(Errors.NONE.code()) + .setOldStyleOffsets(asList(100L)))))); + return new ListOffsetsResponse(data); + } else if (version >= 1 && version <= LIST_OFFSETS.latestVersion()) { + ListOffsetsPartitionResponse partition = new ListOffsetsPartitionResponse() + .setPartitionIndex(0) + .setErrorCode(Errors.NONE.code()) + .setTimestamp(10000L) + .setOffset(100L); + if (version >= 4) { + partition.setLeaderEpoch(27); + } + ListOffsetsResponseData data = new ListOffsetsResponseData() + .setTopics(Collections.singletonList(new ListOffsetsTopicResponse() + .setName("test") + .setPartitions(Collections.singletonList(partition)))); + return new ListOffsetsResponse(data); + } else { + throw new IllegalArgumentException("Illegal ListOffsetResponse version " + version); + } + } + + private MetadataRequest createMetadataRequest(int version, List topics) { + return new MetadataRequest.Builder(topics, true).build((short) version); + } + + private MetadataResponse createMetadataResponse() { + Node node = new Node(1, "host1", 1001); + List replicas = singletonList(node.id()); + List isr = singletonList(node.id()); + List offlineReplicas = emptyList(); + + List allTopicMetadata = new ArrayList<>(); + allTopicMetadata.add(new MetadataResponse.TopicMetadata(Errors.NONE, "__consumer_offsets", true, + asList(new MetadataResponse.PartitionMetadata(Errors.NONE, + new TopicPartition("__consumer_offsets", 1), + Optional.of(node.id()), Optional.of(5), replicas, isr, offlineReplicas)))); + allTopicMetadata.add(new MetadataResponse.TopicMetadata(Errors.LEADER_NOT_AVAILABLE, "topic2", false, + emptyList())); + allTopicMetadata.add(new MetadataResponse.TopicMetadata(Errors.NONE, "topic3", false, + asList(new MetadataResponse.PartitionMetadata(Errors.LEADER_NOT_AVAILABLE, + new TopicPartition("topic3", 0), Optional.empty(), + Optional.empty(), replicas, isr, offlineReplicas)))); + + return RequestTestUtils.metadataResponse(asList(node), null, MetadataResponse.NO_CONTROLLER_ID, allTopicMetadata); + } + + private OffsetCommitRequest createOffsetCommitRequest(int version) { + return new OffsetCommitRequest.Builder(new OffsetCommitRequestData() + .setGroupId("group1") + .setMemberId("consumer1") + .setGroupInstanceId(null) + .setGenerationId(100) + .setTopics(Collections.singletonList( + new OffsetCommitRequestData.OffsetCommitRequestTopic() + .setName("test") + .setPartitions(Arrays.asList( + new OffsetCommitRequestData.OffsetCommitRequestPartition() + .setPartitionIndex(0) + .setCommittedOffset(100) + .setCommittedLeaderEpoch(RecordBatch.NO_PARTITION_LEADER_EPOCH) + .setCommittedMetadata(""), + new OffsetCommitRequestData.OffsetCommitRequestPartition() + .setPartitionIndex(1) + .setCommittedOffset(200) + .setCommittedLeaderEpoch(RecordBatch.NO_PARTITION_LEADER_EPOCH) + .setCommittedMetadata(null) + )) + )) + ).build((short) version); + } + + private OffsetCommitResponse createOffsetCommitResponse() { + return new OffsetCommitResponse(new OffsetCommitResponseData() + .setTopics(Collections.singletonList( + new OffsetCommitResponseData.OffsetCommitResponseTopic() + .setName("test") + .setPartitions(Collections.singletonList( + new OffsetCommitResponseData.OffsetCommitResponsePartition() + .setPartitionIndex(0) + .setErrorCode(Errors.NONE.code()) + )) + )) + ); + } + + private OffsetFetchRequest createOffsetFetchRequest(int version, boolean requireStable) { + if (version < 8) { + return new OffsetFetchRequest.Builder( + "group1", + requireStable, + Collections.singletonList(new TopicPartition("test11", 1)), + false) + .build((short) version); + } + return new OffsetFetchRequest.Builder( + Collections.singletonMap( + "group1", + Collections.singletonList(new TopicPartition("test11", 1))), + requireStable, + false) + .build((short) version); + } + + private OffsetFetchRequest createOffsetFetchRequestWithMultipleGroups(int version, + boolean requireStable) { + Map> groupToPartitionMap = new HashMap<>(); + List topic1 = singletonList( + new TopicPartition("topic1", 0)); + List topic2 = Arrays.asList( + new TopicPartition("topic1", 0), + new TopicPartition("topic2", 0), + new TopicPartition("topic2", 1)); + List topic3 = Arrays.asList( + new TopicPartition("topic1", 0), + new TopicPartition("topic2", 0), + new TopicPartition("topic2", 1), + new TopicPartition("topic3", 0), + new TopicPartition("topic3", 1), + new TopicPartition("topic3", 2)); + groupToPartitionMap.put("group1", topic1); + groupToPartitionMap.put("group2", topic2); + groupToPartitionMap.put("group3", topic3); + groupToPartitionMap.put("group4", null); + groupToPartitionMap.put("group5", null); + + return new OffsetFetchRequest.Builder( + groupToPartitionMap, + requireStable, + false + ).build((short) version); + } + + private OffsetFetchRequest createOffsetFetchRequestForAllPartition(int version, boolean requireStable) { + if (version < 8) { + return new OffsetFetchRequest.Builder( + "group1", + requireStable, + null, + false) + .build((short) version); + } + return new OffsetFetchRequest.Builder( + Collections.singletonMap( + "group1", null), + requireStable, + false) + .build((short) version); + } + + private OffsetFetchResponse createOffsetFetchResponse(int version) { + Map responseData = new HashMap<>(); + responseData.put(new TopicPartition("test", 0), new OffsetFetchResponse.PartitionData( + 100L, Optional.empty(), "", Errors.NONE)); + responseData.put(new TopicPartition("test", 1), new OffsetFetchResponse.PartitionData( + 100L, Optional.of(10), null, Errors.NONE)); + if (version < 8) { + return new OffsetFetchResponse(Errors.NONE, responseData); + } + int throttleMs = 10; + return new OffsetFetchResponse(throttleMs, Collections.singletonMap("group1", Errors.NONE), + Collections.singletonMap("group1", responseData)); + } + + private ProduceRequest createProduceRequest(int version) { + if (version < 2) + throw new IllegalArgumentException("Produce request version 2 is not supported"); + byte magic = version == 2 ? RecordBatch.MAGIC_VALUE_V1 : RecordBatch.MAGIC_VALUE_V2; + MemoryRecords records = MemoryRecords.withRecords(magic, CompressionType.NONE, new SimpleRecord("woot".getBytes())); + return ProduceRequest.forMagic(magic, + new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection(Collections.singletonList( + new ProduceRequestData.TopicProduceData() + .setName("test") + .setPartitionData(Collections.singletonList(new ProduceRequestData.PartitionProduceData() + .setIndex(0) + .setRecords(records)))).iterator())) + .setAcks((short) 1) + .setTimeoutMs(5000) + .setTransactionalId(version >= 3 ? "transactionalId" : null)) + .build((short) version); + } + + @SuppressWarnings("deprecation") + private ProduceResponse createProduceResponse() { + Map responseData = new HashMap<>(); + responseData.put(new TopicPartition("test", 0), new ProduceResponse.PartitionResponse(Errors.NONE, + 10000, RecordBatch.NO_TIMESTAMP, 100)); + return new ProduceResponse(responseData, 0); + } + + @SuppressWarnings("deprecation") + private ProduceResponse createProduceResponseWithErrorMessage() { + Map responseData = new HashMap<>(); + responseData.put(new TopicPartition("test", 0), new ProduceResponse.PartitionResponse(Errors.NONE, + 10000, RecordBatch.NO_TIMESTAMP, 100, Collections.singletonList(new ProduceResponse.RecordError(0, "error message")), + "global error message")); + return new ProduceResponse(responseData, 0); + } + + private StopReplicaRequest createStopReplicaRequest(int version, boolean deletePartitions) { + List topicStates = new ArrayList<>(); + StopReplicaTopicState topic1 = new StopReplicaTopicState() + .setTopicName("topic1") + .setPartitionStates(Collections.singletonList(new StopReplicaPartitionState() + .setPartitionIndex(0) + .setLeaderEpoch(1) + .setDeletePartition(deletePartitions))); + topicStates.add(topic1); + StopReplicaTopicState topic2 = new StopReplicaTopicState() + .setTopicName("topic2") + .setPartitionStates(Collections.singletonList(new StopReplicaPartitionState() + .setPartitionIndex(1) + .setLeaderEpoch(2) + .setDeletePartition(deletePartitions))); + topicStates.add(topic2); + + return new StopReplicaRequest.Builder((short) version, 0, 1, 0, + deletePartitions, topicStates).build((short) version); + } + + private StopReplicaResponse createStopReplicaResponse() { + List partitions = new ArrayList<>(); + partitions.add(new StopReplicaResponseData.StopReplicaPartitionError() + .setTopicName("test") + .setPartitionIndex(0) + .setErrorCode(Errors.NONE.code())); + return new StopReplicaResponse(new StopReplicaResponseData() + .setErrorCode(Errors.NONE.code()) + .setPartitionErrors(partitions)); + } + + private ControlledShutdownRequest createControlledShutdownRequest() { + ControlledShutdownRequestData data = new ControlledShutdownRequestData() + .setBrokerId(10) + .setBrokerEpoch(0L); + return new ControlledShutdownRequest.Builder( + data, + ApiKeys.CONTROLLED_SHUTDOWN.latestVersion()).build(); + } + + private ControlledShutdownRequest createControlledShutdownRequest(int version) { + ControlledShutdownRequestData data = new ControlledShutdownRequestData() + .setBrokerId(10) + .setBrokerEpoch(0L); + return new ControlledShutdownRequest.Builder( + data, + ApiKeys.CONTROLLED_SHUTDOWN.latestVersion()).build((short) version); + } + + private ControlledShutdownResponse createControlledShutdownResponse() { + RemainingPartition p1 = new RemainingPartition() + .setTopicName("test2") + .setPartitionIndex(5); + RemainingPartition p2 = new RemainingPartition() + .setTopicName("test1") + .setPartitionIndex(10); + RemainingPartitionCollection pSet = new RemainingPartitionCollection(); + pSet.add(p1); + pSet.add(p2); + ControlledShutdownResponseData data = new ControlledShutdownResponseData() + .setErrorCode(Errors.NONE.code()) + .setRemainingPartitions(pSet); + return new ControlledShutdownResponse(data); + } + + private LeaderAndIsrRequest createLeaderAndIsrRequest(int version) { + List partitionStates = new ArrayList<>(); + List isr = asList(1, 2); + List replicas = asList(1, 2, 3, 4); + partitionStates.add(new LeaderAndIsrPartitionState() + .setTopicName("topic5") + .setPartitionIndex(105) + .setControllerEpoch(0) + .setLeader(2) + .setLeaderEpoch(1) + .setIsr(isr) + .setZkVersion(2) + .setReplicas(replicas) + .setIsNew(false)); + partitionStates.add(new LeaderAndIsrPartitionState() + .setTopicName("topic5") + .setPartitionIndex(1) + .setControllerEpoch(1) + .setLeader(1) + .setLeaderEpoch(1) + .setIsr(isr) + .setZkVersion(2) + .setReplicas(replicas) + .setIsNew(false)); + partitionStates.add(new LeaderAndIsrPartitionState() + .setTopicName("topic20") + .setPartitionIndex(1) + .setControllerEpoch(1) + .setLeader(0) + .setLeaderEpoch(1) + .setIsr(isr) + .setZkVersion(2) + .setReplicas(replicas) + .setIsNew(false)); + + Set leaders = Utils.mkSet( + new Node(0, "test0", 1223), + new Node(1, "test1", 1223) + ); + + Map topicIds = new HashMap<>(); + topicIds.put("topic5", Uuid.randomUuid()); + topicIds.put("topic20", Uuid.randomUuid()); + + return new LeaderAndIsrRequest.Builder((short) version, 1, 10, 0, + partitionStates, topicIds, leaders).build(); + } + + private LeaderAndIsrResponse createLeaderAndIsrResponse(int version) { + if (version < 5) { + List partitions = new ArrayList<>(); + partitions.add(new LeaderAndIsrResponseData.LeaderAndIsrPartitionError() + .setTopicName("test") + .setPartitionIndex(0) + .setErrorCode(Errors.NONE.code())); + return new LeaderAndIsrResponse(new LeaderAndIsrResponseData() + .setErrorCode(Errors.NONE.code()) + .setPartitionErrors(partitions), (short) version); + } else { + List partition = Collections.singletonList( + new LeaderAndIsrResponseData.LeaderAndIsrPartitionError() + .setPartitionIndex(0) + .setErrorCode(Errors.NONE.code())); + LeaderAndIsrTopicErrorCollection topics = new LeaderAndIsrTopicErrorCollection(); + topics.add(new LeaderAndIsrResponseData.LeaderAndIsrTopicError() + .setTopicId(Uuid.randomUuid()) + .setPartitionErrors(partition)); + return new LeaderAndIsrResponse(new LeaderAndIsrResponseData() + .setTopics(topics), (short) version); + } + } + + private UpdateMetadataRequest createUpdateMetadataRequest(int version, String rack) { + List partitionStates = new ArrayList<>(); + List isr = asList(1, 2); + List replicas = asList(1, 2, 3, 4); + List offlineReplicas = asList(); + partitionStates.add(new UpdateMetadataPartitionState() + .setTopicName("topic5") + .setPartitionIndex(105) + .setControllerEpoch(0) + .setLeader(2) + .setLeaderEpoch(1) + .setIsr(isr) + .setZkVersion(2) + .setReplicas(replicas) + .setOfflineReplicas(offlineReplicas)); + partitionStates.add(new UpdateMetadataPartitionState() + .setTopicName("topic5") + .setPartitionIndex(1) + .setControllerEpoch(1) + .setLeader(1) + .setLeaderEpoch(1) + .setIsr(isr) + .setZkVersion(2) + .setReplicas(replicas) + .setOfflineReplicas(offlineReplicas)); + partitionStates.add(new UpdateMetadataPartitionState() + .setTopicName("topic20") + .setPartitionIndex(1) + .setControllerEpoch(1) + .setLeader(0) + .setLeaderEpoch(1) + .setIsr(isr) + .setZkVersion(2) + .setReplicas(replicas) + .setOfflineReplicas(offlineReplicas)); + + Map topicIds = new HashMap<>(); + topicIds.put("topic5", Uuid.randomUuid()); + topicIds.put("topic20", Uuid.randomUuid()); + + SecurityProtocol plaintext = SecurityProtocol.PLAINTEXT; + List endpoints1 = new ArrayList<>(); + endpoints1.add(new UpdateMetadataEndpoint() + .setHost("host1") + .setPort(1223) + .setSecurityProtocol(plaintext.id) + .setListener(ListenerName.forSecurityProtocol(plaintext).value())); + + List endpoints2 = new ArrayList<>(); + endpoints2.add(new UpdateMetadataEndpoint() + .setHost("host1") + .setPort(1244) + .setSecurityProtocol(plaintext.id) + .setListener(ListenerName.forSecurityProtocol(plaintext).value())); + if (version > 0) { + SecurityProtocol ssl = SecurityProtocol.SSL; + endpoints2.add(new UpdateMetadataEndpoint() + .setHost("host2") + .setPort(1234) + .setSecurityProtocol(ssl.id) + .setListener(ListenerName.forSecurityProtocol(ssl).value())); + endpoints2.add(new UpdateMetadataEndpoint() + .setHost("host2") + .setPort(1334) + .setSecurityProtocol(ssl.id)); + if (version >= 3) + endpoints2.get(1).setListener("CLIENT"); + } + + List liveBrokers = Arrays.asList( + new UpdateMetadataBroker() + .setId(0) + .setEndpoints(endpoints1) + .setRack(rack), + new UpdateMetadataBroker() + .setId(1) + .setEndpoints(endpoints2) + .setRack(rack) + ); + return new UpdateMetadataRequest.Builder((short) version, 1, 10, 0, partitionStates, + liveBrokers, Collections.emptyMap()).build(); + } + + private UpdateMetadataResponse createUpdateMetadataResponse() { + return new UpdateMetadataResponse(new UpdateMetadataResponseData().setErrorCode(Errors.NONE.code())); + } + + private SaslHandshakeRequest createSaslHandshakeRequest() { + return new SaslHandshakeRequest.Builder( + new SaslHandshakeRequestData().setMechanism("PLAIN")).build(); + } + + private SaslHandshakeResponse createSaslHandshakeResponse() { + return new SaslHandshakeResponse( + new SaslHandshakeResponseData() + .setErrorCode(Errors.NONE.code()).setMechanisms(Collections.singletonList("GSSAPI"))); + } + + private SaslAuthenticateRequest createSaslAuthenticateRequest() { + SaslAuthenticateRequestData data = new SaslAuthenticateRequestData().setAuthBytes(new byte[0]); + return new SaslAuthenticateRequest(data, ApiKeys.SASL_AUTHENTICATE.latestVersion()); + } + + private SaslAuthenticateResponse createSaslAuthenticateResponse() { + SaslAuthenticateResponseData data = new SaslAuthenticateResponseData() + .setErrorCode(Errors.NONE.code()) + .setAuthBytes(new byte[0]) + .setSessionLifetimeMs(Long.MAX_VALUE); + return new SaslAuthenticateResponse(data); + } + + private ApiVersionsRequest createApiVersionRequest(short version) { + return new ApiVersionsRequest.Builder().build(version); + } + + private ApiVersionsResponse createApiVersionResponse() { + ApiVersionCollection apiVersions = new ApiVersionCollection(); + apiVersions.add(new ApiVersion() + .setApiKey((short) 0) + .setMinVersion((short) 0) + .setMaxVersion((short) 2)); + + return new ApiVersionsResponse(new ApiVersionsResponseData() + .setErrorCode(Errors.NONE.code()) + .setThrottleTimeMs(0) + .setApiKeys(apiVersions)); + } + + private CreateTopicsRequest createCreateTopicRequest(int version) { + return createCreateTopicRequest(version, version >= 1); + } + + private CreateTopicsRequest createCreateTopicRequest(int version, boolean validateOnly) { + CreateTopicsRequestData data = new CreateTopicsRequestData() + .setTimeoutMs(123) + .setValidateOnly(validateOnly); + data.topics().add(new CreatableTopic() + .setNumPartitions(3) + .setReplicationFactor((short) 5)); + + CreatableTopic topic2 = new CreatableTopic(); + data.topics().add(topic2); + topic2.assignments().add(new CreatableReplicaAssignment() + .setPartitionIndex(0) + .setBrokerIds(Arrays.asList(1, 2, 3))); + topic2.assignments().add(new CreatableReplicaAssignment() + .setPartitionIndex(1) + .setBrokerIds(Arrays.asList(2, 3, 4))); + topic2.configs().add(new CreateableTopicConfig() + .setName("config1").setValue("value1")); + + return new CreateTopicsRequest.Builder(data).build((short) version); + } + + private CreateTopicsResponse createCreateTopicResponse() { + CreateTopicsResponseData data = new CreateTopicsResponseData(); + data.topics().add(new CreatableTopicResult() + .setName("t1") + .setErrorCode(Errors.INVALID_TOPIC_EXCEPTION.code()) + .setErrorMessage(null)); + data.topics().add(new CreatableTopicResult() + .setName("t2") + .setErrorCode(Errors.LEADER_NOT_AVAILABLE.code()) + .setErrorMessage("Leader with id 5 is not available.")); + data.topics().add(new CreatableTopicResult() + .setName("t3") + .setErrorCode(Errors.NONE.code()) + .setNumPartitions(1) + .setReplicationFactor((short) 2) + .setConfigs(Collections.singletonList(new CreatableTopicConfigs() + .setName("min.insync.replicas") + .setValue("2")))); + return new CreateTopicsResponse(data); + } + + private DeleteTopicsRequest createDeleteTopicsRequest(int version) { + return new DeleteTopicsRequest.Builder(new DeleteTopicsRequestData() + .setTopicNames(Arrays.asList("my_t1", "my_t2")) + .setTimeoutMs(1000) + ).build((short) version); + } + + private DeleteTopicsResponse createDeleteTopicsResponse() { + DeleteTopicsResponseData data = new DeleteTopicsResponseData(); + data.responses().add(new DeletableTopicResult() + .setName("t1") + .setErrorCode(Errors.INVALID_TOPIC_EXCEPTION.code()) + .setErrorMessage("Error Message")); + data.responses().add(new DeletableTopicResult() + .setName("t2") + .setErrorCode(Errors.TOPIC_AUTHORIZATION_FAILED.code()) + .setErrorMessage("Error Message")); + data.responses().add(new DeletableTopicResult() + .setName("t3") + .setErrorCode(Errors.NOT_CONTROLLER.code())); + data.responses().add(new DeletableTopicResult() + .setName("t4") + .setErrorCode(Errors.NONE.code())); + return new DeleteTopicsResponse(data); + } + + private InitProducerIdRequest createInitPidRequest() { + InitProducerIdRequestData requestData = new InitProducerIdRequestData() + .setTransactionalId(null) + .setTransactionTimeoutMs(100); + return new InitProducerIdRequest.Builder(requestData).build(); + } + + private InitProducerIdResponse createInitPidResponse() { + InitProducerIdResponseData responseData = new InitProducerIdResponseData() + .setErrorCode(Errors.NONE.code()) + .setProducerEpoch((short) 3) + .setProducerId(3332) + .setThrottleTimeMs(0); + return new InitProducerIdResponse(responseData); + } + + private OffsetForLeaderTopicCollection createOffsetForLeaderTopicCollection() { + OffsetForLeaderTopicCollection topics = new OffsetForLeaderTopicCollection(); + topics.add(new OffsetForLeaderTopic() + .setTopic("topic1") + .setPartitions(Arrays.asList( + new OffsetForLeaderPartition() + .setPartition(0) + .setLeaderEpoch(1) + .setCurrentLeaderEpoch(0), + new OffsetForLeaderPartition() + .setPartition(1) + .setLeaderEpoch(1) + .setCurrentLeaderEpoch(0)))); + topics.add(new OffsetForLeaderTopic() + .setTopic("topic2") + .setPartitions(Arrays.asList( + new OffsetForLeaderPartition() + .setPartition(2) + .setLeaderEpoch(3) + .setCurrentLeaderEpoch(RecordBatch.NO_PARTITION_LEADER_EPOCH)))); + return topics; + } + + private OffsetsForLeaderEpochRequest createLeaderEpochRequestForConsumer() { + OffsetForLeaderTopicCollection epochs = createOffsetForLeaderTopicCollection(); + return OffsetsForLeaderEpochRequest.Builder.forConsumer(epochs).build(); + } + + private OffsetsForLeaderEpochRequest createLeaderEpochRequestForReplica(int version, int replicaId) { + OffsetForLeaderTopicCollection epochs = createOffsetForLeaderTopicCollection(); + return OffsetsForLeaderEpochRequest.Builder.forFollower((short) version, epochs, replicaId).build(); + } + + private OffsetsForLeaderEpochResponse createLeaderEpochResponse() { + OffsetForLeaderEpochResponseData data = new OffsetForLeaderEpochResponseData(); + data.topics().add(new OffsetForLeaderTopicResult() + .setTopic("topic1") + .setPartitions(Arrays.asList( + new EpochEndOffset() + .setPartition(0) + .setErrorCode(Errors.NONE.code()) + .setLeaderEpoch(1) + .setEndOffset(0), + new EpochEndOffset() + .setPartition(1) + .setErrorCode(Errors.NONE.code()) + .setLeaderEpoch(1) + .setEndOffset(1)))); + data.topics().add(new OffsetForLeaderTopicResult() + .setTopic("topic2") + .setPartitions(Arrays.asList( + new EpochEndOffset() + .setPartition(2) + .setErrorCode(Errors.NONE.code()) + .setLeaderEpoch(1) + .setEndOffset(1)))); + + return new OffsetsForLeaderEpochResponse(data); + } + + private AddPartitionsToTxnRequest createAddPartitionsToTxnRequest() { + return new AddPartitionsToTxnRequest.Builder("tid", 21L, (short) 42, + Collections.singletonList(new TopicPartition("topic", 73))).build(); + } + + private AddPartitionsToTxnResponse createAddPartitionsToTxnResponse() { + return new AddPartitionsToTxnResponse(0, Collections.singletonMap(new TopicPartition("t", 0), Errors.NONE)); + } + + private AddOffsetsToTxnRequest createAddOffsetsToTxnRequest() { + return new AddOffsetsToTxnRequest.Builder( + new AddOffsetsToTxnRequestData() + .setTransactionalId("tid") + .setProducerId(21L) + .setProducerEpoch((short) 42) + .setGroupId("gid") + ).build(); + } + + private AddOffsetsToTxnResponse createAddOffsetsToTxnResponse() { + return new AddOffsetsToTxnResponse(new AddOffsetsToTxnResponseData() + .setErrorCode(Errors.NONE.code()) + .setThrottleTimeMs(0)); + } + + private EndTxnRequest createEndTxnRequest() { + return new EndTxnRequest.Builder( + new EndTxnRequestData() + .setTransactionalId("tid") + .setProducerId(21L) + .setProducerEpoch((short) 42) + .setCommitted(TransactionResult.COMMIT.id) + ).build(); + } + + private EndTxnResponse createEndTxnResponse() { + return new EndTxnResponse( + new EndTxnResponseData() + .setErrorCode(Errors.NONE.code()) + .setThrottleTimeMs(0) + ); + } + + private WriteTxnMarkersRequest createWriteTxnMarkersRequest() { + List partitions = Collections.singletonList(new TopicPartition("topic", 73)); + WriteTxnMarkersRequest.TxnMarkerEntry txnMarkerEntry = new WriteTxnMarkersRequest.TxnMarkerEntry(21L, (short) 42, 73, TransactionResult.ABORT, partitions); + return new WriteTxnMarkersRequest.Builder(ApiKeys.WRITE_TXN_MARKERS.latestVersion(), Collections.singletonList(txnMarkerEntry)).build(); + } + + private WriteTxnMarkersResponse createWriteTxnMarkersResponse() { + final Map errorPerPartitions = new HashMap<>(); + errorPerPartitions.put(new TopicPartition("topic", 73), Errors.NONE); + final Map> response = new HashMap<>(); + response.put(21L, errorPerPartitions); + return new WriteTxnMarkersResponse(response); + } + + private TxnOffsetCommitRequest createTxnOffsetCommitRequest(int version) { + final Map offsets = new HashMap<>(); + offsets.put(new TopicPartition("topic", 73), + new TxnOffsetCommitRequest.CommittedOffset(100, null, Optional.empty())); + offsets.put(new TopicPartition("topic", 74), + new TxnOffsetCommitRequest.CommittedOffset(100, "blah", Optional.of(27))); + + if (version < 3) { + return new TxnOffsetCommitRequest.Builder("transactionalId", + "groupId", + 21L, + (short) 42, + offsets).build(); + } else { + return new TxnOffsetCommitRequest.Builder("transactionalId", + "groupId", + 21L, + (short) 42, + offsets, + "member", + 2, + Optional.of("instance")).build(); + } + } + + private TxnOffsetCommitRequest createTxnOffsetCommitRequestWithAutoDowngrade(int version) { + final Map offsets = new HashMap<>(); + offsets.put(new TopicPartition("topic", 73), + new TxnOffsetCommitRequest.CommittedOffset(100, null, Optional.empty())); + offsets.put(new TopicPartition("topic", 74), + new TxnOffsetCommitRequest.CommittedOffset(100, "blah", Optional.of(27))); + + return new TxnOffsetCommitRequest.Builder("transactionalId", + "groupId", + 21L, + (short) 42, + offsets, + "member", + 2, + Optional.of("instance")).build(); + } + + private TxnOffsetCommitResponse createTxnOffsetCommitResponse() { + final Map errorPerPartitions = new HashMap<>(); + errorPerPartitions.put(new TopicPartition("topic", 73), Errors.NONE); + return new TxnOffsetCommitResponse(0, errorPerPartitions); + } + + private DescribeAclsRequest createDescribeAclsRequest() { + return new DescribeAclsRequest.Builder(new AclBindingFilter( + new ResourcePatternFilter(ResourceType.TOPIC, "mytopic", PatternType.LITERAL), + new AccessControlEntryFilter(null, null, AclOperation.ANY, AclPermissionType.ANY))).build(); + } + + private DescribeAclsResponse createDescribeAclsResponse() { + DescribeAclsResponseData data = new DescribeAclsResponseData() + .setErrorCode(Errors.NONE.code()) + .setErrorMessage(Errors.NONE.message()) + .setThrottleTimeMs(0) + .setResources(Collections.singletonList(new DescribeAclsResource() + .setResourceType(ResourceType.TOPIC.code()) + .setResourceName("mytopic") + .setPatternType(PatternType.LITERAL.code()) + .setAcls(Collections.singletonList(new AclDescription() + .setHost("*") + .setOperation(AclOperation.WRITE.code()) + .setPermissionType(AclPermissionType.ALLOW.code()) + .setPrincipal("User:ANONYMOUS"))))); + return new DescribeAclsResponse(data); + } + + private CreateAclsRequest createCreateAclsRequest() { + List creations = new ArrayList<>(); + creations.add(CreateAclsRequest.aclCreation(new AclBinding( + new ResourcePattern(ResourceType.TOPIC, "mytopic", PatternType.LITERAL), + new AccessControlEntry("User:ANONYMOUS", "127.0.0.1", AclOperation.READ, AclPermissionType.ALLOW)))); + creations.add(CreateAclsRequest.aclCreation(new AclBinding( + new ResourcePattern(ResourceType.GROUP, "mygroup", PatternType.LITERAL), + new AccessControlEntry("User:ANONYMOUS", "*", AclOperation.WRITE, AclPermissionType.DENY)))); + CreateAclsRequestData data = new CreateAclsRequestData().setCreations(creations); + return new CreateAclsRequest.Builder(data).build(); + } + + private CreateAclsResponse createCreateAclsResponse() { + return new CreateAclsResponse(new CreateAclsResponseData().setResults(asList( + new CreateAclsResponseData.AclCreationResult(), + new CreateAclsResponseData.AclCreationResult() + .setErrorCode(Errors.NONE.code()) + .setErrorMessage("Foo bar")))); + } + + private DeleteAclsRequest createDeleteAclsRequest() { + DeleteAclsRequestData data = new DeleteAclsRequestData().setFilters(asList( + new DeleteAclsRequestData.DeleteAclsFilter() + .setResourceTypeFilter(ResourceType.ANY.code()) + .setResourceNameFilter(null) + .setPatternTypeFilter(PatternType.LITERAL.code()) + .setPrincipalFilter("User:ANONYMOUS") + .setHostFilter(null) + .setOperation(AclOperation.ANY.code()) + .setPermissionType(AclPermissionType.ANY.code()), + new DeleteAclsRequestData.DeleteAclsFilter() + .setResourceTypeFilter(ResourceType.ANY.code()) + .setResourceNameFilter(null) + .setPatternTypeFilter(PatternType.LITERAL.code()) + .setPrincipalFilter("User:bob") + .setHostFilter(null) + .setOperation(AclOperation.ANY.code()) + .setPermissionType(AclPermissionType.ANY.code()) + )); + return new DeleteAclsRequest.Builder(data).build(); + } + + private DeleteAclsResponse createDeleteAclsResponse(int version) { + List filterResults = new ArrayList<>(); + filterResults.add(new DeleteAclsResponseData.DeleteAclsFilterResult().setMatchingAcls(asList( + new DeleteAclsResponseData.DeleteAclsMatchingAcl() + .setResourceType(ResourceType.TOPIC.code()) + .setResourceName("mytopic3") + .setPatternType(PatternType.LITERAL.code()) + .setPrincipal("User:ANONYMOUS") + .setHost("*") + .setOperation(AclOperation.DESCRIBE.code()) + .setPermissionType(AclPermissionType.ALLOW.code()), + new DeleteAclsResponseData.DeleteAclsMatchingAcl() + .setResourceType(ResourceType.TOPIC.code()) + .setResourceName("mytopic4") + .setPatternType(PatternType.LITERAL.code()) + .setPrincipal("User:ANONYMOUS") + .setHost("*") + .setOperation(AclOperation.DESCRIBE.code()) + .setPermissionType(AclPermissionType.DENY.code())))); + filterResults.add(new DeleteAclsResponseData.DeleteAclsFilterResult() + .setErrorCode(Errors.SECURITY_DISABLED.code()) + .setErrorMessage("No security")); + return new DeleteAclsResponse(new DeleteAclsResponseData() + .setThrottleTimeMs(0) + .setFilterResults(filterResults), (short) version); + } + + private DescribeConfigsRequest createDescribeConfigsRequest(int version) { + return new DescribeConfigsRequest.Builder(new DescribeConfigsRequestData() + .setResources(asList( + new DescribeConfigsRequestData.DescribeConfigsResource() + .setResourceType(ConfigResource.Type.BROKER.id()) + .setResourceName("0"), + new DescribeConfigsRequestData.DescribeConfigsResource() + .setResourceType(ConfigResource.Type.TOPIC.id()) + .setResourceName("topic")))) + .build((short) version); + } + + private DescribeConfigsRequest createDescribeConfigsRequestWithConfigEntries(int version) { + return new DescribeConfigsRequest.Builder(new DescribeConfigsRequestData() + .setResources(asList( + new DescribeConfigsRequestData.DescribeConfigsResource() + .setResourceType(ConfigResource.Type.BROKER.id()) + .setResourceName("0") + .setConfigurationKeys(asList("foo", "bar")), + new DescribeConfigsRequestData.DescribeConfigsResource() + .setResourceType(ConfigResource.Type.TOPIC.id()) + .setResourceName("topic") + .setConfigurationKeys(null), + new DescribeConfigsRequestData.DescribeConfigsResource() + .setResourceType(ConfigResource.Type.TOPIC.id()) + .setResourceName("topic a") + .setConfigurationKeys(emptyList())))).build((short) version); + } + + private DescribeConfigsRequest createDescribeConfigsRequestWithDocumentation(int version) { + DescribeConfigsRequestData data = new DescribeConfigsRequestData() + .setResources(asList( + new DescribeConfigsRequestData.DescribeConfigsResource() + .setResourceType(ConfigResource.Type.BROKER.id()) + .setResourceName("0") + .setConfigurationKeys(asList("foo", "bar")))); + if (version == 3) { + data.setIncludeDocumentation(true); + } + return new DescribeConfigsRequest.Builder(data).build((short) version); + } + + private DescribeConfigsResponse createDescribeConfigsResponse(short version) { + return new DescribeConfigsResponse(new DescribeConfigsResponseData().setResults(asList( + new DescribeConfigsResult() + .setErrorCode(Errors.NONE.code()) + .setResourceType(ConfigResource.Type.BROKER.id()) + .setResourceName("0") + .setConfigs(asList( + new DescribeConfigsResourceResult() + .setName("config_name") + .setValue("config_value") + // Note: the v0 default for this field that should be exposed to callers is + // context-dependent. For example, if the resource is a broker, this should default to 4. + // -1 is just a placeholder value. + .setConfigSource(version == 0 ? DescribeConfigsResponse.ConfigSource.STATIC_BROKER_CONFIG.id() : DescribeConfigsResponse.ConfigSource.DYNAMIC_BROKER_CONFIG.id) + .setIsSensitive(true).setReadOnly(false) + .setSynonyms(emptyList()), + new DescribeConfigsResourceResult() + .setName("yet_another_name") + .setValue("yet another value") + .setConfigSource(version == 0 ? DescribeConfigsResponse.ConfigSource.STATIC_BROKER_CONFIG.id() : DescribeConfigsResponse.ConfigSource.DEFAULT_CONFIG.id) + .setIsSensitive(false).setReadOnly(true) + .setSynonyms(emptyList()) + .setConfigType(ConfigType.BOOLEAN.id()) + .setDocumentation("some description"), + new DescribeConfigsResourceResult() + .setName("another_name") + .setValue("another value") + .setConfigSource(version == 0 ? DescribeConfigsResponse.ConfigSource.STATIC_BROKER_CONFIG.id() : DescribeConfigsResponse.ConfigSource.DEFAULT_CONFIG.id) + .setIsSensitive(false).setReadOnly(true) + .setSynonyms(emptyList()) + )), + new DescribeConfigsResult() + .setErrorCode(Errors.NONE.code()) + .setResourceType(ConfigResource.Type.TOPIC.id()) + .setResourceName("topic") + .setConfigs(emptyList()) + ))); + + } + + private AlterConfigsRequest createAlterConfigsRequest() { + Map configs = new HashMap<>(); + List configEntries = asList( + new AlterConfigsRequest.ConfigEntry("config_name", "config_value"), + new AlterConfigsRequest.ConfigEntry("another_name", "another value") + ); + configs.put(new ConfigResource(ConfigResource.Type.BROKER, "0"), new AlterConfigsRequest.Config(configEntries)); + configs.put(new ConfigResource(ConfigResource.Type.TOPIC, "topic"), + new AlterConfigsRequest.Config(Collections.emptyList())); + return new AlterConfigsRequest.Builder(configs, false).build((short) 0); + } + + private AlterConfigsResponse createAlterConfigsResponse() { + AlterConfigsResponseData data = new AlterConfigsResponseData() + .setThrottleTimeMs(20); + data.responses().add(new AlterConfigsResponseData.AlterConfigsResourceResponse() + .setErrorCode(Errors.NONE.code()) + .setErrorMessage(null) + .setResourceName("0") + .setResourceType(ConfigResource.Type.BROKER.id())); + data.responses().add(new AlterConfigsResponseData.AlterConfigsResourceResponse() + .setErrorCode(Errors.INVALID_REQUEST.code()) + .setErrorMessage("This request is invalid") + .setResourceName("topic") + .setResourceType(ConfigResource.Type.TOPIC.id())); + return new AlterConfigsResponse(data); + } + + private CreatePartitionsRequest createCreatePartitionsRequest(int version) { + CreatePartitionsTopicCollection topics = new CreatePartitionsTopicCollection(); + topics.add(new CreatePartitionsTopic() + .setName("my_topic") + .setCount(3) + ); + topics.add(new CreatePartitionsTopic() + .setName("my_other_topic") + .setCount(3) + ); + + CreatePartitionsRequestData data = new CreatePartitionsRequestData() + .setTimeoutMs(0) + .setValidateOnly(false) + .setTopics(topics); + + return new CreatePartitionsRequest(data, (short) version); + } + + private CreatePartitionsRequest createCreatePartitionsRequestWithAssignments(int version) { + CreatePartitionsTopicCollection topics = new CreatePartitionsTopicCollection(); + CreatePartitionsAssignment myTopicAssignment = new CreatePartitionsAssignment() + .setBrokerIds(Collections.singletonList(2)); + topics.add(new CreatePartitionsTopic() + .setName("my_topic") + .setCount(3) + .setAssignments(Collections.singletonList(myTopicAssignment)) + ); + + topics.add(new CreatePartitionsTopic() + .setName("my_other_topic") + .setCount(3) + .setAssignments(asList( + new CreatePartitionsAssignment().setBrokerIds(asList(2, 3)), + new CreatePartitionsAssignment().setBrokerIds(asList(3, 1)) + )) + ); + + CreatePartitionsRequestData data = new CreatePartitionsRequestData() + .setTimeoutMs(0) + .setValidateOnly(false) + .setTopics(topics); + + return new CreatePartitionsRequest(data, (short) version); + } + + private CreatePartitionsResponse createCreatePartitionsResponse() { + List results = new LinkedList<>(); + results.add(new CreatePartitionsTopicResult() + .setName("my_topic") + .setErrorCode(Errors.INVALID_REPLICA_ASSIGNMENT.code())); + results.add(new CreatePartitionsTopicResult() + .setName("my_topic") + .setErrorCode(Errors.NONE.code())); + CreatePartitionsResponseData data = new CreatePartitionsResponseData() + .setThrottleTimeMs(42) + .setResults(results); + return new CreatePartitionsResponse(data); + } + + private CreateDelegationTokenRequest createCreateTokenRequest() { + List renewers = new ArrayList<>(); + renewers.add(new CreatableRenewers() + .setPrincipalType("User") + .setPrincipalName("user1")); + renewers.add(new CreatableRenewers() + .setPrincipalType("User") + .setPrincipalName("user2")); + return new CreateDelegationTokenRequest.Builder(new CreateDelegationTokenRequestData() + .setRenewers(renewers) + .setMaxLifetimeMs(System.currentTimeMillis())).build(); + } + + private CreateDelegationTokenResponse createCreateTokenResponse() { + CreateDelegationTokenResponseData data = new CreateDelegationTokenResponseData() + .setThrottleTimeMs(20) + .setErrorCode(Errors.NONE.code()) + .setPrincipalType("User") + .setPrincipalName("user1") + .setIssueTimestampMs(System.currentTimeMillis()) + .setExpiryTimestampMs(System.currentTimeMillis()) + .setMaxTimestampMs(System.currentTimeMillis()) + .setTokenId("token1") + .setHmac("test".getBytes()); + return new CreateDelegationTokenResponse(data); + } + + private RenewDelegationTokenRequest createRenewTokenRequest() { + RenewDelegationTokenRequestData data = new RenewDelegationTokenRequestData() + .setHmac("test".getBytes()) + .setRenewPeriodMs(System.currentTimeMillis()); + return new RenewDelegationTokenRequest.Builder(data).build(); + } + + private RenewDelegationTokenResponse createRenewTokenResponse() { + RenewDelegationTokenResponseData data = new RenewDelegationTokenResponseData() + .setThrottleTimeMs(20) + .setErrorCode(Errors.NONE.code()) + .setExpiryTimestampMs(System.currentTimeMillis()); + return new RenewDelegationTokenResponse(data); + } + + private ExpireDelegationTokenRequest createExpireTokenRequest() { + ExpireDelegationTokenRequestData data = new ExpireDelegationTokenRequestData() + .setHmac("test".getBytes()) + .setExpiryTimePeriodMs(System.currentTimeMillis()); + return new ExpireDelegationTokenRequest.Builder(data).build(); + } + + private ExpireDelegationTokenResponse createExpireTokenResponse() { + ExpireDelegationTokenResponseData data = new ExpireDelegationTokenResponseData() + .setThrottleTimeMs(20) + .setErrorCode(Errors.NONE.code()) + .setExpiryTimestampMs(System.currentTimeMillis()); + return new ExpireDelegationTokenResponse(data); + } + + private DescribeDelegationTokenRequest createDescribeTokenRequest() { + List owners = new ArrayList<>(); + owners.add(SecurityUtils.parseKafkaPrincipal("User:user1")); + owners.add(SecurityUtils.parseKafkaPrincipal("User:user2")); + return new DescribeDelegationTokenRequest.Builder(owners).build(); + } + + private DescribeDelegationTokenResponse createDescribeTokenResponse() { + List renewers = new ArrayList<>(); + renewers.add(SecurityUtils.parseKafkaPrincipal("User:user1")); + renewers.add(SecurityUtils.parseKafkaPrincipal("User:user2")); + + List tokenList = new LinkedList<>(); + + TokenInformation tokenInfo1 = new TokenInformation("1", SecurityUtils.parseKafkaPrincipal("User:owner"), renewers, + System.currentTimeMillis(), System.currentTimeMillis(), System.currentTimeMillis()); + + TokenInformation tokenInfo2 = new TokenInformation("2", SecurityUtils.parseKafkaPrincipal("User:owner1"), renewers, + System.currentTimeMillis(), System.currentTimeMillis(), System.currentTimeMillis()); + + tokenList.add(new DelegationToken(tokenInfo1, "test".getBytes())); + tokenList.add(new DelegationToken(tokenInfo2, "test".getBytes())); + + return new DescribeDelegationTokenResponse(20, Errors.NONE, tokenList); + } + + private ElectLeadersRequest createElectLeadersRequestNullPartitions() { + return new ElectLeadersRequest.Builder(ElectionType.PREFERRED, null, 100).build((short) 1); + } + + private ElectLeadersRequest createElectLeadersRequest() { + List partitions = asList(new TopicPartition("data", 1), new TopicPartition("data", 2)); + + return new ElectLeadersRequest.Builder(ElectionType.PREFERRED, partitions, 100).build((short) 1); + } + + private ElectLeadersResponse createElectLeadersResponse() { + String topic = "myTopic"; + List electionResults = new ArrayList<>(); + ReplicaElectionResult electionResult = new ReplicaElectionResult(); + electionResults.add(electionResult); + electionResult.setTopic(topic); + // Add partition 1 result + PartitionResult partitionResult = new PartitionResult(); + partitionResult.setPartitionId(0); + partitionResult.setErrorCode(ApiError.NONE.error().code()); + partitionResult.setErrorMessage(ApiError.NONE.message()); + electionResult.partitionResult().add(partitionResult); + + // Add partition 2 result + partitionResult = new PartitionResult(); + partitionResult.setPartitionId(1); + partitionResult.setErrorCode(Errors.UNKNOWN_TOPIC_OR_PARTITION.code()); + partitionResult.setErrorMessage(Errors.UNKNOWN_TOPIC_OR_PARTITION.message()); + electionResult.partitionResult().add(partitionResult); + + return new ElectLeadersResponse(200, Errors.NONE.code(), electionResults, ApiKeys.ELECT_LEADERS.latestVersion()); + } + + private IncrementalAlterConfigsRequest createIncrementalAlterConfigsRequest() { + IncrementalAlterConfigsRequestData data = new IncrementalAlterConfigsRequestData(); + AlterableConfig alterableConfig = new AlterableConfig() + .setName("retention.ms") + .setConfigOperation((byte) 0) + .setValue("100"); + IncrementalAlterConfigsRequestData.AlterableConfigCollection alterableConfigs = new IncrementalAlterConfigsRequestData.AlterableConfigCollection(); + alterableConfigs.add(alterableConfig); + + data.resources().add(new AlterConfigsResource() + .setResourceName("testtopic") + .setResourceType(ResourceType.TOPIC.code()) + .setConfigs(alterableConfigs)); + return new IncrementalAlterConfigsRequest.Builder(data).build((short) 0); + } + + private IncrementalAlterConfigsResponse createIncrementalAlterConfigsResponse() { + IncrementalAlterConfigsResponseData data = new IncrementalAlterConfigsResponseData(); + + data.responses().add(new AlterConfigsResourceResponse() + .setResourceName("testtopic") + .setResourceType(ResourceType.TOPIC.code()) + .setErrorCode(Errors.NONE.code()) + .setErrorMessage("Duplicate Keys")); + return new IncrementalAlterConfigsResponse(data); + } + + private AlterPartitionReassignmentsRequest createAlterPartitionReassignmentsRequest() { + AlterPartitionReassignmentsRequestData data = new AlterPartitionReassignmentsRequestData(); + data.topics().add( + new AlterPartitionReassignmentsRequestData.ReassignableTopic().setName("topic").setPartitions( + Collections.singletonList( + new AlterPartitionReassignmentsRequestData.ReassignablePartition().setPartitionIndex(0).setReplicas(null) + ) + ) + ); + return new AlterPartitionReassignmentsRequest.Builder(data).build((short) 0); + } + + private AlterPartitionReassignmentsResponse createAlterPartitionReassignmentsResponse() { + AlterPartitionReassignmentsResponseData data = new AlterPartitionReassignmentsResponseData(); + data.responses().add( + new AlterPartitionReassignmentsResponseData.ReassignableTopicResponse() + .setName("topic") + .setPartitions(Collections.singletonList( + new AlterPartitionReassignmentsResponseData.ReassignablePartitionResponse() + .setPartitionIndex(0) + .setErrorCode(Errors.NONE.code()) + .setErrorMessage("No reassignment is in progress for topic topic partition 0") + ) + ) + ); + return new AlterPartitionReassignmentsResponse(data); + } + + private ListPartitionReassignmentsRequest createListPartitionReassignmentsRequest() { + ListPartitionReassignmentsRequestData data = new ListPartitionReassignmentsRequestData(); + data.setTopics( + Collections.singletonList( + new ListPartitionReassignmentsRequestData.ListPartitionReassignmentsTopics() + .setName("topic") + .setPartitionIndexes(Collections.singletonList(1)) + ) + ); + return new ListPartitionReassignmentsRequest.Builder(data).build((short) 0); + } + + private ListPartitionReassignmentsResponse createListPartitionReassignmentsResponse() { + ListPartitionReassignmentsResponseData data = new ListPartitionReassignmentsResponseData(); + data.setTopics(Collections.singletonList( + new ListPartitionReassignmentsResponseData.OngoingTopicReassignment() + .setName("topic") + .setPartitions(Collections.singletonList( + new ListPartitionReassignmentsResponseData.OngoingPartitionReassignment() + .setPartitionIndex(0) + .setReplicas(Arrays.asList(1, 2)) + .setAddingReplicas(Collections.singletonList(2)) + .setRemovingReplicas(Collections.singletonList(1)) + ) + ) + )); + return new ListPartitionReassignmentsResponse(data); + } + + private OffsetDeleteRequest createOffsetDeleteRequest() { + OffsetDeleteRequestTopicCollection topics = new OffsetDeleteRequestTopicCollection(); + topics.add(new OffsetDeleteRequestTopic() + .setName("topic1") + .setPartitions(Collections.singletonList( + new OffsetDeleteRequestPartition() + .setPartitionIndex(0) + ) + ) + ); + + OffsetDeleteRequestData data = new OffsetDeleteRequestData(); + data.setGroupId("group1"); + data.setTopics(topics); + + return new OffsetDeleteRequest.Builder(data).build((short) 0); + } + + private OffsetDeleteResponse createOffsetDeleteResponse() { + OffsetDeleteResponsePartitionCollection partitions = new OffsetDeleteResponsePartitionCollection(); + partitions.add(new OffsetDeleteResponsePartition() + .setPartitionIndex(0) + .setErrorCode(Errors.NONE.code()) + ); + + OffsetDeleteResponseTopicCollection topics = new OffsetDeleteResponseTopicCollection(); + topics.add(new OffsetDeleteResponseTopic() + .setName("topic1") + .setPartitions(partitions) + ); + + OffsetDeleteResponseData data = new OffsetDeleteResponseData(); + data.setErrorCode(Errors.NONE.code()); + data.setTopics(topics); + + return new OffsetDeleteResponse(data); + } + + private AlterReplicaLogDirsRequest createAlterReplicaLogDirsRequest() { + AlterReplicaLogDirsRequestData data = new AlterReplicaLogDirsRequestData(); + data.dirs().add( + new AlterReplicaLogDirsRequestData.AlterReplicaLogDir() + .setPath("/data0") + .setTopics(new AlterReplicaLogDirTopicCollection(Collections.singletonList( + new AlterReplicaLogDirTopic() + .setPartitions(singletonList(0)) + .setName("topic") + ).iterator()) + ) + ); + return new AlterReplicaLogDirsRequest.Builder(data).build((short) 0); + } + + private AlterReplicaLogDirsResponse createAlterReplicaLogDirsResponse() { + AlterReplicaLogDirsResponseData data = new AlterReplicaLogDirsResponseData(); + data.results().add( + new AlterReplicaLogDirsResponseData.AlterReplicaLogDirTopicResult() + .setTopicName("topic") + .setPartitions(Collections.singletonList( + new AlterReplicaLogDirsResponseData.AlterReplicaLogDirPartitionResult() + .setPartitionIndex(0) + .setErrorCode(Errors.NONE.code()) + ) + ) + ); + return new AlterReplicaLogDirsResponse(data); + } + + private DescribeClientQuotasRequest createDescribeClientQuotasRequest() { + ClientQuotaFilter filter = ClientQuotaFilter.all(); + return new DescribeClientQuotasRequest.Builder(filter).build((short) 0); + } + + private DescribeClientQuotasResponse createDescribeClientQuotasResponse() { + DescribeClientQuotasResponseData data = new DescribeClientQuotasResponseData().setEntries(asList( + new DescribeClientQuotasResponseData.EntryData() + .setEntity(asList(new DescribeClientQuotasResponseData.EntityData() + .setEntityType(ClientQuotaEntity.USER) + .setEntityName("user"))) + .setValues(asList(new DescribeClientQuotasResponseData.ValueData() + .setKey("request_percentage") + .setValue(1.0))))); + return new DescribeClientQuotasResponse(data); + } + + private AlterClientQuotasRequest createAlterClientQuotasRequest() { + ClientQuotaEntity entity = new ClientQuotaEntity(Collections.singletonMap(ClientQuotaEntity.USER, "user")); + ClientQuotaAlteration.Op op = new ClientQuotaAlteration.Op("request_percentage", 2.0); + ClientQuotaAlteration alteration = new ClientQuotaAlteration(entity, Collections.singleton(op)); + return new AlterClientQuotasRequest.Builder(Collections.singleton(alteration), false).build((short) 0); + } + + private AlterClientQuotasResponse createAlterClientQuotasResponse() { + AlterClientQuotasResponseData data = new AlterClientQuotasResponseData() + .setEntries(asList(new AlterClientQuotasResponseData.EntryData() + .setEntity(asList(new AlterClientQuotasResponseData.EntityData() + .setEntityType(ClientQuotaEntity.USER) + .setEntityName("user"))))); + return new AlterClientQuotasResponse(data); + } + + private DescribeProducersRequest createDescribeProducersRequest(short version) { + DescribeProducersRequestData data = new DescribeProducersRequestData(); + DescribeProducersRequestData.TopicRequest topicRequest = new DescribeProducersRequestData.TopicRequest(); + topicRequest.setName("test"); + topicRequest.partitionIndexes().add(0); + topicRequest.partitionIndexes().add(1); + data.topics().add(topicRequest); + return new DescribeProducersRequest.Builder(data).build(version); + } + + private DescribeProducersResponse createDescribeProducersResponse() { + DescribeProducersResponseData data = new DescribeProducersResponseData(); + DescribeProducersResponseData.TopicResponse topicResponse = new DescribeProducersResponseData.TopicResponse(); + topicResponse.partitions().add(new DescribeProducersResponseData.PartitionResponse() + .setErrorCode(Errors.NONE.code()) + .setPartitionIndex(0) + .setActiveProducers(Arrays.asList( + new DescribeProducersResponseData.ProducerState() + .setProducerId(1234L) + .setProducerEpoch(15) + .setLastTimestamp(13490218304L) + .setCurrentTxnStartOffset(5000), + new DescribeProducersResponseData.ProducerState() + .setProducerId(9876L) + .setProducerEpoch(32) + .setLastTimestamp(13490218399L) + )) + ); + data.topics().add(topicResponse); + return new DescribeProducersResponse(data); + } + + private BrokerHeartbeatRequest createBrokerHeartbeatRequest(short v) { + BrokerHeartbeatRequestData data = new BrokerHeartbeatRequestData() + .setBrokerId(1) + .setBrokerEpoch(1) + .setCurrentMetadataOffset(1) + .setWantFence(false) + .setWantShutDown(false); + return new BrokerHeartbeatRequest.Builder(data).build(v); + } + + private BrokerHeartbeatResponse createBrokerHeartbeatResponse() { + BrokerHeartbeatResponseData data = new BrokerHeartbeatResponseData() + .setIsFenced(false) + .setShouldShutDown(false) + .setThrottleTimeMs(0); + return new BrokerHeartbeatResponse(data); + } + + private BrokerRegistrationRequest createBrokerRegistrationRequest(short v) { + BrokerRegistrationRequestData data = new BrokerRegistrationRequestData() + .setBrokerId(1) + .setClusterId(Uuid.randomUuid().toString()) + .setRack("1") + .setFeatures(new BrokerRegistrationRequestData.FeatureCollection(singletonList( + new BrokerRegistrationRequestData.Feature()).iterator())) + .setListeners(new BrokerRegistrationRequestData.ListenerCollection(singletonList( + new BrokerRegistrationRequestData.Listener()).iterator())) + .setIncarnationId(Uuid.randomUuid()); + return new BrokerRegistrationRequest.Builder(data).build(v); + } + + private BrokerRegistrationResponse createBrokerRegistrationResponse() { + BrokerRegistrationResponseData data = new BrokerRegistrationResponseData() + .setBrokerEpoch(1) + .setThrottleTimeMs(0); + return new BrokerRegistrationResponse(data); + } + + private UnregisterBrokerRequest createUnregisterBrokerRequest(short version) { + UnregisterBrokerRequestData data = new UnregisterBrokerRequestData().setBrokerId(1); + return new UnregisterBrokerRequest.Builder(data).build(version); + } + + private UnregisterBrokerResponse createUnregisterBrokerResponse() { + return new UnregisterBrokerResponse(new UnregisterBrokerResponseData()); + } + + /** + * Check that all error codes in the response get included in {@link AbstractResponse#errorCounts()}. + */ + @Test + public void testErrorCountsIncludesNone() { + assertEquals(Integer.valueOf(1), createAddOffsetsToTxnResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createAddPartitionsToTxnResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createAlterClientQuotasResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createAlterConfigsResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(2), createAlterPartitionReassignmentsResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createAlterReplicaLogDirsResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createApiVersionResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createBrokerHeartbeatResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createBrokerRegistrationResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createControlledShutdownResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(2), createCreateAclsResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createCreatePartitionsResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createCreateTokenResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createCreateTopicResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createDeleteAclsResponse(ApiKeys.DELETE_ACLS.latestVersion()).errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createDeleteGroupsResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createDeleteTopicsResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createDescribeAclsResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createDescribeClientQuotasResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(2), createDescribeConfigsResponse(DESCRIBE_CONFIGS.latestVersion()).errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createDescribeGroupResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createDescribeTokenResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(2), createElectLeadersResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createEndTxnResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createExpireTokenResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(3), createFetchResponse(123).errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createFindCoordinatorResponse(FIND_COORDINATOR.oldestVersion()).errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createFindCoordinatorResponse(FIND_COORDINATOR.latestVersion()).errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createHeartBeatResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createIncrementalAlterConfigsResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createJoinGroupResponse(JOIN_GROUP.latestVersion()).errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(2), createLeaderAndIsrResponse(4).errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(2), createLeaderAndIsrResponse(5).errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(3), createLeaderEpochResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createLeaveGroupResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createListGroupsResponse(LIST_GROUPS.latestVersion()).errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createListOffsetResponse(LIST_OFFSETS.latestVersion()).errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createListPartitionReassignmentsResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(3), createMetadataResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createOffsetCommitResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(2), createOffsetDeleteResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(3), createOffsetFetchResponse(OFFSET_FETCH.latestVersion()).errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createProduceResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createRenewTokenResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createSaslAuthenticateResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createSaslHandshakeResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(2), createStopReplicaResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createSyncGroupResponse(SYNC_GROUP.latestVersion()).errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createTxnOffsetCommitResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createUpdateMetadataResponse().errorCounts().get(Errors.NONE)); + assertEquals(Integer.valueOf(1), createWriteTxnMarkersResponse().errorCounts().get(Errors.NONE)); + } + + private DescribeTransactionsRequest createDescribeTransactionsRequest(short version) { + DescribeTransactionsRequestData data = new DescribeTransactionsRequestData() + .setTransactionalIds(asList("t1", "t2", "t3")); + return new DescribeTransactionsRequest.Builder(data).build(version); + } + + private DescribeTransactionsResponse createDescribeTransactionsResponse() { + DescribeTransactionsResponseData data = new DescribeTransactionsResponseData(); + data.setTransactionStates(asList( + new DescribeTransactionsResponseData.TransactionState() + .setErrorCode(Errors.NONE.code()) + .setTransactionalId("t1") + .setProducerId(12345L) + .setProducerEpoch((short) 15) + .setTransactionStartTimeMs(13490218304L) + .setTransactionState("Empty"), + new DescribeTransactionsResponseData.TransactionState() + .setErrorCode(Errors.NONE.code()) + .setTransactionalId("t2") + .setProducerId(98765L) + .setProducerEpoch((short) 30) + .setTransactionStartTimeMs(13490218304L) + .setTransactionState("Ongoing") + .setTopics(new DescribeTransactionsResponseData.TopicDataCollection( + asList( + new DescribeTransactionsResponseData.TopicData() + .setTopic("foo") + .setPartitions(asList(1, 3, 5, 7)), + new DescribeTransactionsResponseData.TopicData() + .setTopic("bar") + .setPartitions(asList(1, 3)) + ).iterator() + )), + new DescribeTransactionsResponseData.TransactionState() + .setErrorCode(Errors.NOT_COORDINATOR.code()) + .setTransactionalId("t3") + )); + return new DescribeTransactionsResponse(data); + } + + private ListTransactionsRequest createListTransactionsRequest(short version) { + return new ListTransactionsRequest.Builder(new ListTransactionsRequestData() + .setStateFilters(singletonList("Ongoing")) + .setProducerIdFilters(asList(1L, 2L, 15L)) + ).build(version); + } + + private ListTransactionsResponse createListTransactionsResponse() { + ListTransactionsResponseData response = new ListTransactionsResponseData(); + response.setErrorCode(Errors.NONE.code()); + response.setTransactionStates(Arrays.asList( + new ListTransactionsResponseData.TransactionState() + .setTransactionalId("foo") + .setProducerId(12345L) + .setTransactionState("Ongoing"), + new ListTransactionsResponseData.TransactionState() + .setTransactionalId("bar") + .setProducerId(98765L) + .setTransactionState("PrepareAbort") + )); + return new ListTransactionsResponse(response); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/RequestTestUtils.java b/clients/src/test/java/org/apache/kafka/common/requests/RequestTestUtils.java new file mode 100644 index 0000000..d50e1b9 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/RequestTestUtils.java @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import java.util.HashMap; +import java.util.Set; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.internals.Topic; +import org.apache.kafka.common.message.MetadataResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.protocol.ObjectSerializationCache; +import org.apache.kafka.common.record.RecordBatch; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; + +public class RequestTestUtils { + + public static boolean hasIdempotentRecords(ProduceRequest request) { + return RequestUtils.flag(request, RecordBatch::hasProducerId); + } + + public static ByteBuffer serializeRequestHeader(RequestHeader header) { + ObjectSerializationCache serializationCache = new ObjectSerializationCache(); + ByteBuffer buffer = ByteBuffer.allocate(header.size(serializationCache)); + header.write(buffer, serializationCache); + buffer.flip(); + return buffer; + } + + public static ByteBuffer serializeResponseWithHeader(AbstractResponse response, short version, int correlationId) { + return response.serializeWithHeader(new ResponseHeader(correlationId, + response.apiKey().responseHeaderVersion(version)), version); + } + + public static MetadataResponse metadataResponse(Collection brokers, + String clusterId, int controllerId, + List topicMetadataList) { + return metadataResponse(brokers, clusterId, controllerId, topicMetadataList, ApiKeys.METADATA.latestVersion()); + } + + public static MetadataResponse metadataResponse(Collection brokers, + String clusterId, int controllerId, + List topicMetadataList, + short responseVersion) { + return metadataResponse(MetadataResponse.DEFAULT_THROTTLE_TIME, brokers, clusterId, controllerId, + topicMetadataList, MetadataResponse.AUTHORIZED_OPERATIONS_OMITTED, responseVersion); + } + + public static MetadataResponse metadataResponse(int throttleTimeMs, Collection brokers, + String clusterId, int controllerId, + List topicMetadatas, + int clusterAuthorizedOperations, + short responseVersion) { + List topics = new ArrayList<>(); + topicMetadatas.forEach(topicMetadata -> { + MetadataResponseData.MetadataResponseTopic metadataResponseTopic = new MetadataResponseData.MetadataResponseTopic(); + metadataResponseTopic + .setErrorCode(topicMetadata.error().code()) + .setName(topicMetadata.topic()) + .setTopicId(topicMetadata.topicId()) + .setIsInternal(topicMetadata.isInternal()) + .setTopicAuthorizedOperations(topicMetadata.authorizedOperations()); + + for (MetadataResponse.PartitionMetadata partitionMetadata : topicMetadata.partitionMetadata()) { + metadataResponseTopic.partitions().add(new MetadataResponseData.MetadataResponsePartition() + .setErrorCode(partitionMetadata.error.code()) + .setPartitionIndex(partitionMetadata.partition()) + .setLeaderId(partitionMetadata.leaderId.orElse(MetadataResponse.NO_LEADER_ID)) + .setLeaderEpoch(partitionMetadata.leaderEpoch.orElse(RecordBatch.NO_PARTITION_LEADER_EPOCH)) + .setReplicaNodes(partitionMetadata.replicaIds) + .setIsrNodes(partitionMetadata.inSyncReplicaIds) + .setOfflineReplicas(partitionMetadata.offlineReplicaIds)); + } + topics.add(metadataResponseTopic); + }); + return MetadataResponse.prepareResponse(responseVersion, throttleTimeMs, brokers, clusterId, controllerId, + topics, clusterAuthorizedOperations); } + + public static MetadataResponse metadataUpdateWith(final int numNodes, + final Map topicPartitionCounts) { + return metadataUpdateWith("kafka-cluster", numNodes, topicPartitionCounts); + } + + public static MetadataResponse metadataUpdateWith(final int numNodes, + final Map topicPartitionCounts, + final Function epochSupplier) { + return metadataUpdateWith("kafka-cluster", numNodes, Collections.emptyMap(), + topicPartitionCounts, epochSupplier, MetadataResponse.PartitionMetadata::new, ApiKeys.METADATA.latestVersion(), Collections.emptyMap()); + } + + public static MetadataResponse metadataUpdateWith(final String clusterId, + final int numNodes, + final Map topicPartitionCounts) { + return metadataUpdateWith(clusterId, numNodes, Collections.emptyMap(), + topicPartitionCounts, tp -> null, MetadataResponse.PartitionMetadata::new, ApiKeys.METADATA.latestVersion(), Collections.emptyMap()); + } + + public static MetadataResponse metadataUpdateWith(final String clusterId, + final int numNodes, + final Map topicErrors, + final Map topicPartitionCounts) { + return metadataUpdateWith(clusterId, numNodes, topicErrors, + topicPartitionCounts, tp -> null, MetadataResponse.PartitionMetadata::new, ApiKeys.METADATA.latestVersion(), Collections.emptyMap()); + } + + public static MetadataResponse metadataUpdateWith(final String clusterId, + final int numNodes, + final Map topicErrors, + final Map topicPartitionCounts, + final short responseVersion) { + return metadataUpdateWith(clusterId, numNodes, topicErrors, + topicPartitionCounts, tp -> null, MetadataResponse.PartitionMetadata::new, responseVersion, Collections.emptyMap()); + } + + public static MetadataResponse metadataUpdateWith(final String clusterId, + final int numNodes, + final Map topicErrors, + final Map topicPartitionCounts, + final Function epochSupplier) { + return metadataUpdateWith(clusterId, numNodes, topicErrors, + topicPartitionCounts, epochSupplier, MetadataResponse.PartitionMetadata::new, ApiKeys.METADATA.latestVersion(), Collections.emptyMap()); + } + + public static MetadataResponse metadataUpdateWithIds(final int numNodes, + final Map topicPartitionCounts, + final Map topicIds) { + return metadataUpdateWith("kafka-cluster", numNodes, Collections.emptyMap(), + topicPartitionCounts, tp -> null, MetadataResponse.PartitionMetadata::new, ApiKeys.METADATA.latestVersion(), + topicIds); + } + + public static MetadataResponse metadataUpdateWithIds(final int numNodes, + final Set partitions, + final Function epochSupplier) { + final Map topicPartitionCounts = new HashMap<>(); + final Map topicIds = new HashMap<>(); + + partitions.forEach(partition -> { + topicPartitionCounts.compute(partition.topic(), (key, value) -> value == null ? 1 : value + 1); + topicIds.putIfAbsent(partition.topic(), partition.topicId()); + }); + + return metadataUpdateWithIds(numNodes, topicPartitionCounts, epochSupplier, topicIds); + } + + public static MetadataResponse metadataUpdateWithIds(final int numNodes, + final Map topicPartitionCounts, + final Function epochSupplier, + final Map topicIds) { + return metadataUpdateWith("kafka-cluster", numNodes, Collections.emptyMap(), + topicPartitionCounts, epochSupplier, MetadataResponse.PartitionMetadata::new, ApiKeys.METADATA.latestVersion(), + topicIds); + } + + public static MetadataResponse metadataUpdateWithIds(final String clusterId, + final int numNodes, + final Map topicErrors, + final Map topicPartitionCounts, + final Function epochSupplier, + final Map topicIds) { + return metadataUpdateWith(clusterId, numNodes, topicErrors, + topicPartitionCounts, epochSupplier, MetadataResponse.PartitionMetadata::new, ApiKeys.METADATA.latestVersion(), topicIds); + } + + public static MetadataResponse metadataUpdateWith(final String clusterId, + final int numNodes, + final Map topicErrors, + final Map topicPartitionCounts, + final Function epochSupplier, + final PartitionMetadataSupplier partitionSupplier, + final short responseVersion, + final Map topicIds) { + final List nodes = new ArrayList<>(numNodes); + for (int i = 0; i < numNodes; i++) + nodes.add(new Node(i, "localhost", 1969 + i)); + + List topicMetadata = new ArrayList<>(); + for (Map.Entry topicPartitionCountEntry : topicPartitionCounts.entrySet()) { + String topic = topicPartitionCountEntry.getKey(); + int numPartitions = topicPartitionCountEntry.getValue(); + + List partitionMetadata = new ArrayList<>(numPartitions); + for (int i = 0; i < numPartitions; i++) { + TopicPartition tp = new TopicPartition(topic, i); + Node leader = nodes.get(i % nodes.size()); + List replicaIds = Collections.singletonList(leader.id()); + partitionMetadata.add(partitionSupplier.supply( + Errors.NONE, tp, Optional.of(leader.id()), Optional.ofNullable(epochSupplier.apply(tp)), + replicaIds, replicaIds, replicaIds)); + } + + topicMetadata.add(new MetadataResponse.TopicMetadata(Errors.NONE, topic, topicIds.getOrDefault(topic, Uuid.ZERO_UUID), + Topic.isInternal(topic), partitionMetadata, MetadataResponse.AUTHORIZED_OPERATIONS_OMITTED)); + } + + for (Map.Entry topicErrorEntry : topicErrors.entrySet()) { + String topic = topicErrorEntry.getKey(); + topicMetadata.add(new MetadataResponse.TopicMetadata(topicErrorEntry.getValue(), topic, + Topic.isInternal(topic), Collections.emptyList())); + } + + return metadataResponse(nodes, clusterId, 0, topicMetadata, responseVersion); + } + + @FunctionalInterface + public interface PartitionMetadataSupplier { + MetadataResponse.PartitionMetadata supply(Errors error, + TopicPartition partition, + Optional leaderId, + Optional leaderEpoch, + List replicas, + List isr, + List offlineReplicas); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/StopReplicaRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/StopReplicaRequestTest.java new file mode 100644 index 0000000..3446d49 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/StopReplicaRequestTest.java @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.ClusterAuthorizationException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.StopReplicaRequestData; +import org.apache.kafka.common.message.StopReplicaRequestData.StopReplicaPartitionState; +import org.apache.kafka.common.message.StopReplicaRequestData.StopReplicaPartitionV0; +import org.apache.kafka.common.message.StopReplicaRequestData.StopReplicaTopicV1; +import org.apache.kafka.common.message.StopReplicaRequestData.StopReplicaTopicState; +import org.apache.kafka.common.message.StopReplicaResponseData.StopReplicaPartitionError; +import org.apache.kafka.common.protocol.Errors; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.apache.kafka.common.protocol.ApiKeys.STOP_REPLICA; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class StopReplicaRequestTest { + + @Test + public void testUnsupportedVersion() { + StopReplicaRequest.Builder builder = new StopReplicaRequest.Builder( + (short) (STOP_REPLICA.latestVersion() + 1), + 0, 0, 0L, false, Collections.emptyList()); + assertThrows(UnsupportedVersionException.class, builder::build); + } + + @Test + public void testGetErrorResponse() { + List topicStates = topicStates(true); + + Set expectedPartitions = new HashSet<>(); + for (StopReplicaTopicState topicState : topicStates) { + for (StopReplicaPartitionState partitionState: topicState.partitionStates()) { + expectedPartitions.add(new StopReplicaPartitionError() + .setTopicName(topicState.topicName()) + .setPartitionIndex(partitionState.partitionIndex()) + .setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code())); + } + } + + for (short version : STOP_REPLICA.allVersions()) { + StopReplicaRequest.Builder builder = new StopReplicaRequest.Builder(version, + 0, 0, 0L, false, topicStates); + StopReplicaRequest request = builder.build(); + StopReplicaResponse response = request.getErrorResponse(0, + new ClusterAuthorizationException("Not authorized")); + assertEquals(Errors.CLUSTER_AUTHORIZATION_FAILED, response.error()); + assertEquals(expectedPartitions, new HashSet<>(response.partitionErrors())); + } + } + + @Test + public void testBuilderNormalizationWithAllDeletePartitionEqualToTrue() { + testBuilderNormalization(true); + } + + @Test + public void testBuilderNormalizationWithAllDeletePartitionEqualToFalse() { + testBuilderNormalization(false); + } + + private void testBuilderNormalization(boolean deletePartitions) { + List topicStates = topicStates(deletePartitions); + + Map expectedPartitionStates = + StopReplicaRequestTest.partitionStates(topicStates); + + for (short version : STOP_REPLICA.allVersions()) { + StopReplicaRequest request = new StopReplicaRequest.Builder(version, 0, 1, 0, + deletePartitions, topicStates).build(version); + StopReplicaRequestData data = request.data(); + + if (version < 1) { + Set partitions = new HashSet<>(); + for (StopReplicaPartitionV0 partition : data.ungroupedPartitions()) { + partitions.add(new TopicPartition(partition.topicName(), partition.partitionIndex())); + } + assertEquals(expectedPartitionStates.keySet(), partitions); + assertEquals(deletePartitions, data.deletePartitions()); + } else if (version < 3) { + Set partitions = new HashSet<>(); + for (StopReplicaTopicV1 topic : data.topics()) { + for (Integer partition : topic.partitionIndexes()) { + partitions.add(new TopicPartition(topic.name(), partition)); + } + } + assertEquals(expectedPartitionStates.keySet(), partitions); + assertEquals(deletePartitions, data.deletePartitions()); + } else { + Map partitionStates = + StopReplicaRequestTest.partitionStates(data.topicStates()); + assertEquals(expectedPartitionStates, partitionStates); + // Always false from V3 on + assertFalse(data.deletePartitions()); + } + } + } + + @Test + public void testTopicStatesNormalization() { + List topicStates = topicStates(true); + + for (short version : STOP_REPLICA.allVersions()) { + // Create a request for version to get its serialized form + StopReplicaRequest baseRequest = new StopReplicaRequest.Builder(version, 0, 1, 0, + true, topicStates).build(version); + + // Construct the request from the buffer + StopReplicaRequest request = StopReplicaRequest.parse(baseRequest.serialize(), version); + + Map partitionStates = + StopReplicaRequestTest.partitionStates(request.topicStates()); + assertEquals(6, partitionStates.size()); + + for (StopReplicaTopicState expectedTopicState : topicStates) { + for (StopReplicaPartitionState expectedPartitionState: expectedTopicState.partitionStates()) { + TopicPartition tp = new TopicPartition(expectedTopicState.topicName(), + expectedPartitionState.partitionIndex()); + StopReplicaPartitionState partitionState = partitionStates.get(tp); + + assertEquals(expectedPartitionState.partitionIndex(), partitionState.partitionIndex()); + assertTrue(partitionState.deletePartition()); + + if (version >= 3) { + assertEquals(expectedPartitionState.leaderEpoch(), partitionState.leaderEpoch()); + } else { + assertEquals(-1, partitionState.leaderEpoch()); + } + } + } + } + } + + @Test + public void testPartitionStatesNormalization() { + List topicStates = topicStates(true); + + for (short version : STOP_REPLICA.allVersions()) { + // Create a request for version to get its serialized form + StopReplicaRequest baseRequest = new StopReplicaRequest.Builder(version, 0, 1, 0, + true, topicStates).build(version); + + // Construct the request from the buffer + StopReplicaRequest request = StopReplicaRequest.parse(baseRequest.serialize(), version); + + Map partitionStates = request.partitionStates(); + assertEquals(6, partitionStates.size()); + + for (StopReplicaTopicState expectedTopicState : topicStates) { + for (StopReplicaPartitionState expectedPartitionState: expectedTopicState.partitionStates()) { + TopicPartition tp = new TopicPartition(expectedTopicState.topicName(), + expectedPartitionState.partitionIndex()); + StopReplicaPartitionState partitionState = partitionStates.get(tp); + + assertEquals(expectedPartitionState.partitionIndex(), partitionState.partitionIndex()); + assertTrue(partitionState.deletePartition()); + + if (version >= 3) { + assertEquals(expectedPartitionState.leaderEpoch(), partitionState.leaderEpoch()); + } else { + assertEquals(-1, partitionState.leaderEpoch()); + } + } + } + } + } + + private List topicStates(boolean deletePartition) { + List topicStates = new ArrayList<>(); + StopReplicaTopicState topic0 = new StopReplicaTopicState() + .setTopicName("topic0"); + topic0.partitionStates().add(new StopReplicaPartitionState() + .setPartitionIndex(0) + .setLeaderEpoch(0) + .setDeletePartition(deletePartition)); + topic0.partitionStates().add(new StopReplicaPartitionState() + .setPartitionIndex(1) + .setLeaderEpoch(1) + .setDeletePartition(deletePartition)); + topicStates.add(topic0); + StopReplicaTopicState topic1 = new StopReplicaTopicState() + .setTopicName("topic1"); + topic1.partitionStates().add(new StopReplicaPartitionState() + .setPartitionIndex(2) + .setLeaderEpoch(2) + .setDeletePartition(deletePartition)); + topic1.partitionStates().add(new StopReplicaPartitionState() + .setPartitionIndex(3) + .setLeaderEpoch(3) + .setDeletePartition(deletePartition)); + topicStates.add(topic1); + StopReplicaTopicState topic3 = new StopReplicaTopicState() + .setTopicName("topic1"); + topic3.partitionStates().add(new StopReplicaPartitionState() + .setPartitionIndex(4) + .setLeaderEpoch(-2) + .setDeletePartition(deletePartition)); + topic3.partitionStates().add(new StopReplicaPartitionState() + .setPartitionIndex(5) + .setLeaderEpoch(-2) + .setDeletePartition(deletePartition)); + topicStates.add(topic3); + return topicStates; + } + + public static Map partitionStates( + Iterable topicStates) { + Map partitionStates = new HashMap<>(); + for (StopReplicaTopicState topicState : topicStates) { + for (StopReplicaPartitionState partitionState: topicState.partitionStates()) { + partitionStates.put( + new TopicPartition(topicState.topicName(), partitionState.partitionIndex()), + partitionState); + } + } + return partitionStates; + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/StopReplicaResponseTest.java b/clients/src/test/java/org/apache/kafka/common/requests/StopReplicaResponseTest.java new file mode 100644 index 0000000..a0a5eda --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/StopReplicaResponseTest.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.StopReplicaRequestData.StopReplicaPartitionState; +import org.apache.kafka.common.message.StopReplicaRequestData.StopReplicaTopicState; +import org.apache.kafka.common.message.StopReplicaResponseData; +import org.apache.kafka.common.message.StopReplicaResponseData.StopReplicaPartitionError; +import org.apache.kafka.common.protocol.Errors; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.apache.kafka.common.protocol.ApiKeys.STOP_REPLICA; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class StopReplicaResponseTest { + + @Test + public void testErrorCountsFromGetErrorResponse() { + List topicStates = new ArrayList<>(); + topicStates.add(new StopReplicaTopicState() + .setTopicName("foo") + .setPartitionStates(Arrays.asList( + new StopReplicaPartitionState().setPartitionIndex(0), + new StopReplicaPartitionState().setPartitionIndex(1)))); + + for (short version : STOP_REPLICA.allVersions()) { + StopReplicaRequest request = new StopReplicaRequest.Builder(version, + 15, 20, 0, false, topicStates).build(version); + StopReplicaResponse response = request + .getErrorResponse(0, Errors.CLUSTER_AUTHORIZATION_FAILED.exception()); + assertEquals(Collections.singletonMap(Errors.CLUSTER_AUTHORIZATION_FAILED, 3), + response.errorCounts()); + } + } + + @Test + public void testErrorCountsWithTopLevelError() { + List errors = new ArrayList<>(); + errors.add(new StopReplicaPartitionError().setTopicName("foo").setPartitionIndex(0)); + errors.add(new StopReplicaPartitionError().setTopicName("foo").setPartitionIndex(1) + .setErrorCode(Errors.NOT_LEADER_OR_FOLLOWER.code())); + StopReplicaResponse response = new StopReplicaResponse(new StopReplicaResponseData() + .setErrorCode(Errors.UNKNOWN_SERVER_ERROR.code()) + .setPartitionErrors(errors)); + assertEquals(Collections.singletonMap(Errors.UNKNOWN_SERVER_ERROR, 3), response.errorCounts()); + } + + @Test + public void testErrorCountsNoTopLevelError() { + List errors = new ArrayList<>(); + errors.add(new StopReplicaPartitionError().setTopicName("foo").setPartitionIndex(0)); + errors.add(new StopReplicaPartitionError().setTopicName("foo").setPartitionIndex(1) + .setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code())); + StopReplicaResponse response = new StopReplicaResponse(new StopReplicaResponseData() + .setErrorCode(Errors.NONE.code()) + .setPartitionErrors(errors)); + Map errorCounts = response.errorCounts(); + assertEquals(2, errorCounts.size()); + assertEquals(2, errorCounts.get(Errors.NONE).intValue()); + assertEquals(1, errorCounts.get(Errors.CLUSTER_AUTHORIZATION_FAILED).intValue()); + } + + @Test + public void testToString() { + List errors = new ArrayList<>(); + errors.add(new StopReplicaPartitionError().setTopicName("foo").setPartitionIndex(0)); + errors.add(new StopReplicaPartitionError().setTopicName("foo").setPartitionIndex(1) + .setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code())); + StopReplicaResponse response = new StopReplicaResponse(new StopReplicaResponseData().setPartitionErrors(errors)); + String responseStr = response.toString(); + assertTrue(responseStr.contains(StopReplicaResponse.class.getSimpleName())); + assertTrue(responseStr.contains(errors.toString())); + assertTrue(responseStr.contains("errorCode=" + Errors.NONE.code())); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/SyncGroupRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/SyncGroupRequestTest.java new file mode 100644 index 0000000..038c989 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/SyncGroupRequestTest.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.SyncGroupRequestData; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class SyncGroupRequestTest { + + @Test + public void testRequestVersionCompatibilityFailBuild() { + assertThrows(UnsupportedVersionException.class, () -> new SyncGroupRequest.Builder( + new SyncGroupRequestData() + .setGroupId("groupId") + .setMemberId("consumerId") + .setGroupInstanceId("groupInstanceId") + ).build((short) 2)); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/TxnOffsetCommitRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/TxnOffsetCommitRequestTest.java new file mode 100644 index 0000000..d49bdce --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/TxnOffsetCommitRequestTest.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.TxnOffsetCommitRequestData.TxnOffsetCommitRequestPartition; +import org.apache.kafka.common.message.TxnOffsetCommitRequestData.TxnOffsetCommitRequestTopic; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.TxnOffsetCommitRequest.CommittedOffset; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class TxnOffsetCommitRequestTest extends OffsetCommitRequestTest { + + private static String transactionalId = "transactionalId"; + private static int producerId = 10; + private static short producerEpoch = 1; + private static int generationId = 5; + private static Map offsets = new HashMap<>(); + private static TxnOffsetCommitRequest.Builder builder; + private static TxnOffsetCommitRequest.Builder builderWithGroupMetadata; + + @BeforeEach + @Override + public void setUp() { + super.setUp(); + offsets.clear(); + offsets.put(new TopicPartition(topicOne, partitionOne), + new CommittedOffset( + offset, + metadata, + Optional.of((int) leaderEpoch))); + offsets.put(new TopicPartition(topicTwo, partitionTwo), + new CommittedOffset( + offset, + metadata, + Optional.of((int) leaderEpoch))); + + builder = new TxnOffsetCommitRequest.Builder( + transactionalId, + groupId, + producerId, + producerEpoch, + offsets + ); + + builderWithGroupMetadata = new TxnOffsetCommitRequest.Builder( + transactionalId, + groupId, + producerId, + producerEpoch, + offsets, + memberId, + generationId, + Optional.of(groupInstanceId) + ); + } + + @Test + @Override + public void testConstructor() { + + Map errorsMap = new HashMap<>(); + errorsMap.put(new TopicPartition(topicOne, partitionOne), Errors.NOT_COORDINATOR); + errorsMap.put(new TopicPartition(topicTwo, partitionTwo), Errors.NOT_COORDINATOR); + + List expectedTopics = Arrays.asList( + new TxnOffsetCommitRequestTopic() + .setName(topicOne) + .setPartitions(Collections.singletonList( + new TxnOffsetCommitRequestPartition() + .setPartitionIndex(partitionOne) + .setCommittedOffset(offset) + .setCommittedLeaderEpoch(leaderEpoch) + .setCommittedMetadata(metadata) + )), + new TxnOffsetCommitRequestTopic() + .setName(topicTwo) + .setPartitions(Collections.singletonList( + new TxnOffsetCommitRequestPartition() + .setPartitionIndex(partitionTwo) + .setCommittedOffset(offset) + .setCommittedLeaderEpoch(leaderEpoch) + .setCommittedMetadata(metadata) + )) + ); + + for (short version : ApiKeys.TXN_OFFSET_COMMIT.allVersions()) { + final TxnOffsetCommitRequest request; + if (version < 3) { + request = builder.build(version); + } else { + request = builderWithGroupMetadata.build(version); + } + assertEquals(offsets, request.offsets()); + assertEquals(expectedTopics, TxnOffsetCommitRequest.getTopics(request.offsets())); + + TxnOffsetCommitResponse response = + request.getErrorResponse(throttleTimeMs, Errors.NOT_COORDINATOR.exception()); + + assertEquals(errorsMap, response.errors()); + assertEquals(Collections.singletonMap(Errors.NOT_COORDINATOR, 2), response.errorCounts()); + assertEquals(throttleTimeMs, response.throttleTimeMs()); + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/TxnOffsetCommitResponseTest.java b/clients/src/test/java/org/apache/kafka/common/requests/TxnOffsetCommitResponseTest.java new file mode 100644 index 0000000..1f19ff2 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/TxnOffsetCommitResponseTest.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.TxnOffsetCommitResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.MessageUtil; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class TxnOffsetCommitResponseTest extends OffsetCommitResponseTest { + + @Test + @Override + public void testConstructorWithErrorResponse() { + TxnOffsetCommitResponse response = new TxnOffsetCommitResponse(throttleTimeMs, errorsMap); + + assertEquals(errorsMap, response.errors()); + assertEquals(expectedErrorCounts, response.errorCounts()); + assertEquals(throttleTimeMs, response.throttleTimeMs()); + } + + @Test + @Override + public void testParse() { + TxnOffsetCommitResponseData data = new TxnOffsetCommitResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setTopics(Arrays.asList( + new TxnOffsetCommitResponseData.TxnOffsetCommitResponseTopic().setPartitions( + Collections.singletonList(new TxnOffsetCommitResponseData.TxnOffsetCommitResponsePartition() + .setPartitionIndex(partitionOne) + .setErrorCode(errorOne.code()))), + new TxnOffsetCommitResponseData.TxnOffsetCommitResponseTopic().setPartitions( + Collections.singletonList(new TxnOffsetCommitResponseData.TxnOffsetCommitResponsePartition() + .setPartitionIndex(partitionTwo) + .setErrorCode(errorTwo.code()))) + )); + + for (short version : ApiKeys.TXN_OFFSET_COMMIT.allVersions()) { + TxnOffsetCommitResponse response = TxnOffsetCommitResponse.parse( + MessageUtil.toByteBuffer(data, version), version); + assertEquals(expectedErrorCounts, response.errorCounts()); + assertEquals(throttleTimeMs, response.throttleTimeMs()); + assertEquals(version >= 1, response.shouldClientThrottle(version)); + } + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/UpdateFeaturesRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/UpdateFeaturesRequestTest.java new file mode 100644 index 0000000..1b63aec --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/UpdateFeaturesRequestTest.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.errors.UnknownServerException; +import org.apache.kafka.common.message.UpdateFeaturesRequestData; +import org.apache.kafka.common.protocol.Errors; +import org.junit.jupiter.api.Test; + +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class UpdateFeaturesRequestTest { + + @Test + public void testGetErrorResponse() { + UpdateFeaturesRequestData.FeatureUpdateKeyCollection features = + new UpdateFeaturesRequestData.FeatureUpdateKeyCollection(); + + features.add(new UpdateFeaturesRequestData.FeatureUpdateKey() + .setFeature("foo") + .setMaxVersionLevel((short) 2) + ); + + features.add(new UpdateFeaturesRequestData.FeatureUpdateKey() + .setFeature("bar") + .setMaxVersionLevel((short) 3) + ); + + UpdateFeaturesRequest request = new UpdateFeaturesRequest( + new UpdateFeaturesRequestData().setFeatureUpdates(features), + UpdateFeaturesRequestData.HIGHEST_SUPPORTED_VERSION + ); + + UpdateFeaturesResponse response = request.getErrorResponse(0, new UnknownServerException()); + assertEquals(Errors.UNKNOWN_SERVER_ERROR, response.topLevelError().error()); + assertEquals(0, response.data().results().size()); + assertEquals(Collections.singletonMap(Errors.UNKNOWN_SERVER_ERROR, 1), response.errorCounts()); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/UpdateFeaturesResponseTest.java b/clients/src/test/java/org/apache/kafka/common/requests/UpdateFeaturesResponseTest.java new file mode 100644 index 0000000..130fb80 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/UpdateFeaturesResponseTest.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.message.UpdateFeaturesResponseData; +import org.apache.kafka.common.protocol.Errors; +import org.junit.jupiter.api.Test; + +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class UpdateFeaturesResponseTest { + + @Test + public void testErrorCounts() { + UpdateFeaturesResponseData.UpdatableFeatureResultCollection results = + new UpdateFeaturesResponseData.UpdatableFeatureResultCollection(); + + results.add(new UpdateFeaturesResponseData.UpdatableFeatureResult() + .setFeature("foo") + .setErrorCode(Errors.UNKNOWN_SERVER_ERROR.code()) + ); + + results.add(new UpdateFeaturesResponseData.UpdatableFeatureResult() + .setFeature("bar") + .setErrorCode(Errors.UNKNOWN_SERVER_ERROR.code()) + ); + + results.add(new UpdateFeaturesResponseData.UpdatableFeatureResult() + .setFeature("baz") + .setErrorCode(Errors.FEATURE_UPDATE_FAILED.code()) + ); + + UpdateFeaturesResponse response = new UpdateFeaturesResponse(new UpdateFeaturesResponseData() + .setErrorCode(Errors.INVALID_REQUEST.code()) + .setResults(results) + ); + + Map errorCounts = response.errorCounts(); + assertEquals(3, errorCounts.size()); + assertEquals(1, errorCounts.get(Errors.INVALID_REQUEST).intValue()); + assertEquals(2, errorCounts.get(Errors.UNKNOWN_SERVER_ERROR).intValue()); + assertEquals(1, errorCounts.get(Errors.FEATURE_UPDATE_FAILED).intValue()); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/UpdateMetadataRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/UpdateMetadataRequestTest.java new file mode 100644 index 0000000..6f9d5c2 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/UpdateMetadataRequestTest.java @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.errors.ClusterAuthorizationException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.UpdateMetadataRequestData; +import org.apache.kafka.common.message.UpdateMetadataRequestData.UpdateMetadataBroker; +import org.apache.kafka.common.message.UpdateMetadataRequestData.UpdateMetadataEndpoint; +import org.apache.kafka.common.message.UpdateMetadataRequestData.UpdateMetadataPartitionState; +import org.apache.kafka.common.network.ListenerName; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; + +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; +import static org.apache.kafka.common.protocol.ApiKeys.UPDATE_METADATA; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class UpdateMetadataRequestTest { + + @Test + public void testUnsupportedVersion() { + UpdateMetadataRequest.Builder builder = new UpdateMetadataRequest.Builder( + (short) (UPDATE_METADATA.latestVersion() + 1), 0, 0, 0, + Collections.emptyList(), Collections.emptyList(), Collections.emptyMap()); + assertThrows(UnsupportedVersionException.class, builder::build); + } + + @Test + public void testGetErrorResponse() { + for (short version : UPDATE_METADATA.allVersions()) { + UpdateMetadataRequest.Builder builder = new UpdateMetadataRequest.Builder( + version, 0, 0, 0, Collections.emptyList(), Collections.emptyList(), Collections.emptyMap()); + UpdateMetadataRequest request = builder.build(); + UpdateMetadataResponse response = request.getErrorResponse(0, + new ClusterAuthorizationException("Not authorized")); + assertEquals(Errors.CLUSTER_AUTHORIZATION_FAILED, response.error()); + } + } + + /** + * Verifies the logic we have in UpdateMetadataRequest to present a unified interface across the various versions + * works correctly. For example, `UpdateMetadataPartitionState.topicName` is not serialiazed/deserialized in + * recent versions, but we set it manually so that we can always present the ungrouped partition states + * independently of the version. + */ + @Test + public void testVersionLogic() { + String topic0 = "topic0"; + String topic1 = "topic1"; + for (short version : UPDATE_METADATA.allVersions()) { + List partitionStates = asList( + new UpdateMetadataPartitionState() + .setTopicName(topic0) + .setPartitionIndex(0) + .setControllerEpoch(2) + .setLeader(0) + .setLeaderEpoch(10) + .setIsr(asList(0, 1)) + .setZkVersion(10) + .setReplicas(asList(0, 1, 2)) + .setOfflineReplicas(asList(2)), + new UpdateMetadataPartitionState() + .setTopicName(topic0) + .setPartitionIndex(1) + .setControllerEpoch(2) + .setLeader(1) + .setLeaderEpoch(11) + .setIsr(asList(1, 2, 3)) + .setZkVersion(11) + .setReplicas(asList(1, 2, 3)) + .setOfflineReplicas(emptyList()), + new UpdateMetadataPartitionState() + .setTopicName(topic1) + .setPartitionIndex(0) + .setControllerEpoch(2) + .setLeader(2) + .setLeaderEpoch(11) + .setIsr(asList(2, 3)) + .setZkVersion(11) + .setReplicas(asList(2, 3, 4)) + .setOfflineReplicas(emptyList()) + ); + + List broker0Endpoints = new ArrayList<>(); + broker0Endpoints.add( + new UpdateMetadataEndpoint() + .setHost("host0") + .setPort(9090) + .setSecurityProtocol(SecurityProtocol.PLAINTEXT.id)); + + // Non plaintext endpoints are only supported from version 1 + if (version >= 1) { + broker0Endpoints.add(new UpdateMetadataEndpoint() + .setHost("host0") + .setPort(9091) + .setSecurityProtocol(SecurityProtocol.SSL.id)); + } + + // Custom listeners are only supported from version 3 + if (version >= 3) { + broker0Endpoints.get(0).setListener("listener0"); + broker0Endpoints.get(1).setListener("listener1"); + } + + List liveBrokers = asList( + new UpdateMetadataBroker() + .setId(0) + .setRack("rack0") + .setEndpoints(broker0Endpoints), + new UpdateMetadataBroker() + .setId(1) + .setEndpoints(asList( + new UpdateMetadataEndpoint() + .setHost("host1") + .setPort(9090) + .setSecurityProtocol(SecurityProtocol.PLAINTEXT.id) + .setListener("PLAINTEXT") + )) + ); + + Map topicIds = new HashMap<>(); + topicIds.put(topic0, Uuid.randomUuid()); + topicIds.put(topic1, Uuid.randomUuid()); + + UpdateMetadataRequest request = new UpdateMetadataRequest.Builder(version, 1, 2, 3, + partitionStates, liveBrokers, topicIds).build(); + + assertEquals(new HashSet<>(partitionStates), iterableToSet(request.partitionStates())); + assertEquals(liveBrokers, request.liveBrokers()); + assertEquals(1, request.controllerId()); + assertEquals(2, request.controllerEpoch()); + assertEquals(3, request.brokerEpoch()); + + ByteBuffer byteBuffer = request.serialize(); + UpdateMetadataRequest deserializedRequest = new UpdateMetadataRequest(new UpdateMetadataRequestData( + new ByteBufferAccessor(byteBuffer), version), version); + + // Unset fields that are not supported in this version as the deserialized request won't have them + + // Rack is only supported from version 2 + if (version < 2) { + for (UpdateMetadataBroker liveBroker : liveBrokers) + liveBroker.setRack(""); + } + + // Non plaintext listener name is only supported from version 3 + if (version < 3) { + for (UpdateMetadataBroker liveBroker : liveBrokers) { + for (UpdateMetadataEndpoint endpoint : liveBroker.endpoints()) { + SecurityProtocol securityProtocol = SecurityProtocol.forId(endpoint.securityProtocol()); + endpoint.setListener(ListenerName.forSecurityProtocol(securityProtocol).value()); + } + } + } + + // Offline replicas are only supported from version 4 + if (version < 4) + partitionStates.get(0).setOfflineReplicas(emptyList()); + + assertEquals(new HashSet<>(partitionStates), iterableToSet(deserializedRequest.partitionStates())); + assertEquals(liveBrokers, deserializedRequest.liveBrokers()); + assertEquals(1, deserializedRequest.controllerId()); + assertEquals(2, deserializedRequest.controllerEpoch()); + // Broker epoch is only supported from version 5 + if (version >= 5) + assertEquals(3, deserializedRequest.brokerEpoch()); + else + assertEquals(-1, deserializedRequest.brokerEpoch()); + + long topicIdCount = deserializedRequest.data().topicStates().stream() + .map(UpdateMetadataRequestData.UpdateMetadataTopicState::topicId) + .filter(topicId -> topicId != Uuid.ZERO_UUID).count(); + if (version >= 7) + assertEquals(2, topicIdCount); + else + assertEquals(0, topicIdCount); + } + } + + @Test + public void testTopicPartitionGroupingSizeReduction() { + Set tps = TestUtils.generateRandomTopicPartitions(10, 10); + List partitionStates = new ArrayList<>(); + for (TopicPartition tp : tps) { + partitionStates.add(new UpdateMetadataPartitionState() + .setTopicName(tp.topic()) + .setPartitionIndex(tp.partition())); + } + UpdateMetadataRequest.Builder builder = new UpdateMetadataRequest.Builder((short) 5, 0, 0, 0, + partitionStates, Collections.emptyList(), Collections.emptyMap()); + + assertTrue(builder.build((short) 5).sizeInBytes() < builder.build((short) 4).sizeInBytes()); + } + + private Set iterableToSet(Iterable iterable) { + return StreamSupport.stream(iterable.spliterator(), false).collect(Collectors.toSet()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/WriteTxnMarkersRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/WriteTxnMarkersRequestTest.java new file mode 100644 index 0000000..13e8c8c --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/WriteTxnMarkersRequestTest.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class WriteTxnMarkersRequestTest { + + private static long producerId = 10L; + private static short producerEpoch = 2; + private static int coordinatorEpoch = 1; + private static TransactionResult result = TransactionResult.COMMIT; + private static TopicPartition topicPartition = new TopicPartition("topic", 73); + + protected static int throttleTimeMs = 10; + + private static List markers; + + @BeforeEach + public void setUp() { + markers = Collections.singletonList( + new WriteTxnMarkersRequest.TxnMarkerEntry( + producerId, producerEpoch, coordinatorEpoch, + result, Collections.singletonList(topicPartition)) + ); + } + + @Test + public void testConstructor() { + WriteTxnMarkersRequest.Builder builder = new WriteTxnMarkersRequest.Builder(ApiKeys.WRITE_TXN_MARKERS.latestVersion(), markers); + for (short version : ApiKeys.WRITE_TXN_MARKERS.allVersions()) { + WriteTxnMarkersRequest request = builder.build(version); + assertEquals(1, request.markers().size()); + WriteTxnMarkersRequest.TxnMarkerEntry marker = request.markers().get(0); + assertEquals(producerId, marker.producerId()); + assertEquals(producerEpoch, marker.producerEpoch()); + assertEquals(coordinatorEpoch, marker.coordinatorEpoch()); + assertEquals(result, marker.transactionResult()); + assertEquals(Collections.singletonList(topicPartition), marker.partitions()); + } + } + + @Test + public void testGetErrorResponse() { + WriteTxnMarkersRequest.Builder builder = new WriteTxnMarkersRequest.Builder(ApiKeys.WRITE_TXN_MARKERS.latestVersion(), markers); + for (short version : ApiKeys.WRITE_TXN_MARKERS.allVersions()) { + WriteTxnMarkersRequest request = builder.build(version); + WriteTxnMarkersResponse errorResponse = + request.getErrorResponse(throttleTimeMs, Errors.UNKNOWN_PRODUCER_ID.exception()); + + assertEquals(Collections.singletonMap( + topicPartition, Errors.UNKNOWN_PRODUCER_ID), errorResponse.errorsByProducerId().get(producerId)); + assertEquals(Collections.singletonMap(Errors.UNKNOWN_PRODUCER_ID, 1), errorResponse.errorCounts()); + // Write txn marker has no throttle time defined in response. + assertEquals(0, errorResponse.throttleTimeMs()); + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/requests/WriteTxnMarkersResponseTest.java b/clients/src/test/java/org/apache/kafka/common/requests/WriteTxnMarkersResponseTest.java new file mode 100644 index 0000000..2a07412 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/requests/WriteTxnMarkersResponseTest.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.requests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.protocol.Errors; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class WriteTxnMarkersResponseTest { + + private static long producerIdOne = 1L; + private static long producerIdTwo = 2L; + + private static TopicPartition tp1 = new TopicPartition("topic", 1); + private static TopicPartition tp2 = new TopicPartition("topic", 2); + + private static Errors pidOneError = Errors.UNKNOWN_PRODUCER_ID; + private static Errors pidTwoError = Errors.INVALID_PRODUCER_EPOCH; + + private static Map> errorMap; + + @BeforeEach + public void setUp() { + errorMap = new HashMap<>(); + errorMap.put(producerIdOne, Collections.singletonMap(tp1, pidOneError)); + errorMap.put(producerIdTwo, Collections.singletonMap(tp2, pidTwoError)); + } + + @Test + public void testConstructor() { + Map expectedErrorCounts = new HashMap<>(); + expectedErrorCounts.put(Errors.UNKNOWN_PRODUCER_ID, 1); + expectedErrorCounts.put(Errors.INVALID_PRODUCER_EPOCH, 1); + WriteTxnMarkersResponse response = new WriteTxnMarkersResponse(errorMap); + assertEquals(expectedErrorCounts, response.errorCounts()); + assertEquals(Collections.singletonMap(tp1, pidOneError), response.errorsByProducerId().get(producerIdOne)); + assertEquals(Collections.singletonMap(tp2, pidTwoError), response.errorsByProducerId().get(producerIdTwo)); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/resource/ResourceTypeTest.java b/clients/src/test/java/org/apache/kafka/common/resource/ResourceTypeTest.java new file mode 100644 index 0000000..fcde968 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/resource/ResourceTypeTest.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.resource; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class ResourceTypeTest { + private static class AclResourceTypeTestInfo { + private final ResourceType resourceType; + private final int code; + private final String name; + private final boolean unknown; + + AclResourceTypeTestInfo(ResourceType resourceType, int code, String name, boolean unknown) { + this.resourceType = resourceType; + this.code = code; + this.name = name; + this.unknown = unknown; + } + } + + private static final AclResourceTypeTestInfo[] INFOS = { + new AclResourceTypeTestInfo(ResourceType.UNKNOWN, 0, "unknown", true), + new AclResourceTypeTestInfo(ResourceType.ANY, 1, "any", false), + new AclResourceTypeTestInfo(ResourceType.TOPIC, 2, "topic", false), + new AclResourceTypeTestInfo(ResourceType.GROUP, 3, "group", false), + new AclResourceTypeTestInfo(ResourceType.CLUSTER, 4, "cluster", false), + new AclResourceTypeTestInfo(ResourceType.TRANSACTIONAL_ID, 5, "transactional_id", false), + new AclResourceTypeTestInfo(ResourceType.DELEGATION_TOKEN, 6, "delegation_token", false) + }; + + @Test + public void testIsUnknown() { + for (AclResourceTypeTestInfo info : INFOS) { + assertEquals(info.unknown, info.resourceType.isUnknown(), + info.resourceType + " was supposed to have unknown == " + info.unknown); + } + } + + @Test + public void testCode() { + assertEquals(ResourceType.values().length, INFOS.length); + for (AclResourceTypeTestInfo info : INFOS) { + assertEquals(info.code, info.resourceType.code(), + info.resourceType + " was supposed to have code == " + info.code); + assertEquals(info.resourceType, ResourceType.fromCode((byte) info.code), "AclResourceType.fromCode(" + info.code + ") was supposed to be " + + info.resourceType); + } + assertEquals(ResourceType.UNKNOWN, ResourceType.fromCode((byte) 120)); + } + + @Test + public void testName() { + for (AclResourceTypeTestInfo info : INFOS) { + assertEquals(info.resourceType, ResourceType.fromString(info.name), "ResourceType.fromString(" + info.name + ") was supposed to be " + + info.resourceType); + } + assertEquals(ResourceType.UNKNOWN, ResourceType.fromString("something")); + } + + @Test + public void testExhaustive() { + assertEquals(INFOS.length, ResourceType.values().length); + for (int i = 0; i < INFOS.length; i++) { + assertEquals(INFOS[i].resourceType, ResourceType.values()[i]); + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/JaasContextTest.java b/clients/src/test/java/org/apache/kafka/common/security/JaasContextTest.java new file mode 100644 index 0000000..9e718d8 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/JaasContextTest.java @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security; + +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import javax.security.auth.login.AppConfigurationEntry; +import javax.security.auth.login.AppConfigurationEntry.LoginModuleControlFlag; +import javax.security.auth.login.Configuration; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.fail; + +import org.apache.kafka.common.config.SaslConfigs; +import org.apache.kafka.common.config.types.Password; +import org.apache.kafka.common.network.ListenerName; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** + * Tests parsing of {@link SaslConfigs#SASL_JAAS_CONFIG} property and verifies that the format + * and parsing are consistent with JAAS configuration files loaded by the JRE. + */ +public class JaasContextTest { + + private File jaasConfigFile; + + @BeforeEach + public void setUp() throws IOException { + jaasConfigFile = File.createTempFile("jaas", ".conf"); + jaasConfigFile.deleteOnExit(); + System.setProperty(JaasUtils.JAVA_LOGIN_CONFIG_PARAM, jaasConfigFile.toString()); + Configuration.setConfiguration(null); + } + + @AfterEach + public void tearDown() throws Exception { + Files.delete(jaasConfigFile.toPath()); + } + + @Test + public void testConfigNoOptions() throws Exception { + checkConfiguration("test.testConfigNoOptions", LoginModuleControlFlag.REQUIRED, new HashMap()); + } + + @Test + public void testControlFlag() throws Exception { + LoginModuleControlFlag[] controlFlags = new LoginModuleControlFlag[] { + LoginModuleControlFlag.REQUIRED, + LoginModuleControlFlag.REQUISITE, + LoginModuleControlFlag.SUFFICIENT, + LoginModuleControlFlag.OPTIONAL + }; + Map options = new HashMap<>(); + options.put("propName", "propValue"); + for (LoginModuleControlFlag controlFlag : controlFlags) { + checkConfiguration("test.testControlFlag", controlFlag, options); + } + } + + @Test + public void testSingleOption() throws Exception { + Map options = new HashMap<>(); + options.put("propName", "propValue"); + checkConfiguration("test.testSingleOption", LoginModuleControlFlag.REQUISITE, options); + } + + @Test + public void testMultipleOptions() throws Exception { + Map options = new HashMap<>(); + for (int i = 0; i < 10; i++) + options.put("propName" + i, "propValue" + i); + checkConfiguration("test.testMultipleOptions", LoginModuleControlFlag.SUFFICIENT, options); + } + + @Test + public void testQuotedOptionValue() throws Exception { + Map options = new HashMap<>(); + options.put("propName", "prop value"); + options.put("propName2", "value1 = 1, value2 = 2"); + String config = String.format("test.testQuotedOptionValue required propName=\"%s\" propName2=\"%s\";", options.get("propName"), options.get("propName2")); + checkConfiguration(config, "test.testQuotedOptionValue", LoginModuleControlFlag.REQUIRED, options); + } + + @Test + public void testQuotedOptionName() throws Exception { + Map options = new HashMap<>(); + options.put("prop name", "propValue"); + String config = "test.testQuotedOptionName required \"prop name\"=propValue;"; + checkConfiguration(config, "test.testQuotedOptionName", LoginModuleControlFlag.REQUIRED, options); + } + + @Test + public void testMultipleLoginModules() throws Exception { + StringBuilder builder = new StringBuilder(); + int moduleCount = 3; + Map> moduleOptions = new HashMap<>(); + for (int i = 0; i < moduleCount; i++) { + Map options = new HashMap<>(); + options.put("index", "Index" + i); + options.put("module", "Module" + i); + moduleOptions.put(i, options); + String module = jaasConfigProp("test.Module" + i, LoginModuleControlFlag.REQUIRED, options); + builder.append(' '); + builder.append(module); + } + String jaasConfigProp = builder.toString(); + + String clientContextName = "CLIENT"; + Configuration configuration = new JaasConfig(clientContextName, jaasConfigProp); + AppConfigurationEntry[] dynamicEntries = configuration.getAppConfigurationEntry(clientContextName); + assertEquals(moduleCount, dynamicEntries.length); + + for (int i = 0; i < moduleCount; i++) { + AppConfigurationEntry entry = dynamicEntries[i]; + checkEntry(entry, "test.Module" + i, LoginModuleControlFlag.REQUIRED, moduleOptions.get(i)); + } + + String serverContextName = "SERVER"; + writeConfiguration(serverContextName, jaasConfigProp); + AppConfigurationEntry[] staticEntries = Configuration.getConfiguration().getAppConfigurationEntry(serverContextName); + for (int i = 0; i < moduleCount; i++) { + AppConfigurationEntry staticEntry = staticEntries[i]; + checkEntry(staticEntry, dynamicEntries[i].getLoginModuleName(), LoginModuleControlFlag.REQUIRED, dynamicEntries[i].getOptions()); + } + } + + @Test + public void testMissingLoginModule() throws Exception { + checkInvalidConfiguration(" required option1=value1;"); + } + + @Test + public void testMissingControlFlag() throws Exception { + checkInvalidConfiguration("test.loginModule option1=value1;"); + } + + @Test + public void testMissingOptionValue() throws Exception { + checkInvalidConfiguration("loginModule required option1;"); + } + + @Test + public void testMissingSemicolon() throws Exception { + checkInvalidConfiguration("test.testMissingSemicolon required option1=value1"); + } + + @Test + public void testNumericOptionWithoutQuotes() throws Exception { + checkInvalidConfiguration("test.testNumericOptionWithoutQuotes required option1=3;"); + } + + @Test + public void testInvalidControlFlag() throws Exception { + checkInvalidConfiguration("test.testInvalidControlFlag { option1=3;"); + } + + @Test + public void testNumericOptionWithQuotes() throws Exception { + Map options = new HashMap<>(); + options.put("option1", "3"); + String config = "test.testNumericOptionWithQuotes required option1=\"3\";"; + checkConfiguration(config, "test.testNumericOptionWithQuotes", LoginModuleControlFlag.REQUIRED, options); + } + + @Test + public void testLoadForServerWithListenerNameOverride() throws IOException { + writeConfiguration(Arrays.asList( + "KafkaServer { test.LoginModuleDefault required; };", + "plaintext.KafkaServer { test.LoginModuleOverride requisite; };" + )); + JaasContext context = JaasContext.loadServerContext(new ListenerName("plaintext"), + "SOME-MECHANISM", Collections.emptyMap()); + assertEquals("plaintext.KafkaServer", context.name()); + assertEquals(JaasContext.Type.SERVER, context.type()); + assertEquals(1, context.configurationEntries().size()); + checkEntry(context.configurationEntries().get(0), "test.LoginModuleOverride", + LoginModuleControlFlag.REQUISITE, Collections.emptyMap()); + } + + @Test + public void testLoadForServerWithListenerNameAndFallback() throws IOException { + writeConfiguration(Arrays.asList( + "KafkaServer { test.LoginModule required; };", + "other.KafkaServer { test.LoginModuleOther requisite; };" + )); + JaasContext context = JaasContext.loadServerContext(new ListenerName("plaintext"), + "SOME-MECHANISM", Collections.emptyMap()); + assertEquals("KafkaServer", context.name()); + assertEquals(JaasContext.Type.SERVER, context.type()); + assertEquals(1, context.configurationEntries().size()); + checkEntry(context.configurationEntries().get(0), "test.LoginModule", LoginModuleControlFlag.REQUIRED, + Collections.emptyMap()); + } + + @Test + public void testLoadForServerWithWrongListenerName() throws IOException { + writeConfiguration("Server", "test.LoginModule required;"); + assertThrows(IllegalArgumentException.class, () -> JaasContext.loadServerContext(new ListenerName("plaintext"), + "SOME-MECHANISM", Collections.emptyMap())); + } + + private AppConfigurationEntry configurationEntry(JaasContext.Type contextType, String jaasConfigProp) { + Password saslJaasConfig = jaasConfigProp == null ? null : new Password(jaasConfigProp); + JaasContext context = JaasContext.load(contextType, null, contextType.name(), saslJaasConfig); + List entries = context.configurationEntries(); + assertEquals(1, entries.size()); + return entries.get(0); + } + + private String controlFlag(LoginModuleControlFlag loginModuleControlFlag) { + // LoginModuleControlFlag.toString() has format "LoginModuleControlFlag: flag" + String[] tokens = loginModuleControlFlag.toString().split(" "); + return tokens[tokens.length - 1]; + } + + private String jaasConfigProp(String loginModule, LoginModuleControlFlag controlFlag, Map options) { + StringBuilder builder = new StringBuilder(); + builder.append(loginModule); + builder.append(' '); + builder.append(controlFlag(controlFlag)); + for (Map.Entry entry : options.entrySet()) { + builder.append(' '); + builder.append(entry.getKey()); + builder.append('='); + builder.append(entry.getValue()); + } + builder.append(';'); + return builder.toString(); + } + + private void writeConfiguration(String contextName, String jaasConfigProp) throws IOException { + List lines = Arrays.asList(contextName + " { ", jaasConfigProp, "};"); + writeConfiguration(lines); + } + + private void writeConfiguration(List lines) throws IOException { + Files.write(jaasConfigFile.toPath(), lines, StandardCharsets.UTF_8); + Configuration.setConfiguration(null); + } + + private void checkConfiguration(String loginModule, LoginModuleControlFlag controlFlag, Map options) throws Exception { + String jaasConfigProp = jaasConfigProp(loginModule, controlFlag, options); + checkConfiguration(jaasConfigProp, loginModule, controlFlag, options); + } + + private void checkEntry(AppConfigurationEntry entry, String loginModule, LoginModuleControlFlag controlFlag, Map options) { + assertEquals(loginModule, entry.getLoginModuleName()); + assertEquals(controlFlag, entry.getControlFlag()); + assertEquals(options, entry.getOptions()); + } + + private void checkConfiguration(String jaasConfigProp, String loginModule, LoginModuleControlFlag controlFlag, Map options) throws Exception { + AppConfigurationEntry dynamicEntry = configurationEntry(JaasContext.Type.CLIENT, jaasConfigProp); + checkEntry(dynamicEntry, loginModule, controlFlag, options); + assertNull(Configuration.getConfiguration().getAppConfigurationEntry(JaasContext.Type.CLIENT.name()), "Static configuration updated"); + + writeConfiguration(JaasContext.Type.SERVER.name(), jaasConfigProp); + AppConfigurationEntry staticEntry = configurationEntry(JaasContext.Type.SERVER, null); + checkEntry(staticEntry, loginModule, controlFlag, options); + } + + private void checkInvalidConfiguration(String jaasConfigProp) throws IOException { + try { + writeConfiguration(JaasContext.Type.SERVER.name(), jaasConfigProp); + AppConfigurationEntry entry = configurationEntry(JaasContext.Type.SERVER, null); + fail("Invalid JAAS configuration file didn't throw exception, entry=" + entry); + } catch (SecurityException e) { + // Expected exception + } + try { + AppConfigurationEntry entry = configurationEntry(JaasContext.Type.CLIENT, jaasConfigProp); + fail("Invalid JAAS configuration property didn't throw exception, entry=" + entry); + } catch (IllegalArgumentException e) { + // Expected exception + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/SaslExtensionsTest.java b/clients/src/test/java/org/apache/kafka/common/security/SaslExtensionsTest.java new file mode 100644 index 0000000..9acb78c --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/SaslExtensionsTest.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security; + +import org.apache.kafka.common.security.auth.SaslExtensions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class SaslExtensionsTest { + Map map; + + @BeforeEach + public void setUp() { + this.map = new HashMap<>(); + this.map.put("what", "42"); + this.map.put("who", "me"); + } + + @Test + public void testReturnedMapIsImmutable() { + SaslExtensions extensions = new SaslExtensions(this.map); + assertThrows(UnsupportedOperationException.class, () -> extensions.map().put("hello", "test")); + } + + @Test + public void testCannotAddValueToMapReferenceAndGetFromExtensions() { + SaslExtensions extensions = new SaslExtensions(this.map); + + assertNull(extensions.map().get("hello")); + this.map.put("hello", "42"); + assertNull(extensions.map().get("hello")); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/TestSecurityConfig.java b/clients/src/test/java/org/apache/kafka/common/security/TestSecurityConfig.java new file mode 100644 index 0000000..07cbb78 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/TestSecurityConfig.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security; + +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigDef.Importance; +import org.apache.kafka.common.config.ConfigDef.Type; +import org.apache.kafka.common.config.internals.BrokerSecurityConfigs; + +import java.util.Map; + +public class TestSecurityConfig extends AbstractConfig { + private static final ConfigDef CONFIG = new ConfigDef() + .define(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, Type.STRING, null, Importance.MEDIUM, + BrokerSecurityConfigs.SSL_CLIENT_AUTH_DOC) + .define(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG, Type.LIST, + BrokerSecurityConfigs.DEFAULT_SASL_ENABLED_MECHANISMS, + Importance.MEDIUM, BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_DOC) + .define(BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS, Type.CLASS, + null, + Importance.MEDIUM, BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS_DOC) + .define(BrokerSecurityConfigs.PRINCIPAL_BUILDER_CLASS_CONFIG, Type.CLASS, + null, Importance.MEDIUM, BrokerSecurityConfigs.PRINCIPAL_BUILDER_CLASS_DOC) + .define(BrokerSecurityConfigs.CONNECTIONS_MAX_REAUTH_MS, Type.LONG, 0L, Importance.MEDIUM, + BrokerSecurityConfigs.CONNECTIONS_MAX_REAUTH_MS_DOC) + .withClientSslSupport() + .withClientSaslSupport(); + + public TestSecurityConfig(Map originals) { + super(CONFIG, originals, false); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/auth/DefaultKafkaPrincipalBuilderTest.java b/clients/src/test/java/org/apache/kafka/common/security/auth/DefaultKafkaPrincipalBuilderTest.java new file mode 100644 index 0000000..73a03ab --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/auth/DefaultKafkaPrincipalBuilderTest.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.auth; + +import javax.security.auth.x500.X500Principal; +import org.apache.kafka.common.config.SaslConfigs; +import org.apache.kafka.common.security.authenticator.DefaultKafkaPrincipalBuilder; +import org.apache.kafka.common.security.kerberos.KerberosShortNamer; +import org.apache.kafka.common.security.scram.internals.ScramMechanism; +import org.apache.kafka.common.security.ssl.SslPrincipalMapper; +import org.junit.jupiter.api.Test; + +import javax.net.ssl.SSLSession; +import javax.security.sasl.SaslServer; +import java.net.InetAddress; +import java.security.Principal; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class DefaultKafkaPrincipalBuilderTest { + + @Test + public void testReturnAnonymousPrincipalForPlaintext() throws Exception { + DefaultKafkaPrincipalBuilder builder = new DefaultKafkaPrincipalBuilder(null, null); + assertEquals(KafkaPrincipal.ANONYMOUS, builder.build( + new PlaintextAuthenticationContext(InetAddress.getLocalHost(), SecurityProtocol.PLAINTEXT.name()))); + } + + @Test + public void testUseSessionPeerPrincipalForSsl() throws Exception { + SSLSession session = mock(SSLSession.class); + + when(session.getPeerPrincipal()).thenReturn(new DummyPrincipal("foo")); + + DefaultKafkaPrincipalBuilder builder = new DefaultKafkaPrincipalBuilder(null, null); + + KafkaPrincipal principal = builder.build( + new SslAuthenticationContext(session, InetAddress.getLocalHost(), SecurityProtocol.PLAINTEXT.name())); + assertEquals(KafkaPrincipal.USER_TYPE, principal.getPrincipalType()); + assertEquals("foo", principal.getName()); + + verify(session, atLeastOnce()).getPeerPrincipal(); + } + + @Test + public void testPrincipalIfSSLPeerIsNotAuthenticated() throws Exception { + SSLSession session = mock(SSLSession.class); + + when(session.getPeerPrincipal()).thenReturn(KafkaPrincipal.ANONYMOUS); + + DefaultKafkaPrincipalBuilder builder = new DefaultKafkaPrincipalBuilder(null, null); + + KafkaPrincipal principal = builder.build( + new SslAuthenticationContext(session, InetAddress.getLocalHost(), SecurityProtocol.PLAINTEXT.name())); + assertEquals(KafkaPrincipal.ANONYMOUS, principal); + + verify(session, atLeastOnce()).getPeerPrincipal(); + } + + + @Test + public void testPrincipalWithSslPrincipalMapper() throws Exception { + SSLSession session = mock(SSLSession.class); + + when(session.getPeerPrincipal()).thenReturn(new X500Principal("CN=Duke, OU=ServiceUsers, O=Org, C=US")) + .thenReturn(new X500Principal("CN=Duke, OU=SME, O=mycp, L=Fulton, ST=MD, C=US")) + .thenReturn(new X500Principal("CN=duke, OU=JavaSoft, O=Sun Microsystems")) + .thenReturn(new X500Principal("OU=JavaSoft, O=Sun Microsystems, C=US")); + + String rules = String.join(", ", + "RULE:^CN=(.*),OU=ServiceUsers.*$/$1/L", + "RULE:^CN=(.*),OU=(.*),O=(.*),L=(.*),ST=(.*),C=(.*)$/$1@$2/L", + "RULE:^.*[Cc][Nn]=([a-zA-Z0-9.]*).*$/$1/U", + "DEFAULT" + ); + + SslPrincipalMapper mapper = SslPrincipalMapper.fromRules(rules); + DefaultKafkaPrincipalBuilder builder = new DefaultKafkaPrincipalBuilder(null, mapper); + + SslAuthenticationContext sslContext = new SslAuthenticationContext(session, InetAddress.getLocalHost(), SecurityProtocol.PLAINTEXT.name()); + + KafkaPrincipal principal = builder.build(sslContext); + assertEquals("duke", principal.getName()); + + principal = builder.build(sslContext); + assertEquals("duke@sme", principal.getName()); + + principal = builder.build(sslContext); + assertEquals("DUKE", principal.getName()); + + principal = builder.build(sslContext); + assertEquals("OU=JavaSoft,O=Sun Microsystems,C=US", principal.getName()); + + verify(session, times(4)).getPeerPrincipal(); + } + + @Test + public void testPrincipalBuilderScram() throws Exception { + SaslServer server = mock(SaslServer.class); + + when(server.getMechanismName()).thenReturn(ScramMechanism.SCRAM_SHA_256.mechanismName()); + when(server.getAuthorizationID()).thenReturn("foo"); + + DefaultKafkaPrincipalBuilder builder = new DefaultKafkaPrincipalBuilder(null, null); + + KafkaPrincipal principal = builder.build(new SaslAuthenticationContext(server, + SecurityProtocol.SASL_PLAINTEXT, InetAddress.getLocalHost(), SecurityProtocol.SASL_PLAINTEXT.name())); + assertEquals(KafkaPrincipal.USER_TYPE, principal.getPrincipalType()); + assertEquals("foo", principal.getName()); + + verify(server, atLeastOnce()).getMechanismName(); + verify(server, atLeastOnce()).getAuthorizationID(); + } + + @Test + public void testPrincipalBuilderGssapi() throws Exception { + SaslServer server = mock(SaslServer.class); + KerberosShortNamer kerberosShortNamer = mock(KerberosShortNamer.class); + + when(server.getMechanismName()).thenReturn(SaslConfigs.GSSAPI_MECHANISM); + when(server.getAuthorizationID()).thenReturn("foo/host@REALM.COM"); + when(kerberosShortNamer.shortName(any())).thenReturn("foo"); + + DefaultKafkaPrincipalBuilder builder = new DefaultKafkaPrincipalBuilder(kerberosShortNamer, null); + + KafkaPrincipal principal = builder.build(new SaslAuthenticationContext(server, + SecurityProtocol.SASL_PLAINTEXT, InetAddress.getLocalHost(), SecurityProtocol.SASL_PLAINTEXT.name())); + assertEquals(KafkaPrincipal.USER_TYPE, principal.getPrincipalType()); + assertEquals("foo", principal.getName()); + + verify(server, atLeastOnce()).getMechanismName(); + verify(server, atLeastOnce()).getAuthorizationID(); + verify(kerberosShortNamer, atLeastOnce()).shortName(any()); + } + + @Test + public void testPrincipalBuilderSerde() throws Exception { + SaslServer server = mock(SaslServer.class); + KerberosShortNamer kerberosShortNamer = mock(KerberosShortNamer.class); + + when(server.getMechanismName()).thenReturn(SaslConfigs.GSSAPI_MECHANISM); + when(server.getAuthorizationID()).thenReturn("foo/host@REALM.COM"); + when(kerberosShortNamer.shortName(any())).thenReturn("foo"); + + DefaultKafkaPrincipalBuilder builder = new DefaultKafkaPrincipalBuilder(kerberosShortNamer, null); + + KafkaPrincipal principal = builder.build(new SaslAuthenticationContext(server, + SecurityProtocol.SASL_PLAINTEXT, InetAddress.getLocalHost(), SecurityProtocol.SASL_PLAINTEXT.name())); + assertEquals(KafkaPrincipal.USER_TYPE, principal.getPrincipalType()); + assertEquals("foo", principal.getName()); + + byte[] serializedPrincipal = builder.serialize(principal); + KafkaPrincipal deserializedPrincipal = builder.deserialize(serializedPrincipal); + assertEquals(principal, deserializedPrincipal); + + verify(server, atLeastOnce()).getMechanismName(); + verify(server, atLeastOnce()).getAuthorizationID(); + verify(kerberosShortNamer, atLeastOnce()).shortName(any()); + } + + private static class DummyPrincipal implements Principal { + private final String name; + + private DummyPrincipal(String name) { + this.name = name; + } + + @Override + public String getName() { + return name; + } + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/auth/KafkaPrincipalTest.java b/clients/src/test/java/org/apache/kafka/common/security/auth/KafkaPrincipalTest.java new file mode 100644 index 0000000..639254a --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/auth/KafkaPrincipalTest.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.auth; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class KafkaPrincipalTest { + + @Test + public void testEqualsAndHashCode() { + String name = "KafkaUser"; + KafkaPrincipal principal1 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, name); + KafkaPrincipal principal2 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, name); + + assertEquals(principal1.hashCode(), principal2.hashCode()); + assertEquals(principal1, principal2); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/authenticator/ClientAuthenticationFailureTest.java b/clients/src/test/java/org/apache/kafka/common/security/authenticator/ClientAuthenticationFailureTest.java new file mode 100644 index 0000000..0a0466f --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/authenticator/ClientAuthenticationFailureTest.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.authenticator; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.TopicDescription; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.config.SaslConfigs; +import org.apache.kafka.common.config.internals.BrokerSecurityConfigs; +import org.apache.kafka.common.errors.SaslAuthenticationException; +import org.apache.kafka.common.network.ListenerName; +import org.apache.kafka.common.network.NetworkTestUtils; +import org.apache.kafka.common.network.NioEchoServer; +import org.apache.kafka.common.security.TestSecurityConfig; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.Future; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class ClientAuthenticationFailureTest { + private static MockTime time = new MockTime(50); + + private NioEchoServer server; + private Map saslServerConfigs; + private Map saslClientConfigs; + private final String topic = "test"; + private TestJaasConfig testJaasConfig; + + @BeforeEach + public void setup() throws Exception { + LoginManager.closeAll(); + SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT; + + saslServerConfigs = new HashMap<>(); + saslServerConfigs.put(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG, Arrays.asList("PLAIN")); + + saslClientConfigs = new HashMap<>(); + saslClientConfigs.put(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, "SASL_PLAINTEXT"); + saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, "PLAIN"); + + testJaasConfig = TestJaasConfig.createConfiguration("PLAIN", Arrays.asList("PLAIN")); + testJaasConfig.setClientOptions("PLAIN", TestJaasConfig.USERNAME, "anotherpassword"); + server = createEchoServer(securityProtocol); + } + + @AfterEach + public void teardown() throws Exception { + if (server != null) + server.close(); + } + + @Test + public void testConsumerWithInvalidCredentials() { + Map props = new HashMap<>(saslClientConfigs); + props.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:" + server.port()); + props.put(ConsumerConfig.GROUP_ID_CONFIG, ""); + StringDeserializer deserializer = new StringDeserializer(); + + try (KafkaConsumer consumer = new KafkaConsumer<>(props, deserializer, deserializer)) { + assertThrows(SaslAuthenticationException.class, () -> { + consumer.subscribe(Collections.singleton(topic)); + consumer.poll(Duration.ofSeconds(10)); + }); + } + } + + @Test + public void testProducerWithInvalidCredentials() { + Map props = new HashMap<>(saslClientConfigs); + props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:" + server.port()); + StringSerializer serializer = new StringSerializer(); + + try (KafkaProducer producer = new KafkaProducer<>(props, serializer, serializer)) { + ProducerRecord record = new ProducerRecord<>(topic, "message"); + Future future = producer.send(record); + TestUtils.assertFutureThrows(future, SaslAuthenticationException.class); + } + } + + @Test + public void testAdminClientWithInvalidCredentials() { + Map props = new HashMap<>(saslClientConfigs); + props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:" + server.port()); + try (Admin client = Admin.create(props)) { + KafkaFuture> future = client.describeTopics(Collections.singleton("test")).allTopicNames(); + TestUtils.assertFutureThrows(future, SaslAuthenticationException.class); + } + } + + @Test + public void testTransactionalProducerWithInvalidCredentials() { + Map props = new HashMap<>(saslClientConfigs); + props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:" + server.port()); + props.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "txclient-1"); + props.put(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG, "true"); + StringSerializer serializer = new StringSerializer(); + + try (KafkaProducer producer = new KafkaProducer<>(props, serializer, serializer)) { + assertThrows(SaslAuthenticationException.class, producer::initTransactions); + } + } + + private NioEchoServer createEchoServer(SecurityProtocol securityProtocol) throws Exception { + return createEchoServer(ListenerName.forSecurityProtocol(securityProtocol), securityProtocol); + } + + private NioEchoServer createEchoServer(ListenerName listenerName, SecurityProtocol securityProtocol) throws Exception { + return NetworkTestUtils.createEchoServer(listenerName, securityProtocol, + new TestSecurityConfig(saslServerConfigs), new CredentialCache(), time); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/authenticator/LoginManagerTest.java b/clients/src/test/java/org/apache/kafka/common/security/authenticator/LoginManagerTest.java new file mode 100644 index 0000000..92edfae --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/authenticator/LoginManagerTest.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.authenticator; + +import org.apache.kafka.common.config.types.Password; +import org.apache.kafka.common.network.ListenerName; +import org.apache.kafka.common.security.JaasContext; +import org.apache.kafka.common.security.plain.PlainLoginModule; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class LoginManagerTest { + + private Password dynamicPlainContext; + private Password dynamicDigestContext; + + @BeforeEach + public void setUp() { + dynamicPlainContext = new Password(PlainLoginModule.class.getName() + + " required user=\"plainuser\" password=\"plain-secret\";"); + dynamicDigestContext = new Password(TestDigestLoginModule.class.getName() + + " required user=\"digestuser\" password=\"digest-secret\";"); + TestJaasConfig.createConfiguration("SCRAM-SHA-256", + Collections.singletonList("SCRAM-SHA-256")); + } + + @AfterEach + public void tearDown() { + LoginManager.closeAll(); + } + + @Test + public void testClientLoginManager() throws Exception { + Map configs = Collections.singletonMap("sasl.jaas.config", dynamicPlainContext); + JaasContext dynamicContext = JaasContext.loadClientContext(configs); + JaasContext staticContext = JaasContext.loadClientContext(Collections.emptyMap()); + + LoginManager dynamicLogin = LoginManager.acquireLoginManager(dynamicContext, "PLAIN", + DefaultLogin.class, configs); + assertEquals(dynamicPlainContext, dynamicLogin.cacheKey()); + LoginManager staticLogin = LoginManager.acquireLoginManager(staticContext, "SCRAM-SHA-256", + DefaultLogin.class, configs); + assertNotSame(dynamicLogin, staticLogin); + assertEquals("KafkaClient", staticLogin.cacheKey()); + + assertSame(dynamicLogin, LoginManager.acquireLoginManager(dynamicContext, "PLAIN", + DefaultLogin.class, configs)); + assertSame(staticLogin, LoginManager.acquireLoginManager(staticContext, "SCRAM-SHA-256", + DefaultLogin.class, configs)); + + verifyLoginManagerRelease(dynamicLogin, 2, dynamicContext, configs); + verifyLoginManagerRelease(staticLogin, 2, staticContext, configs); + } + + @Test + public void testServerLoginManager() throws Exception { + Map configs = new HashMap<>(); + configs.put("plain.sasl.jaas.config", dynamicPlainContext); + configs.put("digest-md5.sasl.jaas.config", dynamicDigestContext); + ListenerName listenerName = new ListenerName("listener1"); + JaasContext plainJaasContext = JaasContext.loadServerContext(listenerName, "PLAIN", configs); + JaasContext digestJaasContext = JaasContext.loadServerContext(listenerName, "DIGEST-MD5", configs); + JaasContext scramJaasContext = JaasContext.loadServerContext(listenerName, "SCRAM-SHA-256", configs); + + LoginManager dynamicPlainLogin = LoginManager.acquireLoginManager(plainJaasContext, "PLAIN", + DefaultLogin.class, configs); + assertEquals(dynamicPlainContext, dynamicPlainLogin.cacheKey()); + LoginManager dynamicDigestLogin = LoginManager.acquireLoginManager(digestJaasContext, "DIGEST-MD5", + DefaultLogin.class, configs); + assertNotSame(dynamicPlainLogin, dynamicDigestLogin); + assertEquals(dynamicDigestContext, dynamicDigestLogin.cacheKey()); + LoginManager staticScramLogin = LoginManager.acquireLoginManager(scramJaasContext, "SCRAM-SHA-256", + DefaultLogin.class, configs); + assertNotSame(dynamicPlainLogin, staticScramLogin); + assertEquals("KafkaServer", staticScramLogin.cacheKey()); + + assertSame(dynamicPlainLogin, LoginManager.acquireLoginManager(plainJaasContext, "PLAIN", + DefaultLogin.class, configs)); + assertSame(dynamicDigestLogin, LoginManager.acquireLoginManager(digestJaasContext, "DIGEST-MD5", + DefaultLogin.class, configs)); + assertSame(staticScramLogin, LoginManager.acquireLoginManager(scramJaasContext, "SCRAM-SHA-256", + DefaultLogin.class, configs)); + + verifyLoginManagerRelease(dynamicPlainLogin, 2, plainJaasContext, configs); + verifyLoginManagerRelease(dynamicDigestLogin, 2, digestJaasContext, configs); + verifyLoginManagerRelease(staticScramLogin, 2, scramJaasContext, configs); + } + + private void verifyLoginManagerRelease(LoginManager loginManager, int acquireCount, JaasContext jaasContext, + Map configs) throws Exception { + + // Release all except one reference and verify that the loginManager is still cached + for (int i = 0; i < acquireCount - 1; i++) + loginManager.release(); + assertSame(loginManager, LoginManager.acquireLoginManager(jaasContext, "PLAIN", + DefaultLogin.class, configs)); + + // Release all references and verify that new LoginManager is created on next acquire + for (int i = 0; i < 2; i++) // release all references + loginManager.release(); + LoginManager newLoginManager = LoginManager.acquireLoginManager(jaasContext, "PLAIN", + DefaultLogin.class, configs); + assertNotSame(loginManager, newLoginManager); + newLoginManager.release(); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorFailureDelayTest.java b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorFailureDelayTest.java new file mode 100644 index 0000000..db6ba89 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorFailureDelayTest.java @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.authenticator; + +import org.apache.kafka.common.config.SaslConfigs; +import org.apache.kafka.common.config.internals.BrokerSecurityConfigs; +import org.apache.kafka.common.errors.SaslAuthenticationException; +import org.apache.kafka.common.network.CertStores; +import org.apache.kafka.common.network.ChannelBuilder; +import org.apache.kafka.common.network.ChannelBuilders; +import org.apache.kafka.common.network.ChannelState; +import org.apache.kafka.common.network.ListenerName; +import org.apache.kafka.common.network.NetworkTestUtils; +import org.apache.kafka.common.network.NioEchoServer; +import org.apache.kafka.common.network.Selector; +import org.apache.kafka.common.security.JaasContext; +import org.apache.kafka.common.security.TestSecurityConfig; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public abstract class SaslAuthenticatorFailureDelayTest { + private static final int BUFFER_SIZE = 4 * 1024; + + private final MockTime time = new MockTime(1); + private NioEchoServer server; + private Selector selector; + private ChannelBuilder channelBuilder; + private CertStores serverCertStores; + private CertStores clientCertStores; + private Map saslClientConfigs; + private Map saslServerConfigs; + private CredentialCache credentialCache; + private long startTimeMs; + private final int failedAuthenticationDelayMs; + + public SaslAuthenticatorFailureDelayTest(int failedAuthenticationDelayMs) { + this.failedAuthenticationDelayMs = failedAuthenticationDelayMs; + } + + @BeforeEach + public void setup() throws Exception { + LoginManager.closeAll(); + serverCertStores = new CertStores(true, "localhost"); + clientCertStores = new CertStores(false, "localhost"); + saslServerConfigs = serverCertStores.getTrustingConfig(clientCertStores); + saslClientConfigs = clientCertStores.getTrustingConfig(serverCertStores); + credentialCache = new CredentialCache(); + SaslAuthenticatorTest.TestLogin.loginCount.set(0); + startTimeMs = time.milliseconds(); + } + + @AfterEach + public void teardown() throws Exception { + long now = time.milliseconds(); + if (server != null) + this.server.close(); + if (selector != null) + this.selector.close(); + assertTrue(now - startTimeMs >= failedAuthenticationDelayMs, "timeSpent: " + (now - startTimeMs)); + } + + /** + * Tests that SASL/PLAIN clients with invalid password fail authentication. + */ + @Test + public void testInvalidPasswordSaslPlain() throws Exception { + String node = "0"; + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + TestJaasConfig jaasConfig = configureMechanisms("PLAIN", Arrays.asList("PLAIN")); + jaasConfig.setClientOptions("PLAIN", TestJaasConfig.USERNAME, "invalidpassword"); + + server = createEchoServer(securityProtocol); + createAndCheckClientAuthenticationFailure(securityProtocol, node, "PLAIN", + "Authentication failed: Invalid username or password"); + server.verifyAuthenticationMetrics(0, 1); + } + + /** + * Tests that SASL/SCRAM clients with invalid password fail authentication with + * connection close delay if configured. + */ + @Test + public void testInvalidPasswordSaslScram() throws Exception { + String node = "0"; + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + TestJaasConfig jaasConfig = configureMechanisms("SCRAM-SHA-256", Collections.singletonList("SCRAM-SHA-256")); + jaasConfig.setClientOptions("SCRAM-SHA-256", TestJaasConfig.USERNAME, "invalidpassword"); + + server = createEchoServer(securityProtocol); + createAndCheckClientAuthenticationFailure(securityProtocol, node, "SCRAM-SHA-256", null); + server.verifyAuthenticationMetrics(0, 1); + } + + /** + * Tests that clients with disabled SASL mechanism fail authentication with + * connection close delay if configured. + */ + @Test + public void testDisabledSaslMechanism() throws Exception { + String node = "0"; + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + TestJaasConfig jaasConfig = configureMechanisms("SCRAM-SHA-256", Collections.singletonList("SCRAM-SHA-256")); + jaasConfig.setClientOptions("PLAIN", TestJaasConfig.USERNAME, "invalidpassword"); + + server = createEchoServer(securityProtocol); + createAndCheckClientAuthenticationFailure(securityProtocol, node, "SCRAM-SHA-256", null); + server.verifyAuthenticationMetrics(0, 1); + } + + /** + * Tests client connection close before response for authentication failure is sent. + */ + @Test + public void testClientConnectionClose() throws Exception { + String node = "0"; + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + TestJaasConfig jaasConfig = configureMechanisms("PLAIN", Arrays.asList("PLAIN")); + jaasConfig.setClientOptions("PLAIN", TestJaasConfig.USERNAME, "invalidpassword"); + + server = createEchoServer(securityProtocol); + createClientConnection(securityProtocol, node); + + Map delayedClosingChannels = NetworkTestUtils.delayedClosingChannels(server.selector()); + + // Wait until server has established connection with client and has processed the auth failure + TestUtils.waitForCondition(() -> { + poll(selector); + return !server.selector().channels().isEmpty(); + }, "Timeout waiting for connection"); + TestUtils.waitForCondition(() -> { + poll(selector); + return failedAuthenticationDelayMs == 0 || !delayedClosingChannels.isEmpty(); + }, "Timeout waiting for auth failure"); + + selector.close(); + selector = null; + + // Now that client connection is closed, wait until server notices the disconnection and removes it from the + // list of connected channels and from delayed response for auth failure + TestUtils.waitForCondition(() -> failedAuthenticationDelayMs == 0 || delayedClosingChannels.isEmpty(), + "Timeout waiting for delayed response remove"); + TestUtils.waitForCondition(() -> server.selector().channels().isEmpty(), + "Timeout waiting for connection close"); + + // Try forcing completion of delayed channel close + TestUtils.waitForCondition(() -> time.milliseconds() > startTimeMs + failedAuthenticationDelayMs + 1, + "Timeout when waiting for auth failure response timeout to elapse"); + NetworkTestUtils.completeDelayedChannelClose(server.selector(), time.nanoseconds()); + } + + private void poll(Selector selector) { + try { + selector.poll(50); + } catch (IOException e) { + throw new RuntimeException("Unexpected failure during selector poll", e); + } + } + + private TestJaasConfig configureMechanisms(String clientMechanism, List serverMechanisms) { + saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, clientMechanism); + saslServerConfigs.put(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG, serverMechanisms); + if (serverMechanisms.contains("DIGEST-MD5")) { + saslServerConfigs.put("digest-md5." + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS, + TestDigestLoginModule.DigestServerCallbackHandler.class.getName()); + } + return TestJaasConfig.createConfiguration(clientMechanism, serverMechanisms); + } + + private void createSelector(SecurityProtocol securityProtocol, Map clientConfigs) { + if (selector != null) { + selector.close(); + selector = null; + } + + String saslMechanism = (String) saslClientConfigs.get(SaslConfigs.SASL_MECHANISM); + this.channelBuilder = ChannelBuilders.clientChannelBuilder(securityProtocol, JaasContext.Type.CLIENT, + new TestSecurityConfig(clientConfigs), null, saslMechanism, time, true, + new LogContext()); + this.selector = NetworkTestUtils.createSelector(channelBuilder, time); + } + + private NioEchoServer createEchoServer(SecurityProtocol securityProtocol) throws Exception { + return createEchoServer(ListenerName.forSecurityProtocol(securityProtocol), securityProtocol); + } + + private NioEchoServer createEchoServer(ListenerName listenerName, SecurityProtocol securityProtocol) throws Exception { + return NetworkTestUtils.createEchoServer(listenerName, securityProtocol, + new TestSecurityConfig(saslServerConfigs), credentialCache, time); + } + + private void createClientConnection(SecurityProtocol securityProtocol, String node) throws Exception { + createSelector(securityProtocol, saslClientConfigs); + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); + } + + private void createAndCheckClientAuthenticationFailure(SecurityProtocol securityProtocol, String node, + String mechanism, String expectedErrorMessage) throws Exception { + ChannelState finalState = createAndCheckClientConnectionFailure(securityProtocol, node); + Exception exception = finalState.exception(); + assertTrue(exception instanceof SaslAuthenticationException, "Invalid exception class " + exception.getClass()); + if (expectedErrorMessage == null) + expectedErrorMessage = "Authentication failed during authentication due to invalid credentials with SASL mechanism " + mechanism; + assertEquals(expectedErrorMessage, exception.getMessage()); + } + + private ChannelState createAndCheckClientConnectionFailure(SecurityProtocol securityProtocol, String node) + throws Exception { + createClientConnection(securityProtocol, node); + ChannelState finalState = NetworkTestUtils.waitForChannelClose(selector, node, + ChannelState.State.AUTHENTICATION_FAILED); + selector.close(); + selector = null; + return finalState; + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorFailureNoDelayTest.java b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorFailureNoDelayTest.java new file mode 100644 index 0000000..bac154e --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorFailureNoDelayTest.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.authenticator; + +public class SaslAuthenticatorFailureNoDelayTest extends SaslAuthenticatorFailureDelayTest { + public SaslAuthenticatorFailureNoDelayTest() { + super(0); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorFailurePositiveDelayTest.java b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorFailurePositiveDelayTest.java new file mode 100644 index 0000000..1a803b6 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorFailurePositiveDelayTest.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.authenticator; + +public class SaslAuthenticatorFailurePositiveDelayTest extends SaslAuthenticatorFailureDelayTest { + public SaslAuthenticatorFailurePositiveDelayTest() { + super(200); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java new file mode 100644 index 0000000..988a0f2 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java @@ -0,0 +1,2577 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.authenticator; + +import org.apache.kafka.clients.NetworkClient; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.config.SaslConfigs; +import org.apache.kafka.common.config.SslClientAuth; +import org.apache.kafka.common.config.SslConfigs; +import org.apache.kafka.common.config.internals.BrokerSecurityConfigs; +import org.apache.kafka.common.config.types.Password; +import org.apache.kafka.common.errors.SaslAuthenticationException; +import org.apache.kafka.common.errors.SslAuthenticationException; +import org.apache.kafka.common.message.ApiMessageType; +import org.apache.kafka.common.message.ApiVersionsRequestData; +import org.apache.kafka.common.message.ApiVersionsResponseData; +import org.apache.kafka.common.message.ApiVersionsResponseData.ApiVersion; +import org.apache.kafka.common.message.ApiVersionsResponseData.ApiVersionCollection; +import org.apache.kafka.common.message.ListOffsetsResponseData; +import org.apache.kafka.common.message.ListOffsetsResponseData.ListOffsetsPartitionResponse; +import org.apache.kafka.common.message.ListOffsetsResponseData.ListOffsetsTopicResponse; +import org.apache.kafka.common.message.RequestHeaderData; +import org.apache.kafka.common.message.SaslAuthenticateRequestData; +import org.apache.kafka.common.message.SaslHandshakeRequestData; +import org.apache.kafka.common.network.ByteBufferSend; +import org.apache.kafka.common.network.CertStores; +import org.apache.kafka.common.network.ChannelBuilder; +import org.apache.kafka.common.network.ChannelBuilders; +import org.apache.kafka.common.network.ChannelMetadataRegistry; +import org.apache.kafka.common.network.ChannelState; +import org.apache.kafka.common.network.ListenerName; +import org.apache.kafka.common.network.Mode; +import org.apache.kafka.common.network.NetworkSend; +import org.apache.kafka.common.network.NetworkTestUtils; +import org.apache.kafka.common.network.NioEchoServer; +import org.apache.kafka.common.network.SaslChannelBuilder; +import org.apache.kafka.common.network.Selector; +import org.apache.kafka.common.network.TransportLayer; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.protocol.types.SchemaException; +import org.apache.kafka.common.requests.AbstractRequest; +import org.apache.kafka.common.requests.AbstractResponse; +import org.apache.kafka.common.requests.ApiVersionsRequest; +import org.apache.kafka.common.requests.ApiVersionsResponse; +import org.apache.kafka.common.requests.ListOffsetsResponse; +import org.apache.kafka.common.requests.MetadataRequest; +import org.apache.kafka.common.requests.RequestHeader; +import org.apache.kafka.common.requests.RequestTestUtils; +import org.apache.kafka.common.requests.ResponseHeader; +import org.apache.kafka.common.requests.SaslAuthenticateRequest; +import org.apache.kafka.common.requests.SaslHandshakeRequest; +import org.apache.kafka.common.requests.SaslHandshakeResponse; +import org.apache.kafka.common.security.JaasContext; +import org.apache.kafka.common.security.TestSecurityConfig; +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; +import org.apache.kafka.common.security.auth.AuthenticationContext; +import org.apache.kafka.common.security.auth.KafkaPrincipal; +import org.apache.kafka.common.security.auth.KafkaPrincipalBuilder; +import org.apache.kafka.common.security.auth.Login; +import org.apache.kafka.common.security.auth.SaslAuthenticationContext; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.common.security.authenticator.TestDigestLoginModule.DigestServerCallbackHandler; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback; +import org.apache.kafka.common.security.oauthbearer.internals.unsecured.OAuthBearerConfigException; +import org.apache.kafka.common.security.oauthbearer.internals.unsecured.OAuthBearerIllegalTokenException; +import org.apache.kafka.common.security.oauthbearer.internals.unsecured.OAuthBearerUnsecuredJws; +import org.apache.kafka.common.security.oauthbearer.internals.unsecured.OAuthBearerUnsecuredLoginCallbackHandler; +import org.apache.kafka.common.security.plain.PlainLoginModule; +import org.apache.kafka.common.security.plain.internals.PlainServerCallbackHandler; +import org.apache.kafka.common.security.scram.ScramCredential; +import org.apache.kafka.common.security.scram.ScramLoginModule; +import org.apache.kafka.common.security.scram.internals.ScramCredentialUtils; +import org.apache.kafka.common.security.scram.internals.ScramFormatter; +import org.apache.kafka.common.security.scram.internals.ScramMechanism; +import org.apache.kafka.common.security.token.delegation.TokenInformation; +import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.SecurityUtils; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.opentest4j.AssertionFailedError; + +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.security.auth.Subject; +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.auth.login.AppConfigurationEntry; +import javax.security.auth.login.Configuration; +import javax.security.auth.login.LoginContext; +import javax.security.auth.login.LoginException; +import javax.security.sasl.SaslClient; +import javax.security.sasl.SaslException; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.charset.StandardCharsets; +import java.security.NoSuchAlgorithmException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Base64; +import java.util.Base64.Encoder; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.apache.kafka.common.protocol.ApiKeys.LIST_OFFSETS; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +/** + * Tests for the Sasl authenticator. These use a test harness that runs a simple socket server that echos back responses. + */ +public class SaslAuthenticatorTest { + + private static final long CONNECTIONS_MAX_REAUTH_MS_VALUE = 100L; + private static final int BUFFER_SIZE = 4 * 1024; + private static Time time = Time.SYSTEM; + + private NioEchoServer server; + private Selector selector; + private ChannelBuilder channelBuilder; + private CertStores serverCertStores; + private CertStores clientCertStores; + private Map saslClientConfigs; + private Map saslServerConfigs; + private CredentialCache credentialCache; + private int nextCorrelationId; + + @BeforeEach + public void setup() throws Exception { + LoginManager.closeAll(); + time = Time.SYSTEM; + serverCertStores = new CertStores(true, "localhost"); + clientCertStores = new CertStores(false, "localhost"); + saslServerConfigs = serverCertStores.getTrustingConfig(clientCertStores); + saslClientConfigs = clientCertStores.getTrustingConfig(serverCertStores); + credentialCache = new CredentialCache(); + TestLogin.loginCount.set(0); + } + + @AfterEach + public void teardown() throws Exception { + if (server != null) + this.server.close(); + if (selector != null) + this.selector.close(); + } + + /** + * Tests good path SASL/PLAIN client and server channels using SSL transport layer. + * Also tests successful re-authentication. + */ + @Test + public void testValidSaslPlainOverSsl() throws Exception { + String node = "0"; + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + configureMechanisms("PLAIN", Arrays.asList("PLAIN")); + + server = createEchoServer(securityProtocol); + checkAuthenticationAndReauthentication(securityProtocol, node); + } + + /** + * Tests good path SASL/PLAIN client and server channels using PLAINTEXT transport layer. + * Also tests successful re-authentication. + */ + @Test + public void testValidSaslPlainOverPlaintext() throws Exception { + String node = "0"; + SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT; + configureMechanisms("PLAIN", Arrays.asList("PLAIN")); + + server = createEchoServer(securityProtocol); + checkAuthenticationAndReauthentication(securityProtocol, node); + } + + /** + * Tests that SASL/PLAIN clients with invalid password fail authentication. + */ + @Test + public void testInvalidPasswordSaslPlain() throws Exception { + String node = "0"; + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + TestJaasConfig jaasConfig = configureMechanisms("PLAIN", Arrays.asList("PLAIN")); + jaasConfig.setClientOptions("PLAIN", TestJaasConfig.USERNAME, "invalidpassword"); + + server = createEchoServer(securityProtocol); + createAndCheckClientAuthenticationFailure(securityProtocol, node, "PLAIN", + "Authentication failed: Invalid username or password"); + server.verifyAuthenticationMetrics(0, 1); + server.verifyReauthenticationMetrics(0, 0); + } + + /** + * Tests that SASL/PLAIN clients with invalid username fail authentication. + */ + @Test + public void testInvalidUsernameSaslPlain() throws Exception { + String node = "0"; + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + TestJaasConfig jaasConfig = configureMechanisms("PLAIN", Arrays.asList("PLAIN")); + jaasConfig.setClientOptions("PLAIN", "invaliduser", TestJaasConfig.PASSWORD); + + server = createEchoServer(securityProtocol); + createAndCheckClientAuthenticationFailure(securityProtocol, node, "PLAIN", + "Authentication failed: Invalid username or password"); + server.verifyAuthenticationMetrics(0, 1); + server.verifyReauthenticationMetrics(0, 0); + } + + /** + * Tests that SASL/PLAIN clients without valid username fail authentication. + */ + @Test + public void testMissingUsernameSaslPlain() throws Exception { + String node = "0"; + TestJaasConfig jaasConfig = configureMechanisms("PLAIN", Arrays.asList("PLAIN")); + jaasConfig.setClientOptions("PLAIN", null, "mypassword"); + + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + server = createEchoServer(securityProtocol); + createSelector(securityProtocol, saslClientConfigs); + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + try { + selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); + fail("SASL/PLAIN channel created without username"); + } catch (IOException e) { + // Expected exception + assertTrue(selector.channels().isEmpty(), "Channels not closed"); + for (SelectionKey key : selector.keys()) + assertFalse(key.isValid(), "Key not cancelled"); + } + } + + /** + * Tests that SASL/PLAIN clients with missing password in JAAS configuration fail authentication. + */ + @Test + public void testMissingPasswordSaslPlain() throws Exception { + String node = "0"; + TestJaasConfig jaasConfig = configureMechanisms("PLAIN", Arrays.asList("PLAIN")); + jaasConfig.setClientOptions("PLAIN", "myuser", null); + + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + server = createEchoServer(securityProtocol); + createSelector(securityProtocol, saslClientConfigs); + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + try { + selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); + fail("SASL/PLAIN channel created without password"); + } catch (IOException e) { + // Expected exception + } + } + + /** + * Verify that messages from SaslExceptions thrown in the server during authentication are not + * propagated to the client since these may contain sensitive data. + */ + @Test + public void testClientExceptionDoesNotContainSensitiveData() throws Exception { + InvalidScramServerCallbackHandler.reset(); + + SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT; + TestJaasConfig jaasConfig = configureMechanisms("SCRAM-SHA-256", Collections.singletonList("SCRAM-SHA-256")); + jaasConfig.createOrUpdateEntry(TestJaasConfig.LOGIN_CONTEXT_SERVER, PlainLoginModule.class.getName(), new HashMap<>()); + String callbackPrefix = ListenerName.forSecurityProtocol(securityProtocol).saslMechanismConfigPrefix("SCRAM-SHA-256"); + saslServerConfigs.put(callbackPrefix + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS, + InvalidScramServerCallbackHandler.class.getName()); + server = createEchoServer(securityProtocol); + + try { + InvalidScramServerCallbackHandler.sensitiveException = + new IOException("Could not connect to password database locahost:8000"); + createAndCheckClientAuthenticationFailure(securityProtocol, "1", "SCRAM-SHA-256", null); + + InvalidScramServerCallbackHandler.sensitiveException = + new SaslException("Password for existing user " + TestServerCallbackHandler.USERNAME + " is invalid"); + createAndCheckClientAuthenticationFailure(securityProtocol, "1", "SCRAM-SHA-256", null); + + InvalidScramServerCallbackHandler.reset(); + InvalidScramServerCallbackHandler.clientFriendlyException = + new SaslAuthenticationException("Credential verification failed"); + createAndCheckClientAuthenticationFailure(securityProtocol, "1", "SCRAM-SHA-256", + InvalidScramServerCallbackHandler.clientFriendlyException.getMessage()); + } finally { + InvalidScramServerCallbackHandler.reset(); + } + } + + public static class InvalidScramServerCallbackHandler implements AuthenticateCallbackHandler { + // We want to test three types of exceptions: + // 1) IOException since we can throw this from callback handlers. This may be sensitive. + // 2) SaslException (also an IOException) which may contain data from external (or JRE) servers and callbacks and may be sensitive + // 3) SaslAuthenticationException which is from our own code and is used only for client-friendly exceptions + // We use two different exceptions here since the only checked exception CallbackHandler can throw is IOException, + // covering case 1) and 2). For case 3), SaslAuthenticationException is a RuntimeExceptiom. + static volatile IOException sensitiveException; + static volatile SaslAuthenticationException clientFriendlyException; + + @Override + public void configure(Map configs, String saslMechanism, List jaasConfigEntries) { + } + + @Override + public void handle(Callback[] callbacks) throws IOException { + if (sensitiveException != null) + throw sensitiveException; + if (clientFriendlyException != null) + throw clientFriendlyException; + } + + @Override + public void close() { + reset(); + } + + static void reset() { + sensitiveException = null; + clientFriendlyException = null; + } + } + + /** + * Tests that mechanisms that are not supported in Kafka can be plugged in without modifying + * Kafka code if Sasl client and server providers are available. + */ + @Test + public void testMechanismPluggability() throws Exception { + String node = "0"; + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + configureMechanisms("DIGEST-MD5", Arrays.asList("DIGEST-MD5")); + configureDigestMd5ServerCallback(securityProtocol); + + server = createEchoServer(securityProtocol); + createAndCheckClientConnection(securityProtocol, node); + } + + /** + * Tests that servers supporting multiple SASL mechanisms work with clients using + * any of the enabled mechanisms. + * Also tests successful re-authentication over multiple mechanisms. + */ + @Test + public void testMultipleServerMechanisms() throws Exception { + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + configureMechanisms("DIGEST-MD5", Arrays.asList("DIGEST-MD5", "PLAIN", "SCRAM-SHA-256")); + configureDigestMd5ServerCallback(securityProtocol); + server = createEchoServer(securityProtocol); + updateScramCredentialCache(TestJaasConfig.USERNAME, TestJaasConfig.PASSWORD); + + String node1 = "1"; + saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, "PLAIN"); + createAndCheckClientConnection(securityProtocol, node1); + server.verifyAuthenticationMetrics(1, 0); + + Selector selector2 = null; + Selector selector3 = null; + try { + String node2 = "2"; + saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, "DIGEST-MD5"); + createSelector(securityProtocol, saslClientConfigs); + selector2 = selector; + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + selector.connect(node2, addr, BUFFER_SIZE, BUFFER_SIZE); + NetworkTestUtils.checkClientConnection(selector, node2, 100, 10); + selector = null; // keeps it from being closed when next one is created + server.verifyAuthenticationMetrics(2, 0); + + String node3 = "3"; + saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, "SCRAM-SHA-256"); + createSelector(securityProtocol, saslClientConfigs); + selector3 = selector; + selector.connect(node3, new InetSocketAddress("localhost", server.port()), BUFFER_SIZE, BUFFER_SIZE); + NetworkTestUtils.checkClientConnection(selector, node3, 100, 10); + server.verifyAuthenticationMetrics(3, 0); + + /* + * Now re-authenticate the connections. First we have to sleep long enough so + * that the next write will cause re-authentication, which we expect to succeed. + */ + delay((long) (CONNECTIONS_MAX_REAUTH_MS_VALUE * 1.1)); + server.verifyReauthenticationMetrics(0, 0); + + NetworkTestUtils.checkClientConnection(selector2, node2, 100, 10); + server.verifyReauthenticationMetrics(1, 0); + + NetworkTestUtils.checkClientConnection(selector3, node3, 100, 10); + server.verifyReauthenticationMetrics(2, 0); + + } finally { + if (selector2 != null) + selector2.close(); + if (selector3 != null) + selector3.close(); + } + } + + /** + * Tests good path SASL/SCRAM-SHA-256 client and server channels. + * Also tests successful re-authentication. + */ + @Test + public void testValidSaslScramSha256() throws Exception { + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + configureMechanisms("SCRAM-SHA-256", Arrays.asList("SCRAM-SHA-256")); + + server = createEchoServer(securityProtocol); + updateScramCredentialCache(TestJaasConfig.USERNAME, TestJaasConfig.PASSWORD); + checkAuthenticationAndReauthentication(securityProtocol, "0"); + } + + /** + * Tests all supported SCRAM client and server channels. Also tests that all + * supported SCRAM mechanisms can be supported simultaneously on a server. + */ + @Test + public void testValidSaslScramMechanisms() throws Exception { + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + configureMechanisms("SCRAM-SHA-256", new ArrayList<>(ScramMechanism.mechanismNames())); + server = createEchoServer(securityProtocol); + updateScramCredentialCache(TestJaasConfig.USERNAME, TestJaasConfig.PASSWORD); + + for (String mechanism : ScramMechanism.mechanismNames()) { + saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, mechanism); + createAndCheckClientConnection(securityProtocol, "node-" + mechanism); + } + } + + /** + * Tests that SASL/SCRAM clients fail authentication if password is invalid. + */ + @Test + public void testInvalidPasswordSaslScram() throws Exception { + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + TestJaasConfig jaasConfig = configureMechanisms("SCRAM-SHA-256", Arrays.asList("SCRAM-SHA-256")); + Map options = new HashMap<>(); + options.put("username", TestJaasConfig.USERNAME); + options.put("password", "invalidpassword"); + jaasConfig.createOrUpdateEntry(TestJaasConfig.LOGIN_CONTEXT_CLIENT, ScramLoginModule.class.getName(), options); + + String node = "0"; + server = createEchoServer(securityProtocol); + updateScramCredentialCache(TestJaasConfig.USERNAME, TestJaasConfig.PASSWORD); + createAndCheckClientAuthenticationFailure(securityProtocol, node, "SCRAM-SHA-256", null); + server.verifyAuthenticationMetrics(0, 1); + server.verifyReauthenticationMetrics(0, 0); + } + + /** + * Tests that SASL/SCRAM clients without valid username fail authentication. + */ + @Test + public void testUnknownUserSaslScram() throws Exception { + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + TestJaasConfig jaasConfig = configureMechanisms("SCRAM-SHA-256", Arrays.asList("SCRAM-SHA-256")); + Map options = new HashMap<>(); + options.put("username", "unknownUser"); + options.put("password", TestJaasConfig.PASSWORD); + jaasConfig.createOrUpdateEntry(TestJaasConfig.LOGIN_CONTEXT_CLIENT, ScramLoginModule.class.getName(), options); + + String node = "0"; + server = createEchoServer(securityProtocol); + updateScramCredentialCache(TestJaasConfig.USERNAME, TestJaasConfig.PASSWORD); + createAndCheckClientAuthenticationFailure(securityProtocol, node, "SCRAM-SHA-256", null); + server.verifyAuthenticationMetrics(0, 1); + server.verifyReauthenticationMetrics(0, 0); + } + + /** + * Tests that SASL/SCRAM clients fail authentication if credentials are not available for + * the specific SCRAM mechanism. + */ + @Test + public void testUserCredentialsUnavailableForScramMechanism() throws Exception { + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + configureMechanisms("SCRAM-SHA-256", new ArrayList<>(ScramMechanism.mechanismNames())); + server = createEchoServer(securityProtocol); + updateScramCredentialCache(TestJaasConfig.USERNAME, TestJaasConfig.PASSWORD); + + server.credentialCache().cache(ScramMechanism.SCRAM_SHA_256.mechanismName(), ScramCredential.class).remove(TestJaasConfig.USERNAME); + String node = "1"; + saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, "SCRAM-SHA-256"); + createAndCheckClientAuthenticationFailure(securityProtocol, node, "SCRAM-SHA-256", null); + server.verifyAuthenticationMetrics(0, 1); + + saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, "SCRAM-SHA-512"); + createAndCheckClientConnection(securityProtocol, "2"); + server.verifyAuthenticationMetrics(1, 1); + server.verifyReauthenticationMetrics(0, 0); + } + + /** + * Tests SASL/SCRAM with username containing characters that need + * to be encoded. + */ + @Test + public void testScramUsernameWithSpecialCharacters() throws Exception { + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + String username = "special user= test,scram"; + String password = username + "-password"; + TestJaasConfig jaasConfig = configureMechanisms("SCRAM-SHA-256", Arrays.asList("SCRAM-SHA-256")); + Map options = new HashMap<>(); + options.put("username", username); + options.put("password", password); + jaasConfig.createOrUpdateEntry(TestJaasConfig.LOGIN_CONTEXT_CLIENT, ScramLoginModule.class.getName(), options); + + server = createEchoServer(securityProtocol); + updateScramCredentialCache(username, password); + createAndCheckClientConnection(securityProtocol, "0"); + } + + + @Test + public void testTokenAuthenticationOverSaslScram() throws Exception { + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + TestJaasConfig jaasConfig = configureMechanisms("SCRAM-SHA-256", Arrays.asList("SCRAM-SHA-256")); + + //create jaas config for token auth + Map options = new HashMap<>(); + String tokenId = "token1"; + String tokenHmac = "abcdefghijkl"; + options.put("username", tokenId); //tokenId + options.put("password", tokenHmac); //token hmac + options.put(ScramLoginModule.TOKEN_AUTH_CONFIG, "true"); //enable token authentication + jaasConfig.createOrUpdateEntry(TestJaasConfig.LOGIN_CONTEXT_CLIENT, ScramLoginModule.class.getName(), options); + + server = createEchoServer(securityProtocol); + + //Check invalid tokenId/tokenInfo in tokenCache + createAndCheckClientConnectionFailure(securityProtocol, "0"); + server.verifyAuthenticationMetrics(0, 1); + + //Check valid token Info and invalid credentials + KafkaPrincipal owner = SecurityUtils.parseKafkaPrincipal("User:Owner"); + KafkaPrincipal renewer = SecurityUtils.parseKafkaPrincipal("User:Renewer1"); + TokenInformation tokenInfo = new TokenInformation(tokenId, owner, Collections.singleton(renewer), + System.currentTimeMillis(), System.currentTimeMillis(), System.currentTimeMillis()); + server.tokenCache().addToken(tokenId, tokenInfo); + createAndCheckClientConnectionFailure(securityProtocol, "0"); + server.verifyAuthenticationMetrics(0, 2); + + //Check with valid token Info and credentials + updateTokenCredentialCache(tokenId, tokenHmac); + createAndCheckClientConnection(securityProtocol, "0"); + server.verifyAuthenticationMetrics(1, 2); + server.verifyReauthenticationMetrics(0, 0); + } + + @Test + public void testTokenReauthenticationOverSaslScram() throws Exception { + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + TestJaasConfig jaasConfig = configureMechanisms("SCRAM-SHA-256", Arrays.asList("SCRAM-SHA-256")); + + // create jaas config for token auth + Map options = new HashMap<>(); + String tokenId = "token1"; + String tokenHmac = "abcdefghijkl"; + options.put("username", tokenId); // tokenId + options.put("password", tokenHmac); // token hmac + options.put(ScramLoginModule.TOKEN_AUTH_CONFIG, "true"); // enable token authentication + jaasConfig.createOrUpdateEntry(TestJaasConfig.LOGIN_CONTEXT_CLIENT, ScramLoginModule.class.getName(), options); + + // ensure re-authentication based on token expiry rather than a default value + saslServerConfigs.put(BrokerSecurityConfigs.CONNECTIONS_MAX_REAUTH_MS, Long.MAX_VALUE); + /* + * create a token cache that adjusts the token expiration dynamically so that + * the first time the expiry is read during authentication we use it to define a + * session expiration time that we can then sleep through; then the second time + * the value is read (during re-authentication) it will be in the future. + */ + Function tokenLifetime = callNum -> 10 * callNum * CONNECTIONS_MAX_REAUTH_MS_VALUE; + DelegationTokenCache tokenCache = new DelegationTokenCache(ScramMechanism.mechanismNames()) { + int callNum = 0; + + @Override + public TokenInformation token(String tokenId) { + TokenInformation baseTokenInfo = super.token(tokenId); + long thisLifetimeMs = System.currentTimeMillis() + tokenLifetime.apply(++callNum).longValue(); + TokenInformation retvalTokenInfo = new TokenInformation(baseTokenInfo.tokenId(), baseTokenInfo.owner(), + baseTokenInfo.renewers(), baseTokenInfo.issueTimestamp(), thisLifetimeMs, thisLifetimeMs); + return retvalTokenInfo; + } + }; + server = createEchoServer(ListenerName.forSecurityProtocol(securityProtocol), securityProtocol, tokenCache); + + KafkaPrincipal owner = SecurityUtils.parseKafkaPrincipal("User:Owner"); + KafkaPrincipal renewer = SecurityUtils.parseKafkaPrincipal("User:Renewer1"); + TokenInformation tokenInfo = new TokenInformation(tokenId, owner, Collections.singleton(renewer), + System.currentTimeMillis(), System.currentTimeMillis(), System.currentTimeMillis()); + server.tokenCache().addToken(tokenId, tokenInfo); + updateTokenCredentialCache(tokenId, tokenHmac); + // initial authentication must succeed + createClientConnection(securityProtocol, "0"); + checkClientConnection("0"); + // ensure metrics are as expected before trying to re-authenticate + server.verifyAuthenticationMetrics(1, 0); + server.verifyReauthenticationMetrics(0, 0); + /* + * Now re-authenticate and ensure it succeeds. We have to sleep long enough so + * that the current delegation token will be expired when the next write occurs; + * this will trigger a re-authentication. Then the second time the delegation + * token is read and transmitted to the server it will again have an expiration + * date in the future. + */ + delay(tokenLifetime.apply(1)); + checkClientConnection("0"); + server.verifyReauthenticationMetrics(1, 0); + } + + /** + * Tests that Kafka ApiVersionsRequests are handled by the SASL server authenticator + * prior to SASL handshake flow and that subsequent authentication succeeds + * when transport layer is PLAINTEXT. This test simulates SASL authentication using a + * (non-SASL) PLAINTEXT client and sends ApiVersionsRequest straight after + * connection to the server is established, before any SASL-related packets are sent. + * This test is run with SaslHandshake version 0 and no SaslAuthenticate headers. + */ + @Test + public void testUnauthenticatedApiVersionsRequestOverPlaintextHandshakeVersion0() throws Exception { + testUnauthenticatedApiVersionsRequest(SecurityProtocol.SASL_PLAINTEXT, (short) 0); + } + + /** + * See {@link #testUnauthenticatedApiVersionsRequestOverSslHandshakeVersion0()} for test scenario. + * This test is run with SaslHandshake version 1 and SaslAuthenticate headers. + */ + @Test + public void testUnauthenticatedApiVersionsRequestOverPlaintextHandshakeVersion1() throws Exception { + testUnauthenticatedApiVersionsRequest(SecurityProtocol.SASL_PLAINTEXT, (short) 1); + } + + /** + * Tests that Kafka ApiVersionsRequests are handled by the SASL server authenticator + * prior to SASL handshake flow and that subsequent authentication succeeds + * when transport layer is SSL. This test simulates SASL authentication using a + * (non-SASL) SSL client and sends ApiVersionsRequest straight after + * SSL handshake, before any SASL-related packets are sent. + * This test is run with SaslHandshake version 0 and no SaslAuthenticate headers. + */ + @Test + public void testUnauthenticatedApiVersionsRequestOverSslHandshakeVersion0() throws Exception { + testUnauthenticatedApiVersionsRequest(SecurityProtocol.SASL_SSL, (short) 0); + } + + /** + * See {@link #testUnauthenticatedApiVersionsRequestOverPlaintextHandshakeVersion0()} for test scenario. + * This test is run with SaslHandshake version 1 and SaslAuthenticate headers. + */ + @Test + public void testUnauthenticatedApiVersionsRequestOverSslHandshakeVersion1() throws Exception { + testUnauthenticatedApiVersionsRequest(SecurityProtocol.SASL_SSL, (short) 1); + } + + /** + * Tests that unsupported version of ApiVersionsRequest before SASL handshake request + * returns error response and does not result in authentication failure. This test + * is similar to {@link #testUnauthenticatedApiVersionsRequest(SecurityProtocol, short)} + * where a non-SASL client is used to send requests that are processed by + * {@link SaslServerAuthenticator} of the server prior to client authentication. + */ + @Test + public void testApiVersionsRequestWithServerUnsupportedVersion() throws Exception { + short handshakeVersion = ApiKeys.SASL_HANDSHAKE.latestVersion(); + SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT; + configureMechanisms("PLAIN", Arrays.asList("PLAIN")); + server = createEchoServer(securityProtocol); + + // Send ApiVersionsRequest with unsupported version and validate error response. + String node = "1"; + createClientConnection(SecurityProtocol.PLAINTEXT, node); + + RequestHeader header = new RequestHeader(new RequestHeaderData(). + setRequestApiKey(ApiKeys.API_VERSIONS.id). + setRequestApiVersion(Short.MAX_VALUE). + setClientId("someclient"). + setCorrelationId(1), + (short) 2); + ApiVersionsRequest request = new ApiVersionsRequest.Builder().build(); + selector.send(new NetworkSend(node, request.toSend(header))); + ByteBuffer responseBuffer = waitForResponse(); + ResponseHeader.parse(responseBuffer, ApiKeys.API_VERSIONS.responseHeaderVersion((short) 0)); + ApiVersionsResponse response = ApiVersionsResponse.parse(responseBuffer, (short) 0); + assertEquals(Errors.UNSUPPORTED_VERSION.code(), response.data().errorCode()); + + ApiVersion apiVersion = response.data().apiKeys().find(ApiKeys.API_VERSIONS.id); + assertNotNull(apiVersion); + assertEquals(ApiKeys.API_VERSIONS.id, apiVersion.apiKey()); + assertEquals(ApiKeys.API_VERSIONS.oldestVersion(), apiVersion.minVersion()); + assertEquals(ApiKeys.API_VERSIONS.latestVersion(), apiVersion.maxVersion()); + + // Send ApiVersionsRequest with a supported version. This should succeed. + sendVersionRequestReceiveResponse(node); + + // Test that client can authenticate successfully + sendHandshakeRequestReceiveResponse(node, handshakeVersion); + authenticateUsingSaslPlainAndCheckConnection(node, handshakeVersion > 0); + } + + /** + * Tests correct negotiation of handshake and authenticate api versions by having the server + * return a higher version than supported on the client. + * Note, that due to KAFKA-9577 this will require a workaround to effectively bump + * SASL_HANDSHAKE in the future. + */ + @Test + public void testSaslUnsupportedClientVersions() throws Exception { + configureMechanisms("SCRAM-SHA-512", Arrays.asList("SCRAM-SHA-512")); + + server = startServerApiVersionsUnsupportedByClient(SecurityProtocol.SASL_SSL, "SCRAM-SHA-512"); + updateScramCredentialCache(TestJaasConfig.USERNAME, TestJaasConfig.PASSWORD); + + String node = "0"; + + createClientConnection(SecurityProtocol.SASL_SSL, "SCRAM-SHA-512", node, true); + NetworkTestUtils.checkClientConnection(selector, "0", 100, 10); + } + + /** + * Tests that invalid ApiVersionRequest is handled by the server correctly and + * returns an INVALID_REQUEST error. + */ + @Test + public void testInvalidApiVersionsRequest() throws Exception { + short handshakeVersion = ApiKeys.SASL_HANDSHAKE.latestVersion(); + SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT; + configureMechanisms("PLAIN", Arrays.asList("PLAIN")); + server = createEchoServer(securityProtocol); + + // Send ApiVersionsRequest with invalid version and validate error response. + String node = "1"; + short version = ApiKeys.API_VERSIONS.latestVersion(); + createClientConnection(SecurityProtocol.PLAINTEXT, node); + RequestHeader header = new RequestHeader(ApiKeys.API_VERSIONS, version, "someclient", 1); + ApiVersionsRequest request = new ApiVersionsRequest(new ApiVersionsRequestData(). + setClientSoftwareName(" "). + setClientSoftwareVersion(" "), version); + selector.send(new NetworkSend(node, request.toSend(header))); + ByteBuffer responseBuffer = waitForResponse(); + ResponseHeader.parse(responseBuffer, ApiKeys.API_VERSIONS.responseHeaderVersion(version)); + ApiVersionsResponse response = + ApiVersionsResponse.parse(responseBuffer, version); + assertEquals(Errors.INVALID_REQUEST.code(), response.data().errorCode()); + + // Send ApiVersionsRequest with a supported version. This should succeed. + sendVersionRequestReceiveResponse(node); + + // Test that client can authenticate successfully + sendHandshakeRequestReceiveResponse(node, handshakeVersion); + authenticateUsingSaslPlainAndCheckConnection(node, handshakeVersion > 0); + } + + + @Test + public void testForBrokenSaslHandshakeVersionBump() { + assertEquals(1, ApiKeys.SASL_HANDSHAKE.latestVersion(), + "It is not possible to easily bump SASL_HANDSHAKE schema due to improper version negotiation in " + + "clients < 2.5. Please see https://issues.apache.org/jira/browse/KAFKA-9577"); + } + + /** + * Tests that valid ApiVersionRequest is handled by the server correctly and + * returns an NONE error. + */ + @Test + public void testValidApiVersionsRequest() throws Exception { + short handshakeVersion = ApiKeys.SASL_HANDSHAKE.latestVersion(); + SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT; + configureMechanisms("PLAIN", Arrays.asList("PLAIN")); + server = createEchoServer(securityProtocol); + + // Send ApiVersionsRequest with valid version and validate error response. + String node = "1"; + short version = ApiKeys.API_VERSIONS.latestVersion(); + createClientConnection(SecurityProtocol.PLAINTEXT, node); + RequestHeader header = new RequestHeader(ApiKeys.API_VERSIONS, version, "someclient", 1); + ApiVersionsRequest request = new ApiVersionsRequest.Builder().build(version); + selector.send(new NetworkSend(node, request.toSend(header))); + ByteBuffer responseBuffer = waitForResponse(); + ResponseHeader.parse(responseBuffer, ApiKeys.API_VERSIONS.responseHeaderVersion(version)); + ApiVersionsResponse response = ApiVersionsResponse.parse(responseBuffer, version); + assertEquals(Errors.NONE.code(), response.data().errorCode()); + + // Test that client can authenticate successfully + sendHandshakeRequestReceiveResponse(node, handshakeVersion); + authenticateUsingSaslPlainAndCheckConnection(node, handshakeVersion > 0); + } + + /** + * Tests that unsupported version of SASL handshake request returns error + * response and fails authentication. This test is similar to + * {@link #testUnauthenticatedApiVersionsRequest(SecurityProtocol, short)} + * where a non-SASL client is used to send requests that are processed by + * {@link SaslServerAuthenticator} of the server prior to client authentication. + */ + @Test + public void testSaslHandshakeRequestWithUnsupportedVersion() throws Exception { + SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT; + configureMechanisms("PLAIN", Arrays.asList("PLAIN")); + server = createEchoServer(securityProtocol); + + // Send SaslHandshakeRequest and validate that connection is closed by server. + String node1 = "invalid1"; + createClientConnection(SecurityProtocol.PLAINTEXT, node1); + SaslHandshakeRequest request = buildSaslHandshakeRequest("PLAIN", ApiKeys.SASL_HANDSHAKE.latestVersion()); + RequestHeader header = new RequestHeader(ApiKeys.SASL_HANDSHAKE, Short.MAX_VALUE, "someclient", 2); + + selector.send(new NetworkSend(node1, request.toSend(header))); + // This test uses a non-SASL PLAINTEXT client in order to do manual handshake. + // So the channel is in READY state. + NetworkTestUtils.waitForChannelClose(selector, node1, ChannelState.READY.state()); + selector.close(); + + // Test good connection still works + createAndCheckClientConnection(securityProtocol, "good1"); + } + + /** + * Tests that any invalid data during Kafka SASL handshake request flow + * or the actual SASL authentication flow result in authentication failure + * and do not cause any failures in the server. + */ + @Test + public void testInvalidSaslPacket() throws Exception { + SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT; + configureMechanisms("PLAIN", Arrays.asList("PLAIN")); + server = createEchoServer(securityProtocol); + + // Send invalid SASL packet after valid handshake request + String node1 = "invalid1"; + createClientConnection(SecurityProtocol.PLAINTEXT, node1); + sendHandshakeRequestReceiveResponse(node1, (short) 1); + Random random = new Random(); + byte[] bytes = new byte[1024]; + random.nextBytes(bytes); + selector.send(new NetworkSend(node1, ByteBufferSend.sizePrefixed(ByteBuffer.wrap(bytes)))); + NetworkTestUtils.waitForChannelClose(selector, node1, ChannelState.READY.state()); + selector.close(); + + // Test good connection still works + createAndCheckClientConnection(securityProtocol, "good1"); + + // Send invalid SASL packet before handshake request + String node2 = "invalid2"; + createClientConnection(SecurityProtocol.PLAINTEXT, node2); + random.nextBytes(bytes); + selector.send(new NetworkSend(node2, ByteBufferSend.sizePrefixed(ByteBuffer.wrap(bytes)))); + NetworkTestUtils.waitForChannelClose(selector, node2, ChannelState.READY.state()); + selector.close(); + + // Test good connection still works + createAndCheckClientConnection(securityProtocol, "good2"); + } + + /** + * Tests that ApiVersionsRequest after Kafka SASL handshake request flow, + * but prior to actual SASL authentication, results in authentication failure. + * This is similar to {@link #testUnauthenticatedApiVersionsRequest(SecurityProtocol, short)} + * where a non-SASL client is used to send requests that are processed by + * {@link SaslServerAuthenticator} of the server prior to client authentication. + */ + @Test + public void testInvalidApiVersionsRequestSequence() throws Exception { + SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT; + configureMechanisms("PLAIN", Arrays.asList("PLAIN")); + server = createEchoServer(securityProtocol); + + // Send handshake request followed by ApiVersionsRequest + String node1 = "invalid1"; + createClientConnection(SecurityProtocol.PLAINTEXT, node1); + sendHandshakeRequestReceiveResponse(node1, (short) 1); + + ApiVersionsRequest request = createApiVersionsRequestV0(); + RequestHeader versionsHeader = new RequestHeader(ApiKeys.API_VERSIONS, request.version(), "someclient", 2); + selector.send(new NetworkSend(node1, request.toSend(versionsHeader))); + NetworkTestUtils.waitForChannelClose(selector, node1, ChannelState.READY.state()); + selector.close(); + + // Test good connection still works + createAndCheckClientConnection(securityProtocol, "good1"); + } + + /** + * Tests that packets that are too big during Kafka SASL handshake request flow + * or the actual SASL authentication flow result in authentication failure + * and do not cause any failures in the server. + */ + @Test + public void testPacketSizeTooBig() throws Exception { + SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT; + configureMechanisms("PLAIN", Arrays.asList("PLAIN")); + server = createEchoServer(securityProtocol); + + // Send SASL packet with large size after valid handshake request + String node1 = "invalid1"; + createClientConnection(SecurityProtocol.PLAINTEXT, node1); + sendHandshakeRequestReceiveResponse(node1, (short) 1); + ByteBuffer buffer = ByteBuffer.allocate(1024); + buffer.putInt(Integer.MAX_VALUE); + buffer.put(new byte[buffer.capacity() - 4]); + buffer.rewind(); + selector.send(new NetworkSend(node1, ByteBufferSend.sizePrefixed(buffer))); + NetworkTestUtils.waitForChannelClose(selector, node1, ChannelState.READY.state()); + selector.close(); + + // Test good connection still works + createAndCheckClientConnection(securityProtocol, "good1"); + + // Send packet with large size before handshake request + String node2 = "invalid2"; + createClientConnection(SecurityProtocol.PLAINTEXT, node2); + buffer.clear(); + buffer.putInt(Integer.MAX_VALUE); + buffer.put(new byte[buffer.capacity() - 4]); + buffer.rewind(); + selector.send(new NetworkSend(node2, ByteBufferSend.sizePrefixed(buffer))); + NetworkTestUtils.waitForChannelClose(selector, node2, ChannelState.READY.state()); + selector.close(); + + // Test good connection still works + createAndCheckClientConnection(securityProtocol, "good2"); + } + + /** + * Tests that Kafka requests that are forbidden until successful authentication result + * in authentication failure and do not cause any failures in the server. + */ + @Test + public void testDisallowedKafkaRequestsBeforeAuthentication() throws Exception { + SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT; + configureMechanisms("PLAIN", Arrays.asList("PLAIN")); + server = createEchoServer(securityProtocol); + + // Send metadata request before Kafka SASL handshake request + String node1 = "invalid1"; + createClientConnection(SecurityProtocol.PLAINTEXT, node1); + MetadataRequest metadataRequest1 = new MetadataRequest.Builder(Collections.singletonList("sometopic"), + true).build(); + RequestHeader metadataRequestHeader1 = new RequestHeader(ApiKeys.METADATA, metadataRequest1.version(), + "someclient", 1); + selector.send(new NetworkSend(node1, metadataRequest1.toSend(metadataRequestHeader1))); + NetworkTestUtils.waitForChannelClose(selector, node1, ChannelState.READY.state()); + selector.close(); + + // Test good connection still works + createAndCheckClientConnection(securityProtocol, "good1"); + + // Send metadata request after Kafka SASL handshake request + String node2 = "invalid2"; + createClientConnection(SecurityProtocol.PLAINTEXT, node2); + sendHandshakeRequestReceiveResponse(node2, (short) 1); + MetadataRequest metadataRequest2 = new MetadataRequest.Builder(Collections.singletonList("sometopic"), true).build(); + RequestHeader metadataRequestHeader2 = new RequestHeader(ApiKeys.METADATA, + metadataRequest2.version(), "someclient", 2); + selector.send(new NetworkSend(node2, metadataRequest2.toSend(metadataRequestHeader2))); + NetworkTestUtils.waitForChannelClose(selector, node2, ChannelState.READY.state()); + selector.close(); + + // Test good connection still works + createAndCheckClientConnection(securityProtocol, "good2"); + } + + /** + * Tests that connections cannot be created if the login module class is unavailable. + */ + @Test + public void testInvalidLoginModule() throws Exception { + TestJaasConfig jaasConfig = configureMechanisms("PLAIN", Arrays.asList("PLAIN")); + jaasConfig.createOrUpdateEntry(TestJaasConfig.LOGIN_CONTEXT_CLIENT, "InvalidLoginModule", TestJaasConfig.defaultClientOptions()); + + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + server = createEchoServer(securityProtocol); + try { + createSelector(securityProtocol, saslClientConfigs); + fail("SASL/PLAIN channel created without valid login module"); + } catch (KafkaException e) { + // Expected exception + } + } + + /** + * Tests SASL client authentication callback handler override. + */ + @Test + public void testClientAuthenticateCallbackHandler() throws Exception { + SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT; + TestJaasConfig jaasConfig = configureMechanisms("PLAIN", Collections.singletonList("PLAIN")); + saslClientConfigs.put(SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS, TestClientCallbackHandler.class.getName()); + jaasConfig.setClientOptions("PLAIN", "", ""); // remove username, password in login context + + Map options = new HashMap<>(); + options.put("user_" + TestClientCallbackHandler.USERNAME, TestClientCallbackHandler.PASSWORD); + jaasConfig.createOrUpdateEntry(TestJaasConfig.LOGIN_CONTEXT_SERVER, PlainLoginModule.class.getName(), options); + server = createEchoServer(securityProtocol); + createAndCheckClientConnection(securityProtocol, "good"); + + options.clear(); + options.put("user_" + TestClientCallbackHandler.USERNAME, "invalid-password"); + jaasConfig.createOrUpdateEntry(TestJaasConfig.LOGIN_CONTEXT_SERVER, PlainLoginModule.class.getName(), options); + createAndCheckClientConnectionFailure(securityProtocol, "invalid"); + } + + /** + * Tests SASL server authentication callback handler override. + */ + @Test + public void testServerAuthenticateCallbackHandler() throws Exception { + SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT; + TestJaasConfig jaasConfig = configureMechanisms("PLAIN", Collections.singletonList("PLAIN")); + jaasConfig.createOrUpdateEntry(TestJaasConfig.LOGIN_CONTEXT_SERVER, PlainLoginModule.class.getName(), new HashMap()); + String callbackPrefix = ListenerName.forSecurityProtocol(securityProtocol).saslMechanismConfigPrefix("PLAIN"); + saslServerConfigs.put(callbackPrefix + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS, + TestServerCallbackHandler.class.getName()); + server = createEchoServer(securityProtocol); + + // Set client username/password to the values used by `TestServerCallbackHandler` + jaasConfig.setClientOptions("PLAIN", TestServerCallbackHandler.USERNAME, TestServerCallbackHandler.PASSWORD); + createAndCheckClientConnection(securityProtocol, "good"); + + // Set client username/password to the invalid values + jaasConfig.setClientOptions("PLAIN", TestJaasConfig.USERNAME, "invalid-password"); + createAndCheckClientConnectionFailure(securityProtocol, "invalid"); + } + + /** + * Test that callback handlers are only applied to connections for the mechanisms + * configured for the handler. Test enables two mechanisms 'PLAIN` and `DIGEST-MD5` + * on the servers with different callback handlers for the two mechanisms. Verifies + * that clients using both mechanisms authenticate successfully. + */ + @Test + public void testAuthenticateCallbackHandlerMechanisms() throws Exception { + SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT; + TestJaasConfig jaasConfig = configureMechanisms("DIGEST-MD5", Arrays.asList("DIGEST-MD5", "PLAIN")); + + // Connections should fail using the digest callback handler if listener.mechanism prefix not specified + saslServerConfigs.put("plain." + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS, + TestServerCallbackHandler.class); + saslServerConfigs.put("digest-md5." + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS, + DigestServerCallbackHandler.class); + server = createEchoServer(securityProtocol); + createAndCheckClientConnectionFailure(securityProtocol, "invalid"); + + // Connections should succeed using the server callback handler associated with the listener + ListenerName listener = ListenerName.forSecurityProtocol(securityProtocol); + saslServerConfigs.remove("plain." + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS); + saslServerConfigs.remove("digest-md5." + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS); + saslServerConfigs.put(listener.saslMechanismConfigPrefix("plain") + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS, + TestServerCallbackHandler.class); + saslServerConfigs.put(listener.saslMechanismConfigPrefix("digest-md5") + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS, + DigestServerCallbackHandler.class); + server = createEchoServer(securityProtocol); + + // Verify that DIGEST-MD5 (currently configured for client) works with `DigestServerCallbackHandler` + createAndCheckClientConnection(securityProtocol, "good-digest-md5"); + + // Verify that PLAIN works with `TestServerCallbackHandler` + jaasConfig.setClientOptions("PLAIN", TestServerCallbackHandler.USERNAME, TestServerCallbackHandler.PASSWORD); + saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, "PLAIN"); + createAndCheckClientConnection(securityProtocol, "good-plain"); + } + + /** + * Tests SASL login class override. + */ + @Test + public void testClientLoginOverride() throws Exception { + SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT; + TestJaasConfig jaasConfig = configureMechanisms("PLAIN", Collections.singletonList("PLAIN")); + jaasConfig.setClientOptions("PLAIN", "invaliduser", "invalidpassword"); + server = createEchoServer(securityProtocol); + + // Connection should succeed using login override that sets correct username/password in Subject + saslClientConfigs.put(SaslConfigs.SASL_LOGIN_CLASS, TestLogin.class.getName()); + createAndCheckClientConnection(securityProtocol, "1"); + assertEquals(1, TestLogin.loginCount.get()); + + // Connection should fail without login override since username/password in jaas config is invalid + saslClientConfigs.remove(SaslConfigs.SASL_LOGIN_CLASS); + createAndCheckClientConnectionFailure(securityProtocol, "invalid"); + assertEquals(1, TestLogin.loginCount.get()); + } + + /** + * Tests SASL server login class override. + */ + @Test + public void testServerLoginOverride() throws Exception { + SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT; + configureMechanisms("PLAIN", Collections.singletonList("PLAIN")); + String prefix = ListenerName.forSecurityProtocol(securityProtocol).saslMechanismConfigPrefix("PLAIN"); + saslServerConfigs.put(prefix + SaslConfigs.SASL_LOGIN_CLASS, TestLogin.class.getName()); + server = createEchoServer(securityProtocol); + + // Login is performed when server channel builder is created (before any connections are made on the server) + assertEquals(1, TestLogin.loginCount.get()); + + createAndCheckClientConnection(securityProtocol, "1"); + assertEquals(1, TestLogin.loginCount.get()); + } + + /** + * Tests SASL login callback class override. + */ + @Test + public void testClientLoginCallbackOverride() throws Exception { + SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT; + TestJaasConfig jaasConfig = configureMechanisms("PLAIN", Collections.singletonList("PLAIN")); + jaasConfig.createOrUpdateEntry(TestJaasConfig.LOGIN_CONTEXT_CLIENT, TestPlainLoginModule.class.getName(), + Collections.emptyMap()); + server = createEchoServer(securityProtocol); + + // Connection should succeed using login callback override that sets correct username/password + saslClientConfigs.put(SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS, TestLoginCallbackHandler.class.getName()); + createAndCheckClientConnection(securityProtocol, "1"); + + // Connection should fail without login callback override since username/password in jaas config is invalid + saslClientConfigs.remove(SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS); + try { + createClientConnection(securityProtocol, "invalid"); + } catch (Exception e) { + assertTrue(e.getCause() instanceof LoginException, "Unexpected exception " + e.getCause()); + } + } + + /** + * Tests SASL server login callback class override. + */ + @Test + public void testServerLoginCallbackOverride() throws Exception { + SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT; + TestJaasConfig jaasConfig = configureMechanisms("PLAIN", Collections.singletonList("PLAIN")); + jaasConfig.createOrUpdateEntry(TestJaasConfig.LOGIN_CONTEXT_SERVER, TestPlainLoginModule.class.getName(), + Collections.emptyMap()); + jaasConfig.setClientOptions("PLAIN", TestServerCallbackHandler.USERNAME, TestServerCallbackHandler.PASSWORD); + ListenerName listenerName = ListenerName.forSecurityProtocol(securityProtocol); + String prefix = listenerName.saslMechanismConfigPrefix("PLAIN"); + saslServerConfigs.put(prefix + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS, + TestServerCallbackHandler.class); + Class loginCallback = TestLoginCallbackHandler.class; + + try { + createEchoServer(securityProtocol); + fail("Should have failed to create server with default login handler"); + } catch (KafkaException e) { + // Expected exception + } + + try { + saslServerConfigs.put(SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS, loginCallback); + createEchoServer(securityProtocol); + fail("Should have failed to create server with login handler config without listener+mechanism prefix"); + } catch (KafkaException e) { + // Expected exception + saslServerConfigs.remove(SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS); + } + + try { + saslServerConfigs.put("plain." + SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS, loginCallback); + createEchoServer(securityProtocol); + fail("Should have failed to create server with login handler config without listener prefix"); + } catch (KafkaException e) { + // Expected exception + saslServerConfigs.remove("plain." + SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS); + } + + try { + saslServerConfigs.put(listenerName.configPrefix() + SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS, loginCallback); + createEchoServer(securityProtocol); + fail("Should have failed to create server with login handler config without mechanism prefix"); + } catch (KafkaException e) { + // Expected exception + saslServerConfigs.remove("plain." + SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS); + } + + // Connection should succeed using login callback override for mechanism + saslServerConfigs.put(prefix + SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS, loginCallback); + server = createEchoServer(securityProtocol); + createAndCheckClientConnection(securityProtocol, "1"); + } + + /** + * Tests that mechanisms with default implementation in Kafka may be disabled in + * the Kafka server by removing from the enabled mechanism list. + */ + @Test + public void testDisabledMechanism() throws Exception { + String node = "0"; + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + configureMechanisms("PLAIN", Arrays.asList("DIGEST-MD5")); + + server = createEchoServer(securityProtocol); + createAndCheckClientConnectionFailure(securityProtocol, node); + server.verifyAuthenticationMetrics(0, 1); + server.verifyReauthenticationMetrics(0, 0); + } + + /** + * Tests that clients using invalid SASL mechanisms fail authentication. + */ + @Test + public void testInvalidMechanism() throws Exception { + String node = "0"; + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + configureMechanisms("PLAIN", Arrays.asList("PLAIN")); + saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, "INVALID"); + + server = createEchoServer(securityProtocol); + try { + createAndCheckClientConnectionFailure(securityProtocol, node); + fail("Did not generate exception prior to creating channel"); + } catch (IOException expected) { + server.verifyAuthenticationMetrics(0, 0); + server.verifyReauthenticationMetrics(0, 0); + Throwable underlyingCause = expected.getCause().getCause().getCause(); + assertEquals(SaslAuthenticationException.class, underlyingCause.getClass()); + assertEquals("Failed to create SaslClient with mechanism INVALID", underlyingCause.getMessage()); + } finally { + closeClientConnectionIfNecessary(); + } + } + + /** + * Tests dynamic JAAS configuration property for SASL clients. Invalid client credentials + * are set in the static JVM-wide configuration instance to ensure that the dynamic + * property override is used during authentication. + */ + @Test + public void testClientDynamicJaasConfiguration() throws Exception { + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, "PLAIN"); + saslServerConfigs.put(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG, Arrays.asList("PLAIN")); + Map serverOptions = new HashMap<>(); + serverOptions.put("user_user1", "user1-secret"); + serverOptions.put("user_user2", "user2-secret"); + TestJaasConfig staticJaasConfig = new TestJaasConfig(); + staticJaasConfig.createOrUpdateEntry(TestJaasConfig.LOGIN_CONTEXT_SERVER, PlainLoginModule.class.getName(), + serverOptions); + staticJaasConfig.setClientOptions("PLAIN", "user1", "invalidpassword"); + Configuration.setConfiguration(staticJaasConfig); + server = createEchoServer(securityProtocol); + + // Check that client using static Jaas config does not connect since password is invalid + createAndCheckClientConnectionFailure(securityProtocol, "1"); + + // Check that 'user1' can connect with a Jaas config property override + saslClientConfigs.put(SaslConfigs.SASL_JAAS_CONFIG, TestJaasConfig.jaasConfigProperty("PLAIN", "user1", "user1-secret")); + createAndCheckClientConnection(securityProtocol, "2"); + + // Check that invalid password specified as Jaas config property results in connection failure + saslClientConfigs.put(SaslConfigs.SASL_JAAS_CONFIG, TestJaasConfig.jaasConfigProperty("PLAIN", "user1", "user2-secret")); + createAndCheckClientConnectionFailure(securityProtocol, "3"); + + // Check that another user 'user2' can also connect with a Jaas config override without any changes to static configuration + saslClientConfigs.put(SaslConfigs.SASL_JAAS_CONFIG, TestJaasConfig.jaasConfigProperty("PLAIN", "user2", "user2-secret")); + createAndCheckClientConnection(securityProtocol, "4"); + + // Check that clients specifying multiple login modules fail even if the credentials are valid + String module1 = TestJaasConfig.jaasConfigProperty("PLAIN", "user1", "user1-secret").value(); + String module2 = TestJaasConfig.jaasConfigProperty("PLAIN", "user2", "user2-secret").value(); + saslClientConfigs.put(SaslConfigs.SASL_JAAS_CONFIG, new Password(module1 + " " + module2)); + try { + createClientConnection(securityProtocol, "1"); + fail("Connection created with multiple login modules in sasl.jaas.config"); + } catch (IllegalArgumentException e) { + // Expected + } + } + + /** + * Tests dynamic JAAS configuration property for SASL server. Invalid server credentials + * are set in the static JVM-wide configuration instance to ensure that the dynamic + * property override is used during authentication. + */ + @Test + public void testServerDynamicJaasConfiguration() throws Exception { + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, "PLAIN"); + saslServerConfigs.put(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG, Arrays.asList("PLAIN")); + Map serverOptions = new HashMap<>(); + serverOptions.put("user_user1", "user1-secret"); + serverOptions.put("user_user2", "user2-secret"); + saslServerConfigs.put("listener.name.sasl_ssl.plain." + SaslConfigs.SASL_JAAS_CONFIG, + TestJaasConfig.jaasConfigProperty("PLAIN", serverOptions)); + TestJaasConfig staticJaasConfig = new TestJaasConfig(); + staticJaasConfig.createOrUpdateEntry(TestJaasConfig.LOGIN_CONTEXT_SERVER, PlainLoginModule.class.getName(), + Collections.emptyMap()); + staticJaasConfig.setClientOptions("PLAIN", "user1", "user1-secret"); + Configuration.setConfiguration(staticJaasConfig); + server = createEchoServer(securityProtocol); + + // Check that 'user1' can connect with static Jaas config + createAndCheckClientConnection(securityProtocol, "1"); + + // Check that user 'user2' can also connect with a Jaas config override + saslClientConfigs.put(SaslConfigs.SASL_JAAS_CONFIG, + TestJaasConfig.jaasConfigProperty("PLAIN", "user2", "user2-secret")); + createAndCheckClientConnection(securityProtocol, "2"); + } + + @Test + public void testJaasConfigurationForListener() throws Exception { + SecurityProtocol securityProtocol = SecurityProtocol.SASL_PLAINTEXT; + saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, "PLAIN"); + saslServerConfigs.put(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG, Arrays.asList("PLAIN")); + + TestJaasConfig staticJaasConfig = new TestJaasConfig(); + + Map globalServerOptions = new HashMap<>(); + globalServerOptions.put("user_global1", "gsecret1"); + globalServerOptions.put("user_global2", "gsecret2"); + staticJaasConfig.createOrUpdateEntry(TestJaasConfig.LOGIN_CONTEXT_SERVER, PlainLoginModule.class.getName(), + globalServerOptions); + + Map clientListenerServerOptions = new HashMap<>(); + clientListenerServerOptions.put("user_client1", "csecret1"); + clientListenerServerOptions.put("user_client2", "csecret2"); + String clientJaasEntryName = "client." + TestJaasConfig.LOGIN_CONTEXT_SERVER; + staticJaasConfig.createOrUpdateEntry(clientJaasEntryName, PlainLoginModule.class.getName(), clientListenerServerOptions); + Configuration.setConfiguration(staticJaasConfig); + + // Listener-specific credentials + server = createEchoServer(new ListenerName("client"), securityProtocol); + saslClientConfigs.put(SaslConfigs.SASL_JAAS_CONFIG, + TestJaasConfig.jaasConfigProperty("PLAIN", "client1", "csecret1")); + createAndCheckClientConnection(securityProtocol, "1"); + saslClientConfigs.put(SaslConfigs.SASL_JAAS_CONFIG, + TestJaasConfig.jaasConfigProperty("PLAIN", "global1", "gsecret1")); + createAndCheckClientConnectionFailure(securityProtocol, "2"); + server.close(); + + // Global credentials as there is no listener-specific JAAS entry + server = createEchoServer(new ListenerName("other"), securityProtocol); + saslClientConfigs.put(SaslConfigs.SASL_JAAS_CONFIG, + TestJaasConfig.jaasConfigProperty("PLAIN", "global1", "gsecret1")); + createAndCheckClientConnection(securityProtocol, "3"); + saslClientConfigs.put(SaslConfigs.SASL_JAAS_CONFIG, + TestJaasConfig.jaasConfigProperty("PLAIN", "client1", "csecret1")); + createAndCheckClientConnectionFailure(securityProtocol, "4"); + } + + /** + * Tests good path SASL/PLAIN authentication over PLAINTEXT with old version of server + * that does not support SASL_AUTHENTICATE headers and new version of client. + */ + @Test + public void oldSaslPlainPlaintextServerWithoutSaslAuthenticateHeader() throws Exception { + verifySaslAuthenticateHeaderInterop(false, true, SecurityProtocol.SASL_PLAINTEXT, "PLAIN"); + } + + /** + * Tests good path SASL/PLAIN authentication over PLAINTEXT with old version of client + * that does not support SASL_AUTHENTICATE headers and new version of server. + */ + @Test + public void oldSaslPlainPlaintextClientWithoutSaslAuthenticateHeader() throws Exception { + verifySaslAuthenticateHeaderInterop(true, false, SecurityProtocol.SASL_PLAINTEXT, "PLAIN"); + } + + /** + * Tests good path SASL/SCRAM authentication over PLAINTEXT with old version of server + * that does not support SASL_AUTHENTICATE headers and new version of client. + */ + @Test + public void oldSaslScramPlaintextServerWithoutSaslAuthenticateHeader() throws Exception { + verifySaslAuthenticateHeaderInterop(false, true, SecurityProtocol.SASL_PLAINTEXT, "SCRAM-SHA-256"); + } + + /** + * Tests good path SASL/SCRAM authentication over PLAINTEXT with old version of client + * that does not support SASL_AUTHENTICATE headers and new version of server. + */ + @Test + public void oldSaslScramPlaintextClientWithoutSaslAuthenticateHeader() throws Exception { + verifySaslAuthenticateHeaderInterop(true, false, SecurityProtocol.SASL_PLAINTEXT, "SCRAM-SHA-256"); + } + + /** + * Tests good path SASL/PLAIN authentication over SSL with old version of server + * that does not support SASL_AUTHENTICATE headers and new version of client. + */ + @Test + public void oldSaslPlainSslServerWithoutSaslAuthenticateHeader() throws Exception { + verifySaslAuthenticateHeaderInterop(false, true, SecurityProtocol.SASL_SSL, "PLAIN"); + } + + /** + * Tests good path SASL/PLAIN authentication over SSL with old version of client + * that does not support SASL_AUTHENTICATE headers and new version of server. + */ + @Test + public void oldSaslPlainSslClientWithoutSaslAuthenticateHeader() throws Exception { + verifySaslAuthenticateHeaderInterop(true, false, SecurityProtocol.SASL_SSL, "PLAIN"); + } + + /** + * Tests good path SASL/SCRAM authentication over SSL with old version of server + * that does not support SASL_AUTHENTICATE headers and new version of client. + */ + @Test + public void oldSaslScramSslServerWithoutSaslAuthenticateHeader() throws Exception { + verifySaslAuthenticateHeaderInterop(false, true, SecurityProtocol.SASL_SSL, "SCRAM-SHA-512"); + } + + /** + * Tests good path SASL/SCRAM authentication over SSL with old version of client + * that does not support SASL_AUTHENTICATE headers and new version of server. + */ + @Test + public void oldSaslScramSslClientWithoutSaslAuthenticateHeader() throws Exception { + verifySaslAuthenticateHeaderInterop(true, false, SecurityProtocol.SASL_SSL, "SCRAM-SHA-512"); + } + + /** + * Tests SASL/PLAIN authentication failure over PLAINTEXT with old version of server + * that does not support SASL_AUTHENTICATE headers and new version of client. + */ + @Test + public void oldSaslPlainPlaintextServerWithoutSaslAuthenticateHeaderFailure() throws Exception { + verifySaslAuthenticateHeaderInteropWithFailure(false, true, SecurityProtocol.SASL_PLAINTEXT, "PLAIN"); + } + + /** + * Tests SASL/PLAIN authentication failure over PLAINTEXT with old version of client + * that does not support SASL_AUTHENTICATE headers and new version of server. + */ + @Test + public void oldSaslPlainPlaintextClientWithoutSaslAuthenticateHeaderFailure() throws Exception { + verifySaslAuthenticateHeaderInteropWithFailure(true, false, SecurityProtocol.SASL_PLAINTEXT, "PLAIN"); + } + + /** + * Tests SASL/SCRAM authentication failure over PLAINTEXT with old version of server + * that does not support SASL_AUTHENTICATE headers and new version of client. + */ + @Test + public void oldSaslScramPlaintextServerWithoutSaslAuthenticateHeaderFailure() throws Exception { + verifySaslAuthenticateHeaderInteropWithFailure(false, true, SecurityProtocol.SASL_PLAINTEXT, "SCRAM-SHA-256"); + } + + /** + * Tests SASL/SCRAM authentication failure over PLAINTEXT with old version of client + * that does not support SASL_AUTHENTICATE headers and new version of server. + */ + @Test + public void oldSaslScramPlaintextClientWithoutSaslAuthenticateHeaderFailure() throws Exception { + verifySaslAuthenticateHeaderInteropWithFailure(true, false, SecurityProtocol.SASL_PLAINTEXT, "SCRAM-SHA-256"); + } + + /** + * Tests SASL/PLAIN authentication failure over SSL with old version of server + * that does not support SASL_AUTHENTICATE headers and new version of client. + */ + @Test + public void oldSaslPlainSslServerWithoutSaslAuthenticateHeaderFailure() throws Exception { + verifySaslAuthenticateHeaderInteropWithFailure(false, true, SecurityProtocol.SASL_SSL, "PLAIN"); + } + + /** + * Tests SASL/PLAIN authentication failure over SSL with old version of client + * that does not support SASL_AUTHENTICATE headers and new version of server. + */ + @Test + public void oldSaslPlainSslClientWithoutSaslAuthenticateHeaderFailure() throws Exception { + verifySaslAuthenticateHeaderInteropWithFailure(true, false, SecurityProtocol.SASL_SSL, "PLAIN"); + } + + /** + * Tests SASL/SCRAM authentication failure over SSL with old version of server + * that does not support SASL_AUTHENTICATE headers and new version of client. + */ + @Test + public void oldSaslScramSslServerWithoutSaslAuthenticateHeaderFailure() throws Exception { + verifySaslAuthenticateHeaderInteropWithFailure(false, true, SecurityProtocol.SASL_SSL, "SCRAM-SHA-512"); + } + + /** + * Tests SASL/SCRAM authentication failure over SSL with old version of client + * that does not support SASL_AUTHENTICATE headers and new version of server. + */ + @Test + public void oldSaslScramSslClientWithoutSaslAuthenticateHeaderFailure() throws Exception { + verifySaslAuthenticateHeaderInteropWithFailure(true, false, SecurityProtocol.SASL_SSL, "SCRAM-SHA-512"); + } + + /** + * Tests OAUTHBEARER client and server channels. + */ + @Test + public void testValidSaslOauthBearerMechanism() throws Exception { + String node = "0"; + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + configureMechanisms("OAUTHBEARER", Arrays.asList("OAUTHBEARER")); + server = createEchoServer(securityProtocol); + createAndCheckClientConnection(securityProtocol, node); + } + + /** + * Re-authentication must fail if principal changes + */ + @Test + public void testCannotReauthenticateWithDifferentPrincipal() throws Exception { + String node = "0"; + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + saslClientConfigs.put(SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS, + AlternateLoginCallbackHandler.class.getName()); + configureMechanisms(OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, + Arrays.asList(OAuthBearerLoginModule.OAUTHBEARER_MECHANISM)); + server = createEchoServer(securityProtocol); + // initial authentication must succeed + createClientConnection(securityProtocol, node); + checkClientConnection(node); + // ensure metrics are as expected before trying to re-authenticate + server.verifyAuthenticationMetrics(1, 0); + server.verifyReauthenticationMetrics(0, 0); + /* + * Now re-authenticate with a different principal and ensure it fails. We first + * have to sleep long enough for the background refresh thread to replace the + * original token with a new one. + */ + delay(1000L); + assertThrows(AssertionFailedError.class, () -> checkClientConnection(node)); + server.verifyReauthenticationMetrics(0, 1); + } + + @Test + public void testCorrelationId() { + SaslClientAuthenticator authenticator = new SaslClientAuthenticator( + Collections.emptyMap(), + null, + "node", + null, + null, + null, + "plain", + false, + null, + null, + new LogContext() + ) { + @Override + SaslClient createSaslClient() { + return null; + } + }; + int count = (SaslClientAuthenticator.MAX_RESERVED_CORRELATION_ID - SaslClientAuthenticator.MIN_RESERVED_CORRELATION_ID) * 2; + Set ids = IntStream.range(0, count) + .mapToObj(i -> authenticator.nextCorrelationId()) + .collect(Collectors.toSet()); + assertEquals(SaslClientAuthenticator.MAX_RESERVED_CORRELATION_ID - SaslClientAuthenticator.MIN_RESERVED_CORRELATION_ID + 1, ids.size()); + ids.forEach(id -> { + assertTrue(id >= SaslClientAuthenticator.MIN_RESERVED_CORRELATION_ID); + assertTrue(SaslClientAuthenticator.isReserved(id)); + }); + } + + @Test + public void testConvertListOffsetResponseToSaslHandshakeResponse() { + ListOffsetsResponseData data = new ListOffsetsResponseData() + .setThrottleTimeMs(0) + .setTopics(Collections.singletonList(new ListOffsetsTopicResponse() + .setName("topic") + .setPartitions(Collections.singletonList(new ListOffsetsPartitionResponse() + .setErrorCode(Errors.NONE.code()) + .setLeaderEpoch(ListOffsetsResponse.UNKNOWN_EPOCH) + .setPartitionIndex(0) + .setOffset(0) + .setTimestamp(0))))); + ListOffsetsResponse response = new ListOffsetsResponse(data); + ByteBuffer buffer = RequestTestUtils.serializeResponseWithHeader(response, LIST_OFFSETS.latestVersion(), 0); + final RequestHeader header0 = new RequestHeader(LIST_OFFSETS, LIST_OFFSETS.latestVersion(), "id", SaslClientAuthenticator.MIN_RESERVED_CORRELATION_ID); + assertThrows(SchemaException.class, () -> NetworkClient.parseResponse(buffer.duplicate(), header0)); + final RequestHeader header1 = new RequestHeader(LIST_OFFSETS, LIST_OFFSETS.latestVersion(), "id", 1); + assertThrows(IllegalStateException.class, () -> NetworkClient.parseResponse(buffer.duplicate(), header1)); + } + + /** + * Re-authentication must fail if mechanism changes + */ + @Test + public void testCannotReauthenticateWithDifferentMechanism() throws Exception { + String node = "0"; + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + configureMechanisms("DIGEST-MD5", Arrays.asList("DIGEST-MD5", "PLAIN")); + configureDigestMd5ServerCallback(securityProtocol); + server = createEchoServer(securityProtocol); + + String saslMechanism = (String) saslClientConfigs.get(SaslConfigs.SASL_MECHANISM); + Map configs = new TestSecurityConfig(saslClientConfigs).values(); + this.channelBuilder = new AlternateSaslChannelBuilder(Mode.CLIENT, + Collections.singletonMap(saslMechanism, JaasContext.loadClientContext(configs)), securityProtocol, null, + false, saslMechanism, true, credentialCache, null, time); + this.channelBuilder.configure(configs); + // initial authentication must succeed + this.selector = NetworkTestUtils.createSelector(channelBuilder, time); + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); + checkClientConnection(node); + // ensure metrics are as expected before trying to re-authenticate + server.verifyAuthenticationMetrics(1, 0); + server.verifyReauthenticationMetrics(0, 0); + /* + * Now re-authenticate with a different mechanism and ensure it fails. We have + * to sleep long enough so that the next write will trigger a re-authentication. + */ + delay((long) (CONNECTIONS_MAX_REAUTH_MS_VALUE * 1.1)); + assertThrows(AssertionFailedError.class, () -> checkClientConnection(node)); + server.verifyAuthenticationMetrics(1, 0); + server.verifyReauthenticationMetrics(0, 1); + } + + /** + * Second re-authentication must fail if it is sooner than one second after the first + */ + @Test + public void testCannotReauthenticateAgainFasterThanOneSecond() throws Exception { + String node = "0"; + time = new MockTime(); + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + configureMechanisms(OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, + Arrays.asList(OAuthBearerLoginModule.OAUTHBEARER_MECHANISM)); + server = createEchoServer(securityProtocol); + try { + createClientConnection(securityProtocol, node); + checkClientConnection(node); + server.verifyAuthenticationMetrics(1, 0); + server.verifyReauthenticationMetrics(0, 0); + /* + * Now sleep long enough so that the next write will cause re-authentication, + * which we expect to succeed. + */ + time.sleep((long) (CONNECTIONS_MAX_REAUTH_MS_VALUE * 1.1)); + checkClientConnection(node); + server.verifyAuthenticationMetrics(1, 0); + server.verifyReauthenticationMetrics(1, 0); + /* + * Now sleep long enough so that the next write will cause re-authentication, + * but this time we expect re-authentication to not occur since it has been too + * soon. The checkClientConnection() call should return an error saying it + * expected the one byte-plus-node response but got the SaslHandshakeRequest + * instead + */ + time.sleep((long) (CONNECTIONS_MAX_REAUTH_MS_VALUE * 1.1)); + AssertionFailedError exception = assertThrows(AssertionFailedError.class, + () -> NetworkTestUtils.checkClientConnection(selector, node, 1, 1)); + String expectedResponseTextRegex = "\\w-" + node; + String receivedResponseTextRegex = ".*" + OAuthBearerLoginModule.OAUTHBEARER_MECHANISM; + assertTrue(exception.getMessage().matches( + ".*<" + expectedResponseTextRegex + ">.*<" + receivedResponseTextRegex + ".*?>"), + "Should have received the SaslHandshakeRequest bytes back since we re-authenticated too quickly, " + + "but instead we got our generated message echoed back, implying re-auth succeeded when it should not have: " + + exception); + server.verifyReauthenticationMetrics(1, 0); // unchanged + } finally { + selector.close(); + selector = null; + } + } + + /** + * Tests good path SASL/PLAIN client and server channels using SSL transport layer. + * Repeatedly tests successful re-authentication over several seconds. + */ + @Test + public void testRepeatedValidSaslPlainOverSsl() throws Exception { + String node = "0"; + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + configureMechanisms("PLAIN", Arrays.asList("PLAIN")); + /* + * Make sure 85% of this value is at least 1 second otherwise it is possible for + * the client to start re-authenticating but the server does not start due to + * the 1-second minimum. If this happens the SASL HANDSHAKE request that was + * injected to start re-authentication will be echoed back to the client instead + * of the data that the client explicitly sent, and then the client will not + * recognize that data and will throw an assertion error. + */ + saslServerConfigs.put(BrokerSecurityConfigs.CONNECTIONS_MAX_REAUTH_MS, + Double.valueOf(1.1 * 1000L / 0.85).longValue()); + + server = createEchoServer(securityProtocol); + createClientConnection(securityProtocol, node); + checkClientConnection(node); + server.verifyAuthenticationMetrics(1, 0); + server.verifyReauthenticationMetrics(0, 0); + double successfulReauthentications = 0; + int desiredNumReauthentications = 5; + long startMs = Time.SYSTEM.milliseconds(); + long timeoutMs = startMs + 1000 * 15; // stop after 15 seconds + while (successfulReauthentications < desiredNumReauthentications + && Time.SYSTEM.milliseconds() < timeoutMs) { + checkClientConnection(node); + successfulReauthentications = server.metricValue("successful-reauthentication-total"); + } + server.verifyReauthenticationMetrics(desiredNumReauthentications, 0); + } + + /** + * Tests OAUTHBEARER client channels without tokens for the server. + */ + @Test + public void testValidSaslOauthBearerMechanismWithoutServerTokens() throws Exception { + String node = "0"; + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, "OAUTHBEARER"); + saslServerConfigs.put(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG, Arrays.asList("OAUTHBEARER")); + saslClientConfigs.put(SaslConfigs.SASL_JAAS_CONFIG, + TestJaasConfig.jaasConfigProperty("OAUTHBEARER", Collections.singletonMap("unsecuredLoginStringClaim_sub", TestJaasConfig.USERNAME))); + saslServerConfigs.put("listener.name.sasl_ssl.oauthbearer." + SaslConfigs.SASL_JAAS_CONFIG, + TestJaasConfig.jaasConfigProperty("OAUTHBEARER", Collections.emptyMap())); + + // Server without a token should start up successfully and authenticate clients. + server = createEchoServer(securityProtocol); + createAndCheckClientConnection(securityProtocol, node); + + // Client without a token should fail to connect + saslClientConfigs.put(SaslConfigs.SASL_JAAS_CONFIG, + TestJaasConfig.jaasConfigProperty("OAUTHBEARER", Collections.emptyMap())); + createAndCheckClientConnectionFailure(securityProtocol, node); + + // Server with extensions, but without a token should fail to start up since it could indicate a configuration error + saslServerConfigs.put("listener.name.sasl_ssl.oauthbearer." + SaslConfigs.SASL_JAAS_CONFIG, + TestJaasConfig.jaasConfigProperty("OAUTHBEARER", Collections.singletonMap("unsecuredLoginExtension_test", "something"))); + try { + createEchoServer(securityProtocol); + fail("Server created with invalid login config containing extensions without a token"); + } catch (Throwable e) { + assertTrue(e.getCause() instanceof LoginException, "Unexpected exception " + Utils.stackTrace(e)); + } + } + + /** + * Tests OAUTHBEARER fails the connection when the client presents a token with + * insufficient scope . + */ + @Test + public void testInsufficientScopeSaslOauthBearerMechanism() throws Exception { + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + TestJaasConfig jaasConfig = configureMechanisms("OAUTHBEARER", Arrays.asList("OAUTHBEARER")); + // now update the server side to require a scope the client does not provide + Map serverJaasConfigOptionsMap = TestJaasConfig.defaultServerOptions("OAUTHBEARER"); + serverJaasConfigOptionsMap.put("unsecuredValidatorRequiredScope", "LOGIN_TO_KAFKA"); // causes the failure + jaasConfig.createOrUpdateEntry("KafkaServer", + "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule", serverJaasConfigOptionsMap); + server = createEchoServer(securityProtocol); + createAndCheckClientAuthenticationFailure(securityProtocol, + "node-" + OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, + "{\"status\":\"insufficient_scope\", \"scope\":\"[LOGIN_TO_KAFKA]\"}"); + } + + @Test + public void testSslClientAuthDisabledForSaslSslListener() throws Exception { + verifySslClientAuthForSaslSslListener(true, SslClientAuth.NONE); + } + + @Test + public void testSslClientAuthRequestedForSaslSslListener() throws Exception { + verifySslClientAuthForSaslSslListener(true, SslClientAuth.REQUESTED); + } + + @Test + public void testSslClientAuthRequiredForSaslSslListener() throws Exception { + verifySslClientAuthForSaslSslListener(true, SslClientAuth.REQUIRED); + } + + @Test + public void testSslClientAuthRequestedOverriddenForSaslSslListener() throws Exception { + verifySslClientAuthForSaslSslListener(false, SslClientAuth.REQUESTED); + } + + @Test + public void testSslClientAuthRequiredOverriddenForSaslSslListener() throws Exception { + verifySslClientAuthForSaslSslListener(false, SslClientAuth.REQUIRED); + } + + private void verifySslClientAuthForSaslSslListener(boolean useListenerPrefix, + SslClientAuth configuredClientAuth) throws Exception { + + SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL; + configureMechanisms("PLAIN", Collections.singletonList("PLAIN")); + String listenerPrefix = useListenerPrefix ? ListenerName.forSecurityProtocol(securityProtocol).configPrefix() : ""; + saslServerConfigs.put(listenerPrefix + BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, configuredClientAuth.name()); + saslServerConfigs.put(BrokerSecurityConfigs.PRINCIPAL_BUILDER_CLASS_CONFIG, SaslSslPrincipalBuilder.class.getName()); + server = createEchoServer(securityProtocol); + + SslClientAuth expectedClientAuth = useListenerPrefix ? configuredClientAuth : SslClientAuth.NONE; + String certDn = "O=A client,CN=localhost"; + KafkaPrincipal principalWithMutualTls = SaslSslPrincipalBuilder.saslSslPrincipal(TestJaasConfig.USERNAME, certDn); + KafkaPrincipal principalWithOneWayTls = SaslSslPrincipalBuilder.saslSslPrincipal(TestJaasConfig.USERNAME, "ANONYMOUS"); + + // Client configured with valid key store + createAndCheckClientConnectionAndPrincipal(securityProtocol, "0", + expectedClientAuth == SslClientAuth.NONE ? principalWithOneWayTls : principalWithMutualTls); + + // Client does not configure key store + removeClientSslKeystore(); + if (expectedClientAuth != SslClientAuth.REQUIRED) { + createAndCheckClientConnectionAndPrincipal(securityProtocol, "1", principalWithOneWayTls); + } else { + createAndCheckSslAuthenticationFailure(securityProtocol, "1"); + } + + // Client configures untrusted key store + CertStores newStore = new CertStores(false, "localhost"); + newStore.keyStoreProps().forEach((k, v) -> saslClientConfigs.put(k, v)); + if (expectedClientAuth == SslClientAuth.NONE) { + createAndCheckClientConnectionAndPrincipal(securityProtocol, "2", principalWithOneWayTls); + } else { + createAndCheckSslAuthenticationFailure(securityProtocol, "2"); + } + } + + private void removeClientSslKeystore() { + saslClientConfigs.remove(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG); + saslClientConfigs.remove(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG); + saslClientConfigs.remove(SslConfigs.SSL_KEY_PASSWORD_CONFIG); + } + + private void verifySaslAuthenticateHeaderInterop(boolean enableHeaderOnServer, boolean enableHeaderOnClient, + SecurityProtocol securityProtocol, String saslMechanism) throws Exception { + configureMechanisms(saslMechanism, Arrays.asList(saslMechanism)); + createServer(securityProtocol, saslMechanism, enableHeaderOnServer); + + String node = "0"; + createClientConnection(securityProtocol, saslMechanism, node, enableHeaderOnClient); + NetworkTestUtils.checkClientConnection(selector, "0", 100, 10); + } + + private void verifySaslAuthenticateHeaderInteropWithFailure(boolean enableHeaderOnServer, boolean enableHeaderOnClient, + SecurityProtocol securityProtocol, String saslMechanism) throws Exception { + TestJaasConfig jaasConfig = configureMechanisms(saslMechanism, Arrays.asList(saslMechanism)); + jaasConfig.setClientOptions(saslMechanism, TestJaasConfig.USERNAME, "invalidpassword"); + createServer(securityProtocol, saslMechanism, enableHeaderOnServer); + + String node = "0"; + createClientConnection(securityProtocol, saslMechanism, node, enableHeaderOnClient); + // Without SASL_AUTHENTICATE headers, disconnect state is ChannelState.AUTHENTICATE which is + // a hint that channel was closed during authentication, unlike ChannelState.AUTHENTICATE_FAILED + // which is an actual authentication failure reported by the broker. + NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATE); + } + + private void createServer(SecurityProtocol securityProtocol, String saslMechanism, + boolean enableSaslAuthenticateHeader) throws Exception { + if (enableSaslAuthenticateHeader) + server = createEchoServer(securityProtocol); + else + server = startServerWithoutSaslAuthenticateHeader(securityProtocol, saslMechanism); + updateScramCredentialCache(TestJaasConfig.USERNAME, TestJaasConfig.PASSWORD); + } + + private void createClientConnection(SecurityProtocol securityProtocol, String saslMechanism, String node, + boolean enableSaslAuthenticateHeader) throws Exception { + if (enableSaslAuthenticateHeader) + createClientConnection(securityProtocol, node); + else + createClientConnectionWithoutSaslAuthenticateHeader(securityProtocol, saslMechanism, node); + } + + private NioEchoServer startServerApiVersionsUnsupportedByClient(final SecurityProtocol securityProtocol, String saslMechanism) throws Exception { + final ListenerName listenerName = ListenerName.forSecurityProtocol(securityProtocol); + final Map configs = Collections.emptyMap(); + final JaasContext jaasContext = JaasContext.loadServerContext(listenerName, saslMechanism, configs); + final Map jaasContexts = Collections.singletonMap(saslMechanism, jaasContext); + + boolean isScram = ScramMechanism.isScram(saslMechanism); + if (isScram) + ScramCredentialUtils.createCache(credentialCache, Arrays.asList(saslMechanism)); + + Supplier apiVersionSupplier = () -> { + ApiVersionCollection versionCollection = new ApiVersionCollection(2); + versionCollection.add(new ApiVersion().setApiKey(ApiKeys.SASL_HANDSHAKE.id).setMinVersion((short) 0).setMaxVersion((short) 100)); + versionCollection.add(new ApiVersion().setApiKey(ApiKeys.SASL_AUTHENTICATE.id).setMinVersion((short) 0).setMaxVersion((short) 100)); + return new ApiVersionsResponse(new ApiVersionsResponseData().setApiKeys(versionCollection)); + }; + + SaslChannelBuilder serverChannelBuilder = new SaslChannelBuilder(Mode.SERVER, jaasContexts, + securityProtocol, listenerName, false, saslMechanism, true, + credentialCache, null, null, time, new LogContext(), apiVersionSupplier); + + serverChannelBuilder.configure(saslServerConfigs); + server = new NioEchoServer(listenerName, securityProtocol, new TestSecurityConfig(saslServerConfigs), + "localhost", serverChannelBuilder, credentialCache, time); + server.start(); + return server; + } + + private NioEchoServer startServerWithoutSaslAuthenticateHeader(final SecurityProtocol securityProtocol, String saslMechanism) + throws Exception { + final ListenerName listenerName = ListenerName.forSecurityProtocol(securityProtocol); + final Map configs = Collections.emptyMap(); + final JaasContext jaasContext = JaasContext.loadServerContext(listenerName, saslMechanism, configs); + final Map jaasContexts = Collections.singletonMap(saslMechanism, jaasContext); + + boolean isScram = ScramMechanism.isScram(saslMechanism); + if (isScram) + ScramCredentialUtils.createCache(credentialCache, Arrays.asList(saslMechanism)); + + Supplier apiVersionSupplier = () -> { + ApiVersionsResponse defaultApiVersionResponse = ApiVersionsResponse.defaultApiVersionsResponse( + ApiMessageType.ListenerType.ZK_BROKER); + ApiVersionCollection apiVersions = new ApiVersionCollection(); + for (ApiVersion apiVersion : defaultApiVersionResponse.data().apiKeys()) { + if (apiVersion.apiKey() != ApiKeys.SASL_AUTHENTICATE.id) { + // ApiVersion can NOT be reused in second ApiVersionCollection + // due to the internal pointers it contains. + apiVersions.add(apiVersion.duplicate()); + } + + } + ApiVersionsResponseData data = new ApiVersionsResponseData() + .setErrorCode(Errors.NONE.code()) + .setThrottleTimeMs(0) + .setApiKeys(apiVersions); + return new ApiVersionsResponse(data); + }; + + SaslChannelBuilder serverChannelBuilder = new SaslChannelBuilder(Mode.SERVER, jaasContexts, + securityProtocol, listenerName, false, saslMechanism, true, + credentialCache, null, null, time, new LogContext(), apiVersionSupplier) { + @Override + protected SaslServerAuthenticator buildServerAuthenticator(Map configs, + Map callbackHandlers, + String id, + TransportLayer transportLayer, + Map subjects, + Map connectionsMaxReauthMsByMechanism, + ChannelMetadataRegistry metadataRegistry) { + return new SaslServerAuthenticator(configs, callbackHandlers, id, subjects, null, listenerName, + securityProtocol, transportLayer, connectionsMaxReauthMsByMechanism, metadataRegistry, time, apiVersionSupplier) { + @Override + protected void enableKafkaSaslAuthenticateHeaders(boolean flag) { + // Don't enable Kafka SASL_AUTHENTICATE headers + } + }; + } + }; + serverChannelBuilder.configure(saslServerConfigs); + server = new NioEchoServer(listenerName, securityProtocol, new TestSecurityConfig(saslServerConfigs), + "localhost", serverChannelBuilder, credentialCache, time); + server.start(); + return server; + } + + private void createClientConnectionWithoutSaslAuthenticateHeader(final SecurityProtocol securityProtocol, + final String saslMechanism, String node) throws Exception { + + final ListenerName listenerName = ListenerName.forSecurityProtocol(securityProtocol); + final Map configs = Collections.emptyMap(); + final JaasContext jaasContext = JaasContext.loadClientContext(configs); + final Map jaasContexts = Collections.singletonMap(saslMechanism, jaasContext); + + SaslChannelBuilder clientChannelBuilder = new SaslChannelBuilder(Mode.CLIENT, jaasContexts, + securityProtocol, listenerName, false, saslMechanism, true, + null, null, null, time, new LogContext(), null) { + + @Override + protected SaslClientAuthenticator buildClientAuthenticator(Map configs, + AuthenticateCallbackHandler callbackHandler, + String id, + String serverHost, + String servicePrincipal, + TransportLayer transportLayer, + Subject subject) { + + return new SaslClientAuthenticator(configs, callbackHandler, id, subject, + servicePrincipal, serverHost, saslMechanism, true, + transportLayer, time, new LogContext()) { + @Override + protected SaslHandshakeRequest createSaslHandshakeRequest(short version) { + return buildSaslHandshakeRequest(saslMechanism, (short) 0); + } + @Override + protected void setSaslAuthenticateAndHandshakeVersions(ApiVersionsResponse apiVersionsResponse) { + // Don't set version so that headers are disabled + } + }; + } + }; + clientChannelBuilder.configure(saslClientConfigs); + this.selector = NetworkTestUtils.createSelector(clientChannelBuilder, time); + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); + } + + /** + * Tests that Kafka ApiVersionsRequests are handled by the SASL server authenticator + * prior to SASL handshake flow and that subsequent authentication succeeds + * when transport layer is PLAINTEXT/SSL. This test uses a non-SASL client that simulates + * SASL authentication after ApiVersionsRequest. + *

            + * Test sequence (using securityProtocol=PLAINTEXT as an example): + *

              + *
            1. Starts a SASL_PLAINTEXT test server that simply echoes back client requests after authentication.
            2. + *
            3. A (non-SASL) PLAINTEXT test client connects to the SASL server port. Client is now unauthenticated.<./li> + *
            4. The unauthenticated non-SASL client sends an ApiVersionsRequest and validates the response. + * A valid response indicates that {@link SaslServerAuthenticator} of the test server responded to + * the ApiVersionsRequest even though the client is not yet authenticated.
            5. + *
            6. The unauthenticated non-SASL client sends a SaslHandshakeRequest and validates the response. A valid response + * indicates that {@link SaslServerAuthenticator} of the test server responded to the SaslHandshakeRequest + * after processing ApiVersionsRequest.
            7. + *
            8. The unauthenticated non-SASL client sends the SASL/PLAIN packet containing username/password to authenticate + * itself. The client is now authenticated by the server. At this point this test client is at the + * same state as a regular SASL_PLAINTEXT client that is ready.
            9. + *
            10. The authenticated client sends random data to the server and checks that the data is echoed + * back by the test server (ie, not Kafka request-response) to ensure that the client now + * behaves exactly as a regular SASL_PLAINTEXT client that has completed authentication.
            11. + *
            + */ + private void testUnauthenticatedApiVersionsRequest(SecurityProtocol securityProtocol, short saslHandshakeVersion) throws Exception { + configureMechanisms("PLAIN", Arrays.asList("PLAIN")); + server = createEchoServer(securityProtocol); + + // Create non-SASL connection to manually authenticate after ApiVersionsRequest + String node = "1"; + SecurityProtocol clientProtocol; + switch (securityProtocol) { + case SASL_PLAINTEXT: + clientProtocol = SecurityProtocol.PLAINTEXT; + break; + case SASL_SSL: + clientProtocol = SecurityProtocol.SSL; + break; + default: + throw new IllegalArgumentException("Server protocol " + securityProtocol + " is not SASL"); + } + createClientConnection(clientProtocol, node); + NetworkTestUtils.waitForChannelReady(selector, node); + + // Send ApiVersionsRequest and check response + ApiVersionsResponse versionsResponse = sendVersionRequestReceiveResponse(node); + assertEquals(ApiKeys.SASL_HANDSHAKE.oldestVersion(), versionsResponse.apiVersion(ApiKeys.SASL_HANDSHAKE.id).minVersion()); + assertEquals(ApiKeys.SASL_HANDSHAKE.latestVersion(), versionsResponse.apiVersion(ApiKeys.SASL_HANDSHAKE.id).maxVersion()); + assertEquals(ApiKeys.SASL_AUTHENTICATE.oldestVersion(), versionsResponse.apiVersion(ApiKeys.SASL_AUTHENTICATE.id).minVersion()); + assertEquals(ApiKeys.SASL_AUTHENTICATE.latestVersion(), versionsResponse.apiVersion(ApiKeys.SASL_AUTHENTICATE.id).maxVersion()); + + // Send SaslHandshakeRequest and check response + SaslHandshakeResponse handshakeResponse = sendHandshakeRequestReceiveResponse(node, saslHandshakeVersion); + assertEquals(Collections.singletonList("PLAIN"), handshakeResponse.enabledMechanisms()); + + // Complete manual authentication and check send/receive succeed + authenticateUsingSaslPlainAndCheckConnection(node, saslHandshakeVersion > 0); + } + + private void authenticateUsingSaslPlainAndCheckConnection(String node, boolean enableSaslAuthenticateHeader) throws Exception { + // Authenticate using PLAIN username/password + String authString = "\u0000" + TestJaasConfig.USERNAME + "\u0000" + TestJaasConfig.PASSWORD; + ByteBuffer authBuf = ByteBuffer.wrap(Utils.utf8(authString)); + if (enableSaslAuthenticateHeader) { + SaslAuthenticateRequestData data = new SaslAuthenticateRequestData().setAuthBytes(authBuf.array()); + SaslAuthenticateRequest request = new SaslAuthenticateRequest.Builder(data).build(); + sendKafkaRequestReceiveResponse(node, ApiKeys.SASL_AUTHENTICATE, request); + } else { + selector.send(new NetworkSend(node, ByteBufferSend.sizePrefixed(authBuf))); + waitForResponse(); + } + + // Check send/receive on the manually authenticated connection + NetworkTestUtils.checkClientConnection(selector, node, 100, 10); + } + + private TestJaasConfig configureMechanisms(String clientMechanism, List serverMechanisms) { + saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, clientMechanism); + saslServerConfigs.put(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG, serverMechanisms); + saslServerConfigs.put(BrokerSecurityConfigs.CONNECTIONS_MAX_REAUTH_MS, CONNECTIONS_MAX_REAUTH_MS_VALUE); + if (serverMechanisms.contains("DIGEST-MD5")) { + saslServerConfigs.put("digest-md5." + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS, + TestDigestLoginModule.DigestServerCallbackHandler.class.getName()); + } + return TestJaasConfig.createConfiguration(clientMechanism, serverMechanisms); + } + + private void configureDigestMd5ServerCallback(SecurityProtocol securityProtocol) { + String callbackPrefix = ListenerName.forSecurityProtocol(securityProtocol).saslMechanismConfigPrefix("DIGEST-MD5"); + saslServerConfigs.put(callbackPrefix + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS, + TestDigestLoginModule.DigestServerCallbackHandler.class); + } + + private void createSelector(SecurityProtocol securityProtocol, Map clientConfigs) { + if (selector != null) { + selector.close(); + selector = null; + } + + String saslMechanism = (String) saslClientConfigs.get(SaslConfigs.SASL_MECHANISM); + this.channelBuilder = ChannelBuilders.clientChannelBuilder(securityProtocol, JaasContext.Type.CLIENT, + new TestSecurityConfig(clientConfigs), null, saslMechanism, time, + true, new LogContext()); + this.selector = NetworkTestUtils.createSelector(channelBuilder, time); + } + + private NioEchoServer createEchoServer(SecurityProtocol securityProtocol) throws Exception { + return createEchoServer(ListenerName.forSecurityProtocol(securityProtocol), securityProtocol); + } + + private NioEchoServer createEchoServer(ListenerName listenerName, SecurityProtocol securityProtocol) throws Exception { + return NetworkTestUtils.createEchoServer(listenerName, securityProtocol, + new TestSecurityConfig(saslServerConfigs), credentialCache, time); + } + + private NioEchoServer createEchoServer(ListenerName listenerName, SecurityProtocol securityProtocol, + DelegationTokenCache tokenCache) throws Exception { + return NetworkTestUtils.createEchoServer(listenerName, securityProtocol, + new TestSecurityConfig(saslServerConfigs), credentialCache, 100, time, tokenCache); + } + + private void createClientConnection(SecurityProtocol securityProtocol, String node) throws Exception { + createSelector(securityProtocol, saslClientConfigs); + InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); + selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); + } + + private void checkClientConnection(String node) throws Exception { + NetworkTestUtils.checkClientConnection(selector, node, 100, 10); + } + + private void closeClientConnectionIfNecessary() throws Exception { + if (selector != null) { + selector.close(); + selector = null; + } + } + + /* + * Also closes the connection after creating/checking it + */ + private void createAndCheckClientConnection(SecurityProtocol securityProtocol, String node) throws Exception { + try { + createClientConnection(securityProtocol, node); + checkClientConnection(node); + } finally { + closeClientConnectionIfNecessary(); + } + } + + private void createAndCheckClientAuthenticationFailure(SecurityProtocol securityProtocol, String node, + String mechanism, String expectedErrorMessage) throws Exception { + ChannelState finalState = createAndCheckClientConnectionFailure(securityProtocol, node); + Exception exception = finalState.exception(); + assertTrue(exception instanceof SaslAuthenticationException, "Invalid exception class " + exception.getClass()); + String expectedExceptionMessage = expectedErrorMessage != null ? expectedErrorMessage : + "Authentication failed during authentication due to invalid credentials with SASL mechanism " + mechanism; + assertEquals(expectedExceptionMessage, exception.getMessage()); + } + + private ChannelState createAndCheckClientConnectionFailure(SecurityProtocol securityProtocol, String node) + throws Exception { + try { + createClientConnection(securityProtocol, node); + ChannelState finalState = NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED); + return finalState; + } finally { + closeClientConnectionIfNecessary(); + } + } + + private void createAndCheckClientConnectionAndPrincipal(SecurityProtocol securityProtocol, + String node, + KafkaPrincipal expectedPrincipal) throws Exception { + try { + assertEquals(Collections.emptyList(), server.selector().channels()); + createClientConnection(securityProtocol, node); + NetworkTestUtils.waitForChannelReady(selector, node); + assertEquals(expectedPrincipal, server.selector().channels().get(0).principal()); + checkClientConnection(node); + } finally { + closeClientConnectionIfNecessary(); + TestUtils.waitForCondition(() -> server.selector().channels().isEmpty(), "Channel not removed after disconnection"); + } + } + + private void createAndCheckSslAuthenticationFailure(SecurityProtocol securityProtocol, String node) throws Exception { + ChannelState finalState = createAndCheckClientConnectionFailure(securityProtocol, node); + Exception exception = finalState.exception(); + assertEquals(SslAuthenticationException.class, exception.getClass()); + } + + private void checkAuthenticationAndReauthentication(SecurityProtocol securityProtocol, String node) + throws Exception { + try { + createClientConnection(securityProtocol, node); + checkClientConnection(node); + server.verifyAuthenticationMetrics(1, 0); + /* + * Now re-authenticate the connection. First we have to sleep long enough so + * that the next write will cause re-authentication, which we expect to succeed. + */ + delay((long) (CONNECTIONS_MAX_REAUTH_MS_VALUE * 1.1)); + server.verifyReauthenticationMetrics(0, 0); + checkClientConnection(node); + server.verifyReauthenticationMetrics(1, 0); + } finally { + closeClientConnectionIfNecessary(); + } + } + + private AbstractResponse sendKafkaRequestReceiveResponse(String node, ApiKeys apiKey, AbstractRequest request) throws IOException { + RequestHeader header = new RequestHeader(apiKey, request.version(), "someclient", nextCorrelationId++); + NetworkSend send = new NetworkSend(node, request.toSend(header)); + selector.send(send); + ByteBuffer responseBuffer = waitForResponse(); + return NetworkClient.parseResponse(responseBuffer, header); + } + + private SaslHandshakeResponse sendHandshakeRequestReceiveResponse(String node, short version) throws Exception { + SaslHandshakeRequest handshakeRequest = buildSaslHandshakeRequest("PLAIN", version); + SaslHandshakeResponse response = (SaslHandshakeResponse) sendKafkaRequestReceiveResponse(node, ApiKeys.SASL_HANDSHAKE, handshakeRequest); + assertEquals(Errors.NONE, response.error()); + return response; + } + + private ApiVersionsResponse sendVersionRequestReceiveResponse(String node) throws Exception { + ApiVersionsRequest handshakeRequest = createApiVersionsRequestV0(); + ApiVersionsResponse response = (ApiVersionsResponse) sendKafkaRequestReceiveResponse(node, ApiKeys.API_VERSIONS, handshakeRequest); + assertEquals(Errors.NONE.code(), response.data().errorCode()); + return response; + } + + private ByteBuffer waitForResponse() throws IOException { + int waitSeconds = 10; + do { + selector.poll(1000); + } while (selector.completedReceives().isEmpty() && waitSeconds-- > 0); + assertEquals(1, selector.completedReceives().size()); + return selector.completedReceives().iterator().next().payload(); + } + + public static class TestServerCallbackHandler extends PlainServerCallbackHandler { + + static final String USERNAME = "TestServerCallbackHandler-user"; + static final String PASSWORD = "TestServerCallbackHandler-password"; + private volatile boolean configured; + + @Override + public void configure(Map configs, String mechanism, List jaasConfigEntries) { + if (configured) + throw new IllegalStateException("Server callback handler configured twice"); + configured = true; + super.configure(configs, mechanism, jaasConfigEntries); + } + + @Override + protected boolean authenticate(String username, char[] password) { + if (!configured) + throw new IllegalStateException("Server callback handler not configured"); + return USERNAME.equals(username) && new String(password).equals(PASSWORD); + } + } + + private SaslHandshakeRequest buildSaslHandshakeRequest(String mechanism, short version) { + return new SaslHandshakeRequest.Builder( + new SaslHandshakeRequestData().setMechanism(mechanism)).build(version); + } + + @SuppressWarnings("unchecked") + private void updateScramCredentialCache(String username, String password) throws NoSuchAlgorithmException { + for (String mechanism : (List) saslServerConfigs.get(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG)) { + ScramMechanism scramMechanism = ScramMechanism.forMechanismName(mechanism); + if (scramMechanism != null) { + ScramFormatter formatter = new ScramFormatter(scramMechanism); + ScramCredential credential = formatter.generateCredential(password, 4096); + credentialCache.cache(scramMechanism.mechanismName(), ScramCredential.class).put(username, credential); + } + } + } + + // Creates an ApiVersionsRequest with version 0. Using v0 in tests since + // SaslClientAuthenticator always uses version 0 + private ApiVersionsRequest createApiVersionsRequestV0() { + return new ApiVersionsRequest.Builder((short) 0).build(); + } + + @SuppressWarnings("unchecked") + private void updateTokenCredentialCache(String username, String password) throws NoSuchAlgorithmException { + for (String mechanism : (List) saslServerConfigs.get(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG)) { + ScramMechanism scramMechanism = ScramMechanism.forMechanismName(mechanism); + if (scramMechanism != null) { + ScramFormatter formatter = new ScramFormatter(scramMechanism); + ScramCredential credential = formatter.generateCredential(password, 4096); + server.tokenCache().credentialCache(scramMechanism.mechanismName()).put(username, credential); + } + } + } + + private static void delay(long delayMillis) throws InterruptedException { + final long startTime = System.currentTimeMillis(); + while ((System.currentTimeMillis() - startTime) < delayMillis) + Thread.sleep(CONNECTIONS_MAX_REAUTH_MS_VALUE / 5); + } + + public static class TestClientCallbackHandler implements AuthenticateCallbackHandler { + + static final String USERNAME = "TestClientCallbackHandler-user"; + static final String PASSWORD = "TestClientCallbackHandler-password"; + private volatile boolean configured; + + @Override + public void configure(Map configs, String mechanism, List jaasConfigEntries) { + if (configured) + throw new IllegalStateException("Client callback handler configured twice"); + configured = true; + } + + @Override + public void handle(Callback[] callbacks) throws UnsupportedCallbackException { + if (!configured) + throw new IllegalStateException("Client callback handler not configured"); + for (Callback callback : callbacks) { + if (callback instanceof NameCallback) + ((NameCallback) callback).setName(USERNAME); + else if (callback instanceof PasswordCallback) + ((PasswordCallback) callback).setPassword(PASSWORD.toCharArray()); + else + throw new UnsupportedCallbackException(callback); + } + } + + @Override + public void close() { + } + } + + public static class TestLogin implements Login { + + static AtomicInteger loginCount = new AtomicInteger(); + + private String contextName; + private Configuration configuration; + private Subject subject; + @Override + public void configure(Map configs, String contextName, Configuration configuration, + AuthenticateCallbackHandler callbackHandler) { + assertEquals(1, configuration.getAppConfigurationEntry(contextName).length); + this.contextName = contextName; + this.configuration = configuration; + } + + @Override + public LoginContext login() throws LoginException { + LoginContext context = new LoginContext(contextName, null, new AbstractLogin.DefaultLoginCallbackHandler(), configuration); + context.login(); + subject = context.getSubject(); + subject.getPublicCredentials().clear(); + subject.getPrivateCredentials().clear(); + subject.getPublicCredentials().add(TestJaasConfig.USERNAME); + subject.getPrivateCredentials().add(TestJaasConfig.PASSWORD); + loginCount.incrementAndGet(); + return context; + } + + @Override + public Subject subject() { + return subject; + } + + @Override + public String serviceName() { + return "kafka"; + } + + @Override + public void close() { + } + } + + public static class TestLoginCallbackHandler implements AuthenticateCallbackHandler { + private volatile boolean configured = false; + @Override + public void configure(Map configs, String saslMechanism, List jaasConfigEntries) { + if (configured) + throw new IllegalStateException("Login callback handler configured twice"); + configured = true; + } + + @Override + public void handle(Callback[] callbacks) { + if (!configured) + throw new IllegalStateException("Login callback handler not configured"); + + for (Callback callback : callbacks) { + if (callback instanceof NameCallback) + ((NameCallback) callback).setName(TestJaasConfig.USERNAME); + else if (callback instanceof PasswordCallback) + ((PasswordCallback) callback).setPassword(TestJaasConfig.PASSWORD.toCharArray()); + } + } + + @Override + public void close() { + } + } + + public static final class TestPlainLoginModule extends PlainLoginModule { + @Override + public void initialize(Subject subject, CallbackHandler callbackHandler, Map sharedState, Map options) { + try { + NameCallback nameCallback = new NameCallback("name:"); + PasswordCallback passwordCallback = new PasswordCallback("password:", false); + callbackHandler.handle(new Callback[]{nameCallback, passwordCallback}); + subject.getPublicCredentials().add(nameCallback.getName()); + subject.getPrivateCredentials().add(new String(passwordCallback.getPassword())); + } catch (Exception e) { + throw new SaslAuthenticationException("Login initialization failed", e); + } + } + } + + /* + * Create an alternate login callback handler that continually returns a + * different principal + */ + public static class AlternateLoginCallbackHandler implements AuthenticateCallbackHandler { + private static final OAuthBearerUnsecuredLoginCallbackHandler DELEGATE = new OAuthBearerUnsecuredLoginCallbackHandler(); + private static final String QUOTE = "\""; + private static int numInvocations = 0; + + @Override + public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { + DELEGATE.handle(callbacks); + // now change any returned token to have a different principal name + if (callbacks.length > 0) + for (Callback callback : callbacks) { + if (callback instanceof OAuthBearerTokenCallback) { + OAuthBearerTokenCallback oauthBearerTokenCallback = (OAuthBearerTokenCallback) callback; + OAuthBearerToken token = oauthBearerTokenCallback.token(); + if (token != null) { + String changedPrincipalNameToUse = token.principalName() + + String.valueOf(++numInvocations); + String headerJson = "{" + claimOrHeaderJsonText("alg", "none") + "}"; + /* + * Use a short lifetime so the background refresh thread replaces it before we + * re-authenticate + */ + String lifetimeSecondsValueToUse = "1"; + String claimsJson; + try { + claimsJson = String.format("{%s,%s,%s}", + expClaimText(Long.parseLong(lifetimeSecondsValueToUse)), + claimOrHeaderJsonText("iat", time.milliseconds() / 1000.0), + claimOrHeaderJsonText("sub", changedPrincipalNameToUse)); + } catch (NumberFormatException e) { + throw new OAuthBearerConfigException(e.getMessage()); + } + try { + Encoder urlEncoderNoPadding = Base64.getUrlEncoder().withoutPadding(); + OAuthBearerUnsecuredJws jws = new OAuthBearerUnsecuredJws(String.format("%s.%s.", + urlEncoderNoPadding.encodeToString(headerJson.getBytes(StandardCharsets.UTF_8)), + urlEncoderNoPadding + .encodeToString(claimsJson.getBytes(StandardCharsets.UTF_8))), + "sub", "scope"); + oauthBearerTokenCallback.token(jws); + } catch (OAuthBearerIllegalTokenException e) { + // occurs if the principal claim doesn't exist or has an empty value + throw new OAuthBearerConfigException(e.getMessage(), e); + } + } + } + } + } + + private static String claimOrHeaderJsonText(String claimName, String claimValue) { + return QUOTE + claimName + QUOTE + ":" + QUOTE + claimValue + QUOTE; + } + + private static String claimOrHeaderJsonText(String claimName, Number claimValue) { + return QUOTE + claimName + QUOTE + ":" + claimValue; + } + + private static String expClaimText(long lifetimeSeconds) { + return claimOrHeaderJsonText("exp", time.milliseconds() / 1000.0 + lifetimeSeconds); + } + + @Override + public void configure(Map configs, String saslMechanism, + List jaasConfigEntries) { + DELEGATE.configure(configs, saslMechanism, jaasConfigEntries); + } + + @Override + public void close() { + DELEGATE.close(); + } + } + + /* + * Define a channel builder that starts with the DIGEST-MD5 mechanism and then + * switches to the PLAIN mechanism + */ + private static class AlternateSaslChannelBuilder extends SaslChannelBuilder { + private int numInvocations = 0; + + public AlternateSaslChannelBuilder(Mode mode, Map jaasContexts, + SecurityProtocol securityProtocol, ListenerName listenerName, boolean isInterBrokerListener, + String clientSaslMechanism, boolean handshakeRequestEnable, CredentialCache credentialCache, + DelegationTokenCache tokenCache, Time time) { + super(mode, jaasContexts, securityProtocol, listenerName, isInterBrokerListener, clientSaslMechanism, + handshakeRequestEnable, credentialCache, tokenCache, null, time, new LogContext(), + () -> ApiVersionsResponse.defaultApiVersionsResponse(ApiMessageType.ListenerType.ZK_BROKER)); + } + + @Override + protected SaslClientAuthenticator buildClientAuthenticator(Map configs, + AuthenticateCallbackHandler callbackHandler, String id, String serverHost, String servicePrincipal, + TransportLayer transportLayer, Subject subject) { + if (++numInvocations == 1) + return new SaslClientAuthenticator(configs, callbackHandler, id, subject, servicePrincipal, serverHost, + "DIGEST-MD5", true, transportLayer, time, new LogContext()); + else + return new SaslClientAuthenticator(configs, callbackHandler, id, subject, servicePrincipal, serverHost, + "PLAIN", true, transportLayer, time, new LogContext()) { + @Override + protected SaslHandshakeRequest createSaslHandshakeRequest(short version) { + return new SaslHandshakeRequest.Builder( + new SaslHandshakeRequestData().setMechanism("PLAIN")).build(version); + } + }; + } + } + + public static class SaslSslPrincipalBuilder implements KafkaPrincipalBuilder { + + @Override + public KafkaPrincipal build(AuthenticationContext context) { + SaslAuthenticationContext saslContext = (SaslAuthenticationContext) context; + assertTrue(saslContext.sslSession().isPresent()); + String sslPrincipal; + try { + sslPrincipal = saslContext.sslSession().get().getPeerPrincipal().getName(); + } catch (SSLPeerUnverifiedException e) { + sslPrincipal = KafkaPrincipal.ANONYMOUS.getName(); + } + String saslPrincipal = saslContext.server().getAuthorizationID(); + return saslSslPrincipal(saslPrincipal, sslPrincipal); + } + + static KafkaPrincipal saslSslPrincipal(String saslPrincipal, String sslPrincipal) { + return new KafkaPrincipal(KafkaPrincipal.USER_TYPE, saslPrincipal + ":" + sslPrincipal); + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticatorTest.java b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticatorTest.java new file mode 100644 index 0000000..af0fedd --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslServerAuthenticatorTest.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.authenticator; + +import java.net.InetAddress; +import org.apache.kafka.common.config.internals.BrokerSecurityConfigs; +import org.apache.kafka.common.errors.IllegalSaslStateException; +import org.apache.kafka.common.message.ApiMessageType; +import org.apache.kafka.common.network.ChannelMetadataRegistry; +import org.apache.kafka.common.network.ClientInformation; +import org.apache.kafka.common.network.DefaultChannelMetadataRegistry; +import org.apache.kafka.common.network.InvalidReceiveException; +import org.apache.kafka.common.network.ListenerName; +import org.apache.kafka.common.network.TransportLayer; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.requests.ApiVersionsRequest; +import org.apache.kafka.common.requests.ApiVersionsResponse; +import org.apache.kafka.common.requests.RequestTestUtils; +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.common.requests.RequestHeader; +import org.apache.kafka.common.security.plain.PlainLoginModule; +import org.apache.kafka.common.utils.AppInfoParser; +import org.apache.kafka.common.utils.Time; + +import javax.security.auth.Subject; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.junit.jupiter.api.Test; +import org.mockito.Answers; + +import static org.apache.kafka.common.security.scram.internals.ScramMechanism.SCRAM_SHA_256; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class SaslServerAuthenticatorTest { + + @Test + public void testOversizeRequest() throws IOException { + TransportLayer transportLayer = mock(TransportLayer.class); + Map configs = Collections.singletonMap(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG, + Collections.singletonList(SCRAM_SHA_256.mechanismName())); + SaslServerAuthenticator authenticator = setupAuthenticator(configs, transportLayer, + SCRAM_SHA_256.mechanismName(), new DefaultChannelMetadataRegistry()); + + when(transportLayer.read(any(ByteBuffer.class))).then(invocation -> { + invocation.getArgument(0).putInt(SaslServerAuthenticator.MAX_RECEIVE_SIZE + 1); + return 4; + }); + assertThrows(InvalidReceiveException.class, authenticator::authenticate); + verify(transportLayer).read(any(ByteBuffer.class)); + } + + @Test + public void testUnexpectedRequestType() throws IOException { + TransportLayer transportLayer = mock(TransportLayer.class); + Map configs = Collections.singletonMap(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG, + Collections.singletonList(SCRAM_SHA_256.mechanismName())); + SaslServerAuthenticator authenticator = setupAuthenticator(configs, transportLayer, + SCRAM_SHA_256.mechanismName(), new DefaultChannelMetadataRegistry()); + + RequestHeader header = new RequestHeader(ApiKeys.METADATA, (short) 0, "clientId", 13243); + ByteBuffer headerBuffer = RequestTestUtils.serializeRequestHeader(header); + + when(transportLayer.read(any(ByteBuffer.class))).then(invocation -> { + invocation.getArgument(0).putInt(headerBuffer.remaining()); + return 4; + }).then(invocation -> { + // serialize only the request header. the authenticator should not parse beyond this + invocation.getArgument(0).put(headerBuffer.duplicate()); + return headerBuffer.remaining(); + }); + + try { + authenticator.authenticate(); + fail("Expected authenticate() to raise an exception"); + } catch (IllegalSaslStateException e) { + // expected exception + } + + verify(transportLayer, times(2)).read(any(ByteBuffer.class)); + } + + @Test + public void testOldestApiVersionsRequest() throws IOException { + testApiVersionsRequest(ApiKeys.API_VERSIONS.oldestVersion(), + ClientInformation.UNKNOWN_NAME_OR_VERSION, ClientInformation.UNKNOWN_NAME_OR_VERSION); + } + + @Test + public void testLatestApiVersionsRequest() throws IOException { + testApiVersionsRequest(ApiKeys.API_VERSIONS.latestVersion(), + "apache-kafka-java", AppInfoParser.getVersion()); + } + + private void testApiVersionsRequest(short version, String expectedSoftwareName, + String expectedSoftwareVersion) throws IOException { + TransportLayer transportLayer = mock(TransportLayer.class, Answers.RETURNS_DEEP_STUBS); + Map configs = Collections.singletonMap(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG, + Collections.singletonList(SCRAM_SHA_256.mechanismName())); + ChannelMetadataRegistry metadataRegistry = new DefaultChannelMetadataRegistry(); + SaslServerAuthenticator authenticator = setupAuthenticator(configs, transportLayer, + SCRAM_SHA_256.mechanismName(), metadataRegistry); + + RequestHeader header = new RequestHeader(ApiKeys.API_VERSIONS, version, "clientId", 0); + ByteBuffer headerBuffer = RequestTestUtils.serializeRequestHeader(header); + + ApiVersionsRequest request = new ApiVersionsRequest.Builder().build(version); + ByteBuffer requestBuffer = request.serialize(); + requestBuffer.rewind(); + + when(transportLayer.socketChannel().socket().getInetAddress()).thenReturn(InetAddress.getLoopbackAddress()); + + when(transportLayer.read(any(ByteBuffer.class))).then(invocation -> { + invocation.getArgument(0).putInt(headerBuffer.remaining() + requestBuffer.remaining()); + return 4; + }).then(invocation -> { + invocation.getArgument(0) + .put(headerBuffer.duplicate()) + .put(requestBuffer.duplicate()); + return headerBuffer.remaining() + requestBuffer.remaining(); + }); + + authenticator.authenticate(); + + assertEquals(expectedSoftwareName, metadataRegistry.clientInformation().softwareName()); + assertEquals(expectedSoftwareVersion, metadataRegistry.clientInformation().softwareVersion()); + + verify(transportLayer, times(2)).read(any(ByteBuffer.class)); + } + + private SaslServerAuthenticator setupAuthenticator(Map configs, TransportLayer transportLayer, + String mechanism, ChannelMetadataRegistry metadataRegistry) { + TestJaasConfig jaasConfig = new TestJaasConfig(); + jaasConfig.addEntry("jaasContext", PlainLoginModule.class.getName(), new HashMap()); + Map subjects = Collections.singletonMap(mechanism, new Subject()); + Map callbackHandlers = Collections.singletonMap( + mechanism, new SaslServerCallbackHandler()); + ApiVersionsResponse apiVersionsResponse = ApiVersionsResponse.defaultApiVersionsResponse( + ApiMessageType.ListenerType.ZK_BROKER); + return new SaslServerAuthenticator(configs, callbackHandlers, "node", subjects, null, + new ListenerName("ssl"), SecurityProtocol.SASL_SSL, transportLayer, Collections.emptyMap(), + metadataRegistry, Time.SYSTEM, () -> apiVersionsResponse); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/authenticator/TestDigestLoginModule.java b/clients/src/test/java/org/apache/kafka/common/security/authenticator/TestDigestLoginModule.java new file mode 100644 index 0000000..c27e853 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/authenticator/TestDigestLoginModule.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.authenticator; + +import java.util.List; +import java.util.Map; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.login.AppConfigurationEntry; +import javax.security.sasl.AuthorizeCallback; +import javax.security.sasl.RealmCallback; + +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; +import org.apache.kafka.common.security.plain.PlainLoginModule; + +/** + * Digest-MD5 login module for multi-mechanism tests. + * This login module uses the same format as PlainLoginModule and hence simply reuses the same methods. + * + */ +public class TestDigestLoginModule extends PlainLoginModule { + + public static class DigestServerCallbackHandler implements AuthenticateCallbackHandler { + + @Override + public void configure(Map configs, String saslMechanism, List jaasConfigEntries) { + } + + @Override + public void handle(Callback[] callbacks) { + String username = null; + for (Callback callback : callbacks) { + if (callback instanceof NameCallback) { + NameCallback nameCallback = (NameCallback) callback; + if (TestJaasConfig.USERNAME.equals(nameCallback.getDefaultName())) { + nameCallback.setName(nameCallback.getDefaultName()); + username = TestJaasConfig.USERNAME; + } + } else if (callback instanceof PasswordCallback) { + PasswordCallback passwordCallback = (PasswordCallback) callback; + if (TestJaasConfig.USERNAME.equals(username)) + passwordCallback.setPassword(TestJaasConfig.PASSWORD.toCharArray()); + } else if (callback instanceof RealmCallback) { + RealmCallback realmCallback = (RealmCallback) callback; + realmCallback.setText(realmCallback.getDefaultText()); + } else if (callback instanceof AuthorizeCallback) { + AuthorizeCallback authCallback = (AuthorizeCallback) callback; + if (TestJaasConfig.USERNAME.equals(authCallback.getAuthenticationID())) { + authCallback.setAuthorized(true); + authCallback.setAuthorizedID(authCallback.getAuthenticationID()); + } + } + } + } + + @Override + public void close() { + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/authenticator/TestJaasConfig.java b/clients/src/test/java/org/apache/kafka/common/security/authenticator/TestJaasConfig.java new file mode 100644 index 0000000..f7ad140 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/authenticator/TestJaasConfig.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.authenticator; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import javax.security.auth.login.AppConfigurationEntry; +import javax.security.auth.login.Configuration; +import javax.security.auth.login.AppConfigurationEntry.LoginModuleControlFlag; + +import org.apache.kafka.common.config.types.Password; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; +import org.apache.kafka.common.security.plain.PlainLoginModule; +import org.apache.kafka.common.security.scram.ScramLoginModule; +import org.apache.kafka.common.security.scram.internals.ScramMechanism; + +public class TestJaasConfig extends Configuration { + + static final String LOGIN_CONTEXT_CLIENT = "KafkaClient"; + static final String LOGIN_CONTEXT_SERVER = "KafkaServer"; + + static final String USERNAME = "myuser"; + static final String PASSWORD = "mypassword"; + + private Map entryMap = new HashMap<>(); + + public static TestJaasConfig createConfiguration(String clientMechanism, List serverMechanisms) { + TestJaasConfig config = new TestJaasConfig(); + config.createOrUpdateEntry(LOGIN_CONTEXT_CLIENT, loginModule(clientMechanism), defaultClientOptions(clientMechanism)); + for (String mechanism : serverMechanisms) { + config.addEntry(LOGIN_CONTEXT_SERVER, loginModule(mechanism), defaultServerOptions(mechanism)); + } + Configuration.setConfiguration(config); + return config; + } + + public static Password jaasConfigProperty(String mechanism, String username, String password) { + return new Password(loginModule(mechanism) + " required username=" + username + " password=" + password + ";"); + } + + public static Password jaasConfigProperty(String mechanism, Map options) { + StringBuilder builder = new StringBuilder(); + builder.append(loginModule(mechanism)); + builder.append(" required"); + for (Map.Entry option : options.entrySet()) { + builder.append(' '); + builder.append(option.getKey()); + builder.append('='); + builder.append(option.getValue()); + } + builder.append(';'); + return new Password(builder.toString()); + } + + public void setClientOptions(String saslMechanism, String clientUsername, String clientPassword) { + Map options = new HashMap<>(); + if (clientUsername != null) + options.put("username", clientUsername); + if (clientPassword != null) + options.put("password", clientPassword); + Class loginModuleClass = ScramMechanism.isScram(saslMechanism) ? ScramLoginModule.class : PlainLoginModule.class; + createOrUpdateEntry(LOGIN_CONTEXT_CLIENT, loginModuleClass.getName(), options); + } + + public void createOrUpdateEntry(String name, String loginModule, Map options) { + AppConfigurationEntry entry = new AppConfigurationEntry(loginModule, LoginModuleControlFlag.REQUIRED, options); + entryMap.put(name, new AppConfigurationEntry[] {entry}); + } + + public void addEntry(String name, String loginModule, Map options) { + AppConfigurationEntry entry = new AppConfigurationEntry(loginModule, LoginModuleControlFlag.REQUIRED, options); + AppConfigurationEntry[] existing = entryMap.get(name); + AppConfigurationEntry[] newEntries = existing == null ? new AppConfigurationEntry[1] : Arrays.copyOf(existing, existing.length + 1); + newEntries[newEntries.length - 1] = entry; + entryMap.put(name, newEntries); + } + + @Override + public AppConfigurationEntry[] getAppConfigurationEntry(String name) { + return entryMap.get(name); + } + + private static String loginModule(String mechanism) { + String loginModule; + switch (mechanism) { + case "PLAIN": + loginModule = PlainLoginModule.class.getName(); + break; + case "DIGEST-MD5": + loginModule = TestDigestLoginModule.class.getName(); + break; + case "OAUTHBEARER": + loginModule = OAuthBearerLoginModule.class.getName(); + break; + default: + if (ScramMechanism.isScram(mechanism)) + loginModule = ScramLoginModule.class.getName(); + else + throw new IllegalArgumentException("Unsupported mechanism " + mechanism); + } + return loginModule; + } + + public static Map defaultClientOptions(String mechanism) { + switch (mechanism) { + case "OAUTHBEARER": + Map options = new HashMap<>(); + options.put("unsecuredLoginStringClaim_sub", USERNAME); + return options; + default: + return defaultClientOptions(); + } + } + + public static Map defaultClientOptions() { + Map options = new HashMap<>(); + options.put("username", USERNAME); + options.put("password", PASSWORD); + return options; + } + + public static Map defaultServerOptions(String mechanism) { + Map options = new HashMap<>(); + switch (mechanism) { + case "PLAIN": + case "DIGEST-MD5": + options.put("user_" + USERNAME, PASSWORD); + break; + case "OAUTHBEARER": + options.put("unsecuredLoginStringClaim_sub", USERNAME); + break; + default: + if (!ScramMechanism.isScram(mechanism)) + throw new IllegalArgumentException("Unsupported mechanism " + mechanism); + } + return options; + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/kerberos/KerberosNameTest.java b/clients/src/test/java/org/apache/kafka/common/security/kerberos/KerberosNameTest.java new file mode 100644 index 0000000..2e06396 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/kerberos/KerberosNameTest.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.kerberos; + +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.fail; + +public class KerberosNameTest { + + @Test + public void testParse() throws IOException { + List rules = Arrays.asList( + "RULE:[1:$1](App\\..*)s/App\\.(.*)/$1/g", + "RULE:[2:$1](App\\..*)s/App\\.(.*)/$1/g", + "DEFAULT" + ); + + KerberosShortNamer shortNamer = KerberosShortNamer.fromUnparsedRules("REALM.COM", rules); + + KerberosName name = KerberosName.parse("App.service-name/example.com@REALM.COM"); + assertEquals("App.service-name", name.serviceName()); + assertEquals("example.com", name.hostName()); + assertEquals("REALM.COM", name.realm()); + assertEquals("service-name", shortNamer.shortName(name)); + + name = KerberosName.parse("App.service-name@REALM.COM"); + assertEquals("App.service-name", name.serviceName()); + assertNull(name.hostName()); + assertEquals("REALM.COM", name.realm()); + assertEquals("service-name", shortNamer.shortName(name)); + + name = KerberosName.parse("user/host@REALM.COM"); + assertEquals("user", name.serviceName()); + assertEquals("host", name.hostName()); + assertEquals("REALM.COM", name.realm()); + assertEquals("user", shortNamer.shortName(name)); + } + + @Test + public void testToLowerCase() throws Exception { + List rules = Arrays.asList( + "RULE:[1:$1]/L", + "RULE:[2:$1](Test.*)s/ABC///L", + "RULE:[2:$1](ABC.*)s/ABC/XYZ/g/L", + "RULE:[2:$1](App\\..*)s/App\\.(.*)/$1/g/L", + "RULE:[2:$1]/L", + "DEFAULT" + ); + + KerberosShortNamer shortNamer = KerberosShortNamer.fromUnparsedRules("REALM.COM", rules); + + KerberosName name = KerberosName.parse("User@REALM.COM"); + assertEquals("user", shortNamer.shortName(name)); + + name = KerberosName.parse("TestABC/host@FOO.COM"); + assertEquals("test", shortNamer.shortName(name)); + + name = KerberosName.parse("ABC_User_ABC/host@FOO.COM"); + assertEquals("xyz_user_xyz", shortNamer.shortName(name)); + + name = KerberosName.parse("App.SERVICE-name/example.com@REALM.COM"); + assertEquals("service-name", shortNamer.shortName(name)); + + name = KerberosName.parse("User/root@REALM.COM"); + assertEquals("user", shortNamer.shortName(name)); + } + + @Test + public void testToUpperCase() throws Exception { + List rules = Arrays.asList( + "RULE:[1:$1]/U", + "RULE:[2:$1](Test.*)s/ABC///U", + "RULE:[2:$1](ABC.*)s/ABC/XYZ/g/U", + "RULE:[2:$1](App\\..*)s/App\\.(.*)/$1/g/U", + "RULE:[2:$1]/U", + "DEFAULT" + ); + + KerberosShortNamer shortNamer = KerberosShortNamer.fromUnparsedRules("REALM.COM", rules); + + KerberosName name = KerberosName.parse("User@REALM.COM"); + assertEquals("USER", shortNamer.shortName(name)); + + name = KerberosName.parse("TestABC/host@FOO.COM"); + assertEquals("TEST", shortNamer.shortName(name)); + + name = KerberosName.parse("ABC_User_ABC/host@FOO.COM"); + assertEquals("XYZ_USER_XYZ", shortNamer.shortName(name)); + + name = KerberosName.parse("App.SERVICE-name/example.com@REALM.COM"); + assertEquals("SERVICE-NAME", shortNamer.shortName(name)); + + name = KerberosName.parse("User/root@REALM.COM"); + assertEquals("USER", shortNamer.shortName(name)); + } + + @Test + public void testInvalidRules() { + testInvalidRule(Arrays.asList("default")); + testInvalidRule(Arrays.asList("DEFAUL")); + testInvalidRule(Arrays.asList("DEFAULT/L")); + testInvalidRule(Arrays.asList("DEFAULT/g")); + + testInvalidRule(Arrays.asList("rule:[1:$1]")); + testInvalidRule(Arrays.asList("rule:[1:$1]/L/U")); + testInvalidRule(Arrays.asList("rule:[1:$1]/U/L")); + testInvalidRule(Arrays.asList("rule:[1:$1]/LU")); + testInvalidRule(Arrays.asList("RULE:[1:$1/L")); + testInvalidRule(Arrays.asList("RULE:[1:$1]/l")); + testInvalidRule(Arrays.asList("RULE:[2:$1](ABC.*)s/ABC/XYZ/L/g")); + } + + private void testInvalidRule(List rules) { + try { + KerberosShortNamer.fromUnparsedRules("REALM.COM", rules); + fail("should have thrown IllegalArgumentException"); + } catch (IllegalArgumentException e) { + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/kerberos/KerberosRuleTest.java b/clients/src/test/java/org/apache/kafka/common/security/kerberos/KerberosRuleTest.java new file mode 100644 index 0000000..9c785c4 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/kerberos/KerberosRuleTest.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.kerberos; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +import org.junit.jupiter.api.Test; + +public class KerberosRuleTest { + + @Test + public void testReplaceParameters() throws BadFormatString { + // positive test cases + assertEquals(KerberosRule.replaceParameters("", new String[0]), ""); + assertEquals(KerberosRule.replaceParameters("hello", new String[0]), "hello"); + assertEquals(KerberosRule.replaceParameters("", new String[]{"too", "many", "parameters", "are", "ok"}), ""); + assertEquals(KerberosRule.replaceParameters("hello", new String[]{"too", "many", "parameters", "are", "ok"}), "hello"); + assertEquals(KerberosRule.replaceParameters("hello $0", new String[]{"too", "many", "parameters", "are", "ok"}), "hello too"); + assertEquals(KerberosRule.replaceParameters("hello $0", new String[]{"no recursion $1"}), "hello no recursion $1"); + + // negative test cases + try { + KerberosRule.replaceParameters("$0", new String[]{}); + fail("An out-of-bounds parameter number should trigger an exception!"); + } catch (BadFormatString bfs) { + } + try { + KerberosRule.replaceParameters("hello $a", new String[]{"does not matter"}); + fail("A malformed parameter name should trigger an exception!"); + } catch (BadFormatString bfs) { + } + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerExtensionsValidatorCallbackTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerExtensionsValidatorCallbackTest.java new file mode 100644 index 0000000..82ff4fa --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerExtensionsValidatorCallbackTest.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer; + +import org.apache.kafka.common.security.auth.SaslExtensions; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class OAuthBearerExtensionsValidatorCallbackTest { + private static final OAuthBearerToken TOKEN = new OAuthBearerTokenMock(); + + @Test + public void testValidatedExtensionsAreReturned() { + Map extensions = new HashMap<>(); + extensions.put("hello", "bye"); + + OAuthBearerExtensionsValidatorCallback callback = new OAuthBearerExtensionsValidatorCallback(TOKEN, new SaslExtensions(extensions)); + + assertTrue(callback.validatedExtensions().isEmpty()); + assertTrue(callback.invalidExtensions().isEmpty()); + callback.valid("hello"); + assertFalse(callback.validatedExtensions().isEmpty()); + assertEquals("bye", callback.validatedExtensions().get("hello")); + assertTrue(callback.invalidExtensions().isEmpty()); + } + + @Test + public void testInvalidExtensionsAndErrorMessagesAreReturned() { + Map extensions = new HashMap<>(); + extensions.put("hello", "bye"); + + OAuthBearerExtensionsValidatorCallback callback = new OAuthBearerExtensionsValidatorCallback(TOKEN, new SaslExtensions(extensions)); + + assertTrue(callback.validatedExtensions().isEmpty()); + assertTrue(callback.invalidExtensions().isEmpty()); + callback.error("hello", "error"); + assertFalse(callback.invalidExtensions().isEmpty()); + assertEquals("error", callback.invalidExtensions().get("hello")); + assertTrue(callback.validatedExtensions().isEmpty()); + } + + /** + * Extensions that are neither validated or invalidated must not be present in either maps + */ + @Test + public void testUnvalidatedExtensionsAreIgnored() { + Map extensions = new HashMap<>(); + extensions.put("valid", "valid"); + extensions.put("error", "error"); + extensions.put("nothing", "nothing"); + + OAuthBearerExtensionsValidatorCallback callback = new OAuthBearerExtensionsValidatorCallback(TOKEN, new SaslExtensions(extensions)); + callback.error("error", "error"); + callback.valid("valid"); + + assertFalse(callback.validatedExtensions().containsKey("nothing")); + assertFalse(callback.invalidExtensions().containsKey("nothing")); + assertEquals("nothing", callback.ignoredExtensions().get("nothing")); + } + + @Test + public void testCannotValidateExtensionWhichWasNotGiven() { + Map extensions = new HashMap<>(); + extensions.put("hello", "bye"); + + OAuthBearerExtensionsValidatorCallback callback = new OAuthBearerExtensionsValidatorCallback(TOKEN, new SaslExtensions(extensions)); + + assertThrows(IllegalArgumentException.class, () -> callback.valid("???")); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModuleTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModuleTest.java new file mode 100644 index 0000000..ea03ec5 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginModuleTest.java @@ -0,0 +1,439 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verifyNoInteractions; + +import java.io.IOException; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +import javax.security.auth.Subject; +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.auth.login.AppConfigurationEntry; +import javax.security.auth.login.LoginException; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; +import org.apache.kafka.common.security.auth.SaslExtensionsCallback; +import org.apache.kafka.common.security.auth.SaslExtensions; +import org.junit.jupiter.api.Test; + +public class OAuthBearerLoginModuleTest { + + public static final SaslExtensions RAISE_UNSUPPORTED_CB_EXCEPTION_FLAG = null; + + private static class TestCallbackHandler implements AuthenticateCallbackHandler { + private final OAuthBearerToken[] tokens; + private int index = 0; + private int extensionsIndex = 0; + private final SaslExtensions[] extensions; + + public TestCallbackHandler(OAuthBearerToken[] tokens, SaslExtensions[] extensions) { + this.tokens = Objects.requireNonNull(tokens); + this.extensions = extensions; + } + + @Override + public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { + for (Callback callback : callbacks) { + if (callback instanceof OAuthBearerTokenCallback) + try { + handleCallback((OAuthBearerTokenCallback) callback); + } catch (KafkaException e) { + throw new IOException(e.getMessage(), e); + } + else if (callback instanceof SaslExtensionsCallback) { + try { + handleExtensionsCallback((SaslExtensionsCallback) callback); + } catch (KafkaException e) { + throw new IOException(e.getMessage(), e); + } + } else + throw new UnsupportedCallbackException(callback); + } + } + + @Override + public void configure(Map configs, String saslMechanism, + List jaasConfigEntries) { + // empty + } + + @Override + public void close() { + // empty + } + + private void handleCallback(OAuthBearerTokenCallback callback) throws IOException { + if (callback.token() != null) + throw new IllegalArgumentException("Callback had a token already"); + if (tokens.length > index) + callback.token(tokens[index++]); + else + throw new IOException("no more tokens"); + } + + private void handleExtensionsCallback(SaslExtensionsCallback callback) throws IOException, UnsupportedCallbackException { + if (extensions.length > extensionsIndex) { + SaslExtensions extension = extensions[extensionsIndex++]; + + if (extension == RAISE_UNSUPPORTED_CB_EXCEPTION_FLAG) { + throw new UnsupportedCallbackException(callback); + } + + callback.extensions(extension); + } else + throw new IOException("no more extensions"); + } + } + + @Test + public void login1Commit1Login2Commit2Logout1Login3Commit3Logout2() throws LoginException { + /* + * Invoke login()/commit() on loginModule1; invoke login/commit() on + * loginModule2; invoke logout() on loginModule1; invoke login()/commit() on + * loginModule3; invoke logout() on loginModule2 + */ + Subject subject = new Subject(); + Set privateCredentials = subject.getPrivateCredentials(); + Set publicCredentials = subject.getPublicCredentials(); + + // Create callback handler + OAuthBearerToken[] tokens = new OAuthBearerToken[] {mock(OAuthBearerToken.class), + mock(OAuthBearerToken.class), mock(OAuthBearerToken.class)}; + SaslExtensions[] extensions = new SaslExtensions[] {mock(SaslExtensions.class), + mock(SaslExtensions.class), mock(SaslExtensions.class)}; + TestCallbackHandler testTokenCallbackHandler = new TestCallbackHandler(tokens, extensions); + + // Create login modules + OAuthBearerLoginModule loginModule1 = new OAuthBearerLoginModule(); + loginModule1.initialize(subject, testTokenCallbackHandler, Collections.emptyMap(), + Collections.emptyMap()); + OAuthBearerLoginModule loginModule2 = new OAuthBearerLoginModule(); + loginModule2.initialize(subject, testTokenCallbackHandler, Collections.emptyMap(), + Collections.emptyMap()); + OAuthBearerLoginModule loginModule3 = new OAuthBearerLoginModule(); + loginModule3.initialize(subject, testTokenCallbackHandler, Collections.emptyMap(), + Collections.emptyMap()); + + // Should start with nothing + assertEquals(0, privateCredentials.size()); + assertEquals(0, publicCredentials.size()); + loginModule1.login(); + // Should still have nothing until commit() is called + assertEquals(0, privateCredentials.size()); + assertEquals(0, publicCredentials.size()); + loginModule1.commit(); + // Now we should have the first token and extensions + assertEquals(1, privateCredentials.size()); + assertEquals(1, publicCredentials.size()); + assertSame(tokens[0], privateCredentials.iterator().next()); + assertSame(extensions[0], publicCredentials.iterator().next()); + + // Now login on loginModule2 to get the second token + // loginModule2 does not support the extensions callback and will raise UnsupportedCallbackException + loginModule2.login(); + // Should still have just the first token and extensions + assertEquals(1, privateCredentials.size()); + assertEquals(1, publicCredentials.size()); + assertSame(tokens[0], privateCredentials.iterator().next()); + assertSame(extensions[0], publicCredentials.iterator().next()); + loginModule2.commit(); + // Should have the first and second tokens at this point + assertEquals(2, privateCredentials.size()); + assertEquals(2, publicCredentials.size()); + Iterator iterator = privateCredentials.iterator(); + Iterator publicIterator = publicCredentials.iterator(); + assertNotSame(tokens[2], iterator.next()); + assertNotSame(tokens[2], iterator.next()); + assertNotSame(extensions[2], publicIterator.next()); + assertNotSame(extensions[2], publicIterator.next()); + // finally logout() on loginModule1 + loginModule1.logout(); + // Now we should have just the second token and extension + assertEquals(1, privateCredentials.size()); + assertEquals(1, publicCredentials.size()); + assertSame(tokens[1], privateCredentials.iterator().next()); + assertSame(extensions[1], publicCredentials.iterator().next()); + + // Now login on loginModule3 to get the third token + loginModule3.login(); + // Should still have just the second token and extensions + assertEquals(1, privateCredentials.size()); + assertEquals(1, publicCredentials.size()); + assertSame(tokens[1], privateCredentials.iterator().next()); + assertSame(extensions[1], publicCredentials.iterator().next()); + loginModule3.commit(); + // Should have the second and third tokens at this point + assertEquals(2, privateCredentials.size()); + assertEquals(2, publicCredentials.size()); + iterator = privateCredentials.iterator(); + publicIterator = publicCredentials.iterator(); + assertNotSame(tokens[0], iterator.next()); + assertNotSame(tokens[0], iterator.next()); + assertNotSame(extensions[0], publicIterator.next()); + assertNotSame(extensions[0], publicIterator.next()); + // finally logout() on loginModule2 + loginModule2.logout(); + // Now we should have just the third token + assertEquals(1, privateCredentials.size()); + assertEquals(1, publicCredentials.size()); + assertSame(tokens[2], privateCredentials.iterator().next()); + assertSame(extensions[2], publicCredentials.iterator().next()); + + verifyNoInteractions((Object[]) tokens); + verifyNoInteractions((Object[]) extensions); + } + + @Test + public void login1Commit1Logout1Login2Commit2Logout2() throws LoginException { + /* + * Invoke login()/commit() on loginModule1; invoke logout() on loginModule1; + * invoke login()/commit() on loginModule2; invoke logout() on loginModule2 + */ + Subject subject = new Subject(); + Set privateCredentials = subject.getPrivateCredentials(); + Set publicCredentials = subject.getPublicCredentials(); + + // Create callback handler + OAuthBearerToken[] tokens = new OAuthBearerToken[] {mock(OAuthBearerToken.class), + mock(OAuthBearerToken.class)}; + SaslExtensions[] extensions = new SaslExtensions[] {mock(SaslExtensions.class), + mock(SaslExtensions.class)}; + TestCallbackHandler testTokenCallbackHandler = new TestCallbackHandler(tokens, extensions); + + // Create login modules + OAuthBearerLoginModule loginModule1 = new OAuthBearerLoginModule(); + loginModule1.initialize(subject, testTokenCallbackHandler, Collections.emptyMap(), + Collections.emptyMap()); + OAuthBearerLoginModule loginModule2 = new OAuthBearerLoginModule(); + loginModule2.initialize(subject, testTokenCallbackHandler, Collections.emptyMap(), + Collections.emptyMap()); + + // Should start with nothing + assertEquals(0, privateCredentials.size()); + assertEquals(0, publicCredentials.size()); + loginModule1.login(); + // Should still have nothing until commit() is called + assertEquals(0, privateCredentials.size()); + assertEquals(0, publicCredentials.size()); + loginModule1.commit(); + // Now we should have the first token + assertEquals(1, privateCredentials.size()); + assertEquals(1, publicCredentials.size()); + assertSame(tokens[0], privateCredentials.iterator().next()); + assertSame(extensions[0], publicCredentials.iterator().next()); + loginModule1.logout(); + // Should have nothing again + assertEquals(0, privateCredentials.size()); + assertEquals(0, publicCredentials.size()); + + loginModule2.login(); + // Should still have nothing until commit() is called + assertEquals(0, privateCredentials.size()); + assertEquals(0, publicCredentials.size()); + loginModule2.commit(); + // Now we should have the second token + assertEquals(1, privateCredentials.size()); + assertEquals(1, publicCredentials.size()); + assertSame(tokens[1], privateCredentials.iterator().next()); + assertSame(extensions[1], publicCredentials.iterator().next()); + loginModule2.logout(); + // Should have nothing again + assertEquals(0, privateCredentials.size()); + assertEquals(0, publicCredentials.size()); + + verifyNoInteractions((Object[]) tokens); + verifyNoInteractions((Object[]) extensions); + } + + @Test + public void loginAbortLoginCommitLogout() throws LoginException { + /* + * Invoke login(); invoke abort(); invoke login(); logout() + */ + Subject subject = new Subject(); + Set privateCredentials = subject.getPrivateCredentials(); + Set publicCredentials = subject.getPublicCredentials(); + + // Create callback handler + OAuthBearerToken[] tokens = new OAuthBearerToken[] {mock(OAuthBearerToken.class), + mock(OAuthBearerToken.class)}; + SaslExtensions[] extensions = new SaslExtensions[] {mock(SaslExtensions.class), + mock(SaslExtensions.class)}; + TestCallbackHandler testTokenCallbackHandler = new TestCallbackHandler(tokens, extensions); + + // Create login module + OAuthBearerLoginModule loginModule = new OAuthBearerLoginModule(); + loginModule.initialize(subject, testTokenCallbackHandler, Collections.emptyMap(), + Collections.emptyMap()); + + // Should start with nothing + assertEquals(0, privateCredentials.size()); + assertEquals(0, publicCredentials.size()); + loginModule.login(); + // Should still have nothing until commit() is called + assertEquals(0, privateCredentials.size()); + assertEquals(0, publicCredentials.size()); + loginModule.abort(); + // Should still have nothing since we aborted + assertEquals(0, privateCredentials.size()); + assertEquals(0, publicCredentials.size()); + + loginModule.login(); + // Should still have nothing until commit() is called + assertEquals(0, privateCredentials.size()); + assertEquals(0, publicCredentials.size()); + loginModule.commit(); + // Now we should have the second token + assertEquals(1, privateCredentials.size()); + assertEquals(1, publicCredentials.size()); + assertSame(tokens[1], privateCredentials.iterator().next()); + assertSame(extensions[1], publicCredentials.iterator().next()); + loginModule.logout(); + // Should have nothing again + assertEquals(0, privateCredentials.size()); + assertEquals(0, publicCredentials.size()); + + verifyNoInteractions((Object[]) tokens); + verifyNoInteractions((Object[]) extensions); + } + + @Test + public void login1Commit1Login2Abort2Login3Commit3Logout3() throws LoginException { + /* + * Invoke login()/commit() on loginModule1; invoke login()/abort() on + * loginModule2; invoke login()/commit()/logout() on loginModule3 + */ + Subject subject = new Subject(); + Set privateCredentials = subject.getPrivateCredentials(); + Set publicCredentials = subject.getPublicCredentials(); + + // Create callback handler + OAuthBearerToken[] tokens = new OAuthBearerToken[] {mock(OAuthBearerToken.class), + mock(OAuthBearerToken.class), mock(OAuthBearerToken.class)}; + SaslExtensions[] extensions = new SaslExtensions[] {mock(SaslExtensions.class), + mock(SaslExtensions.class), mock(SaslExtensions.class)}; + TestCallbackHandler testTokenCallbackHandler = new TestCallbackHandler(tokens, extensions); + + // Create login modules + OAuthBearerLoginModule loginModule1 = new OAuthBearerLoginModule(); + loginModule1.initialize(subject, testTokenCallbackHandler, Collections.emptyMap(), + Collections.emptyMap()); + OAuthBearerLoginModule loginModule2 = new OAuthBearerLoginModule(); + loginModule2.initialize(subject, testTokenCallbackHandler, Collections.emptyMap(), + Collections.emptyMap()); + OAuthBearerLoginModule loginModule3 = new OAuthBearerLoginModule(); + loginModule3.initialize(subject, testTokenCallbackHandler, Collections.emptyMap(), + Collections.emptyMap()); + + // Should start with nothing + assertEquals(0, privateCredentials.size()); + assertEquals(0, publicCredentials.size()); + loginModule1.login(); + // Should still have nothing until commit() is called + assertEquals(0, privateCredentials.size()); + assertEquals(0, publicCredentials.size()); + loginModule1.commit(); + // Now we should have the first token + assertEquals(1, privateCredentials.size()); + assertEquals(1, publicCredentials.size()); + assertSame(tokens[0], privateCredentials.iterator().next()); + assertSame(extensions[0], publicCredentials.iterator().next()); + + // Now go get the second token + loginModule2.login(); + // Should still have first token + assertEquals(1, privateCredentials.size()); + assertEquals(1, publicCredentials.size()); + assertSame(tokens[0], privateCredentials.iterator().next()); + assertSame(extensions[0], publicCredentials.iterator().next()); + loginModule2.abort(); + // Should still have just the first token because we aborted + assertEquals(1, privateCredentials.size()); + assertSame(tokens[0], privateCredentials.iterator().next()); + assertEquals(1, publicCredentials.size()); + assertSame(extensions[0], publicCredentials.iterator().next()); + + // Now go get the third token + loginModule2.login(); + // Should still have first token + assertEquals(1, privateCredentials.size()); + assertSame(tokens[0], privateCredentials.iterator().next()); + assertEquals(1, publicCredentials.size()); + assertSame(extensions[0], publicCredentials.iterator().next()); + loginModule2.commit(); + // Should have first and third tokens at this point + assertEquals(2, privateCredentials.size()); + Iterator iterator = privateCredentials.iterator(); + assertNotSame(tokens[1], iterator.next()); + assertNotSame(tokens[1], iterator.next()); + assertEquals(2, publicCredentials.size()); + Iterator publicIterator = publicCredentials.iterator(); + assertNotSame(extensions[1], publicIterator.next()); + assertNotSame(extensions[1], publicIterator.next()); + loginModule1.logout(); + // Now we should have just the third token + assertEquals(1, privateCredentials.size()); + assertSame(tokens[2], privateCredentials.iterator().next()); + assertEquals(1, publicCredentials.size()); + assertSame(extensions[2], publicCredentials.iterator().next()); + + verifyNoInteractions((Object[]) tokens); + verifyNoInteractions((Object[]) extensions); + } + + /** + * 2.1.0 added customizable SASL extensions and a new callback type. + * Ensure that old, custom-written callbackHandlers that do not handle the callback work + */ + @Test + public void commitDoesNotThrowOnUnsupportedExtensionsCallback() throws LoginException { + Subject subject = new Subject(); + + // Create callback handler + OAuthBearerToken[] tokens = new OAuthBearerToken[] {mock(OAuthBearerToken.class), + mock(OAuthBearerToken.class), mock(OAuthBearerToken.class)}; + TestCallbackHandler testTokenCallbackHandler = new TestCallbackHandler(tokens, new SaslExtensions[] {RAISE_UNSUPPORTED_CB_EXCEPTION_FLAG}); + + // Create login modules + OAuthBearerLoginModule loginModule1 = new OAuthBearerLoginModule(); + loginModule1.initialize(subject, testTokenCallbackHandler, Collections.emptyMap(), + Collections.emptyMap()); + + loginModule1.login(); + // Should populate public credentials with SaslExtensions and not throw an exception + loginModule1.commit(); + SaslExtensions extensions = subject.getPublicCredentials(SaslExtensions.class).iterator().next(); + assertNotNull(extensions); + assertTrue(extensions.map().isEmpty()); + + verifyNoInteractions((Object[]) tokens); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerSaslClienCallbackHandlerTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerSaslClienCallbackHandlerTest.java new file mode 100644 index 0000000..b13c0f8 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerSaslClienCallbackHandlerTest.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; +import java.util.Collections; +import java.util.Set; + +import javax.security.auth.Subject; +import javax.security.auth.callback.Callback; + +import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerSaslClientCallbackHandler; +import org.junit.jupiter.api.Test; + +public class OAuthBearerSaslClienCallbackHandlerTest { + private static OAuthBearerToken createTokenWithLifetimeMillis(final long lifetimeMillis) { + return new OAuthBearerToken() { + @Override + public String value() { + return null; + } + + @Override + public Long startTimeMs() { + return null; + } + + @Override + public Set scope() { + return null; + } + + @Override + public String principalName() { + return null; + } + + @Override + public long lifetimeMs() { + return lifetimeMillis; + } + }; + } + + @Test + public void testWithZeroTokens() { + OAuthBearerSaslClientCallbackHandler handler = createCallbackHandler(); + PrivilegedActionException e = assertThrows(PrivilegedActionException.class, () -> Subject.doAs(new Subject(), + (PrivilegedExceptionAction) () -> { + OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback(); + handler.handle(new Callback[] {callback}); + return null; + } + )); + assertEquals(IOException.class, e.getCause().getClass()); + } + + @Test() + public void testWithPotentiallyMultipleTokens() throws Exception { + OAuthBearerSaslClientCallbackHandler handler = createCallbackHandler(); + Subject.doAs(new Subject(), (PrivilegedExceptionAction) () -> { + final int maxTokens = 4; + final Set privateCredentials = Subject.getSubject(AccessController.getContext()) + .getPrivateCredentials(); + privateCredentials.clear(); + for (int num = 1; num <= maxTokens; ++num) { + privateCredentials.add(createTokenWithLifetimeMillis(num)); + OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback(); + handler.handle(new Callback[] {callback}); + assertEquals(num, callback.token().lifetimeMs()); + } + return null; + }); + } + + private static OAuthBearerSaslClientCallbackHandler createCallbackHandler() { + OAuthBearerSaslClientCallbackHandler handler = new OAuthBearerSaslClientCallbackHandler(); + handler.configure(Collections.emptyMap(), OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, + Collections.emptyList()); + return handler; + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerTokenCallbackTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerTokenCallbackTest.java new file mode 100644 index 0000000..07ce1b2 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerTokenCallbackTest.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; + +import java.util.Collections; +import java.util.Set; + +import org.junit.jupiter.api.Test; + +public class OAuthBearerTokenCallbackTest { + private static final OAuthBearerToken TOKEN = new OAuthBearerToken() { + @Override + public String value() { + return "value"; + } + + @Override + public Long startTimeMs() { + return null; + } + + @Override + public Set scope() { + return Collections.emptySet(); + } + + @Override + public String principalName() { + return "principalName"; + } + + @Override + public long lifetimeMs() { + return 0; + } + }; + + @Test + public void testError() { + String errorCode = "errorCode"; + String errorDescription = "errorDescription"; + String errorUri = "errorUri"; + OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback(); + callback.error(errorCode, errorDescription, errorUri); + assertEquals(errorCode, callback.errorCode()); + assertEquals(errorDescription, callback.errorDescription()); + assertEquals(errorUri, callback.errorUri()); + assertNull(callback.token()); + } + + @Test + public void testToken() { + OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback(); + callback.token(TOKEN); + assertSame(TOKEN, callback.token()); + assertNull(callback.errorCode()); + assertNull(callback.errorDescription()); + assertNull(callback.errorUri()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerTokenMock.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerTokenMock.java new file mode 100644 index 0000000..994c923 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerTokenMock.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer; + +import java.util.Set; + +public class OAuthBearerTokenMock implements OAuthBearerToken { + @Override + public String value() { + return null; + } + + @Override + public Set scope() { + return null; + } + + @Override + public long lifetimeMs() { + return 0; + } + + @Override + public String principalName() { + return null; + } + + @Override + public Long startTimeMs() { + return null; + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerValidatorCallbackTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerValidatorCallbackTest.java new file mode 100644 index 0000000..b266c86 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerValidatorCallbackTest.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; + +import java.util.Collections; +import java.util.Set; + +public class OAuthBearerValidatorCallbackTest { + private static final OAuthBearerToken TOKEN = new OAuthBearerToken() { + @Override + public String value() { + return "value"; + } + + @Override + public Long startTimeMs() { + return null; + } + + @Override + public Set scope() { + return Collections.emptySet(); + } + + @Override + public String principalName() { + return "principalName"; + } + + @Override + public long lifetimeMs() { + return 0; + } + }; + + @Test + public void testError() { + String errorStatus = "errorStatus"; + String errorScope = "errorScope"; + String errorOpenIDConfiguration = "errorOpenIDConfiguration"; + OAuthBearerValidatorCallback callback = new OAuthBearerValidatorCallback(TOKEN.value()); + callback.error(errorStatus, errorScope, errorOpenIDConfiguration); + assertEquals(errorStatus, callback.errorStatus()); + assertEquals(errorScope, callback.errorScope()); + assertEquals(errorOpenIDConfiguration, callback.errorOpenIDConfiguration()); + assertNull(callback.token()); + } + + @Test + public void testToken() { + OAuthBearerValidatorCallback callback = new OAuthBearerValidatorCallback(TOKEN.value()); + callback.token(TOKEN); + assertSame(TOKEN, callback.token()); + assertNull(callback.errorStatus()); + assertNull(callback.errorScope()); + assertNull(callback.errorOpenIDConfiguration()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerClientInitialResponseTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerClientInitialResponseTest.java new file mode 100644 index 0000000..2bf3f84 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerClientInitialResponseTest.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer.internals; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.apache.kafka.common.security.auth.SaslExtensions; +import org.junit.jupiter.api.Test; + +import javax.security.sasl.SaslException; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; + +public class OAuthBearerClientInitialResponseTest { + + /* + Test how a client would build a response + */ + @Test + public void testBuildClientResponseToBytes() throws Exception { + String expectedMesssage = "n,,\u0001auth=Bearer 123.345.567\u0001nineteen=42\u0001\u0001"; + + Map extensions = new HashMap<>(); + extensions.put("nineteen", "42"); + OAuthBearerClientInitialResponse response = new OAuthBearerClientInitialResponse("123.345.567", new SaslExtensions(extensions)); + + String message = new String(response.toBytes(), StandardCharsets.UTF_8); + + assertEquals(expectedMesssage, message); + } + + @Test + public void testBuildServerResponseToBytes() throws Exception { + String serverMessage = "n,,\u0001auth=Bearer 123.345.567\u0001nineteen=42\u0001\u0001"; + OAuthBearerClientInitialResponse response = new OAuthBearerClientInitialResponse(serverMessage.getBytes(StandardCharsets.UTF_8)); + + String message = new String(response.toBytes(), StandardCharsets.UTF_8); + + assertEquals(serverMessage, message); + } + + @Test + public void testThrowsSaslExceptionOnInvalidExtensionKey() throws Exception { + Map extensions = new HashMap<>(); + extensions.put("19", "42"); // keys can only be a-z + assertThrows(SaslException.class, () -> new OAuthBearerClientInitialResponse("123.345.567", new SaslExtensions(extensions))); + } + + @Test + public void testToken() throws Exception { + String message = "n,,\u0001auth=Bearer 123.345.567\u0001\u0001"; + OAuthBearerClientInitialResponse response = new OAuthBearerClientInitialResponse(message.getBytes(StandardCharsets.UTF_8)); + assertEquals("123.345.567", response.tokenValue()); + assertEquals("", response.authorizationId()); + } + + @Test + public void testAuthorizationId() throws Exception { + String message = "n,a=myuser,\u0001auth=Bearer 345\u0001\u0001"; + OAuthBearerClientInitialResponse response = new OAuthBearerClientInitialResponse(message.getBytes(StandardCharsets.UTF_8)); + assertEquals("345", response.tokenValue()); + assertEquals("myuser", response.authorizationId()); + } + + @Test + public void testExtensions() throws Exception { + String message = "n,,\u0001propA=valueA1, valueA2\u0001auth=Bearer 567\u0001propB=valueB\u0001\u0001"; + OAuthBearerClientInitialResponse response = new OAuthBearerClientInitialResponse(message.getBytes(StandardCharsets.UTF_8)); + assertEquals("567", response.tokenValue()); + assertEquals("", response.authorizationId()); + assertEquals("valueA1, valueA2", response.extensions().map().get("propA")); + assertEquals("valueB", response.extensions().map().get("propB")); + } + + // The example in the RFC uses `vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg==` as the token + // But since we use Base64Url encoding, padding is omitted. Hence this test verifies without '='. + @Test + public void testRfc7688Example() throws Exception { + String message = "n,a=user@example.com,\u0001host=server.example.com\u0001port=143\u0001" + + "auth=Bearer vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg\u0001\u0001"; + OAuthBearerClientInitialResponse response = new OAuthBearerClientInitialResponse(message.getBytes(StandardCharsets.UTF_8)); + assertEquals("vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg", response.tokenValue()); + assertEquals("user@example.com", response.authorizationId()); + assertEquals("server.example.com", response.extensions().map().get("host")); + assertEquals("143", response.extensions().map().get("port")); + } + + @Test + public void testNoExtensionsFromByteArray() throws Exception { + String message = "n,a=user@example.com,\u0001" + + "auth=Bearer vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg\u0001\u0001"; + OAuthBearerClientInitialResponse response = new OAuthBearerClientInitialResponse(message.getBytes(StandardCharsets.UTF_8)); + assertEquals("vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg", response.tokenValue()); + assertEquals("user@example.com", response.authorizationId()); + assertTrue(response.extensions().map().isEmpty()); + } + + @Test + public void testNoExtensionsFromTokenAndNullExtensions() throws Exception { + OAuthBearerClientInitialResponse response = new OAuthBearerClientInitialResponse("token", null); + assertTrue(response.extensions().map().isEmpty()); + } + + @Test + public void testValidateNullExtensions() throws Exception { + OAuthBearerClientInitialResponse.validateExtensions(null); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClientTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClientTest.java new file mode 100644 index 0000000..50ed3fd --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslClientTest.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer.internals; + +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.security.auth.SaslExtensionsCallback; +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; +import org.apache.kafka.common.security.auth.SaslExtensions; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback; +import org.junit.jupiter.api.Test; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.auth.login.AppConfigurationEntry; +import javax.security.sasl.SaslException; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +public class OAuthBearerSaslClientTest { + + private static final Map TEST_PROPERTIES = new LinkedHashMap() { + { + put("One", "1"); + put("Two", "2"); + put("Three", "3"); + } + }; + private SaslExtensions testExtensions = new SaslExtensions(TEST_PROPERTIES); + private final String errorMessage = "Error as expected!"; + + public class ExtensionsCallbackHandler implements AuthenticateCallbackHandler { + private boolean configured = false; + private boolean toThrow; + + ExtensionsCallbackHandler(boolean toThrow) { + this.toThrow = toThrow; + } + + public boolean configured() { + return configured; + } + + @Override + public void configure(Map configs, String saslMechanism, List jaasConfigEntries) { + configured = true; + } + + @Override + public void handle(Callback[] callbacks) throws UnsupportedCallbackException { + for (Callback callback : callbacks) { + if (callback instanceof OAuthBearerTokenCallback) + ((OAuthBearerTokenCallback) callback).token(new OAuthBearerToken() { + @Override + public String value() { + return ""; + } + + @Override + public Set scope() { + return Collections.emptySet(); + } + + @Override + public long lifetimeMs() { + return 100; + } + + @Override + public String principalName() { + return "principalName"; + } + + @Override + public Long startTimeMs() { + return null; + } + }); + else if (callback instanceof SaslExtensionsCallback) { + if (toThrow) + throw new ConfigException(errorMessage); + else + ((SaslExtensionsCallback) callback).extensions(testExtensions); + } else + throw new UnsupportedCallbackException(callback); + } + } + + @Override + public void close() { + } + } + + @Test + public void testAttachesExtensionsToFirstClientMessage() throws Exception { + String expectedToken = new String(new OAuthBearerClientInitialResponse("", testExtensions).toBytes(), StandardCharsets.UTF_8); + + OAuthBearerSaslClient client = new OAuthBearerSaslClient(new ExtensionsCallbackHandler(false)); + + String message = new String(client.evaluateChallenge("".getBytes()), StandardCharsets.UTF_8); + + assertEquals(expectedToken, message); + } + + @Test + public void testNoExtensionsDoesNotAttachAnythingToFirstClientMessage() throws Exception { + TEST_PROPERTIES.clear(); + testExtensions = new SaslExtensions(TEST_PROPERTIES); + String expectedToken = new String(new OAuthBearerClientInitialResponse("", new SaslExtensions(TEST_PROPERTIES)).toBytes(), StandardCharsets.UTF_8); + OAuthBearerSaslClient client = new OAuthBearerSaslClient(new ExtensionsCallbackHandler(false)); + + String message = new String(client.evaluateChallenge("".getBytes()), StandardCharsets.UTF_8); + + assertEquals(expectedToken, message); + } + + @Test + public void testWrapsExtensionsCallbackHandlingErrorInSaslExceptionInFirstClientMessage() { + OAuthBearerSaslClient client = new OAuthBearerSaslClient(new ExtensionsCallbackHandler(true)); + try { + client.evaluateChallenge("".getBytes()); + fail("Should have failed with " + SaslException.class.getName()); + } catch (SaslException e) { + // assert it has caught our expected exception + assertEquals(ConfigException.class, e.getCause().getClass()); + assertEquals(errorMessage, e.getCause().getMessage()); + } + + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServerTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServerTest.java new file mode 100644 index 0000000..089c908 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/OAuthBearerSaslServerTest.java @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer.internals; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertNull; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.auth.login.LoginException; + +import org.apache.kafka.common.config.SaslConfigs; +import org.apache.kafka.common.config.types.Password; +import org.apache.kafka.common.errors.SaslAuthenticationException; +import org.apache.kafka.common.security.JaasContext; +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; +import org.apache.kafka.common.security.auth.SaslExtensions; +import org.apache.kafka.common.security.authenticator.SaslInternalConfigs; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerExtensionsValidatorCallback; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenMock; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerValidatorCallback; +import org.apache.kafka.common.security.oauthbearer.internals.unsecured.OAuthBearerConfigException; +import org.apache.kafka.common.security.oauthbearer.internals.unsecured.OAuthBearerUnsecuredLoginCallbackHandler; +import org.apache.kafka.common.security.oauthbearer.internals.unsecured.OAuthBearerUnsecuredValidatorCallbackHandler; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class OAuthBearerSaslServerTest { + private static final String USER = "user"; + private static final Map CONFIGS; + static { + String jaasConfigText = "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule Required" + + " unsecuredLoginStringClaim_sub=\"" + USER + "\";"; + Map tmp = new HashMap<>(); + tmp.put(SaslConfigs.SASL_JAAS_CONFIG, new Password(jaasConfigText)); + CONFIGS = Collections.unmodifiableMap(tmp); + } + private static final AuthenticateCallbackHandler LOGIN_CALLBACK_HANDLER; + static { + LOGIN_CALLBACK_HANDLER = new OAuthBearerUnsecuredLoginCallbackHandler(); + LOGIN_CALLBACK_HANDLER.configure(CONFIGS, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, + JaasContext.loadClientContext(CONFIGS).configurationEntries()); + } + private static final AuthenticateCallbackHandler VALIDATOR_CALLBACK_HANDLER; + private static final AuthenticateCallbackHandler EXTENSIONS_VALIDATOR_CALLBACK_HANDLER; + static { + VALIDATOR_CALLBACK_HANDLER = new OAuthBearerUnsecuredValidatorCallbackHandler(); + VALIDATOR_CALLBACK_HANDLER.configure(CONFIGS, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, + JaasContext.loadClientContext(CONFIGS).configurationEntries()); + // only validate extensions "firstKey" and "secondKey" + EXTENSIONS_VALIDATOR_CALLBACK_HANDLER = new OAuthBearerUnsecuredValidatorCallbackHandler() { + @Override + public void handle(Callback[] callbacks) throws UnsupportedCallbackException { + for (Callback callback : callbacks) { + if (callback instanceof OAuthBearerValidatorCallback) { + OAuthBearerValidatorCallback validationCallback = (OAuthBearerValidatorCallback) callback; + validationCallback.token(new OAuthBearerTokenMock()); + } else if (callback instanceof OAuthBearerExtensionsValidatorCallback) { + OAuthBearerExtensionsValidatorCallback extensionsCallback = (OAuthBearerExtensionsValidatorCallback) callback; + extensionsCallback.valid("firstKey"); + extensionsCallback.valid("secondKey"); + } else + throw new UnsupportedCallbackException(callback); + } + } + }; + } + private OAuthBearerSaslServer saslServer; + + @BeforeEach + public void setUp() { + saslServer = new OAuthBearerSaslServer(VALIDATOR_CALLBACK_HANDLER); + } + + @Test + public void noAuthorizationIdSpecified() throws Exception { + byte[] nextChallenge = saslServer + .evaluateResponse(clientInitialResponse(null)); + // also asserts that no authentication error is thrown if OAuthBearerExtensionsValidatorCallback is not supported + assertTrue(nextChallenge.length == 0, "Next challenge is not empty"); + } + + @Test + public void negotiatedProperty() throws Exception { + saslServer.evaluateResponse(clientInitialResponse(USER)); + OAuthBearerToken token = (OAuthBearerToken) saslServer.getNegotiatedProperty("OAUTHBEARER.token"); + assertNotNull(token); + assertEquals(token.lifetimeMs(), + saslServer.getNegotiatedProperty(SaslInternalConfigs.CREDENTIAL_LIFETIME_MS_SASL_NEGOTIATED_PROPERTY_KEY)); + } + + /** + * SASL Extensions that are validated by the callback handler should be accessible through the {@code #getNegotiatedProperty()} method + */ + @Test + public void savesCustomExtensionAsNegotiatedProperty() throws Exception { + Map customExtensions = new HashMap<>(); + customExtensions.put("firstKey", "value1"); + customExtensions.put("secondKey", "value2"); + + byte[] nextChallenge = saslServer + .evaluateResponse(clientInitialResponse(null, false, customExtensions)); + + assertTrue(nextChallenge.length == 0, "Next challenge is not empty"); + assertEquals("value1", saslServer.getNegotiatedProperty("firstKey")); + assertEquals("value2", saslServer.getNegotiatedProperty("secondKey")); + } + + /** + * SASL Extensions that were not recognized (neither validated nor invalidated) + * by the callback handler must not be accessible through the {@code #getNegotiatedProperty()} method + */ + @Test + public void unrecognizedExtensionsAreNotSaved() throws Exception { + saslServer = new OAuthBearerSaslServer(EXTENSIONS_VALIDATOR_CALLBACK_HANDLER); + Map customExtensions = new HashMap<>(); + customExtensions.put("firstKey", "value1"); + customExtensions.put("secondKey", "value1"); + customExtensions.put("thirdKey", "value1"); + + byte[] nextChallenge = saslServer + .evaluateResponse(clientInitialResponse(null, false, customExtensions)); + + assertTrue(nextChallenge.length == 0, "Next challenge is not empty"); + assertNull(saslServer.getNegotiatedProperty("thirdKey"), "Extensions not recognized by the server must be ignored"); + } + + /** + * If the callback handler handles the `OAuthBearerExtensionsValidatorCallback` + * and finds an invalid extension, SaslServer should throw an authentication exception + */ + @Test + public void throwsAuthenticationExceptionOnInvalidExtensions() { + OAuthBearerUnsecuredValidatorCallbackHandler invalidHandler = new OAuthBearerUnsecuredValidatorCallbackHandler() { + @Override + public void handle(Callback[] callbacks) throws UnsupportedCallbackException { + for (Callback callback : callbacks) { + if (callback instanceof OAuthBearerValidatorCallback) { + OAuthBearerValidatorCallback validationCallback = (OAuthBearerValidatorCallback) callback; + validationCallback.token(new OAuthBearerTokenMock()); + } else if (callback instanceof OAuthBearerExtensionsValidatorCallback) { + OAuthBearerExtensionsValidatorCallback extensionsCallback = (OAuthBearerExtensionsValidatorCallback) callback; + extensionsCallback.error("firstKey", "is not valid"); + extensionsCallback.error("secondKey", "is not valid either"); + } else + throw new UnsupportedCallbackException(callback); + } + } + }; + saslServer = new OAuthBearerSaslServer(invalidHandler); + Map customExtensions = new HashMap<>(); + customExtensions.put("firstKey", "value"); + customExtensions.put("secondKey", "value"); + + assertThrows(SaslAuthenticationException.class, + () -> saslServer.evaluateResponse(clientInitialResponse(null, false, customExtensions))); + } + + @Test + public void authorizatonIdEqualsAuthenticationId() throws Exception { + byte[] nextChallenge = saslServer + .evaluateResponse(clientInitialResponse(USER)); + assertTrue(nextChallenge.length == 0, "Next challenge is not empty"); + } + + @Test + public void authorizatonIdNotEqualsAuthenticationId() { + assertThrows(SaslAuthenticationException.class, + () -> saslServer.evaluateResponse(clientInitialResponse(USER + "x"))); + } + + @Test + public void illegalToken() throws Exception { + byte[] bytes = saslServer.evaluateResponse(clientInitialResponse(null, true, Collections.emptyMap())); + String challenge = new String(bytes, StandardCharsets.UTF_8); + assertEquals("{\"status\":\"invalid_token\"}", challenge); + } + + private byte[] clientInitialResponse(String authorizationId) + throws OAuthBearerConfigException, IOException, UnsupportedCallbackException, LoginException { + return clientInitialResponse(authorizationId, false); + } + + private byte[] clientInitialResponse(String authorizationId, boolean illegalToken) + throws OAuthBearerConfigException, IOException, UnsupportedCallbackException { + return clientInitialResponse(authorizationId, false, Collections.emptyMap()); + } + + private byte[] clientInitialResponse(String authorizationId, boolean illegalToken, Map customExtensions) + throws OAuthBearerConfigException, IOException, UnsupportedCallbackException { + OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback(); + LOGIN_CALLBACK_HANDLER.handle(new Callback[] {callback}); + OAuthBearerToken token = callback.token(); + String compactSerialization = token.value(); + + String tokenValue = compactSerialization + (illegalToken ? "AB" : ""); + return new OAuthBearerClientInitialResponse(tokenValue, authorizationId, new SaslExtensions(customExtensions)).toBytes(); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredentialRefreshConfigTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredentialRefreshConfigTest.java new file mode 100644 index 0000000..a9b4f13 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredentialRefreshConfigTest.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer.internals.expiring; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Collections; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.SaslConfigs; +import org.junit.jupiter.api.Test; + +public class ExpiringCredentialRefreshConfigTest { + @Test + public void fromGoodConfig() { + ExpiringCredentialRefreshConfig expiringCredentialRefreshConfig = new ExpiringCredentialRefreshConfig( + new ConfigDef().withClientSaslSupport().parse(Collections.emptyMap()), true); + assertEquals(Double.valueOf(SaslConfigs.DEFAULT_LOGIN_REFRESH_WINDOW_FACTOR), + Double.valueOf(expiringCredentialRefreshConfig.loginRefreshWindowFactor())); + assertEquals(Double.valueOf(SaslConfigs.DEFAULT_LOGIN_REFRESH_WINDOW_JITTER), + Double.valueOf(expiringCredentialRefreshConfig.loginRefreshWindowJitter())); + assertEquals(Short.valueOf(SaslConfigs.DEFAULT_LOGIN_REFRESH_MIN_PERIOD_SECONDS), + Short.valueOf(expiringCredentialRefreshConfig.loginRefreshMinPeriodSeconds())); + assertEquals(Short.valueOf(SaslConfigs.DEFAULT_LOGIN_REFRESH_BUFFER_SECONDS), + Short.valueOf(expiringCredentialRefreshConfig.loginRefreshBufferSeconds())); + assertTrue(expiringCredentialRefreshConfig.loginRefreshReloginAllowedBeforeLogout()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredentialRefreshingLoginTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredentialRefreshingLoginTest.java new file mode 100644 index 0000000..9a77c73 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/expiring/ExpiringCredentialRefreshingLoginTest.java @@ -0,0 +1,775 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer.internals.expiring; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.Date; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +import javax.security.auth.Subject; +import javax.security.auth.login.AppConfigurationEntry; +import javax.security.auth.login.Configuration; +import javax.security.auth.login.LoginContext; +import javax.security.auth.login.LoginException; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.SaslConfigs; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.security.oauthbearer.internals.expiring.ExpiringCredentialRefreshingLogin.LoginContextFactory; +import org.apache.kafka.common.utils.MockScheduler; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.junit.jupiter.api.Test; +import org.mockito.InOrder; +import org.mockito.Mockito; + +public class ExpiringCredentialRefreshingLoginTest { + private static final Configuration EMPTY_WILDCARD_CONFIGURATION; + static { + EMPTY_WILDCARD_CONFIGURATION = new Configuration() { + @Override + public AppConfigurationEntry[] getAppConfigurationEntry(String name) { + return new AppConfigurationEntry[0]; // match any name + } + }; + } + + /* + * An ExpiringCredentialRefreshingLogin that we can tell explicitly to + * create/remove an expiring credential with specific + * create/expire/absoluteLastRefresh times + */ + private static class TestExpiringCredentialRefreshingLogin extends ExpiringCredentialRefreshingLogin { + private ExpiringCredential expiringCredential; + private ExpiringCredential tmpExpiringCredential; + private final Time time; + private final long lifetimeMillis; + private final long absoluteLastRefreshTimeMs; + private final boolean clientReloginAllowedBeforeLogout; + + public TestExpiringCredentialRefreshingLogin(ExpiringCredentialRefreshConfig refreshConfig, + LoginContextFactory loginContextFactory, Time time, final long lifetimeMillis, + final long absoluteLastRefreshMs, boolean clientReloginAllowedBeforeLogout) { + super("contextName", EMPTY_WILDCARD_CONFIGURATION, refreshConfig, null, + TestExpiringCredentialRefreshingLogin.class, loginContextFactory, Objects.requireNonNull(time)); + this.time = time; + this.lifetimeMillis = lifetimeMillis; + this.absoluteLastRefreshTimeMs = absoluteLastRefreshMs; + this.clientReloginAllowedBeforeLogout = clientReloginAllowedBeforeLogout; + } + + public long getCreateMs() { + return time.milliseconds(); + } + + public long getExpireTimeMs() { + return time.milliseconds() + lifetimeMillis; + } + + /* + * Invoke at login time + */ + public void createNewExpiringCredential() { + if (!clientReloginAllowedBeforeLogout) + /* + * Was preceded by logout + */ + expiringCredential = internalNewExpiringCredential(); + else { + boolean initialLogin = expiringCredential == null; + if (initialLogin) + // no logout immediately after the initial login + this.expiringCredential = internalNewExpiringCredential(); + else + /* + * This is at least the second invocation of login; we will move the credential + * over upon logout, which should be invoked next + */ + this.tmpExpiringCredential = internalNewExpiringCredential(); + } + } + + /* + * Invoke at logout time + */ + public void clearExpiringCredential() { + if (!clientReloginAllowedBeforeLogout) + /* + * Have not yet invoked login + */ + expiringCredential = null; + else + /* + * login has already been invoked + */ + expiringCredential = tmpExpiringCredential; + } + + @Override + public ExpiringCredential expiringCredential() { + return expiringCredential; + } + + private ExpiringCredential internalNewExpiringCredential() { + return new ExpiringCredential() { + private final long createMs = getCreateMs(); + private final long expireTimeMs = getExpireTimeMs(); + + @Override + public String principalName() { + return "Created at " + new Date(createMs); + } + + @Override + public Long startTimeMs() { + return createMs; + } + + @Override + public long expireTimeMs() { + return expireTimeMs; + } + + @Override + public Long absoluteLastRefreshTimeMs() { + return absoluteLastRefreshTimeMs; + } + + // useful in debugger + @Override + public String toString() { + return String.format("startTimeMs=%d, expireTimeMs=%d, absoluteLastRefreshTimeMs=%s", startTimeMs(), + expireTimeMs(), absoluteLastRefreshTimeMs()); + } + + }; + } + } + + /* + * A class that will forward all login/logout/getSubject() calls to a mock while + * also telling an instance of TestExpiringCredentialRefreshingLogin to + * create/remove an expiring credential upon login/logout(). Basically we are + * getting the functionality of a mock while simultaneously in the same method + * call performing creation/removal of expiring credentials. + */ + private static class TestLoginContext extends LoginContext { + private final TestExpiringCredentialRefreshingLogin testExpiringCredentialRefreshingLogin; + private final LoginContext mockLoginContext; + + public TestLoginContext(TestExpiringCredentialRefreshingLogin testExpiringCredentialRefreshingLogin, + LoginContext mockLoginContext) throws LoginException { + super("contextName", null, null, EMPTY_WILDCARD_CONFIGURATION); + this.testExpiringCredentialRefreshingLogin = Objects.requireNonNull(testExpiringCredentialRefreshingLogin); + // sanity check to make sure it is likely a mock + if (Objects.requireNonNull(mockLoginContext).getClass().equals(LoginContext.class) + || mockLoginContext.getClass().equals(getClass())) + throw new IllegalArgumentException(); + this.mockLoginContext = mockLoginContext; + } + + @Override + public void login() throws LoginException { + /* + * Here is where we get the functionality of a mock while simultaneously + * performing the creation of an expiring credential + */ + mockLoginContext.login(); + testExpiringCredentialRefreshingLogin.createNewExpiringCredential(); + } + + @Override + public void logout() throws LoginException { + /* + * Here is where we get the functionality of a mock while simultaneously + * performing the removal of an expiring credential + */ + mockLoginContext.logout(); + testExpiringCredentialRefreshingLogin.clearExpiringCredential(); + } + + @Override + public Subject getSubject() { + // here we just need the functionality of a mock + return mockLoginContext.getSubject(); + } + } + + /* + * An implementation of LoginContextFactory that returns an instance of + * TestLoginContext + */ + private static class TestLoginContextFactory extends LoginContextFactory { + private final KafkaFutureImpl refresherThreadStartedFuture = new KafkaFutureImpl<>(); + private final KafkaFutureImpl refresherThreadDoneFuture = new KafkaFutureImpl<>(); + private TestLoginContext testLoginContext; + + public void configure(LoginContext mockLoginContext, + TestExpiringCredentialRefreshingLogin testExpiringCredentialRefreshingLogin) throws LoginException { + // sanity check to make sure it is likely a mock + if (Objects.requireNonNull(mockLoginContext).getClass().equals(LoginContext.class) + || mockLoginContext.getClass().equals(TestLoginContext.class)) + throw new IllegalArgumentException(); + this.testLoginContext = new TestLoginContext(Objects.requireNonNull(testExpiringCredentialRefreshingLogin), + mockLoginContext); + } + + @Override + public LoginContext createLoginContext(ExpiringCredentialRefreshingLogin expiringCredentialRefreshingLogin) throws LoginException { + return new LoginContext("", null, null, EMPTY_WILDCARD_CONFIGURATION) { + private boolean loginSuccess = false; + @Override + public void login() throws LoginException { + testLoginContext.login(); + loginSuccess = true; + } + + @Override + public void logout() throws LoginException { + if (!loginSuccess) + // will cause the refresher thread to exit + throw new IllegalStateException("logout called without a successful login"); + testLoginContext.logout(); + } + + @Override + public Subject getSubject() { + return testLoginContext.getSubject(); + } + }; + } + + @Override + public void refresherThreadStarted() { + refresherThreadStartedFuture.complete(null); + } + + @Override + public void refresherThreadDone() { + refresherThreadDoneFuture.complete(null); + } + + public Future refresherThreadStartedFuture() { + return refresherThreadStartedFuture; + } + + public Future refresherThreadDoneFuture() { + return refresherThreadDoneFuture; + } + } + + @Test + public void testRefresh() throws Exception { + for (int numExpectedRefreshes : new int[] {0, 1, 2}) { + for (boolean clientReloginAllowedBeforeLogout : new boolean[] {true, false}) { + Subject subject = new Subject(); + final LoginContext mockLoginContext = mock(LoginContext.class); + when(mockLoginContext.getSubject()).thenReturn(subject); + + MockTime mockTime = new MockTime(); + long startMs = mockTime.milliseconds(); + /* + * Identify the lifetime of each expiring credential + */ + long lifetimeMinutes = 100L; + /* + * Identify the point at which refresh will occur in that lifetime + */ + long refreshEveryMinutes = 80L; + /* + * Set an absolute last refresh time that will cause the login thread to exit + * after a certain number of re-logins (by adding an extra half of a refresh + * interval). + */ + long absoluteLastRefreshMs = startMs + (1 + numExpectedRefreshes) * 1000 * 60 * refreshEveryMinutes + - 1000 * 60 * refreshEveryMinutes / 2; + /* + * Identify buffer time on either side for the refresh algorithm + */ + short minPeriodSeconds = (short) 0; + short bufferSeconds = minPeriodSeconds; + + /* + * Define some listeners so we can keep track of who gets done and when. All + * added listeners should end up done except the last, extra one, which should + * not. + */ + MockScheduler mockScheduler = new MockScheduler(mockTime); + List> waiters = addWaiters(mockScheduler, 1000 * 60 * refreshEveryMinutes, + numExpectedRefreshes + 1); + + // Create the ExpiringCredentialRefreshingLogin instance under test + TestLoginContextFactory testLoginContextFactory = new TestLoginContextFactory(); + TestExpiringCredentialRefreshingLogin testExpiringCredentialRefreshingLogin = new TestExpiringCredentialRefreshingLogin( + refreshConfigThatPerformsReloginEveryGivenPercentageOfLifetime( + 1.0 * refreshEveryMinutes / lifetimeMinutes, minPeriodSeconds, bufferSeconds, + clientReloginAllowedBeforeLogout), + testLoginContextFactory, mockTime, 1000 * 60 * lifetimeMinutes, absoluteLastRefreshMs, + clientReloginAllowedBeforeLogout); + testLoginContextFactory.configure(mockLoginContext, testExpiringCredentialRefreshingLogin); + + /* + * Perform the login, wait up to a certain amount of time for the refresher + * thread to exit, and make sure the correct calls happened at the correct times + */ + long expectedFinalMs = startMs + numExpectedRefreshes * 1000 * 60 * refreshEveryMinutes; + assertFalse(testLoginContextFactory.refresherThreadStartedFuture().isDone()); + assertFalse(testLoginContextFactory.refresherThreadDoneFuture().isDone()); + testExpiringCredentialRefreshingLogin.login(); + assertTrue(testLoginContextFactory.refresherThreadStartedFuture().isDone()); + testLoginContextFactory.refresherThreadDoneFuture().get(1L, TimeUnit.SECONDS); + assertEquals(expectedFinalMs, mockTime.milliseconds()); + for (int i = 0; i < numExpectedRefreshes; ++i) { + KafkaFutureImpl waiter = waiters.get(i); + assertTrue(waiter.isDone()); + assertEquals((i + 1) * 1000 * 60 * refreshEveryMinutes, waiter.get().longValue() - startMs); + } + assertFalse(waiters.get(numExpectedRefreshes).isDone()); + + /* + * We expect login() to be invoked followed by getSubject() and then ultimately followed by + * numExpectedRefreshes pairs of either login()/logout() or logout()/login() calls + */ + InOrder inOrder = inOrder(mockLoginContext); + inOrder.verify(mockLoginContext).login(); + inOrder.verify(mockLoginContext).getSubject(); + for (int i = 0; i < numExpectedRefreshes; ++i) { + if (clientReloginAllowedBeforeLogout) { + inOrder.verify(mockLoginContext).login(); + inOrder.verify(mockLoginContext).logout(); + } else { + inOrder.verify(mockLoginContext).logout(); + inOrder.verify(mockLoginContext).login(); + } + } + testExpiringCredentialRefreshingLogin.close(); + } + } + } + + @Test + public void testRefreshWithExpirationSmallerThanConfiguredBuffers() throws Exception { + int numExpectedRefreshes = 1; + boolean clientReloginAllowedBeforeLogout = true; + final LoginContext mockLoginContext = mock(LoginContext.class); + Subject subject = new Subject(); + when(mockLoginContext.getSubject()).thenReturn(subject); + + MockTime mockTime = new MockTime(); + long startMs = mockTime.milliseconds(); + /* + * Identify the lifetime of each expiring credential + */ + long lifetimeMinutes = 10L; + /* + * Identify the point at which refresh will occur in that lifetime + */ + long refreshEveryMinutes = 8L; + /* + * Set an absolute last refresh time that will cause the login thread to exit + * after a certain number of re-logins (by adding an extra half of a refresh + * interval). + */ + long absoluteLastRefreshMs = startMs + (1 + numExpectedRefreshes) * 1000 * 60 * refreshEveryMinutes + - 1000 * 60 * refreshEveryMinutes / 2; + /* + * Identify buffer time on either side for the refresh algorithm that will cause + * the entire lifetime to be taken up. In other words, make sure there is no way + * to honor the buffers. + */ + short minPeriodSeconds = (short) (1 + lifetimeMinutes * 60 / 2); + short bufferSeconds = minPeriodSeconds; + + /* + * Define some listeners so we can keep track of who gets done and when. All + * added listeners should end up done except the last, extra one, which should + * not. + */ + MockScheduler mockScheduler = new MockScheduler(mockTime); + List> waiters = addWaiters(mockScheduler, 1000 * 60 * refreshEveryMinutes, + numExpectedRefreshes + 1); + + // Create the ExpiringCredentialRefreshingLogin instance under test + TestLoginContextFactory testLoginContextFactory = new TestLoginContextFactory(); + TestExpiringCredentialRefreshingLogin testExpiringCredentialRefreshingLogin = new TestExpiringCredentialRefreshingLogin( + refreshConfigThatPerformsReloginEveryGivenPercentageOfLifetime( + 1.0 * refreshEveryMinutes / lifetimeMinutes, minPeriodSeconds, bufferSeconds, + clientReloginAllowedBeforeLogout), + testLoginContextFactory, mockTime, 1000 * 60 * lifetimeMinutes, absoluteLastRefreshMs, + clientReloginAllowedBeforeLogout); + testLoginContextFactory.configure(mockLoginContext, testExpiringCredentialRefreshingLogin); + + /* + * Perform the login, wait up to a certain amount of time for the refresher + * thread to exit, and make sure the correct calls happened at the correct times + */ + long expectedFinalMs = startMs + numExpectedRefreshes * 1000 * 60 * refreshEveryMinutes; + assertFalse(testLoginContextFactory.refresherThreadStartedFuture().isDone()); + assertFalse(testLoginContextFactory.refresherThreadDoneFuture().isDone()); + testExpiringCredentialRefreshingLogin.login(); + assertTrue(testLoginContextFactory.refresherThreadStartedFuture().isDone()); + testLoginContextFactory.refresherThreadDoneFuture().get(1L, TimeUnit.SECONDS); + assertEquals(expectedFinalMs, mockTime.milliseconds()); + for (int i = 0; i < numExpectedRefreshes; ++i) { + KafkaFutureImpl waiter = waiters.get(i); + assertTrue(waiter.isDone()); + assertEquals((i + 1) * 1000 * 60 * refreshEveryMinutes, waiter.get().longValue() - startMs); + } + assertFalse(waiters.get(numExpectedRefreshes).isDone()); + + InOrder inOrder = inOrder(mockLoginContext); + inOrder.verify(mockLoginContext).login(); + for (int i = 0; i < numExpectedRefreshes; ++i) { + inOrder.verify(mockLoginContext).login(); + inOrder.verify(mockLoginContext).logout(); + } + } + + @Test + public void testRefreshWithExpirationSmallerThanConfiguredBuffersAndOlderCreateTime() throws Exception { + int numExpectedRefreshes = 1; + boolean clientReloginAllowedBeforeLogout = true; + final LoginContext mockLoginContext = mock(LoginContext.class); + Subject subject = new Subject(); + when(mockLoginContext.getSubject()).thenReturn(subject); + + MockTime mockTime = new MockTime(); + long startMs = mockTime.milliseconds(); + /* + * Identify the lifetime of each expiring credential + */ + long lifetimeMinutes = 10L; + /* + * Identify the point at which refresh will occur in that lifetime + */ + long refreshEveryMinutes = 8L; + /* + * Set an absolute last refresh time that will cause the login thread to exit + * after a certain number of re-logins (by adding an extra half of a refresh + * interval). + */ + long absoluteLastRefreshMs = startMs + (1 + numExpectedRefreshes) * 1000 * 60 * refreshEveryMinutes + - 1000 * 60 * refreshEveryMinutes / 2; + /* + * Identify buffer time on either side for the refresh algorithm that will cause + * the entire lifetime to be taken up. In other words, make sure there is no way + * to honor the buffers. + */ + short minPeriodSeconds = (short) (1 + lifetimeMinutes * 60 / 2); + short bufferSeconds = minPeriodSeconds; + + /* + * Define some listeners so we can keep track of who gets done and when. All + * added listeners should end up done except the last, extra one, which should + * not. + */ + MockScheduler mockScheduler = new MockScheduler(mockTime); + List> waiters = addWaiters(mockScheduler, 1000 * 60 * refreshEveryMinutes, + numExpectedRefreshes + 1); + + // Create the ExpiringCredentialRefreshingLogin instance under test + TestLoginContextFactory testLoginContextFactory = new TestLoginContextFactory(); + TestExpiringCredentialRefreshingLogin testExpiringCredentialRefreshingLogin = new TestExpiringCredentialRefreshingLogin( + refreshConfigThatPerformsReloginEveryGivenPercentageOfLifetime( + 1.0 * refreshEveryMinutes / lifetimeMinutes, minPeriodSeconds, bufferSeconds, + clientReloginAllowedBeforeLogout), + testLoginContextFactory, mockTime, 1000 * 60 * lifetimeMinutes, absoluteLastRefreshMs, + clientReloginAllowedBeforeLogout) { + + @Override + public long getCreateMs() { + return super.getCreateMs() - 1000 * 60 * 60; // distant past + } + }; + testLoginContextFactory.configure(mockLoginContext, testExpiringCredentialRefreshingLogin); + + /* + * Perform the login, wait up to a certain amount of time for the refresher + * thread to exit, and make sure the correct calls happened at the correct times + */ + long expectedFinalMs = startMs + numExpectedRefreshes * 1000 * 60 * refreshEveryMinutes; + assertFalse(testLoginContextFactory.refresherThreadStartedFuture().isDone()); + assertFalse(testLoginContextFactory.refresherThreadDoneFuture().isDone()); + testExpiringCredentialRefreshingLogin.login(); + assertTrue(testLoginContextFactory.refresherThreadStartedFuture().isDone()); + testLoginContextFactory.refresherThreadDoneFuture().get(1L, TimeUnit.SECONDS); + assertEquals(expectedFinalMs, mockTime.milliseconds()); + for (int i = 0; i < numExpectedRefreshes; ++i) { + KafkaFutureImpl waiter = waiters.get(i); + assertTrue(waiter.isDone()); + assertEquals((i + 1) * 1000 * 60 * refreshEveryMinutes, waiter.get().longValue() - startMs); + } + assertFalse(waiters.get(numExpectedRefreshes).isDone()); + + InOrder inOrder = inOrder(mockLoginContext); + inOrder.verify(mockLoginContext).login(); + for (int i = 0; i < numExpectedRefreshes; ++i) { + inOrder.verify(mockLoginContext).login(); + inOrder.verify(mockLoginContext).logout(); + } + } + + @Test + public void testRefreshWithMinPeriodIntrusion() throws Exception { + int numExpectedRefreshes = 1; + boolean clientReloginAllowedBeforeLogout = true; + Subject subject = new Subject(); + final LoginContext mockLoginContext = mock(LoginContext.class); + when(mockLoginContext.getSubject()).thenReturn(subject); + + MockTime mockTime = new MockTime(); + long startMs = mockTime.milliseconds(); + /* + * Identify the lifetime of each expiring credential + */ + long lifetimeMinutes = 10L; + /* + * Identify the point at which refresh will occur in that lifetime + */ + long refreshEveryMinutes = 8L; + /* + * Set an absolute last refresh time that will cause the login thread to exit + * after a certain number of re-logins (by adding an extra half of a refresh + * interval). + */ + long absoluteLastRefreshMs = startMs + (1 + numExpectedRefreshes) * 1000 * 60 * refreshEveryMinutes + - 1000 * 60 * refreshEveryMinutes / 2; + + /* + * Identify a minimum period that will cause the refresh time to be delayed a + * bit. + */ + int bufferIntrusionSeconds = 1; + short minPeriodSeconds = (short) (refreshEveryMinutes * 60 + bufferIntrusionSeconds); + short bufferSeconds = (short) 0; + + /* + * Define some listeners so we can keep track of who gets done and when. All + * added listeners should end up done except the last, extra one, which should + * not. + */ + MockScheduler mockScheduler = new MockScheduler(mockTime); + List> waiters = addWaiters(mockScheduler, + 1000 * (60 * refreshEveryMinutes + bufferIntrusionSeconds), numExpectedRefreshes + 1); + + // Create the ExpiringCredentialRefreshingLogin instance under test + TestLoginContextFactory testLoginContextFactory = new TestLoginContextFactory(); + TestExpiringCredentialRefreshingLogin testExpiringCredentialRefreshingLogin = new TestExpiringCredentialRefreshingLogin( + refreshConfigThatPerformsReloginEveryGivenPercentageOfLifetime( + 1.0 * refreshEveryMinutes / lifetimeMinutes, minPeriodSeconds, bufferSeconds, + clientReloginAllowedBeforeLogout), + testLoginContextFactory, mockTime, 1000 * 60 * lifetimeMinutes, absoluteLastRefreshMs, + clientReloginAllowedBeforeLogout); + testLoginContextFactory.configure(mockLoginContext, testExpiringCredentialRefreshingLogin); + + /* + * Perform the login, wait up to a certain amount of time for the refresher + * thread to exit, and make sure the correct calls happened at the correct times + */ + long expectedFinalMs = startMs + + numExpectedRefreshes * 1000 * (60 * refreshEveryMinutes + bufferIntrusionSeconds); + assertFalse(testLoginContextFactory.refresherThreadStartedFuture().isDone()); + assertFalse(testLoginContextFactory.refresherThreadDoneFuture().isDone()); + testExpiringCredentialRefreshingLogin.login(); + assertTrue(testLoginContextFactory.refresherThreadStartedFuture().isDone()); + testLoginContextFactory.refresherThreadDoneFuture().get(1L, TimeUnit.SECONDS); + assertEquals(expectedFinalMs, mockTime.milliseconds()); + for (int i = 0; i < numExpectedRefreshes; ++i) { + KafkaFutureImpl waiter = waiters.get(i); + assertTrue(waiter.isDone()); + assertEquals((i + 1) * 1000 * (60 * refreshEveryMinutes + bufferIntrusionSeconds), + waiter.get().longValue() - startMs); + } + assertFalse(waiters.get(numExpectedRefreshes).isDone()); + + InOrder inOrder = inOrder(mockLoginContext); + inOrder.verify(mockLoginContext).login(); + for (int i = 0; i < numExpectedRefreshes; ++i) { + inOrder.verify(mockLoginContext).login(); + inOrder.verify(mockLoginContext).logout(); + } + } + + @Test + public void testRefreshWithPreExpirationBufferIntrusion() throws Exception { + int numExpectedRefreshes = 1; + boolean clientReloginAllowedBeforeLogout = true; + Subject subject = new Subject(); + final LoginContext mockLoginContext = mock(LoginContext.class); + when(mockLoginContext.getSubject()).thenReturn(subject); + + MockTime mockTime = new MockTime(); + long startMs = mockTime.milliseconds(); + /* + * Identify the lifetime of each expiring credential + */ + long lifetimeMinutes = 10L; + /* + * Identify the point at which refresh will occur in that lifetime + */ + long refreshEveryMinutes = 8L; + /* + * Set an absolute last refresh time that will cause the login thread to exit + * after a certain number of re-logins (by adding an extra half of a refresh + * interval). + */ + long absoluteLastRefreshMs = startMs + (1 + numExpectedRefreshes) * 1000 * 60 * refreshEveryMinutes + - 1000 * 60 * refreshEveryMinutes / 2; + /* + * Identify a minimum period that will cause the refresh time to be delayed a + * bit. + */ + int bufferIntrusionSeconds = 1; + short bufferSeconds = (short) ((lifetimeMinutes - refreshEveryMinutes) * 60 + bufferIntrusionSeconds); + short minPeriodSeconds = (short) 0; + + /* + * Define some listeners so we can keep track of who gets done and when. All + * added listeners should end up done except the last, extra one, which should + * not. + */ + MockScheduler mockScheduler = new MockScheduler(mockTime); + List> waiters = addWaiters(mockScheduler, + 1000 * (60 * refreshEveryMinutes - bufferIntrusionSeconds), numExpectedRefreshes + 1); + + // Create the ExpiringCredentialRefreshingLogin instance under test + TestLoginContextFactory testLoginContextFactory = new TestLoginContextFactory(); + TestExpiringCredentialRefreshingLogin testExpiringCredentialRefreshingLogin = new TestExpiringCredentialRefreshingLogin( + refreshConfigThatPerformsReloginEveryGivenPercentageOfLifetime( + 1.0 * refreshEveryMinutes / lifetimeMinutes, minPeriodSeconds, bufferSeconds, + clientReloginAllowedBeforeLogout), + testLoginContextFactory, mockTime, 1000 * 60 * lifetimeMinutes, absoluteLastRefreshMs, + clientReloginAllowedBeforeLogout); + testLoginContextFactory.configure(mockLoginContext, testExpiringCredentialRefreshingLogin); + + /* + * Perform the login, wait up to a certain amount of time for the refresher + * thread to exit, and make sure the correct calls happened at the correct times + */ + long expectedFinalMs = startMs + + numExpectedRefreshes * 1000 * (60 * refreshEveryMinutes - bufferIntrusionSeconds); + assertFalse(testLoginContextFactory.refresherThreadStartedFuture().isDone()); + assertFalse(testLoginContextFactory.refresherThreadDoneFuture().isDone()); + testExpiringCredentialRefreshingLogin.login(); + assertTrue(testLoginContextFactory.refresherThreadStartedFuture().isDone()); + testLoginContextFactory.refresherThreadDoneFuture().get(1L, TimeUnit.SECONDS); + assertEquals(expectedFinalMs, mockTime.milliseconds()); + for (int i = 0; i < numExpectedRefreshes; ++i) { + KafkaFutureImpl waiter = waiters.get(i); + assertTrue(waiter.isDone()); + assertEquals((i + 1) * 1000 * (60 * refreshEveryMinutes - bufferIntrusionSeconds), + waiter.get().longValue() - startMs); + } + assertFalse(waiters.get(numExpectedRefreshes).isDone()); + + InOrder inOrder = inOrder(mockLoginContext); + inOrder.verify(mockLoginContext).login(); + for (int i = 0; i < numExpectedRefreshes; ++i) { + inOrder.verify(mockLoginContext).login(); + inOrder.verify(mockLoginContext).logout(); + } + } + + @Test + public void testLoginExceptionCausesCorrectLogout() throws Exception { + int numExpectedRefreshes = 3; + boolean clientReloginAllowedBeforeLogout = true; + Subject subject = new Subject(); + final LoginContext mockLoginContext = mock(LoginContext.class); + when(mockLoginContext.getSubject()).thenReturn(subject); + Mockito.doNothing().doThrow(new LoginException()).doNothing().when(mockLoginContext).login(); + + MockTime mockTime = new MockTime(); + long startMs = mockTime.milliseconds(); + /* + * Identify the lifetime of each expiring credential + */ + long lifetimeMinutes = 100L; + /* + * Identify the point at which refresh will occur in that lifetime + */ + long refreshEveryMinutes = 80L; + /* + * Set an absolute last refresh time that will cause the login thread to exit + * after a certain number of re-logins (by adding an extra half of a refresh + * interval). + */ + long absoluteLastRefreshMs = startMs + (1 + numExpectedRefreshes) * 1000 * 60 * refreshEveryMinutes + - 1000 * 60 * refreshEveryMinutes / 2; + /* + * Identify buffer time on either side for the refresh algorithm + */ + short minPeriodSeconds = (short) 0; + short bufferSeconds = minPeriodSeconds; + + // Create the ExpiringCredentialRefreshingLogin instance under test + TestLoginContextFactory testLoginContextFactory = new TestLoginContextFactory(); + TestExpiringCredentialRefreshingLogin testExpiringCredentialRefreshingLogin = new TestExpiringCredentialRefreshingLogin( + refreshConfigThatPerformsReloginEveryGivenPercentageOfLifetime( + 1.0 * refreshEveryMinutes / lifetimeMinutes, minPeriodSeconds, bufferSeconds, + clientReloginAllowedBeforeLogout), + testLoginContextFactory, mockTime, 1000 * 60 * lifetimeMinutes, absoluteLastRefreshMs, + clientReloginAllowedBeforeLogout); + testLoginContextFactory.configure(mockLoginContext, testExpiringCredentialRefreshingLogin); + + /* + * Perform the login and wait up to a certain amount of time for the refresher + * thread to exit. A timeout indicates the thread died due to logout() + * being invoked on an instance where the login() invocation had failed. + */ + assertFalse(testLoginContextFactory.refresherThreadStartedFuture().isDone()); + assertFalse(testLoginContextFactory.refresherThreadDoneFuture().isDone()); + testExpiringCredentialRefreshingLogin.login(); + assertTrue(testLoginContextFactory.refresherThreadStartedFuture().isDone()); + testLoginContextFactory.refresherThreadDoneFuture().get(1L, TimeUnit.SECONDS); + } + + private static List> addWaiters(MockScheduler mockScheduler, long refreshEveryMillis, + int numWaiters) { + List> retvalWaiters = new ArrayList<>(numWaiters); + for (int i = 1; i <= numWaiters; ++i) { + KafkaFutureImpl waiter = new KafkaFutureImpl(); + mockScheduler.addWaiter(i * refreshEveryMillis, waiter); + retvalWaiters.add(waiter); + } + return retvalWaiters; + } + + private static ExpiringCredentialRefreshConfig refreshConfigThatPerformsReloginEveryGivenPercentageOfLifetime( + double refreshWindowFactor, short minPeriodSeconds, short bufferSeconds, + boolean clientReloginAllowedBeforeLogout) { + Map configs = new HashMap<>(); + configs.put(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_FACTOR, refreshWindowFactor); + configs.put(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_JITTER, 0); + configs.put(SaslConfigs.SASL_LOGIN_REFRESH_MIN_PERIOD_SECONDS, minPeriodSeconds); + configs.put(SaslConfigs.SASL_LOGIN_REFRESH_BUFFER_SECONDS, bufferSeconds); + return new ExpiringCredentialRefreshConfig(new ConfigDef().withClientSaslSupport().parse(configs), + clientReloginAllowedBeforeLogout); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerScopeUtilsTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerScopeUtilsTest.java new file mode 100644 index 0000000..f65440e --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerScopeUtilsTest.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer.internals.unsecured; + +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +import java.util.List; + +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.api.Test; + +public class OAuthBearerScopeUtilsTest { + @Test + public void validScope() { + for (String validScope : new String[] {"", " ", "scope1", " scope1 ", "scope1 Scope2", "scope1 Scope2"}) { + List parsedScope = OAuthBearerScopeUtils.parseScope(validScope); + if (Utils.isBlank(validScope)) { + assertTrue(parsedScope.isEmpty()); + } else if (validScope.contains("Scope2")) { + assertTrue(parsedScope.size() == 2 && parsedScope.get(0).equals("scope1") + && parsedScope.get(1).equals("Scope2")); + } else { + assertTrue(parsedScope.size() == 1 && parsedScope.get(0).equals("scope1")); + } + } + } + + @Test + public void invalidScope() { + for (String invalidScope : new String[] {"\"foo", "\\foo"}) { + try { + OAuthBearerScopeUtils.parseScope(invalidScope); + fail("did not detect invalid scope: " + invalidScope); + } catch (OAuthBearerConfigException expected) { + // empty + } + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredJwsTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredJwsTest.java new file mode 100644 index 0000000..af259c6 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredJwsTest.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer.internals.unsecured; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Base64; +import java.util.Base64.Encoder; +import java.util.HashSet; +import java.util.List; + +public class OAuthBearerUnsecuredJwsTest { + private static final String QUOTE = "\""; + private static final String HEADER_COMPACT_SERIALIZATION = Base64.getUrlEncoder().withoutPadding() + .encodeToString("{\"alg\":\"none\"}".getBytes(StandardCharsets.UTF_8)) + "."; + + @Test + public void validClaims() throws OAuthBearerIllegalTokenException { + double issuedAtSeconds = 100.1; + double expirationTimeSeconds = 300.3; + StringBuilder sb = new StringBuilder("{"); + appendJsonText(sb, "sub", "SUBJECT"); + appendCommaJsonText(sb, "iat", issuedAtSeconds); + appendCommaJsonText(sb, "exp", expirationTimeSeconds); + sb.append("}"); + String compactSerialization = HEADER_COMPACT_SERIALIZATION + + Base64.getUrlEncoder().withoutPadding().encodeToString(sb.toString().getBytes(StandardCharsets.UTF_8)) + + "."; + OAuthBearerUnsecuredJws testJwt = new OAuthBearerUnsecuredJws(compactSerialization, "sub", "scope"); + assertEquals(compactSerialization, testJwt.value()); + assertEquals("sub", testJwt.principalClaimName()); + assertEquals(1, testJwt.header().size()); + assertEquals("none", testJwt.header().get("alg")); + assertEquals("scope", testJwt.scopeClaimName()); + assertEquals(expirationTimeSeconds, testJwt.expirationTime()); + assertTrue(testJwt.isClaimType("exp", Number.class)); + assertEquals(issuedAtSeconds, testJwt.issuedAt()); + assertEquals("SUBJECT", testJwt.subject()); + } + + @Test + public void validCompactSerialization() { + String subject = "foo"; + long issuedAt = 100; + long expirationTime = issuedAt + 60 * 60; + List scope = Arrays.asList("scopeValue1", "scopeValue2"); + String validCompactSerialization = compactSerialization(subject, issuedAt, expirationTime, scope); + OAuthBearerUnsecuredJws jws = new OAuthBearerUnsecuredJws(validCompactSerialization, "sub", "scope"); + assertEquals(1, jws.header().size()); + assertEquals("none", jws.header().get("alg")); + assertEquals(4, jws.claims().size()); + assertEquals(subject, jws.claims().get("sub")); + assertEquals(subject, jws.principalName()); + assertEquals(issuedAt, Number.class.cast(jws.claims().get("iat")).longValue()); + assertEquals(expirationTime, Number.class.cast(jws.claims().get("exp")).longValue()); + assertEquals(expirationTime * 1000, jws.lifetimeMs()); + assertEquals(scope, jws.claims().get("scope")); + assertEquals(new HashSet<>(scope), jws.scope()); + assertEquals(3, jws.splits().size()); + assertEquals(validCompactSerialization.split("\\.")[0], jws.splits().get(0)); + assertEquals(validCompactSerialization.split("\\.")[1], jws.splits().get(1)); + assertEquals("", jws.splits().get(2)); + } + + @Test + public void missingPrincipal() { + String subject = null; + long issuedAt = 100; + Long expirationTime = null; + List scope = Arrays.asList("scopeValue1", "scopeValue2"); + String validCompactSerialization = compactSerialization(subject, issuedAt, expirationTime, scope); + assertThrows(OAuthBearerIllegalTokenException.class, + () -> new OAuthBearerUnsecuredJws(validCompactSerialization, "sub", "scope")); + } + + @Test + public void blankPrincipalName() { + String subject = " "; + long issuedAt = 100; + long expirationTime = issuedAt + 60 * 60; + List scope = Arrays.asList("scopeValue1", "scopeValue2"); + String validCompactSerialization = compactSerialization(subject, issuedAt, expirationTime, scope); + assertThrows(OAuthBearerIllegalTokenException.class, + () -> new OAuthBearerUnsecuredJws(validCompactSerialization, "sub", "scope")); + } + + private static String compactSerialization(String subject, Long issuedAt, Long expirationTime, List scope) { + Encoder encoder = Base64.getUrlEncoder().withoutPadding(); + String algorithm = "none"; + String headerJson = "{\"alg\":\"" + algorithm + "\"}"; + String encodedHeader = encoder.encodeToString(headerJson.getBytes(StandardCharsets.UTF_8)); + String subjectJson = subject != null ? "\"sub\":\"" + subject + "\"" : null; + String issuedAtJson = issuedAt != null ? "\"iat\":" + issuedAt.longValue() : null; + String expirationTimeJson = expirationTime != null ? "\"exp\":" + expirationTime.longValue() : null; + String scopeJson = scope != null ? scopeJson(scope) : null; + String claimsJson = claimsJson(subjectJson, issuedAtJson, expirationTimeJson, scopeJson); + String encodedClaims = encoder.encodeToString(claimsJson.getBytes(StandardCharsets.UTF_8)); + return encodedHeader + "." + encodedClaims + "."; + } + + private static String claimsJson(String... jsonValues) { + StringBuilder claimsJsonBuilder = new StringBuilder("{"); + int initialLength = claimsJsonBuilder.length(); + for (String jsonValue : jsonValues) { + if (jsonValue != null) { + if (claimsJsonBuilder.length() > initialLength) + claimsJsonBuilder.append(','); + claimsJsonBuilder.append(jsonValue); + } + } + claimsJsonBuilder.append('}'); + return claimsJsonBuilder.toString(); + } + + private static String scopeJson(List scope) { + StringBuilder scopeJsonBuilder = new StringBuilder("\"scope\":["); + int initialLength = scopeJsonBuilder.length(); + for (String scopeValue : scope) { + if (scopeJsonBuilder.length() > initialLength) + scopeJsonBuilder.append(','); + scopeJsonBuilder.append('"').append(scopeValue).append('"'); + } + scopeJsonBuilder.append(']'); + return scopeJsonBuilder.toString(); + } + + private static void appendCommaJsonText(StringBuilder sb, String claimName, Number claimValue) { + sb.append(',').append(QUOTE).append(escape(claimName)).append(QUOTE).append(":").append(claimValue); + } + + private static void appendJsonText(StringBuilder sb, String claimName, String claimValue) { + sb.append(QUOTE).append(escape(claimName)).append(QUOTE).append(":").append(QUOTE).append(escape(claimValue)) + .append(QUOTE); + } + + private static String escape(String jsonStringValue) { + return jsonStringValue.replace("\"", "\\\"").replace("\\", "\\\\"); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredLoginCallbackHandlerTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredLoginCallbackHandlerTest.java new file mode 100644 index 0000000..443a1de --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredLoginCallbackHandlerTest.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer.internals.unsecured; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.UnsupportedCallbackException; + +import org.apache.kafka.common.security.auth.SaslExtensionsCallback; +import org.apache.kafka.common.security.authenticator.TestJaasConfig; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback; +import org.apache.kafka.common.utils.MockTime; +import org.junit.jupiter.api.Test; + +public class OAuthBearerUnsecuredLoginCallbackHandlerTest { + + @Test + public void addsExtensions() throws IOException, UnsupportedCallbackException { + Map options = new HashMap<>(); + options.put("unsecuredLoginExtension_testId", "1"); + OAuthBearerUnsecuredLoginCallbackHandler callbackHandler = createCallbackHandler(options, new MockTime()); + SaslExtensionsCallback callback = new SaslExtensionsCallback(); + + callbackHandler.handle(new Callback[] {callback}); + + assertEquals("1", callback.extensions().map().get("testId")); + } + + @Test + public void throwsErrorOnInvalidExtensionName() { + Map options = new HashMap<>(); + options.put("unsecuredLoginExtension_test.Id", "1"); + OAuthBearerUnsecuredLoginCallbackHandler callbackHandler = createCallbackHandler(options, new MockTime()); + SaslExtensionsCallback callback = new SaslExtensionsCallback(); + + assertThrows(IOException.class, () -> callbackHandler.handle(new Callback[] {callback})); + } + + @Test + public void throwsErrorOnInvalidExtensionValue() { + Map options = new HashMap<>(); + options.put("unsecuredLoginExtension_testId", "Çalifornia"); + OAuthBearerUnsecuredLoginCallbackHandler callbackHandler = createCallbackHandler(options, new MockTime()); + SaslExtensionsCallback callback = new SaslExtensionsCallback(); + + assertThrows(IOException.class, () -> callbackHandler.handle(new Callback[] {callback})); + } + + @Test + public void minimalToken() throws IOException, UnsupportedCallbackException { + Map options = new HashMap<>(); + String user = "user"; + options.put("unsecuredLoginStringClaim_sub", user); + MockTime mockTime = new MockTime(); + OAuthBearerUnsecuredLoginCallbackHandler callbackHandler = createCallbackHandler(options, mockTime); + OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback(); + callbackHandler.handle(new Callback[] {callback}); + OAuthBearerUnsecuredJws jws = (OAuthBearerUnsecuredJws) callback.token(); + assertNotNull(jws, "create token failed"); + long startMs = mockTime.milliseconds(); + confirmCorrectValues(jws, user, startMs, 1000 * 60 * 60); + assertEquals(new HashSet<>(Arrays.asList("sub", "iat", "exp")), jws.claims().keySet()); + } + + @SuppressWarnings("unchecked") + @Test + public void validOptionsWithExplicitOptionValues() + throws IOException, UnsupportedCallbackException { + String explicitScope1 = "scope1"; + String explicitScope2 = "scope2"; + String explicitScopeClaimName = "putScopeInHere"; + String principalClaimName = "principal"; + final String[] scopeClaimNameOptionValues = {null, explicitScopeClaimName}; + for (String scopeClaimNameOptionValue : scopeClaimNameOptionValues) { + Map options = new HashMap<>(); + String user = "user"; + options.put("unsecuredLoginStringClaim_" + principalClaimName, user); + options.put("unsecuredLoginListClaim_" + "list", ",1,2,"); + options.put("unsecuredLoginListClaim_" + "emptyList1", ""); + options.put("unsecuredLoginListClaim_" + "emptyList2", ","); + options.put("unsecuredLoginNumberClaim_" + "number", "1"); + long lifetmeSeconds = 10000; + options.put("unsecuredLoginLifetimeSeconds", String.valueOf(lifetmeSeconds)); + options.put("unsecuredLoginPrincipalClaimName", principalClaimName); + if (scopeClaimNameOptionValue != null) + options.put("unsecuredLoginScopeClaimName", scopeClaimNameOptionValue); + String actualScopeClaimName = scopeClaimNameOptionValue == null ? "scope" : explicitScopeClaimName; + options.put("unsecuredLoginListClaim_" + actualScopeClaimName, + String.format("|%s|%s", explicitScope1, explicitScope2)); + MockTime mockTime = new MockTime(); + OAuthBearerUnsecuredLoginCallbackHandler callbackHandler = createCallbackHandler(options, mockTime); + OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback(); + callbackHandler.handle(new Callback[] {callback}); + OAuthBearerUnsecuredJws jws = (OAuthBearerUnsecuredJws) callback.token(); + assertNotNull(jws, "create token failed"); + long startMs = mockTime.milliseconds(); + confirmCorrectValues(jws, user, startMs, lifetmeSeconds * 1000); + Map claims = jws.claims(); + assertEquals(new HashSet<>(Arrays.asList(actualScopeClaimName, principalClaimName, "iat", "exp", "number", + "list", "emptyList1", "emptyList2")), claims.keySet()); + assertEquals(new HashSet<>(Arrays.asList(explicitScope1, explicitScope2)), + new HashSet<>((List) claims.get(actualScopeClaimName))); + assertEquals(new HashSet<>(Arrays.asList(explicitScope1, explicitScope2)), jws.scope()); + assertEquals(1.0, jws.claim("number", Number.class)); + assertEquals(Arrays.asList("1", "2", ""), jws.claim("list", List.class)); + assertEquals(Collections.emptyList(), jws.claim("emptyList1", List.class)); + assertEquals(Collections.emptyList(), jws.claim("emptyList2", List.class)); + } + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private static OAuthBearerUnsecuredLoginCallbackHandler createCallbackHandler(Map options, + MockTime mockTime) { + TestJaasConfig config = new TestJaasConfig(); + config.createOrUpdateEntry("KafkaClient", "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule", + (Map) options); + OAuthBearerUnsecuredLoginCallbackHandler callbackHandler = new OAuthBearerUnsecuredLoginCallbackHandler(); + callbackHandler.time(mockTime); + callbackHandler.configure(Collections.emptyMap(), OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, + Arrays.asList(config.getAppConfigurationEntry("KafkaClient")[0])); + return callbackHandler; + } + + private static void confirmCorrectValues(OAuthBearerUnsecuredJws jws, String user, long startMs, + long lifetimeSeconds) throws OAuthBearerIllegalTokenException { + Map header = jws.header(); + assertEquals(header.size(), 1); + assertEquals("none", header.get("alg")); + assertEquals(user != null ? user : "", jws.principalName()); + assertEquals(Long.valueOf(startMs), jws.startTimeMs()); + assertEquals(startMs, Math.round(jws.issuedAt().doubleValue() * 1000)); + assertEquals(startMs + lifetimeSeconds, jws.lifetimeMs()); + assertEquals(jws.lifetimeMs(), Math.round(jws.expirationTime().doubleValue() * 1000)); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredValidatorCallbackHandlerTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredValidatorCallbackHandlerTest.java new file mode 100644 index 0000000..bc1b660 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerUnsecuredValidatorCallbackHandlerTest.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer.internals.unsecured; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Base64; +import java.util.Base64.Encoder; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.UnsupportedCallbackException; + +import org.apache.kafka.common.security.authenticator.TestJaasConfig; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerValidatorCallback; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.junit.jupiter.api.Test; + +public class OAuthBearerUnsecuredValidatorCallbackHandlerTest { + private static final String UNSECURED_JWT_HEADER_JSON = "{" + claimOrHeaderText("alg", "none") + "}"; + private static final Time MOCK_TIME = new MockTime(); + private static final String QUOTE = "\""; + private static final String PRINCIPAL_CLAIM_VALUE = "username"; + private static final String PRINCIPAL_CLAIM_TEXT = claimOrHeaderText("principal", PRINCIPAL_CLAIM_VALUE); + private static final String SUB_CLAIM_TEXT = claimOrHeaderText("sub", PRINCIPAL_CLAIM_VALUE); + private static final String BAD_PRINCIPAL_CLAIM_TEXT = claimOrHeaderText("principal", 1); + private static final long LIFETIME_SECONDS_TO_USE = 1000 * 60 * 60; + private static final String EXPIRATION_TIME_CLAIM_TEXT = expClaimText(LIFETIME_SECONDS_TO_USE); + private static final String TOO_EARLY_EXPIRATION_TIME_CLAIM_TEXT = expClaimText(0); + private static final String ISSUED_AT_CLAIM_TEXT = claimOrHeaderText("iat", MOCK_TIME.milliseconds() / 1000.0); + private static final String SCOPE_CLAIM_TEXT = claimOrHeaderText("scope", "scope1"); + private static final Map MODULE_OPTIONS_MAP_NO_SCOPE_REQUIRED; + static { + Map tmp = new HashMap<>(); + tmp.put("unsecuredValidatorPrincipalClaimName", "principal"); + tmp.put("unsecuredValidatorAllowableClockSkewMs", "1"); + MODULE_OPTIONS_MAP_NO_SCOPE_REQUIRED = Collections.unmodifiableMap(tmp); + } + private static final Map MODULE_OPTIONS_MAP_REQUIRE_EXISTING_SCOPE; + static { + Map tmp = new HashMap<>(); + tmp.put("unsecuredValidatorRequiredScope", "scope1"); + MODULE_OPTIONS_MAP_REQUIRE_EXISTING_SCOPE = Collections.unmodifiableMap(tmp); + } + private static final Map MODULE_OPTIONS_MAP_REQUIRE_ADDITIONAL_SCOPE; + static { + Map tmp = new HashMap<>(); + tmp.put("unsecuredValidatorRequiredScope", "scope1 scope2"); + MODULE_OPTIONS_MAP_REQUIRE_ADDITIONAL_SCOPE = Collections.unmodifiableMap(tmp); + } + + @Test + public void validToken() { + for (final boolean includeOptionalIssuedAtClaim : new boolean[] {true, false}) { + String claimsJson = "{" + PRINCIPAL_CLAIM_TEXT + comma(EXPIRATION_TIME_CLAIM_TEXT) + + (includeOptionalIssuedAtClaim ? comma(ISSUED_AT_CLAIM_TEXT) : "") + "}"; + Object validationResult = validationResult(UNSECURED_JWT_HEADER_JSON, claimsJson, + MODULE_OPTIONS_MAP_NO_SCOPE_REQUIRED); + assertTrue(validationResult instanceof OAuthBearerValidatorCallback); + assertTrue(((OAuthBearerValidatorCallback) validationResult).token() instanceof OAuthBearerUnsecuredJws); + } + } + + @Test + public void badOrMissingPrincipal() throws IOException, UnsupportedCallbackException { + for (boolean exists : new boolean[] {true, false}) { + String claimsJson = "{" + EXPIRATION_TIME_CLAIM_TEXT + (exists ? comma(BAD_PRINCIPAL_CLAIM_TEXT) : "") + + "}"; + confirmFailsValidation(UNSECURED_JWT_HEADER_JSON, claimsJson, MODULE_OPTIONS_MAP_NO_SCOPE_REQUIRED); + } + } + + @Test + public void tooEarlyExpirationTime() throws IOException, UnsupportedCallbackException { + String claimsJson = "{" + PRINCIPAL_CLAIM_TEXT + comma(ISSUED_AT_CLAIM_TEXT) + + comma(TOO_EARLY_EXPIRATION_TIME_CLAIM_TEXT) + "}"; + confirmFailsValidation(UNSECURED_JWT_HEADER_JSON, claimsJson, MODULE_OPTIONS_MAP_NO_SCOPE_REQUIRED); + } + + @Test + public void includesRequiredScope() { + String claimsJson = "{" + SUB_CLAIM_TEXT + comma(EXPIRATION_TIME_CLAIM_TEXT) + comma(SCOPE_CLAIM_TEXT) + "}"; + Object validationResult = validationResult(UNSECURED_JWT_HEADER_JSON, claimsJson, + MODULE_OPTIONS_MAP_REQUIRE_EXISTING_SCOPE); + assertTrue(validationResult instanceof OAuthBearerValidatorCallback); + assertTrue(((OAuthBearerValidatorCallback) validationResult).token() instanceof OAuthBearerUnsecuredJws); + } + + @Test + public void missingRequiredScope() throws IOException, UnsupportedCallbackException { + String claimsJson = "{" + SUB_CLAIM_TEXT + comma(EXPIRATION_TIME_CLAIM_TEXT) + comma(SCOPE_CLAIM_TEXT) + "}"; + confirmFailsValidation(UNSECURED_JWT_HEADER_JSON, claimsJson, MODULE_OPTIONS_MAP_REQUIRE_ADDITIONAL_SCOPE, + "[scope1, scope2]"); + } + + private static void confirmFailsValidation(String headerJson, String claimsJson, + Map moduleOptionsMap) throws OAuthBearerConfigException, OAuthBearerIllegalTokenException, + IOException, UnsupportedCallbackException { + confirmFailsValidation(headerJson, claimsJson, moduleOptionsMap, null); + } + + private static void confirmFailsValidation(String headerJson, String claimsJson, + Map moduleOptionsMap, String optionalFailureScope) throws OAuthBearerConfigException, + OAuthBearerIllegalTokenException { + Object validationResultObj = validationResult(headerJson, claimsJson, moduleOptionsMap); + assertTrue(validationResultObj instanceof OAuthBearerValidatorCallback); + OAuthBearerValidatorCallback callback = (OAuthBearerValidatorCallback) validationResultObj; + assertNull(callback.token()); + assertNull(callback.errorOpenIDConfiguration()); + if (optionalFailureScope == null) { + assertEquals("invalid_token", callback.errorStatus()); + assertNull(callback.errorScope()); + } else { + assertEquals("insufficient_scope", callback.errorStatus()); + assertEquals(optionalFailureScope, callback.errorScope()); + } + } + + private static Object validationResult(String headerJson, String claimsJson, Map moduleOptionsMap) { + Encoder urlEncoderNoPadding = Base64.getUrlEncoder().withoutPadding(); + try { + String tokenValue = String.format("%s.%s.", + urlEncoderNoPadding.encodeToString(headerJson.getBytes(StandardCharsets.UTF_8)), + urlEncoderNoPadding.encodeToString(claimsJson.getBytes(StandardCharsets.UTF_8))); + OAuthBearerValidatorCallback callback = new OAuthBearerValidatorCallback(tokenValue); + createCallbackHandler(moduleOptionsMap).handle(new Callback[] {callback}); + return callback; + } catch (Exception e) { + return e; + } + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private static OAuthBearerUnsecuredValidatorCallbackHandler createCallbackHandler(Map options) { + TestJaasConfig config = new TestJaasConfig(); + config.createOrUpdateEntry("KafkaClient", "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule", + (Map) options); + OAuthBearerUnsecuredValidatorCallbackHandler callbackHandler = new OAuthBearerUnsecuredValidatorCallbackHandler(); + callbackHandler.configure(Collections.emptyMap(), OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, + Arrays.asList(config.getAppConfigurationEntry("KafkaClient")[0])); + return callbackHandler; + } + + private static String comma(String value) { + return "," + value; + } + + private static String claimOrHeaderText(String claimName, Number claimValue) { + return QUOTE + claimName + QUOTE + ":" + claimValue; + } + + private static String claimOrHeaderText(String claimName, String claimValue) { + return QUOTE + claimName + QUOTE + ":" + QUOTE + claimValue + QUOTE; + } + + private static String expClaimText(long lifetimeSeconds) { + return claimOrHeaderText("exp", MOCK_TIME.milliseconds() / 1000.0 + lifetimeSeconds); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerValidationUtilsTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerValidationUtilsTest.java new file mode 100644 index 0000000..88241b7 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/unsecured/OAuthBearerValidationUtilsTest.java @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.oauthbearer.internals.unsecured; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Base64; +import java.util.Collections; +import java.util.List; + +import org.apache.kafka.common.utils.Time; +import org.junit.jupiter.api.Test; + +public class OAuthBearerValidationUtilsTest { + private static final String QUOTE = "\""; + private static final String HEADER_COMPACT_SERIALIZATION = Base64.getUrlEncoder().withoutPadding() + .encodeToString("{\"alg\":\"none\"}".getBytes(StandardCharsets.UTF_8)) + "."; + private static final Time TIME = Time.SYSTEM; + + @Test + public void validateClaimForExistenceAndType() throws OAuthBearerIllegalTokenException { + String claimName = "foo"; + for (Boolean exists : new Boolean[] {null, Boolean.TRUE, Boolean.FALSE}) { + boolean useErrorValue = exists == null; + for (Boolean required : new boolean[] {true, false}) { + StringBuilder sb = new StringBuilder("{"); + appendJsonText(sb, "exp", 100); + appendCommaJsonText(sb, "sub", "principalName"); + if (useErrorValue) + appendCommaJsonText(sb, claimName, 1); + else if (exists != null && exists.booleanValue()) + appendCommaJsonText(sb, claimName, claimName); + sb.append("}"); + String compactSerialization = HEADER_COMPACT_SERIALIZATION + Base64.getUrlEncoder().withoutPadding() + .encodeToString(sb.toString().getBytes(StandardCharsets.UTF_8)) + "."; + OAuthBearerUnsecuredJws testJwt = new OAuthBearerUnsecuredJws(compactSerialization, "sub", "scope"); + OAuthBearerValidationResult result = OAuthBearerValidationUtils + .validateClaimForExistenceAndType(testJwt, required, claimName, String.class); + if (useErrorValue || required && !exists.booleanValue()) + assertTrue(isFailureWithMessageAndNoFailureScope(result)); + else + assertTrue(isSuccess(result)); + } + } + } + + @Test + public void validateIssuedAt() { + long nowMs = TIME.milliseconds(); + double nowClaimValue = ((double) nowMs) / 1000; + for (boolean exists : new boolean[] {true, false}) { + StringBuilder sb = new StringBuilder("{"); + appendJsonText(sb, "exp", nowClaimValue); + appendCommaJsonText(sb, "sub", "principalName"); + if (exists) + appendCommaJsonText(sb, "iat", nowClaimValue); + sb.append("}"); + String compactSerialization = HEADER_COMPACT_SERIALIZATION + Base64.getUrlEncoder().withoutPadding() + .encodeToString(sb.toString().getBytes(StandardCharsets.UTF_8)) + "."; + OAuthBearerUnsecuredJws testJwt = new OAuthBearerUnsecuredJws(compactSerialization, "sub", "scope"); + for (boolean required : new boolean[] {true, false}) { + for (int allowableClockSkewMs : new int[] {0, 5, 10, 20}) { + for (long whenCheckOffsetMs : new long[] {-10, 0, 10}) { + long whenCheckMs = nowMs + whenCheckOffsetMs; + OAuthBearerValidationResult result = OAuthBearerValidationUtils.validateIssuedAt(testJwt, + required, whenCheckMs, allowableClockSkewMs); + if (required && !exists) + assertTrue(isFailureWithMessageAndNoFailureScope(result), "useErrorValue || required && !exists"); + else if (!required && !exists) + assertTrue(isSuccess(result), "!required && !exists"); + else if (nowClaimValue * 1000 > whenCheckMs + allowableClockSkewMs) // issued in future + assertTrue(isFailureWithMessageAndNoFailureScope(result), + assertionFailureMessage(nowClaimValue, allowableClockSkewMs, whenCheckMs)); + else + assertTrue(isSuccess(result), + assertionFailureMessage(nowClaimValue, allowableClockSkewMs, whenCheckMs)); + } + } + } + } + } + + @Test + public void validateExpirationTime() { + long nowMs = TIME.milliseconds(); + double nowClaimValue = ((double) nowMs) / 1000; + StringBuilder sb = new StringBuilder("{"); + appendJsonText(sb, "exp", nowClaimValue); + appendCommaJsonText(sb, "sub", "principalName"); + sb.append("}"); + String compactSerialization = HEADER_COMPACT_SERIALIZATION + + Base64.getUrlEncoder().withoutPadding().encodeToString(sb.toString().getBytes(StandardCharsets.UTF_8)) + + "."; + OAuthBearerUnsecuredJws testJwt = new OAuthBearerUnsecuredJws(compactSerialization, "sub", "scope"); + for (int allowableClockSkewMs : new int[] {0, 5, 10, 20}) { + for (long whenCheckOffsetMs : new long[] {-10, 0, 10}) { + long whenCheckMs = nowMs + whenCheckOffsetMs; + OAuthBearerValidationResult result = OAuthBearerValidationUtils.validateExpirationTime(testJwt, + whenCheckMs, allowableClockSkewMs); + if (whenCheckMs - allowableClockSkewMs >= nowClaimValue * 1000) // expired + assertTrue(isFailureWithMessageAndNoFailureScope(result), + assertionFailureMessage(nowClaimValue, allowableClockSkewMs, whenCheckMs)); + else + assertTrue(isSuccess(result), assertionFailureMessage(nowClaimValue, allowableClockSkewMs, whenCheckMs)); + } + } + } + + @Test + public void validateExpirationTimeAndIssuedAtConsistency() throws OAuthBearerIllegalTokenException { + long nowMs = TIME.milliseconds(); + double nowClaimValue = ((double) nowMs) / 1000; + for (boolean issuedAtExists : new boolean[] {true, false}) { + if (!issuedAtExists) { + StringBuilder sb = new StringBuilder("{"); + appendJsonText(sb, "exp", nowClaimValue); + appendCommaJsonText(sb, "sub", "principalName"); + sb.append("}"); + String compactSerialization = HEADER_COMPACT_SERIALIZATION + Base64.getUrlEncoder().withoutPadding() + .encodeToString(sb.toString().getBytes(StandardCharsets.UTF_8)) + "."; + OAuthBearerUnsecuredJws testJwt = new OAuthBearerUnsecuredJws(compactSerialization, "sub", "scope"); + assertTrue(isSuccess(OAuthBearerValidationUtils.validateTimeConsistency(testJwt))); + } else + for (int expirationTimeOffset = -1; expirationTimeOffset <= 1; ++expirationTimeOffset) { + StringBuilder sb = new StringBuilder("{"); + appendJsonText(sb, "iat", nowClaimValue); + appendCommaJsonText(sb, "exp", nowClaimValue + expirationTimeOffset); + appendCommaJsonText(sb, "sub", "principalName"); + sb.append("}"); + String compactSerialization = HEADER_COMPACT_SERIALIZATION + Base64.getUrlEncoder().withoutPadding() + .encodeToString(sb.toString().getBytes(StandardCharsets.UTF_8)) + "."; + OAuthBearerUnsecuredJws testJwt = new OAuthBearerUnsecuredJws(compactSerialization, "sub", "scope"); + OAuthBearerValidationResult result = OAuthBearerValidationUtils.validateTimeConsistency(testJwt); + if (expirationTimeOffset <= 0) + assertTrue(isFailureWithMessageAndNoFailureScope(result)); + else + assertTrue(isSuccess(result)); + } + } + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + @Test + public void validateScope() { + long nowMs = TIME.milliseconds(); + double nowClaimValue = ((double) nowMs) / 1000; + final List noScope = Collections.emptyList(); + final List scope1 = Arrays.asList("scope1"); + final List scope1And2 = Arrays.asList("scope1", "scope2"); + for (boolean actualScopeExists : new boolean[] {true, false}) { + List scopes = !actualScopeExists ? Arrays.asList((List) null) + : Arrays.asList(noScope, scope1, scope1And2); + for (List actualScope : scopes) { + for (boolean requiredScopeExists : new boolean[] {true, false}) { + List requiredScopes = !requiredScopeExists ? Arrays.asList((List) null) + : Arrays.asList(noScope, scope1, scope1And2); + for (List requiredScope : requiredScopes) { + StringBuilder sb = new StringBuilder("{"); + appendJsonText(sb, "exp", nowClaimValue); + appendCommaJsonText(sb, "sub", "principalName"); + if (actualScope != null) + sb.append(',').append(scopeJson(actualScope)); + sb.append("}"); + String compactSerialization = HEADER_COMPACT_SERIALIZATION + Base64.getUrlEncoder() + .withoutPadding().encodeToString(sb.toString().getBytes(StandardCharsets.UTF_8)) + "."; + OAuthBearerUnsecuredJws testJwt = new OAuthBearerUnsecuredJws(compactSerialization, "sub", + "scope"); + OAuthBearerValidationResult result = OAuthBearerValidationUtils.validateScope(testJwt, + requiredScope); + if (!requiredScopeExists || requiredScope.isEmpty()) + assertTrue(isSuccess(result)); + else if (!actualScopeExists || actualScope.size() < requiredScope.size()) + assertTrue(isFailureWithMessageAndFailureScope(result)); + else + assertTrue(isSuccess(result)); + } + } + } + } + } + + private static String assertionFailureMessage(double claimValue, int allowableClockSkewMs, long whenCheckMs) { + return String.format("time=%f seconds, whenCheck = %d ms, allowableClockSkew=%d ms", claimValue, whenCheckMs, + allowableClockSkewMs); + } + + private static boolean isSuccess(OAuthBearerValidationResult result) { + return result.success(); + } + + private static boolean isFailureWithMessageAndNoFailureScope(OAuthBearerValidationResult result) { + return !result.success() && !result.failureDescription().isEmpty() && result.failureScope() == null + && result.failureOpenIdConfig() == null; + } + + private static boolean isFailureWithMessageAndFailureScope(OAuthBearerValidationResult result) { + return !result.success() && !result.failureDescription().isEmpty() && !result.failureScope().isEmpty() + && result.failureOpenIdConfig() == null; + } + + private static void appendCommaJsonText(StringBuilder sb, String claimName, Number claimValue) { + sb.append(',').append(QUOTE).append(escape(claimName)).append(QUOTE).append(":").append(claimValue); + } + + private static void appendCommaJsonText(StringBuilder sb, String claimName, String claimValue) { + sb.append(',').append(QUOTE).append(escape(claimName)).append(QUOTE).append(":").append(QUOTE) + .append(escape(claimValue)).append(QUOTE); + } + + private static void appendJsonText(StringBuilder sb, String claimName, Number claimValue) { + sb.append(QUOTE).append(escape(claimName)).append(QUOTE).append(":").append(claimValue); + } + + private static String escape(String jsonStringValue) { + return jsonStringValue.replace("\"", "\\\"").replace("\\", "\\\\"); + } + + private static String scopeJson(List scope) { + StringBuilder scopeJsonBuilder = new StringBuilder("\"scope\":["); + int initialLength = scopeJsonBuilder.length(); + for (String scopeValue : scope) { + if (scopeJsonBuilder.length() > initialLength) + scopeJsonBuilder.append(','); + scopeJsonBuilder.append('"').append(scopeValue).append('"'); + } + scopeJsonBuilder.append(']'); + return scopeJsonBuilder.toString(); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/AccessTokenBuilder.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/AccessTokenBuilder.java new file mode 100644 index 0000000..20def92 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/AccessTokenBuilder.java @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import java.io.IOException; +import java.util.Collection; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.jose4j.jwk.RsaJsonWebKey; +import org.jose4j.jwk.RsaJwkGenerator; +import org.jose4j.jws.AlgorithmIdentifiers; +import org.jose4j.jws.JsonWebSignature; +import org.jose4j.jwt.ReservedClaimNames; +import org.jose4j.lang.JoseException; + +public class AccessTokenBuilder { + + private final ObjectMapper objectMapper = new ObjectMapper(); + + private String audience; + + private String subject = "jdoe"; + + private String subjectClaimName = ReservedClaimNames.SUBJECT; + + private Object scope = "engineering"; + + private String scopeClaimName = "scope"; + + private Long issuedAtSeconds; + + private Long expirationSeconds; + + private RsaJsonWebKey jwk; + + public AccessTokenBuilder() throws JoseException { + this(new MockTime()); + } + + public AccessTokenBuilder(Time time) throws JoseException { + this.issuedAtSeconds = time.milliseconds() / 1000; + this.expirationSeconds = this.issuedAtSeconds + 60; + this.jwk = createJwk(); + } + + public static RsaJsonWebKey createJwk() throws JoseException { + RsaJsonWebKey jwk = RsaJwkGenerator.generateJwk(2048); + jwk.setKeyId("key-1"); + return jwk; + } + + public String audience() { + return audience; + } + + public AccessTokenBuilder audience(String audience) { + this.audience = audience; + return this; + } + + public String subject() { + return subject; + } + + public AccessTokenBuilder subject(String subject) { + this.subject = subject; + return this; + } + + public String subjectClaimName() { + return subjectClaimName; + } + + public AccessTokenBuilder subjectClaimName(String subjectClaimName) { + this.subjectClaimName = subjectClaimName; + return this; + } + + public Object scope() { + return scope; + } + + public AccessTokenBuilder scope(Object scope) { + this.scope = scope; + + if (scope instanceof String) { + return this; + } else if (scope instanceof Collection) { + return this; + } else { + throw new IllegalArgumentException(String.format("%s parameter must be a %s or a %s containing %s", + scopeClaimName, + String.class.getName(), + Collection.class.getName(), + String.class.getName())); + } + } + + public String scopeClaimName() { + return scopeClaimName; + } + + public AccessTokenBuilder scopeClaimName(String scopeClaimName) { + this.scopeClaimName = scopeClaimName; + return this; + } + + public Long issuedAtSeconds() { + return issuedAtSeconds; + } + + public AccessTokenBuilder issuedAtSeconds(Long issuedAtSeconds) { + this.issuedAtSeconds = issuedAtSeconds; + return this; + } + + public Long expirationSeconds() { + return expirationSeconds; + } + + public AccessTokenBuilder expirationSeconds(Long expirationSeconds) { + this.expirationSeconds = expirationSeconds; + return this; + } + + public RsaJsonWebKey jwk() { + return jwk; + } + + public AccessTokenBuilder jwk(RsaJsonWebKey jwk) { + this.jwk = jwk; + return this; + } + + @SuppressWarnings("unchecked") + public String build() throws JoseException, IOException { + ObjectNode node = objectMapper.createObjectNode(); + + if (audience != null) + node.put(ReservedClaimNames.AUDIENCE, audience); + + if (subject != null) + node.put(subjectClaimName, subject); + + if (scope instanceof String) { + node.put(scopeClaimName, (String) scope); + } else if (scope instanceof Collection) { + ArrayNode child = node.putArray(scopeClaimName); + ((Collection) scope).forEach(child::add); + } else { + throw new IllegalArgumentException(String.format("%s claim must be a %s or a %s containing %s", + scopeClaimName, + String.class.getName(), + Collection.class.getName(), + String.class.getName())); + } + + if (issuedAtSeconds != null) + node.put(ReservedClaimNames.ISSUED_AT, issuedAtSeconds); + + if (expirationSeconds != null) + node.put(ReservedClaimNames.EXPIRATION_TIME, expirationSeconds); + + String json = objectMapper.writeValueAsString(node); + + JsonWebSignature jws = new JsonWebSignature(); + jws.setPayload(json); + jws.setKey(jwk.getPrivateKey()); + jws.setKeyIdHeaderValue(jwk.getKeyId()); + jws.setAlgorithmHeaderValue(AlgorithmIdentifiers.RSA_USING_SHA256); + return jws.getCompactSerialization(); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/AccessTokenRetrieverFactoryTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/AccessTokenRetrieverFactoryTest.java new file mode 100644 index 0000000..5195315 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/AccessTokenRetrieverFactoryTest.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.io.File; +import java.util.Collections; +import java.util.Map; +import org.apache.kafka.common.config.ConfigException; +import org.junit.jupiter.api.Test; + +public class AccessTokenRetrieverFactoryTest extends OAuthBearerTest { + + @Test + public void testConfigureRefreshingFileAccessTokenRetriever() throws Exception { + String expected = "{}"; + + File tmpDir = createTempDir("access-token"); + File accessTokenFile = createTempFile(tmpDir, "access-token-", ".json", expected); + + Map configs = Collections.singletonMap(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, accessTokenFile.toURI().toString()); + Map jaasConfig = Collections.emptyMap(); + + try (AccessTokenRetriever accessTokenRetriever = AccessTokenRetrieverFactory.create(configs, jaasConfig)) { + accessTokenRetriever.init(); + assertEquals(expected, accessTokenRetriever.retrieve()); + } + } + + @Test + public void testConfigureRefreshingFileAccessTokenRetrieverWithInvalidDirectory() { + // Should fail because the parent path doesn't exist. + Map configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, new File("/tmp/this-directory-does-not-exist/foo.json").toURI().toString()); + Map jaasConfig = Collections.emptyMap(); + assertThrowsWithMessage(ConfigException.class, () -> AccessTokenRetrieverFactory.create(configs, jaasConfig), "that doesn't exist"); + } + + @Test + public void testConfigureRefreshingFileAccessTokenRetrieverWithInvalidFile() throws Exception { + // Should fail because the while the parent path exists, the file itself doesn't. + File tmpDir = createTempDir("this-directory-does-exist"); + File accessTokenFile = new File(tmpDir, "this-file-does-not-exist.json"); + Map configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, accessTokenFile.toURI().toString()); + Map jaasConfig = Collections.emptyMap(); + assertThrowsWithMessage(ConfigException.class, () -> AccessTokenRetrieverFactory.create(configs, jaasConfig), "that doesn't exist"); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/AccessTokenValidatorFactoryTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/AccessTokenValidatorFactoryTest.java new file mode 100644 index 0000000..1270674 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/AccessTokenValidatorFactoryTest.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import java.io.IOException; +import java.util.Map; +import org.apache.kafka.common.KafkaException; +import org.junit.jupiter.api.Test; + +public class AccessTokenValidatorFactoryTest extends OAuthBearerTest { + + @Test + public void testConfigureThrowsExceptionOnAccessTokenValidatorInit() { + OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); + AccessTokenRetriever accessTokenRetriever = new AccessTokenRetriever() { + @Override + public void init() throws IOException { + throw new IOException("My init had an error!"); + } + @Override + public String retrieve() { + return "dummy"; + } + }; + + Map configs = getSaslConfigs(); + AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs); + + assertThrowsWithMessage( + KafkaException.class, () -> handler.init(accessTokenRetriever, accessTokenValidator), "encountered an error when initializing"); + } + + @Test + public void testConfigureThrowsExceptionOnAccessTokenValidatorClose() { + OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); + AccessTokenRetriever accessTokenRetriever = new AccessTokenRetriever() { + @Override + public void close() throws IOException { + throw new IOException("My close had an error!"); + } + @Override + public String retrieve() { + return "dummy"; + } + }; + + Map configs = getSaslConfigs(); + AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs); + handler.init(accessTokenRetriever, accessTokenValidator); + + // Basically asserting this doesn't throw an exception :( + handler.close(); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/AccessTokenValidatorTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/AccessTokenValidatorTest.java new file mode 100644 index 0000000..8407ac3 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/AccessTokenValidatorTest.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +import org.jose4j.jws.AlgorithmIdentifiers; +import org.jose4j.jwx.HeaderParameterNames; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.TestInstance.Lifecycle; + +@TestInstance(Lifecycle.PER_CLASS) +public abstract class AccessTokenValidatorTest extends OAuthBearerTest { + + protected abstract AccessTokenValidator createAccessTokenValidator(AccessTokenBuilder accessTokenBuilder) throws Exception; + + protected AccessTokenValidator createAccessTokenValidator() throws Exception { + AccessTokenBuilder builder = new AccessTokenBuilder(); + return createAccessTokenValidator(builder); + } + + @Test + public void testNull() throws Exception { + AccessTokenValidator validator = createAccessTokenValidator(); + assertThrowsWithMessage(ValidateException.class, () -> validator.validate(null), "Empty JWT provided"); + } + + @Test + public void testEmptyString() throws Exception { + AccessTokenValidator validator = createAccessTokenValidator(); + assertThrowsWithMessage(ValidateException.class, () -> validator.validate(""), "Empty JWT provided"); + } + + @Test + public void testWhitespace() throws Exception { + AccessTokenValidator validator = createAccessTokenValidator(); + assertThrowsWithMessage(ValidateException.class, () -> validator.validate(" "), "Empty JWT provided"); + } + + @Test + public void testEmptySections() throws Exception { + AccessTokenValidator validator = createAccessTokenValidator(); + assertThrowsWithMessage(ValidateException.class, () -> validator.validate(".."), "Malformed JWT provided"); + } + + @Test + public void testMissingHeader() throws Exception { + AccessTokenValidator validator = createAccessTokenValidator(); + String header = ""; + String payload = createBase64JsonJwtSection(node -> { }); + String signature = ""; + String accessToken = String.format("%s.%s.%s", header, payload, signature); + assertThrows(ValidateException.class, () -> validator.validate(accessToken)); + } + + @Test + public void testMissingPayload() throws Exception { + AccessTokenValidator validator = createAccessTokenValidator(); + String header = createBase64JsonJwtSection(node -> node.put(HeaderParameterNames.ALGORITHM, AlgorithmIdentifiers.NONE)); + String payload = ""; + String signature = ""; + String accessToken = String.format("%s.%s.%s", header, payload, signature); + assertThrows(ValidateException.class, () -> validator.validate(accessToken)); + } + + @Test + public void testMissingSignature() throws Exception { + AccessTokenValidator validator = createAccessTokenValidator(); + String header = createBase64JsonJwtSection(node -> node.put(HeaderParameterNames.ALGORITHM, AlgorithmIdentifiers.NONE)); + String payload = createBase64JsonJwtSection(node -> { }); + String signature = ""; + String accessToken = String.format("%s.%s.%s", header, payload, signature); + assertThrows(ValidateException.class, () -> validator.validate(accessToken)); + } + +} \ No newline at end of file diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/BasicOAuthBearerTokenTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/BasicOAuthBearerTokenTest.java new file mode 100644 index 0000000..658d07f --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/BasicOAuthBearerTokenTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Arrays; +import java.util.Collections; +import java.util.SortedSet; +import java.util.TreeSet; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; +import org.junit.jupiter.api.Test; + +public class BasicOAuthBearerTokenTest { + + @Test + public void basic() { + OAuthBearerToken token = new BasicOAuthBearerToken("not.valid.token", + Collections.emptySet(), + 0L, + "jdoe", + 0L); + assertEquals("not.valid.token", token.value()); + assertTrue(token.scope().isEmpty()); + assertEquals(0L, token.lifetimeMs()); + assertEquals("jdoe", token.principalName()); + assertEquals(0L, token.startTimeMs()); + } + + @Test + public void negativeLifetime() { + OAuthBearerToken token = new BasicOAuthBearerToken("not.valid.token", + Collections.emptySet(), + -1L, + "jdoe", + 0L); + assertEquals("not.valid.token", token.value()); + assertTrue(token.scope().isEmpty()); + assertEquals(-1L, token.lifetimeMs()); + assertEquals("jdoe", token.principalName()); + assertEquals(0L, token.startTimeMs()); + } + + @Test + public void noErrorIfModifyScope() { + // Start with a basic set created by the caller. + SortedSet callerSet = new TreeSet<>(Arrays.asList("a", "b", "c")); + OAuthBearerToken token = new BasicOAuthBearerToken("not.valid.token", + callerSet, + 0L, + "jdoe", + 0L); + + // Make sure it all looks good + assertNotNull(token.scope()); + assertEquals(3, token.scope().size()); + + // Add a value to the caller's set and note that it changes the token's scope set. + // Make sure to make it read-only when it's passed in. + callerSet.add("d"); + assertTrue(token.scope().contains("d")); + + // Similarly, removing a value from the caller's will affect the token's scope set. + // Make sure to make it read-only when it's passed in. + callerSet.remove("c"); + assertFalse(token.scope().contains("c")); + + // Ensure that attempting to change the token's scope set directly will not throw any error. + token.scope().clear(); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/ClaimValidationUtilsTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/ClaimValidationUtilsTest.java new file mode 100644 index 0000000..0aeb6f7 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/ClaimValidationUtilsTest.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Arrays; +import java.util.Set; +import java.util.SortedSet; +import java.util.TreeSet; +import org.junit.jupiter.api.Test; + +public class ClaimValidationUtilsTest extends OAuthBearerTest { + + @Test + public void testValidateScopes() { + Set scopes = ClaimValidationUtils.validateScopes("scope", Arrays.asList(" a ", " b ")); + + assertEquals(2, scopes.size()); + assertTrue(scopes.contains("a")); + assertTrue(scopes.contains("b")); + } + + @Test + public void testValidateScopesDisallowsDuplicates() { + assertThrows(ValidateException.class, () -> ClaimValidationUtils.validateScopes("scope", Arrays.asList("a", "b", "a"))); + assertThrows(ValidateException.class, () -> ClaimValidationUtils.validateScopes("scope", Arrays.asList("a", "b", " a "))); + } + + @Test + public void testValidateScopesDisallowsEmptyNullAndWhitespace() { + assertThrows(ValidateException.class, () -> ClaimValidationUtils.validateScopes("scope", Arrays.asList("a", ""))); + assertThrows(ValidateException.class, () -> ClaimValidationUtils.validateScopes("scope", Arrays.asList("a", null))); + assertThrows(ValidateException.class, () -> ClaimValidationUtils.validateScopes("scope", Arrays.asList("a", " "))); + } + + @Test + public void testValidateScopesResultIsImmutable() { + SortedSet callerSet = new TreeSet<>(Arrays.asList("a", "b", "c")); + Set scopes = ClaimValidationUtils.validateScopes("scope", callerSet); + + assertEquals(3, scopes.size()); + + callerSet.add("d"); + assertEquals(4, callerSet.size()); + assertTrue(callerSet.contains("d")); + assertEquals(3, scopes.size()); + assertFalse(scopes.contains("d")); + + callerSet.remove("c"); + assertEquals(3, callerSet.size()); + assertFalse(callerSet.contains("c")); + assertEquals(3, scopes.size()); + assertTrue(scopes.contains("c")); + + callerSet.clear(); + assertEquals(0, callerSet.size()); + assertEquals(3, scopes.size()); + } + + @Test + public void testValidateScopesResultThrowsExceptionOnMutation() { + SortedSet callerSet = new TreeSet<>(Arrays.asList("a", "b", "c")); + Set scopes = ClaimValidationUtils.validateScopes("scope", callerSet); + assertThrows(UnsupportedOperationException.class, scopes::clear); + } + + @Test + public void testValidateExpiration() { + Long expected = 1L; + Long actual = ClaimValidationUtils.validateExpiration("exp", expected); + assertEquals(expected, actual); + } + + @Test + public void testValidateExpirationAllowsZero() { + Long expected = 0L; + Long actual = ClaimValidationUtils.validateExpiration("exp", expected); + assertEquals(expected, actual); + } + + @Test + public void testValidateExpirationDisallowsNull() { + assertThrows(ValidateException.class, () -> ClaimValidationUtils.validateExpiration("exp", null)); + } + + @Test + public void testValidateExpirationDisallowsNegatives() { + assertThrows(ValidateException.class, () -> ClaimValidationUtils.validateExpiration("exp", -1L)); + } + + @Test + public void testValidateSubject() { + String expected = "jdoe"; + String actual = ClaimValidationUtils.validateSubject("sub", expected); + assertEquals(expected, actual); + } + + @Test + public void testValidateSubjectDisallowsEmptyNullAndWhitespace() { + assertThrows(ValidateException.class, () -> ClaimValidationUtils.validateSubject("sub", "")); + assertThrows(ValidateException.class, () -> ClaimValidationUtils.validateSubject("sub", null)); + assertThrows(ValidateException.class, () -> ClaimValidationUtils.validateSubject("sub", " ")); + } + + @Test + public void testValidateClaimNameOverride() { + String expected = "email"; + String actual = ClaimValidationUtils.validateClaimNameOverride("sub", String.format(" %s ", expected)); + assertEquals(expected, actual); + } + + @Test + public void testValidateClaimNameOverrideDisallowsEmptyNullAndWhitespace() { + assertThrows(ValidateException.class, () -> ClaimValidationUtils.validateSubject("sub", "")); + assertThrows(ValidateException.class, () -> ClaimValidationUtils.validateSubject("sub", null)); + assertThrows(ValidateException.class, () -> ClaimValidationUtils.validateSubject("sub", " ")); + } + + @Test + public void testValidateIssuedAt() { + Long expected = 1L; + Long actual = ClaimValidationUtils.validateIssuedAt("iat", expected); + assertEquals(expected, actual); + } + + @Test + public void testValidateIssuedAtAllowsZero() { + Long expected = 0L; + Long actual = ClaimValidationUtils.validateIssuedAt("iat", expected); + assertEquals(expected, actual); + } + + @Test + public void testValidateIssuedAtAllowsNull() { + Long expected = null; + Long actual = ClaimValidationUtils.validateIssuedAt("iat", expected); + assertEquals(expected, actual); + } + + @Test + public void testValidateIssuedAtDisallowsNegatives() { + assertThrows(ValidateException.class, () -> ClaimValidationUtils.validateIssuedAt("iat", -1L)); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/ConfigurationUtilsTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/ConfigurationUtilsTest.java new file mode 100644 index 0000000..783579a --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/ConfigurationUtilsTest.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import java.io.File; +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; + +public class ConfigurationUtilsTest extends OAuthBearerTest { + + private final static String URL_CONFIG_NAME = "url"; + + @Test + public void testUrl() { + testUrl("http://www.example.com"); + } + + @Test + public void testUrlWithSuperfluousWhitespace() { + testUrl(String.format(" %s ", "http://www.example.com")); + } + + @Test + public void testUrlCaseInsensitivity() { + testUrl("HTTPS://WWW.EXAMPLE.COM"); + } + + @Test + public void testUrlFile() { + testUrl("file:///tmp/foo.txt"); + } + + @Test + public void testUrlFullPath() { + testUrl("https://myidp.example.com/oauth2/default/v1/token"); + } + + @Test + public void testUrlMissingProtocol() { + assertThrowsWithMessage(ConfigException.class, () -> testUrl("www.example.com"), "no protocol"); + } + + @Test + public void testUrlInvalidProtocol() { + assertThrowsWithMessage(ConfigException.class, () -> testUrl("ftp://ftp.example.com"), "invalid protocol"); + } + + @Test + public void testUrlNull() { + assertThrowsWithMessage(ConfigException.class, () -> testUrl(null), "must be non-null"); + } + + @Test + public void testUrlEmptyString() { + assertThrowsWithMessage(ConfigException.class, () -> testUrl(""), "must not contain only whitespace"); + } + + @Test + public void testUrlWhitespace() { + assertThrowsWithMessage(ConfigException.class, () -> testUrl(" "), "must not contain only whitespace"); + } + + private void testUrl(String value) { + Map configs = Collections.singletonMap(URL_CONFIG_NAME, value); + ConfigurationUtils cu = new ConfigurationUtils(configs); + cu.validateUrl(URL_CONFIG_NAME); + } + + @Test + public void testFile() throws IOException { + File file = TestUtils.tempFile("some contents!"); + testFile(file.toURI().toURL().toString()); + } + + @Test + public void testFileWithSuperfluousWhitespace() throws IOException { + File file = TestUtils.tempFile(); + testFile(String.format(" %s ", file.toURI().toURL())); + } + + @Test + public void testFileDoesNotExist() { + assertThrowsWithMessage(ConfigException.class, () -> testFile(new File("/tmp/not/a/real/file.txt").toURI().toURL().toString()), "that doesn't exist"); + } + + @Test + public void testFileUnreadable() throws IOException { + File file = TestUtils.tempFile(); + + if (!file.setReadable(false)) + throw new IllegalStateException(String.format("Can't test file permissions as test couldn't programmatically make temp file %s un-readable", file.getAbsolutePath())); + + assertThrowsWithMessage(ConfigException.class, () -> testFile(file.toURI().toURL().toString()), "that doesn't have read permission"); + } + + @Test + public void testFileNull() { + assertThrowsWithMessage(ConfigException.class, () -> testFile(null), "must be non-null"); + } + + @Test + public void testFileEmptyString() { + assertThrowsWithMessage(ConfigException.class, () -> testFile(""), "must not contain only whitespace"); + } + + @Test + public void testFileWhitespace() { + assertThrowsWithMessage(ConfigException.class, () -> testFile(" "), "must not contain only whitespace"); + } + + protected void testFile(String value) { + Map configs = Collections.singletonMap(URL_CONFIG_NAME, value); + ConfigurationUtils cu = new ConfigurationUtils(configs); + cu.validateFile(URL_CONFIG_NAME); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/HttpAccessTokenRetrieverTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/HttpAccessTokenRetrieverTest.java new file mode 100644 index 0000000..de3b463 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/HttpAccessTokenRetrieverTest.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.HttpURLConnection; +import java.util.Base64; +import java.util.Random; +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.api.Test; + +public class HttpAccessTokenRetrieverTest extends OAuthBearerTest { + + @Test + public void test() throws IOException { + String expectedResponse = "Hiya, buddy"; + HttpURLConnection mockedCon = createHttpURLConnection(expectedResponse); + String response = HttpAccessTokenRetriever.post(mockedCon, null, null, null, null); + assertEquals(expectedResponse, response); + } + + @Test + public void testEmptyResponse() throws IOException { + HttpURLConnection mockedCon = createHttpURLConnection(""); + assertThrows(IOException.class, () -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null)); + } + + @Test + public void testErrorReadingResponse() throws IOException { + HttpURLConnection mockedCon = createHttpURLConnection("dummy"); + when(mockedCon.getInputStream()).thenThrow(new IOException("Can't read")); + + assertThrows(IOException.class, () -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null)); + } + + @Test + public void testCopy() throws IOException { + byte[] expected = new byte[4096 + 1]; + Random r = new Random(); + r.nextBytes(expected); + InputStream in = new ByteArrayInputStream(expected); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + HttpAccessTokenRetriever.copy(in, out); + assertArrayEquals(expected, out.toByteArray()); + } + + @Test + public void testCopyError() throws IOException { + InputStream mockedIn = mock(InputStream.class); + OutputStream out = new ByteArrayOutputStream(); + when(mockedIn.read(any(byte[].class))).thenThrow(new IOException()); + assertThrows(IOException.class, () -> HttpAccessTokenRetriever.copy(mockedIn, out)); + } + + @Test + public void testParseAccessToken() throws IOException { + String expected = "abc"; + ObjectMapper mapper = new ObjectMapper(); + ObjectNode node = mapper.createObjectNode(); + node.put("access_token", expected); + + String actual = HttpAccessTokenRetriever.parseAccessToken(mapper.writeValueAsString(node)); + assertEquals(expected, actual); + } + + @Test + public void testParseAccessTokenEmptyAccessToken() { + ObjectMapper mapper = new ObjectMapper(); + ObjectNode node = mapper.createObjectNode(); + node.put("access_token", ""); + + assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.parseAccessToken(mapper.writeValueAsString(node))); + } + + @Test + public void testParseAccessTokenMissingAccessToken() { + ObjectMapper mapper = new ObjectMapper(); + ObjectNode node = mapper.createObjectNode(); + node.put("sub", "jdoe"); + + assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.parseAccessToken(mapper.writeValueAsString(node))); + } + + @Test + public void testParseAccessTokenInvalidJson() { + assertThrows(IOException.class, () -> HttpAccessTokenRetriever.parseAccessToken("not valid JSON")); + } + + @Test + public void testFormatAuthorizationHeader() throws IOException { + String expected = "Basic " + Base64.getUrlEncoder().encodeToString(Utils.utf8("id:secret")); + + String actual = HttpAccessTokenRetriever.formatAuthorizationHeader("id", "secret"); + assertEquals(expected, actual); + } + + @Test + public void testFormatAuthorizationHeaderMissingValues() { + assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader(null, "secret")); + assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader("id", null)); + assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader(null, null)); + assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader("", "secret")); + assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader("id", "")); + assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader("", "")); + assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader(" ", "secret")); + assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader("id", " ")); + assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader(" ", " ")); + } + + @Test + public void testFormatRequestBody() throws IOException { + String expected = "grant_type=client_credentials&scope=scope"; + String actual = HttpAccessTokenRetriever.formatRequestBody("scope"); + assertEquals(expected, actual); + } + + @Test + public void testFormatRequestBodyWithEscaped() throws IOException { + String questionMark = "%3F"; + String exclamationMark = "%21"; + + String expected = String.format("grant_type=client_credentials&scope=earth+is+great%s", exclamationMark); + String actual = HttpAccessTokenRetriever.formatRequestBody("earth is great!"); + assertEquals(expected, actual); + + expected = String.format("grant_type=client_credentials&scope=what+on+earth%s%s%s%s%s", questionMark, exclamationMark, questionMark, exclamationMark, questionMark); + actual = HttpAccessTokenRetriever.formatRequestBody("what on earth?!?!?"); + assertEquals(expected, actual); + } + + @Test + public void testFormatRequestBodyMissingValues() throws IOException { + String expected = "grant_type=client_credentials"; + String actual = HttpAccessTokenRetriever.formatRequestBody(null); + assertEquals(expected, actual); + + actual = HttpAccessTokenRetriever.formatRequestBody(""); + assertEquals(expected, actual); + + actual = HttpAccessTokenRetriever.formatRequestBody(" "); + assertEquals(expected, actual); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/JaasOptionsUtilsTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/JaasOptionsUtilsTest.java new file mode 100644 index 0000000..2b32408 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/JaasOptionsUtilsTest.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.net.URL; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import org.apache.kafka.common.config.SslConfigs; +import org.junit.jupiter.api.Test; + +public class JaasOptionsUtilsTest extends OAuthBearerTest { + + @Test + public void testSSLClientConfig() { + Map options = new HashMap<>(); + String sslKeystore = "test.keystore.jks"; + String sslTruststore = "test.truststore.jks"; + + options.put(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG, sslKeystore); + options.put(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG, "$3cr3+"); + options.put(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG, sslTruststore); + + JaasOptionsUtils jou = new JaasOptionsUtils(options); + Map sslClientConfig = jou.getSslClientConfig(); + assertNotNull(sslClientConfig); + assertEquals(sslKeystore, sslClientConfig.get(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG)); + assertEquals(sslTruststore, sslClientConfig.get(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG)); + assertEquals(SslConfigs.DEFAULT_SSL_PROTOCOL, sslClientConfig.get(SslConfigs.SSL_PROTOCOL_CONFIG)); + } + + @Test + public void testShouldUseSslClientConfig() throws Exception { + JaasOptionsUtils jou = new JaasOptionsUtils(Collections.emptyMap()); + assertFalse(jou.shouldCreateSSLSocketFactory(new URL("http://example.com"))); + assertTrue(jou.shouldCreateSSLSocketFactory(new URL("https://example.com"))); + assertFalse(jou.shouldCreateSSLSocketFactory(new URL("file:///tmp/test.txt"))); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/LoginAccessTokenValidatorTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/LoginAccessTokenValidatorTest.java new file mode 100644 index 0000000..6fd23f6 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/LoginAccessTokenValidatorTest.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +public class LoginAccessTokenValidatorTest extends AccessTokenValidatorTest { + + @Override + protected AccessTokenValidator createAccessTokenValidator(AccessTokenBuilder builder) { + return new LoginAccessTokenValidator(builder.scopeClaimName(), builder.subjectClaimName()); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/OAuthBearerLoginCallbackHandlerTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/OAuthBearerLoginCallbackHandlerTest.java new file mode 100644 index 0000000..4be823e --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/OAuthBearerLoginCallbackHandlerTest.java @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL; +import static org.apache.kafka.common.security.oauthbearer.secured.OAuthBearerLoginCallbackHandler.CLIENT_ID_CONFIG; +import static org.apache.kafka.common.security.oauthbearer.secured.OAuthBearerLoginCallbackHandler.CLIENT_SECRET_CONFIG; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.File; +import java.io.IOException; +import java.util.Base64; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.UnsupportedCallbackException; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.security.auth.SaslExtensionsCallback; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback; +import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerClientInitialResponse; +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.api.Test; + +public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest { + + @Test + public void testHandleTokenCallback() throws Exception { + Map configs = getSaslConfigs(); + AccessTokenBuilder builder = new AccessTokenBuilder(); + String accessToken = builder.build(); + AccessTokenRetriever accessTokenRetriever = () -> accessToken; + + OAuthBearerLoginCallbackHandler handler = createHandler(accessTokenRetriever, configs); + + try { + OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback(); + handler.handle(new Callback[] {callback}); + + assertNotNull(callback.token()); + OAuthBearerToken token = callback.token(); + assertEquals(accessToken, token.value()); + assertEquals(builder.subject(), token.principalName()); + assertEquals(builder.expirationSeconds() * 1000, token.lifetimeMs()); + assertEquals(builder.issuedAtSeconds() * 1000, token.startTimeMs()); + } finally { + handler.close(); + } + } + + @Test + public void testHandleSaslExtensionsCallback() throws Exception { + OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); + Map configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, "http://www.example.com"); + Map jaasConfig = new HashMap<>(); + jaasConfig.put(CLIENT_ID_CONFIG, "an ID"); + jaasConfig.put(CLIENT_SECRET_CONFIG, "a secret"); + jaasConfig.put("extension_foo", "1"); + jaasConfig.put("extension_bar", 2); + jaasConfig.put("EXTENSION_baz", "3"); + configureHandler(handler, configs, jaasConfig); + + try { + SaslExtensionsCallback callback = new SaslExtensionsCallback(); + handler.handle(new Callback[]{callback}); + + assertNotNull(callback.extensions()); + Map extensions = callback.extensions().map(); + assertEquals("1", extensions.get("foo")); + assertEquals("2", extensions.get("bar")); + assertNull(extensions.get("baz")); + assertEquals(2, extensions.size()); + } finally { + handler.close(); + } + } + + @Test + public void testHandleSaslExtensionsCallbackWithInvalidExtension() { + String illegalKey = "extension_" + OAuthBearerClientInitialResponse.AUTH_KEY; + + OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); + Map configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, "http://www.example.com"); + Map jaasConfig = new HashMap<>(); + jaasConfig.put(CLIENT_ID_CONFIG, "an ID"); + jaasConfig.put(CLIENT_SECRET_CONFIG, "a secret"); + jaasConfig.put(illegalKey, "this key isn't allowed per OAuthBearerClientInitialResponse.validateExtensions"); + configureHandler(handler, configs, jaasConfig); + + try { + SaslExtensionsCallback callback = new SaslExtensionsCallback(); + assertThrowsWithMessage(ConfigException.class, + () -> handler.handle(new Callback[]{callback}), + "Extension name " + OAuthBearerClientInitialResponse.AUTH_KEY + " is invalid"); + } finally { + handler.close(); + } + } + + @Test + public void testInvalidCallbackGeneratesUnsupportedCallbackException() { + Map configs = getSaslConfigs(); + OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); + AccessTokenRetriever accessTokenRetriever = () -> "foo"; + AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs); + handler.init(accessTokenRetriever, accessTokenValidator); + + try { + Callback unsupportedCallback = new Callback() { }; + assertThrows(UnsupportedCallbackException.class, () -> handler.handle(new Callback[]{unsupportedCallback})); + } finally { + handler.close(); + } + } + + @Test + public void testInvalidAccessToken() throws Exception { + testInvalidAccessToken("this isn't valid", "Malformed JWT provided"); + testInvalidAccessToken("this.isn't.valid", "malformed Base64 URL encoded value"); + testInvalidAccessToken(createAccessKey("this", "isn't", "valid"), "malformed JSON"); + testInvalidAccessToken(createAccessKey("{}", "{}", "{}"), "exp value must be non-null"); + } + + @Test + public void testMissingAccessToken() { + AccessTokenRetriever accessTokenRetriever = () -> { + throw new IOException("The token endpoint response access_token value must be non-null"); + }; + Map configs = getSaslConfigs(); + OAuthBearerLoginCallbackHandler handler = createHandler(accessTokenRetriever, configs); + + try { + OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback(); + assertThrowsWithMessage(IOException.class, + () -> handler.handle(new Callback[]{callback}), + "token endpoint response access_token value must be non-null"); + } finally { + handler.close(); + } + } + + @Test + public void testNotConfigured() { + OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); + assertThrowsWithMessage(IllegalStateException.class, () -> handler.handle(new Callback[] {}), "first call the configure or init method"); + } + + @Test + public void testConfigureWithAccessTokenFile() throws Exception { + String expected = "{}"; + + File tmpDir = createTempDir("access-token"); + File accessTokenFile = createTempFile(tmpDir, "access-token-", ".json", expected); + + OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); + Map configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, accessTokenFile.toURI().toString()); + Map jaasConfigs = Collections.emptyMap(); + configureHandler(handler, configs, jaasConfigs); + assertTrue(handler.getAccessTokenRetriever() instanceof FileTokenRetriever); + } + + @Test + public void testConfigureWithAccessClientCredentials() { + OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); + Map configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, "http://www.example.com"); + Map jaasConfigs = new HashMap<>(); + jaasConfigs.put(CLIENT_ID_CONFIG, "an ID"); + jaasConfigs.put(CLIENT_SECRET_CONFIG, "a secret"); + configureHandler(handler, configs, jaasConfigs); + assertTrue(handler.getAccessTokenRetriever() instanceof HttpAccessTokenRetriever); + } + + private void testInvalidAccessToken(String accessToken, String expectedMessageSubstring) throws Exception { + Map configs = getSaslConfigs(); + OAuthBearerLoginCallbackHandler handler = createHandler(() -> accessToken, configs); + + try { + OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback(); + handler.handle(new Callback[]{callback}); + + assertNull(callback.token()); + String actualMessage = callback.errorDescription(); + assertNotNull(actualMessage); + assertTrue(actualMessage.contains(expectedMessageSubstring), String.format( + "The error message \"%s\" didn't contain the expected substring \"%s\"", + actualMessage, expectedMessageSubstring)); + } finally { + handler.close(); + } + } + + private String createAccessKey(String header, String payload, String signature) { + Base64.Encoder enc = Base64.getEncoder(); + header = enc.encodeToString(Utils.utf8(header)); + payload = enc.encodeToString(Utils.utf8(payload)); + signature = enc.encodeToString(Utils.utf8(signature)); + return String.format("%s.%s.%s", header, payload, signature); + } + + private OAuthBearerLoginCallbackHandler createHandler(AccessTokenRetriever accessTokenRetriever, Map configs) { + OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); + AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs); + handler.init(accessTokenRetriever, accessTokenValidator); + return handler; + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/OAuthBearerTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/OAuthBearerTest.java new file mode 100644 index 0000000..6fec08d --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/OAuthBearerTest.java @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.net.HttpURLConnection; +import java.net.URL; +import java.util.Arrays; +import java.util.Base64; +import java.util.Collections; +import java.util.Iterator; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.function.Consumer; +import javax.security.auth.login.AppConfigurationEntry; +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; +import org.apache.kafka.common.security.authenticator.TestJaasConfig; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.TestInstance.Lifecycle; +import org.junit.jupiter.api.function.Executable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@TestInstance(Lifecycle.PER_CLASS) +public abstract class OAuthBearerTest { + + protected final Logger log = LoggerFactory.getLogger(getClass()); + + protected ObjectMapper mapper = new ObjectMapper(); + + protected void assertThrowsWithMessage(Class clazz, + Executable executable, + String substring) { + boolean failed = false; + + try { + executable.execute(); + } catch (Throwable t) { + failed = true; + assertTrue(clazz.isInstance(t), String.format("Test failed by exception %s, but expected %s", t.getClass(), clazz)); + + assertErrorMessageContains(t.getMessage(), substring); + } + + if (!failed) + fail("Expected test to fail with " + clazz + " that contains the string " + substring); + } + + protected void assertErrorMessageContains(String actual, String expectedSubstring) { + assertTrue(actual.contains(expectedSubstring), + String.format("Expected exception message (\"%s\") to contain substring (\"%s\")", + actual, + expectedSubstring)); + } + + protected void configureHandler(AuthenticateCallbackHandler handler, + Map configs, + Map jaasConfig) { + TestJaasConfig config = new TestJaasConfig(); + config.createOrUpdateEntry("KafkaClient", OAuthBearerLoginModule.class.getName(), jaasConfig); + AppConfigurationEntry kafkaClient = config.getAppConfigurationEntry("KafkaClient")[0]; + + handler.configure(configs, + OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, + Collections.singletonList(kafkaClient)); + } + + protected String createBase64JsonJwtSection(Consumer c) { + String json = createJsonJwtSection(c); + + try { + return Utils.utf8(Base64.getEncoder().encode(Utils.utf8(json))); + } catch (Throwable t) { + fail(t); + + // Shouldn't get to here... + return null; + } + } + + protected String createJsonJwtSection(Consumer c) { + ObjectNode node = mapper.createObjectNode(); + c.accept(node); + + try { + return mapper.writeValueAsString(node); + } catch (Throwable t) { + fail(t); + + // Shouldn't get to here... + return null; + } + } + + protected Retryable createRetryable(Exception[] attempts) { + Iterator i = Arrays.asList(attempts).iterator(); + + return () -> { + Exception e = i.hasNext() ? i.next() : null; + + if (e == null) { + return "success!"; + } else { + if (e instanceof IOException) + throw new ExecutionException(e); + else if (e instanceof RuntimeException) + throw (RuntimeException) e; + else + throw new RuntimeException(e); + } + }; + } + + protected HttpURLConnection createHttpURLConnection(String response) throws IOException { + HttpURLConnection mockedCon = mock(HttpURLConnection.class); + when(mockedCon.getURL()).thenReturn(new URL("https://www.example.com")); + when(mockedCon.getResponseCode()).thenReturn(200); + when(mockedCon.getOutputStream()).thenReturn(new ByteArrayOutputStream()); + when(mockedCon.getInputStream()).thenReturn(new ByteArrayInputStream(Utils.utf8(response))); + return mockedCon; + } + + protected File createTempDir(String directory) throws IOException { + File tmpDir = new File(System.getProperty("java.io.tmpdir")); + + if (directory != null) + tmpDir = new File(tmpDir, directory); + + if (!tmpDir.exists() && !tmpDir.mkdirs()) + throw new IOException("Could not create " + tmpDir); + + tmpDir.deleteOnExit(); + log.debug("Created temp directory {}", tmpDir); + return tmpDir; + } + + protected File createTempFile(File tmpDir, + String prefix, + String suffix, + String contents) + throws IOException { + File file = File.createTempFile(prefix, suffix, tmpDir); + log.debug("Created new temp file {}", file); + file.deleteOnExit(); + + try (FileWriter writer = new FileWriter(file)) { + writer.write(contents); + } + + return file; + } + + protected Map getSaslConfigs(Map configs) { + ConfigDef configDef = new ConfigDef(); + configDef.withClientSaslSupport(); + AbstractConfig sslClientConfig = new AbstractConfig(configDef, configs); + return sslClientConfig.values(); + } + + protected Map getSaslConfigs(String name, Object value) { + return getSaslConfigs(Collections.singletonMap(name, value)); + } + + protected Map getSaslConfigs() { + return getSaslConfigs(Collections.emptyMap()); + } + +} \ No newline at end of file diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/OAuthBearerValidatorCallbackHandlerTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/OAuthBearerValidatorCallbackHandlerTest.java new file mode 100644 index 0000000..326197d --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/OAuthBearerValidatorCallbackHandlerTest.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_EXPECTED_AUDIENCE; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Arrays; +import java.util.Base64; +import java.util.List; +import java.util.Map; +import javax.security.auth.callback.Callback; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerValidatorCallback; +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.api.Test; + +public class OAuthBearerValidatorCallbackHandlerTest extends OAuthBearerTest { + + @Test + public void testBasic() throws Exception { + String expectedAudience = "a"; + List allAudiences = Arrays.asList(expectedAudience, "b", "c"); + AccessTokenBuilder builder = new AccessTokenBuilder().audience(expectedAudience); + String accessToken = builder.build(); + + Map configs = getSaslConfigs(SASL_OAUTHBEARER_EXPECTED_AUDIENCE, allAudiences); + OAuthBearerValidatorCallbackHandler handler = createHandler(configs, builder); + + try { + OAuthBearerValidatorCallback callback = new OAuthBearerValidatorCallback(accessToken); + handler.handle(new Callback[]{callback}); + + assertNotNull(callback.token()); + OAuthBearerToken token = callback.token(); + assertEquals(accessToken, token.value()); + assertEquals(builder.subject(), token.principalName()); + assertEquals(builder.expirationSeconds() * 1000, token.lifetimeMs()); + assertEquals(builder.issuedAtSeconds() * 1000, token.startTimeMs()); + } finally { + handler.close(); + } + } + + @Test + public void testInvalidAccessToken() throws Exception { + // There aren't different error messages for the validation step, so these are all the + // same :( + String substring = "invalid_token"; + assertInvalidAccessTokenFails("this isn't valid", substring); + assertInvalidAccessTokenFails("this.isn't.valid", substring); + assertInvalidAccessTokenFails(createAccessKey("this", "isn't", "valid"), substring); + assertInvalidAccessTokenFails(createAccessKey("{}", "{}", "{}"), substring); + } + + private void assertInvalidAccessTokenFails(String accessToken, String expectedMessageSubstring) throws Exception { + Map configs = getSaslConfigs(); + OAuthBearerValidatorCallbackHandler handler = createHandler(configs, new AccessTokenBuilder()); + + try { + OAuthBearerValidatorCallback callback = new OAuthBearerValidatorCallback(accessToken); + handler.handle(new Callback[] {callback}); + + assertNull(callback.token()); + String actualMessage = callback.errorStatus(); + assertNotNull(actualMessage); + assertTrue(actualMessage.contains(expectedMessageSubstring), String.format("The error message \"%s\" didn't contain the expected substring \"%s\"", actualMessage, expectedMessageSubstring)); + } finally { + handler.close(); + } + } + + private OAuthBearerValidatorCallbackHandler createHandler(Map options, + AccessTokenBuilder builder) { + OAuthBearerValidatorCallbackHandler handler = new OAuthBearerValidatorCallbackHandler(); + CloseableVerificationKeyResolver verificationKeyResolver = (jws, nestingContext) -> + builder.jwk().getRsaPublicKey(); + AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(options, verificationKeyResolver); + handler.init(verificationKeyResolver, accessTokenValidator); + return handler; + } + + private String createAccessKey(String header, String payload, String signature) { + Base64.Encoder enc = Base64.getEncoder(); + header = enc.encodeToString(Utils.utf8(header)); + payload = enc.encodeToString(Utils.utf8(payload)); + signature = enc.encodeToString(Utils.utf8(signature)); + return String.format("%s.%s.%s", header, payload, signature); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/RefreshingHttpsJwksTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/RefreshingHttpsJwksTest.java new file mode 100644 index 0000000..27711ea --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/RefreshingHttpsJwksTest.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import static org.apache.kafka.common.security.oauthbearer.secured.RefreshingHttpsJwks.MISSING_KEY_ID_CACHE_IN_FLIGHT_MS; +import static org.apache.kafka.common.security.oauthbearer.secured.RefreshingHttpsJwks.MISSING_KEY_ID_MAX_KEY_LENGTH; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.jose4j.http.SimpleResponse; +import org.jose4j.jwk.HttpsJwks; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +public class RefreshingHttpsJwksTest extends OAuthBearerTest { + + private static final int REFRESH_MS = 5000; + + private static final int RETRY_BACKOFF_MS = 50; + + private static final int RETRY_BACKOFF_MAX_MS = 2000; + + /** + * Test that a key not previously scheduled for refresh will be scheduled without a refresh. + */ + + @Test + public void testBasicScheduleRefresh() throws Exception { + String keyId = "abc123"; + Time time = new MockTime(); + HttpsJwks httpsJwks = spyHttpsJwks(); + + try (RefreshingHttpsJwks refreshingHttpsJwks = getRefreshingHttpsJwks(time, httpsJwks)) { + refreshingHttpsJwks.init(); + verify(httpsJwks, times(1)).refresh(); + assertTrue(refreshingHttpsJwks.maybeExpediteRefresh(keyId)); + verify(httpsJwks, times(1)).refresh(); + } + } + + /** + * Test that a key previously scheduled for refresh will not be scheduled a second time + * if it's requested right away. + */ + + @Test + public void testMaybeExpediteRefreshNoDelay() throws Exception { + String keyId = "abc123"; + Time time = new MockTime(); + HttpsJwks httpsJwks = spyHttpsJwks(); + + try (RefreshingHttpsJwks refreshingHttpsJwks = getRefreshingHttpsJwks(time, httpsJwks)) { + refreshingHttpsJwks.init(); + assertTrue(refreshingHttpsJwks.maybeExpediteRefresh(keyId)); + assertFalse(refreshingHttpsJwks.maybeExpediteRefresh(keyId)); + } + } + + /** + * Test that a key previously scheduled for refresh will be scheduled a second time + * if it's requested after the delay. + */ + + @Test + public void testMaybeExpediteRefreshDelays() throws Exception { + assertMaybeExpediteRefreshWithDelay(MISSING_KEY_ID_CACHE_IN_FLIGHT_MS - 1, false); + assertMaybeExpediteRefreshWithDelay(MISSING_KEY_ID_CACHE_IN_FLIGHT_MS, true); + assertMaybeExpediteRefreshWithDelay(MISSING_KEY_ID_CACHE_IN_FLIGHT_MS + 1, true); + } + + /** + * Test that a "long key" will not be looked up because the key ID is too long. + */ + + @Test + public void testLongKey() throws Exception { + char[] keyIdChars = new char[MISSING_KEY_ID_MAX_KEY_LENGTH + 1]; + Arrays.fill(keyIdChars, '0'); + String keyId = new String(keyIdChars); + + Time time = new MockTime(); + HttpsJwks httpsJwks = spyHttpsJwks(); + + try (RefreshingHttpsJwks refreshingHttpsJwks = getRefreshingHttpsJwks(time, httpsJwks)) { + refreshingHttpsJwks.init(); + verify(httpsJwks, times(1)).refresh(); + assertFalse(refreshingHttpsJwks.maybeExpediteRefresh(keyId)); + verify(httpsJwks, times(1)).refresh(); + } + } + + /** + * Test that if we ask to load a missing key, and then we wait past the sleep time that it will + * call refresh to load the key. + */ + + @Test + public void testSecondaryRefreshAfterElapsedDelay() throws Exception { + String keyId = "abc123"; + Time time = MockTime.SYSTEM; // Unfortunately, we can't mock time here because the + // scheduled executor doesn't respect it. + HttpsJwks httpsJwks = spyHttpsJwks(); + + try (RefreshingHttpsJwks refreshingHttpsJwks = getRefreshingHttpsJwks(time, httpsJwks)) { + refreshingHttpsJwks.init(); + verify(httpsJwks, times(1)).refresh(); + assertTrue(refreshingHttpsJwks.maybeExpediteRefresh(keyId)); + time.sleep(REFRESH_MS + 1); + verify(httpsJwks, times(3)).refresh(); + assertFalse(refreshingHttpsJwks.maybeExpediteRefresh(keyId)); + } + } + + private void assertMaybeExpediteRefreshWithDelay(long sleepDelay, boolean shouldBeScheduled) throws Exception { + String keyId = "abc123"; + Time time = new MockTime(); + HttpsJwks httpsJwks = spyHttpsJwks(); + + try (RefreshingHttpsJwks refreshingHttpsJwks = getRefreshingHttpsJwks(time, httpsJwks)) { + refreshingHttpsJwks.init(); + assertTrue(refreshingHttpsJwks.maybeExpediteRefresh(keyId)); + time.sleep(sleepDelay); + assertEquals(shouldBeScheduled, refreshingHttpsJwks.maybeExpediteRefresh(keyId)); + } + } + + private RefreshingHttpsJwks getRefreshingHttpsJwks(final Time time, final HttpsJwks httpsJwks) { + return new RefreshingHttpsJwks(time, httpsJwks, REFRESH_MS, RETRY_BACKOFF_MS, RETRY_BACKOFF_MAX_MS); + } + + /** + * We *spy* (not *mock*) the {@link HttpsJwks} instance because we want to have it + * _partially mocked_ to determine if it's calling its internal refresh method. We want to + * make sure it *doesn't* do that when we call our getJsonWebKeys() method on + * {@link RefreshingHttpsJwks}. + */ + + private HttpsJwks spyHttpsJwks() { + HttpsJwks httpsJwks = new HttpsJwks("https://www.example.com"); + + SimpleResponse simpleResponse = new SimpleResponse() { + @Override + public int getStatusCode() { + return 200; + } + + @Override + public String getStatusMessage() { + return "OK"; + } + + @Override + public Collection getHeaderNames() { + return Collections.emptyList(); + } + + @Override + public List getHeaderValues(String name) { + return Collections.emptyList(); + } + + @Override + public String getBody() { + return "{\"keys\": []}"; + } + }; + + httpsJwks.setSimpleHttpGet(l -> simpleResponse); + + return Mockito.spy(httpsJwks); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/RetryTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/RetryTest.java new file mode 100644 index 0000000..d04b8c5 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/RetryTest.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.io.IOException; +import java.util.concurrent.ExecutionException; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.junit.jupiter.api.Test; + +public class RetryTest extends OAuthBearerTest { + + @Test + public void test() throws ExecutionException { + Exception[] attempts = new Exception[] { + new IOException("pretend connect error"), + new IOException("pretend timeout error"), + new IOException("pretend read error"), + null // success! + }; + long retryWaitMs = 1000; + long maxWaitMs = 10000; + Retryable call = createRetryable(attempts); + + Time time = new MockTime(0, 0, 0); + assertEquals(0L, time.milliseconds()); + Retry r = new Retry<>(time, retryWaitMs, maxWaitMs); + r.execute(call); + + long secondWait = retryWaitMs * 2; + long thirdWait = retryWaitMs * 4; + long totalWait = retryWaitMs + secondWait + thirdWait; + assertEquals(totalWait, time.milliseconds()); + } + + @Test + public void testIOExceptionFailure() { + Exception[] attempts = new Exception[] { + new IOException("pretend connect error"), + new IOException("pretend timeout error"), + new IOException("pretend read error"), + new IOException("pretend another read error"), + }; + long retryWaitMs = 1000; + long maxWaitMs = 1000 + 2000 + 3999; + Retryable call = createRetryable(attempts); + + Time time = new MockTime(0, 0, 0); + assertEquals(0L, time.milliseconds()); + Retry r = new Retry<>(time, retryWaitMs, maxWaitMs); + + assertThrows(ExecutionException.class, () -> r.execute(call)); + + assertEquals(maxWaitMs, time.milliseconds()); + } + + @Test + public void testRuntimeExceptionFailureOnLastAttempt() { + Exception[] attempts = new Exception[] { + new IOException("pretend connect error"), + new IOException("pretend timeout error"), + new NullPointerException("pretend JSON node /userId in response is null") + }; + long retryWaitMs = 1000; + long maxWaitMs = 10000; + Retryable call = createRetryable(attempts); + + Time time = new MockTime(0, 0, 0); + assertEquals(0L, time.milliseconds()); + Retry r = new Retry<>(time, retryWaitMs, maxWaitMs); + + assertThrows(RuntimeException.class, () -> r.execute(call)); + + long secondWait = retryWaitMs * 2; + long totalWait = retryWaitMs + secondWait; + assertEquals(totalWait, time.milliseconds()); + } + + @Test + public void testRuntimeExceptionFailureOnFirstAttempt() { + Exception[] attempts = new Exception[] { + new NullPointerException("pretend JSON node /userId in response is null"), + null + }; + long retryWaitMs = 1000; + long maxWaitMs = 10000; + Retryable call = createRetryable(attempts); + + Time time = new MockTime(0, 0, 0); + assertEquals(0L, time.milliseconds()); + Retry r = new Retry<>(time, retryWaitMs, maxWaitMs); + + assertThrows(RuntimeException.class, () -> r.execute(call)); + + assertEquals(0, time.milliseconds()); + } + + @Test + public void testUseMaxTimeout() throws IOException { + Exception[] attempts = new Exception[] { + new IOException("pretend connect error"), + new IOException("pretend timeout error"), + new IOException("pretend read error") + }; + long retryWaitMs = 5000; + long maxWaitMs = 5000; + Retryable call = createRetryable(attempts); + + Time time = new MockTime(0, 0, 0); + assertEquals(0L, time.milliseconds()); + Retry r = new Retry<>(time, retryWaitMs, maxWaitMs); + + assertThrows(ExecutionException.class, () -> r.execute(call)); + + assertEquals(maxWaitMs, time.milliseconds()); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/ValidatorAccessTokenValidatorTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/ValidatorAccessTokenValidatorTest.java new file mode 100644 index 0000000..76333e3 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/ValidatorAccessTokenValidatorTest.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.secured; + +import java.util.Collections; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; +import org.jose4j.jws.AlgorithmIdentifiers; +import org.jose4j.jws.JsonWebSignature; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class ValidatorAccessTokenValidatorTest extends AccessTokenValidatorTest { + + @Override + protected AccessTokenValidator createAccessTokenValidator(AccessTokenBuilder builder) { + return new ValidatorAccessTokenValidator(30, + Collections.emptySet(), + null, + (jws, nestingContext) -> builder.jwk().getKey(), + builder.scopeClaimName(), + builder.subjectClaimName()); + } + + @Test + public void testBasicEncryption() throws Exception { + AccessTokenBuilder builder = new AccessTokenBuilder(); + AccessTokenValidator validator = createAccessTokenValidator(builder); + + JsonWebSignature jws = new JsonWebSignature(); + jws.setKey(builder.jwk().getPrivateKey()); + jws.setKeyIdHeaderValue(builder.jwk().getKeyId()); + jws.setAlgorithmHeaderValue(AlgorithmIdentifiers.RSA_USING_SHA256); + String accessToken = builder.build(); + + OAuthBearerToken token = validator.validate(accessToken); + + assertEquals(builder.subject(), token.principalName()); + assertEquals(builder.issuedAtSeconds() * 1000, token.startTimeMs()); + assertEquals(builder.expirationSeconds() * 1000, token.lifetimeMs()); + assertEquals(1, token.scope().size()); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/plain/internals/PlainSaslServerTest.java b/clients/src/test/java/org/apache/kafka/common/security/plain/internals/PlainSaslServerTest.java new file mode 100644 index 0000000..77882b6 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/plain/internals/PlainSaslServerTest.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.plain.internals; + +import org.apache.kafka.common.security.plain.PlainLoginModule; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import org.apache.kafka.common.errors.SaslAuthenticationException; +import org.apache.kafka.common.security.JaasContext; +import org.apache.kafka.common.security.authenticator.TestJaasConfig; + +public class PlainSaslServerTest { + + private static final String USER_A = "userA"; + private static final String PASSWORD_A = "passwordA"; + private static final String USER_B = "userB"; + private static final String PASSWORD_B = "passwordB"; + + private PlainSaslServer saslServer; + + @BeforeEach + public void setUp() { + TestJaasConfig jaasConfig = new TestJaasConfig(); + Map options = new HashMap<>(); + options.put("user_" + USER_A, PASSWORD_A); + options.put("user_" + USER_B, PASSWORD_B); + jaasConfig.addEntry("jaasContext", PlainLoginModule.class.getName(), options); + JaasContext jaasContext = new JaasContext("jaasContext", JaasContext.Type.SERVER, jaasConfig, null); + PlainServerCallbackHandler callbackHandler = new PlainServerCallbackHandler(); + callbackHandler.configure(null, "PLAIN", jaasContext.configurationEntries()); + saslServer = new PlainSaslServer(callbackHandler); + } + + @Test + public void noAuthorizationIdSpecified() throws Exception { + byte[] nextChallenge = saslServer.evaluateResponse(saslMessage("", USER_A, PASSWORD_A)); + assertEquals(0, nextChallenge.length); + } + + @Test + public void authorizatonIdEqualsAuthenticationId() throws Exception { + byte[] nextChallenge = saslServer.evaluateResponse(saslMessage(USER_A, USER_A, PASSWORD_A)); + assertEquals(0, nextChallenge.length); + } + + @Test + public void authorizatonIdNotEqualsAuthenticationId() { + assertThrows(SaslAuthenticationException.class, () -> saslServer.evaluateResponse(saslMessage(USER_B, USER_A, PASSWORD_A))); + } + + @Test + public void emptyTokens() { + Exception e = assertThrows(SaslAuthenticationException.class, () -> + saslServer.evaluateResponse(saslMessage("", "", ""))); + assertEquals("Authentication failed: username not specified", e.getMessage()); + + e = assertThrows(SaslAuthenticationException.class, () -> + saslServer.evaluateResponse(saslMessage("", "", "p"))); + assertEquals("Authentication failed: username not specified", e.getMessage()); + + e = assertThrows(SaslAuthenticationException.class, () -> + saslServer.evaluateResponse(saslMessage("", "u", ""))); + assertEquals("Authentication failed: password not specified", e.getMessage()); + + e = assertThrows(SaslAuthenticationException.class, () -> + saslServer.evaluateResponse(saslMessage("a", "", ""))); + assertEquals("Authentication failed: username not specified", e.getMessage()); + + e = assertThrows(SaslAuthenticationException.class, () -> + saslServer.evaluateResponse(saslMessage("a", "", "p"))); + assertEquals("Authentication failed: username not specified", e.getMessage()); + + e = assertThrows(SaslAuthenticationException.class, () -> + saslServer.evaluateResponse(saslMessage("a", "u", ""))); + assertEquals("Authentication failed: password not specified", e.getMessage()); + + String nul = "\u0000"; + + e = assertThrows(SaslAuthenticationException.class, () -> + saslServer.evaluateResponse( + String.format("%s%s%s%s%s%s", "a", nul, "u", nul, "p", nul).getBytes(StandardCharsets.UTF_8))); + assertEquals("Invalid SASL/PLAIN response: expected 3 tokens, got 4", e.getMessage()); + + e = assertThrows(SaslAuthenticationException.class, () -> + saslServer.evaluateResponse( + String.format("%s%s%s", "", nul, "u").getBytes(StandardCharsets.UTF_8))); + assertEquals("Invalid SASL/PLAIN response: expected 3 tokens, got 2", e.getMessage()); + } + + private byte[] saslMessage(String authorizationId, String userName, String password) { + String nul = "\u0000"; + String message = String.format("%s%s%s%s%s", authorizationId, nul, userName, nul, password); + return message.getBytes(StandardCharsets.UTF_8); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/scram/internals/ScramCredentialUtilsTest.java b/clients/src/test/java/org/apache/kafka/common/security/scram/internals/ScramCredentialUtilsTest.java new file mode 100644 index 0000000..749c4fd --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/scram/internals/ScramCredentialUtilsTest.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.scram.internals; + +import java.security.NoSuchAlgorithmException; +import java.util.Arrays; + +import org.apache.kafka.common.security.authenticator.CredentialCache; +import org.apache.kafka.common.security.scram.ScramCredential; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + + +public class ScramCredentialUtilsTest { + + private ScramFormatter formatter; + + @BeforeEach + public void setUp() throws NoSuchAlgorithmException { + formatter = new ScramFormatter(ScramMechanism.SCRAM_SHA_256); + } + + @Test + public void stringConversion() { + ScramCredential credential = formatter.generateCredential("password", 1024); + assertTrue(credential.salt().length > 0, "Salt must not be empty"); + assertTrue(credential.storedKey().length > 0, "Stored key must not be empty"); + assertTrue(credential.serverKey().length > 0, "Server key must not be empty"); + ScramCredential credential2 = ScramCredentialUtils.credentialFromString(ScramCredentialUtils.credentialToString(credential)); + assertArrayEquals(credential.salt(), credential2.salt()); + assertArrayEquals(credential.storedKey(), credential2.storedKey()); + assertArrayEquals(credential.serverKey(), credential2.serverKey()); + assertEquals(credential.iterations(), credential2.iterations()); + } + + @Test + public void generateCredential() { + ScramCredential credential1 = formatter.generateCredential("password", 4096); + ScramCredential credential2 = formatter.generateCredential("password", 4096); + // Random salt should ensure that the credentials persisted are different every time + assertNotEquals(ScramCredentialUtils.credentialToString(credential1), ScramCredentialUtils.credentialToString(credential2)); + } + + @Test + public void invalidCredential() { + assertThrows(IllegalArgumentException.class, () -> ScramCredentialUtils.credentialFromString("abc")); + } + + @Test + public void missingFields() { + String cred = ScramCredentialUtils.credentialToString(formatter.generateCredential("password", 2048)); + assertThrows(IllegalArgumentException.class, () -> ScramCredentialUtils.credentialFromString(cred.substring(cred.indexOf(',')))); + } + + @Test + public void extraneousFields() { + String cred = ScramCredentialUtils.credentialToString(formatter.generateCredential("password", 2048)); + assertThrows(IllegalArgumentException.class, () -> ScramCredentialUtils.credentialFromString(cred + ",a=test")); + } + + @Test + public void scramCredentialCache() throws Exception { + CredentialCache cache = new CredentialCache(); + ScramCredentialUtils.createCache(cache, Arrays.asList("SCRAM-SHA-512", "PLAIN")); + assertNotNull(cache.cache(ScramMechanism.SCRAM_SHA_512.mechanismName(), ScramCredential.class), "Cache not created for enabled mechanism"); + assertNull(cache.cache(ScramMechanism.SCRAM_SHA_256.mechanismName(), ScramCredential.class), "Cache created for disabled mechanism"); + + CredentialCache.Cache sha512Cache = cache.cache(ScramMechanism.SCRAM_SHA_512.mechanismName(), ScramCredential.class); + ScramFormatter formatter = new ScramFormatter(ScramMechanism.SCRAM_SHA_512); + ScramCredential credentialA = formatter.generateCredential("password", 4096); + sha512Cache.put("userA", credentialA); + assertEquals(credentialA, sha512Cache.get("userA")); + assertNull(sha512Cache.get("userB"), "Invalid user credential"); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/scram/internals/ScramFormatterTest.java b/clients/src/test/java/org/apache/kafka/common/security/scram/internals/ScramFormatterTest.java new file mode 100644 index 0000000..8d7a8ec --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/scram/internals/ScramFormatterTest.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.scram.internals; + +import org.apache.kafka.common.security.scram.internals.ScramMessages.ClientFinalMessage; +import org.apache.kafka.common.security.scram.internals.ScramMessages.ClientFirstMessage; +import org.apache.kafka.common.security.scram.internals.ScramMessages.ServerFinalMessage; +import org.apache.kafka.common.security.scram.internals.ScramMessages.ServerFirstMessage; + +import org.junit.jupiter.api.Test; + +import java.util.Base64; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class ScramFormatterTest { + + /** + * Tests that the formatter implementation produces the same values for the + * example included in RFC 7677 + */ + @Test + public void rfc7677Example() throws Exception { + ScramFormatter formatter = new ScramFormatter(ScramMechanism.SCRAM_SHA_256); + + String password = "pencil"; + String c1 = "n,,n=user,r=rOprNGfwEbeRWgbNEkqO"; + String s1 = "r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,s=W22ZaJ0SNY7soEsUEjb6gQ==,i=4096"; + String c2 = "c=biws,r=rOprNGfwEbeRWgbNEkqO%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0,p=dHzbZapWIk4jUhN+Ute9ytag9zjfMHgsqmmiz7AndVQ="; + String s2 = "v=6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4="; + ClientFirstMessage clientFirst = new ClientFirstMessage(ScramFormatter.toBytes(c1)); + ServerFirstMessage serverFirst = new ServerFirstMessage(ScramFormatter.toBytes(s1)); + ClientFinalMessage clientFinal = new ClientFinalMessage(ScramFormatter.toBytes(c2)); + ServerFinalMessage serverFinal = new ServerFinalMessage(ScramFormatter.toBytes(s2)); + + String username = clientFirst.saslName(); + assertEquals("user", username); + String clientNonce = clientFirst.nonce(); + assertEquals("rOprNGfwEbeRWgbNEkqO", clientNonce); + String serverNonce = serverFirst.nonce().substring(clientNonce.length()); + assertEquals("%hvYDpWUa2RaTCAfuxFIlj)hNlF$k0", serverNonce); + byte[] salt = serverFirst.salt(); + assertArrayEquals(Base64.getDecoder().decode("W22ZaJ0SNY7soEsUEjb6gQ=="), salt); + int iterations = serverFirst.iterations(); + assertEquals(4096, iterations); + byte[] channelBinding = clientFinal.channelBinding(); + assertArrayEquals(Base64.getDecoder().decode("biws"), channelBinding); + byte[] serverSignature = serverFinal.serverSignature(); + assertArrayEquals(Base64.getDecoder().decode("6rriTRBi23WpRR/wtup+mMhUZUn/dB5nLTJRsjl95G4="), serverSignature); + + byte[] saltedPassword = formatter.saltedPassword(password, salt, iterations); + byte[] serverKey = formatter.serverKey(saltedPassword); + byte[] computedProof = formatter.clientProof(saltedPassword, clientFirst, serverFirst, clientFinal); + assertArrayEquals(clientFinal.proof(), computedProof); + byte[] computedSignature = formatter.serverSignature(serverKey, clientFirst, serverFirst, clientFinal); + assertArrayEquals(serverFinal.serverSignature(), computedSignature); + + // Minimum iterations defined in RFC-7677 + assertEquals(4096, ScramMechanism.SCRAM_SHA_256.minIterations()); + } + + /** + * Tests encoding of username + */ + @Test + public void saslName() throws Exception { + String[] usernames = {"user1", "123", "1,2", "user=A", "user==B", "user,1", "user 1", ",", "=", ",=", "=="}; + ScramFormatter formatter = new ScramFormatter(ScramMechanism.SCRAM_SHA_256); + for (String username : usernames) { + String saslName = ScramFormatter.saslName(username); + // There should be no commas in saslName (comma is used as field separator in SASL messages) + assertEquals(-1, saslName.indexOf(',')); + // There should be no "=" in the saslName apart from those used in encoding (comma is =2C and equals is =3D) + assertEquals(-1, saslName.replace("=2C", "").replace("=3D", "").indexOf('=')); + assertEquals(username, ScramFormatter.username(saslName)); + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/scram/internals/ScramMessagesTest.java b/clients/src/test/java/org/apache/kafka/common/security/scram/internals/ScramMessagesTest.java new file mode 100644 index 0000000..a286085 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/scram/internals/ScramMessagesTest.java @@ -0,0 +1,353 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.scram.internals; + +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.Collections; + +import javax.security.sasl.SaslException; + +import org.apache.kafka.common.security.scram.internals.ScramMessages.AbstractScramMessage; +import org.apache.kafka.common.security.scram.internals.ScramMessages.ClientFinalMessage; +import org.apache.kafka.common.security.scram.internals.ScramMessages.ClientFirstMessage; +import org.apache.kafka.common.security.scram.internals.ScramMessages.ServerFinalMessage; +import org.apache.kafka.common.security.scram.internals.ScramMessages.ServerFirstMessage; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class ScramMessagesTest { + + private static final String[] VALID_EXTENSIONS = { + "ext=val1", + "anotherext=name1=value1 name2=another test value \"\'!$[]()", + "first=val1,second=name1 = value ,third=123" + }; + private static final String[] INVALID_EXTENSIONS = { + "ext1=value", + "ext", + "ext=value1,value2", + "ext=,", + "ext =value" + }; + + private static final String[] VALID_RESERVED = { + "m=reserved-value", + "m=name1=value1 name2=another test value \"\'!$[]()" + }; + private static final String[] INVALID_RESERVED = { + "m", + "m=name,value", + "m=," + }; + + private ScramFormatter formatter; + + @BeforeEach + public void setUp() throws Exception { + formatter = new ScramFormatter(ScramMechanism.SCRAM_SHA_256); + } + + @Test + public void validClientFirstMessage() throws SaslException { + String nonce = formatter.secureRandomString(); + ClientFirstMessage m = new ClientFirstMessage("someuser", nonce, Collections.emptyMap()); + checkClientFirstMessage(m, "someuser", nonce, ""); + + // Default format used by Kafka client: only user and nonce are specified + String str = String.format("n,,n=testuser,r=%s", nonce); + m = createScramMessage(ClientFirstMessage.class, str); + checkClientFirstMessage(m, "testuser", nonce, ""); + m = new ClientFirstMessage(m.toBytes()); + checkClientFirstMessage(m, "testuser", nonce, ""); + + // Username containing comma, encoded as =2C + str = String.format("n,,n=test=2Cuser,r=%s", nonce); + m = createScramMessage(ClientFirstMessage.class, str); + checkClientFirstMessage(m, "test=2Cuser", nonce, ""); + assertEquals("test,user", ScramFormatter.username(m.saslName())); + + // Username containing equals, encoded as =3D + str = String.format("n,,n=test=3Duser,r=%s", nonce); + m = createScramMessage(ClientFirstMessage.class, str); + checkClientFirstMessage(m, "test=3Duser", nonce, ""); + assertEquals("test=user", ScramFormatter.username(m.saslName())); + + // Optional authorization id specified + str = String.format("n,a=testauthzid,n=testuser,r=%s", nonce); + checkClientFirstMessage(createScramMessage(ClientFirstMessage.class, str), "testuser", nonce, "testauthzid"); + + // Optional reserved value specified + for (String reserved : VALID_RESERVED) { + str = String.format("n,,%s,n=testuser,r=%s", reserved, nonce); + checkClientFirstMessage(createScramMessage(ClientFirstMessage.class, str), "testuser", nonce, ""); + } + + // Optional extension specified + for (String extension : VALID_EXTENSIONS) { + str = String.format("n,,n=testuser,r=%s,%s", nonce, extension); + checkClientFirstMessage(createScramMessage(ClientFirstMessage.class, str), "testuser", nonce, ""); + } + + //optional tokenauth specified as extensions + str = String.format("n,,n=testuser,r=%s,%s", nonce, "tokenauth=true"); + m = createScramMessage(ClientFirstMessage.class, str); + assertTrue(m.extensions().tokenAuthenticated(), "Token authentication not set from extensions"); + } + + @Test + public void invalidClientFirstMessage() { + String nonce = formatter.secureRandomString(); + // Invalid entry in gs2-header + String invalid = String.format("n,x=something,n=testuser,r=%s", nonce); + checkInvalidScramMessage(ClientFirstMessage.class, invalid); + + // Invalid reserved entry + for (String reserved : INVALID_RESERVED) { + invalid = String.format("n,,%s,n=testuser,r=%s", reserved, nonce); + checkInvalidScramMessage(ClientFirstMessage.class, invalid); + } + + // Invalid extension + for (String extension : INVALID_EXTENSIONS) { + invalid = String.format("n,,n=testuser,r=%s,%s", nonce, extension); + checkInvalidScramMessage(ClientFirstMessage.class, invalid); + } + } + + @Test + public void validServerFirstMessage() throws SaslException { + String clientNonce = formatter.secureRandomString(); + String serverNonce = formatter.secureRandomString(); + String nonce = clientNonce + serverNonce; + String salt = randomBytesAsString(); + + ServerFirstMessage m = new ServerFirstMessage(clientNonce, serverNonce, toBytes(salt), 8192); + checkServerFirstMessage(m, nonce, salt, 8192); + + // Default format used by Kafka clients, only nonce, salt and iterations are specified + String str = String.format("r=%s,s=%s,i=4096", nonce, salt); + m = createScramMessage(ServerFirstMessage.class, str); + checkServerFirstMessage(m, nonce, salt, 4096); + m = new ServerFirstMessage(m.toBytes()); + checkServerFirstMessage(m, nonce, salt, 4096); + + // Optional reserved value + for (String reserved : VALID_RESERVED) { + str = String.format("%s,r=%s,s=%s,i=4096", reserved, nonce, salt); + checkServerFirstMessage(createScramMessage(ServerFirstMessage.class, str), nonce, salt, 4096); + } + + // Optional extension + for (String extension : VALID_EXTENSIONS) { + str = String.format("r=%s,s=%s,i=4096,%s", nonce, salt, extension); + checkServerFirstMessage(createScramMessage(ServerFirstMessage.class, str), nonce, salt, 4096); + } + } + + @Test + public void invalidServerFirstMessage() { + String nonce = formatter.secureRandomString(); + String salt = randomBytesAsString(); + + // Invalid iterations + String invalid = String.format("r=%s,s=%s,i=0", nonce, salt); + checkInvalidScramMessage(ServerFirstMessage.class, invalid); + + // Invalid salt + invalid = String.format("r=%s,s=%s,i=4096", nonce, "=123"); + checkInvalidScramMessage(ServerFirstMessage.class, invalid); + + // Invalid format + invalid = String.format("r=%s,invalid,s=%s,i=4096", nonce, salt); + checkInvalidScramMessage(ServerFirstMessage.class, invalid); + + // Invalid reserved entry + for (String reserved : INVALID_RESERVED) { + invalid = String.format("%s,r=%s,s=%s,i=4096", reserved, nonce, salt); + checkInvalidScramMessage(ServerFirstMessage.class, invalid); + } + + // Invalid extension + for (String extension : INVALID_EXTENSIONS) { + invalid = String.format("r=%s,s=%s,i=4096,%s", nonce, salt, extension); + checkInvalidScramMessage(ServerFirstMessage.class, invalid); + } + } + + @Test + public void validClientFinalMessage() throws SaslException { + String nonce = formatter.secureRandomString(); + String channelBinding = randomBytesAsString(); + String proof = randomBytesAsString(); + + ClientFinalMessage m = new ClientFinalMessage(toBytes(channelBinding), nonce); + assertNull(m.proof(), "Invalid proof"); + m.proof(toBytes(proof)); + checkClientFinalMessage(m, channelBinding, nonce, proof); + + // Default format used by Kafka client: channel-binding, nonce and proof are specified + String str = String.format("c=%s,r=%s,p=%s", channelBinding, nonce, proof); + m = createScramMessage(ClientFinalMessage.class, str); + checkClientFinalMessage(m, channelBinding, nonce, proof); + m = new ClientFinalMessage(m.toBytes()); + checkClientFinalMessage(m, channelBinding, nonce, proof); + + // Optional extension specified + for (String extension : VALID_EXTENSIONS) { + str = String.format("c=%s,r=%s,%s,p=%s", channelBinding, nonce, extension, proof); + checkClientFinalMessage(createScramMessage(ClientFinalMessage.class, str), channelBinding, nonce, proof); + } + } + + @Test + public void invalidClientFinalMessage() { + String nonce = formatter.secureRandomString(); + String channelBinding = randomBytesAsString(); + String proof = randomBytesAsString(); + + // Invalid channel binding + String invalid = String.format("c=ab,r=%s,p=%s", nonce, proof); + checkInvalidScramMessage(ClientFirstMessage.class, invalid); + + // Invalid proof + invalid = String.format("c=%s,r=%s,p=123", channelBinding, nonce); + checkInvalidScramMessage(ClientFirstMessage.class, invalid); + + // Invalid extensions + for (String extension : INVALID_EXTENSIONS) { + invalid = String.format("c=%s,r=%s,%s,p=%s", channelBinding, nonce, extension, proof); + checkInvalidScramMessage(ClientFinalMessage.class, invalid); + } + } + + @Test + public void validServerFinalMessage() throws SaslException { + String serverSignature = randomBytesAsString(); + + ServerFinalMessage m = new ServerFinalMessage("unknown-user", null); + checkServerFinalMessage(m, "unknown-user", null); + m = new ServerFinalMessage(null, toBytes(serverSignature)); + checkServerFinalMessage(m, null, serverSignature); + + // Default format used by Kafka clients for successful final message + String str = String.format("v=%s", serverSignature); + m = createScramMessage(ServerFinalMessage.class, str); + checkServerFinalMessage(m, null, serverSignature); + m = new ServerFinalMessage(m.toBytes()); + checkServerFinalMessage(m, null, serverSignature); + + // Default format used by Kafka clients for final message with error + str = "e=other-error"; + m = createScramMessage(ServerFinalMessage.class, str); + checkServerFinalMessage(m, "other-error", null); + m = new ServerFinalMessage(m.toBytes()); + checkServerFinalMessage(m, "other-error", null); + + // Optional extension + for (String extension : VALID_EXTENSIONS) { + str = String.format("v=%s,%s", serverSignature, extension); + checkServerFinalMessage(createScramMessage(ServerFinalMessage.class, str), null, serverSignature); + } + } + + @Test + public void invalidServerFinalMessage() { + String serverSignature = randomBytesAsString(); + + // Invalid error + String invalid = "e=error1,error2"; + checkInvalidScramMessage(ServerFinalMessage.class, invalid); + + // Invalid server signature + invalid = String.format("v=1=23"); + checkInvalidScramMessage(ServerFinalMessage.class, invalid); + + // Invalid extensions + for (String extension : INVALID_EXTENSIONS) { + invalid = String.format("v=%s,%s", serverSignature, extension); + checkInvalidScramMessage(ServerFinalMessage.class, invalid); + + invalid = String.format("e=unknown-user,%s", extension); + checkInvalidScramMessage(ServerFinalMessage.class, invalid); + } + } + + private String randomBytesAsString() { + return Base64.getEncoder().encodeToString(formatter.secureRandomBytes()); + } + + private byte[] toBytes(String base64Str) { + return Base64.getDecoder().decode(base64Str); + } + + private void checkClientFirstMessage(ClientFirstMessage message, String saslName, String nonce, String authzid) { + assertEquals(saslName, message.saslName()); + assertEquals(nonce, message.nonce()); + assertEquals(authzid, message.authorizationId()); + } + + private void checkServerFirstMessage(ServerFirstMessage message, String nonce, String salt, int iterations) { + assertEquals(nonce, message.nonce()); + assertArrayEquals(Base64.getDecoder().decode(salt), message.salt()); + assertEquals(iterations, message.iterations()); + } + + private void checkClientFinalMessage(ClientFinalMessage message, String channelBinding, String nonce, String proof) { + assertArrayEquals(Base64.getDecoder().decode(channelBinding), message.channelBinding()); + assertEquals(nonce, message.nonce()); + assertArrayEquals(Base64.getDecoder().decode(proof), message.proof()); + } + + private void checkServerFinalMessage(ServerFinalMessage message, String error, String serverSignature) { + assertEquals(error, message.error()); + if (serverSignature == null) + assertNull(message.serverSignature(), "Unexpected server signature"); + else + assertArrayEquals(Base64.getDecoder().decode(serverSignature), message.serverSignature()); + } + + @SuppressWarnings("unchecked") + private T createScramMessage(Class clazz, String message) throws SaslException { + byte[] bytes = message.getBytes(StandardCharsets.UTF_8); + if (clazz == ClientFirstMessage.class) + return (T) new ClientFirstMessage(bytes); + else if (clazz == ServerFirstMessage.class) + return (T) new ServerFirstMessage(bytes); + else if (clazz == ClientFinalMessage.class) + return (T) new ClientFinalMessage(bytes); + else if (clazz == ServerFinalMessage.class) + return (T) new ServerFinalMessage(bytes); + else + throw new IllegalArgumentException("Unknown message type: " + clazz); + } + + private void checkInvalidScramMessage(Class clazz, String message) { + try { + createScramMessage(clazz, message); + fail("Exception not throws for invalid message of type " + clazz + " : " + message); + } catch (SaslException e) { + // Expected exception + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/scram/internals/ScramSaslServerTest.java b/clients/src/test/java/org/apache/kafka/common/security/scram/internals/ScramSaslServerTest.java new file mode 100644 index 0000000..121ac59 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/scram/internals/ScramSaslServerTest.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.scram.internals; + + +import java.nio.charset.StandardCharsets; +import java.util.HashMap; + +import org.apache.kafka.common.errors.SaslAuthenticationException; +import org.apache.kafka.common.security.authenticator.CredentialCache; +import org.apache.kafka.common.security.scram.ScramCredential; +import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ScramSaslServerTest { + + private static final String USER_A = "userA"; + private static final String USER_B = "userB"; + + private ScramMechanism mechanism; + private ScramFormatter formatter; + private ScramSaslServer saslServer; + + @BeforeEach + public void setUp() throws Exception { + mechanism = ScramMechanism.SCRAM_SHA_256; + formatter = new ScramFormatter(mechanism); + CredentialCache.Cache credentialCache = new CredentialCache().createCache(mechanism.mechanismName(), ScramCredential.class); + credentialCache.put(USER_A, formatter.generateCredential("passwordA", 4096)); + credentialCache.put(USER_B, formatter.generateCredential("passwordB", 4096)); + ScramServerCallbackHandler callbackHandler = new ScramServerCallbackHandler(credentialCache, new DelegationTokenCache(ScramMechanism.mechanismNames())); + saslServer = new ScramSaslServer(mechanism, new HashMap(), callbackHandler); + } + + @Test + public void noAuthorizationIdSpecified() throws Exception { + byte[] nextChallenge = saslServer.evaluateResponse(clientFirstMessage(USER_A, null)); + assertTrue(nextChallenge.length > 0, "Next challenge is empty"); + } + + @Test + public void authorizatonIdEqualsAuthenticationId() throws Exception { + byte[] nextChallenge = saslServer.evaluateResponse(clientFirstMessage(USER_A, USER_A)); + assertTrue(nextChallenge.length > 0, "Next challenge is empty"); + } + + @Test + public void authorizatonIdNotEqualsAuthenticationId() { + assertThrows(SaslAuthenticationException.class, () -> saslServer.evaluateResponse(clientFirstMessage(USER_A, USER_B))); + } + + private byte[] clientFirstMessage(String userName, String authorizationId) { + String nonce = formatter.secureRandomString(); + String authorizationField = authorizationId != null ? "a=" + authorizationId : ""; + String firstMessage = String.format("n,%s,n=%s,r=%s", authorizationField, userName, nonce); + return firstMessage.getBytes(StandardCharsets.UTF_8); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/ssl/DefaultSslEngineFactoryTest.java b/clients/src/test/java/org/apache/kafka/common/security/ssl/DefaultSslEngineFactoryTest.java new file mode 100644 index 0000000..be45729 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/ssl/DefaultSslEngineFactoryTest.java @@ -0,0 +1,324 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.ssl; + +import org.apache.kafka.common.config.SslConfigs; +import org.apache.kafka.common.config.types.Password; +import org.apache.kafka.common.errors.InvalidConfigurationException; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.security.KeyStore; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class DefaultSslEngineFactoryTest { + + /* + * Key and certificates were extracted using openssl from a key store file created with 100 years validity using: + * + * openssl pkcs12 -in server.keystore.p12 -nodes -nocerts -out test.key.pem -passin pass:key-password + * openssl pkcs12 -in server.keystore.p12 -nodes -nokeys -out test.certchain.pem -passin pass:key-password + * openssl pkcs12 -in server.keystore.p12 -nodes -out test.keystore.pem -passin pass:key-password + * openssl pkcs8 -topk8 -v1 pbeWithSHA1And3-KeyTripleDES-CBC -in test.key.pem -out test.key.encrypted.pem -passout pass:key-password + */ + + private static final String CA1 = "-----BEGIN CERTIFICATE-----\n" + + "MIIC0zCCAbugAwIBAgIEStdXHTANBgkqhkiG9w0BAQsFADASMRAwDgYDVQQDEwdU\n" + + "ZXN0Q0ExMCAXDTIwMDkyODA5MDI0MFoYDzIxMjAwOTA0MDkwMjQwWjASMRAwDgYD\n" + + "VQQDEwdUZXN0Q0ExMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAo3Gr\n" + + "WJAkjnvgcuIfjArDhNdtAlRTt094WMUXhYDibgGtd+CLcWqA+c4PEoK4oybnKZqU\n" + + "6MlDfPgesIK2YiNBuSVWMtZ2doageOBnd80Iwbg8DqGtQpUsvw8X5fOmuza+4inv\n" + + "/8IpiTizq8YjSMT4nYDmIjyyRCSNY4atjgMnskutJ0v6i69+ZAA520Y6nn2n4RD5\n" + + "8Yc+y7yCkbZXnYS5xBOFEExmtc0Xa7S9nM157xqKws9Z+rTKZYLrryaHI9JNcXgG\n" + + "kzQEH9fBePASeWfi9AGRvAyS2GMSIBOsihIDIha/mqQcJOGCEqTMtefIj2FaErO2\n" + + "bL9yU7OpW53iIC8y0QIDAQABoy8wLTAMBgNVHRMEBTADAQH/MB0GA1UdDgQWBBRf\n" + + "svKcoQ9ZBvjwyUSV2uMFzlkOWDANBgkqhkiG9w0BAQsFAAOCAQEAEE1ZG2MGE248\n" + + "glO83ROrHbxmnVWSQHt/JZANR1i362sY1ekL83wlhkriuvGVBlHQYWezIfo/4l9y\n" + + "JTHNX3Mrs9eWUkaDXADkHWj3AyLXN3nfeU307x1wA7OvI4YKpwvfb4aYS8RTPz9d\n" + + "JtrfR0r8aGTgsXvCe4SgwDBKv7bckctOwD3S7D/b6y3w7X0s7JCU5+8ZjgoYfcLE\n" + + "gNqQEaOwdT2LHCvxHmGn/2VGs/yatPQIYYuufe5i8yX7pp4Xbd2eD6LULYkHFs3x\n" + + "uJzMRI7BukmIIWuBbAkYI0atxLQIysnVFXdL9pBgvgso2nA3FgP/XeORhkyHVvtL\n" + + "REH2YTlftQ==\n" + + "-----END CERTIFICATE-----"; + + private static final String CA2 = "-----BEGIN CERTIFICATE-----\n" + + "MIIC0zCCAbugAwIBAgIEfk9e9DANBgkqhkiG9w0BAQsFADASMRAwDgYDVQQDEwdU\n" + + "ZXN0Q0EyMCAXDTIwMDkyODA5MDI0MVoYDzIxMjAwOTA0MDkwMjQxWjASMRAwDgYD\n" + + "VQQDEwdUZXN0Q0EyMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAvCh0\n" + + "UO5op9eHfz7mvZ7IySK7AOCTC56QYFJcU+hD6yk1wKg2qot7naI5ozAc8n7c4pMt\n" + + "LjI3D0VtC/oHC29R2HNMSWyHcxIXw8z127XeCLRkCqYWuVAl3nBuWfWVPObjKetH\n" + + "TWlQANYWAfk1VbS6wfzgp9cMaK7wQ+VoGEo4x3pjlrdlyg4k4O2yubcpWmJ2TjxS\n" + + "gg7TfKGizUVAvF9wUG9Q4AlCg4uuww5RN9w6vnzDKGhWJhkQ6pf/m1xB+WueFOeU\n" + + "aASGhGqCTqiz3p3M3M4OZzG3KptjQ/yb67x4T5U5RxqoiN4L57E7ZJLREpa6ZZNs\n" + + "ps/gQ8dR9Uo/PRyAkQIDAQABoy8wLTAMBgNVHRMEBTADAQH/MB0GA1UdDgQWBBRg\n" + + "IAOVH5LeE6nZmdScEE3JO/AhvTANBgkqhkiG9w0BAQsFAAOCAQEAHkk1iybwy/Lf\n" + + "iEQMVRy7XfuC008O7jfCUBMgUvE+oO2RadH5MmsXHG3YerdsDM90dui4JqQNZOUh\n" + + "kF8dIWPQHE0xDsR9jiUsemZFpVMN7DcvVZ3eFhbvJA8Q50rxcNGA+tn9xT/xdQ6z\n" + + "1eRq9IPoYcRexQ7s9mincM4T4lLm8GGcd7ZPHy8kw0Bp3E/enRHWaF5b8KbXezXD\n" + + "I3SEYUyRL2K3px4FImT4X9XQm2EX6EONlu4GRcJpD6RPc0zC7c9dwEnSo+0NnewR\n" + + "gjgO34CLzShB/kASLS9VQXcUC6bsggAVK2rWQMmy35SOEUufSuvg8kUFoyuTzfhn\n" + + "hL+PVwIu7g==\n" + + "-----END CERTIFICATE-----"; + + private static final String CERTCHAIN = "Bag Attributes\n" + + " friendlyName: server\n" + + " localKeyID: 54 69 6D 65 20 31 36 30 31 32 38 33 37 36 35 34 32 33 \n" + + "subject=/CN=TestBroker\n" + + "issuer=/CN=TestCA1\n" + + "-----BEGIN CERTIFICATE-----\n" + + "MIIC/zCCAeegAwIBAgIEatBnEzANBgkqhkiG9w0BAQsFADASMRAwDgYDVQQDEwdU\n" + + "ZXN0Q0ExMCAXDTIwMDkyODA5MDI0NFoYDzIxMjAwOTA0MDkwMjQ0WjAVMRMwEQYD\n" + + "VQQDEwpUZXN0QnJva2VyMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA\n" + + "pkw1AS71ej/iOMvzVgVL1dkQOYzI842NcPmx0yFFsue2umL8WVd3085NgWRb3SS1\n" + + "4X676t7zxjPGzYi7jwmA8stCrDt0NAPWd/Ko6ErsCs87CUs4u1Cinf+b3o9NF5u0\n" + + "UPYBQLF4Ir8T1jQ+tKiqsChGDt6urRAg1Cro5i7r10jN1uofY2tBs+r8mALhJ17c\n" + + "T5LKawXeYwNOQ86c5djClbcP0RrfcPyRyj1/Cp1axo28iO0fXFyO2Zf3a4vtt+Ih\n" + + "PW+A2tL+t3JTBd8g7Fl3ozzpcotAi7MDcZaYA9GiTP4DOiKUeDt6yMYQQr3VEqGa\n" + + "pXp4fKY+t9slqnAmcBZ4kQIDAQABo1gwVjAfBgNVHSMEGDAWgBRfsvKcoQ9ZBvjw\n" + + "yUSV2uMFzlkOWDAUBgNVHREEDTALgglsb2NhbGhvc3QwHQYDVR0OBBYEFGWt+27P\n" + + "INk/S5X+PRV/jW3WOhtaMA0GCSqGSIb3DQEBCwUAA4IBAQCLHCjFFvqa+0GcG9eq\n" + + "v1QWaXDohY5t5CCwD8Z+lT9wcSruTxDPwL7LrR36h++D6xJYfiw4iaRighoA40xP\n" + + "W6+0zGK/UtWV4t+ODTDzyAWgls5w+0R5ki6447qGqu5tXlW5DCHkkxWiozMnhNU2\n" + + "G3P/Drh7DhmADDBjtVLsu5M1sagF/xwTP/qCLMdChlJNdeqyLnAUa9SYG1eNZS/i\n" + + "wrCC8m9RUQb4+OlQuFtr0KhaaCkBXfmhigQAmh44zSyO+oa3qQDEavVFo/Mcui9o\n" + + "WBYetcgVbXPNoti+hQEMqmJYBHlLbhxMnkooGn2fa70f453Bdu/Xh6Yphi5NeCHn\n" + + "1I+y\n" + + "-----END CERTIFICATE-----\n" + + "Bag Attributes\n" + + " friendlyName: CN=TestCA1\n" + + "subject=/CN=TestCA1\n" + + "issuer=/CN=TestCA1\n" + + "-----BEGIN CERTIFICATE-----\n" + + "MIIC0zCCAbugAwIBAgIEStdXHTANBgkqhkiG9w0BAQsFADASMRAwDgYDVQQDEwdU\n" + + "ZXN0Q0ExMCAXDTIwMDkyODA5MDI0MFoYDzIxMjAwOTA0MDkwMjQwWjASMRAwDgYD\n" + + "VQQDEwdUZXN0Q0ExMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAo3Gr\n" + + "WJAkjnvgcuIfjArDhNdtAlRTt094WMUXhYDibgGtd+CLcWqA+c4PEoK4oybnKZqU\n" + + "6MlDfPgesIK2YiNBuSVWMtZ2doageOBnd80Iwbg8DqGtQpUsvw8X5fOmuza+4inv\n" + + "/8IpiTizq8YjSMT4nYDmIjyyRCSNY4atjgMnskutJ0v6i69+ZAA520Y6nn2n4RD5\n" + + "8Yc+y7yCkbZXnYS5xBOFEExmtc0Xa7S9nM157xqKws9Z+rTKZYLrryaHI9JNcXgG\n" + + "kzQEH9fBePASeWfi9AGRvAyS2GMSIBOsihIDIha/mqQcJOGCEqTMtefIj2FaErO2\n" + + "bL9yU7OpW53iIC8y0QIDAQABoy8wLTAMBgNVHRMEBTADAQH/MB0GA1UdDgQWBBRf\n" + + "svKcoQ9ZBvjwyUSV2uMFzlkOWDANBgkqhkiG9w0BAQsFAAOCAQEAEE1ZG2MGE248\n" + + "glO83ROrHbxmnVWSQHt/JZANR1i362sY1ekL83wlhkriuvGVBlHQYWezIfo/4l9y\n" + + "JTHNX3Mrs9eWUkaDXADkHWj3AyLXN3nfeU307x1wA7OvI4YKpwvfb4aYS8RTPz9d\n" + + "JtrfR0r8aGTgsXvCe4SgwDBKv7bckctOwD3S7D/b6y3w7X0s7JCU5+8ZjgoYfcLE\n" + + "gNqQEaOwdT2LHCvxHmGn/2VGs/yatPQIYYuufe5i8yX7pp4Xbd2eD6LULYkHFs3x\n" + + "uJzMRI7BukmIIWuBbAkYI0atxLQIysnVFXdL9pBgvgso2nA3FgP/XeORhkyHVvtL\n" + + "REH2YTlftQ==\n" + + "-----END CERTIFICATE-----"; + + private static final String KEY = "Bag Attributes\n" + + " friendlyName: server\n" + + " localKeyID: 54 69 6D 65 20 31 36 30 31 32 38 33 37 36 35 34 32 33\n" + + "Key Attributes: \n" + + "-----BEGIN PRIVATE KEY-----\n" + + "MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCmTDUBLvV6P+I4\n" + + "y/NWBUvV2RA5jMjzjY1w+bHTIUWy57a6YvxZV3fTzk2BZFvdJLXhfrvq3vPGM8bN\n" + + "iLuPCYDyy0KsO3Q0A9Z38qjoSuwKzzsJSzi7UKKd/5vej00Xm7RQ9gFAsXgivxPW\n" + + "ND60qKqwKEYO3q6tECDUKujmLuvXSM3W6h9ja0Gz6vyYAuEnXtxPksprBd5jA05D\n" + + "zpzl2MKVtw/RGt9w/JHKPX8KnVrGjbyI7R9cXI7Zl/dri+234iE9b4Da0v63clMF\n" + + "3yDsWXejPOlyi0CLswNxlpgD0aJM/gM6IpR4O3rIxhBCvdUSoZqlenh8pj632yWq\n" + + "cCZwFniRAgMBAAECggEAOfC/XwQvf0KW3VciF0yNGZshbgvBUCp3p284J+ml0Smu\n" + + "ns4yQiaZl3B/zJ9c6nYJ8OEpNDIuGVac46vKPZIAHZf4SO4GFMFpji078IN6LmH5\n" + + "nclZoNn9brNKaYbgQ2N6teKgmRu8Uc7laHKXjnZd0jaWAkRP8/h0l7fDob+jaERj\n" + + "oJBx4ux2Z62TTCP6W4VY3KZgSL1p6dQswqlukPVytMeI2XEwWnO+w8ED0BxCxM4F\n" + + "K//dw7nUMGS9GUNkgyDcH1akYSCDzdBeymQBp2latBotVfGNK1hq9nC1iaxmRkJL\n" + + "sYjwVc24n37u+txOovy3daq2ySj9trF7ySAPVYkh4QKBgQDWeN/MR6cy1TLF2j3g\n" + + "eMMeM32LxXArIPsar+mft+uisKWk5LDpsKpph93sl0JjFi4x0t1mqw23h23I+B2c\n" + + "JWiPAHUG3FGvvkPPcfMUvd7pODyE2XaXi+36UZAH7qc94VZGJEb+sPITckSruREE\n" + + "QErWZyrbBRgvQXsmVme5B2/kRQKBgQDGf2HQH0KHl54O2r9vrhiQxWIIMSWlizJC\n" + + "hjboY6DkIsAMwnXp3wn3Bk4tSgeLk8DEVlmEaE3gvGpiIp0vQnSOlME2TXfEthdM\n" + + "uS3+BFXN4Vxxx/qjKL2WfZloyzdaaaF7s+LIwmXgLsFFCUSq+uLtBqfpH2Qv+paX\n" + + "Xqm7LN3V3QKBgH5ssj/Q3RZx5oQKqf7wMNRUteT2dbB2uI56s9SariQwzPPuevrG\n" + + "US30ETWt1ExkfsaP7kLfAi71fhnBaHLq+j+RnWp15REbrw1RtmC7q/L+W25UYjvj\n" + + "GF0+RxDl9V/cvOaL6+2mkIw2B5TSet1uqK7KEdEZp6/zgYyP0oSXhbWhAoGAdnlZ\n" + + "HCtMPjnUcPFHCZVTvDTTSihrW9805FfPNe0g/olvLy5xymEBRZtR1d41mq1ZhNY1\n" + + "H75RnS1YIbKfNrHnd6J5n7ulHJfCWFy+grp7rCIyVwcRJYkPf17/zXhdVW1uoLLB\n" + + "TSoaPDAr0tSxU4vjHa23UoEV/z0F3Nr3W2xwC1ECgYBHKjv6ekLhx7HbP797+Ai+\n" + + "wkHvS2L/MqEBxuHzcQ9G6Mj3ANAeyDB8YSC8qGtDQoEyukv2dO73lpodNgbR8P+Q\n" + + "PDBb6eyntAo2sSeo0jZkiXvDOfRaGuGVrxjuTfaqcVB33jC6BYfi61/3Sr5oG9Nd\n" + + "tDGh1HlOIRm1jD9KQNVZ/Q==\n" + + "-----END PRIVATE KEY-----"; + + private static final String ENCRYPTED_KEY = "-----BEGIN ENCRYPTED PRIVATE KEY-----\n" + + "MIIE6jAcBgoqhkiG9w0BDAEDMA4ECGyAEWAXlaXzAgIIAASCBMgt7QD1Bbz7MAHI\n" + + "Ni0eTrwNiuAPluHirLXzsV57d1O9i4EXVp5nzRy6753cjXbGXARbBeaJD+/+jbZp\n" + + "CBZTHMG8rTCfbsg5kMqxT6XuuqWlKLKc4gaq+QNgHHleKqnpwZQmOQ+awKWEK/Ow\n" + + "Z0KxXqkp+b4/qJK3MqKZDsJtVdyUhO0tLVxd+BHDg9B93oExc87F16h3R0+T4rxE\n" + + "Tvz2c2upBqva49AbLDxpWXLCJC8CRkxM+KHrPkYjpNx3jCjtwiiXfzJCWjuCkVrL\n" + + "2F4bqvpYPIseoPtMvWaplNtoPwhpzBB/hoJ+R+URr4XHX3Y+bz6k6iQnhoCOIviy\n" + + "oEEUvWtKnaEEKSauR+Wyj3MoeB64g9NWMEHv7+SQeA4WqlgV2s4txwRxFGKyKLPq\n" + + "caMSpfxvYujtSh0DOv9GI3cVHPM8WsebCz9cNrbKSR8/8JufcoonTitwF/4vm1Et\n" + + "AdmCuH9JIYVvmFKFVxY9SvRAvo43OQaPmJQHMUa4yDfMtpTSgmB/7HFgxtksYs++\n" + + "Gbrq6F/hon+0bLx+bMz2FK635UU+iVno+qaScKWN3BFqDl+KnZprBhLSXTT3aHmp\n" + + "fisQit/HWp71a0Vzq85WwI4ucMKNc8LemlwNBxWLLiJDp7sNPLb5dIl8yIwSEIgd\n" + + "vC5px9KWEdt3GxTUEqtIeBmagbBhahcv+c9Dq924DLI+Slv6TJKZpIcMqUECgzvi\n" + + "hb8gegyEscBEcDSzl0ojlFVz4Va5eZS/linTjNJhnkx8BKLn/QFco7FpEE6uOmQ3\n" + + "0kF64M2Rv67cJbYVrhD46TgIzH3Y/FOMSi1zFHQ14nVXWMu0yAlBX+QGk7Xl+/aF\n" + + "BIq+i9WcBqbttR3CwyeTnIFXkdC66iTZYhDl9HT6yMcazql2Or2TjIIWr6tfNWH/\n" + + "5dWSEHYM5m8F2/wF0ANWJyR1oPr4ckcUsfl5TfOWVj5wz4QVF6EGV7FxEnQHrdx0\n" + + "6rXThRKFjqxUubsNt1yUEwdlTNz2UFhobGF9MmFeB97BZ6T4v8G825de/Caq9FzO\n" + + "yMFFCRcGC7gIzMXRPEjHIvBdTThm9rbNzKPXHqw0LHG478yIqzxvraCYTRw/4eWN\n" + + "Q+hyOL/5T5QNXHpR8Udp/7sptw7HfRnecQ/Vz9hOKShQq3h4Sz6eQMQm7P9qGo/N\n" + + "bltEAIECRVcNYLN8LuEORfeecNcV3BX+4BBniFtdD2bIRsWC0ZUsGf14Yhr4P1OA\n" + + "PtMJzy99mrcq3h+o+hEW6bhIj1gA88JSMJ4iRuwTLRKE81w7EyziScDsotYKvDPu\n" + + "w4+PFbQO3fr/Zga3LgYis8/DMqZoWjVCjAeVoypuOZreieZYC/BgBS8qSUAmDPKq\n" + + "jK+T5pwMMchfXbkV80LTu1kqLfKWdE0AmZfGy8COE/NNZ/FeiWZPdwu2Ix6u/RoY\n" + + "LTjNy4YLIBdVELFXaFJF2GfzLpnwrW5tyNPVVrGmUoiyOzgx8gMyCLGavGtduyoY\n" + + "tBiUTmd05Ugscn4Rz9X30S4NbnjL/h+bWl1m6/M+9FHEe85FPxmt/GRmJPbFPMR5\n" + + "q5EgQGkt4ifiaP6qvyFulwvVwx+m0bf1q6Vb/k3clIyLMcVZWFE1TqNH2Ife46AE\n" + + "2I39ZnGTt0mbWskpHBA=\n" + + "-----END ENCRYPTED PRIVATE KEY-----"; + + private static final Password KEY_PASSWORD = new Password("key-password"); + + private DefaultSslEngineFactory factory = new DefaultSslEngineFactory(); + Map configs = new HashMap<>(); + + @BeforeEach + public void setUp() { + factory = new DefaultSslEngineFactory(); + configs.put(SslConfigs.SSL_PROTOCOL_CONFIG, "TLSv1.2"); + } + + @Test + public void testPemTrustStoreConfigWithOneCert() throws Exception { + configs.put(SslConfigs.SSL_TRUSTSTORE_CERTIFICATES_CONFIG, pemAsConfigValue(CA1)); + configs.put(SslConfigs.SSL_TRUSTSTORE_TYPE_CONFIG, DefaultSslEngineFactory.PEM_TYPE); + factory.configure(configs); + + KeyStore trustStore = factory.truststore(); + List aliases = Collections.list(trustStore.aliases()); + assertEquals(Collections.singletonList("kafka0"), aliases); + assertNotNull(trustStore.getCertificate("kafka0"), "Certificate not loaded"); + assertNull(trustStore.getKey("kafka0", null), "Unexpected private key"); + } + + @Test + public void testPemTrustStoreConfigWithMultipleCerts() throws Exception { + configs.put(SslConfigs.SSL_TRUSTSTORE_CERTIFICATES_CONFIG, pemAsConfigValue(CA1, CA2)); + configs.put(SslConfigs.SSL_TRUSTSTORE_TYPE_CONFIG, DefaultSslEngineFactory.PEM_TYPE); + factory.configure(configs); + + KeyStore trustStore = factory.truststore(); + List aliases = Collections.list(trustStore.aliases()); + assertEquals(Arrays.asList("kafka0", "kafka1"), aliases); + assertNotNull(trustStore.getCertificate("kafka0"), "Certificate not loaded"); + assertNull(trustStore.getKey("kafka0", null), "Unexpected private key"); + assertNotNull(trustStore.getCertificate("kafka1"), "Certificate not loaded"); + assertNull(trustStore.getKey("kafka1", null), "Unexpected private key"); + } + + @Test + public void testPemKeyStoreConfigNoPassword() throws Exception { + verifyPemKeyStoreConfig(KEY, null); + } + + @Test + public void testPemKeyStoreConfigWithKeyPassword() throws Exception { + verifyPemKeyStoreConfig(ENCRYPTED_KEY, KEY_PASSWORD); + } + + @Test + public void testTrailingNewLines() throws Exception { + verifyPemKeyStoreConfig(ENCRYPTED_KEY + "\n\n", KEY_PASSWORD); + } + + @Test + public void testLeadingNewLines() throws Exception { + verifyPemKeyStoreConfig("\n\n" + ENCRYPTED_KEY, KEY_PASSWORD); + } + + @Test + public void testCarriageReturnLineFeed() throws Exception { + verifyPemKeyStoreConfig(ENCRYPTED_KEY.replaceAll("\n", "\r\n"), KEY_PASSWORD); + } + + private void verifyPemKeyStoreConfig(String keyFileName, Password keyPassword) throws Exception { + configs.put(SslConfigs.SSL_KEYSTORE_KEY_CONFIG, pemAsConfigValue(keyFileName)); + configs.put(SslConfigs.SSL_KEYSTORE_CERTIFICATE_CHAIN_CONFIG, pemAsConfigValue(CERTCHAIN)); + configs.put(SslConfigs.SSL_KEY_PASSWORD_CONFIG, keyPassword); + configs.put(SslConfigs.SSL_KEYSTORE_TYPE_CONFIG, DefaultSslEngineFactory.PEM_TYPE); + factory.configure(configs); + + KeyStore keyStore = factory.keystore(); + List aliases = Collections.list(keyStore.aliases()); + assertEquals(Collections.singletonList("kafka"), aliases); + assertNotNull(keyStore.getCertificate("kafka"), "Certificate not loaded"); + assertNotNull(keyStore.getKey("kafka", keyPassword == null ? null : keyPassword.value().toCharArray()), + "Private key not loaded"); + } + + @Test + public void testPemTrustStoreFile() throws Exception { + configs.put(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG, pemFilePath(CA1)); + configs.put(SslConfigs.SSL_TRUSTSTORE_TYPE_CONFIG, DefaultSslEngineFactory.PEM_TYPE); + factory.configure(configs); + + KeyStore trustStore = factory.truststore(); + List aliases = Collections.list(trustStore.aliases()); + assertEquals(Collections.singletonList("kafka0"), aliases); + assertNotNull(trustStore.getCertificate("kafka0"), "Certificate not found"); + assertNull(trustStore.getKey("kafka0", null), "Unexpected private key"); + } + + @Test + public void testPemKeyStoreFileNoKeyPassword() throws Exception { + configs.put(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG, + pemFilePath(pemAsConfigValue(KEY, CERTCHAIN).value())); + configs.put(SslConfigs.SSL_KEYSTORE_TYPE_CONFIG, DefaultSslEngineFactory.PEM_TYPE); + assertThrows(InvalidConfigurationException.class, () -> factory.configure(configs)); + } + + @Test + public void testPemKeyStoreFileWithKeyPassword() throws Exception { + configs.put(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG, + pemFilePath(pemAsConfigValue(ENCRYPTED_KEY, CERTCHAIN).value())); + configs.put(SslConfigs.SSL_KEY_PASSWORD_CONFIG, KEY_PASSWORD); + configs.put(SslConfigs.SSL_KEYSTORE_TYPE_CONFIG, DefaultSslEngineFactory.PEM_TYPE); + factory.configure(configs); + + KeyStore keyStore = factory.keystore(); + List aliases = Collections.list(keyStore.aliases()); + assertEquals(Collections.singletonList("kafka"), aliases); + assertNotNull(keyStore.getCertificate("kafka"), "Certificate not found"); + assertNotNull(keyStore.getKey("kafka", KEY_PASSWORD.value().toCharArray()), "Private key not found"); + } + + private String pemFilePath(String pem) throws Exception { + return TestUtils.tempFile(pem).getAbsolutePath(); + } + + private Password pemAsConfigValue(String... pemValues) throws Exception { + StringBuilder builder = new StringBuilder(); + for (String pem : pemValues) { + builder.append(pem); + builder.append("\n"); + } + return new Password(builder.toString().trim()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/ssl/SslFactoryTest.java b/clients/src/test/java/org/apache/kafka/common/security/ssl/SslFactoryTest.java new file mode 100644 index 0000000..cfb37b3 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/ssl/SslFactoryTest.java @@ -0,0 +1,546 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.ssl; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.util.Arrays; +import java.util.Map; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.config.SecurityConfig; +import org.apache.kafka.common.config.SslConfigs; +import org.apache.kafka.common.config.types.Password; +import org.apache.kafka.common.security.TestSecurityConfig; +import org.apache.kafka.common.security.auth.SslEngineFactory; +import org.apache.kafka.common.security.ssl.DefaultSslEngineFactory.FileBasedStore; +import org.apache.kafka.common.security.ssl.DefaultSslEngineFactory.PemStore; +import org.apache.kafka.common.security.ssl.DefaultSslEngineFactory.SecurityStore; +import org.apache.kafka.common.security.ssl.mock.TestKeyManagerFactory; +import org.apache.kafka.common.security.ssl.mock.TestProviderCreator; +import org.apache.kafka.common.security.ssl.mock.TestTrustManagerFactory; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.test.TestSslUtils; +import org.apache.kafka.common.network.Mode; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import java.security.Security; +import java.util.Properties; + +public abstract class SslFactoryTest { + private final String tlsProtocol; + + public SslFactoryTest(String tlsProtocol) { + this.tlsProtocol = tlsProtocol; + } + + @Test + public void testSslFactoryConfiguration() throws Exception { + File trustStoreFile = File.createTempFile("truststore", ".jks"); + Map serverSslConfig = sslConfigsBuilder(Mode.SERVER) + .createNewTrustStore(trustStoreFile) + .build(); + SslFactory sslFactory = new SslFactory(Mode.SERVER); + sslFactory.configure(serverSslConfig); + //host and port are hints + SSLEngine engine = sslFactory.createSslEngine("localhost", 0); + assertNotNull(engine); + assertEquals(Utils.mkSet(tlsProtocol), Utils.mkSet(engine.getEnabledProtocols())); + assertEquals(false, engine.getUseClientMode()); + } + + @Test + public void testSslFactoryWithCustomKeyManagerConfiguration() { + TestProviderCreator testProviderCreator = new TestProviderCreator(); + Map serverSslConfig = TestSslUtils.createSslConfig( + TestKeyManagerFactory.ALGORITHM, + TestTrustManagerFactory.ALGORITHM, + tlsProtocol + ); + serverSslConfig.put(SecurityConfig.SECURITY_PROVIDERS_CONFIG, testProviderCreator.getClass().getName()); + SslFactory sslFactory = new SslFactory(Mode.SERVER); + sslFactory.configure(serverSslConfig); + assertNotNull(sslFactory.sslEngineFactory(), "SslEngineFactory not created"); + Security.removeProvider(testProviderCreator.getProvider().getName()); + } + + @Test + public void testSslFactoryWithoutProviderClassConfiguration() { + // An exception is thrown as the algorithm is not registered through a provider + Map serverSslConfig = TestSslUtils.createSslConfig( + TestKeyManagerFactory.ALGORITHM, + TestTrustManagerFactory.ALGORITHM, + tlsProtocol + ); + SslFactory sslFactory = new SslFactory(Mode.SERVER); + assertThrows(KafkaException.class, () -> sslFactory.configure(serverSslConfig)); + } + + @Test + public void testSslFactoryWithIncorrectProviderClassConfiguration() { + // An exception is thrown as the algorithm is not registered through a provider + Map serverSslConfig = TestSslUtils.createSslConfig( + TestKeyManagerFactory.ALGORITHM, + TestTrustManagerFactory.ALGORITHM, + tlsProtocol + ); + serverSslConfig.put(SecurityConfig.SECURITY_PROVIDERS_CONFIG, + "com.fake.ProviderClass1,com.fake.ProviderClass2"); + SslFactory sslFactory = new SslFactory(Mode.SERVER); + assertThrows(KafkaException.class, () -> sslFactory.configure(serverSslConfig)); + } + + @Test + public void testSslFactoryWithoutPasswordConfiguration() throws Exception { + File trustStoreFile = File.createTempFile("truststore", ".jks"); + Map serverSslConfig = sslConfigsBuilder(Mode.SERVER) + .createNewTrustStore(trustStoreFile) + .build(); + // unset the password + serverSslConfig.remove(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG); + SslFactory sslFactory = new SslFactory(Mode.SERVER); + try { + sslFactory.configure(serverSslConfig); + } catch (Exception e) { + fail("An exception was thrown when configuring the truststore without a password: " + e); + } + } + + @Test + public void testClientMode() throws Exception { + File trustStoreFile = File.createTempFile("truststore", ".jks"); + Map clientSslConfig = sslConfigsBuilder(Mode.CLIENT) + .createNewTrustStore(trustStoreFile) + .useClientCert(false) + .build(); + SslFactory sslFactory = new SslFactory(Mode.CLIENT); + sslFactory.configure(clientSslConfig); + //host and port are hints + SSLEngine engine = sslFactory.createSslEngine("localhost", 0); + assertTrue(engine.getUseClientMode()); + } + + @Test + public void staleSslEngineFactoryShouldBeClosed() throws IOException, GeneralSecurityException { + File trustStoreFile = File.createTempFile("truststore", ".jks"); + Map clientSslConfig = sslConfigsBuilder(Mode.SERVER) + .createNewTrustStore(trustStoreFile) + .useClientCert(false) + .build(); + clientSslConfig.put(SslConfigs.SSL_ENGINE_FACTORY_CLASS_CONFIG, TestSslUtils.TestSslEngineFactory.class); + SslFactory sslFactory = new SslFactory(Mode.SERVER); + sslFactory.configure(clientSslConfig); + TestSslUtils.TestSslEngineFactory sslEngineFactory = (TestSslUtils.TestSslEngineFactory) sslFactory.sslEngineFactory(); + assertNotNull(sslEngineFactory); + assertFalse(sslEngineFactory.closed); + + trustStoreFile = File.createTempFile("truststore", ".jks"); + clientSslConfig = sslConfigsBuilder(Mode.SERVER) + .createNewTrustStore(trustStoreFile) + .build(); + clientSslConfig.put(SslConfigs.SSL_ENGINE_FACTORY_CLASS_CONFIG, TestSslUtils.TestSslEngineFactory.class); + sslFactory.reconfigure(clientSslConfig); + TestSslUtils.TestSslEngineFactory newSslEngineFactory = (TestSslUtils.TestSslEngineFactory) sslFactory.sslEngineFactory(); + assertNotEquals(sslEngineFactory, newSslEngineFactory); + // the older one should be closed + assertTrue(sslEngineFactory.closed); + } + + @Test + public void testReconfiguration() throws Exception { + File trustStoreFile = File.createTempFile("truststore", ".jks"); + Map sslConfig = sslConfigsBuilder(Mode.SERVER) + .createNewTrustStore(trustStoreFile) + .build(); + SslFactory sslFactory = new SslFactory(Mode.SERVER); + sslFactory.configure(sslConfig); + SslEngineFactory sslEngineFactory = sslFactory.sslEngineFactory(); + assertNotNull(sslEngineFactory, "SslEngineFactory not created"); + + // Verify that SslEngineFactory is not recreated on reconfigure() if config and + // file are not changed + sslFactory.reconfigure(sslConfig); + assertSame(sslEngineFactory, sslFactory.sslEngineFactory(), "SslEngineFactory recreated unnecessarily"); + + // Verify that the SslEngineFactory is recreated on reconfigure() if config is changed + trustStoreFile = File.createTempFile("truststore", ".jks"); + sslConfig = sslConfigsBuilder(Mode.SERVER) + .createNewTrustStore(trustStoreFile) + .build(); + sslFactory.reconfigure(sslConfig); + assertNotSame(sslEngineFactory, sslFactory.sslEngineFactory(), "SslEngineFactory not recreated"); + sslEngineFactory = sslFactory.sslEngineFactory(); + + // Verify that builder is recreated on reconfigure() if config is not changed, but truststore file was modified + trustStoreFile.setLastModified(System.currentTimeMillis() + 10000); + sslFactory.reconfigure(sslConfig); + assertNotSame(sslEngineFactory, sslFactory.sslEngineFactory(), "SslEngineFactory not recreated"); + sslEngineFactory = sslFactory.sslEngineFactory(); + + // Verify that builder is recreated on reconfigure() if config is not changed, but keystore file was modified + File keyStoreFile = new File((String) sslConfig.get(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG)); + keyStoreFile.setLastModified(System.currentTimeMillis() + 10000); + sslFactory.reconfigure(sslConfig); + assertNotSame(sslEngineFactory, sslFactory.sslEngineFactory(), "SslEngineFactory not recreated"); + sslEngineFactory = sslFactory.sslEngineFactory(); + + // Verify that builder is recreated after validation on reconfigure() if config is not changed, but keystore file was modified + keyStoreFile.setLastModified(System.currentTimeMillis() + 15000); + sslFactory.validateReconfiguration(sslConfig); + sslFactory.reconfigure(sslConfig); + assertNotSame(sslEngineFactory, sslFactory.sslEngineFactory(), "SslEngineFactory not recreated"); + sslEngineFactory = sslFactory.sslEngineFactory(); + + // Verify that the builder is not recreated if modification time cannot be determined + keyStoreFile.setLastModified(System.currentTimeMillis() + 20000); + Files.delete(keyStoreFile.toPath()); + sslFactory.reconfigure(sslConfig); + assertSame(sslEngineFactory, sslFactory.sslEngineFactory(), "SslEngineFactory recreated unnecessarily"); + } + + @Test + public void testReconfigurationWithoutTruststore() throws Exception { + File trustStoreFile = File.createTempFile("truststore", ".jks"); + Map sslConfig = sslConfigsBuilder(Mode.SERVER) + .createNewTrustStore(trustStoreFile) + .build(); + sslConfig.remove(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG); + sslConfig.remove(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG); + sslConfig.remove(SslConfigs.SSL_TRUSTSTORE_TYPE_CONFIG); + SslFactory sslFactory = new SslFactory(Mode.SERVER); + sslFactory.configure(sslConfig); + SSLContext sslContext = ((DefaultSslEngineFactory) sslFactory.sslEngineFactory()).sslContext(); + assertNotNull(sslContext, "SSL context not created"); + assertSame(sslContext, ((DefaultSslEngineFactory) sslFactory.sslEngineFactory()).sslContext(), + "SSL context recreated unnecessarily"); + assertFalse(sslFactory.createSslEngine("localhost", 0).getUseClientMode()); + + Map sslConfig2 = sslConfigsBuilder(Mode.SERVER) + .createNewTrustStore(trustStoreFile) + .build(); + try { + sslFactory.validateReconfiguration(sslConfig2); + fail("Truststore configured dynamically for listener without previous truststore"); + } catch (ConfigException e) { + // Expected exception + } + } + + @Test + public void testReconfigurationWithoutKeystore() throws Exception { + File trustStoreFile = File.createTempFile("truststore", ".jks"); + Map sslConfig = sslConfigsBuilder(Mode.SERVER) + .createNewTrustStore(trustStoreFile) + .build(); + sslConfig.remove(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG); + sslConfig.remove(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG); + sslConfig.remove(SslConfigs.SSL_KEYSTORE_TYPE_CONFIG); + SslFactory sslFactory = new SslFactory(Mode.SERVER); + sslFactory.configure(sslConfig); + SSLContext sslContext = ((DefaultSslEngineFactory) sslFactory.sslEngineFactory()).sslContext(); + assertNotNull(sslContext, "SSL context not created"); + assertSame(sslContext, ((DefaultSslEngineFactory) sslFactory.sslEngineFactory()).sslContext(), + "SSL context recreated unnecessarily"); + assertFalse(sslFactory.createSslEngine("localhost", 0).getUseClientMode()); + + File newTrustStoreFile = File.createTempFile("truststore", ".jks"); + sslConfig = sslConfigsBuilder(Mode.SERVER) + .createNewTrustStore(newTrustStoreFile) + .build(); + sslConfig.remove(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG); + sslConfig.remove(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG); + sslConfig.remove(SslConfigs.SSL_KEYSTORE_TYPE_CONFIG); + sslFactory.reconfigure(sslConfig); + assertNotSame(sslContext, ((DefaultSslEngineFactory) sslFactory.sslEngineFactory()).sslContext(), + "SSL context not recreated"); + + sslConfig = sslConfigsBuilder(Mode.SERVER) + .createNewTrustStore(newTrustStoreFile) + .build(); + try { + sslFactory.validateReconfiguration(sslConfig); + fail("Keystore configured dynamically for listener without previous keystore"); + } catch (ConfigException e) { + // Expected exception + } + } + + @Test + public void testPemReconfiguration() throws Exception { + Properties props = new Properties(); + props.putAll(sslConfigsBuilder(Mode.SERVER) + .createNewTrustStore(null) + .usePem(true) + .build()); + TestSecurityConfig sslConfig = new TestSecurityConfig(props); + + SslFactory sslFactory = new SslFactory(Mode.SERVER); + sslFactory.configure(sslConfig.values()); + SslEngineFactory sslEngineFactory = sslFactory.sslEngineFactory(); + assertNotNull(sslEngineFactory, "SslEngineFactory not created"); + + props.put("some.config", "some.value"); + sslConfig = new TestSecurityConfig(props); + sslFactory.reconfigure(sslConfig.values()); + assertSame(sslEngineFactory, sslFactory.sslEngineFactory(), "SslEngineFactory recreated unnecessarily"); + + props.put(SslConfigs.SSL_KEYSTORE_KEY_CONFIG, + new Password(((Password) props.get(SslConfigs.SSL_KEYSTORE_KEY_CONFIG)).value() + " ")); + sslConfig = new TestSecurityConfig(props); + sslFactory.reconfigure(sslConfig.values()); + assertNotSame(sslEngineFactory, sslFactory.sslEngineFactory(), "SslEngineFactory not recreated"); + sslEngineFactory = sslFactory.sslEngineFactory(); + + props.put(SslConfigs.SSL_KEYSTORE_CERTIFICATE_CHAIN_CONFIG, + new Password(((Password) props.get(SslConfigs.SSL_KEYSTORE_CERTIFICATE_CHAIN_CONFIG)).value() + " ")); + sslConfig = new TestSecurityConfig(props); + sslFactory.reconfigure(sslConfig.values()); + assertNotSame(sslEngineFactory, sslFactory.sslEngineFactory(), "SslEngineFactory not recreated"); + sslEngineFactory = sslFactory.sslEngineFactory(); + + props.put(SslConfigs.SSL_TRUSTSTORE_CERTIFICATES_CONFIG, + new Password(((Password) props.get(SslConfigs.SSL_TRUSTSTORE_CERTIFICATES_CONFIG)).value() + " ")); + sslConfig = new TestSecurityConfig(props); + sslFactory.reconfigure(sslConfig.values()); + assertNotSame(sslEngineFactory, sslFactory.sslEngineFactory(), "SslEngineFactory not recreated"); + sslEngineFactory = sslFactory.sslEngineFactory(); + } + + @Test + public void testKeyStoreTrustStoreValidation() throws Exception { + File trustStoreFile = File.createTempFile("truststore", ".jks"); + Map serverSslConfig = sslConfigsBuilder(Mode.SERVER) + .createNewTrustStore(trustStoreFile) + .build(); + SslFactory sslFactory = new SslFactory(Mode.SERVER); + sslFactory.configure(serverSslConfig); + assertNotNull(sslFactory.sslEngineFactory(), "SslEngineFactory not created"); + } + + @Test + public void testUntrustedKeyStoreValidationFails() throws Exception { + File trustStoreFile1 = File.createTempFile("truststore1", ".jks"); + File trustStoreFile2 = File.createTempFile("truststore2", ".jks"); + Map sslConfig1 = sslConfigsBuilder(Mode.SERVER) + .createNewTrustStore(trustStoreFile1) + .build(); + Map sslConfig2 = sslConfigsBuilder(Mode.SERVER) + .createNewTrustStore(trustStoreFile2) + .build(); + SslFactory sslFactory = new SslFactory(Mode.SERVER, null, true); + for (String key : Arrays.asList(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG, + SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG, + SslConfigs.SSL_TRUSTSTORE_TYPE_CONFIG, + SslConfigs.SSL_TRUSTMANAGER_ALGORITHM_CONFIG)) { + sslConfig1.put(key, sslConfig2.get(key)); + } + try { + sslFactory.configure(sslConfig1); + fail("Validation did not fail with untrusted truststore"); + } catch (ConfigException e) { + // Expected exception + } + } + + @Test + public void testKeystoreVerifiableUsingTruststore() throws Exception { + verifyKeystoreVerifiableUsingTruststore(false, tlsProtocol); + } + + @Test + public void testPemKeystoreVerifiableUsingTruststore() throws Exception { + verifyKeystoreVerifiableUsingTruststore(true, tlsProtocol); + } + + private void verifyKeystoreVerifiableUsingTruststore(boolean usePem, String tlsProtocol) throws Exception { + File trustStoreFile1 = usePem ? null : File.createTempFile("truststore1", ".jks"); + Map sslConfig1 = sslConfigsBuilder(Mode.SERVER) + .createNewTrustStore(trustStoreFile1) + .usePem(usePem) + .build(); + SslFactory sslFactory = new SslFactory(Mode.SERVER, null, true); + sslFactory.configure(sslConfig1); + + File trustStoreFile2 = usePem ? null : File.createTempFile("truststore2", ".jks"); + Map sslConfig2 = sslConfigsBuilder(Mode.SERVER) + .createNewTrustStore(trustStoreFile2) + .usePem(usePem) + .build(); + // Verify that `createSSLContext` fails even if certificate from new keystore is trusted by + // the new truststore, if certificate is not trusted by the existing truststore on the `SslFactory`. + // This is to prevent both keystores and truststores to be modified simultaneously on an inter-broker + // listener to stores that may not work with other brokers where the update hasn't yet been performed. + try { + sslFactory.validateReconfiguration(sslConfig2); + fail("ValidateReconfiguration did not fail as expected"); + } catch (ConfigException e) { + // Expected exception + } + } + + @Test + public void testCertificateEntriesValidation() throws Exception { + verifyCertificateEntriesValidation(false, tlsProtocol); + } + + @Test + public void testPemCertificateEntriesValidation() throws Exception { + verifyCertificateEntriesValidation(true, tlsProtocol); + } + + private void verifyCertificateEntriesValidation(boolean usePem, String tlsProtocol) throws Exception { + File trustStoreFile = usePem ? null : File.createTempFile("truststore", ".jks"); + Map serverSslConfig = sslConfigsBuilder(Mode.SERVER) + .createNewTrustStore(trustStoreFile) + .usePem(usePem) + .build(); + File newTrustStoreFile = usePem ? null : File.createTempFile("truststore", ".jks"); + Map newCnConfig = sslConfigsBuilder(Mode.SERVER) + .createNewTrustStore(newTrustStoreFile) + .cn("Another CN") + .usePem(usePem) + .build(); + KeyStore ks1 = sslKeyStore(serverSslConfig); + KeyStore ks2 = sslKeyStore(serverSslConfig); + assertEquals(SslFactory.CertificateEntries.create(ks1), SslFactory.CertificateEntries.create(ks2)); + + // Use different alias name, validation should succeed + ks2.setCertificateEntry("another", ks1.getCertificate("localhost")); + assertEquals(SslFactory.CertificateEntries.create(ks1), SslFactory.CertificateEntries.create(ks2)); + + KeyStore ks3 = sslKeyStore(newCnConfig); + assertNotEquals(SslFactory.CertificateEntries.create(ks1), SslFactory.CertificateEntries.create(ks3)); + } + + /** + * Tests client side ssl.engine.factory configuration is used when specified + */ + @Test + public void testClientSpecifiedSslEngineFactoryUsed() throws Exception { + File trustStoreFile = File.createTempFile("truststore", ".jks"); + Map clientSslConfig = sslConfigsBuilder(Mode.CLIENT) + .createNewTrustStore(trustStoreFile) + .useClientCert(false) + .build(); + clientSslConfig.put(SslConfigs.SSL_ENGINE_FACTORY_CLASS_CONFIG, TestSslUtils.TestSslEngineFactory.class); + SslFactory sslFactory = new SslFactory(Mode.CLIENT); + sslFactory.configure(clientSslConfig); + assertTrue(sslFactory.sslEngineFactory() instanceof TestSslUtils.TestSslEngineFactory, + "SslEngineFactory must be of expected type"); + } + + @Test + public void testEngineFactoryClosed() throws Exception { + File trustStoreFile = File.createTempFile("truststore", ".jks"); + Map clientSslConfig = sslConfigsBuilder(Mode.CLIENT) + .createNewTrustStore(trustStoreFile) + .useClientCert(false) + .build(); + clientSslConfig.put(SslConfigs.SSL_ENGINE_FACTORY_CLASS_CONFIG, TestSslUtils.TestSslEngineFactory.class); + SslFactory sslFactory = new SslFactory(Mode.CLIENT); + sslFactory.configure(clientSslConfig); + TestSslUtils.TestSslEngineFactory engine = (TestSslUtils.TestSslEngineFactory) sslFactory.sslEngineFactory(); + assertFalse(engine.closed); + sslFactory.close(); + assertTrue(engine.closed); + } + + /** + * Tests server side ssl.engine.factory configuration is used when specified + */ + @Test + public void testServerSpecifiedSslEngineFactoryUsed() throws Exception { + File trustStoreFile = File.createTempFile("truststore", ".jks"); + Map serverSslConfig = sslConfigsBuilder(Mode.SERVER) + .createNewTrustStore(trustStoreFile) + .useClientCert(false) + .build(); + serverSslConfig.put(SslConfigs.SSL_ENGINE_FACTORY_CLASS_CONFIG, TestSslUtils.TestSslEngineFactory.class); + SslFactory sslFactory = new SslFactory(Mode.SERVER); + sslFactory.configure(serverSslConfig); + assertTrue(sslFactory.sslEngineFactory() instanceof TestSslUtils.TestSslEngineFactory, + "SslEngineFactory must be of expected type"); + } + + /** + * Tests invalid ssl.engine.factory configuration + */ + @Test + public void testInvalidSslEngineFactory() throws Exception { + File trustStoreFile = File.createTempFile("truststore", ".jks"); + Map clientSslConfig = sslConfigsBuilder(Mode.CLIENT) + .createNewTrustStore(trustStoreFile) + .useClientCert(false) + .build(); + clientSslConfig.put(SslConfigs.SSL_ENGINE_FACTORY_CLASS_CONFIG, String.class); + SslFactory sslFactory = new SslFactory(Mode.CLIENT); + assertThrows(ClassCastException.class, () -> sslFactory.configure(clientSslConfig)); + } + + @Test + public void testUsedConfigs() throws IOException, GeneralSecurityException { + Map serverSslConfig = sslConfigsBuilder(Mode.SERVER) + .createNewTrustStore(File.createTempFile("truststore", ".jks")) + .useClientCert(false) + .build(); + serverSslConfig.put(SslConfigs.SSL_ENGINE_FACTORY_CLASS_CONFIG, TestSslUtils.TestSslEngineFactory.class); + TestSecurityConfig securityConfig = new TestSecurityConfig(serverSslConfig); + SslFactory sslFactory = new SslFactory(Mode.SERVER); + sslFactory.configure(securityConfig.values()); + assertFalse(securityConfig.unused().contains(SslConfigs.SSL_ENGINE_FACTORY_CLASS_CONFIG)); + } + + private KeyStore sslKeyStore(Map sslConfig) { + SecurityStore store; + if (sslConfig.get(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG) != null) { + store = new FileBasedStore( + (String) sslConfig.get(SslConfigs.SSL_KEYSTORE_TYPE_CONFIG), + (String) sslConfig.get(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG), + (Password) sslConfig.get(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG), + (Password) sslConfig.get(SslConfigs.SSL_KEY_PASSWORD_CONFIG), + true + ); + } else { + store = new PemStore( + (Password) sslConfig.get(SslConfigs.SSL_KEYSTORE_CERTIFICATE_CHAIN_CONFIG), + (Password) sslConfig.get(SslConfigs.SSL_KEYSTORE_KEY_CONFIG), + (Password) sslConfig.get(SslConfigs.SSL_KEY_PASSWORD_CONFIG) + ); + } + return store.get(); + } + + private TestSslUtils.SslConfigsBuilder sslConfigsBuilder(Mode mode) { + return new TestSslUtils.SslConfigsBuilder(mode).tlsProtocol(tlsProtocol); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/ssl/SslPrincipalMapperTest.java b/clients/src/test/java/org/apache/kafka/common/security/ssl/SslPrincipalMapperTest.java new file mode 100644 index 0000000..ff5a018 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/ssl/SslPrincipalMapperTest.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.ssl; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +public class SslPrincipalMapperTest { + + @Test + public void testValidRules() { + testValidRule("DEFAULT"); + testValidRule("RULE:^CN=(.*?),OU=ServiceUsers.*$/$1/"); + testValidRule("RULE:^CN=(.*?),OU=ServiceUsers.*$/$1/L, DEFAULT"); + testValidRule("RULE:^CN=(.*?),OU=(.*?),O=(.*?),L=(.*?),ST=(.*?),C=(.*?)$/$1@$2/"); + testValidRule("RULE:^.*[Cc][Nn]=([a-zA-Z0-9.]*).*$/$1/L"); + testValidRule("RULE:^cn=(.?),ou=(.?),dc=(.?),dc=(.?)$/$1@$2/U"); + + testValidRule("RULE:^CN=([^,ADEFLTU,]+)(,.*|$)/$1/"); + testValidRule("RULE:^CN=([^,DEFAULT,]+)(,.*|$)/$1/"); + } + + private void testValidRule(String rules) { + SslPrincipalMapper.fromRules(rules); + } + + @Test + public void testInvalidRules() { + testInvalidRule("default"); + testInvalidRule("DEFAUL"); + testInvalidRule("DEFAULT/L"); + testInvalidRule("DEFAULT/U"); + + testInvalidRule("RULE:CN=(.*?),OU=ServiceUsers.*/$1"); + testInvalidRule("rule:^CN=(.*?),OU=ServiceUsers.*$/$1/"); + testInvalidRule("RULE:^CN=(.*?),OU=ServiceUsers.*$/$1/L/U"); + testInvalidRule("RULE:^CN=(.*?),OU=ServiceUsers.*$/L"); + testInvalidRule("RULE:^CN=(.*?),OU=ServiceUsers.*$/U"); + testInvalidRule("RULE:^CN=(.*?),OU=ServiceUsers.*$/LU"); + } + + private void testInvalidRule(String rules) { + try { + System.out.println(SslPrincipalMapper.fromRules(rules)); + fail("should have thrown IllegalArgumentException"); + } catch (IllegalArgumentException e) { + } + } + + @Test + public void testSslPrincipalMapper() throws Exception { + String rules = String.join(", ", + "RULE:^CN=(.*?),OU=ServiceUsers.*$/$1/L", + "RULE:^CN=(.*?),OU=(.*?),O=(.*?),L=(.*?),ST=(.*?),C=(.*?)$/$1@$2/L", + "RULE:^cn=(.*?),ou=(.*?),dc=(.*?),dc=(.*?)$/$1@$2/U", + "RULE:^.*[Cc][Nn]=([a-zA-Z0-9.]*).*$/$1/U", + "DEFAULT" + ); + + SslPrincipalMapper mapper = SslPrincipalMapper.fromRules(rules); + + assertEquals("duke", mapper.getName("CN=Duke,OU=ServiceUsers,O=Org,C=US")); + assertEquals("duke@sme", mapper.getName("CN=Duke,OU=SME,O=mycp,L=Fulton,ST=MD,C=US")); + assertEquals("DUKE@SME", mapper.getName("cn=duke,ou=sme,dc=mycp,dc=com")); + assertEquals("DUKE", mapper.getName("cN=duke,OU=JavaSoft,O=Sun Microsystems")); + assertEquals("OU=JavaSoft,O=Sun Microsystems,C=US", mapper.getName("OU=JavaSoft,O=Sun Microsystems,C=US")); + } + + private void testRulesSplitting(String expected, String rules) { + SslPrincipalMapper mapper = SslPrincipalMapper.fromRules(rules); + assertEquals(String.format("SslPrincipalMapper(rules = %s)", expected), mapper.toString()); + } + + @Test + public void testRulesSplitting() { + // seeing is believing + testRulesSplitting("[]", ""); + testRulesSplitting("[DEFAULT]", "DEFAULT"); + testRulesSplitting("[RULE:/]", "RULE://"); + testRulesSplitting("[RULE:/.*]", "RULE:/.*/"); + testRulesSplitting("[RULE:/.*/L]", "RULE:/.*/L"); + testRulesSplitting("[RULE:/, DEFAULT]", "RULE://,DEFAULT"); + testRulesSplitting("[RULE:/, DEFAULT]", " RULE:// , DEFAULT "); + testRulesSplitting("[RULE: / , DEFAULT]", " RULE: / / , DEFAULT "); + testRulesSplitting("[RULE: / /U, DEFAULT]", " RULE: / /U ,DEFAULT "); + testRulesSplitting("[RULE:([A-Z]*)/$1/U, RULE:([a-z]+)/$1, DEFAULT]", " RULE:([A-Z]*)/$1/U ,RULE:([a-z]+)/$1/, DEFAULT "); + + // empty rules are ignored + testRulesSplitting("[]", ", , , , , , , "); + testRulesSplitting("[RULE:/, DEFAULT]", ",,RULE://,,,DEFAULT,,"); + testRulesSplitting("[RULE: / , DEFAULT]", ", , RULE: / / ,,, DEFAULT, , "); + testRulesSplitting("[RULE: / /U, DEFAULT]", " , , RULE: / /U ,, ,DEFAULT, ,"); + + // escape sequences + testRulesSplitting("[RULE:\\/\\\\\\(\\)\\n\\t/\\/\\/]", "RULE:\\/\\\\\\(\\)\\n\\t/\\/\\//"); + testRulesSplitting("[RULE:\\**\\/+/*/L, RULE:\\/*\\**/**]", "RULE:\\**\\/+/*/L,RULE:\\/*\\**/**/"); + + // rules rule + testRulesSplitting( + "[RULE:,RULE:,/,RULE:,\\//U, RULE:,/RULE:,, RULE:,RULE:,/L,RULE:,/L, RULE:, DEFAULT, /DEFAULT, DEFAULT]", + "RULE:,RULE:,/,RULE:,\\//U,RULE:,/RULE:,/,RULE:,RULE:,/L,RULE:,/L,RULE:, DEFAULT, /DEFAULT/,DEFAULT" + ); + } + + @Test + public void testCommaWithWhitespace() throws Exception { + String rules = "RULE:^CN=((\\\\, *|\\w)+)(,.*|$)/$1/,DEFAULT"; + + SslPrincipalMapper mapper = SslPrincipalMapper.fromRules(rules); + assertEquals("Tkac\\, Adam", mapper.getName("CN=Tkac\\, Adam,OU=ITZ,DC=geodis,DC=cz")); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/ssl/Tls12SslFactoryTest.java b/clients/src/test/java/org/apache/kafka/common/security/ssl/Tls12SslFactoryTest.java new file mode 100644 index 0000000..67b7202 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/ssl/Tls12SslFactoryTest.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.ssl; + +public class Tls12SslFactoryTest extends SslFactoryTest { + public Tls12SslFactoryTest() { + super("TLSv1.2"); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/ssl/Tls13SslFactoryTest.java b/clients/src/test/java/org/apache/kafka/common/security/ssl/Tls13SslFactoryTest.java new file mode 100644 index 0000000..3d7835f --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/ssl/Tls13SslFactoryTest.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.ssl; + +import org.junit.jupiter.api.condition.DisabledOnJre; +import org.junit.jupiter.api.condition.JRE; + +@DisabledOnJre(JRE.JAVA_8) +public class Tls13SslFactoryTest extends SslFactoryTest { + public Tls13SslFactoryTest() { + super("TLSv1.3"); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/ssl/mock/TestKeyManagerFactory.java b/clients/src/test/java/org/apache/kafka/common/security/ssl/mock/TestKeyManagerFactory.java new file mode 100644 index 0000000..dc686c2 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/ssl/mock/TestKeyManagerFactory.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.ssl.mock; + +import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactorySpi; +import javax.net.ssl.ManagerFactoryParameters; +import javax.net.ssl.X509ExtendedKeyManager; +import java.io.File; +import java.io.IOException; +import java.net.Socket; +import java.security.GeneralSecurityException; +import java.security.KeyPair; +import java.security.KeyStore; +import java.security.Principal; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.HashMap; +import java.util.Map; + +import org.apache.kafka.common.config.types.Password; +import org.apache.kafka.test.TestSslUtils; +import org.apache.kafka.test.TestSslUtils.CertificateBuilder; + +public class TestKeyManagerFactory extends KeyManagerFactorySpi { + public static final String ALGORITHM = "TestAlgorithm"; + + @Override + protected void engineInit(KeyStore keyStore, char[] chars) { + + } + + @Override + protected void engineInit(ManagerFactoryParameters managerFactoryParameters) { + + } + + @Override + protected KeyManager[] engineGetKeyManagers() { + return new KeyManager[] {new TestKeyManager()}; + } + + public static class TestKeyManager extends X509ExtendedKeyManager { + + public static String mockTrustStoreFile; + public static final String ALIAS = "TestAlias"; + private static final String CN = "localhost"; + private static final String SIGNATURE_ALGORITHM = "RSA"; + private KeyPair keyPair; + private X509Certificate certificate; + + protected TestKeyManager() { + try { + this.keyPair = TestSslUtils.generateKeyPair(SIGNATURE_ALGORITHM); + CertificateBuilder certBuilder = new CertificateBuilder(); + this.certificate = certBuilder.generate("CN=" + CN + ", O=A server", this.keyPair); + Map certificates = new HashMap<>(); + certificates.put(ALIAS, certificate); + File trustStoreFile = File.createTempFile("testTrustStore", ".jks"); + mockTrustStoreFile = trustStoreFile.getPath(); + TestSslUtils.createTrustStore(mockTrustStoreFile, new Password(TestSslUtils.TRUST_STORE_PASSWORD), certificates); + } catch (IOException | GeneralSecurityException e) { + throw new RuntimeException(e); + } + } + + @Override + public String[] getClientAliases(String s, Principal[] principals) { + return new String[] {ALIAS}; + } + + @Override + public String chooseClientAlias(String[] strings, Principal[] principals, Socket socket) { + return ALIAS; + } + + @Override + public String[] getServerAliases(String s, Principal[] principals) { + return new String[] {ALIAS}; + } + + @Override + public String chooseServerAlias(String s, Principal[] principals, Socket socket) { + return ALIAS; + } + + @Override + public X509Certificate[] getCertificateChain(String s) { + return new X509Certificate[] {this.certificate}; + } + + @Override + public PrivateKey getPrivateKey(String s) { + return this.keyPair.getPrivate(); + } + } + +} + diff --git a/clients/src/test/java/org/apache/kafka/common/security/ssl/mock/TestPlainSaslServerProvider.java b/clients/src/test/java/org/apache/kafka/common/security/ssl/mock/TestPlainSaslServerProvider.java new file mode 100644 index 0000000..5e6e82e --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/ssl/mock/TestPlainSaslServerProvider.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.ssl.mock; + +import java.security.Provider; + +public class TestPlainSaslServerProvider extends Provider { + + public TestPlainSaslServerProvider() { + this("TestPlainSaslServerProvider", 0.1, "test plain sasl server provider"); + } + + protected TestPlainSaslServerProvider(String name, double version, String info) { + super(name, version, info); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/ssl/mock/TestPlainSaslServerProviderCreator.java b/clients/src/test/java/org/apache/kafka/common/security/ssl/mock/TestPlainSaslServerProviderCreator.java new file mode 100644 index 0000000..0ab927e --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/ssl/mock/TestPlainSaslServerProviderCreator.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.ssl.mock; + +import org.apache.kafka.common.security.auth.SecurityProviderCreator; + +import java.security.Provider; + +public class TestPlainSaslServerProviderCreator implements SecurityProviderCreator { + + private TestPlainSaslServerProvider provider; + + @Override + public Provider getProvider() { + if (provider == null) { + provider = new TestPlainSaslServerProvider(); + } + return provider; + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/ssl/mock/TestProvider.java b/clients/src/test/java/org/apache/kafka/common/security/ssl/mock/TestProvider.java new file mode 100644 index 0000000..fb44d3c --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/ssl/mock/TestProvider.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.ssl.mock; + +import java.security.Provider; + +public class TestProvider extends Provider { + + private static final String KEY_MANAGER_FACTORY = String.format("KeyManagerFactory.%s", TestKeyManagerFactory.ALGORITHM); + private static final String TRUST_MANAGER_FACTORY = String.format("TrustManagerFactory.%s", TestTrustManagerFactory.ALGORITHM); + + public TestProvider() { + this("TestProvider", 0.1, "provider for test cases"); + } + + protected TestProvider(String name, double version, String info) { + super(name, version, info); + super.put(KEY_MANAGER_FACTORY, TestKeyManagerFactory.class.getName()); + super.put(TRUST_MANAGER_FACTORY, TestTrustManagerFactory.class.getName()); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/ssl/mock/TestProviderCreator.java b/clients/src/test/java/org/apache/kafka/common/security/ssl/mock/TestProviderCreator.java new file mode 100644 index 0000000..57c455a --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/ssl/mock/TestProviderCreator.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.ssl.mock; + +import org.apache.kafka.common.security.auth.SecurityProviderCreator; + +import java.security.Provider; + +public class TestProviderCreator implements SecurityProviderCreator { + + private TestProvider provider; + + @Override + public Provider getProvider() { + if (provider == null) { + provider = new TestProvider(); + } + return provider; + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/ssl/mock/TestScramSaslServerProvider.java b/clients/src/test/java/org/apache/kafka/common/security/ssl/mock/TestScramSaslServerProvider.java new file mode 100644 index 0000000..c5e8310 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/ssl/mock/TestScramSaslServerProvider.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.ssl.mock; + +import java.security.Provider; + +public class TestScramSaslServerProvider extends Provider { + + public TestScramSaslServerProvider() { + this("TestScramSaslServerProvider", 0.1, "test scram sasl server provider"); + } + + protected TestScramSaslServerProvider(String name, double version, String info) { + super(name, version, info); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/ssl/mock/TestScramSaslServerProviderCreator.java b/clients/src/test/java/org/apache/kafka/common/security/ssl/mock/TestScramSaslServerProviderCreator.java new file mode 100644 index 0000000..72eb880 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/ssl/mock/TestScramSaslServerProviderCreator.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.ssl.mock; + +import org.apache.kafka.common.security.auth.SecurityProviderCreator; + +import java.security.Provider; + +public class TestScramSaslServerProviderCreator implements SecurityProviderCreator { + + private TestScramSaslServerProvider provider; + + @Override + public Provider getProvider() { + if (provider == null) { + provider = new TestScramSaslServerProvider(); + } + return provider; + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/ssl/mock/TestTrustManagerFactory.java b/clients/src/test/java/org/apache/kafka/common/security/ssl/mock/TestTrustManagerFactory.java new file mode 100644 index 0000000..4115a5f --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/ssl/mock/TestTrustManagerFactory.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.security.ssl.mock; + +import javax.net.ssl.ManagerFactoryParameters; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactorySpi; +import javax.net.ssl.X509ExtendedTrustManager; +import java.net.Socket; +import java.security.KeyStore; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; + +public class TestTrustManagerFactory extends TrustManagerFactorySpi { + public static final String ALGORITHM = "TestAlgorithm"; + + @Override + protected void engineInit(KeyStore keyStore) { + + } + + @Override + protected void engineInit(ManagerFactoryParameters managerFactoryParameters) { + + } + + @Override + protected TrustManager[] engineGetTrustManagers() { + return new TrustManager[] {new TestTrustManager()}; + } + + public static class TestTrustManager extends X509ExtendedTrustManager { + + public static final String ALIAS = "TestAlias"; + + @Override + public void checkClientTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException { + + } + + @Override + public void checkServerTrusted(X509Certificate[] x509Certificates, String s) throws CertificateException { + + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return new X509Certificate[0]; + } + + @Override + public void checkClientTrusted(X509Certificate[] x509Certificates, String s, Socket socket) throws CertificateException { + + } + + @Override + public void checkServerTrusted(X509Certificate[] x509Certificates, String s, Socket socket) throws CertificateException { + + } + + @Override + public void checkClientTrusted(X509Certificate[] x509Certificates, String s, SSLEngine sslEngine) throws CertificateException { + + } + + @Override + public void checkServerTrusted(X509Certificate[] x509Certificates, String s, SSLEngine sslEngine) throws CertificateException { + + } + } + +} + diff --git a/clients/src/test/java/org/apache/kafka/common/serialization/ListDeserializerTest.java b/clients/src/test/java/org/apache/kafka/common/serialization/ListDeserializerTest.java new file mode 100644 index 0000000..aff01e3 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/serialization/ListDeserializerTest.java @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.config.ConfigException; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; +import org.junit.jupiter.api.Test; + +@SuppressWarnings("unchecked") +public class ListDeserializerTest { + private final ListDeserializer listDeserializer = new ListDeserializer<>(); + private final Map props = new HashMap<>(); + private final String nonExistingClass = "non.existing.class"; + private static class FakeObject { + } + + @Test + public void testListKeyDeserializerNoArgConstructorsWithClassNames() { + props.put(CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_TYPE_CLASS, ArrayList.class.getName()); + props.put(CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_INNER_CLASS, Serdes.StringSerde.class.getName()); + listDeserializer.configure(props, true); + final Deserializer inner = listDeserializer.innerDeserializer(); + assertNotNull(inner, "Inner deserializer should be not null"); + assertTrue(inner instanceof StringDeserializer, "Inner deserializer type should be StringDeserializer"); + } + + @Test + public void testListValueDeserializerNoArgConstructorsWithClassNames() { + props.put(CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_TYPE_CLASS, ArrayList.class.getName()); + props.put(CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_INNER_CLASS, Serdes.IntegerSerde.class.getName()); + listDeserializer.configure(props, false); + final Deserializer inner = listDeserializer.innerDeserializer(); + assertNotNull(inner, "Inner deserializer should be not null"); + assertTrue(inner instanceof IntegerDeserializer, "Inner deserializer type should be IntegerDeserializer"); + } + + @Test + public void testListKeyDeserializerNoArgConstructorsWithClassObjects() { + props.put(CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_TYPE_CLASS, ArrayList.class); + props.put(CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_INNER_CLASS, Serdes.StringSerde.class); + listDeserializer.configure(props, true); + final Deserializer inner = listDeserializer.innerDeserializer(); + assertNotNull(inner, "Inner deserializer should be not null"); + assertTrue(inner instanceof StringDeserializer, "Inner deserializer type should be StringDeserializer"); + } + + @Test + public void testListValueDeserializerNoArgConstructorsWithClassObjects() { + props.put(CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_TYPE_CLASS, ArrayList.class); + props.put(CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_INNER_CLASS, Serdes.StringSerde.class); + listDeserializer.configure(props, false); + final Deserializer inner = listDeserializer.innerDeserializer(); + assertNotNull(inner, "Inner deserializer should be not null"); + assertTrue(inner instanceof StringDeserializer, "Inner deserializer type should be StringDeserializer"); + } + + @Test + public void testListKeyDeserializerNoArgConstructorsShouldThrowConfigExceptionDueMissingInnerClassProp() { + props.put(CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_TYPE_CLASS, ArrayList.class); + final ConfigException exception = assertThrows( + ConfigException.class, + () -> listDeserializer.configure(props, true) + ); + assertEquals("Not able to determine the inner serde class because " + + "it was neither passed via the constructor nor set in the config.", exception.getMessage()); + } + + @Test + public void testListValueDeserializerNoArgConstructorsShouldThrowConfigExceptionDueMissingInnerClassProp() { + props.put(CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_TYPE_CLASS, ArrayList.class); + final ConfigException exception = assertThrows( + ConfigException.class, + () -> listDeserializer.configure(props, false) + ); + assertEquals("Not able to determine the inner serde class because " + + "it was neither passed via the constructor nor set in the config.", exception.getMessage()); + } + + @Test + public void testListKeyDeserializerNoArgConstructorsShouldThrowConfigExceptionDueMissingTypeClassProp() { + props.put(CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_INNER_CLASS, Serdes.StringSerde.class); + final ConfigException exception = assertThrows( + ConfigException.class, + () -> listDeserializer.configure(props, true) + ); + assertEquals("Not able to determine the list class because " + + "it was neither passed via the constructor nor set in the config.", exception.getMessage()); + } + + @Test + public void testListValueDeserializerNoArgConstructorsShouldThrowConfigExceptionDueMissingTypeClassProp() { + props.put(CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_INNER_CLASS, Serdes.StringSerde.class); + final ConfigException exception = assertThrows( + ConfigException.class, + () -> listDeserializer.configure(props, false) + ); + assertEquals("Not able to determine the list class because " + + "it was neither passed via the constructor nor set in the config.", exception.getMessage()); + } + + @Test + public void testListKeyDeserializerNoArgConstructorsShouldThrowKafkaExceptionDueInvalidTypeClass() { + props.put(CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_TYPE_CLASS, new FakeObject()); + props.put(CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_INNER_CLASS, Serdes.StringSerde.class); + final KafkaException exception = assertThrows( + KafkaException.class, + () -> listDeserializer.configure(props, true) + ); + assertEquals("Could not determine the list class instance using " + + "\"" + CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_TYPE_CLASS + "\" property.", exception.getMessage()); + } + + @Test + public void testListValueDeserializerNoArgConstructorsShouldThrowKafkaExceptionDueInvalidTypeClass() { + props.put(CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_TYPE_CLASS, new FakeObject()); + props.put(CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_INNER_CLASS, Serdes.StringSerde.class); + final KafkaException exception = assertThrows( + KafkaException.class, + () -> listDeserializer.configure(props, false) + ); + assertEquals("Could not determine the list class instance using " + + "\"" + CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_TYPE_CLASS + "\" property.", exception.getMessage()); + } + + @Test + public void testListKeyDeserializerNoArgConstructorsShouldThrowKafkaExceptionDueInvalidInnerClass() { + props.put(CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_TYPE_CLASS, ArrayList.class); + props.put(CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_INNER_CLASS, new FakeObject()); + final KafkaException exception = assertThrows( + KafkaException.class, + () -> listDeserializer.configure(props, true) + ); + assertEquals("Could not determine the inner serde class instance using " + + "\"" + CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_INNER_CLASS + "\" property.", exception.getMessage()); + } + + @Test + public void testListValueDeserializerNoArgConstructorsShouldThrowKafkaExceptionDueInvalidInnerClass() { + props.put(CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_TYPE_CLASS, ArrayList.class); + props.put(CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_INNER_CLASS, new FakeObject()); + final KafkaException exception = assertThrows( + KafkaException.class, + () -> listDeserializer.configure(props, false) + ); + assertEquals("Could not determine the inner serde class instance using " + + "\"" + CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_INNER_CLASS + "\" property.", exception.getMessage()); + } + + @Test + public void testListKeyDeserializerNoArgConstructorsShouldThrowConfigExceptionDueListClassNotFound() { + props.put(CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_TYPE_CLASS, nonExistingClass); + props.put(CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_INNER_CLASS, Serdes.StringSerde.class); + final ConfigException exception = assertThrows( + ConfigException.class, + () -> listDeserializer.configure(props, true) + ); + assertEquals("Invalid value " + nonExistingClass + " for configuration " + + CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_TYPE_CLASS + ": Deserializer's list class " + + "\"" + nonExistingClass + "\" could not be found.", exception.getMessage()); + } + + @Test + public void testListValueDeserializerNoArgConstructorsShouldThrowConfigExceptionDueListClassNotFound() { + props.put(CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_TYPE_CLASS, nonExistingClass); + props.put(CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_INNER_CLASS, Serdes.StringSerde.class); + final ConfigException exception = assertThrows( + ConfigException.class, + () -> listDeserializer.configure(props, false) + ); + assertEquals("Invalid value " + nonExistingClass + " for configuration " + + CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_TYPE_CLASS + ": Deserializer's list class " + + "\"" + nonExistingClass + "\" could not be found.", exception.getMessage()); + } + + @Test + public void testListKeyDeserializerNoArgConstructorsShouldThrowConfigExceptionDueInnerSerdeClassNotFound() { + props.put(CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_TYPE_CLASS, ArrayList.class); + props.put(CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_INNER_CLASS, nonExistingClass); + final ConfigException exception = assertThrows( + ConfigException.class, + () -> listDeserializer.configure(props, true) + ); + assertEquals("Invalid value " + nonExistingClass + " for configuration " + + CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_INNER_CLASS + ": Deserializer's inner serde class " + + "\"" + nonExistingClass + "\" could not be found.", exception.getMessage()); + } + + @Test + public void testListValueDeserializerNoArgConstructorsShouldThrowConfigExceptionDueInnerSerdeClassNotFound() { + props.put(CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_TYPE_CLASS, ArrayList.class); + props.put(CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_INNER_CLASS, nonExistingClass); + final ConfigException exception = assertThrows( + ConfigException.class, + () -> listDeserializer.configure(props, false) + ); + assertEquals("Invalid value " + nonExistingClass + " for configuration " + + CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_INNER_CLASS + ": Deserializer's inner serde class " + + "\"" + nonExistingClass + "\" could not be found.", exception.getMessage()); + } + + @Test + public void testListKeyDeserializerShouldThrowConfigExceptionDueAlreadyInitialized() { + props.put(CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_TYPE_CLASS, ArrayList.class); + props.put(CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_INNER_CLASS, Serdes.StringSerde.class); + final ListDeserializer initializedListDeserializer = new ListDeserializer<>(ArrayList.class, + Serdes.Integer().deserializer()); + final ConfigException exception = assertThrows( + ConfigException.class, + () -> initializedListDeserializer.configure(props, true) + ); + assertEquals("List deserializer was already initialized using a non-default constructor", exception.getMessage()); + } + + @Test + public void testListValueDeserializerShouldThrowConfigExceptionDueAlreadyInitialized() { + props.put(CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_TYPE_CLASS, ArrayList.class); + props.put(CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_INNER_CLASS, Serdes.StringSerde.class); + final ListDeserializer initializedListDeserializer = new ListDeserializer<>(ArrayList.class, + Serdes.Integer().deserializer()); + final ConfigException exception = assertThrows( + ConfigException.class, + () -> initializedListDeserializer.configure(props, true) + ); + assertEquals("List deserializer was already initialized using a non-default constructor", exception.getMessage()); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/serialization/ListSerializerTest.java b/clients/src/test/java/org/apache/kafka/common/serialization/ListSerializerTest.java new file mode 100644 index 0000000..a8ab191 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/serialization/ListSerializerTest.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.config.ConfigException; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; + + +public class ListSerializerTest { + private final ListSerializer listSerializer = new ListSerializer<>(); + private final Map props = new HashMap<>(); + private final String nonExistingClass = "non.existing.class"; + private static class FakeObject { + } + + @Test + public void testListKeySerializerNoArgConstructorsWithClassName() { + props.put(CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_INNER_CLASS, Serdes.StringSerde.class.getName()); + listSerializer.configure(props, true); + final Serializer inner = listSerializer.getInnerSerializer(); + assertNotNull(inner, "Inner serializer should be not null"); + assertTrue(inner instanceof StringSerializer, "Inner serializer type should be StringSerializer"); + } + + @Test + public void testListValueSerializerNoArgConstructorsWithClassName() { + props.put(CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_INNER_CLASS, Serdes.StringSerde.class.getName()); + listSerializer.configure(props, false); + final Serializer inner = listSerializer.getInnerSerializer(); + assertNotNull(inner, "Inner serializer should be not null"); + assertTrue(inner instanceof StringSerializer, "Inner serializer type should be StringSerializer"); + } + + @Test + public void testListKeySerializerNoArgConstructorsWithClassObject() { + props.put(CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_INNER_CLASS, Serdes.StringSerde.class); + listSerializer.configure(props, true); + final Serializer inner = listSerializer.getInnerSerializer(); + assertNotNull(inner, "Inner serializer should be not null"); + assertTrue(inner instanceof StringSerializer, "Inner serializer type should be StringSerializer"); + } + + @Test + public void testListValueSerializerNoArgConstructorsWithClassObject() { + props.put(CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_INNER_CLASS, Serdes.StringSerde.class); + listSerializer.configure(props, false); + final Serializer inner = listSerializer.getInnerSerializer(); + assertNotNull(inner, "Inner serializer should be not null"); + assertTrue(inner instanceof StringSerializer, "Inner serializer type should be StringSerializer"); + } + + @Test + public void testListSerializerNoArgConstructorsShouldThrowConfigExceptionDueMissingProp() { + ConfigException exception = assertThrows( + ConfigException.class, + () -> listSerializer.configure(props, true) + ); + assertEquals("Not able to determine the serializer class because it was neither passed via the constructor nor set in the config.", exception.getMessage()); + + exception = assertThrows( + ConfigException.class, + () -> listSerializer.configure(props, false) + ); + assertEquals("Not able to determine the serializer class because it was neither passed via the constructor nor set in the config.", exception.getMessage()); + } + + @Test + public void testListKeySerializerNoArgConstructorsShouldThrowKafkaExceptionDueInvalidClass() { + props.put(CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_INNER_CLASS, new FakeObject()); + final KafkaException exception = assertThrows( + KafkaException.class, + () -> listSerializer.configure(props, true) + ); + assertEquals("Could not create a serializer class instance using \"" + CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_INNER_CLASS + "\" property.", exception.getMessage()); + } + + @Test + public void testListValueSerializerNoArgConstructorsShouldThrowKafkaExceptionDueInvalidClass() { + props.put(CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_INNER_CLASS, new FakeObject()); + final KafkaException exception = assertThrows( + KafkaException.class, + () -> listSerializer.configure(props, false) + ); + assertEquals("Could not create a serializer class instance using \"" + CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_INNER_CLASS + "\" property.", exception.getMessage()); + } + + @Test + public void testListKeySerializerNoArgConstructorsShouldThrowKafkaExceptionDueClassNotFound() { + props.put(CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_INNER_CLASS, nonExistingClass); + final KafkaException exception = assertThrows( + KafkaException.class, + () -> listSerializer.configure(props, true) + ); + assertEquals("Invalid value non.existing.class for configuration " + CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_INNER_CLASS + ": Serializer class " + nonExistingClass + " could not be found.", exception.getMessage()); + } + + @Test + public void testListValueSerializerNoArgConstructorsShouldThrowKafkaExceptionDueClassNotFound() { + props.put(CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_INNER_CLASS, nonExistingClass); + final KafkaException exception = assertThrows( + KafkaException.class, + () -> listSerializer.configure(props, false) + ); + assertEquals("Invalid value non.existing.class for configuration " + CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_INNER_CLASS + ": Serializer class " + nonExistingClass + " could not be found.", exception.getMessage()); + } + + @Test + public void testListKeySerializerShouldThrowConfigExceptionDueAlreadyInitialized() { + props.put(CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_INNER_CLASS, Serdes.StringSerde.class); + final ListSerializer initializedListSerializer = new ListSerializer<>(Serdes.Integer().serializer()); + final ConfigException exception = assertThrows( + ConfigException.class, + () -> initializedListSerializer.configure(props, true) + ); + assertEquals("List serializer was already initialized using a non-default constructor", exception.getMessage()); + } + + @Test + public void testListValueSerializerShouldThrowConfigExceptionDueAlreadyInitialized() { + props.put(CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_INNER_CLASS, Serdes.StringSerde.class); + final ListSerializer initializedListSerializer = new ListSerializer<>(Serdes.Integer().serializer()); + final ConfigException exception = assertThrows( + ConfigException.class, + () -> initializedListSerializer.configure(props, false) + ); + assertEquals("List serializer was already initialized using a non-default constructor", exception.getMessage()); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/serialization/SerializationTest.java b/clients/src/test/java/org/apache/kafka/common/serialization/SerializationTest.java new file mode 100644 index 0000000..85c09dd --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/serialization/SerializationTest.java @@ -0,0 +1,371 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.serialization; + +import org.apache.kafka.common.errors.SerializationException; +import org.apache.kafka.common.utils.Bytes; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.ArrayList; +import java.util.LinkedList; +import java.util.Stack; + +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class SerializationTest { + + final private String topic = "testTopic"; + final private Map, List> testData = new HashMap, List>() { + { + put(String.class, Arrays.asList("my string")); + put(Short.class, Arrays.asList((short) 32767, (short) -32768)); + put(Integer.class, Arrays.asList(423412424, -41243432)); + put(Long.class, Arrays.asList(922337203685477580L, -922337203685477581L)); + put(Float.class, Arrays.asList(5678567.12312f, -5678567.12341f)); + put(Double.class, Arrays.asList(5678567.12312d, -5678567.12341d)); + put(byte[].class, Arrays.asList("my string".getBytes())); + put(ByteBuffer.class, Arrays.asList(ByteBuffer.allocate(10).put("my string".getBytes()))); + put(Bytes.class, Arrays.asList(new Bytes("my string".getBytes()))); + put(UUID.class, Arrays.asList(UUID.randomUUID())); + } + }; + + private class DummyClass { + } + + @SuppressWarnings("unchecked") + @Test + public void allSerdesShouldRoundtripInput() { + for (Map.Entry, List> test : testData.entrySet()) { + try (Serde serde = Serdes.serdeFrom((Class) test.getKey())) { + for (Object value : test.getValue()) { + assertEquals(value, serde.deserializer().deserialize(topic, serde.serializer().serialize(topic, value)), + "Should get the original " + test.getKey().getSimpleName() + " after serialization and deserialization"); + } + } + } + } + + @Test + public void allSerdesShouldSupportNull() { + for (Class cls : testData.keySet()) { + try (Serde serde = Serdes.serdeFrom(cls)) { + assertNull(serde.serializer().serialize(topic, null), + "Should support null in " + cls.getSimpleName() + " serialization"); + assertNull(serde.deserializer().deserialize(topic, null), + "Should support null in " + cls.getSimpleName() + " deserialization"); + } + } + } + + @Test + public void testSerdeFromUnknown() { + assertThrows(IllegalArgumentException.class, () -> Serdes.serdeFrom(DummyClass.class)); + } + + @Test + public void testSerdeFromNotNull() { + try (Serde serde = Serdes.Long()) { + assertThrows(IllegalArgumentException.class, () -> Serdes.serdeFrom(null, serde.deserializer())); + } + } + + @Test + public void stringSerdeShouldSupportDifferentEncodings() { + String str = "my string"; + List encodings = Arrays.asList(StandardCharsets.UTF_8.name(), StandardCharsets.UTF_16.name()); + + for (String encoding : encodings) { + try (Serde serDeser = getStringSerde(encoding)) { + + Serializer serializer = serDeser.serializer(); + Deserializer deserializer = serDeser.deserializer(); + assertEquals(str, deserializer.deserialize(topic, serializer.serialize(topic, str)), + "Should get the original string after serialization and deserialization with encoding " + encoding); + } + } + } + + @SuppressWarnings("unchecked") + @Test + public void listSerdeShouldReturnEmptyCollection() { + List testData = Arrays.asList(); + Serde> listSerde = Serdes.ListSerde(ArrayList.class, Serdes.Integer()); + assertEquals(testData, + listSerde.deserializer().deserialize(topic, listSerde.serializer().serialize(topic, testData)), + "Should get empty collection after serialization and deserialization on an empty list"); + } + + @SuppressWarnings("unchecked") + @Test + public void listSerdeShouldReturnNull() { + List testData = null; + Serde> listSerde = Serdes.ListSerde(ArrayList.class, Serdes.Integer()); + assertEquals(testData, + listSerde.deserializer().deserialize(topic, listSerde.serializer().serialize(topic, testData)), + "Should get null after serialization and deserialization on an empty list"); + } + + @SuppressWarnings("unchecked") + @Test + public void listSerdeShouldRoundtripIntPrimitiveInput() { + List testData = Arrays.asList(1, 2, 3); + Serde> listSerde = Serdes.ListSerde(ArrayList.class, Serdes.Integer()); + assertEquals(testData, + listSerde.deserializer().deserialize(topic, listSerde.serializer().serialize(topic, testData)), + "Should get the original collection of integer primitives after serialization and deserialization"); + } + + @SuppressWarnings("unchecked") + @Test + public void listSerdeSerializerShouldReturnByteArrayOfFixedSizeForIntPrimitiveInput() { + List testData = Arrays.asList(1, 2, 3); + Serde> listSerde = Serdes.ListSerde(ArrayList.class, Serdes.Integer()); + assertEquals(21, listSerde.serializer().serialize(topic, testData).length, + "Should get length of 21 bytes after serialization"); + } + + @SuppressWarnings("unchecked") + @Test + public void listSerdeShouldRoundtripShortPrimitiveInput() { + List testData = Arrays.asList((short) 1, (short) 2, (short) 3); + Serde> listSerde = Serdes.ListSerde(ArrayList.class, Serdes.Short()); + assertEquals(testData, + listSerde.deserializer().deserialize(topic, listSerde.serializer().serialize(topic, testData)), + "Should get the original collection of short primitives after serialization and deserialization"); + } + + @SuppressWarnings("unchecked") + @Test + public void listSerdeSerializerShouldReturnByteArrayOfFixedSizeForShortPrimitiveInput() { + List testData = Arrays.asList((short) 1, (short) 2, (short) 3); + Serde> listSerde = Serdes.ListSerde(ArrayList.class, Serdes.Short()); + assertEquals(15, listSerde.serializer().serialize(topic, testData).length, + "Should get length of 15 bytes after serialization"); + } + + @SuppressWarnings("unchecked") + @Test + public void listSerdeShouldRoundtripFloatPrimitiveInput() { + List testData = Arrays.asList((float) 1, (float) 2, (float) 3); + Serde> listSerde = Serdes.ListSerde(ArrayList.class, Serdes.Float()); + assertEquals(testData, + listSerde.deserializer().deserialize(topic, listSerde.serializer().serialize(topic, testData)), + "Should get the original collection of float primitives after serialization and deserialization"); + } + + @SuppressWarnings("unchecked") + @Test + public void listSerdeSerializerShouldReturnByteArrayOfFixedSizeForFloatPrimitiveInput() { + List testData = Arrays.asList((float) 1, (float) 2, (float) 3); + Serde> listSerde = Serdes.ListSerde(ArrayList.class, Serdes.Float()); + assertEquals(21, listSerde.serializer().serialize(topic, testData).length, + "Should get length of 21 bytes after serialization"); + } + + @SuppressWarnings("unchecked") + @Test + public void listSerdeShouldRoundtripLongPrimitiveInput() { + List testData = Arrays.asList((long) 1, (long) 2, (long) 3); + Serde> listSerde = Serdes.ListSerde(ArrayList.class, Serdes.Long()); + assertEquals(testData, + listSerde.deserializer().deserialize(topic, listSerde.serializer().serialize(topic, testData)), + "Should get the original collection of long primitives after serialization and deserialization"); + } + + @SuppressWarnings("unchecked") + @Test + public void listSerdeSerializerShouldReturnByteArrayOfFixedSizeForLongPrimitiveInput() { + List testData = Arrays.asList((long) 1, (long) 2, (long) 3); + Serde> listSerde = Serdes.ListSerde(ArrayList.class, Serdes.Long()); + assertEquals(33, listSerde.serializer().serialize(topic, testData).length, + "Should get length of 33 bytes after serialization"); + } + + @SuppressWarnings("unchecked") + @Test + public void listSerdeShouldRoundtripDoublePrimitiveInput() { + List testData = Arrays.asList((double) 1, (double) 2, (double) 3); + Serde> listSerde = Serdes.ListSerde(ArrayList.class, Serdes.Double()); + assertEquals(testData, + listSerde.deserializer().deserialize(topic, listSerde.serializer().serialize(topic, testData)), + "Should get the original collection of double primitives after serialization and deserialization"); + } + + @SuppressWarnings("unchecked") + @Test + public void listSerdeSerializerShouldReturnByteArrayOfFixedSizeForDoublePrimitiveInput() { + List testData = Arrays.asList((double) 1, (double) 2, (double) 3); + Serde> listSerde = Serdes.ListSerde(ArrayList.class, Serdes.Double()); + assertEquals(33, listSerde.serializer().serialize(topic, testData).length, + "Should get length of 33 bytes after serialization"); + } + + @SuppressWarnings("unchecked") + @Test + public void listSerdeShouldRoundtripUUIDInput() { + List testData = Arrays.asList(UUID.randomUUID(), UUID.randomUUID(), UUID.randomUUID()); + Serde> listSerde = Serdes.ListSerde(ArrayList.class, Serdes.UUID()); + assertEquals(testData, + listSerde.deserializer().deserialize(topic, listSerde.serializer().serialize(topic, testData)), + "Should get the original collection of UUID after serialization and deserialization"); + } + + @SuppressWarnings("unchecked") + @Test + public void listSerdeSerializerShouldReturnByteArrayOfFixedSizeForUUIDInput() { + List testData = Arrays.asList(UUID.randomUUID(), UUID.randomUUID(), UUID.randomUUID()); + Serde> listSerde = Serdes.ListSerde(ArrayList.class, Serdes.UUID()); + assertEquals(117, listSerde.serializer().serialize(topic, testData).length, + "Should get length of 117 bytes after serialization"); + } + + @SuppressWarnings("unchecked") + @Test + public void listSerdeShouldRoundtripNonPrimitiveInput() { + List testData = Arrays.asList("A", "B", "C"); + Serde> listSerde = Serdes.ListSerde(ArrayList.class, Serdes.String()); + assertEquals(testData, + listSerde.deserializer().deserialize(topic, listSerde.serializer().serialize(topic, testData)), + "Should get the original collection of strings list after serialization and deserialization"); + } + + @SuppressWarnings("unchecked") + @Test + public void listSerdeShouldRoundtripPrimitiveInputWithNullEntries() { + List testData = Arrays.asList(1, null, 3); + Serde> listSerde = Serdes.ListSerde(ArrayList.class, Serdes.Integer()); + assertEquals(testData, + listSerde.deserializer().deserialize(topic, listSerde.serializer().serialize(topic, testData)), + "Should get the original collection of integer primitives with null entries " + + "after serialization and deserialization"); + } + + @SuppressWarnings("unchecked") + @Test + public void listSerdeShouldRoundtripNonPrimitiveInputWithNullEntries() { + List testData = Arrays.asList("A", null, "C"); + Serde> listSerde = Serdes.ListSerde(ArrayList.class, Serdes.String()); + assertEquals(testData, + listSerde.deserializer().deserialize(topic, listSerde.serializer().serialize(topic, testData)), + "Should get the original collection of strings list with null entries " + + "after serialization and deserialization"); + } + + @SuppressWarnings("unchecked") + @Test + public void listSerdeShouldReturnLinkedList() { + List testData = new LinkedList<>(); + Serde> listSerde = Serdes.ListSerde(LinkedList.class, Serdes.Integer()); + assertTrue(listSerde.deserializer().deserialize(topic, listSerde.serializer().serialize(topic, testData)) + instanceof LinkedList, "Should return List instance of type LinkedList"); + } + + @SuppressWarnings("unchecked") + @Test + public void listSerdeShouldReturnStack() { + List testData = new Stack<>(); + Serde> listSerde = Serdes.ListSerde(Stack.class, Serdes.Integer()); + assertTrue(listSerde.deserializer().deserialize(topic, listSerde.serializer().serialize(topic, testData)) + instanceof Stack, "Should return List instance of type Stack"); + } + + @Test + public void floatDeserializerShouldThrowSerializationExceptionOnZeroBytes() { + try (Serde serde = Serdes.Float()) { + assertThrows(SerializationException.class, () -> serde.deserializer().deserialize(topic, new byte[0])); + } + } + + @Test + public void floatDeserializerShouldThrowSerializationExceptionOnTooFewBytes() { + try (Serde serde = Serdes.Float()) { + assertThrows(SerializationException.class, () -> serde.deserializer().deserialize(topic, new byte[3])); + } + } + + + @Test + public void floatDeserializerShouldThrowSerializationExceptionOnTooManyBytes() { + try (Serde serde = Serdes.Float()) { + assertThrows(SerializationException.class, () -> serde.deserializer().deserialize(topic, new byte[5])); + } + } + + @Test + public void floatSerdeShouldPreserveNaNValues() { + int someNaNAsIntBits = 0x7f800001; + float someNaN = Float.intBitsToFloat(someNaNAsIntBits); + int anotherNaNAsIntBits = 0x7f800002; + float anotherNaN = Float.intBitsToFloat(anotherNaNAsIntBits); + + try (Serde serde = Serdes.Float()) { + // Because of NaN semantics we must assert based on the raw int bits. + Float roundtrip = serde.deserializer().deserialize(topic, + serde.serializer().serialize(topic, someNaN)); + assertEquals(someNaNAsIntBits, Float.floatToRawIntBits(roundtrip)); + Float otherRoundtrip = serde.deserializer().deserialize(topic, + serde.serializer().serialize(topic, anotherNaN)); + assertEquals(anotherNaNAsIntBits, Float.floatToRawIntBits(otherRoundtrip)); + } + } + + @Test + public void testSerializeVoid() { + try (Serde serde = Serdes.Void()) { + serde.serializer().serialize(topic, null); + } + } + + @Test + public void testDeserializeVoid() { + try (Serde serde = Serdes.Void()) { + serde.deserializer().deserialize(topic, null); + } + } + + @Test + public void voidDeserializerShouldThrowOnNotNullValues() { + try (Serde serde = Serdes.Void()) { + assertThrows(IllegalArgumentException.class, () -> serde.deserializer().deserialize(topic, new byte[5])); + } + } + + private Serde getStringSerde(String encoder) { + Map serializerConfigs = new HashMap(); + serializerConfigs.put("key.serializer.encoding", encoder); + Serializer serializer = Serdes.String().serializer(); + serializer.configure(serializerConfigs, true); + + Map deserializerConfigs = new HashMap(); + deserializerConfigs.put("key.deserializer.encoding", encoder); + Deserializer deserializer = Serdes.String().deserializer(); + deserializer.configure(deserializerConfigs, true); + + return Serdes.serdeFrom(serializer, deserializer); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/AbstractIteratorTest.java b/clients/src/test/java/org/apache/kafka/common/utils/AbstractIteratorTest.java new file mode 100644 index 0000000..e100d24 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/AbstractIteratorTest.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; + +import org.junit.jupiter.api.Test; + +public class AbstractIteratorTest { + + @Test + public void testIterator() { + int max = 10; + List l = new ArrayList(); + for (int i = 0; i < max; i++) + l.add(i); + ListIterator iter = new ListIterator(l); + for (int i = 0; i < max; i++) { + Integer value = i; + assertEquals(value, iter.peek()); + assertTrue(iter.hasNext()); + assertEquals(value, iter.next()); + } + assertFalse(iter.hasNext()); + } + + @Test + public void testEmptyIterator() { + Iterator iter = new ListIterator<>(Collections.emptyList()); + assertThrows(NoSuchElementException.class, iter::next); + } + + static class ListIterator extends AbstractIterator { + private List list; + private int position = 0; + + public ListIterator(List l) { + this.list = l; + } + + public T makeNext() { + if (position < list.size()) + return list.get(position++); + else + return allDone(); + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/AppInfoParserTest.java b/clients/src/test/java/org/apache/kafka/common/utils/AppInfoParserTest.java new file mode 100644 index 0000000..c33ef39 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/AppInfoParserTest.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.apache.kafka.common.metrics.Metrics; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import javax.management.JMException; +import javax.management.MBeanServer; +import javax.management.MalformedObjectNameException; +import javax.management.ObjectName; + +import java.lang.management.ManagementFactory; + +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class AppInfoParserTest { + private static final String EXPECTED_COMMIT_VERSION = AppInfoParser.DEFAULT_VALUE; + private static final String EXPECTED_VERSION = AppInfoParser.DEFAULT_VALUE; + private static final Long EXPECTED_START_MS = 1552313875722L; + private static final String METRICS_PREFIX = "app-info-test"; + private static final String METRICS_ID = "test"; + + private Metrics metrics; + private MBeanServer mBeanServer; + + @BeforeEach + public void setUp() { + metrics = new Metrics(new MockTime(1)); + mBeanServer = ManagementFactory.getPlatformMBeanServer(); + } + + @AfterEach + public void tearDown() { + metrics.close(); + } + + @Test + public void testRegisterAppInfoRegistersMetrics() throws JMException { + registerAppInfo(); + } + + @Test + public void testUnregisterAppInfoUnregistersMetrics() throws JMException { + registerAppInfo(); + AppInfoParser.unregisterAppInfo(METRICS_PREFIX, METRICS_ID, metrics); + + assertFalse(mBeanServer.isRegistered(expectedAppObjectName())); + assertNull(metrics.metric(metrics.metricName("commit-id", "app-info"))); + assertNull(metrics.metric(metrics.metricName("version", "app-info"))); + assertNull(metrics.metric(metrics.metricName("start-time-ms", "app-info"))); + } + + private void registerAppInfo() throws JMException { + assertEquals(EXPECTED_COMMIT_VERSION, AppInfoParser.getCommitId()); + assertEquals(EXPECTED_VERSION, AppInfoParser.getVersion()); + + AppInfoParser.registerAppInfo(METRICS_PREFIX, METRICS_ID, metrics, EXPECTED_START_MS); + + assertTrue(mBeanServer.isRegistered(expectedAppObjectName())); + assertEquals(EXPECTED_COMMIT_VERSION, metrics.metric(metrics.metricName("commit-id", "app-info")).metricValue()); + assertEquals(EXPECTED_VERSION, metrics.metric(metrics.metricName("version", "app-info")).metricValue()); + assertEquals(EXPECTED_START_MS, metrics.metric(metrics.metricName("start-time-ms", "app-info")).metricValue()); + } + + private ObjectName expectedAppObjectName() throws MalformedObjectNameException { + return new ObjectName(METRICS_PREFIX + ":type=app-info,id=" + METRICS_ID); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/ByteBufferInputStreamTest.java b/clients/src/test/java/org/apache/kafka/common/utils/ByteBufferInputStreamTest.java new file mode 100644 index 0000000..46755b7 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/ByteBufferInputStreamTest.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class ByteBufferInputStreamTest { + + @Test + public void testReadUnsignedIntFromInputStream() { + ByteBuffer buffer = ByteBuffer.allocate(8); + buffer.put((byte) 10); + buffer.put((byte) 20); + buffer.put((byte) 30); + buffer.rewind(); + + byte[] b = new byte[6]; + + ByteBufferInputStream inputStream = new ByteBufferInputStream(buffer); + assertEquals(10, inputStream.read()); + assertEquals(20, inputStream.read()); + + assertEquals(3, inputStream.read(b, 3, b.length - 3)); + assertEquals(0, inputStream.read()); + + assertEquals(2, inputStream.read(b, 0, b.length)); + assertEquals(-1, inputStream.read(b, 0, b.length)); + assertEquals(0, inputStream.read(b, 0, 0)); + assertEquals(-1, inputStream.read()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/ByteBufferOutputStreamTest.java b/clients/src/test/java/org/apache/kafka/common/utils/ByteBufferOutputStreamTest.java new file mode 100644 index 0000000..5bc147d --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/ByteBufferOutputStreamTest.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class ByteBufferOutputStreamTest { + + @Test + public void testExpandByteBufferOnPositionIncrease() throws Exception { + testExpandByteBufferOnPositionIncrease(ByteBuffer.allocate(16)); + } + + @Test + public void testExpandDirectByteBufferOnPositionIncrease() throws Exception { + testExpandByteBufferOnPositionIncrease(ByteBuffer.allocateDirect(16)); + } + + private void testExpandByteBufferOnPositionIncrease(ByteBuffer initialBuffer) throws Exception { + ByteBufferOutputStream output = new ByteBufferOutputStream(initialBuffer); + output.write("hello".getBytes()); + output.position(32); + assertEquals(32, output.position()); + assertEquals(0, initialBuffer.position()); + + ByteBuffer buffer = output.buffer(); + assertEquals(32, buffer.limit()); + buffer.position(0); + buffer.limit(5); + byte[] bytes = new byte[5]; + buffer.get(bytes); + assertArrayEquals("hello".getBytes(), bytes); + output.close(); + } + + @Test + public void testExpandByteBufferOnWrite() throws Exception { + testExpandByteBufferOnWrite(ByteBuffer.allocate(16)); + } + + @Test + public void testExpandDirectByteBufferOnWrite() throws Exception { + testExpandByteBufferOnWrite(ByteBuffer.allocateDirect(16)); + } + + private void testExpandByteBufferOnWrite(ByteBuffer initialBuffer) throws Exception { + ByteBufferOutputStream output = new ByteBufferOutputStream(initialBuffer); + output.write("hello".getBytes()); + output.write(new byte[27]); + assertEquals(32, output.position()); + assertEquals(0, initialBuffer.position()); + + ByteBuffer buffer = output.buffer(); + assertEquals(32, buffer.limit()); + buffer.position(0); + buffer.limit(5); + byte[] bytes = new byte[5]; + buffer.get(bytes); + assertArrayEquals("hello".getBytes(), bytes); + output.close(); + } + + @Test + public void testWriteByteBuffer() throws IOException { + testWriteByteBuffer(ByteBuffer.allocate(16)); + } + + @Test + public void testWriteDirectByteBuffer() throws IOException { + testWriteByteBuffer(ByteBuffer.allocateDirect(16)); + } + + private void testWriteByteBuffer(ByteBuffer input) throws IOException { + long value = 234239230L; + input.putLong(value); + input.flip(); + + ByteBufferOutputStream output = new ByteBufferOutputStream(ByteBuffer.allocate(32)); + output.write(input); + assertEquals(8, input.position()); + assertEquals(8, output.position()); + assertEquals(value, output.buffer().getLong(0)); + output.close(); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/ByteBufferUnmapperTest.java b/clients/src/test/java/org/apache/kafka/common/utils/ByteBufferUnmapperTest.java new file mode 100644 index 0000000..795581c --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/ByteBufferUnmapperTest.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.utils; + +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; + +import java.io.File; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; + +public class ByteBufferUnmapperTest { + + /** + * Checks that unmap doesn't throw exceptions. + */ + @Test + public void testUnmap() throws Exception { + File file = TestUtils.tempFile(); + try (FileChannel channel = FileChannel.open(file.toPath())) { + MappedByteBuffer map = channel.map(FileChannel.MapMode.READ_ONLY, 0, 0); + ByteBufferUnmapper.unmap(file.getAbsolutePath(), map); + } + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/ByteUtilsTest.java b/clients/src/test/java/org/apache/kafka/common/utils/ByteUtilsTest.java new file mode 100644 index 0000000..8f432f7 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/ByteUtilsTest.java @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.junit.jupiter.api.Test; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class ByteUtilsTest { + private final byte x00 = 0x00; + private final byte x01 = 0x01; + private final byte x02 = 0x02; + private final byte x0F = 0x0f; + private final byte x07 = 0x07; + private final byte x08 = 0x08; + private final byte x3F = 0x3f; + private final byte x40 = 0x40; + private final byte x7E = 0x7E; + private final byte x7F = 0x7F; + private final byte xFF = (byte) 0xff; + private final byte x80 = (byte) 0x80; + private final byte x81 = (byte) 0x81; + private final byte xBF = (byte) 0xbf; + private final byte xC0 = (byte) 0xc0; + private final byte xFE = (byte) 0xfe; + + @Test + public void testReadUnsignedIntLEFromArray() { + byte[] array1 = {0x01, 0x02, 0x03, 0x04, 0x05}; + assertEquals(0x04030201, ByteUtils.readUnsignedIntLE(array1, 0)); + assertEquals(0x05040302, ByteUtils.readUnsignedIntLE(array1, 1)); + + byte[] array2 = {(byte) 0xf1, (byte) 0xf2, (byte) 0xf3, (byte) 0xf4, (byte) 0xf5, (byte) 0xf6}; + assertEquals(0xf4f3f2f1, ByteUtils.readUnsignedIntLE(array2, 0)); + assertEquals(0xf6f5f4f3, ByteUtils.readUnsignedIntLE(array2, 2)); + } + + @Test + public void testReadUnsignedIntLEFromInputStream() throws IOException { + byte[] array1 = {0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09}; + ByteArrayInputStream is1 = new ByteArrayInputStream(array1); + assertEquals(0x04030201, ByteUtils.readUnsignedIntLE(is1)); + assertEquals(0x08070605, ByteUtils.readUnsignedIntLE(is1)); + + byte[] array2 = {(byte) 0xf1, (byte) 0xf2, (byte) 0xf3, (byte) 0xf4, (byte) 0xf5, (byte) 0xf6, (byte) 0xf7, (byte) 0xf8}; + ByteArrayInputStream is2 = new ByteArrayInputStream(array2); + assertEquals(0xf4f3f2f1, ByteUtils.readUnsignedIntLE(is2)); + assertEquals(0xf8f7f6f5, ByteUtils.readUnsignedIntLE(is2)); + } + + @Test + public void testReadUnsignedInt() { + ByteBuffer buffer = ByteBuffer.allocate(4); + long writeValue = 133444; + ByteUtils.writeUnsignedInt(buffer, writeValue); + buffer.flip(); + long readValue = ByteUtils.readUnsignedInt(buffer); + assertEquals(writeValue, readValue); + } + + @Test + public void testWriteUnsignedIntLEToArray() { + int value1 = 0x04030201; + + byte[] array1 = new byte[4]; + ByteUtils.writeUnsignedIntLE(array1, 0, value1); + assertArrayEquals(new byte[] {0x01, 0x02, 0x03, 0x04}, array1); + + array1 = new byte[8]; + ByteUtils.writeUnsignedIntLE(array1, 2, value1); + assertArrayEquals(new byte[] {0, 0, 0x01, 0x02, 0x03, 0x04, 0, 0}, array1); + + int value2 = 0xf4f3f2f1; + + byte[] array2 = new byte[4]; + ByteUtils.writeUnsignedIntLE(array2, 0, value2); + assertArrayEquals(new byte[] {(byte) 0xf1, (byte) 0xf2, (byte) 0xf3, (byte) 0xf4}, array2); + + array2 = new byte[8]; + ByteUtils.writeUnsignedIntLE(array2, 2, value2); + assertArrayEquals(new byte[] {0, 0, (byte) 0xf1, (byte) 0xf2, (byte) 0xf3, (byte) 0xf4, 0, 0}, array2); + } + + @Test + public void testWriteUnsignedIntLEToOutputStream() throws IOException { + int value1 = 0x04030201; + ByteArrayOutputStream os1 = new ByteArrayOutputStream(); + ByteUtils.writeUnsignedIntLE(os1, value1); + ByteUtils.writeUnsignedIntLE(os1, value1); + assertArrayEquals(new byte[] {0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04}, os1.toByteArray()); + + int value2 = 0xf4f3f2f1; + ByteArrayOutputStream os2 = new ByteArrayOutputStream(); + ByteUtils.writeUnsignedIntLE(os2, value2); + assertArrayEquals(new byte[] {(byte) 0xf1, (byte) 0xf2, (byte) 0xf3, (byte) 0xf4}, os2.toByteArray()); + } + + @Test + public void testUnsignedVarintSerde() throws Exception { + assertUnsignedVarintSerde(0, new byte[] {x00}); + assertUnsignedVarintSerde(-1, new byte[] {xFF, xFF, xFF, xFF, x0F}); + assertUnsignedVarintSerde(1, new byte[] {x01}); + assertUnsignedVarintSerde(63, new byte[] {x3F}); + assertUnsignedVarintSerde(-64, new byte[] {xC0, xFF, xFF, xFF, x0F}); + assertUnsignedVarintSerde(64, new byte[] {x40}); + assertUnsignedVarintSerde(8191, new byte[] {xFF, x3F}); + assertUnsignedVarintSerde(-8192, new byte[] {x80, xC0, xFF, xFF, x0F}); + assertUnsignedVarintSerde(8192, new byte[] {x80, x40}); + assertUnsignedVarintSerde(-8193, new byte[] {xFF, xBF, xFF, xFF, x0F}); + assertUnsignedVarintSerde(1048575, new byte[] {xFF, xFF, x3F}); + assertUnsignedVarintSerde(1048576, new byte[] {x80, x80, x40}); + assertUnsignedVarintSerde(Integer.MAX_VALUE, new byte[] {xFF, xFF, xFF, xFF, x07}); + assertUnsignedVarintSerde(Integer.MIN_VALUE, new byte[] {x80, x80, x80, x80, x08}); + } + + @Test + public void testVarintSerde() throws Exception { + assertVarintSerde(0, new byte[] {x00}); + assertVarintSerde(-1, new byte[] {x01}); + assertVarintSerde(1, new byte[] {x02}); + assertVarintSerde(63, new byte[] {x7E}); + assertVarintSerde(-64, new byte[] {x7F}); + assertVarintSerde(64, new byte[] {x80, x01}); + assertVarintSerde(-65, new byte[] {x81, x01}); + assertVarintSerde(8191, new byte[] {xFE, x7F}); + assertVarintSerde(-8192, new byte[] {xFF, x7F}); + assertVarintSerde(8192, new byte[] {x80, x80, x01}); + assertVarintSerde(-8193, new byte[] {x81, x80, x01}); + assertVarintSerde(1048575, new byte[] {xFE, xFF, x7F}); + assertVarintSerde(-1048576, new byte[] {xFF, xFF, x7F}); + assertVarintSerde(1048576, new byte[] {x80, x80, x80, x01}); + assertVarintSerde(-1048577, new byte[] {x81, x80, x80, x01}); + assertVarintSerde(134217727, new byte[] {xFE, xFF, xFF, x7F}); + assertVarintSerde(-134217728, new byte[] {xFF, xFF, xFF, x7F}); + assertVarintSerde(134217728, new byte[] {x80, x80, x80, x80, x01}); + assertVarintSerde(-134217729, new byte[] {x81, x80, x80, x80, x01}); + assertVarintSerde(Integer.MAX_VALUE, new byte[] {xFE, xFF, xFF, xFF, x0F}); + assertVarintSerde(Integer.MIN_VALUE, new byte[] {xFF, xFF, xFF, xFF, x0F}); + } + + @Test + public void testVarlongSerde() throws Exception { + assertVarlongSerde(0, new byte[] {x00}); + assertVarlongSerde(-1, new byte[] {x01}); + assertVarlongSerde(1, new byte[] {x02}); + assertVarlongSerde(63, new byte[] {x7E}); + assertVarlongSerde(-64, new byte[] {x7F}); + assertVarlongSerde(64, new byte[] {x80, x01}); + assertVarlongSerde(-65, new byte[] {x81, x01}); + assertVarlongSerde(8191, new byte[] {xFE, x7F}); + assertVarlongSerde(-8192, new byte[] {xFF, x7F}); + assertVarlongSerde(8192, new byte[] {x80, x80, x01}); + assertVarlongSerde(-8193, new byte[] {x81, x80, x01}); + assertVarlongSerde(1048575, new byte[] {xFE, xFF, x7F}); + assertVarlongSerde(-1048576, new byte[] {xFF, xFF, x7F}); + assertVarlongSerde(1048576, new byte[] {x80, x80, x80, x01}); + assertVarlongSerde(-1048577, new byte[] {x81, x80, x80, x01}); + assertVarlongSerde(134217727, new byte[] {xFE, xFF, xFF, x7F}); + assertVarlongSerde(-134217728, new byte[] {xFF, xFF, xFF, x7F}); + assertVarlongSerde(134217728, new byte[] {x80, x80, x80, x80, x01}); + assertVarlongSerde(-134217729, new byte[] {x81, x80, x80, x80, x01}); + assertVarlongSerde(Integer.MAX_VALUE, new byte[] {xFE, xFF, xFF, xFF, x0F}); + assertVarlongSerde(Integer.MIN_VALUE, new byte[] {xFF, xFF, xFF, xFF, x0F}); + assertVarlongSerde(17179869183L, new byte[] {xFE, xFF, xFF, xFF, x7F}); + assertVarlongSerde(-17179869184L, new byte[] {xFF, xFF, xFF, xFF, x7F}); + assertVarlongSerde(17179869184L, new byte[] {x80, x80, x80, x80, x80, x01}); + assertVarlongSerde(-17179869185L, new byte[] {x81, x80, x80, x80, x80, x01}); + assertVarlongSerde(2199023255551L, new byte[] {xFE, xFF, xFF, xFF, xFF, x7F}); + assertVarlongSerde(-2199023255552L, new byte[] {xFF, xFF, xFF, xFF, xFF, x7F}); + assertVarlongSerde(2199023255552L, new byte[] {x80, x80, x80, x80, x80, x80, x01}); + assertVarlongSerde(-2199023255553L, new byte[] {x81, x80, x80, x80, x80, x80, x01}); + assertVarlongSerde(281474976710655L, new byte[] {xFE, xFF, xFF, xFF, xFF, xFF, x7F}); + assertVarlongSerde(-281474976710656L, new byte[] {xFF, xFF, xFF, xFF, xFF, xFF, x7F}); + assertVarlongSerde(281474976710656L, new byte[] {x80, x80, x80, x80, x80, x80, x80, x01}); + assertVarlongSerde(-281474976710657L, new byte[] {x81, x80, x80, x80, x80, x80, x80, 1}); + assertVarlongSerde(36028797018963967L, new byte[] {xFE, xFF, xFF, xFF, xFF, xFF, xFF, x7F}); + assertVarlongSerde(-36028797018963968L, new byte[] {xFF, xFF, xFF, xFF, xFF, xFF, xFF, x7F}); + assertVarlongSerde(36028797018963968L, new byte[] {x80, x80, x80, x80, x80, x80, x80, x80, x01}); + assertVarlongSerde(-36028797018963969L, new byte[] {x81, x80, x80, x80, x80, x80, x80, x80, x01}); + assertVarlongSerde(4611686018427387903L, new byte[] {xFE, xFF, xFF, xFF, xFF, xFF, xFF, xFF, x7F}); + assertVarlongSerde(-4611686018427387904L, new byte[] {xFF, xFF, xFF, xFF, xFF, xFF, xFF, xFF, x7F}); + assertVarlongSerde(4611686018427387904L, new byte[] {x80, x80, x80, x80, x80, x80, x80, x80, x80, x01}); + assertVarlongSerde(-4611686018427387905L, new byte[] {x81, x80, x80, x80, x80, x80, x80, x80, x80, x01}); + assertVarlongSerde(Long.MAX_VALUE, new byte[] {xFE, xFF, xFF, xFF, xFF, xFF, xFF, xFF, xFF, x01}); + assertVarlongSerde(Long.MIN_VALUE, new byte[] {xFF, xFF, xFF, xFF, xFF, xFF, xFF, xFF, xFF, x01}); + } + + @Test + public void testInvalidVarint() { + // varint encoding has one overflow byte + ByteBuffer buf = ByteBuffer.wrap(new byte[] {xFF, xFF, xFF, xFF, xFF, x01}); + assertThrows(IllegalArgumentException.class, () -> ByteUtils.readVarint(buf)); + } + + @Test + public void testInvalidVarlong() { + // varlong encoding has one overflow byte + ByteBuffer buf = ByteBuffer.wrap(new byte[] {xFF, xFF, xFF, xFF, xFF, xFF, xFF, xFF, xFF, xFF, x01}); + assertThrows(IllegalArgumentException.class, () -> ByteUtils.readVarlong(buf)); + } + + @Test + public void testDouble() throws IOException { + assertDoubleSerde(0.0, 0x0L); + assertDoubleSerde(-0.0, 0x8000000000000000L); + assertDoubleSerde(1.0, 0x3FF0000000000000L); + assertDoubleSerde(-1.0, 0xBFF0000000000000L); + assertDoubleSerde(123e45, 0x49B58B82C0E0BB00L); + assertDoubleSerde(-123e45, 0xC9B58B82C0E0BB00L); + assertDoubleSerde(Double.MIN_VALUE, 0x1L); + assertDoubleSerde(-Double.MIN_VALUE, 0x8000000000000001L); + assertDoubleSerde(Double.MAX_VALUE, 0x7FEFFFFFFFFFFFFFL); + assertDoubleSerde(-Double.MAX_VALUE, 0xFFEFFFFFFFFFFFFFL); + assertDoubleSerde(Double.NaN, 0x7FF8000000000000L); + assertDoubleSerde(Double.POSITIVE_INFINITY, 0x7FF0000000000000L); + assertDoubleSerde(Double.NEGATIVE_INFINITY, 0xFFF0000000000000L); + } + + private void assertUnsignedVarintSerde(int value, byte[] expectedEncoding) throws IOException { + ByteBuffer buf = ByteBuffer.allocate(32); + ByteUtils.writeUnsignedVarint(value, buf); + buf.flip(); + assertArrayEquals(expectedEncoding, Utils.toArray(buf)); + assertEquals(value, ByteUtils.readUnsignedVarint(buf.duplicate())); + + buf.rewind(); + DataOutputStream out = new DataOutputStream(new ByteBufferOutputStream(buf)); + ByteUtils.writeUnsignedVarint(value, out); + buf.flip(); + assertArrayEquals(expectedEncoding, Utils.toArray(buf)); + DataInputStream in = new DataInputStream(new ByteBufferInputStream(buf)); + assertEquals(value, ByteUtils.readUnsignedVarint(in)); + } + + private void assertVarintSerde(int value, byte[] expectedEncoding) throws IOException { + ByteBuffer buf = ByteBuffer.allocate(32); + ByteUtils.writeVarint(value, buf); + buf.flip(); + assertArrayEquals(expectedEncoding, Utils.toArray(buf)); + assertEquals(value, ByteUtils.readVarint(buf.duplicate())); + + buf.rewind(); + DataOutputStream out = new DataOutputStream(new ByteBufferOutputStream(buf)); + ByteUtils.writeVarint(value, out); + buf.flip(); + assertArrayEquals(expectedEncoding, Utils.toArray(buf)); + DataInputStream in = new DataInputStream(new ByteBufferInputStream(buf)); + assertEquals(value, ByteUtils.readVarint(in)); + } + + private void assertVarlongSerde(long value, byte[] expectedEncoding) throws IOException { + ByteBuffer buf = ByteBuffer.allocate(32); + ByteUtils.writeVarlong(value, buf); + buf.flip(); + assertEquals(value, ByteUtils.readVarlong(buf.duplicate())); + assertArrayEquals(expectedEncoding, Utils.toArray(buf)); + + buf.rewind(); + DataOutputStream out = new DataOutputStream(new ByteBufferOutputStream(buf)); + ByteUtils.writeVarlong(value, out); + buf.flip(); + assertArrayEquals(expectedEncoding, Utils.toArray(buf)); + DataInputStream in = new DataInputStream(new ByteBufferInputStream(buf)); + assertEquals(value, ByteUtils.readVarlong(in)); + } + + private void assertDoubleSerde(double value, long expectedLongValue) throws IOException { + byte[] expectedEncoding = new byte[8]; + for (int i = 0; i < 8; i++) { + expectedEncoding[7 - i] = (byte) (expectedLongValue & 0xFF); + expectedLongValue >>= 8; + } + + ByteBuffer buf = ByteBuffer.allocate(8); + ByteUtils.writeDouble(value, buf); + buf.flip(); + assertEquals(value, ByteUtils.readDouble(buf.duplicate()), 0.0); + assertArrayEquals(expectedEncoding, Utils.toArray(buf)); + + buf.rewind(); + DataOutputStream out = new DataOutputStream(new ByteBufferOutputStream(buf)); + ByteUtils.writeDouble(value, out); + buf.flip(); + assertArrayEquals(expectedEncoding, Utils.toArray(buf)); + DataInputStream in = new DataInputStream(new ByteBufferInputStream(buf)); + assertEquals(value, ByteUtils.readDouble(in), 0.0); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/BytesTest.java b/clients/src/test/java/org/apache/kafka/common/utils/BytesTest.java new file mode 100644 index 0000000..602bbf2 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/BytesTest.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.junit.jupiter.api.Test; + +import java.util.Comparator; +import java.util.NavigableMap; +import java.util.TreeMap; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class BytesTest { + + @Test + public void testIncrement() { + byte[] input = new byte[]{(byte) 0xAB, (byte) 0xCD, (byte) 0xFF}; + byte[] expected = new byte[]{(byte) 0xAB, (byte) 0xCE, (byte) 0x00}; + Bytes output = Bytes.increment(Bytes.wrap(input)); + assertArrayEquals(output.get(), expected); + } + + @Test + public void testIncrementUpperBoundary() { + byte[] input = new byte[]{(byte) 0xFF, (byte) 0xFF, (byte) 0xFF}; + assertThrows(IndexOutOfBoundsException.class, () -> Bytes.increment(Bytes.wrap(input))); + } + + @Test + public void testIncrementWithSubmap() { + final NavigableMap map = new TreeMap<>(); + Bytes key1 = Bytes.wrap(new byte[]{(byte) 0xAA}); + byte[] val = new byte[]{(byte) 0x00}; + map.put(key1, val); + + Bytes key2 = Bytes.wrap(new byte[]{(byte) 0xAA, (byte) 0xAA}); + map.put(key2, val); + + Bytes key3 = Bytes.wrap(new byte[]{(byte) 0xAA, (byte) 0x00, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF}); + map.put(key3, val); + + Bytes key4 = Bytes.wrap(new byte[]{(byte) 0xAB, (byte) 0x00}); + map.put(key4, val); + + Bytes key5 = Bytes.wrap(new byte[]{(byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x01}); + map.put(key5, val); + + Bytes prefix = key1; + Bytes prefixEnd = Bytes.increment(prefix); + + Comparator comparator = map.comparator(); + final int result = comparator == null ? prefix.compareTo(prefixEnd) : comparator.compare(prefix, prefixEnd); + NavigableMap subMapResults; + if (result > 0) { + //Prefix increment would cause a wrap-around. Get the submap from toKey to the end of the map + subMapResults = map.tailMap(prefix, true); + } else { + subMapResults = map.subMap(prefix, true, prefixEnd, false); + } + + NavigableMap subMapExpected = new TreeMap<>(); + subMapExpected.put(key1, val); + subMapExpected.put(key2, val); + subMapExpected.put(key3, val); + + assertEquals(subMapExpected.keySet(), subMapResults.keySet()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/ChecksumsTest.java b/clients/src/test/java/org/apache/kafka/common/utils/ChecksumsTest.java new file mode 100644 index 0000000..f48cd5a --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/ChecksumsTest.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.utils; + +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.util.zip.Checksum; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class ChecksumsTest { + + @Test + public void testUpdateByteBuffer() { + byte[] bytes = new byte[]{0, 1, 2, 3, 4, 5}; + doTestUpdateByteBuffer(bytes, ByteBuffer.allocate(bytes.length)); + doTestUpdateByteBuffer(bytes, ByteBuffer.allocateDirect(bytes.length)); + } + + private void doTestUpdateByteBuffer(byte[] bytes, ByteBuffer buffer) { + buffer.put(bytes); + buffer.flip(); + Checksum bufferCrc = new Crc32(); + Checksums.update(bufferCrc, buffer, buffer.remaining()); + assertEquals(Crc32.crc32(bytes), bufferCrc.getValue()); + assertEquals(0, buffer.position()); + } + + @Test + public void testUpdateByteBufferWithOffsetPosition() { + byte[] bytes = new byte[]{-2, -1, 0, 1, 2, 3, 4, 5}; + doTestUpdateByteBufferWithOffsetPosition(bytes, ByteBuffer.allocate(bytes.length), 2); + doTestUpdateByteBufferWithOffsetPosition(bytes, ByteBuffer.allocateDirect(bytes.length), 2); + } + + @Test + public void testUpdateInt() { + final int value = 1000; + final ByteBuffer buffer = ByteBuffer.allocate(4); + buffer.putInt(value); + + Checksum crc1 = Crc32C.create(); + Checksum crc2 = Crc32C.create(); + + Checksums.updateInt(crc1, value); + crc2.update(buffer.array(), buffer.arrayOffset(), 4); + + assertEquals(crc1.getValue(), crc2.getValue(), "Crc values should be the same"); + } + + @Test + public void testUpdateLong() { + final long value = Integer.MAX_VALUE + 1; + final ByteBuffer buffer = ByteBuffer.allocate(8); + buffer.putLong(value); + + Checksum crc1 = new Crc32(); + Checksum crc2 = new Crc32(); + + Checksums.updateLong(crc1, value); + crc2.update(buffer.array(), buffer.arrayOffset(), 8); + + assertEquals(crc1.getValue(), crc2.getValue(), "Crc values should be the same"); + } + + private void doTestUpdateByteBufferWithOffsetPosition(byte[] bytes, ByteBuffer buffer, int offset) { + buffer.put(bytes); + buffer.flip(); + buffer.position(offset); + + Checksum bufferCrc = Crc32C.create(); + Checksums.update(bufferCrc, buffer, buffer.remaining()); + assertEquals(Crc32C.compute(bytes, offset, buffer.remaining()), bufferCrc.getValue()); + assertEquals(offset, buffer.position()); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/CircularIteratorTest.java b/clients/src/test/java/org/apache/kafka/common/utils/CircularIteratorTest.java new file mode 100644 index 0000000..64a2ddb --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/CircularIteratorTest.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.utils; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Arrays; +import java.util.Collections; + +public class CircularIteratorTest { + + @Test + public void testNullCollection() { + assertThrows(NullPointerException.class, () -> new CircularIterator<>(null)); + } + + @Test + public void testEmptyCollection() { + assertThrows(IllegalArgumentException.class, () -> new CircularIterator<>(Collections.emptyList())); + } + + @Test() + public void testCycleCollection() { + final CircularIterator it = new CircularIterator<>(Arrays.asList("A", "B", null, "C")); + + assertEquals("A", it.peek()); + assertTrue(it.hasNext()); + assertEquals("A", it.next()); + assertEquals("B", it.peek()); + assertTrue(it.hasNext()); + assertEquals("B", it.next()); + assertNull(it.peek()); + assertTrue(it.hasNext()); + assertNull(it.next()); + assertEquals("C", it.peek()); + assertTrue(it.hasNext()); + assertEquals("C", it.next()); + assertEquals("A", it.peek()); + assertTrue(it.hasNext()); + assertEquals("A", it.next()); + assertEquals("B", it.peek()); + + // Check that peek does not have any side-effects + assertEquals("B", it.peek()); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/CollectionUtilsTest.java b/clients/src/test/java/org/apache/kafka/common/utils/CollectionUtilsTest.java new file mode 100644 index 0000000..7f8419e --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/CollectionUtilsTest.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.apache.kafka.common.utils.CollectionUtils.subtractMap; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertNotSame; + +public class CollectionUtilsTest { + + @Test + public void testSubtractMapRemovesSecondMapsKeys() { + Map mainMap = new HashMap<>(); + mainMap.put("one", "1"); + mainMap.put("two", "2"); + mainMap.put("three", "3"); + Map secondaryMap = new HashMap<>(); + secondaryMap.put("one", "4"); + secondaryMap.put("two", "5"); + + Map newMap = subtractMap(mainMap, secondaryMap); + + assertEquals(3, mainMap.size()); // original map should not be modified + assertEquals(1, newMap.size()); + assertTrue(newMap.containsKey("three")); + assertEquals("3", newMap.get("three")); + } + + @Test + public void testSubtractMapDoesntRemoveAnythingWhenEmptyMap() { + Map mainMap = new HashMap<>(); + mainMap.put("one", "1"); + mainMap.put("two", "2"); + mainMap.put("three", "3"); + Map secondaryMap = new HashMap<>(); + + Map newMap = subtractMap(mainMap, secondaryMap); + + assertEquals(3, newMap.size()); + assertEquals("1", newMap.get("one")); + assertEquals("2", newMap.get("two")); + assertEquals("3", newMap.get("three")); + assertNotSame(newMap, mainMap); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/ConfigUtilsTest.java b/clients/src/test/java/org/apache/kafka/common/utils/ConfigUtilsTest.java new file mode 100644 index 0000000..d760330 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/ConfigUtilsTest.java @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.utils; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigDef.Importance; +import org.apache.kafka.common.config.ConfigDef.Type; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; + +public class ConfigUtilsTest { + + @Test + public void testTranslateDeprecated() { + Map config = new HashMap<>(); + config.put("foo.bar", "baz"); + config.put("foo.bar.deprecated", "quux"); + config.put("chicken", "1"); + config.put("rooster", "2"); + config.put("hen", "3"); + config.put("heifer", "moo"); + config.put("blah", "blah"); + config.put("unexpected.non.string.object", 42); + Map newConfig = ConfigUtils.translateDeprecatedConfigs(config, new String[][]{ + {"foo.bar", "foo.bar.deprecated"}, + {"chicken", "rooster", "hen"}, + {"cow", "beef", "heifer", "steer"} + }); + assertEquals("baz", newConfig.get("foo.bar")); + assertNull(newConfig.get("foobar.deprecated")); + assertEquals("1", newConfig.get("chicken")); + assertNull(newConfig.get("rooster")); + assertNull(newConfig.get("hen")); + assertEquals("moo", newConfig.get("cow")); + assertNull(newConfig.get("beef")); + assertNull(newConfig.get("heifer")); + assertNull(newConfig.get("steer")); + assertNull(config.get("cow")); + assertEquals("blah", config.get("blah")); + assertEquals("blah", newConfig.get("blah")); + assertEquals(42, newConfig.get("unexpected.non.string.object")); + assertEquals(42, config.get("unexpected.non.string.object")); + + } + + @Test + public void testAllowsNewKey() { + Map config = new HashMap<>(); + config.put("foo.bar", "baz"); + Map newConfig = ConfigUtils.translateDeprecatedConfigs(config, new String[][]{ + {"foo.bar", "foo.bar.deprecated"}, + {"chicken", "rooster", "hen"}, + {"cow", "beef", "heifer", "steer"} + }); + assertNotNull(newConfig); + assertEquals("baz", newConfig.get("foo.bar")); + assertNull(newConfig.get("foo.bar.deprecated")); + } + + @Test + public void testAllowDeprecatedNulls() { + Map config = new HashMap<>(); + config.put("foo.bar.deprecated", null); + config.put("foo.bar", "baz"); + Map newConfig = ConfigUtils.translateDeprecatedConfigs(config, new String[][]{ + {"foo.bar", "foo.bar.deprecated"} + }); + assertNotNull(newConfig); + assertEquals("baz", newConfig.get("foo.bar")); + assertNull(newConfig.get("foo.bar.deprecated")); + } + + @Test + public void testAllowNullOverride() { + Map config = new HashMap<>(); + config.put("foo.bar.deprecated", "baz"); + config.put("foo.bar", null); + Map newConfig = ConfigUtils.translateDeprecatedConfigs(config, new String[][]{ + {"foo.bar", "foo.bar.deprecated"} + }); + assertNotNull(newConfig); + assertNull(newConfig.get("foo.bar")); + assertNull(newConfig.get("foo.bar.deprecated")); + } + + @Test + public void testNullMapEntriesWithoutAliasesDoNotThrowNPE() { + Map config = new HashMap<>(); + config.put("other", null); + Map newConfig = ConfigUtils.translateDeprecatedConfigs(config, new String[][]{ + {"foo.bar", "foo.bar.deprecated"} + }); + assertNotNull(newConfig); + assertNull(newConfig.get("other")); + } + + @Test + public void testDuplicateSynonyms() { + Map config = new HashMap<>(); + config.put("foo.bar", "baz"); + config.put("foo.bar.deprecated", "derp"); + Map newConfig = ConfigUtils.translateDeprecatedConfigs(config, new String[][]{ + {"foo.bar", "foo.bar.deprecated"}, + {"chicken", "foo.bar.deprecated"} + }); + assertNotNull(newConfig); + assertEquals("baz", newConfig.get("foo.bar")); + assertEquals("derp", newConfig.get("chicken")); + assertNull(newConfig.get("foo.bar.deprecated")); + } + + @Test + public void testMultipleDeprecations() { + Map config = new HashMap<>(); + config.put("foo.bar.deprecated", "derp"); + config.put("foo.bar.even.more.deprecated", "very old configuration"); + Map newConfig = ConfigUtils.translateDeprecatedConfigs(config, new String[][]{ + {"foo.bar", "foo.bar.deprecated", "foo.bar.even.more.deprecated"} + }); + assertNotNull(newConfig); + assertEquals("derp", newConfig.get("foo.bar")); + assertNull(newConfig.get("foo.bar.deprecated")); + assertNull(newConfig.get("foo.bar.even.more.deprecated")); + } + + private static final ConfigDef CONFIG = new ConfigDef(). + define("myPassword", Type.PASSWORD, Importance.HIGH, ""). + define("myString", Type.STRING, Importance.HIGH, ""). + define("myInt", Type.INT, Importance.HIGH, ""). + define("myString2", Type.STRING, Importance.HIGH, ""); + + @Test + public void testConfigMapToRedactedStringForEmptyMap() { + assertEquals("{}", ConfigUtils. + configMapToRedactedString(Collections.emptyMap(), CONFIG)); + } + + @Test + public void testConfigMapToRedactedStringWithSecrets() { + Map testMap1 = new HashMap<>(); + testMap1.put("myString", "whatever"); + testMap1.put("myInt", Integer.valueOf(123)); + testMap1.put("myPassword", "foosecret"); + testMap1.put("myString2", null); + testMap1.put("myUnknown", Integer.valueOf(456)); + assertEquals("{myInt=123, myPassword=(redacted), myString=\"whatever\", myString2=null, myUnknown=(redacted)}", + ConfigUtils.configMapToRedactedString(testMap1, CONFIG)); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/Crc32CTest.java b/clients/src/test/java/org/apache/kafka/common/utils/Crc32CTest.java new file mode 100644 index 0000000..c05df71 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/Crc32CTest.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.junit.jupiter.api.Test; + +import java.util.zip.Checksum; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class Crc32CTest { + + @Test + public void testUpdate() { + final byte[] bytes = "Any String you want".getBytes(); + final int len = bytes.length; + + Checksum crc1 = Crc32C.create(); + Checksum crc2 = Crc32C.create(); + Checksum crc3 = Crc32C.create(); + + crc1.update(bytes, 0, len); + for (int i = 0; i < len; i++) + crc2.update(bytes[i]); + crc3.update(bytes, 0, len / 2); + crc3.update(bytes, len / 2, len - len / 2); + + assertEquals(crc1.getValue(), crc2.getValue(), "Crc values should be the same"); + assertEquals(crc1.getValue(), crc3.getValue(), "Crc values should be the same"); + } + + @Test + public void testValue() { + final byte[] bytes = "Some String".getBytes(); + assertEquals(608512271, Crc32C.compute(bytes, 0, bytes.length)); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/Crc32Test.java b/clients/src/test/java/org/apache/kafka/common/utils/Crc32Test.java new file mode 100644 index 0000000..a358210 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/Crc32Test.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.junit.jupiter.api.Test; + +import java.util.zip.Checksum; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class Crc32Test { + + @Test + public void testUpdate() { + final byte[] bytes = "Any String you want".getBytes(); + final int len = bytes.length; + + Checksum crc1 = new Crc32(); + Checksum crc2 = new Crc32(); + Checksum crc3 = new Crc32(); + + crc1.update(bytes, 0, len); + for (int i = 0; i < len; i++) + crc2.update(bytes[i]); + crc3.update(bytes, 0, len / 2); + crc3.update(bytes, len / 2, len - len / 2); + + assertEquals(crc1.getValue(), crc2.getValue(), "Crc values should be the same"); + assertEquals(crc1.getValue(), crc3.getValue(), "Crc values should be the same"); + } + + @Test + public void testValue() { + final byte[] bytes = "Some String".getBytes(); + assertEquals(2021503672, Crc32.crc32(bytes)); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/ExitTest.java b/clients/src/test/java/org/apache/kafka/common/utils/ExitTest.java new file mode 100644 index 0000000..8adcca3 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/ExitTest.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class ExitTest { + @Test + public void shouldHaltImmediately() { + List list = new ArrayList<>(); + Exit.setHaltProcedure((statusCode, message) -> { + list.add(statusCode); + list.add(message); + }); + try { + int statusCode = 0; + String message = "mesaage"; + Exit.halt(statusCode); + Exit.halt(statusCode, message); + assertEquals(Arrays.asList(statusCode, null, statusCode, message), list); + } finally { + Exit.resetHaltProcedure(); + } + } + + @Test + public void shouldExitImmediately() { + List list = new ArrayList<>(); + Exit.setExitProcedure((statusCode, message) -> { + list.add(statusCode); + list.add(message); + }); + try { + int statusCode = 0; + String message = "mesaage"; + Exit.exit(statusCode); + Exit.exit(statusCode, message); + assertEquals(Arrays.asList(statusCode, null, statusCode, message), list); + } finally { + Exit.resetExitProcedure(); + } + } + + @Test + public void shouldAddShutdownHookImmediately() { + List list = new ArrayList<>(); + Exit.setShutdownHookAdder((name, runnable) -> { + list.add(name); + list.add(runnable); + }); + try { + Runnable runnable = () -> { }; + String name = "name"; + Exit.addShutdownHook(name, runnable); + assertEquals(Arrays.asList(name, runnable), list); + } finally { + Exit.resetShutdownHookAdder(); + } + } + + @Test + public void shouldNotInvokeShutdownHookImmediately() { + List list = new ArrayList<>(); + Runnable runnable = () -> list.add(this); + Exit.addShutdownHook("message", runnable); + assertEquals(0, list.size()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/ExponentialBackoffTest.java b/clients/src/test/java/org/apache/kafka/common/utils/ExponentialBackoffTest.java new file mode 100644 index 0000000..4e84386 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/ExponentialBackoffTest.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.utils; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ExponentialBackoffTest { + @Test + public void testExponentialBackoff() { + long scaleFactor = 100; + int ratio = 2; + long backoffMax = 2000; + double jitter = 0.2; + ExponentialBackoff exponentialBackoff = new ExponentialBackoff( + scaleFactor, ratio, backoffMax, jitter + ); + + for (int i = 0; i <= 100; i++) { + for (int attempts = 0; attempts <= 10; attempts++) { + if (attempts <= 4) { + assertEquals(scaleFactor * Math.pow(ratio, attempts), + exponentialBackoff.backoff(attempts), + scaleFactor * Math.pow(ratio, attempts) * jitter); + } else { + assertTrue(exponentialBackoff.backoff(attempts) <= backoffMax * (1 + jitter)); + } + } + } + } + + @Test + public void testExponentialBackoffWithoutJitter() { + ExponentialBackoff exponentialBackoff = new ExponentialBackoff(100, 2, 400, 0.0); + assertEquals(100, exponentialBackoff.backoff(0)); + assertEquals(200, exponentialBackoff.backoff(1)); + assertEquals(400, exponentialBackoff.backoff(2)); + assertEquals(400, exponentialBackoff.backoff(3)); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/FixedOrderMapTest.java b/clients/src/test/java/org/apache/kafka/common/utils/FixedOrderMapTest.java new file mode 100644 index 0000000..e07be6a --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/FixedOrderMapTest.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.junit.jupiter.api.Test; + +import java.util.Iterator; +import java.util.Map; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class FixedOrderMapTest { + + @Test + public void shouldMaintainOrderWhenAdding() { + final FixedOrderMap map = new FixedOrderMap<>(); + map.put("a", 0); + map.put("b", 1); + map.put("c", 2); + map.put("b", 3); + final Iterator> iterator = map.entrySet().iterator(); + assertEquals(mkEntry("a", 0), iterator.next()); + assertEquals(mkEntry("b", 3), iterator.next()); + assertEquals(mkEntry("c", 2), iterator.next()); + assertFalse(iterator.hasNext()); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldForbidRemove() { + final FixedOrderMap map = new FixedOrderMap<>(); + map.put("a", 0); + assertThrows(UnsupportedOperationException.class, () -> map.remove("a")); + assertEquals(0, map.get("a")); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldForbidConditionalRemove() { + final FixedOrderMap map = new FixedOrderMap<>(); + map.put("a", 0); + assertThrows(UnsupportedOperationException.class, () -> map.remove("a", 0)); + assertEquals(0, map.get("a")); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/FlattenedIteratorTest.java b/clients/src/test/java/org/apache/kafka/common/utils/FlattenedIteratorTest.java new file mode 100644 index 0000000..fe02cbe --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/FlattenedIteratorTest.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class FlattenedIteratorTest { + + @Test + public void testNestedLists() { + List> list = asList( + asList("foo", "a", "bc"), + asList("ddddd"), + asList("", "bar2", "baz45")); + + Iterable flattenedIterable = () -> new FlattenedIterator<>(list.iterator(), l -> l.iterator()); + List flattened = new ArrayList<>(); + flattenedIterable.forEach(flattened::add); + + assertEquals(list.stream().flatMap(l -> l.stream()).collect(Collectors.toList()), flattened); + + // Ensure we can iterate multiple times + List flattened2 = new ArrayList<>(); + flattenedIterable.forEach(flattened2::add); + + assertEquals(flattened, flattened2); + } + + @Test + public void testEmptyList() { + List> list = emptyList(); + + Iterable flattenedIterable = () -> new FlattenedIterator<>(list.iterator(), l -> l.iterator()); + List flattened = new ArrayList<>(); + flattenedIterable.forEach(flattened::add); + + assertEquals(emptyList(), flattened); + } + + @Test + public void testNestedSingleEmptyList() { + List> list = asList(emptyList()); + + Iterable flattenedIterable = () -> new FlattenedIterator<>(list.iterator(), l -> l.iterator()); + List flattened = new ArrayList<>(); + flattenedIterable.forEach(flattened::add); + + assertEquals(emptyList(), flattened); + } + + @Test + public void testEmptyListFollowedByNonEmpty() { + List> list = asList( + emptyList(), + asList("boo", "b", "de")); + + Iterable flattenedIterable = () -> new FlattenedIterator<>(list.iterator(), l -> l.iterator()); + List flattened = new ArrayList<>(); + flattenedIterable.forEach(flattened::add); + + assertEquals(list.stream().flatMap(l -> l.stream()).collect(Collectors.toList()), flattened); + } + + @Test + public void testEmptyListInBetweenNonEmpty() { + List> list = asList( + asList("aadwdwdw"), + emptyList(), + asList("ee", "aa", "dd")); + + Iterable flattenedIterable = () -> new FlattenedIterator<>(list.iterator(), l -> l.iterator()); + List flattened = new ArrayList<>(); + flattenedIterable.forEach(flattened::add); + + assertEquals(list.stream().flatMap(l -> l.stream()).collect(Collectors.toList()), flattened); + } + + @Test + public void testEmptyListAtTheEnd() { + List> list = asList( + asList("ee", "dd"), + asList("e"), + emptyList()); + + Iterable flattenedIterable = () -> new FlattenedIterator<>(list.iterator(), l -> l.iterator()); + List flattened = new ArrayList<>(); + flattenedIterable.forEach(flattened::add); + + assertEquals(list.stream().flatMap(l -> l.stream()).collect(Collectors.toList()), flattened); + } + +} + diff --git a/clients/src/test/java/org/apache/kafka/common/utils/ImplicitLinkedHashCollectionTest.java b/clients/src/test/java/org/apache/kafka/common/utils/ImplicitLinkedHashCollectionTest.java new file mode 100644 index 0000000..3c12c98 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/ImplicitLinkedHashCollectionTest.java @@ -0,0 +1,670 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Comparator; +import java.util.Iterator; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.ListIterator; +import java.util.Random; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; +import static org.junit.jupiter.api.Assertions.assertThrows; + +/** + * A unit test for ImplicitLinkedHashCollection. + */ +@Timeout(120) +public class ImplicitLinkedHashCollectionTest { + + final static class TestElement implements ImplicitLinkedHashCollection.Element { + private int prev = ImplicitLinkedHashCollection.INVALID_INDEX; + private int next = ImplicitLinkedHashCollection.INVALID_INDEX; + private final int key; + private final int val; + + TestElement(int key) { + this.key = key; + this.val = 0; + } + + TestElement(int key, int val) { + this.key = key; + this.val = val; + } + + @Override + public int prev() { + return prev; + } + + @Override + public void setPrev(int prev) { + this.prev = prev; + } + + @Override + public int next() { + return next; + } + + @Override + public void setNext(int next) { + this.next = next; + } + + @Override + public boolean elementKeysAreEqual(Object o) { + if (this == o) return true; + if ((o == null) || (o.getClass() != TestElement.class)) return false; + TestElement that = (TestElement) o; + return key == that.key; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if ((o == null) || (o.getClass() != TestElement.class)) return false; + TestElement that = (TestElement) o; + return key == that.key && val == that.val; + } + + @Override + public String toString() { + return "TestElement(key=" + key + ", val=" + val + ")"; + } + + @Override + public int hashCode() { + long hashCode = 2654435761L * key; + return (int) (hashCode >> 32); + } + } + + @Test + public void testNullForbidden() { + ImplicitLinkedHashMultiCollection multiColl = new ImplicitLinkedHashMultiCollection<>(); + assertFalse(multiColl.add(null)); + } + + @Test + public void testInsertDelete() { + ImplicitLinkedHashCollection coll = new ImplicitLinkedHashCollection<>(100); + assertTrue(coll.add(new TestElement(1))); + TestElement second = new TestElement(2); + assertTrue(coll.add(second)); + assertTrue(coll.add(new TestElement(3))); + assertFalse(coll.add(new TestElement(3))); + assertEquals(3, coll.size()); + assertTrue(coll.contains(new TestElement(1))); + assertFalse(coll.contains(new TestElement(4))); + TestElement secondAgain = coll.find(new TestElement(2)); + assertTrue(second == secondAgain); + assertTrue(coll.remove(new TestElement(1))); + assertFalse(coll.remove(new TestElement(1))); + assertEquals(2, coll.size()); + coll.clear(); + assertEquals(0, coll.size()); + } + + static void expectTraversal(Iterator iterator, Integer... sequence) { + int i = 0; + while (iterator.hasNext()) { + TestElement element = iterator.next(); + assertTrue(i < sequence.length, "Iterator yieled " + (i + 1) + " elements, but only " + + sequence.length + " were expected."); + assertEquals(sequence[i].intValue(), element.key, "Iterator value number " + (i + 1) + " was incorrect."); + i = i + 1; + } + assertTrue(i == sequence.length, "Iterator yieled " + (i + 1) + " elements, but " + + sequence.length + " were expected."); + } + + static void expectTraversal(Iterator iter, Iterator expectedIter) { + int i = 0; + while (iter.hasNext()) { + TestElement element = iter.next(); + assertTrue(expectedIter.hasNext(), "Iterator yieled " + (i + 1) + " elements, but only " + i + + " were expected."); + Integer expected = expectedIter.next(); + assertEquals(expected.intValue(), element.key, "Iterator value number " + (i + 1) + " was incorrect."); + i = i + 1; + } + assertFalse(expectedIter.hasNext(), "Iterator yieled " + i + " elements, but at least " + (i + 1) + + " were expected."); + } + + @Test + public void testTraversal() { + ImplicitLinkedHashCollection coll = new ImplicitLinkedHashCollection<>(); + expectTraversal(coll.iterator()); + assertTrue(coll.add(new TestElement(2))); + expectTraversal(coll.iterator(), 2); + assertTrue(coll.add(new TestElement(1))); + expectTraversal(coll.iterator(), 2, 1); + assertTrue(coll.add(new TestElement(100))); + expectTraversal(coll.iterator(), 2, 1, 100); + assertTrue(coll.remove(new TestElement(1))); + expectTraversal(coll.iterator(), 2, 100); + assertTrue(coll.add(new TestElement(1))); + expectTraversal(coll.iterator(), 2, 100, 1); + Iterator iter = coll.iterator(); + iter.next(); + iter.next(); + iter.remove(); + iter.next(); + assertFalse(iter.hasNext()); + expectTraversal(coll.iterator(), 2, 1); + List list = new ArrayList<>(); + list.add(new TestElement(1)); + list.add(new TestElement(2)); + assertTrue(coll.removeAll(list)); + assertFalse(coll.removeAll(list)); + expectTraversal(coll.iterator()); + assertEquals(0, coll.size()); + assertTrue(coll.isEmpty()); + } + + @Test + public void testSetViewGet() { + ImplicitLinkedHashCollection coll = new ImplicitLinkedHashCollection<>(); + coll.add(new TestElement(1)); + coll.add(new TestElement(2)); + coll.add(new TestElement(3)); + + Set set = coll.valuesSet(); + assertTrue(set.contains(new TestElement(1))); + assertTrue(set.contains(new TestElement(2))); + assertTrue(set.contains(new TestElement(3))); + assertEquals(3, set.size()); + } + + @Test + public void testSetViewModification() { + ImplicitLinkedHashCollection coll = new ImplicitLinkedHashCollection<>(); + coll.add(new TestElement(1)); + coll.add(new TestElement(2)); + coll.add(new TestElement(3)); + + // Removal from set is reflected in collection + Set set = coll.valuesSet(); + set.remove(new TestElement(1)); + assertFalse(coll.contains(new TestElement(1))); + assertEquals(2, coll.size()); + + // Addition to set is reflected in collection + set.add(new TestElement(4)); + assertTrue(coll.contains(new TestElement(4))); + assertEquals(3, coll.size()); + + // Removal from collection is reflected in set + coll.remove(new TestElement(2)); + assertFalse(set.contains(new TestElement(2))); + assertEquals(2, set.size()); + + // Addition to collection is reflected in set + coll.add(new TestElement(5)); + assertTrue(set.contains(new TestElement(5))); + assertEquals(3, set.size()); + + // Ordering in the collection is maintained + int key = 3; + for (TestElement e : coll) { + assertEquals(key, e.key); + ++key; + } + } + + @Test + public void testListViewGet() { + ImplicitLinkedHashCollection coll = new ImplicitLinkedHashCollection<>(); + coll.add(new TestElement(1)); + coll.add(new TestElement(2)); + coll.add(new TestElement(3)); + + List list = coll.valuesList(); + assertEquals(1, list.get(0).key); + assertEquals(2, list.get(1).key); + assertEquals(3, list.get(2).key); + assertEquals(3, list.size()); + } + + @Test + public void testListViewModification() { + ImplicitLinkedHashCollection coll = new ImplicitLinkedHashCollection<>(); + coll.add(new TestElement(1)); + coll.add(new TestElement(2)); + coll.add(new TestElement(3)); + + // Removal from list is reflected in collection + List list = coll.valuesList(); + list.remove(1); + assertTrue(coll.contains(new TestElement(1))); + assertFalse(coll.contains(new TestElement(2))); + assertTrue(coll.contains(new TestElement(3))); + assertEquals(2, coll.size()); + + // Removal from collection is reflected in list + coll.remove(new TestElement(1)); + assertEquals(3, list.get(0).key); + assertEquals(1, list.size()); + + // Addition to collection is reflected in list + coll.add(new TestElement(4)); + assertEquals(3, list.get(0).key); + assertEquals(4, list.get(1).key); + assertEquals(2, list.size()); + } + + @Test + public void testEmptyListIterator() { + ImplicitLinkedHashCollection coll = new ImplicitLinkedHashCollection<>(); + ListIterator iter = coll.valuesList().listIterator(); + assertFalse(iter.hasNext()); + assertFalse(iter.hasPrevious()); + assertEquals(0, iter.nextIndex()); + assertEquals(-1, iter.previousIndex()); + } + + @Test + public void testListIteratorCreation() { + ImplicitLinkedHashCollection coll = new ImplicitLinkedHashCollection<>(); + coll.add(new TestElement(1)); + coll.add(new TestElement(2)); + coll.add(new TestElement(3)); + + // Iterator created at the start of the list should have a next but no prev + ListIterator iter = coll.valuesList().listIterator(); + assertTrue(iter.hasNext()); + assertFalse(iter.hasPrevious()); + assertEquals(0, iter.nextIndex()); + assertEquals(-1, iter.previousIndex()); + + // Iterator created in the middle of the list should have both a next and a prev + iter = coll.valuesList().listIterator(2); + assertTrue(iter.hasNext()); + assertTrue(iter.hasPrevious()); + assertEquals(2, iter.nextIndex()); + assertEquals(1, iter.previousIndex()); + + // Iterator created at the end of the list should have a prev but no next + iter = coll.valuesList().listIterator(3); + assertFalse(iter.hasNext()); + assertTrue(iter.hasPrevious()); + assertEquals(3, iter.nextIndex()); + assertEquals(2, iter.previousIndex()); + } + + @Test + public void testListIteratorTraversal() { + ImplicitLinkedHashCollection coll = new ImplicitLinkedHashCollection<>(); + coll.add(new TestElement(1)); + coll.add(new TestElement(2)); + coll.add(new TestElement(3)); + ListIterator iter = coll.valuesList().listIterator(); + + // Step the iterator forward to the end of the list + assertTrue(iter.hasNext()); + assertFalse(iter.hasPrevious()); + assertEquals(0, iter.nextIndex()); + assertEquals(-1, iter.previousIndex()); + + assertEquals(1, iter.next().key); + assertTrue(iter.hasNext()); + assertTrue(iter.hasPrevious()); + assertEquals(1, iter.nextIndex()); + assertEquals(0, iter.previousIndex()); + + assertEquals(2, iter.next().key); + assertTrue(iter.hasNext()); + assertTrue(iter.hasPrevious()); + assertEquals(2, iter.nextIndex()); + assertEquals(1, iter.previousIndex()); + + assertEquals(3, iter.next().key); + assertFalse(iter.hasNext()); + assertTrue(iter.hasPrevious()); + assertEquals(3, iter.nextIndex()); + assertEquals(2, iter.previousIndex()); + + // Step back to the middle of the list + assertEquals(3, iter.previous().key); + assertTrue(iter.hasNext()); + assertTrue(iter.hasPrevious()); + assertEquals(2, iter.nextIndex()); + assertEquals(1, iter.previousIndex()); + + assertEquals(2, iter.previous().key); + assertTrue(iter.hasNext()); + assertTrue(iter.hasPrevious()); + assertEquals(1, iter.nextIndex()); + assertEquals(0, iter.previousIndex()); + + // Step forward one and then back one, return value should remain the same + assertEquals(2, iter.next().key); + assertTrue(iter.hasNext()); + assertTrue(iter.hasPrevious()); + assertEquals(2, iter.nextIndex()); + assertEquals(1, iter.previousIndex()); + + assertEquals(2, iter.previous().key); + assertTrue(iter.hasNext()); + assertTrue(iter.hasPrevious()); + assertEquals(1, iter.nextIndex()); + assertEquals(0, iter.previousIndex()); + + // Step back to the front of the list + assertEquals(1, iter.previous().key); + assertTrue(iter.hasNext()); + assertFalse(iter.hasPrevious()); + assertEquals(0, iter.nextIndex()); + assertEquals(-1, iter.previousIndex()); + } + + @Test + public void testListIteratorRemove() { + ImplicitLinkedHashCollection coll = new ImplicitLinkedHashCollection<>(); + coll.add(new TestElement(1)); + coll.add(new TestElement(2)); + coll.add(new TestElement(3)); + coll.add(new TestElement(4)); + coll.add(new TestElement(5)); + + ListIterator iter = coll.valuesList().listIterator(); + try { + iter.remove(); + fail("Calling remove() without calling next() or previous() should raise an exception"); + } catch (IllegalStateException e) { + // expected + } + + // Remove after next() + iter.next(); + iter.next(); + iter.next(); + iter.remove(); + assertTrue(iter.hasNext()); + assertTrue(iter.hasPrevious()); + assertEquals(2, iter.nextIndex()); + assertEquals(1, iter.previousIndex()); + + try { + iter.remove(); + fail("Calling remove() twice without calling next() or previous() in between should raise an exception"); + } catch (IllegalStateException e) { + // expected + } + + // Remove after previous() + assertEquals(2, iter.previous().key); + iter.remove(); + assertTrue(iter.hasNext()); + assertTrue(iter.hasPrevious()); + assertEquals(1, iter.nextIndex()); + assertEquals(0, iter.previousIndex()); + + // Remove the first element of the list + assertEquals(1, iter.previous().key); + iter.remove(); + assertTrue(iter.hasNext()); + assertFalse(iter.hasPrevious()); + assertEquals(0, iter.nextIndex()); + assertEquals(-1, iter.previousIndex()); + + // Remove the last element of the list + assertEquals(4, iter.next().key); + assertEquals(5, iter.next().key); + iter.remove(); + assertFalse(iter.hasNext()); + assertTrue(iter.hasPrevious()); + assertEquals(1, iter.nextIndex()); + assertEquals(0, iter.previousIndex()); + + // Remove the final remaining element of the list + assertEquals(4, iter.previous().key); + iter.remove(); + assertFalse(iter.hasNext()); + assertFalse(iter.hasPrevious()); + assertEquals(0, iter.nextIndex()); + assertEquals(-1, iter.previousIndex()); + + } + + @Test + public void testCollisions() { + ImplicitLinkedHashCollection coll = new ImplicitLinkedHashCollection<>(5); + assertEquals(11, coll.numSlots()); + assertTrue(coll.add(new TestElement(11))); + assertTrue(coll.add(new TestElement(0))); + assertTrue(coll.add(new TestElement(22))); + assertTrue(coll.add(new TestElement(33))); + assertEquals(11, coll.numSlots()); + expectTraversal(coll.iterator(), 11, 0, 22, 33); + assertTrue(coll.remove(new TestElement(22))); + expectTraversal(coll.iterator(), 11, 0, 33); + assertEquals(3, coll.size()); + assertFalse(coll.isEmpty()); + } + + @Test + public void testEnlargement() { + ImplicitLinkedHashCollection coll = new ImplicitLinkedHashCollection<>(5); + assertEquals(11, coll.numSlots()); + for (int i = 0; i < 6; i++) { + assertTrue(coll.add(new TestElement(i))); + } + assertEquals(23, coll.numSlots()); + assertEquals(6, coll.size()); + expectTraversal(coll.iterator(), 0, 1, 2, 3, 4, 5); + for (int i = 0; i < 6; i++) { + assertTrue(coll.contains(new TestElement(i)), "Failed to find element " + i); + } + coll.remove(new TestElement(3)); + assertEquals(23, coll.numSlots()); + assertEquals(5, coll.size()); + expectTraversal(coll.iterator(), 0, 1, 2, 4, 5); + } + + @Test + public void testManyInsertsAndDeletes() { + Random random = new Random(123); + LinkedHashSet existing = new LinkedHashSet<>(); + ImplicitLinkedHashCollection coll = new ImplicitLinkedHashCollection<>(); + for (int i = 0; i < 100; i++) { + addRandomElement(random, existing, coll); + addRandomElement(random, existing, coll); + addRandomElement(random, existing, coll); + removeRandomElement(random, existing); + expectTraversal(coll.iterator(), existing.iterator()); + } + } + + @Test + public void testInsertingTheSameObjectMultipleTimes() { + ImplicitLinkedHashCollection coll = new ImplicitLinkedHashCollection<>(); + TestElement element = new TestElement(123); + assertTrue(coll.add(element)); + assertFalse(coll.add(element)); + assertFalse(coll.add(element)); + assertTrue(coll.remove(element)); + assertFalse(coll.remove(element)); + assertTrue(coll.add(element)); + assertFalse(coll.add(element)); + } + + @Test + public void testEquals() { + ImplicitLinkedHashCollection coll1 = new ImplicitLinkedHashCollection<>(); + coll1.add(new TestElement(1)); + coll1.add(new TestElement(2)); + coll1.add(new TestElement(3)); + + ImplicitLinkedHashCollection coll2 = new ImplicitLinkedHashCollection<>(); + coll2.add(new TestElement(1)); + coll2.add(new TestElement(2)); + coll2.add(new TestElement(3)); + + ImplicitLinkedHashCollection coll3 = new ImplicitLinkedHashCollection<>(); + coll3.add(new TestElement(1)); + coll3.add(new TestElement(3)); + coll3.add(new TestElement(2)); + + assertEquals(coll1, coll2); + assertNotEquals(coll1, coll3); + assertNotEquals(coll2, coll3); + } + + @Test + public void testFindContainsRemoveOnEmptyCollection() { + ImplicitLinkedHashCollection coll = new ImplicitLinkedHashCollection<>(); + assertNull(coll.find(new TestElement(2))); + assertFalse(coll.contains(new TestElement(2))); + assertFalse(coll.remove(new TestElement(2))); + } + + private void addRandomElement(Random random, LinkedHashSet existing, + ImplicitLinkedHashCollection set) { + int next; + do { + next = random.nextInt(); + } while (existing.contains(next)); + existing.add(next); + set.add(new TestElement(next)); + } + + @SuppressWarnings("unlikely-arg-type") + private void removeRandomElement(Random random, Collection existing) { + int removeIdx = random.nextInt(existing.size()); + Iterator iter = existing.iterator(); + Integer element = null; + for (int i = 0; i <= removeIdx; i++) { + element = iter.next(); + } + existing.remove(new TestElement(element)); + } + + @Test + public void testSameKeysDifferentValues() { + ImplicitLinkedHashCollection coll = new ImplicitLinkedHashCollection<>(); + assertTrue(coll.add(new TestElement(1, 1))); + assertFalse(coll.add(new TestElement(1, 2))); + TestElement element2 = new TestElement(1, 2); + TestElement element1 = coll.find(element2); + assertFalse(element2.equals(element1)); + assertTrue(element2.elementKeysAreEqual(element1)); + } + + @Test + public void testMoveToEnd() { + ImplicitLinkedHashCollection coll = new ImplicitLinkedHashCollection<>(); + TestElement e1 = new TestElement(1, 1); + TestElement e2 = new TestElement(2, 2); + TestElement e3 = new TestElement(3, 3); + assertTrue(coll.add(e1)); + assertTrue(coll.add(e2)); + assertTrue(coll.add(e3)); + coll.moveToEnd(e1); + expectTraversal(coll.iterator(), 2, 3, 1); + assertThrows(RuntimeException.class, () -> coll.moveToEnd(new TestElement(4, 4))); + } + + @Test + public void testRemovals() { + ImplicitLinkedHashCollection coll = new ImplicitLinkedHashCollection<>(); + List elements = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + TestElement element = new TestElement(i, i); + elements.add(element); + coll.add(element); + } + assertEquals(100, coll.size()); + Iterator iter = coll.iterator(); + for (int i = 0; i < 50; i++) { + iter.next(); + iter.remove(); + } + assertEquals(50, coll.size()); + for (int i = 50; i < 100; i++) { + assertEquals(new TestElement(i, i), coll.find(elements.get(i))); + } + } + + static class TestElementComparator implements Comparator { + static final TestElementComparator INSTANCE = new TestElementComparator(); + + @Override + public int compare(TestElement a, TestElement b) { + if (a.key < b.key) { + return -1; + } else if (a.key > b.key) { + return 1; + } else if (a.val < b.val) { + return -1; + } else if (a.val > b.val) { + return 1; + } else { + return 0; + } + } + } + + static class ReverseTestElementComparator implements Comparator { + static final ReverseTestElementComparator INSTANCE = new ReverseTestElementComparator(); + + @Override + public int compare(TestElement a, TestElement b) { + return TestElementComparator.INSTANCE.compare(b, a); + } + } + + @Test + public void testSort() { + ImplicitLinkedHashCollection coll = new ImplicitLinkedHashCollection<>(); + coll.add(new TestElement(3, 3)); + coll.add(new TestElement(1, 1)); + coll.add(new TestElement(10, 10)); + coll.add(new TestElement(9, 9)); + coll.add(new TestElement(2, 2)); + coll.add(new TestElement(4, 4)); + coll.add(new TestElement(0, 0)); + coll.add(new TestElement(30, 30)); + coll.add(new TestElement(20, 20)); + coll.add(new TestElement(11, 11)); + coll.add(new TestElement(15, 15)); + coll.add(new TestElement(5, 5)); + + expectTraversal(coll.iterator(), 3, 1, 10, 9, 2, 4, 0, 30, 20, 11, 15, 5); + coll.sort(TestElementComparator.INSTANCE); + expectTraversal(coll.iterator(), 0, 1, 2, 3, 4, 5, 9, 10, 11, 15, 20, 30); + coll.sort(TestElementComparator.INSTANCE); + expectTraversal(coll.iterator(), 0, 1, 2, 3, 4, 5, 9, 10, 11, 15, 20, 30); + coll.sort(ReverseTestElementComparator.INSTANCE); + expectTraversal(coll.iterator(), 30, 20, 15, 11, 10, 9, 5, 4, 3, 2, 1, 0); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/ImplicitLinkedHashMultiCollectionTest.java b/clients/src/test/java/org/apache/kafka/common/utils/ImplicitLinkedHashMultiCollectionTest.java new file mode 100644 index 0000000..9b3df92 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/ImplicitLinkedHashMultiCollectionTest.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.apache.kafka.common.utils.ImplicitLinkedHashCollectionTest.TestElement; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.Iterator; +import java.util.LinkedList; +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +/** + * A unit test for ImplicitLinkedHashMultiCollection. + */ +@Timeout(120) +public class ImplicitLinkedHashMultiCollectionTest { + + @Test + public void testNullForbidden() { + ImplicitLinkedHashMultiCollection multiSet = new ImplicitLinkedHashMultiCollection<>(); + assertFalse(multiSet.add(null)); + } + + @Test + public void testFindFindAllContainsRemoveOnEmptyCollection() { + ImplicitLinkedHashMultiCollection coll = new ImplicitLinkedHashMultiCollection<>(); + assertNull(coll.find(new TestElement(2))); + assertFalse(coll.contains(new TestElement(2))); + assertFalse(coll.remove(new TestElement(2))); + assertTrue(coll.findAll(new TestElement(2)).isEmpty()); + } + + @Test + public void testInsertDelete() { + ImplicitLinkedHashMultiCollection multiSet = new ImplicitLinkedHashMultiCollection<>(100); + TestElement e1 = new TestElement(1); + TestElement e2 = new TestElement(1); + TestElement e3 = new TestElement(2); + multiSet.mustAdd(e1); + multiSet.mustAdd(e2); + multiSet.mustAdd(e3); + assertFalse(multiSet.add(e3)); + assertEquals(3, multiSet.size()); + expectExactTraversal(multiSet.findAll(e1).iterator(), e1, e2); + expectExactTraversal(multiSet.findAll(e3).iterator(), e3); + multiSet.remove(e2); + expectExactTraversal(multiSet.findAll(e1).iterator(), e1); + assertTrue(multiSet.contains(e2)); + } + + @Test + public void testTraversal() { + ImplicitLinkedHashMultiCollection multiSet = new ImplicitLinkedHashMultiCollection<>(); + expectExactTraversal(multiSet.iterator()); + TestElement e1 = new TestElement(1); + TestElement e2 = new TestElement(1); + TestElement e3 = new TestElement(2); + assertTrue(multiSet.add(e1)); + assertTrue(multiSet.add(e2)); + assertTrue(multiSet.add(e3)); + expectExactTraversal(multiSet.iterator(), e1, e2, e3); + assertTrue(multiSet.remove(e2)); + expectExactTraversal(multiSet.iterator(), e1, e3); + assertTrue(multiSet.remove(e1)); + expectExactTraversal(multiSet.iterator(), e3); + } + + static void expectExactTraversal(Iterator iterator, TestElement... sequence) { + int i = 0; + while (iterator.hasNext()) { + TestElement element = iterator.next(); + assertTrue(i < sequence.length, "Iterator yieled " + (i + 1) + " elements, but only " + + sequence.length + " were expected."); + if (sequence[i] != element) { + fail("Iterator value number " + (i + 1) + " was incorrect."); + } + i = i + 1; + } + assertTrue(i == sequence.length, "Iterator yieled " + (i + 1) + " elements, but " + + sequence.length + " were expected."); + } + + @Test + public void testEnlargement() { + ImplicitLinkedHashMultiCollection multiSet = new ImplicitLinkedHashMultiCollection<>(5); + assertEquals(11, multiSet.numSlots()); + TestElement[] testElements = { + new TestElement(100), + new TestElement(101), + new TestElement(102), + new TestElement(100), + new TestElement(101), + new TestElement(105) + }; + for (int i = 0; i < testElements.length; i++) { + assertTrue(multiSet.add(testElements[i])); + } + for (int i = 0; i < testElements.length; i++) { + assertFalse(multiSet.add(testElements[i])); + } + assertEquals(23, multiSet.numSlots()); + assertEquals(testElements.length, multiSet.size()); + expectExactTraversal(multiSet.iterator(), testElements); + multiSet.remove(testElements[1]); + assertEquals(23, multiSet.numSlots()); + assertEquals(5, multiSet.size()); + expectExactTraversal(multiSet.iterator(), + testElements[0], testElements[2], testElements[3], testElements[4], testElements[5]); + } + + @Test + public void testManyInsertsAndDeletes() { + Random random = new Random(123); + LinkedList existing = new LinkedList<>(); + ImplicitLinkedHashMultiCollection multiSet = new ImplicitLinkedHashMultiCollection<>(); + for (int i = 0; i < 100; i++) { + for (int j = 0; j < 4; j++) { + TestElement testElement = new TestElement(random.nextInt()); + multiSet.mustAdd(testElement); + existing.add(testElement); + } + int elementToRemove = random.nextInt(multiSet.size()); + Iterator iter1 = multiSet.iterator(); + Iterator iter2 = existing.iterator(); + for (int j = 0; j <= elementToRemove; j++) { + iter1.next(); + iter2.next(); + } + iter1.remove(); + iter2.remove(); + expectTraversal(multiSet.iterator(), existing.iterator()); + } + } + + void expectTraversal(Iterator iter, Iterator expectedIter) { + int i = 0; + while (iter.hasNext()) { + TestElement element = iter.next(); + assertTrue(expectedIter.hasNext(), + "Iterator yieled " + (i + 1) + " elements, but only " + i + " were expected."); + TestElement expected = expectedIter.next(); + assertTrue(expected == element, + "Iterator value number " + (i + 1) + " was incorrect."); + i = i + 1; + } + assertFalse(expectedIter.hasNext(), + "Iterator yieled " + i + " elements, but at least " + (i + 1) + " were expected."); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/JavaTest.java b/clients/src/test/java/org/apache/kafka/common/utils/JavaTest.java new file mode 100644 index 0000000..8e1311c --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/JavaTest.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class JavaTest { + + private String javaVendor; + + @BeforeEach + public void before() { + javaVendor = System.getProperty("java.vendor"); + } + + @AfterEach + public void after() { + System.setProperty("java.vendor", javaVendor); + } + + @Test + public void testIsIBMJdk() { + System.setProperty("java.vendor", "Oracle Corporation"); + assertFalse(Java.isIbmJdk()); + System.setProperty("java.vendor", "IBM Corporation"); + assertTrue(Java.isIbmJdk()); + } + + @Test + public void testLoadKerberosLoginModule() throws ClassNotFoundException { + String clazz = Java.isIbmJdk() + ? "com.ibm.security.auth.module.Krb5LoginModule" + : "com.sun.security.auth.module.Krb5LoginModule"; + Class.forName(clazz); + } + + @Test + public void testJavaVersion() { + Java.Version v = Java.parseVersion("9"); + assertEquals(9, v.majorVersion); + assertEquals(0, v.minorVersion); + assertTrue(v.isJava9Compatible()); + + v = Java.parseVersion("9.0.1"); + assertEquals(9, v.majorVersion); + assertEquals(0, v.minorVersion); + assertTrue(v.isJava9Compatible()); + + v = Java.parseVersion("9.0.0.15"); // Azul Zulu + assertEquals(9, v.majorVersion); + assertEquals(0, v.minorVersion); + assertTrue(v.isJava9Compatible()); + + v = Java.parseVersion("9.1"); + assertEquals(9, v.majorVersion); + assertEquals(1, v.minorVersion); + assertTrue(v.isJava9Compatible()); + + v = Java.parseVersion("1.8.0_152"); + assertEquals(1, v.majorVersion); + assertEquals(8, v.minorVersion); + assertFalse(v.isJava9Compatible()); + + v = Java.parseVersion("1.7.0_80"); + assertEquals(1, v.majorVersion); + assertEquals(7, v.minorVersion); + assertFalse(v.isJava9Compatible()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/LoggingSignalHandlerTest.java b/clients/src/test/java/org/apache/kafka/common/utils/LoggingSignalHandlerTest.java new file mode 100644 index 0000000..cdf30c2 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/LoggingSignalHandlerTest.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.junit.jupiter.api.Test; + +public class LoggingSignalHandlerTest { + + @Test + public void testRegister() throws ReflectiveOperationException { + new LoggingSignalHandler().register(); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/MappedIteratorTest.java b/clients/src/test/java/org/apache/kafka/common/utils/MappedIteratorTest.java new file mode 100644 index 0000000..058f2cd --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/MappedIteratorTest.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class MappedIteratorTest { + + @Test + public void testStringToInteger() { + List list = asList("foo", "", "bar2", "baz45"); + Function mapper = s -> s.length(); + + Iterable mappedIterable = () -> new MappedIterator<>(list.iterator(), mapper); + List mapped = new ArrayList<>(); + mappedIterable.forEach(mapped::add); + + assertEquals(list.stream().map(mapper).collect(Collectors.toList()), mapped); + + // Ensure that we can iterate a second time + List mapped2 = new ArrayList<>(); + mappedIterable.forEach(mapped2::add); + assertEquals(mapped, mapped2); + } + + @Test + public void testEmptyList() { + List list = emptyList(); + Function mapper = s -> s.length(); + + Iterable mappedIterable = () -> new MappedIterator<>(list.iterator(), mapper); + List mapped = new ArrayList<>(); + mappedIterable.forEach(mapped::add); + + assertEquals(emptyList(), mapped); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/MockScheduler.java b/clients/src/test/java/org/apache/kafka/common/utils/MockScheduler.java new file mode 100644 index 0000000..78a9060 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/MockScheduler.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; +import java.util.concurrent.Callable; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; + +public class MockScheduler implements Scheduler, MockTime.Listener { + private static final Logger log = LoggerFactory.getLogger(MockScheduler.class); + + /** + * The MockTime object. + */ + private final MockTime time; + + /** + * Futures which are waiting for a specified wall-clock time to arrive. + */ + private final TreeMap>> waiters = new TreeMap<>(); + + public MockScheduler(MockTime time) { + this.time = time; + time.addListener(this); + } + + @Override + public Time time() { + return time; + } + + @Override + public synchronized void onTimeUpdated() { + long timeMs = time.milliseconds(); + while (true) { + Map.Entry>> entry = waiters.firstEntry(); + if ((entry == null) || (entry.getKey() > timeMs)) { + break; + } + for (KafkaFutureImpl future : entry.getValue()) { + future.complete(timeMs); + } + waiters.remove(entry.getKey()); + } + } + + public synchronized void addWaiter(long delayMs, KafkaFutureImpl waiter) { + long timeMs = time.milliseconds(); + if (delayMs <= 0) { + waiter.complete(timeMs); + } else { + long triggerTimeMs = timeMs + delayMs; + List> futures = waiters.get(triggerTimeMs); + if (futures == null) { + futures = new ArrayList<>(); + waiters.put(triggerTimeMs, futures); + } + futures.add(waiter); + } + } + + @Override + public Future schedule(final ScheduledExecutorService executor, + final Callable callable, long delayMs) { + final KafkaFutureImpl future = new KafkaFutureImpl<>(); + KafkaFutureImpl waiter = new KafkaFutureImpl<>(); + waiter.thenApply(new KafkaFuture.BaseFunction() { + @Override + public Void apply(final Long now) { + executor.submit(new Callable() { + @Override + public Void call() { + // Note: it is possible that we'll execute Callable#call right after + // the future is cancelled. This is a valid sequence of events + // that the author of the Callable needs to be able to handle. + // + // Note 2: If the future is cancelled, we will not remove the waiter + // from this MockTime object. This small bit of inefficiency is acceptable + // in testing code (at least we aren't polling!) + if (!future.isCancelled()) { + try { + log.trace("Invoking {} at {}", callable, now); + future.complete(callable.call()); + } catch (Throwable throwable) { + future.completeExceptionally(throwable); + } + } + return null; + } + }); + return null; + } + }); + log.trace("Scheduling {} for {} ms from now.", callable, delayMs); + addWaiter(delayMs, waiter); + return future; + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/MockTime.java b/clients/src/test/java/org/apache/kafka/common/utils/MockTime.java new file mode 100644 index 0000000..ccf3eec --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/MockTime.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.apache.kafka.common.errors.TimeoutException; + +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; + +/** + * A clock that you can manually advance by calling sleep + */ +public class MockTime implements Time { + + public interface Listener { + void onTimeUpdated(); + } + + /** + * Listeners which are waiting for time changes. + */ + private final CopyOnWriteArrayList listeners = new CopyOnWriteArrayList<>(); + + private final long autoTickMs; + + // Values from `nanoTime` and `currentTimeMillis` are not comparable, so we store them separately to allow tests + // using this class to detect bugs where this is incorrectly assumed to be true + private final AtomicLong timeMs; + private final AtomicLong highResTimeNs; + + public MockTime() { + this(0); + } + + public MockTime(long autoTickMs) { + this(autoTickMs, System.currentTimeMillis(), System.nanoTime()); + } + + public MockTime(long autoTickMs, long currentTimeMs, long currentHighResTimeNs) { + this.timeMs = new AtomicLong(currentTimeMs); + this.highResTimeNs = new AtomicLong(currentHighResTimeNs); + this.autoTickMs = autoTickMs; + } + + public void addListener(Listener listener) { + listeners.add(listener); + } + + @Override + public long milliseconds() { + maybeSleep(autoTickMs); + return timeMs.get(); + } + + @Override + public long nanoseconds() { + maybeSleep(autoTickMs); + return highResTimeNs.get(); + } + + private void maybeSleep(long ms) { + if (ms != 0) + sleep(ms); + } + + @Override + public void sleep(long ms) { + timeMs.addAndGet(ms); + highResTimeNs.addAndGet(TimeUnit.MILLISECONDS.toNanos(ms)); + tick(); + } + + @Override + public void waitObject(Object obj, Supplier condition, long deadlineMs) throws InterruptedException { + Listener listener = () -> { + synchronized (obj) { + obj.notify(); + } + }; + listeners.add(listener); + try { + synchronized (obj) { + while (milliseconds() < deadlineMs && !condition.get()) { + obj.wait(); + } + if (!condition.get()) + throw new TimeoutException("Condition not satisfied before deadline"); + } + } finally { + listeners.remove(listener); + } + } + + public void setCurrentTimeMs(long newMs) { + long oldMs = timeMs.getAndSet(newMs); + + // does not allow to set to an older timestamp + if (oldMs > newMs) + throw new IllegalArgumentException("Setting the time to " + newMs + " while current time " + oldMs + " is newer; this is not allowed"); + + highResTimeNs.set(TimeUnit.MILLISECONDS.toNanos(newMs)); + tick(); + } + + private void tick() { + for (Listener listener : listeners) { + listener.onTimeUpdated(); + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/MockTimeTest.java b/clients/src/test/java/org/apache/kafka/common/utils/MockTimeTest.java new file mode 100644 index 0000000..88a18fe --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/MockTimeTest.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +@Timeout(120) +public class MockTimeTest extends TimeTest { + + @Test + public void testAdvanceClock() { + MockTime time = new MockTime(0, 100, 200); + assertEquals(100, time.milliseconds()); + assertEquals(200, time.nanoseconds()); + time.sleep(1); + assertEquals(101, time.milliseconds()); + assertEquals(1000200, time.nanoseconds()); + } + + @Test + public void testAutoTickMs() { + MockTime time = new MockTime(1, 100, 200); + assertEquals(101, time.milliseconds()); + assertEquals(2000200, time.nanoseconds()); + assertEquals(103, time.milliseconds()); + assertEquals(104, time.milliseconds()); + } + + @Override + protected Time createTime() { + return new MockTime(); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/SanitizerTest.java b/clients/src/test/java/org/apache/kafka/common/utils/SanitizerTest.java new file mode 100644 index 0000000..3024bd3 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/SanitizerTest.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +import java.lang.management.ManagementFactory; + +import javax.management.MBeanException; +import javax.management.MBeanServer; +import javax.management.MalformedObjectNameException; +import javax.management.ObjectName; +import javax.management.OperationsException; + +import org.junit.jupiter.api.Test; + +public class SanitizerTest { + + @Test + public void testSanitize() { + String principal = "CN=Some characters !@#$%&*()_-+=';:,/~"; + String sanitizedPrincipal = Sanitizer.sanitize(principal); + assertTrue(sanitizedPrincipal.replace('%', '_').matches("[a-zA-Z0-9\\._\\-]+")); + assertEquals(principal, Sanitizer.desanitize(sanitizedPrincipal)); + } + + @Test + public void testJmxSanitize() throws MalformedObjectNameException { + int unquoted = 0; + for (int i = 0; i < 65536; i++) { + char c = (char) i; + String value = "value" + c; + String jmxSanitizedValue = Sanitizer.jmxSanitize(value); + if (jmxSanitizedValue.equals(value)) + unquoted++; + verifyJmx(jmxSanitizedValue, i); + String encodedValue = Sanitizer.sanitize(value); + verifyJmx(encodedValue, i); + // jmxSanitize should not sanitize URL-encoded values + assertEquals(encodedValue, Sanitizer.jmxSanitize(encodedValue)); + } + assertEquals(68, unquoted); // a-zA-Z0-9-_% space and tab + } + + private void verifyJmx(String sanitizedValue, int c) throws MalformedObjectNameException { + Object mbean = new TestStat(); + MBeanServer server = ManagementFactory.getPlatformMBeanServer(); + ObjectName objectName = new ObjectName("test:key=" + sanitizedValue); + try { + server.registerMBean(mbean, objectName); + server.unregisterMBean(objectName); + } catch (OperationsException | MBeanException e) { + fail("Could not register char=\\u" + c); + } + } + + public interface TestStatMBean { + int getValue(); + } + + public class TestStat implements TestStatMBean { + public int getValue() { + return 1; + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/SecurityUtilsTest.java b/clients/src/test/java/org/apache/kafka/common/utils/SecurityUtilsTest.java new file mode 100644 index 0000000..e651092 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/SecurityUtilsTest.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.apache.kafka.common.config.SecurityConfig; +import org.apache.kafka.common.security.auth.SecurityProviderCreator; +import org.apache.kafka.common.security.auth.KafkaPrincipal; +import org.apache.kafka.common.security.ssl.mock.TestPlainSaslServerProviderCreator; +import org.apache.kafka.common.security.ssl.mock.TestScramSaslServerProviderCreator; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.security.Provider; +import java.security.Security; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class SecurityUtilsTest { + + private SecurityProviderCreator testScramSaslServerProviderCreator = new TestScramSaslServerProviderCreator(); + private SecurityProviderCreator testPlainSaslServerProviderCreator = new TestPlainSaslServerProviderCreator(); + + private Provider testScramSaslServerProvider = testScramSaslServerProviderCreator.getProvider(); + private Provider testPlainSaslServerProvider = testPlainSaslServerProviderCreator.getProvider(); + + private void clearTestProviders() { + Security.removeProvider(testScramSaslServerProvider.getName()); + Security.removeProvider(testPlainSaslServerProvider.getName()); + } + + @BeforeEach + // Remove the providers if already added + public void setUp() { + clearTestProviders(); + } + + // Remove the providers after running test cases + @AfterEach + public void tearDown() { + clearTestProviders(); + } + + + @Test + public void testPrincipalNameCanContainSeparator() { + String name = "name:with:separator:in:it"; + KafkaPrincipal principal = SecurityUtils.parseKafkaPrincipal(KafkaPrincipal.USER_TYPE + ":" + name); + assertEquals(KafkaPrincipal.USER_TYPE, principal.getPrincipalType()); + assertEquals(name, principal.getName()); + } + + @Test + public void testParseKafkaPrincipalWithNonUserPrincipalType() { + String name = "foo"; + String principalType = "Group"; + KafkaPrincipal principal = SecurityUtils.parseKafkaPrincipal(principalType + ":" + name); + assertEquals(principalType, principal.getPrincipalType()); + assertEquals(name, principal.getName()); + } + + private int getProviderIndexFromName(String providerName, Provider[] providers) { + for (int index = 0; index < providers.length; index++) { + if (providers[index].getName().equals(providerName)) { + return index; + } + } + return -1; + } + + // Tests if the custom providers configured are being added to the JVM correctly. These providers are + // expected to be added at the start of the list of available providers and with the relative ordering maintained + @Test + public void testAddCustomSecurityProvider() { + String customProviderClasses = testScramSaslServerProviderCreator.getClass().getName() + "," + + testPlainSaslServerProviderCreator.getClass().getName(); + Map configs = new HashMap<>(); + configs.put(SecurityConfig.SECURITY_PROVIDERS_CONFIG, customProviderClasses); + SecurityUtils.addConfiguredSecurityProviders(configs); + + Provider[] providers = Security.getProviders(); + int testScramSaslServerProviderIndex = getProviderIndexFromName(testScramSaslServerProvider.getName(), providers); + int testPlainSaslServerProviderIndex = getProviderIndexFromName(testPlainSaslServerProvider.getName(), providers); + + assertEquals(0, testScramSaslServerProviderIndex, + testScramSaslServerProvider.getName() + " testProvider not found at expected index"); + assertEquals(1, testPlainSaslServerProviderIndex, + testPlainSaslServerProvider.getName() + " testProvider not found at expected index"); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/Serializer.java b/clients/src/test/java/org/apache/kafka/common/utils/Serializer.java new file mode 100644 index 0000000..a902449 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/Serializer.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.ObjectOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.ObjectInputStream; + +public class Serializer { + + public static byte[] serialize(Object toSerialize) throws IOException { + ByteArrayOutputStream arrayOutputStream = new ByteArrayOutputStream(); + try (ObjectOutputStream ooStream = new ObjectOutputStream(arrayOutputStream)) { + ooStream.writeObject(toSerialize); + return arrayOutputStream.toByteArray(); + } + } + + public static Object deserialize(InputStream inputStream) throws IOException, ClassNotFoundException { + try (ObjectInputStream objectInputStream = new ObjectInputStream(inputStream)) { + return objectInputStream.readObject(); + } + } + + public static Object deserialize(byte[] byteArray) throws IOException, ClassNotFoundException { + ByteArrayInputStream arrayInputStream = new ByteArrayInputStream(byteArray); + return deserialize(arrayInputStream); + } + + public static Object deserialize(String fileName) throws IOException, ClassNotFoundException { + ClassLoader classLoader = Serializer.class.getClassLoader(); + InputStream fileStream = classLoader.getResourceAsStream(fileName); + return deserialize(fileStream); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/ShellTest.java b/clients/src/test/java/org/apache/kafka/common/utils/ShellTest.java new file mode 100644 index 0000000..75020d5 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/ShellTest.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.condition.DisabledOnOs; +import org.junit.jupiter.api.condition.OS; + +import java.io.IOException; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@Timeout(180) +@DisabledOnOs(OS.WINDOWS) +public class ShellTest { + + @Test + public void testEchoHello() throws Exception { + String output = Shell.execCommand("echo", "hello"); + assertEquals("hello\n", output); + } + + @Test + public void testHeadDevZero() throws Exception { + final int length = 100000; + String output = Shell.execCommand("head", "-c", Integer.toString(length), "/dev/zero"); + assertEquals(length, output.length()); + } + + private final static String NONEXISTENT_PATH = "/dev/a/path/that/does/not/exist/in/the/filesystem"; + + @Test + public void testAttemptToRunNonExistentProgram() { + IOException e = assertThrows(IOException.class, () -> Shell.execCommand(NONEXISTENT_PATH), + "Expected to get an exception when trying to run a program that does not exist"); + assertTrue(e.getMessage().contains("No such file"), "Unexpected error message '" + e.getMessage() + "'"); + } + + @Test + public void testRunProgramWithErrorReturn() { + Shell.ExitCodeException e = assertThrows(Shell.ExitCodeException.class, + () -> Shell.execCommand("head", "-c", "0", NONEXISTENT_PATH)); + String message = e.getMessage(); + assertTrue(message.contains("No such file") || message.contains("illegal byte count"), + "Unexpected error message '" + message + "'"); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/SystemTimeTest.java b/clients/src/test/java/org/apache/kafka/common/utils/SystemTimeTest.java new file mode 100644 index 0000000..edc53d2 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/SystemTimeTest.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +public class SystemTimeTest extends TimeTest { + + @Override + protected Time createTime() { + return Time.SYSTEM; + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/ThreadUtilsTest.java b/clients/src/test/java/org/apache/kafka/common/utils/ThreadUtilsTest.java new file mode 100644 index 0000000..0a299dd --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/ThreadUtilsTest.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.junit.jupiter.api.Test; + +import java.util.concurrent.ThreadFactory; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ThreadUtilsTest { + + private static final Runnable EMPTY_RUNNABLE = () -> { + }; + private static final String THREAD_NAME = "ThreadName"; + private static final String THREAD_NAME_WITH_NUMBER = THREAD_NAME + "%d"; + + + @Test + public void testThreadNameWithoutNumberNoDemon() { + assertEquals(THREAD_NAME, ThreadUtils.createThreadFactory(THREAD_NAME, false). + newThread(EMPTY_RUNNABLE).getName()); + } + + @Test + public void testThreadNameWithoutNumberDemon() { + Thread daemonThread = ThreadUtils.createThreadFactory(THREAD_NAME, true).newThread(EMPTY_RUNNABLE); + try { + assertEquals(THREAD_NAME, daemonThread.getName()); + assertTrue(daemonThread.isDaemon()); + } finally { + try { + daemonThread.join(); + } catch (InterruptedException e) { + // can be ignored + } + } + } + + @Test + public void testThreadNameWithNumberNoDemon() { + ThreadFactory localThreadFactory = ThreadUtils.createThreadFactory(THREAD_NAME_WITH_NUMBER, false); + assertEquals(THREAD_NAME + "1", localThreadFactory.newThread(EMPTY_RUNNABLE).getName()); + assertEquals(THREAD_NAME + "2", localThreadFactory.newThread(EMPTY_RUNNABLE).getName()); + } + + @Test + public void testThreadNameWithNumberDemon() { + ThreadFactory localThreadFactory = ThreadUtils.createThreadFactory(THREAD_NAME_WITH_NUMBER, true); + Thread daemonThread1 = localThreadFactory.newThread(EMPTY_RUNNABLE); + Thread daemonThread2 = localThreadFactory.newThread(EMPTY_RUNNABLE); + + try { + assertEquals(THREAD_NAME + "1", daemonThread1.getName()); + assertTrue(daemonThread1.isDaemon()); + } finally { + try { + daemonThread1.join(); + } catch (InterruptedException e) { + // can be ignored + } + } + try { + assertEquals(THREAD_NAME + "2", daemonThread2.getName()); + assertTrue(daemonThread2.isDaemon()); + } finally { + try { + daemonThread2.join(); + } catch (InterruptedException e) { + // can be ignored + } + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/TimeTest.java b/clients/src/test/java/org/apache/kafka/common/utils/TimeTest.java new file mode 100644 index 0000000..808f63c --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/TimeTest.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.apache.kafka.common.errors.TimeoutException; +import org.junit.jupiter.api.Test; + +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public abstract class TimeTest { + + protected abstract Time createTime(); + + @Test + public void testWaitObjectTimeout() throws InterruptedException { + Object obj = new Object(); + Time time = createTime(); + long timeoutMs = 100; + long deadlineMs = time.milliseconds() + timeoutMs; + AtomicReference caughtException = new AtomicReference<>(); + Thread t = new Thread(() -> { + try { + time.waitObject(obj, () -> false, deadlineMs); + } catch (Exception e) { + caughtException.set(e); + } + }); + + t.start(); + time.sleep(timeoutMs); + t.join(); + + assertEquals(TimeoutException.class, caughtException.get().getClass()); + } + + @Test + public void testWaitObjectConditionSatisfied() throws InterruptedException { + Object obj = new Object(); + Time time = createTime(); + long timeoutMs = 1000000000; + long deadlineMs = time.milliseconds() + timeoutMs; + AtomicBoolean condition = new AtomicBoolean(false); + AtomicReference caughtException = new AtomicReference<>(); + Thread t = new Thread(() -> { + try { + time.waitObject(obj, condition::get, deadlineMs); + } catch (Exception e) { + caughtException.set(e); + } + }); + + t.start(); + + synchronized (obj) { + condition.set(true); + obj.notify(); + } + + t.join(); + + assertTrue(time.milliseconds() < deadlineMs); + assertNull(caughtException.get()); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/TimerTest.java b/clients/src/test/java/org/apache/kafka/common/utils/TimerTest.java new file mode 100644 index 0000000..3fb7980 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/TimerTest.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.utils; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class TimerTest { + + private final MockTime time = new MockTime(); + + @Test + public void testTimerUpdate() { + Timer timer = time.timer(500); + assertEquals(500, timer.timeoutMs()); + assertEquals(500, timer.remainingMs()); + assertEquals(0, timer.elapsedMs()); + + time.sleep(100); + timer.update(); + + assertEquals(500, timer.timeoutMs()); + assertEquals(400, timer.remainingMs()); + assertEquals(100, timer.elapsedMs()); + + time.sleep(400); + timer.update(time.milliseconds()); + + assertEquals(500, timer.timeoutMs()); + assertEquals(0, timer.remainingMs()); + assertEquals(500, timer.elapsedMs()); + assertTrue(timer.isExpired()); + + // Going over the expiration is fine and the elapsed time can exceed + // the initial timeout. However, remaining time should be stuck at 0. + time.sleep(200); + timer.update(time.milliseconds()); + assertTrue(timer.isExpired()); + assertEquals(500, timer.timeoutMs()); + assertEquals(0, timer.remainingMs()); + assertEquals(700, timer.elapsedMs()); + } + + @Test + public void testTimerUpdateAndReset() { + Timer timer = time.timer(500); + timer.sleep(200); + assertEquals(500, timer.timeoutMs()); + assertEquals(300, timer.remainingMs()); + assertEquals(200, timer.elapsedMs()); + + timer.updateAndReset(400); + assertEquals(400, timer.timeoutMs()); + assertEquals(400, timer.remainingMs()); + assertEquals(0, timer.elapsedMs()); + + timer.sleep(400); + assertTrue(timer.isExpired()); + + timer.updateAndReset(200); + assertEquals(200, timer.timeoutMs()); + assertEquals(200, timer.remainingMs()); + assertEquals(0, timer.elapsedMs()); + assertFalse(timer.isExpired()); + } + + @Test + public void testTimerResetUsesCurrentTime() { + Timer timer = time.timer(500); + timer.sleep(200); + assertEquals(300, timer.remainingMs()); + assertEquals(200, timer.elapsedMs()); + + time.sleep(300); + timer.reset(500); + assertEquals(500, timer.remainingMs()); + + timer.update(); + assertEquals(200, timer.remainingMs()); + } + + @Test + public void testTimerResetDeadlineUsesCurrentTime() { + Timer timer = time.timer(500); + timer.sleep(200); + assertEquals(300, timer.remainingMs()); + assertEquals(200, timer.elapsedMs()); + + timer.sleep(100); + timer.resetDeadline(time.milliseconds() + 200); + assertEquals(200, timer.timeoutMs()); + assertEquals(200, timer.remainingMs()); + + timer.sleep(100); + assertEquals(200, timer.timeoutMs()); + assertEquals(100, timer.remainingMs()); + } + + @Test + public void testTimeoutOverflow() { + Timer timer = time.timer(Long.MAX_VALUE); + assertEquals(Long.MAX_VALUE - timer.currentTimeMs(), timer.remainingMs()); + assertEquals(0, timer.elapsedMs()); + } + + @Test + public void testNonMonotonicUpdate() { + Timer timer = time.timer(100); + long currentTimeMs = timer.currentTimeMs(); + + timer.update(currentTimeMs - 1); + assertEquals(currentTimeMs, timer.currentTimeMs()); + + assertEquals(100, timer.remainingMs()); + assertEquals(0, timer.elapsedMs()); + } + + @Test + public void testTimerSleep() { + Timer timer = time.timer(500); + long currentTimeMs = timer.currentTimeMs(); + + timer.sleep(200); + assertEquals(time.milliseconds(), timer.currentTimeMs()); + assertEquals(currentTimeMs + 200, timer.currentTimeMs()); + + timer.sleep(1000); + assertEquals(time.milliseconds(), timer.currentTimeMs()); + assertEquals(currentTimeMs + 500, timer.currentTimeMs()); + assertTrue(timer.isExpired()); + } + +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/UtilsTest.java b/clients/src/test/java/org/apache/kafka/common/utils/UtilsTest.java new file mode 100755 index 0000000..25218e6 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/UtilsTest.java @@ -0,0 +1,905 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils; + +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.function.Executable; +import org.mockito.stubbing.OngoingStubbing; + +import java.io.Closeable; +import java.io.DataOutputStream; +import java.io.EOFException; +import java.io.File; +import java.io.IOException; +import java.nio.BufferUnderflowException; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.StandardOpenOption; +import java.time.LocalDateTime; +import java.time.format.DateTimeFormatter; +import java.time.format.DateTimeFormatterBuilder; +import java.time.temporal.ChronoField; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Map; +import java.util.Properties; +import java.util.Random; +import java.util.Set; +import java.util.TreeSet; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +import static java.util.Arrays.asList; +import static java.util.Collections.emptySet; +import static org.apache.kafka.common.utils.Utils.diff; +import static org.apache.kafka.common.utils.Utils.formatAddress; +import static org.apache.kafka.common.utils.Utils.formatBytes; +import static org.apache.kafka.common.utils.Utils.getHost; +import static org.apache.kafka.common.utils.Utils.getPort; +import static org.apache.kafka.common.utils.Utils.intersection; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.common.utils.Utils.murmur2; +import static org.apache.kafka.common.utils.Utils.union; +import static org.apache.kafka.common.utils.Utils.validHostPattern; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import java.text.ParseException; +import java.text.SimpleDateFormat; +import java.util.Date; + +public class UtilsTest { + + @Test + public void testMurmur2() { + Map cases = new java.util.HashMap<>(); + cases.put("21".getBytes(), -973932308); + cases.put("foobar".getBytes(), -790332482); + cases.put("a-little-bit-long-string".getBytes(), -985981536); + cases.put("a-little-bit-longer-string".getBytes(), -1486304829); + cases.put("lkjh234lh9fiuh90y23oiuhsafujhadof229phr9h19h89h8".getBytes(), -58897971); + cases.put(new byte[] {'a', 'b', 'c'}, 479470107); + + for (Map.Entry c : cases.entrySet()) { + assertEquals(c.getValue().intValue(), murmur2(c.getKey())); + } + } + + @Test + public void testGetHost() { + assertEquals("127.0.0.1", getHost("127.0.0.1:8000")); + assertEquals("mydomain.com", getHost("PLAINTEXT://mydomain.com:8080")); + assertEquals("MyDomain.com", getHost("PLAINTEXT://MyDomain.com:8080")); + assertEquals("My_Domain.com", getHost("PLAINTEXT://My_Domain.com:8080")); + assertEquals("::1", getHost("[::1]:1234")); + assertEquals("2001:db8:85a3:8d3:1319:8a2e:370:7348", getHost("PLAINTEXT://[2001:db8:85a3:8d3:1319:8a2e:370:7348]:5678")); + assertEquals("2001:DB8:85A3:8D3:1319:8A2E:370:7348", getHost("PLAINTEXT://[2001:DB8:85A3:8D3:1319:8A2E:370:7348]:5678")); + assertEquals("fe80::b1da:69ca:57f7:63d8%3", getHost("PLAINTEXT://[fe80::b1da:69ca:57f7:63d8%3]:5678")); + } + + @Test + public void testHostPattern() { + assertTrue(validHostPattern("127.0.0.1")); + assertTrue(validHostPattern("mydomain.com")); + assertTrue(validHostPattern("MyDomain.com")); + assertTrue(validHostPattern("My_Domain.com")); + assertTrue(validHostPattern("::1")); + assertTrue(validHostPattern("2001:db8:85a3:8d3:1319:8a2e:370")); + } + + @Test + public void testGetPort() { + assertEquals(8000, getPort("127.0.0.1:8000").intValue()); + assertEquals(8080, getPort("mydomain.com:8080").intValue()); + assertEquals(8080, getPort("MyDomain.com:8080").intValue()); + assertEquals(1234, getPort("[::1]:1234").intValue()); + assertEquals(5678, getPort("[2001:db8:85a3:8d3:1319:8a2e:370:7348]:5678").intValue()); + assertEquals(5678, getPort("[2001:DB8:85A3:8D3:1319:8A2E:370:7348]:5678").intValue()); + assertEquals(5678, getPort("[fe80::b1da:69ca:57f7:63d8%3]:5678").intValue()); + } + + @Test + public void testFormatAddress() { + assertEquals("127.0.0.1:8000", formatAddress("127.0.0.1", 8000)); + assertEquals("mydomain.com:8080", formatAddress("mydomain.com", 8080)); + assertEquals("[::1]:1234", formatAddress("::1", 1234)); + assertEquals("[2001:db8:85a3:8d3:1319:8a2e:370:7348]:5678", formatAddress("2001:db8:85a3:8d3:1319:8a2e:370:7348", 5678)); + } + + @Test + public void testFormatBytes() { + assertEquals("-1", formatBytes(-1)); + assertEquals("1023 B", formatBytes(1023)); + assertEquals("1 KB", formatBytes(1024)); + assertEquals("1024 KB", formatBytes((1024 * 1024) - 1)); + assertEquals("1 MB", formatBytes(1024 * 1024)); + assertEquals("1.1 MB", formatBytes((long) (1.1 * 1024 * 1024))); + assertEquals("10 MB", formatBytes(10 * 1024 * 1024)); + } + + @Test + public void testJoin() { + assertEquals("", Utils.join(Collections.emptyList(), ",")); + assertEquals("1", Utils.join(asList("1"), ",")); + assertEquals("1,2,3", Utils.join(asList(1, 2, 3), ",")); + } + + @Test + public void testAbs() { + assertEquals(0, Utils.abs(Integer.MIN_VALUE)); + assertEquals(10, Utils.abs(-10)); + assertEquals(10, Utils.abs(10)); + assertEquals(0, Utils.abs(0)); + assertEquals(1, Utils.abs(-1)); + } + + @Test + public void writeToBuffer() throws IOException { + byte[] input = {0, 1, 2, 3, 4, 5}; + ByteBuffer source = ByteBuffer.wrap(input); + + doTestWriteToByteBuffer(source, ByteBuffer.allocate(input.length)); + doTestWriteToByteBuffer(source, ByteBuffer.allocateDirect(input.length)); + assertEquals(0, source.position()); + + source.position(2); + doTestWriteToByteBuffer(source, ByteBuffer.allocate(input.length)); + doTestWriteToByteBuffer(source, ByteBuffer.allocateDirect(input.length)); + } + + private void doTestWriteToByteBuffer(ByteBuffer source, ByteBuffer dest) throws IOException { + int numBytes = source.remaining(); + int position = source.position(); + DataOutputStream out = new DataOutputStream(new ByteBufferOutputStream(dest)); + Utils.writeTo(out, source, source.remaining()); + dest.flip(); + assertEquals(numBytes, dest.remaining()); + assertEquals(position, source.position()); + assertEquals(source, dest); + } + + @Test + public void toArray() { + byte[] input = {0, 1, 2, 3, 4}; + ByteBuffer buffer = ByteBuffer.wrap(input); + assertArrayEquals(input, Utils.toArray(buffer)); + assertEquals(0, buffer.position()); + + assertArrayEquals(new byte[] {1, 2}, Utils.toArray(buffer, 1, 2)); + assertEquals(0, buffer.position()); + + buffer.position(2); + assertArrayEquals(new byte[] {2, 3, 4}, Utils.toArray(buffer)); + assertEquals(2, buffer.position()); + } + + @Test + public void toArrayDirectByteBuffer() { + byte[] input = {0, 1, 2, 3, 4}; + ByteBuffer buffer = ByteBuffer.allocateDirect(5); + buffer.put(input); + buffer.rewind(); + + assertArrayEquals(input, Utils.toArray(buffer)); + assertEquals(0, buffer.position()); + + assertArrayEquals(new byte[] {1, 2}, Utils.toArray(buffer, 1, 2)); + assertEquals(0, buffer.position()); + + buffer.position(2); + assertArrayEquals(new byte[] {2, 3, 4}, Utils.toArray(buffer)); + assertEquals(2, buffer.position()); + } + + @Test + public void getNullableSizePrefixedArrayExact() { + byte[] input = {0, 0, 0, 2, 1, 0}; + final ByteBuffer buffer = ByteBuffer.wrap(input); + final byte[] array = Utils.getNullableSizePrefixedArray(buffer); + assertArrayEquals(new byte[] {1, 0}, array); + assertEquals(6, buffer.position()); + assertFalse(buffer.hasRemaining()); + } + + @Test + public void getNullableSizePrefixedArrayExactEmpty() { + byte[] input = {0, 0, 0, 0}; + final ByteBuffer buffer = ByteBuffer.wrap(input); + final byte[] array = Utils.getNullableSizePrefixedArray(buffer); + assertArrayEquals(new byte[] {}, array); + assertEquals(4, buffer.position()); + assertFalse(buffer.hasRemaining()); + } + + @Test + public void getNullableSizePrefixedArrayRemainder() { + byte[] input = {0, 0, 0, 2, 1, 0, 9}; + final ByteBuffer buffer = ByteBuffer.wrap(input); + final byte[] array = Utils.getNullableSizePrefixedArray(buffer); + assertArrayEquals(new byte[] {1, 0}, array); + assertEquals(6, buffer.position()); + assertTrue(buffer.hasRemaining()); + } + + @Test + public void getNullableSizePrefixedArrayNull() { + // -1 + byte[] input = {-1, -1, -1, -1}; + final ByteBuffer buffer = ByteBuffer.wrap(input); + final byte[] array = Utils.getNullableSizePrefixedArray(buffer); + assertNull(array); + assertEquals(4, buffer.position()); + assertFalse(buffer.hasRemaining()); + } + + @Test + public void getNullableSizePrefixedArrayInvalid() { + // -2 + byte[] input = {-1, -1, -1, -2}; + final ByteBuffer buffer = ByteBuffer.wrap(input); + assertThrows(NegativeArraySizeException.class, () -> Utils.getNullableSizePrefixedArray(buffer)); + } + + @Test + public void getNullableSizePrefixedArrayUnderflow() { + // Integer.MAX_VALUE + byte[] input = {127, -1, -1, -1}; + final ByteBuffer buffer = ByteBuffer.wrap(input); + // note, we get a buffer underflow exception instead of an OOME, even though the encoded size + // would be 2,147,483,647 aka 2.1 GB, probably larger than the available heap + assertThrows(BufferUnderflowException.class, () -> Utils.getNullableSizePrefixedArray(buffer)); + } + + @Test + public void utf8ByteArraySerde() { + String utf8String = "A\u00ea\u00f1\u00fcC"; + byte[] utf8Bytes = utf8String.getBytes(StandardCharsets.UTF_8); + assertArrayEquals(utf8Bytes, Utils.utf8(utf8String)); + assertEquals(utf8Bytes.length, Utils.utf8Length(utf8String)); + assertEquals(utf8String, Utils.utf8(utf8Bytes)); + } + + @Test + public void utf8ByteBufferSerde() { + doTestUtf8ByteBuffer(ByteBuffer.allocate(20)); + doTestUtf8ByteBuffer(ByteBuffer.allocateDirect(20)); + } + + private void doTestUtf8ByteBuffer(ByteBuffer utf8Buffer) { + String utf8String = "A\u00ea\u00f1\u00fcC"; + byte[] utf8Bytes = utf8String.getBytes(StandardCharsets.UTF_8); + + utf8Buffer.position(4); + utf8Buffer.put(utf8Bytes); + + utf8Buffer.position(4); + assertEquals(utf8String, Utils.utf8(utf8Buffer, utf8Bytes.length)); + assertEquals(4, utf8Buffer.position()); + + utf8Buffer.position(0); + assertEquals(utf8String, Utils.utf8(utf8Buffer, 4, utf8Bytes.length)); + assertEquals(0, utf8Buffer.position()); + } + + private void subTest(ByteBuffer buffer) { + // The first byte should be 'A' + assertEquals('A', (Utils.readBytes(buffer, 0, 1))[0]); + + // The offset is 2, so the first 2 bytes should be skipped. + byte[] results = Utils.readBytes(buffer, 2, 3); + assertEquals('y', results[0]); + assertEquals(' ', results[1]); + assertEquals('S', results[2]); + assertEquals(3, results.length); + + // test readBytes without offset and length specified. + results = Utils.readBytes(buffer); + assertEquals('A', results[0]); + assertEquals('t', results[buffer.limit() - 1]); + assertEquals(buffer.limit(), results.length); + } + + @Test + public void testReadBytes() { + byte[] myvar = "Any String you want".getBytes(); + ByteBuffer buffer = ByteBuffer.allocate(myvar.length); + buffer.put(myvar); + buffer.rewind(); + + this.subTest(buffer); + + // test readonly buffer, different path + buffer = ByteBuffer.wrap(myvar).asReadOnlyBuffer(); + this.subTest(buffer); + } + + @Test + public void testFileAsStringSimpleFile() throws IOException { + File tempFile = TestUtils.tempFile(); + try { + String testContent = "Test Content"; + Files.write(tempFile.toPath(), testContent.getBytes()); + assertEquals(testContent, Utils.readFileAsString(tempFile.getPath())); + } finally { + Files.deleteIfExists(tempFile.toPath()); + } + } + + /** + * Test to read content of named pipe as string. As reading/writing to a pipe can block, + * timeout test after a minute (test finishes within 100 ms normally). + */ + @Timeout(60) + @Test + public void testFileAsStringNamedPipe() throws Exception { + + // Create a temporary name for named pipe + Random random = new Random(); + long n = random.nextLong(); + n = n == Long.MIN_VALUE ? 0 : Math.abs(n); + + // Use the name to create a FIFO in tmp directory + String tmpDir = System.getProperty("java.io.tmpdir"); + String fifoName = "fifo-" + n + ".tmp"; + File fifo = new File(tmpDir, fifoName); + Thread producerThread = null; + try { + Process mkFifoCommand = new ProcessBuilder("mkfifo", fifo.getCanonicalPath()).start(); + mkFifoCommand.waitFor(); + + // Send some data to fifo and then read it back, but as FIFO blocks if the consumer isn't present, + // we need to send data in a separate thread. + final String testFileContent = "This is test"; + producerThread = new Thread(() -> { + try { + Files.write(fifo.toPath(), testFileContent.getBytes()); + } catch (IOException e) { + fail("Error when producing to fifo : " + e.getMessage()); + } + }, "FIFO-Producer"); + producerThread.start(); + + assertEquals(testFileContent, Utils.readFileAsString(fifo.getCanonicalPath())); + } finally { + Files.deleteIfExists(fifo.toPath()); + if (producerThread != null) { + producerThread.join(30 * 1000); // Wait for thread to terminate + assertFalse(producerThread.isAlive()); + } + } + } + + @Test + public void testMin() { + assertEquals(1, Utils.min(1)); + assertEquals(1, Utils.min(1, 2, 3)); + assertEquals(1, Utils.min(2, 1, 3)); + assertEquals(1, Utils.min(2, 3, 1)); + } + + @Test + public void testCloseAll() { + TestCloseable[] closeablesWithoutException = TestCloseable.createCloseables(false, false, false); + try { + Utils.closeAll(closeablesWithoutException); + TestCloseable.checkClosed(closeablesWithoutException); + } catch (IOException e) { + fail("Unexpected exception: " + e); + } + + TestCloseable[] closeablesWithException = TestCloseable.createCloseables(true, true, true); + try { + Utils.closeAll(closeablesWithException); + fail("Expected exception not thrown"); + } catch (IOException e) { + TestCloseable.checkClosed(closeablesWithException); + TestCloseable.checkException(e, closeablesWithException); + } + + TestCloseable[] singleExceptionCloseables = TestCloseable.createCloseables(false, true, false); + try { + Utils.closeAll(singleExceptionCloseables); + fail("Expected exception not thrown"); + } catch (IOException e) { + TestCloseable.checkClosed(singleExceptionCloseables); + TestCloseable.checkException(e, singleExceptionCloseables[1]); + } + + TestCloseable[] mixedCloseables = TestCloseable.createCloseables(false, true, false, true, true); + try { + Utils.closeAll(mixedCloseables); + fail("Expected exception not thrown"); + } catch (IOException e) { + TestCloseable.checkClosed(mixedCloseables); + TestCloseable.checkException(e, mixedCloseables[1], mixedCloseables[3], mixedCloseables[4]); + } + } + + @Test + public void testReadFullyOrFailWithRealFile() throws IOException { + try (FileChannel channel = FileChannel.open(TestUtils.tempFile().toPath(), StandardOpenOption.READ, StandardOpenOption.WRITE)) { + // prepare channel + String msg = "hello, world"; + channel.write(ByteBuffer.wrap(msg.getBytes()), 0); + channel.force(true); + assertEquals(channel.size(), msg.length(), "Message should be written to the file channel"); + + ByteBuffer perfectBuffer = ByteBuffer.allocate(msg.length()); + ByteBuffer smallBuffer = ByteBuffer.allocate(5); + ByteBuffer largeBuffer = ByteBuffer.allocate(msg.length() + 1); + // Scenario 1: test reading into a perfectly-sized buffer + Utils.readFullyOrFail(channel, perfectBuffer, 0, "perfect"); + assertFalse(perfectBuffer.hasRemaining(), "Buffer should be filled up"); + assertEquals(msg, new String(perfectBuffer.array()), "Buffer should be populated correctly"); + // Scenario 2: test reading into a smaller buffer + Utils.readFullyOrFail(channel, smallBuffer, 0, "small"); + assertFalse(smallBuffer.hasRemaining(), "Buffer should be filled"); + assertEquals("hello", new String(smallBuffer.array()), "Buffer should be populated correctly"); + // Scenario 3: test reading starting from a non-zero position + smallBuffer.clear(); + Utils.readFullyOrFail(channel, smallBuffer, 7, "small"); + assertFalse(smallBuffer.hasRemaining(), "Buffer should be filled"); + assertEquals("world", new String(smallBuffer.array()), "Buffer should be populated correctly"); + // Scenario 4: test end of stream is reached before buffer is filled up + try { + Utils.readFullyOrFail(channel, largeBuffer, 0, "large"); + fail("Expected EOFException to be raised"); + } catch (EOFException e) { + // expected + } + } + } + + /** + * Tests that `readFullyOrFail` behaves correctly if multiple `FileChannel.read` operations are required to fill + * the destination buffer. + */ + @Test + public void testReadFullyOrFailWithPartialFileChannelReads() throws IOException { + FileChannel channelMock = mock(FileChannel.class); + final int bufferSize = 100; + ByteBuffer buffer = ByteBuffer.allocate(bufferSize); + String expectedBufferContent = fileChannelMockExpectReadWithRandomBytes(channelMock, bufferSize); + Utils.readFullyOrFail(channelMock, buffer, 0L, "test"); + assertEquals(expectedBufferContent, new String(buffer.array()), "The buffer should be populated correctly"); + assertFalse(buffer.hasRemaining(), "The buffer should be filled"); + verify(channelMock, atLeastOnce()).read(any(), anyLong()); + } + + /** + * Tests that `readFullyOrFail` behaves correctly if multiple `FileChannel.read` operations are required to fill + * the destination buffer. + */ + @Test + public void testReadFullyWithPartialFileChannelReads() throws IOException { + FileChannel channelMock = mock(FileChannel.class); + final int bufferSize = 100; + String expectedBufferContent = fileChannelMockExpectReadWithRandomBytes(channelMock, bufferSize); + ByteBuffer buffer = ByteBuffer.allocate(bufferSize); + Utils.readFully(channelMock, buffer, 0L); + assertEquals(expectedBufferContent, new String(buffer.array()), "The buffer should be populated correctly."); + assertFalse(buffer.hasRemaining(), "The buffer should be filled"); + verify(channelMock, atLeastOnce()).read(any(), anyLong()); + } + + @Test + public void testReadFullyIfEofIsReached() throws IOException { + final FileChannel channelMock = mock(FileChannel.class); + final int bufferSize = 100; + final String fileChannelContent = "abcdefghkl"; + ByteBuffer buffer = ByteBuffer.allocate(bufferSize); + when(channelMock.read(any(), anyLong())).then(invocation -> { + ByteBuffer bufferArg = invocation.getArgument(0); + bufferArg.put(fileChannelContent.getBytes()); + return -1; + }); + Utils.readFully(channelMock, buffer, 0L); + assertEquals("abcdefghkl", new String(buffer.array(), 0, buffer.position())); + assertEquals(fileChannelContent.length(), buffer.position()); + assertTrue(buffer.hasRemaining()); + verify(channelMock, atLeastOnce()).read(any(), anyLong()); + } + + @Test + public void testLoadProps() throws IOException { + File tempFile = TestUtils.tempFile(); + try { + String testContent = "a=1\nb=2\n#a comment\n\nc=3\nd="; + Files.write(tempFile.toPath(), testContent.getBytes()); + Properties props = Utils.loadProps(tempFile.getPath()); + assertEquals(4, props.size()); + assertEquals("1", props.get("a")); + assertEquals("2", props.get("b")); + assertEquals("3", props.get("c")); + assertEquals("", props.get("d")); + Properties restrictedProps = Utils.loadProps(tempFile.getPath(), Arrays.asList("b", "d", "e")); + assertEquals(2, restrictedProps.size()); + assertEquals("2", restrictedProps.get("b")); + assertEquals("", restrictedProps.get("d")); + } finally { + Files.deleteIfExists(tempFile.toPath()); + } + } + + /** + * Expectation setter for multiple reads where each one reads random bytes to the buffer. + * + * @param channelMock The mocked FileChannel object + * @param bufferSize The buffer size + * @return Expected buffer string + * @throws IOException If an I/O error occurs + */ + private String fileChannelMockExpectReadWithRandomBytes(final FileChannel channelMock, + final int bufferSize) throws IOException { + final int step = 20; + final Random random = new Random(); + int remainingBytes = bufferSize; + OngoingStubbing when = when(channelMock.read(any(), anyLong())); + StringBuilder expectedBufferContent = new StringBuilder(); + while (remainingBytes > 0) { + final int bytesRead = remainingBytes < step ? remainingBytes : random.nextInt(step); + final String stringRead = IntStream.range(0, bytesRead).mapToObj(i -> "a").collect(Collectors.joining()); + expectedBufferContent.append(stringRead); + when = when.then(invocation -> { + ByteBuffer buffer = invocation.getArgument(0); + buffer.put(stringRead.getBytes()); + return bytesRead; + }); + remainingBytes -= bytesRead; + } + return expectedBufferContent.toString(); + } + + private static class TestCloseable implements Closeable { + private final int id; + private final IOException closeException; + private boolean closed; + + TestCloseable(int id, boolean exceptionOnClose) { + this.id = id; + this.closeException = exceptionOnClose ? new IOException("Test close exception " + id) : null; + } + + @Override + public void close() throws IOException { + closed = true; + if (closeException != null) { + throw closeException; + } + } + + static TestCloseable[] createCloseables(boolean... exceptionOnClose) { + TestCloseable[] closeables = new TestCloseable[exceptionOnClose.length]; + for (int i = 0; i < closeables.length; i++) + closeables[i] = new TestCloseable(i, exceptionOnClose[i]); + return closeables; + } + + static void checkClosed(TestCloseable... closeables) { + for (TestCloseable closeable : closeables) + assertTrue(closeable.closed, "Close not invoked for " + closeable.id); + } + + static void checkException(IOException e, TestCloseable... closeablesWithException) { + assertEquals(closeablesWithException[0].closeException, e); + Throwable[] suppressed = e.getSuppressed(); + assertEquals(closeablesWithException.length - 1, suppressed.length); + for (int i = 1; i < closeablesWithException.length; i++) + assertEquals(closeablesWithException[i].closeException, suppressed[i - 1]); + } + } + + @Timeout(120) + @Test + public void testRecursiveDelete() throws IOException { + Utils.delete(null); // delete of null does nothing. + + // Test that deleting a temporary file works. + File tempFile = TestUtils.tempFile(); + Utils.delete(tempFile); + assertFalse(Files.exists(tempFile.toPath())); + + // Test recursive deletes + File tempDir = TestUtils.tempDirectory(); + File tempDir2 = TestUtils.tempDirectory(tempDir.toPath(), "a"); + TestUtils.tempDirectory(tempDir.toPath(), "b"); + TestUtils.tempDirectory(tempDir2.toPath(), "c"); + Utils.delete(tempDir); + assertFalse(Files.exists(tempDir.toPath())); + assertFalse(Files.exists(tempDir2.toPath())); + + // Test that deleting a non-existent directory hierarchy works. + Utils.delete(tempDir); + assertFalse(Files.exists(tempDir.toPath())); + } + + @Test + public void testConvertTo32BitField() { + Set bytes = mkSet((byte) 0, (byte) 1, (byte) 5, (byte) 10, (byte) 31); + int bitField = Utils.to32BitField(bytes); + assertEquals(bytes, Utils.from32BitField(bitField)); + + bytes = new HashSet<>(); + bitField = Utils.to32BitField(bytes); + assertEquals(bytes, Utils.from32BitField(bitField)); + + assertThrows(IllegalArgumentException.class, () -> Utils.to32BitField(mkSet((byte) 0, (byte) 11, (byte) 32))); + } + + @Test + public void testUnion() { + final Set oneSet = mkSet("a", "b", "c"); + final Set anotherSet = mkSet("c", "d", "e"); + final Set union = union(TreeSet::new, oneSet, anotherSet); + + assertEquals(mkSet("a", "b", "c", "d", "e"), union); + assertEquals(TreeSet.class, union.getClass()); + } + + @Test + public void testUnionOfOne() { + final Set oneSet = mkSet("a", "b", "c"); + final Set union = union(TreeSet::new, oneSet); + + assertEquals(mkSet("a", "b", "c"), union); + assertEquals(TreeSet.class, union.getClass()); + } + + @Test + public void testUnionOfMany() { + final Set oneSet = mkSet("a", "b", "c"); + final Set twoSet = mkSet("c", "d", "e"); + final Set threeSet = mkSet("b", "c", "d"); + final Set fourSet = mkSet("x", "y", "z"); + final Set union = union(TreeSet::new, oneSet, twoSet, threeSet, fourSet); + + assertEquals(mkSet("a", "b", "c", "d", "e", "x", "y", "z"), union); + assertEquals(TreeSet.class, union.getClass()); + } + + @Test + public void testUnionOfNone() { + final Set union = union(TreeSet::new); + + assertEquals(emptySet(), union); + assertEquals(TreeSet.class, union.getClass()); + } + + @Test + public void testIntersection() { + final Set oneSet = mkSet("a", "b", "c"); + final Set anotherSet = mkSet("c", "d", "e"); + final Set intersection = intersection(TreeSet::new, oneSet, anotherSet); + + assertEquals(mkSet("c"), intersection); + assertEquals(TreeSet.class, intersection.getClass()); + } + + @Test + public void testIntersectionOfOne() { + final Set oneSet = mkSet("a", "b", "c"); + final Set intersection = intersection(TreeSet::new, oneSet); + + assertEquals(mkSet("a", "b", "c"), intersection); + assertEquals(TreeSet.class, intersection.getClass()); + } + + @Test + public void testIntersectionOfMany() { + final Set oneSet = mkSet("a", "b", "c"); + final Set twoSet = mkSet("c", "d", "e"); + final Set threeSet = mkSet("b", "c", "d"); + final Set intersection = intersection(TreeSet::new, oneSet, twoSet, threeSet); + + assertEquals(mkSet("c"), intersection); + assertEquals(TreeSet.class, intersection.getClass()); + } + + @Test + public void testDisjointIntersectionOfMany() { + final Set oneSet = mkSet("a", "b", "c"); + final Set twoSet = mkSet("c", "d", "e"); + final Set threeSet = mkSet("b", "c", "d"); + final Set fourSet = mkSet("x", "y", "z"); + final Set intersection = intersection(TreeSet::new, oneSet, twoSet, threeSet, fourSet); + + assertEquals(emptySet(), intersection); + assertEquals(TreeSet.class, intersection.getClass()); + } + + @Test + public void testDiff() { + final Set oneSet = mkSet("a", "b", "c"); + final Set anotherSet = mkSet("c", "d", "e"); + final Set diff = diff(TreeSet::new, oneSet, anotherSet); + + assertEquals(mkSet("a", "b"), diff); + assertEquals(TreeSet.class, diff.getClass()); + } + + @Test + public void testPropsToMap() { + assertThrows(ConfigException.class, () -> { + Properties props = new Properties(); + props.put(1, 2); + Utils.propsToMap(props); + }); + assertValue(false); + assertValue(1); + assertValue("string"); + assertValue(1.1); + assertValue(Collections.emptySet()); + assertValue(Collections.emptyList()); + assertValue(Collections.emptyMap()); + } + + private static void assertValue(Object value) { + Properties props = new Properties(); + props.put("key", value); + assertEquals(Utils.propsToMap(props).get("key"), value); + } + + @Test + public void testCloseAllQuietly() { + AtomicReference exception = new AtomicReference<>(); + String msg = "you should fail"; + AtomicInteger count = new AtomicInteger(0); + AutoCloseable c0 = () -> { + throw new RuntimeException(msg); + }; + AutoCloseable c1 = count::incrementAndGet; + Utils.closeAllQuietly(exception, "test", Stream.of(c0, c1).toArray(AutoCloseable[]::new)); + assertEquals(msg, exception.get().getMessage()); + assertEquals(1, count.get()); + } + + @Test + public void shouldAcceptValidDateFormats() throws ParseException { + //check valid formats + invokeGetDateTimeMethod(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS")); + invokeGetDateTimeMethod(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSZ")); + invokeGetDateTimeMethod(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSX")); + invokeGetDateTimeMethod(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSXX")); + invokeGetDateTimeMethod(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSXXX")); + } + + @Test + public void shouldThrowOnInvalidDateFormatOrNullTimestamp() { + // check some invalid formats + // test null timestamp + assertTrue(assertThrows(IllegalArgumentException.class, () -> { + Utils.getDateTime(null); + }).getMessage().contains("Error parsing timestamp with null value")); + + // test pattern: yyyy-MM-dd'T'HH:mm:ss.X + checkExceptionForGetDateTimeMethod(() -> { + invokeGetDateTimeMethod(new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.X")); + }); + + // test pattern: yyyy-MM-dd HH:mm:ss + assertTrue(assertThrows(ParseException.class, () -> { + invokeGetDateTimeMethod(new SimpleDateFormat("yyyy-MM-dd HH:mm:ss")); + }).getMessage().contains("It does not contain a 'T' according to ISO8601 format")); + + // KAFKA-10685: use DateTimeFormatter generate micro/nano second timestamp + final DateTimeFormatter formatter = new DateTimeFormatterBuilder() + .appendPattern("yyyy-MM-dd'T'HH:mm:ss") + .appendFraction(ChronoField.NANO_OF_SECOND, 0, 9, true) + .toFormatter(); + final LocalDateTime timestampWithNanoSeconds = LocalDateTime.of(2020, 11, 9, 12, 34, 56, 123456789); + final LocalDateTime timestampWithMicroSeconds = timestampWithNanoSeconds.truncatedTo(ChronoUnit.MICROS); + final LocalDateTime timestampWithSeconds = timestampWithNanoSeconds.truncatedTo(ChronoUnit.SECONDS); + + // test pattern: yyyy-MM-dd'T'HH:mm:ss.SSSSSSSSS + checkExceptionForGetDateTimeMethod(() -> { + Utils.getDateTime(formatter.format(timestampWithNanoSeconds)); + }); + + // test pattern: yyyy-MM-dd'T'HH:mm:ss.SSSSSS + checkExceptionForGetDateTimeMethod(() -> { + Utils.getDateTime(formatter.format(timestampWithMicroSeconds)); + }); + + // test pattern: yyyy-MM-dd'T'HH:mm:ss + checkExceptionForGetDateTimeMethod(() -> { + Utils.getDateTime(formatter.format(timestampWithSeconds)); + }); + } + + private void checkExceptionForGetDateTimeMethod(Executable executable) { + assertTrue(assertThrows(ParseException.class, executable) + .getMessage().contains("Unparseable date")); + } + + private void invokeGetDateTimeMethod(final SimpleDateFormat format) throws ParseException { + final Date checkpoint = new Date(); + final String formattedCheckpoint = format.format(checkpoint); + Utils.getDateTime(formattedCheckpoint); + } + + @Test + void testIsBlank() { + assertTrue(Utils.isBlank(null)); + assertTrue(Utils.isBlank("")); + assertTrue(Utils.isBlank(" ")); + assertFalse(Utils.isBlank("bob")); + assertFalse(Utils.isBlank(" bob ")); + } + + @Test + public void testCharacterArrayEquality() { + assertCharacterArraysAreNotEqual(null, "abc"); + assertCharacterArraysAreNotEqual(null, ""); + assertCharacterArraysAreNotEqual("abc", null); + assertCharacterArraysAreNotEqual("", null); + assertCharacterArraysAreNotEqual("", "abc"); + assertCharacterArraysAreNotEqual("abc", "abC"); + assertCharacterArraysAreNotEqual("abc", "abcd"); + assertCharacterArraysAreNotEqual("abc", "abcdefg"); + assertCharacterArraysAreNotEqual("abcdefg", "abc"); + assertCharacterArraysAreEqual("abc", "abc"); + assertCharacterArraysAreEqual("a", "a"); + assertCharacterArraysAreEqual("", ""); + assertCharacterArraysAreEqual("", ""); + assertCharacterArraysAreEqual(null, null); + } + + private void assertCharacterArraysAreNotEqual(String a, String b) { + char[] first = a != null ? a.toCharArray() : null; + char[] second = b != null ? b.toCharArray() : null; + if (a == null) { + assertNotNull(b); + } else { + assertFalse(a.equals(b)); + } + assertFalse(Utils.isEqualConstantTime(first, second)); + assertFalse(Utils.isEqualConstantTime(second, first)); + } + + private void assertCharacterArraysAreEqual(String a, String b) { + char[] first = a != null ? a.toCharArray() : null; + char[] second = b != null ? b.toCharArray() : null; + if (a == null) { + assertNull(b); + } else { + assertTrue(a.equals(b)); + } + assertTrue(Utils.isEqualConstantTime(first, second)); + assertTrue(Utils.isEqualConstantTime(second, first)); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/annotation/ApiKeyVersionsProvider.java b/clients/src/test/java/org/apache/kafka/common/utils/annotation/ApiKeyVersionsProvider.java new file mode 100644 index 0000000..2a1f6e4 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/annotation/ApiKeyVersionsProvider.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils.annotation; + +import java.util.stream.Stream; +import org.apache.kafka.common.protocol.ApiKeys; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.ArgumentsProvider; +import org.junit.jupiter.params.support.AnnotationConsumer; + +public class ApiKeyVersionsProvider implements ArgumentsProvider, AnnotationConsumer { + private ApiKeys apiKey; + + public void accept(ApiKeyVersionsSource source) { + apiKey = source.apiKey(); + } + + public Stream provideArguments(ExtensionContext context) { + return apiKey.allVersions().stream().map(Arguments::of); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/utils/annotation/ApiKeyVersionsSource.java b/clients/src/test/java/org/apache/kafka/common/utils/annotation/ApiKeyVersionsSource.java new file mode 100644 index 0000000..9f169b3 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/utils/annotation/ApiKeyVersionsSource.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.utils.annotation; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import org.apache.kafka.common.protocol.ApiKeys; +import org.junit.jupiter.params.provider.ArgumentsSource; + +@Target({ElementType.METHOD}) +@Retention(RetentionPolicy.RUNTIME) +@ArgumentsSource(ApiKeyVersionsProvider.class) +public @interface ApiKeyVersionsSource { + ApiKeys apiKey(); +} diff --git a/clients/src/test/java/org/apache/kafka/test/DelayedReceive.java b/clients/src/test/java/org/apache/kafka/test/DelayedReceive.java new file mode 100644 index 0000000..37d8f74 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/test/DelayedReceive.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.common.network.NetworkReceive; + +/** + * Used by MockSelector to allow clients to add responses whose associated requests are added later. + */ +public class DelayedReceive { + private final String source; + private final NetworkReceive receive; + + public DelayedReceive(String source, NetworkReceive receive) { + this.source = source; + this.receive = receive; + } + + public String source() { + return source; + } + + public NetworkReceive receive() { + return receive; + } +} diff --git a/clients/src/test/java/org/apache/kafka/test/IntegrationTest.java b/clients/src/test/java/org/apache/kafka/test/IntegrationTest.java new file mode 100644 index 0000000..c73a681 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/test/IntegrationTest.java @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +public interface IntegrationTest { +} diff --git a/clients/src/test/java/org/apache/kafka/test/MetricsBench.java b/clients/src/test/java/org/apache/kafka/test/MetricsBench.java new file mode 100644 index 0000000..93cbf6d --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/test/MetricsBench.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import java.util.Arrays; + +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.stats.Avg; +import org.apache.kafka.common.metrics.stats.Max; +import org.apache.kafka.common.metrics.stats.Percentile; +import org.apache.kafka.common.metrics.stats.Percentiles; +import org.apache.kafka.common.metrics.stats.Percentiles.BucketSizing; +import org.apache.kafka.common.metrics.stats.WindowedCount; + +public class MetricsBench { + + public static void main(String[] args) { + long iters = Long.parseLong(args[0]); + Metrics metrics = new Metrics(); + try { + Sensor parent = metrics.sensor("parent"); + Sensor child = metrics.sensor("child", parent); + for (Sensor sensor : Arrays.asList(parent, child)) { + sensor.add(metrics.metricName(sensor.name() + ".avg", "grp1"), new Avg()); + sensor.add(metrics.metricName(sensor.name() + ".count", "grp1"), new WindowedCount()); + sensor.add(metrics.metricName(sensor.name() + ".max", "grp1"), new Max()); + sensor.add(new Percentiles(1024, + 0.0, + iters, + BucketSizing.CONSTANT, + new Percentile(metrics.metricName(sensor.name() + ".median", "grp1"), 50.0), + new Percentile(metrics.metricName(sensor.name() + ".p_99", "grp1"), 99.0))); + } + long start = System.nanoTime(); + for (int i = 0; i < iters; i++) + parent.record(i); + double ellapsed = (System.nanoTime() - start) / (double) iters; + System.out.println(String.format("%.2f ns per metric recording.", ellapsed)); + } finally { + metrics.close(); + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/test/Microbenchmarks.java b/clients/src/test/java/org/apache/kafka/test/Microbenchmarks.java new file mode 100644 index 0000000..cfb5f6c --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/test/Microbenchmarks.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.locks.ReentrantLock; + +import org.apache.kafka.common.utils.CopyOnWriteMap; +import org.apache.kafka.common.utils.Time; + +public class Microbenchmarks { + + public static void main(String[] args) throws Exception { + + final int iters = Integer.parseInt(args[0]); + double x = 0.0; + long start = System.nanoTime(); + for (int i = 0; i < iters; i++) + x += Math.sqrt(x); + System.out.println(x); + System.out.println("sqrt: " + (System.nanoTime() - start) / (double) iters); + + // test clocks + systemMillis(iters); + systemNanos(iters); + long total = 0; + start = System.nanoTime(); + total += systemMillis(iters); + System.out.println("System.currentTimeMillis(): " + (System.nanoTime() - start) / iters); + start = System.nanoTime(); + total += systemNanos(iters); + System.out.println("System.nanoTime(): " + (System.nanoTime() - start) / iters); + System.out.println(total); + + // test random + int n = 0; + Random random = new Random(); + start = System.nanoTime(); + for (int i = 0; i < iters; i++) { + n += random.nextInt(); + } + System.out.println(n); + System.out.println("random: " + (System.nanoTime() - start) / iters); + + float[] floats = new float[1024]; + for (int i = 0; i < floats.length; i++) + floats[i] = random.nextFloat(); + Arrays.sort(floats); + + int loc = 0; + start = System.nanoTime(); + for (int i = 0; i < iters; i++) + loc += Arrays.binarySearch(floats, floats[i % floats.length]); + System.out.println(loc); + System.out.println("binary search: " + (System.nanoTime() - start) / iters); + + final Time time = Time.SYSTEM; + final AtomicBoolean done = new AtomicBoolean(false); + final Object lock = new Object(); + Thread t1 = new Thread() { + public void run() { + time.sleep(1); + int counter = 0; + long start = time.nanoseconds(); + for (int i = 0; i < iters; i++) { + synchronized (lock) { + counter++; + } + } + System.out.println("synchronized: " + ((time.nanoseconds() - start) / iters)); + System.out.println(counter); + done.set(true); + } + }; + + Thread t2 = new Thread() { + public void run() { + int counter = 0; + while (!done.get()) { + time.sleep(1); + synchronized (lock) { + counter += 1; + } + } + System.out.println("Counter: " + counter); + } + }; + + t1.start(); + t2.start(); + t1.join(); + t2.join(); + + System.out.println("Testing locks"); + done.set(false); + final ReentrantLock lock2 = new ReentrantLock(); + Thread t3 = new Thread() { + public void run() { + time.sleep(1); + int counter = 0; + long start = time.nanoseconds(); + for (int i = 0; i < iters; i++) { + lock2.lock(); + counter++; + lock2.unlock(); + } + System.out.println("lock: " + ((time.nanoseconds() - start) / iters)); + System.out.println(counter); + done.set(true); + } + }; + + Thread t4 = new Thread() { + public void run() { + int counter = 0; + while (!done.get()) { + time.sleep(1); + lock2.lock(); + counter++; + lock2.unlock(); + } + System.out.println("Counter: " + counter); + } + }; + + t3.start(); + t4.start(); + t3.join(); + t4.join(); + + Map values = new HashMap(); + for (int i = 0; i < 100; i++) + values.put(Integer.toString(i), i); + System.out.println("HashMap:"); + benchMap(2, 1000000, values); + System.out.println("ConcurentHashMap:"); + benchMap(2, 1000000, new ConcurrentHashMap(values)); + System.out.println("CopyOnWriteMap:"); + benchMap(2, 1000000, new CopyOnWriteMap(values)); + } + + private static void benchMap(int numThreads, final int iters, final Map map) throws Exception { + final List keys = new ArrayList(map.keySet()); + final List threads = new ArrayList(); + for (int i = 0; i < numThreads; i++) { + threads.add(new Thread() { + public void run() { + long start = System.nanoTime(); + for (int j = 0; j < iters; j++) + map.get(keys.get(j % threads.size())); + System.out.println("Map access time: " + ((System.nanoTime() - start) / (double) iters)); + } + }); + } + for (Thread thread : threads) + thread.start(); + for (Thread thread : threads) + thread.join(); + } + + private static long systemMillis(int iters) { + long total = 0; + for (int i = 0; i < iters; i++) + total += System.currentTimeMillis(); + return total; + } + + private static long systemNanos(int iters) { + long total = 0; + for (int i = 0; i < iters; i++) + total += System.currentTimeMillis(); + return total; + } + +} diff --git a/clients/src/test/java/org/apache/kafka/test/MockClusterResourceListener.java b/clients/src/test/java/org/apache/kafka/test/MockClusterResourceListener.java new file mode 100644 index 0000000..c8185a8 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/test/MockClusterResourceListener.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.common.ClusterResourceListener; +import org.apache.kafka.common.ClusterResource; + +import java.util.concurrent.atomic.AtomicBoolean; + +public class MockClusterResourceListener implements ClusterResourceListener { + + private ClusterResource clusterResource; + public static final AtomicBoolean IS_ON_UPDATE_CALLED = new AtomicBoolean(); + + @Override + public void onUpdate(ClusterResource clusterResource) { + IS_ON_UPDATE_CALLED.set(true); + this.clusterResource = clusterResource; + } + + public ClusterResource clusterResource() { + return clusterResource; + } +} diff --git a/clients/src/test/java/org/apache/kafka/test/MockConsumerInterceptor.java b/clients/src/test/java/org/apache/kafka/test/MockConsumerInterceptor.java new file mode 100644 index 0000000..b01584b --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/test/MockConsumerInterceptor.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.clients.consumer.ConsumerInterceptor; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.ClusterResourceListener; +import org.apache.kafka.common.ClusterResource; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.header.internals.RecordHeaders; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +public class MockConsumerInterceptor implements ClusterResourceListener, ConsumerInterceptor { + public static final AtomicInteger INIT_COUNT = new AtomicInteger(0); + public static final AtomicInteger CLOSE_COUNT = new AtomicInteger(0); + public static final AtomicInteger ON_COMMIT_COUNT = new AtomicInteger(0); + public static final AtomicReference CLUSTER_META = new AtomicReference<>(); + public static final ClusterResource NO_CLUSTER_ID = new ClusterResource("no_cluster_id"); + public static final AtomicReference CLUSTER_ID_BEFORE_ON_CONSUME = new AtomicReference<>(NO_CLUSTER_ID); + + public MockConsumerInterceptor() { + INIT_COUNT.incrementAndGet(); + } + + @Override + public void configure(Map configs) { + // clientId must be in configs + Object clientIdValue = configs.get(ConsumerConfig.CLIENT_ID_CONFIG); + if (clientIdValue == null) + throw new ConfigException("Mock consumer interceptor expects configuration " + ProducerConfig.CLIENT_ID_CONFIG); + } + + @Override + public ConsumerRecords onConsume(ConsumerRecords records) { + + // This will ensure that we get the cluster metadata when onConsume is called for the first time + // as subsequent compareAndSet operations will fail. + CLUSTER_ID_BEFORE_ON_CONSUME.compareAndSet(NO_CLUSTER_ID, CLUSTER_META.get()); + + Map>> recordMap = new HashMap<>(); + for (TopicPartition tp : records.partitions()) { + List> lst = new ArrayList<>(); + for (ConsumerRecord record: records.records(tp)) { + lst.add(new ConsumerRecord<>(record.topic(), record.partition(), record.offset(), + record.timestamp(), record.timestampType(), + record.serializedKeySize(), + record.serializedValueSize(), + record.key(), record.value().toUpperCase(Locale.ROOT), + new RecordHeaders(), Optional.empty())); + } + recordMap.put(tp, lst); + } + return new ConsumerRecords<>(recordMap); + } + + @Override + public void onCommit(Map offsets) { + ON_COMMIT_COUNT.incrementAndGet(); + } + + @Override + public void close() { + CLOSE_COUNT.incrementAndGet(); + } + + public static void resetCounters() { + INIT_COUNT.set(0); + CLOSE_COUNT.set(0); + ON_COMMIT_COUNT.set(0); + CLUSTER_META.set(null); + CLUSTER_ID_BEFORE_ON_CONSUME.set(NO_CLUSTER_ID); + } + + @Override + public void onUpdate(ClusterResource clusterResource) { + CLUSTER_META.set(clusterResource); + } +} diff --git a/clients/src/test/java/org/apache/kafka/test/MockDeserializer.java b/clients/src/test/java/org/apache/kafka/test/MockDeserializer.java new file mode 100644 index 0000000..ac2865e --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/test/MockDeserializer.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.common.ClusterResource; +import org.apache.kafka.common.ClusterResourceListener; +import org.apache.kafka.common.serialization.Deserializer; + +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +public class MockDeserializer implements ClusterResourceListener, Deserializer { + public static AtomicInteger initCount = new AtomicInteger(0); + public static AtomicInteger closeCount = new AtomicInteger(0); + public static AtomicReference clusterMeta = new AtomicReference<>(); + public static ClusterResource noClusterId = new ClusterResource("no_cluster_id"); + public static AtomicReference clusterIdBeforeDeserialize = new AtomicReference<>(noClusterId); + + public boolean isKey; + public Map configs; + + public static void resetStaticVariables() { + initCount = new AtomicInteger(0); + closeCount = new AtomicInteger(0); + clusterMeta = new AtomicReference<>(); + clusterIdBeforeDeserialize = new AtomicReference<>(noClusterId); + } + + public MockDeserializer() { + initCount.incrementAndGet(); + } + + @Override + public void configure(Map configs, boolean isKey) { + this.configs = configs; + this.isKey = isKey; + } + + @Override + public byte[] deserialize(String topic, byte[] data) { + // This will ensure that we get the cluster metadata when deserialize is called for the first time + // as subsequent compareAndSet operations will fail. + clusterIdBeforeDeserialize.compareAndSet(noClusterId, clusterMeta.get()); + return data; + } + + @Override + public void close() { + closeCount.incrementAndGet(); + } + + @Override + public void onUpdate(ClusterResource clusterResource) { + clusterMeta.set(clusterResource); + } +} \ No newline at end of file diff --git a/clients/src/test/java/org/apache/kafka/test/MockMetricsReporter.java b/clients/src/test/java/org/apache/kafka/test/MockMetricsReporter.java new file mode 100644 index 0000000..40521f5 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/test/MockMetricsReporter.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.common.metrics.KafkaMetric; +import org.apache.kafka.common.metrics.MetricsReporter; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + +public class MockMetricsReporter implements MetricsReporter { + public static final AtomicInteger INIT_COUNT = new AtomicInteger(0); + public static final AtomicInteger CLOSE_COUNT = new AtomicInteger(0); + public String clientId; + + public MockMetricsReporter() { + } + + @Override + public void init(List metrics) { + INIT_COUNT.incrementAndGet(); + } + + @Override + public void metricChange(KafkaMetric metric) {} + + @Override + public void metricRemoval(KafkaMetric metric) {} + + @Override + public void close() { + CLOSE_COUNT.incrementAndGet(); + } + + @Override + public void configure(Map configs) { + clientId = (String) configs.get(CommonClientConfigs.CLIENT_ID_CONFIG); + } +} \ No newline at end of file diff --git a/clients/src/test/java/org/apache/kafka/test/MockPartitioner.java b/clients/src/test/java/org/apache/kafka/test/MockPartitioner.java new file mode 100644 index 0000000..af6d6fd --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/test/MockPartitioner.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.common.Cluster; +import org.apache.kafka.clients.producer.Partitioner; + +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + +public class MockPartitioner implements Partitioner { + public static final AtomicInteger INIT_COUNT = new AtomicInteger(0); + public static final AtomicInteger CLOSE_COUNT = new AtomicInteger(0); + + public MockPartitioner() { + INIT_COUNT.incrementAndGet(); + } + + @Override + public void configure(Map configs) { + } + + @Override + public int partition(String topic, Object key, byte[] keyBytes, Object value, byte[] valueBytes, Cluster cluster) { + return 0; + } + + @Override + public void close() { + CLOSE_COUNT.incrementAndGet(); + } + + public static void resetCounters() { + INIT_COUNT.set(0); + CLOSE_COUNT.set(0); + } +} diff --git a/clients/src/test/java/org/apache/kafka/test/MockProducerInterceptor.java b/clients/src/test/java/org/apache/kafka/test/MockProducerInterceptor.java new file mode 100644 index 0000000..133ff56 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/test/MockProducerInterceptor.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerInterceptor; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.ClusterResourceListener; +import org.apache.kafka.common.ClusterResource; +import org.apache.kafka.common.config.ConfigException; + +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +public class MockProducerInterceptor implements ClusterResourceListener, ProducerInterceptor { + public static final AtomicInteger INIT_COUNT = new AtomicInteger(0); + public static final AtomicInteger CLOSE_COUNT = new AtomicInteger(0); + public static final AtomicInteger ONSEND_COUNT = new AtomicInteger(0); + public static final AtomicInteger ON_SUCCESS_COUNT = new AtomicInteger(0); + public static final AtomicInteger ON_ERROR_COUNT = new AtomicInteger(0); + public static final AtomicInteger ON_ERROR_WITH_METADATA_COUNT = new AtomicInteger(0); + public static final AtomicReference CLUSTER_META = new AtomicReference<>(); + public static final ClusterResource NO_CLUSTER_ID = new ClusterResource("no_cluster_id"); + public static final AtomicReference CLUSTER_ID_BEFORE_ON_ACKNOWLEDGEMENT = new AtomicReference<>(NO_CLUSTER_ID); + public static final String APPEND_STRING_PROP = "mock.interceptor.append"; + private String appendStr; + + public MockProducerInterceptor() { + INIT_COUNT.incrementAndGet(); + } + + @Override + public void configure(Map configs) { + // ensure this method is called and expected configs are passed in + Object o = configs.get(APPEND_STRING_PROP); + if (o == null) + throw new ConfigException("Mock producer interceptor expects configuration " + APPEND_STRING_PROP); + if (o instanceof String) + appendStr = (String) o; + + // clientId also must be in configs + Object clientIdValue = configs.get(ProducerConfig.CLIENT_ID_CONFIG); + if (clientIdValue == null) + throw new ConfigException("Mock producer interceptor expects configuration " + ProducerConfig.CLIENT_ID_CONFIG); + } + + @Override + public ProducerRecord onSend(ProducerRecord record) { + ONSEND_COUNT.incrementAndGet(); + return new ProducerRecord<>( + record.topic(), record.partition(), record.key(), record.value().concat(appendStr)); + } + + @Override + public void onAcknowledgement(RecordMetadata metadata, Exception exception) { + // This will ensure that we get the cluster metadata when onAcknowledgement is called for the first time + // as subsequent compareAndSet operations will fail. + CLUSTER_ID_BEFORE_ON_ACKNOWLEDGEMENT.compareAndSet(NO_CLUSTER_ID, CLUSTER_META.get()); + + if (exception != null) { + ON_ERROR_COUNT.incrementAndGet(); + if (metadata != null) { + ON_ERROR_WITH_METADATA_COUNT.incrementAndGet(); + } + } else if (metadata != null) + ON_SUCCESS_COUNT.incrementAndGet(); + } + + @Override + public void close() { + CLOSE_COUNT.incrementAndGet(); + } + + public static void resetCounters() { + INIT_COUNT.set(0); + CLOSE_COUNT.set(0); + ONSEND_COUNT.set(0); + ON_SUCCESS_COUNT.set(0); + ON_ERROR_COUNT.set(0); + ON_ERROR_WITH_METADATA_COUNT.set(0); + CLUSTER_META.set(null); + CLUSTER_ID_BEFORE_ON_ACKNOWLEDGEMENT.set(NO_CLUSTER_ID); + } + + @Override + public void onUpdate(ClusterResource clusterResource) { + CLUSTER_META.set(clusterResource); + } +} \ No newline at end of file diff --git a/clients/src/test/java/org/apache/kafka/test/MockSelector.java b/clients/src/test/java/org/apache/kafka/test/MockSelector.java new file mode 100644 index 0000000..d1d79dc --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/test/MockSelector.java @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.network.ChannelState; +import org.apache.kafka.common.network.NetworkReceive; +import org.apache.kafka.common.network.NetworkSend; +import org.apache.kafka.common.network.Selectable; +import org.apache.kafka.common.requests.ByteBufferChannel; +import org.apache.kafka.common.utils.Time; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.function.Predicate; + +/** + * A fake selector to use for testing + */ +public class MockSelector implements Selectable { + + private final Time time; + private final List initiatedSends = new ArrayList<>(); + private final List completedSends = new ArrayList<>(); + private final List completedSendBuffers = new ArrayList<>(); + private final List completedReceives = new ArrayList<>(); + private final Map disconnected = new HashMap<>(); + private final List connected = new ArrayList<>(); + private final List delayedReceives = new ArrayList<>(); + private final Predicate canConnect; + + public MockSelector(Time time) { + this(time, null); + } + + public MockSelector(Time time, Predicate canConnect) { + this.time = time; + this.canConnect = canConnect; + } + + @Override + public void connect(String id, InetSocketAddress address, int sendBufferSize, int receiveBufferSize) throws IOException { + if (canConnect == null || canConnect.test(address)) { + this.connected.add(id); + } + } + + @Override + public void wakeup() { + } + + @Override + public void close() { + } + + @Override + public void close(String id) { + // Note that there are no notifications for client-side disconnects + + removeSendsForNode(id, completedSends); + removeSendsForNode(id, initiatedSends); + + for (int i = 0; i < this.connected.size(); i++) { + if (this.connected.get(i).equals(id)) { + this.connected.remove(i); + break; + } + } + } + + /** + * Since MockSelector.connect will always succeed and add the + * connection id to the Set connected, we can only simulate + * that the connection is still pending by remove the connection + * id from the Set connected + * + * @param id connection id + */ + public void serverConnectionBlocked(String id) { + this.connected.remove(id); + } + + /** + * Simulate a server disconnect. This id will be present in {@link #disconnected()} on + * the next {@link #poll(long)}. + */ + public void serverDisconnect(String id) { + this.disconnected.put(id, ChannelState.READY); + close(id); + } + + public void serverAuthenticationFailed(String id) { + ChannelState authFailed = new ChannelState(ChannelState.State.AUTHENTICATION_FAILED, + new AuthenticationException("Authentication failed"), null); + this.disconnected.put(id, authFailed); + close(id); + } + + private void removeSendsForNode(String id, Collection sends) { + sends.removeIf(send -> id.equals(send.destinationId())); + } + + public void clear() { + this.completedSends.clear(); + this.completedReceives.clear(); + this.completedSendBuffers.clear(); + this.disconnected.clear(); + this.connected.clear(); + } + + @Override + public void send(NetworkSend send) { + this.initiatedSends.add(send); + } + + @Override + public void poll(long timeout) throws IOException { + completeInitiatedSends(); + completeDelayedReceives(); + time.sleep(timeout); + } + + private void completeInitiatedSends() throws IOException { + for (NetworkSend send : initiatedSends) { + completeSend(send); + } + this.initiatedSends.clear(); + } + + private void completeSend(NetworkSend send) throws IOException { + // Consume the send so that we will be able to send more requests to the destination + try (ByteBufferChannel discardChannel = new ByteBufferChannel(send.size())) { + while (!send.completed()) { + send.writeTo(discardChannel); + } + completedSends.add(send); + completedSendBuffers.add(discardChannel); + } + } + + private void completeDelayedReceives() { + for (NetworkSend completedSend : completedSends) { + Iterator delayedReceiveIterator = delayedReceives.iterator(); + while (delayedReceiveIterator.hasNext()) { + DelayedReceive delayedReceive = delayedReceiveIterator.next(); + if (delayedReceive.source().equals(completedSend.destinationId())) { + completedReceives.add(delayedReceive.receive()); + delayedReceiveIterator.remove(); + } + } + } + } + + @Override + public List completedSends() { + return completedSends; + } + + public List completedSendBuffers() { + return completedSendBuffers; + } + + @Override + public List completedReceives() { + return completedReceives; + } + + public void completeReceive(NetworkReceive receive) { + this.completedReceives.add(receive); + } + + public void delayedReceive(DelayedReceive receive) { + this.delayedReceives.add(receive); + } + + @Override + public Map disconnected() { + return disconnected; + } + + @Override + public List connected() { + List currentConnected = new ArrayList<>(connected); + connected.clear(); + return currentConnected; + } + + @Override + public void mute(String id) { + } + + @Override + public void unmute(String id) { + } + + @Override + public void muteAll() { + } + + @Override + public void unmuteAll() { + } + + @Override + public boolean isChannelReady(String id) { + return true; + } + + public void reset() { + clear(); + initiatedSends.clear(); + delayedReceives.clear(); + } +} diff --git a/clients/src/test/java/org/apache/kafka/test/MockSerializer.java b/clients/src/test/java/org/apache/kafka/test/MockSerializer.java new file mode 100644 index 0000000..1c14445 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/test/MockSerializer.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.common.ClusterResourceListener; +import org.apache.kafka.common.ClusterResource; +import org.apache.kafka.common.serialization.Serializer; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +public class MockSerializer implements ClusterResourceListener, Serializer { + public static final AtomicInteger INIT_COUNT = new AtomicInteger(0); + public static final AtomicInteger CLOSE_COUNT = new AtomicInteger(0); + public static final AtomicReference CLUSTER_META = new AtomicReference<>(); + public static final ClusterResource NO_CLUSTER_ID = new ClusterResource("no_cluster_id"); + public static final AtomicReference CLUSTER_ID_BEFORE_SERIALIZE = new AtomicReference<>(NO_CLUSTER_ID); + + public MockSerializer() { + INIT_COUNT.incrementAndGet(); + } + + @Override + public byte[] serialize(String topic, byte[] data) { + // This will ensure that we get the cluster metadata when serialize is called for the first time + // as subsequent compareAndSet operations will fail. + CLUSTER_ID_BEFORE_SERIALIZE.compareAndSet(NO_CLUSTER_ID, CLUSTER_META.get()); + return data; + } + + @Override + public void close() { + CLOSE_COUNT.incrementAndGet(); + } + + @Override + public void onUpdate(ClusterResource clusterResource) { + CLUSTER_META.set(clusterResource); + } +} diff --git a/clients/src/test/java/org/apache/kafka/test/NoRetryException.java b/clients/src/test/java/org/apache/kafka/test/NoRetryException.java new file mode 100644 index 0000000..a6f7db9 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/test/NoRetryException.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +/** + * This class can be used in the callback given to {@link TestUtils#retryOnExceptionWithTimeout(long, long, ValuelessCallable)} + * to indicate that a particular exception should not be retried. Instead the retry operation will + * be aborted immediately and the exception will be rethrown. + */ +public class NoRetryException extends RuntimeException { + private final Throwable cause; + + public NoRetryException(Throwable cause) { + this.cause = cause; + } + + @Override + public Throwable getCause() { + return this.cause; + } +} diff --git a/clients/src/test/java/org/apache/kafka/test/TestCondition.java b/clients/src/test/java/org/apache/kafka/test/TestCondition.java new file mode 100644 index 0000000..a30ee68 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/test/TestCondition.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +/** + * Interface to wrap actions that are required to wait until a condition is met + * for testing purposes. Note that this is not intended to do any assertions. + */ +@FunctionalInterface +public interface TestCondition { + + boolean conditionMet() throws Exception; +} diff --git a/clients/src/test/java/org/apache/kafka/test/TestSslUtils.java b/clients/src/test/java/org/apache/kafka/test/TestSslUtils.java new file mode 100644 index 0000000..fc72d3d --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/test/TestSslUtils.java @@ -0,0 +1,623 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.common.config.SslConfigs; +import org.apache.kafka.common.config.types.Password; +import org.apache.kafka.common.network.Mode; +import org.apache.kafka.common.security.auth.SslEngineFactory; +import org.apache.kafka.common.security.ssl.DefaultSslEngineFactory; +import org.bouncycastle.asn1.DEROctetString; +import org.bouncycastle.asn1.DERSequence; +import org.bouncycastle.asn1.x500.X500Name; +import org.bouncycastle.asn1.x509.AlgorithmIdentifier; +import org.bouncycastle.asn1.x509.Extension; +import org.bouncycastle.asn1.x509.GeneralName; +import org.bouncycastle.asn1.x509.GeneralNames; +import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo; +import org.bouncycastle.cert.X509CertificateHolder; +import org.bouncycastle.cert.X509v3CertificateBuilder; +import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter; +import org.bouncycastle.crypto.params.AsymmetricKeyParameter; +import org.bouncycastle.crypto.util.PrivateKeyFactory; +import org.bouncycastle.jce.provider.BouncyCastleProvider; +import org.bouncycastle.openssl.PKCS8Generator; +import org.bouncycastle.openssl.jcajce.JcaMiscPEMGenerator; +import org.bouncycastle.openssl.jcajce.JcaPKCS8Generator; +import org.bouncycastle.openssl.jcajce.JceOpenSSLPKCS8EncryptorBuilder; +import org.bouncycastle.operator.ContentSigner; +import org.bouncycastle.operator.DefaultDigestAlgorithmIdentifierFinder; +import org.bouncycastle.operator.DefaultSignatureAlgorithmIdentifierFinder; +import org.bouncycastle.operator.bc.BcContentSignerBuilder; +import org.bouncycastle.operator.bc.BcDSAContentSignerBuilder; +import org.bouncycastle.operator.bc.BcECContentSignerBuilder; +import org.bouncycastle.operator.bc.BcRSAContentSignerBuilder; +import org.bouncycastle.util.io.pem.PemWriter; + +import java.io.ByteArrayOutputStream; +import java.io.EOFException; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.OutputStreamWriter; +import java.math.BigInteger; +import java.net.InetAddress; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.security.GeneralSecurityException; +import java.security.Key; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.KeyStore; +import java.security.NoSuchAlgorithmException; +import java.security.PrivateKey; +import java.security.SecureRandom; +import java.security.Security; +import java.security.cert.Certificate; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Date; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Properties; +import java.util.Set; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.TrustManagerFactory; + +import static org.apache.kafka.common.security.ssl.DefaultSslEngineFactory.PEM_TYPE; + +public class TestSslUtils { + + public static final String TRUST_STORE_PASSWORD = "TrustStorePassword"; + public static final String DEFAULT_TLS_PROTOCOL_FOR_TESTS = SslConfigs.DEFAULT_SSL_PROTOCOL; + + /** + * Create a self-signed X.509 Certificate. + * From http://bfo.com/blog/2011/03/08/odds_and_ends_creating_a_new_x_509_certificate.html. + * + * @param dn the X.509 Distinguished Name, eg "CN=Test, L=London, C=GB" + * @param pair the KeyPair + * @param days how many days from now the Certificate is valid for + * @param algorithm the signing algorithm, eg "SHA1withRSA" + * @return the self-signed certificate + * @throws CertificateException thrown if a security error or an IO error occurred. + */ + public static X509Certificate generateCertificate(String dn, KeyPair pair, + int days, String algorithm) + throws CertificateException { + return new CertificateBuilder(days, algorithm).generate(dn, pair); + } + + public static KeyPair generateKeyPair(String algorithm) throws NoSuchAlgorithmException { + KeyPairGenerator keyGen = KeyPairGenerator.getInstance(algorithm); + keyGen.initialize(algorithm.equals("EC") ? 256 : 2048); + return keyGen.genKeyPair(); + } + + private static KeyStore createEmptyKeyStore() throws GeneralSecurityException, IOException { + KeyStore ks = KeyStore.getInstance("JKS"); + ks.load(null, null); // initialize + return ks; + } + + private static void saveKeyStore(KeyStore ks, String filename, + Password password) throws GeneralSecurityException, IOException { + try (OutputStream out = Files.newOutputStream(Paths.get(filename))) { + ks.store(out, password.value().toCharArray()); + } + } + + /** + * Creates a keystore with a single key and saves it to a file. + * + * @param filename String file to save + * @param password String store password to set on keystore + * @param keyPassword String key password to set on key + * @param alias String alias to use for the key + * @param privateKey Key to save in keystore + * @param cert Certificate to use as certificate chain associated to key + * @throws GeneralSecurityException for any error with the security APIs + * @throws IOException if there is an I/O error saving the file + */ + public static void createKeyStore(String filename, + Password password, Password keyPassword, String alias, + Key privateKey, Certificate cert) throws GeneralSecurityException, IOException { + KeyStore ks = createEmptyKeyStore(); + ks.setKeyEntry(alias, privateKey, keyPassword.value().toCharArray(), + new Certificate[]{cert}); + saveKeyStore(ks, filename, password); + } + + public static void createTrustStore( + String filename, Password password, Map certs) throws GeneralSecurityException, IOException { + KeyStore ks = KeyStore.getInstance("JKS"); + try (InputStream in = Files.newInputStream(Paths.get(filename))) { + ks.load(in, password.value().toCharArray()); + } catch (EOFException e) { + ks = createEmptyKeyStore(); + } + for (Map.Entry cert : certs.entrySet()) { + ks.setCertificateEntry(cert.getKey(), cert.getValue()); + } + saveKeyStore(ks, filename, password); + } + + public static Map createSslConfig(String keyManagerAlgorithm, String trustManagerAlgorithm, String tlsProtocol) { + Map sslConfigs = new HashMap<>(); + sslConfigs.put(SslConfigs.SSL_PROTOCOL_CONFIG, tlsProtocol); // protocol to create SSLContext + + sslConfigs.put(SslConfigs.SSL_KEYMANAGER_ALGORITHM_CONFIG, keyManagerAlgorithm); + sslConfigs.put(SslConfigs.SSL_TRUSTMANAGER_ALGORITHM_CONFIG, trustManagerAlgorithm); + + List enabledProtocols = new ArrayList<>(); + enabledProtocols.add(tlsProtocol); + sslConfigs.put(SslConfigs.SSL_ENABLED_PROTOCOLS_CONFIG, enabledProtocols); + + return sslConfigs; + } + + public static Map createSslConfig(boolean useClientCert, boolean trustStore, Mode mode, File trustStoreFile, String certAlias) + throws IOException, GeneralSecurityException { + return createSslConfig(useClientCert, trustStore, mode, trustStoreFile, certAlias, "localhost"); + } + + public static Map createSslConfig(boolean useClientCert, boolean trustStore, + Mode mode, File trustStoreFile, String certAlias, String cn) + throws IOException, GeneralSecurityException { + return createSslConfig(useClientCert, trustStore, mode, trustStoreFile, certAlias, cn, new CertificateBuilder()); + } + + public static Map createSslConfig(boolean useClientCert, boolean createTrustStore, + Mode mode, File trustStoreFile, String certAlias, String cn, CertificateBuilder certBuilder) + throws IOException, GeneralSecurityException { + SslConfigsBuilder builder = new SslConfigsBuilder(mode) + .useClientCert(useClientCert) + .certAlias(certAlias) + .cn(cn) + .certBuilder(certBuilder); + if (createTrustStore) + builder = builder.createNewTrustStore(trustStoreFile); + else + builder = builder.useExistingTrustStore(trustStoreFile); + return builder.build(); + } + + public static void convertToPem(Map sslProps, boolean writeToFile, boolean encryptPrivateKey) throws Exception { + String tsPath = (String) sslProps.get(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG); + String tsType = (String) sslProps.get(SslConfigs.SSL_TRUSTSTORE_TYPE_CONFIG); + Password tsPassword = (Password) sslProps.remove(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG); + Password trustCerts = (Password) sslProps.remove(SslConfigs.SSL_TRUSTSTORE_CERTIFICATES_CONFIG); + if (trustCerts == null && tsPath != null) { + trustCerts = exportCertificates(tsPath, tsPassword, tsType); + } + if (trustCerts != null) { + if (tsPath == null) { + tsPath = File.createTempFile("truststore", ".pem").getPath(); + sslProps.put(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG, tsPath); + } + sslProps.put(SslConfigs.SSL_TRUSTSTORE_TYPE_CONFIG, PEM_TYPE); + if (writeToFile) + writeToFile(tsPath, trustCerts); + else { + sslProps.put(SslConfigs.SSL_TRUSTSTORE_CERTIFICATES_CONFIG, trustCerts); + sslProps.remove(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG); + } + } + + String ksPath = (String) sslProps.get(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG); + Password certChain = (Password) sslProps.remove(SslConfigs.SSL_KEYSTORE_CERTIFICATE_CHAIN_CONFIG); + Password key = (Password) sslProps.remove(SslConfigs.SSL_KEYSTORE_KEY_CONFIG); + if (certChain == null && ksPath != null) { + String ksType = (String) sslProps.get(SslConfigs.SSL_KEYSTORE_TYPE_CONFIG); + Password ksPassword = (Password) sslProps.remove(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG); + Password keyPassword = (Password) sslProps.get(SslConfigs.SSL_KEY_PASSWORD_CONFIG); + certChain = exportCertificates(ksPath, ksPassword, ksType); + Password pemKeyPassword = encryptPrivateKey ? keyPassword : null; + key = exportPrivateKey(ksPath, ksPassword, keyPassword, ksType, pemKeyPassword); + if (!encryptPrivateKey) + sslProps.remove(SslConfigs.SSL_KEY_PASSWORD_CONFIG); + } + + if (certChain != null) { + if (ksPath == null) { + ksPath = File.createTempFile("keystore", ".pem").getPath(); + sslProps.put(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG, ksPath); + } + sslProps.put(SslConfigs.SSL_KEYSTORE_TYPE_CONFIG, PEM_TYPE); + if (writeToFile) + writeToFile(ksPath, key, certChain); + else { + sslProps.put(SslConfigs.SSL_KEYSTORE_KEY_CONFIG, key); + sslProps.put(SslConfigs.SSL_KEYSTORE_CERTIFICATE_CHAIN_CONFIG, certChain); + sslProps.remove(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG); + } + } + } + + private static void writeToFile(String path, Password... entries) throws IOException { + try (FileOutputStream out = new FileOutputStream(path)) { + for (Password entry: entries) { + out.write(entry.value().getBytes(StandardCharsets.UTF_8)); + } + } + } + + public static void convertToPemWithoutFiles(Properties sslProps) throws Exception { + String tsPath = sslProps.getProperty(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG); + if (tsPath != null) { + Password trustCerts = exportCertificates(tsPath, + (Password) sslProps.get(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG), + sslProps.getProperty(SslConfigs.SSL_TRUSTSTORE_TYPE_CONFIG)); + sslProps.remove(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG); + sslProps.remove(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG); + sslProps.setProperty(SslConfigs.SSL_TRUSTSTORE_TYPE_CONFIG, PEM_TYPE); + sslProps.put(SslConfigs.SSL_TRUSTSTORE_CERTIFICATES_CONFIG, trustCerts); + } + String ksPath = sslProps.getProperty(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG); + if (ksPath != null) { + String ksType = sslProps.getProperty(SslConfigs.SSL_KEYSTORE_TYPE_CONFIG); + Password ksPassword = (Password) sslProps.get(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG); + Password keyPassword = (Password) sslProps.get(SslConfigs.SSL_KEY_PASSWORD_CONFIG); + Password certChain = exportCertificates(ksPath, ksPassword, ksType); + Password key = exportPrivateKey(ksPath, ksPassword, keyPassword, ksType, keyPassword); + sslProps.remove(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG); + sslProps.remove(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG); + sslProps.setProperty(SslConfigs.SSL_KEYSTORE_TYPE_CONFIG, PEM_TYPE); + sslProps.put(SslConfigs.SSL_KEYSTORE_CERTIFICATE_CHAIN_CONFIG, certChain); + sslProps.put(SslConfigs.SSL_KEYSTORE_KEY_CONFIG, key); + } + } + + public static Password exportCertificates(String storePath, Password storePassword, String storeType) throws Exception { + StringBuilder builder = new StringBuilder(); + try (FileInputStream in = new FileInputStream(storePath)) { + KeyStore ks = KeyStore.getInstance(storeType); + ks.load(in, storePassword.value().toCharArray()); + Enumeration aliases = ks.aliases(); + if (!aliases.hasMoreElements()) + throw new IllegalArgumentException("No certificates found in file " + storePath); + while (aliases.hasMoreElements()) { + String alias = aliases.nextElement(); + Certificate[] certs = ks.getCertificateChain(alias); + if (certs != null) { + for (Certificate cert : certs) { + builder.append(pem(cert)); + } + } else { + builder.append(pem(ks.getCertificate(alias))); + } + } + } + return new Password(builder.toString()); + } + + public static Password exportPrivateKey(String storePath, + Password storePassword, + Password keyPassword, + String storeType, + Password pemKeyPassword) throws Exception { + try (FileInputStream in = new FileInputStream(storePath)) { + KeyStore ks = KeyStore.getInstance(storeType); + ks.load(in, storePassword.value().toCharArray()); + String alias = ks.aliases().nextElement(); + return new Password(pem((PrivateKey) ks.getKey(alias, keyPassword.value().toCharArray()), pemKeyPassword)); + } + } + + static String pem(Certificate cert) throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try (PemWriter pemWriter = new PemWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8.name()))) { + pemWriter.writeObject(new JcaMiscPEMGenerator(cert)); + } + return new String(out.toByteArray(), StandardCharsets.UTF_8); + } + + static String pem(PrivateKey privateKey, Password password) throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try (PemWriter pemWriter = new PemWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8.name()))) { + if (password == null) { + pemWriter.writeObject(new JcaPKCS8Generator(privateKey, null)); + } else { + JceOpenSSLPKCS8EncryptorBuilder encryptorBuilder = new JceOpenSSLPKCS8EncryptorBuilder(PKCS8Generator.PBE_SHA1_3DES); + encryptorBuilder.setPasssword(password.value().toCharArray()); + try { + pemWriter.writeObject(new JcaPKCS8Generator(privateKey, encryptorBuilder.build())); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + } + return new String(out.toByteArray(), StandardCharsets.UTF_8); + } + + public static class CertificateBuilder { + private final int days; + private final String algorithm; + private byte[] subjectAltName; + + public CertificateBuilder() { + this(30, "SHA1withRSA"); + } + + public CertificateBuilder(int days, String algorithm) { + this.days = days; + this.algorithm = algorithm; + } + + public CertificateBuilder sanDnsNames(String... hostNames) throws IOException { + GeneralName[] altNames = new GeneralName[hostNames.length]; + for (int i = 0; i < hostNames.length; i++) + altNames[i] = new GeneralName(GeneralName.dNSName, hostNames[i]); + subjectAltName = GeneralNames.getInstance(new DERSequence(altNames)).getEncoded(); + return this; + } + + public CertificateBuilder sanIpAddress(InetAddress hostAddress) throws IOException { + subjectAltName = new GeneralNames(new GeneralName(GeneralName.iPAddress, new DEROctetString(hostAddress.getAddress()))).getEncoded(); + return this; + } + + public X509Certificate generate(String dn, KeyPair keyPair) throws CertificateException { + try { + Security.addProvider(new BouncyCastleProvider()); + AlgorithmIdentifier sigAlgId = new DefaultSignatureAlgorithmIdentifierFinder().find(algorithm); + AlgorithmIdentifier digAlgId = new DefaultDigestAlgorithmIdentifierFinder().find(sigAlgId); + AsymmetricKeyParameter privateKeyAsymKeyParam = PrivateKeyFactory.createKey(keyPair.getPrivate().getEncoded()); + SubjectPublicKeyInfo subPubKeyInfo = SubjectPublicKeyInfo.getInstance(keyPair.getPublic().getEncoded()); + BcContentSignerBuilder signerBuilder; + String keyAlgorithm = keyPair.getPublic().getAlgorithm(); + if (keyAlgorithm.equals("RSA")) + signerBuilder = new BcRSAContentSignerBuilder(sigAlgId, digAlgId); + else if (keyAlgorithm.equals("DSA")) + signerBuilder = new BcDSAContentSignerBuilder(sigAlgId, digAlgId); + else if (keyAlgorithm.equals("EC")) + signerBuilder = new BcECContentSignerBuilder(sigAlgId, digAlgId); + else + throw new IllegalArgumentException("Unsupported algorithm " + keyAlgorithm); + ContentSigner sigGen = signerBuilder.build(privateKeyAsymKeyParam); + X500Name name = new X500Name(dn); + Date from = new Date(); + Date to = new Date(from.getTime() + days * 86400000L); + BigInteger sn = new BigInteger(64, new SecureRandom()); + X509v3CertificateBuilder v3CertGen = new X509v3CertificateBuilder(name, sn, from, to, name, subPubKeyInfo); + + if (subjectAltName != null) + v3CertGen.addExtension(Extension.subjectAlternativeName, false, subjectAltName); + X509CertificateHolder certificateHolder = v3CertGen.build(sigGen); + return new JcaX509CertificateConverter().setProvider("BC").getCertificate(certificateHolder); + } catch (CertificateException ce) { + throw ce; + } catch (Exception e) { + throw new CertificateException(e); + } + } + } + + public static class SslConfigsBuilder { + final Mode mode; + String tlsProtocol; + boolean useClientCert; + boolean createTrustStore; + File trustStoreFile; + Password trustStorePassword; + Password keyStorePassword; + Password keyPassword; + String certAlias; + String cn; + String algorithm; + CertificateBuilder certBuilder; + boolean usePem; + + public SslConfigsBuilder(Mode mode) { + this.mode = mode; + this.tlsProtocol = DEFAULT_TLS_PROTOCOL_FOR_TESTS; + trustStorePassword = new Password(TRUST_STORE_PASSWORD); + keyStorePassword = mode == Mode.SERVER ? new Password("ServerPassword") : new Password("ClientPassword"); + keyPassword = keyStorePassword; + this.certBuilder = new CertificateBuilder(); + this.cn = "localhost"; + this.certAlias = mode.name().toLowerCase(Locale.ROOT); + this.algorithm = "RSA"; + this.createTrustStore = true; + } + + public SslConfigsBuilder tlsProtocol(String tlsProtocol) { + this.tlsProtocol = tlsProtocol; + return this; + } + + public SslConfigsBuilder createNewTrustStore(File trustStoreFile) { + this.trustStoreFile = trustStoreFile; + this.createTrustStore = true; + return this; + } + + public SslConfigsBuilder useExistingTrustStore(File trustStoreFile) { + this.trustStoreFile = trustStoreFile; + this.createTrustStore = false; + return this; + } + + public SslConfigsBuilder useClientCert(boolean useClientCert) { + this.useClientCert = useClientCert; + return this; + } + + public SslConfigsBuilder certAlias(String certAlias) { + this.certAlias = certAlias; + return this; + } + + public SslConfigsBuilder cn(String cn) { + this.cn = cn; + return this; + } + + public SslConfigsBuilder algorithm(String algorithm) { + this.algorithm = algorithm; + return this; + } + + public SslConfigsBuilder certBuilder(CertificateBuilder certBuilder) { + this.certBuilder = certBuilder; + return this; + } + + public SslConfigsBuilder usePem(boolean usePem) { + this.usePem = usePem; + return this; + } + + public Map build() throws IOException, GeneralSecurityException { + if (usePem) { + return buildPem(); + } else + return buildJks(); + } + + private Map buildJks() throws IOException, GeneralSecurityException { + Map certs = new HashMap<>(); + File keyStoreFile = null; + + if (mode == Mode.CLIENT && useClientCert) { + keyStoreFile = File.createTempFile("clientKS", ".jks"); + KeyPair cKP = generateKeyPair(algorithm); + X509Certificate cCert = certBuilder.generate("CN=" + cn + ", O=A client", cKP); + createKeyStore(keyStoreFile.getPath(), keyStorePassword, keyPassword, "client", cKP.getPrivate(), cCert); + certs.put(certAlias, cCert); + } else if (mode == Mode.SERVER) { + keyStoreFile = File.createTempFile("serverKS", ".jks"); + KeyPair sKP = generateKeyPair(algorithm); + X509Certificate sCert = certBuilder.generate("CN=" + cn + ", O=A server", sKP); + createKeyStore(keyStoreFile.getPath(), keyStorePassword, keyPassword, "server", sKP.getPrivate(), sCert); + certs.put(certAlias, sCert); + keyStoreFile.deleteOnExit(); + } + + if (createTrustStore) { + createTrustStore(trustStoreFile.getPath(), trustStorePassword, certs); + trustStoreFile.deleteOnExit(); + } + + Map sslConfigs = new HashMap<>(); + + sslConfigs.put(SslConfigs.SSL_PROTOCOL_CONFIG, tlsProtocol); // protocol to create SSLContext + + if (mode == Mode.SERVER || (mode == Mode.CLIENT && keyStoreFile != null)) { + sslConfigs.put(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG, keyStoreFile.getPath()); + sslConfigs.put(SslConfigs.SSL_KEYSTORE_TYPE_CONFIG, "JKS"); + sslConfigs.put(SslConfigs.SSL_KEYMANAGER_ALGORITHM_CONFIG, TrustManagerFactory.getDefaultAlgorithm()); + sslConfigs.put(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG, keyStorePassword); + sslConfigs.put(SslConfigs.SSL_KEY_PASSWORD_CONFIG, keyPassword); + } + + sslConfigs.put(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG, trustStoreFile.getPath()); + sslConfigs.put(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG, trustStorePassword); + sslConfigs.put(SslConfigs.SSL_TRUSTSTORE_TYPE_CONFIG, "JKS"); + sslConfigs.put(SslConfigs.SSL_TRUSTMANAGER_ALGORITHM_CONFIG, TrustManagerFactory.getDefaultAlgorithm()); + + List enabledProtocols = new ArrayList<>(); + enabledProtocols.add(tlsProtocol); + sslConfigs.put(SslConfigs.SSL_ENABLED_PROTOCOLS_CONFIG, enabledProtocols); + + return sslConfigs; + } + + private Map buildPem() throws IOException, GeneralSecurityException { + if (!createTrustStore) { + throw new IllegalArgumentException("PEM configs cannot be created with existing trust stores"); + } + + Map sslConfigs = new HashMap<>(); + sslConfigs.put(SslConfigs.SSL_PROTOCOL_CONFIG, tlsProtocol); + sslConfigs.put(SslConfigs.SSL_ENABLED_PROTOCOLS_CONFIG, Collections.singletonList(tlsProtocol)); + + if (mode != Mode.CLIENT || useClientCert) { + KeyPair keyPair = generateKeyPair(algorithm); + X509Certificate cert = certBuilder.generate("CN=" + cn + ", O=A " + mode.name().toLowerCase(Locale.ROOT), keyPair); + + Password privateKeyPem = new Password(pem(keyPair.getPrivate(), keyPassword)); + Password certPem = new Password(pem(cert)); + sslConfigs.put(SslConfigs.SSL_KEYSTORE_TYPE_CONFIG, PEM_TYPE); + sslConfigs.put(SslConfigs.SSL_TRUSTSTORE_TYPE_CONFIG, PEM_TYPE); + sslConfigs.put(SslConfigs.SSL_KEYSTORE_KEY_CONFIG, privateKeyPem); + sslConfigs.put(SslConfigs.SSL_KEYSTORE_CERTIFICATE_CHAIN_CONFIG, certPem); + sslConfigs.put(SslConfigs.SSL_KEY_PASSWORD_CONFIG, keyPassword); + sslConfigs.put(SslConfigs.SSL_TRUSTSTORE_CERTIFICATES_CONFIG, certPem); + } + return sslConfigs; + } + } + + public static final class TestSslEngineFactory implements SslEngineFactory { + + public boolean closed = false; + + DefaultSslEngineFactory defaultSslEngineFactory = new DefaultSslEngineFactory(); + + @Override + public SSLEngine createClientSslEngine(String peerHost, int peerPort, String endpointIdentification) { + return defaultSslEngineFactory.createClientSslEngine(peerHost, peerPort, endpointIdentification); + } + + @Override + public SSLEngine createServerSslEngine(String peerHost, int peerPort) { + return defaultSslEngineFactory.createServerSslEngine(peerHost, peerPort); + } + + @Override + public boolean shouldBeRebuilt(Map nextConfigs) { + return defaultSslEngineFactory.shouldBeRebuilt(nextConfigs); + } + + @Override + public Set reconfigurableConfigs() { + return defaultSslEngineFactory.reconfigurableConfigs(); + } + + @Override + public KeyStore keystore() { + return defaultSslEngineFactory.keystore(); + } + + @Override + public KeyStore truststore() { + return defaultSslEngineFactory.truststore(); + } + + @Override + public void close() throws IOException { + defaultSslEngineFactory.close(); + closed = true; + } + + @Override + public void configure(Map configs) { + defaultSslEngineFactory.configure(configs); + } + } +} diff --git a/clients/src/test/java/org/apache/kafka/test/TestUtils.java b/clients/src/test/java/org/apache/kafka/test/TestUtils.java new file mode 100644 index 0000000..3c819be --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/test/TestUtils.java @@ -0,0 +1,579 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.network.NetworkReceive; +import org.apache.kafka.common.network.Send; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.record.UnalignedRecords; +import org.apache.kafka.common.requests.ByteBufferChannel; +import org.apache.kafka.common.requests.RequestHeader; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.lang.reflect.Field; +import java.nio.ByteBuffer; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Base64; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Properties; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.function.Consumer; +import java.util.function.Supplier; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static java.util.Arrays.asList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +/** + * Helper functions for writing unit tests + */ +public class TestUtils { + private static final Logger log = LoggerFactory.getLogger(TestUtils.class); + + public static final File IO_TMP_DIR = new File(System.getProperty("java.io.tmpdir")); + + public static final String LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + public static final String DIGITS = "0123456789"; + public static final String LETTERS_AND_DIGITS = LETTERS + DIGITS; + + /* A consistent random number generator to make tests repeatable */ + public static final Random SEEDED_RANDOM = new Random(192348092834L); + public static final Random RANDOM = new Random(); + public static final long DEFAULT_POLL_INTERVAL_MS = 100; + public static final long DEFAULT_MAX_WAIT_MS = 15000; + + public static Cluster singletonCluster() { + return clusterWith(1); + } + + public static Cluster singletonCluster(final String topic, final int partitions) { + return clusterWith(1, topic, partitions); + } + + public static Cluster clusterWith(int nodes) { + return clusterWith(nodes, new HashMap<>()); + } + + public static Cluster clusterWith(final int nodes, final Map topicPartitionCounts) { + final Node[] ns = new Node[nodes]; + for (int i = 0; i < nodes; i++) + ns[i] = new Node(i, "localhost", 1969); + final List parts = new ArrayList<>(); + for (final Map.Entry topicPartition : topicPartitionCounts.entrySet()) { + final String topic = topicPartition.getKey(); + final int partitions = topicPartition.getValue(); + for (int i = 0; i < partitions; i++) + parts.add(new PartitionInfo(topic, i, ns[i % ns.length], ns, ns)); + } + return new Cluster("kafka-cluster", asList(ns), parts, Collections.emptySet(), Collections.emptySet()); + } + + public static Cluster clusterWith(final int nodes, final String topic, final int partitions) { + return clusterWith(nodes, Collections.singletonMap(topic, partitions)); + } + + /** + * Generate an array of random bytes + * + * @param size The size of the array + */ + public static byte[] randomBytes(final int size) { + final byte[] bytes = new byte[size]; + SEEDED_RANDOM.nextBytes(bytes); + return bytes; + } + + /** + * Generate a random string of letters and digits of the given length + * + * @param len The length of the string + * @return The random string + */ + public static String randomString(final int len) { + final StringBuilder b = new StringBuilder(); + for (int i = 0; i < len; i++) + b.append(LETTERS_AND_DIGITS.charAt(SEEDED_RANDOM.nextInt(LETTERS_AND_DIGITS.length()))); + return b.toString(); + } + + /** + * Create an empty file in the default temporary-file directory, using `kafka` as the prefix and `tmp` as the + * suffix to generate its name. + */ + public static File tempFile() throws IOException { + final File file = File.createTempFile("kafka", ".tmp"); + file.deleteOnExit(); + + return file; + } + + /** + * Create a file with the given contents in the default temporary-file directory, + * using `kafka` as the prefix and `tmp` as the suffix to generate its name. + */ + public static File tempFile(final String contents) throws IOException { + final File file = tempFile(); + final FileWriter writer = new FileWriter(file); + writer.write(contents); + writer.close(); + + return file; + } + + /** + * Create a temporary relative directory in the default temporary-file directory with the given prefix. + * + * @param prefix The prefix of the temporary directory, if null using "kafka-" as default prefix + */ + public static File tempDirectory(final String prefix) { + return tempDirectory(null, prefix); + } + + /** + * Create a temporary relative directory in the default temporary-file directory with a + * prefix of "kafka-" + * + * @return the temporary directory just created. + */ + public static File tempDirectory() { + return tempDirectory(null); + } + + /** + * Create a temporary relative directory in the specified parent directory with the given prefix. + * + * @param parent The parent folder path name, if null using the default temporary-file directory + * @param prefix The prefix of the temporary directory, if null using "kafka-" as default prefix + */ + public static File tempDirectory(final Path parent, String prefix) { + final File file; + prefix = prefix == null ? "kafka-" : prefix; + try { + file = parent == null ? + Files.createTempDirectory(prefix).toFile() : Files.createTempDirectory(parent, prefix).toFile(); + } catch (final IOException ex) { + throw new RuntimeException("Failed to create a temp dir", ex); + } + file.deleteOnExit(); + + Exit.addShutdownHook("delete-temp-file-shutdown-hook", () -> { + try { + Utils.delete(file); + } catch (IOException e) { + log.error("Error deleting {}", file.getAbsolutePath(), e); + } + }); + + return file; + } + + public static Properties producerConfig(final String bootstrapServers, + final Class keySerializer, + final Class valueSerializer, + final Properties additional) { + final Properties properties = new Properties(); + properties.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers); + properties.put(ProducerConfig.ACKS_CONFIG, "all"); + properties.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, keySerializer); + properties.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, valueSerializer); + properties.putAll(additional); + return properties; + } + + public static Properties producerConfig(final String bootstrapServers, final Class keySerializer, final Class valueSerializer) { + return producerConfig(bootstrapServers, keySerializer, valueSerializer, new Properties()); + } + + public static Properties consumerConfig(final String bootstrapServers, + final String groupId, + final Class keyDeserializer, + final Class valueDeserializer, + final Properties additional) { + + final Properties consumerConfig = new Properties(); + consumerConfig.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers); + consumerConfig.put(ConsumerConfig.GROUP_ID_CONFIG, groupId); + consumerConfig.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + consumerConfig.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, keyDeserializer); + consumerConfig.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, valueDeserializer); + consumerConfig.putAll(additional); + return consumerConfig; + } + + public static Properties consumerConfig(final String bootstrapServers, + final String groupId, + final Class keyDeserializer, + final Class valueDeserializer) { + return consumerConfig(bootstrapServers, + groupId, + keyDeserializer, + valueDeserializer, + new Properties()); + } + + /** + * returns consumer config with random UUID for the Group ID + */ + public static Properties consumerConfig(final String bootstrapServers, final Class keyDeserializer, final Class valueDeserializer) { + return consumerConfig(bootstrapServers, + UUID.randomUUID().toString(), + keyDeserializer, + valueDeserializer, + new Properties()); + } + + /** + * uses default value of 15 seconds for timeout + */ + public static void waitForCondition(final TestCondition testCondition, final String conditionDetails) throws InterruptedException { + waitForCondition(testCondition, DEFAULT_MAX_WAIT_MS, () -> conditionDetails); + } + + /** + * uses default value of 15 seconds for timeout + */ + public static void waitForCondition(final TestCondition testCondition, final Supplier conditionDetailsSupplier) throws InterruptedException { + waitForCondition(testCondition, DEFAULT_MAX_WAIT_MS, conditionDetailsSupplier); + } + + /** + * Wait for condition to be met for at most {@code maxWaitMs} and throw assertion failure otherwise. + * This should be used instead of {@code Thread.sleep} whenever possible as it allows a longer timeout to be used + * without unnecessarily increasing test time (as the condition is checked frequently). The longer timeout is needed to + * avoid transient failures due to slow or overloaded machines. + */ + public static void waitForCondition(final TestCondition testCondition, final long maxWaitMs, String conditionDetails) throws InterruptedException { + waitForCondition(testCondition, maxWaitMs, () -> conditionDetails); + } + + /** + * Wait for condition to be met for at most {@code maxWaitMs} and throw assertion failure otherwise. + * This should be used instead of {@code Thread.sleep} whenever possible as it allows a longer timeout to be used + * without unnecessarily increasing test time (as the condition is checked frequently). The longer timeout is needed to + * avoid transient failures due to slow or overloaded machines. + */ + public static void waitForCondition(final TestCondition testCondition, final long maxWaitMs, Supplier conditionDetailsSupplier) throws InterruptedException { + waitForCondition(testCondition, maxWaitMs, DEFAULT_POLL_INTERVAL_MS, conditionDetailsSupplier); + } + + /** + * Wait for condition to be met for at most {@code maxWaitMs} with a polling interval of {@code pollIntervalMs} + * and throw assertion failure otherwise. This should be used instead of {@code Thread.sleep} whenever possible + * as it allows a longer timeout to be used without unnecessarily increasing test time (as the condition is + * checked frequently). The longer timeout is needed to avoid transient failures due to slow or overloaded + * machines. + */ + public static void waitForCondition( + final TestCondition testCondition, + final long maxWaitMs, + final long pollIntervalMs, + Supplier conditionDetailsSupplier + ) throws InterruptedException { + retryOnExceptionWithTimeout(maxWaitMs, pollIntervalMs, () -> { + String conditionDetailsSupplied = conditionDetailsSupplier != null ? conditionDetailsSupplier.get() : null; + String conditionDetails = conditionDetailsSupplied != null ? conditionDetailsSupplied : ""; + assertTrue(testCondition.conditionMet(), + "Condition not met within timeout " + maxWaitMs + ". " + conditionDetails); + }); + } + + /** + * Wait for the given runnable to complete successfully, i.e. throw now {@link Exception}s or + * {@link AssertionError}s, or for the given timeout to expire. If the timeout expires then the + * last exception or assertion failure will be thrown thus providing context for the failure. + * + * @param timeoutMs the total time in milliseconds to wait for {@code runnable} to complete successfully. + * @param runnable the code to attempt to execute successfully. + * @throws InterruptedException if the current thread is interrupted while waiting for {@code runnable} to complete successfully. + */ + public static void retryOnExceptionWithTimeout(final long timeoutMs, + final ValuelessCallable runnable) throws InterruptedException { + retryOnExceptionWithTimeout(timeoutMs, DEFAULT_POLL_INTERVAL_MS, runnable); + } + + /** + * Wait for the given runnable to complete successfully, i.e. throw now {@link Exception}s or + * {@link AssertionError}s, or for the default timeout to expire. If the timeout expires then the + * last exception or assertion failure will be thrown thus providing context for the failure. + * + * @param runnable the code to attempt to execute successfully. + * @throws InterruptedException if the current thread is interrupted while waiting for {@code runnable} to complete successfully. + */ + public static void retryOnExceptionWithTimeout(final ValuelessCallable runnable) throws InterruptedException { + retryOnExceptionWithTimeout(DEFAULT_MAX_WAIT_MS, DEFAULT_POLL_INTERVAL_MS, runnable); + } + + /** + * Wait for the given runnable to complete successfully, i.e. throw now {@link Exception}s or + * {@link AssertionError}s, or for the given timeout to expire. If the timeout expires then the + * last exception or assertion failure will be thrown thus providing context for the failure. + * + * @param timeoutMs the total time in milliseconds to wait for {@code runnable} to complete successfully. + * @param pollIntervalMs the interval in milliseconds to wait between invoking {@code runnable}. + * @param runnable the code to attempt to execute successfully. + * @throws InterruptedException if the current thread is interrupted while waiting for {@code runnable} to complete successfully. + */ + public static void retryOnExceptionWithTimeout(final long timeoutMs, + final long pollIntervalMs, + final ValuelessCallable runnable) throws InterruptedException { + final long expectedEnd = System.currentTimeMillis() + timeoutMs; + + while (true) { + try { + runnable.call(); + return; + } catch (final NoRetryException e) { + throw e; + } catch (final AssertionError t) { + if (expectedEnd <= System.currentTimeMillis()) { + throw t; + } + } catch (final Exception e) { + if (expectedEnd <= System.currentTimeMillis()) { + throw new AssertionError(String.format("Assertion failed with an exception after %s ms", timeoutMs), e); + } + } + Thread.sleep(Math.min(pollIntervalMs, timeoutMs)); + } + } + + /** + * Checks if a cluster id is valid. + * @param clusterId + */ + public static void isValidClusterId(String clusterId) { + assertNotNull(clusterId); + + // Base 64 encoded value is 22 characters + assertEquals(clusterId.length(), 22); + + Pattern clusterIdPattern = Pattern.compile("[a-zA-Z0-9_\\-]+"); + Matcher matcher = clusterIdPattern.matcher(clusterId); + assertTrue(matcher.matches()); + + // Convert into normal variant and add padding at the end. + String originalClusterId = String.format("%s==", clusterId.replace("_", "/").replace("-", "+")); + byte[] decodedUuid = Base64.getDecoder().decode(originalClusterId); + + // We expect 16 bytes, same as the input UUID. + assertEquals(decodedUuid.length, 16); + + //Check if it can be converted back to a UUID. + try { + ByteBuffer uuidBuffer = ByteBuffer.wrap(decodedUuid); + new UUID(uuidBuffer.getLong(), uuidBuffer.getLong()).toString(); + } catch (Exception e) { + fail(clusterId + " cannot be converted back to UUID."); + } + } + + /** + * Checks the two iterables for equality by first converting both to a list. + */ + public static void checkEquals(Iterable it1, Iterable it2) { + assertEquals(toList(it1), toList(it2)); + } + + public static void checkEquals(Iterator it1, Iterator it2) { + assertEquals(Utils.toList(it1), Utils.toList(it2)); + } + + public static void checkEquals(Set c1, Set c2, String firstDesc, String secondDesc) { + if (!c1.equals(c2)) { + Set missing1 = new HashSet<>(c2); + missing1.removeAll(c1); + Set missing2 = new HashSet<>(c1); + missing2.removeAll(c2); + fail(String.format("Sets not equal, missing %s=%s, missing %s=%s", firstDesc, missing1, secondDesc, missing2)); + } + } + + public static List toList(Iterable iterable) { + List list = new ArrayList<>(); + for (T item : iterable) + list.add(item); + return list; + } + + public static Set toSet(Collection collection) { + return new HashSet<>(collection); + } + + public static ByteBuffer toBuffer(Send send) { + ByteBufferChannel channel = new ByteBufferChannel(send.size()); + try { + assertEquals(send.size(), send.writeTo(channel)); + channel.close(); + return channel.buffer(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public static ByteBuffer toBuffer(UnalignedRecords records) { + return toBuffer(records.toSend()); + } + + public static Set generateRandomTopicPartitions(int numTopic, int numPartitionPerTopic) { + Set tps = new HashSet<>(); + for (int i = 0; i < numTopic; i++) { + String topic = randomString(32); + for (int j = 0; j < numPartitionPerTopic; j++) { + tps.add(new TopicPartition(topic, j)); + } + } + return tps; + } + + /** + * Assert that a future raises an expected exception cause type. Return the exception cause + * if the assertion succeeds; otherwise raise AssertionError. + * + * @param future The future to await + * @param exceptionCauseClass Class of the expected exception cause + * @param Exception cause type parameter + * @return The caught exception cause + */ + public static T assertFutureThrows(Future future, Class exceptionCauseClass) { + ExecutionException exception = assertThrows(ExecutionException.class, future::get); + assertTrue(exceptionCauseClass.isInstance(exception.getCause()), + "Unexpected exception cause " + exception.getCause()); + return exceptionCauseClass.cast(exception.getCause()); + } + + public static void assertFutureThrows( + Future future, + Class expectedCauseClassApiException, + String expectedMessage + ) { + T receivedException = assertFutureThrows(future, expectedCauseClassApiException); + assertEquals(expectedMessage, receivedException.getMessage()); + } + + public static void assertFutureError(Future future, Class exceptionClass) + throws InterruptedException { + try { + future.get(); + fail("Expected a " + exceptionClass.getSimpleName() + " exception, but got success."); + } catch (ExecutionException ee) { + Throwable cause = ee.getCause(); + assertEquals(exceptionClass, cause.getClass(), + "Expected a " + exceptionClass.getSimpleName() + " exception, but got " + + cause.getClass().getSimpleName()); + } + } + + public static ApiKeys apiKeyFrom(NetworkReceive networkReceive) { + return RequestHeader.parse(networkReceive.payload().duplicate()).apiKey(); + } + + public static void assertOptional(Optional optional, Consumer assertion) { + if (optional.isPresent()) { + assertion.accept(optional.get()); + } else { + fail("Missing value from Optional"); + } + } + + @SuppressWarnings("unchecked") + public static T fieldValue(Object o, Class clazz, String fieldName) { + try { + Field field = clazz.getDeclaredField(fieldName); + field.setAccessible(true); + return (T) field.get(o); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + public static void setFieldValue(Object obj, String fieldName, Object value) throws Exception { + Field field = obj.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + field.set(obj, value); + } + + /** + * Returns true if both iterators have same elements in the same order. + * + * @param iterator1 first iterator. + * @param iterator2 second iterator. + * @param type of element in the iterators. + */ + public static boolean sameElementsWithOrder(Iterator iterator1, + Iterator iterator2) { + while (iterator1.hasNext()) { + if (!iterator2.hasNext()) { + return false; + } + + if (!Objects.equals(iterator1.next(), iterator2.next())) { + return false; + } + } + + return !iterator2.hasNext(); + } + + /** + * Returns true if both the iterators have same set of elements irrespective of order and duplicates. + * + * @param iterator1 first iterator. + * @param iterator2 second iterator. + * @param type of element in the iterators. + */ + public static boolean sameElementsWithoutOrder(Iterator iterator1, + Iterator iterator2) { + // Check both the iterators have the same set of elements irrespective of order and duplicates. + Set allSegmentsSet = new HashSet<>(); + iterator1.forEachRemaining(allSegmentsSet::add); + Set expectedSegmentsSet = new HashSet<>(); + iterator2.forEachRemaining(expectedSegmentsSet::add); + + return allSegmentsSet.equals(expectedSegmentsSet); + } +} diff --git a/clients/src/test/java/org/apache/kafka/test/ValuelessCallable.java b/clients/src/test/java/org/apache/kafka/test/ValuelessCallable.java new file mode 100644 index 0000000..62863b3 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/test/ValuelessCallable.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +/** + * Like a {@link Runnable} that allows exceptions to be thrown or a {@link java.util.concurrent.Callable} + * that does not return a value. + */ +public interface ValuelessCallable { + void call() throws Exception; +} diff --git a/clients/src/test/resources/common/message/SimpleExampleMessage.json b/clients/src/test/resources/common/message/SimpleExampleMessage.json new file mode 100644 index 0000000..342a9b9 --- /dev/null +++ b/clients/src/test/resources/common/message/SimpleExampleMessage.json @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +{ + "name": "SimpleExampleMessage", + "type": "header", + "validVersions": "0-2", + "flexibleVersions": "1+", + "fields": [ + { "name": "processId", "versions": "1+", "type": "uuid" }, + { "name": "myTaggedIntArray", "type": "[]int32", + "taggedVersions": "1+", "tag": 0 }, + { "name": "myNullableString", "type": "string", "default": "null", + "nullableVersions": "1+", "taggedVersions": "1+", "tag": 1 }, + { "name": "myInt16", "type": "int16", "default": "123", + "taggedVersions": "1+", "tag": 2 }, + { "name": "myFloat64", "type": "float64", "default": "12.34", + "taggedVersions": "1+", "tag": 3 }, + { "name": "myString", "type": "string", "taggedVersions": "1+", "tag": 4 }, + { "name": "myBytes", "type": "bytes", + "nullableVersions": "1+", "taggedVersions": "1+", "tag": 5 }, + { "name": "taggedUuid", "type": "uuid", "default": "H3KKO4NTRPaCWtEmm3vW7A", + "taggedVersions": "1+", "tag": 6 }, + { "name": "taggedLong", "type": "int64", "default": "0xcafcacafcacafca", + "taggedVersions": "1+", "tag": 7 }, + { "name": "zeroCopyByteBuffer", "versions": "1", "type": "bytes", "zeroCopy": true }, + { "name": "nullableZeroCopyByteBuffer", "versions": "1", "nullableVersions": "0+", + "type": "bytes", "zeroCopy": true }, + { "name": "myStruct", "type": "MyStruct", "versions": "2+", "about": "Test Struct field", + "fields": [ + { "name": "structId", "type": "int32", "versions": "2+", "about": "Int field in struct"}, + { "name": "arrayInStruct", "type": "[]StructArray", "versions": "2+", + "fields": [ + { "name": "arrayFieldId", "type": "int32", "versions": "2+"} + ]} + ]}, + { "name": "myTaggedStruct", "type": "TaggedStruct", "versions": "2+", "about": "Test Tagged Struct field", + "taggedVersions": "2+", "tag": 8, + "fields": [ + { "name": "structId", "type": "string", "versions": "2+", "about": "String field in struct"} + ]}, + { "name": "myCommonStruct", "type": "TestCommonStruct", "versions": "0+"}, + { "name": "myOtherCommonStruct", "type": "TestCommonStruct", "versions": "0+"}, + { "name": "myUint16", "type": "uint16", "versions": "1+", "default": "33000" } + ], + "commonStructs": [ + { "name": "TestCommonStruct", "versions": "0+", "fields": [ + { "name": "foo", "type": "int32", "default": "123", "versions": "0+" }, + { "name": "bar", "type": "int32", "default": "123", "versions": "0+" } + ]} + ] +} diff --git a/clients/src/test/resources/common/message/SimpleRecordsMessage.json b/clients/src/test/resources/common/message/SimpleRecordsMessage.json new file mode 100644 index 0000000..ce81b29 --- /dev/null +++ b/clients/src/test/resources/common/message/SimpleRecordsMessage.json @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +{ + "name": "SimpleRecordsMessage", + "type": "header", + "validVersions": "0-1", + "flexibleVersions": "1+", + "fields": [ + { "name": "Topic", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name." }, + { "name": "RecordSet", "type": "records", "versions": "0+", + "nullableVersions": "0+", "about": "The record data." } + ] +} diff --git a/clients/src/test/resources/log4j.properties b/clients/src/test/resources/log4j.properties new file mode 100644 index 0000000..b1d5b7f --- /dev/null +++ b/clients/src/test/resources/log4j.properties @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +log4j.rootLogger=OFF, stdout + +log4j.appender.stdout=org.apache.log4j.ConsoleAppender +log4j.appender.stdout.layout=org.apache.log4j.PatternLayout +log4j.appender.stdout.layout.ConversionPattern=[%d] %p %m (%c:%L)%n + +log4j.logger.org.apache.kafka=ERROR diff --git a/clients/src/test/resources/serializedData/offsetAndMetadataBeforeLeaderEpoch b/clients/src/test/resources/serializedData/offsetAndMetadataBeforeLeaderEpoch new file mode 100644 index 0000000..95319cb Binary files /dev/null and b/clients/src/test/resources/serializedData/offsetAndMetadataBeforeLeaderEpoch differ diff --git a/clients/src/test/resources/serializedData/offsetAndMetadataWithLeaderEpoch b/clients/src/test/resources/serializedData/offsetAndMetadataWithLeaderEpoch new file mode 100644 index 0000000..ddf3956 Binary files /dev/null and b/clients/src/test/resources/serializedData/offsetAndMetadataWithLeaderEpoch differ diff --git a/clients/src/test/resources/serializedData/topicPartitionSerializedfile b/clients/src/test/resources/serializedData/topicPartitionSerializedfile new file mode 100644 index 0000000..2c1c501 Binary files /dev/null and b/clients/src/test/resources/serializedData/topicPartitionSerializedfile differ diff --git a/config/connect-console-sink.properties b/config/connect-console-sink.properties new file mode 100644 index 0000000..e240a8f --- /dev/null +++ b/config/connect-console-sink.properties @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name=local-console-sink +connector.class=org.apache.kafka.connect.file.FileStreamSinkConnector +tasks.max=1 +topics=connect-test \ No newline at end of file diff --git a/config/connect-console-source.properties b/config/connect-console-source.properties new file mode 100644 index 0000000..d0e2069 --- /dev/null +++ b/config/connect-console-source.properties @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name=local-console-source +connector.class=org.apache.kafka.connect.file.FileStreamSourceConnector +tasks.max=1 +topic=connect-test \ No newline at end of file diff --git a/config/connect-distributed.properties b/config/connect-distributed.properties new file mode 100644 index 0000000..cedad9a --- /dev/null +++ b/config/connect-distributed.properties @@ -0,0 +1,89 @@ +## +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +## + +# This file contains some of the configurations for the Kafka Connect distributed worker. This file is intended +# to be used with the examples, and some settings may differ from those used in a production system, especially +# the `bootstrap.servers` and those specifying replication factors. + +# A list of host/port pairs to use for establishing the initial connection to the Kafka cluster. +bootstrap.servers=localhost:9092 + +# unique name for the cluster, used in forming the Connect cluster group. Note that this must not conflict with consumer group IDs +group.id=connect-cluster + +# The converters specify the format of data in Kafka and how to translate it into Connect data. Every Connect user will +# need to configure these based on the format they want their data in when loaded from or stored into Kafka +key.converter=org.apache.kafka.connect.json.JsonConverter +value.converter=org.apache.kafka.connect.json.JsonConverter +# Converter-specific settings can be passed in by prefixing the Converter's setting with the converter we want to apply +# it to +key.converter.schemas.enable=true +value.converter.schemas.enable=true + +# Topic to use for storing offsets. This topic should have many partitions and be replicated and compacted. +# Kafka Connect will attempt to create the topic automatically when needed, but you can always manually create +# the topic before starting Kafka Connect if a specific topic configuration is needed. +# Most users will want to use the built-in default replication factor of 3 or in some cases even specify a larger value. +# Since this means there must be at least as many brokers as the maximum replication factor used, we'd like to be able +# to run this example on a single-broker cluster and so here we instead set the replication factor to 1. +offset.storage.topic=connect-offsets +offset.storage.replication.factor=1 +#offset.storage.partitions=25 + +# Topic to use for storing connector and task configurations; note that this should be a single partition, highly replicated, +# and compacted topic. Kafka Connect will attempt to create the topic automatically when needed, but you can always manually create +# the topic before starting Kafka Connect if a specific topic configuration is needed. +# Most users will want to use the built-in default replication factor of 3 or in some cases even specify a larger value. +# Since this means there must be at least as many brokers as the maximum replication factor used, we'd like to be able +# to run this example on a single-broker cluster and so here we instead set the replication factor to 1. +config.storage.topic=connect-configs +config.storage.replication.factor=1 + +# Topic to use for storing statuses. This topic can have multiple partitions and should be replicated and compacted. +# Kafka Connect will attempt to create the topic automatically when needed, but you can always manually create +# the topic before starting Kafka Connect if a specific topic configuration is needed. +# Most users will want to use the built-in default replication factor of 3 or in some cases even specify a larger value. +# Since this means there must be at least as many brokers as the maximum replication factor used, we'd like to be able +# to run this example on a single-broker cluster and so here we instead set the replication factor to 1. +status.storage.topic=connect-status +status.storage.replication.factor=1 +#status.storage.partitions=5 + +# Flush much faster than normal, which is useful for testing/debugging +offset.flush.interval.ms=10000 + +# List of comma-separated URIs the REST API will listen on. The supported protocols are HTTP and HTTPS. +# Specify hostname as 0.0.0.0 to bind to all interfaces. +# Leave hostname empty to bind to default interface. +# Examples of legal listener lists: HTTP://myhost:8083,HTTPS://myhost:8084" +#listeners=HTTP://:8083 + +# The Hostname & Port that will be given out to other workers to connect to i.e. URLs that are routable from other servers. +# If not set, it uses the value for "listeners" if configured. +#rest.advertised.host.name= +#rest.advertised.port= +#rest.advertised.listener= + +# Set to a list of filesystem paths separated by commas (,) to enable class loading isolation for plugins +# (connectors, converters, transformations). The list should consist of top level directories that include +# any combination of: +# a) directories immediately containing jars with plugins and their dependencies +# b) uber-jars with plugins and their dependencies +# c) directories immediately containing the package directory structure of classes of plugins and their dependencies +# Examples: +# plugin.path=/usr/local/share/java,/usr/local/share/kafka/plugins,/opt/connectors, +#plugin.path= diff --git a/config/connect-file-sink.properties b/config/connect-file-sink.properties new file mode 100644 index 0000000..594ccc6 --- /dev/null +++ b/config/connect-file-sink.properties @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name=local-file-sink +connector.class=FileStreamSink +tasks.max=1 +file=test.sink.txt +topics=connect-test \ No newline at end of file diff --git a/config/connect-file-source.properties b/config/connect-file-source.properties new file mode 100644 index 0000000..599cf4c --- /dev/null +++ b/config/connect-file-source.properties @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name=local-file-source +connector.class=FileStreamSource +tasks.max=1 +file=test.txt +topic=connect-test \ No newline at end of file diff --git a/config/connect-log4j.properties b/config/connect-log4j.properties new file mode 100644 index 0000000..157d593 --- /dev/null +++ b/config/connect-log4j.properties @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +log4j.rootLogger=INFO, stdout, connectAppender + +# Send the logs to the console. +# +log4j.appender.stdout=org.apache.log4j.ConsoleAppender +log4j.appender.stdout.layout=org.apache.log4j.PatternLayout + +# Send the logs to a file, rolling the file at midnight local time. For example, the `File` option specifies the +# location of the log files (e.g. ${kafka.logs.dir}/connect.log), and at midnight local time the file is closed +# and copied in the same directory but with a filename that ends in the `DatePattern` option. +# +log4j.appender.connectAppender=org.apache.log4j.DailyRollingFileAppender +log4j.appender.connectAppender.DatePattern='.'yyyy-MM-dd-HH +log4j.appender.connectAppender.File=${kafka.logs.dir}/connect.log +log4j.appender.connectAppender.layout=org.apache.log4j.PatternLayout + +# The `%X{connector.context}` parameter in the layout includes connector-specific and task-specific information +# in the log messages, where appropriate. This makes it easier to identify those log messages that apply to a +# specific connector. +# +connect.log.pattern=[%d] %p %X{connector.context}%m (%c:%L)%n + +log4j.appender.stdout.layout.ConversionPattern=${connect.log.pattern} +log4j.appender.connectAppender.layout.ConversionPattern=${connect.log.pattern} + +log4j.logger.org.apache.zookeeper=ERROR +log4j.logger.org.reflections=ERROR diff --git a/config/connect-mirror-maker.properties b/config/connect-mirror-maker.properties new file mode 100644 index 0000000..40afda5 --- /dev/null +++ b/config/connect-mirror-maker.properties @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under A or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# see org.apache.kafka.clients.consumer.ConsumerConfig for more details + +# Sample MirrorMaker 2.0 top-level configuration file +# Run with ./bin/connect-mirror-maker.sh connect-mirror-maker.properties + +# specify any number of cluster aliases +clusters = A, B + +# connection information for each cluster +# This is a comma separated host:port pairs for each cluster +# for e.g. "A_host1:9092, A_host2:9092, A_host3:9092" +A.bootstrap.servers = A_host1:9092, A_host2:9092, A_host3:9092 +B.bootstrap.servers = B_host1:9092, B_host2:9092, B_host3:9092 + +# enable and configure individual replication flows +A->B.enabled = true + +# regex which defines which topics gets replicated. For eg "foo-.*" +A->B.topics = .* + +B->A.enabled = true +B->A.topics = .* + +# Setting replication factor of newly created remote topics +replication.factor=1 + +############################# Internal Topic Settings ############################# +# The replication factor for mm2 internal topics "heartbeats", "B.checkpoints.internal" and +# "mm2-offset-syncs.B.internal" +# For anything other than development testing, a value greater than 1 is recommended to ensure availability such as 3. +checkpoints.topic.replication.factor=1 +heartbeats.topic.replication.factor=1 +offset-syncs.topic.replication.factor=1 + +# The replication factor for connect internal topics "mm2-configs.B.internal", "mm2-offsets.B.internal" and +# "mm2-status.B.internal" +# For anything other than development testing, a value greater than 1 is recommended to ensure availability such as 3. +offset.storage.replication.factor=1 +status.storage.replication.factor=1 +config.storage.replication.factor=1 + +# customize as needed +# replication.policy.separator = _ +# sync.topic.acls.enabled = false +# emit.heartbeats.interval.seconds = 5 diff --git a/config/connect-standalone.properties b/config/connect-standalone.properties new file mode 100644 index 0000000..a340a3b --- /dev/null +++ b/config/connect-standalone.properties @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# These are defaults. This file just demonstrates how to override some settings. +bootstrap.servers=localhost:9092 + +# The converters specify the format of data in Kafka and how to translate it into Connect data. Every Connect user will +# need to configure these based on the format they want their data in when loaded from or stored into Kafka +key.converter=org.apache.kafka.connect.json.JsonConverter +value.converter=org.apache.kafka.connect.json.JsonConverter +# Converter-specific settings can be passed in by prefixing the Converter's setting with the converter we want to apply +# it to +key.converter.schemas.enable=true +value.converter.schemas.enable=true + +offset.storage.file.filename=/tmp/connect.offsets +# Flush much faster than normal, which is useful for testing/debugging +offset.flush.interval.ms=10000 + +# Set to a list of filesystem paths separated by commas (,) to enable class loading isolation for plugins +# (connectors, converters, transformations). The list should consist of top level directories that include +# any combination of: +# a) directories immediately containing jars with plugins and their dependencies +# b) uber-jars with plugins and their dependencies +# c) directories immediately containing the package directory structure of classes of plugins and their dependencies +# Note: symlinks will be followed to discover dependencies or plugins. +# Examples: +# plugin.path=/usr/local/share/java,/usr/local/share/kafka/plugins,/opt/connectors, +#plugin.path= diff --git a/config/consumer.properties b/config/consumer.properties new file mode 100644 index 0000000..01bb12e --- /dev/null +++ b/config/consumer.properties @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# see org.apache.kafka.clients.consumer.ConsumerConfig for more details + +# list of brokers used for bootstrapping knowledge about the rest of the cluster +# format: host1:port1,host2:port2 ... +bootstrap.servers=localhost:9092 + +# consumer group id +group.id=test-consumer-group + +# What to do when there is no initial offset in Kafka or if the current +# offset does not exist any more on the server: latest, earliest, none +#auto.offset.reset= diff --git a/config/kraft/README.md b/config/kraft/README.md new file mode 100644 index 0000000..80bc8ca --- /dev/null +++ b/config/kraft/README.md @@ -0,0 +1,173 @@ +KRaft (aka KIP-500) mode Preview Release +========================================================= + +# Introduction +It is now possible to run Apache Kafka without Apache ZooKeeper! We call this the [Kafka Raft metadata mode](https://cwiki.apache.org/confluence/display/KAFKA/KIP-500%3A+Replace+ZooKeeper+with+a+Self-Managed+Metadata+Quorum), typically shortened to `KRaft mode`. +`KRaft` is intended to be pronounced like `craft` (as in `craftsmanship`). It is currently *PREVIEW AND SHOULD NOT BE USED IN PRODUCTION*, but it +is available for testing in the Kafka 3.1 release. + +When the Kafka cluster is in KRaft mode, it does not store its metadata in ZooKeeper. In fact, you do not have to run ZooKeeper at all, because it stores its metadata in a KRaft quorum of controller nodes. + +KRaft mode has many benefits -- some obvious, and some not so obvious. Clearly, it is nice to manage and configure one service rather than two services. In addition, you can now run a single process Kafka cluster. +Most important of all, KRaft mode is more scalable. We expect to be able to [support many more topics and partitions](https://www.confluent.io/kafka-summit-san-francisco-2019/kafka-needs-no-keeper/) in this mode. + +# Quickstart + +## Warning +KRaft mode in Kafka 3.1 is provided for testing only, *NOT* for production. We do not yet support upgrading existing ZooKeeper-based Kafka clusters into this mode. +There may be bugs, including serious ones. You should *assume that your data could be lost at any time* if you try the preview release of KRaft mode. + +## Generate a cluster ID +The first step is to generate an ID for your new cluster, using the kafka-storage tool: + +~~~~ +$ ./bin/kafka-storage.sh random-uuid +xtzWWN4bTjitpL3kfd9s5g +~~~~ + +## Format Storage Directories +The next step is to format your storage directories. If you are running in single-node mode, you can do this with one command: + +~~~~ +$ ./bin/kafka-storage.sh format -t -c ./config/kraft/server.properties +Formatting /tmp/kraft-combined-logs +~~~~ + +If you are using multiple nodes, then you should run the format command on each node. Be sure to use the same cluster ID for each one. + +## Start the Kafka Server +Finally, you are ready to start the Kafka server on each node. + +~~~~ +$ ./bin/kafka-server-start.sh ./config/kraft/server.properties +[2021-02-26 15:37:11,071] INFO Registered kafka:type=kafka.Log4jController MBean (kafka.utils.Log4jControllerRegistration$) +[2021-02-26 15:37:11,294] INFO Setting -D jdk.tls.rejectClientInitiatedRenegotiation=true to disable client-initiated TLS renegotiation (org.apache.zookeeper.common.X509Util) +[2021-02-26 15:37:11,466] INFO [Log partition=__cluster_metadata-0, dir=/tmp/kraft-combined-logs] Loading producer state till offset 0 with message format version 2 (kafka.log.Log) +[2021-02-26 15:37:11,509] INFO [raft-expiration-reaper]: Starting (kafka.raft.TimingWheelExpirationService$ExpiredOperationReaper) +[2021-02-26 15:37:11,640] INFO [RaftManager nodeId=1] Completed transition to Unattached(epoch=0, voters=[1], electionTimeoutMs=9037) (org.apache.kafka.raft.QuorumState) +... +~~~~ + +Just like with a ZooKeeper based broker, you can connect to port 9092 (or whatever port you configured) to perform administrative operations or produce or consume data. + +~~~~ +$ ./bin/kafka-topics.sh --create --topic foo --partitions 1 --replication-factor 1 --bootstrap-server localhost:9092 +Created topic foo. +~~~~ + +# Deployment + +## Controller Servers +In KRaft mode, only a small group of specially selected servers can act as controllers (unlike the ZooKeeper-based mode, where any server can become the +Controller). The specially selected controller servers will participate in the metadata quorum. Each controller server is either active, or a hot +standby for the current active controller server. + +You will typically select 3 or 5 servers for this role, depending on factors like cost and the number of concurrent failures your system should withstand +without availability impact. Just like with ZooKeeper, you must keep a majority of the controllers alive in order to maintain availability. So if you have 3 +controllers, you can tolerate 1 failure; with 5 controllers, you can tolerate 2 failures. + +## Process Roles +Each Kafka server now has a new configuration key called `process.roles` which can have the following values: + +* If `process.roles` is set to `broker`, the server acts as a broker in KRaft mode. +* If `process.roles` is set to `controller`, the server acts as a controller in KRaft mode. +* If `process.roles` is set to `broker,controller`, the server acts as both a broker and a controller in KRaft mode. +* If `process.roles` is not set at all then we are assumed to be in ZooKeeper mode. As mentioned earlier, you can't currently transition back and forth between ZooKeeper mode and KRaft mode without reformatting. + +Nodes that act as both brokers and controllers are referred to as "combined" nodes. Combined nodes are simpler to operate for simple use cases and allow you to avoid +some fixed memory overheads associated with JVMs. The key disadvantage is that the controller will be less isolated from the rest of the system. For example, if activity on the broker causes an out of +memory condition, the controller part of the server is not isolated from that OOM condition. + +## Quorum Voters +All nodes in the system must set the `controller.quorum.voters` configuration. This identifies the quorum controller servers that should be used. All the controllers must be enumerated. +This is similar to how, when using ZooKeeper, the `zookeeper.connect` configuration must contain all the ZooKeeper servers. Unlike with the ZooKeeper config, however, `controller.quorum.voters` +also has IDs for each node. The format is id1@host1:port1,id2@host2:port2, etc. + +So if you have 10 brokers and 3 controllers named controller1, controller2, controller3, you might have the following configuration on controller1: +``` +process.roles=controller +node.id=1 +listeners=CONTROLLER://controller1.example.com:9093 +controller.quorum.voters=1@controller1.example.com:9093,2@controller2.example.com:9093,3@controller3.example.com:9093 +``` + +Each broker and each controller must set `controller.quorum.voters`. Note that the node ID supplied in the `controller.quorum.voters` configuration must match that supplied to the server. +So on controller1, node.id must be set to 1, and so forth. Note that there is no requirement for controller IDs to start at 0 or 1. However, the easiest and least confusing way to allocate +node IDs is probably just to give each server a numeric ID, starting from 0. + +Note that clients never need to configure `controller.quorum.voters`; only servers do. + +## Kafka Storage Tool +As described above in the QuickStart section, you must use the `kafka-storage.sh` tool to generate a cluster ID for your new cluster, and then run the format command on each node before starting the node. + +This is different from how Kafka has operated in the past. Previously, Kafka would format blank storage directories automatically, and also generate a new cluster UUID automatically. One reason for the change +is that auto-formatting can sometimes obscure an error condition. For example, under UNIX, if a data directory can't be mounted, it may show up as blank. In this case, auto-formatting would be the wrong thing to do. + +This is particularly important for the metadata log maintained by the controller servers. If two controllers out of three controllers were able to start with blank logs, a leader might be able to be elected with +nothing in the log, which would cause all metadata to be lost. + +# Missing Features +We don't support any kind of upgrade right now, either to or from KRaft mode. This is an important gap that we are working on. + +Finally, the following Kafka features have not yet been fully implemented: + +* Support for certain security features: configuring a KRaft-based Authorizer, setting up SCRAM, delegation tokens, and so forth + (although note that you can use authorizers such as `kafka.security.authorizer.AclAuthorizer` with KRaft clusters, even + if they are ZooKeeper-based: simply define `authorizer.class.name` and configure the authorizer as you normally would). +* Support for some configurations, like enabling unclean leader election by default or dynamically changing broker endpoints +* Support for KIP-112 "JBOD" modes + +We've tried to make it clear when a feature is not supported in the preview release, but you may encounter some rough edges. We will cover these feature gaps incrementally in the `trunk` branch. + +# Debugging +If you encounter an issue, you might want to take a look at the metadata log. + +## kafka-dump-log +One way to view the metadata log is with kafka-dump-log.sh tool, like so: + +~~~~ +$ ./bin/kafka-dump-log.sh --cluster-metadata-decoder --skip-record-metadata --files /tmp/kraft-combined-logs/__cluster_metadata-0/*.log +Dumping /tmp/kraft-combined-logs/__cluster_metadata-0/00000000000000000000.log +Starting offset: 0 +baseOffset: 0 lastOffset: 0 count: 1 baseSequence: -1 lastSequence: -1 producerId: -1 producerEpoch: -1 partitionLeaderEpoch: 1 isTransactional: false isControl: true position: 0 CreateTime: 1614382631640 size: 89 magic: 2 compresscodec: NONE crc: 1438115474 isvalid: true + +baseOffset: 1 lastOffset: 1 count: 1 baseSequence: -1 lastSequence: -1 producerId: -1 producerEpoch: -1 partitionLeaderEpoch: 1 isTransactional: false isControl: false position: 89 CreateTime: 1614382632329 size: 137 magic: 2 compresscodec: NONE crc: 1095855865 isvalid: true + payload: {"type":"REGISTER_BROKER_RECORD","version":0,"data":{"brokerId":1,"incarnationId":"P3UFsWoNR-erL9PK98YLsA","brokerEpoch":0,"endPoints":[{"name":"PLAINTEXT","host":"localhost","port":9092,"securityProtocol":0}],"features":[],"rack":null}} +baseOffset: 2 lastOffset: 2 count: 1 baseSequence: -1 lastSequence: -1 producerId: -1 producerEpoch: -1 partitionLeaderEpoch: 1 isTransactional: false isControl: false position: 226 CreateTime: 1614382632453 size: 83 magic: 2 compresscodec: NONE crc: 455187130 isvalid: true + payload: {"type":"UNFENCE_BROKER_RECORD","version":0,"data":{"id":1,"epoch":0}} +baseOffset: 3 lastOffset: 3 count: 1 baseSequence: -1 lastSequence: -1 producerId: -1 producerEpoch: -1 partitionLeaderEpoch: 1 isTransactional: false isControl: false position: 309 CreateTime: 1614382634484 size: 83 magic: 2 compresscodec: NONE crc: 4055692847 isvalid: true + payload: {"type":"FENCE_BROKER_RECORD","version":0,"data":{"id":1,"epoch":0}} +baseOffset: 4 lastOffset: 4 count: 1 baseSequence: -1 lastSequence: -1 producerId: -1 producerEpoch: -1 partitionLeaderEpoch: 2 isTransactional: false isControl: true position: 392 CreateTime: 1614382671857 size: 89 magic: 2 compresscodec: NONE crc: 1318571838 isvalid: true + +baseOffset: 5 lastOffset: 5 count: 1 baseSequence: -1 lastSequence: -1 producerId: -1 producerEpoch: -1 partitionLeaderEpoch: 2 isTransactional: false isControl: false position: 481 CreateTime: 1614382672440 size: 137 magic: 2 compresscodec: NONE crc: 841144615 isvalid: true + payload: {"type":"REGISTER_BROKER_RECORD","version":0,"data":{"brokerId":1,"incarnationId":"RXRJu7cnScKRZOnWQGs86g","brokerEpoch":4,"endPoints":[{"name":"PLAINTEXT","host":"localhost","port":9092,"securityProtocol":0}],"features":[],"rack":null}} +baseOffset: 6 lastOffset: 6 count: 1 baseSequence: -1 lastSequence: -1 producerId: -1 producerEpoch: -1 partitionLeaderEpoch: 2 isTransactional: false isControl: false position: 618 CreateTime: 1614382672544 size: 83 magic: 2 compresscodec: NONE crc: 4155905922 isvalid: true + payload: {"type":"UNFENCE_BROKER_RECORD","version":0,"data":{"id":1,"epoch":4}} +baseOffset: 7 lastOffset: 8 count: 2 baseSequence: -1 lastSequence: -1 producerId: -1 producerEpoch: -1 partitionLeaderEpoch: 2 isTransactional: false isControl: false position: 701 CreateTime: 1614382712158 size: 159 magic: 2 compresscodec: NONE crc: 3726758683 isvalid: true + payload: {"type":"TOPIC_RECORD","version":0,"data":{"name":"foo","topicId":"5zoAlv-xEh9xRANKXt1Lbg"}} + payload: {"type":"PARTITION_RECORD","version":0,"data":{"partitionId":0,"topicId":"5zoAlv-xEh9xRANKXt1Lbg","replicas":[1],"isr":[1],"removingReplicas":null,"addingReplicas":null,"leader":1,"leaderEpoch":0,"partitionEpoch":0}} +~~~~ + +## The Metadata Shell +Another tool for examining the metadata logs is the Kafka metadata shell. Just like the ZooKeeper shell, this allows you to inspect the metadata of the cluster. + +~~~~ +$ ./bin/kafka-metadata-shell.sh --snapshot /tmp/kraft-combined-logs/__cluster_metadata-0/00000000000000000000.log +>> ls / +brokers local metadataQuorum topicIds topics +>> ls /topics +foo +>> cat /topics/foo/0/data +{ + "partitionId" : 0, + "topicId" : "5zoAlv-xEh9xRANKXt1Lbg", + "replicas" : [ 1 ], + "isr" : [ 1 ], + "removingReplicas" : null, + "addingReplicas" : null, + "leader" : 1, + "leaderEpoch" : 0, + "partitionEpoch" : 0 +} +>> exit +~~~~ diff --git a/config/kraft/broker.properties b/config/kraft/broker.properties new file mode 100644 index 0000000..dfbd6ec --- /dev/null +++ b/config/kraft/broker.properties @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# +# This configuration file is intended for use in KRaft mode, where +# Apache ZooKeeper is not present. See config/kraft/README.md for details. +# + +############################# Server Basics ############################# + +# The role of this server. Setting this puts us in KRaft mode +process.roles=broker + +# The node id associated with this instance's roles +node.id=2 + +# The connect string for the controller quorum +controller.quorum.voters=1@localhost:9093 + +############################# Socket Server Settings ############################# + +# The address the socket server listens on. It will get the value returned from +# java.net.InetAddress.getCanonicalHostName() if not configured. +# FORMAT: +# listeners = listener_name://host_name:port +# EXAMPLE: +# listeners = PLAINTEXT://your.host.name:9092 +listeners=PLAINTEXT://localhost:9092 +inter.broker.listener.name=PLAINTEXT + +# Hostname and port the broker will advertise to producers and consumers. If not set, +# it uses the value for "listeners" if configured. Otherwise, it will use the value +# returned from java.net.InetAddress.getCanonicalHostName(). +advertised.listeners=PLAINTEXT://localhost:9092 + +# Listener, host name, and port for the controller to advertise to the brokers. If +# this server is a controller, this listener must be configured. +controller.listener.names=CONTROLLER + +# Maps listener names to security protocols, the default is for them to be the same. See the config documentation for more details +listener.security.protocol.map=CONTROLLER:PLAINTEXT,PLAINTEXT:PLAINTEXT,SSL:SSL,SASL_PLAINTEXT:SASL_PLAINTEXT,SASL_SSL:SASL_SSL + +# The number of threads that the server uses for receiving requests from the network and sending responses to the network +num.network.threads=3 + +# The number of threads that the server uses for processing requests, which may include disk I/O +num.io.threads=8 + +# The send buffer (SO_SNDBUF) used by the socket server +socket.send.buffer.bytes=102400 + +# The receive buffer (SO_RCVBUF) used by the socket server +socket.receive.buffer.bytes=102400 + +# The maximum size of a request that the socket server will accept (protection against OOM) +socket.request.max.bytes=104857600 + + +############################# Log Basics ############################# + +# A comma separated list of directories under which to store log files +log.dirs=/tmp/kraft-broker-logs + +# The default number of log partitions per topic. More partitions allow greater +# parallelism for consumption, but this will also result in more files across +# the brokers. +num.partitions=1 + +# The number of threads per data directory to be used for log recovery at startup and flushing at shutdown. +# This value is recommended to be increased for installations with data dirs located in RAID array. +num.recovery.threads.per.data.dir=1 + +############################# Internal Topic Settings ############################# +# The replication factor for the group metadata internal topics "__consumer_offsets" and "__transaction_state" +# For anything other than development testing, a value greater than 1 is recommended to ensure availability such as 3. +offsets.topic.replication.factor=1 +transaction.state.log.replication.factor=1 +transaction.state.log.min.isr=1 + +############################# Log Flush Policy ############################# + +# Messages are immediately written to the filesystem but by default we only fsync() to sync +# the OS cache lazily. The following configurations control the flush of data to disk. +# There are a few important trade-offs here: +# 1. Durability: Unflushed data may be lost if you are not using replication. +# 2. Latency: Very large flush intervals may lead to latency spikes when the flush does occur as there will be a lot of data to flush. +# 3. Throughput: The flush is generally the most expensive operation, and a small flush interval may lead to excessive seeks. +# The settings below allow one to configure the flush policy to flush data after a period of time or +# every N messages (or both). This can be done globally and overridden on a per-topic basis. + +# The number of messages to accept before forcing a flush of data to disk +#log.flush.interval.messages=10000 + +# The maximum amount of time a message can sit in a log before we force a flush +#log.flush.interval.ms=1000 + +############################# Log Retention Policy ############################# + +# The following configurations control the disposal of log segments. The policy can +# be set to delete segments after a period of time, or after a given size has accumulated. +# A segment will be deleted whenever *either* of these criteria are met. Deletion always happens +# from the end of the log. + +# The minimum age of a log file to be eligible for deletion due to age +log.retention.hours=168 + +# A size-based retention policy for logs. Segments are pruned from the log unless the remaining +# segments drop below log.retention.bytes. Functions independently of log.retention.hours. +#log.retention.bytes=1073741824 + +# The maximum size of a log segment file. When this size is reached a new log segment will be created. +log.segment.bytes=1073741824 + +# The interval at which log segments are checked to see if they can be deleted according +# to the retention policies +log.retention.check.interval.ms=300000 diff --git a/config/kraft/controller.properties b/config/kraft/controller.properties new file mode 100644 index 0000000..54aa7fb --- /dev/null +++ b/config/kraft/controller.properties @@ -0,0 +1,127 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# +# This configuration file is intended for use in KRaft mode, where +# Apache ZooKeeper is not present. See config/kraft/README.md for details. +# + +############################# Server Basics ############################# + +# The role of this server. Setting this puts us in KRaft mode +process.roles=controller + +# The node id associated with this instance's roles +node.id=1 + +# The connect string for the controller quorum +controller.quorum.voters=1@localhost:9093 + +############################# Socket Server Settings ############################# + +# The address the socket server listens on. It will get the value returned from +# java.net.InetAddress.getCanonicalHostName() if not configured. +# FORMAT: +# listeners = listener_name://host_name:port +# EXAMPLE: +# listeners = PLAINTEXT://your.host.name:9092 +listeners=PLAINTEXT://:9093 + +# Hostname and port the broker will advertise to producers and consumers. If not set, +# it uses the value for "listeners" if configured. Otherwise, it will use the value +# returned from java.net.InetAddress.getCanonicalHostName(). +#advertised.listeners=PLAINTEXT://your.host.name:9092 + +# Listener, host name, and port for the controller to advertise to the brokers. If +# this server is a controller, this listener must be configured. +controller.listener.names=PLAINTEXT + +# Maps listener names to security protocols, the default is for them to be the same. See the config documentation for more details +#listener.security.protocol.map=PLAINTEXT:PLAINTEXT,SSL:SSL,SASL_PLAINTEXT:SASL_PLAINTEXT,SASL_SSL:SASL_SSL + +# The number of threads that the server uses for receiving requests from the network and sending responses to the network +num.network.threads=3 + +# The number of threads that the server uses for processing requests, which may include disk I/O +num.io.threads=8 + +# The send buffer (SO_SNDBUF) used by the socket server +socket.send.buffer.bytes=102400 + +# The receive buffer (SO_RCVBUF) used by the socket server +socket.receive.buffer.bytes=102400 + +# The maximum size of a request that the socket server will accept (protection against OOM) +socket.request.max.bytes=104857600 + + +############################# Log Basics ############################# + +# A comma separated list of directories under which to store log files +log.dirs=/tmp/kraft-controller-logs + +# The default number of log partitions per topic. More partitions allow greater +# parallelism for consumption, but this will also result in more files across +# the brokers. +num.partitions=1 + +# The number of threads per data directory to be used for log recovery at startup and flushing at shutdown. +# This value is recommended to be increased for installations with data dirs located in RAID array. +num.recovery.threads.per.data.dir=1 + +############################# Internal Topic Settings ############################# +# The replication factor for the group metadata internal topics "__consumer_offsets" and "__transaction_state" +# For anything other than development testing, a value greater than 1 is recommended to ensure availability such as 3. +offsets.topic.replication.factor=1 +transaction.state.log.replication.factor=1 +transaction.state.log.min.isr=1 + +############################# Log Flush Policy ############################# + +# Messages are immediately written to the filesystem but by default we only fsync() to sync +# the OS cache lazily. The following configurations control the flush of data to disk. +# There are a few important trade-offs here: +# 1. Durability: Unflushed data may be lost if you are not using replication. +# 2. Latency: Very large flush intervals may lead to latency spikes when the flush does occur as there will be a lot of data to flush. +# 3. Throughput: The flush is generally the most expensive operation, and a small flush interval may lead to excessive seeks. +# The settings below allow one to configure the flush policy to flush data after a period of time or +# every N messages (or both). This can be done globally and overridden on a per-topic basis. + +# The number of messages to accept before forcing a flush of data to disk +#log.flush.interval.messages=10000 + +# The maximum amount of time a message can sit in a log before we force a flush +#log.flush.interval.ms=1000 + +############################# Log Retention Policy ############################# + +# The following configurations control the disposal of log segments. The policy can +# be set to delete segments after a period of time, or after a given size has accumulated. +# A segment will be deleted whenever *either* of these criteria are met. Deletion always happens +# from the end of the log. + +# The minimum age of a log file to be eligible for deletion due to age +log.retention.hours=168 + +# A size-based retention policy for logs. Segments are pruned from the log unless the remaining +# segments drop below log.retention.bytes. Functions independently of log.retention.hours. +#log.retention.bytes=1073741824 + +# The maximum size of a log segment file. When this size is reached a new log segment will be created. +log.segment.bytes=1073741824 + +# The interval at which log segments are checked to see if they can be deleted according +# to the retention policies +log.retention.check.interval.ms=300000 diff --git a/config/kraft/server.properties b/config/kraft/server.properties new file mode 100644 index 0000000..8e6406c --- /dev/null +++ b/config/kraft/server.properties @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# +# This configuration file is intended for use in KRaft mode, where +# Apache ZooKeeper is not present. See config/kraft/README.md for details. +# + +############################# Server Basics ############################# + +# The role of this server. Setting this puts us in KRaft mode +process.roles=broker,controller + +# The node id associated with this instance's roles +node.id=1 + +# The connect string for the controller quorum +controller.quorum.voters=1@localhost:9093 + +############################# Socket Server Settings ############################# + +# The address the socket server listens on. It will get the value returned from +# java.net.InetAddress.getCanonicalHostName() if not configured. +# FORMAT: +# listeners = listener_name://host_name:port +# EXAMPLE: +# listeners = PLAINTEXT://your.host.name:9092 +listeners=PLAINTEXT://:9092,CONTROLLER://:9093 +inter.broker.listener.name=PLAINTEXT + +# Hostname and port the broker will advertise to producers and consumers. If not set, +# it uses the value for "listeners" if configured. Otherwise, it will use the value +# returned from java.net.InetAddress.getCanonicalHostName(). +advertised.listeners=PLAINTEXT://localhost:9092 + +# Listener, host name, and port for the controller to advertise to the brokers. If +# this server is a controller, this listener must be configured. +controller.listener.names=CONTROLLER + +# Maps listener names to security protocols, the default is for them to be the same. See the config documentation for more details +listener.security.protocol.map=CONTROLLER:PLAINTEXT,PLAINTEXT:PLAINTEXT,SSL:SSL,SASL_PLAINTEXT:SASL_PLAINTEXT,SASL_SSL:SASL_SSL + +# The number of threads that the server uses for receiving requests from the network and sending responses to the network +num.network.threads=3 + +# The number of threads that the server uses for processing requests, which may include disk I/O +num.io.threads=8 + +# The send buffer (SO_SNDBUF) used by the socket server +socket.send.buffer.bytes=102400 + +# The receive buffer (SO_RCVBUF) used by the socket server +socket.receive.buffer.bytes=102400 + +# The maximum size of a request that the socket server will accept (protection against OOM) +socket.request.max.bytes=104857600 + + +############################# Log Basics ############################# + +# A comma separated list of directories under which to store log files +log.dirs=/tmp/kraft-combined-logs + +# The default number of log partitions per topic. More partitions allow greater +# parallelism for consumption, but this will also result in more files across +# the brokers. +num.partitions=1 + +# The number of threads per data directory to be used for log recovery at startup and flushing at shutdown. +# This value is recommended to be increased for installations with data dirs located in RAID array. +num.recovery.threads.per.data.dir=1 + +############################# Internal Topic Settings ############################# +# The replication factor for the group metadata internal topics "__consumer_offsets" and "__transaction_state" +# For anything other than development testing, a value greater than 1 is recommended to ensure availability such as 3. +offsets.topic.replication.factor=1 +transaction.state.log.replication.factor=1 +transaction.state.log.min.isr=1 + +############################# Log Flush Policy ############################# + +# Messages are immediately written to the filesystem but by default we only fsync() to sync +# the OS cache lazily. The following configurations control the flush of data to disk. +# There are a few important trade-offs here: +# 1. Durability: Unflushed data may be lost if you are not using replication. +# 2. Latency: Very large flush intervals may lead to latency spikes when the flush does occur as there will be a lot of data to flush. +# 3. Throughput: The flush is generally the most expensive operation, and a small flush interval may lead to excessive seeks. +# The settings below allow one to configure the flush policy to flush data after a period of time or +# every N messages (or both). This can be done globally and overridden on a per-topic basis. + +# The number of messages to accept before forcing a flush of data to disk +#log.flush.interval.messages=10000 + +# The maximum amount of time a message can sit in a log before we force a flush +#log.flush.interval.ms=1000 + +############################# Log Retention Policy ############################# + +# The following configurations control the disposal of log segments. The policy can +# be set to delete segments after a period of time, or after a given size has accumulated. +# A segment will be deleted whenever *either* of these criteria are met. Deletion always happens +# from the end of the log. + +# The minimum age of a log file to be eligible for deletion due to age +log.retention.hours=168 + +# A size-based retention policy for logs. Segments are pruned from the log unless the remaining +# segments drop below log.retention.bytes. Functions independently of log.retention.hours. +#log.retention.bytes=1073741824 + +# The maximum size of a log segment file. When this size is reached a new log segment will be created. +log.segment.bytes=1073741824 + +# The interval at which log segments are checked to see if they can be deleted according +# to the retention policies +log.retention.check.interval.ms=300000 diff --git a/config/log4j.properties b/config/log4j.properties new file mode 100644 index 0000000..4cbce9d --- /dev/null +++ b/config/log4j.properties @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Unspecified loggers and loggers with additivity=true output to server.log and stdout +# Note that INFO only applies to unspecified loggers, the log level of the child logger is used otherwise +log4j.rootLogger=INFO, stdout, kafkaAppender + +log4j.appender.stdout=org.apache.log4j.ConsoleAppender +log4j.appender.stdout.layout=org.apache.log4j.PatternLayout +log4j.appender.stdout.layout.ConversionPattern=[%d] %p %m (%c)%n + +log4j.appender.kafkaAppender=org.apache.log4j.DailyRollingFileAppender +log4j.appender.kafkaAppender.DatePattern='.'yyyy-MM-dd-HH +log4j.appender.kafkaAppender.File=${kafka.logs.dir}/server.log +log4j.appender.kafkaAppender.layout=org.apache.log4j.PatternLayout +log4j.appender.kafkaAppender.layout.ConversionPattern=[%d] %p %m (%c)%n + +log4j.appender.stateChangeAppender=org.apache.log4j.DailyRollingFileAppender +log4j.appender.stateChangeAppender.DatePattern='.'yyyy-MM-dd-HH +log4j.appender.stateChangeAppender.File=${kafka.logs.dir}/state-change.log +log4j.appender.stateChangeAppender.layout=org.apache.log4j.PatternLayout +log4j.appender.stateChangeAppender.layout.ConversionPattern=[%d] %p %m (%c)%n + +log4j.appender.requestAppender=org.apache.log4j.DailyRollingFileAppender +log4j.appender.requestAppender.DatePattern='.'yyyy-MM-dd-HH +log4j.appender.requestAppender.File=${kafka.logs.dir}/kafka-request.log +log4j.appender.requestAppender.layout=org.apache.log4j.PatternLayout +log4j.appender.requestAppender.layout.ConversionPattern=[%d] %p %m (%c)%n + +log4j.appender.cleanerAppender=org.apache.log4j.DailyRollingFileAppender +log4j.appender.cleanerAppender.DatePattern='.'yyyy-MM-dd-HH +log4j.appender.cleanerAppender.File=${kafka.logs.dir}/log-cleaner.log +log4j.appender.cleanerAppender.layout=org.apache.log4j.PatternLayout +log4j.appender.cleanerAppender.layout.ConversionPattern=[%d] %p %m (%c)%n + +log4j.appender.controllerAppender=org.apache.log4j.DailyRollingFileAppender +log4j.appender.controllerAppender.DatePattern='.'yyyy-MM-dd-HH +log4j.appender.controllerAppender.File=${kafka.logs.dir}/controller.log +log4j.appender.controllerAppender.layout=org.apache.log4j.PatternLayout +log4j.appender.controllerAppender.layout.ConversionPattern=[%d] %p %m (%c)%n + +log4j.appender.authorizerAppender=org.apache.log4j.DailyRollingFileAppender +log4j.appender.authorizerAppender.DatePattern='.'yyyy-MM-dd-HH +log4j.appender.authorizerAppender.File=${kafka.logs.dir}/kafka-authorizer.log +log4j.appender.authorizerAppender.layout=org.apache.log4j.PatternLayout +log4j.appender.authorizerAppender.layout.ConversionPattern=[%d] %p %m (%c)%n + +# Change the line below to adjust ZK client logging +log4j.logger.org.apache.zookeeper=INFO + +# Change the two lines below to adjust the general broker logging level (output to server.log and stdout) +log4j.logger.kafka=INFO +log4j.logger.org.apache.kafka=INFO + +# Change to DEBUG or TRACE to enable request logging +log4j.logger.kafka.request.logger=WARN, requestAppender +log4j.additivity.kafka.request.logger=false + +# Uncomment the lines below and change log4j.logger.kafka.network.RequestChannel$ to TRACE for additional output +# related to the handling of requests +#log4j.logger.kafka.network.Processor=TRACE, requestAppender +#log4j.logger.kafka.server.KafkaApis=TRACE, requestAppender +#log4j.additivity.kafka.server.KafkaApis=false +log4j.logger.kafka.network.RequestChannel$=WARN, requestAppender +log4j.additivity.kafka.network.RequestChannel$=false + +log4j.logger.kafka.controller=TRACE, controllerAppender +log4j.additivity.kafka.controller=false + +log4j.logger.kafka.log.LogCleaner=INFO, cleanerAppender +log4j.additivity.kafka.log.LogCleaner=false + +log4j.logger.state.change.logger=INFO, stateChangeAppender +log4j.additivity.state.change.logger=false + +# Access denials are logged at INFO level, change to DEBUG to also log allowed accesses +log4j.logger.kafka.authorizer.logger=INFO, authorizerAppender +log4j.additivity.kafka.authorizer.logger=false + diff --git a/config/producer.properties b/config/producer.properties new file mode 100644 index 0000000..4786b98 --- /dev/null +++ b/config/producer.properties @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# see org.apache.kafka.clients.producer.ProducerConfig for more details + +############################# Producer Basics ############################# + +# list of brokers used for bootstrapping knowledge about the rest of the cluster +# format: host1:port1,host2:port2 ... +bootstrap.servers=localhost:9092 + +# specify the compression codec for all data generated: none, gzip, snappy, lz4, zstd +compression.type=none + +# name of the partitioner class for partitioning events; default partition spreads data randomly +#partitioner.class= + +# the maximum amount of time the client will wait for the response of a request +#request.timeout.ms= + +# how long `KafkaProducer.send` and `KafkaProducer.partitionsFor` will block for +#max.block.ms= + +# the producer will wait for up to the given delay to allow other records to be sent so that the sends can be batched together +#linger.ms= + +# the maximum size of a request in bytes +#max.request.size= + +# the default batch size in bytes when batching multiple records sent to a partition +#batch.size= + +# the total bytes of memory the producer can use to buffer records waiting to be sent to the server +#buffer.memory= diff --git a/config/server.properties b/config/server.properties new file mode 100644 index 0000000..b1cf5c4 --- /dev/null +++ b/config/server.properties @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# see kafka.server.KafkaConfig for additional details and defaults + +############################# Server Basics ############################# + +# The id of the broker. This must be set to a unique integer for each broker. +broker.id=0 + +############################# Socket Server Settings ############################# + +# The address the socket server listens on. It will get the value returned from +# java.net.InetAddress.getCanonicalHostName() if not configured. +# FORMAT: +# listeners = listener_name://host_name:port +# EXAMPLE: +# listeners = PLAINTEXT://your.host.name:9092 +#listeners=PLAINTEXT://:9092 + +# Hostname and port the broker will advertise to producers and consumers. If not set, +# it uses the value for "listeners" if configured. Otherwise, it will use the value +# returned from java.net.InetAddress.getCanonicalHostName(). +#advertised.listeners=PLAINTEXT://your.host.name:9092 + +# Maps listener names to security protocols, the default is for them to be the same. See the config documentation for more details +#listener.security.protocol.map=PLAINTEXT:PLAINTEXT,SSL:SSL,SASL_PLAINTEXT:SASL_PLAINTEXT,SASL_SSL:SASL_SSL + +# The number of threads that the server uses for receiving requests from the network and sending responses to the network +num.network.threads=3 + +# The number of threads that the server uses for processing requests, which may include disk I/O +num.io.threads=8 + +# The send buffer (SO_SNDBUF) used by the socket server +socket.send.buffer.bytes=102400 + +# The receive buffer (SO_RCVBUF) used by the socket server +socket.receive.buffer.bytes=102400 + +# The maximum size of a request that the socket server will accept (protection against OOM) +socket.request.max.bytes=104857600 + + +############################# Log Basics ############################# + +# A comma separated list of directories under which to store log files +log.dirs=/tmp/kafka-logs + +# The default number of log partitions per topic. More partitions allow greater +# parallelism for consumption, but this will also result in more files across +# the brokers. +num.partitions=1 + +# The number of threads per data directory to be used for log recovery at startup and flushing at shutdown. +# This value is recommended to be increased for installations with data dirs located in RAID array. +num.recovery.threads.per.data.dir=1 + +############################# Internal Topic Settings ############################# +# The replication factor for the group metadata internal topics "__consumer_offsets" and "__transaction_state" +# For anything other than development testing, a value greater than 1 is recommended to ensure availability such as 3. +offsets.topic.replication.factor=1 +transaction.state.log.replication.factor=1 +transaction.state.log.min.isr=1 + +############################# Log Flush Policy ############################# + +# Messages are immediately written to the filesystem but by default we only fsync() to sync +# the OS cache lazily. The following configurations control the flush of data to disk. +# There are a few important trade-offs here: +# 1. Durability: Unflushed data may be lost if you are not using replication. +# 2. Latency: Very large flush intervals may lead to latency spikes when the flush does occur as there will be a lot of data to flush. +# 3. Throughput: The flush is generally the most expensive operation, and a small flush interval may lead to excessive seeks. +# The settings below allow one to configure the flush policy to flush data after a period of time or +# every N messages (or both). This can be done globally and overridden on a per-topic basis. + +# The number of messages to accept before forcing a flush of data to disk +#log.flush.interval.messages=10000 + +# The maximum amount of time a message can sit in a log before we force a flush +#log.flush.interval.ms=1000 + +############################# Log Retention Policy ############################# + +# The following configurations control the disposal of log segments. The policy can +# be set to delete segments after a period of time, or after a given size has accumulated. +# A segment will be deleted whenever *either* of these criteria are met. Deletion always happens +# from the end of the log. + +# The minimum age of a log file to be eligible for deletion due to age +log.retention.hours=168 + +# A size-based retention policy for logs. Segments are pruned from the log unless the remaining +# segments drop below log.retention.bytes. Functions independently of log.retention.hours. +#log.retention.bytes=1073741824 + +# The maximum size of a log segment file. When this size is reached a new log segment will be created. +log.segment.bytes=1073741824 + +# The interval at which log segments are checked to see if they can be deleted according +# to the retention policies +log.retention.check.interval.ms=300000 + +############################# Zookeeper ############################# + +# Zookeeper connection string (see zookeeper docs for details). +# This is a comma separated host:port pairs, each corresponding to a zk +# server. e.g. "127.0.0.1:3000,127.0.0.1:3001,127.0.0.1:3002". +# You can also append an optional chroot string to the urls to specify the +# root directory for all kafka znodes. +zookeeper.connect=localhost:2181 + +# Timeout in ms for connecting to zookeeper +zookeeper.connection.timeout.ms=18000 + + +############################# Group Coordinator Settings ############################# + +# The following configuration specifies the time, in milliseconds, that the GroupCoordinator will delay the initial consumer rebalance. +# The rebalance will be further delayed by the value of group.initial.rebalance.delay.ms as new members join the group, up to a maximum of max.poll.interval.ms. +# The default value for this is 3 seconds. +# We override this to 0 here as it makes for a better out-of-the-box experience for development and testing. +# However, in production environments the default value of 3 seconds is more suitable as this will help to avoid unnecessary, and potentially expensive, rebalances during application startup. +group.initial.rebalance.delay.ms=0 diff --git a/config/tools-log4j.properties b/config/tools-log4j.properties new file mode 100644 index 0000000..b19e343 --- /dev/null +++ b/config/tools-log4j.properties @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +log4j.rootLogger=WARN, stderr + +log4j.appender.stderr=org.apache.log4j.ConsoleAppender +log4j.appender.stderr.layout=org.apache.log4j.PatternLayout +log4j.appender.stderr.layout.ConversionPattern=[%d] %p %m (%c)%n +log4j.appender.stderr.Target=System.err diff --git a/config/trogdor.conf b/config/trogdor.conf new file mode 100644 index 0000000..320cbe7 --- /dev/null +++ b/config/trogdor.conf @@ -0,0 +1,25 @@ +{ + "_comment": [ + "Licensed to the Apache Software Foundation (ASF) under one or more", + "contributor license agreements. See the NOTICE file distributed with", + "this work for additional information regarding copyright ownership.", + "The ASF licenses this file to You under the Apache License, Version 2.0", + "(the \"License\"); you may not use this file except in compliance with", + "the License. You may obtain a copy of the License at", + "", + "http://www.apache.org/licenses/LICENSE-2.0", + "", + "Unless required by applicable law or agreed to in writing, software", + "distributed under the License is distributed on an \"AS IS\" BASIS,", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.", + "See the License for the specific language governing permissions and", + "limitations under the License." + ], + "platform": "org.apache.kafka.trogdor.basic.BasicPlatform", "nodes": { + "node0": { + "hostname": "localhost", + "trogdor.agent.port": 8888, + "trogdor.coordinator.port": 8889 + } + } +} diff --git a/config/zookeeper.properties b/config/zookeeper.properties new file mode 100644 index 0000000..90f4332 --- /dev/null +++ b/config/zookeeper.properties @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# the directory where the snapshot is stored. +dataDir=/tmp/zookeeper +# the port at which the clients will connect +clientPort=2181 +# disable the per-ip limit on the number of connections since this is a non-production config +maxClientCnxns=0 +# Disable the adminserver by default to avoid port conflicts. +# Set the port to something non-conflicting if choosing to enable this +admin.enableServer=false +# admin.serverPort=8080 diff --git a/connect/api/.gitignore b/connect/api/.gitignore new file mode 100644 index 0000000..ae3c172 --- /dev/null +++ b/connect/api/.gitignore @@ -0,0 +1 @@ +/bin/ diff --git a/connect/api/src/main/java/org/apache/kafka/connect/components/Versioned.java b/connect/api/src/main/java/org/apache/kafka/connect/components/Versioned.java new file mode 100644 index 0000000..adabe8f --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/components/Versioned.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.components; + +/** + * Connect requires some components implement this interface to define a version string. + */ +public interface Versioned { + /** + * Get the version of this component. + * + * @return the version, formatted as a String. The version may not be (@code null} or empty. + */ + String version(); +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/connector/ConnectRecord.java b/connect/api/src/main/java/org/apache/kafka/connect/connector/ConnectRecord.java new file mode 100644 index 0000000..1cc756b --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/connector/ConnectRecord.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.connector; + +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.header.ConnectHeaders; +import org.apache.kafka.connect.header.Header; +import org.apache.kafka.connect.header.Headers; + +import java.util.Objects; + +/** + *

            + * Base class for records containing data to be copied to/from Kafka. This corresponds closely to + * Kafka's {@link org.apache.kafka.clients.producer.ProducerRecord ProducerRecord} and {@link org.apache.kafka.clients.consumer.ConsumerRecord ConsumerRecord} classes, and holds the data that may be used by both + * sources and sinks (topic, kafkaPartition, key, value). Although both implementations include a + * notion of offset, it is not included here because they differ in type. + *

            + */ +public abstract class ConnectRecord> { + private final String topic; + private final Integer kafkaPartition; + private final Schema keySchema; + private final Object key; + private final Schema valueSchema; + private final Object value; + private final Long timestamp; + private final Headers headers; + + public ConnectRecord(String topic, Integer kafkaPartition, + Schema keySchema, Object key, + Schema valueSchema, Object value, + Long timestamp) { + this(topic, kafkaPartition, keySchema, key, valueSchema, value, timestamp, new ConnectHeaders()); + } + + public ConnectRecord(String topic, Integer kafkaPartition, + Schema keySchema, Object key, + Schema valueSchema, Object value, + Long timestamp, Iterable
            headers) { + this.topic = topic; + this.kafkaPartition = kafkaPartition; + this.keySchema = keySchema; + this.key = key; + this.valueSchema = valueSchema; + this.value = value; + this.timestamp = timestamp; + if (headers instanceof ConnectHeaders) { + this.headers = (ConnectHeaders) headers; + } else { + this.headers = new ConnectHeaders(headers); + } + } + + public String topic() { + return topic; + } + + public Integer kafkaPartition() { + return kafkaPartition; + } + + public Object key() { + return key; + } + + public Schema keySchema() { + return keySchema; + } + + public Object value() { + return value; + } + + public Schema valueSchema() { + return valueSchema; + } + + public Long timestamp() { + return timestamp; + } + + /** + * Get the headers for this record. + * + * @return the headers; never null + */ + public Headers headers() { + return headers; + } + + /** + * Create a new record of the same type as itself, with the specified parameter values. All other fields in this record will be copied + * over to the new record. Since the headers are mutable, the resulting record will have a copy of this record's headers. + * + * @param topic the name of the topic; may be null + * @param kafkaPartition the partition number for the Kafka topic; may be null + * @param keySchema the schema for the key; may be null + * @param key the key; may be null + * @param valueSchema the schema for the value; may be null + * @param value the value; may be null + * @param timestamp the timestamp; may be null + * @return the new record + */ + public abstract R newRecord(String topic, Integer kafkaPartition, Schema keySchema, Object key, Schema valueSchema, Object value, Long timestamp); + + /** + * Create a new record of the same type as itself, with the specified parameter values. All other fields in this record will be copied + * over to the new record. + * + * @param topic the name of the topic; may be null + * @param kafkaPartition the partition number for the Kafka topic; may be null + * @param keySchema the schema for the key; may be null + * @param key the key; may be null + * @param valueSchema the schema for the value; may be null + * @param value the value; may be null + * @param timestamp the timestamp; may be null + * @param headers the headers; may be null or empty + * @return the new record + */ + public abstract R newRecord(String topic, Integer kafkaPartition, Schema keySchema, Object key, Schema valueSchema, Object value, Long timestamp, Iterable
            headers); + + @Override + public String toString() { + return "ConnectRecord{" + + "topic='" + topic + '\'' + + ", kafkaPartition=" + kafkaPartition + + ", key=" + key + + ", keySchema=" + keySchema + + ", value=" + value + + ", valueSchema=" + valueSchema + + ", timestamp=" + timestamp + + ", headers=" + headers + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + ConnectRecord that = (ConnectRecord) o; + + return Objects.equals(kafkaPartition, that.kafkaPartition) + && Objects.equals(topic, that.topic) + && Objects.equals(keySchema, that.keySchema) + && Objects.equals(key, that.key) + && Objects.equals(valueSchema, that.valueSchema) + && Objects.equals(value, that.value) + && Objects.equals(timestamp, that.timestamp) + && Objects.equals(headers, that.headers); + } + + @Override + public int hashCode() { + int result = topic != null ? topic.hashCode() : 0; + result = 31 * result + (kafkaPartition != null ? kafkaPartition.hashCode() : 0); + result = 31 * result + (keySchema != null ? keySchema.hashCode() : 0); + result = 31 * result + (key != null ? key.hashCode() : 0); + result = 31 * result + (valueSchema != null ? valueSchema.hashCode() : 0); + result = 31 * result + (value != null ? value.hashCode() : 0); + result = 31 * result + (timestamp != null ? timestamp.hashCode() : 0); + result = 31 * result + headers.hashCode(); + return result; + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/connector/Connector.java b/connect/api/src/main/java/org/apache/kafka/connect/connector/Connector.java new file mode 100644 index 0000000..6d54aab --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/connector/Connector.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.connector; + +import org.apache.kafka.common.config.Config; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigValue; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.components.Versioned; + +import java.util.List; +import java.util.Map; + +/** + *

            + * Connectors manage integration of Kafka Connect with another system, either as an input that ingests + * data into Kafka or an output that passes data to an external system. Implementations should + * not use this class directly; they should inherit from {@link org.apache.kafka.connect.source.SourceConnector SourceConnector} + * or {@link org.apache.kafka.connect.sink.SinkConnector SinkConnector}. + *

            + *

            + * Connectors have two primary tasks. First, given some configuration, they are responsible for + * creating configurations for a set of {@link Task}s that split up the data processing. For + * example, a database Connector might create Tasks by dividing the set of tables evenly among + * tasks. Second, they are responsible for monitoring inputs for changes that require + * reconfiguration and notifying the Kafka Connect runtime via the {@link ConnectorContext}. Continuing the + * previous example, the connector might periodically check for new tables and notify Kafka Connect of + * additions and deletions. Kafka Connect will then request new configurations and update the running + * Tasks. + *

            + */ +public abstract class Connector implements Versioned { + + protected ConnectorContext context; + + + /** + * Initialize this connector, using the provided ConnectorContext to notify the runtime of + * input configuration changes. + * @param ctx context object used to interact with the Kafka Connect runtime + */ + public void initialize(ConnectorContext ctx) { + context = ctx; + } + + /** + *

            + * Initialize this connector, using the provided ConnectorContext to notify the runtime of + * input configuration changes and using the provided set of Task configurations. + * This version is only used to recover from failures. + *

            + *

            + * The default implementation ignores the provided Task configurations. During recovery, Kafka Connect will request + * an updated set of configurations and update the running Tasks appropriately. However, Connectors should + * implement special handling of this case if it will avoid unnecessary changes to running Tasks. + *

            + * + * @param ctx context object used to interact with the Kafka Connect runtime + * @param taskConfigs existing task configurations, which may be used when generating new task configs to avoid + * churn in partition to task assignments + */ + public void initialize(ConnectorContext ctx, List> taskConfigs) { + context = ctx; + // Ignore taskConfigs. May result in more churn of tasks during recovery if updated configs + // are very different, but reduces the difficulty of implementing a Connector + } + + /** + * Returns the context object used to interact with the Kafka Connect runtime. + * + * @return the context for this Connector. + */ + protected ConnectorContext context() { + return context; + } + + /** + * Start this Connector. This method will only be called on a clean Connector, i.e. it has + * either just been instantiated and initialized or {@link #stop()} has been invoked. + * + * @param props configuration settings + */ + public abstract void start(Map props); + + /** + * Reconfigure this Connector. Most implementations will not override this, using the default + * implementation that calls {@link #stop()} followed by {@link #start(Map)}. + * Implementations only need to override this if they want to handle this process more + * efficiently, e.g. without shutting down network connections to the external system. + * + * @param props new configuration settings + */ + public void reconfigure(Map props) { + stop(); + start(props); + } + + /** + * Returns the Task implementation for this Connector. + */ + public abstract Class taskClass(); + + /** + * Returns a set of configurations for Tasks based on the current configuration, + * producing at most count configurations. + * + * @param maxTasks maximum number of configurations to generate + * @return configurations for Tasks + */ + public abstract List> taskConfigs(int maxTasks); + + /** + * Stop this connector. + */ + public abstract void stop(); + + /** + * Validate the connector configuration values against configuration definitions. + * @param connectorConfigs the provided configuration values + * @return List of Config, each Config contains the updated configuration information given + * the current configuration values. + */ + public Config validate(Map connectorConfigs) { + ConfigDef configDef = config(); + if (null == configDef) { + throw new ConnectException( + String.format("%s.config() must return a ConfigDef that is not null.", this.getClass().getName()) + ); + } + List configValues = configDef.validate(connectorConfigs); + return new Config(configValues); + } + + /** + * Define the configuration for the connector. + * @return The ConfigDef for this connector; may not be null. + */ + public abstract ConfigDef config(); +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/connector/ConnectorContext.java b/connect/api/src/main/java/org/apache/kafka/connect/connector/ConnectorContext.java new file mode 100644 index 0000000..3f98c6a --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/connector/ConnectorContext.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.connector; + +/** + * ConnectorContext allows Connectors to proactively interact with the Kafka Connect runtime. + */ +public interface ConnectorContext { + /** + * Requests that the runtime reconfigure the Tasks for this source. This should be used to + * indicate to the runtime that something about the input/output has changed (e.g. partitions + * added/removed) and the running Tasks will need to be modified. + */ + void requestTaskReconfiguration(); + + /** + * Raise an unrecoverable exception to the Connect framework. This will cause the status of the + * connector to transition to FAILED. + * @param e Exception to be raised. + */ + void raiseError(Exception e); +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/connector/Task.java b/connect/api/src/main/java/org/apache/kafka/connect/connector/Task.java new file mode 100644 index 0000000..42b87cd --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/connector/Task.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.connector; + +import java.util.Map; + +/** + *

            + * Tasks contain the code that actually copies data to/from another system. They receive + * a configuration from their parent Connector, assigning them a fraction of a Kafka Connect job's work. + * The Kafka Connect framework then pushes/pulls data from the Task. The Task must also be able to + * respond to reconfiguration requests. + *

            + *

            + * Task only contains the minimal shared functionality between + * {@link org.apache.kafka.connect.source.SourceTask} and + * {@link org.apache.kafka.connect.sink.SinkTask}. + *

            + */ +public interface Task { + /** + * Get the version of this task. Usually this should be the same as the corresponding {@link Connector} class's version. + * + * @return the version, formatted as a String + */ + String version(); + + /** + * Start the Task + * @param props initial configuration + */ + void start(Map props); + + /** + * Stop this task. + */ + void stop(); +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/connector/policy/ConnectorClientConfigOverridePolicy.java b/connect/api/src/main/java/org/apache/kafka/connect/connector/policy/ConnectorClientConfigOverridePolicy.java new file mode 100644 index 0000000..94e5fd6 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/connector/policy/ConnectorClientConfigOverridePolicy.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.connector.policy; + +import org.apache.kafka.common.Configurable; +import org.apache.kafka.common.config.ConfigValue; + +import java.util.List; + +/** + *

            An interface for enforcing a policy on overriding of client configs via the connector configs. + * + *

            Common use cases are ability to provide principal per connector, sasl.jaas.config + * and/or enforcing that the producer/consumer configurations for optimizations are within acceptable ranges. + */ +public interface ConnectorClientConfigOverridePolicy extends Configurable, AutoCloseable { + + + /** + * Worker will invoke this while constructing the producer for the SourceConnectors, DLQ for SinkConnectors and the consumer for the + * SinkConnectors to validate if all of the overridden client configurations are allowed per the + * policy implementation. This would also be invoked during the validate of connector configs via the Rest API. + * + * If there are any policy violations, the connector will not be started. + * + * @param connectorClientConfigRequest an instance of {@code ConnectorClientConfigRequest} that provides the configs to overridden and + * its context; never {@code null} + * @return list of {@link ConfigValue} instances that describe each client configuration in the request and includes an + {@link ConfigValue#errorMessages error} if the configuration is not allowed by the policy; never null + */ + List validate(ConnectorClientConfigRequest connectorClientConfigRequest); +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/connector/policy/ConnectorClientConfigRequest.java b/connect/api/src/main/java/org/apache/kafka/connect/connector/policy/ConnectorClientConfigRequest.java new file mode 100644 index 0000000..4ee0b94 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/connector/policy/ConnectorClientConfigRequest.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.connector.policy; + +import org.apache.kafka.connect.connector.Connector; +import org.apache.kafka.connect.health.ConnectorType; + +import java.util.Map; + +public class ConnectorClientConfigRequest { + + private Map clientProps; + private ClientType clientType; + private String connectorName; + private ConnectorType connectorType; + private Class connectorClass; + + public ConnectorClientConfigRequest( + String connectorName, + ConnectorType connectorType, + Class connectorClass, + Map clientProps, + ClientType clientType) { + this.clientProps = clientProps; + this.clientType = clientType; + this.connectorName = connectorName; + this.connectorType = connectorType; + this.connectorClass = connectorClass; + } + + /** + * Provides Config with prefix {@code producer.override.} for {@link ConnectorType#SOURCE}. + * Provides Config with prefix {@code consumer.override.} for {@link ConnectorType#SINK}. + * Provides Config with prefix {@code producer.override.} for {@link ConnectorType#SINK} for DLQ. + * Provides Config with prefix {@code admin.override.} for {@link ConnectorType#SINK} for DLQ. + * + * @return The client properties specified in the Connector Config with prefix {@code producer.override.} , + * {@code consumer.override.} and {@code admin.override.}. The configs don't include the prefixes. + */ + public Map clientProps() { + return clientProps; + } + + /** + * {@link ClientType#PRODUCER} for {@link ConnectorType#SOURCE} + * {@link ClientType#CONSUMER} for {@link ConnectorType#SINK} + * {@link ClientType#PRODUCER} for DLQ in {@link ConnectorType#SINK} + * {@link ClientType#ADMIN} for DLQ Topic Creation in {@link ConnectorType#SINK} + * + * @return enumeration specifying the client type that is being overriden by the worker; never null. + */ + public ClientType clientType() { + return clientType; + } + + /** + * Name of the connector specified in the connector config. + * + * @return name of the connector; never null. + */ + public String connectorName() { + return connectorName; + } + + /** + * Type of the Connector. + * + * @return enumeration specifying the type of the connector {@link ConnectorType#SINK} or {@link ConnectorType#SOURCE}. + */ + public ConnectorType connectorType() { + return connectorType; + } + + /** + * The class of the Connector. + * + * @return the class of the Connector being created; never null + */ + public Class connectorClass() { + return connectorClass; + } + + public enum ClientType { + PRODUCER, CONSUMER, ADMIN + } +} \ No newline at end of file diff --git a/connect/api/src/main/java/org/apache/kafka/connect/data/ConnectSchema.java b/connect/api/src/main/java/org/apache/kafka/connect/data/ConnectSchema.java new file mode 100644 index 0000000..6892bfc --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/data/ConnectSchema.java @@ -0,0 +1,350 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.data; + +import org.apache.kafka.connect.errors.DataException; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.EnumMap; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +public class ConnectSchema implements Schema { + /** + * Maps Schema.Types to a list of Java classes that can be used to represent them. + */ + private static final Map>> SCHEMA_TYPE_CLASSES = new EnumMap<>(Type.class); + /** + * Maps known logical types to a list of Java classes that can be used to represent them. + */ + private static final Map>> LOGICAL_TYPE_CLASSES = new HashMap<>(); + + /** + * Maps the Java classes to the corresponding Schema.Type. + */ + private static final Map, Type> JAVA_CLASS_SCHEMA_TYPES = new HashMap<>(); + + static { + SCHEMA_TYPE_CLASSES.put(Type.INT8, Collections.singletonList(Byte.class)); + SCHEMA_TYPE_CLASSES.put(Type.INT16, Collections.singletonList(Short.class)); + SCHEMA_TYPE_CLASSES.put(Type.INT32, Collections.singletonList(Integer.class)); + SCHEMA_TYPE_CLASSES.put(Type.INT64, Collections.singletonList(Long.class)); + SCHEMA_TYPE_CLASSES.put(Type.FLOAT32, Collections.singletonList(Float.class)); + SCHEMA_TYPE_CLASSES.put(Type.FLOAT64, Collections.singletonList(Double.class)); + SCHEMA_TYPE_CLASSES.put(Type.BOOLEAN, Collections.singletonList(Boolean.class)); + SCHEMA_TYPE_CLASSES.put(Type.STRING, Collections.singletonList(String.class)); + // Bytes are special and have 2 representations. byte[] causes problems because it doesn't handle equals() and + // hashCode() like we want objects to, so we support both byte[] and ByteBuffer. Using plain byte[] can cause + // those methods to fail, so ByteBuffers are recommended + SCHEMA_TYPE_CLASSES.put(Type.BYTES, Arrays.asList(byte[].class, ByteBuffer.class)); + SCHEMA_TYPE_CLASSES.put(Type.ARRAY, Collections.singletonList(List.class)); + SCHEMA_TYPE_CLASSES.put(Type.MAP, Collections.singletonList(Map.class)); + SCHEMA_TYPE_CLASSES.put(Type.STRUCT, Collections.singletonList(Struct.class)); + + for (Map.Entry>> schemaClasses : SCHEMA_TYPE_CLASSES.entrySet()) { + for (Class schemaClass : schemaClasses.getValue()) + JAVA_CLASS_SCHEMA_TYPES.put(schemaClass, schemaClasses.getKey()); + } + + LOGICAL_TYPE_CLASSES.put(Decimal.LOGICAL_NAME, Collections.singletonList(BigDecimal.class)); + LOGICAL_TYPE_CLASSES.put(Date.LOGICAL_NAME, Collections.singletonList(java.util.Date.class)); + LOGICAL_TYPE_CLASSES.put(Time.LOGICAL_NAME, Collections.singletonList(java.util.Date.class)); + LOGICAL_TYPE_CLASSES.put(Timestamp.LOGICAL_NAME, Collections.singletonList(java.util.Date.class)); + // We don't need to put these into JAVA_CLASS_SCHEMA_TYPES since that's only used to determine schemas for + // schemaless data and logical types will have ambiguous schemas (e.g. many of them use the same Java class) so + // they should not be used without schemas. + } + + // The type of the field + private final Type type; + private final boolean optional; + private final Object defaultValue; + + private final List fields; + private final Map fieldsByName; + + private final Schema keySchema; + private final Schema valueSchema; + + // Optional name and version provide a built-in way to indicate what type of data is included. Most + // useful for structs to indicate the semantics of the struct and map it to some existing underlying + // serializer-specific schema. However, can also be useful in specifying other logical types (e.g. a set is an array + // with additional constraints). + private final String name; + private final Integer version; + // Optional human readable documentation describing this schema. + private final String doc; + private final Map parameters; + // precomputed hash code. There is no need to re-compute every time hashCode() is called. + private Integer hash = null; + + /** + * Construct a Schema. Most users should not construct schemas manually, preferring {@link SchemaBuilder} instead. + */ + public ConnectSchema(Type type, boolean optional, Object defaultValue, String name, Integer version, String doc, Map parameters, List fields, Schema keySchema, Schema valueSchema) { + this.type = type; + this.optional = optional; + this.defaultValue = defaultValue; + this.name = name; + this.version = version; + this.doc = doc; + this.parameters = parameters; + + if (this.type == Type.STRUCT) { + this.fields = fields == null ? Collections.emptyList() : fields; + this.fieldsByName = new HashMap<>(this.fields.size()); + for (Field field : this.fields) + fieldsByName.put(field.name(), field); + } else { + this.fields = null; + this.fieldsByName = null; + } + + this.keySchema = keySchema; + this.valueSchema = valueSchema; + } + + /** + * Construct a Schema for a primitive type, setting schema parameters, struct fields, and key and value schemas to null. + */ + public ConnectSchema(Type type, boolean optional, Object defaultValue, String name, Integer version, String doc) { + this(type, optional, defaultValue, name, version, doc, null, null, null, null); + } + + /** + * Construct a default schema for a primitive type. The schema is required, has no default value, name, version, + * or documentation. + */ + public ConnectSchema(Type type) { + this(type, false, null, null, null, null); + } + + @Override + public Type type() { + return type; + } + + @Override + public boolean isOptional() { + return optional; + } + + @Override + public Object defaultValue() { + return defaultValue; + } + + @Override + public String name() { + return name; + } + + @Override + public Integer version() { + return version; + } + + @Override + public String doc() { + return doc; + } + + @Override + public Map parameters() { + return parameters; + } + + @Override + public List fields() { + if (type != Type.STRUCT) + throw new DataException("Cannot list fields on non-struct type"); + return fields; + } + + @Override + public Field field(String fieldName) { + if (type != Type.STRUCT) + throw new DataException("Cannot look up fields on non-struct type"); + return fieldsByName.get(fieldName); + } + + @Override + public Schema keySchema() { + if (type != Type.MAP) + throw new DataException("Cannot look up key schema on non-map type"); + return keySchema; + } + + @Override + public Schema valueSchema() { + if (type != Type.MAP && type != Type.ARRAY) + throw new DataException("Cannot look up value schema on non-array and non-map type"); + return valueSchema; + } + + + + /** + * Validate that the value can be used with the schema, i.e. that its type matches the schema type and nullability + * requirements. Throws a DataException if the value is invalid. + * @param schema Schema to test + * @param value value to test + */ + public static void validateValue(Schema schema, Object value) { + validateValue(null, schema, value); + } + + public static void validateValue(String name, Schema schema, Object value) { + if (value == null) { + if (!schema.isOptional()) + throw new DataException("Invalid value: null used for required field: \"" + name + + "\", schema type: " + schema.type()); + return; + } + + List> expectedClasses = expectedClassesFor(schema); + boolean foundMatch = false; + for (Class expectedClass : expectedClasses) { + if (expectedClass.isInstance(value)) { + foundMatch = true; + break; + } + } + + if (!foundMatch) { + StringBuilder exceptionMessage = new StringBuilder("Invalid Java object for schema"); + if (schema.name() != null) { + exceptionMessage.append(" \"").append(schema.name()).append("\""); + } + exceptionMessage.append(" with type ").append(schema.type()).append(": ").append(value.getClass()); + if (name != null) { + exceptionMessage.append(" for field: \"").append(name).append("\""); + } + throw new DataException(exceptionMessage.toString()); + } + + switch (schema.type()) { + case STRUCT: + Struct struct = (Struct) value; + if (!struct.schema().equals(schema)) + throw new DataException("Struct schemas do not match."); + struct.validate(); + break; + case ARRAY: + List array = (List) value; + for (Object entry : array) + validateValue(schema.valueSchema(), entry); + break; + case MAP: + Map map = (Map) value; + for (Map.Entry entry : map.entrySet()) { + validateValue(schema.keySchema(), entry.getKey()); + validateValue(schema.valueSchema(), entry.getValue()); + } + break; + } + } + + private static List> expectedClassesFor(Schema schema) { + List> expectedClasses = LOGICAL_TYPE_CLASSES.get(schema.name()); + if (expectedClasses == null) + expectedClasses = SCHEMA_TYPE_CLASSES.getOrDefault(schema.type(), Collections.emptyList()); + return expectedClasses; + } + + /** + * Validate that the value can be used for this schema, i.e. that its type matches the schema type and optional + * requirements. Throws a DataException if the value is invalid. + * @param value the value to validate + */ + public void validateValue(Object value) { + validateValue(this, value); + } + + @Override + public ConnectSchema schema() { + return this; + } + + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ConnectSchema schema = (ConnectSchema) o; + return Objects.equals(optional, schema.optional) && + Objects.equals(version, schema.version) && + Objects.equals(name, schema.name) && + Objects.equals(doc, schema.doc) && + Objects.equals(type, schema.type) && + Objects.deepEquals(defaultValue, schema.defaultValue) && + Objects.equals(fields, schema.fields) && + Objects.equals(keySchema, schema.keySchema) && + Objects.equals(valueSchema, schema.valueSchema) && + Objects.equals(parameters, schema.parameters); + } + + @Override + public int hashCode() { + if (this.hash == null) { + this.hash = Objects.hash(type, optional, defaultValue, fields, keySchema, valueSchema, name, version, doc, + parameters); + } + return this.hash; + } + + @Override + public String toString() { + if (name != null) + return "Schema{" + name + ":" + type + "}"; + else + return "Schema{" + type + "}"; + } + + + /** + * Get the {@link Schema.Type} associated with the given class. + * + * @param klass the Class to + * @return the corresponding type, or null if there is no matching type + */ + public static Type schemaType(Class klass) { + synchronized (JAVA_CLASS_SCHEMA_TYPES) { + Type schemaType = JAVA_CLASS_SCHEMA_TYPES.get(klass); + if (schemaType != null) + return schemaType; + + // Since the lookup only checks the class, we need to also try + for (Map.Entry, Type> entry : JAVA_CLASS_SCHEMA_TYPES.entrySet()) { + try { + klass.asSubclass(entry.getKey()); + // Cache this for subsequent lookups + JAVA_CLASS_SCHEMA_TYPES.put(klass, entry.getValue()); + return entry.getValue(); + } catch (ClassCastException e) { + // Expected, ignore + } + } + } + return null; + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/data/Date.java b/connect/api/src/main/java/org/apache/kafka/connect/data/Date.java new file mode 100644 index 0000000..e87188c --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/data/Date.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.data; + +import org.apache.kafka.connect.errors.DataException; + +import java.util.Calendar; +import java.util.TimeZone; + +/** + *

            + * A date representing a calendar day with no time of day or timezone. The corresponding Java type is a java.util.Date + * with hours, minutes, seconds, milliseconds set to 0. The underlying representation is an integer representing the + * number of standardized days (based on a number of milliseconds with 24 hours/day, 60 minutes/hour, 60 seconds/minute, + * 1000 milliseconds/second with n) since Unix epoch. + *

            + */ +public class Date { + public static final String LOGICAL_NAME = "org.apache.kafka.connect.data.Date"; + + private static final long MILLIS_PER_DAY = 24 * 60 * 60 * 1000; + + private static final TimeZone UTC = TimeZone.getTimeZone("UTC"); + + /** + * Returns a SchemaBuilder for a Date. By returning a SchemaBuilder you can override additional schema settings such + * as required/optional, default value, and documentation. + * @return a SchemaBuilder + */ + public static SchemaBuilder builder() { + return SchemaBuilder.int32() + .name(LOGICAL_NAME) + .version(1); + } + + public static final Schema SCHEMA = builder().schema(); + + /** + * Convert a value from its logical format (Date) to it's encoded format. + * @param value the logical value + * @return the encoded value + */ + public static int fromLogical(Schema schema, java.util.Date value) { + if (!(LOGICAL_NAME.equals(schema.name()))) + throw new DataException("Requested conversion of Date object but the schema does not match."); + Calendar calendar = Calendar.getInstance(UTC); + calendar.setTime(value); + if (calendar.get(Calendar.HOUR_OF_DAY) != 0 || calendar.get(Calendar.MINUTE) != 0 || + calendar.get(Calendar.SECOND) != 0 || calendar.get(Calendar.MILLISECOND) != 0) { + throw new DataException("Kafka Connect Date type should not have any time fields set to non-zero values."); + } + long unixMillis = calendar.getTimeInMillis(); + return (int) (unixMillis / MILLIS_PER_DAY); + } + + public static java.util.Date toLogical(Schema schema, int value) { + if (!(LOGICAL_NAME.equals(schema.name()))) + throw new DataException("Requested conversion of Date object but the schema does not match."); + return new java.util.Date(value * MILLIS_PER_DAY); + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/data/Decimal.java b/connect/api/src/main/java/org/apache/kafka/connect/data/Decimal.java new file mode 100644 index 0000000..a0a8895 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/data/Decimal.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.data; + +import org.apache.kafka.connect.errors.DataException; + +import java.math.BigDecimal; +import java.math.BigInteger; + +/** + *

            + * An arbitrary-precision signed decimal number. The value is unscaled * 10 ^ -scale where: + *

              + *
            • unscaled is an integer
            • + *
            • scale is an integer representing how many digits the decimal point should be shifted on the unscaled value
            • + *
            + *

            + *

            + * Decimal does not provide a fixed schema because it is parameterized by the scale, which is fixed on the schema + * rather than being part of the value. + *

            + *

            + * The underlying representation of this type is bytes containing a two's complement integer + *

            + */ +public class Decimal { + public static final String LOGICAL_NAME = "org.apache.kafka.connect.data.Decimal"; + public static final String SCALE_FIELD = "scale"; + + /** + * Returns a SchemaBuilder for a Decimal with the given scale factor. By returning a SchemaBuilder you can override + * additional schema settings such as required/optional, default value, and documentation. + * @param scale the scale factor to apply to unscaled values + * @return a SchemaBuilder + */ + public static SchemaBuilder builder(int scale) { + return SchemaBuilder.bytes() + .name(LOGICAL_NAME) + .parameter(SCALE_FIELD, Integer.toString(scale)) + .version(1); + } + + public static Schema schema(int scale) { + return builder(scale).build(); + } + + /** + * Convert a value from its logical format (BigDecimal) to it's encoded format. + * @param value the logical value + * @return the encoded value + */ + public static byte[] fromLogical(Schema schema, BigDecimal value) { + int schemaScale = scale(schema); + if (value.scale() != schemaScale) + throw new DataException(String.format( + "Decimal value has mismatching scale for given Decimal schema. " + + "Schema has scale %d, value has scale %d.", + schemaScale, + value.scale() + )); + return value.unscaledValue().toByteArray(); + } + + public static BigDecimal toLogical(Schema schema, byte[] value) { + return new BigDecimal(new BigInteger(value), scale(schema)); + } + + private static int scale(Schema schema) { + String scaleString = schema.parameters().get(SCALE_FIELD); + if (scaleString == null) + throw new DataException("Invalid Decimal schema: scale parameter not found."); + try { + return Integer.parseInt(scaleString); + } catch (NumberFormatException e) { + throw new DataException("Invalid scale parameter found in Decimal schema: ", e); + } + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/data/Field.java b/connect/api/src/main/java/org/apache/kafka/connect/data/Field.java new file mode 100644 index 0000000..b5d3f02 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/data/Field.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.data; + +import java.util.Objects; + +/** + *

            + * A field in a {@link Struct}, consisting of a field name, index, and {@link Schema} for the field value. + *

            + */ +public class Field { + private final String name; + private final int index; + private final Schema schema; + + public Field(String name, int index, Schema schema) { + this.name = name; + this.index = index; + this.schema = schema; + } + + /** + * Get the name of this field. + * @return the name of this field + */ + public String name() { + return name; + } + + + /** + * Get the index of this field within the struct. + * @return the index of this field + */ + public int index() { + return index; + } + + /** + * Get the schema of this field + * @return the schema of values of this field + */ + public Schema schema() { + return schema; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Field field = (Field) o; + return Objects.equals(index, field.index) && + Objects.equals(name, field.name) && + Objects.equals(schema, field.schema); + } + + @Override + public int hashCode() { + return Objects.hash(name, index, schema); + } + + @Override + public String toString() { + return "Field{" + + "name=" + name + + ", index=" + index + + ", schema=" + schema + + "}"; + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/data/Schema.java b/connect/api/src/main/java/org/apache/kafka/connect/data/Schema.java new file mode 100644 index 0000000..c234217 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/data/Schema.java @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.data; + +import java.util.List; +import java.util.Locale; +import java.util.Map; + +/** + *

            + * Definition of an abstract data type. Data types can be primitive types (integer types, floating point types, + * boolean, strings, and bytes) or complex types (typed arrays, maps with one key schema and value schema, + * and structs that have a fixed set of field names each with an associated value schema). Any type can be specified + * as optional, allowing it to be omitted (resulting in null values when it is missing) and can specify a default + * value. + *

            + *

            + * All schemas may have some associated metadata: a name, version, and documentation. These are all considered part + * of the schema itself and included when comparing schemas. Besides adding important metadata, these fields enable + * the specification of logical types that specify additional constraints and semantics (e.g. UNIX timestamps are + * just an int64, but the user needs the know about the additional semantics to interpret it properly). + *

            + *

            + * Schemas can be created directly, but in most cases using {@link SchemaBuilder} will be simpler. + *

            + */ +public interface Schema { + /** + * The type of a schema. These only include the core types; logical types must be determined by checking the schema name. + */ + enum Type { + /** + * 8-bit signed integer + * + * Note that if you have an unsigned 8-bit data source, {@link Type#INT16} will be required to safely capture all valid values + */ + INT8, + /** + * 16-bit signed integer + * + * Note that if you have an unsigned 16-bit data source, {@link Type#INT32} will be required to safely capture all valid values + */ + INT16, + /** + * 32-bit signed integer + * + * Note that if you have an unsigned 32-bit data source, {@link Type#INT64} will be required to safely capture all valid values + */ + INT32, + /** + * 64-bit signed integer + * + * Note that if you have an unsigned 64-bit data source, the {@link Decimal} logical type (encoded as {@link Type#BYTES}) + * will be required to safely capture all valid values + */ + INT64, + /** + * 32-bit IEEE 754 floating point number + */ + FLOAT32, + /** + * 64-bit IEEE 754 floating point number + */ + FLOAT64, + /** + * Boolean value (true or false) + */ + BOOLEAN, + /** + * Character string that supports all Unicode characters. + * + * Note that this does not imply any specific encoding (e.g. UTF-8) as this is an in-memory representation. + */ + STRING, + /** + * Sequence of unsigned 8-bit bytes + */ + BYTES, + /** + * An ordered sequence of elements, each of which shares the same type. + */ + ARRAY, + /** + * A mapping from keys to values. Both keys and values can be arbitrarily complex types, including complex types + * such as {@link Struct}. + */ + MAP, + /** + * A structured record containing a set of named fields, each field using a fixed, independent {@link Schema}. + */ + STRUCT; + + private String name; + + Type() { + this.name = this.name().toLowerCase(Locale.ROOT); + } + + public String getName() { + return name; + } + + public boolean isPrimitive() { + switch (this) { + case INT8: + case INT16: + case INT32: + case INT64: + case FLOAT32: + case FLOAT64: + case BOOLEAN: + case STRING: + case BYTES: + return true; + } + return false; + } + } + + + Schema INT8_SCHEMA = SchemaBuilder.int8().build(); + Schema INT16_SCHEMA = SchemaBuilder.int16().build(); + Schema INT32_SCHEMA = SchemaBuilder.int32().build(); + Schema INT64_SCHEMA = SchemaBuilder.int64().build(); + Schema FLOAT32_SCHEMA = SchemaBuilder.float32().build(); + Schema FLOAT64_SCHEMA = SchemaBuilder.float64().build(); + Schema BOOLEAN_SCHEMA = SchemaBuilder.bool().build(); + Schema STRING_SCHEMA = SchemaBuilder.string().build(); + Schema BYTES_SCHEMA = SchemaBuilder.bytes().build(); + + Schema OPTIONAL_INT8_SCHEMA = SchemaBuilder.int8().optional().build(); + Schema OPTIONAL_INT16_SCHEMA = SchemaBuilder.int16().optional().build(); + Schema OPTIONAL_INT32_SCHEMA = SchemaBuilder.int32().optional().build(); + Schema OPTIONAL_INT64_SCHEMA = SchemaBuilder.int64().optional().build(); + Schema OPTIONAL_FLOAT32_SCHEMA = SchemaBuilder.float32().optional().build(); + Schema OPTIONAL_FLOAT64_SCHEMA = SchemaBuilder.float64().optional().build(); + Schema OPTIONAL_BOOLEAN_SCHEMA = SchemaBuilder.bool().optional().build(); + Schema OPTIONAL_STRING_SCHEMA = SchemaBuilder.string().optional().build(); + Schema OPTIONAL_BYTES_SCHEMA = SchemaBuilder.bytes().optional().build(); + + /** + * @return the type of this schema + */ + Type type(); + + /** + * @return true if this field is optional, false otherwise + */ + boolean isOptional(); + + /** + * @return the default value for this schema + */ + Object defaultValue(); + + /** + * @return the name of this schema + */ + String name(); + + /** + * Get the optional version of the schema. If a version is included, newer versions *must* be larger than older ones. + * @return the version of this schema + */ + Integer version(); + + /** + * @return the documentation for this schema + */ + String doc(); + + /** + * Get a map of schema parameters. + * @return Map containing parameters for this schema, or null if there are no parameters + */ + Map parameters(); + + /** + * Get the key schema for this map schema. Throws a DataException if this schema is not a map. + * @return the key schema + */ + Schema keySchema(); + + /** + * Get the value schema for this map or array schema. Throws a DataException if this schema is not a map or array. + * @return the value schema + */ + Schema valueSchema(); + + /** + * Get the list of fields for this Schema. Throws a DataException if this schema is not a struct. + * @return the list of fields for this Schema + */ + List fields(); + + /** + * Get a field for this Schema by name. Throws a DataException if this schema is not a struct. + * @param fieldName the name of the field to look up + * @return the Field object for the specified field, or null if there is no field with the given name + */ + Field field(String fieldName); + + /** + * Return a concrete instance of the {@link Schema} + * @return the {@link Schema} + */ + Schema schema(); +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/data/SchemaAndValue.java b/connect/api/src/main/java/org/apache/kafka/connect/data/SchemaAndValue.java new file mode 100644 index 0000000..b9a539e --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/data/SchemaAndValue.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.data; + +import java.util.Objects; + +public class SchemaAndValue { + private final Schema schema; + private final Object value; + + public static final SchemaAndValue NULL = new SchemaAndValue(null, null); + + public SchemaAndValue(Schema schema, Object value) { + this.value = value; + this.schema = schema; + } + + public Schema schema() { + return schema; + } + + public Object value() { + return value; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SchemaAndValue that = (SchemaAndValue) o; + return Objects.equals(schema, that.schema) && + Objects.equals(value, that.value); + } + + @Override + public int hashCode() { + return Objects.hash(schema, value); + } + + @Override + public String toString() { + return "SchemaAndValue{" + + "schema=" + schema + + ", value=" + value + + '}'; + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/data/SchemaBuilder.java b/connect/api/src/main/java/org/apache/kafka/connect/data/SchemaBuilder.java new file mode 100644 index 0000000..9b4be4f --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/data/SchemaBuilder.java @@ -0,0 +1,444 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.data; + +import org.apache.kafka.connect.errors.DataException; +import org.apache.kafka.connect.errors.SchemaBuilderException; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +/** + *

            + * SchemaBuilder provides a fluent API for constructing {@link Schema} objects. It allows you to set each of the + * properties for the schema and each call returns the SchemaBuilder so the calls can be chained. When nested types + * are required, use one of the predefined schemas from {@link Schema} or use a second SchemaBuilder inline. + *

            + *

            + * Here is an example of building a struct schema: + *

            + *     Schema dateSchema = SchemaBuilder.struct()
            + *         .name("com.example.CalendarDate").version(2).doc("A calendar date including month, day, and year.")
            + *         .field("month", Schema.STRING_SCHEMA)
            + *         .field("day", Schema.INT8_SCHEMA)
            + *         .field("year", Schema.INT16_SCHEMA)
            + *         .build();
            + *     
            + *

            + *

            + * Here is an example of using a second SchemaBuilder to construct complex, nested types: + *

            + *     Schema userListSchema = SchemaBuilder.array(
            + *         SchemaBuilder.struct().name("com.example.User").field("username", Schema.STRING_SCHEMA).field("id", Schema.INT64_SCHEMA).build()
            + *     ).build();
            + *     
            + *

            + */ +public class SchemaBuilder implements Schema { + private static final String TYPE_FIELD = "type"; + private static final String OPTIONAL_FIELD = "optional"; + private static final String DEFAULT_FIELD = "default"; + private static final String NAME_FIELD = "name"; + private static final String VERSION_FIELD = "version"; + private static final String DOC_FIELD = "doc"; + + + private final Type type; + private Boolean optional = null; + private Object defaultValue = null; + + private Map fields = null; + private Schema keySchema = null; + private Schema valueSchema = null; + + private String name; + private Integer version; + // Optional human readable documentation describing this schema. + private String doc; + // Additional parameters for logical types. + private Map parameters; + + public SchemaBuilder(Type type) { + if (null == type) + throw new SchemaBuilderException("type cannot be null"); + this.type = type; + if (type == Type.STRUCT) { + fields = new LinkedHashMap<>(); + } + } + + // Common/metadata fields + + @Override + public boolean isOptional() { + return optional == null ? false : optional; + } + + /** + * Set this schema as optional. + * @return the SchemaBuilder + */ + public SchemaBuilder optional() { + checkCanSet(OPTIONAL_FIELD, optional, true); + optional = true; + return this; + } + + /** + * Set this schema as required. This is the default, but this method can be used to make this choice explicit. + * @return the SchemaBuilder + */ + public SchemaBuilder required() { + checkCanSet(OPTIONAL_FIELD, optional, false); + optional = false; + return this; + } + + @Override + public Object defaultValue() { + return defaultValue; + } + + /** + * Set the default value for this schema. The value is validated against the schema type, throwing a + * {@link SchemaBuilderException} if it does not match. + * @param value the default value + * @return the SchemaBuilder + */ + public SchemaBuilder defaultValue(Object value) { + checkCanSet(DEFAULT_FIELD, defaultValue, value); + checkNotNull(TYPE_FIELD, type, DEFAULT_FIELD); + try { + ConnectSchema.validateValue(this, value); + } catch (DataException e) { + throw new SchemaBuilderException("Invalid default value", e); + } + defaultValue = value; + return this; + } + + @Override + public String name() { + return name; + } + + /** + * Set the name of this schema. + * @param name the schema name + * @return the SchemaBuilder + */ + public SchemaBuilder name(String name) { + checkCanSet(NAME_FIELD, this.name, name); + this.name = name; + return this; + } + + @Override + public Integer version() { + return version; + } + + /** + * Set the version of this schema. Schema versions are integers which, if provided, must indicate which schema is + * newer and which is older by their ordering. + * @param version the schema version + * @return the SchemaBuilder + */ + public SchemaBuilder version(Integer version) { + checkCanSet(VERSION_FIELD, this.version, version); + this.version = version; + return this; + } + + @Override + public String doc() { + return doc; + } + + /** + * Set the documentation for this schema. + * @param doc the documentation + * @return the SchemaBuilder + */ + public SchemaBuilder doc(String doc) { + checkCanSet(DOC_FIELD, this.doc, doc); + this.doc = doc; + return this; + } + + @Override + public Map parameters() { + return parameters == null ? null : Collections.unmodifiableMap(parameters); + } + + /** + * Set a schema parameter. + * @param propertyName name of the schema property to define + * @param propertyValue value of the schema property to define, as a String + * @return the SchemaBuilder + */ + public SchemaBuilder parameter(String propertyName, String propertyValue) { + // Preserve order of insertion with a LinkedHashMap. This isn't strictly necessary, but is nice if logical types + // can print their properties in a consistent order. + if (parameters == null) + parameters = new LinkedHashMap<>(); + parameters.put(propertyName, propertyValue); + return this; + } + + /** + * Set schema parameters. This operation is additive; it does not remove existing parameters that do not appear in + * the set of properties pass to this method. + * @param props Map of properties to set + * @return the SchemaBuilder + */ + public SchemaBuilder parameters(Map props) { + // Avoid creating an empty set of properties so we never have an empty map + if (props.isEmpty()) + return this; + if (parameters == null) + parameters = new LinkedHashMap<>(); + parameters.putAll(props); + return this; + } + + @Override + public Type type() { + return type; + } + + /** + * Create a SchemaBuilder for the specified type. + * + * Usually it will be simpler to use one of the variants like {@link #string()} or {@link #struct()}, but this form + * can be useful when generating schemas dynamically. + * + * @param type the schema type + * @return a new SchemaBuilder + */ + public static SchemaBuilder type(Type type) { + return new SchemaBuilder(type); + } + + // Primitive types + + /** + * @return a new {@link Schema.Type#INT8} SchemaBuilder + */ + public static SchemaBuilder int8() { + return new SchemaBuilder(Type.INT8); + } + + /** + * @return a new {@link Schema.Type#INT16} SchemaBuilder + */ + public static SchemaBuilder int16() { + return new SchemaBuilder(Type.INT16); + } + + /** + * @return a new {@link Schema.Type#INT32} SchemaBuilder + */ + public static SchemaBuilder int32() { + return new SchemaBuilder(Type.INT32); + } + + /** + * @return a new {@link Schema.Type#INT64} SchemaBuilder + */ + public static SchemaBuilder int64() { + return new SchemaBuilder(Type.INT64); + } + + /** + * @return a new {@link Schema.Type#FLOAT32} SchemaBuilder + */ + public static SchemaBuilder float32() { + return new SchemaBuilder(Type.FLOAT32); + } + + /** + * @return a new {@link Schema.Type#FLOAT64} SchemaBuilder + */ + public static SchemaBuilder float64() { + return new SchemaBuilder(Type.FLOAT64); + } + + /** + * @return a new {@link Schema.Type#BOOLEAN} SchemaBuilder + */ + public static SchemaBuilder bool() { + return new SchemaBuilder(Type.BOOLEAN); + } + + /** + * @return a new {@link Schema.Type#STRING} SchemaBuilder + */ + public static SchemaBuilder string() { + return new SchemaBuilder(Type.STRING); + } + + /** + * @return a new {@link Schema.Type#BYTES} SchemaBuilder + */ + public static SchemaBuilder bytes() { + return new SchemaBuilder(Type.BYTES); + } + + + // Structs + + /** + * @return a new {@link Schema.Type#STRUCT} SchemaBuilder + */ + public static SchemaBuilder struct() { + return new SchemaBuilder(Type.STRUCT); + } + + /** + * Add a field to this struct schema. Throws a SchemaBuilderException if this is not a struct schema. + * @param fieldName the name of the field to add + * @param fieldSchema the Schema for the field's value + * @return the SchemaBuilder + */ + public SchemaBuilder field(String fieldName, Schema fieldSchema) { + if (type != Type.STRUCT) + throw new SchemaBuilderException("Cannot create fields on type " + type); + if (null == fieldName || fieldName.isEmpty()) + throw new SchemaBuilderException("fieldName cannot be null."); + if (null == fieldSchema) + throw new SchemaBuilderException("fieldSchema for field " + fieldName + " cannot be null."); + int fieldIndex = fields.size(); + if (fields.containsKey(fieldName)) + throw new SchemaBuilderException("Cannot create field because of field name duplication " + fieldName); + fields.put(fieldName, new Field(fieldName, fieldIndex, fieldSchema)); + return this; + } + + /** + * Get the list of fields for this Schema. Throws a DataException if this schema is not a struct. + * @return the list of fields for this Schema + */ + @Override + public List fields() { + if (type != Type.STRUCT) + throw new DataException("Cannot list fields on non-struct type"); + return new ArrayList<>(fields.values()); + } + + @Override + public Field field(String fieldName) { + if (type != Type.STRUCT) + throw new DataException("Cannot look up fields on non-struct type"); + return fields.get(fieldName); + } + + + + // Maps & Arrays + + /** + * @param valueSchema the schema for elements of the array + * @return a new {@link Schema.Type#ARRAY} SchemaBuilder + */ + public static SchemaBuilder array(Schema valueSchema) { + if (null == valueSchema) + throw new SchemaBuilderException("valueSchema cannot be null."); + SchemaBuilder builder = new SchemaBuilder(Type.ARRAY); + builder.valueSchema = valueSchema; + return builder; + } + + /** + * @param keySchema the schema for keys in the map + * @param valueSchema the schema for values in the map + * @return a new {@link Schema.Type#MAP} SchemaBuilder + */ + public static SchemaBuilder map(Schema keySchema, Schema valueSchema) { + if (null == keySchema) + throw new SchemaBuilderException("keySchema cannot be null."); + if (null == valueSchema) + throw new SchemaBuilderException("valueSchema cannot be null."); + SchemaBuilder builder = new SchemaBuilder(Type.MAP); + builder.keySchema = keySchema; + builder.valueSchema = valueSchema; + return builder; + } + + static SchemaBuilder arrayOfNull() { + return new SchemaBuilder(Type.ARRAY); + } + + static SchemaBuilder mapOfNull() { + return new SchemaBuilder(Type.MAP); + } + + static SchemaBuilder mapWithNullKeys(Schema valueSchema) { + SchemaBuilder result = new SchemaBuilder(Type.MAP); + result.valueSchema = valueSchema; + return result; + } + + static SchemaBuilder mapWithNullValues(Schema keySchema) { + SchemaBuilder result = new SchemaBuilder(Type.MAP); + result.keySchema = keySchema; + return result; + } + + @Override + public Schema keySchema() { + return keySchema; + } + + @Override + public Schema valueSchema() { + return valueSchema; + } + + + /** + * Build the Schema using the current settings + * @return the {@link Schema} + */ + public Schema build() { + return new ConnectSchema(type, isOptional(), defaultValue, name, version, doc, + parameters == null ? null : Collections.unmodifiableMap(parameters), + fields == null ? null : Collections.unmodifiableList(new ArrayList<>(fields.values())), keySchema, valueSchema); + } + + /** + * Return a concrete instance of the {@link Schema} specified by this builder + * @return the {@link Schema} + */ + @Override + public Schema schema() { + return build(); + } + + private static void checkCanSet(String fieldName, Object fieldVal, Object val) { + if (fieldVal != null && fieldVal != val) + throw new SchemaBuilderException("Invalid SchemaBuilder call: " + fieldName + " has already been set."); + } + + private static void checkNotNull(String fieldName, Object val, String fieldToSet) { + if (val == null) + throw new SchemaBuilderException("Invalid SchemaBuilder call: " + fieldName + " must be specified to set " + fieldToSet); + } +} \ No newline at end of file diff --git a/connect/api/src/main/java/org/apache/kafka/connect/data/SchemaProjector.java b/connect/api/src/main/java/org/apache/kafka/connect/data/SchemaProjector.java new file mode 100644 index 0000000..5400705 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/data/SchemaProjector.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.data; + +import org.apache.kafka.connect.data.Schema.Type; +import org.apache.kafka.connect.errors.SchemaProjectorException; + +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +/** + *

            + * SchemaProjector is utility to project a value between compatible schemas and throw exceptions + * when non compatible schemas are provided. + *

            + */ + +public class SchemaProjector { + + private static Set> promotable = new HashSet<>(); + + static { + Type[] promotableTypes = {Type.INT8, Type.INT16, Type.INT32, Type.INT64, Type.FLOAT32, Type.FLOAT64}; + for (int i = 0; i < promotableTypes.length; ++i) { + for (int j = i; j < promotableTypes.length; ++j) { + promotable.add(new AbstractMap.SimpleImmutableEntry<>(promotableTypes[i], promotableTypes[j])); + } + } + } + + /** + * This method project a value between compatible schemas and throw exceptions when non compatible schemas are provided + * @param source the schema used to construct the record + * @param record the value to project from source schema to target schema + * @param target the schema to project the record to + * @return the projected value with target schema + * @throws SchemaProjectorException + */ + public static Object project(Schema source, Object record, Schema target) throws SchemaProjectorException { + checkMaybeCompatible(source, target); + if (source.isOptional() && !target.isOptional()) { + if (target.defaultValue() != null) { + if (record != null) { + return projectRequiredSchema(source, record, target); + } else { + return target.defaultValue(); + } + } else { + throw new SchemaProjectorException("Writer schema is optional, however, target schema does not provide a default value."); + } + } else { + if (record != null) { + return projectRequiredSchema(source, record, target); + } else { + return null; + } + } + } + + private static Object projectRequiredSchema(Schema source, Object record, Schema target) throws SchemaProjectorException { + switch (target.type()) { + case INT8: + case INT16: + case INT32: + case INT64: + case FLOAT32: + case FLOAT64: + case BOOLEAN: + case BYTES: + case STRING: + return projectPrimitive(source, record, target); + case STRUCT: + return projectStruct(source, (Struct) record, target); + case ARRAY: + return projectArray(source, record, target); + case MAP: + return projectMap(source, record, target); + } + return null; + } + + private static Object projectStruct(Schema source, Struct sourceStruct, Schema target) throws SchemaProjectorException { + Struct targetStruct = new Struct(target); + for (Field targetField : target.fields()) { + String fieldName = targetField.name(); + Field sourceField = source.field(fieldName); + if (sourceField != null) { + Object sourceFieldValue = sourceStruct.get(fieldName); + try { + Object targetFieldValue = project(sourceField.schema(), sourceFieldValue, targetField.schema()); + targetStruct.put(fieldName, targetFieldValue); + } catch (SchemaProjectorException e) { + throw new SchemaProjectorException("Error projecting " + sourceField.name(), e); + } + } else if (targetField.schema().isOptional()) { + // Ignore missing field + } else if (targetField.schema().defaultValue() != null) { + targetStruct.put(fieldName, targetField.schema().defaultValue()); + } else { + throw new SchemaProjectorException("Required field `" + fieldName + "` is missing from source schema: " + source); + } + } + return targetStruct; + } + + + private static void checkMaybeCompatible(Schema source, Schema target) { + if (source.type() != target.type() && !isPromotable(source.type(), target.type())) { + throw new SchemaProjectorException("Schema type mismatch. source type: " + source.type() + " and target type: " + target.type()); + } else if (!Objects.equals(source.name(), target.name())) { + throw new SchemaProjectorException("Schema name mismatch. source name: " + source.name() + " and target name: " + target.name()); + } else if (!Objects.equals(source.parameters(), target.parameters())) { + throw new SchemaProjectorException("Schema parameters not equal. source parameters: " + source.parameters() + " and target parameters: " + target.parameters()); + } + } + + private static Object projectArray(Schema source, Object record, Schema target) throws SchemaProjectorException { + List array = (List) record; + List retArray = new ArrayList<>(); + for (Object entry : array) { + retArray.add(project(source.valueSchema(), entry, target.valueSchema())); + } + return retArray; + } + + private static Object projectMap(Schema source, Object record, Schema target) throws SchemaProjectorException { + Map map = (Map) record; + Map retMap = new HashMap<>(); + for (Map.Entry entry : map.entrySet()) { + Object key = entry.getKey(); + Object value = entry.getValue(); + Object retKey = project(source.keySchema(), key, target.keySchema()); + Object retValue = project(source.valueSchema(), value, target.valueSchema()); + retMap.put(retKey, retValue); + } + return retMap; + } + + private static Object projectPrimitive(Schema source, Object record, Schema target) throws SchemaProjectorException { + assert source.type().isPrimitive(); + assert target.type().isPrimitive(); + Object result; + if (isPromotable(source.type(), target.type()) && record instanceof Number) { + Number numberRecord = (Number) record; + switch (target.type()) { + case INT8: + result = numberRecord.byteValue(); + break; + case INT16: + result = numberRecord.shortValue(); + break; + case INT32: + result = numberRecord.intValue(); + break; + case INT64: + result = numberRecord.longValue(); + break; + case FLOAT32: + result = numberRecord.floatValue(); + break; + case FLOAT64: + result = numberRecord.doubleValue(); + break; + default: + throw new SchemaProjectorException("Not promotable type."); + } + } else { + result = record; + } + return result; + } + + private static boolean isPromotable(Type sourceType, Type targetType) { + return promotable.contains(new AbstractMap.SimpleImmutableEntry<>(sourceType, targetType)); + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/data/Struct.java b/connect/api/src/main/java/org/apache/kafka/connect/data/Struct.java new file mode 100644 index 0000000..1f542e5 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/data/Struct.java @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.data; + +import org.apache.kafka.connect.errors.DataException; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +/** + *

            + * A structured record containing a set of named fields with values, each field using an independent {@link Schema}. + * Struct objects must specify a complete {@link Schema} up front, and only fields specified in the Schema may be set. + *

            + *

            + * The Struct's {@link #put(String, Object)} method returns the Struct itself to provide a fluent API for constructing + * complete objects: + *

            + *         Schema schema = SchemaBuilder.struct().name("com.example.Person")
            + *             .field("name", Schema.STRING_SCHEMA).field("age", Schema.INT32_SCHEMA).build()
            + *         Struct struct = new Struct(schema).put("name", "Bobby McGee").put("age", 21)
            + *     
            + *

            + */ +public class Struct { + + private final Schema schema; + private final Object[] values; + + /** + * Create a new Struct for this {@link Schema} + * @param schema the {@link Schema} for the Struct + */ + public Struct(Schema schema) { + if (schema.type() != Schema.Type.STRUCT) + throw new DataException("Not a struct schema: " + schema); + this.schema = schema; + this.values = new Object[schema.fields().size()]; + } + + /** + * Get the schema for this Struct. + * @return the Struct's schema + */ + public Schema schema() { + return schema; + } + + /** + * Get the value of a field, returning the default value if no value has been set yet and a default value is specified + * in the field's schema. Because this handles fields of all types, the value is returned as an {@link Object} and + * must be cast to a more specific type. + * @param fieldName the field name to lookup + * @return the value for the field + */ + public Object get(String fieldName) { + Field field = lookupField(fieldName); + return get(field); + } + + /** + * Get the value of a field, returning the default value if no value has been set yet and a default value is specified + * in the field's schema. Because this handles fields of all types, the value is returned as an {@link Object} and + * must be cast to a more specific type. + * @param field the field to lookup + * @return the value for the field + */ + public Object get(Field field) { + Object val = values[field.index()]; + if (val == null && field.schema().defaultValue() != null) { + val = field.schema().defaultValue(); + } + return val; + } + + /** + * Get the underlying raw value for the field without accounting for default values. + * @param fieldName the field to get the value of + * @return the raw value + */ + public Object getWithoutDefault(String fieldName) { + Field field = lookupField(fieldName); + return values[field.index()]; + } + + // Note that all getters have to have boxed return types since the fields might be optional + + /** + * Equivalent to calling {@link #get(String)} and casting the result to a Byte. + */ + public Byte getInt8(String fieldName) { + return (Byte) getCheckType(fieldName, Schema.Type.INT8); + } + + /** + * Equivalent to calling {@link #get(String)} and casting the result to a Short. + */ + public Short getInt16(String fieldName) { + return (Short) getCheckType(fieldName, Schema.Type.INT16); + } + + /** + * Equivalent to calling {@link #get(String)} and casting the result to a Integer. + */ + public Integer getInt32(String fieldName) { + return (Integer) getCheckType(fieldName, Schema.Type.INT32); + } + + /** + * Equivalent to calling {@link #get(String)} and casting the result to a Long. + */ + public Long getInt64(String fieldName) { + return (Long) getCheckType(fieldName, Schema.Type.INT64); + } + + /** + * Equivalent to calling {@link #get(String)} and casting the result to a Float. + */ + public Float getFloat32(String fieldName) { + return (Float) getCheckType(fieldName, Schema.Type.FLOAT32); + } + + /** + * Equivalent to calling {@link #get(String)} and casting the result to a Double. + */ + public Double getFloat64(String fieldName) { + return (Double) getCheckType(fieldName, Schema.Type.FLOAT64); + } + + /** + * Equivalent to calling {@link #get(String)} and casting the result to a Boolean. + */ + public Boolean getBoolean(String fieldName) { + return (Boolean) getCheckType(fieldName, Schema.Type.BOOLEAN); + } + + /** + * Equivalent to calling {@link #get(String)} and casting the result to a String. + */ + public String getString(String fieldName) { + return (String) getCheckType(fieldName, Schema.Type.STRING); + } + + /** + * Equivalent to calling {@link #get(String)} and casting the result to a byte[]. + */ + public byte[] getBytes(String fieldName) { + Object bytes = getCheckType(fieldName, Schema.Type.BYTES); + if (bytes instanceof ByteBuffer) + return ((ByteBuffer) bytes).array(); + return (byte[]) bytes; + } + + /** + * Equivalent to calling {@link #get(String)} and casting the result to a List. + */ + @SuppressWarnings("unchecked") + public List getArray(String fieldName) { + return (List) getCheckType(fieldName, Schema.Type.ARRAY); + } + + /** + * Equivalent to calling {@link #get(String)} and casting the result to a Map. + */ + @SuppressWarnings("unchecked") + public Map getMap(String fieldName) { + return (Map) getCheckType(fieldName, Schema.Type.MAP); + } + + /** + * Equivalent to calling {@link #get(String)} and casting the result to a Struct. + */ + public Struct getStruct(String fieldName) { + return (Struct) getCheckType(fieldName, Schema.Type.STRUCT); + } + + /** + * Set the value of a field. Validates the value, throwing a {@link DataException} if it does not match the field's + * {@link Schema}. + * @param fieldName the name of the field to set + * @param value the value of the field + * @return the Struct, to allow chaining of {@link #put(String, Object)} calls + */ + public Struct put(String fieldName, Object value) { + Field field = lookupField(fieldName); + return put(field, value); + } + + /** + * Set the value of a field. Validates the value, throwing a {@link DataException} if it does not match the field's + * {@link Schema}. + * @param field the field to set + * @param value the value of the field + * @return the Struct, to allow chaining of {@link #put(String, Object)} calls + */ + public Struct put(Field field, Object value) { + if (null == field) + throw new DataException("field cannot be null."); + ConnectSchema.validateValue(field.name(), field.schema(), value); + values[field.index()] = value; + return this; + } + + + /** + * Validates that this struct has filled in all the necessary data with valid values. For required fields + * without defaults, this validates that a value has been set and has matching types/schemas. If any validation + * fails, throws a DataException. + */ + public void validate() { + for (Field field : schema.fields()) { + Schema fieldSchema = field.schema(); + Object value = values[field.index()]; + if (value == null && (fieldSchema.isOptional() || fieldSchema.defaultValue() != null)) + continue; + ConnectSchema.validateValue(field.name(), fieldSchema, value); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Struct struct = (Struct) o; + return Objects.equals(schema, struct.schema) && + Arrays.deepEquals(values, struct.values); + } + + @Override + public int hashCode() { + return Objects.hash(schema, Arrays.deepHashCode(values)); + } + + private Field lookupField(String fieldName) { + Field field = schema.field(fieldName); + if (field == null) + throw new DataException(fieldName + " is not a valid field name"); + return field; + } + + // Get the field's value, but also check that the field matches the specified type, throwing an exception if it doesn't. + // Used to implement the get*() methods that return typed data instead of Object + private Object getCheckType(String fieldName, Schema.Type type) { + Field field = lookupField(fieldName); + if (field.schema().type() != type) + throw new DataException("Field '" + fieldName + "' is not of type " + type); + return values[field.index()]; + } + + @Override + public String toString() { + final StringBuilder sb = new StringBuilder("Struct{"); + boolean first = true; + for (int i = 0; i < values.length; i++) { + final Object value = values[i]; + if (value != null) { + final Field field = schema.fields().get(i); + if (first) { + first = false; + } else { + sb.append(","); + } + sb.append(field.name()).append("=").append(value); + } + } + return sb.append("}").toString(); + } + +} + diff --git a/connect/api/src/main/java/org/apache/kafka/connect/data/Time.java b/connect/api/src/main/java/org/apache/kafka/connect/data/Time.java new file mode 100644 index 0000000..ad642e4 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/data/Time.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.data; + +import org.apache.kafka.connect.errors.DataException; + +import java.util.Calendar; +import java.util.TimeZone; + +/** + *

            + * A time representing a specific point in a day, not tied to any specific date. The corresponding Java type is a + * java.util.Date where only hours, minutes, seconds, and milliseconds can be non-zero. This effectively makes it a + * point in time during the first day after the Unix epoch. The underlying representation is an integer + * representing the number of milliseconds after midnight. + *

            + */ +public class Time { + public static final String LOGICAL_NAME = "org.apache.kafka.connect.data.Time"; + + private static final long MILLIS_PER_DAY = 24 * 60 * 60 * 1000; + + private static final TimeZone UTC = TimeZone.getTimeZone("UTC"); + + /** + * Returns a SchemaBuilder for a Time. By returning a SchemaBuilder you can override additional schema settings such + * as required/optional, default value, and documentation. + * @return a SchemaBuilder + */ + public static SchemaBuilder builder() { + return SchemaBuilder.int32() + .name(LOGICAL_NAME) + .version(1); + } + + public static final Schema SCHEMA = builder().schema(); + + /** + * Convert a value from its logical format (Time) to it's encoded format. + * @param value the logical value + * @return the encoded value + */ + public static int fromLogical(Schema schema, java.util.Date value) { + if (!(LOGICAL_NAME.equals(schema.name()))) + throw new DataException("Requested conversion of Time object but the schema does not match."); + Calendar calendar = Calendar.getInstance(UTC); + calendar.setTime(value); + long unixMillis = calendar.getTimeInMillis(); + if (unixMillis < 0 || unixMillis > MILLIS_PER_DAY) { + throw new DataException("Kafka Connect Time type should not have any date fields set to non-zero values."); + } + return (int) unixMillis; + } + + public static java.util.Date toLogical(Schema schema, int value) { + if (!(LOGICAL_NAME.equals(schema.name()))) + throw new DataException("Requested conversion of Date object but the schema does not match."); + if (value < 0 || value > MILLIS_PER_DAY) + throw new DataException("Time values must use number of milliseconds greater than 0 and less than 86400000"); + return new java.util.Date(value); + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/data/Timestamp.java b/connect/api/src/main/java/org/apache/kafka/connect/data/Timestamp.java new file mode 100644 index 0000000..2da6107 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/data/Timestamp.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.data; + +import org.apache.kafka.connect.errors.DataException; + +/** + *

            + * A timestamp representing an absolute time, without timezone information. The corresponding Java type is a + * java.util.Date. The underlying representation is a long representing the number of milliseconds since Unix epoch. + *

            + */ +public class Timestamp { + public static final String LOGICAL_NAME = "org.apache.kafka.connect.data.Timestamp"; + + /** + * Returns a SchemaBuilder for a Timestamp. By returning a SchemaBuilder you can override additional schema settings such + * as required/optional, default value, and documentation. + * @return a SchemaBuilder + */ + public static SchemaBuilder builder() { + return SchemaBuilder.int64() + .name(LOGICAL_NAME) + .version(1); + } + + public static final Schema SCHEMA = builder().schema(); + + /** + * Convert a value from its logical format (Date) to it's encoded format. + * @param value the logical value + * @return the encoded value + */ + public static long fromLogical(Schema schema, java.util.Date value) { + if (!(LOGICAL_NAME.equals(schema.name()))) + throw new DataException("Requested conversion of Timestamp object but the schema does not match."); + return value.getTime(); + } + + public static java.util.Date toLogical(Schema schema, long value) { + if (!(LOGICAL_NAME.equals(schema.name()))) + throw new DataException("Requested conversion of Timestamp object but the schema does not match."); + return new java.util.Date(value); + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/data/Values.java b/connect/api/src/main/java/org/apache/kafka/connect/data/Values.java new file mode 100644 index 0000000..31f4183 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/data/Values.java @@ -0,0 +1,1265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.data; + +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.connect.data.Schema.Type; +import org.apache.kafka.connect.errors.DataException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.text.CharacterIterator; +import java.text.DateFormat; +import java.text.ParseException; +import java.text.SimpleDateFormat; +import java.text.StringCharacterIterator; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Base64; +import java.util.Calendar; +import java.util.Collections; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Set; +import java.util.TimeZone; +import java.util.regex.Pattern; + +/** + * Utility for converting from one Connect value to a different form. This is useful when the caller expects a value of a particular type + * but is uncertain whether the actual value is one that isn't directly that type but can be converted into that type. + * + *

            For example, a caller might expects a particular {@link org.apache.kafka.connect.header.Header} to contain an {@link Type#INT64} + * value, when in fact that header contains a string representation of a 32-bit integer. Here, the caller can use the methods in this + * class to convert the value to the desired type: + *

            + *     Header header = ...
            + *     long value = Values.convertToLong(header.schema(), header.value());
            + * 
            + * + *

            This class is able to convert any value to a string representation as well as parse those string representations back into most of + * the types. The only exception is {@link Struct} values that require a schema and thus cannot be parsed from a simple string. + */ +public class Values { + + private static final Logger LOG = LoggerFactory.getLogger(Values.class); + + private static final TimeZone UTC = TimeZone.getTimeZone("UTC"); + private static final SchemaAndValue NULL_SCHEMA_AND_VALUE = new SchemaAndValue(null, null); + private static final SchemaAndValue TRUE_SCHEMA_AND_VALUE = new SchemaAndValue(Schema.BOOLEAN_SCHEMA, Boolean.TRUE); + private static final SchemaAndValue FALSE_SCHEMA_AND_VALUE = new SchemaAndValue(Schema.BOOLEAN_SCHEMA, Boolean.FALSE); + private static final Schema ARRAY_SELECTOR_SCHEMA = SchemaBuilder.array(Schema.STRING_SCHEMA).build(); + private static final Schema MAP_SELECTOR_SCHEMA = SchemaBuilder.map(Schema.STRING_SCHEMA, Schema.STRING_SCHEMA).build(); + private static final Schema STRUCT_SELECTOR_SCHEMA = SchemaBuilder.struct().build(); + private static final String TRUE_LITERAL = Boolean.TRUE.toString(); + private static final String FALSE_LITERAL = Boolean.FALSE.toString(); + private static final long MILLIS_PER_DAY = 24 * 60 * 60 * 1000; + private static final String NULL_VALUE = "null"; + static final String ISO_8601_DATE_FORMAT_PATTERN = "yyyy-MM-dd"; + static final String ISO_8601_TIME_FORMAT_PATTERN = "HH:mm:ss.SSS'Z'"; + static final String ISO_8601_TIMESTAMP_FORMAT_PATTERN = ISO_8601_DATE_FORMAT_PATTERN + "'T'" + ISO_8601_TIME_FORMAT_PATTERN; + private static final Set TEMPORAL_LOGICAL_TYPE_NAMES = + Collections.unmodifiableSet( + new HashSet<>( + Arrays.asList(Time.LOGICAL_NAME, + Timestamp.LOGICAL_NAME, + Date.LOGICAL_NAME + ) + ) + ); + + private static final String QUOTE_DELIMITER = "\""; + private static final String COMMA_DELIMITER = ","; + private static final String ENTRY_DELIMITER = ":"; + private static final String ARRAY_BEGIN_DELIMITER = "["; + private static final String ARRAY_END_DELIMITER = "]"; + private static final String MAP_BEGIN_DELIMITER = "{"; + private static final String MAP_END_DELIMITER = "}"; + private static final int ISO_8601_DATE_LENGTH = ISO_8601_DATE_FORMAT_PATTERN.length(); + private static final int ISO_8601_TIME_LENGTH = ISO_8601_TIME_FORMAT_PATTERN.length() - 2; // subtract single quotes + private static final int ISO_8601_TIMESTAMP_LENGTH = ISO_8601_TIMESTAMP_FORMAT_PATTERN.length() - 4; // subtract single quotes + + private static final Pattern TWO_BACKSLASHES = Pattern.compile("\\\\"); + + private static final Pattern DOUBLEQOUTE = Pattern.compile("\""); + + /** + * Convert the specified value to an {@link Type#BOOLEAN} value. The supplied schema is required if the value is a logical + * type when the schema contains critical information that might be necessary for converting to a boolean. + * + * @param schema the schema for the value; may be null + * @param value the value to be converted; may be null + * @return the representation as a boolean, or null if the supplied value was null + * @throws DataException if the value could not be converted to a boolean + */ + public static Boolean convertToBoolean(Schema schema, Object value) throws DataException { + return (Boolean) convertTo(Schema.OPTIONAL_BOOLEAN_SCHEMA, schema, value); + } + + /** + * Convert the specified value to an {@link Type#INT8} byte value. The supplied schema is required if the value is a logical + * type when the schema contains critical information that might be necessary for converting to a byte. + * + * @param schema the schema for the value; may be null + * @param value the value to be converted; may be null + * @return the representation as a byte, or null if the supplied value was null + * @throws DataException if the value could not be converted to a byte + */ + public static Byte convertToByte(Schema schema, Object value) throws DataException { + return (Byte) convertTo(Schema.OPTIONAL_INT8_SCHEMA, schema, value); + } + + /** + * Convert the specified value to an {@link Type#INT16} short value. The supplied schema is required if the value is a logical + * type when the schema contains critical information that might be necessary for converting to a short. + * + * @param schema the schema for the value; may be null + * @param value the value to be converted; may be null + * @return the representation as a short, or null if the supplied value was null + * @throws DataException if the value could not be converted to a short + */ + public static Short convertToShort(Schema schema, Object value) throws DataException { + return (Short) convertTo(Schema.OPTIONAL_INT16_SCHEMA, schema, value); + } + + /** + * Convert the specified value to an {@link Type#INT32} int value. The supplied schema is required if the value is a logical + * type when the schema contains critical information that might be necessary for converting to an integer. + * + * @param schema the schema for the value; may be null + * @param value the value to be converted; may be null + * @return the representation as an integer, or null if the supplied value was null + * @throws DataException if the value could not be converted to an integer + */ + public static Integer convertToInteger(Schema schema, Object value) throws DataException { + return (Integer) convertTo(Schema.OPTIONAL_INT32_SCHEMA, schema, value); + } + + /** + * Convert the specified value to an {@link Type#INT64} long value. The supplied schema is required if the value is a logical + * type when the schema contains critical information that might be necessary for converting to a long. + * + * @param schema the schema for the value; may be null + * @param value the value to be converted; may be null + * @return the representation as a long, or null if the supplied value was null + * @throws DataException if the value could not be converted to a long + */ + public static Long convertToLong(Schema schema, Object value) throws DataException { + return (Long) convertTo(Schema.OPTIONAL_INT64_SCHEMA, schema, value); + } + + /** + * Convert the specified value to an {@link Type#FLOAT32} float value. The supplied schema is required if the value is a logical + * type when the schema contains critical information that might be necessary for converting to a floating point number. + * + * @param schema the schema for the value; may be null + * @param value the value to be converted; may be null + * @return the representation as a float, or null if the supplied value was null + * @throws DataException if the value could not be converted to a float + */ + public static Float convertToFloat(Schema schema, Object value) throws DataException { + return (Float) convertTo(Schema.OPTIONAL_FLOAT32_SCHEMA, schema, value); + } + + /** + * Convert the specified value to an {@link Type#FLOAT64} double value. The supplied schema is required if the value is a logical + * type when the schema contains critical information that might be necessary for converting to a floating point number. + * + * @param schema the schema for the value; may be null + * @param value the value to be converted; may be null + * @return the representation as a double, or null if the supplied value was null + * @throws DataException if the value could not be converted to a double + */ + public static Double convertToDouble(Schema schema, Object value) throws DataException { + return (Double) convertTo(Schema.OPTIONAL_FLOAT64_SCHEMA, schema, value); + } + + /** + * Convert the specified value to an {@link Type#STRING} value. + * Not supplying a schema may limit the ability to convert to the desired type. + * + * @param schema the schema for the value; may be null + * @param value the value to be converted; may be null + * @return the representation as a string, or null if the supplied value was null + */ + public static String convertToString(Schema schema, Object value) { + return (String) convertTo(Schema.OPTIONAL_STRING_SCHEMA, schema, value); + } + + /** + * Convert the specified value to an {@link Type#ARRAY} value. If the value is a string representation of an array, this method + * will parse the string and its elements to infer the schemas for those elements. Thus, this method supports + * arrays of other primitives and structured types. If the value is already an array (or list), this method simply casts and + * returns it. + * + *

            This method currently does not use the schema, though it may be used in the future.

            + * + * @param schema the schema for the value; may be null + * @param value the value to be converted; may be null + * @return the representation as a list, or null if the supplied value was null + * @throws DataException if the value cannot be converted to a list value + */ + public static List convertToList(Schema schema, Object value) { + return (List) convertTo(ARRAY_SELECTOR_SCHEMA, schema, value); + } + + /** + * Convert the specified value to an {@link Type#MAP} value. If the value is a string representation of a map, this method + * will parse the string and its entries to infer the schemas for those entries. Thus, this method supports + * maps with primitives and structured keys and values. If the value is already a map, this method simply casts and returns it. + * + *

            This method currently does not use the schema, though it may be used in the future.

            + * + * @param schema the schema for the value; may be null + * @param value the value to be converted; may be null + * @return the representation as a map, or null if the supplied value was null + * @throws DataException if the value cannot be converted to a map value + */ + public static Map convertToMap(Schema schema, Object value) { + return (Map) convertTo(MAP_SELECTOR_SCHEMA, schema, value); + } + + /** + * Convert the specified value to an {@link Type#STRUCT} value. Structs cannot be converted from other types, so this method returns + * a struct only if the supplied value is a struct. If not a struct, this method throws an exception. + * + *

            This method currently does not use the schema, though it may be used in the future.

            + * + * @param schema the schema for the value; may be null + * @param value the value to be converted; may be null + * @return the representation as a struct, or null if the supplied value was null + * @throws DataException if the value is not a struct + */ + public static Struct convertToStruct(Schema schema, Object value) { + return (Struct) convertTo(STRUCT_SELECTOR_SCHEMA, schema, value); + } + + /** + * Convert the specified value to an {@link Time#SCHEMA time} value. + * Not supplying a schema may limit the ability to convert to the desired type. + * + * @param schema the schema for the value; may be null + * @param value the value to be converted; may be null + * @return the representation as a time, or null if the supplied value was null + * @throws DataException if the value cannot be converted to a time value + */ + public static java.util.Date convertToTime(Schema schema, Object value) { + return (java.util.Date) convertTo(Time.SCHEMA, schema, value); + } + + /** + * Convert the specified value to an {@link Date#SCHEMA date} value. + * Not supplying a schema may limit the ability to convert to the desired type. + * + * @param schema the schema for the value; may be null + * @param value the value to be converted; may be null + * @return the representation as a date, or null if the supplied value was null + * @throws DataException if the value cannot be converted to a date value + */ + public static java.util.Date convertToDate(Schema schema, Object value) { + return (java.util.Date) convertTo(Date.SCHEMA, schema, value); + } + + /** + * Convert the specified value to an {@link Timestamp#SCHEMA timestamp} value. + * Not supplying a schema may limit the ability to convert to the desired type. + * + * @param schema the schema for the value; may be null + * @param value the value to be converted; may be null + * @return the representation as a timestamp, or null if the supplied value was null + * @throws DataException if the value cannot be converted to a timestamp value + */ + public static java.util.Date convertToTimestamp(Schema schema, Object value) { + return (java.util.Date) convertTo(Timestamp.SCHEMA, schema, value); + } + + /** + * Convert the specified value to an {@link Decimal decimal} value. + * Not supplying a schema may limit the ability to convert to the desired type. + * + * @param schema the schema for the value; may be null + * @param value the value to be converted; may be null + * @return the representation as a decimal, or null if the supplied value was null + * @throws DataException if the value cannot be converted to a decimal value + */ + public static BigDecimal convertToDecimal(Schema schema, Object value, int scale) { + return (BigDecimal) convertTo(Decimal.schema(scale), schema, value); + } + + /** + * If possible infer a schema for the given value. + * + * @param value the value whose schema is to be inferred; may be null + * @return the inferred schema, or null if the value is null or no schema could be inferred + */ + public static Schema inferSchema(Object value) { + if (value instanceof String) { + return Schema.STRING_SCHEMA; + } + if (value instanceof Boolean) { + return Schema.BOOLEAN_SCHEMA; + } + if (value instanceof Byte) { + return Schema.INT8_SCHEMA; + } + if (value instanceof Short) { + return Schema.INT16_SCHEMA; + } + if (value instanceof Integer) { + return Schema.INT32_SCHEMA; + } + if (value instanceof Long) { + return Schema.INT64_SCHEMA; + } + if (value instanceof Float) { + return Schema.FLOAT32_SCHEMA; + } + if (value instanceof Double) { + return Schema.FLOAT64_SCHEMA; + } + if (value instanceof byte[] || value instanceof ByteBuffer) { + return Schema.BYTES_SCHEMA; + } + if (value instanceof List) { + List list = (List) value; + if (list.isEmpty()) { + return null; + } + SchemaDetector detector = new SchemaDetector(); + for (Object element : list) { + if (!detector.canDetect(element)) { + return null; + } + } + return SchemaBuilder.array(detector.schema()).build(); + } + if (value instanceof Map) { + Map map = (Map) value; + if (map.isEmpty()) { + return null; + } + SchemaDetector keyDetector = new SchemaDetector(); + SchemaDetector valueDetector = new SchemaDetector(); + for (Map.Entry entry : map.entrySet()) { + if (!keyDetector.canDetect(entry.getKey()) || !valueDetector.canDetect(entry.getValue())) { + return null; + } + } + return SchemaBuilder.map(keyDetector.schema(), valueDetector.schema()).build(); + } + if (value instanceof Struct) { + return ((Struct) value).schema(); + } + return null; + } + + + /** + * Parse the specified string representation of a value into its schema and value. + * + * @param value the string form of the value + * @return the schema and value; never null, but whose schema and value may be null + * @see #convertToString + */ + public static SchemaAndValue parseString(String value) { + if (value == null) { + return NULL_SCHEMA_AND_VALUE; + } + if (value.isEmpty()) { + return new SchemaAndValue(Schema.STRING_SCHEMA, value); + } + Parser parser = new Parser(value); + return parse(parser, false); + } + + /** + * Convert the value to the desired type. + * + * @param toSchema the schema for the desired type; may not be null + * @param fromSchema the schema for the supplied value; may be null if not known + * @return the converted value; never null + * @throws DataException if the value could not be converted to the desired type + */ + protected static Object convertTo(Schema toSchema, Schema fromSchema, Object value) throws DataException { + if (value == null) { + if (toSchema.isOptional()) { + return null; + } + throw new DataException("Unable to convert a null value to a schema that requires a value"); + } + switch (toSchema.type()) { + case BYTES: + if (Decimal.LOGICAL_NAME.equals(toSchema.name())) { + if (value instanceof ByteBuffer) { + value = Utils.toArray((ByteBuffer) value); + } + if (value instanceof byte[]) { + return Decimal.toLogical(toSchema, (byte[]) value); + } + if (value instanceof BigDecimal) { + return value; + } + if (value instanceof Number) { + // Not already a decimal, so treat it as a double ... + double converted = ((Number) value).doubleValue(); + return BigDecimal.valueOf(converted); + } + if (value instanceof String) { + return new BigDecimal(value.toString()).doubleValue(); + } + } + if (value instanceof ByteBuffer) { + return Utils.toArray((ByteBuffer) value); + } + if (value instanceof byte[]) { + return value; + } + if (value instanceof BigDecimal) { + return Decimal.fromLogical(toSchema, (BigDecimal) value); + } + break; + case STRING: + StringBuilder sb = new StringBuilder(); + append(sb, value, false); + return sb.toString(); + case BOOLEAN: + if (value instanceof Boolean) { + return value; + } + if (value instanceof String) { + SchemaAndValue parsed = parseString(value.toString()); + if (parsed.value() instanceof Boolean) { + return parsed.value(); + } + } + return asLong(value, fromSchema, null) == 0L ? Boolean.FALSE : Boolean.TRUE; + case INT8: + if (value instanceof Byte) { + return value; + } + return (byte) asLong(value, fromSchema, null); + case INT16: + if (value instanceof Short) { + return value; + } + return (short) asLong(value, fromSchema, null); + case INT32: + if (Date.LOGICAL_NAME.equals(toSchema.name())) { + if (value instanceof String) { + SchemaAndValue parsed = parseString(value.toString()); + value = parsed.value(); + } + if (value instanceof java.util.Date) { + if (fromSchema != null) { + String fromSchemaName = fromSchema.name(); + if (Date.LOGICAL_NAME.equals(fromSchemaName)) { + return value; + } + if (Timestamp.LOGICAL_NAME.equals(fromSchemaName)) { + // Just get the number of days from this timestamp + long millis = ((java.util.Date) value).getTime(); + int days = (int) (millis / MILLIS_PER_DAY); // truncates + return Date.toLogical(toSchema, days); + } + } else { + // There is no fromSchema, so no conversion is needed + return value; + } + } + long numeric = asLong(value, fromSchema, null); + return Date.toLogical(toSchema, (int) numeric); + } + if (Time.LOGICAL_NAME.equals(toSchema.name())) { + if (value instanceof String) { + SchemaAndValue parsed = parseString(value.toString()); + value = parsed.value(); + } + if (value instanceof java.util.Date) { + if (fromSchema != null) { + String fromSchemaName = fromSchema.name(); + if (Time.LOGICAL_NAME.equals(fromSchemaName)) { + return value; + } + if (Timestamp.LOGICAL_NAME.equals(fromSchemaName)) { + // Just get the time portion of this timestamp + Calendar calendar = Calendar.getInstance(UTC); + calendar.setTime((java.util.Date) value); + calendar.set(Calendar.YEAR, 1970); + calendar.set(Calendar.MONTH, 0); // Months are zero-based + calendar.set(Calendar.DAY_OF_MONTH, 1); + return Time.toLogical(toSchema, (int) calendar.getTimeInMillis()); + } + } else { + // There is no fromSchema, so no conversion is needed + return value; + } + } + long numeric = asLong(value, fromSchema, null); + return Time.toLogical(toSchema, (int) numeric); + } + if (value instanceof Integer) { + return value; + } + return (int) asLong(value, fromSchema, null); + case INT64: + if (Timestamp.LOGICAL_NAME.equals(toSchema.name())) { + if (value instanceof String) { + SchemaAndValue parsed = parseString(value.toString()); + value = parsed.value(); + } + if (value instanceof java.util.Date) { + java.util.Date date = (java.util.Date) value; + if (fromSchema != null) { + String fromSchemaName = fromSchema.name(); + if (Date.LOGICAL_NAME.equals(fromSchemaName)) { + int days = Date.fromLogical(fromSchema, date); + long millis = days * MILLIS_PER_DAY; + return Timestamp.toLogical(toSchema, millis); + } + if (Time.LOGICAL_NAME.equals(fromSchemaName)) { + long millis = Time.fromLogical(fromSchema, date); + return Timestamp.toLogical(toSchema, millis); + } + if (Timestamp.LOGICAL_NAME.equals(fromSchemaName)) { + return value; + } + } else { + // There is no fromSchema, so no conversion is needed + return value; + } + } + long numeric = asLong(value, fromSchema, null); + return Timestamp.toLogical(toSchema, numeric); + } + if (value instanceof Long) { + return value; + } + return asLong(value, fromSchema, null); + case FLOAT32: + if (value instanceof Float) { + return value; + } + return (float) asDouble(value, fromSchema, null); + case FLOAT64: + if (value instanceof Double) { + return value; + } + return asDouble(value, fromSchema, null); + case ARRAY: + if (value instanceof String) { + SchemaAndValue schemaAndValue = parseString(value.toString()); + value = schemaAndValue.value(); + } + if (value instanceof List) { + return value; + } + break; + case MAP: + if (value instanceof String) { + SchemaAndValue schemaAndValue = parseString(value.toString()); + value = schemaAndValue.value(); + } + if (value instanceof Map) { + return value; + } + break; + case STRUCT: + if (value instanceof Struct) { + Struct struct = (Struct) value; + return struct; + } + } + throw new DataException("Unable to convert " + value + " (" + value.getClass() + ") to " + toSchema); + } + + /** + * Convert the specified value to the desired scalar value type. + * + * @param value the value to be converted; may not be null + * @param fromSchema the schema for the current value type; may not be null + * @param error any previous error that should be included in an exception message; may be null + * @return the long value after conversion; never null + * @throws DataException if the value could not be converted to a long + */ + protected static long asLong(Object value, Schema fromSchema, Throwable error) { + try { + if (value instanceof Number) { + Number number = (Number) value; + return number.longValue(); + } + if (value instanceof String) { + return new BigDecimal(value.toString()).longValue(); + } + } catch (NumberFormatException e) { + error = e; + // fall through + } + if (fromSchema != null) { + String schemaName = fromSchema.name(); + if (value instanceof java.util.Date) { + if (Date.LOGICAL_NAME.equals(schemaName)) { + return Date.fromLogical(fromSchema, (java.util.Date) value); + } + if (Time.LOGICAL_NAME.equals(schemaName)) { + return Time.fromLogical(fromSchema, (java.util.Date) value); + } + if (Timestamp.LOGICAL_NAME.equals(schemaName)) { + return Timestamp.fromLogical(fromSchema, (java.util.Date) value); + } + } + throw new DataException("Unable to convert " + value + " (" + value.getClass() + ") to " + fromSchema, error); + } + throw new DataException("Unable to convert " + value + " (" + value.getClass() + ") to a number", error); + } + + /** + * Convert the specified value with the desired floating point type. + * + * @param value the value to be converted; may not be null + * @param schema the schema for the current value type; may not be null + * @param error any previous error that should be included in an exception message; may be null + * @return the double value after conversion; never null + * @throws DataException if the value could not be converted to a double + */ + protected static double asDouble(Object value, Schema schema, Throwable error) { + try { + if (value instanceof Number) { + Number number = (Number) value; + return number.doubleValue(); + } + if (value instanceof String) { + return new BigDecimal(value.toString()).doubleValue(); + } + } catch (NumberFormatException e) { + error = e; + // fall through + } + return asLong(value, schema, error); + } + + protected static void append(StringBuilder sb, Object value, boolean embedded) { + if (value == null) { + sb.append(NULL_VALUE); + } else if (value instanceof Number) { + sb.append(value); + } else if (value instanceof Boolean) { + sb.append(value); + } else if (value instanceof String) { + if (embedded) { + String escaped = escape((String) value); + sb.append('"').append(escaped).append('"'); + } else { + sb.append(value); + } + } else if (value instanceof byte[]) { + value = Base64.getEncoder().encodeToString((byte[]) value); + if (embedded) { + sb.append('"').append(value).append('"'); + } else { + sb.append(value); + } + } else if (value instanceof ByteBuffer) { + byte[] bytes = Utils.readBytes((ByteBuffer) value); + append(sb, bytes, embedded); + } else if (value instanceof List) { + List list = (List) value; + sb.append('['); + appendIterable(sb, list.iterator()); + sb.append(']'); + } else if (value instanceof Map) { + Map map = (Map) value; + sb.append('{'); + appendIterable(sb, map.entrySet().iterator()); + sb.append('}'); + } else if (value instanceof Struct) { + Struct struct = (Struct) value; + Schema schema = struct.schema(); + boolean first = true; + sb.append('{'); + for (Field field : schema.fields()) { + if (first) { + first = false; + } else { + sb.append(','); + } + append(sb, field.name(), true); + sb.append(':'); + append(sb, struct.get(field), true); + } + sb.append('}'); + } else if (value instanceof Map.Entry) { + Map.Entry entry = (Map.Entry) value; + append(sb, entry.getKey(), true); + sb.append(':'); + append(sb, entry.getValue(), true); + } else if (value instanceof java.util.Date) { + java.util.Date dateValue = (java.util.Date) value; + String formatted = dateFormatFor(dateValue).format(dateValue); + sb.append(formatted); + } else { + throw new DataException("Failed to serialize unexpected value type " + value.getClass().getName() + ": " + value); + } + } + + protected static void appendIterable(StringBuilder sb, Iterator iter) { + if (iter.hasNext()) { + append(sb, iter.next(), true); + while (iter.hasNext()) { + sb.append(','); + append(sb, iter.next(), true); + } + } + } + + protected static String escape(String value) { + String replace1 = TWO_BACKSLASHES.matcher(value).replaceAll("\\\\\\\\"); + return DOUBLEQOUTE.matcher(replace1).replaceAll("\\\\\""); + } + + public static DateFormat dateFormatFor(java.util.Date value) { + if (value.getTime() < MILLIS_PER_DAY) { + return new SimpleDateFormat(ISO_8601_TIME_FORMAT_PATTERN); + } + if (value.getTime() % MILLIS_PER_DAY == 0) { + return new SimpleDateFormat(ISO_8601_DATE_FORMAT_PATTERN); + } + return new SimpleDateFormat(ISO_8601_TIMESTAMP_FORMAT_PATTERN); + } + + protected static boolean canParseSingleTokenLiteral(Parser parser, boolean embedded, String tokenLiteral) { + int startPosition = parser.mark(); + // If the next token is what we expect, then either... + if (parser.canConsume(tokenLiteral)) { + // ...we're reading an embedded value, in which case the next token will be handled appropriately + // by the caller if it's something like an end delimiter for a map or array, or a comma to + // separate multiple embedded values... + // ...or it's being parsed as part of a top-level string, in which case, any other tokens should + // cause use to stop parsing this single-token literal as such and instead just treat it like + // a string. For example, the top-level string "true}" will be tokenized as the tokens "true" and + // "}", but should ultimately be parsed as just the string "true}" instead of the boolean true. + if (embedded || !parser.hasNext()) { + return true; + } + } + parser.rewindTo(startPosition); + return false; + } + + protected static SchemaAndValue parse(Parser parser, boolean embedded) throws NoSuchElementException { + if (!parser.hasNext()) { + return null; + } + if (embedded) { + if (parser.canConsume(QUOTE_DELIMITER)) { + StringBuilder sb = new StringBuilder(); + while (parser.hasNext()) { + if (parser.canConsume(QUOTE_DELIMITER)) { + break; + } + sb.append(parser.next()); + } + String content = sb.toString(); + // We can parse string literals as temporal logical types, but all others + // are treated as strings + SchemaAndValue parsed = parseString(content); + if (parsed != null && TEMPORAL_LOGICAL_TYPE_NAMES.contains(parsed.schema().name())) { + return parsed; + } + return new SchemaAndValue(Schema.STRING_SCHEMA, content); + } + } + + if (canParseSingleTokenLiteral(parser, embedded, NULL_VALUE)) { + return null; + } + if (canParseSingleTokenLiteral(parser, embedded, TRUE_LITERAL)) { + return TRUE_SCHEMA_AND_VALUE; + } + if (canParseSingleTokenLiteral(parser, embedded, FALSE_LITERAL)) { + return FALSE_SCHEMA_AND_VALUE; + } + + int startPosition = parser.mark(); + + try { + if (parser.canConsume(ARRAY_BEGIN_DELIMITER)) { + List result = new ArrayList<>(); + Schema elementSchema = null; + while (parser.hasNext()) { + if (parser.canConsume(ARRAY_END_DELIMITER)) { + Schema listSchema; + if (elementSchema != null) { + listSchema = SchemaBuilder.array(elementSchema).schema(); + result = alignListEntriesWithSchema(listSchema, result); + } else { + // Every value is null + listSchema = SchemaBuilder.arrayOfNull().build(); + } + return new SchemaAndValue(listSchema, result); + } + + if (parser.canConsume(COMMA_DELIMITER)) { + throw new DataException("Unable to parse an empty array element: " + parser.original()); + } + SchemaAndValue element = parse(parser, true); + elementSchema = commonSchemaFor(elementSchema, element); + result.add(element != null ? element.value() : null); + + int currentPosition = parser.mark(); + if (parser.canConsume(ARRAY_END_DELIMITER)) { + parser.rewindTo(currentPosition); + } else if (!parser.canConsume(COMMA_DELIMITER)) { + throw new DataException("Array elements missing '" + COMMA_DELIMITER + "' delimiter"); + } + } + + // Missing either a comma or an end delimiter + if (COMMA_DELIMITER.equals(parser.previous())) { + throw new DataException("Array is missing element after ',': " + parser.original()); + } + throw new DataException("Array is missing terminating ']': " + parser.original()); + } + + if (parser.canConsume(MAP_BEGIN_DELIMITER)) { + Map result = new LinkedHashMap<>(); + Schema keySchema = null; + Schema valueSchema = null; + while (parser.hasNext()) { + if (parser.canConsume(MAP_END_DELIMITER)) { + Schema mapSchema; + if (keySchema != null && valueSchema != null) { + mapSchema = SchemaBuilder.map(keySchema, valueSchema).build(); + result = alignMapKeysAndValuesWithSchema(mapSchema, result); + } else if (keySchema != null) { + mapSchema = SchemaBuilder.mapWithNullValues(keySchema); + result = alignMapKeysWithSchema(mapSchema, result); + } else { + mapSchema = SchemaBuilder.mapOfNull().build(); + } + return new SchemaAndValue(mapSchema, result); + } + + if (parser.canConsume(COMMA_DELIMITER)) { + throw new DataException("Unable to parse a map entry with no key or value: " + parser.original()); + } + SchemaAndValue key = parse(parser, true); + if (key == null || key.value() == null) { + throw new DataException("Map entry may not have a null key: " + parser.original()); + } + + if (!parser.canConsume(ENTRY_DELIMITER)) { + throw new DataException("Map entry is missing '" + ENTRY_DELIMITER + + "' at " + parser.position() + + " in " + parser.original()); + } + SchemaAndValue value = parse(parser, true); + Object entryValue = value != null ? value.value() : null; + result.put(key.value(), entryValue); + + parser.canConsume(COMMA_DELIMITER); + keySchema = commonSchemaFor(keySchema, key); + valueSchema = commonSchemaFor(valueSchema, value); + } + // Missing either a comma or an end delimiter + if (COMMA_DELIMITER.equals(parser.previous())) { + throw new DataException("Map is missing element after ',': " + parser.original()); + } + throw new DataException("Map is missing terminating '}': " + parser.original()); + } + } catch (DataException e) { + LOG.trace("Unable to parse the value as a map or an array; reverting to string", e); + parser.rewindTo(startPosition); + } + + String token = parser.next(); + if (Utils.isBlank(token)) { + return new SchemaAndValue(Schema.STRING_SCHEMA, token); + } + token = token.trim(); + + char firstChar = token.charAt(0); + boolean firstCharIsDigit = Character.isDigit(firstChar); + + // Temporal types are more restrictive, so try them first + if (firstCharIsDigit) { + // The time and timestamp literals may be split into 5 tokens since an unescaped colon + // is a delimiter. Check these first since the first of these tokens is a simple numeric + int position = parser.mark(); + String remainder = parser.next(4); + if (remainder != null) { + String timeOrTimestampStr = token + remainder; + SchemaAndValue temporal = parseAsTemporal(timeOrTimestampStr); + if (temporal != null) { + return temporal; + } + } + // No match was found using the 5 tokens, so rewind and see if the current token has a date, time, or timestamp + parser.rewindTo(position); + SchemaAndValue temporal = parseAsTemporal(token); + if (temporal != null) { + return temporal; + } + } + if (firstCharIsDigit || firstChar == '+' || firstChar == '-') { + try { + // Try to parse as a number ... + BigDecimal decimal = new BigDecimal(token); + try { + return new SchemaAndValue(Schema.INT8_SCHEMA, decimal.byteValueExact()); + } catch (ArithmeticException e) { + // continue + } + try { + return new SchemaAndValue(Schema.INT16_SCHEMA, decimal.shortValueExact()); + } catch (ArithmeticException e) { + // continue + } + try { + return new SchemaAndValue(Schema.INT32_SCHEMA, decimal.intValueExact()); + } catch (ArithmeticException e) { + // continue + } + try { + return new SchemaAndValue(Schema.INT64_SCHEMA, decimal.longValueExact()); + } catch (ArithmeticException e) { + // continue + } + float fValue = decimal.floatValue(); + if (fValue != Float.NEGATIVE_INFINITY && fValue != Float.POSITIVE_INFINITY + && decimal.scale() != 0) { + return new SchemaAndValue(Schema.FLOAT32_SCHEMA, fValue); + } + double dValue = decimal.doubleValue(); + if (dValue != Double.NEGATIVE_INFINITY && dValue != Double.POSITIVE_INFINITY + && decimal.scale() != 0) { + return new SchemaAndValue(Schema.FLOAT64_SCHEMA, dValue); + } + Schema schema = Decimal.schema(decimal.scale()); + return new SchemaAndValue(schema, decimal); + } catch (NumberFormatException e) { + // can't parse as a number + } + } + if (embedded) { + throw new DataException("Failed to parse embedded value"); + } + // At this point, the only thing this non-embedded value can be is a string. + return new SchemaAndValue(Schema.STRING_SCHEMA, parser.original()); + } + + private static SchemaAndValue parseAsTemporal(String token) { + if (token == null) { + return null; + } + // If the colons were escaped, we'll see the escape chars and need to remove them + token = token.replace("\\:", ":"); + int tokenLength = token.length(); + if (tokenLength == ISO_8601_TIME_LENGTH) { + try { + return new SchemaAndValue(Time.SCHEMA, new SimpleDateFormat(ISO_8601_TIME_FORMAT_PATTERN).parse(token)); + } catch (ParseException e) { + // not a valid date + } + } else if (tokenLength == ISO_8601_TIMESTAMP_LENGTH) { + try { + return new SchemaAndValue(Timestamp.SCHEMA, new SimpleDateFormat(ISO_8601_TIMESTAMP_FORMAT_PATTERN).parse(token)); + } catch (ParseException e) { + // not a valid date + } + } else if (tokenLength == ISO_8601_DATE_LENGTH) { + try { + return new SchemaAndValue(Date.SCHEMA, new SimpleDateFormat(ISO_8601_DATE_FORMAT_PATTERN).parse(token)); + } catch (ParseException e) { + // not a valid date + } + } + return null; + } + + protected static Schema commonSchemaFor(Schema previous, SchemaAndValue latest) { + if (latest == null) { + return previous; + } + if (previous == null) { + return latest.schema(); + } + Schema newSchema = latest.schema(); + Type previousType = previous.type(); + Type newType = newSchema.type(); + if (previousType != newType) { + switch (previous.type()) { + case INT8: + if (newType == Type.INT16 || newType == Type.INT32 || newType == Type.INT64 || newType == Type.FLOAT32 || newType == + Type.FLOAT64) { + return newSchema; + } + break; + case INT16: + if (newType == Type.INT8) { + return previous; + } + if (newType == Type.INT32 || newType == Type.INT64 || newType == Type.FLOAT32 || newType == Type.FLOAT64) { + return newSchema; + } + break; + case INT32: + if (newType == Type.INT8 || newType == Type.INT16) { + return previous; + } + if (newType == Type.INT64 || newType == Type.FLOAT32 || newType == Type.FLOAT64) { + return newSchema; + } + break; + case INT64: + if (newType == Type.INT8 || newType == Type.INT16 || newType == Type.INT32) { + return previous; + } + if (newType == Type.FLOAT32 || newType == Type.FLOAT64) { + return newSchema; + } + break; + case FLOAT32: + if (newType == Type.INT8 || newType == Type.INT16 || newType == Type.INT32 || newType == Type.INT64) { + return previous; + } + if (newType == Type.FLOAT64) { + return newSchema; + } + break; + case FLOAT64: + if (newType == Type.INT8 || newType == Type.INT16 || newType == Type.INT32 || newType == Type.INT64 || newType == + Type.FLOAT32) { + return previous; + } + break; + } + return null; + } + if (previous.isOptional() == newSchema.isOptional()) { + // Use the optional one + return previous.isOptional() ? previous : newSchema; + } + if (!previous.equals(newSchema)) { + return null; + } + return previous; + } + + protected static List alignListEntriesWithSchema(Schema schema, List input) { + Schema valueSchema = schema.valueSchema(); + List result = new ArrayList<>(); + for (Object value : input) { + Object newValue = convertTo(valueSchema, null, value); + result.add(newValue); + } + return result; + } + + protected static Map alignMapKeysAndValuesWithSchema(Schema mapSchema, Map input) { + Schema keySchema = mapSchema.keySchema(); + Schema valueSchema = mapSchema.valueSchema(); + Map result = new LinkedHashMap<>(); + for (Map.Entry entry : input.entrySet()) { + Object newKey = convertTo(keySchema, null, entry.getKey()); + Object newValue = convertTo(valueSchema, null, entry.getValue()); + result.put(newKey, newValue); + } + return result; + } + + protected static Map alignMapKeysWithSchema(Schema mapSchema, Map input) { + Schema keySchema = mapSchema.keySchema(); + Map result = new LinkedHashMap<>(); + for (Map.Entry entry : input.entrySet()) { + Object newKey = convertTo(keySchema, null, entry.getKey()); + result.put(newKey, entry.getValue()); + } + return result; + } + + protected static class SchemaDetector { + private Type knownType = null; + private boolean optional = false; + + public SchemaDetector() { + } + + public boolean canDetect(Object value) { + if (value == null) { + optional = true; + return true; + } + Schema schema = inferSchema(value); + if (schema == null) { + return false; + } + if (knownType == null) { + knownType = schema.type(); + } else if (knownType != schema.type()) { + return false; + } + return true; + } + + public Schema schema() { + SchemaBuilder builder = SchemaBuilder.type(knownType); + if (optional) { + builder.optional(); + } + return builder.schema(); + } + } + + protected static class Parser { + private final String original; + private final CharacterIterator iter; + private String nextToken = null; + private String previousToken = null; + + public Parser(String original) { + this.original = original; + this.iter = new StringCharacterIterator(this.original); + } + + public int position() { + return iter.getIndex(); + } + + public int mark() { + return iter.getIndex() - (nextToken != null ? nextToken.length() : 0); + } + + public void rewindTo(int position) { + iter.setIndex(position); + nextToken = null; + previousToken = null; + } + + public String original() { + return original; + } + + public boolean hasNext() { + return nextToken != null || canConsumeNextToken(); + } + + protected boolean canConsumeNextToken() { + return iter.getEndIndex() > iter.getIndex(); + } + + public String next() { + if (nextToken != null) { + previousToken = nextToken; + nextToken = null; + } else { + previousToken = consumeNextToken(); + } + return previousToken; + } + + public String next(int n) { + int current = mark(); + int start = mark(); + for (int i = 0; i != n; ++i) { + if (!hasNext()) { + rewindTo(start); + return null; + } + next(); + } + return original.substring(current, position()); + } + + private String consumeNextToken() throws NoSuchElementException { + boolean escaped = false; + int start = iter.getIndex(); + char c = iter.current(); + while (canConsumeNextToken()) { + switch (c) { + case '\\': + escaped = !escaped; + break; + case ':': + case ',': + case '{': + case '}': + case '[': + case ']': + case '\"': + if (!escaped) { + if (start < iter.getIndex()) { + // Return the previous token + return original.substring(start, iter.getIndex()); + } + // Consume and return this delimiter as a token + iter.next(); + return original.substring(start, start + 1); + } + // escaped, so continue + escaped = false; + break; + default: + // If escaped, then we don't care what was escaped + escaped = false; + break; + } + c = iter.next(); + } + return original.substring(start, iter.getIndex()); + } + + public String previous() { + return previousToken; + } + + public boolean canConsume(String expected) { + return canConsume(expected, true); + } + + public boolean canConsume(String expected, boolean ignoreLeadingAndTrailingWhitespace) { + if (isNext(expected, ignoreLeadingAndTrailingWhitespace)) { + // consume this token ... + nextToken = null; + return true; + } + return false; + } + + protected boolean isNext(String expected, boolean ignoreLeadingAndTrailingWhitespace) { + if (nextToken == null) { + if (!hasNext()) { + return false; + } + // There's another token, so consume it + nextToken = consumeNextToken(); + } + if (ignoreLeadingAndTrailingWhitespace) { + while (Utils.isBlank(nextToken) && canConsumeNextToken()) { + nextToken = consumeNextToken(); + } + } + return ignoreLeadingAndTrailingWhitespace + ? nextToken.trim().equals(expected) + : nextToken.equals(expected); + } + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/errors/AlreadyExistsException.java b/connect/api/src/main/java/org/apache/kafka/connect/errors/AlreadyExistsException.java new file mode 100644 index 0000000..a37f615 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/errors/AlreadyExistsException.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.errors; + +/** + * Indicates the operation tried to create an entity that already exists. + */ +public class AlreadyExistsException extends ConnectException { + public AlreadyExistsException(String s) { + super(s); + } + + public AlreadyExistsException(String s, Throwable throwable) { + super(s, throwable); + } + + public AlreadyExistsException(Throwable throwable) { + super(throwable); + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/errors/ConnectException.java b/connect/api/src/main/java/org/apache/kafka/connect/errors/ConnectException.java new file mode 100644 index 0000000..3cbde36 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/errors/ConnectException.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.errors; + +import org.apache.kafka.common.KafkaException; + +/** + * ConnectException is the top-level exception type generated by Kafka Connect and connector implementations. + */ +public class ConnectException extends KafkaException { + + public ConnectException(String s) { + super(s); + } + + public ConnectException(String s, Throwable throwable) { + super(s, throwable); + } + + public ConnectException(Throwable throwable) { + super(throwable); + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/errors/DataException.java b/connect/api/src/main/java/org/apache/kafka/connect/errors/DataException.java new file mode 100644 index 0000000..a850347 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/errors/DataException.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.errors; + +/** + * Base class for all Kafka Connect data API exceptions. + */ +public class DataException extends ConnectException { + public DataException(String s) { + super(s); + } + + public DataException(String s, Throwable throwable) { + super(s, throwable); + } + + public DataException(Throwable throwable) { + super(throwable); + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/errors/IllegalWorkerStateException.java b/connect/api/src/main/java/org/apache/kafka/connect/errors/IllegalWorkerStateException.java new file mode 100644 index 0000000..be9cd34 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/errors/IllegalWorkerStateException.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.errors; + +/** + * Indicates that a method has been invoked illegally or at an invalid time by a connector or task. + */ +public class IllegalWorkerStateException extends ConnectException { + public IllegalWorkerStateException(String s) { + super(s); + } + + public IllegalWorkerStateException(String s, Throwable throwable) { + super(s, throwable); + } + + public IllegalWorkerStateException(Throwable throwable) { + super(throwable); + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/errors/NotFoundException.java b/connect/api/src/main/java/org/apache/kafka/connect/errors/NotFoundException.java new file mode 100644 index 0000000..90f0179 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/errors/NotFoundException.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.errors; + +/** + * Indicates that an operation attempted to modify or delete a connector or task that is not present on the worker. + */ +public class NotFoundException extends ConnectException { + public NotFoundException(String s) { + super(s); + } + + public NotFoundException(String s, Throwable throwable) { + super(s, throwable); + } + + public NotFoundException(Throwable throwable) { + super(throwable); + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/errors/RetriableException.java b/connect/api/src/main/java/org/apache/kafka/connect/errors/RetriableException.java new file mode 100644 index 0000000..0b34bd0 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/errors/RetriableException.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.errors; + +/** + * An exception that indicates the operation can be reattempted. + */ +public class RetriableException extends ConnectException { + public RetriableException(String s) { + super(s); + } + + public RetriableException(String s, Throwable throwable) { + super(s, throwable); + } + + public RetriableException(Throwable throwable) { + super(throwable); + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/errors/SchemaBuilderException.java b/connect/api/src/main/java/org/apache/kafka/connect/errors/SchemaBuilderException.java new file mode 100644 index 0000000..41843c3 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/errors/SchemaBuilderException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.errors; + +public class SchemaBuilderException extends DataException { + public SchemaBuilderException(String s) { + super(s); + } + + public SchemaBuilderException(String s, Throwable throwable) { + super(s, throwable); + } + + public SchemaBuilderException(Throwable throwable) { + super(throwable); + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/errors/SchemaProjectorException.java b/connect/api/src/main/java/org/apache/kafka/connect/errors/SchemaProjectorException.java new file mode 100644 index 0000000..e2b840a --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/errors/SchemaProjectorException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.errors; + +public class SchemaProjectorException extends DataException { + public SchemaProjectorException(String s) { + super(s); + } + + public SchemaProjectorException(String s, Throwable throwable) { + super(s, throwable); + } + + public SchemaProjectorException(Throwable throwable) { + super(throwable); + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/header/ConnectHeader.java b/connect/api/src/main/java/org/apache/kafka/connect/header/ConnectHeader.java new file mode 100644 index 0000000..3b9f347 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/header/ConnectHeader.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.header; + +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.data.Struct; + +import java.util.Objects; + +/** + * A {@link Header} implementation. + */ +class ConnectHeader implements Header { + + private static final SchemaAndValue NULL_SCHEMA_AND_VALUE = new SchemaAndValue(null, null); + + private final String key; + private final SchemaAndValue schemaAndValue; + + protected ConnectHeader(String key, SchemaAndValue schemaAndValue) { + Objects.requireNonNull(key, "Null header keys are not permitted"); + this.key = key; + this.schemaAndValue = schemaAndValue != null ? schemaAndValue : NULL_SCHEMA_AND_VALUE; + } + + @Override + public String key() { + return key; + } + + @Override + public Object value() { + return schemaAndValue.value(); + } + + @Override + public Schema schema() { + Schema schema = schemaAndValue.schema(); + if (schema == null && value() instanceof Struct) { + schema = ((Struct) value()).schema(); + } + return schema; + } + + @Override + public Header rename(String key) { + Objects.requireNonNull(key, "Null header keys are not permitted"); + if (this.key.equals(key)) { + return this; + } + return new ConnectHeader(key, schemaAndValue); + } + + @Override + public Header with(Schema schema, Object value) { + return new ConnectHeader(key, new SchemaAndValue(schema, value)); + } + + @Override + public int hashCode() { + return Objects.hash(key, schemaAndValue); + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + if (obj instanceof Header) { + Header that = (Header) obj; + return Objects.equals(this.key, that.key()) && Objects.equals(this.schema(), that.schema()) && Objects.equals(this.value(), + that.value()); + } + return false; + } + + @Override + public String toString() { + return "ConnectHeader(key=" + key + ", value=" + value() + ", schema=" + schema() + ")"; + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/header/ConnectHeaders.java b/connect/api/src/main/java/org/apache/kafka/connect/header/ConnectHeaders.java new file mode 100644 index 0000000..5c37ddc --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/header/ConnectHeaders.java @@ -0,0 +1,497 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.header; + +import org.apache.kafka.common.utils.AbstractIterator; +import org.apache.kafka.connect.data.Date; +import org.apache.kafka.connect.data.Decimal; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.Schema.Type; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.data.Time; +import org.apache.kafka.connect.data.Timestamp; +import org.apache.kafka.connect.errors.DataException; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.ListIterator; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +/** + * A basic {@link Headers} implementation. + */ +public class ConnectHeaders implements Headers { + + private static final int EMPTY_HASH = Objects.hash(new LinkedList<>()); + + private LinkedList
            headers; + + public ConnectHeaders() { + } + + public ConnectHeaders(Iterable
            original) { + if (original == null) { + return; + } + if (original instanceof ConnectHeaders) { + ConnectHeaders originalHeaders = (ConnectHeaders) original; + if (!originalHeaders.isEmpty()) { + headers = new LinkedList<>(originalHeaders.headers); + } + } else { + headers = new LinkedList<>(); + for (Header header : original) { + Objects.requireNonNull(header, "Unable to add a null header."); + headers.add(header); + } + } + } + + @Override + public int size() { + return headers == null ? 0 : headers.size(); + } + + @Override + public boolean isEmpty() { + return headers == null || headers.isEmpty(); + } + + @Override + public Headers clear() { + if (headers != null) { + headers.clear(); + } + return this; + } + + @Override + public Headers add(Header header) { + Objects.requireNonNull(header, "Unable to add a null header."); + if (headers == null) { + headers = new LinkedList<>(); + } + headers.add(header); + return this; + } + + protected Headers addWithoutValidating(String key, Object value, Schema schema) { + return add(new ConnectHeader(key, new SchemaAndValue(schema, value))); + } + + @Override + public Headers add(String key, SchemaAndValue schemaAndValue) { + checkSchemaMatches(schemaAndValue); + return add(new ConnectHeader(key, schemaAndValue != null ? schemaAndValue : SchemaAndValue.NULL)); + } + + @Override + public Headers add(String key, Object value, Schema schema) { + return add(key, value != null || schema != null ? new SchemaAndValue(schema, value) : SchemaAndValue.NULL); + } + + @Override + public Headers addString(String key, String value) { + return addWithoutValidating(key, value, value != null ? Schema.STRING_SCHEMA : Schema.OPTIONAL_STRING_SCHEMA); + } + + @Override + public Headers addBytes(String key, byte[] value) { + return addWithoutValidating(key, value, value != null ? Schema.BYTES_SCHEMA : Schema.OPTIONAL_BYTES_SCHEMA); + } + + @Override + public Headers addBoolean(String key, boolean value) { + return addWithoutValidating(key, value, Schema.BOOLEAN_SCHEMA); + } + + @Override + public Headers addByte(String key, byte value) { + return addWithoutValidating(key, value, Schema.INT8_SCHEMA); + } + + @Override + public Headers addShort(String key, short value) { + return addWithoutValidating(key, value, Schema.INT16_SCHEMA); + } + + @Override + public Headers addInt(String key, int value) { + return addWithoutValidating(key, value, Schema.INT32_SCHEMA); + } + + @Override + public Headers addLong(String key, long value) { + return addWithoutValidating(key, value, Schema.INT64_SCHEMA); + } + + @Override + public Headers addFloat(String key, float value) { + return addWithoutValidating(key, value, Schema.FLOAT32_SCHEMA); + } + + @Override + public Headers addDouble(String key, double value) { + return addWithoutValidating(key, value, Schema.FLOAT64_SCHEMA); + } + + @Override + public Headers addList(String key, List value, Schema schema) { + if (value == null) { + return add(key, null, null); + } + checkSchemaType(schema, Type.ARRAY); + return addWithoutValidating(key, value, schema); + } + + @Override + public Headers addMap(String key, Map value, Schema schema) { + if (value == null) { + return add(key, null, null); + } + checkSchemaType(schema, Type.MAP); + return addWithoutValidating(key, value, schema); + } + + @Override + public Headers addStruct(String key, Struct value) { + if (value == null) { + return add(key, null, null); + } + checkSchemaType(value.schema(), Type.STRUCT); + return addWithoutValidating(key, value, value.schema()); + } + + @Override + public Headers addDecimal(String key, BigDecimal value) { + if (value == null) { + return add(key, null, null); + } + // Check that this is a decimal ... + Schema schema = Decimal.schema(value.scale()); + Decimal.fromLogical(schema, value); + return addWithoutValidating(key, value, schema); + } + + @Override + public Headers addDate(String key, java.util.Date value) { + if (value != null) { + // Check that this is a date ... + Date.fromLogical(Date.SCHEMA, value); + } + return addWithoutValidating(key, value, Date.SCHEMA); + } + + @Override + public Headers addTime(String key, java.util.Date value) { + if (value != null) { + // Check that this is a time ... + Time.fromLogical(Time.SCHEMA, value); + } + return addWithoutValidating(key, value, Time.SCHEMA); + } + + @Override + public Headers addTimestamp(String key, java.util.Date value) { + if (value != null) { + // Check that this is a timestamp ... + Timestamp.fromLogical(Timestamp.SCHEMA, value); + } + return addWithoutValidating(key, value, Timestamp.SCHEMA); + } + + @Override + public Header lastWithName(String key) { + checkKey(key); + if (headers != null) { + ListIterator
            iter = headers.listIterator(headers.size()); + while (iter.hasPrevious()) { + Header header = iter.previous(); + if (key.equals(header.key())) { + return header; + } + } + } + return null; + } + + @Override + public Iterator
            allWithName(String key) { + return new FilterByKeyIterator(iterator(), key); + } + + @Override + public Iterator
            iterator() { + return headers == null ? Collections.emptyIterator() : + headers.iterator(); + } + + @Override + public Headers remove(String key) { + checkKey(key); + if (!isEmpty()) { + Iterator
            iterator = iterator(); + while (iterator.hasNext()) { + if (iterator.next().key().equals(key)) { + iterator.remove(); + } + } + } + return this; + } + + @Override + public Headers retainLatest() { + if (!isEmpty()) { + Set keys = new HashSet<>(); + ListIterator
            iter = headers.listIterator(headers.size()); + while (iter.hasPrevious()) { + Header header = iter.previous(); + String key = header.key(); + if (!keys.add(key)) { + iter.remove(); + } + } + } + return this; + } + + @Override + public Headers retainLatest(String key) { + checkKey(key); + if (!isEmpty()) { + boolean found = false; + ListIterator
            iter = headers.listIterator(headers.size()); + while (iter.hasPrevious()) { + String headerKey = iter.previous().key(); + if (key.equals(headerKey)) { + if (found) + iter.remove(); + found = true; + } + } + } + return this; + } + + @Override + public Headers apply(String key, HeaderTransform transform) { + checkKey(key); + if (!isEmpty()) { + ListIterator
            iter = headers.listIterator(); + while (iter.hasNext()) { + Header orig = iter.next(); + if (orig.key().equals(key)) { + Header updated = transform.apply(orig); + if (updated != null) { + iter.set(updated); + } else { + iter.remove(); + } + } + } + } + return this; + } + + @Override + public Headers apply(HeaderTransform transform) { + if (!isEmpty()) { + ListIterator
            iter = headers.listIterator(); + while (iter.hasNext()) { + Header orig = iter.next(); + Header updated = transform.apply(orig); + if (updated != null) { + iter.set(updated); + } else { + iter.remove(); + } + } + } + return this; + } + + @Override + public int hashCode() { + return isEmpty() ? EMPTY_HASH : Objects.hash(headers); + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + if (obj instanceof Headers) { + Headers that = (Headers) obj; + Iterator
            thisIter = this.iterator(); + Iterator
            thatIter = that.iterator(); + while (thisIter.hasNext() && thatIter.hasNext()) { + if (!Objects.equals(thisIter.next(), thatIter.next())) + return false; + } + return !thisIter.hasNext() && !thatIter.hasNext(); + } + return false; + } + + @Override + public String toString() { + return "ConnectHeaders(headers=" + (headers != null ? headers : "") + ")"; + } + + @Override + public ConnectHeaders duplicate() { + return new ConnectHeaders(this); + } + + /** + * Check that the key is not null + * + * @param key the key; may not be null + * @throws NullPointerException if the supplied key is null + */ + private void checkKey(String key) { + Objects.requireNonNull(key, "Header key cannot be null"); + } + + /** + * Check the {@link Schema#type() schema's type} matches the specified type. + * + * @param schema the schema; never null + * @param type the expected type + * @throws DataException if the schema's type does not match the expected type + */ + private void checkSchemaType(Schema schema, Type type) { + if (schema.type() != type) { + throw new DataException("Expecting " + type + " but instead found " + schema.type()); + } + } + + /** + * Check that the value and its schema are compatible. + * + * @param schemaAndValue the schema and value pair + * @throws DataException if the schema is not compatible with the value + */ + // visible for testing + void checkSchemaMatches(SchemaAndValue schemaAndValue) { + if (schemaAndValue != null) { + Schema schema = schemaAndValue.schema(); + if (schema == null) + return; + schema = schema.schema(); // in case a SchemaBuilder is used + Object value = schemaAndValue.value(); + if (value == null && !schema.isOptional()) { + throw new DataException("A null value requires an optional schema but was " + schema); + } + if (value != null) { + switch (schema.type()) { + case BYTES: + if (value instanceof ByteBuffer) + return; + if (value instanceof byte[]) + return; + if (value instanceof BigDecimal && Decimal.LOGICAL_NAME.equals(schema.name())) + return; + break; + case STRING: + if (value instanceof String) + return; + break; + case BOOLEAN: + if (value instanceof Boolean) + return; + break; + case INT8: + if (value instanceof Byte) + return; + break; + case INT16: + if (value instanceof Short) + return; + break; + case INT32: + if (value instanceof Integer) + return; + if (value instanceof java.util.Date && Date.LOGICAL_NAME.equals(schema.name())) + return; + if (value instanceof java.util.Date && Time.LOGICAL_NAME.equals(schema.name())) + return; + break; + case INT64: + if (value instanceof Long) + return; + if (value instanceof java.util.Date && Timestamp.LOGICAL_NAME.equals(schema.name())) + return; + break; + case FLOAT32: + if (value instanceof Float) + return; + break; + case FLOAT64: + if (value instanceof Double) + return; + break; + case ARRAY: + if (value instanceof List) + return; + break; + case MAP: + if (value instanceof Map) + return; + break; + case STRUCT: + if (value instanceof Struct) + return; + break; + } + throw new DataException("The value " + value + " is not compatible with the schema " + schema); + } + } + } + + private static final class FilterByKeyIterator extends AbstractIterator
            { + + private final Iterator
            original; + private final String key; + + private FilterByKeyIterator(Iterator
            original, String key) { + this.original = original; + this.key = key; + } + + @Override + protected Header makeNext() { + while (original.hasNext()) { + Header header = original.next(); + if (!header.key().equals(key)) { + continue; + } + return header; + } + return this.allDone(); + } + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/header/Header.java b/connect/api/src/main/java/org/apache/kafka/connect/header/Header.java new file mode 100644 index 0000000..a70d1dc --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/header/Header.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.header; + +import org.apache.kafka.connect.data.Schema; + +/** + * A {@link Header} is a key-value pair, and multiple headers can be included with the key, value, and timestamp in each Kafka message. + * If the value contains schema information, then the header will have a non-null {@link #schema() schema}. + *

            + * This is an immutable interface. + */ +public interface Header { + + /** + * The header's key, which is not necessarily unique within the set of headers on a Kafka message. + * + * @return the header's key; never null + */ + String key(); + + /** + * Return the {@link Schema} associated with this header, if there is one. Not all headers will have schemas. + * + * @return the header's schema, or null if no schema is associated with this header + */ + Schema schema(); + + /** + * Get the header's value as deserialized by Connect's header converter. + * + * @return the deserialized object representation of the header's value; may be null + */ + Object value(); + + /** + * Return a new {@link Header} object that has the same key but with the supplied value. + * + * @param schema the schema for the new value; may be null + * @param value the new value + * @return the new {@link Header}; never null + */ + Header with(Schema schema, Object value); + + /** + * Return a new {@link Header} object that has the same schema and value but with the supplied key. + * + * @param key the key for the new header; may not be null + * @return the new {@link Header}; never null + */ + Header rename(String key); +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/header/Headers.java b/connect/api/src/main/java/org/apache/kafka/connect/header/Headers.java new file mode 100644 index 0000000..d7bd779 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/header/Headers.java @@ -0,0 +1,308 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.header; + +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.errors.DataException; + +import java.math.BigDecimal; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +/** + * A mutable ordered collection of {@link Header} objects. Note that multiple headers may have the same {@link Header#key() key}. + */ +public interface Headers extends Iterable

            { + + /** + * Get the number of headers in this object. + * + * @return the number of headers; never negative + */ + int size(); + + /** + * Determine whether this object has no headers. + * + * @return true if there are no headers, or false if there is at least one header + */ + boolean isEmpty(); + + /** + * Get the collection of {@link Header} objects whose {@link Header#key() keys} all match the specified key. + * + * @param key the key; may not be null + * @return the iterator over headers with the specified key; may be null if there are no headers with the specified key + */ + Iterator
            allWithName(String key); + + /** + * Return the last {@link Header} with the specified key. + * + * @param key the key for the header; may not be null + * @return the last Header, or null if there are no headers with the specified key + */ + Header lastWithName(String key); + + /** + * Add the given {@link Header} to this collection. + * + * @param header the header; may not be null + * @return this object to facilitate chaining multiple methods; never null + */ + Headers add(Header header); + + /** + * Add to this collection a {@link Header} with the given key and value. + * + * @param key the header's key; may not be null + * @param schemaAndValue the {@link SchemaAndValue} for the header; may be null + * @return this object to facilitate chaining multiple methods; never null + */ + Headers add(String key, SchemaAndValue schemaAndValue); + + /** + * Add to this collection a {@link Header} with the given key and value. + * + * @param key the header's key; may not be null + * @param value the header's value; may be null + * @param schema the schema for the header's value; may not be null if the value is not null + * @return this object to facilitate chaining multiple methods; never null + */ + Headers add(String key, Object value, Schema schema); + + /** + * Add to this collection a {@link Header} with the given key and value. + * + * @param key the header's key; may not be null + * @param value the header's value; may be null + * @return this object to facilitate chaining multiple methods; never null + */ + Headers addString(String key, String value); + + /** + * Add to this collection a {@link Header} with the given key and value. + * + * @param key the header's key; may not be null + * @param value the header's value; may be null + * @return this object to facilitate chaining multiple methods; never null + */ + Headers addBoolean(String key, boolean value); + + /** + * Add to this collection a {@link Header} with the given key and value. + * + * @param key the header's key; may not be null + * @param value the header's value; may be null + * @return this object to facilitate chaining multiple methods; never null + */ + Headers addByte(String key, byte value); + + /** + * Add to this collection a {@link Header} with the given key and value. + * + * @param key the header's key; may not be null + * @param value the header's value; may be null + * @return this object to facilitate chaining multiple methods; never null + */ + Headers addShort(String key, short value); + + /** + * Add to this collection a {@link Header} with the given key and value. + * + * @param key the header's key; may not be null + * @param value the header's value; may be null + * @return this object to facilitate chaining multiple methods; never null + */ + Headers addInt(String key, int value); + + /** + * Add to this collection a {@link Header} with the given key and value. + * + * @param key the header's key; may not be null + * @param value the header's value; may be null + * @return this object to facilitate chaining multiple methods; never null + */ + Headers addLong(String key, long value); + + /** + * Add to this collection a {@link Header} with the given key and value. + * + * @param key the header's key; may not be null + * @param value the header's value; may be null + * @return this object to facilitate chaining multiple methods; never null + */ + Headers addFloat(String key, float value); + + /** + * Add to this collection a {@link Header} with the given key and value. + * + * @param key the header's key; may not be null + * @param value the header's value; may be null + * @return this object to facilitate chaining multiple methods; never null + */ + Headers addDouble(String key, double value); + + /** + * Add to this collection a {@link Header} with the given key and value. + * + * @param key the header's key; may not be null + * @param value the header's value; may be null + * @return this object to facilitate chaining multiple methods; never null + */ + Headers addBytes(String key, byte[] value); + + /** + * Add to this collection a {@link Header} with the given key and value. + * + * @param key the header's key; may not be null + * @param value the header's value; may be null + * @param schema the schema describing the list value; may not be null + * @return this object to facilitate chaining multiple methods; never null + * @throws DataException if the header's value is invalid + */ + Headers addList(String key, List value, Schema schema); + + /** + * Add to this collection a {@link Header} with the given key and value. + * + * @param key the header's key; may not be null + * @param value the header's value; may be null + * @param schema the schema describing the map value; may not be null + * @return this object to facilitate chaining multiple methods; never null + * @throws DataException if the header's value is invalid + */ + Headers addMap(String key, Map value, Schema schema); + + /** + * Add to this collection a {@link Header} with the given key and value. + * + * @param key the header's key; may not be null + * @param value the header's value; may be null + * @return this object to facilitate chaining multiple methods; never null + * @throws DataException if the header's value is invalid + */ + Headers addStruct(String key, Struct value); + + /** + * Add to this collection a {@link Header} with the given key and {@link org.apache.kafka.connect.data.Decimal} value. + * + * @param key the header's key; may not be null + * @param value the header's {@link org.apache.kafka.connect.data.Decimal} value; may be null + * @return this object to facilitate chaining multiple methods; never null + */ + Headers addDecimal(String key, BigDecimal value); + + /** + * Add to this collection a {@link Header} with the given key and {@link org.apache.kafka.connect.data.Date} value. + * + * @param key the header's key; may not be null + * @param value the header's {@link org.apache.kafka.connect.data.Date} value; may be null + * @return this object to facilitate chaining multiple methods; never null + */ + Headers addDate(String key, java.util.Date value); + + /** + * Add to this collection a {@link Header} with the given key and {@link org.apache.kafka.connect.data.Time} value. + * + * @param key the header's key; may not be null + * @param value the header's {@link org.apache.kafka.connect.data.Time} value; may be null + * @return this object to facilitate chaining multiple methods; never null + */ + Headers addTime(String key, java.util.Date value); + + /** + * Add to this collection a {@link Header} with the given key and {@link org.apache.kafka.connect.data.Timestamp} value. + * + * @param key the header's key; may not be null + * @param value the header's {@link org.apache.kafka.connect.data.Timestamp} value; may be null + * @return this object to facilitate chaining multiple methods; never null + */ + Headers addTimestamp(String key, java.util.Date value); + + /** + * Removes all {@link Header} objects whose {@link Header#key() key} matches the specified key. + * + * @param key the key; may not be null + * @return this object to facilitate chaining multiple methods; never null + */ + Headers remove(String key); + + /** + * Removes all but the latest {@link Header} objects whose {@link Header#key() key} matches the specified key. + * + * @param key the key; may not be null + * @return this object to facilitate chaining multiple methods; never null + */ + Headers retainLatest(String key); + + /** + * Removes all but the last {@link Header} object with each key. + * + * @return this object to facilitate chaining multiple methods; never null + */ + Headers retainLatest(); + + /** + * Removes all headers from this object. + * + * @return this object to facilitate chaining multiple methods; never null + */ + Headers clear(); + + /** + * Create a copy of this {@link Headers} object. The new copy will contain all of the same {@link Header} objects as this object. + * @return the copy; never null + */ + Headers duplicate(); + + /** + * Get all {@link Header}s, apply the transform to each and store the result in place of the original. + * + * @param transform the transform to apply; may not be null + * @return this object to facilitate chaining multiple methods; never null + * @throws DataException if the header's value is invalid + */ + Headers apply(HeaderTransform transform); + + /** + * Get all {@link Header}s with the given key, apply the transform to each and store the result in place of the original. + * + * @param key the header's key; may not be null + * @param transform the transform to apply; may not be null + * @return this object to facilitate chaining multiple methods; never null + * @throws DataException if the header's value is invalid + */ + Headers apply(String key, HeaderTransform transform); + + /** + * A function to transform the supplied {@link Header}. Implementations will likely need to use {@link Header#with(Schema, Object)} + * to create the new instance. + */ + interface HeaderTransform { + /** + * Transform the given {@link Header} and return the updated {@link Header}. + * + * @param header the input header; never null + * @return the new header, or null if the supplied {@link Header} is to be removed + */ + Header apply(Header header); + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/health/AbstractState.java b/connect/api/src/main/java/org/apache/kafka/connect/health/AbstractState.java new file mode 100644 index 0000000..ff65715 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/health/AbstractState.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.health; + +import java.util.Objects; + +import org.apache.kafka.common.utils.Utils; + +/** + * Provides the current status along with identifier for Connect worker and tasks. + */ +public abstract class AbstractState { + + private final String state; + private final String traceMessage; + private final String workerId; + + /** + * Construct a state for connector or task. + * + * @param state the status of connector or task; may not be null or empty + * @param workerId the workerId associated with the connector or the task; may not be null or empty + * @param traceMessage any error trace message associated with the connector or the task; may be null or empty + */ + public AbstractState(String state, String workerId, String traceMessage) { + if (Utils.isBlank(state)) { + throw new IllegalArgumentException("State must not be null or empty"); + } + if (Utils.isBlank(workerId)) { + throw new IllegalArgumentException("Worker ID must not be null or empty"); + } + this.state = state; + this.workerId = workerId; + this.traceMessage = traceMessage; + } + + /** + * Provides the current state of the connector or task. + * + * @return state, never {@code null} or empty + */ + public String state() { + return state; + } + + /** + * The identifier of the worker associated with the connector or the task. + * + * @return workerId, never {@code null} or empty. + */ + public String workerId() { + return workerId; + } + + /** + * The error message associated with the connector or task. + * + * @return traceMessage, can be {@code null} or empty. + */ + public String traceMessage() { + return traceMessage; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + AbstractState that = (AbstractState) o; + return state.equals(that.state) + && Objects.equals(traceMessage, that.traceMessage) + && workerId.equals(that.workerId); + } + + @Override + public int hashCode() { + return Objects.hash(state, traceMessage, workerId); + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/health/ConnectClusterDetails.java b/connect/api/src/main/java/org/apache/kafka/connect/health/ConnectClusterDetails.java new file mode 100644 index 0000000..edde6ff --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/health/ConnectClusterDetails.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.health; + +/** + * Provides immutable Connect cluster information, such as the ID of the backing Kafka cluster. The + * Connect framework provides the implementation for this interface. + */ +public interface ConnectClusterDetails { + + /** + * Get the cluster ID of the Kafka cluster backing this Connect cluster. + * + * @return the cluster ID of the Kafka cluster backing this Connect cluster + **/ + String kafkaClusterId(); +} \ No newline at end of file diff --git a/connect/api/src/main/java/org/apache/kafka/connect/health/ConnectClusterState.java b/connect/api/src/main/java/org/apache/kafka/connect/health/ConnectClusterState.java new file mode 100644 index 0000000..753ee1a --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/health/ConnectClusterState.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.health; + +import java.util.Collection; +import java.util.Map; + +/** + * Provides the ability to lookup connector metadata, including status and configurations, as well + * as immutable cluster information such as Kafka cluster ID. This is made available to + * {@link org.apache.kafka.connect.rest.ConnectRestExtension} implementations. The Connect framework + * provides the implementation for this interface. + */ +public interface ConnectClusterState { + + /** + * Get the names of the connectors currently deployed in this cluster. This is a full list of connectors in the cluster gathered from + * the current configuration, which may change over time. + * + * @return collection of connector names, never {@code null} + */ + Collection connectors(); + + /** + * Lookup the current health of a connector and its tasks. This provides the current snapshot of health by querying the underlying + * herder. A connector returned by previous invocation of {@link #connectors()} may no longer be available and could result in {@link + * org.apache.kafka.connect.errors.NotFoundException}. + * + * @param connName name of the connector + * @return the health of the connector for the connector name + * @throws org.apache.kafka.connect.errors.NotFoundException if the requested connector can't be found + */ + ConnectorHealth connectorHealth(String connName); + + /** + * Lookup the current configuration of a connector. This provides the current snapshot of configuration by querying the underlying + * herder. A connector returned by previous invocation of {@link #connectors()} may no longer be available and could result in {@link + * org.apache.kafka.connect.errors.NotFoundException}. + * + * @param connName name of the connector + * @return the configuration of the connector for the connector name + * @throws org.apache.kafka.connect.errors.NotFoundException if the requested connector can't be found + * @throws java.lang.UnsupportedOperationException if the default implementation has not been overridden + */ + default Map connectorConfig(String connName) { + throw new UnsupportedOperationException(); + } + + /** + * Get details about the setup of the Connect cluster. + * @return a {@link ConnectClusterDetails} object containing information about the cluster + * @throws java.lang.UnsupportedOperationException if the default implementation has not been overridden + **/ + default ConnectClusterDetails clusterDetails() { + throw new UnsupportedOperationException(); + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/health/ConnectorHealth.java b/connect/api/src/main/java/org/apache/kafka/connect/health/ConnectorHealth.java new file mode 100644 index 0000000..1f78157 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/health/ConnectorHealth.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.health; + +import java.util.Map; +import java.util.Objects; + +import org.apache.kafka.common.utils.Utils; + +/** + * Provides basic health information about the connector and its tasks. + */ +public class ConnectorHealth { + + private final String name; + private final ConnectorState connectorState; + private final Map tasks; + private final ConnectorType type; + + + public ConnectorHealth(String name, + ConnectorState connectorState, + Map tasks, + ConnectorType type) { + if (Utils.isBlank(name)) { + throw new IllegalArgumentException("Connector name is required"); + } + Objects.requireNonNull(connectorState, "connectorState can't be null"); + Objects.requireNonNull(tasks, "tasks can't be null"); + Objects.requireNonNull(type, "type can't be null"); + this.name = name; + this.connectorState = connectorState; + this.tasks = tasks; + this.type = type; + } + + /** + * Provides the name of the connector. + * + * @return name, never {@code null} or empty + */ + public String name() { + return name; + } + + /** + * Provides the current state of the connector. + * + * @return the connector state, never {@code null} + */ + public ConnectorState connectorState() { + return connectorState; + } + + /** + * Provides the current state of the connector tasks. + * + * @return the state for each task ID; never {@code null} + */ + public Map tasksState() { + return tasks; + } + + /** + * Provides the type of the connector. + * + * @return type, never {@code null} + */ + public ConnectorType type() { + return type; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + ConnectorHealth that = (ConnectorHealth) o; + return name.equals(that.name) + && connectorState.equals(that.connectorState) + && tasks.equals(that.tasks) + && type == that.type; + } + + @Override + public int hashCode() { + return Objects.hash(name, connectorState, tasks, type); + } + + @Override + public String toString() { + return "ConnectorHealth{" + + "name='" + name + '\'' + + ", connectorState=" + connectorState + + ", tasks=" + tasks + + ", type=" + type + + '}'; + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/health/ConnectorState.java b/connect/api/src/main/java/org/apache/kafka/connect/health/ConnectorState.java new file mode 100644 index 0000000..6304426 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/health/ConnectorState.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.health; + +/** + * Describes the status, worker ID, and any errors associated with a connector. + */ +public class ConnectorState extends AbstractState { + + /** + * Provides an instance of the ConnectorState. + * + * @param state - the status of connector, may not be {@code null} or empty + * @param workerId - the workerId associated with the connector, may not be {@code null} or empty + * @param traceMessage - any error message associated with the connector, may be {@code null} or empty + */ + public ConnectorState(String state, String workerId, String traceMessage) { + super(state, workerId, traceMessage); + } + + @Override + public String toString() { + return "ConnectorState{" + + "state='" + state() + '\'' + + ", traceMessage='" + traceMessage() + '\'' + + ", workerId='" + workerId() + '\'' + + '}'; + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/health/ConnectorType.java b/connect/api/src/main/java/org/apache/kafka/connect/health/ConnectorType.java new file mode 100644 index 0000000..fa9db6f --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/health/ConnectorType.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.health; + +import java.util.Locale; + +/** + * Enum definition that identifies the type of the connector. + */ +public enum ConnectorType { + /** + * Identifies a source connector + */ + SOURCE, + /** + * Identifies a sink connector + */ + SINK, + /** + * Identifies a connector whose type could not be inferred + */ + UNKNOWN; + + @Override + public String toString() { + return super.toString().toLowerCase(Locale.ROOT); + } + +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/health/TaskState.java b/connect/api/src/main/java/org/apache/kafka/connect/health/TaskState.java new file mode 100644 index 0000000..ae78a5f --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/health/TaskState.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.health; + +import java.util.Objects; + +/** + * Describes the state, IDs, and any errors of a connector task. + */ +public class TaskState extends AbstractState { + + private final int taskId; + + /** + * Provides an instance of {@link TaskState}. + * + * @param taskId the id associated with the connector task + * @param state the status of the task, may not be {@code null} or empty + * @param workerId id of the worker the task is associated with, may not be {@code null} or empty + * @param trace error message if that task had failed or errored out, may be {@code null} or empty + */ + public TaskState(int taskId, String state, String workerId, String trace) { + super(state, workerId, trace); + this.taskId = taskId; + } + + /** + * Provides the ID of the task. + * + * @return the task ID + */ + public int taskId() { + return taskId; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + if (!super.equals(o)) + return false; + TaskState taskState = (TaskState) o; + return taskId == taskState.taskId; + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), taskId); + } + + @Override + public String toString() { + return "TaskState{" + + "taskId='" + taskId + '\'' + + "state='" + state() + '\'' + + ", traceMessage='" + traceMessage() + '\'' + + ", workerId='" + workerId() + '\'' + + '}'; + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/rest/ConnectRestExtension.java b/connect/api/src/main/java/org/apache/kafka/connect/rest/ConnectRestExtension.java new file mode 100644 index 0000000..aa479a3 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/rest/ConnectRestExtension.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.rest; + +import org.apache.kafka.common.Configurable; +import org.apache.kafka.connect.components.Versioned; +import org.apache.kafka.connect.health.ConnectClusterState; + +import java.io.Closeable; +import java.util.Map; + +/** + * A plugin interface to allow registration of new JAX-RS resources like Filters, REST endpoints, providers, etc. The implementations will + * be discovered using the standard Java {@link java.util.ServiceLoader} mechanism by Connect's plugin class loading mechanism. + * + *

            The extension class(es) must be packaged as a plugin, with one JAR containing the implementation classes and a {@code + * META-INF/services/org.apache.kafka.connect.rest.extension.ConnectRestExtension} file that contains the fully qualified name of the + * class(es) that implement the ConnectRestExtension interface. The plugin should also include the JARs of all dependencies except those + * already provided by the Connect framework. + * + *

            To install into a Connect installation, add a directory named for the plugin and containing the plugin's JARs into a directory that is + * on Connect's {@code plugin.path}, and (re)start the Connect worker. + * + *

            When the Connect worker process starts up, it will read its configuration and instantiate all of the REST extension implementation + * classes that are specified in the `rest.extension.classes` configuration property. Connect will then pass its configuration to each + * extension via the {@link Configurable#configure(Map)} method, and will then call {@link #register} with a provided context. + * + *

            When the Connect worker shuts down, it will call the extension's {@link #close} method to allow the implementation to release all of + * its resources. + */ +public interface ConnectRestExtension extends Configurable, Versioned, Closeable { + + /** + * ConnectRestExtension implementations can register custom JAX-RS resources via the {@link #register(ConnectRestExtensionContext)} + * method. The Connect framework will invoke this method after registering the default Connect resources. If the implementations attempt + * to re-register any of the Connect resources, it will be be ignored and will be logged. + * + * @param restPluginContext The context provides access to JAX-RS {@link javax.ws.rs.core.Configurable} and {@link + * ConnectClusterState}.The custom JAX-RS resources can be registered via the {@link + * ConnectRestExtensionContext#configurable()} + */ + void register(ConnectRestExtensionContext restPluginContext); +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/rest/ConnectRestExtensionContext.java b/connect/api/src/main/java/org/apache/kafka/connect/rest/ConnectRestExtensionContext.java new file mode 100644 index 0000000..c951627 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/rest/ConnectRestExtensionContext.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.rest; + +import org.apache.kafka.connect.health.ConnectClusterState; + +import javax.ws.rs.core.Configurable; + +/** + * The interface provides the ability for {@link ConnectRestExtension} implementations to access the JAX-RS + * {@link javax.ws.rs.core.Configurable} and cluster state {@link ConnectClusterState}. The implementation for the interface is provided + * by the Connect framework. + */ +public interface ConnectRestExtensionContext { + + /** + * Provides an implementation of {@link javax.ws.rs.core.Configurable} that be used to register JAX-RS resources. + * + * @return @return the JAX-RS {@link javax.ws.rs.core.Configurable}; never {@code null} + */ + Configurable> configurable(); + + /** + * Provides the cluster state and health information about the connectors and tasks. + * + * @return the cluster state information; never {@code null} + */ + ConnectClusterState clusterState(); +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/sink/ErrantRecordReporter.java b/connect/api/src/main/java/org/apache/kafka/connect/sink/ErrantRecordReporter.java new file mode 100644 index 0000000..a20e1e3 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/sink/ErrantRecordReporter.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.sink; + +import java.util.concurrent.Future; +import org.apache.kafka.connect.errors.ConnectException; + +/** + * Component that the sink task can use as it {@link SinkTask#put(java.util.Collection)}. + * Reporter of problematic records and the corresponding problems. + * + * @since 2.6 + */ +public interface ErrantRecordReporter { + + /** + * Report a problematic record and the corresponding error to be written to the sink + * connector's dead letter queue (DLQ). + * + *

            This call is asynchronous and returns a {@link java.util.concurrent.Future Future}. + * Invoking {@link java.util.concurrent.Future#get() get()} on this future will block until the + * record has been written or throw any exception that occurred while sending the record. + * If you want to simulate a simple blocking call you can call the get() method + * immediately. + * + * Connect guarantees that sink records reported through this reporter will be written to the error topic + * before the framework calls the {@link SinkTask#preCommit(java.util.Map)} method and therefore before + * committing the consumer offsets. SinkTask implementations can use the Future when stronger guarantees + * are required. + * + * @param record the problematic record; may not be null + * @param error the error capturing the problem with the record; may not be null + * @return a future that can be used to block until the record and error are reported + * to the DLQ + * @throws ConnectException if the error reporter and DLQ fails to write a reported record + * @since 2.6 + */ + Future report(SinkRecord record, Throwable error); +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/sink/SinkConnector.java b/connect/api/src/main/java/org/apache/kafka/connect/sink/SinkConnector.java new file mode 100644 index 0000000..9627571 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/sink/SinkConnector.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.sink; + +import org.apache.kafka.connect.connector.Connector; + +/** + * SinkConnectors implement the Connector interface to send Kafka data to another system. + */ +public abstract class SinkConnector extends Connector { + + /** + *

            + * Configuration key for the list of input topics for this connector. + *

            + *

            + * Usually this setting is only relevant to the Kafka Connect framework, but is provided here for + * the convenience of Connector developers if they also need to know the set of topics. + *

            + */ + public static final String TOPICS_CONFIG = "topics"; + + @Override + protected SinkConnectorContext context() { + return (SinkConnectorContext) context; + } + +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/sink/SinkConnectorContext.java b/connect/api/src/main/java/org/apache/kafka/connect/sink/SinkConnectorContext.java new file mode 100644 index 0000000..5e2b07a --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/sink/SinkConnectorContext.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.sink; + +import org.apache.kafka.connect.connector.ConnectorContext; + +/** + * A context to allow a {@link SinkConnector} to interact with the Kafka Connect runtime. + */ +public interface SinkConnectorContext extends ConnectorContext { +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/sink/SinkRecord.java b/connect/api/src/main/java/org/apache/kafka/connect/sink/SinkRecord.java new file mode 100644 index 0000000..12c7ee1 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/sink/SinkRecord.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.sink; + +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.connect.connector.ConnectRecord; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.header.Header; + +/** + * SinkRecord is a {@link ConnectRecord} that has been read from Kafka and includes the kafkaOffset of + * the record in the Kafka topic-partition in addition to the standard fields. This information + * should be used by the SinkTask to coordinate kafkaOffset commits. + * + * It also includes the {@link TimestampType}, which may be {@link TimestampType#NO_TIMESTAMP_TYPE}, and the relevant + * timestamp, which may be {@code null}. + */ +public class SinkRecord extends ConnectRecord { + private final long kafkaOffset; + private final TimestampType timestampType; + + public SinkRecord(String topic, int partition, Schema keySchema, Object key, Schema valueSchema, Object value, long kafkaOffset) { + this(topic, partition, keySchema, key, valueSchema, value, kafkaOffset, null, TimestampType.NO_TIMESTAMP_TYPE); + } + + public SinkRecord(String topic, int partition, Schema keySchema, Object key, Schema valueSchema, Object value, long kafkaOffset, + Long timestamp, TimestampType timestampType) { + this(topic, partition, keySchema, key, valueSchema, value, kafkaOffset, timestamp, timestampType, null); + } + + public SinkRecord(String topic, int partition, Schema keySchema, Object key, Schema valueSchema, Object value, long kafkaOffset, + Long timestamp, TimestampType timestampType, Iterable
            headers) { + super(topic, partition, keySchema, key, valueSchema, value, timestamp, headers); + this.kafkaOffset = kafkaOffset; + this.timestampType = timestampType; + } + + public long kafkaOffset() { + return kafkaOffset; + } + + public TimestampType timestampType() { + return timestampType; + } + + @Override + public SinkRecord newRecord(String topic, Integer kafkaPartition, Schema keySchema, Object key, Schema valueSchema, Object value, Long timestamp) { + return newRecord(topic, kafkaPartition, keySchema, key, valueSchema, value, timestamp, headers().duplicate()); + } + + @Override + public SinkRecord newRecord(String topic, Integer kafkaPartition, Schema keySchema, Object key, Schema valueSchema, Object value, + Long timestamp, Iterable
            headers) { + return new SinkRecord(topic, kafkaPartition, keySchema, key, valueSchema, value, kafkaOffset(), timestamp, timestampType, headers); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + if (!super.equals(o)) + return false; + + SinkRecord that = (SinkRecord) o; + + if (kafkaOffset != that.kafkaOffset) + return false; + + return timestampType == that.timestampType; + } + + @Override + public int hashCode() { + int result = super.hashCode(); + result = 31 * result + Long.hashCode(kafkaOffset); + result = 31 * result + timestampType.hashCode(); + return result; + } + + @Override + public String toString() { + return "SinkRecord{" + + "kafkaOffset=" + kafkaOffset + + ", timestampType=" + timestampType + + "} " + super.toString(); + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/sink/SinkTask.java b/connect/api/src/main/java/org/apache/kafka/connect/sink/SinkTask.java new file mode 100644 index 0000000..5d308c4 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/sink/SinkTask.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.sink; + +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.connect.connector.Task; + +import java.util.Collection; +import java.util.Map; + +/** + * SinkTask is a Task that takes records loaded from Kafka and sends them to another system. Each task + * instance is assigned a set of partitions by the Connect framework and will handle all records received + * from those partitions. As records are fetched from Kafka, they will be passed to the sink task using the + * {@link #put(Collection)} API, which should either write them to the downstream system or batch them for + * later writing. Periodically, Connect will call {@link #flush(Map)} to ensure that batched records are + * actually pushed to the downstream system.. + * + * Below we describe the lifecycle of a SinkTask. + * + *
              + *
            1. Initialization: SinkTasks are first initialized using {@link #initialize(SinkTaskContext)} + * to prepare the task's context and {@link #start(Map)} to accept configuration and start any services + * needed for processing.
            2. + *
            3. Partition Assignment: After initialization, Connect will assign the task a set of partitions + * using {@link #open(Collection)}. These partitions are owned exclusively by this task until they + * have been closed with {@link #close(Collection)}.
            4. + *
            5. Record Processing: Once partitions have been opened for writing, Connect will begin forwarding + * records from Kafka using the {@link #put(Collection)} API. Periodically, Connect will ask the task + * to flush records using {@link #flush(Map)} as described above.
            6. + *
            7. Partition Rebalancing: Occasionally, Connect will need to change the assignment of this task. + * When this happens, the currently assigned partitions will be closed with {@link #close(Collection)} and + * the new assignment will be opened using {@link #open(Collection)}.
            8. + *
            9. Shutdown: When the task needs to be shutdown, Connect will close active partitions (if there + * are any) and stop the task using {@link #stop()}
            10. + *
            + * + */ +public abstract class SinkTask implements Task { + + /** + *

            + * The configuration key that provides the list of topics that are inputs for this + * SinkTask. + *

            + */ + public static final String TOPICS_CONFIG = "topics"; + + /** + *

            + * The configuration key that provides a regex specifying which topics to include as inputs + * for this SinkTask. + *

            + */ + public static final String TOPICS_REGEX_CONFIG = "topics.regex"; + + protected SinkTaskContext context; + + /** + * Initialize the context of this task. Note that the partition assignment will be empty until + * Connect has opened the partitions for writing with {@link #open(Collection)}. + * @param context The sink task's context + */ + public void initialize(SinkTaskContext context) { + this.context = context; + } + + /** + * Start the Task. This should handle any configuration parsing and one-time setup of the task. + * @param props initial configuration + */ + @Override + public abstract void start(Map props); + + /** + * Put the records in the sink. Usually this should send the records to the sink asynchronously + * and immediately return. + * + * If this operation fails, the SinkTask may throw a {@link org.apache.kafka.connect.errors.RetriableException} to + * indicate that the framework should attempt to retry the same call again. Other exceptions will cause the task to + * be stopped immediately. {@link SinkTaskContext#timeout(long)} can be used to set the maximum time before the + * batch will be retried. + * + * @param records the set of records to send + */ + public abstract void put(Collection records); + + /** + * Flush all records that have been {@link #put(Collection)} for the specified topic-partitions. + * + * @param currentOffsets the current offset state as of the last call to {@link #put(Collection)}}, + * provided for convenience but could also be determined by tracking all offsets included in the {@link SinkRecord}s + * passed to {@link #put}. + */ + public void flush(Map currentOffsets) { + } + + /** + * Pre-commit hook invoked prior to an offset commit. + * + * The default implementation simply invokes {@link #flush(Map)} and is thus able to assume all {@code currentOffsets} are safe to commit. + * + * @param currentOffsets the current offset state as of the last call to {@link #put(Collection)}}, + * provided for convenience but could also be determined by tracking all offsets included in the {@link SinkRecord}s + * passed to {@link #put}. + * + * @return an empty map if Connect-managed offset commit is not desired, otherwise a map of offsets by topic-partition that are safe to commit. + */ + public Map preCommit(Map currentOffsets) { + flush(currentOffsets); + return currentOffsets; + } + + /** + * The SinkTask use this method to create writers for newly assigned partitions in case of partition + * rebalance. This method will be called after partition re-assignment completes and before the SinkTask starts + * fetching data. Note that any errors raised from this method will cause the task to stop. + * @param partitions The list of partitions that are now assigned to the task (may include + * partitions previously assigned to the task) + */ + public void open(Collection partitions) { + this.onPartitionsAssigned(partitions); + } + + /** + * @deprecated Use {@link #open(Collection)} for partition initialization. + */ + @Deprecated + public void onPartitionsAssigned(Collection partitions) { + } + + /** + * The SinkTask use this method to close writers for partitions that are no + * longer assigned to the SinkTask. This method will be called before a rebalance operation starts + * and after the SinkTask stops fetching data. After being closed, Connect will not write + * any records to the task until a new set of partitions has been opened. Note that any errors raised + * from this method will cause the task to stop. + * @param partitions The list of partitions that should be closed + */ + public void close(Collection partitions) { + this.onPartitionsRevoked(partitions); + } + + /** + * @deprecated Use {@link #close(Collection)} instead for partition cleanup. + */ + @Deprecated + public void onPartitionsRevoked(Collection partitions) { + } + + /** + * Perform any cleanup to stop this task. In SinkTasks, this method is invoked only once outstanding calls to other + * methods have completed (e.g., {@link #put(Collection)} has returned) and a final {@link #flush(Map)} and offset + * commit has completed. Implementations of this method should only need to perform final cleanup operations, such + * as closing network connections to the sink system. + */ + @Override + public abstract void stop(); +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/sink/SinkTaskContext.java b/connect/api/src/main/java/org/apache/kafka/connect/sink/SinkTaskContext.java new file mode 100644 index 0000000..c4522c7 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/sink/SinkTaskContext.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.sink; + +import org.apache.kafka.common.TopicPartition; + +import java.util.Map; +import java.util.Set; + +/** + * Context passed to SinkTasks, allowing them to access utilities in the Kafka Connect runtime. + */ +public interface SinkTaskContext { + + /** + * Get the Task configuration. This is the latest configuration and may differ from that passed on startup. + * + * For example, this method can be used to obtain the latest configuration if an external secret has changed, + * and the configuration is using variable references such as those compatible with + * {@link org.apache.kafka.common.config.ConfigTransformer}. + */ + Map configs(); + + /** + * Reset the consumer offsets for the given topic partitions. SinkTasks should use this if they manage offsets + * in the sink data store rather than using Kafka consumer offsets. For example, an HDFS connector might record + * offsets in HDFS to provide exactly once delivery. When the SinkTask is started or a rebalance occurs, the task + * would reload offsets from HDFS and use this method to reset the consumer to those offsets. + * + * SinkTasks that do not manage their own offsets do not need to use this method. + * + * @param offsets map of offsets for topic partitions + */ + void offset(Map offsets); + + /** + * Reset the consumer offsets for the given topic partition. SinkTasks should use if they manage offsets + * in the sink data store rather than using Kafka consumer offsets. For example, an HDFS connector might record + * offsets in HDFS to provide exactly once delivery. When the topic partition is recovered the task + * would reload offsets from HDFS and use this method to reset the consumer to the offset. + * + * SinkTasks that do not manage their own offsets do not need to use this method. + * + * @param tp the topic partition to reset offset. + * @param offset the offset to reset to. + */ + void offset(TopicPartition tp, long offset); + + /** + * Set the timeout in milliseconds. SinkTasks should use this to indicate that they need to retry certain + * operations after the timeout. SinkTasks may have certain operations on external systems that may need + * to retry in case of failures. For example, append a record to an HDFS file may fail due to temporary network + * issues. SinkTasks use this method to set how long to wait before retrying. + * @param timeoutMs the backoff timeout in milliseconds. + */ + void timeout(long timeoutMs); + + /** + * Get the current set of assigned TopicPartitions for this task. + * @return the set of currently assigned TopicPartitions + */ + Set assignment(); + + /** + * Pause consumption of messages from the specified TopicPartitions. + * @param partitions the partitions which should be paused + */ + void pause(TopicPartition... partitions); + + /** + * Resume consumption of messages from previously paused TopicPartitions. + * @param partitions the partitions to resume + */ + void resume(TopicPartition... partitions); + + /** + * Request an offset commit. Sink tasks can use this to minimize the potential for redelivery + * by requesting an offset commit as soon as they flush data to the destination system. + * + * It is only a hint to the runtime and no timing guarantee should be assumed. + */ + void requestCommit(); + + /** + * Get the reporter to which the sink task can report problematic or failed {@link SinkRecord records} + * passed to the {@link SinkTask#put(java.util.Collection)} method. When reporting a failed record, + * the sink task will receive a {@link java.util.concurrent.Future} that the task can optionally use to wait until + * the failed record and exception have been written to Kafka. Note that the result of + * this method may be null if this connector has not been configured to use a reporter. + * + *

            This method was added in Apache Kafka 2.6. Sink tasks that use this method but want to + * maintain backward compatibility so they can also be deployed to older Connect runtimes + * should guard the call to this method with a try-catch block, since calling this method will result in a + * {@link NoSuchMethodException} or {@link NoClassDefFoundError} when the sink connector is deployed to + * Connect runtimes older than Kafka 2.6. For example: + *

            +     *     ErrantRecordReporter reporter;
            +     *     try {
            +     *         reporter = context.errantRecordReporter();
            +     *     } catch (NoSuchMethodError | NoClassDefFoundError e) {
            +     *         reporter = null;
            +     *     }
            +     * 
            + * + * @return the reporter; null if no error reporter has been configured for the connector + * @since 2.6 + */ + default ErrantRecordReporter errantRecordReporter() { + return null; + } + +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/source/SourceConnector.java b/connect/api/src/main/java/org/apache/kafka/connect/source/SourceConnector.java new file mode 100644 index 0000000..6e96940 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/source/SourceConnector.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.source; + +import org.apache.kafka.connect.connector.Connector; + +/** + * SourceConnectors implement the connector interface to pull data from another system and send + * it to Kafka. + */ +public abstract class SourceConnector extends Connector { + + @Override + protected SourceConnectorContext context() { + return (SourceConnectorContext) context; + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/source/SourceConnectorContext.java b/connect/api/src/main/java/org/apache/kafka/connect/source/SourceConnectorContext.java new file mode 100644 index 0000000..417fbdd --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/source/SourceConnectorContext.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.source; + +import org.apache.kafka.connect.connector.ConnectorContext; +import org.apache.kafka.connect.storage.OffsetStorageReader; + +/** + * A context to allow a {@link SourceConnector} to interact with the Kafka Connect runtime. + */ +public interface SourceConnectorContext extends ConnectorContext { + + /** + * Returns the {@link OffsetStorageReader} for this SourceConnectorContext. + * @return the OffsetStorageReader for this connector. + */ + OffsetStorageReader offsetStorageReader(); +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/source/SourceRecord.java b/connect/api/src/main/java/org/apache/kafka/connect/source/SourceRecord.java new file mode 100644 index 0000000..2c390ee --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/source/SourceRecord.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.source; + +import org.apache.kafka.connect.connector.ConnectRecord; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.header.Header; + +import java.util.Map; +import java.util.Objects; + +/** + *

            + * SourceRecords are generated by SourceTasks and passed to Kafka Connect for storage in + * Kafka. In addition to the standard fields in {@link ConnectRecord} which specify where data is stored + * in Kafka, they also include a sourcePartition and sourceOffset. + *

            + *

            + * The sourcePartition represents a single input sourcePartition that the record came from (e.g. a filename, table + * name, or topic-partition). The sourceOffset represents a position in that sourcePartition which can be used + * to resume consumption of data. + *

            + *

            + * These values can have arbitrary structure and should be represented using + * org.apache.kafka.connect.data objects (or primitive values). For example, a database connector + * might specify the sourcePartition as a record containing { "db": "database_name", "table": + * "table_name"} and the sourceOffset as a Long containing the timestamp of the row. + *

            + */ +public class SourceRecord extends ConnectRecord { + private final Map sourcePartition; + private final Map sourceOffset; + + public SourceRecord(Map sourcePartition, Map sourceOffset, + String topic, Integer partition, Schema valueSchema, Object value) { + this(sourcePartition, sourceOffset, topic, partition, null, null, valueSchema, value); + } + + public SourceRecord(Map sourcePartition, Map sourceOffset, + String topic, Schema valueSchema, Object value) { + this(sourcePartition, sourceOffset, topic, null, null, null, valueSchema, value); + } + + public SourceRecord(Map sourcePartition, Map sourceOffset, + String topic, Schema keySchema, Object key, Schema valueSchema, Object value) { + this(sourcePartition, sourceOffset, topic, null, keySchema, key, valueSchema, value); + } + + public SourceRecord(Map sourcePartition, Map sourceOffset, + String topic, Integer partition, + Schema keySchema, Object key, Schema valueSchema, Object value) { + this(sourcePartition, sourceOffset, topic, partition, keySchema, key, valueSchema, value, null); + } + + public SourceRecord(Map sourcePartition, Map sourceOffset, + String topic, Integer partition, + Schema keySchema, Object key, + Schema valueSchema, Object value, + Long timestamp) { + this(sourcePartition, sourceOffset, topic, partition, keySchema, key, valueSchema, value, timestamp, null); + } + + public SourceRecord(Map sourcePartition, Map sourceOffset, + String topic, Integer partition, + Schema keySchema, Object key, + Schema valueSchema, Object value, + Long timestamp, Iterable
            headers) { + super(topic, partition, keySchema, key, valueSchema, value, timestamp, headers); + this.sourcePartition = sourcePartition; + this.sourceOffset = sourceOffset; + } + + public Map sourcePartition() { + return sourcePartition; + } + + public Map sourceOffset() { + return sourceOffset; + } + + @Override + public SourceRecord newRecord(String topic, Integer kafkaPartition, Schema keySchema, Object key, Schema valueSchema, Object value, Long timestamp) { + return newRecord(topic, kafkaPartition, keySchema, key, valueSchema, value, timestamp, headers().duplicate()); + } + + @Override + public SourceRecord newRecord(String topic, Integer kafkaPartition, Schema keySchema, Object key, Schema valueSchema, Object value, + Long timestamp, Iterable
            headers) { + return new SourceRecord(sourcePartition, sourceOffset, topic, kafkaPartition, keySchema, key, valueSchema, value, timestamp, headers); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + if (!super.equals(o)) + return false; + + SourceRecord that = (SourceRecord) o; + + return Objects.equals(sourcePartition, that.sourcePartition) && + Objects.equals(sourceOffset, that.sourceOffset); + } + + @Override + public int hashCode() { + int result = super.hashCode(); + result = 31 * result + (sourcePartition != null ? sourcePartition.hashCode() : 0); + result = 31 * result + (sourceOffset != null ? sourceOffset.hashCode() : 0); + return result; + } + + @Override + public String toString() { + return "SourceRecord{" + + "sourcePartition=" + sourcePartition + + ", sourceOffset=" + sourceOffset + + "} " + super.toString(); + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/source/SourceTask.java b/connect/api/src/main/java/org/apache/kafka/connect/source/SourceTask.java new file mode 100644 index 0000000..225b080 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/source/SourceTask.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.source; + +import org.apache.kafka.connect.connector.Task; +import org.apache.kafka.clients.producer.RecordMetadata; + +import java.util.List; +import java.util.Map; + +/** + * SourceTask is a Task that pulls records from another system for storage in Kafka. + */ +public abstract class SourceTask implements Task { + + protected SourceTaskContext context; + + /** + * Initialize this SourceTask with the specified context object. + */ + public void initialize(SourceTaskContext context) { + this.context = context; + } + + /** + * Start the Task. This should handle any configuration parsing and one-time setup of the task. + * @param props initial configuration + */ + @Override + public abstract void start(Map props); + + /** + *

            + * Poll this source task for new records. If no data is currently available, this method + * should block but return control to the caller regularly (by returning {@code null}) in + * order for the task to transition to the {@code PAUSED} state if requested to do so. + *

            + *

            + * The task will be {@link #stop() stopped} on a separate thread, and when that happens + * this method is expected to unblock, quickly finish up any remaining processing, and + * return. + *

            + * + * @return a list of source records + */ + public abstract List poll() throws InterruptedException; + + /** + *

            + * Commit the offsets, up to the offsets that have been returned by {@link #poll()}. This + * method should block until the commit is complete. + *

            + *

            + * SourceTasks are not required to implement this functionality; Kafka Connect will record offsets + * automatically. This hook is provided for systems that also need to store offsets internally + * in their own system. + *

            + */ + public void commit() throws InterruptedException { + // This space intentionally left blank. + } + + /** + * Signal this SourceTask to stop. In SourceTasks, this method only needs to signal to the task that it should stop + * trying to poll for new data and interrupt any outstanding poll() requests. It is not required that the task has + * fully stopped. Note that this method necessarily may be invoked from a different thread than {@link #poll()} and + * {@link #commit()}. + * + * For example, if a task uses a {@link java.nio.channels.Selector} to receive data over the network, this method + * could set a flag that will force {@link #poll()} to exit immediately and invoke + * {@link java.nio.channels.Selector#wakeup() wakeup()} to interrupt any ongoing requests. + */ + @Override + public abstract void stop(); + + /** + *

            + * Commit an individual {@link SourceRecord} when the callback from the producer client is received. This method is + * also called when a record is filtered by a transformation, and thus will never be ACK'd by a broker. + *

            + *

            + * This is an alias for {@link #commitRecord(SourceRecord, RecordMetadata)} for backwards compatibility. The default + * implementation of {@link #commitRecord(SourceRecord, RecordMetadata)} just calls this method. It is not necessary + * to override both methods. + *

            + *

            + * SourceTasks are not required to implement this functionality; Kafka Connect will record offsets + * automatically. This hook is provided for systems that also need to store offsets internally + * in their own system. + *

            + * + * @param record {@link SourceRecord} that was successfully sent via the producer or filtered by a transformation + * @throws InterruptedException + * @deprecated Use {@link #commitRecord(SourceRecord, RecordMetadata)} instead. + */ + @Deprecated + public void commitRecord(SourceRecord record) throws InterruptedException { + // This space intentionally left blank. + } + + /** + *

            + * Commit an individual {@link SourceRecord} when the callback from the producer client is received. This method is + * also called when a record is filtered by a transformation, and thus will never be ACK'd by a broker. In this case + * {@code metadata} will be null. + *

            + *

            + * SourceTasks are not required to implement this functionality; Kafka Connect will record offsets + * automatically. This hook is provided for systems that also need to store offsets internally + * in their own system. + *

            + *

            + * The default implementation just calls {@link #commitRecord(SourceRecord)}, which is a nop by default. It is + * not necessary to implement both methods. + *

            + * + * @param record {@link SourceRecord} that was successfully sent via the producer or filtered by a transformation + * @param metadata {@link RecordMetadata} record metadata returned from the broker, or null if the record was filtered + * @throws InterruptedException + */ + public void commitRecord(SourceRecord record, RecordMetadata metadata) + throws InterruptedException { + // by default, just call other method for backwards compatibility + commitRecord(record); + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/source/SourceTaskContext.java b/connect/api/src/main/java/org/apache/kafka/connect/source/SourceTaskContext.java new file mode 100644 index 0000000..ddb0a78 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/source/SourceTaskContext.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.source; + +import org.apache.kafka.connect.storage.OffsetStorageReader; + +import java.util.Map; + +/** + * SourceTaskContext is provided to SourceTasks to allow them to interact with the underlying + * runtime. + */ +public interface SourceTaskContext { + /** + * Get the Task configuration. This is the latest configuration and may differ from that passed on startup. + * + * For example, this method can be used to obtain the latest configuration if an external secret has changed, + * and the configuration is using variable references such as those compatible with + * {@link org.apache.kafka.common.config.ConfigTransformer}. + */ + Map configs(); + + /** + * Get the OffsetStorageReader for this SourceTask. + */ + OffsetStorageReader offsetStorageReader(); +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/storage/Converter.java b/connect/api/src/main/java/org/apache/kafka/connect/storage/Converter.java new file mode 100644 index 0000000..2d2ef4a --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/storage/Converter.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; + +import java.util.Map; + +/** + * The Converter interface provides support for translating between Kafka Connect's runtime data format + * and byte[]. Internally, this likely includes an intermediate step to the format used by the serialization + * layer (e.g. JsonNode, GenericRecord, Message). + */ +public interface Converter { + + /** + * Configure this class. + * @param configs configs in key/value pairs + * @param isKey whether is for key or value + */ + void configure(Map configs, boolean isKey); + + /** + * Convert a Kafka Connect data object to a native object for serialization. + * @param topic the topic associated with the data + * @param schema the schema for the value + * @param value the value to convert + * @return the serialized value + */ + byte[] fromConnectData(String topic, Schema schema, Object value); + + /** + * Convert a Kafka Connect data object to a native object for serialization, + * potentially using the supplied topic and headers in the record as necessary. + * + *

            Connect uses this method directly, and for backward compatibility reasons this method + * by default will call the {@link #fromConnectData(String, Schema, Object)} method. + * Override this method to make use of the supplied headers.

            + * @param topic the topic associated with the data + * @param headers the headers associated with the data; any changes done to the headers + * are applied to the message sent to the broker + * @param schema the schema for the value + * @param value the value to convert + * @return the serialized value + */ + default byte[] fromConnectData(String topic, Headers headers, Schema schema, Object value) { + return fromConnectData(topic, schema, value); + } + + /** + * Convert a native object to a Kafka Connect data object. + * @param topic the topic associated with the data + * @param value the value to convert + * @return an object containing the {@link Schema} and the converted value + */ + SchemaAndValue toConnectData(String topic, byte[] value); + + /** + * Convert a native object to a Kafka Connect data object, + * potentially using the supplied topic and headers in the record as necessary. + * + *

            Connect uses this method directly, and for backward compatibility reasons this method + * by default will call the {@link #toConnectData(String, byte[])} method. + * Override this method to make use of the supplied headers.

            + * @param topic the topic associated with the data + * @param headers the headers associated with the data + * @param value the value to convert + * @return an object containing the {@link Schema} and the converted value + */ + default SchemaAndValue toConnectData(String topic, Headers headers, byte[] value) { + return toConnectData(topic, value); + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/storage/ConverterConfig.java b/connect/api/src/main/java/org/apache/kafka/connect/storage/ConverterConfig.java new file mode 100644 index 0000000..cea6995 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/storage/ConverterConfig.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigDef.Importance; +import org.apache.kafka.common.config.ConfigDef.Type; + +import java.util.Map; + +import static org.apache.kafka.common.config.ConfigDef.ValidString.in; + +/** + * Abstract class that defines the configuration options for {@link Converter} and {@link HeaderConverter} instances. + */ +public abstract class ConverterConfig extends AbstractConfig { + + public static final String TYPE_CONFIG = "converter.type"; + private static final String TYPE_DOC = "How this converter will be used."; + + /** + * Create a new {@link ConfigDef} instance containing the configurations defined by ConverterConfig. This can be called by subclasses. + * + * @return the ConfigDef; never null + */ + public static ConfigDef newConfigDef() { + return new ConfigDef().define(TYPE_CONFIG, Type.STRING, ConfigDef.NO_DEFAULT_VALUE, + in(ConverterType.KEY.getName(), ConverterType.VALUE.getName(), ConverterType.HEADER.getName()), + Importance.LOW, TYPE_DOC); + } + + protected ConverterConfig(ConfigDef configDef, Map props) { + super(configDef, props, true); + } + + /** + * Get the type of converter as defined by the {@link #TYPE_CONFIG} configuration. + * @return the converter type; never null + */ + public ConverterType type() { + return ConverterType.withName(getString(TYPE_CONFIG)); + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/storage/ConverterType.java b/connect/api/src/main/java/org/apache/kafka/connect/storage/ConverterType.java new file mode 100644 index 0000000..446ff8b --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/storage/ConverterType.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; + +/** + * The type of {@link Converter} and {@link HeaderConverter}. + */ +public enum ConverterType { + KEY, + VALUE, + HEADER; + + private static final Map NAME_TO_TYPE; + + static { + ConverterType[] types = ConverterType.values(); + Map nameToType = new HashMap<>(types.length); + for (ConverterType type : types) { + nameToType.put(type.name, type); + } + NAME_TO_TYPE = Collections.unmodifiableMap(nameToType); + } + + /** + * Find the ConverterType with the given name, using a case-insensitive match. + * @param name the name of the converter type; may be null + * @return the matching converter type, or null if the supplied name is null or does not match the name of the known types + */ + public static ConverterType withName(String name) { + if (name == null) { + return null; + } + return NAME_TO_TYPE.get(name.toLowerCase(Locale.getDefault())); + } + + private String name; + + ConverterType() { + this.name = this.name().toLowerCase(Locale.ROOT); + } + + public String getName() { + return name; + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/storage/HeaderConverter.java b/connect/api/src/main/java/org/apache/kafka/connect/storage/HeaderConverter.java new file mode 100644 index 0000000..3f9d504 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/storage/HeaderConverter.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import org.apache.kafka.common.Configurable; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.header.Header; + +import java.io.Closeable; + +public interface HeaderConverter extends Configurable, Closeable { + + /** + * Convert the header name and byte array value into a {@link Header} object. + * @param topic the name of the topic for the record containing the header + * @param headerKey the header's key; may not be null + * @param value the header's raw value; may be null + * @return the {@link SchemaAndValue}; may not be null + */ + SchemaAndValue toConnectHeader(String topic, String headerKey, byte[] value); + + /** + * Convert the {@link Header}'s {@link Header#value() value} into its byte array representation. + * @param topic the name of the topic for the record containing the header + * @param headerKey the header's key; may not be null + * @param schema the schema for the header's value; may be null + * @param value the header's value to convert; may be null + * @return the byte array form of the Header's value; may be null if the value is null + */ + byte[] fromConnectHeader(String topic, String headerKey, Schema schema, Object value); + + /** + * Configuration specification for this set of header converters. + * @return the configuration specification; may not be null + */ + ConfigDef config(); +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/storage/OffsetStorageReader.java b/connect/api/src/main/java/org/apache/kafka/connect/storage/OffsetStorageReader.java new file mode 100644 index 0000000..7f94b88 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/storage/OffsetStorageReader.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import java.util.Collection; +import java.util.Map; + +/** + *

            + * OffsetStorageReader provides access to the offset storage used by sources. This can be used by + * connectors to determine offsets to start consuming data from. This is most commonly used during + * initialization of a task, but can also be used during runtime, e.g. when reconfiguring a task. + *

            + *

            + * Offsets are always defined as Maps of Strings to primitive types, i.e. all types supported by + * {@link org.apache.kafka.connect.data.Schema} other than Array, Map, and Struct. + *

            + */ +public interface OffsetStorageReader { + /** + * Get the offset for the specified partition. If the data isn't already available locally, this + * gets it from the backing store, which may require some network round trips. + * + * @param partition object uniquely identifying the partition of data + * @return object uniquely identifying the offset in the partition of data + */ + Map offset(Map partition); + + /** + *

            + * Get a set of offsets for the specified partition identifiers. This may be more efficient + * than calling {@link #offset(Map)} repeatedly. + *

            + *

            + * Note that when errors occur, this method omits the associated data and tries to return as + * many of the requested values as possible. This allows a task that's managing many partitions to + * still proceed with any available data. Therefore, implementations should take care to check + * that the data is actually available in the returned response. The only case when an + * exception will be thrown is if the entire request failed, e.g. because the underlying + * storage was unavailable. + *

            + * + * @param partitions set of identifiers for partitions of data + * @return a map of partition identifiers to decoded offsets + */ + Map, Map> offsets(Collection> partitions); +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/storage/SimpleHeaderConverter.java b/connect/api/src/main/java/org/apache/kafka/connect/storage/SimpleHeaderConverter.java new file mode 100644 index 0000000..69c4b86 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/storage/SimpleHeaderConverter.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.data.Values; +import org.apache.kafka.connect.errors.DataException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.NoSuchElementException; + +/** + * A {@link HeaderConverter} that serializes header values as strings and that deserializes header values to the most appropriate + * numeric, boolean, array, or map representation. Schemas are not serialized, but are inferred upon deserialization when possible. + */ +public class SimpleHeaderConverter implements HeaderConverter { + + private static final Logger LOG = LoggerFactory.getLogger(SimpleHeaderConverter.class); + private static final ConfigDef CONFIG_DEF = new ConfigDef(); + private static final SchemaAndValue NULL_SCHEMA_AND_VALUE = new SchemaAndValue(null, null); + private static final Charset UTF_8 = StandardCharsets.UTF_8; + + @Override + public ConfigDef config() { + return CONFIG_DEF; + } + + @Override + public void configure(Map configs) { + // do nothing + } + + @Override + public SchemaAndValue toConnectHeader(String topic, String headerKey, byte[] value) { + if (value == null) { + return NULL_SCHEMA_AND_VALUE; + } + try { + String str = new String(value, UTF_8); + if (str.isEmpty()) { + return new SchemaAndValue(Schema.STRING_SCHEMA, str); + } + return Values.parseString(str); + } catch (NoSuchElementException e) { + throw new DataException("Failed to deserialize value for header '" + headerKey + "' on topic '" + topic + "'", e); + } catch (Throwable t) { + LOG.warn("Failed to deserialize value for header '{}' on topic '{}', so using byte array", headerKey, topic, t); + return new SchemaAndValue(Schema.BYTES_SCHEMA, value); + } + } + + @Override + public byte[] fromConnectHeader(String topic, String headerKey, Schema schema, Object value) { + if (value == null) { + return null; + } + return Values.convertToString(schema, value).getBytes(UTF_8); + } + + @Override + public void close() throws IOException { + // do nothing + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/storage/StringConverter.java b/connect/api/src/main/java/org/apache/kafka/connect/storage/StringConverter.java new file mode 100644 index 0000000..534cddd --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/storage/StringConverter.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.errors.SerializationException; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.errors.DataException; + +import java.util.HashMap; +import java.util.Map; + +/** + * {@link Converter} and {@link HeaderConverter} implementation that only supports serializing to strings. When converting Kafka Connect + * data to bytes, the schema will be ignored and {@link Object#toString()} will always be invoked to convert the data to a String. + * When converting from bytes to Kafka Connect format, the converter will only ever return an optional string schema and + * a string or null. + * + * Encoding configuration is identical to {@link StringSerializer} and {@link StringDeserializer}, but for convenience + * this class can also be configured to use the same encoding for both encoding and decoding with the + * {@link StringConverterConfig#ENCODING_CONFIG converter.encoding} setting. + * + * This implementation currently does nothing with the topic names or header names. + */ +public class StringConverter implements Converter, HeaderConverter { + + private final StringSerializer serializer = new StringSerializer(); + private final StringDeserializer deserializer = new StringDeserializer(); + + public StringConverter() { + } + + @Override + public ConfigDef config() { + return StringConverterConfig.configDef(); + } + + @Override + public void configure(Map configs) { + StringConverterConfig conf = new StringConverterConfig(configs); + String encoding = conf.encoding(); + + Map serializerConfigs = new HashMap<>(configs); + Map deserializerConfigs = new HashMap<>(configs); + serializerConfigs.put("serializer.encoding", encoding); + deserializerConfigs.put("deserializer.encoding", encoding); + + boolean isKey = conf.type() == ConverterType.KEY; + serializer.configure(serializerConfigs, isKey); + deserializer.configure(deserializerConfigs, isKey); + } + + @Override + public void configure(Map configs, boolean isKey) { + Map conf = new HashMap<>(configs); + conf.put(StringConverterConfig.TYPE_CONFIG, isKey ? ConverterType.KEY.getName() : ConverterType.VALUE.getName()); + configure(conf); + } + + @Override + public byte[] fromConnectData(String topic, Schema schema, Object value) { + try { + return serializer.serialize(topic, value == null ? null : value.toString()); + } catch (SerializationException e) { + throw new DataException("Failed to serialize to a string: ", e); + } + } + + @Override + public SchemaAndValue toConnectData(String topic, byte[] value) { + try { + return new SchemaAndValue(Schema.OPTIONAL_STRING_SCHEMA, deserializer.deserialize(topic, value)); + } catch (SerializationException e) { + throw new DataException("Failed to deserialize string: ", e); + } + } + + @Override + public byte[] fromConnectHeader(String topic, String headerKey, Schema schema, Object value) { + return fromConnectData(topic, schema, value); + } + + @Override + public SchemaAndValue toConnectHeader(String topic, String headerKey, byte[] value) { + return toConnectData(topic, value); + } + + @Override + public void close() { + // do nothing + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/storage/StringConverterConfig.java b/connect/api/src/main/java/org/apache/kafka/connect/storage/StringConverterConfig.java new file mode 100644 index 0000000..96d0b5b --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/storage/StringConverterConfig.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigDef.Importance; +import org.apache.kafka.common.config.ConfigDef.Type; +import org.apache.kafka.common.config.ConfigDef.Width; + +import java.nio.charset.StandardCharsets; +import java.util.Map; + +/** + * Configuration options for {@link StringConverter} instances. + */ +public class StringConverterConfig extends ConverterConfig { + + public static final String ENCODING_CONFIG = "converter.encoding"; + public static final String ENCODING_DEFAULT = StandardCharsets.UTF_8.name(); + private static final String ENCODING_DOC = "The name of the Java character set to use for encoding strings as byte arrays."; + private static final String ENCODING_DISPLAY = "Encoding"; + + private final static ConfigDef CONFIG; + + static { + CONFIG = ConverterConfig.newConfigDef(); + CONFIG.define(ENCODING_CONFIG, Type.STRING, ENCODING_DEFAULT, Importance.HIGH, ENCODING_DOC, null, -1, Width.MEDIUM, + ENCODING_DISPLAY); + } + + public static ConfigDef configDef() { + return CONFIG; + } + + public StringConverterConfig(Map props) { + super(CONFIG, props); + } + + /** + * Get the string encoding. + * + * @return the encoding; never null + */ + public String encoding() { + return getString(ENCODING_CONFIG); + } +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/transforms/Transformation.java b/connect/api/src/main/java/org/apache/kafka/connect/transforms/Transformation.java new file mode 100644 index 0000000..238a642 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/transforms/Transformation.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.common.Configurable; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.connector.ConnectRecord; + +import java.io.Closeable; + +/** + * Single message transformation for Kafka Connect record types. + * + * Connectors can be configured with transformations to make lightweight message-at-a-time modifications. + */ +public interface Transformation> extends Configurable, Closeable { + + /** + * Apply transformation to the {@code record} and return another record object (which may be {@code record} itself) or {@code null}, + * corresponding to a map or filter operation respectively. + * + * A transformation must not mutate objects reachable from the given {@code record} + * (including, but not limited to, {@link org.apache.kafka.connect.header.Headers Headers}, + * {@link org.apache.kafka.connect.data.Struct Structs}, {@code Lists}, and {@code Maps}). + * If such objects need to be changed, a new ConnectRecord should be created and returned. + * + * The implementation must be thread-safe. + */ + R apply(R record); + + /** Configuration specification for this transformation. **/ + ConfigDef config(); + + /** Signal that this transformation instance will no longer will be used. **/ + @Override + void close(); + +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/transforms/predicates/Predicate.java b/connect/api/src/main/java/org/apache/kafka/connect/transforms/predicates/Predicate.java new file mode 100644 index 0000000..cc38fc0 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/transforms/predicates/Predicate.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms.predicates; + +import org.apache.kafka.common.Configurable; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.connector.ConnectRecord; + +/** + *

            A predicate on records. + * Predicates can be used to conditionally apply a {@link org.apache.kafka.connect.transforms.Transformation} + * by configuring the transformation's {@code predicate} (and {@code negate}) configuration parameters. + * In particular, the {@code Filter} transformation can be conditionally applied in order to filter + * certain records from further processing. + * + *

            Implementations of this interface must be public and have a public constructor with no parameters. + * + * @param The type of record. + */ +public interface Predicate> extends Configurable, AutoCloseable { + + /** + * Configuration specification for this predicate. + * + * @return the configuration definition for this predicate; never null + */ + ConfigDef config(); + + /** + * Returns whether the given record satisfies this predicate. + * + * @param record the record to evaluate; may not be null + * @return true if the predicate matches, or false otherwise + */ + boolean test(R record); + + @Override + void close(); +} diff --git a/connect/api/src/main/java/org/apache/kafka/connect/util/ConnectorUtils.java b/connect/api/src/main/java/org/apache/kafka/connect/util/ConnectorUtils.java new file mode 100644 index 0000000..7c09093 --- /dev/null +++ b/connect/api/src/main/java/org/apache/kafka/connect/util/ConnectorUtils.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +import java.util.ArrayList; +import java.util.List; + +/** + * Utilities that connector implementations might find useful. Contains common building blocks + * for writing connectors. + */ +public class ConnectorUtils { + /** + * Given a list of elements and a target number of groups, generates list of groups of + * elements to match the target number of groups, spreading them evenly among the groups. + * This generates groups with contiguous elements, which results in intuitive ordering if + * your elements are also ordered (e.g. alphabetical lists of table names if you sort + * table names alphabetically to generate the raw partitions) or can result in efficient + * partitioning if elements are sorted according to some criteria that affects performance + * (e.g. topic partitions with the same leader). + * + * @param elements list of elements to partition + * @param numGroups the number of output groups to generate. + */ + public static List> groupPartitions(List elements, int numGroups) { + if (numGroups <= 0) + throw new IllegalArgumentException("Number of groups must be positive."); + + List> result = new ArrayList<>(numGroups); + + // Each group has either n+1 or n raw partitions + int perGroup = elements.size() / numGroups; + int leftover = elements.size() - (numGroups * perGroup); + + int assigned = 0; + for (int group = 0; group < numGroups; group++) { + int numThisGroup = group < leftover ? perGroup + 1 : perGroup; + List groupList = new ArrayList<>(numThisGroup); + for (int i = 0; i < numThisGroup; i++) { + groupList.add(elements.get(assigned)); + assigned++; + } + result.add(groupList); + } + + return result; + } +} diff --git a/connect/api/src/test/java/org/apache/kafka/connect/connector/ConnectorReconfigurationTest.java b/connect/api/src/test/java/org/apache/kafka/connect/connector/ConnectorReconfigurationTest.java new file mode 100644 index 0000000..b895ed3 --- /dev/null +++ b/connect/api/src/test/java/org/apache/kafka/connect/connector/ConnectorReconfigurationTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.connector; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.errors.ConnectException; + +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class ConnectorReconfigurationTest { + + @Test + public void testDefaultReconfigure() { + TestConnector conn = new TestConnector(false); + conn.reconfigure(Collections.emptyMap()); + assertEquals(conn.stopOrder, 0); + assertEquals(conn.configureOrder, 1); + } + + @Test + public void testReconfigureStopException() { + TestConnector conn = new TestConnector(true); + assertThrows(ConnectException.class, () -> conn.reconfigure(Collections.emptyMap())); + } + + private static class TestConnector extends Connector { + + private boolean stopException; + private int order = 0; + public int stopOrder = -1; + public int configureOrder = -1; + + public TestConnector(boolean stopException) { + this.stopException = stopException; + } + + @Override + public String version() { + return "1.0"; + } + + @Override + public void start(Map props) { + configureOrder = order++; + } + + @Override + public Class taskClass() { + return null; + } + + @Override + public List> taskConfigs(int count) { + return null; + } + + @Override + public void stop() { + stopOrder = order++; + if (stopException) + throw new ConnectException("error"); + } + + @Override + public ConfigDef config() { + return new ConfigDef(); + } + } +} diff --git a/connect/api/src/test/java/org/apache/kafka/connect/connector/ConnectorTest.java b/connect/api/src/test/java/org/apache/kafka/connect/connector/ConnectorTest.java new file mode 100644 index 0000000..ce0c1d4 --- /dev/null +++ b/connect/api/src/test/java/org/apache/kafka/connect/connector/ConnectorTest.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.connector; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public abstract class ConnectorTest { + + protected ConnectorContext context; + protected Connector connector; + protected AssertableConnector assertableConnector; + + @BeforeEach + public void beforeEach() { + connector = createConnector(); + context = createContext(); + assertableConnector = (AssertableConnector) connector; + } + + @Test + public void shouldInitializeContext() { + connector.initialize(context); + assertableConnector.assertInitialized(); + assertableConnector.assertContext(context); + assertableConnector.assertTaskConfigs(null); + } + + @Test + public void shouldInitializeContextWithTaskConfigs() { + List> taskConfigs = new ArrayList<>(); + connector.initialize(context, taskConfigs); + assertableConnector.assertInitialized(); + assertableConnector.assertContext(context); + assertableConnector.assertTaskConfigs(taskConfigs); + } + + @Test + public void shouldStopAndStartWhenReconfigure() { + Map props = new HashMap<>(); + connector.initialize(context); + assertableConnector.assertContext(context); + assertableConnector.assertStarted(false); + assertableConnector.assertStopped(false); + connector.reconfigure(props); + assertableConnector.assertStarted(true); + assertableConnector.assertStopped(true); + assertableConnector.assertProperties(props); + } + + protected abstract ConnectorContext createContext(); + + protected abstract Connector createConnector(); + + public interface AssertableConnector { + + void assertContext(ConnectorContext expected); + + void assertInitialized(); + + void assertTaskConfigs(List> expectedTaskConfigs); + + void assertStarted(boolean expected); + + void assertStopped(boolean expected); + + void assertProperties(Map expected); + } +} \ No newline at end of file diff --git a/connect/api/src/test/java/org/apache/kafka/connect/data/ConnectSchemaTest.java b/connect/api/src/test/java/org/apache/kafka/connect/data/ConnectSchemaTest.java new file mode 100644 index 0000000..25e6db3 --- /dev/null +++ b/connect/api/src/test/java/org/apache/kafka/connect/data/ConnectSchemaTest.java @@ -0,0 +1,333 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.data; + +import org.apache.kafka.connect.errors.DataException; +import org.junit.jupiter.api.Test; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class ConnectSchemaTest { + private static final Schema MAP_INT_STRING_SCHEMA = SchemaBuilder.map(Schema.INT32_SCHEMA, Schema.STRING_SCHEMA).build(); + private static final Schema FLAT_STRUCT_SCHEMA = SchemaBuilder.struct() + .field("field", Schema.INT32_SCHEMA) + .build(); + private static final Schema STRUCT_SCHEMA = SchemaBuilder.struct() + .field("first", Schema.INT32_SCHEMA) + .field("second", Schema.STRING_SCHEMA) + .field("array", SchemaBuilder.array(Schema.INT32_SCHEMA).build()) + .field("map", SchemaBuilder.map(Schema.INT32_SCHEMA, Schema.STRING_SCHEMA).build()) + .field("nested", FLAT_STRUCT_SCHEMA) + .build(); + private static final Schema PARENT_STRUCT_SCHEMA = SchemaBuilder.struct() + .field("nested", FLAT_STRUCT_SCHEMA) + .build(); + + @Test + public void testFieldsOnStructSchema() { + Schema schema = SchemaBuilder.struct() + .field("foo", Schema.BOOLEAN_SCHEMA) + .field("bar", Schema.INT32_SCHEMA) + .build(); + + assertEquals(2, schema.fields().size()); + // Validate field lookup by name + Field foo = schema.field("foo"); + assertEquals(0, foo.index()); + Field bar = schema.field("bar"); + assertEquals(1, bar.index()); + // Any other field name should fail + assertNull(schema.field("other")); + } + + + @Test + public void testFieldsOnlyValidForStructs() { + assertThrows(DataException.class, + Schema.INT8_SCHEMA::fields); + } + + @Test + public void testValidateValueMatchingType() { + ConnectSchema.validateValue(Schema.INT8_SCHEMA, (byte) 1); + ConnectSchema.validateValue(Schema.INT16_SCHEMA, (short) 1); + ConnectSchema.validateValue(Schema.INT32_SCHEMA, 1); + ConnectSchema.validateValue(Schema.INT64_SCHEMA, (long) 1); + ConnectSchema.validateValue(Schema.FLOAT32_SCHEMA, 1.f); + ConnectSchema.validateValue(Schema.FLOAT64_SCHEMA, 1.); + ConnectSchema.validateValue(Schema.BOOLEAN_SCHEMA, true); + ConnectSchema.validateValue(Schema.STRING_SCHEMA, "a string"); + ConnectSchema.validateValue(Schema.BYTES_SCHEMA, "a byte array".getBytes()); + ConnectSchema.validateValue(Schema.BYTES_SCHEMA, ByteBuffer.wrap("a byte array".getBytes())); + ConnectSchema.validateValue(SchemaBuilder.array(Schema.INT32_SCHEMA).build(), Arrays.asList(1, 2, 3)); + ConnectSchema.validateValue( + SchemaBuilder.map(Schema.INT32_SCHEMA, Schema.STRING_SCHEMA).build(), + Collections.singletonMap(1, "value") + ); + // Struct tests the basic struct layout + complex field types + nested structs + Struct structValue = new Struct(STRUCT_SCHEMA) + .put("first", 1) + .put("second", "foo") + .put("array", Arrays.asList(1, 2, 3)) + .put("map", Collections.singletonMap(1, "value")) + .put("nested", new Struct(FLAT_STRUCT_SCHEMA).put("field", 12)); + ConnectSchema.validateValue(STRUCT_SCHEMA, structValue); + } + + @Test + public void testValidateValueMatchingLogicalType() { + ConnectSchema.validateValue(Decimal.schema(2), new BigDecimal(new BigInteger("156"), 2)); + ConnectSchema.validateValue(Date.SCHEMA, new java.util.Date(0)); + ConnectSchema.validateValue(Time.SCHEMA, new java.util.Date(0)); + ConnectSchema.validateValue(Timestamp.SCHEMA, new java.util.Date(0)); + } + + // To avoid requiring excessive numbers of tests, these checks for invalid types use a similar type where possible + // to only include a single test for each type + + @Test + public void testValidateValueMismatchInt8() { + assertThrows(DataException.class, + () -> ConnectSchema.validateValue(Schema.INT8_SCHEMA, 1)); + } + + @Test + public void testValidateValueMismatchInt16() { + assertThrows(DataException.class, + () -> ConnectSchema.validateValue(Schema.INT16_SCHEMA, 1)); + } + + @Test + public void testValidateValueMismatchInt32() { + assertThrows(DataException.class, + () -> ConnectSchema.validateValue(Schema.INT32_SCHEMA, (long) 1)); + } + + @Test + public void testValidateValueMismatchInt64() { + assertThrows(DataException.class, + () -> ConnectSchema.validateValue(Schema.INT64_SCHEMA, 1)); + } + + @Test + public void testValidateValueMismatchFloat() { + assertThrows(DataException.class, + () -> ConnectSchema.validateValue(Schema.FLOAT32_SCHEMA, 1.0)); + } + + @Test + public void testValidateValueMismatchDouble() { + assertThrows(DataException.class, + () -> ConnectSchema.validateValue(Schema.FLOAT64_SCHEMA, 1.f)); + } + + @Test + public void testValidateValueMismatchBoolean() { + assertThrows(DataException.class, + () -> ConnectSchema.validateValue(Schema.BOOLEAN_SCHEMA, 1.f)); + } + + @Test + public void testValidateValueMismatchString() { + // CharSequence is a similar type (supertype of String), but we restrict to String. + CharBuffer cbuf = CharBuffer.wrap("abc"); + assertThrows(DataException.class, + () -> ConnectSchema.validateValue(Schema.STRING_SCHEMA, cbuf)); + } + + @Test + public void testValidateValueMismatchBytes() { + assertThrows(DataException.class, + () -> ConnectSchema.validateValue(Schema.BYTES_SCHEMA, new Object[]{1, "foo"})); + } + + @Test + public void testValidateValueMismatchArray() { + assertThrows(DataException.class, + () -> ConnectSchema.validateValue(SchemaBuilder.array(Schema.INT32_SCHEMA).build(), Arrays.asList("a", "b", "c"))); + } + + @Test + public void testValidateValueMismatchArraySomeMatch() { + // Even if some match the right type, this should fail if any mismatch. In this case, type erasure loses + // the fact that the list is actually List, but we couldn't tell if only checking the first element + assertThrows(DataException.class, + () -> ConnectSchema.validateValue(SchemaBuilder.array(Schema.INT32_SCHEMA).build(), Arrays.asList(1, 2, "c"))); + } + + @Test + public void testValidateValueMismatchMapKey() { + assertThrows(DataException.class, + () -> ConnectSchema.validateValue(MAP_INT_STRING_SCHEMA, Collections.singletonMap("wrong key type", "value"))); + } + + @Test + public void testValidateValueMismatchMapValue() { + assertThrows(DataException.class, + () -> ConnectSchema.validateValue(MAP_INT_STRING_SCHEMA, Collections.singletonMap(1, 2))); + } + + @Test + public void testValidateValueMismatchMapSomeKeys() { + Map data = new HashMap<>(); + data.put(1, "abc"); + data.put("wrong", "it's as easy as one two three"); + assertThrows(DataException.class, + () -> ConnectSchema.validateValue(MAP_INT_STRING_SCHEMA, data)); + } + + @Test + public void testValidateValueMismatchMapSomeValues() { + Map data = new HashMap<>(); + data.put(1, "abc"); + data.put(2, "wrong".getBytes()); + assertThrows(DataException.class, + () -> ConnectSchema.validateValue(MAP_INT_STRING_SCHEMA, data)); + } + + @Test + public void testValidateValueMismatchStructWrongSchema() { + // Completely mismatching schemas + assertThrows(DataException.class, () -> ConnectSchema.validateValue(FLAT_STRUCT_SCHEMA, + new Struct(SchemaBuilder.struct().field("x", Schema.INT32_SCHEMA).build()).put("x", 1))); + } + + @Test + public void testValidateValueMismatchStructWrongNestedSchema() { + // Top-level schema matches, but nested does not. + assertThrows(DataException.class, () -> ConnectSchema.validateValue(PARENT_STRUCT_SCHEMA, + new Struct(PARENT_STRUCT_SCHEMA) + .put("nested", new Struct(SchemaBuilder.struct() + .field("x", Schema.INT32_SCHEMA).build()).put("x", 1)))); + } + + @Test + public void testValidateValueMismatchDecimal() { + assertThrows(DataException.class, () -> ConnectSchema.validateValue(Decimal.schema(2), new BigInteger("156"))); + } + + @Test + public void testValidateValueMismatchDate() { + assertThrows(DataException.class, () -> ConnectSchema.validateValue(Date.SCHEMA, 1000L)); + } + + @Test + public void testValidateValueMismatchTime() { + assertThrows(DataException.class, () -> ConnectSchema.validateValue(Time.SCHEMA, 1000L)); + } + + @Test + public void testValidateValueMismatchTimestamp() { + assertThrows(DataException.class, () -> ConnectSchema.validateValue(Timestamp.SCHEMA, 1000L)); + } + + @Test + public void testPrimitiveEquality() { + // Test that primitive types, which only need to consider all the type & metadata fields, handle equality correctly + ConnectSchema s1 = new ConnectSchema(Schema.Type.INT8, false, null, "name", 2, "doc"); + ConnectSchema s2 = new ConnectSchema(Schema.Type.INT8, false, null, "name", 2, "doc"); + ConnectSchema differentType = new ConnectSchema(Schema.Type.INT16, false, null, "name", 2, "doc"); + ConnectSchema differentOptional = new ConnectSchema(Schema.Type.INT8, true, null, "name", 2, "doc"); + ConnectSchema differentDefault = new ConnectSchema(Schema.Type.INT8, false, true, "name", 2, "doc"); + ConnectSchema differentName = new ConnectSchema(Schema.Type.INT8, false, null, "otherName", 2, "doc"); + ConnectSchema differentVersion = new ConnectSchema(Schema.Type.INT8, false, null, "name", 4, "doc"); + ConnectSchema differentDoc = new ConnectSchema(Schema.Type.INT8, false, null, "name", 2, "other doc"); + ConnectSchema differentParameters = new ConnectSchema(Schema.Type.INT8, false, null, "name", 2, "doc", Collections.singletonMap("param", "value"), null, null, null); + + assertEquals(s1, s2); + assertNotEquals(s1, differentType); + assertNotEquals(s1, differentOptional); + assertNotEquals(s1, differentDefault); + assertNotEquals(s1, differentName); + assertNotEquals(s1, differentVersion); + assertNotEquals(s1, differentDoc); + assertNotEquals(s1, differentParameters); + } + + @Test + public void testArrayEquality() { + // Validate that the value type for the array is tested for equality. This test makes sure the same schema object is + // never reused to ensure we're actually checking equality + ConnectSchema s1 = new ConnectSchema(Schema.Type.ARRAY, false, null, null, null, null, null, null, null, SchemaBuilder.int8().build()); + ConnectSchema s2 = new ConnectSchema(Schema.Type.ARRAY, false, null, null, null, null, null, null, null, SchemaBuilder.int8().build()); + ConnectSchema differentValueSchema = new ConnectSchema(Schema.Type.ARRAY, false, null, null, null, null, null, null, null, SchemaBuilder.int16().build()); + + assertEquals(s1, s2); + assertNotEquals(s1, differentValueSchema); + } + + @Test + public void testArrayDefaultValueEquality() { + ConnectSchema s1 = new ConnectSchema(Schema.Type.ARRAY, false, new String[] {"a", "b"}, null, null, null, null, null, null, SchemaBuilder.int8().build()); + ConnectSchema s2 = new ConnectSchema(Schema.Type.ARRAY, false, new String[] {"a", "b"}, null, null, null, null, null, null, SchemaBuilder.int8().build()); + ConnectSchema differentValueSchema = new ConnectSchema(Schema.Type.ARRAY, false, new String[] {"b", "c"}, null, null, null, null, null, null, SchemaBuilder.int8().build()); + + assertEquals(s1, s2); + assertNotEquals(s1, differentValueSchema); + } + + @Test + public void testMapEquality() { + // Same as testArrayEquality, but for both key and value schemas + ConnectSchema s1 = new ConnectSchema(Schema.Type.MAP, false, null, null, null, null, null, null, SchemaBuilder.int8().build(), SchemaBuilder.int16().build()); + ConnectSchema s2 = new ConnectSchema(Schema.Type.MAP, false, null, null, null, null, null, null, SchemaBuilder.int8().build(), SchemaBuilder.int16().build()); + ConnectSchema differentKeySchema = new ConnectSchema(Schema.Type.MAP, false, null, null, null, null, null, null, SchemaBuilder.string().build(), SchemaBuilder.int16().build()); + ConnectSchema differentValueSchema = new ConnectSchema(Schema.Type.MAP, false, null, null, null, null, null, null, SchemaBuilder.int8().build(), SchemaBuilder.string().build()); + + assertEquals(s1, s2); + assertNotEquals(s1, differentKeySchema); + assertNotEquals(s1, differentValueSchema); + } + + @Test + public void testStructEquality() { + // Same as testArrayEquality, but checks differences in fields. Only does a simple check, relying on tests of + // Field's equals() method to validate all variations in the list of fields will be checked + ConnectSchema s1 = new ConnectSchema(Schema.Type.STRUCT, false, null, null, null, null, null, + Arrays.asList(new Field("field", 0, SchemaBuilder.int8().build()), + new Field("field2", 1, SchemaBuilder.int16().build())), null, null); + ConnectSchema s2 = new ConnectSchema(Schema.Type.STRUCT, false, null, null, null, null, null, + Arrays.asList(new Field("field", 0, SchemaBuilder.int8().build()), + new Field("field2", 1, SchemaBuilder.int16().build())), null, null); + ConnectSchema differentField = new ConnectSchema(Schema.Type.STRUCT, false, null, null, null, null, null, + Arrays.asList(new Field("field", 0, SchemaBuilder.int8().build()), + new Field("different field name", 1, SchemaBuilder.int16().build())), null, null); + + assertEquals(s1, s2); + assertNotEquals(s1, differentField); + } + + @Test + public void testEmptyStruct() { + final ConnectSchema emptyStruct = new ConnectSchema(Schema.Type.STRUCT, false, null, null, null, null); + assertEquals(0, emptyStruct.fields().size()); + new Struct(emptyStruct); + } + +} diff --git a/connect/api/src/test/java/org/apache/kafka/connect/data/DateTest.java b/connect/api/src/test/java/org/apache/kafka/connect/data/DateTest.java new file mode 100644 index 0000000..2cd656c --- /dev/null +++ b/connect/api/src/test/java/org/apache/kafka/connect/data/DateTest.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.data; + +import org.apache.kafka.connect.errors.DataException; +import org.junit.jupiter.api.Test; + +import java.util.Calendar; +import java.util.GregorianCalendar; +import java.util.TimeZone; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class DateTest { + private static final GregorianCalendar EPOCH; + private static final GregorianCalendar EPOCH_PLUS_TEN_THOUSAND_DAYS; + private static final GregorianCalendar EPOCH_PLUS_TIME_COMPONENT; + static { + EPOCH = new GregorianCalendar(1970, Calendar.JANUARY, 1, 0, 0, 0); + EPOCH.setTimeZone(TimeZone.getTimeZone("UTC")); + + EPOCH_PLUS_TIME_COMPONENT = new GregorianCalendar(1970, Calendar.JANUARY, 1, 0, 0, 1); + EPOCH_PLUS_TIME_COMPONENT.setTimeZone(TimeZone.getTimeZone("UTC")); + + EPOCH_PLUS_TEN_THOUSAND_DAYS = new GregorianCalendar(1970, Calendar.JANUARY, 1, 0, 0, 0); + EPOCH_PLUS_TEN_THOUSAND_DAYS.setTimeZone(TimeZone.getTimeZone("UTC")); + EPOCH_PLUS_TEN_THOUSAND_DAYS.add(Calendar.DATE, 10000); + } + + @Test + public void testBuilder() { + Schema plain = Date.SCHEMA; + assertEquals(Date.LOGICAL_NAME, plain.name()); + assertEquals(1, (Object) plain.version()); + } + + @Test + public void testFromLogical() { + assertEquals(0, Date.fromLogical(Date.SCHEMA, EPOCH.getTime())); + assertEquals(10000, Date.fromLogical(Date.SCHEMA, EPOCH_PLUS_TEN_THOUSAND_DAYS.getTime())); + } + + @Test + public void testFromLogicalInvalidSchema() { + assertThrows(DataException.class, + () -> Date.fromLogical(Date.builder().name("invalid").build(), EPOCH.getTime())); + } + + @Test + public void testFromLogicalInvalidHasTimeComponents() { + assertThrows(DataException.class, + () -> Date.fromLogical(Date.SCHEMA, EPOCH_PLUS_TIME_COMPONENT.getTime())); + } + + @Test + public void testToLogical() { + assertEquals(EPOCH.getTime(), Date.toLogical(Date.SCHEMA, 0)); + assertEquals(EPOCH_PLUS_TEN_THOUSAND_DAYS.getTime(), Date.toLogical(Date.SCHEMA, 10000)); + } + + @Test + public void testToLogicalInvalidSchema() { + assertThrows(DataException.class, + () -> Date.toLogical(Date.builder().name("invalid").build(), 0)); + } +} diff --git a/connect/api/src/test/java/org/apache/kafka/connect/data/DecimalTest.java b/connect/api/src/test/java/org/apache/kafka/connect/data/DecimalTest.java new file mode 100644 index 0000000..9592fb9 --- /dev/null +++ b/connect/api/src/test/java/org/apache/kafka/connect/data/DecimalTest.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.data; + +import org.junit.jupiter.api.Test; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class DecimalTest { + private static final int TEST_SCALE = 2; + private static final BigDecimal TEST_DECIMAL = new BigDecimal(new BigInteger("156"), TEST_SCALE); + private static final BigDecimal TEST_DECIMAL_NEGATIVE = new BigDecimal(new BigInteger("-156"), TEST_SCALE); + private static final byte[] TEST_BYTES = new byte[]{0, -100}; + private static final byte[] TEST_BYTES_NEGATIVE = new byte[]{-1, 100}; + + @Test + public void testBuilder() { + Schema plain = Decimal.builder(2).build(); + assertEquals(Decimal.LOGICAL_NAME, plain.name()); + assertEquals(Collections.singletonMap(Decimal.SCALE_FIELD, "2"), plain.parameters()); + assertEquals(1, (Object) plain.version()); + } + + @Test + public void testFromLogical() { + Schema schema = Decimal.schema(TEST_SCALE); + byte[] encoded = Decimal.fromLogical(schema, TEST_DECIMAL); + assertArrayEquals(TEST_BYTES, encoded); + + encoded = Decimal.fromLogical(schema, TEST_DECIMAL_NEGATIVE); + assertArrayEquals(TEST_BYTES_NEGATIVE, encoded); + } + + @Test + public void testToLogical() { + Schema schema = Decimal.schema(2); + BigDecimal converted = Decimal.toLogical(schema, TEST_BYTES); + assertEquals(TEST_DECIMAL, converted); + + converted = Decimal.toLogical(schema, TEST_BYTES_NEGATIVE); + assertEquals(TEST_DECIMAL_NEGATIVE, converted); + } +} diff --git a/connect/api/src/test/java/org/apache/kafka/connect/data/FakeSchema.java b/connect/api/src/test/java/org/apache/kafka/connect/data/FakeSchema.java new file mode 100644 index 0000000..bc6fe3e --- /dev/null +++ b/connect/api/src/test/java/org/apache/kafka/connect/data/FakeSchema.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.data; + +import java.util.List; +import java.util.Map; + +public class FakeSchema implements Schema { + @Override + public Type type() { + return null; + } + + @Override + public boolean isOptional() { + return false; + } + + @Override + public Object defaultValue() { + return null; + } + + @Override + public String name() { + return "fake"; + } + + @Override + public Integer version() { + return null; + } + + @Override + public String doc() { + return null; + } + + @Override + public Map parameters() { + return null; + } + + @Override + public Schema keySchema() { + return null; + } + + @Override + public Schema valueSchema() { + return null; + } + + @Override + public List fields() { + return null; + } + + @Override + public Field field(String fieldName) { + return null; + } + + @Override + public Schema schema() { + return null; + } +} diff --git a/connect/api/src/test/java/org/apache/kafka/connect/data/FieldTest.java b/connect/api/src/test/java/org/apache/kafka/connect/data/FieldTest.java new file mode 100644 index 0000000..6b2ffa4 --- /dev/null +++ b/connect/api/src/test/java/org/apache/kafka/connect/data/FieldTest.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.data; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; + +public class FieldTest { + + @Test + public void testEquality() { + Field field1 = new Field("name", 0, Schema.INT8_SCHEMA); + Field field2 = new Field("name", 0, Schema.INT8_SCHEMA); + Field differentName = new Field("name2", 0, Schema.INT8_SCHEMA); + Field differentIndex = new Field("name", 1, Schema.INT8_SCHEMA); + Field differentSchema = new Field("name", 0, Schema.INT16_SCHEMA); + + assertEquals(field1, field2); + assertNotEquals(field1, differentName); + assertNotEquals(field1, differentIndex); + assertNotEquals(field1, differentSchema); + } +} diff --git a/connect/api/src/test/java/org/apache/kafka/connect/data/SchemaBuilderTest.java b/connect/api/src/test/java/org/apache/kafka/connect/data/SchemaBuilderTest.java new file mode 100644 index 0000000..ba7c574 --- /dev/null +++ b/connect/api/src/test/java/org/apache/kafka/connect/data/SchemaBuilderTest.java @@ -0,0 +1,381 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.data; + +import org.apache.kafka.connect.errors.SchemaBuilderException; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class SchemaBuilderTest { + private static final String NAME = "name"; + private static final Integer VERSION = 2; + private static final String DOC = "doc"; + private static final Map NO_PARAMS = null; + + @Test + public void testInt8Builder() { + Schema schema = SchemaBuilder.int8().build(); + assertTypeAndDefault(schema, Schema.Type.INT8, false, null); + assertNoMetadata(schema); + + schema = SchemaBuilder.int8().name(NAME).optional().defaultValue((byte) 12) + .version(VERSION).doc(DOC).build(); + assertTypeAndDefault(schema, Schema.Type.INT8, true, (byte) 12); + assertMetadata(schema, NAME, VERSION, DOC, NO_PARAMS); + } + + @Test + public void testInt8BuilderInvalidDefault() { + assertThrows(SchemaBuilderException.class, () -> SchemaBuilder.int8().defaultValue("invalid")); + } + + @Test + public void testInt16Builder() { + Schema schema = SchemaBuilder.int16().build(); + assertTypeAndDefault(schema, Schema.Type.INT16, false, null); + assertNoMetadata(schema); + + schema = SchemaBuilder.int16().name(NAME).optional().defaultValue((short) 12) + .version(VERSION).doc(DOC).build(); + assertTypeAndDefault(schema, Schema.Type.INT16, true, (short) 12); + assertMetadata(schema, NAME, VERSION, DOC, NO_PARAMS); + } + + @Test + public void testInt16BuilderInvalidDefault() { + assertThrows(SchemaBuilderException.class, () -> SchemaBuilder.int16().defaultValue("invalid")); + } + + @Test + public void testInt32Builder() { + Schema schema = SchemaBuilder.int32().build(); + assertTypeAndDefault(schema, Schema.Type.INT32, false, null); + assertNoMetadata(schema); + + schema = SchemaBuilder.int32().name(NAME).optional().defaultValue(12) + .version(VERSION).doc(DOC).build(); + assertTypeAndDefault(schema, Schema.Type.INT32, true, 12); + assertMetadata(schema, NAME, VERSION, DOC, NO_PARAMS); + } + + @Test + public void testInt32BuilderInvalidDefault() { + assertThrows(SchemaBuilderException.class, () -> SchemaBuilder.int32().defaultValue("invalid")); + } + + @Test + public void testInt64Builder() { + Schema schema = SchemaBuilder.int64().build(); + assertTypeAndDefault(schema, Schema.Type.INT64, false, null); + assertNoMetadata(schema); + + schema = SchemaBuilder.int64().name(NAME).optional().defaultValue((long) 12) + .version(VERSION).doc(DOC).build(); + assertTypeAndDefault(schema, Schema.Type.INT64, true, (long) 12); + assertMetadata(schema, NAME, VERSION, DOC, NO_PARAMS); + } + + @Test + public void testInt64BuilderInvalidDefault() { + assertThrows(SchemaBuilderException.class, () -> SchemaBuilder.int64().defaultValue("invalid")); + } + + @Test + public void testFloatBuilder() { + Schema schema = SchemaBuilder.float32().build(); + assertTypeAndDefault(schema, Schema.Type.FLOAT32, false, null); + assertNoMetadata(schema); + + schema = SchemaBuilder.float32().name(NAME).optional().defaultValue(12.f) + .version(VERSION).doc(DOC).build(); + assertTypeAndDefault(schema, Schema.Type.FLOAT32, true, 12.f); + assertMetadata(schema, NAME, VERSION, DOC, NO_PARAMS); + } + + @Test + public void testFloatBuilderInvalidDefault() { + assertThrows(SchemaBuilderException.class, () -> SchemaBuilder.float32().defaultValue("invalid")); + } + + @Test + public void testDoubleBuilder() { + Schema schema = SchemaBuilder.float64().build(); + assertTypeAndDefault(schema, Schema.Type.FLOAT64, false, null); + assertNoMetadata(schema); + + schema = SchemaBuilder.float64().name(NAME).optional().defaultValue(12.0) + .version(VERSION).doc(DOC).build(); + assertTypeAndDefault(schema, Schema.Type.FLOAT64, true, 12.0); + assertMetadata(schema, NAME, VERSION, DOC, NO_PARAMS); + } + + @Test + public void testDoubleBuilderInvalidDefault() { + assertThrows(SchemaBuilderException.class, () -> SchemaBuilder.float64().defaultValue("invalid")); + } + + @Test + public void testBooleanBuilder() { + Schema schema = SchemaBuilder.bool().build(); + assertTypeAndDefault(schema, Schema.Type.BOOLEAN, false, null); + assertNoMetadata(schema); + + schema = SchemaBuilder.bool().name(NAME).optional().defaultValue(true) + .version(VERSION).doc(DOC).build(); + assertTypeAndDefault(schema, Schema.Type.BOOLEAN, true, true); + assertMetadata(schema, NAME, VERSION, DOC, NO_PARAMS); + } + + @Test + public void testBooleanBuilderInvalidDefault() { + assertThrows(SchemaBuilderException.class, () -> SchemaBuilder.bool().defaultValue("invalid")); + } + + @Test + public void testStringBuilder() { + Schema schema = SchemaBuilder.string().build(); + assertTypeAndDefault(schema, Schema.Type.STRING, false, null); + assertNoMetadata(schema); + + schema = SchemaBuilder.string().name(NAME).optional().defaultValue("a default string") + .version(VERSION).doc(DOC).build(); + assertTypeAndDefault(schema, Schema.Type.STRING, true, "a default string"); + assertMetadata(schema, NAME, VERSION, DOC, NO_PARAMS); + } + + @Test + public void testStringBuilderInvalidDefault() { + assertThrows(SchemaBuilderException.class, () -> SchemaBuilder.string().defaultValue(true)); + } + + @Test + public void testBytesBuilder() { + Schema schema = SchemaBuilder.bytes().build(); + assertTypeAndDefault(schema, Schema.Type.BYTES, false, null); + assertNoMetadata(schema); + + schema = SchemaBuilder.bytes().name(NAME).optional().defaultValue("a default byte array".getBytes()) + .version(VERSION).doc(DOC).build(); + assertTypeAndDefault(schema, Schema.Type.BYTES, true, "a default byte array".getBytes()); + assertMetadata(schema, NAME, VERSION, DOC, NO_PARAMS); + } + + @Test + public void testBytesBuilderInvalidDefault() { + assertThrows(SchemaBuilderException.class, () -> SchemaBuilder.bytes().defaultValue("a string, not bytes")); + } + + + @Test + public void testParameters() { + Map expectedParameters = new HashMap<>(); + expectedParameters.put("foo", "val"); + expectedParameters.put("bar", "baz"); + + Schema schema = SchemaBuilder.string().parameter("foo", "val").parameter("bar", "baz").build(); + assertTypeAndDefault(schema, Schema.Type.STRING, false, null); + assertMetadata(schema, null, null, null, expectedParameters); + + schema = SchemaBuilder.string().parameters(expectedParameters).build(); + assertTypeAndDefault(schema, Schema.Type.STRING, false, null); + assertMetadata(schema, null, null, null, expectedParameters); + } + + + @Test + public void testStructBuilder() { + Schema schema = SchemaBuilder.struct() + .field("field1", Schema.INT8_SCHEMA) + .field("field2", Schema.INT8_SCHEMA) + .build(); + assertTypeAndDefault(schema, Schema.Type.STRUCT, false, null); + assertEquals(2, schema.fields().size()); + assertEquals("field1", schema.fields().get(0).name()); + assertEquals(0, schema.fields().get(0).index()); + assertEquals(Schema.INT8_SCHEMA, schema.fields().get(0).schema()); + assertEquals("field2", schema.fields().get(1).name()); + assertEquals(1, schema.fields().get(1).index()); + assertEquals(Schema.INT8_SCHEMA, schema.fields().get(1).schema()); + assertNoMetadata(schema); + } + + @Test + public void testNonStructCantHaveFields() { + assertThrows(SchemaBuilderException.class, () -> SchemaBuilder.int8().field("field", SchemaBuilder.int8().build())); + } + + + @Test + public void testArrayBuilder() { + Schema schema = SchemaBuilder.array(Schema.INT8_SCHEMA).build(); + assertTypeAndDefault(schema, Schema.Type.ARRAY, false, null); + assertEquals(schema.valueSchema(), Schema.INT8_SCHEMA); + assertNoMetadata(schema); + + // Default value + List defArray = Arrays.asList((byte) 1, (byte) 2); + schema = SchemaBuilder.array(Schema.INT8_SCHEMA).defaultValue(defArray).build(); + assertTypeAndDefault(schema, Schema.Type.ARRAY, false, defArray); + assertEquals(schema.valueSchema(), Schema.INT8_SCHEMA); + assertNoMetadata(schema); + } + + @Test + public void testArrayBuilderInvalidDefault() { + // Array, but wrong embedded type + assertThrows(SchemaBuilderException.class, + () -> SchemaBuilder.array(Schema.INT8_SCHEMA).defaultValue(Collections.singletonList("string")).build()); + } + + @Test + public void testMapBuilder() { + // SchemaBuilder should also pass the check + Schema schema = SchemaBuilder.map(Schema.INT8_SCHEMA, Schema.INT8_SCHEMA); + assertTypeAndDefault(schema, Schema.Type.MAP, false, null); + assertEquals(schema.keySchema(), Schema.INT8_SCHEMA); + assertEquals(schema.valueSchema(), Schema.INT8_SCHEMA); + assertNoMetadata(schema); + + schema = SchemaBuilder.map(Schema.INT8_SCHEMA, Schema.INT8_SCHEMA).build(); + assertTypeAndDefault(schema, Schema.Type.MAP, false, null); + assertEquals(schema.keySchema(), Schema.INT8_SCHEMA); + assertEquals(schema.valueSchema(), Schema.INT8_SCHEMA); + assertNoMetadata(schema); + + // Default value + Map defMap = Collections.singletonMap((byte) 5, (byte) 10); + schema = SchemaBuilder.map(Schema.INT8_SCHEMA, Schema.INT8_SCHEMA) + .defaultValue(defMap).build(); + assertTypeAndDefault(schema, Schema.Type.MAP, false, defMap); + assertEquals(schema.keySchema(), Schema.INT8_SCHEMA); + assertEquals(schema.valueSchema(), Schema.INT8_SCHEMA); + assertNoMetadata(schema); + } + + @Test + public void testMapBuilderInvalidDefault() { + // Map, but wrong embedded type + Map defMap = Collections.singletonMap((byte) 5, "foo"); + assertThrows(SchemaBuilderException.class, () -> SchemaBuilder.map(Schema.INT8_SCHEMA, Schema.INT8_SCHEMA) + .defaultValue(defMap).build()); + } + + @Test + public void testEmptyStruct() { + final SchemaBuilder emptyStructSchemaBuilder = SchemaBuilder.struct(); + assertEquals(0, emptyStructSchemaBuilder.fields().size()); + new Struct(emptyStructSchemaBuilder); + + final Schema emptyStructSchema = emptyStructSchemaBuilder.build(); + assertEquals(0, emptyStructSchema.fields().size()); + new Struct(emptyStructSchema); + } + + @Test + public void testDuplicateFields() { + assertThrows(SchemaBuilderException.class, () -> SchemaBuilder.struct() + .name("testing") + .field("id", SchemaBuilder.string().doc("").build()) + .field("id", SchemaBuilder.string().doc("").build()) + .build()); + } + + @Test + public void testDefaultFieldsSameValueOverwriting() { + final SchemaBuilder schemaBuilder = SchemaBuilder.string().name("testing").version(123); + + schemaBuilder.name("testing"); + schemaBuilder.version(123); + + assertEquals("testing", schemaBuilder.name()); + } + + @Test + public void testDefaultFieldsDifferentValueOverwriting() { + final SchemaBuilder schemaBuilder = SchemaBuilder.string().name("testing").version(123); + + schemaBuilder.name("testing"); + assertThrows(SchemaBuilderException.class, () -> schemaBuilder.version(456)); + } + + @Test + public void testFieldNameNull() { + assertThrows(SchemaBuilderException.class, + () -> SchemaBuilder.struct().field(null, Schema.STRING_SCHEMA).build()); + } + + @Test + public void testFieldSchemaNull() { + assertThrows(SchemaBuilderException.class, + () -> SchemaBuilder.struct().field("fieldName", null).build()); + } + + @Test + public void testArraySchemaNull() { + assertThrows(SchemaBuilderException.class, () -> SchemaBuilder.array(null).build()); + } + + @Test + public void testMapKeySchemaNull() { + assertThrows(SchemaBuilderException.class, () -> SchemaBuilder.map(null, Schema.STRING_SCHEMA).build()); + } + + @Test + public void testMapValueSchemaNull() { + assertThrows(SchemaBuilderException.class, () -> SchemaBuilder.map(Schema.STRING_SCHEMA, null).build()); + } + + @Test + public void testTypeNotNull() { + assertThrows(SchemaBuilderException.class, () -> SchemaBuilder.type(null)); + } + + private void assertTypeAndDefault(Schema schema, Schema.Type type, boolean optional, Object defaultValue) { + assertEquals(type, schema.type()); + assertEquals(optional, schema.isOptional()); + if (type == Schema.Type.BYTES) { + // byte[] is not comparable, need to wrap to check correctly + if (defaultValue == null) + assertNull(schema.defaultValue()); + else + assertEquals(ByteBuffer.wrap((byte[]) defaultValue), ByteBuffer.wrap((byte[]) schema.defaultValue())); + } else { + assertEquals(defaultValue, schema.defaultValue()); + } + } + + private void assertMetadata(Schema schema, String name, Integer version, String doc, Map parameters) { + assertEquals(name, schema.name()); + assertEquals(version, schema.version()); + assertEquals(doc, schema.doc()); + assertEquals(parameters, schema.parameters()); + } + + private void assertNoMetadata(Schema schema) { + assertMetadata(schema, null, null, null, null); + } +} diff --git a/connect/api/src/test/java/org/apache/kafka/connect/data/SchemaProjectorTest.java b/connect/api/src/test/java/org/apache/kafka/connect/data/SchemaProjectorTest.java new file mode 100644 index 0000000..3e0c9de --- /dev/null +++ b/connect/api/src/test/java/org/apache/kafka/connect/data/SchemaProjectorTest.java @@ -0,0 +1,477 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.data; + +import org.apache.kafka.connect.data.Schema.Type; +import org.apache.kafka.connect.errors.DataException; +import org.apache.kafka.connect.errors.SchemaProjectorException; +import org.junit.jupiter.api.Test; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class SchemaProjectorTest { + + @Test + public void testPrimitiveTypeProjection() { + Object projected; + projected = SchemaProjector.project(Schema.BOOLEAN_SCHEMA, false, Schema.BOOLEAN_SCHEMA); + assertEquals(false, projected); + + byte[] bytes = {(byte) 1, (byte) 2}; + projected = SchemaProjector.project(Schema.BYTES_SCHEMA, bytes, Schema.BYTES_SCHEMA); + assertEquals(bytes, projected); + + projected = SchemaProjector.project(Schema.STRING_SCHEMA, "abc", Schema.STRING_SCHEMA); + assertEquals("abc", projected); + + projected = SchemaProjector.project(Schema.BOOLEAN_SCHEMA, false, Schema.OPTIONAL_BOOLEAN_SCHEMA); + assertEquals(false, projected); + + projected = SchemaProjector.project(Schema.BYTES_SCHEMA, bytes, Schema.OPTIONAL_BYTES_SCHEMA); + assertEquals(bytes, projected); + + projected = SchemaProjector.project(Schema.STRING_SCHEMA, "abc", Schema.OPTIONAL_STRING_SCHEMA); + assertEquals("abc", projected); + + assertThrows(DataException.class, () -> SchemaProjector.project(Schema.OPTIONAL_BOOLEAN_SCHEMA, false, + Schema.BOOLEAN_SCHEMA), "Cannot project optional schema to schema with no default value."); + + assertThrows(DataException.class, () -> SchemaProjector.project(Schema.OPTIONAL_BYTES_SCHEMA, bytes, + Schema.BYTES_SCHEMA), "Cannot project optional schema to schema with no default value."); + + assertThrows(DataException.class, () -> SchemaProjector.project(Schema.OPTIONAL_STRING_SCHEMA, "abc", + Schema.STRING_SCHEMA), "Cannot project optional schema to schema with no default value."); + } + + @Test + public void testNumericTypeProjection() { + Schema[] promotableSchemas = {Schema.INT8_SCHEMA, Schema.INT16_SCHEMA, Schema.INT32_SCHEMA, Schema.INT64_SCHEMA, Schema.FLOAT32_SCHEMA, Schema.FLOAT64_SCHEMA}; + Schema[] promotableOptionalSchemas = {Schema.OPTIONAL_INT8_SCHEMA, Schema.OPTIONAL_INT16_SCHEMA, Schema.OPTIONAL_INT32_SCHEMA, Schema.OPTIONAL_INT64_SCHEMA, + Schema.OPTIONAL_FLOAT32_SCHEMA, Schema.OPTIONAL_FLOAT64_SCHEMA}; + + Object[] values = {(byte) 127, (short) 255, 32767, 327890L, 1.2F, 1.2345}; + Map> expectedProjected = new HashMap<>(); + expectedProjected.put(values[0], Arrays.asList((byte) 127, (short) 127, 127, 127L, 127.F, 127.)); + expectedProjected.put(values[1], Arrays.asList((short) 255, 255, 255L, 255.F, 255.)); + expectedProjected.put(values[2], Arrays.asList(32767, 32767L, 32767.F, 32767.)); + expectedProjected.put(values[3], Arrays.asList(327890L, 327890.F, 327890.)); + expectedProjected.put(values[4], Arrays.asList(1.2F, 1.2)); + expectedProjected.put(values[5], Arrays.asList(1.2345)); + + Object promoted; + for (int i = 0; i < promotableSchemas.length; ++i) { + Schema source = promotableSchemas[i]; + List expected = expectedProjected.get(values[i]); + for (int j = i; j < promotableSchemas.length; ++j) { + Schema target = promotableSchemas[j]; + promoted = SchemaProjector.project(source, values[i], target); + if (target.type() == Type.FLOAT64) { + assertEquals((Double) (expected.get(j - i)), (double) promoted, 1e-6); + } else { + assertEquals(expected.get(j - i), promoted); + } + } + for (int j = i; j < promotableOptionalSchemas.length; ++j) { + Schema target = promotableOptionalSchemas[j]; + promoted = SchemaProjector.project(source, values[i], target); + if (target.type() == Type.FLOAT64) { + assertEquals((Double) (expected.get(j - i)), (double) promoted, 1e-6); + } else { + assertEquals(expected.get(j - i), promoted); + } + } + } + + for (int i = 0; i < promotableOptionalSchemas.length; ++i) { + Schema source = promotableSchemas[i]; + List expected = expectedProjected.get(values[i]); + for (int j = i; j < promotableOptionalSchemas.length; ++j) { + Schema target = promotableOptionalSchemas[j]; + promoted = SchemaProjector.project(source, values[i], target); + if (target.type() == Type.FLOAT64) { + assertEquals((Double) (expected.get(j - i)), (double) promoted, 1e-6); + } else { + assertEquals(expected.get(j - i), promoted); + } + } + } + + Schema[] nonPromotableSchemas = {Schema.BOOLEAN_SCHEMA, Schema.BYTES_SCHEMA, Schema.STRING_SCHEMA}; + for (Schema promotableSchema: promotableSchemas) { + for (Schema nonPromotableSchema: nonPromotableSchemas) { + Object dummy = new Object(); + + assertThrows(DataException.class, () -> SchemaProjector.project(promotableSchema, dummy, nonPromotableSchema), + "Cannot promote " + promotableSchema.type() + " to " + nonPromotableSchema.type()); + } + } + } + + @Test + public void testPrimitiveOptionalProjection() { + verifyOptionalProjection(Schema.OPTIONAL_BOOLEAN_SCHEMA, Type.BOOLEAN, false, true, false, true); + verifyOptionalProjection(Schema.OPTIONAL_BOOLEAN_SCHEMA, Type.BOOLEAN, false, true, false, false); + + byte[] bytes = {(byte) 1, (byte) 2}; + byte[] defaultBytes = {(byte) 3, (byte) 4}; + verifyOptionalProjection(Schema.OPTIONAL_BYTES_SCHEMA, Type.BYTES, bytes, defaultBytes, bytes, true); + verifyOptionalProjection(Schema.OPTIONAL_BYTES_SCHEMA, Type.BYTES, bytes, defaultBytes, bytes, false); + + verifyOptionalProjection(Schema.OPTIONAL_STRING_SCHEMA, Type.STRING, "abc", "def", "abc", true); + verifyOptionalProjection(Schema.OPTIONAL_STRING_SCHEMA, Type.STRING, "abc", "def", "abc", false); + + verifyOptionalProjection(Schema.OPTIONAL_INT8_SCHEMA, Type.INT8, (byte) 12, (byte) 127, (byte) 12, true); + verifyOptionalProjection(Schema.OPTIONAL_INT8_SCHEMA, Type.INT8, (byte) 12, (byte) 127, (byte) 12, false); + verifyOptionalProjection(Schema.OPTIONAL_INT8_SCHEMA, Type.INT16, (byte) 12, (short) 127, (short) 12, true); + verifyOptionalProjection(Schema.OPTIONAL_INT8_SCHEMA, Type.INT16, (byte) 12, (short) 127, (short) 12, false); + verifyOptionalProjection(Schema.OPTIONAL_INT8_SCHEMA, Type.INT32, (byte) 12, 12789, 12, true); + verifyOptionalProjection(Schema.OPTIONAL_INT8_SCHEMA, Type.INT32, (byte) 12, 12789, 12, false); + verifyOptionalProjection(Schema.OPTIONAL_INT8_SCHEMA, Type.INT64, (byte) 12, 127890L, 12L, true); + verifyOptionalProjection(Schema.OPTIONAL_INT8_SCHEMA, Type.INT64, (byte) 12, 127890L, 12L, false); + verifyOptionalProjection(Schema.OPTIONAL_INT8_SCHEMA, Type.FLOAT32, (byte) 12, 3.45F, 12.F, true); + verifyOptionalProjection(Schema.OPTIONAL_INT8_SCHEMA, Type.FLOAT32, (byte) 12, 3.45F, 12.F, false); + verifyOptionalProjection(Schema.OPTIONAL_INT8_SCHEMA, Type.FLOAT64, (byte) 12, 3.4567, 12., true); + verifyOptionalProjection(Schema.OPTIONAL_INT8_SCHEMA, Type.FLOAT64, (byte) 12, 3.4567, 12., false); + + verifyOptionalProjection(Schema.OPTIONAL_INT16_SCHEMA, Type.INT16, (short) 12, (short) 127, (short) 12, true); + verifyOptionalProjection(Schema.OPTIONAL_INT16_SCHEMA, Type.INT16, (short) 12, (short) 127, (short) 12, false); + verifyOptionalProjection(Schema.OPTIONAL_INT16_SCHEMA, Type.INT32, (short) 12, 12789, 12, true); + verifyOptionalProjection(Schema.OPTIONAL_INT16_SCHEMA, Type.INT32, (short) 12, 12789, 12, false); + verifyOptionalProjection(Schema.OPTIONAL_INT16_SCHEMA, Type.INT64, (short) 12, 127890L, 12L, true); + verifyOptionalProjection(Schema.OPTIONAL_INT16_SCHEMA, Type.INT64, (short) 12, 127890L, 12L, false); + verifyOptionalProjection(Schema.OPTIONAL_INT16_SCHEMA, Type.FLOAT32, (short) 12, 3.45F, 12.F, true); + verifyOptionalProjection(Schema.OPTIONAL_INT16_SCHEMA, Type.FLOAT32, (short) 12, 3.45F, 12.F, false); + verifyOptionalProjection(Schema.OPTIONAL_INT16_SCHEMA, Type.FLOAT64, (short) 12, 3.4567, 12., true); + verifyOptionalProjection(Schema.OPTIONAL_INT16_SCHEMA, Type.FLOAT64, (short) 12, 3.4567, 12., false); + + verifyOptionalProjection(Schema.OPTIONAL_INT32_SCHEMA, Type.INT32, 12, 12789, 12, true); + verifyOptionalProjection(Schema.OPTIONAL_INT32_SCHEMA, Type.INT32, 12, 12789, 12, false); + verifyOptionalProjection(Schema.OPTIONAL_INT32_SCHEMA, Type.INT64, 12, 127890L, 12L, true); + verifyOptionalProjection(Schema.OPTIONAL_INT32_SCHEMA, Type.INT64, 12, 127890L, 12L, false); + verifyOptionalProjection(Schema.OPTIONAL_INT32_SCHEMA, Type.FLOAT32, 12, 3.45F, 12.F, true); + verifyOptionalProjection(Schema.OPTIONAL_INT32_SCHEMA, Type.FLOAT32, 12, 3.45F, 12.F, false); + verifyOptionalProjection(Schema.OPTIONAL_INT32_SCHEMA, Type.FLOAT64, 12, 3.4567, 12., true); + verifyOptionalProjection(Schema.OPTIONAL_INT32_SCHEMA, Type.FLOAT64, 12, 3.4567, 12., false); + + verifyOptionalProjection(Schema.OPTIONAL_INT64_SCHEMA, Type.INT64, 12L, 127890L, 12L, true); + verifyOptionalProjection(Schema.OPTIONAL_INT64_SCHEMA, Type.INT64, 12L, 127890L, 12L, false); + verifyOptionalProjection(Schema.OPTIONAL_INT64_SCHEMA, Type.FLOAT32, 12L, 3.45F, 12.F, true); + verifyOptionalProjection(Schema.OPTIONAL_INT64_SCHEMA, Type.FLOAT32, 12L, 3.45F, 12.F, false); + verifyOptionalProjection(Schema.OPTIONAL_INT64_SCHEMA, Type.FLOAT64, 12L, 3.4567, 12., true); + verifyOptionalProjection(Schema.OPTIONAL_INT64_SCHEMA, Type.FLOAT64, 12L, 3.4567, 12., false); + + verifyOptionalProjection(Schema.OPTIONAL_FLOAT32_SCHEMA, Type.FLOAT32, 12.345F, 3.45F, 12.345F, true); + verifyOptionalProjection(Schema.OPTIONAL_FLOAT32_SCHEMA, Type.FLOAT32, 12.345F, 3.45F, 12.345F, false); + verifyOptionalProjection(Schema.OPTIONAL_FLOAT32_SCHEMA, Type.FLOAT64, 12.345F, 3.4567, 12.345, true); + verifyOptionalProjection(Schema.OPTIONAL_FLOAT32_SCHEMA, Type.FLOAT64, 12.345F, 3.4567, 12.345, false); + + verifyOptionalProjection(Schema.OPTIONAL_FLOAT32_SCHEMA, Type.FLOAT64, 12.345, 3.4567, 12.345, true); + verifyOptionalProjection(Schema.OPTIONAL_FLOAT32_SCHEMA, Type.FLOAT64, 12.345, 3.4567, 12.345, false); + } + + @Test + public void testStructAddField() { + Schema source = SchemaBuilder.struct() + .field("field", Schema.INT32_SCHEMA) + .build(); + Struct sourceStruct = new Struct(source); + sourceStruct.put("field", 1); + + Schema target = SchemaBuilder.struct() + .field("field", Schema.INT32_SCHEMA) + .field("field2", SchemaBuilder.int32().defaultValue(123).build()) + .build(); + + Struct targetStruct = (Struct) SchemaProjector.project(source, sourceStruct, target); + + assertEquals(1, (int) targetStruct.getInt32("field")); + assertEquals(123, (int) targetStruct.getInt32("field2")); + + Schema incompatibleTargetSchema = SchemaBuilder.struct() + .field("field", Schema.INT32_SCHEMA) + .field("field2", Schema.INT32_SCHEMA) + .build(); + + assertThrows(DataException.class, () -> SchemaProjector.project(source, sourceStruct, incompatibleTargetSchema), + "Incompatible schema."); + } + + @Test + public void testStructRemoveField() { + Schema source = SchemaBuilder.struct() + .field("field", Schema.INT32_SCHEMA) + .field("field2", Schema.INT32_SCHEMA) + .build(); + Struct sourceStruct = new Struct(source); + sourceStruct.put("field", 1); + sourceStruct.put("field2", 234); + + Schema target = SchemaBuilder.struct() + .field("field", Schema.INT32_SCHEMA) + .build(); + Struct targetStruct = (Struct) SchemaProjector.project(source, sourceStruct, target); + + assertEquals(1, targetStruct.get("field")); + assertThrows(DataException.class, () -> targetStruct.get("field2"), + "field2 is not part of the projected struct"); + } + + @Test + public void testStructDefaultValue() { + Schema source = SchemaBuilder.struct().optional() + .field("field", Schema.INT32_SCHEMA) + .field("field2", Schema.INT32_SCHEMA) + .build(); + + SchemaBuilder builder = SchemaBuilder.struct() + .field("field", Schema.INT32_SCHEMA) + .field("field2", Schema.INT32_SCHEMA); + + Struct defaultStruct = new Struct(builder).put("field", 12).put("field2", 345); + builder.defaultValue(defaultStruct); + Schema target = builder.build(); + + Object projected = SchemaProjector.project(source, null, target); + assertEquals(defaultStruct, projected); + + Struct sourceStruct = new Struct(source).put("field", 45).put("field2", 678); + Struct targetStruct = (Struct) SchemaProjector.project(source, sourceStruct, target); + + assertEquals(sourceStruct.get("field"), targetStruct.get("field")); + assertEquals(sourceStruct.get("field2"), targetStruct.get("field2")); + } + + @Test + public void testNestedSchemaProjection() { + Schema sourceFlatSchema = SchemaBuilder.struct() + .field("field", Schema.INT32_SCHEMA) + .build(); + Schema targetFlatSchema = SchemaBuilder.struct() + .field("field", Schema.INT32_SCHEMA) + .field("field2", SchemaBuilder.int32().defaultValue(123).build()) + .build(); + Schema sourceNestedSchema = SchemaBuilder.struct() + .field("first", Schema.INT32_SCHEMA) + .field("second", Schema.STRING_SCHEMA) + .field("array", SchemaBuilder.array(Schema.INT32_SCHEMA).build()) + .field("map", SchemaBuilder.map(Schema.INT32_SCHEMA, Schema.STRING_SCHEMA).build()) + .field("nested", sourceFlatSchema) + .build(); + Schema targetNestedSchema = SchemaBuilder.struct() + .field("first", Schema.INT32_SCHEMA) + .field("second", Schema.STRING_SCHEMA) + .field("array", SchemaBuilder.array(Schema.INT32_SCHEMA).build()) + .field("map", SchemaBuilder.map(Schema.INT32_SCHEMA, Schema.STRING_SCHEMA).build()) + .field("nested", targetFlatSchema) + .build(); + + Struct sourceFlatStruct = new Struct(sourceFlatSchema); + sourceFlatStruct.put("field", 113); + + Struct sourceNestedStruct = new Struct(sourceNestedSchema); + sourceNestedStruct.put("first", 1); + sourceNestedStruct.put("second", "abc"); + sourceNestedStruct.put("array", Arrays.asList(1, 2)); + sourceNestedStruct.put("map", Collections.singletonMap(5, "def")); + sourceNestedStruct.put("nested", sourceFlatStruct); + + Struct targetNestedStruct = (Struct) SchemaProjector.project(sourceNestedSchema, sourceNestedStruct, + targetNestedSchema); + assertEquals(1, targetNestedStruct.get("first")); + assertEquals("abc", targetNestedStruct.get("second")); + assertEquals(Arrays.asList(1, 2), targetNestedStruct.get("array")); + assertEquals(Collections.singletonMap(5, "def"), targetNestedStruct.get("map")); + + Struct projectedStruct = (Struct) targetNestedStruct.get("nested"); + assertEquals(113, projectedStruct.get("field")); + assertEquals(123, projectedStruct.get("field2")); + } + + @Test + public void testLogicalTypeProjection() { + Schema[] logicalTypeSchemas = {Decimal.schema(2), Date.SCHEMA, Time.SCHEMA, Timestamp.SCHEMA}; + Object projected; + + BigDecimal testDecimal = new BigDecimal(new BigInteger("156"), 2); + projected = SchemaProjector.project(Decimal.schema(2), testDecimal, Decimal.schema(2)); + assertEquals(testDecimal, projected); + + projected = SchemaProjector.project(Date.SCHEMA, 1000, Date.SCHEMA); + assertEquals(1000, projected); + + projected = SchemaProjector.project(Time.SCHEMA, 231, Time.SCHEMA); + assertEquals(231, projected); + + projected = SchemaProjector.project(Timestamp.SCHEMA, 34567L, Timestamp.SCHEMA); + assertEquals(34567L, projected); + + java.util.Date date = new java.util.Date(); + + projected = SchemaProjector.project(Date.SCHEMA, date, Date.SCHEMA); + assertEquals(date, projected); + + projected = SchemaProjector.project(Time.SCHEMA, date, Time.SCHEMA); + assertEquals(date, projected); + + projected = SchemaProjector.project(Timestamp.SCHEMA, date, Timestamp.SCHEMA); + assertEquals(date, projected); + + Schema namedSchema = SchemaBuilder.int32().name("invalidLogicalTypeName").build(); + for (Schema logicalTypeSchema: logicalTypeSchemas) { + assertThrows(SchemaProjectorException.class, () -> SchemaProjector.project(logicalTypeSchema, null, + Schema.BOOLEAN_SCHEMA), "Cannot project logical types to non-logical types."); + + assertThrows(SchemaProjectorException.class, () -> SchemaProjector.project(logicalTypeSchema, null, + namedSchema), "Reader name is not a valid logical type name."); + + assertThrows(SchemaProjectorException.class, () -> SchemaProjector.project(Schema.BOOLEAN_SCHEMA, + null, logicalTypeSchema), "Cannot project non-logical types to logical types."); + } + } + + @Test + public void testArrayProjection() { + Schema source = SchemaBuilder.array(Schema.INT32_SCHEMA).build(); + + Object projected = SchemaProjector.project(source, Arrays.asList(1, 2, 3), source); + assertEquals(Arrays.asList(1, 2, 3), projected); + + Schema optionalSource = SchemaBuilder.array(Schema.INT32_SCHEMA).optional().build(); + Schema target = SchemaBuilder.array(Schema.INT32_SCHEMA).defaultValue(Arrays.asList(1, 2, 3)).build(); + projected = SchemaProjector.project(optionalSource, Arrays.asList(4, 5), target); + assertEquals(Arrays.asList(4, 5), projected); + projected = SchemaProjector.project(optionalSource, null, target); + assertEquals(Arrays.asList(1, 2, 3), projected); + + Schema promotedTarget = SchemaBuilder.array(Schema.INT64_SCHEMA).defaultValue(Arrays.asList(1L, 2L, 3L)).build(); + projected = SchemaProjector.project(optionalSource, Arrays.asList(4, 5), promotedTarget); + List expectedProjected = Arrays.asList(4L, 5L); + assertEquals(expectedProjected, projected); + projected = SchemaProjector.project(optionalSource, null, promotedTarget); + assertEquals(Arrays.asList(1L, 2L, 3L), projected); + + Schema noDefaultValueTarget = SchemaBuilder.array(Schema.INT32_SCHEMA).build(); + assertThrows(SchemaProjectorException.class, () -> SchemaProjector.project(optionalSource, null, + noDefaultValueTarget), "Target schema does not provide a default value."); + + Schema nonPromotableTarget = SchemaBuilder.array(Schema.BOOLEAN_SCHEMA).build(); + assertThrows(SchemaProjectorException.class, + () -> SchemaProjector.project(optionalSource, null, nonPromotableTarget), + "Neither source type matches target type nor source type can be promoted to target type"); + } + + @Test + public void testMapProjection() { + Schema source = SchemaBuilder.map(Schema.INT32_SCHEMA, Schema.INT32_SCHEMA).optional().build(); + + Schema target = SchemaBuilder.map(Schema.INT32_SCHEMA, Schema.INT32_SCHEMA).defaultValue(Collections.singletonMap(1, 2)).build(); + Object projected = SchemaProjector.project(source, Collections.singletonMap(3, 4), target); + assertEquals(Collections.singletonMap(3, 4), projected); + projected = SchemaProjector.project(source, null, target); + assertEquals(Collections.singletonMap(1, 2), projected); + + Schema promotedTarget = SchemaBuilder.map(Schema.INT64_SCHEMA, Schema.FLOAT32_SCHEMA).defaultValue( + Collections.singletonMap(3L, 4.5F)).build(); + projected = SchemaProjector.project(source, Collections.singletonMap(3, 4), promotedTarget); + assertEquals(Collections.singletonMap(3L, 4.F), projected); + projected = SchemaProjector.project(source, null, promotedTarget); + assertEquals(Collections.singletonMap(3L, 4.5F), projected); + + Schema noDefaultValueTarget = SchemaBuilder.map(Schema.INT32_SCHEMA, Schema.INT32_SCHEMA).build(); + assertThrows(SchemaProjectorException.class, + () -> SchemaProjector.project(source, null, noDefaultValueTarget), + "Reader does not provide a default value."); + + Schema nonPromotableTarget = SchemaBuilder.map(Schema.BOOLEAN_SCHEMA, Schema.STRING_SCHEMA).build(); + assertThrows(SchemaProjectorException.class, + () -> SchemaProjector.project(source, null, nonPromotableTarget), + "Neither source type matches target type nor source type can be promoted to target type"); + } + + @Test + public void testMaybeCompatible() { + Schema source = SchemaBuilder.int32().name("source").build(); + Schema target = SchemaBuilder.int32().name("target").build(); + + assertThrows(SchemaProjectorException.class, + () -> SchemaProjector.project(source, 12, target), + "Source name and target name mismatch."); + + Schema targetWithParameters = SchemaBuilder.int32().parameters(Collections.singletonMap("key", "value")); + assertThrows(SchemaProjectorException.class, + () -> SchemaProjector.project(source, 34, targetWithParameters), + "Source parameters and target parameters mismatch."); + } + + @Test + public void testProjectMissingDefaultValuedStructField() { + final Schema source = SchemaBuilder.struct().build(); + final Schema target = SchemaBuilder.struct().field("id", SchemaBuilder.int64().defaultValue(42L).build()).build(); + assertEquals(42L, (long) ((Struct) SchemaProjector.project(source, new Struct(source), target)).getInt64("id")); + } + + @Test + public void testProjectMissingOptionalStructField() { + final Schema source = SchemaBuilder.struct().build(); + final Schema target = SchemaBuilder.struct().field("id", SchemaBuilder.OPTIONAL_INT64_SCHEMA).build(); + assertNull(((Struct) SchemaProjector.project(source, new Struct(source), target)).getInt64("id")); + } + + @Test + public void testProjectMissingRequiredField() { + final Schema source = SchemaBuilder.struct().build(); + final Schema target = SchemaBuilder.struct().field("id", SchemaBuilder.INT64_SCHEMA).build(); + assertThrows(SchemaProjectorException.class, () -> SchemaProjector.project(source, new Struct(source), target)); + } + + private void verifyOptionalProjection(Schema source, Type targetType, Object value, Object defaultValue, Object expectedProjected, boolean optional) { + Schema target; + assertTrue(source.isOptional()); + assertNotNull(value); + + if (optional) { + target = SchemaBuilder.type(targetType).optional().defaultValue(defaultValue).build(); + } else { + target = SchemaBuilder.type(targetType).defaultValue(defaultValue).build(); + } + Object projected = SchemaProjector.project(source, value, target); + if (targetType == Type.FLOAT64) { + assertEquals((double) expectedProjected, (double) projected, 1e-6); + } else { + assertEquals(expectedProjected, projected); + } + + projected = SchemaProjector.project(source, null, target); + if (optional) { + assertNull(projected); + } else { + assertEquals(defaultValue, projected); + } + } +} diff --git a/connect/api/src/test/java/org/apache/kafka/connect/data/StructTest.java b/connect/api/src/test/java/org/apache/kafka/connect/data/StructTest.java new file mode 100644 index 0000000..55ccc81 --- /dev/null +++ b/connect/api/src/test/java/org/apache/kafka/connect/data/StructTest.java @@ -0,0 +1,361 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.data; + +import org.apache.kafka.connect.errors.DataException; + +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + + +public class StructTest { + + private static final Schema FLAT_STRUCT_SCHEMA = SchemaBuilder.struct() + .field("int8", Schema.INT8_SCHEMA) + .field("int16", Schema.INT16_SCHEMA) + .field("int32", Schema.INT32_SCHEMA) + .field("int64", Schema.INT64_SCHEMA) + .field("float32", Schema.FLOAT32_SCHEMA) + .field("float64", Schema.FLOAT64_SCHEMA) + .field("boolean", Schema.BOOLEAN_SCHEMA) + .field("string", Schema.STRING_SCHEMA) + .field("bytes", Schema.BYTES_SCHEMA) + .build(); + + private static final Schema ARRAY_SCHEMA = SchemaBuilder.array(Schema.INT8_SCHEMA).build(); + private static final Schema MAP_SCHEMA = SchemaBuilder.map( + Schema.INT32_SCHEMA, + Schema.STRING_SCHEMA + ).build(); + private static final Schema NESTED_CHILD_SCHEMA = SchemaBuilder.struct() + .field("int8", Schema.INT8_SCHEMA) + .build(); + private static final Schema NESTED_SCHEMA = SchemaBuilder.struct() + .field("array", ARRAY_SCHEMA) + .field("map", MAP_SCHEMA) + .field("nested", NESTED_CHILD_SCHEMA) + .build(); + + private static final Schema REQUIRED_FIELD_SCHEMA = Schema.INT8_SCHEMA; + private static final Schema OPTIONAL_FIELD_SCHEMA = SchemaBuilder.int8().optional().build(); + private static final Schema DEFAULT_FIELD_SCHEMA = SchemaBuilder.int8().defaultValue((byte) 0).build(); + + @Test + public void testFlatStruct() { + Struct struct = new Struct(FLAT_STRUCT_SCHEMA) + .put("int8", (byte) 12) + .put("int16", (short) 12) + .put("int32", 12) + .put("int64", (long) 12) + .put("float32", 12.f) + .put("float64", 12.) + .put("boolean", true) + .put("string", "foobar") + .put("bytes", "foobar".getBytes()); + + // Test equality, and also the type-specific getters + assertEquals((byte) 12, (byte) struct.getInt8("int8")); + assertEquals((short) 12, (short) struct.getInt16("int16")); + assertEquals(12, (int) struct.getInt32("int32")); + assertEquals((long) 12, (long) struct.getInt64("int64")); + assertEquals((Float) 12.f, struct.getFloat32("float32")); + assertEquals((Double) 12., struct.getFloat64("float64")); + assertEquals(true, struct.getBoolean("boolean")); + assertEquals("foobar", struct.getString("string")); + assertEquals(ByteBuffer.wrap("foobar".getBytes()), ByteBuffer.wrap(struct.getBytes("bytes"))); + + struct.validate(); + } + + @Test + public void testComplexStruct() { + List array = Arrays.asList((byte) 1, (byte) 2); + Map map = Collections.singletonMap(1, "string"); + Struct struct = new Struct(NESTED_SCHEMA) + .put("array", array) + .put("map", map) + .put("nested", new Struct(NESTED_CHILD_SCHEMA).put("int8", (byte) 12)); + + // Separate the call to get the array and map to validate the typed get methods work properly + List arrayExtracted = struct.getArray("array"); + assertEquals(array, arrayExtracted); + Map mapExtracted = struct.getMap("map"); + assertEquals(map, mapExtracted); + assertEquals((byte) 12, struct.getStruct("nested").get("int8")); + + struct.validate(); + } + + + // These don't test all the ways validation can fail, just one for each element. See more extensive validation + // tests in SchemaTest. These are meant to ensure that we are invoking the same code path and that we do deeper + // inspection than just checking the class of the object + + @Test + public void testInvalidFieldType() { + assertThrows(DataException.class, + () -> new Struct(FLAT_STRUCT_SCHEMA).put("int8", "should fail because this is a string, not int8")); + } + + @Test + public void testInvalidArrayFieldElements() { + assertThrows(DataException.class, + () -> new Struct(NESTED_SCHEMA).put("array", Collections.singletonList("should fail since elements should be int8s"))); + } + + @Test + public void testInvalidMapKeyElements() { + assertThrows(DataException.class, + () -> new Struct(NESTED_SCHEMA).put("map", Collections.singletonMap("should fail because keys should be int8s", (byte) 12))); + } + + @Test + public void testInvalidStructFieldSchema() { + assertThrows(DataException.class, + () -> new Struct(NESTED_SCHEMA).put("nested", new Struct(MAP_SCHEMA))); + } + + @Test + public void testInvalidStructFieldValue() { + assertThrows(DataException.class, + () -> new Struct(NESTED_SCHEMA).put("nested", new Struct(NESTED_CHILD_SCHEMA))); + } + + + @Test + public void testMissingFieldValidation() { + // Required int8 field + Schema schema = SchemaBuilder.struct().field("field", REQUIRED_FIELD_SCHEMA).build(); + Struct struct = new Struct(schema); + assertThrows(DataException.class, struct::validate); + } + + @Test + public void testMissingOptionalFieldValidation() { + Schema schema = SchemaBuilder.struct().field("field", OPTIONAL_FIELD_SCHEMA).build(); + Struct struct = new Struct(schema); + struct.validate(); + } + + @Test + public void testMissingFieldWithDefaultValidation() { + Schema schema = SchemaBuilder.struct().field("field", DEFAULT_FIELD_SCHEMA).build(); + Struct struct = new Struct(schema); + struct.validate(); + } + + @Test + public void testMissingFieldWithDefaultValue() { + Schema schema = SchemaBuilder.struct().field("field", DEFAULT_FIELD_SCHEMA).build(); + Struct struct = new Struct(schema); + assertEquals((byte) 0, struct.get("field")); + } + + @Test + public void testMissingFieldWithoutDefaultValue() { + Schema schema = SchemaBuilder.struct().field("field", REQUIRED_FIELD_SCHEMA).build(); + Struct struct = new Struct(schema); + assertNull(struct.get("field")); + } + + + @Test + public void testEquals() { + Struct struct1 = new Struct(FLAT_STRUCT_SCHEMA) + .put("int8", (byte) 12) + .put("int16", (short) 12) + .put("int32", 12) + .put("int64", (long) 12) + .put("float32", 12.f) + .put("float64", 12.) + .put("boolean", true) + .put("string", "foobar") + .put("bytes", ByteBuffer.wrap("foobar".getBytes())); + Struct struct2 = new Struct(FLAT_STRUCT_SCHEMA) + .put("int8", (byte) 12) + .put("int16", (short) 12) + .put("int32", 12) + .put("int64", (long) 12) + .put("float32", 12.f) + .put("float64", 12.) + .put("boolean", true) + .put("string", "foobar") + .put("bytes", ByteBuffer.wrap("foobar".getBytes())); + Struct struct3 = new Struct(FLAT_STRUCT_SCHEMA) + .put("int8", (byte) 12) + .put("int16", (short) 12) + .put("int32", 12) + .put("int64", (long) 12) + .put("float32", 12.f) + .put("float64", 12.) + .put("boolean", true) + .put("string", "mismatching string") + .put("bytes", ByteBuffer.wrap("foobar".getBytes())); + + assertEquals(struct1, struct2); + assertNotEquals(struct1, struct3); + + List array = Arrays.asList((byte) 1, (byte) 2); + Map map = Collections.singletonMap(1, "string"); + struct1 = new Struct(NESTED_SCHEMA) + .put("array", array) + .put("map", map) + .put("nested", new Struct(NESTED_CHILD_SCHEMA).put("int8", (byte) 12)); + List array2 = Arrays.asList((byte) 1, (byte) 2); + Map map2 = Collections.singletonMap(1, "string"); + struct2 = new Struct(NESTED_SCHEMA) + .put("array", array2) + .put("map", map2) + .put("nested", new Struct(NESTED_CHILD_SCHEMA).put("int8", (byte) 12)); + List array3 = Arrays.asList((byte) 1, (byte) 2, (byte) 3); + Map map3 = Collections.singletonMap(2, "string"); + struct3 = new Struct(NESTED_SCHEMA) + .put("array", array3) + .put("map", map3) + .put("nested", new Struct(NESTED_CHILD_SCHEMA).put("int8", (byte) 13)); + + assertEquals(struct1, struct2); + assertNotEquals(struct1, struct3); + } + + @Test + public void testEqualsAndHashCodeWithByteArrayValue() { + Struct struct1 = new Struct(FLAT_STRUCT_SCHEMA) + .put("int8", (byte) 12) + .put("int16", (short) 12) + .put("int32", 12) + .put("int64", (long) 12) + .put("float32", 12.f) + .put("float64", 12.) + .put("boolean", true) + .put("string", "foobar") + .put("bytes", "foobar".getBytes()); + + Struct struct2 = new Struct(FLAT_STRUCT_SCHEMA) + .put("int8", (byte) 12) + .put("int16", (short) 12) + .put("int32", 12) + .put("int64", (long) 12) + .put("float32", 12.f) + .put("float64", 12.) + .put("boolean", true) + .put("string", "foobar") + .put("bytes", "foobar".getBytes()); + + Struct struct3 = new Struct(FLAT_STRUCT_SCHEMA) + .put("int8", (byte) 12) + .put("int16", (short) 12) + .put("int32", 12) + .put("int64", (long) 12) + .put("float32", 12.f) + .put("float64", 12.) + .put("boolean", true) + .put("string", "foobar") + .put("bytes", "mismatching_string".getBytes()); + + // Verify contract for equals: method must be reflexive and transitive + assertEquals(struct1, struct2); + assertEquals(struct2, struct1); + assertNotEquals(struct1, struct3); + assertNotEquals(struct2, struct3); + // Testing hashCode against a hardcoded value here would be incorrect: hashCode values need not be equal for any + // two distinct executions. However, based on the general contract for hashCode, if two objects are equal, their + // hashCodes must be equal. If they are not equal, their hashCodes should not be equal for performance reasons. + assertEquals(struct1.hashCode(), struct2.hashCode()); + assertNotEquals(struct1.hashCode(), struct3.hashCode()); + assertNotEquals(struct2.hashCode(), struct3.hashCode()); + } + + @Test + public void testValidateStructWithNullValue() { + Schema schema = SchemaBuilder.struct() + .field("one", Schema.STRING_SCHEMA) + .field("two", Schema.STRING_SCHEMA) + .field("three", Schema.STRING_SCHEMA) + .build(); + + Struct struct = new Struct(schema); + Exception e = assertThrows(DataException.class, struct::validate); + assertEquals("Invalid value: null used for required field: \"one\", schema type: STRING", + e.getMessage()); + } + + @Test + public void testValidateFieldWithInvalidValueType() { + String fieldName = "field"; + FakeSchema fakeSchema = new FakeSchema(); + + Exception e = assertThrows(DataException.class, () -> ConnectSchema.validateValue(fieldName, + fakeSchema, new Object())); + assertEquals("Invalid Java object for schema \"fake\" with type null: class java.lang.Object for field: \"field\"", + e.getMessage()); + + e = assertThrows(DataException.class, () -> ConnectSchema.validateValue(fieldName, + Schema.INT8_SCHEMA, new Object())); + assertEquals("Invalid Java object for schema with type INT8: class java.lang.Object for field: \"field\"", + e.getMessage()); + + e = assertThrows(DataException.class, () -> ConnectSchema.validateValue(Schema.INT8_SCHEMA, new Object())); + assertEquals("Invalid Java object for schema with type INT8: class java.lang.Object", e.getMessage()); + } + + @Test + public void testValidateFieldWithInvalidValueMismatchTimestamp() { + String fieldName = "field"; + long longValue = 1000L; + + // Does not throw + ConnectSchema.validateValue(fieldName, Schema.INT64_SCHEMA, longValue); + + Exception e = assertThrows(DataException.class, () -> ConnectSchema.validateValue(fieldName, + Timestamp.SCHEMA, longValue)); + assertEquals("Invalid Java object for schema \"org.apache.kafka.connect.data.Timestamp\" " + + "with type INT64: class java.lang.Long for field: \"field\"", e.getMessage()); + } + + @Test + public void testPutNullField() { + final String fieldName = "fieldName"; + Schema testSchema = SchemaBuilder.struct() + .field(fieldName, Schema.STRING_SCHEMA); + Struct struct = new Struct(testSchema); + + assertThrows(DataException.class, () -> struct.put((Field) null, "valid")); + } + + @Test + public void testInvalidPutIncludesFieldName() { + final String fieldName = "fieldName"; + Schema testSchema = SchemaBuilder.struct() + .field(fieldName, Schema.STRING_SCHEMA); + Struct struct = new Struct(testSchema); + + Exception e = assertThrows(DataException.class, () -> struct.put(fieldName, null)); + assertEquals("Invalid value: null used for required field: \"fieldName\", schema type: STRING", + e.getMessage()); + } +} diff --git a/connect/api/src/test/java/org/apache/kafka/connect/data/TimeTest.java b/connect/api/src/test/java/org/apache/kafka/connect/data/TimeTest.java new file mode 100644 index 0000000..b07ccc0 --- /dev/null +++ b/connect/api/src/test/java/org/apache/kafka/connect/data/TimeTest.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.data; + +import org.apache.kafka.connect.errors.DataException; +import org.junit.jupiter.api.Test; + +import java.util.Calendar; +import java.util.GregorianCalendar; +import java.util.TimeZone; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class TimeTest { + private static final GregorianCalendar EPOCH; + private static final GregorianCalendar EPOCH_PLUS_DATE_COMPONENT; + private static final GregorianCalendar EPOCH_PLUS_TEN_THOUSAND_MILLIS; + static { + EPOCH = new GregorianCalendar(1970, Calendar.JANUARY, 1, 0, 0, 0); + EPOCH.setTimeZone(TimeZone.getTimeZone("UTC")); + + EPOCH_PLUS_TEN_THOUSAND_MILLIS = new GregorianCalendar(1970, Calendar.JANUARY, 1, 0, 0, 0); + EPOCH_PLUS_TEN_THOUSAND_MILLIS.setTimeZone(TimeZone.getTimeZone("UTC")); + EPOCH_PLUS_TEN_THOUSAND_MILLIS.add(Calendar.MILLISECOND, 10000); + + + EPOCH_PLUS_DATE_COMPONENT = new GregorianCalendar(1970, Calendar.JANUARY, 1, 0, 0, 0); + EPOCH_PLUS_DATE_COMPONENT.setTimeZone(TimeZone.getTimeZone("UTC")); + EPOCH_PLUS_DATE_COMPONENT.add(Calendar.DATE, 10000); + } + + @Test + public void testBuilder() { + Schema plain = Time.SCHEMA; + assertEquals(Time.LOGICAL_NAME, plain.name()); + assertEquals(1, (Object) plain.version()); + } + + @Test + public void testFromLogical() { + assertEquals(0, Time.fromLogical(Time.SCHEMA, EPOCH.getTime())); + assertEquals(10000, Time.fromLogical(Time.SCHEMA, EPOCH_PLUS_TEN_THOUSAND_MILLIS.getTime())); + } + + @Test + public void testFromLogicalInvalidSchema() { + assertThrows(DataException.class, + () -> Time.fromLogical(Time.builder().name("invalid").build(), EPOCH.getTime())); + } + + @Test + public void testFromLogicalInvalidHasDateComponents() { + assertThrows(DataException.class, + () -> Time.fromLogical(Time.SCHEMA, EPOCH_PLUS_DATE_COMPONENT.getTime())); + } + + @Test + public void testToLogical() { + assertEquals(EPOCH.getTime(), Time.toLogical(Time.SCHEMA, 0)); + assertEquals(EPOCH_PLUS_TEN_THOUSAND_MILLIS.getTime(), Time.toLogical(Time.SCHEMA, 10000)); + } + + @Test + public void testToLogicalInvalidSchema() { + assertThrows(DataException.class, + () -> Time.toLogical(Time.builder().name("invalid").build(), 0)); + } +} diff --git a/connect/api/src/test/java/org/apache/kafka/connect/data/TimestampTest.java b/connect/api/src/test/java/org/apache/kafka/connect/data/TimestampTest.java new file mode 100644 index 0000000..94f67b4 --- /dev/null +++ b/connect/api/src/test/java/org/apache/kafka/connect/data/TimestampTest.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.data; + +import org.apache.kafka.connect.errors.DataException; +import org.junit.jupiter.api.Test; + +import java.util.Calendar; +import java.util.GregorianCalendar; +import java.util.TimeZone; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class TimestampTest { + private static final GregorianCalendar EPOCH; + private static final GregorianCalendar EPOCH_PLUS_MILLIS; + + private static final int NUM_MILLIS = 2000000000; + private static final long TOTAL_MILLIS = ((long) NUM_MILLIS) * 2; + + static { + EPOCH = new GregorianCalendar(1970, Calendar.JANUARY, 1, 0, 0, 0); + EPOCH.setTimeZone(TimeZone.getTimeZone("UTC")); + + + EPOCH_PLUS_MILLIS = new GregorianCalendar(1970, Calendar.JANUARY, 1, 0, 0, 0); + EPOCH_PLUS_MILLIS.setTimeZone(TimeZone.getTimeZone("UTC")); + EPOCH_PLUS_MILLIS.add(Calendar.MILLISECOND, NUM_MILLIS); + EPOCH_PLUS_MILLIS.add(Calendar.MILLISECOND, NUM_MILLIS); + } + + @Test + public void testBuilder() { + Schema plain = Date.SCHEMA; + assertEquals(Date.LOGICAL_NAME, plain.name()); + assertEquals(1, (Object) plain.version()); + } + + @Test + public void testFromLogical() { + assertEquals(0L, Timestamp.fromLogical(Timestamp.SCHEMA, EPOCH.getTime())); + assertEquals(TOTAL_MILLIS, Timestamp.fromLogical(Timestamp.SCHEMA, EPOCH_PLUS_MILLIS.getTime())); + } + + @Test + public void testFromLogicalInvalidSchema() { + assertThrows(DataException.class, + () -> Timestamp.fromLogical(Timestamp.builder().name("invalid").build(), EPOCH.getTime())); + } + + @Test + public void testToLogical() { + assertEquals(EPOCH.getTime(), Timestamp.toLogical(Timestamp.SCHEMA, 0L)); + assertEquals(EPOCH_PLUS_MILLIS.getTime(), Timestamp.toLogical(Timestamp.SCHEMA, TOTAL_MILLIS)); + } + + @Test + public void testToLogicalInvalidSchema() { + assertThrows(DataException.class, + () -> Date.toLogical(Date.builder().name("invalid").build(), 0)); + } +} diff --git a/connect/api/src/test/java/org/apache/kafka/connect/data/ValuesTest.java b/connect/api/src/test/java/org/apache/kafka/connect/data/ValuesTest.java new file mode 100644 index 0000000..3700a6e --- /dev/null +++ b/connect/api/src/test/java/org/apache/kafka/connect/data/ValuesTest.java @@ -0,0 +1,1003 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.data; + +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.connect.data.Schema.Type; +import org.apache.kafka.connect.data.Values.Parser; +import org.apache.kafka.connect.errors.DataException; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.charset.StandardCharsets; +import java.text.SimpleDateFormat; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class ValuesTest { + + private static final String WHITESPACE = "\n \t \t\n"; + + private static final long MILLIS_PER_DAY = 24 * 60 * 60 * 1000; + + private static final Map STRING_MAP = new LinkedHashMap<>(); + private static final Schema STRING_MAP_SCHEMA = SchemaBuilder.map(Schema.STRING_SCHEMA, Schema.STRING_SCHEMA).schema(); + + private static final Map STRING_SHORT_MAP = new LinkedHashMap<>(); + private static final Schema STRING_SHORT_MAP_SCHEMA = SchemaBuilder.map(Schema.STRING_SCHEMA, Schema.INT16_SCHEMA).schema(); + + private static final Map STRING_INT_MAP = new LinkedHashMap<>(); + private static final Schema STRING_INT_MAP_SCHEMA = SchemaBuilder.map(Schema.STRING_SCHEMA, Schema.INT32_SCHEMA).schema(); + + private static final List INT_LIST = new ArrayList<>(); + private static final Schema INT_LIST_SCHEMA = SchemaBuilder.array(Schema.INT32_SCHEMA).schema(); + + private static final List STRING_LIST = new ArrayList<>(); + private static final Schema STRING_LIST_SCHEMA = SchemaBuilder.array(Schema.STRING_SCHEMA).schema(); + + static { + STRING_MAP.put("foo", "123"); + STRING_MAP.put("bar", "baz"); + STRING_SHORT_MAP.put("foo", (short) 12345); + STRING_SHORT_MAP.put("bar", (short) 0); + STRING_SHORT_MAP.put("baz", (short) -4321); + STRING_INT_MAP.put("foo", 1234567890); + STRING_INT_MAP.put("bar", 0); + STRING_INT_MAP.put("baz", -987654321); + STRING_LIST.add("foo"); + STRING_LIST.add("bar"); + INT_LIST.add(1234567890); + INT_LIST.add(-987654321); + } + + @Test + @Timeout(5) + public void shouldNotEncounterInfiniteLoop() { + // This byte sequence gets parsed as CharacterIterator.DONE and can cause issues if + // comparisons to that character are done to check if the end of a string has been reached. + // For more information, see https://issues.apache.org/jira/browse/KAFKA-10574 + byte[] bytes = new byte[] {-17, -65, -65}; + String str = new String(bytes, StandardCharsets.UTF_8); + SchemaAndValue schemaAndValue = Values.parseString(str); + assertEquals(Type.STRING, schemaAndValue.schema().type()); + assertEquals(str, schemaAndValue.value()); + } + + @Test + public void shouldNotParseUnquotedEmbeddedMapKeysAsStrings() { + SchemaAndValue schemaAndValue = Values.parseString("{foo: 3}"); + assertEquals(Type.STRING, schemaAndValue.schema().type()); + assertEquals("{foo: 3}", schemaAndValue.value()); + } + + @Test + public void shouldNotParseUnquotedEmbeddedMapValuesAsStrings() { + SchemaAndValue schemaAndValue = Values.parseString("{3: foo}"); + assertEquals(Type.STRING, schemaAndValue.schema().type()); + assertEquals("{3: foo}", schemaAndValue.value()); + } + + @Test + public void shouldNotParseUnquotedArrayElementsAsStrings() { + SchemaAndValue schemaAndValue = Values.parseString("[foo]"); + assertEquals(Type.STRING, schemaAndValue.schema().type()); + assertEquals("[foo]", schemaAndValue.value()); + } + + @Test + public void shouldNotParseStringsBeginningWithNullAsStrings() { + SchemaAndValue schemaAndValue = Values.parseString("null="); + assertEquals(Type.STRING, schemaAndValue.schema().type()); + assertEquals("null=", schemaAndValue.value()); + } + + @Test + public void shouldParseStringsBeginningWithTrueAsStrings() { + SchemaAndValue schemaAndValue = Values.parseString("true}"); + assertEquals(Type.STRING, schemaAndValue.schema().type()); + assertEquals("true}", schemaAndValue.value()); + } + + @Test + public void shouldParseStringsBeginningWithFalseAsStrings() { + SchemaAndValue schemaAndValue = Values.parseString("false]"); + assertEquals(Type.STRING, schemaAndValue.schema().type()); + assertEquals("false]", schemaAndValue.value()); + } + + @Test + public void shouldParseTrueAsBooleanIfSurroundedByWhitespace() { + SchemaAndValue schemaAndValue = Values.parseString(WHITESPACE + "true" + WHITESPACE); + assertEquals(Type.BOOLEAN, schemaAndValue.schema().type()); + assertEquals(true, schemaAndValue.value()); + } + + @Test + public void shouldParseFalseAsBooleanIfSurroundedByWhitespace() { + SchemaAndValue schemaAndValue = Values.parseString(WHITESPACE + "false" + WHITESPACE); + assertEquals(Type.BOOLEAN, schemaAndValue.schema().type()); + assertEquals(false, schemaAndValue.value()); + } + + @Test + public void shouldParseNullAsNullIfSurroundedByWhitespace() { + SchemaAndValue schemaAndValue = Values.parseString(WHITESPACE + "null" + WHITESPACE); + assertNull(schemaAndValue); + } + + @Test + public void shouldParseBooleanLiteralsEmbeddedInArray() { + SchemaAndValue schemaAndValue = Values.parseString("[true, false]"); + assertEquals(Type.ARRAY, schemaAndValue.schema().type()); + assertEquals(Type.BOOLEAN, schemaAndValue.schema().valueSchema().type()); + assertEquals(Arrays.asList(true, false), schemaAndValue.value()); + } + + @Test + public void shouldParseBooleanLiteralsEmbeddedInMap() { + SchemaAndValue schemaAndValue = Values.parseString("{true: false, false: true}"); + assertEquals(Type.MAP, schemaAndValue.schema().type()); + assertEquals(Type.BOOLEAN, schemaAndValue.schema().keySchema().type()); + assertEquals(Type.BOOLEAN, schemaAndValue.schema().valueSchema().type()); + Map expectedValue = new HashMap<>(); + expectedValue.put(true, false); + expectedValue.put(false, true); + assertEquals(expectedValue, schemaAndValue.value()); + } + + @Test + public void shouldNotParseAsMapWithoutCommas() { + SchemaAndValue schemaAndValue = Values.parseString("{6:9 4:20}"); + assertEquals(Type.STRING, schemaAndValue.schema().type()); + assertEquals("{6:9 4:20}", schemaAndValue.value()); + } + + @Test + public void shouldNotParseAsArrayWithoutCommas() { + SchemaAndValue schemaAndValue = Values.parseString("[0 1 2]"); + assertEquals(Type.STRING, schemaAndValue.schema().type()); + assertEquals("[0 1 2]", schemaAndValue.value()); + } + + @Test + public void shouldParseEmptyMap() { + SchemaAndValue schemaAndValue = Values.parseString("{}"); + assertEquals(Type.MAP, schemaAndValue.schema().type()); + assertEquals(Collections.emptyMap(), schemaAndValue.value()); + } + + @Test + public void shouldParseEmptyArray() { + SchemaAndValue schemaAndValue = Values.parseString("[]"); + assertEquals(Type.ARRAY, schemaAndValue.schema().type()); + assertEquals(Collections.emptyList(), schemaAndValue.value()); + } + + @Test + public void shouldNotParseAsMapWithNullKeys() { + SchemaAndValue schemaAndValue = Values.parseString("{null: 3}"); + assertEquals(Type.STRING, schemaAndValue.schema().type()); + assertEquals("{null: 3}", schemaAndValue.value()); + } + + @Test + public void shouldParseNull() { + SchemaAndValue schemaAndValue = Values.parseString("null"); + assertNull(schemaAndValue); + } + + @Test + public void shouldConvertStringOfNull() { + assertRoundTrip(Schema.STRING_SCHEMA, "null"); + } + + @Test + public void shouldParseNullMapValues() { + SchemaAndValue schemaAndValue = Values.parseString("{3: null}"); + assertEquals(Type.MAP, schemaAndValue.schema().type()); + assertEquals(Type.INT8, schemaAndValue.schema().keySchema().type()); + assertEquals(Collections.singletonMap((byte) 3, null), schemaAndValue.value()); + } + + @Test + public void shouldParseNullArrayElements() { + SchemaAndValue schemaAndValue = Values.parseString("[null]"); + assertEquals(Type.ARRAY, schemaAndValue.schema().type()); + assertEquals(Collections.singletonList(null), schemaAndValue.value()); + } + + @Test + public void shouldEscapeStringsWithEmbeddedQuotesAndBackslashes() { + String original = "three\"blind\\\"mice"; + String expected = "three\\\"blind\\\\\\\"mice"; + assertEquals(expected, Values.escape(original)); + } + + @Test + public void shouldConvertNullValue() { + assertRoundTrip(Schema.STRING_SCHEMA, Schema.STRING_SCHEMA, null); + assertRoundTrip(Schema.OPTIONAL_STRING_SCHEMA, Schema.STRING_SCHEMA, null); + } + + @Test + public void shouldConvertBooleanValues() { + assertRoundTrip(Schema.BOOLEAN_SCHEMA, Schema.BOOLEAN_SCHEMA, Boolean.FALSE); + SchemaAndValue resultFalse = roundTrip(Schema.BOOLEAN_SCHEMA, "false"); + assertEquals(Schema.BOOLEAN_SCHEMA, resultFalse.schema()); + assertEquals(Boolean.FALSE, resultFalse.value()); + + assertRoundTrip(Schema.BOOLEAN_SCHEMA, Schema.BOOLEAN_SCHEMA, Boolean.TRUE); + SchemaAndValue resultTrue = roundTrip(Schema.BOOLEAN_SCHEMA, "true"); + assertEquals(Schema.BOOLEAN_SCHEMA, resultTrue.schema()); + assertEquals(Boolean.TRUE, resultTrue.value()); + } + + @Test + public void shouldFailToParseInvalidBooleanValueString() { + assertThrows(DataException.class, () -> Values.convertToBoolean(Schema.STRING_SCHEMA, "\"green\"")); + } + + @Test + public void shouldConvertSimpleString() { + assertRoundTrip(Schema.STRING_SCHEMA, "simple"); + } + + @Test + public void shouldConvertEmptyString() { + assertRoundTrip(Schema.STRING_SCHEMA, ""); + } + + @Test + public void shouldConvertStringWithQuotesAndOtherDelimiterCharacters() { + assertRoundTrip(Schema.STRING_SCHEMA, Schema.STRING_SCHEMA, "three\"blind\\\"mice"); + assertRoundTrip(Schema.STRING_SCHEMA, Schema.STRING_SCHEMA, "string with delimiters: <>?,./\\=+-!@#$%^&*(){}[]|;':"); + } + + @Test + public void shouldConvertMapWithStringKeys() { + assertRoundTrip(STRING_MAP_SCHEMA, STRING_MAP_SCHEMA, STRING_MAP); + } + + @Test + public void shouldParseStringOfMapWithStringValuesWithoutWhitespaceAsMap() { + SchemaAndValue result = roundTrip(STRING_MAP_SCHEMA, "{\"foo\":\"123\",\"bar\":\"baz\"}"); + assertEquals(STRING_MAP_SCHEMA, result.schema()); + assertEquals(STRING_MAP, result.value()); + } + + @Test + public void shouldParseStringOfMapWithStringValuesWithWhitespaceAsMap() { + SchemaAndValue result = roundTrip(STRING_MAP_SCHEMA, "{ \"foo\" : \"123\", \n\"bar\" : \"baz\" } "); + assertEquals(STRING_MAP_SCHEMA, result.schema()); + assertEquals(STRING_MAP, result.value()); + } + + @Test + public void shouldConvertMapWithStringKeysAndShortValues() { + assertRoundTrip(STRING_SHORT_MAP_SCHEMA, STRING_SHORT_MAP_SCHEMA, STRING_SHORT_MAP); + } + + @Test + public void shouldParseStringOfMapWithShortValuesWithoutWhitespaceAsMap() { + SchemaAndValue result = roundTrip(STRING_SHORT_MAP_SCHEMA, "{\"foo\":12345,\"bar\":0,\"baz\":-4321}"); + assertEquals(STRING_SHORT_MAP_SCHEMA, result.schema()); + assertEquals(STRING_SHORT_MAP, result.value()); + } + + @Test + public void shouldParseStringOfMapWithShortValuesWithWhitespaceAsMap() { + SchemaAndValue result = roundTrip(STRING_SHORT_MAP_SCHEMA, " { \"foo\" : 12345 , \"bar\" : 0, \"baz\" : -4321 } "); + assertEquals(STRING_SHORT_MAP_SCHEMA, result.schema()); + assertEquals(STRING_SHORT_MAP, result.value()); + } + + @Test + public void shouldConvertMapWithStringKeysAndIntegerValues() { + assertRoundTrip(STRING_INT_MAP_SCHEMA, STRING_INT_MAP_SCHEMA, STRING_INT_MAP); + } + + @Test + public void shouldParseStringOfMapWithIntValuesWithoutWhitespaceAsMap() { + SchemaAndValue result = roundTrip(STRING_INT_MAP_SCHEMA, "{\"foo\":1234567890,\"bar\":0,\"baz\":-987654321}"); + assertEquals(STRING_INT_MAP_SCHEMA, result.schema()); + assertEquals(STRING_INT_MAP, result.value()); + } + + @Test + public void shouldParseStringOfMapWithIntValuesWithWhitespaceAsMap() { + SchemaAndValue result = roundTrip(STRING_INT_MAP_SCHEMA, " { \"foo\" : 1234567890 , \"bar\" : 0, \"baz\" : -987654321 } "); + assertEquals(STRING_INT_MAP_SCHEMA, result.schema()); + assertEquals(STRING_INT_MAP, result.value()); + } + + @Test + public void shouldConvertListWithStringValues() { + assertRoundTrip(STRING_LIST_SCHEMA, STRING_LIST_SCHEMA, STRING_LIST); + } + + @Test + public void shouldConvertListWithIntegerValues() { + assertRoundTrip(INT_LIST_SCHEMA, INT_LIST_SCHEMA, INT_LIST); + } + + /** + * The parsed array has byte values and one int value, so we should return list with single unified type of integers. + */ + @Test + public void shouldConvertStringOfListWithOnlyNumericElementTypesIntoListOfLargestNumericType() { + int thirdValue = Short.MAX_VALUE + 1; + List list = Values.convertToList(Schema.STRING_SCHEMA, "[1, 2, " + thirdValue + "]"); + assertEquals(3, list.size()); + assertEquals(1, ((Number) list.get(0)).intValue()); + assertEquals(2, ((Number) list.get(1)).intValue()); + assertEquals(thirdValue, ((Number) list.get(2)).intValue()); + } + + /** + * The parsed array has byte values and one int value, so we should return list with single unified type of integers. + */ + @Test + public void shouldConvertStringOfListWithMixedElementTypesIntoListWithDifferentElementTypes() { + String str = "[1, 2, \"three\"]"; + List list = Values.convertToList(Schema.STRING_SCHEMA, str); + assertEquals(3, list.size()); + assertEquals(1, ((Number) list.get(0)).intValue()); + assertEquals(2, ((Number) list.get(1)).intValue()); + assertEquals("three", list.get(2)); + } + + /** + * We parse into different element types, but cannot infer a common element schema. + */ + @Test + public void shouldParseStringListWithMultipleElementTypesAndReturnListWithNoSchema() { + String str = "[1, 2, 3, \"four\"]"; + SchemaAndValue result = Values.parseString(str); + assertEquals(Type.ARRAY, result.schema().type()); + assertNull(result.schema().valueSchema()); + List list = (List) result.value(); + assertEquals(4, list.size()); + assertEquals(1, ((Number) list.get(0)).intValue()); + assertEquals(2, ((Number) list.get(1)).intValue()); + assertEquals(3, ((Number) list.get(2)).intValue()); + assertEquals("four", list.get(3)); + } + + /** + * We can't infer or successfully parse into a different type, so this returns the same string. + */ + @Test + public void shouldParseStringListWithExtraDelimitersAndReturnString() { + String str = "[1, 2, 3,,,]"; + SchemaAndValue result = Values.parseString(str); + assertEquals(Type.STRING, result.schema().type()); + assertEquals(str, result.value()); + } + + @Test + public void shouldParseTimestampStringAsTimestamp() throws Exception { + String str = "2019-08-23T14:34:54.346Z"; + SchemaAndValue result = Values.parseString(str); + assertEquals(Type.INT64, result.schema().type()); + assertEquals(Timestamp.LOGICAL_NAME, result.schema().name()); + java.util.Date expected = new SimpleDateFormat(Values.ISO_8601_TIMESTAMP_FORMAT_PATTERN).parse(str); + assertEquals(expected, result.value()); + } + + @Test + public void shouldParseDateStringAsDate() throws Exception { + String str = "2019-08-23"; + SchemaAndValue result = Values.parseString(str); + assertEquals(Type.INT32, result.schema().type()); + assertEquals(Date.LOGICAL_NAME, result.schema().name()); + java.util.Date expected = new SimpleDateFormat(Values.ISO_8601_DATE_FORMAT_PATTERN).parse(str); + assertEquals(expected, result.value()); + } + + @Test + public void shouldParseTimeStringAsDate() throws Exception { + String str = "14:34:54.346Z"; + SchemaAndValue result = Values.parseString(str); + assertEquals(Type.INT32, result.schema().type()); + assertEquals(Time.LOGICAL_NAME, result.schema().name()); + java.util.Date expected = new SimpleDateFormat(Values.ISO_8601_TIME_FORMAT_PATTERN).parse(str); + assertEquals(expected, result.value()); + } + + @Test + public void shouldParseTimestampStringWithEscapedColonsAsTimestamp() throws Exception { + String str = "2019-08-23T14\\:34\\:54.346Z"; + SchemaAndValue result = Values.parseString(str); + assertEquals(Type.INT64, result.schema().type()); + assertEquals(Timestamp.LOGICAL_NAME, result.schema().name()); + String expectedStr = "2019-08-23T14:34:54.346Z"; + java.util.Date expected = new SimpleDateFormat(Values.ISO_8601_TIMESTAMP_FORMAT_PATTERN).parse(expectedStr); + assertEquals(expected, result.value()); + } + + @Test + public void shouldParseTimeStringWithEscapedColonsAsDate() throws Exception { + String str = "14\\:34\\:54.346Z"; + SchemaAndValue result = Values.parseString(str); + assertEquals(Type.INT32, result.schema().type()); + assertEquals(Time.LOGICAL_NAME, result.schema().name()); + String expectedStr = "14:34:54.346Z"; + java.util.Date expected = new SimpleDateFormat(Values.ISO_8601_TIME_FORMAT_PATTERN).parse(expectedStr); + assertEquals(expected, result.value()); + } + + @Test + public void shouldParseDateStringAsDateInArray() throws Exception { + String dateStr = "2019-08-23"; + String arrayStr = "[" + dateStr + "]"; + SchemaAndValue result = Values.parseString(arrayStr); + assertEquals(Type.ARRAY, result.schema().type()); + Schema elementSchema = result.schema().valueSchema(); + assertEquals(Type.INT32, elementSchema.type()); + assertEquals(Date.LOGICAL_NAME, elementSchema.name()); + java.util.Date expected = new SimpleDateFormat(Values.ISO_8601_DATE_FORMAT_PATTERN).parse(dateStr); + assertEquals(Collections.singletonList(expected), result.value()); + } + + @Test + public void shouldParseTimeStringAsTimeInArray() throws Exception { + String timeStr = "14:34:54.346Z"; + String arrayStr = "[" + timeStr + "]"; + SchemaAndValue result = Values.parseString(arrayStr); + assertEquals(Type.ARRAY, result.schema().type()); + Schema elementSchema = result.schema().valueSchema(); + assertEquals(Type.INT32, elementSchema.type()); + assertEquals(Time.LOGICAL_NAME, elementSchema.name()); + java.util.Date expected = new SimpleDateFormat(Values.ISO_8601_TIME_FORMAT_PATTERN).parse(timeStr); + assertEquals(Collections.singletonList(expected), result.value()); + } + + @Test + public void shouldParseTimestampStringAsTimestampInArray() throws Exception { + String tsStr = "2019-08-23T14:34:54.346Z"; + String arrayStr = "[" + tsStr + "]"; + SchemaAndValue result = Values.parseString(arrayStr); + assertEquals(Type.ARRAY, result.schema().type()); + Schema elementSchema = result.schema().valueSchema(); + assertEquals(Type.INT64, elementSchema.type()); + assertEquals(Timestamp.LOGICAL_NAME, elementSchema.name()); + java.util.Date expected = new SimpleDateFormat(Values.ISO_8601_TIMESTAMP_FORMAT_PATTERN).parse(tsStr); + assertEquals(Collections.singletonList(expected), result.value()); + } + + @Test + public void shouldParseMultipleTimestampStringAsTimestampInArray() throws Exception { + String tsStr1 = "2019-08-23T14:34:54.346Z"; + String tsStr2 = "2019-01-23T15:12:34.567Z"; + String tsStr3 = "2019-04-23T19:12:34.567Z"; + String arrayStr = "[" + tsStr1 + "," + tsStr2 + ", " + tsStr3 + "]"; + SchemaAndValue result = Values.parseString(arrayStr); + assertEquals(Type.ARRAY, result.schema().type()); + Schema elementSchema = result.schema().valueSchema(); + assertEquals(Type.INT64, elementSchema.type()); + assertEquals(Timestamp.LOGICAL_NAME, elementSchema.name()); + java.util.Date expected1 = new SimpleDateFormat(Values.ISO_8601_TIMESTAMP_FORMAT_PATTERN).parse(tsStr1); + java.util.Date expected2 = new SimpleDateFormat(Values.ISO_8601_TIMESTAMP_FORMAT_PATTERN).parse(tsStr2); + java.util.Date expected3 = new SimpleDateFormat(Values.ISO_8601_TIMESTAMP_FORMAT_PATTERN).parse(tsStr3); + assertEquals(Arrays.asList(expected1, expected2, expected3), result.value()); + } + + @Test + public void shouldParseQuotedTimeStringAsTimeInMap() throws Exception { + String keyStr = "k1"; + String timeStr = "14:34:54.346Z"; + String mapStr = "{\"" + keyStr + "\":\"" + timeStr + "\"}"; + SchemaAndValue result = Values.parseString(mapStr); + assertEquals(Type.MAP, result.schema().type()); + Schema keySchema = result.schema().keySchema(); + Schema valueSchema = result.schema().valueSchema(); + assertEquals(Type.STRING, keySchema.type()); + assertEquals(Type.INT32, valueSchema.type()); + assertEquals(Time.LOGICAL_NAME, valueSchema.name()); + java.util.Date expected = new SimpleDateFormat(Values.ISO_8601_TIME_FORMAT_PATTERN).parse(timeStr); + assertEquals(Collections.singletonMap(keyStr, expected), result.value()); + } + + @Test + public void shouldParseTimeStringAsTimeInMap() throws Exception { + String keyStr = "k1"; + String timeStr = "14:34:54.346Z"; + String mapStr = "{\"" + keyStr + "\":" + timeStr + "}"; + SchemaAndValue result = Values.parseString(mapStr); + assertEquals(Type.MAP, result.schema().type()); + Schema keySchema = result.schema().keySchema(); + Schema valueSchema = result.schema().valueSchema(); + assertEquals(Type.STRING, keySchema.type()); + assertEquals(Type.INT32, valueSchema.type()); + assertEquals(Time.LOGICAL_NAME, valueSchema.name()); + java.util.Date expected = new SimpleDateFormat(Values.ISO_8601_TIME_FORMAT_PATTERN).parse(timeStr); + assertEquals(Collections.singletonMap(keyStr, expected), result.value()); + } + + /** + * This is technically invalid JSON, and we don't want to simply ignore the blank elements. + */ + @Test + public void shouldFailToConvertToListFromStringWithExtraDelimiters() { + assertThrows(DataException.class, () -> Values.convertToList(Schema.STRING_SCHEMA, "[1, 2, 3,,,]")); + } + + /** + * Schema of type ARRAY requires a schema for the values, but Connect has no union or "any" schema type. + * Therefore, we can't represent this. + */ + @Test + public void shouldFailToConvertToListFromStringWithNonCommonElementTypeAndBlankElement() { + assertThrows(DataException.class, () -> Values.convertToList(Schema.STRING_SCHEMA, "[1, 2, 3, \"four\",,,]")); + } + + /** + * This is technically invalid JSON, and we don't want to simply ignore the blank entry. + */ + @Test + public void shouldFailToParseStringOfMapWithIntValuesWithBlankEntry() { + assertThrows(DataException.class, + () -> Values.convertToMap(Schema.STRING_SCHEMA, " { \"foo\" : 1234567890 ,, \"bar\" : 0, \"baz\" : -987654321 } ")); + } + + /** + * This is technically invalid JSON, and we don't want to simply ignore the malformed entry. + */ + @Test + public void shouldFailToParseStringOfMalformedMap() { + assertThrows(DataException.class, + () -> Values.convertToMap(Schema.STRING_SCHEMA, " { \"foo\" : 1234567890 , \"a\", \"bar\" : 0, \"baz\" : -987654321 } ")); + } + + /** + * This is technically invalid JSON, and we don't want to simply ignore the blank entries. + */ + @Test + public void shouldFailToParseStringOfMapWithIntValuesWithOnlyBlankEntries() { + assertThrows(DataException.class, () -> Values.convertToMap(Schema.STRING_SCHEMA, " { ,, , , } ")); + } + + /** + * This is technically invalid JSON, and we don't want to simply ignore the blank entry. + */ + @Test + public void shouldFailToParseStringOfMapWithIntValuesWithBlankEntries() { + assertThrows(DataException.class, + () -> Values.convertToMap(Schema.STRING_SCHEMA, " { \"foo\" : \"1234567890\" ,, \"bar\" : \"0\", \"baz\" : \"boz\" } ")); + } + + @Test + public void shouldConsumeMultipleTokens() { + String value = "a:b:c:d:e:f:g:h"; + Parser parser = new Parser(value); + String firstFive = parser.next(5); + assertEquals("a:b:c", firstFive); + assertEquals(":", parser.next()); + assertEquals("d", parser.next()); + assertEquals(":", parser.next()); + String lastEight = parser.next(8); // only 7 remain + assertNull(lastEight); + assertEquals("e", parser.next()); + } + + @Test + public void shouldParseStringsWithoutDelimiters() { + //assertParsed(""); + assertParsed(" "); + assertParsed("simple"); + assertParsed("simple string"); + assertParsed("simple \n\t\bstring"); + assertParsed("'simple' string"); + assertParsed("si\\mple"); + assertParsed("si\\\\mple"); + } + + @Test + public void shouldParseStringsWithEscapedDelimiters() { + assertParsed("si\\\"mple"); + assertParsed("si\\{mple"); + assertParsed("si\\}mple"); + assertParsed("si\\]mple"); + assertParsed("si\\[mple"); + assertParsed("si\\:mple"); + assertParsed("si\\,mple"); + } + + @Test + public void shouldParseStringsWithSingleDelimiter() { + assertParsed("a{b", "a", "{", "b"); + assertParsed("a}b", "a", "}", "b"); + assertParsed("a[b", "a", "[", "b"); + assertParsed("a]b", "a", "]", "b"); + assertParsed("a:b", "a", ":", "b"); + assertParsed("a,b", "a", ",", "b"); + assertParsed("a\"b", "a", "\"", "b"); + assertParsed("{b", "{", "b"); + assertParsed("}b", "}", "b"); + assertParsed("[b", "[", "b"); + assertParsed("]b", "]", "b"); + assertParsed(":b", ":", "b"); + assertParsed(",b", ",", "b"); + assertParsed("\"b", "\"", "b"); + assertParsed("{", "{"); + assertParsed("}", "}"); + assertParsed("[", "["); + assertParsed("]", "]"); + assertParsed(":", ":"); + assertParsed(",", ","); + assertParsed("\"", "\""); + } + + @Test + public void shouldParseStringsWithMultipleDelimiters() { + assertParsed("\"simple\" string", "\"", "simple", "\"", " string"); + assertParsed("a{bc}d", "a", "{", "bc", "}", "d"); + assertParsed("a { b c } d", "a ", "{", " b c ", "}", " d"); + assertParsed("a { b c } d", "a ", "{", " b c ", "}", " d"); + } + + @Test + public void shouldConvertTimeValues() { + java.util.Date current = new java.util.Date(); + long currentMillis = current.getTime() % MILLIS_PER_DAY; + + // java.util.Date - just copy + java.util.Date t1 = Values.convertToTime(Time.SCHEMA, current); + assertEquals(current, t1); + + // java.util.Date as a Timestamp - discard the date and keep just day's milliseconds + t1 = Values.convertToTime(Timestamp.SCHEMA, current); + assertEquals(new java.util.Date(currentMillis), t1); + + // ISO8601 strings - currently broken because tokenization breaks at colon + + // Millis as string + java.util.Date t3 = Values.convertToTime(Time.SCHEMA, Long.toString(currentMillis)); + assertEquals(currentMillis, t3.getTime()); + + // Millis as long + java.util.Date t4 = Values.convertToTime(Time.SCHEMA, currentMillis); + assertEquals(currentMillis, t4.getTime()); + } + + @Test + public void shouldConvertDateValues() { + java.util.Date current = new java.util.Date(); + long currentMillis = current.getTime() % MILLIS_PER_DAY; + long days = current.getTime() / MILLIS_PER_DAY; + + // java.util.Date - just copy + java.util.Date d1 = Values.convertToDate(Date.SCHEMA, current); + assertEquals(current, d1); + + // java.util.Date as a Timestamp - discard the day's milliseconds and keep the date + java.util.Date currentDate = new java.util.Date(current.getTime() - currentMillis); + d1 = Values.convertToDate(Timestamp.SCHEMA, currentDate); + assertEquals(currentDate, d1); + + // ISO8601 strings - currently broken because tokenization breaks at colon + + // Days as string + java.util.Date d3 = Values.convertToDate(Date.SCHEMA, Long.toString(days)); + assertEquals(currentDate, d3); + + // Days as long + java.util.Date d4 = Values.convertToDate(Date.SCHEMA, days); + assertEquals(currentDate, d4); + } + + @Test + public void shouldConvertTimestampValues() { + java.util.Date current = new java.util.Date(); + long currentMillis = current.getTime() % MILLIS_PER_DAY; + + // java.util.Date - just copy + java.util.Date ts1 = Values.convertToTimestamp(Timestamp.SCHEMA, current); + assertEquals(current, ts1); + + // java.util.Date as a Timestamp - discard the day's milliseconds and keep the date + java.util.Date currentDate = new java.util.Date(current.getTime() - currentMillis); + ts1 = Values.convertToTimestamp(Date.SCHEMA, currentDate); + assertEquals(currentDate, ts1); + + // java.util.Date as a Time - discard the date and keep the day's milliseconds + ts1 = Values.convertToTimestamp(Time.SCHEMA, currentMillis); + assertEquals(new java.util.Date(currentMillis), ts1); + + // ISO8601 strings - currently broken because tokenization breaks at colon + + // Millis as string + java.util.Date ts3 = Values.convertToTimestamp(Timestamp.SCHEMA, Long.toString(current.getTime())); + assertEquals(current, ts3); + + // Millis as long + java.util.Date ts4 = Values.convertToTimestamp(Timestamp.SCHEMA, current.getTime()); + assertEquals(current, ts4); + } + + @Test + public void canConsume() { + } + + @Test + public void shouldParseBigIntegerAsDecimalWithZeroScale() { + BigInteger value = BigInteger.valueOf(Long.MAX_VALUE).add(new BigInteger("1")); + SchemaAndValue schemaAndValue = Values.parseString( + String.valueOf(value) + ); + assertEquals(Decimal.schema(0), schemaAndValue.schema()); + assertTrue(schemaAndValue.value() instanceof BigDecimal); + assertEquals(value, ((BigDecimal) schemaAndValue.value()).unscaledValue()); + value = BigInteger.valueOf(Long.MIN_VALUE).subtract(new BigInteger("1")); + schemaAndValue = Values.parseString( + String.valueOf(value) + ); + assertEquals(Decimal.schema(0), schemaAndValue.schema()); + assertTrue(schemaAndValue.value() instanceof BigDecimal); + assertEquals(value, ((BigDecimal) schemaAndValue.value()).unscaledValue()); + } + + @Test + public void shouldParseByteAsInt8() { + Byte value = Byte.MAX_VALUE; + SchemaAndValue schemaAndValue = Values.parseString( + String.valueOf(value) + ); + assertEquals(Schema.INT8_SCHEMA, schemaAndValue.schema()); + assertTrue(schemaAndValue.value() instanceof Byte); + assertEquals(value.byteValue(), ((Byte) schemaAndValue.value()).byteValue()); + value = Byte.MIN_VALUE; + schemaAndValue = Values.parseString( + String.valueOf(value) + ); + assertEquals(Schema.INT8_SCHEMA, schemaAndValue.schema()); + assertTrue(schemaAndValue.value() instanceof Byte); + assertEquals(value.byteValue(), ((Byte) schemaAndValue.value()).byteValue()); + } + + @Test + public void shouldParseShortAsInt16() { + Short value = Short.MAX_VALUE; + SchemaAndValue schemaAndValue = Values.parseString( + String.valueOf(value) + ); + assertEquals(Schema.INT16_SCHEMA, schemaAndValue.schema()); + assertTrue(schemaAndValue.value() instanceof Short); + assertEquals(value.shortValue(), ((Short) schemaAndValue.value()).shortValue()); + value = Short.MIN_VALUE; + schemaAndValue = Values.parseString( + String.valueOf(value) + ); + assertEquals(Schema.INT16_SCHEMA, schemaAndValue.schema()); + assertTrue(schemaAndValue.value() instanceof Short); + assertEquals(value.shortValue(), ((Short) schemaAndValue.value()).shortValue()); + } + + @Test + public void shouldParseIntegerAsInt32() { + Integer value = Integer.MAX_VALUE; + SchemaAndValue schemaAndValue = Values.parseString( + String.valueOf(value) + ); + assertEquals(Schema.INT32_SCHEMA, schemaAndValue.schema()); + assertTrue(schemaAndValue.value() instanceof Integer); + assertEquals(value.intValue(), ((Integer) schemaAndValue.value()).intValue()); + value = Integer.MIN_VALUE; + schemaAndValue = Values.parseString( + String.valueOf(value) + ); + assertEquals(Schema.INT32_SCHEMA, schemaAndValue.schema()); + assertTrue(schemaAndValue.value() instanceof Integer); + assertEquals(value.intValue(), ((Integer) schemaAndValue.value()).intValue()); + } + + @Test + public void shouldParseLongAsInt64() { + Long value = Long.MAX_VALUE; + SchemaAndValue schemaAndValue = Values.parseString( + String.valueOf(value) + ); + assertEquals(Schema.INT64_SCHEMA, schemaAndValue.schema()); + assertTrue(schemaAndValue.value() instanceof Long); + assertEquals(value.longValue(), ((Long) schemaAndValue.value()).longValue()); + value = Long.MIN_VALUE; + schemaAndValue = Values.parseString( + String.valueOf(value) + ); + assertEquals(Schema.INT64_SCHEMA, schemaAndValue.schema()); + assertTrue(schemaAndValue.value() instanceof Long); + assertEquals(value.longValue(), ((Long) schemaAndValue.value()).longValue()); + } + + @Test + public void shouldParseFloatAsFloat32() { + Float value = Float.MAX_VALUE; + SchemaAndValue schemaAndValue = Values.parseString( + String.valueOf(value) + ); + assertEquals(Schema.FLOAT32_SCHEMA, schemaAndValue.schema()); + assertTrue(schemaAndValue.value() instanceof Float); + assertEquals(value, (Float) schemaAndValue.value(), 0); + value = -Float.MAX_VALUE; + schemaAndValue = Values.parseString( + String.valueOf(value) + ); + assertEquals(Schema.FLOAT32_SCHEMA, schemaAndValue.schema()); + assertTrue(schemaAndValue.value() instanceof Float); + assertEquals(value, (Float) schemaAndValue.value(), 0); + } + + @Test + public void shouldParseDoubleAsFloat64() { + Double value = Double.MAX_VALUE; + SchemaAndValue schemaAndValue = Values.parseString( + String.valueOf(value) + ); + assertEquals(Schema.FLOAT64_SCHEMA, schemaAndValue.schema()); + assertTrue(schemaAndValue.value() instanceof Double); + assertEquals(value, (Double) schemaAndValue.value(), 0); + value = -Double.MAX_VALUE; + schemaAndValue = Values.parseString( + String.valueOf(value) + ); + assertEquals(Schema.FLOAT64_SCHEMA, schemaAndValue.schema()); + assertTrue(schemaAndValue.value() instanceof Double); + assertEquals(value, (Double) schemaAndValue.value(), 0); + } + + protected void assertParsed(String input) { + assertParsed(input, input); + } + + protected void assertParsed(String input, String... expectedTokens) { + Parser parser = new Parser(input); + if (!parser.hasNext()) { + assertEquals(1, expectedTokens.length); + assertTrue(expectedTokens[0].isEmpty()); + return; + } + + for (String expectedToken : expectedTokens) { + assertTrue(parser.hasNext()); + int position = parser.mark(); + assertEquals(expectedToken, parser.next()); + assertEquals(position + expectedToken.length(), parser.position()); + assertEquals(expectedToken, parser.previous()); + parser.rewindTo(position); + assertEquals(position, parser.position()); + assertEquals(expectedToken, parser.next()); + int newPosition = parser.mark(); + assertEquals(position + expectedToken.length(), newPosition); + assertEquals(expectedToken, parser.previous()); + } + assertFalse(parser.hasNext()); + + // Rewind and try consuming expected tokens ... + parser.rewindTo(0); + assertConsumable(parser, expectedTokens); + + // Parse again and try consuming expected tokens ... + parser = new Parser(input); + assertConsumable(parser, expectedTokens); + } + + protected void assertConsumable(Parser parser, String... expectedTokens) { + for (String expectedToken : expectedTokens) { + if (!Utils.isBlank(expectedToken)) { + int position = parser.mark(); + assertTrue(parser.canConsume(expectedToken.trim())); + parser.rewindTo(position); + assertTrue(parser.canConsume(expectedToken.trim(), true)); + parser.rewindTo(position); + assertTrue(parser.canConsume(expectedToken, false)); + } + } + } + + protected SchemaAndValue roundTrip(Schema desiredSchema, String currentValue) { + return roundTrip(desiredSchema, new SchemaAndValue(Schema.STRING_SCHEMA, currentValue)); + } + + protected SchemaAndValue roundTrip(Schema desiredSchema, SchemaAndValue input) { + String serialized = Values.convertToString(input.schema(), input.value()); + if (input != null && input.value() != null) { + assertNotNull(serialized); + } + if (desiredSchema == null) { + desiredSchema = Values.inferSchema(input); + assertNotNull(desiredSchema); + } + Object newValue = null; + Schema newSchema = null; + switch (desiredSchema.type()) { + case STRING: + newValue = Values.convertToString(Schema.STRING_SCHEMA, serialized); + break; + case INT8: + newValue = Values.convertToByte(Schema.STRING_SCHEMA, serialized); + break; + case INT16: + newValue = Values.convertToShort(Schema.STRING_SCHEMA, serialized); + break; + case INT32: + newValue = Values.convertToInteger(Schema.STRING_SCHEMA, serialized); + break; + case INT64: + newValue = Values.convertToLong(Schema.STRING_SCHEMA, serialized); + break; + case FLOAT32: + newValue = Values.convertToFloat(Schema.STRING_SCHEMA, serialized); + break; + case FLOAT64: + newValue = Values.convertToDouble(Schema.STRING_SCHEMA, serialized); + break; + case BOOLEAN: + newValue = Values.convertToBoolean(Schema.STRING_SCHEMA, serialized); + break; + case ARRAY: + newValue = Values.convertToList(Schema.STRING_SCHEMA, serialized); + break; + case MAP: + newValue = Values.convertToMap(Schema.STRING_SCHEMA, serialized); + break; + case STRUCT: + newValue = Values.convertToStruct(Schema.STRING_SCHEMA, serialized); + break; + case BYTES: + fail("unexpected schema type"); + break; + } + newSchema = Values.inferSchema(newValue); + return new SchemaAndValue(newSchema, newValue); + } + + protected void assertRoundTrip(Schema schema, String value) { + assertRoundTrip(schema, Schema.STRING_SCHEMA, value); + } + + protected void assertRoundTrip(Schema schema, Schema currentSchema, Object value) { + SchemaAndValue result = roundTrip(schema, new SchemaAndValue(currentSchema, value)); + + if (value == null) { + assertNull(result.schema()); + assertNull(result.value()); + } else { + assertEquals(value, result.value()); + assertEquals(schema, result.schema()); + + SchemaAndValue result2 = roundTrip(result.schema(), result); + assertEquals(schema, result2.schema()); + assertEquals(value, result2.value()); + assertEquals(result, result2); + } + } +} diff --git a/connect/api/src/test/java/org/apache/kafka/connect/header/ConnectHeaderTest.java b/connect/api/src/test/java/org/apache/kafka/connect/header/ConnectHeaderTest.java new file mode 100644 index 0000000..8a84d44 --- /dev/null +++ b/connect/api/src/test/java/org/apache/kafka/connect/header/ConnectHeaderTest.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.header; + +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.data.SchemaBuilder; +import org.apache.kafka.connect.data.Struct; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; + +public class ConnectHeaderTest { + + private String key; + private ConnectHeader header; + + @BeforeEach + public void beforeEach() { + key = "key"; + withString("value"); + } + + protected Header withValue(Schema schema, Object value) { + header = new ConnectHeader(key, new SchemaAndValue(schema, value)); + return header; + } + + protected Header withString(String value) { + return withValue(Schema.STRING_SCHEMA, value); + } + + @Test + public void shouldAllowNullValues() { + withValue(Schema.OPTIONAL_STRING_SCHEMA, null); + } + + @Test + public void shouldAllowNullSchema() { + withValue(null, null); + assertNull(header.schema()); + assertNull(header.value()); + + String value = "non-null value"; + withValue(null, value); + assertNull(header.schema()); + assertSame(value, header.value()); + } + + @Test + public void shouldAllowNonNullValue() { + String value = "non-null value"; + withValue(Schema.STRING_SCHEMA, value); + assertSame(Schema.STRING_SCHEMA, header.schema()); + assertEquals(value, header.value()); + + withValue(Schema.BOOLEAN_SCHEMA, true); + assertSame(Schema.BOOLEAN_SCHEMA, header.schema()); + assertEquals(true, header.value()); + } + + @Test + public void shouldGetSchemaFromStruct() { + Schema schema = SchemaBuilder.struct() + .field("foo", Schema.STRING_SCHEMA) + .field("bar", Schema.INT32_SCHEMA) + .build(); + Struct value = new Struct(schema); + value.put("foo", "value"); + value.put("bar", 100); + withValue(null, value); + assertSame(schema, header.schema()); + assertSame(value, header.value()); + } + + @Test + public void shouldSatisfyEquals() { + String value = "non-null value"; + Header h1 = withValue(Schema.STRING_SCHEMA, value); + assertSame(Schema.STRING_SCHEMA, header.schema()); + assertEquals(value, header.value()); + + Header h2 = withValue(Schema.STRING_SCHEMA, value); + assertEquals(h1, h2); + assertEquals(h1.hashCode(), h2.hashCode()); + + Header h3 = withValue(Schema.INT8_SCHEMA, 100); + assertNotEquals(h3, h2); + } +} \ No newline at end of file diff --git a/connect/api/src/test/java/org/apache/kafka/connect/header/ConnectHeadersTest.java b/connect/api/src/test/java/org/apache/kafka/connect/header/ConnectHeadersTest.java new file mode 100644 index 0000000..b9b9174 --- /dev/null +++ b/connect/api/src/test/java/org/apache/kafka/connect/header/ConnectHeadersTest.java @@ -0,0 +1,575 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.header; + +import org.apache.kafka.connect.data.Date; +import org.apache.kafka.connect.data.Decimal; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.Schema.Type; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.data.SchemaBuilder; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.data.Time; +import org.apache.kafka.connect.data.Timestamp; +import org.apache.kafka.connect.data.Values; +import org.apache.kafka.connect.errors.DataException; +import org.apache.kafka.connect.header.Headers.HeaderTransform; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Calendar; +import java.util.Collections; +import java.util.GregorianCalendar; +import java.util.HashMap; +import java.util.Iterator; +import java.util.TimeZone; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class ConnectHeadersTest { + + private static final GregorianCalendar EPOCH_PLUS_TEN_THOUSAND_DAYS; + private static final GregorianCalendar EPOCH_PLUS_TEN_THOUSAND_MILLIS; + + static { + EPOCH_PLUS_TEN_THOUSAND_DAYS = new GregorianCalendar(1970, Calendar.JANUARY, 1, 0, 0, 0); + EPOCH_PLUS_TEN_THOUSAND_DAYS.setTimeZone(TimeZone.getTimeZone("UTC")); + EPOCH_PLUS_TEN_THOUSAND_DAYS.add(Calendar.DATE, 10000); + + EPOCH_PLUS_TEN_THOUSAND_MILLIS = new GregorianCalendar(1970, Calendar.JANUARY, 1, 0, 0, 0); + EPOCH_PLUS_TEN_THOUSAND_MILLIS.setTimeZone(TimeZone.getTimeZone("UTC")); + EPOCH_PLUS_TEN_THOUSAND_MILLIS.add(Calendar.MILLISECOND, 10000); + } + + private ConnectHeaders headers; + private Iterator
            iter; + private String key; + private String other; + + @BeforeEach + public void beforeEach() { + headers = new ConnectHeaders(); + key = "k1"; + other = "other key"; + } + + @Test + public void shouldNotAllowNullKey() { + assertThrows(NullPointerException.class, + () -> headers.add(null, "value", Schema.STRING_SCHEMA)); + } + + protected void populate(Headers headers) { + headers.addBoolean(key, true); + headers.addInt(key, 0); + headers.addString(other, "other value"); + headers.addString(key, null); + headers.addString(key, "third"); + } + + @Test + public void shouldBeEquals() { + Headers other = new ConnectHeaders(); + assertEquals(headers, other); + assertEquals(headers.hashCode(), other.hashCode()); + + populate(headers); + assertNotEquals(headers, other); + assertNotEquals(headers.hashCode(), other.hashCode()); + + populate(other); + assertEquals(headers, other); + assertEquals(headers.hashCode(), other.hashCode()); + + headers.addString("wow", "some value"); + assertNotEquals(headers, other); + } + + @Test + public void shouldHaveToString() { + // empty + assertNotNull(headers.toString()); + + // not empty + populate(headers); + assertNotNull(headers.toString()); + } + + @Test + public void shouldRetainLatestWhenEmpty() { + headers.retainLatest(other); + headers.retainLatest(key); + headers.retainLatest(); + assertTrue(headers.isEmpty()); + } + + @Test + public void shouldAddMultipleHeadersWithSameKeyAndRetainLatest() { + populate(headers); + + Header header = headers.lastWithName(key); + assertHeader(header, key, Schema.STRING_SCHEMA, "third"); + + iter = headers.allWithName(key); + assertNextHeader(iter, key, Schema.BOOLEAN_SCHEMA, true); + assertNextHeader(iter, key, Schema.INT32_SCHEMA, 0); + assertNextHeader(iter, key, Schema.OPTIONAL_STRING_SCHEMA, null); + assertNextHeader(iter, key, Schema.STRING_SCHEMA, "third"); + assertNoNextHeader(iter); + + iter = headers.allWithName(other); + assertOnlyNextHeader(iter, other, Schema.STRING_SCHEMA, "other value"); + + headers.retainLatest(other); + assertOnlySingleHeader(other, Schema.STRING_SCHEMA, "other value"); + + headers.retainLatest(key); + assertOnlySingleHeader(key, Schema.STRING_SCHEMA, "third"); + + headers.retainLatest(); + assertOnlySingleHeader(other, Schema.STRING_SCHEMA, "other value"); + assertOnlySingleHeader(key, Schema.STRING_SCHEMA, "third"); + } + + @Test + public void shouldAddHeadersWithPrimitiveValues() { + String key = "k1"; + headers.addBoolean(key, true); + headers.addByte(key, (byte) 0); + headers.addShort(key, (short) 0); + headers.addInt(key, 0); + headers.addLong(key, 0); + headers.addFloat(key, 1.0f); + headers.addDouble(key, 1.0d); + headers.addString(key, null); + headers.addString(key, "third"); + } + + @Test + public void shouldAddHeadersWithNullObjectValuesWithOptionalSchema() { + addHeader("k1", Schema.BOOLEAN_SCHEMA, true); + addHeader("k2", Schema.STRING_SCHEMA, "hello"); + addHeader("k3", Schema.OPTIONAL_STRING_SCHEMA, null); + } + + @Test + public void shouldNotAddHeadersWithNullObjectValuesWithNonOptionalSchema() { + attemptAndFailToAddHeader("k1", Schema.BOOLEAN_SCHEMA, null); + attemptAndFailToAddHeader("k2", Schema.STRING_SCHEMA, null); + } + + @Test + public void shouldNotAddHeadersWithObjectValuesAndMismatchedSchema() { + attemptAndFailToAddHeader("k1", Schema.BOOLEAN_SCHEMA, "wrong"); + attemptAndFailToAddHeader("k2", Schema.OPTIONAL_STRING_SCHEMA, 0L); + } + + @Test + public void shouldRemoveAllHeadersWithSameKeyWhenEmpty() { + headers.remove(key); + assertNoHeaderWithKey(key); + } + + @Test + public void shouldRemoveAllHeadersWithSameKey() { + populate(headers); + + iter = headers.allWithName(key); + assertContainsHeader(key, Schema.BOOLEAN_SCHEMA, true); + assertContainsHeader(key, Schema.INT32_SCHEMA, 0); + assertContainsHeader(key, Schema.STRING_SCHEMA, "third"); + assertOnlySingleHeader(other, Schema.STRING_SCHEMA, "other value"); + + headers.remove(key); + assertNoHeaderWithKey(key); + assertOnlySingleHeader(other, Schema.STRING_SCHEMA, "other value"); + } + + @Test + public void shouldRemoveAllHeaders() { + populate(headers); + + iter = headers.allWithName(key); + assertContainsHeader(key, Schema.BOOLEAN_SCHEMA, true); + assertContainsHeader(key, Schema.INT32_SCHEMA, 0); + assertContainsHeader(key, Schema.STRING_SCHEMA, "third"); + assertOnlySingleHeader(other, Schema.STRING_SCHEMA, "other value"); + + headers.clear(); + assertNoHeaderWithKey(key); + assertNoHeaderWithKey(other); + assertEquals(0, headers.size()); + assertTrue(headers.isEmpty()); + } + + @Test + public void shouldTransformHeadersWhenEmpty() { + headers.apply(appendToKey("-suffix")); + headers.apply(key, appendToKey("-suffix")); + assertTrue(headers.isEmpty()); + } + + @Test + public void shouldTransformHeaders() { + populate(headers); + + iter = headers.allWithName(key); + assertNextHeader(iter, key, Schema.BOOLEAN_SCHEMA, true); + assertNextHeader(iter, key, Schema.INT32_SCHEMA, 0); + assertNextHeader(iter, key, Schema.OPTIONAL_STRING_SCHEMA, null); + assertNextHeader(iter, key, Schema.STRING_SCHEMA, "third"); + assertNoNextHeader(iter); + + iter = headers.allWithName(other); + assertOnlyNextHeader(iter, other, Schema.STRING_SCHEMA, "other value"); + + // Transform the headers + assertEquals(5, headers.size()); + headers.apply(appendToKey("-suffix")); + assertEquals(5, headers.size()); + + assertNoHeaderWithKey(key); + assertNoHeaderWithKey(other); + + String altKey = key + "-suffix"; + iter = headers.allWithName(altKey); + assertNextHeader(iter, altKey, Schema.BOOLEAN_SCHEMA, true); + assertNextHeader(iter, altKey, Schema.INT32_SCHEMA, 0); + assertNextHeader(iter, altKey, Schema.OPTIONAL_STRING_SCHEMA, null); + assertNextHeader(iter, altKey, Schema.STRING_SCHEMA, "third"); + assertNoNextHeader(iter); + + iter = headers.allWithName(other + "-suffix"); + assertOnlyNextHeader(iter, other + "-suffix", Schema.STRING_SCHEMA, "other value"); + } + + @Test + public void shouldTransformHeadersWithKey() { + populate(headers); + + iter = headers.allWithName(key); + assertNextHeader(iter, key, Schema.BOOLEAN_SCHEMA, true); + assertNextHeader(iter, key, Schema.INT32_SCHEMA, 0); + assertNextHeader(iter, key, Schema.OPTIONAL_STRING_SCHEMA, null); + assertNextHeader(iter, key, Schema.STRING_SCHEMA, "third"); + assertNoNextHeader(iter); + + iter = headers.allWithName(other); + assertOnlyNextHeader(iter, other, Schema.STRING_SCHEMA, "other value"); + + // Transform the headers + assertEquals(5, headers.size()); + headers.apply(key, appendToKey("-suffix")); + assertEquals(5, headers.size()); + + assertNoHeaderWithKey(key); + + String altKey = key + "-suffix"; + iter = headers.allWithName(altKey); + assertNextHeader(iter, altKey, Schema.BOOLEAN_SCHEMA, true); + assertNextHeader(iter, altKey, Schema.INT32_SCHEMA, 0); + assertNextHeader(iter, altKey, Schema.OPTIONAL_STRING_SCHEMA, null); + assertNextHeader(iter, altKey, Schema.STRING_SCHEMA, "third"); + assertNoNextHeader(iter); + + iter = headers.allWithName(other); + assertOnlyNextHeader(iter, other, Schema.STRING_SCHEMA, "other value"); + } + + @Test + public void shouldTransformAndRemoveHeaders() { + populate(headers); + + iter = headers.allWithName(key); + assertNextHeader(iter, key, Schema.BOOLEAN_SCHEMA, true); + assertNextHeader(iter, key, Schema.INT32_SCHEMA, 0); + assertNextHeader(iter, key, Schema.OPTIONAL_STRING_SCHEMA, null); + assertNextHeader(iter, key, Schema.STRING_SCHEMA, "third"); + assertNoNextHeader(iter); + + iter = headers.allWithName(other); + assertOnlyNextHeader(iter, other, Schema.STRING_SCHEMA, "other value"); + + // Transform the headers + assertEquals(5, headers.size()); + headers.apply(key, removeHeadersOfType(Type.STRING)); + assertEquals(3, headers.size()); + + iter = headers.allWithName(key); + assertNextHeader(iter, key, Schema.BOOLEAN_SCHEMA, true); + assertNextHeader(iter, key, Schema.INT32_SCHEMA, 0); + assertNoNextHeader(iter); + + assertHeader(headers.lastWithName(key), key, Schema.INT32_SCHEMA, 0); + + iter = headers.allWithName(other); + assertOnlyNextHeader(iter, other, Schema.STRING_SCHEMA, "other value"); + + // Transform the headers + assertEquals(3, headers.size()); + headers.apply(removeHeadersOfType(Type.STRING)); + assertEquals(2, headers.size()); + + assertNoHeaderWithKey(other); + + iter = headers.allWithName(key); + assertNextHeader(iter, key, Schema.BOOLEAN_SCHEMA, true); + assertNextHeader(iter, key, Schema.INT32_SCHEMA, 0); + assertNoNextHeader(iter); + } + + protected HeaderTransform appendToKey(final String suffix) { + return header -> header.rename(header.key() + suffix); + } + + protected HeaderTransform removeHeadersOfType(final Type type) { + return header -> { + Schema schema = header.schema(); + if (schema != null && schema.type() == type) { + return null; + } + return header; + }; + } + + @Test + public void shouldValidateBuildInTypes() { + assertSchemaMatches(Schema.OPTIONAL_BOOLEAN_SCHEMA, null); + assertSchemaMatches(Schema.OPTIONAL_BYTES_SCHEMA, null); + assertSchemaMatches(Schema.OPTIONAL_INT8_SCHEMA, null); + assertSchemaMatches(Schema.OPTIONAL_INT16_SCHEMA, null); + assertSchemaMatches(Schema.OPTIONAL_INT32_SCHEMA, null); + assertSchemaMatches(Schema.OPTIONAL_INT64_SCHEMA, null); + assertSchemaMatches(Schema.OPTIONAL_FLOAT32_SCHEMA, null); + assertSchemaMatches(Schema.OPTIONAL_FLOAT64_SCHEMA, null); + assertSchemaMatches(Schema.OPTIONAL_STRING_SCHEMA, null); + assertSchemaMatches(Schema.BOOLEAN_SCHEMA, true); + assertSchemaMatches(Schema.BYTES_SCHEMA, new byte[]{}); + assertSchemaMatches(Schema.INT8_SCHEMA, (byte) 0); + assertSchemaMatches(Schema.INT16_SCHEMA, (short) 0); + assertSchemaMatches(Schema.INT32_SCHEMA, 0); + assertSchemaMatches(Schema.INT64_SCHEMA, 0L); + assertSchemaMatches(Schema.FLOAT32_SCHEMA, 1.0f); + assertSchemaMatches(Schema.FLOAT64_SCHEMA, 1.0d); + assertSchemaMatches(Schema.STRING_SCHEMA, "value"); + assertSchemaMatches(SchemaBuilder.array(Schema.STRING_SCHEMA), new ArrayList()); + assertSchemaMatches(SchemaBuilder.array(Schema.STRING_SCHEMA), Collections.singletonList("value")); + assertSchemaMatches(SchemaBuilder.map(Schema.STRING_SCHEMA, Schema.INT32_SCHEMA), new HashMap()); + assertSchemaMatches(SchemaBuilder.map(Schema.STRING_SCHEMA, Schema.INT32_SCHEMA), Collections.singletonMap("a", 0)); + Schema emptyStructSchema = SchemaBuilder.struct(); + assertSchemaMatches(emptyStructSchema, new Struct(emptyStructSchema)); + Schema structSchema = SchemaBuilder.struct().field("foo", Schema.OPTIONAL_BOOLEAN_SCHEMA).field("bar", Schema.STRING_SCHEMA) + .schema(); + assertSchemaMatches(structSchema, new Struct(structSchema).put("foo", true).put("bar", "v")); + } + + @Test + public void shouldValidateLogicalTypes() { + assertSchemaMatches(Decimal.schema(3), new BigDecimal(100.00)); + assertSchemaMatches(Time.SCHEMA, new java.util.Date()); + assertSchemaMatches(Date.SCHEMA, new java.util.Date()); + assertSchemaMatches(Timestamp.SCHEMA, new java.util.Date()); + } + + @Test + public void shouldNotValidateNullValuesWithBuiltInTypes() { + assertSchemaDoesNotMatch(Schema.BOOLEAN_SCHEMA, null); + assertSchemaDoesNotMatch(Schema.BYTES_SCHEMA, null); + assertSchemaDoesNotMatch(Schema.INT8_SCHEMA, null); + assertSchemaDoesNotMatch(Schema.INT16_SCHEMA, null); + assertSchemaDoesNotMatch(Schema.INT32_SCHEMA, null); + assertSchemaDoesNotMatch(Schema.INT64_SCHEMA, null); + assertSchemaDoesNotMatch(Schema.FLOAT32_SCHEMA, null); + assertSchemaDoesNotMatch(Schema.FLOAT64_SCHEMA, null); + assertSchemaDoesNotMatch(Schema.STRING_SCHEMA, null); + assertSchemaDoesNotMatch(SchemaBuilder.array(Schema.STRING_SCHEMA), null); + assertSchemaDoesNotMatch(SchemaBuilder.map(Schema.STRING_SCHEMA, Schema.INT32_SCHEMA), null); + assertSchemaDoesNotMatch(SchemaBuilder.struct(), null); + } + + @Test + public void shouldNotValidateMismatchedValuesWithBuiltInTypes() { + assertSchemaDoesNotMatch(Schema.BOOLEAN_SCHEMA, 0L); + assertSchemaDoesNotMatch(Schema.BYTES_SCHEMA, "oops"); + assertSchemaDoesNotMatch(Schema.INT8_SCHEMA, 1.0f); + assertSchemaDoesNotMatch(Schema.INT16_SCHEMA, 1.0f); + assertSchemaDoesNotMatch(Schema.INT32_SCHEMA, 0L); + assertSchemaDoesNotMatch(Schema.INT64_SCHEMA, 1.0f); + assertSchemaDoesNotMatch(Schema.FLOAT32_SCHEMA, 1L); + assertSchemaDoesNotMatch(Schema.FLOAT64_SCHEMA, 1L); + assertSchemaDoesNotMatch(Schema.STRING_SCHEMA, true); + assertSchemaDoesNotMatch(SchemaBuilder.array(Schema.STRING_SCHEMA), "value"); + assertSchemaDoesNotMatch(SchemaBuilder.map(Schema.STRING_SCHEMA, Schema.INT32_SCHEMA), "value"); + assertSchemaDoesNotMatch(SchemaBuilder.struct(), new ArrayList()); + } + + @Test + public void shouldAddDate() { + java.util.Date dateObj = EPOCH_PLUS_TEN_THOUSAND_DAYS.getTime(); + int days = Date.fromLogical(Date.SCHEMA, dateObj); + headers.addDate(key, dateObj); + Header header = headers.lastWithName(key); + assertEquals(days, (int) Values.convertToInteger(header.schema(), header.value())); + assertSame(dateObj, Values.convertToDate(header.schema(), header.value())); + + headers.addInt(other, days); + header = headers.lastWithName(other); + assertEquals(days, (int) Values.convertToInteger(header.schema(), header.value())); + assertEquals(dateObj, Values.convertToDate(header.schema(), header.value())); + } + + @Test + public void shouldAddTime() { + java.util.Date dateObj = EPOCH_PLUS_TEN_THOUSAND_MILLIS.getTime(); + long millis = Time.fromLogical(Time.SCHEMA, dateObj); + headers.addTime(key, dateObj); + Header header = headers.lastWithName(key); + assertEquals(millis, (long) Values.convertToLong(header.schema(), header.value())); + assertSame(dateObj, Values.convertToTime(header.schema(), header.value())); + + headers.addLong(other, millis); + header = headers.lastWithName(other); + assertEquals(millis, (long) Values.convertToLong(header.schema(), header.value())); + assertEquals(dateObj, Values.convertToTime(header.schema(), header.value())); + } + + @Test + public void shouldAddTimestamp() { + java.util.Date dateObj = EPOCH_PLUS_TEN_THOUSAND_MILLIS.getTime(); + long millis = Timestamp.fromLogical(Timestamp.SCHEMA, dateObj); + headers.addTimestamp(key, dateObj); + Header header = headers.lastWithName(key); + assertEquals(millis, (long) Values.convertToLong(header.schema(), header.value())); + assertSame(dateObj, Values.convertToTimestamp(header.schema(), header.value())); + + headers.addLong(other, millis); + header = headers.lastWithName(other); + assertEquals(millis, (long) Values.convertToLong(header.schema(), header.value())); + assertEquals(dateObj, Values.convertToTimestamp(header.schema(), header.value())); + } + + @Test + public void shouldAddDecimal() { + BigDecimal value = new BigDecimal("3.038573478e+3"); + headers.addDecimal(key, value); + Header header = headers.lastWithName(key); + assertEquals(value.doubleValue(), Values.convertToDouble(header.schema(), header.value()), 0.00001d); + assertEquals(value, Values.convertToDecimal(header.schema(), header.value(), value.scale())); + + value = value.setScale(3, RoundingMode.DOWN); + BigDecimal decimal = Values.convertToDecimal(header.schema(), header.value(), value.scale()); + assertEquals(value, decimal.setScale(value.scale(), RoundingMode.DOWN)); + } + + @Test + public void shouldDuplicateAndAlwaysReturnEquivalentButDifferentObject() { + assertEquals(headers, headers.duplicate()); + assertNotSame(headers, headers.duplicate()); + } + + @Test + public void shouldNotAllowToAddNullHeader() { + final ConnectHeaders headers = new ConnectHeaders(); + assertThrows(NullPointerException.class, () -> headers.add(null)); + } + + @Test + public void shouldThrowNpeWhenAddingCollectionWithNullHeader() { + final Iterable
            header = Arrays.asList(new ConnectHeader[1]); + assertThrows(NullPointerException.class, () -> new ConnectHeaders(header)); + } + + protected void assertSchemaMatches(Schema schema, Object value) { + headers.checkSchemaMatches(new SchemaAndValue(schema.schema(), value)); + } + + protected void assertSchemaDoesNotMatch(Schema schema, Object value) { + try { + assertSchemaMatches(schema, value); + fail("Should have failed to validate value '" + value + "' and schema: " + schema); + } catch (DataException e) { + // expected + } + } + + protected void attemptAndFailToAddHeader(String key, Schema schema, Object value) { + try { + headers.add(key, value, schema); + fail("Should have failed to add header with key '" + key + "', value '" + value + "', and schema: " + schema); + } catch (DataException e) { + // expected + } + } + + protected void addHeader(String key, Schema schema, Object value) { + headers.add(key, value, schema); + Header header = headers.lastWithName(key); + assertNotNull(header); + assertHeader(header, key, schema, value); + } + + protected void assertNoHeaderWithKey(String key) { + assertNoNextHeader(headers.allWithName(key)); + } + + protected void assertContainsHeader(String key, Schema schema, Object value) { + Header expected = new ConnectHeader(key, new SchemaAndValue(schema, value)); + Iterator
            iter = headers.allWithName(key); + while (iter.hasNext()) { + Header header = iter.next(); + if (header.equals(expected)) + return; + } + fail("Should have found header " + expected); + } + + protected void assertOnlySingleHeader(String key, Schema schema, Object value) { + assertOnlyNextHeader(headers.allWithName(key), key, schema, value); + } + + protected void assertOnlyNextHeader(Iterator
            iter, String key, Schema schema, Object value) { + assertNextHeader(iter, key, schema, value); + assertNoNextHeader(iter); + } + + protected void assertNextHeader(Iterator
            iter, String key, Schema schema, Object value) { + Header header = iter.next(); + assertHeader(header, key, schema, value); + } + + protected void assertNoNextHeader(Iterator
            iter) { + assertFalse(iter.hasNext()); + } + + protected void assertHeader(Header header, String key, Schema schema, Object value) { + assertNotNull(header); + assertSame(schema, header.schema()); + assertSame(value, header.value()); + } +} diff --git a/connect/api/src/test/java/org/apache/kafka/connect/sink/SinkConnectorTest.java b/connect/api/src/test/java/org/apache/kafka/connect/sink/SinkConnectorTest.java new file mode 100644 index 0000000..2cf2278 --- /dev/null +++ b/connect/api/src/test/java/org/apache/kafka/connect/sink/SinkConnectorTest.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.sink; + +import java.util.List; +import java.util.Map; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.connector.ConnectorContext; +import org.apache.kafka.connect.connector.ConnectorTest; +import org.apache.kafka.connect.connector.Task; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class SinkConnectorTest extends ConnectorTest { + + @Override + protected TestSinkConnectorContext createContext() { + return new TestSinkConnectorContext(); + } + + @Override + protected TestSinkConnector createConnector() { + return new TestSinkConnector(); + } + + private static class TestSinkConnectorContext implements SinkConnectorContext { + + @Override + public void requestTaskReconfiguration() { + // Unexpected in these tests + throw new UnsupportedOperationException(); + } + + @Override + public void raiseError(Exception e) { + // Unexpected in these tests + throw new UnsupportedOperationException(); + } + } + + protected static class TestSinkConnector extends SinkConnector implements ConnectorTest.AssertableConnector { + + public static final String VERSION = "an entirely different version"; + + private boolean initialized; + private List> taskConfigs; + private Map props; + private boolean started; + private boolean stopped; + + @Override + public String version() { + return VERSION; + } + + @Override + public void initialize(ConnectorContext ctx) { + super.initialize(ctx); + initialized = true; + this.taskConfigs = null; + } + + @Override + public void initialize(ConnectorContext ctx, List> taskConfigs) { + super.initialize(ctx, taskConfigs); + initialized = true; + this.taskConfigs = taskConfigs; + } + + @Override + public void start(Map props) { + this.props = props; + started = true; + } + + @Override + public Class taskClass() { + return null; + } + + @Override + public List> taskConfigs(int maxTasks) { + return null; + } + + @Override + public void stop() { + stopped = true; + } + + @Override + public ConfigDef config() { + return new ConfigDef() + .define("required", ConfigDef.Type.STRING, ConfigDef.Importance.HIGH, "required docs") + .define("optional", ConfigDef.Type.STRING, "defaultVal", ConfigDef.Importance.HIGH, "optional docs"); + } + + @Override + public void assertContext(ConnectorContext expected) { + assertSame(expected, context); + assertSame(expected, context()); + } + + @Override + public void assertInitialized() { + assertTrue(initialized); + } + + @Override + public void assertTaskConfigs(List> expectedTaskConfigs) { + assertSame(expectedTaskConfigs, taskConfigs); + } + + @Override + public void assertStarted(boolean expected) { + assertEquals(expected, started); + } + + @Override + public void assertStopped(boolean expected) { + assertEquals(expected, stopped); + } + + @Override + public void assertProperties(Map expected) { + assertSame(expected, props); + } + } +} \ No newline at end of file diff --git a/connect/api/src/test/java/org/apache/kafka/connect/sink/SinkRecordTest.java b/connect/api/src/test/java/org/apache/kafka/connect/sink/SinkRecordTest.java new file mode 100644 index 0000000..02a6b2a --- /dev/null +++ b/connect/api/src/test/java/org/apache/kafka/connect/sink/SinkRecordTest.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.sink; + +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.Values; +import org.apache.kafka.connect.header.ConnectHeaders; +import org.apache.kafka.connect.header.Header; +import org.apache.kafka.connect.header.Headers; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class SinkRecordTest { + + private static final String TOPIC_NAME = "myTopic"; + private static final Integer PARTITION_NUMBER = 0; + private static final long KAFKA_OFFSET = 0L; + private static final Long KAFKA_TIMESTAMP = 0L; + private static final TimestampType TS_TYPE = TimestampType.CREATE_TIME; + + private SinkRecord record; + + @BeforeEach + public void beforeEach() { + record = new SinkRecord(TOPIC_NAME, PARTITION_NUMBER, Schema.STRING_SCHEMA, "key", Schema.BOOLEAN_SCHEMA, false, KAFKA_OFFSET, + KAFKA_TIMESTAMP, TS_TYPE, null); + } + + @Test + public void shouldCreateSinkRecordWithHeaders() { + Headers headers = new ConnectHeaders().addString("h1", "hv1").addBoolean("h2", true); + record = new SinkRecord(TOPIC_NAME, PARTITION_NUMBER, Schema.STRING_SCHEMA, "key", Schema.BOOLEAN_SCHEMA, false, KAFKA_OFFSET, + KAFKA_TIMESTAMP, TS_TYPE, headers); + assertNotNull(record.headers()); + assertSame(headers, record.headers()); + assertFalse(record.headers().isEmpty()); + } + + @Test + public void shouldCreateSinkRecordWithEmptyHeaders() { + assertEquals(TOPIC_NAME, record.topic()); + assertEquals(PARTITION_NUMBER, record.kafkaPartition()); + assertEquals(Schema.STRING_SCHEMA, record.keySchema()); + assertEquals("key", record.key()); + assertEquals(Schema.BOOLEAN_SCHEMA, record.valueSchema()); + assertEquals(false, record.value()); + assertEquals(KAFKA_OFFSET, record.kafkaOffset()); + assertEquals(KAFKA_TIMESTAMP, record.timestamp()); + assertEquals(TS_TYPE, record.timestampType()); + assertNotNull(record.headers()); + assertTrue(record.headers().isEmpty()); + } + + @Test + public void shouldDuplicateRecordAndCloneHeaders() { + SinkRecord duplicate = record.newRecord(TOPIC_NAME, PARTITION_NUMBER, Schema.STRING_SCHEMA, "key", Schema.BOOLEAN_SCHEMA, false, + KAFKA_TIMESTAMP); + + assertEquals(TOPIC_NAME, duplicate.topic()); + assertEquals(PARTITION_NUMBER, duplicate.kafkaPartition()); + assertEquals(Schema.STRING_SCHEMA, duplicate.keySchema()); + assertEquals("key", duplicate.key()); + assertEquals(Schema.BOOLEAN_SCHEMA, duplicate.valueSchema()); + assertEquals(false, duplicate.value()); + assertEquals(KAFKA_OFFSET, duplicate.kafkaOffset()); + assertEquals(KAFKA_TIMESTAMP, duplicate.timestamp()); + assertEquals(TS_TYPE, duplicate.timestampType()); + assertNotNull(duplicate.headers()); + assertTrue(duplicate.headers().isEmpty()); + assertNotSame(record.headers(), duplicate.headers()); + assertEquals(record.headers(), duplicate.headers()); + } + + + @Test + public void shouldDuplicateRecordUsingNewHeaders() { + Headers newHeaders = new ConnectHeaders().addString("h3", "hv3"); + SinkRecord duplicate = record.newRecord(TOPIC_NAME, PARTITION_NUMBER, Schema.STRING_SCHEMA, "key", Schema.BOOLEAN_SCHEMA, false, + KAFKA_TIMESTAMP, newHeaders); + + assertEquals(TOPIC_NAME, duplicate.topic()); + assertEquals(PARTITION_NUMBER, duplicate.kafkaPartition()); + assertEquals(Schema.STRING_SCHEMA, duplicate.keySchema()); + assertEquals("key", duplicate.key()); + assertEquals(Schema.BOOLEAN_SCHEMA, duplicate.valueSchema()); + assertEquals(false, duplicate.value()); + assertEquals(KAFKA_OFFSET, duplicate.kafkaOffset()); + assertEquals(KAFKA_TIMESTAMP, duplicate.timestamp()); + assertEquals(TS_TYPE, duplicate.timestampType()); + assertNotNull(duplicate.headers()); + assertEquals(newHeaders, duplicate.headers()); + assertSame(newHeaders, duplicate.headers()); + assertNotSame(record.headers(), duplicate.headers()); + assertNotEquals(record.headers(), duplicate.headers()); + } + + @Test + public void shouldModifyRecordHeader() { + assertTrue(record.headers().isEmpty()); + record.headers().addInt("intHeader", 100); + assertEquals(1, record.headers().size()); + Header header = record.headers().lastWithName("intHeader"); + assertEquals(100, (int) Values.convertToInteger(header.schema(), header.value())); + } +} \ No newline at end of file diff --git a/connect/api/src/test/java/org/apache/kafka/connect/source/SourceConnectorTest.java b/connect/api/src/test/java/org/apache/kafka/connect/source/SourceConnectorTest.java new file mode 100644 index 0000000..3359b1a --- /dev/null +++ b/connect/api/src/test/java/org/apache/kafka/connect/source/SourceConnectorTest.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.source; + +import java.util.List; +import java.util.Map; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.connector.ConnectorContext; +import org.apache.kafka.connect.connector.ConnectorTest; +import org.apache.kafka.connect.connector.Task; +import org.apache.kafka.connect.storage.OffsetStorageReader; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class SourceConnectorTest extends ConnectorTest { + + @Override + protected ConnectorContext createContext() { + return new TestSourceConnectorContext(); + } + + @Override + protected TestSourceConnector createConnector() { + return new TestSourceConnector(); + } + + private static class TestSourceConnectorContext implements SourceConnectorContext { + + @Override + public void requestTaskReconfiguration() { + // Unexpected in these tests + throw new UnsupportedOperationException(); + } + + @Override + public void raiseError(Exception e) { + // Unexpected in these tests + throw new UnsupportedOperationException(); + } + + @Override + public OffsetStorageReader offsetStorageReader() { + return null; + } + } + + private static class TestSourceConnector extends SourceConnector implements AssertableConnector { + + public static final String VERSION = "an entirely different version"; + + private boolean initialized; + private List> taskConfigs; + private Map props; + private boolean started; + private boolean stopped; + + @Override + public String version() { + return VERSION; + } + + @Override + public void initialize(ConnectorContext ctx) { + super.initialize(ctx); + initialized = true; + this.taskConfigs = null; + } + + @Override + public void initialize(ConnectorContext ctx, List> taskConfigs) { + super.initialize(ctx, taskConfigs); + initialized = true; + this.taskConfigs = taskConfigs; + } + + @Override + public void start(Map props) { + this.props = props; + started = true; + } + + @Override + public Class taskClass() { + return null; + } + + @Override + public List> taskConfigs(int maxTasks) { + return null; + } + + @Override + public void stop() { + stopped = true; + } + + @Override + public ConfigDef config() { + return new ConfigDef() + .define("required", ConfigDef.Type.STRING, ConfigDef.Importance.HIGH, "required docs") + .define("optional", ConfigDef.Type.STRING, "defaultVal", ConfigDef.Importance.HIGH, "optional docs"); + } + + @Override + public void assertContext(ConnectorContext expected) { + assertSame(expected, context); + assertSame(expected, context()); + } + + @Override + public void assertInitialized() { + assertTrue(initialized); + } + + @Override + public void assertTaskConfigs(List> expectedTaskConfigs) { + assertSame(expectedTaskConfigs, taskConfigs); + } + + @Override + public void assertStarted(boolean expected) { + assertEquals(expected, started); + } + + @Override + public void assertStopped(boolean expected) { + assertEquals(expected, stopped); + } + + @Override + public void assertProperties(Map expected) { + assertSame(expected, props); + } + } +} \ No newline at end of file diff --git a/connect/api/src/test/java/org/apache/kafka/connect/source/SourceRecordTest.java b/connect/api/src/test/java/org/apache/kafka/connect/source/SourceRecordTest.java new file mode 100644 index 0000000..f859005 --- /dev/null +++ b/connect/api/src/test/java/org/apache/kafka/connect/source/SourceRecordTest.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.source; + +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.Values; +import org.apache.kafka.connect.header.ConnectHeaders; +import org.apache.kafka.connect.header.Header; +import org.apache.kafka.connect.header.Headers; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class SourceRecordTest { + + private static final Map SOURCE_PARTITION = Collections.singletonMap("src", "abc"); + private static final Map SOURCE_OFFSET = Collections.singletonMap("offset", "1"); + private static final String TOPIC_NAME = "myTopic"; + private static final Integer PARTITION_NUMBER = 0; + private static final Long KAFKA_TIMESTAMP = 0L; + + private SourceRecord record; + + @BeforeEach + public void beforeEach() { + record = new SourceRecord(SOURCE_PARTITION, SOURCE_OFFSET, TOPIC_NAME, PARTITION_NUMBER, Schema.STRING_SCHEMA, "key", + Schema.BOOLEAN_SCHEMA, false, KAFKA_TIMESTAMP, null); + } + + @Test + public void shouldCreateSinkRecordWithHeaders() { + Headers headers = new ConnectHeaders().addString("h1", "hv1").addBoolean("h2", true); + record = new SourceRecord(SOURCE_PARTITION, SOURCE_OFFSET, TOPIC_NAME, PARTITION_NUMBER, Schema.STRING_SCHEMA, "key", + Schema.BOOLEAN_SCHEMA, false, KAFKA_TIMESTAMP, headers); + assertNotNull(record.headers()); + assertSame(headers, record.headers()); + assertFalse(record.headers().isEmpty()); + } + + @Test + public void shouldCreateSinkRecordWithEmtpyHeaders() { + assertEquals(SOURCE_PARTITION, record.sourcePartition()); + assertEquals(SOURCE_OFFSET, record.sourceOffset()); + assertEquals(TOPIC_NAME, record.topic()); + assertEquals(PARTITION_NUMBER, record.kafkaPartition()); + assertEquals(Schema.STRING_SCHEMA, record.keySchema()); + assertEquals("key", record.key()); + assertEquals(Schema.BOOLEAN_SCHEMA, record.valueSchema()); + assertEquals(false, record.value()); + assertEquals(KAFKA_TIMESTAMP, record.timestamp()); + assertNotNull(record.headers()); + assertTrue(record.headers().isEmpty()); + } + + @Test + public void shouldDuplicateRecordAndCloneHeaders() { + SourceRecord duplicate = record.newRecord(TOPIC_NAME, PARTITION_NUMBER, Schema.STRING_SCHEMA, "key", Schema.BOOLEAN_SCHEMA, false, + KAFKA_TIMESTAMP); + + assertEquals(SOURCE_PARTITION, duplicate.sourcePartition()); + assertEquals(SOURCE_OFFSET, duplicate.sourceOffset()); + assertEquals(TOPIC_NAME, duplicate.topic()); + assertEquals(PARTITION_NUMBER, duplicate.kafkaPartition()); + assertEquals(Schema.STRING_SCHEMA, duplicate.keySchema()); + assertEquals("key", duplicate.key()); + assertEquals(Schema.BOOLEAN_SCHEMA, duplicate.valueSchema()); + assertEquals(false, duplicate.value()); + assertEquals(KAFKA_TIMESTAMP, duplicate.timestamp()); + assertNotNull(duplicate.headers()); + assertTrue(duplicate.headers().isEmpty()); + assertNotSame(record.headers(), duplicate.headers()); + assertEquals(record.headers(), duplicate.headers()); + } + + @Test + public void shouldDuplicateRecordUsingNewHeaders() { + Headers newHeaders = new ConnectHeaders().addString("h3", "hv3"); + SourceRecord duplicate = record.newRecord(TOPIC_NAME, PARTITION_NUMBER, Schema.STRING_SCHEMA, "key", Schema.BOOLEAN_SCHEMA, false, + KAFKA_TIMESTAMP, newHeaders); + + assertEquals(SOURCE_PARTITION, duplicate.sourcePartition()); + assertEquals(SOURCE_OFFSET, duplicate.sourceOffset()); + assertEquals(TOPIC_NAME, duplicate.topic()); + assertEquals(PARTITION_NUMBER, duplicate.kafkaPartition()); + assertEquals(Schema.STRING_SCHEMA, duplicate.keySchema()); + assertEquals("key", duplicate.key()); + assertEquals(Schema.BOOLEAN_SCHEMA, duplicate.valueSchema()); + assertEquals(false, duplicate.value()); + assertEquals(KAFKA_TIMESTAMP, duplicate.timestamp()); + assertNotNull(duplicate.headers()); + assertEquals(newHeaders, duplicate.headers()); + assertSame(newHeaders, duplicate.headers()); + assertNotSame(record.headers(), duplicate.headers()); + assertNotEquals(record.headers(), duplicate.headers()); + } + + @Test + public void shouldModifyRecordHeader() { + assertTrue(record.headers().isEmpty()); + record.headers().addInt("intHeader", 100); + assertEquals(1, record.headers().size()); + Header header = record.headers().lastWithName("intHeader"); + assertEquals(100, (int) Values.convertToInteger(header.schema(), header.value())); + } +} \ No newline at end of file diff --git a/connect/api/src/test/java/org/apache/kafka/connect/storage/ConverterTypeTest.java b/connect/api/src/test/java/org/apache/kafka/connect/storage/ConverterTypeTest.java new file mode 100644 index 0000000..f88ca91 --- /dev/null +++ b/connect/api/src/test/java/org/apache/kafka/connect/storage/ConverterTypeTest.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class ConverterTypeTest { + + @Test + public void shouldFindByName() { + for (ConverterType type : ConverterType.values()) { + assertEquals(type, ConverterType.withName(type.getName())); + } + } +} \ No newline at end of file diff --git a/connect/api/src/test/java/org/apache/kafka/connect/storage/SimpleHeaderConverterTest.java b/connect/api/src/test/java/org/apache/kafka/connect/storage/SimpleHeaderConverterTest.java new file mode 100644 index 0000000..8c9306c --- /dev/null +++ b/connect/api/src/test/java/org/apache/kafka/connect/storage/SimpleHeaderConverterTest.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.data.SchemaBuilder; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; + +public class SimpleHeaderConverterTest { + + private static final String TOPIC = "topic"; + private static final String HEADER = "header"; + + private static final Map STRING_MAP = new LinkedHashMap<>(); + private static final Schema STRING_MAP_SCHEMA = SchemaBuilder.map(Schema.STRING_SCHEMA, Schema.STRING_SCHEMA).schema(); + + private static final Map STRING_SHORT_MAP = new LinkedHashMap<>(); + private static final Schema STRING_SHORT_MAP_SCHEMA = SchemaBuilder.map(Schema.STRING_SCHEMA, Schema.INT16_SCHEMA).schema(); + + private static final Map STRING_INT_MAP = new LinkedHashMap<>(); + private static final Schema STRING_INT_MAP_SCHEMA = SchemaBuilder.map(Schema.STRING_SCHEMA, Schema.INT32_SCHEMA).schema(); + + private static final List INT_LIST = new ArrayList<>(); + private static final Schema INT_LIST_SCHEMA = SchemaBuilder.array(Schema.INT32_SCHEMA).schema(); + + private static final List STRING_LIST = new ArrayList<>(); + private static final Schema STRING_LIST_SCHEMA = SchemaBuilder.array(Schema.STRING_SCHEMA).schema(); + + static { + STRING_MAP.put("foo", "123"); + STRING_MAP.put("bar", "baz"); + STRING_SHORT_MAP.put("foo", (short) 12345); + STRING_SHORT_MAP.put("bar", (short) 0); + STRING_SHORT_MAP.put("baz", (short) -4321); + STRING_INT_MAP.put("foo", 1234567890); + STRING_INT_MAP.put("bar", 0); + STRING_INT_MAP.put("baz", -987654321); + STRING_LIST.add("foo"); + STRING_LIST.add("bar"); + INT_LIST.add(1234567890); + INT_LIST.add(-987654321); + } + + private SimpleHeaderConverter converter; + + @BeforeEach + public void beforeEach() { + converter = new SimpleHeaderConverter(); + } + + @Test + public void shouldConvertNullValue() { + assertRoundTrip(Schema.STRING_SCHEMA, null); + assertRoundTrip(Schema.OPTIONAL_STRING_SCHEMA, null); + } + + @Test + public void shouldConvertSimpleString() { + assertRoundTrip(Schema.STRING_SCHEMA, "simple"); + } + + @Test + public void shouldConvertEmptyString() { + assertRoundTrip(Schema.STRING_SCHEMA, ""); + } + + @Test + public void shouldConvertStringWithQuotesAndOtherDelimiterCharacters() { + assertRoundTrip(Schema.STRING_SCHEMA, "three\"blind\\\"mice"); + assertRoundTrip(Schema.STRING_SCHEMA, "string with delimiters: <>?,./\\=+-!@#$%^&*(){}[]|;':"); + } + + @Test + public void shouldConvertMapWithStringKeys() { + assertRoundTrip(STRING_MAP_SCHEMA, STRING_MAP); + } + + @Test + public void shouldParseStringOfMapWithStringValuesWithoutWhitespaceAsMap() { + SchemaAndValue result = roundTrip(Schema.STRING_SCHEMA, "{\"foo\":\"123\",\"bar\":\"baz\"}"); + assertEquals(STRING_MAP_SCHEMA, result.schema()); + assertEquals(STRING_MAP, result.value()); + } + + @Test + public void shouldParseStringOfMapWithStringValuesWithWhitespaceAsMap() { + SchemaAndValue result = roundTrip(Schema.STRING_SCHEMA, "{ \"foo\" : \"123\", \n\"bar\" : \"baz\" } "); + assertEquals(STRING_MAP_SCHEMA, result.schema()); + assertEquals(STRING_MAP, result.value()); + } + + @Test + public void shouldConvertMapWithStringKeysAndShortValues() { + assertRoundTrip(STRING_SHORT_MAP_SCHEMA, STRING_SHORT_MAP); + } + + @Test + public void shouldParseStringOfMapWithShortValuesWithoutWhitespaceAsMap() { + SchemaAndValue result = roundTrip(Schema.STRING_SCHEMA, "{\"foo\":12345,\"bar\":0,\"baz\":-4321}"); + assertEquals(STRING_SHORT_MAP_SCHEMA, result.schema()); + assertEquals(STRING_SHORT_MAP, result.value()); + } + + @Test + public void shouldParseStringOfMapWithShortValuesWithWhitespaceAsMap() { + SchemaAndValue result = roundTrip(Schema.STRING_SCHEMA, " { \"foo\" : 12345 , \"bar\" : 0, \"baz\" : -4321 } "); + assertEquals(STRING_SHORT_MAP_SCHEMA, result.schema()); + assertEquals(STRING_SHORT_MAP, result.value()); + } + + @Test + public void shouldConvertMapWithStringKeysAndIntegerValues() { + assertRoundTrip(STRING_INT_MAP_SCHEMA, STRING_INT_MAP); + } + + @Test + public void shouldParseStringOfMapWithIntValuesWithoutWhitespaceAsMap() { + SchemaAndValue result = roundTrip(Schema.STRING_SCHEMA, "{\"foo\":1234567890,\"bar\":0,\"baz\":-987654321}"); + assertEquals(STRING_INT_MAP_SCHEMA, result.schema()); + assertEquals(STRING_INT_MAP, result.value()); + } + + @Test + public void shouldParseStringOfMapWithIntValuesWithWhitespaceAsMap() { + SchemaAndValue result = roundTrip(Schema.STRING_SCHEMA, " { \"foo\" : 1234567890 , \"bar\" : 0, \"baz\" : -987654321 } "); + assertEquals(STRING_INT_MAP_SCHEMA, result.schema()); + assertEquals(STRING_INT_MAP, result.value()); + } + + @Test + public void shouldConvertListWithStringValues() { + assertRoundTrip(STRING_LIST_SCHEMA, STRING_LIST); + } + + @Test + public void shouldConvertListWithIntegerValues() { + assertRoundTrip(INT_LIST_SCHEMA, INT_LIST); + } + + @Test + public void shouldConvertMapWithStringKeysAndMixedValuesToMap() { + Map map = new LinkedHashMap<>(); + map.put("foo", "bar"); + map.put("baz", (short) 3456); + SchemaAndValue result = roundTrip(null, map); + assertEquals(Schema.Type.MAP, result.schema().type()); + assertEquals(Schema.Type.STRING, result.schema().keySchema().type()); + assertNull(result.schema().valueSchema()); + assertEquals(map, result.value()); + } + + @Test + public void shouldConvertListWithMixedValuesToListWithoutSchema() { + List list = new ArrayList<>(); + list.add("foo"); + list.add((short) 13344); + SchemaAndValue result = roundTrip(null, list); + assertEquals(Schema.Type.ARRAY, result.schema().type()); + assertNull(result.schema().valueSchema()); + assertEquals(list, result.value()); + } + + @Test + public void shouldConvertEmptyMapToMap() { + Map map = new LinkedHashMap<>(); + SchemaAndValue result = roundTrip(null, map); + assertEquals(Schema.Type.MAP, result.schema().type()); + assertNull(result.schema().keySchema()); + assertNull(result.schema().valueSchema()); + assertEquals(map, result.value()); + } + + @Test + public void shouldConvertEmptyListToList() { + List list = new ArrayList<>(); + SchemaAndValue result = roundTrip(null, list); + assertEquals(Schema.Type.ARRAY, result.schema().type()); + assertNull(result.schema().valueSchema()); + assertEquals(list, result.value()); + } + + protected SchemaAndValue roundTrip(Schema schema, Object input) { + byte[] serialized = converter.fromConnectHeader(TOPIC, HEADER, schema, input); + return converter.toConnectHeader(TOPIC, HEADER, serialized); + } + + protected void assertRoundTrip(Schema schema, Object value) { + byte[] serialized = converter.fromConnectHeader(TOPIC, HEADER, schema, value); + SchemaAndValue result = converter.toConnectHeader(TOPIC, HEADER, serialized); + + if (value == null) { + assertNull(serialized); + assertNull(result.schema()); + assertNull(result.value()); + } else { + assertNotNull(serialized); + assertEquals(value, result.value()); + assertEquals(schema, result.schema()); + + byte[] serialized2 = converter.fromConnectHeader(TOPIC, HEADER, result.schema(), result.value()); + SchemaAndValue result2 = converter.toConnectHeader(TOPIC, HEADER, serialized2); + assertNotNull(serialized2); + assertEquals(schema, result2.schema()); + assertEquals(value, result2.value()); + assertEquals(result, result2); + assertArrayEquals(serialized, serialized); + } + } + +} \ No newline at end of file diff --git a/connect/api/src/test/java/org/apache/kafka/connect/storage/StringConverterTest.java b/connect/api/src/test/java/org/apache/kafka/connect/storage/StringConverterTest.java new file mode 100644 index 0000000..648bb7e --- /dev/null +++ b/connect/api/src/test/java/org/apache/kafka/connect/storage/StringConverterTest.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.junit.jupiter.api.Test; + +import java.nio.charset.StandardCharsets; +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +public class StringConverterTest { + private static final String TOPIC = "topic"; + private static final String SAMPLE_STRING = "a string"; + + private StringConverter converter = new StringConverter(); + + @Test + public void testStringToBytes() { + assertArrayEquals(Utils.utf8(SAMPLE_STRING), converter.fromConnectData(TOPIC, Schema.STRING_SCHEMA, SAMPLE_STRING)); + } + + @Test + public void testNonStringToBytes() { + assertArrayEquals(Utils.utf8("true"), converter.fromConnectData(TOPIC, Schema.BOOLEAN_SCHEMA, true)); + } + + @Test + public void testNullToBytes() { + assertNull(converter.fromConnectData(TOPIC, Schema.OPTIONAL_STRING_SCHEMA, null)); + } + + @Test + public void testToBytesIgnoresSchema() { + assertArrayEquals(Utils.utf8("true"), converter.fromConnectData(TOPIC, null, true)); + } + + @Test + public void testToBytesNonUtf8Encoding() { + converter.configure(Collections.singletonMap("converter.encoding", StandardCharsets.UTF_16.name()), true); + assertArrayEquals(SAMPLE_STRING.getBytes(StandardCharsets.UTF_16), converter.fromConnectData(TOPIC, Schema.STRING_SCHEMA, SAMPLE_STRING)); + } + + @Test + public void testBytesToString() { + SchemaAndValue data = converter.toConnectData(TOPIC, SAMPLE_STRING.getBytes()); + assertEquals(Schema.OPTIONAL_STRING_SCHEMA, data.schema()); + assertEquals(SAMPLE_STRING, data.value()); + } + + @Test + public void testBytesNullToString() { + SchemaAndValue data = converter.toConnectData(TOPIC, null); + assertEquals(Schema.OPTIONAL_STRING_SCHEMA, data.schema()); + assertNull(data.value()); + } + + @Test + public void testBytesToStringNonUtf8Encoding() { + converter.configure(Collections.singletonMap("converter.encoding", StandardCharsets.UTF_16.name()), true); + SchemaAndValue data = converter.toConnectData(TOPIC, SAMPLE_STRING.getBytes(StandardCharsets.UTF_16)); + assertEquals(Schema.OPTIONAL_STRING_SCHEMA, data.schema()); + assertEquals(SAMPLE_STRING, data.value()); + } + + // Note: the header conversion methods delegates to the data conversion methods, which are tested above. + // The following simply verify that the delegation works. + + @Test + public void testStringHeaderValueToBytes() { + assertArrayEquals(Utils.utf8(SAMPLE_STRING), converter.fromConnectHeader(TOPIC, "hdr", Schema.STRING_SCHEMA, SAMPLE_STRING)); + } + + @Test + public void testNonStringHeaderValueToBytes() { + assertArrayEquals(Utils.utf8("true"), converter.fromConnectHeader(TOPIC, "hdr", Schema.BOOLEAN_SCHEMA, true)); + } + + @Test + public void testNullHeaderValueToBytes() { + assertNull(converter.fromConnectHeader(TOPIC, "hdr", Schema.OPTIONAL_STRING_SCHEMA, null)); + } +} diff --git a/connect/api/src/test/java/org/apache/kafka/connect/util/ConnectorUtilsTest.java b/connect/api/src/test/java/org/apache/kafka/connect/util/ConnectorUtilsTest.java new file mode 100644 index 0000000..b6f96bd --- /dev/null +++ b/connect/api/src/test/java/org/apache/kafka/connect/util/ConnectorUtilsTest.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class ConnectorUtilsTest { + + private static final List FIVE_ELEMENTS = Arrays.asList(1, 2, 3, 4, 5); + + @Test + public void testGroupPartitions() { + + List> grouped = ConnectorUtils.groupPartitions(FIVE_ELEMENTS, 1); + assertEquals(Arrays.asList(FIVE_ELEMENTS), grouped); + + grouped = ConnectorUtils.groupPartitions(FIVE_ELEMENTS, 2); + assertEquals(Arrays.asList(Arrays.asList(1, 2, 3), Arrays.asList(4, 5)), grouped); + + grouped = ConnectorUtils.groupPartitions(FIVE_ELEMENTS, 3); + assertEquals(Arrays.asList(Arrays.asList(1, 2), + Arrays.asList(3, 4), + Arrays.asList(5)), grouped); + + grouped = ConnectorUtils.groupPartitions(FIVE_ELEMENTS, 5); + assertEquals(Arrays.asList(Arrays.asList(1), + Arrays.asList(2), + Arrays.asList(3), + Arrays.asList(4), + Arrays.asList(5)), grouped); + + grouped = ConnectorUtils.groupPartitions(FIVE_ELEMENTS, 7); + assertEquals(Arrays.asList(Arrays.asList(1), + Arrays.asList(2), + Arrays.asList(3), + Arrays.asList(4), + Arrays.asList(5), + Collections.emptyList(), + Collections.emptyList()), grouped); + } + + @Test + public void testGroupPartitionsInvalidCount() { + assertThrows(IllegalArgumentException.class, + () -> ConnectorUtils.groupPartitions(FIVE_ELEMENTS, 0)); + } +} diff --git a/connect/basic-auth-extension/src/main/java/org/apache/kafka/connect/rest/basic/auth/extension/BasicAuthSecurityRestExtension.java b/connect/basic-auth-extension/src/main/java/org/apache/kafka/connect/rest/basic/auth/extension/BasicAuthSecurityRestExtension.java new file mode 100644 index 0000000..8c41762 --- /dev/null +++ b/connect/basic-auth-extension/src/main/java/org/apache/kafka/connect/rest/basic/auth/extension/BasicAuthSecurityRestExtension.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.rest.basic.auth.extension; + +import org.apache.kafka.common.utils.AppInfoParser; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.rest.ConnectRestExtension; +import org.apache.kafka.connect.rest.ConnectRestExtensionContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.security.auth.login.Configuration; +import java.io.IOException; +import java.util.Map; +import java.util.function.Supplier; + +/** + * Provides the ability to authenticate incoming BasicAuth credentials using the configured JAAS {@link + * javax.security.auth.spi.LoginModule}. An entry with the name {@code KafkaConnect} is expected in the JAAS config file configured in the + * JVM. An implementation of {@link javax.security.auth.spi.LoginModule} needs to be provided in the JAAS config file. The {@code + * LoginModule} implementation should configure the {@link javax.security.auth.callback.CallbackHandler} with only {@link + * javax.security.auth.callback.NameCallback} and {@link javax.security.auth.callback.PasswordCallback}. + * + *

            To use this extension, one needs to add the following config in the {@code worker.properties} + *

            + *     rest.extension.classes = org.apache.kafka.connect.rest.basic.auth.extension.BasicAuthSecurityRestExtension
            + * 
            + * + *

            An example JAAS config would look as below + *

            + *         KafkaConnect {
            + *              org.apache.kafka.connect.rest.basic.auth.extension.PropertyFileLoginModule required
            + *              file="/mnt/secret/credentials.properties";
            + *         };
            + *
            + * + *

            This is a reference implementation of the {@link ConnectRestExtension} interface. It registers an implementation of {@link + * javax.ws.rs.container.ContainerRequestFilter} that does JAAS based authentication of incoming Basic Auth credentials. {@link + * ConnectRestExtension} implementations are loaded via the plugin class loader using {@link java.util.ServiceLoader} mechanism and hence + * the packaged jar includes {@code META-INF/services/org.apache.kafka.connect.rest.extension.ConnectRestExtension} with the entry + * {@code org.apache.kafka.connect.extension.auth.jaas.BasicAuthSecurityRestExtension} + * + *

            NOTE: The implementation ships with a default {@link PropertyFileLoginModule} that helps authenticate the request against a + * property file. {@link PropertyFileLoginModule} is NOT intended to be used in production since the credentials are stored in PLAINTEXT. One can use + * this extension in production by using their own implementation of {@link javax.security.auth.spi.LoginModule} that authenticates against + * stores like LDAP, DB, etc. + */ +public class BasicAuthSecurityRestExtension implements ConnectRestExtension { + + private static final Logger log = LoggerFactory.getLogger(BasicAuthSecurityRestExtension.class); + + private static final Supplier CONFIGURATION = initializeConfiguration(Configuration::getConfiguration); + + // Capture the JVM's global JAAS configuration as soon as possible, as it may be altered later + // by connectors, converters, other REST extensions, etc. + static Supplier initializeConfiguration(Supplier configurationSupplier) { + try { + Configuration configuration = configurationSupplier.get(); + return () -> configuration; + } catch (Exception e) { + // We have to be careful not to throw anything here as this static block gets executed during plugin scanning and any exceptions will + // cause the worker to fail during startup, even if it's not configured to use the basic auth extension. + return () -> { + throw new ConnectException("Failed to retrieve JAAS configuration", e); + }; + } + } + + private final Supplier configuration; + + public BasicAuthSecurityRestExtension() { + this(CONFIGURATION); + } + + // For testing + BasicAuthSecurityRestExtension(Supplier configuration) { + this.configuration = configuration; + } + + @Override + public void register(ConnectRestExtensionContext restPluginContext) { + log.trace("Registering JAAS basic auth filter"); + restPluginContext.configurable().register(new JaasBasicAuthFilter(configuration.get())); + log.trace("Finished registering JAAS basic auth filter"); + } + + @Override + public void close() throws IOException { + + } + + @Override + public void configure(Map configs) { + // If we failed to retrieve a JAAS configuration during startup, throw that exception now + configuration.get(); + } + + @Override + public String version() { + return AppInfoParser.getVersion(); + } +} diff --git a/connect/basic-auth-extension/src/main/java/org/apache/kafka/connect/rest/basic/auth/extension/JaasBasicAuthFilter.java b/connect/basic-auth-extension/src/main/java/org/apache/kafka/connect/rest/basic/auth/extension/JaasBasicAuthFilter.java new file mode 100644 index 0000000..0299cbb --- /dev/null +++ b/connect/basic-auth-extension/src/main/java/org/apache/kafka/connect/rest/basic/auth/extension/JaasBasicAuthFilter.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.rest.basic.auth.extension; + +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Pattern; +import javax.security.auth.login.Configuration; +import javax.ws.rs.HttpMethod; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.connect.errors.ConnectException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Base64; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.auth.login.LoginContext; +import javax.security.auth.login.LoginException; +import javax.ws.rs.container.ContainerRequestContext; +import javax.ws.rs.container.ContainerRequestFilter; +import javax.ws.rs.core.Response; + +public class JaasBasicAuthFilter implements ContainerRequestFilter { + + private static final Logger log = LoggerFactory.getLogger(JaasBasicAuthFilter.class); + private static final Pattern TASK_REQUEST_PATTERN = Pattern.compile("/?connectors/([^/]+)/tasks/?"); + private static final String CONNECT_LOGIN_MODULE = "KafkaConnect"; + + static final String AUTHORIZATION = "Authorization"; + + // Package-private for testing + final Configuration configuration; + + public JaasBasicAuthFilter(Configuration configuration) { + this.configuration = configuration; + } + + @Override + public void filter(ContainerRequestContext requestContext) throws IOException { + if (isInternalTaskConfigRequest(requestContext)) { + log.trace("Skipping authentication for internal request"); + return; + } + + try { + log.debug("Authenticating request"); + LoginContext loginContext = new LoginContext( + CONNECT_LOGIN_MODULE, + null, + new BasicAuthCallBackHandler(requestContext.getHeaderString(AUTHORIZATION)), + configuration); + loginContext.login(); + } catch (LoginException | ConfigException e) { + // Log at debug here in order to avoid polluting log files whenever someone mistypes their credentials + log.debug("Request failed authentication", e); + requestContext.abortWith( + Response.status(Response.Status.UNAUTHORIZED) + .entity("User cannot access the resource.") + .build()); + } + } + + private static boolean isInternalTaskConfigRequest(ContainerRequestContext requestContext) { + return requestContext.getMethod().equals(HttpMethod.POST) + && TASK_REQUEST_PATTERN.matcher(requestContext.getUriInfo().getPath()).matches(); + } + + + public static class BasicAuthCallBackHandler implements CallbackHandler { + + private static final String BASIC = "basic"; + private static final char COLON = ':'; + private static final char SPACE = ' '; + private String username; + private String password; + + public BasicAuthCallBackHandler(String credentials) { + if (credentials == null) { + log.trace("No credentials were provided with the request"); + return; + } + + int space = credentials.indexOf(SPACE); + if (space <= 0) { + log.trace("Request credentials were malformed; no space present in value for authorization header"); + return; + } + + String method = credentials.substring(0, space); + if (!BASIC.equalsIgnoreCase(method)) { + log.trace("Request credentials used {} authentication, but only {} supported; ignoring", method, BASIC); + return; + } + + credentials = credentials.substring(space + 1); + credentials = new String(Base64.getDecoder().decode(credentials), + StandardCharsets.UTF_8); + int i = credentials.indexOf(COLON); + if (i <= 0) { + log.trace("Request credentials were malformed; no colon present between username and password"); + return; + } + + username = credentials.substring(0, i); + password = credentials.substring(i + 1); + } + + @Override + public void handle(Callback[] callbacks) throws UnsupportedCallbackException { + List unsupportedCallbacks = new ArrayList<>(); + for (Callback callback : callbacks) { + if (callback instanceof NameCallback) { + ((NameCallback) callback).setName(username); + } else if (callback instanceof PasswordCallback) { + ((PasswordCallback) callback).setPassword(password != null + ? password.toCharArray() + : null + ); + } else { + unsupportedCallbacks.add(callback); + } + } + if (!unsupportedCallbacks.isEmpty()) + throw new ConnectException(String.format( + "Unsupported callbacks %s; request authentication will fail. " + + "This indicates the Connect worker was configured with a JAAS " + + "LoginModule that is incompatible with the %s, and will need to be " + + "corrected and restarted.", + unsupportedCallbacks, + BasicAuthSecurityRestExtension.class.getSimpleName() + )); + } + } +} diff --git a/connect/basic-auth-extension/src/main/java/org/apache/kafka/connect/rest/basic/auth/extension/PropertyFileLoginModule.java b/connect/basic-auth-extension/src/main/java/org/apache/kafka/connect/rest/basic/auth/extension/PropertyFileLoginModule.java new file mode 100644 index 0000000..8b8e324 --- /dev/null +++ b/connect/basic-auth-extension/src/main/java/org/apache/kafka/connect/rest/basic/auth/extension/PropertyFileLoginModule.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.rest.basic.auth.extension; + +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.ConcurrentHashMap; + +import javax.security.auth.Subject; +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.login.LoginException; +import javax.security.auth.spi.LoginModule; + +/** + * {@link PropertyFileLoginModule} authenticates against a properties file. + * The credentials should be stored in the format {username}={password} in the properties file. + * The absolute path of the file needs to specified using the option file + * + *

            NOTE: This implementation is NOT intended to be used in production since the credentials are stored in PLAINTEXT in the + * properties file. + */ +public class PropertyFileLoginModule implements LoginModule { + private static final Logger log = LoggerFactory.getLogger(PropertyFileLoginModule.class); + + private CallbackHandler callbackHandler; + private static final String FILE_OPTIONS = "file"; + private String fileName; + private boolean authenticated; + + private static Map credentialPropertiesMap = new ConcurrentHashMap<>(); + + @Override + public void initialize(Subject subject, CallbackHandler callbackHandler, Map sharedState, Map options) { + this.callbackHandler = callbackHandler; + fileName = (String) options.get(FILE_OPTIONS); + if (Utils.isBlank(fileName)) { + throw new ConfigException("Property Credentials file must be specified"); + } + + if (!credentialPropertiesMap.containsKey(fileName)) { + log.trace("Opening credential properties file '{}'", fileName); + Properties credentialProperties = new Properties(); + try { + try (InputStream inputStream = Files.newInputStream(Paths.get(fileName))) { + log.trace("Parsing credential properties file '{}'", fileName); + credentialProperties.load(inputStream); + } + credentialPropertiesMap.putIfAbsent(fileName, credentialProperties); + if (credentialProperties.isEmpty()) + log.warn("Credential properties file '{}' is empty; all requests will be permitted", + fileName); + } catch (IOException e) { + log.error("Error loading credentials file ", e); + throw new ConfigException("Error loading Property Credentials file"); + } + } else { + log.trace( + "Credential properties file '{}' has already been opened and parsed; will read from cached, in-memory store", + fileName); + } + } + + @Override + public boolean login() throws LoginException { + Callback[] callbacks = configureCallbacks(); + try { + log.trace("Authenticating user; invoking JAAS login callbacks"); + callbackHandler.handle(callbacks); + } catch (Exception e) { + log.warn("Authentication failed while invoking JAAS login callbacks", e); + throw new LoginException(e.getMessage()); + } + + String username = ((NameCallback) callbacks[0]).getName(); + char[] passwordChars = ((PasswordCallback) callbacks[1]).getPassword(); + String password = passwordChars != null ? new String(passwordChars) : null; + Properties credentialProperties = credentialPropertiesMap.get(fileName); + + if (credentialProperties.isEmpty()) { + log.trace("Not validating credentials for user '{}' as credential properties file '{}' is empty", + username, + fileName); + authenticated = true; + } else if (username == null) { + log.trace("No credentials were provided or the provided credentials were malformed"); + authenticated = false; + } else if (password != null && password.equals(credentialProperties.get(username))) { + log.trace("Credentials provided for user '{}' match those present in the credential properties file '{}'", + username, + fileName); + authenticated = true; + } else if (!credentialProperties.containsKey(username)) { + log.trace("User '{}' is not present in the credential properties file '{}'", + username, + fileName); + authenticated = false; + } else { + log.trace("Credentials provided for user '{}' do not match those present in the credential properties file '{}'", + username, + fileName); + authenticated = false; + } + + return authenticated; + } + + @Override + public boolean commit() throws LoginException { + return authenticated; + } + + @Override + public boolean abort() throws LoginException { + return true; + } + + @Override + public boolean logout() throws LoginException { + return true; + } + + private Callback[] configureCallbacks() { + Callback[] callbacks = new Callback[2]; + callbacks[0] = new NameCallback("Enter user name"); + callbacks[1] = new PasswordCallback("Enter password", false); + return callbacks; + } +} diff --git a/connect/basic-auth-extension/src/main/resources/META-INF/services/org.apache.kafka.connect.rest.ConnectRestExtension b/connect/basic-auth-extension/src/main/resources/META-INF/services/org.apache.kafka.connect.rest.ConnectRestExtension new file mode 100644 index 0000000..ba7ae5b --- /dev/null +++ b/connect/basic-auth-extension/src/main/resources/META-INF/services/org.apache.kafka.connect.rest.ConnectRestExtension @@ -0,0 +1,16 @@ + # Licensed to the Apache Software Foundation (ASF) under one or more + # contributor license agreements. See the NOTICE file distributed with + # this work for additional information regarding copyright ownership. + # The ASF licenses this file to You under the Apache License, Version 2.0 + # (the "License"); you may not use this file except in compliance with + # the License. You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + +org.apache.kafka.connect.rest.basic.auth.extension.BasicAuthSecurityRestExtension \ No newline at end of file diff --git a/connect/basic-auth-extension/src/test/java/org/apache/kafka/connect/rest/basic/auth/extension/BasicAuthSecurityRestExtensionTest.java b/connect/basic-auth-extension/src/test/java/org/apache/kafka/connect/rest/basic/auth/extension/BasicAuthSecurityRestExtensionTest.java new file mode 100644 index 0000000..b1b5b1e --- /dev/null +++ b/connect/basic-auth-extension/src/test/java/org/apache/kafka/connect/rest/basic/auth/extension/BasicAuthSecurityRestExtensionTest.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.rest.basic.auth.extension; + +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.rest.ConnectRestExtensionContext; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; + +import javax.security.auth.login.Configuration; +import javax.ws.rs.core.Configurable; + +import java.io.IOException; +import java.util.Collections; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class BasicAuthSecurityRestExtensionTest { + + Configuration priorConfiguration; + + @BeforeEach + public void setup() { + priorConfiguration = Configuration.getConfiguration(); + } + + @AfterEach + public void tearDown() { + Configuration.setConfiguration(priorConfiguration); + } + + @SuppressWarnings("unchecked") + @Test + public void testJaasConfigurationNotOverwritten() { + ArgumentCaptor jaasFilter = ArgumentCaptor.forClass(JaasBasicAuthFilter.class); + Configurable> configurable = mock(Configurable.class); + when(configurable.register(jaasFilter.capture())).thenReturn(null); + + ConnectRestExtensionContext context = mock(ConnectRestExtensionContext.class); + when(context.configurable()).thenReturn((Configurable) configurable); + + BasicAuthSecurityRestExtension extension = new BasicAuthSecurityRestExtension(); + Configuration overwrittenConfiguration = mock(Configuration.class); + Configuration.setConfiguration(overwrittenConfiguration); + extension.register(context); + + assertNotEquals(overwrittenConfiguration, jaasFilter.getValue().configuration, + "Overwritten JAAS configuration should not be used by basic auth REST extension"); + } + + @Test + public void testBadJaasConfigInitialization() { + SecurityException jaasConfigurationException = new SecurityException(new IOException("Bad JAAS config is bad")); + Supplier configuration = BasicAuthSecurityRestExtension.initializeConfiguration(() -> { + throw jaasConfigurationException; + }); + + ConnectException thrownException = assertThrows(ConnectException.class, configuration::get); + assertEquals(jaasConfigurationException, thrownException.getCause()); + } + + @Test + public void testGoodJaasConfigInitialization() { + AtomicBoolean configurationInitializerEvaluated = new AtomicBoolean(false); + Configuration mockConfiguration = mock(Configuration.class); + Supplier configuration = BasicAuthSecurityRestExtension.initializeConfiguration(() -> { + configurationInitializerEvaluated.set(true); + return mockConfiguration; + }); + + assertTrue(configurationInitializerEvaluated.get()); + assertEquals(mockConfiguration, configuration.get()); + } + + @Test + public void testBadJaasConfigExtensionSetup() { + SecurityException jaasConfigurationException = new SecurityException(new IOException("Bad JAAS config is bad")); + Supplier configuration = () -> { + throw jaasConfigurationException; + }; + + BasicAuthSecurityRestExtension extension = new BasicAuthSecurityRestExtension(configuration); + + Exception thrownException = assertThrows(Exception.class, () -> extension.configure(Collections.emptyMap())); + assertEquals(jaasConfigurationException, thrownException); + + thrownException = assertThrows(Exception.class, () -> extension.register(mock(ConnectRestExtensionContext.class))); + assertEquals(jaasConfigurationException, thrownException); + } +} diff --git a/connect/basic-auth-extension/src/test/java/org/apache/kafka/connect/rest/basic/auth/extension/JaasBasicAuthFilterTest.java b/connect/basic-auth-extension/src/test/java/org/apache/kafka/connect/rest/basic/auth/extension/JaasBasicAuthFilterTest.java new file mode 100644 index 0000000..561095f --- /dev/null +++ b/connect/basic-auth-extension/src/test/java/org/apache/kafka/connect/rest/basic/auth/extension/JaasBasicAuthFilterTest.java @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.rest.basic.auth.extension; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.ChoiceCallback; +import javax.ws.rs.HttpMethod; +import javax.ws.rs.core.UriInfo; + +import org.apache.kafka.common.security.authenticator.TestJaasConfig; +import org.apache.kafka.connect.errors.ConnectException; +import org.junit.jupiter.api.Test; + +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.Base64; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import javax.ws.rs.container.ContainerRequestContext; +import javax.ws.rs.core.Response; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class JaasBasicAuthFilterTest { + + private static final String LOGIN_MODULE = + "org.apache.kafka.connect.rest.basic.auth.extension.PropertyFileLoginModule"; + + @Test + public void testSuccess() throws IOException { + File credentialFile = setupPropertyLoginFile(true); + JaasBasicAuthFilter jaasBasicAuthFilter = setupJaasFilter("KafkaConnect", credentialFile.getPath()); + ContainerRequestContext requestContext = setMock("Basic", "user", "password"); + jaasBasicAuthFilter.filter(requestContext); + + verify(requestContext).getMethod(); + verify(requestContext).getHeaderString(JaasBasicAuthFilter.AUTHORIZATION); + } + + @Test + public void testEmptyCredentialsFile() throws IOException { + File credentialFile = setupPropertyLoginFile(false); + JaasBasicAuthFilter jaasBasicAuthFilter = setupJaasFilter("KafkaConnect", credentialFile.getPath()); + ContainerRequestContext requestContext = setMock("Basic", "user", "password"); + jaasBasicAuthFilter.filter(requestContext); + + verify(requestContext).getMethod(); + verify(requestContext).getHeaderString(JaasBasicAuthFilter.AUTHORIZATION); + } + + @Test + public void testBadCredential() throws IOException { + File credentialFile = setupPropertyLoginFile(true); + JaasBasicAuthFilter jaasBasicAuthFilter = setupJaasFilter("KafkaConnect", credentialFile.getPath()); + ContainerRequestContext requestContext = setMock("Basic", "user1", "password"); + jaasBasicAuthFilter.filter(requestContext); + + verify(requestContext).abortWith(any(Response.class)); + verify(requestContext).getMethod(); + verify(requestContext).getHeaderString(JaasBasicAuthFilter.AUTHORIZATION); + } + + @Test + public void testBadPassword() throws IOException { + File credentialFile = setupPropertyLoginFile(true); + JaasBasicAuthFilter jaasBasicAuthFilter = setupJaasFilter("KafkaConnect", credentialFile.getPath()); + ContainerRequestContext requestContext = setMock("Basic", "user", "password1"); + jaasBasicAuthFilter.filter(requestContext); + + verify(requestContext).abortWith(any(Response.class)); + verify(requestContext).getMethod(); + verify(requestContext).getHeaderString(JaasBasicAuthFilter.AUTHORIZATION); + } + + @Test + public void testUnknownBearer() throws IOException { + File credentialFile = setupPropertyLoginFile(true); + JaasBasicAuthFilter jaasBasicAuthFilter = setupJaasFilter("KafkaConnect", credentialFile.getPath()); + ContainerRequestContext requestContext = setMock("Unknown", "user", "password"); + jaasBasicAuthFilter.filter(requestContext); + + verify(requestContext).abortWith(any(Response.class)); + verify(requestContext).getMethod(); + verify(requestContext).getHeaderString(JaasBasicAuthFilter.AUTHORIZATION); + } + + @Test + public void testUnknownLoginModule() throws IOException { + File credentialFile = setupPropertyLoginFile(true); + JaasBasicAuthFilter jaasBasicAuthFilter = setupJaasFilter("KafkaConnect1", credentialFile.getPath()); + ContainerRequestContext requestContext = setMock("Basic", "user", "password"); + jaasBasicAuthFilter.filter(requestContext); + + verify(requestContext).abortWith(any(Response.class)); + verify(requestContext).getMethod(); + verify(requestContext).getHeaderString(JaasBasicAuthFilter.AUTHORIZATION); + } + + @Test + public void testUnknownCredentialsFile() throws IOException { + JaasBasicAuthFilter jaasBasicAuthFilter = setupJaasFilter("KafkaConnect", "/tmp/testcrednetial"); + ContainerRequestContext requestContext = setMock("Basic", "user", "password"); + jaasBasicAuthFilter.filter(requestContext); + + verify(requestContext).abortWith(any(Response.class)); + verify(requestContext).getMethod(); + verify(requestContext).getHeaderString(JaasBasicAuthFilter.AUTHORIZATION); + } + + @Test + public void testNoFileOption() throws IOException { + JaasBasicAuthFilter jaasBasicAuthFilter = setupJaasFilter("KafkaConnect", null); + ContainerRequestContext requestContext = setMock("Basic", "user", "password"); + jaasBasicAuthFilter.filter(requestContext); + + verify(requestContext).abortWith(any(Response.class)); + verify(requestContext).getMethod(); + verify(requestContext).getHeaderString(JaasBasicAuthFilter.AUTHORIZATION); + } + + @Test + public void testPostWithoutAppropriateCredential() throws IOException { + UriInfo uriInfo = mock(UriInfo.class); + when(uriInfo.getPath()).thenReturn("connectors/connName/tasks"); + + ContainerRequestContext requestContext = mock(ContainerRequestContext.class); + when(requestContext.getMethod()).thenReturn(HttpMethod.POST); + when(requestContext.getUriInfo()).thenReturn(uriInfo); + + File credentialFile = setupPropertyLoginFile(true); + JaasBasicAuthFilter jaasBasicAuthFilter = setupJaasFilter("KafkaConnect1", credentialFile.getPath()); + + jaasBasicAuthFilter.filter(requestContext); + + verify(uriInfo).getPath(); + verify(requestContext).getMethod(); + verify(requestContext).getUriInfo(); + } + + @Test + public void testPostNotChangingConnectorTask() throws IOException { + UriInfo uriInfo = mock(UriInfo.class); + when(uriInfo.getPath()).thenReturn("local:randomport/connectors/connName"); + + ContainerRequestContext requestContext = mock(ContainerRequestContext.class); + when(requestContext.getMethod()).thenReturn(HttpMethod.POST); + when(requestContext.getUriInfo()).thenReturn(uriInfo); + String authHeader = "Basic" + Base64.getEncoder().encodeToString(("user" + ":" + "password").getBytes()); + when(requestContext.getHeaderString(JaasBasicAuthFilter.AUTHORIZATION)) + .thenReturn(authHeader); + + File credentialFile = setupPropertyLoginFile(true); + JaasBasicAuthFilter jaasBasicAuthFilter = setupJaasFilter("KafkaConnect", credentialFile.getPath()); + + jaasBasicAuthFilter.filter(requestContext); + + verify(requestContext).abortWith(any(Response.class)); + verify(requestContext).getUriInfo(); + verify(requestContext).getUriInfo(); + } + + @Test + public void testUnsupportedCallback() { + String authHeader = authHeader("basic", "user", "pwd"); + CallbackHandler callbackHandler = new JaasBasicAuthFilter.BasicAuthCallBackHandler(authHeader); + Callback unsupportedCallback = new ChoiceCallback( + "You take the blue pill... the story ends, you wake up in your bed and believe whatever you want to believe. " + + "You take the red pill... you stay in Wonderland, and I show you how deep the rabbit hole goes.", + new String[] {"blue pill", "red pill"}, + 1, + true + ); + assertThrows(ConnectException.class, () -> callbackHandler.handle(new Callback[] {unsupportedCallback})); + } + + private String authHeader(String authorization, String username, String password) { + return authorization + " " + Base64.getEncoder().encodeToString((username + ":" + password).getBytes()); + } + + private ContainerRequestContext setMock(String authorization, String username, String password) { + ContainerRequestContext requestContext = mock(ContainerRequestContext.class); + when(requestContext.getMethod()).thenReturn(HttpMethod.GET); + when(requestContext.getHeaderString(JaasBasicAuthFilter.AUTHORIZATION)) + .thenReturn(authHeader(authorization, username, password)); + return requestContext; + } + + private File setupPropertyLoginFile(boolean includeUsers) throws IOException { + File credentialFile = File.createTempFile("credential", ".properties"); + credentialFile.deleteOnExit(); + if (includeUsers) { + List lines = new ArrayList<>(); + lines.add("user=password"); + lines.add("user1=password1"); + Files.write(credentialFile.toPath(), lines, StandardCharsets.UTF_8); + } + return credentialFile; + } + + private JaasBasicAuthFilter setupJaasFilter(String name, String credentialFilePath) { + TestJaasConfig configuration = new TestJaasConfig(); + Map moduleOptions = credentialFilePath != null + ? Collections.singletonMap("file", credentialFilePath) + : Collections.emptyMap(); + configuration.addEntry(name, LOGIN_MODULE, moduleOptions); + return new JaasBasicAuthFilter(configuration); + } + +} diff --git a/connect/file/src/main/java/org/apache/kafka/connect/file/FileStreamSinkConnector.java b/connect/file/src/main/java/org/apache/kafka/connect/file/FileStreamSinkConnector.java new file mode 100644 index 0000000..136e899 --- /dev/null +++ b/connect/file/src/main/java/org/apache/kafka/connect/file/FileStreamSinkConnector.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.file; + +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigDef.Importance; +import org.apache.kafka.common.config.ConfigDef.Type; +import org.apache.kafka.common.utils.AppInfoParser; +import org.apache.kafka.connect.connector.Task; +import org.apache.kafka.connect.sink.SinkConnector; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Very simple connector that works with the console. This connector supports both source and + * sink modes via its 'mode' setting. + */ +public class FileStreamSinkConnector extends SinkConnector { + + public static final String FILE_CONFIG = "file"; + private static final ConfigDef CONFIG_DEF = new ConfigDef() + .define(FILE_CONFIG, Type.STRING, null, Importance.HIGH, "Destination filename. If not specified, the standard output will be used"); + + private String filename; + + @Override + public String version() { + return AppInfoParser.getVersion(); + } + + @Override + public void start(Map props) { + AbstractConfig parsedConfig = new AbstractConfig(CONFIG_DEF, props); + filename = parsedConfig.getString(FILE_CONFIG); + } + + @Override + public Class taskClass() { + return FileStreamSinkTask.class; + } + + @Override + public List> taskConfigs(int maxTasks) { + ArrayList> configs = new ArrayList<>(); + for (int i = 0; i < maxTasks; i++) { + Map config = new HashMap<>(); + if (filename != null) + config.put(FILE_CONFIG, filename); + configs.add(config); + } + return configs; + } + + @Override + public void stop() { + // Nothing to do since FileStreamSinkConnector has no background monitoring. + } + + @Override + public ConfigDef config() { + return CONFIG_DEF; + } +} diff --git a/connect/file/src/main/java/org/apache/kafka/connect/file/FileStreamSinkTask.java b/connect/file/src/main/java/org/apache/kafka/connect/file/FileStreamSinkTask.java new file mode 100644 index 0000000..3d1d2b8 --- /dev/null +++ b/connect/file/src/main/java/org/apache/kafka/connect/file/FileStreamSinkTask.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.file; + +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.sink.SinkRecord; +import org.apache.kafka.connect.sink.SinkTask; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.PrintStream; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.nio.file.StandardOpenOption; +import java.util.Collection; +import java.util.Map; + +/** + * FileStreamSinkTask writes records to stdout or a file. + */ +public class FileStreamSinkTask extends SinkTask { + private static final Logger log = LoggerFactory.getLogger(FileStreamSinkTask.class); + + private String filename; + private PrintStream outputStream; + + public FileStreamSinkTask() { + } + + // for testing + public FileStreamSinkTask(PrintStream outputStream) { + filename = null; + this.outputStream = outputStream; + } + + @Override + public String version() { + return new FileStreamSinkConnector().version(); + } + + @Override + public void start(Map props) { + filename = props.get(FileStreamSinkConnector.FILE_CONFIG); + if (filename == null) { + outputStream = System.out; + } else { + try { + outputStream = new PrintStream( + Files.newOutputStream(Paths.get(filename), StandardOpenOption.CREATE, StandardOpenOption.APPEND), + false, + StandardCharsets.UTF_8.name()); + } catch (IOException e) { + throw new ConnectException("Couldn't find or create file '" + filename + "' for FileStreamSinkTask", e); + } + } + } + + @Override + public void put(Collection sinkRecords) { + for (SinkRecord record : sinkRecords) { + log.trace("Writing line to {}: {}", logFilename(), record.value()); + outputStream.println(record.value()); + } + } + + @Override + public void flush(Map offsets) { + log.trace("Flushing output stream for {}", logFilename()); + outputStream.flush(); + } + + @Override + public void stop() { + if (outputStream != null && outputStream != System.out) + outputStream.close(); + } + + private String logFilename() { + return filename == null ? "stdout" : filename; + } +} diff --git a/connect/file/src/main/java/org/apache/kafka/connect/file/FileStreamSourceConnector.java b/connect/file/src/main/java/org/apache/kafka/connect/file/FileStreamSourceConnector.java new file mode 100644 index 0000000..74b5f7c --- /dev/null +++ b/connect/file/src/main/java/org/apache/kafka/connect/file/FileStreamSourceConnector.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.file; + +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigDef.Importance; +import org.apache.kafka.common.config.ConfigDef.Type; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.utils.AppInfoParser; +import org.apache.kafka.connect.connector.Task; +import org.apache.kafka.connect.source.SourceConnector; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Very simple connector that works with the console. This connector supports both source and + * sink modes via its 'mode' setting. + */ +public class FileStreamSourceConnector extends SourceConnector { + public static final String TOPIC_CONFIG = "topic"; + public static final String FILE_CONFIG = "file"; + public static final String TASK_BATCH_SIZE_CONFIG = "batch.size"; + + public static final int DEFAULT_TASK_BATCH_SIZE = 2000; + + private static final ConfigDef CONFIG_DEF = new ConfigDef() + .define(FILE_CONFIG, Type.STRING, null, Importance.HIGH, "Source filename. If not specified, the standard input will be used") + .define(TOPIC_CONFIG, Type.LIST, Importance.HIGH, "The topic to publish data to") + .define(TASK_BATCH_SIZE_CONFIG, Type.INT, DEFAULT_TASK_BATCH_SIZE, Importance.LOW, + "The maximum number of records the Source task can read from file one time"); + + private String filename; + private String topic; + private int batchSize; + + @Override + public String version() { + return AppInfoParser.getVersion(); + } + + @Override + public void start(Map props) { + AbstractConfig parsedConfig = new AbstractConfig(CONFIG_DEF, props); + filename = parsedConfig.getString(FILE_CONFIG); + List topics = parsedConfig.getList(TOPIC_CONFIG); + if (topics.size() != 1) { + throw new ConfigException("'topic' in FileStreamSourceConnector configuration requires definition of a single topic"); + } + topic = topics.get(0); + batchSize = parsedConfig.getInt(TASK_BATCH_SIZE_CONFIG); + } + + @Override + public Class taskClass() { + return FileStreamSourceTask.class; + } + + @Override + public List> taskConfigs(int maxTasks) { + ArrayList> configs = new ArrayList<>(); + // Only one input stream makes sense. + Map config = new HashMap<>(); + if (filename != null) + config.put(FILE_CONFIG, filename); + config.put(TOPIC_CONFIG, topic); + config.put(TASK_BATCH_SIZE_CONFIG, String.valueOf(batchSize)); + configs.add(config); + return configs; + } + + @Override + public void stop() { + // Nothing to do since FileStreamSourceConnector has no background monitoring. + } + + @Override + public ConfigDef config() { + return CONFIG_DEF; + } +} diff --git a/connect/file/src/main/java/org/apache/kafka/connect/file/FileStreamSourceTask.java b/connect/file/src/main/java/org/apache/kafka/connect/file/FileStreamSourceTask.java new file mode 100644 index 0000000..8e3fb89 --- /dev/null +++ b/connect/file/src/main/java/org/apache/kafka/connect/file/FileStreamSourceTask.java @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.file; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.NoSuchFileException; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.source.SourceRecord; +import org.apache.kafka.connect.source.SourceTask; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * FileStreamSourceTask reads from stdin or a file. + */ +public class FileStreamSourceTask extends SourceTask { + private static final Logger log = LoggerFactory.getLogger(FileStreamSourceTask.class); + public static final String FILENAME_FIELD = "filename"; + public static final String POSITION_FIELD = "position"; + private static final Schema VALUE_SCHEMA = Schema.STRING_SCHEMA; + + private String filename; + private InputStream stream; + private BufferedReader reader = null; + private char[] buffer; + private int offset = 0; + private String topic = null; + private int batchSize = FileStreamSourceConnector.DEFAULT_TASK_BATCH_SIZE; + + private Long streamOffset; + + public FileStreamSourceTask() { + this(1024); + } + + /* visible for testing */ + FileStreamSourceTask(int initialBufferSize) { + buffer = new char[initialBufferSize]; + } + + @Override + public String version() { + return new FileStreamSourceConnector().version(); + } + + @Override + public void start(Map props) { + filename = props.get(FileStreamSourceConnector.FILE_CONFIG); + if (filename == null || filename.isEmpty()) { + stream = System.in; + // Tracking offset for stdin doesn't make sense + streamOffset = null; + reader = new BufferedReader(new InputStreamReader(stream, StandardCharsets.UTF_8)); + } + // Missing topic or parsing error is not possible because we've parsed the config in the + // Connector + topic = props.get(FileStreamSourceConnector.TOPIC_CONFIG); + batchSize = Integer.parseInt(props.get(FileStreamSourceConnector.TASK_BATCH_SIZE_CONFIG)); + } + + @Override + public List poll() throws InterruptedException { + if (stream == null) { + try { + stream = Files.newInputStream(Paths.get(filename)); + Map offset = context.offsetStorageReader().offset(Collections.singletonMap(FILENAME_FIELD, filename)); + if (offset != null) { + Object lastRecordedOffset = offset.get(POSITION_FIELD); + if (lastRecordedOffset != null && !(lastRecordedOffset instanceof Long)) + throw new ConnectException("Offset position is the incorrect type"); + if (lastRecordedOffset != null) { + log.debug("Found previous offset, trying to skip to file offset {}", lastRecordedOffset); + long skipLeft = (Long) lastRecordedOffset; + while (skipLeft > 0) { + try { + long skipped = stream.skip(skipLeft); + skipLeft -= skipped; + } catch (IOException e) { + log.error("Error while trying to seek to previous offset in file {}: ", filename, e); + throw new ConnectException(e); + } + } + log.debug("Skipped to offset {}", lastRecordedOffset); + } + streamOffset = (lastRecordedOffset != null) ? (Long) lastRecordedOffset : 0L; + } else { + streamOffset = 0L; + } + reader = new BufferedReader(new InputStreamReader(stream, StandardCharsets.UTF_8)); + log.debug("Opened {} for reading", logFilename()); + } catch (NoSuchFileException e) { + log.warn("Couldn't find file {} for FileStreamSourceTask, sleeping to wait for it to be created", logFilename()); + synchronized (this) { + this.wait(1000); + } + return null; + } catch (IOException e) { + log.error("Error while trying to open file {}: ", filename, e); + throw new ConnectException(e); + } + } + + // Unfortunately we can't just use readLine() because it blocks in an uninterruptible way. + // Instead we have to manage splitting lines ourselves, using simple backoff when no new data + // is available. + try { + final BufferedReader readerCopy; + synchronized (this) { + readerCopy = reader; + } + if (readerCopy == null) + return null; + + ArrayList records = null; + + int nread = 0; + while (readerCopy.ready()) { + nread = readerCopy.read(buffer, offset, buffer.length - offset); + log.trace("Read {} bytes from {}", nread, logFilename()); + + if (nread > 0) { + offset += nread; + String line; + boolean foundOneLine = false; + do { + line = extractLine(); + if (line != null) { + foundOneLine = true; + log.trace("Read a line from {}", logFilename()); + if (records == null) + records = new ArrayList<>(); + records.add(new SourceRecord(offsetKey(filename), offsetValue(streamOffset), topic, null, + null, null, VALUE_SCHEMA, line, System.currentTimeMillis())); + + if (records.size() >= batchSize) { + return records; + } + } + } while (line != null); + + if (!foundOneLine && offset == buffer.length) { + char[] newbuf = new char[buffer.length * 2]; + System.arraycopy(buffer, 0, newbuf, 0, buffer.length); + log.info("Increased buffer from {} to {}", buffer.length, newbuf.length); + buffer = newbuf; + } + } + } + + if (nread <= 0) + synchronized (this) { + this.wait(1000); + } + + return records; + } catch (IOException e) { + // Underlying stream was killed, probably as a result of calling stop. Allow to return + // null, and driving thread will handle any shutdown if necessary. + } + return null; + } + + private String extractLine() { + int until = -1, newStart = -1; + for (int i = 0; i < offset; i++) { + if (buffer[i] == '\n') { + until = i; + newStart = i + 1; + break; + } else if (buffer[i] == '\r') { + // We need to check for \r\n, so we must skip this if we can't check the next char + if (i + 1 >= offset) + return null; + + until = i; + newStart = (buffer[i + 1] == '\n') ? i + 2 : i + 1; + break; + } + } + + if (until != -1) { + String result = new String(buffer, 0, until); + System.arraycopy(buffer, newStart, buffer, 0, buffer.length - newStart); + offset = offset - newStart; + if (streamOffset != null) + streamOffset += newStart; + return result; + } else { + return null; + } + } + + @Override + public void stop() { + log.trace("Stopping"); + synchronized (this) { + try { + if (stream != null && stream != System.in) { + stream.close(); + log.trace("Closed input stream"); + } + } catch (IOException e) { + log.error("Failed to close FileStreamSourceTask stream: ", e); + } + this.notify(); + } + } + + private Map offsetKey(String filename) { + return Collections.singletonMap(FILENAME_FIELD, filename); + } + + private Map offsetValue(Long pos) { + return Collections.singletonMap(POSITION_FIELD, pos); + } + + private String logFilename() { + return filename == null ? "stdin" : filename; + } + + /* visible for testing */ + int bufferSize() { + return buffer.length; + } +} diff --git a/connect/file/src/test/java/org/apache/kafka/connect/file/FileStreamSinkConnectorTest.java b/connect/file/src/test/java/org/apache/kafka/connect/file/FileStreamSinkConnectorTest.java new file mode 100644 index 0000000..548e388 --- /dev/null +++ b/connect/file/src/test/java/org/apache/kafka/connect/file/FileStreamSinkConnectorTest.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.file; + +import org.apache.kafka.common.config.ConfigValue; +import org.apache.kafka.connect.connector.ConnectorContext; +import org.apache.kafka.connect.sink.SinkConnector; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.mockito.Mockito.mock; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class FileStreamSinkConnectorTest { + + private static final String MULTIPLE_TOPICS = "test1,test2"; + private static final String FILENAME = "/afilename"; + + private FileStreamSinkConnector connector; + private ConnectorContext ctx; + private Map sinkProperties; + + @BeforeEach + public void setup() { + connector = new FileStreamSinkConnector(); + ctx = mock(ConnectorContext.class); + connector.initialize(ctx); + + sinkProperties = new HashMap<>(); + sinkProperties.put(SinkConnector.TOPICS_CONFIG, MULTIPLE_TOPICS); + sinkProperties.put(FileStreamSinkConnector.FILE_CONFIG, FILENAME); + } + + @Test + public void testConnectorConfigValidation() { + List configValues = connector.config().validate(sinkProperties); + for (ConfigValue val : configValues) { + assertEquals(0, val.errorMessages().size(), "Config property errors: " + val.errorMessages()); + } + } + + @Test + public void testSinkTasks() { + connector.start(sinkProperties); + List> taskConfigs = connector.taskConfigs(1); + assertEquals(1, taskConfigs.size()); + assertEquals(FILENAME, taskConfigs.get(0).get(FileStreamSinkConnector.FILE_CONFIG)); + + taskConfigs = connector.taskConfigs(2); + assertEquals(2, taskConfigs.size()); + for (int i = 0; i < 2; i++) { + assertEquals(FILENAME, taskConfigs.get(0).get(FileStreamSinkConnector.FILE_CONFIG)); + } + } + + @Test + public void testSinkTasksStdout() { + sinkProperties.remove(FileStreamSourceConnector.FILE_CONFIG); + connector.start(sinkProperties); + List> taskConfigs = connector.taskConfigs(1); + assertEquals(1, taskConfigs.size()); + assertNull(taskConfigs.get(0).get(FileStreamSourceConnector.FILE_CONFIG)); + } + + @Test + public void testTaskClass() { + connector.start(sinkProperties); + assertEquals(FileStreamSinkTask.class, connector.taskClass()); + } +} diff --git a/connect/file/src/test/java/org/apache/kafka/connect/file/FileStreamSinkTaskTest.java b/connect/file/src/test/java/org/apache/kafka/connect/file/FileStreamSinkTaskTest.java new file mode 100644 index 0000000..3878530 --- /dev/null +++ b/connect/file/src/test/java/org/apache/kafka/connect/file/FileStreamSinkTaskTest.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.file; + +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.sink.SinkRecord; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.BufferedReader; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class FileStreamSinkTaskTest { + + private FileStreamSinkTask task; + private ByteArrayOutputStream os; + private PrintStream printStream; + + @TempDir + public Path topDir; + private String outputFile; + + @BeforeEach + public void setup() { + os = new ByteArrayOutputStream(); + printStream = new PrintStream(os); + task = new FileStreamSinkTask(printStream); + outputFile = topDir.resolve("connect.output").toAbsolutePath().toString(); + } + + @Test + public void testPutFlush() { + HashMap offsets = new HashMap<>(); + final String newLine = System.getProperty("line.separator"); + + // We do not call task.start() since it would override the output stream + + task.put(Arrays.asList( + new SinkRecord("topic1", 0, null, null, Schema.STRING_SCHEMA, "line1", 1) + )); + offsets.put(new TopicPartition("topic1", 0), new OffsetAndMetadata(1L)); + task.flush(offsets); + assertEquals("line1" + newLine, os.toString()); + + task.put(Arrays.asList( + new SinkRecord("topic1", 0, null, null, Schema.STRING_SCHEMA, "line2", 2), + new SinkRecord("topic2", 0, null, null, Schema.STRING_SCHEMA, "line3", 1) + )); + offsets.put(new TopicPartition("topic1", 0), new OffsetAndMetadata(2L)); + offsets.put(new TopicPartition("topic2", 0), new OffsetAndMetadata(1L)); + task.flush(offsets); + assertEquals("line1" + newLine + "line2" + newLine + "line3" + newLine, os.toString()); + } + + @Test + public void testStart() throws IOException { + task = new FileStreamSinkTask(); + Map props = new HashMap<>(); + props.put(FileStreamSinkConnector.FILE_CONFIG, outputFile); + task.start(props); + + HashMap offsets = new HashMap<>(); + task.put(Arrays.asList( + new SinkRecord("topic1", 0, null, null, Schema.STRING_SCHEMA, "line0", 1) + )); + offsets.put(new TopicPartition("topic1", 0), new OffsetAndMetadata(1L)); + task.flush(offsets); + + int numLines = 3; + String[] lines = new String[numLines]; + int i = 0; + try (BufferedReader reader = Files.newBufferedReader(Paths.get(outputFile))) { + lines[i++] = reader.readLine(); + task.put(Arrays.asList( + new SinkRecord("topic1", 0, null, null, Schema.STRING_SCHEMA, "line1", 2), + new SinkRecord("topic2", 0, null, null, Schema.STRING_SCHEMA, "line2", 1) + )); + offsets.put(new TopicPartition("topic1", 0), new OffsetAndMetadata(2L)); + offsets.put(new TopicPartition("topic2", 0), new OffsetAndMetadata(1L)); + task.flush(offsets); + lines[i++] = reader.readLine(); + lines[i++] = reader.readLine(); + } + + while (--i >= 0) { + assertEquals("line" + i, lines[i]); + } + } +} diff --git a/connect/file/src/test/java/org/apache/kafka/connect/file/FileStreamSourceConnectorTest.java b/connect/file/src/test/java/org/apache/kafka/connect/file/FileStreamSourceConnectorTest.java new file mode 100644 index 0000000..3550d5c --- /dev/null +++ b/connect/file/src/test/java/org/apache/kafka/connect/file/FileStreamSourceConnectorTest.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.file; + +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.config.ConfigValue; +import org.apache.kafka.connect.connector.ConnectorContext; +import org.easymock.EasyMock; +import org.easymock.EasyMockSupport; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class FileStreamSourceConnectorTest extends EasyMockSupport { + + private static final String SINGLE_TOPIC = "test"; + private static final String MULTIPLE_TOPICS = "test1,test2"; + private static final String FILENAME = "/somefilename"; + + private FileStreamSourceConnector connector; + private ConnectorContext ctx; + private Map sourceProperties; + + @BeforeEach + public void setup() { + connector = new FileStreamSourceConnector(); + ctx = createMock(ConnectorContext.class); + connector.initialize(ctx); + + sourceProperties = new HashMap<>(); + sourceProperties.put(FileStreamSourceConnector.TOPIC_CONFIG, SINGLE_TOPIC); + sourceProperties.put(FileStreamSourceConnector.FILE_CONFIG, FILENAME); + } + + @Test + public void testConnectorConfigValidation() { + replayAll(); + List configValues = connector.config().validate(sourceProperties); + for (ConfigValue val : configValues) { + assertEquals(0, val.errorMessages().size(), "Config property errors: " + val.errorMessages()); + } + verifyAll(); + } + + @Test + public void testSourceTasks() { + replayAll(); + + connector.start(sourceProperties); + List> taskConfigs = connector.taskConfigs(1); + assertEquals(1, taskConfigs.size()); + assertEquals(FILENAME, + taskConfigs.get(0).get(FileStreamSourceConnector.FILE_CONFIG)); + assertEquals(SINGLE_TOPIC, + taskConfigs.get(0).get(FileStreamSourceConnector.TOPIC_CONFIG)); + + // Should be able to return fewer than requested # + taskConfigs = connector.taskConfigs(2); + assertEquals(1, taskConfigs.size()); + assertEquals(FILENAME, + taskConfigs.get(0).get(FileStreamSourceConnector.FILE_CONFIG)); + assertEquals(SINGLE_TOPIC, + taskConfigs.get(0).get(FileStreamSourceConnector.TOPIC_CONFIG)); + + verifyAll(); + } + + @Test + public void testSourceTasksStdin() { + EasyMock.replay(ctx); + + sourceProperties.remove(FileStreamSourceConnector.FILE_CONFIG); + connector.start(sourceProperties); + List> taskConfigs = connector.taskConfigs(1); + assertEquals(1, taskConfigs.size()); + assertNull(taskConfigs.get(0).get(FileStreamSourceConnector.FILE_CONFIG)); + + EasyMock.verify(ctx); + } + + @Test + public void testMultipleSourcesInvalid() { + sourceProperties.put(FileStreamSourceConnector.TOPIC_CONFIG, MULTIPLE_TOPICS); + assertThrows(ConfigException.class, () -> connector.start(sourceProperties)); + } + + @Test + public void testTaskClass() { + EasyMock.replay(ctx); + + connector.start(sourceProperties); + assertEquals(FileStreamSourceTask.class, connector.taskClass()); + + EasyMock.verify(ctx); + } + + @Test + public void testMissingTopic() { + sourceProperties.remove(FileStreamSourceConnector.TOPIC_CONFIG); + assertThrows(ConfigException.class, () -> connector.start(sourceProperties)); + } + + @Test + public void testBlankTopic() { + // Because of trimming this tests is same as testing for empty string. + sourceProperties.put(FileStreamSourceConnector.TOPIC_CONFIG, " "); + assertThrows(ConfigException.class, () -> connector.start(sourceProperties)); + } + + @Test + public void testInvalidBatchSize() { + sourceProperties.put(FileStreamSourceConnector.TASK_BATCH_SIZE_CONFIG, "abcd"); + assertThrows(ConfigException.class, () -> connector.start(sourceProperties)); + } +} diff --git a/connect/file/src/test/java/org/apache/kafka/connect/file/FileStreamSourceTaskTest.java b/connect/file/src/test/java/org/apache/kafka/connect/file/FileStreamSourceTaskTest.java new file mode 100644 index 0000000..d02463d --- /dev/null +++ b/connect/file/src/test/java/org/apache/kafka/connect/file/FileStreamSourceTaskTest.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.file; + +import org.apache.kafka.connect.source.SourceRecord; +import org.apache.kafka.connect.source.SourceTaskContext; +import org.apache.kafka.connect.storage.OffsetStorageReader; +import org.easymock.EasyMock; +import org.easymock.EasyMockSupport; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.ByteArrayInputStream; +import java.io.File; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.file.Files; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +public class FileStreamSourceTaskTest extends EasyMockSupport { + + private static final String TOPIC = "test"; + + private File tempFile; + private Map config; + private OffsetStorageReader offsetStorageReader; + private SourceTaskContext context; + private FileStreamSourceTask task; + + private boolean verifyMocks = false; + + @BeforeEach + public void setup() throws IOException { + tempFile = File.createTempFile("file-stream-source-task-test", null); + config = new HashMap<>(); + config.put(FileStreamSourceConnector.FILE_CONFIG, tempFile.getAbsolutePath()); + config.put(FileStreamSourceConnector.TOPIC_CONFIG, TOPIC); + config.put(FileStreamSourceConnector.TASK_BATCH_SIZE_CONFIG, String.valueOf(FileStreamSourceConnector.DEFAULT_TASK_BATCH_SIZE)); + task = new FileStreamSourceTask(2); + offsetStorageReader = createMock(OffsetStorageReader.class); + context = createMock(SourceTaskContext.class); + task.initialize(context); + } + + @AfterEach + public void teardown() { + tempFile.delete(); + + if (verifyMocks) + verifyAll(); + } + + private void replay() { + replayAll(); + verifyMocks = true; + } + + @Test + public void testNormalLifecycle() throws InterruptedException, IOException { + expectOffsetLookupReturnNone(); + replay(); + + task.start(config); + + OutputStream os = Files.newOutputStream(tempFile.toPath()); + assertNull(task.poll()); + os.write("partial line".getBytes()); + os.flush(); + assertNull(task.poll()); + os.write(" finished\n".getBytes()); + os.flush(); + List records = task.poll(); + assertEquals(1, records.size()); + assertEquals(TOPIC, records.get(0).topic()); + assertEquals("partial line finished", records.get(0).value()); + assertEquals(Collections.singletonMap(FileStreamSourceTask.FILENAME_FIELD, tempFile.getAbsolutePath()), records.get(0).sourcePartition()); + assertEquals(Collections.singletonMap(FileStreamSourceTask.POSITION_FIELD, 22L), records.get(0).sourceOffset()); + assertNull(task.poll()); + + // Different line endings, and make sure the final \r doesn't result in a line until we can + // read the subsequent byte. + os.write("line1\rline2\r\nline3\nline4\n\r".getBytes()); + os.flush(); + records = task.poll(); + assertEquals(4, records.size()); + assertEquals("line1", records.get(0).value()); + assertEquals(Collections.singletonMap(FileStreamSourceTask.FILENAME_FIELD, tempFile.getAbsolutePath()), records.get(0).sourcePartition()); + assertEquals(Collections.singletonMap(FileStreamSourceTask.POSITION_FIELD, 28L), records.get(0).sourceOffset()); + assertEquals("line2", records.get(1).value()); + assertEquals(Collections.singletonMap(FileStreamSourceTask.FILENAME_FIELD, tempFile.getAbsolutePath()), records.get(1).sourcePartition()); + assertEquals(Collections.singletonMap(FileStreamSourceTask.POSITION_FIELD, 35L), records.get(1).sourceOffset()); + assertEquals("line3", records.get(2).value()); + assertEquals(Collections.singletonMap(FileStreamSourceTask.FILENAME_FIELD, tempFile.getAbsolutePath()), records.get(2).sourcePartition()); + assertEquals(Collections.singletonMap(FileStreamSourceTask.POSITION_FIELD, 41L), records.get(2).sourceOffset()); + assertEquals("line4", records.get(3).value()); + assertEquals(Collections.singletonMap(FileStreamSourceTask.FILENAME_FIELD, tempFile.getAbsolutePath()), records.get(3).sourcePartition()); + assertEquals(Collections.singletonMap(FileStreamSourceTask.POSITION_FIELD, 47L), records.get(3).sourceOffset()); + + os.write("subsequent text".getBytes()); + os.flush(); + records = task.poll(); + assertEquals(1, records.size()); + assertEquals("", records.get(0).value()); + assertEquals(Collections.singletonMap(FileStreamSourceTask.FILENAME_FIELD, tempFile.getAbsolutePath()), records.get(0).sourcePartition()); + assertEquals(Collections.singletonMap(FileStreamSourceTask.POSITION_FIELD, 48L), records.get(0).sourceOffset()); + + os.close(); + task.stop(); + } + + @Test + public void testBatchSize() throws IOException, InterruptedException { + expectOffsetLookupReturnNone(); + replay(); + + config.put(FileStreamSourceConnector.TASK_BATCH_SIZE_CONFIG, "5000"); + task.start(config); + + OutputStream os = Files.newOutputStream(tempFile.toPath()); + writeTimesAndFlush(os, 10_000, + "Neque porro quisquam est qui dolorem ipsum quia dolor sit amet, consectetur, adipisci velit...\n".getBytes() + ); + + assertEquals(2, task.bufferSize()); + List records = task.poll(); + assertEquals(5000, records.size()); + assertEquals(128, task.bufferSize()); + + records = task.poll(); + assertEquals(5000, records.size()); + assertEquals(128, task.bufferSize()); + + os.close(); + task.stop(); + } + + @Test + public void testBufferResize() throws IOException, InterruptedException { + int batchSize = 1000; + expectOffsetLookupReturnNone(); + replay(); + + config.put(FileStreamSourceConnector.TASK_BATCH_SIZE_CONFIG, Integer.toString(batchSize)); + task.start(config); + + OutputStream os = Files.newOutputStream(tempFile.toPath()); + + assertEquals(2, task.bufferSize()); + writeAndAssertBufferSize(batchSize, os, "1\n".getBytes(), 2); + writeAndAssertBufferSize(batchSize, os, "3 \n".getBytes(), 4); + writeAndAssertBufferSize(batchSize, os, "7 \n".getBytes(), 8); + writeAndAssertBufferSize(batchSize, os, "8 \n".getBytes(), 8); + writeAndAssertBufferSize(batchSize, os, "9 \n".getBytes(), 16); + + byte[] bytes = new byte[1025]; + Arrays.fill(bytes, (byte) '*'); + bytes[bytes.length - 1] = '\n'; + writeAndAssertBufferSize(batchSize, os, bytes, 2048); + writeAndAssertBufferSize(batchSize, os, "9 \n".getBytes(), 2048); + os.close(); + task.stop(); + } + + private void writeAndAssertBufferSize(int batchSize, OutputStream os, byte[] bytes, int expectBufferSize) + throws IOException, InterruptedException { + writeTimesAndFlush(os, batchSize, bytes); + List records = task.poll(); + assertEquals(batchSize, records.size()); + String expectedLine = new String(bytes, 0, bytes.length - 1); // remove \n + for (SourceRecord record : records) { + assertEquals(expectedLine, record.value()); + } + assertEquals(expectBufferSize, task.bufferSize()); + } + + private void writeTimesAndFlush(OutputStream os, int times, byte[] line) throws IOException { + for (int i = 0; i < times; i++) { + os.write(line); + } + os.flush(); + } + + @Test + public void testMissingFile() throws InterruptedException { + replay(); + + String data = "line\n"; + System.setIn(new ByteArrayInputStream(data.getBytes())); + + config.remove(FileStreamSourceConnector.FILE_CONFIG); + task.start(config); + + List records = task.poll(); + assertEquals(1, records.size()); + assertEquals(TOPIC, records.get(0).topic()); + assertEquals("line", records.get(0).value()); + + task.stop(); + } + + public void testInvalidFile() throws InterruptedException { + config.put(FileStreamSourceConnector.FILE_CONFIG, "bogusfilename"); + task.start(config); + // Currently the task retries indefinitely if the file isn't found, but shouldn't return any data. + for (int i = 0; i < 100; i++) + assertNull(task.poll()); + } + + + private void expectOffsetLookupReturnNone() { + EasyMock.expect(context.offsetStorageReader()).andReturn(offsetStorageReader); + EasyMock.expect(offsetStorageReader.offset(EasyMock.>anyObject())).andReturn(null); + } +} diff --git a/connect/json/.gitignore b/connect/json/.gitignore new file mode 100644 index 0000000..ae3c172 --- /dev/null +++ b/connect/json/.gitignore @@ -0,0 +1 @@ +/bin/ diff --git a/connect/json/src/main/java/org/apache/kafka/connect/json/DecimalFormat.java b/connect/json/src/main/java/org/apache/kafka/connect/json/DecimalFormat.java new file mode 100644 index 0000000..b4a7fc5 --- /dev/null +++ b/connect/json/src/main/java/org/apache/kafka/connect/json/DecimalFormat.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.json; + +/** + * Represents the valid {@link org.apache.kafka.connect.data.Decimal} serialization formats + * in a {@link JsonConverter}. + */ +public enum DecimalFormat { + + /** + * Serializes the JSON Decimal as a base-64 string. For example, serializing the value + * `10.2345` with the BASE64 setting will result in `"D3J5"`. + */ + BASE64, + + /** + * Serializes the JSON Decimal as a JSON number. For example, serializing the value + * `10.2345` with the NUMERIC setting will result in `10.2345`. + */ + NUMERIC +} diff --git a/connect/json/src/main/java/org/apache/kafka/connect/json/JsonConverter.java b/connect/json/src/main/java/org/apache/kafka/connect/json/JsonConverter.java new file mode 100644 index 0000000..10fde8f --- /dev/null +++ b/connect/json/src/main/java/org/apache/kafka/connect/json/JsonConverter.java @@ -0,0 +1,741 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.json; + +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.JsonNodeFactory; +import com.fasterxml.jackson.databind.node.ObjectNode; +import org.apache.kafka.common.cache.Cache; +import org.apache.kafka.common.cache.LRUCache; +import org.apache.kafka.common.cache.SynchronizedCache; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.errors.SerializationException; +import org.apache.kafka.connect.data.SchemaBuilder; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.data.Field; +import org.apache.kafka.connect.data.ConnectSchema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.data.Timestamp; +import org.apache.kafka.connect.data.Time; +import org.apache.kafka.connect.data.Decimal; +import org.apache.kafka.connect.data.Date; +import org.apache.kafka.connect.errors.DataException; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.storage.ConverterType; +import org.apache.kafka.connect.storage.HeaderConverter; +import org.apache.kafka.connect.storage.StringConverterConfig; + +import java.io.IOException; +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.EnumMap; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; + +import static org.apache.kafka.common.utils.Utils.mkSet; + +/** + * Implementation of Converter that uses JSON to store schemas and objects. By default this converter will serialize Connect keys, values, + * and headers with schemas, although this can be disabled with {@link JsonConverterConfig#SCHEMAS_ENABLE_CONFIG schemas.enable} + * configuration option. + * + * This implementation currently does nothing with the topic names or header names. + */ +public class JsonConverter implements Converter, HeaderConverter { + + private static final Map TO_CONNECT_CONVERTERS = new EnumMap<>(Schema.Type.class); + + static { + TO_CONNECT_CONVERTERS.put(Schema.Type.BOOLEAN, (schema, value) -> value.booleanValue()); + TO_CONNECT_CONVERTERS.put(Schema.Type.INT8, (schema, value) -> (byte) value.intValue()); + TO_CONNECT_CONVERTERS.put(Schema.Type.INT16, (schema, value) -> (short) value.intValue()); + TO_CONNECT_CONVERTERS.put(Schema.Type.INT32, (schema, value) -> value.intValue()); + TO_CONNECT_CONVERTERS.put(Schema.Type.INT64, (schema, value) -> value.longValue()); + TO_CONNECT_CONVERTERS.put(Schema.Type.FLOAT32, (schema, value) -> value.floatValue()); + TO_CONNECT_CONVERTERS.put(Schema.Type.FLOAT64, (schema, value) -> value.doubleValue()); + TO_CONNECT_CONVERTERS.put(Schema.Type.BYTES, (schema, value) -> { + try { + return value.binaryValue(); + } catch (IOException e) { + throw new DataException("Invalid bytes field", e); + } + }); + TO_CONNECT_CONVERTERS.put(Schema.Type.STRING, (schema, value) -> value.textValue()); + TO_CONNECT_CONVERTERS.put(Schema.Type.ARRAY, (schema, value) -> { + Schema elemSchema = schema == null ? null : schema.valueSchema(); + ArrayList result = new ArrayList<>(); + for (JsonNode elem : value) { + result.add(convertToConnect(elemSchema, elem)); + } + return result; + }); + TO_CONNECT_CONVERTERS.put(Schema.Type.MAP, (schema, value) -> { + Schema keySchema = schema == null ? null : schema.keySchema(); + Schema valueSchema = schema == null ? null : schema.valueSchema(); + + // If the map uses strings for keys, it should be encoded in the natural JSON format. If it uses other + // primitive types or a complex type as a key, it will be encoded as a list of pairs. If we don't have a + // schema, we default to encoding in a Map. + Map result = new HashMap<>(); + if (schema == null || keySchema.type() == Schema.Type.STRING) { + if (!value.isObject()) + throw new DataException("Maps with string fields should be encoded as JSON objects, but found " + value.getNodeType()); + Iterator> fieldIt = value.fields(); + while (fieldIt.hasNext()) { + Map.Entry entry = fieldIt.next(); + result.put(entry.getKey(), convertToConnect(valueSchema, entry.getValue())); + } + } else { + if (!value.isArray()) + throw new DataException("Maps with non-string fields should be encoded as JSON array of tuples, but found " + value.getNodeType()); + for (JsonNode entry : value) { + if (!entry.isArray()) + throw new DataException("Found invalid map entry instead of array tuple: " + entry.getNodeType()); + if (entry.size() != 2) + throw new DataException("Found invalid map entry, expected length 2 but found :" + entry.size()); + result.put(convertToConnect(keySchema, entry.get(0)), + convertToConnect(valueSchema, entry.get(1))); + } + } + return result; + }); + TO_CONNECT_CONVERTERS.put(Schema.Type.STRUCT, (schema, value) -> { + if (!value.isObject()) + throw new DataException("Structs should be encoded as JSON objects, but found " + value.getNodeType()); + + // We only have ISchema here but need Schema, so we need to materialize the actual schema. Using ISchema + // avoids having to materialize the schema for non-Struct types but it cannot be avoided for Structs since + // they require a schema to be provided at construction. However, the schema is only a SchemaBuilder during + // translation of schemas to JSON; during the more common translation of data to JSON, the call to schema.schema() + // just returns the schema Object and has no overhead. + Struct result = new Struct(schema.schema()); + for (Field field : schema.fields()) + result.put(field, convertToConnect(field.schema(), value.get(field.name()))); + + return result; + }); + } + + // Convert values in Kafka Connect form into/from their logical types. These logical converters are discovered by logical type + // names specified in the field + private static final HashMap LOGICAL_CONVERTERS = new HashMap<>(); + + private static final JsonNodeFactory JSON_NODE_FACTORY = JsonNodeFactory.withExactBigDecimals(true); + + static { + LOGICAL_CONVERTERS.put(Decimal.LOGICAL_NAME, new LogicalTypeConverter() { + @Override + public JsonNode toJson(final Schema schema, final Object value, final JsonConverterConfig config) { + if (!(value instanceof BigDecimal)) + throw new DataException("Invalid type for Decimal, expected BigDecimal but was " + value.getClass()); + + final BigDecimal decimal = (BigDecimal) value; + switch (config.decimalFormat()) { + case NUMERIC: + return JSON_NODE_FACTORY.numberNode(decimal); + case BASE64: + return JSON_NODE_FACTORY.binaryNode(Decimal.fromLogical(schema, decimal)); + default: + throw new DataException("Unexpected " + JsonConverterConfig.DECIMAL_FORMAT_CONFIG + ": " + config.decimalFormat()); + } + } + + @Override + public Object toConnect(final Schema schema, final JsonNode value) { + if (value.isNumber()) return value.decimalValue(); + if (value.isBinary() || value.isTextual()) { + try { + return Decimal.toLogical(schema, value.binaryValue()); + } catch (Exception e) { + throw new DataException("Invalid bytes for Decimal field", e); + } + } + + throw new DataException("Invalid type for Decimal, underlying representation should be numeric or bytes but was " + value.getNodeType()); + } + }); + + LOGICAL_CONVERTERS.put(Date.LOGICAL_NAME, new LogicalTypeConverter() { + @Override + public JsonNode toJson(final Schema schema, final Object value, final JsonConverterConfig config) { + if (!(value instanceof java.util.Date)) + throw new DataException("Invalid type for Date, expected Date but was " + value.getClass()); + return JSON_NODE_FACTORY.numberNode(Date.fromLogical(schema, (java.util.Date) value)); + } + + @Override + public Object toConnect(final Schema schema, final JsonNode value) { + if (!(value.isInt())) + throw new DataException("Invalid type for Date, underlying representation should be integer but was " + value.getNodeType()); + return Date.toLogical(schema, value.intValue()); + } + }); + + LOGICAL_CONVERTERS.put(Time.LOGICAL_NAME, new LogicalTypeConverter() { + @Override + public JsonNode toJson(final Schema schema, final Object value, final JsonConverterConfig config) { + if (!(value instanceof java.util.Date)) + throw new DataException("Invalid type for Time, expected Date but was " + value.getClass()); + return JSON_NODE_FACTORY.numberNode(Time.fromLogical(schema, (java.util.Date) value)); + } + + @Override + public Object toConnect(final Schema schema, final JsonNode value) { + if (!(value.isInt())) + throw new DataException("Invalid type for Time, underlying representation should be integer but was " + value.getNodeType()); + return Time.toLogical(schema, value.intValue()); + } + }); + + LOGICAL_CONVERTERS.put(Timestamp.LOGICAL_NAME, new LogicalTypeConverter() { + @Override + public JsonNode toJson(final Schema schema, final Object value, final JsonConverterConfig config) { + if (!(value instanceof java.util.Date)) + throw new DataException("Invalid type for Timestamp, expected Date but was " + value.getClass()); + return JSON_NODE_FACTORY.numberNode(Timestamp.fromLogical(schema, (java.util.Date) value)); + } + + @Override + public Object toConnect(final Schema schema, final JsonNode value) { + if (!(value.isIntegralNumber())) + throw new DataException("Invalid type for Timestamp, underlying representation should be integral but was " + value.getNodeType()); + return Timestamp.toLogical(schema, value.longValue()); + } + }); + } + + private JsonConverterConfig config; + private Cache fromConnectSchemaCache; + private Cache toConnectSchemaCache; + + private final JsonSerializer serializer; + private final JsonDeserializer deserializer; + + public JsonConverter() { + serializer = new JsonSerializer( + mkSet(), + JSON_NODE_FACTORY + ); + + deserializer = new JsonDeserializer( + mkSet( + // this ensures that the JsonDeserializer maintains full precision on + // floating point numbers that cannot fit into float64 + DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS + ), + JSON_NODE_FACTORY + ); + } + + // visible for testing + long sizeOfFromConnectSchemaCache() { + return fromConnectSchemaCache.size(); + } + + // visible for testing + long sizeOfToConnectSchemaCache() { + return toConnectSchemaCache.size(); + } + + @Override + public ConfigDef config() { + return JsonConverterConfig.configDef(); + } + + @Override + public void configure(Map configs) { + config = new JsonConverterConfig(configs); + + serializer.configure(configs, config.type() == ConverterType.KEY); + deserializer.configure(configs, config.type() == ConverterType.KEY); + + fromConnectSchemaCache = new SynchronizedCache<>(new LRUCache<>(config.schemaCacheSize())); + toConnectSchemaCache = new SynchronizedCache<>(new LRUCache<>(config.schemaCacheSize())); + } + + @Override + public void configure(Map configs, boolean isKey) { + Map conf = new HashMap<>(configs); + conf.put(StringConverterConfig.TYPE_CONFIG, isKey ? ConverterType.KEY.getName() : ConverterType.VALUE.getName()); + configure(conf); + } + + @Override + public void close() { + // do nothing + } + + @Override + public byte[] fromConnectHeader(String topic, String headerKey, Schema schema, Object value) { + return fromConnectData(topic, schema, value); + } + + @Override + public SchemaAndValue toConnectHeader(String topic, String headerKey, byte[] value) { + return toConnectData(topic, value); + } + + @Override + public byte[] fromConnectData(String topic, Schema schema, Object value) { + if (schema == null && value == null) { + return null; + } + + JsonNode jsonValue = config.schemasEnabled() ? convertToJsonWithEnvelope(schema, value) : convertToJsonWithoutEnvelope(schema, value); + try { + return serializer.serialize(topic, jsonValue); + } catch (SerializationException e) { + throw new DataException("Converting Kafka Connect data to byte[] failed due to serialization error: ", e); + } + } + + @Override + public SchemaAndValue toConnectData(String topic, byte[] value) { + JsonNode jsonValue; + + // This handles a tombstone message + if (value == null) { + return SchemaAndValue.NULL; + } + + try { + jsonValue = deserializer.deserialize(topic, value); + } catch (SerializationException e) { + throw new DataException("Converting byte[] to Kafka Connect data failed due to serialization error: ", e); + } + + if (config.schemasEnabled() && (!jsonValue.isObject() || jsonValue.size() != 2 || !jsonValue.has(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME) || !jsonValue.has(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME))) + throw new DataException("JsonConverter with schemas.enable requires \"schema\" and \"payload\" fields and may not contain additional fields." + + " If you are trying to deserialize plain JSON data, set schemas.enable=false in your converter configuration."); + + // The deserialized data should either be an envelope object containing the schema and the payload or the schema + // was stripped during serialization and we need to fill in an all-encompassing schema. + if (!config.schemasEnabled()) { + ObjectNode envelope = JSON_NODE_FACTORY.objectNode(); + envelope.set(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME, null); + envelope.set(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME, jsonValue); + jsonValue = envelope; + } + + Schema schema = asConnectSchema(jsonValue.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME)); + return new SchemaAndValue( + schema, + convertToConnect(schema, jsonValue.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME)) + ); + } + + public ObjectNode asJsonSchema(Schema schema) { + if (schema == null) + return null; + + ObjectNode cached = fromConnectSchemaCache.get(schema); + if (cached != null) + return cached; + + final ObjectNode jsonSchema; + switch (schema.type()) { + case BOOLEAN: + jsonSchema = JsonSchema.BOOLEAN_SCHEMA.deepCopy(); + break; + case BYTES: + jsonSchema = JsonSchema.BYTES_SCHEMA.deepCopy(); + break; + case FLOAT64: + jsonSchema = JsonSchema.DOUBLE_SCHEMA.deepCopy(); + break; + case FLOAT32: + jsonSchema = JsonSchema.FLOAT_SCHEMA.deepCopy(); + break; + case INT8: + jsonSchema = JsonSchema.INT8_SCHEMA.deepCopy(); + break; + case INT16: + jsonSchema = JsonSchema.INT16_SCHEMA.deepCopy(); + break; + case INT32: + jsonSchema = JsonSchema.INT32_SCHEMA.deepCopy(); + break; + case INT64: + jsonSchema = JsonSchema.INT64_SCHEMA.deepCopy(); + break; + case STRING: + jsonSchema = JsonSchema.STRING_SCHEMA.deepCopy(); + break; + case ARRAY: + jsonSchema = JSON_NODE_FACTORY.objectNode().put(JsonSchema.SCHEMA_TYPE_FIELD_NAME, JsonSchema.ARRAY_TYPE_NAME); + jsonSchema.set(JsonSchema.ARRAY_ITEMS_FIELD_NAME, asJsonSchema(schema.valueSchema())); + break; + case MAP: + jsonSchema = JSON_NODE_FACTORY.objectNode().put(JsonSchema.SCHEMA_TYPE_FIELD_NAME, JsonSchema.MAP_TYPE_NAME); + jsonSchema.set(JsonSchema.MAP_KEY_FIELD_NAME, asJsonSchema(schema.keySchema())); + jsonSchema.set(JsonSchema.MAP_VALUE_FIELD_NAME, asJsonSchema(schema.valueSchema())); + break; + case STRUCT: + jsonSchema = JSON_NODE_FACTORY.objectNode().put(JsonSchema.SCHEMA_TYPE_FIELD_NAME, JsonSchema.STRUCT_TYPE_NAME); + ArrayNode fields = JSON_NODE_FACTORY.arrayNode(); + for (Field field : schema.fields()) { + ObjectNode fieldJsonSchema = asJsonSchema(field.schema()).deepCopy(); + fieldJsonSchema.put(JsonSchema.STRUCT_FIELD_NAME_FIELD_NAME, field.name()); + fields.add(fieldJsonSchema); + } + jsonSchema.set(JsonSchema.STRUCT_FIELDS_FIELD_NAME, fields); + break; + default: + throw new DataException("Couldn't translate unsupported schema type " + schema + "."); + } + + jsonSchema.put(JsonSchema.SCHEMA_OPTIONAL_FIELD_NAME, schema.isOptional()); + if (schema.name() != null) + jsonSchema.put(JsonSchema.SCHEMA_NAME_FIELD_NAME, schema.name()); + if (schema.version() != null) + jsonSchema.put(JsonSchema.SCHEMA_VERSION_FIELD_NAME, schema.version()); + if (schema.doc() != null) + jsonSchema.put(JsonSchema.SCHEMA_DOC_FIELD_NAME, schema.doc()); + if (schema.parameters() != null) { + ObjectNode jsonSchemaParams = JSON_NODE_FACTORY.objectNode(); + for (Map.Entry prop : schema.parameters().entrySet()) + jsonSchemaParams.put(prop.getKey(), prop.getValue()); + jsonSchema.set(JsonSchema.SCHEMA_PARAMETERS_FIELD_NAME, jsonSchemaParams); + } + if (schema.defaultValue() != null) + jsonSchema.set(JsonSchema.SCHEMA_DEFAULT_FIELD_NAME, convertToJson(schema, schema.defaultValue())); + + fromConnectSchemaCache.put(schema, jsonSchema); + return jsonSchema; + } + + + public Schema asConnectSchema(JsonNode jsonSchema) { + if (jsonSchema.isNull()) + return null; + + Schema cached = toConnectSchemaCache.get(jsonSchema); + if (cached != null) + return cached; + + JsonNode schemaTypeNode = jsonSchema.get(JsonSchema.SCHEMA_TYPE_FIELD_NAME); + if (schemaTypeNode == null || !schemaTypeNode.isTextual()) + throw new DataException("Schema must contain 'type' field"); + + final SchemaBuilder builder; + switch (schemaTypeNode.textValue()) { + case JsonSchema.BOOLEAN_TYPE_NAME: + builder = SchemaBuilder.bool(); + break; + case JsonSchema.INT8_TYPE_NAME: + builder = SchemaBuilder.int8(); + break; + case JsonSchema.INT16_TYPE_NAME: + builder = SchemaBuilder.int16(); + break; + case JsonSchema.INT32_TYPE_NAME: + builder = SchemaBuilder.int32(); + break; + case JsonSchema.INT64_TYPE_NAME: + builder = SchemaBuilder.int64(); + break; + case JsonSchema.FLOAT_TYPE_NAME: + builder = SchemaBuilder.float32(); + break; + case JsonSchema.DOUBLE_TYPE_NAME: + builder = SchemaBuilder.float64(); + break; + case JsonSchema.BYTES_TYPE_NAME: + builder = SchemaBuilder.bytes(); + break; + case JsonSchema.STRING_TYPE_NAME: + builder = SchemaBuilder.string(); + break; + case JsonSchema.ARRAY_TYPE_NAME: + JsonNode elemSchema = jsonSchema.get(JsonSchema.ARRAY_ITEMS_FIELD_NAME); + if (elemSchema == null || elemSchema.isNull()) + throw new DataException("Array schema did not specify the element type"); + builder = SchemaBuilder.array(asConnectSchema(elemSchema)); + break; + case JsonSchema.MAP_TYPE_NAME: + JsonNode keySchema = jsonSchema.get(JsonSchema.MAP_KEY_FIELD_NAME); + if (keySchema == null) + throw new DataException("Map schema did not specify the key type"); + JsonNode valueSchema = jsonSchema.get(JsonSchema.MAP_VALUE_FIELD_NAME); + if (valueSchema == null) + throw new DataException("Map schema did not specify the value type"); + builder = SchemaBuilder.map(asConnectSchema(keySchema), asConnectSchema(valueSchema)); + break; + case JsonSchema.STRUCT_TYPE_NAME: + builder = SchemaBuilder.struct(); + JsonNode fields = jsonSchema.get(JsonSchema.STRUCT_FIELDS_FIELD_NAME); + if (fields == null || !fields.isArray()) + throw new DataException("Struct schema's \"fields\" argument is not an array."); + for (JsonNode field : fields) { + JsonNode jsonFieldName = field.get(JsonSchema.STRUCT_FIELD_NAME_FIELD_NAME); + if (jsonFieldName == null || !jsonFieldName.isTextual()) + throw new DataException("Struct schema's field name not specified properly"); + builder.field(jsonFieldName.asText(), asConnectSchema(field)); + } + break; + default: + throw new DataException("Unknown schema type: " + schemaTypeNode.textValue()); + } + + + JsonNode schemaOptionalNode = jsonSchema.get(JsonSchema.SCHEMA_OPTIONAL_FIELD_NAME); + if (schemaOptionalNode != null && schemaOptionalNode.isBoolean() && schemaOptionalNode.booleanValue()) + builder.optional(); + else + builder.required(); + + JsonNode schemaNameNode = jsonSchema.get(JsonSchema.SCHEMA_NAME_FIELD_NAME); + if (schemaNameNode != null && schemaNameNode.isTextual()) + builder.name(schemaNameNode.textValue()); + + JsonNode schemaVersionNode = jsonSchema.get(JsonSchema.SCHEMA_VERSION_FIELD_NAME); + if (schemaVersionNode != null && schemaVersionNode.isIntegralNumber()) { + builder.version(schemaVersionNode.intValue()); + } + + JsonNode schemaDocNode = jsonSchema.get(JsonSchema.SCHEMA_DOC_FIELD_NAME); + if (schemaDocNode != null && schemaDocNode.isTextual()) + builder.doc(schemaDocNode.textValue()); + + JsonNode schemaParamsNode = jsonSchema.get(JsonSchema.SCHEMA_PARAMETERS_FIELD_NAME); + if (schemaParamsNode != null && schemaParamsNode.isObject()) { + Iterator> paramsIt = schemaParamsNode.fields(); + while (paramsIt.hasNext()) { + Map.Entry entry = paramsIt.next(); + JsonNode paramValue = entry.getValue(); + if (!paramValue.isTextual()) + throw new DataException("Schema parameters must have string values."); + builder.parameter(entry.getKey(), paramValue.textValue()); + } + } + + JsonNode schemaDefaultNode = jsonSchema.get(JsonSchema.SCHEMA_DEFAULT_FIELD_NAME); + if (schemaDefaultNode != null) + builder.defaultValue(convertToConnect(builder, schemaDefaultNode)); + + Schema result = builder.build(); + toConnectSchemaCache.put(jsonSchema, result); + return result; + } + + + /** + * Convert this object, in org.apache.kafka.connect.data format, into a JSON object with an envelope object + * containing schema and payload fields. + * @param schema the schema for the data + * @param value the value + * @return JsonNode-encoded version + */ + private JsonNode convertToJsonWithEnvelope(Schema schema, Object value) { + return new JsonSchema.Envelope(asJsonSchema(schema), convertToJson(schema, value)).toJsonNode(); + } + + private JsonNode convertToJsonWithoutEnvelope(Schema schema, Object value) { + return convertToJson(schema, value); + } + + /** + * Convert this object, in the org.apache.kafka.connect.data format, into a JSON object, returning both the schema + * and the converted object. + */ + private JsonNode convertToJson(Schema schema, Object value) { + if (value == null) { + if (schema == null) // Any schema is valid and we don't have a default, so treat this as an optional schema + return null; + if (schema.defaultValue() != null) + return convertToJson(schema, schema.defaultValue()); + if (schema.isOptional()) + return JSON_NODE_FACTORY.nullNode(); + throw new DataException("Conversion error: null value for field that is required and has no default value"); + } + + if (schema != null && schema.name() != null) { + LogicalTypeConverter logicalConverter = LOGICAL_CONVERTERS.get(schema.name()); + if (logicalConverter != null) + return logicalConverter.toJson(schema, value, config); + } + + try { + final Schema.Type schemaType; + if (schema == null) { + schemaType = ConnectSchema.schemaType(value.getClass()); + if (schemaType == null) + throw new DataException("Java class " + value.getClass() + " does not have corresponding schema type."); + } else { + schemaType = schema.type(); + } + switch (schemaType) { + case INT8: + return JSON_NODE_FACTORY.numberNode((Byte) value); + case INT16: + return JSON_NODE_FACTORY.numberNode((Short) value); + case INT32: + return JSON_NODE_FACTORY.numberNode((Integer) value); + case INT64: + return JSON_NODE_FACTORY.numberNode((Long) value); + case FLOAT32: + return JSON_NODE_FACTORY.numberNode((Float) value); + case FLOAT64: + return JSON_NODE_FACTORY.numberNode((Double) value); + case BOOLEAN: + return JSON_NODE_FACTORY.booleanNode((Boolean) value); + case STRING: + CharSequence charSeq = (CharSequence) value; + return JSON_NODE_FACTORY.textNode(charSeq.toString()); + case BYTES: + if (value instanceof byte[]) + return JSON_NODE_FACTORY.binaryNode((byte[]) value); + else if (value instanceof ByteBuffer) + return JSON_NODE_FACTORY.binaryNode(((ByteBuffer) value).array()); + else + throw new DataException("Invalid type for bytes type: " + value.getClass()); + case ARRAY: { + Collection collection = (Collection) value; + ArrayNode list = JSON_NODE_FACTORY.arrayNode(); + for (Object elem : collection) { + Schema valueSchema = schema == null ? null : schema.valueSchema(); + JsonNode fieldValue = convertToJson(valueSchema, elem); + list.add(fieldValue); + } + return list; + } + case MAP: { + Map map = (Map) value; + // If true, using string keys and JSON object; if false, using non-string keys and Array-encoding + boolean objectMode; + if (schema == null) { + objectMode = true; + for (Map.Entry entry : map.entrySet()) { + if (!(entry.getKey() instanceof String)) { + objectMode = false; + break; + } + } + } else { + objectMode = schema.keySchema().type() == Schema.Type.STRING; + } + ObjectNode obj = null; + ArrayNode list = null; + if (objectMode) + obj = JSON_NODE_FACTORY.objectNode(); + else + list = JSON_NODE_FACTORY.arrayNode(); + for (Map.Entry entry : map.entrySet()) { + Schema keySchema = schema == null ? null : schema.keySchema(); + Schema valueSchema = schema == null ? null : schema.valueSchema(); + JsonNode mapKey = convertToJson(keySchema, entry.getKey()); + JsonNode mapValue = convertToJson(valueSchema, entry.getValue()); + + if (objectMode) + obj.set(mapKey.asText(), mapValue); + else + list.add(JSON_NODE_FACTORY.arrayNode().add(mapKey).add(mapValue)); + } + return objectMode ? obj : list; + } + case STRUCT: { + Struct struct = (Struct) value; + if (!struct.schema().equals(schema)) + throw new DataException("Mismatching schema."); + ObjectNode obj = JSON_NODE_FACTORY.objectNode(); + for (Field field : schema.fields()) { + obj.set(field.name(), convertToJson(field.schema(), struct.get(field))); + } + return obj; + } + } + + throw new DataException("Couldn't convert " + value + " to JSON."); + } catch (ClassCastException e) { + String schemaTypeStr = (schema != null) ? schema.type().toString() : "unknown schema"; + throw new DataException("Invalid type for " + schemaTypeStr + ": " + value.getClass()); + } + } + + + private static Object convertToConnect(Schema schema, JsonNode jsonValue) { + final Schema.Type schemaType; + if (schema != null) { + schemaType = schema.type(); + if (jsonValue == null || jsonValue.isNull()) { + if (schema.defaultValue() != null) + return schema.defaultValue(); // any logical type conversions should already have been applied + if (schema.isOptional()) + return null; + throw new DataException("Invalid null value for required " + schemaType + " field"); + } + } else { + switch (jsonValue.getNodeType()) { + case NULL: + case MISSING: + // Special case. With no schema + return null; + case BOOLEAN: + schemaType = Schema.Type.BOOLEAN; + break; + case NUMBER: + if (jsonValue.isIntegralNumber()) + schemaType = Schema.Type.INT64; + else + schemaType = Schema.Type.FLOAT64; + break; + case ARRAY: + schemaType = Schema.Type.ARRAY; + break; + case OBJECT: + schemaType = Schema.Type.MAP; + break; + case STRING: + schemaType = Schema.Type.STRING; + break; + + case BINARY: + case POJO: + default: + schemaType = null; + break; + } + } + + final JsonToConnectTypeConverter typeConverter = TO_CONNECT_CONVERTERS.get(schemaType); + if (typeConverter == null) + throw new DataException("Unknown schema type: " + schemaType); + + if (schema != null && schema.name() != null) { + LogicalTypeConverter logicalConverter = LOGICAL_CONVERTERS.get(schema.name()); + if (logicalConverter != null) + return logicalConverter.toConnect(schema, jsonValue); + } + + return typeConverter.convert(schema, jsonValue); + } + + private interface JsonToConnectTypeConverter { + Object convert(Schema schema, JsonNode value); + } + + private interface LogicalTypeConverter { + JsonNode toJson(Schema schema, Object value, JsonConverterConfig config); + Object toConnect(Schema schema, JsonNode value); + } +} diff --git a/connect/json/src/main/java/org/apache/kafka/connect/json/JsonConverterConfig.java b/connect/json/src/main/java/org/apache/kafka/connect/json/JsonConverterConfig.java new file mode 100644 index 0000000..efb4979 --- /dev/null +++ b/connect/json/src/main/java/org/apache/kafka/connect/json/JsonConverterConfig.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.json; + +import java.util.Locale; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigDef.Importance; +import org.apache.kafka.common.config.ConfigDef.Type; +import org.apache.kafka.common.config.ConfigDef.Width; +import org.apache.kafka.connect.storage.ConverterConfig; + +import java.util.Map; + +/** + * Configuration options for {@link JsonConverter} instances. + */ +public class JsonConverterConfig extends ConverterConfig { + + public static final String SCHEMAS_ENABLE_CONFIG = "schemas.enable"; + public static final boolean SCHEMAS_ENABLE_DEFAULT = true; + private static final String SCHEMAS_ENABLE_DOC = "Include schemas within each of the serialized values and keys."; + private static final String SCHEMAS_ENABLE_DISPLAY = "Enable Schemas"; + + public static final String SCHEMAS_CACHE_SIZE_CONFIG = "schemas.cache.size"; + public static final int SCHEMAS_CACHE_SIZE_DEFAULT = 1000; + private static final String SCHEMAS_CACHE_SIZE_DOC = "The maximum number of schemas that can be cached in this converter instance."; + private static final String SCHEMAS_CACHE_SIZE_DISPLAY = "Schema Cache Size"; + + public static final String DECIMAL_FORMAT_CONFIG = "decimal.format"; + public static final String DECIMAL_FORMAT_DEFAULT = DecimalFormat.BASE64.name(); + private static final String DECIMAL_FORMAT_DOC = "Controls which format this converter will serialize decimals in." + + " This value is case insensitive and can be either 'BASE64' (default) or 'NUMERIC'"; + private static final String DECIMAL_FORMAT_DISPLAY = "Decimal Format"; + + private final static ConfigDef CONFIG; + + static { + String group = "Schemas"; + int orderInGroup = 0; + CONFIG = ConverterConfig.newConfigDef(); + CONFIG.define(SCHEMAS_ENABLE_CONFIG, Type.BOOLEAN, SCHEMAS_ENABLE_DEFAULT, Importance.HIGH, SCHEMAS_ENABLE_DOC, group, + orderInGroup++, Width.MEDIUM, SCHEMAS_ENABLE_DISPLAY); + CONFIG.define(SCHEMAS_CACHE_SIZE_CONFIG, Type.INT, SCHEMAS_CACHE_SIZE_DEFAULT, Importance.HIGH, SCHEMAS_CACHE_SIZE_DOC, group, + orderInGroup++, Width.MEDIUM, SCHEMAS_CACHE_SIZE_DISPLAY); + + group = "Serialization"; + orderInGroup = 0; + CONFIG.define( + DECIMAL_FORMAT_CONFIG, Type.STRING, DECIMAL_FORMAT_DEFAULT, + ConfigDef.CaseInsensitiveValidString.in( + DecimalFormat.BASE64.name(), + DecimalFormat.NUMERIC.name()), + Importance.LOW, DECIMAL_FORMAT_DOC, group, orderInGroup++, + Width.MEDIUM, DECIMAL_FORMAT_DISPLAY); + } + + public static ConfigDef configDef() { + return CONFIG; + } + + // cached config values + private final boolean schemasEnabled; + private final int schemaCacheSize; + private final DecimalFormat decimalFormat; + + public JsonConverterConfig(Map props) { + super(CONFIG, props); + this.schemasEnabled = getBoolean(SCHEMAS_ENABLE_CONFIG); + this.schemaCacheSize = getInt(SCHEMAS_CACHE_SIZE_CONFIG); + this.decimalFormat = DecimalFormat.valueOf(getString(DECIMAL_FORMAT_CONFIG).toUpperCase(Locale.ROOT)); + } + + /** + * Return whether schemas are enabled. + * + * @return true if enabled, or false otherwise + */ + public boolean schemasEnabled() { + return schemasEnabled; + } + + /** + * Get the cache size. + * + * @return the cache size + */ + public int schemaCacheSize() { + return schemaCacheSize; + } + + /** + * Get the serialization format for decimal types. + * + * @return the decimal serialization format + */ + public DecimalFormat decimalFormat() { + return decimalFormat; + } + +} diff --git a/connect/json/src/main/java/org/apache/kafka/connect/json/JsonDeserializer.java b/connect/json/src/main/java/org/apache/kafka/connect/json/JsonDeserializer.java new file mode 100644 index 0000000..2e6e821 --- /dev/null +++ b/connect/json/src/main/java/org/apache/kafka/connect/json/JsonDeserializer.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.json; + +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.JsonNodeFactory; +import java.util.Collections; +import java.util.Set; +import org.apache.kafka.common.errors.SerializationException; +import org.apache.kafka.common.serialization.Deserializer; + +/** + * JSON deserializer for Jackson's JsonNode tree model. Using the tree model allows it to work with arbitrarily + * structured data without having associated Java classes. This deserializer also supports Connect schemas. + */ +public class JsonDeserializer implements Deserializer { + private final ObjectMapper objectMapper = new ObjectMapper(); + + /** + * Default constructor needed by Kafka + */ + public JsonDeserializer() { + this(Collections.emptySet(), JsonNodeFactory.withExactBigDecimals(true)); + } + + /** + * A constructor that additionally specifies some {@link DeserializationFeature} + * for the deserializer + * + * @param deserializationFeatures the specified deserialization features + * @param jsonNodeFactory the json node factory to use. + */ + JsonDeserializer( + final Set deserializationFeatures, + final JsonNodeFactory jsonNodeFactory + ) { + deserializationFeatures.forEach(objectMapper::enable); + objectMapper.setNodeFactory(jsonNodeFactory); + } + + @Override + public JsonNode deserialize(String topic, byte[] bytes) { + if (bytes == null) + return null; + + JsonNode data; + try { + data = objectMapper.readTree(bytes); + } catch (Exception e) { + throw new SerializationException(e); + } + + return data; + } +} diff --git a/connect/json/src/main/java/org/apache/kafka/connect/json/JsonSchema.java b/connect/json/src/main/java/org/apache/kafka/connect/json/JsonSchema.java new file mode 100644 index 0000000..e15d97d --- /dev/null +++ b/connect/json/src/main/java/org/apache/kafka/connect/json/JsonSchema.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.json; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.JsonNodeFactory; +import com.fasterxml.jackson.databind.node.ObjectNode; + +public class JsonSchema { + + static final String ENVELOPE_SCHEMA_FIELD_NAME = "schema"; + static final String ENVELOPE_PAYLOAD_FIELD_NAME = "payload"; + static final String SCHEMA_TYPE_FIELD_NAME = "type"; + static final String SCHEMA_OPTIONAL_FIELD_NAME = "optional"; + static final String SCHEMA_NAME_FIELD_NAME = "name"; + static final String SCHEMA_VERSION_FIELD_NAME = "version"; + static final String SCHEMA_DOC_FIELD_NAME = "doc"; + static final String SCHEMA_PARAMETERS_FIELD_NAME = "parameters"; + static final String SCHEMA_DEFAULT_FIELD_NAME = "default"; + static final String ARRAY_ITEMS_FIELD_NAME = "items"; + static final String MAP_KEY_FIELD_NAME = "keys"; + static final String MAP_VALUE_FIELD_NAME = "values"; + static final String STRUCT_FIELDS_FIELD_NAME = "fields"; + static final String STRUCT_FIELD_NAME_FIELD_NAME = "field"; + static final String BOOLEAN_TYPE_NAME = "boolean"; + static final ObjectNode BOOLEAN_SCHEMA = JsonNodeFactory.instance.objectNode().put(SCHEMA_TYPE_FIELD_NAME, BOOLEAN_TYPE_NAME); + static final String INT8_TYPE_NAME = "int8"; + static final ObjectNode INT8_SCHEMA = JsonNodeFactory.instance.objectNode().put(SCHEMA_TYPE_FIELD_NAME, INT8_TYPE_NAME); + static final String INT16_TYPE_NAME = "int16"; + static final ObjectNode INT16_SCHEMA = JsonNodeFactory.instance.objectNode().put(SCHEMA_TYPE_FIELD_NAME, INT16_TYPE_NAME); + static final String INT32_TYPE_NAME = "int32"; + static final ObjectNode INT32_SCHEMA = JsonNodeFactory.instance.objectNode().put(SCHEMA_TYPE_FIELD_NAME, INT32_TYPE_NAME); + static final String INT64_TYPE_NAME = "int64"; + static final ObjectNode INT64_SCHEMA = JsonNodeFactory.instance.objectNode().put(SCHEMA_TYPE_FIELD_NAME, INT64_TYPE_NAME); + static final String FLOAT_TYPE_NAME = "float"; + static final ObjectNode FLOAT_SCHEMA = JsonNodeFactory.instance.objectNode().put(SCHEMA_TYPE_FIELD_NAME, FLOAT_TYPE_NAME); + static final String DOUBLE_TYPE_NAME = "double"; + static final ObjectNode DOUBLE_SCHEMA = JsonNodeFactory.instance.objectNode().put(SCHEMA_TYPE_FIELD_NAME, DOUBLE_TYPE_NAME); + static final String BYTES_TYPE_NAME = "bytes"; + static final ObjectNode BYTES_SCHEMA = JsonNodeFactory.instance.objectNode().put(SCHEMA_TYPE_FIELD_NAME, BYTES_TYPE_NAME); + static final String STRING_TYPE_NAME = "string"; + static final ObjectNode STRING_SCHEMA = JsonNodeFactory.instance.objectNode().put(SCHEMA_TYPE_FIELD_NAME, STRING_TYPE_NAME); + static final String ARRAY_TYPE_NAME = "array"; + static final String MAP_TYPE_NAME = "map"; + static final String STRUCT_TYPE_NAME = "struct"; + + public static ObjectNode envelope(JsonNode schema, JsonNode payload) { + ObjectNode result = JsonNodeFactory.instance.objectNode(); + result.set(ENVELOPE_SCHEMA_FIELD_NAME, schema); + result.set(ENVELOPE_PAYLOAD_FIELD_NAME, payload); + return result; + } + + static class Envelope { + public JsonNode schema; + public JsonNode payload; + + public Envelope(JsonNode schema, JsonNode payload) { + this.schema = schema; + this.payload = payload; + } + + public ObjectNode toJsonNode() { + return envelope(schema, payload); + } + } +} diff --git a/connect/json/src/main/java/org/apache/kafka/connect/json/JsonSerializer.java b/connect/json/src/main/java/org/apache/kafka/connect/json/JsonSerializer.java new file mode 100644 index 0000000..0f2b62b --- /dev/null +++ b/connect/json/src/main/java/org/apache/kafka/connect/json/JsonSerializer.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.json; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; +import com.fasterxml.jackson.databind.node.JsonNodeFactory; +import org.apache.kafka.common.errors.SerializationException; +import org.apache.kafka.common.serialization.Serializer; + +import java.util.Collections; +import java.util.Set; + +/** + * Serialize Jackson JsonNode tree model objects to UTF-8 JSON. Using the tree model allows handling arbitrarily + * structured data without corresponding Java classes. This serializer also supports Connect schemas. + */ +public class JsonSerializer implements Serializer { + private final ObjectMapper objectMapper = new ObjectMapper(); + + /** + * Default constructor needed by Kafka + */ + public JsonSerializer() { + this(Collections.emptySet(), JsonNodeFactory.withExactBigDecimals(true)); + } + + /** + * A constructor that additionally specifies some {@link SerializationFeature} + * for the serializer + * + * @param serializationFeatures the specified serialization features + * @param jsonNodeFactory the json node factory to use. + */ + JsonSerializer( + final Set serializationFeatures, + final JsonNodeFactory jsonNodeFactory + ) { + serializationFeatures.forEach(objectMapper::enable); + objectMapper.setNodeFactory(jsonNodeFactory); + } + + @Override + public byte[] serialize(String topic, JsonNode data) { + if (data == null) + return null; + + try { + return objectMapper.writeValueAsBytes(data); + } catch (Exception e) { + throw new SerializationException("Error serializing JSON message", e); + } + } +} diff --git a/connect/json/src/test/java/org/apache/kafka/connect/json/JsonConverterConfigTest.java b/connect/json/src/test/java/org/apache/kafka/connect/json/JsonConverterConfigTest.java new file mode 100644 index 0000000..efa1f60 --- /dev/null +++ b/connect/json/src/test/java/org/apache/kafka/connect/json/JsonConverterConfigTest.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.json; + +import org.apache.kafka.connect.storage.ConverterConfig; +import org.apache.kafka.connect.storage.ConverterType; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class JsonConverterConfigTest { + + @Test + public void shouldBeCaseInsensitiveForDecimalFormatConfig() { + final Map configValues = new HashMap<>(); + configValues.put(ConverterConfig.TYPE_CONFIG, ConverterType.KEY.getName()); + configValues.put(JsonConverterConfig.DECIMAL_FORMAT_CONFIG, "NuMeRiC"); + + final JsonConverterConfig config = new JsonConverterConfig(configValues); + assertEquals(config.decimalFormat(), DecimalFormat.NUMERIC); + } + +} \ No newline at end of file diff --git a/connect/json/src/test/java/org/apache/kafka/connect/json/JsonConverterTest.java b/connect/json/src/test/java/org/apache/kafka/connect/json/JsonConverterTest.java new file mode 100644 index 0000000..4e4c53b --- /dev/null +++ b/connect/json/src/test/java/org/apache/kafka/connect/json/JsonConverterTest.java @@ -0,0 +1,931 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.json; + +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.JsonNodeFactory; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.connect.data.Date; +import org.apache.kafka.connect.data.Decimal; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.data.SchemaBuilder; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.data.Time; +import org.apache.kafka.connect.data.Timestamp; +import org.apache.kafka.connect.errors.DataException; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.File; +import java.io.IOException; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.net.URISyntaxException; +import java.net.URL; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Calendar; +import java.util.Collections; +import java.util.GregorianCalendar; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.TimeZone; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class JsonConverterTest { + private static final String TOPIC = "topic"; + + private final ObjectMapper objectMapper = new ObjectMapper() + .enable(DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS) + .setNodeFactory(JsonNodeFactory.withExactBigDecimals(true)); + + private final JsonConverter converter = new JsonConverter(); + + @BeforeEach + public void setUp() { + converter.configure(Collections.emptyMap(), false); + } + + // Schema metadata + + @Test + public void testConnectSchemaMetadataTranslation() { + // this validates the non-type fields are translated and handled properly + assertEquals(new SchemaAndValue(Schema.BOOLEAN_SCHEMA, true), converter.toConnectData(TOPIC, "{ \"schema\": { \"type\": \"boolean\" }, \"payload\": true }".getBytes())); + assertEquals(new SchemaAndValue(Schema.OPTIONAL_BOOLEAN_SCHEMA, null), converter.toConnectData(TOPIC, "{ \"schema\": { \"type\": \"boolean\", \"optional\": true }, \"payload\": null }".getBytes())); + assertEquals(new SchemaAndValue(SchemaBuilder.bool().defaultValue(true).build(), true), + converter.toConnectData(TOPIC, "{ \"schema\": { \"type\": \"boolean\", \"default\": true }, \"payload\": null }".getBytes())); + assertEquals(new SchemaAndValue(SchemaBuilder.bool().required().name("bool").version(2).doc("the documentation").parameter("foo", "bar").build(), true), + converter.toConnectData(TOPIC, "{ \"schema\": { \"type\": \"boolean\", \"optional\": false, \"name\": \"bool\", \"version\": 2, \"doc\": \"the documentation\", \"parameters\": { \"foo\": \"bar\" }}, \"payload\": true }".getBytes())); + } + + // Schema types + + @Test + public void booleanToConnect() { + assertEquals(new SchemaAndValue(Schema.BOOLEAN_SCHEMA, true), converter.toConnectData(TOPIC, "{ \"schema\": { \"type\": \"boolean\" }, \"payload\": true }".getBytes())); + assertEquals(new SchemaAndValue(Schema.BOOLEAN_SCHEMA, false), converter.toConnectData(TOPIC, "{ \"schema\": { \"type\": \"boolean\" }, \"payload\": false }".getBytes())); + } + + @Test + public void byteToConnect() { + assertEquals(new SchemaAndValue(Schema.INT8_SCHEMA, (byte) 12), converter.toConnectData(TOPIC, "{ \"schema\": { \"type\": \"int8\" }, \"payload\": 12 }".getBytes())); + } + + @Test + public void shortToConnect() { + assertEquals(new SchemaAndValue(Schema.INT16_SCHEMA, (short) 12), converter.toConnectData(TOPIC, "{ \"schema\": { \"type\": \"int16\" }, \"payload\": 12 }".getBytes())); + } + + @Test + public void intToConnect() { + assertEquals(new SchemaAndValue(Schema.INT32_SCHEMA, 12), converter.toConnectData(TOPIC, "{ \"schema\": { \"type\": \"int32\" }, \"payload\": 12 }".getBytes())); + } + + @Test + public void longToConnect() { + assertEquals(new SchemaAndValue(Schema.INT64_SCHEMA, 12L), converter.toConnectData(TOPIC, "{ \"schema\": { \"type\": \"int64\" }, \"payload\": 12 }".getBytes())); + assertEquals(new SchemaAndValue(Schema.INT64_SCHEMA, 4398046511104L), converter.toConnectData(TOPIC, "{ \"schema\": { \"type\": \"int64\" }, \"payload\": 4398046511104 }".getBytes())); + } + + @Test + public void floatToConnect() { + assertEquals(new SchemaAndValue(Schema.FLOAT32_SCHEMA, 12.34f), converter.toConnectData(TOPIC, "{ \"schema\": { \"type\": \"float\" }, \"payload\": 12.34 }".getBytes())); + } + + @Test + public void doubleToConnect() { + assertEquals(new SchemaAndValue(Schema.FLOAT64_SCHEMA, 12.34), converter.toConnectData(TOPIC, "{ \"schema\": { \"type\": \"double\" }, \"payload\": 12.34 }".getBytes())); + } + + + @Test + public void bytesToConnect() { + ByteBuffer reference = ByteBuffer.wrap(Utils.utf8("test-string")); + String msg = "{ \"schema\": { \"type\": \"bytes\" }, \"payload\": \"dGVzdC1zdHJpbmc=\" }"; + SchemaAndValue schemaAndValue = converter.toConnectData(TOPIC, msg.getBytes()); + ByteBuffer converted = ByteBuffer.wrap((byte[]) schemaAndValue.value()); + assertEquals(reference, converted); + } + + @Test + public void stringToConnect() { + assertEquals(new SchemaAndValue(Schema.STRING_SCHEMA, "foo-bar-baz"), converter.toConnectData(TOPIC, "{ \"schema\": { \"type\": \"string\" }, \"payload\": \"foo-bar-baz\" }".getBytes())); + } + + @Test + public void arrayToConnect() { + byte[] arrayJson = "{ \"schema\": { \"type\": \"array\", \"items\": { \"type\" : \"int32\" } }, \"payload\": [1, 2, 3] }".getBytes(); + assertEquals(new SchemaAndValue(SchemaBuilder.array(Schema.INT32_SCHEMA).build(), Arrays.asList(1, 2, 3)), converter.toConnectData(TOPIC, arrayJson)); + } + + @Test + public void mapToConnectStringKeys() { + byte[] mapJson = "{ \"schema\": { \"type\": \"map\", \"keys\": { \"type\" : \"string\" }, \"values\": { \"type\" : \"int32\" } }, \"payload\": { \"key1\": 12, \"key2\": 15} }".getBytes(); + Map expected = new HashMap<>(); + expected.put("key1", 12); + expected.put("key2", 15); + assertEquals(new SchemaAndValue(SchemaBuilder.map(Schema.STRING_SCHEMA, Schema.INT32_SCHEMA).build(), expected), converter.toConnectData(TOPIC, mapJson)); + } + + @Test + public void mapToConnectNonStringKeys() { + byte[] mapJson = "{ \"schema\": { \"type\": \"map\", \"keys\": { \"type\" : \"int32\" }, \"values\": { \"type\" : \"int32\" } }, \"payload\": [ [1, 12], [2, 15] ] }".getBytes(); + Map expected = new HashMap<>(); + expected.put(1, 12); + expected.put(2, 15); + assertEquals(new SchemaAndValue(SchemaBuilder.map(Schema.INT32_SCHEMA, Schema.INT32_SCHEMA).build(), expected), converter.toConnectData(TOPIC, mapJson)); + } + + @Test + public void structToConnect() { + byte[] structJson = "{ \"schema\": { \"type\": \"struct\", \"fields\": [{ \"field\": \"field1\", \"type\": \"boolean\" }, { \"field\": \"field2\", \"type\": \"string\" }] }, \"payload\": { \"field1\": true, \"field2\": \"string\" } }".getBytes(); + Schema expectedSchema = SchemaBuilder.struct().field("field1", Schema.BOOLEAN_SCHEMA).field("field2", Schema.STRING_SCHEMA).build(); + Struct expected = new Struct(expectedSchema).put("field1", true).put("field2", "string"); + SchemaAndValue converted = converter.toConnectData(TOPIC, structJson); + assertEquals(new SchemaAndValue(expectedSchema, expected), converted); + } + + @Test + public void structWithOptionalFieldToConnect() { + byte[] structJson = "{ \"schema\": { \"type\": \"struct\", \"fields\": [{ \"field\":\"optional\", \"type\": \"string\", \"optional\": true }, { \"field\": \"required\", \"type\": \"string\" }] }, \"payload\": { \"required\": \"required\" } }".getBytes(); + Schema expectedSchema = SchemaBuilder.struct().field("optional", Schema.OPTIONAL_STRING_SCHEMA).field("required", Schema.STRING_SCHEMA).build(); + Struct expected = new Struct(expectedSchema).put("required", "required"); + SchemaAndValue converted = converter.toConnectData(TOPIC, structJson); + assertEquals(new SchemaAndValue(expectedSchema, expected), converted); + } + + @Test + public void nullToConnect() { + // When schemas are enabled, trying to decode a tombstone should be an empty envelope + // the behavior is the same as when the json is "{ "schema": null, "payload": null }" + // to keep compatibility with the record + SchemaAndValue converted = converter.toConnectData(TOPIC, null); + assertEquals(SchemaAndValue.NULL, converted); + } + + /** + * When schemas are disabled, empty data should be decoded to an empty envelope. + * This test verifies the case where `schemas.enable` configuration is set to false, and + * {@link JsonConverter} converts empty bytes to {@link SchemaAndValue#NULL}. + */ + @Test + public void emptyBytesToConnect() { + // This characterizes the messages with empty data when Json schemas is disabled + Map props = Collections.singletonMap("schemas.enable", false); + converter.configure(props, true); + SchemaAndValue converted = converter.toConnectData(TOPIC, "".getBytes()); + assertEquals(SchemaAndValue.NULL, converted); + } + + /** + * When schemas are disabled, fields are mapped to Connect maps. + */ + @Test + public void schemalessWithEmptyFieldValueToConnect() { + // This characterizes the messages with empty data when Json schemas is disabled + Map props = Collections.singletonMap("schemas.enable", false); + converter.configure(props, true); + String input = "{ \"a\": \"\", \"b\": null}"; + SchemaAndValue converted = converter.toConnectData(TOPIC, input.getBytes()); + Map expected = new HashMap<>(); + expected.put("a", ""); + expected.put("b", null); + assertEquals(new SchemaAndValue(null, expected), converted); + } + + @Test + public void nullSchemaPrimitiveToConnect() { + SchemaAndValue converted = converter.toConnectData(TOPIC, "{ \"schema\": null, \"payload\": null }".getBytes()); + assertEquals(SchemaAndValue.NULL, converted); + + converted = converter.toConnectData(TOPIC, "{ \"schema\": null, \"payload\": true }".getBytes()); + assertEquals(new SchemaAndValue(null, true), converted); + + // Integers: Connect has more data types, and JSON unfortunately mixes all number types. We try to preserve + // info as best we can, so we always use the largest integer and floating point numbers we can and have Jackson + // determine if it's an integer or not + converted = converter.toConnectData(TOPIC, "{ \"schema\": null, \"payload\": 12 }".getBytes()); + assertEquals(new SchemaAndValue(null, 12L), converted); + + converted = converter.toConnectData(TOPIC, "{ \"schema\": null, \"payload\": 12.24 }".getBytes()); + assertEquals(new SchemaAndValue(null, 12.24), converted); + + converted = converter.toConnectData(TOPIC, "{ \"schema\": null, \"payload\": \"a string\" }".getBytes()); + assertEquals(new SchemaAndValue(null, "a string"), converted); + + converted = converter.toConnectData(TOPIC, "{ \"schema\": null, \"payload\": [1, \"2\", 3] }".getBytes()); + assertEquals(new SchemaAndValue(null, Arrays.asList(1L, "2", 3L)), converted); + + converted = converter.toConnectData(TOPIC, "{ \"schema\": null, \"payload\": { \"field1\": 1, \"field2\": 2} }".getBytes()); + Map obj = new HashMap<>(); + obj.put("field1", 1L); + obj.put("field2", 2L); + assertEquals(new SchemaAndValue(null, obj), converted); + } + + @Test + public void decimalToConnect() { + Schema schema = Decimal.schema(2); + BigDecimal reference = new BigDecimal(new BigInteger("156"), 2); + // Payload is base64 encoded byte[]{0, -100}, which is the two's complement encoding of 156. + String msg = "{ \"schema\": { \"type\": \"bytes\", \"name\": \"org.apache.kafka.connect.data.Decimal\", \"version\": 1, \"parameters\": { \"scale\": \"2\" } }, \"payload\": \"AJw=\" }"; + SchemaAndValue schemaAndValue = converter.toConnectData(TOPIC, msg.getBytes()); + BigDecimal converted = (BigDecimal) schemaAndValue.value(); + assertEquals(schema, schemaAndValue.schema()); + assertEquals(reference, converted); + } + + @Test + public void decimalToConnectOptional() { + Schema schema = Decimal.builder(2).optional().schema(); + String msg = "{ \"schema\": { \"type\": \"bytes\", \"name\": \"org.apache.kafka.connect.data.Decimal\", \"version\": 1, \"optional\": true, \"parameters\": { \"scale\": \"2\" } }, \"payload\": null }"; + SchemaAndValue schemaAndValue = converter.toConnectData(TOPIC, msg.getBytes()); + assertEquals(schema, schemaAndValue.schema()); + assertNull(schemaAndValue.value()); + } + + @Test + public void decimalToConnectWithDefaultValue() { + BigDecimal reference = new BigDecimal(new BigInteger("156"), 2); + Schema schema = Decimal.builder(2).defaultValue(reference).build(); + String msg = "{ \"schema\": { \"type\": \"bytes\", \"name\": \"org.apache.kafka.connect.data.Decimal\", \"version\": 1, \"default\": \"AJw=\", \"parameters\": { \"scale\": \"2\" } }, \"payload\": null }"; + SchemaAndValue schemaAndValue = converter.toConnectData(TOPIC, msg.getBytes()); + assertEquals(schema, schemaAndValue.schema()); + assertEquals(reference, schemaAndValue.value()); + } + + @Test + public void decimalToConnectOptionalWithDefaultValue() { + BigDecimal reference = new BigDecimal(new BigInteger("156"), 2); + Schema schema = Decimal.builder(2).optional().defaultValue(reference).build(); + String msg = "{ \"schema\": { \"type\": \"bytes\", \"name\": \"org.apache.kafka.connect.data.Decimal\", \"version\": 1, \"optional\": true, \"default\": \"AJw=\", \"parameters\": { \"scale\": \"2\" } }, \"payload\": null }"; + SchemaAndValue schemaAndValue = converter.toConnectData(TOPIC, msg.getBytes()); + assertEquals(schema, schemaAndValue.schema()); + assertEquals(reference, schemaAndValue.value()); + } + + @Test + public void numericDecimalToConnect() { + BigDecimal reference = new BigDecimal(new BigInteger("156"), 2); + Schema schema = Decimal.schema(2); + String msg = "{ \"schema\": { \"type\": \"bytes\", \"name\": \"org.apache.kafka.connect.data.Decimal\", \"version\": 1, \"parameters\": { \"scale\": \"2\" } }, \"payload\": 1.56 }"; + SchemaAndValue schemaAndValue = converter.toConnectData(TOPIC, msg.getBytes()); + assertEquals(schema, schemaAndValue.schema()); + assertEquals(reference, schemaAndValue.value()); + } + + @Test + public void numericDecimalWithTrailingZerosToConnect() { + BigDecimal reference = new BigDecimal(new BigInteger("15600"), 4); + Schema schema = Decimal.schema(4); + String msg = "{ \"schema\": { \"type\": \"bytes\", \"name\": \"org.apache.kafka.connect.data.Decimal\", \"version\": 1, \"parameters\": { \"scale\": \"4\" } }, \"payload\": 1.5600 }"; + SchemaAndValue schemaAndValue = converter.toConnectData(TOPIC, msg.getBytes()); + assertEquals(schema, schemaAndValue.schema()); + assertEquals(reference, schemaAndValue.value()); + } + + @Test + public void highPrecisionNumericDecimalToConnect() { + // this number is too big to be kept in a float64! + BigDecimal reference = new BigDecimal("1.23456789123456789"); + Schema schema = Decimal.schema(17); + String msg = "{ \"schema\": { \"type\": \"bytes\", \"name\": \"org.apache.kafka.connect.data.Decimal\", \"version\": 1, \"parameters\": { \"scale\": \"17\" } }, \"payload\": 1.23456789123456789 }"; + SchemaAndValue schemaAndValue = converter.toConnectData(TOPIC, msg.getBytes()); + assertEquals(schema, schemaAndValue.schema()); + assertEquals(reference, schemaAndValue.value()); + } + + @Test + public void dateToConnect() { + Schema schema = Date.SCHEMA; + GregorianCalendar calendar = new GregorianCalendar(1970, Calendar.JANUARY, 1, 0, 0, 0); + calendar.setTimeZone(TimeZone.getTimeZone("UTC")); + calendar.add(Calendar.DATE, 10000); + java.util.Date reference = calendar.getTime(); + String msg = "{ \"schema\": { \"type\": \"int32\", \"name\": \"org.apache.kafka.connect.data.Date\", \"version\": 1 }, \"payload\": 10000 }"; + SchemaAndValue schemaAndValue = converter.toConnectData(TOPIC, msg.getBytes()); + java.util.Date converted = (java.util.Date) schemaAndValue.value(); + assertEquals(schema, schemaAndValue.schema()); + assertEquals(reference, converted); + } + + @Test + public void dateToConnectOptional() { + Schema schema = Date.builder().optional().schema(); + String msg = "{ \"schema\": { \"type\": \"int32\", \"name\": \"org.apache.kafka.connect.data.Date\", \"version\": 1, \"optional\": true }, \"payload\": null }"; + SchemaAndValue schemaAndValue = converter.toConnectData(TOPIC, msg.getBytes()); + assertEquals(schema, schemaAndValue.schema()); + assertNull(schemaAndValue.value()); + } + + @Test + public void dateToConnectWithDefaultValue() { + java.util.Date reference = new java.util.Date(0); + Schema schema = Date.builder().defaultValue(reference).schema(); + String msg = "{ \"schema\": { \"type\": \"int32\", \"name\": \"org.apache.kafka.connect.data.Date\", \"version\": 1, \"default\": 0 }, \"payload\": null }"; + SchemaAndValue schemaAndValue = converter.toConnectData(TOPIC, msg.getBytes()); + assertEquals(schema, schemaAndValue.schema()); + assertEquals(reference, schemaAndValue.value()); + } + + @Test + public void dateToConnectOptionalWithDefaultValue() { + java.util.Date reference = new java.util.Date(0); + Schema schema = Date.builder().optional().defaultValue(reference).schema(); + String msg = "{ \"schema\": { \"type\": \"int32\", \"name\": \"org.apache.kafka.connect.data.Date\", \"version\": 1, \"optional\": true, \"default\": 0 }, \"payload\": null }"; + SchemaAndValue schemaAndValue = converter.toConnectData(TOPIC, msg.getBytes()); + assertEquals(schema, schemaAndValue.schema()); + assertEquals(reference, schemaAndValue.value()); + } + + @Test + public void timeToConnect() { + Schema schema = Time.SCHEMA; + GregorianCalendar calendar = new GregorianCalendar(1970, Calendar.JANUARY, 1, 0, 0, 0); + calendar.setTimeZone(TimeZone.getTimeZone("UTC")); + calendar.add(Calendar.MILLISECOND, 14400000); + java.util.Date reference = calendar.getTime(); + String msg = "{ \"schema\": { \"type\": \"int32\", \"name\": \"org.apache.kafka.connect.data.Time\", \"version\": 1 }, \"payload\": 14400000 }"; + SchemaAndValue schemaAndValue = converter.toConnectData(TOPIC, msg.getBytes()); + java.util.Date converted = (java.util.Date) schemaAndValue.value(); + assertEquals(schema, schemaAndValue.schema()); + assertEquals(reference, converted); + } + + @Test + public void timeToConnectOptional() { + Schema schema = Time.builder().optional().schema(); + String msg = "{ \"schema\": { \"type\": \"int32\", \"name\": \"org.apache.kafka.connect.data.Time\", \"version\": 1, \"optional\": true }, \"payload\": null }"; + SchemaAndValue schemaAndValue = converter.toConnectData(TOPIC, msg.getBytes()); + assertEquals(schema, schemaAndValue.schema()); + assertNull(schemaAndValue.value()); + } + + @Test + public void timeToConnectWithDefaultValue() { + java.util.Date reference = new java.util.Date(0); + Schema schema = Time.builder().defaultValue(reference).schema(); + String msg = "{ \"schema\": { \"type\": \"int32\", \"name\": \"org.apache.kafka.connect.data.Time\", \"version\": 1, \"default\": 0 }, \"payload\": null }"; + SchemaAndValue schemaAndValue = converter.toConnectData(TOPIC, msg.getBytes()); + assertEquals(schema, schemaAndValue.schema()); + assertEquals(reference, schemaAndValue.value()); + } + + @Test + public void timeToConnectOptionalWithDefaultValue() { + java.util.Date reference = new java.util.Date(0); + Schema schema = Time.builder().optional().defaultValue(reference).schema(); + String msg = "{ \"schema\": { \"type\": \"int32\", \"name\": \"org.apache.kafka.connect.data.Time\", \"version\": 1, \"optional\": true, \"default\": 0 }, \"payload\": null }"; + SchemaAndValue schemaAndValue = converter.toConnectData(TOPIC, msg.getBytes()); + assertEquals(schema, schemaAndValue.schema()); + assertEquals(reference, schemaAndValue.value()); + } + + @Test + public void timestampToConnect() { + Schema schema = Timestamp.SCHEMA; + GregorianCalendar calendar = new GregorianCalendar(1970, Calendar.JANUARY, 1, 0, 0, 0); + calendar.setTimeZone(TimeZone.getTimeZone("UTC")); + calendar.add(Calendar.MILLISECOND, 2000000000); + calendar.add(Calendar.MILLISECOND, 2000000000); + java.util.Date reference = calendar.getTime(); + String msg = "{ \"schema\": { \"type\": \"int64\", \"name\": \"org.apache.kafka.connect.data.Timestamp\", \"version\": 1 }, \"payload\": 4000000000 }"; + SchemaAndValue schemaAndValue = converter.toConnectData(TOPIC, msg.getBytes()); + java.util.Date converted = (java.util.Date) schemaAndValue.value(); + assertEquals(schema, schemaAndValue.schema()); + assertEquals(reference, converted); + } + + @Test + public void timestampToConnectOptional() { + Schema schema = Timestamp.builder().optional().schema(); + String msg = "{ \"schema\": { \"type\": \"int64\", \"name\": \"org.apache.kafka.connect.data.Timestamp\", \"version\": 1, \"optional\": true }, \"payload\": null }"; + SchemaAndValue schemaAndValue = converter.toConnectData(TOPIC, msg.getBytes()); + assertEquals(schema, schemaAndValue.schema()); + assertNull(schemaAndValue.value()); + } + + @Test + public void timestampToConnectWithDefaultValue() { + Schema schema = Timestamp.builder().defaultValue(new java.util.Date(42)).schema(); + String msg = "{ \"schema\": { \"type\": \"int64\", \"name\": \"org.apache.kafka.connect.data.Timestamp\", \"version\": 1, \"default\": 42 }, \"payload\": null }"; + SchemaAndValue schemaAndValue = converter.toConnectData(TOPIC, msg.getBytes()); + assertEquals(schema, schemaAndValue.schema()); + assertEquals(new java.util.Date(42), schemaAndValue.value()); + } + + @Test + public void timestampToConnectOptionalWithDefaultValue() { + Schema schema = Timestamp.builder().optional().defaultValue(new java.util.Date(42)).schema(); + String msg = "{ \"schema\": { \"type\": \"int64\", \"name\": \"org.apache.kafka.connect.data.Timestamp\", \"version\": 1, \"optional\": true, \"default\": 42 }, \"payload\": null }"; + SchemaAndValue schemaAndValue = converter.toConnectData(TOPIC, msg.getBytes()); + assertEquals(schema, schemaAndValue.schema()); + assertEquals(new java.util.Date(42), schemaAndValue.value()); + } + + // Schema metadata + + @Test + public void testJsonSchemaMetadataTranslation() { + JsonNode converted = parse(converter.fromConnectData(TOPIC, Schema.BOOLEAN_SCHEMA, true)); + validateEnvelope(converted); + assertEquals(parse("{ \"type\": \"boolean\", \"optional\": false }"), converted.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME)); + assertTrue(converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME).booleanValue()); + + converted = parse(converter.fromConnectData(TOPIC, Schema.OPTIONAL_BOOLEAN_SCHEMA, null)); + validateEnvelope(converted); + assertEquals(parse("{ \"type\": \"boolean\", \"optional\": true }"), converted.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME)); + assertTrue(converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME).isNull()); + + converted = parse(converter.fromConnectData(TOPIC, SchemaBuilder.bool().defaultValue(true).build(), true)); + validateEnvelope(converted); + assertEquals(parse("{ \"type\": \"boolean\", \"optional\": false, \"default\": true }"), converted.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME)); + assertTrue(converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME).booleanValue()); + + converted = parse(converter.fromConnectData(TOPIC, SchemaBuilder.bool().required().name("bool").version(3).doc("the documentation").parameter("foo", "bar").build(), true)); + validateEnvelope(converted); + assertEquals(parse("{ \"type\": \"boolean\", \"optional\": false, \"name\": \"bool\", \"version\": 3, \"doc\": \"the documentation\", \"parameters\": { \"foo\": \"bar\" }}"), + converted.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME)); + assertTrue(converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME).booleanValue()); + } + + + @Test + public void testCacheSchemaToConnectConversion() { + assertEquals(0, converter.sizeOfToConnectSchemaCache()); + + converter.toConnectData(TOPIC, "{ \"schema\": { \"type\": \"boolean\" }, \"payload\": true }".getBytes()); + assertEquals(1, converter.sizeOfToConnectSchemaCache()); + + converter.toConnectData(TOPIC, "{ \"schema\": { \"type\": \"boolean\" }, \"payload\": true }".getBytes()); + assertEquals(1, converter.sizeOfToConnectSchemaCache()); + + // Different schema should also get cached + converter.toConnectData(TOPIC, "{ \"schema\": { \"type\": \"boolean\", \"optional\": true }, \"payload\": true }".getBytes()); + assertEquals(2, converter.sizeOfToConnectSchemaCache()); + + // Even equivalent, but different JSON encoding of schema, should get different cache entry + converter.toConnectData(TOPIC, "{ \"schema\": { \"type\": \"boolean\", \"optional\": false }, \"payload\": true }".getBytes()); + assertEquals(3, converter.sizeOfToConnectSchemaCache()); + } + + // Schema types + + @Test + public void booleanToJson() { + JsonNode converted = parse(converter.fromConnectData(TOPIC, Schema.BOOLEAN_SCHEMA, true)); + validateEnvelope(converted); + assertEquals(parse("{ \"type\": \"boolean\", \"optional\": false }"), converted.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME)); + assertTrue(converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME).booleanValue()); + } + + @Test + public void byteToJson() { + JsonNode converted = parse(converter.fromConnectData(TOPIC, Schema.INT8_SCHEMA, (byte) 12)); + validateEnvelope(converted); + assertEquals(parse("{ \"type\": \"int8\", \"optional\": false }"), converted.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME)); + assertEquals(12, converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME).intValue()); + } + + @Test + public void shortToJson() { + JsonNode converted = parse(converter.fromConnectData(TOPIC, Schema.INT16_SCHEMA, (short) 12)); + validateEnvelope(converted); + assertEquals(parse("{ \"type\": \"int16\", \"optional\": false }"), converted.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME)); + assertEquals(12, converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME).intValue()); + } + + @Test + public void intToJson() { + JsonNode converted = parse(converter.fromConnectData(TOPIC, Schema.INT32_SCHEMA, 12)); + validateEnvelope(converted); + assertEquals(parse("{ \"type\": \"int32\", \"optional\": false }"), converted.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME)); + assertEquals(12, converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME).intValue()); + } + + @Test + public void longToJson() { + JsonNode converted = parse(converter.fromConnectData(TOPIC, Schema.INT64_SCHEMA, 4398046511104L)); + validateEnvelope(converted); + assertEquals(parse("{ \"type\": \"int64\", \"optional\": false }"), converted.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME)); + assertEquals(4398046511104L, converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME).longValue()); + } + + @Test + public void floatToJson() { + JsonNode converted = parse(converter.fromConnectData(TOPIC, Schema.FLOAT32_SCHEMA, 12.34f)); + validateEnvelope(converted); + assertEquals(parse("{ \"type\": \"float\", \"optional\": false }"), converted.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME)); + assertEquals(12.34f, converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME).floatValue(), 0.001); + } + + @Test + public void doubleToJson() { + JsonNode converted = parse(converter.fromConnectData(TOPIC, Schema.FLOAT64_SCHEMA, 12.34)); + validateEnvelope(converted); + assertEquals(parse("{ \"type\": \"double\", \"optional\": false }"), converted.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME)); + assertEquals(12.34, converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME).doubleValue(), 0.001); + } + + @Test + public void bytesToJson() throws IOException { + JsonNode converted = parse(converter.fromConnectData(TOPIC, Schema.BYTES_SCHEMA, "test-string".getBytes())); + validateEnvelope(converted); + assertEquals(parse("{ \"type\": \"bytes\", \"optional\": false }"), converted.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME)); + assertEquals(ByteBuffer.wrap("test-string".getBytes()), + ByteBuffer.wrap(converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME).binaryValue())); + } + + @Test + public void stringToJson() { + JsonNode converted = parse(converter.fromConnectData(TOPIC, Schema.STRING_SCHEMA, "test-string")); + validateEnvelope(converted); + assertEquals(parse("{ \"type\": \"string\", \"optional\": false }"), converted.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME)); + assertEquals("test-string", converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME).textValue()); + } + + @Test + public void arrayToJson() { + Schema int32Array = SchemaBuilder.array(Schema.INT32_SCHEMA).build(); + JsonNode converted = parse(converter.fromConnectData(TOPIC, int32Array, Arrays.asList(1, 2, 3))); + validateEnvelope(converted); + assertEquals(parse("{ \"type\": \"array\", \"items\": { \"type\": \"int32\", \"optional\": false }, \"optional\": false }"), + converted.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME)); + assertEquals(JsonNodeFactory.instance.arrayNode().add(1).add(2).add(3), + converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME)); + } + + @Test + public void mapToJsonStringKeys() { + Schema stringIntMap = SchemaBuilder.map(Schema.STRING_SCHEMA, Schema.INT32_SCHEMA).build(); + Map input = new HashMap<>(); + input.put("key1", 12); + input.put("key2", 15); + JsonNode converted = parse(converter.fromConnectData(TOPIC, stringIntMap, input)); + validateEnvelope(converted); + assertEquals(parse("{ \"type\": \"map\", \"keys\": { \"type\" : \"string\", \"optional\": false }, \"values\": { \"type\" : \"int32\", \"optional\": false }, \"optional\": false }"), + converted.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME)); + assertEquals(JsonNodeFactory.instance.objectNode().put("key1", 12).put("key2", 15), + converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME)); + } + + @Test + public void mapToJsonNonStringKeys() { + Schema intIntMap = SchemaBuilder.map(Schema.INT32_SCHEMA, Schema.INT32_SCHEMA).build(); + Map input = new HashMap<>(); + input.put(1, 12); + input.put(2, 15); + JsonNode converted = parse(converter.fromConnectData(TOPIC, intIntMap, input)); + validateEnvelope(converted); + assertEquals(parse("{ \"type\": \"map\", \"keys\": { \"type\" : \"int32\", \"optional\": false }, \"values\": { \"type\" : \"int32\", \"optional\": false }, \"optional\": false }"), + converted.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME)); + + assertTrue(converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME).isArray()); + ArrayNode payload = (ArrayNode) converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME); + assertEquals(2, payload.size()); + Set payloadEntries = new HashSet<>(); + for (JsonNode elem : payload) + payloadEntries.add(elem); + assertEquals(new HashSet<>(Arrays.asList(JsonNodeFactory.instance.arrayNode().add(1).add(12), + JsonNodeFactory.instance.arrayNode().add(2).add(15))), + payloadEntries + ); + } + + @Test + public void structToJson() { + Schema schema = SchemaBuilder.struct().field("field1", Schema.BOOLEAN_SCHEMA).field("field2", Schema.STRING_SCHEMA).field("field3", Schema.STRING_SCHEMA).field("field4", Schema.BOOLEAN_SCHEMA).build(); + Struct input = new Struct(schema).put("field1", true).put("field2", "string2").put("field3", "string3").put("field4", false); + JsonNode converted = parse(converter.fromConnectData(TOPIC, schema, input)); + validateEnvelope(converted); + assertEquals(parse("{ \"type\": \"struct\", \"optional\": false, \"fields\": [{ \"field\": \"field1\", \"type\": \"boolean\", \"optional\": false }, { \"field\": \"field2\", \"type\": \"string\", \"optional\": false }, { \"field\": \"field3\", \"type\": \"string\", \"optional\": false }, { \"field\": \"field4\", \"type\": \"boolean\", \"optional\": false }] }"), + converted.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME)); + assertEquals(JsonNodeFactory.instance.objectNode() + .put("field1", true) + .put("field2", "string2") + .put("field3", "string3") + .put("field4", false), + converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME)); + } + + @Test + public void structSchemaIdentical() { + Schema schema = SchemaBuilder.struct().field("field1", Schema.BOOLEAN_SCHEMA) + .field("field2", Schema.STRING_SCHEMA) + .field("field3", Schema.STRING_SCHEMA) + .field("field4", Schema.BOOLEAN_SCHEMA).build(); + Schema inputSchema = SchemaBuilder.struct().field("field1", Schema.BOOLEAN_SCHEMA) + .field("field2", Schema.STRING_SCHEMA) + .field("field3", Schema.STRING_SCHEMA) + .field("field4", Schema.BOOLEAN_SCHEMA).build(); + Struct input = new Struct(inputSchema).put("field1", true).put("field2", "string2").put("field3", "string3").put("field4", false); + assertStructSchemaEqual(schema, input); + } + + + @Test + public void decimalToJson() throws IOException { + JsonNode converted = parse(converter.fromConnectData(TOPIC, Decimal.schema(2), new BigDecimal(new BigInteger("156"), 2))); + validateEnvelope(converted); + assertEquals(parse("{ \"type\": \"bytes\", \"optional\": false, \"name\": \"org.apache.kafka.connect.data.Decimal\", \"version\": 1, \"parameters\": { \"scale\": \"2\" } }"), + converted.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME)); + assertTrue(converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME).isTextual(), "expected node to be base64 text"); + assertArrayEquals(new byte[]{0, -100}, converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME).binaryValue()); + } + + @Test + public void decimalToNumericJson() { + converter.configure(Collections.singletonMap(JsonConverterConfig.DECIMAL_FORMAT_CONFIG, DecimalFormat.NUMERIC.name()), false); + JsonNode converted = parse(converter.fromConnectData(TOPIC, Decimal.schema(2), new BigDecimal(new BigInteger("156"), 2))); + validateEnvelope(converted); + assertEquals(parse("{ \"type\": \"bytes\", \"optional\": false, \"name\": \"org.apache.kafka.connect.data.Decimal\", \"version\": 1, \"parameters\": { \"scale\": \"2\" } }"), + converted.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME)); + assertTrue(converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME).isNumber(), "expected node to be numeric"); + assertEquals(new BigDecimal("1.56"), converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME).decimalValue()); + } + + @Test + public void decimalWithTrailingZerosToNumericJson() { + converter.configure(Collections.singletonMap(JsonConverterConfig.DECIMAL_FORMAT_CONFIG, DecimalFormat.NUMERIC.name()), false); + JsonNode converted = parse(converter.fromConnectData(TOPIC, Decimal.schema(4), new BigDecimal(new BigInteger("15600"), 4))); + validateEnvelope(converted); + assertEquals(parse("{ \"type\": \"bytes\", \"optional\": false, \"name\": \"org.apache.kafka.connect.data.Decimal\", \"version\": 1, \"parameters\": { \"scale\": \"4\" } }"), + converted.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME)); + assertTrue(converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME).isNumber(), "expected node to be numeric"); + assertEquals(new BigDecimal("1.5600"), converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME).decimalValue()); + } + + @Test + public void decimalToJsonWithoutSchema() { + assertThrows( + DataException.class, + () -> converter.fromConnectData(TOPIC, null, new BigDecimal(new BigInteger("156"), 2)), + "expected data exception when serializing BigDecimal without schema"); + } + + @Test + public void dateToJson() { + GregorianCalendar calendar = new GregorianCalendar(1970, Calendar.JANUARY, 1, 0, 0, 0); + calendar.setTimeZone(TimeZone.getTimeZone("UTC")); + calendar.add(Calendar.DATE, 10000); + java.util.Date date = calendar.getTime(); + + JsonNode converted = parse(converter.fromConnectData(TOPIC, Date.SCHEMA, date)); + validateEnvelope(converted); + assertEquals(parse("{ \"type\": \"int32\", \"optional\": false, \"name\": \"org.apache.kafka.connect.data.Date\", \"version\": 1 }"), + converted.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME)); + JsonNode payload = converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME); + assertTrue(payload.isInt()); + assertEquals(10000, payload.intValue()); + } + + @Test + public void timeToJson() { + GregorianCalendar calendar = new GregorianCalendar(1970, Calendar.JANUARY, 1, 0, 0, 0); + calendar.setTimeZone(TimeZone.getTimeZone("UTC")); + calendar.add(Calendar.MILLISECOND, 14400000); + java.util.Date date = calendar.getTime(); + + JsonNode converted = parse(converter.fromConnectData(TOPIC, Time.SCHEMA, date)); + validateEnvelope(converted); + assertEquals(parse("{ \"type\": \"int32\", \"optional\": false, \"name\": \"org.apache.kafka.connect.data.Time\", \"version\": 1 }"), + converted.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME)); + JsonNode payload = converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME); + assertTrue(payload.isInt()); + assertEquals(14400000, payload.longValue()); + } + + @Test + public void timestampToJson() { + GregorianCalendar calendar = new GregorianCalendar(1970, Calendar.JANUARY, 1, 0, 0, 0); + calendar.setTimeZone(TimeZone.getTimeZone("UTC")); + calendar.add(Calendar.MILLISECOND, 2000000000); + calendar.add(Calendar.MILLISECOND, 2000000000); + java.util.Date date = calendar.getTime(); + + JsonNode converted = parse(converter.fromConnectData(TOPIC, Timestamp.SCHEMA, date)); + validateEnvelope(converted); + assertEquals(parse("{ \"type\": \"int64\", \"optional\": false, \"name\": \"org.apache.kafka.connect.data.Timestamp\", \"version\": 1 }"), + converted.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME)); + JsonNode payload = converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME); + assertTrue(payload.isLong()); + assertEquals(4000000000L, payload.longValue()); + } + + + @Test + public void nullSchemaAndPrimitiveToJson() { + // This still needs to do conversion of data, null schema means "anything goes" + JsonNode converted = parse(converter.fromConnectData(TOPIC, null, true)); + validateEnvelopeNullSchema(converted); + assertTrue(converted.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME).isNull()); + assertTrue(converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME).booleanValue()); + } + + @Test + public void nullSchemaAndArrayToJson() { + // This still needs to do conversion of data, null schema means "anything goes". Make sure we mix and match + // types to verify conversion still works. + JsonNode converted = parse(converter.fromConnectData(TOPIC, null, Arrays.asList(1, "string", true))); + validateEnvelopeNullSchema(converted); + assertTrue(converted.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME).isNull()); + assertEquals(JsonNodeFactory.instance.arrayNode().add(1).add("string").add(true), + converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME)); + } + + @Test + public void nullSchemaAndMapToJson() { + // This still needs to do conversion of data, null schema means "anything goes". Make sure we mix and match + // types to verify conversion still works. + Map input = new HashMap<>(); + input.put("key1", 12); + input.put("key2", "string"); + input.put("key3", true); + JsonNode converted = parse(converter.fromConnectData(TOPIC, null, input)); + validateEnvelopeNullSchema(converted); + assertTrue(converted.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME).isNull()); + assertEquals(JsonNodeFactory.instance.objectNode().put("key1", 12).put("key2", "string").put("key3", true), + converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME)); + } + + @Test + public void nullSchemaAndMapNonStringKeysToJson() { + // This still needs to do conversion of data, null schema means "anything goes". Make sure we mix and match + // types to verify conversion still works. + Map input = new HashMap<>(); + input.put("string", 12); + input.put(52, "string"); + input.put(false, true); + JsonNode converted = parse(converter.fromConnectData(TOPIC, null, input)); + validateEnvelopeNullSchema(converted); + assertTrue(converted.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME).isNull()); + assertTrue(converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME).isArray()); + ArrayNode payload = (ArrayNode) converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME); + assertEquals(3, payload.size()); + Set payloadEntries = new HashSet<>(); + for (JsonNode elem : payload) + payloadEntries.add(elem); + assertEquals(new HashSet<>(Arrays.asList(JsonNodeFactory.instance.arrayNode().add("string").add(12), + JsonNodeFactory.instance.arrayNode().add(52).add("string"), + JsonNodeFactory.instance.arrayNode().add(false).add(true))), + payloadEntries + ); + } + + @Test + public void nullSchemaAndNullValueToJson() { + // This characterizes the production of tombstone messages when Json schemas is enabled + Map props = Collections.singletonMap("schemas.enable", true); + converter.configure(props, true); + byte[] converted = converter.fromConnectData(TOPIC, null, null); + assertNull(converted); + } + + @Test + public void nullValueToJson() { + // This characterizes the production of tombstone messages when Json schemas is not enabled + Map props = Collections.singletonMap("schemas.enable", false); + converter.configure(props, true); + byte[] converted = converter.fromConnectData(TOPIC, null, null); + assertNull(converted); + } + + @Test + public void mismatchSchemaJson() { + // If we have mismatching schema info, we should properly convert to a DataException + assertThrows(DataException.class, + () -> converter.fromConnectData(TOPIC, Schema.FLOAT64_SCHEMA, true)); + } + + @Test + public void noSchemaToConnect() { + Map props = Collections.singletonMap("schemas.enable", false); + converter.configure(props, true); + assertEquals(new SchemaAndValue(null, true), converter.toConnectData(TOPIC, "true".getBytes())); + } + + @Test + public void noSchemaToJson() { + Map props = Collections.singletonMap("schemas.enable", false); + converter.configure(props, true); + JsonNode converted = parse(converter.fromConnectData(TOPIC, null, true)); + assertTrue(converted.isBoolean()); + assertTrue(converted.booleanValue()); + } + + @Test + public void testCacheSchemaToJsonConversion() { + assertEquals(0, converter.sizeOfFromConnectSchemaCache()); + + // Repeated conversion of the same schema, even if the schema object is different should return the same Java + // object + converter.fromConnectData(TOPIC, SchemaBuilder.bool().build(), true); + assertEquals(1, converter.sizeOfFromConnectSchemaCache()); + + converter.fromConnectData(TOPIC, SchemaBuilder.bool().build(), true); + assertEquals(1, converter.sizeOfFromConnectSchemaCache()); + + // Validate that a similar, but different schema correctly returns a different schema. + converter.fromConnectData(TOPIC, SchemaBuilder.bool().optional().build(), true); + assertEquals(2, converter.sizeOfFromConnectSchemaCache()); + } + + @Test + public void testJsonSchemaCacheSizeFromConfigFile() throws URISyntaxException, IOException { + URL url = getClass().getResource("/connect-test.properties"); + File propFile = new File(url.toURI()); + String workerPropsFile = propFile.getAbsolutePath(); + Map workerProps = !workerPropsFile.isEmpty() ? + Utils.propsToStringMap(Utils.loadProps(workerPropsFile)) : Collections.emptyMap(); + + JsonConverter rc = new JsonConverter(); + rc.configure(workerProps, false); + } + + + // Note: the header conversion methods delegates to the data conversion methods, which are tested above. + // The following simply verify that the delegation works. + + @Test + public void testStringHeaderToJson() { + JsonNode converted = parse(converter.fromConnectHeader(TOPIC, "headerName", Schema.STRING_SCHEMA, "test-string")); + validateEnvelope(converted); + assertEquals(parse("{ \"type\": \"string\", \"optional\": false }"), converted.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME)); + assertEquals("test-string", converted.get(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME).textValue()); + } + + @Test + public void stringHeaderToConnect() { + assertEquals(new SchemaAndValue(Schema.STRING_SCHEMA, "foo-bar-baz"), converter.toConnectHeader(TOPIC, "headerName", "{ \"schema\": { \"type\": \"string\" }, \"payload\": \"foo-bar-baz\" }".getBytes())); + } + + + private JsonNode parse(byte[] json) { + try { + return objectMapper.readTree(json); + } catch (IOException e) { + fail("IOException during JSON parse: " + e.getMessage()); + throw new RuntimeException("failed"); + } + } + + private JsonNode parse(String json) { + try { + return objectMapper.readTree(json); + } catch (IOException e) { + fail("IOException during JSON parse: " + e.getMessage()); + throw new RuntimeException("failed"); + } + } + + private void validateEnvelope(JsonNode env) { + assertNotNull(env); + assertTrue(env.isObject()); + assertEquals(2, env.size()); + assertTrue(env.has(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME)); + assertTrue(env.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME).isObject()); + assertTrue(env.has(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME)); + } + + private void validateEnvelopeNullSchema(JsonNode env) { + assertNotNull(env); + assertTrue(env.isObject()); + assertEquals(2, env.size()); + assertTrue(env.has(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME)); + assertTrue(env.get(JsonSchema.ENVELOPE_SCHEMA_FIELD_NAME).isNull()); + assertTrue(env.has(JsonSchema.ENVELOPE_PAYLOAD_FIELD_NAME)); + } + + private void assertStructSchemaEqual(Schema schema, Struct struct) { + converter.fromConnectData(TOPIC, schema, struct); + assertEquals(schema, struct.schema()); + } +} diff --git a/connect/json/src/test/resources/connect-test.properties b/connect/json/src/test/resources/connect-test.properties new file mode 100644 index 0000000..9a48f68 --- /dev/null +++ b/connect/json/src/test/resources/connect-test.properties @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +schemas.cache.size=1 + diff --git a/connect/mirror-client/src/main/java/org/apache/kafka/connect/mirror/Checkpoint.java b/connect/mirror-client/src/main/java/org/apache/kafka/connect/mirror/Checkpoint.java new file mode 100644 index 0000000..8729005 --- /dev/null +++ b/connect/mirror-client/src/main/java/org/apache/kafka/connect/mirror/Checkpoint.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.protocol.types.Field; +import org.apache.kafka.common.protocol.types.Schema; +import org.apache.kafka.common.protocol.types.Struct; +import org.apache.kafka.common.protocol.types.Type; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; + +import java.util.Map; +import java.util.HashMap; +import java.nio.ByteBuffer; + +/** Checkpoint records emitted from MirrorCheckpointConnector. Encodes remote consumer group state. */ +public class Checkpoint { + public static final String TOPIC_KEY = "topic"; + public static final String PARTITION_KEY = "partition"; + public static final String CONSUMER_GROUP_ID_KEY = "group"; + public static final String UPSTREAM_OFFSET_KEY = "upstreamOffset"; + public static final String DOWNSTREAM_OFFSET_KEY = "offset"; + public static final String METADATA_KEY = "metadata"; + public static final String VERSION_KEY = "version"; + public static final short VERSION = 0; + + public static final Schema VALUE_SCHEMA_V0 = new Schema( + new Field(UPSTREAM_OFFSET_KEY, Type.INT64), + new Field(DOWNSTREAM_OFFSET_KEY, Type.INT64), + new Field(METADATA_KEY, Type.STRING)); + + public static final Schema KEY_SCHEMA = new Schema( + new Field(CONSUMER_GROUP_ID_KEY, Type.STRING), + new Field(TOPIC_KEY, Type.STRING), + new Field(PARTITION_KEY, Type.INT32)); + + public static final Schema HEADER_SCHEMA = new Schema( + new Field(VERSION_KEY, Type.INT16)); + + private String consumerGroupId; + private TopicPartition topicPartition; + private long upstreamOffset; + private long downstreamOffset; + private String metadata; + + public Checkpoint(String consumerGroupId, TopicPartition topicPartition, long upstreamOffset, + long downstreamOffset, String metadata) { + this.consumerGroupId = consumerGroupId; + this.topicPartition = topicPartition; + this.upstreamOffset = upstreamOffset; + this.downstreamOffset = downstreamOffset; + this.metadata = metadata; + } + + public String consumerGroupId() { + return consumerGroupId; + } + + public TopicPartition topicPartition() { + return topicPartition; + } + + public long upstreamOffset() { + return upstreamOffset; + } + + public long downstreamOffset() { + return downstreamOffset; + } + + public String metadata() { + return metadata; + } + + public OffsetAndMetadata offsetAndMetadata() { + return new OffsetAndMetadata(downstreamOffset, metadata); + } + + @Override + public String toString() { + return String.format("Checkpoint{consumerGroupId=%s, topicPartition=%s, " + + "upstreamOffset=%d, downstreamOffset=%d, metatadata=%s}", + consumerGroupId, topicPartition, upstreamOffset, downstreamOffset, metadata); + } + + ByteBuffer serializeValue(short version) { + Struct header = headerStruct(version); + Schema valueSchema = valueSchema(version); + Struct valueStruct = valueStruct(valueSchema); + ByteBuffer buffer = ByteBuffer.allocate(HEADER_SCHEMA.sizeOf(header) + valueSchema.sizeOf(valueStruct)); + HEADER_SCHEMA.write(buffer, header); + valueSchema.write(buffer, valueStruct); + buffer.flip(); + return buffer; + } + + ByteBuffer serializeKey() { + Struct struct = keyStruct(); + ByteBuffer buffer = ByteBuffer.allocate(KEY_SCHEMA.sizeOf(struct)); + KEY_SCHEMA.write(buffer, struct); + buffer.flip(); + return buffer; + } + + public static Checkpoint deserializeRecord(ConsumerRecord record) { + ByteBuffer value = ByteBuffer.wrap(record.value()); + Struct header = HEADER_SCHEMA.read(value); + short version = header.getShort(VERSION_KEY); + Schema valueSchema = valueSchema(version); + Struct valueStruct = valueSchema.read(value); + long upstreamOffset = valueStruct.getLong(UPSTREAM_OFFSET_KEY); + long downstreamOffset = valueStruct.getLong(DOWNSTREAM_OFFSET_KEY); + String metadata = valueStruct.getString(METADATA_KEY); + Struct keyStruct = KEY_SCHEMA.read(ByteBuffer.wrap(record.key())); + String group = keyStruct.getString(CONSUMER_GROUP_ID_KEY); + String topic = keyStruct.getString(TOPIC_KEY); + int partition = keyStruct.getInt(PARTITION_KEY); + return new Checkpoint(group, new TopicPartition(topic, partition), upstreamOffset, + downstreamOffset, metadata); + } + + private static Schema valueSchema(short version) { + assert version == 0; + return VALUE_SCHEMA_V0; + } + + private Struct valueStruct(Schema schema) { + Struct struct = new Struct(schema); + struct.set(UPSTREAM_OFFSET_KEY, upstreamOffset); + struct.set(DOWNSTREAM_OFFSET_KEY, downstreamOffset); + struct.set(METADATA_KEY, metadata); + return struct; + } + + private Struct keyStruct() { + Struct struct = new Struct(KEY_SCHEMA); + struct.set(CONSUMER_GROUP_ID_KEY, consumerGroupId); + struct.set(TOPIC_KEY, topicPartition.topic()); + struct.set(PARTITION_KEY, topicPartition.partition()); + return struct; + } + + private Struct headerStruct(short version) { + Struct struct = new Struct(HEADER_SCHEMA); + struct.set(VERSION_KEY, version); + return struct; + } + + Map connectPartition() { + Map partition = new HashMap<>(); + partition.put(CONSUMER_GROUP_ID_KEY, consumerGroupId); + partition.put(TOPIC_KEY, topicPartition.topic()); + partition.put(PARTITION_KEY, topicPartition.partition()); + return partition; + } + + static String unwrapGroup(Map connectPartition) { + return connectPartition.get(CONSUMER_GROUP_ID_KEY).toString(); + } + + byte[] recordKey() { + return serializeKey().array(); + } + + byte[] recordValue() { + return serializeValue(VERSION).array(); + } +} + diff --git a/connect/mirror-client/src/main/java/org/apache/kafka/connect/mirror/DefaultReplicationPolicy.java b/connect/mirror-client/src/main/java/org/apache/kafka/connect/mirror/DefaultReplicationPolicy.java new file mode 100644 index 0000000..9de50b6 --- /dev/null +++ b/connect/mirror-client/src/main/java/org/apache/kafka/connect/mirror/DefaultReplicationPolicy.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.common.Configurable; + +import java.util.Map; +import java.util.regex.Pattern; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Defines remote topics like "us-west.topic1". The separator is customizable and defaults to a period. */ +public class DefaultReplicationPolicy implements ReplicationPolicy, Configurable { + + private static final Logger log = LoggerFactory.getLogger(DefaultReplicationPolicy.class); + + // In order to work with various metrics stores, we allow custom separators. + public static final String SEPARATOR_CONFIG = MirrorClientConfig.REPLICATION_POLICY_SEPARATOR; + public static final String SEPARATOR_DEFAULT = "."; + + private String separator = SEPARATOR_DEFAULT; + private Pattern separatorPattern = Pattern.compile(Pattern.quote(SEPARATOR_DEFAULT)); + + @Override + public void configure(Map props) { + if (props.containsKey(SEPARATOR_CONFIG)) { + separator = (String) props.get(SEPARATOR_CONFIG); + log.info("Using custom remote topic separator: '{}'", separator); + separatorPattern = Pattern.compile(Pattern.quote(separator)); + } + } + + @Override + public String formatRemoteTopic(String sourceClusterAlias, String topic) { + return sourceClusterAlias + separator + topic; + } + + @Override + public String topicSource(String topic) { + String[] parts = separatorPattern.split(topic); + if (parts.length < 2) { + // this is not a remote topic + return null; + } else { + return parts[0]; + } + } + + @Override + public String upstreamTopic(String topic) { + String source = topicSource(topic); + if (source == null) { + return null; + } else { + return topic.substring(source.length() + separator.length()); + } + } + + private String internalSuffix() { + return separator + "internal"; + } + + private String checkpointsTopicSuffix() { + return separator + "checkpoints" + internalSuffix(); + } + + @Override + public String offsetSyncsTopic(String clusterAlias) { + return "mm2-offset-syncs" + separator + clusterAlias + internalSuffix(); + } + + @Override + public String checkpointsTopic(String clusterAlias) { + return clusterAlias + checkpointsTopicSuffix(); + } + + @Override + public boolean isCheckpointsTopic(String topic) { + return topic.endsWith(checkpointsTopicSuffix()); + } + + @Override + public boolean isMM2InternalTopic(String topic) { + return topic.endsWith(internalSuffix()); + } +} diff --git a/connect/mirror-client/src/main/java/org/apache/kafka/connect/mirror/Heartbeat.java b/connect/mirror-client/src/main/java/org/apache/kafka/connect/mirror/Heartbeat.java new file mode 100644 index 0000000..5f70055 --- /dev/null +++ b/connect/mirror-client/src/main/java/org/apache/kafka/connect/mirror/Heartbeat.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.common.protocol.types.Field; +import org.apache.kafka.common.protocol.types.Schema; +import org.apache.kafka.common.protocol.types.Struct; +import org.apache.kafka.common.protocol.types.Type; +import org.apache.kafka.clients.consumer.ConsumerRecord; + +import java.util.Map; +import java.util.HashMap; +import java.nio.ByteBuffer; + +/** Heartbeat message sent from MirrorHeartbeatTask to target cluster. Heartbeats are always replicated. */ +public class Heartbeat { + public static final String SOURCE_CLUSTER_ALIAS_KEY = "sourceClusterAlias"; + public static final String TARGET_CLUSTER_ALIAS_KEY = "targetClusterAlias"; + public static final String TIMESTAMP_KEY = "timestamp"; + public static final String VERSION_KEY = "version"; + public static final short VERSION = 0; + + public static final Schema VALUE_SCHEMA_V0 = new Schema( + new Field(TIMESTAMP_KEY, Type.INT64)); + + public static final Schema KEY_SCHEMA = new Schema( + new Field(SOURCE_CLUSTER_ALIAS_KEY, Type.STRING), + new Field(TARGET_CLUSTER_ALIAS_KEY, Type.STRING)); + + public static final Schema HEADER_SCHEMA = new Schema( + new Field(VERSION_KEY, Type.INT16)); + + private String sourceClusterAlias; + private String targetClusterAlias; + private long timestamp; + + public Heartbeat(String sourceClusterAlias, String targetClusterAlias, long timestamp) { + this.sourceClusterAlias = sourceClusterAlias; + this.targetClusterAlias = targetClusterAlias; + this.timestamp = timestamp; + } + + public String sourceClusterAlias() { + return sourceClusterAlias; + } + + public String targetClusterAlias() { + return targetClusterAlias; + } + + public long timestamp() { + return timestamp; + } + + @Override + public String toString() { + return String.format("Heartbeat{sourceClusterAlias=%s, targetClusterAlias=%s, timestamp=%d}", + sourceClusterAlias, targetClusterAlias, timestamp); + } + + ByteBuffer serializeValue(short version) { + Schema valueSchema = valueSchema(version); + Struct header = headerStruct(version); + Struct value = valueStruct(valueSchema); + ByteBuffer buffer = ByteBuffer.allocate(HEADER_SCHEMA.sizeOf(header) + valueSchema.sizeOf(value)); + HEADER_SCHEMA.write(buffer, header); + valueSchema.write(buffer, value); + buffer.flip(); + return buffer; + } + + ByteBuffer serializeKey() { + Struct struct = keyStruct(); + ByteBuffer buffer = ByteBuffer.allocate(KEY_SCHEMA.sizeOf(struct)); + KEY_SCHEMA.write(buffer, struct); + buffer.flip(); + return buffer; + } + + public static Heartbeat deserializeRecord(ConsumerRecord record) { + ByteBuffer value = ByteBuffer.wrap(record.value()); + Struct headerStruct = HEADER_SCHEMA.read(value); + short version = headerStruct.getShort(VERSION_KEY); + Struct valueStruct = valueSchema(version).read(value); + long timestamp = valueStruct.getLong(TIMESTAMP_KEY); + Struct keyStruct = KEY_SCHEMA.read(ByteBuffer.wrap(record.key())); + String sourceClusterAlias = keyStruct.getString(SOURCE_CLUSTER_ALIAS_KEY); + String targetClusterAlias = keyStruct.getString(TARGET_CLUSTER_ALIAS_KEY); + return new Heartbeat(sourceClusterAlias, targetClusterAlias, timestamp); + } + + private Struct headerStruct(short version) { + Struct struct = new Struct(HEADER_SCHEMA); + struct.set(VERSION_KEY, version); + return struct; + } + + private Struct valueStruct(Schema schema) { + Struct struct = new Struct(schema); + struct.set(TIMESTAMP_KEY, timestamp); + return struct; + } + + private Struct keyStruct() { + Struct struct = new Struct(KEY_SCHEMA); + struct.set(SOURCE_CLUSTER_ALIAS_KEY, sourceClusterAlias); + struct.set(TARGET_CLUSTER_ALIAS_KEY, targetClusterAlias); + return struct; + } + + Map connectPartition() { + Map partition = new HashMap<>(); + partition.put(SOURCE_CLUSTER_ALIAS_KEY, sourceClusterAlias); + partition.put(TARGET_CLUSTER_ALIAS_KEY, targetClusterAlias); + return partition; + } + + byte[] recordKey() { + return serializeKey().array(); + } + + byte[] recordValue() { + return serializeValue(VERSION).array(); + } + + private static Schema valueSchema(short version) { + assert version == 0; + return VALUE_SCHEMA_V0; + } +} + diff --git a/connect/mirror-client/src/main/java/org/apache/kafka/connect/mirror/IdentityReplicationPolicy.java b/connect/mirror-client/src/main/java/org/apache/kafka/connect/mirror/IdentityReplicationPolicy.java new file mode 100644 index 0000000..1f0df63 --- /dev/null +++ b/connect/mirror-client/src/main/java/org/apache/kafka/connect/mirror/IdentityReplicationPolicy.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import java.util.Map; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** IdentityReplicationPolicy does not rename remote topics. This is useful for migrating + * from legacy MM1, or for any use-case involving one-way replication. + * + * N.B. MirrorMaker is not able to prevent cycles when using this class, so take care that + * your replication topology is acyclic. If migrating from MirrorMaker v1, this will likely + * already be the case. + */ +public class IdentityReplicationPolicy extends DefaultReplicationPolicy { + private static final Logger log = LoggerFactory.getLogger(IdentityReplicationPolicy.class); + + public static final String SOURCE_CLUSTER_ALIAS_CONFIG = "source.cluster.alias"; + + private String sourceClusterAlias = null; + + @Override + public void configure(Map props) { + super.configure(props); + if (props.containsKey(SOURCE_CLUSTER_ALIAS_CONFIG)) { + sourceClusterAlias = (String) props.get(SOURCE_CLUSTER_ALIAS_CONFIG); + log.info("Using source cluster alias `{}`.", sourceClusterAlias); + } + } + + /** Unlike DefaultReplicationPolicy, IdentityReplicationPolicy does not include the source + * cluster alias in the remote topic name. Instead, topic names are unchanged. + * + * In the special case of heartbeats, we defer to DefaultReplicationPolicy. + */ + @Override + public String formatRemoteTopic(String sourceClusterAlias, String topic) { + if (looksLikeHeartbeat(topic)) { + return super.formatRemoteTopic(sourceClusterAlias, topic); + } else { + return topic; + } + } + + /** Unlike DefaultReplicationPolicy, IdendityReplicationPolicy cannot know the source of + * a remote topic based on its name alone. If `source.cluster.alias` is provided, + * `topicSource` will return that. + * + * In the special case of heartbeats, we defer to DefaultReplicationPolicy. + */ + @Override + public String topicSource(String topic) { + if (looksLikeHeartbeat(topic)) { + return super.topicSource(topic); + } else { + return sourceClusterAlias; + } + } + + /** Since any topic may be a "remote topic", this just returns `topic`. + * + * In the special case of heartbeats, we defer to DefaultReplicationPolicy. + */ + @Override + public String upstreamTopic(String topic) { + if (looksLikeHeartbeat(topic)) { + return super.upstreamTopic(topic); + } else { + return topic; + } + } + + private boolean looksLikeHeartbeat(String topic) { + return topic != null && topic.endsWith(heartbeatsTopic()); + } +} diff --git a/connect/mirror-client/src/main/java/org/apache/kafka/connect/mirror/MirrorClient.java b/connect/mirror-client/src/main/java/org/apache/kafka/connect/mirror/MirrorClient.java new file mode 100644 index 0000000..28790c1 --- /dev/null +++ b/connect/mirror-client/src/main/java/org/apache/kafka/connect/mirror/MirrorClient.java @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.clients.admin.AdminClient; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.common.protocol.types.SchemaException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.util.Set; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.HashMap; +import java.util.Collections; +import java.util.Collection; +import java.util.stream.Collectors; +import java.util.concurrent.ExecutionException; + +/** Interprets MM2's internal topics (checkpoints, heartbeats) on a given cluster. + *

            + * Given a top-level "mm2.properties" configuration file, MirrorClients can be constructed + * for individual clusters as follows: + *

            + *
            + *    MirrorMakerConfig mmConfig = new MirrorMakerConfig(props);
            + *    MirrorClientConfig mmClientConfig = mmConfig.clientConfig("some-cluster");
            + *    MirrorClient mmClient = new Mirrorclient(mmClientConfig);
            + *  
            + */ +public class MirrorClient implements AutoCloseable { + private static final Logger log = LoggerFactory.getLogger(MirrorClient.class); + + private AdminClient adminClient; + private ReplicationPolicy replicationPolicy; + private Map consumerConfig; + + public MirrorClient(Map props) { + this(new MirrorClientConfig(props)); + } + + public MirrorClient(MirrorClientConfig config) { + adminClient = AdminClient.create(config.adminConfig()); + consumerConfig = config.consumerConfig(); + replicationPolicy = config.replicationPolicy(); + } + + // for testing + MirrorClient(AdminClient adminClient, ReplicationPolicy replicationPolicy, + Map consumerConfig) { + this.adminClient = adminClient; + this.replicationPolicy = replicationPolicy; + this.consumerConfig = consumerConfig; + } + + /** Close internal clients. */ + public void close() { + adminClient.close(); + } + + /** Get the ReplicationPolicy instance used to interpret remote topics. This instance is constructed based on + * relevant configuration properties, including {@code replication.policy.class}. */ + public ReplicationPolicy replicationPolicy() { + return replicationPolicy; + } + + /** Compute shortest number of hops from an upstream source cluster. + * For example, given replication flow A->B->C, there are two hops from A to C. + * Returns -1 if upstream cluster is unreachable. + */ + public int replicationHops(String upstreamClusterAlias) throws InterruptedException { + return heartbeatTopics().stream() + .map(x -> countHopsForTopic(x, upstreamClusterAlias)) + .filter(x -> x != -1) + .mapToInt(x -> x) + .min() + .orElse(-1); + } + + /** Find all heartbeat topics on this cluster. Heartbeat topics are replicated from other clusters. */ + public Set heartbeatTopics() throws InterruptedException { + return listTopics().stream() + .filter(this::isHeartbeatTopic) + .collect(Collectors.toSet()); + } + + /** Find all checkpoint topics on this cluster. */ + public Set checkpointTopics() throws InterruptedException { + return listTopics().stream() + .filter(this::isCheckpointTopic) + .collect(Collectors.toSet()); + } + + /** Find upstream clusters, which may be multiple hops away, based on incoming heartbeats. */ + public Set upstreamClusters() throws InterruptedException { + return listTopics().stream() + .filter(this::isHeartbeatTopic) + .flatMap(x -> allSources(x).stream()) + .distinct() + .collect(Collectors.toSet()); + } + + /** Find all remote topics on this cluster. This does not include internal topics (heartbeats, checkpoints). */ + public Set remoteTopics() throws InterruptedException { + return listTopics().stream() + .filter(this::isRemoteTopic) + .collect(Collectors.toSet()); + } + + /** Find all remote topics that have been replicated directly from the given source cluster. */ + public Set remoteTopics(String source) throws InterruptedException { + return listTopics().stream() + .filter(this::isRemoteTopic) + .filter(x -> source.equals(replicationPolicy.topicSource(x))) + .distinct() + .collect(Collectors.toSet()); + } + + /** Translate a remote consumer group's offsets into corresponding local offsets. Topics are automatically + * renamed according to the ReplicationPolicy. + * @param consumerGroupId group ID of remote consumer group + * @param remoteClusterAlias alias of remote cluster + * @param timeout timeout + */ + public Map remoteConsumerOffsets(String consumerGroupId, + String remoteClusterAlias, Duration timeout) { + long deadline = System.currentTimeMillis() + timeout.toMillis(); + Map offsets = new HashMap<>(); + KafkaConsumer consumer = new KafkaConsumer<>(consumerConfig, + new ByteArrayDeserializer(), new ByteArrayDeserializer()); + try { + // checkpoint topics are not "remote topics", as they are not replicated. So we don't need + // to use ReplicationPolicy to create the checkpoint topic here. + String checkpointTopic = replicationPolicy.checkpointsTopic(remoteClusterAlias); + List checkpointAssignment = + Collections.singletonList(new TopicPartition(checkpointTopic, 0)); + consumer.assign(checkpointAssignment); + consumer.seekToBeginning(checkpointAssignment); + while (System.currentTimeMillis() < deadline && !endOfStream(consumer, checkpointAssignment)) { + ConsumerRecords records = consumer.poll(timeout); + for (ConsumerRecord record : records) { + try { + Checkpoint checkpoint = Checkpoint.deserializeRecord(record); + if (checkpoint.consumerGroupId().equals(consumerGroupId)) { + offsets.put(checkpoint.topicPartition(), checkpoint.offsetAndMetadata()); + } + } catch (SchemaException e) { + log.info("Could not deserialize record. Skipping.", e); + } + } + } + log.info("Consumed {} checkpoint records for {} from {}.", offsets.size(), + consumerGroupId, checkpointTopic); + } finally { + consumer.close(); + } + return offsets; + } + + Set listTopics() throws InterruptedException { + try { + return adminClient.listTopics().names().get(); + } catch (ExecutionException e) { + throw new KafkaException(e.getCause()); + } + } + + int countHopsForTopic(String topic, String sourceClusterAlias) { + int hops = 0; + Set visited = new HashSet<>(); + while (true) { + hops++; + String source = replicationPolicy.topicSource(topic); + if (source == null) { + return -1; + } + if (source.equals(sourceClusterAlias)) { + return hops; + } + if (visited.contains(source)) { + // Extra check for IdentityReplicationPolicy and similar impls that cannot prevent cycles. + // We assume we're stuck in a cycle and will never find sourceClusterAlias. + return -1; + } + visited.add(source); + topic = replicationPolicy.upstreamTopic(topic); + } + } + + boolean isHeartbeatTopic(String topic) { + return replicationPolicy.isHeartbeatsTopic(topic); + } + + boolean isCheckpointTopic(String topic) { + return replicationPolicy.isCheckpointsTopic(topic); + } + + boolean isRemoteTopic(String topic) { + return !replicationPolicy.isInternalTopic(topic) + && replicationPolicy.topicSource(topic) != null; + } + + Set allSources(String topic) { + Set sources = new HashSet<>(); + String source = replicationPolicy.topicSource(topic); + while (source != null && !sources.contains(source)) { + // The extra Set.contains above is for ReplicationPolicies that cannot prevent cycles. + sources.add(source); + topic = replicationPolicy.upstreamTopic(topic); + source = replicationPolicy.topicSource(topic); + } + return sources; + } + + static private boolean endOfStream(Consumer consumer, Collection assignments) { + Map endOffsets = consumer.endOffsets(assignments); + for (TopicPartition topicPartition : assignments) { + if (consumer.position(topicPartition) < endOffsets.get(topicPartition)) { + return false; + } + } + return true; + } +} diff --git a/connect/mirror-client/src/main/java/org/apache/kafka/connect/mirror/MirrorClientConfig.java b/connect/mirror-client/src/main/java/org/apache/kafka/connect/mirror/MirrorClientConfig.java new file mode 100644 index 0000000..4305366 --- /dev/null +++ b/connect/mirror-client/src/main/java/org/apache/kafka/connect/mirror/MirrorClientConfig.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigDef.Type; +import org.apache.kafka.common.config.ConfigDef.Importance; +import org.apache.kafka.clients.CommonClientConfigs; + +import java.util.Map; +import java.util.HashMap; + +/** Configuration required for MirrorClient to talk to a given target cluster. + *

            + * Generally, these properties come from an mm2.properties configuration file + * (@see MirrorMakerConfig.clientConfig): + *

            + *
            + *    MirrorMakerConfig mmConfig = new MirrorMakerConfig(props);
            + *    MirrorClientConfig mmClientConfig = mmConfig.clientConfig("some-cluster");
            + *  
            + *

            + * In addition to the properties defined here, sub-configs are supported for Admin, Consumer, and Producer clients. + * For example: + *

            + *
            + *      bootstrap.servers = host1:9092
            + *      consumer.client.id = mm2-client
            + *      replication.policy.separator = __
            + *  
            + */ +public class MirrorClientConfig extends AbstractConfig { + public static final String REPLICATION_POLICY_CLASS = "replication.policy.class"; + private static final String REPLICATION_POLICY_CLASS_DOC = "Class which defines the remote topic naming convention."; + public static final Class REPLICATION_POLICY_CLASS_DEFAULT = DefaultReplicationPolicy.class; + public static final String REPLICATION_POLICY_SEPARATOR = "replication.policy.separator"; + private static final String REPLICATION_POLICY_SEPARATOR_DOC = "Separator used in remote topic naming convention."; + public static final String REPLICATION_POLICY_SEPARATOR_DEFAULT = + DefaultReplicationPolicy.SEPARATOR_DEFAULT; + + public static final String ADMIN_CLIENT_PREFIX = "admin."; + public static final String CONSUMER_CLIENT_PREFIX = "consumer."; + public static final String PRODUCER_CLIENT_PREFIX = "producer."; + + MirrorClientConfig(Map props) { + super(CONFIG_DEF, props, true); + } + + public ReplicationPolicy replicationPolicy() { + return getConfiguredInstance(REPLICATION_POLICY_CLASS, ReplicationPolicy.class); + } + + /** Sub-config for Admin clients. */ + public Map adminConfig() { + return clientConfig(ADMIN_CLIENT_PREFIX); + } + + /** Sub-config for Consumer clients. */ + public Map consumerConfig() { + return clientConfig(CONSUMER_CLIENT_PREFIX); + } + + /** Sub-config for Producer clients. */ + public Map producerConfig() { + return clientConfig(PRODUCER_CLIENT_PREFIX); + } + + private Map clientConfig(String prefix) { + Map props = new HashMap<>(); + props.putAll(valuesWithPrefixOverride(prefix)); + props.keySet().retainAll(CLIENT_CONFIG_DEF.names()); + props.entrySet().removeIf(x -> x.getValue() == null); + return props; + } + + // Properties passed to internal Kafka clients + static final ConfigDef CLIENT_CONFIG_DEF = new ConfigDef() + .define(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, + Type.LIST, + null, + Importance.HIGH, + CommonClientConfigs.BOOTSTRAP_SERVERS_DOC) + // security support + .define(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, + Type.STRING, + CommonClientConfigs.DEFAULT_SECURITY_PROTOCOL, + Importance.MEDIUM, + CommonClientConfigs.SECURITY_PROTOCOL_DOC) + .withClientSslSupport() + .withClientSaslSupport(); + + static final ConfigDef CONFIG_DEF = new ConfigDef() + .define(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, + Type.STRING, + null, + Importance.HIGH, + CommonClientConfigs.BOOTSTRAP_SERVERS_DOC) + .define( + REPLICATION_POLICY_CLASS, + ConfigDef.Type.CLASS, + REPLICATION_POLICY_CLASS_DEFAULT, + ConfigDef.Importance.LOW, + REPLICATION_POLICY_CLASS_DOC) + .define( + REPLICATION_POLICY_SEPARATOR, + ConfigDef.Type.STRING, + REPLICATION_POLICY_SEPARATOR_DEFAULT, + ConfigDef.Importance.LOW, + REPLICATION_POLICY_SEPARATOR_DOC) + .define(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, + Type.STRING, + CommonClientConfigs.DEFAULT_SECURITY_PROTOCOL, + Importance.MEDIUM, + CommonClientConfigs.SECURITY_PROTOCOL_DOC) + .withClientSslSupport() + .withClientSaslSupport(); +} diff --git a/connect/mirror-client/src/main/java/org/apache/kafka/connect/mirror/RemoteClusterUtils.java b/connect/mirror-client/src/main/java/org/apache/kafka/connect/mirror/RemoteClusterUtils.java new file mode 100644 index 0000000..49da62d --- /dev/null +++ b/connect/mirror-client/src/main/java/org/apache/kafka/connect/mirror/RemoteClusterUtils.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; + +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeoutException; +import java.time.Duration; + + +/** Convenience methods for multi-cluster environments. Wraps MirrorClient (@see MirrorClient). + *

            + * Properties passed to these methods are used to construct internal Admin and Consumer clients. + * Sub-configs like "admin.xyz" are also supported. For example: + *

            + *
            + *      bootstrap.servers = host1:9092
            + *      consumer.client.id = mm2-client
            + *  
            + *

            + * @see MirrorClientConfig for additional properties used by the internal MirrorClient. + *

            + */ +public final class RemoteClusterUtils { + + // utility class + private RemoteClusterUtils() {} + + /** Find shortest number of hops from an upstream cluster. + * Returns -1 if the cluster is unreachable */ + public static int replicationHops(Map properties, String upstreamClusterAlias) + throws InterruptedException, TimeoutException { + try (MirrorClient client = new MirrorClient(properties)) { + return client.replicationHops(upstreamClusterAlias); + } + } + + /** Find all heartbeat topics */ + public static Set heartbeatTopics(Map properties) + throws InterruptedException, TimeoutException { + try (MirrorClient client = new MirrorClient(properties)) { + return client.heartbeatTopics(); + } + } + + /** Find all checkpoint topics */ + public static Set checkpointTopics(Map properties) + throws InterruptedException, TimeoutException { + try (MirrorClient client = new MirrorClient(properties)) { + return client.checkpointTopics(); + } + } + + /** Find all upstream clusters */ + public static Set upstreamClusters(Map properties) + throws InterruptedException, TimeoutException { + try (MirrorClient client = new MirrorClient(properties)) { + return client.upstreamClusters(); + } + } + + /** Translate a remote consumer group's offsets into corresponding local offsets. Topics are automatically + * renamed according to the ReplicationPolicy. + * @param properties @see MirrorClientConfig + * @param consumerGroupId group ID of remote consumer group + * @param remoteClusterAlias alias of remote cluster + * @param timeout timeout + */ + public static Map translateOffsets(Map properties, + String remoteClusterAlias, String consumerGroupId, Duration timeout) + throws InterruptedException, TimeoutException { + try (MirrorClient client = new MirrorClient(properties)) { + return client.remoteConsumerOffsets(consumerGroupId, remoteClusterAlias, timeout); + } + } +} diff --git a/connect/mirror-client/src/main/java/org/apache/kafka/connect/mirror/ReplicationPolicy.java b/connect/mirror-client/src/main/java/org/apache/kafka/connect/mirror/ReplicationPolicy.java new file mode 100644 index 0000000..d8d5593 --- /dev/null +++ b/connect/mirror-client/src/main/java/org/apache/kafka/connect/mirror/ReplicationPolicy.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.common.annotation.InterfaceStability; + +/** Defines which topics are "remote topics". e.g. "us-west.topic1". */ +@InterfaceStability.Evolving +public interface ReplicationPolicy { + + /** How to rename remote topics; generally should be like us-west.topic1. */ + String formatRemoteTopic(String sourceClusterAlias, String topic); + + /** Source cluster alias of given remote topic, e.g. "us-west" for "us-west.topic1". + * Returns null if not a remote topic. + */ + String topicSource(String topic); + + /** Name of topic on the source cluster, e.g. "topic1" for "us-west.topic1". + * + * Topics may be replicated multiple hops, so the immediately upstream topic + * may itself be a remote topic. + * + * Returns null if not a remote topic. + */ + String upstreamTopic(String topic); + + /** The name of the original source-topic, which may have been replicated multiple hops. + * Returns the topic if it is not a remote topic. + */ + default String originalTopic(String topic) { + String upstream = upstreamTopic(topic); + if (upstream == null || upstream.equals(topic)) { + return topic; + } else { + return originalTopic(upstream); + } + } + + /** Returns heartbeats topic name.*/ + default String heartbeatsTopic() { + return "heartbeats"; + } + + /** Returns the offset-syncs topic for given cluster alias. */ + default String offsetSyncsTopic(String clusterAlias) { + return "mm2-offset-syncs." + clusterAlias + ".internal"; + } + + /** Returns the name checkpoint topic for given cluster alias. */ + default String checkpointsTopic(String clusterAlias) { + return clusterAlias + ".checkpoints.internal"; + } + + /** check if topic is a heartbeat topic, e.g heartbeats, us-west.heartbeats. */ + default boolean isHeartbeatsTopic(String topic) { + return heartbeatsTopic().equals(originalTopic(topic)); + } + + /** check if topic is a checkpoint topic. */ + default boolean isCheckpointsTopic(String topic) { + return topic.endsWith(".checkpoints.internal"); + } + + /** Check topic is one of MM2 internal topic, this is used to make sure the topic doesn't need to be replicated.*/ + default boolean isMM2InternalTopic(String topic) { + return topic.endsWith(".internal"); + } + + /** Internal topics are never replicated. */ + default boolean isInternalTopic(String topic) { + boolean isKafkaInternalTopic = topic.startsWith("__") || topic.startsWith("."); + boolean isDefaultConnectTopic = topic.endsWith("-internal") || topic.endsWith(".internal"); + return isMM2InternalTopic(topic) || isKafkaInternalTopic || isDefaultConnectTopic; + } +} diff --git a/connect/mirror-client/src/main/java/org/apache/kafka/connect/mirror/SourceAndTarget.java b/connect/mirror-client/src/main/java/org/apache/kafka/connect/mirror/SourceAndTarget.java new file mode 100644 index 0000000..e2f3a37 --- /dev/null +++ b/connect/mirror-client/src/main/java/org/apache/kafka/connect/mirror/SourceAndTarget.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +/** Directional pair of clusters, where source is replicated to target. */ +public class SourceAndTarget { + private String source; + private String target; + + public SourceAndTarget(String source, String target) { + this.source = source; + this.target = target; + } + + public String source() { + return source; + } + + public String target() { + return target; + } + + @Override + public String toString() { + return source + "->" + target; + } + + @Override + public int hashCode() { + return toString().hashCode(); + } + + @Override + public boolean equals(Object other) { + return other != null && toString().equals(other.toString()); + } +} + diff --git a/connect/mirror-client/src/test/java/org/apache/kafka/connect/mirror/MirrorClientTest.java b/connect/mirror-client/src/test/java/org/apache/kafka/connect/mirror/MirrorClientTest.java new file mode 100644 index 0000000..2e1b9b7 --- /dev/null +++ b/connect/mirror-client/src/test/java/org/apache/kafka/connect/mirror/MirrorClientTest.java @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.common.Configurable; + +import java.util.Collections; +import java.util.List; +import java.util.Set; +import java.util.HashSet; +import java.util.Arrays; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; + +public class MirrorClientTest { + + private static class FakeMirrorClient extends MirrorClient { + + List topics; + + FakeMirrorClient(List topics) { + this(new DefaultReplicationPolicy(), topics); + } + + FakeMirrorClient(ReplicationPolicy replicationPolicy, List topics) { + super(null, replicationPolicy, null); + this.topics = topics; + } + + FakeMirrorClient() { + this(Collections.emptyList()); + } + + @Override + protected Set listTopics() { + return new HashSet<>(topics); + } + } + + @Test + public void testIsHeartbeatTopic() { + MirrorClient client = new FakeMirrorClient(); + assertTrue(client.isHeartbeatTopic("heartbeats")); + assertTrue(client.isHeartbeatTopic("source1.heartbeats")); + assertTrue(client.isHeartbeatTopic("source2.source1.heartbeats")); + assertFalse(client.isHeartbeatTopic("heartbeats!")); + assertFalse(client.isHeartbeatTopic("!heartbeats")); + assertFalse(client.isHeartbeatTopic("source1heartbeats")); + assertFalse(client.isHeartbeatTopic("source1-heartbeats")); + } + + @Test + public void testIsCheckpointTopic() { + MirrorClient client = new FakeMirrorClient(); + assertTrue(client.isCheckpointTopic("source1.checkpoints.internal")); + assertFalse(client.isCheckpointTopic("checkpoints.internal")); + assertFalse(client.isCheckpointTopic("checkpoints-internal")); + assertFalse(client.isCheckpointTopic("checkpoints.internal!")); + assertFalse(client.isCheckpointTopic("!checkpoints.internal")); + assertFalse(client.isCheckpointTopic("source1checkpointsinternal")); + } + + @Test + public void countHopsForTopicTest() { + MirrorClient client = new FakeMirrorClient(); + assertEquals(-1, client.countHopsForTopic("topic", "source")); + assertEquals(-1, client.countHopsForTopic("source", "source")); + assertEquals(-1, client.countHopsForTopic("sourcetopic", "source")); + assertEquals(-1, client.countHopsForTopic("source1.topic", "source2")); + assertEquals(1, client.countHopsForTopic("source1.topic", "source1")); + assertEquals(1, client.countHopsForTopic("source2.source1.topic", "source2")); + assertEquals(2, client.countHopsForTopic("source2.source1.topic", "source1")); + assertEquals(3, client.countHopsForTopic("source3.source2.source1.topic", "source1")); + assertEquals(-1, client.countHopsForTopic("source3.source2.source1.topic", "source4")); + } + + @Test + public void heartbeatTopicsTest() throws InterruptedException { + MirrorClient client = new FakeMirrorClient(Arrays.asList("topic1", "topic2", "heartbeats", + "source1.heartbeats", "source2.source1.heartbeats", "source3.heartbeats")); + Set heartbeatTopics = client.heartbeatTopics(); + assertEquals(heartbeatTopics, new HashSet<>(Arrays.asList("heartbeats", "source1.heartbeats", + "source2.source1.heartbeats", "source3.heartbeats"))); + } + + @Test + public void checkpointsTopicsTest() throws InterruptedException { + MirrorClient client = new FakeMirrorClient(Arrays.asList("topic1", "topic2", "checkpoints.internal", + "source1.checkpoints.internal", "source2.source1.checkpoints.internal", "source3.checkpoints.internal")); + Set checkpointTopics = client.checkpointTopics(); + assertEquals(new HashSet<>(Arrays.asList("source1.checkpoints.internal", + "source2.source1.checkpoints.internal", "source3.checkpoints.internal")), checkpointTopics); + } + + @Test + public void replicationHopsTest() throws InterruptedException { + MirrorClient client = new FakeMirrorClient(Arrays.asList("topic1", "topic2", "heartbeats", + "source1.heartbeats", "source1.source2.heartbeats", "source3.heartbeats")); + assertEquals(1, client.replicationHops("source1")); + assertEquals(2, client.replicationHops("source2")); + assertEquals(1, client.replicationHops("source3")); + assertEquals(-1, client.replicationHops("source4")); + } + + @Test + public void upstreamClustersTest() throws InterruptedException { + MirrorClient client = new FakeMirrorClient(Arrays.asList("topic1", "topic2", "heartbeats", + "source1.heartbeats", "source1.source2.heartbeats", "source3.source4.source5.heartbeats")); + Set sources = client.upstreamClusters(); + assertTrue(sources.contains("source1")); + assertTrue(sources.contains("source2")); + assertTrue(sources.contains("source3")); + assertTrue(sources.contains("source4")); + assertTrue(sources.contains("source5")); + assertFalse(sources.contains("sourceX")); + assertFalse(sources.contains("")); + assertFalse(sources.contains(null)); + } + + @Test + public void testIdentityReplicationUpstreamClusters() throws InterruptedException { + // IdentityReplicationPolicy treats heartbeats as a special case, so these should work as usual. + MirrorClient client = new FakeMirrorClient(identityReplicationPolicy("source"), Arrays.asList("topic1", + "topic2", "heartbeats", "source1.heartbeats", "source1.source2.heartbeats", + "source3.source4.source5.heartbeats")); + Set sources = client.upstreamClusters(); + assertTrue(sources.contains("source1")); + assertTrue(sources.contains("source2")); + assertTrue(sources.contains("source3")); + assertTrue(sources.contains("source4")); + assertTrue(sources.contains("source5")); + assertFalse(sources.contains("")); + assertFalse(sources.contains(null)); + assertEquals(5, sources.size()); + } + + @Test + public void remoteTopicsTest() throws InterruptedException { + MirrorClient client = new FakeMirrorClient(Arrays.asList("topic1", "topic2", "topic3", + "source1.topic4", "source1.source2.topic5", "source3.source4.source5.topic6")); + Set remoteTopics = client.remoteTopics(); + assertFalse(remoteTopics.contains("topic1")); + assertFalse(remoteTopics.contains("topic2")); + assertFalse(remoteTopics.contains("topic3")); + assertTrue(remoteTopics.contains("source1.topic4")); + assertTrue(remoteTopics.contains("source1.source2.topic5")); + assertTrue(remoteTopics.contains("source3.source4.source5.topic6")); + } + + @Test + public void testIdentityReplicationRemoteTopics() throws InterruptedException { + // IdentityReplicationPolicy should consider any topic to be remote. + MirrorClient client = new FakeMirrorClient(identityReplicationPolicy("source"), Arrays.asList( + "topic1", "topic2", "topic3", "heartbeats", "backup.heartbeats")); + Set remoteTopics = client.remoteTopics(); + assertTrue(remoteTopics.contains("topic1")); + assertTrue(remoteTopics.contains("topic2")); + assertTrue(remoteTopics.contains("topic3")); + // Heartbeats are treated as a special case + assertFalse(remoteTopics.contains("heartbeats")); + assertTrue(remoteTopics.contains("backup.heartbeats")); + } + + @Test + public void remoteTopicsSeparatorTest() throws InterruptedException { + MirrorClient client = new FakeMirrorClient(Arrays.asList("topic1", "topic2", "topic3", + "source1__topic4", "source1__source2__topic5", "source3__source4__source5__topic6")); + ((Configurable) client.replicationPolicy()).configure( + Collections.singletonMap("replication.policy.separator", "__")); + Set remoteTopics = client.remoteTopics(); + assertFalse(remoteTopics.contains("topic1")); + assertFalse(remoteTopics.contains("topic2")); + assertFalse(remoteTopics.contains("topic3")); + assertTrue(remoteTopics.contains("source1__topic4")); + assertTrue(remoteTopics.contains("source1__source2__topic5")); + assertTrue(remoteTopics.contains("source3__source4__source5__topic6")); + } + + @Test + public void testIdentityReplicationTopicSource() { + MirrorClient client = new FakeMirrorClient( + identityReplicationPolicy("primary"), Arrays.asList()); + assertEquals("topic1", client.replicationPolicy() + .formatRemoteTopic("primary", "topic1")); + assertEquals("primary", client.replicationPolicy() + .topicSource("topic1")); + // Heartbeats are handled as a special case + assertEquals("backup.heartbeats", client.replicationPolicy() + .formatRemoteTopic("backup", "heartbeats")); + assertEquals("backup", client.replicationPolicy() + .topicSource("backup.heartbeats")); + } + + private ReplicationPolicy identityReplicationPolicy(String source) { + IdentityReplicationPolicy policy = new IdentityReplicationPolicy(); + policy.configure(Collections.singletonMap( + IdentityReplicationPolicy.SOURCE_CLUSTER_ALIAS_CONFIG, source)); + return policy; + } +} diff --git a/connect/mirror-client/src/test/java/org/apache/kafka/connect/mirror/ReplicationPolicyTest.java b/connect/mirror-client/src/test/java/org/apache/kafka/connect/mirror/ReplicationPolicyTest.java new file mode 100644 index 0000000..4810f0e --- /dev/null +++ b/connect/mirror-client/src/test/java/org/apache/kafka/connect/mirror/ReplicationPolicyTest.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.mirror; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; + +public class ReplicationPolicyTest { + private static final DefaultReplicationPolicy DEFAULT_REPLICATION_POLICY = new DefaultReplicationPolicy(); + + @Test + public void testInternalTopic() { + // starts with '__' + assertTrue(DEFAULT_REPLICATION_POLICY.isInternalTopic("__consumer_offsets")); + // starts with '.' + assertTrue(DEFAULT_REPLICATION_POLICY.isInternalTopic(".hiddentopic")); + + // ends with '.internal': default DistributedConfig.OFFSET_STORAGE_TOPIC_CONFIG in standalone mode. + assertTrue(DEFAULT_REPLICATION_POLICY.isInternalTopic("mm2-offsets.CLUSTER.internal")); + // ends with '-internal' + assertTrue(DEFAULT_REPLICATION_POLICY.isInternalTopic("mm2-offsets-CLUSTER-internal")); + // non-internal topic. + assertFalse(DEFAULT_REPLICATION_POLICY.isInternalTopic("mm2-offsets_CLUSTER_internal")); + } +} diff --git a/connect/mirror/README.md b/connect/mirror/README.md new file mode 100644 index 0000000..3c8aebc --- /dev/null +++ b/connect/mirror/README.md @@ -0,0 +1,255 @@ + +# MirrorMaker 2.0 + +MM2 leverages the Connect framework to replicate topics between Kafka +clusters. MM2 includes several new features, including: + + - both topics and consumer groups are replicated + - topic configuration and ACLs are replicated + - cross-cluster offsets are synchronized + - partitioning is preserved + +## Replication flows + +MM2 replicates topics and consumer groups from upstream source clusters +to downstream target clusters. These directional flows are notated +`A->B`. + +It's possible to create complex replication topologies based on these +`source->target` flows, including: + + - *fan-out*, e.g. `K->A, K->B, K->C` + - *aggregation*, e.g. `A->K, B->K, C->K` + - *active/active*, e.g. `A->B, B->A` + +Each replication flow can be configured independently, e.g. to replicate +specific topics or groups: + + A->B.topics = topic-1, topic-2 + A->B.groups = group-1, group-2 + +By default, all topics and consumer groups are replicated (except +excluded ones), across all enabled replication flows. Each +replication flow must be explicitly enabled to begin replication: + + A->B.enabled = true + B->A.enabled = true + +## Starting an MM2 process + +You can run any number of MM2 processes as needed. Any MM2 processes +which are configured to replicate the same Kafka clusters will find each +other, share configuration, load balance, etc. + +To start an MM2 process, first specify Kafka cluster information in a +configuration file as follows: + + # mm2.properties + clusters = us-west, us-east + us-west.bootstrap.servers = host1:9092 + us-east.bootstrap.servers = host2:9092 + +You can list any number of clusters this way. + +Optionally, you can override default MirrorMaker properties: + + topics = .* + groups = group1, group2 + emit.checkpoints.interval.seconds = 10 + +These will apply to all replication flows. You can also override default +properties for specific clusters or replication flows: + + # configure a specific cluster + us-west.offset.storage.topic = mm2-offsets + + # configure a specific source->target replication flow + us-west->us-east.emit.heartbeats = false + +Next, enable individual replication flows as follows: + + us-west->us-east.enabled = true # disabled by default + +Finally, launch one or more MirrorMaker processes with the `connect-mirror-maker.sh` +script: + + $ ./bin/connect-mirror-maker.sh mm2.properties + +## Multicluster environments + +MM2 supports replication between multiple Kafka clusters, whether in the +same data center or across multiple data centers. A single MM2 cluster +can span multiple data centers, but it is recommended to keep MM2's producers +as close as possible to their target clusters. To do so, specify a subset +of clusters for each MM2 node as follows: + + # in west DC: + $ ./bin/connect-mirror-maker.sh mm2.properties --clusters west-1 west-2 + +This signals to the node that the given clusters are nearby, and prevents the +node from sending records or configuration to clusters in other data centers. + +### Example + +Say there are three data centers (west, east, north) with two Kafka +clusters in each data center (west-1, west-2 etc). We can configure MM2 +for active/active replication within each data center, as well as cross data +center replication (XDCR) as follows: + + # mm2.properties + clusters: west-1, west-2, east-1, east-2, north-1, north-2 + + west-1.bootstrap.servers = ... + ---%<--- + + # active/active in west + west-1->west-2.enabled = true + west-2->west-1.enabled = true + + # active/active in east + east-1->east-2.enabled = true + east-2->east-1.enabled = true + + # active/active in north + north-1->north-2.enabled = true + north-2->north-1.enabled = true + + # XDCR via west-1, east-1, north-1 + west-1->east-1.enabled = true + west-1->north-1.enabled = true + east-1->west-1.enabled = true + east-1->north-1.enabled = true + north-1->west-1.enabled = true + north-1->east-1.enabled = true + +Then, launch MM2 in each data center as follows: + + # in west: + $ ./bin/connect-mirror-maker.sh mm2.properties --clusters west-1 west-2 + + # in east: + $ ./bin/connect-mirror-maker.sh mm2.properties --clusters east-1 east-2 + + # in north: + $ ./bin/connect-mirror-maker.sh mm2.properties --clusters north-1 north-2 + +With this configuration, records produced to any cluster will be replicated +within the data center, as well as across to other data centers. By providing +the `--clusters` parameter, we ensure that each node only produces records to +nearby clusters. + +N.B. that the `--clusters` parameter is not technically required here. MM2 will work fine without it; however, throughput may suffer from "producer lag" between +data centers, and you may incur unnecessary data transfer costs. + +## Configuration +The following sections target for dedicated MM2 cluster. If running MM2 in a Connect cluster, please refer to [KIP-382: MirrorMaker 2.0](https://cwiki.apache.org/confluence/display/KAFKA/KIP-382%3A+MirrorMaker+2.0) for guidance. + +### General Kafka Connect Config +All Kafka Connect, Source Connector, Sink Connector configs, as defined in [Kafka official doc](https://kafka.apache.org/documentation/#connectconfigs), can be +directly used in MM2 configuration without prefix in the configuration name. As the starting point, most of these default configs may work well with the exception of `tasks.max`. + +In order to evenly distribute the workload across more than one MM2 instance, it is advised to set `tasks.max` at least to 2 or even larger depending on the hardware resources +and the total number partitions to be replicated. + +### Kafka Connect Config for a Specific Connector +If needed, Kafka Connect worker-level configs could be even specified "per connector", which needs to follow the format of `cluster_alias.config_name` in MM2 configuration. For example, + + backup.ssl.truststore.location = /usr/lib/jvm/zulu-8-amd64/jre/lib/security/cacerts // SSL cert location + backup.security.protocol = SSL // if target cluster needs SSL to send message + +### MM2 Config for a Specific Connector +MM2 itself has many configs to control how it behaves. To override those default values, add the config name by the format of `source_cluster_alias->target_cluster_alias.config_name` in MM2 configuration. For example, + + backup->primary.enabled = false // set to false if one-way replication is desired + primary->backup.topics.blacklist = topics_to_blacklist + primary->backup.emit.heartbeats.enabled = false + primary->backup.sync.group.offsets = true + +### Producer / Consumer / Admin Config used by MM2 +In many cases, customized values for producer or consumer configurations are needed. In order to override the default values of producer or consumer used by MM2, +`target_cluster_alias.producer.producer_config_name`, `source_cluster_alias.consumer.consumer_config_name` or `cluster_alias.admin.admin_config_name` are the formats to use in MM2 configuration. For example, + + backup.producer.compression.type = gzip + backup.producer.buffer.memory = 32768 + primary.consumer.isolation.level = read_committed + primary.admin.bootstrap.servers = localhost:9092 + +### Shared configuration + +MM2 processes share configuration via their target Kafka clusters. +For example, the following two processes would be racy: + + # process1: + A->B.enabled = true + A->B.topics = foo + + # process2: + A->B.enabled = true + A->B.topics = bar + +In this case, the two processes will share configuration via cluster `B`. +Depending on which processes is elected "leader", the result will be +that either `foo` or `bar` is replicated -- but not both. For this reason, +it is important to keep configuration consistent across flows to the same +target cluster. In most cases, your entire organization should use a single +MM2 configuration file. + +## Remote topics + +MM2 employs a naming convention to ensure that records from different +clusters are not written to the same partition. By default, replicated +topics are renamed based on "source cluster aliases": + + topic-1 --> source.topic-1 + +This can be customized by overriding the `replication.policy.separator` +property (default is a period). If you need more control over how +remote topics are defined, you can implement a custom `ReplicationPolicy` +and override `replication.policy.class` (default is +`DefaultReplicationPolicy`). + +## Monitoring an MM2 process + +MM2 is built on the Connect framework and inherits all of Connect's metrics, e.g. +`source-record-poll-rate`. In addition, MM2 produces its own metrics under the +`kafka.connect.mirror` metric group. Metrics are tagged with the following properties: + + - *target*: alias of target cluster + - *source*: alias of source cluster + - *topic*: remote topic on target cluster + - *partition*: partition being replicated + +Metrics are tracked for each *remote* topic. The source cluster can be inferred +from the topic name. For example, replicating `topic1` from `A->B` will yield metrics +like: + + - `target=B` + - `topic=A.topic1` + - `partition=1` + +The following metrics are emitted: + + # MBean: kafka.connect.mirror:type=MirrorSourceConnector,target=([-.w]+),topic=([-.w]+),partition=([0-9]+) + + record-count # number of records replicated source -> target + record-age-ms # age of records when they are replicated + record-age-ms-min + record-age-ms-max + record-age-ms-avg + replication-latency-ms # time it takes records to propagate source->target + replication-latency-ms-min + replication-latency-ms-max + replication-latency-ms-avg + byte-rate # average number of bytes/sec in replicated records + + + # MBean: kafka.connect.mirror:type=MirrorCheckpointConnector,source=([-.w]+),target=([-.w]+) + + checkpoint-latency-ms # time it takes to replicate consumer offsets + checkpoint-latency-ms-min + checkpoint-latency-ms-max + checkpoint-latency-ms-avg + +These metrics do not discern between created-at and log-append timestamps. + + diff --git a/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/ConfigPropertyFilter.java b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/ConfigPropertyFilter.java new file mode 100644 index 0000000..ec6b3b9 --- /dev/null +++ b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/ConfigPropertyFilter.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.common.Configurable; +import org.apache.kafka.common.annotation.InterfaceStability; +import java.util.Map; + +/** Defines which topic configuration properties should be replicated. */ +@InterfaceStability.Evolving +public interface ConfigPropertyFilter extends Configurable, AutoCloseable { + + boolean shouldReplicateConfigProperty(String prop); + + default void close() { + //nop + } + + default void configure(Map props) { + //nop + } +} diff --git a/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/DefaultConfigPropertyFilter.java b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/DefaultConfigPropertyFilter.java new file mode 100644 index 0000000..0c85f50 --- /dev/null +++ b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/DefaultConfigPropertyFilter.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigDef.Type; +import org.apache.kafka.common.config.ConfigDef.Importance; +import org.apache.kafka.common.utils.ConfigUtils; + +import java.util.Map; +import java.util.regex.Pattern; + +/** Filters excluded property names or regexes. */ +public class DefaultConfigPropertyFilter implements ConfigPropertyFilter { + + public static final String CONFIG_PROPERTIES_EXCLUDE_CONFIG = "config.properties.exclude"; + public static final String CONFIG_PROPERTIES_EXCLUDE_ALIAS_CONFIG = "config.properties.blacklist"; + + private static final String CONFIG_PROPERTIES_EXCLUDE_DOC = "List of topic configuration properties and/or regexes " + + "that should not be replicated."; + public static final String CONFIG_PROPERTIES_EXCLUDE_DEFAULT = "follower\\.replication\\.throttled\\.replicas, " + + "leader\\.replication\\.throttled\\.replicas, " + + "message\\.timestamp\\.difference\\.max\\.ms, " + + "message\\.timestamp\\.type, " + + "unclean\\.leader\\.election\\.enable, " + + "min\\.insync\\.replicas"; + private Pattern excludePattern = MirrorUtils.compilePatternList(CONFIG_PROPERTIES_EXCLUDE_DEFAULT); + + @Override + public void configure(Map props) { + ConfigPropertyFilterConfig config = new ConfigPropertyFilterConfig(props); + excludePattern = config.excludePattern(); + } + + @Override + public void close() { + } + + private boolean excluded(String prop) { + return excludePattern != null && excludePattern.matcher(prop).matches(); + } + + @Override + public boolean shouldReplicateConfigProperty(String prop) { + return !excluded(prop); + } + + static class ConfigPropertyFilterConfig extends AbstractConfig { + + static final ConfigDef DEF = new ConfigDef() + .define(CONFIG_PROPERTIES_EXCLUDE_CONFIG, + Type.LIST, + CONFIG_PROPERTIES_EXCLUDE_DEFAULT, + Importance.HIGH, + CONFIG_PROPERTIES_EXCLUDE_DOC) + .define(CONFIG_PROPERTIES_EXCLUDE_ALIAS_CONFIG, + Type.LIST, + null, + Importance.HIGH, + "Deprecated. Use " + CONFIG_PROPERTIES_EXCLUDE_CONFIG + " instead."); + + ConfigPropertyFilterConfig(Map props) { + super(DEF, ConfigUtils.translateDeprecatedConfigs(props, new String[][]{ + {CONFIG_PROPERTIES_EXCLUDE_CONFIG, CONFIG_PROPERTIES_EXCLUDE_ALIAS_CONFIG}}), false); + } + + Pattern excludePattern() { + return MirrorUtils.compilePatternList(getList(CONFIG_PROPERTIES_EXCLUDE_CONFIG)); + } + } +} diff --git a/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/DefaultGroupFilter.java b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/DefaultGroupFilter.java new file mode 100644 index 0000000..179067e --- /dev/null +++ b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/DefaultGroupFilter.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigDef.Type; +import org.apache.kafka.common.config.ConfigDef.Importance; +import org.apache.kafka.common.utils.ConfigUtils; + +import java.util.Map; +import java.util.regex.Pattern; + +/** Uses an include and exclude pattern. */ +public class DefaultGroupFilter implements GroupFilter { + + public static final String GROUPS_INCLUDE_CONFIG = "groups"; + private static final String GROUPS_INCLUDE_DOC = "List of consumer group names and/or regexes to replicate."; + public static final String GROUPS_INCLUDE_DEFAULT = ".*"; + + public static final String GROUPS_EXCLUDE_CONFIG = "groups.exclude"; + public static final String GROUPS_EXCLUDE_CONFIG_ALIAS = "groups.blacklist"; + + private static final String GROUPS_EXCLUDE_DOC = "List of consumer group names and/or regexes that should not be replicated."; + public static final String GROUPS_EXCLUDE_DEFAULT = "console-consumer-.*, connect-.*, __.*"; + + private Pattern includePattern; + private Pattern excludePattern; + + @Override + public void configure(Map props) { + GroupFilterConfig config = new GroupFilterConfig(props); + includePattern = config.includePattern(); + excludePattern = config.excludePattern(); + } + + @Override + public void close() { + } + + private boolean included(String group) { + return includePattern != null && includePattern.matcher(group).matches(); + } + + private boolean excluded(String group) { + return excludePattern != null && excludePattern.matcher(group).matches(); + } + + @Override + public boolean shouldReplicateGroup(String group) { + return included(group) && !excluded(group); + } + + static class GroupFilterConfig extends AbstractConfig { + + static final ConfigDef DEF = new ConfigDef() + .define(GROUPS_INCLUDE_CONFIG, + Type.LIST, + GROUPS_INCLUDE_DEFAULT, + Importance.HIGH, + GROUPS_INCLUDE_DOC) + .define(GROUPS_EXCLUDE_CONFIG, + Type.LIST, + GROUPS_EXCLUDE_DEFAULT, + Importance.HIGH, + GROUPS_EXCLUDE_DOC) + .define(GROUPS_EXCLUDE_CONFIG_ALIAS, + Type.LIST, + null, + Importance.HIGH, + "Deprecated. Use " + GROUPS_EXCLUDE_CONFIG + " instead."); + + GroupFilterConfig(Map props) { + super(DEF, ConfigUtils.translateDeprecatedConfigs(props, new String[][]{ + {GROUPS_EXCLUDE_CONFIG, GROUPS_EXCLUDE_CONFIG_ALIAS}}), false); + } + + Pattern includePattern() { + return MirrorUtils.compilePatternList(getList(GROUPS_INCLUDE_CONFIG)); + } + + Pattern excludePattern() { + return MirrorUtils.compilePatternList(getList(GROUPS_EXCLUDE_CONFIG)); + } + } +} diff --git a/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/DefaultTopicFilter.java b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/DefaultTopicFilter.java new file mode 100644 index 0000000..f808ce8 --- /dev/null +++ b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/DefaultTopicFilter.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigDef.Type; +import org.apache.kafka.common.config.ConfigDef.Importance; +import org.apache.kafka.common.utils.ConfigUtils; + +import java.util.Map; +import java.util.regex.Pattern; + +/** Uses an include and exclude pattern. */ +public class DefaultTopicFilter implements TopicFilter { + + public static final String TOPICS_INCLUDE_CONFIG = "topics"; + private static final String TOPICS_INCLUDE_DOC = "List of topics and/or regexes to replicate."; + public static final String TOPICS_INCLUDE_DEFAULT = ".*"; + + public static final String TOPICS_EXCLUDE_CONFIG = "topics.exclude"; + public static final String TOPICS_EXCLUDE_CONFIG_ALIAS = "topics.blacklist"; + private static final String TOPICS_EXCLUDE_DOC = "List of topics and/or regexes that should not be replicated."; + public static final String TOPICS_EXCLUDE_DEFAULT = ".*[\\-\\.]internal, .*\\.replica, __.*"; + + private Pattern includePattern; + private Pattern excludePattern; + + @Override + public void configure(Map props) { + TopicFilterConfig config = new TopicFilterConfig(props); + includePattern = config.includePattern(); + excludePattern = config.excludePattern(); + } + + @Override + public void close() { + } + + private boolean included(String topic) { + return includePattern != null && includePattern.matcher(topic).matches(); + } + + private boolean excluded(String topic) { + return excludePattern != null && excludePattern.matcher(topic).matches(); + } + + @Override + public boolean shouldReplicateTopic(String topic) { + return included(topic) && !excluded(topic); + } + + static class TopicFilterConfig extends AbstractConfig { + + static final ConfigDef DEF = new ConfigDef() + .define(TOPICS_INCLUDE_CONFIG, + Type.LIST, + TOPICS_INCLUDE_DEFAULT, + Importance.HIGH, + TOPICS_INCLUDE_DOC) + .define(TOPICS_EXCLUDE_CONFIG, + Type.LIST, + TOPICS_EXCLUDE_DEFAULT, + Importance.HIGH, + TOPICS_EXCLUDE_DOC) + .define(TOPICS_EXCLUDE_CONFIG_ALIAS, + Type.LIST, + null, + Importance.HIGH, + "Deprecated. Use " + TOPICS_EXCLUDE_CONFIG + " instead."); + + TopicFilterConfig(Map props) { + super(DEF, ConfigUtils.translateDeprecatedConfigs(props, new String[][]{ + {TOPICS_EXCLUDE_CONFIG, TOPICS_EXCLUDE_CONFIG_ALIAS}}), false); + } + + Pattern includePattern() { + return MirrorUtils.compilePatternList(getList(TOPICS_INCLUDE_CONFIG)); + } + + Pattern excludePattern() { + return MirrorUtils.compilePatternList(getList(TOPICS_EXCLUDE_CONFIG)); + } + } +} diff --git a/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/GroupFilter.java b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/GroupFilter.java new file mode 100644 index 0000000..0202dd5 --- /dev/null +++ b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/GroupFilter.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.common.Configurable; +import org.apache.kafka.common.annotation.InterfaceStability; +import java.util.Map; + +/** Defines which consumer groups should be replicated. */ +@InterfaceStability.Evolving +public interface GroupFilter extends Configurable, AutoCloseable { + + boolean shouldReplicateGroup(String group); + + default void close() { + //nop + } + + default void configure(Map props) { + //nop + } +} diff --git a/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorCheckpointConnector.java b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorCheckpointConnector.java new file mode 100644 index 0000000..5118ee1 --- /dev/null +++ b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorCheckpointConnector.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.clients.admin.AdminClient; +import org.apache.kafka.clients.admin.ConsumerGroupListing; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.connect.connector.Task; +import org.apache.kafka.connect.source.SourceConnector; +import org.apache.kafka.connect.util.ConnectorUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.stream.Collectors; + +/** Replicate consumer group state between clusters. Emits checkpoint records. + * + * @see MirrorConnectorConfig for supported config properties. + */ +public class MirrorCheckpointConnector extends SourceConnector { + + private static final Logger log = LoggerFactory.getLogger(MirrorCheckpointConnector.class); + + private Scheduler scheduler; + private MirrorConnectorConfig config; + private GroupFilter groupFilter; + private AdminClient sourceAdminClient; + private SourceAndTarget sourceAndTarget; + private String connectorName; + private List knownConsumerGroups = Collections.emptyList(); + + public MirrorCheckpointConnector() { + // nop + } + + // visible for testing + MirrorCheckpointConnector(List knownConsumerGroups, MirrorConnectorConfig config) { + this.knownConsumerGroups = knownConsumerGroups; + this.config = config; + } + + @Override + public void start(Map props) { + config = new MirrorConnectorConfig(props); + if (!config.enabled()) { + return; + } + connectorName = config.connectorName(); + sourceAndTarget = new SourceAndTarget(config.sourceClusterAlias(), config.targetClusterAlias()); + groupFilter = config.groupFilter(); + sourceAdminClient = AdminClient.create(config.sourceAdminConfig()); + scheduler = new Scheduler(MirrorCheckpointConnector.class, config.adminTimeout()); + scheduler.execute(this::createInternalTopics, "creating internal topics"); + scheduler.execute(this::loadInitialConsumerGroups, "loading initial consumer groups"); + scheduler.scheduleRepeatingDelayed(this::refreshConsumerGroups, config.refreshGroupsInterval(), + "refreshing consumer groups"); + log.info("Started {} with {} consumer groups.", connectorName, knownConsumerGroups.size()); + log.debug("Started {} with consumer groups: {}", connectorName, knownConsumerGroups); + } + + @Override + public void stop() { + if (!config.enabled()) { + return; + } + Utils.closeQuietly(scheduler, "scheduler"); + Utils.closeQuietly(groupFilter, "group filter"); + Utils.closeQuietly(sourceAdminClient, "source admin client"); + } + + @Override + public Class taskClass() { + return MirrorCheckpointTask.class; + } + + // divide consumer groups among tasks + @Override + public List> taskConfigs(int maxTasks) { + // if the replication is disabled, known consumer group is empty, or checkpoint emission is + // disabled by setting 'emit.checkpoints.enabled' to false, the interval of checkpoint emission + // will be negative and no 'MirrorHeartbeatTask' will be created + if (!config.enabled() || knownConsumerGroups.isEmpty() + || config.emitCheckpointsInterval().isNegative()) { + return Collections.emptyList(); + } + int numTasks = Math.min(maxTasks, knownConsumerGroups.size()); + return ConnectorUtils.groupPartitions(knownConsumerGroups, numTasks).stream() + .map(config::taskConfigForConsumerGroups) + .collect(Collectors.toList()); + } + + @Override + public ConfigDef config() { + return MirrorConnectorConfig.CONNECTOR_CONFIG_DEF; + } + + @Override + public String version() { + return "1"; + } + + private void refreshConsumerGroups() + throws InterruptedException, ExecutionException { + List consumerGroups = findConsumerGroups(); + Set newConsumerGroups = new HashSet<>(); + newConsumerGroups.addAll(consumerGroups); + newConsumerGroups.removeAll(knownConsumerGroups); + Set deadConsumerGroups = new HashSet<>(); + deadConsumerGroups.addAll(knownConsumerGroups); + deadConsumerGroups.removeAll(consumerGroups); + if (!newConsumerGroups.isEmpty() || !deadConsumerGroups.isEmpty()) { + log.info("Found {} consumer groups for {}. {} are new. {} were removed. Previously had {}.", + consumerGroups.size(), sourceAndTarget, newConsumerGroups.size(), deadConsumerGroups.size(), + knownConsumerGroups.size()); + log.debug("Found new consumer groups: {}", newConsumerGroups); + knownConsumerGroups = consumerGroups; + context.requestTaskReconfiguration(); + } + } + + private void loadInitialConsumerGroups() + throws InterruptedException, ExecutionException { + knownConsumerGroups = findConsumerGroups(); + } + + List findConsumerGroups() + throws InterruptedException, ExecutionException { + return listConsumerGroups().stream() + .map(ConsumerGroupListing::groupId) + .filter(this::shouldReplicate) + .collect(Collectors.toList()); + } + + Collection listConsumerGroups() + throws InterruptedException, ExecutionException { + return sourceAdminClient.listConsumerGroups().valid().get(); + } + + private void createInternalTopics() { + MirrorUtils.createSinglePartitionCompactedTopic(config.checkpointsTopic(), + config.checkpointsTopicReplicationFactor(), config.targetAdminConfig()); + } + + boolean shouldReplicate(String group) { + return groupFilter.shouldReplicateGroup(group); + } +} diff --git a/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorCheckpointTask.java b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorCheckpointTask.java new file mode 100644 index 0000000..09eb0fd --- /dev/null +++ b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorCheckpointTask.java @@ -0,0 +1,321 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.ConsumerGroupDescription; +import org.apache.kafka.common.ConsumerGroupState; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.connect.source.SourceTask; +import org.apache.kafka.connect.source.SourceRecord; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.clients.admin.AdminClient; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.producer.RecordMetadata; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.Map.Entry; +import java.util.Map; +import java.util.List; +import java.util.ArrayList; +import java.util.Set; +import java.util.Collections; +import java.util.stream.Collectors; +import java.util.concurrent.ExecutionException; +import java.time.Duration; + +/** Emits checkpoints for upstream consumer groups. */ +public class MirrorCheckpointTask extends SourceTask { + + private static final Logger log = LoggerFactory.getLogger(MirrorCheckpointTask.class); + + private Admin sourceAdminClient; + private Admin targetAdminClient; + private String sourceClusterAlias; + private String targetClusterAlias; + private String checkpointsTopic; + private Duration interval; + private Duration pollTimeout; + private TopicFilter topicFilter; + private Set consumerGroups; + private ReplicationPolicy replicationPolicy; + private OffsetSyncStore offsetSyncStore; + private boolean stopping; + private MirrorMetrics metrics; + private Scheduler scheduler; + private Map> idleConsumerGroupsOffset; + private Map> checkpointsPerConsumerGroup; + public MirrorCheckpointTask() {} + + // for testing + MirrorCheckpointTask(String sourceClusterAlias, String targetClusterAlias, + ReplicationPolicy replicationPolicy, OffsetSyncStore offsetSyncStore, + Map> idleConsumerGroupsOffset, + Map> checkpointsPerConsumerGroup) { + this.sourceClusterAlias = sourceClusterAlias; + this.targetClusterAlias = targetClusterAlias; + this.replicationPolicy = replicationPolicy; + this.offsetSyncStore = offsetSyncStore; + this.idleConsumerGroupsOffset = idleConsumerGroupsOffset; + this.checkpointsPerConsumerGroup = checkpointsPerConsumerGroup; + } + + @Override + public void start(Map props) { + MirrorTaskConfig config = new MirrorTaskConfig(props); + stopping = false; + sourceClusterAlias = config.sourceClusterAlias(); + targetClusterAlias = config.targetClusterAlias(); + consumerGroups = config.taskConsumerGroups(); + checkpointsTopic = config.checkpointsTopic(); + topicFilter = config.topicFilter(); + replicationPolicy = config.replicationPolicy(); + interval = config.emitCheckpointsInterval(); + pollTimeout = config.consumerPollTimeout(); + offsetSyncStore = new OffsetSyncStore(config); + sourceAdminClient = AdminClient.create(config.sourceAdminConfig()); + targetAdminClient = AdminClient.create(config.targetAdminConfig()); + metrics = config.metrics(); + idleConsumerGroupsOffset = new HashMap<>(); + checkpointsPerConsumerGroup = new HashMap<>(); + scheduler = new Scheduler(MirrorCheckpointTask.class, config.adminTimeout()); + scheduler.scheduleRepeating(this::refreshIdleConsumerGroupOffset, config.syncGroupOffsetsInterval(), + "refreshing idle consumers group offsets at target cluster"); + scheduler.scheduleRepeatingDelayed(this::syncGroupOffset, config.syncGroupOffsetsInterval(), + "sync idle consumer group offset from source to target"); + } + + @Override + public void commit() throws InterruptedException { + // nop + } + + @Override + public void stop() { + long start = System.currentTimeMillis(); + stopping = true; + Utils.closeQuietly(topicFilter, "topic filter"); + Utils.closeQuietly(offsetSyncStore, "offset sync store"); + Utils.closeQuietly(sourceAdminClient, "source admin client"); + Utils.closeQuietly(targetAdminClient, "target admin client"); + Utils.closeQuietly(metrics, "metrics"); + Utils.closeQuietly(scheduler, "scheduler"); + log.info("Stopping {} took {} ms.", Thread.currentThread().getName(), System.currentTimeMillis() - start); + } + + @Override + public String version() { + return "1"; + } + + @Override + public List poll() throws InterruptedException { + try { + long deadline = System.currentTimeMillis() + interval.toMillis(); + while (!stopping && System.currentTimeMillis() < deadline) { + offsetSyncStore.update(pollTimeout); + } + List records = new ArrayList<>(); + for (String group : consumerGroups) { + records.addAll(sourceRecordsForGroup(group)); + } + if (records.isEmpty()) { + // WorkerSourceTask expects non-zero batches or null + return null; + } else { + return records; + } + } catch (Throwable e) { + log.warn("Failure polling consumer state for checkpoints.", e); + return null; + } + } + + + private List sourceRecordsForGroup(String group) throws InterruptedException { + try { + long timestamp = System.currentTimeMillis(); + List checkpoints = checkpointsForGroup(group); + checkpointsPerConsumerGroup.put(group, checkpoints); + return checkpoints.stream() + .map(x -> checkpointRecord(x, timestamp)) + .collect(Collectors.toList()); + } catch (ExecutionException e) { + log.error("Error querying offsets for consumer group {} on cluster {}.", group, sourceClusterAlias, e); + return Collections.emptyList(); + } + } + + private List checkpointsForGroup(String group) throws ExecutionException, InterruptedException { + return listConsumerGroupOffsets(group).entrySet().stream() + .filter(x -> shouldCheckpointTopic(x.getKey().topic())) + .map(x -> checkpoint(group, x.getKey(), x.getValue())) + .filter(x -> x.downstreamOffset() >= 0) // ignore offsets we cannot translate accurately + .collect(Collectors.toList()); + } + + private Map listConsumerGroupOffsets(String group) + throws InterruptedException, ExecutionException { + if (stopping) { + // short circuit if stopping + return Collections.emptyMap(); + } + return sourceAdminClient.listConsumerGroupOffsets(group).partitionsToOffsetAndMetadata().get(); + } + + Checkpoint checkpoint(String group, TopicPartition topicPartition, + OffsetAndMetadata offsetAndMetadata) { + long upstreamOffset = offsetAndMetadata.offset(); + long downstreamOffset = offsetSyncStore.translateDownstream(topicPartition, upstreamOffset); + return new Checkpoint(group, renameTopicPartition(topicPartition), + upstreamOffset, downstreamOffset, offsetAndMetadata.metadata()); + } + + SourceRecord checkpointRecord(Checkpoint checkpoint, long timestamp) { + return new SourceRecord( + checkpoint.connectPartition(), MirrorUtils.wrapOffset(0), + checkpointsTopic, 0, + Schema.BYTES_SCHEMA, checkpoint.recordKey(), + Schema.BYTES_SCHEMA, checkpoint.recordValue(), + timestamp); + } + + TopicPartition renameTopicPartition(TopicPartition upstreamTopicPartition) { + if (targetClusterAlias.equals(replicationPolicy.topicSource(upstreamTopicPartition.topic()))) { + // this topic came from the target cluster, so we rename like us-west.topic1 -> topic1 + return new TopicPartition(replicationPolicy.originalTopic(upstreamTopicPartition.topic()), + upstreamTopicPartition.partition()); + } else { + // rename like topic1 -> us-west.topic1 + return new TopicPartition(replicationPolicy.formatRemoteTopic(sourceClusterAlias, + upstreamTopicPartition.topic()), upstreamTopicPartition.partition()); + } + } + + boolean shouldCheckpointTopic(String topic) { + return topicFilter.shouldReplicateTopic(topic); + } + + @Override + public void commitRecord(SourceRecord record, RecordMetadata metadata) { + metrics.checkpointLatency(MirrorUtils.unwrapPartition(record.sourcePartition()), + Checkpoint.unwrapGroup(record.sourcePartition()), + System.currentTimeMillis() - record.timestamp()); + } + + private void refreshIdleConsumerGroupOffset() { + Map> consumerGroupsDesc = targetAdminClient + .describeConsumerGroups(consumerGroups).describedGroups(); + + for (String group : consumerGroups) { + try { + ConsumerGroupDescription consumerGroupDesc = consumerGroupsDesc.get(group).get(); + ConsumerGroupState consumerGroupState = consumerGroupDesc.state(); + // sync offset to the target cluster only if the state of current consumer group is: + // (1) idle: because the consumer at target is not actively consuming the mirrored topic + // (2) dead: the new consumer that is recently created at source and never exist at target + if (consumerGroupState.equals(ConsumerGroupState.EMPTY)) { + idleConsumerGroupsOffset.put(group, targetAdminClient.listConsumerGroupOffsets(group) + .partitionsToOffsetAndMetadata().get().entrySet().stream().collect( + Collectors.toMap(Entry::getKey, Entry::getValue))); + } + // new consumer upstream has state "DEAD" and will be identified during the offset sync-up + } catch (InterruptedException | ExecutionException e) { + log.error("Error querying for consumer group {} on cluster {}.", group, targetClusterAlias, e); + } + } + } + + Map> syncGroupOffset() { + Map> offsetToSyncAll = new HashMap<>(); + + // first, sync offsets for the idle consumers at target + for (Entry> group : getConvertedUpstreamOffset().entrySet()) { + String consumerGroupId = group.getKey(); + // for each idle consumer at target, read the checkpoints (converted upstream offset) + // from the pre-populated map + Map convertedUpstreamOffset = group.getValue(); + + Map offsetToSync = new HashMap<>(); + Map targetConsumerOffset = idleConsumerGroupsOffset.get(consumerGroupId); + if (targetConsumerOffset == null) { + // this is a new consumer, just sync the offset to target + syncGroupOffset(consumerGroupId, convertedUpstreamOffset); + offsetToSyncAll.put(consumerGroupId, convertedUpstreamOffset); + continue; + } + + for (Entry convertedEntry : convertedUpstreamOffset.entrySet()) { + + TopicPartition topicPartition = convertedEntry.getKey(); + OffsetAndMetadata convertedOffset = convertedUpstreamOffset.get(topicPartition); + if (!targetConsumerOffset.containsKey(topicPartition)) { + // if is a new topicPartition from upstream, just sync the offset to target + offsetToSync.put(topicPartition, convertedOffset); + continue; + } + + // if translated offset from upstream is smaller than the current consumer offset + // in the target, skip updating the offset for that partition + long latestDownstreamOffset = targetConsumerOffset.get(topicPartition).offset(); + if (latestDownstreamOffset >= convertedOffset.offset()) { + log.trace("latestDownstreamOffset {} is larger than or equal to convertedUpstreamOffset {} for " + + "TopicPartition {}", latestDownstreamOffset, convertedOffset.offset(), topicPartition); + continue; + } + offsetToSync.put(topicPartition, convertedOffset); + } + + if (offsetToSync.size() == 0) { + log.trace("skip syncing the offset for consumer group: {}", consumerGroupId); + continue; + } + syncGroupOffset(consumerGroupId, offsetToSync); + + offsetToSyncAll.put(consumerGroupId, offsetToSync); + } + idleConsumerGroupsOffset.clear(); + return offsetToSyncAll; + } + + void syncGroupOffset(String consumerGroupId, Map offsetToSync) { + if (targetAdminClient != null) { + targetAdminClient.alterConsumerGroupOffsets(consumerGroupId, offsetToSync); + log.trace("sync-ed the offset for consumer group: {} with {} number of offset entries", + consumerGroupId, offsetToSync.size()); + } + } + + Map> getConvertedUpstreamOffset() { + Map> result = new HashMap<>(); + + for (Entry> entry : checkpointsPerConsumerGroup.entrySet()) { + String consumerId = entry.getKey(); + Map convertedUpstreamOffset = new HashMap<>(); + for (Checkpoint checkpoint : entry.getValue()) { + convertedUpstreamOffset.put(checkpoint.topicPartition(), checkpoint.offsetAndMetadata()); + } + result.put(consumerId, convertedUpstreamOffset); + } + return result; + } +} diff --git a/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorConnectorConfig.java b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorConnectorConfig.java new file mode 100644 index 0000000..c964ab0 --- /dev/null +++ b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorConnectorConfig.java @@ -0,0 +1,726 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigDef.ValidString; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.metrics.KafkaMetricsContext; +import org.apache.kafka.common.metrics.MetricsReporter; +import org.apache.kafka.common.metrics.JmxReporter; +import org.apache.kafka.common.metrics.MetricsContext; +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.common.utils.ConfigUtils; +import org.apache.kafka.connect.runtime.ConnectorConfig; +import static org.apache.kafka.clients.consumer.ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG; +import static org.apache.kafka.clients.consumer.ConsumerConfig.AUTO_OFFSET_RESET_CONFIG; + +import java.util.Map; +import java.util.HashMap; +import java.util.List; +import java.util.stream.Collectors; +import java.time.Duration; + +/** Shared config properties used by MirrorSourceConnector, MirrorCheckpointConnector, and MirrorHeartbeatConnector. + *

            + * Generally, these properties are filled-in automatically by MirrorMaker based on a top-level mm2.properties file. + * However, when running MM2 connectors as plugins on a Connect-as-a-Service cluster, these properties must be configured manually, + * e.g. via the Connect REST API. + *

            + *

            + * An example configuration when running on Connect (not via MirrorMaker driver): + *

            + *
            + *      {
            + *        "name": "MirrorSourceConnector",
            + *        "connector.class": "org.apache.kafka.connect.mirror.MirrorSourceConnector",
            + *        "replication.factor": "1",
            + *        "source.cluster.alias": "backup",
            + *        "target.cluster.alias": "primary",
            + *        "source.cluster.bootstrap.servers": "vip1:9092",
            + *        "target.cluster.bootstrap.servers": "vip2:9092",
            + *        "topics": ".*test-topic-.*",
            + *        "groups": "consumer-group-.*",
            + *        "emit.checkpoints.interval.seconds": "1",
            + *        "emit.heartbeats.interval.seconds": "1",
            + *        "sync.topic.acls.enabled": "false"
            + *      }
            + *  
            + */ +public class MirrorConnectorConfig extends AbstractConfig { + + protected static final String ENABLED_SUFFIX = ".enabled"; + protected static final String INTERVAL_SECONDS_SUFFIX = ".interval.seconds"; + + protected static final String REFRESH_TOPICS = "refresh.topics"; + protected static final String REFRESH_GROUPS = "refresh.groups"; + protected static final String SYNC_TOPIC_CONFIGS = "sync.topic.configs"; + protected static final String SYNC_TOPIC_ACLS = "sync.topic.acls"; + protected static final String EMIT_HEARTBEATS = "emit.heartbeats"; + protected static final String EMIT_CHECKPOINTS = "emit.checkpoints"; + protected static final String SYNC_GROUP_OFFSETS = "sync.group.offsets"; + + public static final String ENABLED = "enabled"; + private static final String ENABLED_DOC = "Whether to replicate source->target."; + public static final String SOURCE_CLUSTER_ALIAS = "source.cluster.alias"; + public static final String SOURCE_CLUSTER_ALIAS_DEFAULT = "source"; + private static final String SOURCE_CLUSTER_ALIAS_DOC = "Alias of source cluster"; + public static final String TARGET_CLUSTER_ALIAS = "target.cluster.alias"; + public static final String TARGET_CLUSTER_ALIAS_DEFAULT = "target"; + private static final String TARGET_CLUSTER_ALIAS_DOC = "Alias of target cluster. Used in metrics reporting."; + public static final String REPLICATION_POLICY_CLASS = MirrorClientConfig.REPLICATION_POLICY_CLASS; + public static final Class REPLICATION_POLICY_CLASS_DEFAULT = MirrorClientConfig.REPLICATION_POLICY_CLASS_DEFAULT; + private static final String REPLICATION_POLICY_CLASS_DOC = "Class which defines the remote topic naming convention."; + public static final String REPLICATION_POLICY_SEPARATOR = MirrorClientConfig.REPLICATION_POLICY_SEPARATOR; + private static final String REPLICATION_POLICY_SEPARATOR_DOC = "Separator used in remote topic naming convention."; + public static final String REPLICATION_POLICY_SEPARATOR_DEFAULT = + MirrorClientConfig.REPLICATION_POLICY_SEPARATOR_DEFAULT; + public static final String REPLICATION_FACTOR = "replication.factor"; + private static final String REPLICATION_FACTOR_DOC = "Replication factor for newly created remote topics."; + public static final int REPLICATION_FACTOR_DEFAULT = 2; + public static final String TOPICS = DefaultTopicFilter.TOPICS_INCLUDE_CONFIG; + public static final String TOPICS_DEFAULT = DefaultTopicFilter.TOPICS_INCLUDE_DEFAULT; + private static final String TOPICS_DOC = "Topics to replicate. Supports comma-separated topic names and regexes."; + public static final String TOPICS_EXCLUDE = DefaultTopicFilter.TOPICS_EXCLUDE_CONFIG; + public static final String TOPICS_EXCLUDE_ALIAS = DefaultTopicFilter.TOPICS_EXCLUDE_CONFIG_ALIAS; + public static final String TOPICS_EXCLUDE_DEFAULT = DefaultTopicFilter.TOPICS_EXCLUDE_DEFAULT; + private static final String TOPICS_EXCLUDE_DOC = "Excluded topics. Supports comma-separated topic names and regexes." + + " Excludes take precedence over includes."; + public static final String GROUPS = DefaultGroupFilter.GROUPS_INCLUDE_CONFIG; + public static final String GROUPS_DEFAULT = DefaultGroupFilter.GROUPS_INCLUDE_DEFAULT; + private static final String GROUPS_DOC = "Consumer groups to replicate. Supports comma-separated group IDs and regexes."; + public static final String GROUPS_EXCLUDE = DefaultGroupFilter.GROUPS_EXCLUDE_CONFIG; + public static final String GROUPS_EXCLUDE_ALIAS = DefaultGroupFilter.GROUPS_EXCLUDE_CONFIG_ALIAS; + + public static final String GROUPS_EXCLUDE_DEFAULT = DefaultGroupFilter.GROUPS_EXCLUDE_DEFAULT; + private static final String GROUPS_EXCLUDE_DOC = "Exclude groups. Supports comma-separated group IDs and regexes." + + " Excludes take precedence over includes."; + public static final String CONFIG_PROPERTIES_EXCLUDE = DefaultConfigPropertyFilter.CONFIG_PROPERTIES_EXCLUDE_CONFIG; + public static final String CONFIG_PROPERTIES_EXCLUDE_ALIAS = DefaultConfigPropertyFilter.CONFIG_PROPERTIES_EXCLUDE_ALIAS_CONFIG; + public static final String CONFIG_PROPERTIES_EXCLUDE_DEFAULT = DefaultConfigPropertyFilter.CONFIG_PROPERTIES_EXCLUDE_DEFAULT; + private static final String CONFIG_PROPERTIES_EXCLUDE_DOC = "Topic config properties that should not be replicated. Supports " + + "comma-separated property names and regexes."; + + public static final String HEARTBEATS_TOPIC_REPLICATION_FACTOR = "heartbeats.topic.replication.factor"; + public static final String HEARTBEATS_TOPIC_REPLICATION_FACTOR_DOC = "Replication factor for heartbeats topic."; + public static final short HEARTBEATS_TOPIC_REPLICATION_FACTOR_DEFAULT = 3; + + public static final String CHECKPOINTS_TOPIC_REPLICATION_FACTOR = "checkpoints.topic.replication.factor"; + public static final String CHECKPOINTS_TOPIC_REPLICATION_FACTOR_DOC = "Replication factor for checkpoints topic."; + public static final short CHECKPOINTS_TOPIC_REPLICATION_FACTOR_DEFAULT = 3; + + public static final String OFFSET_SYNCS_TOPIC_REPLICATION_FACTOR = "offset-syncs.topic.replication.factor"; + public static final String OFFSET_SYNCS_TOPIC_REPLICATION_FACTOR_DOC = "Replication factor for offset-syncs topic."; + public static final short OFFSET_SYNCS_TOPIC_REPLICATION_FACTOR_DEFAULT = 3; + + protected static final String TASK_TOPIC_PARTITIONS = "task.assigned.partitions"; + protected static final String TASK_CONSUMER_GROUPS = "task.assigned.groups"; + + public static final String CONSUMER_POLL_TIMEOUT_MILLIS = "consumer.poll.timeout.ms"; + private static final String CONSUMER_POLL_TIMEOUT_MILLIS_DOC = "Timeout when polling source cluster."; + public static final long CONSUMER_POLL_TIMEOUT_MILLIS_DEFAULT = 1000L; + + public static final String ADMIN_TASK_TIMEOUT_MILLIS = "admin.timeout.ms"; + private static final String ADMIN_TASK_TIMEOUT_MILLIS_DOC = "Timeout for administrative tasks, e.g. detecting new topics."; + public static final long ADMIN_TASK_TIMEOUT_MILLIS_DEFAULT = 60000L; + + public static final String REFRESH_TOPICS_ENABLED = REFRESH_TOPICS + ENABLED_SUFFIX; + private static final String REFRESH_TOPICS_ENABLED_DOC = "Whether to periodically check for new topics and partitions."; + public static final boolean REFRESH_TOPICS_ENABLED_DEFAULT = true; + public static final String REFRESH_TOPICS_INTERVAL_SECONDS = REFRESH_TOPICS + INTERVAL_SECONDS_SUFFIX; + private static final String REFRESH_TOPICS_INTERVAL_SECONDS_DOC = "Frequency of topic refresh."; + public static final long REFRESH_TOPICS_INTERVAL_SECONDS_DEFAULT = 10 * 60; + + public static final String REFRESH_GROUPS_ENABLED = REFRESH_GROUPS + ENABLED_SUFFIX; + private static final String REFRESH_GROUPS_ENABLED_DOC = "Whether to periodically check for new consumer groups."; + public static final boolean REFRESH_GROUPS_ENABLED_DEFAULT = true; + public static final String REFRESH_GROUPS_INTERVAL_SECONDS = REFRESH_GROUPS + INTERVAL_SECONDS_SUFFIX; + private static final String REFRESH_GROUPS_INTERVAL_SECONDS_DOC = "Frequency of group refresh."; + public static final long REFRESH_GROUPS_INTERVAL_SECONDS_DEFAULT = 10 * 60; + + public static final String SYNC_TOPIC_CONFIGS_ENABLED = SYNC_TOPIC_CONFIGS + ENABLED_SUFFIX; + private static final String SYNC_TOPIC_CONFIGS_ENABLED_DOC = "Whether to periodically configure remote topics to match their corresponding upstream topics."; + public static final boolean SYNC_TOPIC_CONFIGS_ENABLED_DEFAULT = true; + public static final String SYNC_TOPIC_CONFIGS_INTERVAL_SECONDS = SYNC_TOPIC_CONFIGS + INTERVAL_SECONDS_SUFFIX; + private static final String SYNC_TOPIC_CONFIGS_INTERVAL_SECONDS_DOC = "Frequency of topic config sync."; + public static final long SYNC_TOPIC_CONFIGS_INTERVAL_SECONDS_DEFAULT = 10 * 60; + + public static final String SYNC_TOPIC_ACLS_ENABLED = SYNC_TOPIC_ACLS + ENABLED_SUFFIX; + private static final String SYNC_TOPIC_ACLS_ENABLED_DOC = "Whether to periodically configure remote topic ACLs to match their corresponding upstream topics."; + public static final boolean SYNC_TOPIC_ACLS_ENABLED_DEFAULT = true; + public static final String SYNC_TOPIC_ACLS_INTERVAL_SECONDS = SYNC_TOPIC_ACLS + INTERVAL_SECONDS_SUFFIX; + private static final String SYNC_TOPIC_ACLS_INTERVAL_SECONDS_DOC = "Frequency of topic ACL sync."; + public static final long SYNC_TOPIC_ACLS_INTERVAL_SECONDS_DEFAULT = 10 * 60; + + public static final String EMIT_HEARTBEATS_ENABLED = EMIT_HEARTBEATS + ENABLED_SUFFIX; + private static final String EMIT_HEARTBEATS_ENABLED_DOC = "Whether to emit heartbeats to target cluster."; + public static final boolean EMIT_HEARTBEATS_ENABLED_DEFAULT = true; + public static final String EMIT_HEARTBEATS_INTERVAL_SECONDS = EMIT_HEARTBEATS + INTERVAL_SECONDS_SUFFIX; + private static final String EMIT_HEARTBEATS_INTERVAL_SECONDS_DOC = "Frequency of heartbeats."; + public static final long EMIT_HEARTBEATS_INTERVAL_SECONDS_DEFAULT = 1; + + public static final String EMIT_CHECKPOINTS_ENABLED = EMIT_CHECKPOINTS + ENABLED_SUFFIX; + private static final String EMIT_CHECKPOINTS_ENABLED_DOC = "Whether to replicate consumer offsets to target cluster."; + public static final boolean EMIT_CHECKPOINTS_ENABLED_DEFAULT = true; + public static final String EMIT_CHECKPOINTS_INTERVAL_SECONDS = EMIT_CHECKPOINTS + INTERVAL_SECONDS_SUFFIX; + private static final String EMIT_CHECKPOINTS_INTERVAL_SECONDS_DOC = "Frequency of checkpoints."; + public static final long EMIT_CHECKPOINTS_INTERVAL_SECONDS_DEFAULT = 60; + + + public static final String SYNC_GROUP_OFFSETS_ENABLED = SYNC_GROUP_OFFSETS + ENABLED_SUFFIX; + private static final String SYNC_GROUP_OFFSETS_ENABLED_DOC = "Whether to periodically write the translated offsets to __consumer_offsets topic in target cluster, as long as no active consumers in that group are connected to the target cluster"; + public static final boolean SYNC_GROUP_OFFSETS_ENABLED_DEFAULT = false; + public static final String SYNC_GROUP_OFFSETS_INTERVAL_SECONDS = SYNC_GROUP_OFFSETS + INTERVAL_SECONDS_SUFFIX; + private static final String SYNC_GROUP_OFFSETS_INTERVAL_SECONDS_DOC = "Frequency of consumer group offset sync."; + public static final long SYNC_GROUP_OFFSETS_INTERVAL_SECONDS_DEFAULT = 60; + + public static final String TOPIC_FILTER_CLASS = "topic.filter.class"; + private static final String TOPIC_FILTER_CLASS_DOC = "TopicFilter to use. Selects topics to replicate."; + public static final Class TOPIC_FILTER_CLASS_DEFAULT = DefaultTopicFilter.class; + public static final String GROUP_FILTER_CLASS = "group.filter.class"; + private static final String GROUP_FILTER_CLASS_DOC = "GroupFilter to use. Selects consumer groups to replicate."; + public static final Class GROUP_FILTER_CLASS_DEFAULT = DefaultGroupFilter.class; + public static final String CONFIG_PROPERTY_FILTER_CLASS = "config.property.filter.class"; + private static final String CONFIG_PROPERTY_FILTER_CLASS_DOC = "ConfigPropertyFilter to use. Selects topic config " + + " properties to replicate."; + public static final Class CONFIG_PROPERTY_FILTER_CLASS_DEFAULT = DefaultConfigPropertyFilter.class; + + public static final String OFFSET_LAG_MAX = "offset.lag.max"; + private static final String OFFSET_LAG_MAX_DOC = "How out-of-sync a remote partition can be before it is resynced."; + public static final long OFFSET_LAG_MAX_DEFAULT = 100L; + + private static final String OFFSET_SYNCS_TOPIC_LOCATION = "offset-syncs.topic.location"; + private static final String OFFSET_SYNCS_TOPIC_LOCATION_DEFAULT = SOURCE_CLUSTER_ALIAS_DEFAULT; + private static final String OFFSET_SYNCS_TOPIC_LOCATION_DOC = "The location (source/target) of the offset-syncs topic."; + + protected static final String SOURCE_CLUSTER_PREFIX = MirrorMakerConfig.SOURCE_CLUSTER_PREFIX; + protected static final String TARGET_CLUSTER_PREFIX = MirrorMakerConfig.TARGET_CLUSTER_PREFIX; + protected static final String SOURCE_PREFIX = MirrorMakerConfig.SOURCE_PREFIX; + protected static final String TARGET_PREFIX = MirrorMakerConfig.TARGET_PREFIX; + protected static final String PRODUCER_CLIENT_PREFIX = "producer."; + protected static final String CONSUMER_CLIENT_PREFIX = "consumer."; + protected static final String ADMIN_CLIENT_PREFIX = "admin."; + + public MirrorConnectorConfig(Map props) { + this(CONNECTOR_CONFIG_DEF, ConfigUtils.translateDeprecatedConfigs(props, new String[][]{ + {TOPICS_EXCLUDE, TOPICS_EXCLUDE_ALIAS}, + {GROUPS_EXCLUDE, GROUPS_EXCLUDE_ALIAS}, + {CONFIG_PROPERTIES_EXCLUDE, CONFIG_PROPERTIES_EXCLUDE_ALIAS}})); + } + + protected MirrorConnectorConfig(ConfigDef configDef, Map props) { + super(configDef, props, true); + } + + String connectorName() { + return getString(ConnectorConfig.NAME_CONFIG); + } + + boolean enabled() { + return getBoolean(ENABLED); + } + + Duration consumerPollTimeout() { + return Duration.ofMillis(getLong(CONSUMER_POLL_TIMEOUT_MILLIS)); + } + + Duration adminTimeout() { + return Duration.ofMillis(getLong(ADMIN_TASK_TIMEOUT_MILLIS)); + } + + Map sourceProducerConfig() { + Map props = new HashMap<>(); + props.putAll(originalsWithPrefix(SOURCE_CLUSTER_PREFIX)); + props.keySet().retainAll(MirrorClientConfig.CLIENT_CONFIG_DEF.names()); + props.putAll(originalsWithPrefix(PRODUCER_CLIENT_PREFIX)); + props.putAll(originalsWithPrefix(SOURCE_PREFIX + PRODUCER_CLIENT_PREFIX)); + return props; + } + + Map sourceConsumerConfig() { + Map props = new HashMap<>(); + props.putAll(originalsWithPrefix(SOURCE_CLUSTER_PREFIX)); + props.keySet().retainAll(MirrorClientConfig.CLIENT_CONFIG_DEF.names()); + props.putAll(originalsWithPrefix(CONSUMER_CLIENT_PREFIX)); + props.putAll(originalsWithPrefix(SOURCE_PREFIX + CONSUMER_CLIENT_PREFIX)); + props.put(ENABLE_AUTO_COMMIT_CONFIG, "false"); + props.putIfAbsent(AUTO_OFFSET_RESET_CONFIG, "earliest"); + return props; + } + + Map taskConfigForTopicPartitions(List topicPartitions) { + Map props = originalsStrings(); + String topicPartitionsString = topicPartitions.stream() + .map(MirrorUtils::encodeTopicPartition) + .collect(Collectors.joining(",")); + props.put(TASK_TOPIC_PARTITIONS, topicPartitionsString); + return props; + } + + Map taskConfigForConsumerGroups(List groups) { + Map props = originalsStrings(); + props.put(TASK_CONSUMER_GROUPS, String.join(",", groups)); + return props; + } + + Map targetAdminConfig() { + Map props = new HashMap<>(); + props.putAll(originalsWithPrefix(TARGET_CLUSTER_PREFIX)); + props.keySet().retainAll(MirrorClientConfig.CLIENT_CONFIG_DEF.names()); + props.putAll(originalsWithPrefix(ADMIN_CLIENT_PREFIX)); + props.putAll(originalsWithPrefix(TARGET_PREFIX + ADMIN_CLIENT_PREFIX)); + return props; + } + + Map targetProducerConfig() { + Map props = new HashMap<>(); + props.putAll(originalsWithPrefix(TARGET_CLUSTER_PREFIX)); + props.keySet().retainAll(MirrorClientConfig.CLIENT_CONFIG_DEF.names()); + props.putAll(originalsWithPrefix(PRODUCER_CLIENT_PREFIX)); + props.putAll(originalsWithPrefix(TARGET_PREFIX + PRODUCER_CLIENT_PREFIX)); + return props; + } + + Map targetConsumerConfig() { + Map props = new HashMap<>(); + props.putAll(originalsWithPrefix(TARGET_CLUSTER_PREFIX)); + props.keySet().retainAll(MirrorClientConfig.CLIENT_CONFIG_DEF.names()); + props.putAll(originalsWithPrefix(CONSUMER_CLIENT_PREFIX)); + props.putAll(originalsWithPrefix(TARGET_PREFIX + CONSUMER_CLIENT_PREFIX)); + props.put(ENABLE_AUTO_COMMIT_CONFIG, "false"); + props.putIfAbsent(AUTO_OFFSET_RESET_CONFIG, "earliest"); + return props; + } + + Map sourceAdminConfig() { + Map props = new HashMap<>(); + props.putAll(originalsWithPrefix(SOURCE_CLUSTER_PREFIX)); + props.keySet().retainAll(MirrorClientConfig.CLIENT_CONFIG_DEF.names()); + props.putAll(originalsWithPrefix(ADMIN_CLIENT_PREFIX)); + props.putAll(originalsWithPrefix(SOURCE_PREFIX + ADMIN_CLIENT_PREFIX)); + return props; + } + + List metricsReporters() { + List reporters = getConfiguredInstances( + CommonClientConfigs.METRIC_REPORTER_CLASSES_CONFIG, MetricsReporter.class); + JmxReporter jmxReporter = new JmxReporter(); + jmxReporter.configure(this.originals()); + reporters.add(jmxReporter); + MetricsContext metricsContext = new KafkaMetricsContext("kafka.connect.mirror"); + + for (MetricsReporter reporter : reporters) { + reporter.contextChange(metricsContext); + } + + return reporters; + } + + String sourceClusterAlias() { + return getString(SOURCE_CLUSTER_ALIAS); + } + + String targetClusterAlias() { + return getString(TARGET_CLUSTER_ALIAS); + } + + String offsetSyncsTopic() { + String otherClusterAlias = SOURCE_CLUSTER_ALIAS_DEFAULT.equals(offsetSyncsTopicLocation()) + ? targetClusterAlias() + : sourceClusterAlias(); + return replicationPolicy().offsetSyncsTopic(otherClusterAlias); + } + + String offsetSyncsTopicLocation() { + return getString(OFFSET_SYNCS_TOPIC_LOCATION); + } + + Map offsetSyncsTopicAdminConfig() { + return SOURCE_CLUSTER_ALIAS_DEFAULT.equals(offsetSyncsTopicLocation()) + ? sourceAdminConfig() + : targetAdminConfig(); + } + + Map offsetSyncsTopicProducerConfig() { + return SOURCE_CLUSTER_ALIAS_DEFAULT.equals(offsetSyncsTopicLocation()) + ? sourceProducerConfig() + : targetProducerConfig(); + } + + Map offsetSyncsTopicConsumerConfig() { + return SOURCE_CLUSTER_ALIAS_DEFAULT.equals(offsetSyncsTopicLocation()) + ? sourceConsumerConfig() + : targetConsumerConfig(); + } + + String heartbeatsTopic() { + return replicationPolicy().heartbeatsTopic(); + } + + // e.g. source1.heartbeats + String targetHeartbeatsTopic() { + return replicationPolicy().formatRemoteTopic(sourceClusterAlias(), heartbeatsTopic()); + } + + String checkpointsTopic() { + return replicationPolicy().checkpointsTopic(sourceClusterAlias()); + } + + long maxOffsetLag() { + return getLong(OFFSET_LAG_MAX); + } + + Duration emitHeartbeatsInterval() { + if (getBoolean(EMIT_HEARTBEATS_ENABLED)) { + return Duration.ofSeconds(getLong(EMIT_HEARTBEATS_INTERVAL_SECONDS)); + } else { + // negative interval to disable + return Duration.ofMillis(-1); + } + } + + Duration emitCheckpointsInterval() { + if (getBoolean(EMIT_CHECKPOINTS_ENABLED)) { + return Duration.ofSeconds(getLong(EMIT_CHECKPOINTS_INTERVAL_SECONDS)); + } else { + // negative interval to disable + return Duration.ofMillis(-1); + } + } + + Duration refreshTopicsInterval() { + if (getBoolean(REFRESH_TOPICS_ENABLED)) { + return Duration.ofSeconds(getLong(REFRESH_TOPICS_INTERVAL_SECONDS)); + } else { + // negative interval to disable + return Duration.ofMillis(-1); + } + } + + Duration refreshGroupsInterval() { + if (getBoolean(REFRESH_GROUPS_ENABLED)) { + return Duration.ofSeconds(getLong(REFRESH_GROUPS_INTERVAL_SECONDS)); + } else { + // negative interval to disable + return Duration.ofMillis(-1); + } + } + + Duration syncTopicConfigsInterval() { + if (getBoolean(SYNC_TOPIC_CONFIGS_ENABLED)) { + return Duration.ofSeconds(getLong(SYNC_TOPIC_CONFIGS_INTERVAL_SECONDS)); + } else { + // negative interval to disable + return Duration.ofMillis(-1); + } + } + + Duration syncTopicAclsInterval() { + if (getBoolean(SYNC_TOPIC_ACLS_ENABLED)) { + return Duration.ofSeconds(getLong(SYNC_TOPIC_ACLS_INTERVAL_SECONDS)); + } else { + // negative interval to disable + return Duration.ofMillis(-1); + } + } + + ReplicationPolicy replicationPolicy() { + return getConfiguredInstance(REPLICATION_POLICY_CLASS, ReplicationPolicy.class); + } + + int replicationFactor() { + return getInt(REPLICATION_FACTOR); + } + + short heartbeatsTopicReplicationFactor() { + return getShort(HEARTBEATS_TOPIC_REPLICATION_FACTOR); + } + + short checkpointsTopicReplicationFactor() { + return getShort(CHECKPOINTS_TOPIC_REPLICATION_FACTOR); + } + + short offsetSyncsTopicReplicationFactor() { + return getShort(OFFSET_SYNCS_TOPIC_REPLICATION_FACTOR); + } + + TopicFilter topicFilter() { + return getConfiguredInstance(TOPIC_FILTER_CLASS, TopicFilter.class); + } + + GroupFilter groupFilter() { + return getConfiguredInstance(GROUP_FILTER_CLASS, GroupFilter.class); + } + + ConfigPropertyFilter configPropertyFilter() { + return getConfiguredInstance(CONFIG_PROPERTY_FILTER_CLASS, ConfigPropertyFilter.class); + } + + Duration syncGroupOffsetsInterval() { + if (getBoolean(SYNC_GROUP_OFFSETS_ENABLED)) { + return Duration.ofSeconds(getLong(SYNC_GROUP_OFFSETS_INTERVAL_SECONDS)); + } else { + // negative interval to disable + return Duration.ofMillis(-1); + } + } + + protected static final ConfigDef CONNECTOR_CONFIG_DEF = ConnectorConfig.configDef() + .define( + ENABLED, + ConfigDef.Type.BOOLEAN, + true, + ConfigDef.Importance.LOW, + ENABLED_DOC) + .define( + TOPICS, + ConfigDef.Type.LIST, + TOPICS_DEFAULT, + ConfigDef.Importance.HIGH, + TOPICS_DOC) + .define( + TOPICS_EXCLUDE, + ConfigDef.Type.LIST, + TOPICS_EXCLUDE_DEFAULT, + ConfigDef.Importance.HIGH, + TOPICS_EXCLUDE_DOC) + .define( + TOPICS_EXCLUDE_ALIAS, + ConfigDef.Type.LIST, + null, + ConfigDef.Importance.HIGH, + "Deprecated. Use " + TOPICS_EXCLUDE + " instead.") + .define( + GROUPS, + ConfigDef.Type.LIST, + GROUPS_DEFAULT, + ConfigDef.Importance.HIGH, + GROUPS_DOC) + .define( + GROUPS_EXCLUDE, + ConfigDef.Type.LIST, + GROUPS_EXCLUDE_DEFAULT, + ConfigDef.Importance.HIGH, + GROUPS_EXCLUDE_DOC) + .define( + GROUPS_EXCLUDE_ALIAS, + ConfigDef.Type.LIST, + null, + ConfigDef.Importance.HIGH, + "Deprecated. Use " + GROUPS_EXCLUDE + " instead.") + .define( + CONFIG_PROPERTIES_EXCLUDE, + ConfigDef.Type.LIST, + CONFIG_PROPERTIES_EXCLUDE_DEFAULT, + ConfigDef.Importance.HIGH, + CONFIG_PROPERTIES_EXCLUDE_DOC) + .define( + CONFIG_PROPERTIES_EXCLUDE_ALIAS, + ConfigDef.Type.LIST, + null, + ConfigDef.Importance.HIGH, + "Deprecated. Use " + CONFIG_PROPERTIES_EXCLUDE + " instead.") + .define( + TOPIC_FILTER_CLASS, + ConfigDef.Type.CLASS, + TOPIC_FILTER_CLASS_DEFAULT, + ConfigDef.Importance.LOW, + TOPIC_FILTER_CLASS_DOC) + .define( + GROUP_FILTER_CLASS, + ConfigDef.Type.CLASS, + GROUP_FILTER_CLASS_DEFAULT, + ConfigDef.Importance.LOW, + GROUP_FILTER_CLASS_DOC) + .define( + CONFIG_PROPERTY_FILTER_CLASS, + ConfigDef.Type.CLASS, + CONFIG_PROPERTY_FILTER_CLASS_DEFAULT, + ConfigDef.Importance.LOW, + CONFIG_PROPERTY_FILTER_CLASS_DOC) + .define( + SOURCE_CLUSTER_ALIAS, + ConfigDef.Type.STRING, + ConfigDef.Importance.HIGH, + SOURCE_CLUSTER_ALIAS_DOC) + .define( + TARGET_CLUSTER_ALIAS, + ConfigDef.Type.STRING, + TARGET_CLUSTER_ALIAS_DEFAULT, + ConfigDef.Importance.HIGH, + TARGET_CLUSTER_ALIAS_DOC) + .define( + CONSUMER_POLL_TIMEOUT_MILLIS, + ConfigDef.Type.LONG, + CONSUMER_POLL_TIMEOUT_MILLIS_DEFAULT, + ConfigDef.Importance.LOW, + CONSUMER_POLL_TIMEOUT_MILLIS_DOC) + .define( + ADMIN_TASK_TIMEOUT_MILLIS, + ConfigDef.Type.LONG, + ADMIN_TASK_TIMEOUT_MILLIS_DEFAULT, + ConfigDef.Importance.LOW, + ADMIN_TASK_TIMEOUT_MILLIS_DOC) + .define( + REFRESH_TOPICS_ENABLED, + ConfigDef.Type.BOOLEAN, + REFRESH_TOPICS_ENABLED_DEFAULT, + ConfigDef.Importance.LOW, + REFRESH_TOPICS_ENABLED_DOC) + .define( + REFRESH_TOPICS_INTERVAL_SECONDS, + ConfigDef.Type.LONG, + REFRESH_TOPICS_INTERVAL_SECONDS_DEFAULT, + ConfigDef.Importance.LOW, + REFRESH_TOPICS_INTERVAL_SECONDS_DOC) + .define( + REFRESH_GROUPS_ENABLED, + ConfigDef.Type.BOOLEAN, + REFRESH_GROUPS_ENABLED_DEFAULT, + ConfigDef.Importance.LOW, + REFRESH_GROUPS_ENABLED_DOC) + .define( + REFRESH_GROUPS_INTERVAL_SECONDS, + ConfigDef.Type.LONG, + REFRESH_GROUPS_INTERVAL_SECONDS_DEFAULT, + ConfigDef.Importance.LOW, + REFRESH_GROUPS_INTERVAL_SECONDS_DOC) + .define( + SYNC_TOPIC_CONFIGS_ENABLED, + ConfigDef.Type.BOOLEAN, + SYNC_TOPIC_CONFIGS_ENABLED_DEFAULT, + ConfigDef.Importance.LOW, + SYNC_TOPIC_CONFIGS_ENABLED_DOC) + .define( + SYNC_TOPIC_CONFIGS_INTERVAL_SECONDS, + ConfigDef.Type.LONG, + SYNC_TOPIC_CONFIGS_INTERVAL_SECONDS_DEFAULT, + ConfigDef.Importance.LOW, + SYNC_TOPIC_CONFIGS_INTERVAL_SECONDS_DOC) + .define( + SYNC_TOPIC_ACLS_ENABLED, + ConfigDef.Type.BOOLEAN, + SYNC_TOPIC_ACLS_ENABLED_DEFAULT, + ConfigDef.Importance.LOW, + SYNC_TOPIC_ACLS_ENABLED_DOC) + .define( + SYNC_TOPIC_ACLS_INTERVAL_SECONDS, + ConfigDef.Type.LONG, + SYNC_TOPIC_ACLS_INTERVAL_SECONDS_DEFAULT, + ConfigDef.Importance.LOW, + SYNC_TOPIC_ACLS_INTERVAL_SECONDS_DOC) + .define( + EMIT_HEARTBEATS_ENABLED, + ConfigDef.Type.BOOLEAN, + EMIT_HEARTBEATS_ENABLED_DEFAULT, + ConfigDef.Importance.LOW, + EMIT_HEARTBEATS_ENABLED_DOC) + .define( + EMIT_HEARTBEATS_INTERVAL_SECONDS, + ConfigDef.Type.LONG, + EMIT_HEARTBEATS_INTERVAL_SECONDS_DEFAULT, + ConfigDef.Importance.LOW, + EMIT_HEARTBEATS_INTERVAL_SECONDS_DOC) + .define( + EMIT_CHECKPOINTS_ENABLED, + ConfigDef.Type.BOOLEAN, + EMIT_CHECKPOINTS_ENABLED_DEFAULT, + ConfigDef.Importance.LOW, + EMIT_CHECKPOINTS_ENABLED_DOC) + .define( + EMIT_CHECKPOINTS_INTERVAL_SECONDS, + ConfigDef.Type.LONG, + EMIT_CHECKPOINTS_INTERVAL_SECONDS_DEFAULT, + ConfigDef.Importance.LOW, + EMIT_CHECKPOINTS_INTERVAL_SECONDS_DOC) + .define( + SYNC_GROUP_OFFSETS_ENABLED, + ConfigDef.Type.BOOLEAN, + SYNC_GROUP_OFFSETS_ENABLED_DEFAULT, + ConfigDef.Importance.LOW, + SYNC_GROUP_OFFSETS_ENABLED_DOC) + .define( + SYNC_GROUP_OFFSETS_INTERVAL_SECONDS, + ConfigDef.Type.LONG, + SYNC_GROUP_OFFSETS_INTERVAL_SECONDS_DEFAULT, + ConfigDef.Importance.LOW, + SYNC_GROUP_OFFSETS_INTERVAL_SECONDS_DOC) + .define( + REPLICATION_POLICY_CLASS, + ConfigDef.Type.CLASS, + REPLICATION_POLICY_CLASS_DEFAULT, + ConfigDef.Importance.LOW, + REPLICATION_POLICY_CLASS_DOC) + .define( + REPLICATION_POLICY_SEPARATOR, + ConfigDef.Type.STRING, + REPLICATION_POLICY_SEPARATOR_DEFAULT, + ConfigDef.Importance.LOW, + REPLICATION_POLICY_SEPARATOR_DOC) + .define( + REPLICATION_FACTOR, + ConfigDef.Type.INT, + REPLICATION_FACTOR_DEFAULT, + ConfigDef.Importance.LOW, + REPLICATION_FACTOR_DOC) + .define( + HEARTBEATS_TOPIC_REPLICATION_FACTOR, + ConfigDef.Type.SHORT, + HEARTBEATS_TOPIC_REPLICATION_FACTOR_DEFAULT, + ConfigDef.Importance.LOW, + HEARTBEATS_TOPIC_REPLICATION_FACTOR_DOC) + .define( + CHECKPOINTS_TOPIC_REPLICATION_FACTOR, + ConfigDef.Type.SHORT, + CHECKPOINTS_TOPIC_REPLICATION_FACTOR_DEFAULT, + ConfigDef.Importance.LOW, + CHECKPOINTS_TOPIC_REPLICATION_FACTOR_DOC) + .define( + OFFSET_SYNCS_TOPIC_REPLICATION_FACTOR, + ConfigDef.Type.SHORT, + OFFSET_SYNCS_TOPIC_REPLICATION_FACTOR_DEFAULT, + ConfigDef.Importance.LOW, + OFFSET_SYNCS_TOPIC_REPLICATION_FACTOR_DOC) + .define( + OFFSET_LAG_MAX, + ConfigDef.Type.LONG, + OFFSET_LAG_MAX_DEFAULT, + ConfigDef.Importance.LOW, + OFFSET_LAG_MAX_DOC) + .define( + OFFSET_SYNCS_TOPIC_LOCATION, + ConfigDef.Type.STRING, + OFFSET_SYNCS_TOPIC_LOCATION_DEFAULT, + ValidString.in(SOURCE_CLUSTER_ALIAS_DEFAULT, TARGET_CLUSTER_ALIAS_DEFAULT), + ConfigDef.Importance.LOW, + OFFSET_SYNCS_TOPIC_LOCATION_DOC) + .define( + CommonClientConfigs.METRIC_REPORTER_CLASSES_CONFIG, + ConfigDef.Type.LIST, + null, + ConfigDef.Importance.LOW, + CommonClientConfigs.METRIC_REPORTER_CLASSES_DOC) + .define( + CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, + ConfigDef.Type.STRING, + CommonClientConfigs.DEFAULT_SECURITY_PROTOCOL, + ConfigDef.Importance.MEDIUM, + CommonClientConfigs.SECURITY_PROTOCOL_DOC) + .withClientSslSupport() + .withClientSaslSupport(); +} diff --git a/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorHeartbeatConnector.java b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorHeartbeatConnector.java new file mode 100644 index 0000000..8b2d064 --- /dev/null +++ b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorHeartbeatConnector.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.connect.connector.Task; +import org.apache.kafka.connect.source.SourceConnector; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.utils.Utils; + +import java.util.Map; +import java.util.List; +import java.util.Collections; + +/** Emits heartbeats to Kafka. + */ +public class MirrorHeartbeatConnector extends SourceConnector { + private MirrorConnectorConfig config; + private Scheduler scheduler; + + public MirrorHeartbeatConnector() { + // nop + } + + // visible for testing + MirrorHeartbeatConnector(MirrorConnectorConfig config) { + this.config = config; + } + + @Override + public void start(Map props) { + config = new MirrorConnectorConfig(props); + scheduler = new Scheduler(MirrorHeartbeatConnector.class, config.adminTimeout()); + scheduler.execute(this::createInternalTopics, "creating internal topics"); + } + + @Override + public void stop() { + Utils.closeQuietly(scheduler, "scheduler"); + } + + @Override + public Class taskClass() { + return MirrorHeartbeatTask.class; + } + + @Override + public List> taskConfigs(int maxTasks) { + // if the heartbeats emission is disabled by setting `emit.heartbeats.enabled` to `false`, + // the interval heartbeat emission will be negative and no `MirrorHeartbeatTask` will be created + if (config.emitHeartbeatsInterval().isNegative()) { + return Collections.emptyList(); + } + // just need a single task + return Collections.singletonList(config.originalsStrings()); + } + + @Override + public ConfigDef config() { + return MirrorConnectorConfig.CONNECTOR_CONFIG_DEF; + } + + @Override + public String version() { + return "1"; + } + + private void createInternalTopics() { + MirrorUtils.createSinglePartitionCompactedTopic(config.heartbeatsTopic(), + config.heartbeatsTopicReplicationFactor(), config.targetAdminConfig()); + } +} diff --git a/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorHeartbeatTask.java b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorHeartbeatTask.java new file mode 100644 index 0000000..9f38b59 --- /dev/null +++ b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorHeartbeatTask.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.connect.source.SourceTask; +import org.apache.kafka.connect.source.SourceRecord; +import org.apache.kafka.connect.data.Schema; + +import java.util.Map; +import java.util.List; +import java.util.Collections; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.time.Duration; + +/** Emits heartbeats. */ +public class MirrorHeartbeatTask extends SourceTask { + private String sourceClusterAlias; + private String targetClusterAlias; + private String heartbeatsTopic; + private Duration interval; + private CountDownLatch stopped; + + @Override + public void start(Map props) { + stopped = new CountDownLatch(1); + MirrorTaskConfig config = new MirrorTaskConfig(props); + sourceClusterAlias = config.sourceClusterAlias(); + targetClusterAlias = config.targetClusterAlias(); + heartbeatsTopic = config.heartbeatsTopic(); + interval = config.emitHeartbeatsInterval(); + } + + @Override + public void commit() { + // nop + } + + @Override + public void stop() { + stopped.countDown(); + } + + @Override + public String version() { + return "1"; + } + + @Override + public List poll() throws InterruptedException { + // pause to throttle, unless we've stopped + if (stopped.await(interval.toMillis(), TimeUnit.MILLISECONDS)) { + // SourceWorkerTask expects non-zero batches or null + return null; + } + long timestamp = System.currentTimeMillis(); + Heartbeat heartbeat = new Heartbeat(sourceClusterAlias, targetClusterAlias, timestamp); + SourceRecord record = new SourceRecord( + heartbeat.connectPartition(), MirrorUtils.wrapOffset(0), + heartbeatsTopic, 0, + Schema.BYTES_SCHEMA, heartbeat.recordKey(), + Schema.BYTES_SCHEMA, heartbeat.recordValue(), + timestamp); + return Collections.singletonList(record); + } + + @Override + public void commitRecord(SourceRecord record, RecordMetadata metadata) { + } +} diff --git a/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorMaker.java b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorMaker.java new file mode 100644 index 0000000..7ac6831 --- /dev/null +++ b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorMaker.java @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.connect.runtime.Herder; +import org.apache.kafka.connect.runtime.isolation.Plugins; +import org.apache.kafka.connect.runtime.Worker; +import org.apache.kafka.connect.runtime.WorkerConfigTransformer; +import org.apache.kafka.connect.runtime.distributed.DistributedConfig; +import org.apache.kafka.connect.runtime.distributed.DistributedHerder; +import org.apache.kafka.connect.runtime.distributed.NotLeaderException; +import org.apache.kafka.connect.storage.KafkaOffsetBackingStore; +import org.apache.kafka.connect.storage.StatusBackingStore; +import org.apache.kafka.connect.storage.KafkaStatusBackingStore; +import org.apache.kafka.connect.storage.ConfigBackingStore; +import org.apache.kafka.connect.storage.KafkaConfigBackingStore; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.util.ConnectUtils; +import org.apache.kafka.connect.connector.policy.AllConnectorClientConfigOverridePolicy; +import org.apache.kafka.connect.connector.policy.ConnectorClientConfigOverridePolicy; + +import org.apache.kafka.connect.util.SharedTopicAdmin; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import net.sourceforge.argparse4j.impl.Arguments; +import net.sourceforge.argparse4j.inf.Namespace; +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.ArgumentParserException; +import net.sourceforge.argparse4j.ArgumentParsers; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.Map; +import java.util.HashMap; +import java.util.Set; +import java.util.HashSet; +import java.util.List; +import java.util.Arrays; +import java.util.Properties; +import java.util.stream.Collectors; +import java.io.File; + +/** + * Entry point for "MirrorMaker 2.0". + *

            + * MirrorMaker runs a set of Connectors between multiple clusters, in order to replicate data, configuration, + * ACL rules, and consumer group state. + *

            + *

            + * Configuration is via a top-level "mm2.properties" file, which supports per-cluster and per-replication + * sub-configs. Each source->target replication must be explicitly enabled. For example: + *

            + *
            + *    clusters = primary, backup
            + *    primary.bootstrap.servers = vip1:9092
            + *    backup.bootstrap.servers = vip2:9092
            + *    primary->backup.enabled = true
            + *    backup->primary.enabled = true
            + *  
            + *

            + * Run as follows: + *

            + *
            + *    ./bin/connect-mirror-maker.sh mm2.properties
            + *  
            + *

            + * Additional information and example configurations are provided in ./connect/mirror/README.md + *

            + */ +public class MirrorMaker { + private static final Logger log = LoggerFactory.getLogger(MirrorMaker.class); + + private static final long SHUTDOWN_TIMEOUT_SECONDS = 60L; + private static final ConnectorClientConfigOverridePolicy CLIENT_CONFIG_OVERRIDE_POLICY = + new AllConnectorClientConfigOverridePolicy(); + + private static final List> CONNECTOR_CLASSES = Arrays.asList( + MirrorSourceConnector.class, + MirrorHeartbeatConnector.class, + MirrorCheckpointConnector.class); + + private final Map herders = new HashMap<>(); + private CountDownLatch startLatch; + private CountDownLatch stopLatch; + private final AtomicBoolean shutdown = new AtomicBoolean(false); + private final ShutdownHook shutdownHook; + private final String advertisedBaseUrl; + private final Time time; + private final MirrorMakerConfig config; + private final Set clusters; + private final Set herderPairs; + + /** + * @param config MM2 configuration from mm2.properties file + * @param clusters target clusters for this node. These must match cluster + * aliases as defined in the config. If null or empty list, + * uses all clusters in the config. + * @param time time source + */ + public MirrorMaker(MirrorMakerConfig config, List clusters, Time time) { + log.debug("Kafka MirrorMaker instance created"); + this.time = time; + this.advertisedBaseUrl = "NOTUSED"; + this.config = config; + if (clusters != null && !clusters.isEmpty()) { + this.clusters = new HashSet<>(clusters); + } else { + // default to all clusters + this.clusters = config.clusters(); + } + log.info("Targeting clusters {}", this.clusters); + this.herderPairs = config.clusterPairs().stream() + .filter(x -> this.clusters.contains(x.target())) + .collect(Collectors.toSet()); + if (herderPairs.isEmpty()) { + throw new IllegalArgumentException("No source->target replication flows."); + } + this.herderPairs.forEach(this::addHerder); + shutdownHook = new ShutdownHook(); + } + + /** + * @param config MM2 configuration from mm2.properties file + * @param clusters target clusters for this node. These must match cluster + * aliases as defined in the config. If null or empty list, + * uses all clusters in the config. + * @param time time source + */ + public MirrorMaker(Map config, List clusters, Time time) { + this(new MirrorMakerConfig(config), clusters, time); + } + + public MirrorMaker(Map props, List clusters) { + this(props, clusters, Time.SYSTEM); + } + + public MirrorMaker(Map props) { + this(props, null); + } + + + public void start() { + log.info("Kafka MirrorMaker starting with {} herders.", herders.size()); + if (startLatch != null) { + throw new IllegalStateException("MirrorMaker instance already started"); + } + startLatch = new CountDownLatch(herders.size()); + stopLatch = new CountDownLatch(herders.size()); + Exit.addShutdownHook("mirror-maker-shutdown-hook", shutdownHook); + for (Herder herder : herders.values()) { + try { + herder.start(); + } finally { + startLatch.countDown(); + } + } + log.info("Configuring connectors..."); + herderPairs.forEach(this::configureConnectors); + log.info("Kafka MirrorMaker started"); + } + + public void stop() { + boolean wasShuttingDown = shutdown.getAndSet(true); + if (!wasShuttingDown) { + log.info("Kafka MirrorMaker stopping"); + for (Herder herder : herders.values()) { + try { + herder.stop(); + } finally { + stopLatch.countDown(); + } + } + log.info("Kafka MirrorMaker stopped."); + } + } + + public void awaitStop() { + try { + stopLatch.await(); + } catch (InterruptedException e) { + log.error("Interrupted waiting for MirrorMaker to shutdown"); + } + } + + private void configureConnector(SourceAndTarget sourceAndTarget, Class connectorClass) { + checkHerder(sourceAndTarget); + Map connectorProps = config.connectorBaseConfig(sourceAndTarget, connectorClass); + herders.get(sourceAndTarget) + .putConnectorConfig(connectorClass.getSimpleName(), connectorProps, true, (e, x) -> { + if (e instanceof NotLeaderException) { + // No way to determine if the connector is a leader or not beforehand. + log.info("Connector {} is a follower. Using existing configuration.", sourceAndTarget); + } else { + log.info("Connector {} configured.", sourceAndTarget, e); + } + }); + } + + private void checkHerder(SourceAndTarget sourceAndTarget) { + if (!herders.containsKey(sourceAndTarget)) { + throw new IllegalArgumentException("No herder for " + sourceAndTarget.toString()); + } + } + + private void configureConnectors(SourceAndTarget sourceAndTarget) { + CONNECTOR_CLASSES.forEach(x -> configureConnector(sourceAndTarget, x)); + } + + private void addHerder(SourceAndTarget sourceAndTarget) { + log.info("creating herder for " + sourceAndTarget.toString()); + Map workerProps = config.workerConfig(sourceAndTarget); + String advertisedUrl = advertisedBaseUrl + "/" + sourceAndTarget.source(); + String workerId = sourceAndTarget.toString(); + Plugins plugins = new Plugins(workerProps); + plugins.compareAndSwapWithDelegatingLoader(); + DistributedConfig distributedConfig = new DistributedConfig(workerProps); + String kafkaClusterId = ConnectUtils.lookupKafkaClusterId(distributedConfig); + // Create the admin client to be shared by all backing stores for this herder + Map adminProps = new HashMap<>(distributedConfig.originals()); + ConnectUtils.addMetricsContextProperties(adminProps, distributedConfig, kafkaClusterId); + SharedTopicAdmin sharedAdmin = new SharedTopicAdmin(adminProps); + KafkaOffsetBackingStore offsetBackingStore = new KafkaOffsetBackingStore(sharedAdmin); + offsetBackingStore.configure(distributedConfig); + Worker worker = new Worker(workerId, time, plugins, distributedConfig, offsetBackingStore, CLIENT_CONFIG_OVERRIDE_POLICY); + WorkerConfigTransformer configTransformer = worker.configTransformer(); + Converter internalValueConverter = worker.getInternalValueConverter(); + StatusBackingStore statusBackingStore = new KafkaStatusBackingStore(time, internalValueConverter, sharedAdmin); + statusBackingStore.configure(distributedConfig); + ConfigBackingStore configBackingStore = new KafkaConfigBackingStore( + internalValueConverter, + distributedConfig, + configTransformer, + sharedAdmin); + // Pass the shared admin to the distributed herder as an additional AutoCloseable object that should be closed when the + // herder is stopped. MirrorMaker has multiple herders, and having the herder own the close responsibility is much easier than + // tracking the various shared admin objects in this class. + Herder herder = new DistributedHerder(distributedConfig, time, worker, + kafkaClusterId, statusBackingStore, configBackingStore, + advertisedUrl, CLIENT_CONFIG_OVERRIDE_POLICY, sharedAdmin); + herders.put(sourceAndTarget, herder); + } + + private class ShutdownHook extends Thread { + @Override + public void run() { + try { + if (!startLatch.await(SHUTDOWN_TIMEOUT_SECONDS, TimeUnit.SECONDS)) { + log.error("Timed out in shutdown hook waiting for MirrorMaker startup to finish. Unable to shutdown cleanly."); + } + } catch (InterruptedException e) { + log.error("Interrupted in shutdown hook while waiting for MirrorMaker startup to finish. Unable to shutdown cleanly."); + } finally { + MirrorMaker.this.stop(); + } + } + } + + public static void main(String[] args) { + ArgumentParser parser = ArgumentParsers.newArgumentParser("connect-mirror-maker"); + parser.description("MirrorMaker 2.0 driver"); + parser.addArgument("config").type(Arguments.fileType().verifyCanRead()) + .metavar("mm2.properties").required(true) + .help("MM2 configuration file."); + parser.addArgument("--clusters").nargs("+").metavar("CLUSTER").required(false) + .help("Target cluster to use for this node."); + Namespace ns; + try { + ns = parser.parseArgs(args); + } catch (ArgumentParserException e) { + parser.handleError(e); + Exit.exit(-1); + return; + } + File configFile = ns.get("config"); + List clusters = ns.getList("clusters"); + try { + log.info("Kafka MirrorMaker initializing ..."); + + Properties props = Utils.loadProps(configFile.getPath()); + Map config = Utils.propsToStringMap(props); + MirrorMaker mirrorMaker = new MirrorMaker(config, clusters, Time.SYSTEM); + + try { + mirrorMaker.start(); + } catch (Exception e) { + log.error("Failed to start MirrorMaker", e); + mirrorMaker.stop(); + Exit.exit(3); + } + + mirrorMaker.awaitStop(); + + } catch (Throwable t) { + log.error("Stopping due to error", t); + Exit.exit(2); + } + } + +} diff --git a/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorMakerConfig.java b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorMakerConfig.java new file mode 100644 index 0000000..33cd8a7 --- /dev/null +++ b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorMakerConfig.java @@ -0,0 +1,292 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import java.util.Map.Entry; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigDef.Type; +import org.apache.kafka.common.config.ConfigDef.Importance; +import org.apache.kafka.common.config.provider.ConfigProvider; +import org.apache.kafka.common.config.ConfigTransformer; +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.apache.kafka.connect.runtime.distributed.DistributedConfig; +import org.apache.kafka.connect.runtime.isolation.Plugins; + +import java.util.Map; +import java.util.HashMap; +import java.util.List; +import java.util.Set; +import java.util.HashSet; +import java.util.ArrayList; +import java.util.Collections; +import java.util.stream.Collectors; + +/** Top-level config describing replication flows between multiple Kafka clusters. + * + * Supports cluster-level properties of the form cluster.x.y.z, and replication-level + * properties of the form source->target.x.y.z. + * e.g. + * + * clusters = A, B, C + * A.bootstrap.servers = aaa:9092 + * A.security.protocol = SSL + * --->%--- + * A->B.enabled = true + * A->B.producer.client.id = "A-B-producer" + * --->%--- + * + */ +public class MirrorMakerConfig extends AbstractConfig { + + public static final String CLUSTERS_CONFIG = "clusters"; + private static final String CLUSTERS_DOC = "List of cluster aliases."; + public static final String CONFIG_PROVIDERS_CONFIG = WorkerConfig.CONFIG_PROVIDERS_CONFIG; + private static final String CONFIG_PROVIDERS_DOC = "Names of ConfigProviders to use."; + + private static final String NAME = "name"; + private static final String CONNECTOR_CLASS = "connector.class"; + private static final String SOURCE_CLUSTER_ALIAS = "source.cluster.alias"; + private static final String TARGET_CLUSTER_ALIAS = "target.cluster.alias"; + private static final String GROUP_ID_CONFIG = "group.id"; + private static final String KEY_CONVERTER_CLASS_CONFIG = "key.converter"; + private static final String VALUE_CONVERTER_CLASS_CONFIG = "value.converter"; + private static final String HEADER_CONVERTER_CLASS_CONFIG = "header.converter"; + private static final String BYTE_ARRAY_CONVERTER_CLASS = + "org.apache.kafka.connect.converters.ByteArrayConverter"; + + static final String SOURCE_CLUSTER_PREFIX = "source.cluster."; + static final String TARGET_CLUSTER_PREFIX = "target.cluster."; + static final String SOURCE_PREFIX = "source."; + static final String TARGET_PREFIX = "target."; + + private final Plugins plugins; + + public MirrorMakerConfig(Map props) { + super(CONFIG_DEF, props, true); + plugins = new Plugins(originalsStrings()); + } + + public Set clusters() { + return new HashSet<>(getList(CLUSTERS_CONFIG)); + } + + public List clusterPairs() { + List pairs = new ArrayList<>(); + Set clusters = clusters(); + Map originalStrings = originalsStrings(); + boolean globalHeartbeatsEnabled = MirrorConnectorConfig.EMIT_HEARTBEATS_ENABLED_DEFAULT; + if (originalStrings.containsKey(MirrorConnectorConfig.EMIT_HEARTBEATS_ENABLED)) { + globalHeartbeatsEnabled = Boolean.valueOf(originalStrings.get(MirrorConnectorConfig.EMIT_HEARTBEATS_ENABLED)); + } + + for (String source : clusters) { + for (String target : clusters) { + if (!source.equals(target)) { + String clusterPairConfigPrefix = source + "->" + target + "."; + boolean clusterPairEnabled = Boolean.valueOf(originalStrings.getOrDefault(clusterPairConfigPrefix + "enabled", "false")); + boolean clusterPairHeartbeatsEnabled = globalHeartbeatsEnabled; + if (originalStrings.containsKey(clusterPairConfigPrefix + MirrorConnectorConfig.EMIT_HEARTBEATS_ENABLED)) { + clusterPairHeartbeatsEnabled = Boolean.valueOf(originalStrings.get(clusterPairConfigPrefix + MirrorConnectorConfig.EMIT_HEARTBEATS_ENABLED)); + } + + // By default, all source->target Herder combinations are created even if `x->y.enabled=false` + // Unless `emit.heartbeats.enabled=false` or `x->y.emit.heartbeats.enabled=false` + // Reason for this behavior: for a given replication flow A->B with heartbeats, 2 herders are required : + // B->A for the MirrorHeartbeatConnector (emits heartbeats into A for monitoring replication health) + // A->B for the MirrorSourceConnector (actual replication flow) + if (clusterPairEnabled || clusterPairHeartbeatsEnabled) { + pairs.add(new SourceAndTarget(source, target)); + } + } + } + } + return pairs; + } + + /** Construct a MirrorClientConfig from properties of the form cluster.x.y.z. + * Use to connect to a cluster based on the MirrorMaker top-level config file. + */ + public MirrorClientConfig clientConfig(String cluster) { + Map props = new HashMap<>(); + props.putAll(originalsStrings()); + props.putAll(clusterProps(cluster)); + return new MirrorClientConfig(transform(props)); + } + + // loads properties of the form cluster.x.y.z + Map clusterProps(String cluster) { + Map props = new HashMap<>(); + Map strings = originalsStrings(); + + props.putAll(stringsWithPrefixStripped(cluster + ".")); + + for (String k : MirrorClientConfig.CLIENT_CONFIG_DEF.names()) { + String v = props.get(k); + if (v != null) { + props.putIfAbsent("producer." + k, v); + props.putIfAbsent("consumer." + k, v); + props.putIfAbsent("admin." + k, v); + } + } + + for (String k : MirrorClientConfig.CLIENT_CONFIG_DEF.names()) { + String v = strings.get(k); + if (v != null) { + props.putIfAbsent("producer." + k, v); + props.putIfAbsent("consumer." + k, v); + props.putIfAbsent("admin." + k, v); + props.putIfAbsent(k, v); + } + } + + return props; + } + + // loads worker configs based on properties of the form x.y.z and cluster.x.y.z + public Map workerConfig(SourceAndTarget sourceAndTarget) { + Map props = new HashMap<>(); + props.putAll(clusterProps(sourceAndTarget.target())); + + // Accept common top-level configs that are otherwise ignored by MM2. + // N.B. all other worker properties should be configured for specific herders, + // e.g. primary->backup.client.id + props.putAll(stringsWithPrefix("offset.storage")); + props.putAll(stringsWithPrefix("config.storage")); + props.putAll(stringsWithPrefix("status.storage")); + props.putAll(stringsWithPrefix("key.converter")); + props.putAll(stringsWithPrefix("value.converter")); + props.putAll(stringsWithPrefix("header.converter")); + props.putAll(stringsWithPrefix("task")); + props.putAll(stringsWithPrefix("worker")); + props.putAll(stringsWithPrefix("replication.policy")); + + // transform any expression like ${provider:path:key}, since the worker doesn't do so + props = transform(props); + props.putAll(stringsWithPrefix(CONFIG_PROVIDERS_CONFIG)); + + // fill in reasonable defaults + props.putIfAbsent(GROUP_ID_CONFIG, sourceAndTarget.source() + "-mm2"); + props.putIfAbsent(DistributedConfig.OFFSET_STORAGE_TOPIC_CONFIG, "mm2-offsets." + + sourceAndTarget.source() + ".internal"); + props.putIfAbsent(DistributedConfig.STATUS_STORAGE_TOPIC_CONFIG, "mm2-status." + + sourceAndTarget.source() + ".internal"); + props.putIfAbsent(DistributedConfig.CONFIG_TOPIC_CONFIG, "mm2-configs." + + sourceAndTarget.source() + ".internal"); + props.putIfAbsent(KEY_CONVERTER_CLASS_CONFIG, BYTE_ARRAY_CONVERTER_CLASS); + props.putIfAbsent(VALUE_CONVERTER_CLASS_CONFIG, BYTE_ARRAY_CONVERTER_CLASS); + props.putIfAbsent(HEADER_CONVERTER_CLASS_CONFIG, BYTE_ARRAY_CONVERTER_CLASS); + + return props; + } + + // loads properties of the form cluster.x.y.z and source->target.x.y.z + public Map connectorBaseConfig(SourceAndTarget sourceAndTarget, Class connectorClass) { + Map props = new HashMap<>(); + + props.putAll(originalsStrings()); + props.keySet().retainAll(MirrorConnectorConfig.CONNECTOR_CONFIG_DEF.names()); + + props.putAll(stringsWithPrefix(CONFIG_PROVIDERS_CONFIG)); + props.putAll(stringsWithPrefix("replication.policy")); + + Map sourceClusterProps = clusterProps(sourceAndTarget.source()); + // attrs non prefixed with producer|consumer|admin + props.putAll(clusterConfigsWithPrefix(SOURCE_CLUSTER_PREFIX, sourceClusterProps)); + // attrs prefixed with producer|consumer|admin + props.putAll(clientConfigsWithPrefix(SOURCE_PREFIX, sourceClusterProps)); + + Map targetClusterProps = clusterProps(sourceAndTarget.target()); + props.putAll(clusterConfigsWithPrefix(TARGET_CLUSTER_PREFIX, targetClusterProps)); + props.putAll(clientConfigsWithPrefix(TARGET_PREFIX, targetClusterProps)); + + props.putIfAbsent(NAME, connectorClass.getSimpleName()); + props.putIfAbsent(CONNECTOR_CLASS, connectorClass.getName()); + props.putIfAbsent(SOURCE_CLUSTER_ALIAS, sourceAndTarget.source()); + props.putIfAbsent(TARGET_CLUSTER_ALIAS, sourceAndTarget.target()); + + // override with connector-level properties + props.putAll(stringsWithPrefixStripped(sourceAndTarget.source() + "->" + + sourceAndTarget.target() + ".")); + + // disabled by default + props.putIfAbsent(MirrorConnectorConfig.ENABLED, "false"); + + // don't transform -- the worker will handle transformation of Connector and Task configs + return props; + } + + List configProviders() { + return getList(CONFIG_PROVIDERS_CONFIG); + } + + Map transform(Map props) { + // transform worker config according to config.providers + List providerNames = configProviders(); + Map providers = new HashMap<>(); + for (String name : providerNames) { + ConfigProvider configProvider = plugins.newConfigProvider( + this, + CONFIG_PROVIDERS_CONFIG + "." + name, + Plugins.ClassLoaderUsage.PLUGINS + ); + providers.put(name, configProvider); + } + ConfigTransformer transformer = new ConfigTransformer(providers); + Map transformed = transformer.transform(props).data(); + providers.values().forEach(x -> Utils.closeQuietly(x, "config provider")); + return transformed; + } + + protected static final ConfigDef CONFIG_DEF = new ConfigDef() + .define(CLUSTERS_CONFIG, Type.LIST, Importance.HIGH, CLUSTERS_DOC) + .define(CONFIG_PROVIDERS_CONFIG, Type.LIST, Collections.emptyList(), Importance.LOW, CONFIG_PROVIDERS_DOC) + // security support + .define(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, + Type.STRING, + CommonClientConfigs.DEFAULT_SECURITY_PROTOCOL, + Importance.MEDIUM, + CommonClientConfigs.SECURITY_PROTOCOL_DOC) + .withClientSslSupport() + .withClientSaslSupport(); + + private Map stringsWithPrefixStripped(String prefix) { + return originalsStrings().entrySet().stream() + .filter(x -> x.getKey().startsWith(prefix)) + .collect(Collectors.toMap(x -> x.getKey().substring(prefix.length()), Entry::getValue)); + } + + private Map stringsWithPrefix(String prefix) { + Map strings = originalsStrings(); + strings.keySet().removeIf(x -> !x.startsWith(prefix)); + return strings; + } + + static Map clusterConfigsWithPrefix(String prefix, Map props) { + return props.entrySet().stream() + .filter(x -> !x.getKey().matches("(^consumer.*|^producer.*|^admin.*)")) + .collect(Collectors.toMap(x -> prefix + x.getKey(), Entry::getValue)); + } + + static Map clientConfigsWithPrefix(String prefix, Map props) { + return props.entrySet().stream() + .filter(x -> x.getKey().matches("(^consumer.*|^producer.*|^admin.*)")) + .collect(Collectors.toMap(x -> prefix + x.getKey(), Entry::getValue)); + } +} diff --git a/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorMetrics.java b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorMetrics.java new file mode 100644 index 0000000..4bd03f3 --- /dev/null +++ b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorMetrics.java @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.common.MetricNameTemplate; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.MetricsReporter; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.stats.Value; +import org.apache.kafka.common.metrics.stats.Min; +import org.apache.kafka.common.metrics.stats.Max; +import org.apache.kafka.common.metrics.stats.Avg; +import org.apache.kafka.common.metrics.stats.Meter; +import org.apache.kafka.common.TopicPartition; + +import java.util.Arrays; +import java.util.Set; +import java.util.HashSet; +import java.util.Map; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.stream.Collectors; + +/** Metrics for replicated topic-partitions */ +class MirrorMetrics implements AutoCloseable { + + private static final String SOURCE_CONNECTOR_GROUP = MirrorSourceConnector.class.getSimpleName(); + private static final String CHECKPOINT_CONNECTOR_GROUP = MirrorCheckpointConnector.class.getSimpleName(); + + private static final Set PARTITION_TAGS = new HashSet<>(Arrays.asList("target", "topic", "partition")); + private static final Set GROUP_TAGS = new HashSet<>(Arrays.asList("source", "target", "group", "topic", "partition")); + + private static final MetricNameTemplate RECORD_COUNT = new MetricNameTemplate( + "record-count", SOURCE_CONNECTOR_GROUP, + "Number of source records replicated to the target cluster.", PARTITION_TAGS); + private static final MetricNameTemplate RECORD_RATE = new MetricNameTemplate( + "record-rate", SOURCE_CONNECTOR_GROUP, + "Average number of source records replicated to the target cluster per second.", PARTITION_TAGS); + private static final MetricNameTemplate RECORD_AGE = new MetricNameTemplate( + "record-age-ms", SOURCE_CONNECTOR_GROUP, + "The age of incoming source records when replicated to the target cluster.", PARTITION_TAGS); + private static final MetricNameTemplate RECORD_AGE_MAX = new MetricNameTemplate( + "record-age-ms-max", SOURCE_CONNECTOR_GROUP, + "The max age of incoming source records when replicated to the target cluster.", PARTITION_TAGS); + private static final MetricNameTemplate RECORD_AGE_MIN = new MetricNameTemplate( + "record-age-ms-min", SOURCE_CONNECTOR_GROUP, + "The min age of incoming source records when replicated to the target cluster.", PARTITION_TAGS); + private static final MetricNameTemplate RECORD_AGE_AVG = new MetricNameTemplate( + "record-age-ms-avg", SOURCE_CONNECTOR_GROUP, + "The average age of incoming source records when replicated to the target cluster.", PARTITION_TAGS); + private static final MetricNameTemplate BYTE_COUNT = new MetricNameTemplate( + "byte-count", SOURCE_CONNECTOR_GROUP, + "Number of bytes replicated to the target cluster.", PARTITION_TAGS); + private static final MetricNameTemplate BYTE_RATE = new MetricNameTemplate( + "byte-rate", SOURCE_CONNECTOR_GROUP, + "Average number of bytes replicated per second.", PARTITION_TAGS); + private static final MetricNameTemplate REPLICATION_LATENCY = new MetricNameTemplate( + "replication-latency-ms", SOURCE_CONNECTOR_GROUP, + "Time it takes records to replicate from source to target cluster.", PARTITION_TAGS); + private static final MetricNameTemplate REPLICATION_LATENCY_MAX = new MetricNameTemplate( + "replication-latency-ms-max", SOURCE_CONNECTOR_GROUP, + "Max time it takes records to replicate from source to target cluster.", PARTITION_TAGS); + private static final MetricNameTemplate REPLICATION_LATENCY_MIN = new MetricNameTemplate( + "replication-latency-ms-min", SOURCE_CONNECTOR_GROUP, + "Min time it takes records to replicate from source to target cluster.", PARTITION_TAGS); + private static final MetricNameTemplate REPLICATION_LATENCY_AVG = new MetricNameTemplate( + "replication-latency-ms-avg", SOURCE_CONNECTOR_GROUP, + "Average time it takes records to replicate from source to target cluster.", PARTITION_TAGS); + + private static final MetricNameTemplate CHECKPOINT_LATENCY = new MetricNameTemplate( + "checkpoint-latency-ms", CHECKPOINT_CONNECTOR_GROUP, + "Time it takes consumer group offsets to replicate from source to target cluster.", GROUP_TAGS); + private static final MetricNameTemplate CHECKPOINT_LATENCY_MAX = new MetricNameTemplate( + "checkpoint-latency-ms-max", CHECKPOINT_CONNECTOR_GROUP, + "Max time it takes consumer group offsets to replicate from source to target cluster.", GROUP_TAGS); + private static final MetricNameTemplate CHECKPOINT_LATENCY_MIN = new MetricNameTemplate( + "checkpoint-latency-ms-min", CHECKPOINT_CONNECTOR_GROUP, + "Min time it takes consumer group offsets to replicate from source to target cluster.", GROUP_TAGS); + private static final MetricNameTemplate CHECKPOINT_LATENCY_AVG = new MetricNameTemplate( + "checkpoint-latency-ms-avg", CHECKPOINT_CONNECTOR_GROUP, + "Average time it takes consumer group offsets to replicate from source to target cluster.", GROUP_TAGS); + + + private final Metrics metrics; + private final Map partitionMetrics; + private final Map groupMetrics = new HashMap<>(); + private final String source; + private final String target; + + MirrorMetrics(MirrorTaskConfig taskConfig) { + this.target = taskConfig.targetClusterAlias(); + this.source = taskConfig.sourceClusterAlias(); + this.metrics = new Metrics(); + + // for side-effect + metrics.sensor("record-count"); + metrics.sensor("byte-rate"); + metrics.sensor("record-age"); + metrics.sensor("replication-latency"); + + ReplicationPolicy replicationPolicy = taskConfig.replicationPolicy(); + partitionMetrics = taskConfig.taskTopicPartitions().stream() + .map(x -> new TopicPartition(replicationPolicy.formatRemoteTopic(source, x.topic()), x.partition())) + .collect(Collectors.toMap(x -> x, PartitionMetrics::new)); + + } + + @Override + public void close() { + metrics.close(); + } + + void countRecord(TopicPartition topicPartition) { + partitionMetrics.get(topicPartition).recordSensor.record(); + } + + void recordAge(TopicPartition topicPartition, long ageMillis) { + partitionMetrics.get(topicPartition).recordAgeSensor.record((double) ageMillis); + } + + void replicationLatency(TopicPartition topicPartition, long millis) { + partitionMetrics.get(topicPartition).replicationLatencySensor.record((double) millis); + } + + void recordBytes(TopicPartition topicPartition, long bytes) { + partitionMetrics.get(topicPartition).byteSensor.record((double) bytes); + } + + void checkpointLatency(TopicPartition topicPartition, String group, long millis) { + group(topicPartition, group).checkpointLatencySensor.record((double) millis); + } + + GroupMetrics group(TopicPartition topicPartition, String group) { + return groupMetrics.computeIfAbsent(String.join("-", topicPartition.toString(), group), + x -> new GroupMetrics(topicPartition, group)); + } + + void addReporter(MetricsReporter reporter) { + metrics.addReporter(reporter); + } + + private class PartitionMetrics { + private final Sensor recordSensor; + private final Sensor byteSensor; + private final Sensor recordAgeSensor; + private final Sensor replicationLatencySensor; + + PartitionMetrics(TopicPartition topicPartition) { + String prefix = topicPartition.topic() + "-" + topicPartition.partition() + "-"; + + Map tags = new LinkedHashMap<>(); + tags.put("target", target); + tags.put("topic", topicPartition.topic()); + tags.put("partition", Integer.toString(topicPartition.partition())); + + recordSensor = metrics.sensor(prefix + "records-sent"); + recordSensor.add(new Meter(metrics.metricInstance(RECORD_RATE, tags), metrics.metricInstance(RECORD_COUNT, tags))); + + byteSensor = metrics.sensor(prefix + "bytes-sent"); + byteSensor.add(new Meter(metrics.metricInstance(BYTE_RATE, tags), metrics.metricInstance(BYTE_COUNT, tags))); + + recordAgeSensor = metrics.sensor(prefix + "record-age"); + recordAgeSensor.add(metrics.metricInstance(RECORD_AGE, tags), new Value()); + recordAgeSensor.add(metrics.metricInstance(RECORD_AGE_MAX, tags), new Max()); + recordAgeSensor.add(metrics.metricInstance(RECORD_AGE_MIN, tags), new Min()); + recordAgeSensor.add(metrics.metricInstance(RECORD_AGE_AVG, tags), new Avg()); + + replicationLatencySensor = metrics.sensor(prefix + "replication-latency"); + replicationLatencySensor.add(metrics.metricInstance(REPLICATION_LATENCY, tags), new Value()); + replicationLatencySensor.add(metrics.metricInstance(REPLICATION_LATENCY_MAX, tags), new Max()); + replicationLatencySensor.add(metrics.metricInstance(REPLICATION_LATENCY_MIN, tags), new Min()); + replicationLatencySensor.add(metrics.metricInstance(REPLICATION_LATENCY_AVG, tags), new Avg()); + } + } + + private class GroupMetrics { + private final Sensor checkpointLatencySensor; + + GroupMetrics(TopicPartition topicPartition, String group) { + Map tags = new LinkedHashMap<>(); + tags.put("source", source); + tags.put("target", target); + tags.put("group", group); + tags.put("topic", topicPartition.topic()); + tags.put("partition", Integer.toString(topicPartition.partition())); + + checkpointLatencySensor = metrics.sensor("checkpoint-latency"); + checkpointLatencySensor.add(metrics.metricInstance(CHECKPOINT_LATENCY, tags), new Value()); + checkpointLatencySensor.add(metrics.metricInstance(CHECKPOINT_LATENCY_MAX, tags), new Max()); + checkpointLatencySensor.add(metrics.metricInstance(CHECKPOINT_LATENCY_MIN, tags), new Min()); + checkpointLatencySensor.add(metrics.metricInstance(CHECKPOINT_LATENCY_AVG, tags), new Avg()); + } + } +} diff --git a/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorSourceConnector.java b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorSourceConnector.java new file mode 100644 index 0000000..b2cbe02 --- /dev/null +++ b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorSourceConnector.java @@ -0,0 +1,507 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import java.util.Map.Entry; +import org.apache.kafka.connect.connector.Task; +import org.apache.kafka.connect.source.SourceConnector; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.acl.AclBinding; +import org.apache.kafka.common.acl.AclBindingFilter; +import org.apache.kafka.common.acl.AccessControlEntry; +import org.apache.kafka.common.acl.AccessControlEntryFilter; +import org.apache.kafka.common.acl.AclPermissionType; +import org.apache.kafka.common.acl.AclOperation; +import org.apache.kafka.common.resource.ResourceType; +import org.apache.kafka.common.resource.ResourcePattern; +import org.apache.kafka.common.resource.ResourcePatternFilter; +import org.apache.kafka.common.resource.PatternType; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.InvalidPartitionsException; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.clients.admin.AdminClient; +import org.apache.kafka.clients.admin.TopicDescription; +import org.apache.kafka.clients.admin.Config; +import org.apache.kafka.clients.admin.ConfigEntry; +import org.apache.kafka.clients.admin.NewPartitions; +import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.clients.admin.CreateTopicsOptions; + +import java.util.Map; +import java.util.List; +import java.util.ArrayList; +import java.util.Set; +import java.util.HashSet; +import java.util.Collection; +import java.util.Collections; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import java.util.concurrent.ExecutionException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Replicate data, configuration, and ACLs between clusters. + * + * @see MirrorConnectorConfig for supported config properties. + */ +public class MirrorSourceConnector extends SourceConnector { + + private static final Logger log = LoggerFactory.getLogger(MirrorSourceConnector.class); + private static final ResourcePatternFilter ANY_TOPIC = new ResourcePatternFilter(ResourceType.TOPIC, + null, PatternType.ANY); + private static final AclBindingFilter ANY_TOPIC_ACL = new AclBindingFilter(ANY_TOPIC, AccessControlEntryFilter.ANY); + + private Scheduler scheduler; + private MirrorConnectorConfig config; + private SourceAndTarget sourceAndTarget; + private String connectorName; + private TopicFilter topicFilter; + private ConfigPropertyFilter configPropertyFilter; + private List knownSourceTopicPartitions = Collections.emptyList(); + private List knownTargetTopicPartitions = Collections.emptyList(); + private ReplicationPolicy replicationPolicy; + private int replicationFactor; + private AdminClient sourceAdminClient; + private AdminClient targetAdminClient; + + public MirrorSourceConnector() { + // nop + } + + // visible for testing + MirrorSourceConnector(List knownSourceTopicPartitions, MirrorConnectorConfig config) { + this.knownSourceTopicPartitions = knownSourceTopicPartitions; + this.config = config; + } + + // visible for testing + MirrorSourceConnector(SourceAndTarget sourceAndTarget, ReplicationPolicy replicationPolicy, + TopicFilter topicFilter, ConfigPropertyFilter configPropertyFilter) { + this.sourceAndTarget = sourceAndTarget; + this.replicationPolicy = replicationPolicy; + this.topicFilter = topicFilter; + this.configPropertyFilter = configPropertyFilter; + } + + @Override + public void start(Map props) { + long start = System.currentTimeMillis(); + config = new MirrorConnectorConfig(props); + if (!config.enabled()) { + return; + } + connectorName = config.connectorName(); + sourceAndTarget = new SourceAndTarget(config.sourceClusterAlias(), config.targetClusterAlias()); + topicFilter = config.topicFilter(); + configPropertyFilter = config.configPropertyFilter(); + replicationPolicy = config.replicationPolicy(); + replicationFactor = config.replicationFactor(); + sourceAdminClient = AdminClient.create(config.sourceAdminConfig()); + targetAdminClient = AdminClient.create(config.targetAdminConfig()); + scheduler = new Scheduler(MirrorSourceConnector.class, config.adminTimeout()); + scheduler.execute(this::createOffsetSyncsTopic, "creating upstream offset-syncs topic"); + scheduler.execute(this::loadTopicPartitions, "loading initial set of topic-partitions"); + scheduler.execute(this::computeAndCreateTopicPartitions, "creating downstream topic-partitions"); + scheduler.execute(this::refreshKnownTargetTopics, "refreshing known target topics"); + scheduler.scheduleRepeating(this::syncTopicAcls, config.syncTopicAclsInterval(), "syncing topic ACLs"); + scheduler.scheduleRepeating(this::syncTopicConfigs, config.syncTopicConfigsInterval(), + "syncing topic configs"); + scheduler.scheduleRepeatingDelayed(this::refreshTopicPartitions, config.refreshTopicsInterval(), + "refreshing topics"); + log.info("Started {} with {} topic-partitions.", connectorName, knownSourceTopicPartitions.size()); + log.info("Starting {} took {} ms.", connectorName, System.currentTimeMillis() - start); + } + + @Override + public void stop() { + long start = System.currentTimeMillis(); + if (!config.enabled()) { + return; + } + Utils.closeQuietly(scheduler, "scheduler"); + Utils.closeQuietly(topicFilter, "topic filter"); + Utils.closeQuietly(configPropertyFilter, "config property filter"); + Utils.closeQuietly(sourceAdminClient, "source admin client"); + Utils.closeQuietly(targetAdminClient, "target admin client"); + log.info("Stopping {} took {} ms.", connectorName, System.currentTimeMillis() - start); + } + + @Override + public Class taskClass() { + return MirrorSourceTask.class; + } + + // divide topic-partitions among tasks + // since each mirrored topic has different traffic and number of partitions, to balance the load + // across all mirrormaker instances (workers), 'roundrobin' helps to evenly assign all + // topic-partition to the tasks, then the tasks are further distributed to workers. + // For example, 3 tasks to mirror 3 topics with 8, 2 and 2 partitions respectively. + // 't1' denotes 'task 1', 't0p5' denotes 'topic 0, partition 5' + // t1 -> [t0p0, t0p3, t0p6, t1p1] + // t2 -> [t0p1, t0p4, t0p7, t2p0] + // t3 -> [t0p2, t0p5, t1p0, t2p1] + @Override + public List> taskConfigs(int maxTasks) { + if (!config.enabled() || knownSourceTopicPartitions.isEmpty()) { + return Collections.emptyList(); + } + int numTasks = Math.min(maxTasks, knownSourceTopicPartitions.size()); + List> roundRobinByTask = new ArrayList<>(numTasks); + for (int i = 0; i < numTasks; i++) { + roundRobinByTask.add(new ArrayList<>()); + } + int count = 0; + for (TopicPartition partition : knownSourceTopicPartitions) { + int index = count % numTasks; + roundRobinByTask.get(index).add(partition); + count++; + } + + return roundRobinByTask.stream().map(config::taskConfigForTopicPartitions) + .collect(Collectors.toList()); + } + + @Override + public ConfigDef config() { + return MirrorConnectorConfig.CONNECTOR_CONFIG_DEF; + } + + @Override + public String version() { + return "1"; + } + + // visible for testing + List findSourceTopicPartitions() + throws InterruptedException, ExecutionException { + Set topics = listTopics(sourceAdminClient).stream() + .filter(this::shouldReplicateTopic) + .collect(Collectors.toSet()); + return describeTopics(sourceAdminClient, topics).stream() + .flatMap(MirrorSourceConnector::expandTopicDescription) + .collect(Collectors.toList()); + } + + // visible for testing + List findTargetTopicPartitions() + throws InterruptedException, ExecutionException { + Set topics = listTopics(targetAdminClient).stream() + .filter(t -> sourceAndTarget.source().equals(replicationPolicy.topicSource(t))) + .filter(t -> !t.equals(config.checkpointsTopic())) + .collect(Collectors.toSet()); + return describeTopics(targetAdminClient, topics).stream() + .flatMap(MirrorSourceConnector::expandTopicDescription) + .collect(Collectors.toList()); + } + + // visible for testing + void refreshTopicPartitions() + throws InterruptedException, ExecutionException { + + List sourceTopicPartitions = findSourceTopicPartitions(); + List targetTopicPartitions = findTargetTopicPartitions(); + + Set sourceTopicPartitionsSet = new HashSet<>(sourceTopicPartitions); + Set knownSourceTopicPartitionsSet = new HashSet<>(knownSourceTopicPartitions); + + Set upstreamTargetTopicPartitions = targetTopicPartitions.stream() + .map(x -> new TopicPartition(replicationPolicy.upstreamTopic(x.topic()), x.partition())) + .collect(Collectors.toSet()); + + Set missingInTarget = new HashSet<>(sourceTopicPartitions); + missingInTarget.removeAll(upstreamTargetTopicPartitions); + + knownTargetTopicPartitions = targetTopicPartitions; + + // Detect if topic-partitions were added or deleted from the source cluster + // or if topic-partitions are missing from the target cluster + if (!knownSourceTopicPartitionsSet.equals(sourceTopicPartitionsSet) || !missingInTarget.isEmpty()) { + + Set newTopicPartitions = sourceTopicPartitionsSet; + newTopicPartitions.removeAll(knownSourceTopicPartitions); + + Set deletedTopicPartitions = knownSourceTopicPartitionsSet; + deletedTopicPartitions.removeAll(sourceTopicPartitions); + + log.info("Found {} new topic-partitions on {}. " + + "Found {} deleted topic-partitions on {}. " + + "Found {} topic-partitions missing on {}.", + newTopicPartitions.size(), sourceAndTarget.source(), + deletedTopicPartitions.size(), sourceAndTarget.source(), + missingInTarget.size(), sourceAndTarget.target()); + + log.trace("Found new topic-partitions on {}: {}", sourceAndTarget.source(), newTopicPartitions); + log.trace("Found deleted topic-partitions on {}: {}", sourceAndTarget.source(), deletedTopicPartitions); + log.trace("Found missing topic-partitions on {}: {}", sourceAndTarget.target(), missingInTarget); + + knownSourceTopicPartitions = sourceTopicPartitions; + computeAndCreateTopicPartitions(); + context.requestTaskReconfiguration(); + } + } + + private void loadTopicPartitions() + throws InterruptedException, ExecutionException { + knownSourceTopicPartitions = findSourceTopicPartitions(); + knownTargetTopicPartitions = findTargetTopicPartitions(); + } + + private void refreshKnownTargetTopics() + throws InterruptedException, ExecutionException { + knownTargetTopicPartitions = findTargetTopicPartitions(); + } + + private Set topicsBeingReplicated() { + Set knownTargetTopics = toTopics(knownTargetTopicPartitions); + return knownSourceTopicPartitions.stream() + .map(TopicPartition::topic) + .distinct() + .filter(x -> knownTargetTopics.contains(formatRemoteTopic(x))) + .collect(Collectors.toSet()); + } + + private Set toTopics(Collection tps) { + return tps.stream() + .map(TopicPartition::topic) + .collect(Collectors.toSet()); + } + + private void syncTopicAcls() + throws InterruptedException, ExecutionException { + List bindings = listTopicAclBindings().stream() + .filter(x -> x.pattern().resourceType() == ResourceType.TOPIC) + .filter(x -> x.pattern().patternType() == PatternType.LITERAL) + .filter(this::shouldReplicateAcl) + .filter(x -> shouldReplicateTopic(x.pattern().name())) + .map(this::targetAclBinding) + .collect(Collectors.toList()); + updateTopicAcls(bindings); + } + + private void syncTopicConfigs() + throws InterruptedException, ExecutionException { + Map sourceConfigs = describeTopicConfigs(topicsBeingReplicated()); + Map targetConfigs = sourceConfigs.entrySet().stream() + .collect(Collectors.toMap(x -> formatRemoteTopic(x.getKey()), x -> targetConfig(x.getValue()))); + updateTopicConfigs(targetConfigs); + } + + private void createOffsetSyncsTopic() { + MirrorUtils.createSinglePartitionCompactedTopic(config.offsetSyncsTopic(), config.offsetSyncsTopicReplicationFactor(), config.offsetSyncsTopicAdminConfig()); + } + + void computeAndCreateTopicPartitions() throws ExecutionException, InterruptedException { + // get source and target topics with respective partition counts + Map sourceTopicToPartitionCounts = knownSourceTopicPartitions.stream() + .collect(Collectors.groupingBy(TopicPartition::topic, Collectors.counting())).entrySet().stream() + .collect(Collectors.toMap(Entry::getKey, Entry::getValue)); + Map targetTopicToPartitionCounts = knownTargetTopicPartitions.stream() + .collect(Collectors.groupingBy(TopicPartition::topic, Collectors.counting())).entrySet().stream() + .collect(Collectors.toMap(Entry::getKey, Entry::getValue)); + + Set knownSourceTopics = sourceTopicToPartitionCounts.keySet(); + Set knownTargetTopics = targetTopicToPartitionCounts.keySet(); + Map sourceToRemoteTopics = knownSourceTopics.stream() + .collect(Collectors.toMap(Function.identity(), sourceTopic -> formatRemoteTopic(sourceTopic))); + + // compute existing and new source topics + Map> partitionedSourceTopics = knownSourceTopics.stream() + .collect(Collectors.partitioningBy(sourceTopic -> knownTargetTopics.contains(sourceToRemoteTopics.get(sourceTopic)), + Collectors.toSet())); + Set existingSourceTopics = partitionedSourceTopics.get(true); + Set newSourceTopics = partitionedSourceTopics.get(false); + + // create new topics + if (!newSourceTopics.isEmpty()) + createNewTopics(newSourceTopics, sourceTopicToPartitionCounts); + + // compute topics with new partitions + Map sourceTopicsWithNewPartitions = existingSourceTopics.stream() + .filter(sourceTopic -> { + String targetTopic = sourceToRemoteTopics.get(sourceTopic); + return sourceTopicToPartitionCounts.get(sourceTopic) > targetTopicToPartitionCounts.get(targetTopic); + }) + .collect(Collectors.toMap(Function.identity(), sourceTopicToPartitionCounts::get)); + + // create new partitions + if (!sourceTopicsWithNewPartitions.isEmpty()) { + Map newTargetPartitions = sourceTopicsWithNewPartitions.entrySet().stream() + .collect(Collectors.toMap(sourceTopicAndPartitionCount -> sourceToRemoteTopics.get(sourceTopicAndPartitionCount.getKey()), + sourceTopicAndPartitionCount -> NewPartitions.increaseTo(sourceTopicAndPartitionCount.getValue().intValue()))); + createNewPartitions(newTargetPartitions); + } + } + + private void createNewTopics(Set newSourceTopics, Map sourceTopicToPartitionCounts) + throws ExecutionException, InterruptedException { + Map sourceTopicToConfig = describeTopicConfigs(newSourceTopics); + Map newTopics = newSourceTopics.stream() + .map(sourceTopic -> { + String remoteTopic = formatRemoteTopic(sourceTopic); + int partitionCount = sourceTopicToPartitionCounts.get(sourceTopic).intValue(); + Map configs = configToMap(sourceTopicToConfig.get(sourceTopic)); + return new NewTopic(remoteTopic, partitionCount, (short) replicationFactor) + .configs(configs); + }) + .collect(Collectors.toMap(NewTopic::name, Function.identity())); + createNewTopics(newTopics); + } + + // visible for testing + void createNewTopics(Map newTopics) { + targetAdminClient.createTopics(newTopics.values(), new CreateTopicsOptions()).values().forEach((k, v) -> v.whenComplete((x, e) -> { + if (e != null) { + log.warn("Could not create topic {}.", k, e); + } else { + log.info("Created remote topic {} with {} partitions.", k, newTopics.get(k).numPartitions()); + } + })); + } + + void createNewPartitions(Map newPartitions) { + targetAdminClient.createPartitions(newPartitions).values().forEach((k, v) -> v.whenComplete((x, e) -> { + if (e instanceof InvalidPartitionsException) { + // swallow, this is normal + } else if (e != null) { + log.warn("Could not create topic-partitions for {}.", k, e); + } else { + log.info("Increased size of {} to {} partitions.", k, newPartitions.get(k).totalCount()); + } + })); + } + + private Set listTopics(AdminClient adminClient) + throws InterruptedException, ExecutionException { + return adminClient.listTopics().names().get(); + } + + private Collection listTopicAclBindings() + throws InterruptedException, ExecutionException { + return sourceAdminClient.describeAcls(ANY_TOPIC_ACL).values().get(); + } + + private static Collection describeTopics(AdminClient adminClient, Collection topics) + throws InterruptedException, ExecutionException { + return adminClient.describeTopics(topics).allTopicNames().get().values(); + } + + static Map configToMap(Config config) { + return config.entries().stream() + .collect(Collectors.toMap(ConfigEntry::name, ConfigEntry::value)); + } + + @SuppressWarnings("deprecation") + // use deprecated alterConfigs API for broker compatibility back to 0.11.0 + private void updateTopicConfigs(Map topicConfigs) { + Map configs = topicConfigs.entrySet().stream() + .collect(Collectors.toMap(x -> + new ConfigResource(ConfigResource.Type.TOPIC, x.getKey()), Entry::getValue)); + log.trace("Syncing configs for {} topics.", configs.size()); + targetAdminClient.alterConfigs(configs).values().forEach((k, v) -> v.whenComplete((x, e) -> { + if (e != null) { + log.warn("Could not alter configuration of topic {}.", k.name(), e); + } + })); + } + + private void updateTopicAcls(List bindings) { + log.trace("Syncing {} topic ACL bindings.", bindings.size()); + targetAdminClient.createAcls(bindings).values().forEach((k, v) -> v.whenComplete((x, e) -> { + if (e != null) { + log.warn("Could not sync ACL of topic {}.", k.pattern().name(), e); + } + })); + } + + private static Stream expandTopicDescription(TopicDescription description) { + String topic = description.name(); + return description.partitions().stream() + .map(x -> new TopicPartition(topic, x.partition())); + } + + Map describeTopicConfigs(Set topics) + throws InterruptedException, ExecutionException { + Set resources = topics.stream() + .map(x -> new ConfigResource(ConfigResource.Type.TOPIC, x)) + .collect(Collectors.toSet()); + return sourceAdminClient.describeConfigs(resources).all().get().entrySet().stream() + .collect(Collectors.toMap(x -> x.getKey().name(), Entry::getValue)); + } + + Config targetConfig(Config sourceConfig) { + List entries = sourceConfig.entries().stream() + .filter(x -> !x.isDefault() && !x.isReadOnly() && !x.isSensitive()) + .filter(x -> x.source() != ConfigEntry.ConfigSource.STATIC_BROKER_CONFIG) + .filter(x -> shouldReplicateTopicConfigurationProperty(x.name())) + .collect(Collectors.toList()); + return new Config(entries); + } + + private static AccessControlEntry downgradeAllowAllACL(AccessControlEntry entry) { + return new AccessControlEntry(entry.principal(), entry.host(), AclOperation.READ, entry.permissionType()); + } + + AclBinding targetAclBinding(AclBinding sourceAclBinding) { + String targetTopic = formatRemoteTopic(sourceAclBinding.pattern().name()); + final AccessControlEntry entry; + if (sourceAclBinding.entry().permissionType() == AclPermissionType.ALLOW + && sourceAclBinding.entry().operation() == AclOperation.ALL) { + entry = downgradeAllowAllACL(sourceAclBinding.entry()); + } else { + entry = sourceAclBinding.entry(); + } + return new AclBinding(new ResourcePattern(ResourceType.TOPIC, targetTopic, PatternType.LITERAL), entry); + } + + boolean shouldReplicateTopic(String topic) { + return (topicFilter.shouldReplicateTopic(topic) || replicationPolicy.isHeartbeatsTopic(topic)) + && !replicationPolicy.isInternalTopic(topic) && !isCycle(topic); + } + + boolean shouldReplicateAcl(AclBinding aclBinding) { + return !(aclBinding.entry().permissionType() == AclPermissionType.ALLOW + && aclBinding.entry().operation() == AclOperation.WRITE); + } + + boolean shouldReplicateTopicConfigurationProperty(String property) { + return configPropertyFilter.shouldReplicateConfigProperty(property); + } + + // Recurse upstream to detect cycles, i.e. whether this topic is already on the target cluster + boolean isCycle(String topic) { + String source = replicationPolicy.topicSource(topic); + if (source == null) { + return false; + } else if (source.equals(sourceAndTarget.target())) { + return true; + } else { + String upstreamTopic = replicationPolicy.upstreamTopic(topic); + if (upstreamTopic.equals(topic)) { + // Extra check for IdentityReplicationPolicy and similar impls that don't prevent cycles. + return false; + } + return isCycle(upstreamTopic); + } + } + + String formatRemoteTopic(String topic) { + return replicationPolicy.formatRemoteTopic(sourceAndTarget.source(), topic); + } +} diff --git a/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorSourceTask.java b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorSourceTask.java new file mode 100644 index 0000000..fb5c844 --- /dev/null +++ b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorSourceTask.java @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.source.SourceTask; +import org.apache.kafka.connect.source.SourceRecord; +import org.apache.kafka.connect.header.Headers; +import org.apache.kafka.connect.header.ConnectHeaders; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.errors.WakeupException; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.utils.Utils; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Map; +import java.util.HashMap; +import java.util.List; +import java.util.Set; +import java.util.ArrayList; +import java.util.stream.Collectors; +import java.util.concurrent.Semaphore; +import java.time.Duration; + +/** Replicates a set of topic-partitions. */ +public class MirrorSourceTask extends SourceTask { + + private static final Logger log = LoggerFactory.getLogger(MirrorSourceTask.class); + + private static final int MAX_OUTSTANDING_OFFSET_SYNCS = 10; + + private KafkaConsumer consumer; + private KafkaProducer offsetProducer; + private String sourceClusterAlias; + private String offsetSyncsTopic; + private Duration pollTimeout; + private long maxOffsetLag; + private Map partitionStates; + private ReplicationPolicy replicationPolicy; + private MirrorMetrics metrics; + private boolean stopping = false; + private Semaphore outstandingOffsetSyncs; + private Semaphore consumerAccess; + + public MirrorSourceTask() {} + + // for testing + MirrorSourceTask(KafkaConsumer consumer, MirrorMetrics metrics, String sourceClusterAlias, + ReplicationPolicy replicationPolicy, long maxOffsetLag) { + this.consumer = consumer; + this.metrics = metrics; + this.sourceClusterAlias = sourceClusterAlias; + this.replicationPolicy = replicationPolicy; + this.maxOffsetLag = maxOffsetLag; + consumerAccess = new Semaphore(1); + } + + @Override + public void start(Map props) { + MirrorTaskConfig config = new MirrorTaskConfig(props); + outstandingOffsetSyncs = new Semaphore(MAX_OUTSTANDING_OFFSET_SYNCS); + consumerAccess = new Semaphore(1); // let one thread at a time access the consumer + sourceClusterAlias = config.sourceClusterAlias(); + metrics = config.metrics(); + pollTimeout = config.consumerPollTimeout(); + maxOffsetLag = config.maxOffsetLag(); + replicationPolicy = config.replicationPolicy(); + partitionStates = new HashMap<>(); + offsetSyncsTopic = config.offsetSyncsTopic(); + consumer = MirrorUtils.newConsumer(config.sourceConsumerConfig()); + offsetProducer = MirrorUtils.newProducer(config.offsetSyncsTopicProducerConfig()); + Set taskTopicPartitions = config.taskTopicPartitions(); + Map topicPartitionOffsets = loadOffsets(taskTopicPartitions); + consumer.assign(topicPartitionOffsets.keySet()); + log.info("Starting with {} previously uncommitted partitions.", topicPartitionOffsets.entrySet().stream() + .filter(x -> x.getValue() == 0L).count()); + log.trace("Seeking offsets: {}", topicPartitionOffsets); + topicPartitionOffsets.forEach(consumer::seek); + log.info("{} replicating {} topic-partitions {}->{}: {}.", Thread.currentThread().getName(), + taskTopicPartitions.size(), sourceClusterAlias, config.targetClusterAlias(), taskTopicPartitions); + } + + @Override + public void commit() { + // nop + } + + @Override + public void stop() { + long start = System.currentTimeMillis(); + stopping = true; + consumer.wakeup(); + try { + consumerAccess.acquire(); + } catch (InterruptedException e) { + log.warn("Interrupted waiting for access to consumer. Will try closing anyway."); + } + Utils.closeQuietly(consumer, "source consumer"); + Utils.closeQuietly(offsetProducer, "offset producer"); + Utils.closeQuietly(metrics, "metrics"); + log.info("Stopping {} took {} ms.", Thread.currentThread().getName(), System.currentTimeMillis() - start); + } + + @Override + public String version() { + return "1"; + } + + @Override + public List poll() { + if (!consumerAccess.tryAcquire()) { + return null; + } + if (stopping) { + return null; + } + try { + ConsumerRecords records = consumer.poll(pollTimeout); + List sourceRecords = new ArrayList<>(records.count()); + for (ConsumerRecord record : records) { + SourceRecord converted = convertRecord(record); + sourceRecords.add(converted); + TopicPartition topicPartition = new TopicPartition(converted.topic(), converted.kafkaPartition()); + metrics.recordAge(topicPartition, System.currentTimeMillis() - record.timestamp()); + metrics.recordBytes(topicPartition, byteSize(record.value())); + } + if (sourceRecords.isEmpty()) { + // WorkerSourceTasks expects non-zero batch size + return null; + } else { + log.trace("Polled {} records from {}.", sourceRecords.size(), records.partitions()); + return sourceRecords; + } + } catch (WakeupException e) { + return null; + } catch (KafkaException e) { + log.warn("Failure during poll.", e); + return null; + } catch (Throwable e) { + log.error("Failure during poll.", e); + // allow Connect to deal with the exception + throw e; + } finally { + consumerAccess.release(); + } + } + + @Override + public void commitRecord(SourceRecord record, RecordMetadata metadata) { + try { + if (stopping) { + return; + } + if (!metadata.hasOffset()) { + log.error("RecordMetadata has no offset -- can't sync offsets for {}.", record.topic()); + return; + } + TopicPartition topicPartition = new TopicPartition(record.topic(), record.kafkaPartition()); + long latency = System.currentTimeMillis() - record.timestamp(); + metrics.countRecord(topicPartition); + metrics.replicationLatency(topicPartition, latency); + TopicPartition sourceTopicPartition = MirrorUtils.unwrapPartition(record.sourcePartition()); + long upstreamOffset = MirrorUtils.unwrapOffset(record.sourceOffset()); + long downstreamOffset = metadata.offset(); + maybeSyncOffsets(sourceTopicPartition, upstreamOffset, downstreamOffset); + } catch (Throwable e) { + log.warn("Failure committing record.", e); + } + } + + // updates partition state and sends OffsetSync if necessary + private void maybeSyncOffsets(TopicPartition topicPartition, long upstreamOffset, + long downstreamOffset) { + PartitionState partitionState = + partitionStates.computeIfAbsent(topicPartition, x -> new PartitionState(maxOffsetLag)); + if (partitionState.update(upstreamOffset, downstreamOffset)) { + sendOffsetSync(topicPartition, upstreamOffset, downstreamOffset); + } + } + + // sends OffsetSync record upstream to internal offsets topic + private void sendOffsetSync(TopicPartition topicPartition, long upstreamOffset, + long downstreamOffset) { + if (!outstandingOffsetSyncs.tryAcquire()) { + // Too many outstanding offset syncs. + return; + } + OffsetSync offsetSync = new OffsetSync(topicPartition, upstreamOffset, downstreamOffset); + ProducerRecord record = new ProducerRecord<>(offsetSyncsTopic, 0, + offsetSync.recordKey(), offsetSync.recordValue()); + offsetProducer.send(record, (x, e) -> { + if (e != null) { + log.error("Failure sending offset sync.", e); + } else { + log.trace("Sync'd offsets for {}: {}=={}", topicPartition, + upstreamOffset, downstreamOffset); + } + outstandingOffsetSyncs.release(); + }); + } + + private Map loadOffsets(Set topicPartitions) { + return topicPartitions.stream().collect(Collectors.toMap(x -> x, this::loadOffset)); + } + + private Long loadOffset(TopicPartition topicPartition) { + Map wrappedPartition = MirrorUtils.wrapPartition(topicPartition, sourceClusterAlias); + Map wrappedOffset = context.offsetStorageReader().offset(wrappedPartition); + return MirrorUtils.unwrapOffset(wrappedOffset) + 1; + } + + // visible for testing + SourceRecord convertRecord(ConsumerRecord record) { + String targetTopic = formatRemoteTopic(record.topic()); + Headers headers = convertHeaders(record); + return new SourceRecord( + MirrorUtils.wrapPartition(new TopicPartition(record.topic(), record.partition()), sourceClusterAlias), + MirrorUtils.wrapOffset(record.offset()), + targetTopic, record.partition(), + Schema.OPTIONAL_BYTES_SCHEMA, record.key(), + Schema.BYTES_SCHEMA, record.value(), + record.timestamp(), headers); + } + + private Headers convertHeaders(ConsumerRecord record) { + ConnectHeaders headers = new ConnectHeaders(); + for (Header header : record.headers()) { + headers.addBytes(header.key(), header.value()); + } + return headers; + } + + private String formatRemoteTopic(String topic) { + return replicationPolicy.formatRemoteTopic(sourceClusterAlias, topic); + } + + private static int byteSize(byte[] bytes) { + if (bytes == null) { + return 0; + } else { + return bytes.length; + } + } + + static class PartitionState { + long previousUpstreamOffset = -1L; + long previousDownstreamOffset = -1L; + long lastSyncUpstreamOffset = -1L; + long lastSyncDownstreamOffset = -1L; + long maxOffsetLag; + + PartitionState(long maxOffsetLag) { + this.maxOffsetLag = maxOffsetLag; + } + + // true if we should emit an offset sync + boolean update(long upstreamOffset, long downstreamOffset) { + boolean shouldSyncOffsets = false; + long upstreamStep = upstreamOffset - lastSyncUpstreamOffset; + long downstreamTargetOffset = lastSyncDownstreamOffset + upstreamStep; + if (lastSyncDownstreamOffset == -1L + || downstreamOffset - downstreamTargetOffset >= maxOffsetLag + || upstreamOffset - previousUpstreamOffset != 1L + || downstreamOffset < previousDownstreamOffset) { + lastSyncUpstreamOffset = upstreamOffset; + lastSyncDownstreamOffset = downstreamOffset; + shouldSyncOffsets = true; + } + previousUpstreamOffset = upstreamOffset; + previousDownstreamOffset = downstreamOffset; + return shouldSyncOffsets; + } + } +} diff --git a/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorTaskConfig.java b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorTaskConfig.java new file mode 100644 index 0000000..73024f5 --- /dev/null +++ b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorTaskConfig.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.TopicPartition; + +import java.util.Map; +import java.util.Set; +import java.util.List; +import java.util.HashSet; +import java.util.Collections; +import java.util.stream.Collectors; + +public class MirrorTaskConfig extends MirrorConnectorConfig { + + private static final String TASK_TOPIC_PARTITIONS_DOC = "Topic-partitions assigned to this task to replicate."; + private static final String TASK_CONSUMER_GROUPS_DOC = "Consumer groups assigned to this task to replicate."; + + public MirrorTaskConfig(Map props) { + super(TASK_CONFIG_DEF, props); + } + + Set taskTopicPartitions() { + List fields = getList(TASK_TOPIC_PARTITIONS); + if (fields == null || fields.isEmpty()) { + return Collections.emptySet(); + } + return fields.stream() + .map(MirrorUtils::decodeTopicPartition) + .collect(Collectors.toSet()); + } + + Set taskConsumerGroups() { + List fields = getList(TASK_CONSUMER_GROUPS); + if (fields == null || fields.isEmpty()) { + return Collections.emptySet(); + } + return new HashSet<>(fields); + } + + MirrorMetrics metrics() { + MirrorMetrics metrics = new MirrorMetrics(this); + metricsReporters().forEach(metrics::addReporter); + return metrics; + } + + protected static final ConfigDef TASK_CONFIG_DEF = new ConfigDef(CONNECTOR_CONFIG_DEF) + .define( + TASK_TOPIC_PARTITIONS, + ConfigDef.Type.LIST, + null, + ConfigDef.Importance.LOW, + TASK_TOPIC_PARTITIONS_DOC) + .define( + TASK_CONSUMER_GROUPS, + ConfigDef.Type.LIST, + null, + ConfigDef.Importance.LOW, + TASK_CONSUMER_GROUPS_DOC); +} diff --git a/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorUtils.java b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorUtils.java new file mode 100644 index 0000000..f15dda8 --- /dev/null +++ b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/MirrorUtils.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.connect.util.TopicAdmin; + +import java.util.Arrays; +import java.util.Map; +import java.util.List; +import java.util.HashMap; +import java.util.Collections; +import java.util.regex.Pattern; + +/** Internal utility methods. */ +final class MirrorUtils { + + // utility class + private MirrorUtils() {} + + static KafkaProducer newProducer(Map props) { + return new KafkaProducer<>(props, new ByteArraySerializer(), new ByteArraySerializer()); + } + + static KafkaConsumer newConsumer(Map props) { + return new KafkaConsumer<>(props, new ByteArrayDeserializer(), new ByteArrayDeserializer()); + } + + static String encodeTopicPartition(TopicPartition topicPartition) { + return topicPartition.toString(); + } + + static Map wrapPartition(TopicPartition topicPartition, String sourceClusterAlias) { + Map wrapped = new HashMap<>(); + wrapped.put("topic", topicPartition.topic()); + wrapped.put("partition", topicPartition.partition()); + wrapped.put("cluster", sourceClusterAlias); + return wrapped; + } + + static Map wrapOffset(long offset) { + return Collections.singletonMap("offset", offset); + } + + static TopicPartition unwrapPartition(Map wrapped) { + String topic = (String) wrapped.get("topic"); + int partition = (Integer) wrapped.get("partition"); + return new TopicPartition(topic, partition); + } + + static Long unwrapOffset(Map wrapped) { + if (wrapped == null || wrapped.get("offset") == null) { + return -1L; + } + return (Long) wrapped.get("offset"); + } + + static TopicPartition decodeTopicPartition(String topicPartitionString) { + int sep = topicPartitionString.lastIndexOf('-'); + String topic = topicPartitionString.substring(0, sep); + String partitionString = topicPartitionString.substring(sep + 1); + int partition = Integer.parseInt(partitionString); + return new TopicPartition(topic, partition); + } + + // returns null if given empty list + static Pattern compilePatternList(List fields) { + if (fields.isEmpty()) { + // The empty pattern matches _everything_, but a blank + // config property should match _nothing_. + return null; + } else { + String joined = String.join("|", fields); + return Pattern.compile(joined); + } + } + + static Pattern compilePatternList(String fields) { + return compilePatternList(Arrays.asList(fields.split("\\W*,\\W*"))); + } + + static void createCompactedTopic(String topicName, short partitions, short replicationFactor, Map adminProps) { + NewTopic topicDescription = TopicAdmin.defineTopic(topicName). + compacted(). + partitions(partitions). + replicationFactor(replicationFactor). + build(); + + try (TopicAdmin admin = new TopicAdmin(adminProps)) { + admin.createTopics(topicDescription); + } + } + + static void createSinglePartitionCompactedTopic(String topicName, short replicationFactor, Map adminProps) { + createCompactedTopic(topicName, (short) 1, replicationFactor, adminProps); + } +} diff --git a/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/OffsetSync.java b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/OffsetSync.java new file mode 100644 index 0000000..68e6441 --- /dev/null +++ b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/OffsetSync.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.protocol.types.Field; +import org.apache.kafka.common.protocol.types.Schema; +import org.apache.kafka.common.protocol.types.Struct; +import org.apache.kafka.common.protocol.types.Type; +import org.apache.kafka.clients.consumer.ConsumerRecord; + +import java.nio.ByteBuffer; + +public class OffsetSync { + public static final String TOPIC_KEY = "topic"; + public static final String PARTITION_KEY = "partition"; + public static final String UPSTREAM_OFFSET_KEY = "upstreamOffset"; + public static final String DOWNSTREAM_OFFSET_KEY = "offset"; + + public static final Schema VALUE_SCHEMA = new Schema( + new Field(UPSTREAM_OFFSET_KEY, Type.INT64), + new Field(DOWNSTREAM_OFFSET_KEY, Type.INT64)); + + public static final Schema KEY_SCHEMA = new Schema( + new Field(TOPIC_KEY, Type.STRING), + new Field(PARTITION_KEY, Type.INT32)); + + private TopicPartition topicPartition; + private long upstreamOffset; + private long downstreamOffset; + + public OffsetSync(TopicPartition topicPartition, long upstreamOffset, long downstreamOffset) { + this.topicPartition = topicPartition; + this.upstreamOffset = upstreamOffset; + this.downstreamOffset = downstreamOffset; + } + + public TopicPartition topicPartition() { + return topicPartition; + } + + public long upstreamOffset() { + return upstreamOffset; + } + + public long downstreamOffset() { + return downstreamOffset; + } + + @Override + public String toString() { + return String.format("OffsetSync{topicPartition=%s, upstreamOffset=%d, downstreamOffset=%d}", + topicPartition, upstreamOffset, downstreamOffset); + } + + ByteBuffer serializeValue() { + Struct struct = valueStruct(); + ByteBuffer buffer = ByteBuffer.allocate(VALUE_SCHEMA.sizeOf(struct)); + VALUE_SCHEMA.write(buffer, struct); + buffer.flip(); + return buffer; + } + + ByteBuffer serializeKey() { + Struct struct = keyStruct(); + ByteBuffer buffer = ByteBuffer.allocate(KEY_SCHEMA.sizeOf(struct)); + KEY_SCHEMA.write(buffer, struct); + buffer.flip(); + return buffer; + } + + public static OffsetSync deserializeRecord(ConsumerRecord record) { + Struct keyStruct = KEY_SCHEMA.read(ByteBuffer.wrap(record.key())); + String topic = keyStruct.getString(TOPIC_KEY); + int partition = keyStruct.getInt(PARTITION_KEY); + + Struct valueStruct = VALUE_SCHEMA.read(ByteBuffer.wrap(record.value())); + long upstreamOffset = valueStruct.getLong(UPSTREAM_OFFSET_KEY); + long downstreamOffset = valueStruct.getLong(DOWNSTREAM_OFFSET_KEY); + + return new OffsetSync(new TopicPartition(topic, partition), upstreamOffset, downstreamOffset); + } + + private Struct valueStruct() { + Struct struct = new Struct(VALUE_SCHEMA); + struct.set(UPSTREAM_OFFSET_KEY, upstreamOffset); + struct.set(DOWNSTREAM_OFFSET_KEY, downstreamOffset); + return struct; + } + + private Struct keyStruct() { + Struct struct = new Struct(KEY_SCHEMA); + struct.set(TOPIC_KEY, topicPartition.topic()); + struct.set(PARTITION_KEY, topicPartition.partition()); + return struct; + } + + byte[] recordKey() { + return serializeKey().array(); + } + + byte[] recordValue() { + return serializeValue().array(); + } +} + diff --git a/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/OffsetSyncStore.java b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/OffsetSyncStore.java new file mode 100644 index 0000000..600dda4 --- /dev/null +++ b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/OffsetSyncStore.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.WakeupException; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.apache.kafka.common.utils.Utils; + +import java.util.Map; +import java.util.HashMap; +import java.util.Collections; +import java.time.Duration; + +/** Used internally by MirrorMaker. Stores offset syncs and performs offset translation. */ +class OffsetSyncStore implements AutoCloseable { + private KafkaConsumer consumer; + private Map offsetSyncs = new HashMap<>(); + private TopicPartition offsetSyncTopicPartition; + + OffsetSyncStore(MirrorConnectorConfig config) { + consumer = new KafkaConsumer<>(config.offsetSyncsTopicConsumerConfig(), + new ByteArrayDeserializer(), new ByteArrayDeserializer()); + offsetSyncTopicPartition = new TopicPartition(config.offsetSyncsTopic(), 0); + consumer.assign(Collections.singleton(offsetSyncTopicPartition)); + } + + // for testing + OffsetSyncStore(KafkaConsumer consumer, TopicPartition offsetSyncTopicPartition) { + this.consumer = consumer; + this.offsetSyncTopicPartition = offsetSyncTopicPartition; + } + + long translateDownstream(TopicPartition sourceTopicPartition, long upstreamOffset) { + OffsetSync offsetSync = latestOffsetSync(sourceTopicPartition); + if (offsetSync.upstreamOffset() > upstreamOffset) { + // Offset is too far in the past to translate accurately + return -1; + } + long upstreamStep = upstreamOffset - offsetSync.upstreamOffset(); + return offsetSync.downstreamOffset() + upstreamStep; + } + + // poll and handle records + synchronized void update(Duration pollTimeout) { + try { + consumer.poll(pollTimeout).forEach(this::handleRecord); + } catch (WakeupException e) { + // swallow + } + } + + public synchronized void close() { + consumer.wakeup(); + Utils.closeQuietly(consumer, "offset sync store consumer"); + } + + protected void handleRecord(ConsumerRecord record) { + OffsetSync offsetSync = OffsetSync.deserializeRecord(record); + TopicPartition sourceTopicPartition = offsetSync.topicPartition(); + offsetSyncs.put(sourceTopicPartition, offsetSync); + } + + private OffsetSync latestOffsetSync(TopicPartition topicPartition) { + return offsetSyncs.computeIfAbsent(topicPartition, x -> new OffsetSync(topicPartition, + -1, -1)); + } +} diff --git a/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/Scheduler.java b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/Scheduler.java new file mode 100644 index 0000000..20f2ca7 --- /dev/null +++ b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/Scheduler.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import java.time.Duration; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +class Scheduler implements AutoCloseable { + private static Logger log = LoggerFactory.getLogger(Scheduler.class); + + private final String name; + private final ScheduledExecutorService executor = Executors.newSingleThreadScheduledExecutor(); + private final Duration timeout; + private boolean closed = false; + + Scheduler(String name, Duration timeout) { + this.name = name; + this.timeout = timeout; + } + + Scheduler(Class clazz, Duration timeout) { + this("Scheduler for " + clazz.getSimpleName(), timeout); + } + + void scheduleRepeating(Task task, Duration interval, String description) { + if (interval.toMillis() < 0L) { + return; + } + executor.scheduleAtFixedRate(() -> executeThread(task, description), 0, interval.toMillis(), TimeUnit.MILLISECONDS); + } + + void scheduleRepeatingDelayed(Task task, Duration interval, String description) { + if (interval.toMillis() < 0L) { + return; + } + executor.scheduleAtFixedRate(() -> executeThread(task, description), interval.toMillis(), + interval.toMillis(), TimeUnit.MILLISECONDS); + } + + void execute(Task task, String description) { + try { + executor.submit(() -> executeThread(task, description)).get(timeout.toMillis(), TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + log.warn("{} was interrupted running task: {}", name, description); + } catch (TimeoutException e) { + log.error("{} timed out running task: {}", name, description); + } catch (Throwable e) { + log.error("{} caught exception in task: {}", name, description, e); + } + } + + public void close() { + closed = true; + executor.shutdown(); + try { + boolean terminated = executor.awaitTermination(timeout.toMillis(), TimeUnit.MILLISECONDS); + if (!terminated) { + log.error("{} timed out during shutdown of internal scheduler.", name); + } + } catch (InterruptedException e) { + log.warn("{} was interrupted during shutdown of internal scheduler.", name); + } + } + + interface Task { + void run() throws InterruptedException, ExecutionException; + } + + private void run(Task task, String description) { + try { + long start = System.currentTimeMillis(); + task.run(); + long elapsed = System.currentTimeMillis() - start; + log.info("{} took {} ms", description, elapsed); + if (elapsed > timeout.toMillis()) { + log.warn("{} took too long ({} ms) running task: {}", name, elapsed, description); + } + } catch (InterruptedException e) { + log.warn("{} was interrupted running task: {}", name, description); + } catch (Throwable e) { + log.error("{} caught exception in scheduled task: {}", name, description, e); + } + } + + private void executeThread(Task task, String description) { + Thread.currentThread().setName(name + "-" + description); + if (closed) { + log.info("{} skipping task due to shutdown: {}", name, description); + return; + } + run(task, description); + } +} + diff --git a/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/TopicFilter.java b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/TopicFilter.java new file mode 100644 index 0000000..f13453f --- /dev/null +++ b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/TopicFilter.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.common.Configurable; +import org.apache.kafka.common.annotation.InterfaceStability; +import java.util.Map; + +/** Defines which topics should be replicated. */ +@InterfaceStability.Evolving +public interface TopicFilter extends Configurable, AutoCloseable { + + boolean shouldReplicateTopic(String topic); + + default void close() { + //nop + } + + default void configure(Map props) { + //nop + } +} diff --git a/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/formatters/CheckpointFormatter.java b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/formatters/CheckpointFormatter.java new file mode 100644 index 0000000..33fe695 --- /dev/null +++ b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/formatters/CheckpointFormatter.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror.formatters; + +import java.io.PrintStream; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.MessageFormatter; +import org.apache.kafka.connect.mirror.Checkpoint; + +public class CheckpointFormatter implements MessageFormatter { + + @Override + public void writeTo(ConsumerRecord record, PrintStream output) { + output.println(Checkpoint.deserializeRecord(record)); + } +} diff --git a/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/formatters/HeartbeatFormatter.java b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/formatters/HeartbeatFormatter.java new file mode 100644 index 0000000..a193dbe --- /dev/null +++ b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/formatters/HeartbeatFormatter.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror.formatters; + +import java.io.PrintStream; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.MessageFormatter; +import org.apache.kafka.connect.mirror.Heartbeat; + +public class HeartbeatFormatter implements MessageFormatter { + + @Override + public void writeTo(ConsumerRecord record, PrintStream output) { + output.println(Heartbeat.deserializeRecord(record)); + } +} diff --git a/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/formatters/OffsetSyncFormatter.java b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/formatters/OffsetSyncFormatter.java new file mode 100644 index 0000000..dacae60 --- /dev/null +++ b/connect/mirror/src/main/java/org/apache/kafka/connect/mirror/formatters/OffsetSyncFormatter.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror.formatters; + +import java.io.PrintStream; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.MessageFormatter; +import org.apache.kafka.connect.mirror.OffsetSync; + +public class OffsetSyncFormatter implements MessageFormatter { + + @Override + public void writeTo(ConsumerRecord record, PrintStream output) { + output.println(OffsetSync.deserializeRecord(record)); + } +} diff --git a/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/CheckpointTest.java b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/CheckpointTest.java new file mode 100644 index 0000000..f008f99 --- /dev/null +++ b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/CheckpointTest.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.TopicPartition; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class CheckpointTest { + + @Test + public void testSerde() { + Checkpoint checkpoint = new Checkpoint("group-1", new TopicPartition("topic-2", 3), 4, 5, "metadata-6"); + byte[] key = checkpoint.recordKey(); + byte[] value = checkpoint.recordValue(); + ConsumerRecord record = new ConsumerRecord<>("any-topic", 7, 8, key, value); + Checkpoint deserialized = Checkpoint.deserializeRecord(record); + assertEquals(checkpoint.consumerGroupId(), deserialized.consumerGroupId(), + "Failure on checkpoint consumerGroupId serde"); + assertEquals(checkpoint.topicPartition(), deserialized.topicPartition(), + "Failure on checkpoint topicPartition serde"); + assertEquals(checkpoint.upstreamOffset(), deserialized.upstreamOffset(), + "Failure on checkpoint upstreamOffset serde"); + assertEquals(checkpoint.downstreamOffset(), deserialized.downstreamOffset(), + "Failure on checkpoint downstreamOffset serde"); + } +} diff --git a/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/HeartbeatTest.java b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/HeartbeatTest.java new file mode 100644 index 0000000..723b0dc --- /dev/null +++ b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/HeartbeatTest.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.clients.consumer.ConsumerRecord; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class HeartbeatTest { + + @Test + public void testSerde() { + Heartbeat heartbeat = new Heartbeat("source-1", "target-2", 1234567890L); + byte[] key = heartbeat.recordKey(); + byte[] value = heartbeat.recordValue(); + ConsumerRecord record = new ConsumerRecord<>("any-topic", 6, 7, key, value); + Heartbeat deserialized = Heartbeat.deserializeRecord(record); + assertEquals(heartbeat.sourceClusterAlias(), deserialized.sourceClusterAlias(), + "Failure on heartbeat sourceClusterAlias serde"); + assertEquals(heartbeat.targetClusterAlias(), deserialized.targetClusterAlias(), + "Failure on heartbeat targetClusterAlias serde"); + assertEquals(heartbeat.timestamp(), deserialized.timestamp(), + "Failure on heartbeat timestamp serde"); + } +} diff --git a/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/MirrorCheckpointConnectorTest.java b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/MirrorCheckpointConnectorTest.java new file mode 100644 index 0000000..1391e76 --- /dev/null +++ b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/MirrorCheckpointConnectorTest.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.clients.admin.ConsumerGroupListing; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.apache.kafka.connect.mirror.TestUtils.makeProps; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.spy; + + +public class MirrorCheckpointConnectorTest { + + private static final String CONSUMER_GROUP = "consumer-group-1"; + + @Test + public void testMirrorCheckpointConnectorDisabled() { + // disable the checkpoint emission + MirrorConnectorConfig config = new MirrorConnectorConfig( + makeProps("emit.checkpoints.enabled", "false")); + + List knownConsumerGroups = new ArrayList<>(); + knownConsumerGroups.add(CONSUMER_GROUP); + // MirrorCheckpointConnector as minimum to run taskConfig() + MirrorCheckpointConnector connector = new MirrorCheckpointConnector(knownConsumerGroups, + config); + List> output = connector.taskConfigs(1); + // expect no task will be created + assertEquals(0, output.size(), "MirrorCheckpointConnector not disabled"); + } + + @Test + public void testMirrorCheckpointConnectorEnabled() { + // enable the checkpoint emission + MirrorConnectorConfig config = new MirrorConnectorConfig( + makeProps("emit.checkpoints.enabled", "true")); + + List knownConsumerGroups = new ArrayList<>(); + knownConsumerGroups.add(CONSUMER_GROUP); + // MirrorCheckpointConnector as minimum to run taskConfig() + MirrorCheckpointConnector connector = new MirrorCheckpointConnector(knownConsumerGroups, + config); + List> output = connector.taskConfigs(1); + // expect 1 task will be created + assertEquals(1, output.size(), + "MirrorCheckpointConnectorEnabled for " + CONSUMER_GROUP + " has incorrect size"); + assertEquals(CONSUMER_GROUP, output.get(0).get(MirrorConnectorConfig.TASK_CONSUMER_GROUPS), + "MirrorCheckpointConnectorEnabled for " + CONSUMER_GROUP + " failed"); + } + + @Test + public void testNoConsumerGroup() { + MirrorConnectorConfig config = new MirrorConnectorConfig(makeProps()); + MirrorCheckpointConnector connector = new MirrorCheckpointConnector(new ArrayList<>(), config); + List> output = connector.taskConfigs(1); + // expect no task will be created + assertEquals(0, output.size(), "ConsumerGroup shouldn't exist"); + } + + @Test + public void testReplicationDisabled() { + // disable the replication + MirrorConnectorConfig config = new MirrorConnectorConfig(makeProps("enabled", "false")); + + List knownConsumerGroups = new ArrayList<>(); + knownConsumerGroups.add(CONSUMER_GROUP); + // MirrorCheckpointConnector as minimum to run taskConfig() + MirrorCheckpointConnector connector = new MirrorCheckpointConnector(knownConsumerGroups, config); + List> output = connector.taskConfigs(1); + // expect no task will be created + assertEquals(0, output.size(), "Replication isn't disabled"); + } + + @Test + public void testReplicationEnabled() { + // enable the replication + MirrorConnectorConfig config = new MirrorConnectorConfig(makeProps("enabled", "true")); + + List knownConsumerGroups = new ArrayList<>(); + knownConsumerGroups.add(CONSUMER_GROUP); + // MirrorCheckpointConnector as minimum to run taskConfig() + MirrorCheckpointConnector connector = new MirrorCheckpointConnector(knownConsumerGroups, config); + List> output = connector.taskConfigs(1); + // expect 1 task will be created + assertEquals(1, output.size(), "Replication for consumer-group-1 has incorrect size"); + assertEquals(CONSUMER_GROUP, output.get(0).get(MirrorConnectorConfig.TASK_CONSUMER_GROUPS), + "Replication for consumer-group-1 failed"); + } + + @Test + public void testFindConsumerGroups() throws Exception { + MirrorConnectorConfig config = new MirrorConnectorConfig(makeProps()); + MirrorCheckpointConnector connector = new MirrorCheckpointConnector(Collections.emptyList(), config); + connector = spy(connector); + + Collection groups = Arrays.asList( + new ConsumerGroupListing("g1", true), + new ConsumerGroupListing("g2", false)); + doReturn(groups).when(connector).listConsumerGroups(); + doReturn(true).when(connector).shouldReplicate(anyString()); + List groupFound = connector.findConsumerGroups(); + + Set expectedGroups = groups.stream().map(ConsumerGroupListing::groupId).collect(Collectors.toSet()); + assertEquals(expectedGroups, new HashSet<>(groupFound), + "Expected groups are not the same as findConsumerGroups"); + } + +} diff --git a/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/MirrorCheckpointTaskTest.java b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/MirrorCheckpointTaskTest.java new file mode 100644 index 0000000..7ef878a --- /dev/null +++ b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/MirrorCheckpointTaskTest.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Collections; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.connect.source.SourceRecord; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class MirrorCheckpointTaskTest { + + @Test + public void testDownstreamTopicRenaming() { + MirrorCheckpointTask mirrorCheckpointTask = new MirrorCheckpointTask("source1", "target2", + new DefaultReplicationPolicy(), null, Collections.emptyMap(), Collections.emptyMap()); + assertEquals(new TopicPartition("source1.topic3", 4), + mirrorCheckpointTask.renameTopicPartition(new TopicPartition("topic3", 4)), + "Renaming source1.topic3 failed"); + assertEquals(new TopicPartition("topic3", 5), + mirrorCheckpointTask.renameTopicPartition(new TopicPartition("target2.topic3", 5)), + "Renaming target2.topic3 failed"); + assertEquals(new TopicPartition("source1.source6.topic7", 8), + mirrorCheckpointTask.renameTopicPartition(new TopicPartition("source6.topic7", 8)), + "Renaming source1.source6.topic7 failed"); + } + + @Test + public void testCheckpoint() { + OffsetSyncStoreTest.FakeOffsetSyncStore offsetSyncStore = new OffsetSyncStoreTest.FakeOffsetSyncStore(); + MirrorCheckpointTask mirrorCheckpointTask = new MirrorCheckpointTask("source1", "target2", + new DefaultReplicationPolicy(), offsetSyncStore, Collections.emptyMap(), Collections.emptyMap()); + offsetSyncStore.sync(new TopicPartition("topic1", 2), 3L, 4L); + offsetSyncStore.sync(new TopicPartition("target2.topic5", 6), 7L, 8L); + Checkpoint checkpoint1 = mirrorCheckpointTask.checkpoint("group9", new TopicPartition("topic1", 2), + new OffsetAndMetadata(10, null)); + SourceRecord sourceRecord1 = mirrorCheckpointTask.checkpointRecord(checkpoint1, 123L); + assertEquals(new TopicPartition("source1.topic1", 2), checkpoint1.topicPartition(), + "checkpoint group9 source1.topic1 failed"); + assertEquals("group9", checkpoint1.consumerGroupId(), + "checkpoint group9 consumerGroupId failed"); + assertEquals("group9", Checkpoint.unwrapGroup(sourceRecord1.sourcePartition()), + "checkpoint group9 sourcePartition failed"); + assertEquals(10, checkpoint1.upstreamOffset(), + "checkpoint group9 upstreamOffset failed"); + assertEquals(11, checkpoint1.downstreamOffset(), + "checkpoint group9 downstreamOffset failed"); + assertEquals(123L, sourceRecord1.timestamp().longValue(), + "checkpoint group9 timestamp failed"); + Checkpoint checkpoint2 = mirrorCheckpointTask.checkpoint("group11", new TopicPartition("target2.topic5", 6), + new OffsetAndMetadata(12, null)); + SourceRecord sourceRecord2 = mirrorCheckpointTask.checkpointRecord(checkpoint2, 234L); + assertEquals(new TopicPartition("topic5", 6), checkpoint2.topicPartition(), + "checkpoint group11 topic5 failed"); + assertEquals("group11", checkpoint2.consumerGroupId(), + "checkpoint group11 consumerGroupId failed"); + assertEquals("group11", Checkpoint.unwrapGroup(sourceRecord2.sourcePartition()), + "checkpoint group11 sourcePartition failed"); + assertEquals(12, checkpoint2.upstreamOffset(), + "checkpoint group11 upstreamOffset failed"); + assertEquals(13, checkpoint2.downstreamOffset(), + "checkpoint group11 downstreamOffset failed"); + assertEquals(234L, sourceRecord2.timestamp().longValue(), + "checkpoint group11 timestamp failed"); + } + + @Test + public void testSyncOffset() { + Map> idleConsumerGroupsOffset = new HashMap<>(); + Map> checkpointsPerConsumerGroup = new HashMap<>(); + + String consumer1 = "consumer1"; + String consumer2 = "consumer2"; + + String topic1 = "topic1"; + String topic2 = "topic2"; + + // 'c1t1' denotes consumer offsets of all partitions of topic1 for consumer1 + Map c1t1 = new HashMap<>(); + // 't1p0' denotes topic1, partition 0 + TopicPartition t1p0 = new TopicPartition(topic1, 0); + + c1t1.put(t1p0, new OffsetAndMetadata(100)); + + Map c2t2 = new HashMap<>(); + TopicPartition t2p0 = new TopicPartition(topic2, 0); + + c2t2.put(t2p0, new OffsetAndMetadata(50)); + + idleConsumerGroupsOffset.put(consumer1, c1t1); + idleConsumerGroupsOffset.put(consumer2, c2t2); + + // 'cpC1T1P0' denotes 'checkpoint' of topic1, partition 0 for consumer1 + Checkpoint cpC1T1P0 = new Checkpoint(consumer1, new TopicPartition(topic1, 0), 200, 101, "metadata"); + + // 'cpC2T2p0' denotes 'checkpoint' of topic2, partition 0 for consumer2 + Checkpoint cpC2T2P0 = new Checkpoint(consumer2, new TopicPartition(topic2, 0), 100, 51, "metadata"); + + // 'checkpointListC1' denotes 'checkpoint' list for consumer1 + List checkpointListC1 = new ArrayList<>(); + checkpointListC1.add(cpC1T1P0); + + // 'checkpointListC2' denotes 'checkpoint' list for consumer2 + List checkpointListC2 = new ArrayList<>(); + checkpointListC2.add(cpC2T2P0); + + checkpointsPerConsumerGroup.put(consumer1, checkpointListC1); + checkpointsPerConsumerGroup.put(consumer2, checkpointListC2); + + MirrorCheckpointTask mirrorCheckpointTask = new MirrorCheckpointTask("source1", "target2", + new DefaultReplicationPolicy(), null, idleConsumerGroupsOffset, checkpointsPerConsumerGroup); + + Map> output = mirrorCheckpointTask.syncGroupOffset(); + + assertEquals(101, output.get(consumer1).get(t1p0).offset(), + "Consumer 1 " + topic1 + " failed"); + assertEquals(51, output.get(consumer2).get(t2p0).offset(), + "Consumer 2 " + topic2 + " failed"); + } +} diff --git a/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/MirrorConnectorConfigTest.java b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/MirrorConnectorConfigTest.java new file mode 100644 index 0000000..c7f629e --- /dev/null +++ b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/MirrorConnectorConfigTest.java @@ -0,0 +1,330 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigException; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.HashMap; +import java.util.HashSet; + +import static org.apache.kafka.connect.mirror.TestUtils.makeProps; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class MirrorConnectorConfigTest { + + @Test + public void testTaskConfigTopicPartitions() { + List topicPartitions = Arrays.asList(new TopicPartition("topic-1", 2), + new TopicPartition("topic-3", 4), new TopicPartition("topic-5", 6)); + MirrorConnectorConfig config = new MirrorConnectorConfig(makeProps()); + Map props = config.taskConfigForTopicPartitions(topicPartitions); + MirrorTaskConfig taskConfig = new MirrorTaskConfig(props); + assertEquals(taskConfig.taskTopicPartitions(), new HashSet<>(topicPartitions), + "Setting topic property configuration failed"); + } + + @Test + public void testTaskConfigConsumerGroups() { + List groups = Arrays.asList("consumer-1", "consumer-2", "consumer-3"); + MirrorConnectorConfig config = new MirrorConnectorConfig(makeProps()); + Map props = config.taskConfigForConsumerGroups(groups); + MirrorTaskConfig taskConfig = new MirrorTaskConfig(props); + assertEquals(taskConfig.taskConsumerGroups(), new HashSet<>(groups), + "Setting consumer groups property configuration failed"); + } + + @Test + public void testTopicMatching() { + MirrorConnectorConfig config = new MirrorConnectorConfig(makeProps("topics", "topic1")); + assertTrue(config.topicFilter().shouldReplicateTopic("topic1"), + "topic1 replication property configuration failed"); + assertFalse(config.topicFilter().shouldReplicateTopic("topic2"), + "topic2 replication property configuration failed"); + } + + @Test + public void testGroupMatching() { + MirrorConnectorConfig config = new MirrorConnectorConfig(makeProps("groups", "group1")); + assertTrue(config.groupFilter().shouldReplicateGroup("group1"), + "topic1 group matching property configuration failed"); + assertFalse(config.groupFilter().shouldReplicateGroup("group2"), + "topic2 group matching property configuration failed"); + } + + @Test + public void testConfigPropertyMatching() { + MirrorConnectorConfig config = new MirrorConnectorConfig( + makeProps("config.properties.exclude", "prop2")); + assertTrue(config.configPropertyFilter().shouldReplicateConfigProperty("prop1"), + "config.properties.exclude incorrectly excluded prop1"); + assertFalse(config.configPropertyFilter().shouldReplicateConfigProperty("prop2"), + "config.properties.exclude incorrectly included prop2"); + } + + @Test + public void testConfigBackwardsCompatibility() { + MirrorConnectorConfig config = new MirrorConnectorConfig( + makeProps("config.properties.blacklist", "prop1", + "groups.blacklist", "group-1", + "topics.blacklist", "topic-1")); + assertFalse(config.configPropertyFilter().shouldReplicateConfigProperty("prop1")); + assertTrue(config.configPropertyFilter().shouldReplicateConfigProperty("prop2")); + assertFalse(config.topicFilter().shouldReplicateTopic("topic-1")); + assertTrue(config.topicFilter().shouldReplicateTopic("topic-2")); + assertFalse(config.groupFilter().shouldReplicateGroup("group-1")); + assertTrue(config.groupFilter().shouldReplicateGroup("group-2")); + } + + @Test + public void testNoTopics() { + MirrorConnectorConfig config = new MirrorConnectorConfig(makeProps("topics", "")); + assertFalse(config.topicFilter().shouldReplicateTopic("topic1"), "topic1 shouldn't exist"); + assertFalse(config.topicFilter().shouldReplicateTopic("topic2"), "topic2 shouldn't exist"); + assertFalse(config.topicFilter().shouldReplicateTopic(""), "Empty topic shouldn't exist"); + } + + @Test + public void testAllTopics() { + MirrorConnectorConfig config = new MirrorConnectorConfig(makeProps("topics", ".*")); + assertTrue(config.topicFilter().shouldReplicateTopic("topic1"), + "topic1 created from wildcard should exist"); + assertTrue(config.topicFilter().shouldReplicateTopic("topic2"), + "topic2 created from wildcard should exist"); + } + + @Test + public void testListOfTopics() { + MirrorConnectorConfig config = new MirrorConnectorConfig(makeProps("topics", "topic1, topic2")); + assertTrue(config.topicFilter().shouldReplicateTopic("topic1"), "topic1 created from list should exist"); + assertTrue(config.topicFilter().shouldReplicateTopic("topic2"), "topic2 created from list should exist"); + assertFalse(config.topicFilter().shouldReplicateTopic("topic3"), "topic3 created from list should exist"); + } + + @Test + public void testNonMutationOfConfigDef() { + Collection taskSpecificProperties = Arrays.asList( + MirrorConnectorConfig.TASK_TOPIC_PARTITIONS, + MirrorConnectorConfig.TASK_CONSUMER_GROUPS + ); + + // Sanity check to make sure that these properties are actually defined for the task config, + // and that the task config class has been loaded and statically initialized by the JVM + ConfigDef taskConfigDef = MirrorTaskConfig.TASK_CONFIG_DEF; + taskSpecificProperties.forEach(taskSpecificProperty -> assertTrue( + taskConfigDef.names().contains(taskSpecificProperty), + taskSpecificProperty + " should be defined for task ConfigDef" + )); + + // Ensure that the task config class hasn't accidentally modified the connector config + ConfigDef connectorConfigDef = MirrorConnectorConfig.CONNECTOR_CONFIG_DEF; + taskSpecificProperties.forEach(taskSpecificProperty -> assertFalse( + connectorConfigDef.names().contains(taskSpecificProperty), + taskSpecificProperty + " should not be defined for connector ConfigDef" + )); + } + + @Test + public void testSourceConsumerConfig() { + Map connectorProps = makeProps( + MirrorConnectorConfig.CONSUMER_CLIENT_PREFIX + "max.poll.interval.ms", "120000" + ); + MirrorConnectorConfig config = new MirrorConnectorConfig(connectorProps); + Map connectorConsumerProps = config.sourceConsumerConfig(); + Map expectedConsumerProps = new HashMap<>(); + expectedConsumerProps.put("enable.auto.commit", "false"); + expectedConsumerProps.put("auto.offset.reset", "earliest"); + expectedConsumerProps.put("max.poll.interval.ms", "120000"); + assertEquals(expectedConsumerProps, connectorConsumerProps); + + // checking auto.offset.reset override works + connectorProps = makeProps( + MirrorConnectorConfig.CONSUMER_CLIENT_PREFIX + "auto.offset.reset", "latest" + ); + config = new MirrorConnectorConfig(connectorProps); + connectorConsumerProps = config.sourceConsumerConfig(); + expectedConsumerProps.put("auto.offset.reset", "latest"); + expectedConsumerProps.remove("max.poll.interval.ms"); + assertEquals(expectedConsumerProps, connectorConsumerProps, + MirrorConnectorConfig.CONSUMER_CLIENT_PREFIX + " source consumer config not matching"); + } + + @Test + public void testSourceConsumerConfigWithSourcePrefix() { + String prefix = MirrorConnectorConfig.SOURCE_PREFIX + MirrorConnectorConfig.CONSUMER_CLIENT_PREFIX; + Map connectorProps = makeProps( + prefix + "auto.offset.reset", "latest", + prefix + "max.poll.interval.ms", "100" + ); + MirrorConnectorConfig config = new MirrorConnectorConfig(connectorProps); + Map connectorConsumerProps = config.sourceConsumerConfig(); + Map expectedConsumerProps = new HashMap<>(); + expectedConsumerProps.put("enable.auto.commit", "false"); + expectedConsumerProps.put("auto.offset.reset", "latest"); + expectedConsumerProps.put("max.poll.interval.ms", "100"); + assertEquals(expectedConsumerProps, connectorConsumerProps, + prefix + " source consumer config not matching"); + } + + @Test + public void testSourceProducerConfig() { + Map connectorProps = makeProps( + MirrorConnectorConfig.PRODUCER_CLIENT_PREFIX + "acks", "1" + ); + MirrorConnectorConfig config = new MirrorConnectorConfig(connectorProps); + Map connectorProducerProps = config.sourceProducerConfig(); + Map expectedProducerProps = new HashMap<>(); + expectedProducerProps.put("acks", "1"); + assertEquals(expectedProducerProps, connectorProducerProps, + MirrorConnectorConfig.PRODUCER_CLIENT_PREFIX + " source product config not matching"); + } + + @Test + public void testSourceProducerConfigWithSourcePrefix() { + String prefix = MirrorConnectorConfig.SOURCE_PREFIX + MirrorConnectorConfig.PRODUCER_CLIENT_PREFIX; + Map connectorProps = makeProps(prefix + "acks", "1"); + MirrorConnectorConfig config = new MirrorConnectorConfig(connectorProps); + Map connectorProducerProps = config.sourceProducerConfig(); + Map expectedProducerProps = new HashMap<>(); + expectedProducerProps.put("acks", "1"); + assertEquals(expectedProducerProps, connectorProducerProps, + prefix + " source producer config not matching"); + } + + @Test + public void testSourceAdminConfig() { + Map connectorProps = makeProps( + MirrorConnectorConfig.ADMIN_CLIENT_PREFIX + + "connections.max.idle.ms", "10000" + ); + MirrorConnectorConfig config = new MirrorConnectorConfig(connectorProps); + Map connectorAdminProps = config.sourceAdminConfig(); + Map expectedAdminProps = new HashMap<>(); + expectedAdminProps.put("connections.max.idle.ms", "10000"); + assertEquals(expectedAdminProps, connectorAdminProps, + MirrorConnectorConfig.ADMIN_CLIENT_PREFIX + " source connector admin props not matching"); + } + + @Test + public void testSourceAdminConfigWithSourcePrefix() { + String prefix = MirrorConnectorConfig.SOURCE_PREFIX + MirrorConnectorConfig.ADMIN_CLIENT_PREFIX; + Map connectorProps = makeProps(prefix + "connections.max.idle.ms", "10000"); + MirrorConnectorConfig config = new MirrorConnectorConfig(connectorProps); + Map connectorAdminProps = config.sourceAdminConfig(); + Map expectedAdminProps = new HashMap<>(); + expectedAdminProps.put("connections.max.idle.ms", "10000"); + assertEquals(expectedAdminProps, connectorAdminProps, prefix + " source connector admin props not matching"); + } + + @Test + public void testTargetAdminConfig() { + Map connectorProps = makeProps( + MirrorConnectorConfig.ADMIN_CLIENT_PREFIX + + "connections.max.idle.ms", "10000" + ); + MirrorConnectorConfig config = new MirrorConnectorConfig(connectorProps); + Map connectorAdminProps = config.targetAdminConfig(); + Map expectedAdminProps = new HashMap<>(); + expectedAdminProps.put("connections.max.idle.ms", "10000"); + assertEquals(expectedAdminProps, connectorAdminProps, + MirrorConnectorConfig.ADMIN_CLIENT_PREFIX + " target connector admin props not matching"); + } + + @Test + public void testTargetAdminConfigWithSourcePrefix() { + String prefix = MirrorConnectorConfig.TARGET_PREFIX + MirrorConnectorConfig.ADMIN_CLIENT_PREFIX; + Map connectorProps = makeProps(prefix + "connections.max.idle.ms", "10000"); + MirrorConnectorConfig config = new MirrorConnectorConfig(connectorProps); + Map connectorAdminProps = config.targetAdminConfig(); + Map expectedAdminProps = new HashMap<>(); + expectedAdminProps.put("connections.max.idle.ms", "10000"); + assertEquals(expectedAdminProps, connectorAdminProps, prefix + " source connector admin props not matching"); + } + + @Test + public void testOffsetSyncsTopic() { + // Invalid location + Map connectorProps = makeProps("offset-syncs.topic.location", "something"); + assertThrows(ConfigException.class, () -> new MirrorConnectorConfig(connectorProps)); + + connectorProps.put("offset-syncs.topic.location", "source"); + MirrorConnectorConfig config = new MirrorConnectorConfig(connectorProps); + assertEquals("mm2-offset-syncs.target2.internal", config.offsetSyncsTopic()); + connectorProps.put("offset-syncs.topic.location", "target"); + config = new MirrorConnectorConfig(connectorProps); + assertEquals("mm2-offset-syncs.source1.internal", config.offsetSyncsTopic()); + // Default to source + connectorProps.remove("offset-syncs.topic.location"); + config = new MirrorConnectorConfig(connectorProps); + assertEquals("mm2-offset-syncs.target2.internal", config.offsetSyncsTopic()); + } + + @Test + public void testConsumerConfigsForOffsetSyncsTopic() { + Map connectorProps = makeProps( + "source.consumer.max.partition.fetch.bytes", "1", + "target.consumer.heartbeat.interval.ms", "1", + "consumer.max.poll.interval.ms", "1", + "fetch.min.bytes", "1" + ); + MirrorConnectorConfig config = new MirrorConnectorConfig(connectorProps); + assertEquals(config.sourceConsumerConfig(), config.offsetSyncsTopicConsumerConfig()); + connectorProps.put("offset-syncs.topic.location", "target"); + config = new MirrorConnectorConfig(connectorProps); + assertEquals(config.targetConsumerConfig(), config.offsetSyncsTopicConsumerConfig()); + } + + @Test + public void testProducerConfigsForOffsetSyncsTopic() { + Map connectorProps = makeProps( + "source.producer.batch.size", "1", + "target.producer.acks", "1", + "producer.max.poll.interval.ms", "1", + "fetch.min.bytes", "1" + ); + MirrorConnectorConfig config = new MirrorConnectorConfig(connectorProps); + assertEquals(config.sourceProducerConfig(), config.offsetSyncsTopicProducerConfig()); + connectorProps.put("offset-syncs.topic.location", "target"); + config = new MirrorConnectorConfig(connectorProps); + assertEquals(config.targetProducerConfig(), config.offsetSyncsTopicProducerConfig()); + } + + @Test + public void testAdminConfigsForOffsetSyncsTopic() { + Map connectorProps = makeProps( + "source.admin.request.timeout.ms", "1", + "target.admin.send.buffer.bytes", "1", + "admin.reconnect.backoff.max.ms", "1", + "retries", "123" + ); + MirrorConnectorConfig config = new MirrorConnectorConfig(connectorProps); + assertEquals(config.sourceAdminConfig(), config.offsetSyncsTopicAdminConfig()); + connectorProps.put("offset-syncs.topic.location", "target"); + config = new MirrorConnectorConfig(connectorProps); + assertEquals(config.targetAdminConfig(), config.offsetSyncsTopicAdminConfig()); + } + +} diff --git a/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/MirrorHeartBeatConnectorTest.java b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/MirrorHeartBeatConnectorTest.java new file mode 100644 index 0000000..ec06919 --- /dev/null +++ b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/MirrorHeartBeatConnectorTest.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import static org.apache.kafka.connect.mirror.TestUtils.makeProps; +import static org.junit.jupiter.api.Assertions.assertEquals; +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; + +public class MirrorHeartBeatConnectorTest { + + @Test + public void testMirrorHeartbeatConnectorDisabled() { + // disable the heartbeat emission + MirrorConnectorConfig config = new MirrorConnectorConfig( + makeProps("emit.heartbeats.enabled", "false")); + + // MirrorHeartbeatConnector as minimum to run taskConfig() + MirrorHeartbeatConnector connector = new MirrorHeartbeatConnector(config); + List> output = connector.taskConfigs(1); + // expect no task will be created + assertEquals(0, output.size(), "Expected task to not be created"); + } + + @Test + public void testReplicationDisabled() { + // disable the replication + MirrorConnectorConfig config = new MirrorConnectorConfig( + makeProps("enabled", "false")); + + // MirrorHeartbeatConnector as minimum to run taskConfig() + MirrorHeartbeatConnector connector = new MirrorHeartbeatConnector(config); + List> output = connector.taskConfigs(1); + // expect one task will be created, even the replication is disabled + assertEquals(1, output.size(), "Task should have been created even with replication disabled"); + } +} diff --git a/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/MirrorHeartbeatTaskTest.java b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/MirrorHeartbeatTaskTest.java new file mode 100644 index 0000000..39fd6df --- /dev/null +++ b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/MirrorHeartbeatTaskTest.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.connect.source.SourceRecord; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class MirrorHeartbeatTaskTest { + + @Test + public void testPollCreatesRecords() throws InterruptedException { + MirrorHeartbeatTask heartbeatTask = new MirrorHeartbeatTask(); + heartbeatTask.start(TestUtils.makeProps("source.cluster.alias", "testSource", + "target.cluster.alias", "testTarget")); + List records = heartbeatTask.poll(); + assertEquals(1, records.size()); + Map sourcePartition = records.iterator().next().sourcePartition(); + assertEquals(sourcePartition.get(Heartbeat.SOURCE_CLUSTER_ALIAS_KEY), "testSource", + "sourcePartition's " + Heartbeat.SOURCE_CLUSTER_ALIAS_KEY + " record was not created"); + assertEquals(sourcePartition.get(Heartbeat.TARGET_CLUSTER_ALIAS_KEY), "testTarget", + "sourcePartition's " + Heartbeat.TARGET_CLUSTER_ALIAS_KEY + " record was not created"); + } +} diff --git a/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/MirrorMakerConfigTest.java b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/MirrorMakerConfigTest.java new file mode 100644 index 0000000..41bcacb --- /dev/null +++ b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/MirrorMakerConfigTest.java @@ -0,0 +1,355 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.common.config.types.Password; +import org.apache.kafka.common.config.provider.ConfigProvider; +import org.apache.kafka.common.config.ConfigData; +import org.apache.kafka.common.metrics.FakeMetricsReporter; + +import org.junit.jupiter.api.Test; + +import java.util.Map; +import java.util.Set; +import java.util.Collections; +import java.util.HashMap; +import java.util.Arrays; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; + +public class MirrorMakerConfigTest { + + private Map makeProps(String... keyValues) { + Map props = new HashMap<>(); + for (int i = 0; i < keyValues.length; i += 2) { + props.put(keyValues[i], keyValues[i + 1]); + } + return props; + } + + @Test + public void testClusterConfigProperties() { + MirrorMakerConfig mirrorConfig = new MirrorMakerConfig(makeProps( + "clusters", "a, b", + "a.bootstrap.servers", "servers-one", + "b.bootstrap.servers", "servers-two", + "security.protocol", "SASL", + "replication.factor", "4")); + Map connectorProps = mirrorConfig.connectorBaseConfig(new SourceAndTarget("a", "b"), + MirrorSourceConnector.class); + assertEquals("servers-one", connectorProps.get("source.cluster.bootstrap.servers"), + "source.cluster.bootstrap.servers is set"); + assertEquals("servers-two", connectorProps.get("target.cluster.bootstrap.servers"), + "target.cluster.bootstrap.servers is set"); + assertEquals("SASL", connectorProps.get("security.protocol"), + "top-level security.protocol is passed through to connector config"); + } + + @Test + public void testReplicationConfigProperties() { + MirrorMakerConfig mirrorConfig = new MirrorMakerConfig(makeProps( + "clusters", "a, b", + "a->b.tasks.max", "123")); + Map connectorProps = mirrorConfig.connectorBaseConfig(new SourceAndTarget("a", "b"), + MirrorSourceConnector.class); + assertEquals("123", connectorProps.get("tasks.max"), "connector props should include tasks.max"); + } + + @Test + public void testClientConfigProperties() { + MirrorMakerConfig mirrorConfig = new MirrorMakerConfig(makeProps( + "clusters", "a, b", + "config.providers", "fake", + "config.providers.fake.class", FakeConfigProvider.class.getName(), + "replication.policy.separator", "__", + "ssl.truststore.password", "secret1", + "ssl.key.password", "${fake:secret:password}", // resolves to "secret2" + "security.protocol", "SSL", + "a.security.protocol", "PLAINTEXT", + "a.producer.security.protocol", "SASL", + "a.bootstrap.servers", "one:9092, two:9092", + "metrics.reporter", FakeMetricsReporter.class.getName(), + "a.metrics.reporter", FakeMetricsReporter.class.getName(), + "b->a.metrics.reporter", FakeMetricsReporter.class.getName(), + "a.xxx", "yyy", + "xxx", "zzz")); + MirrorClientConfig aClientConfig = mirrorConfig.clientConfig("a"); + MirrorClientConfig bClientConfig = mirrorConfig.clientConfig("b"); + assertEquals("__", aClientConfig.getString("replication.policy.separator"), + "replication.policy.separator is picked up in MirrorClientConfig"); + assertEquals("b__topic1", aClientConfig.replicationPolicy().formatRemoteTopic("b", "topic1"), + "replication.policy.separator is honored"); + assertEquals("one:9092, two:9092", aClientConfig.adminConfig().get("bootstrap.servers"), + "client configs include boostrap.servers"); + assertEquals("PLAINTEXT", aClientConfig.adminConfig().get("security.protocol"), + "client configs include security.protocol"); + assertEquals("SASL", aClientConfig.producerConfig().get("security.protocol"), + "producer configs include security.protocol"); + assertFalse(aClientConfig.adminConfig().containsKey("xxx"), + "unknown properties aren't included in client configs"); + assertFalse(aClientConfig.adminConfig().containsKey("metric.reporters"), + "top-leve metrics reporters aren't included in client configs"); + assertEquals("secret1", aClientConfig.getPassword("ssl.truststore.password").value(), + "security properties are picked up in MirrorClientConfig"); + assertEquals("secret1", ((Password) aClientConfig.adminConfig().get("ssl.truststore.password")).value(), + "client configs include top-level security properties"); + assertEquals("secret2", aClientConfig.getPassword("ssl.key.password").value(), + "security properties are translated from external sources"); + assertEquals("secret2", ((Password) aClientConfig.adminConfig().get("ssl.key.password")).value(), + "client configs are translated from external sources"); + assertFalse(aClientConfig.producerConfig().containsKey("metrics.reporter"), + "client configs should not include metrics reporter"); + assertFalse(bClientConfig.adminConfig().containsKey("metrics.reporter"), + "client configs should not include metrics reporter"); + } + + @Test + public void testIncludesConnectorConfigProperties() { + MirrorMakerConfig mirrorConfig = new MirrorMakerConfig(makeProps( + "clusters", "a, b", + "tasks.max", "100", + "topics", "topic-1", + "groups", "group-2", + "replication.policy.separator", "__", + "config.properties.exclude", "property-3", + "metric.reporters", "FakeMetricsReporter", + "topic.filter.class", DefaultTopicFilter.class.getName(), + "xxx", "yyy")); + SourceAndTarget sourceAndTarget = new SourceAndTarget("source", "target"); + Map connectorProps = mirrorConfig.connectorBaseConfig(sourceAndTarget, + MirrorSourceConnector.class); + MirrorConnectorConfig connectorConfig = new MirrorConnectorConfig(connectorProps); + assertEquals(100, (int) connectorConfig.getInt("tasks.max"), + "Connector properties like tasks.max should be passed through to underlying Connectors."); + assertEquals(Collections.singletonList("topic-1"), connectorConfig.getList("topics"), + "Topics include should be passed through to underlying Connectors."); + assertEquals(Collections.singletonList("group-2"), connectorConfig.getList("groups"), + "Groups include should be passed through to underlying Connectors."); + assertEquals(Collections.singletonList("property-3"), connectorConfig.getList("config.properties.exclude"), + "Config properties exclude should be passed through to underlying Connectors."); + assertEquals(Collections.singletonList("FakeMetricsReporter"), connectorConfig.getList("metric.reporters"), + "Metrics reporters should be passed through to underlying Connectors."); + assertEquals("DefaultTopicFilter", connectorConfig.getClass("topic.filter.class").getSimpleName(), + "Filters should be passed through to underlying Connectors."); + assertEquals("__", connectorConfig.getString("replication.policy.separator"), + "replication policy separator should be passed through to underlying Connectors."); + assertFalse(connectorConfig.originals().containsKey("xxx"), + "Unknown properties should not be passed through to Connectors."); + } + + @Test + public void testConfigBackwardsCompatibility() { + MirrorMakerConfig mirrorConfig = new MirrorMakerConfig(makeProps( + "clusters", "a, b", + "groups.blacklist", "group-7", + "topics.blacklist", "topic3", + "config.properties.blacklist", "property-3", + "topic.filter.class", DefaultTopicFilter.class.getName())); + SourceAndTarget sourceAndTarget = new SourceAndTarget("source", "target"); + Map connectorProps = mirrorConfig.connectorBaseConfig(sourceAndTarget, + MirrorSourceConnector.class); + MirrorConnectorConfig connectorConfig = new MirrorConnectorConfig(connectorProps); + DefaultTopicFilter.TopicFilterConfig filterConfig = + new DefaultTopicFilter.TopicFilterConfig(connectorProps); + + assertEquals(Collections.singletonList("topic3"), filterConfig.getList("topics.exclude"), + "Topics exclude should be backwards compatible."); + + assertEquals(Collections.singletonList("group-7"), connectorConfig.getList("groups.exclude"), + "Groups exclude should be backwards compatible."); + + assertEquals(Collections.singletonList("property-3"), connectorConfig.getList("config.properties.exclude"), + "Config properties exclude should be backwards compatible."); + + } + + @Test + public void testConfigBackwardsCompatibilitySourceTarget() { + MirrorMakerConfig mirrorConfig = new MirrorMakerConfig(makeProps( + "clusters", "a, b", + "source->target.topics.blacklist", "topic3", + "source->target.groups.blacklist", "group-7", + "topic.filter.class", DefaultTopicFilter.class.getName())); + SourceAndTarget sourceAndTarget = new SourceAndTarget("source", "target"); + Map connectorProps = mirrorConfig.connectorBaseConfig(sourceAndTarget, + MirrorSourceConnector.class); + MirrorConnectorConfig connectorConfig = new MirrorConnectorConfig(connectorProps); + DefaultTopicFilter.TopicFilterConfig filterConfig = + new DefaultTopicFilter.TopicFilterConfig(connectorProps); + + assertEquals(Collections.singletonList("topic3"), filterConfig.getList("topics.exclude"), + "Topics exclude should be backwards compatible."); + + assertEquals(Collections.singletonList("group-7"), connectorConfig.getList("groups.exclude"), + "Groups exclude should be backwards compatible."); + } + + @Test + public void testIncludesTopicFilterProperties() { + MirrorMakerConfig mirrorConfig = new MirrorMakerConfig(makeProps( + "clusters", "a, b", + "source->target.topics", "topic1, topic2", + "source->target.topics.exclude", "topic3")); + SourceAndTarget sourceAndTarget = new SourceAndTarget("source", "target"); + Map connectorProps = mirrorConfig.connectorBaseConfig(sourceAndTarget, + MirrorSourceConnector.class); + DefaultTopicFilter.TopicFilterConfig filterConfig = + new DefaultTopicFilter.TopicFilterConfig(connectorProps); + assertEquals(Arrays.asList("topic1", "topic2"), filterConfig.getList("topics"), + "source->target.topics should be passed through to TopicFilters."); + assertEquals(Collections.singletonList("topic3"), filterConfig.getList("topics.exclude"), + "source->target.topics.exclude should be passed through to TopicFilters."); + } + + @Test + public void testWorkerConfigs() { + MirrorMakerConfig mirrorConfig = new MirrorMakerConfig(makeProps( + "clusters", "a, b", + "config.providers", "fake", + "config.providers.fake.class", FakeConfigProvider.class.getName(), + "replication.policy.separator", "__", + "offset.storage.replication.factor", "123", + "b.status.storage.replication.factor", "456", + "b.producer.client.id", "client-one", + "b.security.protocol", "PLAINTEXT", + "b.producer.security.protocol", "SASL", + "ssl.truststore.password", "secret1", + "ssl.key.password", "${fake:secret:password}", // resolves to "secret2" + "b.xxx", "yyy")); + SourceAndTarget a = new SourceAndTarget("b", "a"); + SourceAndTarget b = new SourceAndTarget("a", "b"); + Map aProps = mirrorConfig.workerConfig(a); + assertEquals("123", aProps.get("offset.storage.replication.factor")); + assertEquals("__", aProps.get("replication.policy.separator")); + Map bProps = mirrorConfig.workerConfig(b); + assertEquals("456", bProps.get("status.storage.replication.factor")); + assertEquals("client-one", bProps.get("producer.client.id"), + "producer props should be passed through to worker producer config: " + bProps); + assertEquals("SASL", bProps.get("producer.security.protocol"), + "replication-level security props should be passed through to worker producer config"); + assertEquals("SASL", bProps.get("producer.security.protocol"), + "replication-level security props should be passed through to worker producer config"); + assertEquals("PLAINTEXT", bProps.get("consumer.security.protocol"), + "replication-level security props should be passed through to worker consumer config"); + assertEquals("secret1", bProps.get("ssl.truststore.password"), + "security properties should be passed through to worker config: " + bProps); + assertEquals("secret1", bProps.get("producer.ssl.truststore.password"), + "security properties should be passed through to worker producer config: " + bProps); + assertEquals("secret2", bProps.get("ssl.key.password"), + "security properties should be transformed in worker config"); + assertEquals("secret2", bProps.get("producer.ssl.key.password"), + "security properties should be transformed in worker producer config"); + assertEquals("__", bProps.get("replication.policy.separator")); + } + + @Test + public void testClusterPairsWithDefaultSettings() { + MirrorMakerConfig mirrorConfig = new MirrorMakerConfig(makeProps( + "clusters", "a, b, c")); + // implicit configuration associated + // a->b.enabled=false + // a->b.emit.heartbeat.enabled=true + // a->c.enabled=false + // a->c.emit.heartbeat.enabled=true + // b->a.enabled=false + // b->a.emit.heartbeat.enabled=true + // b->c.enabled=false + // b->c.emit.heartbeat.enabled=true + // c->a.enabled=false + // c->a.emit.heartbeat.enabled=true + // c->b.enabled=false + // c->b.emit.heartbeat.enabled=true + List clusterPairs = mirrorConfig.clusterPairs(); + assertEquals(6, clusterPairs.size(), "clusterPairs count should match all combinations count"); + } + + @Test + public void testEmptyClusterPairsWithGloballyDisabledHeartbeats() { + MirrorMakerConfig mirrorConfig = new MirrorMakerConfig(makeProps( + "clusters", "a, b, c", + "emit.heartbeats.enabled", "false")); + assertEquals(0, mirrorConfig.clusterPairs().size(), "clusterPairs count should be 0"); + } + + @Test + public void testClusterPairsWithTwoDisabledHeartbeats() { + MirrorMakerConfig mirrorConfig = new MirrorMakerConfig(makeProps( + "clusters", "a, b, c", + "a->b.emit.heartbeats.enabled", "false", + "a->c.emit.heartbeats.enabled", "false")); + List clusterPairs = mirrorConfig.clusterPairs(); + assertEquals(4, clusterPairs.size(), + "clusterPairs count should match all combinations count except x->y.emit.heartbeats.enabled=false"); + } + + @Test + public void testClusterPairsWithGloballyDisabledHeartbeats() { + MirrorMakerConfig mirrorConfig = new MirrorMakerConfig(makeProps( + "clusters", "a, b, c, d, e, f", + "emit.heartbeats.enabled", "false", + "a->b.enabled", "true", + "a->c.enabled", "true", + "a->d.enabled", "true", + "a->e.enabled", "false", + "a->f.enabled", "false")); + List clusterPairs = mirrorConfig.clusterPairs(); + assertEquals(3, clusterPairs.size(), + "clusterPairs count should match (x->y.enabled=true or x->y.emit.heartbeats.enabled=true) count"); + + // Link b->a.enabled doesn't exist therefore it must not be in clusterPairs + SourceAndTarget sourceAndTarget = new SourceAndTarget("b", "a"); + assertFalse(clusterPairs.contains(sourceAndTarget), "disabled/unset link x->y should not be in clusterPairs"); + } + + @Test + public void testClusterPairsWithGloballyDisabledHeartbeatsCentralLocal() { + MirrorMakerConfig mirrorConfig = new MirrorMakerConfig(makeProps( + "clusters", "central, local_one, local_two, beats_emitter", + "emit.heartbeats.enabled", "false", + "central->local_one.enabled", "true", + "central->local_two.enabled", "true", + "beats_emitter->central.emit.heartbeats.enabled", "true")); + + assertEquals(3, mirrorConfig.clusterPairs().size(), + "clusterPairs count should match (x->y.enabled=true or x->y.emit.heartbeats.enabled=true) count"); + } + + public static class FakeConfigProvider implements ConfigProvider { + + Map secrets = Collections.singletonMap("password", "secret2"); + + @Override + public void configure(Map props) { + } + + @Override + public void close() { + } + + @Override + public ConfigData get(String path) { + return new ConfigData(secrets); + } + + @Override + public ConfigData get(String path, Set keys) { + return get(path); + } + } +} diff --git a/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/MirrorSourceConnectorTest.java b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/MirrorSourceConnectorTest.java new file mode 100644 index 0000000..b4c8ca6 --- /dev/null +++ b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/MirrorSourceConnectorTest.java @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.acl.AccessControlEntry; +import org.apache.kafka.common.acl.AclBinding; +import org.apache.kafka.common.acl.AclOperation; +import org.apache.kafka.common.acl.AclPermissionType; +import org.apache.kafka.common.resource.PatternType; +import org.apache.kafka.common.resource.ResourcePattern; +import org.apache.kafka.common.resource.ResourceType; +import org.apache.kafka.clients.admin.Config; +import org.apache.kafka.connect.connector.ConnectorContext; +import org.apache.kafka.clients.admin.ConfigEntry; +import org.apache.kafka.clients.admin.NewTopic; + +import org.junit.jupiter.api.Test; + +import static org.apache.kafka.connect.mirror.MirrorConnectorConfig.TASK_TOPIC_PARTITIONS; +import static org.apache.kafka.connect.mirror.TestUtils.makeProps; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class MirrorSourceConnectorTest { + + @Test + public void testReplicatesHeartbeatsByDefault() { + MirrorSourceConnector connector = new MirrorSourceConnector(new SourceAndTarget("source", "target"), + new DefaultReplicationPolicy(), new DefaultTopicFilter(), new DefaultConfigPropertyFilter()); + assertTrue(connector.shouldReplicateTopic("heartbeats"), "should replicate heartbeats"); + assertTrue(connector.shouldReplicateTopic("us-west.heartbeats"), "should replicate upstream heartbeats"); + } + + @Test + public void testReplicatesHeartbeatsDespiteFilter() { + MirrorSourceConnector connector = new MirrorSourceConnector(new SourceAndTarget("source", "target"), + new DefaultReplicationPolicy(), x -> false, new DefaultConfigPropertyFilter()); + assertTrue(connector.shouldReplicateTopic("heartbeats"), "should replicate heartbeats"); + assertTrue(connector.shouldReplicateTopic("us-west.heartbeats"), "should replicate upstream heartbeats"); + } + + @Test + public void testNoCycles() { + MirrorSourceConnector connector = new MirrorSourceConnector(new SourceAndTarget("source", "target"), + new DefaultReplicationPolicy(), x -> true, x -> true); + assertFalse(connector.shouldReplicateTopic("target.topic1"), "should not allow cycles"); + assertFalse(connector.shouldReplicateTopic("target.source.topic1"), "should not allow cycles"); + assertFalse(connector.shouldReplicateTopic("source.target.topic1"), "should not allow cycles"); + assertFalse(connector.shouldReplicateTopic("target.source.target.topic1"), "should not allow cycles"); + assertFalse(connector.shouldReplicateTopic("source.target.source.topic1"), "should not allow cycles"); + assertTrue(connector.shouldReplicateTopic("topic1"), "should allow anything else"); + assertTrue(connector.shouldReplicateTopic("source.topic1"), "should allow anything else"); + } + + @Test + public void testIdentityReplication() { + MirrorSourceConnector connector = new MirrorSourceConnector(new SourceAndTarget("source", "target"), + new IdentityReplicationPolicy(), x -> true, x -> true); + assertTrue(connector.shouldReplicateTopic("target.topic1"), "should allow cycles"); + assertTrue(connector.shouldReplicateTopic("target.source.topic1"), "should allow cycles"); + assertTrue(connector.shouldReplicateTopic("source.target.topic1"), "should allow cycles"); + assertTrue(connector.shouldReplicateTopic("target.source.target.topic1"), "should allow cycles"); + assertTrue(connector.shouldReplicateTopic("source.target.source.topic1"), "should allow cycles"); + assertTrue(connector.shouldReplicateTopic("topic1"), "should allow normal topics"); + assertTrue(connector.shouldReplicateTopic("othersource.topic1"), "should allow normal topics"); + assertFalse(connector.shouldReplicateTopic("target.heartbeats"), "should not allow heartbeat cycles"); + assertFalse(connector.shouldReplicateTopic("target.source.heartbeats"), "should not allow heartbeat cycles"); + assertFalse(connector.shouldReplicateTopic("source.target.heartbeats"), "should not allow heartbeat cycles"); + assertFalse(connector.shouldReplicateTopic("target.source.target.heartbeats"), "should not allow heartbeat cycles"); + assertFalse(connector.shouldReplicateTopic("source.target.source.heartbeats"), "should not allow heartbeat cycles"); + assertTrue(connector.shouldReplicateTopic("heartbeats"), "should allow heartbeat topics"); + assertTrue(connector.shouldReplicateTopic("othersource.heartbeats"), "should allow heartbeat topics"); + } + + @Test + public void testAclFiltering() { + MirrorSourceConnector connector = new MirrorSourceConnector(new SourceAndTarget("source", "target"), + new DefaultReplicationPolicy(), x -> true, x -> true); + assertFalse(connector.shouldReplicateAcl( + new AclBinding(new ResourcePattern(ResourceType.TOPIC, "test_topic", PatternType.LITERAL), + new AccessControlEntry("kafka", "", AclOperation.WRITE, AclPermissionType.ALLOW))), "should not replicate ALLOW WRITE"); + assertTrue(connector.shouldReplicateAcl( + new AclBinding(new ResourcePattern(ResourceType.TOPIC, "test_topic", PatternType.LITERAL), + new AccessControlEntry("kafka", "", AclOperation.ALL, AclPermissionType.ALLOW))), "should replicate ALLOW ALL"); + } + + @Test + public void testAclTransformation() { + MirrorSourceConnector connector = new MirrorSourceConnector(new SourceAndTarget("source", "target"), + new DefaultReplicationPolicy(), x -> true, x -> true); + AclBinding allowAllAclBinding = new AclBinding( + new ResourcePattern(ResourceType.TOPIC, "test_topic", PatternType.LITERAL), + new AccessControlEntry("kafka", "", AclOperation.ALL, AclPermissionType.ALLOW)); + AclBinding processedAllowAllAclBinding = connector.targetAclBinding(allowAllAclBinding); + String expectedRemoteTopicName = "source" + DefaultReplicationPolicy.SEPARATOR_DEFAULT + + allowAllAclBinding.pattern().name(); + assertEquals(expectedRemoteTopicName, processedAllowAllAclBinding.pattern().name(), "should change topic name"); + assertEquals(processedAllowAllAclBinding.entry().operation(), AclOperation.READ, "should change ALL to READ"); + assertEquals(processedAllowAllAclBinding.entry().permissionType(), AclPermissionType.ALLOW, "should not change ALLOW"); + + AclBinding denyAllAclBinding = new AclBinding( + new ResourcePattern(ResourceType.TOPIC, "test_topic", PatternType.LITERAL), + new AccessControlEntry("kafka", "", AclOperation.ALL, AclPermissionType.DENY)); + AclBinding processedDenyAllAclBinding = connector.targetAclBinding(denyAllAclBinding); + assertEquals(processedDenyAllAclBinding.entry().operation(), AclOperation.ALL, "should not change ALL"); + assertEquals(processedDenyAllAclBinding.entry().permissionType(), AclPermissionType.DENY, "should not change DENY"); + } + + @Test + public void testConfigPropertyFiltering() { + MirrorSourceConnector connector = new MirrorSourceConnector(new SourceAndTarget("source", "target"), + new DefaultReplicationPolicy(), x -> true, new DefaultConfigPropertyFilter()); + ArrayList entries = new ArrayList<>(); + entries.add(new ConfigEntry("name-1", "value-1")); + entries.add(new ConfigEntry("min.insync.replicas", "2")); + Config config = new Config(entries); + Config targetConfig = connector.targetConfig(config); + assertTrue(targetConfig.entries().stream() + .anyMatch(x -> x.name().equals("name-1")), "should replicate properties"); + assertFalse(targetConfig.entries().stream() + .anyMatch(x -> x.name().equals("min.insync.replicas")), "should not replicate excluded properties"); + } + + @Test + public void testMirrorSourceConnectorTaskConfig() { + List knownSourceTopicPartitions = new ArrayList<>(); + + // topic `t0` has 8 partitions + knownSourceTopicPartitions.add(new TopicPartition("t0", 0)); + knownSourceTopicPartitions.add(new TopicPartition("t0", 1)); + knownSourceTopicPartitions.add(new TopicPartition("t0", 2)); + knownSourceTopicPartitions.add(new TopicPartition("t0", 3)); + knownSourceTopicPartitions.add(new TopicPartition("t0", 4)); + knownSourceTopicPartitions.add(new TopicPartition("t0", 5)); + knownSourceTopicPartitions.add(new TopicPartition("t0", 6)); + knownSourceTopicPartitions.add(new TopicPartition("t0", 7)); + + // topic `t1` has 2 partitions + knownSourceTopicPartitions.add(new TopicPartition("t1", 0)); + knownSourceTopicPartitions.add(new TopicPartition("t1", 1)); + + // topic `t2` has 2 partitions + knownSourceTopicPartitions.add(new TopicPartition("t2", 0)); + knownSourceTopicPartitions.add(new TopicPartition("t2", 1)); + + // MirrorConnectorConfig example for test + MirrorConnectorConfig config = new MirrorConnectorConfig(makeProps()); + + // MirrorSourceConnector as minimum to run taskConfig() + MirrorSourceConnector connector = new MirrorSourceConnector(knownSourceTopicPartitions, config); + + // distribute the topic-partition to 3 tasks by round-robin + List> output = connector.taskConfigs(3); + + // the expected assignments over 3 tasks: + // t1 -> [t0p0, t0p3, t0p6, t1p1] + // t2 -> [t0p1, t0p4, t0p7, t2p0] + // t3 -> [t0p2, t0p5, t1p0, t2p1] + + Map t1 = output.get(0); + assertEquals("t0-0,t0-3,t0-6,t1-1", t1.get(TASK_TOPIC_PARTITIONS), "Config for t1 is incorrect"); + + Map t2 = output.get(1); + assertEquals("t0-1,t0-4,t0-7,t2-0", t2.get(TASK_TOPIC_PARTITIONS), "Config for t2 is incorrect"); + + Map t3 = output.get(2); + assertEquals("t0-2,t0-5,t1-0,t2-1", t3.get(TASK_TOPIC_PARTITIONS), "Config for t3 is incorrect"); + } + + @Test + public void testRefreshTopicPartitions() throws Exception { + MirrorSourceConnector connector = new MirrorSourceConnector(new SourceAndTarget("source", "target"), + new DefaultReplicationPolicy(), new DefaultTopicFilter(), new DefaultConfigPropertyFilter()); + connector.initialize(mock(ConnectorContext.class)); + connector = spy(connector); + + Config topicConfig = new Config(Arrays.asList( + new ConfigEntry("cleanup.policy", "compact"), + new ConfigEntry("segment.bytes", "100"))); + Map configs = Collections.singletonMap("topic", topicConfig); + + List sourceTopicPartitions = Collections.singletonList(new TopicPartition("topic", 0)); + doReturn(sourceTopicPartitions).when(connector).findSourceTopicPartitions(); + doReturn(Collections.emptyList()).when(connector).findTargetTopicPartitions(); + doReturn(configs).when(connector).describeTopicConfigs(Collections.singleton("topic")); + doNothing().when(connector).createNewTopics(any()); + + connector.refreshTopicPartitions(); + // if target topic is not created, refreshTopicPartitions() will call createTopicPartitions() again + connector.refreshTopicPartitions(); + + Map expectedPartitionCounts = new HashMap<>(); + expectedPartitionCounts.put("source.topic", 1L); + Map configMap = MirrorSourceConnector.configToMap(topicConfig); + assertEquals(2, configMap.size(), "configMap has incorrect size"); + + Map expectedNewTopics = new HashMap<>(); + expectedNewTopics.put("source.topic", new NewTopic("source.topic", 1, (short) 0).configs(configMap)); + + verify(connector, times(2)).computeAndCreateTopicPartitions(); + verify(connector, times(2)).createNewTopics(eq(expectedNewTopics)); + verify(connector, times(0)).createNewPartitions(any()); + + List targetTopicPartitions = Collections.singletonList(new TopicPartition("source.topic", 0)); + doReturn(targetTopicPartitions).when(connector).findTargetTopicPartitions(); + connector.refreshTopicPartitions(); + + // once target topic is created, refreshTopicPartitions() will NOT call computeAndCreateTopicPartitions() again + verify(connector, times(2)).computeAndCreateTopicPartitions(); + } + + @Test + public void testRefreshTopicPartitionsTopicOnTargetFirst() throws Exception { + MirrorSourceConnector connector = new MirrorSourceConnector(new SourceAndTarget("source", "target"), + new DefaultReplicationPolicy(), new DefaultTopicFilter(), new DefaultConfigPropertyFilter()); + connector.initialize(mock(ConnectorContext.class)); + connector = spy(connector); + + Config topicConfig = new Config(Arrays.asList( + new ConfigEntry("cleanup.policy", "compact"), + new ConfigEntry("segment.bytes", "100"))); + Map configs = Collections.singletonMap("source.topic", topicConfig); + + List sourceTopicPartitions = Collections.emptyList(); + List targetTopicPartitions = Collections.singletonList(new TopicPartition("source.topic", 0)); + doReturn(sourceTopicPartitions).when(connector).findSourceTopicPartitions(); + doReturn(targetTopicPartitions).when(connector).findTargetTopicPartitions(); + doReturn(configs).when(connector).describeTopicConfigs(Collections.singleton("source.topic")); + doReturn(Collections.emptyMap()).when(connector).describeTopicConfigs(Collections.emptySet()); + doNothing().when(connector).createNewTopics(any()); + doNothing().when(connector).createNewPartitions(any()); + + // partitions appearing on the target cluster should not cause reconfiguration + connector.refreshTopicPartitions(); + connector.refreshTopicPartitions(); + verify(connector, times(0)).computeAndCreateTopicPartitions(); + + sourceTopicPartitions = Collections.singletonList(new TopicPartition("topic", 0)); + doReturn(sourceTopicPartitions).when(connector).findSourceTopicPartitions(); + + // when partitions are added to the source cluster, reconfiguration is triggered + connector.refreshTopicPartitions(); + verify(connector, times(1)).computeAndCreateTopicPartitions(); + } +} diff --git a/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/MirrorSourceTaskTest.java b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/MirrorSourceTaskTest.java new file mode 100644 index 0000000..feb2f7f --- /dev/null +++ b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/MirrorSourceTaskTest.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.connect.source.SourceRecord; + +import org.junit.jupiter.api.Test; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class MirrorSourceTaskTest { + + @Test + public void testSerde() { + byte[] key = new byte[]{'a', 'b', 'c', 'd', 'e'}; + byte[] value = new byte[]{'f', 'g', 'h', 'i', 'j', 'k'}; + Headers headers = new RecordHeaders(); + headers.add("header1", new byte[]{'l', 'm', 'n', 'o'}); + headers.add("header2", new byte[]{'p', 'q', 'r', 's', 't'}); + ConsumerRecord consumerRecord = new ConsumerRecord<>("topic1", 2, 3L, 4L, + TimestampType.CREATE_TIME, 5, 6, key, value, headers, Optional.empty()); + MirrorSourceTask mirrorSourceTask = new MirrorSourceTask(null, null, "cluster7", + new DefaultReplicationPolicy(), 50); + SourceRecord sourceRecord = mirrorSourceTask.convertRecord(consumerRecord); + assertEquals("cluster7.topic1", sourceRecord.topic(), + "Failure on cluster7.topic1 consumerRecord serde"); + assertEquals(2, sourceRecord.kafkaPartition().intValue(), + "sourceRecord kafka partition is incorrect"); + assertEquals(new TopicPartition("topic1", 2), MirrorUtils.unwrapPartition(sourceRecord.sourcePartition()), + "topic1 unwrapped from sourcePartition is incorrect"); + assertEquals(3L, MirrorUtils.unwrapOffset(sourceRecord.sourceOffset()).longValue(), + "sourceRecord's sourceOffset is incorrect"); + assertEquals(4L, sourceRecord.timestamp().longValue(), + "sourceRecord's timestamp is incorrect"); + assertEquals(key, sourceRecord.key(), "sourceRecord's key is incorrect"); + assertEquals(value, sourceRecord.value(), "sourceRecord's value is incorrect"); + assertEquals(headers.lastHeader("header1").value(), sourceRecord.headers().lastWithName("header1").value(), + "sourceRecord's header1 is incorrect"); + assertEquals(headers.lastHeader("header2").value(), sourceRecord.headers().lastWithName("header2").value(), + "sourceRecord's header2 is incorrect"); + } + + @Test + public void testOffsetSync() { + MirrorSourceTask.PartitionState partitionState = new MirrorSourceTask.PartitionState(50); + + assertTrue(partitionState.update(0, 100), "always emit offset sync on first update"); + assertTrue(partitionState.update(2, 102), "upstream offset skipped -> resync"); + assertFalse(partitionState.update(3, 152), "no sync"); + assertFalse(partitionState.update(4, 153), "no sync"); + assertFalse(partitionState.update(5, 154), "no sync"); + assertTrue(partitionState.update(6, 205), "one past target offset"); + assertTrue(partitionState.update(2, 206), "upstream reset"); + assertFalse(partitionState.update(3, 207), "no sync"); + assertTrue(partitionState.update(4, 3), "downstream reset"); + assertFalse(partitionState.update(5, 4), "no sync"); + } + + @Test + public void testZeroOffsetSync() { + MirrorSourceTask.PartitionState partitionState = new MirrorSourceTask.PartitionState(0); + + // if max offset lag is zero, should always emit offset syncs + assertTrue(partitionState.update(0, 100), "zeroOffsetSync downStreamOffset 100 is incorrect"); + assertTrue(partitionState.update(2, 102), "zeroOffsetSync downStreamOffset 102 is incorrect"); + assertTrue(partitionState.update(3, 153), "zeroOffsetSync downStreamOffset 153 is incorrect"); + assertTrue(partitionState.update(4, 154), "zeroOffsetSync downStreamOffset 154 is incorrect"); + assertTrue(partitionState.update(5, 155), "zeroOffsetSync downStreamOffset 155 is incorrect"); + assertTrue(partitionState.update(6, 207), "zeroOffsetSync downStreamOffset 207 is incorrect"); + assertTrue(partitionState.update(2, 208), "zeroOffsetSync downStreamOffset 208 is incorrect"); + assertTrue(partitionState.update(3, 209), "zeroOffsetSync downStreamOffset 209 is incorrect"); + assertTrue(partitionState.update(4, 3), "zeroOffsetSync downStreamOffset 3 is incorrect"); + assertTrue(partitionState.update(5, 4), "zeroOffsetSync downStreamOffset 4 is incorrect"); + } + + @Test + public void testPoll() { + // Create a consumer mock + byte[] key1 = "abc".getBytes(); + byte[] value1 = "fgh".getBytes(); + byte[] key2 = "123".getBytes(); + byte[] value2 = "456".getBytes(); + List> consumerRecordsList = new ArrayList<>(); + String topicName = "test"; + String headerKey = "key"; + RecordHeaders headers = new RecordHeaders(new Header[] { + new RecordHeader(headerKey, "value".getBytes()), + }); + consumerRecordsList.add(new ConsumerRecord<>(topicName, 0, 0, System.currentTimeMillis(), + TimestampType.CREATE_TIME, key1.length, value1.length, key1, value1, headers, Optional.empty())); + consumerRecordsList.add(new ConsumerRecord<>(topicName, 1, 1, System.currentTimeMillis(), + TimestampType.CREATE_TIME, key2.length, value2.length, key2, value2, headers, Optional.empty())); + ConsumerRecords consumerRecords = + new ConsumerRecords<>(Collections.singletonMap(new TopicPartition(topicName, 0), consumerRecordsList)); + + @SuppressWarnings("unchecked") + KafkaConsumer consumer = mock(KafkaConsumer.class); + when(consumer.poll(any())).thenReturn(consumerRecords); + + MirrorMetrics metrics = mock(MirrorMetrics.class); + + String sourceClusterName = "cluster1"; + ReplicationPolicy replicationPolicy = new DefaultReplicationPolicy(); + MirrorSourceTask mirrorSourceTask = new MirrorSourceTask(consumer, metrics, sourceClusterName, + replicationPolicy, 50); + List sourceRecords = mirrorSourceTask.poll(); + + assertEquals(2, sourceRecords.size()); + for (int i = 0; i < sourceRecords.size(); i++) { + SourceRecord sourceRecord = sourceRecords.get(i); + ConsumerRecord consumerRecord = consumerRecordsList.get(i); + assertEquals(consumerRecord.key(), sourceRecord.key(), + "consumerRecord key does not equal sourceRecord key"); + assertEquals(consumerRecord.value(), sourceRecord.value(), + "consumerRecord value does not equal sourceRecord value"); + // We expect that the topicname will be based on the replication policy currently used + assertEquals(replicationPolicy.formatRemoteTopic(sourceClusterName, topicName), + sourceRecord.topic(), "topicName not the same as the current replicationPolicy"); + // We expect that MirrorMaker will keep the same partition assignment + assertEquals(consumerRecord.partition(), sourceRecord.kafkaPartition().intValue(), + "partition assignment not the same as the current replicationPolicy"); + // Check header values + List
            expectedHeaders = new ArrayList<>(); + consumerRecord.headers().forEach(expectedHeaders::add); + List taskHeaders = new ArrayList<>(); + sourceRecord.headers().forEach(taskHeaders::add); + compareHeaders(expectedHeaders, taskHeaders); + } + } + + private void compareHeaders(List
            expectedHeaders, List taskHeaders) { + assertEquals(expectedHeaders.size(), taskHeaders.size()); + for (int i = 0; i < expectedHeaders.size(); i++) { + Header expectedHeader = expectedHeaders.get(i); + org.apache.kafka.connect.header.Header taskHeader = taskHeaders.get(i); + assertEquals(expectedHeader.key(), taskHeader.key(), + "taskHeader's key expected to equal " + taskHeader.key()); + assertEquals(expectedHeader.value(), taskHeader.value(), + "taskHeader's value expected to equal " + taskHeader.value().toString()); + } + } +} diff --git a/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/OffsetSyncStoreTest.java b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/OffsetSyncStoreTest.java new file mode 100644 index 0000000..9307c60 --- /dev/null +++ b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/OffsetSyncStoreTest.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.TopicPartition; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class OffsetSyncStoreTest { + + static TopicPartition tp = new TopicPartition("topic1", 2); + + static class FakeOffsetSyncStore extends OffsetSyncStore { + + FakeOffsetSyncStore() { + super(null, null); + } + + void sync(TopicPartition topicPartition, long upstreamOffset, long downstreamOffset) { + OffsetSync offsetSync = new OffsetSync(topicPartition, upstreamOffset, downstreamOffset); + byte[] key = offsetSync.recordKey(); + byte[] value = offsetSync.recordValue(); + ConsumerRecord record = new ConsumerRecord<>("test.offsets.internal", 0, 3, key, value); + handleRecord(record); + } + } + + @Test + public void testOffsetTranslation() { + FakeOffsetSyncStore store = new FakeOffsetSyncStore(); + + store.sync(tp, 100, 200); + assertEquals(store.translateDownstream(tp, 150), 250, + "Failure in translating downstream offset 250"); + + // Translate exact offsets + store.sync(tp, 150, 251); + assertEquals(store.translateDownstream(tp, 150), 251, + "Failure in translating exact downstream offset 251"); + + // Use old offset (5) prior to any sync -> can't translate + assertEquals(-1, store.translateDownstream(tp, 5), + "Expected old offset to not translate"); + + // Downstream offsets reset + store.sync(tp, 200, 10); + assertEquals(store.translateDownstream(tp, 200), 10, + "Failure in resetting translation of downstream offset"); + + // Upstream offsets reset + store.sync(tp, 20, 20); + assertEquals(store.translateDownstream(tp, 20), 20, + "Failure in resetting translation of upstream offset"); + } +} diff --git a/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/OffsetSyncTest.java b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/OffsetSyncTest.java new file mode 100644 index 0000000..dc7efe2 --- /dev/null +++ b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/OffsetSyncTest.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.TopicPartition; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class OffsetSyncTest { + + @Test + public void testSerde() { + OffsetSync offsetSync = new OffsetSync(new TopicPartition("topic-1", 2), 3, 4); + byte[] key = offsetSync.recordKey(); + byte[] value = offsetSync.recordValue(); + ConsumerRecord record = new ConsumerRecord<>("any-topic", 6, 7, key, value); + OffsetSync deserialized = OffsetSync.deserializeRecord(record); + assertEquals(offsetSync.topicPartition(), deserialized.topicPartition(), + "Failure on offset sync topic partition serde"); + assertEquals(offsetSync.upstreamOffset(), deserialized.upstreamOffset(), + "Failure on upstream offset serde"); + assertEquals(offsetSync.downstreamOffset(), deserialized.downstreamOffset(), + "Failure on downstream offset serde"); + } +} diff --git a/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/TestUtils.java b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/TestUtils.java new file mode 100644 index 0000000..64f689a --- /dev/null +++ b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/TestUtils.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror; + +import java.util.HashMap; +import java.util.Map; + +public class TestUtils { + + static Map makeProps(String... keyValues) { + Map props = new HashMap<>(); + props.put("name", "ConnectorName"); + props.put("connector.class", "ConnectorClass"); + props.put("source.cluster.alias", "source1"); + props.put("target.cluster.alias", "target2"); + for (int i = 0; i < keyValues.length; i += 2) { + props.put(keyValues[i], keyValues[i + 1]); + } + return props; + } + + /* + * return records with different but predictable key and value + */ + public static Map generateRecords(int numRecords) { + Map records = new HashMap<>(); + for (int i = 0; i < numRecords; i++) { + records.put("key-" + i, "message-" + i); + } + return records; + } +} diff --git a/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/integration/IdentityReplicationIntegrationTest.java b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/integration/IdentityReplicationIntegrationTest.java new file mode 100644 index 0000000..43b1fcb --- /dev/null +++ b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/integration/IdentityReplicationIntegrationTest.java @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror.integration; + +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.config.TopicConfig; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.connect.mirror.IdentityReplicationPolicy; +import org.apache.kafka.connect.mirror.MirrorClient; +import org.apache.kafka.connect.mirror.MirrorHeartbeatConnector; +import org.apache.kafka.connect.mirror.MirrorMakerConfig; + +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import org.junit.jupiter.api.Tag; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.BeforeEach; + +/** + * Tests MM2 replication and failover logic for {@link IdentityReplicationPolicy}. + * + *

            MM2 is configured with active/passive replication between two Kafka clusters with {@link IdentityReplicationPolicy}. + * Tests validate that records sent to the primary cluster arrive at the backup cluster. Then, a consumer group is + * migrated from the primary cluster to the backup cluster. Tests validate that consumer offsets + * are translated and replicated from the primary cluster to the backup cluster during this failover. + */ +@Tag("integration") +public class IdentityReplicationIntegrationTest extends MirrorConnectorsIntegrationBaseTest { + @BeforeEach + public void startClusters() throws Exception { + super.startClusters(new HashMap() {{ + put("replication.policy.class", IdentityReplicationPolicy.class.getName()); + put("topics", "test-topic-.*"); + put(BACKUP_CLUSTER_ALIAS + "->" + PRIMARY_CLUSTER_ALIAS + ".enabled", "false"); + put(PRIMARY_CLUSTER_ALIAS + "->" + BACKUP_CLUSTER_ALIAS + ".enabled", "true"); + }}); + } + + @Test + public void testReplication() throws Exception { + produceMessages(primary, "test-topic-1"); + String consumerGroupName = "consumer-group-testReplication"; + Map consumerProps = new HashMap() {{ + put("group.id", consumerGroupName); + put("auto.offset.reset", "latest"); + }}; + // warm up consumers before starting the connectors so we don't need to wait for discovery + warmUpConsumer(consumerProps); + + mm2Config = new MirrorMakerConfig(mm2Props); + + waitUntilMirrorMakerIsRunning(backup, CONNECTOR_LIST, mm2Config, PRIMARY_CLUSTER_ALIAS, BACKUP_CLUSTER_ALIAS); + waitUntilMirrorMakerIsRunning(primary, Collections.singletonList(MirrorHeartbeatConnector.class), mm2Config, BACKUP_CLUSTER_ALIAS, PRIMARY_CLUSTER_ALIAS); + + MirrorClient primaryClient = new MirrorClient(mm2Config.clientConfig(PRIMARY_CLUSTER_ALIAS)); + MirrorClient backupClient = new MirrorClient(mm2Config.clientConfig(BACKUP_CLUSTER_ALIAS)); + + // make sure the topic is auto-created in the other cluster + waitForTopicCreated(primary, "test-topic-1"); + waitForTopicCreated(backup, "test-topic-1"); + assertEquals(TopicConfig.CLEANUP_POLICY_COMPACT, getTopicConfig(backup.kafka(), "test-topic-1", TopicConfig.CLEANUP_POLICY_CONFIG), + "topic config was not synced"); + + assertEquals(NUM_RECORDS_PRODUCED, primary.kafka().consume(NUM_RECORDS_PRODUCED, RECORD_TRANSFER_DURATION_MS, "test-topic-1").count(), + "Records were not produced to primary cluster."); + assertEquals(NUM_RECORDS_PRODUCED, backup.kafka().consume(NUM_RECORDS_PRODUCED, RECORD_TRANSFER_DURATION_MS, "test-topic-1").count(), + "Records were not replicated to backup cluster."); + + assertTrue(primary.kafka().consume(1, RECORD_TRANSFER_DURATION_MS, "heartbeats").count() > 0, + "Heartbeats were not emitted to primary cluster."); + assertTrue(backup.kafka().consume(1, RECORD_TRANSFER_DURATION_MS, "heartbeats").count() > 0, + "Heartbeats were not emitted to backup cluster."); + assertTrue(backup.kafka().consume(1, RECORD_TRANSFER_DURATION_MS, "primary.heartbeats").count() > 0, + "Heartbeats were not replicated downstream to backup cluster."); + assertTrue(primary.kafka().consume(1, RECORD_TRANSFER_DURATION_MS, "heartbeats").count() > 0, + "Heartbeats were not replicated downstream to primary cluster."); + + assertTrue(backupClient.upstreamClusters().contains(PRIMARY_CLUSTER_ALIAS), "Did not find upstream primary cluster."); + assertEquals(1, backupClient.replicationHops(PRIMARY_CLUSTER_ALIAS), "Did not calculate replication hops correctly."); + assertTrue(backup.kafka().consume(1, CHECKPOINT_DURATION_MS, "primary.checkpoints.internal").count() > 0, + "Checkpoints were not emitted downstream to backup cluster."); + + Map backupOffsets = backupClient.remoteConsumerOffsets(consumerGroupName, PRIMARY_CLUSTER_ALIAS, + Duration.ofMillis(CHECKPOINT_DURATION_MS)); + + assertTrue(backupOffsets.containsKey( + new TopicPartition("test-topic-1", 0)), "Offsets not translated downstream to backup cluster. Found: " + backupOffsets); + + // Failover consumer group to backup cluster. + try (Consumer primaryConsumer = backup.kafka().createConsumer(Collections.singletonMap("group.id", consumerGroupName))) { + primaryConsumer.assign(backupOffsets.keySet()); + backupOffsets.forEach(primaryConsumer::seek); + primaryConsumer.poll(CONSUMER_POLL_TIMEOUT_MS); + primaryConsumer.commitAsync(); + + assertTrue(primaryConsumer.position(new TopicPartition("test-topic-1", 0)) > 0, "Consumer failedover to zero offset."); + assertTrue(primaryConsumer.position( + new TopicPartition("test-topic-1", 0)) <= NUM_RECORDS_PRODUCED, "Consumer failedover beyond expected offset."); + } + + primaryClient.close(); + backupClient.close(); + + // create more matching topics + primary.kafka().createTopic("test-topic-2", NUM_PARTITIONS); + + // make sure the topic is auto-created in the other cluster + waitForTopicCreated(backup, "test-topic-2"); + + // only produce messages to the first partition + produceMessages(primary, "test-topic-2", 1); + + // expect total consumed messages equals to NUM_RECORDS_PER_PARTITION + assertEquals(NUM_RECORDS_PER_PARTITION, primary.kafka().consume(NUM_RECORDS_PER_PARTITION, RECORD_TRANSFER_DURATION_MS, "test-topic-2").count(), + "Records were not produced to primary cluster."); + assertEquals(NUM_RECORDS_PER_PARTITION, backup.kafka().consume(NUM_RECORDS_PER_PARTITION, 2 * RECORD_TRANSFER_DURATION_MS, "test-topic-2").count(), + "New topic was not replicated to backup cluster."); + } + + @Test + public void testReplicationWithEmptyPartition() throws Exception { + String consumerGroupName = "consumer-group-testReplicationWithEmptyPartition"; + Map consumerProps = Collections.singletonMap("group.id", consumerGroupName); + + // create topic + String topic = "test-topic-with-empty-partition"; + primary.kafka().createTopic(topic, NUM_PARTITIONS); + + // produce to all test-topic-empty's partitions, except the last partition + produceMessages(primary, topic, NUM_PARTITIONS - 1); + + // consume before starting the connectors so we don't need to wait for discovery + int expectedRecords = NUM_RECORDS_PER_PARTITION * (NUM_PARTITIONS - 1); + try (Consumer primaryConsumer = primary.kafka().createConsumerAndSubscribeTo(consumerProps, topic)) { + waitForConsumingAllRecords(primaryConsumer, expectedRecords); + } + + // one way replication from primary to backup + mm2Props.put(BACKUP_CLUSTER_ALIAS + "->" + PRIMARY_CLUSTER_ALIAS + ".enabled", "false"); + mm2Config = new MirrorMakerConfig(mm2Props); + waitUntilMirrorMakerIsRunning(backup, CONNECTOR_LIST, mm2Config, PRIMARY_CLUSTER_ALIAS, BACKUP_CLUSTER_ALIAS); + + // sleep few seconds to have MM2 finish replication so that "end" consumer will consume some record + Thread.sleep(TimeUnit.SECONDS.toMillis(3)); + + // note that with IdentityReplicationPolicy, topics on the backup are NOT renamed to PRIMARY_CLUSTER_ALIAS + "." + topic + String backupTopic = topic; + + // consume all records from backup cluster + try (Consumer backupConsumer = backup.kafka().createConsumerAndSubscribeTo(consumerProps, + backupTopic)) { + waitForConsumingAllRecords(backupConsumer, expectedRecords); + } + + try (Admin backupClient = backup.kafka().createAdminClient()) { + // retrieve the consumer group offset from backup cluster + Map remoteOffsets = + backupClient.listConsumerGroupOffsets(consumerGroupName).partitionsToOffsetAndMetadata().get(); + + // pinpoint the offset of the last partition which does not receive records + OffsetAndMetadata offset = remoteOffsets.get(new TopicPartition(backupTopic, NUM_PARTITIONS - 1)); + // offset of the last partition should exist, but its value should be 0 + assertNotNull(offset, "Offset of last partition was not replicated"); + assertEquals(0, offset.offset(), "Offset of last partition is not zero"); + } + } + + @Test + public void testOneWayReplicationWithAutoOffsetSync() throws InterruptedException { + produceMessages(primary, "test-topic-1"); + String consumerGroupName = "consumer-group-testOneWayReplicationWithAutoOffsetSync"; + Map consumerProps = new HashMap() {{ + put("group.id", consumerGroupName); + put("auto.offset.reset", "earliest"); + }}; + // create consumers before starting the connectors so we don't need to wait for discovery + try (Consumer primaryConsumer = primary.kafka().createConsumerAndSubscribeTo(consumerProps, + "test-topic-1")) { + // we need to wait for consuming all the records for MM2 replicating the expected offsets + waitForConsumingAllRecords(primaryConsumer, NUM_RECORDS_PRODUCED); + } + + // enable automated consumer group offset sync + mm2Props.put("sync.group.offsets.enabled", "true"); + mm2Props.put("sync.group.offsets.interval.seconds", "1"); + // one way replication from primary to backup + mm2Props.put(BACKUP_CLUSTER_ALIAS + "->" + PRIMARY_CLUSTER_ALIAS + ".enabled", "false"); + + mm2Config = new MirrorMakerConfig(mm2Props); + + waitUntilMirrorMakerIsRunning(backup, CONNECTOR_LIST, mm2Config, PRIMARY_CLUSTER_ALIAS, BACKUP_CLUSTER_ALIAS); + + // make sure the topic is created in the other cluster + waitForTopicCreated(primary, "backup.test-topic-1"); + waitForTopicCreated(backup, "test-topic-1"); + // create a consumer at backup cluster with same consumer group Id to consume 1 topic + Consumer backupConsumer = backup.kafka().createConsumerAndSubscribeTo( + consumerProps, "test-topic-1"); + + waitForConsumerGroupOffsetSync(backup, backupConsumer, Collections.singletonList("test-topic-1"), + consumerGroupName, NUM_RECORDS_PRODUCED); + + ConsumerRecords records = backupConsumer.poll(CONSUMER_POLL_TIMEOUT_MS); + + // the size of consumer record should be zero, because the offsets of the same consumer group + // have been automatically synchronized from primary to backup by the background job, so no + // more records to consume from the replicated topic by the same consumer group at backup cluster + assertEquals(0, records.count(), "consumer record size is not zero"); + + // now create a new topic in primary cluster + primary.kafka().createTopic("test-topic-2", NUM_PARTITIONS); + // make sure the topic is created in backup cluster + waitForTopicCreated(backup, "test-topic-2"); + + // produce some records to the new topic in primary cluster + produceMessages(primary, "test-topic-2"); + + // create a consumer at primary cluster to consume the new topic + try (Consumer consumer1 = primary.kafka().createConsumerAndSubscribeTo(Collections.singletonMap( + "group.id", "consumer-group-1"), "test-topic-2")) { + // we need to wait for consuming all the records for MM2 replicating the expected offsets + waitForConsumingAllRecords(consumer1, NUM_RECORDS_PRODUCED); + } + + // create a consumer at backup cluster with same consumer group Id to consume old and new topic + backupConsumer = backup.kafka().createConsumerAndSubscribeTo(Collections.singletonMap( + "group.id", consumerGroupName), "test-topic-1", "test-topic-2"); + + waitForConsumerGroupOffsetSync(backup, backupConsumer, Arrays.asList("test-topic-1", "test-topic-2"), + consumerGroupName, NUM_RECORDS_PRODUCED); + + records = backupConsumer.poll(CONSUMER_POLL_TIMEOUT_MS); + // similar reasoning as above, no more records to consume by the same consumer group at backup cluster + assertEquals(0, records.count(), "consumer record size is not zero"); + backupConsumer.close(); + } +} diff --git a/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/integration/MirrorConnectorsIntegrationBaseTest.java b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/integration/MirrorConnectorsIntegrationBaseTest.java new file mode 100644 index 0000000..6fb7a81 --- /dev/null +++ b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/integration/MirrorConnectorsIntegrationBaseTest.java @@ -0,0 +1,711 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror.integration; + +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.AdminClientConfig; +import org.apache.kafka.clients.admin.Config; +import org.apache.kafka.clients.admin.DescribeConfigsResult; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.config.TopicConfig; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.connect.connector.Connector; +import org.apache.kafka.connect.mirror.MirrorClient; +import org.apache.kafka.connect.mirror.MirrorHeartbeatConnector; +import org.apache.kafka.connect.mirror.MirrorMakerConfig; +import org.apache.kafka.connect.mirror.MirrorSourceConnector; +import org.apache.kafka.connect.mirror.SourceAndTarget; +import org.apache.kafka.connect.mirror.Checkpoint; +import org.apache.kafka.connect.mirror.MirrorCheckpointConnector; +import org.apache.kafka.connect.mirror.ReplicationPolicy; +import org.apache.kafka.connect.util.clusters.EmbeddedConnectCluster; +import org.apache.kafka.connect.util.clusters.EmbeddedKafkaCluster; +import org.apache.kafka.connect.util.clusters.UngracefulShutdownException; +import static org.apache.kafka.test.TestUtils.waitForCondition; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.jupiter.api.Tag; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; + +import static org.apache.kafka.connect.mirror.TestUtils.generateRecords; + +/** + * Tests MM2 replication and failover/failback logic. + * + * MM2 is configured with active/active replication between two Kafka clusters. Tests validate that + * records sent to either cluster arrive at the other cluster. Then, a consumer group is migrated from + * one cluster to the other and back. Tests validate that consumer offsets are translated and replicated + * between clusters during this failover and failback. + */ +@Tag("integration") +public abstract class MirrorConnectorsIntegrationBaseTest { + private static final Logger log = LoggerFactory.getLogger(MirrorConnectorsIntegrationBaseTest.class); + + protected static final int NUM_RECORDS_PER_PARTITION = 10; + protected static final int NUM_PARTITIONS = 10; + protected static final int NUM_RECORDS_PRODUCED = NUM_PARTITIONS * NUM_RECORDS_PER_PARTITION; + protected static final int RECORD_TRANSFER_DURATION_MS = 30_000; + protected static final int CHECKPOINT_DURATION_MS = 20_000; + private static final int RECORD_CONSUME_DURATION_MS = 20_000; + private static final int OFFSET_SYNC_DURATION_MS = 30_000; + private static final int TOPIC_SYNC_DURATION_MS = 60_000; + private static final int REQUEST_TIMEOUT_DURATION_MS = 60_000; + private static final int NUM_WORKERS = 3; + protected static final Duration CONSUMER_POLL_TIMEOUT_MS = Duration.ofMillis(500); + protected static final String PRIMARY_CLUSTER_ALIAS = "primary"; + protected static final String BACKUP_CLUSTER_ALIAS = "backup"; + protected static final List> CONNECTOR_LIST = + Arrays.asList(MirrorSourceConnector.class, MirrorCheckpointConnector.class, MirrorHeartbeatConnector.class); + + private volatile boolean shuttingDown; + protected Map mm2Props = new HashMap<>(); + protected MirrorMakerConfig mm2Config; + protected EmbeddedConnectCluster primary; + protected EmbeddedConnectCluster backup; + + protected Exit.Procedure exitProcedure; + private Exit.Procedure haltProcedure; + + protected Properties primaryBrokerProps = new Properties(); + protected Properties backupBrokerProps = new Properties(); + protected Map primaryWorkerProps = new HashMap<>(); + protected Map backupWorkerProps = new HashMap<>(); + + @BeforeEach + public void startClusters() throws Exception { + startClusters(new HashMap() {{ + put("topics", "test-topic-.*, primary.test-topic-.*, backup.test-topic-.*"); + put(PRIMARY_CLUSTER_ALIAS + "->" + BACKUP_CLUSTER_ALIAS + ".enabled", "true"); + put(BACKUP_CLUSTER_ALIAS + "->" + PRIMARY_CLUSTER_ALIAS + ".enabled", "true"); + }}); + } + + public void startClusters(Map additionalMM2Config) throws Exception { + shuttingDown = false; + exitProcedure = (code, message) -> { + if (shuttingDown) { + // ignore this since we're shutting down Connect and Kafka and timing isn't always great + return; + } + if (code != 0) { + String exitMessage = "Abrupt service exit with code " + code + " and message " + message; + log.warn(exitMessage); + throw new UngracefulShutdownException(exitMessage); + } + }; + haltProcedure = (code, message) -> { + if (shuttingDown) { + // ignore this since we're shutting down Connect and Kafka and timing isn't always great + return; + } + if (code != 0) { + String haltMessage = "Abrupt service halt with code " + code + " and message " + message; + log.warn(haltMessage); + throw new UngracefulShutdownException(haltMessage); + } + }; + // Override the exit and halt procedure that Connect and Kafka will use. For these integration tests, + // we don't want to exit the JVM and instead simply want to fail the test + Exit.setExitProcedure(exitProcedure); + Exit.setHaltProcedure(haltProcedure); + + primaryBrokerProps.put("auto.create.topics.enable", "false"); + backupBrokerProps.put("auto.create.topics.enable", "false"); + + mm2Props.putAll(basicMM2Config()); + mm2Props.putAll(additionalMM2Config); + + mm2Config = new MirrorMakerConfig(mm2Props); + primaryWorkerProps = mm2Config.workerConfig(new SourceAndTarget(BACKUP_CLUSTER_ALIAS, PRIMARY_CLUSTER_ALIAS)); + backupWorkerProps.putAll(mm2Config.workerConfig(new SourceAndTarget(PRIMARY_CLUSTER_ALIAS, BACKUP_CLUSTER_ALIAS))); + + primary = new EmbeddedConnectCluster.Builder() + .name(PRIMARY_CLUSTER_ALIAS + "-connect-cluster") + .numWorkers(NUM_WORKERS) + .numBrokers(1) + .brokerProps(primaryBrokerProps) + .workerProps(primaryWorkerProps) + .maskExitProcedures(false) + .build(); + + backup = new EmbeddedConnectCluster.Builder() + .name(BACKUP_CLUSTER_ALIAS + "-connect-cluster") + .numWorkers(NUM_WORKERS) + .numBrokers(1) + .brokerProps(backupBrokerProps) + .workerProps(backupWorkerProps) + .maskExitProcedures(false) + .build(); + + primary.start(); + primary.assertions().assertAtLeastNumWorkersAreUp(NUM_WORKERS, + "Workers of " + PRIMARY_CLUSTER_ALIAS + "-connect-cluster did not start in time."); + + waitForTopicCreated(primary, "mm2-status.backup.internal"); + waitForTopicCreated(primary, "mm2-offsets.backup.internal"); + waitForTopicCreated(primary, "mm2-configs.backup.internal"); + + backup.start(); + backup.assertions().assertAtLeastNumWorkersAreUp(NUM_WORKERS, + "Workers of " + BACKUP_CLUSTER_ALIAS + "-connect-cluster did not start in time."); + + waitForTopicCreated(backup, "mm2-status.primary.internal"); + waitForTopicCreated(backup, "mm2-offsets.primary.internal"); + waitForTopicCreated(backup, "mm2-configs.primary.internal"); + + createTopics(); + + warmUpConsumer(Collections.singletonMap("group.id", "consumer-group-dummy")); + + log.info(PRIMARY_CLUSTER_ALIAS + " REST service: {}", primary.endpointForResource("connectors")); + log.info(BACKUP_CLUSTER_ALIAS + " REST service: {}", backup.endpointForResource("connectors")); + log.info(PRIMARY_CLUSTER_ALIAS + " brokers: {}", primary.kafka().bootstrapServers()); + log.info(BACKUP_CLUSTER_ALIAS + " brokers: {}", backup.kafka().bootstrapServers()); + + // now that the brokers are running, we can finish setting up the Connectors + mm2Props.put(PRIMARY_CLUSTER_ALIAS + ".bootstrap.servers", primary.kafka().bootstrapServers()); + mm2Props.put(BACKUP_CLUSTER_ALIAS + ".bootstrap.servers", backup.kafka().bootstrapServers()); + } + + @AfterEach + public void shutdownClusters() throws Exception { + try { + for (String x : primary.connectors()) { + primary.deleteConnector(x); + } + for (String x : backup.connectors()) { + backup.deleteConnector(x); + } + deleteAllTopics(primary.kafka()); + deleteAllTopics(backup.kafka()); + } finally { + shuttingDown = true; + try { + try { + primary.stop(); + } finally { + backup.stop(); + } + } finally { + Exit.resetExitProcedure(); + Exit.resetHaltProcedure(); + } + } + } + + @Test + public void testReplication() throws Exception { + produceMessages(primary, "test-topic-1"); + produceMessages(backup, "test-topic-1"); + String consumerGroupName = "consumer-group-testReplication"; + Map consumerProps = new HashMap() {{ + put("group.id", consumerGroupName); + put("auto.offset.reset", "latest"); + }}; + // warm up consumers before starting the connectors so we don't need to wait for discovery + warmUpConsumer(consumerProps); + + mm2Config = new MirrorMakerConfig(mm2Props); + + waitUntilMirrorMakerIsRunning(backup, CONNECTOR_LIST, mm2Config, PRIMARY_CLUSTER_ALIAS, BACKUP_CLUSTER_ALIAS); + waitUntilMirrorMakerIsRunning(primary, CONNECTOR_LIST, mm2Config, BACKUP_CLUSTER_ALIAS, PRIMARY_CLUSTER_ALIAS); + + MirrorClient primaryClient = new MirrorClient(mm2Config.clientConfig(PRIMARY_CLUSTER_ALIAS)); + MirrorClient backupClient = new MirrorClient(mm2Config.clientConfig(BACKUP_CLUSTER_ALIAS)); + + // make sure the topic is auto-created in the other cluster + waitForTopicCreated(primary, "backup.test-topic-1"); + waitForTopicCreated(backup, "primary.test-topic-1"); + waitForTopicCreated(primary, "mm2-offset-syncs.backup.internal"); + assertEquals(TopicConfig.CLEANUP_POLICY_COMPACT, getTopicConfig(backup.kafka(), "primary.test-topic-1", TopicConfig.CLEANUP_POLICY_CONFIG), + "topic config was not synced"); + + assertEquals(NUM_RECORDS_PRODUCED, primary.kafka().consume(NUM_RECORDS_PRODUCED, RECORD_TRANSFER_DURATION_MS, "test-topic-1").count(), + "Records were not produced to primary cluster."); + assertEquals(NUM_RECORDS_PRODUCED, backup.kafka().consume(NUM_RECORDS_PRODUCED, RECORD_TRANSFER_DURATION_MS, "primary.test-topic-1").count(), + "Records were not replicated to backup cluster."); + assertEquals(NUM_RECORDS_PRODUCED, backup.kafka().consume(NUM_RECORDS_PRODUCED, RECORD_TRANSFER_DURATION_MS, "test-topic-1").count(), + "Records were not produced to backup cluster."); + assertEquals(NUM_RECORDS_PRODUCED, primary.kafka().consume(NUM_RECORDS_PRODUCED, RECORD_TRANSFER_DURATION_MS, "backup.test-topic-1").count(), + "Records were not replicated to primary cluster."); + + assertEquals(NUM_RECORDS_PRODUCED * 2, primary.kafka().consume(NUM_RECORDS_PRODUCED * 2, RECORD_TRANSFER_DURATION_MS, "backup.test-topic-1", "test-topic-1").count(), + "Primary cluster doesn't have all records from both clusters."); + assertEquals(NUM_RECORDS_PRODUCED * 2, backup.kafka().consume(NUM_RECORDS_PRODUCED * 2, RECORD_TRANSFER_DURATION_MS, "primary.test-topic-1", "test-topic-1").count(), + "Backup cluster doesn't have all records from both clusters."); + + assertTrue(primary.kafka().consume(1, RECORD_TRANSFER_DURATION_MS, "heartbeats").count() > 0, + "Heartbeats were not emitted to primary cluster."); + assertTrue(backup.kafka().consume(1, RECORD_TRANSFER_DURATION_MS, "heartbeats").count() > 0, + "Heartbeats were not emitted to backup cluster."); + assertTrue(backup.kafka().consume(1, RECORD_TRANSFER_DURATION_MS, "primary.heartbeats").count() > 0, + "Heartbeats were not replicated downstream to backup cluster."); + assertTrue(primary.kafka().consume(1, RECORD_TRANSFER_DURATION_MS, "backup.heartbeats").count() > 0, + "Heartbeats were not replicated downstream to primary cluster."); + + assertTrue(backupClient.upstreamClusters().contains(PRIMARY_CLUSTER_ALIAS), "Did not find upstream primary cluster."); + assertEquals(1, backupClient.replicationHops(PRIMARY_CLUSTER_ALIAS), "Did not calculate replication hops correctly."); + assertTrue(primaryClient.upstreamClusters().contains(BACKUP_CLUSTER_ALIAS), "Did not find upstream backup cluster."); + assertEquals(1, primaryClient.replicationHops(BACKUP_CLUSTER_ALIAS), "Did not calculate replication hops correctly."); + assertTrue(backup.kafka().consume(1, CHECKPOINT_DURATION_MS, "primary.checkpoints.internal").count() > 0, + "Checkpoints were not emitted downstream to backup cluster."); + + Map backupOffsets = backupClient.remoteConsumerOffsets(consumerGroupName, PRIMARY_CLUSTER_ALIAS, + Duration.ofMillis(CHECKPOINT_DURATION_MS)); + + assertTrue(backupOffsets.containsKey( + new TopicPartition("primary.test-topic-1", 0)), "Offsets not translated downstream to backup cluster. Found: " + backupOffsets); + + // Failover consumer group to backup cluster. + try (Consumer primaryConsumer = backup.kafka().createConsumer(Collections.singletonMap("group.id", consumerGroupName))) { + primaryConsumer.assign(backupOffsets.keySet()); + backupOffsets.forEach(primaryConsumer::seek); + primaryConsumer.poll(CONSUMER_POLL_TIMEOUT_MS); + primaryConsumer.commitAsync(); + + assertTrue(primaryConsumer.position(new TopicPartition("primary.test-topic-1", 0)) > 0, "Consumer failedover to zero offset."); + assertTrue(primaryConsumer.position( + new TopicPartition("primary.test-topic-1", 0)) <= NUM_RECORDS_PRODUCED, "Consumer failedover beyond expected offset."); + assertTrue(primary.kafka().consume(1, CHECKPOINT_DURATION_MS, "backup.checkpoints.internal").count() > 0, + "Checkpoints were not emitted upstream to primary cluster."); + } + + waitForCondition(() -> primaryClient.remoteConsumerOffsets(consumerGroupName, BACKUP_CLUSTER_ALIAS, + Duration.ofMillis(CHECKPOINT_DURATION_MS)).containsKey(new TopicPartition("backup.test-topic-1", 0)), CHECKPOINT_DURATION_MS, "Offsets not translated downstream to primary cluster."); + + waitForCondition(() -> primaryClient.remoteConsumerOffsets(consumerGroupName, BACKUP_CLUSTER_ALIAS, + Duration.ofMillis(CHECKPOINT_DURATION_MS)).containsKey(new TopicPartition("test-topic-1", 0)), CHECKPOINT_DURATION_MS, "Offsets not translated upstream to primary cluster."); + + Map primaryOffsets = primaryClient.remoteConsumerOffsets(consumerGroupName, BACKUP_CLUSTER_ALIAS, + Duration.ofMillis(CHECKPOINT_DURATION_MS)); + + primaryClient.close(); + backupClient.close(); + + // Failback consumer group to primary cluster + try (Consumer backupConsumer = primary.kafka().createConsumer(Collections.singletonMap("group.id", consumerGroupName))) { + backupConsumer.assign(primaryOffsets.keySet()); + primaryOffsets.forEach(backupConsumer::seek); + backupConsumer.poll(CONSUMER_POLL_TIMEOUT_MS); + backupConsumer.commitAsync(); + + assertTrue(backupConsumer.position(new TopicPartition("test-topic-1", 0)) > 0, "Consumer failedback to zero upstream offset."); + assertTrue(backupConsumer.position(new TopicPartition("backup.test-topic-1", 0)) > 0, "Consumer failedback to zero downstream offset."); + assertTrue(backupConsumer.position( + new TopicPartition("test-topic-1", 0)) <= NUM_RECORDS_PRODUCED, "Consumer failedback beyond expected upstream offset."); + assertTrue(backupConsumer.position( + new TopicPartition("backup.test-topic-1", 0)) <= NUM_RECORDS_PRODUCED, "Consumer failedback beyond expected downstream offset."); + } + + // create more matching topics + primary.kafka().createTopic("test-topic-2", NUM_PARTITIONS); + backup.kafka().createTopic("test-topic-3", NUM_PARTITIONS); + + // make sure the topic is auto-created in the other cluster + waitForTopicCreated(backup, "primary.test-topic-2"); + waitForTopicCreated(primary, "backup.test-topic-3"); + + // only produce messages to the first partition + produceMessages(primary, "test-topic-2", 1); + produceMessages(backup, "test-topic-3", 1); + + // expect total consumed messages equals to NUM_RECORDS_PER_PARTITION + assertEquals(NUM_RECORDS_PER_PARTITION, primary.kafka().consume(NUM_RECORDS_PER_PARTITION, RECORD_TRANSFER_DURATION_MS, "test-topic-2").count(), + "Records were not produced to primary cluster."); + assertEquals(NUM_RECORDS_PER_PARTITION, backup.kafka().consume(NUM_RECORDS_PER_PARTITION, RECORD_TRANSFER_DURATION_MS, "test-topic-3").count(), + "Records were not produced to backup cluster."); + + assertEquals(NUM_RECORDS_PER_PARTITION, primary.kafka().consume(NUM_RECORDS_PER_PARTITION, 2 * RECORD_TRANSFER_DURATION_MS, "backup.test-topic-3").count(), + "New topic was not replicated to primary cluster."); + assertEquals(NUM_RECORDS_PER_PARTITION, backup.kafka().consume(NUM_RECORDS_PER_PARTITION, 2 * RECORD_TRANSFER_DURATION_MS, "primary.test-topic-2").count(), + "New topic was not replicated to backup cluster."); + } + + @Test + public void testReplicationWithEmptyPartition() throws Exception { + String consumerGroupName = "consumer-group-testReplicationWithEmptyPartition"; + Map consumerProps = Collections.singletonMap("group.id", consumerGroupName); + + // create topic + String topic = "test-topic-with-empty-partition"; + primary.kafka().createTopic(topic, NUM_PARTITIONS); + + // produce to all test-topic-empty's partitions, except the last partition + produceMessages(primary, topic, NUM_PARTITIONS - 1); + + // consume before starting the connectors so we don't need to wait for discovery + int expectedRecords = NUM_RECORDS_PER_PARTITION * (NUM_PARTITIONS - 1); + try (Consumer primaryConsumer = primary.kafka().createConsumerAndSubscribeTo(consumerProps, topic)) { + waitForConsumingAllRecords(primaryConsumer, expectedRecords); + } + + // one way replication from primary to backup + mm2Props.put(BACKUP_CLUSTER_ALIAS + "->" + PRIMARY_CLUSTER_ALIAS + ".enabled", "false"); + mm2Config = new MirrorMakerConfig(mm2Props); + waitUntilMirrorMakerIsRunning(backup, CONNECTOR_LIST, mm2Config, PRIMARY_CLUSTER_ALIAS, BACKUP_CLUSTER_ALIAS); + + // sleep few seconds to have MM2 finish replication so that "end" consumer will consume some record + Thread.sleep(TimeUnit.SECONDS.toMillis(3)); + + String backupTopic = PRIMARY_CLUSTER_ALIAS + "." + topic; + + // consume all records from backup cluster + try (Consumer backupConsumer = backup.kafka().createConsumerAndSubscribeTo(consumerProps, + backupTopic)) { + waitForConsumingAllRecords(backupConsumer, expectedRecords); + } + + try (Admin backupClient = backup.kafka().createAdminClient()) { + // retrieve the consumer group offset from backup cluster + Map remoteOffsets = + backupClient.listConsumerGroupOffsets(consumerGroupName).partitionsToOffsetAndMetadata().get(); + + // pinpoint the offset of the last partition which does not receive records + OffsetAndMetadata offset = remoteOffsets.get(new TopicPartition(backupTopic, NUM_PARTITIONS - 1)); + // offset of the last partition should exist, but its value should be 0 + assertNotNull(offset, "Offset of last partition was not replicated"); + assertEquals(0, offset.offset(), "Offset of last partition is not zero"); + } + } + + @Test + public void testOneWayReplicationWithAutoOffsetSync() throws InterruptedException { + produceMessages(primary, "test-topic-1"); + String consumerGroupName = "consumer-group-testOneWayReplicationWithAutoOffsetSync"; + Map consumerProps = new HashMap() {{ + put("group.id", consumerGroupName); + put("auto.offset.reset", "earliest"); + }}; + // create consumers before starting the connectors so we don't need to wait for discovery + try (Consumer primaryConsumer = primary.kafka().createConsumerAndSubscribeTo(consumerProps, + "test-topic-1")) { + // we need to wait for consuming all the records for MM2 replicating the expected offsets + waitForConsumingAllRecords(primaryConsumer, NUM_RECORDS_PRODUCED); + } + + // enable automated consumer group offset sync + mm2Props.put("sync.group.offsets.enabled", "true"); + mm2Props.put("sync.group.offsets.interval.seconds", "1"); + // one way replication from primary to backup + mm2Props.put(BACKUP_CLUSTER_ALIAS + "->" + PRIMARY_CLUSTER_ALIAS + ".enabled", "false"); + + mm2Config = new MirrorMakerConfig(mm2Props); + + waitUntilMirrorMakerIsRunning(backup, CONNECTOR_LIST, mm2Config, PRIMARY_CLUSTER_ALIAS, BACKUP_CLUSTER_ALIAS); + + // make sure the topic is created in the other cluster + waitForTopicCreated(primary, "backup.test-topic-1"); + waitForTopicCreated(backup, "primary.test-topic-1"); + // create a consumer at backup cluster with same consumer group Id to consume 1 topic + Consumer backupConsumer = backup.kafka().createConsumerAndSubscribeTo( + consumerProps, "primary.test-topic-1"); + + waitForConsumerGroupOffsetSync(backup, backupConsumer, Collections.singletonList("primary.test-topic-1"), + consumerGroupName, NUM_RECORDS_PRODUCED); + + ConsumerRecords records = backupConsumer.poll(CONSUMER_POLL_TIMEOUT_MS); + + // the size of consumer record should be zero, because the offsets of the same consumer group + // have been automatically synchronized from primary to backup by the background job, so no + // more records to consume from the replicated topic by the same consumer group at backup cluster + assertEquals(0, records.count(), "consumer record size is not zero"); + + // now create a new topic in primary cluster + primary.kafka().createTopic("test-topic-2", NUM_PARTITIONS); + // make sure the topic is created in backup cluster + waitForTopicCreated(backup, "primary.test-topic-2"); + + // produce some records to the new topic in primary cluster + produceMessages(primary, "test-topic-2"); + + // create a consumer at primary cluster to consume the new topic + try (Consumer consumer1 = primary.kafka().createConsumerAndSubscribeTo(Collections.singletonMap( + "group.id", "consumer-group-1"), "test-topic-2")) { + // we need to wait for consuming all the records for MM2 replicating the expected offsets + waitForConsumingAllRecords(consumer1, NUM_RECORDS_PRODUCED); + } + + // create a consumer at backup cluster with same consumer group Id to consume old and new topic + backupConsumer = backup.kafka().createConsumerAndSubscribeTo(Collections.singletonMap( + "group.id", consumerGroupName), "primary.test-topic-1", "primary.test-topic-2"); + + waitForConsumerGroupOffsetSync(backup, backupConsumer, Arrays.asList("primary.test-topic-1", "primary.test-topic-2"), + consumerGroupName, NUM_RECORDS_PRODUCED); + + records = backupConsumer.poll(CONSUMER_POLL_TIMEOUT_MS); + // similar reasoning as above, no more records to consume by the same consumer group at backup cluster + assertEquals(0, records.count(), "consumer record size is not zero"); + backupConsumer.close(); + } + + @Test + public void testOffsetSyncsTopicsOnTarget() throws Exception { + // move offset-syncs topics to target + mm2Props.put(PRIMARY_CLUSTER_ALIAS + "->" + BACKUP_CLUSTER_ALIAS + ".offset-syncs.topic.location", "target"); + // one way replication from primary to backup + mm2Props.put(BACKUP_CLUSTER_ALIAS + "->" + PRIMARY_CLUSTER_ALIAS + ".enabled", "false"); + + mm2Config = new MirrorMakerConfig(mm2Props); + + waitUntilMirrorMakerIsRunning(backup, CONNECTOR_LIST, mm2Config, PRIMARY_CLUSTER_ALIAS, BACKUP_CLUSTER_ALIAS); + + // Ensure the offset syncs topic is created in the target cluster + waitForTopicCreated(backup.kafka(), "mm2-offset-syncs." + PRIMARY_CLUSTER_ALIAS + ".internal"); + + produceMessages(primary, "test-topic-1"); + + ReplicationPolicy replicationPolicy = new MirrorClient(mm2Config.clientConfig(BACKUP_CLUSTER_ALIAS)).replicationPolicy(); + String remoteTopic = replicationPolicy.formatRemoteTopic(PRIMARY_CLUSTER_ALIAS, "test-topic-1"); + + // Check offsets are pushed to the checkpoint topic + Consumer backupConsumer = backup.kafka().createConsumerAndSubscribeTo(Collections.singletonMap( + "auto.offset.reset", "earliest"), PRIMARY_CLUSTER_ALIAS + ".checkpoints.internal"); + waitForCondition(() -> { + ConsumerRecords records = backupConsumer.poll(Duration.ofSeconds(1L)); + for (ConsumerRecord record : records) { + Checkpoint checkpoint = Checkpoint.deserializeRecord(record); + if (remoteTopic.equals(checkpoint.topicPartition().topic())) { + return true; + } + } + return false; + }, 30_000, + "Unable to find checkpoints for " + PRIMARY_CLUSTER_ALIAS + ".test-topic-1" + ); + + // Ensure no offset-syncs topics have been created on the primary cluster + Set primaryTopics = primary.kafka().createAdminClient().listTopics().names().get(); + assertFalse(primaryTopics.contains("mm2-offset-syncs." + PRIMARY_CLUSTER_ALIAS + ".internal")); + assertFalse(primaryTopics.contains("mm2-offset-syncs." + BACKUP_CLUSTER_ALIAS + ".internal")); + } + + /* + * launch the connectors on kafka connect cluster and check if they are running + */ + protected static void waitUntilMirrorMakerIsRunning(EmbeddedConnectCluster connectCluster, + List> connectorClasses, MirrorMakerConfig mm2Config, + String primary, String backup) throws InterruptedException { + for (Class connector : connectorClasses) { + connectCluster.configureConnector(connector.getSimpleName(), mm2Config.connectorBaseConfig( + new SourceAndTarget(primary, backup), connector)); + } + + // we wait for the connector and tasks to come up for each connector, so that when we do the + // actual testing, we are certain that the tasks are up and running; this will prevent + // flaky tests where the connector and tasks didn't start up in time for the tests to be run + for (Class connector : connectorClasses) { + connectCluster.assertions().assertConnectorAndAtLeastNumTasksAreRunning(connector.getSimpleName(), 1, + "Connector " + connector.getSimpleName() + " tasks did not start in time on cluster: " + connectCluster.getName()); + } + } + + /* + * wait for the topic created on the cluster + */ + protected static void waitForTopicCreated(EmbeddedConnectCluster cluster, String topicName) throws InterruptedException { + try (final Admin adminClient = cluster.kafka().createAdminClient()) { + waitForCondition(() -> adminClient.listTopics().names().get().contains(topicName), TOPIC_SYNC_DURATION_MS, + "Topic: " + topicName + " didn't get created on cluster: " + cluster.getName() + ); + } + } + + /* + * delete all topics of the input kafka cluster + */ + private static void deleteAllTopics(EmbeddedKafkaCluster cluster) throws Exception { + try (final Admin adminClient = cluster.createAdminClient()) { + Set topicsToBeDeleted = adminClient.listTopics().names().get(); + log.debug("Deleting topics: {} ", topicsToBeDeleted); + adminClient.deleteTopics(topicsToBeDeleted).all().get(); + } + } + + /* + * retrieve the config value based on the input cluster, topic and config name + */ + protected static String getTopicConfig(EmbeddedKafkaCluster cluster, String topic, String configName) throws Exception { + try (Admin client = cluster.createAdminClient()) { + Collection cr = Collections.singleton( + new ConfigResource(ConfigResource.Type.TOPIC, topic)); + + DescribeConfigsResult configsResult = client.describeConfigs(cr); + Config allConfigs = (Config) configsResult.all().get().values().toArray()[0]; + return allConfigs.get(configName).value(); + } + } + + /* + * produce messages to the cluster and topic + */ + protected void produceMessages(EmbeddedConnectCluster cluster, String topicName) { + Map recordSent = generateRecords(NUM_RECORDS_PRODUCED); + for (Map.Entry entry : recordSent.entrySet()) { + cluster.kafka().produce(topicName, entry.getKey(), entry.getValue()); + } + } + + /* + * produce messages to the cluster and topic partition less than numPartitions + */ + protected void produceMessages(EmbeddedConnectCluster cluster, String topicName, int numPartitions) { + int cnt = 0; + for (int r = 0; r < NUM_RECORDS_PER_PARTITION; r++) + for (int p = 0; p < numPartitions; p++) + cluster.kafka().produce(topicName, p, "key", "value-" + cnt++); + } + + /* + * given consumer group, topics and expected number of records, make sure the consumer group + * offsets are eventually synced to the expected offset numbers + */ + protected static void waitForConsumerGroupOffsetSync(EmbeddedConnectCluster connect, + Consumer consumer, List topics, String consumerGroupId, int numRecords) + throws InterruptedException { + try (Admin adminClient = connect.kafka().createAdminClient()) { + List tps = new ArrayList<>(NUM_PARTITIONS * topics.size()); + for (int partitionIndex = 0; partitionIndex < NUM_PARTITIONS; partitionIndex++) { + for (String topic : topics) { + tps.add(new TopicPartition(topic, partitionIndex)); + } + } + long expectedTotalOffsets = numRecords * topics.size(); + + waitForCondition(() -> { + Map consumerGroupOffsets = + adminClient.listConsumerGroupOffsets(consumerGroupId).partitionsToOffsetAndMetadata().get(); + long consumerGroupOffsetTotal = consumerGroupOffsets.values().stream() + .mapToLong(OffsetAndMetadata::offset).sum(); + + Map offsets = consumer.endOffsets(tps, CONSUMER_POLL_TIMEOUT_MS); + long totalOffsets = offsets.values().stream().mapToLong(l -> l).sum(); + + // make sure the consumer group offsets are synced to expected number + return totalOffsets == expectedTotalOffsets && consumerGroupOffsetTotal > 0; + }, OFFSET_SYNC_DURATION_MS, "Consumer group offset sync is not complete in time"); + } + } + + /* + * make sure the consumer to consume expected number of records + */ + protected static void waitForConsumingAllRecords(Consumer consumer, int numExpectedRecords) + throws InterruptedException { + final AtomicInteger totalConsumedRecords = new AtomicInteger(0); + waitForCondition(() -> { + ConsumerRecords records = consumer.poll(CONSUMER_POLL_TIMEOUT_MS); + return numExpectedRecords == totalConsumedRecords.addAndGet(records.count()); + }, RECORD_CONSUME_DURATION_MS, "Consumer cannot consume all records in time"); + consumer.commitSync(); + } + + /* + * MM2 config to use in integration tests + */ + protected static Map basicMM2Config() { + Map mm2Props = new HashMap<>(); + mm2Props.put("clusters", PRIMARY_CLUSTER_ALIAS + ", " + BACKUP_CLUSTER_ALIAS); + mm2Props.put("max.tasks", "10"); + mm2Props.put("groups", "consumer-group-.*"); + mm2Props.put("sync.topic.acls.enabled", "false"); + mm2Props.put("emit.checkpoints.interval.seconds", "1"); + mm2Props.put("emit.heartbeats.interval.seconds", "1"); + mm2Props.put("refresh.topics.interval.seconds", "1"); + mm2Props.put("refresh.groups.interval.seconds", "1"); + mm2Props.put("checkpoints.topic.replication.factor", "1"); + mm2Props.put("heartbeats.topic.replication.factor", "1"); + mm2Props.put("offset-syncs.topic.replication.factor", "1"); + mm2Props.put("config.storage.replication.factor", "1"); + mm2Props.put("offset.storage.replication.factor", "1"); + mm2Props.put("status.storage.replication.factor", "1"); + mm2Props.put("replication.factor", "1"); + + return mm2Props; + } + + private void createTopics() { + // to verify topic config will be sync-ed across clusters + Map topicConfig = Collections.singletonMap(TopicConfig.CLEANUP_POLICY_CONFIG, TopicConfig.CLEANUP_POLICY_COMPACT); + Map emptyMap = Collections.emptyMap(); + + // increase admin client request timeout value to make the tests reliable. + Properties adminClientConfig = new Properties(); + adminClientConfig.put(AdminClientConfig.REQUEST_TIMEOUT_MS_CONFIG, REQUEST_TIMEOUT_DURATION_MS); + + // create these topics before starting the connectors so we don't need to wait for discovery + primary.kafka().createTopic("test-topic-1", NUM_PARTITIONS, 1, topicConfig, adminClientConfig); + primary.kafka().createTopic("backup.test-topic-1", 1, 1, emptyMap, adminClientConfig); + primary.kafka().createTopic("heartbeats", 1, 1, emptyMap, adminClientConfig); + backup.kafka().createTopic("test-topic-1", NUM_PARTITIONS, 1, emptyMap, adminClientConfig); + backup.kafka().createTopic("primary.test-topic-1", 1, 1, emptyMap, adminClientConfig); + backup.kafka().createTopic("heartbeats", 1, 1, emptyMap, adminClientConfig); + } + + /* + * Generate some consumer activity on both clusters to ensure the checkpoint connector always starts promptly + */ + protected void warmUpConsumer(Map consumerProps) throws InterruptedException { + Consumer dummyConsumer = primary.kafka().createConsumerAndSubscribeTo(consumerProps, "test-topic-1"); + dummyConsumer.poll(CONSUMER_POLL_TIMEOUT_MS); + dummyConsumer.commitSync(); + dummyConsumer.close(); + dummyConsumer = backup.kafka().createConsumerAndSubscribeTo(consumerProps, "test-topic-1"); + dummyConsumer.poll(CONSUMER_POLL_TIMEOUT_MS); + dummyConsumer.commitSync(); + dummyConsumer.close(); + } + + /* + * wait for the topic created on the cluster + */ + private static void waitForTopicCreated(EmbeddedKafkaCluster cluster, String topicName) throws InterruptedException { + try (final Admin adminClient = cluster.createAdminClient()) { + waitForCondition(() -> { + Set topics = adminClient.listTopics().names().get(); + return topics.contains(topicName); + }, OFFSET_SYNC_DURATION_MS, + "Topic: " + topicName + " didn't get created in the cluster" + ); + } + } +} diff --git a/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/integration/MirrorConnectorsIntegrationSSLTest.java b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/integration/MirrorConnectorsIntegrationSSLTest.java new file mode 100644 index 0000000..eb2af48 --- /dev/null +++ b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/integration/MirrorConnectorsIntegrationSSLTest.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror.integration; + +import java.util.Map; +import java.util.Properties; +import java.util.stream.Collectors; + +import kafka.server.KafkaConfig; +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.common.config.SslConfigs; +import org.apache.kafka.common.config.types.Password; +import org.apache.kafka.common.network.Mode; +import org.apache.kafka.test.TestSslUtils; +import org.apache.kafka.test.TestUtils; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; + +/** + * Tests MM2 replication with SSL enabled at backup kafka cluster + */ +@Tag("integration") +public class MirrorConnectorsIntegrationSSLTest extends MirrorConnectorsIntegrationBaseTest { + + @BeforeEach + public void startClusters() throws Exception { + Map sslConfig = TestSslUtils.createSslConfig(false, true, Mode.SERVER, TestUtils.tempFile(), "testCert"); + // enable SSL on backup kafka broker + backupBrokerProps.put(KafkaConfig.ListenersProp(), "SSL://localhost:0"); + backupBrokerProps.put(KafkaConfig.InterBrokerListenerNameProp(), "SSL"); + backupBrokerProps.putAll(sslConfig); + + Properties sslProps = new Properties(); + sslProps.put(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG, sslConfig.get(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG)); + sslProps.put(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG, ((Password) sslConfig.get(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG)).value()); + sslProps.put(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, "SSL"); + + // set SSL config for kafka connect worker + backupWorkerProps.putAll(sslProps.entrySet().stream().collect(Collectors.toMap( + e -> String.valueOf(e.getKey()), e -> String.valueOf(e.getValue())))); + + mm2Props.putAll(sslProps.entrySet().stream().collect(Collectors.toMap( + e -> BACKUP_CLUSTER_ALIAS + "." + e.getKey(), e -> String.valueOf(e.getValue())))); + // set SSL config for producer used by source task in MM2 + mm2Props.putAll(sslProps.entrySet().stream().collect(Collectors.toMap( + e -> BACKUP_CLUSTER_ALIAS + ".producer." + e.getKey(), e -> String.valueOf(e.getValue())))); + + super.startClusters(); + } +} + diff --git a/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/integration/MirrorConnectorsIntegrationTest.java b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/integration/MirrorConnectorsIntegrationTest.java new file mode 100644 index 0000000..ed82aa9 --- /dev/null +++ b/connect/mirror/src/test/java/org/apache/kafka/connect/mirror/integration/MirrorConnectorsIntegrationTest.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.mirror.integration; + +import org.junit.jupiter.api.Tag; + +@Tag("integration") +public class MirrorConnectorsIntegrationTest extends MirrorConnectorsIntegrationBaseTest { +} diff --git a/connect/mirror/src/test/resources/log4j.properties b/connect/mirror/src/test/resources/log4j.properties new file mode 100644 index 0000000..a2ac021 --- /dev/null +++ b/connect/mirror/src/test/resources/log4j.properties @@ -0,0 +1,34 @@ +## +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +## +log4j.rootLogger=ERROR, stdout + +log4j.appender.stdout=org.apache.log4j.ConsoleAppender +log4j.appender.stdout.layout=org.apache.log4j.PatternLayout +# +# The `%X{connector.context}` parameter in the layout includes connector-specific and task-specific information +# in the log message, where appropriate. This makes it easier to identify those log messages that apply to a +# specific connector. Simply add this parameter to the log layout configuration below to include the contextual information. +# +log4j.appender.stdout.layout.ConversionPattern=[%d] %p %X{connector.context}%m (%c:%L)%n +# +# The following line includes no MDC context parameters: +#log4j.appender.stdout.layout.ConversionPattern=[%d] %p %m (%c:%L)%n (%t) + +log4j.logger.org.reflections=OFF +log4j.logger.kafka=OFF +log4j.logger.state.change.logger=OFF +log4j.logger.org.apache.kafka.connect.mirror=INFO diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/cli/ConnectDistributed.java b/connect/runtime/src/main/java/org/apache/kafka/connect/cli/ConnectDistributed.java new file mode 100644 index 0000000..8d93e79 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/cli/ConnectDistributed.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.cli; + +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.connect.connector.policy.ConnectorClientConfigOverridePolicy; +import org.apache.kafka.connect.runtime.Connect; +import org.apache.kafka.connect.runtime.Worker; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.apache.kafka.connect.runtime.WorkerConfigTransformer; +import org.apache.kafka.connect.runtime.WorkerInfo; +import org.apache.kafka.connect.runtime.distributed.DistributedConfig; +import org.apache.kafka.connect.runtime.distributed.DistributedHerder; +import org.apache.kafka.connect.runtime.isolation.Plugins; +import org.apache.kafka.connect.runtime.rest.RestServer; +import org.apache.kafka.connect.storage.ConfigBackingStore; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.storage.KafkaConfigBackingStore; +import org.apache.kafka.connect.storage.KafkaOffsetBackingStore; +import org.apache.kafka.connect.storage.KafkaStatusBackingStore; +import org.apache.kafka.connect.storage.StatusBackingStore; +import org.apache.kafka.connect.util.ConnectUtils; +import org.apache.kafka.connect.util.SharedTopicAdmin; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.URI; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + *

            + * Command line utility that runs Kafka Connect in distributed mode. In this mode, the process joints a group of other workers + * and work is distributed among them. This is useful for running Connect as a service, where connectors can be + * submitted to the cluster to be automatically executed in a scalable, distributed fashion. This also allows you to + * easily scale out horizontally, elastically adding or removing capacity simply by starting or stopping worker + * instances. + *

            + */ +public class ConnectDistributed { + private static final Logger log = LoggerFactory.getLogger(ConnectDistributed.class); + + private final Time time = Time.SYSTEM; + private final long initStart = time.hiResClockMs(); + + public static void main(String[] args) { + + if (args.length < 1 || Arrays.asList(args).contains("--help")) { + log.info("Usage: ConnectDistributed worker.properties"); + Exit.exit(1); + } + + try { + WorkerInfo initInfo = new WorkerInfo(); + initInfo.logAll(); + + String workerPropsFile = args[0]; + Map workerProps = !workerPropsFile.isEmpty() ? + Utils.propsToStringMap(Utils.loadProps(workerPropsFile)) : Collections.emptyMap(); + + ConnectDistributed connectDistributed = new ConnectDistributed(); + Connect connect = connectDistributed.startConnect(workerProps); + + // Shutdown will be triggered by Ctrl-C or via HTTP shutdown request + connect.awaitStop(); + + } catch (Throwable t) { + log.error("Stopping due to error", t); + Exit.exit(2); + } + } + + public Connect startConnect(Map workerProps) { + log.info("Scanning for plugin classes. This might take a moment ..."); + Plugins plugins = new Plugins(workerProps); + plugins.compareAndSwapWithDelegatingLoader(); + DistributedConfig config = new DistributedConfig(workerProps); + + String kafkaClusterId = ConnectUtils.lookupKafkaClusterId(config); + log.debug("Kafka cluster ID: {}", kafkaClusterId); + + RestServer rest = new RestServer(config); + rest.initializeServer(); + + URI advertisedUrl = rest.advertisedUrl(); + String workerId = advertisedUrl.getHost() + ":" + advertisedUrl.getPort(); + + // Create the admin client to be shared by all backing stores. + Map adminProps = new HashMap<>(config.originals()); + ConnectUtils.addMetricsContextProperties(adminProps, config, kafkaClusterId); + SharedTopicAdmin sharedAdmin = new SharedTopicAdmin(adminProps); + + KafkaOffsetBackingStore offsetBackingStore = new KafkaOffsetBackingStore(sharedAdmin); + offsetBackingStore.configure(config); + + ConnectorClientConfigOverridePolicy connectorClientConfigOverridePolicy = plugins.newPlugin( + config.getString(WorkerConfig.CONNECTOR_CLIENT_POLICY_CLASS_CONFIG), + config, ConnectorClientConfigOverridePolicy.class); + + Worker worker = new Worker(workerId, time, plugins, config, offsetBackingStore, connectorClientConfigOverridePolicy); + WorkerConfigTransformer configTransformer = worker.configTransformer(); + + Converter internalValueConverter = worker.getInternalValueConverter(); + StatusBackingStore statusBackingStore = new KafkaStatusBackingStore(time, internalValueConverter, sharedAdmin); + statusBackingStore.configure(config); + + ConfigBackingStore configBackingStore = new KafkaConfigBackingStore( + internalValueConverter, + config, + configTransformer, + sharedAdmin); + + // Pass the shared admin to the distributed herder as an additional AutoCloseable object that should be closed when the + // herder is stopped. This is easier than having to track and own the lifecycle ourselves. + DistributedHerder herder = new DistributedHerder(config, time, worker, + kafkaClusterId, statusBackingStore, configBackingStore, + advertisedUrl.toString(), connectorClientConfigOverridePolicy, sharedAdmin); + + final Connect connect = new Connect(herder, rest); + log.info("Kafka Connect distributed worker initialization took {}ms", time.hiResClockMs() - initStart); + try { + connect.start(); + } catch (Exception e) { + log.error("Failed to start Connect", e); + connect.stop(); + Exit.exit(3); + } + + return connect; + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/cli/ConnectStandalone.java b/connect/runtime/src/main/java/org/apache/kafka/connect/cli/ConnectStandalone.java new file mode 100644 index 0000000..19cc115 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/cli/ConnectStandalone.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.cli; + +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.connect.connector.policy.ConnectorClientConfigOverridePolicy; +import org.apache.kafka.connect.runtime.Connect; +import org.apache.kafka.connect.runtime.ConnectorConfig; +import org.apache.kafka.connect.runtime.Herder; +import org.apache.kafka.connect.runtime.Worker; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.apache.kafka.connect.runtime.WorkerInfo; +import org.apache.kafka.connect.runtime.isolation.Plugins; +import org.apache.kafka.connect.runtime.rest.RestServer; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorInfo; +import org.apache.kafka.connect.runtime.standalone.StandaloneConfig; +import org.apache.kafka.connect.runtime.standalone.StandaloneHerder; +import org.apache.kafka.connect.storage.FileOffsetBackingStore; +import org.apache.kafka.connect.util.ConnectUtils; +import org.apache.kafka.connect.util.FutureCallback; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.URI; +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; + +/** + *

            + * Command line utility that runs Kafka Connect as a standalone process. In this mode, work is not + * distributed. Instead, all the normal Connect machinery works within a single process. This is + * useful for ad hoc, small, or experimental jobs. + *

            + *

            + * By default, no job configs or offset data is persistent. You can make jobs persistent and + * fault tolerant by overriding the settings to use file storage for both. + *

            + */ +public class ConnectStandalone { + private static final Logger log = LoggerFactory.getLogger(ConnectStandalone.class); + + public static void main(String[] args) { + + if (args.length < 2 || Arrays.asList(args).contains("--help")) { + log.info("Usage: ConnectStandalone worker.properties connector1.properties [connector2.properties ...]"); + Exit.exit(1); + } + + try { + Time time = Time.SYSTEM; + log.info("Kafka Connect standalone worker initializing ..."); + long initStart = time.hiResClockMs(); + WorkerInfo initInfo = new WorkerInfo(); + initInfo.logAll(); + + String workerPropsFile = args[0]; + Map workerProps = !workerPropsFile.isEmpty() ? + Utils.propsToStringMap(Utils.loadProps(workerPropsFile)) : Collections.emptyMap(); + + log.info("Scanning for plugin classes. This might take a moment ..."); + Plugins plugins = new Plugins(workerProps); + plugins.compareAndSwapWithDelegatingLoader(); + StandaloneConfig config = new StandaloneConfig(workerProps); + + String kafkaClusterId = ConnectUtils.lookupKafkaClusterId(config); + log.debug("Kafka cluster ID: {}", kafkaClusterId); + + RestServer rest = new RestServer(config); + rest.initializeServer(); + + URI advertisedUrl = rest.advertisedUrl(); + String workerId = advertisedUrl.getHost() + ":" + advertisedUrl.getPort(); + + ConnectorClientConfigOverridePolicy connectorClientConfigOverridePolicy = plugins.newPlugin( + config.getString(WorkerConfig.CONNECTOR_CLIENT_POLICY_CLASS_CONFIG), + config, ConnectorClientConfigOverridePolicy.class); + Worker worker = new Worker(workerId, time, plugins, config, new FileOffsetBackingStore(), + connectorClientConfigOverridePolicy); + + Herder herder = new StandaloneHerder(worker, kafkaClusterId, connectorClientConfigOverridePolicy); + final Connect connect = new Connect(herder, rest); + log.info("Kafka Connect standalone worker initialization took {}ms", time.hiResClockMs() - initStart); + + try { + connect.start(); + for (final String connectorPropsFile : Arrays.copyOfRange(args, 1, args.length)) { + Map connectorProps = Utils.propsToStringMap(Utils.loadProps(connectorPropsFile)); + FutureCallback> cb = new FutureCallback<>((error, info) -> { + if (error != null) + log.error("Failed to create job for {}", connectorPropsFile); + else + log.info("Created connector {}", info.result().name()); + }); + herder.putConnectorConfig( + connectorProps.get(ConnectorConfig.NAME_CONFIG), + connectorProps, false, cb); + cb.get(); + } + } catch (Throwable t) { + log.error("Stopping after connector error", t); + connect.stop(); + Exit.exit(3); + } + + // Shutdown will be triggered by Ctrl-C or via HTTP shutdown request + connect.awaitStop(); + + } catch (Throwable t) { + log.error("Stopping due to error", t); + Exit.exit(2); + } + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/connector/policy/AbstractConnectorClientConfigOverridePolicy.java b/connect/runtime/src/main/java/org/apache/kafka/connect/connector/policy/AbstractConnectorClientConfigOverridePolicy.java new file mode 100644 index 0000000..859041a --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/connector/policy/AbstractConnectorClientConfigOverridePolicy.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.connector.policy; + +import org.apache.kafka.common.config.ConfigValue; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +public abstract class AbstractConnectorClientConfigOverridePolicy implements ConnectorClientConfigOverridePolicy { + + @Override + public void close() throws Exception { + + } + + @Override + public final List validate(ConnectorClientConfigRequest connectorClientConfigRequest) { + Map inputConfig = connectorClientConfigRequest.clientProps(); + return inputConfig.entrySet().stream().map(this::configValue).collect(Collectors.toList()); + } + + protected ConfigValue configValue(Map.Entry configEntry) { + ConfigValue configValue = + new ConfigValue(configEntry.getKey(), configEntry.getValue(), new ArrayList<>(), new ArrayList<>()); + validate(configValue); + return configValue; + } + + protected void validate(ConfigValue configValue) { + if (!isAllowed(configValue)) { + configValue.addErrorMessage("The '" + policyName() + "' policy does not allow '" + configValue.name() + + "' to be overridden in the connector configuration."); + } + } + + protected abstract String policyName(); + + protected abstract boolean isAllowed(ConfigValue configValue); +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/connector/policy/AllConnectorClientConfigOverridePolicy.java b/connect/runtime/src/main/java/org/apache/kafka/connect/connector/policy/AllConnectorClientConfigOverridePolicy.java new file mode 100644 index 0000000..e808857 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/connector/policy/AllConnectorClientConfigOverridePolicy.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.connector.policy; + +import org.apache.kafka.common.config.ConfigValue; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Map; + +/** + * Allows all client configurations to be overridden via the connector configs by setting {@code connector.client.config.override.policy} to {@code All} + */ +public class AllConnectorClientConfigOverridePolicy extends AbstractConnectorClientConfigOverridePolicy { + private static final Logger log = LoggerFactory.getLogger(AllConnectorClientConfigOverridePolicy.class); + + @Override + protected String policyName() { + return "All"; + } + + @Override + protected boolean isAllowed(ConfigValue configValue) { + return true; + } + + @Override + public void configure(Map configs) { + log.info("Setting up All Policy for ConnectorClientConfigOverride. This will allow all client configurations to be overridden"); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/connector/policy/NoneConnectorClientConfigOverridePolicy.java b/connect/runtime/src/main/java/org/apache/kafka/connect/connector/policy/NoneConnectorClientConfigOverridePolicy.java new file mode 100644 index 0000000..9b414c4 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/connector/policy/NoneConnectorClientConfigOverridePolicy.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.connector.policy; + +import org.apache.kafka.common.config.ConfigValue; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Map; + +/** + * Disallow any client configuration to be overridden via the connector configs by setting {@code connector.client.config.override.policy} to {@code None}. + * This is the default behavior. + */ +public class NoneConnectorClientConfigOverridePolicy extends AbstractConnectorClientConfigOverridePolicy { + private static final Logger log = LoggerFactory.getLogger(NoneConnectorClientConfigOverridePolicy.class); + + @Override + protected String policyName() { + return "None"; + } + + @Override + protected boolean isAllowed(ConfigValue configValue) { + return false; + } + + @Override + public void configure(Map configs) { + log.info("Setting up None Policy for ConnectorClientConfigOverride. This will disallow any client configuration to be overridden"); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/connector/policy/PrincipalConnectorClientConfigOverridePolicy.java b/connect/runtime/src/main/java/org/apache/kafka/connect/connector/policy/PrincipalConnectorClientConfigOverridePolicy.java new file mode 100644 index 0000000..492c5a9 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/connector/policy/PrincipalConnectorClientConfigOverridePolicy.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.connector.policy; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.common.config.ConfigValue; +import org.apache.kafka.common.config.SaslConfigs; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * Allows all {@code sasl} configurations to be overridden via the connector configs by setting {@code connector.client.config.override.policy} to + * {@code Principal}. This allows to set a principal per connector. + */ +public class PrincipalConnectorClientConfigOverridePolicy extends AbstractConnectorClientConfigOverridePolicy { + private static final Logger log = LoggerFactory.getLogger(PrincipalConnectorClientConfigOverridePolicy.class); + + private static final Set ALLOWED_CONFIG = + Stream.of(SaslConfigs.SASL_JAAS_CONFIG, SaslConfigs.SASL_MECHANISM, CommonClientConfigs.SECURITY_PROTOCOL_CONFIG). + collect(Collectors.toSet()); + + @Override + protected String policyName() { + return "Principal"; + } + + @Override + protected boolean isAllowed(ConfigValue configValue) { + return ALLOWED_CONFIG.contains(configValue.name()); + } + + @Override + public void configure(Map configs) { + log.info("Setting up Principal policy for ConnectorClientConfigOverride. This will allow `sasl` client configuration to be " + + "overridden."); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/converters/ByteArrayConverter.java b/connect/runtime/src/main/java/org/apache/kafka/connect/converters/ByteArrayConverter.java new file mode 100644 index 0000000..34c552e --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/converters/ByteArrayConverter.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.converters; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.errors.DataException; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.storage.ConverterConfig; +import org.apache.kafka.connect.storage.HeaderConverter; + +import java.util.Map; + +/** + * Pass-through converter for raw byte data. + * + * This implementation currently does nothing with the topic names or header names. + */ +public class ByteArrayConverter implements Converter, HeaderConverter { + + private static final ConfigDef CONFIG_DEF = ConverterConfig.newConfigDef(); + + @Override + public ConfigDef config() { + return CONFIG_DEF; + } + + @Override + public void configure(Map configs) { + } + + @Override + public void configure(Map configs, boolean isKey) { + } + + @Override + public byte[] fromConnectData(String topic, Schema schema, Object value) { + if (schema != null && schema.type() != Schema.Type.BYTES) + throw new DataException("Invalid schema type for ByteArrayConverter: " + schema.type().toString()); + + if (value != null && !(value instanceof byte[])) + throw new DataException("ByteArrayConverter is not compatible with objects of type " + value.getClass()); + + return (byte[]) value; + } + + @Override + public SchemaAndValue toConnectData(String topic, byte[] value) { + return new SchemaAndValue(Schema.OPTIONAL_BYTES_SCHEMA, value); + } + + @Override + public byte[] fromConnectHeader(String topic, String headerKey, Schema schema, Object value) { + return fromConnectData(topic, schema, value); + } + + @Override + public SchemaAndValue toConnectHeader(String topic, String headerKey, byte[] value) { + return toConnectData(topic, value); + } + + @Override + public void close() { + // do nothing + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/converters/DoubleConverter.java b/connect/runtime/src/main/java/org/apache/kafka/connect/converters/DoubleConverter.java new file mode 100644 index 0000000..684caa1 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/converters/DoubleConverter.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.converters; + +import org.apache.kafka.common.serialization.DoubleDeserializer; +import org.apache.kafka.common.serialization.DoubleSerializer; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.storage.HeaderConverter; + +/** + * {@link Converter} and {@link HeaderConverter} implementation that only supports serializing to and deserializing from double values. + * It does support handling nulls. When converting from bytes to Kafka Connect format, the converter will always return an + * optional FLOAT64 schema. + *

            + * This implementation currently does nothing with the topic names or header names. + */ +public class DoubleConverter extends NumberConverter { + + public DoubleConverter() { + super("double", Schema.OPTIONAL_FLOAT64_SCHEMA, new DoubleSerializer(), new DoubleDeserializer()); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/converters/FloatConverter.java b/connect/runtime/src/main/java/org/apache/kafka/connect/converters/FloatConverter.java new file mode 100644 index 0000000..3f92b96 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/converters/FloatConverter.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.converters; + +import org.apache.kafka.common.serialization.FloatDeserializer; +import org.apache.kafka.common.serialization.FloatSerializer; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.storage.HeaderConverter; + +/** + * {@link Converter} and {@link HeaderConverter} implementation that only supports serializing to and deserializing from float values. + * It does support handling nulls. When converting from bytes to Kafka Connect format, the converter will always return an + * optional FLOAT32 schema. + *

            + * This implementation currently does nothing with the topic names or header names. + */ +public class FloatConverter extends NumberConverter { + + public FloatConverter() { + super("float", Schema.OPTIONAL_FLOAT32_SCHEMA, new FloatSerializer(), new FloatDeserializer()); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/converters/IntegerConverter.java b/connect/runtime/src/main/java/org/apache/kafka/connect/converters/IntegerConverter.java new file mode 100644 index 0000000..f5388ce --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/converters/IntegerConverter.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.converters; + +import org.apache.kafka.common.serialization.IntegerDeserializer; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.storage.HeaderConverter; + +/** + * {@link Converter} and {@link HeaderConverter} implementation that only supports serializing to and deserializing from integer values. + * It does support handling nulls. When converting from bytes to Kafka Connect format, the converter will always return an + * optional INT32 schema. + *

            + * This implementation currently does nothing with the topic names or header names. + */ +public class IntegerConverter extends NumberConverter { + + public IntegerConverter() { + super("integer", Schema.OPTIONAL_INT32_SCHEMA, new IntegerSerializer(), new IntegerDeserializer()); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/converters/LongConverter.java b/connect/runtime/src/main/java/org/apache/kafka/connect/converters/LongConverter.java new file mode 100644 index 0000000..f91f4fa --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/converters/LongConverter.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.converters; + +import org.apache.kafka.common.serialization.LongDeserializer; +import org.apache.kafka.common.serialization.LongSerializer; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.storage.HeaderConverter; + +/** + * {@link Converter} and {@link HeaderConverter} implementation that only supports serializing to and deserializing from long values. + * It does support handling nulls. When converting from bytes to Kafka Connect format, the converter will always return an + * optional INT64 schema. + *

            + * This implementation currently does nothing with the topic names or header names. + */ +public class LongConverter extends NumberConverter { + + public LongConverter() { + super("long", Schema.OPTIONAL_INT64_SCHEMA, new LongSerializer(), new LongDeserializer()); + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/converters/NumberConverter.java b/connect/runtime/src/main/java/org/apache/kafka/connect/converters/NumberConverter.java new file mode 100644 index 0000000..0af4aac --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/converters/NumberConverter.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.converters; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.errors.SerializationException; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.errors.DataException; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.storage.ConverterType; +import org.apache.kafka.connect.storage.HeaderConverter; +import org.apache.kafka.connect.storage.StringConverterConfig; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +/** + * {@link Converter} and {@link HeaderConverter} implementation that only supports serializing to and deserializing from number values. + * It does support handling nulls. When converting from bytes to Kafka Connect format, the converter will always return the specified + * schema. + *

            + * This implementation currently does nothing with the topic names or header names. + */ +abstract class NumberConverter implements Converter, HeaderConverter { + + private final Serializer serializer; + private final Deserializer deserializer; + private final String typeName; + private final Schema schema; + + /** + * Create the converter. + * + * @param typeName the displayable name of the type; may not be null + * @param schema the optional schema to be used for all deserialized forms; may not be null + * @param serializer the serializer; may not be null + * @param deserializer the deserializer; may not be null + */ + protected NumberConverter(String typeName, Schema schema, Serializer serializer, Deserializer deserializer) { + this.typeName = typeName; + this.schema = schema; + this.serializer = serializer; + this.deserializer = deserializer; + assert this.serializer != null; + assert this.deserializer != null; + assert this.typeName != null; + assert this.schema != null; + } + + @Override + public ConfigDef config() { + return NumberConverterConfig.configDef(); + } + + @Override + public void configure(Map configs) { + NumberConverterConfig conf = new NumberConverterConfig(configs); + boolean isKey = conf.type() == ConverterType.KEY; + serializer.configure(configs, isKey); + deserializer.configure(configs, isKey); + + } + + @Override + public void configure(Map configs, boolean isKey) { + Map conf = new HashMap<>(configs); + conf.put(StringConverterConfig.TYPE_CONFIG, isKey ? ConverterType.KEY.getName() : ConverterType.VALUE.getName()); + configure(conf); + } + + @SuppressWarnings("unchecked") + protected T cast(Object value) { + return (T) value; + } + + @Override + public byte[] fromConnectData(String topic, Schema schema, Object value) { + try { + return serializer.serialize(topic, value == null ? null : cast(value)); + } catch (ClassCastException e) { + throw new DataException("Failed to serialize to " + typeName + " (was " + value.getClass() + "): ", e); + } catch (SerializationException e) { + throw new DataException("Failed to serialize to " + typeName + ": ", e); + } + } + + @Override + public SchemaAndValue toConnectData(String topic, byte[] value) { + try { + return new SchemaAndValue(schema, deserializer.deserialize(topic, value)); + } catch (SerializationException e) { + throw new DataException("Failed to deserialize " + typeName + ": ", e); + } + } + + @Override + public byte[] fromConnectHeader(String topic, String headerKey, Schema schema, Object value) { + return fromConnectData(topic, schema, value); + } + + @Override + public SchemaAndValue toConnectHeader(String topic, String headerKey, byte[] value) { + return toConnectData(topic, value); + } + + @Override + public void close() throws IOException { + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/converters/NumberConverterConfig.java b/connect/runtime/src/main/java/org/apache/kafka/connect/converters/NumberConverterConfig.java new file mode 100644 index 0000000..49ad986 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/converters/NumberConverterConfig.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.converters; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.storage.ConverterConfig; + +import java.util.Map; + +/** + * Configuration options for instances of {@link LongConverter}, {@link IntegerConverter}, {@link ShortConverter}, {@link DoubleConverter}, + * and {@link FloatConverter} instances. + */ +public class NumberConverterConfig extends ConverterConfig { + + private final static ConfigDef CONFIG = ConverterConfig.newConfigDef(); + + public static ConfigDef configDef() { + return CONFIG; + } + + public NumberConverterConfig(Map props) { + super(CONFIG, props); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/converters/ShortConverter.java b/connect/runtime/src/main/java/org/apache/kafka/connect/converters/ShortConverter.java new file mode 100644 index 0000000..1c455b1 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/converters/ShortConverter.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.converters; + +import org.apache.kafka.common.serialization.ShortDeserializer; +import org.apache.kafka.common.serialization.ShortSerializer; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.storage.HeaderConverter; + +/** + * {@link Converter} and {@link HeaderConverter} implementation that only supports serializing to and deserializing from short values. + * It does support handling nulls. When converting from bytes to Kafka Connect format, the converter will always return an + * optional INT16 schema. + *

            + * This implementation currently does nothing with the topic names or header names. + */ +public class ShortConverter extends NumberConverter { + + public ShortConverter() { + super("short", Schema.OPTIONAL_INT16_SCHEMA, new ShortSerializer(), new ShortDeserializer()); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/AbstractHerder.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/AbstractHerder.java new file mode 100644 index 0000000..30555ef --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/AbstractHerder.java @@ -0,0 +1,746 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.Config; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigDef.ConfigKey; +import org.apache.kafka.common.config.ConfigDef.Type; +import org.apache.kafka.common.config.ConfigTransformer; +import org.apache.kafka.common.config.ConfigValue; +import org.apache.kafka.connect.connector.Connector; +import org.apache.kafka.connect.connector.policy.ConnectorClientConfigOverridePolicy; +import org.apache.kafka.connect.connector.policy.ConnectorClientConfigRequest; +import org.apache.kafka.connect.errors.NotFoundException; +import org.apache.kafka.connect.runtime.distributed.ClusterConfigState; +import org.apache.kafka.connect.runtime.isolation.Plugins; +import org.apache.kafka.connect.runtime.rest.entities.ActiveTopicsInfo; +import org.apache.kafka.connect.runtime.rest.entities.ConfigInfo; +import org.apache.kafka.connect.runtime.rest.entities.ConfigInfos; +import org.apache.kafka.connect.runtime.rest.entities.ConfigKeyInfo; +import org.apache.kafka.connect.runtime.rest.entities.ConfigValueInfo; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorInfo; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorStateInfo; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorType; +import org.apache.kafka.connect.runtime.rest.errors.BadRequestException; +import org.apache.kafka.connect.source.SourceConnector; +import org.apache.kafka.connect.storage.ConfigBackingStore; +import org.apache.kafka.connect.storage.StatusBackingStore; +import org.apache.kafka.connect.util.Callback; +import org.apache.kafka.connect.util.ConnectorTaskId; + +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; +import java.io.UnsupportedEncodingException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +/** + * Abstract Herder implementation which handles connector/task lifecycle tracking. Extensions + * must invoke the lifecycle hooks appropriately. + * + * This class takes the following approach for sending status updates to the backing store: + * + * 1) When the connector or task is starting, we overwrite the previous state blindly. This ensures that + * every rebalance will reset the state of tasks to the proper state. The intuition is that there should + * be less chance of write conflicts when the worker has just received its assignment and is starting tasks. + * In particular, this prevents us from depending on the generation absolutely. If the group disappears + * and the generation is reset, then we'll overwrite the status information with the older (and larger) + * generation with the updated one. The danger of this approach is that slow starting tasks may cause the + * status to be overwritten after a rebalance has completed. + * + * 2) If the connector or task fails or is shutdown, we use {@link StatusBackingStore#putSafe(ConnectorStatus)}, + * which provides a little more protection if the worker is no longer in the group (in which case the + * task may have already been started on another worker). Obviously this is still racy. If the task has just + * started on another worker, we may not have the updated status cached yet. In this case, we'll overwrite + * the value which will cause the state to be inconsistent (most likely until the next rebalance). Until + * we have proper producer groups with fenced groups, there is not much else we can do. + */ +public abstract class AbstractHerder implements Herder, TaskStatus.Listener, ConnectorStatus.Listener { + + private final String workerId; + protected final Worker worker; + private final String kafkaClusterId; + protected final StatusBackingStore statusBackingStore; + protected final ConfigBackingStore configBackingStore; + private final ConnectorClientConfigOverridePolicy connectorClientConfigOverridePolicy; + protected volatile boolean running = false; + private final ExecutorService connectorExecutor; + + private ConcurrentMap tempConnectors = new ConcurrentHashMap<>(); + + public AbstractHerder(Worker worker, + String workerId, + String kafkaClusterId, + StatusBackingStore statusBackingStore, + ConfigBackingStore configBackingStore, + ConnectorClientConfigOverridePolicy connectorClientConfigOverridePolicy) { + this.worker = worker; + this.worker.herder = this; + this.workerId = workerId; + this.kafkaClusterId = kafkaClusterId; + this.statusBackingStore = statusBackingStore; + this.configBackingStore = configBackingStore; + this.connectorClientConfigOverridePolicy = connectorClientConfigOverridePolicy; + this.connectorExecutor = Executors.newCachedThreadPool(); + } + + @Override + public String kafkaClusterId() { + return kafkaClusterId; + } + + protected abstract int generation(); + + protected void startServices() { + this.worker.start(); + this.statusBackingStore.start(); + this.configBackingStore.start(); + } + + protected void stopServices() { + this.statusBackingStore.stop(); + this.configBackingStore.stop(); + this.worker.stop(); + this.connectorExecutor.shutdown(); + } + + @Override + public boolean isRunning() { + return running; + } + + @Override + public void onStartup(String connector) { + statusBackingStore.put(new ConnectorStatus(connector, ConnectorStatus.State.RUNNING, + workerId, generation())); + } + + @Override + public void onPause(String connector) { + statusBackingStore.put(new ConnectorStatus(connector, ConnectorStatus.State.PAUSED, + workerId, generation())); + } + + @Override + public void onResume(String connector) { + statusBackingStore.put(new ConnectorStatus(connector, TaskStatus.State.RUNNING, + workerId, generation())); + } + + @Override + public void onShutdown(String connector) { + statusBackingStore.putSafe(new ConnectorStatus(connector, ConnectorStatus.State.UNASSIGNED, + workerId, generation())); + } + + @Override + public void onFailure(String connector, Throwable cause) { + statusBackingStore.putSafe(new ConnectorStatus(connector, ConnectorStatus.State.FAILED, + trace(cause), workerId, generation())); + } + + @Override + public void onStartup(ConnectorTaskId id) { + statusBackingStore.put(new TaskStatus(id, TaskStatus.State.RUNNING, workerId, generation())); + } + + @Override + public void onFailure(ConnectorTaskId id, Throwable cause) { + statusBackingStore.putSafe(new TaskStatus(id, TaskStatus.State.FAILED, workerId, generation(), trace(cause))); + } + + @Override + public void onShutdown(ConnectorTaskId id) { + statusBackingStore.putSafe(new TaskStatus(id, TaskStatus.State.UNASSIGNED, workerId, generation())); + } + + @Override + public void onResume(ConnectorTaskId id) { + statusBackingStore.put(new TaskStatus(id, TaskStatus.State.RUNNING, workerId, generation())); + } + + @Override + public void onPause(ConnectorTaskId id) { + statusBackingStore.put(new TaskStatus(id, TaskStatus.State.PAUSED, workerId, generation())); + } + + @Override + public void onDeletion(String connector) { + for (TaskStatus status : statusBackingStore.getAll(connector)) + onDeletion(status.id()); + statusBackingStore.put(new ConnectorStatus(connector, ConnectorStatus.State.DESTROYED, workerId, generation())); + } + + @Override + public void onDeletion(ConnectorTaskId id) { + statusBackingStore.put(new TaskStatus(id, TaskStatus.State.DESTROYED, workerId, generation())); + } + + public void onRestart(String connector) { + statusBackingStore.put(new ConnectorStatus(connector, ConnectorStatus.State.RESTARTING, + workerId, generation())); + } + + public void onRestart(ConnectorTaskId id) { + statusBackingStore.put(new TaskStatus(id, TaskStatus.State.RESTARTING, workerId, generation())); + } + + @Override + public void pauseConnector(String connector) { + if (!configBackingStore.contains(connector)) + throw new NotFoundException("Unknown connector " + connector); + configBackingStore.putTargetState(connector, TargetState.PAUSED); + } + + @Override + public void resumeConnector(String connector) { + if (!configBackingStore.contains(connector)) + throw new NotFoundException("Unknown connector " + connector); + configBackingStore.putTargetState(connector, TargetState.STARTED); + } + + @Override + public Plugins plugins() { + return worker.getPlugins(); + } + + /* + * Retrieves raw config map by connector name. + */ + protected abstract Map rawConfig(String connName); + + @Override + public void connectorConfig(String connName, Callback> callback) { + // Subset of connectorInfo, so piggy back on that implementation + connectorInfo(connName, (error, result) -> { + if (error != null) + callback.onCompletion(error, null); + else + callback.onCompletion(null, result.config()); + }); + } + + @Override + public Collection connectors() { + return configBackingStore.snapshot().connectors(); + } + + @Override + public ConnectorInfo connectorInfo(String connector) { + final ClusterConfigState configState = configBackingStore.snapshot(); + + if (!configState.contains(connector)) + return null; + Map config = configState.rawConnectorConfig(connector); + + return new ConnectorInfo( + connector, + config, + configState.tasks(connector), + connectorTypeForClass(config.get(ConnectorConfig.CONNECTOR_CLASS_CONFIG)) + ); + } + + protected Map> buildTasksConfig(String connector) { + final ClusterConfigState configState = configBackingStore.snapshot(); + + if (!configState.contains(connector)) + return Collections.emptyMap(); + + Map> configs = new HashMap<>(); + for (ConnectorTaskId cti : configState.tasks(connector)) { + configs.put(cti, configState.taskConfig(cti)); + } + + return configs; + } + + @Override + public ConnectorStateInfo connectorStatus(String connName) { + ConnectorStatus connector = statusBackingStore.get(connName); + if (connector == null) + throw new NotFoundException("No status found for connector " + connName); + + Collection tasks = statusBackingStore.getAll(connName); + + ConnectorStateInfo.ConnectorState connectorState = new ConnectorStateInfo.ConnectorState( + connector.state().toString(), connector.workerId(), connector.trace()); + List taskStates = new ArrayList<>(); + + for (TaskStatus status : tasks) { + taskStates.add(new ConnectorStateInfo.TaskState(status.id().task(), + status.state().toString(), status.workerId(), status.trace())); + } + + Collections.sort(taskStates); + + Map conf = rawConfig(connName); + return new ConnectorStateInfo(connName, connectorState, taskStates, + conf == null ? ConnectorType.UNKNOWN : connectorTypeForClass(conf.get(ConnectorConfig.CONNECTOR_CLASS_CONFIG))); + } + + @Override + public ActiveTopicsInfo connectorActiveTopics(String connName) { + Collection topics = statusBackingStore.getAllTopics(connName).stream() + .map(TopicStatus::topic) + .collect(Collectors.toList()); + return new ActiveTopicsInfo(connName, topics); + } + + @Override + public void resetConnectorActiveTopics(String connName) { + statusBackingStore.getAllTopics(connName).stream() + .forEach(status -> statusBackingStore.deleteTopic(status.connector(), status.topic())); + } + + @Override + public StatusBackingStore statusBackingStore() { + return statusBackingStore; + } + + @Override + public ConnectorStateInfo.TaskState taskStatus(ConnectorTaskId id) { + TaskStatus status = statusBackingStore.get(id); + + if (status == null) + throw new NotFoundException("No status found for task " + id); + + return new ConnectorStateInfo.TaskState(id.task(), status.state().toString(), + status.workerId(), status.trace()); + } + + protected Map validateBasicConnectorConfig(Connector connector, + ConfigDef configDef, + Map config) { + return configDef.validateAll(config); + } + + @Override + public void validateConnectorConfig(Map connectorProps, Callback callback) { + validateConnectorConfig(connectorProps, callback, true); + } + + @Override + public void validateConnectorConfig(Map connectorProps, Callback callback, boolean doLog) { + connectorExecutor.submit(() -> { + try { + ConfigInfos result = validateConnectorConfig(connectorProps, doLog); + callback.onCompletion(null, result); + } catch (Throwable t) { + callback.onCompletion(t, null); + } + }); + } + + /** + * Build the {@link RestartPlan} that describes what should and should not be restarted given the restart request + * and the current status of the connector and task instances. + * + * @param request the restart request; may not be null + * @return the restart plan, or empty if this worker has no status for the connector named in the request and therefore the + * connector cannot be restarted + */ + public Optional buildRestartPlan(RestartRequest request) { + String connectorName = request.connectorName(); + ConnectorStatus connectorStatus = statusBackingStore.get(connectorName); + if (connectorStatus == null) { + return Optional.empty(); + } + + // If requested, mark the connector as restarting + AbstractStatus.State connectorState = request.shouldRestartConnector(connectorStatus) ? AbstractStatus.State.RESTARTING : connectorStatus.state(); + ConnectorStateInfo.ConnectorState connectorInfoState = new ConnectorStateInfo.ConnectorState( + connectorState.toString(), + connectorStatus.workerId(), + connectorStatus.trace() + ); + + // Collect the task states, If requested, mark the task as restarting + List taskStates = statusBackingStore.getAll(connectorName) + .stream() + .map(taskStatus -> { + AbstractStatus.State taskState = request.shouldRestartTask(taskStatus) ? AbstractStatus.State.RESTARTING : taskStatus.state(); + return new ConnectorStateInfo.TaskState( + taskStatus.id().task(), + taskState.toString(), + taskStatus.workerId(), + taskStatus.trace() + ); + }) + .collect(Collectors.toList()); + // Construct the response from the various states + Map conf = rawConfig(connectorName); + ConnectorStateInfo stateInfo = new ConnectorStateInfo( + connectorName, + connectorInfoState, + taskStates, + conf == null ? ConnectorType.UNKNOWN : connectorTypeForClass(conf.get(ConnectorConfig.CONNECTOR_CLASS_CONFIG)) + ); + return Optional.of(new RestartPlan(request, stateInfo)); + + } + + ConfigInfos validateConnectorConfig(Map connectorProps, boolean doLog) { + if (worker.configTransformer() != null) { + connectorProps = worker.configTransformer().transform(connectorProps); + } + String connType = connectorProps.get(ConnectorConfig.CONNECTOR_CLASS_CONFIG); + if (connType == null) + throw new BadRequestException("Connector config " + connectorProps + " contains no connector type"); + + Connector connector = getConnector(connType); + org.apache.kafka.connect.health.ConnectorType connectorType; + ClassLoader savedLoader = plugins().compareAndSwapLoaders(connector); + try { + ConfigDef baseConfigDef; + if (connector instanceof SourceConnector) { + baseConfigDef = SourceConnectorConfig.configDef(); + connectorType = org.apache.kafka.connect.health.ConnectorType.SOURCE; + } else { + baseConfigDef = SinkConnectorConfig.configDef(); + SinkConnectorConfig.validate(connectorProps); + connectorType = org.apache.kafka.connect.health.ConnectorType.SINK; + } + ConfigDef enrichedConfigDef = ConnectorConfig.enrich(plugins(), baseConfigDef, connectorProps, false); + Map validatedConnectorConfig = validateBasicConnectorConfig( + connector, + enrichedConfigDef, + connectorProps + ); + List configValues = new ArrayList<>(validatedConnectorConfig.values()); + Map configKeys = new LinkedHashMap<>(enrichedConfigDef.configKeys()); + Set allGroups = new LinkedHashSet<>(enrichedConfigDef.groups()); + + // do custom connector-specific validation + ConfigDef configDef = connector.config(); + if (null == configDef) { + throw new BadRequestException( + String.format( + "%s.config() must return a ConfigDef that is not null.", + connector.getClass().getName() + ) + ); + } + Config config = connector.validate(connectorProps); + if (null == config) { + throw new BadRequestException( + String.format( + "%s.validate() must return a Config that is not null.", + connector.getClass().getName() + ) + ); + } + configKeys.putAll(configDef.configKeys()); + allGroups.addAll(configDef.groups()); + configValues.addAll(config.configValues()); + ConfigInfos configInfos = generateResult(connType, configKeys, configValues, new ArrayList<>(allGroups)); + + AbstractConfig connectorConfig = new AbstractConfig(new ConfigDef(), connectorProps, doLog); + String connName = connectorProps.get(ConnectorConfig.NAME_CONFIG); + ConfigInfos producerConfigInfos = null; + ConfigInfos consumerConfigInfos = null; + ConfigInfos adminConfigInfos = null; + if (connectorType.equals(org.apache.kafka.connect.health.ConnectorType.SOURCE)) { + producerConfigInfos = validateClientOverrides(connName, + ConnectorConfig.CONNECTOR_CLIENT_PRODUCER_OVERRIDES_PREFIX, + connectorConfig, + ProducerConfig.configDef(), + connector.getClass(), + connectorType, + ConnectorClientConfigRequest.ClientType.PRODUCER, + connectorClientConfigOverridePolicy); + return mergeConfigInfos(connType, configInfos, producerConfigInfos); + } else { + consumerConfigInfos = validateClientOverrides(connName, + ConnectorConfig.CONNECTOR_CLIENT_CONSUMER_OVERRIDES_PREFIX, + connectorConfig, + ProducerConfig.configDef(), + connector.getClass(), + connectorType, + ConnectorClientConfigRequest.ClientType.CONSUMER, + connectorClientConfigOverridePolicy); + // check if topic for dead letter queue exists + String topic = connectorProps.get(SinkConnectorConfig.DLQ_TOPIC_NAME_CONFIG); + if (topic != null && !topic.isEmpty()) { + adminConfigInfos = validateClientOverrides(connName, + ConnectorConfig.CONNECTOR_CLIENT_ADMIN_OVERRIDES_PREFIX, + connectorConfig, + ProducerConfig.configDef(), + connector.getClass(), + connectorType, + ConnectorClientConfigRequest.ClientType.ADMIN, + connectorClientConfigOverridePolicy); + } + + } + return mergeConfigInfos(connType, configInfos, producerConfigInfos, consumerConfigInfos, adminConfigInfos); + } finally { + Plugins.compareAndSwapLoaders(savedLoader); + } + } + + private static ConfigInfos mergeConfigInfos(String connType, ConfigInfos... configInfosList) { + int errorCount = 0; + List configInfoList = new LinkedList<>(); + Set groups = new LinkedHashSet<>(); + for (ConfigInfos configInfos : configInfosList) { + if (configInfos != null) { + errorCount += configInfos.errorCount(); + configInfoList.addAll(configInfos.values()); + groups.addAll(configInfos.groups()); + } + } + return new ConfigInfos(connType, errorCount, new ArrayList<>(groups), configInfoList); + } + + private static ConfigInfos validateClientOverrides(String connName, + String prefix, + AbstractConfig connectorConfig, + ConfigDef configDef, + Class connectorClass, + org.apache.kafka.connect.health.ConnectorType connectorType, + ConnectorClientConfigRequest.ClientType clientType, + ConnectorClientConfigOverridePolicy connectorClientConfigOverridePolicy) { + int errorCount = 0; + List configInfoList = new LinkedList<>(); + Map configKeys = configDef.configKeys(); + Set groups = new LinkedHashSet<>(); + Map clientConfigs = new HashMap<>(); + for (Map.Entry rawClientConfig : connectorConfig.originalsWithPrefix(prefix).entrySet()) { + String configName = rawClientConfig.getKey(); + Object rawConfigValue = rawClientConfig.getValue(); + ConfigKey configKey = configDef.configKeys().get(configName); + Object parsedConfigValue = configKey != null + ? ConfigDef.parseType(configName, rawConfigValue, configKey.type) + : rawConfigValue; + clientConfigs.put(configName, parsedConfigValue); + } + ConnectorClientConfigRequest connectorClientConfigRequest = new ConnectorClientConfigRequest( + connName, connectorType, connectorClass, clientConfigs, clientType); + List configValues = connectorClientConfigOverridePolicy.validate(connectorClientConfigRequest); + if (configValues != null) { + for (ConfigValue validatedConfigValue : configValues) { + ConfigKey configKey = configKeys.get(validatedConfigValue.name()); + ConfigKeyInfo configKeyInfo = null; + if (configKey != null) { + if (configKey.group != null) { + groups.add(configKey.group); + } + configKeyInfo = convertConfigKey(configKey, prefix); + } + + ConfigValue configValue = new ConfigValue(prefix + validatedConfigValue.name(), validatedConfigValue.value(), + validatedConfigValue.recommendedValues(), validatedConfigValue.errorMessages()); + if (configValue.errorMessages().size() > 0) { + errorCount++; + } + ConfigValueInfo configValueInfo = convertConfigValue(configValue, configKey != null ? configKey.type : null); + configInfoList.add(new ConfigInfo(configKeyInfo, configValueInfo)); + } + } + return new ConfigInfos(connectorClass.toString(), errorCount, new ArrayList<>(groups), configInfoList); + } + + // public for testing + public static ConfigInfos generateResult(String connType, Map configKeys, List configValues, List groups) { + int errorCount = 0; + List configInfoList = new LinkedList<>(); + + Map configValueMap = new HashMap<>(); + for (ConfigValue configValue: configValues) { + String configName = configValue.name(); + configValueMap.put(configName, configValue); + if (!configKeys.containsKey(configName)) { + configInfoList.add(new ConfigInfo(null, convertConfigValue(configValue, null))); + errorCount += configValue.errorMessages().size(); + } + } + + for (Map.Entry entry : configKeys.entrySet()) { + String configName = entry.getKey(); + ConfigKeyInfo configKeyInfo = convertConfigKey(entry.getValue()); + Type type = entry.getValue().type; + ConfigValueInfo configValueInfo = null; + if (configValueMap.containsKey(configName)) { + ConfigValue configValue = configValueMap.get(configName); + configValueInfo = convertConfigValue(configValue, type); + errorCount += configValue.errorMessages().size(); + } + configInfoList.add(new ConfigInfo(configKeyInfo, configValueInfo)); + } + return new ConfigInfos(connType, errorCount, groups, configInfoList); + } + + private static ConfigKeyInfo convertConfigKey(ConfigKey configKey) { + return convertConfigKey(configKey, ""); + } + + private static ConfigKeyInfo convertConfigKey(ConfigKey configKey, String prefix) { + String name = prefix + configKey.name; + Type type = configKey.type; + String typeName = configKey.type.name(); + + boolean required = false; + String defaultValue; + if (ConfigDef.NO_DEFAULT_VALUE.equals(configKey.defaultValue)) { + defaultValue = null; + required = true; + } else { + defaultValue = ConfigDef.convertToString(configKey.defaultValue, type); + } + String importance = configKey.importance.name(); + String documentation = configKey.documentation; + String group = configKey.group; + int orderInGroup = configKey.orderInGroup; + String width = configKey.width.name(); + String displayName = configKey.displayName; + List dependents = configKey.dependents; + return new ConfigKeyInfo(name, typeName, required, defaultValue, importance, documentation, group, orderInGroup, width, displayName, dependents); + } + + private static ConfigValueInfo convertConfigValue(ConfigValue configValue, Type type) { + String value = ConfigDef.convertToString(configValue.value(), type); + List recommendedValues = new LinkedList<>(); + + if (type == Type.LIST) { + for (Object object: configValue.recommendedValues()) { + recommendedValues.add(ConfigDef.convertToString(object, Type.STRING)); + } + } else { + for (Object object : configValue.recommendedValues()) { + recommendedValues.add(ConfigDef.convertToString(object, type)); + } + } + return new ConfigValueInfo(configValue.name(), value, recommendedValues, configValue.errorMessages(), configValue.visible()); + } + + protected Connector getConnector(String connType) { + return tempConnectors.computeIfAbsent(connType, k -> plugins().newConnector(k)); + } + + /* + * Retrieves ConnectorType for the corresponding connector class + * @param connClass class of the connector + */ + public ConnectorType connectorTypeForClass(String connClass) { + return ConnectorType.from(getConnector(connClass).getClass()); + } + + /** + * Checks a given {@link ConfigInfos} for validation error messages and adds an exception + * to the given {@link Callback} if any were found. + * + * @param configInfos configInfos to read Errors from + * @param callback callback to add config error exception to + * @return true if errors were found in the config + */ + protected final boolean maybeAddConfigErrors( + ConfigInfos configInfos, + Callback> callback + ) { + int errors = configInfos.errorCount(); + boolean hasErrors = errors > 0; + if (hasErrors) { + StringBuilder messages = new StringBuilder(); + messages.append("Connector configuration is invalid and contains the following ") + .append(errors).append(" error(s):"); + for (ConfigInfo configInfo : configInfos.values()) { + for (String msg : configInfo.configValue().errors()) { + messages.append('\n').append(msg); + } + } + callback.onCompletion( + new BadRequestException( + messages.append( + "\nYou can also find the above list of errors at the endpoint `/connector-plugins/{connectorType}/config/validate`" + ).toString() + ), null + ); + } + return hasErrors; + } + + private String trace(Throwable t) { + ByteArrayOutputStream output = new ByteArrayOutputStream(); + try { + t.printStackTrace(new PrintStream(output, false, StandardCharsets.UTF_8.name())); + return output.toString(StandardCharsets.UTF_8.name()); + } catch (UnsupportedEncodingException e) { + return null; + } + } + + /* + * Performs a reverse transformation on a set of task configs, by replacing values with variable references. + */ + public static List> reverseTransform(String connName, + ClusterConfigState configState, + List> configs) { + + // Find the config keys in the raw connector config that have variable references + Map rawConnConfig = configState.rawConnectorConfig(connName); + Set connKeysWithVariableValues = keysWithVariableValues(rawConnConfig, ConfigTransformer.DEFAULT_PATTERN); + + List> result = new ArrayList<>(); + for (Map config : configs) { + Map newConfig = new HashMap<>(config); + for (String key : connKeysWithVariableValues) { + if (newConfig.containsKey(key)) { + newConfig.put(key, rawConnConfig.get(key)); + } + } + result.add(newConfig); + } + return result; + } + + // Visible for testing + static Set keysWithVariableValues(Map rawConfig, Pattern pattern) { + Set keys = new HashSet<>(); + for (Map.Entry config : rawConfig.entrySet()) { + if (config.getValue() != null) { + Matcher matcher = pattern.matcher(config.getValue()); + if (matcher.find()) { + keys.add(config.getKey()); + } + } + } + return keys; + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/AbstractStatus.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/AbstractStatus.java new file mode 100644 index 0000000..c5e0702 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/AbstractStatus.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import java.util.Objects; + +public abstract class AbstractStatus { + + public enum State { + UNASSIGNED, + RUNNING, + PAUSED, + FAILED, + DESTROYED, + RESTARTING, + } + + private final T id; + private final State state; + private final String trace; + private final String workerId; + private final int generation; + + public AbstractStatus(T id, + State state, + String workerId, + int generation, + String trace) { + this.id = id; + this.state = state; + this.workerId = workerId; + this.generation = generation; + this.trace = trace; + } + + public T id() { + return id; + } + + public State state() { + return state; + } + + public String trace() { + return trace; + } + + public String workerId() { + return workerId; + } + + public int generation() { + return generation; + } + + @Override + public String toString() { + return "Status{" + + "id=" + id + + ", state=" + state + + ", workerId='" + workerId + '\'' + + ", generation=" + generation + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + AbstractStatus that = (AbstractStatus) o; + + return generation == that.generation + && Objects.equals(id, that.id) + && state == that.state + && Objects.equals(trace, that.trace) + && Objects.equals(workerId, that.workerId); + } + + @Override + public int hashCode() { + int result = id != null ? id.hashCode() : 0; + result = 31 * result + (state != null ? state.hashCode() : 0); + result = 31 * result + (trace != null ? trace.hashCode() : 0); + result = 31 * result + (workerId != null ? workerId.hashCode() : 0); + result = 31 * result + generation; + return result; + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/CloseableConnectorContext.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/CloseableConnectorContext.java new file mode 100644 index 0000000..7a09a90 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/CloseableConnectorContext.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.connect.connector.ConnectorContext; +import org.apache.kafka.connect.errors.ConnectException; + +import java.io.Closeable; + +public interface CloseableConnectorContext extends ConnectorContext, Closeable { + + /** + * Close this connector context, causing all future calls to it to throw {@link ConnectException}. + * This is useful to prevent zombie connector threads from making such calls after their connector + * instance should be shut down. + */ + void close(); +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/Connect.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/Connect.java new file mode 100644 index 0000000..80eef03 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/Connect.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.connect.runtime.rest.RestServer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.URI; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * This class ties together all the components of a Kafka Connect process (herder, worker, + * storage, command interface), managing their lifecycle. + */ +public class Connect { + private static final Logger log = LoggerFactory.getLogger(Connect.class); + + private final Herder herder; + private final RestServer rest; + private final CountDownLatch startLatch = new CountDownLatch(1); + private final CountDownLatch stopLatch = new CountDownLatch(1); + private final AtomicBoolean shutdown = new AtomicBoolean(false); + private final ShutdownHook shutdownHook; + + public Connect(Herder herder, RestServer rest) { + log.debug("Kafka Connect instance created"); + this.herder = herder; + this.rest = rest; + shutdownHook = new ShutdownHook(); + } + + public void start() { + try { + log.info("Kafka Connect starting"); + Exit.addShutdownHook("connect-shutdown-hook", shutdownHook); + + herder.start(); + rest.initializeResources(herder); + + log.info("Kafka Connect started"); + } finally { + startLatch.countDown(); + } + } + + public void stop() { + try { + boolean wasShuttingDown = shutdown.getAndSet(true); + if (!wasShuttingDown) { + log.info("Kafka Connect stopping"); + + rest.stop(); + herder.stop(); + + log.info("Kafka Connect stopped"); + } + } finally { + stopLatch.countDown(); + } + } + + public void awaitStop() { + try { + stopLatch.await(); + } catch (InterruptedException e) { + log.error("Interrupted waiting for Kafka Connect to shutdown"); + } + } + + public boolean isRunning() { + return herder.isRunning(); + } + + // Visible for testing + public URI restUrl() { + return rest.serverUrl(); + } + + public URI adminUrl() { + return rest.adminUrl(); + } + + private class ShutdownHook extends Thread { + @Override + public void run() { + try { + startLatch.await(); + Connect.this.stop(); + } catch (InterruptedException e) { + log.error("Interrupted in shutdown hook while waiting for Kafka Connect startup to finish"); + } + } + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/ConnectMetrics.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/ConnectMetrics.java new file mode 100644 index 0000000..2871bbe --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/ConnectMetrics.java @@ -0,0 +1,446 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.MetricNameTemplate; +import org.apache.kafka.common.metrics.Gauge; +import org.apache.kafka.common.metrics.JmxReporter; +import org.apache.kafka.common.metrics.KafkaMetricsContext; +import org.apache.kafka.common.metrics.MetricConfig; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.MetricsContext; +import org.apache.kafka.common.metrics.MetricsReporter; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.internals.MetricsUtils; +import org.apache.kafka.common.utils.AppInfoParser; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.connect.runtime.distributed.DistributedConfig; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.TimeUnit; + +/** + * The Connect metrics with JMX reporter. + */ +public class ConnectMetrics { + + public static final String JMX_PREFIX = "kafka.connect"; + + private static final Logger LOG = LoggerFactory.getLogger(ConnectMetrics.class); + + private final Metrics metrics; + private final Time time; + private final String workerId; + private final ConcurrentMap groupsByName = new ConcurrentHashMap<>(); + private final ConnectMetricsRegistry registry = new ConnectMetricsRegistry(); + + /** + * Create an instance. + * + * @param workerId the worker identifier; may not be null + * @param config the worker configuration; may not be null + * @param time the time; may not be null + * @param clusterId the Kafka cluster ID + */ + public ConnectMetrics(String workerId, WorkerConfig config, Time time, String clusterId) { + this.workerId = workerId; + this.time = time; + + int numSamples = config.getInt(CommonClientConfigs.METRICS_NUM_SAMPLES_CONFIG); + long sampleWindowMs = config.getLong(CommonClientConfigs.METRICS_SAMPLE_WINDOW_MS_CONFIG); + String metricsRecordingLevel = config.getString(CommonClientConfigs.METRICS_RECORDING_LEVEL_CONFIG); + List reporters = config.getConfiguredInstances(CommonClientConfigs.METRIC_REPORTER_CLASSES_CONFIG, MetricsReporter.class); + + MetricConfig metricConfig = new MetricConfig().samples(numSamples) + .timeWindow(sampleWindowMs, TimeUnit.MILLISECONDS).recordLevel( + Sensor.RecordingLevel.forName(metricsRecordingLevel)); + JmxReporter jmxReporter = new JmxReporter(); + jmxReporter.configure(config.originals()); + reporters.add(jmxReporter); + + Map contextLabels = new HashMap<>(); + contextLabels.putAll(config.originalsWithPrefix(CommonClientConfigs.METRICS_CONTEXT_PREFIX)); + contextLabels.put(WorkerConfig.CONNECT_KAFKA_CLUSTER_ID, clusterId); + Object groupId = config.originals().get(DistributedConfig.GROUP_ID_CONFIG); + if (groupId != null) { + contextLabels.put(WorkerConfig.CONNECT_GROUP_ID, groupId); + } + MetricsContext metricsContext = new KafkaMetricsContext(JMX_PREFIX, contextLabels); + this.metrics = new Metrics(metricConfig, reporters, time, metricsContext); + + LOG.debug("Registering Connect metrics with JMX for worker '{}'", workerId); + AppInfoParser.registerAppInfo(JMX_PREFIX, workerId, metrics, time.milliseconds()); + } + + /** + * Get the worker identifier. + * + * @return the worker ID; never null + */ + public String workerId() { + return workerId; + } + + /** + * Get the {@link Metrics Kafka Metrics} that are managed by this object and that should be used to + * add sensors and individual metrics. + * + * @return the Kafka Metrics instance; never null + */ + public Metrics metrics() { + return metrics; + } + + /** + * Get the registry of metric names. + * + * @return the registry for the Connect metrics; never null + */ + public ConnectMetricsRegistry registry() { + return registry; + } + + /** + * Get or create a {@link MetricGroup} with the specified group name and the given tags. + * Each group is uniquely identified by the name and tags. + * + * @param groupName the name of the metric group; may not be null + * @param tagKeyValues pairs of tag name and values + * @return the {@link MetricGroup} that can be used to create metrics; never null + * @throws IllegalArgumentException if the group name is not valid + */ + public MetricGroup group(String groupName, String... tagKeyValues) { + MetricGroupId groupId = groupId(groupName, tagKeyValues); + MetricGroup group = groupsByName.get(groupId); + if (group == null) { + group = new MetricGroup(groupId); + MetricGroup previous = groupsByName.putIfAbsent(groupId, group); + if (previous != null) + group = previous; + } + return group; + } + + protected MetricGroupId groupId(String groupName, String... tagKeyValues) { + Map tags = MetricsUtils.getTags(tagKeyValues); + return new MetricGroupId(groupName, tags); + } + + /** + * Get the time. + * + * @return the time; never null + */ + public Time time() { + return time; + } + + /** + * Stop and unregister the metrics from any reporters. + */ + public void stop() { + metrics.close(); + LOG.debug("Unregistering Connect metrics with JMX for worker '{}'", workerId); + AppInfoParser.unregisterAppInfo(JMX_PREFIX, workerId, metrics); + } + + public static class MetricGroupId { + private final String groupName; + private final Map tags; + private final int hc; + private final String str; + + public MetricGroupId(String groupName, Map tags) { + Objects.requireNonNull(groupName); + Objects.requireNonNull(tags); + this.groupName = groupName; + this.tags = Collections.unmodifiableMap(new LinkedHashMap<>(tags)); + this.hc = Objects.hash(this.groupName, this.tags); + StringBuilder sb = new StringBuilder(this.groupName); + for (Map.Entry entry : this.tags.entrySet()) { + sb.append(";").append(entry.getKey()).append('=').append(entry.getValue()); + } + this.str = sb.toString(); + } + + /** + * Get the group name. + * + * @return the group name; never null + */ + public String groupName() { + return groupName; + } + + /** + * Get the immutable map of tag names and values. + * + * @return the tags; never null + */ + public Map tags() { + return tags; + } + + /** + * Determine if the supplied metric name is part of this group identifier. + * + * @param metricName the metric name + * @return true if the metric name's group and tags match this group identifier, or false otherwise + */ + public boolean includes(MetricName metricName) { + return metricName != null && groupName.equals(metricName.group()) && tags.equals(metricName.tags()); + } + + @Override + public int hashCode() { + return hc; + } + + @Override + public boolean equals(Object obj) { + if (obj == this) + return true; + if (obj instanceof MetricGroupId) { + MetricGroupId that = (MetricGroupId) obj; + return this.groupName.equals(that.groupName) && this.tags.equals(that.tags); + } + return false; + } + + @Override + public String toString() { + return str; + } + } + + /** + * A group of metrics. Each group maps to a JMX MBean and each metric maps to an MBean attribute. + *

            + * Sensors should be added via the {@code sensor} methods on this class, rather than directly through + * the {@link Metrics} class, so that the sensor names are made to be unique (based on the group name) + * and so the sensors are removed when this group is {@link #close() closed}. + */ + public class MetricGroup implements AutoCloseable { + private final MetricGroupId groupId; + private final Set sensorNames = new HashSet<>(); + private final String sensorPrefix; + + /** + * Create a group of Connect metrics. + * + * @param groupId the identifier of the group; may not be null and must be valid + */ + protected MetricGroup(MetricGroupId groupId) { + Objects.requireNonNull(groupId); + this.groupId = groupId; + sensorPrefix = "connect-sensor-group: " + groupId.toString() + ";"; + } + + /** + * Get the group identifier. + * + * @return the group identifier; never null + */ + public MetricGroupId groupId() { + return groupId; + } + + /** + * Create the name of a metric that belongs to this group and has the group's tags. + * + * @param template the name template for the metric; may not be null + * @return the metric name; never null + * @throws IllegalArgumentException if the name is not valid + */ + public MetricName metricName(MetricNameTemplate template) { + return metrics.metricInstance(template, groupId.tags()); + } + + // for testing only + MetricName metricName(String name) { + return metrics.metricName(name, groupId.groupName(), "", groupId.tags()); + } + + /** + * The {@link Metrics} that this group belongs to. + *

            + * Do not use this to add {@link Sensor Sensors}, since they will not be removed when this group is + * {@link #close() closed}. Metrics can be added directly, as long as the metric names are obtained from + * this group via the {@link #metricName(MetricNameTemplate)} method. + * + * @return the metrics; never null + */ + public Metrics metrics() { + return metrics; + } + + /** + * The tags of this group. + * + * @return the unmodifiable tags; never null but may be empty + */ + Map tags() { + return groupId.tags(); + } + + /** + * Add to this group an indicator metric with a function that returns the current value. + * + * @param nameTemplate the name template for the metric; may not be null + * @param supplier the function used to determine the literal value of the metric; may not be null + * @throws IllegalArgumentException if the name is not valid + */ + public void addValueMetric(MetricNameTemplate nameTemplate, final LiteralSupplier supplier) { + MetricName metricName = metricName(nameTemplate); + if (metrics().metric(metricName) == null) { + metrics().addMetric(metricName, (Gauge) (config, now) -> supplier.metricValue(now)); + } + } + + /** + * Add to this group an indicator metric that always returns the specified value. + * + * @param nameTemplate the name template for the metric; may not be null + * @param value the value; may not be null + * @throws IllegalArgumentException if the name is not valid + */ + public void addImmutableValueMetric(MetricNameTemplate nameTemplate, final T value) { + MetricName metricName = metricName(nameTemplate); + if (metrics().metric(metricName) == null) { + metrics().addMetric(metricName, (Gauge) (config, now) -> value); + } + } + + /** + * Get or create a sensor with the given unique name and no parent sensors. This uses + * a default recording level of INFO. + * + * @param name The sensor name + * @return The sensor + */ + public Sensor sensor(String name) { + return sensor(name, null, Sensor.RecordingLevel.INFO); + } + + /** + * Get or create a sensor with the given unique name and no parent sensors. This uses + * a default recording level of INFO. + * + * @param name The sensor name + * @return The sensor + */ + public Sensor sensor(String name, Sensor... parents) { + return sensor(name, null, Sensor.RecordingLevel.INFO, parents); + } + + /** + * Get or create a sensor with the given unique name and zero or more parent sensors. All parent sensors will + * receive every value recorded with this sensor. + * + * @param name The name of the sensor + * @param recordingLevel The recording level. + * @param parents The parent sensors + * @return The sensor that is created + */ + public Sensor sensor(String name, Sensor.RecordingLevel recordingLevel, Sensor... parents) { + return sensor(name, null, recordingLevel, parents); + } + + /** + * Get or create a sensor with the given unique name and zero or more parent sensors. All parent sensors will + * receive every value recorded with this sensor. + * + * @param name The name of the sensor + * @param config A default configuration to use for this sensor for metrics that don't have their own config + * @param parents The parent sensors + * @return The sensor that is created + */ + public Sensor sensor(String name, MetricConfig config, Sensor... parents) { + return sensor(name, config, Sensor.RecordingLevel.INFO, parents); + } + + /** + * Get or create a sensor with the given unique name and zero or more parent sensors. All parent sensors will + * receive every value recorded with this sensor. + * + * @param name The name of the sensor + * @param config A default configuration to use for this sensor for metrics that don't have their own config + * @param recordingLevel The recording level. + * @param parents The parent sensors + * @return The sensor that is created + */ + public synchronized Sensor sensor(String name, MetricConfig config, Sensor.RecordingLevel recordingLevel, Sensor... parents) { + // We need to make sure that all sensor names are unique across all groups, so use the sensor prefix + Sensor result = metrics.sensor(sensorPrefix + name, config, Long.MAX_VALUE, recordingLevel, parents); + if (result != null) + sensorNames.add(result.name()); + return result; + } + + /** + * Remove all sensors and metrics associated with this group. + */ + public synchronized void close() { + for (String sensorName : sensorNames) { + metrics.removeSensor(sensorName); + } + sensorNames.clear(); + for (MetricName metricName : new HashSet<>(metrics.metrics().keySet())) { + if (groupId.includes(metricName)) { + metrics.removeMetric(metricName); + } + } + } + } + + /** + * A simple functional interface that returns a literal value. + */ + public interface LiteralSupplier { + + /** + * Return the literal value for the metric. + * + * @param now the current time in milliseconds + * @return the literal metric value; may not be null + */ + T metricValue(long now); + } + + /** + * Utility to generate the documentation for the Connect metrics. + * + * @param args the arguments + */ + public static void main(String[] args) { + ConnectMetricsRegistry metrics = new ConnectMetricsRegistry(); + System.out.println(Metrics.toHtmlTable(JMX_PREFIX, metrics.getAllTemplates())); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/ConnectMetricsRegistry.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/ConnectMetricsRegistry.java new file mode 100644 index 0000000..cf56745 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/ConnectMetricsRegistry.java @@ -0,0 +1,417 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.common.MetricNameTemplate; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +public class ConnectMetricsRegistry { + + public static final String CONNECTOR_TAG_NAME = "connector"; + public static final String TASK_TAG_NAME = "task"; + public static final String CONNECTOR_GROUP_NAME = "connector-metrics"; + public static final String TASK_GROUP_NAME = "connector-task-metrics"; + public static final String SOURCE_TASK_GROUP_NAME = "source-task-metrics"; + public static final String SINK_TASK_GROUP_NAME = "sink-task-metrics"; + public static final String WORKER_GROUP_NAME = "connect-worker-metrics"; + public static final String WORKER_REBALANCE_GROUP_NAME = "connect-worker-rebalance-metrics"; + public static final String TASK_ERROR_HANDLING_GROUP_NAME = "task-error-metrics"; + + private final List allTemplates = new ArrayList<>(); + public final MetricNameTemplate connectorStatus; + public final MetricNameTemplate connectorType; + public final MetricNameTemplate connectorClass; + public final MetricNameTemplate connectorVersion; + public final MetricNameTemplate connectorTotalTaskCount; + public final MetricNameTemplate connectorRunningTaskCount; + public final MetricNameTemplate connectorPausedTaskCount; + public final MetricNameTemplate connectorFailedTaskCount; + public final MetricNameTemplate connectorUnassignedTaskCount; + public final MetricNameTemplate connectorDestroyedTaskCount; + public final MetricNameTemplate connectorRestartingTaskCount; + public final MetricNameTemplate taskStatus; + public final MetricNameTemplate taskRunningRatio; + public final MetricNameTemplate taskPauseRatio; + public final MetricNameTemplate taskCommitTimeMax; + public final MetricNameTemplate taskCommitTimeAvg; + public final MetricNameTemplate taskBatchSizeMax; + public final MetricNameTemplate taskBatchSizeAvg; + public final MetricNameTemplate taskCommitFailurePercentage; + public final MetricNameTemplate taskCommitSuccessPercentage; + public final MetricNameTemplate sourceRecordPollRate; + public final MetricNameTemplate sourceRecordPollTotal; + public final MetricNameTemplate sourceRecordWriteRate; + public final MetricNameTemplate sourceRecordWriteTotal; + public final MetricNameTemplate sourceRecordPollBatchTimeMax; + public final MetricNameTemplate sourceRecordPollBatchTimeAvg; + public final MetricNameTemplate sourceRecordActiveCount; + public final MetricNameTemplate sourceRecordActiveCountMax; + public final MetricNameTemplate sourceRecordActiveCountAvg; + public final MetricNameTemplate sinkRecordReadRate; + public final MetricNameTemplate sinkRecordReadTotal; + public final MetricNameTemplate sinkRecordSendRate; + public final MetricNameTemplate sinkRecordSendTotal; + public final MetricNameTemplate sinkRecordLagMax; + public final MetricNameTemplate sinkRecordPartitionCount; + public final MetricNameTemplate sinkRecordOffsetCommitSeqNum; + public final MetricNameTemplate sinkRecordOffsetCommitCompletionRate; + public final MetricNameTemplate sinkRecordOffsetCommitCompletionTotal; + public final MetricNameTemplate sinkRecordOffsetCommitSkipRate; + public final MetricNameTemplate sinkRecordOffsetCommitSkipTotal; + public final MetricNameTemplate sinkRecordPutBatchTimeMax; + public final MetricNameTemplate sinkRecordPutBatchTimeAvg; + public final MetricNameTemplate sinkRecordActiveCount; + public final MetricNameTemplate sinkRecordActiveCountMax; + public final MetricNameTemplate sinkRecordActiveCountAvg; + public final MetricNameTemplate connectorCount; + public final MetricNameTemplate taskCount; + public final MetricNameTemplate connectorStartupAttemptsTotal; + public final MetricNameTemplate connectorStartupSuccessTotal; + public final MetricNameTemplate connectorStartupSuccessPercentage; + public final MetricNameTemplate connectorStartupFailureTotal; + public final MetricNameTemplate connectorStartupFailurePercentage; + public final MetricNameTemplate taskStartupAttemptsTotal; + public final MetricNameTemplate taskStartupSuccessTotal; + public final MetricNameTemplate taskStartupSuccessPercentage; + public final MetricNameTemplate taskStartupFailureTotal; + public final MetricNameTemplate taskStartupFailurePercentage; + public final MetricNameTemplate connectProtocol; + public final MetricNameTemplate leaderName; + public final MetricNameTemplate epoch; + public final MetricNameTemplate rebalanceCompletedTotal; + public final MetricNameTemplate rebalanceMode; + public final MetricNameTemplate rebalanceTimeMax; + public final MetricNameTemplate rebalanceTimeAvg; + public final MetricNameTemplate rebalanceTimeSinceLast; + public final MetricNameTemplate recordProcessingFailures; + public final MetricNameTemplate recordProcessingErrors; + public final MetricNameTemplate recordsSkipped; + public final MetricNameTemplate retries; + public final MetricNameTemplate errorsLogged; + public final MetricNameTemplate dlqProduceRequests; + public final MetricNameTemplate dlqProduceFailures; + public final MetricNameTemplate lastErrorTimestamp; + + public Map connectorStatusMetrics; + + public ConnectMetricsRegistry() { + this(new LinkedHashSet<>()); + } + + public ConnectMetricsRegistry(Set tags) { + /***** Connector level *****/ + Set connectorTags = new LinkedHashSet<>(tags); + connectorTags.add(CONNECTOR_TAG_NAME); + + connectorStatus = createTemplate("status", CONNECTOR_GROUP_NAME, + "The status of the connector. One of 'unassigned', 'running', 'paused', 'failed', or " + + "'destroyed'.", + connectorTags); + connectorType = createTemplate("connector-type", CONNECTOR_GROUP_NAME, "The type of the connector. One of 'source' or 'sink'.", + connectorTags); + connectorClass = createTemplate("connector-class", CONNECTOR_GROUP_NAME, "The name of the connector class.", connectorTags); + connectorVersion = createTemplate("connector-version", CONNECTOR_GROUP_NAME, + "The version of the connector class, as reported by the connector.", connectorTags); + + /***** Worker task level *****/ + Set workerTaskTags = new LinkedHashSet<>(tags); + workerTaskTags.add(CONNECTOR_TAG_NAME); + workerTaskTags.add(TASK_TAG_NAME); + + taskStatus = createTemplate("status", TASK_GROUP_NAME, + "The status of the connector task. One of 'unassigned', 'running', 'paused', 'failed', or " + + "'destroyed'.", + workerTaskTags); + taskRunningRatio = createTemplate("running-ratio", TASK_GROUP_NAME, + "The fraction of time this task has spent in the running state.", workerTaskTags); + taskPauseRatio = createTemplate("pause-ratio", TASK_GROUP_NAME, "The fraction of time this task has spent in the pause state.", + workerTaskTags); + taskCommitTimeMax = createTemplate("offset-commit-max-time-ms", TASK_GROUP_NAME, + "The maximum time in milliseconds taken by this task to commit offsets.", workerTaskTags); + taskCommitTimeAvg = createTemplate("offset-commit-avg-time-ms", TASK_GROUP_NAME, + "The average time in milliseconds taken by this task to commit offsets.", workerTaskTags); + taskBatchSizeMax = createTemplate("batch-size-max", TASK_GROUP_NAME, "The maximum size of the batches processed by the connector.", + workerTaskTags); + taskBatchSizeAvg = createTemplate("batch-size-avg", TASK_GROUP_NAME, "The average size of the batches processed by the connector.", + workerTaskTags); + taskCommitFailurePercentage = createTemplate("offset-commit-failure-percentage", TASK_GROUP_NAME, + "The average percentage of this task's offset commit attempts that failed.", + workerTaskTags); + taskCommitSuccessPercentage = createTemplate("offset-commit-success-percentage", TASK_GROUP_NAME, + "The average percentage of this task's offset commit attempts that succeeded.", + workerTaskTags); + + /***** Source worker task level *****/ + Set sourceTaskTags = new LinkedHashSet<>(tags); + sourceTaskTags.add(CONNECTOR_TAG_NAME); + sourceTaskTags.add(TASK_TAG_NAME); + + sourceRecordPollRate = createTemplate("source-record-poll-rate", SOURCE_TASK_GROUP_NAME, + "The average per-second number of records produced/polled (before transformation) by " + + "this task belonging to the named source connector in this worker.", + sourceTaskTags); + sourceRecordPollTotal = createTemplate("source-record-poll-total", SOURCE_TASK_GROUP_NAME, + "The total number of records produced/polled (before transformation) by this task " + + "belonging to the named source connector in this worker.", + sourceTaskTags); + sourceRecordWriteRate = createTemplate("source-record-write-rate", SOURCE_TASK_GROUP_NAME, + "The average per-second number of records output from the transformations and written" + + " to Kafka for this task belonging to the named source connector in this worker. This" + + " is after transformations are applied and excludes any records filtered out by the " + + "transformations.", + sourceTaskTags); + sourceRecordWriteTotal = createTemplate("source-record-write-total", SOURCE_TASK_GROUP_NAME, + "The number of records output from the transformations and written to Kafka for this" + + " task belonging to the named source connector in this worker, since the task was " + + "last restarted.", + sourceTaskTags); + sourceRecordPollBatchTimeMax = createTemplate("poll-batch-max-time-ms", SOURCE_TASK_GROUP_NAME, + "The maximum time in milliseconds taken by this task to poll for a batch of " + + "source records.", + sourceTaskTags); + sourceRecordPollBatchTimeAvg = createTemplate("poll-batch-avg-time-ms", SOURCE_TASK_GROUP_NAME, + "The average time in milliseconds taken by this task to poll for a batch of " + + "source records.", + sourceTaskTags); + sourceRecordActiveCount = createTemplate("source-record-active-count", SOURCE_TASK_GROUP_NAME, + "The number of records that have been produced by this task but not yet completely " + + "written to Kafka.", + sourceTaskTags); + sourceRecordActiveCountMax = createTemplate("source-record-active-count-max", SOURCE_TASK_GROUP_NAME, + "The maximum number of records that have been produced by this task but not yet " + + "completely written to Kafka.", + sourceTaskTags); + sourceRecordActiveCountAvg = createTemplate("source-record-active-count-avg", SOURCE_TASK_GROUP_NAME, + "The average number of records that have been produced by this task but not yet " + + "completely written to Kafka.", + sourceTaskTags); + + /***** Sink worker task level *****/ + Set sinkTaskTags = new LinkedHashSet<>(tags); + sinkTaskTags.add(CONNECTOR_TAG_NAME); + sinkTaskTags.add(TASK_TAG_NAME); + + sinkRecordReadRate = createTemplate("sink-record-read-rate", SINK_TASK_GROUP_NAME, + "The average per-second number of records read from Kafka for this task belonging to the" + + " named sink connector in this worker. This is before transformations are applied.", + sinkTaskTags); + sinkRecordReadTotal = createTemplate("sink-record-read-total", SINK_TASK_GROUP_NAME, + "The total number of records read from Kafka by this task belonging to the named sink " + + "connector in this worker, since the task was last restarted.", + sinkTaskTags); + sinkRecordSendRate = createTemplate("sink-record-send-rate", SINK_TASK_GROUP_NAME, + "The average per-second number of records output from the transformations and sent/put " + + "to this task belonging to the named sink connector in this worker. This is after " + + "transformations are applied and excludes any records filtered out by the " + + "transformations.", + sinkTaskTags); + sinkRecordSendTotal = createTemplate("sink-record-send-total", SINK_TASK_GROUP_NAME, + "The total number of records output from the transformations and sent/put to this task " + + "belonging to the named sink connector in this worker, since the task was last " + + "restarted.", + sinkTaskTags); + sinkRecordLagMax = createTemplate("sink-record-lag-max", SINK_TASK_GROUP_NAME, + "The maximum lag in terms of number of records that the sink task is behind the consumer's " + + "position for any topic partitions.", + sinkTaskTags); + sinkRecordPartitionCount = createTemplate("partition-count", SINK_TASK_GROUP_NAME, + "The number of topic partitions assigned to this task belonging to the named sink " + + "connector in this worker.", + sinkTaskTags); + sinkRecordOffsetCommitSeqNum = createTemplate("offset-commit-seq-no", SINK_TASK_GROUP_NAME, + "The current sequence number for offset commits.", sinkTaskTags); + sinkRecordOffsetCommitCompletionRate = createTemplate("offset-commit-completion-rate", SINK_TASK_GROUP_NAME, + "The average per-second number of offset commit completions that were " + + "completed successfully.", + sinkTaskTags); + sinkRecordOffsetCommitCompletionTotal = createTemplate("offset-commit-completion-total", SINK_TASK_GROUP_NAME, + "The total number of offset commit completions that were completed " + + "successfully.", + sinkTaskTags); + sinkRecordOffsetCommitSkipRate = createTemplate("offset-commit-skip-rate", SINK_TASK_GROUP_NAME, + "The average per-second number of offset commit completions that were " + + "received too late and skipped/ignored.", + sinkTaskTags); + sinkRecordOffsetCommitSkipTotal = createTemplate("offset-commit-skip-total", SINK_TASK_GROUP_NAME, + "The total number of offset commit completions that were received too late " + + "and skipped/ignored.", + sinkTaskTags); + sinkRecordPutBatchTimeMax = createTemplate("put-batch-max-time-ms", SINK_TASK_GROUP_NAME, + "The maximum time taken by this task to put a batch of sinks records.", sinkTaskTags); + sinkRecordPutBatchTimeAvg = createTemplate("put-batch-avg-time-ms", SINK_TASK_GROUP_NAME, + "The average time taken by this task to put a batch of sinks records.", sinkTaskTags); + sinkRecordActiveCount = createTemplate("sink-record-active-count", SINK_TASK_GROUP_NAME, + "The number of records that have been read from Kafka but not yet completely " + + "committed/flushed/acknowledged by the sink task.", + sinkTaskTags); + sinkRecordActiveCountMax = createTemplate("sink-record-active-count-max", SINK_TASK_GROUP_NAME, + "The maximum number of records that have been read from Kafka but not yet completely " + + "committed/flushed/acknowledged by the sink task.", + sinkTaskTags); + sinkRecordActiveCountAvg = createTemplate("sink-record-active-count-avg", SINK_TASK_GROUP_NAME, + "The average number of records that have been read from Kafka but not yet completely " + + "committed/flushed/acknowledged by the sink task.", + sinkTaskTags); + + /***** Worker level *****/ + Set workerTags = new LinkedHashSet<>(tags); + + connectorCount = createTemplate("connector-count", WORKER_GROUP_NAME, "The number of connectors run in this worker.", workerTags); + taskCount = createTemplate("task-count", WORKER_GROUP_NAME, "The number of tasks run in this worker.", workerTags); + connectorStartupAttemptsTotal = createTemplate("connector-startup-attempts-total", WORKER_GROUP_NAME, + "The total number of connector startups that this worker has attempted.", workerTags); + connectorStartupSuccessTotal = createTemplate("connector-startup-success-total", WORKER_GROUP_NAME, + "The total number of connector starts that succeeded.", workerTags); + connectorStartupSuccessPercentage = createTemplate("connector-startup-success-percentage", WORKER_GROUP_NAME, + "The average percentage of this worker's connectors starts that succeeded.", workerTags); + connectorStartupFailureTotal = createTemplate("connector-startup-failure-total", WORKER_GROUP_NAME, + "The total number of connector starts that failed.", workerTags); + connectorStartupFailurePercentage = createTemplate("connector-startup-failure-percentage", WORKER_GROUP_NAME, + "The average percentage of this worker's connectors starts that failed.", workerTags); + taskStartupAttemptsTotal = createTemplate("task-startup-attempts-total", WORKER_GROUP_NAME, + "The total number of task startups that this worker has attempted.", workerTags); + taskStartupSuccessTotal = createTemplate("task-startup-success-total", WORKER_GROUP_NAME, + "The total number of task starts that succeeded.", workerTags); + taskStartupSuccessPercentage = createTemplate("task-startup-success-percentage", WORKER_GROUP_NAME, + "The average percentage of this worker's tasks starts that succeeded.", workerTags); + taskStartupFailureTotal = createTemplate("task-startup-failure-total", WORKER_GROUP_NAME, + "The total number of task starts that failed.", workerTags); + taskStartupFailurePercentage = createTemplate("task-startup-failure-percentage", WORKER_GROUP_NAME, + "The average percentage of this worker's tasks starts that failed.", workerTags); + + Set workerConnectorTags = new LinkedHashSet<>(tags); + workerConnectorTags.add(CONNECTOR_TAG_NAME); + connectorTotalTaskCount = createTemplate("connector-total-task-count", WORKER_GROUP_NAME, + "The number of tasks of the connector on the worker.", workerConnectorTags); + connectorRunningTaskCount = createTemplate("connector-running-task-count", WORKER_GROUP_NAME, + "The number of running tasks of the connector on the worker.", workerConnectorTags); + connectorPausedTaskCount = createTemplate("connector-paused-task-count", WORKER_GROUP_NAME, + "The number of paused tasks of the connector on the worker.", workerConnectorTags); + connectorFailedTaskCount = createTemplate("connector-failed-task-count", WORKER_GROUP_NAME, + "The number of failed tasks of the connector on the worker.", workerConnectorTags); + connectorUnassignedTaskCount = createTemplate("connector-unassigned-task-count", + WORKER_GROUP_NAME, + "The number of unassigned tasks of the connector on the worker.", workerConnectorTags); + connectorDestroyedTaskCount = createTemplate("connector-destroyed-task-count", + WORKER_GROUP_NAME, + "The number of destroyed tasks of the connector on the worker.", workerConnectorTags); + connectorRestartingTaskCount = createTemplate("connector-restarting-task-count", + WORKER_GROUP_NAME, + "The number of restarting tasks of the connector on the worker.", workerConnectorTags); + + connectorStatusMetrics = new HashMap<>(); + connectorStatusMetrics.put(connectorRunningTaskCount, TaskStatus.State.RUNNING); + connectorStatusMetrics.put(connectorPausedTaskCount, TaskStatus.State.PAUSED); + connectorStatusMetrics.put(connectorFailedTaskCount, TaskStatus.State.FAILED); + connectorStatusMetrics.put(connectorUnassignedTaskCount, TaskStatus.State.UNASSIGNED); + connectorStatusMetrics.put(connectorDestroyedTaskCount, TaskStatus.State.DESTROYED); + connectorStatusMetrics.put(connectorRestartingTaskCount, TaskStatus.State.RESTARTING); + connectorStatusMetrics = Collections.unmodifiableMap(connectorStatusMetrics); + + /***** Worker rebalance level *****/ + Set rebalanceTags = new LinkedHashSet<>(tags); + + connectProtocol = createTemplate("connect-protocol", WORKER_REBALANCE_GROUP_NAME, "The Connect protocol used by this cluster", rebalanceTags); + leaderName = createTemplate("leader-name", WORKER_REBALANCE_GROUP_NAME, "The name of the group leader.", rebalanceTags); + epoch = createTemplate("epoch", WORKER_REBALANCE_GROUP_NAME, "The epoch or generation number of this worker.", rebalanceTags); + rebalanceCompletedTotal = createTemplate("completed-rebalances-total", WORKER_REBALANCE_GROUP_NAME, + "The total number of rebalances completed by this worker.", rebalanceTags); + rebalanceMode = createTemplate("rebalancing", WORKER_REBALANCE_GROUP_NAME, + "Whether this worker is currently rebalancing.", rebalanceTags); + rebalanceTimeMax = createTemplate("rebalance-max-time-ms", WORKER_REBALANCE_GROUP_NAME, + "The maximum time in milliseconds spent by this worker to rebalance.", rebalanceTags); + rebalanceTimeAvg = createTemplate("rebalance-avg-time-ms", WORKER_REBALANCE_GROUP_NAME, + "The average time in milliseconds spent by this worker to rebalance.", rebalanceTags); + rebalanceTimeSinceLast = createTemplate("time-since-last-rebalance-ms", WORKER_REBALANCE_GROUP_NAME, + "The time in milliseconds since this worker completed the most recent rebalance.", rebalanceTags); + + /***** Task Error Handling Metrics *****/ + Set taskErrorHandlingTags = new LinkedHashSet<>(tags); + taskErrorHandlingTags.add(CONNECTOR_TAG_NAME); + taskErrorHandlingTags.add(TASK_TAG_NAME); + + recordProcessingFailures = createTemplate("total-record-failures", TASK_ERROR_HANDLING_GROUP_NAME, + "The number of record processing failures in this task.", taskErrorHandlingTags); + recordProcessingErrors = createTemplate("total-record-errors", TASK_ERROR_HANDLING_GROUP_NAME, + "The number of record processing errors in this task. ", taskErrorHandlingTags); + recordsSkipped = createTemplate("total-records-skipped", TASK_ERROR_HANDLING_GROUP_NAME, + "The number of records skipped due to errors.", taskErrorHandlingTags); + retries = createTemplate("total-retries", TASK_ERROR_HANDLING_GROUP_NAME, + "The number of operations retried.", taskErrorHandlingTags); + errorsLogged = createTemplate("total-errors-logged", TASK_ERROR_HANDLING_GROUP_NAME, + "The number of errors that were logged.", taskErrorHandlingTags); + dlqProduceRequests = createTemplate("deadletterqueue-produce-requests", TASK_ERROR_HANDLING_GROUP_NAME, + "The number of attempted writes to the dead letter queue.", taskErrorHandlingTags); + dlqProduceFailures = createTemplate("deadletterqueue-produce-failures", TASK_ERROR_HANDLING_GROUP_NAME, + "The number of failed writes to the dead letter queue.", taskErrorHandlingTags); + lastErrorTimestamp = createTemplate("last-error-timestamp", TASK_ERROR_HANDLING_GROUP_NAME, + "The epoch timestamp when this task last encountered an error.", taskErrorHandlingTags); + } + + private MetricNameTemplate createTemplate(String name, String group, String doc, Set tags) { + MetricNameTemplate template = new MetricNameTemplate(name, group, doc, tags); + allTemplates.add(template); + return template; + } + + public List getAllTemplates() { + return Collections.unmodifiableList(allTemplates); + } + + public String connectorTagName() { + return CONNECTOR_TAG_NAME; + } + + public String taskTagName() { + return TASK_TAG_NAME; + } + + public String connectorGroupName() { + return CONNECTOR_GROUP_NAME; + } + + public String taskGroupName() { + return TASK_GROUP_NAME; + } + + public String sinkTaskGroupName() { + return SINK_TASK_GROUP_NAME; + } + + public String sourceTaskGroupName() { + return SOURCE_TASK_GROUP_NAME; + } + + public String workerGroupName() { + return WORKER_GROUP_NAME; + } + + public String workerRebalanceGroupName() { + return WORKER_REBALANCE_GROUP_NAME; + } + + public String taskErrorHandlingGroupName() { + return TASK_ERROR_HANDLING_GROUP_NAME; + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/ConnectorConfig.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/ConnectorConfig.java new file mode 100644 index 0000000..4ba1ddd --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/ConnectorConfig.java @@ -0,0 +1,554 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigDef.Importance; +import org.apache.kafka.common.config.ConfigDef.Type; +import org.apache.kafka.common.config.ConfigDef.Width; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.connect.connector.ConnectRecord; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.runtime.errors.ToleranceType; +import org.apache.kafka.connect.runtime.isolation.PluginDesc; +import org.apache.kafka.connect.runtime.isolation.Plugins; +import org.apache.kafka.connect.transforms.Transformation; +import org.apache.kafka.connect.transforms.predicates.Predicate; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.lang.reflect.Modifier; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.apache.kafka.common.config.ConfigDef.NonEmptyStringWithoutControlChars.nonEmptyStringWithoutControlChars; +import static org.apache.kafka.common.config.ConfigDef.Range.atLeast; +import static org.apache.kafka.common.config.ConfigDef.ValidString.in; + +/** + *

            + * Configuration options for Connectors. These only include Kafka Connect system-level configuration + * options (e.g. Connector class name, timeouts used by Connect to control the connector) but does + * not include Connector-specific options (e.g. database connection settings). + *

            + *

            + * Note that some of these options are not required for all connectors. For example TOPICS_CONFIG + * is sink-specific. + *

            + */ +public class ConnectorConfig extends AbstractConfig { + private static final Logger log = LoggerFactory.getLogger(ConnectorConfig.class); + + protected static final String COMMON_GROUP = "Common"; + protected static final String TRANSFORMS_GROUP = "Transforms"; + protected static final String PREDICATES_GROUP = "Predicates"; + protected static final String ERROR_GROUP = "Error Handling"; + + public static final String NAME_CONFIG = "name"; + private static final String NAME_DOC = "Globally unique name to use for this connector."; + private static final String NAME_DISPLAY = "Connector name"; + + public static final String CONNECTOR_CLASS_CONFIG = "connector.class"; + private static final String CONNECTOR_CLASS_DOC = + "Name or alias of the class for this connector. Must be a subclass of org.apache.kafka.connect.connector.Connector. " + + "If the connector is org.apache.kafka.connect.file.FileStreamSinkConnector, you can either specify this full name, " + + " or use \"FileStreamSink\" or \"FileStreamSinkConnector\" to make the configuration a bit shorter"; + private static final String CONNECTOR_CLASS_DISPLAY = "Connector class"; + + public static final String KEY_CONVERTER_CLASS_CONFIG = WorkerConfig.KEY_CONVERTER_CLASS_CONFIG; + public static final String KEY_CONVERTER_CLASS_DOC = WorkerConfig.KEY_CONVERTER_CLASS_DOC; + public static final String KEY_CONVERTER_CLASS_DISPLAY = "Key converter class"; + + public static final String VALUE_CONVERTER_CLASS_CONFIG = WorkerConfig.VALUE_CONVERTER_CLASS_CONFIG; + public static final String VALUE_CONVERTER_CLASS_DOC = WorkerConfig.VALUE_CONVERTER_CLASS_DOC; + public static final String VALUE_CONVERTER_CLASS_DISPLAY = "Value converter class"; + + public static final String HEADER_CONVERTER_CLASS_CONFIG = WorkerConfig.HEADER_CONVERTER_CLASS_CONFIG; + public static final String HEADER_CONVERTER_CLASS_DOC = WorkerConfig.HEADER_CONVERTER_CLASS_DOC; + public static final String HEADER_CONVERTER_CLASS_DISPLAY = "Header converter class"; + // The Connector config should not have a default for the header converter, since the absence of a config property means that + // the worker config settings should be used. Thus, we set the default to null here. + public static final String HEADER_CONVERTER_CLASS_DEFAULT = null; + + public static final String TASKS_MAX_CONFIG = "tasks.max"; + private static final String TASKS_MAX_DOC = "Maximum number of tasks to use for this connector."; + public static final int TASKS_MAX_DEFAULT = 1; + private static final int TASKS_MIN_CONFIG = 1; + + private static final String TASK_MAX_DISPLAY = "Tasks max"; + + public static final String TRANSFORMS_CONFIG = "transforms"; + private static final String TRANSFORMS_DOC = "Aliases for the transformations to be applied to records."; + private static final String TRANSFORMS_DISPLAY = "Transforms"; + + public static final String PREDICATES_CONFIG = "predicates"; + private static final String PREDICATES_DOC = "Aliases for the predicates used by transformations."; + private static final String PREDICATES_DISPLAY = "Predicates"; + + public static final String CONFIG_RELOAD_ACTION_CONFIG = "config.action.reload"; + private static final String CONFIG_RELOAD_ACTION_DOC = + "The action that Connect should take on the connector when changes in external " + + "configuration providers result in a change in the connector's configuration properties. " + + "A value of 'none' indicates that Connect will do nothing. " + + "A value of 'restart' indicates that Connect should restart/reload the connector with the " + + "updated configuration properties." + + "The restart may actually be scheduled in the future if the external configuration provider " + + "indicates that a configuration value will expire in the future."; + + private static final String CONFIG_RELOAD_ACTION_DISPLAY = "Reload Action"; + public static final String CONFIG_RELOAD_ACTION_NONE = Herder.ConfigReloadAction.NONE.name().toLowerCase(Locale.ROOT); + public static final String CONFIG_RELOAD_ACTION_RESTART = Herder.ConfigReloadAction.RESTART.name().toLowerCase(Locale.ROOT); + + public static final String ERRORS_RETRY_TIMEOUT_CONFIG = "errors.retry.timeout"; + public static final String ERRORS_RETRY_TIMEOUT_DISPLAY = "Retry Timeout for Errors"; + public static final int ERRORS_RETRY_TIMEOUT_DEFAULT = 0; + public static final String ERRORS_RETRY_TIMEOUT_DOC = "The maximum duration in milliseconds that a failed operation " + + "will be reattempted. The default is 0, which means no retries will be attempted. Use -1 for infinite retries."; + + public static final String ERRORS_RETRY_MAX_DELAY_CONFIG = "errors.retry.delay.max.ms"; + public static final String ERRORS_RETRY_MAX_DELAY_DISPLAY = "Maximum Delay Between Retries for Errors"; + public static final int ERRORS_RETRY_MAX_DELAY_DEFAULT = 60000; + public static final String ERRORS_RETRY_MAX_DELAY_DOC = "The maximum duration in milliseconds between consecutive retry attempts. " + + "Jitter will be added to the delay once this limit is reached to prevent thundering herd issues."; + + public static final String ERRORS_TOLERANCE_CONFIG = "errors.tolerance"; + public static final String ERRORS_TOLERANCE_DISPLAY = "Error Tolerance"; + public static final ToleranceType ERRORS_TOLERANCE_DEFAULT = ToleranceType.NONE; + public static final String ERRORS_TOLERANCE_DOC = "Behavior for tolerating errors during connector operation. 'none' is the default value " + + "and signals that any error will result in an immediate connector task failure; 'all' changes the behavior to skip over problematic records."; + + public static final String ERRORS_LOG_ENABLE_CONFIG = "errors.log.enable"; + public static final String ERRORS_LOG_ENABLE_DISPLAY = "Log Errors"; + public static final boolean ERRORS_LOG_ENABLE_DEFAULT = false; + public static final String ERRORS_LOG_ENABLE_DOC = "If true, write each error and the details of the failed operation and problematic record " + + "to the Connect application log. This is 'false' by default, so that only errors that are not tolerated are reported."; + + public static final String ERRORS_LOG_INCLUDE_MESSAGES_CONFIG = "errors.log.include.messages"; + public static final String ERRORS_LOG_INCLUDE_MESSAGES_DISPLAY = "Log Error Details"; + public static final boolean ERRORS_LOG_INCLUDE_MESSAGES_DEFAULT = false; + public static final String ERRORS_LOG_INCLUDE_MESSAGES_DOC = "Whether to the include in the log the Connect record that resulted in " + + "a failure. This is 'false' by default, which will prevent record keys, values, and headers from being written to log files, " + + "although some information such as topic and partition number will still be logged."; + + + public static final String CONNECTOR_CLIENT_PRODUCER_OVERRIDES_PREFIX = "producer.override."; + public static final String CONNECTOR_CLIENT_CONSUMER_OVERRIDES_PREFIX = "consumer.override."; + public static final String CONNECTOR_CLIENT_ADMIN_OVERRIDES_PREFIX = "admin.override."; + public static final String PREDICATES_PREFIX = "predicates."; + + private final EnrichedConnectorConfig enrichedConfig; + private static class EnrichedConnectorConfig extends AbstractConfig { + EnrichedConnectorConfig(ConfigDef configDef, Map props) { + super(configDef, props); + } + + @Override + public Object get(String key) { + return super.get(key); + } + } + + public static ConfigDef configDef() { + int orderInGroup = 0; + int orderInErrorGroup = 0; + return new ConfigDef() + .define(NAME_CONFIG, Type.STRING, ConfigDef.NO_DEFAULT_VALUE, nonEmptyStringWithoutControlChars(), Importance.HIGH, NAME_DOC, COMMON_GROUP, ++orderInGroup, Width.MEDIUM, NAME_DISPLAY) + .define(CONNECTOR_CLASS_CONFIG, Type.STRING, Importance.HIGH, CONNECTOR_CLASS_DOC, COMMON_GROUP, ++orderInGroup, Width.LONG, CONNECTOR_CLASS_DISPLAY) + .define(TASKS_MAX_CONFIG, Type.INT, TASKS_MAX_DEFAULT, atLeast(TASKS_MIN_CONFIG), Importance.HIGH, TASKS_MAX_DOC, COMMON_GROUP, ++orderInGroup, Width.SHORT, TASK_MAX_DISPLAY) + .define(KEY_CONVERTER_CLASS_CONFIG, Type.CLASS, null, Importance.LOW, KEY_CONVERTER_CLASS_DOC, COMMON_GROUP, ++orderInGroup, Width.SHORT, KEY_CONVERTER_CLASS_DISPLAY) + .define(VALUE_CONVERTER_CLASS_CONFIG, Type.CLASS, null, Importance.LOW, VALUE_CONVERTER_CLASS_DOC, COMMON_GROUP, ++orderInGroup, Width.SHORT, VALUE_CONVERTER_CLASS_DISPLAY) + .define(HEADER_CONVERTER_CLASS_CONFIG, Type.CLASS, HEADER_CONVERTER_CLASS_DEFAULT, Importance.LOW, HEADER_CONVERTER_CLASS_DOC, COMMON_GROUP, ++orderInGroup, Width.SHORT, HEADER_CONVERTER_CLASS_DISPLAY) + .define(TRANSFORMS_CONFIG, Type.LIST, Collections.emptyList(), aliasValidator("transformation"), Importance.LOW, TRANSFORMS_DOC, TRANSFORMS_GROUP, ++orderInGroup, Width.LONG, TRANSFORMS_DISPLAY) + .define(PREDICATES_CONFIG, Type.LIST, Collections.emptyList(), aliasValidator("predicate"), Importance.LOW, PREDICATES_DOC, PREDICATES_GROUP, ++orderInGroup, Width.LONG, PREDICATES_DISPLAY) + .define(CONFIG_RELOAD_ACTION_CONFIG, Type.STRING, CONFIG_RELOAD_ACTION_RESTART, + in(CONFIG_RELOAD_ACTION_NONE, CONFIG_RELOAD_ACTION_RESTART), Importance.LOW, + CONFIG_RELOAD_ACTION_DOC, COMMON_GROUP, ++orderInGroup, Width.MEDIUM, CONFIG_RELOAD_ACTION_DISPLAY) + .define(ERRORS_RETRY_TIMEOUT_CONFIG, Type.LONG, ERRORS_RETRY_TIMEOUT_DEFAULT, Importance.MEDIUM, + ERRORS_RETRY_TIMEOUT_DOC, ERROR_GROUP, ++orderInErrorGroup, Width.MEDIUM, ERRORS_RETRY_TIMEOUT_DISPLAY) + .define(ERRORS_RETRY_MAX_DELAY_CONFIG, Type.LONG, ERRORS_RETRY_MAX_DELAY_DEFAULT, Importance.MEDIUM, + ERRORS_RETRY_MAX_DELAY_DOC, ERROR_GROUP, ++orderInErrorGroup, Width.MEDIUM, ERRORS_RETRY_MAX_DELAY_DISPLAY) + .define(ERRORS_TOLERANCE_CONFIG, Type.STRING, ERRORS_TOLERANCE_DEFAULT.value(), + in(ToleranceType.NONE.value(), ToleranceType.ALL.value()), Importance.MEDIUM, + ERRORS_TOLERANCE_DOC, ERROR_GROUP, ++orderInErrorGroup, Width.SHORT, ERRORS_TOLERANCE_DISPLAY) + .define(ERRORS_LOG_ENABLE_CONFIG, Type.BOOLEAN, ERRORS_LOG_ENABLE_DEFAULT, Importance.MEDIUM, + ERRORS_LOG_ENABLE_DOC, ERROR_GROUP, ++orderInErrorGroup, Width.SHORT, ERRORS_LOG_ENABLE_DISPLAY) + .define(ERRORS_LOG_INCLUDE_MESSAGES_CONFIG, Type.BOOLEAN, ERRORS_LOG_INCLUDE_MESSAGES_DEFAULT, Importance.MEDIUM, + ERRORS_LOG_INCLUDE_MESSAGES_DOC, ERROR_GROUP, ++orderInErrorGroup, Width.SHORT, ERRORS_LOG_INCLUDE_MESSAGES_DISPLAY); + } + + private static ConfigDef.CompositeValidator aliasValidator(String kind) { + return ConfigDef.CompositeValidator.of(new ConfigDef.NonNullValidator(), new ConfigDef.Validator() { + @SuppressWarnings("unchecked") + @Override + public void ensureValid(String name, Object value) { + final List aliases = (List) value; + if (aliases.size() > new HashSet<>(aliases).size()) { + throw new ConfigException(name, value, "Duplicate alias provided."); + } + } + + @Override + public String toString() { + return "unique " + kind + " aliases"; + } + }); + } + + public ConnectorConfig(Plugins plugins) { + this(plugins, Collections.emptyMap()); + } + + public ConnectorConfig(Plugins plugins, Map props) { + this(plugins, configDef(), props); + } + + public ConnectorConfig(Plugins plugins, ConfigDef configDef, Map props) { + super(configDef, props); + enrichedConfig = new EnrichedConnectorConfig( + enrich(plugins, configDef, props, true), + props + ); + } + + @Override + public Object get(String key) { + return enrichedConfig.get(key); + } + + public long errorRetryTimeout() { + return getLong(ERRORS_RETRY_TIMEOUT_CONFIG); + } + + public long errorMaxDelayInMillis() { + return getLong(ERRORS_RETRY_MAX_DELAY_CONFIG); + } + + public ToleranceType errorToleranceType() { + String tolerance = getString(ERRORS_TOLERANCE_CONFIG); + for (ToleranceType type: ToleranceType.values()) { + if (type.name().equalsIgnoreCase(tolerance)) { + return type; + } + } + return ERRORS_TOLERANCE_DEFAULT; + } + + public boolean enableErrorLog() { + return getBoolean(ERRORS_LOG_ENABLE_CONFIG); + } + + public boolean includeRecordDetailsInErrorLog() { + return getBoolean(ERRORS_LOG_INCLUDE_MESSAGES_CONFIG); + } + + /** + * Returns the initialized list of {@link Transformation} which are specified in {@link #TRANSFORMS_CONFIG}. + */ + public > List> transformations() { + final List transformAliases = getList(TRANSFORMS_CONFIG); + + final List> transformations = new ArrayList<>(transformAliases.size()); + for (String alias : transformAliases) { + final String prefix = TRANSFORMS_CONFIG + "." + alias + "."; + + try { + @SuppressWarnings("unchecked") + final Transformation transformation = Utils.newInstance(getClass(prefix + "type"), Transformation.class); + Map configs = originalsWithPrefix(prefix); + Object predicateAlias = configs.remove(PredicatedTransformation.PREDICATE_CONFIG); + Object negate = configs.remove(PredicatedTransformation.NEGATE_CONFIG); + transformation.configure(configs); + if (predicateAlias != null) { + String predicatePrefix = PREDICATES_PREFIX + predicateAlias + "."; + @SuppressWarnings("unchecked") + Predicate predicate = Utils.newInstance(getClass(predicatePrefix + "type"), Predicate.class); + predicate.configure(originalsWithPrefix(predicatePrefix)); + transformations.add(new PredicatedTransformation<>(predicate, negate == null ? false : Boolean.parseBoolean(negate.toString()), transformation)); + } else { + transformations.add(transformation); + } + } catch (Exception e) { + throw new ConnectException(e); + } + } + + return transformations; + } + + /** + * Returns an enriched {@link ConfigDef} building upon the {@code ConfigDef}, using the current configuration specified in {@code props} as an input. + *

            + * {@code requireFullConfig} specifies whether required config values that are missing should cause an exception to be thrown. + */ + @SuppressWarnings({"rawtypes", "unchecked"}) + public static ConfigDef enrich(Plugins plugins, ConfigDef baseConfigDef, Map props, boolean requireFullConfig) { + ConfigDef newDef = new ConfigDef(baseConfigDef); + new EnrichablePlugin>("Transformation", TRANSFORMS_CONFIG, TRANSFORMS_GROUP, (Class) Transformation.class, + props, requireFullConfig) { + @SuppressWarnings("rawtypes") + @Override + protected Set>> plugins() { + return (Set) plugins.transformations(); + } + + @Override + protected ConfigDef initialConfigDef() { + // All Transformations get these config parameters implicitly + return super.initialConfigDef() + .define(PredicatedTransformation.PREDICATE_CONFIG, Type.STRING, "", Importance.MEDIUM, + "The alias of a predicate used to determine whether to apply this transformation.") + .define(PredicatedTransformation.NEGATE_CONFIG, Type.BOOLEAN, false, Importance.MEDIUM, + "Whether the configured predicate should be negated."); + } + + @Override + protected Stream> configDefsForClass(String typeConfig) { + return super.configDefsForClass(typeConfig) + .filter(entry -> { + // The implicit parameters mask any from the transformer with the same name + if (PredicatedTransformation.PREDICATE_CONFIG.equals(entry.getKey()) + || PredicatedTransformation.NEGATE_CONFIG.equals(entry.getKey())) { + log.warn("Transformer config {} is masked by implicit config of that name", + entry.getKey()); + return false; + } else { + return true; + } + }); + } + + @Override + protected ConfigDef config(Transformation transformation) { + return transformation.config(); + } + + @Override + protected void validateProps(String prefix) { + String prefixedNegate = prefix + PredicatedTransformation.NEGATE_CONFIG; + String prefixedPredicate = prefix + PredicatedTransformation.PREDICATE_CONFIG; + if (props.containsKey(prefixedNegate) && + !props.containsKey(prefixedPredicate)) { + throw new ConfigException("Config '" + prefixedNegate + "' was provided " + + "but there is no config '" + prefixedPredicate + "' defining a predicate to be negated."); + } + } + }.enrich(newDef); + + new EnrichablePlugin>("Predicate", PREDICATES_CONFIG, PREDICATES_GROUP, + (Class) Predicate.class, props, requireFullConfig) { + @Override + protected Set>> plugins() { + return (Set) plugins.predicates(); + } + + @Override + protected ConfigDef config(Predicate predicate) { + return predicate.config(); + } + }.enrich(newDef); + return newDef; + } + + /** + * An abstraction over "enrichable plugins" ({@link Transformation}s and {@link Predicate}s) used for computing the + * contribution to a Connectors ConfigDef. + * + * This is not entirely elegant because + * although they basically use the same "alias prefix" configuration idiom there are some differences. + * The abstract method pattern is used to cope with this. + * @param The type of plugin (either {@code Transformation} or {@code Predicate}). + */ + static abstract class EnrichablePlugin { + + private final String aliasKind; + private final String aliasConfig; + private final String aliasGroup; + private final Class baseClass; + private final Map props; + private final boolean requireFullConfig; + + public EnrichablePlugin( + String aliasKind, + String aliasConfig, String aliasGroup, Class baseClass, + Map props, boolean requireFullConfig) { + this.aliasKind = aliasKind; + this.aliasConfig = aliasConfig; + this.aliasGroup = aliasGroup; + this.baseClass = baseClass; + this.props = props; + this.requireFullConfig = requireFullConfig; + } + + /** Add the configs for this alias to the given {@code ConfigDef}. */ + void enrich(ConfigDef newDef) { + Object aliases = ConfigDef.parseType(aliasConfig, props.get(aliasConfig), Type.LIST); + if (!(aliases instanceof List)) { + return; + } + + LinkedHashSet uniqueAliases = new LinkedHashSet<>((List) aliases); + for (Object o : uniqueAliases) { + if (!(o instanceof String)) { + throw new ConfigException("Item in " + aliasConfig + " property is not of " + + "type String"); + } + String alias = (String) o; + final String prefix = aliasConfig + "." + alias + "."; + final String group = aliasGroup + ": " + alias; + int orderInGroup = 0; + + final String typeConfig = prefix + "type"; + final ConfigDef.Validator typeValidator = ConfigDef.LambdaValidator.with( + (String name, Object value) -> { + validateProps(prefix); + getConfigDefFromConfigProvidingClass(typeConfig, (Class) value); + }, + () -> "valid configs for " + alias + " " + aliasKind.toLowerCase(Locale.ENGLISH)); + newDef.define(typeConfig, Type.CLASS, ConfigDef.NO_DEFAULT_VALUE, typeValidator, Importance.HIGH, + "Class for the '" + alias + "' " + aliasKind.toLowerCase(Locale.ENGLISH) + ".", group, orderInGroup++, Width.LONG, + baseClass.getSimpleName() + " type for " + alias, + Collections.emptyList(), new ClassRecommender()); + + final ConfigDef configDef = populateConfigDef(typeConfig); + if (configDef == null) continue; + newDef.embed(prefix, group, orderInGroup, configDef); + } + } + + /** Subclasses can add extra validation of the {@link #props}. */ + protected void validateProps(String prefix) { } + + /** + * Populates the ConfigDef according to the configs returned from {@code configs()} method of class + * named in the {@code ...type} parameter of the {@code props}. + */ + protected ConfigDef populateConfigDef(String typeConfig) { + final ConfigDef configDef = initialConfigDef(); + try { + configDefsForClass(typeConfig) + .forEach(entry -> configDef.define(entry.getValue())); + + } catch (ConfigException e) { + if (requireFullConfig) { + throw e; + } else { + return null; + } + } + return configDef; + } + + /** + * Return a stream of configs provided by the {@code configs()} method of class + * named in the {@code ...type} parameter of the {@code props}. + */ + protected Stream> configDefsForClass(String typeConfig) { + final Class cls = (Class) ConfigDef.parseType(typeConfig, props.get(typeConfig), Type.CLASS); + return getConfigDefFromConfigProvidingClass(typeConfig, cls) + .configKeys().entrySet().stream(); + } + + /** Get an initial ConfigDef */ + protected ConfigDef initialConfigDef() { + return new ConfigDef(); + } + + /** + * Return {@link ConfigDef} from {@code cls}, which is expected to be a non-null {@code Class}, + * by instantiating it and invoking {@link #config(T)}. + * @param key + * @param cls The subclass of the baseclass. + */ + ConfigDef getConfigDefFromConfigProvidingClass(String key, Class cls) { + if (cls == null || !baseClass.isAssignableFrom(cls)) { + throw new ConfigException(key, String.valueOf(cls), "Not a " + baseClass.getSimpleName()); + } + if (Modifier.isAbstract(cls.getModifiers())) { + String childClassNames = Stream.of(cls.getClasses()) + .filter(cls::isAssignableFrom) + .filter(c -> !Modifier.isAbstract(c.getModifiers())) + .filter(c -> Modifier.isPublic(c.getModifiers())) + .map(Class::getName) + .collect(Collectors.joining(", ")); + String message = Utils.isBlank(childClassNames) ? + aliasKind + " is abstract and cannot be created." : + aliasKind + " is abstract and cannot be created. Did you mean " + childClassNames + "?"; + throw new ConfigException(key, String.valueOf(cls), message); + } + T transformation; + try { + transformation = Utils.newInstance(cls, baseClass); + } catch (Exception e) { + throw new ConfigException(key, String.valueOf(cls), "Error getting config definition from " + baseClass.getSimpleName() + ": " + e.getMessage()); + } + ConfigDef configDef = config(transformation); + if (null == configDef) { + throw new ConnectException( + String.format( + "%s.config() must return a ConfigDef that is not null.", + cls.getName() + ) + ); + } + return configDef; + } + + /** + * Get the ConfigDef from the given entity. + * This is necessary because there's no abstraction across {@link Transformation#config()} and + * {@link Predicate#config()}. + */ + protected abstract ConfigDef config(T t); + + /** + * The transformation or predicate plugins (as appropriate for T) to be used + * for the {@link ClassRecommender}. + */ + protected abstract Set> plugins(); + + /** + * Recommend bundled transformations or predicates. + */ + final class ClassRecommender implements ConfigDef.Recommender { + + @Override + public List validValues(String name, Map parsedConfig) { + List result = new ArrayList<>(); + for (PluginDesc plugin : plugins()) { + result.add(plugin.pluginClass()); + } + return Collections.unmodifiableList(result); + } + + @Override + public boolean visible(String name, Map parsedConfig) { + return true; + } + } + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/ConnectorStatus.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/ConnectorStatus.java new file mode 100644 index 0000000..6772c8b --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/ConnectorStatus.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + + +public class ConnectorStatus extends AbstractStatus { + + public ConnectorStatus(String connector, State state, String msg, String workerUrl, int generation) { + super(connector, state, workerUrl, generation, msg); + } + + public ConnectorStatus(String connector, State state, String workerUrl, int generation) { + super(connector, state, workerUrl, generation, null); + } + + public interface Listener { + + /** + * Invoked after connector has successfully been shutdown. + * @param connector The connector name + */ + void onShutdown(String connector); + + /** + * Invoked from the Connector using {@link org.apache.kafka.connect.connector.ConnectorContext#raiseError(Exception)} + * or if either {@link org.apache.kafka.connect.connector.Connector#start(java.util.Map)} or + * {@link org.apache.kafka.connect.connector.Connector#stop()} throw an exception. + * Note that no shutdown event will follow after the task has been failed. + * @param connector The connector name + * @param cause Error raised from the connector. + */ + void onFailure(String connector, Throwable cause); + + /** + * Invoked when the connector is paused through the REST API + * @param connector The connector name + */ + void onPause(String connector); + + /** + * Invoked after the connector has been resumed. + * @param connector The connector name + */ + void onResume(String connector); + + /** + * Invoked after successful startup of the connector. + * @param connector The connector name + */ + void onStartup(String connector); + + /** + * Invoked when the connector is deleted through the REST API. + * @param connector The connector name + */ + void onDeletion(String connector); + + /** + * Invoked when the connector is restarted asynchronously by the herder on processing a restart request. + * @param connector The connector name + */ + void onRestart(String connector); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/Herder.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/Herder.java new file mode 100644 index 0000000..945797f --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/Herder.java @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.connect.runtime.isolation.Plugins; +import org.apache.kafka.connect.runtime.rest.InternalRequestSignature; +import org.apache.kafka.connect.runtime.rest.entities.ActiveTopicsInfo; +import org.apache.kafka.connect.runtime.rest.entities.ConfigInfos; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorInfo; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorStateInfo; +import org.apache.kafka.connect.runtime.rest.entities.TaskInfo; +import org.apache.kafka.connect.storage.StatusBackingStore; +import org.apache.kafka.connect.util.Callback; +import org.apache.kafka.connect.util.ConnectorTaskId; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +/** + *

            + * The herder interface tracks and manages workers and connectors. It is the main interface for external components + * to make changes to the state of the cluster. For example, in distributed mode, an implementation of this class + * knows how to accept a connector configuration, may need to route it to the current leader worker for the cluster so + * the config can be written to persistent storage, and then ensures the new connector is correctly instantiated on one + * of the workers. + *

            + *

            + * This class must implement all the actions that can be taken on the cluster (add/remove connectors, pause/resume tasks, + * get state of connectors and tasks, etc). The non-Java interfaces to the cluster (REST API and CLI) are very simple + * wrappers of the functionality provided by this interface. + *

            + *

            + * In standalone mode, this implementation of this class will be trivial because no coordination is needed. In that case, + * the implementation will mainly be delegating tasks directly to other components. For example, when creating a new + * connector in standalone mode, there is no need to persist the config and the connector and its tasks must run in the + * same process, so the standalone herder implementation can immediately instantiate and start the connector and its + * tasks. + *

            + */ +public interface Herder { + + void start(); + + void stop(); + + boolean isRunning(); + + /** + * Get a list of connectors currently running in this cluster. This is a full list of connectors in the cluster gathered + * from the current configuration. However, note + * + * @return A list of connector names + * @throws org.apache.kafka.connect.runtime.distributed.RequestTargetException if this node can not resolve the request + * (e.g., because it has not joined the cluster or does not have configs in sync with the group) and it is + * not the leader or the task owner (e.g., task restart must be handled by the worker which owns the task) + * @throws org.apache.kafka.connect.errors.ConnectException if this node is the leader, but still cannot resolve the + * request (e.g., it is not in sync with other worker's config state) + */ + void connectors(Callback> callback); + + /** + * Get the definition and status of a connector. + */ + void connectorInfo(String connName, Callback callback); + + /** + * Get the configuration for a connector. + * @param connName name of the connector + * @param callback callback to invoke with the configuration + */ + void connectorConfig(String connName, Callback> callback); + + /** + * Get the configuration for all tasks. + * @param connName name of the connector + * @param callback callback to invoke with the configuration + */ + void tasksConfig(String connName, Callback>> callback); + + /** + * Set the configuration for a connector. This supports creation and updating. + * @param connName name of the connector + * @param config the connectors configuration, or null if deleting the connector + * @param allowReplace if true, allow overwriting previous configs; if false, throw AlreadyExistsException if a connector + * with the same name already exists + * @param callback callback to invoke when the configuration has been written + */ + void putConnectorConfig(String connName, Map config, boolean allowReplace, Callback> callback); + + /** + * Delete a connector and its configuration. + * @param connName name of the connector + * @param callback callback to invoke when the configuration has been written + */ + void deleteConnectorConfig(String connName, Callback> callback); + + /** + * Requests reconfiguration of the task. This should only be triggered by + * {@link HerderConnectorContext}. + * + * @param connName name of the connector that should be reconfigured + */ + void requestTaskReconfiguration(String connName); + + /** + * Get the configurations for the current set of tasks of a connector. + * @param connName connector to update + * @param callback callback to invoke upon completion + */ + void taskConfigs(String connName, Callback> callback); + + /** + * Set the configurations for the tasks of a connector. This should always include all tasks in the connector; if + * there are existing configurations and fewer are provided, this will reduce the number of tasks, and if more are + * provided it will increase the number of tasks. + * @param connName connector to update + * @param configs list of configurations + * @param callback callback to invoke upon completion + * @param requestSignature the signature of the request made for this task (re-)configuration; + * may be null if no signature was provided + */ + void putTaskConfigs(String connName, List> configs, Callback callback, InternalRequestSignature requestSignature); + + /** + * Get a list of connectors currently running in this cluster. + * @return A list of connector names + */ + Collection connectors(); + + /** + * Get the definition and status of a connector. + * @param connName name of the connector + */ + ConnectorInfo connectorInfo(String connName); + + /** + * Lookup the current status of a connector. + * @param connName name of the connector + */ + ConnectorStateInfo connectorStatus(String connName); + + /** + * Lookup the set of topics currently used by a connector. + * + * @param connName name of the connector + * @return the set of active topics + */ + ActiveTopicsInfo connectorActiveTopics(String connName); + + /** + * Request to asynchronously reset the active topics for the named connector. + * + * @param connName name of the connector + */ + void resetConnectorActiveTopics(String connName); + + /** + * Return a reference to the status backing store used by this herder. + * + * @return the status backing store used by this herder + */ + StatusBackingStore statusBackingStore(); + + /** + * Lookup the status of the a task. + * @param id id of the task + */ + ConnectorStateInfo.TaskState taskStatus(ConnectorTaskId id); + + /** + * Validate the provided connector config values against the configuration definition. + * @param connectorConfig the provided connector config values + * @param callback the callback to invoke after validation has completed (successfully or not) + */ + void validateConnectorConfig(Map connectorConfig, Callback callback); + + /** + * Validate the provided connector config values against the configuration definition. + * @param connectorConfig the provided connector config values + * @param callback the callback to invoke after validation has completed (successfully or not) + * @param doLog if true log all the connector configurations at INFO level; if false, no connector configurations are logged. + * Note that logging of configuration is not necessary in every endpoint that uses this method. + */ + default void validateConnectorConfig(Map connectorConfig, Callback callback, boolean doLog) { + validateConnectorConfig(connectorConfig, callback); + } + + /** + * Restart the task with the given id. + * @param id id of the task + * @param cb callback to invoke upon completion + */ + void restartTask(ConnectorTaskId id, Callback cb); + + /** + * Restart the connector. + * @param connName name of the connector + * @param cb callback to invoke upon completion + */ + void restartConnector(String connName, Callback cb); + + /** + * Restart the connector. + * @param delayMs delay before restart + * @param connName name of the connector + * @param cb callback to invoke upon completion + * @return The id of the request + */ + HerderRequest restartConnector(long delayMs, String connName, Callback cb); + + /** + * Restart the connector and optionally its tasks. + * @param request the details of the restart request + * @param cb callback to invoke upon completion with the connector state info + */ + void restartConnectorAndTasks(RestartRequest request, Callback cb); + + /** + * Pause the connector. This call will asynchronously suspend processing by the connector and all + * of its tasks. + * @param connector name of the connector + */ + void pauseConnector(String connector); + + /** + * Resume the connector. This call will asynchronously start the connector and its tasks (if + * not started already). + * @param connector name of the connector + */ + void resumeConnector(String connector); + + /** + * Returns a handle to the plugin factory used by this herder and its worker. + * + * @return a reference to the plugin factory. + */ + Plugins plugins(); + + /** + * Get the cluster ID of the Kafka cluster backing this Connect cluster. + * @return the cluster ID of the Kafka cluster backing this connect cluster + */ + String kafkaClusterId(); + + enum ConfigReloadAction { + NONE, + RESTART + } + + class Created { + private final boolean created; + private final T result; + + public Created(boolean created, T result) { + this.created = created; + this.result = result; + } + + public boolean created() { + return created; + } + + public T result() { + return result; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Created created1 = (Created) o; + return Objects.equals(created, created1.created) && + Objects.equals(result, created1.result); + } + + @Override + public int hashCode() { + return Objects.hash(created, result); + } + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/HerderConnectorContext.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/HerderConnectorContext.java new file mode 100644 index 0000000..60092ba --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/HerderConnectorContext.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.connect.errors.ConnectException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * ConnectorContext for use with a Herder + */ +public class HerderConnectorContext implements CloseableConnectorContext { + + private static final Logger log = LoggerFactory.getLogger(HerderConnectorContext.class); + + private final AbstractHerder herder; + private final String connectorName; + private volatile boolean closed; + + public HerderConnectorContext(AbstractHerder herder, String connectorName) { + this.herder = herder; + this.connectorName = connectorName; + this.closed = false; + } + + @Override + public void requestTaskReconfiguration() { + if (closed) { + throw new ConnectException("The request for task reconfiguration has been rejected " + + "because this instance of the connector '" + connectorName + "' has already " + + "been shut down."); + } + + // Local herder runs in memory in this process + // Distributed herder will forward the request to the leader if needed + herder.requestTaskReconfiguration(connectorName); + } + + @Override + public void raiseError(Exception e) { + if (closed) { + log.warn("Connector {} attempted to raise error after shutdown:", connectorName, e); + throw new ConnectException("The request to fail the connector has been rejected " + + "because this instance of the connector '" + connectorName + "' has already " + + "been shut down."); + } + + herder.onFailure(connectorName, e); + } + + @Override + public void close() { + closed = true; + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/HerderRequest.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/HerderRequest.java new file mode 100644 index 0000000..627da4d --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/HerderRequest.java @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +public interface HerderRequest { + void cancel(); +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/InternalSinkRecord.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/InternalSinkRecord.java new file mode 100644 index 0000000..69554ff --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/InternalSinkRecord.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.header.Header; +import org.apache.kafka.connect.sink.SinkRecord; + +/** + * A specialization of {@link SinkRecord} that allows a {@link WorkerSinkTask} to track the + * original {@link ConsumerRecord} for each {@link SinkRecord}. It is used internally and not + * exposed to connectors. + */ +public class InternalSinkRecord extends SinkRecord { + + private final ConsumerRecord originalRecord; + + public InternalSinkRecord(ConsumerRecord originalRecord, SinkRecord record) { + super(record.topic(), record.kafkaPartition(), record.keySchema(), record.key(), + record.valueSchema(), record.value(), record.kafkaOffset(), record.timestamp(), + record.timestampType(), record.headers()); + this.originalRecord = originalRecord; + } + + protected InternalSinkRecord(ConsumerRecord originalRecord, String topic, + int partition, Schema keySchema, Object key, Schema valueSchema, + Object value, long kafkaOffset, Long timestamp, + TimestampType timestampType, Iterable
            headers) { + super(topic, partition, keySchema, key, valueSchema, value, kafkaOffset, timestamp, timestampType, headers); + this.originalRecord = originalRecord; + } + + @Override + public SinkRecord newRecord(String topic, Integer kafkaPartition, Schema keySchema, Object key, + Schema valueSchema, Object value, Long timestamp, + Iterable
            headers) { + return new InternalSinkRecord(originalRecord, topic, kafkaPartition, keySchema, key, + valueSchema, value, kafkaOffset(), timestamp, timestampType(), headers()); + } + + @Override + public boolean equals(Object o) { + return super.equals(o); + } + + @Override + public int hashCode() { + return super.hashCode(); + } + + @Override + public String toString() { + return super.toString(); + } + + /** + * Return the original consumer record that this sink record represents. + * + * @return the original consumer record; never null + */ + public ConsumerRecord originalRecord() { + return originalRecord; + } +} + diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/PredicatedTransformation.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/PredicatedTransformation.java new file mode 100644 index 0000000..d61772f --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/PredicatedTransformation.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import java.util.Map; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.connect.connector.ConnectRecord; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.transforms.Transformation; +import org.apache.kafka.connect.transforms.predicates.Predicate; + +/** + * Decorator for a {@link Transformation} which applies the delegate only when a + * {@link Predicate} is true (or false, according to {@code negate}). + * @param + */ +class PredicatedTransformation> implements Transformation { + + static final String PREDICATE_CONFIG = "predicate"; + static final String NEGATE_CONFIG = "negate"; + Predicate predicate; + Transformation delegate; + boolean negate; + + PredicatedTransformation(Predicate predicate, boolean negate, Transformation delegate) { + this.predicate = predicate; + this.negate = negate; + this.delegate = delegate; + } + + @Override + public void configure(Map configs) { + throw new ConnectException(PredicatedTransformation.class.getName() + ".configure() " + + "should never be called directly."); + } + + @Override + public R apply(R record) { + if (negate ^ predicate.test(record)) { + return delegate.apply(record); + } + return record; + } + + @Override + public ConfigDef config() { + throw new ConnectException(PredicatedTransformation.class.getName() + ".config() " + + "should never be called directly."); + } + + @Override + public void close() { + Utils.closeQuietly(delegate, "predicated"); + Utils.closeQuietly(predicate, "predicate"); + } + + @Override + public String toString() { + return "PredicatedTransformation{" + + "predicate=" + predicate + + ", delegate=" + delegate + + ", negate=" + negate + + '}'; + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/RestartPlan.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/RestartPlan.java new file mode 100644 index 0000000..a57ce99 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/RestartPlan.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import java.util.Collection; +import java.util.Collections; +import java.util.Objects; +import java.util.stream.Collectors; + +import org.apache.kafka.connect.connector.Connector; +import org.apache.kafka.connect.connector.Task; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorStateInfo; +import org.apache.kafka.connect.util.ConnectorTaskId; + +/** + * An immutable restart plan per connector. + */ +public class RestartPlan { + + private final RestartRequest request; + private final ConnectorStateInfo stateInfo; + private final Collection idsToRestart; + + /** + * Create a new plan to restart a connector and optionally its tasks. + * + * @param request the restart request; may not be null + * @param restartStateInfo the current state info for the connector; may not be null + */ + public RestartPlan(RestartRequest request, ConnectorStateInfo restartStateInfo) { + this.request = Objects.requireNonNull(request, "RestartRequest name may not be null"); + this.stateInfo = Objects.requireNonNull(restartStateInfo, "ConnectorStateInfo name may not be null"); + // Collect the task IDs to stop and restart (may be none) + this.idsToRestart = Collections.unmodifiableList( + stateInfo.tasks() + .stream() + .filter(this::isRestarting) + .map(taskState -> new ConnectorTaskId(request.connectorName(), taskState.id())) + .collect(Collectors.toList()) + ); + } + + /** + * Get the connector name. + * + * @return the name of the connector; never null + */ + public String connectorName() { + return request.connectorName(); + } + + /** + * Get the original {@link RestartRequest}. + * + * @return the restart request; never null + */ + public RestartRequest restartRequest() { + return request; + } + + /** + * Get the {@link ConnectorStateInfo} that reflects the current state of the connector except with the {@code status} + * set to {@link AbstractStatus.State#RESTARTING} for the {@link Connector} instance and any {@link Task} instances that + * are to be restarted, based upon the {@link #restartRequest() restart request}. + * + * @return the connector state info that reflects the restart plan; never null + */ + public ConnectorStateInfo restartConnectorStateInfo() { + return stateInfo; + } + + /** + * Get the immutable collection of {@link ConnectorTaskId} for all tasks to be restarted + * based upon the {@link #restartRequest() restart request}. + * + * @return the IDs of the tasks to be restarted; never null but possibly empty + */ + public Collection taskIdsToRestart() { + return idsToRestart; + } + + /** + * Determine whether the {@link Connector} instance is to be restarted + * based upon the {@link #restartRequest() restart request}. + * + * @return true if the {@link Connector} instance is to be restarted, or false otherwise + */ + public boolean shouldRestartConnector() { + return isRestarting(stateInfo.connector()); + } + + /** + * Determine whether at least one {@link Task} instance is to be restarted + * based upon the {@link #restartRequest() restart request}. + * + * @return true if any {@link Task} instances are to be restarted, or false if none are to be restarted + */ + public boolean shouldRestartTasks() { + return !taskIdsToRestart().isEmpty(); + } + + /** + * Get the number of connector tasks that are to be restarted + * based upon the {@link #restartRequest() restart request}. + * + * @return the number of {@link Task} instance is to be restarted + */ + public int restartTaskCount() { + return taskIdsToRestart().size(); + } + + /** + * Get the total number of tasks in the connector. + * + * @return the total number of tasks + */ + public int totalTaskCount() { + return stateInfo.tasks().size(); + } + + private boolean isRestarting(ConnectorStateInfo.AbstractState state) { + return isRestarting(state.state()); + } + + private boolean isRestarting(String state) { + return AbstractStatus.State.RESTARTING.toString().equalsIgnoreCase(state); + } + + @Override + public String toString() { + return shouldRestartConnector() + ? String.format("plan to restart connector and %d of %d tasks for %s", restartTaskCount(), totalTaskCount(), request) + : String.format("plan to restart %d of %d tasks for %s", restartTaskCount(), totalTaskCount(), request); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/RestartRequest.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/RestartRequest.java new file mode 100644 index 0000000..425c639 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/RestartRequest.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import java.util.Objects; + +import org.apache.kafka.connect.connector.Connector; +import org.apache.kafka.connect.connector.Task; + +/** + * A request to restart a connector and/or task instances. + *

            The natural order is based first upon the connector name and then requested restart behaviors. + * If two requests have the same connector name, then the requests are ordered based on the + * probable number of tasks/connector this request is going to restart. + */ +public class RestartRequest implements Comparable { + + private final String connectorName; + private final boolean onlyFailed; + private final boolean includeTasks; + + /** + * Create a new request to restart a connector and optionally its tasks. + * + * @param connectorName the name of the connector; may not be null + * @param onlyFailed true if only failed instances should be restarted + * @param includeTasks true if tasks should be restarted, or false if only the connector should be restarted + */ + public RestartRequest(String connectorName, boolean onlyFailed, boolean includeTasks) { + this.connectorName = Objects.requireNonNull(connectorName, "Connector name may not be null"); + this.onlyFailed = onlyFailed; + this.includeTasks = includeTasks; + } + + /** + * Get the name of the connector. + * + * @return the connector name; never null + */ + public String connectorName() { + return connectorName; + } + + /** + * Determine whether only failed instances be restarted. + * + * @return true if only failed instances should be restarted, or false if all applicable instances should be restarted + */ + public boolean onlyFailed() { + return onlyFailed; + } + + /** + * Determine whether {@link Task} instances should also be restarted in addition to the {@link Connector} instance. + * + * @return true if the connector and task instances should be restarted, or false if just the connector should be restarted + */ + public boolean includeTasks() { + return includeTasks; + } + + /** + * Determine whether the connector with the given status is to be restarted. + * + * @param status the connector status; may not be null + * @return true if the connector is to be restarted, or false otherwise + */ + public boolean shouldRestartConnector(ConnectorStatus status) { + return !onlyFailed || status.state() == AbstractStatus.State.FAILED; + } + + /** + * Determine whether only the {@link Connector} instance is to be restarted even if not failed. + * + * @return true if only the {@link Connector} instance is to be restarted even if not failed, or false otherwise + */ + public boolean forceRestartConnectorOnly() { + return !onlyFailed() && !includeTasks(); + } + + /** + * Determine whether the task instance with the given status is to be restarted. + * + * @param status the task status; may not be null + * @return true if the task is to be restarted, or false otherwise + */ + public boolean shouldRestartTask(TaskStatus status) { + return includeTasks && (!onlyFailed || status.state() == AbstractStatus.State.FAILED); + } + + @Override + public int compareTo(RestartRequest o) { + int result = connectorName.compareTo(o.connectorName); + return result == 0 ? impactRank() - o.impactRank() : result; + } + //calculates an internal rank for the restart request based on the probable number of tasks/connector this request is going to restart + private int impactRank() { + if (onlyFailed && !includeTasks) { //restarts only failed connector so least impactful + return 0; + } else if (onlyFailed && includeTasks) { //restarts only failed connector and tasks + return 1; + } else if (!onlyFailed && !includeTasks) { //restart connector in any state but no tasks + return 2; + } + //onlyFailed==false&&includeTasks restarts both connector and tasks in any state so highest impact + return 3; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + RestartRequest that = (RestartRequest) o; + return onlyFailed == that.onlyFailed && includeTasks == that.includeTasks && Objects.equals(connectorName, that.connectorName); + } + + @Override + public int hashCode() { + return Objects.hash(connectorName, onlyFailed, includeTasks); + } + + @Override + public String toString() { + return "restart request for {" + "connectorName='" + connectorName + "', onlyFailed=" + onlyFailed + ", includeTasks=" + includeTasks + '}'; + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/SessionKey.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/SessionKey.java new file mode 100644 index 0000000..ab5476e --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/SessionKey.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import javax.crypto.SecretKey; +import java.util.Objects; + +/** + * A session key, which can be used to validate internal REST requests between workers. + */ +public class SessionKey { + + private final SecretKey key; + private final long creationTimestamp; + + /** + * Create a new session key with the given key value and creation timestamp + * @param key the actual cryptographic key to use for request validation; may not be null + * @param creationTimestamp the time at which the key was generated + */ + public SessionKey(SecretKey key, long creationTimestamp) { + this.key = Objects.requireNonNull(key, "Key may not be null"); + this.creationTimestamp = creationTimestamp; + } + + /** + * Get the cryptographic key to use for request validation. + * + * @return the cryptographic key; may not be null + */ + public SecretKey key() { + return key; + } + + /** + * Get the time at which the key was generated. + * + * @return the time at which the key was generated + */ + public long creationTimestamp() { + return creationTimestamp; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + SessionKey that = (SessionKey) o; + return creationTimestamp == that.creationTimestamp + && key.equals(that.key); + } + + @Override + public int hashCode() { + return Objects.hash(key, creationTimestamp); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/SinkConnectorConfig.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/SinkConnectorConfig.java new file mode 100644 index 0000000..93c2cb4 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/SinkConnectorConfig.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigDef.Importance; +import org.apache.kafka.common.config.ConfigDef.Type; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.connect.runtime.isolation.Plugins; +import org.apache.kafka.connect.sink.SinkTask; +import org.apache.kafka.connect.transforms.util.RegexValidator; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +/** + * Configuration needed for all sink connectors + */ + +public class SinkConnectorConfig extends ConnectorConfig { + + public static final String TOPICS_CONFIG = SinkTask.TOPICS_CONFIG; + private static final String TOPICS_DOC = "List of topics to consume, separated by commas"; + public static final String TOPICS_DEFAULT = ""; + private static final String TOPICS_DISPLAY = "Topics"; + + public static final String TOPICS_REGEX_CONFIG = SinkTask.TOPICS_REGEX_CONFIG; + private static final String TOPICS_REGEX_DOC = "Regular expression giving topics to consume. " + + "Under the hood, the regex is compiled to a java.util.regex.Pattern. " + + "Only one of " + TOPICS_CONFIG + " or " + TOPICS_REGEX_CONFIG + " should be specified."; + public static final String TOPICS_REGEX_DEFAULT = ""; + private static final String TOPICS_REGEX_DISPLAY = "Topics regex"; + + public static final String DLQ_PREFIX = "errors.deadletterqueue."; + + public static final String DLQ_TOPIC_NAME_CONFIG = DLQ_PREFIX + "topic.name"; + public static final String DLQ_TOPIC_NAME_DOC = "The name of the topic to be used as the dead letter queue (DLQ) for messages that " + + "result in an error when processed by this sink connector, or its transformations or converters. The topic name is blank by default, " + + "which means that no messages are to be recorded in the DLQ."; + public static final String DLQ_TOPIC_DEFAULT = ""; + private static final String DLQ_TOPIC_DISPLAY = "Dead Letter Queue Topic Name"; + + public static final String DLQ_TOPIC_REPLICATION_FACTOR_CONFIG = DLQ_PREFIX + "topic.replication.factor"; + private static final String DLQ_TOPIC_REPLICATION_FACTOR_CONFIG_DOC = "Replication factor used to create the dead letter queue topic when it doesn't already exist."; + public static final short DLQ_TOPIC_REPLICATION_FACTOR_CONFIG_DEFAULT = 3; + private static final String DLQ_TOPIC_REPLICATION_FACTOR_CONFIG_DISPLAY = "Dead Letter Queue Topic Replication Factor"; + + public static final String DLQ_CONTEXT_HEADERS_ENABLE_CONFIG = DLQ_PREFIX + "context.headers.enable"; + public static final boolean DLQ_CONTEXT_HEADERS_ENABLE_DEFAULT = false; + public static final String DLQ_CONTEXT_HEADERS_ENABLE_DOC = "If true, add headers containing error context to the messages " + + "written to the dead letter queue. To avoid clashing with headers from the original record, all error context header " + + "keys, all error context header keys will start with __connect.errors."; + private static final String DLQ_CONTEXT_HEADERS_ENABLE_DISPLAY = "Enable Error Context Headers"; + + static ConfigDef config = ConnectorConfig.configDef() + .define(TOPICS_CONFIG, ConfigDef.Type.LIST, TOPICS_DEFAULT, ConfigDef.Importance.HIGH, TOPICS_DOC, COMMON_GROUP, 4, ConfigDef.Width.LONG, TOPICS_DISPLAY) + .define(TOPICS_REGEX_CONFIG, ConfigDef.Type.STRING, TOPICS_REGEX_DEFAULT, new RegexValidator(), ConfigDef.Importance.HIGH, TOPICS_REGEX_DOC, COMMON_GROUP, 4, ConfigDef.Width.LONG, TOPICS_REGEX_DISPLAY) + .define(DLQ_TOPIC_NAME_CONFIG, ConfigDef.Type.STRING, DLQ_TOPIC_DEFAULT, Importance.MEDIUM, DLQ_TOPIC_NAME_DOC, ERROR_GROUP, 6, ConfigDef.Width.MEDIUM, DLQ_TOPIC_DISPLAY) + .define(DLQ_TOPIC_REPLICATION_FACTOR_CONFIG, ConfigDef.Type.SHORT, DLQ_TOPIC_REPLICATION_FACTOR_CONFIG_DEFAULT, Importance.MEDIUM, DLQ_TOPIC_REPLICATION_FACTOR_CONFIG_DOC, ERROR_GROUP, 7, ConfigDef.Width.MEDIUM, DLQ_TOPIC_REPLICATION_FACTOR_CONFIG_DISPLAY) + .define(DLQ_CONTEXT_HEADERS_ENABLE_CONFIG, ConfigDef.Type.BOOLEAN, DLQ_CONTEXT_HEADERS_ENABLE_DEFAULT, Importance.MEDIUM, DLQ_CONTEXT_HEADERS_ENABLE_DOC, ERROR_GROUP, 8, ConfigDef.Width.MEDIUM, DLQ_CONTEXT_HEADERS_ENABLE_DISPLAY); + + public static ConfigDef configDef() { + return config; + } + + public SinkConnectorConfig(Plugins plugins, Map props) { + super(plugins, config, props); + } + + /** + * Throw an exception if the passed-in properties do not constitute a valid sink. + * @param props sink configuration properties + */ + public static void validate(Map props) { + final boolean hasTopicsConfig = hasTopicsConfig(props); + final boolean hasTopicsRegexConfig = hasTopicsRegexConfig(props); + final boolean hasDlqTopicConfig = hasDlqTopicConfig(props); + + if (hasTopicsConfig && hasTopicsRegexConfig) { + throw new ConfigException(SinkTask.TOPICS_CONFIG + " and " + SinkTask.TOPICS_REGEX_CONFIG + + " are mutually exclusive options, but both are set."); + } + + if (!hasTopicsConfig && !hasTopicsRegexConfig) { + throw new ConfigException("Must configure one of " + + SinkTask.TOPICS_CONFIG + " or " + SinkTask.TOPICS_REGEX_CONFIG); + } + + if (hasDlqTopicConfig) { + String dlqTopic = props.get(DLQ_TOPIC_NAME_CONFIG).trim(); + if (hasTopicsConfig) { + List topics = parseTopicsList(props); + if (topics.contains(dlqTopic)) { + throw new ConfigException(String.format("The DLQ topic '%s' may not be included in the list of " + + "topics ('%s=%s') consumed by the connector", dlqTopic, SinkTask.TOPICS_REGEX_CONFIG, topics)); + } + } + if (hasTopicsRegexConfig) { + String topicsRegexStr = props.get(SinkTask.TOPICS_REGEX_CONFIG); + Pattern pattern = Pattern.compile(topicsRegexStr); + if (pattern.matcher(dlqTopic).matches()) { + throw new ConfigException(String.format("The DLQ topic '%s' may not be included in the regex matching the " + + "topics ('%s=%s') consumed by the connector", dlqTopic, SinkTask.TOPICS_REGEX_CONFIG, topicsRegexStr)); + } + } + } + } + + public static boolean hasTopicsConfig(Map props) { + String topicsStr = props.get(TOPICS_CONFIG); + return !Utils.isBlank(topicsStr); + } + + public static boolean hasTopicsRegexConfig(Map props) { + String topicsRegexStr = props.get(TOPICS_REGEX_CONFIG); + return !Utils.isBlank(topicsRegexStr); + } + + public static boolean hasDlqTopicConfig(Map props) { + String dqlTopicStr = props.get(DLQ_TOPIC_NAME_CONFIG); + return !Utils.isBlank(dqlTopicStr); + } + + @SuppressWarnings("unchecked") + public static List parseTopicsList(Map props) { + List topics = (List) ConfigDef.parseType(TOPICS_CONFIG, props.get(TOPICS_CONFIG), Type.LIST); + if (topics == null) { + return Collections.emptyList(); + } + return topics + .stream() + .filter(topic -> !topic.isEmpty()) + .distinct() + .collect(Collectors.toList()); + } + + public String dlqTopicName() { + return getString(DLQ_TOPIC_NAME_CONFIG); + } + + public short dlqTopicReplicationFactor() { + return getShort(DLQ_TOPIC_REPLICATION_FACTOR_CONFIG); + } + + public boolean isDlqContextHeadersEnabled() { + return getBoolean(DLQ_CONTEXT_HEADERS_ENABLE_CONFIG); + } + + public boolean enableErrantRecordReporter() { + String dqlTopic = dlqTopicName(); + return !dqlTopic.isEmpty() || enableErrorLog(); + } + + public static void main(String[] args) { + System.out.println(config.toHtml(4, config -> "sinkconnectorconfigs_" + config)); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/SourceConnectorConfig.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/SourceConnectorConfig.java new file mode 100644 index 0000000..7cf5d67 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/SourceConnectorConfig.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.connect.runtime.isolation.Plugins; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.apache.kafka.connect.runtime.TopicCreationConfig.DEFAULT_TOPIC_CREATION_GROUP; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.DEFAULT_TOPIC_CREATION_PREFIX; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.EXCLUDE_REGEX_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.INCLUDE_REGEX_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.PARTITIONS_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.REPLICATION_FACTOR_CONFIG; + +public class SourceConnectorConfig extends ConnectorConfig { + + protected static final String TOPIC_CREATION_GROUP = "Topic Creation"; + + public static final String TOPIC_CREATION_PREFIX = "topic.creation."; + + public static final String TOPIC_CREATION_GROUPS_CONFIG = TOPIC_CREATION_PREFIX + "groups"; + private static final String TOPIC_CREATION_GROUPS_DOC = "Groups of configurations for topics " + + "created by source connectors"; + private static final String TOPIC_CREATION_GROUPS_DISPLAY = "Topic Creation Groups"; + + private static class EnrichedSourceConnectorConfig extends ConnectorConfig { + EnrichedSourceConnectorConfig(Plugins plugins, ConfigDef configDef, Map props) { + super(plugins, configDef, props); + } + + @Override + public Object get(String key) { + return super.get(key); + } + } + + private static ConfigDef config = SourceConnectorConfig.configDef(); + private final EnrichedSourceConnectorConfig enrichedSourceConfig; + + public static ConfigDef configDef() { + int orderInGroup = 0; + return new ConfigDef(ConnectorConfig.configDef()) + .define(TOPIC_CREATION_GROUPS_CONFIG, ConfigDef.Type.LIST, Collections.emptyList(), + ConfigDef.CompositeValidator.of(new ConfigDef.NonNullValidator(), ConfigDef.LambdaValidator.with( + (name, value) -> { + List groupAliases = (List) value; + if (groupAliases.size() > new HashSet<>(groupAliases).size()) { + throw new ConfigException(name, value, "Duplicate alias provided."); + } + }, + () -> "unique topic creation groups")), + ConfigDef.Importance.LOW, TOPIC_CREATION_GROUPS_DOC, TOPIC_CREATION_GROUP, + ++orderInGroup, ConfigDef.Width.LONG, TOPIC_CREATION_GROUPS_DISPLAY); + } + + public static ConfigDef embedDefaultGroup(ConfigDef baseConfigDef) { + String defaultGroup = "default"; + ConfigDef newDefaultDef = new ConfigDef(baseConfigDef); + newDefaultDef.embed(DEFAULT_TOPIC_CREATION_PREFIX, defaultGroup, 0, TopicCreationConfig.defaultGroupConfigDef()); + return newDefaultDef; + } + + /** + * Returns an enriched {@link ConfigDef} building upon the {@code ConfigDef}, using the current configuration specified in {@code props} as an input. + * + * @param baseConfigDef the base configuration definition to be enriched + * @param props the non parsed configuration properties + * @return the enriched configuration definition + */ + public static ConfigDef enrich(ConfigDef baseConfigDef, Map props, AbstractConfig defaultGroupConfig) { + List topicCreationGroups = new ArrayList<>(); + Object aliases = ConfigDef.parseType(TOPIC_CREATION_GROUPS_CONFIG, props.get(TOPIC_CREATION_GROUPS_CONFIG), ConfigDef.Type.LIST); + if (aliases instanceof List) { + topicCreationGroups.addAll((List) aliases); + } + + ConfigDef newDef = new ConfigDef(baseConfigDef); + String defaultGroupPrefix = TOPIC_CREATION_PREFIX + DEFAULT_TOPIC_CREATION_GROUP + "."; + short defaultGroupReplicationFactor = defaultGroupConfig.getShort(defaultGroupPrefix + REPLICATION_FACTOR_CONFIG); + int defaultGroupPartitions = defaultGroupConfig.getInt(defaultGroupPrefix + PARTITIONS_CONFIG); + topicCreationGroups.stream().distinct().forEach(group -> { + if (!(group instanceof String)) { + throw new ConfigException("Item in " + TOPIC_CREATION_GROUPS_CONFIG + " property is not of type String"); + } + String alias = (String) group; + String prefix = TOPIC_CREATION_PREFIX + alias + "."; + String configGroup = TOPIC_CREATION_GROUP + ": " + alias; + newDef.embed(prefix, configGroup, 0, + TopicCreationConfig.configDef(configGroup, defaultGroupReplicationFactor, defaultGroupPartitions)); + }); + return newDef; + } + + public SourceConnectorConfig(Plugins plugins, Map props, boolean createTopics) { + super(plugins, config, props); + if (createTopics && props.entrySet().stream().anyMatch(e -> e.getKey().startsWith(TOPIC_CREATION_PREFIX))) { + ConfigDef defaultConfigDef = embedDefaultGroup(config); + // This config is only used to set default values for partitions and replication + // factor from the default group and otherwise it remains unused + AbstractConfig defaultGroup = new AbstractConfig(defaultConfigDef, props, false); + + // If the user has added regex of include or exclude patterns in the default group, + // they should be ignored. + Map propsWithoutRegexForDefaultGroup = new HashMap<>(props); + propsWithoutRegexForDefaultGroup.entrySet() + .removeIf(e -> e.getKey().equals(DEFAULT_TOPIC_CREATION_PREFIX + INCLUDE_REGEX_CONFIG) + || e.getKey().equals(DEFAULT_TOPIC_CREATION_PREFIX + EXCLUDE_REGEX_CONFIG)); + enrichedSourceConfig = new EnrichedSourceConnectorConfig(plugins, + enrich(defaultConfigDef, props, defaultGroup), + propsWithoutRegexForDefaultGroup); + } else { + enrichedSourceConfig = null; + } + } + + @Override + public Object get(String key) { + return enrichedSourceConfig != null ? enrichedSourceConfig.get(key) : super.get(key); + } + + /** + * Returns whether this configuration uses topic creation properties. + * + * @return true if the configuration should be validated and used for topic creation; false otherwise + */ + public boolean usesTopicCreation() { + return enrichedSourceConfig != null; + } + + public List topicCreationInclude(String group) { + return getList(TOPIC_CREATION_PREFIX + group + '.' + INCLUDE_REGEX_CONFIG); + } + + public List topicCreationExclude(String group) { + return getList(TOPIC_CREATION_PREFIX + group + '.' + EXCLUDE_REGEX_CONFIG); + } + + public Short topicCreationReplicationFactor(String group) { + return getShort(TOPIC_CREATION_PREFIX + group + '.' + REPLICATION_FACTOR_CONFIG); + } + + public Integer topicCreationPartitions(String group) { + return getInt(TOPIC_CREATION_PREFIX + group + '.' + PARTITIONS_CONFIG); + } + + public Map topicCreationOtherConfigs(String group) { + if (enrichedSourceConfig == null) { + return Collections.emptyMap(); + } + return enrichedSourceConfig.originalsWithPrefix(TOPIC_CREATION_PREFIX + group + '.').entrySet().stream() + .filter(e -> { + String key = e.getKey(); + return !(INCLUDE_REGEX_CONFIG.equals(key) || EXCLUDE_REGEX_CONFIG.equals(key) + || REPLICATION_FACTOR_CONFIG.equals(key) || PARTITIONS_CONFIG.equals(key)); + }) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + } + + public static void main(String[] args) { + System.out.println(config.toHtml(4, config -> "sourceconnectorconfigs_" + config)); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/SourceTaskOffsetCommitter.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/SourceTaskOffsetCommitter.java new file mode 100644 index 0000000..c3416be --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/SourceTaskOffsetCommitter.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.apache.kafka.connect.util.LoggingContext; +import org.apache.kafka.common.utils.ThreadUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.concurrent.CancellationException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; + +/** + *

            + * Manages offset commit scheduling and execution for SourceTasks. + *

            + *

            + * Unlike sink tasks which directly manage their offset commits in the main poll() thread since + * they drive the event loop and control (for all intents and purposes) the timeouts, source + * tasks are at the whim of the connector and cannot be guaranteed to wake up on the necessary + * schedule. Instead, this class tracks all the active tasks, their schedule for commits, and + * ensures they are invoked in a timely fashion. + *

            + */ +class SourceTaskOffsetCommitter { + private static final Logger log = LoggerFactory.getLogger(SourceTaskOffsetCommitter.class); + + private final WorkerConfig config; + private final ScheduledExecutorService commitExecutorService; + private final ConcurrentMap> committers; + + // visible for testing + SourceTaskOffsetCommitter(WorkerConfig config, + ScheduledExecutorService commitExecutorService, + ConcurrentMap> committers) { + this.config = config; + this.commitExecutorService = commitExecutorService; + this.committers = committers; + } + + public SourceTaskOffsetCommitter(WorkerConfig config) { + this(config, Executors.newSingleThreadScheduledExecutor(ThreadUtils.createThreadFactory( + SourceTaskOffsetCommitter.class.getSimpleName() + "-%d", false)), + new ConcurrentHashMap<>()); + } + + public void close(long timeoutMs) { + commitExecutorService.shutdown(); + try { + if (!commitExecutorService.awaitTermination(timeoutMs, TimeUnit.MILLISECONDS)) { + log.error("Graceful shutdown of offset commitOffsets thread timed out."); + } + } catch (InterruptedException e) { + // ignore and allow to exit immediately + } + } + + public void schedule(final ConnectorTaskId id, final WorkerSourceTask workerTask) { + long commitIntervalMs = config.getLong(WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_CONFIG); + ScheduledFuture commitFuture = commitExecutorService.scheduleWithFixedDelay(() -> { + try (LoggingContext loggingContext = LoggingContext.forOffsets(id)) { + commit(workerTask); + } + }, commitIntervalMs, commitIntervalMs, TimeUnit.MILLISECONDS); + committers.put(id, commitFuture); + } + + public void remove(ConnectorTaskId id) { + final ScheduledFuture task = committers.remove(id); + if (task == null) + return; + + try (LoggingContext loggingContext = LoggingContext.forTask(id)) { + task.cancel(false); + if (!task.isDone()) + task.get(); + } catch (CancellationException e) { + // ignore + log.trace("Offset commit thread was cancelled by another thread while removing connector task with id: {}", id); + } catch (ExecutionException | InterruptedException e) { + throw new ConnectException("Unexpected interruption in SourceTaskOffsetCommitter while removing task with id: " + id, e); + } + } + + private void commit(WorkerSourceTask workerTask) { + log.debug("{} Committing offsets", workerTask); + try { + if (workerTask.commitOffsets()) { + return; + } + log.error("{} Failed to commit offsets", workerTask); + } catch (Throwable t) { + // We're very careful about exceptions here since any uncaught exceptions in the commit + // thread would cause the fixed interval schedule on the ExecutorService to stop running + // for that task + log.error("{} Unhandled exception when committing: ", workerTask, t); + } + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/StateTracker.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/StateTracker.java new file mode 100644 index 0000000..297d473 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/StateTracker.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.connect.runtime.AbstractStatus.State; + +import java.util.concurrent.atomic.AtomicReference; + +/** + * Utility class that tracks the current state and the duration of time spent in each state. + * This class is threadsafe. + */ +public class StateTracker { + + private final AtomicReference lastState = new AtomicReference<>(new StateChange()); + + /** + * Change the current state. + *

            + * This method is synchronized to ensure that all state changes are captured correctly and in the same order. + * Synchronization is acceptable since it is assumed that state changes will be relatively infrequent. + * + * @param newState the current state; may not be null + * @param now the current time in milliseconds + */ + public synchronized void changeState(State newState, long now) { + // JDK8: remove synchronization by using lastState.getAndUpdate(oldState->oldState.newState(newState, now)); + lastState.set(lastState.get().newState(newState, now)); + } + + /** + * Calculate the ratio of time spent in the specified state. + * + * @param ratioState the state for which the ratio is to be calculated; may not be null + * @param now the current time in milliseconds + * @return the ratio of time spent in the specified state to the time spent in all states + */ + public double durationRatio(State ratioState, long now) { + return lastState.get().durationRatio(ratioState, now); + } + + /** + * Get the current state. + * + * @return the current state; may be null if no state change has been recorded + */ + public State currentState() { + return lastState.get().state; + } + + /** + * An immutable record of the accumulated times at the most recent state change. This class is required to + * efficiently make {@link StateTracker} threadsafe. + */ + private static final class StateChange { + + private final State state; + private final long startTime; + private final long unassignedTotalTimeMs; + private final long runningTotalTimeMs; + private final long pausedTotalTimeMs; + private final long failedTotalTimeMs; + private final long destroyedTotalTimeMs; + private final long restartingTotalTimeMs; + + /** + * The initial StateChange instance before any state has changed. + */ + StateChange() { + this(null, 0L, 0L, 0L, 0L, 0L, 0L, 0L); + } + + StateChange(State state, long startTime, long unassignedTotalTimeMs, long runningTotalTimeMs, + long pausedTotalTimeMs, long failedTotalTimeMs, long destroyedTotalTimeMs, long restartingTotalTimeMs) { + this.state = state; + this.startTime = startTime; + this.unassignedTotalTimeMs = unassignedTotalTimeMs; + this.runningTotalTimeMs = runningTotalTimeMs; + this.pausedTotalTimeMs = pausedTotalTimeMs; + this.failedTotalTimeMs = failedTotalTimeMs; + this.destroyedTotalTimeMs = destroyedTotalTimeMs; + this.restartingTotalTimeMs = restartingTotalTimeMs; + } + + /** + * Return a new StateChange that includes the accumulated times of this state plus the time spent in the + * current state. + * + * @param state the new state; may not be null + * @param now the time at which the state transition occurs. + * @return the new StateChange, though may be this instance of the state did not actually change; never null + */ + public StateChange newState(State state, long now) { + if (this.state == null) { + return new StateChange(state, now, 0L, 0L, 0L, 0L, 0L, 0L); + } + if (state == this.state) { + return this; + } + long unassignedTime = this.unassignedTotalTimeMs; + long runningTime = this.runningTotalTimeMs; + long pausedTime = this.pausedTotalTimeMs; + long failedTime = this.failedTotalTimeMs; + long destroyedTime = this.destroyedTotalTimeMs; + long restartingTime = this.restartingTotalTimeMs; + long duration = now - startTime; + switch (this.state) { + case UNASSIGNED: + unassignedTime += duration; + break; + case RUNNING: + runningTime += duration; + break; + case PAUSED: + pausedTime += duration; + break; + case FAILED: + failedTime += duration; + break; + case DESTROYED: + destroyedTime += duration; + break; + case RESTARTING: + restartingTime += duration; + break; + } + return new StateChange(state, now, unassignedTime, runningTime, pausedTime, failedTime, destroyedTime, restartingTime); + } + + /** + * Calculate the ratio of time spent in the specified state. + * + * @param ratioState the state for which the ratio is to be calculated; may not be null + * @param now the current time in milliseconds + * @return the ratio of time spent in the specified state to the time spent in all states + */ + public double durationRatio(State ratioState, long now) { + if (state == null) { + return 0.0d; + } + long durationCurrent = now - startTime; // since last state change + long durationDesired = ratioState == state ? durationCurrent : 0L; + switch (ratioState) { + case UNASSIGNED: + durationDesired += unassignedTotalTimeMs; + break; + case RUNNING: + durationDesired += runningTotalTimeMs; + break; + case PAUSED: + durationDesired += pausedTotalTimeMs; + break; + case FAILED: + durationDesired += failedTotalTimeMs; + break; + case DESTROYED: + durationDesired += destroyedTotalTimeMs; + break; + case RESTARTING: + durationDesired += restartingTotalTimeMs; + break; + } + long total = durationCurrent + unassignedTotalTimeMs + runningTotalTimeMs + pausedTotalTimeMs + + failedTotalTimeMs + destroyedTotalTimeMs + restartingTotalTimeMs; + return total == 0.0d ? 0.0d : (double) durationDesired / total; + } + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/SubmittedRecords.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/SubmittedRecords.java new file mode 100644 index 0000000..6cdd2c1 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/SubmittedRecords.java @@ -0,0 +1,340 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.connect.source.SourceRecord; +import org.apache.kafka.connect.source.SourceTask; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collections; +import java.util.Deque; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Used to track source records that have been (or are about to be) dispatched to a producer and their accompanying + * source offsets. Records are tracked in the order in which they are submitted, which should match the order they were + * returned from {@link SourceTask#poll()}. The latest-eligible offsets for each source partition can be retrieved via + * {@link #committableOffsets()}, where every record up to and including the record for each returned offset has been + * either {@link SubmittedRecord#ack() acknowledged} or {@link #removeLastOccurrence(SubmittedRecord) removed}. + * Note that this class is not thread-safe, though a {@link SubmittedRecord} can be + * {@link SubmittedRecord#ack() acknowledged} from a different thread. + */ +class SubmittedRecords { + + private static final Logger log = LoggerFactory.getLogger(SubmittedRecords.class); + + // Visible for testing + final Map, Deque> records = new HashMap<>(); + private int numUnackedMessages = 0; + private CountDownLatch messageDrainLatch; + + public SubmittedRecords() { + } + + /** + * Enqueue a new source record before dispatching it to a producer. + * The returned {@link SubmittedRecord} should either be {@link SubmittedRecord#ack() acknowledged} in the + * producer callback, or {@link #removeLastOccurrence(SubmittedRecord) removed} if the record could not be successfully + * sent to the producer. + * + * @param record the record about to be dispatched; may not be null but may have a null + * {@link SourceRecord#sourcePartition()} and/or {@link SourceRecord#sourceOffset()} + * @return a {@link SubmittedRecord} that can be either {@link SubmittedRecord#ack() acknowledged} once ack'd by + * the producer, or {@link #removeLastOccurrence removed} if synchronously rejected by the producer + */ + @SuppressWarnings("unchecked") + public SubmittedRecord submit(SourceRecord record) { + return submit((Map) record.sourcePartition(), (Map) record.sourceOffset()); + } + + // Convenience method for testing + SubmittedRecord submit(Map partition, Map offset) { + SubmittedRecord result = new SubmittedRecord(partition, offset); + records.computeIfAbsent(result.partition(), p -> new LinkedList<>()) + .add(result); + synchronized (this) { + numUnackedMessages++; + } + return result; + } + + /** + * Remove a source record and do not take it into account any longer when tracking offsets. + * Useful if the record has been synchronously rejected by the producer. + * If multiple instances of the same {@link SubmittedRecord} have been submitted already, only the first one found + * (traversing from the end of the deque backward) will be removed. + * @param record the {@link #submit previously-submitted} record to stop tracking; may not be null + * @return whether an instance of the record was removed + */ + public boolean removeLastOccurrence(SubmittedRecord record) { + Deque deque = records.get(record.partition()); + if (deque == null) { + log.warn("Attempted to remove record from submitted queue for partition {}, but no records with that partition appear to have been submitted", record.partition()); + return false; + } + boolean result = deque.removeLastOccurrence(record); + if (deque.isEmpty()) { + records.remove(record.partition()); + } + if (result) { + messageAcked(); + } else { + log.warn("Attempted to remove record from submitted queue for partition {}, but the record has not been submitted or has already been removed", record.partition()); + } + return result; + } + + /** + * Clear out any acknowledged records at the head of the deques and return a {@link CommittableOffsets snapshot} of the offsets and offset metadata + * accrued between the last time this method was invoked and now. This snapshot can be {@link CommittableOffsets#updatedWith(CommittableOffsets) combined} + * with an existing snapshot if desired. + * Note that this may take some time to complete if a large number of records has built up, which may occur if a + * Kafka partition is offline and all records targeting that partition go unacknowledged while records targeting + * other partitions continue to be dispatched to the producer and sent successfully + * @return a fresh offset snapshot; never null + */ + public CommittableOffsets committableOffsets() { + Map, Map> offsets = new HashMap<>(); + int totalCommittableMessages = 0; + int totalUncommittableMessages = 0; + int largestDequeSize = 0; + Map largestDequePartition = null; + for (Map.Entry, Deque> entry : records.entrySet()) { + Map partition = entry.getKey(); + Deque queuedRecords = entry.getValue(); + int initialDequeSize = queuedRecords.size(); + if (canCommitHead(queuedRecords)) { + Map offset = committableOffset(queuedRecords); + offsets.put(partition, offset); + } + int uncommittableMessages = queuedRecords.size(); + int committableMessages = initialDequeSize - uncommittableMessages; + totalCommittableMessages += committableMessages; + totalUncommittableMessages += uncommittableMessages; + if (uncommittableMessages > largestDequeSize) { + largestDequeSize = uncommittableMessages; + largestDequePartition = partition; + } + } + // Clear out all empty deques from the map to keep it from growing indefinitely + records.values().removeIf(Deque::isEmpty); + return new CommittableOffsets(offsets, totalCommittableMessages, totalUncommittableMessages, records.size(), largestDequeSize, largestDequePartition); + } + + /** + * Wait for all currently in-flight messages to be acknowledged, up to the requested timeout. + * This method is expected to be called from the same thread that calls {@link #committableOffsets()}. + * @param timeout the maximum time to wait + * @param timeUnit the time unit of the timeout argument + * @return whether all in-flight messages were acknowledged before the timeout elapsed + */ + public boolean awaitAllMessages(long timeout, TimeUnit timeUnit) { + // Create a new message drain latch as a local variable to avoid SpotBugs warnings about inconsistent synchronization + // on an instance variable when invoking CountDownLatch::await outside a synchronized block + CountDownLatch messageDrainLatch; + synchronized (this) { + messageDrainLatch = new CountDownLatch(numUnackedMessages); + this.messageDrainLatch = messageDrainLatch; + } + try { + return messageDrainLatch.await(timeout, timeUnit); + } catch (InterruptedException e) { + return false; + } + } + + // Note that this will return null if either there are no committable offsets for the given deque, or the latest + // committable offset is itself null. The caller is responsible for distinguishing between the two cases. + private Map committableOffset(Deque queuedRecords) { + Map result = null; + while (canCommitHead(queuedRecords)) { + result = queuedRecords.poll().offset(); + } + return result; + } + + private boolean canCommitHead(Deque queuedRecords) { + return queuedRecords.peek() != null && queuedRecords.peek().acked(); + } + + // Synchronize in order to ensure that the number of unacknowledged messages isn't modified in the middle of a call + // to awaitAllMessages (which might cause us to decrement first, then create a new message drain latch, then count down + // that latch here, effectively double-acking the message) + private synchronized void messageAcked() { + numUnackedMessages--; + if (messageDrainLatch != null) { + messageDrainLatch.countDown(); + } + } + + class SubmittedRecord { + private final Map partition; + private final Map offset; + private final AtomicBoolean acked; + + public SubmittedRecord(Map partition, Map offset) { + this.partition = partition; + this.offset = offset; + this.acked = new AtomicBoolean(false); + } + + /** + * Acknowledge this record; signals that its offset may be safely committed. + * This is safe to be called from a different thread than what called {@link SubmittedRecords#submit(SourceRecord)}. + */ + public void ack() { + if (this.acked.compareAndSet(false, true)) { + messageAcked(); + } + } + + private boolean acked() { + return acked.get(); + } + + private Map partition() { + return partition; + } + + private Map offset() { + return offset; + } + } + + /** + * Contains a snapshot of offsets that can be committed for a source task and metadata for that offset commit + * (such as the number of messages for which offsets can and cannot be committed). + */ + static class CommittableOffsets { + + /** + * An "empty" snapshot that contains no offsets to commit and whose metadata contains no committable or uncommitable messages. + */ + public static final CommittableOffsets EMPTY = new CommittableOffsets(Collections.emptyMap(), 0, 0, 0, 0, null); + + private final Map, Map> offsets; + private final int numCommittableMessages; + private final int numUncommittableMessages; + private final int numDeques; + private final int largestDequeSize; + private final Map largestDequePartition; + + CommittableOffsets( + Map, Map> offsets, + int numCommittableMessages, + int numUncommittableMessages, + int numDeques, + int largestDequeSize, + Map largestDequePartition + ) { + this.offsets = offsets != null ? new HashMap<>(offsets) : Collections.emptyMap(); + this.numCommittableMessages = numCommittableMessages; + this.numUncommittableMessages = numUncommittableMessages; + this.numDeques = numDeques; + this.largestDequeSize = largestDequeSize; + this.largestDequePartition = largestDequePartition; + } + + /** + * @return the offsets that can be committed at the time of the snapshot + */ + public Map, Map> offsets() { + return Collections.unmodifiableMap(offsets); + } + + /** + * @return the number of committable messages at the time of the snapshot, where a committable message is both + * acknowledged and not preceded by any unacknowledged messages in the deque for its source partition + */ + public int numCommittableMessages() { + return numCommittableMessages; + } + + /** + * @return the number of uncommittable messages at the time of the snapshot, where an uncommittable message + * is either unacknowledged, or preceded in the deque for its source partition by an unacknowledged message + */ + public int numUncommittableMessages() { + return numUncommittableMessages; + } + + /** + * @return the number of non-empty deques tracking uncommittable messages at the time of the snapshot + */ + public int numDeques() { + return numDeques; + } + + /** + * @return the size of the largest deque at the time of the snapshot + */ + public int largestDequeSize() { + return largestDequeSize; + } + + /** + * Get the partition for the deque with the most uncommitted messages at the time of the snapshot. + * @return the applicable partition, which may be null, or null if there are no uncommitted messages; + * it is the caller's responsibility to distinguish between these two cases via {@link #hasPending()} + */ + public Map largestDequePartition() { + return largestDequePartition; + } + + /** + * @return whether there were any uncommittable messages at the time of the snapshot + */ + public boolean hasPending() { + return numUncommittableMessages > 0; + } + + /** + * @return whether there were any committable or uncommittable messages at the time of the snapshot + */ + public boolean isEmpty() { + return numCommittableMessages == 0 && numUncommittableMessages == 0 && offsets.isEmpty(); + } + + /** + * Create a new snapshot by combining the data for this snapshot with newer data in a more recent snapshot. + * Offsets are combined (giving precedence to the newer snapshot in case of conflict), the total number of + * committable messages is summed across the two snapshots, and the newer snapshot's information on pending + * messages (num deques, largest deque size, etc.) is used. + * @param newerOffsets the newer snapshot to combine with this snapshot + * @return the new offset snapshot containing information from this snapshot and the newer snapshot; never null + */ + public CommittableOffsets updatedWith(CommittableOffsets newerOffsets) { + Map, Map> offsets = new HashMap<>(this.offsets); + offsets.putAll(newerOffsets.offsets); + + return new CommittableOffsets( + offsets, + this.numCommittableMessages + newerOffsets.numCommittableMessages, + newerOffsets.numUncommittableMessages, + newerOffsets.numDeques, + newerOffsets.largestDequeSize, + newerOffsets.largestDequePartition + ); + } + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/TargetState.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/TargetState.java new file mode 100644 index 0000000..eb25b3d --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/TargetState.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +/** + * The target state of a connector is its desired state as indicated by the user + * through interaction with the REST API. When a connector is first created, its + * target state is "STARTED." This does not mean it has actually started, just that + * the Connect framework will attempt to start it after its tasks have been assigned. + * After the connector has been paused, the target state will change to PAUSED, + * and all the tasks will stop doing work. + * + * Target states are persisted in the config topic, which is read by all of the + * workers in the group. When a worker sees a new target state for a connector which + * is running, it will transition any tasks which it owns (i.e. which have been + * assigned to it by the leader) into the desired target state. Upon completion of + * a task rebalance, the worker will start the task in the last known target state. + */ +public enum TargetState { + STARTED, + PAUSED, +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/TaskConfig.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/TaskConfig.java new file mode 100644 index 0000000..649bc00 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/TaskConfig.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigDef.Importance; +import org.apache.kafka.common.config.ConfigDef.Type; + +import java.util.HashMap; +import java.util.Map; + +/** + *

            + * Configuration options for Tasks. These only include Kafka Connect system-level configuration + * options. + *

            + */ +public class TaskConfig extends AbstractConfig { + + public static final String TASK_CLASS_CONFIG = "task.class"; + private static final String TASK_CLASS_DOC = + "Name of the class for this task. Must be a subclass of org.apache.kafka.connect.connector.Task"; + + private static ConfigDef config; + + static { + config = new ConfigDef() + .define(TASK_CLASS_CONFIG, Type.CLASS, Importance.HIGH, TASK_CLASS_DOC); + } + + public TaskConfig() { + this(new HashMap()); + } + + public TaskConfig(Map props) { + super(config, props, true); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/TaskStatus.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/TaskStatus.java new file mode 100644 index 0000000..e35efca --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/TaskStatus.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.connect.util.ConnectorTaskId; + +public class TaskStatus extends AbstractStatus { + + public TaskStatus(ConnectorTaskId id, State state, String workerUrl, int generation, String trace) { + super(id, state, workerUrl, generation, trace); + } + + public TaskStatus(ConnectorTaskId id, State state, String workerUrl, int generation) { + super(id, state, workerUrl, generation, null); + } + + public interface Listener { + + /** + * Invoked after successful startup of the task. + * @param id The id of the task + */ + void onStartup(ConnectorTaskId id); + + /** + * Invoked after the task has been paused. + * @param id The id of the task + */ + void onPause(ConnectorTaskId id); + + /** + * Invoked after the task has been resumed. + * @param id The id of the task + */ + void onResume(ConnectorTaskId id); + + /** + * Invoked if the task raises an error. No shutdown event will follow. + * @param id The id of the task + * @param cause The error raised by the task. + */ + void onFailure(ConnectorTaskId id, Throwable cause); + + /** + * Invoked after successful shutdown of the task. + * @param id The id of the task + */ + void onShutdown(ConnectorTaskId id); + + /** + * Invoked after the task has been deleted. Can be called if the + * connector tasks have been reduced, or if the connector itself has + * been deleted. + * @param id The id of the task + */ + void onDeletion(ConnectorTaskId id); + + /** + * Invoked when the task is restarted asynchronously by the herder on processing a restart request. + * @param id The id of the task + */ + void onRestart(ConnectorTaskId id); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/TopicCreationConfig.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/TopicCreationConfig.java new file mode 100644 index 0000000..e5c5f15 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/TopicCreationConfig.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.connect.util.TopicAdmin; + +import java.util.Collections; +import java.util.List; +import java.util.regex.Pattern; +import java.util.regex.PatternSyntaxException; + +public class TopicCreationConfig { + + public static final String DEFAULT_TOPIC_CREATION_PREFIX = "topic.creation.default."; + public static final String DEFAULT_TOPIC_CREATION_GROUP = "default"; + + public static final String INCLUDE_REGEX_CONFIG = "include"; + private static final String INCLUDE_REGEX_DOC = "A list of regular expression literals " + + "used to match the topic names used by the source connector. This list is used " + + "to include topics that should be created using the topic settings defined by this group."; + + public static final String EXCLUDE_REGEX_CONFIG = "exclude"; + private static final String EXCLUDE_REGEX_DOC = "A list of regular expression literals " + + "used to match the topic names used by the source connector. This list is used " + + "to exclude topics from being created with the topic settings defined by this group. " + + "Note that exclusion rules have precedent and override any inclusion rules for the topics."; + + public static final String REPLICATION_FACTOR_CONFIG = "replication.factor"; + private static final String REPLICATION_FACTOR_DOC = "The replication factor for new topics " + + "created for this connector using this group. This value may be -1 to use the broker's" + + "default replication factor, or may be a positive number not larger than the number of " + + "brokers in the Kafka cluster. A value larger than the number of brokers in the Kafka cluster " + + "will result in an error when the new topic is created. For the default group this configuration " + + "is required. For any other group defined in topic.creation.groups this config is " + + "optional and if it's missing it gets the value of the default group"; + + public static final String PARTITIONS_CONFIG = "partitions"; + private static final String PARTITIONS_DOC = "The number of partitions new topics created for " + + "this connector. This value may be -1 to use the broker's default number of partitions, " + + "or a positive number representing the desired number of partitions. " + + "For the default group this configuration is required. For any " + + "other group defined in topic.creation.groups this config is optional and if it's " + + "missing it gets the value of the default group"; + + public static final ConfigDef.Validator REPLICATION_FACTOR_VALIDATOR = ConfigDef.LambdaValidator.with( + (name, value) -> validateReplicationFactor(name, (short) value), + () -> "Positive number not larger than the number of brokers in the Kafka cluster, or -1 to use the broker's default" + ); + public static final ConfigDef.Validator PARTITIONS_VALIDATOR = ConfigDef.LambdaValidator.with( + (name, value) -> validatePartitions(name, (int) value), + () -> "Positive number, or -1 to use the broker's default" + ); + @SuppressWarnings("unchecked") + public static final ConfigDef.Validator REGEX_VALIDATOR = ConfigDef.LambdaValidator.with( + (name, value) -> { + try { + ((List) value).forEach(Pattern::compile); + } catch (PatternSyntaxException e) { + throw new ConfigException(name, value, + "Syntax error in regular expression: " + e.getMessage()); + } + }, + () -> "Positive number, or -1 to use the broker's default" + ); + + private static void validatePartitions(String configName, int factor) { + if (factor != TopicAdmin.NO_PARTITIONS && factor < 1) { + throw new ConfigException(configName, factor, + "Number of partitions must be positive, or -1 to use the broker's default"); + } + } + + private static void validateReplicationFactor(String configName, short factor) { + if (factor != TopicAdmin.NO_REPLICATION_FACTOR && factor < 1) { + throw new ConfigException(configName, factor, + "Replication factor must be positive and not larger than the number of brokers in the Kafka cluster, or -1 to use the broker's default"); + } + } + + public static ConfigDef configDef(String group, short defaultReplicationFactor, int defaultParitionCount) { + int orderInGroup = 0; + ConfigDef configDef = new ConfigDef(); + configDef + .define(INCLUDE_REGEX_CONFIG, ConfigDef.Type.LIST, Collections.emptyList(), + REGEX_VALIDATOR, ConfigDef.Importance.LOW, + INCLUDE_REGEX_DOC, group, ++orderInGroup, ConfigDef.Width.LONG, + "Inclusion Topic Pattern for " + group) + .define(EXCLUDE_REGEX_CONFIG, ConfigDef.Type.LIST, Collections.emptyList(), + REGEX_VALIDATOR, ConfigDef.Importance.LOW, + EXCLUDE_REGEX_DOC, group, ++orderInGroup, ConfigDef.Width.LONG, + "Exclusion Topic Pattern for " + group) + .define(REPLICATION_FACTOR_CONFIG, ConfigDef.Type.SHORT, + defaultReplicationFactor, REPLICATION_FACTOR_VALIDATOR, + ConfigDef.Importance.LOW, REPLICATION_FACTOR_DOC, group, ++orderInGroup, + ConfigDef.Width.LONG, "Replication Factor for Topics in " + group) + .define(PARTITIONS_CONFIG, ConfigDef.Type.INT, + defaultParitionCount, PARTITIONS_VALIDATOR, + ConfigDef.Importance.LOW, PARTITIONS_DOC, group, ++orderInGroup, + ConfigDef.Width.LONG, "Partition Count for Topics in " + group); + return configDef; + } + + public static ConfigDef defaultGroupConfigDef() { + int orderInGroup = 0; + ConfigDef configDef = new ConfigDef(); + configDef + .define(INCLUDE_REGEX_CONFIG, ConfigDef.Type.LIST, ".*", + new ConfigDef.NonNullValidator(), ConfigDef.Importance.LOW, + INCLUDE_REGEX_DOC, DEFAULT_TOPIC_CREATION_GROUP, ++orderInGroup, ConfigDef.Width.LONG, + "Inclusion Topic Pattern for " + DEFAULT_TOPIC_CREATION_GROUP) + .define(EXCLUDE_REGEX_CONFIG, ConfigDef.Type.LIST, Collections.emptyList(), + new ConfigDef.NonNullValidator(), ConfigDef.Importance.LOW, + EXCLUDE_REGEX_DOC, DEFAULT_TOPIC_CREATION_GROUP, ++orderInGroup, ConfigDef.Width.LONG, + "Exclusion Topic Pattern for " + DEFAULT_TOPIC_CREATION_GROUP) + .define(REPLICATION_FACTOR_CONFIG, ConfigDef.Type.SHORT, + ConfigDef.NO_DEFAULT_VALUE, REPLICATION_FACTOR_VALIDATOR, + ConfigDef.Importance.LOW, REPLICATION_FACTOR_DOC, DEFAULT_TOPIC_CREATION_GROUP, ++orderInGroup, + ConfigDef.Width.LONG, "Replication Factor for Topics in " + DEFAULT_TOPIC_CREATION_GROUP) + .define(PARTITIONS_CONFIG, ConfigDef.Type.INT, + ConfigDef.NO_DEFAULT_VALUE, PARTITIONS_VALIDATOR, + ConfigDef.Importance.LOW, PARTITIONS_DOC, DEFAULT_TOPIC_CREATION_GROUP, ++orderInGroup, + ConfigDef.Width.LONG, "Partition Count for Topics in " + DEFAULT_TOPIC_CREATION_GROUP); + return configDef; + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/TopicStatus.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/TopicStatus.java new file mode 100644 index 0000000..16dcd80 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/TopicStatus.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.connect.util.ConnectorTaskId; + +import java.util.Objects; + +/** + * Represents the metadata that is stored as the value of the record that is stored in the + * {@link org.apache.kafka.connect.storage.StatusBackingStore#put(TopicStatus)}, + */ +public class TopicStatus { + private final String topic; + private final String connector; + private final int task; + private final long discoverTimestamp; + + public TopicStatus(String topic, ConnectorTaskId task, long discoverTimestamp) { + this(topic, task.connector(), task.task(), discoverTimestamp); + } + + public TopicStatus(String topic, String connector, int task, long discoverTimestamp) { + this.topic = Objects.requireNonNull(topic); + this.connector = Objects.requireNonNull(connector); + this.task = task; + this.discoverTimestamp = discoverTimestamp; + } + + /** + * Get the name of the topic. + * + * @return the topic name; never null + */ + public String topic() { + return topic; + } + + /** + * Get the name of the connector. + * + * @return the connector name; never null + */ + public String connector() { + return connector; + } + + /** + * Get the ID of the task that stored the topic status. + * + * @return the task ID + */ + public int task() { + return task; + } + + /** + * Get a timestamp that represents when this topic was discovered as being actively used by + * this connector. + * + * @return the discovery timestamp + */ + public long discoverTimestamp() { + return discoverTimestamp; + } + + @Override + public String toString() { + return "TopicStatus{" + + "topic='" + topic + '\'' + + ", connector='" + connector + '\'' + + ", task=" + task + + ", discoverTimestamp=" + discoverTimestamp + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof TopicStatus)) { + return false; + } + TopicStatus that = (TopicStatus) o; + return task == that.task && + discoverTimestamp == that.discoverTimestamp && + topic.equals(that.topic) && + connector.equals(that.connector); + } + + @Override + public int hashCode() { + return Objects.hash(topic, connector, task, discoverTimestamp); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/TransformationChain.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/TransformationChain.java new file mode 100644 index 0000000..6777a96 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/TransformationChain.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.connect.connector.ConnectRecord; +import org.apache.kafka.connect.runtime.errors.RetryWithToleranceOperator; +import org.apache.kafka.connect.runtime.errors.Stage; +import org.apache.kafka.connect.transforms.Transformation; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.Objects; +import java.util.StringJoiner; + +public class TransformationChain> implements AutoCloseable { + private static final Logger log = LoggerFactory.getLogger(TransformationChain.class); + + private final List> transformations; + private final RetryWithToleranceOperator retryWithToleranceOperator; + + public TransformationChain(List> transformations, RetryWithToleranceOperator retryWithToleranceOperator) { + this.transformations = transformations; + this.retryWithToleranceOperator = retryWithToleranceOperator; + } + + public R apply(R record) { + if (transformations.isEmpty()) return record; + + for (final Transformation transformation : transformations) { + final R current = record; + + log.trace("Applying transformation {} to {}", + transformation.getClass().getName(), record); + // execute the operation + record = retryWithToleranceOperator.execute(() -> transformation.apply(current), Stage.TRANSFORMATION, transformation.getClass()); + + if (record == null) break; + } + + return record; + } + + @Override + public void close() { + for (Transformation transformation : transformations) { + transformation.close(); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TransformationChain that = (TransformationChain) o; + return Objects.equals(transformations, that.transformations); + } + + @Override + public int hashCode() { + return Objects.hash(transformations); + } + + public String toString() { + StringJoiner chain = new StringJoiner(", ", getClass().getName() + "{", "}"); + for (Transformation transformation : transformations) { + chain.add(transformation.getClass().getName()); + } + return chain.toString(); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/Worker.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/Worker.java new file mode 100644 index 0000000..11af818 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/Worker.java @@ -0,0 +1,1053 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.clients.admin.AdminClientConfig; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.MetricNameTemplate; +import org.apache.kafka.common.config.ConfigValue; +import org.apache.kafka.common.config.provider.ConfigProvider; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.connect.connector.Connector; +import org.apache.kafka.connect.connector.Task; +import org.apache.kafka.connect.connector.policy.ConnectorClientConfigOverridePolicy; +import org.apache.kafka.connect.connector.policy.ConnectorClientConfigRequest; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.health.ConnectorType; +import org.apache.kafka.connect.json.JsonConverter; +import org.apache.kafka.connect.json.JsonConverterConfig; +import org.apache.kafka.connect.runtime.ConnectMetrics.MetricGroup; +import org.apache.kafka.connect.runtime.distributed.ClusterConfigState; +import org.apache.kafka.connect.runtime.errors.DeadLetterQueueReporter; +import org.apache.kafka.connect.runtime.errors.ErrorHandlingMetrics; +import org.apache.kafka.connect.runtime.errors.ErrorReporter; +import org.apache.kafka.connect.runtime.errors.LogReporter; +import org.apache.kafka.connect.runtime.errors.RetryWithToleranceOperator; +import org.apache.kafka.connect.runtime.errors.WorkerErrantRecordReporter; +import org.apache.kafka.connect.runtime.isolation.Plugins; +import org.apache.kafka.connect.runtime.isolation.Plugins.ClassLoaderUsage; +import org.apache.kafka.connect.sink.SinkRecord; +import org.apache.kafka.connect.sink.SinkTask; +import org.apache.kafka.connect.source.SourceRecord; +import org.apache.kafka.connect.source.SourceTask; +import org.apache.kafka.connect.storage.CloseableOffsetStorageReader; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.storage.HeaderConverter; +import org.apache.kafka.connect.storage.OffsetBackingStore; +import org.apache.kafka.connect.storage.OffsetStorageReader; +import org.apache.kafka.connect.storage.OffsetStorageReaderImpl; +import org.apache.kafka.connect.storage.OffsetStorageWriter; +import org.apache.kafka.connect.util.Callback; +import org.apache.kafka.connect.util.ConnectUtils; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.apache.kafka.connect.util.LoggingContext; +import org.apache.kafka.connect.util.SinkUtils; +import org.apache.kafka.connect.util.TopicAdmin; +import org.apache.kafka.connect.util.TopicCreationGroup; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +/** + *

            + * Worker runs a (dynamic) set of tasks in a set of threads, doing the work of actually moving + * data to/from Kafka. + *

            + *

            + * Since each task has a dedicated thread, this is mainly just a container for them. + *

            + */ +public class Worker { + + public static final long CONNECTOR_GRACEFUL_SHUTDOWN_TIMEOUT_MS = TimeUnit.SECONDS.toMillis(5); + + private static final Logger log = LoggerFactory.getLogger(Worker.class); + + protected Herder herder; + private final ExecutorService executor; + private final Time time; + private final String workerId; + //kafka cluster id + private final String kafkaClusterId; + private final Plugins plugins; + private final ConnectMetrics metrics; + private final WorkerMetricsGroup workerMetricsGroup; + private ConnectorStatusMetricsGroup connectorStatusMetricsGroup; + private final WorkerConfig config; + private final Converter internalKeyConverter; + private final Converter internalValueConverter; + private final OffsetBackingStore offsetBackingStore; + + private final ConcurrentMap connectors = new ConcurrentHashMap<>(); + private final ConcurrentMap tasks = new ConcurrentHashMap<>(); + private SourceTaskOffsetCommitter sourceTaskOffsetCommitter; + private final WorkerConfigTransformer workerConfigTransformer; + private final ConnectorClientConfigOverridePolicy connectorClientConfigOverridePolicy; + + public Worker( + String workerId, + Time time, + Plugins plugins, + WorkerConfig config, + OffsetBackingStore offsetBackingStore, + ConnectorClientConfigOverridePolicy connectorClientConfigOverridePolicy) { + this(workerId, time, plugins, config, offsetBackingStore, Executors.newCachedThreadPool(), connectorClientConfigOverridePolicy); + } + + @SuppressWarnings("deprecation") + Worker( + String workerId, + Time time, + Plugins plugins, + WorkerConfig config, + OffsetBackingStore offsetBackingStore, + ExecutorService executorService, + ConnectorClientConfigOverridePolicy connectorClientConfigOverridePolicy + ) { + this.kafkaClusterId = ConnectUtils.lookupKafkaClusterId(config); + this.metrics = new ConnectMetrics(workerId, config, time, kafkaClusterId); + this.executor = executorService; + this.workerId = workerId; + this.time = time; + this.plugins = plugins; + this.config = config; + this.connectorClientConfigOverridePolicy = connectorClientConfigOverridePolicy; + this.workerMetricsGroup = new WorkerMetricsGroup(this.connectors, this.tasks, metrics); + + Map internalConverterConfig = Collections.singletonMap(JsonConverterConfig.SCHEMAS_ENABLE_CONFIG, "false"); + this.internalKeyConverter = plugins.newInternalConverter(true, JsonConverter.class.getName(), internalConverterConfig); + this.internalValueConverter = plugins.newInternalConverter(false, JsonConverter.class.getName(), internalConverterConfig); + + this.offsetBackingStore = offsetBackingStore; + this.offsetBackingStore.configure(config); + + this.workerConfigTransformer = initConfigTransformer(); + + } + + private WorkerConfigTransformer initConfigTransformer() { + final List providerNames = config.getList(WorkerConfig.CONFIG_PROVIDERS_CONFIG); + Map providerMap = new HashMap<>(); + for (String providerName : providerNames) { + ConfigProvider configProvider = plugins.newConfigProvider( + config, + WorkerConfig.CONFIG_PROVIDERS_CONFIG + "." + providerName, + ClassLoaderUsage.PLUGINS + ); + providerMap.put(providerName, configProvider); + } + return new WorkerConfigTransformer(this, providerMap); + } + + public WorkerConfigTransformer configTransformer() { + return workerConfigTransformer; + } + + protected Herder herder() { + return herder; + } + + /** + * Start worker. + */ + public void start() { + log.info("Worker starting"); + + offsetBackingStore.start(); + sourceTaskOffsetCommitter = new SourceTaskOffsetCommitter(config); + + connectorStatusMetricsGroup = new ConnectorStatusMetricsGroup(metrics, tasks, herder); + + log.info("Worker started"); + } + + /** + * Stop worker. + */ + public void stop() { + log.info("Worker stopping"); + + long started = time.milliseconds(); + long limit = started + config.getLong(WorkerConfig.TASK_SHUTDOWN_GRACEFUL_TIMEOUT_MS_CONFIG); + + if (!connectors.isEmpty()) { + log.warn("Shutting down connectors {} uncleanly; herder should have shut down connectors before the Worker is stopped", connectors.keySet()); + stopAndAwaitConnectors(); + } + + if (!tasks.isEmpty()) { + log.warn("Shutting down tasks {} uncleanly; herder should have shut down tasks before the Worker is stopped", tasks.keySet()); + stopAndAwaitTasks(); + } + + long timeoutMs = limit - time.milliseconds(); + sourceTaskOffsetCommitter.close(timeoutMs); + + offsetBackingStore.stop(); + metrics.stop(); + + log.info("Worker stopped"); + + workerMetricsGroup.close(); + connectorStatusMetricsGroup.close(); + + workerConfigTransformer.close(); + } + + /** + * Start a connector managed by this worker. + * + * @param connName the connector name. + * @param connProps the properties of the connector. + * @param ctx the connector runtime context. + * @param statusListener a listener for the runtime status transitions of the connector. + * @param initialState the initial state of the connector. + * @param onConnectorStateChange invoked when the initial state change of the connector is completed + */ + public void startConnector( + String connName, + Map connProps, + CloseableConnectorContext ctx, + ConnectorStatus.Listener statusListener, + TargetState initialState, + Callback onConnectorStateChange + ) { + final ConnectorStatus.Listener connectorStatusListener = workerMetricsGroup.wrapStatusListener(statusListener); + try (LoggingContext loggingContext = LoggingContext.forConnector(connName)) { + if (connectors.containsKey(connName)) { + onConnectorStateChange.onCompletion( + new ConnectException("Connector with name " + connName + " already exists"), + null); + return; + } + + final WorkerConnector workerConnector; + ClassLoader savedLoader = plugins.currentThreadLoader(); + try { + // By the time we arrive here, CONNECTOR_CLASS_CONFIG has been validated already + // Getting this value from the unparsed map will allow us to instantiate the + // right config (source or sink) + final String connClass = connProps.get(ConnectorConfig.CONNECTOR_CLASS_CONFIG); + ClassLoader connectorLoader = plugins.delegatingLoader().connectorLoader(connClass); + savedLoader = Plugins.compareAndSwapLoaders(connectorLoader); + + log.info("Creating connector {} of type {}", connName, connClass); + final Connector connector = plugins.newConnector(connClass); + final ConnectorConfig connConfig = ConnectUtils.isSinkConnector(connector) + ? new SinkConnectorConfig(plugins, connProps) + : new SourceConnectorConfig(plugins, connProps, config.topicCreationEnable()); + + final OffsetStorageReader offsetReader = new OffsetStorageReaderImpl( + offsetBackingStore, connName, internalKeyConverter, internalValueConverter); + workerConnector = new WorkerConnector( + connName, connector, connConfig, ctx, metrics, connectorStatusListener, offsetReader, connectorLoader); + log.info("Instantiated connector {} with version {} of type {}", connName, connector.version(), connector.getClass()); + workerConnector.transitionTo(initialState, onConnectorStateChange); + Plugins.compareAndSwapLoaders(savedLoader); + } catch (Throwable t) { + log.error("Failed to start connector {}", connName, t); + // Can't be put in a finally block because it needs to be swapped before the call on + // statusListener + Plugins.compareAndSwapLoaders(savedLoader); + connectorStatusListener.onFailure(connName, t); + onConnectorStateChange.onCompletion(t, null); + return; + } + + WorkerConnector existing = connectors.putIfAbsent(connName, workerConnector); + if (existing != null) { + onConnectorStateChange.onCompletion( + new ConnectException("Connector with name " + connName + " already exists"), + null); + // Don't need to do any cleanup of the WorkerConnector instance (such as calling + // shutdown() on it) here because it hasn't actually started running yet + return; + } + + executor.submit(workerConnector); + + log.info("Finished creating connector {}", connName); + } + } + + /** + * Return true if the connector associated with this worker is a sink connector. + * + * @param connName the connector name. + * @return true if the connector belongs to the worker and is a sink connector. + * @throws ConnectException if the worker does not manage a connector with the given name. + */ + public boolean isSinkConnector(String connName) { + WorkerConnector workerConnector = connectors.get(connName); + if (workerConnector == null) + throw new ConnectException("Connector " + connName + " not found in this worker."); + + ClassLoader savedLoader = plugins.currentThreadLoader(); + try { + savedLoader = Plugins.compareAndSwapLoaders(workerConnector.loader()); + return workerConnector.isSinkConnector(); + } finally { + Plugins.compareAndSwapLoaders(savedLoader); + } + } + + /** + * Get a list of updated task properties for the tasks of this connector. + * + * @param connName the connector name. + * @return a list of updated tasks properties. + */ + public List> connectorTaskConfigs(String connName, ConnectorConfig connConfig) { + List> result = new ArrayList<>(); + try (LoggingContext loggingContext = LoggingContext.forConnector(connName)) { + log.trace("Reconfiguring connector tasks for {}", connName); + + WorkerConnector workerConnector = connectors.get(connName); + if (workerConnector == null) + throw new ConnectException("Connector " + connName + " not found in this worker."); + + int maxTasks = connConfig.getInt(ConnectorConfig.TASKS_MAX_CONFIG); + Map connOriginals = connConfig.originalsStrings(); + + Connector connector = workerConnector.connector(); + ClassLoader savedLoader = plugins.currentThreadLoader(); + try { + savedLoader = Plugins.compareAndSwapLoaders(workerConnector.loader()); + String taskClassName = connector.taskClass().getName(); + for (Map taskProps : connector.taskConfigs(maxTasks)) { + // Ensure we don't modify the connector's copy of the config + Map taskConfig = new HashMap<>(taskProps); + taskConfig.put(TaskConfig.TASK_CLASS_CONFIG, taskClassName); + if (connOriginals.containsKey(SinkTask.TOPICS_CONFIG)) { + taskConfig.put(SinkTask.TOPICS_CONFIG, connOriginals.get(SinkTask.TOPICS_CONFIG)); + } + if (connOriginals.containsKey(SinkTask.TOPICS_REGEX_CONFIG)) { + taskConfig.put(SinkTask.TOPICS_REGEX_CONFIG, connOriginals.get(SinkTask.TOPICS_REGEX_CONFIG)); + } + result.add(taskConfig); + } + } finally { + Plugins.compareAndSwapLoaders(savedLoader); + } + } + + return result; + } + + /** + * Stop a connector managed by this worker. + * + * @param connName the connector name. + */ + private void stopConnector(String connName) { + try (LoggingContext loggingContext = LoggingContext.forConnector(connName)) { + WorkerConnector workerConnector = connectors.get(connName); + log.info("Stopping connector {}", connName); + + if (workerConnector == null) { + log.warn("Ignoring stop request for unowned connector {}", connName); + return; + } + + ClassLoader savedLoader = plugins.currentThreadLoader(); + try { + savedLoader = Plugins.compareAndSwapLoaders(workerConnector.loader()); + workerConnector.shutdown(); + } finally { + Plugins.compareAndSwapLoaders(savedLoader); + } + } + } + + private void stopConnectors(Collection ids) { + // Herder is responsible for stopping connectors. This is an internal method to sequentially + // stop connectors that have not explicitly been stopped. + for (String connector: ids) + stopConnector(connector); + } + + private void awaitStopConnector(String connName, long timeout) { + try (LoggingContext loggingContext = LoggingContext.forConnector(connName)) { + WorkerConnector connector = connectors.remove(connName); + if (connector == null) { + log.warn("Ignoring await stop request for non-present connector {}", connName); + return; + } + + if (!connector.awaitShutdown(timeout)) { + log.error("Connector ‘{}’ failed to properly shut down, has become unresponsive, and " + + "may be consuming external resources. Correct the configuration for " + + "this connector or remove the connector. After fixing the connector, it " + + "may be necessary to restart this worker to release any consumed " + + "resources.", connName); + connector.cancel(); + } else { + log.debug("Graceful stop of connector {} succeeded.", connName); + } + } + } + + private void awaitStopConnectors(Collection ids) { + long now = time.milliseconds(); + long deadline = now + CONNECTOR_GRACEFUL_SHUTDOWN_TIMEOUT_MS; + for (String id : ids) { + long remaining = Math.max(0, deadline - time.milliseconds()); + awaitStopConnector(id, remaining); + } + } + + /** + * Stop asynchronously all the worker's connectors and await their termination. + */ + public void stopAndAwaitConnectors() { + stopAndAwaitConnectors(new ArrayList<>(connectors.keySet())); + } + + /** + * Stop asynchronously a collection of connectors that belong to this worker and await their + * termination. + * + * @param ids the collection of connectors to be stopped. + */ + public void stopAndAwaitConnectors(Collection ids) { + stopConnectors(ids); + awaitStopConnectors(ids); + } + + /** + * Stop a connector that belongs to this worker and await its termination. + * + * @param connName the name of the connector to be stopped. + */ + public void stopAndAwaitConnector(String connName) { + stopConnector(connName); + awaitStopConnectors(Collections.singletonList(connName)); + } + + /** + * Get the IDs of the connectors currently running in this worker. + * + * @return the set of connector IDs. + */ + public Set connectorNames() { + return connectors.keySet(); + } + + /** + * Return true if a connector with the given name is managed by this worker and is currently running. + * + * @param connName the connector name. + * @return true if the connector is running, false if the connector is not running or is not manages by this worker. + */ + public boolean isRunning(String connName) { + WorkerConnector workerConnector = connectors.get(connName); + return workerConnector != null && workerConnector.isRunning(); + } + + /** + * Start a task managed by this worker. + * + * @param id the task ID. + * @param connProps the connector properties. + * @param taskProps the tasks properties. + * @param statusListener a listener for the runtime status transitions of the task. + * @param initialState the initial state of the connector. + * @return true if the task started successfully. + */ + public boolean startTask( + ConnectorTaskId id, + ClusterConfigState configState, + Map connProps, + Map taskProps, + TaskStatus.Listener statusListener, + TargetState initialState + ) { + final WorkerTask workerTask; + final TaskStatus.Listener taskStatusListener = workerMetricsGroup.wrapStatusListener(statusListener); + try (LoggingContext loggingContext = LoggingContext.forTask(id)) { + log.info("Creating task {}", id); + + if (tasks.containsKey(id)) + throw new ConnectException("Task already exists in this worker: " + id); + + connectorStatusMetricsGroup.recordTaskAdded(id); + ClassLoader savedLoader = plugins.currentThreadLoader(); + try { + String connType = connProps.get(ConnectorConfig.CONNECTOR_CLASS_CONFIG); + ClassLoader connectorLoader = plugins.delegatingLoader().connectorLoader(connType); + savedLoader = Plugins.compareAndSwapLoaders(connectorLoader); + final ConnectorConfig connConfig = new ConnectorConfig(plugins, connProps); + final TaskConfig taskConfig = new TaskConfig(taskProps); + final Class taskClass = taskConfig.getClass(TaskConfig.TASK_CLASS_CONFIG).asSubclass(Task.class); + final Task task = plugins.newTask(taskClass); + log.info("Instantiated task {} with version {} of type {}", id, task.version(), taskClass.getName()); + + // By maintaining connector's specific class loader for this thread here, we first + // search for converters within the connector dependencies. + // If any of these aren't found, that means the connector didn't configure specific converters, + // so we should instantiate based upon the worker configuration + Converter keyConverter = plugins.newConverter(connConfig, WorkerConfig.KEY_CONVERTER_CLASS_CONFIG, ClassLoaderUsage + .CURRENT_CLASSLOADER); + Converter valueConverter = plugins.newConverter(connConfig, WorkerConfig.VALUE_CONVERTER_CLASS_CONFIG, ClassLoaderUsage.CURRENT_CLASSLOADER); + HeaderConverter headerConverter = plugins.newHeaderConverter(connConfig, WorkerConfig.HEADER_CONVERTER_CLASS_CONFIG, + ClassLoaderUsage.CURRENT_CLASSLOADER); + if (keyConverter == null) { + keyConverter = plugins.newConverter(config, WorkerConfig.KEY_CONVERTER_CLASS_CONFIG, ClassLoaderUsage.PLUGINS); + log.info("Set up the key converter {} for task {} using the worker config", keyConverter.getClass(), id); + } else { + log.info("Set up the key converter {} for task {} using the connector config", keyConverter.getClass(), id); + } + if (valueConverter == null) { + valueConverter = plugins.newConverter(config, WorkerConfig.VALUE_CONVERTER_CLASS_CONFIG, ClassLoaderUsage.PLUGINS); + log.info("Set up the value converter {} for task {} using the worker config", valueConverter.getClass(), id); + } else { + log.info("Set up the value converter {} for task {} using the connector config", valueConverter.getClass(), id); + } + if (headerConverter == null) { + headerConverter = plugins.newHeaderConverter(config, WorkerConfig.HEADER_CONVERTER_CLASS_CONFIG, ClassLoaderUsage + .PLUGINS); + log.info("Set up the header converter {} for task {} using the worker config", headerConverter.getClass(), id); + } else { + log.info("Set up the header converter {} for task {} using the connector config", headerConverter.getClass(), id); + } + + workerTask = buildWorkerTask(configState, connConfig, id, task, taskStatusListener, + initialState, keyConverter, valueConverter, headerConverter, connectorLoader); + workerTask.initialize(taskConfig); + Plugins.compareAndSwapLoaders(savedLoader); + } catch (Throwable t) { + log.error("Failed to start task {}", id, t); + // Can't be put in a finally block because it needs to be swapped before the call on + // statusListener + Plugins.compareAndSwapLoaders(savedLoader); + connectorStatusMetricsGroup.recordTaskRemoved(id); + taskStatusListener.onFailure(id, t); + return false; + } + + WorkerTask existing = tasks.putIfAbsent(id, workerTask); + if (existing != null) + throw new ConnectException("Task already exists in this worker: " + id); + + executor.submit(workerTask); + if (workerTask instanceof WorkerSourceTask) { + sourceTaskOffsetCommitter.schedule(id, (WorkerSourceTask) workerTask); + } + return true; + } + } + + private WorkerTask buildWorkerTask(ClusterConfigState configState, + ConnectorConfig connConfig, + ConnectorTaskId id, + Task task, + TaskStatus.Listener statusListener, + TargetState initialState, + Converter keyConverter, + Converter valueConverter, + HeaderConverter headerConverter, + ClassLoader loader) { + ErrorHandlingMetrics errorHandlingMetrics = errorHandlingMetrics(id); + final Class connectorClass = plugins.connectorClass( + connConfig.getString(ConnectorConfig.CONNECTOR_CLASS_CONFIG)); + RetryWithToleranceOperator retryWithToleranceOperator = new RetryWithToleranceOperator(connConfig.errorRetryTimeout(), + connConfig.errorMaxDelayInMillis(), connConfig.errorToleranceType(), Time.SYSTEM); + retryWithToleranceOperator.metrics(errorHandlingMetrics); + + // Decide which type of worker task we need based on the type of task. + if (task instanceof SourceTask) { + SourceConnectorConfig sourceConfig = new SourceConnectorConfig(plugins, + connConfig.originalsStrings(), config.topicCreationEnable()); + retryWithToleranceOperator.reporters(sourceTaskReporters(id, sourceConfig, errorHandlingMetrics)); + TransformationChain transformationChain = new TransformationChain<>(sourceConfig.transformations(), retryWithToleranceOperator); + log.info("Initializing: {}", transformationChain); + CloseableOffsetStorageReader offsetReader = new OffsetStorageReaderImpl(offsetBackingStore, id.connector(), + internalKeyConverter, internalValueConverter); + OffsetStorageWriter offsetWriter = new OffsetStorageWriter(offsetBackingStore, id.connector(), + internalKeyConverter, internalValueConverter); + Map producerProps = producerConfigs(id, "connector-producer-" + id, config, sourceConfig, connectorClass, + connectorClientConfigOverridePolicy, kafkaClusterId); + KafkaProducer producer = new KafkaProducer<>(producerProps); + TopicAdmin admin; + Map topicCreationGroups; + if (config.topicCreationEnable() && sourceConfig.usesTopicCreation()) { + Map adminProps = adminConfigs(id, "connector-adminclient-" + id, config, + sourceConfig, connectorClass, connectorClientConfigOverridePolicy, kafkaClusterId); + admin = new TopicAdmin(adminProps); + topicCreationGroups = TopicCreationGroup.configuredGroups(sourceConfig); + } else { + admin = null; + topicCreationGroups = null; + } + + // Note we pass the configState as it performs dynamic transformations under the covers + return new WorkerSourceTask(id, (SourceTask) task, statusListener, initialState, keyConverter, valueConverter, + headerConverter, transformationChain, producer, admin, topicCreationGroups, + offsetReader, offsetWriter, config, configState, metrics, loader, time, retryWithToleranceOperator, herder.statusBackingStore(), executor); + } else if (task instanceof SinkTask) { + TransformationChain transformationChain = new TransformationChain<>(connConfig.transformations(), retryWithToleranceOperator); + log.info("Initializing: {}", transformationChain); + SinkConnectorConfig sinkConfig = new SinkConnectorConfig(plugins, connConfig.originalsStrings()); + retryWithToleranceOperator.reporters(sinkTaskReporters(id, sinkConfig, errorHandlingMetrics, connectorClass)); + WorkerErrantRecordReporter workerErrantRecordReporter = createWorkerErrantRecordReporter(sinkConfig, retryWithToleranceOperator, + keyConverter, valueConverter, headerConverter); + + Map consumerProps = consumerConfigs(id, config, connConfig, connectorClass, connectorClientConfigOverridePolicy, kafkaClusterId); + KafkaConsumer consumer = new KafkaConsumer<>(consumerProps); + + return new WorkerSinkTask(id, (SinkTask) task, statusListener, initialState, config, configState, metrics, keyConverter, + valueConverter, headerConverter, transformationChain, consumer, loader, time, + retryWithToleranceOperator, workerErrantRecordReporter, herder.statusBackingStore()); + } else { + log.error("Tasks must be a subclass of either SourceTask or SinkTask and current is {}", task); + throw new ConnectException("Tasks must be a subclass of either SourceTask or SinkTask"); + } + } + + static Map producerConfigs(ConnectorTaskId id, + String defaultClientId, + WorkerConfig config, + ConnectorConfig connConfig, + Class connectorClass, + ConnectorClientConfigOverridePolicy connectorClientConfigOverridePolicy, + String clusterId) { + Map producerProps = new HashMap<>(); + producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, Utils.join(config.getList(WorkerConfig.BOOTSTRAP_SERVERS_CONFIG), ",")); + producerProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArraySerializer"); + producerProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArraySerializer"); + // These settings will execute infinite retries on retriable exceptions. They *may* be overridden via configs passed to the worker, + // but this may compromise the delivery guarantees of Kafka Connect. + producerProps.put(ProducerConfig.MAX_BLOCK_MS_CONFIG, Long.toString(Long.MAX_VALUE)); + producerProps.put(ProducerConfig.ACKS_CONFIG, "all"); + producerProps.put(ProducerConfig.MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION, "1"); + producerProps.put(ProducerConfig.DELIVERY_TIMEOUT_MS_CONFIG, Integer.toString(Integer.MAX_VALUE)); + producerProps.put(ProducerConfig.CLIENT_ID_CONFIG, defaultClientId); + // User-specified overrides + producerProps.putAll(config.originalsWithPrefix("producer.")); + //add client metrics.context properties + ConnectUtils.addMetricsContextProperties(producerProps, config, clusterId); + + // Connector-specified overrides + Map producerOverrides = + connectorClientConfigOverrides(id, connConfig, connectorClass, ConnectorConfig.CONNECTOR_CLIENT_PRODUCER_OVERRIDES_PREFIX, + ConnectorType.SOURCE, ConnectorClientConfigRequest.ClientType.PRODUCER, + connectorClientConfigOverridePolicy); + producerProps.putAll(producerOverrides); + + return producerProps; + } + + static Map consumerConfigs(ConnectorTaskId id, + WorkerConfig config, + ConnectorConfig connConfig, + Class connectorClass, + ConnectorClientConfigOverridePolicy connectorClientConfigOverridePolicy, + String clusterId) { + // Include any unknown worker configs so consumer configs can be set globally on the worker + // and through to the task + Map consumerProps = new HashMap<>(); + + consumerProps.put(ConsumerConfig.GROUP_ID_CONFIG, SinkUtils.consumerGroupId(id.connector())); + consumerProps.put(ConsumerConfig.CLIENT_ID_CONFIG, "connector-consumer-" + id); + consumerProps.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, + Utils.join(config.getList(WorkerConfig.BOOTSTRAP_SERVERS_CONFIG), ",")); + consumerProps.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false"); + consumerProps.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + consumerProps.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArrayDeserializer"); + consumerProps.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArrayDeserializer"); + + consumerProps.putAll(config.originalsWithPrefix("consumer.")); + //add client metrics.context properties + ConnectUtils.addMetricsContextProperties(consumerProps, config, clusterId); + // Connector-specified overrides + Map consumerOverrides = + connectorClientConfigOverrides(id, connConfig, connectorClass, ConnectorConfig.CONNECTOR_CLIENT_CONSUMER_OVERRIDES_PREFIX, + ConnectorType.SINK, ConnectorClientConfigRequest.ClientType.CONSUMER, + connectorClientConfigOverridePolicy); + consumerProps.putAll(consumerOverrides); + + return consumerProps; + } + + static Map adminConfigs(ConnectorTaskId id, + String defaultClientId, + WorkerConfig config, + ConnectorConfig connConfig, + Class connectorClass, + ConnectorClientConfigOverridePolicy connectorClientConfigOverridePolicy, + String clusterId) { + Map adminProps = new HashMap<>(); + // Use the top-level worker configs to retain backwards compatibility with older releases which + // did not require a prefix for connector admin client configs in the worker configuration file + // Ignore configs that begin with "admin." since those will be added next (with the prefix stripped) + // and those that begin with "producer." and "consumer.", since we know they aren't intended for + // the admin client + Map nonPrefixedWorkerConfigs = config.originals().entrySet().stream() + .filter(e -> !e.getKey().startsWith("admin.") + && !e.getKey().startsWith("producer.") + && !e.getKey().startsWith("consumer.")) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + adminProps.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, + Utils.join(config.getList(WorkerConfig.BOOTSTRAP_SERVERS_CONFIG), ",")); + adminProps.put(AdminClientConfig.CLIENT_ID_CONFIG, defaultClientId); + adminProps.putAll(nonPrefixedWorkerConfigs); + + // Admin client-specific overrides in the worker config + adminProps.putAll(config.originalsWithPrefix("admin.")); + + // Connector-specified overrides + Map adminOverrides = + connectorClientConfigOverrides(id, connConfig, connectorClass, ConnectorConfig.CONNECTOR_CLIENT_ADMIN_OVERRIDES_PREFIX, + ConnectorType.SINK, ConnectorClientConfigRequest.ClientType.ADMIN, + connectorClientConfigOverridePolicy); + adminProps.putAll(adminOverrides); + + //add client metrics.context properties + ConnectUtils.addMetricsContextProperties(adminProps, config, clusterId); + + return adminProps; + } + + private static Map connectorClientConfigOverrides(ConnectorTaskId id, + ConnectorConfig connConfig, + Class connectorClass, + String clientConfigPrefix, + ConnectorType connectorType, + ConnectorClientConfigRequest.ClientType clientType, + ConnectorClientConfigOverridePolicy connectorClientConfigOverridePolicy) { + Map clientOverrides = connConfig.originalsWithPrefix(clientConfigPrefix); + ConnectorClientConfigRequest connectorClientConfigRequest = new ConnectorClientConfigRequest( + id.connector(), + connectorType, + connectorClass, + clientOverrides, + clientType + ); + List configValues = connectorClientConfigOverridePolicy.validate(connectorClientConfigRequest); + List errorConfigs = configValues.stream(). + filter(configValue -> configValue.errorMessages().size() > 0).collect(Collectors.toList()); + // These should be caught when the herder validates the connector configuration, but just in case + if (errorConfigs.size() > 0) { + throw new ConnectException("Client Config Overrides not allowed " + errorConfigs); + } + return clientOverrides; + } + + ErrorHandlingMetrics errorHandlingMetrics(ConnectorTaskId id) { + return new ErrorHandlingMetrics(id, metrics); + } + + private List sinkTaskReporters(ConnectorTaskId id, SinkConnectorConfig connConfig, + ErrorHandlingMetrics errorHandlingMetrics, + Class connectorClass) { + ArrayList reporters = new ArrayList<>(); + LogReporter logReporter = new LogReporter(id, connConfig, errorHandlingMetrics); + reporters.add(logReporter); + + // check if topic for dead letter queue exists + String topic = connConfig.dlqTopicName(); + if (topic != null && !topic.isEmpty()) { + Map producerProps = producerConfigs(id, "connector-dlq-producer-" + id, config, connConfig, connectorClass, + connectorClientConfigOverridePolicy, kafkaClusterId); + Map adminProps = adminConfigs(id, "connector-dlq-adminclient-", config, connConfig, connectorClass, connectorClientConfigOverridePolicy, kafkaClusterId); + DeadLetterQueueReporter reporter = DeadLetterQueueReporter.createAndSetup(adminProps, id, connConfig, producerProps, errorHandlingMetrics); + + reporters.add(reporter); + } + + return reporters; + } + + private List sourceTaskReporters(ConnectorTaskId id, ConnectorConfig connConfig, + ErrorHandlingMetrics errorHandlingMetrics) { + List reporters = new ArrayList<>(); + LogReporter logReporter = new LogReporter(id, connConfig, errorHandlingMetrics); + reporters.add(logReporter); + + return reporters; + } + + private WorkerErrantRecordReporter createWorkerErrantRecordReporter( + SinkConnectorConfig connConfig, + RetryWithToleranceOperator retryWithToleranceOperator, + Converter keyConverter, + Converter valueConverter, + HeaderConverter headerConverter + ) { + // check if errant record reporter topic is configured + if (connConfig.enableErrantRecordReporter()) { + return new WorkerErrantRecordReporter(retryWithToleranceOperator, keyConverter, valueConverter, headerConverter); + } + return null; + } + + private void stopTask(ConnectorTaskId taskId) { + try (LoggingContext loggingContext = LoggingContext.forTask(taskId)) { + WorkerTask task = tasks.get(taskId); + if (task == null) { + log.warn("Ignoring stop request for unowned task {}", taskId); + return; + } + + log.info("Stopping task {}", task.id()); + if (task instanceof WorkerSourceTask) + sourceTaskOffsetCommitter.remove(task.id()); + + ClassLoader savedLoader = plugins.currentThreadLoader(); + try { + savedLoader = Plugins.compareAndSwapLoaders(task.loader()); + task.stop(); + } finally { + Plugins.compareAndSwapLoaders(savedLoader); + } + } + } + + private void stopTasks(Collection ids) { + // Herder is responsible for stopping tasks. This is an internal method to sequentially + // stop the tasks that have not explicitly been stopped. + for (ConnectorTaskId taskId : ids) { + stopTask(taskId); + } + } + + private void awaitStopTask(ConnectorTaskId taskId, long timeout) { + try (LoggingContext loggingContext = LoggingContext.forTask(taskId)) { + WorkerTask task = tasks.remove(taskId); + if (task == null) { + log.warn("Ignoring await stop request for non-present task {}", taskId); + return; + } + + if (!task.awaitStop(timeout)) { + log.error("Graceful stop of task {} failed.", task.id()); + task.cancel(); + } else { + log.debug("Graceful stop of task {} succeeded.", task.id()); + } + + try { + task.removeMetrics(); + } finally { + connectorStatusMetricsGroup.recordTaskRemoved(taskId); + } + } + } + + private void awaitStopTasks(Collection ids) { + long now = time.milliseconds(); + long deadline = now + config.getLong(WorkerConfig.TASK_SHUTDOWN_GRACEFUL_TIMEOUT_MS_CONFIG); + for (ConnectorTaskId id : ids) { + long remaining = Math.max(0, deadline - time.milliseconds()); + awaitStopTask(id, remaining); + } + } + + /** + * Stop asynchronously all the worker's tasks and await their termination. + */ + public void stopAndAwaitTasks() { + stopAndAwaitTasks(new ArrayList<>(tasks.keySet())); + } + + /** + * Stop asynchronously a collection of tasks that belong to this worker and await their termination. + * + * @param ids the collection of tasks to be stopped. + */ + public void stopAndAwaitTasks(Collection ids) { + stopTasks(ids); + awaitStopTasks(ids); + } + + /** + * Stop a task that belongs to this worker and await its termination. + * + * @param taskId the ID of the task to be stopped. + */ + public void stopAndAwaitTask(ConnectorTaskId taskId) { + stopTask(taskId); + awaitStopTasks(Collections.singletonList(taskId)); + } + + /** + * Get the IDs of the tasks currently running in this worker. + */ + public Set taskIds() { + return tasks.keySet(); + } + + public Converter getInternalKeyConverter() { + return internalKeyConverter; + } + + public Converter getInternalValueConverter() { + return internalValueConverter; + } + + public Plugins getPlugins() { + return plugins; + } + + public String workerId() { + return workerId; + } + + /** + * Returns whether this worker is configured to allow source connectors to create the topics + * that they use with custom configurations, if these topics don't already exist. + * + * @return true if topic creation by source connectors is allowed; false otherwise + */ + public boolean isTopicCreationEnabled() { + return config.topicCreationEnable(); + } + + /** + * Get the {@link ConnectMetrics} that uses Kafka Metrics and manages the JMX reporter. + * @return the Connect-specific metrics; never null + */ + public ConnectMetrics metrics() { + return metrics; + } + + public void setTargetState(String connName, TargetState state, Callback stateChangeCallback) { + log.info("Setting connector {} state to {}", connName, state); + + WorkerConnector workerConnector = connectors.get(connName); + if (workerConnector != null) { + ClassLoader connectorLoader = + plugins.delegatingLoader().connectorLoader(workerConnector.connector()); + executeStateTransition( + () -> workerConnector.transitionTo(state, stateChangeCallback), + connectorLoader); + } + + for (Map.Entry taskEntry : tasks.entrySet()) { + if (taskEntry.getKey().connector().equals(connName)) { + WorkerTask workerTask = taskEntry.getValue(); + executeStateTransition(() -> workerTask.transitionTo(state), workerTask.loader); + } + } + } + + private void executeStateTransition(Runnable stateTransition, ClassLoader loader) { + ClassLoader savedLoader = plugins.currentThreadLoader(); + try { + savedLoader = Plugins.compareAndSwapLoaders(loader); + stateTransition.run(); + } finally { + Plugins.compareAndSwapLoaders(savedLoader); + } + } + + ConnectorStatusMetricsGroup connectorStatusMetricsGroup() { + return connectorStatusMetricsGroup; + } + + WorkerMetricsGroup workerMetricsGroup() { + return workerMetricsGroup; + } + + static class ConnectorStatusMetricsGroup { + private final ConnectMetrics connectMetrics; + private final ConnectMetricsRegistry registry; + private final ConcurrentMap connectorStatusMetrics = new ConcurrentHashMap<>(); + private final Herder herder; + private final ConcurrentMap tasks; + + + protected ConnectorStatusMetricsGroup( + ConnectMetrics connectMetrics, ConcurrentMap tasks, Herder herder) { + this.connectMetrics = connectMetrics; + this.registry = connectMetrics.registry(); + this.tasks = tasks; + this.herder = herder; + } + + protected ConnectMetrics.LiteralSupplier taskCounter(String connName) { + return now -> tasks.keySet() + .stream() + .filter(taskId -> taskId.connector().equals(connName)) + .count(); + } + + protected ConnectMetrics.LiteralSupplier taskStatusCounter(String connName, TaskStatus.State state) { + return now -> tasks.values() + .stream() + .filter(task -> + task.id().connector().equals(connName) && + herder.taskStatus(task.id()).state().equalsIgnoreCase(state.toString())) + .count(); + } + + protected synchronized void recordTaskAdded(ConnectorTaskId connectorTaskId) { + if (connectorStatusMetrics.containsKey(connectorTaskId.connector())) { + return; + } + + String connName = connectorTaskId.connector(); + + MetricGroup metricGroup = connectMetrics.group(registry.workerGroupName(), + registry.connectorTagName(), connName); + + metricGroup.addValueMetric(registry.connectorTotalTaskCount, taskCounter(connName)); + for (Map.Entry statusMetric : registry.connectorStatusMetrics + .entrySet()) { + metricGroup.addValueMetric(statusMetric.getKey(), taskStatusCounter(connName, + statusMetric.getValue())); + } + connectorStatusMetrics.put(connectorTaskId.connector(), metricGroup); + } + + protected synchronized void recordTaskRemoved(ConnectorTaskId connectorTaskId) { + // Unregister connector task count metric if we remove the last task of the connector + if (tasks.keySet().stream().noneMatch(id -> id.connector().equals(connectorTaskId.connector()))) { + connectorStatusMetrics.get(connectorTaskId.connector()).close(); + connectorStatusMetrics.remove(connectorTaskId.connector()); + } + } + + protected synchronized void close() { + for (MetricGroup metricGroup: connectorStatusMetrics.values()) { + metricGroup.close(); + } + } + + protected MetricGroup metricGroup(String connectorId) { + return connectorStatusMetrics.get(connectorId); + } + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerConfig.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerConfig.java new file mode 100644 index 0000000..73b743b --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerConfig.java @@ -0,0 +1,492 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.clients.ClientDnsLookup; +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigDef.Importance; +import org.apache.kafka.common.config.ConfigDef.Type; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.config.internals.BrokerSecurityConfigs; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.connect.storage.SimpleHeaderConverter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.regex.Pattern; + +import org.eclipse.jetty.util.StringUtil; + +import static org.apache.kafka.common.config.ConfigDef.Range.atLeast; +import static org.apache.kafka.common.config.ConfigDef.ValidString.in; +import static org.apache.kafka.connect.runtime.SourceConnectorConfig.TOPIC_CREATION_PREFIX; + +/** + * Common base class providing configuration for Kafka Connect workers, whether standalone or distributed. + */ +public class WorkerConfig extends AbstractConfig { + private static final Logger log = LoggerFactory.getLogger(WorkerConfig.class); + + private static final Pattern COMMA_WITH_WHITESPACE = Pattern.compile("\\s*,\\s*"); + private static final Collection HEADER_ACTIONS = Collections.unmodifiableList( + Arrays.asList("set", "add", "setDate", "addDate") + ); + + public static final String BOOTSTRAP_SERVERS_CONFIG = "bootstrap.servers"; + public static final String BOOTSTRAP_SERVERS_DOC + = "A list of host/port pairs to use for establishing the initial connection to the Kafka " + + "cluster. The client will make use of all servers irrespective of which servers are " + + "specified here for bootstrapping—this list only impacts the initial hosts used " + + "to discover the full set of servers. This list should be in the form " + + "host1:port1,host2:port2,.... Since these servers are just used for the " + + "initial connection to discover the full cluster membership (which may change " + + "dynamically), this list need not contain the full set of servers (you may want more " + + "than one, though, in case a server is down)."; + public static final String BOOTSTRAP_SERVERS_DEFAULT = "localhost:9092"; + + public static final String CLIENT_DNS_LOOKUP_CONFIG = CommonClientConfigs.CLIENT_DNS_LOOKUP_CONFIG; + public static final String CLIENT_DNS_LOOKUP_DOC = CommonClientConfigs.CLIENT_DNS_LOOKUP_DOC; + + public static final String KEY_CONVERTER_CLASS_CONFIG = "key.converter"; + public static final String KEY_CONVERTER_CLASS_DOC = + "Converter class used to convert between Kafka Connect format and the serialized form that is written to Kafka." + + " This controls the format of the keys in messages written to or read from Kafka, and since this is" + + " independent of connectors it allows any connector to work with any serialization format." + + " Examples of common formats include JSON and Avro."; + + public static final String VALUE_CONVERTER_CLASS_CONFIG = "value.converter"; + public static final String VALUE_CONVERTER_CLASS_DOC = + "Converter class used to convert between Kafka Connect format and the serialized form that is written to Kafka." + + " This controls the format of the values in messages written to or read from Kafka, and since this is" + + " independent of connectors it allows any connector to work with any serialization format." + + " Examples of common formats include JSON and Avro."; + + public static final String HEADER_CONVERTER_CLASS_CONFIG = "header.converter"; + public static final String HEADER_CONVERTER_CLASS_DOC = + "HeaderConverter class used to convert between Kafka Connect format and the serialized form that is written to Kafka." + + " This controls the format of the header values in messages written to or read from Kafka, and since this is" + + " independent of connectors it allows any connector to work with any serialization format." + + " Examples of common formats include JSON and Avro. By default, the SimpleHeaderConverter is used to serialize" + + " header values to strings and deserialize them by inferring the schemas."; + public static final String HEADER_CONVERTER_CLASS_DEFAULT = SimpleHeaderConverter.class.getName(); + + public static final String TASK_SHUTDOWN_GRACEFUL_TIMEOUT_MS_CONFIG + = "task.shutdown.graceful.timeout.ms"; + private static final String TASK_SHUTDOWN_GRACEFUL_TIMEOUT_MS_DOC = + "Amount of time to wait for tasks to shutdown gracefully. This is the total amount of time," + + " not per task. All task have shutdown triggered, then they are waited on sequentially."; + private static final String TASK_SHUTDOWN_GRACEFUL_TIMEOUT_MS_DEFAULT = "5000"; + + public static final String OFFSET_COMMIT_INTERVAL_MS_CONFIG = "offset.flush.interval.ms"; + private static final String OFFSET_COMMIT_INTERVAL_MS_DOC + = "Interval at which to try committing offsets for tasks."; + public static final long OFFSET_COMMIT_INTERVAL_MS_DEFAULT = 60000L; + + public static final String OFFSET_COMMIT_TIMEOUT_MS_CONFIG = "offset.flush.timeout.ms"; + private static final String OFFSET_COMMIT_TIMEOUT_MS_DOC + = "Maximum number of milliseconds to wait for records to flush and partition offset data to be" + + " committed to offset storage before cancelling the process and restoring the offset " + + "data to be committed in a future attempt."; + public static final long OFFSET_COMMIT_TIMEOUT_MS_DEFAULT = 5000L; + + public static final String LISTENERS_CONFIG = "listeners"; + private static final String LISTENERS_DOC + = "List of comma-separated URIs the REST API will listen on. The supported protocols are HTTP and HTTPS.\n" + + " Specify hostname as 0.0.0.0 to bind to all interfaces.\n" + + " Leave hostname empty to bind to default interface.\n" + + " Examples of legal listener lists: HTTP://myhost:8083,HTTPS://myhost:8084"; + static final List LISTENERS_DEFAULT = Collections.singletonList("http://:8083"); + + public static final String REST_ADVERTISED_HOST_NAME_CONFIG = "rest.advertised.host.name"; + private static final String REST_ADVERTISED_HOST_NAME_DOC + = "If this is set, this is the hostname that will be given out to other workers to connect to."; + + public static final String REST_ADVERTISED_PORT_CONFIG = "rest.advertised.port"; + private static final String REST_ADVERTISED_PORT_DOC + = "If this is set, this is the port that will be given out to other workers to connect to."; + + public static final String REST_ADVERTISED_LISTENER_CONFIG = "rest.advertised.listener"; + private static final String REST_ADVERTISED_LISTENER_DOC + = "Sets the advertised listener (HTTP or HTTPS) which will be given to other workers to use."; + + public static final String ACCESS_CONTROL_ALLOW_ORIGIN_CONFIG = "access.control.allow.origin"; + protected static final String ACCESS_CONTROL_ALLOW_ORIGIN_DOC = + "Value to set the Access-Control-Allow-Origin header to for REST API requests." + + "To enable cross origin access, set this to the domain of the application that should be permitted" + + " to access the API, or '*' to allow access from any domain. The default value only allows access" + + " from the domain of the REST API."; + protected static final String ACCESS_CONTROL_ALLOW_ORIGIN_DEFAULT = ""; + + public static final String ACCESS_CONTROL_ALLOW_METHODS_CONFIG = "access.control.allow.methods"; + protected static final String ACCESS_CONTROL_ALLOW_METHODS_DOC = + "Sets the methods supported for cross origin requests by setting the Access-Control-Allow-Methods header. " + + "The default value of the Access-Control-Allow-Methods header allows cross origin requests for GET, POST and HEAD."; + protected static final String ACCESS_CONTROL_ALLOW_METHODS_DEFAULT = ""; + + public static final String ADMIN_LISTENERS_CONFIG = "admin.listeners"; + protected static final String ADMIN_LISTENERS_DOC = "List of comma-separated URIs the Admin REST API will listen on." + + " The supported protocols are HTTP and HTTPS." + + " An empty or blank string will disable this feature." + + " The default behavior is to use the regular listener (specified by the 'listeners' property)."; + public static final String ADMIN_LISTENERS_HTTPS_CONFIGS_PREFIX = "admin.listeners.https."; + + public static final String PLUGIN_PATH_CONFIG = "plugin.path"; + protected static final String PLUGIN_PATH_DOC = "List of paths separated by commas (,) that " + + "contain plugins (connectors, converters, transformations). The list should consist" + + " of top level directories that include any combination of: \n" + + "a) directories immediately containing jars with plugins and their dependencies\n" + + "b) uber-jars with plugins and their dependencies\n" + + "c) directories immediately containing the package directory structure of classes of " + + "plugins and their dependencies\n" + + "Note: symlinks will be followed to discover dependencies or plugins.\n" + + "Examples: plugin.path=/usr/local/share/java,/usr/local/share/kafka/plugins," + + "/opt/connectors\n" + + "Do not use config provider variables in this property, since the raw path is used " + + "by the worker's scanner before config providers are initialized and used to " + + "replace variables."; + + public static final String CONFIG_PROVIDERS_CONFIG = "config.providers"; + protected static final String CONFIG_PROVIDERS_DOC = + "Comma-separated names of ConfigProvider classes, loaded and used " + + "in the order specified. Implementing the interface " + + "ConfigProvider allows you to replace variable references in connector configurations, " + + "such as for externalized secrets. "; + + public static final String REST_EXTENSION_CLASSES_CONFIG = "rest.extension.classes"; + protected static final String REST_EXTENSION_CLASSES_DOC = + "Comma-separated names of ConnectRestExtension classes, loaded and called " + + "in the order specified. Implementing the interface " + + "ConnectRestExtension allows you to inject into Connect's REST API user defined resources like filters. " + + "Typically used to add custom capability like logging, security, etc. "; + + public static final String CONNECTOR_CLIENT_POLICY_CLASS_CONFIG = "connector.client.config.override.policy"; + public static final String CONNECTOR_CLIENT_POLICY_CLASS_DOC = + "Class name or alias of implementation of ConnectorClientConfigOverridePolicy. Defines what client configurations can be " + + "overriden by the connector. The default implementation is `All`, meaning connector configurations can override all client properties. " + + "The other possible policies in the framework include `None` to disallow connectors from overriding client properties, " + + "and `Principal` to allow connectors to override only client principals."; + public static final String CONNECTOR_CLIENT_POLICY_CLASS_DEFAULT = "All"; + + + public static final String METRICS_SAMPLE_WINDOW_MS_CONFIG = CommonClientConfigs.METRICS_SAMPLE_WINDOW_MS_CONFIG; + public static final String METRICS_NUM_SAMPLES_CONFIG = CommonClientConfigs.METRICS_NUM_SAMPLES_CONFIG; + public static final String METRICS_RECORDING_LEVEL_CONFIG = CommonClientConfigs.METRICS_RECORDING_LEVEL_CONFIG; + public static final String METRIC_REPORTER_CLASSES_CONFIG = CommonClientConfigs.METRIC_REPORTER_CLASSES_CONFIG; + + public static final String TOPIC_TRACKING_ENABLE_CONFIG = "topic.tracking.enable"; + protected static final String TOPIC_TRACKING_ENABLE_DOC = "Enable tracking the set of active " + + "topics per connector during runtime."; + protected static final boolean TOPIC_TRACKING_ENABLE_DEFAULT = true; + + public static final String TOPIC_TRACKING_ALLOW_RESET_CONFIG = "topic.tracking.allow.reset"; + protected static final String TOPIC_TRACKING_ALLOW_RESET_DOC = "If set to true, it allows " + + "user requests to reset the set of active topics per connector."; + protected static final boolean TOPIC_TRACKING_ALLOW_RESET_DEFAULT = true; + + public static final String CONNECT_KAFKA_CLUSTER_ID = "connect.kafka.cluster.id"; + public static final String CONNECT_GROUP_ID = "connect.group.id"; + + public static final String TOPIC_CREATION_ENABLE_CONFIG = "topic.creation.enable"; + protected static final String TOPIC_CREATION_ENABLE_DOC = "Whether to allow " + + "automatic creation of topics used by source connectors, when source connectors " + + "are configured with `" + TOPIC_CREATION_PREFIX + "` properties. Each task will use an " + + "admin client to create its topics and will not depend on the Kafka brokers " + + "to create topics automatically."; + protected static final boolean TOPIC_CREATION_ENABLE_DEFAULT = true; + + public static final String RESPONSE_HTTP_HEADERS_CONFIG = "response.http.headers.config"; + protected static final String RESPONSE_HTTP_HEADERS_DOC = "Rules for REST API HTTP response headers"; + protected static final String RESPONSE_HTTP_HEADERS_DEFAULT = ""; + + /** + * Get a basic ConfigDef for a WorkerConfig. This includes all the common settings. Subclasses can use this to + * bootstrap their own ConfigDef. + * @return a ConfigDef with all the common options specified + */ + protected static ConfigDef baseConfigDef() { + return new ConfigDef() + .define(BOOTSTRAP_SERVERS_CONFIG, Type.LIST, BOOTSTRAP_SERVERS_DEFAULT, + Importance.HIGH, BOOTSTRAP_SERVERS_DOC) + .define(CLIENT_DNS_LOOKUP_CONFIG, + Type.STRING, + ClientDnsLookup.USE_ALL_DNS_IPS.toString(), + in(ClientDnsLookup.USE_ALL_DNS_IPS.toString(), + ClientDnsLookup.RESOLVE_CANONICAL_BOOTSTRAP_SERVERS_ONLY.toString()), + Importance.MEDIUM, + CLIENT_DNS_LOOKUP_DOC) + .define(KEY_CONVERTER_CLASS_CONFIG, Type.CLASS, + Importance.HIGH, KEY_CONVERTER_CLASS_DOC) + .define(VALUE_CONVERTER_CLASS_CONFIG, Type.CLASS, + Importance.HIGH, VALUE_CONVERTER_CLASS_DOC) + .define(TASK_SHUTDOWN_GRACEFUL_TIMEOUT_MS_CONFIG, Type.LONG, + TASK_SHUTDOWN_GRACEFUL_TIMEOUT_MS_DEFAULT, Importance.LOW, + TASK_SHUTDOWN_GRACEFUL_TIMEOUT_MS_DOC) + .define(OFFSET_COMMIT_INTERVAL_MS_CONFIG, Type.LONG, OFFSET_COMMIT_INTERVAL_MS_DEFAULT, + Importance.LOW, OFFSET_COMMIT_INTERVAL_MS_DOC) + .define(OFFSET_COMMIT_TIMEOUT_MS_CONFIG, Type.LONG, OFFSET_COMMIT_TIMEOUT_MS_DEFAULT, + Importance.LOW, OFFSET_COMMIT_TIMEOUT_MS_DOC) + .define(LISTENERS_CONFIG, Type.LIST, LISTENERS_DEFAULT, new ListenersValidator(), Importance.LOW, LISTENERS_DOC) + .define(REST_ADVERTISED_HOST_NAME_CONFIG, Type.STRING, null, Importance.LOW, REST_ADVERTISED_HOST_NAME_DOC) + .define(REST_ADVERTISED_PORT_CONFIG, Type.INT, null, Importance.LOW, REST_ADVERTISED_PORT_DOC) + .define(REST_ADVERTISED_LISTENER_CONFIG, Type.STRING, null, Importance.LOW, REST_ADVERTISED_LISTENER_DOC) + .define(ACCESS_CONTROL_ALLOW_ORIGIN_CONFIG, Type.STRING, + ACCESS_CONTROL_ALLOW_ORIGIN_DEFAULT, Importance.LOW, + ACCESS_CONTROL_ALLOW_ORIGIN_DOC) + .define(ACCESS_CONTROL_ALLOW_METHODS_CONFIG, Type.STRING, + ACCESS_CONTROL_ALLOW_METHODS_DEFAULT, Importance.LOW, + ACCESS_CONTROL_ALLOW_METHODS_DOC) + .define(PLUGIN_PATH_CONFIG, + Type.LIST, + null, + Importance.LOW, + PLUGIN_PATH_DOC) + .define(METRICS_SAMPLE_WINDOW_MS_CONFIG, Type.LONG, + 30000, atLeast(0), Importance.LOW, + CommonClientConfigs.METRICS_SAMPLE_WINDOW_MS_DOC) + .define(METRICS_NUM_SAMPLES_CONFIG, Type.INT, + 2, atLeast(1), Importance.LOW, + CommonClientConfigs.METRICS_NUM_SAMPLES_DOC) + .define(METRICS_RECORDING_LEVEL_CONFIG, Type.STRING, + Sensor.RecordingLevel.INFO.toString(), + in(Sensor.RecordingLevel.INFO.toString(), Sensor.RecordingLevel.DEBUG.toString()), + Importance.LOW, + CommonClientConfigs.METRICS_RECORDING_LEVEL_DOC) + .define(METRIC_REPORTER_CLASSES_CONFIG, Type.LIST, + "", Importance.LOW, + CommonClientConfigs.METRIC_REPORTER_CLASSES_DOC) + .define(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, + ConfigDef.Type.STRING, "none", ConfigDef.Importance.LOW, BrokerSecurityConfigs.SSL_CLIENT_AUTH_DOC) + .define(HEADER_CONVERTER_CLASS_CONFIG, Type.CLASS, + HEADER_CONVERTER_CLASS_DEFAULT, + Importance.LOW, HEADER_CONVERTER_CLASS_DOC) + .define(CONFIG_PROVIDERS_CONFIG, Type.LIST, + Collections.emptyList(), + Importance.LOW, CONFIG_PROVIDERS_DOC) + .define(REST_EXTENSION_CLASSES_CONFIG, Type.LIST, "", + Importance.LOW, REST_EXTENSION_CLASSES_DOC) + .define(ADMIN_LISTENERS_CONFIG, Type.LIST, null, + new AdminListenersValidator(), Importance.LOW, ADMIN_LISTENERS_DOC) + .define(CONNECTOR_CLIENT_POLICY_CLASS_CONFIG, Type.STRING, CONNECTOR_CLIENT_POLICY_CLASS_DEFAULT, + Importance.MEDIUM, CONNECTOR_CLIENT_POLICY_CLASS_DOC) + .define(TOPIC_TRACKING_ENABLE_CONFIG, Type.BOOLEAN, TOPIC_TRACKING_ENABLE_DEFAULT, + Importance.LOW, TOPIC_TRACKING_ENABLE_DOC) + .define(TOPIC_TRACKING_ALLOW_RESET_CONFIG, Type.BOOLEAN, TOPIC_TRACKING_ALLOW_RESET_DEFAULT, + Importance.LOW, TOPIC_TRACKING_ALLOW_RESET_DOC) + .define(TOPIC_CREATION_ENABLE_CONFIG, Type.BOOLEAN, TOPIC_CREATION_ENABLE_DEFAULT, Importance.LOW, + TOPIC_CREATION_ENABLE_DOC) + .define(RESPONSE_HTTP_HEADERS_CONFIG, Type.STRING, RESPONSE_HTTP_HEADERS_DEFAULT, + new ResponseHttpHeadersValidator(), Importance.LOW, RESPONSE_HTTP_HEADERS_DOC) + // security support + .withClientSslSupport(); + } + + private void logInternalConverterRemovalWarnings(Map props) { + List removedProperties = new ArrayList<>(); + for (String property : Arrays.asList("internal.key.converter", "internal.value.converter")) { + if (props.containsKey(property)) { + removedProperties.add(property); + } + removedProperties.addAll(originalsWithPrefix(property + ".").keySet()); + } + if (!removedProperties.isEmpty()) { + log.warn( + "The worker has been configured with one or more internal converter properties ({}). " + + "Support for these properties was deprecated in version 2.0 and removed in version 3.0, " + + "and specifying them will have no effect. " + + "Instead, an instance of the JsonConverter with schemas.enable " + + "set to false will be used. For more information, please visit " + + "http://kafka.apache.org/documentation/#upgrade and consult the upgrade notes" + + "for the 3.0 release.", + removedProperties); + } + } + + private void logPluginPathConfigProviderWarning(Map rawOriginals) { + String rawPluginPath = rawOriginals.get(PLUGIN_PATH_CONFIG); + // Can't use AbstractConfig::originalsStrings here since some values may be null, which + // causes that method to fail + String transformedPluginPath = Objects.toString(originals().get(PLUGIN_PATH_CONFIG)); + if (!Objects.equals(rawPluginPath, transformedPluginPath)) { + log.warn( + "Variables cannot be used in the 'plugin.path' property, since the property is " + + "used by plugin scanning before the config providers that replace the " + + "variables are initialized. The raw value '{}' was used for plugin scanning, as " + + "opposed to the transformed value '{}', and this may cause unexpected results.", + rawPluginPath, + transformedPluginPath + ); + } + } + + public Integer getRebalanceTimeout() { + return null; + } + + public boolean topicCreationEnable() { + return getBoolean(TOPIC_CREATION_ENABLE_CONFIG); + } + + @Override + protected Map postProcessParsedConfig(final Map parsedValues) { + return CommonClientConfigs.postProcessReconnectBackoffConfigs(this, parsedValues); + } + + public static List pluginLocations(Map props) { + String locationList = props.get(WorkerConfig.PLUGIN_PATH_CONFIG); + return locationList == null + ? new ArrayList<>() + : Arrays.asList(COMMA_WITH_WHITESPACE.split(locationList.trim(), -1)); + } + + public WorkerConfig(ConfigDef definition, Map props) { + super(definition, props); + logInternalConverterRemovalWarnings(props); + logPluginPathConfigProviderWarning(props); + } + + // Visible for testing + static void validateHttpResponseHeaderConfig(String config) { + try { + // validate format + String[] configTokens = config.trim().split("\\s+", 2); + if (configTokens.length != 2) { + throw new ConfigException(String.format("Invalid format of header config '%s\'. " + + "Expected: '[ation] [header name]:[header value]'", config)); + } + + // validate action + String method = configTokens[0].trim(); + validateHeaderConfigAction(method); + + // validate header name and header value pair + String header = configTokens[1]; + String[] headerTokens = header.trim().split(":"); + if (headerTokens.length != 2) { + throw new ConfigException( + String.format("Invalid format of header name and header value pair '%s'. " + + "Expected: '[header name]:[header value]'", header)); + } + + // validate header name + String headerName = headerTokens[0].trim(); + if (headerName.isEmpty() || headerName.matches(".*\\s+.*")) { + throw new ConfigException(String.format("Invalid header name '%s'. " + + "The '[header name]' cannot contain whitespace", headerName)); + } + } catch (ArrayIndexOutOfBoundsException e) { + throw new ConfigException(String.format("Invalid header config '%s'.", config), e); + } + } + + // Visible for testing + static void validateHeaderConfigAction(String action) { + if (!HEADER_ACTIONS.stream().anyMatch(action::equalsIgnoreCase)) { + throw new ConfigException(String.format("Invalid header config action: '%s'. " + + "Expected one of %s", action, HEADER_ACTIONS)); + } + } + + private static class ListenersValidator implements ConfigDef.Validator { + @Override + public void ensureValid(String name, Object value) { + if (!(value instanceof List)) { + throw new ConfigException("Invalid value type for listeners (expected list of URLs , ex: http://localhost:8080,https://localhost:8443)."); + } + + List items = (List) value; + if (items.isEmpty()) { + throw new ConfigException("Invalid value for listeners, at least one URL is expected, ex: http://localhost:8080,https://localhost:8443."); + } + + for (Object item : items) { + if (!(item instanceof String)) { + throw new ConfigException("Invalid type for listeners (expected String)."); + } + if (Utils.isBlank((String) item)) { + throw new ConfigException("Empty URL found when parsing listeners list."); + } + } + } + + @Override + public String toString() { + return "List of comma-separated URLs, ex: http://localhost:8080,https://localhost:8443."; + } + } + + private static class AdminListenersValidator implements ConfigDef.Validator { + @Override + public void ensureValid(String name, Object value) { + if (value == null) { + return; + } + + if (!(value instanceof List)) { + throw new ConfigException("Invalid value type for admin.listeners (expected list)."); + } + + List items = (List) value; + if (items.isEmpty()) { + return; + } + + for (Object item : items) { + if (!(item instanceof String)) { + throw new ConfigException("Invalid type for admin.listeners (expected String)."); + } + if (Utils.isBlank((String) item)) { + throw new ConfigException("Empty URL found when parsing admin.listeners list."); + } + } + } + + @Override + public String toString() { + return "List of comma-separated URLs, ex: http://localhost:8080,https://localhost:8443."; + } + } + + private static class ResponseHttpHeadersValidator implements ConfigDef.Validator { + @Override + public void ensureValid(String name, Object value) { + String strValue = (String) value; + if (Utils.isBlank(strValue)) { + return; + } + + String[] configs = StringUtil.csvSplit(strValue); // handles and removed surrounding quotes + Arrays.stream(configs).forEach(WorkerConfig::validateHttpResponseHeaderConfig); + } + + @Override + public String toString() { + return "Comma-separated header rules, where each header rule is of the form " + + "'[action] [header name]:[header value]' and optionally surrounded by double quotes " + + "if any part of a header rule contains a comma"; + } + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerConfigTransformer.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerConfigTransformer.java new file mode 100644 index 0000000..4d9c4c1 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerConfigTransformer.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.provider.ConfigProvider; +import org.apache.kafka.common.config.ConfigTransformer; +import org.apache.kafka.common.config.ConfigTransformerResult; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.connect.runtime.Herder.ConfigReloadAction; +import org.apache.kafka.connect.util.Callback; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Locale; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +/** + * A wrapper class to perform configuration transformations and schedule reloads for any + * retrieved TTL values. + */ +public class WorkerConfigTransformer implements AutoCloseable { + private static final Logger log = LoggerFactory.getLogger(WorkerConfigTransformer.class); + + private final Worker worker; + private final ConfigTransformer configTransformer; + private final ConcurrentMap> requests = new ConcurrentHashMap<>(); + private final Map configProviders; + + public WorkerConfigTransformer(Worker worker, Map configProviders) { + this.worker = worker; + this.configProviders = configProviders; + this.configTransformer = new ConfigTransformer(configProviders); + } + + public Map transform(Map configs) { + return transform(null, configs); + } + + public Map transform(String connectorName, Map configs) { + if (configs == null) return null; + ConfigTransformerResult result = configTransformer.transform(configs); + if (connectorName != null) { + String key = ConnectorConfig.CONFIG_RELOAD_ACTION_CONFIG; + String action = (String) ConfigDef.parseType(key, configs.get(key), ConfigDef.Type.STRING); + if (action == null) { + // The default action is "restart". + action = ConnectorConfig.CONFIG_RELOAD_ACTION_RESTART; + } + ConfigReloadAction reloadAction = ConfigReloadAction.valueOf(action.toUpperCase(Locale.ROOT)); + if (reloadAction == ConfigReloadAction.RESTART) { + scheduleReload(connectorName, result.ttls()); + } + } + return result.data(); + } + + private void scheduleReload(String connectorName, Map ttls) { + for (Map.Entry entry : ttls.entrySet()) { + scheduleReload(connectorName, entry.getKey(), entry.getValue()); + } + } + + private void scheduleReload(String connectorName, String path, long ttl) { + Map connectorRequests = requests.get(connectorName); + if (connectorRequests == null) { + connectorRequests = new ConcurrentHashMap<>(); + requests.put(connectorName, connectorRequests); + } else { + HerderRequest previousRequest = connectorRequests.get(path); + if (previousRequest != null) { + // Delete previous request for ttl which is now stale + previousRequest.cancel(); + } + } + log.info("Scheduling a restart of connector {} in {} ms", connectorName, ttl); + Callback cb = (error, result) -> { + if (error != null) { + log.error("Unexpected error during connector restart: ", error); + } + }; + HerderRequest request = worker.herder().restartConnector(ttl, connectorName, cb); + connectorRequests.put(path, request); + } + + @Override + public void close() { + configProviders.values().forEach(x -> Utils.closeQuietly(x, "config provider")); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerConnector.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerConnector.java new file mode 100644 index 0000000..09b57fd --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerConnector.java @@ -0,0 +1,533 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.connect.connector.Connector; +import org.apache.kafka.connect.connector.ConnectorContext; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.runtime.ConnectMetrics.MetricGroup; +import org.apache.kafka.connect.runtime.isolation.Plugins; +import org.apache.kafka.connect.sink.SinkConnectorContext; +import org.apache.kafka.connect.source.SourceConnectorContext; +import org.apache.kafka.connect.storage.OffsetStorageReader; +import org.apache.kafka.connect.util.Callback; +import org.apache.kafka.connect.util.ConnectUtils; +import org.apache.kafka.connect.util.LoggingContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Container for connectors which is responsible for managing their lifecycle (e.g. handling startup, + * shutdown, pausing, etc.). Internally, we manage the runtime state of the connector and transition according + * to target state changes. Note that unlike connector tasks, the connector does not really have a "pause" + * state which is distinct from being stopped. We therefore treat pause operations as requests to momentarily + * stop the connector, and resume operations as requests to restart it (without reinitialization). Connector + * failures, whether in initialization or after startup, are treated as fatal, which means that we will not attempt + * to restart this connector instance after failure. What this means from a user perspective is that you must + * use the /restart REST API to restart a failed task. This behavior is consistent with task failures. + * + * Note that this class is NOT thread-safe. + */ +public class WorkerConnector implements Runnable { + private static final Logger log = LoggerFactory.getLogger(WorkerConnector.class); + private static final String THREAD_NAME_PREFIX = "connector-thread-"; + + private enum State { + INIT, // initial state before startup + STOPPED, // the connector has been stopped/paused. + STARTED, // the connector has been started/resumed. + FAILED, // the connector has failed (no further transitions are possible after this state) + } + + private final String connName; + private final Map config; + private final ConnectorStatus.Listener statusListener; + private final ClassLoader loader; + private final CloseableConnectorContext ctx; + private final Connector connector; + private final ConnectorMetricsGroup metrics; + private final AtomicReference pendingTargetStateChange; + private final AtomicReference> pendingStateChangeCallback; + private final CountDownLatch shutdownLatch; + private volatile boolean stopping; // indicates whether the Worker has asked the connector to stop + private volatile boolean cancelled; // indicates whether the Worker has cancelled the connector (e.g. because of slow shutdown) + + private State state; + private final OffsetStorageReader offsetStorageReader; + + public WorkerConnector(String connName, + Connector connector, + ConnectorConfig connectorConfig, + CloseableConnectorContext ctx, + ConnectMetrics metrics, + ConnectorStatus.Listener statusListener, + OffsetStorageReader offsetStorageReader, + ClassLoader loader) { + this.connName = connName; + this.config = connectorConfig.originalsStrings(); + this.loader = loader; + this.ctx = ctx; + this.connector = connector; + this.state = State.INIT; + this.metrics = new ConnectorMetricsGroup(metrics, AbstractStatus.State.UNASSIGNED, statusListener); + this.statusListener = this.metrics; + this.offsetStorageReader = offsetStorageReader; + this.pendingTargetStateChange = new AtomicReference<>(); + this.pendingStateChangeCallback = new AtomicReference<>(); + this.shutdownLatch = new CountDownLatch(1); + this.stopping = false; + this.cancelled = false; + } + + public ClassLoader loader() { + return loader; + } + + @Override + public void run() { + // Clear all MDC parameters, in case this thread is being reused + LoggingContext.clear(); + + try (LoggingContext loggingContext = LoggingContext.forConnector(connName)) { + ClassLoader savedLoader = Plugins.compareAndSwapLoaders(loader); + String savedName = Thread.currentThread().getName(); + try { + Thread.currentThread().setName(THREAD_NAME_PREFIX + connName); + doRun(); + } finally { + Thread.currentThread().setName(savedName); + Plugins.compareAndSwapLoaders(savedLoader); + } + } finally { + // In the rare case of an exception being thrown outside the doRun() method, or an + // uncaught one being thrown from within it, mark the connector as shut down to avoid + // unnecessarily blocking and eventually timing out during awaitShutdown + shutdownLatch.countDown(); + } + } + + void doRun() { + initialize(); + while (!stopping) { + TargetState newTargetState; + Callback stateChangeCallback; + synchronized (this) { + newTargetState = pendingTargetStateChange.getAndSet(null); + stateChangeCallback = pendingStateChangeCallback.getAndSet(null); + } + if (newTargetState != null && !stopping) { + doTransitionTo(newTargetState, stateChangeCallback); + } + synchronized (this) { + if (pendingTargetStateChange.get() != null || stopping) { + // An update occurred before we entered the synchronized block; no big deal, + // just start the loop again until we've handled everything + } else { + try { + wait(); + } catch (InterruptedException e) { + // We'll pick up any potential state changes at the top of the loop + } + } + } + } + doShutdown(); + } + + void initialize() { + try { + if (!isSourceConnector() && !isSinkConnector()) { + throw new ConnectException("Connector implementations must be a subclass of either SourceConnector or SinkConnector"); + } + log.debug("{} Initializing connector {}", this, connName); + if (isSinkConnector()) { + SinkConnectorConfig.validate(config); + connector.initialize(new WorkerSinkConnectorContext()); + } else { + connector.initialize(new WorkerSourceConnectorContext(offsetStorageReader)); + } + } catch (Throwable t) { + log.error("{} Error initializing connector", this, t); + onFailure(t); + } + } + + private boolean doStart() throws Throwable { + try { + switch (state) { + case STARTED: + return false; + + case INIT: + case STOPPED: + connector.start(config); + this.state = State.STARTED; + return true; + + default: + throw new IllegalArgumentException("Cannot start connector in state " + state); + } + } catch (Throwable t) { + log.error("{} Error while starting connector", this, t); + onFailure(t); + throw t; + } + } + + private void onFailure(Throwable t) { + statusListener.onFailure(connName, t); + this.state = State.FAILED; + } + + private void resume() throws Throwable { + if (doStart()) + statusListener.onResume(connName); + } + + private void start() throws Throwable { + if (doStart()) + statusListener.onStartup(connName); + } + + public boolean isRunning() { + return state == State.STARTED; + } + + @SuppressWarnings("fallthrough") + private void pause() { + try { + switch (state) { + case STOPPED: + return; + + case STARTED: + connector.stop(); + // fall through + + case INIT: + statusListener.onPause(connName); + this.state = State.STOPPED; + break; + + default: + throw new IllegalArgumentException("Cannot pause connector in state " + state); + } + } catch (Throwable t) { + log.error("{} Error while shutting down connector", this, t); + statusListener.onFailure(connName, t); + this.state = State.FAILED; + } + } + + /** + * Stop this connector. This method does not block, it only triggers shutdown. Use + * #{@link #awaitShutdown} to block until completion. + */ + public synchronized void shutdown() { + log.info("Scheduled shutdown for {}", this); + stopping = true; + notify(); + } + + void doShutdown() { + try { + TargetState preEmptedState = pendingTargetStateChange.getAndSet(null); + Callback stateChangeCallback = pendingStateChangeCallback.getAndSet(null); + if (stateChangeCallback != null) { + stateChangeCallback.onCompletion( + new ConnectException( + "Could not begin changing connector state to " + preEmptedState.name() + + " as the connector has been scheduled for shutdown"), + null); + } + if (state == State.STARTED) + connector.stop(); + this.state = State.STOPPED; + statusListener.onShutdown(connName); + log.info("Completed shutdown for {}", this); + } catch (Throwable t) { + log.error("{} Error while shutting down connector", this, t); + state = State.FAILED; + statusListener.onFailure(connName, t); + } finally { + ctx.close(); + metrics.close(); + } + } + + public synchronized void cancel() { + // Proactively update the status of the connector to UNASSIGNED since this connector + // instance is being abandoned and we won't update the status on its behalf any more + // after this since a new instance may be started soon + statusListener.onShutdown(connName); + ctx.close(); + cancelled = true; + } + + /** + * Wait for this connector to finish shutting down. + * + * @param timeoutMs time in milliseconds to await shutdown + * @return true if successful, false if the timeout was reached + */ + public boolean awaitShutdown(long timeoutMs) { + try { + return shutdownLatch.await(timeoutMs, TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + return false; + } + } + + public void transitionTo(TargetState targetState, Callback stateChangeCallback) { + Callback preEmptedStateChangeCallback; + TargetState preEmptedState; + synchronized (this) { + preEmptedStateChangeCallback = pendingStateChangeCallback.getAndSet(stateChangeCallback); + preEmptedState = pendingTargetStateChange.getAndSet(targetState); + notify(); + } + if (preEmptedStateChangeCallback != null) { + preEmptedStateChangeCallback.onCompletion( + new ConnectException( + "Could not begin changing connector state to " + preEmptedState.name() + + " before another request to change state was made;" + + " the new request (which is to change the state to " + targetState.name() + + ") has pre-empted this one"), + null + ); + } + } + + void doTransitionTo(TargetState targetState, Callback stateChangeCallback) { + if (state == State.FAILED) { + stateChangeCallback.onCompletion( + new ConnectException(this + " Cannot transition connector to " + targetState + " since it has failed"), + null); + return; + } + + try { + doTransitionTo(targetState); + stateChangeCallback.onCompletion(null, targetState); + } catch (Throwable t) { + stateChangeCallback.onCompletion( + new ConnectException( + "Failed to transition connector " + connName + " to state " + targetState, + t), + null); + } + } + + private void doTransitionTo(TargetState targetState) throws Throwable { + log.debug("{} Transition connector to {}", this, targetState); + if (targetState == TargetState.PAUSED) { + pause(); + } else if (targetState == TargetState.STARTED) { + if (state == State.INIT) + start(); + else + resume(); + } else { + throw new IllegalArgumentException("Unhandled target state " + targetState); + } + } + + public boolean isSinkConnector() { + return ConnectUtils.isSinkConnector(connector); + } + + public boolean isSourceConnector() { + return ConnectUtils.isSourceConnector(connector); + } + + protected String connectorType() { + if (isSinkConnector()) + return "sink"; + if (isSourceConnector()) + return "source"; + return "unknown"; + } + + public Connector connector() { + return connector; + } + + ConnectorMetricsGroup metrics() { + return metrics; + } + + @Override + public String toString() { + return "WorkerConnector{" + + "id=" + connName + + '}'; + } + + class ConnectorMetricsGroup implements ConnectorStatus.Listener, AutoCloseable { + /** + * Use {@link AbstractStatus.State} since it has all of the states we want, + * unlike {@link WorkerConnector.State}. + */ + private volatile AbstractStatus.State state; + private final MetricGroup metricGroup; + private final ConnectorStatus.Listener delegate; + + public ConnectorMetricsGroup(ConnectMetrics connectMetrics, AbstractStatus.State initialState, ConnectorStatus.Listener delegate) { + Objects.requireNonNull(connectMetrics); + Objects.requireNonNull(connector); + Objects.requireNonNull(initialState); + Objects.requireNonNull(delegate); + this.delegate = delegate; + this.state = initialState; + ConnectMetricsRegistry registry = connectMetrics.registry(); + this.metricGroup = connectMetrics.group(registry.connectorGroupName(), + registry.connectorTagName(), connName); + // prevent collisions by removing any previously created metrics in this group. + metricGroup.close(); + + metricGroup.addImmutableValueMetric(registry.connectorType, connectorType()); + metricGroup.addImmutableValueMetric(registry.connectorClass, connector.getClass().getName()); + metricGroup.addImmutableValueMetric(registry.connectorVersion, connector.version()); + metricGroup.addValueMetric(registry.connectorStatus, now -> state.toString().toLowerCase(Locale.getDefault())); + } + + public void close() { + metricGroup.close(); + } + + @Override + public void onStartup(String connector) { + state = AbstractStatus.State.RUNNING; + synchronized (this) { + if (!cancelled) { + delegate.onStartup(connector); + } + } + } + + @Override + public void onShutdown(String connector) { + state = AbstractStatus.State.UNASSIGNED; + synchronized (this) { + if (!cancelled) { + delegate.onShutdown(connector); + } + } + } + + @Override + public void onPause(String connector) { + state = AbstractStatus.State.PAUSED; + synchronized (this) { + if (!cancelled) { + delegate.onPause(connector); + } + } + } + + @Override + public void onResume(String connector) { + state = AbstractStatus.State.RUNNING; + synchronized (this) { + if (!cancelled) { + delegate.onResume(connector); + } + } + } + + @Override + public void onFailure(String connector, Throwable cause) { + state = AbstractStatus.State.FAILED; + synchronized (this) { + if (!cancelled) { + delegate.onFailure(connector, cause); + } + } + } + + @Override + public void onDeletion(String connector) { + state = AbstractStatus.State.DESTROYED; + delegate.onDeletion(connector); + } + + @Override + public void onRestart(String connector) { + state = AbstractStatus.State.RESTARTING; + delegate.onRestart(connector); + } + + boolean isUnassigned() { + return state == AbstractStatus.State.UNASSIGNED; + } + + boolean isRunning() { + return state == AbstractStatus.State.RUNNING; + } + + boolean isPaused() { + return state == AbstractStatus.State.PAUSED; + } + + boolean isFailed() { + return state == AbstractStatus.State.FAILED; + } + + protected MetricGroup metricGroup() { + return metricGroup; + } + } + + private abstract class WorkerConnectorContext implements ConnectorContext { + + @Override + public void requestTaskReconfiguration() { + WorkerConnector.this.ctx.requestTaskReconfiguration(); + } + + @Override + public void raiseError(Exception e) { + log.error("{} Connector raised an error", WorkerConnector.this, e); + onFailure(e); + WorkerConnector.this.ctx.raiseError(e); + } + } + + private class WorkerSinkConnectorContext extends WorkerConnectorContext implements SinkConnectorContext { + } + + private class WorkerSourceConnectorContext extends WorkerConnectorContext implements SourceConnectorContext { + + private final OffsetStorageReader offsetStorageReader; + + WorkerSourceConnectorContext(OffsetStorageReader offsetStorageReader) { + this.offsetStorageReader = offsetStorageReader; + } + + @Override + public OffsetStorageReader offsetStorageReader() { + return offsetStorageReader; + } + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerInfo.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerInfo.java new file mode 100644 index 0000000..7d13226 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerInfo.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.lang.management.ManagementFactory; +import java.lang.management.OperatingSystemMXBean; +import java.lang.management.RuntimeMXBean; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +/** + * Connect Worker system and runtime information. + */ +public class WorkerInfo { + private static final Logger log = LoggerFactory.getLogger(WorkerInfo.class); + private static final RuntimeMXBean RUNTIME; + private static final OperatingSystemMXBean OS; + + static { + RUNTIME = ManagementFactory.getRuntimeMXBean(); + OS = ManagementFactory.getOperatingSystemMXBean(); + } + + private final Map values; + + /** + * Constructor. + */ + public WorkerInfo() { + this.values = new LinkedHashMap<>(); + addRuntimeInfo(); + addSystemInfo(); + } + + /** + * Log the values of this object at level INFO. + */ + // Equivalent to logAll in AbstractConfig + public void logAll() { + StringBuilder b = new StringBuilder(); + b.append(getClass().getSimpleName()); + b.append(" values: "); + b.append(Utils.NL); + + for (Map.Entry entry : values.entrySet()) { + b.append('\t'); + b.append(entry.getKey()); + b.append(" = "); + b.append(format(entry.getValue())); + b.append(Utils.NL); + } + log.info(b.toString()); + } + + private static Object format(Object value) { + return value == null ? "NA" : value; + } + + /** + * Collect general runtime information. + */ + protected void addRuntimeInfo() { + List jvmArgs = RUNTIME.getInputArguments(); + values.put("jvm.args", Utils.join(jvmArgs, ", ")); + String[] jvmSpec = { + RUNTIME.getVmVendor(), + RUNTIME.getVmName(), + RUNTIME.getSystemProperties().get("java.version"), + RUNTIME.getVmVersion() + }; + values.put("jvm.spec", Utils.join(jvmSpec, ", ")); + values.put("jvm.classpath", RUNTIME.getClassPath()); + } + + /** + * Collect system information. + */ + protected void addSystemInfo() { + String[] osInfo = { + OS.getName(), + OS.getArch(), + OS.getVersion(), + }; + values.put("os.spec", Utils.join(osInfo, ", ")); + values.put("os.vcpus", String.valueOf(OS.getAvailableProcessors())); + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerMetricsGroup.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerMetricsGroup.java new file mode 100644 index 0000000..f03bc4f --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerMetricsGroup.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.stats.CumulativeSum; +import org.apache.kafka.common.metrics.stats.Frequencies; +import org.apache.kafka.connect.util.ConnectorTaskId; + +import java.util.Map; + +class WorkerMetricsGroup { + private final ConnectMetrics.MetricGroup metricGroup; + private final Sensor connectorStartupAttempts; + private final Sensor connectorStartupSuccesses; + private final Sensor connectorStartupFailures; + private final Sensor connectorStartupResults; + private final Sensor taskStartupAttempts; + private final Sensor taskStartupSuccesses; + private final Sensor taskStartupFailures; + private final Sensor taskStartupResults; + + public WorkerMetricsGroup(final Map connectors, Map tasks, ConnectMetrics connectMetrics) { + ConnectMetricsRegistry registry = connectMetrics.registry(); + metricGroup = connectMetrics.group(registry.workerGroupName()); + + metricGroup.addValueMetric(registry.connectorCount, now -> (double) connectors.size()); + metricGroup.addValueMetric(registry.taskCount, now -> (double) tasks.size()); + + MetricName connectorFailurePct = metricGroup.metricName(registry.connectorStartupFailurePercentage); + MetricName connectorSuccessPct = metricGroup.metricName(registry.connectorStartupSuccessPercentage); + Frequencies connectorStartupResultFrequencies = Frequencies.forBooleanValues(connectorFailurePct, connectorSuccessPct); + connectorStartupResults = metricGroup.sensor("connector-startup-results"); + connectorStartupResults.add(connectorStartupResultFrequencies); + + connectorStartupAttempts = metricGroup.sensor("connector-startup-attempts"); + connectorStartupAttempts.add(metricGroup.metricName(registry.connectorStartupAttemptsTotal), new CumulativeSum()); + + connectorStartupSuccesses = metricGroup.sensor("connector-startup-successes"); + connectorStartupSuccesses.add(metricGroup.metricName(registry.connectorStartupSuccessTotal), new CumulativeSum()); + + connectorStartupFailures = metricGroup.sensor("connector-startup-failures"); + connectorStartupFailures.add(metricGroup.metricName(registry.connectorStartupFailureTotal), new CumulativeSum()); + + MetricName taskFailurePct = metricGroup.metricName(registry.taskStartupFailurePercentage); + MetricName taskSuccessPct = metricGroup.metricName(registry.taskStartupSuccessPercentage); + Frequencies taskStartupResultFrequencies = Frequencies.forBooleanValues(taskFailurePct, taskSuccessPct); + taskStartupResults = metricGroup.sensor("task-startup-results"); + taskStartupResults.add(taskStartupResultFrequencies); + + taskStartupAttempts = metricGroup.sensor("task-startup-attempts"); + taskStartupAttempts.add(metricGroup.metricName(registry.taskStartupAttemptsTotal), new CumulativeSum()); + + taskStartupSuccesses = metricGroup.sensor("task-startup-successes"); + taskStartupSuccesses.add(metricGroup.metricName(registry.taskStartupSuccessTotal), new CumulativeSum()); + + taskStartupFailures = metricGroup.sensor("task-startup-failures"); + taskStartupFailures.add(metricGroup.metricName(registry.taskStartupFailureTotal), new CumulativeSum()); + } + + void close() { + metricGroup.close(); + } + + void recordConnectorStartupFailure() { + connectorStartupAttempts.record(1.0); + connectorStartupFailures.record(1.0); + connectorStartupResults.record(0.0); + } + + void recordConnectorStartupSuccess() { + connectorStartupAttempts.record(1.0); + connectorStartupSuccesses.record(1.0); + connectorStartupResults.record(1.0); + } + + void recordTaskFailure() { + taskStartupAttempts.record(1.0); + taskStartupFailures.record(1.0); + taskStartupResults.record(0.0); + } + + void recordTaskSuccess() { + taskStartupAttempts.record(1.0); + taskStartupSuccesses.record(1.0); + taskStartupResults.record(1.0); + } + + protected ConnectMetrics.MetricGroup metricGroup() { + return metricGroup; + } + + ConnectorStatus.Listener wrapStatusListener(ConnectorStatus.Listener delegateListener) { + return new ConnectorStatusListener(delegateListener); + } + + TaskStatus.Listener wrapStatusListener(TaskStatus.Listener delegateListener) { + return new TaskStatusListener(delegateListener); + } + + class ConnectorStatusListener implements ConnectorStatus.Listener { + private final ConnectorStatus.Listener delegateListener; + private volatile boolean startupSucceeded = false; + + ConnectorStatusListener(ConnectorStatus.Listener delegateListener) { + this.delegateListener = delegateListener; + } + + @Override + public void onStartup(final String connector) { + startupSucceeded = true; + recordConnectorStartupSuccess(); + delegateListener.onStartup(connector); + } + + @Override + public void onPause(final String connector) { + delegateListener.onPause(connector); + } + + @Override + public void onResume(final String connector) { + delegateListener.onResume(connector); + } + + @Override + public void onFailure(final String connector, final Throwable cause) { + if (!startupSucceeded) { + recordConnectorStartupFailure(); + } + delegateListener.onFailure(connector, cause); + } + + @Override + public void onRestart(String connector) { + delegateListener.onRestart(connector); + } + + @Override + public void onShutdown(final String connector) { + delegateListener.onShutdown(connector); + } + + @Override + public void onDeletion(final String connector) { + delegateListener.onDeletion(connector); + } + } + + class TaskStatusListener implements TaskStatus.Listener { + private final TaskStatus.Listener delegatedListener; + private volatile boolean startupSucceeded = false; + + TaskStatusListener(TaskStatus.Listener delegatedListener) { + this.delegatedListener = delegatedListener; + } + + @Override + public void onStartup(final ConnectorTaskId id) { + recordTaskSuccess(); + startupSucceeded = true; + delegatedListener.onStartup(id); + } + + @Override + public void onPause(final ConnectorTaskId id) { + delegatedListener.onPause(id); + } + + @Override + public void onResume(final ConnectorTaskId id) { + delegatedListener.onResume(id); + } + + @Override + public void onFailure(final ConnectorTaskId id, final Throwable cause) { + if (!startupSucceeded) { + recordTaskFailure(); + } + delegatedListener.onFailure(id, cause); + } + + @Override + public void onRestart(ConnectorTaskId id) { + delegatedListener.onRestart(id); + } + + @Override + public void onShutdown(final ConnectorTaskId id) { + delegatedListener.onShutdown(id); + } + + @Override + public void onDeletion(final ConnectorTaskId id) { + delegatedListener.onDeletion(id); + } + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSinkTask.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSinkTask.java new file mode 100644 index 0000000..72ee749 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSinkTask.java @@ -0,0 +1,914 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.clients.consumer.ConsumerRebalanceListener; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.consumer.OffsetCommitCallback; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.WakeupException; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.stats.Avg; +import org.apache.kafka.common.metrics.stats.CumulativeSum; +import org.apache.kafka.common.metrics.stats.Max; +import org.apache.kafka.common.metrics.stats.Rate; +import org.apache.kafka.common.metrics.stats.Value; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.common.utils.Utils.UncheckedCloseable; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.errors.RetriableException; +import org.apache.kafka.connect.header.ConnectHeaders; +import org.apache.kafka.connect.header.Headers; +import org.apache.kafka.connect.runtime.ConnectMetrics.MetricGroup; +import org.apache.kafka.connect.runtime.distributed.ClusterConfigState; +import org.apache.kafka.connect.runtime.errors.RetryWithToleranceOperator; +import org.apache.kafka.connect.runtime.errors.Stage; +import org.apache.kafka.connect.runtime.errors.WorkerErrantRecordReporter; +import org.apache.kafka.connect.sink.SinkRecord; +import org.apache.kafka.connect.sink.SinkTask; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.storage.HeaderConverter; +import org.apache.kafka.connect.storage.StatusBackingStore; +import org.apache.kafka.connect.util.ConnectUtils; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import static java.util.Collections.singleton; +import static org.apache.kafka.connect.runtime.WorkerConfig.TOPIC_TRACKING_ENABLE_CONFIG; + +/** + * WorkerTask that uses a SinkTask to export data from Kafka. + */ +class WorkerSinkTask extends WorkerTask { + private static final Logger log = LoggerFactory.getLogger(WorkerSinkTask.class); + + private final WorkerConfig workerConfig; + private final SinkTask task; + private final ClusterConfigState configState; + private Map taskConfig; + private final Converter keyConverter; + private final Converter valueConverter; + private final HeaderConverter headerConverter; + private final TransformationChain transformationChain; + private final SinkTaskMetricsGroup sinkTaskMetricsGroup; + private final boolean isTopicTrackingEnabled; + private KafkaConsumer consumer; + private WorkerSinkTaskContext context; + private final List messageBatch; + private Map lastCommittedOffsets; + private Map currentOffsets; + private final Map origOffsets; + private RuntimeException rebalanceException; + private long nextCommit; + private int commitSeqno; + private long commitStarted; + private int commitFailures; + private boolean pausedForRedelivery; + private boolean committing; + private boolean taskStopped; + private final WorkerErrantRecordReporter workerErrantRecordReporter; + + public WorkerSinkTask(ConnectorTaskId id, + SinkTask task, + TaskStatus.Listener statusListener, + TargetState initialState, + WorkerConfig workerConfig, + ClusterConfigState configState, + ConnectMetrics connectMetrics, + Converter keyConverter, + Converter valueConverter, + HeaderConverter headerConverter, + TransformationChain transformationChain, + KafkaConsumer consumer, + ClassLoader loader, + Time time, + RetryWithToleranceOperator retryWithToleranceOperator, + WorkerErrantRecordReporter workerErrantRecordReporter, + StatusBackingStore statusBackingStore) { + super(id, statusListener, initialState, loader, connectMetrics, + retryWithToleranceOperator, time, statusBackingStore); + + this.workerConfig = workerConfig; + this.task = task; + this.configState = configState; + this.keyConverter = keyConverter; + this.valueConverter = valueConverter; + this.headerConverter = headerConverter; + this.transformationChain = transformationChain; + this.messageBatch = new ArrayList<>(); + this.lastCommittedOffsets = new HashMap<>(); + this.currentOffsets = new HashMap<>(); + this.origOffsets = new HashMap<>(); + this.pausedForRedelivery = false; + this.rebalanceException = null; + this.nextCommit = time.milliseconds() + + workerConfig.getLong(WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_CONFIG); + this.committing = false; + this.commitSeqno = 0; + this.commitStarted = -1; + this.commitFailures = 0; + this.sinkTaskMetricsGroup = new SinkTaskMetricsGroup(id, connectMetrics); + this.sinkTaskMetricsGroup.recordOffsetSequenceNumber(commitSeqno); + this.consumer = consumer; + this.isTopicTrackingEnabled = workerConfig.getBoolean(TOPIC_TRACKING_ENABLE_CONFIG); + this.taskStopped = false; + this.workerErrantRecordReporter = workerErrantRecordReporter; + } + + @Override + public void initialize(TaskConfig taskConfig) { + try { + this.taskConfig = taskConfig.originalsStrings(); + this.context = new WorkerSinkTaskContext(consumer, this, configState); + } catch (Throwable t) { + log.error("{} Task failed initialization and will not be started.", this, t); + onFailure(t); + } + } + + @Override + public void stop() { + // Offset commit is handled upon exit in work thread + super.stop(); + consumer.wakeup(); + } + + @Override + protected void close() { + // FIXME Kafka needs to add a timeout parameter here for us to properly obey the timeout + // passed in + try { + task.stop(); + } catch (Throwable t) { + log.warn("Could not stop task", t); + } + taskStopped = true; + Utils.closeQuietly(consumer, "consumer"); + Utils.closeQuietly(transformationChain, "transformation chain"); + Utils.closeQuietly(retryWithToleranceOperator, "retry operator"); + } + + @Override + public void removeMetrics() { + try { + sinkTaskMetricsGroup.close(); + } finally { + super.removeMetrics(); + } + } + + @Override + public void transitionTo(TargetState state) { + super.transitionTo(state); + consumer.wakeup(); + } + + @Override + public void execute() { + log.info("{} Executing sink task", this); + // Make sure any uncommitted data has been committed and the task has + // a chance to clean up its state + try (UncheckedCloseable suppressible = this::closeAllPartitions) { + while (!isStopping()) + iteration(); + } catch (WakeupException e) { + log.trace("Consumer woken up during initial offset commit attempt, " + + "but succeeded during a later attempt"); + } + } + + protected void iteration() { + final long offsetCommitIntervalMs = workerConfig.getLong(WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_CONFIG); + + try { + long now = time.milliseconds(); + + // Maybe commit + if (!committing && (context.isCommitRequested() || now >= nextCommit)) { + commitOffsets(now, false); + nextCommit = now + offsetCommitIntervalMs; + context.clearCommitRequest(); + } + + final long commitTimeoutMs = commitStarted + workerConfig.getLong(WorkerConfig.OFFSET_COMMIT_TIMEOUT_MS_CONFIG); + + // Check for timed out commits + if (committing && now >= commitTimeoutMs) { + log.warn("{} Commit of offsets timed out", this); + commitFailures++; + committing = false; + } + + // And process messages + long timeoutMs = Math.max(nextCommit - now, 0); + poll(timeoutMs); + } catch (WakeupException we) { + log.trace("{} Consumer woken up", this); + + if (isStopping()) + return; + + if (shouldPause()) { + pauseAll(); + onPause(); + context.requestCommit(); + } else if (!pausedForRedelivery) { + resumeAll(); + onResume(); + } + } + } + + /** + * Respond to a previous commit attempt that may or may not have succeeded. Note that due to our use of async commits, + * these invocations may come out of order and thus the need for the commit sequence number. + * + * @param error the error resulting from the commit, or null if the commit succeeded without error + * @param seqno the sequence number at the time the commit was requested + * @param committedOffsets the offsets that were committed; may be null if the commit did not complete successfully + * or if no new offsets were committed + */ + private void onCommitCompleted(Throwable error, long seqno, Map committedOffsets) { + if (commitSeqno != seqno) { + log.debug("{} Received out of order commit callback for sequence number {}, but most recent sequence number is {}", + this, seqno, commitSeqno); + sinkTaskMetricsGroup.recordOffsetCommitSkip(); + } else { + long durationMillis = time.milliseconds() - commitStarted; + if (error != null) { + log.error("{} Commit of offsets threw an unexpected exception for sequence number {}: {}", + this, seqno, committedOffsets, error); + commitFailures++; + recordCommitFailure(durationMillis, error); + } else { + log.debug("{} Finished offset commit successfully in {} ms for sequence number {}: {}", + this, durationMillis, seqno, committedOffsets); + if (committedOffsets != null) { + log.trace("{} Adding to last committed offsets: {}", this, committedOffsets); + lastCommittedOffsets.putAll(committedOffsets); + log.debug("{} Last committed offsets are now {}", this, committedOffsets); + sinkTaskMetricsGroup.recordCommittedOffsets(committedOffsets); + } + commitFailures = 0; + recordCommitSuccess(durationMillis); + } + committing = false; + } + } + + public int commitFailures() { + return commitFailures; + } + + /** + * Initializes and starts the SinkTask. + */ + @Override + protected void initializeAndStart() { + SinkConnectorConfig.validate(taskConfig); + + if (SinkConnectorConfig.hasTopicsConfig(taskConfig)) { + List topics = SinkConnectorConfig.parseTopicsList(taskConfig); + consumer.subscribe(topics, new HandleRebalance()); + log.debug("{} Initializing and starting task for topics {}", this, Utils.join(topics, ", ")); + } else { + String topicsRegexStr = taskConfig.get(SinkTask.TOPICS_REGEX_CONFIG); + Pattern pattern = Pattern.compile(topicsRegexStr); + consumer.subscribe(pattern, new HandleRebalance()); + log.debug("{} Initializing and starting task for topics regex {}", this, topicsRegexStr); + } + + task.initialize(context); + task.start(taskConfig); + log.info("{} Sink task finished initialization and start", this); + } + + /** + * Poll for new messages with the given timeout. Should only be invoked by the worker thread. + */ + protected void poll(long timeoutMs) { + rewind(); + long retryTimeout = context.timeout(); + if (retryTimeout > 0) { + timeoutMs = Math.min(timeoutMs, retryTimeout); + context.timeout(-1L); + } + + log.trace("{} Polling consumer with timeout {} ms", this, timeoutMs); + ConsumerRecords msgs = pollConsumer(timeoutMs); + assert messageBatch.isEmpty() || msgs.isEmpty(); + log.trace("{} Polling returned {} messages", this, msgs.count()); + + convertMessages(msgs); + deliverMessages(); + } + + // Visible for testing + boolean isCommitting() { + return committing; + } + + private void doCommitSync(Map offsets, int seqno) { + log.debug("{} Committing offsets synchronously using sequence number {}: {}", this, seqno, offsets); + try { + consumer.commitSync(offsets); + onCommitCompleted(null, seqno, offsets); + } catch (WakeupException e) { + // retry the commit to ensure offsets get pushed, then propagate the wakeup up to poll + doCommitSync(offsets, seqno); + throw e; + } catch (KafkaException e) { + onCommitCompleted(e, seqno, offsets); + } + } + + private void doCommitAsync(Map offsets, final int seqno) { + log.debug("{} Committing offsets asynchronously using sequence number {}: {}", this, seqno, offsets); + OffsetCommitCallback cb = (tpOffsets, error) -> onCommitCompleted(error, seqno, tpOffsets); + consumer.commitAsync(offsets, cb); + } + + /** + * Starts an offset commit by flushing outstanding messages from the task and then starting + * the write commit. + **/ + private void doCommit(Map offsets, boolean closing, int seqno) { + if (closing) { + doCommitSync(offsets, seqno); + } else { + doCommitAsync(offsets, seqno); + } + } + + private void commitOffsets(long now, boolean closing) { + commitOffsets(now, closing, consumer.assignment()); + } + + private void commitOffsets(long now, boolean closing, Collection topicPartitions) { + log.trace("Committing offsets for partitions {}", topicPartitions); + if (workerErrantRecordReporter != null) { + log.trace("Awaiting reported errors to be completed"); + workerErrantRecordReporter.awaitFutures(topicPartitions); + log.trace("Completed reported errors"); + } + + Map offsetsToCommit = currentOffsets.entrySet().stream() + .filter(e -> topicPartitions.contains(e.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + if (offsetsToCommit.isEmpty()) + return; + + committing = true; + commitSeqno += 1; + commitStarted = now; + sinkTaskMetricsGroup.recordOffsetSequenceNumber(commitSeqno); + + Map lastCommittedOffsetsForPartitions = this.lastCommittedOffsets.entrySet().stream() + .filter(e -> offsetsToCommit.containsKey(e.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + final Map taskProvidedOffsets; + try { + log.trace("{} Calling task.preCommit with current offsets: {}", this, offsetsToCommit); + taskProvidedOffsets = task.preCommit(new HashMap<>(offsetsToCommit)); + } catch (Throwable t) { + if (closing) { + log.warn("{} Offset commit failed during close", this); + } else { + log.error("{} Offset commit failed, rewinding to last committed offsets", this, t); + for (Map.Entry entry : lastCommittedOffsetsForPartitions.entrySet()) { + log.debug("{} Rewinding topic partition {} to offset {}", this, entry.getKey(), entry.getValue().offset()); + consumer.seek(entry.getKey(), entry.getValue().offset()); + } + currentOffsets.putAll(lastCommittedOffsetsForPartitions); + } + onCommitCompleted(t, commitSeqno, null); + return; + } finally { + if (closing) { + log.trace("{} Closing the task before committing the offsets: {}", this, offsetsToCommit); + task.close(topicPartitions); + } + } + + if (taskProvidedOffsets.isEmpty()) { + log.debug("{} Skipping offset commit, task opted-out by returning no offsets from preCommit", this); + onCommitCompleted(null, commitSeqno, null); + return; + } + + Collection allAssignedTopicPartitions = consumer.assignment(); + final Map committableOffsets = new HashMap<>(lastCommittedOffsetsForPartitions); + for (Map.Entry taskProvidedOffsetEntry : taskProvidedOffsets.entrySet()) { + final TopicPartition partition = taskProvidedOffsetEntry.getKey(); + final OffsetAndMetadata taskProvidedOffset = taskProvidedOffsetEntry.getValue(); + if (committableOffsets.containsKey(partition)) { + long taskOffset = taskProvidedOffset.offset(); + long currentOffset = offsetsToCommit.get(partition).offset(); + if (taskOffset <= currentOffset) { + committableOffsets.put(partition, taskProvidedOffset); + } else { + log.warn("{} Ignoring invalid task provided offset {}/{} -- not yet consumed, taskOffset={} currentOffset={}", + this, partition, taskProvidedOffset, taskOffset, currentOffset); + } + } else if (!allAssignedTopicPartitions.contains(partition)) { + log.warn("{} Ignoring invalid task provided offset {}/{} -- partition not assigned, assignment={}", + this, partition, taskProvidedOffset, allAssignedTopicPartitions); + } else { + log.debug("{} Ignoring task provided offset {}/{} -- partition not requested, requested={}", + this, partition, taskProvidedOffset, committableOffsets.keySet()); + } + } + + if (committableOffsets.equals(lastCommittedOffsetsForPartitions)) { + log.debug("{} Skipping offset commit, no change since last commit", this); + onCommitCompleted(null, commitSeqno, null); + return; + } + + doCommit(committableOffsets, closing, commitSeqno); + } + + + @Override + public String toString() { + return "WorkerSinkTask{" + + "id=" + id + + '}'; + } + + private ConsumerRecords pollConsumer(long timeoutMs) { + ConsumerRecords msgs = consumer.poll(Duration.ofMillis(timeoutMs)); + + // Exceptions raised from the task during a rebalance should be rethrown to stop the worker + if (rebalanceException != null) { + RuntimeException e = rebalanceException; + rebalanceException = null; + throw e; + } + + sinkTaskMetricsGroup.recordRead(msgs.count()); + return msgs; + } + + private void convertMessages(ConsumerRecords msgs) { + origOffsets.clear(); + for (ConsumerRecord msg : msgs) { + log.trace("{} Consuming and converting message in topic '{}' partition {} at offset {} and timestamp {}", + this, msg.topic(), msg.partition(), msg.offset(), msg.timestamp()); + + retryWithToleranceOperator.consumerRecord(msg); + + SinkRecord transRecord = convertAndTransformRecord(msg); + + origOffsets.put( + new TopicPartition(msg.topic(), msg.partition()), + new OffsetAndMetadata(msg.offset() + 1) + ); + if (transRecord != null) { + messageBatch.add(transRecord); + } else { + log.trace( + "{} Converters and transformations returned null, possibly because of too many retries, so " + + "dropping record in topic '{}' partition {} at offset {}", + this, msg.topic(), msg.partition(), msg.offset() + ); + } + } + sinkTaskMetricsGroup.recordConsumedOffsets(origOffsets); + } + + private SinkRecord convertAndTransformRecord(final ConsumerRecord msg) { + SchemaAndValue keyAndSchema = retryWithToleranceOperator.execute(() -> keyConverter.toConnectData(msg.topic(), msg.headers(), msg.key()), + Stage.KEY_CONVERTER, keyConverter.getClass()); + + SchemaAndValue valueAndSchema = retryWithToleranceOperator.execute(() -> valueConverter.toConnectData(msg.topic(), msg.headers(), msg.value()), + Stage.VALUE_CONVERTER, valueConverter.getClass()); + + Headers headers = retryWithToleranceOperator.execute(() -> convertHeadersFor(msg), Stage.HEADER_CONVERTER, headerConverter.getClass()); + + if (retryWithToleranceOperator.failed()) { + return null; + } + + Long timestamp = ConnectUtils.checkAndConvertTimestamp(msg.timestamp()); + SinkRecord origRecord = new SinkRecord(msg.topic(), msg.partition(), + keyAndSchema.schema(), keyAndSchema.value(), + valueAndSchema.schema(), valueAndSchema.value(), + msg.offset(), + timestamp, + msg.timestampType(), + headers); + log.trace("{} Applying transformations to record in topic '{}' partition {} at offset {} and timestamp {} with key {} and value {}", + this, msg.topic(), msg.partition(), msg.offset(), timestamp, keyAndSchema.value(), valueAndSchema.value()); + if (isTopicTrackingEnabled) { + recordActiveTopic(origRecord.topic()); + } + + // Apply the transformations + SinkRecord transformedRecord = transformationChain.apply(origRecord); + if (transformedRecord == null) { + return null; + } + // Error reporting will need to correlate each sink record with the original consumer record + return new InternalSinkRecord(msg, transformedRecord); + } + + private Headers convertHeadersFor(ConsumerRecord record) { + Headers result = new ConnectHeaders(); + org.apache.kafka.common.header.Headers recordHeaders = record.headers(); + if (recordHeaders != null) { + String topic = record.topic(); + for (org.apache.kafka.common.header.Header recordHeader : recordHeaders) { + SchemaAndValue schemaAndValue = headerConverter.toConnectHeader(topic, recordHeader.key(), recordHeader.value()); + result.add(recordHeader.key(), schemaAndValue); + } + } + return result; + } + + protected WorkerErrantRecordReporter workerErrantRecordReporter() { + return workerErrantRecordReporter; + } + + private void resumeAll() { + for (TopicPartition tp : consumer.assignment()) + if (!context.pausedPartitions().contains(tp)) + consumer.resume(singleton(tp)); + } + + private void pauseAll() { + consumer.pause(consumer.assignment()); + } + + private void deliverMessages() { + // Finally, deliver this batch to the sink + try { + // Since we reuse the messageBatch buffer, ensure we give the task its own copy + log.trace("{} Delivering batch of {} messages to task", this, messageBatch.size()); + long start = time.milliseconds(); + task.put(new ArrayList<>(messageBatch)); + // if errors raised from the operator were swallowed by the task implementation, an + // exception needs to be thrown to kill the task indicating the tolerance was exceeded + if (retryWithToleranceOperator.failed() && !retryWithToleranceOperator.withinToleranceLimits()) { + throw new ConnectException("Tolerance exceeded in error handler", + retryWithToleranceOperator.error()); + } + recordBatch(messageBatch.size()); + sinkTaskMetricsGroup.recordPut(time.milliseconds() - start); + currentOffsets.putAll(origOffsets); + messageBatch.clear(); + // If we had paused all consumer topic partitions to try to redeliver data, then we should resume any that + // the task had not explicitly paused + if (pausedForRedelivery) { + if (!shouldPause()) + resumeAll(); + pausedForRedelivery = false; + } + } catch (RetriableException e) { + log.error("{} RetriableException from SinkTask:", this, e); + if (!pausedForRedelivery) { + // If we're retrying a previous batch, make sure we've paused all topic partitions so we don't get new data, + // but will still be able to poll in order to handle user-requested timeouts, keep group membership, etc. + pausedForRedelivery = true; + pauseAll(); + } + // Let this exit normally, the batch will be reprocessed on the next loop. + } catch (Throwable t) { + log.error("{} Task threw an uncaught and unrecoverable exception. Task is being killed and will not " + + "recover until manually restarted. Error: {}", this, t.getMessage(), t); + throw new ConnectException("Exiting WorkerSinkTask due to unrecoverable exception.", t); + } + } + + private void rewind() { + Map offsets = context.offsets(); + if (offsets.isEmpty()) { + return; + } + for (Map.Entry entry: offsets.entrySet()) { + TopicPartition tp = entry.getKey(); + Long offset = entry.getValue(); + if (offset != null) { + log.trace("{} Rewind {} to offset {}", this, tp, offset); + consumer.seek(tp, offset); + lastCommittedOffsets.put(tp, new OffsetAndMetadata(offset)); + currentOffsets.put(tp, new OffsetAndMetadata(offset)); + } else { + log.warn("{} Cannot rewind {} to null offset", this, tp); + } + } + context.clearOffsets(); + } + + private void openPartitions(Collection partitions) { + updatePartitionCount(); + task.open(partitions); + } + + private void closeAllPartitions() { + closePartitions(currentOffsets.keySet(), false); + } + + private void closePartitions(Collection topicPartitions, boolean lost) { + if (!lost) { + commitOffsets(time.milliseconds(), true, topicPartitions); + } else { + log.trace("{} Closing the task as partitions have been lost: {}", this, topicPartitions); + task.close(topicPartitions); + if (workerErrantRecordReporter != null) { + log.trace("Cancelling reported errors for {}", topicPartitions); + workerErrantRecordReporter.cancelFutures(topicPartitions); + log.trace("Cancelled all reported errors for {}", topicPartitions); + } + currentOffsets.keySet().removeAll(topicPartitions); + } + updatePartitionCount(); + lastCommittedOffsets.keySet().removeAll(topicPartitions); + } + + private void updatePartitionCount() { + sinkTaskMetricsGroup.recordPartitionCount(consumer.assignment().size()); + } + + @Override + protected void recordBatch(int size) { + super.recordBatch(size); + sinkTaskMetricsGroup.recordSend(size); + } + + @Override + protected void recordCommitFailure(long duration, Throwable error) { + super.recordCommitFailure(duration, error); + } + + @Override + protected void recordCommitSuccess(long duration) { + super.recordCommitSuccess(duration); + sinkTaskMetricsGroup.recordOffsetCommitSuccess(); + } + + SinkTaskMetricsGroup sinkTaskMetricsGroup() { + return sinkTaskMetricsGroup; + } + + // Visible for testing + long getNextCommit() { + return nextCommit; + } + + private class HandleRebalance implements ConsumerRebalanceListener { + @Override + public void onPartitionsAssigned(Collection partitions) { + log.debug("{} Partitions assigned {}", WorkerSinkTask.this, partitions); + + for (TopicPartition tp : partitions) { + long pos = consumer.position(tp); + lastCommittedOffsets.put(tp, new OffsetAndMetadata(pos)); + currentOffsets.put(tp, new OffsetAndMetadata(pos)); + log.debug("{} Assigned topic partition {} with offset {}", WorkerSinkTask.this, tp, pos); + } + sinkTaskMetricsGroup.assignedOffsets(currentOffsets); + + boolean wasPausedForRedelivery = pausedForRedelivery; + pausedForRedelivery = wasPausedForRedelivery && !messageBatch.isEmpty(); + if (pausedForRedelivery) { + // Re-pause here in case we picked up new partitions in the rebalance + pauseAll(); + } else { + // If we paused everything for redelivery and all partitions for the failed deliveries have been revoked, make + // sure anything we paused that the task didn't request to be paused *and* which we still own is resumed. + // Also make sure our tracking of paused partitions is updated to remove any partitions we no longer own. + if (wasPausedForRedelivery) { + resumeAll(); + } + // Ensure that the paused partitions contains only assigned partitions and repause as necessary + context.pausedPartitions().retainAll(consumer.assignment()); + if (shouldPause()) + pauseAll(); + else if (!context.pausedPartitions().isEmpty()) + consumer.pause(context.pausedPartitions()); + } + + if (partitions.isEmpty()) { + return; + } + + // Instead of invoking the assignment callback on initialization, we guarantee the consumer is ready upon + // task start. Since this callback gets invoked during that initial setup before we've started the task, we + // need to guard against invoking the user's callback method during that period. + if (rebalanceException == null || rebalanceException instanceof WakeupException) { + try { + openPartitions(partitions); + // Rewind should be applied only if openPartitions succeeds. + rewind(); + } catch (RuntimeException e) { + // The consumer swallows exceptions raised in the rebalance listener, so we need to store + // exceptions and rethrow when poll() returns. + rebalanceException = e; + } + } + } + + @Override + public void onPartitionsRevoked(Collection partitions) { + onPartitionsRemoved(partitions, false); + } + + @Override + public void onPartitionsLost(Collection partitions) { + onPartitionsRemoved(partitions, true); + } + + private void onPartitionsRemoved(Collection partitions, boolean lost) { + if (taskStopped) { + log.trace("Skipping partition revocation callback as task has already been stopped"); + return; + } + log.debug("{} Partitions {}: {}", WorkerSinkTask.this, lost ? "lost" : "revoked", partitions); + + if (partitions.isEmpty()) + return; + + try { + closePartitions(partitions, lost); + sinkTaskMetricsGroup.clearOffsets(partitions); + } catch (RuntimeException e) { + // The consumer swallows exceptions raised in the rebalance listener, so we need to store + // exceptions and rethrow when poll() returns. + rebalanceException = e; + } + + // Make sure we don't have any leftover data since offsets for these partitions will be reset to committed positions + messageBatch.removeIf(record -> partitions.contains(new TopicPartition(record.topic(), record.kafkaPartition()))); + } + } + + static class SinkTaskMetricsGroup { + private final ConnectorTaskId id; + private final ConnectMetrics metrics; + private final MetricGroup metricGroup; + private final Sensor sinkRecordRead; + private final Sensor sinkRecordSend; + private final Sensor partitionCount; + private final Sensor offsetSeqNum; + private final Sensor offsetCompletion; + private final Sensor offsetCompletionSkip; + private final Sensor putBatchTime; + private final Sensor sinkRecordActiveCount; + private long activeRecords; + private Map consumedOffsets = new HashMap<>(); + private Map committedOffsets = new HashMap<>(); + + public SinkTaskMetricsGroup(ConnectorTaskId id, ConnectMetrics connectMetrics) { + this.metrics = connectMetrics; + this.id = id; + + ConnectMetricsRegistry registry = connectMetrics.registry(); + metricGroup = connectMetrics + .group(registry.sinkTaskGroupName(), registry.connectorTagName(), id.connector(), registry.taskTagName(), + Integer.toString(id.task())); + // prevent collisions by removing any previously created metrics in this group. + metricGroup.close(); + + sinkRecordRead = metricGroup.sensor("sink-record-read"); + sinkRecordRead.add(metricGroup.metricName(registry.sinkRecordReadRate), new Rate()); + sinkRecordRead.add(metricGroup.metricName(registry.sinkRecordReadTotal), new CumulativeSum()); + + sinkRecordSend = metricGroup.sensor("sink-record-send"); + sinkRecordSend.add(metricGroup.metricName(registry.sinkRecordSendRate), new Rate()); + sinkRecordSend.add(metricGroup.metricName(registry.sinkRecordSendTotal), new CumulativeSum()); + + sinkRecordActiveCount = metricGroup.sensor("sink-record-active-count"); + sinkRecordActiveCount.add(metricGroup.metricName(registry.sinkRecordActiveCount), new Value()); + sinkRecordActiveCount.add(metricGroup.metricName(registry.sinkRecordActiveCountMax), new Max()); + sinkRecordActiveCount.add(metricGroup.metricName(registry.sinkRecordActiveCountAvg), new Avg()); + + partitionCount = metricGroup.sensor("partition-count"); + partitionCount.add(metricGroup.metricName(registry.sinkRecordPartitionCount), new Value()); + + offsetSeqNum = metricGroup.sensor("offset-seq-number"); + offsetSeqNum.add(metricGroup.metricName(registry.sinkRecordOffsetCommitSeqNum), new Value()); + + offsetCompletion = metricGroup.sensor("offset-commit-completion"); + offsetCompletion.add(metricGroup.metricName(registry.sinkRecordOffsetCommitCompletionRate), new Rate()); + offsetCompletion.add(metricGroup.metricName(registry.sinkRecordOffsetCommitCompletionTotal), new CumulativeSum()); + + offsetCompletionSkip = metricGroup.sensor("offset-commit-completion-skip"); + offsetCompletionSkip.add(metricGroup.metricName(registry.sinkRecordOffsetCommitSkipRate), new Rate()); + offsetCompletionSkip.add(metricGroup.metricName(registry.sinkRecordOffsetCommitSkipTotal), new CumulativeSum()); + + putBatchTime = metricGroup.sensor("put-batch-time"); + putBatchTime.add(metricGroup.metricName(registry.sinkRecordPutBatchTimeMax), new Max()); + putBatchTime.add(metricGroup.metricName(registry.sinkRecordPutBatchTimeAvg), new Avg()); + } + + void computeSinkRecordLag() { + Map consumed = this.consumedOffsets; + Map committed = this.committedOffsets; + activeRecords = 0L; + for (Map.Entry committedOffsetEntry : committed.entrySet()) { + final TopicPartition partition = committedOffsetEntry.getKey(); + final OffsetAndMetadata consumedOffsetMeta = consumed.get(partition); + if (consumedOffsetMeta != null) { + final OffsetAndMetadata committedOffsetMeta = committedOffsetEntry.getValue(); + long consumedOffset = consumedOffsetMeta.offset(); + long committedOffset = committedOffsetMeta.offset(); + long diff = consumedOffset - committedOffset; + // Connector tasks can return offsets, so make sure nothing wonky happens + activeRecords += Math.max(diff, 0L); + } + } + sinkRecordActiveCount.record(activeRecords); + } + + void close() { + metricGroup.close(); + } + + void recordRead(int batchSize) { + sinkRecordRead.record(batchSize); + } + + void recordSend(int batchSize) { + sinkRecordSend.record(batchSize); + } + + void recordPut(long duration) { + putBatchTime.record(duration); + } + + void recordPartitionCount(int assignedPartitionCount) { + partitionCount.record(assignedPartitionCount); + } + + void recordOffsetSequenceNumber(int seqNum) { + offsetSeqNum.record(seqNum); + } + + void recordConsumedOffsets(Map offsets) { + consumedOffsets.putAll(offsets); + computeSinkRecordLag(); + } + + void recordCommittedOffsets(Map offsets) { + committedOffsets = offsets; + computeSinkRecordLag(); + } + + void assignedOffsets(Map offsets) { + consumedOffsets = new HashMap<>(offsets); + committedOffsets = offsets; + computeSinkRecordLag(); + } + + void clearOffsets(Collection topicPartitions) { + consumedOffsets.keySet().removeAll(topicPartitions); + committedOffsets.keySet().removeAll(topicPartitions); + computeSinkRecordLag(); + } + + void recordOffsetCommitSuccess() { + offsetCompletion.record(1.0); + } + + void recordOffsetCommitSkip() { + offsetCompletionSkip.record(1.0); + } + + protected MetricGroup metricGroup() { + return metricGroup; + } + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSinkTaskContext.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSinkTaskContext.java new file mode 100644 index 0000000..724b02e --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSinkTaskContext.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.connect.errors.IllegalWorkerStateException; +import org.apache.kafka.connect.runtime.distributed.ClusterConfigState; +import org.apache.kafka.connect.sink.ErrantRecordReporter; +import org.apache.kafka.connect.sink.SinkTaskContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +public class WorkerSinkTaskContext implements SinkTaskContext { + + private final Logger log = LoggerFactory.getLogger(getClass()); + private Map offsets; + private long timeoutMs; + private KafkaConsumer consumer; + private final WorkerSinkTask sinkTask; + private final ClusterConfigState configState; + private final Set pausedPartitions; + private boolean commitRequested; + + public WorkerSinkTaskContext(KafkaConsumer consumer, + WorkerSinkTask sinkTask, + ClusterConfigState configState) { + this.offsets = new HashMap<>(); + this.timeoutMs = -1L; + this.consumer = consumer; + this.sinkTask = sinkTask; + this.configState = configState; + this.pausedPartitions = new HashSet<>(); + } + + @Override + public Map configs() { + return configState.taskConfig(sinkTask.id()); + } + + @Override + public void offset(Map offsets) { + log.debug("{} Setting offsets for topic partitions {}", this, offsets); + this.offsets.putAll(offsets); + } + + @Override + public void offset(TopicPartition tp, long offset) { + log.debug("{} Setting offset for topic partition {} to {}", this, tp, offset); + offsets.put(tp, offset); + } + + public void clearOffsets() { + offsets.clear(); + } + + /** + * Get offsets that the SinkTask has submitted to be reset. Used by the Kafka Connect framework. + * @return the map of offsets + */ + public Map offsets() { + return offsets; + } + + @Override + public void timeout(long timeoutMs) { + log.debug("{} Setting timeout to {} ms", this, timeoutMs); + this.timeoutMs = timeoutMs; + } + + /** + * Get the timeout in milliseconds set by SinkTasks. Used by the Kafka Connect framework. + * @return the backoff timeout in milliseconds. + */ + public long timeout() { + return timeoutMs; + } + + @Override + public Set assignment() { + if (consumer == null) { + throw new IllegalWorkerStateException("SinkTaskContext may not be used to look up partition assignment until the task is initialized"); + } + return consumer.assignment(); + } + + @Override + public void pause(TopicPartition... partitions) { + if (consumer == null) { + throw new IllegalWorkerStateException("SinkTaskContext may not be used to pause consumption until the task is initialized"); + } + try { + Collections.addAll(pausedPartitions, partitions); + if (sinkTask.shouldPause()) { + log.debug("{} Connector is paused, so not pausing consumer's partitions {}", this, partitions); + } else { + consumer.pause(Arrays.asList(partitions)); + log.debug("{} Pausing partitions {}. Connector is not paused.", this, partitions); + } + } catch (IllegalStateException e) { + throw new IllegalWorkerStateException("SinkTasks may not pause partitions that are not currently assigned to them.", e); + } + } + + @Override + public void resume(TopicPartition... partitions) { + if (consumer == null) { + throw new IllegalWorkerStateException("SinkTaskContext may not be used to resume consumption until the task is initialized"); + } + try { + pausedPartitions.removeAll(Arrays.asList(partitions)); + if (sinkTask.shouldPause()) { + log.debug("{} Connector is paused, so not resuming consumer's partitions {}", this, partitions); + } else { + consumer.resume(Arrays.asList(partitions)); + log.debug("{} Resuming partitions: {}", this, partitions); + } + } catch (IllegalStateException e) { + throw new IllegalWorkerStateException("SinkTasks may not resume partitions that are not currently assigned to them.", e); + } + } + + public Set pausedPartitions() { + return pausedPartitions; + } + + @Override + public void requestCommit() { + log.debug("{} Requesting commit", this); + commitRequested = true; + } + + public boolean isCommitRequested() { + return commitRequested; + } + + public void clearCommitRequest() { + commitRequested = false; + } + + @Override + public ErrantRecordReporter errantRecordReporter() { + return sinkTask.workerErrantRecordReporter(); + } + + @Override + public String toString() { + return "WorkerSinkTaskContext{" + + "id=" + sinkTask.id + + '}'; + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSourceTask.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSourceTask.java new file mode 100644 index 0000000..ed36676 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSourceTask.java @@ -0,0 +1,687 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.clients.admin.TopicDescription; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.stats.Avg; +import org.apache.kafka.common.metrics.stats.CumulativeSum; +import org.apache.kafka.common.metrics.stats.Max; +import org.apache.kafka.common.metrics.stats.Rate; +import org.apache.kafka.common.metrics.stats.Value; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.errors.RetriableException; +import org.apache.kafka.connect.header.Header; +import org.apache.kafka.connect.header.Headers; +import org.apache.kafka.connect.runtime.ConnectMetrics.MetricGroup; +import org.apache.kafka.connect.runtime.SubmittedRecords.SubmittedRecord; +import org.apache.kafka.connect.runtime.distributed.ClusterConfigState; +import org.apache.kafka.connect.runtime.errors.RetryWithToleranceOperator; +import org.apache.kafka.connect.runtime.errors.Stage; +import org.apache.kafka.connect.source.SourceRecord; +import org.apache.kafka.connect.source.SourceTask; +import org.apache.kafka.connect.storage.CloseableOffsetStorageReader; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.storage.HeaderConverter; +import org.apache.kafka.connect.storage.OffsetStorageWriter; +import org.apache.kafka.connect.storage.StatusBackingStore; +import org.apache.kafka.connect.util.ConnectUtils; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.apache.kafka.connect.util.TopicAdmin; +import org.apache.kafka.connect.util.TopicCreation; +import org.apache.kafka.connect.util.TopicCreationGroup; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReference; + +import static org.apache.kafka.connect.runtime.SubmittedRecords.CommittableOffsets; +import static org.apache.kafka.connect.runtime.WorkerConfig.TOPIC_TRACKING_ENABLE_CONFIG; + +/** + * WorkerTask that uses a SourceTask to ingest data into Kafka. + */ +class WorkerSourceTask extends WorkerTask { + private static final Logger log = LoggerFactory.getLogger(WorkerSourceTask.class); + + private static final long SEND_FAILED_BACKOFF_MS = 100; + + private final WorkerConfig workerConfig; + private final SourceTask task; + private final ClusterConfigState configState; + private final Converter keyConverter; + private final Converter valueConverter; + private final HeaderConverter headerConverter; + private final TransformationChain transformationChain; + private final KafkaProducer producer; + private final TopicAdmin admin; + private final CloseableOffsetStorageReader offsetReader; + private final OffsetStorageWriter offsetWriter; + private final Executor closeExecutor; + private final SourceTaskMetricsGroup sourceTaskMetricsGroup; + private final AtomicReference producerSendException; + private final boolean isTopicTrackingEnabled; + private final TopicCreation topicCreation; + + private List toSend; + private volatile CommittableOffsets committableOffsets; + private final SubmittedRecords submittedRecords; + private final CountDownLatch stopRequestedLatch; + + private Map taskConfig; + private boolean started = false; + + public WorkerSourceTask(ConnectorTaskId id, + SourceTask task, + TaskStatus.Listener statusListener, + TargetState initialState, + Converter keyConverter, + Converter valueConverter, + HeaderConverter headerConverter, + TransformationChain transformationChain, + KafkaProducer producer, + TopicAdmin admin, + Map topicGroups, + CloseableOffsetStorageReader offsetReader, + OffsetStorageWriter offsetWriter, + WorkerConfig workerConfig, + ClusterConfigState configState, + ConnectMetrics connectMetrics, + ClassLoader loader, + Time time, + RetryWithToleranceOperator retryWithToleranceOperator, + StatusBackingStore statusBackingStore, + Executor closeExecutor) { + + super(id, statusListener, initialState, loader, connectMetrics, + retryWithToleranceOperator, time, statusBackingStore); + + this.workerConfig = workerConfig; + this.task = task; + this.configState = configState; + this.keyConverter = keyConverter; + this.valueConverter = valueConverter; + this.headerConverter = headerConverter; + this.transformationChain = transformationChain; + this.producer = producer; + this.admin = admin; + this.offsetReader = offsetReader; + this.offsetWriter = offsetWriter; + this.closeExecutor = closeExecutor; + + this.toSend = null; + this.committableOffsets = CommittableOffsets.EMPTY; + this.submittedRecords = new SubmittedRecords(); + this.stopRequestedLatch = new CountDownLatch(1); + this.sourceTaskMetricsGroup = new SourceTaskMetricsGroup(id, connectMetrics); + this.producerSendException = new AtomicReference<>(); + this.isTopicTrackingEnabled = workerConfig.getBoolean(TOPIC_TRACKING_ENABLE_CONFIG); + this.topicCreation = TopicCreation.newTopicCreation(workerConfig, topicGroups); + } + + @Override + public void initialize(TaskConfig taskConfig) { + try { + this.taskConfig = taskConfig.originalsStrings(); + } catch (Throwable t) { + log.error("{} Task failed initialization and will not be started.", this, t); + onFailure(t); + } + } + + @Override + protected void close() { + if (started) { + try { + task.stop(); + } catch (Throwable t) { + log.warn("Could not stop task", t); + } + } + + closeProducer(Duration.ofSeconds(30)); + + if (admin != null) { + try { + admin.close(Duration.ofSeconds(30)); + } catch (Throwable t) { + log.warn("Failed to close admin client on time", t); + } + } + Utils.closeQuietly(transformationChain, "transformation chain"); + Utils.closeQuietly(retryWithToleranceOperator, "retry operator"); + } + + @Override + public void removeMetrics() { + try { + sourceTaskMetricsGroup.close(); + } finally { + super.removeMetrics(); + } + } + + @Override + public void cancel() { + super.cancel(); + offsetReader.close(); + // We proactively close the producer here as the main work thread for the task may + // be blocked indefinitely in a call to Producer::send if automatic topic creation is + // not enabled on either the connector or the Kafka cluster. Closing the producer should + // unblock it in that case and allow shutdown to proceed normally. + // With a duration of 0, the producer's own shutdown logic should be fairly quick, + // but closing user-pluggable classes like interceptors may lag indefinitely. So, we + // call close on a separate thread in order to avoid blocking the herder's tick thread. + closeExecutor.execute(() -> closeProducer(Duration.ZERO)); + } + + @Override + public void stop() { + super.stop(); + stopRequestedLatch.countDown(); + } + + @Override + protected void initializeAndStart() { + // If we try to start the task at all by invoking initialize, then count this as + // "started" and expect a subsequent call to the task's stop() method + // to properly clean up any resources allocated by its initialize() or + // start() methods. If the task throws an exception during stop(), + // the worst thing that happens is another exception gets logged for an already- + // failed task + started = true; + task.initialize(new WorkerSourceTaskContext(offsetReader, this, configState)); + task.start(taskConfig); + log.info("{} Source task finished initialization and start", this); + } + + @Override + public void execute() { + try { + log.info("{} Executing source task", this); + while (!isStopping()) { + updateCommittableOffsets(); + + if (shouldPause()) { + onPause(); + if (awaitUnpause()) { + onResume(); + } + continue; + } + + maybeThrowProducerSendException(); + if (toSend == null) { + log.trace("{} Nothing to send to Kafka. Polling source for additional records", this); + long start = time.milliseconds(); + toSend = poll(); + if (toSend != null) { + recordPollReturned(toSend.size(), time.milliseconds() - start); + } + } + + if (toSend == null) + continue; + log.trace("{} About to send {} records to Kafka", this, toSend.size()); + if (!sendRecords()) + stopRequestedLatch.await(SEND_FAILED_BACKOFF_MS, TimeUnit.MILLISECONDS); + } + } catch (InterruptedException e) { + // Ignore and allow to exit. + } finally { + submittedRecords.awaitAllMessages( + workerConfig.getLong(WorkerConfig.OFFSET_COMMIT_TIMEOUT_MS_CONFIG), + TimeUnit.MILLISECONDS + ); + // It should still be safe to commit offsets since any exception would have + // simply resulted in not getting more records but all the existing records should be ok to flush + // and commit offsets. Worst case, task.flush() will also throw an exception causing the offset commit + // to fail. + updateCommittableOffsets(); + commitOffsets(); + } + } + + private void closeProducer(Duration duration) { + if (producer != null) { + try { + producer.close(duration); + } catch (Throwable t) { + log.warn("Could not close producer for {}", id, t); + } + } + } + + private void maybeThrowProducerSendException() { + if (producerSendException.get() != null) { + throw new ConnectException( + "Unrecoverable exception from producer send callback", + producerSendException.get() + ); + } + } + + private void updateCommittableOffsets() { + CommittableOffsets newOffsets = submittedRecords.committableOffsets(); + synchronized (this) { + this.committableOffsets = this.committableOffsets.updatedWith(newOffsets); + } + } + + protected List poll() throws InterruptedException { + try { + return task.poll(); + } catch (RetriableException | org.apache.kafka.common.errors.RetriableException e) { + log.warn("{} failed to poll records from SourceTask. Will retry operation.", this, e); + // Do nothing. Let the framework poll whenever it's ready. + return null; + } + } + + /** + * Convert the source record into a producer record. + * + * @param record the transformed record + * @return the producer record which can sent over to Kafka. A null is returned if the input is null or + * if an error was encountered during any of the converter stages. + */ + private ProducerRecord convertTransformedRecord(SourceRecord record) { + if (record == null) { + return null; + } + + RecordHeaders headers = retryWithToleranceOperator.execute(() -> convertHeaderFor(record), Stage.HEADER_CONVERTER, headerConverter.getClass()); + + byte[] key = retryWithToleranceOperator.execute(() -> keyConverter.fromConnectData(record.topic(), headers, record.keySchema(), record.key()), + Stage.KEY_CONVERTER, keyConverter.getClass()); + + byte[] value = retryWithToleranceOperator.execute(() -> valueConverter.fromConnectData(record.topic(), headers, record.valueSchema(), record.value()), + Stage.VALUE_CONVERTER, valueConverter.getClass()); + + if (retryWithToleranceOperator.failed()) { + return null; + } + + return new ProducerRecord<>(record.topic(), record.kafkaPartition(), + ConnectUtils.checkAndConvertTimestamp(record.timestamp()), key, value, headers); + } + + /** + * Try to send a batch of records. If a send fails and is retriable, this saves the remainder of the batch so it can + * be retried after backing off. If a send fails and is not retriable, this will throw a ConnectException. + * @return true if all messages were sent, false if some need to be retried + */ + private boolean sendRecords() { + int processed = 0; + recordBatch(toSend.size()); + final SourceRecordWriteCounter counter = + toSend.size() > 0 ? new SourceRecordWriteCounter(toSend.size(), sourceTaskMetricsGroup) : null; + for (final SourceRecord preTransformRecord : toSend) { + maybeThrowProducerSendException(); + + retryWithToleranceOperator.sourceRecord(preTransformRecord); + final SourceRecord record = transformationChain.apply(preTransformRecord); + final ProducerRecord producerRecord = convertTransformedRecord(record); + if (producerRecord == null || retryWithToleranceOperator.failed()) { + counter.skipRecord(); + commitTaskRecord(preTransformRecord, null); + continue; + } + + log.trace("{} Appending record to the topic {} with key {}, value {}", this, record.topic(), record.key(), record.value()); + SubmittedRecord submittedRecord = submittedRecords.submit(record); + try { + maybeCreateTopic(record.topic()); + final String topic = producerRecord.topic(); + producer.send( + producerRecord, + (recordMetadata, e) -> { + if (e != null) { + log.error("{} failed to send record to {}: ", WorkerSourceTask.this, topic, e); + log.trace("{} Failed record: {}", WorkerSourceTask.this, preTransformRecord); + producerSendException.compareAndSet(null, e); + } else { + submittedRecord.ack(); + counter.completeRecord(); + log.trace("{} Wrote record successfully: topic {} partition {} offset {}", + WorkerSourceTask.this, + recordMetadata.topic(), recordMetadata.partition(), + recordMetadata.offset()); + commitTaskRecord(preTransformRecord, recordMetadata); + if (isTopicTrackingEnabled) { + recordActiveTopic(producerRecord.topic()); + } + } + }); + } catch (RetriableException | org.apache.kafka.common.errors.RetriableException e) { + log.warn("{} Failed to send record to topic '{}' and partition '{}'. Backing off before retrying: ", + this, producerRecord.topic(), producerRecord.partition(), e); + toSend = toSend.subList(processed, toSend.size()); + submittedRecords.removeLastOccurrence(submittedRecord); + counter.retryRemaining(); + return false; + } catch (ConnectException e) { + log.warn("{} Failed to send record to topic '{}' and partition '{}' due to an unrecoverable exception: ", + this, producerRecord.topic(), producerRecord.partition(), e); + log.trace("{} Failed to send {} with unrecoverable exception: ", this, producerRecord, e); + throw e; + } catch (KafkaException e) { + throw new ConnectException("Unrecoverable exception trying to send", e); + } + processed++; + } + toSend = null; + return true; + } + + // Due to transformations that may change the destination topic of a record (such as + // RegexRouter) topic creation can not be batched for multiple topics + private void maybeCreateTopic(String topic) { + if (!topicCreation.isTopicCreationRequired(topic)) { + log.trace("Topic creation by the connector is disabled or the topic {} was previously created." + + "If auto.create.topics.enable is enabled on the broker, " + + "the topic will be created with default settings", topic); + return; + } + log.info("The task will send records to topic '{}' for the first time. Checking " + + "whether topic exists", topic); + Map existing = admin.describeTopics(topic); + if (!existing.isEmpty()) { + log.info("Topic '{}' already exists.", topic); + topicCreation.addTopic(topic); + return; + } + + log.info("Creating topic '{}'", topic); + TopicCreationGroup topicGroup = topicCreation.findFirstGroup(topic); + log.debug("Topic '{}' matched topic creation group: {}", topic, topicGroup); + NewTopic newTopic = topicGroup.newTopic(topic); + + TopicAdmin.TopicCreationResponse response = admin.createOrFindTopics(newTopic); + if (response.isCreated(newTopic.name())) { + topicCreation.addTopic(topic); + log.info("Created topic '{}' using creation group {}", newTopic, topicGroup); + } else if (response.isExisting(newTopic.name())) { + topicCreation.addTopic(topic); + log.info("Found existing topic '{}'", newTopic); + } else { + // The topic still does not exist and could not be created, so treat it as a task failure + log.warn("Request to create new topic '{}' failed", topic); + throw new ConnectException("Task failed to create new topic " + newTopic + ". Ensure " + + "that the task is authorized to create topics or that the topic exists and " + + "restart the task"); + } + } + + private RecordHeaders convertHeaderFor(SourceRecord record) { + Headers headers = record.headers(); + RecordHeaders result = new RecordHeaders(); + if (headers != null) { + String topic = record.topic(); + for (Header header : headers) { + String key = header.key(); + byte[] rawHeader = headerConverter.fromConnectHeader(topic, key, header.schema(), header.value()); + result.add(key, rawHeader); + } + } + return result; + } + + private void commitTaskRecord(SourceRecord record, RecordMetadata metadata) { + try { + task.commitRecord(record, metadata); + } catch (Throwable t) { + log.error("{} Exception thrown while calling task.commitRecord()", this, t); + } + } + + public boolean commitOffsets() { + long commitTimeoutMs = workerConfig.getLong(WorkerConfig.OFFSET_COMMIT_TIMEOUT_MS_CONFIG); + + log.debug("{} Committing offsets", this); + + long started = time.milliseconds(); + long timeout = started + commitTimeoutMs; + + CommittableOffsets offsetsToCommit; + synchronized (this) { + offsetsToCommit = this.committableOffsets; + this.committableOffsets = CommittableOffsets.EMPTY; + } + + if (committableOffsets.isEmpty()) { + log.info("{} Either no records were produced by the task since the last offset commit, " + + "or every record has been filtered out by a transformation " + + "or dropped due to transformation or conversion errors.", + this + ); + // We continue with the offset commit process here instead of simply returning immediately + // in order to invoke SourceTask::commit and record metrics for a successful offset commit + } else { + log.info("{} Committing offsets for {} acknowledged messages", this, committableOffsets.numCommittableMessages()); + if (committableOffsets.hasPending()) { + log.debug("{} There are currently {} pending messages spread across {} source partitions whose offsets will not be committed. " + + "The source partition with the most pending messages is {}, with {} pending messages", + this, + committableOffsets.numUncommittableMessages(), + committableOffsets.numDeques(), + committableOffsets.largestDequePartition(), + committableOffsets.largestDequeSize() + ); + } else { + log.debug("{} There are currently no pending messages for this offset commit; " + + "all messages dispatched to the task's producer since the last commit have been acknowledged", + this + ); + } + } + + // Update the offset writer with any new offsets for records that have been acked. + // The offset writer will continue to track all offsets until they are able to be successfully flushed. + // IOW, if the offset writer fails to flush, it keeps those offset for the next attempt, + // though we may update them here with newer offsets for acked records. + offsetsToCommit.offsets().forEach(offsetWriter::offset); + + if (!offsetWriter.beginFlush()) { + // There was nothing in the offsets to process, but we still mark a successful offset commit. + long durationMillis = time.milliseconds() - started; + recordCommitSuccess(durationMillis); + log.debug("{} Finished offset commitOffsets successfully in {} ms", + this, durationMillis); + + commitSourceTask(); + return true; + } + + // Now we can actually flush the offsets to user storage. + Future flushFuture = offsetWriter.doFlush((error, result) -> { + if (error != null) { + log.error("{} Failed to flush offsets to storage: ", WorkerSourceTask.this, error); + } else { + log.trace("{} Finished flushing offsets to storage", WorkerSourceTask.this); + } + }); + // Very rare case: offsets were unserializable and we finished immediately, unable to store + // any data + if (flushFuture == null) { + offsetWriter.cancelFlush(); + recordCommitFailure(time.milliseconds() - started, null); + return false; + } + try { + flushFuture.get(Math.max(timeout - time.milliseconds(), 0), TimeUnit.MILLISECONDS); + // There's a small race here where we can get the callback just as this times out (and log + // success), but then catch the exception below and cancel everything. This won't cause any + // errors, is only wasteful in this minor edge case, and the worst result is that the log + // could look a little confusing. + } catch (InterruptedException e) { + log.warn("{} Flush of offsets interrupted, cancelling", this); + offsetWriter.cancelFlush(); + recordCommitFailure(time.milliseconds() - started, e); + return false; + } catch (ExecutionException e) { + log.error("{} Flush of offsets threw an unexpected exception: ", this, e); + offsetWriter.cancelFlush(); + recordCommitFailure(time.milliseconds() - started, e); + return false; + } catch (TimeoutException e) { + log.error("{} Timed out waiting to flush offsets to storage; will try again on next flush interval with latest offsets", this); + offsetWriter.cancelFlush(); + recordCommitFailure(time.milliseconds() - started, null); + return false; + } + + long durationMillis = time.milliseconds() - started; + recordCommitSuccess(durationMillis); + log.debug("{} Finished commitOffsets successfully in {} ms", + this, durationMillis); + + commitSourceTask(); + + return true; + } + + private void commitSourceTask() { + try { + this.task.commit(); + } catch (Throwable t) { + log.error("{} Exception thrown while calling task.commit()", this, t); + } + } + + @Override + public String toString() { + return "WorkerSourceTask{" + + "id=" + id + + '}'; + } + + protected void recordPollReturned(int numRecordsInBatch, long duration) { + sourceTaskMetricsGroup.recordPoll(numRecordsInBatch, duration); + } + + SourceTaskMetricsGroup sourceTaskMetricsGroup() { + return sourceTaskMetricsGroup; + } + + static class SourceRecordWriteCounter { + private final SourceTaskMetricsGroup metricsGroup; + private final int batchSize; + private boolean completed = false; + private int counter; + public SourceRecordWriteCounter(int batchSize, SourceTaskMetricsGroup metricsGroup) { + assert batchSize > 0; + assert metricsGroup != null; + this.batchSize = batchSize; + counter = batchSize; + this.metricsGroup = metricsGroup; + } + public void skipRecord() { + if (counter > 0 && --counter == 0) { + finishedAllWrites(); + } + } + public void completeRecord() { + if (counter > 0 && --counter == 0) { + finishedAllWrites(); + } + } + public void retryRemaining() { + finishedAllWrites(); + } + private void finishedAllWrites() { + if (!completed) { + metricsGroup.recordWrite(batchSize - counter); + completed = true; + } + } + } + + static class SourceTaskMetricsGroup { + private final MetricGroup metricGroup; + private final Sensor sourceRecordPoll; + private final Sensor sourceRecordWrite; + private final Sensor sourceRecordActiveCount; + private final Sensor pollTime; + private int activeRecordCount; + + public SourceTaskMetricsGroup(ConnectorTaskId id, ConnectMetrics connectMetrics) { + ConnectMetricsRegistry registry = connectMetrics.registry(); + metricGroup = connectMetrics.group(registry.sourceTaskGroupName(), + registry.connectorTagName(), id.connector(), + registry.taskTagName(), Integer.toString(id.task())); + // remove any previously created metrics in this group to prevent collisions. + metricGroup.close(); + + sourceRecordPoll = metricGroup.sensor("source-record-poll"); + sourceRecordPoll.add(metricGroup.metricName(registry.sourceRecordPollRate), new Rate()); + sourceRecordPoll.add(metricGroup.metricName(registry.sourceRecordPollTotal), new CumulativeSum()); + + sourceRecordWrite = metricGroup.sensor("source-record-write"); + sourceRecordWrite.add(metricGroup.metricName(registry.sourceRecordWriteRate), new Rate()); + sourceRecordWrite.add(metricGroup.metricName(registry.sourceRecordWriteTotal), new CumulativeSum()); + + pollTime = metricGroup.sensor("poll-batch-time"); + pollTime.add(metricGroup.metricName(registry.sourceRecordPollBatchTimeMax), new Max()); + pollTime.add(metricGroup.metricName(registry.sourceRecordPollBatchTimeAvg), new Avg()); + + sourceRecordActiveCount = metricGroup.sensor("source-record-active-count"); + sourceRecordActiveCount.add(metricGroup.metricName(registry.sourceRecordActiveCount), new Value()); + sourceRecordActiveCount.add(metricGroup.metricName(registry.sourceRecordActiveCountMax), new Max()); + sourceRecordActiveCount.add(metricGroup.metricName(registry.sourceRecordActiveCountAvg), new Avg()); + } + + void close() { + metricGroup.close(); + } + + void recordPoll(int batchSize, long duration) { + sourceRecordPoll.record(batchSize); + pollTime.record(duration); + activeRecordCount += batchSize; + sourceRecordActiveCount.record(activeRecordCount); + } + + void recordWrite(int recordCount) { + sourceRecordWrite.record(recordCount); + activeRecordCount -= recordCount; + activeRecordCount = Math.max(0, activeRecordCount); + sourceRecordActiveCount.record(activeRecordCount); + } + + protected MetricGroup metricGroup() { + return metricGroup; + } + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSourceTaskContext.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSourceTaskContext.java new file mode 100644 index 0000000..fe1409b --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerSourceTaskContext.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.connect.runtime.distributed.ClusterConfigState; +import org.apache.kafka.connect.source.SourceTaskContext; +import org.apache.kafka.connect.storage.OffsetStorageReader; + +import java.util.Map; + +public class WorkerSourceTaskContext implements SourceTaskContext { + + private final OffsetStorageReader reader; + private final WorkerSourceTask task; + private final ClusterConfigState configState; + + public WorkerSourceTaskContext(OffsetStorageReader reader, + WorkerSourceTask task, + ClusterConfigState configState) { + this.reader = reader; + this.task = task; + this.configState = configState; + } + + @Override + public Map configs() { + return configState.taskConfig(task.id()); + } + + @Override + public OffsetStorageReader offsetStorageReader() { + return reader; + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerTask.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerTask.java new file mode 100644 index 0000000..0d893f5 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/WorkerTask.java @@ -0,0 +1,466 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.MetricNameTemplate; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.stats.Avg; +import org.apache.kafka.common.metrics.stats.Frequencies; +import org.apache.kafka.common.metrics.stats.Max; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.connect.runtime.AbstractStatus.State; +import org.apache.kafka.connect.runtime.ConnectMetrics.MetricGroup; +import org.apache.kafka.connect.runtime.errors.RetryWithToleranceOperator; +import org.apache.kafka.connect.runtime.isolation.Plugins; +import org.apache.kafka.connect.storage.StatusBackingStore; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.apache.kafka.connect.util.LoggingContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Locale; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +/** + * Handles processing for an individual task. This interface only provides the basic methods + * used by {@link Worker} to manage the tasks. Implementations combine a user-specified Task with + * Kafka to create a data flow. + * + * Note on locking: since the task runs in its own thread, special care must be taken to ensure + * that state transitions are reported correctly, in particular since some state transitions are + * asynchronous (e.g. pause/resume). For example, changing the state to paused could cause a race + * if the task fails at the same time. To protect from these cases, we synchronize status updates + * using the WorkerTask's monitor. + */ +abstract class WorkerTask implements Runnable { + private static final Logger log = LoggerFactory.getLogger(WorkerTask.class); + private static final String THREAD_NAME_PREFIX = "task-thread-"; + + protected final ConnectorTaskId id; + private final TaskStatus.Listener statusListener; + protected final ClassLoader loader; + protected final StatusBackingStore statusBackingStore; + protected final Time time; + private final CountDownLatch shutdownLatch = new CountDownLatch(1); + private final TaskMetricsGroup taskMetricsGroup; + private volatile TargetState targetState; + private volatile boolean stopping; // indicates whether the Worker has asked the task to stop + private volatile boolean cancelled; // indicates whether the Worker has cancelled the task (e.g. because of slow shutdown) + + protected final RetryWithToleranceOperator retryWithToleranceOperator; + + public WorkerTask(ConnectorTaskId id, + TaskStatus.Listener statusListener, + TargetState initialState, + ClassLoader loader, + ConnectMetrics connectMetrics, + RetryWithToleranceOperator retryWithToleranceOperator, + Time time, + StatusBackingStore statusBackingStore) { + this.id = id; + this.taskMetricsGroup = new TaskMetricsGroup(this.id, connectMetrics, statusListener); + this.statusListener = taskMetricsGroup; + this.loader = loader; + this.targetState = initialState; + this.stopping = false; + this.cancelled = false; + this.taskMetricsGroup.recordState(this.targetState); + this.retryWithToleranceOperator = retryWithToleranceOperator; + this.time = time; + this.statusBackingStore = statusBackingStore; + } + + public ConnectorTaskId id() { + return id; + } + + public ClassLoader loader() { + return loader; + } + + /** + * Initialize the task for execution. + * + * @param taskConfig initial configuration + */ + public abstract void initialize(TaskConfig taskConfig); + + + private void triggerStop() { + synchronized (this) { + stopping = true; + + // wakeup any threads that are waiting for unpause + this.notifyAll(); + } + } + + /** + * Stop this task from processing messages. This method does not block, it only triggers + * shutdown. Use #{@link #awaitStop} to block until completion. + */ + public void stop() { + triggerStop(); + } + + /** + * Cancel this task. This won't actually stop it, but it will prevent the state from being + * updated when it eventually does shutdown. + */ + public void cancel() { + cancelled = true; + } + + /** + * Wait for this task to finish stopping. + * + * @param timeoutMs time in milliseconds to await stop + * @return true if successful, false if the timeout was reached + */ + public boolean awaitStop(long timeoutMs) { + try { + return shutdownLatch.await(timeoutMs, TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + return false; + } + } + + /** + * Remove all metrics published by this task. + */ + public void removeMetrics() { + taskMetricsGroup.close(); + } + + protected abstract void initializeAndStart(); + + protected abstract void execute(); + + protected abstract void close(); + + protected boolean isStopping() { + return stopping; + } + + protected boolean isCancelled() { + return cancelled; + } + + private void doClose() { + try { + close(); + } catch (Throwable t) { + log.error("{} Task threw an uncaught and unrecoverable exception during shutdown", this, t); + throw t; + } + } + + private void doRun() throws InterruptedException { + try { + synchronized (this) { + if (stopping) + return; + + if (targetState == TargetState.PAUSED) { + onPause(); + if (!awaitUnpause()) return; + } + } + + initializeAndStart(); + statusListener.onStartup(id); + execute(); + } catch (Throwable t) { + if (cancelled) { + log.warn("{} After being scheduled for shutdown, the orphan task threw an uncaught exception. A newer instance of this task might be already running", this, t); + } else if (stopping) { + log.warn("{} After being scheduled for shutdown, task threw an uncaught exception.", this, t); + } else { + log.error("{} Task threw an uncaught and unrecoverable exception. Task is being killed and will not recover until manually restarted", this, t); + throw t; + } + } finally { + doClose(); + } + } + + private void onShutdown() { + synchronized (this) { + triggerStop(); + + // if we were cancelled, skip the status update since the task may have already been + // started somewhere else + if (!cancelled) + statusListener.onShutdown(id); + } + } + + protected void onFailure(Throwable t) { + synchronized (this) { + triggerStop(); + + // if we were cancelled, skip the status update since the task may have already been + // started somewhere else + if (!cancelled) + statusListener.onFailure(id, t); + } + } + + protected synchronized void onPause() { + statusListener.onPause(id); + } + + protected synchronized void onResume() { + statusListener.onResume(id); + } + + @Override + public void run() { + // Clear all MDC parameters, in case this thread is being reused + LoggingContext.clear(); + + try (LoggingContext loggingContext = LoggingContext.forTask(id())) { + ClassLoader savedLoader = Plugins.compareAndSwapLoaders(loader); + String savedName = Thread.currentThread().getName(); + try { + Thread.currentThread().setName(THREAD_NAME_PREFIX + id); + doRun(); + onShutdown(); + } catch (Throwable t) { + onFailure(t); + + if (t instanceof Error) + throw (Error) t; + } finally { + Thread.currentThread().setName(savedName); + Plugins.compareAndSwapLoaders(savedLoader); + shutdownLatch.countDown(); + } + } + } + + public boolean shouldPause() { + return this.targetState == TargetState.PAUSED; + } + + /** + * Await task resumption. + * + * @return true if the task's target state is not paused, false if the task is shutdown before resumption + * @throws InterruptedException + */ + protected boolean awaitUnpause() throws InterruptedException { + synchronized (this) { + while (targetState == TargetState.PAUSED) { + if (stopping) + return false; + this.wait(); + } + return true; + } + } + + public void transitionTo(TargetState state) { + synchronized (this) { + // ignore the state change if we are stopping + if (stopping) + return; + + this.targetState = state; + this.notifyAll(); + } + } + + /** + * Include this topic to the set of active topics for the connector that this worker task + * is running. This information is persisted in the status backing store used by this worker. + * + * @param topic the topic to mark as active for this connector + */ + protected void recordActiveTopic(String topic) { + if (statusBackingStore.getTopic(id.connector(), topic) != null) { + // The topic is already recorded as active. No further action is required. + return; + } + statusBackingStore.put(new TopicStatus(topic, id, time.milliseconds())); + } + + /** + * Record that offsets have been committed. + * + * @param duration the length of time in milliseconds for the commit attempt to complete + */ + protected void recordCommitSuccess(long duration) { + taskMetricsGroup.recordCommit(duration, true, null); + } + + /** + * Record that offsets have been committed. + * + * @param duration the length of time in milliseconds for the commit attempt to complete + * @param error the unexpected error that occurred; may be null in the case of timeouts or interruptions + */ + protected void recordCommitFailure(long duration, Throwable error) { + taskMetricsGroup.recordCommit(duration, false, error); + } + + /** + * Record that a batch of records has been processed. + * + * @param size the number of records in the batch + */ + protected void recordBatch(int size) { + taskMetricsGroup.recordBatch(size); + } + + TaskMetricsGroup taskMetricsGroup() { + return taskMetricsGroup; + } + + static class TaskMetricsGroup implements TaskStatus.Listener { + private final TaskStatus.Listener delegateListener; + private final MetricGroup metricGroup; + private final Time time; + private final StateTracker taskStateTimer; + private final Sensor commitTime; + private final Sensor batchSize; + private final Sensor commitAttempts; + + public TaskMetricsGroup(ConnectorTaskId id, ConnectMetrics connectMetrics, TaskStatus.Listener statusListener) { + delegateListener = statusListener; + time = connectMetrics.time(); + taskStateTimer = new StateTracker(); + ConnectMetricsRegistry registry = connectMetrics.registry(); + metricGroup = connectMetrics.group(registry.taskGroupName(), + registry.connectorTagName(), id.connector(), + registry.taskTagName(), Integer.toString(id.task())); + // prevent collisions by removing any previously created metrics in this group. + metricGroup.close(); + + metricGroup.addValueMetric(registry.taskStatus, now -> + taskStateTimer.currentState().toString().toLowerCase(Locale.getDefault()) + ); + + addRatioMetric(State.RUNNING, registry.taskRunningRatio); + addRatioMetric(State.PAUSED, registry.taskPauseRatio); + + commitTime = metricGroup.sensor("commit-time"); + commitTime.add(metricGroup.metricName(registry.taskCommitTimeMax), new Max()); + commitTime.add(metricGroup.metricName(registry.taskCommitTimeAvg), new Avg()); + + batchSize = metricGroup.sensor("batch-size"); + batchSize.add(metricGroup.metricName(registry.taskBatchSizeMax), new Max()); + batchSize.add(metricGroup.metricName(registry.taskBatchSizeAvg), new Avg()); + + MetricName offsetCommitFailures = metricGroup.metricName(registry.taskCommitFailurePercentage); + MetricName offsetCommitSucceeds = metricGroup.metricName(registry.taskCommitSuccessPercentage); + Frequencies commitFrequencies = Frequencies.forBooleanValues(offsetCommitFailures, offsetCommitSucceeds); + commitAttempts = metricGroup.sensor("offset-commit-completion"); + commitAttempts.add(commitFrequencies); + } + + private void addRatioMetric(final State matchingState, MetricNameTemplate template) { + MetricName metricName = metricGroup.metricName(template); + if (metricGroup.metrics().metric(metricName) == null) { + metricGroup.metrics().addMetric(metricName, (config, now) -> + taskStateTimer.durationRatio(matchingState, now)); + } + } + + void close() { + metricGroup.close(); + } + + void recordCommit(long duration, boolean success, Throwable error) { + if (success) { + commitTime.record(duration); + commitAttempts.record(1.0d); + } else { + commitAttempts.record(0.0d); + } + } + + void recordBatch(int size) { + batchSize.record(size); + } + + @Override + public void onStartup(ConnectorTaskId id) { + taskStateTimer.changeState(State.RUNNING, time.milliseconds()); + delegateListener.onStartup(id); + } + + @Override + public void onFailure(ConnectorTaskId id, Throwable cause) { + taskStateTimer.changeState(State.FAILED, time.milliseconds()); + delegateListener.onFailure(id, cause); + } + + @Override + public void onPause(ConnectorTaskId id) { + taskStateTimer.changeState(State.PAUSED, time.milliseconds()); + delegateListener.onPause(id); + } + + @Override + public void onResume(ConnectorTaskId id) { + taskStateTimer.changeState(State.RUNNING, time.milliseconds()); + delegateListener.onResume(id); + } + + @Override + public void onShutdown(ConnectorTaskId id) { + taskStateTimer.changeState(State.UNASSIGNED, time.milliseconds()); + delegateListener.onShutdown(id); + } + + @Override + public void onDeletion(ConnectorTaskId id) { + taskStateTimer.changeState(State.DESTROYED, time.milliseconds()); + delegateListener.onDeletion(id); + } + + @Override + public void onRestart(ConnectorTaskId id) { + taskStateTimer.changeState(State.RESTARTING, time.milliseconds()); + delegateListener.onRestart(id); + } + + public void recordState(TargetState state) { + switch (state) { + case STARTED: + taskStateTimer.changeState(State.RUNNING, time.milliseconds()); + break; + case PAUSED: + taskStateTimer.changeState(State.PAUSED, time.milliseconds()); + break; + default: + break; + } + } + + public State state() { + return taskStateTimer.currentState(); + } + + protected MetricGroup metricGroup() { + return metricGroup; + } + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/ClusterConfigState.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/ClusterConfigState.java new file mode 100644 index 0000000..717120d --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/ClusterConfigState.java @@ -0,0 +1,283 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.distributed; + +import org.apache.kafka.common.config.provider.ConfigProvider; +import org.apache.kafka.connect.runtime.SessionKey; +import org.apache.kafka.connect.runtime.WorkerConfigTransformer; +import org.apache.kafka.connect.runtime.TargetState; +import org.apache.kafka.connect.util.ConnectorTaskId; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.TreeMap; + +/** + * An immutable snapshot of the configuration state of connectors and tasks in a Kafka Connect cluster. + */ +public class ClusterConfigState { + public static final long NO_OFFSET = -1; + public static final ClusterConfigState EMPTY = new ClusterConfigState( + NO_OFFSET, + null, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptySet()); + + private final long offset; + private final SessionKey sessionKey; + private final Map connectorTaskCounts; + private final Map> connectorConfigs; + private final Map connectorTargetStates; + private final Map> taskConfigs; + private final Set inconsistentConnectors; + private final WorkerConfigTransformer configTransformer; + + public ClusterConfigState(long offset, + SessionKey sessionKey, + Map connectorTaskCounts, + Map> connectorConfigs, + Map connectorTargetStates, + Map> taskConfigs, + Set inconsistentConnectors) { + this(offset, + sessionKey, + connectorTaskCounts, + connectorConfigs, + connectorTargetStates, + taskConfigs, + inconsistentConnectors, + null); + } + + public ClusterConfigState(long offset, + SessionKey sessionKey, + Map connectorTaskCounts, + Map> connectorConfigs, + Map connectorTargetStates, + Map> taskConfigs, + Set inconsistentConnectors, + WorkerConfigTransformer configTransformer) { + this.offset = offset; + this.sessionKey = sessionKey; + this.connectorTaskCounts = connectorTaskCounts; + this.connectorConfigs = connectorConfigs; + this.connectorTargetStates = connectorTargetStates; + this.taskConfigs = taskConfigs; + this.inconsistentConnectors = inconsistentConnectors; + this.configTransformer = configTransformer; + } + + /** + * Get the last offset read to generate this config state. This offset is not guaranteed to be perfectly consistent + * with the recorded state because some partial updates to task configs may have been read. + * @return the latest config offset + */ + public long offset() { + return offset; + } + + /** + * Get the latest session key from the config state + * @return the {@link SessionKey session key}; may be null if no key has been read yet + */ + public SessionKey sessionKey() { + return sessionKey; + } + + /** + * Check whether this snapshot contains configuration for a connector. + * @param connector name of the connector + * @return true if this state contains configuration for the connector, false otherwise + */ + public boolean contains(String connector) { + return connectorConfigs.containsKey(connector); + } + + /** + * Get a list of the connectors in this configuration + */ + public Set connectors() { + return connectorConfigs.keySet(); + } + + /** + * Get the configuration for a connector. The configuration will have been transformed by + * {@link org.apache.kafka.common.config.ConfigTransformer} by having all variable + * references replaced with the current values from external instances of + * {@link ConfigProvider}, and may include secrets. + * @param connector name of the connector + * @return a map containing configuration parameters + */ + public Map connectorConfig(String connector) { + Map configs = connectorConfigs.get(connector); + if (configTransformer != null) { + configs = configTransformer.transform(connector, configs); + } + return configs; + } + + public Map rawConnectorConfig(String connector) { + return connectorConfigs.get(connector); + } + + /** + * Get the target state of the connector + * @param connector name of the connector + * @return the target state + */ + public TargetState targetState(String connector) { + return connectorTargetStates.get(connector); + } + + /** + * Get the configuration for a task. The configuration will have been transformed by + * {@link org.apache.kafka.common.config.ConfigTransformer} by having all variable + * references replaced with the current values from external instances of + * {@link ConfigProvider}, and may include secrets. + * @param task id of the task + * @return a map containing configuration parameters + */ + public Map taskConfig(ConnectorTaskId task) { + Map configs = taskConfigs.get(task); + if (configTransformer != null) { + configs = configTransformer.transform(task.connector(), configs); + } + return configs; + } + + public Map rawTaskConfig(ConnectorTaskId task) { + return taskConfigs.get(task); + } + + /** + * Get all task configs for a connector. The configurations will have been transformed by + * {@link org.apache.kafka.common.config.ConfigTransformer} by having all variable + * references replaced with the current values from external instances of + * {@link ConfigProvider}, and may include secrets. + * @param connector name of the connector + * @return a list of task configurations + */ + public List> allTaskConfigs(String connector) { + Map> taskConfigs = new TreeMap<>(); + for (Map.Entry> taskConfigEntry : this.taskConfigs.entrySet()) { + if (taskConfigEntry.getKey().connector().equals(connector)) { + Map configs = taskConfigEntry.getValue(); + if (configTransformer != null) { + configs = configTransformer.transform(connector, configs); + } + taskConfigs.put(taskConfigEntry.getKey().task(), configs); + } + } + return Collections.unmodifiableList(new ArrayList<>(taskConfigs.values())); + } + + /** + * Get the number of tasks assigned for the given connector. + * @param connectorName name of the connector to look up tasks for + * @return the number of tasks + */ + public int taskCount(String connectorName) { + Integer count = connectorTaskCounts.get(connectorName); + return count == null ? 0 : count; + } + + /** + * Get the current set of task IDs for the specified connector. + * @param connectorName the name of the connector to look up task configs for + * @return the current set of connector task IDs + */ + public List tasks(String connectorName) { + if (inconsistentConnectors.contains(connectorName)) { + return Collections.emptyList(); + } + + Integer numTasks = connectorTaskCounts.get(connectorName); + if (numTasks == null) { + return Collections.emptyList(); + } + + List taskIds = new ArrayList<>(numTasks); + for (int taskIndex = 0; taskIndex < numTasks; taskIndex++) { + ConnectorTaskId taskId = new ConnectorTaskId(connectorName, taskIndex); + taskIds.add(taskId); + } + return Collections.unmodifiableList(taskIds); + } + + /** + * Get the set of connectors which have inconsistent data in this snapshot. These inconsistencies can occur due to + * partially completed writes combined with log compaction. + * + * Connectors in this set will appear in the output of {@link #connectors()} since their connector configuration is + * available, but not in the output of {@link #taskConfig(ConnectorTaskId)} since the task configs are incomplete. + * + * When a worker detects a connector in this state, it should request that the connector regenerate its task + * configurations. + * + * @return the set of inconsistent connectors + */ + public Set inconsistentConnectors() { + return inconsistentConnectors; + } + + @Override + public String toString() { + return "ClusterConfigState{" + + "offset=" + offset + + ", sessionKey=" + (sessionKey != null ? "[hidden]" : "null") + + ", connectorTaskCounts=" + connectorTaskCounts + + ", connectorConfigs=" + connectorConfigs + + ", taskConfigs=" + taskConfigs + + ", inconsistentConnectors=" + inconsistentConnectors + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ClusterConfigState that = (ClusterConfigState) o; + return offset == that.offset && + Objects.equals(sessionKey, that.sessionKey) && + Objects.equals(connectorTaskCounts, that.connectorTaskCounts) && + Objects.equals(connectorConfigs, that.connectorConfigs) && + Objects.equals(connectorTargetStates, that.connectorTargetStates) && + Objects.equals(taskConfigs, that.taskConfigs) && + Objects.equals(inconsistentConnectors, that.inconsistentConnectors) && + Objects.equals(configTransformer, that.configTransformer); + } + + @Override + public int hashCode() { + return Objects.hash( + offset, + sessionKey, + connectorTaskCounts, + connectorConfigs, + connectorTargetStates, + taskConfigs, + inconsistentConnectors, + configTransformer); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/ConnectAssignor.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/ConnectAssignor.java new file mode 100644 index 0000000..752e62e --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/ConnectAssignor.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.distributed; + +import org.apache.kafka.common.message.JoinGroupResponseData; + +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Map; + +/** + * An assignor that computes a distribution of connectors and tasks among the workers of the group + * that performs rebalancing. + */ +public interface ConnectAssignor { + /** + * Based on the member metadata and the information stored in the worker coordinator this + * method computes an assignment of connectors and tasks among the members of the worker group. + * + * @param leaderId the leader of the group + * @param protocol the protocol type; for Connect assignors this is normally "connect" + * @param allMemberMetadata the metadata of all the active workers of the group + * @param coordinator the worker coordinator that runs this assignor + * @return the assignment of connectors and tasks to workers + */ + Map performAssignment(String leaderId, String protocol, + List allMemberMetadata, + WorkerCoordinator coordinator); +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/ConnectProtocol.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/ConnectProtocol.java new file mode 100644 index 0000000..c167e80 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/ConnectProtocol.java @@ -0,0 +1,407 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.distributed; + +import org.apache.kafka.common.protocol.types.ArrayOf; +import org.apache.kafka.common.protocol.types.Field; +import org.apache.kafka.common.protocol.types.Schema; +import org.apache.kafka.common.protocol.types.SchemaException; +import org.apache.kafka.common.protocol.types.Struct; +import org.apache.kafka.common.protocol.types.Type; +import org.apache.kafka.connect.util.ConnectorTaskId; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.apache.kafka.common.message.JoinGroupRequestData.JoinGroupRequestProtocol; +import static org.apache.kafka.common.message.JoinGroupRequestData.JoinGroupRequestProtocolCollection; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocolCompatibility.EAGER; + +/** + * This class implements the protocol for Kafka Connect workers in a group. It includes the format of worker state used when + * joining the group and distributing assignments, and the format of assignments of connectors and tasks to workers. + */ +public class ConnectProtocol { + public static final String VERSION_KEY_NAME = "version"; + public static final String URL_KEY_NAME = "url"; + public static final String CONFIG_OFFSET_KEY_NAME = "config-offset"; + public static final String CONNECTOR_KEY_NAME = "connector"; + public static final String LEADER_KEY_NAME = "leader"; + public static final String LEADER_URL_KEY_NAME = "leader-url"; + public static final String ERROR_KEY_NAME = "error"; + public static final String TASKS_KEY_NAME = "tasks"; + public static final String ASSIGNMENT_KEY_NAME = "assignment"; + public static final int CONNECTOR_TASK = -1; + + public static final short CONNECT_PROTOCOL_V0 = 0; + public static final Schema CONNECT_PROTOCOL_HEADER_SCHEMA = new Schema( + new Field(VERSION_KEY_NAME, Type.INT16)); + + /** + * Connect Protocol Header V0: + *
            +     *   Version            => Int16
            +     * 
            + */ + private static final Struct CONNECT_PROTOCOL_HEADER_V0 = new Struct(CONNECT_PROTOCOL_HEADER_SCHEMA) + .set(VERSION_KEY_NAME, CONNECT_PROTOCOL_V0); + + /** + * Config State V0: + *
            +     *   Url                => [String]
            +     *   ConfigOffset       => Int64
            +     * 
            + */ + public static final Schema CONFIG_STATE_V0 = new Schema( + new Field(URL_KEY_NAME, Type.STRING), + new Field(CONFIG_OFFSET_KEY_NAME, Type.INT64)); + + /** + * Connector Assignment V0: + *
            +     *   Connector          => [String]
            +     *   Tasks              => [Int32]
            +     * 
            + * + *

            Assignments for each worker are a set of connectors and tasks. These are categorized by + * connector ID. A sentinel task ID (CONNECTOR_TASK) is used to indicate the connector itself + * (i.e. that the assignment includes responsibility for running the Connector instance in + * addition to any tasks it generates).

            + */ + public static final Schema CONNECTOR_ASSIGNMENT_V0 = new Schema( + new Field(CONNECTOR_KEY_NAME, Type.STRING), + new Field(TASKS_KEY_NAME, new ArrayOf(Type.INT32))); + + /** + * Assignment V0: + *
            +     *   Error              => Int16
            +     *   Leader             => [String]
            +     *   LeaderUrl          => [String]
            +     *   ConfigOffset       => Int64
            +     *   Assignment         => [Connector Assignment]
            +     * 
            + */ + public static final Schema ASSIGNMENT_V0 = new Schema( + new Field(ERROR_KEY_NAME, Type.INT16), + new Field(LEADER_KEY_NAME, Type.STRING), + new Field(LEADER_URL_KEY_NAME, Type.STRING), + new Field(CONFIG_OFFSET_KEY_NAME, Type.INT64), + new Field(ASSIGNMENT_KEY_NAME, new ArrayOf(CONNECTOR_ASSIGNMENT_V0))); + + /** + * The fields are serialized in sequence as follows: + * Subscription V0: + *
            +     *   Version            => Int16
            +     *   Url                => [String]
            +     *   ConfigOffset       => Int64
            +     * 
            + * + * @param workerState the current state of the worker metadata + * @return the serialized state of the worker metadata + */ + public static ByteBuffer serializeMetadata(WorkerState workerState) { + Struct struct = new Struct(CONFIG_STATE_V0); + struct.set(URL_KEY_NAME, workerState.url()); + struct.set(CONFIG_OFFSET_KEY_NAME, workerState.offset()); + ByteBuffer buffer = ByteBuffer.allocate(CONNECT_PROTOCOL_HEADER_V0.sizeOf() + CONFIG_STATE_V0.sizeOf(struct)); + CONNECT_PROTOCOL_HEADER_V0.writeTo(buffer); + CONFIG_STATE_V0.write(buffer, struct); + buffer.flip(); + return buffer; + } + + /** + * Returns the collection of Connect protocols that are supported by this version along + * with their serialized metadata. The protocols are ordered by preference. + * + * @param workerState the current state of the worker metadata + * @return the collection of Connect protocol metadata + */ + public static JoinGroupRequestProtocolCollection metadataRequest(WorkerState workerState) { + return new JoinGroupRequestProtocolCollection(Collections.singleton( + new JoinGroupRequestProtocol() + .setName(EAGER.protocol()) + .setMetadata(ConnectProtocol.serializeMetadata(workerState).array())) + .iterator()); + } + + /** + * Given a byte buffer that contains protocol metadata return the deserialized form of the + * metadata. + * + * @param buffer A buffer containing the protocols metadata + * @return the deserialized metadata + * @throws SchemaException on incompatible Connect protocol version + */ + public static WorkerState deserializeMetadata(ByteBuffer buffer) { + Struct header = CONNECT_PROTOCOL_HEADER_SCHEMA.read(buffer); + Short version = header.getShort(VERSION_KEY_NAME); + checkVersionCompatibility(version); + Struct struct = CONFIG_STATE_V0.read(buffer); + long configOffset = struct.getLong(CONFIG_OFFSET_KEY_NAME); + String url = struct.getString(URL_KEY_NAME); + return new WorkerState(url, configOffset); + } + + /** + * The fields are serialized in sequence as follows: + * Complete Assignment V0: + *
            +     *   Version            => Int16
            +     *   Error              => Int16
            +     *   Leader             => [String]
            +     *   LeaderUrl          => [String]
            +     *   ConfigOffset       => Int64
            +     *   Assignment         => [Connector Assignment]
            +     * 
            + */ + public static ByteBuffer serializeAssignment(Assignment assignment) { + Struct struct = new Struct(ASSIGNMENT_V0); + struct.set(ERROR_KEY_NAME, assignment.error()); + struct.set(LEADER_KEY_NAME, assignment.leader()); + struct.set(LEADER_URL_KEY_NAME, assignment.leaderUrl()); + struct.set(CONFIG_OFFSET_KEY_NAME, assignment.offset()); + List taskAssignments = new ArrayList<>(); + for (Map.Entry> connectorEntry : assignment.asMap().entrySet()) { + Struct taskAssignment = new Struct(CONNECTOR_ASSIGNMENT_V0); + taskAssignment.set(CONNECTOR_KEY_NAME, connectorEntry.getKey()); + Collection tasks = connectorEntry.getValue(); + taskAssignment.set(TASKS_KEY_NAME, tasks.toArray()); + taskAssignments.add(taskAssignment); + } + struct.set(ASSIGNMENT_KEY_NAME, taskAssignments.toArray()); + + ByteBuffer buffer = ByteBuffer.allocate(CONNECT_PROTOCOL_HEADER_V0.sizeOf() + ASSIGNMENT_V0.sizeOf(struct)); + CONNECT_PROTOCOL_HEADER_V0.writeTo(buffer); + ASSIGNMENT_V0.write(buffer, struct); + buffer.flip(); + return buffer; + } + + /** + * Given a byte buffer that contains an assignment as defined by this protocol, return the + * deserialized form of the assignment. + * + * @param buffer the buffer containing a serialized assignment + * @return the deserialized assignment + * @throws SchemaException on incompatible Connect protocol version + */ + public static Assignment deserializeAssignment(ByteBuffer buffer) { + Struct header = CONNECT_PROTOCOL_HEADER_SCHEMA.read(buffer); + Short version = header.getShort(VERSION_KEY_NAME); + checkVersionCompatibility(version); + Struct struct = ASSIGNMENT_V0.read(buffer); + short error = struct.getShort(ERROR_KEY_NAME); + String leader = struct.getString(LEADER_KEY_NAME); + String leaderUrl = struct.getString(LEADER_URL_KEY_NAME); + long offset = struct.getLong(CONFIG_OFFSET_KEY_NAME); + List connectorIds = new ArrayList<>(); + List taskIds = new ArrayList<>(); + for (Object structObj : struct.getArray(ASSIGNMENT_KEY_NAME)) { + Struct assignment = (Struct) structObj; + String connector = assignment.getString(CONNECTOR_KEY_NAME); + for (Object taskIdObj : assignment.getArray(TASKS_KEY_NAME)) { + Integer taskId = (Integer) taskIdObj; + if (taskId == CONNECTOR_TASK) + connectorIds.add(connector); + else + taskIds.add(new ConnectorTaskId(connector, taskId)); + } + } + return new Assignment(error, leader, leaderUrl, offset, connectorIds, taskIds); + } + + /** + * A class that captures the deserialized form of a worker's metadata. + */ + public static class WorkerState { + private final String url; + private final long offset; + + public WorkerState(String url, long offset) { + this.url = url; + this.offset = offset; + } + + public String url() { + return url; + } + + /** + * The most up-to-date (maximum) configuration offset according known to this worker. + * + * @return the configuration offset + */ + public long offset() { + return offset; + } + + @Override + public String toString() { + return "WorkerState{" + + "url='" + url + '\'' + + ", offset=" + offset + + '}'; + } + } + + /** + * The basic assignment of connectors and tasks introduced with V0 version of the Connect protocol. + */ + public static class Assignment { + public static final short NO_ERROR = 0; + // Configuration offsets mismatched in a way that the leader could not resolve. Workers should read to the end + // of the config log and try to re-join + public static final short CONFIG_MISMATCH = 1; + + private final short error; + private final String leader; + private final String leaderUrl; + private final long offset; + private final Collection connectorIds; + private final Collection taskIds; + + /** + * Create an assignment indicating responsibility for the given connector instances and task Ids. + * + * @param error error code for this assignment; {@code ConnectProtocol.Assignment.NO_ERROR} + * indicates no error during assignment + * @param leader Connect group's leader Id; may be null only on the empty assignment + * @param leaderUrl Connect group's leader URL; may be null only on the empty assignment + * @param configOffset the most up-to-date configuration offset according to this assignment + * @param connectorIds list of connectors that the worker should instantiate and run; may not be null + * @param taskIds list of task IDs that the worker should instantiate and run; may not be null + */ + public Assignment(short error, String leader, String leaderUrl, long configOffset, + Collection connectorIds, Collection taskIds) { + this.error = error; + this.leader = leader; + this.leaderUrl = leaderUrl; + this.offset = configOffset; + this.connectorIds = Objects.requireNonNull(connectorIds, + "Assigned connector IDs may be empty but not null"); + this.taskIds = Objects.requireNonNull(taskIds, + "Assigned task IDs may be empty but not null"); + } + + /** + * Return the error code of this assignment; 0 signals successful assignment ({@code ConnectProtocol.Assignment.NO_ERROR}). + * + * @return the error code of the assignment + */ + public short error() { + return error; + } + + /** + * Return the ID of the leader Connect worker in this assignment. + * + * @return the ID of the leader + */ + public String leader() { + return leader; + } + + /** + * Return the URL to which the leader accepts requests from other members of the group. + * + * @return the leader URL + */ + public String leaderUrl() { + return leaderUrl; + } + + /** + * Check if this assignment failed. + * + * @return true if this assignment failed; false otherwise + */ + public boolean failed() { + return error != NO_ERROR; + } + + /** + * Return the most up-to-date offset in the configuration topic according to this assignment + * + * @return the configuration topic + */ + public long offset() { + return offset; + } + + /** + * The connectors included in this assignment. + * + * @return the connectors + */ + public Collection connectors() { + return connectorIds; + } + + /** + * The tasks included in this assignment. + * + * @return the tasks + */ + public Collection tasks() { + return taskIds; + } + + @Override + public String toString() { + return "Assignment{" + + "error=" + error + + ", leader='" + leader + '\'' + + ", leaderUrl='" + leaderUrl + '\'' + + ", offset=" + offset + + ", connectorIds=" + connectorIds + + ", taskIds=" + taskIds + + '}'; + } + + protected Map> asMap() { + // Using LinkedHashMap preserves the ordering, which is helpful for tests and debugging + Map> taskMap = new LinkedHashMap<>(); + for (String connectorId : new HashSet<>(connectorIds)) { + taskMap.computeIfAbsent(connectorId, key -> new ArrayList<>()).add(CONNECTOR_TASK); + } + for (ConnectorTaskId taskId : taskIds) { + String connectorId = taskId.connector(); + taskMap.computeIfAbsent(connectorId, key -> new ArrayList<>()).add(taskId.task()); + } + return taskMap; + } + } + + private static void checkVersionCompatibility(short version) { + // check for invalid versions + if (version < CONNECT_PROTOCOL_V0) + throw new SchemaException("Unsupported subscription version: " + version); + + // otherwise, assume versions can be parsed as V0 + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/ConnectProtocolCompatibility.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/ConnectProtocolCompatibility.java new file mode 100644 index 0000000..d618fe2 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/ConnectProtocolCompatibility.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.distributed; + +import java.util.Arrays; +import java.util.Locale; + +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocol.CONNECT_PROTOCOL_V0; +import static org.apache.kafka.connect.runtime.distributed.IncrementalCooperativeConnectProtocol.CONNECT_PROTOCOL_V1; +import static org.apache.kafka.connect.runtime.distributed.IncrementalCooperativeConnectProtocol.CONNECT_PROTOCOL_V2; + +/** + * An enumeration of the modes available to the worker to signal which Connect protocols are + * enabled at any time. + * + * {@code EAGER} signifies that this worker only supports prompt release of assigned connectors + * and tasks in every rebalance. Corresponds to Connect protocol V0. + * + * {@code COMPATIBLE} signifies that this worker supports both eager and incremental cooperative + * Connect protocols and will use the version that is elected by the Kafka broker coordinator + * during rebalance. + * + * {@code SESSIONED} signifies that this worker supports all of the above protocols in addition to + * a protocol that uses incremental cooperative rebalancing for worker assignment and uses session + * keys distributed via the config topic to verify internal REST requests + */ +public enum ConnectProtocolCompatibility { + EAGER { + @Override + public String protocol() { + return "default"; + } + + @Override + public short protocolVersion() { + return CONNECT_PROTOCOL_V0; + } + }, + + COMPATIBLE { + @Override + public String protocol() { + return "compatible"; + } + + @Override + public short protocolVersion() { + return CONNECT_PROTOCOL_V1; + } + }, + + SESSIONED { + @Override + public String protocol() { + return "sessioned"; + } + + @Override + public short protocolVersion() { + return CONNECT_PROTOCOL_V2; + } + }; + + /** + * Return the enum that corresponds to the name that is given as an argument; + * if no mapping is found {@code IllegalArgumentException} is thrown. + * + * @param name the name of the protocol compatibility mode + * @return the enum that corresponds to the protocol compatibility mode + */ + public static ConnectProtocolCompatibility compatibility(String name) { + return Arrays.stream(ConnectProtocolCompatibility.values()) + .filter(mode -> mode.name().equalsIgnoreCase(name)) + .findFirst() + .orElseThrow(() -> new IllegalArgumentException( + "Unknown Connect protocol compatibility mode: " + name)); + } + + /** + * Return the enum that corresponds to the Connect protocol version that is given as an argument; + * if no mapping is found {@code IllegalArgumentException} is thrown. + * + * @param protocolVersion the version of the protocol; for example, + * {@link ConnectProtocol#CONNECT_PROTOCOL_V0 CONNECT_PROTOCOL_V0}. May not be null + * @return the enum that corresponds to the protocol compatibility mode + */ + public static ConnectProtocolCompatibility fromProtocolVersion(short protocolVersion) { + switch (protocolVersion) { + case CONNECT_PROTOCOL_V0: + return EAGER; + case CONNECT_PROTOCOL_V1: + return COMPATIBLE; + case CONNECT_PROTOCOL_V2: + return SESSIONED; + default: + throw new IllegalArgumentException("Unknown Connect protocol version: " + protocolVersion); + } + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } + + /** + * Return the version of the protocol for this mode. + * + * @return the protocol version + */ + public abstract short protocolVersion(); + + /** + * Return the name of the protocol that this mode will use in {@code ProtocolMetadata}. + * + * @return the protocol name + */ + public abstract String protocol(); + + /** + * Return the enum that corresponds to the protocol name that is given as an argument; + * if no mapping is found {@code IllegalArgumentException} is thrown. + * + * @param protocolName the name of the connect protocol + * @return the enum that corresponds to the protocol compatibility mode that supports the + * given protocol + */ + public static ConnectProtocolCompatibility fromProtocol(String protocolName) { + return Arrays.stream(ConnectProtocolCompatibility.values()) + .filter(mode -> mode.protocol().equalsIgnoreCase(protocolName)) + .findFirst() + .orElseThrow(() -> new IllegalArgumentException( + "Not found Connect protocol compatibility mode for protocol: " + protocolName)); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/DistributedConfig.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/DistributedConfig.java new file mode 100644 index 0000000..0823fbc --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/DistributedConfig.java @@ -0,0 +1,494 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.distributed; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.config.TopicConfig; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.crypto.KeyGenerator; +import javax.crypto.Mac; +import java.security.InvalidParameterException; +import java.security.NoSuchAlgorithmException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.TimeUnit; + +import static org.apache.kafka.common.config.ConfigDef.Range.atLeast; +import static org.apache.kafka.common.config.ConfigDef.Range.between; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.PARTITIONS_VALIDATOR; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.REPLICATION_FACTOR_VALIDATOR; + +public class DistributedConfig extends WorkerConfig { + + private static final Logger log = LoggerFactory.getLogger(DistributedConfig.class); + + /* + * NOTE: DO NOT CHANGE EITHER CONFIG STRINGS OR THEIR JAVA VARIABLE NAMES AS + * THESE ARE PART OF THE PUBLIC API AND CHANGE WILL BREAK USER CODE. + */ + + /** + * group.id + */ + public static final String GROUP_ID_CONFIG = CommonClientConfigs.GROUP_ID_CONFIG; + private static final String GROUP_ID_DOC = "A unique string that identifies the Connect cluster group this worker belongs to."; + + /** + * session.timeout.ms + */ + public static final String SESSION_TIMEOUT_MS_CONFIG = CommonClientConfigs.SESSION_TIMEOUT_MS_CONFIG; + private static final String SESSION_TIMEOUT_MS_DOC = "The timeout used to detect worker failures. " + + "The worker sends periodic heartbeats to indicate its liveness to the broker. If no heartbeats are " + + "received by the broker before the expiration of this session timeout, then the broker will remove the " + + "worker from the group and initiate a rebalance. Note that the value must be in the allowable range as " + + "configured in the broker configuration by group.min.session.timeout.ms " + + "and group.max.session.timeout.ms."; + + /** + * heartbeat.interval.ms + */ + public static final String HEARTBEAT_INTERVAL_MS_CONFIG = CommonClientConfigs.HEARTBEAT_INTERVAL_MS_CONFIG; + private static final String HEARTBEAT_INTERVAL_MS_DOC = "The expected time between heartbeats to the group " + + "coordinator when using Kafka's group management facilities. Heartbeats are used to ensure that the " + + "worker's session stays active and to facilitate rebalancing when new members join or leave the group. " + + "The value must be set lower than session.timeout.ms, but typically should be set no higher " + + "than 1/3 of that value. It can be adjusted even lower to control the expected time for normal rebalances."; + + /** + * rebalance.timeout.ms + */ + public static final String REBALANCE_TIMEOUT_MS_CONFIG = CommonClientConfigs.REBALANCE_TIMEOUT_MS_CONFIG; + private static final String REBALANCE_TIMEOUT_MS_DOC = CommonClientConfigs.REBALANCE_TIMEOUT_MS_DOC; + + /** + * worker.sync.timeout.ms + */ + public static final String WORKER_SYNC_TIMEOUT_MS_CONFIG = "worker.sync.timeout.ms"; + private static final String WORKER_SYNC_TIMEOUT_MS_DOC = "When the worker is out of sync with other workers and needs" + + " to resynchronize configurations, wait up to this amount of time before giving up, leaving the group, and" + + " waiting a backoff period before rejoining."; + + /** + * group.unsync.timeout.ms + */ + public static final String WORKER_UNSYNC_BACKOFF_MS_CONFIG = "worker.unsync.backoff.ms"; + private static final String WORKER_UNSYNC_BACKOFF_MS_DOC = "When the worker is out of sync with other workers and " + + " fails to catch up within worker.sync.timeout.ms, leave the Connect cluster for this long before rejoining."; + public static final int WORKER_UNSYNC_BACKOFF_MS_DEFAULT = 5 * 60 * 1000; + + public static final String CONFIG_STORAGE_PREFIX = "config.storage."; + public static final String OFFSET_STORAGE_PREFIX = "offset.storage."; + public static final String STATUS_STORAGE_PREFIX = "status.storage."; + public static final String TOPIC_SUFFIX = "topic"; + public static final String PARTITIONS_SUFFIX = "partitions"; + public static final String REPLICATION_FACTOR_SUFFIX = "replication.factor"; + + /** + * offset.storage.topic + */ + public static final String OFFSET_STORAGE_TOPIC_CONFIG = OFFSET_STORAGE_PREFIX + TOPIC_SUFFIX; + private static final String OFFSET_STORAGE_TOPIC_CONFIG_DOC = "The name of the Kafka topic where connector offsets are stored"; + + /** + * offset.storage.partitions + */ + public static final String OFFSET_STORAGE_PARTITIONS_CONFIG = OFFSET_STORAGE_PREFIX + PARTITIONS_SUFFIX; + private static final String OFFSET_STORAGE_PARTITIONS_CONFIG_DOC = "The number of partitions used when creating the offset storage topic"; + + /** + * offset.storage.replication.factor + */ + public static final String OFFSET_STORAGE_REPLICATION_FACTOR_CONFIG = OFFSET_STORAGE_PREFIX + REPLICATION_FACTOR_SUFFIX; + private static final String OFFSET_STORAGE_REPLICATION_FACTOR_CONFIG_DOC = "Replication factor used when creating the offset storage topic"; + + /** + * config.storage.topic + */ + public static final String CONFIG_TOPIC_CONFIG = CONFIG_STORAGE_PREFIX + TOPIC_SUFFIX; + private static final String CONFIG_TOPIC_CONFIG_DOC = "The name of the Kafka topic where connector configurations are stored"; + + /** + * config.storage.replication.factor + */ + public static final String CONFIG_STORAGE_REPLICATION_FACTOR_CONFIG = CONFIG_STORAGE_PREFIX + REPLICATION_FACTOR_SUFFIX; + private static final String CONFIG_STORAGE_REPLICATION_FACTOR_CONFIG_DOC = "Replication factor used when creating the configuration storage topic"; + + /** + * status.storage.topic + */ + public static final String STATUS_STORAGE_TOPIC_CONFIG = STATUS_STORAGE_PREFIX + TOPIC_SUFFIX; + public static final String STATUS_STORAGE_TOPIC_CONFIG_DOC = "The name of the Kafka topic where connector and task status are stored"; + + /** + * status.storage.partitions + */ + public static final String STATUS_STORAGE_PARTITIONS_CONFIG = STATUS_STORAGE_PREFIX + PARTITIONS_SUFFIX; + private static final String STATUS_STORAGE_PARTITIONS_CONFIG_DOC = "The number of partitions used when creating the status storage topic"; + + /** + * status.storage.replication.factor + */ + public static final String STATUS_STORAGE_REPLICATION_FACTOR_CONFIG = STATUS_STORAGE_PREFIX + REPLICATION_FACTOR_SUFFIX; + private static final String STATUS_STORAGE_REPLICATION_FACTOR_CONFIG_DOC = "Replication factor used when creating the status storage topic"; + + /** + * connect.protocol + */ + public static final String CONNECT_PROTOCOL_CONFIG = "connect.protocol"; + public static final String CONNECT_PROTOCOL_DOC = "Compatibility mode for Kafka Connect Protocol"; + public static final String CONNECT_PROTOCOL_DEFAULT = ConnectProtocolCompatibility.SESSIONED.toString(); + + /** + * scheduled.rebalance.max.delay.ms + */ + public static final String SCHEDULED_REBALANCE_MAX_DELAY_MS_CONFIG = "scheduled.rebalance.max.delay.ms"; + public static final String SCHEDULED_REBALANCE_MAX_DELAY_MS_DOC = "The maximum delay that is " + + "scheduled in order to wait for the return of one or more departed workers before " + + "rebalancing and reassigning their connectors and tasks to the group. During this " + + "period the connectors and tasks of the departed workers remain unassigned"; + public static final int SCHEDULED_REBALANCE_MAX_DELAY_MS_DEFAULT = Math.toIntExact(TimeUnit.SECONDS.toMillis(300)); + + public static final String INTER_WORKER_KEY_GENERATION_ALGORITHM_CONFIG = "inter.worker.key.generation.algorithm"; + public static final String INTER_WORKER_KEY_GENERATION_ALGORITHM_DOC = "The algorithm to use for generating internal request keys"; + public static final String INTER_WORKER_KEY_GENERATION_ALGORITHM_DEFAULT = "HmacSHA256"; + + public static final String INTER_WORKER_KEY_SIZE_CONFIG = "inter.worker.key.size"; + public static final String INTER_WORKER_KEY_SIZE_DOC = "The size of the key to use for signing internal requests, in bits. " + + "If null, the default key size for the key generation algorithm will be used."; + public static final Long INTER_WORKER_KEY_SIZE_DEFAULT = null; + + public static final String INTER_WORKER_KEY_TTL_MS_CONFIG = "inter.worker.key.ttl.ms"; + public static final String INTER_WORKER_KEY_TTL_MS_MS_DOC = "The TTL of generated session keys used for " + + "internal request validation (in milliseconds)"; + public static final int INTER_WORKER_KEY_TTL_MS_MS_DEFAULT = Math.toIntExact(TimeUnit.HOURS.toMillis(1)); + + public static final String INTER_WORKER_SIGNATURE_ALGORITHM_CONFIG = "inter.worker.signature.algorithm"; + public static final String INTER_WORKER_SIGNATURE_ALGORITHM_DOC = "The algorithm used to sign internal requests"; + public static final String INTER_WORKER_SIGNATURE_ALGORITHM_DEFAULT = "HmacSHA256"; + + public static final String INTER_WORKER_VERIFICATION_ALGORITHMS_CONFIG = "inter.worker.verification.algorithms"; + public static final String INTER_WORKER_VERIFICATION_ALGORITHMS_DOC = "A list of permitted algorithms for verifying internal requests"; + public static final List INTER_WORKER_VERIFICATION_ALGORITHMS_DEFAULT = Collections.singletonList(INTER_WORKER_SIGNATURE_ALGORITHM_DEFAULT); + + @SuppressWarnings("unchecked") + private static final ConfigDef CONFIG = baseConfigDef() + .define(GROUP_ID_CONFIG, + ConfigDef.Type.STRING, + ConfigDef.Importance.HIGH, + GROUP_ID_DOC) + .define(SESSION_TIMEOUT_MS_CONFIG, + ConfigDef.Type.INT, + Math.toIntExact(TimeUnit.SECONDS.toMillis(10)), + ConfigDef.Importance.HIGH, + SESSION_TIMEOUT_MS_DOC) + .define(REBALANCE_TIMEOUT_MS_CONFIG, + ConfigDef.Type.INT, + Math.toIntExact(TimeUnit.MINUTES.toMillis(1)), + ConfigDef.Importance.HIGH, + REBALANCE_TIMEOUT_MS_DOC) + .define(HEARTBEAT_INTERVAL_MS_CONFIG, + ConfigDef.Type.INT, + Math.toIntExact(TimeUnit.SECONDS.toMillis(3)), + ConfigDef.Importance.HIGH, + HEARTBEAT_INTERVAL_MS_DOC) + .define(CommonClientConfigs.METADATA_MAX_AGE_CONFIG, + ConfigDef.Type.LONG, + TimeUnit.MINUTES.toMillis(5), + atLeast(0), + ConfigDef.Importance.LOW, + CommonClientConfigs.METADATA_MAX_AGE_DOC) + .define(CommonClientConfigs.CLIENT_ID_CONFIG, + ConfigDef.Type.STRING, + "", + ConfigDef.Importance.LOW, + CommonClientConfigs.CLIENT_ID_DOC) + .define(CommonClientConfigs.SEND_BUFFER_CONFIG, + ConfigDef.Type.INT, + 128 * 1024, + atLeast(0), + ConfigDef.Importance.MEDIUM, + CommonClientConfigs.SEND_BUFFER_DOC) + .define(CommonClientConfigs.RECEIVE_BUFFER_CONFIG, + ConfigDef.Type.INT, + 32 * 1024, + atLeast(0), + ConfigDef.Importance.MEDIUM, + CommonClientConfigs.RECEIVE_BUFFER_DOC) + .define(CommonClientConfigs.RECONNECT_BACKOFF_MS_CONFIG, + ConfigDef.Type.LONG, + 50L, + atLeast(0L), + ConfigDef.Importance.LOW, + CommonClientConfigs.RECONNECT_BACKOFF_MS_DOC) + .define(CommonClientConfigs.RECONNECT_BACKOFF_MAX_MS_CONFIG, + ConfigDef.Type.LONG, + TimeUnit.SECONDS.toMillis(1), + atLeast(0L), + ConfigDef.Importance.LOW, + CommonClientConfigs.RECONNECT_BACKOFF_MAX_MS_DOC) + .define(CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MS_CONFIG, + ConfigDef.Type.LONG, + CommonClientConfigs.DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT_MS, + atLeast(0L), + ConfigDef.Importance.LOW, + CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MS_DOC) + .define(CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS_CONFIG, + ConfigDef.Type.LONG, + CommonClientConfigs.DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS, + atLeast(0L), + ConfigDef.Importance.LOW, + CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS_DOC) + .define(CommonClientConfigs.RETRY_BACKOFF_MS_CONFIG, + ConfigDef.Type.LONG, + 100L, + atLeast(0L), + ConfigDef.Importance.LOW, + CommonClientConfigs.RETRY_BACKOFF_MS_DOC) + .define(CommonClientConfigs.REQUEST_TIMEOUT_MS_CONFIG, + ConfigDef.Type.INT, + Math.toIntExact(TimeUnit.SECONDS.toMillis(40)), + atLeast(0), + ConfigDef.Importance.MEDIUM, + CommonClientConfigs.REQUEST_TIMEOUT_MS_DOC) + /* default is set to be a bit lower than the server default (10 min), to avoid both client and server closing connection at same time */ + .define(CommonClientConfigs.CONNECTIONS_MAX_IDLE_MS_CONFIG, + ConfigDef.Type.LONG, + TimeUnit.MINUTES.toMillis(9), + ConfigDef.Importance.MEDIUM, + CommonClientConfigs.CONNECTIONS_MAX_IDLE_MS_DOC) + // security support + .define(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, + ConfigDef.Type.STRING, + CommonClientConfigs.DEFAULT_SECURITY_PROTOCOL, + ConfigDef.Importance.MEDIUM, + CommonClientConfigs.SECURITY_PROTOCOL_DOC) + .withClientSaslSupport() + .define(WORKER_SYNC_TIMEOUT_MS_CONFIG, + ConfigDef.Type.INT, + 3000, + ConfigDef.Importance.MEDIUM, + WORKER_SYNC_TIMEOUT_MS_DOC) + .define(WORKER_UNSYNC_BACKOFF_MS_CONFIG, + ConfigDef.Type.INT, + WORKER_UNSYNC_BACKOFF_MS_DEFAULT, + ConfigDef.Importance.MEDIUM, + WORKER_UNSYNC_BACKOFF_MS_DOC) + .define(OFFSET_STORAGE_TOPIC_CONFIG, + ConfigDef.Type.STRING, + ConfigDef.Importance.HIGH, + OFFSET_STORAGE_TOPIC_CONFIG_DOC) + .define(OFFSET_STORAGE_PARTITIONS_CONFIG, + ConfigDef.Type.INT, + 25, + PARTITIONS_VALIDATOR, + ConfigDef.Importance.LOW, + OFFSET_STORAGE_PARTITIONS_CONFIG_DOC) + .define(OFFSET_STORAGE_REPLICATION_FACTOR_CONFIG, + ConfigDef.Type.SHORT, + (short) 3, + REPLICATION_FACTOR_VALIDATOR, + ConfigDef.Importance.LOW, + OFFSET_STORAGE_REPLICATION_FACTOR_CONFIG_DOC) + .define(CONFIG_TOPIC_CONFIG, + ConfigDef.Type.STRING, + ConfigDef.Importance.HIGH, + CONFIG_TOPIC_CONFIG_DOC) + .define(CONFIG_STORAGE_REPLICATION_FACTOR_CONFIG, + ConfigDef.Type.SHORT, + (short) 3, + REPLICATION_FACTOR_VALIDATOR, + ConfigDef.Importance.LOW, + CONFIG_STORAGE_REPLICATION_FACTOR_CONFIG_DOC) + .define(STATUS_STORAGE_TOPIC_CONFIG, + ConfigDef.Type.STRING, + ConfigDef.Importance.HIGH, + STATUS_STORAGE_TOPIC_CONFIG_DOC) + .define(STATUS_STORAGE_PARTITIONS_CONFIG, + ConfigDef.Type.INT, + 5, + PARTITIONS_VALIDATOR, + ConfigDef.Importance.LOW, + STATUS_STORAGE_PARTITIONS_CONFIG_DOC) + .define(STATUS_STORAGE_REPLICATION_FACTOR_CONFIG, + ConfigDef.Type.SHORT, + (short) 3, + REPLICATION_FACTOR_VALIDATOR, + ConfigDef.Importance.LOW, + STATUS_STORAGE_REPLICATION_FACTOR_CONFIG_DOC) + .define(CONNECT_PROTOCOL_CONFIG, + ConfigDef.Type.STRING, + CONNECT_PROTOCOL_DEFAULT, + ConfigDef.LambdaValidator.with( + (name, value) -> { + try { + ConnectProtocolCompatibility.compatibility((String) value); + } catch (Throwable t) { + throw new ConfigException(name, value, "Invalid Connect protocol " + + "compatibility"); + } + }, + () -> "[" + Utils.join(ConnectProtocolCompatibility.values(), ", ") + "]"), + ConfigDef.Importance.LOW, + CONNECT_PROTOCOL_DOC) + .define(SCHEDULED_REBALANCE_MAX_DELAY_MS_CONFIG, + ConfigDef.Type.INT, + SCHEDULED_REBALANCE_MAX_DELAY_MS_DEFAULT, + between(0, Integer.MAX_VALUE), + ConfigDef.Importance.LOW, + SCHEDULED_REBALANCE_MAX_DELAY_MS_DOC) + .define(INTER_WORKER_KEY_TTL_MS_CONFIG, + ConfigDef.Type.INT, + INTER_WORKER_KEY_TTL_MS_MS_DEFAULT, + between(0, Integer.MAX_VALUE), + ConfigDef.Importance.LOW, + INTER_WORKER_KEY_TTL_MS_MS_DOC) + .define(INTER_WORKER_KEY_GENERATION_ALGORITHM_CONFIG, + ConfigDef.Type.STRING, + INTER_WORKER_KEY_GENERATION_ALGORITHM_DEFAULT, + ConfigDef.LambdaValidator.with( + (name, value) -> validateKeyAlgorithm(name, (String) value), + () -> "Any KeyGenerator algorithm supported by the worker JVM" + ), + ConfigDef.Importance.LOW, + INTER_WORKER_KEY_GENERATION_ALGORITHM_DOC) + .define(INTER_WORKER_KEY_SIZE_CONFIG, + ConfigDef.Type.INT, + INTER_WORKER_KEY_SIZE_DEFAULT, + ConfigDef.Importance.LOW, + INTER_WORKER_KEY_SIZE_DOC) + .define(INTER_WORKER_SIGNATURE_ALGORITHM_CONFIG, + ConfigDef.Type.STRING, + INTER_WORKER_SIGNATURE_ALGORITHM_DEFAULT, + ConfigDef.LambdaValidator.with( + (name, value) -> validateSignatureAlgorithm(name, (String) value), + () -> "Any MAC algorithm supported by the worker JVM"), + ConfigDef.Importance.LOW, + INTER_WORKER_SIGNATURE_ALGORITHM_DOC) + .define(INTER_WORKER_VERIFICATION_ALGORITHMS_CONFIG, + ConfigDef.Type.LIST, + INTER_WORKER_VERIFICATION_ALGORITHMS_DEFAULT, + ConfigDef.LambdaValidator.with( + (name, value) -> validateSignatureAlgorithms(name, (List) value), + () -> "A list of one or more MAC algorithms, each supported by the worker JVM" + ), + ConfigDef.Importance.LOW, + INTER_WORKER_VERIFICATION_ALGORITHMS_DOC); + + @Override + public Integer getRebalanceTimeout() { + return getInt(DistributedConfig.REBALANCE_TIMEOUT_MS_CONFIG); + } + + public DistributedConfig(Map props) { + super(CONFIG, props); + getInternalRequestKeyGenerator(); // Check here for a valid key size + key algorithm to fail fast if either are invalid + validateKeyAlgorithmAndVerificationAlgorithms(); + } + + public static void main(String[] args) { + System.out.println(CONFIG.toHtml(4, config -> "connectconfigs_" + config)); + } + + public KeyGenerator getInternalRequestKeyGenerator() { + try { + KeyGenerator result = KeyGenerator.getInstance(getString(INTER_WORKER_KEY_GENERATION_ALGORITHM_CONFIG)); + Optional.ofNullable(getInt(INTER_WORKER_KEY_SIZE_CONFIG)).ifPresent(result::init); + return result; + } catch (NoSuchAlgorithmException | InvalidParameterException e) { + throw new ConfigException(String.format( + "Unable to create key generator with algorithm %s and key size %d: %s", + getString(INTER_WORKER_KEY_GENERATION_ALGORITHM_CONFIG), + getInt(INTER_WORKER_KEY_SIZE_CONFIG), + e.getMessage() + )); + } + } + + private Map topicSettings(String prefix) { + Map result = originalsWithPrefix(prefix); + if (CONFIG_STORAGE_PREFIX.equals(prefix) && result.containsKey(PARTITIONS_SUFFIX)) { + log.warn("Ignoring '{}{}={}' setting, since config topic partitions is always 1", prefix, PARTITIONS_SUFFIX, result.get("partitions")); + } + Object removedPolicy = result.remove(TopicConfig.CLEANUP_POLICY_CONFIG); + if (removedPolicy != null) { + log.warn("Ignoring '{}cleanup.policy={}' setting, since compaction is always used", prefix, removedPolicy); + } + result.remove(TOPIC_SUFFIX); + result.remove(REPLICATION_FACTOR_SUFFIX); + result.remove(PARTITIONS_SUFFIX); + return result; + } + + public Map configStorageTopicSettings() { + return topicSettings(CONFIG_STORAGE_PREFIX); + } + + public Map offsetStorageTopicSettings() { + return topicSettings(OFFSET_STORAGE_PREFIX); + } + + public Map statusStorageTopicSettings() { + return topicSettings(STATUS_STORAGE_PREFIX); + } + + private void validateKeyAlgorithmAndVerificationAlgorithms() { + String keyAlgorithm = getString(INTER_WORKER_KEY_GENERATION_ALGORITHM_CONFIG); + List verificationAlgorithms = getList(INTER_WORKER_VERIFICATION_ALGORITHMS_CONFIG); + if (!verificationAlgorithms.contains(keyAlgorithm)) { + throw new ConfigException( + INTER_WORKER_KEY_GENERATION_ALGORITHM_CONFIG, + keyAlgorithm, + String.format("Key generation algorithm must be present in %s list", INTER_WORKER_VERIFICATION_ALGORITHMS_CONFIG) + ); + } + } + + private static void validateSignatureAlgorithms(String configName, List algorithms) { + if (algorithms.isEmpty()) { + throw new ConfigException( + configName, + algorithms, + "At least one signature verification algorithm must be provided" + ); + } + algorithms.forEach(algorithm -> validateSignatureAlgorithm(configName, algorithm)); + } + + private static void validateSignatureAlgorithm(String configName, String algorithm) { + try { + Mac.getInstance(algorithm); + } catch (NoSuchAlgorithmException e) { + throw new ConfigException(configName, algorithm, e.getMessage()); + } + } + + private static void validateKeyAlgorithm(String configName, String algorithm) { + try { + KeyGenerator.getInstance(algorithm); + } catch (NoSuchAlgorithmException e) { + throw new ConfigException(configName, algorithm, e.getMessage()); + } + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/DistributedHerder.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/DistributedHerder.java new file mode 100644 index 0000000..5aa327e --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/DistributedHerder.java @@ -0,0 +1,2019 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.distributed; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigValue; +import org.apache.kafka.common.errors.WakeupException; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.stats.Avg; +import org.apache.kafka.common.metrics.stats.CumulativeSum; +import org.apache.kafka.common.metrics.stats.Max; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.ThreadUtils; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.connect.connector.Connector; +import org.apache.kafka.connect.connector.policy.ConnectorClientConfigOverridePolicy; +import org.apache.kafka.connect.errors.AlreadyExistsException; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.errors.NotFoundException; +import org.apache.kafka.connect.runtime.AbstractHerder; +import org.apache.kafka.connect.runtime.CloseableConnectorContext; +import org.apache.kafka.connect.runtime.ConnectMetrics; +import org.apache.kafka.connect.runtime.ConnectMetrics.MetricGroup; +import org.apache.kafka.connect.runtime.ConnectMetricsRegistry; +import org.apache.kafka.connect.runtime.ConnectorConfig; +import org.apache.kafka.connect.runtime.HerderConnectorContext; +import org.apache.kafka.connect.runtime.HerderRequest; +import org.apache.kafka.connect.runtime.RestartPlan; +import org.apache.kafka.connect.runtime.RestartRequest; +import org.apache.kafka.connect.runtime.SessionKey; +import org.apache.kafka.connect.runtime.SinkConnectorConfig; +import org.apache.kafka.connect.runtime.SourceConnectorConfig; +import org.apache.kafka.connect.runtime.TargetState; +import org.apache.kafka.connect.runtime.TaskStatus; +import org.apache.kafka.connect.runtime.Worker; +import org.apache.kafka.connect.runtime.rest.InternalRequestSignature; +import org.apache.kafka.connect.runtime.rest.RestClient; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorInfo; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorStateInfo; +import org.apache.kafka.connect.runtime.rest.entities.TaskInfo; +import org.apache.kafka.connect.runtime.rest.errors.BadRequestException; +import org.apache.kafka.connect.runtime.rest.errors.ConnectRestException; +import org.apache.kafka.connect.sink.SinkConnector; +import org.apache.kafka.connect.storage.ConfigBackingStore; +import org.apache.kafka.connect.storage.StatusBackingStore; +import org.apache.kafka.connect.util.Callback; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.apache.kafka.connect.util.SinkUtils; +import org.slf4j.Logger; + +import javax.crypto.KeyGenerator; +import javax.crypto.SecretKey; +import javax.ws.rs.core.Response; +import javax.ws.rs.core.UriBuilder; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.NavigableSet; +import java.util.NoSuchElementException; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.Callable; +import java.util.concurrent.ConcurrentSkipListSet; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; + +import static org.apache.kafka.connect.runtime.WorkerConfig.TOPIC_TRACKING_ENABLE_CONFIG; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocol.CONNECT_PROTOCOL_V0; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocolCompatibility.EAGER; +import static org.apache.kafka.connect.runtime.distributed.IncrementalCooperativeConnectProtocol.CONNECT_PROTOCOL_V1; +import static org.apache.kafka.connect.runtime.distributed.IncrementalCooperativeConnectProtocol.CONNECT_PROTOCOL_V2; + +/** + *

            + * Distributed "herder" that coordinates with other workers to spread work across multiple processes. + *

            + *

            + * Under the hood, this is implemented as a group managed by Kafka's group membership facilities (i.e. the generalized + * group/consumer coordinator). Each instance of DistributedHerder joins the group and indicates what it's current + * configuration state is (where it is in the configuration log). The group coordinator selects one member to take + * this information and assign each instance a subset of the active connectors & tasks to execute. This assignment + * is currently performed in a simple round-robin fashion, but this is not guaranteed -- the herder may also choose + * to, e.g., use a sticky assignment to avoid the usual start/stop costs associated with connectors and tasks. Once + * an assignment is received, the DistributedHerder simply runs its assigned connectors and tasks in a Worker. + *

            + *

            + * In addition to distributing work, the DistributedHerder uses the leader determined during the work assignment + * to select a leader for this generation of the group who is responsible for other tasks that can only be performed + * by a single node at a time. Most importantly, this includes writing updated configurations for connectors and tasks, + * (and therefore, also for creating, destroy, and scaling up/down connectors). + *

            + *

            + * The DistributedHerder uses a single thread for most of its processing. This includes processing + * config changes, handling task rebalances and serving requests from the HTTP layer. The latter are pushed + * into a queue until the thread has time to handle them. A consequence of this is that requests can get blocked + * behind a worker rebalance. When the herder knows that a rebalance is expected, it typically returns an error + * immediately to the request, but this is not always possible (in particular when another worker has requested + * the rebalance). Similar to handling HTTP requests, config changes which are observed asynchronously by polling + * the config log are batched for handling in the work thread. + *

            + */ +public class DistributedHerder extends AbstractHerder implements Runnable { + private static final AtomicInteger CONNECT_CLIENT_ID_SEQUENCE = new AtomicInteger(1); + private final Logger log; + + private static final long FORWARD_REQUEST_SHUTDOWN_TIMEOUT_MS = TimeUnit.SECONDS.toMillis(10); + private static final long START_AND_STOP_SHUTDOWN_TIMEOUT_MS = TimeUnit.SECONDS.toMillis(1); + private static final long RECONFIGURE_CONNECTOR_TASKS_BACKOFF_MS = 250; + private static final int START_STOP_THREAD_POOL_SIZE = 8; + private static final short BACKOFF_RETRIES = 5; + + private final AtomicLong requestSeqNum = new AtomicLong(); + + private final Time time; + private final HerderMetrics herderMetrics; + private final List uponShutdown; + + private final String workerGroupId; + private final int workerSyncTimeoutMs; + private final long workerTasksShutdownTimeoutMs; + private final int workerUnsyncBackoffMs; + private final int keyRotationIntervalMs; + private final String requestSignatureAlgorithm; + private final List keySignatureVerificationAlgorithms; + private final KeyGenerator keyGenerator; + + private final ExecutorService herderExecutor; + private final ExecutorService forwardRequestExecutor; + private final ExecutorService startAndStopExecutor; + private final WorkerGroupMember member; + private final AtomicBoolean stopping; + private final boolean isTopicTrackingEnabled; + + // Track enough information about the current membership state to be able to determine which requests via the API + // and the from other nodes are safe to process + private boolean rebalanceResolved; + private ExtendedAssignment runningAssignment = ExtendedAssignment.empty(); + private Set tasksToRestart = new HashSet<>(); + // visible for testing + ExtendedAssignment assignment; + private boolean canReadConfigs; + // visible for testing + protected ClusterConfigState configState; + + // To handle most external requests, like creating or destroying a connector, we can use a generic request where + // the caller specifies all the code that should be executed. + final NavigableSet requests = new ConcurrentSkipListSet<>(); + // Config updates can be collected and applied together when possible. Also, we need to take care to rebalance when + // needed (e.g. task reconfiguration, which requires everyone to coordinate offset commits). + private Set connectorConfigUpdates = new HashSet<>(); + private Set taskConfigUpdates = new HashSet<>(); + // Similarly collect target state changes (when observed by the config storage listener) for handling in the + // herder's main thread. + private Set connectorTargetStateChanges = new HashSet<>(); + private boolean needsReconfigRebalance; + private volatile int generation; + private volatile long scheduledRebalance; + private volatile SecretKey sessionKey; + private volatile long keyExpiration; + private short currentProtocolVersion; + private short backoffRetries; + + // visible for testing + // The latest pending restart request for each named connector + final Map pendingRestartRequests = new HashMap<>(); + + private final DistributedConfig config; + + /** + * Create a herder that will form a Connect cluster with other {@link DistributedHerder} instances (in this or other JVMs) + * that have the same group ID. + * + * @param config the configuration for the worker; may not be null + * @param time the clock to use; may not be null + * @param worker the {@link Worker} instance to use; may not be null + * @param kafkaClusterId the identifier of the Kafka cluster to use for internal topics; may not be null + * @param statusBackingStore the backing store for statuses; may not be null + * @param configBackingStore the backing store for connector configurations; may not be null + * @param restUrl the URL of this herder's REST API; may not be null + * @param connectorClientConfigOverridePolicy the policy specifying the client configuration properties that may be overridden + * in connector configurations; may not be null + * @param uponShutdown any {@link AutoCloseable} objects that should be closed when this herder is {@link #stop() stopped}, + * after all services and resources owned by this herder are stopped + */ + public DistributedHerder(DistributedConfig config, + Time time, + Worker worker, + String kafkaClusterId, + StatusBackingStore statusBackingStore, + ConfigBackingStore configBackingStore, + String restUrl, + ConnectorClientConfigOverridePolicy connectorClientConfigOverridePolicy, + AutoCloseable... uponShutdown) { + this(config, worker, worker.workerId(), kafkaClusterId, statusBackingStore, configBackingStore, null, restUrl, worker.metrics(), + time, connectorClientConfigOverridePolicy, uponShutdown); + configBackingStore.setUpdateListener(new ConfigUpdateListener()); + } + + // visible for testing + DistributedHerder(DistributedConfig config, + Worker worker, + String workerId, + String kafkaClusterId, + StatusBackingStore statusBackingStore, + ConfigBackingStore configBackingStore, + WorkerGroupMember member, + String restUrl, + ConnectMetrics metrics, + Time time, + ConnectorClientConfigOverridePolicy connectorClientConfigOverridePolicy, + AutoCloseable... uponShutdown) { + super(worker, workerId, kafkaClusterId, statusBackingStore, configBackingStore, connectorClientConfigOverridePolicy); + + this.time = time; + this.herderMetrics = new HerderMetrics(metrics); + this.workerGroupId = config.getString(DistributedConfig.GROUP_ID_CONFIG); + this.workerSyncTimeoutMs = config.getInt(DistributedConfig.WORKER_SYNC_TIMEOUT_MS_CONFIG); + this.workerTasksShutdownTimeoutMs = config.getLong(DistributedConfig.TASK_SHUTDOWN_GRACEFUL_TIMEOUT_MS_CONFIG); + this.workerUnsyncBackoffMs = config.getInt(DistributedConfig.WORKER_UNSYNC_BACKOFF_MS_CONFIG); + this.requestSignatureAlgorithm = config.getString(DistributedConfig.INTER_WORKER_SIGNATURE_ALGORITHM_CONFIG); + this.keyRotationIntervalMs = config.getInt(DistributedConfig.INTER_WORKER_KEY_TTL_MS_CONFIG); + this.keySignatureVerificationAlgorithms = config.getList(DistributedConfig.INTER_WORKER_VERIFICATION_ALGORITHMS_CONFIG); + this.keyGenerator = config.getInternalRequestKeyGenerator(); + this.isTopicTrackingEnabled = config.getBoolean(TOPIC_TRACKING_ENABLE_CONFIG); + this.uponShutdown = Arrays.asList(uponShutdown); + + String clientIdConfig = config.getString(CommonClientConfigs.CLIENT_ID_CONFIG); + String clientId = clientIdConfig.length() <= 0 ? "connect-" + CONNECT_CLIENT_ID_SEQUENCE.getAndIncrement() : clientIdConfig; + LogContext logContext = new LogContext("[Worker clientId=" + clientId + ", groupId=" + this.workerGroupId + "] "); + log = logContext.logger(DistributedHerder.class); + + this.member = member != null + ? member + : new WorkerGroupMember(config, restUrl, this.configBackingStore, + new RebalanceListener(time), time, clientId, logContext); + + this.herderExecutor = new ThreadPoolExecutor(1, 1, 0L, + TimeUnit.MILLISECONDS, + new LinkedBlockingDeque<>(1), + ThreadUtils.createThreadFactory( + this.getClass().getSimpleName() + "-" + clientId + "-%d", false)); + + this.forwardRequestExecutor = Executors.newFixedThreadPool(1, + ThreadUtils.createThreadFactory( + "ForwardRequestExecutor-" + clientId + "-%d", false)); + this.startAndStopExecutor = Executors.newFixedThreadPool(START_STOP_THREAD_POOL_SIZE, + ThreadUtils.createThreadFactory( + "StartAndStopExecutor-" + clientId + "-%d", false)); + this.config = config; + + stopping = new AtomicBoolean(false); + configState = ClusterConfigState.EMPTY; + rebalanceResolved = true; // If we still need to follow up after a rebalance occurred, starting up tasks + needsReconfigRebalance = false; + canReadConfigs = true; // We didn't try yet, but Configs are readable until proven otherwise + scheduledRebalance = Long.MAX_VALUE; + keyExpiration = Long.MAX_VALUE; + sessionKey = null; + backoffRetries = BACKOFF_RETRIES; + + currentProtocolVersion = ConnectProtocolCompatibility.compatibility( + config.getString(DistributedConfig.CONNECT_PROTOCOL_CONFIG) + ).protocolVersion(); + if (!internalRequestValidationEnabled(currentProtocolVersion)) { + log.warn( + "Internal request verification will be disabled for this cluster as this worker's {} configuration has been set to '{}'. " + + "If this is not intentional, either remove the '{}' configuration from the worker config file or change its value " + + "to '{}'. If this configuration is left as-is, the cluster will be insecure; for more information, see KIP-507: " + + "https://cwiki.apache.org/confluence/display/KAFKA/KIP-507%3A+Securing+Internal+Connect+REST+Endpoints", + DistributedConfig.CONNECT_PROTOCOL_CONFIG, + config.getString(DistributedConfig.CONNECT_PROTOCOL_CONFIG), + DistributedConfig.CONNECT_PROTOCOL_CONFIG, + ConnectProtocolCompatibility.SESSIONED.name() + ); + } + } + + @Override + public void start() { + this.herderExecutor.submit(this); + } + + @Override + public void run() { + try { + log.info("Herder starting"); + + startServices(); + + log.info("Herder started"); + running = true; + + while (!stopping.get()) { + tick(); + } + + halt(); + + log.info("Herder stopped"); + herderMetrics.close(); + } catch (Throwable t) { + log.error("Uncaught exception in herder work thread, exiting: ", t); + Exit.exit(1); + } finally { + running = false; + } + } + + // public for testing + public void tick() { + // The main loop does two primary things: 1) drive the group membership protocol, responding to rebalance events + // as they occur, and 2) handle external requests targeted at the leader. All the "real" work of the herder is + // performed in this thread, which keeps synchronization straightforward at the cost of some operations possibly + // blocking up this thread (especially those in callbacks due to rebalance events). + + try { + // if we failed to read to end of log before, we need to make sure the issue was resolved before joining group + // Joining and immediately leaving for failure to read configs is exceedingly impolite + if (!canReadConfigs) { + if (readConfigToEnd(workerSyncTimeoutMs)) { + canReadConfigs = true; + } else { + return; // Safe to return and tick immediately because readConfigToEnd will do the backoff for us + } + } + + log.debug("Ensuring group membership is still active"); + member.ensureActive(); + // Ensure we're in a good state in our group. If not restart and everything should be setup to rejoin + if (!handleRebalanceCompleted()) return; + } catch (WakeupException e) { + // May be due to a request from another thread, or might be stopping. If the latter, we need to check the + // flag immediately. If the former, we need to re-run the ensureActive call since we can't handle requests + // unless we're in the group. + log.trace("Woken up while ensure group membership is still active"); + return; + } + + long now = time.milliseconds(); + + if (checkForKeyRotation(now)) { + log.debug("Distributing new session key"); + keyExpiration = Long.MAX_VALUE; + try { + configBackingStore.putSessionKey(new SessionKey( + keyGenerator.generateKey(), + now + )); + } catch (Exception e) { + log.info("Failed to write new session key to config topic; forcing a read to the end of the config topic before possibly retrying"); + canReadConfigs = false; + return; + } + } + + // Process any external requests + // TODO: Some of these can be performed concurrently or even optimized away entirely. + // For example, if three different connectors are slated to be restarted, it's fine to + // restart all three at the same time instead. + // Another example: if multiple configurations are submitted for the same connector, + // the only one that actually has to be written to the config topic is the + // most-recently one. + long nextRequestTimeoutMs = Long.MAX_VALUE; + while (true) { + final DistributedHerderRequest next = peekWithoutException(); + if (next == null) { + break; + } else if (now >= next.at) { + requests.pollFirst(); + } else { + nextRequestTimeoutMs = next.at - now; + break; + } + + try { + next.action().call(); + next.callback().onCompletion(null, null); + } catch (Throwable t) { + next.callback().onCompletion(t, null); + } + } + + // Process all pending connector restart requests + processRestartRequests(); + + if (scheduledRebalance < Long.MAX_VALUE) { + nextRequestTimeoutMs = Math.min(nextRequestTimeoutMs, Math.max(scheduledRebalance - now, 0)); + rebalanceResolved = false; + log.debug("Scheduled rebalance at: {} (now: {} nextRequestTimeoutMs: {}) ", + scheduledRebalance, now, nextRequestTimeoutMs); + } + if (isLeader() && internalRequestValidationEnabled() && keyExpiration < Long.MAX_VALUE) { + nextRequestTimeoutMs = Math.min(nextRequestTimeoutMs, Math.max(keyExpiration - now, 0)); + log.debug("Scheduled next key rotation at: {} (now: {} nextRequestTimeoutMs: {}) ", + keyExpiration, now, nextRequestTimeoutMs); + } + + // Process any configuration updates + AtomicReference> connectorConfigUpdatesCopy = new AtomicReference<>(); + AtomicReference> connectorTargetStateChangesCopy = new AtomicReference<>(); + AtomicReference> taskConfigUpdatesCopy = new AtomicReference<>(); + + boolean shouldReturn; + if (member.currentProtocolVersion() == CONNECT_PROTOCOL_V0) { + shouldReturn = updateConfigsWithEager(connectorConfigUpdatesCopy, + connectorTargetStateChangesCopy); + // With eager protocol we should return immediately if needsReconfigRebalance has + // been set to retain the old workflow + if (shouldReturn) { + return; + } + + if (connectorConfigUpdatesCopy.get() != null) { + processConnectorConfigUpdates(connectorConfigUpdatesCopy.get()); + } + + if (connectorTargetStateChangesCopy.get() != null) { + processTargetStateChanges(connectorTargetStateChangesCopy.get()); + } + } else { + shouldReturn = updateConfigsWithIncrementalCooperative(connectorConfigUpdatesCopy, + connectorTargetStateChangesCopy, taskConfigUpdatesCopy); + + if (connectorConfigUpdatesCopy.get() != null) { + processConnectorConfigUpdates(connectorConfigUpdatesCopy.get()); + } + + if (connectorTargetStateChangesCopy.get() != null) { + processTargetStateChanges(connectorTargetStateChangesCopy.get()); + } + + if (taskConfigUpdatesCopy.get() != null) { + processTaskConfigUpdatesWithIncrementalCooperative(taskConfigUpdatesCopy.get()); + } + + if (shouldReturn) { + return; + } + } + + // Let the group take any actions it needs to + try { + log.trace("Polling for group activity; will wait for {}ms or until poll is interrupted by " + + "either config backing store updates or a new external request", + nextRequestTimeoutMs); + member.poll(nextRequestTimeoutMs); + // Ensure we're in a good state in our group. If not restart and everything should be setup to rejoin + handleRebalanceCompleted(); + } catch (WakeupException e) { // FIXME should not be WakeupException + log.trace("Woken up while polling for group activity"); + // Ignore. Just indicates we need to check the exit flag, for requested actions, etc. + } + } + + private boolean checkForKeyRotation(long now) { + SecretKey key; + long expiration; + synchronized (this) { + key = sessionKey; + expiration = keyExpiration; + } + + if (internalRequestValidationEnabled()) { + if (isLeader()) { + if (key == null) { + log.debug("Internal request signing is enabled but no session key has been distributed yet. " + + "Distributing new key now."); + return true; + } else if (expiration <= now) { + log.debug("Existing key has expired. Distributing new key now."); + return true; + } else if (!key.getAlgorithm().equals(keyGenerator.getAlgorithm()) + || key.getEncoded().length != keyGenerator.generateKey().getEncoded().length) { + log.debug("Previously-distributed key uses different algorithm/key size " + + "than required by current worker configuration. Distributing new key now."); + return true; + } + } else if (key == null && configState.sessionKey() != null) { + // This happens on startup for follower workers; the snapshot contains the session key, + // but no callback in the config update listener has been fired for it yet. + sessionKey = configState.sessionKey().key(); + } + } + return false; + } + + private synchronized boolean updateConfigsWithEager(AtomicReference> connectorConfigUpdatesCopy, + AtomicReference> connectorTargetStateChangesCopy) { + // This branch is here to avoid creating a snapshot if not needed + if (needsReconfigRebalance + || !connectorConfigUpdates.isEmpty() + || !connectorTargetStateChanges.isEmpty()) { + log.trace("Handling config updates with eager rebalancing"); + // Connector reconfigs only need local updates since there is no coordination between workers required. + // However, if connectors were added or removed, work needs to be rebalanced since we have more work + // items to distribute among workers. + configState = configBackingStore.snapshot(); + + if (needsReconfigRebalance) { + // Task reconfigs require a rebalance. Request the rebalance, clean out state, and then restart + // this loop, which will then ensure the rebalance occurs without any other requests being + // processed until it completes. + log.debug("Requesting rebalance due to reconfiguration of tasks (needsReconfigRebalance: {})", + needsReconfigRebalance); + member.requestRejoin(); + needsReconfigRebalance = false; + // Any connector config updates or target state changes will be addressed during the rebalance too + connectorConfigUpdates.clear(); + connectorTargetStateChanges.clear(); + return true; + } else { + if (!connectorConfigUpdates.isEmpty()) { + // We can't start/stop while locked since starting connectors can cause task updates that will + // require writing configs, which in turn make callbacks into this class from another thread that + // require acquiring a lock. This leads to deadlock. Instead, just copy the info we need and process + // the updates after unlocking. + connectorConfigUpdatesCopy.set(connectorConfigUpdates); + connectorConfigUpdates = new HashSet<>(); + } + + if (!connectorTargetStateChanges.isEmpty()) { + // Similarly for target state changes which can cause connectors to be restarted + connectorTargetStateChangesCopy.set(connectorTargetStateChanges); + connectorTargetStateChanges = new HashSet<>(); + } + } + } else { + log.trace("Skipping config updates with eager rebalancing " + + "since no config rebalance is required " + + "and there are no connector config, task config, or target state changes pending"); + } + return false; + } + + private synchronized boolean updateConfigsWithIncrementalCooperative(AtomicReference> connectorConfigUpdatesCopy, + AtomicReference> connectorTargetStateChangesCopy, + AtomicReference> taskConfigUpdatesCopy) { + boolean retValue = false; + // This branch is here to avoid creating a snapshot if not needed + if (needsReconfigRebalance + || !connectorConfigUpdates.isEmpty() + || !connectorTargetStateChanges.isEmpty() + || !taskConfigUpdates.isEmpty()) { + log.trace("Handling config updates with incremental cooperative rebalancing"); + // Connector reconfigs only need local updates since there is no coordination between workers required. + // However, if connectors were added or removed, work needs to be rebalanced since we have more work + // items to distribute among workers. + configState = configBackingStore.snapshot(); + + if (needsReconfigRebalance) { + log.debug("Requesting rebalance due to reconfiguration of tasks (needsReconfigRebalance: {})", + needsReconfigRebalance); + member.requestRejoin(); + needsReconfigRebalance = false; + retValue = true; + } + + if (!connectorConfigUpdates.isEmpty()) { + // We can't start/stop while locked since starting connectors can cause task updates that will + // require writing configs, which in turn make callbacks into this class from another thread that + // require acquiring a lock. This leads to deadlock. Instead, just copy the info we need and process + // the updates after unlocking. + connectorConfigUpdatesCopy.set(connectorConfigUpdates); + connectorConfigUpdates = new HashSet<>(); + } + + if (!connectorTargetStateChanges.isEmpty()) { + // Similarly for target state changes which can cause connectors to be restarted + connectorTargetStateChangesCopy.set(connectorTargetStateChanges); + connectorTargetStateChanges = new HashSet<>(); + } + + if (!taskConfigUpdates.isEmpty()) { + // Similarly for task config updates + taskConfigUpdatesCopy.set(taskConfigUpdates); + taskConfigUpdates = new HashSet<>(); + } + } else { + log.trace("Skipping config updates with incremental cooperative rebalancing " + + "since no config rebalance is required " + + "and there are no connector config, task config, or target state changes pending"); + } + return retValue; + } + + private void processConnectorConfigUpdates(Set connectorConfigUpdates) { + // If we only have connector config updates, we can just bounce the updated connectors that are + // currently assigned to this worker. + Set localConnectors = assignment == null ? Collections.emptySet() : new HashSet<>(assignment.connectors()); + log.trace("Processing connector config updates; " + + "currently-owned connectors are {}, and to-be-updated connectors are {}", + localConnectors, + connectorConfigUpdates); + for (String connectorName : connectorConfigUpdates) { + if (!localConnectors.contains(connectorName)) { + log.trace("Skipping config update for connector {} as it is not owned by this worker", + connectorName); + continue; + } + boolean remains = configState.contains(connectorName); + log.info("Handling connector-only config update by {} connector {}", + remains ? "restarting" : "stopping", connectorName); + worker.stopAndAwaitConnector(connectorName); + // The update may be a deletion, so verify we actually need to restart the connector + if (remains) { + startConnector(connectorName, (error, result) -> { + if (error != null) { + log.error("Failed to start connector '" + connectorName + "'", error); + } + }); + } + } + } + + private void processTargetStateChanges(Set connectorTargetStateChanges) { + log.trace("Processing target state updates; " + + "currently-known connectors are {}, and to-be-updated connectors are {}", + configState.connectors(), connectorTargetStateChanges); + for (String connector : connectorTargetStateChanges) { + TargetState targetState = configState.targetState(connector); + if (!configState.connectors().contains(connector)) { + log.debug("Received target state change for unknown connector: {}", connector); + continue; + } + + // we must propagate the state change to the worker so that the connector's + // tasks can transition to the new target state + worker.setTargetState(connector, targetState, (error, newState) -> { + if (error != null) { + log.error("Failed to transition connector to target state", error); + return; + } + // additionally, if the worker is running the connector itself, then we need to + // request reconfiguration to ensure that config changes while paused take effect + if (newState == TargetState.STARTED) { + requestTaskReconfiguration(connector); + } + }); + } + } + + private void processTaskConfigUpdatesWithIncrementalCooperative(Set taskConfigUpdates) { + Set localTasks = assignment == null + ? Collections.emptySet() + : new HashSet<>(assignment.tasks()); + log.trace("Processing task config updates with incremental cooperative rebalance protocol; " + + "currently-owned tasks are {}, and to-be-updated tasks are {}", + localTasks, taskConfigUpdates); + Set connectorsWhoseTasksToStop = taskConfigUpdates.stream() + .map(ConnectorTaskId::connector).collect(Collectors.toSet()); + + List tasksToStop = localTasks.stream() + .filter(taskId -> connectorsWhoseTasksToStop.contains(taskId.connector())) + .collect(Collectors.toList()); + log.info("Handling task config update by restarting tasks {}", tasksToStop); + worker.stopAndAwaitTasks(tasksToStop); + tasksToRestart.addAll(tasksToStop); + } + + // public for testing + public void halt() { + synchronized (this) { + // Clean up any connectors and tasks that are still running. + log.info("Stopping connectors and tasks that are still assigned to this worker."); + List> callables = new ArrayList<>(); + for (String connectorName : new ArrayList<>(worker.connectorNames())) { + callables.add(getConnectorStoppingCallable(connectorName)); + } + for (ConnectorTaskId taskId : new ArrayList<>(worker.taskIds())) { + callables.add(getTaskStoppingCallable(taskId)); + } + startAndStop(callables); + + member.stop(); + + // Explicitly fail any outstanding requests so they actually get a response and get an + // understandable reason for their failure. + DistributedHerderRequest request = requests.pollFirst(); + while (request != null) { + request.callback().onCompletion(new ConnectException("Worker is shutting down"), null); + request = requests.pollFirst(); + } + + stopServices(); + } + } + + @Override + protected void stopServices() { + try { + super.stopServices(); + } finally { + this.uponShutdown.forEach(closeable -> Utils.closeQuietly(closeable, closeable != null ? closeable.toString() : "")); + } + } + + @Override + public void stop() { + log.info("Herder stopping"); + + stopping.set(true); + member.wakeup(); + herderExecutor.shutdown(); + try { + if (!herderExecutor.awaitTermination(workerTasksShutdownTimeoutMs, TimeUnit.MILLISECONDS)) + herderExecutor.shutdownNow(); + + forwardRequestExecutor.shutdown(); + startAndStopExecutor.shutdown(); + + if (!forwardRequestExecutor.awaitTermination(FORWARD_REQUEST_SHUTDOWN_TIMEOUT_MS, TimeUnit.MILLISECONDS)) + forwardRequestExecutor.shutdownNow(); + if (!startAndStopExecutor.awaitTermination(START_AND_STOP_SHUTDOWN_TIMEOUT_MS, TimeUnit.MILLISECONDS)) + startAndStopExecutor.shutdownNow(); + } catch (InterruptedException e) { + // ignore + } + + log.info("Herder stopped"); + running = false; + } + + @Override + public void connectors(final Callback> callback) { + log.trace("Submitting connector listing request"); + + addRequest( + () -> { + if (!checkRebalanceNeeded(callback)) + callback.onCompletion(null, configState.connectors()); + return null; + }, + forwardErrorCallback(callback) + ); + } + + @Override + public void connectorInfo(final String connName, final Callback callback) { + log.trace("Submitting connector info request {}", connName); + + addRequest( + () -> { + if (checkRebalanceNeeded(callback)) + return null; + + if (!configState.contains(connName)) { + callback.onCompletion( + new NotFoundException("Connector " + connName + " not found"), null); + } else { + callback.onCompletion(null, connectorInfo(connName)); + } + return null; + }, + forwardErrorCallback(callback) + ); + } + + @Override + public void tasksConfig(String connName, final Callback>> callback) { + log.trace("Submitting tasks config request {}", connName); + + addRequest( + () -> { + if (checkRebalanceNeeded(callback)) + return null; + + if (!configState.contains(connName)) { + callback.onCompletion(new NotFoundException("Connector " + connName + " not found"), null); + } else { + callback.onCompletion(null, buildTasksConfig(connName)); + } + return null; + }, + forwardErrorCallback(callback) + ); + } + + @Override + protected Map rawConfig(String connName) { + return configState.rawConnectorConfig(connName); + } + + @Override + public void connectorConfig(String connName, final Callback> callback) { + log.trace("Submitting connector config read request {}", connName); + super.connectorConfig(connName, callback); + } + + @Override + public void deleteConnectorConfig(final String connName, final Callback> callback) { + addRequest( + () -> { + log.trace("Handling connector config request {}", connName); + if (!isLeader()) { + callback.onCompletion(new NotLeaderException("Only the leader can delete connector configs.", leaderUrl()), null); + return null; + } + + if (!configState.contains(connName)) { + callback.onCompletion(new NotFoundException("Connector " + connName + " not found"), null); + } else { + log.trace("Removing connector config {} {}", connName, configState.connectors()); + configBackingStore.removeConnectorConfig(connName); + callback.onCompletion(null, new Created<>(false, null)); + } + return null; + }, + forwardErrorCallback(callback) + ); + } + + @Override + protected Map validateBasicConnectorConfig(Connector connector, + ConfigDef configDef, + Map config) { + Map validatedConfig = super.validateBasicConnectorConfig(connector, configDef, config); + if (connector instanceof SinkConnector) { + ConfigValue validatedName = validatedConfig.get(ConnectorConfig.NAME_CONFIG); + String name = (String) validatedName.value(); + if (workerGroupId.equals(SinkUtils.consumerGroupId(name))) { + validatedName.addErrorMessage("Consumer group for sink connector named " + name + + " conflicts with Connect worker group " + workerGroupId); + } + } + return validatedConfig; + } + + + @Override + public void putConnectorConfig(final String connName, final Map config, final boolean allowReplace, + final Callback> callback) { + log.trace("Submitting connector config write request {}", connName); + addRequest( + () -> { + validateConnectorConfig(config, (error, configInfos) -> { + if (error != null) { + callback.onCompletion(error, null); + return; + } + + // Complete the connector config write via another herder request in order to + // perform the write to the backing store (or forward to the leader) during + // the "external request" portion of the tick loop + addRequest( + () -> { + if (maybeAddConfigErrors(configInfos, callback)) { + return null; + } + + log.trace("Handling connector config request {}", connName); + if (!isLeader()) { + callback.onCompletion(new NotLeaderException("Only the leader can set connector configs.", leaderUrl()), null); + return null; + } + boolean exists = configState.contains(connName); + if (!allowReplace && exists) { + callback.onCompletion(new AlreadyExistsException("Connector " + connName + " already exists"), null); + return null; + } + + log.trace("Submitting connector config {} {} {}", connName, allowReplace, configState.connectors()); + configBackingStore.putConnectorConfig(connName, config); + + // Note that we use the updated connector config despite the fact that we don't have an updated + // snapshot yet. The existing task info should still be accurate. + ConnectorInfo info = new ConnectorInfo(connName, config, configState.tasks(connName), + // validateConnectorConfig have checked the existence of CONNECTOR_CLASS_CONFIG + connectorTypeForClass(config.get(ConnectorConfig.CONNECTOR_CLASS_CONFIG))); + callback.onCompletion(null, new Created<>(!exists, info)); + return null; + }, + forwardErrorCallback(callback) + ); + }); + return null; + }, + forwardErrorCallback(callback) + ); + } + + @Override + public void requestTaskReconfiguration(final String connName) { + log.trace("Submitting connector task reconfiguration request {}", connName); + + addRequest( + () -> { + reconfigureConnectorTasksWithRetry(time.milliseconds(), connName); + return null; + }, + (error, result) -> { + if (error != null) { + log.error("Unexpected error during task reconfiguration: ", error); + log.error("Task reconfiguration for {} failed unexpectedly, this connector will not be properly reconfigured unless manually triggered.", connName); + } + } + ); + } + + @Override + public void taskConfigs(final String connName, final Callback> callback) { + log.trace("Submitting get task configuration request {}", connName); + + addRequest( + () -> { + if (checkRebalanceNeeded(callback)) + return null; + + if (!configState.contains(connName)) { + callback.onCompletion(new NotFoundException("Connector " + connName + " not found"), null); + } else { + List result = new ArrayList<>(); + for (int i = 0; i < configState.taskCount(connName); i++) { + ConnectorTaskId id = new ConnectorTaskId(connName, i); + result.add(new TaskInfo(id, configState.rawTaskConfig(id))); + } + callback.onCompletion(null, result); + } + return null; + }, + forwardErrorCallback(callback) + ); + } + + @Override + public void putTaskConfigs(final String connName, final List> configs, final Callback callback, InternalRequestSignature requestSignature) { + log.trace("Submitting put task configuration request {}", connName); + if (internalRequestValidationEnabled()) { + ConnectRestException requestValidationError = null; + if (requestSignature == null) { + requestValidationError = new BadRequestException("Internal request missing required signature"); + } else if (!keySignatureVerificationAlgorithms.contains(requestSignature.keyAlgorithm())) { + requestValidationError = new BadRequestException(String.format( + "This worker does not support the '%s' key signing algorithm used by other workers. " + + "This worker is currently configured to use: %s. " + + "Check that all workers' configuration files permit the same set of signature algorithms, " + + "and correct any misconfigured worker and restart it.", + requestSignature.keyAlgorithm(), + keySignatureVerificationAlgorithms + )); + } else { + if (!requestSignature.isValid(sessionKey)) { + requestValidationError = new ConnectRestException( + Response.Status.FORBIDDEN, + "Internal request contained invalid signature." + ); + } + } + if (requestValidationError != null) { + callback.onCompletion(requestValidationError, null); + return; + } + } + + addRequest( + () -> { + if (!isLeader()) + callback.onCompletion(new NotLeaderException("Only the leader may write task configurations.", leaderUrl()), null); + else if (!configState.contains(connName)) + callback.onCompletion(new NotFoundException("Connector " + connName + " not found"), null); + else { + configBackingStore.putTaskConfigs(connName, configs); + callback.onCompletion(null, null); + } + return null; + }, + forwardErrorCallback(callback) + ); + } + + @Override + public void restartConnector(final String connName, final Callback callback) { + restartConnector(0, connName, callback); + } + + @Override + public HerderRequest restartConnector(final long delayMs, final String connName, final Callback callback) { + return addRequest( + delayMs, + () -> { + if (checkRebalanceNeeded(callback)) + return null; + + if (!configState.connectors().contains(connName)) { + callback.onCompletion(new NotFoundException("Unknown connector: " + connName), null); + return null; + } + + if (assignment.connectors().contains(connName)) { + try { + worker.stopAndAwaitConnector(connName); + startConnector(connName, callback); + } catch (Throwable t) { + callback.onCompletion(t, null); + } + } else if (isLeader()) { + callback.onCompletion(new NotAssignedException("Cannot restart connector since it is not assigned to this member", member.ownerUrl(connName)), null); + } else { + callback.onCompletion(new NotLeaderException("Only the leader can process restart requests.", leaderUrl()), null); + } + return null; + }, + forwardErrorCallback(callback)); + } + + @Override + public void restartTask(final ConnectorTaskId id, final Callback callback) { + addRequest( + () -> { + if (checkRebalanceNeeded(callback)) + return null; + + if (!configState.connectors().contains(id.connector())) { + callback.onCompletion(new NotFoundException("Unknown connector: " + id.connector()), null); + return null; + } + + if (configState.taskConfig(id) == null) { + callback.onCompletion(new NotFoundException("Unknown task: " + id), null); + return null; + } + + if (assignment.tasks().contains(id)) { + try { + worker.stopAndAwaitTask(id); + if (startTask(id)) + callback.onCompletion(null, null); + else + callback.onCompletion(new ConnectException("Failed to start task: " + id), null); + } catch (Throwable t) { + callback.onCompletion(t, null); + } + } else if (isLeader()) { + callback.onCompletion(new NotAssignedException("Cannot restart task since it is not assigned to this member", member.ownerUrl(id)), null); + } else { + callback.onCompletion(new NotLeaderException("Cannot restart task since it is not assigned to this member", leaderUrl()), null); + } + return null; + }, + forwardErrorCallback(callback)); + } + + @Override + public int generation() { + return generation; + } + + @Override + public void restartConnectorAndTasks(RestartRequest request, Callback callback) { + final String connectorName = request.connectorName(); + addRequest( + () -> { + if (checkRebalanceNeeded(callback)) { + return null; + } + if (!configState.connectors().contains(request.connectorName())) { + callback.onCompletion(new NotFoundException("Unknown connector: " + connectorName), null); + return null; + } + if (isLeader()) { + // Write a restart request to the config backing store, to be executed asynchronously in tick() + configBackingStore.putRestartRequest(request); + // Compute and send the response that this was accepted + Optional plan = buildRestartPlan(request); + if (!plan.isPresent()) { + callback.onCompletion(new NotFoundException("Status for connector " + connectorName + " not found", null), null); + } else { + callback.onCompletion(null, plan.get().restartConnectorStateInfo()); + } + } else { + callback.onCompletion(new NotLeaderException("Only the leader can process restart requests.", leaderUrl()), null); + } + return null; + }, + forwardErrorCallback(callback) + ); + } + + /** + * Process all pending restart requests. There can be at most one request per connector. + * + *

            This method is called from within the {@link #tick()} method. + */ + void processRestartRequests() { + List restartRequests; + synchronized (this) { + if (pendingRestartRequests.isEmpty()) { + return; + } + //dequeue into a local list to minimize the work being done within the synchronized block + restartRequests = new ArrayList<>(pendingRestartRequests.values()); + pendingRestartRequests.clear(); + } + restartRequests.forEach(restartRequest -> { + try { + doRestartConnectorAndTasks(restartRequest); + } catch (Exception e) { + log.warn("Unexpected error while trying to process " + restartRequest + ", the restart request will be skipped.", e); + } + }); + } + + /** + * Builds and executes a restart plan for the connector and its tasks from request. + * Execution of a plan involves triggering the stop of eligible connector/tasks and then queuing the start for eligible connector/tasks. + * + * @param request the request to restart connector and tasks + */ + protected synchronized void doRestartConnectorAndTasks(RestartRequest request) { + String connectorName = request.connectorName(); + Optional maybePlan = buildRestartPlan(request); + if (!maybePlan.isPresent()) { + log.debug("Skipping restart of connector '{}' since no status is available: {}", connectorName, request); + return; + } + RestartPlan plan = maybePlan.get(); + log.info("Executing {}", plan); + + // If requested, stop the connector and any tasks, marking each as restarting + final ExtendedAssignment currentAssignments = assignment; + final Collection assignedIdsToRestart = plan.taskIdsToRestart() + .stream() + .filter(taskId -> currentAssignments.tasks().contains(taskId)) + .collect(Collectors.toList()); + final boolean restartConnector = plan.shouldRestartConnector() && currentAssignments.connectors().contains(connectorName); + final boolean restartTasks = !assignedIdsToRestart.isEmpty(); + if (restartConnector) { + worker.stopAndAwaitConnector(connectorName); + onRestart(connectorName); + } + if (restartTasks) { + // Stop the tasks and mark as restarting + worker.stopAndAwaitTasks(assignedIdsToRestart); + assignedIdsToRestart.forEach(this::onRestart); + } + + // Now restart the connector and tasks + if (restartConnector) { + try { + startConnector(connectorName, (error, targetState) -> { + if (error == null) { + log.info("Connector '{}' restart successful", connectorName); + } else { + log.error("Connector '{}' restart failed", connectorName, error); + } + }); + } catch (Throwable t) { + log.error("Connector '{}' restart failed", connectorName, t); + } + } + if (restartTasks) { + log.debug("Restarting {} of {} tasks for {}", plan.restartTaskCount(), plan.totalTaskCount(), request); + plan.taskIdsToRestart().forEach(taskId -> { + try { + if (startTask(taskId)) { + log.info("Task '{}' restart successful", taskId); + } else { + log.error("Task '{}' restart failed", taskId); + } + } catch (Throwable t) { + log.error("Task '{}' restart failed", taskId, t); + } + }); + log.debug("Restarted {} of {} tasks for {} as requested", plan.restartTaskCount(), plan.totalTaskCount(), request); + } + log.info("Completed {}", plan); + } + + // Should only be called from work thread, so synchronization should not be needed + private boolean isLeader() { + return assignment != null && member.memberId().equals(assignment.leader()); + } + + /** + * Get the URL for the leader's REST interface, or null if we do not have the leader's URL yet. + */ + private String leaderUrl() { + if (assignment == null) + return null; + return assignment.leaderUrl(); + } + + /** + * Handle post-assignment operations, either trying to resolve issues that kept assignment from completing, getting + * this node into sync and its work started. + * + * @return false if we couldn't finish + */ + private boolean handleRebalanceCompleted() { + if (rebalanceResolved) { + log.trace("Returning early because rebalance is marked as resolved (rebalanceResolved: true)"); + return true; + } + log.debug("Handling completed but unresolved rebalance"); + + // We need to handle a variety of cases after a rebalance: + // 1. Assignment failed + // 1a. We are the leader for the round. We will be leader again if we rejoin now, so we need to catch up before + // even attempting to. If we can't we should drop out of the group because we will block everyone from making + // progress. We can backoff and try rejoining later. + // 1b. We are not the leader. We might need to catch up. If we're already caught up we can rejoin immediately, + // otherwise, we just want to wait reasonable amount of time to catch up and rejoin if we are ready. + // 2. Assignment succeeded. + // 2a. We are caught up on configs. Awesome! We can proceed to run our assigned work. + // 2b. We need to try to catch up - try reading configs for reasonable amount of time. + + boolean needsReadToEnd = false; + boolean needsRejoin = false; + if (assignment.failed()) { + needsRejoin = true; + if (isLeader()) { + log.warn("Join group completed, but assignment failed and we are the leader. Reading to end of config and retrying."); + needsReadToEnd = true; + } else if (configState.offset() < assignment.offset()) { + log.warn("Join group completed, but assignment failed and we lagging. Reading to end of config and retrying."); + needsReadToEnd = true; + } else { + log.warn("Join group completed, but assignment failed. We were up to date, so just retrying."); + } + } else { + if (configState.offset() < assignment.offset()) { + log.warn("Catching up to assignment's config offset."); + needsReadToEnd = true; + } + } + + long now = time.milliseconds(); + if (scheduledRebalance <= now) { + log.debug("Requesting rebalance because scheduled rebalance timeout has been reached " + + "(now: {} scheduledRebalance: {}", scheduledRebalance, now); + + needsRejoin = true; + scheduledRebalance = Long.MAX_VALUE; + } + + if (needsReadToEnd) { + // Force exiting this method to avoid creating any connectors/tasks and require immediate rejoining if + // we timed out. This should only happen if we failed to read configuration for long enough, + // in which case giving back control to the main loop will prevent hanging around indefinitely after getting kicked out of the group. + // We also indicate to the main loop that we failed to readConfigs so it will check that the issue was resolved before trying to join the group + if (readConfigToEnd(workerSyncTimeoutMs)) { + canReadConfigs = true; + } else { + canReadConfigs = false; + needsRejoin = true; + } + } + + if (needsRejoin) { + member.requestRejoin(); + return false; + } + + // Should still validate that they match since we may have gone *past* the required offset, in which case we + // should *not* start any tasks and rejoin + if (configState.offset() != assignment.offset()) { + log.info("Current config state offset {} does not match group assignment {}. Forcing rebalance.", configState.offset(), assignment.offset()); + member.requestRejoin(); + return false; + } + + startWork(); + + // We only mark this as resolved once we've actually started work, which allows us to correctly track whether + // what work is currently active and running. If we bail early, the main tick loop + having requested rejoin + // guarantees we'll attempt to rejoin before executing this method again. + herderMetrics.rebalanceSucceeded(time.milliseconds()); + rebalanceResolved = true; + + if (!assignment.revokedConnectors().isEmpty() || !assignment.revokedTasks().isEmpty()) { + assignment.revokedConnectors().clear(); + assignment.revokedTasks().clear(); + member.requestRejoin(); + return false; + } + return true; + } + + /** + * Try to read to the end of the config log within the given timeout + * @param timeoutMs maximum time to wait to sync to the end of the log + * @return true if successful, false if timed out + */ + private boolean readConfigToEnd(long timeoutMs) { + if (configState.offset() < assignment.offset()) { + log.info("Current config state offset {} is behind group assignment {}, reading to end of config log", configState.offset(), assignment.offset()); + } else { + log.info("Reading to end of config log; current config state offset: {}", configState.offset()); + } + try { + configBackingStore.refresh(timeoutMs, TimeUnit.MILLISECONDS); + configState = configBackingStore.snapshot(); + log.info("Finished reading to end of log and updated config snapshot, new config log offset: {}", configState.offset()); + backoffRetries = BACKOFF_RETRIES; + return true; + } catch (TimeoutException e) { + // in case reading the log takes too long, leave the group to ensure a quick rebalance (although by default we should be out of the group already) + // and back off to avoid a tight loop of rejoin-attempt-to-catch-up-leave + log.warn("Didn't reach end of config log quickly enough", e); + member.maybeLeaveGroup("taking too long to read the log"); + backoff(workerUnsyncBackoffMs); + return false; + } + } + + private void backoff(long ms) { + if (ConnectProtocolCompatibility.fromProtocolVersion(currentProtocolVersion) == EAGER) { + time.sleep(ms); + return; + } + + if (backoffRetries > 0) { + int rebalanceDelayFraction = + config.getInt(DistributedConfig.SCHEDULED_REBALANCE_MAX_DELAY_MS_CONFIG) / 10 / backoffRetries; + time.sleep(rebalanceDelayFraction); + --backoffRetries; + return; + } + + ExtendedAssignment runningAssignmentSnapshot; + synchronized (this) { + runningAssignmentSnapshot = ExtendedAssignment.duplicate(runningAssignment); + } + log.info("Revoking current running assignment {} because after {} retries the worker " + + "has not caught up with the latest Connect cluster updates", + runningAssignmentSnapshot, BACKOFF_RETRIES); + member.revokeAssignment(runningAssignmentSnapshot); + backoffRetries = BACKOFF_RETRIES; + } + + private void startAndStop(Collection> callables) { + try { + startAndStopExecutor.invokeAll(callables); + } catch (InterruptedException e) { + // ignore + } + } + + private void startWork() { + // Start assigned connectors and tasks + List> callables = new ArrayList<>(); + + // The sets in runningAssignment may change when onRevoked is called voluntarily by this + // herder (e.g. when a broker coordinator failure is detected). Otherwise the + // runningAssignment is always replaced by the assignment here. + synchronized (this) { + log.info("Starting connectors and tasks using config offset {}", assignment.offset()); + log.debug("Received assignment: {}", assignment); + log.debug("Currently running assignment: {}", runningAssignment); + + for (String connectorName : assignmentDifference(assignment.connectors(), runningAssignment.connectors())) { + callables.add(getConnectorStartingCallable(connectorName)); + } + + // These tasks have been stopped by this worker due to task reconfiguration. In order to + // restart them, they are removed just before the overall task startup from the set of + // currently running tasks. Therefore, they'll be restarted only if they are included in + // the assignment that was just received after rebalancing. + log.debug("Tasks to restart from currently running assignment: {}", tasksToRestart); + runningAssignment.tasks().removeAll(tasksToRestart); + tasksToRestart.clear(); + for (ConnectorTaskId taskId : assignmentDifference(assignment.tasks(), runningAssignment.tasks())) { + callables.add(getTaskStartingCallable(taskId)); + } + } + + startAndStop(callables); + + synchronized (this) { + runningAssignment = member.currentProtocolVersion() == CONNECT_PROTOCOL_V0 + ? ExtendedAssignment.empty() + : assignment; + } + + log.info("Finished starting connectors and tasks"); + } + + // arguments should assignment collections (connectors or tasks) and should not be null + private static Collection assignmentDifference(Collection update, Collection running) { + if (running.isEmpty()) { + return update; + } + HashSet diff = new HashSet<>(update); + diff.removeAll(running); + return diff; + } + + private boolean startTask(ConnectorTaskId taskId) { + log.info("Starting task {}", taskId); + return worker.startTask( + taskId, + configState, + configState.connectorConfig(taskId.connector()), + configState.taskConfig(taskId), + this, + configState.targetState(taskId.connector()) + ); + } + + private Callable getTaskStartingCallable(final ConnectorTaskId taskId) { + return () -> { + try { + startTask(taskId); + } catch (Throwable t) { + log.error("Couldn't instantiate task {} because it has an invalid task configuration. This task will not execute until reconfigured.", + taskId, t); + onFailure(taskId, t); + } + return null; + }; + } + + private Callable getTaskStoppingCallable(final ConnectorTaskId taskId) { + return () -> { + worker.stopAndAwaitTask(taskId); + return null; + }; + } + + // Helper for starting a connector with the given name, which will extract & parse the config, generate connector + // context and add to the worker. This needs to be called from within the main worker thread for this herder. + // The callback is invoked after the connector has finished startup and generated task configs, or failed in the process. + private void startConnector(String connectorName, Callback callback) { + log.info("Starting connector {}", connectorName); + final Map configProps = configState.connectorConfig(connectorName); + final CloseableConnectorContext ctx = new HerderConnectorContext(this, connectorName); + final TargetState initialState = configState.targetState(connectorName); + final Callback onInitialStateChange = (error, newState) -> { + if (error != null) { + callback.onCompletion(new ConnectException("Failed to start connector: " + connectorName, error), null); + return; + } + + // Use newState here in case the connector has been paused right after being created + if (newState == TargetState.STARTED) { + addRequest( + () -> { + // Request configuration since this could be a brand new connector. However, also only update those + // task configs if they are actually different from the existing ones to avoid unnecessary updates when this is + // just restoring an existing connector. + reconfigureConnectorTasksWithRetry(time.milliseconds(), connectorName); + callback.onCompletion(null, null); + return null; + }, + forwardErrorCallback(callback) + ); + } else { + callback.onCompletion(null, null); + } + }; + worker.startConnector(connectorName, configProps, ctx, this, initialState, onInitialStateChange); + } + + private Callable getConnectorStartingCallable(final String connectorName) { + return () -> { + try { + startConnector(connectorName, (error, result) -> { + if (error != null) { + log.error("Failed to start connector '" + connectorName + "'", error); + } + }); + } catch (Throwable t) { + log.error("Unexpected error while trying to start connector " + connectorName, t); + onFailure(connectorName, t); + } + return null; + }; + } + + private Callable getConnectorStoppingCallable(final String connectorName) { + return () -> { + try { + worker.stopAndAwaitConnector(connectorName); + } catch (Throwable t) { + log.error("Failed to shut down connector " + connectorName, t); + } + return null; + }; + } + + private void reconfigureConnectorTasksWithRetry(long initialRequestTime, final String connName) { + reconfigureConnector(connName, (error, result) -> { + // If we encountered an error, we don't have much choice but to just retry. If we don't, we could get + // stuck with a connector that thinks it has generated tasks, but wasn't actually successful and therefore + // never makes progress. The retry has to run through a DistributedHerderRequest since this callback could be happening + // from the HTTP request forwarding thread. + if (error != null) { + if (isPossibleExpiredKeyException(initialRequestTime, error)) { + log.debug("Failed to reconfigure connector's tasks ({}), possibly due to expired session key. Retrying after backoff", connName); + } else { + log.error("Failed to reconfigure connector's tasks ({}), retrying after backoff:", connName, error); + } + addRequest(RECONFIGURE_CONNECTOR_TASKS_BACKOFF_MS, + () -> { + reconfigureConnectorTasksWithRetry(initialRequestTime, connName); + return null; + }, (err, res) -> { + if (err != null) { + log.error("Unexpected error during connector task reconfiguration: ", err); + log.error("Task reconfiguration for {} failed unexpectedly, this connector will not be properly reconfigured unless manually triggered.", connName); + } + } + ); + } + }); + } + + boolean isPossibleExpiredKeyException(long initialRequestTime, Throwable error) { + if (error instanceof ConnectRestException) { + ConnectRestException connectError = (ConnectRestException) error; + return connectError.statusCode() == Response.Status.FORBIDDEN.getStatusCode() + && initialRequestTime + TimeUnit.MINUTES.toMillis(1) >= time.milliseconds(); + } + return false; + } + + // Updates configurations for a connector by requesting them from the connector, filling in parameters provided + // by the system, then checks whether any configs have actually changed before submitting the new configs to storage + private void reconfigureConnector(final String connName, final Callback cb) { + try { + if (!worker.isRunning(connName)) { + log.info("Skipping reconfiguration of connector {} since it is not running", connName); + return; + } + + Map configs = configState.connectorConfig(connName); + + ConnectorConfig connConfig; + if (worker.isSinkConnector(connName)) { + connConfig = new SinkConnectorConfig(plugins(), configs); + } else { + connConfig = new SourceConnectorConfig(plugins(), configs, worker.isTopicCreationEnabled()); + } + + final List> taskProps = worker.connectorTaskConfigs(connName, connConfig); + boolean changed = false; + int currentNumTasks = configState.taskCount(connName); + if (taskProps.size() != currentNumTasks) { + log.debug("Change in connector task count from {} to {}, writing updated task configurations", currentNumTasks, taskProps.size()); + changed = true; + } else { + int index = 0; + for (Map taskConfig : taskProps) { + if (!taskConfig.equals(configState.taskConfig(new ConnectorTaskId(connName, index)))) { + log.debug("Change in task configurations, writing updated task configurations"); + changed = true; + break; + } + index++; + } + } + if (changed) { + List> rawTaskProps = reverseTransform(connName, configState, taskProps); + if (isLeader()) { + configBackingStore.putTaskConfigs(connName, rawTaskProps); + cb.onCompletion(null, null); + } else { + // We cannot forward the request on the same thread because this reconfiguration can happen as a result of connector + // addition or removal. If we blocked waiting for the response from leader, we may be kicked out of the worker group. + forwardRequestExecutor.submit(() -> { + try { + String leaderUrl = leaderUrl(); + if (Utils.isBlank(leaderUrl)) { + cb.onCompletion(new ConnectException("Request to leader to " + + "reconfigure connector tasks failed " + + "because the URL of the leader's REST interface is empty!"), null); + return; + } + String reconfigUrl = UriBuilder.fromUri(leaderUrl) + .path("connectors") + .path(connName) + .path("tasks") + .build() + .toString(); + log.trace("Forwarding task configurations for connector {} to leader", connName); + RestClient.httpRequest(reconfigUrl, "POST", null, rawTaskProps, null, config, sessionKey, requestSignatureAlgorithm); + cb.onCompletion(null, null); + } catch (ConnectException e) { + log.error("Request to leader to reconfigure connector tasks failed", e); + cb.onCompletion(e, null); + } + }); + } + } + } catch (Throwable t) { + cb.onCompletion(t, null); + } + } + + private boolean checkRebalanceNeeded(Callback callback) { + // Raise an error if we are expecting a rebalance to begin. This prevents us from forwarding requests + // based on stale leadership or assignment information + if (needsReconfigRebalance) { + callback.onCompletion(new RebalanceNeededException("Request cannot be completed because a rebalance is expected"), null); + return true; + } + return false; + } + + DistributedHerderRequest addRequest(Callable action, Callback callback) { + return addRequest(0, action, callback); + } + + DistributedHerderRequest addRequest(long delayMs, Callable action, Callback callback) { + DistributedHerderRequest req = new DistributedHerderRequest(time.milliseconds() + delayMs, requestSeqNum.incrementAndGet(), action, callback); + requests.add(req); + if (peekWithoutException() == req) + member.wakeup(); + return req; + } + + private boolean internalRequestValidationEnabled() { + return internalRequestValidationEnabled(member.currentProtocolVersion()); + } + + private static boolean internalRequestValidationEnabled(short protocolVersion) { + return protocolVersion >= CONNECT_PROTOCOL_V2; + } + + private DistributedHerderRequest peekWithoutException() { + try { + return requests.isEmpty() ? null : requests.first(); + } catch (NoSuchElementException e) { + // Ignore exception. Should be rare. Means that the collection became empty between + // checking the size and retrieving the first element. + } + return null; + } + + public class ConfigUpdateListener implements ConfigBackingStore.UpdateListener { + @Override + public void onConnectorConfigRemove(String connector) { + log.info("Connector {} config removed", connector); + + synchronized (DistributedHerder.this) { + // rebalance after connector removal to ensure that existing tasks are balanced among workers + if (configState.contains(connector)) + needsReconfigRebalance = true; + connectorConfigUpdates.add(connector); + } + member.wakeup(); + } + + @Override + public void onConnectorConfigUpdate(String connector) { + log.info("Connector {} config updated", connector); + + // Stage the update and wake up the work thread. Connector config *changes* only need the one connector + // to be bounced. However, this callback may also indicate a connector *addition*, which does require + // a rebalance, so we need to be careful about what operation we request. + synchronized (DistributedHerder.this) { + if (!configState.contains(connector)) + needsReconfigRebalance = true; + connectorConfigUpdates.add(connector); + } + member.wakeup(); + } + + @Override + public void onTaskConfigUpdate(Collection tasks) { + log.info("Tasks {} configs updated", tasks); + + // Stage the update and wake up the work thread. + // The set of tasks is recorder for incremental cooperative rebalancing, in which + // tasks don't get restarted unless they are balanced between workers. + // With eager rebalancing there's no need to record the set of tasks because task reconfigs + // always need a rebalance to ensure offsets get committed. In eager rebalancing the + // recorded set of tasks remains unused. + // TODO: As an optimization, some task config updates could avoid a rebalance. In particular, single-task + // connectors clearly don't need any coordination. + synchronized (DistributedHerder.this) { + needsReconfigRebalance = true; + taskConfigUpdates.addAll(tasks); + } + member.wakeup(); + } + + @Override + public void onConnectorTargetStateChange(String connector) { + log.info("Connector {} target state change", connector); + + synchronized (DistributedHerder.this) { + connectorTargetStateChanges.add(connector); + } + member.wakeup(); + } + + @Override + public void onSessionKeyUpdate(SessionKey sessionKey) { + log.info("Session key updated"); + + synchronized (DistributedHerder.this) { + DistributedHerder.this.sessionKey = sessionKey.key(); + // Track the expiration of the key. + // Followers will receive rotated keys from the leader and won't be responsible for + // tracking expiration and distributing new keys themselves, but may become leaders + // later on and will need to know when to update the key. + if (keyRotationIntervalMs > 0) { + DistributedHerder.this.keyExpiration = sessionKey.creationTimestamp() + keyRotationIntervalMs; + } + } + } + + @Override + public void onRestartRequest(RestartRequest request) { + log.info("Received and enqueuing {}", request); + + synchronized (DistributedHerder.this) { + String connectorName = request.connectorName(); + //preserve the highest impact request + pendingRestartRequests.compute(connectorName, (k, existingRequest) -> { + if (existingRequest == null || request.compareTo(existingRequest) > 0) { + log.debug("Overwriting existing {} and enqueuing the higher impact {}", existingRequest, request); + return request; + } else { + log.debug("Preserving existing higher impact {} and ignoring incoming {}", existingRequest, request); + return existingRequest; + } + }); + } + member.wakeup(); + } + } + + class DistributedHerderRequest implements HerderRequest, Comparable { + private final long at; + private final long seq; + private final Callable action; + private final Callback callback; + + public DistributedHerderRequest(long at, long seq, Callable action, Callback callback) { + this.at = at; + this.seq = seq; + this.action = action; + this.callback = callback; + } + + public Callable action() { + return action; + } + + public Callback callback() { + return callback; + } + + @Override + public void cancel() { + DistributedHerder.this.requests.remove(this); + } + + @Override + public int compareTo(DistributedHerderRequest o) { + final int cmp = Long.compare(at, o.at); + return cmp == 0 ? Long.compare(seq, o.seq) : cmp; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof DistributedHerderRequest)) + return false; + DistributedHerderRequest other = (DistributedHerderRequest) o; + return compareTo(other) == 0; + } + + @Override + public int hashCode() { + return Objects.hash(at, seq); + } + } + + private static Callback forwardErrorCallback(final Callback callback) { + return (error, result) -> { + if (error != null) + callback.onCompletion(error, null); + }; + } + + private void updateDeletedConnectorStatus() { + ClusterConfigState snapshot = configBackingStore.snapshot(); + Set connectors = snapshot.connectors(); + for (String connector : statusBackingStore.connectors()) { + if (!connectors.contains(connector)) { + log.debug("Cleaning status information for connector {}", connector); + onDeletion(connector); + } + } + } + + private void updateDeletedTaskStatus() { + ClusterConfigState snapshot = configBackingStore.snapshot(); + for (String connector : statusBackingStore.connectors()) { + Set remainingTasks = new HashSet<>(snapshot.tasks(connector)); + + statusBackingStore.getAll(connector).stream() + .map(TaskStatus::id) + .filter(task -> !remainingTasks.contains(task)) + .forEach(this::onDeletion); + } + } + + protected HerderMetrics herderMetrics() { + return herderMetrics; + } + + // Rebalances are triggered internally from the group member, so these are always executed in the work thread. + public class RebalanceListener implements WorkerRebalanceListener { + private final Time time; + RebalanceListener(Time time) { + this.time = time; + } + + @Override + public void onAssigned(ExtendedAssignment assignment, int generation) { + // This callback just logs the info and saves it. The actual response is handled in the main loop, which + // ensures the group member's logic for rebalancing can complete, potentially long-running steps to + // catch up (or backoff if we fail) not executed in a callback, and so we'll be able to invoke other + // group membership actions (e.g., we may need to explicitly leave the group if we cannot handle the + // assigned tasks). + short priorProtocolVersion = currentProtocolVersion; + DistributedHerder.this.currentProtocolVersion = member.currentProtocolVersion(); + log.info( + "Joined group at generation {} with protocol version {} and got assignment: {} with rebalance delay: {}", + generation, + DistributedHerder.this.currentProtocolVersion, + assignment, + assignment.delay() + ); + synchronized (DistributedHerder.this) { + DistributedHerder.this.assignment = assignment; + DistributedHerder.this.generation = generation; + int delay = assignment.delay(); + DistributedHerder.this.scheduledRebalance = delay > 0 + ? time.milliseconds() + delay + : Long.MAX_VALUE; + + boolean requestValidationWasEnabled = internalRequestValidationEnabled(priorProtocolVersion); + boolean requestValidationNowEnabled = internalRequestValidationEnabled(currentProtocolVersion); + if (requestValidationNowEnabled != requestValidationWasEnabled) { + // Internal request verification has been switched on or off; let the user know + if (requestValidationNowEnabled) { + log.info("Internal request validation has been re-enabled"); + } else { + log.warn( + "The protocol used by this Connect cluster has been downgraded from '{}' to '{}' and internal request " + + "validation is now disabled. This is most likely caused by a new worker joining the cluster with an " + + "older protocol specified for the {} configuration; if this is not intentional, either remove the {} " + + "configuration from that worker's config file, or change its value to '{}'. If this configuration is " + + "left as-is, the cluster will be insecure; for more information, see KIP-507: " + + "https://cwiki.apache.org/confluence/display/KAFKA/KIP-507%3A+Securing+Internal+Connect+REST+Endpoints", + ConnectProtocolCompatibility.fromProtocolVersion(priorProtocolVersion), + ConnectProtocolCompatibility.fromProtocolVersion(DistributedHerder.this.currentProtocolVersion), + DistributedConfig.CONNECT_PROTOCOL_CONFIG, + DistributedConfig.CONNECT_PROTOCOL_CONFIG, + ConnectProtocolCompatibility.SESSIONED.name() + ); + } + } + + rebalanceResolved = false; + herderMetrics.rebalanceStarted(time.milliseconds()); + } + + // Delete the statuses of all connectors and tasks removed prior to the start of this rebalance. This + // has to be done after the rebalance completes to avoid race conditions as the previous generation + // attempts to change the state to UNASSIGNED after tasks have been stopped. + if (isLeader()) { + updateDeletedConnectorStatus(); + updateDeletedTaskStatus(); + } + + // We *must* interrupt any poll() call since this could occur when the poll starts, and we might then + // sleep in the poll() for a long time. Forcing a wakeup ensures we'll get to process this event in the + // main thread. + member.wakeup(); + } + + @Override + public void onRevoked(String leader, Collection connectors, Collection tasks) { + // Note that since we don't reset the assignment, we don't revoke leadership here. During a rebalance, + // it is still important to have a leader that can write configs, offsets, etc. + + if (rebalanceResolved || currentProtocolVersion >= CONNECT_PROTOCOL_V1) { + List> callables = new ArrayList<>(); + for (final String connectorName : connectors) { + callables.add(getConnectorStoppingCallable(connectorName)); + } + + // TODO: We need to at least commit task offsets, but if we could commit offsets & pause them instead of + // stopping them then state could continue to be reused when the task remains on this worker. For example, + // this would avoid having to close a connection and then reopen it when the task is assigned back to this + // worker again. + for (final ConnectorTaskId taskId : tasks) { + callables.add(getTaskStoppingCallable(taskId)); + } + + // The actual timeout for graceful task/connector stop is applied in worker's + // stopAndAwaitTask/stopAndAwaitConnector methods. + startAndStop(callables); + log.info("Finished stopping tasks in preparation for rebalance"); + + synchronized (DistributedHerder.this) { + log.debug("Removing connectors from running assignment {}", connectors); + runningAssignment.connectors().removeAll(connectors); + log.debug("Removing tasks from running assignment {}", tasks); + runningAssignment.tasks().removeAll(tasks); + } + + if (isTopicTrackingEnabled) { + // Send tombstones to reset active topics for removed connectors only after + // connectors and tasks have been stopped, or these tombstones will be overwritten + resetActiveTopics(connectors, tasks); + } + + // Ensure that all status updates have been pushed to the storage system before rebalancing. + // Otherwise, we may inadvertently overwrite the state with a stale value after the rebalance + // completes. + statusBackingStore.flush(); + log.info("Finished flushing status backing store in preparation for rebalance"); + } else { + log.info("Wasn't able to resume work after last rebalance, can skip stopping connectors and tasks"); + } + } + + private void resetActiveTopics(Collection connectors, Collection tasks) { + connectors.stream() + .filter(connectorName -> !configState.contains(connectorName)) + .forEach(DistributedHerder.this::resetConnectorActiveTopics); + tasks.stream() + .map(ConnectorTaskId::connector) + .distinct() + .filter(connectorName -> !configState.contains(connectorName)) + .forEach(DistributedHerder.this::resetConnectorActiveTopics); + } + } + + class HerderMetrics { + private final MetricGroup metricGroup; + private final Sensor rebalanceCompletedCounts; + private final Sensor rebalanceTime; + private volatile long lastRebalanceCompletedAtMillis = Long.MIN_VALUE; + private volatile boolean rebalancing = false; + private volatile long rebalanceStartedAtMillis = 0L; + + public HerderMetrics(ConnectMetrics connectMetrics) { + ConnectMetricsRegistry registry = connectMetrics.registry(); + metricGroup = connectMetrics.group(registry.workerRebalanceGroupName()); + + metricGroup.addValueMetric(registry.connectProtocol, now -> + ConnectProtocolCompatibility.fromProtocolVersion(member.currentProtocolVersion()).name() + ); + metricGroup.addValueMetric(registry.leaderName, now -> leaderUrl()); + metricGroup.addValueMetric(registry.epoch, now -> (double) generation); + metricGroup.addValueMetric(registry.rebalanceMode, now -> rebalancing ? 1.0d : 0.0d); + + rebalanceCompletedCounts = metricGroup.sensor("completed-rebalance-count"); + rebalanceCompletedCounts.add(metricGroup.metricName(registry.rebalanceCompletedTotal), new CumulativeSum()); + + rebalanceTime = metricGroup.sensor("rebalance-time"); + rebalanceTime.add(metricGroup.metricName(registry.rebalanceTimeMax), new Max()); + rebalanceTime.add(metricGroup.metricName(registry.rebalanceTimeAvg), new Avg()); + + metricGroup.addValueMetric(registry.rebalanceTimeSinceLast, now -> + lastRebalanceCompletedAtMillis == Long.MIN_VALUE ? Double.POSITIVE_INFINITY : (double) (now - lastRebalanceCompletedAtMillis)); + } + + void close() { + metricGroup.close(); + } + + void rebalanceStarted(long now) { + rebalanceStartedAtMillis = now; + rebalancing = true; + } + + void rebalanceSucceeded(long now) { + long duration = Math.max(0L, now - rebalanceStartedAtMillis); + rebalancing = false; + rebalanceCompletedCounts.record(1.0); + rebalanceTime.record(duration); + lastRebalanceCompletedAtMillis = now; + } + + protected MetricGroup metricGroup() { + return metricGroup; + } + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/EagerAssignor.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/EagerAssignor.java new file mode 100644 index 0000000..b6dbd09 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/EagerAssignor.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.distributed; + +import org.apache.kafka.common.utils.CircularIterator; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.slf4j.Logger; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.kafka.common.message.JoinGroupResponseData.JoinGroupResponseMember; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocol.Assignment; +import static org.apache.kafka.connect.runtime.distributed.WorkerCoordinator.LeaderState; + + +/** + * An assignor that computes a unweighted round-robin distribution of connectors and tasks. The + * connectors are assigned to the workers first, followed by the tasks. This is to avoid + * load imbalance when several 1-task connectors are running, given that a connector is usually + * more lightweight than a task. + * + * Note that this class is NOT thread-safe. + */ +public class EagerAssignor implements ConnectAssignor { + private final Logger log; + + public EagerAssignor(LogContext logContext) { + this.log = logContext.logger(EagerAssignor.class); + } + + @Override + public Map performAssignment(String leaderId, String protocol, + List allMemberMetadata, + WorkerCoordinator coordinator) { + log.debug("Performing task assignment"); + Map memberConfigs = new HashMap<>(); + for (JoinGroupResponseMember member : allMemberMetadata) + memberConfigs.put(member.memberId(), IncrementalCooperativeConnectProtocol.deserializeMetadata(ByteBuffer.wrap(member.metadata()))); + + long maxOffset = findMaxMemberConfigOffset(memberConfigs, coordinator); + Long leaderOffset = ensureLeaderConfig(maxOffset, coordinator); + if (leaderOffset == null) + return fillAssignmentsAndSerialize(memberConfigs.keySet(), Assignment.CONFIG_MISMATCH, + leaderId, memberConfigs.get(leaderId).url(), maxOffset, + new HashMap<>(), new HashMap<>()); + return performTaskAssignment(leaderId, leaderOffset, memberConfigs, coordinator); + } + + private Long ensureLeaderConfig(long maxOffset, WorkerCoordinator coordinator) { + // If this leader is behind some other members, we can't do assignment + if (coordinator.configSnapshot().offset() < maxOffset) { + // We might be able to take a new snapshot to catch up immediately and avoid another round of syncing here. + // Alternatively, if this node has already passed the maximum reported by any other member of the group, it + // is also safe to use this newer state. + ClusterConfigState updatedSnapshot = coordinator.configFreshSnapshot(); + if (updatedSnapshot.offset() < maxOffset) { + log.info("Was selected to perform assignments, but do not have latest config found in sync request. " + + "Returning an empty configuration to trigger re-sync."); + return null; + } else { + coordinator.configSnapshot(updatedSnapshot); + return updatedSnapshot.offset(); + } + } + return maxOffset; + } + + private Map performTaskAssignment(String leaderId, long maxOffset, + Map memberConfigs, + WorkerCoordinator coordinator) { + Map> connectorAssignments = new HashMap<>(); + Map> taskAssignments = new HashMap<>(); + + // Perform round-robin task assignment. Assign all connectors and then all tasks because assigning both the + // connector and its tasks can lead to very uneven distribution of work in some common cases (e.g. for connectors + // that generate only 1 task each; in a cluster of 2 or an even # of nodes, only even nodes will be assigned + // connectors and only odd nodes will be assigned tasks, but tasks are, on average, actually more resource + // intensive than connectors). + List connectorsSorted = sorted(coordinator.configSnapshot().connectors()); + CircularIterator memberIt = new CircularIterator<>(sorted(memberConfigs.keySet())); + for (String connectorId : connectorsSorted) { + String connectorAssignedTo = memberIt.next(); + log.trace("Assigning connector {} to {}", connectorId, connectorAssignedTo); + Collection memberConnectors = connectorAssignments.get(connectorAssignedTo); + if (memberConnectors == null) { + memberConnectors = new ArrayList<>(); + connectorAssignments.put(connectorAssignedTo, memberConnectors); + } + memberConnectors.add(connectorId); + } + for (String connectorId : connectorsSorted) { + for (ConnectorTaskId taskId : sorted(coordinator.configSnapshot().tasks(connectorId))) { + String taskAssignedTo = memberIt.next(); + log.trace("Assigning task {} to {}", taskId, taskAssignedTo); + Collection memberTasks = taskAssignments.get(taskAssignedTo); + if (memberTasks == null) { + memberTasks = new ArrayList<>(); + taskAssignments.put(taskAssignedTo, memberTasks); + } + memberTasks.add(taskId); + } + } + + coordinator.leaderState(new LeaderState(memberConfigs, connectorAssignments, taskAssignments)); + + return fillAssignmentsAndSerialize(memberConfigs.keySet(), Assignment.NO_ERROR, + leaderId, memberConfigs.get(leaderId).url(), maxOffset, connectorAssignments, taskAssignments); + } + + private Map fillAssignmentsAndSerialize(Collection members, + short error, + String leaderId, + String leaderUrl, + long maxOffset, + Map> connectorAssignments, + Map> taskAssignments) { + + Map groupAssignment = new HashMap<>(); + for (String member : members) { + Collection connectors = connectorAssignments.get(member); + if (connectors == null) { + connectors = Collections.emptyList(); + } + Collection tasks = taskAssignments.get(member); + if (tasks == null) { + tasks = Collections.emptyList(); + } + Assignment assignment = new Assignment(error, leaderId, leaderUrl, maxOffset, connectors, tasks); + log.debug("Assignment: {} -> {}", member, assignment); + groupAssignment.put(member, ConnectProtocol.serializeAssignment(assignment)); + } + log.debug("Finished assignment"); + return groupAssignment; + } + + private long findMaxMemberConfigOffset(Map memberConfigs, + WorkerCoordinator coordinator) { + // The new config offset is the maximum seen by any member. We always perform assignment using this offset, + // even if some members have fallen behind. The config offset used to generate the assignment is included in + // the response so members that have fallen behind will not use the assignment until they have caught up. + Long maxOffset = null; + for (Map.Entry stateEntry : memberConfigs.entrySet()) { + long memberRootOffset = stateEntry.getValue().offset(); + if (maxOffset == null) + maxOffset = memberRootOffset; + else + maxOffset = Math.max(maxOffset, memberRootOffset); + } + + log.debug("Max config offset root: {}, local snapshot config offsets root: {}", + maxOffset, coordinator.configSnapshot().offset()); + return maxOffset; + } + + private static > List sorted(Collection members) { + List res = new ArrayList<>(members); + Collections.sort(res); + return res; + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/ExtendedAssignment.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/ExtendedAssignment.java new file mode 100644 index 0000000..e544407 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/ExtendedAssignment.java @@ -0,0 +1,284 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.distributed; + +import org.apache.kafka.common.protocol.types.Struct; +import org.apache.kafka.connect.util.ConnectorTaskId; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocol.ASSIGNMENT_KEY_NAME; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocol.CONFIG_OFFSET_KEY_NAME; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocol.CONNECTOR_KEY_NAME; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocol.CONNECTOR_TASK; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocol.ERROR_KEY_NAME; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocol.LEADER_KEY_NAME; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocol.LEADER_URL_KEY_NAME; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocol.TASKS_KEY_NAME; +import static org.apache.kafka.connect.runtime.distributed.IncrementalCooperativeConnectProtocol.ASSIGNMENT_V1; +import static org.apache.kafka.connect.runtime.distributed.IncrementalCooperativeConnectProtocol.CONNECTOR_ASSIGNMENT_V1; +import static org.apache.kafka.connect.runtime.distributed.IncrementalCooperativeConnectProtocol.CONNECT_PROTOCOL_V1; +import static org.apache.kafka.connect.runtime.distributed.IncrementalCooperativeConnectProtocol.REVOKED_KEY_NAME; +import static org.apache.kafka.connect.runtime.distributed.IncrementalCooperativeConnectProtocol.SCHEDULED_DELAY_KEY_NAME; + +/** + * The extended assignment of connectors and tasks that includes revoked connectors and tasks + * as well as a scheduled rebalancing delay. + */ +public class ExtendedAssignment extends ConnectProtocol.Assignment { + private final short version; + private final Collection revokedConnectorIds; + private final Collection revokedTaskIds; + private final int delay; + + private static final ExtendedAssignment EMPTY = new ExtendedAssignment( + CONNECT_PROTOCOL_V1, ConnectProtocol.Assignment.NO_ERROR, null, null, -1, + Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), 0); + + /** + * Create an assignment indicating responsibility for the given connector instances and task Ids. + * + * @param version Connect protocol version + * @param error error code for this assignment; {@code ConnectProtocol.Assignment.NO_ERROR} + * indicates no error during assignment + * @param leader Connect group's leader Id; may be null only on the empty assignment + * @param leaderUrl Connect group's leader URL; may be null only on the empty assignment + * @param configOffset the offset in the config topic that this assignment is corresponding to + * @param connectorIds list of connectors that the worker should instantiate and run; may not be null + * @param taskIds list of task IDs that the worker should instantiate and run; may not be null + * @param revokedConnectorIds list of connectors that the worker should stop running; may not be null + * @param revokedTaskIds list of task IDs that the worker should stop running; may not be null + * @param delay the scheduled delay after which the worker should rejoin the group + */ + public ExtendedAssignment(short version, short error, String leader, String leaderUrl, long configOffset, + Collection connectorIds, Collection taskIds, + Collection revokedConnectorIds, Collection revokedTaskIds, + int delay) { + super(error, leader, leaderUrl, configOffset, connectorIds, taskIds); + this.version = version; + this.revokedConnectorIds = Objects.requireNonNull(revokedConnectorIds, + "Revoked connector IDs may be empty but not null"); + this.revokedTaskIds = Objects.requireNonNull(revokedTaskIds, + "Revoked task IDs may be empty but not null"); + this.delay = delay; + } + + public static ExtendedAssignment duplicate(ExtendedAssignment assignment) { + return new ExtendedAssignment( + assignment.version(), + assignment.error(), + assignment.leader(), + assignment.leaderUrl(), + assignment.offset(), + new LinkedHashSet<>(assignment.connectors()), + new LinkedHashSet<>(assignment.tasks()), + new LinkedHashSet<>(assignment.revokedConnectors()), + new LinkedHashSet<>(assignment.revokedTasks()), + assignment.delay()); + } + + /** + * Return the version of the connect protocol that this assignment belongs to. + * + * @return the connect protocol version of this assignment + */ + public short version() { + return version; + } + + /** + * Return the IDs of the connectors that are revoked by this assignment. + * + * @return the revoked connector IDs; empty if there are no revoked connectors + */ + public Collection revokedConnectors() { + return revokedConnectorIds; + } + + /** + * Return the IDs of the tasks that are revoked by this assignment. + * + * @return the revoked task IDs; empty if there are no revoked tasks + */ + public Collection revokedTasks() { + return revokedTaskIds; + } + + /** + * Return the delay for the rebalance that is scheduled by this assignment. + * + * @return the scheduled delay + */ + public int delay() { + return delay; + } + + /** + * Return an empty assignment. + * + * @return an empty assignment + */ + public static ExtendedAssignment empty() { + return EMPTY; + } + + @Override + public String toString() { + return "Assignment{" + + "error=" + error() + + ", leader='" + leader() + '\'' + + ", leaderUrl='" + leaderUrl() + '\'' + + ", offset=" + offset() + + ", connectorIds=" + connectors() + + ", taskIds=" + tasks() + + ", revokedConnectorIds=" + revokedConnectorIds + + ", revokedTaskIds=" + revokedTaskIds + + ", delay=" + delay + + '}'; + } + + private Map> revokedAsMap() { + if (revokedConnectorIds == null && revokedTaskIds == null) { + return null; + } + // Using LinkedHashMap preserves the ordering, which is helpful for tests and debugging + Map> taskMap = new LinkedHashMap<>(); + Optional.ofNullable(revokedConnectorIds) + .orElseGet(Collections::emptyList) + .stream() + .distinct() + .forEachOrdered(connectorId -> { + Collection connectorTasks = + taskMap.computeIfAbsent(connectorId, v -> new ArrayList<>()); + connectorTasks.add(CONNECTOR_TASK); + }); + + Optional.ofNullable(revokedTaskIds) + .orElseGet(Collections::emptyList) + .forEach(taskId -> { + String connectorId = taskId.connector(); + Collection connectorTasks = + taskMap.computeIfAbsent(connectorId, v -> new ArrayList<>()); + connectorTasks.add(taskId.task()); + }); + return taskMap; + } + + /** + * Return the {@code Struct} that corresponds to this assignment. + * + * @return the assignment struct + */ + public Struct toStruct() { + Collection assigned = taskAssignments(asMap()); + Collection revoked = taskAssignments(revokedAsMap()); + return new Struct(ASSIGNMENT_V1) + .set(ERROR_KEY_NAME, error()) + .set(LEADER_KEY_NAME, leader()) + .set(LEADER_URL_KEY_NAME, leaderUrl()) + .set(CONFIG_OFFSET_KEY_NAME, offset()) + .set(ASSIGNMENT_KEY_NAME, assigned != null ? assigned.toArray() : null) + .set(REVOKED_KEY_NAME, revoked != null ? revoked.toArray() : null) + .set(SCHEDULED_DELAY_KEY_NAME, delay); + } + + /** + * Given a {@code Struct} that encodes an assignment return the assignment object. + * + * @param struct a struct representing an assignment + * @return the assignment + */ + public static ExtendedAssignment fromStruct(short version, Struct struct) { + return struct == null + ? null + : new ExtendedAssignment( + version, + struct.getShort(ERROR_KEY_NAME), + struct.getString(LEADER_KEY_NAME), + struct.getString(LEADER_URL_KEY_NAME), + struct.getLong(CONFIG_OFFSET_KEY_NAME), + extractConnectors(struct, ASSIGNMENT_KEY_NAME), + extractTasks(struct, ASSIGNMENT_KEY_NAME), + extractConnectors(struct, REVOKED_KEY_NAME), + extractTasks(struct, REVOKED_KEY_NAME), + struct.getInt(SCHEDULED_DELAY_KEY_NAME)); + } + + private static Collection taskAssignments(Map> assignments) { + return assignments == null + ? null + : assignments.entrySet().stream() + .map(connectorEntry -> { + Struct taskAssignment = new Struct(CONNECTOR_ASSIGNMENT_V1); + taskAssignment.set(CONNECTOR_KEY_NAME, connectorEntry.getKey()); + taskAssignment.set(TASKS_KEY_NAME, connectorEntry.getValue().toArray()); + return taskAssignment; + }).collect(Collectors.toList()); + } + + private static Collection extractConnectors(Struct struct, String key) { + assert REVOKED_KEY_NAME.equals(key) || ASSIGNMENT_KEY_NAME.equals(key); + + Object[] connectors = struct.getArray(key); + if (connectors == null) { + return Collections.emptyList(); + } + List connectorIds = new ArrayList<>(); + for (Object structObj : connectors) { + Struct assignment = (Struct) structObj; + String connector = assignment.getString(CONNECTOR_KEY_NAME); + for (Object taskIdObj : assignment.getArray(TASKS_KEY_NAME)) { + Integer taskId = (Integer) taskIdObj; + if (taskId == CONNECTOR_TASK) { + connectorIds.add(connector); + } + } + } + return connectorIds; + } + + private static Collection extractTasks(Struct struct, String key) { + assert REVOKED_KEY_NAME.equals(key) || ASSIGNMENT_KEY_NAME.equals(key); + + Object[] tasks = struct.getArray(key); + if (tasks == null) { + return Collections.emptyList(); + } + List tasksIds = new ArrayList<>(); + for (Object structObj : tasks) { + Struct assignment = (Struct) structObj; + String connector = assignment.getString(CONNECTOR_KEY_NAME); + for (Object taskIdObj : assignment.getArray(TASKS_KEY_NAME)) { + Integer taskId = (Integer) taskIdObj; + if (taskId != CONNECTOR_TASK) { + tasksIds.add(new ConnectorTaskId(connector, taskId)); + } + } + } + return tasksIds; + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/ExtendedWorkerState.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/ExtendedWorkerState.java new file mode 100644 index 0000000..663979b --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/ExtendedWorkerState.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.distributed; + +/** + * A class that captures the deserialized form of a worker's metadata. + */ +public class ExtendedWorkerState extends ConnectProtocol.WorkerState { + private final ExtendedAssignment assignment; + + public ExtendedWorkerState(String url, long offset, ExtendedAssignment assignment) { + super(url, offset); + this.assignment = assignment != null ? assignment : ExtendedAssignment.empty(); + } + + /** + * This method returns which was the assignment of connectors and tasks on a worker at the + * moment that its state was captured by this class. + * + * @return the assignment of connectors and tasks + */ + public ExtendedAssignment assignment() { + return assignment; + } + + @Override + public String toString() { + return "WorkerState{" + + "url='" + url() + '\'' + + ", offset=" + offset() + + ", " + assignment + + '}'; + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/IncrementalCooperativeAssignor.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/IncrementalCooperativeAssignor.java new file mode 100644 index 0000000..f7fa55d --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/IncrementalCooperativeAssignor.java @@ -0,0 +1,766 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.distributed; + +import java.util.Map.Entry; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.connect.runtime.distributed.WorkerCoordinator.ConnectorsAndTasks; +import org.apache.kafka.connect.runtime.distributed.WorkerCoordinator.WorkerLoad; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.slf4j.Logger; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.TreeSet; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.apache.kafka.common.message.JoinGroupResponseData.JoinGroupResponseMember; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocol.Assignment; +import static org.apache.kafka.connect.runtime.distributed.IncrementalCooperativeConnectProtocol.CONNECT_PROTOCOL_V1; +import static org.apache.kafka.connect.runtime.distributed.IncrementalCooperativeConnectProtocol.CONNECT_PROTOCOL_V2; +import static org.apache.kafka.connect.runtime.distributed.WorkerCoordinator.LeaderState; + +/** + * An assignor that computes a distribution of connectors and tasks according to the incremental + * cooperative strategy for rebalancing. {@see + * https://cwiki.apache.org/confluence/display/KAFKA/KIP-415%3A+Incremental+Cooperative + * +Rebalancing+in+Kafka+Connect} for a description of the assignment policy. + * + * Note that this class is NOT thread-safe. + */ +public class IncrementalCooperativeAssignor implements ConnectAssignor { + private final Logger log; + private final Time time; + private final int maxDelay; + private ConnectorsAndTasks previousAssignment; + private ConnectorsAndTasks previousRevocation; + private boolean canRevoke; + // visible for testing + protected final Set candidateWorkersForReassignment; + protected long scheduledRebalance; + protected int delay; + protected int previousGenerationId; + protected Set previousMembers; + + public IncrementalCooperativeAssignor(LogContext logContext, Time time, int maxDelay) { + this.log = logContext.logger(IncrementalCooperativeAssignor.class); + this.time = time; + this.maxDelay = maxDelay; + this.previousAssignment = ConnectorsAndTasks.EMPTY; + this.previousRevocation = new ConnectorsAndTasks.Builder().build(); + this.canRevoke = true; + this.scheduledRebalance = 0; + this.candidateWorkersForReassignment = new LinkedHashSet<>(); + this.delay = 0; + this.previousGenerationId = -1; + this.previousMembers = Collections.emptySet(); + } + + @Override + public Map performAssignment(String leaderId, String protocol, + List allMemberMetadata, + WorkerCoordinator coordinator) { + log.debug("Performing task assignment"); + + Map memberConfigs = new HashMap<>(); + for (JoinGroupResponseMember member : allMemberMetadata) { + memberConfigs.put( + member.memberId(), + IncrementalCooperativeConnectProtocol.deserializeMetadata(ByteBuffer.wrap(member.metadata()))); + } + log.debug("Member configs: {}", memberConfigs); + + // The new config offset is the maximum seen by any member. We always perform assignment using this offset, + // even if some members have fallen behind. The config offset used to generate the assignment is included in + // the response so members that have fallen behind will not use the assignment until they have caught up. + long maxOffset = memberConfigs.values().stream().map(ExtendedWorkerState::offset).max(Long::compare).get(); + log.debug("Max config offset root: {}, local snapshot config offsets root: {}", + maxOffset, coordinator.configSnapshot().offset()); + + short protocolVersion = memberConfigs.values().stream() + .allMatch(state -> state.assignment().version() == CONNECT_PROTOCOL_V2) + ? CONNECT_PROTOCOL_V2 + : CONNECT_PROTOCOL_V1; + + Long leaderOffset = ensureLeaderConfig(maxOffset, coordinator); + if (leaderOffset == null) { + Map assignments = fillAssignments( + memberConfigs.keySet(), Assignment.CONFIG_MISMATCH, + leaderId, memberConfigs.get(leaderId).url(), maxOffset, Collections.emptyMap(), + Collections.emptyMap(), Collections.emptyMap(), 0, protocolVersion); + return serializeAssignments(assignments); + } + return performTaskAssignment(leaderId, leaderOffset, memberConfigs, coordinator, protocolVersion); + } + + private Long ensureLeaderConfig(long maxOffset, WorkerCoordinator coordinator) { + // If this leader is behind some other members, we can't do assignment + if (coordinator.configSnapshot().offset() < maxOffset) { + // We might be able to take a new snapshot to catch up immediately and avoid another round of syncing here. + // Alternatively, if this node has already passed the maximum reported by any other member of the group, it + // is also safe to use this newer state. + ClusterConfigState updatedSnapshot = coordinator.configFreshSnapshot(); + if (updatedSnapshot.offset() < maxOffset) { + log.info("Was selected to perform assignments, but do not have latest config found in sync request. " + + "Returning an empty configuration to trigger re-sync."); + return null; + } else { + coordinator.configSnapshot(updatedSnapshot); + return updatedSnapshot.offset(); + } + } + return maxOffset; + } + + /** + * Performs task assignment based on the incremental cooperative connect protocol. + * Read more on the design and implementation in: + * {@see https://cwiki.apache.org/confluence/display/KAFKA/KIP-415%3A+Incremental+Cooperative+Rebalancing+in+Kafka+Connect} + * + * @param leaderId the ID of the group leader + * @param maxOffset the latest known offset of the configuration topic + * @param memberConfigs the metadata of all the members of the group as gather in the current + * round of rebalancing + * @param coordinator the worker coordinator instance that provide the configuration snapshot + * and get assigned the leader state during this assignment + * @param protocolVersion the Connect subprotocol version + * @return the serialized assignment of tasks to the whole group, including assigned or + * revoked tasks + */ + protected Map performTaskAssignment(String leaderId, long maxOffset, + Map memberConfigs, + WorkerCoordinator coordinator, short protocolVersion) { + log.debug("Performing task assignment during generation: {} with memberId: {}", + coordinator.generationId(), coordinator.memberId()); + + // Base set: The previous assignment of connectors-and-tasks is a standalone snapshot that + // can be used to calculate derived sets + log.debug("Previous assignments: {}", previousAssignment); + int lastCompletedGenerationId = coordinator.lastCompletedGenerationId(); + if (previousGenerationId != lastCompletedGenerationId) { + log.debug("Clearing the view of previous assignments due to generation mismatch between " + + "previous generation ID {} and last completed generation ID {}. This can " + + "happen if the leader fails to sync the assignment within a rebalancing round. " + + "The following view of previous assignments might be outdated and will be " + + "ignored by the leader in the current computation of new assignments. " + + "Possibly outdated previous assignments: {}", + previousGenerationId, lastCompletedGenerationId, previousAssignment); + this.previousAssignment = ConnectorsAndTasks.EMPTY; + } + + ClusterConfigState snapshot = coordinator.configSnapshot(); + Set configuredConnectors = new TreeSet<>(snapshot.connectors()); + Set configuredTasks = configuredConnectors.stream() + .flatMap(c -> snapshot.tasks(c).stream()) + .collect(Collectors.toSet()); + + // Base set: The set of configured connectors-and-tasks is a standalone snapshot that can + // be used to calculate derived sets + ConnectorsAndTasks configured = new ConnectorsAndTasks.Builder() + .with(configuredConnectors, configuredTasks).build(); + log.debug("Configured assignments: {}", configured); + + // Base set: The set of active connectors-and-tasks is a standalone snapshot that can be + // used to calculate derived sets + ConnectorsAndTasks activeAssignments = assignment(memberConfigs); + log.debug("Active assignments: {}", activeAssignments); + + // This means that a previous revocation did not take effect. In this case, reset + // appropriately and be ready to re-apply revocation of tasks + if (!previousRevocation.isEmpty()) { + if (previousRevocation.connectors().stream().anyMatch(c -> activeAssignments.connectors().contains(c)) + || previousRevocation.tasks().stream().anyMatch(t -> activeAssignments.tasks().contains(t))) { + previousAssignment = activeAssignments; + canRevoke = true; + } + previousRevocation.connectors().clear(); + previousRevocation.tasks().clear(); + } + + // Derived set: The set of deleted connectors-and-tasks is a derived set from the set + // difference of previous - configured + ConnectorsAndTasks deleted = diff(previousAssignment, configured); + log.debug("Deleted assignments: {}", deleted); + + // Derived set: The set of remaining active connectors-and-tasks is a derived set from the + // set difference of active - deleted + ConnectorsAndTasks remainingActive = diff(activeAssignments, deleted); + log.debug("Remaining (excluding deleted) active assignments: {}", remainingActive); + + // Derived set: The set of lost or unaccounted connectors-and-tasks is a derived set from + // the set difference of previous - active - deleted + ConnectorsAndTasks lostAssignments = diff(previousAssignment, activeAssignments, deleted); + log.debug("Lost assignments: {}", lostAssignments); + + // Derived set: The set of new connectors-and-tasks is a derived set from the set + // difference of configured - previous - active + ConnectorsAndTasks newSubmissions = diff(configured, previousAssignment, activeAssignments); + log.debug("New assignments: {}", newSubmissions); + + // A collection of the complete assignment + List completeWorkerAssignment = workerAssignment(memberConfigs, ConnectorsAndTasks.EMPTY); + log.debug("Complete (ignoring deletions) worker assignments: {}", completeWorkerAssignment); + + // Per worker connector assignments without removing deleted connectors yet + Map> connectorAssignments = + completeWorkerAssignment.stream().collect(Collectors.toMap(WorkerLoad::worker, WorkerLoad::connectors)); + log.debug("Complete (ignoring deletions) connector assignments: {}", connectorAssignments); + + // Per worker task assignments without removing deleted connectors yet + Map> taskAssignments = + completeWorkerAssignment.stream().collect(Collectors.toMap(WorkerLoad::worker, WorkerLoad::tasks)); + log.debug("Complete (ignoring deletions) task assignments: {}", taskAssignments); + + // A collection of the current assignment excluding the connectors-and-tasks to be deleted + List currentWorkerAssignment = workerAssignment(memberConfigs, deleted); + + Map toRevoke = computeDeleted(deleted, connectorAssignments, taskAssignments); + log.debug("Connector and task to delete assignments: {}", toRevoke); + + // Revoking redundant connectors/tasks if the workers have duplicate assignments + toRevoke.putAll(computeDuplicatedAssignments(memberConfigs, connectorAssignments, taskAssignments)); + log.debug("Connector and task to revoke assignments (include duplicated assignments): {}", toRevoke); + + // Recompute the complete assignment excluding the deleted connectors-and-tasks + completeWorkerAssignment = workerAssignment(memberConfigs, deleted); + connectorAssignments = + completeWorkerAssignment.stream().collect(Collectors.toMap(WorkerLoad::worker, WorkerLoad::connectors)); + taskAssignments = + completeWorkerAssignment.stream().collect(Collectors.toMap(WorkerLoad::worker, WorkerLoad::tasks)); + + handleLostAssignments(lostAssignments, newSubmissions, completeWorkerAssignment, memberConfigs); + + // Do not revoke resources for re-assignment while a delayed rebalance is active + // Also we do not revoke in two consecutive rebalances by the same leader + canRevoke = delay == 0 && canRevoke; + + // Compute the connectors-and-tasks to be revoked for load balancing without taking into + // account the deleted ones. + log.debug("Can leader revoke tasks in this assignment? {} (delay: {})", canRevoke, delay); + if (canRevoke) { + Map toExplicitlyRevoke = + performTaskRevocation(activeAssignments, currentWorkerAssignment); + + log.debug("Connector and task to revoke assignments: {}", toRevoke); + + toExplicitlyRevoke.forEach( + (worker, assignment) -> { + ConnectorsAndTasks existing = toRevoke.computeIfAbsent( + worker, + v -> new ConnectorsAndTasks.Builder().build()); + existing.connectors().addAll(assignment.connectors()); + existing.tasks().addAll(assignment.tasks()); + } + ); + canRevoke = toExplicitlyRevoke.size() == 0; + } else { + canRevoke = delay == 0; + } + + assignConnectors(completeWorkerAssignment, newSubmissions.connectors()); + assignTasks(completeWorkerAssignment, newSubmissions.tasks()); + log.debug("Current complete assignments: {}", currentWorkerAssignment); + log.debug("New complete assignments: {}", completeWorkerAssignment); + + Map> currentConnectorAssignments = + currentWorkerAssignment.stream().collect(Collectors.toMap(WorkerLoad::worker, WorkerLoad::connectors)); + Map> currentTaskAssignments = + currentWorkerAssignment.stream().collect(Collectors.toMap(WorkerLoad::worker, WorkerLoad::tasks)); + Map> incrementalConnectorAssignments = + diff(connectorAssignments, currentConnectorAssignments); + Map> incrementalTaskAssignments = + diff(taskAssignments, currentTaskAssignments); + + log.debug("Incremental connector assignments: {}", incrementalConnectorAssignments); + log.debug("Incremental task assignments: {}", incrementalTaskAssignments); + + coordinator.leaderState(new LeaderState(memberConfigs, connectorAssignments, taskAssignments)); + + Map assignments = + fillAssignments(memberConfigs.keySet(), Assignment.NO_ERROR, leaderId, + memberConfigs.get(leaderId).url(), maxOffset, incrementalConnectorAssignments, + incrementalTaskAssignments, toRevoke, delay, protocolVersion); + previousAssignment = computePreviousAssignment(toRevoke, connectorAssignments, taskAssignments, lostAssignments); + previousGenerationId = coordinator.generationId(); + previousMembers = memberConfigs.keySet(); + log.debug("Actual assignments: {}", assignments); + return serializeAssignments(assignments); + } + + private Map computeDeleted(ConnectorsAndTasks deleted, + Map> connectorAssignments, + Map> taskAssignments) { + // Connector to worker reverse lookup map + Map connectorOwners = WorkerCoordinator.invertAssignment(connectorAssignments); + // Task to worker reverse lookup map + Map taskOwners = WorkerCoordinator.invertAssignment(taskAssignments); + + Map toRevoke = new HashMap<>(); + // Add the connectors that have been deleted to the revoked set + deleted.connectors().forEach(c -> + toRevoke.computeIfAbsent( + connectorOwners.get(c), + v -> new ConnectorsAndTasks.Builder().build() + ).connectors().add(c)); + // Add the tasks that have been deleted to the revoked set + deleted.tasks().forEach(t -> + toRevoke.computeIfAbsent( + taskOwners.get(t), + v -> new ConnectorsAndTasks.Builder().build() + ).tasks().add(t)); + log.debug("Connectors and tasks to delete assignments: {}", toRevoke); + return toRevoke; + } + + private ConnectorsAndTasks computePreviousAssignment(Map toRevoke, + Map> connectorAssignments, + Map> taskAssignments, + ConnectorsAndTasks lostAssignments) { + ConnectorsAndTasks previousAssignment = new ConnectorsAndTasks.Builder().with( + connectorAssignments.values().stream().flatMap(Collection::stream).collect(Collectors.toSet()), + taskAssignments.values() .stream() .flatMap(Collection::stream).collect(Collectors.toSet())) + .build(); + + for (ConnectorsAndTasks revoked : toRevoke.values()) { + previousAssignment.connectors().removeAll(revoked.connectors()); + previousAssignment.tasks().removeAll(revoked.tasks()); + previousRevocation.connectors().addAll(revoked.connectors()); + previousRevocation.tasks().addAll(revoked.tasks()); + } + + // Depends on the previous assignment's collections being sets at the moment. + // TODO: make it independent + previousAssignment.connectors().addAll(lostAssignments.connectors()); + previousAssignment.tasks().addAll(lostAssignments.tasks()); + + return previousAssignment; + } + + private ConnectorsAndTasks duplicatedAssignments(Map memberConfigs) { + Set connectors = memberConfigs.entrySet().stream() + .flatMap(memberConfig -> memberConfig.getValue().assignment().connectors().stream()) + .collect(Collectors.groupingBy(Function.identity(), Collectors.counting())) + .entrySet().stream() + .filter(entry -> entry.getValue() > 1L) + .map(Entry::getKey) + .collect(Collectors.toSet()); + + Set tasks = memberConfigs.values().stream() + .flatMap(state -> state.assignment().tasks().stream()) + .collect(Collectors.groupingBy(Function.identity(), Collectors.counting())) + .entrySet().stream() + .filter(entry -> entry.getValue() > 1L) + .map(Entry::getKey) + .collect(Collectors.toSet()); + return new ConnectorsAndTasks.Builder().with(connectors, tasks).build(); + } + + private Map computeDuplicatedAssignments(Map memberConfigs, + Map> connectorAssignments, + Map> taskAssignment) { + ConnectorsAndTasks duplicatedAssignments = duplicatedAssignments(memberConfigs); + log.debug("Duplicated assignments: {}", duplicatedAssignments); + + Map toRevoke = new HashMap<>(); + if (!duplicatedAssignments.connectors().isEmpty()) { + connectorAssignments.entrySet().stream() + .forEach(entry -> { + Set duplicatedConnectors = new HashSet<>(duplicatedAssignments.connectors()); + duplicatedConnectors.retainAll(entry.getValue()); + if (!duplicatedConnectors.isEmpty()) { + toRevoke.computeIfAbsent( + entry.getKey(), + v -> new ConnectorsAndTasks.Builder().build() + ).connectors().addAll(duplicatedConnectors); + } + }); + } + if (!duplicatedAssignments.tasks().isEmpty()) { + taskAssignment.entrySet().stream() + .forEach(entry -> { + Set duplicatedTasks = new HashSet<>(duplicatedAssignments.tasks()); + duplicatedTasks.retainAll(entry.getValue()); + if (!duplicatedTasks.isEmpty()) { + toRevoke.computeIfAbsent( + entry.getKey(), + v -> new ConnectorsAndTasks.Builder().build() + ).tasks().addAll(duplicatedTasks); + } + }); + } + return toRevoke; + } + + // visible for testing + protected void handleLostAssignments(ConnectorsAndTasks lostAssignments, + ConnectorsAndTasks newSubmissions, + List completeWorkerAssignment, + Map memberConfigs) { + if (lostAssignments.isEmpty()) { + resetDelay(); + return; + } + + final long now = time.milliseconds(); + log.debug("Found the following connectors and tasks missing from previous assignments: " + + lostAssignments); + + if (scheduledRebalance <= 0 && memberConfigs.keySet().containsAll(previousMembers)) { + log.debug("No worker seems to have departed the group during the rebalance. The " + + "missing assignments that the leader is detecting are probably due to some " + + "workers failing to receive the new assignments in the previous rebalance. " + + "Will reassign missing tasks as new tasks"); + newSubmissions.connectors().addAll(lostAssignments.connectors()); + newSubmissions.tasks().addAll(lostAssignments.tasks()); + return; + } + + if (scheduledRebalance > 0 && now >= scheduledRebalance) { + // delayed rebalance expired and it's time to assign resources + log.debug("Delayed rebalance expired. Reassigning lost tasks"); + List candidateWorkerLoad = Collections.emptyList(); + if (!candidateWorkersForReassignment.isEmpty()) { + candidateWorkerLoad = pickCandidateWorkerForReassignment(completeWorkerAssignment); + } + + if (!candidateWorkerLoad.isEmpty()) { + log.debug("Assigning lost tasks to {} candidate workers: {}", + candidateWorkerLoad.size(), + candidateWorkerLoad.stream().map(WorkerLoad::worker).collect(Collectors.joining(","))); + Iterator candidateWorkerIterator = candidateWorkerLoad.iterator(); + for (String connector : lostAssignments.connectors()) { + // Loop over the candidate workers as many times as it takes + if (!candidateWorkerIterator.hasNext()) { + candidateWorkerIterator = candidateWorkerLoad.iterator(); + } + WorkerLoad worker = candidateWorkerIterator.next(); + log.debug("Assigning connector id {} to member {}", connector, worker.worker()); + worker.assign(connector); + } + candidateWorkerIterator = candidateWorkerLoad.iterator(); + for (ConnectorTaskId task : lostAssignments.tasks()) { + if (!candidateWorkerIterator.hasNext()) { + candidateWorkerIterator = candidateWorkerLoad.iterator(); + } + WorkerLoad worker = candidateWorkerIterator.next(); + log.debug("Assigning task id {} to member {}", task, worker.worker()); + worker.assign(task); + } + } else { + log.debug("No single candidate worker was found to assign lost tasks. Treating lost tasks as new tasks"); + newSubmissions.connectors().addAll(lostAssignments.connectors()); + newSubmissions.tasks().addAll(lostAssignments.tasks()); + } + resetDelay(); + } else { + candidateWorkersForReassignment + .addAll(candidateWorkersForReassignment(completeWorkerAssignment)); + if (now < scheduledRebalance) { + // a delayed rebalance is in progress, but it's not yet time to reassign + // unaccounted resources + delay = calculateDelay(now); + log.debug("Delayed rebalance in progress. Task reassignment is postponed. New computed rebalance delay: {}", delay); + } else { + // This means scheduledRebalance == 0 + // We could also also extract the current minimum delay from the group, to make + // independent of consecutive leader failures, but this optimization is skipped + // at the moment + delay = maxDelay; + log.debug("Resetting rebalance delay to the max: {}. scheduledRebalance: {} now: {} diff scheduledRebalance - now: {}", + delay, scheduledRebalance, now, scheduledRebalance - now); + } + scheduledRebalance = now + delay; + } + } + + private void resetDelay() { + candidateWorkersForReassignment.clear(); + scheduledRebalance = 0; + if (delay != 0) { + log.debug("Resetting delay from previous value: {} to 0", delay); + } + delay = 0; + } + + private Set candidateWorkersForReassignment(List completeWorkerAssignment) { + return completeWorkerAssignment.stream() + .filter(WorkerLoad::isEmpty) + .map(WorkerLoad::worker) + .collect(Collectors.toSet()); + } + + private List pickCandidateWorkerForReassignment(List completeWorkerAssignment) { + Map activeWorkers = completeWorkerAssignment.stream() + .collect(Collectors.toMap(WorkerLoad::worker, Function.identity())); + return candidateWorkersForReassignment.stream() + .map(activeWorkers::get) + .filter(Objects::nonNull) + .collect(Collectors.toList()); + } + + /** + * Task revocation is based on an rough estimation of the lower average number of tasks before + * and after new workers join the group. If no new workers join, no revocation takes place. + * Based on this estimation, tasks are revoked until the new floor average is reached for + * each existing worker. The revoked tasks, once assigned to the new workers will maintain + * a balanced load among the group. + * + * @param activeAssignments + * @param completeWorkerAssignment + * @return + */ + private Map performTaskRevocation(ConnectorsAndTasks activeAssignments, + Collection completeWorkerAssignment) { + int totalActiveConnectorsNum = activeAssignments.connectors().size(); + int totalActiveTasksNum = activeAssignments.tasks().size(); + Collection existingWorkers = completeWorkerAssignment.stream() + .filter(wl -> wl.size() > 0) + .collect(Collectors.toList()); + int existingWorkersNum = existingWorkers.size(); + int totalWorkersNum = completeWorkerAssignment.size(); + int newWorkersNum = totalWorkersNum - existingWorkersNum; + + if (log.isDebugEnabled()) { + completeWorkerAssignment.forEach(wl -> log.debug( + "Per worker current load size; worker: {} connectors: {} tasks: {}", + wl.worker(), wl.connectorsSize(), wl.tasksSize())); + } + + Map revoking = new HashMap<>(); + // If there are no new workers, or no existing workers to revoke tasks from return early + // after logging the status + if (!(newWorkersNum > 0 && existingWorkersNum > 0)) { + log.debug("No task revocation required; workers with existing load: {} workers with " + + "no load {} total workers {}", + existingWorkersNum, newWorkersNum, totalWorkersNum); + // This is intentionally empty but mutable, because the map is used to include deleted + // connectors and tasks as well + return revoking; + } + + log.debug("Task revocation is required; workers with existing load: {} workers with " + + "no load {} total workers {}", + existingWorkersNum, newWorkersNum, totalWorkersNum); + + // We have at least one worker assignment (the leader itself) so totalWorkersNum can't be 0 + log.debug("Previous rounded down (floor) average number of connectors per worker {}", totalActiveConnectorsNum / existingWorkersNum); + int floorConnectors = totalActiveConnectorsNum / totalWorkersNum; + int ceilConnectors = floorConnectors + ((totalActiveConnectorsNum % totalWorkersNum == 0) ? 0 : 1); + log.debug("New average number of connectors per worker rounded down (floor) {} and rounded up (ceil) {}", floorConnectors, ceilConnectors); + + + log.debug("Previous rounded down (floor) average number of tasks per worker {}", totalActiveTasksNum / existingWorkersNum); + int floorTasks = totalActiveTasksNum / totalWorkersNum; + int ceilTasks = floorTasks + ((totalActiveTasksNum % totalWorkersNum == 0) ? 0 : 1); + log.debug("New average number of tasks per worker rounded down (floor) {} and rounded up (ceil) {}", floorTasks, ceilTasks); + int numToRevoke; + + for (WorkerLoad existing : existingWorkers) { + Iterator connectors = existing.connectors().iterator(); + numToRevoke = existing.connectorsSize() - ceilConnectors; + for (int i = existing.connectorsSize(); i > floorConnectors && numToRevoke > 0; --i, --numToRevoke) { + ConnectorsAndTasks resources = revoking.computeIfAbsent( + existing.worker(), + w -> new ConnectorsAndTasks.Builder().build()); + resources.connectors().add(connectors.next()); + } + } + + for (WorkerLoad existing : existingWorkers) { + Iterator tasks = existing.tasks().iterator(); + numToRevoke = existing.tasksSize() - ceilTasks; + log.debug("Tasks on worker {} is higher than ceiling, so revoking {} tasks", existing, numToRevoke); + for (int i = existing.tasksSize(); i > floorTasks && numToRevoke > 0; --i, --numToRevoke) { + ConnectorsAndTasks resources = revoking.computeIfAbsent( + existing.worker(), + w -> new ConnectorsAndTasks.Builder().build()); + resources.tasks().add(tasks.next()); + } + } + + return revoking; + } + + private Map fillAssignments(Collection members, short error, + String leaderId, String leaderUrl, long maxOffset, + Map> connectorAssignments, + Map> taskAssignments, + Map revoked, + int delay, short protocolVersion) { + Map groupAssignment = new HashMap<>(); + for (String member : members) { + Collection connectorsToStart = connectorAssignments.getOrDefault(member, Collections.emptyList()); + Collection tasksToStart = taskAssignments.getOrDefault(member, Collections.emptyList()); + Collection connectorsToStop = revoked.getOrDefault(member, ConnectorsAndTasks.EMPTY).connectors(); + Collection tasksToStop = revoked.getOrDefault(member, ConnectorsAndTasks.EMPTY).tasks(); + ExtendedAssignment assignment = + new ExtendedAssignment(protocolVersion, error, leaderId, leaderUrl, maxOffset, + connectorsToStart, tasksToStart, connectorsToStop, tasksToStop, delay); + log.debug("Filling assignment: {} -> {}", member, assignment); + groupAssignment.put(member, assignment); + } + log.debug("Finished assignment"); + return groupAssignment; + } + + /** + * From a map of workers to assignment object generate the equivalent map of workers to byte + * buffers of serialized assignments. + * + * @param assignments the map of worker assignments + * @return the serialized map of assignments to workers + */ + protected Map serializeAssignments(Map assignments) { + return assignments.entrySet() + .stream() + .collect(Collectors.toMap( + Map.Entry::getKey, + e -> IncrementalCooperativeConnectProtocol.serializeAssignment(e.getValue()))); + } + + private static ConnectorsAndTasks diff(ConnectorsAndTasks base, + ConnectorsAndTasks... toSubtract) { + Collection connectors = new TreeSet<>(base.connectors()); + Collection tasks = new TreeSet<>(base.tasks()); + for (ConnectorsAndTasks sub : toSubtract) { + connectors.removeAll(sub.connectors()); + tasks.removeAll(sub.tasks()); + } + return new ConnectorsAndTasks.Builder().with(connectors, tasks).build(); + } + + private static Map> diff(Map> base, + Map> toSubtract) { + Map> incremental = new HashMap<>(); + for (Map.Entry> entry : base.entrySet()) { + List values = new ArrayList<>(entry.getValue()); + values.removeAll(toSubtract.get(entry.getKey())); + incremental.put(entry.getKey(), values); + } + return incremental; + } + + private ConnectorsAndTasks assignment(Map memberConfigs) { + log.debug("Received assignments: {}", memberConfigs); + Set connectors = memberConfigs.values() + .stream() + .flatMap(state -> state.assignment().connectors().stream()) + .collect(Collectors.toSet()); + Set tasks = memberConfigs.values() + .stream() + .flatMap(state -> state.assignment().tasks().stream()) + .collect(Collectors.toSet()); + return new ConnectorsAndTasks.Builder().with(connectors, tasks).build(); + } + + private int calculateDelay(long now) { + long diff = scheduledRebalance - now; + return diff > 0 ? (int) Math.min(diff, maxDelay) : 0; + } + + /** + * Perform a round-robin assignment of connectors to workers with existing worker load. This + * assignment tries to balance the load between workers, by assigning connectors to workers + * that have equal load, starting with the least loaded workers. + * + * @param workerAssignment the current worker assignment; assigned connectors are added to this list + * @param connectors the connectors to be assigned + */ + protected void assignConnectors(List workerAssignment, Collection connectors) { + workerAssignment.sort(WorkerLoad.connectorComparator()); + WorkerLoad first = workerAssignment.get(0); + + Iterator load = connectors.iterator(); + while (load.hasNext()) { + int firstLoad = first.connectorsSize(); + int upTo = IntStream.range(0, workerAssignment.size()) + .filter(i -> workerAssignment.get(i).connectorsSize() > firstLoad) + .findFirst() + .orElse(workerAssignment.size()); + for (WorkerLoad worker : workerAssignment.subList(0, upTo)) { + String connector = load.next(); + log.debug("Assigning connector {} to {}", connector, worker.worker()); + worker.assign(connector); + if (!load.hasNext()) { + break; + } + } + } + } + + /** + * Perform a round-robin assignment of tasks to workers with existing worker load. This + * assignment tries to balance the load between workers, by assigning tasks to workers that + * have equal load, starting with the least loaded workers. + * + * @param workerAssignment the current worker assignment; assigned tasks are added to this list + * @param tasks the tasks to be assigned + */ + protected void assignTasks(List workerAssignment, Collection tasks) { + workerAssignment.sort(WorkerLoad.taskComparator()); + WorkerLoad first = workerAssignment.get(0); + + Iterator load = tasks.iterator(); + while (load.hasNext()) { + int firstLoad = first.tasksSize(); + int upTo = IntStream.range(0, workerAssignment.size()) + .filter(i -> workerAssignment.get(i).tasksSize() > firstLoad) + .findFirst() + .orElse(workerAssignment.size()); + for (WorkerLoad worker : workerAssignment.subList(0, upTo)) { + ConnectorTaskId task = load.next(); + log.debug("Assigning task {} to {}", task, worker.worker()); + worker.assign(task); + if (!load.hasNext()) { + break; + } + } + } + } + + private static List workerAssignment(Map memberConfigs, + ConnectorsAndTasks toExclude) { + ConnectorsAndTasks ignore = new ConnectorsAndTasks.Builder() + .with(new HashSet<>(toExclude.connectors()), new HashSet<>(toExclude.tasks())) + .build(); + + return memberConfigs.entrySet().stream() + .map(e -> new WorkerLoad.Builder(e.getKey()).with( + e.getValue().assignment().connectors().stream() + .filter(v -> !ignore.connectors().contains(v)) + .collect(Collectors.toList()), + e.getValue().assignment().tasks().stream() + .filter(v -> !ignore.tasks().contains(v)) + .collect(Collectors.toList()) + ).build() + ).collect(Collectors.toList()); + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/IncrementalCooperativeConnectProtocol.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/IncrementalCooperativeConnectProtocol.java new file mode 100644 index 0000000..6bcf9be --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/IncrementalCooperativeConnectProtocol.java @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.distributed; + +import org.apache.kafka.common.protocol.types.ArrayOf; +import org.apache.kafka.common.protocol.types.Field; +import org.apache.kafka.common.protocol.types.Schema; +import org.apache.kafka.common.protocol.types.SchemaException; +import org.apache.kafka.common.protocol.types.Struct; +import org.apache.kafka.common.protocol.types.Type; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +import static org.apache.kafka.common.message.JoinGroupRequestData.JoinGroupRequestProtocol; +import static org.apache.kafka.common.message.JoinGroupRequestData.JoinGroupRequestProtocolCollection; +import static org.apache.kafka.common.protocol.types.Type.NULLABLE_BYTES; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocol.ASSIGNMENT_KEY_NAME; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocol.CONFIG_OFFSET_KEY_NAME; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocol.CONFIG_STATE_V0; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocol.CONNECTOR_ASSIGNMENT_V0; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocol.CONNECT_PROTOCOL_HEADER_SCHEMA; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocol.CONNECT_PROTOCOL_V0; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocol.ERROR_KEY_NAME; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocol.LEADER_KEY_NAME; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocol.LEADER_URL_KEY_NAME; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocol.URL_KEY_NAME; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocol.VERSION_KEY_NAME; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocolCompatibility.COMPATIBLE; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocolCompatibility.EAGER; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocolCompatibility.SESSIONED; + + +/** + * This class implements a group protocol for Kafka Connect workers that support incremental and + * cooperative rebalancing of connectors and tasks. It includes the format of worker state used when + * joining the group and distributing assignments, and the format of assignments of connectors + * and tasks to workers. + */ +public class IncrementalCooperativeConnectProtocol { + public static final String ALLOCATION_KEY_NAME = "allocation"; + public static final String REVOKED_KEY_NAME = "revoked"; + public static final String SCHEDULED_DELAY_KEY_NAME = "delay"; + public static final short CONNECT_PROTOCOL_V1 = 1; + public static final short CONNECT_PROTOCOL_V2 = 2; + public static final boolean TOLERATE_MISSING_FIELDS_WITH_DEFAULTS = true; + + /** + * Connect Protocol Header V1: + *

            +     *   Version            => Int16
            +     * 
            + */ + private static final Struct CONNECT_PROTOCOL_HEADER_V1 = new Struct(CONNECT_PROTOCOL_HEADER_SCHEMA) + .set(VERSION_KEY_NAME, CONNECT_PROTOCOL_V1); + + /** + * Connect Protocol Header V2: + *
            +     *   Version            => Int16
            +     * 
            + * The V2 protocol is schematically identical to V1, but is used to signify that internal request + * verification and distribution of session keys is enabled (for more information, see KIP-507: + * https://cwiki.apache.org/confluence/display/KAFKA/KIP-507%3A+Securing+Internal+Connect+REST+Endpoints) + */ + private static final Struct CONNECT_PROTOCOL_HEADER_V2 = new Struct(CONNECT_PROTOCOL_HEADER_SCHEMA) + .set(VERSION_KEY_NAME, CONNECT_PROTOCOL_V2); + + + /** + * Config State V1: + *
            +     *   Url                => [String]
            +     *   ConfigOffset       => Int64
            +     * 
            + */ + public static final Schema CONFIG_STATE_V1 = CONFIG_STATE_V0; + + /** + * Allocation V1 + *
            +     *   Current Assignment => [Byte]
            +     * 
            + */ + public static final Schema ALLOCATION_V1 = new Schema( + TOLERATE_MISSING_FIELDS_WITH_DEFAULTS, + new Field(ALLOCATION_KEY_NAME, NULLABLE_BYTES, null, true, null)); + + /** + * + * Connector Assignment V1: + *
            +     *   Connector          => [String]
            +     *   Tasks              => [Int32]
            +     * 
            + * + *

            Assignments for each worker are a set of connectors and tasks. These are categorized by + * connector ID. A sentinel task ID (CONNECTOR_TASK) is used to indicate the connector itself + * (i.e. that the assignment includes responsibility for running the Connector instance in + * addition to any tasks it generates).

            + */ + public static final Schema CONNECTOR_ASSIGNMENT_V1 = CONNECTOR_ASSIGNMENT_V0; + + /** + * Raw (non versioned) assignment V1: + *
            +     *   Error              => Int16
            +     *   Leader             => [String]
            +     *   LeaderUrl          => [String]
            +     *   ConfigOffset       => Int64
            +     *   Assignment         => [Connector Assignment]
            +     *   Revoked            => [Connector Assignment]
            +     *   ScheduledDelay     => Int32
            +     * 
            + */ + public static final Schema ASSIGNMENT_V1 = new Schema( + TOLERATE_MISSING_FIELDS_WITH_DEFAULTS, + new Field(ERROR_KEY_NAME, Type.INT16), + new Field(LEADER_KEY_NAME, Type.STRING), + new Field(LEADER_URL_KEY_NAME, Type.STRING), + new Field(CONFIG_OFFSET_KEY_NAME, Type.INT64), + new Field(ASSIGNMENT_KEY_NAME, ArrayOf.nullable(CONNECTOR_ASSIGNMENT_V1), null, true, null), + new Field(REVOKED_KEY_NAME, ArrayOf.nullable(CONNECTOR_ASSIGNMENT_V1), null, true, null), + new Field(SCHEDULED_DELAY_KEY_NAME, Type.INT32, null, 0)); + + /** + * The fields are serialized in sequence as follows: + * Subscription V1: + *
            +     *   Version            => Int16
            +     *   Url                => [String]
            +     *   ConfigOffset       => Int64
            +     *   Current Assignment => [Byte]
            +     * 
            + */ + public static ByteBuffer serializeMetadata(ExtendedWorkerState workerState, boolean sessioned) { + Struct configState = new Struct(CONFIG_STATE_V1) + .set(URL_KEY_NAME, workerState.url()) + .set(CONFIG_OFFSET_KEY_NAME, workerState.offset()); + // Not a big issue if we embed the protocol version with the assignment in the metadata + Struct allocation = new Struct(ALLOCATION_V1) + .set(ALLOCATION_KEY_NAME, serializeAssignment(workerState.assignment())); + Struct connectProtocolHeader = sessioned ? CONNECT_PROTOCOL_HEADER_V2 : CONNECT_PROTOCOL_HEADER_V1; + ByteBuffer buffer = ByteBuffer.allocate(connectProtocolHeader.sizeOf() + + CONFIG_STATE_V1.sizeOf(configState) + + ALLOCATION_V1.sizeOf(allocation)); + connectProtocolHeader.writeTo(buffer); + CONFIG_STATE_V1.write(buffer, configState); + ALLOCATION_V1.write(buffer, allocation); + buffer.flip(); + return buffer; + } + + /** + * Returns the collection of Connect protocols that are supported by this version along + * with their serialized metadata. The protocols are ordered by preference. + * + * @param workerState the current state of the worker metadata + * @param sessioned whether the {@link ConnectProtocolCompatibility#SESSIONED} protocol should + * be included in the collection of supported protocols + * @return the collection of Connect protocol metadata + */ + public static JoinGroupRequestProtocolCollection metadataRequest(ExtendedWorkerState workerState, boolean sessioned) { + // Order matters in terms of protocol preference + List joinGroupRequestProtocols = new ArrayList<>(); + if (sessioned) { + joinGroupRequestProtocols.add(new JoinGroupRequestProtocol() + .setName(SESSIONED.protocol()) + .setMetadata(IncrementalCooperativeConnectProtocol.serializeMetadata(workerState, true).array()) + ); + } + joinGroupRequestProtocols.add(new JoinGroupRequestProtocol() + .setName(COMPATIBLE.protocol()) + .setMetadata(IncrementalCooperativeConnectProtocol.serializeMetadata(workerState, false).array()) + ); + joinGroupRequestProtocols.add(new JoinGroupRequestProtocol() + .setName(EAGER.protocol()) + .setMetadata(ConnectProtocol.serializeMetadata(workerState).array()) + ); + return new JoinGroupRequestProtocolCollection(joinGroupRequestProtocols.iterator()); + } + + /** + * Given a byte buffer that contains protocol metadata return the deserialized form of the + * metadata. + * + * @param buffer A buffer containing the protocols metadata + * @return the deserialized metadata + * @throws SchemaException on incompatible Connect protocol version + */ + public static ExtendedWorkerState deserializeMetadata(ByteBuffer buffer) { + Struct header = CONNECT_PROTOCOL_HEADER_SCHEMA.read(buffer); + Short version = header.getShort(VERSION_KEY_NAME); + checkVersionCompatibility(version); + Struct configState = CONFIG_STATE_V1.read(buffer); + long configOffset = configState.getLong(CONFIG_OFFSET_KEY_NAME); + String url = configState.getString(URL_KEY_NAME); + Struct allocation = ALLOCATION_V1.read(buffer); + // Protocol version is embedded with the assignment in the metadata + ExtendedAssignment assignment = deserializeAssignment(allocation.getBytes(ALLOCATION_KEY_NAME)); + return new ExtendedWorkerState(url, configOffset, assignment); + } + + /** + * The fields are serialized in sequence as follows: + * Complete Assignment V1: + *
            +     *   Version            => Int16
            +     *   Error              => Int16
            +     *   Leader             => [String]
            +     *   LeaderUrl          => [String]
            +     *   ConfigOffset       => Int64
            +     *   Assignment         => [Connector Assignment]
            +     *   Revoked            => [Connector Assignment]
            +     *   ScheduledDelay     => Int32
            +     * 
            + */ + public static ByteBuffer serializeAssignment(ExtendedAssignment assignment) { + // comparison depends on reference equality for now + if (assignment == null || ExtendedAssignment.empty().equals(assignment)) { + return null; + } + Struct struct = assignment.toStruct(); + ByteBuffer buffer = ByteBuffer.allocate(CONNECT_PROTOCOL_HEADER_V1.sizeOf() + + ASSIGNMENT_V1.sizeOf(struct)); + CONNECT_PROTOCOL_HEADER_V1.writeTo(buffer); + ASSIGNMENT_V1.write(buffer, struct); + buffer.flip(); + return buffer; + } + + /** + * Given a byte buffer that contains an assignment as defined by this protocol, return the + * deserialized form of the assignment. + * + * @param buffer the buffer containing a serialized assignment + * @return the deserialized assignment + * @throws SchemaException on incompatible Connect protocol version + */ + public static ExtendedAssignment deserializeAssignment(ByteBuffer buffer) { + if (buffer == null) { + return null; + } + Struct header = CONNECT_PROTOCOL_HEADER_SCHEMA.read(buffer); + Short version = header.getShort(VERSION_KEY_NAME); + checkVersionCompatibility(version); + Struct struct = ASSIGNMENT_V1.read(buffer); + return ExtendedAssignment.fromStruct(version, struct); + } + + private static void checkVersionCompatibility(short version) { + // check for invalid versions + if (version < CONNECT_PROTOCOL_V0) + throw new SchemaException("Unsupported subscription version: " + version); + + // otherwise, assume versions can be parsed + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/NotAssignedException.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/NotAssignedException.java new file mode 100644 index 0000000..ee0270e --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/NotAssignedException.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.distributed; + +/** + * Thrown when a request intended for the owner of a task or connector is received by a worker which doesn't + * own it (typically the leader). + */ +public class NotAssignedException extends RequestTargetException { + + public NotAssignedException(String message, String ownerUrl) { + super(message, ownerUrl); + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/NotLeaderException.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/NotLeaderException.java new file mode 100644 index 0000000..5ccea8a --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/NotLeaderException.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.distributed; + +/** + * Indicates an operation was not permitted because it can only be performed on the leader and this worker is not currently + * the leader. + */ +public class NotLeaderException extends RequestTargetException { + + public NotLeaderException(String msg, String leaderUrl) { + super(msg, leaderUrl); + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/RebalanceNeededException.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/RebalanceNeededException.java new file mode 100644 index 0000000..922fabe --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/RebalanceNeededException.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.distributed; + +import org.apache.kafka.connect.errors.ConnectException; + +public class RebalanceNeededException extends ConnectException { + + public RebalanceNeededException(String s) { + super(s); + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/RequestTargetException.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/RequestTargetException.java new file mode 100644 index 0000000..3c03e7b --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/RequestTargetException.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.distributed; + +import org.apache.kafka.connect.errors.ConnectException; + +/** + * Raised when a request has been received by a worker which cannot handle it, + * but can forward it to the right target + */ +public class RequestTargetException extends ConnectException { + private final String forwardUrl; + + public RequestTargetException(String s, String forwardUrl) { + super(s); + this.forwardUrl = forwardUrl; + } + + public RequestTargetException(String s, Throwable throwable, String forwardUrl) { + super(s, throwable); + this.forwardUrl = forwardUrl; + } + + public RequestTargetException(Throwable throwable, String forwardUrl) { + super(throwable); + this.forwardUrl = forwardUrl; + } + + public String forwardUrl() { + return forwardUrl; + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinator.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinator.java new file mode 100644 index 0000000..425213f --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinator.java @@ -0,0 +1,604 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.distributed; + +import org.apache.kafka.clients.consumer.internals.AbstractCoordinator; +import org.apache.kafka.clients.consumer.internals.ConsumerNetworkClient; +import org.apache.kafka.clients.GroupRebalanceConfig; +import org.apache.kafka.common.metrics.Measurable; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.requests.JoinGroupRequest; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Timer; +import org.apache.kafka.connect.storage.ConfigBackingStore; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.slf4j.Logger; + +import java.io.Closeable; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.apache.kafka.common.message.JoinGroupRequestData.JoinGroupRequestProtocolCollection; +import static org.apache.kafka.common.message.JoinGroupResponseData.JoinGroupResponseMember; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocolCompatibility.EAGER; + +/** + * This class manages the coordination process with the Kafka group coordinator on the broker for managing assignments + * to workers. + */ +public class WorkerCoordinator extends AbstractCoordinator implements Closeable { + private final Logger log; + private final String restUrl; + private final ConfigBackingStore configStorage; + private volatile ExtendedAssignment assignmentSnapshot; + private ClusterConfigState configSnapshot; + private final WorkerRebalanceListener listener; + private final ConnectProtocolCompatibility protocolCompatibility; + private LeaderState leaderState; + + private boolean rejoinRequested; + private volatile ConnectProtocolCompatibility currentConnectProtocol; + private volatile int lastCompletedGenerationId; + private final ConnectAssignor eagerAssignor; + private final ConnectAssignor incrementalAssignor; + private final int coordinatorDiscoveryTimeoutMs; + + /** + * Initialize the coordination manager. + */ + public WorkerCoordinator(GroupRebalanceConfig config, + LogContext logContext, + ConsumerNetworkClient client, + Metrics metrics, + String metricGrpPrefix, + Time time, + String restUrl, + ConfigBackingStore configStorage, + WorkerRebalanceListener listener, + ConnectProtocolCompatibility protocolCompatibility, + int maxDelay) { + super(config, + logContext, + client, + metrics, + metricGrpPrefix, + time); + this.log = logContext.logger(WorkerCoordinator.class); + this.restUrl = restUrl; + this.configStorage = configStorage; + this.assignmentSnapshot = null; + new WorkerCoordinatorMetrics(metrics, metricGrpPrefix); + this.listener = listener; + this.rejoinRequested = false; + this.protocolCompatibility = protocolCompatibility; + this.incrementalAssignor = new IncrementalCooperativeAssignor(logContext, time, maxDelay); + this.eagerAssignor = new EagerAssignor(logContext); + this.currentConnectProtocol = protocolCompatibility; + this.coordinatorDiscoveryTimeoutMs = config.heartbeatIntervalMs; + this.lastCompletedGenerationId = Generation.NO_GENERATION.generationId; + } + + @Override + public void requestRejoin(final String reason) { + log.debug("Request joining group due to: {}", reason); + rejoinRequested = true; + } + + @Override + public String protocolType() { + return "connect"; + } + + // expose for tests + @Override + protected synchronized boolean ensureCoordinatorReady(final Timer timer) { + return super.ensureCoordinatorReady(timer); + } + + public void poll(long timeout) { + // poll for io until the timeout expires + final long start = time.milliseconds(); + long now = start; + long remaining; + + do { + if (coordinatorUnknown()) { + log.debug("Broker coordinator is marked unknown. Attempting discovery with a timeout of {}ms", + coordinatorDiscoveryTimeoutMs); + if (ensureCoordinatorReady(time.timer(coordinatorDiscoveryTimeoutMs))) { + log.debug("Broker coordinator is ready"); + } else { + log.debug("Can not connect to broker coordinator"); + final ExtendedAssignment localAssignmentSnapshot = assignmentSnapshot; + if (localAssignmentSnapshot != null && !localAssignmentSnapshot.failed()) { + log.info("Broker coordinator was unreachable for {}ms. Revoking previous assignment {} to " + + "avoid running tasks while not being a member the group", coordinatorDiscoveryTimeoutMs, localAssignmentSnapshot); + listener.onRevoked(localAssignmentSnapshot.leader(), localAssignmentSnapshot.connectors(), localAssignmentSnapshot.tasks()); + assignmentSnapshot = null; + } + } + now = time.milliseconds(); + } + + if (rejoinNeededOrPending()) { + ensureActiveGroup(); + now = time.milliseconds(); + } + + pollHeartbeat(now); + + long elapsed = now - start; + remaining = timeout - elapsed; + + // Note that because the network client is shared with the background heartbeat thread, + // we do not want to block in poll longer than the time to the next heartbeat. + long pollTimeout = Math.min(Math.max(0, remaining), timeToNextHeartbeat(now)); + client.poll(time.timer(pollTimeout)); + + now = time.milliseconds(); + elapsed = now - start; + remaining = timeout - elapsed; + } while (remaining > 0); + } + + @Override + public JoinGroupRequestProtocolCollection metadata() { + configSnapshot = configStorage.snapshot(); + final ExtendedAssignment localAssignmentSnapshot = assignmentSnapshot; + ExtendedWorkerState workerState = new ExtendedWorkerState(restUrl, configSnapshot.offset(), localAssignmentSnapshot); + switch (protocolCompatibility) { + case EAGER: + return ConnectProtocol.metadataRequest(workerState); + case COMPATIBLE: + return IncrementalCooperativeConnectProtocol.metadataRequest(workerState, false); + case SESSIONED: + return IncrementalCooperativeConnectProtocol.metadataRequest(workerState, true); + default: + throw new IllegalStateException("Unknown Connect protocol compatibility mode " + protocolCompatibility); + } + } + + @Override + protected void onJoinComplete(int generation, String memberId, String protocol, ByteBuffer memberAssignment) { + ExtendedAssignment newAssignment = IncrementalCooperativeConnectProtocol.deserializeAssignment(memberAssignment); + log.debug("Deserialized new assignment: {}", newAssignment); + currentConnectProtocol = ConnectProtocolCompatibility.fromProtocol(protocol); + // At this point we always consider ourselves to be a member of the cluster, even if there was an assignment + // error (the leader couldn't make the assignment) or we are behind the config and cannot yet work on our assigned + // tasks. It's the responsibility of the code driving this process to decide how to react (e.g. trying to get + // up to date, try to rejoin again, leaving the group and backing off, etc.). + rejoinRequested = false; + if (currentConnectProtocol != EAGER) { + if (!newAssignment.revokedConnectors().isEmpty() || !newAssignment.revokedTasks().isEmpty()) { + listener.onRevoked(newAssignment.leader(), newAssignment.revokedConnectors(), newAssignment.revokedTasks()); + } + + final ExtendedAssignment localAssignmentSnapshot = assignmentSnapshot; + if (localAssignmentSnapshot != null) { + localAssignmentSnapshot.connectors().removeAll(newAssignment.revokedConnectors()); + localAssignmentSnapshot.tasks().removeAll(newAssignment.revokedTasks()); + log.debug("After revocations snapshot of assignment: {}", localAssignmentSnapshot); + newAssignment.connectors().addAll(localAssignmentSnapshot.connectors()); + newAssignment.tasks().addAll(localAssignmentSnapshot.tasks()); + } + log.debug("Augmented new assignment: {}", newAssignment); + } + assignmentSnapshot = newAssignment; + lastCompletedGenerationId = generation; + listener.onAssigned(newAssignment, generation); + } + + @Override + protected Map performAssignment(String leaderId, String protocol, List allMemberMetadata) { + return ConnectProtocolCompatibility.fromProtocol(protocol) == EAGER + ? eagerAssignor.performAssignment(leaderId, protocol, allMemberMetadata, this) + : incrementalAssignor.performAssignment(leaderId, protocol, allMemberMetadata, this); + } + + @Override + protected void onJoinPrepare(int generation, String memberId) { + log.info("Rebalance started"); + leaderState(null); + final ExtendedAssignment localAssignmentSnapshot = assignmentSnapshot; + if (currentConnectProtocol == EAGER) { + log.debug("Revoking previous assignment {}", localAssignmentSnapshot); + if (localAssignmentSnapshot != null && !localAssignmentSnapshot.failed()) + listener.onRevoked(localAssignmentSnapshot.leader(), localAssignmentSnapshot.connectors(), localAssignmentSnapshot.tasks()); + } else { + log.debug("Cooperative rebalance triggered. Keeping assignment {} until it's " + + "explicitly revoked.", localAssignmentSnapshot); + } + } + + @Override + protected boolean rejoinNeededOrPending() { + final ExtendedAssignment localAssignmentSnapshot = assignmentSnapshot; + return super.rejoinNeededOrPending() || (localAssignmentSnapshot == null || localAssignmentSnapshot.failed()) || rejoinRequested; + } + + @Override + public String memberId() { + Generation generation = generationIfStable(); + if (generation != null) + return generation.memberId; + return JoinGroupRequest.UNKNOWN_MEMBER_ID; + } + + /** + * Return the current generation. The generation refers to this worker's knowledge with + * respect to which generation is the latest one and, therefore, this information is local. + * + * @return the generation ID or -1 if no generation is defined + */ + public int generationId() { + return super.generation().generationId; + } + + /** + * Return id that corresponds to the group generation that was active when the last join was successful + * + * @return the generation ID of the last group that was joined successfully by this member or -1 if no generation + * was stable at that point + */ + public int lastCompletedGenerationId() { + return lastCompletedGenerationId; + } + + public void revokeAssignment(ExtendedAssignment assignment) { + listener.onRevoked(assignment.leader(), assignment.connectors(), assignment.tasks()); + } + + private boolean isLeader() { + final ExtendedAssignment localAssignmentSnapshot = assignmentSnapshot; + return localAssignmentSnapshot != null && memberId().equals(localAssignmentSnapshot.leader()); + } + + public String ownerUrl(String connector) { + if (rejoinNeededOrPending() || !isLeader()) + return null; + return leaderState().ownerUrl(connector); + } + + public String ownerUrl(ConnectorTaskId task) { + if (rejoinNeededOrPending() || !isLeader()) + return null; + return leaderState().ownerUrl(task); + } + + /** + * Get an up-to-date snapshot of the cluster configuration. + * + * @return the state of the cluster configuration; the result is not locally cached + */ + public ClusterConfigState configFreshSnapshot() { + return configStorage.snapshot(); + } + + /** + * Get a snapshot of the cluster configuration. + * + * @return the state of the cluster configuration + */ + public ClusterConfigState configSnapshot() { + return configSnapshot; + } + + /** + * Set the state of the cluster configuration to this worker coordinator. + * + * @param update the updated state of the cluster configuration + */ + public void configSnapshot(ClusterConfigState update) { + configSnapshot = update; + } + + /** + * Get the leader state stored in this worker coordinator. + * + * @return the leader state + */ + private LeaderState leaderState() { + return leaderState; + } + + /** + * Store the leader state to this worker coordinator. + * + * @param update the updated leader state + */ + public void leaderState(LeaderState update) { + leaderState = update; + } + + /** + * Get the version of the connect protocol that is currently active in the group of workers. + * + * @return the current connect protocol version + */ + public short currentProtocolVersion() { + return currentConnectProtocol.protocolVersion(); + } + + private class WorkerCoordinatorMetrics { + public final String metricGrpName; + + public WorkerCoordinatorMetrics(Metrics metrics, String metricGrpPrefix) { + this.metricGrpName = metricGrpPrefix + "-coordinator-metrics"; + + Measurable numConnectors = (config, now) -> { + final ExtendedAssignment localAssignmentSnapshot = assignmentSnapshot; + if (localAssignmentSnapshot == null) { + return 0.0; + } + return localAssignmentSnapshot.connectors().size(); + }; + + Measurable numTasks = (config, now) -> { + final ExtendedAssignment localAssignmentSnapshot = assignmentSnapshot; + if (localAssignmentSnapshot == null) { + return 0.0; + } + return localAssignmentSnapshot.tasks().size(); + }; + + metrics.addMetric(metrics.metricName("assigned-connectors", + this.metricGrpName, + "The number of connector instances currently assigned to this consumer"), numConnectors); + metrics.addMetric(metrics.metricName("assigned-tasks", + this.metricGrpName, + "The number of tasks currently assigned to this consumer"), numTasks); + } + } + + public static Map invertAssignment(Map> assignment) { + Map inverted = new HashMap<>(); + for (Map.Entry> assignmentEntry : assignment.entrySet()) { + K key = assignmentEntry.getKey(); + for (V value : assignmentEntry.getValue()) + inverted.put(value, key); + } + return inverted; + } + + public static class LeaderState { + private final Map allMembers; + private final Map connectorOwners; + private final Map taskOwners; + + public LeaderState(Map allMembers, + Map> connectorAssignment, + Map> taskAssignment) { + this.allMembers = allMembers; + this.connectorOwners = invertAssignment(connectorAssignment); + this.taskOwners = invertAssignment(taskAssignment); + } + + private String ownerUrl(ConnectorTaskId id) { + String ownerId = taskOwners.get(id); + if (ownerId == null) + return null; + return allMembers.get(ownerId).url(); + } + + private String ownerUrl(String connector) { + String ownerId = connectorOwners.get(connector); + if (ownerId == null) + return null; + return allMembers.get(ownerId).url(); + } + + } + + public static class ConnectorsAndTasks { + public static final ConnectorsAndTasks EMPTY = + new ConnectorsAndTasks(Collections.emptyList(), Collections.emptyList()); + + private final Collection connectors; + private final Collection tasks; + + private ConnectorsAndTasks(Collection connectors, Collection tasks) { + this.connectors = connectors; + this.tasks = tasks; + } + + public static class Builder { + private Collection withConnectors; + private Collection withTasks; + + public Builder() { + } + + public ConnectorsAndTasks.Builder withCopies(Collection connectors, + Collection tasks) { + withConnectors = new ArrayList<>(connectors); + withTasks = new ArrayList<>(tasks); + return this; + } + + public ConnectorsAndTasks.Builder with(Collection connectors, + Collection tasks) { + withConnectors = new ArrayList<>(connectors); + withTasks = new ArrayList<>(tasks); + return this; + } + + public ConnectorsAndTasks build() { + return new ConnectorsAndTasks( + withConnectors != null ? withConnectors : new ArrayList<>(), + withTasks != null ? withTasks : new ArrayList<>()); + } + } + + public Collection connectors() { + return connectors; + } + + public Collection tasks() { + return tasks; + } + + public int size() { + return connectors.size() + tasks.size(); + } + + public boolean isEmpty() { + return connectors.isEmpty() && tasks.isEmpty(); + } + + @Override + public String toString() { + return "{ connectorIds=" + connectors + ", taskIds=" + tasks + '}'; + } + } + + public static class WorkerLoad { + private final String worker; + private final Collection connectors; + private final Collection tasks; + + private WorkerLoad( + String worker, + Collection connectors, + Collection tasks + ) { + this.worker = worker; + this.connectors = connectors; + this.tasks = tasks; + } + + public static class Builder { + private String withWorker; + private Collection withConnectors; + private Collection withTasks; + + public Builder(String worker) { + this.withWorker = Objects.requireNonNull(worker, "worker cannot be null"); + } + + public WorkerLoad.Builder withCopies(Collection connectors, + Collection tasks) { + withConnectors = new ArrayList<>( + Objects.requireNonNull(connectors, "connectors may be empty but not null")); + withTasks = new ArrayList<>( + Objects.requireNonNull(tasks, "tasks may be empty but not null")); + return this; + } + + public WorkerLoad.Builder with(Collection connectors, + Collection tasks) { + withConnectors = Objects.requireNonNull(connectors, + "connectors may be empty but not null"); + withTasks = Objects.requireNonNull(tasks, "tasks may be empty but not null"); + return this; + } + + public WorkerLoad build() { + return new WorkerLoad( + withWorker, + withConnectors != null ? withConnectors : new ArrayList<>(), + withTasks != null ? withTasks : new ArrayList<>()); + } + } + + public String worker() { + return worker; + } + + public Collection connectors() { + return connectors; + } + + public Collection tasks() { + return tasks; + } + + public int connectorsSize() { + return connectors.size(); + } + + public int tasksSize() { + return tasks.size(); + } + + public void assign(String connector) { + connectors.add(connector); + } + + public void assign(ConnectorTaskId task) { + tasks.add(task); + } + + public int size() { + return connectors.size() + tasks.size(); + } + + public boolean isEmpty() { + return connectors.isEmpty() && tasks.isEmpty(); + } + + public static Comparator connectorComparator() { + return (left, right) -> { + int res = left.connectors.size() - right.connectors.size(); + return res != 0 ? res : left.worker == null + ? right.worker == null ? 0 : -1 + : left.worker.compareTo(right.worker); + }; + } + + public static Comparator taskComparator() { + return (left, right) -> { + int res = left.tasks.size() - right.tasks.size(); + return res != 0 ? res : left.worker == null + ? right.worker == null ? 0 : -1 + : left.worker.compareTo(right.worker); + }; + } + + @Override + public String toString() { + return "{ worker=" + worker + ", connectorIds=" + connectors + ", taskIds=" + tasks + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof WorkerLoad)) { + return false; + } + WorkerLoad that = (WorkerLoad) o; + return worker.equals(that.worker) && + connectors.equals(that.connectors) && + tasks.equals(that.tasks); + } + + @Override + public int hashCode() { + return Objects.hash(worker, connectors, tasks); + } + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerGroupMember.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerGroupMember.java new file mode 100644 index 0000000..4c1d6a5 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerGroupMember.java @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.distributed; + +import org.apache.kafka.clients.ApiVersions; +import org.apache.kafka.clients.ClientUtils; +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.clients.Metadata; +import org.apache.kafka.clients.NetworkClient; +import org.apache.kafka.clients.consumer.internals.ConsumerNetworkClient; +import org.apache.kafka.clients.GroupRebalanceConfig; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.internals.ClusterResourceListeners; +import org.apache.kafka.common.metrics.JmxReporter; +import org.apache.kafka.common.metrics.MetricsContext; +import org.apache.kafka.common.metrics.KafkaMetricsContext; +import org.apache.kafka.common.metrics.MetricConfig; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.MetricsReporter; +import org.apache.kafka.common.network.ChannelBuilder; +import org.apache.kafka.common.network.Selector; +import org.apache.kafka.common.utils.AppInfoParser; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.apache.kafka.connect.storage.ConfigBackingStore; +import org.apache.kafka.connect.util.ConnectUtils; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.slf4j.Logger; + +import java.net.InetSocketAddress; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +/** + * This class manages the coordination process with brokers for the Connect cluster group membership. It ties together + * the Coordinator, which implements the group member protocol, with all the other pieces needed to drive the connection + * to the group coordinator broker. This isolates all the networking to a single thread managed by this class, with + * higher level operations in response to group membership events being handled by the herder. + */ +public class WorkerGroupMember { + private static final String JMX_PREFIX = "kafka.connect"; + + private final Logger log; + private final Time time; + private final String clientId; + private final ConsumerNetworkClient client; + private final Metrics metrics; + private final Metadata metadata; + private final long retryBackoffMs; + private final WorkerCoordinator coordinator; + + private boolean stopped = false; + + public WorkerGroupMember(DistributedConfig config, + String restUrl, + ConfigBackingStore configStorage, + WorkerRebalanceListener listener, + Time time, + String clientId, + LogContext logContext) { + try { + this.time = time; + this.clientId = clientId; + this.log = logContext.logger(WorkerGroupMember.class); + + Map metricsTags = new LinkedHashMap<>(); + metricsTags.put("client-id", clientId); + MetricConfig metricConfig = new MetricConfig().samples(config.getInt(CommonClientConfigs.METRICS_NUM_SAMPLES_CONFIG)) + .timeWindow(config.getLong(CommonClientConfigs.METRICS_SAMPLE_WINDOW_MS_CONFIG), TimeUnit.MILLISECONDS) + .tags(metricsTags); + List reporters = config.getConfiguredInstances(CommonClientConfigs.METRIC_REPORTER_CLASSES_CONFIG, + MetricsReporter.class, + Collections.singletonMap(CommonClientConfigs.CLIENT_ID_CONFIG, clientId)); + JmxReporter jmxReporter = new JmxReporter(); + jmxReporter.configure(config.originals()); + reporters.add(jmxReporter); + + Map contextLabels = new HashMap<>(); + contextLabels.putAll(config.originalsWithPrefix(CommonClientConfigs.METRICS_CONTEXT_PREFIX)); + contextLabels.put(WorkerConfig.CONNECT_KAFKA_CLUSTER_ID, ConnectUtils.lookupKafkaClusterId(config)); + contextLabels.put(WorkerConfig.CONNECT_GROUP_ID, config.getString(DistributedConfig.GROUP_ID_CONFIG)); + MetricsContext metricsContext = new KafkaMetricsContext(JMX_PREFIX, contextLabels); + + this.metrics = new Metrics(metricConfig, reporters, time, metricsContext); + this.retryBackoffMs = config.getLong(CommonClientConfigs.RETRY_BACKOFF_MS_CONFIG); + this.metadata = new Metadata(retryBackoffMs, config.getLong(CommonClientConfigs.METADATA_MAX_AGE_CONFIG), + logContext, new ClusterResourceListeners()); + List addresses = ClientUtils.parseAndValidateAddresses( + config.getList(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG), + config.getString(CommonClientConfigs.CLIENT_DNS_LOOKUP_CONFIG)); + this.metadata.bootstrap(addresses); + String metricGrpPrefix = "connect"; + ChannelBuilder channelBuilder = ClientUtils.createChannelBuilder(config, time, logContext); + NetworkClient netClient = new NetworkClient( + new Selector(config.getLong(CommonClientConfigs.CONNECTIONS_MAX_IDLE_MS_CONFIG), metrics, time, metricGrpPrefix, channelBuilder, logContext), + this.metadata, + clientId, + 100, // a fixed large enough value will suffice + config.getLong(CommonClientConfigs.RECONNECT_BACKOFF_MS_CONFIG), + config.getLong(CommonClientConfigs.RECONNECT_BACKOFF_MAX_MS_CONFIG), + config.getInt(CommonClientConfigs.SEND_BUFFER_CONFIG), + config.getInt(CommonClientConfigs.RECEIVE_BUFFER_CONFIG), + config.getInt(CommonClientConfigs.REQUEST_TIMEOUT_MS_CONFIG), + config.getLong(CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MS_CONFIG), + config.getLong(CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS_CONFIG), + time, + true, + new ApiVersions(), + logContext); + this.client = new ConsumerNetworkClient( + logContext, + netClient, + metadata, + time, + retryBackoffMs, + config.getInt(CommonClientConfigs.REQUEST_TIMEOUT_MS_CONFIG), + Integer.MAX_VALUE); + this.coordinator = new WorkerCoordinator( + new GroupRebalanceConfig(config, GroupRebalanceConfig.ProtocolType.CONNECT), + logContext, + this.client, + metrics, + metricGrpPrefix, + this.time, + restUrl, + configStorage, + listener, + ConnectProtocolCompatibility.compatibility(config.getString(DistributedConfig.CONNECT_PROTOCOL_CONFIG)), + config.getInt(DistributedConfig.SCHEDULED_REBALANCE_MAX_DELAY_MS_CONFIG)); + + AppInfoParser.registerAppInfo(JMX_PREFIX, clientId, metrics, time.milliseconds()); + log.debug("Connect group member created"); + } catch (Throwable t) { + // call close methods if internal objects are already constructed + // this is to prevent resource leak. see KAFKA-2121 + stop(true); + // now propagate the exception + throw new KafkaException("Failed to construct kafka consumer", t); + } + } + + public void stop() { + if (stopped) return; + stop(false); + } + + /** + * Ensure that the connection to the broker coordinator is up and that the worker is an + * active member of the group. + */ + public void ensureActive() { + coordinator.poll(0); + } + + public void poll(long timeout) { + if (timeout < 0) + throw new IllegalArgumentException("Timeout must not be negative"); + coordinator.poll(timeout); + } + + /** + * Interrupt any running poll() calls, causing a WakeupException to be thrown in the thread invoking that method. + */ + public void wakeup() { + this.client.wakeup(); + } + + /** + * Get the member ID of this worker in the group of workers. + * + * This ID is the unique member ID automatically generated. + * + * @return the member ID + */ + public String memberId() { + return coordinator.memberId(); + } + + public void requestRejoin() { + coordinator.requestRejoin("connect worker requested rejoin"); + } + + public void maybeLeaveGroup(String leaveReason) { + coordinator.maybeLeaveGroup(leaveReason); + } + + public String ownerUrl(String connector) { + return coordinator.ownerUrl(connector); + } + + public String ownerUrl(ConnectorTaskId task) { + return coordinator.ownerUrl(task); + } + + /** + * Get the version of the connect protocol that is currently active in the group of workers. + * + * @return the current connect protocol version + */ + public short currentProtocolVersion() { + return coordinator.currentProtocolVersion(); + } + + public void revokeAssignment(ExtendedAssignment assignment) { + coordinator.revokeAssignment(assignment); + } + + private void stop(boolean swallowException) { + log.trace("Stopping the Connect group member."); + AtomicReference firstException = new AtomicReference<>(); + this.stopped = true; + Utils.closeQuietly(coordinator, "coordinator", firstException); + Utils.closeQuietly(metrics, "consumer metrics", firstException); + Utils.closeQuietly(client, "consumer network client", firstException); + AppInfoParser.unregisterAppInfo(JMX_PREFIX, clientId, metrics); + if (firstException.get() != null && !swallowException) + throw new KafkaException("Failed to stop the Connect group member", firstException.get()); + else + log.debug("The Connect group member has stopped."); + } + + // Visible for testing + Metrics metrics() { + return this.metrics; + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerRebalanceListener.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerRebalanceListener.java new file mode 100644 index 0000000..93d0327 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/distributed/WorkerRebalanceListener.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.distributed; + +import org.apache.kafka.connect.util.ConnectorTaskId; + +import java.util.Collection; + +/** + * Listener for rebalance events in the worker group. + */ +public interface WorkerRebalanceListener { + /** + * Invoked when a new assignment is created by joining the Connect worker group. This is + * invoked for both successful and unsuccessful assignments. + */ + void onAssigned(ExtendedAssignment assignment, int generation); + + /** + * Invoked when a rebalance operation starts, revoking ownership for the set of connectors + * and tasks. Depending on the Connect protocol version, the collection of revoked connectors + * or tasks might refer to all or some of the connectors and tasks running on the worker. + */ + void onRevoked(String leader, Collection connectors, Collection tasks); +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/DeadLetterQueueReporter.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/DeadLetterQueueReporter.java new file mode 100644 index 0000000..a4480f4 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/DeadLetterQueueReporter.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.errors; + +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.errors.TopicExistsException; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.runtime.SinkConnectorConfig; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; + +import static java.util.Collections.singleton; + +/** + * Write the original consumed record into a dead letter queue. The dead letter queue is a Kafka topic located + * on the same cluster used by the worker to maintain internal topics. Each connector is typically configured + * with its own Kafka topic dead letter queue. By default, the topic name is not set, and if the + * connector config doesn't specify one, this feature is disabled. + */ +public class DeadLetterQueueReporter implements ErrorReporter { + + private static final Logger log = LoggerFactory.getLogger(DeadLetterQueueReporter.class); + + private static final int DLQ_NUM_DESIRED_PARTITIONS = 1; + + public static final String HEADER_PREFIX = "__connect.errors."; + public static final String ERROR_HEADER_ORIG_TOPIC = HEADER_PREFIX + "topic"; + public static final String ERROR_HEADER_ORIG_PARTITION = HEADER_PREFIX + "partition"; + public static final String ERROR_HEADER_ORIG_OFFSET = HEADER_PREFIX + "offset"; + public static final String ERROR_HEADER_CONNECTOR_NAME = HEADER_PREFIX + "connector.name"; + public static final String ERROR_HEADER_TASK_ID = HEADER_PREFIX + "task.id"; + public static final String ERROR_HEADER_STAGE = HEADER_PREFIX + "stage"; + public static final String ERROR_HEADER_EXECUTING_CLASS = HEADER_PREFIX + "class.name"; + public static final String ERROR_HEADER_EXCEPTION = HEADER_PREFIX + "exception.class.name"; + public static final String ERROR_HEADER_EXCEPTION_MESSAGE = HEADER_PREFIX + "exception.message"; + public static final String ERROR_HEADER_EXCEPTION_STACK_TRACE = HEADER_PREFIX + "exception.stacktrace"; + + private final SinkConnectorConfig connConfig; + private final ConnectorTaskId connectorTaskId; + private final ErrorHandlingMetrics errorHandlingMetrics; + private final String dlqTopicName; + + private KafkaProducer kafkaProducer; + + public static DeadLetterQueueReporter createAndSetup(Map adminProps, + ConnectorTaskId id, + SinkConnectorConfig sinkConfig, Map producerProps, + ErrorHandlingMetrics errorHandlingMetrics) { + String topic = sinkConfig.dlqTopicName(); + + try (Admin admin = Admin.create(adminProps)) { + if (!admin.listTopics().names().get().contains(topic)) { + log.error("Topic {} doesn't exist. Will attempt to create topic.", topic); + NewTopic schemaTopicRequest = new NewTopic(topic, DLQ_NUM_DESIRED_PARTITIONS, sinkConfig.dlqTopicReplicationFactor()); + admin.createTopics(singleton(schemaTopicRequest)).all().get(); + } + } catch (InterruptedException e) { + throw new ConnectException("Could not initialize dead letter queue with topic=" + topic, e); + } catch (ExecutionException e) { + if (!(e.getCause() instanceof TopicExistsException)) { + throw new ConnectException("Could not initialize dead letter queue with topic=" + topic, e); + } + } + + KafkaProducer dlqProducer = new KafkaProducer<>(producerProps); + return new DeadLetterQueueReporter(dlqProducer, sinkConfig, id, errorHandlingMetrics); + } + + /** + * Initialize the dead letter queue reporter with a {@link KafkaProducer}. + * + * @param kafkaProducer a Kafka Producer to produce the original consumed records. + */ + // Visible for testing + DeadLetterQueueReporter(KafkaProducer kafkaProducer, SinkConnectorConfig connConfig, + ConnectorTaskId id, ErrorHandlingMetrics errorHandlingMetrics) { + Objects.requireNonNull(kafkaProducer); + Objects.requireNonNull(connConfig); + Objects.requireNonNull(id); + Objects.requireNonNull(errorHandlingMetrics); + + this.kafkaProducer = kafkaProducer; + this.connConfig = connConfig; + this.connectorTaskId = id; + this.errorHandlingMetrics = errorHandlingMetrics; + this.dlqTopicName = connConfig.dlqTopicName().trim(); + } + + /** + * Write the raw records into a Kafka topic and return the producer future. + * + * @param context processing context containing the raw record at {@link ProcessingContext#consumerRecord()}. + * @return the future associated with the writing of this record; never null + */ + public Future report(ProcessingContext context) { + if (dlqTopicName.isEmpty()) { + return CompletableFuture.completedFuture(null); + } + errorHandlingMetrics.recordDeadLetterQueueProduceRequest(); + + ConsumerRecord originalMessage = context.consumerRecord(); + if (originalMessage == null) { + errorHandlingMetrics.recordDeadLetterQueueProduceFailed(); + return CompletableFuture.completedFuture(null); + } + + ProducerRecord producerRecord; + if (originalMessage.timestamp() == RecordBatch.NO_TIMESTAMP) { + producerRecord = new ProducerRecord<>(dlqTopicName, null, + originalMessage.key(), originalMessage.value(), originalMessage.headers()); + } else { + producerRecord = new ProducerRecord<>(dlqTopicName, null, originalMessage.timestamp(), + originalMessage.key(), originalMessage.value(), originalMessage.headers()); + } + + if (connConfig.isDlqContextHeadersEnabled()) { + populateContextHeaders(producerRecord, context); + } + + return this.kafkaProducer.send(producerRecord, (metadata, exception) -> { + if (exception != null) { + log.error("Could not produce message to dead letter queue. topic=" + dlqTopicName, exception); + errorHandlingMetrics.recordDeadLetterQueueProduceFailed(); + } + }); + } + + // Visible for testing + void populateContextHeaders(ProducerRecord producerRecord, ProcessingContext context) { + Headers headers = producerRecord.headers(); + if (context.consumerRecord() != null) { + headers.add(ERROR_HEADER_ORIG_TOPIC, toBytes(context.consumerRecord().topic())); + headers.add(ERROR_HEADER_ORIG_PARTITION, toBytes(context.consumerRecord().partition())); + headers.add(ERROR_HEADER_ORIG_OFFSET, toBytes(context.consumerRecord().offset())); + } + + headers.add(ERROR_HEADER_CONNECTOR_NAME, toBytes(connectorTaskId.connector())); + headers.add(ERROR_HEADER_TASK_ID, toBytes(String.valueOf(connectorTaskId.task()))); + headers.add(ERROR_HEADER_STAGE, toBytes(context.stage().name())); + headers.add(ERROR_HEADER_EXECUTING_CLASS, toBytes(context.executingClass().getName())); + if (context.error() != null) { + headers.add(ERROR_HEADER_EXCEPTION, toBytes(context.error().getClass().getName())); + headers.add(ERROR_HEADER_EXCEPTION_MESSAGE, toBytes(context.error().getMessage())); + byte[] trace; + if ((trace = stacktrace(context.error())) != null) { + headers.add(ERROR_HEADER_EXCEPTION_STACK_TRACE, trace); + } + } + } + + private byte[] stacktrace(Throwable error) { + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + try { + PrintStream stream = new PrintStream(bos, true, StandardCharsets.UTF_8.name()); + error.printStackTrace(stream); + bos.close(); + return bos.toByteArray(); + } catch (IOException e) { + log.error("Could not serialize stacktrace.", e); + } + return null; + } + + private byte[] toBytes(int value) { + return toBytes(String.valueOf(value)); + } + + private byte[] toBytes(long value) { + return toBytes(String.valueOf(value)); + } + + private byte[] toBytes(String value) { + if (value != null) { + return value.getBytes(StandardCharsets.UTF_8); + } else { + return null; + } + } + + @Override + public void close() { + kafkaProducer.close(); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/ErrorHandlingMetrics.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/ErrorHandlingMetrics.java new file mode 100644 index 0000000..419bea9 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/ErrorHandlingMetrics.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.errors; + +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.stats.CumulativeSum; +import org.apache.kafka.common.utils.SystemTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.connect.runtime.ConnectMetrics; +import org.apache.kafka.connect.runtime.ConnectMetricsRegistry; +import org.apache.kafka.connect.util.ConnectorTaskId; + +/** + * Contains various sensors used for monitoring errors. + */ +public class ErrorHandlingMetrics { + + private final Time time = new SystemTime(); + + private final ConnectMetrics.MetricGroup metricGroup; + + // metrics + private final Sensor recordProcessingFailures; + private final Sensor recordProcessingErrors; + private final Sensor recordsSkipped; + private final Sensor retries; + private final Sensor errorsLogged; + private final Sensor dlqProduceRequests; + private final Sensor dlqProduceFailures; + private long lastErrorTime = 0; + + public ErrorHandlingMetrics(ConnectorTaskId id, ConnectMetrics connectMetrics) { + + ConnectMetricsRegistry registry = connectMetrics.registry(); + metricGroup = connectMetrics.group(registry.taskErrorHandlingGroupName(), + registry.connectorTagName(), id.connector(), registry.taskTagName(), Integer.toString(id.task())); + + // prevent collisions by removing any previously created metrics in this group. + metricGroup.close(); + + recordProcessingFailures = metricGroup.sensor("total-record-failures"); + recordProcessingFailures.add(metricGroup.metricName(registry.recordProcessingFailures), new CumulativeSum()); + + recordProcessingErrors = metricGroup.sensor("total-record-errors"); + recordProcessingErrors.add(metricGroup.metricName(registry.recordProcessingErrors), new CumulativeSum()); + + recordsSkipped = metricGroup.sensor("total-records-skipped"); + recordsSkipped.add(metricGroup.metricName(registry.recordsSkipped), new CumulativeSum()); + + retries = metricGroup.sensor("total-retries"); + retries.add(metricGroup.metricName(registry.retries), new CumulativeSum()); + + errorsLogged = metricGroup.sensor("total-errors-logged"); + errorsLogged.add(metricGroup.metricName(registry.errorsLogged), new CumulativeSum()); + + dlqProduceRequests = metricGroup.sensor("deadletterqueue-produce-requests"); + dlqProduceRequests.add(metricGroup.metricName(registry.dlqProduceRequests), new CumulativeSum()); + + dlqProduceFailures = metricGroup.sensor("deadletterqueue-produce-failures"); + dlqProduceFailures.add(metricGroup.metricName(registry.dlqProduceFailures), new CumulativeSum()); + + metricGroup.addValueMetric(registry.lastErrorTimestamp, now -> lastErrorTime); + } + + /** + * Increment the number of failed operations (retriable and non-retriable). + */ + public void recordFailure() { + recordProcessingFailures.record(); + } + + /** + * Increment the number of operations which could not be successfully executed. + */ + public void recordError() { + recordProcessingErrors.record(); + } + + /** + * Increment the number of records skipped. + */ + public void recordSkipped() { + recordsSkipped.record(); + } + + /** + * The number of retries made while executing operations. + */ + public void recordRetry() { + retries.record(); + } + + /** + * The number of errors logged by the {@link LogReporter}. + */ + public void recordErrorLogged() { + errorsLogged.record(); + } + + /** + * The number of produce requests to the {@link DeadLetterQueueReporter}. + */ + public void recordDeadLetterQueueProduceRequest() { + dlqProduceRequests.record(); + } + + /** + * The number of produce requests to the {@link DeadLetterQueueReporter} which failed to be successfully produced into Kafka. + */ + public void recordDeadLetterQueueProduceFailed() { + dlqProduceFailures.record(); + } + + /** + * Record the time of error. + */ + public void recordErrorTimestamp() { + this.lastErrorTime = time.milliseconds(); + } + + /** + * @return the metric group for this class. + */ + public ConnectMetrics.MetricGroup metricGroup() { + return metricGroup; + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/ErrorReporter.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/ErrorReporter.java new file mode 100644 index 0000000..f9bc2f2 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/ErrorReporter.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.errors; + +import org.apache.kafka.clients.producer.RecordMetadata; + +import java.util.concurrent.Future; + +/** + * Report an error using the information contained in the {@link ProcessingContext}. + */ +public interface ErrorReporter extends AutoCloseable { + + /** + * Report an error and return the producer future. + * + * @param context the processing context (cannot be null). + * @return future result from the producer sending a record to Kafka. + */ + Future report(ProcessingContext context); + + @Override + default void close() { } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/LogReporter.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/LogReporter.java new file mode 100644 index 0000000..cf9db2c --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/LogReporter.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.errors; + +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.connect.runtime.ConnectorConfig; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Future; + +/** + * Writes errors and their context to application logs. + */ +public class LogReporter implements ErrorReporter { + + private static final Logger log = LoggerFactory.getLogger(LogReporter.class); + private static final Future COMPLETED = CompletableFuture.completedFuture(null); + + private final ConnectorTaskId id; + private final ConnectorConfig connConfig; + private final ErrorHandlingMetrics errorHandlingMetrics; + + public LogReporter(ConnectorTaskId id, ConnectorConfig connConfig, ErrorHandlingMetrics errorHandlingMetrics) { + Objects.requireNonNull(id); + Objects.requireNonNull(connConfig); + Objects.requireNonNull(errorHandlingMetrics); + + this.id = id; + this.connConfig = connConfig; + this.errorHandlingMetrics = errorHandlingMetrics; + } + + /** + * Log error context. + * + * @param context the processing context. + */ + @Override + public Future report(ProcessingContext context) { + if (!connConfig.enableErrorLog()) { + return COMPLETED; + } + + if (!context.failed()) { + return COMPLETED; + } + + log.error(message(context), context.error()); + errorHandlingMetrics.recordErrorLogged(); + return COMPLETED; + } + + // Visible for testing + String message(ProcessingContext context) { + return String.format("Error encountered in task %s. %s", String.valueOf(id), + context.toString(connConfig.includeRecordDetailsInErrorLog())); + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/Operation.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/Operation.java new file mode 100644 index 0000000..3e0f792 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/Operation.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.errors; + +import java.util.concurrent.Callable; + +/** + * A recoverable operation evaluated in the connector pipeline. + * + * @param return type of the result of the operation. + */ +public interface Operation extends Callable { + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/ProcessingContext.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/ProcessingContext.java new file mode 100644 index 0000000..b49c93c --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/ProcessingContext.java @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.errors; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.runtime.errors.WorkerErrantRecordReporter.ErrantRecordFuture; +import org.apache.kafka.connect.source.SourceRecord; + +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Future; +import java.util.stream.Collectors; + +/** + * Contains all the metadata related to the currently evaluating operation. Only one instance of this class is meant + * to exist per task in a JVM. + */ +class ProcessingContext implements AutoCloseable { + + private Collection reporters = Collections.emptyList(); + + private ConsumerRecord consumedMessage; + private SourceRecord sourceRecord; + + /** + * The following fields need to be reset every time a new record is seen. + */ + + private Stage position; + private Class klass; + private int attempt; + private Throwable error; + + /** + * Reset the internal fields before executing operations on a new record. + */ + private void reset() { + attempt = 0; + position = null; + klass = null; + error = null; + } + + /** + * Set the record consumed from Kafka in a sink connector. + * + * @param consumedMessage the record + */ + public void consumerRecord(ConsumerRecord consumedMessage) { + this.consumedMessage = consumedMessage; + reset(); + } + + /** + * @return the record consumed from Kafka. could be null + */ + public ConsumerRecord consumerRecord() { + return consumedMessage; + } + + /** + * @return the source record being processed. + */ + public SourceRecord sourceRecord() { + return sourceRecord; + } + + /** + * Set the source record being processed in the connect pipeline. + * + * @param record the source record + */ + public void sourceRecord(SourceRecord record) { + this.sourceRecord = record; + reset(); + } + + /** + * Set the stage in the connector pipeline which is currently executing. + * + * @param position the stage + */ + public void position(Stage position) { + this.position = position; + } + + /** + * @return the stage in the connector pipeline which is currently executing. + */ + public Stage stage() { + return position; + } + + /** + * @return the class which is going to execute the current operation. + */ + public Class executingClass() { + return klass; + } + + /** + * @param klass set the class which is currently executing. + */ + public void executingClass(Class klass) { + this.klass = klass; + } + + /** + * A helper method to set both the stage and the class. + * + * @param stage the stage + * @param klass the class which will execute the operation in this stage. + */ + public void currentContext(Stage stage, Class klass) { + position(stage); + executingClass(klass); + } + + /** + * Report errors. Should be called only if an error was encountered while executing the operation. + * + * @return a errant record future that potentially aggregates the producer futures + */ + public Future report() { + if (reporters.size() == 1) { + return new ErrantRecordFuture(Collections.singletonList(reporters.iterator().next().report(this))); + } + + List> futures = reporters.stream() + .map(r -> r.report(this)) + .filter(f -> !f.isDone()) + .collect(Collectors.toList()); + if (futures.isEmpty()) { + return CompletableFuture.completedFuture(null); + } + return new ErrantRecordFuture(futures); + } + + @Override + public String toString() { + return toString(false); + } + + public String toString(boolean includeMessage) { + StringBuilder builder = new StringBuilder(); + builder.append("Executing stage '"); + builder.append(stage().name()); + builder.append("' with class '"); + builder.append(executingClass() == null ? "null" : executingClass().getName()); + builder.append('\''); + if (includeMessage && sourceRecord() != null) { + builder.append(", where source record is = "); + builder.append(sourceRecord()); + } else if (includeMessage && consumerRecord() != null) { + ConsumerRecord msg = consumerRecord(); + builder.append(", where consumed record is "); + builder.append("{topic='").append(msg.topic()).append('\''); + builder.append(", partition=").append(msg.partition()); + builder.append(", offset=").append(msg.offset()); + if (msg.timestampType() == TimestampType.CREATE_TIME || msg.timestampType() == TimestampType.LOG_APPEND_TIME) { + builder.append(", timestamp=").append(msg.timestamp()); + builder.append(", timestampType=").append(msg.timestampType()); + } + builder.append("}"); + } + builder.append('.'); + return builder.toString(); + } + + /** + * @param attempt the number of attempts made to execute the current operation. + */ + public void attempt(int attempt) { + this.attempt = attempt; + } + + /** + * @return the number of attempts made to execute the current operation. + */ + public int attempt() { + return attempt; + } + + /** + * @return the error (if any) which was encountered while processing the current stage. + */ + public Throwable error() { + return error; + } + + /** + * The error (if any) which was encountered while processing the current stage. + * + * @param error the error + */ + public void error(Throwable error) { + this.error = error; + } + + /** + * @return true, if the last operation encountered an error; false otherwise + */ + public boolean failed() { + return error() != null; + } + + /** + * Set the error reporters for this connector. + * + * @param reporters the error reporters (should not be null). + */ + public void reporters(Collection reporters) { + Objects.requireNonNull(reporters); + this.reporters = reporters; + } + + @Override + public void close() { + ConnectException e = null; + for (ErrorReporter reporter : reporters) { + try { + reporter.close(); + } catch (Throwable t) { + e = e != null ? e : new ConnectException("Failed to close all reporters"); + e.addSuppressed(t); + } + } + if (e != null) { + throw e; + } + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/RetryWithToleranceOperator.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/RetryWithToleranceOperator.java new file mode 100644 index 0000000..ce4c1e2 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/RetryWithToleranceOperator.java @@ -0,0 +1,314 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.errors; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.errors.RetriableException; +import org.apache.kafka.connect.runtime.ConnectorConfig; +import org.apache.kafka.connect.source.SourceRecord; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadLocalRandom; + +/** + * Attempt to recover a failed operation with retries and tolerance limits. + *

            + * + * A retry is attempted if the operation throws a {@link RetriableException}. Retries are accompanied by exponential backoffs, starting with + * {@link #RETRIES_DELAY_MIN_MS}, up to what is specified with {@link ConnectorConfig#errorMaxDelayInMillis()}. + * Including the first attempt and future retries, the total time taken to evaluate the operation should be within + * {@link ConnectorConfig#errorMaxDelayInMillis()} millis. + *

            + * + * This executor will tolerate failures, as specified by {@link ConnectorConfig#errorToleranceType()}. + * For transformations and converters, all exceptions are tolerated. For others operations, only {@link RetriableException} are tolerated. + *

            + * + * There are three outcomes to executing an operation. It might succeed, in which case the result is returned to the caller. + * If it fails, this class does one of these two things: (1) if the failure occurred due to a tolerable exception, then + * set appropriate error reason in the {@link ProcessingContext} and return null, or (2) if the exception is not tolerated, + * then it is wrapped into a ConnectException and rethrown to the caller. + *

            + * + * Instances of this class are thread safe. + *

            + */ +public class RetryWithToleranceOperator implements AutoCloseable { + + private static final Logger log = LoggerFactory.getLogger(RetryWithToleranceOperator.class); + + public static final long RETRIES_DELAY_MIN_MS = 300; + + private static final Map> TOLERABLE_EXCEPTIONS = new HashMap<>(); + static { + TOLERABLE_EXCEPTIONS.put(Stage.TRANSFORMATION, Exception.class); + TOLERABLE_EXCEPTIONS.put(Stage.HEADER_CONVERTER, Exception.class); + TOLERABLE_EXCEPTIONS.put(Stage.KEY_CONVERTER, Exception.class); + TOLERABLE_EXCEPTIONS.put(Stage.VALUE_CONVERTER, Exception.class); + } + + private final long errorRetryTimeout; + private final long errorMaxDelayInMillis; + private final ToleranceType errorToleranceType; + + private long totalFailures = 0; + private final Time time; + private ErrorHandlingMetrics errorHandlingMetrics; + + protected final ProcessingContext context; + + public RetryWithToleranceOperator(long errorRetryTimeout, long errorMaxDelayInMillis, + ToleranceType toleranceType, Time time) { + this(errorRetryTimeout, errorMaxDelayInMillis, toleranceType, time, new ProcessingContext()); + } + + RetryWithToleranceOperator(long errorRetryTimeout, long errorMaxDelayInMillis, + ToleranceType toleranceType, Time time, + ProcessingContext context) { + this.errorRetryTimeout = errorRetryTimeout; + this.errorMaxDelayInMillis = errorMaxDelayInMillis; + this.errorToleranceType = toleranceType; + this.time = time; + this.context = context; + } + + public synchronized Future executeFailed(Stage stage, Class executingClass, + ConsumerRecord consumerRecord, + Throwable error) { + + markAsFailed(); + context.consumerRecord(consumerRecord); + context.currentContext(stage, executingClass); + context.error(error); + errorHandlingMetrics.recordFailure(); + Future errantRecordFuture = context.report(); + if (!withinToleranceLimits()) { + errorHandlingMetrics.recordError(); + throw new ConnectException("Tolerance exceeded in error handler", error); + } + return errantRecordFuture; + } + + /** + * Execute the recoverable operation. If the operation is already in a failed state, then simply return + * with the existing failure. + * + * @param operation the recoverable operation + * @param return type of the result of the operation. + * @return result of the operation + */ + public synchronized V execute(Operation operation, Stage stage, Class executingClass) { + context.currentContext(stage, executingClass); + + if (context.failed()) { + log.debug("ProcessingContext is already in failed state. Ignoring requested operation."); + return null; + } + + try { + Class ex = TOLERABLE_EXCEPTIONS.getOrDefault(context.stage(), RetriableException.class); + return execAndHandleError(operation, ex); + } finally { + if (context.failed()) { + errorHandlingMetrics.recordError(); + context.report(); + } + } + } + + /** + * Attempt to execute an operation. Retry if a {@link RetriableException} is raised. Re-throw everything else. + * + * @param operation the operation to be executed. + * @param the return type of the result of the operation. + * @return the result of the operation. + * @throws Exception rethrow if a non-retriable Exception is thrown by the operation + */ + protected V execAndRetry(Operation operation) throws Exception { + int attempt = 0; + long startTime = time.milliseconds(); + long deadline = startTime + errorRetryTimeout; + do { + try { + attempt++; + return operation.call(); + } catch (RetriableException e) { + log.trace("Caught a retriable exception while executing {} operation with {}", context.stage(), context.executingClass()); + errorHandlingMetrics.recordFailure(); + if (checkRetry(startTime)) { + backoff(attempt, deadline); + if (Thread.currentThread().isInterrupted()) { + log.trace("Thread was interrupted. Marking operation as failed."); + context.error(e); + return null; + } + errorHandlingMetrics.recordRetry(); + } else { + log.trace("Can't retry. start={}, attempt={}, deadline={}", startTime, attempt, deadline); + context.error(e); + return null; + } + } finally { + context.attempt(attempt); + } + } while (true); + } + + /** + * Execute a given operation multiple times (if needed), and tolerate certain exceptions. + * + * @param operation the operation to be executed. + * @param tolerated the class of exceptions which can be tolerated. + * @param The return type of the result of the operation. + * @return the result of the operation + */ + // Visible for testing + protected V execAndHandleError(Operation operation, Class tolerated) { + try { + V result = execAndRetry(operation); + if (context.failed()) { + markAsFailed(); + errorHandlingMetrics.recordSkipped(); + } + return result; + } catch (Exception e) { + errorHandlingMetrics.recordFailure(); + markAsFailed(); + context.error(e); + + if (!tolerated.isAssignableFrom(e.getClass())) { + throw new ConnectException("Unhandled exception in error handler", e); + } + + if (!withinToleranceLimits()) { + throw new ConnectException("Tolerance exceeded in error handler", e); + } + + errorHandlingMetrics.recordSkipped(); + return null; + } + } + + // Visible for testing + void markAsFailed() { + errorHandlingMetrics.recordErrorTimestamp(); + totalFailures++; + } + + @SuppressWarnings("fallthrough") + public synchronized boolean withinToleranceLimits() { + switch (errorToleranceType) { + case NONE: + if (totalFailures > 0) return false; + case ALL: + return true; + default: + throw new ConfigException("Unknown tolerance type: {}", errorToleranceType); + } + } + + // Visible for testing + boolean checkRetry(long startTime) { + return (time.milliseconds() - startTime) < errorRetryTimeout; + } + + // Visible for testing + void backoff(int attempt, long deadline) { + int numRetry = attempt - 1; + long delay = RETRIES_DELAY_MIN_MS << numRetry; + if (delay > errorMaxDelayInMillis) { + delay = ThreadLocalRandom.current().nextLong(errorMaxDelayInMillis); + } + if (delay + time.milliseconds() > deadline) { + delay = deadline - time.milliseconds(); + } + log.debug("Sleeping for {} millis", delay); + time.sleep(delay); + } + + public synchronized void metrics(ErrorHandlingMetrics errorHandlingMetrics) { + this.errorHandlingMetrics = errorHandlingMetrics; + } + + @Override + public String toString() { + return "RetryWithToleranceOperator{" + + "errorRetryTimeout=" + errorRetryTimeout + + ", errorMaxDelayInMillis=" + errorMaxDelayInMillis + + ", errorToleranceType=" + errorToleranceType + + ", totalFailures=" + totalFailures + + ", time=" + time + + ", context=" + context + + '}'; + } + + /** + * Set the error reporters for this connector. + * + * @param reporters the error reporters (should not be null). + */ + public synchronized void reporters(List reporters) { + this.context.reporters(reporters); + } + + /** + * Set the source record being processed in the connect pipeline. + * + * @param preTransformRecord the source record + */ + public synchronized void sourceRecord(SourceRecord preTransformRecord) { + this.context.sourceRecord(preTransformRecord); + } + + /** + * Set the record consumed from Kafka in a sink connector. + * + * @param consumedMessage the record + */ + public synchronized void consumerRecord(ConsumerRecord consumedMessage) { + this.context.consumerRecord(consumedMessage); + } + + /** + * @return true, if the last operation encountered an error; false otherwise + */ + public synchronized boolean failed() { + return this.context.failed(); + } + + /** + * Returns the error encountered when processing the current stage. + * + * @return the error encountered when processing the current stage + */ + public synchronized Throwable error() { + return this.context.error(); + } + + @Override + public synchronized void close() { + this.context.close(); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/Stage.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/Stage.java new file mode 100644 index 0000000..b9aa1f2 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/Stage.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.errors; + +/** + * A logical stage in a Connect pipeline. + */ +public enum Stage { + + /** + * When calling the poll() method on a SourceConnector + */ + TASK_POLL, + + /** + * When calling the put() method on a SinkConnector + */ + TASK_PUT, + + /** + * When running any transformation operation on a record + */ + TRANSFORMATION, + + /** + * When using the key converter to serialize/deserialize keys in ConnectRecords + */ + KEY_CONVERTER, + + /** + * When using the value converter to serialize/deserialize values in ConnectRecords + */ + VALUE_CONVERTER, + + /** + * When using the header converter to serialize/deserialize headers in ConnectRecords + */ + HEADER_CONVERTER, + + /** + * When producing to Kafka topic + */ + KAFKA_PRODUCE, + + /** + * When consuming from a Kafka topic + */ + KAFKA_CONSUME +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/ToleranceType.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/ToleranceType.java new file mode 100644 index 0000000..dd40a60 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/ToleranceType.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.errors; + +import java.util.Locale; + +/** + * The different levels of error tolerance. + */ +public enum ToleranceType { + + /** + * Tolerate no errors. + */ + NONE, + + /** + * Tolerate all errors. + */ + ALL; + + public String value() { + return name().toLowerCase(Locale.ROOT); + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/WorkerErrantRecordReporter.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/WorkerErrantRecordReporter.java new file mode 100644 index 0000000..ed48f79 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/errors/WorkerErrantRecordReporter.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.errors; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.header.Header; +import org.apache.kafka.connect.runtime.InternalSinkRecord; +import org.apache.kafka.connect.sink.ErrantRecordReporter; +import org.apache.kafka.connect.sink.SinkRecord; + +import org.apache.kafka.connect.sink.SinkTask; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.storage.HeaderConverter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.stream.Collectors; + +public class WorkerErrantRecordReporter implements ErrantRecordReporter { + + private static final Logger log = LoggerFactory.getLogger(WorkerErrantRecordReporter.class); + + private final RetryWithToleranceOperator retryWithToleranceOperator; + private final Converter keyConverter; + private final Converter valueConverter; + private final HeaderConverter headerConverter; + + // Visible for testing + protected final ConcurrentMap>> futures; + + public WorkerErrantRecordReporter( + RetryWithToleranceOperator retryWithToleranceOperator, + Converter keyConverter, + Converter valueConverter, + HeaderConverter headerConverter + ) { + this.retryWithToleranceOperator = retryWithToleranceOperator; + this.keyConverter = keyConverter; + this.valueConverter = valueConverter; + this.headerConverter = headerConverter; + this.futures = new ConcurrentHashMap<>(); + } + + @Override + public Future report(SinkRecord record, Throwable error) { + ConsumerRecord consumerRecord; + + // Most of the records will be an internal sink record, but the task could potentially + // report modified or new records, so handle both cases + if (record instanceof InternalSinkRecord) { + consumerRecord = ((InternalSinkRecord) record).originalRecord(); + } else { + // Generate a new consumer record from the modified sink record. We prefer + // to send the original consumer record (pre-transformed) to the DLQ, + // but in this case we don't have one and send the potentially transformed + // record instead + String topic = record.topic(); + byte[] key = keyConverter.fromConnectData(topic, record.keySchema(), record.key()); + byte[] value = valueConverter.fromConnectData(topic, + record.valueSchema(), record.value()); + + RecordHeaders headers = new RecordHeaders(); + if (record.headers() != null) { + for (Header header : record.headers()) { + String headerKey = header.key(); + byte[] rawHeader = headerConverter.fromConnectHeader(topic, headerKey, + header.schema(), header.value()); + headers.add(headerKey, rawHeader); + } + } + + int keyLength = key != null ? key.length : -1; + int valLength = value != null ? value.length : -1; + + consumerRecord = new ConsumerRecord<>(record.topic(), record.kafkaPartition(), + record.kafkaOffset(), record.timestamp(), record.timestampType(), keyLength, + valLength, key, value, headers, Optional.empty()); + } + + Future future = retryWithToleranceOperator.executeFailed(Stage.TASK_PUT, SinkTask.class, consumerRecord, error); + + if (!future.isDone()) { + TopicPartition partition = new TopicPartition(consumerRecord.topic(), consumerRecord.partition()); + futures.computeIfAbsent(partition, p -> new ArrayList<>()).add(future); + } + return future; + } + + /** + * Awaits the completion of all error reports for a given set of topic partitions + * @param topicPartitions the topic partitions to await reporter completion for + */ + public void awaitFutures(Collection topicPartitions) { + futuresFor(topicPartitions).forEach(future -> { + try { + future.get(); + } catch (InterruptedException | ExecutionException e) { + log.error("Encountered an error while awaiting an errant record future's completion.", e); + throw new ConnectException(e); + } + }); + } + + /** + * Cancels all active error reports for a given set of topic partitions + * @param topicPartitions the topic partitions to cancel reporting for + */ + public void cancelFutures(Collection topicPartitions) { + futuresFor(topicPartitions).forEach(future -> { + try { + future.cancel(true); + } catch (Exception e) { + log.error("Encountered an error while cancelling an errant record future", e); + // No need to throw the exception here; it's enough to log an error message + } + }); + } + + // Removes and returns all futures for the given topic partitions from the set of currently-active futures + private Collection> futuresFor(Collection topicPartitions) { + return topicPartitions.stream() + .map(futures::remove) + .filter(Objects::nonNull) + .flatMap(List::stream) + .collect(Collectors.toList()); + } + + /** + * Wrapper class to aggregate producer futures and abstract away the record metadata from the + * Connect user. + */ + public static class ErrantRecordFuture implements Future { + + private final List> futures; + + public ErrantRecordFuture(List> producerFutures) { + futures = producerFutures; + } + + public boolean cancel(boolean mayInterruptIfRunning) { + throw new UnsupportedOperationException("Reporting an errant record cannot be cancelled."); + } + + public boolean isCancelled() { + return false; + } + + public boolean isDone() { + return futures.stream().allMatch(Future::isDone); + } + + public Void get() throws InterruptedException, ExecutionException { + for (Future future: futures) { + future.get(); + } + return null; + } + + public Void get(long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, TimeoutException { + for (Future future: futures) { + future.get(timeout, unit); + } + return null; + } + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/health/ConnectClusterDetailsImpl.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/health/ConnectClusterDetailsImpl.java new file mode 100644 index 0000000..09f09bd --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/health/ConnectClusterDetailsImpl.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.runtime.health; + +import org.apache.kafka.connect.health.ConnectClusterDetails; + +public class ConnectClusterDetailsImpl implements ConnectClusterDetails { + + private final String kafkaClusterId; + + public ConnectClusterDetailsImpl(String kafkaClusterId) { + this.kafkaClusterId = kafkaClusterId; + } + + @Override + public String kafkaClusterId() { + return kafkaClusterId; + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/health/ConnectClusterStateImpl.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/health/ConnectClusterStateImpl.java new file mode 100644 index 0000000..6b7285d --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/health/ConnectClusterStateImpl.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.runtime.health; + +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.health.ConnectClusterDetails; +import org.apache.kafka.connect.health.ConnectClusterState; +import org.apache.kafka.connect.health.ConnectorHealth; +import org.apache.kafka.connect.health.ConnectorState; +import org.apache.kafka.connect.health.ConnectorType; +import org.apache.kafka.connect.health.TaskState; +import org.apache.kafka.connect.runtime.Herder; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorStateInfo; +import org.apache.kafka.connect.util.FutureCallback; + +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class ConnectClusterStateImpl implements ConnectClusterState { + + private final long herderRequestTimeoutMs; + private final ConnectClusterDetails clusterDetails; + private final Herder herder; + + public ConnectClusterStateImpl( + long connectorsTimeoutMs, + ConnectClusterDetails clusterDetails, + Herder herder + ) { + this.herderRequestTimeoutMs = connectorsTimeoutMs; + this.clusterDetails = clusterDetails; + this.herder = herder; + } + + @Override + public Collection connectors() { + FutureCallback> connectorsCallback = new FutureCallback<>(); + herder.connectors(connectorsCallback); + try { + return connectorsCallback.get(herderRequestTimeoutMs, TimeUnit.MILLISECONDS); + } catch (InterruptedException | ExecutionException | TimeoutException e) { + throw new ConnectException("Failed to retrieve list of connectors", e); + } + } + + @Override + public ConnectorHealth connectorHealth(String connName) { + ConnectorStateInfo state = herder.connectorStatus(connName); + ConnectorState connectorState = new ConnectorState( + state.connector().state(), + state.connector().workerId(), + state.connector().trace() + ); + Map taskStates = taskStates(state.tasks()); + ConnectorHealth connectorHealth = new ConnectorHealth( + connName, + connectorState, + taskStates, + ConnectorType.valueOf(state.type().name()) + ); + return connectorHealth; + } + + @Override + public Map connectorConfig(String connName) { + FutureCallback> connectorConfigCallback = new FutureCallback<>(); + herder.connectorConfig(connName, connectorConfigCallback); + try { + return new HashMap<>(connectorConfigCallback.get(herderRequestTimeoutMs, TimeUnit.MILLISECONDS)); + } catch (InterruptedException | ExecutionException | TimeoutException e) { + throw new ConnectException( + String.format("Failed to retrieve configuration for connector '%s'", connName), + e + ); + } + } + + @Override + public ConnectClusterDetails clusterDetails() { + return clusterDetails; + } + + private Map taskStates(List states) { + + Map taskStates = new HashMap<>(); + + for (ConnectorStateInfo.TaskState state : states) { + taskStates.put( + state.id(), + new TaskState(state.id(), state.state(), state.workerId(), state.trace()) + ); + } + return taskStates; + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/isolation/DelegatingClassLoader.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/isolation/DelegatingClassLoader.java new file mode 100644 index 0000000..4a1de5e --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/isolation/DelegatingClassLoader.java @@ -0,0 +1,520 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.isolation; + +import org.apache.kafka.common.config.provider.ConfigProvider; +import org.apache.kafka.connect.components.Versioned; +import org.apache.kafka.connect.connector.Connector; +import org.apache.kafka.connect.connector.policy.ConnectorClientConfigOverridePolicy; +import org.apache.kafka.connect.rest.ConnectRestExtension; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.storage.HeaderConverter; +import org.apache.kafka.connect.transforms.Transformation; +import org.apache.kafka.connect.transforms.predicates.Predicate; +import org.reflections.Configuration; +import org.reflections.Reflections; +import org.reflections.ReflectionsException; +import org.reflections.scanners.SubTypesScanner; +import org.reflections.util.ClasspathHelper; +import org.reflections.util.ConfigurationBuilder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.MalformedURLException; +import java.net.URL; +import java.net.URLClassLoader; +import java.nio.file.Files; +import java.nio.file.InvalidPathException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.sql.Driver; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Enumeration; +import java.util.Iterator; +import java.util.List; +import java.util.ServiceLoader; +import java.util.Set; +import java.util.SortedMap; +import java.util.SortedSet; +import java.util.TreeMap; +import java.util.TreeSet; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.stream.Collectors; + +/** + * A custom classloader dedicated to loading Connect plugin classes in classloading isolation. + * + *

            + * Under the current scheme for classloading isolation in Connect, the delegating classloader loads + * plugin classes that it finds in its child plugin classloaders. For classes that are not plugins, + * this delegating classloader delegates its loading to its parent. This makes this classloader a + * child-first classloader. + *

            + * This class is thread-safe and parallel capable. + */ +public class DelegatingClassLoader extends URLClassLoader { + private static final Logger log = LoggerFactory.getLogger(DelegatingClassLoader.class); + private static final String CLASSPATH_NAME = "classpath"; + private static final String UNDEFINED_VERSION = "undefined"; + + private final ConcurrentMap, ClassLoader>> pluginLoaders; + private final ConcurrentMap aliases; + private final SortedSet> connectors; + private final SortedSet> converters; + private final SortedSet> headerConverters; + private final SortedSet>> transformations; + private final SortedSet>> predicates; + private final SortedSet> configProviders; + private final SortedSet> restExtensions; + private final SortedSet> connectorClientConfigPolicies; + private final List pluginPaths; + + private static final String MANIFEST_PREFIX = "META-INF/services/"; + private static final Class[] SERVICE_LOADER_PLUGINS = new Class[] {ConnectRestExtension.class, ConfigProvider.class}; + private static final Set PLUGIN_MANIFEST_FILES = + Arrays.stream(SERVICE_LOADER_PLUGINS).map(serviceLoaderPlugin -> MANIFEST_PREFIX + serviceLoaderPlugin.getName()) + .collect(Collectors.toSet()); + + // Although this classloader does not load classes directly but rather delegates loading to a + // PluginClassLoader or its parent through its base class, because of the use of inheritance in + // in the latter case, this classloader needs to also be declared as parallel capable to use + // fine-grain locking when loading classes. + static { + ClassLoader.registerAsParallelCapable(); + } + + public DelegatingClassLoader(List pluginPaths, ClassLoader parent) { + super(new URL[0], parent); + this.pluginPaths = pluginPaths; + this.pluginLoaders = new ConcurrentHashMap<>(); + this.aliases = new ConcurrentHashMap<>(); + this.connectors = new TreeSet<>(); + this.converters = new TreeSet<>(); + this.headerConverters = new TreeSet<>(); + this.transformations = new TreeSet<>(); + this.predicates = new TreeSet<>(); + this.configProviders = new TreeSet<>(); + this.restExtensions = new TreeSet<>(); + this.connectorClientConfigPolicies = new TreeSet<>(); + } + + public DelegatingClassLoader(List pluginPaths) { + // Use as parent the classloader that loaded this class. In most cases this will be the + // System classloader. But this choice here provides additional flexibility in managed + // environments that control classloading differently (OSGi, Spring and others) and don't + // depend on the System classloader to load Connect's classes. + this(pluginPaths, DelegatingClassLoader.class.getClassLoader()); + } + + public Set> connectors() { + return connectors; + } + + public Set> converters() { + return converters; + } + + public Set> headerConverters() { + return headerConverters; + } + + public Set>> transformations() { + return transformations; + } + + public Set>> predicates() { + return predicates; + } + + public Set> configProviders() { + return configProviders; + } + + public Set> restExtensions() { + return restExtensions; + } + + public Set> connectorClientConfigPolicies() { + return connectorClientConfigPolicies; + } + + /** + * Retrieve the PluginClassLoader associated with a plugin class + * @param name The fully qualified class name of the plugin + * @return the PluginClassLoader that should be used to load this, or null if the plugin is not isolated. + */ + public PluginClassLoader pluginClassLoader(String name) { + if (!PluginUtils.shouldLoadInIsolation(name)) { + return null; + } + SortedMap, ClassLoader> inner = pluginLoaders.get(name); + if (inner == null) { + return null; + } + ClassLoader pluginLoader = inner.get(inner.lastKey()); + return pluginLoader instanceof PluginClassLoader + ? (PluginClassLoader) pluginLoader + : null; + } + + public ClassLoader connectorLoader(Connector connector) { + return connectorLoader(connector.getClass().getName()); + } + + public ClassLoader connectorLoader(String connectorClassOrAlias) { + String fullName = aliases.containsKey(connectorClassOrAlias) + ? aliases.get(connectorClassOrAlias) + : connectorClassOrAlias; + ClassLoader classLoader = pluginClassLoader(fullName); + if (classLoader == null) classLoader = this; + log.debug( + "Getting plugin class loader: '{}' for connector: {}", + classLoader, + connectorClassOrAlias + ); + return classLoader; + } + + protected PluginClassLoader newPluginClassLoader( + final URL pluginLocation, + final URL[] urls, + final ClassLoader parent + ) { + return AccessController.doPrivileged( + (PrivilegedAction) () -> new PluginClassLoader(pluginLocation, urls, parent) + ); + } + + private void addPlugins(Collection> plugins, ClassLoader loader) { + for (PluginDesc plugin : plugins) { + String pluginClassName = plugin.className(); + SortedMap, ClassLoader> inner = pluginLoaders.get(pluginClassName); + if (inner == null) { + inner = new TreeMap<>(); + pluginLoaders.put(pluginClassName, inner); + // TODO: once versioning is enabled this line should be moved outside this if branch + log.info("Added plugin '{}'", pluginClassName); + } + inner.put(plugin, loader); + } + } + + protected void initLoaders() { + for (String configPath : pluginPaths) { + initPluginLoader(configPath); + } + // Finally add parent/system loader. + initPluginLoader(CLASSPATH_NAME); + addAllAliases(); + } + + private void initPluginLoader(String path) { + try { + if (CLASSPATH_NAME.equals(path)) { + scanUrlsAndAddPlugins( + getParent(), + ClasspathHelper.forJavaClassPath().toArray(new URL[0]), + null + ); + } else { + Path pluginPath = Paths.get(path).toAbsolutePath(); + // Update for exception handling + path = pluginPath.toString(); + // Currently 'plugin.paths' property is a list of top-level directories + // containing plugins + if (Files.isDirectory(pluginPath)) { + for (Path pluginLocation : PluginUtils.pluginLocations(pluginPath)) { + registerPlugin(pluginLocation); + } + } else if (PluginUtils.isArchive(pluginPath)) { + registerPlugin(pluginPath); + } + } + } catch (InvalidPathException | MalformedURLException e) { + log.error("Invalid path in plugin path: {}. Ignoring.", path, e); + } catch (IOException e) { + log.error("Could not get listing for plugin path: {}. Ignoring.", path, e); + } catch (ReflectiveOperationException e) { + log.error("Could not instantiate plugins in: {}. Ignoring: {}", path, e); + } + } + + private void registerPlugin(Path pluginLocation) + throws IOException, ReflectiveOperationException { + log.info("Loading plugin from: {}", pluginLocation); + List pluginUrls = new ArrayList<>(); + for (Path path : PluginUtils.pluginUrls(pluginLocation)) { + pluginUrls.add(path.toUri().toURL()); + } + URL[] urls = pluginUrls.toArray(new URL[0]); + if (log.isDebugEnabled()) { + log.debug("Loading plugin urls: {}", Arrays.toString(urls)); + } + PluginClassLoader loader = newPluginClassLoader( + pluginLocation.toUri().toURL(), + urls, + this + ); + scanUrlsAndAddPlugins(loader, urls, pluginLocation); + } + + private void scanUrlsAndAddPlugins( + ClassLoader loader, + URL[] urls, + Path pluginLocation + ) throws ReflectiveOperationException { + PluginScanResult plugins = scanPluginPath(loader, urls); + log.info("Registered loader: {}", loader); + if (!plugins.isEmpty()) { + addPlugins(plugins.connectors(), loader); + connectors.addAll(plugins.connectors()); + addPlugins(plugins.converters(), loader); + converters.addAll(plugins.converters()); + addPlugins(plugins.headerConverters(), loader); + headerConverters.addAll(plugins.headerConverters()); + addPlugins(plugins.transformations(), loader); + transformations.addAll(plugins.transformations()); + addPlugins(plugins.predicates(), loader); + predicates.addAll(plugins.predicates()); + addPlugins(plugins.configProviders(), loader); + configProviders.addAll(plugins.configProviders()); + addPlugins(plugins.restExtensions(), loader); + restExtensions.addAll(plugins.restExtensions()); + addPlugins(plugins.connectorClientConfigPolicies(), loader); + connectorClientConfigPolicies.addAll(plugins.connectorClientConfigPolicies()); + } + + loadJdbcDrivers(loader); + } + + private void loadJdbcDrivers(final ClassLoader loader) { + // Apply here what java.sql.DriverManager does to discover and register classes + // implementing the java.sql.Driver interface. + AccessController.doPrivileged( + (PrivilegedAction) () -> { + ServiceLoader loadedDrivers = ServiceLoader.load( + Driver.class, + loader + ); + Iterator driversIterator = loadedDrivers.iterator(); + try { + while (driversIterator.hasNext()) { + Driver driver = driversIterator.next(); + log.debug( + "Registered java.sql.Driver: {} to java.sql.DriverManager", + driver + ); + } + } catch (Throwable t) { + log.debug( + "Ignoring java.sql.Driver classes listed in resources but not" + + " present in class loader's classpath: ", + t + ); + } + return null; + } + ); + } + + private PluginScanResult scanPluginPath( + ClassLoader loader, + URL[] urls + ) throws ReflectiveOperationException { + ConfigurationBuilder builder = new ConfigurationBuilder(); + builder.setClassLoaders(new ClassLoader[]{loader}); + builder.addUrls(urls); + builder.setScanners(new SubTypesScanner()); + builder.useParallelExecutor(); + Reflections reflections = new InternalReflections(builder); + + return new PluginScanResult( + getPluginDesc(reflections, Connector.class, loader), + getPluginDesc(reflections, Converter.class, loader), + getPluginDesc(reflections, HeaderConverter.class, loader), + getTransformationPluginDesc(loader, reflections), + getPredicatePluginDesc(loader, reflections), + getServiceLoaderPluginDesc(ConfigProvider.class, loader), + getServiceLoaderPluginDesc(ConnectRestExtension.class, loader), + getServiceLoaderPluginDesc(ConnectorClientConfigOverridePolicy.class, loader) + ); + } + + @SuppressWarnings({"unchecked"}) + private Collection>> getPredicatePluginDesc(ClassLoader loader, Reflections reflections) throws ReflectiveOperationException { + return (Collection>>) (Collection) getPluginDesc(reflections, Predicate.class, loader); + } + + @SuppressWarnings({"unchecked"}) + private Collection>> getTransformationPluginDesc(ClassLoader loader, Reflections reflections) throws ReflectiveOperationException { + return (Collection>>) (Collection) getPluginDesc(reflections, Transformation.class, loader); + } + + private Collection> getPluginDesc( + Reflections reflections, + Class klass, + ClassLoader loader + ) throws ReflectiveOperationException { + Set> plugins; + try { + plugins = reflections.getSubTypesOf(klass); + } catch (ReflectionsException e) { + log.debug("Reflections scanner could not find any classes for URLs: " + + reflections.getConfiguration().getUrls(), e); + return Collections.emptyList(); + } + + Collection> result = new ArrayList<>(); + for (Class plugin : plugins) { + if (PluginUtils.isConcrete(plugin)) { + result.add(pluginDesc(plugin, versionFor(plugin), loader)); + } else { + log.debug("Skipping {} as it is not concrete implementation", plugin); + } + } + return result; + } + + @SuppressWarnings({"rawtypes", "unchecked"}) + private PluginDesc pluginDesc(Class plugin, String version, ClassLoader loader) { + return new PluginDesc(plugin, version, loader); + } + + @SuppressWarnings("unchecked") + private Collection> getServiceLoaderPluginDesc(Class klass, ClassLoader loader) { + ClassLoader savedLoader = Plugins.compareAndSwapLoaders(loader); + Collection> result = new ArrayList<>(); + try { + ServiceLoader serviceLoader = ServiceLoader.load(klass, loader); + for (T pluginImpl : serviceLoader) { + result.add(pluginDesc((Class) pluginImpl.getClass(), + versionFor(pluginImpl), loader)); + } + } finally { + Plugins.compareAndSwapLoaders(savedLoader); + } + return result; + } + + private static String versionFor(T pluginImpl) { + return pluginImpl instanceof Versioned ? ((Versioned) pluginImpl).version() : UNDEFINED_VERSION; + } + + private static String versionFor(Class pluginKlass) throws ReflectiveOperationException { + // Temporary workaround until all the plugins are versioned. + return Connector.class.isAssignableFrom(pluginKlass) ? + versionFor(pluginKlass.getDeclaredConstructor().newInstance()) : UNDEFINED_VERSION; + } + + @Override + protected Class loadClass(String name, boolean resolve) throws ClassNotFoundException { + String fullName = aliases.containsKey(name) ? aliases.get(name) : name; + PluginClassLoader pluginLoader = pluginClassLoader(fullName); + if (pluginLoader != null) { + log.trace("Retrieving loaded class '{}' from '{}'", fullName, pluginLoader); + return pluginLoader.loadClass(fullName, resolve); + } + + return super.loadClass(fullName, resolve); + } + + private void addAllAliases() { + addAliases(connectors); + addAliases(converters); + addAliases(headerConverters); + addAliases(transformations); + addAliases(predicates); + addAliases(restExtensions); + addAliases(connectorClientConfigPolicies); + } + + private void addAliases(Collection> plugins) { + for (PluginDesc plugin : plugins) { + if (PluginUtils.isAliasUnique(plugin, plugins)) { + String simple = PluginUtils.simpleName(plugin); + String pruned = PluginUtils.prunedName(plugin); + aliases.put(simple, plugin.className()); + if (simple.equals(pruned)) { + log.info("Added alias '{}' to plugin '{}'", simple, plugin.className()); + } else { + aliases.put(pruned, plugin.className()); + log.info( + "Added aliases '{}' and '{}' to plugin '{}'", + simple, + pruned, + plugin.className() + ); + } + } + } + } + + private static class InternalReflections extends Reflections { + + public InternalReflections(Configuration configuration) { + super(configuration); + } + + // When Reflections is used for parallel scans, it has a bug where it propagates ReflectionsException + // as RuntimeException. Override the scan behavior to emulate the singled-threaded logic. + @Override + protected void scan(URL url) { + try { + super.scan(url); + } catch (ReflectionsException e) { + Logger log = Reflections.log; + if (log != null && log.isWarnEnabled()) { + log.warn("could not create Vfs.Dir from url. ignoring the exception and continuing", e); + } + } + } + } + + @Override + public URL getResource(String name) { + if (serviceLoaderManifestForPlugin(name)) { + // Default implementation of getResource searches the parent class loader and if not available/found, its own URL paths. + // This will enable thePluginClassLoader to limit its resource search only to its own URL paths. + return null; + } else { + return super.getResource(name); + } + } + + @Override + public Enumeration getResources(String name) throws IOException { + if (serviceLoaderManifestForPlugin(name)) { + // Default implementation of getResources searches the parent class loader and and also its own URL paths. This will enable the + // PluginClassLoader to limit its resource search to only its own URL paths. + return null; + } else { + return super.getResources(name); + } + } + + //Visible for testing + static boolean serviceLoaderManifestForPlugin(String name) { + return PLUGIN_MANIFEST_FILES.contains(name); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/isolation/PluginClassLoader.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/isolation/PluginClassLoader.java new file mode 100644 index 0000000..bc0df79 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/isolation/PluginClassLoader.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.isolation; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.URL; +import java.net.URLClassLoader; + +/** + * A custom classloader dedicated to loading Connect plugin classes in classloading isolation. + *

            + * Under the current scheme for classloading isolation in Connect, a plugin classloader loads the + * classes that it finds in its urls. For classes that are either not found or are not supposed to + * be loaded in isolation, this plugin classloader delegates their loading to its parent. This makes + * this classloader a child-first classloader. + *

            + * This class is thread-safe and parallel capable. + */ +public class PluginClassLoader extends URLClassLoader { + private static final Logger log = LoggerFactory.getLogger(PluginClassLoader.class); + private final URL pluginLocation; + + static { + ClassLoader.registerAsParallelCapable(); + } + + /** + * Constructor that accepts a specific classloader as parent. + * + * @param pluginLocation the top-level location of the plugin to be loaded in isolation by this + * classloader. + * @param urls the list of urls from which to load classes and resources for this plugin. + * @param parent the parent classloader to be used for delegation for classes that were + * not found or should not be loaded in isolation by this classloader. + */ + public PluginClassLoader(URL pluginLocation, URL[] urls, ClassLoader parent) { + super(urls, parent); + this.pluginLocation = pluginLocation; + } + + /** + * Constructor that defines the system classloader as parent of this plugin classloader. + * + * @param pluginLocation the top-level location of the plugin to be loaded in isolation by this + * classloader. + * @param urls the list of urls from which to load classes and resources for this plugin. + */ + public PluginClassLoader(URL pluginLocation, URL[] urls) { + super(urls); + this.pluginLocation = pluginLocation; + } + + /** + * Returns the top-level location of the classes and dependencies required by the plugin that + * is loaded by this classloader. + * + * @return the plugin location. + */ + public String location() { + return pluginLocation.toString(); + } + + @Override + public String toString() { + return "PluginClassLoader{pluginLocation=" + pluginLocation + "}"; + } + + // This method needs to be thread-safe because it is supposed to be called by multiple + // Connect tasks. While findClass is thread-safe, defineClass called within loadClass of the + // base method is not. More on multithreaded classloaders in: + // https://docs.oracle.com/javase/7/docs/technotes/guides/lang/cl-mt.html + @Override + protected Class loadClass(String name, boolean resolve) throws ClassNotFoundException { + synchronized (getClassLoadingLock(name)) { + Class klass = findLoadedClass(name); + if (klass == null) { + try { + if (PluginUtils.shouldLoadInIsolation(name)) { + klass = findClass(name); + } + } catch (ClassNotFoundException e) { + // Not found in loader's path. Search in parents. + log.trace("Class '{}' not found. Delegating to parent", name); + } + } + if (klass == null) { + klass = super.loadClass(name, false); + } + if (resolve) { + resolveClass(klass); + } + return klass; + } + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/isolation/PluginDesc.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/isolation/PluginDesc.java new file mode 100644 index 0000000..62a7d6c --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/isolation/PluginDesc.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.isolation; + +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.maven.artifact.versioning.DefaultArtifactVersion; + +import java.util.Objects; + +public class PluginDesc implements Comparable> { + private final Class klass; + private final String name; + private final String version; + private final DefaultArtifactVersion encodedVersion; + private final PluginType type; + private final String typeName; + private final String location; + + public PluginDesc(Class klass, String version, ClassLoader loader) { + this.klass = klass; + this.name = klass.getName(); + this.version = version != null ? version : "null"; + this.encodedVersion = new DefaultArtifactVersion(this.version); + this.type = PluginType.from(klass); + this.typeName = type.toString(); + this.location = loader instanceof PluginClassLoader + ? ((PluginClassLoader) loader).location() + : "classpath"; + } + + @Override + public String toString() { + return "PluginDesc{" + + "klass=" + klass + + ", name='" + name + '\'' + + ", version='" + version + '\'' + + ", encodedVersion=" + encodedVersion + + ", type=" + type + + ", typeName='" + typeName + '\'' + + ", location='" + location + '\'' + + '}'; + } + + public Class pluginClass() { + return klass; + } + + @JsonProperty("class") + public String className() { + return name; + } + + @JsonProperty("version") + public String version() { + return version; + } + + public PluginType type() { + return type; + } + + @JsonProperty("type") + public String typeName() { + return typeName; + } + + @JsonProperty("location") + public String location() { + return location; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof PluginDesc)) { + return false; + } + PluginDesc that = (PluginDesc) o; + return Objects.equals(klass, that.klass) && + Objects.equals(version, that.version) && + type == that.type; + } + + @Override + public int hashCode() { + return Objects.hash(klass, version, type); + } + + @Override + public int compareTo(PluginDesc other) { + int nameComp = name.compareTo(other.name); + return nameComp != 0 ? nameComp : encodedVersion.compareTo(other.encodedVersion); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/isolation/PluginScanResult.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/isolation/PluginScanResult.java new file mode 100644 index 0000000..e98945e --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/isolation/PluginScanResult.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.isolation; + +import org.apache.kafka.common.config.provider.ConfigProvider; +import org.apache.kafka.connect.connector.Connector; +import org.apache.kafka.connect.connector.policy.ConnectorClientConfigOverridePolicy; +import org.apache.kafka.connect.rest.ConnectRestExtension; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.storage.HeaderConverter; +import org.apache.kafka.connect.transforms.Transformation; +import org.apache.kafka.connect.transforms.predicates.Predicate; + +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +public class PluginScanResult { + private final Collection> connectors; + private final Collection> converters; + private final Collection> headerConverters; + private final Collection>> transformations; + private final Collection>> predicates; + private final Collection> configProviders; + private final Collection> restExtensions; + private final Collection> connectorClientConfigPolicies; + + private final List> allPlugins; + + public PluginScanResult( + Collection> connectors, + Collection> converters, + Collection> headerConverters, + Collection>> transformations, + Collection>> predicates, + Collection> configProviders, + Collection> restExtensions, + Collection> connectorClientConfigPolicies + ) { + this.connectors = connectors; + this.converters = converters; + this.headerConverters = headerConverters; + this.transformations = transformations; + this.predicates = predicates; + this.configProviders = configProviders; + this.restExtensions = restExtensions; + this.connectorClientConfigPolicies = connectorClientConfigPolicies; + this.allPlugins = + Arrays.asList(connectors, converters, headerConverters, transformations, configProviders, + connectorClientConfigPolicies); + } + + public Collection> connectors() { + return connectors; + } + + public Collection> converters() { + return converters; + } + + public Collection> headerConverters() { + return headerConverters; + } + + public Collection>> transformations() { + return transformations; + } + + public Collection>> predicates() { + return predicates; + } + + public Collection> configProviders() { + return configProviders; + } + + public Collection> restExtensions() { + return restExtensions; + } + + public Collection> connectorClientConfigPolicies() { + return connectorClientConfigPolicies; + } + + public boolean isEmpty() { + boolean isEmpty = true; + for (Collection plugins : allPlugins) { + isEmpty = isEmpty && plugins.isEmpty(); + } + return isEmpty; + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/isolation/PluginType.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/isolation/PluginType.java new file mode 100644 index 0000000..8b42f59 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/isolation/PluginType.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.isolation; + +import org.apache.kafka.common.config.provider.ConfigProvider; +import org.apache.kafka.connect.connector.Connector; +import org.apache.kafka.connect.connector.policy.ConnectorClientConfigOverridePolicy; +import org.apache.kafka.connect.rest.ConnectRestExtension; +import org.apache.kafka.connect.sink.SinkConnector; +import org.apache.kafka.connect.source.SourceConnector; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.transforms.Transformation; + +import java.util.Locale; + +public enum PluginType { + SOURCE(SourceConnector.class), + SINK(SinkConnector.class), + CONNECTOR(Connector.class), + CONVERTER(Converter.class), + TRANSFORMATION(Transformation.class), + CONFIGPROVIDER(ConfigProvider.class), + REST_EXTENSION(ConnectRestExtension.class), + CONNECTOR_CLIENT_CONFIG_OVERRIDE_POLICY(ConnectorClientConfigOverridePolicy.class), + UNKNOWN(Object.class); + + private Class klass; + + PluginType(Class klass) { + this.klass = klass; + } + + public static PluginType from(Class klass) { + for (PluginType type : PluginType.values()) { + if (type.klass.isAssignableFrom(klass)) { + return type; + } + } + return UNKNOWN; + } + + public String simpleName() { + return klass.getSimpleName(); + } + + @Override + public String toString() { + return super.toString().toLowerCase(Locale.ROOT); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/isolation/PluginUtils.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/isolation/PluginUtils.java new file mode 100644 index 0000000..12cb186 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/isolation/PluginUtils.java @@ -0,0 +1,373 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.isolation; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.lang.reflect.Modifier; +import java.nio.file.DirectoryStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Locale; +import java.util.Set; +import java.util.TreeSet; +import java.util.regex.Pattern; + +/** + * Connect plugin utility methods. + */ +public class PluginUtils { + private static final Logger log = LoggerFactory.getLogger(PluginUtils.class); + + // Be specific about javax packages and exclude those existing in Java SE and Java EE libraries. + private static final Pattern EXCLUDE = Pattern.compile("^(?:" + + "java" + + "|javax\\.accessibility" + + "|javax\\.activation" + + "|javax\\.activity" + + "|javax\\.annotation" + + "|javax\\.batch\\.api" + + "|javax\\.batch\\.operations" + + "|javax\\.batch\\.runtime" + + "|javax\\.crypto" + + "|javax\\.decorator" + + "|javax\\.ejb" + + "|javax\\.el" + + "|javax\\.enterprise\\.concurrent" + + "|javax\\.enterprise\\.context" + + "|javax\\.enterprise\\.context\\.spi" + + "|javax\\.enterprise\\.deploy\\.model" + + "|javax\\.enterprise\\.deploy\\.shared" + + "|javax\\.enterprise\\.deploy\\.spi" + + "|javax\\.enterprise\\.event" + + "|javax\\.enterprise\\.inject" + + "|javax\\.enterprise\\.inject\\.spi" + + "|javax\\.enterprise\\.util" + + "|javax\\.faces" + + "|javax\\.imageio" + + "|javax\\.inject" + + "|javax\\.interceptor" + + "|javax\\.jms" + + "|javax\\.json" + + "|javax\\.jws" + + "|javax\\.lang\\.model" + + "|javax\\.mail" + + "|javax\\.management" + + "|javax\\.management\\.j2ee" + + "|javax\\.naming" + + "|javax\\.net" + + "|javax\\.persistence" + + "|javax\\.print" + + "|javax\\.resource" + + "|javax\\.rmi" + + "|javax\\.script" + + "|javax\\.security\\.auth" + + "|javax\\.security\\.auth\\.message" + + "|javax\\.security\\.cert" + + "|javax\\.security\\.jacc" + + "|javax\\.security\\.sasl" + + "|javax\\.servlet" + + "|javax\\.sound\\.midi" + + "|javax\\.sound\\.sampled" + + "|javax\\.sql" + + "|javax\\.swing" + + "|javax\\.tools" + + "|javax\\.transaction" + + "|javax\\.validation" + + "|javax\\.websocket" + + "|javax\\.ws\\.rs" + + "|javax\\.xml" + + "|javax\\.xml\\.bind" + + "|javax\\.xml\\.registry" + + "|javax\\.xml\\.rpc" + + "|javax\\.xml\\.soap" + + "|javax\\.xml\\.ws" + + "|org\\.ietf\\.jgss" + + "|org\\.omg\\.CORBA" + + "|org\\.omg\\.CosNaming" + + "|org\\.omg\\.Dynamic" + + "|org\\.omg\\.DynamicAny" + + "|org\\.omg\\.IOP" + + "|org\\.omg\\.Messaging" + + "|org\\.omg\\.PortableInterceptor" + + "|org\\.omg\\.PortableServer" + + "|org\\.omg\\.SendingContext" + + "|org\\.omg\\.stub\\.java\\.rmi" + + "|org\\.w3c\\.dom" + + "|org\\.xml\\.sax" + + "|org\\.apache\\.kafka" + + "|org\\.slf4j" + + ")\\..*$"); + + // If the base interface or class that will be used to identify Connect plugins resides within + // the same java package as the plugins that need to be loaded in isolation (and thus are + // added to the INCLUDE pattern), then this base interface or class needs to be excluded in the + // regular expression pattern + private static final Pattern INCLUDE = Pattern.compile("^org\\.apache\\.kafka\\.(?:connect\\.(?:" + + "transforms\\.(?!Transformation|predicates\\.Predicate$).*" + + "|json\\..*" + + "|file\\..*" + + "|mirror\\..*" + + "|mirror-client\\..*" + + "|converters\\..*" + + "|storage\\.StringConverter" + + "|storage\\.SimpleHeaderConverter" + + "|rest\\.basic\\.auth\\.extension\\.BasicAuthSecurityRestExtension" + + "|connector\\.policy\\.(?!ConnectorClientConfig(?:OverridePolicy|Request(?:\\$ClientType)?)$).*" + + ")" + + "|common\\.config\\.provider\\.(?!ConfigProvider$).*" + + ")$"); + + private static final DirectoryStream.Filter PLUGIN_PATH_FILTER = path -> + Files.isDirectory(path) || isArchive(path) || isClassFile(path); + + /** + * Return whether the class with the given name should be loaded in isolation using a plugin + * classloader. + * + * @param name the fully qualified name of the class. + * @return true if this class should be loaded in isolation, false otherwise. + */ + public static boolean shouldLoadInIsolation(String name) { + return !(EXCLUDE.matcher(name).matches() && !INCLUDE.matcher(name).matches()); + } + + /** + * Verify the given class corresponds to a concrete class and not to an abstract class or + * interface. + * @param klass the class object. + * @return true if the argument is a concrete class, false if it's abstract or interface. + */ + public static boolean isConcrete(Class klass) { + int mod = klass.getModifiers(); + return !Modifier.isAbstract(mod) && !Modifier.isInterface(mod); + } + + /** + * Return whether a path corresponds to a JAR or ZIP archive. + * + * @param path the path to validate. + * @return true if the path is a JAR or ZIP archive file, otherwise false. + */ + public static boolean isArchive(Path path) { + String archivePath = path.toString().toLowerCase(Locale.ROOT); + return archivePath.endsWith(".jar") || archivePath.endsWith(".zip"); + } + + /** + * Return whether a path corresponds java class file. + * + * @param path the path to validate. + * @return true if the path is a java class file, otherwise false. + */ + public static boolean isClassFile(Path path) { + return path.toString().toLowerCase(Locale.ROOT).endsWith(".class"); + } + + public static List pluginLocations(Path topPath) throws IOException { + List locations = new ArrayList<>(); + try ( + DirectoryStream listing = Files.newDirectoryStream( + topPath, + PLUGIN_PATH_FILTER + ) + ) { + for (Path dir : listing) { + locations.add(dir); + } + } + return locations; + } + + /** + * Given a top path in the filesystem, return a list of paths to archives (JAR or ZIP + * files) contained under this top path. If the top path contains only java class files, + * return the top path itself. This method follows symbolic links to discover archives and + * returns the such archives as absolute paths. + * + * @param topPath the path to use as root of plugin search. + * @return a list of potential plugin paths, or empty list if no such paths exist. + * @throws IOException + */ + public static List pluginUrls(Path topPath) throws IOException { + boolean containsClassFiles = false; + Set archives = new TreeSet<>(); + LinkedList dfs = new LinkedList<>(); + Set visited = new HashSet<>(); + + if (isArchive(topPath)) { + return Collections.singletonList(topPath); + } + + DirectoryStream topListing = Files.newDirectoryStream( + topPath, + PLUGIN_PATH_FILTER + ); + dfs.push(new DirectoryEntry(topListing)); + visited.add(topPath); + try { + while (!dfs.isEmpty()) { + Iterator neighbors = dfs.peek().iterator; + if (!neighbors.hasNext()) { + dfs.pop().stream.close(); + continue; + } + + Path adjacent = neighbors.next(); + if (Files.isSymbolicLink(adjacent)) { + try { + Path symlink = Files.readSymbolicLink(adjacent); + // if symlink is absolute resolve() returns the absolute symlink itself + Path parent = adjacent.getParent(); + if (parent == null) { + continue; + } + Path absolute = parent.resolve(symlink).toRealPath(); + if (Files.exists(absolute)) { + adjacent = absolute; + } else { + continue; + } + } catch (IOException e) { + // See https://issues.apache.org/jira/browse/KAFKA-6288 for a reported + // failure. Such a failure at this stage is not easily reproducible and + // therefore an exception is caught and ignored after issuing a + // warning. This allows class scanning to continue for non-broken plugins. + log.warn( + "Resolving symbolic link '{}' failed. Ignoring this path.", + adjacent, + e + ); + continue; + } + } + + if (!visited.contains(adjacent)) { + visited.add(adjacent); + if (isArchive(adjacent)) { + archives.add(adjacent); + } else if (isClassFile(adjacent)) { + containsClassFiles = true; + } else { + DirectoryStream listing = Files.newDirectoryStream( + adjacent, + PLUGIN_PATH_FILTER + ); + dfs.push(new DirectoryEntry(listing)); + } + } + } + } finally { + while (!dfs.isEmpty()) { + dfs.pop().stream.close(); + } + } + + if (containsClassFiles) { + if (archives.isEmpty()) { + return Collections.singletonList(topPath); + } + log.warn("Plugin path contains both java archives and class files. Returning only the" + + " archives"); + } + return Arrays.asList(archives.toArray(new Path[0])); + } + + /** + * Return the simple class name of a plugin as {@code String}. + * + * @param plugin the plugin descriptor. + * @return the plugin's simple class name. + */ + public static String simpleName(PluginDesc plugin) { + return plugin.pluginClass().getSimpleName(); + } + + /** + * Remove the plugin type name at the end of a plugin class name, if such suffix is present. + * This method is meant to be used to extract plugin aliases. + * + * @param plugin the plugin descriptor. + * @return the pruned simple class name of the plugin. + */ + public static String prunedName(PluginDesc plugin) { + // It's currently simpler to switch on type than do pattern matching. + switch (plugin.type()) { + case SOURCE: + case SINK: + case CONNECTOR: + return prunePluginName(plugin, "Connector"); + default: + return prunePluginName(plugin, plugin.type().simpleName()); + } + } + + /** + * Verify whether a given plugin's alias matches another alias in a collection of plugins. + * + * @param alias the plugin descriptor to test for alias matching. + * @param plugins the collection of plugins to test against. + * @param the plugin type. + * @return false if a match was found in the collection, otherwise true. + */ + public static boolean isAliasUnique( + PluginDesc alias, + Collection> plugins + ) { + boolean matched = false; + for (PluginDesc plugin : plugins) { + if (simpleName(alias).equals(simpleName(plugin)) + || prunedName(alias).equals(prunedName(plugin))) { + if (matched) { + return false; + } + matched = true; + } + } + return true; + } + + private static String prunePluginName(PluginDesc plugin, String suffix) { + String simple = plugin.pluginClass().getSimpleName(); + int pos = simple.lastIndexOf(suffix); + if (pos > 0) { + return simple.substring(0, pos); + } + return simple; + } + + private static class DirectoryEntry { + final DirectoryStream stream; + final Iterator iterator; + + DirectoryEntry(DirectoryStream stream) { + this.stream = stream; + this.iterator = stream.iterator(); + } + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/isolation/Plugins.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/isolation/Plugins.java new file mode 100644 index 0000000..9013c61 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/isolation/Plugins.java @@ -0,0 +1,474 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.isolation; + +import org.apache.kafka.common.Configurable; +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.provider.ConfigProvider; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.connect.components.Versioned; +import org.apache.kafka.connect.connector.ConnectRecord; +import org.apache.kafka.connect.connector.Connector; +import org.apache.kafka.connect.connector.Task; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.storage.ConverterConfig; +import org.apache.kafka.connect.storage.ConverterType; +import org.apache.kafka.connect.storage.HeaderConverter; +import org.apache.kafka.connect.transforms.Transformation; +import org.apache.kafka.connect.transforms.predicates.Predicate; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Set; + +public class Plugins { + + public enum ClassLoaderUsage { + CURRENT_CLASSLOADER, + PLUGINS + } + + private static final Logger log = LoggerFactory.getLogger(Plugins.class); + private final DelegatingClassLoader delegatingLoader; + + public Plugins(Map props) { + List pluginLocations = WorkerConfig.pluginLocations(props); + delegatingLoader = newDelegatingClassLoader(pluginLocations); + delegatingLoader.initLoaders(); + } + + protected DelegatingClassLoader newDelegatingClassLoader(final List paths) { + return AccessController.doPrivileged( + (PrivilegedAction) () -> new DelegatingClassLoader(paths) + ); + } + + private static String pluginNames(Collection> plugins) { + return Utils.join(plugins, ", "); + } + + protected static T newPlugin(Class klass) { + // KAFKA-8340: The thread classloader is used during static initialization and must be + // set to the plugin's classloader during instantiation + ClassLoader savedLoader = compareAndSwapLoaders(klass.getClassLoader()); + try { + return Utils.newInstance(klass); + } catch (Throwable t) { + throw new ConnectException("Instantiation error", t); + } finally { + compareAndSwapLoaders(savedLoader); + } + } + + @SuppressWarnings("unchecked") + protected Class pluginClassFromConfig( + AbstractConfig config, + String propertyName, + Class pluginClass, + Collection> plugins + ) { + Class klass = config.getClass(propertyName); + if (pluginClass.isAssignableFrom(klass)) { + return (Class) klass; + } + throw new ConnectException( + "Failed to find any class that implements " + pluginClass.getSimpleName() + + " for the config " + + propertyName + ", available classes are: " + + pluginNames(plugins) + ); + } + + @SuppressWarnings("unchecked") + protected static Class pluginClass( + DelegatingClassLoader loader, + String classOrAlias, + Class pluginClass + ) throws ClassNotFoundException { + Class klass = loader.loadClass(classOrAlias, false); + if (pluginClass.isAssignableFrom(klass)) { + return (Class) klass; + } + + throw new ClassNotFoundException( + "Requested class: " + + classOrAlias + + " does not extend " + pluginClass.getSimpleName() + ); + } + + public static ClassLoader compareAndSwapLoaders(ClassLoader loader) { + ClassLoader current = Thread.currentThread().getContextClassLoader(); + if (!current.equals(loader)) { + Thread.currentThread().setContextClassLoader(loader); + } + return current; + } + + public ClassLoader currentThreadLoader() { + return Thread.currentThread().getContextClassLoader(); + } + + public ClassLoader compareAndSwapWithDelegatingLoader() { + ClassLoader current = Thread.currentThread().getContextClassLoader(); + if (!current.equals(delegatingLoader)) { + Thread.currentThread().setContextClassLoader(delegatingLoader); + } + return current; + } + + public ClassLoader compareAndSwapLoaders(Connector connector) { + ClassLoader connectorLoader = delegatingLoader.connectorLoader(connector); + return compareAndSwapLoaders(connectorLoader); + } + + public DelegatingClassLoader delegatingLoader() { + return delegatingLoader; + } + + public Set> connectors() { + return delegatingLoader.connectors(); + } + + public Set> converters() { + return delegatingLoader.converters(); + } + + public Set>> transformations() { + return delegatingLoader.transformations(); + } + + public Set>> predicates() { + return delegatingLoader.predicates(); + } + + public Set> configProviders() { + return delegatingLoader.configProviders(); + } + + public Connector newConnector(String connectorClassOrAlias) { + Class klass = connectorClass(connectorClassOrAlias); + return newPlugin(klass); + } + + public Class connectorClass(String connectorClassOrAlias) { + Class klass; + try { + klass = pluginClass( + delegatingLoader, + connectorClassOrAlias, + Connector.class + ); + } catch (ClassNotFoundException e) { + List> matches = new ArrayList<>(); + for (PluginDesc plugin : delegatingLoader.connectors()) { + Class pluginClass = plugin.pluginClass(); + String simpleName = pluginClass.getSimpleName(); + if (simpleName.equals(connectorClassOrAlias) + || simpleName.equals(connectorClassOrAlias + "Connector")) { + matches.add(plugin); + } + } + + if (matches.isEmpty()) { + throw new ConnectException( + "Failed to find any class that implements Connector and which name matches " + + connectorClassOrAlias + + ", available connectors are: " + + pluginNames(delegatingLoader.connectors()) + ); + } + if (matches.size() > 1) { + throw new ConnectException( + "More than one connector matches alias " + + connectorClassOrAlias + + + ". Please use full package and class name instead. Classes found: " + + pluginNames(matches) + ); + } + + PluginDesc entry = matches.get(0); + klass = entry.pluginClass(); + } + return klass; + } + + public Task newTask(Class taskClass) { + return newPlugin(taskClass); + } + + /** + * If the given configuration defines a {@link Converter} using the named configuration property, return a new configured instance. + * + * @param config the configuration containing the {@link Converter}'s configuration; may not be null + * @param classPropertyName the name of the property that contains the name of the {@link Converter} class; may not be null + * @param classLoaderUsage which classloader should be used + * @return the instantiated and configured {@link Converter}; null if the configuration did not define the specified property + * @throws ConnectException if the {@link Converter} implementation class could not be found + */ + public Converter newConverter(AbstractConfig config, String classPropertyName, ClassLoaderUsage classLoaderUsage) { + if (!config.originals().containsKey(classPropertyName)) { + // This configuration does not define the converter via the specified property name + return null; + } + + Class klass = null; + switch (classLoaderUsage) { + case CURRENT_CLASSLOADER: + // Attempt to load first with the current classloader, and plugins as a fallback. + // Note: we can't use config.getConfiguredInstance because Converter doesn't implement Configurable, and even if it did + // we have to remove the property prefixes before calling config(...) and we still always want to call Converter.config. + klass = pluginClassFromConfig(config, classPropertyName, Converter.class, delegatingLoader.converters()); + break; + case PLUGINS: + // Attempt to load with the plugin class loader, which uses the current classloader as a fallback + String converterClassOrAlias = config.getClass(classPropertyName).getName(); + try { + klass = pluginClass(delegatingLoader, converterClassOrAlias, Converter.class); + } catch (ClassNotFoundException e) { + throw new ConnectException( + "Failed to find any class that implements Converter and which name matches " + + converterClassOrAlias + ", available converters are: " + + pluginNames(delegatingLoader.converters()) + ); + } + break; + } + if (klass == null) { + throw new ConnectException("Unable to initialize the Converter specified in '" + classPropertyName + "'"); + } + + // Determine whether this is a key or value converter based upon the supplied property name ... + final boolean isKeyConverter = WorkerConfig.KEY_CONVERTER_CLASS_CONFIG.equals(classPropertyName); + + // Configure the Converter using only the old configuration mechanism ... + String configPrefix = classPropertyName + "."; + Map converterConfig = config.originalsWithPrefix(configPrefix); + log.debug("Configuring the {} converter with configuration keys:{}{}", + isKeyConverter ? "key" : "value", System.lineSeparator(), converterConfig.keySet()); + + Converter plugin; + ClassLoader savedLoader = compareAndSwapLoaders(klass.getClassLoader()); + try { + plugin = newPlugin(klass); + plugin.configure(converterConfig, isKeyConverter); + } finally { + compareAndSwapLoaders(savedLoader); + } + return plugin; + } + + /** + * Load an internal converter, used by the worker for (de)serializing data in internal topics. + * + * @param isKey whether the converter is a key converter + * @param className the class name of the converter + * @param converterConfig the properties to configure the converter with + * @return the instantiated and configured {@link Converter}; never null + * @throws ConnectException if the {@link Converter} implementation class could not be found + */ + public Converter newInternalConverter(boolean isKey, String className, Map converterConfig) { + Class klass; + try { + klass = pluginClass(delegatingLoader, className, Converter.class); + } catch (ClassNotFoundException e) { + throw new ConnectException("Failed to load internal converter class " + className); + } + + Converter plugin; + ClassLoader savedLoader = compareAndSwapLoaders(klass.getClassLoader()); + try { + plugin = newPlugin(klass); + plugin.configure(converterConfig, isKey); + } finally { + compareAndSwapLoaders(savedLoader); + } + return plugin; + } + + /** + * If the given configuration defines a {@link HeaderConverter} using the named configuration property, return a new configured + * instance. + * + * @param config the configuration containing the {@link Converter}'s configuration; may not be null + * @param classPropertyName the name of the property that contains the name of the {@link Converter} class; may not be null + * @param classLoaderUsage which classloader should be used + * @return the instantiated and configured {@link HeaderConverter}; null if the configuration did not define the specified property + * @throws ConnectException if the {@link HeaderConverter} implementation class could not be found + */ + public HeaderConverter newHeaderConverter(AbstractConfig config, String classPropertyName, ClassLoaderUsage classLoaderUsage) { + Class klass = null; + switch (classLoaderUsage) { + case CURRENT_CLASSLOADER: + if (!config.originals().containsKey(classPropertyName)) { + // This connector configuration does not define the header converter via the specified property name + return null; + } + // Attempt to load first with the current classloader, and plugins as a fallback. + // Note: we can't use config.getConfiguredInstance because we have to remove the property prefixes + // before calling config(...) + klass = pluginClassFromConfig(config, classPropertyName, HeaderConverter.class, delegatingLoader.headerConverters()); + break; + case PLUGINS: + // Attempt to load with the plugin class loader, which uses the current classloader as a fallback. + // Note that there will always be at least a default header converter for the worker + String converterClassOrAlias = config.getClass(classPropertyName).getName(); + try { + klass = pluginClass( + delegatingLoader, + converterClassOrAlias, + HeaderConverter.class + ); + } catch (ClassNotFoundException e) { + throw new ConnectException( + "Failed to find any class that implements HeaderConverter and which name matches " + + converterClassOrAlias + + ", available header converters are: " + + pluginNames(delegatingLoader.headerConverters()) + ); + } + } + if (klass == null) { + throw new ConnectException("Unable to initialize the HeaderConverter specified in '" + classPropertyName + "'"); + } + + String configPrefix = classPropertyName + "."; + Map converterConfig = config.originalsWithPrefix(configPrefix); + converterConfig.put(ConverterConfig.TYPE_CONFIG, ConverterType.HEADER.getName()); + log.debug("Configuring the header converter with configuration keys:{}{}", System.lineSeparator(), converterConfig.keySet()); + + HeaderConverter plugin; + ClassLoader savedLoader = compareAndSwapLoaders(klass.getClassLoader()); + try { + plugin = newPlugin(klass); + plugin.configure(converterConfig); + } finally { + compareAndSwapLoaders(savedLoader); + } + return plugin; + } + + public ConfigProvider newConfigProvider(AbstractConfig config, String providerPrefix, ClassLoaderUsage classLoaderUsage) { + String classPropertyName = providerPrefix + ".class"; + Map originalConfig = config.originalsStrings(); + if (!originalConfig.containsKey(classPropertyName)) { + // This configuration does not define the config provider via the specified property name + return null; + } + Class klass = null; + switch (classLoaderUsage) { + case CURRENT_CLASSLOADER: + // Attempt to load first with the current classloader, and plugins as a fallback. + klass = pluginClassFromConfig(config, classPropertyName, ConfigProvider.class, delegatingLoader.configProviders()); + break; + case PLUGINS: + // Attempt to load with the plugin class loader, which uses the current classloader as a fallback + String configProviderClassOrAlias = originalConfig.get(classPropertyName); + try { + klass = pluginClass(delegatingLoader, configProviderClassOrAlias, ConfigProvider.class); + } catch (ClassNotFoundException e) { + throw new ConnectException( + "Failed to find any class that implements ConfigProvider and which name matches " + + configProviderClassOrAlias + ", available ConfigProviders are: " + + pluginNames(delegatingLoader.configProviders()) + ); + } + break; + } + if (klass == null) { + throw new ConnectException("Unable to initialize the ConfigProvider specified in '" + classPropertyName + "'"); + } + + // Configure the ConfigProvider + String configPrefix = providerPrefix + ".param."; + Map configProviderConfig = config.originalsWithPrefix(configPrefix); + + ConfigProvider plugin; + ClassLoader savedLoader = compareAndSwapLoaders(klass.getClassLoader()); + try { + plugin = newPlugin(klass); + plugin.configure(configProviderConfig); + } finally { + compareAndSwapLoaders(savedLoader); + } + return plugin; + } + + /** + * If the given class names are available in the classloader, return a list of new configured + * instances. If the instances implement {@link Configurable}, they are configured with provided {@param config} + * + * @param klassNames the list of class names of plugins that needs to instantiated and configured + * @param config the configuration containing the {@link org.apache.kafka.connect.runtime.Worker}'s configuration; may not be {@code null} + * @param pluginKlass the type of the plugin class that is being instantiated + * @return the instantiated and configured list of plugins of type ; empty list if the {@param klassNames} is {@code null} or empty + * @throws ConnectException if the implementation class could not be found + */ + public List newPlugins(List klassNames, AbstractConfig config, Class pluginKlass) { + List plugins = new ArrayList<>(); + if (klassNames != null) { + for (String klassName : klassNames) { + plugins.add(newPlugin(klassName, config, pluginKlass)); + } + } + return plugins; + } + + public T newPlugin(String klassName, AbstractConfig config, Class pluginKlass) { + T plugin; + Class klass; + try { + klass = pluginClass(delegatingLoader, klassName, pluginKlass); + } catch (ClassNotFoundException e) { + String msg = String.format("Failed to find any class that implements %s and which " + + "name matches %s", pluginKlass, klassName); + throw new ConnectException(msg); + } + ClassLoader savedLoader = compareAndSwapLoaders(klass.getClassLoader()); + try { + plugin = newPlugin(klass); + if (plugin instanceof Versioned) { + Versioned versionedPlugin = (Versioned) plugin; + if (Utils.isBlank(versionedPlugin.version())) { + throw new ConnectException("Version not defined for '" + klassName + "'"); + } + } + if (plugin instanceof Configurable) { + ((Configurable) plugin).configure(config.originals()); + } + } finally { + compareAndSwapLoaders(savedLoader); + } + return plugin; + } + + public > Transformation newTranformations( + String transformationClassOrAlias + ) { + return null; + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/ConnectRestConfigurable.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/ConnectRestConfigurable.java new file mode 100644 index 0000000..f33ce19 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/ConnectRestConfigurable.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.runtime.rest; + +import org.glassfish.jersey.server.ResourceConfig; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Map; +import java.util.Objects; + +import javax.ws.rs.core.Configurable; +import javax.ws.rs.core.Configuration; + +/** + * The implementation delegates to {@link ResourceConfig} so that we can handle duplicate + * registrations deterministically by not re-registering them again. + */ +public class ConnectRestConfigurable implements Configurable { + + private static final Logger log = LoggerFactory.getLogger(ConnectRestConfigurable.class); + + private static final boolean ALLOWED_TO_REGISTER = true; + private static final boolean NOT_ALLOWED_TO_REGISTER = false; + + private ResourceConfig resourceConfig; + + public ConnectRestConfigurable(ResourceConfig resourceConfig) { + Objects.requireNonNull(resourceConfig, "ResourceConfig can't be null"); + this.resourceConfig = resourceConfig; + } + + + @Override + public Configuration getConfiguration() { + return resourceConfig.getConfiguration(); + } + + @Override + public ResourceConfig property(String name, Object value) { + return resourceConfig.property(name, value); + } + + @Override + public ResourceConfig register(Object component) { + if (allowedToRegister(component)) { + resourceConfig.register(component); + } + return resourceConfig; + } + + @Override + public ResourceConfig register(Object component, int priority) { + if (allowedToRegister(component)) { + resourceConfig.register(component, priority); + } + return resourceConfig; + } + + @Override + public ResourceConfig register(Object component, Map, Integer> contracts) { + if (allowedToRegister(component)) { + resourceConfig.register(component, contracts); + } + return resourceConfig; + } + + @Override + public ResourceConfig register(Object component, Class... contracts) { + if (allowedToRegister(component)) { + resourceConfig.register(component, contracts); + } + return resourceConfig; + } + + @Override + public ResourceConfig register(Class componentClass, Map, Integer> contracts) { + if (allowedToRegister(componentClass)) { + resourceConfig.register(componentClass, contracts); + } + return resourceConfig; + } + + @Override + public ResourceConfig register(Class componentClass, Class... contracts) { + if (allowedToRegister(componentClass)) { + resourceConfig.register(componentClass, contracts); + } + return resourceConfig; + } + + @Override + public ResourceConfig register(Class componentClass, int priority) { + if (allowedToRegister(componentClass)) { + resourceConfig.register(componentClass, priority); + } + return resourceConfig; + } + + @Override + public ResourceConfig register(Class componentClass) { + if (allowedToRegister(componentClass)) { + resourceConfig.register(componentClass); + } + return resourceConfig; + } + + private boolean allowedToRegister(Object component) { + if (resourceConfig.isRegistered(component)) { + log.warn("The resource {} is already registered", component); + return NOT_ALLOWED_TO_REGISTER; + } + return ALLOWED_TO_REGISTER; + } + + private boolean allowedToRegister(Class componentClass) { + if (resourceConfig.isRegistered(componentClass)) { + log.warn("The resource {} is already registered", componentClass); + return NOT_ALLOWED_TO_REGISTER; + } + return ALLOWED_TO_REGISTER; + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/ConnectRestExtensionContextImpl.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/ConnectRestExtensionContextImpl.java new file mode 100644 index 0000000..6d0a2a2 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/ConnectRestExtensionContextImpl.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.runtime.rest; + +import org.apache.kafka.connect.health.ConnectClusterState; +import org.apache.kafka.connect.rest.ConnectRestExtensionContext; + +import javax.ws.rs.core.Configurable; + +public class ConnectRestExtensionContextImpl implements ConnectRestExtensionContext { + + private Configurable> configurable; + private ConnectClusterState clusterState; + + public ConnectRestExtensionContextImpl( + Configurable> configurable, + ConnectClusterState clusterState + ) { + this.configurable = configurable; + this.clusterState = clusterState; + } + + @Override + public Configurable> configurable() { + return configurable; + } + + @Override + public ConnectClusterState clusterState() { + return clusterState; + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/InternalRequestSignature.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/InternalRequestSignature.java new file mode 100644 index 0000000..3cee577 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/InternalRequestSignature.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest; + +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.runtime.rest.errors.BadRequestException; +import org.eclipse.jetty.client.api.Request; + +import javax.crypto.Mac; +import javax.crypto.SecretKey; +import javax.ws.rs.core.HttpHeaders; +import java.security.InvalidKeyException; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.Arrays; +import java.util.Base64; +import java.util.Objects; + +public class InternalRequestSignature { + + public static final String SIGNATURE_HEADER = "X-Connect-Authorization"; + public static final String SIGNATURE_ALGORITHM_HEADER = "X-Connect-Request-Signature-Algorithm"; + + private final byte[] requestBody; + private final Mac mac; + private final byte[] requestSignature; + + /** + * Add a signature to a request. + * @param key the key to sign the request with; may not be null + * @param requestBody the body of the request; may not be null + * @param signatureAlgorithm the algorithm to use to sign the request; may not be null + * @param request the request to add the signature to; may not be null + */ + public static void addToRequest(SecretKey key, byte[] requestBody, String signatureAlgorithm, Request request) { + Mac mac; + try { + mac = mac(signatureAlgorithm); + } catch (NoSuchAlgorithmException e) { + throw new ConnectException(e); + } + byte[] requestSignature = sign(mac, key, requestBody); + request.header(InternalRequestSignature.SIGNATURE_HEADER, Base64.getEncoder().encodeToString(requestSignature)) + .header(InternalRequestSignature.SIGNATURE_ALGORITHM_HEADER, signatureAlgorithm); + } + + /** + * Extract a signature from a request. + * @param requestBody the body of the request; may not be null + * @param headers the headers for the request; may be null + * @return the signature extracted from the request, or null if one or more request signature + * headers was not present + */ + public static InternalRequestSignature fromHeaders(byte[] requestBody, HttpHeaders headers) { + if (headers == null) { + return null; + } + + String signatureAlgorithm = headers.getHeaderString(SIGNATURE_ALGORITHM_HEADER); + String encodedSignature = headers.getHeaderString(SIGNATURE_HEADER); + if (signatureAlgorithm == null || encodedSignature == null) { + return null; + } + + Mac mac; + try { + mac = mac(signatureAlgorithm); + } catch (NoSuchAlgorithmException e) { + throw new BadRequestException(e.getMessage()); + } + + byte[] decodedSignature; + try { + decodedSignature = Base64.getDecoder().decode(encodedSignature); + } catch (IllegalArgumentException e) { + throw new BadRequestException(e.getMessage()); + } + + return new InternalRequestSignature( + requestBody, + mac, + decodedSignature + ); + } + + // Public for testing + public InternalRequestSignature(byte[] requestBody, Mac mac, byte[] requestSignature) { + this.requestBody = requestBody; + this.mac = mac; + this.requestSignature = requestSignature; + } + + public String keyAlgorithm() { + return mac.getAlgorithm(); + } + + public boolean isValid(SecretKey key) { + return MessageDigest.isEqual(sign(mac, key, requestBody), requestSignature); + } + + private static Mac mac(String signatureAlgorithm) throws NoSuchAlgorithmException { + return Mac.getInstance(signatureAlgorithm); + } + + private static byte[] sign(Mac mac, SecretKey key, byte[] requestBody) { + try { + mac.init(key); + } catch (InvalidKeyException e) { + throw new ConnectException(e); + } + return mac.doFinal(requestBody); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + InternalRequestSignature that = (InternalRequestSignature) o; + return Arrays.equals(requestBody, that.requestBody) + && mac.getAlgorithm().equals(that.mac.getAlgorithm()) + && mac.getMacLength() == that.mac.getMacLength() + && mac.getProvider().equals(that.mac.getProvider()) + && Arrays.equals(requestSignature, that.requestSignature); + } + + @Override + public int hashCode() { + int result = Objects.hash(mac); + result = 31 * result + Arrays.hashCode(requestBody); + result = 31 * result + Arrays.hashCode(requestSignature); + return result; + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/RestClient.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/RestClient.java new file mode 100644 index 0000000..81c5a84 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/RestClient.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.runtime.rest; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +import javax.crypto.SecretKey; +import javax.ws.rs.core.HttpHeaders; + +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.apache.kafka.connect.runtime.rest.entities.ErrorMessage; +import org.apache.kafka.connect.runtime.rest.errors.ConnectRestException; +import org.apache.kafka.connect.runtime.rest.util.SSLUtils; +import org.eclipse.jetty.client.HttpClient; +import org.eclipse.jetty.client.api.ContentResponse; +import org.eclipse.jetty.client.api.Request; +import org.eclipse.jetty.client.util.StringContentProvider; +import org.eclipse.jetty.http.HttpField; +import org.eclipse.jetty.http.HttpFields; +import org.eclipse.jetty.http.HttpStatus; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.ws.rs.core.Response; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; + +public class RestClient { + private static final Logger log = LoggerFactory.getLogger(RestClient.class); + private static final ObjectMapper JSON_SERDE = new ObjectMapper(); + + /** + * Sends HTTP request to remote REST server + * + * @param url HTTP connection will be established with this url. + * @param method HTTP method ("GET", "POST", "PUT", etc.) + * @param headers HTTP headers from REST endpoint + * @param requestBodyData Object to serialize as JSON and send in the request body. + * @param responseFormat Expected format of the response to the HTTP request. + * @param The type of the deserialized response to the HTTP request. + * @return The deserialized response to the HTTP request, or null if no data is expected. + */ + public static HttpResponse httpRequest(String url, String method, HttpHeaders headers, Object requestBodyData, + TypeReference responseFormat, WorkerConfig config) { + return httpRequest(url, method, headers, requestBodyData, responseFormat, config, null, null); + } + + /** + * Sends HTTP request to remote REST server + * + * @param url HTTP connection will be established with this url. + * @param method HTTP method ("GET", "POST", "PUT", etc.) + * @param headers HTTP headers from REST endpoint + * @param requestBodyData Object to serialize as JSON and send in the request body. + * @param responseFormat Expected format of the response to the HTTP request. + * @param The type of the deserialized response to the HTTP request. + * @param sessionKey The key to sign the request with (intended for internal requests only); + * may be null if the request doesn't need to be signed + * @param requestSignatureAlgorithm The algorithm to sign the request with (intended for internal requests only); + * may be null if the request doesn't need to be signed + * @return The deserialized response to the HTTP request, or null if no data is expected. + */ + public static HttpResponse httpRequest(String url, String method, HttpHeaders headers, Object requestBodyData, + TypeReference responseFormat, WorkerConfig config, + SecretKey sessionKey, String requestSignatureAlgorithm) { + HttpClient client; + + if (url.startsWith("https://")) { + client = new HttpClient(SSLUtils.createClientSideSslContextFactory(config)); + } else { + client = new HttpClient(); + } + + client.setFollowRedirects(false); + + try { + client.start(); + } catch (Exception e) { + log.error("Failed to start RestClient: ", e); + throw new ConnectRestException(Response.Status.INTERNAL_SERVER_ERROR, "Failed to start RestClient: " + e.getMessage(), e); + } + + try { + String serializedBody = requestBodyData == null ? null : JSON_SERDE.writeValueAsString(requestBodyData); + log.trace("Sending {} with input {} to {}", method, serializedBody, url); + + Request req = client.newRequest(url); + req.method(method); + req.accept("application/json"); + req.agent("kafka-connect"); + addHeadersToRequest(headers, req); + + if (serializedBody != null) { + req.content(new StringContentProvider(serializedBody, StandardCharsets.UTF_8), "application/json"); + if (sessionKey != null && requestSignatureAlgorithm != null) { + InternalRequestSignature.addToRequest( + sessionKey, + serializedBody.getBytes(StandardCharsets.UTF_8), + requestSignatureAlgorithm, + req + ); + } + } + + ContentResponse res = req.send(); + + int responseCode = res.getStatus(); + log.debug("Request's response code: {}", responseCode); + if (responseCode == HttpStatus.NO_CONTENT_204) { + return new HttpResponse<>(responseCode, convertHttpFieldsToMap(res.getHeaders()), null); + } else if (responseCode >= 400) { + ErrorMessage errorMessage = JSON_SERDE.readValue(res.getContentAsString(), ErrorMessage.class); + throw new ConnectRestException(responseCode, errorMessage.errorCode(), errorMessage.message()); + } else if (responseCode >= 200 && responseCode < 300) { + T result = JSON_SERDE.readValue(res.getContentAsString(), responseFormat); + return new HttpResponse<>(responseCode, convertHttpFieldsToMap(res.getHeaders()), result); + } else { + throw new ConnectRestException(Response.Status.INTERNAL_SERVER_ERROR, + Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), + "Unexpected status code when handling forwarded request: " + responseCode); + } + } catch (IOException | InterruptedException | TimeoutException | ExecutionException e) { + log.error("IO error forwarding REST request: ", e); + throw new ConnectRestException(Response.Status.INTERNAL_SERVER_ERROR, "IO Error trying to forward REST request: " + e.getMessage(), e); + } catch (Throwable t) { + log.error("Error forwarding REST request", t); + throw new ConnectRestException(Response.Status.INTERNAL_SERVER_ERROR, "Error trying to forward REST request: " + t.getMessage(), t); + } finally { + try { + client.stop(); + } catch (Exception e) { + log.error("Failed to stop HTTP client", e); + } + } + } + + + /** + * Extract headers from REST call and add to client request + * @param headers Headers from REST endpoint + * @param req The client request to modify + */ + private static void addHeadersToRequest(HttpHeaders headers, Request req) { + if (headers != null) { + String credentialAuthorization = headers.getHeaderString(HttpHeaders.AUTHORIZATION); + if (credentialAuthorization != null) { + req.header(HttpHeaders.AUTHORIZATION, credentialAuthorization); + } + } + } + + /** + * Convert response parameters from Jetty format (HttpFields) + * @param httpFields + * @return + */ + private static Map convertHttpFieldsToMap(HttpFields httpFields) { + Map headers = new HashMap<>(); + + if (httpFields == null || httpFields.size() == 0) + return headers; + + for (HttpField field : httpFields) { + headers.put(field.getName(), field.getValue()); + } + + return headers; + } + + public static class HttpResponse { + private int status; + private Map headers; + private T body; + + public HttpResponse(int status, Map headers, T body) { + this.status = status; + this.headers = headers; + this.body = body; + } + + public int status() { + return status; + } + + public Map headers() { + return headers; + } + + public T body() { + return body; + } + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/RestServer.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/RestServer.java new file mode 100644 index 0000000..b337451 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/RestServer.java @@ -0,0 +1,468 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest; + +import com.fasterxml.jackson.jaxrs.json.JacksonJsonProvider; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.health.ConnectClusterDetails; +import org.apache.kafka.connect.rest.ConnectRestExtension; +import org.apache.kafka.connect.rest.ConnectRestExtensionContext; +import org.apache.kafka.connect.runtime.Herder; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.apache.kafka.connect.runtime.health.ConnectClusterDetailsImpl; +import org.apache.kafka.connect.runtime.health.ConnectClusterStateImpl; +import org.apache.kafka.connect.runtime.rest.errors.ConnectExceptionMapper; +import org.apache.kafka.connect.runtime.rest.resources.ConnectorPluginsResource; +import org.apache.kafka.connect.runtime.rest.resources.ConnectorsResource; +import org.apache.kafka.connect.runtime.rest.resources.LoggingResource; +import org.apache.kafka.connect.runtime.rest.resources.RootResource; +import org.apache.kafka.connect.runtime.rest.util.SSLUtils; +import org.eclipse.jetty.server.Connector; +import org.eclipse.jetty.server.CustomRequestLog; +import org.eclipse.jetty.server.Handler; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.server.ServerConnector; +import org.eclipse.jetty.server.Slf4jRequestLogWriter; +import org.eclipse.jetty.server.handler.ContextHandlerCollection; +import org.eclipse.jetty.server.handler.DefaultHandler; +import org.eclipse.jetty.server.handler.RequestLogHandler; +import org.eclipse.jetty.server.handler.StatisticsHandler; +import org.eclipse.jetty.servlet.FilterHolder; +import org.eclipse.jetty.servlet.ServletContextHandler; +import org.eclipse.jetty.servlet.ServletHolder; +import org.eclipse.jetty.servlets.CrossOriginFilter; +import org.eclipse.jetty.servlets.HeaderFilter; +import org.eclipse.jetty.util.ssl.SslContextFactory; +import org.glassfish.jersey.server.ResourceConfig; +import org.glassfish.jersey.server.ServerProperties; +import org.glassfish.jersey.servlet.ServletContainer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.servlet.DispatcherType; +import javax.ws.rs.core.UriBuilder; +import java.io.IOException; +import java.net.URI; +import java.util.ArrayList; +import java.util.Collections; +import java.util.EnumSet; +import java.util.List; +import java.util.Locale; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static org.apache.kafka.connect.runtime.WorkerConfig.ADMIN_LISTENERS_HTTPS_CONFIGS_PREFIX; + +/** + * Embedded server for the REST API that provides the control plane for Kafka Connect workers. + */ +public class RestServer { + private static final Logger log = LoggerFactory.getLogger(RestServer.class); + + // Used to distinguish between Admin connectors and regular REST API connectors when binding admin handlers + private static final String ADMIN_SERVER_CONNECTOR_NAME = "Admin"; + + private static final Pattern LISTENER_PATTERN = Pattern.compile("^(.*)://\\[?([0-9a-zA-Z\\-%._:]*)\\]?:(-?[0-9]+)"); + private static final long GRACEFUL_SHUTDOWN_TIMEOUT_MS = 60 * 1000; + + private static final String PROTOCOL_HTTP = "http"; + private static final String PROTOCOL_HTTPS = "https"; + + private final WorkerConfig config; + private ContextHandlerCollection handlers; + private Server jettyServer; + + private List connectRestExtensions = Collections.emptyList(); + + /** + * Create a REST server for this herder using the specified configs. + */ + public RestServer(WorkerConfig config) { + this.config = config; + + List listeners = config.getList(WorkerConfig.LISTENERS_CONFIG); + List adminListeners = config.getList(WorkerConfig.ADMIN_LISTENERS_CONFIG); + + jettyServer = new Server(); + handlers = new ContextHandlerCollection(); + + createConnectors(listeners, adminListeners); + } + + /** + * Adds Jetty connector for each configured listener + */ + public void createConnectors(List listeners, List adminListeners) { + List connectors = new ArrayList<>(); + + for (String listener : listeners) { + Connector connector = createConnector(listener); + connectors.add(connector); + log.info("Added connector for {}", listener); + } + + jettyServer.setConnectors(connectors.toArray(new Connector[0])); + + if (adminListeners != null && !adminListeners.isEmpty()) { + for (String adminListener : adminListeners) { + Connector conn = createConnector(adminListener, true); + jettyServer.addConnector(conn); + log.info("Added admin connector for {}", adminListener); + } + } + } + + /** + * Creates regular (non-admin) Jetty connector according to configuration + */ + public Connector createConnector(String listener) { + return createConnector(listener, false); + } + + /** + * Creates Jetty connector according to configuration + */ + public Connector createConnector(String listener, boolean isAdmin) { + Matcher listenerMatcher = LISTENER_PATTERN.matcher(listener); + + if (!listenerMatcher.matches()) + throw new ConfigException("Listener doesn't have the right format (protocol://hostname:port)."); + + String protocol = listenerMatcher.group(1).toLowerCase(Locale.ENGLISH); + + if (!PROTOCOL_HTTP.equals(protocol) && !PROTOCOL_HTTPS.equals(protocol)) + throw new ConfigException(String.format("Listener protocol must be either \"%s\" or \"%s\".", PROTOCOL_HTTP, PROTOCOL_HTTPS)); + + String hostname = listenerMatcher.group(2); + int port = Integer.parseInt(listenerMatcher.group(3)); + + ServerConnector connector; + + if (PROTOCOL_HTTPS.equals(protocol)) { + SslContextFactory ssl; + if (isAdmin) { + ssl = SSLUtils.createServerSideSslContextFactory(config, ADMIN_LISTENERS_HTTPS_CONFIGS_PREFIX); + } else { + ssl = SSLUtils.createServerSideSslContextFactory(config); + } + connector = new ServerConnector(jettyServer, ssl); + if (!isAdmin) { + connector.setName(String.format("%s_%s%d", PROTOCOL_HTTPS, hostname, port)); + } + } else { + connector = new ServerConnector(jettyServer); + if (!isAdmin) { + connector.setName(String.format("%s_%s%d", PROTOCOL_HTTP, hostname, port)); + } + } + + if (isAdmin) { + connector.setName(ADMIN_SERVER_CONNECTOR_NAME); + } + + if (!hostname.isEmpty()) + connector.setHost(hostname); + + connector.setPort(port); + + return connector; + } + + public void initializeServer() { + log.info("Initializing REST server"); + + /* Needed for graceful shutdown as per `setStopTimeout` documentation */ + StatisticsHandler statsHandler = new StatisticsHandler(); + statsHandler.setHandler(handlers); + jettyServer.setHandler(statsHandler); + jettyServer.setStopTimeout(GRACEFUL_SHUTDOWN_TIMEOUT_MS); + jettyServer.setStopAtShutdown(true); + + try { + jettyServer.start(); + } catch (Exception e) { + throw new ConnectException("Unable to initialize REST server", e); + } + + log.info("REST server listening at " + jettyServer.getURI() + ", advertising URL " + advertisedUrl()); + log.info("REST admin endpoints at " + adminUrl()); + } + + public void initializeResources(Herder herder) { + log.info("Initializing REST resources"); + + ResourceConfig resourceConfig = new ResourceConfig(); + resourceConfig.register(new JacksonJsonProvider()); + + resourceConfig.register(new RootResource(herder)); + resourceConfig.register(new ConnectorsResource(herder, config)); + resourceConfig.register(new ConnectorPluginsResource(herder)); + + resourceConfig.register(ConnectExceptionMapper.class); + resourceConfig.property(ServerProperties.WADL_FEATURE_DISABLE, true); + + registerRestExtensions(herder, resourceConfig); + + List adminListeners = config.getList(WorkerConfig.ADMIN_LISTENERS_CONFIG); + ResourceConfig adminResourceConfig; + if (adminListeners == null) { + log.info("Adding admin resources to main listener"); + adminResourceConfig = resourceConfig; + adminResourceConfig.register(new LoggingResource()); + } else if (adminListeners.size() > 0) { + // TODO: we need to check if these listeners are same as 'listeners' + // TODO: the following code assumes that they are different + log.info("Adding admin resources to admin listener"); + adminResourceConfig = new ResourceConfig(); + adminResourceConfig.register(new JacksonJsonProvider()); + adminResourceConfig.register(new LoggingResource()); + adminResourceConfig.register(ConnectExceptionMapper.class); + } else { + log.info("Skipping adding admin resources"); + // set up adminResource but add no handlers to it + adminResourceConfig = resourceConfig; + } + + ServletContainer servletContainer = new ServletContainer(resourceConfig); + ServletHolder servletHolder = new ServletHolder(servletContainer); + List contextHandlers = new ArrayList<>(); + + ServletContextHandler context = new ServletContextHandler(ServletContextHandler.SESSIONS); + context.setContextPath("/"); + context.addServlet(servletHolder, "/*"); + contextHandlers.add(context); + + ServletContextHandler adminContext = null; + if (adminResourceConfig != resourceConfig) { + adminContext = new ServletContextHandler(ServletContextHandler.SESSIONS); + ServletHolder adminServletHolder = new ServletHolder(new ServletContainer(adminResourceConfig)); + adminContext.setContextPath("/"); + adminContext.addServlet(adminServletHolder, "/*"); + adminContext.setVirtualHosts(new String[]{"@" + ADMIN_SERVER_CONNECTOR_NAME}); + contextHandlers.add(adminContext); + } + + String allowedOrigins = config.getString(WorkerConfig.ACCESS_CONTROL_ALLOW_ORIGIN_CONFIG); + if (!Utils.isBlank(allowedOrigins)) { + FilterHolder filterHolder = new FilterHolder(new CrossOriginFilter()); + filterHolder.setName("cross-origin"); + filterHolder.setInitParameter(CrossOriginFilter.ALLOWED_ORIGINS_PARAM, allowedOrigins); + String allowedMethods = config.getString(WorkerConfig.ACCESS_CONTROL_ALLOW_METHODS_CONFIG); + if (!Utils.isBlank(allowedMethods)) { + filterHolder.setInitParameter(CrossOriginFilter.ALLOWED_METHODS_PARAM, allowedMethods); + } + context.addFilter(filterHolder, "/*", EnumSet.of(DispatcherType.REQUEST)); + } + + String headerConfig = config.getString(WorkerConfig.RESPONSE_HTTP_HEADERS_CONFIG); + if (!Utils.isBlank(headerConfig)) { + configureHttpResponsHeaderFilter(context); + } + + RequestLogHandler requestLogHandler = new RequestLogHandler(); + Slf4jRequestLogWriter slf4jRequestLogWriter = new Slf4jRequestLogWriter(); + slf4jRequestLogWriter.setLoggerName(RestServer.class.getCanonicalName()); + CustomRequestLog requestLog = new CustomRequestLog(slf4jRequestLogWriter, CustomRequestLog.EXTENDED_NCSA_FORMAT + " %{ms}T"); + requestLogHandler.setRequestLog(requestLog); + + contextHandlers.add(new DefaultHandler()); + contextHandlers.add(requestLogHandler); + + handlers.setHandlers(contextHandlers.toArray(new Handler[0])); + try { + context.start(); + } catch (Exception e) { + throw new ConnectException("Unable to initialize REST resources", e); + } + + if (adminResourceConfig != resourceConfig) { + try { + log.debug("Starting admin context"); + adminContext.start(); + } catch (Exception e) { + throw new ConnectException("Unable to initialize Admin REST resources", e); + } + } + + log.info("REST resources initialized; server is started and ready to handle requests"); + } + + public URI serverUrl() { + return jettyServer.getURI(); + } + + public void stop() { + log.info("Stopping REST server"); + + try { + for (ConnectRestExtension connectRestExtension : connectRestExtensions) { + try { + connectRestExtension.close(); + } catch (IOException e) { + log.warn("Error while invoking close on " + connectRestExtension.getClass(), e); + } + } + jettyServer.stop(); + jettyServer.join(); + } catch (Exception e) { + jettyServer.destroy(); + throw new ConnectException("Unable to stop REST server", e); + } + + log.info("REST server stopped"); + } + + /** + * Get the URL to advertise to other workers and clients. This uses the default connector from the embedded Jetty + * server, unless overrides for advertised hostname and/or port are provided via configs. {@link #initializeServer()} + * must be invoked successfully before calling this method. + */ + public URI advertisedUrl() { + UriBuilder builder = UriBuilder.fromUri(jettyServer.getURI()); + + String advertisedSecurityProtocol = determineAdvertisedProtocol(); + ServerConnector serverConnector = findConnector(advertisedSecurityProtocol); + builder.scheme(advertisedSecurityProtocol); + + String advertisedHostname = config.getString(WorkerConfig.REST_ADVERTISED_HOST_NAME_CONFIG); + if (advertisedHostname != null && !advertisedHostname.isEmpty()) + builder.host(advertisedHostname); + else if (serverConnector != null && serverConnector.getHost() != null && serverConnector.getHost().length() > 0) + builder.host(serverConnector.getHost()); + + Integer advertisedPort = config.getInt(WorkerConfig.REST_ADVERTISED_PORT_CONFIG); + if (advertisedPort != null) + builder.port(advertisedPort); + else if (serverConnector != null && serverConnector.getPort() > 0) + builder.port(serverConnector.getPort()); + + log.info("Advertised URI: {}", builder.build()); + + return builder.build(); + } + + /** + * @return the admin url for this worker. can be null if admin endpoints are disabled. + */ + public URI adminUrl() { + ServerConnector adminConnector = null; + for (Connector connector : jettyServer.getConnectors()) { + if (ADMIN_SERVER_CONNECTOR_NAME.equals(connector.getName())) + adminConnector = (ServerConnector) connector; + } + + if (adminConnector == null) { + List adminListeners = config.getList(WorkerConfig.ADMIN_LISTENERS_CONFIG); + if (adminListeners == null) { + return advertisedUrl(); + } else if (adminListeners.isEmpty()) { + return null; + } else { + log.error("No admin connector found for listeners {}", adminListeners); + return null; + } + } + + UriBuilder builder = UriBuilder.fromUri(jettyServer.getURI()); + builder.port(adminConnector.getLocalPort()); + + return builder.build(); + } + + String determineAdvertisedProtocol() { + String advertisedSecurityProtocol = config.getString(WorkerConfig.REST_ADVERTISED_LISTENER_CONFIG); + if (advertisedSecurityProtocol == null) { + String listeners = (String) config.originals().get(WorkerConfig.LISTENERS_CONFIG); + + if (listeners == null) + return PROTOCOL_HTTP; + else + listeners = listeners.toLowerCase(Locale.ENGLISH); + + if (listeners.contains(String.format("%s://", PROTOCOL_HTTP))) + return PROTOCOL_HTTP; + else if (listeners.contains(String.format("%s://", PROTOCOL_HTTPS))) + return PROTOCOL_HTTPS; + else + return PROTOCOL_HTTP; + } else { + return advertisedSecurityProtocol.toLowerCase(Locale.ENGLISH); + } + } + + /** + * Locate a Jetty connector for the standard (non-admin) REST API that uses the given protocol. + * @param protocol the protocol for the connector (e.g., "http" or "https"). + * @return a {@link ServerConnector} for the server that uses the requested protocol, or + * {@code null} if none exist. + */ + ServerConnector findConnector(String protocol) { + for (Connector connector : jettyServer.getConnectors()) { + String connectorName = connector.getName(); + // We set the names for these connectors when instantiating them, beginning with the + // protocol for the connector and then an underscore ("_"). We rely on that format here + // when trying to locate a connector with the requested protocol; if the naming format + // for the connectors we create is ever changed, we'll need to adjust the logic here + // accordingly. + if (connectorName.startsWith(protocol + "_") && !ADMIN_SERVER_CONNECTOR_NAME.equals(connectorName)) + return (ServerConnector) connector; + } + + return null; + } + + void registerRestExtensions(Herder herder, ResourceConfig resourceConfig) { + connectRestExtensions = herder.plugins().newPlugins( + config.getList(WorkerConfig.REST_EXTENSION_CLASSES_CONFIG), + config, ConnectRestExtension.class); + + long herderRequestTimeoutMs = ConnectorsResource.REQUEST_TIMEOUT_MS; + + Integer rebalanceTimeoutMs = config.getRebalanceTimeout(); + + if (rebalanceTimeoutMs != null) { + herderRequestTimeoutMs = Math.min(herderRequestTimeoutMs, rebalanceTimeoutMs.longValue()); + } + + ConnectClusterDetails connectClusterDetails = new ConnectClusterDetailsImpl( + herder.kafkaClusterId() + ); + + ConnectRestExtensionContext connectRestExtensionContext = + new ConnectRestExtensionContextImpl( + new ConnectRestConfigurable(resourceConfig), + new ConnectClusterStateImpl(herderRequestTimeoutMs, connectClusterDetails, herder) + ); + for (ConnectRestExtension connectRestExtension : connectRestExtensions) { + connectRestExtension.register(connectRestExtensionContext); + } + + } + + /** + * Register header filter to ServletContextHandler. + * @param context The serverlet context handler + */ + protected void configureHttpResponsHeaderFilter(ServletContextHandler context) { + String headerConfig = config.getString(WorkerConfig.RESPONSE_HTTP_HEADERS_CONFIG); + FilterHolder headerFilterHolder = new FilterHolder(HeaderFilter.class); + headerFilterHolder.setInitParameter("headerConfig", headerConfig); + context.addFilter(headerFilterHolder, "/*", EnumSet.of(DispatcherType.REQUEST)); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ActiveTopicsInfo.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ActiveTopicsInfo.java new file mode 100644 index 0000000..b43c5aa --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ActiveTopicsInfo.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest.entities; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Collection; + +public class ActiveTopicsInfo { + private final String connector; + private final Collection topics; + + @JsonCreator + public ActiveTopicsInfo(String connector, @JsonProperty("topics") Collection topics) { + this.connector = connector; + this.topics = topics; + } + + public String connector() { + return connector; + } + + @JsonProperty + public Collection topics() { + return topics; + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ConfigInfo.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ConfigInfo.java new file mode 100644 index 0000000..49a2f6f --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ConfigInfo.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest.entities; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; + +public class ConfigInfo { + + private ConfigKeyInfo configKey; + private ConfigValueInfo configValue; + + @JsonCreator + public ConfigInfo( + @JsonProperty("definition") ConfigKeyInfo configKey, + @JsonProperty("value") ConfigValueInfo configValue) { + this.configKey = configKey; + this.configValue = configValue; + } + + @JsonProperty("definition") + public ConfigKeyInfo configKey() { + return configKey; + } + + @JsonProperty("value") + public ConfigValueInfo configValue() { + return configValue; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ConfigInfo that = (ConfigInfo) o; + return Objects.equals(configKey, that.configKey) && + Objects.equals(configValue, that.configValue); + } + + @Override + public int hashCode() { + return Objects.hash(configKey, configValue); + } + + @Override + public String toString() { + return "[" + configKey.toString() + "," + configValue.toString() + "]"; + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ConfigInfos.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ConfigInfos.java new file mode 100644 index 0000000..d5970b5 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ConfigInfos.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest.entities; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; +import java.util.Objects; + +public class ConfigInfos { + + @JsonProperty("name") + private final String name; + + @JsonProperty("error_count") + private final int errorCount; + + @JsonProperty("groups") + private final List groups; + + @JsonProperty("configs") + private final List configs; + + @JsonCreator + public ConfigInfos(@JsonProperty("name") String name, + @JsonProperty("error_count") int errorCount, + @JsonProperty("groups") List groups, + @JsonProperty("configs") List configs) { + this.name = name; + this.groups = groups; + this.errorCount = errorCount; + this.configs = configs; + } + + @JsonProperty + public String name() { + return name; + } + + @JsonProperty + public List groups() { + return groups; + } + + @JsonProperty("error_count") + public int errorCount() { + return errorCount; + } + + @JsonProperty("configs") + public List values() { + return configs; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ConfigInfos that = (ConfigInfos) o; + return Objects.equals(name, that.name) && + Objects.equals(errorCount, that.errorCount) && + Objects.equals(groups, that.groups) && + Objects.equals(configs, that.configs); + } + + @Override + public int hashCode() { + return Objects.hash(name, errorCount, groups, configs); + } + + @Override + public String toString() { + StringBuffer sb = new StringBuffer(); + sb.append("[") + .append(name) + .append(",") + .append(errorCount) + .append(",") + .append(groups) + .append(",") + .append(configs) + .append("]"); + return sb.toString(); + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ConfigKeyInfo.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ConfigKeyInfo.java new file mode 100644 index 0000000..728ecc5 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ConfigKeyInfo.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest.entities; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; +import java.util.Objects; + +public class ConfigKeyInfo { + + private final String name; + private final String type; + private final boolean required; + private final String defaultValue; + private final String importance; + private final String documentation; + private final String group; + private final int orderInGroup; + private final String width; + private final String displayName; + private final List dependents; + + @JsonCreator + public ConfigKeyInfo(@JsonProperty("name") String name, + @JsonProperty("type") String type, + @JsonProperty("required") boolean required, + @JsonProperty("default_value") String defaultValue, + @JsonProperty("importance") String importance, + @JsonProperty("documentation") String documentation, + @JsonProperty("group") String group, + @JsonProperty("order_in_group") int orderInGroup, + @JsonProperty("width") String width, + @JsonProperty("display_name") String displayName, + @JsonProperty("dependents") List dependents) { + this.name = name; + this.type = type; + this.required = required; + this.defaultValue = defaultValue; + this.importance = importance; + this.documentation = documentation; + this.group = group; + this.orderInGroup = orderInGroup; + this.width = width; + this.displayName = displayName; + this.dependents = dependents; + } + + @JsonProperty + public String name() { + return name; + } + + @JsonProperty + public String type() { + return type; + } + + @JsonProperty + public boolean required() { + return required; + } + + @JsonProperty("default_value") + public String defaultValue() { + return defaultValue; + } + + @JsonProperty + public String documentation() { + return documentation; + } + + @JsonProperty + public String group() { + return group; + } + + @JsonProperty("order") + public int orderInGroup() { + return orderInGroup; + } + + @JsonProperty + public String width() { + return width; + } + + @JsonProperty + public String importance() { + return importance; + } + + @JsonProperty("display_name") + public String displayName() { + return displayName; + } + + @JsonProperty + public List dependents() { + return dependents; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ConfigKeyInfo that = (ConfigKeyInfo) o; + return Objects.equals(name, that.name) && + Objects.equals(type, that.type) && + Objects.equals(required, that.required) && + Objects.equals(defaultValue, that.defaultValue) && + Objects.equals(importance, that.importance) && + Objects.equals(documentation, that.documentation) && + Objects.equals(group, that.group) && + Objects.equals(orderInGroup, that.orderInGroup) && + Objects.equals(width, that.width) && + Objects.equals(displayName, that.displayName) && + Objects.equals(dependents, that.dependents); + } + + @Override + public int hashCode() { + return Objects.hash(name, type, required, defaultValue, importance, documentation, group, orderInGroup, width, displayName, dependents); + } + + @Override + public String toString() { + StringBuffer sb = new StringBuffer(); + sb.append("[") + .append(name) + .append(",") + .append(type) + .append(",") + .append(required) + .append(",") + .append(defaultValue) + .append(",") + .append(importance) + .append(",") + .append(documentation) + .append(",") + .append(group) + .append(",") + .append(orderInGroup) + .append(",") + .append(width) + .append(",") + .append(displayName) + .append(",") + .append(dependents) + .append("]"); + return sb.toString(); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ConfigValueInfo.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ConfigValueInfo.java new file mode 100644 index 0000000..abdcf93 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ConfigValueInfo.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest.entities; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; +import java.util.Objects; + +public class ConfigValueInfo { + private String name; + private String value; + private List recommendedValues; + private List errors; + private boolean visible; + + @JsonCreator + public ConfigValueInfo( + @JsonProperty("name") String name, + @JsonProperty("value") String value, + @JsonProperty("recommended_values") List recommendedValues, + @JsonProperty("errors") List errors, + @JsonProperty("visible") boolean visible) { + this.name = name; + this.value = value; + this.recommendedValues = recommendedValues; + this.errors = errors; + this.visible = visible; + } + + @JsonProperty + public String name() { + return name; + } + + @JsonProperty + public String value() { + return value; + } + + @JsonProperty("recommended_values") + public List recommendedValues() { + return recommendedValues; + } + + @JsonProperty + public List errors() { + return errors; + } + + @JsonProperty + public boolean visible() { + return visible; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ConfigValueInfo that = (ConfigValueInfo) o; + return Objects.equals(name, that.name) && + Objects.equals(value, that.value) && + Objects.equals(recommendedValues, that.recommendedValues) && + Objects.equals(errors, that.errors) && + Objects.equals(visible, that.visible); + } + + @Override + public int hashCode() { + return Objects.hash(name, value, recommendedValues, errors, visible); + } + + @Override + public String toString() { + StringBuffer sb = new StringBuffer(); + sb.append("[") + .append(name) + .append(",") + .append(value) + .append(",") + .append(recommendedValues) + .append(",") + .append(errors) + .append(",") + .append(visible) + .append("]"); + return sb.toString(); + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ConnectorInfo.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ConnectorInfo.java new file mode 100644 index 0000000..f36ee74 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ConnectorInfo.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest.entities; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.apache.kafka.connect.util.ConnectorTaskId; + +import java.util.List; +import java.util.Map; +import java.util.Objects; + +public class ConnectorInfo { + + private final String name; + private final Map config; + private final List tasks; + private final ConnectorType type; + + @JsonCreator + public ConnectorInfo(@JsonProperty("name") String name, + @JsonProperty("config") Map config, + @JsonProperty("tasks") List tasks, + @JsonProperty("type") ConnectorType type) { + this.name = name; + this.config = config; + this.tasks = tasks; + this.type = type; + } + + + @JsonProperty + public String name() { + return name; + } + + @JsonProperty + public ConnectorType type() { + return type; + } + + @JsonProperty + public Map config() { + return config; + } + + @JsonProperty + public List tasks() { + return tasks; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ConnectorInfo that = (ConnectorInfo) o; + return Objects.equals(name, that.name) && + Objects.equals(config, that.config) && + Objects.equals(tasks, that.tasks) && + Objects.equals(type, that.type); + } + + @Override + public int hashCode() { + return Objects.hash(name, config, tasks, type); + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ConnectorPluginInfo.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ConnectorPluginInfo.java new file mode 100644 index 0000000..36b896f --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ConnectorPluginInfo.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest.entities; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.kafka.connect.connector.Connector; +import org.apache.kafka.connect.runtime.isolation.PluginDesc; + +import java.util.Objects; + +public class ConnectorPluginInfo { + private String className; + private ConnectorType type; + private String version; + + @JsonCreator + public ConnectorPluginInfo( + @JsonProperty("class") String className, + @JsonProperty("type") ConnectorType type, + @JsonProperty("version") String version + ) { + this.className = className; + this.type = type; + this.version = version; + } + + public ConnectorPluginInfo(PluginDesc plugin) { + this(plugin.className(), ConnectorType.from(plugin.pluginClass()), plugin.version()); + } + + @JsonProperty("class") + public String className() { + return className; + } + + @JsonProperty("type") + public ConnectorType type() { + return type; + } + + @JsonProperty("version") + public String version() { + return version; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ConnectorPluginInfo that = (ConnectorPluginInfo) o; + return Objects.equals(className, that.className) && + type == that.type && + Objects.equals(version, that.version); + } + + @Override + public int hashCode() { + return Objects.hash(className, type, version); + } + + @Override + public String toString() { + final StringBuilder sb = new StringBuilder("ConnectorPluginInfo{"); + sb.append("className='").append(className).append('\''); + sb.append(", type=").append(type); + sb.append(", version='").append(version).append('\''); + sb.append('}'); + return sb.toString(); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ConnectorStateInfo.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ConnectorStateInfo.java new file mode 100644 index 0000000..6280473 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ConnectorStateInfo.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest.entities; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; +import java.util.Objects; + +public class ConnectorStateInfo { + + private final String name; + private final ConnectorState connector; + private final List tasks; + private final ConnectorType type; + + @JsonCreator + public ConnectorStateInfo(@JsonProperty("name") String name, + @JsonProperty("connector") ConnectorState connector, + @JsonProperty("tasks") List tasks, + @JsonProperty("type") ConnectorType type) { + this.name = name; + this.connector = connector; + this.tasks = tasks; + this.type = type; + } + + @JsonProperty + public String name() { + return name; + } + + @JsonProperty + public ConnectorState connector() { + return connector; + } + + @JsonProperty + public List tasks() { + return tasks; + } + + @JsonProperty + public ConnectorType type() { + return type; + } + + public abstract static class AbstractState { + private final String state; + private final String trace; + private final String workerId; + + public AbstractState(String state, String workerId, String trace) { + this.state = state; + this.workerId = workerId; + this.trace = trace; + } + + @JsonProperty + public String state() { + return state; + } + + @JsonProperty("worker_id") + public String workerId() { + return workerId; + } + + @JsonProperty + @JsonInclude(JsonInclude.Include.NON_EMPTY) + public String trace() { + return trace; + } + } + + public static class ConnectorState extends AbstractState { + @JsonCreator + public ConnectorState(@JsonProperty("state") String state, + @JsonProperty("worker_id") String worker, + @JsonProperty("msg") String msg) { + super(state, worker, msg); + } + } + + public static class TaskState extends AbstractState implements Comparable { + private final int id; + + @JsonCreator + public TaskState(@JsonProperty("id") int id, + @JsonProperty("state") String state, + @JsonProperty("worker_id") String worker, + @JsonProperty("msg") String msg) { + super(state, worker, msg); + this.id = id; + } + + @JsonProperty + public int id() { + return id; + } + + @Override + public int compareTo(TaskState that) { + return Integer.compare(this.id, that.id); + } + + @Override + public boolean equals(Object o) { + if (o == this) + return true; + if (!(o instanceof TaskState)) + return false; + TaskState other = (TaskState) o; + return compareTo(other) == 0; + } + + @Override + public int hashCode() { + return Objects.hash(id); + } + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ConnectorType.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ConnectorType.java new file mode 100644 index 0000000..292a1ee --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ConnectorType.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest.entities; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonValue; + +import org.apache.kafka.connect.connector.Connector; +import org.apache.kafka.connect.sink.SinkConnector; +import org.apache.kafka.connect.source.SourceConnector; + +import java.util.Locale; + +public enum ConnectorType { + SOURCE, SINK, UNKNOWN; + + public static ConnectorType from(Class clazz) { + if (SinkConnector.class.isAssignableFrom(clazz)) { + return SINK; + } + if (SourceConnector.class.isAssignableFrom(clazz)) { + return SOURCE; + } + + return UNKNOWN; + } + + @Override + @JsonValue + public String toString() { + return super.toString().toLowerCase(Locale.ROOT); + } + + @JsonCreator + public static ConnectorType forValue(String value) { + return ConnectorType.valueOf(value.toUpperCase(Locale.ROOT)); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/CreateConnectorRequest.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/CreateConnectorRequest.java new file mode 100644 index 0000000..1c52d8d --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/CreateConnectorRequest.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest.entities; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Map; +import java.util.Objects; + +public class CreateConnectorRequest { + private final String name; + private final Map config; + + @JsonCreator + public CreateConnectorRequest(@JsonProperty("name") String name, @JsonProperty("config") Map config) { + this.name = name; + this.config = config; + } + + @JsonProperty + public String name() { + return name; + } + + @JsonProperty + public Map config() { + return config; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + CreateConnectorRequest that = (CreateConnectorRequest) o; + return Objects.equals(name, that.name) && + Objects.equals(config, that.config); + } + + @Override + public int hashCode() { + return Objects.hash(name, config); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ErrorMessage.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ErrorMessage.java new file mode 100644 index 0000000..ecc4de5 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ErrorMessage.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest.entities; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; + +/** + * Standard error format for all REST API failures. These are generated automatically by + * {@link org.apache.kafka.connect.runtime.rest.errors.ConnectExceptionMapper} in response to uncaught + * {@link org.apache.kafka.connect.errors.ConnectException}s. + */ +public class ErrorMessage { + private final int errorCode; + private final String message; + + @JsonCreator + public ErrorMessage(@JsonProperty("error_code") int errorCode, @JsonProperty("message") String message) { + this.errorCode = errorCode; + this.message = message; + } + + @JsonProperty("error_code") + public int errorCode() { + return errorCode; + } + + @JsonProperty + public String message() { + return message; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ErrorMessage that = (ErrorMessage) o; + return Objects.equals(errorCode, that.errorCode) && + Objects.equals(message, that.message); + } + + @Override + public int hashCode() { + return Objects.hash(errorCode, message); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ServerInfo.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ServerInfo.java new file mode 100644 index 0000000..e5c5553 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/ServerInfo.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest.entities; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.kafka.common.utils.AppInfoParser; + +public class ServerInfo { + private final String version; + private final String commit; + private final String kafkaClusterId; + + @JsonCreator + private ServerInfo(@JsonProperty("version") String version, + @JsonProperty("commit") String commit, + @JsonProperty("kafka_cluster_id") String kafkaClusterId) { + this.version = version; + this.commit = commit; + this.kafkaClusterId = kafkaClusterId; + } + + public ServerInfo(String kafkaClusterId) { + this(AppInfoParser.getVersion(), AppInfoParser.getCommitId(), kafkaClusterId); + } + + @JsonProperty + public String version() { + return version; + } + + @JsonProperty + public String commit() { + return commit; + } + + @JsonProperty("kafka_cluster_id") + public String clusterId() { + return kafkaClusterId; + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/TaskInfo.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/TaskInfo.java new file mode 100644 index 0000000..8e6f3d7 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/entities/TaskInfo.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest.entities; + +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.kafka.connect.util.ConnectorTaskId; + +import java.util.Map; +import java.util.Objects; + +public class TaskInfo { + private final ConnectorTaskId id; + private final Map config; + + public TaskInfo(ConnectorTaskId id, Map config) { + this.id = id; + this.config = config; + } + + @JsonProperty + public ConnectorTaskId id() { + return id; + } + + @JsonProperty + public Map config() { + return config; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TaskInfo taskInfo = (TaskInfo) o; + return Objects.equals(id, taskInfo.id) && + Objects.equals(config, taskInfo.config); + } + + @Override + public int hashCode() { + return Objects.hash(id, config); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/errors/BadRequestException.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/errors/BadRequestException.java new file mode 100644 index 0000000..bc9c7f2 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/errors/BadRequestException.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest.errors; + +import javax.ws.rs.core.Response; + +public class BadRequestException extends ConnectRestException { + + public BadRequestException(String message) { + super(Response.Status.BAD_REQUEST, message); + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/errors/ConnectExceptionMapper.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/errors/ConnectExceptionMapper.java new file mode 100644 index 0000000..8678fbf --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/errors/ConnectExceptionMapper.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest.errors; + +import org.apache.kafka.connect.errors.AlreadyExistsException; +import org.apache.kafka.connect.errors.NotFoundException; +import org.apache.kafka.connect.runtime.rest.entities.ErrorMessage; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.Context; +import javax.ws.rs.core.Response; +import javax.ws.rs.core.UriInfo; +import javax.ws.rs.ext.ExceptionMapper; + +public class ConnectExceptionMapper implements ExceptionMapper { + private static final Logger log = LoggerFactory.getLogger(ConnectExceptionMapper.class); + + @Context + private UriInfo uriInfo; + + @Override + public Response toResponse(Exception exception) { + log.debug("Uncaught exception in REST call to /{}", uriInfo.getPath(), exception); + + if (exception instanceof ConnectRestException) { + ConnectRestException restException = (ConnectRestException) exception; + return Response.status(restException.statusCode()) + .entity(new ErrorMessage(restException.errorCode(), restException.getMessage())) + .build(); + } + + if (exception instanceof NotFoundException) { + return Response.status(Response.Status.NOT_FOUND) + .entity(new ErrorMessage(Response.Status.NOT_FOUND.getStatusCode(), exception.getMessage())) + .build(); + } + + if (exception instanceof AlreadyExistsException) { + return Response.status(Response.Status.CONFLICT) + .entity(new ErrorMessage(Response.Status.CONFLICT.getStatusCode(), exception.getMessage())) + .build(); + } + + if (!log.isDebugEnabled()) { + log.error("Uncaught exception in REST call to /{}", uriInfo.getPath(), exception); + } + + final int statusCode; + if (exception instanceof WebApplicationException) { + Response.StatusType statusInfo = ((WebApplicationException) exception).getResponse().getStatusInfo(); + statusCode = statusInfo.getStatusCode(); + } else { + statusCode = Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(); + } + return Response.status(statusCode) + .entity(new ErrorMessage(statusCode, exception.getMessage())) + .build(); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/errors/ConnectRestException.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/errors/ConnectRestException.java new file mode 100644 index 0000000..f45f72d --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/errors/ConnectRestException.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest.errors; + +import org.apache.kafka.connect.errors.ConnectException; + +import javax.ws.rs.core.Response; + +public class ConnectRestException extends ConnectException { + private final int statusCode; + private final int errorCode; + + public ConnectRestException(int statusCode, int errorCode, String message, Throwable t) { + super(message, t); + this.statusCode = statusCode; + this.errorCode = errorCode; + } + + public ConnectRestException(Response.Status status, int errorCode, String message, Throwable t) { + this(status.getStatusCode(), errorCode, message, t); + } + + public ConnectRestException(int statusCode, int errorCode, String message) { + this(statusCode, errorCode, message, null); + } + + public ConnectRestException(Response.Status status, int errorCode, String message) { + this(status, errorCode, message, null); + } + + public ConnectRestException(int statusCode, String message, Throwable t) { + this(statusCode, statusCode, message, t); + } + + public ConnectRestException(Response.Status status, String message, Throwable t) { + this(status, status.getStatusCode(), message, t); + } + + public ConnectRestException(int statusCode, String message) { + this(statusCode, statusCode, message, null); + } + + public ConnectRestException(Response.Status status, String message) { + this(status.getStatusCode(), status.getStatusCode(), message, null); + } + + + public int statusCode() { + return statusCode; + } + + public int errorCode() { + return errorCode; + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/resources/ConnectorPluginsResource.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/resources/ConnectorPluginsResource.java new file mode 100644 index 0000000..0854c8f --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/resources/ConnectorPluginsResource.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest.resources; + +import org.apache.kafka.connect.connector.Connector; +import org.apache.kafka.connect.runtime.ConnectorConfig; +import org.apache.kafka.connect.runtime.Herder; +import org.apache.kafka.connect.runtime.isolation.PluginDesc; +import org.apache.kafka.connect.runtime.rest.entities.ConfigInfos; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorPluginInfo; +import org.apache.kafka.connect.runtime.rest.errors.ConnectRestException; +import org.apache.kafka.connect.tools.MockConnector; +import org.apache.kafka.connect.tools.MockSinkConnector; +import org.apache.kafka.connect.tools.MockSourceConnector; +import org.apache.kafka.connect.tools.SchemaSourceConnector; +import org.apache.kafka.connect.tools.VerifiableSinkConnector; +import org.apache.kafka.connect.tools.VerifiableSourceConnector; +import org.apache.kafka.connect.util.FutureCallback; + +import javax.ws.rs.BadRequestException; +import javax.ws.rs.Consumes; +import javax.ws.rs.GET; +import javax.ws.rs.PUT; +import javax.ws.rs.Path; +import javax.ws.rs.PathParam; +import javax.ws.rs.Produces; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +@Path("/connector-plugins") +@Produces(MediaType.APPLICATION_JSON) +@Consumes(MediaType.APPLICATION_JSON) +public class ConnectorPluginsResource { + + private static final String ALIAS_SUFFIX = "Connector"; + private final Herder herder; + private final List connectorPlugins; + + private static final List> CONNECTOR_EXCLUDES = Arrays.asList( + VerifiableSourceConnector.class, VerifiableSinkConnector.class, + MockConnector.class, MockSourceConnector.class, MockSinkConnector.class, + SchemaSourceConnector.class + ); + + public ConnectorPluginsResource(Herder herder) { + this.herder = herder; + this.connectorPlugins = new ArrayList<>(); + } + + @PUT + @Path("/{connectorType}/config/validate") + public ConfigInfos validateConfigs( + final @PathParam("connectorType") String connType, + final Map connectorConfig + ) throws Throwable { + String includedConnType = connectorConfig.get(ConnectorConfig.CONNECTOR_CLASS_CONFIG); + if (includedConnType != null + && !normalizedPluginName(includedConnType).endsWith(normalizedPluginName(connType))) { + throw new BadRequestException( + "Included connector type " + includedConnType + " does not match request type " + + connType + ); + } + + // the validated configs don't need to be logged + FutureCallback validationCallback = new FutureCallback<>(); + herder.validateConnectorConfig(connectorConfig, validationCallback, false); + + try { + return validationCallback.get(ConnectorsResource.REQUEST_TIMEOUT_MS, TimeUnit.MILLISECONDS); + } catch (TimeoutException e) { + // This timeout is for the operation itself. None of the timeout error codes are relevant, so internal server + // error is the best option + throw new ConnectRestException(Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), "Request timed out"); + } catch (InterruptedException e) { + throw new ConnectRestException(Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), "Request interrupted"); + } + } + + @GET + @Path("/") + public List listConnectorPlugins() { + return getConnectorPlugins(); + } + + // TODO: improve once plugins are allowed to be added/removed during runtime. + private synchronized List getConnectorPlugins() { + if (connectorPlugins.isEmpty()) { + for (PluginDesc plugin : herder.plugins().connectors()) { + if (!CONNECTOR_EXCLUDES.contains(plugin.pluginClass())) { + connectorPlugins.add(new ConnectorPluginInfo(plugin)); + } + } + } + + return Collections.unmodifiableList(connectorPlugins); + } + + private String normalizedPluginName(String pluginName) { + // Works for both full and simple class names. In the latter case, it generates the alias. + return pluginName.endsWith(ALIAS_SUFFIX) && pluginName.length() > ALIAS_SUFFIX.length() + ? pluginName.substring(0, pluginName.length() - ALIAS_SUFFIX.length()) + : pluginName; + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/resources/ConnectorsResource.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/resources/ConnectorsResource.java new file mode 100644 index 0000000..18b10c9 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/resources/ConnectorsResource.java @@ -0,0 +1,458 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest.resources; + +import com.fasterxml.jackson.core.type.TypeReference; + +import javax.ws.rs.DefaultValue; +import javax.ws.rs.core.HttpHeaders; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.kafka.connect.errors.NotFoundException; +import org.apache.kafka.connect.runtime.ConnectorConfig; +import org.apache.kafka.connect.runtime.Herder; +import org.apache.kafka.connect.runtime.RestartRequest; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.apache.kafka.connect.runtime.distributed.RebalanceNeededException; +import org.apache.kafka.connect.runtime.distributed.RequestTargetException; +import org.apache.kafka.connect.runtime.rest.InternalRequestSignature; +import org.apache.kafka.connect.runtime.rest.RestClient; +import org.apache.kafka.connect.runtime.rest.entities.ActiveTopicsInfo; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorInfo; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorStateInfo; +import org.apache.kafka.connect.runtime.rest.entities.CreateConnectorRequest; +import org.apache.kafka.connect.runtime.rest.entities.TaskInfo; +import org.apache.kafka.connect.runtime.rest.errors.ConnectRestException; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.apache.kafka.connect.util.FutureCallback; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.servlet.ServletContext; +import javax.ws.rs.BadRequestException; +import javax.ws.rs.Consumes; +import javax.ws.rs.DELETE; +import javax.ws.rs.GET; +import javax.ws.rs.POST; +import javax.ws.rs.PUT; +import javax.ws.rs.Path; +import javax.ws.rs.PathParam; +import javax.ws.rs.Produces; +import javax.ws.rs.QueryParam; +import javax.ws.rs.core.Context; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; +import javax.ws.rs.core.UriBuilder; +import javax.ws.rs.core.UriInfo; + +import java.net.URI; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import static org.apache.kafka.connect.runtime.WorkerConfig.TOPIC_TRACKING_ALLOW_RESET_CONFIG; +import static org.apache.kafka.connect.runtime.WorkerConfig.TOPIC_TRACKING_ENABLE_CONFIG; + +@Path("/connectors") +@Produces(MediaType.APPLICATION_JSON) +@Consumes(MediaType.APPLICATION_JSON) +public class ConnectorsResource { + private static final Logger log = LoggerFactory.getLogger(ConnectorsResource.class); + private static final TypeReference>> TASK_CONFIGS_TYPE = + new TypeReference>>() { }; + + // TODO: This should not be so long. However, due to potentially long rebalances that may have to wait a full + // session timeout to complete, during which we cannot serve some requests. Ideally we could reduce this, but + // we need to consider all possible scenarios this could fail. It might be ok to fail with a timeout in rare cases, + // but currently a worker simply leaving the group can take this long as well. + public static final long REQUEST_TIMEOUT_MS = 90 * 1000; + // Mutable for integration testing; otherwise, some tests would take at least REQUEST_TIMEOUT_MS + // to run + private static long requestTimeoutMs = REQUEST_TIMEOUT_MS; + + private final Herder herder; + private final WorkerConfig config; + @javax.ws.rs.core.Context + private ServletContext context; + private final boolean isTopicTrackingDisabled; + private final boolean isTopicTrackingResetDisabled; + + public ConnectorsResource(Herder herder, WorkerConfig config) { + this.herder = herder; + this.config = config; + isTopicTrackingDisabled = !config.getBoolean(TOPIC_TRACKING_ENABLE_CONFIG); + isTopicTrackingResetDisabled = !config.getBoolean(TOPIC_TRACKING_ALLOW_RESET_CONFIG); + } + + // For testing purposes only + public static void setRequestTimeout(long requestTimeoutMs) { + ConnectorsResource.requestTimeoutMs = requestTimeoutMs; + } + + public static void resetRequestTimeout() { + ConnectorsResource.requestTimeoutMs = REQUEST_TIMEOUT_MS; + } + + @GET + @Path("/") + public Response listConnectors( + final @Context UriInfo uriInfo, + final @Context HttpHeaders headers + ) { + if (uriInfo.getQueryParameters().containsKey("expand")) { + Map> out = new HashMap<>(); + for (String connector : herder.connectors()) { + try { + Map connectorExpansions = new HashMap<>(); + for (String expansion : uriInfo.getQueryParameters().get("expand")) { + switch (expansion) { + case "status": + connectorExpansions.put("status", herder.connectorStatus(connector)); + break; + case "info": + connectorExpansions.put("info", herder.connectorInfo(connector)); + break; + default: + log.info("Ignoring unknown expansion type {}", expansion); + } + } + out.put(connector, connectorExpansions); + } catch (NotFoundException e) { + // this likely means that a connector has been removed while we look its info up + // we can just not include this connector in the return entity + log.debug("Unable to get connector info for {} on this worker", connector); + } + + } + return Response.ok(out).build(); + } else { + return Response.ok(herder.connectors()).build(); + } + } + + @POST + @Path("/") + public Response createConnector(final @QueryParam("forward") Boolean forward, + final @Context HttpHeaders headers, + final CreateConnectorRequest createRequest) throws Throwable { + // Trim leading and trailing whitespaces from the connector name, replace null with empty string + // if no name element present to keep validation within validator (NonEmptyStringWithoutControlChars + // allows null values) + String name = createRequest.name() == null ? "" : createRequest.name().trim(); + + Map configs = createRequest.config(); + checkAndPutConnectorConfigName(name, configs); + + FutureCallback> cb = new FutureCallback<>(); + herder.putConnectorConfig(name, configs, false, cb); + Herder.Created info = completeOrForwardRequest(cb, "/connectors", "POST", headers, createRequest, + new TypeReference() { }, new CreatedConnectorInfoTranslator(), forward); + + URI location = UriBuilder.fromUri("/connectors").path(name).build(); + return Response.created(location).entity(info.result()).build(); + } + + @GET + @Path("/{connector}") + public ConnectorInfo getConnector(final @PathParam("connector") String connector, + final @Context HttpHeaders headers, + final @QueryParam("forward") Boolean forward) throws Throwable { + FutureCallback cb = new FutureCallback<>(); + herder.connectorInfo(connector, cb); + return completeOrForwardRequest(cb, "/connectors/" + connector, "GET", headers, null, forward); + } + + @GET + @Path("/{connector}/config") + public Map getConnectorConfig(final @PathParam("connector") String connector, + final @Context HttpHeaders headers, + final @QueryParam("forward") Boolean forward) throws Throwable { + FutureCallback> cb = new FutureCallback<>(); + herder.connectorConfig(connector, cb); + return completeOrForwardRequest(cb, "/connectors/" + connector + "/config", "GET", headers, null, forward); + } + + @GET + @Path("/{connector}/tasks-config") + public Map> getTasksConfig( + final @PathParam("connector") String connector, + final @Context HttpHeaders headers, + final @QueryParam("forward") Boolean forward) throws Throwable { + FutureCallback>> cb = new FutureCallback<>(); + herder.tasksConfig(connector, cb); + return completeOrForwardRequest(cb, "/connectors/" + connector + "/tasks-config", "GET", headers, null, forward); + } + + @GET + @Path("/{connector}/status") + public ConnectorStateInfo getConnectorStatus(final @PathParam("connector") String connector) { + return herder.connectorStatus(connector); + } + + @GET + @Path("/{connector}/topics") + public Response getConnectorActiveTopics(final @PathParam("connector") String connector) { + if (isTopicTrackingDisabled) { + throw new ConnectRestException(Response.Status.FORBIDDEN.getStatusCode(), + "Topic tracking is disabled."); + } + ActiveTopicsInfo info = herder.connectorActiveTopics(connector); + return Response.ok(Collections.singletonMap(info.connector(), info)).build(); + } + + @PUT + @Path("/{connector}/topics/reset") + public Response resetConnectorActiveTopics(final @PathParam("connector") String connector, final @Context HttpHeaders headers) { + if (isTopicTrackingDisabled) { + throw new ConnectRestException(Response.Status.FORBIDDEN.getStatusCode(), + "Topic tracking is disabled."); + } + if (isTopicTrackingResetDisabled) { + throw new ConnectRestException(Response.Status.FORBIDDEN.getStatusCode(), + "Topic tracking reset is disabled."); + } + herder.resetConnectorActiveTopics(connector); + return Response.accepted().build(); + } + + @PUT + @Path("/{connector}/config") + public Response putConnectorConfig(final @PathParam("connector") String connector, + final @Context HttpHeaders headers, + final @QueryParam("forward") Boolean forward, + final Map connectorConfig) throws Throwable { + FutureCallback> cb = new FutureCallback<>(); + checkAndPutConnectorConfigName(connector, connectorConfig); + + herder.putConnectorConfig(connector, connectorConfig, true, cb); + Herder.Created createdInfo = completeOrForwardRequest(cb, "/connectors/" + connector + "/config", + "PUT", headers, connectorConfig, new TypeReference() { }, new CreatedConnectorInfoTranslator(), forward); + Response.ResponseBuilder response; + if (createdInfo.created()) { + URI location = UriBuilder.fromUri("/connectors").path(connector).build(); + response = Response.created(location); + } else { + response = Response.ok(); + } + return response.entity(createdInfo.result()).build(); + } + + @POST + @Path("/{connector}/restart") + public Response restartConnector(final @PathParam("connector") String connector, + final @Context HttpHeaders headers, + final @DefaultValue("false") @QueryParam("includeTasks") Boolean includeTasks, + final @DefaultValue("false") @QueryParam("onlyFailed") Boolean onlyFailed, + final @QueryParam("forward") Boolean forward) throws Throwable { + RestartRequest restartRequest = new RestartRequest(connector, onlyFailed, includeTasks); + String forwardingPath = "/connectors/" + connector + "/restart"; + if (restartRequest.forceRestartConnectorOnly()) { + // For backward compatibility, just restart the connector instance and return OK with no body + FutureCallback cb = new FutureCallback<>(); + herder.restartConnector(connector, cb); + completeOrForwardRequest(cb, forwardingPath, "POST", headers, null, forward); + return Response.noContent().build(); + } + + // In all other cases, submit the async restart request and return connector state + FutureCallback cb = new FutureCallback<>(); + herder.restartConnectorAndTasks(restartRequest, cb); + Map queryParameters = new HashMap<>(); + queryParameters.put("includeTasks", includeTasks.toString()); + queryParameters.put("onlyFailed", onlyFailed.toString()); + ConnectorStateInfo stateInfo = completeOrForwardRequest(cb, forwardingPath, "POST", headers, queryParameters, null, new TypeReference() { + }, new IdentityTranslator<>(), forward); + return Response.accepted().entity(stateInfo).build(); + } + + @PUT + @Path("/{connector}/pause") + public Response pauseConnector(@PathParam("connector") String connector, final @Context HttpHeaders headers) { + herder.pauseConnector(connector); + return Response.accepted().build(); + } + + @PUT + @Path("/{connector}/resume") + public Response resumeConnector(@PathParam("connector") String connector) { + herder.resumeConnector(connector); + return Response.accepted().build(); + } + + @GET + @Path("/{connector}/tasks") + public List getTaskConfigs(final @PathParam("connector") String connector, + final @Context HttpHeaders headers, + final @QueryParam("forward") Boolean forward) throws Throwable { + FutureCallback> cb = new FutureCallback<>(); + herder.taskConfigs(connector, cb); + return completeOrForwardRequest(cb, "/connectors/" + connector + "/tasks", "GET", headers, null, new TypeReference>() { + }, forward); + } + + @POST + @Path("/{connector}/tasks") + public void putTaskConfigs(final @PathParam("connector") String connector, + final @Context HttpHeaders headers, + final @QueryParam("forward") Boolean forward, + final byte[] requestBody) throws Throwable { + List> taskConfigs = new ObjectMapper().readValue(requestBody, TASK_CONFIGS_TYPE); + FutureCallback cb = new FutureCallback<>(); + herder.putTaskConfigs(connector, taskConfigs, cb, InternalRequestSignature.fromHeaders(requestBody, headers)); + completeOrForwardRequest(cb, "/connectors/" + connector + "/tasks", "POST", headers, taskConfigs, forward); + } + + @GET + @Path("/{connector}/tasks/{task}/status") + public ConnectorStateInfo.TaskState getTaskStatus(final @PathParam("connector") String connector, + final @Context HttpHeaders headers, + final @PathParam("task") Integer task) { + return herder.taskStatus(new ConnectorTaskId(connector, task)); + } + + @POST + @Path("/{connector}/tasks/{task}/restart") + public void restartTask(final @PathParam("connector") String connector, + final @PathParam("task") Integer task, + final @Context HttpHeaders headers, + final @QueryParam("forward") Boolean forward) throws Throwable { + FutureCallback cb = new FutureCallback<>(); + ConnectorTaskId taskId = new ConnectorTaskId(connector, task); + herder.restartTask(taskId, cb); + completeOrForwardRequest(cb, "/connectors/" + connector + "/tasks/" + task + "/restart", "POST", headers, null, forward); + } + + @DELETE + @Path("/{connector}") + public void destroyConnector(final @PathParam("connector") String connector, + final @Context HttpHeaders headers, + final @QueryParam("forward") Boolean forward) throws Throwable { + FutureCallback> cb = new FutureCallback<>(); + herder.deleteConnectorConfig(connector, cb); + completeOrForwardRequest(cb, "/connectors/" + connector, "DELETE", headers, null, forward); + } + + // Check whether the connector name from the url matches the one (if there is one) provided in the connectorConfig + // object. Throw BadRequestException on mismatch, otherwise put connectorName in config + private void checkAndPutConnectorConfigName(String connectorName, Map connectorConfig) { + String includedName = connectorConfig.get(ConnectorConfig.NAME_CONFIG); + if (includedName != null) { + if (!includedName.equals(connectorName)) + throw new BadRequestException("Connector name configuration (" + includedName + ") doesn't match connector name in the URL (" + connectorName + ")"); + } else { + connectorConfig.put(ConnectorConfig.NAME_CONFIG, connectorName); + } + } + + // Wait for a FutureCallback to complete. If it succeeds, return the parsed response. If it fails, try to forward the + // request to the leader. + private T completeOrForwardRequest(FutureCallback cb, + String path, + String method, + HttpHeaders headers, + Map queryParameters, + Object body, + TypeReference resultType, + Translator translator, + Boolean forward) throws Throwable { + try { + return cb.get(requestTimeoutMs, TimeUnit.MILLISECONDS); + } catch (ExecutionException e) { + Throwable cause = e.getCause(); + + if (cause instanceof RequestTargetException) { + if (forward == null || forward) { + // the only time we allow recursive forwarding is when no forward flag has + // been set, which should only be seen by the first worker to handle a user request. + // this gives two total hops to resolve the request before giving up. + boolean recursiveForward = forward == null; + RequestTargetException targetException = (RequestTargetException) cause; + String forwardedUrl = targetException.forwardUrl(); + if (forwardedUrl == null) { + // the target didn't know of the leader at this moment. + throw new ConnectRestException(Response.Status.CONFLICT.getStatusCode(), + "Cannot complete request momentarily due to no known leader URL, " + + "likely because a rebalance was underway."); + } + UriBuilder uriBuilder = UriBuilder.fromUri(forwardedUrl) + .path(path) + .queryParam("forward", recursiveForward); + if (queryParameters != null) { + queryParameters.forEach((k, v) -> uriBuilder.queryParam(k, v)); + } + String forwardUrl = uriBuilder.build().toString(); + log.debug("Forwarding request {} {} {}", forwardUrl, method, body); + return translator.translate(RestClient.httpRequest(forwardUrl, method, headers, body, resultType, config)); + } else { + // we should find the right target for the query within two hops, so if + // we don't, it probably means that a rebalance has taken place. + throw new ConnectRestException(Response.Status.CONFLICT.getStatusCode(), + "Cannot complete request because of a conflicting operation (e.g. worker rebalance)"); + } + } else if (cause instanceof RebalanceNeededException) { + throw new ConnectRestException(Response.Status.CONFLICT.getStatusCode(), + "Cannot complete request momentarily due to stale configuration (typically caused by a concurrent config change)"); + } + + throw cause; + } catch (TimeoutException e) { + // This timeout is for the operation itself. None of the timeout error codes are relevant, so internal server + // error is the best option + throw new ConnectRestException(Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), "Request timed out"); + } catch (InterruptedException e) { + throw new ConnectRestException(Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), "Request interrupted"); + } + } + + private T completeOrForwardRequest(FutureCallback cb, String path, String method, HttpHeaders headers, Object body, + TypeReference resultType, Translator translator, Boolean forward) throws Throwable { + return completeOrForwardRequest(cb, path, method, headers, null, body, resultType, translator, forward); + } + + private T completeOrForwardRequest(FutureCallback cb, String path, String method, HttpHeaders headers, Object body, + TypeReference resultType, Boolean forward) throws Throwable { + return completeOrForwardRequest(cb, path, method, headers, body, resultType, new IdentityTranslator<>(), forward); + } + + private T completeOrForwardRequest(FutureCallback cb, String path, String method, HttpHeaders headers, + Object body, Boolean forward) throws Throwable { + return completeOrForwardRequest(cb, path, method, headers, body, null, new IdentityTranslator<>(), forward); + } + + private interface Translator { + T translate(RestClient.HttpResponse response); + } + + private static class IdentityTranslator implements Translator { + @Override + public T translate(RestClient.HttpResponse response) { + return response.body(); + } + } + + private static class CreatedConnectorInfoTranslator implements Translator, ConnectorInfo> { + @Override + public Herder.Created translate(RestClient.HttpResponse response) { + boolean created = response.status() == 201; + return new Herder.Created<>(created, response.body()); + } + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/resources/LoggingResource.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/resources/LoggingResource.java new file mode 100644 index 0000000..ce9ce14 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/resources/LoggingResource.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest.resources; + +import org.apache.kafka.connect.errors.NotFoundException; +import org.apache.kafka.connect.runtime.rest.errors.BadRequestException; +import org.apache.log4j.Level; +import org.apache.log4j.LogManager; +import org.apache.log4j.Logger; + +import javax.ws.rs.Consumes; +import javax.ws.rs.GET; +import javax.ws.rs.PUT; +import javax.ws.rs.Path; +import javax.ws.rs.PathParam; +import javax.ws.rs.Produces; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Enumeration; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.TreeMap; + +/** + * A set of endpoints to adjust the log levels of runtime loggers. + */ +@Path("/admin/loggers") +@Produces(MediaType.APPLICATION_JSON) +@Consumes(MediaType.APPLICATION_JSON) +public class LoggingResource { + + /** + * Log4j uses "root" (case insensitive) as name of the root logger. + */ + private static final String ROOT_LOGGER_NAME = "root"; + + /** + * List the current loggers that have their levels explicitly set and their log levels. + * + * @return a list of current loggers and their levels. + */ + @GET + @Path("/") + public Response listLoggers() { + Map> loggers = new TreeMap<>(); + Enumeration enumeration = currentLoggers(); + Collections.list(enumeration) + .stream() + .filter(logger -> logger.getLevel() != null) + .forEach(logger -> loggers.put(logger.getName(), levelToMap(logger))); + + Logger root = rootLogger(); + if (root.getLevel() != null) { + loggers.put(ROOT_LOGGER_NAME, levelToMap(root)); + } + + return Response.ok(loggers).build(); + } + + /** + * Get the log level of a named logger. + * + * @param namedLogger name of a logger + * @return level of the logger, effective level if the level was not explicitly set. + */ + @GET + @Path("/{logger}") + public Response getLogger(final @PathParam("logger") String namedLogger) { + Objects.requireNonNull(namedLogger, "require non-null name"); + + Logger logger = null; + if (ROOT_LOGGER_NAME.equalsIgnoreCase(namedLogger)) { + logger = rootLogger(); + } else { + Enumeration en = currentLoggers(); + // search within existing loggers for the given name. + // using LogManger.getLogger() will create a logger if it doesn't exist + // (potential leak since these don't get cleaned up). + while (en.hasMoreElements()) { + Logger l = en.nextElement(); + if (namedLogger.equals(l.getName())) { + logger = l; + break; + } + } + } + if (logger == null) { + throw new NotFoundException("Logger " + namedLogger + " not found."); + } else { + return Response.ok(effectiveLevelToMap(logger)).build(); + } + } + + + /** + * Adjust level of a named logger. if name corresponds to an ancestor, then the log level is applied to all child loggers. + * + * @param namedLogger name of the logger + * @param levelMap a map that is expected to contain one key 'level', and a value that is one of the log4j levels: + * DEBUG, ERROR, FATAL, INFO, TRACE, WARN + * @return names of loggers whose levels were modified + */ + @PUT + @Path("/{logger}") + public Response setLevel(final @PathParam("logger") String namedLogger, + final Map levelMap) { + String desiredLevelStr = levelMap.get("level"); + if (desiredLevelStr == null) { + throw new BadRequestException("Desired 'level' parameter was not specified in request."); + } + + Level level = Level.toLevel(desiredLevelStr.toUpperCase(Locale.ROOT), null); + if (level == null) { + throw new NotFoundException("invalid log level '" + desiredLevelStr + "'."); + } + + List childLoggers; + if (ROOT_LOGGER_NAME.equalsIgnoreCase(namedLogger)) { + childLoggers = Collections.list(currentLoggers()); + childLoggers.add(rootLogger()); + } else { + childLoggers = new ArrayList<>(); + Logger ancestorLogger = lookupLogger(namedLogger); + Enumeration en = currentLoggers(); + boolean present = false; + while (en.hasMoreElements()) { + Logger current = en.nextElement(); + if (current.getName().startsWith(namedLogger)) { + childLoggers.add(current); + } + if (namedLogger.equals(current.getName())) { + present = true; + } + } + if (!present) { + childLoggers.add(ancestorLogger); + } + } + + List modifiedLoggerNames = new ArrayList<>(); + for (Logger logger: childLoggers) { + logger.setLevel(level); + modifiedLoggerNames.add(logger.getName()); + } + Collections.sort(modifiedLoggerNames); + + return Response.ok(modifiedLoggerNames).build(); + } + + protected Logger lookupLogger(String namedLogger) { + return LogManager.getLogger(namedLogger); + } + + @SuppressWarnings("unchecked") + protected Enumeration currentLoggers() { + return LogManager.getCurrentLoggers(); + } + + protected Logger rootLogger() { + return LogManager.getRootLogger(); + } + + /** + * + * Map representation of a logger's effective log level. + * + * @param logger a non-null log4j logger + * @return a singleton map whose key is level and the value is the string representation of the logger's effective log level. + */ + private static Map effectiveLevelToMap(Logger logger) { + Level level = logger.getLevel(); + if (level == null) { + level = logger.getEffectiveLevel(); + } + return Collections.singletonMap("level", String.valueOf(level)); + } + + /** + * + * Map representation of a logger's log level. + * + * @param logger a non-null log4j logger + * @return a singleton map whose key is level and the value is the string representation of the logger's log level. + */ + private static Map levelToMap(Logger logger) { + return Collections.singletonMap("level", String.valueOf(logger.getLevel())); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/resources/RootResource.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/resources/RootResource.java new file mode 100644 index 0000000..9666bf1 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/resources/RootResource.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest.resources; + +import org.apache.kafka.connect.runtime.Herder; +import org.apache.kafka.connect.runtime.rest.entities.ServerInfo; + +import javax.ws.rs.GET; +import javax.ws.rs.Path; +import javax.ws.rs.Produces; +import javax.ws.rs.core.MediaType; + +@Path("/") +@Produces(MediaType.APPLICATION_JSON) +public class RootResource { + + private final Herder herder; + + public RootResource(Herder herder) { + this.herder = herder; + } + + @GET + @Path("/") + public ServerInfo serverInfo() { + return new ServerInfo(herder.kafkaClusterId()); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/util/SSLUtils.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/util/SSLUtils.java new file mode 100644 index 0000000..bf22bb6 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/rest/util/SSLUtils.java @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest.util; + +import org.apache.kafka.common.config.SslConfigs; +import org.apache.kafka.common.config.internals.BrokerSecurityConfigs; +import org.apache.kafka.common.config.types.Password; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.eclipse.jetty.util.ssl.SslContextFactory; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.regex.Pattern; + +/** + * Helper class for setting up SSL for RestServer and RestClient + */ +public class SSLUtils { + + private static final Pattern COMMA_WITH_WHITESPACE = Pattern.compile("\\s*,\\s*"); + + + /** + * Configures SSL/TLS for HTTPS Jetty Server using configs with the given prefix + */ + public static SslContextFactory createServerSideSslContextFactory(WorkerConfig config, String prefix) { + Map sslConfigValues = config.valuesWithPrefixAllOrNothing(prefix); + + final SslContextFactory.Server ssl = new SslContextFactory.Server(); + + configureSslContextFactoryKeyStore(ssl, sslConfigValues); + configureSslContextFactoryTrustStore(ssl, sslConfigValues); + configureSslContextFactoryAlgorithms(ssl, sslConfigValues); + configureSslContextFactoryAuthentication(ssl, sslConfigValues); + + return ssl; + } + + /** + * Configures SSL/TLS for HTTPS Jetty Server + */ + public static SslContextFactory createServerSideSslContextFactory(WorkerConfig config) { + return createServerSideSslContextFactory(config, "listeners.https."); + } + + /** + * Configures SSL/TLS for HTTPS Jetty Client + */ + public static SslContextFactory createClientSideSslContextFactory(WorkerConfig config) { + Map sslConfigValues = config.valuesWithPrefixAllOrNothing("listeners.https."); + + final SslContextFactory.Client ssl = new SslContextFactory.Client(); + + configureSslContextFactoryKeyStore(ssl, sslConfigValues); + configureSslContextFactoryTrustStore(ssl, sslConfigValues); + configureSslContextFactoryAlgorithms(ssl, sslConfigValues); + configureSslContextFactoryEndpointIdentification(ssl, sslConfigValues); + + return ssl; + } + + /** + * Configures KeyStore related settings in SslContextFactory + */ + protected static void configureSslContextFactoryKeyStore(SslContextFactory ssl, Map sslConfigValues) { + ssl.setKeyStoreType((String) getOrDefault(sslConfigValues, SslConfigs.SSL_KEYSTORE_TYPE_CONFIG, SslConfigs.DEFAULT_SSL_KEYSTORE_TYPE)); + + String sslKeystoreLocation = (String) sslConfigValues.get(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG); + if (sslKeystoreLocation != null) + ssl.setKeyStorePath(sslKeystoreLocation); + + Password sslKeystorePassword = (Password) sslConfigValues.get(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG); + if (sslKeystorePassword != null) + ssl.setKeyStorePassword(sslKeystorePassword.value()); + + Password sslKeyPassword = (Password) sslConfigValues.get(SslConfigs.SSL_KEY_PASSWORD_CONFIG); + if (sslKeyPassword != null) + ssl.setKeyManagerPassword(sslKeyPassword.value()); + } + + protected static Object getOrDefault(Map configMap, String key, Object defaultValue) { + if (configMap.containsKey(key)) + return configMap.get(key); + + return defaultValue; + } + + /** + * Configures TrustStore related settings in SslContextFactory + */ + protected static void configureSslContextFactoryTrustStore(SslContextFactory ssl, Map sslConfigValues) { + ssl.setTrustStoreType((String) getOrDefault(sslConfigValues, SslConfigs.SSL_TRUSTSTORE_TYPE_CONFIG, SslConfigs.DEFAULT_SSL_TRUSTSTORE_TYPE)); + + String sslTruststoreLocation = (String) sslConfigValues.get(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG); + if (sslTruststoreLocation != null) + ssl.setTrustStorePath(sslTruststoreLocation); + + Password sslTruststorePassword = (Password) sslConfigValues.get(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG); + if (sslTruststorePassword != null) + ssl.setTrustStorePassword(sslTruststorePassword.value()); + } + + /** + * Configures Protocol, Algorithm and Provider related settings in SslContextFactory + */ + @SuppressWarnings("unchecked") + protected static void configureSslContextFactoryAlgorithms(SslContextFactory ssl, Map sslConfigValues) { + List sslEnabledProtocols = (List) getOrDefault(sslConfigValues, SslConfigs.SSL_ENABLED_PROTOCOLS_CONFIG, Arrays.asList(COMMA_WITH_WHITESPACE.split(SslConfigs.DEFAULT_SSL_ENABLED_PROTOCOLS))); + ssl.setIncludeProtocols(sslEnabledProtocols.toArray(new String[0])); + + String sslProvider = (String) sslConfigValues.get(SslConfigs.SSL_PROVIDER_CONFIG); + if (sslProvider != null) + ssl.setProvider(sslProvider); + + ssl.setProtocol((String) getOrDefault(sslConfigValues, SslConfigs.SSL_PROTOCOL_CONFIG, SslConfigs.DEFAULT_SSL_PROTOCOL)); + + List sslCipherSuites = (List) sslConfigValues.get(SslConfigs.SSL_CIPHER_SUITES_CONFIG); + if (sslCipherSuites != null) + ssl.setIncludeCipherSuites(sslCipherSuites.toArray(new String[0])); + + ssl.setKeyManagerFactoryAlgorithm((String) getOrDefault(sslConfigValues, SslConfigs.SSL_KEYMANAGER_ALGORITHM_CONFIG, SslConfigs.DEFAULT_SSL_KEYMANGER_ALGORITHM)); + + String sslSecureRandomImpl = (String) sslConfigValues.get(SslConfigs.SSL_SECURE_RANDOM_IMPLEMENTATION_CONFIG); + if (sslSecureRandomImpl != null) + ssl.setSecureRandomAlgorithm(sslSecureRandomImpl); + + ssl.setTrustManagerFactoryAlgorithm((String) getOrDefault(sslConfigValues, SslConfigs.SSL_TRUSTMANAGER_ALGORITHM_CONFIG, SslConfigs.DEFAULT_SSL_TRUSTMANAGER_ALGORITHM)); + } + + /** + * Configures Protocol, Algorithm and Provider related settings in SslContextFactory + */ + protected static void configureSslContextFactoryEndpointIdentification(SslContextFactory ssl, Map sslConfigValues) { + String sslEndpointIdentificationAlg = (String) sslConfigValues.get(SslConfigs.SSL_ENDPOINT_IDENTIFICATION_ALGORITHM_CONFIG); + if (sslEndpointIdentificationAlg != null) + ssl.setEndpointIdentificationAlgorithm(sslEndpointIdentificationAlg); + } + + /** + * Configures Authentication related settings in SslContextFactory + */ + protected static void configureSslContextFactoryAuthentication(SslContextFactory.Server ssl, Map sslConfigValues) { + String sslClientAuth = (String) getOrDefault(sslConfigValues, BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "none"); + switch (sslClientAuth) { + case "requested": + ssl.setWantClientAuth(true); + break; + case "required": + ssl.setNeedClientAuth(true); + break; + default: + ssl.setNeedClientAuth(false); + ssl.setWantClientAuth(false); + } + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/standalone/StandaloneConfig.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/standalone/StandaloneConfig.java new file mode 100644 index 0000000..f950edf --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/standalone/StandaloneConfig.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.standalone; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.runtime.WorkerConfig; + +import java.util.Map; + +public class StandaloneConfig extends WorkerConfig { + private static final ConfigDef CONFIG; + + /** + * offset.storage.file.filename + */ + public static final String OFFSET_STORAGE_FILE_FILENAME_CONFIG = "offset.storage.file.filename"; + private static final String OFFSET_STORAGE_FILE_FILENAME_DOC = "File to store offset data in"; + + static { + CONFIG = baseConfigDef() + .define(OFFSET_STORAGE_FILE_FILENAME_CONFIG, + ConfigDef.Type.STRING, + ConfigDef.Importance.HIGH, + OFFSET_STORAGE_FILE_FILENAME_DOC); + } + + public StandaloneConfig(Map props) { + super(CONFIG, props); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/standalone/StandaloneHerder.java b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/standalone/StandaloneHerder.java new file mode 100644 index 0000000..dac389b --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/runtime/standalone/StandaloneHerder.java @@ -0,0 +1,509 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.standalone; + +import org.apache.kafka.connect.connector.policy.ConnectorClientConfigOverridePolicy; +import org.apache.kafka.connect.errors.AlreadyExistsException; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.errors.NotFoundException; +import org.apache.kafka.connect.runtime.AbstractHerder; +import org.apache.kafka.connect.runtime.ConnectorConfig; +import org.apache.kafka.connect.runtime.HerderConnectorContext; +import org.apache.kafka.connect.runtime.HerderRequest; +import org.apache.kafka.connect.runtime.RestartPlan; +import org.apache.kafka.connect.runtime.RestartRequest; +import org.apache.kafka.connect.runtime.SessionKey; +import org.apache.kafka.connect.runtime.SinkConnectorConfig; +import org.apache.kafka.connect.runtime.SourceConnectorConfig; +import org.apache.kafka.connect.runtime.TargetState; +import org.apache.kafka.connect.runtime.Worker; +import org.apache.kafka.connect.runtime.distributed.ClusterConfigState; +import org.apache.kafka.connect.runtime.rest.InternalRequestSignature; +import org.apache.kafka.connect.runtime.rest.entities.ConfigInfos; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorInfo; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorStateInfo; +import org.apache.kafka.connect.runtime.rest.entities.TaskInfo; +import org.apache.kafka.connect.storage.ConfigBackingStore; +import org.apache.kafka.connect.storage.MemoryConfigBackingStore; +import org.apache.kafka.connect.storage.MemoryStatusBackingStore; +import org.apache.kafka.connect.storage.StatusBackingStore; +import org.apache.kafka.connect.util.Callback; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; + + +/** + * Single process, in-memory "herder". Useful for a standalone Kafka Connect process. + */ +public class StandaloneHerder extends AbstractHerder { + private static final Logger log = LoggerFactory.getLogger(StandaloneHerder.class); + + private final AtomicLong requestSeqNum = new AtomicLong(); + private final ScheduledExecutorService requestExecutorService; + + private ClusterConfigState configState; + + public StandaloneHerder(Worker worker, String kafkaClusterId, + ConnectorClientConfigOverridePolicy connectorClientConfigOverridePolicy) { + this(worker, + worker.workerId(), + kafkaClusterId, + new MemoryStatusBackingStore(), + new MemoryConfigBackingStore(worker.configTransformer()), + connectorClientConfigOverridePolicy); + } + + // visible for testing + StandaloneHerder(Worker worker, + String workerId, + String kafkaClusterId, + StatusBackingStore statusBackingStore, + MemoryConfigBackingStore configBackingStore, + ConnectorClientConfigOverridePolicy connectorClientConfigOverridePolicy) { + super(worker, workerId, kafkaClusterId, statusBackingStore, configBackingStore, connectorClientConfigOverridePolicy); + this.configState = ClusterConfigState.EMPTY; + this.requestExecutorService = Executors.newSingleThreadScheduledExecutor(); + configBackingStore.setUpdateListener(new ConfigUpdateListener()); + } + + @Override + public synchronized void start() { + log.info("Herder starting"); + startServices(); + running = true; + log.info("Herder started"); + } + + @Override + public synchronized void stop() { + log.info("Herder stopping"); + requestExecutorService.shutdown(); + try { + if (!requestExecutorService.awaitTermination(30, TimeUnit.SECONDS)) + requestExecutorService.shutdownNow(); + } catch (InterruptedException e) { + // ignore + } + + // There's no coordination/hand-off to do here since this is all standalone. Instead, we + // should just clean up the stuff we normally would, i.e. cleanly checkpoint and shutdown all + // the tasks. + for (String connName : connectors()) { + removeConnectorTasks(connName); + worker.stopAndAwaitConnector(connName); + } + stopServices(); + running = false; + log.info("Herder stopped"); + } + + @Override + public int generation() { + return 0; + } + + @Override + public synchronized void connectors(Callback> callback) { + callback.onCompletion(null, connectors()); + } + + @Override + public synchronized void connectorInfo(String connName, Callback callback) { + ConnectorInfo connectorInfo = connectorInfo(connName); + if (connectorInfo == null) { + callback.onCompletion(new NotFoundException("Connector " + connName + " not found"), null); + return; + } + callback.onCompletion(null, connectorInfo); + } + + private synchronized ConnectorInfo createConnectorInfo(String connector) { + if (!configState.contains(connector)) + return null; + Map config = configState.rawConnectorConfig(connector); + return new ConnectorInfo(connector, config, configState.tasks(connector), + connectorTypeForClass(config.get(ConnectorConfig.CONNECTOR_CLASS_CONFIG))); + } + + @Override + protected synchronized Map rawConfig(String connName) { + return configState.rawConnectorConfig(connName); + } + + @Override + public synchronized void deleteConnectorConfig(String connName, Callback> callback) { + try { + if (!configState.contains(connName)) { + // Deletion, must already exist + callback.onCompletion(new NotFoundException("Connector " + connName + " not found", null), null); + return; + } + + removeConnectorTasks(connName); + worker.stopAndAwaitConnector(connName); + configBackingStore.removeConnectorConfig(connName); + onDeletion(connName); + callback.onCompletion(null, new Created<>(false, null)); + } catch (ConnectException e) { + callback.onCompletion(e, null); + } + + } + + @Override + public synchronized void putConnectorConfig(String connName, + final Map config, + boolean allowReplace, + final Callback> callback) { + try { + validateConnectorConfig(config, (error, configInfos) -> { + if (error != null) { + callback.onCompletion(error, null); + return; + } + + requestExecutorService.submit( + () -> putConnectorConfig(connName, config, allowReplace, callback, configInfos) + ); + }); + } catch (Throwable t) { + callback.onCompletion(t, null); + } + } + + private synchronized void putConnectorConfig(String connName, + final Map config, + boolean allowReplace, + final Callback> callback, + ConfigInfos configInfos) { + try { + if (maybeAddConfigErrors(configInfos, callback)) { + return; + } + + final boolean created; + if (configState.contains(connName)) { + if (!allowReplace) { + callback.onCompletion(new AlreadyExistsException("Connector " + connName + " already exists"), null); + return; + } + worker.stopAndAwaitConnector(connName); + created = false; + } else { + created = true; + } + + configBackingStore.putConnectorConfig(connName, config); + + startConnector(connName, (error, result) -> { + if (error != null) { + callback.onCompletion(error, null); + return; + } + + requestExecutorService.submit(() -> { + updateConnectorTasks(connName); + callback.onCompletion(null, new Created<>(created, createConnectorInfo(connName))); + }); + }); + } catch (Throwable t) { + callback.onCompletion(t, null); + } + } + + @Override + public synchronized void requestTaskReconfiguration(String connName) { + if (!worker.connectorNames().contains(connName)) { + log.error("Task that requested reconfiguration does not exist: {}", connName); + return; + } + updateConnectorTasks(connName); + } + + @Override + public synchronized void taskConfigs(String connName, Callback> callback) { + if (!configState.contains(connName)) { + callback.onCompletion(new NotFoundException("Connector " + connName + " not found", null), null); + return; + } + + List result = new ArrayList<>(); + for (ConnectorTaskId taskId : configState.tasks(connName)) + result.add(new TaskInfo(taskId, configState.rawTaskConfig(taskId))); + callback.onCompletion(null, result); + } + + @Override + public void putTaskConfigs(String connName, List> configs, Callback callback, InternalRequestSignature requestSignature) { + throw new UnsupportedOperationException("Kafka Connect in standalone mode does not support externally setting task configurations."); + } + + @Override + public synchronized void restartTask(ConnectorTaskId taskId, Callback cb) { + if (!configState.contains(taskId.connector())) + cb.onCompletion(new NotFoundException("Connector " + taskId.connector() + " not found", null), null); + + Map taskConfigProps = configState.taskConfig(taskId); + if (taskConfigProps == null) + cb.onCompletion(new NotFoundException("Task " + taskId + " not found", null), null); + Map connConfigProps = configState.connectorConfig(taskId.connector()); + + TargetState targetState = configState.targetState(taskId.connector()); + worker.stopAndAwaitTask(taskId); + if (worker.startTask(taskId, configState, connConfigProps, taskConfigProps, this, targetState)) + cb.onCompletion(null, null); + else + cb.onCompletion(new ConnectException("Failed to start task: " + taskId), null); + } + + @Override + public synchronized void restartConnector(String connName, Callback cb) { + if (!configState.contains(connName)) + cb.onCompletion(new NotFoundException("Connector " + connName + " not found", null), null); + + worker.stopAndAwaitConnector(connName); + + startConnector(connName, (error, result) -> cb.onCompletion(error, null)); + } + + @Override + public synchronized HerderRequest restartConnector(long delayMs, final String connName, final Callback cb) { + ScheduledFuture future = requestExecutorService.schedule( + () -> restartConnector(connName, cb), delayMs, TimeUnit.MILLISECONDS); + + return new StandaloneHerderRequest(requestSeqNum.incrementAndGet(), future); + } + + @Override + public synchronized void restartConnectorAndTasks(RestartRequest request, Callback cb) { + // Ensure the connector exists + String connectorName = request.connectorName(); + if (!configState.contains(connectorName)) { + cb.onCompletion(new NotFoundException("Unknown connector: " + connectorName, null), null); + return; + } + + Optional maybePlan = buildRestartPlan(request); + if (!maybePlan.isPresent()) { + cb.onCompletion(new NotFoundException("Status for connector " + connectorName + " not found", null), null); + return; + } + RestartPlan plan = maybePlan.get(); + + // If requested, stop the connector and any tasks, marking each as restarting + log.info("Received {}", plan); + if (plan.shouldRestartConnector()) { + worker.stopAndAwaitConnector(connectorName); + onRestart(connectorName); + } + if (plan.shouldRestartTasks()) { + // Stop the tasks and mark as restarting + worker.stopAndAwaitTasks(plan.taskIdsToRestart()); + plan.taskIdsToRestart().forEach(this::onRestart); + } + + // Now restart the connector and tasks + if (plan.shouldRestartConnector()) { + log.debug("Restarting connector '{}'", connectorName); + startConnector(connectorName, (error, targetState) -> { + if (error == null) { + log.info("Connector '{}' restart successful", connectorName); + } else { + log.error("Connector '{}' restart failed", connectorName, error); + } + }); + } + if (plan.shouldRestartTasks()) { + log.debug("Restarting {} of {} tasks for {}", plan.restartTaskCount(), plan.totalTaskCount(), request); + createConnectorTasks(connectorName, plan.taskIdsToRestart()); + log.debug("Restarted {} of {} tasks for {} as requested", plan.restartTaskCount(), plan.totalTaskCount(), request); + } + // Complete the restart request + log.info("Completed {}", plan); + cb.onCompletion(null, plan.restartConnectorStateInfo()); + } + + private void startConnector(String connName, Callback onStart) { + Map connConfigs = configState.connectorConfig(connName); + TargetState targetState = configState.targetState(connName); + worker.startConnector(connName, connConfigs, new HerderConnectorContext(this, connName), this, targetState, onStart); + } + + private List> recomputeTaskConfigs(String connName) { + Map config = configState.connectorConfig(connName); + + ConnectorConfig connConfig = worker.isSinkConnector(connName) ? + new SinkConnectorConfig(plugins(), config) : + new SourceConnectorConfig(plugins(), config, worker.isTopicCreationEnabled()); + + return worker.connectorTaskConfigs(connName, connConfig); + } + + private void createConnectorTasks(String connName) { + List taskIds = configState.tasks(connName); + createConnectorTasks(connName, taskIds); + } + + private void createConnectorTasks(String connName, Collection taskIds) { + TargetState initialState = configState.targetState(connName); + Map connConfigs = configState.connectorConfig(connName); + for (ConnectorTaskId taskId : taskIds) { + Map taskConfigMap = configState.taskConfig(taskId); + worker.startTask(taskId, configState, connConfigs, taskConfigMap, this, initialState); + } + } + + private void removeConnectorTasks(String connName) { + Collection tasks = configState.tasks(connName); + if (!tasks.isEmpty()) { + worker.stopAndAwaitTasks(tasks); + configBackingStore.removeTaskConfigs(connName); + tasks.forEach(this::onDeletion); + } + } + + private void updateConnectorTasks(String connName) { + if (!worker.isRunning(connName)) { + log.info("Skipping update of connector {} since it is not running", connName); + return; + } + + List> newTaskConfigs = recomputeTaskConfigs(connName); + List> oldTaskConfigs = configState.allTaskConfigs(connName); + + if (!newTaskConfigs.equals(oldTaskConfigs)) { + removeConnectorTasks(connName); + List> rawTaskConfigs = reverseTransform(connName, configState, newTaskConfigs); + configBackingStore.putTaskConfigs(connName, rawTaskConfigs); + createConnectorTasks(connName); + } + } + + // This update listener assumes synchronous updates the ConfigBackingStore, which only works + // with the MemoryConfigBackingStore. This allows us to write a change (e.g. through + // ConfigBackingStore.putConnectorConfig()) and then immediately read it back from an updated + // snapshot. + // TODO: To get any real benefit from the backing store abstraction, we should move some of + // the handling into the callbacks in this listener. + private class ConfigUpdateListener implements ConfigBackingStore.UpdateListener { + + @Override + public void onConnectorConfigRemove(String connector) { + synchronized (StandaloneHerder.this) { + configState = configBackingStore.snapshot(); + } + } + + @Override + public void onConnectorConfigUpdate(String connector) { + // TODO: move connector configuration update handling here to be consistent with + // the semantics of the config backing store + + synchronized (StandaloneHerder.this) { + configState = configBackingStore.snapshot(); + } + } + + @Override + public void onTaskConfigUpdate(Collection tasks) { + synchronized (StandaloneHerder.this) { + configState = configBackingStore.snapshot(); + } + } + + @Override + public void onConnectorTargetStateChange(String connector) { + synchronized (StandaloneHerder.this) { + configState = configBackingStore.snapshot(); + TargetState targetState = configState.targetState(connector); + worker.setTargetState(connector, targetState, (error, newState) -> { + if (error != null) { + log.error("Failed to transition connector {} to target state {}", connector, targetState, error); + return; + } + + if (newState == TargetState.STARTED) { + requestExecutorService.submit(() -> updateConnectorTasks(connector)); + } + }); + } + } + + @Override + public void onSessionKeyUpdate(SessionKey sessionKey) { + // no-op + } + + @Override + public void onRestartRequest(RestartRequest restartRequest) { + // no-op + } + } + + static class StandaloneHerderRequest implements HerderRequest { + private final long seq; + private final ScheduledFuture future; + + public StandaloneHerderRequest(long seq, ScheduledFuture future) { + this.seq = seq; + this.future = future; + } + + @Override + public void cancel() { + future.cancel(false); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof StandaloneHerderRequest)) + return false; + StandaloneHerderRequest other = (StandaloneHerderRequest) o; + return seq == other.seq; + } + + @Override + public int hashCode() { + return Objects.hash(seq); + } + } + + @Override + public void tasksConfig(String connName, Callback>> callback) { + Map> tasksConfig = buildTasksConfig(connName); + if (tasksConfig.isEmpty()) { + callback.onCompletion(new NotFoundException("Connector " + connName + " not found"), tasksConfig); + return; + } + callback.onCompletion(null, tasksConfig); + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/storage/CloseableOffsetStorageReader.java b/connect/runtime/src/main/java/org/apache/kafka/connect/storage/CloseableOffsetStorageReader.java new file mode 100644 index 0000000..b902739 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/storage/CloseableOffsetStorageReader.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import java.io.Closeable; +import java.util.Collection; +import java.util.Map; +import java.util.concurrent.Future; + +public interface CloseableOffsetStorageReader extends Closeable, OffsetStorageReader { + + /** + * {@link Future#cancel(boolean) Cancel} all outstanding offset read requests, and throw an + * exception in all current and future calls to {@link #offsets(Collection)} and + * {@link #offset(Map)}. This is useful for unblocking task threads which need to shut down but + * are blocked on offset reads. + */ + void close(); +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/storage/ConfigBackingStore.java b/connect/runtime/src/main/java/org/apache/kafka/connect/storage/ConfigBackingStore.java new file mode 100644 index 0000000..826f934 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/storage/ConfigBackingStore.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import org.apache.kafka.connect.runtime.RestartRequest; +import org.apache.kafka.connect.runtime.SessionKey; +import org.apache.kafka.connect.runtime.TargetState; +import org.apache.kafka.connect.runtime.distributed.ClusterConfigState; +import org.apache.kafka.connect.util.ConnectorTaskId; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public interface ConfigBackingStore { + + void start(); + + void stop(); + + /** + * Get a snapshot of the current configuration state including all connector and task + * configurations. + * @return the cluster config state + */ + ClusterConfigState snapshot(); + + /** + * Check if the store has configuration for a connector. + * @param connector name of the connector + * @return true if the backing store contains configuration for the connector + */ + boolean contains(String connector); + + /** + * Update the configuration for a connector. + * @param connector name of the connector + * @param properties the connector configuration + */ + void putConnectorConfig(String connector, Map properties); + + /** + * Remove configuration for a connector + * @param connector name of the connector + */ + void removeConnectorConfig(String connector); + + /** + * Update the task configurations for a connector. + * @param connector name of the connector + * @param configs the new task configs for the connector + */ + void putTaskConfigs(String connector, List> configs); + + /** + * Remove the task configs associated with a connector. + * @param connector name of the connector + */ + void removeTaskConfigs(String connector); + + /** + * Refresh the backing store. This forces the store to ensure that it has the latest + * configs that have been written. + * @param timeout max time to wait for the refresh to complete + * @param unit unit of timeout + * @throws TimeoutException if the timeout expires before the refresh has completed + */ + void refresh(long timeout, TimeUnit unit) throws TimeoutException; + + /** + * Transition a connector to a new target state (e.g. paused). + * @param connector name of the connector + * @param state the state to transition to + */ + void putTargetState(String connector, TargetState state); + + void putSessionKey(SessionKey sessionKey); + + /** + * Request a restart of a connector and optionally its tasks. + * @param restartRequest the restart request details + */ + void putRestartRequest(RestartRequest restartRequest); + + /** + * Set an update listener to get notifications when there are config/target state + * changes. + * @param listener non-null listener + */ + void setUpdateListener(UpdateListener listener); + + interface UpdateListener { + /** + * Invoked when a connector configuration has been removed + * @param connector name of the connector + */ + void onConnectorConfigRemove(String connector); + + /** + * Invoked when a connector configuration has been updated. + * @param connector name of the connector + */ + void onConnectorConfigUpdate(String connector); + + /** + * Invoked when task configs are updated. + * @param tasks all the tasks whose configs have been updated + */ + void onTaskConfigUpdate(Collection tasks); + + /** + * Invoked when the user has set a new target state (e.g. paused) + * @param connector name of the connector + */ + void onConnectorTargetStateChange(String connector); + + /** + * Invoked when the leader has distributed a new session key + * @param sessionKey the {@link SessionKey session key} + */ + void onSessionKeyUpdate(SessionKey sessionKey); + + /** + * Invoked when a connector and possibly its tasks have been requested to be restarted. + * @param restartRequest the {@link RestartRequest restart request} + */ + void onRestartRequest(RestartRequest restartRequest); + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/storage/FileOffsetBackingStore.java b/connect/runtime/src/main/java/org/apache/kafka/connect/storage/FileOffsetBackingStore.java new file mode 100644 index 0000000..8f828fb --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/storage/FileOffsetBackingStore.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.apache.kafka.connect.runtime.standalone.StandaloneConfig; +import org.apache.kafka.connect.util.SafeObjectInputStream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.EOFException; +import java.io.File; +import java.io.IOException; +import java.io.ObjectOutputStream; +import java.nio.ByteBuffer; +import java.nio.file.Files; +import java.nio.file.NoSuchFileException; +import java.util.HashMap; +import java.util.Map; + +/** + * Implementation of OffsetBackingStore that saves data locally to a file. To ensure this behaves + * similarly to a real backing store, operations are executed asynchronously on a background thread. + */ +public class FileOffsetBackingStore extends MemoryOffsetBackingStore { + private static final Logger log = LoggerFactory.getLogger(FileOffsetBackingStore.class); + + private File file; + + public FileOffsetBackingStore() { + + } + + @Override + public void configure(WorkerConfig config) { + super.configure(config); + file = new File(config.getString(StandaloneConfig.OFFSET_STORAGE_FILE_FILENAME_CONFIG)); + } + + @Override + public synchronized void start() { + super.start(); + log.info("Starting FileOffsetBackingStore with file {}", file); + load(); + } + + @Override + public synchronized void stop() { + super.stop(); + // Nothing to do since this doesn't maintain any outstanding connections/data + log.info("Stopped FileOffsetBackingStore"); + } + + @SuppressWarnings("unchecked") + private void load() { + try (SafeObjectInputStream is = new SafeObjectInputStream(Files.newInputStream(file.toPath()))) { + Object obj = is.readObject(); + if (!(obj instanceof HashMap)) + throw new ConnectException("Expected HashMap but found " + obj.getClass()); + Map raw = (Map) obj; + data = new HashMap<>(); + for (Map.Entry mapEntry : raw.entrySet()) { + ByteBuffer key = (mapEntry.getKey() != null) ? ByteBuffer.wrap(mapEntry.getKey()) : null; + ByteBuffer value = (mapEntry.getValue() != null) ? ByteBuffer.wrap(mapEntry.getValue()) : null; + data.put(key, value); + } + } catch (NoSuchFileException | EOFException e) { + // NoSuchFileException: Ignore, may be new. + // EOFException: Ignore, this means the file was missing or corrupt + } catch (IOException | ClassNotFoundException e) { + throw new ConnectException(e); + } + } + + @Override + protected void save() { + try (ObjectOutputStream os = new ObjectOutputStream(Files.newOutputStream(file.toPath()))) { + Map raw = new HashMap<>(); + for (Map.Entry mapEntry : data.entrySet()) { + byte[] key = (mapEntry.getKey() != null) ? mapEntry.getKey().array() : null; + byte[] value = (mapEntry.getValue() != null) ? mapEntry.getValue().array() : null; + raw.put(key, value); + } + os.writeObject(raw); + } catch (IOException e) { + throw new ConnectException(e); + } + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/storage/KafkaConfigBackingStore.java b/connect/runtime/src/main/java/org/apache/kafka/connect/storage/KafkaConfigBackingStore.java new file mode 100644 index 0000000..e77bcfa --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/storage/KafkaConfigBackingStore.java @@ -0,0 +1,904 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.config.TopicConfig; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.data.SchemaBuilder; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.errors.DataException; +import org.apache.kafka.connect.runtime.RestartRequest; +import org.apache.kafka.connect.runtime.SessionKey; +import org.apache.kafka.connect.runtime.TargetState; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.apache.kafka.connect.runtime.WorkerConfigTransformer; +import org.apache.kafka.connect.runtime.distributed.ClusterConfigState; +import org.apache.kafka.connect.runtime.distributed.DistributedConfig; +import org.apache.kafka.connect.util.Callback; +import org.apache.kafka.connect.util.ConnectUtils; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.apache.kafka.connect.util.KafkaBasedLog; +import org.apache.kafka.connect.util.SharedTopicAdmin; +import org.apache.kafka.connect.util.TopicAdmin; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.crypto.spec.SecretKeySpec; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Base64; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeSet; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.function.Supplier; + +/** + *

            + * Provides persistent storage of Kafka Connect connector configurations in a Kafka topic. + *

            + *

            + * This class manages both connector and task configurations. It tracks three types of configuration entries: + *

            + * 1. Connector config: map of string -> string configurations passed to the Connector class, with support for + * expanding this format if necessary. (Kafka key: connector-[connector-id]). + * These configs are *not* ephemeral. They represent the source of truth. If the entire Connect + * cluster goes down, this is all that is really needed to recover. + * 2. Task configs: map of string -> string configurations passed to the Task class, with support for expanding + * this format if necessary. (Kafka key: task-[connector-id]-[task-id]). + * These configs are ephemeral; they are stored here to a) disseminate them to all workers while + * ensuring agreement and b) to allow faster cluster/worker recovery since the common case + * of recovery (restoring a connector) will simply result in the same configuration as before + * the failure. + * 3. Task commit "configs": records indicating that previous task config entries should be committed and all task + * configs for a connector can be applied. (Kafka key: commit-[connector-id]. + * This config has two effects. First, it records the number of tasks the connector is currently + * running (and can therefore increase/decrease parallelism). Second, because each task config + * is stored separately but they need to be applied together to ensure each partition is assigned + * to a single task, this record also indicates that task configs for the specified connector + * can be "applied" or "committed". + *

            + *

            + * This configuration is expected to be stored in a *single partition* and *compacted* topic. Using a single partition + * ensures we can enforce ordering on messages, allowing Kafka to be used as a write ahead log. Compaction allows + * us to clean up outdated configurations over time. However, this combination has some important implications for + * the implementation of this class and the configuration state that it may expose. + *

            + *

            + * Connector configurations are independent of all other configs, so they are handled easily. Writing a single record + * is already atomic, so these can be applied as soon as they are read. One connectors config does not affect any + * others, and they do not need to coordinate with the connector's task configuration at all. + *

            + *

            + * The most obvious implication for task configs is the need for the commit messages. Because Kafka does not + * currently have multi-record transactions or support atomic batch record writes, task commit messages are required + * to ensure that readers do not end up using inconsistent configs. For example, consider if a connector wrote configs + * for its tasks, then was reconfigured and only managed to write updated configs for half its tasks. If task configs + * were applied immediately you could be using half the old configs and half the new configs. In that condition, some + * partitions may be double-assigned because the old config and new config may use completely different assignments. + * Therefore, when reading the log, we must buffer config updates for a connector's tasks and only apply atomically them + * once a commit message has been read. + *

            + *

            + * However, there are also further challenges. This simple buffering approach would work fine as long as the entire log was + * always available, but we would like to be able to enable compaction so our configuration topic does not grow + * indefinitely. Compaction may break a normal log because old entries will suddenly go missing. A new worker reading + * from the beginning of the log in order to build up the full current configuration will see task commits, but some + * records required for those commits will have been removed because the same keys have subsequently been rewritten. + * For example, if you have a sequence of record keys [connector-foo-config, task-foo-1-config, task-foo-2-config, + * commit-foo (2 tasks), task-foo-1-config, commit-foo (1 task)], we can end up with a compacted log containing + * [connector-foo-config, task-foo-2-config, commit-foo (2 tasks), task-foo-1-config, commit-foo (1 task)]. When read + * back, the first commit will see an invalid state because the first task-foo-1-config has been cleaned up. + *

            + *

            + * Compaction can further complicate things if writing new task configs fails mid-write. Consider a similar scenario + * as the previous one, but in this case both the first and second update will write 2 task configs. However, the + * second write fails half of the way through: + * [connector-foo-config, task-foo-1-config, task-foo-2-config, commit-foo (2 tasks), task-foo-1-config]. Now compaction + * occurs and we're left with + * [connector-foo-config, task-foo-2-config, commit-foo (2 tasks), task-foo-1-config]. At the first commit, we don't + * have a complete set of configs. And because of the failure, there is no second commit. We are left in an inconsistent + * state with no obvious way to resolve the issue -- we can try to keep on reading, but the failed node may never + * recover and write the updated config. Meanwhile, other workers may have seen the entire log; they will see the second + * task-foo-1-config waiting to be applied, but will otherwise think everything is ok -- they have a valid set of task + * configs for connector "foo". + *

            + *

            + * Because we can encounter these inconsistencies and addressing them requires support from the rest of the system + * (resolving the task configuration inconsistencies requires support from the connector instance to regenerate updated + * configs), this class exposes not only the current set of configs, but also which connectors have inconsistent data. + * This allows users of this class (i.e., Herder implementations) to take action to resolve any inconsistencies. These + * inconsistencies should be rare (as described above, due to compaction combined with leader failures in the middle + * of updating task configurations). + *

            + *

            + * Note that the expectation is that this config storage system has only a single writer at a time. + * The caller (Herder) must ensure this is the case. In distributed mode this will require forwarding config change + * requests to the leader in the cluster (i.e. the worker group coordinated by the Kafka broker). + *

            + *

            + * Since processing of the config log occurs in a background thread, callers must take care when using accessors. + * To simplify handling this correctly, this class only exposes a mechanism to snapshot the current state of the cluster. + * Updates may continue to be applied (and callbacks invoked) in the background. Callers must take care that they are + * using a consistent snapshot and only update when it is safe. In particular, if task configs are updated which require + * synchronization across workers to commit offsets and update the configuration, callbacks and updates during the + * rebalance must be deferred. + *

            + */ +public class KafkaConfigBackingStore implements ConfigBackingStore { + private static final Logger log = LoggerFactory.getLogger(KafkaConfigBackingStore.class); + + public static final String TARGET_STATE_PREFIX = "target-state-"; + + public static String TARGET_STATE_KEY(String connectorName) { + return TARGET_STATE_PREFIX + connectorName; + } + + public static final String CONNECTOR_PREFIX = "connector-"; + + public static String CONNECTOR_KEY(String connectorName) { + return CONNECTOR_PREFIX + connectorName; + } + + public static final String TASK_PREFIX = "task-"; + + public static String TASK_KEY(ConnectorTaskId taskId) { + return TASK_PREFIX + taskId.connector() + "-" + taskId.task(); + } + + public static final String COMMIT_TASKS_PREFIX = "commit-"; + + public static String COMMIT_TASKS_KEY(String connectorName) { + return COMMIT_TASKS_PREFIX + connectorName; + } + + public static final String SESSION_KEY_KEY = "session-key"; + + // Note that while using real serialization for values as we have here, but ad hoc string serialization for keys, + // isn't ideal, we use this approach because it avoids any potential problems with schema evolution or + // converter/serializer changes causing keys to change. We need to absolutely ensure that the keys remain precisely + // the same. + public static final Schema CONNECTOR_CONFIGURATION_V0 = SchemaBuilder.struct() + .field("properties", SchemaBuilder.map(Schema.STRING_SCHEMA, Schema.OPTIONAL_STRING_SCHEMA).build()) + .build(); + public static final Schema TASK_CONFIGURATION_V0 = CONNECTOR_CONFIGURATION_V0; + public static final Schema CONNECTOR_TASKS_COMMIT_V0 = SchemaBuilder.struct() + .field("tasks", Schema.INT32_SCHEMA) + .build(); + public static final Schema TARGET_STATE_V0 = SchemaBuilder.struct() + .field("state", Schema.STRING_SCHEMA) + .build(); + // The key is logically a byte array, but we can't use the JSON converter to (de-)serialize that without a schema. + // So instead, we base 64-encode it before serializing and decode it after deserializing. + public static final Schema SESSION_KEY_V0 = SchemaBuilder.struct() + .field("key", Schema.STRING_SCHEMA) + .field("algorithm", Schema.STRING_SCHEMA) + .field("creation-timestamp", Schema.INT64_SCHEMA) + .build(); + + public static final String RESTART_PREFIX = "restart-connector-"; + + public static String RESTART_KEY(String connectorName) { + return RESTART_PREFIX + connectorName; + } + + public static final boolean ONLY_FAILED_DEFAULT = false; + public static final boolean INCLUDE_TASKS_DEFAULT = false; + public static final String ONLY_FAILED_FIELD_NAME = "only-failed"; + public static final String INCLUDE_TASKS_FIELD_NAME = "include-tasks"; + public static final Schema RESTART_REQUEST_V0 = SchemaBuilder.struct() + .field(INCLUDE_TASKS_FIELD_NAME, Schema.BOOLEAN_SCHEMA) + .field(ONLY_FAILED_FIELD_NAME, Schema.BOOLEAN_SCHEMA) + .build(); + + private static final long READ_TO_END_TIMEOUT_MS = 30000; + + private final Object lock; + private final Converter converter; + private volatile boolean started; + // Although updateListener is not final, it's guaranteed to be visible to any thread after its + // initialization as long as we always read the volatile variable "started" before we access the listener. + private UpdateListener updateListener; + + private final String topic; + // Data is passed to the log already serialized. We use a converter to handle translating to/from generic Connect + // format to serialized form + private final KafkaBasedLog configLog; + // Connector -> # of tasks + private final Map connectorTaskCounts = new HashMap<>(); + // Connector and task configs: name or id -> config map + private final Map> connectorConfigs = new HashMap<>(); + private final Map> taskConfigs = new HashMap<>(); + private final Supplier topicAdminSupplier; + private SharedTopicAdmin ownTopicAdmin; + + // Set of connectors where we saw a task commit with an incomplete set of task config updates, indicating the data + // is in an inconsistent state and we cannot safely use them until they have been refreshed. + private final Set inconsistent = new HashSet<>(); + // The most recently read offset. This does not take into account deferred task updates/commits, so we may have + // outstanding data to be applied. + private volatile long offset; + // The most recently read session key, to use for validating internal REST requests. + private volatile SessionKey sessionKey; + + // Connector -> Map[ConnectorTaskId -> Configs] + private final Map>> deferredTaskUpdates = new HashMap<>(); + + private final Map connectorTargetStates = new HashMap<>(); + + private final WorkerConfigTransformer configTransformer; + + @Deprecated + public KafkaConfigBackingStore(Converter converter, WorkerConfig config, WorkerConfigTransformer configTransformer) { + this(converter, config, configTransformer, null); + } + + public KafkaConfigBackingStore(Converter converter, WorkerConfig config, WorkerConfigTransformer configTransformer, Supplier adminSupplier) { + this.lock = new Object(); + this.started = false; + this.converter = converter; + this.offset = -1; + this.topicAdminSupplier = adminSupplier; + + this.topic = config.getString(DistributedConfig.CONFIG_TOPIC_CONFIG); + if (this.topic == null || this.topic.trim().length() == 0) + throw new ConfigException("Must specify topic for connector configuration."); + + configLog = setupAndCreateKafkaBasedLog(this.topic, config); + this.configTransformer = configTransformer; + } + + @Override + public void setUpdateListener(UpdateListener listener) { + this.updateListener = listener; + } + + @Override + public void start() { + log.info("Starting KafkaConfigBackingStore"); + // Before startup, callbacks are *not* invoked. You can grab a snapshot after starting -- just take care that + // updates can continue to occur in the background + configLog.start(); + + int partitionCount = configLog.partitionCount(); + if (partitionCount > 1) { + String msg = String.format("Topic '%s' supplied via the '%s' property is required " + + "to have a single partition in order to guarantee consistency of " + + "connector configurations, but found %d partitions.", + topic, DistributedConfig.CONFIG_TOPIC_CONFIG, partitionCount); + throw new ConfigException(msg); + } + + started = true; + log.info("Started KafkaConfigBackingStore"); + } + + @Override + public void stop() { + log.info("Closing KafkaConfigBackingStore"); + try { + configLog.stop(); + } finally { + if (ownTopicAdmin != null) { + ownTopicAdmin.close(); + } + } + log.info("Closed KafkaConfigBackingStore"); + } + + /** + * Get a snapshot of the current state of the cluster. + */ + @Override + public ClusterConfigState snapshot() { + synchronized (lock) { + // Only a shallow copy is performed here; in order to avoid accidentally corrupting the worker's view + // of the config topic, any nested structures should be copied before making modifications + return new ClusterConfigState( + offset, + sessionKey, + new HashMap<>(connectorTaskCounts), + new HashMap<>(connectorConfigs), + new HashMap<>(connectorTargetStates), + new HashMap<>(taskConfigs), + new HashSet<>(inconsistent), + configTransformer + ); + } + } + + @Override + public boolean contains(String connector) { + synchronized (lock) { + return connectorConfigs.containsKey(connector); + } + } + + /** + * Write this connector configuration to persistent storage and wait until it has been acknowledged and read back by + * tailing the Kafka log with a consumer. + * + * @param connector name of the connector to write data for + * @param properties the configuration to write + */ + @Override + public void putConnectorConfig(String connector, Map properties) { + log.debug("Writing connector configuration for connector '{}'", connector); + Struct connectConfig = new Struct(CONNECTOR_CONFIGURATION_V0); + connectConfig.put("properties", properties); + byte[] serializedConfig = converter.fromConnectData(topic, CONNECTOR_CONFIGURATION_V0, connectConfig); + updateConnectorConfig(connector, serializedConfig); + } + + /** + * Remove configuration for a given connector. + * @param connector name of the connector to remove + */ + @Override + public void removeConnectorConfig(String connector) { + log.debug("Removing connector configuration for connector '{}'", connector); + try { + configLog.send(CONNECTOR_KEY(connector), null); + configLog.send(TARGET_STATE_KEY(connector), null); + configLog.readToEnd().get(READ_TO_END_TIMEOUT_MS, TimeUnit.MILLISECONDS); + } catch (InterruptedException | ExecutionException | TimeoutException e) { + log.error("Failed to remove connector configuration from Kafka: ", e); + throw new ConnectException("Error removing connector configuration from Kafka", e); + } + } + + @Override + public void removeTaskConfigs(String connector) { + throw new UnsupportedOperationException("Removal of tasks is not currently supported"); + } + + private void updateConnectorConfig(String connector, byte[] serializedConfig) { + try { + configLog.send(CONNECTOR_KEY(connector), serializedConfig); + configLog.readToEnd().get(READ_TO_END_TIMEOUT_MS, TimeUnit.MILLISECONDS); + } catch (InterruptedException | ExecutionException | TimeoutException e) { + log.error("Failed to write connector configuration to Kafka: ", e); + throw new ConnectException("Error writing connector configuration to Kafka", e); + } + } + + /** + * Write these task configurations and associated commit messages, unless an inconsistency is found that indicates + * that we would be leaving one of the referenced connectors with an inconsistent state. + * + * @param connector the connector to write task configuration + * @param configs list of task configurations for the connector + * @throws ConnectException if the task configurations do not resolve inconsistencies found in the existing root + * and task configurations. + */ + @Override + public void putTaskConfigs(String connector, List> configs) { + // Make sure we're at the end of the log. We should be the only writer, but we want to make sure we don't have + // any outstanding lagging data to consume. + try { + configLog.readToEnd().get(READ_TO_END_TIMEOUT_MS, TimeUnit.MILLISECONDS); + } catch (InterruptedException | ExecutionException | TimeoutException e) { + log.error("Failed to write root configuration to Kafka: ", e); + throw new ConnectException("Error writing root configuration to Kafka", e); + } + + int taskCount = configs.size(); + + // Start sending all the individual updates + int index = 0; + for (Map taskConfig: configs) { + Struct connectConfig = new Struct(TASK_CONFIGURATION_V0); + connectConfig.put("properties", taskConfig); + byte[] serializedConfig = converter.fromConnectData(topic, TASK_CONFIGURATION_V0, connectConfig); + log.debug("Writing configuration for connector '{}' task {}", connector, index); + ConnectorTaskId connectorTaskId = new ConnectorTaskId(connector, index); + configLog.send(TASK_KEY(connectorTaskId), serializedConfig); + index++; + } + + // Finally, send the commit to update the number of tasks and apply the new configs, then wait until we read to + // the end of the log + try { + // Read to end to ensure all the task configs have been written + if (taskCount > 0) { + configLog.readToEnd().get(READ_TO_END_TIMEOUT_MS, TimeUnit.MILLISECONDS); + } + // Write the commit message + Struct connectConfig = new Struct(CONNECTOR_TASKS_COMMIT_V0); + connectConfig.put("tasks", taskCount); + byte[] serializedConfig = converter.fromConnectData(topic, CONNECTOR_TASKS_COMMIT_V0, connectConfig); + log.debug("Writing commit for connector '{}' with {} tasks.", connector, taskCount); + configLog.send(COMMIT_TASKS_KEY(connector), serializedConfig); + + // Read to end to ensure all the commit messages have been written + configLog.readToEnd().get(READ_TO_END_TIMEOUT_MS, TimeUnit.MILLISECONDS); + } catch (InterruptedException | ExecutionException | TimeoutException e) { + log.error("Failed to write root configuration to Kafka: ", e); + throw new ConnectException("Error writing root configuration to Kafka", e); + } + } + + @Override + public void refresh(long timeout, TimeUnit unit) throws TimeoutException { + try { + configLog.readToEnd().get(timeout, unit); + } catch (InterruptedException | ExecutionException e) { + throw new ConnectException("Error trying to read to end of config log", e); + } + } + + @Override + public void putTargetState(String connector, TargetState state) { + Struct connectTargetState = new Struct(TARGET_STATE_V0); + connectTargetState.put("state", state.name()); + byte[] serializedTargetState = converter.fromConnectData(topic, TARGET_STATE_V0, connectTargetState); + log.debug("Writing target state {} for connector {}", state, connector); + configLog.send(TARGET_STATE_KEY(connector), serializedTargetState); + } + + @Override + public void putSessionKey(SessionKey sessionKey) { + log.debug("Distributing new session key"); + Struct sessionKeyStruct = new Struct(SESSION_KEY_V0); + sessionKeyStruct.put("key", Base64.getEncoder().encodeToString(sessionKey.key().getEncoded())); + sessionKeyStruct.put("algorithm", sessionKey.key().getAlgorithm()); + sessionKeyStruct.put("creation-timestamp", sessionKey.creationTimestamp()); + byte[] serializedSessionKey = converter.fromConnectData(topic, SESSION_KEY_V0, sessionKeyStruct); + try { + configLog.send(SESSION_KEY_KEY, serializedSessionKey); + configLog.readToEnd().get(READ_TO_END_TIMEOUT_MS, TimeUnit.MILLISECONDS); + } catch (InterruptedException | ExecutionException | TimeoutException e) { + log.error("Failed to write session key to Kafka: ", e); + throw new ConnectException("Error writing session key to Kafka", e); + } + } + + @Override + public void putRestartRequest(RestartRequest restartRequest) { + log.debug("Writing {} to Kafka", restartRequest); + String key = RESTART_KEY(restartRequest.connectorName()); + Struct value = new Struct(RESTART_REQUEST_V0); + value.put(INCLUDE_TASKS_FIELD_NAME, restartRequest.includeTasks()); + value.put(ONLY_FAILED_FIELD_NAME, restartRequest.onlyFailed()); + byte[] serializedValue = converter.fromConnectData(topic, value.schema(), value); + try { + configLog.send(key, serializedValue); + configLog.readToEnd().get(READ_TO_END_TIMEOUT_MS, TimeUnit.MILLISECONDS); + } catch (InterruptedException | ExecutionException | TimeoutException e) { + log.error("Failed to write {} to Kafka: ", restartRequest, e); + throw new ConnectException("Error writing " + restartRequest + " to Kafka", e); + } + } + + // package private for testing + KafkaBasedLog setupAndCreateKafkaBasedLog(String topic, final WorkerConfig config) { + String clusterId = ConnectUtils.lookupKafkaClusterId(config); + Map originals = config.originals(); + Map producerProps = new HashMap<>(originals); + producerProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class.getName()); + producerProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class.getName()); + producerProps.put(ProducerConfig.DELIVERY_TIMEOUT_MS_CONFIG, Integer.MAX_VALUE); + ConnectUtils.addMetricsContextProperties(producerProps, config, clusterId); + + Map consumerProps = new HashMap<>(originals); + consumerProps.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class.getName()); + consumerProps.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class.getName()); + ConnectUtils.addMetricsContextProperties(consumerProps, config, clusterId); + + Map adminProps = new HashMap<>(originals); + ConnectUtils.addMetricsContextProperties(adminProps, config, clusterId); + Supplier adminSupplier; + if (topicAdminSupplier != null) { + adminSupplier = topicAdminSupplier; + } else { + // Create our own topic admin supplier that we'll close when we're stopped + ownTopicAdmin = new SharedTopicAdmin(adminProps); + adminSupplier = ownTopicAdmin; + } + Map topicSettings = config instanceof DistributedConfig + ? ((DistributedConfig) config).configStorageTopicSettings() + : Collections.emptyMap(); + NewTopic topicDescription = TopicAdmin.defineTopic(topic) + .config(topicSettings) // first so that we override user-supplied settings as needed + .compacted() + .partitions(1) + .replicationFactor(config.getShort(DistributedConfig.CONFIG_STORAGE_REPLICATION_FACTOR_CONFIG)) + .build(); + + return createKafkaBasedLog(topic, producerProps, consumerProps, new ConsumeCallback(), topicDescription, adminSupplier); + } + + private KafkaBasedLog createKafkaBasedLog(String topic, Map producerProps, + Map consumerProps, + Callback> consumedCallback, + final NewTopic topicDescription, Supplier adminSupplier) { + java.util.function.Consumer createTopics = admin -> { + log.debug("Creating admin client to manage Connect internal config topic"); + // Create the topic if it doesn't exist + Set newTopics = admin.createTopics(topicDescription); + if (!newTopics.contains(topic)) { + // It already existed, so check that the topic cleanup policy is compact only and not delete + log.debug("Using admin client to check cleanup policy of '{}' topic is '{}'", topic, TopicConfig.CLEANUP_POLICY_COMPACT); + admin.verifyTopicCleanupPolicyOnlyCompact(topic, + DistributedConfig.CONFIG_TOPIC_CONFIG, "connector configurations"); + } + }; + return new KafkaBasedLog<>(topic, producerProps, consumerProps, adminSupplier, consumedCallback, Time.SYSTEM, createTopics); + } + + @SuppressWarnings("unchecked") + private class ConsumeCallback implements Callback> { + @Override + public void onCompletion(Throwable error, ConsumerRecord record) { + if (error != null) { + log.error("Unexpected in consumer callback for KafkaConfigBackingStore: ", error); + return; + } + + final SchemaAndValue value; + try { + value = converter.toConnectData(topic, record.value()); + } catch (DataException e) { + log.error("Failed to convert config data to Kafka Connect format: ", e); + return; + } + // Make the recorded offset match the API used for positions in the consumer -- return the offset of the + // *next record*, not the last one consumed. + offset = record.offset() + 1; + + if (record.key().startsWith(TARGET_STATE_PREFIX)) { + String connectorName = record.key().substring(TARGET_STATE_PREFIX.length()); + boolean removed = false; + synchronized (lock) { + if (value.value() == null) { + // When connector configs are removed, we also write tombstones for the target state. + log.debug("Removed target state for connector {} due to null value in topic.", connectorName); + connectorTargetStates.remove(connectorName); + removed = true; + + // If for some reason we still have configs for the connector, add back the default + // STARTED state to ensure each connector always has a valid target state. + if (connectorConfigs.containsKey(connectorName)) + connectorTargetStates.put(connectorName, TargetState.STARTED); + } else { + if (!(value.value() instanceof Map)) { + log.error("Found target state ({}) in wrong format: {}", record.key(), value.value().getClass()); + return; + } + Object targetState = ((Map) value.value()).get("state"); + if (!(targetState instanceof String)) { + log.error("Invalid data for target state for connector '{}': 'state' field should be a Map but is {}", + connectorName, targetState == null ? null : targetState.getClass()); + return; + } + + try { + TargetState state = TargetState.valueOf((String) targetState); + log.debug("Setting target state for connector '{}' to {}", connectorName, targetState); + connectorTargetStates.put(connectorName, state); + } catch (IllegalArgumentException e) { + log.error("Invalid target state for connector '{}': {}", connectorName, targetState); + return; + } + } + } + + // Note that we do not notify the update listener if the target state has been removed. + // Instead we depend on the removal callback of the connector config itself to notify the worker. + if (started && !removed) + updateListener.onConnectorTargetStateChange(connectorName); + + } else if (record.key().startsWith(CONNECTOR_PREFIX)) { + String connectorName = record.key().substring(CONNECTOR_PREFIX.length()); + boolean removed = false; + synchronized (lock) { + if (value.value() == null) { + // Connector deletion will be written as a null value + log.info("Successfully processed removal of connector '{}'", connectorName); + connectorConfigs.remove(connectorName); + connectorTaskCounts.remove(connectorName); + taskConfigs.keySet().removeIf(taskId -> taskId.connector().equals(connectorName)); + removed = true; + } else { + // Connector configs can be applied and callbacks invoked immediately + if (!(value.value() instanceof Map)) { + log.error("Found configuration for connector '{}' in wrong format: {}", record.key(), value.value().getClass()); + return; + } + Object newConnectorConfig = ((Map) value.value()).get("properties"); + if (!(newConnectorConfig instanceof Map)) { + log.error("Invalid data for config for connector '{}': 'properties' field should be a Map but is {}", + connectorName, newConnectorConfig == null ? null : newConnectorConfig.getClass()); + return; + } + log.debug("Updating configuration for connector '{}'", connectorName); + connectorConfigs.put(connectorName, (Map) newConnectorConfig); + + // Set the initial state of the connector to STARTED, which ensures that any connectors + // which were created with 0.9 Connect will be initialized in the STARTED state. + if (!connectorTargetStates.containsKey(connectorName)) + connectorTargetStates.put(connectorName, TargetState.STARTED); + } + } + if (started) { + if (removed) + updateListener.onConnectorConfigRemove(connectorName); + else + updateListener.onConnectorConfigUpdate(connectorName); + } + } else if (record.key().startsWith(TASK_PREFIX)) { + synchronized (lock) { + ConnectorTaskId taskId = parseTaskId(record.key()); + if (taskId == null) { + log.error("Ignoring task configuration because {} couldn't be parsed as a task config key", record.key()); + return; + } + if (value.value() == null) { + log.error("Ignoring task configuration for task {} because it is unexpectedly null", taskId); + return; + } + if (!(value.value() instanceof Map)) { + log.error("Ignoring task configuration for task {} because the value is not a Map but is {}", taskId, value.value().getClass()); + return; + } + + Object newTaskConfig = ((Map) value.value()).get("properties"); + if (!(newTaskConfig instanceof Map)) { + log.error("Invalid data for config of task {} 'properties' field should be a Map but is {}", taskId, newTaskConfig.getClass()); + return; + } + + Map> deferred = deferredTaskUpdates.get(taskId.connector()); + if (deferred == null) { + deferred = new HashMap<>(); + deferredTaskUpdates.put(taskId.connector(), deferred); + } + log.debug("Storing new config for task {}; this will wait for a commit message before the new config will take effect.", taskId); + deferred.put(taskId, (Map) newTaskConfig); + } + } else if (record.key().startsWith(COMMIT_TASKS_PREFIX)) { + String connectorName = record.key().substring(COMMIT_TASKS_PREFIX.length()); + List updatedTasks = new ArrayList<>(); + synchronized (lock) { + // Apply any outstanding deferred task updates for the given connector. Note that just because we + // encounter a commit message does not mean it will result in consistent output. In particular due to + // compaction, there may be cases where . For example if we have the following sequence of writes: + // + // 1. Write connector "foo"'s config + // 2. Write connector "foo", task 1's config <-- compacted + // 3. Write connector "foo", task 2's config + // 4. Write connector "foo" task commit message + // 5. Write connector "foo", task 1's config + // 6. Write connector "foo", task 2's config + // 7. Write connector "foo" task commit message + // + // then when a new worker starts up, if message 2 had been compacted, then when message 4 is applied + // "foo" will not have a complete set of configs. Only when message 7 is applied will the complete + // configuration be available. Worse, if the leader died while writing messages 5, 6, and 7 such that + // only 5 was written, then there may be nothing that will finish writing the configs and get the + // log back into a consistent state. + // + // It is expected that the user of this class (i.e., the Herder) will take the necessary action to + // resolve this (i.e., get the connector to recommit its configuration). This inconsistent state is + // exposed in the snapshots provided via ClusterConfigState so they are easy to handle. + if (!(value.value() instanceof Map)) { // Schema-less, so we get maps instead of structs + log.error("Ignoring connector tasks configuration commit for connector '{}' because it is in the wrong format: {}", connectorName, value.value()); + return; + } + Map> deferred = deferredTaskUpdates.get(connectorName); + + int newTaskCount = intValue(((Map) value.value()).get("tasks")); + + // Validate the configs we're supposed to update to ensure we're getting a complete configuration + // update of all tasks that are expected based on the number of tasks in the commit message. + Set taskIdSet = taskIds(connectorName, deferred); + if (!completeTaskIdSet(taskIdSet, newTaskCount)) { + // Given the logic for writing commit messages, we should only hit this condition due to compacted + // historical data, in which case we would not have applied any updates yet and there will be no + // task config data already committed for the connector, so we shouldn't have to clear any data + // out. All we need to do is add the flag marking it inconsistent. + log.debug("We have an incomplete set of task configs for connector '{}' probably due to compaction. So we are not doing anything with the new configuration.", connectorName); + inconsistent.add(connectorName); + } else { + if (deferred != null) { + taskConfigs.putAll(deferred); + updatedTasks.addAll(deferred.keySet()); + } + inconsistent.remove(connectorName); + } + // Always clear the deferred entries, even if we didn't apply them. If they represented an inconsistent + // update, then we need to see a completely fresh set of configs after this commit message, so we don't + // want any of these outdated configs + if (deferred != null) + deferred.clear(); + + connectorTaskCounts.put(connectorName, newTaskCount); + } + + if (started) + updateListener.onTaskConfigUpdate(updatedTasks); + } else if (record.key().startsWith(RESTART_PREFIX)) { + RestartRequest request = recordToRestartRequest(record, value); + // Only notify the listener if this backing store is already successfully started (having caught up the first time) + if (request != null && started) { + updateListener.onRestartRequest(request); + } + } else if (record.key().equals(SESSION_KEY_KEY)) { + if (value.value() == null) { + log.error("Ignoring session key because it is unexpectedly null"); + return; + } + if (!(value.value() instanceof Map)) { + log.error("Ignoring session key because the value is not a Map but is {}", value.value().getClass()); + return; + } + + Map valueAsMap = (Map) value.value(); + + Object sessionKey = valueAsMap.get("key"); + if (!(sessionKey instanceof String)) { + log.error("Invalid data for session key 'key' field should be a String but is {}", sessionKey.getClass()); + return; + } + byte[] key = Base64.getDecoder().decode((String) sessionKey); + + Object keyAlgorithm = valueAsMap.get("algorithm"); + if (!(keyAlgorithm instanceof String)) { + log.error("Invalid data for session key 'algorithm' field should be a String but it is {}", keyAlgorithm.getClass()); + return; + } + + Object creationTimestamp = valueAsMap.get("creation-timestamp"); + if (!(creationTimestamp instanceof Long)) { + log.error("Invalid data for session key 'creation-timestamp' field should be a long but it is {}", creationTimestamp.getClass()); + return; + } + KafkaConfigBackingStore.this.sessionKey = new SessionKey( + new SecretKeySpec(key, (String) keyAlgorithm), + (long) creationTimestamp + ); + + if (started) + updateListener.onSessionKeyUpdate(KafkaConfigBackingStore.this.sessionKey); + } else { + log.error("Discarding config update record with invalid key: {}", record.key()); + } + } + + } + + @SuppressWarnings("unchecked") + RestartRequest recordToRestartRequest(ConsumerRecord record, SchemaAndValue value) { + String connectorName = record.key().substring(RESTART_PREFIX.length()); + if (!(value.value() instanceof Map)) { + log.error("Ignoring restart request because the value is not a Map but is {}", value.value() == null ? "null" : value.value().getClass()); + return null; + } + + Map valueAsMap = (Map) value.value(); + + Object failed = valueAsMap.get(ONLY_FAILED_FIELD_NAME); + boolean onlyFailed; + if (!(failed instanceof Boolean)) { + log.warn("Invalid data for restart request '{}' field should be a Boolean but is {}, defaulting to {}", ONLY_FAILED_FIELD_NAME, failed == null ? "null" : failed.getClass(), ONLY_FAILED_DEFAULT); + onlyFailed = ONLY_FAILED_DEFAULT; + } else { + onlyFailed = (Boolean) failed; + } + + Object withTasks = valueAsMap.get(INCLUDE_TASKS_FIELD_NAME); + boolean includeTasks; + if (!(withTasks instanceof Boolean)) { + log.warn("Invalid data for restart request '{}' field should be a Boolean but is {}, defaulting to {}", INCLUDE_TASKS_FIELD_NAME, withTasks == null ? "null" : withTasks.getClass(), INCLUDE_TASKS_DEFAULT); + includeTasks = INCLUDE_TASKS_DEFAULT; + } else { + includeTasks = (Boolean) withTasks; + } + return new RestartRequest(connectorName, onlyFailed, includeTasks); + } + + private ConnectorTaskId parseTaskId(String key) { + String[] parts = key.split("-"); + if (parts.length < 3) return null; + + try { + int taskNum = Integer.parseInt(parts[parts.length - 1]); + String connectorName = Utils.join(Arrays.copyOfRange(parts, 1, parts.length - 1), "-"); + return new ConnectorTaskId(connectorName, taskNum); + } catch (NumberFormatException e) { + return null; + } + } + + /** + * Given task configurations, get a set of integer task IDs for the connector. + */ + private Set taskIds(String connector, Map> configs) { + Set tasks = new TreeSet<>(); + if (configs == null) { + return tasks; + } + for (ConnectorTaskId taskId : configs.keySet()) { + assert taskId.connector().equals(connector); + tasks.add(taskId.task()); + } + return tasks; + } + + private boolean completeTaskIdSet(Set idSet, int expectedSize) { + // Note that we do *not* check for the exact set. This is an important implication of compaction. If we start out + // with 2 tasks, then reduce to 1, we'll end up with log entries like: + // + // 1. Connector "foo" config + // 2. Connector "foo", task 1 config + // 3. Connector "foo", task 2 config + // 4. Connector "foo", commit 2 tasks + // 5. Connector "foo", task 1 config + // 6. Connector "foo", commit 1 tasks + // + // However, due to compaction we could end up with a log that looks like this: + // + // 1. Connector "foo" config + // 3. Connector "foo", task 2 config + // 5. Connector "foo", task 1 config + // 6. Connector "foo", commit 1 tasks + // + // which isn't incorrect, but would appear in this code to have an extra task configuration. Instead, we just + // validate that all the configs specified by the commit message are present. This should be fine because the + // logic for writing configs ensures all the task configs are written (and reads them back) before writing the + // commit message. + + if (idSet.size() < expectedSize) + return false; + + for (int i = 0; i < expectedSize; i++) + if (!idSet.contains(i)) + return false; + return true; + } + + // Convert an integer value extracted from a schemaless struct to an int. This handles potentially different + // encodings by different Converters. + private static int intValue(Object value) { + if (value instanceof Integer) + return (int) value; + else if (value instanceof Long) + return (int) (long) value; + else + throw new ConnectException("Expected integer value to be either Integer or Long"); + } +} + diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/storage/KafkaOffsetBackingStore.java b/connect/runtime/src/main/java/org/apache/kafka/connect/storage/KafkaOffsetBackingStore.java new file mode 100644 index 0000000..313baf7 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/storage/KafkaOffsetBackingStore.java @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.config.TopicConfig; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.apache.kafka.connect.runtime.distributed.DistributedConfig; +import org.apache.kafka.connect.util.Callback; +import org.apache.kafka.connect.util.ConnectUtils; +import org.apache.kafka.connect.util.ConvertingFutureCallback; +import org.apache.kafka.connect.util.KafkaBasedLog; +import org.apache.kafka.connect.util.SharedTopicAdmin; +import org.apache.kafka.connect.util.TopicAdmin; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.function.Supplier; + +/** + *

            + * Implementation of OffsetBackingStore that uses a Kafka topic to store offset data. + *

            + *

            + * Internally, this implementation both produces to and consumes from a Kafka topic which stores the offsets. + * It accepts producer and consumer overrides via its configuration but forces some settings to specific values + * to ensure correct behavior (e.g. acks, auto.offset.reset). + *

            + */ +public class KafkaOffsetBackingStore implements OffsetBackingStore { + private static final Logger log = LoggerFactory.getLogger(KafkaOffsetBackingStore.class); + + private KafkaBasedLog offsetLog; + private HashMap data; + private final Supplier topicAdminSupplier; + private SharedTopicAdmin ownTopicAdmin; + + @Deprecated + public KafkaOffsetBackingStore() { + this.topicAdminSupplier = null; + } + + public KafkaOffsetBackingStore(Supplier topicAdmin) { + this.topicAdminSupplier = Objects.requireNonNull(topicAdmin); + } + + @Override + public void configure(final WorkerConfig config) { + String topic = config.getString(DistributedConfig.OFFSET_STORAGE_TOPIC_CONFIG); + if (topic == null || topic.trim().length() == 0) + throw new ConfigException("Offset storage topic must be specified"); + + String clusterId = ConnectUtils.lookupKafkaClusterId(config); + data = new HashMap<>(); + + Map originals = config.originals(); + Map producerProps = new HashMap<>(originals); + producerProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class.getName()); + producerProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class.getName()); + producerProps.put(ProducerConfig.DELIVERY_TIMEOUT_MS_CONFIG, Integer.MAX_VALUE); + ConnectUtils.addMetricsContextProperties(producerProps, config, clusterId); + + Map consumerProps = new HashMap<>(originals); + consumerProps.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class.getName()); + consumerProps.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class.getName()); + ConnectUtils.addMetricsContextProperties(consumerProps, config, clusterId); + + Map adminProps = new HashMap<>(originals); + ConnectUtils.addMetricsContextProperties(adminProps, config, clusterId); + Supplier adminSupplier; + if (topicAdminSupplier != null) { + adminSupplier = topicAdminSupplier; + } else { + // Create our own topic admin supplier that we'll close when we're stopped + ownTopicAdmin = new SharedTopicAdmin(adminProps); + adminSupplier = ownTopicAdmin; + } + Map topicSettings = config instanceof DistributedConfig + ? ((DistributedConfig) config).offsetStorageTopicSettings() + : Collections.emptyMap(); + NewTopic topicDescription = TopicAdmin.defineTopic(topic) + .config(topicSettings) // first so that we override user-supplied settings as needed + .compacted() + .partitions(config.getInt(DistributedConfig.OFFSET_STORAGE_PARTITIONS_CONFIG)) + .replicationFactor(config.getShort(DistributedConfig.OFFSET_STORAGE_REPLICATION_FACTOR_CONFIG)) + .build(); + + offsetLog = createKafkaBasedLog(topic, producerProps, consumerProps, consumedCallback, topicDescription, adminSupplier); + } + + private KafkaBasedLog createKafkaBasedLog(String topic, Map producerProps, + Map consumerProps, + Callback> consumedCallback, + final NewTopic topicDescription, Supplier adminSupplier) { + java.util.function.Consumer createTopics = admin -> { + log.debug("Creating admin client to manage Connect internal offset topic"); + // Create the topic if it doesn't exist + Set newTopics = admin.createTopics(topicDescription); + if (!newTopics.contains(topic)) { + // It already existed, so check that the topic cleanup policy is compact only and not delete + log.debug("Using admin client to check cleanup policy for '{}' topic is '{}'", topic, TopicConfig.CLEANUP_POLICY_COMPACT); + admin.verifyTopicCleanupPolicyOnlyCompact(topic, + DistributedConfig.OFFSET_STORAGE_TOPIC_CONFIG, "source connector offsets"); + } + }; + return new KafkaBasedLog<>(topic, producerProps, consumerProps, adminSupplier, consumedCallback, Time.SYSTEM, createTopics); + } + + @Override + public void start() { + log.info("Starting KafkaOffsetBackingStore"); + offsetLog.start(); + log.info("Finished reading offsets topic and starting KafkaOffsetBackingStore"); + } + + @Override + public void stop() { + log.info("Stopping KafkaOffsetBackingStore"); + try { + offsetLog.stop(); + } finally { + if (ownTopicAdmin != null) { + ownTopicAdmin.close(); + } + } + log.info("Stopped KafkaOffsetBackingStore"); + } + + @Override + public Future> get(final Collection keys) { + ConvertingFutureCallback> future = new ConvertingFutureCallback>() { + @Override + public Map convert(Void result) { + Map values = new HashMap<>(); + for (ByteBuffer key : keys) + values.put(key, data.get(key)); + return values; + } + }; + // This operation may be relatively (but not too) expensive since it always requires checking end offsets, even + // if we've already read up to the end. However, it also should not be common (offsets should only be read when + // resetting a task). Always requiring that we read to the end is simpler than trying to differentiate when it + // is safe not to (which should only be if we *know* we've maintained ownership since the last write). + offsetLog.readToEnd(future); + return future; + } + + @Override + public Future set(final Map values, final Callback callback) { + SetCallbackFuture producerCallback = new SetCallbackFuture(values.size(), callback); + + for (Map.Entry entry : values.entrySet()) { + ByteBuffer key = entry.getKey(); + ByteBuffer value = entry.getValue(); + offsetLog.send(key == null ? null : key.array(), value == null ? null : value.array(), producerCallback); + } + + return producerCallback; + } + + private final Callback> consumedCallback = new Callback>() { + @Override + public void onCompletion(Throwable error, ConsumerRecord record) { + ByteBuffer key = record.key() != null ? ByteBuffer.wrap(record.key()) : null; + ByteBuffer value = record.value() != null ? ByteBuffer.wrap(record.value()) : null; + data.put(key, value); + } + }; + + private static class SetCallbackFuture implements org.apache.kafka.clients.producer.Callback, Future { + private int numLeft; + private boolean completed = false; + private Throwable exception = null; + private final Callback callback; + + public SetCallbackFuture(int numRecords, Callback callback) { + numLeft = numRecords; + this.callback = callback; + } + + @Override + public synchronized void onCompletion(RecordMetadata metadata, Exception exception) { + if (exception != null) { + if (!completed) { + this.exception = exception; + callback.onCompletion(exception, null); + completed = true; + this.notify(); + } + return; + } + + numLeft -= 1; + if (numLeft == 0) { + callback.onCompletion(null, null); + completed = true; + this.notify(); + } + } + + @Override + public synchronized boolean cancel(boolean mayInterruptIfRunning) { + return false; + } + + @Override + public synchronized boolean isCancelled() { + return false; + } + + @Override + public synchronized boolean isDone() { + return completed; + } + + @Override + public synchronized Void get() throws InterruptedException, ExecutionException { + while (!completed) { + this.wait(); + } + if (exception != null) + throw new ExecutionException(exception); + return null; + } + + @Override + public synchronized Void get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { + long started = System.currentTimeMillis(); + long limit = started + unit.toMillis(timeout); + while (!completed) { + long leftMs = limit - System.currentTimeMillis(); + if (leftMs < 0) + throw new TimeoutException("KafkaOffsetBackingStore Future timed out."); + this.wait(leftMs); + } + if (exception != null) + throw new ExecutionException(exception); + return null; + } + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/storage/KafkaStatusBackingStore.java b/connect/runtime/src/main/java/org/apache/kafka/connect/storage/KafkaStatusBackingStore.java new file mode 100644 index 0000000..44902c0 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/storage/KafkaStatusBackingStore.java @@ -0,0 +1,685 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.config.TopicConfig; +import org.apache.kafka.common.errors.RetriableException; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.data.SchemaBuilder; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.runtime.AbstractStatus; +import org.apache.kafka.connect.runtime.ConnectorStatus; +import org.apache.kafka.connect.runtime.TaskStatus; +import org.apache.kafka.connect.runtime.TopicStatus; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.apache.kafka.connect.runtime.distributed.DistributedConfig; +import org.apache.kafka.connect.util.Callback; +import org.apache.kafka.connect.util.ConnectUtils; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.apache.kafka.connect.util.KafkaBasedLog; +import org.apache.kafka.connect.util.SharedTopicAdmin; +import org.apache.kafka.connect.util.Table; +import org.apache.kafka.connect.util.TopicAdmin; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.function.Supplier; + +/** + * StatusBackingStore implementation which uses a compacted topic for storage + * of connector and task status information. When a state change is observed, + * the new state is written to the compacted topic. The new state will not be + * visible until it has been read back from the topic. + * + * In spite of their names, the putSafe() methods cannot guarantee the safety + * of the write (since Kafka itself cannot provide such guarantees currently), + * but it can avoid specific unsafe conditions. In particular, we putSafe() + * allows writes in the following conditions: + * + * 1) It is (probably) safe to overwrite the state if there is no previous + * value. + * 2) It is (probably) safe to overwrite the state if the previous value was + * set by a worker with the same workerId. + * 3) It is (probably) safe to overwrite the previous state if the current + * generation is higher than the previous . + * + * Basically all these conditions do is reduce the window for conflicts. They + * obviously cannot take into account in-flight requests. + * + */ +public class KafkaStatusBackingStore implements StatusBackingStore { + private static final Logger log = LoggerFactory.getLogger(KafkaStatusBackingStore.class); + + public static final String TASK_STATUS_PREFIX = "status-task-"; + public static final String CONNECTOR_STATUS_PREFIX = "status-connector-"; + public static final String TOPIC_STATUS_PREFIX = "status-topic-"; + public static final String TOPIC_STATUS_SEPARATOR = ":connector-"; + + public static final String STATE_KEY_NAME = "state"; + public static final String TRACE_KEY_NAME = "trace"; + public static final String WORKER_ID_KEY_NAME = "worker_id"; + public static final String GENERATION_KEY_NAME = "generation"; + + public static final String TOPIC_STATE_KEY = "topic"; + public static final String TOPIC_NAME_KEY = "name"; + public static final String TOPIC_CONNECTOR_KEY = "connector"; + public static final String TOPIC_TASK_KEY = "task"; + public static final String TOPIC_DISCOVER_TIMESTAMP_KEY = "discoverTimestamp"; + + private static final Schema STATUS_SCHEMA_V0 = SchemaBuilder.struct() + .field(STATE_KEY_NAME, Schema.STRING_SCHEMA) + .field(TRACE_KEY_NAME, SchemaBuilder.string().optional().build()) + .field(WORKER_ID_KEY_NAME, Schema.STRING_SCHEMA) + .field(GENERATION_KEY_NAME, Schema.INT32_SCHEMA) + .build(); + + private static final Schema TOPIC_STATUS_VALUE_SCHEMA_V0 = SchemaBuilder.struct() + .field(TOPIC_NAME_KEY, Schema.STRING_SCHEMA) + .field(TOPIC_CONNECTOR_KEY, Schema.STRING_SCHEMA) + .field(TOPIC_TASK_KEY, Schema.INT32_SCHEMA) + .field(TOPIC_DISCOVER_TIMESTAMP_KEY, Schema.INT64_SCHEMA) + .build(); + + private static final Schema TOPIC_STATUS_SCHEMA_V0 = SchemaBuilder.map( + Schema.STRING_SCHEMA, + TOPIC_STATUS_VALUE_SCHEMA_V0 + ).build(); + + private final Time time; + private final Converter converter; + //visible for testing + protected final Table> tasks; + protected final Map> connectors; + protected final ConcurrentMap> topics; + private final Supplier topicAdminSupplier; + + private String statusTopic; + private KafkaBasedLog kafkaLog; + private int generation; + private SharedTopicAdmin ownTopicAdmin; + + @Deprecated + public KafkaStatusBackingStore(Time time, Converter converter) { + this(time, converter, null); + } + + public KafkaStatusBackingStore(Time time, Converter converter, Supplier topicAdminSupplier) { + this.time = time; + this.converter = converter; + this.tasks = new Table<>(); + this.connectors = new HashMap<>(); + this.topics = new ConcurrentHashMap<>(); + this.topicAdminSupplier = topicAdminSupplier; + } + + // visible for testing + KafkaStatusBackingStore(Time time, Converter converter, String statusTopic, KafkaBasedLog kafkaLog) { + this(time, converter); + this.kafkaLog = kafkaLog; + this.statusTopic = statusTopic; + } + + @Override + public void configure(final WorkerConfig config) { + this.statusTopic = config.getString(DistributedConfig.STATUS_STORAGE_TOPIC_CONFIG); + if (this.statusTopic == null || this.statusTopic.trim().length() == 0) + throw new ConfigException("Must specify topic for connector status."); + + String clusterId = ConnectUtils.lookupKafkaClusterId(config); + Map originals = config.originals(); + Map producerProps = new HashMap<>(originals); + producerProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class.getName()); + producerProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class.getName()); + producerProps.put(ProducerConfig.RETRIES_CONFIG, 0); // we handle retries in this class + ConnectUtils.addMetricsContextProperties(producerProps, config, clusterId); + + Map consumerProps = new HashMap<>(originals); + consumerProps.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class.getName()); + consumerProps.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class.getName()); + ConnectUtils.addMetricsContextProperties(consumerProps, config, clusterId); + + Map adminProps = new HashMap<>(originals); + ConnectUtils.addMetricsContextProperties(adminProps, config, clusterId); + Supplier adminSupplier; + if (topicAdminSupplier != null) { + adminSupplier = topicAdminSupplier; + } else { + // Create our own topic admin supplier that we'll close when we're stopped + ownTopicAdmin = new SharedTopicAdmin(adminProps); + adminSupplier = ownTopicAdmin; + } + + Map topicSettings = config instanceof DistributedConfig + ? ((DistributedConfig) config).statusStorageTopicSettings() + : Collections.emptyMap(); + NewTopic topicDescription = TopicAdmin.defineTopic(statusTopic) + .config(topicSettings) // first so that we override user-supplied settings as needed + .compacted() + .partitions(config.getInt(DistributedConfig.STATUS_STORAGE_PARTITIONS_CONFIG)) + .replicationFactor(config.getShort(DistributedConfig.STATUS_STORAGE_REPLICATION_FACTOR_CONFIG)) + .build(); + + Callback> readCallback = (error, record) -> read(record); + this.kafkaLog = createKafkaBasedLog(statusTopic, producerProps, consumerProps, readCallback, topicDescription, adminSupplier); + } + + private KafkaBasedLog createKafkaBasedLog(String topic, Map producerProps, + Map consumerProps, + Callback> consumedCallback, + final NewTopic topicDescription, Supplier adminSupplier) { + java.util.function.Consumer createTopics = admin -> { + log.debug("Creating admin client to manage Connect internal status topic"); + // Create the topic if it doesn't exist + Set newTopics = admin.createTopics(topicDescription); + if (!newTopics.contains(topic)) { + // It already existed, so check that the topic cleanup policy is compact only and not delete + log.debug("Using admin client to check cleanup policy of '{}' topic is '{}'", topic, TopicConfig.CLEANUP_POLICY_COMPACT); + admin.verifyTopicCleanupPolicyOnlyCompact(topic, + DistributedConfig.STATUS_STORAGE_TOPIC_CONFIG, "connector and task statuses"); + } + }; + return new KafkaBasedLog<>(topic, producerProps, consumerProps, adminSupplier, consumedCallback, time, createTopics); + } + + @Override + public void start() { + kafkaLog.start(); + + // read to the end on startup to ensure that api requests see the most recent states + kafkaLog.readToEnd(); + } + + @Override + public void stop() { + try { + kafkaLog.stop(); + } finally { + if (ownTopicAdmin != null) { + ownTopicAdmin.close(); + } + } + } + + @Override + public void put(final ConnectorStatus status) { + sendConnectorStatus(status, false); + } + + @Override + public void putSafe(final ConnectorStatus status) { + sendConnectorStatus(status, true); + } + + @Override + public void put(final TaskStatus status) { + sendTaskStatus(status, false); + } + + @Override + public void putSafe(final TaskStatus status) { + sendTaskStatus(status, true); + } + + @Override + public void put(final TopicStatus status) { + sendTopicStatus(status.connector(), status.topic(), status); + } + + @Override + public void flush() { + kafkaLog.flush(); + } + + private void sendConnectorStatus(final ConnectorStatus status, boolean safeWrite) { + String connector = status.id(); + CacheEntry entry = getOrAdd(connector); + String key = CONNECTOR_STATUS_PREFIX + connector; + send(key, status, entry, safeWrite); + } + + private void sendTaskStatus(final TaskStatus status, boolean safeWrite) { + ConnectorTaskId taskId = status.id(); + CacheEntry entry = getOrAdd(taskId); + String key = TASK_STATUS_PREFIX + taskId.connector() + "-" + taskId.task(); + send(key, status, entry, safeWrite); + } + + private void sendTopicStatus(final String connector, final String topic, final TopicStatus status) { + String key = TOPIC_STATUS_PREFIX + topic + TOPIC_STATUS_SEPARATOR + connector; + + final byte[] value = serializeTopicStatus(status); + + kafkaLog.send(key, value, new org.apache.kafka.clients.producer.Callback() { + @Override + public void onCompletion(RecordMetadata metadata, Exception exception) { + if (exception == null) return; + // TODO: retry more gracefully and not forever + if (exception instanceof RetriableException) { + kafkaLog.send(key, value, this); + } else { + log.error("Failed to write status update", exception); + } + } + }); + } + + private > void send(final String key, + final V status, + final CacheEntry entry, + final boolean safeWrite) { + final int sequence; + synchronized (this) { + this.generation = status.generation(); + if (safeWrite && !entry.canWriteSafely(status)) + return; + sequence = entry.increment(); + } + + final byte[] value = status.state() == ConnectorStatus.State.DESTROYED ? null : serialize(status); + + kafkaLog.send(key, value, new org.apache.kafka.clients.producer.Callback() { + @Override + public void onCompletion(RecordMetadata metadata, Exception exception) { + if (exception == null) return; + if (exception instanceof RetriableException) { + synchronized (KafkaStatusBackingStore.this) { + if (entry.isDeleted() + || status.generation() != generation + || (safeWrite && !entry.canWriteSafely(status, sequence))) + return; + } + kafkaLog.send(key, value, this); + } else { + log.error("Failed to write status update", exception); + } + } + }); + } + + private synchronized CacheEntry getOrAdd(String connector) { + CacheEntry entry = connectors.get(connector); + if (entry == null) { + entry = new CacheEntry<>(); + connectors.put(connector, entry); + } + return entry; + } + + private synchronized void remove(String connector) { + CacheEntry removed = connectors.remove(connector); + if (removed != null) + removed.delete(); + + Map> tasks = this.tasks.remove(connector); + if (tasks != null) { + for (CacheEntry taskEntry : tasks.values()) + taskEntry.delete(); + } + } + + private synchronized CacheEntry getOrAdd(ConnectorTaskId task) { + CacheEntry entry = tasks.get(task.connector(), task.task()); + if (entry == null) { + entry = new CacheEntry<>(); + tasks.put(task.connector(), task.task(), entry); + } + return entry; + } + + private synchronized void remove(ConnectorTaskId id) { + CacheEntry removed = tasks.remove(id.connector(), id.task()); + if (removed != null) + removed.delete(); + } + + private void removeTopic(String topic, String connector) { + ConcurrentMap activeTopics = topics.get(connector); + if (activeTopics == null) { + return; + } + activeTopics.remove(topic); + } + + @Override + public synchronized TaskStatus get(ConnectorTaskId id) { + CacheEntry entry = tasks.get(id.connector(), id.task()); + return entry == null ? null : entry.get(); + } + + @Override + public synchronized ConnectorStatus get(String connector) { + CacheEntry entry = connectors.get(connector); + return entry == null ? null : entry.get(); + } + + @Override + public synchronized Collection getAll(String connector) { + List res = new ArrayList<>(); + for (CacheEntry statusEntry : tasks.row(connector).values()) { + TaskStatus status = statusEntry.get(); + if (status != null) + res.add(status); + } + return res; + } + + @Override + public TopicStatus getTopic(String connector, String topic) { + ConcurrentMap activeTopics = topics.get(Objects.requireNonNull(connector)); + return activeTopics != null ? activeTopics.get(Objects.requireNonNull(topic)) : null; + } + + @Override + public Collection getAllTopics(String connector) { + ConcurrentMap activeTopics = topics.get(Objects.requireNonNull(connector)); + return activeTopics != null + ? Collections.unmodifiableCollection(Objects.requireNonNull(activeTopics.values())) + : Collections.emptySet(); + } + + @Override + public void deleteTopic(String connector, String topic) { + sendTopicStatus(Objects.requireNonNull(connector), Objects.requireNonNull(topic), null); + } + + @Override + public synchronized Set connectors() { + return new HashSet<>(connectors.keySet()); + } + + private ConnectorStatus parseConnectorStatus(String connector, byte[] data) { + try { + SchemaAndValue schemaAndValue = converter.toConnectData(statusTopic, data); + if (!(schemaAndValue.value() instanceof Map)) { + log.error("Invalid connector status type {}", schemaAndValue.value().getClass()); + return null; + } + + @SuppressWarnings("unchecked") + Map statusMap = (Map) schemaAndValue.value(); + TaskStatus.State state = TaskStatus.State.valueOf((String) statusMap.get(STATE_KEY_NAME)); + String trace = (String) statusMap.get(TRACE_KEY_NAME); + String workerUrl = (String) statusMap.get(WORKER_ID_KEY_NAME); + int generation = ((Long) statusMap.get(GENERATION_KEY_NAME)).intValue(); + return new ConnectorStatus(connector, state, trace, workerUrl, generation); + } catch (Exception e) { + log.error("Failed to deserialize connector status", e); + return null; + } + } + + private TaskStatus parseTaskStatus(ConnectorTaskId taskId, byte[] data) { + try { + SchemaAndValue schemaAndValue = converter.toConnectData(statusTopic, data); + if (!(schemaAndValue.value() instanceof Map)) { + log.error("Invalid task status type {}", schemaAndValue.value().getClass()); + return null; + } + @SuppressWarnings("unchecked") + Map statusMap = (Map) schemaAndValue.value(); + TaskStatus.State state = TaskStatus.State.valueOf((String) statusMap.get(STATE_KEY_NAME)); + String trace = (String) statusMap.get(TRACE_KEY_NAME); + String workerUrl = (String) statusMap.get(WORKER_ID_KEY_NAME); + int generation = ((Long) statusMap.get(GENERATION_KEY_NAME)).intValue(); + return new TaskStatus(taskId, state, workerUrl, generation, trace); + } catch (Exception e) { + log.error("Failed to deserialize task status", e); + return null; + } + } + + protected TopicStatus parseTopicStatus(byte[] data) { + try { + SchemaAndValue schemaAndValue = converter.toConnectData(statusTopic, data); + if (!(schemaAndValue.value() instanceof Map)) { + log.error("Invalid topic status value {}", schemaAndValue.value()); + return null; + } + @SuppressWarnings("unchecked") + Object innerValue = ((Map) schemaAndValue.value()).get(TOPIC_STATE_KEY); + if (!(innerValue instanceof Map)) { + log.error("Invalid topic status value {} for field {}", innerValue, TOPIC_STATE_KEY); + return null; + } + @SuppressWarnings("unchecked") + Map topicStatusMetadata = (Map) innerValue; + return new TopicStatus((String) topicStatusMetadata.get(TOPIC_NAME_KEY), + (String) topicStatusMetadata.get(TOPIC_CONNECTOR_KEY), + ((Long) topicStatusMetadata.get(TOPIC_TASK_KEY)).intValue(), + (long) topicStatusMetadata.get(TOPIC_DISCOVER_TIMESTAMP_KEY)); + } catch (Exception e) { + log.error("Failed to deserialize topic status", e); + return null; + } + } + + private byte[] serialize(AbstractStatus status) { + Struct struct = new Struct(STATUS_SCHEMA_V0); + struct.put(STATE_KEY_NAME, status.state().name()); + if (status.trace() != null) + struct.put(TRACE_KEY_NAME, status.trace()); + struct.put(WORKER_ID_KEY_NAME, status.workerId()); + struct.put(GENERATION_KEY_NAME, status.generation()); + return converter.fromConnectData(statusTopic, STATUS_SCHEMA_V0, struct); + } + + //visible for testing + protected byte[] serializeTopicStatus(TopicStatus status) { + if (status == null) { + // This should send a tombstone record that will represent delete + return null; + } + Struct struct = new Struct(TOPIC_STATUS_VALUE_SCHEMA_V0); + struct.put(TOPIC_NAME_KEY, status.topic()); + struct.put(TOPIC_CONNECTOR_KEY, status.connector()); + struct.put(TOPIC_TASK_KEY, status.task()); + struct.put(TOPIC_DISCOVER_TIMESTAMP_KEY, status.discoverTimestamp()); + return converter.fromConnectData( + statusTopic, + TOPIC_STATUS_SCHEMA_V0, + Collections.singletonMap(TOPIC_STATE_KEY, struct)); + } + + private String parseConnectorStatusKey(String key) { + return key.substring(CONNECTOR_STATUS_PREFIX.length()); + } + + private ConnectorTaskId parseConnectorTaskId(String key) { + String[] parts = key.split("-"); + if (parts.length < 4) return null; + + try { + int taskNum = Integer.parseInt(parts[parts.length - 1]); + String connectorName = Utils.join(Arrays.copyOfRange(parts, 2, parts.length - 1), "-"); + return new ConnectorTaskId(connectorName, taskNum); + } catch (NumberFormatException e) { + log.warn("Invalid task status key {}", key); + return null; + } + } + + private void readConnectorStatus(String key, byte[] value) { + String connector = parseConnectorStatusKey(key); + if (connector == null || connector.isEmpty()) { + log.warn("Discarding record with invalid connector status key {}", key); + return; + } + + if (value == null) { + log.trace("Removing status for connector {}", connector); + remove(connector); + return; + } + + ConnectorStatus status = parseConnectorStatus(connector, value); + if (status == null) + return; + + synchronized (this) { + log.trace("Received connector {} status update {}", connector, status); + CacheEntry entry = getOrAdd(connector); + entry.put(status); + } + } + + private void readTaskStatus(String key, byte[] value) { + ConnectorTaskId id = parseConnectorTaskId(key); + if (id == null) { + log.warn("Discarding record with invalid task status key {}", key); + return; + } + + if (value == null) { + log.trace("Removing task status for {}", id); + remove(id); + return; + } + + TaskStatus status = parseTaskStatus(id, value); + if (status == null) { + log.warn("Failed to parse task status with key {}", key); + return; + } + + synchronized (this) { + log.trace("Received task {} status update {}", id, status); + CacheEntry entry = getOrAdd(id); + entry.put(status); + } + } + + private void readTopicStatus(String key, byte[] value) { + int delimiterPos = key.indexOf(':'); + int beginPos = TOPIC_STATUS_PREFIX.length(); + if (beginPos > delimiterPos) { + log.warn("Discarding record with invalid topic status key {}", key); + return; + } + + String topic = key.substring(beginPos, delimiterPos); + if (topic.isEmpty()) { + log.warn("Discarding record with invalid topic status key containing empty topic {}", key); + return; + } + + beginPos = delimiterPos + TOPIC_STATUS_SEPARATOR.length(); + int endPos = key.length(); + if (beginPos > endPos) { + log.warn("Discarding record with invalid topic status key {}", key); + return; + } + + String connector = key.substring(beginPos); + if (connector.isEmpty()) { + log.warn("Discarding record with invalid topic status key containing empty connector {}", key); + return; + } + + if (value == null) { + log.trace("Removing status for topic {} and connector {}", topic, connector); + removeTopic(topic, connector); + return; + } + + TopicStatus status = parseTopicStatus(value); + if (status == null) { + log.warn("Failed to parse topic status with key {}", key); + return; + } + + log.trace("Received topic status update {}", status); + topics.computeIfAbsent(connector, k -> new ConcurrentHashMap<>()) + .put(topic, status); + } + + // visible for testing + void read(ConsumerRecord record) { + String key = record.key(); + if (key.startsWith(CONNECTOR_STATUS_PREFIX)) { + readConnectorStatus(key, record.value()); + } else if (key.startsWith(TASK_STATUS_PREFIX)) { + readTaskStatus(key, record.value()); + } else if (key.startsWith(TOPIC_STATUS_PREFIX)) { + readTopicStatus(key, record.value()); + } else { + log.warn("Discarding record with invalid key {}", key); + } + } + + private static class CacheEntry> { + private T value = null; + private int sequence = 0; + private boolean deleted = false; + + public int increment() { + return ++sequence; + } + + public void put(T value) { + this.value = value; + } + + public T get() { + return value; + } + + public void delete() { + this.deleted = true; + } + + public boolean isDeleted() { + return deleted; + } + + public boolean canWriteSafely(T status) { + return value == null + || value.workerId().equals(status.workerId()) + || value.generation() <= status.generation(); + } + + public boolean canWriteSafely(T status, int sequence) { + return canWriteSafely(status) && this.sequence == sequence; + } + + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/storage/MemoryConfigBackingStore.java b/connect/runtime/src/main/java/org/apache/kafka/connect/storage/MemoryConfigBackingStore.java new file mode 100644 index 0000000..38acf2e --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/storage/MemoryConfigBackingStore.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import org.apache.kafka.connect.runtime.RestartRequest; +import org.apache.kafka.connect.runtime.SessionKey; +import org.apache.kafka.connect.runtime.TargetState; +import org.apache.kafka.connect.runtime.WorkerConfigTransformer; +import org.apache.kafka.connect.runtime.distributed.ClusterConfigState; +import org.apache.kafka.connect.util.ConnectorTaskId; + +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; +import java.util.concurrent.TimeUnit; + +public class MemoryConfigBackingStore implements ConfigBackingStore { + + private Map connectors = new HashMap<>(); + private UpdateListener updateListener; + private WorkerConfigTransformer configTransformer; + + public MemoryConfigBackingStore() { + } + + public MemoryConfigBackingStore(WorkerConfigTransformer configTransformer) { + this.configTransformer = configTransformer; + } + + @Override + public synchronized void start() { + } + + @Override + public synchronized void stop() { + } + + @Override + public synchronized ClusterConfigState snapshot() { + Map connectorTaskCounts = new HashMap<>(); + Map> connectorConfigs = new HashMap<>(); + Map connectorTargetStates = new HashMap<>(); + Map> taskConfigs = new HashMap<>(); + + for (Map.Entry connectorStateEntry : connectors.entrySet()) { + String connector = connectorStateEntry.getKey(); + ConnectorState connectorState = connectorStateEntry.getValue(); + connectorTaskCounts.put(connector, connectorState.taskConfigs.size()); + connectorConfigs.put(connector, connectorState.connConfig); + connectorTargetStates.put(connector, connectorState.targetState); + taskConfigs.putAll(connectorState.taskConfigs); + } + + return new ClusterConfigState( + ClusterConfigState.NO_OFFSET, + null, + connectorTaskCounts, + connectorConfigs, + connectorTargetStates, + taskConfigs, + Collections.emptySet(), + configTransformer); + } + + @Override + public synchronized boolean contains(String connector) { + return connectors.containsKey(connector); + } + + @Override + public synchronized void putConnectorConfig(String connector, Map properties) { + ConnectorState state = connectors.get(connector); + if (state == null) + connectors.put(connector, new ConnectorState(properties)); + else + state.connConfig = properties; + + if (updateListener != null) + updateListener.onConnectorConfigUpdate(connector); + } + + @Override + public synchronized void removeConnectorConfig(String connector) { + ConnectorState state = connectors.remove(connector); + + if (updateListener != null && state != null) + updateListener.onConnectorConfigRemove(connector); + } + + @Override + public synchronized void removeTaskConfigs(String connector) { + ConnectorState state = connectors.get(connector); + if (state == null) + throw new IllegalArgumentException("Cannot remove tasks for non-existing connector"); + + HashSet taskIds = new HashSet<>(state.taskConfigs.keySet()); + state.taskConfigs.clear(); + + if (updateListener != null) + updateListener.onTaskConfigUpdate(taskIds); + } + + @Override + public synchronized void putTaskConfigs(String connector, List> configs) { + ConnectorState state = connectors.get(connector); + if (state == null) + throw new IllegalArgumentException("Cannot put tasks for non-existing connector"); + + Map> taskConfigsMap = taskConfigListAsMap(connector, configs); + state.taskConfigs = taskConfigsMap; + + if (updateListener != null) + updateListener.onTaskConfigUpdate(taskConfigsMap.keySet()); + } + + @Override + public void refresh(long timeout, TimeUnit unit) { + } + + @Override + public synchronized void putTargetState(String connector, TargetState state) { + ConnectorState connectorState = connectors.get(connector); + if (connectorState == null) + throw new IllegalArgumentException("No connector `" + connector + "` configured"); + + connectorState.targetState = state; + + if (updateListener != null) + updateListener.onConnectorTargetStateChange(connector); + } + + @Override + public void putSessionKey(SessionKey sessionKey) { + // no-op + } + + @Override + public void putRestartRequest(RestartRequest restartRequest) { + // no-op + } + + @Override + public synchronized void setUpdateListener(UpdateListener listener) { + this.updateListener = listener; + } + + private static class ConnectorState { + private TargetState targetState; + private Map connConfig; + private Map> taskConfigs; + + public ConnectorState(Map connConfig) { + this.targetState = TargetState.STARTED; + this.connConfig = connConfig; + this.taskConfigs = new HashMap<>(); + } + } + + private static Map> taskConfigListAsMap(String connector, List> configs) { + int index = 0; + Map> result = new TreeMap<>(); + for (Map taskConfigMap: configs) { + result.put(new ConnectorTaskId(connector, index++), taskConfigMap); + } + return result; + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/storage/MemoryOffsetBackingStore.java b/connect/runtime/src/main/java/org/apache/kafka/connect/storage/MemoryOffsetBackingStore.java new file mode 100644 index 0000000..ceefd13 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/storage/MemoryOffsetBackingStore.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.apache.kafka.connect.util.Callback; +import org.apache.kafka.common.utils.ThreadUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +/** + * Implementation of OffsetBackingStore that doesn't actually persist any data. To ensure this + * behaves similarly to a real backing store, operations are executed asynchronously on a + * background thread. + */ +public class MemoryOffsetBackingStore implements OffsetBackingStore { + private static final Logger log = LoggerFactory.getLogger(MemoryOffsetBackingStore.class); + + protected Map data = new HashMap<>(); + protected ExecutorService executor; + + public MemoryOffsetBackingStore() { + + } + + @Override + public void configure(WorkerConfig config) { + } + + @Override + public void start() { + executor = Executors.newFixedThreadPool(1, ThreadUtils.createThreadFactory( + this.getClass().getSimpleName() + "-%d", false)); + } + + @Override + public void stop() { + if (executor != null) { + executor.shutdown(); + // Best effort wait for any get() and set() tasks (and caller's callbacks) to complete. + try { + executor.awaitTermination(30, TimeUnit.SECONDS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + if (!executor.shutdownNow().isEmpty()) { + throw new ConnectException("Failed to stop MemoryOffsetBackingStore. Exiting without cleanly " + + "shutting down pending tasks and/or callbacks."); + } + executor = null; + } + } + + @Override + public Future> get(final Collection keys) { + return executor.submit(() -> { + Map result = new HashMap<>(); + for (ByteBuffer key : keys) { + result.put(key, data.get(key)); + } + return result; + }); + } + + @Override + public Future set(final Map values, + final Callback callback) { + return executor.submit(() -> { + for (Map.Entry entry : values.entrySet()) { + data.put(entry.getKey(), entry.getValue()); + } + save(); + if (callback != null) + callback.onCompletion(null, null); + return null; + }); + } + + // Hook to allow subclasses to persist data + protected void save() { + + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/storage/MemoryStatusBackingStore.java b/connect/runtime/src/main/java/org/apache/kafka/connect/storage/MemoryStatusBackingStore.java new file mode 100644 index 0000000..fbd7048 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/storage/MemoryStatusBackingStore.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import org.apache.kafka.connect.runtime.ConnectorStatus; +import org.apache.kafka.connect.runtime.TaskStatus; +import org.apache.kafka.connect.runtime.TopicStatus; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.apache.kafka.connect.util.Table; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +public class MemoryStatusBackingStore implements StatusBackingStore { + private final Table tasks; + private final Map connectors; + private final ConcurrentMap> topics; + + public MemoryStatusBackingStore() { + this.tasks = new Table<>(); + this.connectors = new HashMap<>(); + this.topics = new ConcurrentHashMap<>(); + } + + @Override + public void configure(WorkerConfig config) { + + } + + @Override + public void start() { + + } + + @Override + public void stop() { + + } + + @Override + public synchronized void put(ConnectorStatus status) { + if (status.state() == ConnectorStatus.State.DESTROYED) + connectors.remove(status.id()); + else + connectors.put(status.id(), status); + } + + @Override + public synchronized void putSafe(ConnectorStatus status) { + put(status); + } + + @Override + public synchronized void put(TaskStatus status) { + if (status.state() == TaskStatus.State.DESTROYED) + tasks.remove(status.id().connector(), status.id().task()); + else + tasks.put(status.id().connector(), status.id().task(), status); + } + + @Override + public synchronized void putSafe(TaskStatus status) { + put(status); + } + + @Override + public void put(final TopicStatus status) { + topics.computeIfAbsent(status.connector(), k -> new ConcurrentHashMap<>()) + .put(status.topic(), status); + } + + @Override + public synchronized TaskStatus get(ConnectorTaskId id) { + return tasks.get(id.connector(), id.task()); + } + + @Override + public synchronized ConnectorStatus get(String connector) { + return connectors.get(connector); + } + + @Override + public synchronized Collection getAll(String connector) { + return new HashSet<>(tasks.row(connector).values()); + } + + @Override + public TopicStatus getTopic(String connector, String topic) { + ConcurrentMap activeTopics = topics.get(Objects.requireNonNull(connector)); + return activeTopics != null ? activeTopics.get(Objects.requireNonNull(topic)) : null; + } + + @Override + public Collection getAllTopics(String connector) { + ConcurrentMap activeTopics = topics.get(Objects.requireNonNull(connector)); + return activeTopics != null + ? Collections.unmodifiableCollection(activeTopics.values()) + : Collections.emptySet(); + } + + @Override + public void deleteTopic(String connector, String topic) { + ConcurrentMap activeTopics = topics.get(Objects.requireNonNull(connector)); + if (activeTopics != null) { + activeTopics.remove(Objects.requireNonNull(topic)); + } + } + + @Override + public synchronized Set connectors() { + return new HashSet<>(connectors.keySet()); + } + + @Override + public void flush() { + + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/storage/OffsetBackingStore.java b/connect/runtime/src/main/java/org/apache/kafka/connect/storage/OffsetBackingStore.java new file mode 100644 index 0000000..1e4375b --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/storage/OffsetBackingStore.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.apache.kafka.connect.util.Callback; + +import java.nio.ByteBuffer; +import java.util.Collection; +import java.util.Map; +import java.util.concurrent.Future; + +/** + *

            + * OffsetBackingStore is an interface for storage backends that store key-value data. The backing + * store doesn't need to handle serialization or deserialization. It only needs to support + * reading/writing bytes. Since it is expected these operations will require network + * operations, only bulk operations are supported. + *

            + *

            + * Since OffsetBackingStore is a shared resource that may be used by many OffsetStorage instances + * that are associated with individual tasks, the caller must be sure keys include information about the + * connector so that the shared namespace does not result in conflicting keys. + *

            + */ +public interface OffsetBackingStore { + + /** + * Start this offset store. + */ + void start(); + + /** + * Stop the backing store. Implementations should attempt to shutdown gracefully, but not block + * indefinitely. + */ + void stop(); + + /** + * Get the values for the specified keys + * @param keys list of keys to look up + * @return future for the resulting map from key to value + */ + Future> get(Collection keys); + + /** + * Set the specified keys and values. + * @param values map from key to value + * @param callback callback to invoke on completion + * @return void future for the operation + */ + Future set(Map values, Callback callback); + + /** + * Configure class with the given key-value pairs + * @param config can be DistributedConfig or StandaloneConfig + */ + void configure(WorkerConfig config); +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/storage/OffsetStorageReaderImpl.java b/connect/runtime/src/main/java/org/apache/kafka/connect/storage/OffsetStorageReaderImpl.java new file mode 100644 index 0000000..a1eea43 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/storage/OffsetStorageReaderImpl.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.errors.ConnectException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CancellationException; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Implementation of OffsetStorageReader. Unlike OffsetStorageWriter which is implemented + * directly, the interface is only separate from this implementation because it needs to be + * included in the public API package. + */ +public class OffsetStorageReaderImpl implements CloseableOffsetStorageReader { + private static final Logger log = LoggerFactory.getLogger(OffsetStorageReaderImpl.class); + + private final OffsetBackingStore backingStore; + private final String namespace; + private final Converter keyConverter; + private final Converter valueConverter; + private final AtomicBoolean closed; + private final Set>> offsetReadFutures; + + public OffsetStorageReaderImpl(OffsetBackingStore backingStore, String namespace, + Converter keyConverter, Converter valueConverter) { + this.backingStore = backingStore; + this.namespace = namespace; + this.keyConverter = keyConverter; + this.valueConverter = valueConverter; + this.closed = new AtomicBoolean(false); + this.offsetReadFutures = new HashSet<>(); + } + + @Override + public Map offset(Map partition) { + return offsets(Collections.singletonList(partition)).get(partition); + } + + @Override + @SuppressWarnings("unchecked") + public Map, Map> offsets(Collection> partitions) { + // Serialize keys so backing store can work with them + Map> serializedToOriginal = new HashMap<>(partitions.size()); + for (Map key : partitions) { + try { + // Offsets are treated as schemaless, their format is only validated here (and the returned value below) + OffsetUtils.validateFormat(key); + byte[] keySerialized = keyConverter.fromConnectData(namespace, null, Arrays.asList(namespace, key)); + ByteBuffer keyBuffer = (keySerialized != null) ? ByteBuffer.wrap(keySerialized) : null; + serializedToOriginal.put(keyBuffer, key); + } catch (Throwable t) { + log.error("CRITICAL: Failed to serialize partition key when getting offsets for task with " + + "namespace {}. No value for this data will be returned, which may break the " + + "task or cause it to skip some data.", namespace, t); + } + } + + // Get serialized key -> serialized value from backing store + Map raw; + try { + Future> offsetReadFuture; + synchronized (offsetReadFutures) { + if (closed.get()) { + throw new ConnectException( + "Offset reader is closed. This is likely because the task has already been " + + "scheduled to stop but has taken longer than the graceful shutdown " + + "period to do so."); + } + offsetReadFuture = backingStore.get(serializedToOriginal.keySet()); + offsetReadFutures.add(offsetReadFuture); + } + + try { + raw = offsetReadFuture.get(); + } catch (CancellationException e) { + throw new ConnectException( + "Offset reader closed while attempting to read offsets. This is likely because " + + "the task was been scheduled to stop but has taken longer than the " + + "graceful shutdown period to do so."); + } finally { + synchronized (offsetReadFutures) { + offsetReadFutures.remove(offsetReadFuture); + } + } + } catch (Exception e) { + log.error("Failed to fetch offsets from namespace {}: ", namespace, e); + throw new ConnectException("Failed to fetch offsets.", e); + } + + // Deserialize all the values and map back to the original keys + Map, Map> result = new HashMap<>(partitions.size()); + for (Map.Entry rawEntry : raw.entrySet()) { + try { + // Since null could be a valid key, explicitly check whether map contains the key + if (!serializedToOriginal.containsKey(rawEntry.getKey())) { + log.error("Should be able to map {} back to a requested partition-offset key, backing " + + "store may have returned invalid data", rawEntry.getKey()); + continue; + } + Map origKey = serializedToOriginal.get(rawEntry.getKey()); + SchemaAndValue deserializedSchemaAndValue = valueConverter.toConnectData(namespace, rawEntry.getValue() != null ? rawEntry.getValue().array() : null); + Object deserializedValue = deserializedSchemaAndValue.value(); + OffsetUtils.validateFormat(deserializedValue); + + result.put(origKey, (Map) deserializedValue); + } catch (Throwable t) { + log.error("CRITICAL: Failed to deserialize offset data when getting offsets for task with" + + " namespace {}. No value for this data will be returned, which may break the " + + "task or cause it to skip some data. This could either be due to an error in " + + "the connector implementation or incompatible schema.", namespace, t); + } + } + + return result; + } + + public void close() { + if (!closed.getAndSet(true)) { + synchronized (offsetReadFutures) { + for (Future> offsetReadFuture : offsetReadFutures) { + try { + offsetReadFuture.cancel(true); + } catch (Throwable t) { + log.error("Failed to cancel offset read future", t); + } + } + offsetReadFutures.clear(); + } + } + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/storage/OffsetStorageWriter.java b/connect/runtime/src/main/java/org/apache/kafka/connect/storage/OffsetStorageWriter.java new file mode 100644 index 0000000..7766e2c --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/storage/OffsetStorageWriter.java @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.util.Callback; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.Future; + +/** + *

            + * OffsetStorageWriter is a buffered writer that wraps the simple OffsetBackingStore interface. + * It maintains a copy of the key-value data in memory and buffers writes. It allows you to take + * a snapshot, which can then be asynchronously flushed to the backing store while new writes + * continue to be processed. This allows Kafka Connect to process offset commits in the background + * while continuing to process messages. + *

            + *

            + * Connect uses an OffsetStorage implementation to save state about the current progress of + * source (import to Kafka) jobs, which may have many input partitions and "offsets" may not be as + * simple as they are for Kafka partitions or files. Offset storage is not required for sink jobs + * because they can use Kafka's native offset storage (or the sink data store can handle offset + * storage to achieve exactly once semantics). + *

            + *

            + * Both partitions and offsets are generic data objects. This allows different connectors to use + * whatever representation they need, even arbitrarily complex records. These are translated + * internally into the serialized form the OffsetBackingStore uses. + *

            + *

            + * Note that this only provides write functionality. This is intentional to ensure stale data is + * never read. Offset data should only be read during startup or reconfiguration of a task. By + * always serving those requests by reading the values from the backing store, we ensure we never + * accidentally use stale data. (One example of how this can occur: a task is processing input + * partition A, writing offsets; reconfiguration causes partition A to be reassigned elsewhere; + * reconfiguration causes partition A to be reassigned to this node, but now the offset data is out + * of date). Since these offsets are created and managed by the connector itself, there's no way + * for the offset management layer to know which keys are "owned" by which tasks at any given + * time. + *

            + *

            + * This class is thread-safe. + *

            + */ +public class OffsetStorageWriter { + private static final Logger log = LoggerFactory.getLogger(OffsetStorageWriter.class); + + private final OffsetBackingStore backingStore; + private final Converter keyConverter; + private final Converter valueConverter; + private final String namespace; + // Offset data in Connect format + private Map, Map> data = new HashMap<>(); + + private Map, Map> toFlush = null; + // Unique ID for each flush request to handle callbacks after timeouts + private long currentFlushId = 0; + + public OffsetStorageWriter(OffsetBackingStore backingStore, + String namespace, Converter keyConverter, Converter valueConverter) { + this.backingStore = backingStore; + this.namespace = namespace; + this.keyConverter = keyConverter; + this.valueConverter = valueConverter; + } + + /** + * Set an offset for a partition using Connect data values + * @param partition the partition to store an offset for + * @param offset the offset + */ + public synchronized void offset(Map partition, Map offset) { + data.put(partition, offset); + } + + private boolean flushing() { + return toFlush != null; + } + + /** + * Performs the first step of a flush operation, snapshotting the current state. This does not + * actually initiate the flush with the underlying storage. + * + * @return true if a flush was initiated, false if no data was available + */ + public synchronized boolean beginFlush() { + if (flushing()) { + log.error("Invalid call to OffsetStorageWriter flush() while already flushing, the " + + "framework should not allow this"); + throw new ConnectException("OffsetStorageWriter is already flushing"); + } + + if (data.isEmpty()) + return false; + + assert !flushing(); + toFlush = data; + data = new HashMap<>(); + return true; + } + + /** + * Flush the current offsets and clear them from this writer. This is non-blocking: it + * moves the current set of offsets out of the way, serializes the data, and asynchronously + * writes the data to the backing store. If no offsets need to be written, the callback is + * still invoked, but no Future is returned. + * + * @return a Future, or null if there are no offsets to commitOffsets + */ + public Future doFlush(final Callback callback) { + + final long flushId; + // Serialize + final Map offsetsSerialized; + + synchronized (this) { + flushId = currentFlushId; + + try { + offsetsSerialized = new HashMap<>(toFlush.size()); + for (Map.Entry, Map> entry : toFlush.entrySet()) { + // Offsets are specified as schemaless to the converter, using whatever internal schema is appropriate + // for that data. The only enforcement of the format is here. + OffsetUtils.validateFormat(entry.getKey()); + OffsetUtils.validateFormat(entry.getValue()); + // When serializing the key, we add in the namespace information so the key is [namespace, real key] + byte[] key = keyConverter.fromConnectData(namespace, null, Arrays.asList(namespace, entry.getKey())); + ByteBuffer keyBuffer = (key != null) ? ByteBuffer.wrap(key) : null; + byte[] value = valueConverter.fromConnectData(namespace, null, entry.getValue()); + ByteBuffer valueBuffer = (value != null) ? ByteBuffer.wrap(value) : null; + offsetsSerialized.put(keyBuffer, valueBuffer); + } + } catch (Throwable t) { + // Must handle errors properly here or the writer will be left mid-flush forever and be + // unable to make progress. + log.error("CRITICAL: Failed to serialize offset data, making it impossible to commit " + + "offsets under namespace {}. This likely won't recover unless the " + + "unserializable partition or offset information is overwritten.", namespace); + log.error("Cause of serialization failure:", t); + callback.onCompletion(t, null); + return null; + } + + // And submit the data + log.debug("Submitting {} entries to backing store. The offsets are: {}", offsetsSerialized.size(), toFlush); + } + + return backingStore.set(offsetsSerialized, (error, result) -> { + boolean isCurrent = handleFinishWrite(flushId, error, result); + if (isCurrent && callback != null) { + callback.onCompletion(error, result); + } + }); + } + + /** + * Cancel a flush that has been initiated by {@link #beginFlush}. This should not be called if + * {@link #doFlush} has already been invoked. It should be used if an operation performed + * between beginFlush and doFlush failed. + */ + public synchronized void cancelFlush() { + // Verify we're still flushing data to handle a race between cancelFlush() calls from up the + // call stack and callbacks from the write request to underlying storage + if (flushing()) { + // Just recombine the data and place it back in the primary storage + toFlush.putAll(data); + data = toFlush; + currentFlushId++; + toFlush = null; + } + } + + /** + * Handle completion of a write. Returns true if this callback is for the current flush + * operation, false if it's for an old operation that should now be ignored. + */ + private synchronized boolean handleFinishWrite(long flushId, Throwable error, Void result) { + // Callbacks need to be handled carefully since the flush operation may have already timed + // out and been cancelled. + if (flushId != currentFlushId) + return false; + + if (error != null) { + cancelFlush(); + } else { + currentFlushId++; + toFlush = null; + } + return true; + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/storage/OffsetUtils.java b/connect/runtime/src/main/java/org/apache/kafka/connect/storage/OffsetUtils.java new file mode 100644 index 0000000..cfae221 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/storage/OffsetUtils.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import org.apache.kafka.connect.data.ConnectSchema; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.errors.DataException; + +import java.util.Map; + +public class OffsetUtils { + @SuppressWarnings("unchecked") + public static void validateFormat(Object offsetData) { + if (offsetData == null) + return; + + if (!(offsetData instanceof Map)) + throw new DataException("Offsets must be specified as a Map"); + validateFormat((Map) offsetData); + } + + public static void validateFormat(Map offsetData) { + // Both keys and values for offsets may be null. For values, this is a useful way to delete offsets or indicate + // that there's not usable concept of offsets in your source system. + if (offsetData == null) + return; + + for (Map.Entry entry : offsetData.entrySet()) { + if (!(entry.getKey() instanceof String)) + throw new DataException("Offsets may only use String keys"); + + Object value = entry.getValue(); + if (value == null) + continue; + Schema.Type schemaType = ConnectSchema.schemaType(value.getClass()); + if (schemaType == null) + throw new DataException("Offsets may only contain primitive types as values, but field " + entry.getKey() + " contains " + value.getClass()); + if (!schemaType.isPrimitive()) + throw new DataException("Offsets may only contain primitive types as values, but field " + entry.getKey() + " contains " + schemaType); + } + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/storage/StatusBackingStore.java b/connect/runtime/src/main/java/org/apache/kafka/connect/storage/StatusBackingStore.java new file mode 100644 index 0000000..0250932 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/storage/StatusBackingStore.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import org.apache.kafka.connect.runtime.ConnectorStatus; +import org.apache.kafka.connect.runtime.TaskStatus; +import org.apache.kafka.connect.runtime.TopicStatus; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.apache.kafka.connect.util.ConnectorTaskId; + +import java.util.Collection; +import java.util.Set; + +public interface StatusBackingStore { + + /** + * Start dependent services (if needed) + */ + void start(); + + /** + * Stop dependent services (if needed) + */ + void stop(); + + /** + * Set the state of the connector to the given value. + * @param status the status of the connector + */ + void put(ConnectorStatus status); + + /** + * Safely set the state of the connector to the given value. What is + * considered "safe" depends on the implementation, but basically it + * means that the store can provide higher assurance that another worker + * hasn't concurrently written any conflicting data. + * @param status the status of the connector + */ + void putSafe(ConnectorStatus status); + + /** + * Set the state of the connector to the given value. + * @param status the status of the task + */ + void put(TaskStatus status); + + /** + * Safely set the state of the task to the given value. What is + * considered "safe" depends on the implementation, but basically it + * means that the store can provide higher assurance that another worker + * hasn't concurrently written any conflicting data. + * @param status the status of the task + */ + void putSafe(TaskStatus status); + + /** + * Set the state of a connector's topic to the given value. + * @param status the status of the topic used by a connector + */ + void put(TopicStatus status); + + /** + * Get the current state of the task. + * @param id the id of the task + * @return the state or null if there is none + */ + TaskStatus get(ConnectorTaskId id); + + /** + * Get the current state of the connector. + * @param connector the connector name + * @return the state or null if there is none + */ + ConnectorStatus get(String connector); + + /** + * Get the states of all tasks for the given connector. + * @param connector the connector name + * @return a map from task ids to their respective status + */ + Collection getAll(String connector); + + /** + * Get the status of a connector's topic if the connector is actively using this topic + * @param connector the connector name; never null + * @param topic the topic name; never null + * @return the state or null if there is none + */ + TopicStatus getTopic(String connector, String topic); + + /** + * Get the states of all topics that a connector is using. + * @param connector the connector name; never null + * @return a collection of topic states or an empty collection if there is none + */ + Collection getAllTopics(String connector); + + /** + * Delete this topic from the connector's set of active topics + * @param connector the connector name; never null + * @param topic the topic name; never null + */ + void deleteTopic(String connector, String topic); + + /** + * Get all cached connectors. + * @return the set of connector names + */ + Set connectors(); + + /** + * Flush any pending writes + */ + void flush(); + + /** + * Configure class with the given key-value pairs + * @param config config for StatusBackingStore + */ + void configure(WorkerConfig config); +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/tools/MockConnector.java b/connect/runtime/src/main/java/org/apache/kafka/connect/tools/MockConnector.java new file mode 100644 index 0000000..c7d24d9 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/tools/MockConnector.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.tools; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.utils.AppInfoParser; +import org.apache.kafka.connect.connector.Connector; +import org.apache.kafka.connect.connector.Task; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +/** + * This connector provides support for mocking certain connector behaviors. For example, + * this can be used to simulate connector or task failures. It works by passing a "mock mode" + * through configuration from the system test. New mock behavior can be implemented either + * in the connector or in the task by providing a new mode implementation. + * + * At the moment, this connector only supports a single task and shares configuration between + * the connector and its tasks. + * + * @see MockSinkConnector + * @see MockSourceConnector + */ +public class MockConnector extends Connector { + public static final String MOCK_MODE_KEY = "mock_mode"; + public static final String DELAY_MS_KEY = "delay_ms"; + + public static final String CONNECTOR_FAILURE = "connector-failure"; + public static final String TASK_FAILURE = "task-failure"; + + public static final long DEFAULT_FAILURE_DELAY_MS = 15000; + + private static final Logger log = LoggerFactory.getLogger(MockConnector.class); + + private Map config; + private ScheduledExecutorService executor; + + @Override + public String version() { + return AppInfoParser.getVersion(); + } + + @Override + public void start(Map config) { + this.config = config; + + if (CONNECTOR_FAILURE.equals(config.get(MOCK_MODE_KEY))) { + // Schedule this connector to raise an exception after some delay + + String delayMsString = config.get(DELAY_MS_KEY); + long delayMs = DEFAULT_FAILURE_DELAY_MS; + if (delayMsString != null) + delayMs = Long.parseLong(delayMsString); + + log.debug("Started MockConnector with failure delay of {} ms", delayMs); + executor = Executors.newSingleThreadScheduledExecutor(); + executor.schedule(() -> { + log.debug("Triggering connector failure"); + context.raiseError(new RuntimeException()); + }, delayMs, TimeUnit.MILLISECONDS); + } + } + + @Override + public Class taskClass() { + throw new UnsupportedOperationException(); + } + + @Override + public List> taskConfigs(int maxTasks) { + log.debug("Creating single task for MockConnector"); + return Collections.singletonList(config); + } + + @Override + public void stop() { + if (executor != null) { + executor.shutdownNow(); + + try { + if (!executor.awaitTermination(20, TimeUnit.SECONDS)) + throw new RuntimeException("Failed timely termination of scheduler"); + } catch (InterruptedException e) { + throw new RuntimeException("Task was interrupted during shutdown"); + } + } + } + + @Override + public ConfigDef config() { + return new ConfigDef(); + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/tools/MockSinkConnector.java b/connect/runtime/src/main/java/org/apache/kafka/connect/tools/MockSinkConnector.java new file mode 100644 index 0000000..2550e51 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/tools/MockSinkConnector.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.tools; + +import org.apache.kafka.common.config.Config; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.connector.ConnectorContext; +import org.apache.kafka.connect.connector.Task; +import org.apache.kafka.connect.sink.SinkConnector; + +import java.util.List; +import java.util.Map; + +/** + * Mock sink implementation which delegates to {@link MockConnector}. + */ +public class MockSinkConnector extends SinkConnector { + + private MockConnector delegate = new MockConnector(); + + @Override + public void initialize(ConnectorContext ctx) { + delegate.initialize(ctx); + } + + @Override + public void initialize(ConnectorContext ctx, List> taskConfigs) { + delegate.initialize(ctx, taskConfigs); + } + + @Override + public void reconfigure(Map props) { + delegate.reconfigure(props); + } + + @Override + public Config validate(Map connectorConfigs) { + return delegate.validate(connectorConfigs); + } + + @Override + public String version() { + return delegate.version(); + } + + @Override + public void start(Map props) { + delegate.start(props); + } + + @Override + public Class taskClass() { + return MockSinkTask.class; + } + + @Override + public List> taskConfigs(int maxTasks) { + return delegate.taskConfigs(maxTasks); + } + + @Override + public void stop() { + delegate.stop(); + } + + @Override + public ConfigDef config() { + return delegate.config(); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/tools/MockSinkTask.java b/connect/runtime/src/main/java/org/apache/kafka/connect/tools/MockSinkTask.java new file mode 100644 index 0000000..f48bd31 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/tools/MockSinkTask.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.tools; + +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.AppInfoParser; +import org.apache.kafka.connect.sink.SinkRecord; +import org.apache.kafka.connect.sink.SinkTask; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; +import java.util.Map; + +public class MockSinkTask extends SinkTask { + private static final Logger log = LoggerFactory.getLogger(MockSinkTask.class); + + private String mockMode; + private long startTimeMs; + private long failureDelayMs; + + @Override + public String version() { + return AppInfoParser.getVersion(); + } + + @Override + public void start(Map config) { + this.mockMode = config.get(MockConnector.MOCK_MODE_KEY); + + if (MockConnector.TASK_FAILURE.equals(mockMode)) { + this.startTimeMs = System.currentTimeMillis(); + + String delayMsString = config.get(MockConnector.DELAY_MS_KEY); + this.failureDelayMs = MockConnector.DEFAULT_FAILURE_DELAY_MS; + if (delayMsString != null) + failureDelayMs = Long.parseLong(delayMsString); + + log.debug("Started MockSinkTask at {} with failure scheduled in {} ms", startTimeMs, failureDelayMs); + setTimeout(); + } + } + + @Override + public void put(Collection records) { + if (MockConnector.TASK_FAILURE.equals(mockMode)) { + long now = System.currentTimeMillis(); + if (now - startTimeMs > failureDelayMs) { + log.debug("Triggering sink task failure"); + throw new RuntimeException(); + } + setTimeout(); + } + } + + @Override + public void flush(Map offsets) { + + } + + @Override + public void stop() { + + } + + private void setTimeout() { + // Set a reasonable minimum delay. Since this mock task may not actually consume any data from Kafka, it may only + // see put() calls triggered by wakeups for offset commits. To make sure we aren't tied to the offset commit + // interval, we force a wakeup every 250ms or after the failure delay, whichever is smaller. This is not overly + // aggressive but ensures any scheduled tasks this connector performs are reasonably close to the target time. + context.timeout(Math.min(failureDelayMs, 250)); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/tools/MockSourceConnector.java b/connect/runtime/src/main/java/org/apache/kafka/connect/tools/MockSourceConnector.java new file mode 100644 index 0000000..90868d8 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/tools/MockSourceConnector.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.tools; + +import org.apache.kafka.common.config.Config; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.connector.ConnectorContext; +import org.apache.kafka.connect.connector.Task; +import org.apache.kafka.connect.source.SourceConnector; + +import java.util.List; +import java.util.Map; + +/** + * Mock source implementation which delegates to {@link MockConnector}. + */ +public class MockSourceConnector extends SourceConnector { + + private MockConnector delegate = new MockConnector(); + + @Override + public void initialize(ConnectorContext ctx) { + delegate.initialize(ctx); + } + + @Override + public void initialize(ConnectorContext ctx, List> taskConfigs) { + delegate.initialize(ctx, taskConfigs); + } + + @Override + public void reconfigure(Map props) { + delegate.reconfigure(props); + } + + @Override + public Config validate(Map connectorConfigs) { + return delegate.validate(connectorConfigs); + } + + @Override + public String version() { + return delegate.version(); + } + + @Override + public void start(Map props) { + delegate.start(props); + } + + @Override + public Class taskClass() { + return MockSourceTask.class; + } + + @Override + public List> taskConfigs(int maxTasks) { + return delegate.taskConfigs(maxTasks); + } + + @Override + public void stop() { + delegate.stop(); + } + + @Override + public ConfigDef config() { + return delegate.config(); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/tools/MockSourceTask.java b/connect/runtime/src/main/java/org/apache/kafka/connect/tools/MockSourceTask.java new file mode 100644 index 0000000..4decf03 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/tools/MockSourceTask.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.tools; + +import org.apache.kafka.common.utils.AppInfoParser; +import org.apache.kafka.connect.source.SourceRecord; +import org.apache.kafka.connect.source.SourceTask; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +public class MockSourceTask extends SourceTask { + private static final Logger log = LoggerFactory.getLogger(MockSourceTask.class); + + private String mockMode; + private long startTimeMs; + private long failureDelayMs; + + @Override + public String version() { + return AppInfoParser.getVersion(); + } + + @Override + public void start(Map config) { + this.mockMode = config.get(MockConnector.MOCK_MODE_KEY); + + if (MockConnector.TASK_FAILURE.equals(mockMode)) { + this.startTimeMs = System.currentTimeMillis(); + + String delayMsString = config.get(MockConnector.DELAY_MS_KEY); + this.failureDelayMs = MockConnector.DEFAULT_FAILURE_DELAY_MS; + if (delayMsString != null) + failureDelayMs = Long.parseLong(delayMsString); + + log.debug("Started MockSourceTask at {} with failure scheduled in {} ms", startTimeMs, failureDelayMs); + } + } + + @Override + public List poll() throws InterruptedException { + if (MockConnector.TASK_FAILURE.equals(mockMode)) { + long now = System.currentTimeMillis(); + if (now - startTimeMs > failureDelayMs) { + log.debug("Triggering source task failure"); + throw new RuntimeException(); + } + } + return Collections.emptyList(); + } + + @Override + public void stop() { + + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/tools/PredicateDoc.java b/connect/runtime/src/main/java/org/apache/kafka/connect/tools/PredicateDoc.java new file mode 100644 index 0000000..d4399d6 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/tools/PredicateDoc.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.tools; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.runtime.isolation.Plugins; +import org.apache.kafka.connect.transforms.predicates.Predicate; + +import java.io.PrintStream; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; + +public class PredicateDoc { + + private static final class DocInfo { + final String predicateName; + final String overview; + final ConfigDef configDef; + + private

            > DocInfo(Class

            predicateClass, String overview, ConfigDef configDef) { + this.predicateName = predicateClass.getName(); + this.overview = overview; + this.configDef = configDef; + } + } + + private static final List PREDICATES; + static { + List collect = new Plugins(Collections.emptyMap()).predicates().stream() + .map(p -> { + try { + String overviewDoc = (String) p.pluginClass().getDeclaredField("OVERVIEW_DOC").get(null); + ConfigDef configDef = (ConfigDef) p.pluginClass().getDeclaredField("CONFIG_DEF").get(null); + return new DocInfo(p.pluginClass(), overviewDoc, configDef); + } catch (ReflectiveOperationException e) { + throw new RuntimeException("Predicate class " + p.pluginClass().getName() + " lacks either a `public static final String OVERVIEW_DOC` or `public static final ConfigDef CONFIG_DEF`"); + } + }) + .collect(Collectors.toList()); + collect.sort(Comparator.comparing(docInfo -> docInfo.predicateName)); + PREDICATES = collect; + } + + private static void printPredicateHtml(PrintStream out, DocInfo docInfo) { + out.println("

            "); + + out.print("
            "); + out.print(docInfo.predicateName); + out.println("
            "); + + out.println(docInfo.overview); + + out.println("

            "); + + out.println(docInfo.configDef.toHtml(6, key -> docInfo.predicateName + "_" + key)); + + out.println("

            "); + } + + private static void printHtml(PrintStream out) { + for (final DocInfo docInfo : PREDICATES) { + printPredicateHtml(out, docInfo); + } + } + + public static void main(String... args) { + printHtml(System.out); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/tools/SchemaSourceConnector.java b/connect/runtime/src/main/java/org/apache/kafka/connect/tools/SchemaSourceConnector.java new file mode 100644 index 0000000..06379d5 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/tools/SchemaSourceConnector.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.tools; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.utils.AppInfoParser; +import org.apache.kafka.connect.connector.Task; +import org.apache.kafka.connect.source.SourceConnector; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class SchemaSourceConnector extends SourceConnector { + + private Map config; + + @Override + public String version() { + return AppInfoParser.getVersion(); + } + + @Override + public void start(Map props) { + this.config = props; + } + + @Override + public Class taskClass() { + return SchemaSourceTask.class; + } + + @Override + public List> taskConfigs(int maxTasks) { + ArrayList> configs = new ArrayList<>(); + for (int i = 0; i < maxTasks; i++) { + Map props = new HashMap<>(config); + props.put(SchemaSourceTask.ID_CONFIG, String.valueOf(i)); + configs.add(props); + } + return configs; + } + + @Override + public void stop() { + } + + @Override + public ConfigDef config() { + return new ConfigDef(); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/tools/SchemaSourceTask.java b/connect/runtime/src/main/java/org/apache/kafka/connect/tools/SchemaSourceTask.java new file mode 100644 index 0000000..6fde784 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/tools/SchemaSourceTask.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.tools; + +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaBuilder; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.source.SourceRecord; +import org.apache.kafka.connect.source.SourceTask; +import org.apache.kafka.tools.ThroughputThrottler; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +public class SchemaSourceTask extends SourceTask { + + private static final Logger log = LoggerFactory.getLogger(SchemaSourceTask.class); + + public static final String NAME_CONFIG = "name"; + public static final String ID_CONFIG = "id"; + public static final String TOPIC_CONFIG = "topic"; + public static final String NUM_MSGS_CONFIG = "num.messages"; + public static final String THROUGHPUT_CONFIG = "throughput"; + public static final String MULTIPLE_SCHEMA_CONFIG = "multiple.schema"; + public static final String PARTITION_COUNT_CONFIG = "partition.count"; + + private static final String ID_FIELD = "id"; + private static final String SEQNO_FIELD = "seqno"; + private ThroughputThrottler throttler; + + private String name; // Connector name + private int id; // Task ID + private String topic; + private Map partition; + private long startingSeqno; + private long seqno; + private long count; + private long maxNumMsgs; + private boolean multipleSchema; + private int partitionCount; + + private static Schema valueSchema = SchemaBuilder.struct().version(1).name("record") + .field("boolean", Schema.BOOLEAN_SCHEMA) + .field("int", Schema.INT32_SCHEMA) + .field("long", Schema.INT64_SCHEMA) + .field("float", Schema.FLOAT32_SCHEMA) + .field("double", Schema.FLOAT64_SCHEMA) + .field("partitioning", Schema.INT32_SCHEMA) + .field("id", Schema.INT32_SCHEMA) + .field("seqno", Schema.INT64_SCHEMA) + .build(); + + private static Schema valueSchema2 = SchemaBuilder.struct().version(2).name("record") + .field("boolean", Schema.BOOLEAN_SCHEMA) + .field("int", Schema.INT32_SCHEMA) + .field("long", Schema.INT64_SCHEMA) + .field("float", Schema.FLOAT32_SCHEMA) + .field("double", Schema.FLOAT64_SCHEMA) + .field("partitioning", Schema.INT32_SCHEMA) + .field("string", SchemaBuilder.string().defaultValue("abc").build()) + .field("id", Schema.INT32_SCHEMA) + .field("seqno", Schema.INT64_SCHEMA) + .build(); + + @Override + public String version() { + return new SchemaSourceConnector().version(); + } + + @Override + public void start(Map props) { + final long throughput; + try { + name = props.get(NAME_CONFIG); + id = Integer.parseInt(props.get(ID_CONFIG)); + topic = props.get(TOPIC_CONFIG); + maxNumMsgs = Long.parseLong(props.get(NUM_MSGS_CONFIG)); + multipleSchema = Boolean.parseBoolean(props.get(MULTIPLE_SCHEMA_CONFIG)); + partitionCount = Integer.parseInt(props.containsKey(PARTITION_COUNT_CONFIG) ? props.get(PARTITION_COUNT_CONFIG) : "1"); + throughput = Long.parseLong(props.get(THROUGHPUT_CONFIG)); + } catch (NumberFormatException e) { + throw new ConnectException("Invalid SchemaSourceTask configuration", e); + } + + throttler = new ThroughputThrottler(throughput, System.currentTimeMillis()); + partition = Collections.singletonMap(ID_FIELD, id); + Map previousOffset = this.context.offsetStorageReader().offset(partition); + if (previousOffset != null) { + seqno = (Long) previousOffset.get(SEQNO_FIELD) + 1; + } else { + seqno = 0; + } + startingSeqno = seqno; + count = 0; + log.info("Started SchemaSourceTask {}-{} producing to topic {} resuming from seqno {}", name, id, topic, startingSeqno); + } + + @Override + public List poll() throws InterruptedException { + if (count < maxNumMsgs) { + long sendStartMs = System.currentTimeMillis(); + if (throttler.shouldThrottle(seqno - startingSeqno, sendStartMs)) { + throttler.throttle(); + } + + Map ccOffset = Collections.singletonMap(SEQNO_FIELD, seqno); + int partitionVal = (int) (seqno % partitionCount); + final Struct data; + final SourceRecord srcRecord; + if (!multipleSchema || count % 2 == 0) { + data = new Struct(valueSchema) + .put("boolean", true) + .put("int", 12) + .put("long", 12L) + .put("float", 12.2f) + .put("double", 12.2) + .put("partitioning", partitionVal) + .put("id", id) + .put("seqno", seqno); + + srcRecord = new SourceRecord(partition, ccOffset, topic, id, Schema.STRING_SCHEMA, "key", valueSchema, data); + } else { + data = new Struct(valueSchema2) + .put("boolean", true) + .put("int", 12) + .put("long", 12L) + .put("float", 12.2f) + .put("double", 12.2) + .put("partitioning", partitionVal) + .put("string", "def") + .put("id", id) + .put("seqno", seqno); + + srcRecord = new SourceRecord(partition, ccOffset, topic, id, Schema.STRING_SCHEMA, "key", valueSchema2, data); + } + + System.out.println("{\"task\": " + id + ", \"seqno\": " + seqno + "}"); + seqno++; + count++; + return Collections.singletonList(srcRecord); + } else { + throttler.throttle(); + return Collections.emptyList(); + } + } + + @Override + public void stop() { + throttler.wakeup(); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/tools/TransformationDoc.java b/connect/runtime/src/main/java/org/apache/kafka/connect/tools/TransformationDoc.java new file mode 100644 index 0000000..5771a6b --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/tools/TransformationDoc.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.tools; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.transforms.Cast; +import org.apache.kafka.connect.transforms.DropHeaders; +import org.apache.kafka.connect.transforms.ExtractField; +import org.apache.kafka.connect.transforms.Filter; +import org.apache.kafka.connect.transforms.Flatten; +import org.apache.kafka.connect.transforms.HeaderFrom; +import org.apache.kafka.connect.transforms.HoistField; +import org.apache.kafka.connect.transforms.InsertField; +import org.apache.kafka.connect.transforms.InsertHeader; +import org.apache.kafka.connect.transforms.MaskField; +import org.apache.kafka.connect.transforms.RegexRouter; +import org.apache.kafka.connect.transforms.ReplaceField; +import org.apache.kafka.connect.transforms.SetSchemaMetadata; +import org.apache.kafka.connect.transforms.TimestampConverter; +import org.apache.kafka.connect.transforms.TimestampRouter; +import org.apache.kafka.connect.transforms.ValueToKey; + +import java.io.PrintStream; +import java.util.Arrays; +import java.util.List; + +public class TransformationDoc { + + private static final class DocInfo { + final String transformationName; + final String overview; + final ConfigDef configDef; + + private DocInfo(String transformationName, String overview, ConfigDef configDef) { + this.transformationName = transformationName; + this.overview = overview; + this.configDef = configDef; + } + } + + private static final List TRANSFORMATIONS = Arrays.asList( + new DocInfo(InsertField.class.getName(), InsertField.OVERVIEW_DOC, InsertField.CONFIG_DEF), + new DocInfo(ReplaceField.class.getName(), ReplaceField.OVERVIEW_DOC, ReplaceField.CONFIG_DEF), + new DocInfo(MaskField.class.getName(), MaskField.OVERVIEW_DOC, MaskField.CONFIG_DEF), + new DocInfo(ValueToKey.class.getName(), ValueToKey.OVERVIEW_DOC, ValueToKey.CONFIG_DEF), + new DocInfo(HoistField.class.getName(), HoistField.OVERVIEW_DOC, HoistField.CONFIG_DEF), + new DocInfo(ExtractField.class.getName(), ExtractField.OVERVIEW_DOC, ExtractField.CONFIG_DEF), + new DocInfo(SetSchemaMetadata.class.getName(), SetSchemaMetadata.OVERVIEW_DOC, SetSchemaMetadata.CONFIG_DEF), + new DocInfo(TimestampRouter.class.getName(), TimestampRouter.OVERVIEW_DOC, TimestampRouter.CONFIG_DEF), + new DocInfo(RegexRouter.class.getName(), RegexRouter.OVERVIEW_DOC, RegexRouter.CONFIG_DEF), + new DocInfo(Flatten.class.getName(), Flatten.OVERVIEW_DOC, Flatten.CONFIG_DEF), + new DocInfo(Cast.class.getName(), Cast.OVERVIEW_DOC, Cast.CONFIG_DEF), + new DocInfo(TimestampConverter.class.getName(), TimestampConverter.OVERVIEW_DOC, TimestampConverter.CONFIG_DEF), + new DocInfo(Filter.class.getName(), Filter.OVERVIEW_DOC, Filter.CONFIG_DEF), + new DocInfo(InsertHeader.class.getName(), InsertHeader.OVERVIEW_DOC, InsertHeader.CONFIG_DEF), + new DocInfo(DropHeaders.class.getName(), DropHeaders.OVERVIEW_DOC, DropHeaders.CONFIG_DEF), + new DocInfo(HeaderFrom.class.getName(), HeaderFrom.OVERVIEW_DOC, HeaderFrom.CONFIG_DEF) + ); + + private static void printTransformationHtml(PrintStream out, DocInfo docInfo) { + out.println("
            "); + + out.print("
            "); + out.print(docInfo.transformationName); + out.println("
            "); + + out.println(docInfo.overview); + + out.println("

            "); + + out.println(docInfo.configDef.toHtml(6, key -> docInfo.transformationName + "_" + key)); + + out.println("

            "); + } + + private static void printHtml(PrintStream out) { + for (final DocInfo docInfo : TRANSFORMATIONS) { + printTransformationHtml(out, docInfo); + } + } + + public static void main(String... args) { + printHtml(System.out); + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/tools/VerifiableSinkConnector.java b/connect/runtime/src/main/java/org/apache/kafka/connect/tools/VerifiableSinkConnector.java new file mode 100644 index 0000000..55f95a3 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/tools/VerifiableSinkConnector.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.tools; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.utils.AppInfoParser; +import org.apache.kafka.connect.connector.Task; +import org.apache.kafka.connect.source.SourceConnector; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * @see VerifiableSinkTask + */ +public class VerifiableSinkConnector extends SourceConnector { + private Map config; + + @Override + public String version() { + return AppInfoParser.getVersion(); + } + + @Override + public void start(Map props) { + this.config = props; + } + + @Override + public Class taskClass() { + return VerifiableSinkTask.class; + } + + @Override + public List> taskConfigs(int maxTasks) { + ArrayList> configs = new ArrayList<>(); + for (int i = 0; i < maxTasks; i++) { + Map props = new HashMap<>(config); + props.put(VerifiableSinkTask.ID_CONFIG, String.valueOf(i)); + configs.add(props); + } + return configs; + } + + @Override + public void stop() { + } + + @Override + public ConfigDef config() { + return new ConfigDef(); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/tools/VerifiableSinkTask.java b/connect/runtime/src/main/java/org/apache/kafka/connect/tools/VerifiableSinkTask.java new file mode 100644 index 0000000..ff71ff8 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/tools/VerifiableSinkTask.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.tools; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.sink.SinkRecord; +import org.apache.kafka.connect.sink.SinkTask; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Counterpart to {@link VerifiableSourceTask} that consumes records and logs information about each to stdout. This + * allows validation of processing of messages by sink tasks on distributed workers even in the face of worker restarts + * and failures. This task relies on the offset management provided by the Kafka Connect framework and therefore can detect + * bugs in its implementation. + */ +public class VerifiableSinkTask extends SinkTask { + public static final String NAME_CONFIG = "name"; + public static final String ID_CONFIG = "id"; + + private static final ObjectMapper JSON_SERDE = new ObjectMapper(); + + private String name; // Connector name + private int id; // Task ID + + private final Map>> unflushed = new HashMap<>(); + + @Override + public String version() { + return new VerifiableSinkConnector().version(); + } + + @Override + public void start(Map props) { + try { + name = props.get(NAME_CONFIG); + id = Integer.parseInt(props.get(ID_CONFIG)); + } catch (NumberFormatException e) { + throw new ConnectException("Invalid VerifiableSourceTask configuration", e); + } + } + + @Override + public void put(Collection records) { + long nowMs = System.currentTimeMillis(); + for (SinkRecord record : records) { + Map data = new HashMap<>(); + data.put("name", name); + data.put("task", record.key()); // VerifiableSourceTask's input task (source partition) + data.put("sinkTask", id); + data.put("topic", record.topic()); + data.put("time_ms", nowMs); + data.put("seqno", record.value()); + data.put("offset", record.kafkaOffset()); + String dataJson; + try { + dataJson = JSON_SERDE.writeValueAsString(data); + } catch (JsonProcessingException e) { + dataJson = "Bad data can't be written as json: " + e.getMessage(); + } + System.out.println(dataJson); + unflushed.computeIfAbsent( + new TopicPartition(record.topic(), record.kafkaPartition()), + tp -> new ArrayList<>() + ).add(data); + } + } + + @Override + public void flush(Map offsets) { + long nowMs = System.currentTimeMillis(); + for (TopicPartition topicPartition : offsets.keySet()) { + if (!unflushed.containsKey(topicPartition)) { + continue; + } + for (Map data : unflushed.get(topicPartition)) { + data.put("time_ms", nowMs); + data.put("flushed", true); + String dataJson; + try { + dataJson = JSON_SERDE.writeValueAsString(data); + } catch (JsonProcessingException e) { + dataJson = "Bad data can't be written as json: " + e.getMessage(); + } + System.out.println(dataJson); + } + unflushed.remove(topicPartition); + } + } + + @Override + public void stop() { + + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/tools/VerifiableSourceConnector.java b/connect/runtime/src/main/java/org/apache/kafka/connect/tools/VerifiableSourceConnector.java new file mode 100644 index 0000000..6262cc3 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/tools/VerifiableSourceConnector.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.tools; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.utils.AppInfoParser; +import org.apache.kafka.connect.connector.Task; +import org.apache.kafka.connect.source.SourceConnector; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * @see VerifiableSourceTask + */ +public class VerifiableSourceConnector extends SourceConnector { + private Map config; + + @Override + public String version() { + return AppInfoParser.getVersion(); + } + + @Override + public void start(Map props) { + this.config = props; + } + + @Override + public Class taskClass() { + return VerifiableSourceTask.class; + } + + @Override + public List> taskConfigs(int maxTasks) { + ArrayList> configs = new ArrayList<>(); + for (int i = 0; i < maxTasks; i++) { + Map props = new HashMap<>(config); + props.put(VerifiableSourceTask.ID_CONFIG, String.valueOf(i)); + configs.add(props); + } + return configs; + } + + @Override + public void stop() { + } + + @Override + public ConfigDef config() { + return new ConfigDef(); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/tools/VerifiableSourceTask.java b/connect/runtime/src/main/java/org/apache/kafka/connect/tools/VerifiableSourceTask.java new file mode 100644 index 0000000..8afbdff --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/tools/VerifiableSourceTask.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.tools; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.kafka.tools.ThroughputThrottler; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.source.SourceRecord; +import org.apache.kafka.connect.source.SourceTask; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * A connector primarily intended for system tests. The connector simply generates as many tasks as requested. The + * tasks print metadata in the form of JSON to stdout for each message generated, making externally visible which + * messages have been sent. Each message is also assigned a unique, increasing seqno that is passed to Kafka Connect; when + * tasks are started on new nodes, this seqno is used to resume where the task previously left off, allowing for + * testing of distributed Kafka Connect. + * + * If logging is left enabled, log output on stdout can be easily ignored by checking whether a given line is valid JSON. + */ +public class VerifiableSourceTask extends SourceTask { + private static final Logger log = LoggerFactory.getLogger(VerifiableSourceTask.class); + + public static final String NAME_CONFIG = "name"; + public static final String ID_CONFIG = "id"; + public static final String TOPIC_CONFIG = "topic"; + public static final String THROUGHPUT_CONFIG = "throughput"; + + private static final String ID_FIELD = "id"; + private static final String SEQNO_FIELD = "seqno"; + + private static final ObjectMapper JSON_SERDE = new ObjectMapper(); + + private String name; // Connector name + private int id; // Task ID + private String topic; + private Map partition; + private long startingSeqno; + private long seqno; + private ThroughputThrottler throttler; + + @Override + public String version() { + return new VerifiableSourceConnector().version(); + } + + @Override + public void start(Map props) { + final long throughput; + try { + name = props.get(NAME_CONFIG); + id = Integer.parseInt(props.get(ID_CONFIG)); + topic = props.get(TOPIC_CONFIG); + throughput = Long.parseLong(props.get(THROUGHPUT_CONFIG)); + } catch (NumberFormatException e) { + throw new ConnectException("Invalid VerifiableSourceTask configuration", e); + } + + partition = Collections.singletonMap(ID_FIELD, id); + Map previousOffset = this.context.offsetStorageReader().offset(partition); + if (previousOffset != null) + seqno = (Long) previousOffset.get(SEQNO_FIELD) + 1; + else + seqno = 0; + startingSeqno = seqno; + throttler = new ThroughputThrottler(throughput, System.currentTimeMillis()); + + log.info("Started VerifiableSourceTask {}-{} producing to topic {} resuming from seqno {}", name, id, topic, startingSeqno); + } + + @Override + public List poll() throws InterruptedException { + long sendStartMs = System.currentTimeMillis(); + if (throttler.shouldThrottle(seqno - startingSeqno, sendStartMs)) + throttler.throttle(); + + long nowMs = System.currentTimeMillis(); + + Map data = new HashMap<>(); + data.put("name", name); + data.put("task", id); + data.put("topic", this.topic); + data.put("time_ms", nowMs); + data.put("seqno", seqno); + String dataJson; + try { + dataJson = JSON_SERDE.writeValueAsString(data); + } catch (JsonProcessingException e) { + dataJson = "Bad data can't be written as json: " + e.getMessage(); + } + System.out.println(dataJson); + + Map ccOffset = Collections.singletonMap(SEQNO_FIELD, seqno); + SourceRecord srcRecord = new SourceRecord(partition, ccOffset, topic, Schema.INT32_SCHEMA, id, Schema.INT64_SCHEMA, seqno); + List result = Collections.singletonList(srcRecord); + seqno++; + return result; + } + + @Override + public void commitRecord(SourceRecord record, RecordMetadata metadata) throws InterruptedException { + Map data = new HashMap<>(); + data.put("name", name); + data.put("task", id); + data.put("topic", this.topic); + data.put("time_ms", System.currentTimeMillis()); + data.put("seqno", record.value()); + data.put("committed", true); + + String dataJson; + try { + dataJson = JSON_SERDE.writeValueAsString(data); + } catch (JsonProcessingException e) { + dataJson = "Bad data can't be written as json: " + e.getMessage(); + } + System.out.println(dataJson); + } + + @Override + public void stop() { + throttler.wakeup(); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/util/Callback.java b/connect/runtime/src/main/java/org/apache/kafka/connect/util/Callback.java new file mode 100644 index 0000000..277863b --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/util/Callback.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +/** + * Generic interface for callbacks + */ +public interface Callback { + /** + * Invoked upon completion of the operation. + * + * @param error the error that caused the operation to fail, or null if no error occurred + * @param result the return value, or null if the operation failed + */ + void onCompletion(Throwable error, V result); +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/util/ConnectUtils.java b/connect/runtime/src/main/java/org/apache/kafka/connect/util/ConnectUtils.java new file mode 100644 index 0000000..c1b83eb --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/util/ConnectUtils.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.InvalidRecordException; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.connect.connector.Connector; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.apache.kafka.connect.runtime.distributed.DistributedConfig; +import org.apache.kafka.connect.sink.SinkConnector; +import org.apache.kafka.connect.source.SourceConnector; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Map; +import java.util.concurrent.ExecutionException; + +public final class ConnectUtils { + private static final Logger log = LoggerFactory.getLogger(ConnectUtils.class); + + public static Long checkAndConvertTimestamp(Long timestamp) { + if (timestamp == null || timestamp >= 0) + return timestamp; + else if (timestamp == RecordBatch.NO_TIMESTAMP) + return null; + else + throw new InvalidRecordException(String.format("Invalid record timestamp %d", timestamp)); + } + + public static String lookupKafkaClusterId(WorkerConfig config) { + log.info("Creating Kafka admin client"); + try (Admin adminClient = Admin.create(config.originals())) { + return lookupKafkaClusterId(adminClient); + } + } + + static String lookupKafkaClusterId(Admin adminClient) { + log.debug("Looking up Kafka cluster ID"); + try { + KafkaFuture clusterIdFuture = adminClient.describeCluster().clusterId(); + if (clusterIdFuture == null) { + log.info("Kafka cluster version is too old to return cluster ID"); + return null; + } + log.debug("Fetching Kafka cluster ID"); + String kafkaClusterId = clusterIdFuture.get(); + log.info("Kafka cluster ID: {}", kafkaClusterId); + return kafkaClusterId; + } catch (InterruptedException e) { + throw new ConnectException("Unexpectedly interrupted when looking up Kafka cluster info", e); + } catch (ExecutionException e) { + throw new ConnectException("Failed to connect to and describe Kafka cluster. " + + "Check worker's broker connection and security properties.", e); + } + } + + public static void addMetricsContextProperties(Map prop, WorkerConfig config, String clusterId) { + //add all properties predefined with "metrics.context." + prop.putAll(config.originalsWithPrefix(CommonClientConfigs.METRICS_CONTEXT_PREFIX, false)); + //add connect properties + prop.put(CommonClientConfigs.METRICS_CONTEXT_PREFIX + WorkerConfig.CONNECT_KAFKA_CLUSTER_ID, clusterId); + Object groupId = config.originals().get(DistributedConfig.GROUP_ID_CONFIG); + if (groupId != null) { + prop.put(CommonClientConfigs.METRICS_CONTEXT_PREFIX + WorkerConfig.CONNECT_GROUP_ID, groupId); + } + } + + public static boolean isSinkConnector(Connector connector) { + return SinkConnector.class.isAssignableFrom(connector.getClass()); + } + + public static boolean isSourceConnector(Connector connector) { + return SourceConnector.class.isAssignableFrom(connector.getClass()); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/util/ConnectorTaskId.java b/connect/runtime/src/main/java/org/apache/kafka/connect/util/ConnectorTaskId.java new file mode 100644 index 0000000..1b69bd0 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/util/ConnectorTaskId.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.io.Serializable; +import java.util.Objects; + +/** + * Unique ID for a single task. It includes a unique connector ID and a task ID that is unique within + * the connector. + */ +public class ConnectorTaskId implements Serializable, Comparable { + private final String connector; + private final int task; + + @JsonCreator + public ConnectorTaskId(@JsonProperty("connector") String connector, @JsonProperty("task") int task) { + this.connector = connector; + this.task = task; + } + + @JsonProperty + public String connector() { + return connector; + } + + @JsonProperty + public int task() { + return task; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + ConnectorTaskId that = (ConnectorTaskId) o; + + if (task != that.task) + return false; + + return Objects.equals(connector, that.connector); + } + + @Override + public int hashCode() { + int result = connector != null ? connector.hashCode() : 0; + result = 31 * result + task; + return result; + } + + @Override + public String toString() { + return connector + '-' + task; + } + + @Override + public int compareTo(ConnectorTaskId o) { + int connectorCmp = connector.compareTo(o.connector); + if (connectorCmp != 0) + return connectorCmp; + return Integer.compare(task, o.task); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/util/ConvertingFutureCallback.java b/connect/runtime/src/main/java/org/apache/kafka/connect/util/ConvertingFutureCallback.java new file mode 100644 index 0000000..e15c38e --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/util/ConvertingFutureCallback.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +import org.apache.kafka.connect.errors.ConnectException; + +import java.util.concurrent.CancellationException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public abstract class ConvertingFutureCallback implements Callback, Future { + + private final Callback underlying; + private final CountDownLatch finishedLatch; + private volatile T result = null; + private volatile Throwable exception = null; + private volatile boolean cancelled = false; + + public ConvertingFutureCallback() { + this(null); + } + + public ConvertingFutureCallback(Callback underlying) { + this.underlying = underlying; + this.finishedLatch = new CountDownLatch(1); + } + + public abstract T convert(U result); + + @Override + public void onCompletion(Throwable error, U result) { + synchronized (this) { + if (isDone()) { + return; + } + + if (error != null) { + this.exception = error; + } else { + this.result = convert(result); + } + + if (underlying != null) + underlying.onCompletion(error, this.result); + finishedLatch.countDown(); + } + } + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + synchronized (this) { + if (isDone()) { + return false; + } + if (mayInterruptIfRunning) { + this.cancelled = true; + finishedLatch.countDown(); + return true; + } + } + try { + finishedLatch.await(); + } catch (InterruptedException e) { + throw new ConnectException("Interrupted while waiting for task to complete", e); + } + return false; + } + + @Override + public boolean isCancelled() { + return cancelled; + } + + @Override + public boolean isDone() { + return finishedLatch.getCount() == 0; + } + + @Override + public T get() throws InterruptedException, ExecutionException { + finishedLatch.await(); + return result(); + } + + @Override + public T get(long l, TimeUnit timeUnit) + throws InterruptedException, ExecutionException, TimeoutException { + if (!finishedLatch.await(l, timeUnit)) + throw new TimeoutException("Timed out waiting for future"); + return result(); + } + + private T result() throws ExecutionException { + if (cancelled) { + throw new CancellationException(); + } + if (exception != null) { + throw new ExecutionException(exception); + } + return result; + } +} + diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/util/FutureCallback.java b/connect/runtime/src/main/java/org/apache/kafka/connect/util/FutureCallback.java new file mode 100644 index 0000000..a151926 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/util/FutureCallback.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +public class FutureCallback extends ConvertingFutureCallback { + + public FutureCallback(Callback underlying) { + super(underlying); + } + + public FutureCallback() { + super(null); + } + + @Override + public T convert(T result) { + return result; + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/util/KafkaBasedLog.java b/connect/runtime/src/main/java/org/apache/kafka/connect/util/KafkaBasedLog.java new file mode 100644 index 0000000..b1920d5 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/util/KafkaBasedLog.java @@ -0,0 +1,436 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.Producer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.RetriableException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.errors.WakeupException; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.connect.errors.ConnectException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + + +/** + *

            + * KafkaBasedLog provides a generic implementation of a shared, compacted log of records stored in Kafka that all + * clients need to consume and, at times, agree on their offset / that they have read to the end of the log. + *

            + *

            + * This functionality is useful for storing different types of data that all clients may need to agree on -- + * offsets or config for example. This class runs a consumer in a background thread to continuously tail the target + * topic, accepts write requests which it writes to the topic using an internal producer, and provides some helpful + * utilities like checking the current log end offset and waiting until the current end of the log is reached. + *

            + *

            + * To support different use cases, this class works with either single- or multi-partition topics. + *

            + *

            + * Since this class is generic, it delegates the details of data storage via a callback that is invoked for each + * record that is consumed from the topic. The invocation of callbacks is guaranteed to be serialized -- if the + * calling class keeps track of state based on the log and only writes to it when consume callbacks are invoked + * and only reads it in {@link #readToEnd(Callback)} callbacks then no additional synchronization will be required. + *

            + */ +public class KafkaBasedLog { + private static final Logger log = LoggerFactory.getLogger(KafkaBasedLog.class); + private static final long CREATE_TOPIC_TIMEOUT_NS = TimeUnit.SECONDS.toNanos(30); + private static final long MAX_SLEEP_MS = TimeUnit.SECONDS.toMillis(1); + + private Time time; + private final String topic; + private int partitionCount; + private final Map producerConfigs; + private final Map consumerConfigs; + private final Callback> consumedCallback; + private final Supplier topicAdminSupplier; + private Consumer consumer; + private Producer producer; + private TopicAdmin admin; + + private Thread thread; + private boolean stopRequested; + private Queue> readLogEndOffsetCallbacks; + private java.util.function.Consumer initializer; + + /** + * Create a new KafkaBasedLog object. This does not start reading the log and writing is not permitted until + * {@link #start()} is invoked. + * + * @param topic the topic to treat as a log + * @param producerConfigs configuration options to use when creating the internal producer. At a minimum this must + * contain compatible serializer settings for the generic types used on this class. Some + * setting, such as the number of acks, will be overridden to ensure correct behavior of this + * class. + * @param consumerConfigs configuration options to use when creating the internal consumer. At a minimum this must + * contain compatible serializer settings for the generic types used on this class. Some + * setting, such as the auto offset reset policy, will be overridden to ensure correct + * behavior of this class. + * @param consumedCallback callback to invoke for each {@link ConsumerRecord} consumed when tailing the log + * @param time Time interface + * @param initializer the component that should be run when this log is {@link #start() started}; may be null + * @deprecated Replaced by {@link #KafkaBasedLog(String, Map, Map, Supplier, Callback, Time, java.util.function.Consumer)} + */ + @Deprecated + public KafkaBasedLog(String topic, + Map producerConfigs, + Map consumerConfigs, + Callback> consumedCallback, + Time time, + Runnable initializer) { + this(topic, producerConfigs, consumerConfigs, () -> null, consumedCallback, time, initializer != null ? admin -> initializer.run() : null); + } + + /** + * Create a new KafkaBasedLog object. This does not start reading the log and writing is not permitted until + * {@link #start()} is invoked. + * + * @param topic the topic to treat as a log + * @param producerConfigs configuration options to use when creating the internal producer. At a minimum this must + * contain compatible serializer settings for the generic types used on this class. Some + * setting, such as the number of acks, will be overridden to ensure correct behavior of this + * class. + * @param consumerConfigs configuration options to use when creating the internal consumer. At a minimum this must + * contain compatible serializer settings for the generic types used on this class. Some + * setting, such as the auto offset reset policy, will be overridden to ensure correct + * behavior of this class. + * @param topicAdminSupplier supplier function for an admin client, the lifecycle of which is expected to be controlled + * by the calling component; may not be null + * @param consumedCallback callback to invoke for each {@link ConsumerRecord} consumed when tailing the log + * @param time Time interface + * @param initializer the function that should be run when this log is {@link #start() started}; may be null + */ + public KafkaBasedLog(String topic, + Map producerConfigs, + Map consumerConfigs, + Supplier topicAdminSupplier, + Callback> consumedCallback, + Time time, + java.util.function.Consumer initializer) { + this.topic = topic; + this.producerConfigs = producerConfigs; + this.consumerConfigs = consumerConfigs; + this.topicAdminSupplier = Objects.requireNonNull(topicAdminSupplier); + this.consumedCallback = consumedCallback; + this.stopRequested = false; + this.readLogEndOffsetCallbacks = new ArrayDeque<>(); + this.time = time; + this.initializer = initializer != null ? initializer : admin -> { }; + } + + public void start() { + log.info("Starting KafkaBasedLog with topic " + topic); + + // Create the topic admin client and initialize the topic ... + admin = topicAdminSupplier.get(); // may be null + initializer.accept(admin); + + // Then create the producer and consumer + producer = createProducer(); + consumer = createConsumer(); + + List partitions = new ArrayList<>(); + + // We expect that the topics will have been created either manually by the user or automatically by the herder + List partitionInfos = consumer.partitionsFor(topic); + long started = time.nanoseconds(); + long sleepMs = 100; + while (partitionInfos.isEmpty() && time.nanoseconds() - started < CREATE_TOPIC_TIMEOUT_NS) { + time.sleep(sleepMs); + sleepMs = Math.min(2 * sleepMs, MAX_SLEEP_MS); + partitionInfos = consumer.partitionsFor(topic); + } + if (partitionInfos.isEmpty()) + throw new ConnectException("Could not look up partition metadata for offset backing store topic in" + + " allotted period. This could indicate a connectivity issue, unavailable topic partitions, or if" + + " this is your first use of the topic it may have taken too long to create."); + + for (PartitionInfo partition : partitionInfos) + partitions.add(new TopicPartition(partition.topic(), partition.partition())); + partitionCount = partitions.size(); + consumer.assign(partitions); + + // Always consume from the beginning of all partitions. Necessary to ensure that we don't use committed offsets + // when a 'group.id' is specified (if offsets happen to have been committed unexpectedly). + consumer.seekToBeginning(partitions); + + readToLogEnd(); + + thread = new WorkThread(); + thread.start(); + + log.info("Finished reading KafkaBasedLog for topic " + topic); + + log.info("Started KafkaBasedLog for topic " + topic); + } + + public void stop() { + log.info("Stopping KafkaBasedLog for topic " + topic); + + synchronized (this) { + stopRequested = true; + } + consumer.wakeup(); + + try { + thread.join(); + } catch (InterruptedException e) { + throw new ConnectException("Failed to stop KafkaBasedLog. Exiting without cleanly shutting " + + "down it's producer and consumer.", e); + } + + try { + producer.close(); + } catch (KafkaException e) { + log.error("Failed to stop KafkaBasedLog producer", e); + } + + try { + consumer.close(); + } catch (KafkaException e) { + log.error("Failed to stop KafkaBasedLog consumer", e); + } + + // do not close the admin client, since we don't own it + admin = null; + + log.info("Stopped KafkaBasedLog for topic " + topic); + } + + /** + * Flushes any outstanding writes and then reads to the current end of the log and invokes the specified callback. + * Note that this checks the current, offsets, reads to them, and invokes the callback regardless of whether + * additional records have been written to the log. If the caller needs to ensure they have truly reached the end + * of the log, they must ensure there are no other writers during this period. + * + * This waits until the end of all partitions has been reached. + * + * This method is asynchronous. If you need a synchronous version, pass an instance of + * {@link org.apache.kafka.connect.util.FutureCallback} as the {@param callback} parameter and wait on it to block. + * + * @param callback the callback to invoke once the end of the log has been reached. + */ + public void readToEnd(Callback callback) { + log.trace("Starting read to end log for topic {}", topic); + producer.flush(); + synchronized (this) { + readLogEndOffsetCallbacks.add(callback); + } + consumer.wakeup(); + } + + /** + * Flush the underlying producer to ensure that all pending writes have been sent. + */ + public void flush() { + producer.flush(); + } + + /** + * Same as {@link #readToEnd(Callback)} but provides a {@link Future} instead of using a callback. + * @return the future associated with the operation + */ + public Future readToEnd() { + FutureCallback future = new FutureCallback<>(null); + readToEnd(future); + return future; + } + + public void send(K key, V value) { + send(key, value, null); + } + + public void send(K key, V value, org.apache.kafka.clients.producer.Callback callback) { + producer.send(new ProducerRecord<>(topic, key, value), callback); + } + + public int partitionCount() { + return partitionCount; + } + + private Producer createProducer() { + // Always require producer acks to all to ensure durable writes + producerConfigs.put(ProducerConfig.ACKS_CONFIG, "all"); + + // Don't allow more than one in-flight request to prevent reordering on retry (if enabled) + producerConfigs.put(ProducerConfig.MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION, 1); + return new KafkaProducer<>(producerConfigs); + } + + private Consumer createConsumer() { + // Always force reset to the beginning of the log since this class wants to consume all available log data + consumerConfigs.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + + // Turn off autocommit since we always want to consume the full log + consumerConfigs.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, false); + return new KafkaConsumer<>(consumerConfigs); + } + + private void poll(long timeoutMs) { + try { + ConsumerRecords records = consumer.poll(Duration.ofMillis(timeoutMs)); + for (ConsumerRecord record : records) + consumedCallback.onCompletion(null, record); + } catch (WakeupException e) { + // Expected on get() or stop(). The calling code should handle this + throw e; + } catch (KafkaException e) { + log.error("Error polling: " + e); + } + } + + private void readToLogEnd() { + Set assignment = consumer.assignment(); + Map endOffsets = readEndOffsets(assignment); + log.trace("Reading to end of log offsets {}", endOffsets); + + while (!endOffsets.isEmpty()) { + Iterator> it = endOffsets.entrySet().iterator(); + while (it.hasNext()) { + Map.Entry entry = it.next(); + TopicPartition topicPartition = entry.getKey(); + long endOffset = entry.getValue(); + long lastConsumedOffset = consumer.position(topicPartition); + if (lastConsumedOffset >= endOffset) { + log.trace("Read to end offset {} for {}", endOffset, topicPartition); + it.remove(); + } else { + log.trace("Behind end offset {} for {}; last-read offset is {}", + endOffset, topicPartition, lastConsumedOffset); + poll(Integer.MAX_VALUE); + break; + } + } + } + } + + // Visible for testing + Map readEndOffsets(Set assignment) { + log.trace("Reading to end of offset log"); + + // Note that we'd prefer to not use the consumer to find the end offsets for the assigned topic partitions. + // That is because it's possible that the consumer is already blocked waiting for new records to appear, when + // the consumer is already at the end. In such cases, using 'consumer.endOffsets(...)' will block until at least + // one more record becomes available, meaning we can't even check whether we're at the end offset. + // Since all we're trying to do here is get the end offset, we should use the supplied admin client + // (if available) to obtain the end offsets for the given topic partitions. + + // Deprecated constructors do not provide an admin supplier, so the admin is potentially null. + if (admin != null) { + // Use the admin client to immediately find the end offsets for the assigned topic partitions. + // Unlike using the consumer + try { + return admin.endOffsets(assignment); + } catch (UnsupportedVersionException e) { + // This may happen with really old brokers that don't support the auto topic creation + // field in metadata requests + log.debug("Reading to end of log offsets with consumer since admin client is unsupported: {}", e.getMessage()); + // Forget the reference to the admin so that we won't even try to use the admin the next time this method is called + admin = null; + // continue and let the consumer handle the read + } + // Other errors, like timeouts and retriable exceptions are intentionally propagated + } + // The admin may be null if older deprecated constructor is used or if the admin client is using a broker that doesn't + // support getting the end offsets (e.g., 0.10.x). In such cases, we should use the consumer, which is not ideal (see above). + return consumer.endOffsets(assignment); + } + + private class WorkThread extends Thread { + public WorkThread() { + super("KafkaBasedLog Work Thread - " + topic); + } + + @Override + public void run() { + try { + log.trace("{} started execution", this); + while (true) { + int numCallbacks; + synchronized (KafkaBasedLog.this) { + if (stopRequested) + break; + numCallbacks = readLogEndOffsetCallbacks.size(); + } + + if (numCallbacks > 0) { + try { + readToLogEnd(); + log.trace("Finished read to end log for topic {}", topic); + } catch (TimeoutException e) { + log.warn("Timeout while reading log to end for topic '{}'. Retrying automatically. " + + "This may occur when brokers are unavailable or unreachable. Reason: {}", topic, e.getMessage()); + continue; + } catch (RetriableException | org.apache.kafka.connect.errors.RetriableException e) { + log.warn("Retriable error while reading log to end for topic '{}'. Retrying automatically. " + + "Reason: {}", topic, e.getMessage()); + continue; + } catch (WakeupException e) { + // Either received another get() call and need to retry reading to end of log or stop() was + // called. Both are handled by restarting this loop. + continue; + } + } + + synchronized (KafkaBasedLog.this) { + // Only invoke exactly the number of callbacks we found before triggering the read to log end + // since it is possible for another write + readToEnd to sneak in the meantime + for (int i = 0; i < numCallbacks; i++) { + Callback cb = readLogEndOffsetCallbacks.poll(); + cb.onCompletion(null, null); + } + } + + try { + poll(Integer.MAX_VALUE); + } catch (WakeupException e) { + // See previous comment, both possible causes of this wakeup are handled by starting this loop again + continue; + } + } + } catch (Throwable t) { + log.error("Unexpected exception in {}", this, t); + } + } + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/util/LoggingContext.java b/connect/runtime/src/main/java/org/apache/kafka/connect/util/LoggingContext.java new file mode 100644 index 0000000..8df5f9c --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/util/LoggingContext.java @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +import org.slf4j.MDC; + +import java.util.Collection; +import java.util.Collections; +import java.util.Map; +import java.util.Objects; + +/** + * A utility for defining Mapped Diagnostic Context (MDC) for SLF4J logs. + * + *

            {@link LoggingContext} instances should be created in a try-with-resources block to ensure + * that the logging context is properly closed. The only exception is the logging context created + * upon thread creation that is to be used for the entire lifetime of the thread. + * + *

            Any logger created on the thread will inherit the MDC context, so this mechanism is ideal for + * providing additional information in the log messages without requiring connector + * implementations to use a specific Connect API or SLF4J API. {@link LoggingContext#close()} + * will also properly restore the Connect MDC parameters to their state just prior to when the + * LoggingContext was created. Use {@link #clear()} to remove all MDC parameters from the + * current thread context. + * + *

            Compare this approach to {@link org.apache.kafka.common.utils.LogContext}, which must be + * used to create a new {@link org.slf4j.Logger} instance pre-configured with the desired prefix. + * Currently LogContext does not allow the prefix to be changed, and it requires that all + * components use the LogContext to create their Logger instance. + */ +public final class LoggingContext implements AutoCloseable { + + /** + * The name of the Mapped Diagnostic Context (MDC) key that defines the context for a connector. + */ + public static final String CONNECTOR_CONTEXT = "connector.context"; + + public static final Collection ALL_CONTEXTS = Collections.singleton(CONNECTOR_CONTEXT); + + /** + * The Scope values used by Connect when specifying the context. + */ + public enum Scope { + /** + * The scope value for the worker as it starts a connector. + */ + WORKER("worker"), + + /** + * The scope value for Task implementations. + */ + TASK("task"), + + /** + * The scope value for committing offsets. + */ + OFFSETS("offsets"), + + /** + * The scope value for validating connector configurations. + */ + VALIDATE("validate"); + + private final String text; + Scope(String value) { + this.text = value; + } + + @Override + public String toString() { + return text; + } + } + + /** + * Clear all MDC parameters. + */ + public static void clear() { + MDC.clear(); + } + + /** + * Modify the current {@link MDC} logging context to set the {@link #CONNECTOR_CONTEXT connector context} to include the + * supplied name and the {@link Scope#WORKER} scope. + * + * @param connectorName the connector name; may not be null + */ + public static LoggingContext forConnector(String connectorName) { + Objects.requireNonNull(connectorName); + LoggingContext context = new LoggingContext(); + MDC.put(CONNECTOR_CONTEXT, prefixFor(connectorName, Scope.WORKER, null)); + return context; + } + + /** + * Modify the current {@link MDC} logging context to set the {@link #CONNECTOR_CONTEXT connector context} to include the + * supplied connector name and the {@link Scope#VALIDATE} scope. + * + * @param connectorName the connector name + */ + public static LoggingContext forValidation(String connectorName) { + LoggingContext context = new LoggingContext(); + MDC.put(CONNECTOR_CONTEXT, prefixFor(connectorName, Scope.VALIDATE, null)); + return context; + } + + /** + * Modify the current {@link MDC} logging context to set the {@link #CONNECTOR_CONTEXT connector context} to include the + * connector name and task number using the supplied {@link ConnectorTaskId}, and to set the scope to {@link Scope#TASK}. + * + * @param id the connector task ID; may not be null + */ + public static LoggingContext forTask(ConnectorTaskId id) { + Objects.requireNonNull(id); + LoggingContext context = new LoggingContext(); + MDC.put(CONNECTOR_CONTEXT, prefixFor(id.connector(), Scope.TASK, id.task())); + return context; + } + + /** + * Modify the current {@link MDC} logging context to set the {@link #CONNECTOR_CONTEXT connector context} to include the + * connector name and task number using the supplied {@link ConnectorTaskId}, and to set the scope to {@link Scope#OFFSETS}. + * + * @param id the connector task ID; may not be null + */ + public static LoggingContext forOffsets(ConnectorTaskId id) { + Objects.requireNonNull(id); + LoggingContext context = new LoggingContext(); + MDC.put(CONNECTOR_CONTEXT, prefixFor(id.connector(), Scope.OFFSETS, id.task())); + return context; + } + + /** + * Return the prefix that uses the specified connector name, task number, and scope. The + * format is as follows: + * + *

            +     *     [<connectorName>|<scope>]<sp>
            +     * 
            + * + * where "<connectorName>" is the name of the connector, + * "<sp>" indicates a trailing space, and + * "<scope>" is one of the following: + * + *
              + *
            • "task-n" for the operation of the numbered task, including calling the + * task methods and the producer/consumer; here "n" is the 0-based task number + *
            • "task-n|offset" for the committing of source offsets for the numbered + * task; here "n" is the * zero-based task number + *
            • "worker" for the creation and usage of connector instances + *
            + * + *

            The following are examples of the connector context for a connector named "my-connector": + * + *

              + *
            • `[my-connector|worker]` - used on log messages where the Connect worker is + * validating the configuration for or starting/stopping the "local-file-source" connector + * via the SourceConnector / SinkConnector implementation methods. + *
            • `[my-connector|task-0]` - used on log messages where the Connect worker is executing + * task 0 of the "local-file-source" connector, including calling any of the SourceTask / + * SinkTask implementation methods, processing the messages for/from the task, and + * calling the task's * producer/consumer. + *
            • `[my-connector|task-0|offsets]` - used on log messages where the Connect worker is + * committing * source offsets for task 0 of the "local-file-source" connector. + *
            + * + * @param connectorName the name of the connector; may not be null + * @param scope the scope; may not be null + * @param taskNumber the 0-based task number; may be null if there is no associated task + * @return the prefix; never null + */ + protected static String prefixFor(String connectorName, Scope scope, Integer taskNumber) { + StringBuilder sb = new StringBuilder(); + sb.append("["); + sb.append(connectorName); + if (taskNumber != null) { + // There is a task number, so this is a task + sb.append("|"); + sb.append(Scope.TASK.toString()); + sb.append("-"); + sb.append(taskNumber.toString()); + } + // Append non-task scopes (e.g., worker and offset) + if (scope != Scope.TASK) { + sb.append("|"); + sb.append(scope.toString()); + } + sb.append("] "); + return sb.toString(); + } + + private final Map previous; + + private LoggingContext() { + previous = MDC.getCopyOfContextMap(); // may be null! + } + + /** + * Close this logging context, restoring the Connect {@link MDC} parameters back to the state + * just before this context was created. This does not affect other MDC parameters set by + * connectors or tasks. + */ + @Override + public void close() { + for (String param : ALL_CONTEXTS) { + if (previous != null && previous.containsKey(param)) { + MDC.put(param, previous.get(param)); + } else { + MDC.remove(param); + } + } + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/util/SafeObjectInputStream.java b/connect/runtime/src/main/java/org/apache/kafka/connect/util/SafeObjectInputStream.java new file mode 100644 index 0000000..0ad3889 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/util/SafeObjectInputStream.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +import java.io.IOException; +import java.io.InputStream; +import java.io.ObjectInputStream; +import java.io.ObjectStreamClass; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +public class SafeObjectInputStream extends ObjectInputStream { + + protected static final Set DEFAULT_NO_DESERIALIZE_CLASS_NAMES; + + static { + + Set s = new HashSet<>(); + s.add("org.apache.commons.collections.functors.InvokerTransformer"); + s.add("org.apache.commons.collections.functors.InstantiateTransformer"); + s.add("org.apache.commons.collections4.functors.InvokerTransformer"); + s.add("org.apache.commons.collections4.functors.InstantiateTransformer"); + s.add("org.codehaus.groovy.runtime.ConvertedClosure"); + s.add("org.codehaus.groovy.runtime.MethodClosure"); + s.add("org.springframework.beans.factory.ObjectFactory"); + s.add("com.sun.org.apache.xalan.internal.xsltc.trax.TemplatesImpl"); + s.add("org.apache.xalan.xsltc.trax.TemplatesImpl"); + DEFAULT_NO_DESERIALIZE_CLASS_NAMES = Collections.unmodifiableSet(s); + } + + + public SafeObjectInputStream(InputStream in) throws IOException { + super(in); + } + + @Override + protected Class resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException { + String name = desc.getName(); + + if (isBlocked(name)) { + throw new SecurityException("Illegal type to deserialize: prevented for security reasons"); + } + + return super.resolveClass(desc); + } + + private boolean isBlocked(String name) { + for (String list : DEFAULT_NO_DESERIALIZE_CLASS_NAMES) { + if (name.endsWith(list)) { + return true; + } + } + + return false; + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/util/SharedTopicAdmin.java b/connect/runtime/src/main/java/org/apache/kafka/connect/util/SharedTopicAdmin.java new file mode 100644 index 0000000..a99514e --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/util/SharedTopicAdmin.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +import java.time.Duration; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.function.UnaryOperator; + +import org.apache.kafka.clients.admin.AdminClientConfig; +import org.apache.kafka.connect.errors.ConnectException; + +/** + * A holder of a {@link TopicAdmin} object that is lazily and atomically created when needed by multiple callers. + * As soon as one of the getters is called, all getters will return the same shared {@link TopicAdmin} + * instance until this SharedAdmin is closed via {@link #close()} or {@link #close(Duration)}. + * + *

            The owner of this object is responsible for ensuring that either {@link #close()} or {@link #close(Duration)} + * is called when the {@link TopicAdmin} instance is no longer needed. Consequently, once this + * {@link SharedTopicAdmin} instance has been closed, the {@link #get()} and {@link #topicAdmin()} methods, + * nor any previously returned {@link TopicAdmin} instances may be used. + * + *

            This class is thread-safe. It also appears as immutable to callers that obtain the {@link TopicAdmin} object, + * until this object is closed, at which point it cannot be used anymore + */ +public class SharedTopicAdmin implements AutoCloseable, Supplier { + + // Visible for testing + static final Duration DEFAULT_CLOSE_DURATION = Duration.ofMillis(Long.MAX_VALUE); + + private final Map adminProps; + private final AtomicReference admin = new AtomicReference<>(); + private final AtomicBoolean closed = new AtomicBoolean(false); + private final Function, TopicAdmin> factory; + + public SharedTopicAdmin(Map adminProps) { + this(adminProps, TopicAdmin::new); + } + + // Visible for testing + SharedTopicAdmin(Map adminProps, Function, TopicAdmin> factory) { + this.adminProps = Objects.requireNonNull(adminProps); + this.factory = Objects.requireNonNull(factory); + } + + /** + * Get the shared {@link TopicAdmin} instance. + * + * @return the shared instance; never null + * @throws ConnectException if this object has already been closed + */ + @Override + public TopicAdmin get() { + return topicAdmin(); + } + + /** + * Get the shared {@link TopicAdmin} instance. + * + * @return the shared instance; never null + * @throws ConnectException if this object has already been closed + */ + public TopicAdmin topicAdmin() { + return admin.updateAndGet(this::createAdmin); + } + + /** + * Get the string containing the list of bootstrap server addresses to the Kafka broker(s) to which + * the admin client connects. + * + * @return the bootstrap servers as a string; never null + */ + public String bootstrapServers() { + return adminProps.getOrDefault(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, "").toString(); + } + + /** + * Close the underlying {@link TopicAdmin} instance, if one has been created, and prevent new ones from being created. + * + *

            Once this method is called, the {@link #get()} and {@link #topicAdmin()} methods, + * nor any previously returned {@link TopicAdmin} instances may be used. + */ + @Override + public void close() { + close(DEFAULT_CLOSE_DURATION); + } + + /** + * Close the underlying {@link TopicAdmin} instance, if one has been created, and prevent new ones from being created. + * + *

            Once this method is called, the {@link #get()} and {@link #topicAdmin()} methods, + * nor any previously returned {@link TopicAdmin} instances may be used. + * + * @param timeout the maximum time to wait while the underlying admin client is closed; may not be null + */ + public void close(Duration timeout) { + Objects.requireNonNull(timeout); + if (this.closed.compareAndSet(false, true)) { + TopicAdmin admin = this.admin.getAndSet(null); + if (admin != null) { + admin.close(timeout); + } + } + } + + @Override + public String toString() { + return "admin client for brokers at " + bootstrapServers(); + } + + /** + * Method used to create a {@link TopicAdmin} instance. This method must be side-effect free, since it is called from within + * the {@link AtomicReference#updateAndGet(UnaryOperator)}. + * + * @param existing the existing instance; may be null + * @return the + */ + protected TopicAdmin createAdmin(TopicAdmin existing) { + if (closed.get()) { + throw new ConnectException("The " + this + " has already been closed and cannot be used."); + } + if (existing != null) { + return existing; + } + return factory.apply(adminProps); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/util/ShutdownableThread.java b/connect/runtime/src/main/java/org/apache/kafka/connect/util/ShutdownableThread.java new file mode 100644 index 0000000..daf005b --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/util/ShutdownableThread.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + *

            + * Thread class with support for triggering graceful and forcible shutdown. In graceful shutdown, + * a flag is set, which the thread should detect and try to exit gracefully from. In forcible + * shutdown, the thread is interrupted. These can be combined to give a thread a chance to exit + * gracefully, but then force it to exit if it takes too long. + *

            + *

            + * Implementations should override the {@link #execute} method and check {@link #getRunning} to + * determine whether they should try to gracefully exit. + *

            + */ +public abstract class ShutdownableThread extends Thread { + private static final Logger log = LoggerFactory.getLogger(ShutdownableThread.class); + + private AtomicBoolean isRunning = new AtomicBoolean(true); + private CountDownLatch shutdownLatch = new CountDownLatch(1); + + /** + * An UncaughtExceptionHandler to register on every instance of this class. This is useful for + * testing, where AssertionExceptions in the thread may not cause the test to fail. Since one + * instance is used for all threads, it must be thread-safe. + */ + volatile public static UncaughtExceptionHandler funcaughtExceptionHandler = null; + + public ShutdownableThread(String name) { + // The default is daemon=true so that these threads will not prevent shutdown. We use this + // default because threads that are running user code that may not clean up properly, even + // when we attempt to forcibly shut them down. + this(name, true); + } + + public ShutdownableThread(String name, boolean daemon) { + super(name); + this.setDaemon(daemon); + if (funcaughtExceptionHandler != null) + this.setUncaughtExceptionHandler(funcaughtExceptionHandler); + } + + /** + * Implementations should override this method with the main body for the thread. + */ + public abstract void execute(); + + /** + * Returns true if the thread hasn't exited yet and none of the shutdown methods have been + * invoked + */ + public boolean getRunning() { + return isRunning.get(); + } + + @Override + public void run() { + try { + execute(); + } catch (Error | RuntimeException e) { + log.error("Thread {} exiting with uncaught exception: ", getName(), e); + throw e; + } finally { + shutdownLatch.countDown(); + } + } + + /** + * Shutdown the thread, first trying to shut down gracefully using the specified timeout, then + * forcibly interrupting the thread. + * @param gracefulTimeout the maximum time to wait for a graceful exit + * @param unit the time unit of the timeout argument + */ + public void shutdown(long gracefulTimeout, TimeUnit unit) + throws InterruptedException { + boolean success = gracefulShutdown(gracefulTimeout, unit); + if (!success) + forceShutdown(); + } + + /** + * Attempt graceful shutdown + * @param timeout the maximum time to wait + * @param unit the time unit of the timeout argument + * @return true if successful, false if the timeout elapsed + */ + public boolean gracefulShutdown(long timeout, TimeUnit unit) throws InterruptedException { + startGracefulShutdown(); + return awaitShutdown(timeout, unit); + } + + /** + * Start shutting down this thread gracefully, but do not block waiting for it to exit. + */ + public void startGracefulShutdown() { + log.info("Starting graceful shutdown of thread {}", getName()); + isRunning.set(false); + } + + /** + * Awaits shutdown of this thread, waiting up to the timeout. + * @param timeout the maximum time to wait + * @param unit the time unit of the timeout argument + * @return true if successful, false if the timeout elapsed + * @throws InterruptedException + */ + public boolean awaitShutdown(long timeout, TimeUnit unit) throws InterruptedException { + return shutdownLatch.await(timeout, unit); + } + + /** + * Immediately tries to force the thread to shut down by interrupting it. This does not try to + * wait for the thread to truly exit because forcible shutdown is not always possible. By + * default, threads are marked as daemon threads so they will not prevent the process from + * exiting. + */ + public void forceShutdown() throws InterruptedException { + log.info("Forcing shutdown of thread {}", getName()); + isRunning.set(false); + interrupt(); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/util/SinkUtils.java b/connect/runtime/src/main/java/org/apache/kafka/connect/util/SinkUtils.java new file mode 100644 index 0000000..7777174 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/util/SinkUtils.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +public final class SinkUtils { + + private SinkUtils() {} + + public static String consumerGroupId(String connector) { + return "connect-" + connector; + } + +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/util/Table.java b/connect/runtime/src/main/java/org/apache/kafka/connect/util/Table.java new file mode 100644 index 0000000..1b7131a --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/util/Table.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +public class Table { + + private Map> table = new HashMap<>(); + + public V put(R row, C column, V value) { + Map columns = table.get(row); + if (columns == null) { + columns = new HashMap<>(); + table.put(row, columns); + } + return columns.put(column, value); + } + + public V get(R row, C column) { + Map columns = table.get(row); + if (columns == null) + return null; + return columns.get(column); + } + + public Map remove(R row) { + return table.remove(row); + } + + public V remove(R row, C column) { + Map columns = table.get(row); + if (columns == null) + return null; + + V value = columns.remove(column); + if (columns.isEmpty()) + table.remove(row); + return value; + } + + public Map row(R row) { + Map columns = table.get(row); + if (columns == null) + return Collections.emptyMap(); + return Collections.unmodifiableMap(columns); + } + + public boolean isEmpty() { + return table.isEmpty(); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/util/TopicAdmin.java b/connect/runtime/src/main/java/org/apache/kafka/connect/util/TopicAdmin.java new file mode 100644 index 0000000..7b2f152 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/util/TopicAdmin.java @@ -0,0 +1,719 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.AdminClientConfig; +import org.apache.kafka.clients.admin.Config; +import org.apache.kafka.clients.admin.ConfigEntry; +import org.apache.kafka.clients.admin.CreateTopicsOptions; +import org.apache.kafka.clients.admin.DescribeConfigsOptions; +import org.apache.kafka.clients.admin.DescribeTopicsOptions; +import org.apache.kafka.clients.admin.ListOffsetsResult; +import org.apache.kafka.clients.admin.ListOffsetsResult.ListOffsetsResultInfo; +import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.clients.admin.OffsetSpec; +import org.apache.kafka.clients.admin.TopicDescription; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.config.TopicConfig; +import org.apache.kafka.common.errors.AuthorizationException; +import org.apache.kafka.common.errors.ClusterAuthorizationException; +import org.apache.kafka.common.errors.InvalidConfigurationException; +import org.apache.kafka.common.errors.LeaderNotAvailableException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.TopicAuthorizationException; +import org.apache.kafka.common.errors.TopicExistsException; +import org.apache.kafka.common.errors.UnknownTopicOrPartitionException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.errors.RetriableException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * Utility to simplify creating and managing topics via the {@link Admin}. + */ +public class TopicAdmin implements AutoCloseable { + + public static final TopicCreationResponse EMPTY_CREATION = new TopicCreationResponse(Collections.emptySet(), Collections.emptySet()); + + public static class TopicCreationResponse { + + private final Set created; + private final Set existing; + + public TopicCreationResponse(Set createdTopicNames, Set existingTopicNames) { + this.created = Collections.unmodifiableSet(createdTopicNames); + this.existing = Collections.unmodifiableSet(existingTopicNames); + } + + public Set createdTopics() { + return created; + } + + public Set existingTopics() { + return existing; + } + + public boolean isCreated(String topicName) { + return created.contains(topicName); + } + + public boolean isExisting(String topicName) { + return existing.contains(topicName); + } + + public boolean isCreatedOrExisting(String topicName) { + return isCreated(topicName) || isExisting(topicName); + } + + public int createdTopicsCount() { + return created.size(); + } + + public int existingTopicsCount() { + return existing.size(); + } + + public int createdOrExistingTopicsCount() { + return createdTopicsCount() + existingTopicsCount(); + } + + public boolean isEmpty() { + return createdOrExistingTopicsCount() == 0; + } + + @Override + public String toString() { + return "TopicCreationResponse{" + "created=" + created + ", existing=" + existing + '}'; + } + } + + public static final int NO_PARTITIONS = -1; + public static final short NO_REPLICATION_FACTOR = -1; + + private static final String CLEANUP_POLICY_CONFIG = TopicConfig.CLEANUP_POLICY_CONFIG; + private static final String CLEANUP_POLICY_COMPACT = TopicConfig.CLEANUP_POLICY_COMPACT; + private static final String MIN_INSYNC_REPLICAS_CONFIG = TopicConfig.MIN_IN_SYNC_REPLICAS_CONFIG; + private static final String UNCLEAN_LEADER_ELECTION_ENABLE_CONFIG = TopicConfig.UNCLEAN_LEADER_ELECTION_ENABLE_CONFIG; + + /** + * A builder of {@link NewTopic} instances. + */ + public static class NewTopicBuilder { + private final String name; + private int numPartitions = NO_PARTITIONS; + private short replicationFactor = NO_REPLICATION_FACTOR; + private final Map configs = new HashMap<>(); + + NewTopicBuilder(String name) { + this.name = name; + } + + /** + * Specify the desired number of partitions for the topic. + * + * @param numPartitions the desired number of partitions; must be positive, or -1 to + * signify using the broker's default + * @return this builder to allow methods to be chained; never null + */ + public NewTopicBuilder partitions(int numPartitions) { + this.numPartitions = numPartitions; + return this; + } + + /** + * Specify the topic's number of partition should be the broker configuration for + * {@code num.partitions}. + * + * @return this builder to allow methods to be chained; never null + */ + public NewTopicBuilder defaultPartitions() { + this.numPartitions = NO_PARTITIONS; + return this; + } + + /** + * Specify the desired replication factor for the topic. + * + * @param replicationFactor the desired replication factor; must be positive, or -1 to + * signify using the broker's default + * @return this builder to allow methods to be chained; never null + */ + public NewTopicBuilder replicationFactor(short replicationFactor) { + this.replicationFactor = replicationFactor; + return this; + } + + /** + * Specify the replication factor for the topic should be the broker configurations for + * {@code default.replication.factor}. + * + * @return this builder to allow methods to be chained; never null + */ + public NewTopicBuilder defaultReplicationFactor() { + this.replicationFactor = NO_REPLICATION_FACTOR; + return this; + } + + /** + * Specify that the topic should be compacted. + * + * @return this builder to allow methods to be chained; never null + */ + public NewTopicBuilder compacted() { + this.configs.put(CLEANUP_POLICY_CONFIG, CLEANUP_POLICY_COMPACT); + return this; + } + + /** + * Specify the minimum number of in-sync replicas required for this topic. + * + * @param minInSyncReplicas the minimum number of in-sync replicas allowed for the topic; must be positive + * @return this builder to allow methods to be chained; never null + */ + public NewTopicBuilder minInSyncReplicas(short minInSyncReplicas) { + this.configs.put(MIN_INSYNC_REPLICAS_CONFIG, Short.toString(minInSyncReplicas)); + return this; + } + + /** + * Specify whether the broker is allowed to elect a leader that was not an in-sync replica when no ISRs + * are available. + * + * @param allow true if unclean leaders can be elected, or false if they are not allowed + * @return this builder to allow methods to be chained; never null + */ + public NewTopicBuilder uncleanLeaderElection(boolean allow) { + this.configs.put(UNCLEAN_LEADER_ELECTION_ENABLE_CONFIG, Boolean.toString(allow)); + return this; + } + + /** + * Specify the configuration properties for the topic, overwriting any previously-set properties. + * + * @param configs the desired topic configuration properties, or null if all existing properties should be cleared + * @return this builder to allow methods to be chained; never null + */ + public NewTopicBuilder config(Map configs) { + if (configs != null) { + for (Map.Entry entry : configs.entrySet()) { + Object value = entry.getValue(); + this.configs.put(entry.getKey(), value != null ? value.toString() : null); + } + } else { + this.configs.clear(); + } + return this; + } + + /** + * Build the {@link NewTopic} representation. + * + * @return the topic description; never null + */ + public NewTopic build() { + return new NewTopic( + name, + Optional.of(numPartitions), + Optional.of(replicationFactor) + ).configs(configs); + } + } + + /** + * Obtain a {@link NewTopicBuilder builder} to define a {@link NewTopic}. + * + * @param topicName the name of the topic + * @return the {@link NewTopic} description of the topic; never null + */ + public static NewTopicBuilder defineTopic(String topicName) { + return new NewTopicBuilder(topicName); + } + + private static final Logger log = LoggerFactory.getLogger(TopicAdmin.class); + private final Map adminConfig; + private final Admin admin; + private final boolean logCreation; + + /** + * Create a new topic admin component with the given configuration. + * + * @param adminConfig the configuration for the {@link Admin} + */ + public TopicAdmin(Map adminConfig) { + this(adminConfig, Admin.create(adminConfig)); + } + + // visible for testing + TopicAdmin(Map adminConfig, Admin adminClient) { + this(adminConfig, adminClient, true); + } + + // visible for testing + TopicAdmin(Map adminConfig, Admin adminClient, boolean logCreation) { + this.admin = adminClient; + this.adminConfig = adminConfig != null ? adminConfig : Collections.emptyMap(); + this.logCreation = logCreation; + } + + /** + * Get the {@link Admin} client used by this topic admin object. + * @return the Kafka admin instance; never null + */ + public Admin admin() { + return admin; + } + + /** + * Attempt to create the topic described by the given definition, returning true if the topic was created or false + * if the topic already existed. + * + * @param topic the specification of the topic + * @return true if the topic was created or false if the topic already existed. + * @throws ConnectException if an error occurs, the operation takes too long, or the thread is interrupted while + * attempting to perform this operation + * @throws UnsupportedVersionException if the broker does not support the necessary APIs to perform this request + */ + public boolean createTopic(NewTopic topic) { + if (topic == null) return false; + Set newTopicNames = createTopics(topic); + return newTopicNames.contains(topic.name()); + } + + /** + * Attempt to create the topics described by the given definitions, returning all of the names of those topics that + * were created by this request. Any existing topics with the same name are unchanged, and the names of such topics + * are excluded from the result. + *

            + * If multiple topic definitions have the same topic name, the last one with that name will be used. + *

            + * Apache Kafka added support for creating topics in 0.10.1.0, so this method works as expected with that and later versions. + * With brokers older than 0.10.1.0, this method is unable to create topics and always returns an empty set. + * + * @param topics the specifications of the topics + * @return the names of the topics that were created by this operation; never null but possibly empty + * @throws ConnectException if an error occurs, the operation takes too long, or the thread is interrupted while + * attempting to perform this operation + */ + public Set createTopics(NewTopic... topics) { + return createOrFindTopics(topics).createdTopics(); + } + + /** + * Attempt to find or create the topic described by the given definition, returning true if the topic was created or had + * already existed, or false if the topic did not exist and could not be created. + * + * @param topic the specification of the topic + * @return true if the topic was created or existed, or false if the topic could not already existed. + * @throws ConnectException if an error occurs, the operation takes too long, or the thread is interrupted while + * attempting to perform this operation + * @throws UnsupportedVersionException if the broker does not support the necessary APIs to perform this request + */ + public boolean createOrFindTopic(NewTopic topic) { + if (topic == null) return false; + return createOrFindTopics(topic).isCreatedOrExisting(topic.name()); + } + + /** + * Attempt to create the topics described by the given definitions, returning all of the names of those topics that + * were created by this request. Any existing topics with the same name are unchanged, and the names of such topics + * are excluded from the result. + *

            + * If multiple topic definitions have the same topic name, the last one with that name will be used. + *

            + * Apache Kafka added support for creating topics in 0.10.1.0, so this method works as expected with that and later versions. + * With brokers older than 0.10.1.0, this method is unable to create topics and always returns an empty set. + * + * @param topics the specifications of the topics + * @return the {@link TopicCreationResponse} with the names of the newly created and existing topics; + * never null but possibly empty + * @throws ConnectException if an error occurs, the operation takes too long, or the thread is interrupted while + * attempting to perform this operation + */ + public TopicCreationResponse createOrFindTopics(NewTopic... topics) { + Map topicsByName = new HashMap<>(); + if (topics != null) { + for (NewTopic topic : topics) { + if (topic != null) topicsByName.put(topic.name(), topic); + } + } + if (topicsByName.isEmpty()) return EMPTY_CREATION; + String bootstrapServers = bootstrapServers(); + String topicNameList = Utils.join(topicsByName.keySet(), "', '"); + + // Attempt to create any missing topics + CreateTopicsOptions args = new CreateTopicsOptions().validateOnly(false); + Map> newResults = admin.createTopics(topicsByName.values(), args).values(); + + // Iterate over each future so that we can handle individual failures like when some topics already exist + Set newlyCreatedTopicNames = new HashSet<>(); + Set existingTopicNames = new HashSet<>(); + for (Map.Entry> entry : newResults.entrySet()) { + String topic = entry.getKey(); + try { + entry.getValue().get(); + if (logCreation) { + log.info("Created topic {} on brokers at {}", topicsByName.get(topic), bootstrapServers); + } + newlyCreatedTopicNames.add(topic); + } catch (ExecutionException e) { + Throwable cause = e.getCause(); + if (cause instanceof TopicExistsException) { + log.debug("Found existing topic '{}' on the brokers at {}", topic, bootstrapServers); + existingTopicNames.add(topic); + continue; + } + if (cause instanceof UnsupportedVersionException) { + log.debug("Unable to create topic(s) '{}' since the brokers at {} do not support the CreateTopics API." + + " Falling back to assume topic(s) exist or will be auto-created by the broker.", + topicNameList, bootstrapServers); + return EMPTY_CREATION; + } + if (cause instanceof ClusterAuthorizationException) { + log.debug("Not authorized to create topic(s) '{}' upon the brokers {}." + + " Falling back to assume topic(s) exist or will be auto-created by the broker.", + topicNameList, bootstrapServers); + return EMPTY_CREATION; + } + if (cause instanceof TopicAuthorizationException) { + log.debug("Not authorized to create topic(s) '{}' upon the brokers {}." + + " Falling back to assume topic(s) exist or will be auto-created by the broker.", + topicNameList, bootstrapServers); + return EMPTY_CREATION; + } + if (cause instanceof InvalidConfigurationException) { + throw new ConnectException("Unable to create topic(s) '" + topicNameList + "': " + cause.getMessage(), + cause); + } + if (cause instanceof TimeoutException) { + // Timed out waiting for the operation to complete + throw new ConnectException("Timed out while checking for or creating topic(s) '" + topicNameList + "'." + + " This could indicate a connectivity issue, unavailable topic partitions, or if" + + " this is your first use of the topic it may have taken too long to create.", cause); + } + throw new ConnectException("Error while attempting to create/find topic(s) '" + topicNameList + "'", e); + } catch (InterruptedException e) { + Thread.interrupted(); + throw new ConnectException("Interrupted while attempting to create/find topic(s) '" + topicNameList + "'", e); + } + } + return new TopicCreationResponse(newlyCreatedTopicNames, existingTopicNames); + } + + /** + * Attempt to fetch the descriptions of the given topics + * Apache Kafka added support for describing topics in 0.10.0.0, so this method works as expected with that and later versions. + * With brokers older than 0.10.0.0, this method is unable to describe topics and always returns an empty set. + * + * @param topics the topics to describe + * @return a map of topic names to topic descriptions of the topics that were requested; never null but possibly empty + * @throws RetriableException if a retriable error occurs, the operation takes too long, or the + * thread is interrupted while attempting to perform this operation + * @throws ConnectException if a non retriable error occurs + */ + public Map describeTopics(String... topics) { + if (topics == null) { + return Collections.emptyMap(); + } + String bootstrapServers = bootstrapServers(); + String topicNameList = String.join(", ", topics); + + Map> newResults = + admin.describeTopics(Arrays.asList(topics), new DescribeTopicsOptions()).topicNameValues(); + + // Iterate over each future so that we can handle individual failures like when some topics don't exist + Map existingTopics = new HashMap<>(); + newResults.forEach((topic, desc) -> { + try { + existingTopics.put(topic, desc.get()); + } catch (ExecutionException e) { + Throwable cause = e.getCause(); + if (cause instanceof UnknownTopicOrPartitionException) { + log.debug("Topic '{}' does not exist on the brokers at {}", topic, bootstrapServers); + return; + } + if (cause instanceof ClusterAuthorizationException || cause instanceof TopicAuthorizationException) { + String msg = String.format("Not authorized to describe topic(s) '%s' on the brokers %s", + topicNameList, bootstrapServers); + throw new ConnectException(msg, cause); + } + if (cause instanceof UnsupportedVersionException) { + String msg = String.format("Unable to describe topic(s) '%s' since the brokers " + + "at %s do not support the DescribeTopics API.", + topicNameList, bootstrapServers); + throw new ConnectException(msg, cause); + } + if (cause instanceof TimeoutException) { + // Timed out waiting for the operation to complete + throw new RetriableException("Timed out while describing topics '" + topicNameList + "'", cause); + } + throw new ConnectException("Error while attempting to describe topics '" + topicNameList + "'", e); + } catch (InterruptedException e) { + Thread.interrupted(); + throw new RetriableException("Interrupted while attempting to describe topics '" + topicNameList + "'", e); + } + }); + return existingTopics; + } + + /** + * Verify the named topic uses only compaction for the cleanup policy. + * + * @param topic the name of the topic + * @param workerTopicConfig the name of the worker configuration that specifies the topic name + * @return true if the admin client could be used to verify the topic setting, or false if + * the verification could not be performed, likely because the admin client principal + * did not have the required permissions or because the broker was older than 0.11.0.0 + * @throws ConfigException if the actual topic setting did not match the required setting + */ + public boolean verifyTopicCleanupPolicyOnlyCompact(String topic, String workerTopicConfig, + String topicPurpose) { + Set cleanupPolicies = topicCleanupPolicy(topic); + if (cleanupPolicies.isEmpty()) { + log.info("Unable to use admin client to verify the cleanup policy of '{}' " + + "topic is '{}', either because the broker is an older " + + "version or because the Kafka principal used for Connect " + + "internal topics does not have the required permission to " + + "describe topic configurations.", topic, TopicConfig.CLEANUP_POLICY_COMPACT); + return false; + } + Set expectedPolicies = Collections.singleton(TopicConfig.CLEANUP_POLICY_COMPACT); + if (!cleanupPolicies.equals(expectedPolicies)) { + String expectedPolicyStr = String.join(",", expectedPolicies); + String cleanupPolicyStr = String.join(",", cleanupPolicies); + String msg = String.format("Topic '%s' supplied via the '%s' property is required " + + "to have '%s=%s' to guarantee consistency and durability of " + + "%s, but found the topic currently has '%s=%s'. Continuing would likely " + + "result in eventually losing %s and problems restarting this Connect " + + "cluster in the future. Change the '%s' property in the " + + "Connect worker configurations to use a topic with '%s=%s'.", + topic, workerTopicConfig, TopicConfig.CLEANUP_POLICY_CONFIG, expectedPolicyStr, + topicPurpose, TopicConfig.CLEANUP_POLICY_CONFIG, cleanupPolicyStr, topicPurpose, + workerTopicConfig, TopicConfig.CLEANUP_POLICY_CONFIG, expectedPolicyStr); + throw new ConfigException(msg); + } + return true; + } + + /** + * Get the cleanup policy for a topic. + * + * @param topic the name of the topic + * @return the set of cleanup policies set for the topic; may be empty if the topic does not + * exist or the topic's cleanup policy could not be retrieved + */ + public Set topicCleanupPolicy(String topic) { + Config topicConfig = describeTopicConfig(topic); + if (topicConfig == null) { + // The topic must not exist + log.debug("Unable to find topic '{}' when getting cleanup policy", topic); + return Collections.emptySet(); + } + ConfigEntry entry = topicConfig.get(CLEANUP_POLICY_CONFIG); + if (entry != null && entry.value() != null) { + String policyStr = entry.value(); + log.debug("Found cleanup.policy={} for topic '{}'", policyStr, topic); + return Arrays.stream(policyStr.split(",")) + .map(String::trim) + .filter(s -> !s.isEmpty()) + .map(String::toLowerCase) + .collect(Collectors.toSet()); + } + // This is unexpected, as the topic config should include the cleanup.policy even if + // the topic settings don't override the broker's log.cleanup.policy. But just to be safe. + log.debug("Found no cleanup.policy for topic '{}'", topic); + return Collections.emptySet(); + } + + /** + * Attempt to fetch the topic configuration for the given topic. + * Apache Kafka added support for describing topic configurations in 0.11.0.0, so this method + * works as expected with that and later versions. With brokers older than 0.11.0.0, this method + * is unable get the topic configurations and always returns a null value. + * + *

            If the topic does not exist, a null value is returned. + * + * @param topic the name of the topic for which the topic configuration should be obtained + * @return the topic configuration if the topic exists, or null if the topic did not exist + * @throws RetriableException if a retriable error occurs, the operation takes too long, or the + * thread is interrupted while attempting to perform this operation + * @throws ConnectException if a non retriable error occurs + */ + public Config describeTopicConfig(String topic) { + return describeTopicConfigs(topic).get(topic); + } + + /** + * Attempt to fetch the topic configurations for the given topics. + * Apache Kafka added support for describing topic configurations in 0.11.0.0, so this method + * works as expected with that and later versions. With brokers older than 0.11.0.0, this method + * is unable get the topic configurations and always returns an empty set. + * + *

            An entry with a null Config is placed into the resulting map for any topic that does + * not exist on the brokers. + * + * @param topicNames the topics to obtain configurations + * @return the map of topic configurations for each existing topic, or an empty map if none + * of the topics exist + * @throws RetriableException if a retriable error occurs, the operation takes too long, or the + * thread is interrupted while attempting to perform this operation + * @throws ConnectException if a non retriable error occurs + */ + public Map describeTopicConfigs(String... topicNames) { + if (topicNames == null) { + return Collections.emptyMap(); + } + Collection topics = Arrays.stream(topicNames) + .filter(Objects::nonNull) + .map(String::trim) + .filter(s -> !s.isEmpty()) + .collect(Collectors.toList()); + if (topics.isEmpty()) { + return Collections.emptyMap(); + } + String bootstrapServers = bootstrapServers(); + String topicNameList = topics.stream().collect(Collectors.joining(", ")); + Collection resources = topics.stream() + .map(t -> new ConfigResource(ConfigResource.Type.TOPIC, t)) + .collect(Collectors.toList()); + + Map> newResults = admin.describeConfigs(resources, new DescribeConfigsOptions()).values(); + + // Iterate over each future so that we can handle individual failures like when some topics don't exist + Map result = new HashMap<>(); + newResults.forEach((resource, configs) -> { + String topic = resource.name(); + try { + result.put(topic, configs.get()); + } catch (ExecutionException e) { + Throwable cause = e.getCause(); + if (cause instanceof UnknownTopicOrPartitionException) { + log.debug("Topic '{}' does not exist on the brokers at {}", topic, bootstrapServers); + result.put(topic, null); + } else if (cause instanceof ClusterAuthorizationException || cause instanceof TopicAuthorizationException) { + log.debug("Not authorized to describe topic config for topic '{}' on brokers at {}", topic, bootstrapServers); + } else if (cause instanceof UnsupportedVersionException) { + log.debug("API to describe topic config for topic '{}' is unsupported on brokers at {}", topic, bootstrapServers); + } else if (cause instanceof TimeoutException) { + String msg = String.format("Timed out while waiting to describe topic config for topic '%s' on brokers at %s", + topic, bootstrapServers); + throw new RetriableException(msg, e); + } else { + String msg = String.format("Error while attempting to describe topic config for topic '%s' on brokers at %s", + topic, bootstrapServers); + throw new ConnectException(msg, e); + } + } catch (InterruptedException e) { + Thread.interrupted(); + String msg = String.format("Interrupted while attempting to describe topic configs '%s'", topicNameList); + throw new RetriableException(msg, e); + } + }); + return result; + } + + /** + * Fetch the most recent offset for each of the supplied {@link TopicPartition} objects. + * + * @param partitions the topic partitions + * @return the map of offset for each topic partition, or an empty map if the supplied partitions + * are null or empty + * @throws UnsupportedVersionException if the admin client cannot read end offsets + * @throws TimeoutException if the offset metadata could not be fetched before the amount of time allocated + * by {@code request.timeout.ms} expires, and this call can be retried + * @throws LeaderNotAvailableException if the leader was not available and this call can be retried + * @throws RetriableException if a retriable error occurs, or the thread is interrupted while attempting + * to perform this operation + * @throws ConnectException if a non retriable error occurs + */ + public Map endOffsets(Set partitions) { + if (partitions == null || partitions.isEmpty()) { + return Collections.emptyMap(); + } + Map offsetSpecMap = partitions.stream().collect(Collectors.toMap(Function.identity(), tp -> OffsetSpec.latest())); + ListOffsetsResult resultFuture = admin.listOffsets(offsetSpecMap); + // Get the individual result for each topic partition so we have better error messages + Map result = new HashMap<>(); + for (TopicPartition partition : partitions) { + try { + ListOffsetsResultInfo info = resultFuture.partitionResult(partition).get(); + result.put(partition, info.offset()); + } catch (ExecutionException e) { + Throwable cause = e.getCause(); + String topic = partition.topic(); + if (cause instanceof AuthorizationException) { + String msg = String.format("Not authorized to get the end offsets for topic '%s' on brokers at %s", topic, bootstrapServers()); + throw new ConnectException(msg, e); + } else if (cause instanceof UnsupportedVersionException) { + // Should theoretically never happen, because this method is the same as what the consumer uses and therefore + // should exist in the broker since before the admin client was added + String msg = String.format("API to get the get the end offsets for topic '%s' is unsupported on brokers at %s", topic, bootstrapServers()); + throw new UnsupportedVersionException(msg, e); + } else if (cause instanceof TimeoutException) { + String msg = String.format("Timed out while waiting to get end offsets for topic '%s' on brokers at %s", topic, bootstrapServers()); + throw new TimeoutException(msg, e); + } else if (cause instanceof LeaderNotAvailableException) { + String msg = String.format("Unable to get end offsets during leader election for topic '%s' on brokers at %s", topic, bootstrapServers()); + throw new LeaderNotAvailableException(msg, e); + } else if (cause instanceof org.apache.kafka.common.errors.RetriableException) { + throw (org.apache.kafka.common.errors.RetriableException) cause; + } else { + String msg = String.format("Error while getting end offsets for topic '%s' on brokers at %s", topic, bootstrapServers()); + throw new ConnectException(msg, e); + } + } catch (InterruptedException e) { + Thread.interrupted(); + String msg = String.format("Interrupted while attempting to read end offsets for topic '%s' on brokers at %s", partition.topic(), bootstrapServers()); + throw new RetriableException(msg, e); + } + } + return result; + } + + @Override + public void close() { + admin.close(); + } + + public void close(Duration timeout) { + admin.close(timeout); + } + + private String bootstrapServers() { + Object servers = adminConfig.get(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG); + return servers != null ? servers.toString() : ""; + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/util/TopicCreation.java b/connect/runtime/src/main/java/org/apache/kafka/connect/util/TopicCreation.java new file mode 100644 index 0000000..f914ed9 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/util/TopicCreation.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Set; + +import static org.apache.kafka.connect.runtime.TopicCreationConfig.DEFAULT_TOPIC_CREATION_GROUP; + +/** + * Utility to be used by worker source tasks in order to create topics, if topic creation is + * enabled for source connectors at the worker and the connector configurations. + */ +public class TopicCreation { + private static final Logger log = LoggerFactory.getLogger(TopicCreation.class); + private static final TopicCreation EMPTY = + new TopicCreation(false, null, Collections.emptyMap(), Collections.emptySet()); + + private final boolean isTopicCreationEnabled; + private final TopicCreationGroup defaultTopicGroup; + private final Map topicGroups; + private final Set topicCache; + + protected TopicCreation(boolean isTopicCreationEnabled, + TopicCreationGroup defaultTopicGroup, + Map topicGroups, + Set topicCache) { + this.isTopicCreationEnabled = isTopicCreationEnabled; + this.defaultTopicGroup = defaultTopicGroup; + this.topicGroups = topicGroups; + this.topicCache = topicCache; + } + + public static TopicCreation newTopicCreation(WorkerConfig workerConfig, + Map topicGroups) { + if (!workerConfig.topicCreationEnable() || topicGroups == null) { + return EMPTY; + } + Map groups = new LinkedHashMap<>(topicGroups); + groups.remove(DEFAULT_TOPIC_CREATION_GROUP); + return new TopicCreation(true, topicGroups.get(DEFAULT_TOPIC_CREATION_GROUP), groups, new HashSet<>()); + } + + /** + * Return an instance of this utility that represents what the state of the internal data + * structures should be when topic creation is disabled. + * + * @return the utility when topic creation is disabled + */ + public static TopicCreation empty() { + return EMPTY; + } + + /** + * Check whether topic creation is enabled for this utility instance. This is state is set at + * instantiation time and remains unchanged for the lifetime of every {@link TopicCreation} + * object. + * + * @return true if topic creation is enabled; false otherwise + */ + public boolean isTopicCreationEnabled() { + return isTopicCreationEnabled; + } + + /** + * Check whether topic creation may be required for a specific topic name. + * + * @return true if topic creation is enabled and the topic name is not in the topic cache; + * false otherwise + */ + public boolean isTopicCreationRequired(String topic) { + return isTopicCreationEnabled && !topicCache.contains(topic); + } + + /** + * Return the default topic creation group. This group is always defined when topic creation is + * enabled but is {@code null} if topic creation is disabled. + * + * @return the default topic creation group if topic creation is enabled; {@code null} otherwise + */ + public TopicCreationGroup defaultTopicGroup() { + return defaultTopicGroup; + } + + /** + * Return the topic creation groups defined for a source connector as a map of topic creation + * group name to topic creation group instance. This map maintains all the optionally defined + * groups besides the default group which is defined for any connector when topic creation is + * enabled. + * + * @return the map of all the topic creation groups besides the default group; may be empty + * but not {@code null} + */ + public Map topicGroups() { + return topicGroups; + } + + /** + * Inform this utility instance that a topic has been created and its creation will no + * longer be required. After {@link #addTopic(String)} is called for a give {@param topic} + * any subsequent calls to {@link #isTopicCreationRequired} will return {@code false} for the + * same topic. + * + * @param topic the topic name to mark as created + */ + public void addTopic(String topic) { + if (isTopicCreationEnabled) { + topicCache.add(topic); + } + } + + /** + * Get the first topic creation group that is configured to match the given {@param topic} + * name. If topic creation is enabled, any topic should match at least the default topic + * creation group. + * + * @param topic the topic name to match against group configurations + * + * @return the first group that matches the given topic + */ + public TopicCreationGroup findFirstGroup(String topic) { + return topicGroups.values().stream() + .filter(group -> group.matches(topic)) + .findFirst() + .orElse(defaultTopicGroup); + } +} diff --git a/connect/runtime/src/main/java/org/apache/kafka/connect/util/TopicCreationGroup.java b/connect/runtime/src/main/java/org/apache/kafka/connect/util/TopicCreationGroup.java new file mode 100644 index 0000000..1197033 --- /dev/null +++ b/connect/runtime/src/main/java/org/apache/kafka/connect/util/TopicCreationGroup.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.connect.runtime.SourceConnectorConfig; + +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.regex.Pattern; + +import static org.apache.kafka.connect.runtime.SourceConnectorConfig.TOPIC_CREATION_GROUPS_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.DEFAULT_TOPIC_CREATION_GROUP; + +/** + * Utility to simplify creating and managing topics via the {@link Admin}. + */ +public class TopicCreationGroup { + private final String name; + private final Pattern inclusionPattern; + private final Pattern exclusionPattern; + private final int numPartitions; + private final short replicationFactor; + private final Map otherConfigs; + + protected TopicCreationGroup(String group, SourceConnectorConfig config) { + this.name = group; + this.inclusionPattern = Pattern.compile(String.join("|", config.topicCreationInclude(group))); + this.exclusionPattern = Pattern.compile(String.join("|", config.topicCreationExclude(group))); + this.numPartitions = config.topicCreationPartitions(group); + this.replicationFactor = config.topicCreationReplicationFactor(group); + this.otherConfigs = config.topicCreationOtherConfigs(group); + } + + /** + * Parses the configuration of a source connector and returns the topic creation groups + * defined in the given configuration as a map of group names to {@link TopicCreation} objects. + * + * @param config the source connector configuration + * + * @return the map of topic creation groups; may be empty but not {@code null} + */ + public static Map configuredGroups(SourceConnectorConfig config) { + if (!config.usesTopicCreation()) { + return Collections.emptyMap(); + } + List groupNames = config.getList(TOPIC_CREATION_GROUPS_CONFIG); + Map groups = new LinkedHashMap<>(); + for (String group : groupNames) { + groups.put(group, new TopicCreationGroup(group, config)); + } + // Even if there was a group called 'default' in the config, it will be overridden here. + // Order matters for all the topic groups besides the default, since it will be + // removed from this collection by the Worker + groups.put(DEFAULT_TOPIC_CREATION_GROUP, new TopicCreationGroup(DEFAULT_TOPIC_CREATION_GROUP, config)); + return groups; + } + + /** + * Return the name of the topic creation group. + * + * @return the name of the topic creation group + */ + public String name() { + return name; + } + + /** + * Answer whether this topic creation group is configured to allow the creation of the given + * {@param topic} name. + * + * @param topic the topic name to check against the groups configuration + * + * @return true if the topic name matches the inclusion regex and does + * not match the exclusion regex of this group's configuration; false otherwise + */ + public boolean matches(String topic) { + return !exclusionPattern.matcher(topic).matches() && inclusionPattern.matcher(topic) + .matches(); + } + + /** + * Return the description for a new topic with the given {@param topic} name with the topic + * settings defined for this topic creation group. + * + * @param topic the name of the topic to be created + * + * @return the topic description of the given topic with settings of this topic creation group + */ + public NewTopic newTopic(String topic) { + TopicAdmin.NewTopicBuilder builder = new TopicAdmin.NewTopicBuilder(topic); + return builder.partitions(numPartitions) + .replicationFactor(replicationFactor) + .config(otherConfigs) + .build(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof TopicCreationGroup)) { + return false; + } + TopicCreationGroup that = (TopicCreationGroup) o; + return Objects.equals(name, that.name) + && numPartitions == that.numPartitions + && replicationFactor == that.replicationFactor + && Objects.equals(inclusionPattern.pattern(), that.inclusionPattern.pattern()) + && Objects.equals(exclusionPattern.pattern(), that.exclusionPattern.pattern()) + && Objects.equals(otherConfigs, that.otherConfigs); + } + + @Override + public int hashCode() { + return Objects.hash(name, numPartitions, replicationFactor, inclusionPattern.pattern(), + exclusionPattern.pattern(), otherConfigs + ); + } + + @Override + public String toString() { + return "TopicCreationGroup{" + + "name='" + name + '\'' + + ", inclusionPattern=" + inclusionPattern + + ", exclusionPattern=" + exclusionPattern + + ", numPartitions=" + numPartitions + + ", replicationFactor=" + replicationFactor + + ", otherConfigs=" + otherConfigs + + '}'; + } +} diff --git a/connect/runtime/src/main/resources/META-INF/services/org.apache.kafka.connect.connector.policy.ConnectorClientConfigOverridePolicy b/connect/runtime/src/main/resources/META-INF/services/org.apache.kafka.connect.connector.policy.ConnectorClientConfigOverridePolicy new file mode 100644 index 0000000..8b76ce4 --- /dev/null +++ b/connect/runtime/src/main/resources/META-INF/services/org.apache.kafka.connect.connector.policy.ConnectorClientConfigOverridePolicy @@ -0,0 +1,18 @@ + # Licensed to the Apache Software Foundation (ASF) under one or more + # contributor license agreements. See the NOTICE file distributed with + # this work for additional information regarding copyright ownership. + # The ASF licenses this file to You under the Apache License, Version 2.0 + # (the "License"); you may not use this file except in compliance with + # the License. You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + +org.apache.kafka.connect.connector.policy.AllConnectorClientConfigOverridePolicy +org.apache.kafka.connect.connector.policy.PrincipalConnectorClientConfigOverridePolicy +org.apache.kafka.connect.connector.policy.NoneConnectorClientConfigOverridePolicy \ No newline at end of file diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/connector/policy/BaseConnectorClientConfigOverridePolicyTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/connector/policy/BaseConnectorClientConfigOverridePolicyTest.java new file mode 100644 index 0000000..28fee73 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/connector/policy/BaseConnectorClientConfigOverridePolicyTest.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.connector.policy; + +import org.apache.kafka.common.config.ConfigValue; +import org.apache.kafka.connect.health.ConnectorType; +import org.apache.kafka.connect.runtime.WorkerTest; +import org.junit.Assert; + +import java.util.List; +import java.util.Map; + +public abstract class BaseConnectorClientConfigOverridePolicyTest { + + protected abstract ConnectorClientConfigOverridePolicy policyToTest(); + + protected void testValidOverride(Map clientConfig) { + List configValues = configValues(clientConfig); + assertNoError(configValues); + } + + protected void testInvalidOverride(Map clientConfig) { + List configValues = configValues(clientConfig); + assertError(configValues); + } + + private List configValues(Map clientConfig) { + ConnectorClientConfigRequest connectorClientConfigRequest = new ConnectorClientConfigRequest( + "test", + ConnectorType.SOURCE, + WorkerTest.WorkerTestConnector.class, + clientConfig, + ConnectorClientConfigRequest.ClientType.PRODUCER); + return policyToTest().validate(connectorClientConfigRequest); + } + + protected void assertNoError(List configValues) { + Assert.assertTrue(configValues.stream().allMatch(configValue -> configValue.errorMessages().size() == 0)); + } + + protected void assertError(List configValues) { + Assert.assertTrue(configValues.stream().anyMatch(configValue -> configValue.errorMessages().size() > 0)); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/connector/policy/NoneConnectorClientConfigOverridePolicyTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/connector/policy/NoneConnectorClientConfigOverridePolicyTest.java new file mode 100644 index 0000000..2c7b078 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/connector/policy/NoneConnectorClientConfigOverridePolicyTest.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.connector.policy; + +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.config.SaslConfigs; +import org.junit.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +public class NoneConnectorClientConfigOverridePolicyTest extends BaseConnectorClientConfigOverridePolicyTest { + + ConnectorClientConfigOverridePolicy noneConnectorClientConfigOverridePolicy = new NoneConnectorClientConfigOverridePolicy(); + + @Test + public void testNoOverrides() { + testValidOverride(Collections.emptyMap()); + } + + @Test + public void testWithOverrides() { + Map clientConfig = new HashMap<>(); + clientConfig.put(SaslConfigs.SASL_JAAS_CONFIG, "test"); + clientConfig.put(ProducerConfig.ACKS_CONFIG, "none"); + testInvalidOverride(clientConfig); + } + + @Override + protected ConnectorClientConfigOverridePolicy policyToTest() { + return noneConnectorClientConfigOverridePolicy; + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/connector/policy/PrincipalConnectorClientConfigOverridePolicyTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/connector/policy/PrincipalConnectorClientConfigOverridePolicyTest.java new file mode 100644 index 0000000..0e79c8a --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/connector/policy/PrincipalConnectorClientConfigOverridePolicyTest.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.connector.policy; + +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.config.SaslConfigs; +import org.junit.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +public class PrincipalConnectorClientConfigOverridePolicyTest extends BaseConnectorClientConfigOverridePolicyTest { + + ConnectorClientConfigOverridePolicy principalConnectorClientConfigOverridePolicy = new PrincipalConnectorClientConfigOverridePolicy(); + + @Test + public void testPrincipalOnly() { + Map clientConfig = Collections.singletonMap(SaslConfigs.SASL_JAAS_CONFIG, "test"); + testValidOverride(clientConfig); + } + + @Test + public void testPrincipalPlusOtherConfigs() { + Map clientConfig = new HashMap<>(); + clientConfig.put(SaslConfigs.SASL_JAAS_CONFIG, "test"); + clientConfig.put(ProducerConfig.ACKS_CONFIG, "none"); + testInvalidOverride(clientConfig); + } + + @Override + protected ConnectorClientConfigOverridePolicy policyToTest() { + return principalConnectorClientConfigOverridePolicy; + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/converters/ByteArrayConverterTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/converters/ByteArrayConverterTest.java new file mode 100644 index 0000000..f14b76f --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/converters/ByteArrayConverterTest.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.converters; + +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.errors.DataException; +import org.junit.Before; +import org.junit.Test; + +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Collections; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class ByteArrayConverterTest { + private static final String TOPIC = "topic"; + private static final byte[] SAMPLE_BYTES = "sample string".getBytes(StandardCharsets.UTF_8); + + private ByteArrayConverter converter = new ByteArrayConverter(); + + @Before + public void setUp() { + converter.configure(Collections.emptyMap(), false); + } + + @Test + public void testFromConnect() { + assertArrayEquals( + SAMPLE_BYTES, + converter.fromConnectData(TOPIC, Schema.BYTES_SCHEMA, SAMPLE_BYTES) + ); + } + + @Test + public void testFromConnectSchemaless() { + assertArrayEquals( + SAMPLE_BYTES, + converter.fromConnectData(TOPIC, null, SAMPLE_BYTES) + ); + } + + @Test + public void testFromConnectBadSchema() { + assertThrows(DataException.class, + () -> converter.fromConnectData(TOPIC, Schema.INT32_SCHEMA, SAMPLE_BYTES)); + } + + @Test + public void testFromConnectInvalidValue() { + assertThrows(DataException.class, + () -> converter.fromConnectData(TOPIC, Schema.BYTES_SCHEMA, 12)); + } + + @Test + public void testFromConnectNull() { + assertNull(converter.fromConnectData(TOPIC, Schema.BYTES_SCHEMA, null)); + } + + @Test + public void testToConnect() { + SchemaAndValue data = converter.toConnectData(TOPIC, SAMPLE_BYTES); + assertEquals(Schema.OPTIONAL_BYTES_SCHEMA, data.schema()); + assertTrue(Arrays.equals(SAMPLE_BYTES, (byte[]) data.value())); + } + + @Test + public void testToConnectNull() { + SchemaAndValue data = converter.toConnectData(TOPIC, null); + assertEquals(Schema.OPTIONAL_BYTES_SCHEMA, data.schema()); + assertNull(data.value()); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/converters/DoubleConverterTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/converters/DoubleConverterTest.java new file mode 100644 index 0000000..acc3dde --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/converters/DoubleConverterTest.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.converters; + +import org.apache.kafka.common.serialization.DoubleSerializer; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.connect.data.Schema; + +public class DoubleConverterTest extends NumberConverterTest { + + public Double[] samples() { + return new Double[]{Double.MIN_VALUE, 1234.31, Double.MAX_VALUE}; + } + + @Override + protected Schema schema() { + return Schema.OPTIONAL_FLOAT64_SCHEMA; + } + + @Override + protected NumberConverter createConverter() { + return new DoubleConverter(); + } + + @Override + protected Serializer createSerializer() { + return new DoubleSerializer(); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/converters/FloatConverterTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/converters/FloatConverterTest.java new file mode 100644 index 0000000..e95ff56 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/converters/FloatConverterTest.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.converters; + +import org.apache.kafka.common.serialization.FloatSerializer; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.connect.data.Schema; + +public class FloatConverterTest extends NumberConverterTest { + + public Float[] samples() { + return new Float[]{Float.MIN_VALUE, 1234.31f, Float.MAX_VALUE}; + } + + @Override + protected Schema schema() { + return Schema.OPTIONAL_FLOAT32_SCHEMA; + } + + @Override + protected NumberConverter createConverter() { + return new FloatConverter(); + } + + @Override + protected Serializer createSerializer() { + return new FloatSerializer(); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/converters/IntegerConverterTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/converters/IntegerConverterTest.java new file mode 100644 index 0000000..0c9ed28 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/converters/IntegerConverterTest.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.converters; + +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.connect.data.Schema; + +public class IntegerConverterTest extends NumberConverterTest { + + public Integer[] samples() { + return new Integer[]{Integer.MIN_VALUE, 1234, Integer.MAX_VALUE}; + } + + @Override + protected Schema schema() { + return Schema.OPTIONAL_INT32_SCHEMA; + } + + @Override + protected NumberConverter createConverter() { + return new IntegerConverter(); + } + + @Override + protected Serializer createSerializer() { + return new IntegerSerializer(); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/converters/LongConverterTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/converters/LongConverterTest.java new file mode 100644 index 0000000..35d26b7 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/converters/LongConverterTest.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.converters; + +import org.apache.kafka.common.serialization.LongSerializer; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.connect.data.Schema; + +public class LongConverterTest extends NumberConverterTest { + + public Long[] samples() { + return new Long[]{Long.MIN_VALUE, 1234L, Long.MAX_VALUE}; + } + + @Override + protected Schema schema() { + return Schema.OPTIONAL_INT64_SCHEMA; + } + + @Override + protected NumberConverter createConverter() { + return new LongConverter(); + } + + @Override + protected Serializer createSerializer() { + return new LongSerializer(); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/converters/NumberConverterTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/converters/NumberConverterTest.java new file mode 100644 index 0000000..d377c69 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/converters/NumberConverterTest.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.converters; + +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.errors.DataException; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; + +public abstract class NumberConverterTest { + private static final String TOPIC = "topic"; + private static final String HEADER_NAME = "header"; + + private T[] samples; + private Schema schema; + private NumberConverter converter; + private Serializer serializer; + + protected abstract T[] samples(); + + protected abstract NumberConverter createConverter(); + + protected abstract Serializer createSerializer(); + + protected abstract Schema schema(); + + @Before + public void setup() { + converter = createConverter(); + serializer = createSerializer(); + schema = schema(); + samples = samples(); + } + + @Test + public void testConvertingSamplesToAndFromBytes() throws UnsupportedOperationException { + for (T sample : samples) { + byte[] expected = serializer.serialize(TOPIC, sample); + + // Data conversion + assertArrayEquals(expected, converter.fromConnectData(TOPIC, schema, sample)); + SchemaAndValue data = converter.toConnectData(TOPIC, expected); + assertEquals(schema, data.schema()); + assertEquals(sample, data.value()); + + // Header conversion + assertArrayEquals(expected, converter.fromConnectHeader(TOPIC, HEADER_NAME, schema, sample)); + data = converter.toConnectHeader(TOPIC, HEADER_NAME, expected); + assertEquals(schema, data.schema()); + assertEquals(sample, data.value()); + } + } + + @Test + public void testDeserializingDataWithTooManyBytes() { + assertThrows(DataException.class, () -> converter.toConnectData(TOPIC, new byte[10])); + } + + @Test + public void testDeserializingHeaderWithTooManyBytes() { + assertThrows(DataException.class, () -> converter.toConnectHeader(TOPIC, HEADER_NAME, new byte[10])); + } + + @Test + public void testSerializingIncorrectType() { + assertThrows(DataException.class, () -> converter.fromConnectData(TOPIC, schema, "not a valid number")); + } + + @Test + public void testSerializingIncorrectHeader() { + assertThrows(DataException.class, + () -> converter.fromConnectHeader(TOPIC, HEADER_NAME, schema, "not a valid number")); + } + + @Test + public void testNullToBytes() { + assertNull(converter.fromConnectData(TOPIC, schema, null)); + } + + @Test + public void testBytesNullToNumber() { + SchemaAndValue data = converter.toConnectData(TOPIC, null); + assertEquals(schema(), data.schema()); + assertNull(data.value()); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/converters/ShortConverterTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/converters/ShortConverterTest.java new file mode 100644 index 0000000..d1237c9 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/converters/ShortConverterTest.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.converters; + +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.serialization.ShortSerializer; +import org.apache.kafka.connect.data.Schema; + +public class ShortConverterTest extends NumberConverterTest { + + public Short[] samples() { + return new Short[]{Short.MIN_VALUE, 123, Short.MAX_VALUE}; + } + + @Override + protected Schema schema() { + return Schema.OPTIONAL_INT16_SCHEMA; + } + + @Override + protected NumberConverter createConverter() { + return new ShortConverter(); + } + + @Override + protected Serializer createSerializer() { + return new ShortSerializer(); + } +} + diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/BlockingConnectorTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/BlockingConnectorTest.java new file mode 100644 index 0000000..8268de2 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/BlockingConnectorTest.java @@ -0,0 +1,799 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.integration; + +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.Config; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.connector.Connector; +import org.apache.kafka.connect.connector.ConnectorContext; +import org.apache.kafka.connect.connector.Task; +import org.apache.kafka.connect.runtime.Worker; +import org.apache.kafka.connect.runtime.rest.errors.ConnectRestException; +import org.apache.kafka.connect.runtime.rest.resources.ConnectorsResource; +import org.apache.kafka.connect.sink.SinkConnector; +import org.apache.kafka.connect.sink.SinkRecord; +import org.apache.kafka.connect.sink.SinkTask; +import org.apache.kafka.connect.sink.SinkTaskContext; +import org.apache.kafka.connect.source.SourceConnector; +import org.apache.kafka.connect.source.SourceRecord; +import org.apache.kafka.connect.source.SourceTask; +import org.apache.kafka.connect.source.SourceTaskContext; +import org.apache.kafka.connect.util.clusters.EmbeddedConnectCluster; +import org.apache.kafka.test.IntegrationTest; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Properties; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.apache.kafka.connect.runtime.ConnectorConfig.CONNECTOR_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.TASKS_MAX_CONFIG; +import static org.apache.kafka.connect.runtime.SinkConnectorConfig.TOPICS_CONFIG; +import static org.apache.kafka.test.TestUtils.waitForCondition; +import static org.junit.Assert.assertThrows; + +/** + * Tests situations during which certain connector operations, such as start, validation, + * configuration and others, take longer than expected. + */ +@Category(IntegrationTest.class) +public class BlockingConnectorTest { + + private static final Logger log = LoggerFactory.getLogger(BlockingConnectorTest.class); + + private static final int NUM_WORKERS = 1; + private static final String BLOCKING_CONNECTOR_NAME = "blocking-connector"; + private static final String NORMAL_CONNECTOR_NAME = "normal-connector"; + private static final String TEST_TOPIC = "normal-topic"; + private static final int NUM_RECORDS_PRODUCED = 100; + private static final long CONNECT_WORKER_STARTUP_TIMEOUT = TimeUnit.SECONDS.toMillis(60); + private static final long RECORD_TRANSFER_DURATION_MS = TimeUnit.SECONDS.toMillis(30); + private static final long REST_REQUEST_TIMEOUT = Worker.CONNECTOR_GRACEFUL_SHUTDOWN_TIMEOUT_MS * 2; + + private static final String CONNECTOR_INITIALIZE = "Connector::initialize"; + private static final String CONNECTOR_INITIALIZE_WITH_TASK_CONFIGS = "Connector::initializeWithTaskConfigs"; + private static final String CONNECTOR_START = "Connector::start"; + private static final String CONNECTOR_RECONFIGURE = "Connector::reconfigure"; + private static final String CONNECTOR_TASK_CLASS = "Connector::taskClass"; + private static final String CONNECTOR_TASK_CONFIGS = "Connector::taskConfigs"; + private static final String CONNECTOR_STOP = "Connector::stop"; + private static final String CONNECTOR_VALIDATE = "Connector::validate"; + private static final String CONNECTOR_CONFIG = "Connector::config"; + private static final String CONNECTOR_VERSION = "Connector::version"; + private static final String TASK_START = "Task::start"; + private static final String TASK_STOP = "Task::stop"; + private static final String TASK_VERSION = "Task::version"; + private static final String SINK_TASK_INITIALIZE = "SinkTask::initialize"; + private static final String SINK_TASK_PUT = "SinkTask::put"; + private static final String SINK_TASK_FLUSH = "SinkTask::flush"; + private static final String SINK_TASK_PRE_COMMIT = "SinkTask::preCommit"; + private static final String SINK_TASK_OPEN = "SinkTask::open"; + private static final String SINK_TASK_ON_PARTITIONS_ASSIGNED = "SinkTask::onPartitionsAssigned"; + private static final String SINK_TASK_CLOSE = "SinkTask::close"; + private static final String SINK_TASK_ON_PARTITIONS_REVOKED = "SinkTask::onPartitionsRevoked"; + private static final String SOURCE_TASK_INITIALIZE = "SourceTask::initialize"; + private static final String SOURCE_TASK_POLL = "SourceTask::poll"; + private static final String SOURCE_TASK_COMMIT = "SourceTask::commit"; + private static final String SOURCE_TASK_COMMIT_RECORD = "SourceTask::commitRecord"; + private static final String SOURCE_TASK_COMMIT_RECORD_WITH_METADATA = "SourceTask::commitRecordWithMetadata"; + + private EmbeddedConnectCluster connect; + private ConnectorHandle normalConnectorHandle; + + @Before + public void setup() throws Exception { + // Artificially reduce the REST request timeout so that these don't take forever + ConnectorsResource.setRequestTimeout(REST_REQUEST_TIMEOUT); + // build a Connect cluster backed by Kafka and Zk + connect = new EmbeddedConnectCluster.Builder() + .name("connect-cluster") + .numWorkers(NUM_WORKERS) + .numBrokers(1) + .workerProps(new HashMap<>()) + .brokerProps(new Properties()) + .build(); + + // start the clusters + connect.start(); + + // wait for the Connect REST API to become available. necessary because of the reduced REST + // request timeout; otherwise, we may get an unexpected 500 with our first real REST request + // if the worker is still getting on its feet. + waitForCondition( + () -> connect.requestGet(connect.endpointForResource("connectors/nonexistent")).getStatus() == 404, + CONNECT_WORKER_STARTUP_TIMEOUT, + "Worker did not complete startup in time" + ); + } + + @After + public void close() { + // stop all Connect, Kafka and Zk threads. + connect.stop(); + ConnectorsResource.resetRequestTimeout(); + Block.resetBlockLatch(); + } + + @Test + public void testBlockInConnectorValidate() throws Exception { + log.info("Starting test testBlockInConnectorValidate"); + assertThrows(ConnectRestException.class, () -> createConnectorWithBlock(ValidateBlockingConnector.class, CONNECTOR_VALIDATE)); + // Will NOT assert that connector has failed, since the request should fail before it's even created + + // Connector should already be blocked so this should return immediately, but check just to + // make sure that it actually did block + Block.waitForBlock(); + + createNormalConnector(); + verifyNormalConnector(); + } + + @Test + public void testBlockInConnectorConfig() throws Exception { + log.info("Starting test testBlockInConnectorConfig"); + assertThrows(ConnectRestException.class, () -> createConnectorWithBlock(ConfigBlockingConnector.class, CONNECTOR_CONFIG)); + // Will NOT assert that connector has failed, since the request should fail before it's even created + + // Connector should already be blocked so this should return immediately, but check just to + // make sure that it actually did block + Block.waitForBlock(); + + createNormalConnector(); + verifyNormalConnector(); + } + + @Test + public void testBlockInConnectorInitialize() throws Exception { + log.info("Starting test testBlockInConnectorInitialize"); + createConnectorWithBlock(InitializeBlockingConnector.class, CONNECTOR_INITIALIZE); + Block.waitForBlock(); + + createNormalConnector(); + verifyNormalConnector(); + } + + @Test + public void testBlockInConnectorStart() throws Exception { + log.info("Starting test testBlockInConnectorStart"); + createConnectorWithBlock(BlockingConnector.class, CONNECTOR_START); + Block.waitForBlock(); + + createNormalConnector(); + verifyNormalConnector(); + } + + @Test + public void testBlockInConnectorStop() throws Exception { + log.info("Starting test testBlockInConnectorStop"); + createConnectorWithBlock(BlockingConnector.class, CONNECTOR_STOP); + waitForConnectorStart(BLOCKING_CONNECTOR_NAME); + connect.deleteConnector(BLOCKING_CONNECTOR_NAME); + Block.waitForBlock(); + + createNormalConnector(); + verifyNormalConnector(); + } + + @Test + public void testBlockInSourceTaskStart() throws Exception { + log.info("Starting test testBlockInSourceTaskStart"); + createConnectorWithBlock(BlockingSourceConnector.class, TASK_START); + Block.waitForBlock(); + + createNormalConnector(); + verifyNormalConnector(); + } + + @Test + public void testBlockInSourceTaskStop() throws Exception { + log.info("Starting test testBlockInSourceTaskStop"); + createConnectorWithBlock(BlockingSourceConnector.class, TASK_STOP); + waitForConnectorStart(BLOCKING_CONNECTOR_NAME); + connect.deleteConnector(BLOCKING_CONNECTOR_NAME); + Block.waitForBlock(); + + createNormalConnector(); + verifyNormalConnector(); + } + + @Test + public void testBlockInSinkTaskStart() throws Exception { + log.info("Starting test testBlockInSinkTaskStart"); + createConnectorWithBlock(BlockingSinkConnector.class, TASK_START); + Block.waitForBlock(); + + createNormalConnector(); + verifyNormalConnector(); + } + + @Test + public void testBlockInSinkTaskStop() throws Exception { + log.info("Starting test testBlockInSinkTaskStop"); + createConnectorWithBlock(BlockingSinkConnector.class, TASK_STOP); + waitForConnectorStart(BLOCKING_CONNECTOR_NAME); + connect.deleteConnector(BLOCKING_CONNECTOR_NAME); + Block.waitForBlock(); + + createNormalConnector(); + verifyNormalConnector(); + } + + @Test + public void testWorkerRestartWithBlockInConnectorStart() throws Exception { + log.info("Starting test testWorkerRestartWithBlockInConnectorStart"); + createConnectorWithBlock(BlockingConnector.class, CONNECTOR_START); + // First instance of the connector should block on startup + Block.waitForBlock(); + createNormalConnector(); + connect.removeWorker(); + + connect.addWorker(); + // After stopping the only worker and restarting it, a new instance of the blocking + // connector should be created and we can ensure that it blocks again + Block.waitForBlock(); + verifyNormalConnector(); + } + + @Test + public void testWorkerRestartWithBlockInConnectorStop() throws Exception { + log.info("Starting test testWorkerRestartWithBlockInConnectorStop"); + createConnectorWithBlock(BlockingConnector.class, CONNECTOR_STOP); + waitForConnectorStart(BLOCKING_CONNECTOR_NAME); + createNormalConnector(); + waitForConnectorStart(NORMAL_CONNECTOR_NAME); + connect.removeWorker(); + Block.waitForBlock(); + + connect.addWorker(); + waitForConnectorStart(BLOCKING_CONNECTOR_NAME); + verifyNormalConnector(); + } + + private void createConnectorWithBlock(Class connectorClass, String block) { + Map props = new HashMap<>(); + props.put(CONNECTOR_CLASS_CONFIG, connectorClass.getName()); + props.put(TASKS_MAX_CONFIG, "1"); + props.put(TOPICS_CONFIG, "t1"); // Required for sink connectors + props.put(Block.BLOCK_CONFIG, Objects.requireNonNull(block)); + log.info("Creating blocking connector of type {} with block in {}", connectorClass.getSimpleName(), block); + try { + connect.configureConnector(BLOCKING_CONNECTOR_NAME, props); + } catch (RuntimeException e) { + log.info("Failed to create connector", e); + throw e; + } + } + + private void createNormalConnector() { + connect.kafka().createTopic(TEST_TOPIC, 3); + + normalConnectorHandle = RuntimeHandles.get().connectorHandle(NORMAL_CONNECTOR_NAME); + normalConnectorHandle.expectedRecords(NUM_RECORDS_PRODUCED); + normalConnectorHandle.expectedCommits(NUM_RECORDS_PRODUCED); + + Map props = new HashMap<>(); + props.put(CONNECTOR_CLASS_CONFIG, MonitorableSourceConnector.class.getName()); + props.put(TASKS_MAX_CONFIG, "1"); + props.put(MonitorableSourceConnector.TOPIC_CONFIG, TEST_TOPIC); + log.info("Creating normal connector"); + try { + connect.configureConnector(NORMAL_CONNECTOR_NAME, props); + } catch (RuntimeException e) { + log.info("Failed to create connector", e); + throw e; + } + } + + private void waitForConnectorStart(String connector) throws InterruptedException { + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning( + connector, + 0, + String.format( + "Failed to observe transition to 'RUNNING' state for connector '%s' in time", + connector + ) + ); + } + + private void verifyNormalConnector() throws InterruptedException { + waitForConnectorStart(NORMAL_CONNECTOR_NAME); + normalConnectorHandle.awaitRecords(RECORD_TRANSFER_DURATION_MS); + normalConnectorHandle.awaitCommits(RECORD_TRANSFER_DURATION_MS); + } + + private static class Block { + private static CountDownLatch blockLatch; + + private final String block; + + public static final String BLOCK_CONFIG = "block"; + + private static ConfigDef config() { + return new ConfigDef() + .define( + BLOCK_CONFIG, + ConfigDef.Type.STRING, + "", + ConfigDef.Importance.MEDIUM, + "Where to block indefinitely, e.g., 'Connector::start', 'Connector::initialize', " + + "'Connector::taskConfigs', 'Task::version', 'SinkTask::put', 'SourceTask::poll'" + ); + } + + public static void waitForBlock() throws InterruptedException { + synchronized (Block.class) { + if (blockLatch == null) { + throw new IllegalArgumentException("No connector has been created yet"); + } + } + + log.debug("Waiting for connector to block"); + blockLatch.await(); + log.debug("Connector should now be blocked"); + } + + // Note that there is only ever at most one global block latch at a time, which makes tests that + // use blocks in multiple places impossible. If necessary, this can be addressed in the future by + // adding support for multiple block latches at a time, possibly identifiable by a connector/task + // ID, the location of the expected block, or both. + public static void resetBlockLatch() { + synchronized (Block.class) { + if (blockLatch != null) { + blockLatch.countDown(); + blockLatch = null; + } + } + } + + public Block(Map props) { + this(new AbstractConfig(config(), props).getString(BLOCK_CONFIG)); + } + + public Block(String block) { + this.block = block; + synchronized (Block.class) { + if (blockLatch != null) { + blockLatch.countDown(); + } + blockLatch = new CountDownLatch(1); + } + } + + public Map taskConfig() { + return Collections.singletonMap(BLOCK_CONFIG, block); + } + + public void maybeBlockOn(String block) { + if (block.equals(this.block)) { + log.info("Will block on {}", block); + blockLatch.countDown(); + while (true) { + try { + Thread.sleep(Long.MAX_VALUE); + } catch (InterruptedException e) { + // No-op. Just keep blocking. + } + } + } else { + log.debug("Will not block on {}", block); + } + } + } + + // Used to test blocks in Connector (as opposed to Task) methods + public static class BlockingConnector extends SourceConnector { + + private Block block; + + // No-args constructor required by the framework + public BlockingConnector() { + this(null); + } + + protected BlockingConnector(String block) { + this.block = new Block(block); + } + + @Override + public void initialize(ConnectorContext ctx) { + block.maybeBlockOn(CONNECTOR_INITIALIZE); + super.initialize(ctx); + } + + @Override + public void initialize(ConnectorContext ctx, List> taskConfigs) { + block.maybeBlockOn(CONNECTOR_INITIALIZE_WITH_TASK_CONFIGS); + super.initialize(ctx, taskConfigs); + } + + @Override + public void start(Map props) { + this.block = new Block(props); + block.maybeBlockOn(CONNECTOR_START); + } + + @Override + public void reconfigure(Map props) { + block.maybeBlockOn(CONNECTOR_RECONFIGURE); + super.reconfigure(props); + } + + @Override + public Class taskClass() { + block.maybeBlockOn(CONNECTOR_TASK_CLASS); + return BlockingTask.class; + } + + @Override + public List> taskConfigs(int maxTasks) { + block.maybeBlockOn(CONNECTOR_TASK_CONFIGS); + return Collections.singletonList(Collections.emptyMap()); + } + + @Override + public void stop() { + block.maybeBlockOn(CONNECTOR_STOP); + } + + @Override + public Config validate(Map connectorConfigs) { + block.maybeBlockOn(CONNECTOR_VALIDATE); + return super.validate(connectorConfigs); + } + + @Override + public ConfigDef config() { + block.maybeBlockOn(CONNECTOR_CONFIG); + return Block.config(); + } + + @Override + public String version() { + block.maybeBlockOn(CONNECTOR_VERSION); + return "0.0.0"; + } + + public static class BlockingTask extends SourceTask { + @Override + public void start(Map props) { + } + + @Override + public List poll() { + return null; + } + + @Override + public void stop() { + } + + @Override + public String version() { + return "0.0.0"; + } + } + } + + // Some methods are called before Connector::start, so we use this as a workaround + public static class InitializeBlockingConnector extends BlockingConnector { + public InitializeBlockingConnector() { + super(CONNECTOR_INITIALIZE); + } + } + + public static class ConfigBlockingConnector extends BlockingConnector { + public ConfigBlockingConnector() { + super(CONNECTOR_CONFIG); + } + } + + public static class ValidateBlockingConnector extends BlockingConnector { + public ValidateBlockingConnector() { + super(CONNECTOR_VALIDATE); + } + } + + // Used to test blocks in SourceTask methods + public static class BlockingSourceConnector extends SourceConnector { + + private Map props; + private final Class taskClass; + + // No-args constructor required by the framework + public BlockingSourceConnector() { + this(BlockingSourceTask.class); + } + + protected BlockingSourceConnector(Class taskClass) { + this.taskClass = taskClass; + } + + @Override + public void start(Map props) { + this.props = props; + } + + @Override + public Class taskClass() { + return taskClass; + } + + @Override + public List> taskConfigs(int maxTasks) { + return IntStream.range(0, maxTasks) + .mapToObj(i -> new HashMap<>(props)) + .collect(Collectors.toList()); + } + + @Override + public void stop() { + } + + @Override + public Config validate(Map connectorConfigs) { + return super.validate(connectorConfigs); + } + + @Override + public ConfigDef config() { + return Block.config(); + } + + @Override + public String version() { + return "0.0.0"; + } + + public static class BlockingSourceTask extends SourceTask { + private Block block; + + // No-args constructor required by the framework + public BlockingSourceTask() { + this(null); + } + + protected BlockingSourceTask(String block) { + this.block = new Block(block); + } + + @Override + public void start(Map props) { + this.block = new Block(props); + block.maybeBlockOn(TASK_START); + } + + @Override + public List poll() { + block.maybeBlockOn(SOURCE_TASK_POLL); + return null; + } + + @Override + public void stop() { + block.maybeBlockOn(TASK_STOP); + } + + @Override + public String version() { + block.maybeBlockOn(TASK_VERSION); + return "0.0.0"; + } + + @Override + public void initialize(SourceTaskContext context) { + block.maybeBlockOn(SOURCE_TASK_INITIALIZE); + super.initialize(context); + } + + @Override + public void commit() throws InterruptedException { + block.maybeBlockOn(SOURCE_TASK_COMMIT); + super.commit(); + } + + @Override + @SuppressWarnings("deprecation") + public void commitRecord(SourceRecord record) throws InterruptedException { + block.maybeBlockOn(SOURCE_TASK_COMMIT_RECORD); + super.commitRecord(record); + } + + @Override + public void commitRecord(SourceRecord record, RecordMetadata metadata) throws InterruptedException { + block.maybeBlockOn(SOURCE_TASK_COMMIT_RECORD_WITH_METADATA); + super.commitRecord(record, metadata); + } + } + } + + public static class TaskInitializeBlockingSourceConnector extends BlockingSourceConnector { + public TaskInitializeBlockingSourceConnector() { + super(InitializeBlockingSourceTask.class); + } + + public static class InitializeBlockingSourceTask extends BlockingSourceTask { + public InitializeBlockingSourceTask() { + super(SOURCE_TASK_INITIALIZE); + } + } + } + + // Used to test blocks in SinkTask methods + public static class BlockingSinkConnector extends SinkConnector { + + private Map props; + private final Class taskClass; + + // No-args constructor required by the framework + public BlockingSinkConnector() { + this(BlockingSinkTask.class); + } + + protected BlockingSinkConnector(Class taskClass) { + this.taskClass = taskClass; + } + + @Override + public void start(Map props) { + this.props = props; + } + + @Override + public Class taskClass() { + return taskClass; + } + + @Override + public List> taskConfigs(int maxTasks) { + return IntStream.rangeClosed(0, maxTasks) + .mapToObj(i -> new HashMap<>(props)) + .collect(Collectors.toList()); + } + + @Override + public void stop() { + } + + @Override + public Config validate(Map connectorConfigs) { + return super.validate(connectorConfigs); + } + + @Override + public ConfigDef config() { + return Block.config(); + } + + @Override + public String version() { + return "0.0.0"; + } + + public static class BlockingSinkTask extends SinkTask { + private Block block; + + // No-args constructor required by the framework + public BlockingSinkTask() { + this(null); + } + + protected BlockingSinkTask(String block) { + this.block = new Block(block); + } + + @Override + public void start(Map props) { + this.block = new Block(props); + block.maybeBlockOn(TASK_START); + } + + @Override + public void put(Collection records) { + block.maybeBlockOn(SINK_TASK_PUT); + } + + @Override + public void stop() { + block.maybeBlockOn(TASK_STOP); + } + + @Override + public String version() { + block.maybeBlockOn(TASK_VERSION); + return "0.0.0"; + } + + @Override + public void initialize(SinkTaskContext context) { + block.maybeBlockOn(SINK_TASK_INITIALIZE); + super.initialize(context); + } + + @Override + public void flush(Map currentOffsets) { + block.maybeBlockOn(SINK_TASK_FLUSH); + super.flush(currentOffsets); + } + + @Override + public Map preCommit(Map currentOffsets) { + block.maybeBlockOn(SINK_TASK_PRE_COMMIT); + return super.preCommit(currentOffsets); + } + + @Override + public void open(Collection partitions) { + block.maybeBlockOn(SINK_TASK_OPEN); + super.open(partitions); + } + + @Override + @SuppressWarnings("deprecation") + public void onPartitionsAssigned(Collection partitions) { + block.maybeBlockOn(SINK_TASK_ON_PARTITIONS_ASSIGNED); + super.onPartitionsAssigned(partitions); + } + + @Override + public void close(Collection partitions) { + block.maybeBlockOn(SINK_TASK_CLOSE); + super.close(partitions); + } + + @Override + @SuppressWarnings("deprecation") + public void onPartitionsRevoked(Collection partitions) { + block.maybeBlockOn(SINK_TASK_ON_PARTITIONS_REVOKED); + super.onPartitionsRevoked(partitions); + } + } + } + + public static class TaskInitializeBlockingSinkConnector extends BlockingSinkConnector { + public TaskInitializeBlockingSinkConnector() { + super(InitializeBlockingSinkTask.class); + } + + public static class InitializeBlockingSinkTask extends BlockingSinkTask { + public InitializeBlockingSinkTask() { + super(SINK_TASK_INITIALIZE); + } + } + } + + // We don't declare a class here that blocks in the version() method since that method is used + // in plugin path scanning. Until/unless plugin path scanning is altered to not block completely + // on connectors' version() methods, we can't even declare a class that does that without + // causing the workers in this test to hang on startup. +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ConnectIntegrationTestUtils.java b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ConnectIntegrationTestUtils.java new file mode 100644 index 0000000..058dbe2 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ConnectIntegrationTestUtils.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.integration; + +import org.junit.rules.TestRule; +import org.junit.rules.TestWatcher; +import org.junit.runner.Description; +import org.slf4j.Logger; + +/** + * A utility class for Connect's integration tests + */ +public class ConnectIntegrationTestUtils { + public static TestRule newTestWatcher(Logger log) { + return new TestWatcher() { + @Override + protected void starting(Description description) { + super.starting(description); + log.info("Starting test {}", description.getMethodName()); + } + + @Override + protected void finished(Description description) { + super.finished(description); + log.info("Finished test {}", description.getMethodName()); + } + }; + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ConnectWorkerIntegrationTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ConnectWorkerIntegrationTest.java new file mode 100644 index 0000000..5cd794e --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ConnectWorkerIntegrationTest.java @@ -0,0 +1,342 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.integration; + +import org.apache.kafka.connect.runtime.distributed.DistributedConfig; +import org.apache.kafka.connect.storage.StringConverter; +import org.apache.kafka.connect.util.clusters.EmbeddedConnectCluster; +import org.apache.kafka.connect.util.clusters.WorkerHandle; +import org.apache.kafka.test.IntegrationTest; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestRule; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import static org.apache.kafka.clients.CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG; +import static org.apache.kafka.connect.integration.MonitorableSourceConnector.TOPIC_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.CONNECTOR_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.CONNECTOR_CLIENT_PRODUCER_OVERRIDES_PREFIX; +import static org.apache.kafka.connect.runtime.ConnectorConfig.KEY_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.TASKS_MAX_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.VALUE_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.DEFAULT_TOPIC_CREATION_PREFIX; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.PARTITIONS_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.REPLICATION_FACTOR_CONFIG; +import static org.apache.kafka.connect.runtime.WorkerConfig.CONNECTOR_CLIENT_POLICY_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_CONFIG; +import static org.apache.kafka.connect.util.clusters.EmbeddedConnectClusterAssertions.CONNECTOR_SETUP_DURATION_MS; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/** + * Test simple operations on the workers of a Connect cluster. + */ +@Category(IntegrationTest.class) +public class ConnectWorkerIntegrationTest { + private static final Logger log = LoggerFactory.getLogger(ConnectWorkerIntegrationTest.class); + + private static final int NUM_TOPIC_PARTITIONS = 3; + private static final long OFFSET_COMMIT_INTERVAL_MS = TimeUnit.SECONDS.toMillis(30); + private static final int NUM_WORKERS = 3; + private static final int NUM_TASKS = 4; + private static final int MESSAGES_PER_POLL = 10; + private static final String CONNECTOR_NAME = "simple-source"; + private static final String TOPIC_NAME = "test-topic"; + + private EmbeddedConnectCluster.Builder connectBuilder; + private EmbeddedConnectCluster connect; + private Map workerProps; + private Properties brokerProps; + + @Rule + public TestRule watcher = ConnectIntegrationTestUtils.newTestWatcher(log); + + @Before + public void setup() { + // setup Connect worker properties + workerProps = new HashMap<>(); + workerProps.put(OFFSET_COMMIT_INTERVAL_MS_CONFIG, String.valueOf(OFFSET_COMMIT_INTERVAL_MS)); + workerProps.put(CONNECTOR_CLIENT_POLICY_CLASS_CONFIG, "All"); + + // setup Kafka broker properties + brokerProps = new Properties(); + brokerProps.put("auto.create.topics.enable", String.valueOf(false)); + + // build a Connect cluster backed by Kafka and Zk + connectBuilder = new EmbeddedConnectCluster.Builder() + .name("connect-cluster") + .numWorkers(NUM_WORKERS) + .workerProps(workerProps) + .brokerProps(brokerProps) + .maskExitProcedures(true); // true is the default, setting here as example + } + + @After + public void close() { + // stop all Connect, Kafka and Zk threads. + connect.stop(); + } + + /** + * Simple test case to add and then remove a worker from the embedded Connect cluster while + * running a simple source connector. + */ + @Test + public void testAddAndRemoveWorker() throws Exception { + connect = connectBuilder.build(); + // start the clusters + connect.start(); + + // create test topic + connect.kafka().createTopic(TOPIC_NAME, NUM_TOPIC_PARTITIONS); + + // set up props for the source connector + Map props = defaultSourceConnectorProps(TOPIC_NAME); + + connect.assertions().assertAtLeastNumWorkersAreUp(NUM_WORKERS, + "Initial group of workers did not start in time."); + + // start a source connector + connect.configureConnector(CONNECTOR_NAME, props); + + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(CONNECTOR_NAME, NUM_TASKS, + "Connector tasks did not start in time."); + + WorkerHandle extraWorker = connect.addWorker(); + + connect.assertions().assertAtLeastNumWorkersAreUp(NUM_WORKERS + 1, + "Expanded group of workers did not start in time."); + + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(CONNECTOR_NAME, NUM_TASKS, + "Connector tasks are not all in running state."); + + Set workers = connect.activeWorkers(); + assertTrue(workers.contains(extraWorker)); + + connect.removeWorker(extraWorker); + + connect.assertions().assertExactlyNumWorkersAreUp(NUM_WORKERS, + "Group of workers did not shrink in time."); + + workers = connect.activeWorkers(); + assertFalse(workers.contains(extraWorker)); + } + + /** + * Verify that a failed task can be restarted successfully. + */ + @Test + public void testRestartFailedTask() throws Exception { + connect = connectBuilder.build(); + // start the clusters + connect.start(); + + int numTasks = 1; + + // setup up props for the source connector + Map props = defaultSourceConnectorProps(TOPIC_NAME); + // Properties for the source connector. The task should fail at startup due to the bad broker address. + props.put(TASKS_MAX_CONFIG, Objects.toString(numTasks)); + props.put(CONNECTOR_CLIENT_PRODUCER_OVERRIDES_PREFIX + BOOTSTRAP_SERVERS_CONFIG, "nobrokerrunningatthisaddress"); + + connect.assertions().assertExactlyNumWorkersAreUp(NUM_WORKERS, + "Initial group of workers did not start in time."); + + // Try to start the connector and its single task. + connect.configureConnector(CONNECTOR_NAME, props); + + connect.assertions().assertConnectorIsRunningAndTasksHaveFailed(CONNECTOR_NAME, numTasks, + "Connector tasks did not fail in time"); + + // Reconfigure the connector without the bad broker address. + props.remove(CONNECTOR_CLIENT_PRODUCER_OVERRIDES_PREFIX + BOOTSTRAP_SERVERS_CONFIG); + connect.configureConnector(CONNECTOR_NAME, props); + + // Restart the failed task + String taskRestartEndpoint = connect.endpointForResource( + String.format("connectors/%s/tasks/0/restart", CONNECTOR_NAME)); + connect.requestPost(taskRestartEndpoint, "", Collections.emptyMap()); + + // Ensure the task started successfully this time + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(CONNECTOR_NAME, numTasks, + "Connector tasks are not all in running state."); + } + + /** + * Verify that a set of tasks restarts correctly after a broker goes offline and back online + */ + @Test + public void testBrokerCoordinator() throws Exception { + ConnectorHandle connectorHandle = RuntimeHandles.get().connectorHandle(CONNECTOR_NAME); + workerProps.put(DistributedConfig.SCHEDULED_REBALANCE_MAX_DELAY_MS_CONFIG, String.valueOf(5000)); + connect = connectBuilder.workerProps(workerProps).build(); + // start the clusters + connect.start(); + int numTasks = 4; + // create test topic + connect.kafka().createTopic(TOPIC_NAME, NUM_TOPIC_PARTITIONS); + + // set up props for the source connector + Map props = defaultSourceConnectorProps(TOPIC_NAME); + + connect.assertions().assertAtLeastNumWorkersAreUp(NUM_WORKERS, + "Initial group of workers did not start in time."); + + // start a source connector + connect.configureConnector(CONNECTOR_NAME, props); + + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(CONNECTOR_NAME, numTasks, + "Connector tasks did not start in time."); + + // expect that the connector will be stopped once the coordinator is detected to be down + StartAndStopLatch stopLatch = connectorHandle.expectedStops(1, false); + + connect.kafka().stopOnlyKafka(); + + connect.assertions().assertExactlyNumWorkersAreUp(NUM_WORKERS, + "Group of workers did not remain the same after broker shutdown"); + + // Allow for the workers to discover that the coordinator is unavailable, wait is + // heartbeat timeout * 2 + 4sec + Thread.sleep(TimeUnit.SECONDS.toMillis(10)); + + // Wait for the connector to be stopped + assertTrue("Failed to stop connector and tasks after coordinator failure within " + + CONNECTOR_SETUP_DURATION_MS + "ms", + stopLatch.await(CONNECTOR_SETUP_DURATION_MS, TimeUnit.MILLISECONDS)); + + StartAndStopLatch startLatch = connectorHandle.expectedStarts(1, false); + connect.kafka().startOnlyKafkaOnSamePorts(); + + // Allow for the kafka brokers to come back online + Thread.sleep(TimeUnit.SECONDS.toMillis(10)); + + connect.assertions().assertExactlyNumWorkersAreUp(NUM_WORKERS, + "Group of workers did not remain the same within the designated time."); + + // Allow for the workers to rebalance and reach a steady state + Thread.sleep(TimeUnit.SECONDS.toMillis(10)); + + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(CONNECTOR_NAME, numTasks, + "Connector tasks did not start in time."); + + // Expect that the connector has started again + assertTrue("Failed to stop connector and tasks after coordinator failure within " + + CONNECTOR_SETUP_DURATION_MS + "ms", + startLatch.await(CONNECTOR_SETUP_DURATION_MS, TimeUnit.MILLISECONDS)); + } + + /** + * Verify that the number of tasks listed in the REST API is updated correctly after changes to + * the "tasks.max" connector configuration. + */ + @Test + public void testTaskStatuses() throws Exception { + connect = connectBuilder.build(); + // start the clusters + connect.start(); + + connect.assertions().assertAtLeastNumWorkersAreUp(NUM_WORKERS, + "Initial group of workers did not start in time."); + + // base connector props + Map props = defaultSourceConnectorProps(TOPIC_NAME); + props.put(CONNECTOR_CLASS_CONFIG, MonitorableSourceConnector.class.getSimpleName()); + + // start the connector with only one task + int initialNumTasks = 1; + props.put(TASKS_MAX_CONFIG, String.valueOf(initialNumTasks)); + connect.configureConnector(CONNECTOR_NAME, props); + connect.assertions().assertConnectorAndExactlyNumTasksAreRunning(CONNECTOR_NAME, + initialNumTasks, "Connector tasks did not start in time"); + + // then reconfigure it to use more tasks + int increasedNumTasks = 5; + props.put(TASKS_MAX_CONFIG, String.valueOf(increasedNumTasks)); + connect.configureConnector(CONNECTOR_NAME, props); + connect.assertions().assertConnectorAndExactlyNumTasksAreRunning(CONNECTOR_NAME, + increasedNumTasks, "Connector task statuses did not update in time."); + + // then reconfigure it to use fewer tasks + int decreasedNumTasks = 3; + props.put(TASKS_MAX_CONFIG, String.valueOf(decreasedNumTasks)); + connect.configureConnector(CONNECTOR_NAME, props); + connect.assertions().assertConnectorAndExactlyNumTasksAreRunning(CONNECTOR_NAME, + decreasedNumTasks, "Connector task statuses did not update in time."); + } + + @Test + public void testSourceTaskNotBlockedOnShutdownWithNonExistentTopic() throws Exception { + // When automatic topic creation is disabled on the broker + brokerProps.put("auto.create.topics.enable", "false"); + connect = connectBuilder + .brokerProps(brokerProps) + .numWorkers(1) + .numBrokers(1) + .build(); + connect.start(); + + connect.assertions().assertAtLeastNumWorkersAreUp(1, "Initial group of workers did not start in time."); + + // and when the connector is not configured to create topics + Map props = defaultSourceConnectorProps("nonexistenttopic"); + props.remove(DEFAULT_TOPIC_CREATION_PREFIX + REPLICATION_FACTOR_CONFIG); + props.remove(DEFAULT_TOPIC_CREATION_PREFIX + PARTITIONS_CONFIG); + props.put("throughput", "-1"); + + ConnectorHandle connector = RuntimeHandles.get().connectorHandle(CONNECTOR_NAME); + connector.expectedRecords(NUM_TASKS * MESSAGES_PER_POLL); + connect.configureConnector(CONNECTOR_NAME, props); + connect.assertions().assertConnectorAndExactlyNumTasksAreRunning(CONNECTOR_NAME, + NUM_TASKS, "Connector tasks did not start in time"); + connector.awaitRecords(TimeUnit.MINUTES.toMillis(1)); + + // Then if we delete the connector, it and each of its tasks should be stopped by the framework + // even though the producer is blocked because there is no topic + StartAndStopLatch stopCounter = connector.expectedStops(1); + connect.deleteConnector(CONNECTOR_NAME); + + assertTrue("Connector and all tasks were not stopped in time", stopCounter.await(1, TimeUnit.MINUTES)); + } + + private Map defaultSourceConnectorProps(String topic) { + // setup up props for the source connector + Map props = new HashMap<>(); + props.put(CONNECTOR_CLASS_CONFIG, MonitorableSourceConnector.class.getSimpleName()); + props.put(TASKS_MAX_CONFIG, String.valueOf(NUM_TASKS)); + props.put(TOPIC_CONFIG, topic); + props.put("throughput", "10"); + props.put("messages.per.poll", String.valueOf(MESSAGES_PER_POLL)); + props.put(KEY_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(VALUE_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + REPLICATION_FACTOR_CONFIG, String.valueOf(1)); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + PARTITIONS_CONFIG, String.valueOf(1)); + return props; + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ConnectorClientPolicyIntegrationTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ConnectorClientPolicyIntegrationTest.java new file mode 100644 index 0000000..a0abece --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ConnectorClientPolicyIntegrationTest.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.integration; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.config.SaslConfigs; +import org.apache.kafka.connect.runtime.ConnectorConfig; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.apache.kafka.connect.runtime.rest.errors.ConnectRestException; +import org.apache.kafka.connect.storage.StringConverter; +import org.apache.kafka.connect.util.clusters.EmbeddedConnectCluster; +import org.apache.kafka.test.IntegrationTest; +import org.junit.After; +import org.junit.Test; +import org.junit.experimental.categories.Category; + +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; + +import static org.apache.kafka.connect.runtime.ConnectorConfig.CONNECTOR_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.KEY_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.TASKS_MAX_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.VALUE_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.SinkConnectorConfig.TOPICS_CONFIG; +import static org.apache.kafka.connect.runtime.WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_CONFIG; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +@Category(IntegrationTest.class) +public class ConnectorClientPolicyIntegrationTest { + + private static final int NUM_TASKS = 1; + private static final int NUM_WORKERS = 1; + private static final String CONNECTOR_NAME = "simple-conn"; + + @After + public void close() { + } + + @Test + public void testCreateWithOverridesForNonePolicy() throws Exception { + Map props = basicConnectorConfig(); + props.put(ConnectorConfig.CONNECTOR_CLIENT_CONSUMER_OVERRIDES_PREFIX + SaslConfigs.SASL_JAAS_CONFIG, "sasl"); + assertFailCreateConnector("None", props); + } + + @Test + public void testCreateWithNotAllowedOverridesForPrincipalPolicy() throws Exception { + Map props = basicConnectorConfig(); + props.put(ConnectorConfig.CONNECTOR_CLIENT_CONSUMER_OVERRIDES_PREFIX + SaslConfigs.SASL_JAAS_CONFIG, "sasl"); + props.put(ConnectorConfig.CONNECTOR_CLIENT_CONSUMER_OVERRIDES_PREFIX + ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "latest"); + assertFailCreateConnector("Principal", props); + } + + @Test + public void testCreateWithAllowedOverridesForPrincipalPolicy() throws Exception { + Map props = basicConnectorConfig(); + props.put(ConnectorConfig.CONNECTOR_CLIENT_CONSUMER_OVERRIDES_PREFIX + CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, "PLAINTEXT"); + assertPassCreateConnector("Principal", props); + } + + @Test + public void testCreateWithAllowedOverridesForAllPolicy() throws Exception { + // setup up props for the sink connector + Map props = basicConnectorConfig(); + props.put(ConnectorConfig.CONNECTOR_CLIENT_CONSUMER_OVERRIDES_PREFIX + CommonClientConfigs.CLIENT_ID_CONFIG, "test"); + assertPassCreateConnector("All", props); + } + + @Test + public void testCreateWithNoAllowedOverridesForNonePolicy() throws Exception { + // setup up props for the sink connector + Map props = basicConnectorConfig(); + assertPassCreateConnector("None", props); + } + + @Test + public void testCreateWithAllowedOverridesForDefaultPolicy() throws Exception { + // setup up props for the sink connector + Map props = basicConnectorConfig(); + props.put(ConnectorConfig.CONNECTOR_CLIENT_CONSUMER_OVERRIDES_PREFIX + CommonClientConfigs.CLIENT_ID_CONFIG, "test"); + assertPassCreateConnector(null, props); + } + + private EmbeddedConnectCluster connectClusterWithPolicy(String policy) throws InterruptedException { + // setup Connect worker properties + Map workerProps = new HashMap<>(); + workerProps.put(OFFSET_COMMIT_INTERVAL_MS_CONFIG, String.valueOf(5_000)); + if (policy != null) { + workerProps.put(WorkerConfig.CONNECTOR_CLIENT_POLICY_CLASS_CONFIG, policy); + } + + // setup Kafka broker properties + Properties exampleBrokerProps = new Properties(); + exampleBrokerProps.put("auto.create.topics.enable", "false"); + + // build a Connect cluster backed by Kafka and Zk + EmbeddedConnectCluster connect = new EmbeddedConnectCluster.Builder() + .name("connect-cluster") + .numWorkers(NUM_WORKERS) + .numBrokers(1) + .workerProps(workerProps) + .brokerProps(exampleBrokerProps) + .build(); + + // start the clusters + connect.start(); + connect.assertions().assertAtLeastNumWorkersAreUp(NUM_WORKERS, + "Initial group of workers did not start in time."); + + return connect; + } + + private void assertFailCreateConnector(String policy, Map props) throws InterruptedException { + EmbeddedConnectCluster connect = connectClusterWithPolicy(policy); + try { + connect.configureConnector(CONNECTOR_NAME, props); + fail("Shouldn't be able to create connector"); + } catch (ConnectRestException e) { + assertEquals(e.statusCode(), 400); + } finally { + connect.stop(); + } + } + + private void assertPassCreateConnector(String policy, Map props) throws InterruptedException { + EmbeddedConnectCluster connect = connectClusterWithPolicy(policy); + try { + connect.configureConnector(CONNECTOR_NAME, props); + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(CONNECTOR_NAME, NUM_TASKS, + "Connector tasks did not start in time."); + } catch (ConnectRestException e) { + fail("Should be able to create connector"); + } finally { + connect.stop(); + } + } + + + public Map basicConnectorConfig() { + Map props = new HashMap<>(); + props.put(CONNECTOR_CLASS_CONFIG, MonitorableSinkConnector.class.getSimpleName()); + props.put(TASKS_MAX_CONFIG, String.valueOf(NUM_TASKS)); + props.put(TOPICS_CONFIG, "test-topic"); + props.put(KEY_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(VALUE_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + return props; + } + +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ConnectorHandle.java b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ConnectorHandle.java new file mode 100644 index 0000000..b31455b --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ConnectorHandle.java @@ -0,0 +1,369 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.integration; + +import org.apache.kafka.connect.errors.DataException; +import org.apache.kafka.connect.sink.SinkRecord; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +/** + * A handle to a connector executing in a Connect cluster. + */ +public class ConnectorHandle { + + private static final Logger log = LoggerFactory.getLogger(ConnectorHandle.class); + + private final String connectorName; + private final Map taskHandles = new ConcurrentHashMap<>(); + private final StartAndStopCounter startAndStopCounter = new StartAndStopCounter(); + + private CountDownLatch recordsRemainingLatch; + private CountDownLatch recordsToCommitLatch; + private int expectedRecords = -1; + private int expectedCommits = -1; + + public ConnectorHandle(String connectorName) { + this.connectorName = connectorName; + } + + /** + * Get or create a task handle for a given task id. The task need not be created when this method is called. If the + * handle is called before the task is created, the task will bind to the handle once it starts (or restarts). + * + * @param taskId the task id + * @return a non-null {@link TaskHandle} + */ + public TaskHandle taskHandle(String taskId) { + return taskHandle(taskId, null); + } + + /** + * Get or create a task handle for a given task id. The task need not be created when this method is called. If the + * handle is called before the task is created, the task will bind to the handle once it starts (or restarts). + * + * @param taskId the task id + * @param consumer A callback invoked when a sink task processes a record. + * @return a non-null {@link TaskHandle} + */ + public TaskHandle taskHandle(String taskId, Consumer consumer) { + return taskHandles.computeIfAbsent(taskId, k -> new TaskHandle(this, taskId, consumer)); + } + + /** + * Gets the start and stop counter corresponding to this handle. + * + * @return the start and stop counter + */ + public StartAndStopCounter startAndStopCounter() { + return startAndStopCounter; + } + + /** + * Get the connector's name corresponding to this handle. + * + * @return the connector's name + */ + public String name() { + return connectorName; + } + + /** + * Get the list of tasks handles monitored by this connector handle. + * + * @return the task handle list + */ + public Collection tasks() { + return taskHandles.values(); + } + + /** + * Delete the task handle for this task id. + * + * @param taskId the task id. + */ + public void deleteTask(String taskId) { + log.info("Removing handle for {} task in connector {}", taskId, connectorName); + taskHandles.remove(taskId); + } + + /** + * Set the number of expected records for this connector. + * + * @param expected number of records + */ + public void expectedRecords(int expected) { + expectedRecords = expected; + recordsRemainingLatch = new CountDownLatch(expected); + } + + /** + * Set the number of expected commits performed by this connector. + * + * @param expected number of commits + */ + public void expectedCommits(int expected) { + expectedCommits = expected; + recordsToCommitLatch = new CountDownLatch(expected); + } + + /** + * Record a message arrival at the connector. + */ + public void record() { + if (recordsRemainingLatch != null) { + recordsRemainingLatch.countDown(); + } + } + + /** + * Record arrival of a batch of messages at the connector. + * + * @param batchSize the number of messages + */ + public void record(int batchSize) { + if (recordsRemainingLatch != null) { + IntStream.range(0, batchSize).forEach(i -> recordsRemainingLatch.countDown()); + } + } + + /** + * Record a message commit from the connector. + */ + public void commit() { + if (recordsToCommitLatch != null) { + recordsToCommitLatch.countDown(); + } + } + + /** + * Record commit on a batch of messages from the connector. + * + * @param batchSize the number of messages + */ + public void commit(int batchSize) { + if (recordsToCommitLatch != null) { + IntStream.range(0, batchSize).forEach(i -> recordsToCommitLatch.countDown()); + } + } + + /** + * Wait for this connector to meet the expected number of records as defined by {@code + * expectedRecords}. + * + * @param timeout max duration to wait for records + * @throws InterruptedException if another threads interrupts this one while waiting for records + */ + public void awaitRecords(long timeout) throws InterruptedException { + if (recordsRemainingLatch == null || expectedRecords < 0) { + throw new IllegalStateException("expectedRecords() was not set for this connector?"); + } + if (!recordsRemainingLatch.await(timeout, TimeUnit.MILLISECONDS)) { + String msg = String.format( + "Insufficient records seen by connector %s in %d millis. Records expected=%d, actual=%d", + connectorName, + timeout, + expectedRecords, + expectedRecords - recordsRemainingLatch.getCount()); + throw new DataException(msg); + } + } + + /** + * Wait for this connector to meet the expected number of commits as defined by {@code + * expectedCommits}. + * + * @param timeout duration to wait for commits + * @throws InterruptedException if another threads interrupts this one while waiting for commits + */ + public void awaitCommits(long timeout) throws InterruptedException { + if (recordsToCommitLatch == null || expectedCommits < 0) { + throw new IllegalStateException("expectedCommits() was not set for this connector?"); + } + if (!recordsToCommitLatch.await(timeout, TimeUnit.MILLISECONDS)) { + String msg = String.format( + "Insufficient records committed by connector %s in %d millis. Records expected=%d, actual=%d", + connectorName, + timeout, + expectedCommits, + expectedCommits - recordsToCommitLatch.getCount()); + throw new DataException(msg); + } + } + + /** + * Record that this connector has been started. This should be called by the connector under + * test. + * + * @see #expectedStarts(int) + */ + public void recordConnectorStart() { + startAndStopCounter.recordStart(); + } + + /** + * Record that this connector has been stopped. This should be called by the connector under + * test. + * + * @see #expectedStarts(int) + */ + public void recordConnectorStop() { + startAndStopCounter.recordStop(); + } + + /** + * Obtain a {@link StartAndStopLatch} that can be used to wait until the connector using this handle + * and all tasks using {@link TaskHandle} have completed the expected number of + * starts, starting the counts at the time this method is called. + * + *

            A test can call this method, specifying the number of times the connector and tasks + * will each be stopped and started from that point (typically {@code expectedStarts(1)}). + * The test should then change the connector or otherwise cause the connector to restart one or + * more times, and then can call {@link StartAndStopLatch#await(long, TimeUnit)} to wait up to a + * specified duration for the connector and all tasks to be started at least the specified + * number of times. + * + *

            This method does not track the number of times the connector and tasks are stopped, and + * only tracks the number of times the connector and tasks are started. + * + * @param expectedStarts the minimum number of starts that are expected once this method is + * called + * @return the latch that can be used to wait for the starts to complete; never null + */ + public StartAndStopLatch expectedStarts(int expectedStarts) { + return expectedStarts(expectedStarts, true); + } + + /** + * Obtain a {@link StartAndStopLatch} that can be used to wait until the connector using this handle + * and optionally all tasks using {@link TaskHandle} have completed the expected number of + * starts, starting the counts at the time this method is called. + * + *

            A test can call this method, specifying the number of times the connector and tasks + * will each be stopped and started from that point (typically {@code expectedStarts(1)}). + * The test should then change the connector or otherwise cause the connector to restart one or + * more times, and then can call {@link StartAndStopLatch#await(long, TimeUnit)} to wait up to a + * specified duration for the connector and all tasks to be started at least the specified + * number of times. + * + *

            This method does not track the number of times the connector and tasks are stopped, and + * only tracks the number of times the connector and tasks are started. + * + * @param expectedStarts the minimum number of starts that are expected once this method is + * called + * @param includeTasks true if the latch should also wait for the tasks to be stopped the + * specified minimum number of times + * @return the latch that can be used to wait for the starts to complete; never null + */ + public StartAndStopLatch expectedStarts(int expectedStarts, boolean includeTasks) { + List taskLatches = includeTasks + ? taskHandles.values().stream() + .map(task -> task.expectedStarts(expectedStarts)) + .collect(Collectors.toList()) + : Collections.emptyList(); + return startAndStopCounter.expectedStarts(expectedStarts, taskLatches); + } + + public StartAndStopLatch expectedStarts(int expectedStarts, Map expectedTasksStarts, boolean includeTasks) { + List taskLatches = includeTasks + ? taskHandles.values().stream() + .map(task -> task.expectedStarts(expectedTasksStarts.get(task.taskId()))) + .collect(Collectors.toList()) + : Collections.emptyList(); + return startAndStopCounter.expectedStarts(expectedStarts, taskLatches); + } + + /** + * Obtain a {@link StartAndStopLatch} that can be used to wait until the connector using this handle + * and optionally all tasks using {@link TaskHandle} have completed the minimum number of + * stops, starting the counts at the time this method is called. + * + *

            A test can call this method, specifying the number of times the connector and tasks + * will each be stopped from that point (typically {@code expectedStops(1)}). + * The test should then change the connector or otherwise cause the connector to stop (or + * restart) one or more times, and then can call + * {@link StartAndStopLatch#await(long, TimeUnit)} to wait up to a specified duration for the + * connector and all tasks to be started at least the specified number of times. + * + *

            This method does not track the number of times the connector and tasks are started, and + * only tracks the number of times the connector and tasks are stopped. + * + * @param expectedStops the minimum number of starts that are expected once this method is + * called + * @return the latch that can be used to wait for the starts to complete; never null + */ + public StartAndStopLatch expectedStops(int expectedStops) { + return expectedStops(expectedStops, true); + } + + /** + * Obtain a {@link StartAndStopLatch} that can be used to wait until the connector using this handle + * and optionally all tasks using {@link TaskHandle} have completed the minimum number of + * stops, starting the counts at the time this method is called. + * + *

            A test can call this method, specifying the number of times the connector and tasks + * will each be stopped from that point (typically {@code expectedStops(1)}). + * The test should then change the connector or otherwise cause the connector to stop (or + * restart) one or more times, and then can call + * {@link StartAndStopLatch#await(long, TimeUnit)} to wait up to a specified duration for the + * connector and all tasks to be started at least the specified number of times. + * + *

            This method does not track the number of times the connector and tasks are started, and + * only tracks the number of times the connector and tasks are stopped. + * + * @param expectedStops the minimum number of starts that are expected once this method is + * called + * @param includeTasks true if the latch should also wait for the tasks to be stopped the + * specified minimum number of times + * @return the latch that can be used to wait for the starts to complete; never null + */ + public StartAndStopLatch expectedStops(int expectedStops, boolean includeTasks) { + List taskLatches = includeTasks + ? taskHandles.values().stream() + .map(task -> task.expectedStops(expectedStops)) + .collect(Collectors.toList()) + : Collections.emptyList(); + return startAndStopCounter.expectedStops(expectedStops, taskLatches); + } + + public StartAndStopLatch expectedStops(int expectedStops, Map expectedTasksStops, boolean includeTasks) { + List taskLatches = includeTasks + ? taskHandles.values().stream() + .map(task -> task.expectedStops(expectedTasksStops.get(task.taskId()))) + .collect(Collectors.toList()) + : Collections.emptyList(); + return startAndStopCounter.expectedStops(expectedStops, taskLatches); + } + + @Override + public String toString() { + return "ConnectorHandle{" + + "connectorName='" + connectorName + '\'' + + '}'; + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ConnectorRestartApiIntegrationTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ConnectorRestartApiIntegrationTest.java new file mode 100644 index 0000000..7d5646f --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ConnectorRestartApiIntegrationTest.java @@ -0,0 +1,435 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.integration; + +import org.apache.kafka.connect.runtime.AbstractStatus; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorStateInfo; +import org.apache.kafka.connect.storage.StringConverter; +import org.apache.kafka.connect.util.clusters.EmbeddedConnectCluster; +import org.apache.kafka.test.IntegrationTest; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; +import org.junit.rules.TestRule; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.ws.rs.core.Response; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static org.apache.kafka.connect.integration.MonitorableSourceConnector.TOPIC_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.CONNECTOR_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.KEY_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.TASKS_MAX_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.VALUE_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.DEFAULT_TOPIC_CREATION_PREFIX; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.PARTITIONS_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.REPLICATION_FACTOR_CONFIG; +import static org.apache.kafka.connect.runtime.WorkerConfig.CONNECTOR_CLIENT_POLICY_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_CONFIG; +import static org.apache.kafka.connect.util.clusters.EmbeddedConnectClusterAssertions.CONNECTOR_SETUP_DURATION_MS; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; + +/** + * Test connectors restart API use cases. + */ +@Category(IntegrationTest.class) +public class ConnectorRestartApiIntegrationTest { + private static final Logger log = LoggerFactory.getLogger(ConnectorRestartApiIntegrationTest.class); + + private static final long OFFSET_COMMIT_INTERVAL_MS = TimeUnit.SECONDS.toMillis(30); + private static final int ONE_WORKER = 1; + private static final int NUM_TASKS = 4; + private static final int MESSAGES_PER_POLL = 10; + private static final String CONNECTOR_NAME_PREFIX = "conn-"; + + private static final String TOPIC_NAME = "test-topic"; + + private static Map connectClusterMap = new ConcurrentHashMap<>(); + + private EmbeddedConnectCluster connect; + private ConnectorHandle connectorHandle; + private String connectorName; + @Rule + public TestRule watcher = ConnectIntegrationTestUtils.newTestWatcher(log); + @Rule + public TestName testName = new TestName(); + + @Before + public void setup() { + connectorName = CONNECTOR_NAME_PREFIX + testName.getMethodName(); + // get connector handles before starting test. + connectorHandle = RuntimeHandles.get().connectorHandle(connectorName); + } + + private void startOrReuseConnectWithNumWorkers(int numWorkers) throws Exception { + connect = connectClusterMap.computeIfAbsent(numWorkers, n -> { + // setup Connect worker properties + Map workerProps = new HashMap<>(); + workerProps.put(OFFSET_COMMIT_INTERVAL_MS_CONFIG, String.valueOf(OFFSET_COMMIT_INTERVAL_MS)); + workerProps.put(CONNECTOR_CLIENT_POLICY_CLASS_CONFIG, "All"); + + // setup Kafka broker properties + Properties brokerProps = new Properties(); + brokerProps.put("auto.create.topics.enable", String.valueOf(false)); + + EmbeddedConnectCluster.Builder connectBuilder = new EmbeddedConnectCluster.Builder() + .name("connect-cluster") + .numWorkers(numWorkers) + .workerProps(workerProps) + .brokerProps(brokerProps) + // true is the default, setting here as example + .maskExitProcedures(true); + EmbeddedConnectCluster connect = connectBuilder.build(); + // start the clusters + connect.start(); + return connect; + }); + connect.assertions().assertExactlyNumWorkersAreUp(numWorkers, + "Initial group of workers did not start in time."); + } + + @After + public void tearDown() { + RuntimeHandles.get().deleteConnector(connectorName); + } + + @AfterClass + public static void close() { + // stop all Connect, Kafka and Zk threads. + connectClusterMap.values().forEach(c -> c.stop()); + } + + @Test + public void testRestartUnknownConnectorNoParams() throws Exception { + String connectorName = "Unknown"; + + // build a Connect cluster backed by Kafka and Zk + startOrReuseConnectWithNumWorkers(ONE_WORKER); + // Call the Restart API + String restartEndpoint = connect.endpointForResource( + String.format("connectors/%s/restart", connectorName)); + Response response = connect.requestPost(restartEndpoint, "", Collections.emptyMap()); + assertEquals(Response.Status.NOT_FOUND.getStatusCode(), response.getStatus()); + + } + + @Test + public void testRestartUnknownConnector() throws Exception { + restartUnknownConnector(false, false); + restartUnknownConnector(false, true); + restartUnknownConnector(true, false); + restartUnknownConnector(true, true); + } + + private void restartUnknownConnector(boolean onlyFailed, boolean includeTasks) throws Exception { + String connectorName = "Unknown"; + + // build a Connect cluster backed by Kafka and Zk + startOrReuseConnectWithNumWorkers(ONE_WORKER); + // Call the Restart API + String restartEndpoint = connect.endpointForResource( + String.format("connectors/%s/restart?onlyFailed=" + onlyFailed + "&includeTasks=" + includeTasks, connectorName)); + Response response = connect.requestPost(restartEndpoint, "", Collections.emptyMap()); + assertEquals(Response.Status.NOT_FOUND.getStatusCode(), response.getStatus()); + } + + @Test + public void testRunningConnectorAndTasksRestartOnlyConnector() throws Exception { + runningConnectorAndTasksRestart(false, false, 1, allTasksExpectedRestarts(0), false); + } + + @Test + public void testRunningConnectorAndTasksRestartBothConnectorAndTasks() throws Exception { + runningConnectorAndTasksRestart(false, true, 1, allTasksExpectedRestarts(1), false); + } + + @Test + public void testRunningConnectorAndTasksRestartOnlyFailedConnectorNoop() throws Exception { + runningConnectorAndTasksRestart(true, false, 0, allTasksExpectedRestarts(0), true); + } + + @Test + public void testRunningConnectorAndTasksRestartBothConnectorAndTasksNoop() throws Exception { + runningConnectorAndTasksRestart(true, true, 0, allTasksExpectedRestarts(0), true); + } + + @Test + public void testFailedConnectorRestartOnlyConnector() throws Exception { + failedConnectorRestart(false, false, 1); + } + + @Test + public void testFailedConnectorRestartBothConnectorAndTasks() throws Exception { + failedConnectorRestart(false, true, 1); + } + + @Test + public void testFailedConnectorRestartOnlyFailedConnectorAndTasks() throws Exception { + failedConnectorRestart(true, true, 1); + } + + @Test + public void testFailedTasksRestartOnlyConnector() throws Exception { + failedTasksRestart(false, false, 1, allTasksExpectedRestarts(0), buildAllTasksToFail(), false); + } + + @Test + public void testFailedTasksRestartOnlyTasks() throws Exception { + failedTasksRestart(true, true, 0, allTasksExpectedRestarts(1), buildAllTasksToFail(), false); + } + + @Test + public void testFailedTasksRestartWithoutIncludeTasksNoop() throws Exception { + failedTasksRestart(true, false, 0, allTasksExpectedRestarts(0), buildAllTasksToFail(), true); + } + + @Test + public void testFailedTasksRestartBothConnectorAndTasks() throws Exception { + failedTasksRestart(false, true, 1, allTasksExpectedRestarts(1), buildAllTasksToFail(), false); + } + + @Test + public void testOneFailedTasksRestartOnlyOneTasks() throws Exception { + Set tasksToFail = Collections.singleton(taskId(1)); + failedTasksRestart(true, true, 0, buildExpectedTasksRestarts(tasksToFail), tasksToFail, false); + } + + @Test + public void testMultiWorkerRestartOnlyConnector() throws Exception { + //run two additional workers to ensure that one worker will always be free and not running any tasks or connector instance for this connector + //we will call restart on that worker and that will test the distributed behavior of the restart API + int numWorkers = NUM_TASKS + 2; + runningConnectorAndTasksRestart(false, false, 1, allTasksExpectedRestarts(0), false, numWorkers); + } + + @Test + public void testMultiWorkerRestartBothConnectorAndTasks() throws Exception { + //run 2 additional workers to ensure 1 worker will be free that is not running any tasks or connector instance for this connector + int numWorkers = NUM_TASKS + 2; + runningConnectorAndTasksRestart(false, true, 1, allTasksExpectedRestarts(1), false, numWorkers); + } + + private void runningConnectorAndTasksRestart(boolean onlyFailed, boolean includeTasks, int expectedConnectorRestarts, Map expectedTasksRestarts, boolean noopRequest) throws Exception { + runningConnectorAndTasksRestart(onlyFailed, includeTasks, expectedConnectorRestarts, expectedTasksRestarts, noopRequest, ONE_WORKER); + } + + private void runningConnectorAndTasksRestart(boolean onlyFailed, boolean includeTasks, int expectedConnectorRestarts, Map expectedTasksRestarts, boolean noopRequest, int numWorkers) throws Exception { + startOrReuseConnectWithNumWorkers(numWorkers); + // setup up props for the source connector + Map props = defaultSourceConnectorProps(TOPIC_NAME); + // Try to start the connector and its single task. + connect.configureConnector(connectorName, props); + + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(connectorName, NUM_TASKS, + "Connector tasks are not all in running state."); + + StartsAndStops beforeSnapshot = connectorHandle.startAndStopCounter().countsSnapshot(); + Map beforeTasksSnapshot = connectorHandle.tasks().stream().collect(Collectors.toMap(TaskHandle::taskId, task -> task.startAndStopCounter().countsSnapshot())); + + StartAndStopLatch stopLatch = connectorHandle.expectedStops(expectedConnectorRestarts, expectedTasksRestarts, includeTasks); + StartAndStopLatch startLatch = connectorHandle.expectedStarts(expectedConnectorRestarts, expectedTasksRestarts, includeTasks); + ConnectorStateInfo connectorStateInfo; + // Call the Restart API + if (numWorkers == 1) { + connectorStateInfo = connect.restartConnectorAndTasks(connectorName, onlyFailed, includeTasks, false); + } else { + connectorStateInfo = connect.restartConnectorAndTasks(connectorName, onlyFailed, includeTasks, true); + } + + if (noopRequest) { + assertNoRestartingState(connectorStateInfo); + } + + // Wait for the connector to be stopped + assertTrue("Failed to stop connector and tasks within " + + CONNECTOR_SETUP_DURATION_MS + "ms", + stopLatch.await(CONNECTOR_SETUP_DURATION_MS, TimeUnit.MILLISECONDS)); + + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(connectorName, NUM_TASKS, + "Connector tasks are not all in running state."); + // Expect that the connector has started again + assertTrue("Failed to start connector and tasks within " + + CONNECTOR_SETUP_DURATION_MS + "ms", + startLatch.await(CONNECTOR_SETUP_DURATION_MS, TimeUnit.MILLISECONDS)); + StartsAndStops afterSnapshot = connectorHandle.startAndStopCounter().countsSnapshot(); + + assertEquals(beforeSnapshot.starts() + expectedConnectorRestarts, afterSnapshot.starts()); + assertEquals(beforeSnapshot.stops() + expectedConnectorRestarts, afterSnapshot.stops()); + connectorHandle.tasks().forEach(t -> { + StartsAndStops afterTaskSnapshot = t.startAndStopCounter().countsSnapshot(); + if (numWorkers == 1) { + assertEquals(beforeTasksSnapshot.get(t.taskId()).starts() + expectedTasksRestarts.get(t.taskId()), afterTaskSnapshot.starts()); + assertEquals(beforeTasksSnapshot.get(t.taskId()).stops() + expectedTasksRestarts.get(t.taskId()), afterTaskSnapshot.stops()); + } else { + //validate tasks stop/start counts only in single worker test because the multi worker rebalance triggers stop/start on task and this make the exact counts unpredictable + assertTrue(afterTaskSnapshot.starts() >= beforeTasksSnapshot.get(t.taskId()).starts() + expectedTasksRestarts.get(t.taskId())); + assertTrue(afterTaskSnapshot.stops() >= beforeTasksSnapshot.get(t.taskId()).stops() + expectedTasksRestarts.get(t.taskId())); + } + }); + } + + private void failedConnectorRestart(boolean onlyFailed, boolean includeTasks, int expectedConnectorRestarts) throws Exception { + //as connector is failed we expect 0 task to be started + Map expectedTasksStarts = allTasksExpectedRestarts(0); + + // setup up props for the source connector + Map props = defaultSourceConnectorProps(TOPIC_NAME); + props.put("connector.start.inject.error", "true"); + // build a Connect cluster backed by Kafka and Zk + startOrReuseConnectWithNumWorkers(ONE_WORKER); + + // Try to start the connector and its single task. + connect.configureConnector(connectorName, props); + + connect.assertions().assertConnectorIsFailedAndTasksHaveFailed(connectorName, 0, + "Connector or tasks are in running state."); + + StartsAndStops beforeSnapshot = connectorHandle.startAndStopCounter().countsSnapshot(); + + StartAndStopLatch startLatch = connectorHandle.expectedStarts(expectedConnectorRestarts, expectedTasksStarts, includeTasks); + + // Call the Restart API + connect.restartConnectorAndTasks(connectorName, onlyFailed, includeTasks, false); + + connect.assertions().assertConnectorIsFailedAndTasksHaveFailed(connectorName, 0, + "Connector tasks are not all in running state."); + // Expect that the connector has started again + assertTrue("Failed to start connector and tasks after coordinator failure within " + + CONNECTOR_SETUP_DURATION_MS + "ms", + startLatch.await(CONNECTOR_SETUP_DURATION_MS, TimeUnit.MILLISECONDS)); + StartsAndStops afterSnapshot = connectorHandle.startAndStopCounter().countsSnapshot(); + + assertEquals(beforeSnapshot.starts() + expectedConnectorRestarts, afterSnapshot.starts()); + } + + private void failedTasksRestart(boolean onlyFailed, boolean includeTasks, int expectedConnectorRestarts, Map expectedTasksRestarts, Set tasksToFail, boolean noopRequest) throws Exception { + // setup up props for the source connector + Map props = defaultSourceConnectorProps(TOPIC_NAME); + tasksToFail.forEach(taskId -> props.put("task-" + taskId + ".start.inject.error", "true")); + // build a Connect cluster backed by Kafka and Zk + startOrReuseConnectWithNumWorkers(ONE_WORKER); + + // Try to start the connector and its single task. + connect.configureConnector(connectorName, props); + + connect.assertions().assertConnectorIsRunningAndNumTasksHaveFailed(connectorName, NUM_TASKS, tasksToFail.size(), + "Connector tasks are in running state."); + + StartsAndStops beforeSnapshot = connectorHandle.startAndStopCounter().countsSnapshot(); + Map beforeTasksSnapshot = connectorHandle.tasks().stream().collect(Collectors.toMap(TaskHandle::taskId, task -> task.startAndStopCounter().countsSnapshot())); + + StartAndStopLatch stopLatch = connectorHandle.expectedStops(expectedConnectorRestarts, expectedTasksRestarts, includeTasks); + StartAndStopLatch startLatch = connectorHandle.expectedStarts(expectedConnectorRestarts, expectedTasksRestarts, includeTasks); + + // Call the Restart API + ConnectorStateInfo connectorStateInfo = connect.restartConnectorAndTasks(connectorName, onlyFailed, includeTasks, false); + + if (noopRequest) { + assertNoRestartingState(connectorStateInfo); + } + + // Wait for the connector to be stopped + assertTrue("Failed to stop connector and tasks within " + + CONNECTOR_SETUP_DURATION_MS + "ms", + stopLatch.await(CONNECTOR_SETUP_DURATION_MS, TimeUnit.MILLISECONDS)); + + connect.assertions().assertConnectorIsRunningAndNumTasksHaveFailed(connectorName, NUM_TASKS, tasksToFail.size(), + "Connector tasks are not all in running state."); + // Expect that the connector has started again + assertTrue("Failed to start connector and tasks within " + + CONNECTOR_SETUP_DURATION_MS + "ms", + startLatch.await(CONNECTOR_SETUP_DURATION_MS, TimeUnit.MILLISECONDS)); + + StartsAndStops afterSnapshot = connectorHandle.startAndStopCounter().countsSnapshot(); + + assertEquals(beforeSnapshot.starts() + expectedConnectorRestarts, afterSnapshot.starts()); + assertEquals(beforeSnapshot.stops() + expectedConnectorRestarts, afterSnapshot.stops()); + connectorHandle.tasks().forEach(t -> { + StartsAndStops afterTaskSnapshot = t.startAndStopCounter().countsSnapshot(); + assertEquals(beforeTasksSnapshot.get(t.taskId()).starts() + expectedTasksRestarts.get(t.taskId()), afterTaskSnapshot.starts()); + assertEquals(beforeTasksSnapshot.get(t.taskId()).stops() + expectedTasksRestarts.get(t.taskId()), afterTaskSnapshot.stops()); + }); + } + + private void assertNoRestartingState(ConnectorStateInfo connectorStateInfo) { + //for noop requests as everything is in RUNNING state, assert that plan was empty which means + // no RESTARTING state for the connector or tasks + assertNotEquals(AbstractStatus.State.RESTARTING.name(), connectorStateInfo.connector().state()); + connectorStateInfo.tasks().forEach(t -> assertNotEquals(AbstractStatus.State.RESTARTING.name(), t.state())); + } + + private Set buildAllTasksToFail() { + Set tasksToFail = new HashSet<>(); + for (int i = 0; i < NUM_TASKS; i++) { + String taskId = taskId(i); + tasksToFail.add(taskId); + } + return tasksToFail; + } + + private Map allTasksExpectedRestarts(int expectedRestarts) { + Map expectedTasksRestarts = new HashMap<>(); + for (int i = 0; i < NUM_TASKS; i++) { + String taskId = taskId(i); + expectedTasksRestarts.put(taskId, expectedRestarts); + } + return expectedTasksRestarts; + } + + private Map buildExpectedTasksRestarts(Set tasksToFail) { + Map expectedTasksRestarts = new HashMap<>(); + for (int i = 0; i < NUM_TASKS; i++) { + String taskId = taskId(i); + expectedTasksRestarts.put(taskId, tasksToFail.contains(taskId) ? 1 : 0); + } + return expectedTasksRestarts; + } + + private String taskId(int i) { + return connectorName + "-" + i; + } + + private Map defaultSourceConnectorProps(String topic) { + // setup up props for the source connector + Map props = new HashMap<>(); + props.put(CONNECTOR_CLASS_CONFIG, MonitorableSourceConnector.class.getSimpleName()); + props.put(TASKS_MAX_CONFIG, String.valueOf(NUM_TASKS)); + props.put(TOPIC_CONFIG, topic); + props.put("throughput", "10"); + props.put("messages.per.poll", String.valueOf(MESSAGES_PER_POLL)); + props.put(KEY_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(VALUE_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + REPLICATION_FACTOR_CONFIG, String.valueOf(1)); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + PARTITIONS_CONFIG, String.valueOf(1)); + return props; + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ConnectorTopicsIntegrationTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ConnectorTopicsIntegrationTest.java new file mode 100644 index 0000000..8c4e156 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ConnectorTopicsIntegrationTest.java @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.integration; + +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.connect.runtime.distributed.DistributedConfig; +import org.apache.kafka.connect.runtime.rest.errors.ConnectRestException; +import org.apache.kafka.connect.storage.KafkaStatusBackingStore; +import org.apache.kafka.connect.storage.StringConverter; +import org.apache.kafka.connect.util.clusters.EmbeddedConnectCluster; +import org.apache.kafka.test.IntegrationTest; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.experimental.categories.Category; + +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Properties; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; + +import static org.apache.kafka.connect.integration.MonitorableSourceConnector.TOPIC_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.CONNECTOR_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.KEY_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.TASKS_MAX_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.VALUE_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.DEFAULT_TOPIC_CREATION_PREFIX; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.PARTITIONS_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.REPLICATION_FACTOR_CONFIG; +import static org.apache.kafka.connect.runtime.WorkerConfig.CONNECTOR_CLIENT_POLICY_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.WorkerConfig.TOPIC_TRACKING_ALLOW_RESET_CONFIG; +import static org.apache.kafka.connect.runtime.WorkerConfig.TOPIC_TRACKING_ENABLE_CONFIG; +import static org.apache.kafka.connect.sink.SinkConnector.TOPICS_CONFIG; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +/** + * Integration test for the endpoints that offer topic tracking of a connector's active + * topics. + */ +@Category(IntegrationTest.class) +public class ConnectorTopicsIntegrationTest { + + private static final int NUM_WORKERS = 5; + private static final int NUM_TASKS = 1; + private static final String FOO_TOPIC = "foo-topic"; + private static final String FOO_CONNECTOR = "foo-source"; + private static final String BAR_TOPIC = "bar-topic"; + private static final String BAR_CONNECTOR = "bar-source"; + private static final String SINK_CONNECTOR = "baz-sink"; + private static final int NUM_TOPIC_PARTITIONS = 3; + + private EmbeddedConnectCluster.Builder connectBuilder; + private EmbeddedConnectCluster connect; + Map workerProps = new HashMap<>(); + Properties brokerProps = new Properties(); + + @Before + public void setup() { + // setup Connect worker properties + workerProps.put(CONNECTOR_CLIENT_POLICY_CLASS_CONFIG, "All"); + + // setup Kafka broker properties + brokerProps.put("auto.create.topics.enable", String.valueOf(false)); + + // build a Connect cluster backed by Kafka and Zk + connectBuilder = new EmbeddedConnectCluster.Builder() + .name("connect-cluster") + .numWorkers(NUM_WORKERS) + .workerProps(workerProps) + .brokerProps(brokerProps) + .maskExitProcedures(true); // true is the default, setting here as example + } + + @After + public void close() { + // stop all Connect, Kafka and Zk threads. + connect.stop(); + } + + @Test + public void testGetActiveTopics() throws InterruptedException { + connect = connectBuilder.build(); + // start the clusters + connect.start(); + + // create test topic + connect.kafka().createTopic(FOO_TOPIC, NUM_TOPIC_PARTITIONS); + connect.kafka().createTopic(BAR_TOPIC, NUM_TOPIC_PARTITIONS); + + connect.assertions().assertAtLeastNumWorkersAreUp(NUM_WORKERS, "Initial group of workers did not start in time."); + + connect.assertions().assertConnectorActiveTopics(FOO_CONNECTOR, Collections.emptyList(), + "Active topic set is not empty for connector: " + FOO_CONNECTOR); + + // start a source connector + connect.configureConnector(FOO_CONNECTOR, defaultSourceConnectorProps(FOO_TOPIC)); + + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(FOO_CONNECTOR, NUM_TASKS, + "Connector tasks did not start in time."); + + connect.assertions().assertConnectorActiveTopics(FOO_CONNECTOR, Collections.singletonList(FOO_TOPIC), + "Active topic set is not: " + Collections.singletonList(FOO_TOPIC) + " for connector: " + FOO_CONNECTOR); + + // start another source connector + connect.configureConnector(BAR_CONNECTOR, defaultSourceConnectorProps(BAR_TOPIC)); + + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(BAR_CONNECTOR, NUM_TASKS, + "Connector tasks did not start in time."); + + connect.assertions().assertConnectorActiveTopics(BAR_CONNECTOR, Collections.singletonList(BAR_TOPIC), + "Active topic set is not: " + Collections.singletonList(BAR_TOPIC) + " for connector: " + BAR_CONNECTOR); + + // start a sink connector + connect.configureConnector(SINK_CONNECTOR, defaultSinkConnectorProps(FOO_TOPIC, BAR_TOPIC)); + + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(SINK_CONNECTOR, NUM_TASKS, + "Connector tasks did not start in time."); + + connect.assertions().assertConnectorActiveTopics(SINK_CONNECTOR, Arrays.asList(FOO_TOPIC, BAR_TOPIC), + "Active topic set is not: " + Arrays.asList(FOO_TOPIC, BAR_TOPIC) + " for connector: " + SINK_CONNECTOR); + + // deleting a connector resets its active topics + connect.deleteConnector(BAR_CONNECTOR); + + connect.assertions().assertConnectorAndTasksAreStopped(BAR_CONNECTOR, + "Connector tasks did not stop in time."); + + connect.assertions().assertConnectorActiveTopics(BAR_CONNECTOR, Collections.emptyList(), + "Active topic set is not empty for deleted connector: " + BAR_CONNECTOR); + + // Unfortunately there's currently no easy way to know when the consumer caught up with + // the last records that the producer of the stopped connector managed to produce. + // Repeated runs show that this amount of time is sufficient for the consumer to catch up. + Thread.sleep(5000); + + // reset active topics for the sink connector after one of the topics has become idle + connect.resetConnectorTopics(SINK_CONNECTOR); + + connect.assertions().assertConnectorActiveTopics(SINK_CONNECTOR, Collections.singletonList(FOO_TOPIC), + "Active topic set is not: " + Collections.singletonList(FOO_TOPIC) + " for connector: " + SINK_CONNECTOR); + } + + @Test + public void testTopicTrackingResetIsDisabled() throws InterruptedException { + workerProps.put(TOPIC_TRACKING_ALLOW_RESET_CONFIG, "false"); + connect = connectBuilder.build(); + // start the clusters + connect.start(); + + // create test topic + connect.kafka().createTopic(FOO_TOPIC, NUM_TOPIC_PARTITIONS); + connect.kafka().createTopic(BAR_TOPIC, NUM_TOPIC_PARTITIONS); + + connect.assertions().assertAtLeastNumWorkersAreUp(NUM_WORKERS, "Initial group of workers did not start in time."); + + connect.assertions().assertConnectorActiveTopics(FOO_CONNECTOR, Collections.emptyList(), + "Active topic set is not empty for connector: " + FOO_CONNECTOR); + + // start a source connector + connect.configureConnector(FOO_CONNECTOR, defaultSourceConnectorProps(FOO_TOPIC)); + + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(FOO_CONNECTOR, NUM_TASKS, + "Connector tasks did not start in time."); + + connect.assertions().assertConnectorActiveTopics(FOO_CONNECTOR, Collections.singletonList(FOO_TOPIC), + "Active topic set is not: " + Collections.singletonList(FOO_TOPIC) + " for connector: " + FOO_CONNECTOR); + + // start a sink connector + connect.configureConnector(SINK_CONNECTOR, defaultSinkConnectorProps(FOO_TOPIC)); + + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(SINK_CONNECTOR, NUM_TASKS, + "Connector tasks did not start in time."); + + connect.assertions().assertConnectorActiveTopics(SINK_CONNECTOR, Arrays.asList(FOO_TOPIC), + "Active topic set is not: " + Arrays.asList(FOO_TOPIC) + " for connector: " + SINK_CONNECTOR); + + // deleting a connector resets its active topics + connect.deleteConnector(FOO_CONNECTOR); + + connect.assertions().assertConnectorAndTasksAreStopped(FOO_CONNECTOR, + "Connector tasks did not stop in time."); + + connect.assertions().assertConnectorActiveTopics(FOO_CONNECTOR, Collections.emptyList(), + "Active topic set is not empty for deleted connector: " + FOO_CONNECTOR); + + // Unfortunately there's currently no easy way to know when the consumer caught up with + // the last records that the producer of the stopped connector managed to produce. + // Repeated runs show that this amount of time is sufficient for the consumer to catch up. + Thread.sleep(5000); + + // resetting active topics for the sink connector won't work when the config is disabled + Exception e = assertThrows(ConnectRestException.class, () -> connect.resetConnectorTopics(SINK_CONNECTOR)); + assertTrue(e.getMessage().contains("Topic tracking reset is disabled.")); + + connect.assertions().assertConnectorActiveTopics(SINK_CONNECTOR, Collections.singletonList(FOO_TOPIC), + "Active topic set is not: " + Collections.singletonList(FOO_TOPIC) + " for connector: " + SINK_CONNECTOR); + } + + @Test + public void testTopicTrackingIsDisabled() throws InterruptedException { + workerProps.put(TOPIC_TRACKING_ENABLE_CONFIG, "false"); + connect = connectBuilder.build(); + // start the clusters + connect.start(); + + // create test topic + connect.kafka().createTopic(FOO_TOPIC, NUM_TOPIC_PARTITIONS); + connect.kafka().createTopic(BAR_TOPIC, NUM_TOPIC_PARTITIONS); + + connect.assertions().assertAtLeastNumWorkersAreUp(NUM_WORKERS, "Initial group of workers did not start in time."); + + // start a source connector + connect.configureConnector(FOO_CONNECTOR, defaultSourceConnectorProps(FOO_TOPIC)); + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(FOO_CONNECTOR, NUM_TASKS, + "Connector tasks did not start in time."); + + // resetting active topics for the sink connector won't work when the config is disabled + Exception e = assertThrows(ConnectRestException.class, () -> connect.resetConnectorTopics(SINK_CONNECTOR)); + assertTrue(e.getMessage().contains("Topic tracking is disabled.")); + + e = assertThrows(ConnectRestException.class, () -> connect.connectorTopics(SINK_CONNECTOR)); + assertTrue(e.getMessage().contains("Topic tracking is disabled.")); + + // Wait for tasks to produce a few records + Thread.sleep(5000); + + assertNoTopicStatusInStatusTopic(); + } + + public void assertNoTopicStatusInStatusTopic() { + String statusTopic = workerProps.get(DistributedConfig.STATUS_STORAGE_TOPIC_CONFIG); + Consumer verifiableConsumer = connect.kafka().createConsumer( + Collections.singletonMap("group.id", "verifiable-consumer-group-0")); + + List partitionInfos = verifiableConsumer.partitionsFor(statusTopic); + if (partitionInfos.isEmpty()) { + throw new AssertionError("Unable to retrieve partitions info for status topic"); + } + List partitions = partitionInfos.stream() + .map(info -> new TopicPartition(info.topic(), info.partition())) + .collect(Collectors.toList()); + verifiableConsumer.assign(partitions); + + // Based on the implementation of {@link org.apache.kafka.connect.util.KafkaBasedLog#readToLogEnd} + Set assignment = verifiableConsumer.assignment(); + verifiableConsumer.seekToBeginning(assignment); + Map endOffsets = verifiableConsumer.endOffsets(assignment); + while (!endOffsets.isEmpty()) { + Iterator> it = endOffsets.entrySet().iterator(); + while (it.hasNext()) { + Map.Entry entry = it.next(); + if (verifiableConsumer.position(entry.getKey()) >= entry.getValue()) + it.remove(); + else { + try { + StreamSupport.stream(verifiableConsumer.poll(Duration.ofMillis(Integer.MAX_VALUE)).spliterator(), false) + .map(ConsumerRecord::key) + .filter(Objects::nonNull) + .filter(key -> new String(key, StandardCharsets.UTF_8).startsWith(KafkaStatusBackingStore.TOPIC_STATUS_PREFIX)) + .findFirst() + .ifPresent(key -> { + throw new AssertionError("Found unexpected key: " + new String(key, StandardCharsets.UTF_8) + " in status topic"); + }); + } catch (KafkaException e) { + throw new AssertionError("Error while reading to the end of status topic", e); + } + break; + } + } + } + } + + private Map defaultSourceConnectorProps(String topic) { + // setup up props for the source connector + Map props = new HashMap<>(); + props.put(CONNECTOR_CLASS_CONFIG, MonitorableSourceConnector.class.getSimpleName()); + props.put(TASKS_MAX_CONFIG, String.valueOf(NUM_TASKS)); + props.put(TOPIC_CONFIG, topic); + props.put("throughput", String.valueOf(10)); + props.put("messages.per.poll", String.valueOf(10)); + props.put(KEY_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(VALUE_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + REPLICATION_FACTOR_CONFIG, String.valueOf(1)); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + PARTITIONS_CONFIG, String.valueOf(1)); + return props; + } + + private Map defaultSinkConnectorProps(String... topics) { + // setup up props for the sink connector + Map props = new HashMap<>(); + props.put(CONNECTOR_CLASS_CONFIG, MonitorableSinkConnector.class.getSimpleName()); + props.put(TASKS_MAX_CONFIG, String.valueOf(NUM_TASKS)); + props.put(TOPICS_CONFIG, String.join(",", topics)); + props.put(KEY_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(VALUE_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + return props; + } + +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ErrantRecordSinkConnector.java b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ErrantRecordSinkConnector.java new file mode 100644 index 0000000..251c67c --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ErrantRecordSinkConnector.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.integration; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.connect.connector.Task; +import org.apache.kafka.connect.sink.ErrantRecordReporter; +import org.apache.kafka.connect.sink.SinkRecord; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; + +public class ErrantRecordSinkConnector extends MonitorableSinkConnector { + + @Override + public Class taskClass() { + return ErrantRecordSinkTask.class; + } + + public static class ErrantRecordSinkTask extends MonitorableSinkTask { + private ErrantRecordReporter reporter; + + public ErrantRecordSinkTask() { + super(); + } + + @Override + public void start(Map props) { + super.start(props); + reporter = context.errantRecordReporter(); + } + + @Override + public void put(Collection records) { + for (SinkRecord rec : records) { + taskHandle.record(); + TopicPartition tp = cachedTopicPartitions + .computeIfAbsent(rec.topic(), v -> new HashMap<>()) + .computeIfAbsent(rec.kafkaPartition(), v -> new TopicPartition(rec.topic(), rec.kafkaPartition())); + committedOffsets.put(tp, committedOffsets.getOrDefault(tp, 0) + 1); + reporter.report(rec, new Throwable()); + } + } + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ErrorHandlingIntegrationTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ErrorHandlingIntegrationTest.java new file mode 100644 index 0000000..b3dd9a0 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ErrorHandlingIntegrationTest.java @@ -0,0 +1,328 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.integration; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.connect.connector.ConnectRecord; +import org.apache.kafka.connect.errors.RetriableException; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorStateInfo; +import org.apache.kafka.connect.storage.StringConverter; +import org.apache.kafka.connect.transforms.Transformation; +import org.apache.kafka.connect.util.clusters.EmbeddedConnectCluster; +import org.apache.kafka.test.IntegrationTest; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.apache.kafka.connect.runtime.ConnectorConfig.CONNECTOR_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.ERRORS_LOG_ENABLE_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.ERRORS_LOG_INCLUDE_MESSAGES_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.ERRORS_RETRY_TIMEOUT_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.ERRORS_TOLERANCE_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.KEY_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.TASKS_MAX_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.TRANSFORMS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.VALUE_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.SinkConnectorConfig.DLQ_CONTEXT_HEADERS_ENABLE_CONFIG; +import static org.apache.kafka.connect.runtime.SinkConnectorConfig.DLQ_TOPIC_NAME_CONFIG; +import static org.apache.kafka.connect.runtime.SinkConnectorConfig.DLQ_TOPIC_REPLICATION_FACTOR_CONFIG; +import static org.apache.kafka.connect.runtime.SinkConnectorConfig.TOPICS_CONFIG; +import static org.apache.kafka.connect.runtime.errors.DeadLetterQueueReporter.ERROR_HEADER_EXCEPTION; +import static org.apache.kafka.connect.runtime.errors.DeadLetterQueueReporter.ERROR_HEADER_EXCEPTION_MESSAGE; +import static org.apache.kafka.connect.runtime.errors.DeadLetterQueueReporter.ERROR_HEADER_ORIG_TOPIC; +import static org.apache.kafka.test.TestUtils.waitForCondition; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** + * Integration test for the different error handling policies in Connect (namely, retry policies, skipping bad records, + * and dead letter queues). + */ +@Category(IntegrationTest.class) +public class ErrorHandlingIntegrationTest { + + private static final Logger log = LoggerFactory.getLogger(ErrorHandlingIntegrationTest.class); + + private static final int NUM_WORKERS = 1; + private static final String DLQ_TOPIC = "my-connector-errors"; + private static final String CONNECTOR_NAME = "error-conn"; + private static final String TASK_ID = "error-conn-0"; + private static final int NUM_RECORDS_PRODUCED = 20; + private static final int EXPECTED_CORRECT_RECORDS = 19; + private static final int EXPECTED_INCORRECT_RECORDS = 1; + private static final int NUM_TASKS = 1; + private static final long CONNECTOR_SETUP_DURATION_MS = TimeUnit.SECONDS.toMillis(60); + private static final long CONSUME_MAX_DURATION_MS = TimeUnit.SECONDS.toMillis(30); + + private EmbeddedConnectCluster connect; + private ConnectorHandle connectorHandle; + + @Before + public void setup() throws InterruptedException { + // setup Connect cluster with defaults + connect = new EmbeddedConnectCluster.Builder().build(); + + // start Connect cluster + connect.start(); + connect.assertions().assertAtLeastNumWorkersAreUp(NUM_WORKERS, + "Initial group of workers did not start in time."); + + // get connector handles before starting test. + connectorHandle = RuntimeHandles.get().connectorHandle(CONNECTOR_NAME); + } + + @After + public void close() { + RuntimeHandles.get().deleteConnector(CONNECTOR_NAME); + connect.stop(); + } + + @Test + public void testSkipRetryAndDLQWithHeaders() throws Exception { + // create test topic + connect.kafka().createTopic("test-topic"); + + // setup connector config + Map props = new HashMap<>(); + props.put(CONNECTOR_CLASS_CONFIG, MonitorableSinkConnector.class.getSimpleName()); + props.put(TASKS_MAX_CONFIG, String.valueOf(NUM_TASKS)); + props.put(TOPICS_CONFIG, "test-topic"); + props.put(KEY_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(VALUE_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(TRANSFORMS_CONFIG, "failing_transform"); + props.put("transforms.failing_transform.type", FaultyPassthrough.class.getName()); + + // log all errors, along with message metadata + props.put(ERRORS_LOG_ENABLE_CONFIG, "true"); + props.put(ERRORS_LOG_INCLUDE_MESSAGES_CONFIG, "true"); + + // produce bad messages into dead letter queue + props.put(DLQ_TOPIC_NAME_CONFIG, DLQ_TOPIC); + props.put(DLQ_CONTEXT_HEADERS_ENABLE_CONFIG, "true"); + props.put(DLQ_TOPIC_REPLICATION_FACTOR_CONFIG, "1"); + + // tolerate all erros + props.put(ERRORS_TOLERANCE_CONFIG, "all"); + + // retry for up to one second + props.put(ERRORS_RETRY_TIMEOUT_CONFIG, "1000"); + + // set expected records to successfully reach the task + connectorHandle.taskHandle(TASK_ID).expectedRecords(EXPECTED_CORRECT_RECORDS); + + connect.configureConnector(CONNECTOR_NAME, props); + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(CONNECTOR_NAME, NUM_TASKS, + "Connector tasks did not start in time."); + + waitForCondition(this::checkForPartitionAssignment, + CONNECTOR_SETUP_DURATION_MS, + "Connector task was not assigned a partition."); + + // produce some strings into test topic + for (int i = 0; i < NUM_RECORDS_PRODUCED; i++) { + connect.kafka().produce("test-topic", "key-" + i, "value-" + i); + } + + // consume all records from test topic + log.info("Consuming records from test topic"); + int i = 0; + for (ConsumerRecord rec : connect.kafka().consume(NUM_RECORDS_PRODUCED, CONSUME_MAX_DURATION_MS, "test-topic")) { + String k = new String(rec.key()); + String v = new String(rec.value()); + log.debug("Consumed record (key='{}', value='{}') from topic {}", k, v, rec.topic()); + assertEquals("Unexpected key", k, "key-" + i); + assertEquals("Unexpected value", v, "value-" + i); + i++; + } + + // wait for records to reach the task + connectorHandle.taskHandle(TASK_ID).awaitRecords(CONSUME_MAX_DURATION_MS); + + // consume failed records from dead letter queue topic + log.info("Consuming records from test topic"); + ConsumerRecords messages = connect.kafka().consume(EXPECTED_INCORRECT_RECORDS, CONSUME_MAX_DURATION_MS, DLQ_TOPIC); + for (ConsumerRecord recs : messages) { + log.debug("Consumed record (key={}, value={}) from dead letter queue topic {}", + new String(recs.key()), new String(recs.value()), DLQ_TOPIC); + assertTrue(recs.headers().toArray().length > 0); + assertValue("test-topic", recs.headers(), ERROR_HEADER_ORIG_TOPIC); + assertValue(RetriableException.class.getName(), recs.headers(), ERROR_HEADER_EXCEPTION); + assertValue("Error when value='value-7'", recs.headers(), ERROR_HEADER_EXCEPTION_MESSAGE); + } + + connect.deleteConnector(CONNECTOR_NAME); + connect.assertions().assertConnectorAndTasksAreStopped(CONNECTOR_NAME, + "Connector tasks did not stop in time."); + + } + + @Test + public void testErrantRecordReporter() throws Exception { + // create test topic + connect.kafka().createTopic("test-topic"); + + // setup connector config + Map props = new HashMap<>(); + props.put(CONNECTOR_CLASS_CONFIG, ErrantRecordSinkConnector.class.getSimpleName()); + props.put(TASKS_MAX_CONFIG, String.valueOf(NUM_TASKS)); + props.put(TOPICS_CONFIG, "test-topic"); + props.put(KEY_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(VALUE_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + + // log all errors, along with message metadata + props.put(ERRORS_LOG_ENABLE_CONFIG, "true"); + props.put(ERRORS_LOG_INCLUDE_MESSAGES_CONFIG, "true"); + + // produce bad messages into dead letter queue + props.put(DLQ_TOPIC_NAME_CONFIG, DLQ_TOPIC); + props.put(DLQ_CONTEXT_HEADERS_ENABLE_CONFIG, "true"); + props.put(DLQ_TOPIC_REPLICATION_FACTOR_CONFIG, "1"); + + // tolerate all erros + props.put(ERRORS_TOLERANCE_CONFIG, "all"); + + // retry for up to one second + props.put(ERRORS_RETRY_TIMEOUT_CONFIG, "1000"); + + // set expected records to successfully reach the task + connectorHandle.taskHandle(TASK_ID).expectedRecords(EXPECTED_CORRECT_RECORDS); + + connect.configureConnector(CONNECTOR_NAME, props); + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(CONNECTOR_NAME, NUM_TASKS, + "Connector tasks did not start in time."); + + waitForCondition(this::checkForPartitionAssignment, + CONNECTOR_SETUP_DURATION_MS, + "Connector task was not assigned a partition."); + + // produce some strings into test topic + for (int i = 0; i < NUM_RECORDS_PRODUCED; i++) { + connect.kafka().produce("test-topic", "key-" + i, "value-" + i); + } + + // consume all records from test topic + log.info("Consuming records from test topic"); + int i = 0; + for (ConsumerRecord rec : connect.kafka().consume(NUM_RECORDS_PRODUCED, CONSUME_MAX_DURATION_MS, "test-topic")) { + String k = new String(rec.key()); + String v = new String(rec.value()); + log.debug("Consumed record (key='{}', value='{}') from topic {}", k, v, rec.topic()); + assertEquals("Unexpected key", k, "key-" + i); + assertEquals("Unexpected value", v, "value-" + i); + i++; + } + + // wait for records to reach the task + connectorHandle.taskHandle(TASK_ID).awaitRecords(CONSUME_MAX_DURATION_MS); + + // consume failed records from dead letter queue topic + log.info("Consuming records from test topic"); + ConsumerRecords messages = connect.kafka().consume(EXPECTED_INCORRECT_RECORDS, CONSUME_MAX_DURATION_MS, DLQ_TOPIC); + + connect.deleteConnector(CONNECTOR_NAME); + connect.assertions().assertConnectorAndTasksAreStopped(CONNECTOR_NAME, + "Connector tasks did not stop in time."); + } + + /** + * Check if a partition was assigned to each task. This method swallows exceptions since it is invoked from a + * {@link org.apache.kafka.test.TestUtils#waitForCondition} that will throw an error if this method continued + * to return false after the specified duration has elapsed. + * + * @return true if each task was assigned a partition each, false if this was not true or an error occurred when + * executing this operation. + */ + private boolean checkForPartitionAssignment() { + try { + ConnectorStateInfo info = connect.connectorStatus(CONNECTOR_NAME); + return info != null && info.tasks().size() == NUM_TASKS + && connectorHandle.taskHandle(TASK_ID).numPartitionsAssigned() == 1; + } catch (Exception e) { + // Log the exception and return that the partitions were not assigned + log.error("Could not check connector state info.", e); + return false; + } + } + + private void assertValue(String expected, Headers headers, String headerKey) { + byte[] actual = headers.lastHeader(headerKey).value(); + if (expected == null && actual == null) { + return; + } + if (expected == null || actual == null) { + fail(); + } + assertEquals(expected, new String(actual)); + } + + public static class FaultyPassthrough> implements Transformation { + + static final ConfigDef CONFIG_DEF = new ConfigDef(); + + /** + * An arbitrary id which causes this transformation to fail with a {@link RetriableException}, but succeeds + * on subsequent attempt. + */ + static final int BAD_RECORD_VAL_RETRIABLE = 4; + + /** + * An arbitrary id which causes this transformation to fail with a {@link RetriableException}. + */ + static final int BAD_RECORD_VAL = 7; + + private boolean shouldFail = true; + + @Override + public R apply(R record) { + String badValRetriable = "value-" + BAD_RECORD_VAL_RETRIABLE; + if (badValRetriable.equals(record.value()) && shouldFail) { + shouldFail = false; + throw new RetriableException("Error when value='" + badValRetriable + + "'. A reattempt with this record will succeed."); + } + String badVal = "value-" + BAD_RECORD_VAL; + if (badVal.equals(record.value())) { + throw new RetriableException("Error when value='" + badVal + "'"); + } + return record; + } + + @Override + public ConfigDef config() { + return CONFIG_DEF; + } + + @Override + public void close() { + } + + @Override + public void configure(Map configs) { + } + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ExampleConnectIntegrationTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ExampleConnectIntegrationTest.java new file mode 100644 index 0000000..23a87c2 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/ExampleConnectIntegrationTest.java @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.integration; + +import org.apache.kafka.connect.runtime.rest.entities.ConnectorStateInfo; +import org.apache.kafka.connect.storage.StringConverter; +import org.apache.kafka.connect.util.clusters.EmbeddedConnectCluster; +import org.apache.kafka.test.IntegrationTest; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestRule; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.TimeUnit; + +import static org.apache.kafka.connect.runtime.ConnectorConfig.CONNECTOR_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.KEY_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.TASKS_MAX_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.VALUE_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.SinkConnectorConfig.TOPICS_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.DEFAULT_TOPIC_CREATION_PREFIX; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.PARTITIONS_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.REPLICATION_FACTOR_CONFIG; +import static org.apache.kafka.connect.runtime.WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_CONFIG; +import static org.apache.kafka.test.TestUtils.waitForCondition; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * An example integration test that demonstrates how to setup an integration test for Connect. + *

            + * The following test configures and executes up a sink connector pipeline in a worker, produces messages into + * the source topic-partitions, and demonstrates how to check the overall behavior of the pipeline. + */ +@Category(IntegrationTest.class) +public class ExampleConnectIntegrationTest { + + private static final Logger log = LoggerFactory.getLogger(ExampleConnectIntegrationTest.class); + + private static final int NUM_RECORDS_PRODUCED = 2000; + private static final int NUM_TOPIC_PARTITIONS = 3; + private static final long RECORD_TRANSFER_DURATION_MS = TimeUnit.SECONDS.toMillis(30); + private static final long CONNECTOR_SETUP_DURATION_MS = TimeUnit.SECONDS.toMillis(60); + private static final int NUM_TASKS = 3; + private static final int NUM_WORKERS = 3; + private static final String CONNECTOR_NAME = "simple-conn"; + private static final String SINK_CONNECTOR_CLASS_NAME = MonitorableSinkConnector.class.getSimpleName(); + private static final String SOURCE_CONNECTOR_CLASS_NAME = MonitorableSourceConnector.class.getSimpleName(); + + private EmbeddedConnectCluster connect; + private ConnectorHandle connectorHandle; + + @Rule + public TestRule watcher = ConnectIntegrationTestUtils.newTestWatcher(log); + + @Before + public void setup() { + // setup Connect worker properties + Map exampleWorkerProps = new HashMap<>(); + exampleWorkerProps.put(OFFSET_COMMIT_INTERVAL_MS_CONFIG, String.valueOf(5_000)); + + // setup Kafka broker properties + Properties exampleBrokerProps = new Properties(); + exampleBrokerProps.put("auto.create.topics.enable", "false"); + + // build a Connect cluster backed by Kafka and Zk + connect = new EmbeddedConnectCluster.Builder() + .name("connect-cluster") + .numWorkers(NUM_WORKERS) + .numBrokers(1) + .workerProps(exampleWorkerProps) + .brokerProps(exampleBrokerProps) + .build(); + + // start the clusters + connect.start(); + + // get a handle to the connector + connectorHandle = RuntimeHandles.get().connectorHandle(CONNECTOR_NAME); + } + + @After + public void close() { + // delete connector handle + RuntimeHandles.get().deleteConnector(CONNECTOR_NAME); + + // stop all Connect, Kafka and Zk threads. + connect.stop(); + } + + /** + * Simple test case to configure and execute an embedded Connect cluster. The test will produce and consume + * records, and start up a sink connector which will consume these records. + */ + @Test + public void testSinkConnector() throws Exception { + // create test topic + connect.kafka().createTopic("test-topic", NUM_TOPIC_PARTITIONS); + + // setup up props for the sink connector + Map props = new HashMap<>(); + props.put(CONNECTOR_CLASS_CONFIG, SINK_CONNECTOR_CLASS_NAME); + props.put(TASKS_MAX_CONFIG, String.valueOf(NUM_TASKS)); + props.put(TOPICS_CONFIG, "test-topic"); + props.put(KEY_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(VALUE_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + + // expect all records to be consumed by the connector + connectorHandle.expectedRecords(NUM_RECORDS_PRODUCED); + + // expect all records to be consumed by the connector + connectorHandle.expectedCommits(NUM_RECORDS_PRODUCED); + + // validate the intended connector configuration, a config that errors + connect.assertions().assertExactlyNumErrorsOnConnectorConfigValidation(SINK_CONNECTOR_CLASS_NAME, props, 1, + "Validating connector configuration produced an unexpected number or errors."); + + // add missing configuration to make the config valid + props.put("name", CONNECTOR_NAME); + + // validate the intended connector configuration, a valid config + connect.assertions().assertExactlyNumErrorsOnConnectorConfigValidation(SINK_CONNECTOR_CLASS_NAME, props, 0, + "Validating connector configuration produced an unexpected number or errors."); + + // start a sink connector + connect.configureConnector(CONNECTOR_NAME, props); + + waitForCondition(this::checkForPartitionAssignment, + CONNECTOR_SETUP_DURATION_MS, + "Connector tasks were not assigned a partition each."); + + // produce some messages into source topic partitions + for (int i = 0; i < NUM_RECORDS_PRODUCED; i++) { + connect.kafka().produce("test-topic", i % NUM_TOPIC_PARTITIONS, "key", "simple-message-value-" + i); + } + + // consume all records from the source topic or fail, to ensure that they were correctly produced. + assertEquals("Unexpected number of records consumed", NUM_RECORDS_PRODUCED, + connect.kafka().consume(NUM_RECORDS_PRODUCED, RECORD_TRANSFER_DURATION_MS, "test-topic").count()); + + // wait for the connector tasks to consume all records. + connectorHandle.awaitRecords(RECORD_TRANSFER_DURATION_MS); + + // wait for the connector tasks to commit all records. + connectorHandle.awaitCommits(RECORD_TRANSFER_DURATION_MS); + + // delete connector + connect.deleteConnector(CONNECTOR_NAME); + } + + /** + * Simple test case to configure and execute an embedded Connect cluster. The test will produce and consume + * records, and start up a sink connector which will consume these records. + */ + @Test + public void testSourceConnector() throws Exception { + // create test topic + connect.kafka().createTopic("test-topic", NUM_TOPIC_PARTITIONS); + + // setup up props for the source connector + Map props = new HashMap<>(); + props.put(CONNECTOR_CLASS_CONFIG, SOURCE_CONNECTOR_CLASS_NAME); + props.put(TASKS_MAX_CONFIG, String.valueOf(NUM_TASKS)); + props.put("topic", "test-topic"); + props.put("throughput", String.valueOf(500)); + props.put(KEY_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(VALUE_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + REPLICATION_FACTOR_CONFIG, String.valueOf(1)); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + PARTITIONS_CONFIG, String.valueOf(1)); + + // expect all records to be produced by the connector + connectorHandle.expectedRecords(NUM_RECORDS_PRODUCED); + + // expect all records to be produced by the connector + connectorHandle.expectedCommits(NUM_RECORDS_PRODUCED); + + // validate the intended connector configuration, a config that errors + connect.assertions().assertExactlyNumErrorsOnConnectorConfigValidation(SOURCE_CONNECTOR_CLASS_NAME, props, 1, + "Validating connector configuration produced an unexpected number or errors."); + + // add missing configuration to make the config valid + props.put("name", CONNECTOR_NAME); + + // validate the intended connector configuration, a valid config + connect.assertions().assertExactlyNumErrorsOnConnectorConfigValidation(SOURCE_CONNECTOR_CLASS_NAME, props, 0, + "Validating connector configuration produced an unexpected number or errors."); + + // start a source connector + connect.configureConnector(CONNECTOR_NAME, props); + + // wait for the connector tasks to produce enough records + connectorHandle.awaitRecords(RECORD_TRANSFER_DURATION_MS); + + // wait for the connector tasks to commit enough records + connectorHandle.awaitCommits(RECORD_TRANSFER_DURATION_MS); + + // consume all records from the source topic or fail, to ensure that they were correctly produced + int recordNum = connect.kafka().consume(NUM_RECORDS_PRODUCED, RECORD_TRANSFER_DURATION_MS, "test-topic").count(); + assertTrue("Not enough records produced by source connector. Expected at least: " + NUM_RECORDS_PRODUCED + " + but got " + recordNum, + recordNum >= NUM_RECORDS_PRODUCED); + + // delete connector + connect.deleteConnector(CONNECTOR_NAME); + } + + /** + * Check if a partition was assigned to each task. This method swallows exceptions since it is invoked from a + * {@link org.apache.kafka.test.TestUtils#waitForCondition} that will throw an error if this method continued + * to return false after the specified duration has elapsed. + * + * @return true if each task was assigned a partition each, false if this was not true or an error occurred when + * executing this operation. + */ + private boolean checkForPartitionAssignment() { + try { + ConnectorStateInfo info = connect.connectorStatus(CONNECTOR_NAME); + return info != null && info.tasks().size() == NUM_TASKS + && connectorHandle.tasks().stream().allMatch(th -> th.numPartitionsAssigned() == 1); + } catch (Exception e) { + // Log the exception and return that the partitions were not assigned + log.error("Could not check connector state info.", e); + return false; + } + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/InternalTopicsIntegrationTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/InternalTopicsIntegrationTest.java new file mode 100644 index 0000000..d73d1c4 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/InternalTopicsIntegrationTest.java @@ -0,0 +1,315 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.integration; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; + +import org.apache.kafka.common.config.TopicConfig; +import org.apache.kafka.connect.runtime.distributed.DistributedConfig; +import org.apache.kafka.connect.util.clusters.EmbeddedConnectCluster; +import org.apache.kafka.connect.util.clusters.WorkerHandle; +import org.apache.kafka.test.IntegrationTest; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.junit.Assert.assertFalse; + +/** + * Integration test for the creation of internal topics. + */ +@Category(IntegrationTest.class) +public class InternalTopicsIntegrationTest { + + private static final Logger log = LoggerFactory.getLogger(InternalTopicsIntegrationTest.class); + + private EmbeddedConnectCluster connect; + Map workerProps = new HashMap<>(); + Properties brokerProps = new Properties(); + + @Before + public void setup() { + // setup Kafka broker properties + brokerProps.put("auto.create.topics.enable", String.valueOf(false)); + } + + @After + public void close() { + // stop all Connect, Kafka and Zk threads. + connect.stop(); + } + + @Test + public void testCreateInternalTopicsWithDefaultSettings() throws InterruptedException { + int numWorkers = 1; + int numBrokers = 3; + connect = new EmbeddedConnectCluster.Builder().name("connect-cluster-1") + .workerProps(workerProps) + .numWorkers(numWorkers) + .numBrokers(numBrokers) + .brokerProps(brokerProps) + .build(); + + // Start the Connect cluster + connect.start(); + connect.assertions().assertExactlyNumBrokersAreUp(numBrokers, "Brokers did not start in time."); + connect.assertions().assertExactlyNumWorkersAreUp(numWorkers, "Worker did not start in time."); + log.info("Completed startup of {} Kafka brokers and {} Connect workers", numBrokers, numWorkers); + + // Check the topics + log.info("Verifying the internal topics for Connect"); + connect.assertions().assertTopicsExist(configTopic(), offsetTopic(), statusTopic()); + assertInternalTopicSettings(); + + // Remove the Connect worker + log.info("Stopping the Connect worker"); + connect.removeWorker(); + + // And restart + log.info("Starting the Connect worker"); + connect.startConnect(); + + // Check the topics + log.info("Verifying the internal topics for Connect"); + connect.assertions().assertTopicsExist(configTopic(), offsetTopic(), statusTopic()); + assertInternalTopicSettings(); + } + + @Test + public void testCreateInternalTopicsWithFewerReplicasThanBrokers() throws InterruptedException { + workerProps.put(DistributedConfig.CONFIG_STORAGE_REPLICATION_FACTOR_CONFIG, "1"); + workerProps.put(DistributedConfig.OFFSET_STORAGE_REPLICATION_FACTOR_CONFIG, "2"); + workerProps.put(DistributedConfig.STATUS_STORAGE_REPLICATION_FACTOR_CONFIG, "1"); + int numWorkers = 1; + int numBrokers = 2; + connect = new EmbeddedConnectCluster.Builder().name("connect-cluster-1") + .workerProps(workerProps) + .numWorkers(numWorkers) + .numBrokers(numBrokers) + .brokerProps(brokerProps) + .build(); + + // Start the Connect cluster + connect.start(); + connect.assertions().assertExactlyNumBrokersAreUp(numBrokers, "Broker did not start in time."); + connect.assertions().assertAtLeastNumWorkersAreUp(numWorkers, "Worker did not start in time."); + log.info("Completed startup of {} Kafka brokers and {} Connect workers", numBrokers, numWorkers); + + // Check the topics + log.info("Verifying the internal topics for Connect"); + connect.assertions().assertTopicsExist(configTopic(), offsetTopic(), statusTopic()); + assertInternalTopicSettings(); + } + + @Test + public void testFailToCreateInternalTopicsWithMoreReplicasThanBrokers() throws InterruptedException { + workerProps.put(DistributedConfig.CONFIG_STORAGE_REPLICATION_FACTOR_CONFIG, "3"); + workerProps.put(DistributedConfig.OFFSET_STORAGE_REPLICATION_FACTOR_CONFIG, "2"); + workerProps.put(DistributedConfig.STATUS_STORAGE_REPLICATION_FACTOR_CONFIG, "1"); + int numWorkers = 1; + int numBrokers = 1; + connect = new EmbeddedConnectCluster.Builder().name("connect-cluster-1") + .workerProps(workerProps) + .numWorkers(numWorkers) + .numBrokers(numBrokers) + .brokerProps(brokerProps) + .build(); + + // Start the brokers and Connect, but Connect should fail to create config and offset topic + connect.start(); + connect.assertions().assertExactlyNumBrokersAreUp(numBrokers, "Broker did not start in time."); + log.info("Completed startup of {} Kafka broker. Expected Connect worker to fail", numBrokers); + + // Verify that the offset and config topic don't exist; + // the status topic may have been created if timing was right but we don't care + log.info("Verifying the internal topics for Connect"); + connect.assertions().assertTopicsDoNotExist(configTopic(), offsetTopic()); + } + + @Test + public void testFailToStartWhenInternalTopicsAreNotCompacted() throws InterruptedException { + // Change the broker default cleanup policy to something Connect doesn't like + brokerProps.put("log." + TopicConfig.CLEANUP_POLICY_CONFIG, TopicConfig.CLEANUP_POLICY_DELETE); + // Start out using the improperly configured topics + workerProps.put(DistributedConfig.CONFIG_TOPIC_CONFIG, "bad-config"); + workerProps.put(DistributedConfig.OFFSET_STORAGE_TOPIC_CONFIG, "bad-offset"); + workerProps.put(DistributedConfig.STATUS_STORAGE_TOPIC_CONFIG, "bad-status"); + workerProps.put(DistributedConfig.CONFIG_STORAGE_REPLICATION_FACTOR_CONFIG, "1"); + workerProps.put(DistributedConfig.OFFSET_STORAGE_REPLICATION_FACTOR_CONFIG, "1"); + workerProps.put(DistributedConfig.STATUS_STORAGE_REPLICATION_FACTOR_CONFIG, "1"); + int numWorkers = 0; + int numBrokers = 1; + connect = new EmbeddedConnectCluster.Builder().name("connect-cluster-1") + .workerProps(workerProps) + .numWorkers(numWorkers) + .numBrokers(numBrokers) + .brokerProps(brokerProps) + .build(); + + // Start the brokers but not Connect + log.info("Starting {} Kafka brokers, but no Connect workers yet", numBrokers); + connect.start(); + connect.assertions().assertExactlyNumBrokersAreUp(numBrokers, "Broker did not start in time."); + log.info("Completed startup of {} Kafka broker. Expected Connect worker to fail", numBrokers); + + // Create the good topics + connect.kafka().createTopic("good-config", 1, 1, compactCleanupPolicy()); + connect.kafka().createTopic("good-offset", 1, 1, compactCleanupPolicy()); + connect.kafka().createTopic("good-status", 1, 1, compactCleanupPolicy()); + + // Create the poorly-configured topics + connect.kafka().createTopic("bad-config", 1, 1, deleteCleanupPolicy()); + connect.kafka().createTopic("bad-offset", 1, 1, compactAndDeleteCleanupPolicy()); + connect.kafka().createTopic("bad-status", 1, 1, noTopicSettings()); + + // Check the topics + log.info("Verifying the internal topics for Connect were manually created"); + connect.assertions().assertTopicsExist("good-config", "good-offset", "good-status", "bad-config", "bad-offset", "bad-status"); + + // Try to start one worker, with three bad topics + WorkerHandle worker = connect.addWorker(); // should have failed to start before returning + assertFalse(worker.isRunning()); + assertFalse(connect.allWorkersRunning()); + assertFalse(connect.anyWorkersRunning()); + connect.removeWorker(worker); + + // We rely upon the fact that we can change the worker properties before the workers are started + workerProps.put(DistributedConfig.CONFIG_TOPIC_CONFIG, "good-config"); + + // Try to start one worker, with two bad topics remaining + worker = connect.addWorker(); // should have failed to start before returning + assertFalse(worker.isRunning()); + assertFalse(connect.allWorkersRunning()); + assertFalse(connect.anyWorkersRunning()); + connect.removeWorker(worker); + + // We rely upon the fact that we can change the worker properties before the workers are started + workerProps.put(DistributedConfig.OFFSET_STORAGE_TOPIC_CONFIG, "good-offset"); + + // Try to start one worker, with one bad topic remaining + worker = connect.addWorker(); // should have failed to start before returning + assertFalse(worker.isRunning()); + assertFalse(connect.allWorkersRunning()); + assertFalse(connect.anyWorkersRunning()); + connect.removeWorker(worker); + // We rely upon the fact that we can change the worker properties before the workers are started + workerProps.put(DistributedConfig.STATUS_STORAGE_TOPIC_CONFIG, "good-status"); + + // Try to start one worker, now using all good internal topics + connect.addWorker(); + connect.assertions().assertAtLeastNumWorkersAreUp(1, "Worker did not start in time."); + } + + @Test + public void testStartWhenInternalTopicsCreatedManuallyWithCompactForBrokersDefaultCleanupPolicy() throws InterruptedException { + // Change the broker default cleanup policy to compact, which is good for Connect + brokerProps.put("log." + TopicConfig.CLEANUP_POLICY_CONFIG, TopicConfig.CLEANUP_POLICY_COMPACT); + // Start out using the properly configured topics + workerProps.put(DistributedConfig.CONFIG_TOPIC_CONFIG, "config-topic"); + workerProps.put(DistributedConfig.OFFSET_STORAGE_TOPIC_CONFIG, "offset-topic"); + workerProps.put(DistributedConfig.STATUS_STORAGE_TOPIC_CONFIG, "status-topic"); + workerProps.put(DistributedConfig.CONFIG_STORAGE_REPLICATION_FACTOR_CONFIG, "1"); + workerProps.put(DistributedConfig.OFFSET_STORAGE_REPLICATION_FACTOR_CONFIG, "1"); + workerProps.put(DistributedConfig.STATUS_STORAGE_REPLICATION_FACTOR_CONFIG, "1"); + int numWorkers = 0; + int numBrokers = 1; + connect = new EmbeddedConnectCluster.Builder().name("connect-cluster-1") + .workerProps(workerProps) + .numWorkers(numWorkers) + .numBrokers(numBrokers) + .brokerProps(brokerProps) + .build(); + + // Start the brokers but not Connect + log.info("Starting {} Kafka brokers, but no Connect workers yet", numBrokers); + connect.start(); + connect.assertions().assertExactlyNumBrokersAreUp(numBrokers, "Broker did not start in time."); + log.info("Completed startup of {} Kafka broker. Expected Connect worker to fail", numBrokers); + + // Create the valid internal topics w/o topic settings, so these will use the broker's + // broker's log.cleanup.policy=compact (set above) + connect.kafka().createTopic("config-topic", 1, 1, noTopicSettings()); + connect.kafka().createTopic("offset-topic", 1, 1, noTopicSettings()); + connect.kafka().createTopic("status-topic", 1, 1, noTopicSettings()); + + // Check the topics + log.info("Verifying the internal topics for Connect were manually created"); + connect.assertions().assertTopicsExist("config-topic", "offset-topic", "status-topic"); + + // Try to start one worker using valid internal topics + connect.addWorker(); + connect.assertions().assertAtLeastNumWorkersAreUp(1, "Worker did not start in time."); + } + + protected Map compactCleanupPolicy() { + return Collections.singletonMap(TopicConfig.CLEANUP_POLICY_CONFIG, TopicConfig.CLEANUP_POLICY_COMPACT); + } + + protected Map deleteCleanupPolicy() { + return Collections.singletonMap(TopicConfig.CLEANUP_POLICY_CONFIG, TopicConfig.CLEANUP_POLICY_DELETE); + } + + protected Map noTopicSettings() { + return Collections.emptyMap(); + } + + protected Map compactAndDeleteCleanupPolicy() { + Map config = new HashMap<>(); + config.put(TopicConfig.CLEANUP_POLICY_CONFIG, TopicConfig.CLEANUP_POLICY_DELETE + "," + TopicConfig.CLEANUP_POLICY_COMPACT); + return config; + } + + protected void assertInternalTopicSettings() throws InterruptedException { + DistributedConfig config = new DistributedConfig(workerProps); + connect.assertions().assertTopicSettings( + configTopic(), + config.getShort(DistributedConfig.CONFIG_STORAGE_REPLICATION_FACTOR_CONFIG), + 1, + "Config topic does not have the expected settings" + ); + connect.assertions().assertTopicSettings( + statusTopic(), + config.getShort(DistributedConfig.STATUS_STORAGE_REPLICATION_FACTOR_CONFIG), + config.getInt(DistributedConfig.STATUS_STORAGE_PARTITIONS_CONFIG), + "Status topic does not have the expected settings" + ); + connect.assertions().assertTopicSettings( + offsetTopic(), + config.getShort(DistributedConfig.OFFSET_STORAGE_REPLICATION_FACTOR_CONFIG), + config.getInt(DistributedConfig.OFFSET_STORAGE_PARTITIONS_CONFIG), + "Offset topic does not have the expected settings" + ); + } + + protected String configTopic() { + return workerProps.get(DistributedConfig.CONFIG_TOPIC_CONFIG); + } + + protected String offsetTopic() { + return workerProps.get(DistributedConfig.OFFSET_STORAGE_TOPIC_CONFIG); + } + + protected String statusTopic() { + return workerProps.get(DistributedConfig.STATUS_STORAGE_TOPIC_CONFIG); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/MonitorableSinkConnector.java b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/MonitorableSinkConnector.java new file mode 100644 index 0000000..5733199 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/MonitorableSinkConnector.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.integration; + +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.connector.Task; +import org.apache.kafka.connect.runtime.TestSinkConnector; +import org.apache.kafka.connect.sink.SinkRecord; +import org.apache.kafka.connect.sink.SinkTask; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * A sink connector that is used in Apache Kafka integration tests to verify the behavior of the + * Connect framework, but that can be used in other integration tests as a simple connector that + * consumes and counts records. This class provides methods to find task instances + * which are initiated by the embedded connector, and wait for them to consume a desired number of + * messages. + */ +public class MonitorableSinkConnector extends TestSinkConnector { + + private static final Logger log = LoggerFactory.getLogger(MonitorableSinkConnector.class); + + private String connectorName; + private Map commonConfigs; + private ConnectorHandle connectorHandle; + + @Override + public void start(Map props) { + connectorHandle = RuntimeHandles.get().connectorHandle(props.get("name")); + connectorName = props.get("name"); + commonConfigs = props; + log.info("Starting connector {}", props.get("name")); + connectorHandle.recordConnectorStart(); + } + + @Override + public Class taskClass() { + return MonitorableSinkTask.class; + } + + @Override + public List> taskConfigs(int maxTasks) { + List> configs = new ArrayList<>(); + for (int i = 0; i < maxTasks; i++) { + Map config = new HashMap<>(commonConfigs); + config.put("connector.name", connectorName); + config.put("task.id", connectorName + "-" + i); + configs.add(config); + } + return configs; + } + + @Override + public void stop() { + log.info("Stopped {} connector {}", this.getClass().getSimpleName(), connectorName); + connectorHandle.recordConnectorStop(); + } + + @Override + public ConfigDef config() { + return new ConfigDef(); + } + + public static class MonitorableSinkTask extends SinkTask { + + private String connectorName; + private String taskId; + TaskHandle taskHandle; + Map committedOffsets; + Map> cachedTopicPartitions; + + public MonitorableSinkTask() { + this.committedOffsets = new HashMap<>(); + this.cachedTopicPartitions = new HashMap<>(); + } + + @Override + public String version() { + return "unknown"; + } + + @Override + public void start(Map props) { + taskId = props.get("task.id"); + connectorName = props.get("connector.name"); + taskHandle = RuntimeHandles.get().connectorHandle(connectorName).taskHandle(taskId); + log.debug("Starting task {}", taskId); + taskHandle.recordTaskStart(); + } + + @Override + public void open(Collection partitions) { + log.debug("Opening partitions {}", partitions); + taskHandle.partitionsAssigned(partitions); + } + + @Override + public void close(Collection partitions) { + log.debug("Closing partitions {}", partitions); + taskHandle.partitionsRevoked(partitions); + partitions.forEach(committedOffsets::remove); + } + + @Override + public void put(Collection records) { + for (SinkRecord rec : records) { + taskHandle.record(rec); + TopicPartition tp = cachedTopicPartitions + .computeIfAbsent(rec.topic(), v -> new HashMap<>()) + .computeIfAbsent(rec.kafkaPartition(), v -> new TopicPartition(rec.topic(), rec.kafkaPartition())); + committedOffsets.put(tp, committedOffsets.getOrDefault(tp, 0) + 1); + log.trace("Task {} obtained record (key='{}' value='{}')", taskId, rec.key(), rec.value()); + } + } + + @Override + public Map preCommit(Map offsets) { + taskHandle.partitionsCommitted(offsets.keySet()); + offsets.forEach((tp, offset) -> { + int recordsSinceLastCommit = committedOffsets.getOrDefault(tp, 0); + if (recordsSinceLastCommit != 0) { + taskHandle.commit(recordsSinceLastCommit); + log.debug("Forwarding to framework request to commit {} records for {}", recordsSinceLastCommit, tp); + committedOffsets.put(tp, 0); + } + }); + return offsets; + } + + @Override + public void stop() { + log.info("Stopped {} task {}", this.getClass().getSimpleName(), taskId); + taskHandle.recordTaskStop(); + } + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/MonitorableSourceConnector.java b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/MonitorableSourceConnector.java new file mode 100644 index 0000000..afd9325 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/MonitorableSourceConnector.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.integration; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.connect.connector.Task; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.header.ConnectHeaders; +import org.apache.kafka.connect.runtime.TestSourceConnector; +import org.apache.kafka.connect.source.SourceRecord; +import org.apache.kafka.connect.source.SourceTask; +import org.apache.kafka.tools.ThroughputThrottler; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; +import java.util.stream.LongStream; + +/** + * A source connector that is used in Apache Kafka integration tests to verify the behavior of + * the Connect framework, but that can be used in other integration tests as a simple connector + * that generates records of a fixed structure. The rate of record production can be adjusted + * through the configs 'throughput' and 'messages.per.poll' + */ +public class MonitorableSourceConnector extends TestSourceConnector { + private static final Logger log = LoggerFactory.getLogger(MonitorableSourceConnector.class); + + public static final String TOPIC_CONFIG = "topic"; + private String connectorName; + private ConnectorHandle connectorHandle; + private Map commonConfigs; + + @Override + public void start(Map props) { + connectorHandle = RuntimeHandles.get().connectorHandle(props.get("name")); + connectorName = connectorHandle.name(); + commonConfigs = props; + log.info("Started {} connector {}", this.getClass().getSimpleName(), connectorName); + connectorHandle.recordConnectorStart(); + if (Boolean.parseBoolean(props.getOrDefault("connector.start.inject.error", "false"))) { + throw new RuntimeException("Injecting errors during connector start"); + } + } + + @Override + public Class taskClass() { + return MonitorableSourceTask.class; + } + + @Override + public List> taskConfigs(int maxTasks) { + List> configs = new ArrayList<>(); + for (int i = 0; i < maxTasks; i++) { + Map config = new HashMap<>(commonConfigs); + config.put("connector.name", connectorName); + config.put("task.id", connectorName + "-" + i); + configs.add(config); + } + return configs; + } + + @Override + public void stop() { + log.info("Stopped {} connector {}", this.getClass().getSimpleName(), connectorName); + connectorHandle.recordConnectorStop(); + } + + @Override + public ConfigDef config() { + log.info("Configured {} connector {}", this.getClass().getSimpleName(), connectorName); + return new ConfigDef(); + } + + public static class MonitorableSourceTask extends SourceTask { + private String connectorName; + private String taskId; + private String topicName; + private TaskHandle taskHandle; + private volatile boolean stopped; + private long startingSeqno; + private long seqno; + private long throughput; + private int batchSize; + private ThroughputThrottler throttler; + + @Override + public String version() { + return "unknown"; + } + + @Override + public void start(Map props) { + taskId = props.get("task.id"); + connectorName = props.get("connector.name"); + topicName = props.getOrDefault(TOPIC_CONFIG, "sequential-topic"); + throughput = Long.valueOf(props.getOrDefault("throughput", "-1")); + batchSize = Integer.valueOf(props.getOrDefault("messages.per.poll", "1")); + taskHandle = RuntimeHandles.get().connectorHandle(connectorName).taskHandle(taskId); + Map offset = Optional.ofNullable( + context.offsetStorageReader().offset(Collections.singletonMap("task.id", taskId))) + .orElse(Collections.emptyMap()); + startingSeqno = Optional.ofNullable((Long) offset.get("saved")).orElse(0L); + log.info("Started {} task {} with properties {}", this.getClass().getSimpleName(), taskId, props); + throttler = new ThroughputThrottler(throughput, System.currentTimeMillis()); + taskHandle.recordTaskStart(); + if (Boolean.parseBoolean(props.getOrDefault("task-" + taskId + ".start.inject.error", "false"))) { + throw new RuntimeException("Injecting errors during task start"); + } + } + + @Override + public List poll() { + if (!stopped) { + if (throttler.shouldThrottle(seqno - startingSeqno, System.currentTimeMillis())) { + throttler.throttle(); + } + taskHandle.record(batchSize); + log.info("Returning batch of {} records", batchSize); + return LongStream.range(0, batchSize) + .mapToObj(i -> new SourceRecord( + Collections.singletonMap("task.id", taskId), + Collections.singletonMap("saved", ++seqno), + topicName, + null, + Schema.STRING_SCHEMA, + "key-" + taskId + "-" + seqno, + Schema.STRING_SCHEMA, + "value-" + taskId + "-" + seqno, + null, + new ConnectHeaders().addLong("header-" + seqno, seqno))) + .collect(Collectors.toList()); + } + return null; + } + + @Override + public void commit() { + log.info("Task {} committing offsets", taskId); + //TODO: save progress outside the offset topic, potentially in the task handle + } + + @Override + public void commitRecord(SourceRecord record, RecordMetadata metadata) { + log.trace("Committing record: {}", record); + taskHandle.commit(); + } + + @Override + public void stop() { + log.info("Stopped {} task {}", this.getClass().getSimpleName(), taskId); + stopped = true; + taskHandle.recordTaskStop(); + } + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/RebalanceSourceConnectorsIntegrationTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/RebalanceSourceConnectorsIntegrationTest.java new file mode 100644 index 0000000..b56a8fa --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/RebalanceSourceConnectorsIntegrationTest.java @@ -0,0 +1,382 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.integration; + +import org.apache.kafka.connect.runtime.rest.entities.ConnectorStateInfo; +import org.apache.kafka.connect.storage.StringConverter; +import org.apache.kafka.connect.util.clusters.EmbeddedConnectCluster; +import org.apache.kafka.test.IntegrationTest; +import org.junit.After; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestRule; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.apache.kafka.connect.integration.MonitorableSourceConnector.TOPIC_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.CONNECTOR_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.KEY_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.TASKS_MAX_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.VALUE_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.DEFAULT_TOPIC_CREATION_PREFIX; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.PARTITIONS_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.REPLICATION_FACTOR_CONFIG; +import static org.apache.kafka.connect.runtime.WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_CONFIG; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocolCompatibility.COMPATIBLE; +import static org.apache.kafka.connect.runtime.distributed.DistributedConfig.CONNECT_PROTOCOL_CONFIG; +import static org.apache.kafka.connect.runtime.distributed.DistributedConfig.SCHEDULED_REBALANCE_MAX_DELAY_MS_CONFIG; +import static org.apache.kafka.test.TestUtils.waitForCondition; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; + +/** + * Integration tests for incremental cooperative rebalancing between Connect workers + */ +@Category(IntegrationTest.class) +public class RebalanceSourceConnectorsIntegrationTest { + + private static final Logger log = LoggerFactory.getLogger(RebalanceSourceConnectorsIntegrationTest.class); + + private static final int NUM_TOPIC_PARTITIONS = 3; + private static final long CONNECTOR_SETUP_DURATION_MS = TimeUnit.SECONDS.toMillis(30); + private static final long WORKER_SETUP_DURATION_MS = TimeUnit.SECONDS.toMillis(60); + private static final int NUM_WORKERS = 3; + private static final int NUM_TASKS = 4; + private static final String CONNECTOR_NAME = "seq-source1"; + private static final String TOPIC_NAME = "sequential-topic"; + + private EmbeddedConnectCluster connect; + + @Rule + public TestRule watcher = ConnectIntegrationTestUtils.newTestWatcher(log); + + @Before + public void setup() { + // setup Connect worker properties + Map workerProps = new HashMap<>(); + workerProps.put(CONNECT_PROTOCOL_CONFIG, COMPATIBLE.toString()); + workerProps.put(OFFSET_COMMIT_INTERVAL_MS_CONFIG, String.valueOf(TimeUnit.SECONDS.toMillis(30))); + workerProps.put(SCHEDULED_REBALANCE_MAX_DELAY_MS_CONFIG, String.valueOf(TimeUnit.SECONDS.toMillis(30))); + + // setup Kafka broker properties + Properties brokerProps = new Properties(); + brokerProps.put("auto.create.topics.enable", "false"); + + // build a Connect cluster backed by Kafka and Zk + connect = new EmbeddedConnectCluster.Builder() + .name("connect-cluster") + .numWorkers(NUM_WORKERS) + .numBrokers(1) + .workerProps(workerProps) + .brokerProps(brokerProps) + .build(); + + // start the clusters + connect.start(); + } + + @After + public void close() { + // stop all Connect, Kafka and Zk threads. + connect.stop(); + } + + @Test + public void testStartTwoConnectors() throws Exception { + // create test topic + connect.kafka().createTopic(TOPIC_NAME, NUM_TOPIC_PARTITIONS); + + // setup up props for the source connector + Map props = defaultSourceConnectorProps(TOPIC_NAME); + + connect.assertions().assertAtLeastNumWorkersAreUp(NUM_WORKERS, + "Connect workers did not start in time."); + + // start a source connector + connect.configureConnector(CONNECTOR_NAME, props); + + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(CONNECTOR_NAME, NUM_TASKS, + "Connector tasks did not start in time."); + + // start a source connector + connect.configureConnector("another-source", props); + + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(CONNECTOR_NAME, NUM_TASKS, + "Connector tasks did not start in time."); + + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning("another-source", 4, + "Connector tasks did not start in time."); + } + + @Test + public void testReconfigConnector() throws Exception { + ConnectorHandle connectorHandle = RuntimeHandles.get().connectorHandle(CONNECTOR_NAME); + + // create test topic + String anotherTopic = "another-topic"; + connect.kafka().createTopic(TOPIC_NAME, NUM_TOPIC_PARTITIONS); + connect.kafka().createTopic(anotherTopic, NUM_TOPIC_PARTITIONS); + + // setup up props for the source connector + Map props = defaultSourceConnectorProps(TOPIC_NAME); + + connect.assertions().assertAtLeastNumWorkersAreUp(NUM_WORKERS, + "Connect workers did not start in time."); + + // start a source connector + connect.configureConnector(CONNECTOR_NAME, props); + + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(CONNECTOR_NAME, NUM_TASKS, + "Connector tasks did not start in time."); + + int numRecordsProduced = 100; + long recordTransferDurationMs = TimeUnit.SECONDS.toMillis(30); + + // consume all records from the source topic or fail, to ensure that they were correctly produced + int recordNum = connect.kafka().consume(numRecordsProduced, recordTransferDurationMs, TOPIC_NAME).count(); + assertTrue("Not enough records produced by source connector. Expected at least: " + numRecordsProduced + " + but got " + recordNum, + recordNum >= numRecordsProduced); + + // expect that we're going to restart the connector and its tasks + StartAndStopLatch restartLatch = connectorHandle.expectedStarts(1); + + // Reconfigure the source connector by changing the Kafka topic used as output + props.put(TOPIC_CONFIG, anotherTopic); + connect.configureConnector(CONNECTOR_NAME, props); + + // Wait for the connector *and tasks* to be restarted + assertTrue("Failed to alter connector configuration and see connector and tasks restart " + + "within " + CONNECTOR_SETUP_DURATION_MS + "ms", + restartLatch.await(CONNECTOR_SETUP_DURATION_MS, TimeUnit.MILLISECONDS)); + + // And wait for the Connect to show the connectors and tasks are running + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(CONNECTOR_NAME, NUM_TASKS, + "Connector tasks did not start in time."); + + // consume all records from the source topic or fail, to ensure that they were correctly produced + recordNum = connect.kafka().consume(numRecordsProduced, recordTransferDurationMs, anotherTopic).count(); + assertTrue("Not enough records produced by source connector. Expected at least: " + numRecordsProduced + " + but got " + recordNum, + recordNum >= numRecordsProduced); + } + + @Test + public void testDeleteConnector() throws Exception { + // create test topic + connect.kafka().createTopic(TOPIC_NAME, NUM_TOPIC_PARTITIONS); + + // setup up props for the source connector + Map props = defaultSourceConnectorProps(TOPIC_NAME); + + connect.assertions().assertAtLeastNumWorkersAreUp(NUM_WORKERS, + "Connect workers did not start in time."); + + // start several source connectors + IntStream.range(0, 4).forEachOrdered(i -> connect.configureConnector(CONNECTOR_NAME + i, props)); + + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(CONNECTOR_NAME + 3, NUM_TASKS, + "Connector tasks did not start in time."); + + // delete connector + connect.deleteConnector(CONNECTOR_NAME + 3); + + connect.assertions().assertConnectorAndTasksAreStopped(CONNECTOR_NAME + 3, + "Connector tasks did not stop in time."); + + waitForCondition(this::assertConnectorAndTasksAreUniqueAndBalanced, + WORKER_SETUP_DURATION_MS, "Connect and tasks are imbalanced between the workers."); + } + + @Test + public void testAddingWorker() throws Exception { + // create test topic + connect.kafka().createTopic(TOPIC_NAME, NUM_TOPIC_PARTITIONS); + + // setup up props for the source connector + Map props = defaultSourceConnectorProps(TOPIC_NAME); + + connect.assertions().assertAtLeastNumWorkersAreUp(NUM_WORKERS, + "Connect workers did not start in time."); + + // start a source connector + IntStream.range(0, 4).forEachOrdered(i -> connect.configureConnector(CONNECTOR_NAME + i, props)); + + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(CONNECTOR_NAME + 3, NUM_TASKS, + "Connector tasks did not start in time."); + + connect.addWorker(); + + connect.assertions().assertAtLeastNumWorkersAreUp(NUM_WORKERS + 1, + "Connect workers did not start in time."); + + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(CONNECTOR_NAME + 3, NUM_TASKS, + "Connector tasks did not start in time."); + + waitForCondition(this::assertConnectorAndTasksAreUniqueAndBalanced, + WORKER_SETUP_DURATION_MS, "Connect and tasks are imbalanced between the workers."); + } + + @Test + public void testRemovingWorker() throws Exception { + // create test topic + connect.kafka().createTopic(TOPIC_NAME, NUM_TOPIC_PARTITIONS); + + // setup up props for the source connector + Map props = defaultSourceConnectorProps(TOPIC_NAME); + + connect.assertions().assertExactlyNumWorkersAreUp(NUM_WORKERS, + "Connect workers did not start in time."); + + // start a source connector + IntStream.range(0, 4).forEachOrdered(i -> connect.configureConnector(CONNECTOR_NAME + i, props)); + + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(CONNECTOR_NAME + 3, NUM_TASKS, + "Connector tasks did not start in time."); + + connect.removeWorker(); + + connect.assertions().assertExactlyNumWorkersAreUp(NUM_WORKERS - 1, + "Connect workers did not start in time."); + + waitForCondition(this::assertConnectorAndTasksAreUniqueAndBalanced, + WORKER_SETUP_DURATION_MS, "Connect and tasks are imbalanced between the workers."); + } + + // should enable it after KAFKA-12495 fixed + @Ignore + @Test + public void testMultipleWorkersRejoining() throws Exception { + // create test topic + connect.kafka().createTopic(TOPIC_NAME, NUM_TOPIC_PARTITIONS); + + // setup up props for the source connector + Map props = defaultSourceConnectorProps(TOPIC_NAME); + + connect.assertions().assertExactlyNumWorkersAreUp(NUM_WORKERS, + "Connect workers did not start in time."); + + // start a source connector + IntStream.range(0, 4).forEachOrdered(i -> connect.configureConnector(CONNECTOR_NAME + i, props)); + + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(CONNECTOR_NAME + 3, NUM_TASKS, + "Connector tasks did not start in time."); + + waitForCondition(this::assertConnectorAndTasksAreUniqueAndBalanced, + WORKER_SETUP_DURATION_MS, "Connect and tasks are imbalanced between the workers."); + + Thread.sleep(TimeUnit.SECONDS.toMillis(10)); + + connect.removeWorker(); + connect.removeWorker(); + + connect.assertions().assertExactlyNumWorkersAreUp(NUM_WORKERS - 2, + "Connect workers did not stop in time."); + + Thread.sleep(TimeUnit.SECONDS.toMillis(10)); + + connect.addWorker(); + connect.addWorker(); + + connect.assertions().assertExactlyNumWorkersAreUp(NUM_WORKERS, + "Connect workers did not start in time."); + + Thread.sleep(TimeUnit.SECONDS.toMillis(10)); + + for (int i = 0; i < 4; ++i) { + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(CONNECTOR_NAME + i, NUM_TASKS, "Connector tasks did not start in time."); + } + + waitForCondition(this::assertConnectorAndTasksAreUniqueAndBalanced, + WORKER_SETUP_DURATION_MS, "Connect and tasks are imbalanced between the workers."); + } + + private Map defaultSourceConnectorProps(String topic) { + // setup up props for the source connector + Map props = new HashMap<>(); + props.put(CONNECTOR_CLASS_CONFIG, MonitorableSourceConnector.class.getSimpleName()); + props.put(TASKS_MAX_CONFIG, String.valueOf(NUM_TASKS)); + props.put(TOPIC_CONFIG, topic); + props.put("throughput", String.valueOf(10)); + props.put("messages.per.poll", String.valueOf(10)); + props.put(KEY_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(VALUE_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + REPLICATION_FACTOR_CONFIG, String.valueOf(1)); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + PARTITIONS_CONFIG, String.valueOf(1)); + return props; + } + + private boolean assertConnectorAndTasksAreUniqueAndBalanced() { + try { + Map> connectors = new HashMap<>(); + Map> tasks = new HashMap<>(); + for (String connector : connect.connectors()) { + ConnectorStateInfo info = connect.connectorStatus(connector); + connectors.computeIfAbsent(info.connector().workerId(), k -> new ArrayList<>()) + .add(connector); + info.tasks().forEach( + t -> tasks.computeIfAbsent(t.workerId(), k -> new ArrayList<>()) + .add(connector + "-" + t.id())); + } + + int maxConnectors = connectors.values().stream().mapToInt(Collection::size).max().orElse(0); + int minConnectors = connectors.values().stream().mapToInt(Collection::size).min().orElse(0); + int maxTasks = tasks.values().stream().mapToInt(Collection::size).max().orElse(0); + int minTasks = tasks.values().stream().mapToInt(Collection::size).min().orElse(0); + + log.debug("Connector balance: {}", formatAssignment(connectors)); + log.debug("Task balance: {}", formatAssignment(tasks)); + + assertNotEquals("Found no connectors running!", maxConnectors, 0); + assertNotEquals("Found no tasks running!", maxTasks, 0); + assertEquals("Connector assignments are not unique: " + connectors, + connectors.values().size(), + connectors.values().stream().distinct().collect(Collectors.toList()).size()); + assertEquals("Task assignments are not unique: " + tasks, + tasks.values().size(), + tasks.values().stream().distinct().collect(Collectors.toList()).size()); + assertTrue("Connectors are imbalanced: " + formatAssignment(connectors), maxConnectors - minConnectors < 2); + assertTrue("Tasks are imbalanced: " + formatAssignment(tasks), maxTasks - minTasks < 2); + return true; + } catch (Exception e) { + log.error("Could not check connector state info.", e); + return false; + } + } + + private static String formatAssignment(Map> assignment) { + StringBuilder result = new StringBuilder(); + for (String worker : assignment.keySet().stream().sorted().collect(Collectors.toList())) { + result.append(String.format("\n%s=%s", worker, assignment.getOrDefault(worker, + Collections.emptyList()))); + } + return result.toString(); + } + +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/RestExtensionIntegrationTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/RestExtensionIntegrationTest.java new file mode 100644 index 0000000..6ec86bd --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/RestExtensionIntegrationTest.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.integration; + +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.health.ConnectClusterState; +import org.apache.kafka.connect.health.ConnectorHealth; +import org.apache.kafka.connect.health.ConnectorState; +import org.apache.kafka.connect.health.ConnectorType; +import org.apache.kafka.connect.health.TaskState; +import org.apache.kafka.connect.rest.ConnectRestExtension; +import org.apache.kafka.connect.rest.ConnectRestExtensionContext; +import org.apache.kafka.connect.util.clusters.EmbeddedConnectCluster; +import org.apache.kafka.connect.util.clusters.WorkerHandle; +import org.apache.kafka.test.IntegrationTest; +import org.junit.After; +import org.junit.Test; +import org.junit.experimental.categories.Category; + +import javax.ws.rs.GET; +import javax.ws.rs.Path; +import javax.ws.rs.core.Response; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static javax.ws.rs.core.Response.Status.BAD_REQUEST; +import static org.apache.kafka.connect.runtime.ConnectorConfig.CONNECTOR_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.NAME_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.TASKS_MAX_CONFIG; +import static org.apache.kafka.connect.runtime.SinkConnectorConfig.TOPICS_CONFIG; +import static org.apache.kafka.connect.runtime.WorkerConfig.REST_EXTENSION_CLASSES_CONFIG; +import static org.apache.kafka.test.TestUtils.waitForCondition; +import static org.junit.Assert.assertEquals; + +/** + * A simple integration test to ensure that REST extensions are registered correctly. + */ +@Category(IntegrationTest.class) +public class RestExtensionIntegrationTest { + + private static final long REST_EXTENSION_REGISTRATION_TIMEOUT_MS = TimeUnit.MINUTES.toMillis(1); + private static final long CONNECTOR_HEALTH_AND_CONFIG_TIMEOUT_MS = TimeUnit.MINUTES.toMillis(1); + private static final int NUM_WORKERS = 1; + + private EmbeddedConnectCluster connect; + + @Test + public void testRestExtensionApi() throws InterruptedException { + // setup Connect worker properties + Map workerProps = new HashMap<>(); + workerProps.put(REST_EXTENSION_CLASSES_CONFIG, IntegrationTestRestExtension.class.getName()); + + // build a Connect cluster backed by Kafka and Zk + connect = new EmbeddedConnectCluster.Builder() + .name("connect-cluster") + .numWorkers(NUM_WORKERS) + .numBrokers(1) + .workerProps(workerProps) + .build(); + + // start the clusters + connect.start(); + + connect.assertions().assertAtLeastNumWorkersAreUp(NUM_WORKERS, + "Initial group of workers did not start in time."); + + WorkerHandle worker = connect.workers().stream() + .findFirst() + .orElseThrow(() -> new AssertionError("At least one worker handle should be available")); + + waitForCondition( + this::extensionIsRegistered, + REST_EXTENSION_REGISTRATION_TIMEOUT_MS, + "REST extension was never registered" + ); + + ConnectorHandle connectorHandle = RuntimeHandles.get().connectorHandle("test-conn"); + try { + // setup up props for the connector + Map connectorProps = new HashMap<>(); + connectorProps.put(CONNECTOR_CLASS_CONFIG, MonitorableSinkConnector.class.getSimpleName()); + connectorProps.put(TASKS_MAX_CONFIG, String.valueOf(1)); + connectorProps.put(TOPICS_CONFIG, "test-topic"); + + // start a connector + connectorHandle.taskHandle(connectorHandle.name() + "-0"); + StartAndStopLatch connectorStartLatch = connectorHandle.expectedStarts(1); + connect.configureConnector(connectorHandle.name(), connectorProps); + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(connectorHandle.name(), 1, + "Connector tasks did not start in time."); + connectorStartLatch.await(CONNECTOR_HEALTH_AND_CONFIG_TIMEOUT_MS, TimeUnit.MILLISECONDS); + + String workerId = String.format("%s:%d", worker.url().getHost(), worker.url().getPort()); + ConnectorHealth expectedHealth = new ConnectorHealth( + connectorHandle.name(), + new ConnectorState( + "RUNNING", + workerId, + null + ), + Collections.singletonMap( + 0, + new TaskState(0, "RUNNING", workerId, null) + ), + ConnectorType.SINK + ); + + connectorProps.put(NAME_CONFIG, connectorHandle.name()); + + // Test the REST extension API; specifically, that the connector's health and configuration + // are available to the REST extension we registered and that they contain expected values + waitForCondition( + () -> verifyConnectorHealthAndConfig(connectorHandle.name(), expectedHealth, connectorProps), + CONNECTOR_HEALTH_AND_CONFIG_TIMEOUT_MS, + "Connector health and/or config was never accessible by the REST extension" + ); + } finally { + RuntimeHandles.get().deleteConnector(connectorHandle.name()); + } + } + + @After + public void close() { + // stop all Connect, Kafka and Zk threads. + connect.stop(); + IntegrationTestRestExtension.instance = null; + } + + private boolean extensionIsRegistered() { + try { + String extensionUrl = connect.endpointForResource("integration-test-rest-extension/registered"); + Response response = connect.requestGet(extensionUrl); + return response.getStatus() < BAD_REQUEST.getStatusCode(); + } catch (ConnectException e) { + return false; + } + } + + private boolean verifyConnectorHealthAndConfig( + String connectorName, + ConnectorHealth expectedHealth, + Map expectedConfig + ) { + ConnectClusterState clusterState = + IntegrationTestRestExtension.instance.restPluginContext.clusterState(); + + ConnectorHealth actualHealth = clusterState.connectorHealth(connectorName); + if (actualHealth.tasksState().isEmpty()) { + // Happens if the task has been started but its status has not yet been picked up from + // the status topic by the worker. + return false; + } + Map actualConfig = clusterState.connectorConfig(connectorName); + + assertEquals(expectedConfig, actualConfig); + assertEquals(expectedHealth, actualHealth); + + return true; + } + + public static class IntegrationTestRestExtension implements ConnectRestExtension { + private static IntegrationTestRestExtension instance; + + public ConnectRestExtensionContext restPluginContext; + + @Override + public void register(ConnectRestExtensionContext restPluginContext) { + instance = this; + this.restPluginContext = restPluginContext; + // Immediately request a list of connectors to confirm that the context and its fields + // has been fully initialized and there is no risk of deadlock + restPluginContext.clusterState().connectors(); + // Install a new REST resource that can be used to confirm that the extension has been + // successfully registered + restPluginContext.configurable().register(new IntegrationTestRestExtensionResource()); + } + + @Override + public void close() { + } + + @Override + public void configure(Map configs) { + } + + @Override + public String version() { + return "test"; + } + + @Path("integration-test-rest-extension") + public static class IntegrationTestRestExtensionResource { + + @GET + @Path("/registered") + public boolean isRegistered() { + return true; + } + } + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/RuntimeHandles.java b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/RuntimeHandles.java new file mode 100644 index 0000000..c9900f3 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/RuntimeHandles.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.integration; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * A singleton class which provides a shared class for {@link ConnectorHandle}s and {@link TaskHandle}s that are + * required for integration tests. + */ +public class RuntimeHandles { + + private static final RuntimeHandles INSTANCE = new RuntimeHandles(); + + private final Map connectorHandles = new ConcurrentHashMap<>(); + + private RuntimeHandles() { + } + + /** + * @return the shared {@link RuntimeHandles} instance. + */ + public static RuntimeHandles get() { + return INSTANCE; + } + + /** + * Get or create a connector handle for a given connector name. The connector need not be running at the time + * this method is called. Once the connector is created, it will bind to this handle. Binding happens with the + * connectorName. + * + * @param connectorName the name of the connector + * @return a non-null {@link ConnectorHandle} + */ + public ConnectorHandle connectorHandle(String connectorName) { + return connectorHandles.computeIfAbsent(connectorName, k -> new ConnectorHandle(connectorName)); + } + + /** + * Delete the connector handle for this connector name. + * + * @param connectorName name of the connector + */ + public void deleteConnector(String connectorName) { + connectorHandles.remove(connectorName); + } + +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/SessionedProtocolIntegrationTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/SessionedProtocolIntegrationTest.java new file mode 100644 index 0000000..8956a86 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/SessionedProtocolIntegrationTest.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.integration; + +import org.apache.kafka.connect.runtime.distributed.ConnectProtocolCompatibility; +import org.apache.kafka.connect.storage.StringConverter; +import org.apache.kafka.connect.util.clusters.EmbeddedConnectCluster; +import org.apache.kafka.test.IntegrationTest; +import org.junit.After; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static javax.ws.rs.core.Response.Status.BAD_REQUEST; +import static javax.ws.rs.core.Response.Status.FORBIDDEN; +import static org.apache.kafka.connect.runtime.ConnectorConfig.CONNECTOR_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.KEY_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.TASKS_MAX_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.VALUE_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.SinkConnectorConfig.TOPICS_CONFIG; +import static org.apache.kafka.connect.runtime.distributed.DistributedConfig.CONNECT_PROTOCOL_CONFIG; +import static org.apache.kafka.connect.runtime.rest.InternalRequestSignature.SIGNATURE_ALGORITHM_HEADER; +import static org.apache.kafka.connect.runtime.rest.InternalRequestSignature.SIGNATURE_HEADER; +import static org.junit.Assert.assertEquals; + +/** + * A simple integration test to ensure that internal request validation becomes enabled with the + * "sessioned" protocol. + */ +@Category(IntegrationTest.class) +public class SessionedProtocolIntegrationTest { + + private static final Logger log = LoggerFactory.getLogger(SessionedProtocolIntegrationTest.class); + + private static final String CONNECTOR_NAME = "connector"; + private static final long CONNECTOR_SETUP_DURATION_MS = 60000; + + private EmbeddedConnectCluster connect; + private ConnectorHandle connectorHandle; + + @Before + public void setup() { + // setup Connect worker properties + Map workerProps = new HashMap<>(); + workerProps.put(CONNECT_PROTOCOL_CONFIG, ConnectProtocolCompatibility.SESSIONED.protocol()); + + // build a Connect cluster backed by Kafka and Zk + connect = new EmbeddedConnectCluster.Builder() + .name("connect-cluster") + .numWorkers(2) + .numBrokers(1) + .workerProps(workerProps) + .build(); + + // start the clusters + connect.start(); + + // get a handle to the connector + connectorHandle = RuntimeHandles.get().connectorHandle(CONNECTOR_NAME); + } + + @After + public void close() { + // stop all Connect, Kafka and Zk threads. + connect.stop(); + } + + @Test + @Ignore + // TODO: This test runs fine locally but fails on Jenkins. Ignoring for now, should revisit when + // possible. + public void ensureInternalEndpointIsSecured() throws Throwable { + final String connectorTasksEndpoint = connect.endpointForResource(String.format( + "connectors/%s/tasks", + CONNECTOR_NAME + )); + final Map emptyHeaders = new HashMap<>(); + final Map invalidSignatureHeaders = new HashMap<>(); + invalidSignatureHeaders.put(SIGNATURE_HEADER, "S2Fma2Flc3F1ZQ=="); + invalidSignatureHeaders.put(SIGNATURE_ALGORITHM_HEADER, "HmacSHA256"); + + // We haven't created the connector yet, but this should still return a 400 instead of a 404 + // if the endpoint is secured + log.info( + "Making a POST request to the {} endpoint with no connector started and no signature header; " + + "expecting 400 error response", + connectorTasksEndpoint + ); + assertEquals( + BAD_REQUEST.getStatusCode(), + connect.requestPost(connectorTasksEndpoint, "[]", emptyHeaders).getStatus() + ); + + // Try again, but with an invalid signature + log.info( + "Making a POST request to the {} endpoint with no connector started and an invalid signature header; " + + "expecting 403 error response", + connectorTasksEndpoint + ); + assertEquals( + FORBIDDEN.getStatusCode(), + connect.requestPost(connectorTasksEndpoint, "[]", invalidSignatureHeaders).getStatus() + ); + + // Create the connector now + // setup up props for the sink connector + Map connectorProps = new HashMap<>(); + connectorProps.put(CONNECTOR_CLASS_CONFIG, MonitorableSinkConnector.class.getSimpleName()); + connectorProps.put(TASKS_MAX_CONFIG, String.valueOf(1)); + connectorProps.put(TOPICS_CONFIG, "test-topic"); + connectorProps.put(KEY_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + connectorProps.put(VALUE_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + + // start a sink connector + log.info("Starting the {} connector", CONNECTOR_NAME); + StartAndStopLatch startLatch = connectorHandle.expectedStarts(1); + connect.configureConnector(CONNECTOR_NAME, connectorProps); + startLatch.await(CONNECTOR_SETUP_DURATION_MS, TimeUnit.MILLISECONDS); + + + // Verify the exact same behavior, after starting the connector + + // We haven't created the connector yet, but this should still return a 400 instead of a 404 + // if the endpoint is secured + log.info( + "Making a POST request to the {} endpoint with the connector started and no signature header; " + + "expecting 400 error response", + connectorTasksEndpoint + ); + assertEquals( + BAD_REQUEST.getStatusCode(), + connect.requestPost(connectorTasksEndpoint, "[]", emptyHeaders).getStatus() + ); + + // Try again, but with an invalid signature + log.info( + "Making a POST request to the {} endpoint with the connector started and an invalid signature header; " + + "expecting 403 error response", + connectorTasksEndpoint + ); + assertEquals( + FORBIDDEN.getStatusCode(), + connect.requestPost(connectorTasksEndpoint, "[]", invalidSignatureHeaders).getStatus() + ); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/SinkConnectorsIntegrationTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/SinkConnectorsIntegrationTest.java new file mode 100644 index 0000000..a8bfbb2 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/SinkConnectorsIntegrationTest.java @@ -0,0 +1,321 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.integration; + +import org.apache.kafka.clients.consumer.CooperativeStickyAssignor; +import org.apache.kafka.clients.consumer.RoundRobinAssignor; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.connect.sink.SinkRecord; +import org.apache.kafka.connect.storage.StringConverter; +import org.apache.kafka.connect.util.clusters.EmbeddedConnectCluster; +import org.apache.kafka.test.IntegrationTest; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.experimental.categories.Category; + +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Objects; +import java.util.Properties; +import java.util.Set; +import java.util.function.Consumer; + +import static org.apache.kafka.clients.consumer.ConsumerConfig.DEFAULT_API_TIMEOUT_MS_CONFIG; +import static org.apache.kafka.clients.consumer.ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.CONNECTOR_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.CONNECTOR_CLIENT_CONSUMER_OVERRIDES_PREFIX; +import static org.apache.kafka.connect.runtime.ConnectorConfig.TASKS_MAX_CONFIG; +import static org.apache.kafka.connect.runtime.SinkConnectorConfig.TOPICS_CONFIG; +import static org.apache.kafka.connect.runtime.WorkerConfig.CONNECTOR_CLIENT_POLICY_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.WorkerConfig.KEY_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.WorkerConfig.VALUE_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.test.TestUtils.waitForCondition; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * Integration test for sink connectors + */ +@Category(IntegrationTest.class) +public class SinkConnectorsIntegrationTest { + + private static final int NUM_TASKS = 1; + private static final int NUM_WORKERS = 1; + private static final String CONNECTOR_NAME = "connect-integration-test-sink"; + private static final long TASK_CONSUME_TIMEOUT_MS = 10_000L; + + private EmbeddedConnectCluster connect; + + @Before + public void setup() throws Exception { + Map workerProps = new HashMap<>(); + // permit all Kafka client overrides; required for testing different consumer partition assignment strategies + workerProps.put(CONNECTOR_CLIENT_POLICY_CLASS_CONFIG, "All"); + + // setup Kafka broker properties + Properties brokerProps = new Properties(); + brokerProps.put("auto.create.topics.enable", "false"); + brokerProps.put("delete.topic.enable", "true"); + + // build a Connect cluster backed by Kafka and Zk + connect = new EmbeddedConnectCluster.Builder() + .name("connect-cluster") + .numWorkers(NUM_WORKERS) + .workerProps(workerProps) + .brokerProps(brokerProps) + .build(); + connect.start(); + connect.assertions().assertAtLeastNumWorkersAreUp(NUM_WORKERS, "Initial group of workers did not start in time."); + } + + @After + public void close() { + // delete connector handle + RuntimeHandles.get().deleteConnector(CONNECTOR_NAME); + + // stop all Connect, Kafka and Zk threads. + connect.stop(); + } + + @Test + public void testEagerConsumerPartitionAssignment() throws Exception { + final String topic1 = "topic1", topic2 = "topic2", topic3 = "topic3"; + final TopicPartition tp1 = new TopicPartition(topic1, 0), tp2 = new TopicPartition(topic2, 0), tp3 = new TopicPartition(topic3, 0); + final Collection topics = Arrays.asList(topic1, topic2, topic3); + + Map connectorProps = baseSinkConnectorProps(String.join(",", topics)); + // Need an eager assignor here; round robin is as good as any + connectorProps.put( + CONNECTOR_CLIENT_CONSUMER_OVERRIDES_PREFIX + PARTITION_ASSIGNMENT_STRATEGY_CONFIG, + RoundRobinAssignor.class.getName()); + // After deleting a topic, offset commits will fail for it; reduce the timeout here so that the test doesn't take forever to proceed past that point + connectorProps.put( + CONNECTOR_CLIENT_CONSUMER_OVERRIDES_PREFIX + DEFAULT_API_TIMEOUT_MS_CONFIG, + "5000"); + + final Set consumedRecordValues = new HashSet<>(); + Consumer onPut = record -> assertTrue("Task received duplicate record from Connect", consumedRecordValues.add(Objects.toString(record.value()))); + ConnectorHandle connector = RuntimeHandles.get().connectorHandle(CONNECTOR_NAME); + TaskHandle task = connector.taskHandle(CONNECTOR_NAME + "-0", onPut); + + connect.configureConnector(CONNECTOR_NAME, connectorProps); + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(CONNECTOR_NAME, NUM_TASKS, "Connector tasks did not start in time."); + + // None of the topics has been created yet; the task shouldn't be assigned any partitions + assertEquals(0, task.numPartitionsAssigned()); + + Set expectedRecordValues = new HashSet<>(); + Set expectedAssignment = new HashSet<>(); + + connect.kafka().createTopic(topic1, 1); + expectedAssignment.add(tp1); + connect.kafka().produce(topic1, "t1v1"); + expectedRecordValues.add("t1v1"); + + waitForCondition( + () -> expectedRecordValues.equals(consumedRecordValues), + TASK_CONSUME_TIMEOUT_MS, + "Task did not receive records in time"); + assertEquals(1, task.timesAssigned(tp1)); + assertEquals(0, task.timesRevoked(tp1)); + assertEquals(expectedAssignment, task.assignment()); + + connect.kafka().createTopic(topic2, 1); + expectedAssignment.add(tp2); + connect.kafka().produce(topic2, "t2v1"); + expectedRecordValues.add("t2v1"); + connect.kafka().produce(topic2, "t1v2"); + expectedRecordValues.add("t1v2"); + + waitForCondition( + () -> expectedRecordValues.equals(consumedRecordValues), + TASK_CONSUME_TIMEOUT_MS, + "Task did not receive records in time"); + assertEquals(2, task.timesAssigned(tp1)); + assertEquals(1, task.timesRevoked(tp1)); + assertEquals(1, task.timesCommitted(tp1)); + assertEquals(1, task.timesAssigned(tp2)); + assertEquals(0, task.timesRevoked(tp2)); + assertEquals(expectedAssignment, task.assignment()); + + connect.kafka().createTopic(topic3, 1); + expectedAssignment.add(tp3); + connect.kafka().produce(topic3, "t3v1"); + expectedRecordValues.add("t3v1"); + connect.kafka().produce(topic2, "t2v2"); + expectedRecordValues.add("t2v2"); + connect.kafka().produce(topic2, "t1v3"); + expectedRecordValues.add("t1v3"); + + expectedAssignment.add(tp3); + waitForCondition( + () -> expectedRecordValues.equals(consumedRecordValues), + TASK_CONSUME_TIMEOUT_MS, + "Task did not receive records in time"); + assertEquals(3, task.timesAssigned(tp1)); + assertEquals(2, task.timesRevoked(tp1)); + assertEquals(2, task.timesCommitted(tp1)); + assertEquals(2, task.timesAssigned(tp2)); + assertEquals(1, task.timesRevoked(tp2)); + assertEquals(1, task.timesCommitted(tp2)); + assertEquals(1, task.timesAssigned(tp3)); + assertEquals(0, task.timesRevoked(tp3)); + assertEquals(expectedAssignment, task.assignment()); + + connect.kafka().deleteTopic(topic1); + expectedAssignment.remove(tp1); + connect.kafka().produce(topic3, "t3v2"); + expectedRecordValues.add("t3v2"); + connect.kafka().produce(topic2, "t2v3"); + expectedRecordValues.add("t2v3"); + + waitForCondition( + () -> expectedRecordValues.equals(consumedRecordValues) && expectedAssignment.equals(task.assignment()), + TASK_CONSUME_TIMEOUT_MS, + "Timed out while waiting for task to receive records and updated topic partition assignment"); + assertEquals(3, task.timesAssigned(tp1)); + assertEquals(3, task.timesRevoked(tp1)); + assertEquals(3, task.timesCommitted(tp1)); + assertEquals(3, task.timesAssigned(tp2)); + assertEquals(2, task.timesRevoked(tp2)); + assertEquals(2, task.timesCommitted(tp2)); + assertEquals(2, task.timesAssigned(tp3)); + assertEquals(1, task.timesRevoked(tp3)); + assertEquals(1, task.timesCommitted(tp3)); + } + + @Test + public void testCooperativeConsumerPartitionAssignment() throws Exception { + final String topic1 = "topic1", topic2 = "topic2", topic3 = "topic3"; + final TopicPartition tp1 = new TopicPartition(topic1, 0), tp2 = new TopicPartition(topic2, 0), tp3 = new TopicPartition(topic3, 0); + final Collection topics = Arrays.asList(topic1, topic2, topic3); + + Map connectorProps = baseSinkConnectorProps(String.join(",", topics)); + // Need an eager assignor here; round robin is as good as any + connectorProps.put( + CONNECTOR_CLIENT_CONSUMER_OVERRIDES_PREFIX + PARTITION_ASSIGNMENT_STRATEGY_CONFIG, + CooperativeStickyAssignor.class.getName()); + // After deleting a topic, offset commits will fail for it; reduce the timeout here so that the test doesn't take forever to proceed past that point + connectorProps.put( + CONNECTOR_CLIENT_CONSUMER_OVERRIDES_PREFIX + DEFAULT_API_TIMEOUT_MS_CONFIG, + "5000"); + + final Set consumedRecordValues = new HashSet<>(); + Consumer onPut = record -> assertTrue("Task received duplicate record from Connect", consumedRecordValues.add(Objects.toString(record.value()))); + ConnectorHandle connector = RuntimeHandles.get().connectorHandle(CONNECTOR_NAME); + TaskHandle task = connector.taskHandle(CONNECTOR_NAME + "-0", onPut); + + connect.configureConnector(CONNECTOR_NAME, connectorProps); + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(CONNECTOR_NAME, NUM_TASKS, "Connector tasks did not start in time."); + + // None of the topics has been created yet; the task shouldn't be assigned any partitions + assertEquals(0, task.numPartitionsAssigned()); + + Set expectedRecordValues = new HashSet<>(); + Set expectedAssignment = new HashSet<>(); + + connect.kafka().createTopic(topic1, 1); + expectedAssignment.add(tp1); + connect.kafka().produce(topic1, "t1v1"); + expectedRecordValues.add("t1v1"); + + waitForCondition( + () -> expectedRecordValues.equals(consumedRecordValues), + TASK_CONSUME_TIMEOUT_MS, + "Task did not receive records in time"); + assertEquals(1, task.timesAssigned(tp1)); + assertEquals(0, task.timesRevoked(tp1)); + assertEquals(expectedAssignment, task.assignment()); + + connect.kafka().createTopic(topic2, 1); + expectedAssignment.add(tp2); + connect.kafka().produce(topic2, "t2v1"); + expectedRecordValues.add("t2v1"); + connect.kafka().produce(topic2, "t1v2"); + expectedRecordValues.add("t1v2"); + + waitForCondition( + () -> expectedRecordValues.equals(consumedRecordValues), + TASK_CONSUME_TIMEOUT_MS, + "Task did not receive records in time"); + assertEquals(1, task.timesAssigned(tp1)); + assertEquals(0, task.timesRevoked(tp1)); + assertEquals(0, task.timesCommitted(tp1)); + assertEquals(1, task.timesAssigned(tp2)); + assertEquals(0, task.timesRevoked(tp2)); + assertEquals(expectedAssignment, task.assignment()); + + connect.kafka().createTopic(topic3, 1); + expectedAssignment.add(tp3); + connect.kafka().produce(topic3, "t3v1"); + expectedRecordValues.add("t3v1"); + connect.kafka().produce(topic2, "t2v2"); + expectedRecordValues.add("t2v2"); + connect.kafka().produce(topic2, "t1v3"); + expectedRecordValues.add("t1v3"); + + expectedAssignment.add(tp3); + waitForCondition( + () -> expectedRecordValues.equals(consumedRecordValues), + TASK_CONSUME_TIMEOUT_MS, + "Task did not receive records in time"); + assertEquals(1, task.timesAssigned(tp1)); + assertEquals(0, task.timesRevoked(tp1)); + assertEquals(0, task.timesCommitted(tp1)); + assertEquals(1, task.timesAssigned(tp2)); + assertEquals(0, task.timesRevoked(tp2)); + assertEquals(0, task.timesCommitted(tp2)); + assertEquals(1, task.timesAssigned(tp3)); + assertEquals(0, task.timesRevoked(tp3)); + assertEquals(expectedAssignment, task.assignment()); + + connect.kafka().deleteTopic(topic1); + expectedAssignment.remove(tp1); + connect.kafka().produce(topic3, "t3v2"); + expectedRecordValues.add("t3v2"); + connect.kafka().produce(topic2, "t2v3"); + expectedRecordValues.add("t2v3"); + + waitForCondition( + () -> expectedRecordValues.equals(consumedRecordValues) && expectedAssignment.equals(task.assignment()), + TASK_CONSUME_TIMEOUT_MS, + "Timed out while waiting for task to receive records and updated topic partition assignment"); + assertEquals(1, task.timesAssigned(tp1)); + assertEquals(1, task.timesRevoked(tp1)); + assertEquals(1, task.timesCommitted(tp1)); + assertEquals(1, task.timesAssigned(tp2)); + assertEquals(0, task.timesRevoked(tp2)); + assertEquals(0, task.timesCommitted(tp2)); + assertEquals(1, task.timesAssigned(tp3)); + assertEquals(0, task.timesRevoked(tp3)); + assertEquals(0, task.timesCommitted(tp3)); + } + + private Map baseSinkConnectorProps(String topics) { + Map props = new HashMap<>(); + props.put(CONNECTOR_CLASS_CONFIG, MonitorableSinkConnector.class.getSimpleName()); + props.put(TASKS_MAX_CONFIG, String.valueOf(NUM_TASKS)); + props.put(TOPICS_CONFIG, topics); + props.put(KEY_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(VALUE_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + return props; + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/SourceConnectorsIntegrationTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/SourceConnectorsIntegrationTest.java new file mode 100644 index 0000000..b35b072 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/SourceConnectorsIntegrationTest.java @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.integration; + +import org.apache.kafka.connect.runtime.SourceConnectorConfig; +import org.apache.kafka.connect.storage.StringConverter; +import org.apache.kafka.connect.util.clusters.EmbeddedConnectCluster; +import org.apache.kafka.test.IntegrationTest; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.experimental.categories.Category; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; +import java.util.stream.IntStream; + +import static org.apache.kafka.connect.integration.MonitorableSourceConnector.TOPIC_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.CONNECTOR_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.KEY_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.NAME_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.TASKS_MAX_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.VALUE_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.SourceConnectorConfig.TOPIC_CREATION_GROUPS_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.DEFAULT_TOPIC_CREATION_PREFIX; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.EXCLUDE_REGEX_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.INCLUDE_REGEX_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.PARTITIONS_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.REPLICATION_FACTOR_CONFIG; +import static org.apache.kafka.connect.runtime.WorkerConfig.CONNECTOR_CLIENT_POLICY_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.WorkerConfig.TOPIC_CREATION_ENABLE_CONFIG; +import static org.apache.kafka.connect.util.clusters.EmbeddedConnectCluster.DEFAULT_NUM_BROKERS; + +/** + * Integration test for source connectors with a focus on topic creation with custom properties by + * the connector tasks. + */ +@Category(IntegrationTest.class) +public class SourceConnectorsIntegrationTest { + + private static final int NUM_WORKERS = 3; + private static final int NUM_TASKS = 1; + private static final String FOO_TOPIC = "foo-topic"; + private static final String FOO_CONNECTOR = "foo-source"; + private static final String BAR_TOPIC = "bar-topic"; + private static final String BAR_CONNECTOR = "bar-source"; + private static final String FOO_GROUP = "foo"; + private static final String BAR_GROUP = "bar"; + private static final int DEFAULT_REPLICATION_FACTOR = DEFAULT_NUM_BROKERS; + private static final int DEFAULT_PARTITIONS = 1; + private static final int FOO_GROUP_REPLICATION_FACTOR = DEFAULT_NUM_BROKERS; + private static final int FOO_GROUP_PARTITIONS = 9; + + private EmbeddedConnectCluster.Builder connectBuilder; + private EmbeddedConnectCluster connect; + Map workerProps = new HashMap<>(); + Properties brokerProps = new Properties(); + + @Before + public void setup() { + // setup Connect worker properties + workerProps.put(CONNECTOR_CLIENT_POLICY_CLASS_CONFIG, "All"); + + // setup Kafka broker properties + brokerProps.put("auto.create.topics.enable", String.valueOf(false)); + + // build a Connect cluster backed by Kafka and Zk + connectBuilder = new EmbeddedConnectCluster.Builder() + .name("connect-cluster") + .numWorkers(NUM_WORKERS) + .workerProps(workerProps) + .brokerProps(brokerProps) + .maskExitProcedures(true); // true is the default, setting here as example + } + + @After + public void close() { + // stop all Connect, Kafka and Zk threads. + connect.stop(); + } + + @Test + public void testTopicsAreCreatedWhenAutoCreateTopicsIsEnabledAtTheBroker() throws InterruptedException { + brokerProps.put("auto.create.topics.enable", String.valueOf(true)); + workerProps.put(TOPIC_CREATION_ENABLE_CONFIG, String.valueOf(false)); + connect = connectBuilder.brokerProps(brokerProps).workerProps(workerProps).build(); + // start the clusters + connect.start(); + + connect.assertions().assertAtLeastNumWorkersAreUp(NUM_WORKERS, "Initial group of workers did not start in time."); + + Map fooProps = sourceConnectorPropsWithGroups(FOO_TOPIC); + + // start a source connector + connect.configureConnector(FOO_CONNECTOR, fooProps); + fooProps.put(NAME_CONFIG, FOO_CONNECTOR); + + connect.assertions().assertExactlyNumErrorsOnConnectorConfigValidation(fooProps.get(CONNECTOR_CLASS_CONFIG), fooProps, 0, + "Validating connector configuration produced an unexpected number or errors."); + + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(FOO_CONNECTOR, NUM_TASKS, + "Connector tasks did not start in time."); + + connect.assertions().assertTopicsExist(FOO_TOPIC); + connect.assertions().assertTopicSettings(FOO_TOPIC, DEFAULT_REPLICATION_FACTOR, + DEFAULT_PARTITIONS, "Topic " + FOO_TOPIC + " does not have the expected settings"); + } + + @Test + public void testTopicsAreCreatedWhenTopicCreationIsEnabled() throws InterruptedException { + connect = connectBuilder.build(); + // start the clusters + connect.start(); + + connect.assertions().assertAtLeastNumWorkersAreUp(NUM_WORKERS, "Initial group of workers did not start in time."); + + Map fooProps = sourceConnectorPropsWithGroups(FOO_TOPIC); + + // start a source connector + connect.configureConnector(FOO_CONNECTOR, fooProps); + fooProps.put(NAME_CONFIG, FOO_CONNECTOR); + + connect.assertions().assertExactlyNumErrorsOnConnectorConfigValidation(fooProps.get(CONNECTOR_CLASS_CONFIG), fooProps, 0, + "Validating connector configuration produced an unexpected number or errors."); + + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(FOO_CONNECTOR, NUM_TASKS, + "Connector tasks did not start in time."); + + connect.assertions().assertTopicsExist(FOO_TOPIC); + connect.assertions().assertTopicSettings(FOO_TOPIC, FOO_GROUP_REPLICATION_FACTOR, + FOO_GROUP_PARTITIONS, "Topic " + FOO_TOPIC + " does not have the expected settings"); + } + + @Test + public void testSwitchingToTopicCreationEnabled() throws InterruptedException { + workerProps.put(TOPIC_CREATION_ENABLE_CONFIG, String.valueOf(false)); + connect = connectBuilder.build(); + // start the clusters + connect.start(); + + connect.kafka().createTopic(BAR_TOPIC, DEFAULT_PARTITIONS, DEFAULT_REPLICATION_FACTOR, Collections.emptyMap()); + + connect.assertions().assertTopicsExist(BAR_TOPIC); + connect.assertions().assertTopicSettings(BAR_TOPIC, DEFAULT_REPLICATION_FACTOR, + DEFAULT_PARTITIONS, "Topic " + BAR_TOPIC + " does not have the expected settings"); + + connect.assertions().assertAtLeastNumWorkersAreUp(NUM_WORKERS, "Initial group of workers did not start in time."); + + Map barProps = defaultSourceConnectorProps(BAR_TOPIC); + // start a source connector with topic creation properties + connect.configureConnector(BAR_CONNECTOR, barProps); + barProps.put(NAME_CONFIG, BAR_CONNECTOR); + + Map fooProps = sourceConnectorPropsWithGroups(FOO_TOPIC); + // start a source connector without topic creation properties + connect.configureConnector(FOO_CONNECTOR, fooProps); + fooProps.put(NAME_CONFIG, FOO_CONNECTOR); + + connect.assertions().assertExactlyNumErrorsOnConnectorConfigValidation(fooProps.get(CONNECTOR_CLASS_CONFIG), fooProps, 0, + "Validating connector configuration produced an unexpected number or errors."); + + connect.assertions().assertExactlyNumErrorsOnConnectorConfigValidation(barProps.get(CONNECTOR_CLASS_CONFIG), barProps, 0, + "Validating connector configuration produced an unexpected number or errors."); + + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(FOO_CONNECTOR, NUM_TASKS, + "Connector tasks did not start in time."); + + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(BAR_CONNECTOR, NUM_TASKS, + "Connector tasks did not start in time."); + + connect.assertions().assertTopicsExist(BAR_TOPIC); + connect.assertions().assertTopicSettings(BAR_TOPIC, DEFAULT_REPLICATION_FACTOR, + DEFAULT_PARTITIONS, "Topic " + BAR_TOPIC + " does not have the expected settings"); + + connect.assertions().assertTopicsDoNotExist(FOO_TOPIC); + + connect.activeWorkers().forEach(w -> connect.removeWorker(w)); + + workerProps.put(TOPIC_CREATION_ENABLE_CONFIG, String.valueOf(true)); + + IntStream.range(0, 3).forEach(i -> connect.addWorker()); + connect.assertions().assertAtLeastNumWorkersAreUp(NUM_WORKERS, "Initial group of workers did not start in time."); + + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(FOO_CONNECTOR, NUM_TASKS, + "Connector tasks did not start in time."); + + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(BAR_CONNECTOR, NUM_TASKS, + "Connector tasks did not start in time."); + + connect.assertions().assertTopicsExist(FOO_TOPIC); + connect.assertions().assertTopicSettings(FOO_TOPIC, FOO_GROUP_REPLICATION_FACTOR, + FOO_GROUP_PARTITIONS, "Topic " + FOO_TOPIC + " does not have the expected settings"); + connect.assertions().assertTopicsExist(BAR_TOPIC); + connect.assertions().assertTopicSettings(BAR_TOPIC, DEFAULT_REPLICATION_FACTOR, + DEFAULT_PARTITIONS, "Topic " + BAR_TOPIC + " does not have the expected settings"); + } + + private Map defaultSourceConnectorProps(String topic) { + // setup up props for the source connector + Map props = new HashMap<>(); + props.put(CONNECTOR_CLASS_CONFIG, MonitorableSourceConnector.class.getSimpleName()); + props.put(TASKS_MAX_CONFIG, String.valueOf(NUM_TASKS)); + props.put(TOPIC_CONFIG, topic); + props.put("throughput", String.valueOf(10)); + props.put("messages.per.poll", String.valueOf(10)); + props.put(KEY_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(VALUE_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + return props; + } + + private Map sourceConnectorPropsWithGroups(String topic) { + // setup up props for the source connector + Map props = defaultSourceConnectorProps(topic); + props.put(TOPIC_CREATION_GROUPS_CONFIG, String.join(",", FOO_GROUP, BAR_GROUP)); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + REPLICATION_FACTOR_CONFIG, String.valueOf(DEFAULT_REPLICATION_FACTOR)); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + PARTITIONS_CONFIG, String.valueOf(DEFAULT_PARTITIONS)); + props.put(SourceConnectorConfig.TOPIC_CREATION_PREFIX + FOO_GROUP + "." + INCLUDE_REGEX_CONFIG, FOO_TOPIC); + props.put(SourceConnectorConfig.TOPIC_CREATION_PREFIX + FOO_GROUP + "." + EXCLUDE_REGEX_CONFIG, BAR_TOPIC); + props.put(SourceConnectorConfig.TOPIC_CREATION_PREFIX + FOO_GROUP + "." + PARTITIONS_CONFIG, + String.valueOf(FOO_GROUP_PARTITIONS)); + return props; + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/StartAndStopCounter.java b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/StartAndStopCounter.java new file mode 100644 index 0000000..9ed5b06 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/StartAndStopCounter.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.integration; + +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.kafka.common.utils.Time; + +public class StartAndStopCounter { + + private final AtomicInteger startCounter = new AtomicInteger(0); + private final AtomicInteger stopCounter = new AtomicInteger(0); + private final List restartLatches = new CopyOnWriteArrayList<>(); + private final Time clock; + + public StartAndStopCounter() { + this(Time.SYSTEM); + } + + public StartAndStopCounter(Time clock) { + this.clock = clock != null ? clock : Time.SYSTEM; + } + + /** + * Record a start. + */ + public void recordStart() { + startCounter.incrementAndGet(); + restartLatches.forEach(StartAndStopLatch::recordStart); + } + + /** + * Record a stop. + */ + public void recordStop() { + stopCounter.incrementAndGet(); + restartLatches.forEach(StartAndStopLatch::recordStop); + } + + /** + * Get the number of starts. + * + * @return the number of starts + */ + public int starts() { + return startCounter.get(); + } + + /** + * Get the number of stops. + * + * @return the number of stops + */ + public int stops() { + return stopCounter.get(); + } + + public StartsAndStops countsSnapshot() { + return new StartsAndStops(starts(), stops()); + } + + /** + * Obtain a {@link StartAndStopLatch} that can be used to wait until the expected number of restarts + * has been completed. + * + * @param expectedStarts the expected number of starts; may be 0 + * @param expectedStops the expected number of stops; may be 0 + * @return the latch; never null + */ + public StartAndStopLatch expectedRestarts(int expectedStarts, int expectedStops) { + return expectedRestarts(expectedStarts, expectedStops, null); + } + + /** + * Obtain a {@link StartAndStopLatch} that can be used to wait until the expected number of restarts + * has been completed. + * + * @param expectedStarts the expected number of starts; may be 0 + * @param expectedStops the expected number of stops; may be 0 + * @param dependents any dependent latches that must also complete in order for the + * resulting latch to complete + * @return the latch; never null + */ + public StartAndStopLatch expectedRestarts(int expectedStarts, int expectedStops, List dependents) { + StartAndStopLatch latch = new StartAndStopLatch(expectedStarts, expectedStops, this::remove, dependents, clock); + restartLatches.add(latch); + return latch; + } + + /** + * Obtain a {@link StartAndStopLatch} that can be used to wait until the expected number of restarts + * has been completed. + * + * @param expectedRestarts the expected number of restarts + * @return the latch; never null + */ + public StartAndStopLatch expectedRestarts(int expectedRestarts) { + return expectedRestarts(expectedRestarts, expectedRestarts); + } + + /** + * Obtain a {@link StartAndStopLatch} that can be used to wait until the expected number of restarts + * has been completed. + * + * @param expectedRestarts the expected number of restarts + * @param dependents any dependent latches that must also complete in order for the + * resulting latch to complete + * @return the latch; never null + */ + public StartAndStopLatch expectedRestarts(int expectedRestarts, List dependents) { + return expectedRestarts(expectedRestarts, expectedRestarts, dependents); + } + + /** + * Obtain a {@link StartAndStopLatch} that can be used to wait until the expected number of starts + * has been completed. + * + * @param expectedStarts the expected number of starts + * @return the latch; never null + */ + public StartAndStopLatch expectedStarts(int expectedStarts) { + return expectedRestarts(expectedStarts, 0); + } + + /** + * Obtain a {@link StartAndStopLatch} that can be used to wait until the expected number of starts + * has been completed. + * + * @param expectedStarts the expected number of starts + * @param dependents any dependent latches that must also complete in order for the + * resulting latch to complete + * @return the latch; never null + */ + public StartAndStopLatch expectedStarts(int expectedStarts, List dependents) { + return expectedRestarts(expectedStarts, 0, dependents); + } + + + /** + * Obtain a {@link StartAndStopLatch} that can be used to wait until the expected number of + * stops has been completed. + * + * @param expectedStops the expected number of stops + * @return the latch; never null + */ + public StartAndStopLatch expectedStops(int expectedStops) { + return expectedRestarts(0, expectedStops); + } + + /** + * Obtain a {@link StartAndStopLatch} that can be used to wait until the expected number of + * stops has been completed. + * + * @param expectedStops the expected number of stops + * @param dependents any dependent latches that must also complete in order for the + * resulting latch to complete + * @return the latch; never null + */ + public StartAndStopLatch expectedStops(int expectedStops, List dependents) { + return expectedRestarts(0, expectedStops, dependents); + } + + protected void remove(StartAndStopLatch restartLatch) { + restartLatches.remove(restartLatch); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/StartAndStopCounterTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/StartAndStopCounterTest.java new file mode 100644 index 0000000..7820a6d --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/StartAndStopCounterTest.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.integration; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class StartAndStopCounterTest { + + private StartAndStopCounter counter; + private Time clock; + private ExecutorService waiters; + private StartAndStopLatch latch; + + @Before + public void setup() { + clock = new MockTime(); + counter = new StartAndStopCounter(clock); + } + + @After + public void teardown() { + if (waiters != null) { + try { + waiters.shutdownNow(); + } finally { + waiters = null; + } + } + } + + @Test + public void shouldRecordStarts() { + assertEquals(0, counter.starts()); + counter.recordStart(); + assertEquals(1, counter.starts()); + counter.recordStart(); + assertEquals(2, counter.starts()); + assertEquals(2, counter.starts()); + } + + @Test + public void shouldRecordStops() { + assertEquals(0, counter.stops()); + counter.recordStop(); + assertEquals(1, counter.stops()); + counter.recordStop(); + assertEquals(2, counter.stops()); + assertEquals(2, counter.stops()); + } + + @Test + public void shouldExpectRestarts() throws Exception { + waiters = Executors.newSingleThreadExecutor(); + + latch = counter.expectedRestarts(1); + Future future = asyncAwait(100, TimeUnit.MILLISECONDS); + + clock.sleep(1000); + counter.recordStop(); + counter.recordStart(); + assertTrue(future.get(200, TimeUnit.MILLISECONDS)); + assertTrue(future.isDone()); + } + @Test + public void shouldFailToWaitForRestartThatNeverHappens() throws Exception { + waiters = Executors.newSingleThreadExecutor(); + + latch = counter.expectedRestarts(1); + Future future = asyncAwait(100, TimeUnit.MILLISECONDS); + + clock.sleep(1000); + // Record a stop but NOT a start + counter.recordStop(); + assertFalse(future.get(200, TimeUnit.MILLISECONDS)); + assertTrue(future.isDone()); + } + + private Future asyncAwait(long duration, TimeUnit unit) { + return waiters.submit(() -> { + try { + return latch.await(duration, unit); + } catch (InterruptedException e) { + Thread.interrupted(); + return false; + } + }); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/StartAndStopLatch.java b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/StartAndStopLatch.java new file mode 100644 index 0000000..b77007c --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/StartAndStopLatch.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.integration; + +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; + +import org.apache.kafka.common.utils.Time; + +/** + * A latch that can be used to count down the number of times a connector and/or tasks have + * been started and stopped. + */ +public class StartAndStopLatch { + private final CountDownLatch startLatch; + private final CountDownLatch stopLatch; + private final List dependents; + private final Consumer uponCompletion; + private final Time clock; + + StartAndStopLatch(int expectedStarts, int expectedStops, Consumer uponCompletion, + List dependents, Time clock) { + this.startLatch = new CountDownLatch(expectedStarts < 0 ? 0 : expectedStarts); + this.stopLatch = new CountDownLatch(expectedStops < 0 ? 0 : expectedStops); + this.dependents = dependents; + this.uponCompletion = uponCompletion; + this.clock = clock; + } + + protected void recordStart() { + startLatch.countDown(); + } + + protected void recordStop() { + stopLatch.countDown(); + } + + /** + * Causes the current thread to wait until the latch has counted down the starts and + * stops to zero, unless the thread is {@linkplain Thread#interrupt interrupted}, + * or the specified waiting time elapses. + * + *

            If the current counts are zero then this method returns immediately + * with the value {@code true}. + * + *

            If the current count is greater than zero then the current + * thread becomes disabled for thread scheduling purposes and lies + * dormant until one of three things happen: + *

              + *
            • The counts reach zero due to invocations of the {@link #recordStart()} and + * {@link #recordStop()} methods; or + *
            • Some other thread {@linkplain Thread#interrupt interrupts} + * the current thread; or + *
            • The specified waiting time elapses. + *
            + * + *

            If the count reaches zero then the method returns with the + * value {@code true}. + * + *

            If the current thread: + *

              + *
            • has its interrupted status set on entry to this method; or + *
            • is {@linkplain Thread#interrupt interrupted} while waiting, + *
            + * then {@link InterruptedException} is thrown and the current thread's + * interrupted status is cleared. + * + *

            If the specified waiting time elapses then the value {@code false} + * is returned. If the time is less than or equal to zero, the method + * will not wait at all. + * + * @param timeout the maximum time to wait + * @param unit the time unit of the {@code timeout} argument + * @return {@code true} if the counts reached zero and {@code false} + * if the waiting time elapsed before the counts reached zero + * @throws InterruptedException if the current thread is interrupted + * while waiting + */ + public boolean await(long timeout, TimeUnit unit) throws InterruptedException { + final long start = clock.milliseconds(); + final long end = start + unit.toMillis(timeout); + if (!startLatch.await(end - start, TimeUnit.MILLISECONDS)) { + return false; + } + if (!stopLatch.await(end - clock.milliseconds(), TimeUnit.MILLISECONDS)) { + return false; + } + + if (dependents != null) { + for (StartAndStopLatch dependent : dependents) { + if (!dependent.await(end - clock.milliseconds(), TimeUnit.MILLISECONDS)) { + return false; + } + } + } + if (uponCompletion != null) { + uponCompletion.accept(this); + } + return true; + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/StartAndStopLatchTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/StartAndStopLatchTest.java new file mode 100644 index 0000000..d2732ea --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/StartAndStopLatchTest.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.integration; + +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class StartAndStopLatchTest { + + private Time clock; + private StartAndStopLatch latch; + private List dependents; + private AtomicBoolean completed = new AtomicBoolean(); + private ExecutorService waiters; + private Future future; + + @Before + public void setup() { + clock = new MockTime(); + waiters = Executors.newSingleThreadExecutor(); + } + + @After + public void teardown() { + if (waiters != null) { + waiters.shutdownNow(); + } + } + + @Test + public void shouldReturnFalseWhenAwaitingForStartToNeverComplete() throws Throwable { + latch = new StartAndStopLatch(1, 1, this::complete, dependents, clock); + future = asyncAwait(100); + clock.sleep(10); + assertFalse(future.get(200, TimeUnit.MILLISECONDS)); + assertTrue(future.isDone()); + } + + @Test + public void shouldReturnFalseWhenAwaitingForStopToNeverComplete() throws Throwable { + latch = new StartAndStopLatch(1, 1, this::complete, dependents, clock); + future = asyncAwait(100); + latch.recordStart(); + clock.sleep(10); + assertFalse(future.get(200, TimeUnit.MILLISECONDS)); + assertTrue(future.isDone()); + } + + @Test + public void shouldReturnTrueWhenAwaitingForStartAndStopToComplete() throws Throwable { + latch = new StartAndStopLatch(1, 1, this::complete, dependents, clock); + future = asyncAwait(100); + latch.recordStart(); + latch.recordStop(); + clock.sleep(10); + assertTrue(future.get(200, TimeUnit.MILLISECONDS)); + assertTrue(future.isDone()); + } + + @Test + public void shouldReturnFalseWhenAwaitingForDependentLatchToComplete() throws Throwable { + StartAndStopLatch depLatch = new StartAndStopLatch(1, 1, this::complete, null, clock); + dependents = Collections.singletonList(depLatch); + latch = new StartAndStopLatch(1, 1, this::complete, dependents, clock); + + future = asyncAwait(100); + latch.recordStart(); + latch.recordStop(); + clock.sleep(10); + assertFalse(future.get(200, TimeUnit.MILLISECONDS)); + assertTrue(future.isDone()); + } + + @Test + public void shouldReturnTrueWhenAwaitingForStartAndStopAndDependentLatch() throws Throwable { + StartAndStopLatch depLatch = new StartAndStopLatch(1, 1, this::complete, null, clock); + dependents = Collections.singletonList(depLatch); + latch = new StartAndStopLatch(1, 1, this::complete, dependents, clock); + + future = asyncAwait(100); + latch.recordStart(); + latch.recordStop(); + depLatch.recordStart(); + depLatch.recordStop(); + clock.sleep(10); + assertTrue(future.get(200, TimeUnit.MILLISECONDS)); + assertTrue(future.isDone()); + } + + private Future asyncAwait(long duration) { + return asyncAwait(duration, TimeUnit.MILLISECONDS); + } + + private Future asyncAwait(long duration, TimeUnit unit) { + return waiters.submit(() -> { + try { + return latch.await(duration, unit); + } catch (InterruptedException e) { + Thread.interrupted(); + return false; + } + }); + } + + private void complete(StartAndStopLatch latch) { + completed.set(true); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/StartsAndStops.java b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/StartsAndStops.java new file mode 100644 index 0000000..25bc748 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/StartsAndStops.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.integration; + +public class StartsAndStops { + private final int starts; + private final int stops; + + public StartsAndStops(int starts, int stops) { + this.starts = starts; + this.stops = stops; + } + + public int starts() { + return starts; + } + + public int stops() { + return stops; + } + +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/TaskHandle.java b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/TaskHandle.java new file mode 100644 index 0000000..ab5b711 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/TaskHandle.java @@ -0,0 +1,374 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.integration; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.connect.errors.DataException; +import org.apache.kafka.connect.sink.SinkRecord; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +/** + * A handle to an executing task in a worker. Use this class to record progress, for example: number of records seen + * by the task using so far, or waiting for partitions to be assigned to the task. + */ +public class TaskHandle { + + private static final Logger log = LoggerFactory.getLogger(TaskHandle.class); + + private final String taskId; + private final ConnectorHandle connectorHandle; + private final ConcurrentMap partitions = new ConcurrentHashMap<>(); + private final StartAndStopCounter startAndStopCounter = new StartAndStopCounter(); + private final Consumer consumer; + + private CountDownLatch recordsRemainingLatch; + private CountDownLatch recordsToCommitLatch; + private int expectedRecords = -1; + private int expectedCommits = -1; + + public TaskHandle(ConnectorHandle connectorHandle, String taskId, Consumer consumer) { + this.taskId = taskId; + this.connectorHandle = connectorHandle; + this.consumer = consumer; + } + + public String taskId() { + return taskId; + } + + public void record() { + record(null); + } + + /** + * Record a message arrival at the task and the connector overall. + */ + public void record(SinkRecord record) { + if (consumer != null && record != null) { + consumer.accept(record); + } + if (recordsRemainingLatch != null) { + recordsRemainingLatch.countDown(); + } + connectorHandle.record(); + } + + /** + * Record arrival of a batch of messages at the task and the connector overall. + * + * @param batchSize the number of messages + */ + public void record(int batchSize) { + if (recordsRemainingLatch != null) { + IntStream.range(0, batchSize).forEach(i -> recordsRemainingLatch.countDown()); + } + connectorHandle.record(batchSize); + } + + /** + * Record a message commit from the task and the connector overall. + */ + public void commit() { + if (recordsToCommitLatch != null) { + recordsToCommitLatch.countDown(); + } + connectorHandle.commit(); + } + + /** + * Record commit on a batch of messages from the task and the connector overall. + * + * @param batchSize the number of messages + */ + public void commit(int batchSize) { + if (recordsToCommitLatch != null) { + IntStream.range(0, batchSize).forEach(i -> recordsToCommitLatch.countDown()); + } + connectorHandle.commit(batchSize); + } + + /** + * Set the number of expected records for this task. + * + * @param expected number of records + */ + public void expectedRecords(int expected) { + expectedRecords = expected; + recordsRemainingLatch = new CountDownLatch(expected); + } + + /** + * Set the number of expected record commits performed by this task. + * + * @param expected number of commits + */ + public void expectedCommits(int expected) { + expectedRecords = expected; + recordsToCommitLatch = new CountDownLatch(expected); + } + + /** + * Adds a set of partitions to the (sink) task's assignment + * + * @param partitions the newly-assigned partitions + */ + public void partitionsAssigned(Collection partitions) { + partitions.forEach(partition -> this.partitions.computeIfAbsent(partition, PartitionHistory::new).assigned()); + } + + /** + * Removes a set of partitions to the (sink) task's assignment + * + * @param partitions the newly-revoked partitions + */ + public void partitionsRevoked(Collection partitions) { + partitions.forEach(partition -> this.partitions.computeIfAbsent(partition, PartitionHistory::new).revoked()); + } + + /** + * Records offset commits for a (sink) task's partitions + * + * @param partitions the committed partitions + */ + public void partitionsCommitted(Collection partitions) { + partitions.forEach(partition -> this.partitions.computeIfAbsent(partition, PartitionHistory::new).committed()); + } + + /** + * @return the complete set of partitions currently assigned to this (sink) task + */ + public Collection assignment() { + return partitions.values().stream() + .filter(PartitionHistory::isAssigned) + .map(PartitionHistory::topicPartition) + .collect(Collectors.toSet()); + } + + /** + * @return the number of topic partitions assigned to this (sink) task. + */ + public int numPartitionsAssigned() { + return assignment().size(); + } + + /** + * Returns the number of times the partition has been assigned to this (sink) task. + * @param partition the partition + * @return the number of times it has been assigned; may be 0 if never assigned + */ + public int timesAssigned(TopicPartition partition) { + return partitions.computeIfAbsent(partition, PartitionHistory::new).timesAssigned(); + } + + /** + * Returns the number of times the partition has been revoked from this (sink) task. + * @param partition the partition + * @return the number of times it has been revoked; may be 0 if never revoked + */ + public int timesRevoked(TopicPartition partition) { + return partitions.computeIfAbsent(partition, PartitionHistory::new).timesRevoked(); + } + + /** + * Returns the number of times the framework has committed offsets for this partition + * @param partition the partition + * @return the number of times it has been committed; may be 0 if never committed + */ + public int timesCommitted(TopicPartition partition) { + return partitions.computeIfAbsent(partition, PartitionHistory::new).timesCommitted(); + } + + /** + * Wait up to the specified number of milliseconds for this task to meet the expected number of + * records as defined by {@code expectedRecords}. + * + * @param timeoutMillis number of milliseconds to wait for records + * @throws InterruptedException if another threads interrupts this one while waiting for records + */ + public void awaitRecords(long timeoutMillis) throws InterruptedException { + awaitRecords(timeoutMillis, TimeUnit.MILLISECONDS); + } + + /** + * Wait up to the specified timeout for this task to meet the expected number of records as + * defined by {@code expectedRecords}. + * + * @param timeout duration to wait for records + * @param unit the unit of duration; may not be null + * @throws InterruptedException if another threads interrupts this one while waiting for records + */ + public void awaitRecords(long timeout, TimeUnit unit) throws InterruptedException { + if (recordsRemainingLatch == null) { + throw new IllegalStateException("Illegal state encountered. expectedRecords() was not set for this task?"); + } + if (!recordsRemainingLatch.await(timeout, unit)) { + String msg = String.format( + "Insufficient records seen by task %s in %d millis. Records expected=%d, actual=%d", + taskId, + unit.toMillis(timeout), + expectedRecords, + expectedRecords - recordsRemainingLatch.getCount()); + throw new DataException(msg); + } + log.debug("Task {} saw {} records, expected {} records", + taskId, expectedRecords - recordsRemainingLatch.getCount(), expectedRecords); + } + + /** + * Wait up to the specified timeout in milliseconds for this task to meet the expected number + * of commits as defined by {@code expectedCommits}. + * + * @param timeoutMillis number of milliseconds to wait for commits + * @throws InterruptedException if another threads interrupts this one while waiting for commits + */ + public void awaitCommits(long timeoutMillis) throws InterruptedException { + awaitCommits(timeoutMillis, TimeUnit.MILLISECONDS); + } + + /** + * Wait up to the specified timeout for this task to meet the expected number of commits as + * defined by {@code expectedCommits}. + * + * @param timeout duration to wait for commits + * @param unit the unit of duration; may not be null + * @throws InterruptedException if another threads interrupts this one while waiting for commits + */ + public void awaitCommits(long timeout, TimeUnit unit) throws InterruptedException { + if (recordsToCommitLatch == null) { + throw new IllegalStateException("Illegal state encountered. expectedRecords() was not set for this task?"); + } + if (!recordsToCommitLatch.await(timeout, unit)) { + String msg = String.format( + "Insufficient records seen by task %s in %d millis. Records expected=%d, actual=%d", + taskId, + unit.toMillis(timeout), + expectedCommits, + expectedCommits - recordsToCommitLatch.getCount()); + throw new DataException(msg); + } + log.debug("Task {} saw {} records, expected {} records", + taskId, expectedCommits - recordsToCommitLatch.getCount(), expectedCommits); + } + + /** + * Gets the start and stop counter corresponding to this handle. + * + * @return the start and stop counter + */ + public StartAndStopCounter startAndStopCounter() { + return startAndStopCounter; + } + + /** + * Record that this task has been stopped. This should be called by the task. + */ + public void recordTaskStart() { + startAndStopCounter.recordStart(); + } + + /** + * Record that this task has been stopped. This should be called by the task. + */ + public void recordTaskStop() { + startAndStopCounter.recordStop(); + } + + /** + * Obtain a {@link StartAndStopLatch} that can be used to wait until this task has completed the + * expected number of starts. + * + * @param expectedStarts the expected number of starts + * @return the latch; never null + */ + public StartAndStopLatch expectedStarts(int expectedStarts) { + return startAndStopCounter.expectedStarts(expectedStarts); + } + + /** + * Obtain a {@link StartAndStopLatch} that can be used to wait until this task has completed the + * expected number of starts. + * + * @param expectedStops the expected number of stops + * @return the latch; never null + */ + public StartAndStopLatch expectedStops(int expectedStops) { + return startAndStopCounter.expectedStops(expectedStops); + } + + @Override + public String toString() { + return "Handle{" + + "taskId='" + taskId + '\'' + + '}'; + } + + private static class PartitionHistory { + private final TopicPartition topicPartition; + private boolean assigned = false; + private int timesAssigned = 0; + private int timesRevoked = 0; + private int timesCommitted = 0; + + public PartitionHistory(TopicPartition topicPartition) { + this.topicPartition = topicPartition; + } + + public void assigned() { + timesAssigned++; + assigned = true; + } + + public void revoked() { + timesRevoked++; + assigned = false; + } + + public void committed() { + timesCommitted++; + } + + public TopicPartition topicPartition() { + return topicPartition; + } + + public boolean isAssigned() { + return assigned; + } + + public int timesAssigned() { + return timesAssigned; + } + + public int timesRevoked() { + return timesRevoked; + } + + public int timesCommitted() { + return timesCommitted; + } + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/integration/TransformationIntegrationTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/TransformationIntegrationTest.java new file mode 100644 index 0000000..02d8c7f --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/integration/TransformationIntegrationTest.java @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.integration; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.connect.storage.StringConverter; +import org.apache.kafka.connect.transforms.Filter; +import org.apache.kafka.connect.transforms.predicates.HasHeaderKey; +import org.apache.kafka.connect.transforms.predicates.RecordIsTombstone; +import org.apache.kafka.connect.transforms.predicates.TopicNameMatches; +import org.apache.kafka.connect.util.clusters.EmbeddedConnectCluster; +import org.apache.kafka.test.IntegrationTest; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.TimeUnit; + +import static java.util.Collections.singletonMap; +import static org.apache.kafka.connect.runtime.ConnectorConfig.CONNECTOR_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.KEY_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.PREDICATES_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.TASKS_MAX_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.TRANSFORMS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.VALUE_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.SinkConnectorConfig.TOPICS_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.DEFAULT_TOPIC_CREATION_PREFIX; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.PARTITIONS_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.REPLICATION_FACTOR_CONFIG; +import static org.apache.kafka.connect.runtime.WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_CONFIG; +import static org.apache.kafka.test.TestUtils.waitForCondition; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +/** + * An integration test for connectors with transformations + */ +@Category(IntegrationTest.class) +public class TransformationIntegrationTest { + + private static final Logger log = LoggerFactory.getLogger(TransformationIntegrationTest.class); + + private static final int NUM_RECORDS_PRODUCED = 2000; + private static final int NUM_TOPIC_PARTITIONS = 3; + private static final long RECORD_TRANSFER_DURATION_MS = TimeUnit.SECONDS.toMillis(30); + private static final long OBSERVED_RECORDS_DURATION_MS = TimeUnit.SECONDS.toMillis(60); + private static final int NUM_TASKS = 1; + private static final int NUM_WORKERS = 3; + private static final String CONNECTOR_NAME = "simple-conn"; + private static final String SINK_CONNECTOR_CLASS_NAME = MonitorableSinkConnector.class.getSimpleName(); + private static final String SOURCE_CONNECTOR_CLASS_NAME = MonitorableSourceConnector.class.getSimpleName(); + + private EmbeddedConnectCluster connect; + private ConnectorHandle connectorHandle; + + @Before + public void setup() { + // setup Connect worker properties + Map workerProps = new HashMap<>(); + workerProps.put(OFFSET_COMMIT_INTERVAL_MS_CONFIG, String.valueOf(5_000)); + + // setup Kafka broker properties + Properties brokerProps = new Properties(); + // This is required because tests in this class also test per-connector topic creation with transformations + brokerProps.put("auto.create.topics.enable", "false"); + + // build a Connect cluster backed by Kafka and Zk + connect = new EmbeddedConnectCluster.Builder() + .name("connect-cluster") + .numWorkers(NUM_WORKERS) + .numBrokers(1) + .workerProps(workerProps) + .brokerProps(brokerProps) + .build(); + + // start the clusters + connect.start(); + + // get a handle to the connector + connectorHandle = RuntimeHandles.get().connectorHandle(CONNECTOR_NAME); + } + + @After + public void close() { + // delete connector handle + RuntimeHandles.get().deleteConnector(CONNECTOR_NAME); + + // stop all Connect, Kafka and Zk threads. + connect.stop(); + } + + /** + * Test the {@link Filter} transformer with a + * {@link TopicNameMatches} predicate on a sink connector. + */ + @Test + public void testFilterOnTopicNameWithSinkConnector() throws Exception { + assertConnectReady(); + + Map observedRecords = observeRecords(); + + // create test topics + String fooTopic = "foo-topic"; + String barTopic = "bar-topic"; + int numFooRecords = NUM_RECORDS_PRODUCED; + int numBarRecords = NUM_RECORDS_PRODUCED; + connect.kafka().createTopic(fooTopic, NUM_TOPIC_PARTITIONS); + connect.kafka().createTopic(barTopic, NUM_TOPIC_PARTITIONS); + + // setup up props for the sink connector + Map props = new HashMap<>(); + props.put("name", CONNECTOR_NAME); + props.put(CONNECTOR_CLASS_CONFIG, SINK_CONNECTOR_CLASS_NAME); + props.put(TASKS_MAX_CONFIG, String.valueOf(NUM_TASKS)); + props.put(TOPICS_CONFIG, String.join(",", fooTopic, barTopic)); + props.put(KEY_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(VALUE_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(TRANSFORMS_CONFIG, "filter"); + props.put(TRANSFORMS_CONFIG + ".filter.type", Filter.class.getSimpleName()); + props.put(TRANSFORMS_CONFIG + ".filter.predicate", "barPredicate"); + props.put(PREDICATES_CONFIG, "barPredicate"); + props.put(PREDICATES_CONFIG + ".barPredicate.type", TopicNameMatches.class.getSimpleName()); + props.put(PREDICATES_CONFIG + ".barPredicate.pattern", "bar-.*"); + + // expect all records to be consumed by the connector + connectorHandle.expectedRecords(numFooRecords); + + // expect all records to be consumed by the connector + connectorHandle.expectedCommits(numFooRecords); + + // start a sink connector + connect.configureConnector(CONNECTOR_NAME, props); + assertConnectorRunning(); + + // produce some messages into source topic partitions + for (int i = 0; i < numBarRecords; i++) { + connect.kafka().produce(barTopic, i % NUM_TOPIC_PARTITIONS, "key", "simple-message-value-" + i); + } + for (int i = 0; i < numFooRecords; i++) { + connect.kafka().produce(fooTopic, i % NUM_TOPIC_PARTITIONS, "key", "simple-message-value-" + i); + } + + // consume all records from the source topic or fail, to ensure that they were correctly produced. + assertEquals("Unexpected number of records consumed", numFooRecords, + connect.kafka().consume(numFooRecords, RECORD_TRANSFER_DURATION_MS, fooTopic).count()); + assertEquals("Unexpected number of records consumed", numBarRecords, + connect.kafka().consume(numBarRecords, RECORD_TRANSFER_DURATION_MS, barTopic).count()); + + // wait for the connector tasks to consume all records. + connectorHandle.awaitRecords(RECORD_TRANSFER_DURATION_MS); + + // wait for the connector tasks to commit all records. + connectorHandle.awaitCommits(RECORD_TRANSFER_DURATION_MS); + + // Assert that we didn't see any baz + Map expectedRecordCounts = singletonMap(fooTopic, (long) numFooRecords); + assertObservedRecords(observedRecords, expectedRecordCounts); + + // delete connector + connect.deleteConnector(CONNECTOR_NAME); + } + + private void assertConnectReady() throws InterruptedException { + connect.assertions().assertExactlyNumBrokersAreUp(1, "Brokers did not start in time."); + connect.assertions().assertExactlyNumWorkersAreUp(NUM_WORKERS, "Worker did not start in time."); + log.info("Completed startup of {} Kafka brokers and {} Connect workers", 1, NUM_WORKERS); + } + + private void assertConnectorRunning() throws InterruptedException { + connect.assertions().assertConnectorAndAtLeastNumTasksAreRunning(CONNECTOR_NAME, NUM_TASKS, + "Connector tasks did not start in time."); + } + + private void assertObservedRecords(Map observedRecords, Map expectedRecordCounts) throws InterruptedException { + waitForCondition(() -> expectedRecordCounts.equals(observedRecords), + OBSERVED_RECORDS_DURATION_MS, + () -> "The observed records should be " + expectedRecordCounts + " but was " + observedRecords); + } + + private Map observeRecords() { + Map observedRecords = new HashMap<>(); + // record all the record we see + connectorHandle.taskHandle(CONNECTOR_NAME + "-0", + record -> observedRecords.compute(record.topic(), + (key, value) -> value == null ? 1 : value + 1)); + return observedRecords; + } + + /** + * Test the {@link Filter} transformer with a + * {@link RecordIsTombstone} predicate on a sink connector. + */ + @Test + public void testFilterOnTombstonesWithSinkConnector() throws Exception { + assertConnectReady(); + + Map observedRecords = observeRecords(); + + // create test topics + String topic = "foo-topic"; + int numRecords = NUM_RECORDS_PRODUCED; + connect.kafka().createTopic(topic, NUM_TOPIC_PARTITIONS); + + // setup up props for the sink connector + Map props = new HashMap<>(); + props.put("name", CONNECTOR_NAME); + props.put(CONNECTOR_CLASS_CONFIG, SINK_CONNECTOR_CLASS_NAME); + props.put(TASKS_MAX_CONFIG, String.valueOf(NUM_TASKS)); + props.put(TOPICS_CONFIG, String.join(",", topic)); + props.put(KEY_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(VALUE_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(TRANSFORMS_CONFIG, "filter"); + props.put(TRANSFORMS_CONFIG + ".filter.type", Filter.class.getSimpleName()); + props.put(TRANSFORMS_CONFIG + ".filter.predicate", "barPredicate"); + props.put(PREDICATES_CONFIG, "barPredicate"); + props.put(PREDICATES_CONFIG + ".barPredicate.type", RecordIsTombstone.class.getSimpleName()); + + // expect only half the records to be consumed by the connector + connectorHandle.expectedCommits(numRecords / 2); + connectorHandle.expectedRecords(numRecords / 2); + + // start a sink connector + connect.configureConnector(CONNECTOR_NAME, props); + assertConnectorRunning(); + + // produce some messages into source topic partitions + for (int i = 0; i < numRecords; i++) { + connect.kafka().produce(topic, i % NUM_TOPIC_PARTITIONS, "key", i % 2 == 0 ? "simple-message-value-" + i : null); + } + + // consume all records from the source topic or fail, to ensure that they were correctly produced. + assertEquals("Unexpected number of records consumed", numRecords, + connect.kafka().consume(numRecords, RECORD_TRANSFER_DURATION_MS, topic).count()); + + // wait for the connector tasks to consume all records. + connectorHandle.awaitRecords(RECORD_TRANSFER_DURATION_MS); + + // wait for the connector tasks to commit all records. + connectorHandle.awaitCommits(RECORD_TRANSFER_DURATION_MS); + + Map expectedRecordCounts = singletonMap(topic, (long) (numRecords / 2)); + assertObservedRecords(observedRecords, expectedRecordCounts); + + // delete connector + connect.deleteConnector(CONNECTOR_NAME); + } + + /** + * Test the {@link Filter} transformer with a {@link HasHeaderKey} predicate on a source connector. + * Note that this test uses topic creation configs to allow the source connector to create + * the topic when it tries to produce the first source record, instead of requiring the topic + * to exist before the connector starts. + */ + @Test + public void testFilterOnHasHeaderKeyWithSourceConnectorAndTopicCreation() throws Exception { + assertConnectReady(); + + // setup up props for the sink connector + Map props = new HashMap<>(); + props.put("name", CONNECTOR_NAME); + props.put(CONNECTOR_CLASS_CONFIG, SOURCE_CONNECTOR_CLASS_NAME); + props.put(TASKS_MAX_CONFIG, String.valueOf(NUM_TASKS)); + props.put("topic", "test-topic"); + props.put("throughput", String.valueOf(500)); + props.put(KEY_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(VALUE_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(TRANSFORMS_CONFIG, "filter"); + props.put(TRANSFORMS_CONFIG + ".filter.type", Filter.class.getSimpleName()); + props.put(TRANSFORMS_CONFIG + ".filter.predicate", "headerPredicate"); + props.put(TRANSFORMS_CONFIG + ".filter.negate", "true"); + props.put(PREDICATES_CONFIG, "headerPredicate"); + props.put(PREDICATES_CONFIG + ".headerPredicate.type", HasHeaderKey.class.getSimpleName()); + props.put(PREDICATES_CONFIG + ".headerPredicate.name", "header-8"); + // custom topic creation is used, so there's no need to proactively create the test topic + props.put(DEFAULT_TOPIC_CREATION_PREFIX + REPLICATION_FACTOR_CONFIG, String.valueOf(-1)); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + PARTITIONS_CONFIG, String.valueOf(NUM_TOPIC_PARTITIONS)); + + // expect all records to be produced by the connector + connectorHandle.expectedRecords(NUM_RECORDS_PRODUCED); + + // expect all records to be produced by the connector + connectorHandle.expectedCommits(NUM_RECORDS_PRODUCED); + + // validate the intended connector configuration, a valid config + connect.assertions().assertExactlyNumErrorsOnConnectorConfigValidation(SOURCE_CONNECTOR_CLASS_NAME, props, 0, + "Validating connector configuration produced an unexpected number or errors."); + + // start a source connector + connect.configureConnector(CONNECTOR_NAME, props); + assertConnectorRunning(); + + // wait for the connector tasks to produce enough records + connectorHandle.awaitRecords(RECORD_TRANSFER_DURATION_MS); + + // wait for the connector tasks to commit enough records + connectorHandle.awaitCommits(RECORD_TRANSFER_DURATION_MS); + + // consume all records from the source topic or fail, to ensure that they were correctly produced + for (ConsumerRecord record : connect.kafka().consume(1, RECORD_TRANSFER_DURATION_MS, "test-topic")) { + assertNotNull("Expected header to exist", + record.headers().lastHeader("header-8")); + } + + // delete connector + connect.deleteConnector(CONNECTOR_NAME); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/AbstractHerderTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/AbstractHerderTest.java new file mode 100644 index 0000000..8c8d00d --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/AbstractHerderTest.java @@ -0,0 +1,996 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.config.ConfigTransformer; +import org.apache.kafka.common.config.ConfigValue; +import org.apache.kafka.common.config.SaslConfigs; +import org.apache.kafka.common.security.oauthbearer.internals.unsecured.OAuthBearerUnsecuredLoginCallbackHandler; +import org.apache.kafka.connect.connector.ConnectRecord; +import org.apache.kafka.connect.connector.Connector; +import org.apache.kafka.connect.connector.policy.AllConnectorClientConfigOverridePolicy; +import org.apache.kafka.connect.connector.policy.ConnectorClientConfigOverridePolicy; +import org.apache.kafka.connect.connector.policy.NoneConnectorClientConfigOverridePolicy; +import org.apache.kafka.connect.connector.policy.PrincipalConnectorClientConfigOverridePolicy; +import org.apache.kafka.connect.runtime.distributed.ClusterConfigState; +import org.apache.kafka.connect.runtime.isolation.PluginDesc; +import org.apache.kafka.connect.runtime.isolation.Plugins; +import org.apache.kafka.connect.runtime.rest.entities.ConfigInfo; +import org.apache.kafka.connect.runtime.rest.entities.ConfigInfos; +import org.apache.kafka.connect.runtime.rest.entities.ConfigValueInfo; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorStateInfo; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorType; +import org.apache.kafka.connect.runtime.rest.errors.BadRequestException; +import org.apache.kafka.connect.source.SourceConnector; +import org.apache.kafka.connect.source.SourceTask; +import org.apache.kafka.connect.storage.ConfigBackingStore; +import org.apache.kafka.connect.storage.StatusBackingStore; +import org.apache.kafka.connect.transforms.Transformation; +import org.apache.kafka.connect.transforms.predicates.Predicate; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.easymock.Capture; +import org.easymock.EasyMock; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.easymock.PowerMock; +import org.powermock.api.easymock.annotation.MockStrict; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static org.apache.kafka.connect.runtime.AbstractHerder.keysWithVariableValues; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.powermock.api.easymock.PowerMock.verifyAll; +import static org.powermock.api.easymock.PowerMock.replayAll; +import static org.easymock.EasyMock.strictMock; +import static org.easymock.EasyMock.partialMockBuilder; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +@RunWith(PowerMockRunner.class) +@PrepareForTest({AbstractHerder.class}) +public class AbstractHerderTest { + + private static final String CONN1 = "sourceA"; + private static final ConnectorTaskId TASK0 = new ConnectorTaskId(CONN1, 0); + private static final ConnectorTaskId TASK1 = new ConnectorTaskId(CONN1, 1); + private static final ConnectorTaskId TASK2 = new ConnectorTaskId(CONN1, 2); + private static final Integer MAX_TASKS = 3; + private static final Map CONN1_CONFIG = new HashMap<>(); + private static final String TEST_KEY = "testKey"; + private static final String TEST_KEY2 = "testKey2"; + private static final String TEST_KEY3 = "testKey3"; + private static final String TEST_VAL = "testVal"; + private static final String TEST_VAL2 = "testVal2"; + private static final String TEST_REF = "${file:/tmp/somefile.txt:somevar}"; + private static final String TEST_REF2 = "${file:/tmp/somefile2.txt:somevar2}"; + private static final String TEST_REF3 = "${file:/tmp/somefile3.txt:somevar3}"; + static { + CONN1_CONFIG.put(ConnectorConfig.NAME_CONFIG, CONN1); + CONN1_CONFIG.put(ConnectorConfig.TASKS_MAX_CONFIG, MAX_TASKS.toString()); + CONN1_CONFIG.put(SinkConnectorConfig.TOPICS_CONFIG, "foo,bar"); + CONN1_CONFIG.put(ConnectorConfig.CONNECTOR_CLASS_CONFIG, BogusSourceConnector.class.getName()); + CONN1_CONFIG.put(TEST_KEY, TEST_REF); + CONN1_CONFIG.put(TEST_KEY2, TEST_REF2); + CONN1_CONFIG.put(TEST_KEY3, TEST_REF3); + } + private static final Map TASK_CONFIG = new HashMap<>(); + static { + TASK_CONFIG.put(TaskConfig.TASK_CLASS_CONFIG, BogusSourceTask.class.getName()); + TASK_CONFIG.put(TEST_KEY, TEST_REF); + } + private static final List> TASK_CONFIGS = new ArrayList<>(); + static { + TASK_CONFIGS.add(TASK_CONFIG); + TASK_CONFIGS.add(TASK_CONFIG); + TASK_CONFIGS.add(TASK_CONFIG); + } + private static final HashMap> TASK_CONFIGS_MAP = new HashMap<>(); + static { + TASK_CONFIGS_MAP.put(TASK0, TASK_CONFIG); + TASK_CONFIGS_MAP.put(TASK1, TASK_CONFIG); + TASK_CONFIGS_MAP.put(TASK2, TASK_CONFIG); + } + private static final ClusterConfigState SNAPSHOT = new ClusterConfigState(1, null, Collections.singletonMap(CONN1, 3), + Collections.singletonMap(CONN1, CONN1_CONFIG), Collections.singletonMap(CONN1, TargetState.STARTED), + TASK_CONFIGS_MAP, Collections.emptySet()); + private static final ClusterConfigState SNAPSHOT_NO_TASKS = new ClusterConfigState(1, null, Collections.singletonMap(CONN1, 3), + Collections.singletonMap(CONN1, CONN1_CONFIG), Collections.singletonMap(CONN1, TargetState.STARTED), + Collections.emptyMap(), Collections.emptySet()); + + private final String workerId = "workerId"; + private final String kafkaClusterId = "I4ZmrWqfT2e-upky_4fdPA"; + private final int generation = 5; + private final String connector = "connector"; + private final ConnectorClientConfigOverridePolicy noneConnectorClientConfigOverridePolicy = new NoneConnectorClientConfigOverridePolicy(); + + @MockStrict private Worker worker; + @MockStrict private WorkerConfigTransformer transformer; + @MockStrict private Plugins plugins; + @MockStrict private ClassLoader classLoader; + @MockStrict private ConfigBackingStore configStore; + @MockStrict private StatusBackingStore statusStore; + + @Test + public void testConnectors() { + AbstractHerder herder = partialMockBuilder(AbstractHerder.class) + .withConstructor( + Worker.class, + String.class, + String.class, + StatusBackingStore.class, + ConfigBackingStore.class, + ConnectorClientConfigOverridePolicy.class + ) + .withArgs(worker, workerId, kafkaClusterId, statusStore, configStore, noneConnectorClientConfigOverridePolicy) + .addMockedMethod("generation") + .createMock(); + + EasyMock.expect(herder.generation()).andStubReturn(generation); + EasyMock.expect(herder.rawConfig(connector)).andReturn(null); + EasyMock.expect(configStore.snapshot()).andReturn(SNAPSHOT); + replayAll(); + assertEquals(Collections.singleton(CONN1), new HashSet<>(herder.connectors())); + PowerMock.verifyAll(); + } + + @Test + public void testConnectorStatus() { + ConnectorTaskId taskId = new ConnectorTaskId(connector, 0); + AbstractHerder herder = partialMockBuilder(AbstractHerder.class) + .withConstructor( + Worker.class, + String.class, + String.class, + StatusBackingStore.class, + ConfigBackingStore.class, + ConnectorClientConfigOverridePolicy.class + ) + .withArgs(worker, workerId, kafkaClusterId, statusStore, configStore, noneConnectorClientConfigOverridePolicy) + .addMockedMethod("generation") + .createMock(); + + EasyMock.expect(herder.generation()).andStubReturn(generation); + EasyMock.expect(herder.rawConfig(connector)).andReturn(null); + EasyMock.expect(statusStore.get(connector)) + .andReturn(new ConnectorStatus(connector, AbstractStatus.State.RUNNING, workerId, generation)); + EasyMock.expect(statusStore.getAll(connector)) + .andReturn(Collections.singletonList( + new TaskStatus(taskId, AbstractStatus.State.UNASSIGNED, workerId, generation))); + + replayAll(); + ConnectorStateInfo csi = herder.connectorStatus(connector); + PowerMock.verifyAll(); + } + + @Test + public void connectorStatus() { + ConnectorTaskId taskId = new ConnectorTaskId(connector, 0); + + AbstractHerder herder = partialMockBuilder(AbstractHerder.class) + .withConstructor(Worker.class, String.class, String.class, StatusBackingStore.class, ConfigBackingStore.class, + ConnectorClientConfigOverridePolicy.class) + .withArgs(worker, workerId, kafkaClusterId, statusStore, configStore, noneConnectorClientConfigOverridePolicy) + .addMockedMethod("generation") + .createMock(); + + EasyMock.expect(herder.generation()).andStubReturn(generation); + EasyMock.expect(herder.rawConfig(connector)).andReturn(null); + + EasyMock.expect(statusStore.get(connector)) + .andReturn(new ConnectorStatus(connector, AbstractStatus.State.RUNNING, workerId, generation)); + + EasyMock.expect(statusStore.getAll(connector)) + .andReturn(Collections.singletonList( + new TaskStatus(taskId, AbstractStatus.State.UNASSIGNED, workerId, generation))); + EasyMock.expect(worker.getPlugins()).andStubReturn(plugins); + + replayAll(); + + + ConnectorStateInfo state = herder.connectorStatus(connector); + + assertEquals(connector, state.name()); + assertEquals("RUNNING", state.connector().state()); + assertEquals(1, state.tasks().size()); + assertEquals(workerId, state.connector().workerId()); + + ConnectorStateInfo.TaskState taskState = state.tasks().get(0); + assertEquals(0, taskState.id()); + assertEquals("UNASSIGNED", taskState.state()); + assertEquals(workerId, taskState.workerId()); + + PowerMock.verifyAll(); + } + + @Test + public void taskStatus() { + ConnectorTaskId taskId = new ConnectorTaskId("connector", 0); + String workerId = "workerId"; + + AbstractHerder herder = partialMockBuilder(AbstractHerder.class) + .withConstructor(Worker.class, String.class, String.class, StatusBackingStore.class, ConfigBackingStore.class, + ConnectorClientConfigOverridePolicy.class) + .withArgs(worker, workerId, kafkaClusterId, statusStore, configStore, noneConnectorClientConfigOverridePolicy) + .addMockedMethod("generation") + .createMock(); + + EasyMock.expect(herder.generation()).andStubReturn(5); + + final Capture statusCapture = EasyMock.newCapture(); + statusStore.putSafe(EasyMock.capture(statusCapture)); + EasyMock.expectLastCall(); + + EasyMock.expect(statusStore.get(taskId)).andAnswer(statusCapture::getValue); + + replayAll(); + + herder.onFailure(taskId, new RuntimeException()); + + ConnectorStateInfo.TaskState taskState = herder.taskStatus(taskId); + assertEquals(workerId, taskState.workerId()); + assertEquals("FAILED", taskState.state()); + assertEquals(0, taskState.id()); + assertNotNull(taskState.trace()); + + verifyAll(); + } + + @Test + public void testBuildRestartPlanForUnknownConnector() { + String connectorName = "UnknownConnector"; + RestartRequest restartRequest = new RestartRequest(connectorName, false, true); + AbstractHerder herder = partialMockBuilder(AbstractHerder.class) + .withConstructor(Worker.class, String.class, String.class, StatusBackingStore.class, ConfigBackingStore.class, + ConnectorClientConfigOverridePolicy.class) + .withArgs(worker, workerId, kafkaClusterId, statusStore, configStore, noneConnectorClientConfigOverridePolicy) + .addMockedMethod("generation") + .createMock(); + + EasyMock.expect(herder.generation()).andStubReturn(generation); + + EasyMock.expect(statusStore.get(connectorName)).andReturn(null); + replayAll(); + + Optional mayBeRestartPlan = herder.buildRestartPlan(restartRequest); + + assertFalse(mayBeRestartPlan.isPresent()); + } + + @Test + public void testBuildRestartPlanForConnectorAndTasks() { + RestartRequest restartRequest = new RestartRequest(connector, false, true); + + ConnectorTaskId taskId1 = new ConnectorTaskId(connector, 1); + ConnectorTaskId taskId2 = new ConnectorTaskId(connector, 2); + List taskStatuses = new ArrayList<>(); + taskStatuses.add(new TaskStatus(taskId1, AbstractStatus.State.RUNNING, workerId, generation)); + taskStatuses.add(new TaskStatus(taskId2, AbstractStatus.State.FAILED, workerId, generation)); + + AbstractHerder herder = partialMockBuilder(AbstractHerder.class) + .withConstructor(Worker.class, String.class, String.class, StatusBackingStore.class, ConfigBackingStore.class, + ConnectorClientConfigOverridePolicy.class) + .withArgs(worker, workerId, kafkaClusterId, statusStore, configStore, noneConnectorClientConfigOverridePolicy) + .addMockedMethod("generation") + .createMock(); + + EasyMock.expect(herder.generation()).andStubReturn(generation); + EasyMock.expect(herder.rawConfig(connector)).andReturn(null); + + EasyMock.expect(statusStore.get(connector)) + .andReturn(new ConnectorStatus(connector, AbstractStatus.State.RUNNING, workerId, generation)); + + EasyMock.expect(statusStore.getAll(connector)) + .andReturn(taskStatuses); + EasyMock.expect(worker.getPlugins()).andStubReturn(plugins); + + replayAll(); + + Optional mayBeRestartPlan = herder.buildRestartPlan(restartRequest); + + assertTrue(mayBeRestartPlan.isPresent()); + RestartPlan restartPlan = mayBeRestartPlan.get(); + assertTrue(restartPlan.shouldRestartConnector()); + assertTrue(restartPlan.shouldRestartTasks()); + assertEquals(2, restartPlan.taskIdsToRestart().size()); + assertTrue(restartPlan.taskIdsToRestart().contains(taskId1)); + assertTrue(restartPlan.taskIdsToRestart().contains(taskId2)); + + PowerMock.verifyAll(); + } + + @Test + public void testBuildRestartPlanForNoRestart() { + RestartRequest restartRequest = new RestartRequest(connector, true, false); + + ConnectorTaskId taskId1 = new ConnectorTaskId(connector, 1); + ConnectorTaskId taskId2 = new ConnectorTaskId(connector, 2); + List taskStatuses = new ArrayList<>(); + taskStatuses.add(new TaskStatus(taskId1, AbstractStatus.State.RUNNING, workerId, generation)); + taskStatuses.add(new TaskStatus(taskId2, AbstractStatus.State.FAILED, workerId, generation)); + + AbstractHerder herder = partialMockBuilder(AbstractHerder.class) + .withConstructor(Worker.class, String.class, String.class, StatusBackingStore.class, ConfigBackingStore.class, + ConnectorClientConfigOverridePolicy.class) + .withArgs(worker, workerId, kafkaClusterId, statusStore, configStore, noneConnectorClientConfigOverridePolicy) + .addMockedMethod("generation") + .createMock(); + + EasyMock.expect(herder.generation()).andStubReturn(generation); + EasyMock.expect(herder.rawConfig(connector)).andReturn(null); + + EasyMock.expect(statusStore.get(connector)) + .andReturn(new ConnectorStatus(connector, AbstractStatus.State.RUNNING, workerId, generation)); + + EasyMock.expect(statusStore.getAll(connector)) + .andReturn(taskStatuses); + EasyMock.expect(worker.getPlugins()).andStubReturn(plugins); + + replayAll(); + + Optional mayBeRestartPlan = herder.buildRestartPlan(restartRequest); + + assertTrue(mayBeRestartPlan.isPresent()); + RestartPlan restartPlan = mayBeRestartPlan.get(); + assertFalse(restartPlan.shouldRestartConnector()); + assertFalse(restartPlan.shouldRestartTasks()); + assertTrue(restartPlan.taskIdsToRestart().isEmpty()); + + PowerMock.verifyAll(); + } + + @Test + public void testConfigValidationEmptyConfig() { + AbstractHerder herder = createConfigValidationHerder(TestSourceConnector.class, noneConnectorClientConfigOverridePolicy, 0); + replayAll(); + + assertThrows(BadRequestException.class, () -> herder.validateConnectorConfig(Collections.emptyMap(), false)); + + verifyAll(); + } + + @Test() + public void testConfigValidationMissingName() throws Throwable { + AbstractHerder herder = createConfigValidationHerder(TestSourceConnector.class, noneConnectorClientConfigOverridePolicy); + replayAll(); + + Map config = Collections.singletonMap(ConnectorConfig.CONNECTOR_CLASS_CONFIG, TestSourceConnector.class.getName()); + ConfigInfos result = herder.validateConnectorConfig(config, false); + + // We expect there to be errors due to the missing name and .... Note that these assertions depend heavily on + // the config fields for SourceConnectorConfig, but we expect these to change rarely. + assertEquals(TestSourceConnector.class.getName(), result.name()); + assertEquals(Arrays.asList(ConnectorConfig.COMMON_GROUP, ConnectorConfig.TRANSFORMS_GROUP, + ConnectorConfig.PREDICATES_GROUP, ConnectorConfig.ERROR_GROUP, SourceConnectorConfig.TOPIC_CREATION_GROUP), result.groups()); + assertEquals(2, result.errorCount()); + Map infos = result.values().stream() + .collect(Collectors.toMap(info -> info.configKey().name(), Function.identity())); + // Base connector config has 14 fields, connector's configs add 2 + assertEquals(17, infos.size()); + // Missing name should generate an error + assertEquals(ConnectorConfig.NAME_CONFIG, + infos.get(ConnectorConfig.NAME_CONFIG).configValue().name()); + assertEquals(1, infos.get(ConnectorConfig.NAME_CONFIG).configValue().errors().size()); + // "required" config from connector should generate an error + assertEquals("required", infos.get("required").configValue().name()); + assertEquals(1, infos.get("required").configValue().errors().size()); + + verifyAll(); + } + + @Test + public void testConfigValidationInvalidTopics() { + AbstractHerder herder = createConfigValidationHerder(TestSinkConnector.class, noneConnectorClientConfigOverridePolicy); + replayAll(); + + Map config = new HashMap<>(); + config.put(ConnectorConfig.CONNECTOR_CLASS_CONFIG, TestSinkConnector.class.getName()); + config.put(SinkConnectorConfig.TOPICS_CONFIG, "topic1,topic2"); + config.put(SinkConnectorConfig.TOPICS_REGEX_CONFIG, "topic.*"); + + assertThrows(ConfigException.class, () -> herder.validateConnectorConfig(config, false)); + + verifyAll(); + } + + @Test + public void testConfigValidationTopicsWithDlq() { + AbstractHerder herder = createConfigValidationHerder(TestSinkConnector.class, noneConnectorClientConfigOverridePolicy); + replayAll(); + + Map config = new HashMap<>(); + config.put(ConnectorConfig.CONNECTOR_CLASS_CONFIG, TestSinkConnector.class.getName()); + config.put(SinkConnectorConfig.TOPICS_CONFIG, "topic1"); + config.put(SinkConnectorConfig.DLQ_TOPIC_NAME_CONFIG, "topic1"); + + assertThrows(ConfigException.class, () -> herder.validateConnectorConfig(config, false)); + + verifyAll(); + } + + @Test + public void testConfigValidationTopicsRegexWithDlq() { + AbstractHerder herder = createConfigValidationHerder(TestSinkConnector.class, noneConnectorClientConfigOverridePolicy); + replayAll(); + + Map config = new HashMap<>(); + config.put(ConnectorConfig.CONNECTOR_CLASS_CONFIG, TestSinkConnector.class.getName()); + config.put(SinkConnectorConfig.TOPICS_REGEX_CONFIG, "topic.*"); + config.put(SinkConnectorConfig.DLQ_TOPIC_NAME_CONFIG, "topic1"); + + assertThrows(ConfigException.class, () -> herder.validateConnectorConfig(config, false)); + + verifyAll(); + } + + @SuppressWarnings({"rawtypes", "unchecked"}) + @Test() + public void testConfigValidationTransformsExtendResults() throws Throwable { + AbstractHerder herder = createConfigValidationHerder(TestSourceConnector.class, noneConnectorClientConfigOverridePolicy); + + // 2 transform aliases defined -> 2 plugin lookups + Set>> transformations = new HashSet<>(); + transformations.add(transformationPluginDesc()); + EasyMock.expect(plugins.transformations()).andReturn(transformations).times(2); + + replayAll(); + + // Define 2 transformations. One has a class defined and so can get embedded configs, the other is missing + // class info that should generate an error. + Map config = new HashMap<>(); + config.put(ConnectorConfig.CONNECTOR_CLASS_CONFIG, TestSourceConnector.class.getName()); + config.put(ConnectorConfig.NAME_CONFIG, "connector-name"); + config.put(ConnectorConfig.TRANSFORMS_CONFIG, "xformA,xformB"); + config.put(ConnectorConfig.TRANSFORMS_CONFIG + ".xformA.type", SampleTransformation.class.getName()); + config.put("required", "value"); // connector required config + ConfigInfos result = herder.validateConnectorConfig(config, false); + assertEquals(herder.connectorTypeForClass(config.get(ConnectorConfig.CONNECTOR_CLASS_CONFIG)), ConnectorType.SOURCE); + + // We expect there to be errors due to the missing name and .... Note that these assertions depend heavily on + // the config fields for SourceConnectorConfig, but we expect these to change rarely. + assertEquals(TestSourceConnector.class.getName(), result.name()); + // Each transform also gets its own group + List expectedGroups = Arrays.asList( + ConnectorConfig.COMMON_GROUP, + ConnectorConfig.TRANSFORMS_GROUP, + ConnectorConfig.PREDICATES_GROUP, + ConnectorConfig.ERROR_GROUP, + SourceConnectorConfig.TOPIC_CREATION_GROUP, + "Transforms: xformA", + "Transforms: xformB" + ); + assertEquals(expectedGroups, result.groups()); + assertEquals(2, result.errorCount()); + Map infos = result.values().stream() + .collect(Collectors.toMap(info -> info.configKey().name(), Function.identity())); + assertEquals(22, infos.size()); + // Should get 2 type fields from the transforms, first adds its own config since it has a valid class + assertEquals("transforms.xformA.type", + infos.get("transforms.xformA.type").configValue().name()); + assertTrue(infos.get("transforms.xformA.type").configValue().errors().isEmpty()); + assertEquals("transforms.xformA.subconfig", + infos.get("transforms.xformA.subconfig").configValue().name()); + assertEquals("transforms.xformB.type", infos.get("transforms.xformB.type").configValue().name()); + assertFalse(infos.get("transforms.xformB.type").configValue().errors().isEmpty()); + + verifyAll(); + } + + @Test() + public void testConfigValidationPredicatesExtendResults() { + AbstractHerder herder = createConfigValidationHerder(TestSourceConnector.class, noneConnectorClientConfigOverridePolicy); + + // 2 transform aliases defined -> 2 plugin lookups + Set>> transformations = new HashSet<>(); + transformations.add(transformationPluginDesc()); + EasyMock.expect(plugins.transformations()).andReturn(transformations).times(1); + + Set>> predicates = new HashSet<>(); + predicates.add(predicatePluginDesc()); + EasyMock.expect(plugins.predicates()).andReturn(predicates).times(2); + + replayAll(); + + // Define 2 transformations. One has a class defined and so can get embedded configs, the other is missing + // class info that should generate an error. + Map config = new HashMap<>(); + config.put(ConnectorConfig.CONNECTOR_CLASS_CONFIG, TestSourceConnector.class.getName()); + config.put(ConnectorConfig.NAME_CONFIG, "connector-name"); + config.put(ConnectorConfig.TRANSFORMS_CONFIG, "xformA"); + config.put(ConnectorConfig.TRANSFORMS_CONFIG + ".xformA.type", SampleTransformation.class.getName()); + config.put(ConnectorConfig.TRANSFORMS_CONFIG + ".xformA.predicate", "predX"); + config.put(ConnectorConfig.PREDICATES_CONFIG, "predX,predY"); + config.put(ConnectorConfig.PREDICATES_CONFIG + ".predX.type", SamplePredicate.class.getName()); + config.put("required", "value"); // connector required config + ConfigInfos result = herder.validateConnectorConfig(config, false); + assertEquals(herder.connectorTypeForClass(config.get(ConnectorConfig.CONNECTOR_CLASS_CONFIG)), ConnectorType.SOURCE); + + // We expect there to be errors due to the missing name and .... Note that these assertions depend heavily on + // the config fields for SourceConnectorConfig, but we expect these to change rarely. + assertEquals(TestSourceConnector.class.getName(), result.name()); + // Each transform also gets its own group + List expectedGroups = Arrays.asList( + ConnectorConfig.COMMON_GROUP, + ConnectorConfig.TRANSFORMS_GROUP, + ConnectorConfig.PREDICATES_GROUP, + ConnectorConfig.ERROR_GROUP, + SourceConnectorConfig.TOPIC_CREATION_GROUP, + "Transforms: xformA", + "Predicates: predX", + "Predicates: predY" + ); + assertEquals(expectedGroups, result.groups()); + assertEquals(2, result.errorCount()); + Map infos = result.values().stream() + .collect(Collectors.toMap(info -> info.configKey().name(), Function.identity())); + assertEquals(24, infos.size()); + // Should get 2 type fields from the transforms, first adds its own config since it has a valid class + assertEquals("transforms.xformA.type", + infos.get("transforms.xformA.type").configValue().name()); + assertTrue(infos.get("transforms.xformA.type").configValue().errors().isEmpty()); + assertEquals("transforms.xformA.subconfig", + infos.get("transforms.xformA.subconfig").configValue().name()); + assertEquals("transforms.xformA.predicate", + infos.get("transforms.xformA.predicate").configValue().name()); + assertTrue(infos.get("transforms.xformA.predicate").configValue().errors().isEmpty()); + assertEquals("transforms.xformA.negate", + infos.get("transforms.xformA.negate").configValue().name()); + assertTrue(infos.get("transforms.xformA.negate").configValue().errors().isEmpty()); + assertEquals("predicates.predX.type", + infos.get("predicates.predX.type").configValue().name()); + assertEquals("predicates.predX.predconfig", + infos.get("predicates.predX.predconfig").configValue().name()); + assertEquals("predicates.predY.type", + infos.get("predicates.predY.type").configValue().name()); + assertFalse( + infos.get("predicates.predY.type").configValue().errors().isEmpty()); + + verifyAll(); + } + + @SuppressWarnings({"rawtypes", "unchecked"}) + private PluginDesc> predicatePluginDesc() { + return new PluginDesc(SamplePredicate.class, "1.0", classLoader); + } + + @SuppressWarnings({"rawtypes", "unchecked"}) + private PluginDesc> transformationPluginDesc() { + return new PluginDesc(SampleTransformation.class, "1.0", classLoader); + } + + @Test() + public void testConfigValidationPrincipalOnlyOverride() throws Throwable { + AbstractHerder herder = createConfigValidationHerder(TestSourceConnector.class, new PrincipalConnectorClientConfigOverridePolicy()); + replayAll(); + + Map config = new HashMap<>(); + config.put(ConnectorConfig.CONNECTOR_CLASS_CONFIG, TestSourceConnector.class.getName()); + config.put(ConnectorConfig.NAME_CONFIG, "connector-name"); + config.put("required", "value"); // connector required config + String ackConfigKey = producerOverrideKey(ProducerConfig.ACKS_CONFIG); + String saslConfigKey = producerOverrideKey(SaslConfigs.SASL_JAAS_CONFIG); + config.put(ackConfigKey, "none"); + config.put(saslConfigKey, "jaas_config"); + + ConfigInfos result = herder.validateConnectorConfig(config, false); + assertEquals(herder.connectorTypeForClass(config.get(ConnectorConfig.CONNECTOR_CLASS_CONFIG)), ConnectorType.SOURCE); + + // We expect there to be errors due to now allowed override policy for ACKS.... Note that these assertions depend heavily on + // the config fields for SourceConnectorConfig, but we expect these to change rarely. + assertEquals(TestSourceConnector.class.getName(), result.name()); + // Each transform also gets its own group + List expectedGroups = Arrays.asList( + ConnectorConfig.COMMON_GROUP, + ConnectorConfig.TRANSFORMS_GROUP, + ConnectorConfig.PREDICATES_GROUP, + ConnectorConfig.ERROR_GROUP, + SourceConnectorConfig.TOPIC_CREATION_GROUP + ); + assertEquals(expectedGroups, result.groups()); + assertEquals(1, result.errorCount()); + // Base connector config has 14 fields, connector's configs add 2, and 2 producer overrides + assertEquals(19, result.values().size()); + assertTrue(result.values().stream().anyMatch( + configInfo -> ackConfigKey.equals(configInfo.configValue().name()) && !configInfo.configValue().errors().isEmpty())); + assertTrue(result.values().stream().anyMatch( + configInfo -> saslConfigKey.equals(configInfo.configValue().name()) && configInfo.configValue().errors().isEmpty())); + + verifyAll(); + } + + @Test + public void testConfigValidationAllOverride() throws Throwable { + AbstractHerder herder = createConfigValidationHerder(TestSourceConnector.class, new AllConnectorClientConfigOverridePolicy()); + replayAll(); + + Map config = new HashMap<>(); + config.put(ConnectorConfig.CONNECTOR_CLASS_CONFIG, TestSourceConnector.class.getName()); + config.put(ConnectorConfig.NAME_CONFIG, "connector-name"); + config.put("required", "value"); // connector required config + // Try to test a variety of configuration types: string, int, long, boolean, list, class + String protocolConfigKey = producerOverrideKey(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG); + config.put(protocolConfigKey, "SASL_PLAINTEXT"); + String maxRequestSizeConfigKey = producerOverrideKey(ProducerConfig.MAX_REQUEST_SIZE_CONFIG); + config.put(maxRequestSizeConfigKey, "420"); + String maxBlockConfigKey = producerOverrideKey(ProducerConfig.MAX_BLOCK_MS_CONFIG); + config.put(maxBlockConfigKey, "28980"); + String idempotenceConfigKey = producerOverrideKey(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG); + config.put(idempotenceConfigKey, "true"); + String bootstrapServersConfigKey = producerOverrideKey(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG); + config.put(bootstrapServersConfigKey, "SASL_PLAINTEXT://localhost:12345,SASL_PLAINTEXT://localhost:23456"); + String loginCallbackHandlerConfigKey = producerOverrideKey(SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS); + config.put(loginCallbackHandlerConfigKey, OAuthBearerUnsecuredLoginCallbackHandler.class.getName()); + + final Set overriddenClientConfigs = new HashSet<>(); + overriddenClientConfigs.add(protocolConfigKey); + overriddenClientConfigs.add(maxRequestSizeConfigKey); + overriddenClientConfigs.add(maxBlockConfigKey); + overriddenClientConfigs.add(idempotenceConfigKey); + overriddenClientConfigs.add(bootstrapServersConfigKey); + overriddenClientConfigs.add(loginCallbackHandlerConfigKey); + + ConfigInfos result = herder.validateConnectorConfig(config, false); + assertEquals(herder.connectorTypeForClass(config.get(ConnectorConfig.CONNECTOR_CLASS_CONFIG)), ConnectorType.SOURCE); + + Map validatedOverriddenClientConfigs = new HashMap<>(); + for (ConfigInfo configInfo : result.values()) { + String configName = configInfo.configKey().name(); + if (overriddenClientConfigs.contains(configName)) { + validatedOverriddenClientConfigs.put(configName, configInfo.configValue().value()); + } + } + Map rawOverriddenClientConfigs = config.entrySet().stream() + .filter(e -> overriddenClientConfigs.contains(e.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + assertEquals(rawOverriddenClientConfigs, validatedOverriddenClientConfigs); + verifyAll(); + } + + @Test + public void testReverseTransformConfigs() { + // Construct a task config with constant values for TEST_KEY and TEST_KEY2 + Map newTaskConfig = new HashMap<>(); + newTaskConfig.put(TaskConfig.TASK_CLASS_CONFIG, BogusSourceTask.class.getName()); + newTaskConfig.put(TEST_KEY, TEST_VAL); + newTaskConfig.put(TEST_KEY2, TEST_VAL2); + List> newTaskConfigs = new ArrayList<>(); + newTaskConfigs.add(newTaskConfig); + + // The SNAPSHOT has a task config with TEST_KEY and TEST_REF + List> reverseTransformed = AbstractHerder.reverseTransform(CONN1, SNAPSHOT, newTaskConfigs); + assertEquals(TEST_REF, reverseTransformed.get(0).get(TEST_KEY)); + + // The SNAPSHOT has no task configs but does have a connector config with TEST_KEY2 and TEST_REF2 + reverseTransformed = AbstractHerder.reverseTransform(CONN1, SNAPSHOT_NO_TASKS, newTaskConfigs); + assertEquals(TEST_REF2, reverseTransformed.get(0).get(TEST_KEY2)); + + // The reverseTransformed result should not have TEST_KEY3 since newTaskConfigs does not have TEST_KEY3 + reverseTransformed = AbstractHerder.reverseTransform(CONN1, SNAPSHOT_NO_TASKS, newTaskConfigs); + assertFalse(reverseTransformed.get(0).containsKey(TEST_KEY3)); + } + + @Test + public void testConfigProviderRegex() { + testConfigProviderRegex("\"${::}\""); + testConfigProviderRegex("${::}"); + testConfigProviderRegex("\"${:/a:somevar}\""); + testConfigProviderRegex("\"${file::somevar}\""); + testConfigProviderRegex("${file:/a/b/c:}"); + testConfigProviderRegex("${file:/tmp/somefile.txt:somevar}"); + testConfigProviderRegex("\"${file:/tmp/somefile.txt:somevar}\""); + testConfigProviderRegex("plain.PlainLoginModule required username=\"${file:/tmp/somefile.txt:somevar}\""); + testConfigProviderRegex("plain.PlainLoginModule required username=${file:/tmp/somefile.txt:somevar}"); + testConfigProviderRegex("plain.PlainLoginModule required username=${file:/tmp/somefile.txt:somevar} not null"); + testConfigProviderRegex("plain.PlainLoginModule required username=${file:/tmp/somefile.txt:somevar} password=${file:/tmp/somefile.txt:othervar}"); + testConfigProviderRegex("plain.PlainLoginModule required username", false); + } + + @Test + public void testGenerateResultWithConfigValuesAllUsingConfigKeysAndWithNoErrors() { + String name = "com.acme.connector.MyConnector"; + Map keys = new HashMap<>(); + addConfigKey(keys, "config.a1", null); + addConfigKey(keys, "config.b1", "group B"); + addConfigKey(keys, "config.b2", "group B"); + addConfigKey(keys, "config.c1", "group C"); + + List groups = Arrays.asList("groupB", "group C"); + List values = new ArrayList<>(); + addValue(values, "config.a1", "value.a1"); + addValue(values, "config.b1", "value.b1"); + addValue(values, "config.b2", "value.b2"); + addValue(values, "config.c1", "value.c1"); + + ConfigInfos infos = AbstractHerder.generateResult(name, keys, values, groups); + assertEquals(name, infos.name()); + assertEquals(groups, infos.groups()); + assertEquals(values.size(), infos.values().size()); + assertEquals(0, infos.errorCount()); + assertInfoKey(infos, "config.a1", null); + assertInfoKey(infos, "config.b1", "group B"); + assertInfoKey(infos, "config.b2", "group B"); + assertInfoKey(infos, "config.c1", "group C"); + assertInfoValue(infos, "config.a1", "value.a1"); + assertInfoValue(infos, "config.b1", "value.b1"); + assertInfoValue(infos, "config.b2", "value.b2"); + assertInfoValue(infos, "config.c1", "value.c1"); + } + + @Test + public void testGenerateResultWithConfigValuesAllUsingConfigKeysAndWithSomeErrors() { + String name = "com.acme.connector.MyConnector"; + Map keys = new HashMap<>(); + addConfigKey(keys, "config.a1", null); + addConfigKey(keys, "config.b1", "group B"); + addConfigKey(keys, "config.b2", "group B"); + addConfigKey(keys, "config.c1", "group C"); + + List groups = Arrays.asList("groupB", "group C"); + List values = new ArrayList<>(); + addValue(values, "config.a1", "value.a1"); + addValue(values, "config.b1", "value.b1"); + addValue(values, "config.b2", "value.b2"); + addValue(values, "config.c1", "value.c1", "error c1"); + + ConfigInfos infos = AbstractHerder.generateResult(name, keys, values, groups); + assertEquals(name, infos.name()); + assertEquals(groups, infos.groups()); + assertEquals(values.size(), infos.values().size()); + assertEquals(1, infos.errorCount()); + assertInfoKey(infos, "config.a1", null); + assertInfoKey(infos, "config.b1", "group B"); + assertInfoKey(infos, "config.b2", "group B"); + assertInfoKey(infos, "config.c1", "group C"); + assertInfoValue(infos, "config.a1", "value.a1"); + assertInfoValue(infos, "config.b1", "value.b1"); + assertInfoValue(infos, "config.b2", "value.b2"); + assertInfoValue(infos, "config.c1", "value.c1", "error c1"); + } + + @Test + public void testGenerateResultWithConfigValuesMoreThanConfigKeysAndWithSomeErrors() { + String name = "com.acme.connector.MyConnector"; + Map keys = new HashMap<>(); + addConfigKey(keys, "config.a1", null); + addConfigKey(keys, "config.b1", "group B"); + addConfigKey(keys, "config.b2", "group B"); + addConfigKey(keys, "config.c1", "group C"); + + List groups = Arrays.asList("groupB", "group C"); + List values = new ArrayList<>(); + addValue(values, "config.a1", "value.a1"); + addValue(values, "config.b1", "value.b1"); + addValue(values, "config.b2", "value.b2"); + addValue(values, "config.c1", "value.c1", "error c1"); + addValue(values, "config.extra1", "value.extra1"); + addValue(values, "config.extra2", "value.extra2", "error extra2"); + + ConfigInfos infos = AbstractHerder.generateResult(name, keys, values, groups); + assertEquals(name, infos.name()); + assertEquals(groups, infos.groups()); + assertEquals(values.size(), infos.values().size()); + assertEquals(2, infos.errorCount()); + assertInfoKey(infos, "config.a1", null); + assertInfoKey(infos, "config.b1", "group B"); + assertInfoKey(infos, "config.b2", "group B"); + assertInfoKey(infos, "config.c1", "group C"); + assertNoInfoKey(infos, "config.extra1"); + assertNoInfoKey(infos, "config.extra2"); + assertInfoValue(infos, "config.a1", "value.a1"); + assertInfoValue(infos, "config.b1", "value.b1"); + assertInfoValue(infos, "config.b2", "value.b2"); + assertInfoValue(infos, "config.c1", "value.c1", "error c1"); + assertInfoValue(infos, "config.extra1", "value.extra1"); + assertInfoValue(infos, "config.extra2", "value.extra2", "error extra2"); + } + + @Test + public void testGenerateResultWithConfigValuesWithNoConfigKeysAndWithSomeErrors() { + String name = "com.acme.connector.MyConnector"; + Map keys = new HashMap<>(); + + List groups = new ArrayList<>(); + List values = new ArrayList<>(); + addValue(values, "config.a1", "value.a1"); + addValue(values, "config.b1", "value.b1"); + addValue(values, "config.b2", "value.b2"); + addValue(values, "config.c1", "value.c1", "error c1"); + addValue(values, "config.extra1", "value.extra1"); + addValue(values, "config.extra2", "value.extra2", "error extra2"); + + ConfigInfos infos = AbstractHerder.generateResult(name, keys, values, groups); + assertEquals(name, infos.name()); + assertEquals(groups, infos.groups()); + assertEquals(values.size(), infos.values().size()); + assertEquals(2, infos.errorCount()); + assertNoInfoKey(infos, "config.a1"); + assertNoInfoKey(infos, "config.b1"); + assertNoInfoKey(infos, "config.b2"); + assertNoInfoKey(infos, "config.c1"); + assertNoInfoKey(infos, "config.extra1"); + assertNoInfoKey(infos, "config.extra2"); + assertInfoValue(infos, "config.a1", "value.a1"); + assertInfoValue(infos, "config.b1", "value.b1"); + assertInfoValue(infos, "config.b2", "value.b2"); + assertInfoValue(infos, "config.c1", "value.c1", "error c1"); + assertInfoValue(infos, "config.extra1", "value.extra1"); + assertInfoValue(infos, "config.extra2", "value.extra2", "error extra2"); + } + + protected void addConfigKey(Map keys, String name, String group) { + keys.put(name, new ConfigDef.ConfigKey(name, ConfigDef.Type.STRING, null, null, + ConfigDef.Importance.HIGH, "doc", group, 10, + ConfigDef.Width.MEDIUM, "display name", Collections.emptyList(), null, false)); + } + + protected void addValue(List values, String name, String value, String...errors) { + values.add(new ConfigValue(name, value, new ArrayList<>(), Arrays.asList(errors))); + } + + protected void assertInfoKey(ConfigInfos infos, String name, String group) { + ConfigInfo info = findInfo(infos, name); + assertEquals(name, info.configKey().name()); + assertEquals(group, info.configKey().group()); + } + + protected void assertNoInfoKey(ConfigInfos infos, String name) { + ConfigInfo info = findInfo(infos, name); + assertNull(info.configKey()); + } + + protected void assertInfoValue(ConfigInfos infos, String name, String value, String...errors) { + ConfigValueInfo info = findInfo(infos, name).configValue(); + assertEquals(name, info.name()); + assertEquals(value, info.value()); + assertEquals(Arrays.asList(errors), info.errors()); + } + + protected ConfigInfo findInfo(ConfigInfos infos, String name) { + return infos.values() + .stream() + .filter(i -> i.configValue().name().equals(name)) + .findFirst() + .orElse(null); + } + + private void testConfigProviderRegex(String rawConnConfig) { + testConfigProviderRegex(rawConnConfig, true); + } + + private void testConfigProviderRegex(String rawConnConfig, boolean expected) { + Set keys = keysWithVariableValues(Collections.singletonMap("key", rawConnConfig), ConfigTransformer.DEFAULT_PATTERN); + boolean actual = keys != null && !keys.isEmpty() && keys.contains("key"); + assertEquals(String.format("%s should have matched regex", rawConnConfig), expected, actual); + } + + private AbstractHerder createConfigValidationHerder(Class connectorClass, + ConnectorClientConfigOverridePolicy connectorClientConfigOverridePolicy) { + return createConfigValidationHerder(connectorClass, connectorClientConfigOverridePolicy, 1); + } + + private AbstractHerder createConfigValidationHerder(Class connectorClass, + ConnectorClientConfigOverridePolicy connectorClientConfigOverridePolicy, + int countOfCallingNewConnector) { + + + ConfigBackingStore configStore = strictMock(ConfigBackingStore.class); + StatusBackingStore statusStore = strictMock(StatusBackingStore.class); + + AbstractHerder herder = partialMockBuilder(AbstractHerder.class) + .withConstructor(Worker.class, String.class, String.class, StatusBackingStore.class, ConfigBackingStore.class, + ConnectorClientConfigOverridePolicy.class) + .withArgs(worker, workerId, kafkaClusterId, statusStore, configStore, connectorClientConfigOverridePolicy) + .addMockedMethod("generation") + .createMock(); + EasyMock.expect(herder.generation()).andStubReturn(generation); + + // Call to validateConnectorConfig + EasyMock.expect(worker.configTransformer()).andReturn(transformer).times(2); + final Capture> configCapture = EasyMock.newCapture(); + EasyMock.expect(transformer.transform(EasyMock.capture(configCapture))).andAnswer(configCapture::getValue); + EasyMock.expect(worker.getPlugins()).andStubReturn(plugins); + final Connector connector; + try { + connector = connectorClass.getConstructor().newInstance(); + } catch (ReflectiveOperationException e) { + throw new RuntimeException("Couldn't create connector", e); + } + if (countOfCallingNewConnector > 0) { + EasyMock.expect(plugins.newConnector(connectorClass.getName())).andReturn(connector).times(countOfCallingNewConnector); + EasyMock.expect(plugins.compareAndSwapLoaders(connector)).andReturn(classLoader).times(countOfCallingNewConnector); + } + + return herder; + } + + public static class SampleTransformation> implements Transformation { + @Override + public void configure(Map configs) { + + } + + @Override + public R apply(R record) { + return record; + } + + @Override + public ConfigDef config() { + return new ConfigDef() + .define("subconfig", ConfigDef.Type.STRING, "default", ConfigDef.Importance.LOW, "docs"); + } + + @Override + public void close() { + + } + } + + public static class SamplePredicate> implements Predicate { + + @Override + public ConfigDef config() { + return new ConfigDef() + .define("predconfig", ConfigDef.Type.STRING, "default", ConfigDef.Importance.LOW, "docs"); + } + + @Override + public boolean test(R record) { + return false; + } + + @Override + public void close() { + + } + + @Override + public void configure(Map configs) { + + } + } + + // We need to use a real class here due to some issue with mocking java.lang.Class + private abstract class BogusSourceConnector extends SourceConnector { + } + + private abstract class BogusSourceTask extends SourceTask { + } + + private static String producerOverrideKey(String config) { + return ConnectorConfig.CONNECTOR_CLIENT_PRODUCER_OVERRIDES_PREFIX + config; + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/ConnectMetricsTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/ConnectMetricsTest.java new file mode 100644 index 0000000..98bf8e0 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/ConnectMetricsTest.java @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.stats.Avg; +import org.apache.kafka.common.metrics.stats.Max; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.connect.runtime.ConnectMetrics.MetricGroup; +import org.apache.kafka.connect.runtime.ConnectMetrics.MetricGroupId; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; + +@SuppressWarnings("deprecation") +public class ConnectMetricsTest { + + private static final Map DEFAULT_WORKER_CONFIG = new HashMap<>(); + + static { + DEFAULT_WORKER_CONFIG.put(WorkerConfig.KEY_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + DEFAULT_WORKER_CONFIG.put(WorkerConfig.VALUE_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + } + + private ConnectMetrics metrics; + + @Before + public void setUp() { + metrics = new ConnectMetrics("worker1", new WorkerConfig(WorkerConfig.baseConfigDef(), DEFAULT_WORKER_CONFIG), new MockTime(), "cluster-1"); + } + + @After + public void tearDown() { + if (metrics != null) + metrics.stop(); + } + + @Test + public void testKafkaMetricsNotNull() { + assertNotNull(metrics.metrics()); + } + + @Test + public void testGettingGroupWithOddNumberOfTags() { + assertThrows(IllegalArgumentException.class, + () -> metrics.group("name", "k1", "v1", "k2", "v2", "extra")); + } + + @Test + public void testGettingGroupWithTags() { + MetricGroup group1 = metrics.group("name", "k1", "v1", "k2", "v2"); + assertEquals("v1", group1.tags().get("k1")); + assertEquals("v2", group1.tags().get("k2")); + assertEquals(2, group1.tags().size()); + } + + @Test + public void testGettingGroupMultipleTimes() { + MetricGroup group1 = metrics.group("name"); + MetricGroup group2 = metrics.group("name"); + assertNotNull(group1); + assertSame(group1, group2); + MetricGroup group3 = metrics.group("other"); + assertNotNull(group3); + assertNotSame(group1, group3); + + // Now with tags + MetricGroup group4 = metrics.group("name", "k1", "v1"); + assertNotNull(group4); + assertNotSame(group1, group4); + assertNotSame(group2, group4); + assertNotSame(group3, group4); + MetricGroup group5 = metrics.group("name", "k1", "v1"); + assertSame(group4, group5); + } + + @Test + public void testMetricGroupIdIdentity() { + MetricGroupId id1 = metrics.groupId("name", "k1", "v1"); + MetricGroupId id2 = metrics.groupId("name", "k1", "v1"); + MetricGroupId id3 = metrics.groupId("name", "k1", "v1", "k2", "v2"); + + assertEquals(id1.hashCode(), id2.hashCode()); + assertEquals(id1, id2); + assertEquals(id1.toString(), id2.toString()); + assertEquals(id1.groupName(), id2.groupName()); + assertEquals(id1.tags(), id2.tags()); + assertNotNull(id1.tags()); + + assertNotEquals(id1, id3); + } + + @Test + public void testMetricGroupIdWithoutTags() { + MetricGroupId id1 = metrics.groupId("name"); + MetricGroupId id2 = metrics.groupId("name"); + + assertEquals(id1.hashCode(), id2.hashCode()); + assertEquals(id1, id2); + assertEquals(id1.toString(), id2.toString()); + assertEquals(id1.groupName(), id2.groupName()); + assertEquals(id1.tags(), id2.tags()); + assertNotNull(id1.tags()); + assertNotNull(id2.tags()); + } + + @Test + public void testRecreateWithClose() { + final Sensor originalSensor = addToGroup(metrics, false); + final Sensor recreatedSensor = addToGroup(metrics, true); + // because we closed the metricGroup, we get a brand-new sensor + assertNotSame(originalSensor, recreatedSensor); + } + + @Test + public void testRecreateWithoutClose() { + final Sensor originalSensor = addToGroup(metrics, false); + final Sensor recreatedSensor = addToGroup(metrics, false); + // since we didn't close the group, the second addToGroup is idempotent + assertSame(originalSensor, recreatedSensor); + } + + private Sensor addToGroup(ConnectMetrics connectMetrics, boolean shouldClose) { + ConnectMetricsRegistry registry = connectMetrics.registry(); + ConnectMetrics.MetricGroup metricGroup = connectMetrics.group(registry.taskGroupName(), + registry.connectorTagName(), "conn_name"); + + if (shouldClose) { + metricGroup.close(); + } + + Sensor sensor = metricGroup.sensor("my_sensor"); + sensor.add(metricName("x1"), new Max()); + sensor.add(metricName("y2"), new Avg()); + + return sensor; + } + + static MetricName metricName(String name) { + return new MetricName(name, "test_group", "metrics for testing", Collections.emptyMap()); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/ConnectorConfigTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/ConnectorConfigTest.java new file mode 100644 index 0000000..4abdbea --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/ConnectorConfigTest.java @@ -0,0 +1,487 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.connect.connector.ConnectRecord; +import org.apache.kafka.connect.connector.Connector; +import org.apache.kafka.connect.runtime.isolation.PluginDesc; +import org.apache.kafka.connect.runtime.isolation.Plugins; +import org.apache.kafka.connect.transforms.Transformation; +import org.apache.kafka.connect.transforms.predicates.Predicate; +import org.junit.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class ConnectorConfigTest> { + + public static final Plugins MOCK_PLUGINS = new Plugins(new HashMap<>()) { + @Override + public Set>> transformations() { + return Collections.emptySet(); + } + }; + + public static abstract class TestConnector extends Connector { + } + + public static class SimpleTransformation> implements Transformation { + + int magicNumber = 0; + + @Override + public void configure(Map props) { + magicNumber = Integer.parseInt((String) props.get("magic.number")); + } + + @Override + public R apply(R record) { + return null; + } + + @Override + public void close() { + magicNumber = 0; + } + + @Override + public ConfigDef config() { + return new ConfigDef() + .define("magic.number", ConfigDef.Type.INT, ConfigDef.NO_DEFAULT_VALUE, ConfigDef.Range.atLeast(42), ConfigDef.Importance.HIGH, ""); + } + } + + @Test + public void noTransforms() { + Map props = new HashMap<>(); + props.put("name", "test"); + props.put("connector.class", TestConnector.class.getName()); + new ConnectorConfig(MOCK_PLUGINS, props); + } + + @Test + public void danglingTransformAlias() { + Map props = new HashMap<>(); + props.put("name", "test"); + props.put("connector.class", TestConnector.class.getName()); + props.put("transforms", "dangler"); + ConfigException e = assertThrows(ConfigException.class, () -> new ConnectorConfig(MOCK_PLUGINS, props)); + assertTrue(e.getMessage().contains("Not a Transformation")); + } + + @Test + public void emptyConnectorName() { + Map props = new HashMap<>(); + props.put("name", ""); + props.put("connector.class", TestConnector.class.getName()); + ConfigException e = assertThrows(ConfigException.class, () -> new ConnectorConfig(MOCK_PLUGINS, props)); + assertTrue(e.getMessage().contains("String may not be empty")); + } + + @Test + public void wrongTransformationType() { + Map props = new HashMap<>(); + props.put("name", "test"); + props.put("connector.class", TestConnector.class.getName()); + props.put("transforms", "a"); + props.put("transforms.a.type", "uninstantiable"); + ConfigException e = assertThrows(ConfigException.class, () -> new ConnectorConfig(MOCK_PLUGINS, props)); + assertTrue(e.getMessage().contains("Class uninstantiable could not be found")); + } + + @Test + public void unconfiguredTransform() { + Map props = new HashMap<>(); + props.put("name", "test"); + props.put("connector.class", TestConnector.class.getName()); + props.put("transforms", "a"); + props.put("transforms.a.type", SimpleTransformation.class.getName()); + ConfigException e = assertThrows(ConfigException.class, () -> new ConnectorConfig(MOCK_PLUGINS, props)); + assertTrue(e.getMessage().contains("Missing required configuration \"transforms.a.magic.number\" which")); + } + + @Test + public void misconfiguredTransform() { + Map props = new HashMap<>(); + props.put("name", "test"); + props.put("connector.class", TestConnector.class.getName()); + props.put("transforms", "a"); + props.put("transforms.a.type", SimpleTransformation.class.getName()); + props.put("transforms.a.magic.number", "40"); + ConfigException e = assertThrows(ConfigException.class, () -> new ConnectorConfig(MOCK_PLUGINS, props)); + assertTrue(e.getMessage().contains("Value must be at least 42")); + } + + @Test + public void singleTransform() { + Map props = new HashMap<>(); + props.put("name", "test"); + props.put("connector.class", TestConnector.class.getName()); + props.put("transforms", "a"); + props.put("transforms.a.type", SimpleTransformation.class.getName()); + props.put("transforms.a.magic.number", "42"); + final ConnectorConfig config = new ConnectorConfig(MOCK_PLUGINS, props); + final List> transformations = config.transformations(); + assertEquals(1, transformations.size()); + final SimpleTransformation xform = (SimpleTransformation) transformations.get(0); + assertEquals(42, xform.magicNumber); + } + + @Test + public void multipleTransformsOneDangling() { + Map props = new HashMap<>(); + props.put("name", "test"); + props.put("connector.class", TestConnector.class.getName()); + props.put("transforms", "a, b"); + props.put("transforms.a.type", SimpleTransformation.class.getName()); + props.put("transforms.a.magic.number", "42"); + assertThrows(ConfigException.class, () -> new ConnectorConfig(MOCK_PLUGINS, props)); + } + + @Test + public void multipleTransforms() { + Map props = new HashMap<>(); + props.put("name", "test"); + props.put("connector.class", TestConnector.class.getName()); + props.put("transforms", "a, b"); + props.put("transforms.a.type", SimpleTransformation.class.getName()); + props.put("transforms.a.magic.number", "42"); + props.put("transforms.b.type", SimpleTransformation.class.getName()); + props.put("transforms.b.magic.number", "84"); + final ConnectorConfig config = new ConnectorConfig(MOCK_PLUGINS, props); + final List> transformations = config.transformations(); + assertEquals(2, transformations.size()); + assertEquals(42, ((SimpleTransformation) transformations.get(0)).magicNumber); + assertEquals(84, ((SimpleTransformation) transformations.get(1)).magicNumber); + } + + @Test + public void abstractTransform() { + Map props = new HashMap<>(); + props.put("name", "test"); + props.put("connector.class", TestConnector.class.getName()); + props.put("transforms", "a"); + props.put("transforms.a.type", AbstractTransformation.class.getName()); + try { + new ConnectorConfig(MOCK_PLUGINS, props); + } catch (ConfigException ex) { + assertTrue( + ex.getMessage().contains("Transformation is abstract and cannot be created.") + ); + } + } + @Test + public void abstractKeyValueTransform() { + Map props = new HashMap<>(); + props.put("name", "test"); + props.put("connector.class", TestConnector.class.getName()); + props.put("transforms", "a"); + props.put("transforms.a.type", AbstractKeyValueTransformation.class.getName()); + try { + new ConnectorConfig(MOCK_PLUGINS, props); + } catch (ConfigException ex) { + assertTrue( + ex.getMessage().contains("Transformation is abstract and cannot be created.") + ); + assertTrue( + ex.getMessage().contains(AbstractKeyValueTransformation.Key.class.getName()) + ); + assertTrue( + ex.getMessage().contains(AbstractKeyValueTransformation.Value.class.getName()) + ); + } + } + + @Test + public void wrongPredicateType() { + Map props = new HashMap<>(); + props.put("name", "test"); + props.put("connector.class", TestConnector.class.getName()); + props.put("transforms", "a"); + props.put("transforms.a.type", SimpleTransformation.class.getName()); + props.put("transforms.a.magic.number", "42"); + props.put("transforms.a.predicate", "my-pred"); + props.put("predicates", "my-pred"); + props.put("predicates.my-pred.type", TestConnector.class.getName()); + ConfigException e = assertThrows(ConfigException.class, () -> new ConnectorConfig(MOCK_PLUGINS, props)); + assertTrue(e.getMessage().contains("Not a Predicate")); + } + + @Test + public void singleConditionalTransform() { + Map props = new HashMap<>(); + props.put("name", "test"); + props.put("connector.class", TestConnector.class.getName()); + props.put("transforms", "a"); + props.put("transforms.a.type", SimpleTransformation.class.getName()); + props.put("transforms.a.magic.number", "42"); + props.put("transforms.a.predicate", "my-pred"); + props.put("transforms.a.negate", "true"); + props.put("predicates", "my-pred"); + props.put("predicates.my-pred.type", TestPredicate.class.getName()); + props.put("predicates.my-pred.int", "84"); + assertPredicatedTransform(props, true); + } + + @Test + public void predicateNegationDefaultsToFalse() { + Map props = new HashMap<>(); + props.put("name", "test"); + props.put("connector.class", TestConnector.class.getName()); + props.put("transforms", "a"); + props.put("transforms.a.type", SimpleTransformation.class.getName()); + props.put("transforms.a.magic.number", "42"); + props.put("transforms.a.predicate", "my-pred"); + props.put("predicates", "my-pred"); + props.put("predicates.my-pred.type", TestPredicate.class.getName()); + props.put("predicates.my-pred.int", "84"); + assertPredicatedTransform(props, false); + } + + @Test + public void abstractPredicate() { + Map props = new HashMap<>(); + props.put("name", "test"); + props.put("connector.class", TestConnector.class.getName()); + props.put("transforms", "a"); + props.put("transforms.a.type", SimpleTransformation.class.getName()); + props.put("transforms.a.magic.number", "42"); + props.put("transforms.a.predicate", "my-pred"); + props.put("predicates", "my-pred"); + props.put("predicates.my-pred.type", AbstractTestPredicate.class.getName()); + props.put("predicates.my-pred.int", "84"); + ConfigException e = assertThrows(ConfigException.class, () -> new ConnectorConfig(MOCK_PLUGINS, props)); + assertTrue(e.getMessage().contains("Predicate is abstract and cannot be created")); + } + + private void assertPredicatedTransform(Map props, boolean expectedNegated) { + final ConnectorConfig config = new ConnectorConfig(MOCK_PLUGINS, props); + final List> transformations = config.transformations(); + assertEquals(1, transformations.size()); + assertTrue(transformations.get(0) instanceof PredicatedTransformation); + PredicatedTransformation predicated = (PredicatedTransformation) transformations.get(0); + + assertEquals(expectedNegated, predicated.negate); + + assertTrue(predicated.delegate instanceof ConnectorConfigTest.SimpleTransformation); + assertEquals(42, ((SimpleTransformation) predicated.delegate).magicNumber); + + assertTrue(predicated.predicate instanceof ConnectorConfigTest.TestPredicate); + assertEquals(84, ((TestPredicate) predicated.predicate).param); + + predicated.close(); + + assertEquals(0, ((SimpleTransformation) predicated.delegate).magicNumber); + assertEquals(0, ((TestPredicate) predicated.predicate).param); + } + + @Test + public void misconfiguredPredicate() { + Map props = new HashMap<>(); + props.put("name", "test"); + props.put("connector.class", TestConnector.class.getName()); + props.put("transforms", "a"); + props.put("transforms.a.type", SimpleTransformation.class.getName()); + props.put("transforms.a.magic.number", "42"); + props.put("transforms.a.predicate", "my-pred"); + props.put("transforms.a.negate", "true"); + props.put("predicates", "my-pred"); + props.put("predicates.my-pred.type", TestPredicate.class.getName()); + props.put("predicates.my-pred.int", "79"); + try { + new ConnectorConfig(MOCK_PLUGINS, props); + fail(); + } catch (ConfigException e) { + assertTrue(e.getMessage().contains("Value must be at least 80")); + } + } + + @Test + public void missingPredicateAliasProperty() { + Map props = new HashMap<>(); + props.put("name", "test"); + props.put("connector.class", TestConnector.class.getName()); + props.put("transforms", "a"); + props.put("transforms.a.type", SimpleTransformation.class.getName()); + props.put("transforms.a.magic.number", "42"); + props.put("transforms.a.predicate", "my-pred"); + // technically not needed + //props.put("predicates", "my-pred"); + props.put("predicates.my-pred.type", TestPredicate.class.getName()); + props.put("predicates.my-pred.int", "84"); + new ConnectorConfig(MOCK_PLUGINS, props); + } + + @Test + public void missingPredicateConfig() { + Map props = new HashMap<>(); + props.put("name", "test"); + props.put("connector.class", TestConnector.class.getName()); + props.put("transforms", "a"); + props.put("transforms.a.type", SimpleTransformation.class.getName()); + props.put("transforms.a.magic.number", "42"); + props.put("transforms.a.predicate", "my-pred"); + props.put("predicates", "my-pred"); + //props.put("predicates.my-pred.type", TestPredicate.class.getName()); + //props.put("predicates.my-pred.int", "84"); + ConfigException e = assertThrows(ConfigException.class, () -> new ConnectorConfig(MOCK_PLUGINS, props)); + assertTrue(e.getMessage().contains("Not a Predicate")); + } + + @Test + public void negatedButNoPredicate() { + Map props = new HashMap<>(); + props.put("name", "test"); + props.put("connector.class", TestConnector.class.getName()); + props.put("transforms", "a"); + props.put("transforms.a.type", SimpleTransformation.class.getName()); + props.put("transforms.a.magic.number", "42"); + props.put("transforms.a.negate", "true"); + ConfigException e = assertThrows(ConfigException.class, () -> new ConnectorConfig(MOCK_PLUGINS, props)); + assertTrue(e.getMessage().contains("there is no config 'transforms.a.predicate' defining a predicate to be negated")); + } + + public static class TestPredicate> implements Predicate { + + int param; + + public TestPredicate() { } + + @Override + public ConfigDef config() { + return new ConfigDef().define("int", ConfigDef.Type.INT, 80, ConfigDef.Range.atLeast(80), ConfigDef.Importance.MEDIUM, + "A test parameter"); + } + + @Override + public boolean test(R record) { + return false; + } + + @Override + public void close() { + param = 0; + } + + @Override + public void configure(Map configs) { + param = Integer.parseInt((String) configs.get("int")); + } + } + + public static abstract class AbstractTestPredicate> implements Predicate { + + public AbstractTestPredicate() { } + + } + + public static abstract class AbstractTransformation> implements Transformation { + + } + + public static abstract class AbstractKeyValueTransformation> implements Transformation { + @Override + public R apply(R record) { + return null; + } + + @Override + public ConfigDef config() { + return new ConfigDef(); + } + + @Override + public void close() { + + } + + @Override + public void configure(Map configs) { + + } + + + public static class Key> extends AbstractKeyValueTransformation { + + + } + public static class Value> extends AbstractKeyValueTransformation { + + } + } + + @Test + public void testEnrichedConfigDef() { + String alias = "hdt"; + String prefix = ConnectorConfig.TRANSFORMS_CONFIG + "." + alias + "."; + Map props = new HashMap<>(); + props.put(ConnectorConfig.TRANSFORMS_CONFIG, alias); + props.put(prefix + "type", HasDuplicateConfigTransformation.class.getName()); + ConfigDef def = ConnectorConfig.enrich(MOCK_PLUGINS, new ConfigDef(), props, false); + assertEnrichedConfigDef(def, prefix, HasDuplicateConfigTransformation.MUST_EXIST_KEY, ConfigDef.Type.BOOLEAN); + assertEnrichedConfigDef(def, prefix, PredicatedTransformation.PREDICATE_CONFIG, ConfigDef.Type.STRING); + assertEnrichedConfigDef(def, prefix, PredicatedTransformation.NEGATE_CONFIG, ConfigDef.Type.BOOLEAN); + } + + private static void assertEnrichedConfigDef(ConfigDef def, String prefix, String keyName, ConfigDef.Type expectedType) { + assertNull(def.configKeys().get(keyName)); + ConfigDef.ConfigKey configKey = def.configKeys().get(prefix + keyName); + assertNotNull(prefix + keyName + "' config must be present", configKey); + assertEquals(prefix + keyName + "' config should be a " + expectedType, expectedType, configKey.type); + } + + public static class HasDuplicateConfigTransformation> implements Transformation { + private static final String MUST_EXIST_KEY = "must.exist.key"; + private static final ConfigDef CONFIG_DEF = new ConfigDef() + // this configDef is duplicate. It should be removed automatically so as to avoid duplicate config error. + .define(PredicatedTransformation.PREDICATE_CONFIG, ConfigDef.Type.INT, ConfigDef.NO_DEFAULT_VALUE, ConfigDef.Importance.MEDIUM, "fake") + // this configDef is duplicate. It should be removed automatically so as to avoid duplicate config error. + .define(PredicatedTransformation.NEGATE_CONFIG, ConfigDef.Type.INT, 123, ConfigDef.Importance.MEDIUM, "fake") + // this configDef should appear if above duplicate configDef is removed without any error + .define(MUST_EXIST_KEY, ConfigDef.Type.BOOLEAN, true, ConfigDef.Importance.MEDIUM, "this key must exist"); + + @Override + public R apply(R record) { + return record; + } + + @Override + public ConfigDef config() { + return CONFIG_DEF; + } + + @Override + public void close() { + } + + @Override + public void configure(Map configs) { + } + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/ErrorHandlingTaskTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/ErrorHandlingTaskTest.java new file mode 100644 index 0000000..4743dce --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/ErrorHandlingTaskTest.java @@ -0,0 +1,668 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.clients.consumer.ConsumerRebalanceListener; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.connect.connector.ConnectRecord; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaBuilder; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.errors.RetriableException; +import org.apache.kafka.connect.integration.MonitorableSourceConnector; +import org.apache.kafka.connect.json.JsonConverter; +import org.apache.kafka.connect.runtime.distributed.ClusterConfigState; +import org.apache.kafka.connect.runtime.errors.ErrorHandlingMetrics; +import org.apache.kafka.connect.runtime.errors.ErrorReporter; +import org.apache.kafka.connect.runtime.errors.LogReporter; +import org.apache.kafka.connect.runtime.errors.RetryWithToleranceOperator; +import org.apache.kafka.connect.runtime.errors.ToleranceType; +import org.apache.kafka.connect.runtime.errors.WorkerErrantRecordReporter; +import org.apache.kafka.connect.runtime.isolation.PluginClassLoader; +import org.apache.kafka.connect.runtime.isolation.Plugins; +import org.apache.kafka.connect.runtime.standalone.StandaloneConfig; +import org.apache.kafka.connect.sink.SinkConnector; +import org.apache.kafka.connect.sink.SinkRecord; +import org.apache.kafka.connect.sink.SinkTask; +import org.apache.kafka.connect.source.SourceRecord; +import org.apache.kafka.connect.source.SourceTask; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.storage.HeaderConverter; +import org.apache.kafka.connect.storage.OffsetStorageReaderImpl; +import org.apache.kafka.connect.storage.OffsetStorageWriter; +import org.apache.kafka.connect.storage.StatusBackingStore; +import org.apache.kafka.connect.storage.StringConverter; +import org.apache.kafka.connect.transforms.Transformation; +import org.apache.kafka.connect.transforms.util.SimpleConfig; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.apache.kafka.connect.util.ParameterizedTest; +import org.apache.kafka.connect.util.TopicAdmin; +import org.apache.kafka.connect.util.TopicCreationGroup; +import org.easymock.Capture; +import org.easymock.EasyMock; +import org.easymock.IExpectationSetters; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.easymock.PowerMock; +import org.powermock.api.easymock.annotation.Mock; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; +import org.powermock.modules.junit4.PowerMockRunnerDelegate; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.Executor; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.singletonList; +import static org.apache.kafka.common.utils.Time.SYSTEM; +import static org.apache.kafka.connect.integration.MonitorableSourceConnector.TOPIC_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.CONNECTOR_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.KEY_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.TASKS_MAX_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.VALUE_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.SourceConnectorConfig.TOPIC_CREATION_GROUPS_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.DEFAULT_TOPIC_CREATION_PREFIX; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.INCLUDE_REGEX_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.PARTITIONS_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.REPLICATION_FACTOR_CONFIG; +import static org.apache.kafka.connect.runtime.WorkerConfig.TOPIC_CREATION_ENABLE_CONFIG; +import static org.junit.Assert.assertEquals; + +@RunWith(PowerMockRunner.class) +@PowerMockRunnerDelegate(ParameterizedTest.class) +@PrepareForTest({WorkerSinkTask.class, WorkerSourceTask.class}) +@PowerMockIgnore("javax.management.*") +public class ErrorHandlingTaskTest { + + private static final String TOPIC = "test"; + private static final int PARTITION1 = 12; + private static final int PARTITION2 = 13; + private static final long FIRST_OFFSET = 45; + + @Mock Plugins plugins; + + private static final Map TASK_PROPS = new HashMap<>(); + + static { + TASK_PROPS.put(SinkConnector.TOPICS_CONFIG, TOPIC); + TASK_PROPS.put(TaskConfig.TASK_CLASS_CONFIG, TestSinkTask.class.getName()); + } + + public static final long OPERATOR_RETRY_TIMEOUT_MILLIS = 60000; + public static final long OPERATOR_RETRY_MAX_DELAY_MILLIS = 5000; + public static final ToleranceType OPERATOR_TOLERANCE_TYPE = ToleranceType.ALL; + + private static final TaskConfig TASK_CONFIG = new TaskConfig(TASK_PROPS); + + private ConnectorTaskId taskId = new ConnectorTaskId("job", 0); + private TargetState initialState = TargetState.STARTED; + private Time time; + private MockConnectMetrics metrics; + @SuppressWarnings("unused") + @Mock + private SinkTask sinkTask; + @SuppressWarnings("unused") + @Mock + private SourceTask sourceTask; + private Capture sinkTaskContext = EasyMock.newCapture(); + private WorkerConfig workerConfig; + private SourceConnectorConfig sourceConfig; + @Mock + private PluginClassLoader pluginLoader; + @SuppressWarnings("unused") + @Mock + private HeaderConverter headerConverter; + private WorkerSinkTask workerSinkTask; + private WorkerSourceTask workerSourceTask; + @SuppressWarnings("unused") + @Mock + private KafkaConsumer consumer; + @SuppressWarnings("unused") + @Mock + private KafkaProducer producer; + @SuppressWarnings("unused") + @Mock private TopicAdmin admin; + + @Mock + OffsetStorageReaderImpl offsetReader; + @Mock + OffsetStorageWriter offsetWriter; + + private Capture rebalanceListener = EasyMock.newCapture(); + @SuppressWarnings("unused") + @Mock + private TaskStatus.Listener statusListener; + @SuppressWarnings("unused") + @Mock private StatusBackingStore statusBackingStore; + + @Mock + private WorkerErrantRecordReporter workerErrantRecordReporter; + + private ErrorHandlingMetrics errorHandlingMetrics; + + private boolean enableTopicCreation; + + @ParameterizedTest.Parameters + public static Collection parameters() { + return Arrays.asList(false, true); + } + + public ErrorHandlingTaskTest(boolean enableTopicCreation) { + this.enableTopicCreation = enableTopicCreation; + } + + @Before + public void setup() { + time = new MockTime(0, 0, 0); + metrics = new MockConnectMetrics(); + Map workerProps = new HashMap<>(); + workerProps.put("key.converter", "org.apache.kafka.connect.json.JsonConverter"); + workerProps.put("value.converter", "org.apache.kafka.connect.json.JsonConverter"); + workerProps.put("offset.storage.file.filename", "/tmp/connect.offsets"); + workerProps.put(TOPIC_CREATION_ENABLE_CONFIG, String.valueOf(enableTopicCreation)); + pluginLoader = PowerMock.createMock(PluginClassLoader.class); + workerConfig = new StandaloneConfig(workerProps); + sourceConfig = new SourceConnectorConfig(plugins, sourceConnectorProps(TOPIC), true); + errorHandlingMetrics = new ErrorHandlingMetrics(taskId, metrics); + } + + private Map sourceConnectorProps(String topic) { + // setup up props for the source connector + Map props = new HashMap<>(); + props.put("name", "foo-connector"); + props.put(CONNECTOR_CLASS_CONFIG, MonitorableSourceConnector.class.getSimpleName()); + props.put(TASKS_MAX_CONFIG, String.valueOf(1)); + props.put(TOPIC_CONFIG, topic); + props.put(KEY_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(VALUE_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(TOPIC_CREATION_GROUPS_CONFIG, String.join(",", "foo", "bar")); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + REPLICATION_FACTOR_CONFIG, String.valueOf(1)); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + PARTITIONS_CONFIG, String.valueOf(1)); + props.put(SourceConnectorConfig.TOPIC_CREATION_PREFIX + "foo" + "." + INCLUDE_REGEX_CONFIG, topic); + return props; + } + + @After + public void tearDown() { + if (metrics != null) { + metrics.stop(); + } + } + + @Test + public void testSinkTasksCloseErrorReporters() throws Exception { + ErrorReporter reporter = EasyMock.mock(ErrorReporter.class); + + RetryWithToleranceOperator retryWithToleranceOperator = operator(); + retryWithToleranceOperator.metrics(errorHandlingMetrics); + retryWithToleranceOperator.reporters(singletonList(reporter)); + + createSinkTask(initialState, retryWithToleranceOperator); + + expectInitializeTask(); + reporter.close(); + EasyMock.expectLastCall(); + sinkTask.stop(); + EasyMock.expectLastCall(); + + consumer.close(); + EasyMock.expectLastCall(); + + PowerMock.replayAll(); + + workerSinkTask.initialize(TASK_CONFIG); + workerSinkTask.initializeAndStart(); + workerSinkTask.close(); + + PowerMock.verifyAll(); + } + + @Test + public void testSourceTasksCloseErrorReporters() { + ErrorReporter reporter = EasyMock.mock(ErrorReporter.class); + + RetryWithToleranceOperator retryWithToleranceOperator = operator(); + retryWithToleranceOperator.metrics(errorHandlingMetrics); + retryWithToleranceOperator.reporters(singletonList(reporter)); + + createSourceTask(initialState, retryWithToleranceOperator); + + expectClose(); + + reporter.close(); + EasyMock.expectLastCall(); + + PowerMock.replayAll(); + + workerSourceTask.initialize(TASK_CONFIG); + workerSourceTask.close(); + + PowerMock.verifyAll(); + } + + @Test + public void testCloseErrorReportersExceptionPropagation() { + ErrorReporter reporterA = EasyMock.mock(ErrorReporter.class); + ErrorReporter reporterB = EasyMock.mock(ErrorReporter.class); + + RetryWithToleranceOperator retryWithToleranceOperator = operator(); + retryWithToleranceOperator.metrics(errorHandlingMetrics); + retryWithToleranceOperator.reporters(Arrays.asList(reporterA, reporterB)); + + createSourceTask(initialState, retryWithToleranceOperator); + + expectClose(); + + // Even though the reporters throw exceptions, they should both still be closed. + reporterA.close(); + EasyMock.expectLastCall().andThrow(new RuntimeException()); + + reporterB.close(); + EasyMock.expectLastCall().andThrow(new RuntimeException()); + + PowerMock.replayAll(); + + workerSourceTask.initialize(TASK_CONFIG); + workerSourceTask.close(); + + PowerMock.verifyAll(); + } + + @Test + public void testErrorHandlingInSinkTasks() throws Exception { + Map reportProps = new HashMap<>(); + reportProps.put(ConnectorConfig.ERRORS_LOG_ENABLE_CONFIG, "true"); + reportProps.put(ConnectorConfig.ERRORS_LOG_INCLUDE_MESSAGES_CONFIG, "true"); + LogReporter reporter = new LogReporter(taskId, connConfig(reportProps), errorHandlingMetrics); + + RetryWithToleranceOperator retryWithToleranceOperator = operator(); + retryWithToleranceOperator.metrics(errorHandlingMetrics); + retryWithToleranceOperator.reporters(singletonList(reporter)); + createSinkTask(initialState, retryWithToleranceOperator); + + expectInitializeTask(); + expectTaskGetTopic(true); + + // valid json + ConsumerRecord record1 = new ConsumerRecord<>(TOPIC, PARTITION1, FIRST_OFFSET, null, "{\"a\": 10}".getBytes()); + // bad json + ConsumerRecord record2 = new ConsumerRecord<>(TOPIC, PARTITION2, FIRST_OFFSET, null, "{\"a\" 10}".getBytes()); + + EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andReturn(records(record1)); + EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andReturn(records(record2)); + + sinkTask.put(EasyMock.anyObject()); + EasyMock.expectLastCall().times(2); + + PowerMock.replayAll(); + + workerSinkTask.initialize(TASK_CONFIG); + workerSinkTask.initializeAndStart(); + workerSinkTask.iteration(); + + workerSinkTask.iteration(); + + // two records were consumed from Kafka + assertSinkMetricValue("sink-record-read-total", 2.0); + // only one was written to the task + assertSinkMetricValue("sink-record-send-total", 1.0); + // one record completely failed (converter issues) + assertErrorHandlingMetricValue("total-record-errors", 1.0); + // 2 failures in the transformation, and 1 in the converter + assertErrorHandlingMetricValue("total-record-failures", 3.0); + // one record completely failed (converter issues), and thus was skipped + assertErrorHandlingMetricValue("total-records-skipped", 1.0); + + PowerMock.verifyAll(); + } + + private RetryWithToleranceOperator operator() { + return new RetryWithToleranceOperator(OPERATOR_RETRY_TIMEOUT_MILLIS, OPERATOR_RETRY_MAX_DELAY_MILLIS, OPERATOR_TOLERANCE_TYPE, SYSTEM); + } + + @Test + public void testErrorHandlingInSourceTasks() throws Exception { + Map reportProps = new HashMap<>(); + reportProps.put(ConnectorConfig.ERRORS_LOG_ENABLE_CONFIG, "true"); + reportProps.put(ConnectorConfig.ERRORS_LOG_INCLUDE_MESSAGES_CONFIG, "true"); + LogReporter reporter = new LogReporter(taskId, connConfig(reportProps), errorHandlingMetrics); + + RetryWithToleranceOperator retryWithToleranceOperator = operator(); + retryWithToleranceOperator.metrics(errorHandlingMetrics); + retryWithToleranceOperator.reporters(singletonList(reporter)); + createSourceTask(initialState, retryWithToleranceOperator); + + // valid json + Schema valSchema = SchemaBuilder.struct().field("val", Schema.INT32_SCHEMA).build(); + Struct struct1 = new Struct(valSchema).put("val", 1234); + SourceRecord record1 = new SourceRecord(emptyMap(), emptyMap(), TOPIC, PARTITION1, valSchema, struct1); + Struct struct2 = new Struct(valSchema).put("val", 6789); + SourceRecord record2 = new SourceRecord(emptyMap(), emptyMap(), TOPIC, PARTITION1, valSchema, struct2); + + EasyMock.expect(workerSourceTask.isStopping()).andReturn(false); + EasyMock.expect(workerSourceTask.isStopping()).andReturn(false); + EasyMock.expect(workerSourceTask.isStopping()).andReturn(true); + + EasyMock.expect(workerSourceTask.commitOffsets()).andReturn(true); + + sourceTask.initialize(EasyMock.anyObject()); + EasyMock.expectLastCall(); + + sourceTask.start(EasyMock.anyObject()); + EasyMock.expectLastCall(); + + EasyMock.expect(sourceTask.poll()).andReturn(singletonList(record1)); + EasyMock.expect(sourceTask.poll()).andReturn(singletonList(record2)); + expectTopicCreation(TOPIC); + EasyMock.expect(producer.send(EasyMock.anyObject(), EasyMock.anyObject())).andReturn(null).times(2); + + PowerMock.replayAll(); + + workerSourceTask.initialize(TASK_CONFIG); + workerSourceTask.initializeAndStart(); + workerSourceTask.execute(); + + // two records were consumed from Kafka + assertSourceMetricValue("source-record-poll-total", 2.0); + // only one was written to the task + assertSourceMetricValue("source-record-write-total", 0.0); + // one record completely failed (converter issues) + assertErrorHandlingMetricValue("total-record-errors", 0.0); + // 2 failures in the transformation, and 1 in the converter + assertErrorHandlingMetricValue("total-record-failures", 4.0); + // one record completely failed (converter issues), and thus was skipped + assertErrorHandlingMetricValue("total-records-skipped", 0.0); + + PowerMock.verifyAll(); + } + + private ConnectorConfig connConfig(Map connProps) { + Map props = new HashMap<>(); + props.put(ConnectorConfig.NAME_CONFIG, "test"); + props.put(ConnectorConfig.CONNECTOR_CLASS_CONFIG, SinkTask.class.getName()); + props.putAll(connProps); + return new ConnectorConfig(plugins, props); + } + + @Test + public void testErrorHandlingInSourceTasksWthBadConverter() throws Exception { + Map reportProps = new HashMap<>(); + reportProps.put(ConnectorConfig.ERRORS_LOG_ENABLE_CONFIG, "true"); + reportProps.put(ConnectorConfig.ERRORS_LOG_INCLUDE_MESSAGES_CONFIG, "true"); + LogReporter reporter = new LogReporter(taskId, connConfig(reportProps), errorHandlingMetrics); + + RetryWithToleranceOperator retryWithToleranceOperator = operator(); + retryWithToleranceOperator.metrics(errorHandlingMetrics); + retryWithToleranceOperator.reporters(singletonList(reporter)); + createSourceTask(initialState, retryWithToleranceOperator, badConverter()); + + // valid json + Schema valSchema = SchemaBuilder.struct().field("val", Schema.INT32_SCHEMA).build(); + Struct struct1 = new Struct(valSchema).put("val", 1234); + SourceRecord record1 = new SourceRecord(emptyMap(), emptyMap(), TOPIC, PARTITION1, valSchema, struct1); + Struct struct2 = new Struct(valSchema).put("val", 6789); + SourceRecord record2 = new SourceRecord(emptyMap(), emptyMap(), TOPIC, PARTITION1, valSchema, struct2); + + EasyMock.expect(workerSourceTask.isStopping()).andReturn(false); + EasyMock.expect(workerSourceTask.isStopping()).andReturn(false); + EasyMock.expect(workerSourceTask.isStopping()).andReturn(true); + + EasyMock.expect(workerSourceTask.commitOffsets()).andReturn(true); + + sourceTask.initialize(EasyMock.anyObject()); + EasyMock.expectLastCall(); + + sourceTask.start(EasyMock.anyObject()); + EasyMock.expectLastCall(); + + EasyMock.expect(sourceTask.poll()).andReturn(singletonList(record1)); + EasyMock.expect(sourceTask.poll()).andReturn(singletonList(record2)); + expectTopicCreation(TOPIC); + EasyMock.expect(producer.send(EasyMock.anyObject(), EasyMock.anyObject())).andReturn(null).times(2); + + PowerMock.replayAll(); + + workerSourceTask.initialize(TASK_CONFIG); + workerSourceTask.initializeAndStart(); + workerSourceTask.execute(); + + // two records were consumed from Kafka + assertSourceMetricValue("source-record-poll-total", 2.0); + // only one was written to the task + assertSourceMetricValue("source-record-write-total", 0.0); + // one record completely failed (converter issues) + assertErrorHandlingMetricValue("total-record-errors", 0.0); + // 2 failures in the transformation, and 1 in the converter + assertErrorHandlingMetricValue("total-record-failures", 8.0); + // one record completely failed (converter issues), and thus was skipped + assertErrorHandlingMetricValue("total-records-skipped", 0.0); + + PowerMock.verifyAll(); + } + + private void assertSinkMetricValue(String name, double expected) { + ConnectMetrics.MetricGroup sinkTaskGroup = workerSinkTask.sinkTaskMetricsGroup().metricGroup(); + double measured = metrics.currentMetricValueAsDouble(sinkTaskGroup, name); + assertEquals(expected, measured, 0.001d); + } + + private void assertSourceMetricValue(String name, double expected) { + ConnectMetrics.MetricGroup sinkTaskGroup = workerSourceTask.sourceTaskMetricsGroup().metricGroup(); + double measured = metrics.currentMetricValueAsDouble(sinkTaskGroup, name); + assertEquals(expected, measured, 0.001d); + } + + private void assertErrorHandlingMetricValue(String name, double expected) { + ConnectMetrics.MetricGroup sinkTaskGroup = errorHandlingMetrics.metricGroup(); + double measured = metrics.currentMetricValueAsDouble(sinkTaskGroup, name); + assertEquals(expected, measured, 0.001d); + } + + private void expectInitializeTask() throws Exception { + consumer.subscribe(EasyMock.eq(singletonList(TOPIC)), EasyMock.capture(rebalanceListener)); + PowerMock.expectLastCall(); + + sinkTask.initialize(EasyMock.capture(sinkTaskContext)); + PowerMock.expectLastCall(); + sinkTask.start(TASK_PROPS); + PowerMock.expectLastCall(); + } + + private void expectTaskGetTopic(boolean anyTimes) { + final Capture connectorCapture = EasyMock.newCapture(); + final Capture topicCapture = EasyMock.newCapture(); + IExpectationSetters expect = EasyMock.expect(statusBackingStore.getTopic( + EasyMock.capture(connectorCapture), + EasyMock.capture(topicCapture))); + if (anyTimes) { + expect.andStubAnswer(() -> new TopicStatus( + topicCapture.getValue(), + new ConnectorTaskId(connectorCapture.getValue(), 0), + Time.SYSTEM.milliseconds())); + } else { + expect.andAnswer(() -> new TopicStatus( + topicCapture.getValue(), + new ConnectorTaskId(connectorCapture.getValue(), 0), + Time.SYSTEM.milliseconds())); + } + if (connectorCapture.hasCaptured() && topicCapture.hasCaptured()) { + assertEquals("job", connectorCapture.getValue()); + assertEquals(TOPIC, topicCapture.getValue()); + } + } + + private void expectClose() { + producer.close(EasyMock.anyObject(Duration.class)); + EasyMock.expectLastCall(); + + admin.close(EasyMock.anyObject(Duration.class)); + EasyMock.expectLastCall(); + } + + private void expectTopicCreation(String topic) { + if (workerConfig.topicCreationEnable()) { + EasyMock.expect(admin.describeTopics(topic)).andReturn(Collections.emptyMap()); + Capture newTopicCapture = EasyMock.newCapture(); + + if (enableTopicCreation) { + Set created = Collections.singleton(topic); + Set existing = Collections.emptySet(); + TopicAdmin.TopicCreationResponse response = new TopicAdmin.TopicCreationResponse(created, existing); + EasyMock.expect(admin.createOrFindTopics(EasyMock.capture(newTopicCapture))).andReturn(response); + } else { + EasyMock.expect(admin.createTopic(EasyMock.capture(newTopicCapture))).andReturn(true); + } + } + } + + private void createSinkTask(TargetState initialState, RetryWithToleranceOperator retryWithToleranceOperator) { + JsonConverter converter = new JsonConverter(); + Map oo = workerConfig.originalsWithPrefix("value.converter."); + oo.put("converter.type", "value"); + oo.put("schemas.enable", "false"); + converter.configure(oo); + + TransformationChain sinkTransforms = new TransformationChain<>(singletonList(new FaultyPassthrough()), retryWithToleranceOperator); + + workerSinkTask = new WorkerSinkTask( + taskId, sinkTask, statusListener, initialState, workerConfig, + ClusterConfigState.EMPTY, metrics, converter, converter, + headerConverter, sinkTransforms, consumer, pluginLoader, time, + retryWithToleranceOperator, workerErrantRecordReporter, statusBackingStore); + } + + private void createSourceTask(TargetState initialState, RetryWithToleranceOperator retryWithToleranceOperator) { + JsonConverter converter = new JsonConverter(); + Map oo = workerConfig.originalsWithPrefix("value.converter."); + oo.put("converter.type", "value"); + oo.put("schemas.enable", "false"); + converter.configure(oo); + + createSourceTask(initialState, retryWithToleranceOperator, converter); + } + + private Converter badConverter() { + FaultyConverter converter = new FaultyConverter(); + Map oo = workerConfig.originalsWithPrefix("value.converter."); + oo.put("converter.type", "value"); + oo.put("schemas.enable", "false"); + converter.configure(oo); + return converter; + } + + private void createSourceTask(TargetState initialState, RetryWithToleranceOperator retryWithToleranceOperator, Converter converter) { + TransformationChain sourceTransforms = new TransformationChain<>(singletonList(new FaultyPassthrough()), retryWithToleranceOperator); + + workerSourceTask = PowerMock.createPartialMock( + WorkerSourceTask.class, new String[]{"commitOffsets", "isStopping"}, + taskId, sourceTask, statusListener, initialState, converter, converter, headerConverter, sourceTransforms, + producer, admin, TopicCreationGroup.configuredGroups(sourceConfig), + offsetReader, offsetWriter, workerConfig, + ClusterConfigState.EMPTY, metrics, pluginLoader, time, retryWithToleranceOperator, + statusBackingStore, (Executor) Runnable::run); + } + + private ConsumerRecords records(ConsumerRecord record) { + return new ConsumerRecords<>(Collections.singletonMap( + new TopicPartition(record.topic(), record.partition()), singletonList(record))); + } + + private abstract static class TestSinkTask extends SinkTask { + } + + static class FaultyConverter extends JsonConverter { + private static final Logger log = LoggerFactory.getLogger(FaultyConverter.class); + private int invocations = 0; + + public byte[] fromConnectData(String topic, Schema schema, Object value) { + if (value == null) { + return super.fromConnectData(topic, schema, null); + } + invocations++; + if (invocations % 3 == 0) { + log.debug("Succeeding record: {} where invocations={}", value, invocations); + return super.fromConnectData(topic, schema, value); + } else { + log.debug("Failing record: {} at invocations={}", value, invocations); + throw new RetriableException("Bad invocations " + invocations + " for mod 3"); + } + } + } + + static class FaultyPassthrough> implements Transformation { + + private static final Logger log = LoggerFactory.getLogger(FaultyPassthrough.class); + + private static final String MOD_CONFIG = "mod"; + private static final int MOD_CONFIG_DEFAULT = 3; + + public static final ConfigDef CONFIG_DEF = new ConfigDef() + .define(MOD_CONFIG, ConfigDef.Type.INT, MOD_CONFIG_DEFAULT, ConfigDef.Importance.MEDIUM, "Pass records without failure only if timestamp % mod == 0"); + + private int mod = MOD_CONFIG_DEFAULT; + + private int invocations = 0; + + @Override + public R apply(R record) { + invocations++; + if (invocations % mod == 0) { + log.debug("Succeeding record: {} where invocations={}", record, invocations); + return record; + } else { + log.debug("Failing record: {} at invocations={}", record, invocations); + throw new RetriableException("Bad invocations " + invocations + " for mod " + mod); + } + } + + @Override + public ConfigDef config() { + return CONFIG_DEF; + } + + @Override + public void close() { + log.info("Shutting down transform"); + } + + @Override + public void configure(Map configs) { + final SimpleConfig config = new SimpleConfig(CONFIG_DEF, configs); + mod = Math.max(config.getInt(MOD_CONFIG), 2); + log.info("Configuring {}. Setting mod to {}", this.getClass(), mod); + } + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/MockConnectMetrics.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/MockConnectMetrics.java new file mode 100644 index 0000000..4abbc64 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/MockConnectMetrics.java @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.KafkaMetric; +import org.apache.kafka.common.metrics.MetricsContext; +import org.apache.kafka.common.metrics.MetricsReporter; +import org.apache.kafka.common.utils.MockTime; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * A specialization of {@link ConnectMetrics} that uses a custom {@link MetricsReporter} to capture the metrics + * that were created, and makes those metrics available even after the metrics were removed from the + * {@link org.apache.kafka.common.metrics.Metrics} registry. + * + * This is needed because many of the Connect metric groups are specific to connectors and/or tasks, and therefore + * their metrics are removed from the {@link org.apache.kafka.common.metrics.Metrics} registry when the connector + * and tasks are closed. This instance keeps track of the metrics that were created so that it is possible for + * tests to {@link #currentMetricValue(MetricGroup, String) read the metrics' value} even after the connector + * and/or tasks have been closed. + * + * If the same metric is created a second time (e.g., a worker task is re-created), the new metric will replace + * the previous metric in the custom reporter. + */ +@SuppressWarnings("deprecation") +public class MockConnectMetrics extends ConnectMetrics { + + private static final Map DEFAULT_WORKER_CONFIG = new HashMap<>(); + + static { + DEFAULT_WORKER_CONFIG.put(WorkerConfig.KEY_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + DEFAULT_WORKER_CONFIG.put(WorkerConfig.VALUE_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + DEFAULT_WORKER_CONFIG.put(CommonClientConfigs.METRIC_REPORTER_CLASSES_CONFIG, MockMetricsReporter.class.getName()); + } + + public MockConnectMetrics() { + this(new MockTime()); + } + + public MockConnectMetrics(MockTime time) { + super("mock", new WorkerConfig(WorkerConfig.baseConfigDef(), DEFAULT_WORKER_CONFIG), time, "cluster-1"); + } + + @Override + public MockTime time() { + return (MockTime) super.time(); + } + + /** + * Get the current value of the named metric, which may have already been removed from the + * {@link org.apache.kafka.common.metrics.Metrics} but will have been captured before it was removed. + * + * @param metricGroup the metric metricGroup that contained the metric + * @param name the name of the metric + * @return the current value of the metric + */ + public Object currentMetricValue(MetricGroup metricGroup, String name) { + return currentMetricValue(this, metricGroup, name); + } + + /** + * Get the current value of the named metric, which may have already been removed from the + * {@link org.apache.kafka.common.metrics.Metrics} but will have been captured before it was removed. + * + * @param metricGroup the metric metricGroup that contained the metric + * @param name the name of the metric + * @return the current value of the metric + */ + public double currentMetricValueAsDouble(MetricGroup metricGroup, String name) { + Object value = currentMetricValue(metricGroup, name); + return value instanceof Double ? (Double) value : Double.NaN; + } + + /** + * Get the current value of the named metric, which may have already been removed from the + * {@link org.apache.kafka.common.metrics.Metrics} but will have been captured before it was removed. + * + * @param metricGroup the metric metricGroup that contained the metric + * @param name the name of the metric + * @return the current value of the metric + */ + public String currentMetricValueAsString(MetricGroup metricGroup, String name) { + Object value = currentMetricValue(metricGroup, name); + return value instanceof String ? (String) value : null; + } + + /** + * Get the current value of the named metric, which may have already been removed from the + * {@link org.apache.kafka.common.metrics.Metrics} but will have been captured before it was removed. + * + * @param metrics the {@link ConnectMetrics} instance + * @param metricGroup the metric metricGroup that contained the metric + * @param name the name of the metric + * @return the current value of the metric + */ + public static Object currentMetricValue(ConnectMetrics metrics, MetricGroup metricGroup, String name) { + MetricName metricName = metricGroup.metricName(name); + for (MetricsReporter reporter : metrics.metrics().reporters()) { + if (reporter instanceof MockMetricsReporter) { + return ((MockMetricsReporter) reporter).currentMetricValue(metricName); + } + } + return null; + } + + /** + * Get the current value of the named metric, which may have already been removed from the + * {@link org.apache.kafka.common.metrics.Metrics} but will have been captured before it was removed. + * + * @param metrics the {@link ConnectMetrics} instance + * @param metricGroup the metric metricGroup that contained the metric + * @param name the name of the metric + * @return the current value of the metric + */ + public static double currentMetricValueAsDouble(ConnectMetrics metrics, MetricGroup metricGroup, String name) { + Object value = currentMetricValue(metrics, metricGroup, name); + return value instanceof Double ? (Double) value : Double.NaN; + } + + /** + * Get the current value of the named metric, which may have already been removed from the + * {@link org.apache.kafka.common.metrics.Metrics} but will have been captured before it was removed. + * + * @param metrics the {@link ConnectMetrics} instance + * @param metricGroup the metric metricGroup that contained the metric + * @param name the name of the metric + * @return the current value of the metric + */ + public static String currentMetricValueAsString(ConnectMetrics metrics, MetricGroup metricGroup, String name) { + Object value = currentMetricValue(metrics, metricGroup, name); + return value instanceof String ? (String) value : null; + } + + public static class MockMetricsReporter implements MetricsReporter { + private Map metricsByName = new HashMap<>(); + + private MetricsContext metricsContext; + + public MockMetricsReporter() { + } + + @Override + public void configure(Map configs) { + // do nothing + } + + @Override + public void init(List metrics) { + for (KafkaMetric metric : metrics) { + metricsByName.put(metric.metricName(), metric); + } + } + + @Override + public void metricChange(KafkaMetric metric) { + metricsByName.put(metric.metricName(), metric); + } + + @Override + public void metricRemoval(KafkaMetric metric) { + // don't remove metrics, or else we won't be able to access them after the metric metricGroup is closed + } + + @Override + public void close() { + // do nothing + } + + /** + * Get the current value of the metric. + * + * @param metricName the name of the metric that was registered most recently + * @return the current value of the metric + */ + public Object currentMetricValue(MetricName metricName) { + KafkaMetric metric = metricsByName.get(metricName); + return metric != null ? metric.metricValue() : null; + } + + @Override + public void contextChange(MetricsContext metricsContext) { + this.metricsContext = metricsContext; + } + + public MetricsContext getMetricsContext() { + return this.metricsContext; + } + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/PredicatedTransformationTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/PredicatedTransformationTest.java new file mode 100644 index 0000000..75542e7 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/PredicatedTransformationTest.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import java.util.Map; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.source.SourceRecord; +import org.apache.kafka.connect.transforms.Transformation; +import org.apache.kafka.connect.transforms.predicates.Predicate; +import org.junit.Test; + +import static java.util.Collections.singletonMap; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class PredicatedTransformationTest { + + private final SourceRecord initial = new SourceRecord(singletonMap("initial", 1), null, null, null, null); + private final SourceRecord transformed = new SourceRecord(singletonMap("transformed", 2), null, null, null, null); + + @Test + public void apply() { + applyAndAssert(true, false, transformed); + applyAndAssert(true, true, initial); + applyAndAssert(false, false, initial); + applyAndAssert(false, true, transformed); + } + + private void applyAndAssert(boolean predicateResult, boolean negate, + SourceRecord expectedResult) { + class TestTransformation implements Transformation { + + private boolean closed = false; + private SourceRecord transformedRecord; + + private TestTransformation(SourceRecord transformedRecord) { + this.transformedRecord = transformedRecord; + } + + @Override + public SourceRecord apply(SourceRecord record) { + return transformedRecord; + } + + @Override + public ConfigDef config() { + return null; + } + + @Override + public void close() { + closed = true; + } + + @Override + public void configure(Map configs) { + + } + + private void assertClosed() { + assertTrue("Transformer should be closed", closed); + } + } + + class TestPredicate implements Predicate { + + private boolean testResult; + private boolean closed = false; + + private TestPredicate(boolean testResult) { + this.testResult = testResult; + } + + @Override + public ConfigDef config() { + return null; + } + + @Override + public boolean test(SourceRecord record) { + return testResult; + } + + @Override + public void close() { + closed = true; + } + + @Override + public void configure(Map configs) { + + } + + private void assertClosed() { + assertTrue("Predicate should be closed", closed); + } + } + TestPredicate predicate = new TestPredicate(predicateResult); + TestTransformation predicatedTransform = new TestTransformation(transformed); + PredicatedTransformation pt = new PredicatedTransformation<>( + predicate, + negate, + predicatedTransform); + + assertEquals(expectedResult, pt.apply(initial)); + + pt.close(); + predicate.assertClosed(); + predicatedTransform.assertClosed(); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/RestartPlanTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/RestartPlanTest.java new file mode 100644 index 0000000..480ba2b --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/RestartPlanTest.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.connect.runtime.rest.entities.ConnectorStateInfo; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorStateInfo.TaskState; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorType; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class RestartPlanTest { + private static final String CONNECTOR_NAME = "foo"; + + @Test + public void testRestartPlan() { + ConnectorStateInfo.ConnectorState state = new ConnectorStateInfo.ConnectorState( + AbstractStatus.State.RESTARTING.name(), + "foo", + null + ); + List tasks = new ArrayList<>(); + tasks.add(new TaskState(1, AbstractStatus.State.RUNNING.name(), "worker1", null)); + tasks.add(new TaskState(2, AbstractStatus.State.PAUSED.name(), "worker1", null)); + tasks.add(new TaskState(3, AbstractStatus.State.RESTARTING.name(), "worker1", null)); + tasks.add(new TaskState(4, AbstractStatus.State.DESTROYED.name(), "worker1", null)); + tasks.add(new TaskState(5, AbstractStatus.State.RUNNING.name(), "worker1", null)); + tasks.add(new TaskState(6, AbstractStatus.State.RUNNING.name(), "worker1", null)); + ConnectorStateInfo connectorStateInfo = new ConnectorStateInfo(CONNECTOR_NAME, state, tasks, ConnectorType.SOURCE); + + RestartRequest restartRequest = new RestartRequest(CONNECTOR_NAME, false, true); + RestartPlan restartPlan = new RestartPlan(restartRequest, connectorStateInfo); + + assertTrue(restartPlan.shouldRestartConnector()); + assertTrue(restartPlan.shouldRestartTasks()); + assertEquals(1, restartPlan.taskIdsToRestart().size()); + assertEquals(3, restartPlan.taskIdsToRestart().iterator().next().task()); + assertTrue(restartPlan.toString().contains("plan to restart connector")); + } + + @Test + public void testNoRestartsPlan() { + ConnectorStateInfo.ConnectorState state = new ConnectorStateInfo.ConnectorState( + AbstractStatus.State.RUNNING.name(), + "foo", + null + ); + List tasks = new ArrayList<>(); + tasks.add(new TaskState(1, AbstractStatus.State.RUNNING.name(), "worker1", null)); + tasks.add(new TaskState(2, AbstractStatus.State.PAUSED.name(), "worker1", null)); + ConnectorStateInfo connectorStateInfo = new ConnectorStateInfo(CONNECTOR_NAME, state, tasks, ConnectorType.SOURCE); + RestartRequest restartRequest = new RestartRequest(CONNECTOR_NAME, false, true); + RestartPlan restartPlan = new RestartPlan(restartRequest, connectorStateInfo); + + assertFalse(restartPlan.shouldRestartConnector()); + assertFalse(restartPlan.shouldRestartTasks()); + assertEquals(0, restartPlan.taskIdsToRestart().size()); + assertTrue(restartPlan.toString().contains("plan to restart 0 of")); + } + + @Test + public void testRestartsOnlyConnector() { + ConnectorStateInfo.ConnectorState state = new ConnectorStateInfo.ConnectorState( + AbstractStatus.State.RESTARTING.name(), + "foo", + null + ); + List tasks = new ArrayList<>(); + tasks.add(new TaskState(1, AbstractStatus.State.RUNNING.name(), "worker1", null)); + tasks.add(new TaskState(2, AbstractStatus.State.PAUSED.name(), "worker1", null)); + ConnectorStateInfo connectorStateInfo = new ConnectorStateInfo(CONNECTOR_NAME, state, tasks, ConnectorType.SOURCE); + RestartRequest restartRequest = new RestartRequest(CONNECTOR_NAME, false, true); + RestartPlan restartPlan = new RestartPlan(restartRequest, connectorStateInfo); + + assertTrue(restartPlan.shouldRestartConnector()); + assertFalse(restartPlan.shouldRestartTasks()); + assertEquals(0, restartPlan.taskIdsToRestart().size()); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/RestartRequestTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/RestartRequestTest.java new file mode 100644 index 0000000..c4be5ca --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/RestartRequestTest.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class RestartRequestTest { + private static final String CONNECTOR_NAME = "foo"; + + @Test + public void forciblyRestartConnectorOnly() { + RestartRequest restartRequest = new RestartRequest(CONNECTOR_NAME, false, false); + assertTrue(restartRequest.forceRestartConnectorOnly()); + restartRequest = new RestartRequest(CONNECTOR_NAME, false, true); + assertFalse(restartRequest.forceRestartConnectorOnly()); + restartRequest = new RestartRequest(CONNECTOR_NAME, true, false); + assertFalse(restartRequest.forceRestartConnectorOnly()); + restartRequest = new RestartRequest(CONNECTOR_NAME, true, true); + assertFalse(restartRequest.forceRestartConnectorOnly()); + } + + @Test + public void restartOnlyFailedConnector() { + RestartRequest restartRequest = new RestartRequest(CONNECTOR_NAME, true, false); + assertTrue(restartRequest.shouldRestartConnector(createConnectorStatus(AbstractStatus.State.FAILED))); + assertFalse(restartRequest.shouldRestartConnector(createConnectorStatus(AbstractStatus.State.RUNNING))); + assertFalse(restartRequest.shouldRestartConnector(createConnectorStatus(AbstractStatus.State.PAUSED))); + } + + @Test + public void restartAnyStatusConnector() { + RestartRequest restartRequest = new RestartRequest(CONNECTOR_NAME, false, false); + assertTrue(restartRequest.shouldRestartConnector(createConnectorStatus(AbstractStatus.State.FAILED))); + assertTrue(restartRequest.shouldRestartConnector(createConnectorStatus(AbstractStatus.State.RUNNING))); + assertTrue(restartRequest.shouldRestartConnector(createConnectorStatus(AbstractStatus.State.PAUSED))); + } + + @Test + public void restartOnlyFailedTasks() { + RestartRequest restartRequest = new RestartRequest(CONNECTOR_NAME, true, true); + assertTrue(restartRequest.shouldRestartTask(createTaskStatus(AbstractStatus.State.FAILED))); + assertFalse(restartRequest.shouldRestartTask(createTaskStatus(AbstractStatus.State.RUNNING))); + assertFalse(restartRequest.shouldRestartTask(createTaskStatus(AbstractStatus.State.PAUSED))); + } + + @Test + public void restartAnyStatusTasks() { + RestartRequest restartRequest = new RestartRequest(CONNECTOR_NAME, false, true); + assertTrue(restartRequest.shouldRestartTask(createTaskStatus(AbstractStatus.State.FAILED))); + assertTrue(restartRequest.shouldRestartTask(createTaskStatus(AbstractStatus.State.RUNNING))); + assertTrue(restartRequest.shouldRestartTask(createTaskStatus(AbstractStatus.State.PAUSED))); + } + + @Test + public void doNotRestartTasks() { + RestartRequest restartRequest = new RestartRequest(CONNECTOR_NAME, false, false); + assertFalse(restartRequest.shouldRestartTask(createTaskStatus(AbstractStatus.State.FAILED))); + assertFalse(restartRequest.shouldRestartTask(createTaskStatus(AbstractStatus.State.RUNNING))); + + restartRequest = new RestartRequest(CONNECTOR_NAME, true, false); + assertFalse(restartRequest.shouldRestartTask(createTaskStatus(AbstractStatus.State.FAILED))); + assertFalse(restartRequest.shouldRestartTask(createTaskStatus(AbstractStatus.State.RUNNING))); + } + + @Test + public void compareImpact() { + RestartRequest onlyFailedConnector = new RestartRequest(CONNECTOR_NAME, true, false); + RestartRequest failedConnectorAndTasks = new RestartRequest(CONNECTOR_NAME, true, true); + RestartRequest onlyConnector = new RestartRequest(CONNECTOR_NAME, false, false); + RestartRequest connectorAndTasks = new RestartRequest(CONNECTOR_NAME, false, true); + List restartRequests = Arrays.asList(connectorAndTasks, onlyConnector, onlyFailedConnector, failedConnectorAndTasks); + Collections.sort(restartRequests); + assertEquals(onlyFailedConnector, restartRequests.get(0)); + assertEquals(failedConnectorAndTasks, restartRequests.get(1)); + assertEquals(onlyConnector, restartRequests.get(2)); + assertEquals(connectorAndTasks, restartRequests.get(3)); + + RestartRequest onlyFailedDiffConnector = new RestartRequest(CONNECTOR_NAME + "foo", true, false); + assertTrue(onlyFailedConnector.compareTo(onlyFailedDiffConnector) != 0); + } + + private TaskStatus createTaskStatus(AbstractStatus.State state) { + return new TaskStatus(null, state, null, 0); + } + + private ConnectorStatus createConnectorStatus(AbstractStatus.State state) { + return new ConnectorStatus(null, state, null, 0); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/SourceConnectorConfigTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/SourceConnectorConfigTest.java new file mode 100644 index 0000000..1972b62 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/SourceConnectorConfigTest.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.common.config.ConfigException; +import org.junit.Test; + +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static org.apache.kafka.common.config.TopicConfig.CLEANUP_POLICY_COMPACT; +import static org.apache.kafka.common.config.TopicConfig.CLEANUP_POLICY_CONFIG; +import static org.apache.kafka.common.config.TopicConfig.COMPRESSION_TYPE_CONFIG; +import static org.apache.kafka.common.config.TopicConfig.RETENTION_MS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.CONNECTOR_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.NAME_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfigTest.MOCK_PLUGINS; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.DEFAULT_TOPIC_CREATION_GROUP; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.DEFAULT_TOPIC_CREATION_PREFIX; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.PARTITIONS_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.REPLICATION_FACTOR_CONFIG; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class SourceConnectorConfigTest { + + private static final String FOO_CONNECTOR = "foo-source"; + private static final short DEFAULT_REPLICATION_FACTOR = -1; + private static final int DEFAULT_PARTITIONS = -1; + + public Map defaultConnectorProps() { + Map props = new HashMap<>(); + props.put(NAME_CONFIG, FOO_CONNECTOR); + props.put(CONNECTOR_CLASS_CONFIG, ConnectorConfigTest.TestConnector.class.getName()); + return props; + } + + public Map defaultConnectorPropsWithTopicCreation() { + Map props = defaultConnectorProps(); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + REPLICATION_FACTOR_CONFIG, String.valueOf(DEFAULT_REPLICATION_FACTOR)); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + PARTITIONS_CONFIG, String.valueOf(DEFAULT_PARTITIONS)); + return props; + } + + @Test + public void noTopicCreation() { + Map props = defaultConnectorProps(); + SourceConnectorConfig config = new SourceConnectorConfig(MOCK_PLUGINS, props, false); + assertFalse(config.usesTopicCreation()); + } + + @Test + public void shouldNotAllowZeroPartitionsOrReplicationFactor() { + Map props = defaultConnectorPropsWithTopicCreation(); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + PARTITIONS_CONFIG, String.valueOf(0)); + Exception e = assertThrows(ConfigException.class, () -> new SourceConnectorConfig(MOCK_PLUGINS, props, true)); + assertThat(e.getMessage(), containsString("Number of partitions must be positive, or -1")); + + props.put(DEFAULT_TOPIC_CREATION_PREFIX + PARTITIONS_CONFIG, String.valueOf(DEFAULT_PARTITIONS)); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + REPLICATION_FACTOR_CONFIG, String.valueOf(0)); + + e = assertThrows(ConfigException.class, () -> new SourceConnectorConfig(MOCK_PLUGINS, props, true)); + assertThat(e.getMessage(), containsString("Replication factor must be positive and not " + + "larger than the number of brokers in the Kafka cluster, or -1 to use the " + + "broker's default")); + } + + @Test + public void shouldNotAllowPartitionsOrReplicationFactorLessThanNegativeOne() { + Map props = defaultConnectorPropsWithTopicCreation(); + for (int i = -2; i > -100; --i) { + props.put(DEFAULT_TOPIC_CREATION_PREFIX + PARTITIONS_CONFIG, String.valueOf(i)); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + REPLICATION_FACTOR_CONFIG, String.valueOf(DEFAULT_REPLICATION_FACTOR)); + Exception e = assertThrows(ConfigException.class, () -> new SourceConnectorConfig(MOCK_PLUGINS, props, true)); + assertThat(e.getMessage(), containsString("Number of partitions must be positive, or -1")); + + props.put(DEFAULT_TOPIC_CREATION_PREFIX + PARTITIONS_CONFIG, String.valueOf(DEFAULT_PARTITIONS)); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + REPLICATION_FACTOR_CONFIG, String.valueOf(i)); + e = assertThrows(ConfigException.class, () -> new SourceConnectorConfig(MOCK_PLUGINS, props, true)); + assertThat(e.getMessage(), containsString("Replication factor must be positive and not " + + "larger than the number of brokers in the Kafka cluster, or -1 to use the " + + "broker's default")); + } + } + + @Test + public void shouldAllowNegativeOneAndPositiveForReplicationFactor() { + Map props = defaultConnectorPropsWithTopicCreation(); + SourceConnectorConfig config = new SourceConnectorConfig(MOCK_PLUGINS, props, true); + assertTrue(config.usesTopicCreation()); + + for (int i = 1; i <= 100; ++i) { + props.put(DEFAULT_TOPIC_CREATION_PREFIX + PARTITIONS_CONFIG, String.valueOf(i)); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + REPLICATION_FACTOR_CONFIG, String.valueOf(DEFAULT_REPLICATION_FACTOR)); + config = new SourceConnectorConfig(MOCK_PLUGINS, props, true); + assertTrue(config.usesTopicCreation()); + + props.put(DEFAULT_TOPIC_CREATION_PREFIX + PARTITIONS_CONFIG, String.valueOf(DEFAULT_PARTITIONS)); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + REPLICATION_FACTOR_CONFIG, String.valueOf(i)); + config = new SourceConnectorConfig(MOCK_PLUGINS, props, true); + assertTrue(config.usesTopicCreation()); + } + } + + @Test + public void shouldAllowSettingTopicProperties() { + Map topicProps = new HashMap<>(); + topicProps.put(CLEANUP_POLICY_CONFIG, CLEANUP_POLICY_COMPACT); + topicProps.put(COMPRESSION_TYPE_CONFIG, "lz4"); + topicProps.put(RETENTION_MS_CONFIG, String.valueOf(TimeUnit.DAYS.toMillis(30))); + + Map props = defaultConnectorPropsWithTopicCreation(); + topicProps.forEach((k, v) -> props.put(DEFAULT_TOPIC_CREATION_PREFIX + k, v)); + + SourceConnectorConfig config = new SourceConnectorConfig(MOCK_PLUGINS, props, true); + assertEquals(topicProps, + convertToStringValues(config.topicCreationOtherConfigs(DEFAULT_TOPIC_CREATION_GROUP))); + } + + private static Map convertToStringValues(Map config) { + // null values are not allowed + return config.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> { + Objects.requireNonNull(e.getValue()); + return e.getValue().toString(); + })); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/SourceTaskOffsetCommitterTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/SourceTaskOffsetCommitterTest.java new file mode 100644 index 0000000..278a73d --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/SourceTaskOffsetCommitterTest.java @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.runtime.standalone.StandaloneConfig; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.apache.kafka.connect.util.ThreadedTest; +import org.easymock.Capture; +import org.easymock.EasyMock; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.easymock.PowerMock; +import org.powermock.api.easymock.annotation.Mock; +import org.powermock.modules.junit4.PowerMockRunner; +import org.powermock.reflect.Whitebox; +import org.slf4j.Logger; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; + +import static java.util.Collections.singletonMap; +import static org.easymock.EasyMock.eq; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +@RunWith(PowerMockRunner.class) +public class SourceTaskOffsetCommitterTest extends ThreadedTest { + + private final ConcurrentHashMap> committers = new ConcurrentHashMap<>(); + + @Mock private ScheduledExecutorService executor; + @Mock private Logger mockLog; + @Mock private ScheduledFuture commitFuture; + @Mock private ScheduledFuture taskFuture; + @Mock private ConnectorTaskId taskId; + @Mock private WorkerSourceTask task; + + private SourceTaskOffsetCommitter committer; + + private static final long DEFAULT_OFFSET_COMMIT_INTERVAL_MS = 1000; + + @Override + public void setup() { + super.setup(); + Map workerProps = new HashMap<>(); + workerProps.put("key.converter", "org.apache.kafka.connect.json.JsonConverter"); + workerProps.put("value.converter", "org.apache.kafka.connect.json.JsonConverter"); + workerProps.put("offset.storage.file.filename", "/tmp/connect.offsets"); + workerProps.put("offset.flush.interval.ms", + Long.toString(DEFAULT_OFFSET_COMMIT_INTERVAL_MS)); + WorkerConfig config = new StandaloneConfig(workerProps); + committer = new SourceTaskOffsetCommitter(config, executor, committers); + Whitebox.setInternalState(SourceTaskOffsetCommitter.class, "log", mockLog); + } + + @SuppressWarnings("unchecked") + @Test + public void testSchedule() { + Capture taskWrapper = EasyMock.newCapture(); + + EasyMock.expect(executor.scheduleWithFixedDelay( + EasyMock.capture(taskWrapper), eq(DEFAULT_OFFSET_COMMIT_INTERVAL_MS), + eq(DEFAULT_OFFSET_COMMIT_INTERVAL_MS), eq(TimeUnit.MILLISECONDS)) + ).andReturn((ScheduledFuture) commitFuture); + + PowerMock.replayAll(); + + committer.schedule(taskId, task); + assertTrue(taskWrapper.hasCaptured()); + assertNotNull(taskWrapper.getValue()); + assertEquals(singletonMap(taskId, commitFuture), committers); + + PowerMock.verifyAll(); + } + + @Test + public void testClose() throws Exception { + long timeoutMs = 1000; + + // Normal termination, where termination times out. + executor.shutdown(); + PowerMock.expectLastCall(); + + EasyMock.expect(executor.awaitTermination(eq(timeoutMs), eq(TimeUnit.MILLISECONDS))) + .andReturn(false); + mockLog.error(EasyMock.anyString()); + PowerMock.expectLastCall(); + PowerMock.replayAll(); + + committer.close(timeoutMs); + + PowerMock.verifyAll(); + PowerMock.resetAll(); + + // Termination interrupted + executor.shutdown(); + PowerMock.expectLastCall(); + + EasyMock.expect(executor.awaitTermination(eq(timeoutMs), eq(TimeUnit.MILLISECONDS))) + .andThrow(new InterruptedException()); + PowerMock.replayAll(); + + committer.close(timeoutMs); + + PowerMock.verifyAll(); + } + + @Test + public void testRemove() throws Exception { + // Try to remove a non-existing task + PowerMock.replayAll(); + + assertTrue(committers.isEmpty()); + committer.remove(taskId); + assertTrue(committers.isEmpty()); + + PowerMock.verifyAll(); + PowerMock.resetAll(); + + // Try to remove an existing task + EasyMock.expect(taskFuture.cancel(eq(false))).andReturn(false); + EasyMock.expect(taskFuture.isDone()).andReturn(false); + EasyMock.expect(taskFuture.get()).andReturn(null); + EasyMock.expect(taskId.connector()).andReturn("MyConnector"); + EasyMock.expect(taskId.task()).andReturn(1); + PowerMock.replayAll(); + + committers.put(taskId, taskFuture); + committer.remove(taskId); + assertTrue(committers.isEmpty()); + + PowerMock.verifyAll(); + PowerMock.resetAll(); + + // Try to remove a cancelled task + EasyMock.expect(taskFuture.cancel(eq(false))).andReturn(false); + EasyMock.expect(taskFuture.isDone()).andReturn(false); + EasyMock.expect(taskFuture.get()).andThrow(new CancellationException()); + EasyMock.expect(taskId.connector()).andReturn("MyConnector"); + EasyMock.expect(taskId.task()).andReturn(1); + mockLog.trace(EasyMock.anyString(), EasyMock.anyObject()); + PowerMock.expectLastCall(); + PowerMock.replayAll(); + + committers.put(taskId, taskFuture); + committer.remove(taskId); + assertTrue(committers.isEmpty()); + + PowerMock.verifyAll(); + PowerMock.resetAll(); + + // Try to remove an interrupted task + EasyMock.expect(taskFuture.cancel(eq(false))).andReturn(false); + EasyMock.expect(taskFuture.isDone()).andReturn(false); + EasyMock.expect(taskFuture.get()).andThrow(new InterruptedException()); + EasyMock.expect(taskId.connector()).andReturn("MyConnector"); + EasyMock.expect(taskId.task()).andReturn(1); + PowerMock.replayAll(); + + try { + committers.put(taskId, taskFuture); + committer.remove(taskId); + fail("Expected ConnectException to be raised"); + } catch (ConnectException e) { + //ignore + } + + PowerMock.verifyAll(); + } + +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/StateTrackerTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/StateTrackerTest.java new file mode 100644 index 0000000..eb12b4d --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/StateTrackerTest.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.connect.runtime.AbstractStatus.State; +import org.apache.kafka.common.utils.MockTime; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +public class StateTrackerTest { + + private static final double DELTA = 0.000001d; + + private StateTracker tracker; + private MockTime time; + private State state; + + @Before + public void setUp() { + time = new MockTime(); + time.sleep(1000L); + tracker = new StateTracker(); + state = State.UNASSIGNED; + } + + @Test + public void currentStateIsNullWhenNotInitialized() { + assertNull(tracker.currentState()); + } + + @Test + public void currentState() { + for (State state : State.values()) { + tracker.changeState(state, time.milliseconds()); + assertEquals(state, tracker.currentState()); + } + } + + @Test + public void calculateDurations() { + tracker.changeState(State.UNASSIGNED, time.milliseconds()); + time.sleep(1000L); + assertEquals(1.0d, tracker.durationRatio(State.UNASSIGNED, time.milliseconds()), DELTA); + assertEquals(0.0d, tracker.durationRatio(State.RUNNING, time.milliseconds()), DELTA); + assertEquals(0.0d, tracker.durationRatio(State.PAUSED, time.milliseconds()), DELTA); + assertEquals(0.0d, tracker.durationRatio(State.FAILED, time.milliseconds()), DELTA); + assertEquals(0.0d, tracker.durationRatio(State.DESTROYED, time.milliseconds()), DELTA); + assertEquals(0.0d, tracker.durationRatio(State.RESTARTING, time.milliseconds()), DELTA); + + tracker.changeState(State.RUNNING, time.milliseconds()); + time.sleep(3000L); + assertEquals(0.25d, tracker.durationRatio(State.UNASSIGNED, time.milliseconds()), DELTA); + assertEquals(0.75d, tracker.durationRatio(State.RUNNING, time.milliseconds()), DELTA); + assertEquals(0.0d, tracker.durationRatio(State.PAUSED, time.milliseconds()), DELTA); + assertEquals(0.0d, tracker.durationRatio(State.FAILED, time.milliseconds()), DELTA); + assertEquals(0.0d, tracker.durationRatio(State.DESTROYED, time.milliseconds()), DELTA); + assertEquals(0.0d, tracker.durationRatio(State.RESTARTING, time.milliseconds()), DELTA); + + tracker.changeState(State.PAUSED, time.milliseconds()); + time.sleep(4000L); + assertEquals(0.125d, tracker.durationRatio(State.UNASSIGNED, time.milliseconds()), DELTA); + assertEquals(0.375d, tracker.durationRatio(State.RUNNING, time.milliseconds()), DELTA); + assertEquals(0.500d, tracker.durationRatio(State.PAUSED, time.milliseconds()), DELTA); + assertEquals(0.0d, tracker.durationRatio(State.FAILED, time.milliseconds()), DELTA); + assertEquals(0.0d, tracker.durationRatio(State.DESTROYED, time.milliseconds()), DELTA); + assertEquals(0.0d, tracker.durationRatio(State.RESTARTING, time.milliseconds()), DELTA); + + tracker.changeState(State.RUNNING, time.milliseconds()); + time.sleep(8000L); + assertEquals(0.0625d, tracker.durationRatio(State.UNASSIGNED, time.milliseconds()), DELTA); + assertEquals(0.6875d, tracker.durationRatio(State.RUNNING, time.milliseconds()), DELTA); + assertEquals(0.2500d, tracker.durationRatio(State.PAUSED, time.milliseconds()), DELTA); + assertEquals(0.0d, tracker.durationRatio(State.FAILED, time.milliseconds()), DELTA); + assertEquals(0.0d, tracker.durationRatio(State.DESTROYED, time.milliseconds()), DELTA); + assertEquals(0.0d, tracker.durationRatio(State.RESTARTING, time.milliseconds()), DELTA); + + tracker.changeState(State.FAILED, time.milliseconds()); + time.sleep(16000L); + assertEquals(0.03125d, tracker.durationRatio(State.UNASSIGNED, time.milliseconds()), DELTA); + assertEquals(0.34375d, tracker.durationRatio(State.RUNNING, time.milliseconds()), DELTA); + assertEquals(0.12500d, tracker.durationRatio(State.PAUSED, time.milliseconds()), DELTA); + assertEquals(0.50000d, tracker.durationRatio(State.FAILED, time.milliseconds()), DELTA); + assertEquals(0.0d, tracker.durationRatio(State.DESTROYED, time.milliseconds()), DELTA); + assertEquals(0.0d, tracker.durationRatio(State.RESTARTING, time.milliseconds()), DELTA); + + } + +} \ No newline at end of file diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/SubmittedRecordsTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/SubmittedRecordsTest.java new file mode 100644 index 0000000..4028249 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/SubmittedRecordsTest.java @@ -0,0 +1,381 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.connect.runtime.SubmittedRecords.SubmittedRecord; +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.apache.kafka.connect.runtime.SubmittedRecords.CommittableOffsets; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class SubmittedRecordsTest { + + private static final Map PARTITION1 = Collections.singletonMap("subreddit", "apachekafka"); + private static final Map PARTITION2 = Collections.singletonMap("subreddit", "adifferentvalue"); + private static final Map PARTITION3 = Collections.singletonMap("subreddit", "asdfqweoicus"); + + private AtomicInteger offset; + + SubmittedRecords submittedRecords; + + @Before + public void setup() { + submittedRecords = new SubmittedRecords(); + offset = new AtomicInteger(); + } + + @Test + public void testNoRecords() { + CommittableOffsets committableOffsets = submittedRecords.committableOffsets(); + assertTrue(committableOffsets.isEmpty()); + + committableOffsets = submittedRecords.committableOffsets(); + assertTrue(committableOffsets.isEmpty()); + + committableOffsets = submittedRecords.committableOffsets(); + assertTrue(committableOffsets.isEmpty()); + + assertNoRemainingDeques(); + } + + @Test + public void testNoCommittedRecords() { + for (int i = 0; i < 3; i++) { + for (Map partition : Arrays.asList(PARTITION1, PARTITION2, PARTITION3)) { + submittedRecords.submit(partition, newOffset()); + } + } + + CommittableOffsets committableOffsets = submittedRecords.committableOffsets(); + assertMetadata(committableOffsets, 0, 9, 3, 3, PARTITION1, PARTITION2, PARTITION3); + assertEquals(Collections.emptyMap(), committableOffsets.offsets()); + + committableOffsets = submittedRecords.committableOffsets(); + assertMetadata(committableOffsets, 0, 9, 3, 3, PARTITION1, PARTITION2, PARTITION3); + assertEquals(Collections.emptyMap(), committableOffsets.offsets()); + + committableOffsets = submittedRecords.committableOffsets(); + assertMetadata(committableOffsets, 0, 9, 3, 3, PARTITION1, PARTITION2, PARTITION3); + assertEquals(Collections.emptyMap(), committableOffsets.offsets()); + } + + @Test + public void testSingleAck() { + Map offset = newOffset(); + + SubmittedRecord submittedRecord = submittedRecords.submit(PARTITION1, offset); + CommittableOffsets committableOffsets = submittedRecords.committableOffsets(); + // Record has been submitted but not yet acked; cannot commit offsets for it yet + assertFalse(committableOffsets.isEmpty()); + assertEquals(Collections.emptyMap(), committableOffsets.offsets()); + assertMetadata(committableOffsets, 0, 1, 1, 1, PARTITION1); + assertNoEmptyDeques(); + + submittedRecord.ack(); + committableOffsets = submittedRecords.committableOffsets(); + // Record has been acked; can commit offsets for it + assertFalse(committableOffsets.isEmpty()); + assertEquals(Collections.singletonMap(PARTITION1, offset), committableOffsets.offsets()); + assertMetadataNoPending(committableOffsets, 1); + + // Everything has been ack'd and consumed; make sure that it's been cleaned up to avoid memory leaks + assertNoRemainingDeques(); + + committableOffsets = submittedRecords.committableOffsets(); + // Old offsets should be wiped + assertEquals(Collections.emptyMap(), committableOffsets.offsets()); + assertTrue(committableOffsets.isEmpty()); + } + + @Test + public void testMultipleAcksAcrossMultiplePartitions() { + Map partition1Offset1 = newOffset(); + Map partition1Offset2 = newOffset(); + Map partition2Offset1 = newOffset(); + Map partition2Offset2 = newOffset(); + + SubmittedRecord partition1Record1 = submittedRecords.submit(PARTITION1, partition1Offset1); + SubmittedRecord partition1Record2 = submittedRecords.submit(PARTITION1, partition1Offset2); + SubmittedRecord partition2Record1 = submittedRecords.submit(PARTITION2, partition2Offset1); + SubmittedRecord partition2Record2 = submittedRecords.submit(PARTITION2, partition2Offset2); + + CommittableOffsets committableOffsets = submittedRecords.committableOffsets(); + // No records ack'd yet; can't commit any offsets + assertEquals(Collections.emptyMap(), committableOffsets.offsets()); + assertMetadata(committableOffsets, 0, 4, 2, 2, PARTITION1, PARTITION2); + assertNoEmptyDeques(); + + partition1Record2.ack(); + committableOffsets = submittedRecords.committableOffsets(); + // One record has been ack'd, but a record that comes before it and corresponds to the same source partition hasn't been + assertEquals(Collections.emptyMap(), committableOffsets.offsets()); + assertMetadata(committableOffsets, 0, 4, 2, 2, PARTITION1, PARTITION2); + assertNoEmptyDeques(); + + partition2Record1.ack(); + committableOffsets = submittedRecords.committableOffsets(); + // We can commit the first offset for the second partition + assertEquals(Collections.singletonMap(PARTITION2, partition2Offset1), committableOffsets.offsets()); + assertMetadata(committableOffsets, 1, 3, 2, 2, PARTITION1); + assertNoEmptyDeques(); + + committableOffsets = submittedRecords.committableOffsets(); + // No new offsets to commit + assertEquals(Collections.emptyMap(), committableOffsets.offsets()); + assertMetadata(committableOffsets, 0, 3, 2, 2, PARTITION1); + assertNoEmptyDeques(); + + partition1Record1.ack(); + partition2Record2.ack(); + + committableOffsets = submittedRecords.committableOffsets(); + // We can commit new offsets for both partitions now + Map, Map> expectedOffsets = new HashMap<>(); + expectedOffsets.put(PARTITION1, partition1Offset2); + expectedOffsets.put(PARTITION2, partition2Offset2); + assertEquals(expectedOffsets, committableOffsets.offsets()); + assertMetadataNoPending(committableOffsets, 3); + + // Everything has been ack'd and consumed; make sure that it's been cleaned up to avoid memory leaks + assertNoRemainingDeques(); + + committableOffsets = submittedRecords.committableOffsets(); + // No new offsets to commit + assertTrue(committableOffsets.isEmpty()); + } + + @Test + public void testRemoveLastSubmittedRecord() { + SubmittedRecord submittedRecord = submittedRecords.submit(PARTITION1, newOffset()); + + CommittableOffsets committableOffsets = submittedRecords.committableOffsets(); + assertEquals(Collections.emptyMap(), committableOffsets.offsets()); + assertMetadata(committableOffsets, 0, 1, 1, 1, PARTITION1); + + assertTrue("First attempt to remove record from submitted queue should succeed", submittedRecords.removeLastOccurrence(submittedRecord)); + assertFalse("Attempt to remove already-removed record from submitted queue should fail", submittedRecords.removeLastOccurrence(submittedRecord)); + + committableOffsets = submittedRecords.committableOffsets(); + // Even if SubmittedRecords::remove is broken, we haven't ack'd anything yet, so there should be no committable offsets + assertTrue(committableOffsets.isEmpty()); + + submittedRecord.ack(); + committableOffsets = submittedRecords.committableOffsets(); + // Even though the record has somehow been acknowledged, it should not be counted when collecting committable offsets + assertTrue(committableOffsets.isEmpty()); + } + + @Test + public void testRemoveNotLastSubmittedRecord() { + Map partition1Offset = newOffset(); + Map partition2Offset = newOffset(); + + SubmittedRecord recordToRemove = submittedRecords.submit(PARTITION1, partition1Offset); + SubmittedRecord lastSubmittedRecord = submittedRecords.submit(PARTITION2, partition2Offset); + + CommittableOffsets committableOffsets = submittedRecords.committableOffsets(); + assertMetadata(committableOffsets, 0, 2, 2, 1, PARTITION1, PARTITION2); + assertNoEmptyDeques(); + + assertTrue("First attempt to remove record from submitted queue should succeed", submittedRecords.removeLastOccurrence(recordToRemove)); + + committableOffsets = submittedRecords.committableOffsets(); + // Even if SubmittedRecords::remove is broken, we haven't ack'd anything yet, so there should be no committable offsets + assertEquals(Collections.emptyMap(), committableOffsets.offsets()); + assertMetadata(committableOffsets, 0, 1, 1, 1, PARTITION2); + assertNoEmptyDeques(); + // The only record for this partition has been removed; we shouldn't be tracking a deque for it anymore + assertRemovedDeques(PARTITION1); + + recordToRemove.ack(); + committableOffsets = submittedRecords.committableOffsets(); + // Even though the record has somehow been acknowledged, it should not be counted when collecting committable offsets + assertEquals(Collections.emptyMap(), committableOffsets.offsets()); + assertMetadata(committableOffsets, 0, 1, 1, 1, PARTITION2); + assertNoEmptyDeques(); + + lastSubmittedRecord.ack(); + committableOffsets = submittedRecords.committableOffsets(); + // Now that the last-submitted record has been ack'd, we should be able to commit its offset + assertEquals(Collections.singletonMap(PARTITION2, partition2Offset), committableOffsets.offsets()); + assertMetadata(committableOffsets, 1, 0, 0, 0, (Map) null); + assertFalse(committableOffsets.hasPending()); + + // Everything has been ack'd and consumed; make sure that it's been cleaned up to avoid memory leaks + assertNoRemainingDeques(); + committableOffsets = submittedRecords.committableOffsets(); + assertTrue(committableOffsets.isEmpty()); + } + + @Test + public void testNullPartitionAndOffset() { + SubmittedRecord submittedRecord = submittedRecords.submit(null, null); + CommittableOffsets committableOffsets = submittedRecords.committableOffsets(); + assertMetadata(committableOffsets, 0, 1, 1, 1, (Map) null); + + submittedRecord.ack(); + committableOffsets = submittedRecords.committableOffsets(); + assertEquals(Collections.singletonMap(null, null), committableOffsets.offsets()); + assertMetadataNoPending(committableOffsets, 1); + + assertNoEmptyDeques(); + } + + @Test + public void testAwaitMessagesNoneSubmitted() { + assertTrue(submittedRecords.awaitAllMessages(0, TimeUnit.MILLISECONDS)); + } + + @Test + public void testAwaitMessagesAfterAllAcknowledged() { + SubmittedRecord recordToAck = submittedRecords.submit(PARTITION1, newOffset()); + assertFalse(submittedRecords.awaitAllMessages(0, TimeUnit.MILLISECONDS)); + recordToAck.ack(); + assertTrue(submittedRecords.awaitAllMessages(0, TimeUnit.MILLISECONDS)); + } + + @Test + public void testAwaitMessagesAfterAllRemoved() { + SubmittedRecord recordToRemove1 = submittedRecords.submit(PARTITION1, newOffset()); + SubmittedRecord recordToRemove2 = submittedRecords.submit(PARTITION1, newOffset()); + assertFalse( + "Await should fail since neither of the in-flight records has been removed so far", + submittedRecords.awaitAllMessages(0, TimeUnit.MILLISECONDS) + ); + + submittedRecords.removeLastOccurrence(recordToRemove1); + assertFalse( + "Await should fail since only one of the two submitted records has been removed so far", + submittedRecords.awaitAllMessages(0, TimeUnit.MILLISECONDS) + ); + + submittedRecords.removeLastOccurrence(recordToRemove1); + assertFalse( + "Await should fail since only one of the two submitted records has been removed so far, " + + "even though that record has been removed twice", + submittedRecords.awaitAllMessages(0, TimeUnit.MILLISECONDS) + ); + + submittedRecords.removeLastOccurrence(recordToRemove2); + assertTrue( + "Await should succeed since both submitted records have now been removed", + submittedRecords.awaitAllMessages(0, TimeUnit.MILLISECONDS) + ); + } + + @Test + public void testAwaitMessagesTimesOut() { + submittedRecords.submit(PARTITION1, newOffset()); + assertFalse(submittedRecords.awaitAllMessages(10, TimeUnit.MILLISECONDS)); + } + + @Test + public void testAwaitMessagesReturnsAfterAsynchronousAck() throws Exception { + SubmittedRecord inFlightRecord1 = submittedRecords.submit(PARTITION1, newOffset()); + SubmittedRecord inFlightRecord2 = submittedRecords.submit(PARTITION2, newOffset()); + + AtomicBoolean awaitResult = new AtomicBoolean(); + CountDownLatch awaitComplete = new CountDownLatch(1); + new Thread(() -> { + awaitResult.set(submittedRecords.awaitAllMessages(5, TimeUnit.SECONDS)); + awaitComplete.countDown(); + }).start(); + + assertTrue( + "Should not have finished awaiting message delivery before either in-flight record was acknowledged", + awaitComplete.getCount() > 0 + ); + + inFlightRecord1.ack(); + assertTrue( + "Should not have finished awaiting message delivery before one in-flight record was acknowledged", + awaitComplete.getCount() > 0 + ); + + inFlightRecord1.ack(); + assertTrue( + "Should not have finished awaiting message delivery before one in-flight record was acknowledged, " + + "even though the other record has been acknowledged twice", + awaitComplete.getCount() > 0 + ); + + inFlightRecord2.ack(); + assertTrue( + "Should have finished awaiting message delivery after both in-flight records were acknowledged", + awaitComplete.await(1, TimeUnit.SECONDS) + ); + assertTrue( + "Await of in-flight messages should have succeeded", + awaitResult.get() + ); + } + + private void assertNoRemainingDeques() { + assertEquals("Internal records map should be completely empty", Collections.emptyMap(), submittedRecords.records); + } + + @SafeVarargs + private final void assertRemovedDeques(Map... partitions) { + for (Map partition : partitions) { + assertFalse("Deque for partition " + partition + " should have been cleaned up from internal records map", submittedRecords.records.containsKey(partition)); + } + } + + private void assertNoEmptyDeques() { + submittedRecords.records.forEach((partition, deque) -> + assertFalse("Empty deque for partition " + partition + " should have been cleaned up from internal records map", deque.isEmpty()) + ); + } + + private Map newOffset() { + return Collections.singletonMap("timestamp", offset.getAndIncrement()); + } + + private void assertMetadataNoPending(CommittableOffsets committableOffsets, int committableMessages) { + assertEquals(committableMessages, committableOffsets.numCommittableMessages()); + assertFalse(committableOffsets.hasPending()); + } + + @SafeVarargs + @SuppressWarnings("varargs") + private final void assertMetadata( + CommittableOffsets committableOffsets, + int committableMessages, + int uncommittableMessages, + int numDeques, + int largestDequeSize, + Map... largestDequePartitions + ) { + assertEquals(committableMessages, committableOffsets.numCommittableMessages()); + assertEquals(uncommittableMessages, committableOffsets.numUncommittableMessages()); + assertEquals(numDeques, committableOffsets.numDeques()); + assertEquals(largestDequeSize, committableOffsets.largestDequeSize()); + assertTrue(Arrays.asList(largestDequePartitions).contains(committableOffsets.largestDequePartition())); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/TestConverterWithHeaders.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/TestConverterWithHeaders.java new file mode 100644 index 0000000..91e0999 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/TestConverterWithHeaders.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import java.io.UnsupportedEncodingException; +import java.util.Map; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.errors.DataException; +import org.apache.kafka.connect.storage.Converter; + +/** + * This is a simple Converter implementation that uses "encoding" header to encode/decode strings via provided charset name + */ +public class TestConverterWithHeaders implements Converter { + private static final String HEADER_ENCODING = "encoding"; + + @Override + public void configure(Map configs, boolean isKey) { + + } + + @Override + public SchemaAndValue toConnectData(String topic, Headers headers, byte[] value) { + String encoding = extractEncoding(headers); + + try { + return new SchemaAndValue(Schema.STRING_SCHEMA, new String(value, encoding)); + } catch (UnsupportedEncodingException e) { + throw new DataException("Unsupported encoding: " + encoding, e); + } + } + + @Override + public byte[] fromConnectData(String topic, Headers headers, Schema schema, Object value) { + String encoding = extractEncoding(headers); + + try { + return ((String) value).getBytes(encoding); + } catch (UnsupportedEncodingException e) { + throw new DataException("Unsupported encoding: " + encoding, e); + } + } + + private String extractEncoding(Headers headers) { + Header header = headers.lastHeader(HEADER_ENCODING); + if (header == null) { + throw new DataException("Header '" + HEADER_ENCODING + "' is required!"); + } + + return new String(header.value()); + } + + + @Override + public SchemaAndValue toConnectData(String topic, byte[] value) { + throw new DataException("Headers are required for this converter!"); + } + + @Override + public byte[] fromConnectData(String topic, Schema schema, Object value) { + throw new DataException("Headers are required for this converter!"); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/TestSinkConnector.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/TestSinkConnector.java new file mode 100644 index 0000000..a6e3bb1 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/TestSinkConnector.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.connector.Task; +import org.apache.kafka.connect.sink.SinkConnector; + +import java.util.List; +import java.util.Map; + +public class TestSinkConnector extends SinkConnector { + + public static final String VERSION = "some great version"; + + @Override + public String version() { + return VERSION; + } + + @Override + public void start(Map props) { + + } + + @Override + public Class taskClass() { + return null; + } + + @Override + public List> taskConfigs(int maxTasks) { + return null; + } + + @Override + public void stop() { + + } + + @Override + public ConfigDef config() { + return new ConfigDef() + .define("required", ConfigDef.Type.STRING, ConfigDef.Importance.HIGH, "required docs") + .define("optional", ConfigDef.Type.STRING, "defaultVal", ConfigDef.Importance.HIGH, "optional docs"); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/TestSourceConnector.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/TestSourceConnector.java new file mode 100644 index 0000000..5f754e2 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/TestSourceConnector.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.connector.Task; +import org.apache.kafka.connect.source.SourceConnector; + +import java.util.List; +import java.util.Map; + +public class TestSourceConnector extends SourceConnector { + + public static final String VERSION = "an entirely different version"; + + @Override + public String version() { + return VERSION; + } + + @Override + public void start(Map props) { + + } + + @Override + public Class taskClass() { + return null; + } + + @Override + public List> taskConfigs(int maxTasks) { + return null; + } + + @Override + public void stop() { + + } + + @Override + public ConfigDef config() { + return new ConfigDef() + .define("required", ConfigDef.Type.STRING, ConfigDef.Importance.HIGH, "required docs") + .define("optional", ConfigDef.Type.STRING, "defaultVal", ConfigDef.Importance.HIGH, "optional docs"); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/TransformationConfigTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/TransformationConfigTest.java new file mode 100644 index 0000000..1d63d7d --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/TransformationConfigTest.java @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.connect.runtime.isolation.Plugins; +import org.apache.kafka.connect.tools.MockConnector; +import org.apache.kafka.connect.transforms.Cast; +import org.apache.kafka.connect.transforms.ExtractField; +import org.apache.kafka.connect.transforms.Flatten; +import org.apache.kafka.connect.transforms.HoistField; +import org.apache.kafka.connect.transforms.InsertField; +import org.apache.kafka.connect.transforms.MaskField; +import org.apache.kafka.connect.transforms.RegexRouter; +import org.apache.kafka.connect.transforms.ReplaceField; +import org.apache.kafka.connect.transforms.SetSchemaMetadata; +import org.apache.kafka.connect.transforms.TimestampConverter; +import org.apache.kafka.connect.transforms.TimestampRouter; +import org.apache.kafka.connect.transforms.ValueToKey; +import org.junit.Test; + +import java.util.HashMap; + +/** + * Tests that transformations' configs can be composed with ConnectorConfig during its construction, ensuring no + * conflicting fields or other issues. + * + * This test appears here simply because it requires both connect-runtime and connect-transforms and connect-runtime + * already depends on connect-transforms. + */ +public class TransformationConfigTest { + + @Test + public void testEmbeddedConfigCast() { + // Validate that we can construct a Connector config containing the extended config for the transform + HashMap connProps = new HashMap<>(); + connProps.put("name", "foo"); + connProps.put("connector.class", MockConnector.class.getName()); + connProps.put("transforms", "example"); + connProps.put("transforms.example.type", Cast.Value.class.getName()); + connProps.put("transforms.example.spec", "int8"); + + Plugins plugins = null; // Safe when we're only constructing the config + new ConnectorConfig(plugins, connProps); + } + + @Test + public void testEmbeddedConfigExtractField() { + // Validate that we can construct a Connector config containing the extended config for the transform + HashMap connProps = new HashMap<>(); + connProps.put("name", "foo"); + connProps.put("connector.class", MockConnector.class.getName()); + connProps.put("transforms", "example"); + connProps.put("transforms.example.type", ExtractField.Value.class.getName()); + connProps.put("transforms.example.field", "field"); + + Plugins plugins = null; // Safe when we're only constructing the config + new ConnectorConfig(plugins, connProps); + } + + @Test + public void testEmbeddedConfigFlatten() { + // Validate that we can construct a Connector config containing the extended config for the transform + HashMap connProps = new HashMap<>(); + connProps.put("name", "foo"); + connProps.put("connector.class", MockConnector.class.getName()); + connProps.put("transforms", "example"); + connProps.put("transforms.example.type", Flatten.Value.class.getName()); + + Plugins plugins = null; // Safe when we're only constructing the config + new ConnectorConfig(plugins, connProps); + } + + @Test + public void testEmbeddedConfigHoistField() { + // Validate that we can construct a Connector config containing the extended config for the transform + HashMap connProps = new HashMap<>(); + connProps.put("name", "foo"); + connProps.put("connector.class", MockConnector.class.getName()); + connProps.put("transforms", "example"); + connProps.put("transforms.example.type", HoistField.Value.class.getName()); + connProps.put("transforms.example.field", "field"); + + Plugins plugins = null; // Safe when we're only constructing the config + new ConnectorConfig(plugins, connProps); + } + + @Test + public void testEmbeddedConfigInsertField() { + // Validate that we can construct a Connector config containing the extended config for the transform + HashMap connProps = new HashMap<>(); + connProps.put("name", "foo"); + connProps.put("connector.class", MockConnector.class.getName()); + connProps.put("transforms", "example"); + connProps.put("transforms.example.type", InsertField.Value.class.getName()); + + Plugins plugins = null; // Safe when we're only constructing the config + new ConnectorConfig(plugins, connProps); + } + + @Test + public void testEmbeddedConfigMaskField() { + // Validate that we can construct a Connector config containing the extended config for the transform + HashMap connProps = new HashMap<>(); + connProps.put("name", "foo"); + connProps.put("connector.class", MockConnector.class.getName()); + connProps.put("transforms", "example"); + connProps.put("transforms.example.type", MaskField.Value.class.getName()); + connProps.put("transforms.example.fields", "field"); + connProps.put("transforms.example.replacement", "nothing"); + + Plugins plugins = null; // Safe when we're only constructing the config + new ConnectorConfig(plugins, connProps); + } + + @Test + public void testEmbeddedConfigRegexRouter() { + // Validate that we can construct a Connector config containing the extended config for the transform + HashMap connProps = new HashMap<>(); + connProps.put("name", "foo"); + connProps.put("connector.class", MockConnector.class.getName()); + connProps.put("transforms", "example"); + connProps.put("transforms.example.type", RegexRouter.class.getName()); + connProps.put("transforms.example.regex", "(.*)"); + connProps.put("transforms.example.replacement", "prefix-$1"); + + Plugins plugins = null; // Safe when we're only constructing the config + new ConnectorConfig(plugins, connProps); + } + + @Test + public void testEmbeddedConfigReplaceField() { + // Validate that we can construct a Connector config containing the extended config for the transform + HashMap connProps = new HashMap<>(); + connProps.put("name", "foo"); + connProps.put("connector.class", MockConnector.class.getName()); + connProps.put("transforms", "example"); + connProps.put("transforms.example.type", ReplaceField.Value.class.getName()); + + Plugins plugins = null; // Safe when we're only constructing the config + new ConnectorConfig(plugins, connProps); + } + + @Test + public void testEmbeddedConfigSetSchemaMetadata() { + // Validate that we can construct a Connector config containing the extended config for the transform + HashMap connProps = new HashMap<>(); + connProps.put("name", "foo"); + connProps.put("connector.class", MockConnector.class.getName()); + connProps.put("transforms", "example"); + connProps.put("transforms.example.type", SetSchemaMetadata.Value.class.getName()); + + Plugins plugins = null; // Safe when we're only constructing the config + new ConnectorConfig(plugins, connProps); + } + + @Test + public void testEmbeddedConfigTimestampConverter() { + // Validate that we can construct a Connector config containing the extended config for the transform + HashMap connProps = new HashMap<>(); + connProps.put("name", "foo"); + connProps.put("connector.class", MockConnector.class.getName()); + connProps.put("transforms", "example"); + connProps.put("transforms.example.type", TimestampConverter.Value.class.getName()); + connProps.put("transforms.example.target.type", "unix"); + + Plugins plugins = null; // Safe when we're only constructing the config + new ConnectorConfig(plugins, connProps); + } + + @Test + public void testEmbeddedConfigTimestampRouter() { + // Validate that we can construct a Connector config containing the extended config for the transform + HashMap connProps = new HashMap<>(); + connProps.put("name", "foo"); + connProps.put("connector.class", MockConnector.class.getName()); + connProps.put("transforms", "example"); + connProps.put("transforms.example.type", TimestampRouter.class.getName()); + + Plugins plugins = null; // Safe when we're only constructing the config + new ConnectorConfig(plugins, connProps); + } + + @Test + public void testEmbeddedConfigValueToKey() { + // Validate that we can construct a Connector config containing the extended config for the transform + HashMap connProps = new HashMap<>(); + connProps.put("name", "foo"); + connProps.put("connector.class", MockConnector.class.getName()); + connProps.put("transforms", "example"); + connProps.put("transforms.example.type", ValueToKey.class.getName()); + connProps.put("transforms.example.fields", "field"); + + Plugins plugins = null; // Safe when we're only constructing the config + new ConnectorConfig(plugins, connProps); + } + +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerConfigTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerConfigTest.java new file mode 100644 index 0000000..fbe6800 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerConfigTest.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.common.config.ConfigException; +import org.junit.Test; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.List; + +import static org.apache.kafka.connect.runtime.WorkerConfig.LISTENERS_DEFAULT; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertThrows; + +public class WorkerConfigTest { + private static final List VALID_HEADER_CONFIGS = Arrays.asList( + "add \t Cache-Control: no-cache, no-store, must-revalidate", + "add \r X-XSS-Protection: 1; mode=block", + "\n add Strict-Transport-Security: max-age=31536000; includeSubDomains", + "AdD Strict-Transport-Security: \r max-age=31536000; includeSubDomains", + "AdD \t Strict-Transport-Security : \n max-age=31536000; includeSubDomains", + "add X-Content-Type-Options: \r nosniff", + "Set \t X-Frame-Options: \t Deny\n ", + "seT \t X-Cache-Info: \t not cacheable\n ", + "seTDate \t Expires: \r 31540000000", + "adDdate \n Last-Modified: \t 0" + ); + + private static final List INVALID_HEADER_CONFIGS = Arrays.asList( + "set \t", + "badaction \t X-Frame-Options:DENY", + "set add X-XSS-Protection:1", + "addX-XSS-Protection", + "X-XSS-Protection:", + "add set X-XSS-Protection: 1", + "add X-XSS-Protection:1 X-XSS-Protection:1 ", + "add X-XSS-Protection", + "set X-Frame-Options:DENY, add :no-cache, no-store, must-revalidate " + ); + + @Test + public void testListenersConfigAllowedValues() { + Map props = baseProps(); + + // no value set for "listeners" + WorkerConfig config = new WorkerConfig(WorkerConfig.baseConfigDef(), props); + assertEquals(LISTENERS_DEFAULT, config.getList(WorkerConfig.LISTENERS_CONFIG)); + + props.put(WorkerConfig.LISTENERS_CONFIG, "http://a.b:9999"); + config = new WorkerConfig(WorkerConfig.baseConfigDef(), props); + assertEquals(Arrays.asList("http://a.b:9999"), config.getList(WorkerConfig.LISTENERS_CONFIG)); + + props.put(WorkerConfig.LISTENERS_CONFIG, "http://a.b:9999, https://a.b:7812"); + config = new WorkerConfig(WorkerConfig.baseConfigDef(), props); + assertEquals(Arrays.asList("http://a.b:9999", "https://a.b:7812"), config.getList(WorkerConfig.LISTENERS_CONFIG)); + + new WorkerConfig(WorkerConfig.baseConfigDef(), props); + } + + @Test + public void testListenersConfigNotAllowedValues() { + Map props = baseProps(); + assertEquals(LISTENERS_DEFAULT, new WorkerConfig(WorkerConfig.baseConfigDef(), props).getList(WorkerConfig.LISTENERS_CONFIG)); + + props.put(WorkerConfig.LISTENERS_CONFIG, ""); + ConfigException ce = assertThrows(ConfigException.class, () -> new WorkerConfig(WorkerConfig.baseConfigDef(), props)); + assertTrue(ce.getMessage().contains(" listeners")); + + props.put(WorkerConfig.LISTENERS_CONFIG, ",,,"); + ce = assertThrows(ConfigException.class, () -> new WorkerConfig(WorkerConfig.baseConfigDef(), props)); + assertTrue(ce.getMessage().contains(" listeners")); + + props.put(WorkerConfig.LISTENERS_CONFIG, "http://a.b:9999,"); + ce = assertThrows(ConfigException.class, () -> new WorkerConfig(WorkerConfig.baseConfigDef(), props)); + assertTrue(ce.getMessage().contains(" listeners")); + + props.put(WorkerConfig.LISTENERS_CONFIG, "http://a.b:9999, ,https://a.b:9999"); + ce = assertThrows(ConfigException.class, () -> new WorkerConfig(WorkerConfig.baseConfigDef(), props)); + assertTrue(ce.getMessage().contains(" listeners")); + } + + @Test + public void testAdminListenersConfigAllowedValues() { + Map props = baseProps(); + + // no value set for "admin.listeners" + WorkerConfig config = new WorkerConfig(WorkerConfig.baseConfigDef(), props); + assertNull("Default value should be null.", config.getList(WorkerConfig.ADMIN_LISTENERS_CONFIG)); + + props.put(WorkerConfig.ADMIN_LISTENERS_CONFIG, ""); + config = new WorkerConfig(WorkerConfig.baseConfigDef(), props); + assertTrue(config.getList(WorkerConfig.ADMIN_LISTENERS_CONFIG).isEmpty()); + + props.put(WorkerConfig.ADMIN_LISTENERS_CONFIG, "http://a.b:9999, https://a.b:7812"); + config = new WorkerConfig(WorkerConfig.baseConfigDef(), props); + assertEquals(Arrays.asList("http://a.b:9999", "https://a.b:7812"), config.getList(WorkerConfig.ADMIN_LISTENERS_CONFIG)); + + new WorkerConfig(WorkerConfig.baseConfigDef(), props); + } + + @Test + public void testAdminListenersNotAllowingEmptyStrings() { + Map props = baseProps(); + + props.put(WorkerConfig.ADMIN_LISTENERS_CONFIG, "http://a.b:9999,"); + ConfigException ce = assertThrows(ConfigException.class, () -> new WorkerConfig(WorkerConfig.baseConfigDef(), props)); + assertTrue(ce.getMessage().contains(" admin.listeners")); + } + + @Test + public void testAdminListenersNotAllowingBlankStrings() { + Map props = baseProps(); + props.put(WorkerConfig.ADMIN_LISTENERS_CONFIG, "http://a.b:9999, ,https://a.b:9999"); + assertThrows(ConfigException.class, () -> new WorkerConfig(WorkerConfig.baseConfigDef(), props)); + } + + @Test + public void testInvalidHeaderConfigs() { + for (String config : INVALID_HEADER_CONFIGS) { + assertInvalidHeaderConfig(config); + } + } + + @Test + public void testValidHeaderConfigs() { + for (String config : VALID_HEADER_CONFIGS) { + assertValidHeaderConfig(config); + } + } + + private void assertInvalidHeaderConfig(String config) { + assertThrows(ConfigException.class, () -> WorkerConfig.validateHttpResponseHeaderConfig(config)); + } + + private void assertValidHeaderConfig(String config) { + WorkerConfig.validateHttpResponseHeaderConfig(config); + } + + private Map baseProps() { + Map props = new HashMap<>(); + props.put(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + props.put(WorkerConfig.KEY_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + props.put(WorkerConfig.VALUE_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + return props; + } + +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerConfigTransformerTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerConfigTransformerTest.java new file mode 100644 index 0000000..6f4bda6 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerConfigTransformerTest.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.common.config.ConfigChangeCallback; +import org.apache.kafka.common.config.ConfigData; +import org.apache.kafka.common.config.provider.ConfigProvider; +import org.easymock.EasyMock; +import static org.easymock.EasyMock.eq; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.easymock.PowerMock; +import org.powermock.api.easymock.annotation.Mock; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +import static org.apache.kafka.connect.runtime.ConnectorConfig.CONFIG_RELOAD_ACTION_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.CONFIG_RELOAD_ACTION_NONE; +import static org.easymock.EasyMock.notNull; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.powermock.api.easymock.PowerMock.replayAll; + +@RunWith(PowerMockRunner.class) +public class WorkerConfigTransformerTest { + + public static final String MY_KEY = "myKey"; + public static final String MY_CONNECTOR = "myConnector"; + public static final String TEST_KEY = "testKey"; + public static final String TEST_PATH = "testPath"; + public static final String TEST_KEY_WITH_TTL = "testKeyWithTTL"; + public static final String TEST_KEY_WITH_LONGER_TTL = "testKeyWithLongerTTL"; + public static final String TEST_RESULT = "testResult"; + public static final String TEST_RESULT_WITH_TTL = "testResultWithTTL"; + public static final String TEST_RESULT_WITH_LONGER_TTL = "testResultWithLongerTTL"; + + @Mock private Herder herder; + @Mock private Worker worker; + @Mock private HerderRequest requestId; + private WorkerConfigTransformer configTransformer; + + @Before + public void setup() { + worker = PowerMock.createMock(Worker.class); + herder = PowerMock.createMock(Herder.class); + configTransformer = new WorkerConfigTransformer(worker, Collections.singletonMap("test", new TestConfigProvider())); + } + + @Test + public void testReplaceVariable() { + Map result = configTransformer.transform(MY_CONNECTOR, Collections.singletonMap(MY_KEY, "${test:testPath:testKey}")); + assertEquals(TEST_RESULT, result.get(MY_KEY)); + } + + @Test + public void testReplaceVariableWithTTL() { + EasyMock.expect(worker.herder()).andReturn(herder); + + replayAll(); + + Map props = new HashMap<>(); + props.put(MY_KEY, "${test:testPath:testKeyWithTTL}"); + props.put(CONFIG_RELOAD_ACTION_CONFIG, CONFIG_RELOAD_ACTION_NONE); + Map result = configTransformer.transform(MY_CONNECTOR, props); + } + + @Test + public void testReplaceVariableWithTTLAndScheduleRestart() { + EasyMock.expect(worker.herder()).andReturn(herder); + EasyMock.expect(herder.restartConnector(eq(1L), eq(MY_CONNECTOR), notNull())).andReturn(requestId); + replayAll(); + + Map result = configTransformer.transform(MY_CONNECTOR, Collections.singletonMap(MY_KEY, "${test:testPath:testKeyWithTTL}")); + assertEquals(TEST_RESULT_WITH_TTL, result.get(MY_KEY)); + } + + @Test + public void testReplaceVariableWithTTLFirstCancelThenScheduleRestart() { + EasyMock.expect(worker.herder()).andReturn(herder); + EasyMock.expect(herder.restartConnector(eq(1L), eq(MY_CONNECTOR), notNull())).andReturn(requestId); + + EasyMock.expect(worker.herder()).andReturn(herder); + EasyMock.expectLastCall(); + requestId.cancel(); + EasyMock.expectLastCall(); + EasyMock.expect(herder.restartConnector(eq(10L), eq(MY_CONNECTOR), notNull())).andReturn(requestId); + + replayAll(); + + Map result = configTransformer.transform(MY_CONNECTOR, Collections.singletonMap(MY_KEY, "${test:testPath:testKeyWithTTL}")); + assertEquals(TEST_RESULT_WITH_TTL, result.get(MY_KEY)); + + result = configTransformer.transform(MY_CONNECTOR, Collections.singletonMap(MY_KEY, "${test:testPath:testKeyWithLongerTTL}")); + assertEquals(TEST_RESULT_WITH_LONGER_TTL, result.get(MY_KEY)); + } + + @Test + public void testTransformNullConfiguration() { + assertNull(configTransformer.transform(MY_CONNECTOR, null)); + } + + public static class TestConfigProvider implements ConfigProvider { + + public void configure(Map configs) { + } + + public ConfigData get(String path) { + return null; + } + + public ConfigData get(String path, Set keys) { + if (path.equals(TEST_PATH)) { + if (keys.contains(TEST_KEY)) { + return new ConfigData(Collections.singletonMap(TEST_KEY, TEST_RESULT)); + } else if (keys.contains(TEST_KEY_WITH_TTL)) { + return new ConfigData(Collections.singletonMap(TEST_KEY_WITH_TTL, TEST_RESULT_WITH_TTL), 1L); + } else if (keys.contains(TEST_KEY_WITH_LONGER_TTL)) { + return new ConfigData(Collections.singletonMap(TEST_KEY_WITH_LONGER_TTL, TEST_RESULT_WITH_LONGER_TTL), 10L); + } + } + return new ConfigData(Collections.emptyMap()); + } + + public void subscribe(String path, Set keys, ConfigChangeCallback callback) { + throw new UnsupportedOperationException(); + } + + public void unsubscribe(String path, Set keys) { + throw new UnsupportedOperationException(); + } + + public void close() { + } + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerConnectorTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerConnectorTest.java new file mode 100644 index 0000000..f99b4c1 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerConnectorTest.java @@ -0,0 +1,597 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.connect.connector.Connector; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.runtime.ConnectMetrics.MetricGroup; +import org.apache.kafka.connect.runtime.isolation.Plugins; +import org.apache.kafka.connect.sink.SinkConnector; +import org.apache.kafka.connect.sink.SinkConnectorContext; +import org.apache.kafka.connect.source.SourceConnector; +import org.apache.kafka.connect.source.SourceConnectorContext; +import org.apache.kafka.connect.storage.OffsetStorageReader; +import org.easymock.Capture; +import org.apache.kafka.connect.util.Callback; +import org.easymock.EasyMock; +import org.easymock.EasyMockRunner; +import org.easymock.EasyMockSupport; +import org.easymock.Mock; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import java.util.HashMap; +import java.util.Map; + +import static org.easymock.EasyMock.expectLastCall; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +@RunWith(EasyMockRunner.class) +public class WorkerConnectorTest extends EasyMockSupport { + + private static final String VERSION = "1.1"; + public static final String CONNECTOR = "connector"; + public static final Map CONFIG = new HashMap<>(); + static { + CONFIG.put(ConnectorConfig.CONNECTOR_CLASS_CONFIG, TestConnector.class.getName()); + CONFIG.put(ConnectorConfig.NAME_CONFIG, CONNECTOR); + CONFIG.put(SinkConnectorConfig.TOPICS_CONFIG, "my-topic"); + } + public ConnectorConfig connectorConfig; + public MockConnectMetrics metrics; + + @Mock Plugins plugins; + @Mock SourceConnector sourceConnector; + @Mock SinkConnector sinkConnector; + @Mock Connector connector; + @Mock CloseableConnectorContext ctx; + @Mock ConnectorStatus.Listener listener; + @Mock OffsetStorageReader offsetStorageReader; + @Mock ClassLoader classLoader; + + @Before + public void setup() { + connectorConfig = new ConnectorConfig(plugins, CONFIG); + metrics = new MockConnectMetrics(); + } + + @After + public void tearDown() { + if (metrics != null) metrics.stop(); + } + + @Test + public void testInitializeFailure() throws InterruptedException { + RuntimeException exception = new RuntimeException(); + connector = sourceConnector; + + connector.version(); + expectLastCall().andReturn(VERSION); + + connector.initialize(EasyMock.notNull(SourceConnectorContext.class)); + expectLastCall().andThrow(exception); + + listener.onFailure(CONNECTOR, exception); + expectLastCall(); + + listener.onShutdown(CONNECTOR); + expectLastCall(); + + ctx.close(); + expectLastCall(); + + replayAll(); + + WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, offsetStorageReader, classLoader); + + workerConnector.initialize(); + assertFailedMetric(workerConnector); + workerConnector.shutdown(); + workerConnector.doShutdown(); + assertStoppedMetric(workerConnector); + + verifyAll(); + } + + @Test + public void testFailureIsFinalState() { + RuntimeException exception = new RuntimeException(); + connector = sinkConnector; + + connector.version(); + expectLastCall().andReturn(VERSION); + + connector.initialize(EasyMock.notNull(SinkConnectorContext.class)); + expectLastCall().andThrow(exception); + + listener.onFailure(CONNECTOR, exception); + expectLastCall(); + + // expect no call to onStartup() after failure + + listener.onShutdown(CONNECTOR); + expectLastCall(); + + ctx.close(); + expectLastCall(); + + Callback onStateChange = createStrictMock(Callback.class); + onStateChange.onCompletion(EasyMock.anyObject(Exception.class), EasyMock.isNull()); + expectLastCall(); + + replayAll(); + + WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, offsetStorageReader, classLoader); + + workerConnector.initialize(); + assertFailedMetric(workerConnector); + workerConnector.doTransitionTo(TargetState.STARTED, onStateChange); + assertFailedMetric(workerConnector); + workerConnector.shutdown(); + workerConnector.doShutdown(); + assertStoppedMetric(workerConnector); + + verifyAll(); + } + + @Test + public void testStartupAndShutdown() { + connector = sourceConnector; + connector.version(); + expectLastCall().andReturn(VERSION); + + connector.initialize(EasyMock.notNull(SourceConnectorContext.class)); + expectLastCall(); + + connector.start(CONFIG); + expectLastCall(); + + listener.onStartup(CONNECTOR); + expectLastCall(); + + connector.stop(); + expectLastCall(); + + listener.onShutdown(CONNECTOR); + expectLastCall(); + + ctx.close(); + expectLastCall(); + + Callback onStateChange = createStrictMock(Callback.class); + onStateChange.onCompletion(EasyMock.isNull(), EasyMock.eq(TargetState.STARTED)); + expectLastCall(); + + replayAll(); + + WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, offsetStorageReader, classLoader); + + workerConnector.initialize(); + assertInitializedSourceMetric(workerConnector); + workerConnector.doTransitionTo(TargetState.STARTED, onStateChange); + assertRunningMetric(workerConnector); + workerConnector.shutdown(); + workerConnector.doShutdown(); + assertStoppedMetric(workerConnector); + + verifyAll(); + } + + @Test + public void testStartupAndPause() { + connector = sinkConnector; + connector.version(); + expectLastCall().andReturn(VERSION); + + connector.initialize(EasyMock.notNull(SinkConnectorContext.class)); + expectLastCall(); + + connector.start(CONFIG); + expectLastCall(); + + listener.onStartup(CONNECTOR); + expectLastCall(); + + connector.stop(); + expectLastCall(); + + listener.onPause(CONNECTOR); + expectLastCall(); + + listener.onShutdown(CONNECTOR); + expectLastCall(); + + ctx.close(); + expectLastCall(); + + Callback onStateChange = createStrictMock(Callback.class); + onStateChange.onCompletion(EasyMock.isNull(), EasyMock.eq(TargetState.STARTED)); + expectLastCall(); + onStateChange.onCompletion(EasyMock.isNull(), EasyMock.eq(TargetState.PAUSED)); + expectLastCall(); + + replayAll(); + + WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, offsetStorageReader, classLoader); + + workerConnector.initialize(); + assertInitializedSinkMetric(workerConnector); + workerConnector.doTransitionTo(TargetState.STARTED, onStateChange); + assertRunningMetric(workerConnector); + workerConnector.doTransitionTo(TargetState.PAUSED, onStateChange); + assertPausedMetric(workerConnector); + workerConnector.shutdown(); + workerConnector.doShutdown(); + assertStoppedMetric(workerConnector); + + verifyAll(); + } + + @Test + public void testOnResume() { + connector = sourceConnector; + connector.version(); + expectLastCall().andReturn(VERSION); + + connector.initialize(EasyMock.notNull(SourceConnectorContext.class)); + expectLastCall(); + + listener.onPause(CONNECTOR); + expectLastCall(); + + connector.start(CONFIG); + expectLastCall(); + + listener.onResume(CONNECTOR); + expectLastCall(); + + connector.stop(); + expectLastCall(); + + listener.onShutdown(CONNECTOR); + expectLastCall(); + + ctx.close(); + expectLastCall(); + + Callback onStateChange = createStrictMock(Callback.class); + onStateChange.onCompletion(EasyMock.isNull(), EasyMock.eq(TargetState.PAUSED)); + expectLastCall(); + onStateChange.onCompletion(EasyMock.isNull(), EasyMock.eq(TargetState.STARTED)); + expectLastCall(); + + replayAll(); + + WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, offsetStorageReader, classLoader); + + workerConnector.initialize(); + assertInitializedSourceMetric(workerConnector); + workerConnector.doTransitionTo(TargetState.PAUSED, onStateChange); + assertPausedMetric(workerConnector); + workerConnector.doTransitionTo(TargetState.STARTED, onStateChange); + assertRunningMetric(workerConnector); + workerConnector.shutdown(); + workerConnector.doShutdown(); + assertStoppedMetric(workerConnector); + + verifyAll(); + } + + @Test + public void testStartupPaused() { + connector = sinkConnector; + connector.version(); + expectLastCall().andReturn(VERSION); + + connector.initialize(EasyMock.notNull(SinkConnectorContext.class)); + expectLastCall(); + + // connector never gets started + + listener.onPause(CONNECTOR); + expectLastCall(); + + listener.onShutdown(CONNECTOR); + expectLastCall(); + + ctx.close(); + expectLastCall(); + + Callback onStateChange = createStrictMock(Callback.class); + onStateChange.onCompletion(EasyMock.isNull(), EasyMock.eq(TargetState.PAUSED)); + expectLastCall(); + + replayAll(); + + WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, offsetStorageReader, classLoader); + + workerConnector.initialize(); + assertInitializedSinkMetric(workerConnector); + workerConnector.doTransitionTo(TargetState.PAUSED, onStateChange); + assertPausedMetric(workerConnector); + workerConnector.shutdown(); + workerConnector.doShutdown(); + assertStoppedMetric(workerConnector); + + verifyAll(); + } + + @Test + public void testStartupFailure() { + RuntimeException exception = new RuntimeException(); + + connector = sinkConnector; + connector.version(); + expectLastCall().andReturn(VERSION); + + connector.initialize(EasyMock.notNull(SinkConnectorContext.class)); + expectLastCall(); + + connector.start(CONFIG); + expectLastCall().andThrow(exception); + + listener.onFailure(CONNECTOR, exception); + expectLastCall(); + + listener.onShutdown(CONNECTOR); + expectLastCall(); + + ctx.close(); + expectLastCall(); + + Callback onStateChange = createStrictMock(Callback.class); + onStateChange.onCompletion(EasyMock.anyObject(Exception.class), EasyMock.isNull()); + expectLastCall(); + + replayAll(); + + WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, offsetStorageReader, classLoader); + + workerConnector.initialize(); + assertInitializedSinkMetric(workerConnector); + workerConnector.doTransitionTo(TargetState.STARTED, onStateChange); + assertFailedMetric(workerConnector); + workerConnector.shutdown(); + workerConnector.doShutdown(); + assertStoppedMetric(workerConnector); + + verifyAll(); + } + + @Test + public void testShutdownFailure() { + RuntimeException exception = new RuntimeException(); + connector = sourceConnector; + + connector.version(); + expectLastCall().andReturn(VERSION); + + connector.initialize(EasyMock.notNull(SourceConnectorContext.class)); + expectLastCall(); + + connector.start(CONFIG); + expectLastCall(); + + listener.onStartup(CONNECTOR); + expectLastCall(); + + connector.stop(); + expectLastCall().andThrow(exception); + + Callback onStateChange = createStrictMock(Callback.class); + onStateChange.onCompletion(EasyMock.isNull(), EasyMock.eq(TargetState.STARTED)); + expectLastCall(); + + listener.onFailure(CONNECTOR, exception); + expectLastCall(); + + ctx.close(); + expectLastCall(); + + replayAll(); + + WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, offsetStorageReader, classLoader); + + workerConnector.initialize(); + assertInitializedSourceMetric(workerConnector); + workerConnector.doTransitionTo(TargetState.STARTED, onStateChange); + assertRunningMetric(workerConnector); + workerConnector.shutdown(); + workerConnector.doShutdown(); + assertFailedMetric(workerConnector); + + verifyAll(); + } + + @Test + public void testTransitionStartedToStarted() { + connector = sourceConnector; + connector.version(); + expectLastCall().andReturn(VERSION); + + connector.initialize(EasyMock.notNull(SourceConnectorContext.class)); + expectLastCall(); + + connector.start(CONFIG); + expectLastCall(); + + // expect only one call to onStartup() + listener.onStartup(CONNECTOR); + expectLastCall(); + + connector.stop(); + expectLastCall(); + + listener.onShutdown(CONNECTOR); + expectLastCall(); + + ctx.close(); + expectLastCall(); + + Callback onStateChange = createStrictMock(Callback.class); + onStateChange.onCompletion(EasyMock.isNull(), EasyMock.eq(TargetState.STARTED)); + expectLastCall().times(2); + + replayAll(); + + WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, offsetStorageReader, classLoader); + + workerConnector.initialize(); + assertInitializedSourceMetric(workerConnector); + workerConnector.doTransitionTo(TargetState.STARTED, onStateChange); + assertRunningMetric(workerConnector); + workerConnector.doTransitionTo(TargetState.STARTED, onStateChange); + assertRunningMetric(workerConnector); + workerConnector.shutdown(); + workerConnector.doShutdown(); + assertStoppedMetric(workerConnector); + + verifyAll(); + } + + @Test + public void testTransitionPausedToPaused() { + connector = sourceConnector; + connector.version(); + expectLastCall().andReturn(VERSION); + + connector.initialize(EasyMock.notNull(SourceConnectorContext.class)); + expectLastCall(); + + connector.start(CONFIG); + expectLastCall(); + + listener.onStartup(CONNECTOR); + expectLastCall(); + + connector.stop(); + expectLastCall(); + + listener.onPause(CONNECTOR); + expectLastCall(); + + listener.onShutdown(CONNECTOR); + expectLastCall(); + + ctx.close(); + expectLastCall(); + + Callback onStateChange = createStrictMock(Callback.class); + onStateChange.onCompletion(EasyMock.isNull(), EasyMock.eq(TargetState.STARTED)); + expectLastCall(); + onStateChange.onCompletion(EasyMock.isNull(), EasyMock.eq(TargetState.PAUSED)); + expectLastCall().times(2); + + replayAll(); + + WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, offsetStorageReader, classLoader); + + workerConnector.initialize(); + assertInitializedSourceMetric(workerConnector); + workerConnector.doTransitionTo(TargetState.STARTED, onStateChange); + assertRunningMetric(workerConnector); + workerConnector.doTransitionTo(TargetState.PAUSED, onStateChange); + assertPausedMetric(workerConnector); + workerConnector.doTransitionTo(TargetState.PAUSED, onStateChange); + assertPausedMetric(workerConnector); + workerConnector.shutdown(); + workerConnector.doShutdown(); + assertStoppedMetric(workerConnector); + + verifyAll(); + } + + @Test + public void testFailConnectorThatIsNeitherSourceNorSink() { + connector.version(); + expectLastCall().andReturn(VERSION); + + Capture exceptionCapture = Capture.newInstance(); + listener.onFailure(EasyMock.eq(CONNECTOR), EasyMock.capture(exceptionCapture)); + expectLastCall(); + + replayAll(); + + WorkerConnector workerConnector = new WorkerConnector(CONNECTOR, connector, connectorConfig, ctx, metrics, listener, offsetStorageReader, classLoader); + + workerConnector.initialize(); + Throwable e = exceptionCapture.getValue(); + assertTrue(e instanceof ConnectException); + assertTrue(e.getMessage().contains("must be a subclass of")); + + verifyAll(); + } + + protected void assertFailedMetric(WorkerConnector workerConnector) { + assertFalse(workerConnector.metrics().isUnassigned()); + assertTrue(workerConnector.metrics().isFailed()); + assertFalse(workerConnector.metrics().isPaused()); + assertFalse(workerConnector.metrics().isRunning()); + } + + protected void assertPausedMetric(WorkerConnector workerConnector) { + assertFalse(workerConnector.metrics().isUnassigned()); + assertFalse(workerConnector.metrics().isFailed()); + assertTrue(workerConnector.metrics().isPaused()); + assertFalse(workerConnector.metrics().isRunning()); + } + + protected void assertRunningMetric(WorkerConnector workerConnector) { + assertFalse(workerConnector.metrics().isUnassigned()); + assertFalse(workerConnector.metrics().isFailed()); + assertFalse(workerConnector.metrics().isPaused()); + assertTrue(workerConnector.metrics().isRunning()); + } + + protected void assertStoppedMetric(WorkerConnector workerConnector) { + assertTrue(workerConnector.metrics().isUnassigned()); + assertFalse(workerConnector.metrics().isFailed()); + assertFalse(workerConnector.metrics().isPaused()); + assertFalse(workerConnector.metrics().isRunning()); + } + + protected void assertInitializedSinkMetric(WorkerConnector workerConnector) { + assertInitializedMetric(workerConnector, "sink"); + } + + protected void assertInitializedSourceMetric(WorkerConnector workerConnector) { + assertInitializedMetric(workerConnector, "source"); + } + + protected void assertInitializedMetric(WorkerConnector workerConnector, String expectedType) { + assertTrue(workerConnector.metrics().isUnassigned()); + assertFalse(workerConnector.metrics().isFailed()); + assertFalse(workerConnector.metrics().isPaused()); + assertFalse(workerConnector.metrics().isRunning()); + MetricGroup metricGroup = workerConnector.metrics().metricGroup(); + String status = metrics.currentMetricValueAsString(metricGroup, "status"); + String type = metrics.currentMetricValueAsString(metricGroup, "connector-type"); + String clazz = metrics.currentMetricValueAsString(metricGroup, "connector-class"); + String version = metrics.currentMetricValueAsString(metricGroup, "connector-version"); + assertEquals(expectedType, type); + assertNotNull(clazz); + assertEquals(VERSION, version); + } + + private static abstract class TestConnector extends Connector { + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerMetricsGroupTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerMetricsGroupTest.java new file mode 100644 index 0000000..2eb20ce --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerMetricsGroupTest.java @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.MetricNameTemplate; +import org.apache.kafka.common.metrics.CompoundStat; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.stats.CumulativeSum; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.easymock.PowerMock; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.util.HashMap; + +import static org.easymock.EasyMock.anyObject; +import static org.easymock.EasyMock.anyString; +import static org.easymock.EasyMock.eq; +import static org.powermock.api.easymock.PowerMock.expectLastCall; + +@RunWith(PowerMockRunner.class) +@PrepareForTest({Sensor.class, MetricName.class}) +public class WorkerMetricsGroupTest { + private final String connector = "org.FakeConnector"; + private final ConnectorTaskId task = new ConnectorTaskId(connector, 0); + private final RuntimeException exception = new RuntimeException(); + + private ConnectMetrics connectMetrics; + + private Sensor connectorStartupResults; + private Sensor connectorStartupAttempts; + private Sensor connectorStartupSuccesses; + private Sensor connectorStartupFailures; + + private Sensor taskStartupResults; + private Sensor taskStartupAttempts; + private Sensor taskStartupSuccesses; + private Sensor taskStartupFailures; + + private ConnectorStatus.Listener delegateConnectorListener; + private TaskStatus.Listener delegateTaskListener; + + @Before + public void setup() { + connectMetrics = PowerMock.createMock(ConnectMetrics.class); + ConnectMetricsRegistry connectMetricsRegistry = PowerMock.createNiceMock(ConnectMetricsRegistry.class); + ConnectMetrics.MetricGroup metricGroup = PowerMock.createNiceMock(ConnectMetrics.MetricGroup.class); + + connectMetrics.registry(); + expectLastCall().andReturn(connectMetricsRegistry); + + connectMetrics.group(anyString()); + expectLastCall().andReturn(metricGroup); + + MetricName metricName = PowerMock.createMock(MetricName.class); + metricGroup.metricName(anyObject(MetricNameTemplate.class)); + expectLastCall().andStubReturn(metricName); + + connectorStartupResults = mockSensor(metricGroup, "connector-startup-results"); + connectorStartupAttempts = mockSensor(metricGroup, "connector-startup-attempts"); + connectorStartupSuccesses = mockSensor(metricGroup, "connector-startup-successes"); + connectorStartupFailures = mockSensor(metricGroup, "connector-startup-failures"); + + taskStartupResults = mockSensor(metricGroup, "task-startup-results"); + taskStartupAttempts = mockSensor(metricGroup, "task-startup-attempts"); + taskStartupSuccesses = mockSensor(metricGroup, "task-startup-successes"); + taskStartupFailures = mockSensor(metricGroup, "task-startup-failures"); + + delegateConnectorListener = PowerMock.createStrictMock(ConnectorStatus.Listener.class); + delegateTaskListener = PowerMock.createStrictMock(TaskStatus.Listener.class); + } + + private Sensor mockSensor(ConnectMetrics.MetricGroup metricGroup, String name) { + Sensor sensor = PowerMock.createMock(Sensor.class); + metricGroup.sensor(eq(name)); + expectLastCall().andReturn(sensor); + + sensor.add(anyObject(CompoundStat.class)); + expectLastCall().andStubReturn(true); + + sensor.add(anyObject(MetricName.class), anyObject(CumulativeSum.class)); + expectLastCall().andStubReturn(true); + + return sensor; + } + + @Test + public void testConnectorStartupRecordedMetrics() { + delegateConnectorListener.onStartup(eq(connector)); + expectLastCall(); + + connectorStartupAttempts.record(eq(1.0)); + expectLastCall(); + connectorStartupSuccesses.record(eq(1.0)); + expectLastCall(); + connectorStartupResults.record(eq(1.0)); + expectLastCall(); + + PowerMock.replayAll(); + + WorkerMetricsGroup workerMetricsGroup = new WorkerMetricsGroup(new HashMap<>(), new HashMap<>(), connectMetrics); + final ConnectorStatus.Listener connectorListener = workerMetricsGroup.wrapStatusListener(delegateConnectorListener); + + connectorListener.onStartup(connector); + + PowerMock.verifyAll(); + } + + @Test + public void testConnectorFailureAfterStartupRecordedMetrics() { + delegateConnectorListener.onStartup(eq(connector)); + expectLastCall(); + + connectorStartupAttempts.record(eq(1.0)); + expectLastCall(); + connectorStartupSuccesses.record(eq(1.0)); + expectLastCall(); + connectorStartupResults.record(eq(1.0)); + expectLastCall(); + + delegateConnectorListener.onFailure(eq(connector), eq(exception)); + expectLastCall(); + + // recordConnectorStartupFailure() should not be called if failure happens after a successful startup + + PowerMock.replayAll(); + + WorkerMetricsGroup workerMetricsGroup = new WorkerMetricsGroup(new HashMap<>(), new HashMap<>(), connectMetrics); + final ConnectorStatus.Listener connectorListener = workerMetricsGroup.wrapStatusListener(delegateConnectorListener); + + connectorListener.onStartup(connector); + connectorListener.onFailure(connector, exception); + + PowerMock.verifyAll(); + } + + @Test + public void testConnectorFailureBeforeStartupRecordedMetrics() { + delegateConnectorListener.onFailure(eq(connector), eq(exception)); + expectLastCall(); + + connectorStartupAttempts.record(eq(1.0)); + expectLastCall(); + connectorStartupFailures.record(eq(1.0)); + expectLastCall(); + connectorStartupResults.record(eq(0.0)); + expectLastCall(); + + PowerMock.replayAll(); + + WorkerMetricsGroup workerMetricsGroup = new WorkerMetricsGroup(new HashMap<>(), new HashMap<>(), connectMetrics); + final ConnectorStatus.Listener connectorListener = workerMetricsGroup.wrapStatusListener(delegateConnectorListener); + + connectorListener.onFailure(connector, exception); + + PowerMock.verifyAll(); + } + + @Test + public void testTaskStartupRecordedMetrics() { + delegateTaskListener.onStartup(eq(task)); + expectLastCall(); + + taskStartupAttempts.record(eq(1.0)); + expectLastCall(); + taskStartupSuccesses.record(eq(1.0)); + expectLastCall(); + taskStartupResults.record(eq(1.0)); + expectLastCall(); + + PowerMock.replayAll(); + + WorkerMetricsGroup workerMetricsGroup = new WorkerMetricsGroup(new HashMap<>(), new HashMap<>(), connectMetrics); + final TaskStatus.Listener taskListener = workerMetricsGroup.wrapStatusListener(delegateTaskListener); + + taskListener.onStartup(task); + + PowerMock.verifyAll(); + } + + @Test + public void testTaskFailureAfterStartupRecordedMetrics() { + delegateTaskListener.onStartup(eq(task)); + expectLastCall(); + + taskStartupAttempts.record(eq(1.0)); + expectLastCall(); + taskStartupSuccesses.record(eq(1.0)); + expectLastCall(); + taskStartupResults.record(eq(1.0)); + expectLastCall(); + + delegateTaskListener.onFailure(eq(task), eq(exception)); + expectLastCall(); + + // recordTaskFailure() should not be called if failure happens after a successful startup + + PowerMock.replayAll(); + + WorkerMetricsGroup workerMetricsGroup = new WorkerMetricsGroup(new HashMap<>(), new HashMap<>(), connectMetrics); + final TaskStatus.Listener taskListener = workerMetricsGroup.wrapStatusListener(delegateTaskListener); + + taskListener.onStartup(task); + taskListener.onFailure(task, exception); + + PowerMock.verifyAll(); + } + + @Test + public void testTaskFailureBeforeStartupRecordedMetrics() { + delegateTaskListener.onFailure(eq(task), eq(exception)); + expectLastCall(); + + taskStartupAttempts.record(eq(1.0)); + expectLastCall(); + taskStartupFailures.record(eq(1.0)); + expectLastCall(); + taskStartupResults.record(eq(0.0)); + expectLastCall(); + + PowerMock.replayAll(); + + WorkerMetricsGroup workerMetricsGroup = new WorkerMetricsGroup(new HashMap<>(), new HashMap<>(), connectMetrics); + final TaskStatus.Listener taskListener = workerMetricsGroup.wrapStatusListener(delegateTaskListener); + + taskListener.onFailure(task, exception); + + PowerMock.verifyAll(); + } + +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSinkTaskTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSinkTaskTest.java new file mode 100644 index 0000000..1600dcf --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSinkTaskTest.java @@ -0,0 +1,2104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import java.util.Arrays; +import java.util.Iterator; +import org.apache.kafka.clients.consumer.ConsumerRebalanceListener; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.consumer.OffsetCommitCallback; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.WakeupException; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.errors.RetriableException; +import org.apache.kafka.connect.runtime.ConnectMetrics.MetricGroup; +import org.apache.kafka.connect.runtime.distributed.ClusterConfigState; +import org.apache.kafka.connect.runtime.WorkerSinkTask.SinkTaskMetricsGroup; +import org.apache.kafka.connect.runtime.errors.RetryWithToleranceOperatorTest; +import org.apache.kafka.connect.runtime.isolation.PluginClassLoader; +import org.apache.kafka.connect.runtime.standalone.StandaloneConfig; +import org.apache.kafka.connect.sink.SinkConnector; +import org.apache.kafka.connect.sink.SinkRecord; +import org.apache.kafka.connect.sink.SinkTask; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.storage.HeaderConverter; +import org.apache.kafka.connect.storage.StatusBackingStore; +import org.apache.kafka.connect.storage.StringConverter; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.easymock.Capture; +import org.easymock.CaptureType; +import org.easymock.EasyMock; +import org.easymock.IExpectationSetters; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.easymock.PowerMock; +import org.powermock.api.easymock.annotation.Mock; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; +import org.powermock.reflect.Whitebox; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import static java.util.Arrays.asList; +import static java.util.Collections.singleton; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +@RunWith(PowerMockRunner.class) +@PrepareForTest(WorkerSinkTask.class) +@PowerMockIgnore("javax.management.*") +public class WorkerSinkTaskTest { + // These are fixed to keep this code simpler. In this example we assume byte[] raw values + // with mix of integer/string in Connect + private static final String TOPIC = "test"; + private static final int PARTITION = 12; + private static final int PARTITION2 = 13; + private static final int PARTITION3 = 14; + private static final long FIRST_OFFSET = 45; + private static final Schema KEY_SCHEMA = Schema.INT32_SCHEMA; + private static final int KEY = 12; + private static final Schema VALUE_SCHEMA = Schema.STRING_SCHEMA; + private static final String VALUE = "VALUE"; + private static final byte[] RAW_KEY = "key".getBytes(); + private static final byte[] RAW_VALUE = "value".getBytes(); + + private static final TopicPartition TOPIC_PARTITION = new TopicPartition(TOPIC, PARTITION); + private static final TopicPartition TOPIC_PARTITION2 = new TopicPartition(TOPIC, PARTITION2); + private static final TopicPartition TOPIC_PARTITION3 = new TopicPartition(TOPIC, PARTITION3); + + private static final Set INITIAL_ASSIGNMENT = + new HashSet<>(Arrays.asList(TOPIC_PARTITION, TOPIC_PARTITION2)); + + private static final Map TASK_PROPS = new HashMap<>(); + static { + TASK_PROPS.put(SinkConnector.TOPICS_CONFIG, TOPIC); + TASK_PROPS.put(TaskConfig.TASK_CLASS_CONFIG, TestSinkTask.class.getName()); + } + private static final TaskConfig TASK_CONFIG = new TaskConfig(TASK_PROPS); + + private ConnectorTaskId taskId = new ConnectorTaskId("job", 0); + private ConnectorTaskId taskId1 = new ConnectorTaskId("job", 1); + private TargetState initialState = TargetState.STARTED; + private MockTime time; + private WorkerSinkTask workerTask; + @Mock + private SinkTask sinkTask; + private Capture sinkTaskContext = EasyMock.newCapture(); + private WorkerConfig workerConfig; + private MockConnectMetrics metrics; + @Mock + private PluginClassLoader pluginLoader; + @Mock + private Converter keyConverter; + @Mock + private Converter valueConverter; + @Mock + private HeaderConverter headerConverter; + @Mock + private TransformationChain transformationChain; + @Mock + private TaskStatus.Listener statusListener; + @Mock + private StatusBackingStore statusBackingStore; + @Mock + private KafkaConsumer consumer; + private Capture rebalanceListener = EasyMock.newCapture(); + private Capture topicsRegex = EasyMock.newCapture(); + + private long recordsReturnedTp1; + private long recordsReturnedTp3; + + @Before + public void setUp() { + time = new MockTime(); + Map workerProps = new HashMap<>(); + workerProps.put("key.converter", "org.apache.kafka.connect.json.JsonConverter"); + workerProps.put("value.converter", "org.apache.kafka.connect.json.JsonConverter"); + workerProps.put("offset.storage.file.filename", "/tmp/connect.offsets"); + workerConfig = new StandaloneConfig(workerProps); + pluginLoader = PowerMock.createMock(PluginClassLoader.class); + metrics = new MockConnectMetrics(time); + recordsReturnedTp1 = 0; + recordsReturnedTp3 = 0; + } + + private void createTask(TargetState initialState) { + createTask(initialState, keyConverter, valueConverter, headerConverter); + } + + private void createTask(TargetState initialState, Converter keyConverter, Converter valueConverter, HeaderConverter headerConverter) { + workerTask = new WorkerSinkTask( + taskId, sinkTask, statusListener, initialState, workerConfig, ClusterConfigState.EMPTY, metrics, + keyConverter, valueConverter, headerConverter, + transformationChain, consumer, pluginLoader, time, + RetryWithToleranceOperatorTest.NOOP_OPERATOR, null, statusBackingStore); + } + + @After + public void tearDown() { + if (metrics != null) metrics.stop(); + } + + @Test + public void testStartPaused() throws Exception { + createTask(TargetState.PAUSED); + + expectInitializeTask(); + expectTaskGetTopic(true); + expectPollInitialAssignment(); + + EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT); + consumer.pause(INITIAL_ASSIGNMENT); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + workerTask.iteration(); + time.sleep(10000L); + + assertSinkMetricValue("partition-count", 2); + assertTaskMetricValue("status", "paused"); + assertTaskMetricValue("running-ratio", 0.0); + assertTaskMetricValue("pause-ratio", 1.0); + assertTaskMetricValue("offset-commit-max-time-ms", Double.NaN); + + PowerMock.verifyAll(); + } + + @Test + public void testPause() throws Exception { + createTask(initialState); + + expectInitializeTask(); + expectTaskGetTopic(true); + expectPollInitialAssignment(); + + expectConsumerPoll(1); + expectConversionAndTransformation(1); + sinkTask.put(EasyMock.anyObject()); + EasyMock.expectLastCall(); + + // Pause + statusListener.onPause(taskId); + EasyMock.expectLastCall(); + expectConsumerWakeup(); + EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT); + consumer.pause(INITIAL_ASSIGNMENT); + PowerMock.expectLastCall(); + + // Offset commit as requested when pausing; No records returned by consumer.poll() + sinkTask.preCommit(EasyMock.anyObject()); + EasyMock.expectLastCall().andStubReturn(Collections.emptyMap()); + expectConsumerPoll(0); + sinkTask.put(Collections.emptyList()); + EasyMock.expectLastCall(); + + // And unpause + statusListener.onResume(taskId); + EasyMock.expectLastCall(); + expectConsumerWakeup(); + EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT).times(2); + INITIAL_ASSIGNMENT.forEach(tp -> { + consumer.resume(Collections.singleton(tp)); + PowerMock.expectLastCall(); + }); + + expectConsumerPoll(1); + expectConversionAndTransformation(1); + sinkTask.put(EasyMock.anyObject()); + EasyMock.expectLastCall(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + workerTask.iteration(); // initial assignment + workerTask.iteration(); // fetch some data + workerTask.transitionTo(TargetState.PAUSED); + time.sleep(10000L); + + assertSinkMetricValue("partition-count", 2); + assertSinkMetricValue("sink-record-read-total", 1.0); + assertSinkMetricValue("sink-record-send-total", 1.0); + assertSinkMetricValue("sink-record-active-count", 1.0); + assertSinkMetricValue("sink-record-active-count-max", 1.0); + assertSinkMetricValue("sink-record-active-count-avg", 0.333333); + assertSinkMetricValue("offset-commit-seq-no", 0.0); + assertSinkMetricValue("offset-commit-completion-rate", 0.0); + assertSinkMetricValue("offset-commit-completion-total", 0.0); + assertSinkMetricValue("offset-commit-skip-rate", 0.0); + assertSinkMetricValue("offset-commit-skip-total", 0.0); + assertTaskMetricValue("status", "running"); + assertTaskMetricValue("running-ratio", 1.0); + assertTaskMetricValue("pause-ratio", 0.0); + assertTaskMetricValue("batch-size-max", 1.0); + assertTaskMetricValue("batch-size-avg", 0.5); + assertTaskMetricValue("offset-commit-max-time-ms", Double.NaN); + assertTaskMetricValue("offset-commit-failure-percentage", 0.0); + assertTaskMetricValue("offset-commit-success-percentage", 0.0); + + workerTask.iteration(); // wakeup + workerTask.iteration(); // now paused + time.sleep(30000L); + + assertSinkMetricValue("offset-commit-seq-no", 1.0); + assertSinkMetricValue("offset-commit-completion-rate", 0.0333); + assertSinkMetricValue("offset-commit-completion-total", 1.0); + assertSinkMetricValue("offset-commit-skip-rate", 0.0); + assertSinkMetricValue("offset-commit-skip-total", 0.0); + assertTaskMetricValue("status", "paused"); + assertTaskMetricValue("running-ratio", 0.25); + assertTaskMetricValue("pause-ratio", 0.75); + + workerTask.transitionTo(TargetState.STARTED); + workerTask.iteration(); // wakeup + workerTask.iteration(); // now unpaused + //printMetrics(); + + PowerMock.verifyAll(); + } + + @Test + public void testShutdown() throws Exception { + createTask(initialState); + + expectInitializeTask(); + expectTaskGetTopic(true); + + // first iteration + expectPollInitialAssignment(); + + // second iteration + EasyMock.expect(sinkTask.preCommit(EasyMock.anyObject())).andReturn(Collections.emptyMap()); + expectConsumerPoll(1); + expectConversionAndTransformation(1); + sinkTask.put(EasyMock.anyObject()); + EasyMock.expectLastCall(); + + // WorkerSinkTask::stop + consumer.wakeup(); + PowerMock.expectLastCall(); + sinkTask.stop(); + PowerMock.expectLastCall(); + + EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT); + // WorkerSinkTask::close + consumer.close(); + PowerMock.expectLastCall().andAnswer(() -> { + rebalanceListener.getValue().onPartitionsRevoked( + INITIAL_ASSIGNMENT + ); + return null; + }); + transformationChain.close(); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + workerTask.iteration(); + sinkTaskContext.getValue().requestCommit(); // Force an offset commit + workerTask.iteration(); + workerTask.stop(); + workerTask.close(); + + PowerMock.verifyAll(); + } + + @Test + public void testPollRedelivery() throws Exception { + createTask(initialState); + + expectInitializeTask(); + expectTaskGetTopic(true); + expectPollInitialAssignment(); + + // If a retriable exception is thrown, we should redeliver the same batch, pausing the consumer in the meantime + expectConsumerPoll(1); + expectConversionAndTransformation(1); + Capture> records = EasyMock.newCapture(CaptureType.ALL); + sinkTask.put(EasyMock.capture(records)); + EasyMock.expectLastCall().andThrow(new RetriableException("retry")); + // Pause + EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT); + consumer.pause(INITIAL_ASSIGNMENT); + PowerMock.expectLastCall(); + + // Retry delivery should succeed + expectConsumerPoll(0); + sinkTask.put(EasyMock.capture(records)); + EasyMock.expectLastCall(); + // And unpause + EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT); + INITIAL_ASSIGNMENT.forEach(tp -> { + consumer.resume(singleton(tp)); + PowerMock.expectLastCall(); + }); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + workerTask.iteration(); + time.sleep(10000L); + + assertSinkMetricValue("partition-count", 2); + assertSinkMetricValue("sink-record-read-total", 0.0); + assertSinkMetricValue("sink-record-send-total", 0.0); + assertSinkMetricValue("sink-record-active-count", 0.0); + assertSinkMetricValue("sink-record-active-count-max", 0.0); + assertSinkMetricValue("sink-record-active-count-avg", 0.0); + assertSinkMetricValue("offset-commit-seq-no", 0.0); + assertSinkMetricValue("offset-commit-completion-rate", 0.0); + assertSinkMetricValue("offset-commit-completion-total", 0.0); + assertSinkMetricValue("offset-commit-skip-rate", 0.0); + assertSinkMetricValue("offset-commit-skip-total", 0.0); + assertTaskMetricValue("status", "running"); + assertTaskMetricValue("running-ratio", 1.0); + assertTaskMetricValue("pause-ratio", 0.0); + assertTaskMetricValue("batch-size-max", 0.0); + assertTaskMetricValue("batch-size-avg", 0.0); + assertTaskMetricValue("offset-commit-max-time-ms", Double.NaN); + assertTaskMetricValue("offset-commit-failure-percentage", 0.0); + assertTaskMetricValue("offset-commit-success-percentage", 0.0); + + workerTask.iteration(); + workerTask.iteration(); + time.sleep(30000L); + + assertSinkMetricValue("sink-record-read-total", 1.0); + assertSinkMetricValue("sink-record-send-total", 1.0); + assertSinkMetricValue("sink-record-active-count", 1.0); + assertSinkMetricValue("sink-record-active-count-max", 1.0); + assertSinkMetricValue("sink-record-active-count-avg", 0.5); + assertTaskMetricValue("status", "running"); + assertTaskMetricValue("running-ratio", 1.0); + assertTaskMetricValue("batch-size-max", 1.0); + assertTaskMetricValue("batch-size-avg", 0.5); + + PowerMock.verifyAll(); + } + + @Test + public void testPollRedeliveryWithConsumerRebalance() throws Exception { + createTask(initialState); + + expectInitializeTask(); + expectTaskGetTopic(true); + expectPollInitialAssignment(); + + // If a retriable exception is thrown, we should redeliver the same batch, pausing the consumer in the meantime + expectConsumerPoll(1); + expectConversionAndTransformation(1); + sinkTask.put(EasyMock.anyObject()); + EasyMock.expectLastCall().andThrow(new RetriableException("retry")); + // Pause + EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT); + consumer.pause(INITIAL_ASSIGNMENT); + PowerMock.expectLastCall(); + + // Empty consumer poll (all partitions are paused) with rebalance; one new partition is assigned + EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer( + () -> { + rebalanceListener.getValue().onPartitionsRevoked(Collections.emptySet()); + rebalanceListener.getValue().onPartitionsAssigned(Collections.singleton(TOPIC_PARTITION3)); + return ConsumerRecords.empty(); + }); + Set newAssignment = new HashSet<>(Arrays.asList(TOPIC_PARTITION, TOPIC_PARTITION2, TOPIC_PARTITION3)); + EasyMock.expect(consumer.assignment()).andReturn(newAssignment).times(3); + EasyMock.expect(consumer.position(TOPIC_PARTITION3)).andReturn(FIRST_OFFSET); + sinkTask.open(Collections.singleton(TOPIC_PARTITION3)); + EasyMock.expectLastCall(); + // All partitions are re-paused in order to pause any newly-assigned partitions so that redelivery efforts can continue + consumer.pause(newAssignment); + EasyMock.expectLastCall(); + sinkTask.put(EasyMock.anyObject()); + EasyMock.expectLastCall().andThrow(new RetriableException("retry")); + + // Next delivery attempt fails again + expectConsumerPoll(0); + sinkTask.put(EasyMock.anyObject()); + EasyMock.expectLastCall().andThrow(new RetriableException("retry")); + + // Non-empty consumer poll; all initially-assigned partitions are revoked in rebalance, and new partitions are allowed to resume + ConsumerRecord newRecord = new ConsumerRecord<>(TOPIC, PARTITION3, FIRST_OFFSET, RAW_KEY, RAW_VALUE); + EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer( + () -> { + rebalanceListener.getValue().onPartitionsRevoked(INITIAL_ASSIGNMENT); + rebalanceListener.getValue().onPartitionsAssigned(Collections.emptyList()); + return new ConsumerRecords<>(Collections.singletonMap(TOPIC_PARTITION3, Collections.singletonList(newRecord))); + }); + newAssignment = Collections.singleton(TOPIC_PARTITION3); + EasyMock.expect(consumer.assignment()).andReturn(new HashSet<>(newAssignment)).times(3); + final Map offsets = INITIAL_ASSIGNMENT.stream() + .collect(Collectors.toMap(Function.identity(), tp -> new OffsetAndMetadata(FIRST_OFFSET))); + sinkTask.preCommit(offsets); + EasyMock.expectLastCall().andReturn(offsets); + sinkTask.close(INITIAL_ASSIGNMENT); + EasyMock.expectLastCall(); + // All partitions are resumed, as all previously paused-for-redelivery partitions were revoked + newAssignment.forEach(tp -> { + consumer.resume(Collections.singleton(tp)); + EasyMock.expectLastCall(); + }); + expectConversionAndTransformation(1); + sinkTask.put(EasyMock.anyObject()); + EasyMock.expectLastCall(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + workerTask.iteration(); + workerTask.iteration(); + workerTask.iteration(); + workerTask.iteration(); + workerTask.iteration(); + + PowerMock.verifyAll(); + } + + @Test + public void testErrorInRebalancePartitionLoss() throws Exception { + RuntimeException exception = new RuntimeException("Revocation error"); + + createTask(initialState); + + expectInitializeTask(); + expectTaskGetTopic(true); + expectPollInitialAssignment(); + expectRebalanceLossError(exception); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + workerTask.iteration(); + try { + workerTask.iteration(); + fail("Poll should have raised the rebalance exception"); + } catch (RuntimeException e) { + assertEquals(exception, e); + } + + PowerMock.verifyAll(); + } + + @Test + public void testErrorInRebalancePartitionRevocation() throws Exception { + RuntimeException exception = new RuntimeException("Revocation error"); + + createTask(initialState); + + expectInitializeTask(); + expectTaskGetTopic(true); + expectPollInitialAssignment(); + expectRebalanceRevocationError(exception); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + workerTask.iteration(); + try { + workerTask.iteration(); + fail("Poll should have raised the rebalance exception"); + } catch (RuntimeException e) { + assertEquals(exception, e); + } + + PowerMock.verifyAll(); + } + + @Test + public void testErrorInRebalancePartitionAssignment() throws Exception { + RuntimeException exception = new RuntimeException("Assignment error"); + + createTask(initialState); + + expectInitializeTask(); + expectTaskGetTopic(true); + expectPollInitialAssignment(); + expectRebalanceAssignmentError(exception); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + workerTask.iteration(); + try { + workerTask.iteration(); + fail("Poll should have raised the rebalance exception"); + } catch (RuntimeException e) { + assertEquals(exception, e); + } + + PowerMock.verifyAll(); + } + + @Test + public void testPartialRevocationAndAssignment() throws Exception { + createTask(initialState); + + expectInitializeTask(); + expectTaskGetTopic(true); + expectPollInitialAssignment(); + + EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer( + () -> { + rebalanceListener.getValue().onPartitionsRevoked(Collections.singleton(TOPIC_PARTITION)); + rebalanceListener.getValue().onPartitionsAssigned(Collections.emptySet()); + return ConsumerRecords.empty(); + }); + EasyMock.expect(consumer.assignment()).andReturn(Collections.singleton(TOPIC_PARTITION)).times(2); + final Map offsets = new HashMap<>(); + offsets.put(TOPIC_PARTITION, new OffsetAndMetadata(FIRST_OFFSET)); + sinkTask.preCommit(offsets); + EasyMock.expectLastCall().andReturn(offsets); + sinkTask.close(Collections.singleton(TOPIC_PARTITION)); + EasyMock.expectLastCall(); + sinkTask.put(Collections.emptyList()); + EasyMock.expectLastCall(); + + EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer( + () -> { + rebalanceListener.getValue().onPartitionsRevoked(Collections.emptySet()); + rebalanceListener.getValue().onPartitionsAssigned(Collections.singleton(TOPIC_PARTITION3)); + return ConsumerRecords.empty(); + }); + EasyMock.expect(consumer.assignment()).andReturn(new HashSet<>(Arrays.asList(TOPIC_PARTITION2, TOPIC_PARTITION3))).times(2); + EasyMock.expect(consumer.position(TOPIC_PARTITION3)).andReturn(FIRST_OFFSET); + sinkTask.open(Collections.singleton(TOPIC_PARTITION3)); + EasyMock.expectLastCall(); + sinkTask.put(Collections.emptyList()); + EasyMock.expectLastCall(); + + EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer( + () -> { + rebalanceListener.getValue().onPartitionsLost(Collections.singleton(TOPIC_PARTITION3)); + rebalanceListener.getValue().onPartitionsAssigned(Collections.singleton(TOPIC_PARTITION)); + return ConsumerRecords.empty(); + }); + EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT).times(4); + sinkTask.close(Collections.singleton(TOPIC_PARTITION3)); + EasyMock.expectLastCall(); + EasyMock.expect(consumer.position(TOPIC_PARTITION)).andReturn(FIRST_OFFSET); + sinkTask.open(Collections.singleton(TOPIC_PARTITION)); + EasyMock.expectLastCall(); + sinkTask.put(Collections.emptyList()); + EasyMock.expectLastCall(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + // First iteration--first call to poll, first consumer assignment + workerTask.iteration(); + // Second iteration--second call to poll, partial consumer revocation + workerTask.iteration(); + // Third iteration--third call to poll, partial consumer assignment + workerTask.iteration(); + // Fourth iteration--fourth call to poll, one partition lost; can't commit offsets for it, one new partition assigned + workerTask.iteration(); + + PowerMock.verifyAll(); + } + + @Test + public void testPreCommitFailureAfterPartialRevocationAndAssignment() throws Exception { + createTask(initialState); + + // First poll; assignment is [TP1, TP2] + expectInitializeTask(); + expectTaskGetTopic(true); + expectPollInitialAssignment(); + + // Second poll; a single record is delivered from TP1 + expectConsumerPoll(1); + expectConversionAndTransformation(1); + sinkTask.put(EasyMock.anyObject()); + EasyMock.expectLastCall(); + + // Third poll; assignment changes to [TP2] + EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer( + () -> { + rebalanceListener.getValue().onPartitionsRevoked(Collections.singleton(TOPIC_PARTITION)); + rebalanceListener.getValue().onPartitionsAssigned(Collections.emptySet()); + return ConsumerRecords.empty(); + }); + EasyMock.expect(consumer.assignment()).andReturn(Collections.singleton(TOPIC_PARTITION)).times(2); + final Map offsets = new HashMap<>(); + offsets.put(TOPIC_PARTITION, new OffsetAndMetadata(FIRST_OFFSET + 1)); + sinkTask.preCommit(offsets); + EasyMock.expectLastCall().andReturn(offsets); + consumer.commitSync(offsets); + EasyMock.expectLastCall(); + sinkTask.close(Collections.singleton(TOPIC_PARTITION)); + EasyMock.expectLastCall(); + sinkTask.put(Collections.emptyList()); + EasyMock.expectLastCall(); + + // Fourth poll; assignment changes to [TP2, TP3] + EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer( + () -> { + rebalanceListener.getValue().onPartitionsRevoked(Collections.emptySet()); + rebalanceListener.getValue().onPartitionsAssigned(Collections.singleton(TOPIC_PARTITION3)); + return ConsumerRecords.empty(); + }); + EasyMock.expect(consumer.assignment()).andReturn(new HashSet<>(Arrays.asList(TOPIC_PARTITION2, TOPIC_PARTITION3))).times(2); + EasyMock.expect(consumer.position(TOPIC_PARTITION3)).andReturn(FIRST_OFFSET); + sinkTask.open(Collections.singleton(TOPIC_PARTITION3)); + EasyMock.expectLastCall(); + sinkTask.put(Collections.emptyList()); + EasyMock.expectLastCall(); + + // Fifth poll; an offset commit takes place + EasyMock.expect(consumer.assignment()).andReturn(new HashSet<>(Arrays.asList(TOPIC_PARTITION2, TOPIC_PARTITION3))).times(2); + final Map workerCurrentOffsets = new HashMap<>(); + workerCurrentOffsets.put(TOPIC_PARTITION2, new OffsetAndMetadata(FIRST_OFFSET)); + workerCurrentOffsets.put(TOPIC_PARTITION3, new OffsetAndMetadata(FIRST_OFFSET)); + sinkTask.preCommit(workerCurrentOffsets); + EasyMock.expectLastCall().andThrow(new ConnectException("Failed to flush")); + + consumer.seek(TOPIC_PARTITION2, FIRST_OFFSET); + EasyMock.expectLastCall(); + consumer.seek(TOPIC_PARTITION3, FIRST_OFFSET); + EasyMock.expectLastCall(); + + expectConsumerPoll(0); + sinkTask.put(EasyMock.eq(Collections.emptyList())); + EasyMock.expectLastCall(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + // First iteration--first call to poll, first consumer assignment + workerTask.iteration(); + // Second iteration--second call to poll, delivery of one record + workerTask.iteration(); + // Third iteration--third call to poll, partial consumer revocation + workerTask.iteration(); + // Fourth iteration--fourth call to poll, partial consumer assignment + workerTask.iteration(); + // Fifth iteration--task-requested offset commit with failure in SinkTask::preCommit + sinkTaskContext.getValue().requestCommit(); + workerTask.iteration(); + + PowerMock.verifyAll(); + } + + @Test + public void testWakeupInCommitSyncCausesRetry() throws Exception { + createTask(initialState); + + expectInitializeTask(); + expectTaskGetTopic(true); + expectPollInitialAssignment(); + + expectConsumerPoll(1); + expectConversionAndTransformation(1); + sinkTask.put(EasyMock.anyObject()); + EasyMock.expectLastCall(); + + final Map offsets = new HashMap<>(); + offsets.put(TOPIC_PARTITION, new OffsetAndMetadata(FIRST_OFFSET + 1)); + offsets.put(TOPIC_PARTITION2, new OffsetAndMetadata(FIRST_OFFSET)); + sinkTask.preCommit(offsets); + EasyMock.expectLastCall().andReturn(offsets); + + // first one raises wakeup + consumer.commitSync(EasyMock.>anyObject()); + EasyMock.expectLastCall().andThrow(new WakeupException()); + + // we should retry and complete the commit + consumer.commitSync(EasyMock.>anyObject()); + EasyMock.expectLastCall(); + + sinkTask.close(INITIAL_ASSIGNMENT); + EasyMock.expectLastCall(); + + INITIAL_ASSIGNMENT.forEach(tp -> EasyMock.expect(consumer.position(tp)).andReturn(FIRST_OFFSET)); + + sinkTask.open(INITIAL_ASSIGNMENT); + EasyMock.expectLastCall(); + + EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT).times(5); + EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer( + () -> { + rebalanceListener.getValue().onPartitionsRevoked(INITIAL_ASSIGNMENT); + rebalanceListener.getValue().onPartitionsAssigned(INITIAL_ASSIGNMENT); + return ConsumerRecords.empty(); + }); + + INITIAL_ASSIGNMENT.forEach(tp -> { + consumer.resume(Collections.singleton(tp)); + EasyMock.expectLastCall(); + }); + + statusListener.onResume(taskId); + EasyMock.expectLastCall(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + time.sleep(30000L); + workerTask.initializeAndStart(); + time.sleep(30000L); + + workerTask.iteration(); // poll for initial assignment + time.sleep(30000L); + workerTask.iteration(); // first record delivered + workerTask.iteration(); // now rebalance with the wakeup triggered + time.sleep(30000L); + + assertSinkMetricValue("partition-count", 2); + assertSinkMetricValue("sink-record-read-total", 1.0); + assertSinkMetricValue("sink-record-send-total", 1.0); + assertSinkMetricValue("sink-record-active-count", 0.0); + assertSinkMetricValue("sink-record-active-count-max", 1.0); + assertSinkMetricValue("sink-record-active-count-avg", 0.33333); + assertSinkMetricValue("offset-commit-seq-no", 1.0); + assertSinkMetricValue("offset-commit-completion-total", 1.0); + assertSinkMetricValue("offset-commit-skip-total", 0.0); + assertTaskMetricValue("status", "running"); + assertTaskMetricValue("running-ratio", 1.0); + assertTaskMetricValue("pause-ratio", 0.0); + assertTaskMetricValue("batch-size-max", 1.0); + assertTaskMetricValue("batch-size-avg", 1.0); + assertTaskMetricValue("offset-commit-max-time-ms", 0.0); + assertTaskMetricValue("offset-commit-avg-time-ms", 0.0); + assertTaskMetricValue("offset-commit-failure-percentage", 0.0); + assertTaskMetricValue("offset-commit-success-percentage", 1.0); + + PowerMock.verifyAll(); + } + + @Test + public void testWakeupNotThrownDuringShutdown() throws Exception { + createTask(initialState); + + expectInitializeTask(); + expectTaskGetTopic(true); + expectPollInitialAssignment(); + + expectConsumerPoll(1); + expectConversionAndTransformation(1); + sinkTask.put(EasyMock.anyObject()); + EasyMock.expectLastCall(); + + EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer(() -> { + // stop the task during its second iteration + workerTask.stop(); + return new ConsumerRecords<>(Collections.emptyMap()); + }); + consumer.wakeup(); + EasyMock.expectLastCall(); + + sinkTask.put(EasyMock.eq(Collections.emptyList())); + EasyMock.expectLastCall(); + + EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT).times(1); + + final Map offsets = new HashMap<>(); + offsets.put(TOPIC_PARTITION, new OffsetAndMetadata(FIRST_OFFSET + 1)); + offsets.put(TOPIC_PARTITION2, new OffsetAndMetadata(FIRST_OFFSET)); + sinkTask.preCommit(offsets); + EasyMock.expectLastCall().andReturn(offsets); + + sinkTask.close(EasyMock.anyObject()); + PowerMock.expectLastCall(); + + // fail the first time + consumer.commitSync(EasyMock.eq(offsets)); + EasyMock.expectLastCall().andThrow(new WakeupException()); + + // and succeed the second time + consumer.commitSync(EasyMock.eq(offsets)); + EasyMock.expectLastCall(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + workerTask.execute(); + + assertEquals(0, workerTask.commitFailures()); + + PowerMock.verifyAll(); + } + + @Test + public void testRequestCommit() throws Exception { + createTask(initialState); + + expectInitializeTask(); + expectTaskGetTopic(true); + expectPollInitialAssignment(); + + expectConsumerPoll(1); + expectConversionAndTransformation(1); + sinkTask.put(EasyMock.anyObject()); + EasyMock.expectLastCall(); + + final Map offsets = new HashMap<>(); + offsets.put(TOPIC_PARTITION, new OffsetAndMetadata(FIRST_OFFSET + 1)); + offsets.put(TOPIC_PARTITION2, new OffsetAndMetadata(FIRST_OFFSET)); + sinkTask.preCommit(offsets); + EasyMock.expectLastCall().andReturn(offsets); + + EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT).times(2); + + final Capture callback = EasyMock.newCapture(); + consumer.commitAsync(EasyMock.eq(offsets), EasyMock.capture(callback)); + EasyMock.expectLastCall().andAnswer(() -> { + callback.getValue().onComplete(offsets, null); + return null; + }); + + expectConsumerPoll(0); + sinkTask.put(Collections.emptyList()); + EasyMock.expectLastCall(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + + // Initial assignment + time.sleep(30000L); + workerTask.iteration(); + assertSinkMetricValue("partition-count", 2); + + // First record delivered + workerTask.iteration(); + assertSinkMetricValue("partition-count", 2); + assertSinkMetricValue("sink-record-read-total", 1.0); + assertSinkMetricValue("sink-record-send-total", 1.0); + assertSinkMetricValue("sink-record-active-count", 1.0); + assertSinkMetricValue("sink-record-active-count-max", 1.0); + assertSinkMetricValue("sink-record-active-count-avg", 0.333333); + assertSinkMetricValue("offset-commit-seq-no", 0.0); + assertSinkMetricValue("offset-commit-completion-total", 0.0); + assertSinkMetricValue("offset-commit-skip-total", 0.0); + assertTaskMetricValue("status", "running"); + assertTaskMetricValue("running-ratio", 1.0); + assertTaskMetricValue("pause-ratio", 0.0); + assertTaskMetricValue("batch-size-max", 1.0); + assertTaskMetricValue("batch-size-avg", 0.5); + assertTaskMetricValue("offset-commit-failure-percentage", 0.0); + assertTaskMetricValue("offset-commit-success-percentage", 0.0); + + // Grab the commit time prior to requesting a commit. + // This time should advance slightly after committing. + // KAFKA-8229 + final long previousCommitValue = workerTask.getNextCommit(); + sinkTaskContext.getValue().requestCommit(); + assertTrue(sinkTaskContext.getValue().isCommitRequested()); + assertNotEquals(offsets, Whitebox.>getInternalState(workerTask, "lastCommittedOffsets")); + time.sleep(10000L); + workerTask.iteration(); // triggers the commit + time.sleep(10000L); + assertFalse(sinkTaskContext.getValue().isCommitRequested()); // should have been cleared + assertEquals(offsets, Whitebox.>getInternalState(workerTask, "lastCommittedOffsets")); + assertEquals(0, workerTask.commitFailures()); + // Assert the next commit time advances slightly, the amount it advances + // is the normal commit time less the two sleeps since it started each + // of those sleeps were 10 seconds. + // KAFKA-8229 + assertEquals("Should have only advanced by 40 seconds", + previousCommitValue + + (WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_DEFAULT - 10000L * 2), + workerTask.getNextCommit()); + + assertSinkMetricValue("partition-count", 2); + assertSinkMetricValue("sink-record-read-total", 1.0); + assertSinkMetricValue("sink-record-send-total", 1.0); + assertSinkMetricValue("sink-record-active-count", 0.0); + assertSinkMetricValue("sink-record-active-count-max", 1.0); + assertSinkMetricValue("sink-record-active-count-avg", 0.2); + assertSinkMetricValue("offset-commit-seq-no", 1.0); + assertSinkMetricValue("offset-commit-completion-total", 1.0); + assertSinkMetricValue("offset-commit-skip-total", 0.0); + assertTaskMetricValue("status", "running"); + assertTaskMetricValue("running-ratio", 1.0); + assertTaskMetricValue("pause-ratio", 0.0); + assertTaskMetricValue("batch-size-max", 1.0); + assertTaskMetricValue("batch-size-avg", 0.33333); + assertTaskMetricValue("offset-commit-max-time-ms", 0.0); + assertTaskMetricValue("offset-commit-avg-time-ms", 0.0); + assertTaskMetricValue("offset-commit-failure-percentage", 0.0); + assertTaskMetricValue("offset-commit-success-percentage", 1.0); + + PowerMock.verifyAll(); + } + + @Test + public void testPreCommit() throws Exception { + createTask(initialState); + + expectInitializeTask(); + expectTaskGetTopic(true); + + // iter 1 + expectPollInitialAssignment(); + + // iter 2 + expectConsumerPoll(2); + expectConversionAndTransformation(2); + sinkTask.put(EasyMock.anyObject()); + EasyMock.expectLastCall(); + + final Map workerStartingOffsets = new HashMap<>(); + workerStartingOffsets.put(TOPIC_PARTITION, new OffsetAndMetadata(FIRST_OFFSET)); + workerStartingOffsets.put(TOPIC_PARTITION2, new OffsetAndMetadata(FIRST_OFFSET)); + + final Map workerCurrentOffsets = new HashMap<>(); + workerCurrentOffsets.put(TOPIC_PARTITION, new OffsetAndMetadata(FIRST_OFFSET + 2)); + workerCurrentOffsets.put(TOPIC_PARTITION2, new OffsetAndMetadata(FIRST_OFFSET)); + + final Map taskOffsets = new HashMap<>(); + taskOffsets.put(TOPIC_PARTITION, new OffsetAndMetadata(FIRST_OFFSET + 1)); // act like FIRST_OFFSET+2 has not yet been flushed by the task + taskOffsets.put(TOPIC_PARTITION2, new OffsetAndMetadata(FIRST_OFFSET + 1)); // should be ignored because > current offset + taskOffsets.put(new TopicPartition(TOPIC, 3), new OffsetAndMetadata(FIRST_OFFSET)); // should be ignored because this partition is not assigned + + final Map committableOffsets = new HashMap<>(); + committableOffsets.put(TOPIC_PARTITION, new OffsetAndMetadata(FIRST_OFFSET + 1)); + committableOffsets.put(TOPIC_PARTITION2, new OffsetAndMetadata(FIRST_OFFSET)); + + sinkTask.preCommit(workerCurrentOffsets); + EasyMock.expectLastCall().andReturn(taskOffsets); + // Expect extra invalid topic partition to be filtered, which causes the consumer assignment to be logged + EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT).times(2); + final Capture callback = EasyMock.newCapture(); + consumer.commitAsync(EasyMock.eq(committableOffsets), EasyMock.capture(callback)); + EasyMock.expectLastCall().andAnswer(() -> { + callback.getValue().onComplete(committableOffsets, null); + return null; + }); + expectConsumerPoll(0); + sinkTask.put(EasyMock.anyObject()); + EasyMock.expectLastCall(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + workerTask.iteration(); // iter 1 -- initial assignment + + assertEquals(workerStartingOffsets, Whitebox.>getInternalState(workerTask, "currentOffsets")); + workerTask.iteration(); // iter 2 -- deliver 2 records + + assertEquals(workerCurrentOffsets, Whitebox.>getInternalState(workerTask, "currentOffsets")); + assertEquals(workerStartingOffsets, Whitebox.>getInternalState(workerTask, "lastCommittedOffsets")); + sinkTaskContext.getValue().requestCommit(); + workerTask.iteration(); // iter 3 -- commit + assertEquals(committableOffsets, Whitebox.>getInternalState(workerTask, "lastCommittedOffsets")); + + PowerMock.verifyAll(); + } + + @Test + public void testPreCommitFailure() throws Exception { + createTask(initialState); + + expectInitializeTask(); + expectTaskGetTopic(true); + EasyMock.expect(consumer.assignment()).andStubReturn(INITIAL_ASSIGNMENT); + + // iter 1 + expectPollInitialAssignment(); + + // iter 2 + expectConsumerPoll(2); + expectConversionAndTransformation(2); + sinkTask.put(EasyMock.anyObject()); + EasyMock.expectLastCall(); + + // iter 3 + final Map workerCurrentOffsets = new HashMap<>(); + workerCurrentOffsets.put(TOPIC_PARTITION, new OffsetAndMetadata(FIRST_OFFSET + 2)); + workerCurrentOffsets.put(TOPIC_PARTITION2, new OffsetAndMetadata(FIRST_OFFSET)); + sinkTask.preCommit(workerCurrentOffsets); + EasyMock.expectLastCall().andThrow(new ConnectException("Failed to flush")); + + consumer.seek(TOPIC_PARTITION, FIRST_OFFSET); + EasyMock.expectLastCall(); + consumer.seek(TOPIC_PARTITION2, FIRST_OFFSET); + EasyMock.expectLastCall(); + + expectConsumerPoll(0); + sinkTask.put(EasyMock.eq(Collections.emptyList())); + EasyMock.expectLastCall(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + workerTask.iteration(); // iter 1 -- initial assignment + workerTask.iteration(); // iter 2 -- deliver 2 records + sinkTaskContext.getValue().requestCommit(); + workerTask.iteration(); // iter 3 -- commit + + PowerMock.verifyAll(); + } + + @Test + public void testIgnoredCommit() throws Exception { + createTask(initialState); + + expectInitializeTask(); + expectTaskGetTopic(true); + + // iter 1 + expectPollInitialAssignment(); + + // iter 2 + expectConsumerPoll(1); + expectConversionAndTransformation(1); + sinkTask.put(EasyMock.anyObject()); + EasyMock.expectLastCall(); + + final Map workerStartingOffsets = new HashMap<>(); + workerStartingOffsets.put(TOPIC_PARTITION, new OffsetAndMetadata(FIRST_OFFSET)); + workerStartingOffsets.put(TOPIC_PARTITION2, new OffsetAndMetadata(FIRST_OFFSET)); + + final Map workerCurrentOffsets = new HashMap<>(); + workerCurrentOffsets.put(TOPIC_PARTITION, new OffsetAndMetadata(FIRST_OFFSET + 1)); + workerCurrentOffsets.put(TOPIC_PARTITION2, new OffsetAndMetadata(FIRST_OFFSET)); + + EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT).times(2); + + // iter 3 + sinkTask.preCommit(workerCurrentOffsets); + EasyMock.expectLastCall().andReturn(workerStartingOffsets); + // no actual consumer.commit() triggered + expectConsumerPoll(0); + sinkTask.put(EasyMock.anyObject()); + EasyMock.expectLastCall(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + workerTask.iteration(); // iter 1 -- initial assignment + + assertEquals(workerStartingOffsets, Whitebox.>getInternalState(workerTask, "currentOffsets")); + assertEquals(workerStartingOffsets, Whitebox.>getInternalState(workerTask, "lastCommittedOffsets")); + + workerTask.iteration(); // iter 2 -- deliver 2 records + + sinkTaskContext.getValue().requestCommit(); + workerTask.iteration(); // iter 3 -- commit + + PowerMock.verifyAll(); + } + + // Test that the commitTimeoutMs timestamp is correctly computed and checked in WorkerSinkTask.iteration() + // when there is a long running commit in process. See KAFKA-4942 for more information. + @Test + public void testLongRunningCommitWithoutTimeout() throws Exception { + createTask(initialState); + + expectInitializeTask(); + expectTaskGetTopic(true); + + // iter 1 + expectPollInitialAssignment(); + + // iter 2 + expectConsumerPoll(1); + expectConversionAndTransformation(1); + sinkTask.put(EasyMock.anyObject()); + EasyMock.expectLastCall(); + + final Map workerStartingOffsets = new HashMap<>(); + workerStartingOffsets.put(TOPIC_PARTITION, new OffsetAndMetadata(FIRST_OFFSET)); + workerStartingOffsets.put(TOPIC_PARTITION2, new OffsetAndMetadata(FIRST_OFFSET)); + + final Map workerCurrentOffsets = new HashMap<>(); + workerCurrentOffsets.put(TOPIC_PARTITION, new OffsetAndMetadata(FIRST_OFFSET + 1)); + workerCurrentOffsets.put(TOPIC_PARTITION2, new OffsetAndMetadata(FIRST_OFFSET)); + + EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT).times(2); + + // iter 3 - note that we return the current offset to indicate they should be committed + sinkTask.preCommit(workerCurrentOffsets); + EasyMock.expectLastCall().andReturn(workerCurrentOffsets); + + // We need to delay the result of trying to commit offsets to Kafka via the consumer.commitAsync + // method. We do this so that we can test that we do not erroneously mark a commit as timed out + // while it is still running and under time. To fake this for tests we have the commit run in a + // separate thread and wait for a latch which we control back in the main thread. + final ExecutorService executor = Executors.newSingleThreadExecutor(); + final CountDownLatch latch = new CountDownLatch(1); + + consumer.commitAsync(EasyMock.eq(workerCurrentOffsets), EasyMock.anyObject()); + EasyMock.expectLastCall().andAnswer(() -> { + // Grab the arguments passed to the consumer.commitAsync method + final Object[] args = EasyMock.getCurrentArguments(); + @SuppressWarnings("unchecked") + final Map offsets = (Map) args[0]; + final OffsetCommitCallback callback = (OffsetCommitCallback) args[1]; + + executor.execute(() -> { + try { + latch.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + callback.onComplete(offsets, null); + }); + + return null; + }); + + // no actual consumer.commit() triggered + expectConsumerPoll(0); + + sinkTask.put(EasyMock.anyObject()); + EasyMock.expectLastCall(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + workerTask.iteration(); // iter 1 -- initial assignment + + assertEquals(workerStartingOffsets, Whitebox.>getInternalState(workerTask, "currentOffsets")); + assertEquals(workerStartingOffsets, Whitebox.>getInternalState(workerTask, "lastCommittedOffsets")); + + time.sleep(WorkerConfig.OFFSET_COMMIT_TIMEOUT_MS_DEFAULT); + workerTask.iteration(); // iter 2 -- deliver 2 records + + sinkTaskContext.getValue().requestCommit(); + workerTask.iteration(); // iter 3 -- commit in progress + + // Make sure the "committing" flag didn't immediately get flipped back to false due to an incorrect timeout + assertTrue("Expected worker to be in the process of committing offsets", workerTask.isCommitting()); + + // Let the async commit finish and wait for it to end + latch.countDown(); + executor.shutdown(); + executor.awaitTermination(30, TimeUnit.SECONDS); + + assertEquals(workerCurrentOffsets, Whitebox.>getInternalState(workerTask, "currentOffsets")); + assertEquals(workerCurrentOffsets, Whitebox.>getInternalState(workerTask, "lastCommittedOffsets")); + + PowerMock.verifyAll(); + } + + @Test + public void testSinkTasksHandleCloseErrors() throws Exception { + createTask(initialState); + expectInitializeTask(); + expectTaskGetTopic(true); + + expectPollInitialAssignment(); + + // Put one message through the task to get some offsets to commit + expectConsumerPoll(1); + expectConversionAndTransformation(1); + sinkTask.put(EasyMock.anyObject()); + PowerMock.expectLastCall().andVoid(); + + // Stop the task during the next put + expectConsumerPoll(1); + expectConversionAndTransformation(1); + sinkTask.put(EasyMock.anyObject()); + PowerMock.expectLastCall().andAnswer(() -> { + workerTask.stop(); + return null; + }); + + consumer.wakeup(); + PowerMock.expectLastCall(); + + // Throw another exception while closing the task's assignment + EasyMock.expect(sinkTask.preCommit(EasyMock.anyObject())) + .andStubReturn(Collections.emptyMap()); + Throwable closeException = new RuntimeException(); + sinkTask.close(EasyMock.anyObject()); + PowerMock.expectLastCall().andThrow(closeException); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + try { + workerTask.execute(); + fail("workerTask.execute should have thrown an exception"); + } catch (RuntimeException e) { + PowerMock.verifyAll(); + assertSame("Exception from close should propagate as-is", closeException, e); + } + } + + @Test + public void testSuppressCloseErrors() throws Exception { + createTask(initialState); + expectInitializeTask(); + expectTaskGetTopic(true); + + expectPollInitialAssignment(); + + // Put one message through the task to get some offsets to commit + expectConsumerPoll(1); + expectConversionAndTransformation(1); + sinkTask.put(EasyMock.anyObject()); + PowerMock.expectLastCall().andVoid(); + + // Throw an exception on the next put to trigger shutdown behavior + // This exception is the true "cause" of the failure + expectConsumerPoll(1); + expectConversionAndTransformation(1); + Throwable putException = new RuntimeException(); + sinkTask.put(EasyMock.anyObject()); + PowerMock.expectLastCall().andThrow(putException); + + // Throw another exception while closing the task's assignment + EasyMock.expect(sinkTask.preCommit(EasyMock.anyObject())) + .andStubReturn(Collections.emptyMap()); + Throwable closeException = new RuntimeException(); + sinkTask.close(EasyMock.anyObject()); + PowerMock.expectLastCall().andThrow(closeException); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + try { + workerTask.execute(); + fail("workerTask.execute should have thrown an exception"); + } catch (ConnectException e) { + PowerMock.verifyAll(); + assertSame("Exception from put should be the cause", putException, e.getCause()); + assertTrue("Exception from close should be suppressed", e.getSuppressed().length > 0); + assertSame(closeException, e.getSuppressed()[0]); + } + } + + // Verify that when commitAsync is called but the supplied callback is not called by the consumer before a + // rebalance occurs, the async callback does not reset the last committed offset from the rebalance. + // See KAFKA-5731 for more information. + @Test + public void testCommitWithOutOfOrderCallback() throws Exception { + createTask(initialState); + + expectInitializeTask(); + expectTaskGetTopic(true); + + // iter 1 + expectPollInitialAssignment(); + + // iter 2 + expectConsumerPoll(1); + expectConversionAndTransformation(4); + sinkTask.put(EasyMock.anyObject()); + EasyMock.expectLastCall(); + + final Map workerStartingOffsets = new HashMap<>(); + workerStartingOffsets.put(TOPIC_PARTITION, new OffsetAndMetadata(FIRST_OFFSET)); + workerStartingOffsets.put(TOPIC_PARTITION2, new OffsetAndMetadata(FIRST_OFFSET)); + + final Map workerCurrentOffsets = new HashMap<>(); + workerCurrentOffsets.put(TOPIC_PARTITION, new OffsetAndMetadata(FIRST_OFFSET + 1)); + workerCurrentOffsets.put(TOPIC_PARTITION2, new OffsetAndMetadata(FIRST_OFFSET)); + + final List originalPartitions = new ArrayList<>(INITIAL_ASSIGNMENT); + final List rebalancedPartitions = asList(TOPIC_PARTITION, TOPIC_PARTITION2, TOPIC_PARTITION3); + final Map rebalanceOffsets = new HashMap<>(); + rebalanceOffsets.put(TOPIC_PARTITION, workerCurrentOffsets.get(TOPIC_PARTITION)); + rebalanceOffsets.put(TOPIC_PARTITION2, workerCurrentOffsets.get(TOPIC_PARTITION2)); + rebalanceOffsets.put(TOPIC_PARTITION3, new OffsetAndMetadata(FIRST_OFFSET)); + + final Map postRebalanceCurrentOffsets = new HashMap<>(); + postRebalanceCurrentOffsets.put(TOPIC_PARTITION, new OffsetAndMetadata(FIRST_OFFSET + 3)); + postRebalanceCurrentOffsets.put(TOPIC_PARTITION2, new OffsetAndMetadata(FIRST_OFFSET)); + postRebalanceCurrentOffsets.put(TOPIC_PARTITION3, new OffsetAndMetadata(FIRST_OFFSET + 2)); + + EasyMock.expect(consumer.assignment()).andReturn(new HashSet<>(originalPartitions)).times(2); + + // iter 3 - note that we return the current offset to indicate they should be committed + sinkTask.preCommit(workerCurrentOffsets); + EasyMock.expectLastCall().andReturn(workerCurrentOffsets); + + // We need to delay the result of trying to commit offsets to Kafka via the consumer.commitAsync + // method. We do this so that we can test that the callback is not called until after the rebalance + // changes the lastCommittedOffsets. To fake this for tests we have the commitAsync build a function + // that will call the callback with the appropriate parameters, and we'll run that function later. + final AtomicReference asyncCallbackRunner = new AtomicReference<>(); + final AtomicBoolean asyncCallbackRan = new AtomicBoolean(); + + consumer.commitAsync(EasyMock.eq(workerCurrentOffsets), EasyMock.anyObject()); + EasyMock.expectLastCall().andAnswer(() -> { + // Grab the arguments passed to the consumer.commitAsync method + final Object[] args = EasyMock.getCurrentArguments(); + @SuppressWarnings("unchecked") + final Map offsets = (Map) args[0]; + final OffsetCommitCallback callback = (OffsetCommitCallback) args[1]; + asyncCallbackRunner.set(() -> { + callback.onComplete(offsets, null); + asyncCallbackRan.set(true); + }); + return null; + }); + + // Expect the next poll to discover and perform the rebalance, THEN complete the previous callback handler, + // and then return one record for TP1 and one for TP3. + final AtomicBoolean rebalanced = new AtomicBoolean(); + EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer( + () -> { + // Rebalance always begins with revoking current partitions ... + rebalanceListener.getValue().onPartitionsRevoked(originalPartitions); + // Respond to the rebalance + Map offsets = new HashMap<>(); + offsets.put(TOPIC_PARTITION, rebalanceOffsets.get(TOPIC_PARTITION).offset()); + offsets.put(TOPIC_PARTITION2, rebalanceOffsets.get(TOPIC_PARTITION2).offset()); + offsets.put(TOPIC_PARTITION3, rebalanceOffsets.get(TOPIC_PARTITION3).offset()); + sinkTaskContext.getValue().offset(offsets); + rebalanceListener.getValue().onPartitionsAssigned(rebalancedPartitions); + rebalanced.set(true); + + // Run the previous async commit handler + asyncCallbackRunner.get().run(); + + // And prep the two records to return + long timestamp = RecordBatch.NO_TIMESTAMP; + TimestampType timestampType = TimestampType.NO_TIMESTAMP_TYPE; + List> records = new ArrayList<>(); + records.add(new ConsumerRecord<>(TOPIC, PARTITION, FIRST_OFFSET + recordsReturnedTp1 + 1, timestamp, timestampType, + 0, 0, RAW_KEY, RAW_VALUE, new RecordHeaders(), Optional.empty())); + records.add(new ConsumerRecord<>(TOPIC, PARTITION3, FIRST_OFFSET + recordsReturnedTp3 + 1, timestamp, timestampType, + 0, 0, RAW_KEY, RAW_VALUE, new RecordHeaders(), Optional.empty())); + recordsReturnedTp1 += 1; + recordsReturnedTp3 += 1; + return new ConsumerRecords<>(Collections.singletonMap(new TopicPartition(TOPIC, PARTITION), records)); + }); + + // onPartitionsRevoked + sinkTask.preCommit(workerCurrentOffsets); + EasyMock.expectLastCall().andReturn(workerCurrentOffsets); + sinkTask.put(EasyMock.anyObject()); + EasyMock.expectLastCall(); + sinkTask.close(new ArrayList<>(workerCurrentOffsets.keySet())); + EasyMock.expectLastCall(); + consumer.commitSync(workerCurrentOffsets); + EasyMock.expectLastCall(); + + // onPartitionsAssigned - step 1 + final long offsetTp1 = rebalanceOffsets.get(TOPIC_PARTITION).offset(); + final long offsetTp2 = rebalanceOffsets.get(TOPIC_PARTITION2).offset(); + final long offsetTp3 = rebalanceOffsets.get(TOPIC_PARTITION3).offset(); + EasyMock.expect(consumer.position(TOPIC_PARTITION)).andReturn(offsetTp1); + EasyMock.expect(consumer.position(TOPIC_PARTITION2)).andReturn(offsetTp2); + EasyMock.expect(consumer.position(TOPIC_PARTITION3)).andReturn(offsetTp3); + EasyMock.expect(consumer.assignment()).andReturn(new HashSet<>(rebalancedPartitions)).times(6); + + // onPartitionsAssigned - step 2 + sinkTask.open(EasyMock.eq(rebalancedPartitions)); + EasyMock.expectLastCall(); + + // onPartitionsAssigned - step 3 rewind + consumer.seek(TOPIC_PARTITION, offsetTp1); + EasyMock.expectLastCall(); + consumer.seek(TOPIC_PARTITION2, offsetTp2); + EasyMock.expectLastCall(); + consumer.seek(TOPIC_PARTITION3, offsetTp3); + EasyMock.expectLastCall(); + + // iter 4 - note that we return the current offset to indicate they should be committed + sinkTask.preCommit(postRebalanceCurrentOffsets); + EasyMock.expectLastCall().andReturn(postRebalanceCurrentOffsets); + + final Capture callback = EasyMock.newCapture(); + consumer.commitAsync(EasyMock.eq(postRebalanceCurrentOffsets), EasyMock.capture(callback)); + EasyMock.expectLastCall().andAnswer(() -> { + callback.getValue().onComplete(postRebalanceCurrentOffsets, null); + return null; + }); + + // no actual consumer.commit() triggered + expectConsumerPoll(1); + + sinkTask.put(EasyMock.anyObject()); + EasyMock.expectLastCall(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + workerTask.iteration(); // iter 1 -- initial assignment + + assertEquals(workerStartingOffsets, Whitebox.getInternalState(workerTask, "currentOffsets")); + assertEquals(workerStartingOffsets, Whitebox.getInternalState(workerTask, "lastCommittedOffsets")); + + time.sleep(WorkerConfig.OFFSET_COMMIT_TIMEOUT_MS_DEFAULT); + workerTask.iteration(); // iter 2 -- deliver 2 records + + sinkTaskContext.getValue().requestCommit(); + workerTask.iteration(); // iter 3 -- commit in progress + + assertSinkMetricValue("partition-count", 3); + assertSinkMetricValue("sink-record-read-total", 3.0); + assertSinkMetricValue("sink-record-send-total", 3.0); + assertSinkMetricValue("sink-record-active-count", 4.0); + assertSinkMetricValue("sink-record-active-count-max", 4.0); + assertSinkMetricValue("sink-record-active-count-avg", 0.71429); + assertSinkMetricValue("offset-commit-seq-no", 2.0); + assertSinkMetricValue("offset-commit-completion-total", 1.0); + assertSinkMetricValue("offset-commit-skip-total", 1.0); + assertTaskMetricValue("status", "running"); + assertTaskMetricValue("running-ratio", 1.0); + assertTaskMetricValue("pause-ratio", 0.0); + assertTaskMetricValue("batch-size-max", 2.0); + assertTaskMetricValue("batch-size-avg", 1.0); + assertTaskMetricValue("offset-commit-max-time-ms", 0.0); + assertTaskMetricValue("offset-commit-avg-time-ms", 0.0); + assertTaskMetricValue("offset-commit-failure-percentage", 0.0); + assertTaskMetricValue("offset-commit-success-percentage", 1.0); + + assertTrue(asyncCallbackRan.get()); + assertTrue(rebalanced.get()); + + // Check that the offsets were not reset by the out-of-order async commit callback + assertEquals(postRebalanceCurrentOffsets, Whitebox.getInternalState(workerTask, "currentOffsets")); + assertEquals(rebalanceOffsets, Whitebox.getInternalState(workerTask, "lastCommittedOffsets")); + + time.sleep(WorkerConfig.OFFSET_COMMIT_TIMEOUT_MS_DEFAULT); + sinkTaskContext.getValue().requestCommit(); + workerTask.iteration(); // iter 4 -- commit in progress + + // Check that the offsets were not reset by the out-of-order async commit callback + assertEquals(postRebalanceCurrentOffsets, Whitebox.getInternalState(workerTask, "currentOffsets")); + assertEquals(postRebalanceCurrentOffsets, Whitebox.getInternalState(workerTask, "lastCommittedOffsets")); + + assertSinkMetricValue("partition-count", 3); + assertSinkMetricValue("sink-record-read-total", 4.0); + assertSinkMetricValue("sink-record-send-total", 4.0); + assertSinkMetricValue("sink-record-active-count", 0.0); + assertSinkMetricValue("sink-record-active-count-max", 4.0); + assertSinkMetricValue("sink-record-active-count-avg", 0.5555555); + assertSinkMetricValue("offset-commit-seq-no", 3.0); + assertSinkMetricValue("offset-commit-completion-total", 2.0); + assertSinkMetricValue("offset-commit-skip-total", 1.0); + assertTaskMetricValue("status", "running"); + assertTaskMetricValue("running-ratio", 1.0); + assertTaskMetricValue("pause-ratio", 0.0); + assertTaskMetricValue("batch-size-max", 2.0); + assertTaskMetricValue("batch-size-avg", 1.0); + assertTaskMetricValue("offset-commit-max-time-ms", 0.0); + assertTaskMetricValue("offset-commit-avg-time-ms", 0.0); + assertTaskMetricValue("offset-commit-failure-percentage", 0.0); + assertTaskMetricValue("offset-commit-success-percentage", 1.0); + + PowerMock.verifyAll(); + } + + @Test + public void testDeliveryWithMutatingTransform() throws Exception { + createTask(initialState); + + expectInitializeTask(); + expectTaskGetTopic(true); + + expectPollInitialAssignment(); + + expectConsumerPoll(1); + expectConversionAndTransformation(1, "newtopic_"); + sinkTask.put(EasyMock.anyObject()); + EasyMock.expectLastCall(); + + final Map offsets = new HashMap<>(); + offsets.put(TOPIC_PARTITION, new OffsetAndMetadata(FIRST_OFFSET + 1)); + offsets.put(TOPIC_PARTITION2, new OffsetAndMetadata(FIRST_OFFSET)); + sinkTask.preCommit(offsets); + EasyMock.expectLastCall().andReturn(offsets); + + EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT).times(2); + + final Capture callback = EasyMock.newCapture(); + consumer.commitAsync(EasyMock.eq(offsets), EasyMock.capture(callback)); + EasyMock.expectLastCall().andAnswer(() -> { + callback.getValue().onComplete(offsets, null); + return null; + }); + + expectConsumerPoll(0); + sinkTask.put(Collections.emptyList()); + EasyMock.expectLastCall(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + + workerTask.iteration(); // initial assignment + + workerTask.iteration(); // first record delivered + + sinkTaskContext.getValue().requestCommit(); + assertTrue(sinkTaskContext.getValue().isCommitRequested()); + assertNotEquals(offsets, Whitebox.>getInternalState(workerTask, "lastCommittedOffsets")); + workerTask.iteration(); // triggers the commit + assertFalse(sinkTaskContext.getValue().isCommitRequested()); // should have been cleared + assertEquals(offsets, Whitebox.>getInternalState(workerTask, "lastCommittedOffsets")); + assertEquals(0, workerTask.commitFailures()); + assertEquals(1.0, metrics.currentMetricValueAsDouble(workerTask.taskMetricsGroup().metricGroup(), "batch-size-max"), 0.0001); + + PowerMock.verifyAll(); + } + + @Test + public void testMissingTimestampPropagation() throws Exception { + createTask(initialState); + + expectInitializeTask(); + expectTaskGetTopic(true); + expectPollInitialAssignment(); + expectConsumerPoll(1, RecordBatch.NO_TIMESTAMP, TimestampType.CREATE_TIME); + expectConversionAndTransformation(1); + + Capture> records = EasyMock.newCapture(CaptureType.ALL); + + sinkTask.put(EasyMock.capture(records)); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + workerTask.iteration(); // iter 1 -- initial assignment + workerTask.iteration(); // iter 2 -- deliver 1 record + + SinkRecord record = records.getValue().iterator().next(); + + // we expect null for missing timestamp, the sentinel value of Record.NO_TIMESTAMP is Kafka's API + assertNull(record.timestamp()); + assertEquals(TimestampType.CREATE_TIME, record.timestampType()); + + PowerMock.verifyAll(); + } + + @Test + public void testTimestampPropagation() throws Exception { + final Long timestamp = System.currentTimeMillis(); + final TimestampType timestampType = TimestampType.CREATE_TIME; + + createTask(initialState); + + expectInitializeTask(); + expectTaskGetTopic(true); + expectPollInitialAssignment(); + expectConsumerPoll(1, timestamp, timestampType); + expectConversionAndTransformation(1); + + Capture> records = EasyMock.newCapture(CaptureType.ALL); + sinkTask.put(EasyMock.capture(records)); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + workerTask.iteration(); // iter 1 -- initial assignment + workerTask.iteration(); // iter 2 -- deliver 1 record + + SinkRecord record = records.getValue().iterator().next(); + + assertEquals(timestamp, record.timestamp()); + assertEquals(timestampType, record.timestampType()); + + PowerMock.verifyAll(); + } + + @Test + public void testTopicsRegex() throws Exception { + Map props = new HashMap<>(TASK_PROPS); + props.remove("topics"); + props.put("topics.regex", "te.*"); + TaskConfig taskConfig = new TaskConfig(props); + + createTask(TargetState.PAUSED); + + consumer.subscribe(EasyMock.capture(topicsRegex), EasyMock.capture(rebalanceListener)); + PowerMock.expectLastCall(); + + sinkTask.initialize(EasyMock.capture(sinkTaskContext)); + PowerMock.expectLastCall(); + sinkTask.start(props); + PowerMock.expectLastCall(); + + expectPollInitialAssignment(); + + EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT); + consumer.pause(INITIAL_ASSIGNMENT); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + workerTask.initialize(taskConfig); + workerTask.initializeAndStart(); + workerTask.iteration(); + time.sleep(10000L); + + PowerMock.verifyAll(); + } + + @Test + public void testMetricsGroup() { + SinkTaskMetricsGroup group = new SinkTaskMetricsGroup(taskId, metrics); + SinkTaskMetricsGroup group1 = new SinkTaskMetricsGroup(taskId1, metrics); + for (int i = 0; i != 10; ++i) { + group.recordRead(1); + group.recordSend(2); + group.recordPut(3); + group.recordPartitionCount(4); + group.recordOffsetSequenceNumber(5); + } + Map committedOffsets = new HashMap<>(); + committedOffsets.put(TOPIC_PARTITION, new OffsetAndMetadata(FIRST_OFFSET + 1)); + group.recordCommittedOffsets(committedOffsets); + Map consumedOffsets = new HashMap<>(); + consumedOffsets.put(TOPIC_PARTITION, new OffsetAndMetadata(FIRST_OFFSET + 10)); + group.recordConsumedOffsets(consumedOffsets); + + for (int i = 0; i != 20; ++i) { + group1.recordRead(1); + group1.recordSend(2); + group1.recordPut(30); + group1.recordPartitionCount(40); + group1.recordOffsetSequenceNumber(50); + } + committedOffsets = new HashMap<>(); + committedOffsets.put(TOPIC_PARTITION2, new OffsetAndMetadata(FIRST_OFFSET + 2)); + committedOffsets.put(TOPIC_PARTITION3, new OffsetAndMetadata(FIRST_OFFSET + 3)); + group1.recordCommittedOffsets(committedOffsets); + consumedOffsets = new HashMap<>(); + consumedOffsets.put(TOPIC_PARTITION2, new OffsetAndMetadata(FIRST_OFFSET + 20)); + consumedOffsets.put(TOPIC_PARTITION3, new OffsetAndMetadata(FIRST_OFFSET + 30)); + group1.recordConsumedOffsets(consumedOffsets); + + assertEquals(0.333, metrics.currentMetricValueAsDouble(group.metricGroup(), "sink-record-read-rate"), 0.001d); + assertEquals(0.667, metrics.currentMetricValueAsDouble(group.metricGroup(), "sink-record-send-rate"), 0.001d); + assertEquals(9, metrics.currentMetricValueAsDouble(group.metricGroup(), "sink-record-active-count"), 0.001d); + assertEquals(4, metrics.currentMetricValueAsDouble(group.metricGroup(), "partition-count"), 0.001d); + assertEquals(5, metrics.currentMetricValueAsDouble(group.metricGroup(), "offset-commit-seq-no"), 0.001d); + assertEquals(3, metrics.currentMetricValueAsDouble(group.metricGroup(), "put-batch-max-time-ms"), 0.001d); + + // Close the group + group.close(); + + for (MetricName metricName : group.metricGroup().metrics().metrics().keySet()) { + // Metrics for this group should no longer exist + assertFalse(group.metricGroup().groupId().includes(metricName)); + } + // Sensors for this group should no longer exist + assertNull(group.metricGroup().metrics().getSensor("source-record-poll")); + assertNull(group.metricGroup().metrics().getSensor("source-record-write")); + assertNull(group.metricGroup().metrics().getSensor("poll-batch-time")); + + assertEquals(0.667, metrics.currentMetricValueAsDouble(group1.metricGroup(), "sink-record-read-rate"), 0.001d); + assertEquals(1.333, metrics.currentMetricValueAsDouble(group1.metricGroup(), "sink-record-send-rate"), 0.001d); + assertEquals(45, metrics.currentMetricValueAsDouble(group1.metricGroup(), "sink-record-active-count"), 0.001d); + assertEquals(40, metrics.currentMetricValueAsDouble(group1.metricGroup(), "partition-count"), 0.001d); + assertEquals(50, metrics.currentMetricValueAsDouble(group1.metricGroup(), "offset-commit-seq-no"), 0.001d); + assertEquals(30, metrics.currentMetricValueAsDouble(group1.metricGroup(), "put-batch-max-time-ms"), 0.001d); + } + + @Test + public void testHeaders() throws Exception { + Headers headers = new RecordHeaders(); + headers.add("header_key", "header_value".getBytes()); + + createTask(initialState); + + expectInitializeTask(); + expectTaskGetTopic(true); + expectPollInitialAssignment(); + + expectConsumerPoll(1, headers); + expectConversionAndTransformation(1, null, headers); + sinkTask.put(EasyMock.anyObject()); + EasyMock.expectLastCall(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + workerTask.iteration(); // iter 1 -- initial assignment + workerTask.iteration(); // iter 2 -- deliver 1 record + + PowerMock.verifyAll(); + } + + @Test + public void testHeadersWithCustomConverter() throws Exception { + StringConverter stringConverter = new StringConverter(); + TestConverterWithHeaders testConverter = new TestConverterWithHeaders(); + + createTask(initialState, stringConverter, testConverter, stringConverter); + + expectInitializeTask(); + expectTaskGetTopic(true); + expectPollInitialAssignment(); + + String keyA = "a"; + String valueA = "Árvíztűrő tükörfúrógép"; + Headers headersA = new RecordHeaders(); + String encodingA = "latin2"; + headersA.add("encoding", encodingA.getBytes()); + + String keyB = "b"; + String valueB = "Тестовое сообщение"; + Headers headersB = new RecordHeaders(); + String encodingB = "koi8_r"; + headersB.add("encoding", encodingB.getBytes()); + + expectConsumerPoll(Arrays.asList( + new ConsumerRecord<>(TOPIC, PARTITION, FIRST_OFFSET + recordsReturnedTp1 + 1, RecordBatch.NO_TIMESTAMP, TimestampType.NO_TIMESTAMP_TYPE, + 0, 0, keyA.getBytes(), valueA.getBytes(encodingA), headersA, Optional.empty()), + new ConsumerRecord<>(TOPIC, PARTITION, FIRST_OFFSET + recordsReturnedTp1 + 2, RecordBatch.NO_TIMESTAMP, TimestampType.NO_TIMESTAMP_TYPE, + 0, 0, keyB.getBytes(), valueB.getBytes(encodingB), headersB, Optional.empty()) + )); + + expectTransformation(2, null); + + Capture> records = EasyMock.newCapture(CaptureType.ALL); + sinkTask.put(EasyMock.capture(records)); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + workerTask.iteration(); // iter 1 -- initial assignment + workerTask.iteration(); // iter 2 -- deliver 1 record + + Iterator iterator = records.getValue().iterator(); + + SinkRecord recordA = iterator.next(); + assertEquals(keyA, recordA.key()); + assertEquals(valueA, recordA.value()); + + SinkRecord recordB = iterator.next(); + assertEquals(keyB, recordB.key()); + assertEquals(valueB, recordB.value()); + + PowerMock.verifyAll(); + } + + private void expectInitializeTask() throws Exception { + consumer.subscribe(EasyMock.eq(asList(TOPIC)), EasyMock.capture(rebalanceListener)); + PowerMock.expectLastCall(); + + sinkTask.initialize(EasyMock.capture(sinkTaskContext)); + PowerMock.expectLastCall(); + sinkTask.start(TASK_PROPS); + PowerMock.expectLastCall(); + } + + private void expectRebalanceLossError(RuntimeException e) { + sinkTask.close(new HashSet<>(INITIAL_ASSIGNMENT)); + EasyMock.expectLastCall().andThrow(e); + + EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer( + () -> { + rebalanceListener.getValue().onPartitionsLost(INITIAL_ASSIGNMENT); + return ConsumerRecords.empty(); + }); + } + + private void expectRebalanceRevocationError(RuntimeException e) { + sinkTask.close(new HashSet<>(INITIAL_ASSIGNMENT)); + EasyMock.expectLastCall().andThrow(e); + + sinkTask.preCommit(EasyMock.anyObject()); + EasyMock.expectLastCall().andReturn(Collections.emptyMap()); + + EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer( + () -> { + rebalanceListener.getValue().onPartitionsRevoked(INITIAL_ASSIGNMENT); + return ConsumerRecords.empty(); + }); + } + + private void expectRebalanceAssignmentError(RuntimeException e) { + sinkTask.close(INITIAL_ASSIGNMENT); + EasyMock.expectLastCall(); + + sinkTask.preCommit(EasyMock.anyObject()); + EasyMock.expectLastCall().andReturn(Collections.emptyMap()); + + EasyMock.expect(consumer.position(TOPIC_PARTITION)).andReturn(FIRST_OFFSET); + EasyMock.expect(consumer.position(TOPIC_PARTITION2)).andReturn(FIRST_OFFSET); + + sinkTask.open(INITIAL_ASSIGNMENT); + EasyMock.expectLastCall().andThrow(e); + + EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT).times(3); + EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer( + () -> { + rebalanceListener.getValue().onPartitionsRevoked(INITIAL_ASSIGNMENT); + rebalanceListener.getValue().onPartitionsAssigned(INITIAL_ASSIGNMENT); + return ConsumerRecords.empty(); + }); + } + + private void expectPollInitialAssignment() { + sinkTask.open(INITIAL_ASSIGNMENT); + EasyMock.expectLastCall(); + + EasyMock.expect(consumer.assignment()).andReturn(INITIAL_ASSIGNMENT).times(2); + + EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer(() -> { + rebalanceListener.getValue().onPartitionsAssigned(INITIAL_ASSIGNMENT); + return ConsumerRecords.empty(); + }); + INITIAL_ASSIGNMENT.forEach(tp -> EasyMock.expect(consumer.position(tp)).andReturn(FIRST_OFFSET)); + + sinkTask.put(Collections.emptyList()); + EasyMock.expectLastCall(); + } + + private void expectConsumerWakeup() { + consumer.wakeup(); + EasyMock.expectLastCall(); + EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andThrow(new WakeupException()); + } + + private void expectConsumerPoll(final int numMessages) { + expectConsumerPoll(numMessages, RecordBatch.NO_TIMESTAMP, TimestampType.NO_TIMESTAMP_TYPE, emptyHeaders()); + } + + private void expectConsumerPoll(final int numMessages, Headers headers) { + expectConsumerPoll(numMessages, RecordBatch.NO_TIMESTAMP, TimestampType.NO_TIMESTAMP_TYPE, headers); + } + + private void expectConsumerPoll(final int numMessages, final long timestamp, final TimestampType timestampType) { + expectConsumerPoll(numMessages, timestamp, timestampType, emptyHeaders()); + } + + private void expectConsumerPoll(final int numMessages, final long timestamp, final TimestampType timestampType, Headers headers) { + EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer( + () -> { + List> records = new ArrayList<>(); + for (int i = 0; i < numMessages; i++) + records.add(new ConsumerRecord<>(TOPIC, PARTITION, FIRST_OFFSET + recordsReturnedTp1 + i, timestamp, timestampType, + 0, 0, RAW_KEY, RAW_VALUE, headers, Optional.empty())); + recordsReturnedTp1 += numMessages; + return new ConsumerRecords<>( + numMessages > 0 ? + Collections.singletonMap(new TopicPartition(TOPIC, PARTITION), records) : + Collections.emptyMap() + ); + }); + } + + private void expectConsumerPoll(List> records) { + EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer( + () -> new ConsumerRecords<>( + records.isEmpty() ? + Collections.emptyMap() : + Collections.singletonMap(new TopicPartition(TOPIC, PARTITION), records) + )); + } + + private void expectConversionAndTransformation(final int numMessages) { + expectConversionAndTransformation(numMessages, null); + } + + private void expectConversionAndTransformation(final int numMessages, final String topicPrefix) { + expectConversionAndTransformation(numMessages, topicPrefix, emptyHeaders()); + } + + private void expectConversionAndTransformation(final int numMessages, final String topicPrefix, final Headers headers) { + EasyMock.expect(keyConverter.toConnectData(TOPIC, headers, RAW_KEY)).andReturn(new SchemaAndValue(KEY_SCHEMA, KEY)).times(numMessages); + EasyMock.expect(valueConverter.toConnectData(TOPIC, headers, RAW_VALUE)).andReturn(new SchemaAndValue(VALUE_SCHEMA, VALUE)).times(numMessages); + + for (Header header : headers) { + EasyMock.expect(headerConverter.toConnectHeader(TOPIC, header.key(), header.value())).andReturn(new SchemaAndValue(VALUE_SCHEMA, new String(header.value()))).times(1); + } + + expectTransformation(numMessages, topicPrefix); + } + + private void expectTransformation(final int numMessages, final String topicPrefix) { + final Capture recordCapture = EasyMock.newCapture(); + EasyMock.expect(transformationChain.apply(EasyMock.capture(recordCapture))) + .andAnswer(() -> { + SinkRecord origRecord = recordCapture.getValue(); + return topicPrefix != null && !topicPrefix.isEmpty() + ? origRecord.newRecord( + topicPrefix + origRecord.topic(), + origRecord.kafkaPartition(), + origRecord.keySchema(), + origRecord.key(), + origRecord.valueSchema(), + origRecord.value(), + origRecord.timestamp(), + origRecord.headers() + ) + : origRecord; + }).times(numMessages); + } + + private void expectTaskGetTopic(boolean anyTimes) { + final Capture connectorCapture = EasyMock.newCapture(); + final Capture topicCapture = EasyMock.newCapture(); + IExpectationSetters expect = EasyMock.expect(statusBackingStore.getTopic( + EasyMock.capture(connectorCapture), + EasyMock.capture(topicCapture))); + if (anyTimes) { + expect.andStubAnswer(() -> new TopicStatus( + topicCapture.getValue(), + new ConnectorTaskId(connectorCapture.getValue(), 0), + Time.SYSTEM.milliseconds())); + } else { + expect.andAnswer(() -> new TopicStatus( + topicCapture.getValue(), + new ConnectorTaskId(connectorCapture.getValue(), 0), + Time.SYSTEM.milliseconds())); + } + if (connectorCapture.hasCaptured() && topicCapture.hasCaptured()) { + assertEquals("job", connectorCapture.getValue()); + assertEquals(TOPIC, topicCapture.getValue()); + } + } + + private void assertSinkMetricValue(String name, double expected) { + MetricGroup sinkTaskGroup = workerTask.sinkTaskMetricsGroup().metricGroup(); + double measured = metrics.currentMetricValueAsDouble(sinkTaskGroup, name); + assertEquals(expected, measured, 0.001d); + } + + private void assertTaskMetricValue(String name, double expected) { + MetricGroup taskGroup = workerTask.taskMetricsGroup().metricGroup(); + double measured = metrics.currentMetricValueAsDouble(taskGroup, name); + assertEquals(expected, measured, 0.001d); + } + + private void assertTaskMetricValue(String name, String expected) { + MetricGroup taskGroup = workerTask.taskMetricsGroup().metricGroup(); + String measured = metrics.currentMetricValueAsString(taskGroup, name); + assertEquals(expected, measured); + } + + private void printMetrics() { + System.out.println(); + sinkMetricValue("sink-record-read-rate"); + sinkMetricValue("sink-record-read-total"); + sinkMetricValue("sink-record-send-rate"); + sinkMetricValue("sink-record-send-total"); + sinkMetricValue("sink-record-active-count"); + sinkMetricValue("sink-record-active-count-max"); + sinkMetricValue("sink-record-active-count-avg"); + sinkMetricValue("partition-count"); + sinkMetricValue("offset-commit-seq-no"); + sinkMetricValue("offset-commit-completion-rate"); + sinkMetricValue("offset-commit-completion-total"); + sinkMetricValue("offset-commit-skip-rate"); + sinkMetricValue("offset-commit-skip-total"); + sinkMetricValue("put-batch-max-time-ms"); + sinkMetricValue("put-batch-avg-time-ms"); + + taskMetricValue("status-unassigned"); + taskMetricValue("status-running"); + taskMetricValue("status-paused"); + taskMetricValue("status-failed"); + taskMetricValue("status-destroyed"); + taskMetricValue("running-ratio"); + taskMetricValue("pause-ratio"); + taskMetricValue("offset-commit-max-time-ms"); + taskMetricValue("offset-commit-avg-time-ms"); + taskMetricValue("batch-size-max"); + taskMetricValue("batch-size-avg"); + taskMetricValue("offset-commit-failure-percentage"); + taskMetricValue("offset-commit-success-percentage"); + } + + private double sinkMetricValue(String metricName) { + MetricGroup sinkTaskGroup = workerTask.sinkTaskMetricsGroup().metricGroup(); + double value = metrics.currentMetricValueAsDouble(sinkTaskGroup, metricName); + System.out.println("** " + metricName + "=" + value); + return value; + } + + private double taskMetricValue(String metricName) { + MetricGroup taskGroup = workerTask.taskMetricsGroup().metricGroup(); + double value = metrics.currentMetricValueAsDouble(taskGroup, metricName); + System.out.println("** " + metricName + "=" + value); + return value; + } + + + private void assertMetrics(int minimumPollCountExpected) { + MetricGroup sinkTaskGroup = workerTask.sinkTaskMetricsGroup().metricGroup(); + MetricGroup taskGroup = workerTask.taskMetricsGroup().metricGroup(); + double readRate = metrics.currentMetricValueAsDouble(sinkTaskGroup, "sink-record-read-rate"); + double readTotal = metrics.currentMetricValueAsDouble(sinkTaskGroup, "sink-record-read-total"); + double sendRate = metrics.currentMetricValueAsDouble(sinkTaskGroup, "sink-record-send-rate"); + double sendTotal = metrics.currentMetricValueAsDouble(sinkTaskGroup, "sink-record-send-total"); + } + + private RecordHeaders emptyHeaders() { + return new RecordHeaders(); + } + + private abstract static class TestSinkTask extends SinkTask { + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSinkTaskThreadedTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSinkTaskThreadedTest.java new file mode 100644 index 0000000..a7c6a8a --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSinkTaskThreadedTest.java @@ -0,0 +1,714 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.clients.consumer.ConsumerRebalanceListener; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.consumer.OffsetCommitCallback; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.runtime.distributed.ClusterConfigState; +import org.apache.kafka.connect.runtime.errors.RetryWithToleranceOperatorTest; +import org.apache.kafka.connect.runtime.isolation.PluginClassLoader; +import org.apache.kafka.connect.runtime.standalone.StandaloneConfig; +import org.apache.kafka.connect.sink.SinkConnector; +import org.apache.kafka.connect.sink.SinkRecord; +import org.apache.kafka.connect.sink.SinkTask; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.storage.HeaderConverter; +import org.apache.kafka.connect.storage.StatusBackingStore; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.connect.util.ThreadedTest; +import org.easymock.Capture; +import org.easymock.CaptureType; +import org.easymock.EasyMock; +import org.easymock.IExpectationSetters; +import org.junit.After; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.easymock.PowerMock; +import org.powermock.api.easymock.annotation.Mock; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; +import org.powermock.reflect.Whitebox; + +import java.time.Duration; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +@RunWith(PowerMockRunner.class) +@PrepareForTest(WorkerSinkTask.class) +@PowerMockIgnore("javax.management.*") +public class WorkerSinkTaskThreadedTest extends ThreadedTest { + + // These are fixed to keep this code simpler. In this example we assume byte[] raw values + // with mix of integer/string in Connect + private static final String TOPIC = "test"; + private static final int PARTITION = 12; + private static final int PARTITION2 = 13; + private static final int PARTITION3 = 14; + private static final long FIRST_OFFSET = 45; + private static final Schema KEY_SCHEMA = Schema.INT32_SCHEMA; + private static final int KEY = 12; + private static final Schema VALUE_SCHEMA = Schema.STRING_SCHEMA; + private static final String VALUE = "VALUE"; + private static final byte[] RAW_KEY = "key".getBytes(); + private static final byte[] RAW_VALUE = "value".getBytes(); + + private static final TopicPartition TOPIC_PARTITION = new TopicPartition(TOPIC, PARTITION); + private static final TopicPartition TOPIC_PARTITION2 = new TopicPartition(TOPIC, PARTITION2); + private static final TopicPartition TOPIC_PARTITION3 = new TopicPartition(TOPIC, PARTITION3); + private static final TopicPartition UNASSIGNED_TOPIC_PARTITION = new TopicPartition(TOPIC, 200); + private static final Set INITIAL_ASSIGNMENT = new HashSet<>(Arrays.asList( + TOPIC_PARTITION, TOPIC_PARTITION2, TOPIC_PARTITION3)); + + private static final Map TASK_PROPS = new HashMap<>(); + private static final long TIMESTAMP = 42L; + private static final TimestampType TIMESTAMP_TYPE = TimestampType.CREATE_TIME; + + static { + TASK_PROPS.put(SinkConnector.TOPICS_CONFIG, TOPIC); + TASK_PROPS.put(TaskConfig.TASK_CLASS_CONFIG, TestSinkTask.class.getName()); + } + private static final TaskConfig TASK_CONFIG = new TaskConfig(TASK_PROPS); + + private ConnectorTaskId taskId = new ConnectorTaskId("job", 0); + private TargetState initialState = TargetState.STARTED; + private Time time; + private ConnectMetrics metrics; + @Mock private SinkTask sinkTask; + private Capture sinkTaskContext = EasyMock.newCapture(); + private WorkerConfig workerConfig; + @Mock + private PluginClassLoader pluginLoader; + @Mock private Converter keyConverter; + @Mock private Converter valueConverter; + @Mock private HeaderConverter headerConverter; + @Mock private TransformationChain transformationChain; + private WorkerSinkTask workerTask; + @Mock private KafkaConsumer consumer; + private Capture rebalanceListener = EasyMock.newCapture(); + @Mock private TaskStatus.Listener statusListener; + @Mock private StatusBackingStore statusBackingStore; + + private long recordsReturned; + + + @Override + public void setup() { + super.setup(); + time = new MockTime(); + metrics = new MockConnectMetrics(); + Map workerProps = new HashMap<>(); + workerProps.put("key.converter", "org.apache.kafka.connect.json.JsonConverter"); + workerProps.put("value.converter", "org.apache.kafka.connect.json.JsonConverter"); + workerProps.put("offset.storage.file.filename", "/tmp/connect.offsets"); + pluginLoader = PowerMock.createMock(PluginClassLoader.class); + workerConfig = new StandaloneConfig(workerProps); + workerTask = new WorkerSinkTask( + taskId, sinkTask, statusListener, initialState, workerConfig, ClusterConfigState.EMPTY, metrics, keyConverter, + valueConverter, headerConverter, + new TransformationChain<>(Collections.emptyList(), RetryWithToleranceOperatorTest.NOOP_OPERATOR), + consumer, pluginLoader, time, RetryWithToleranceOperatorTest.NOOP_OPERATOR, null, statusBackingStore); + + recordsReturned = 0; + } + + @After + public void tearDown() { + if (metrics != null) metrics.stop(); + } + + @Test + public void testPollsInBackground() throws Exception { + expectInitializeTask(); + expectTaskGetTopic(true); + expectPollInitialAssignment(); + + Capture> capturedRecords = expectPolls(1L); + expectStopTask(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + + // First iteration initializes partition assignment + workerTask.iteration(); + + // Then we iterate to fetch data + for (int i = 0; i < 10; i++) { + workerTask.iteration(); + } + workerTask.stop(); + workerTask.close(); + + // Verify contents match expected values, i.e. that they were translated properly. With max + // batch size 1 and poll returns 1 message at a time, we should have a matching # of batches + assertEquals(10, capturedRecords.getValues().size()); + int offset = 0; + for (Collection recs : capturedRecords.getValues()) { + assertEquals(1, recs.size()); + for (SinkRecord rec : recs) { + SinkRecord referenceSinkRecord + = new SinkRecord(TOPIC, PARTITION, KEY_SCHEMA, KEY, VALUE_SCHEMA, VALUE, FIRST_OFFSET + offset, TIMESTAMP, TIMESTAMP_TYPE); + InternalSinkRecord referenceInternalSinkRecord = + new InternalSinkRecord(null, referenceSinkRecord); + assertEquals(referenceInternalSinkRecord, rec); + offset++; + } + } + + PowerMock.verifyAll(); + } + + @Test + public void testCommit() throws Exception { + expectInitializeTask(); + expectTaskGetTopic(true); + expectPollInitialAssignment(); + expectConsumerAssignment(INITIAL_ASSIGNMENT).times(2); + + // Make each poll() take the offset commit interval + Capture> capturedRecords + = expectPolls(WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_DEFAULT); + expectOffsetCommit(1L, null, null, 0, true); + expectStopTask(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + + // Initialize partition assignment + workerTask.iteration(); + // Fetch one record + workerTask.iteration(); + // Trigger the commit + workerTask.iteration(); + + // Commit finishes synchronously for testing so we can check this immediately + assertEquals(0, workerTask.commitFailures()); + workerTask.stop(); + workerTask.close(); + + assertEquals(2, capturedRecords.getValues().size()); + + PowerMock.verifyAll(); + } + + @Test + public void testCommitFailure() throws Exception { + expectInitializeTask(); + expectTaskGetTopic(true); + expectPollInitialAssignment(); + expectConsumerAssignment(INITIAL_ASSIGNMENT); + + Capture> capturedRecords = expectPolls(WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_DEFAULT); + expectOffsetCommit(1L, new RuntimeException(), null, 0, true); + // Should rewind to last known good positions, which in this case will be the offsets loaded during initialization + // for all topic partitions + consumer.seek(TOPIC_PARTITION, FIRST_OFFSET); + PowerMock.expectLastCall(); + consumer.seek(TOPIC_PARTITION2, FIRST_OFFSET); + PowerMock.expectLastCall(); + consumer.seek(TOPIC_PARTITION3, FIRST_OFFSET); + PowerMock.expectLastCall(); + expectStopTask(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + + // Initialize partition assignment + workerTask.iteration(); + // Fetch some data + workerTask.iteration(); + // Trigger the commit + workerTask.iteration(); + + assertEquals(1, workerTask.commitFailures()); + assertEquals(false, Whitebox.getInternalState(workerTask, "committing")); + workerTask.stop(); + workerTask.close(); + + PowerMock.verifyAll(); + } + + @Test + public void testCommitSuccessFollowedByFailure() throws Exception { + // Validate that we rewind to the correct offsets if a task's preCommit() method throws an exception + + expectInitializeTask(); + expectTaskGetTopic(true); + expectPollInitialAssignment(); + expectConsumerAssignment(INITIAL_ASSIGNMENT).times(3); + Capture> capturedRecords = expectPolls(WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_DEFAULT); + expectOffsetCommit(1L, null, null, 0, true); + expectOffsetCommit(2L, new RuntimeException(), null, 0, true); + // Should rewind to last known committed positions + consumer.seek(TOPIC_PARTITION, FIRST_OFFSET + 1); + PowerMock.expectLastCall(); + consumer.seek(TOPIC_PARTITION2, FIRST_OFFSET); + PowerMock.expectLastCall(); + consumer.seek(TOPIC_PARTITION3, FIRST_OFFSET); + PowerMock.expectLastCall(); + expectStopTask(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + + // Initialize partition assignment + workerTask.iteration(); + // Fetch some data + workerTask.iteration(); + // Trigger first commit, + workerTask.iteration(); + // Trigger second (failing) commit + workerTask.iteration(); + + assertEquals(1, workerTask.commitFailures()); + assertEquals(false, Whitebox.getInternalState(workerTask, "committing")); + workerTask.stop(); + workerTask.close(); + + PowerMock.verifyAll(); + } + + @Test + public void testCommitConsumerFailure() throws Exception { + expectInitializeTask(); + expectTaskGetTopic(true); + expectPollInitialAssignment(); + expectConsumerAssignment(INITIAL_ASSIGNMENT).times(2); + + Capture> capturedRecords + = expectPolls(WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_DEFAULT); + expectOffsetCommit(1L, null, new Exception(), 0, true); + expectStopTask(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + + // Initialize partition assignment + workerTask.iteration(); + // Fetch some data + workerTask.iteration(); + // Trigger commit + workerTask.iteration(); + + // TODO Response to consistent failures? + assertEquals(1, workerTask.commitFailures()); + assertEquals(false, Whitebox.getInternalState(workerTask, "committing")); + workerTask.stop(); + workerTask.close(); + + PowerMock.verifyAll(); + } + + @Test + public void testCommitTimeout() throws Exception { + expectInitializeTask(); + expectTaskGetTopic(true); + expectPollInitialAssignment(); + expectConsumerAssignment(INITIAL_ASSIGNMENT).times(2); + + // Cut down amount of time to pass in each poll so we trigger exactly 1 offset commit + Capture> capturedRecords + = expectPolls(WorkerConfig.OFFSET_COMMIT_INTERVAL_MS_DEFAULT / 2); + expectOffsetCommit(2L, null, null, WorkerConfig.OFFSET_COMMIT_TIMEOUT_MS_DEFAULT, false); + expectStopTask(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + + // Initialize partition assignment + workerTask.iteration(); + // Fetch some data + workerTask.iteration(); + workerTask.iteration(); + // Trigger the commit + workerTask.iteration(); + // Trigger the timeout without another commit + workerTask.iteration(); + + // TODO Response to consistent failures? + assertEquals(1, workerTask.commitFailures()); + assertEquals(false, Whitebox.getInternalState(workerTask, "committing")); + workerTask.stop(); + workerTask.close(); + + PowerMock.verifyAll(); + } + + @Test + public void testAssignmentPauseResume() throws Exception { + // Just validate that the calls are passed through to the consumer, and that where appropriate errors are + // converted + expectInitializeTask(); + expectTaskGetTopic(true); + + expectPollInitialAssignment(); + expectOnePoll().andAnswer(() -> { + assertEquals(new HashSet<>(Arrays.asList(TOPIC_PARTITION, TOPIC_PARTITION2, TOPIC_PARTITION3)), + sinkTaskContext.getValue().assignment()); + return null; + }); + EasyMock.expect(consumer.assignment()).andReturn(new HashSet<>(Arrays.asList(TOPIC_PARTITION, TOPIC_PARTITION2, TOPIC_PARTITION3))); + + expectOnePoll().andAnswer(() -> { + try { + sinkTaskContext.getValue().pause(UNASSIGNED_TOPIC_PARTITION); + fail("Trying to pause unassigned partition should have thrown an Connect exception"); + } catch (ConnectException e) { + // expected + } + sinkTaskContext.getValue().pause(TOPIC_PARTITION, TOPIC_PARTITION2); + return null; + }); + consumer.pause(Arrays.asList(UNASSIGNED_TOPIC_PARTITION)); + PowerMock.expectLastCall().andThrow(new IllegalStateException("unassigned topic partition")); + consumer.pause(Arrays.asList(TOPIC_PARTITION, TOPIC_PARTITION2)); + PowerMock.expectLastCall(); + + expectOnePoll().andAnswer(() -> { + try { + sinkTaskContext.getValue().resume(UNASSIGNED_TOPIC_PARTITION); + fail("Trying to resume unassigned partition should have thrown an Connect exception"); + } catch (ConnectException e) { + // expected + } + + sinkTaskContext.getValue().resume(TOPIC_PARTITION, TOPIC_PARTITION2); + return null; + }); + consumer.resume(Arrays.asList(UNASSIGNED_TOPIC_PARTITION)); + PowerMock.expectLastCall().andThrow(new IllegalStateException("unassigned topic partition")); + consumer.resume(Arrays.asList(TOPIC_PARTITION, TOPIC_PARTITION2)); + PowerMock.expectLastCall(); + + expectStopTask(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + workerTask.iteration(); + workerTask.iteration(); + workerTask.iteration(); + workerTask.iteration(); + workerTask.stop(); + workerTask.close(); + + PowerMock.verifyAll(); + } + + @Test + public void testRewind() throws Exception { + expectInitializeTask(); + expectTaskGetTopic(true); + expectPollInitialAssignment(); + + final long startOffset = 40L; + final Map offsets = new HashMap<>(); + + expectOnePoll().andAnswer(() -> { + offsets.put(TOPIC_PARTITION, startOffset); + sinkTaskContext.getValue().offset(offsets); + return null; + }); + + consumer.seek(TOPIC_PARTITION, startOffset); + EasyMock.expectLastCall(); + + expectOnePoll().andAnswer(() -> { + Map offsets1 = sinkTaskContext.getValue().offsets(); + assertEquals(0, offsets1.size()); + return null; + }); + + expectStopTask(); + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + workerTask.iteration(); + workerTask.iteration(); + workerTask.iteration(); + workerTask.stop(); + workerTask.close(); + + PowerMock.verifyAll(); + } + + @Test + public void testRewindOnRebalanceDuringPoll() throws Exception { + expectInitializeTask(); + expectTaskGetTopic(true); + expectPollInitialAssignment(); + expectConsumerAssignment(INITIAL_ASSIGNMENT).times(2); + + expectRebalanceDuringPoll().andAnswer(() -> { + Map offsets = sinkTaskContext.getValue().offsets(); + assertEquals(0, offsets.size()); + return null; + }); + + expectStopTask(); + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + workerTask.initializeAndStart(); + workerTask.iteration(); + workerTask.iteration(); + workerTask.stop(); + workerTask.close(); + + PowerMock.verifyAll(); + } + + private void expectInitializeTask() throws Exception { + + consumer.subscribe(EasyMock.eq(Arrays.asList(TOPIC)), EasyMock.capture(rebalanceListener)); + PowerMock.expectLastCall(); + + sinkTask.initialize(EasyMock.capture(sinkTaskContext)); + PowerMock.expectLastCall(); + sinkTask.start(TASK_PROPS); + PowerMock.expectLastCall(); + } + + private void expectPollInitialAssignment() throws Exception { + expectConsumerAssignment(INITIAL_ASSIGNMENT).times(2); + + sinkTask.open(INITIAL_ASSIGNMENT); + EasyMock.expectLastCall(); + + EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer(() -> { + rebalanceListener.getValue().onPartitionsAssigned(INITIAL_ASSIGNMENT); + return ConsumerRecords.empty(); + }); + EasyMock.expect(consumer.position(TOPIC_PARTITION)).andReturn(FIRST_OFFSET); + EasyMock.expect(consumer.position(TOPIC_PARTITION2)).andReturn(FIRST_OFFSET); + EasyMock.expect(consumer.position(TOPIC_PARTITION3)).andReturn(FIRST_OFFSET); + + sinkTask.put(Collections.emptyList()); + EasyMock.expectLastCall(); + } + + private IExpectationSetters> expectConsumerAssignment(Set assignment) { + return EasyMock.expect(consumer.assignment()).andReturn(assignment); + } + + private void expectStopTask() throws Exception { + sinkTask.stop(); + PowerMock.expectLastCall(); + + // No offset commit since it happens in the mocked worker thread, but the main thread does need to wake up the + // consumer so it exits quickly + consumer.wakeup(); + PowerMock.expectLastCall(); + + consumer.close(); + PowerMock.expectLastCall(); + } + + // Note that this can only be called once per test currently + private Capture> expectPolls(final long pollDelayMs) throws Exception { + // Stub out all the consumer stream/iterator responses, which we just want to verify occur, + // but don't care about the exact details here. + EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andStubAnswer( + () -> { + // "Sleep" so time will progress + time.sleep(pollDelayMs); + ConsumerRecords records = new ConsumerRecords<>( + Collections.singletonMap( + new TopicPartition(TOPIC, PARTITION), + Arrays.asList(new ConsumerRecord<>(TOPIC, PARTITION, FIRST_OFFSET + recordsReturned, TIMESTAMP, TIMESTAMP_TYPE, + 0, 0, RAW_KEY, RAW_VALUE, new RecordHeaders(), Optional.empty())))); + recordsReturned++; + return records; + }); + EasyMock.expect(keyConverter.toConnectData(TOPIC, emptyHeaders(), RAW_KEY)).andReturn(new SchemaAndValue(KEY_SCHEMA, KEY)).anyTimes(); + EasyMock.expect(valueConverter.toConnectData(TOPIC, emptyHeaders(), RAW_VALUE)).andReturn(new SchemaAndValue(VALUE_SCHEMA, VALUE)).anyTimes(); + + final Capture recordCapture = EasyMock.newCapture(); + EasyMock.expect(transformationChain.apply(EasyMock.capture(recordCapture))).andAnswer( + recordCapture::getValue).anyTimes(); + + Capture> capturedRecords = EasyMock.newCapture(CaptureType.ALL); + sinkTask.put(EasyMock.capture(capturedRecords)); + EasyMock.expectLastCall().anyTimes(); + return capturedRecords; + } + + @SuppressWarnings("unchecked") + private IExpectationSetters expectOnePoll() { + // Currently the SinkTask's put() method will not be invoked unless we provide some data, so instead of + // returning empty data, we return one record. The expectation is that the data will be ignored by the + // response behavior specified using the return value of this method. + EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer( + () -> { + // "Sleep" so time will progress + time.sleep(1L); + ConsumerRecords records = new ConsumerRecords<>( + Collections.singletonMap( + new TopicPartition(TOPIC, PARTITION), + Arrays.asList(new ConsumerRecord<>(TOPIC, PARTITION, FIRST_OFFSET + recordsReturned, TIMESTAMP, TIMESTAMP_TYPE, + 0, 0, RAW_KEY, RAW_VALUE, new RecordHeaders(), Optional.empty())))); + recordsReturned++; + return records; + }); + EasyMock.expect(keyConverter.toConnectData(TOPIC, emptyHeaders(), RAW_KEY)).andReturn(new SchemaAndValue(KEY_SCHEMA, KEY)); + EasyMock.expect(valueConverter.toConnectData(TOPIC, emptyHeaders(), RAW_VALUE)).andReturn(new SchemaAndValue(VALUE_SCHEMA, VALUE)); + sinkTask.put(EasyMock.anyObject(Collection.class)); + return EasyMock.expectLastCall(); + } + + @SuppressWarnings("unchecked") + private IExpectationSetters expectRebalanceDuringPoll() throws Exception { + final List partitions = Arrays.asList(TOPIC_PARTITION, TOPIC_PARTITION2, TOPIC_PARTITION3); + + final long startOffset = 40L; + final Map offsets = new HashMap<>(); + offsets.put(TOPIC_PARTITION, startOffset); + + EasyMock.expect(consumer.poll(Duration.ofMillis(EasyMock.anyLong()))).andAnswer( + () -> { + // "Sleep" so time will progress + time.sleep(1L); + + sinkTaskContext.getValue().offset(offsets); + rebalanceListener.getValue().onPartitionsAssigned(partitions); + + ConsumerRecords records = new ConsumerRecords<>( + Collections.singletonMap( + new TopicPartition(TOPIC, PARTITION), + Arrays.asList(new ConsumerRecord<>(TOPIC, PARTITION, FIRST_OFFSET + recordsReturned, TIMESTAMP, TIMESTAMP_TYPE, + 0, 0, RAW_KEY, RAW_VALUE, new RecordHeaders(), Optional.empty()) + ))); + recordsReturned++; + return records; + }); + + EasyMock.expect(consumer.position(TOPIC_PARTITION)).andReturn(FIRST_OFFSET); + EasyMock.expect(consumer.position(TOPIC_PARTITION2)).andReturn(FIRST_OFFSET); + EasyMock.expect(consumer.position(TOPIC_PARTITION3)).andReturn(FIRST_OFFSET); + + sinkTask.open(partitions); + EasyMock.expectLastCall(); + + consumer.seek(TOPIC_PARTITION, startOffset); + EasyMock.expectLastCall(); + + EasyMock.expect(keyConverter.toConnectData(TOPIC, emptyHeaders(), RAW_KEY)).andReturn(new SchemaAndValue(KEY_SCHEMA, KEY)); + EasyMock.expect(valueConverter.toConnectData(TOPIC, emptyHeaders(), RAW_VALUE)).andReturn(new SchemaAndValue(VALUE_SCHEMA, VALUE)); + sinkTask.put(EasyMock.anyObject(Collection.class)); + return EasyMock.expectLastCall(); + } + + private Capture expectOffsetCommit(final long expectedMessages, + final RuntimeException error, + final Exception consumerCommitError, + final long consumerCommitDelayMs, + final boolean invokeCallback) + throws Exception { + final long finalOffset = FIRST_OFFSET + expectedMessages; + + // All assigned partitions will have offsets committed, but we've only processed messages/updated offsets for one + final Map offsetsToCommit = new HashMap<>(); + offsetsToCommit.put(TOPIC_PARTITION, new OffsetAndMetadata(finalOffset)); + offsetsToCommit.put(TOPIC_PARTITION2, new OffsetAndMetadata(FIRST_OFFSET)); + offsetsToCommit.put(TOPIC_PARTITION3, new OffsetAndMetadata(FIRST_OFFSET)); + sinkTask.preCommit(offsetsToCommit); + IExpectationSetters expectation = PowerMock.expectLastCall(); + if (error != null) { + expectation.andThrow(error).once(); + return null; + } else { + expectation.andReturn(offsetsToCommit); + } + + final Capture capturedCallback = EasyMock.newCapture(); + consumer.commitAsync(EasyMock.eq(offsetsToCommit), + EasyMock.capture(capturedCallback)); + PowerMock.expectLastCall().andAnswer(() -> { + time.sleep(consumerCommitDelayMs); + if (invokeCallback) + capturedCallback.getValue().onComplete(offsetsToCommit, consumerCommitError); + return null; + }); + return capturedCallback; + } + + private void expectTaskGetTopic(boolean anyTimes) { + final Capture connectorCapture = EasyMock.newCapture(); + final Capture topicCapture = EasyMock.newCapture(); + IExpectationSetters expect = EasyMock.expect(statusBackingStore.getTopic( + EasyMock.capture(connectorCapture), + EasyMock.capture(topicCapture))); + if (anyTimes) { + expect.andStubAnswer(() -> new TopicStatus( + topicCapture.getValue(), + new ConnectorTaskId(connectorCapture.getValue(), 0), + Time.SYSTEM.milliseconds())); + } else { + expect.andAnswer(() -> new TopicStatus( + topicCapture.getValue(), + new ConnectorTaskId(connectorCapture.getValue(), 0), + Time.SYSTEM.milliseconds())); + } + if (connectorCapture.hasCaptured() && topicCapture.hasCaptured()) { + assertEquals("job", connectorCapture.getValue()); + assertEquals(TOPIC, topicCapture.getValue()); + } + } + + private RecordHeaders emptyHeaders() { + return new RecordHeaders(); + } + + private static abstract class TestSinkTask extends SinkTask { + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSourceTaskTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSourceTaskTest.java new file mode 100644 index 0000000..fcd657f --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerSourceTaskTest.java @@ -0,0 +1,1596 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import java.util.Collection; +import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.clients.admin.TopicDescription; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.InvalidRecordException; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.TopicPartitionInfo; +import org.apache.kafka.common.errors.InvalidTopicException; +import org.apache.kafka.common.errors.TopicAuthorizationException; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.errors.RetriableException; +import org.apache.kafka.connect.header.ConnectHeaders; +import org.apache.kafka.connect.integration.MonitorableSourceConnector; +import org.apache.kafka.connect.runtime.ConnectMetrics.MetricGroup; +import org.apache.kafka.connect.runtime.WorkerSourceTask.SourceTaskMetricsGroup; +import org.apache.kafka.connect.runtime.distributed.ClusterConfigState; +import org.apache.kafka.connect.runtime.errors.RetryWithToleranceOperatorTest; +import org.apache.kafka.connect.runtime.isolation.Plugins; +import org.apache.kafka.connect.runtime.standalone.StandaloneConfig; +import org.apache.kafka.connect.source.SourceRecord; +import org.apache.kafka.connect.source.SourceTask; +import org.apache.kafka.connect.source.SourceTaskContext; +import org.apache.kafka.connect.storage.CloseableOffsetStorageReader; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.storage.HeaderConverter; +import org.apache.kafka.connect.storage.OffsetStorageWriter; +import org.apache.kafka.connect.storage.StatusBackingStore; +import org.apache.kafka.connect.storage.StringConverter; +import org.apache.kafka.connect.util.Callback; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.apache.kafka.connect.util.ParameterizedTest; +import org.apache.kafka.connect.util.ThreadedTest; +import org.apache.kafka.connect.util.TopicAdmin; +import org.apache.kafka.connect.util.TopicCreationGroup; +import org.easymock.Capture; +import org.easymock.EasyMock; +import org.easymock.IAnswer; +import org.easymock.IExpectationSetters; +import org.junit.After; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.easymock.PowerMock; +import org.powermock.api.easymock.annotation.Mock; +import org.powermock.api.easymock.annotation.MockStrict; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.modules.junit4.PowerMockRunner; +import org.powermock.modules.junit4.PowerMockRunnerDelegate; +import org.powermock.reflect.Whitebox; + +import java.nio.ByteBuffer; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.apache.kafka.connect.integration.MonitorableSourceConnector.TOPIC_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.CONNECTOR_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.KEY_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.TASKS_MAX_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.VALUE_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.SourceConnectorConfig.TOPIC_CREATION_GROUPS_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.DEFAULT_TOPIC_CREATION_PREFIX; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.EXCLUDE_REGEX_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.INCLUDE_REGEX_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.PARTITIONS_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.REPLICATION_FACTOR_CONFIG; +import static org.apache.kafka.connect.runtime.WorkerConfig.TOPIC_CREATION_ENABLE_CONFIG; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +@PowerMockIgnore({"javax.management.*", + "org.apache.log4j.*"}) +@RunWith(PowerMockRunner.class) +@PowerMockRunnerDelegate(ParameterizedTest.class) +public class WorkerSourceTaskTest extends ThreadedTest { + private static final String TOPIC = "topic"; + private static final String OTHER_TOPIC = "other-topic"; + private static final Map PARTITION = Collections.singletonMap("key", "partition".getBytes()); + private static final Map OFFSET = Collections.singletonMap("key", 12); + + // Connect-format data + private static final Schema KEY_SCHEMA = Schema.INT32_SCHEMA; + private static final Integer KEY = -1; + private static final Schema RECORD_SCHEMA = Schema.INT64_SCHEMA; + private static final Long RECORD = 12L; + // Serialized data. The actual format of this data doesn't matter -- we just want to see that the right version + // is used in the right place. + private static final byte[] SERIALIZED_KEY = "converted-key".getBytes(); + private static final byte[] SERIALIZED_RECORD = "converted-record".getBytes(); + + private ExecutorService executor = Executors.newSingleThreadExecutor(); + private ConnectorTaskId taskId = new ConnectorTaskId("job", 0); + private ConnectorTaskId taskId1 = new ConnectorTaskId("job", 1); + private WorkerConfig config; + private SourceConnectorConfig sourceConfig; + private Plugins plugins; + private MockConnectMetrics metrics; + @Mock private SourceTask sourceTask; + @Mock private Converter keyConverter; + @Mock private Converter valueConverter; + @Mock private HeaderConverter headerConverter; + @Mock private TransformationChain transformationChain; + @Mock private KafkaProducer producer; + @Mock private TopicAdmin admin; + @Mock private CloseableOffsetStorageReader offsetReader; + @Mock private OffsetStorageWriter offsetWriter; + @Mock private ClusterConfigState clusterConfigState; + private WorkerSourceTask workerTask; + @Mock private Future sendFuture; + @MockStrict private TaskStatus.Listener statusListener; + @Mock private StatusBackingStore statusBackingStore; + + private Capture producerCallbacks; + + private static final Map TASK_PROPS = new HashMap<>(); + static { + TASK_PROPS.put(TaskConfig.TASK_CLASS_CONFIG, TestSourceTask.class.getName()); + } + private static final TaskConfig TASK_CONFIG = new TaskConfig(TASK_PROPS); + + private static final List RECORDS = Arrays.asList( + new SourceRecord(PARTITION, OFFSET, "topic", null, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD) + ); + + private boolean enableTopicCreation; + + @ParameterizedTest.Parameters + public static Collection parameters() { + return Arrays.asList(false, true); + } + + public WorkerSourceTaskTest(boolean enableTopicCreation) { + this.enableTopicCreation = enableTopicCreation; + } + + @Override + public void setup() { + super.setup(); + Map workerProps = workerProps(); + plugins = new Plugins(workerProps); + config = new StandaloneConfig(workerProps); + sourceConfig = new SourceConnectorConfig(plugins, sourceConnectorPropsWithGroups(TOPIC), true); + producerCallbacks = EasyMock.newCapture(); + metrics = new MockConnectMetrics(); + } + + private Map workerProps() { + Map props = new HashMap<>(); + props.put("key.converter", "org.apache.kafka.connect.json.JsonConverter"); + props.put("value.converter", "org.apache.kafka.connect.json.JsonConverter"); + props.put("offset.storage.file.filename", "/tmp/connect.offsets"); + props.put(TOPIC_CREATION_ENABLE_CONFIG, String.valueOf(enableTopicCreation)); + return props; + } + + private Map sourceConnectorPropsWithGroups(String topic) { + // setup up props for the source connector + Map props = new HashMap<>(); + props.put("name", "foo-connector"); + props.put(CONNECTOR_CLASS_CONFIG, MonitorableSourceConnector.class.getSimpleName()); + props.put(TASKS_MAX_CONFIG, String.valueOf(1)); + props.put(TOPIC_CONFIG, topic); + props.put(KEY_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(VALUE_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(TOPIC_CREATION_GROUPS_CONFIG, String.join(",", "foo", "bar")); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + REPLICATION_FACTOR_CONFIG, String.valueOf(1)); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + PARTITIONS_CONFIG, String.valueOf(1)); + props.put(SourceConnectorConfig.TOPIC_CREATION_PREFIX + "foo" + "." + INCLUDE_REGEX_CONFIG, topic); + props.put(SourceConnectorConfig.TOPIC_CREATION_PREFIX + "bar" + "." + INCLUDE_REGEX_CONFIG, ".*"); + props.put(SourceConnectorConfig.TOPIC_CREATION_PREFIX + "bar" + "." + EXCLUDE_REGEX_CONFIG, topic); + return props; + } + + @After + public void tearDown() { + if (metrics != null) metrics.stop(); + } + + private void createWorkerTask() { + createWorkerTask(TargetState.STARTED); + } + + private void createWorkerTask(TargetState initialState) { + createWorkerTask(initialState, keyConverter, valueConverter, headerConverter); + } + + private void createWorkerTask(TargetState initialState, Converter keyConverter, Converter valueConverter, HeaderConverter headerConverter) { + workerTask = new WorkerSourceTask(taskId, sourceTask, statusListener, initialState, keyConverter, valueConverter, headerConverter, + transformationChain, producer, admin, TopicCreationGroup.configuredGroups(sourceConfig), + offsetReader, offsetWriter, config, clusterConfigState, metrics, plugins.delegatingLoader(), Time.SYSTEM, + RetryWithToleranceOperatorTest.NOOP_OPERATOR, statusBackingStore, Runnable::run); + } + + @Test + public void testStartPaused() throws Exception { + final CountDownLatch pauseLatch = new CountDownLatch(1); + + createWorkerTask(TargetState.PAUSED); + + statusListener.onPause(taskId); + EasyMock.expectLastCall().andAnswer(() -> { + pauseLatch.countDown(); + return null; + }); + + expectClose(); + + statusListener.onShutdown(taskId); + EasyMock.expectLastCall(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + Future taskFuture = executor.submit(workerTask); + + assertTrue(pauseLatch.await(5, TimeUnit.SECONDS)); + workerTask.stop(); + assertTrue(workerTask.awaitStop(1000)); + + taskFuture.get(); + + PowerMock.verifyAll(); + } + + @Test + public void testPause() throws Exception { + createWorkerTask(); + + sourceTask.initialize(EasyMock.anyObject(SourceTaskContext.class)); + EasyMock.expectLastCall(); + sourceTask.start(TASK_PROPS); + EasyMock.expectLastCall(); + statusListener.onStartup(taskId); + EasyMock.expectLastCall(); + + AtomicInteger count = new AtomicInteger(0); + CountDownLatch pollLatch = expectPolls(10, count); + // In this test, we don't flush, so nothing goes any further than the offset writer + + expectTopicCreation(TOPIC); + + statusListener.onPause(taskId); + EasyMock.expectLastCall(); + + sourceTask.stop(); + EasyMock.expectLastCall(); + + offsetWriter.offset(PARTITION, OFFSET); + PowerMock.expectLastCall(); + expectOffsetFlush(true); + + statusListener.onShutdown(taskId); + EasyMock.expectLastCall(); + + expectClose(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + Future taskFuture = executor.submit(workerTask); + assertTrue(awaitLatch(pollLatch)); + + workerTask.transitionTo(TargetState.PAUSED); + + int priorCount = count.get(); + Thread.sleep(100); + + // since the transition is observed asynchronously, the count could be off by one loop iteration + assertTrue(count.get() - priorCount <= 1); + + workerTask.stop(); + assertTrue(workerTask.awaitStop(1000)); + + taskFuture.get(); + + PowerMock.verifyAll(); + } + + @Test + public void testPollsInBackground() throws Exception { + createWorkerTask(); + + sourceTask.initialize(EasyMock.anyObject(SourceTaskContext.class)); + EasyMock.expectLastCall(); + sourceTask.start(TASK_PROPS); + EasyMock.expectLastCall(); + statusListener.onStartup(taskId); + EasyMock.expectLastCall(); + + final CountDownLatch pollLatch = expectPolls(10); + // In this test, we don't flush, so nothing goes any further than the offset writer + + expectTopicCreation(TOPIC); + + sourceTask.stop(); + EasyMock.expectLastCall(); + + offsetWriter.offset(PARTITION, OFFSET); + PowerMock.expectLastCall(); + expectOffsetFlush(true); + + statusListener.onShutdown(taskId); + EasyMock.expectLastCall(); + + expectClose(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + Future taskFuture = executor.submit(workerTask); + + assertTrue(awaitLatch(pollLatch)); + workerTask.stop(); + assertTrue(workerTask.awaitStop(1000)); + + taskFuture.get(); + assertPollMetrics(10); + + PowerMock.verifyAll(); + } + + @Test + public void testFailureInPoll() throws Exception { + createWorkerTask(); + + sourceTask.initialize(EasyMock.anyObject(SourceTaskContext.class)); + EasyMock.expectLastCall(); + sourceTask.start(TASK_PROPS); + EasyMock.expectLastCall(); + statusListener.onStartup(taskId); + EasyMock.expectLastCall(); + + final CountDownLatch pollLatch = new CountDownLatch(1); + final RuntimeException exception = new RuntimeException(); + EasyMock.expect(sourceTask.poll()).andAnswer(() -> { + pollLatch.countDown(); + throw exception; + }); + + statusListener.onFailure(taskId, exception); + EasyMock.expectLastCall(); + + sourceTask.stop(); + EasyMock.expectLastCall(); + expectOffsetFlush(true); + + expectClose(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + Future taskFuture = executor.submit(workerTask); + + assertTrue(awaitLatch(pollLatch)); + //Failure in poll should trigger automatic stop of the worker + assertTrue(workerTask.awaitStop(1000)); + + taskFuture.get(); + assertPollMetrics(0); + + PowerMock.verifyAll(); + } + + @Test + public void testFailureInPollAfterCancel() throws Exception { + createWorkerTask(); + + sourceTask.initialize(EasyMock.anyObject(SourceTaskContext.class)); + EasyMock.expectLastCall(); + sourceTask.start(TASK_PROPS); + EasyMock.expectLastCall(); + statusListener.onStartup(taskId); + EasyMock.expectLastCall(); + + final CountDownLatch pollLatch = new CountDownLatch(1); + final CountDownLatch workerCancelLatch = new CountDownLatch(1); + final RuntimeException exception = new RuntimeException(); + EasyMock.expect(sourceTask.poll()).andAnswer(() -> { + pollLatch.countDown(); + assertTrue(awaitLatch(workerCancelLatch)); + throw exception; + }); + + offsetReader.close(); + PowerMock.expectLastCall(); + + producer.close(Duration.ZERO); + PowerMock.expectLastCall(); + + sourceTask.stop(); + EasyMock.expectLastCall(); + expectOffsetFlush(true); + + expectClose(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + Future taskFuture = executor.submit(workerTask); + + assertTrue(awaitLatch(pollLatch)); + workerTask.cancel(); + workerCancelLatch.countDown(); + assertTrue(workerTask.awaitStop(1000)); + + taskFuture.get(); + assertPollMetrics(0); + + PowerMock.verifyAll(); + } + + @Test + public void testFailureInPollAfterStop() throws Exception { + createWorkerTask(); + + sourceTask.initialize(EasyMock.anyObject(SourceTaskContext.class)); + EasyMock.expectLastCall(); + sourceTask.start(TASK_PROPS); + EasyMock.expectLastCall(); + statusListener.onStartup(taskId); + EasyMock.expectLastCall(); + + final CountDownLatch pollLatch = new CountDownLatch(1); + final CountDownLatch workerStopLatch = new CountDownLatch(1); + final RuntimeException exception = new RuntimeException(); + EasyMock.expect(sourceTask.poll()).andAnswer(() -> { + pollLatch.countDown(); + assertTrue(awaitLatch(workerStopLatch)); + throw exception; + }); + + statusListener.onShutdown(taskId); + EasyMock.expectLastCall(); + + sourceTask.stop(); + EasyMock.expectLastCall(); + expectOffsetFlush(true); + + expectClose(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + Future taskFuture = executor.submit(workerTask); + + assertTrue(awaitLatch(pollLatch)); + workerTask.stop(); + workerStopLatch.countDown(); + assertTrue(workerTask.awaitStop(1000)); + + taskFuture.get(); + assertPollMetrics(0); + + PowerMock.verifyAll(); + } + + @Test + public void testPollReturnsNoRecords() throws Exception { + // Test that the task handles an empty list of records + createWorkerTask(); + + sourceTask.initialize(EasyMock.anyObject(SourceTaskContext.class)); + EasyMock.expectLastCall(); + sourceTask.start(TASK_PROPS); + EasyMock.expectLastCall(); + statusListener.onStartup(taskId); + EasyMock.expectLastCall(); + + // We'll wait for some data, then trigger a flush + final CountDownLatch pollLatch = expectEmptyPolls(1, new AtomicInteger()); + expectOffsetFlush(true); + + sourceTask.stop(); + EasyMock.expectLastCall(); + expectOffsetFlush(true); + + statusListener.onShutdown(taskId); + EasyMock.expectLastCall(); + + expectClose(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + Future taskFuture = executor.submit(workerTask); + + assertTrue(awaitLatch(pollLatch)); + assertTrue(workerTask.commitOffsets()); + workerTask.stop(); + assertTrue(workerTask.awaitStop(1000)); + + taskFuture.get(); + assertPollMetrics(0); + + PowerMock.verifyAll(); + } + + @Test + public void testCommit() throws Exception { + // Test that the task commits properly when prompted + createWorkerTask(); + + sourceTask.initialize(EasyMock.anyObject(SourceTaskContext.class)); + EasyMock.expectLastCall(); + sourceTask.start(TASK_PROPS); + EasyMock.expectLastCall(); + statusListener.onStartup(taskId); + EasyMock.expectLastCall(); + + // We'll wait for some data, then trigger a flush + final CountDownLatch pollLatch = expectPolls(1); + expectOffsetFlush(true); + + offsetWriter.offset(PARTITION, OFFSET); + PowerMock.expectLastCall().atLeastOnce(); + + expectTopicCreation(TOPIC); + + sourceTask.stop(); + EasyMock.expectLastCall(); + expectOffsetFlush(true); + + statusListener.onShutdown(taskId); + EasyMock.expectLastCall(); + + expectClose(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + Future taskFuture = executor.submit(workerTask); + + assertTrue(awaitLatch(pollLatch)); + assertTrue(workerTask.commitOffsets()); + workerTask.stop(); + assertTrue(workerTask.awaitStop(1000)); + + taskFuture.get(); + assertPollMetrics(1); + + PowerMock.verifyAll(); + } + + @Test + public void testCommitFailure() throws Exception { + // Test that the task commits properly when prompted + createWorkerTask(); + + sourceTask.initialize(EasyMock.anyObject(SourceTaskContext.class)); + EasyMock.expectLastCall(); + sourceTask.start(TASK_PROPS); + EasyMock.expectLastCall(); + statusListener.onStartup(taskId); + EasyMock.expectLastCall(); + + // We'll wait for some data, then trigger a flush + final CountDownLatch pollLatch = expectPolls(1); + expectOffsetFlush(true); + + offsetWriter.offset(PARTITION, OFFSET); + PowerMock.expectLastCall().atLeastOnce(); + + expectTopicCreation(TOPIC); + + sourceTask.stop(); + EasyMock.expectLastCall(); + expectOffsetFlush(false); + + statusListener.onShutdown(taskId); + EasyMock.expectLastCall(); + + expectClose(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + Future taskFuture = executor.submit(workerTask); + + assertTrue(awaitLatch(pollLatch)); + assertTrue(workerTask.commitOffsets()); + workerTask.stop(); + assertTrue(workerTask.awaitStop(1000)); + + taskFuture.get(); + assertPollMetrics(1); + + PowerMock.verifyAll(); + } + + @Test + public void testSendRecordsConvertsData() throws Exception { + createWorkerTask(); + + List records = new ArrayList<>(); + // Can just use the same record for key and value + records.add(new SourceRecord(PARTITION, OFFSET, "topic", null, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD)); + + Capture> sent = expectSendRecordAnyTimes(); + + expectTopicCreation(TOPIC); + + PowerMock.replayAll(); + + Whitebox.setInternalState(workerTask, "toSend", records); + Whitebox.invokeMethod(workerTask, "sendRecords"); + assertEquals(SERIALIZED_KEY, sent.getValue().key()); + assertEquals(SERIALIZED_RECORD, sent.getValue().value()); + + PowerMock.verifyAll(); + } + + @Test + public void testSendRecordsPropagatesTimestamp() throws Exception { + final Long timestamp = System.currentTimeMillis(); + + createWorkerTask(); + + List records = Collections.singletonList( + new SourceRecord(PARTITION, OFFSET, "topic", null, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD, timestamp) + ); + + Capture> sent = expectSendRecordAnyTimes(); + + expectTopicCreation(TOPIC); + + PowerMock.replayAll(); + + Whitebox.setInternalState(workerTask, "toSend", records); + Whitebox.invokeMethod(workerTask, "sendRecords"); + assertEquals(timestamp, sent.getValue().timestamp()); + + PowerMock.verifyAll(); + } + + @Test + public void testSendRecordsCorruptTimestamp() throws Exception { + final Long timestamp = -3L; + createWorkerTask(); + + List records = Collections.singletonList( + new SourceRecord(PARTITION, OFFSET, "topic", null, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD, timestamp) + ); + + Capture> sent = expectSendRecordAnyTimes(); + + PowerMock.replayAll(); + + Whitebox.setInternalState(workerTask, "toSend", records); + assertThrows(InvalidRecordException.class, () -> Whitebox.invokeMethod(workerTask, "sendRecords")); + assertFalse(sent.hasCaptured()); + + PowerMock.verifyAll(); + } + + @Test + public void testSendRecordsNoTimestamp() throws Exception { + final Long timestamp = -1L; + createWorkerTask(); + + List records = Collections.singletonList( + new SourceRecord(PARTITION, OFFSET, "topic", null, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD, timestamp) + ); + + Capture> sent = expectSendRecordAnyTimes(); + + expectTopicCreation(TOPIC); + + PowerMock.replayAll(); + + Whitebox.setInternalState(workerTask, "toSend", records); + Whitebox.invokeMethod(workerTask, "sendRecords"); + assertNull(sent.getValue().timestamp()); + + PowerMock.verifyAll(); + } + + @Test + public void testSendRecordsRetries() throws Exception { + createWorkerTask(); + + // Differentiate only by Kafka partition so we can reuse conversion expectations + SourceRecord record1 = new SourceRecord(PARTITION, OFFSET, "topic", 1, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + SourceRecord record2 = new SourceRecord(PARTITION, OFFSET, "topic", 2, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + SourceRecord record3 = new SourceRecord(PARTITION, OFFSET, "topic", 3, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + + expectTopicCreation(TOPIC); + + // First round + expectSendRecordOnce(); + // Any Producer retriable exception should work here + expectSendRecordSyncFailure(new org.apache.kafka.common.errors.TimeoutException("retriable sync failure")); + + // Second round + expectSendRecordOnce(); + expectSendRecordOnce(); + + PowerMock.replayAll(); + + // Try to send 3, make first pass, second fail. Should save last two + Whitebox.setInternalState(workerTask, "toSend", Arrays.asList(record1, record2, record3)); + Whitebox.invokeMethod(workerTask, "sendRecords"); + assertEquals(Arrays.asList(record2, record3), Whitebox.getInternalState(workerTask, "toSend")); + + // Next they all succeed + Whitebox.invokeMethod(workerTask, "sendRecords"); + assertNull(Whitebox.getInternalState(workerTask, "toSend")); + + PowerMock.verifyAll(); + } + + @Test + public void testSendRecordsProducerCallbackFail() throws Exception { + createWorkerTask(); + + SourceRecord record1 = new SourceRecord(PARTITION, OFFSET, "topic", 1, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + SourceRecord record2 = new SourceRecord(PARTITION, OFFSET, "topic", 2, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + + expectTopicCreation(TOPIC); + + expectSendRecordProducerCallbackFail(); + + PowerMock.replayAll(); + + Whitebox.setInternalState(workerTask, "toSend", Arrays.asList(record1, record2)); + assertThrows(ConnectException.class, () -> Whitebox.invokeMethod(workerTask, "sendRecords")); + } + + @Test + public void testSendRecordsProducerSendFailsImmediately() { + if (!enableTopicCreation) + // should only test with topic creation enabled + return; + + createWorkerTask(); + + SourceRecord record1 = new SourceRecord(PARTITION, OFFSET, TOPIC, 1, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + SourceRecord record2 = new SourceRecord(PARTITION, OFFSET, TOPIC, 2, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + + expectPreliminaryCalls(); + expectTopicCreation(TOPIC); + + EasyMock.expect(producer.send(EasyMock.anyObject(), EasyMock.anyObject())) + .andThrow(new KafkaException("Producer closed while send in progress", new InvalidTopicException(TOPIC))); + + PowerMock.replayAll(); + + Whitebox.setInternalState(workerTask, "toSend", Arrays.asList(record1, record2)); + assertThrows(ConnectException.class, () -> Whitebox.invokeMethod(workerTask, "sendRecords")); + } + + @Test + public void testSendRecordsTaskCommitRecordFail() throws Exception { + createWorkerTask(); + + // Differentiate only by Kafka partition so we can reuse conversion expectations + SourceRecord record1 = new SourceRecord(PARTITION, OFFSET, "topic", 1, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + SourceRecord record2 = new SourceRecord(PARTITION, OFFSET, "topic", 2, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + SourceRecord record3 = new SourceRecord(PARTITION, OFFSET, "topic", 3, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + + expectTopicCreation(TOPIC); + + // Source task commit record failure will not cause the task to abort + expectSendRecordOnce(); + expectSendRecordTaskCommitRecordFail(false); + expectSendRecordOnce(); + + PowerMock.replayAll(); + + Whitebox.setInternalState(workerTask, "toSend", Arrays.asList(record1, record2, record3)); + Whitebox.invokeMethod(workerTask, "sendRecords"); + assertNull(Whitebox.getInternalState(workerTask, "toSend")); + + PowerMock.verifyAll(); + } + + @Test + public void testSlowTaskStart() throws Exception { + final CountDownLatch startupLatch = new CountDownLatch(1); + final CountDownLatch finishStartupLatch = new CountDownLatch(1); + + createWorkerTask(); + + sourceTask.initialize(EasyMock.anyObject(SourceTaskContext.class)); + EasyMock.expectLastCall(); + sourceTask.start(TASK_PROPS); + EasyMock.expectLastCall().andAnswer(() -> { + startupLatch.countDown(); + assertTrue(awaitLatch(finishStartupLatch)); + return null; + }); + + statusListener.onStartup(taskId); + EasyMock.expectLastCall(); + + sourceTask.stop(); + EasyMock.expectLastCall(); + expectOffsetFlush(true); + + statusListener.onShutdown(taskId); + EasyMock.expectLastCall(); + + expectClose(); + + PowerMock.replayAll(); + + workerTask.initialize(TASK_CONFIG); + Future workerTaskFuture = executor.submit(workerTask); + + // Stopping immediately while the other thread has work to do should result in no polling, no offset commits, + // exiting the work thread immediately, and the stop() method will be invoked in the background thread since it + // cannot be invoked immediately in the thread trying to stop the task. + assertTrue(awaitLatch(startupLatch)); + workerTask.stop(); + finishStartupLatch.countDown(); + assertTrue(workerTask.awaitStop(1000)); + + workerTaskFuture.get(); + + PowerMock.verifyAll(); + } + + @Test + public void testCancel() { + createWorkerTask(); + + offsetReader.close(); + PowerMock.expectLastCall(); + + producer.close(Duration.ZERO); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + workerTask.cancel(); + + PowerMock.verifyAll(); + } + + @Test + public void testMetricsGroup() { + SourceTaskMetricsGroup group = new SourceTaskMetricsGroup(taskId, metrics); + SourceTaskMetricsGroup group1 = new SourceTaskMetricsGroup(taskId1, metrics); + for (int i = 0; i != 10; ++i) { + group.recordPoll(100, 1000 + i * 100); + group.recordWrite(10); + } + for (int i = 0; i != 20; ++i) { + group1.recordPoll(100, 1000 + i * 100); + group1.recordWrite(10); + } + assertEquals(1900.0, metrics.currentMetricValueAsDouble(group.metricGroup(), "poll-batch-max-time-ms"), 0.001d); + assertEquals(1450.0, metrics.currentMetricValueAsDouble(group.metricGroup(), "poll-batch-avg-time-ms"), 0.001d); + assertEquals(33.333, metrics.currentMetricValueAsDouble(group.metricGroup(), "source-record-poll-rate"), 0.001d); + assertEquals(1000, metrics.currentMetricValueAsDouble(group.metricGroup(), "source-record-poll-total"), 0.001d); + assertEquals(3.3333, metrics.currentMetricValueAsDouble(group.metricGroup(), "source-record-write-rate"), 0.001d); + assertEquals(100, metrics.currentMetricValueAsDouble(group.metricGroup(), "source-record-write-total"), 0.001d); + assertEquals(900.0, metrics.currentMetricValueAsDouble(group.metricGroup(), "source-record-active-count"), 0.001d); + + // Close the group + group.close(); + + for (MetricName metricName : group.metricGroup().metrics().metrics().keySet()) { + // Metrics for this group should no longer exist + assertFalse(group.metricGroup().groupId().includes(metricName)); + } + // Sensors for this group should no longer exist + assertNull(group.metricGroup().metrics().getSensor("sink-record-read")); + assertNull(group.metricGroup().metrics().getSensor("sink-record-send")); + assertNull(group.metricGroup().metrics().getSensor("sink-record-active-count")); + assertNull(group.metricGroup().metrics().getSensor("partition-count")); + assertNull(group.metricGroup().metrics().getSensor("offset-seq-number")); + assertNull(group.metricGroup().metrics().getSensor("offset-commit-completion")); + assertNull(group.metricGroup().metrics().getSensor("offset-commit-completion-skip")); + assertNull(group.metricGroup().metrics().getSensor("put-batch-time")); + + assertEquals(2900.0, metrics.currentMetricValueAsDouble(group1.metricGroup(), "poll-batch-max-time-ms"), 0.001d); + assertEquals(1950.0, metrics.currentMetricValueAsDouble(group1.metricGroup(), "poll-batch-avg-time-ms"), 0.001d); + assertEquals(66.667, metrics.currentMetricValueAsDouble(group1.metricGroup(), "source-record-poll-rate"), 0.001d); + assertEquals(2000, metrics.currentMetricValueAsDouble(group1.metricGroup(), "source-record-poll-total"), 0.001d); + assertEquals(6.667, metrics.currentMetricValueAsDouble(group1.metricGroup(), "source-record-write-rate"), 0.001d); + assertEquals(200, metrics.currentMetricValueAsDouble(group1.metricGroup(), "source-record-write-total"), 0.001d); + assertEquals(1800.0, metrics.currentMetricValueAsDouble(group1.metricGroup(), "source-record-active-count"), 0.001d); + } + + @Test + public void testHeaders() throws Exception { + Headers headers = new RecordHeaders(); + headers.add("header_key", "header_value".getBytes()); + + org.apache.kafka.connect.header.Headers connectHeaders = new ConnectHeaders(); + connectHeaders.add("header_key", new SchemaAndValue(Schema.STRING_SCHEMA, "header_value")); + + createWorkerTask(); + + List records = new ArrayList<>(); + records.add(new SourceRecord(PARTITION, OFFSET, TOPIC, null, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD, null, connectHeaders)); + + expectTopicCreation(TOPIC); + + Capture> sent = expectSendRecord(TOPIC, true, true, true, true, headers); + + PowerMock.replayAll(); + + Whitebox.setInternalState(workerTask, "toSend", records); + Whitebox.invokeMethod(workerTask, "sendRecords"); + assertEquals(SERIALIZED_KEY, sent.getValue().key()); + assertEquals(SERIALIZED_RECORD, sent.getValue().value()); + assertEquals(headers, sent.getValue().headers()); + + PowerMock.verifyAll(); + } + + @Test + public void testHeadersWithCustomConverter() throws Exception { + StringConverter stringConverter = new StringConverter(); + TestConverterWithHeaders testConverter = new TestConverterWithHeaders(); + + createWorkerTask(TargetState.STARTED, stringConverter, testConverter, stringConverter); + + List records = new ArrayList<>(); + + String stringA = "Árvíztűrő tükörfúrógép"; + org.apache.kafka.connect.header.Headers headersA = new ConnectHeaders(); + String encodingA = "latin2"; + headersA.addString("encoding", encodingA); + + records.add(new SourceRecord(PARTITION, OFFSET, "topic", null, Schema.STRING_SCHEMA, "a", Schema.STRING_SCHEMA, stringA, null, headersA)); + + String stringB = "Тестовое сообщение"; + org.apache.kafka.connect.header.Headers headersB = new ConnectHeaders(); + String encodingB = "koi8_r"; + headersB.addString("encoding", encodingB); + + records.add(new SourceRecord(PARTITION, OFFSET, "topic", null, Schema.STRING_SCHEMA, "b", Schema.STRING_SCHEMA, stringB, null, headersB)); + + expectTopicCreation(TOPIC); + + Capture> sentRecordA = expectSendRecord(TOPIC, false, true, true, false, null); + Capture> sentRecordB = expectSendRecord(TOPIC, false, true, true, false, null); + + PowerMock.replayAll(); + + Whitebox.setInternalState(workerTask, "toSend", records); + Whitebox.invokeMethod(workerTask, "sendRecords"); + + assertEquals(ByteBuffer.wrap("a".getBytes()), ByteBuffer.wrap(sentRecordA.getValue().key())); + assertEquals( + ByteBuffer.wrap(stringA.getBytes(encodingA)), + ByteBuffer.wrap(sentRecordA.getValue().value()) + ); + assertEquals(encodingA, new String(sentRecordA.getValue().headers().lastHeader("encoding").value())); + + assertEquals(ByteBuffer.wrap("b".getBytes()), ByteBuffer.wrap(sentRecordB.getValue().key())); + assertEquals( + ByteBuffer.wrap(stringB.getBytes(encodingB)), + ByteBuffer.wrap(sentRecordB.getValue().value()) + ); + assertEquals(encodingB, new String(sentRecordB.getValue().headers().lastHeader("encoding").value())); + + PowerMock.verifyAll(); + } + + @Test + public void testTopicCreateWhenTopicExists() throws Exception { + if (!enableTopicCreation) + // should only test with topic creation enabled + return; + + createWorkerTask(); + + SourceRecord record1 = new SourceRecord(PARTITION, OFFSET, TOPIC, 1, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + SourceRecord record2 = new SourceRecord(PARTITION, OFFSET, TOPIC, 2, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + + expectPreliminaryCalls(); + TopicPartitionInfo topicPartitionInfo = new TopicPartitionInfo(0, null, Collections.emptyList(), Collections.emptyList()); + TopicDescription topicDesc = new TopicDescription(TOPIC, false, Collections.singletonList(topicPartitionInfo)); + EasyMock.expect(admin.describeTopics(TOPIC)).andReturn(Collections.singletonMap(TOPIC, topicDesc)); + + expectSendRecordTaskCommitRecordSucceed(false); + expectSendRecordTaskCommitRecordSucceed(false); + + PowerMock.replayAll(); + + Whitebox.setInternalState(workerTask, "toSend", Arrays.asList(record1, record2)); + Whitebox.invokeMethod(workerTask, "sendRecords"); + } + + @Test + public void testSendRecordsTopicDescribeRetries() throws Exception { + if (!enableTopicCreation) + // should only test with topic creation enabled + return; + + createWorkerTask(); + + SourceRecord record1 = new SourceRecord(PARTITION, OFFSET, TOPIC, 1, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + SourceRecord record2 = new SourceRecord(PARTITION, OFFSET, TOPIC, 2, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + + expectPreliminaryCalls(); + // First round - call to describe the topic times out + EasyMock.expect(admin.describeTopics(TOPIC)) + .andThrow(new RetriableException(new TimeoutException("timeout"))); + + // Second round - calls to describe and create succeed + expectTopicCreation(TOPIC); + // Exactly two records are sent + expectSendRecordTaskCommitRecordSucceed(false); + expectSendRecordTaskCommitRecordSucceed(false); + + PowerMock.replayAll(); + + Whitebox.setInternalState(workerTask, "toSend", Arrays.asList(record1, record2)); + Whitebox.invokeMethod(workerTask, "sendRecords"); + assertEquals(Arrays.asList(record1, record2), Whitebox.getInternalState(workerTask, "toSend")); + + // Next they all succeed + Whitebox.invokeMethod(workerTask, "sendRecords"); + assertNull(Whitebox.getInternalState(workerTask, "toSend")); + } + + @Test + public void testSendRecordsTopicCreateRetries() throws Exception { + if (!enableTopicCreation) + // should only test with topic creation enabled + return; + + createWorkerTask(); + + SourceRecord record1 = new SourceRecord(PARTITION, OFFSET, TOPIC, 1, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + SourceRecord record2 = new SourceRecord(PARTITION, OFFSET, TOPIC, 2, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + + // First call to describe the topic times out + expectPreliminaryCalls(); + EasyMock.expect(admin.describeTopics(TOPIC)).andReturn(Collections.emptyMap()); + Capture newTopicCapture = EasyMock.newCapture(); + EasyMock.expect(admin.createOrFindTopics(EasyMock.capture(newTopicCapture))) + .andThrow(new RetriableException(new TimeoutException("timeout"))); + + // Second round + expectTopicCreation(TOPIC); + expectSendRecordTaskCommitRecordSucceed(false); + expectSendRecordTaskCommitRecordSucceed(false); + + PowerMock.replayAll(); + + Whitebox.setInternalState(workerTask, "toSend", Arrays.asList(record1, record2)); + Whitebox.invokeMethod(workerTask, "sendRecords"); + assertEquals(Arrays.asList(record1, record2), Whitebox.getInternalState(workerTask, "toSend")); + + // Next they all succeed + Whitebox.invokeMethod(workerTask, "sendRecords"); + assertNull(Whitebox.getInternalState(workerTask, "toSend")); + } + + @Test + public void testSendRecordsTopicDescribeRetriesMidway() throws Exception { + if (!enableTopicCreation) + // should only test with topic creation enabled + return; + + createWorkerTask(); + + // Differentiate only by Kafka partition so we can reuse conversion expectations + SourceRecord record1 = new SourceRecord(PARTITION, OFFSET, TOPIC, 1, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + SourceRecord record2 = new SourceRecord(PARTITION, OFFSET, TOPIC, 2, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + SourceRecord record3 = new SourceRecord(PARTITION, OFFSET, OTHER_TOPIC, 3, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + + // First round + expectPreliminaryCalls(OTHER_TOPIC); + expectTopicCreation(TOPIC); + expectSendRecordTaskCommitRecordSucceed(false); + expectSendRecordTaskCommitRecordSucceed(false); + + // First call to describe the topic times out + EasyMock.expect(admin.describeTopics(OTHER_TOPIC)) + .andThrow(new RetriableException(new TimeoutException("timeout"))); + + // Second round + expectTopicCreation(OTHER_TOPIC); + expectSendRecord(OTHER_TOPIC, false, true, true, true, emptyHeaders()); + + PowerMock.replayAll(); + + // Try to send 3, make first pass, second fail. Should save last two + Whitebox.setInternalState(workerTask, "toSend", Arrays.asList(record1, record2, record3)); + Whitebox.invokeMethod(workerTask, "sendRecords"); + assertEquals(Arrays.asList(record3), Whitebox.getInternalState(workerTask, "toSend")); + + // Next they all succeed + Whitebox.invokeMethod(workerTask, "sendRecords"); + assertNull(Whitebox.getInternalState(workerTask, "toSend")); + + PowerMock.verifyAll(); + } + + @Test + public void testSendRecordsTopicCreateRetriesMidway() throws Exception { + if (!enableTopicCreation) + // should only test with topic creation enabled + return; + + createWorkerTask(); + + // Differentiate only by Kafka partition so we can reuse conversion expectations + SourceRecord record1 = new SourceRecord(PARTITION, OFFSET, TOPIC, 1, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + SourceRecord record2 = new SourceRecord(PARTITION, OFFSET, TOPIC, 2, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + SourceRecord record3 = new SourceRecord(PARTITION, OFFSET, OTHER_TOPIC, 3, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + + // First round + expectPreliminaryCalls(OTHER_TOPIC); + expectTopicCreation(TOPIC); + expectSendRecordTaskCommitRecordSucceed(false); + expectSendRecordTaskCommitRecordSucceed(false); + + EasyMock.expect(admin.describeTopics(OTHER_TOPIC)).andReturn(Collections.emptyMap()); + // First call to create the topic times out + Capture newTopicCapture = EasyMock.newCapture(); + EasyMock.expect(admin.createOrFindTopics(EasyMock.capture(newTopicCapture))) + .andThrow(new RetriableException(new TimeoutException("timeout"))); + + // Second round + expectTopicCreation(OTHER_TOPIC); + expectSendRecord(OTHER_TOPIC, false, true, true, true, emptyHeaders()); + + PowerMock.replayAll(); + + // Try to send 3, make first pass, second fail. Should save last two + Whitebox.setInternalState(workerTask, "toSend", Arrays.asList(record1, record2, record3)); + Whitebox.invokeMethod(workerTask, "sendRecords"); + assertEquals(Arrays.asList(record3), Whitebox.getInternalState(workerTask, "toSend")); + + // Next they all succeed + Whitebox.invokeMethod(workerTask, "sendRecords"); + assertNull(Whitebox.getInternalState(workerTask, "toSend")); + + PowerMock.verifyAll(); + } + + @Test + public void testTopicDescribeFails() { + if (!enableTopicCreation) + // should only test with topic creation enabled + return; + + createWorkerTask(); + + SourceRecord record1 = new SourceRecord(PARTITION, OFFSET, TOPIC, 1, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + SourceRecord record2 = new SourceRecord(PARTITION, OFFSET, TOPIC, 2, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + + expectPreliminaryCalls(); + EasyMock.expect(admin.describeTopics(TOPIC)) + .andThrow(new ConnectException(new TopicAuthorizationException("unauthorized"))); + + PowerMock.replayAll(); + + Whitebox.setInternalState(workerTask, "toSend", Arrays.asList(record1, record2)); + assertThrows(ConnectException.class, () -> Whitebox.invokeMethod(workerTask, "sendRecords")); + } + + @Test + public void testTopicCreateFails() throws Exception { + if (!enableTopicCreation) + // should only test with topic creation enabled + return; + + createWorkerTask(); + + SourceRecord record1 = new SourceRecord(PARTITION, OFFSET, TOPIC, 1, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + SourceRecord record2 = new SourceRecord(PARTITION, OFFSET, TOPIC, 2, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + + expectPreliminaryCalls(); + EasyMock.expect(admin.describeTopics(TOPIC)).andReturn(Collections.emptyMap()); + + Capture newTopicCapture = EasyMock.newCapture(); + EasyMock.expect(admin.createOrFindTopics(EasyMock.capture(newTopicCapture))) + .andThrow(new ConnectException(new TopicAuthorizationException("unauthorized"))); + + PowerMock.replayAll(); + + Whitebox.setInternalState(workerTask, "toSend", Arrays.asList(record1, record2)); + assertThrows(ConnectException.class, () -> Whitebox.invokeMethod(workerTask, "sendRecords")); + assertTrue(newTopicCapture.hasCaptured()); + } + + @Test + public void testTopicCreateFailsWithExceptionWhenCreateReturnsTopicNotCreatedOrFound() throws Exception { + if (!enableTopicCreation) + // should only test with topic creation enabled + return; + + createWorkerTask(); + + SourceRecord record1 = new SourceRecord(PARTITION, OFFSET, TOPIC, 1, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + SourceRecord record2 = new SourceRecord(PARTITION, OFFSET, TOPIC, 2, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + + expectPreliminaryCalls(); + EasyMock.expect(admin.describeTopics(TOPIC)).andReturn(Collections.emptyMap()); + + Capture newTopicCapture = EasyMock.newCapture(); + EasyMock.expect(admin.createOrFindTopics(EasyMock.capture(newTopicCapture))).andReturn(TopicAdmin.EMPTY_CREATION); + + PowerMock.replayAll(); + + Whitebox.setInternalState(workerTask, "toSend", Arrays.asList(record1, record2)); + assertThrows(ConnectException.class, () -> Whitebox.invokeMethod(workerTask, "sendRecords")); + assertTrue(newTopicCapture.hasCaptured()); + } + + @Test + public void testTopicCreateSucceedsWhenCreateReturnsExistingTopicFound() throws Exception { + if (!enableTopicCreation) + // should only test with topic creation enabled + return; + + createWorkerTask(); + + SourceRecord record1 = new SourceRecord(PARTITION, OFFSET, TOPIC, 1, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + SourceRecord record2 = new SourceRecord(PARTITION, OFFSET, TOPIC, 2, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + + expectPreliminaryCalls(); + EasyMock.expect(admin.describeTopics(TOPIC)).andReturn(Collections.emptyMap()); + + Capture newTopicCapture = EasyMock.newCapture(); + EasyMock.expect(admin.createOrFindTopics(EasyMock.capture(newTopicCapture))).andReturn(foundTopic(TOPIC)); + + expectSendRecordTaskCommitRecordSucceed(false); + expectSendRecordTaskCommitRecordSucceed(false); + + PowerMock.replayAll(); + + Whitebox.setInternalState(workerTask, "toSend", Arrays.asList(record1, record2)); + Whitebox.invokeMethod(workerTask, "sendRecords"); + } + + @Test + public void testTopicCreateSucceedsWhenCreateReturnsNewTopicFound() throws Exception { + if (!enableTopicCreation) + // should only test with topic creation enabled + return; + + createWorkerTask(); + + SourceRecord record1 = new SourceRecord(PARTITION, OFFSET, TOPIC, 1, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + SourceRecord record2 = new SourceRecord(PARTITION, OFFSET, TOPIC, 2, KEY_SCHEMA, KEY, RECORD_SCHEMA, RECORD); + + expectPreliminaryCalls(); + EasyMock.expect(admin.describeTopics(TOPIC)).andReturn(Collections.emptyMap()); + + Capture newTopicCapture = EasyMock.newCapture(); + EasyMock.expect(admin.createOrFindTopics(EasyMock.capture(newTopicCapture))).andReturn(createdTopic(TOPIC)); + + expectSendRecordTaskCommitRecordSucceed(false); + expectSendRecordTaskCommitRecordSucceed(false); + + PowerMock.replayAll(); + + Whitebox.setInternalState(workerTask, "toSend", Arrays.asList(record1, record2)); + Whitebox.invokeMethod(workerTask, "sendRecords"); + } + + private TopicAdmin.TopicCreationResponse createdTopic(String topic) { + Set created = Collections.singleton(topic); + Set existing = Collections.emptySet(); + return new TopicAdmin.TopicCreationResponse(created, existing); + } + + private TopicAdmin.TopicCreationResponse foundTopic(String topic) { + Set created = Collections.emptySet(); + Set existing = Collections.singleton(topic); + return new TopicAdmin.TopicCreationResponse(created, existing); + } + + private void expectPreliminaryCalls() { + expectPreliminaryCalls(TOPIC); + } + + private void expectPreliminaryCalls(String topic) { + expectConvertHeadersAndKeyValue(topic, true, emptyHeaders()); + expectApplyTransformationChain(false); + } + + private CountDownLatch expectEmptyPolls(int minimum, final AtomicInteger count) throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(minimum); + // Note that we stub these to allow any number of calls because the thread will continue to + // run. The count passed in + latch returned just makes sure we get *at least* that number of + // calls + EasyMock.expect(sourceTask.poll()) + .andStubAnswer(() -> { + count.incrementAndGet(); + latch.countDown(); + Thread.sleep(10); + return Collections.emptyList(); + }); + return latch; + } + + private CountDownLatch expectPolls(int minimum, final AtomicInteger count) throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(minimum); + // Note that we stub these to allow any number of calls because the thread will continue to + // run. The count passed in + latch returned just makes sure we get *at least* that number of + // calls + EasyMock.expect(sourceTask.poll()) + .andStubAnswer(() -> { + count.incrementAndGet(); + latch.countDown(); + Thread.sleep(10); + return RECORDS; + }); + // Fallout of the poll() call + expectSendRecordAnyTimes(); + return latch; + } + + private CountDownLatch expectPolls(int count) throws InterruptedException { + return expectPolls(count, new AtomicInteger()); + } + + @SuppressWarnings("unchecked") + private void expectSendRecordSyncFailure(Throwable error) throws InterruptedException { + expectConvertHeadersAndKeyValue(false); + expectApplyTransformationChain(false); + + EasyMock.expect( + producer.send(EasyMock.anyObject(ProducerRecord.class), + EasyMock.anyObject(org.apache.kafka.clients.producer.Callback.class))) + .andThrow(error); + } + + private Capture> expectSendRecordAnyTimes() throws InterruptedException { + return expectSendRecordTaskCommitRecordSucceed(true); + } + + private Capture> expectSendRecordOnce() throws InterruptedException { + return expectSendRecordTaskCommitRecordSucceed(false); + } + + private Capture> expectSendRecordProducerCallbackFail() throws InterruptedException { + return expectSendRecord(TOPIC, false, false, false, true, emptyHeaders()); + } + + private Capture> expectSendRecordTaskCommitRecordSucceed(boolean anyTimes) throws InterruptedException { + return expectSendRecord(TOPIC, anyTimes, true, true, true, emptyHeaders()); + } + + private Capture> expectSendRecordTaskCommitRecordFail(boolean anyTimes) throws InterruptedException { + return expectSendRecord(TOPIC, anyTimes, true, false, true, emptyHeaders()); + } + + private Capture> expectSendRecord( + String topic, + boolean anyTimes, + boolean sendSuccess, + boolean commitSuccess, + boolean isMockedConverters, + Headers headers + ) throws InterruptedException { + if (isMockedConverters) { + expectConvertHeadersAndKeyValue(topic, anyTimes, headers); + } + + expectApplyTransformationChain(anyTimes); + + Capture> sent = EasyMock.newCapture(); + + // 1. Converted data passed to the producer, which will need callbacks invoked for flush to work + IExpectationSetters> expect = EasyMock.expect( + producer.send(EasyMock.capture(sent), + EasyMock.capture(producerCallbacks))); + IAnswer> expectResponse = () -> { + synchronized (producerCallbacks) { + for (org.apache.kafka.clients.producer.Callback cb : producerCallbacks.getValues()) { + if (sendSuccess) { + cb.onCompletion(new RecordMetadata(new TopicPartition("foo", 0), 0, 0, + 0L, 0, 0), null); + } else { + cb.onCompletion(null, new TopicAuthorizationException("foo")); + } + } + producerCallbacks.reset(); + } + return sendFuture; + }; + if (anyTimes) + expect.andStubAnswer(expectResponse); + else + expect.andAnswer(expectResponse); + + if (sendSuccess) { + // 2. As a result of a successful producer send callback, we'll notify the source task of the record commit + expectTaskCommitRecordWithOffset(anyTimes, commitSuccess); + expectTaskGetTopic(anyTimes); + } + + return sent; + } + + private void expectConvertHeadersAndKeyValue(boolean anyTimes) { + expectConvertHeadersAndKeyValue(TOPIC, anyTimes, emptyHeaders()); + } + + private void expectConvertHeadersAndKeyValue(String topic, boolean anyTimes, Headers headers) { + for (Header header : headers) { + IExpectationSetters convertHeaderExpect = EasyMock.expect(headerConverter.fromConnectHeader(topic, header.key(), Schema.STRING_SCHEMA, new String(header.value()))); + if (anyTimes) + convertHeaderExpect.andStubReturn(header.value()); + else + convertHeaderExpect.andReturn(header.value()); + } + IExpectationSetters convertKeyExpect = EasyMock.expect(keyConverter.fromConnectData(topic, headers, KEY_SCHEMA, KEY)); + if (anyTimes) + convertKeyExpect.andStubReturn(SERIALIZED_KEY); + else + convertKeyExpect.andReturn(SERIALIZED_KEY); + IExpectationSetters convertValueExpect = EasyMock.expect(valueConverter.fromConnectData(topic, headers, RECORD_SCHEMA, RECORD)); + if (anyTimes) + convertValueExpect.andStubReturn(SERIALIZED_RECORD); + else + convertValueExpect.andReturn(SERIALIZED_RECORD); + } + + private void expectApplyTransformationChain(boolean anyTimes) { + final Capture recordCapture = EasyMock.newCapture(); + IExpectationSetters convertKeyExpect = EasyMock.expect(transformationChain.apply(EasyMock.capture(recordCapture))); + if (anyTimes) + convertKeyExpect.andStubAnswer(recordCapture::getValue); + else + convertKeyExpect.andAnswer(recordCapture::getValue); + } + + private void expectTaskCommitRecordWithOffset(boolean anyTimes, boolean succeed) throws InterruptedException { + sourceTask.commitRecord(EasyMock.anyObject(SourceRecord.class), EasyMock.anyObject(RecordMetadata.class)); + IExpectationSetters expect = EasyMock.expectLastCall(); + if (!succeed) { + expect = expect.andThrow(new RuntimeException("Error committing record in source task")); + } + if (anyTimes) { + expect.anyTimes(); + } + } + + private void expectTaskGetTopic(boolean anyTimes) { + final Capture connectorCapture = EasyMock.newCapture(); + final Capture topicCapture = EasyMock.newCapture(); + IExpectationSetters expect = EasyMock.expect(statusBackingStore.getTopic( + EasyMock.capture(connectorCapture), + EasyMock.capture(topicCapture))); + if (anyTimes) { + expect.andStubAnswer(() -> new TopicStatus( + topicCapture.getValue(), + new ConnectorTaskId(connectorCapture.getValue(), 0), + Time.SYSTEM.milliseconds())); + } else { + expect.andAnswer(() -> new TopicStatus( + topicCapture.getValue(), + new ConnectorTaskId(connectorCapture.getValue(), 0), + Time.SYSTEM.milliseconds())); + } + if (connectorCapture.hasCaptured() && topicCapture.hasCaptured()) { + assertEquals("job", connectorCapture.getValue()); + assertEquals(TOPIC, topicCapture.getValue()); + } + } + + private boolean awaitLatch(CountDownLatch latch) { + try { + return latch.await(5000, TimeUnit.MILLISECONDS); + } catch (InterruptedException e) { + // ignore + } + return false; + } + + @SuppressWarnings("unchecked") + private void expectOffsetFlush(boolean succeed) throws Exception { + EasyMock.expect(offsetWriter.beginFlush()).andReturn(true); + Future flushFuture = PowerMock.createMock(Future.class); + EasyMock.expect(offsetWriter.doFlush(EasyMock.anyObject(Callback.class))).andReturn(flushFuture); + // Should throw for failure + IExpectationSetters futureGetExpect = EasyMock.expect( + flushFuture.get(EasyMock.anyLong(), EasyMock.anyObject(TimeUnit.class))); + if (succeed) { + sourceTask.commit(); + EasyMock.expectLastCall(); + futureGetExpect.andReturn(null); + } else { + futureGetExpect.andThrow(new TimeoutException()); + offsetWriter.cancelFlush(); + PowerMock.expectLastCall(); + } + } + + private void assertPollMetrics(int minimumPollCountExpected) { + MetricGroup sourceTaskGroup = workerTask.sourceTaskMetricsGroup().metricGroup(); + MetricGroup taskGroup = workerTask.taskMetricsGroup().metricGroup(); + double pollRate = metrics.currentMetricValueAsDouble(sourceTaskGroup, "source-record-poll-rate"); + double pollTotal = metrics.currentMetricValueAsDouble(sourceTaskGroup, "source-record-poll-total"); + if (minimumPollCountExpected > 0) { + assertEquals(RECORDS.size(), metrics.currentMetricValueAsDouble(taskGroup, "batch-size-max"), 0.000001d); + assertEquals(RECORDS.size(), metrics.currentMetricValueAsDouble(taskGroup, "batch-size-avg"), 0.000001d); + assertTrue(pollRate > 0.0d); + } else { + assertTrue(pollRate == 0.0d); + } + assertTrue(pollTotal >= minimumPollCountExpected); + + double writeRate = metrics.currentMetricValueAsDouble(sourceTaskGroup, "source-record-write-rate"); + double writeTotal = metrics.currentMetricValueAsDouble(sourceTaskGroup, "source-record-write-total"); + if (minimumPollCountExpected > 0) { + assertTrue(writeRate > 0.0d); + } else { + assertTrue(writeRate == 0.0d); + } + assertTrue(writeTotal >= minimumPollCountExpected); + + double pollBatchTimeMax = metrics.currentMetricValueAsDouble(sourceTaskGroup, "poll-batch-max-time-ms"); + double pollBatchTimeAvg = metrics.currentMetricValueAsDouble(sourceTaskGroup, "poll-batch-avg-time-ms"); + if (minimumPollCountExpected > 0) { + assertTrue(pollBatchTimeMax >= 0.0d); + } + assertTrue(Double.isNaN(pollBatchTimeAvg) || pollBatchTimeAvg > 0.0d); + double activeCount = metrics.currentMetricValueAsDouble(sourceTaskGroup, "source-record-active-count"); + double activeCountMax = metrics.currentMetricValueAsDouble(sourceTaskGroup, "source-record-active-count-max"); + assertEquals(0, activeCount, 0.000001d); + if (minimumPollCountExpected > 0) { + assertEquals(RECORDS.size(), activeCountMax, 0.000001d); + } + } + + private RecordHeaders emptyHeaders() { + return new RecordHeaders(); + } + + private abstract static class TestSourceTask extends SourceTask { + } + + private void expectClose() { + producer.close(EasyMock.anyObject(Duration.class)); + EasyMock.expectLastCall(); + + admin.close(EasyMock.anyObject(Duration.class)); + EasyMock.expectLastCall(); + + transformationChain.close(); + EasyMock.expectLastCall(); + } + + private void expectTopicCreation(String topic) { + if (config.topicCreationEnable()) { + EasyMock.expect(admin.describeTopics(topic)).andReturn(Collections.emptyMap()); + Capture newTopicCapture = EasyMock.newCapture(); + EasyMock.expect(admin.createOrFindTopics(EasyMock.capture(newTopicCapture))).andReturn(createdTopic(topic)); + } + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerTaskTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerTaskTest.java new file mode 100644 index 0000000..890c0f7 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerTaskTest.java @@ -0,0 +1,350 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.runtime.WorkerTask.TaskMetricsGroup; +import org.apache.kafka.connect.runtime.errors.RetryWithToleranceOperator; +import org.apache.kafka.connect.runtime.errors.RetryWithToleranceOperatorTest; +import org.apache.kafka.connect.sink.SinkTask; +import org.apache.kafka.connect.storage.StatusBackingStore; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.apache.kafka.common.utils.MockTime; +import org.easymock.EasyMock; +import org.easymock.Mock; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CountDownLatch; + +import static org.easymock.EasyMock.expectLastCall; +import static org.easymock.EasyMock.partialMockBuilder; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.verify; +import static org.junit.Assert.assertEquals; + +@RunWith(PowerMockRunner.class) +@PrepareForTest({WorkerTask.class}) +@PowerMockIgnore("javax.management.*") +public class WorkerTaskTest { + + private static final Map TASK_PROPS = new HashMap<>(); + static { + TASK_PROPS.put(TaskConfig.TASK_CLASS_CONFIG, TestSinkTask.class.getName()); + } + private static final TaskConfig TASK_CONFIG = new TaskConfig(TASK_PROPS); + + private ConnectMetrics metrics; + @Mock private TaskStatus.Listener statusListener; + @Mock private ClassLoader loader; + RetryWithToleranceOperator retryWithToleranceOperator; + @Mock + StatusBackingStore statusBackingStore; + + @Before + public void setup() { + metrics = new MockConnectMetrics(); + retryWithToleranceOperator = RetryWithToleranceOperatorTest.NOOP_OPERATOR; + } + + @After + public void tearDown() { + if (metrics != null) metrics.stop(); + } + + @Test + public void standardStartup() { + ConnectorTaskId taskId = new ConnectorTaskId("foo", 0); + + WorkerTask workerTask = partialMockBuilder(WorkerTask.class) + .withConstructor( + ConnectorTaskId.class, + TaskStatus.Listener.class, + TargetState.class, + ClassLoader.class, + ConnectMetrics.class, + RetryWithToleranceOperator.class, + Time.class, + StatusBackingStore.class + ) + .withArgs(taskId, statusListener, TargetState.STARTED, loader, metrics, + retryWithToleranceOperator, Time.SYSTEM, statusBackingStore) + .addMockedMethod("initialize") + .addMockedMethod("initializeAndStart") + .addMockedMethod("execute") + .addMockedMethod("close") + .createStrictMock(); + + workerTask.initialize(TASK_CONFIG); + expectLastCall(); + + workerTask.initializeAndStart(); + expectLastCall(); + + workerTask.execute(); + expectLastCall(); + + statusListener.onStartup(taskId); + expectLastCall(); + + workerTask.close(); + expectLastCall(); + + statusListener.onShutdown(taskId); + expectLastCall(); + + replay(workerTask); + + workerTask.initialize(TASK_CONFIG); + workerTask.run(); + workerTask.stop(); + workerTask.awaitStop(1000L); + + verify(workerTask); + } + + @Test + public void stopBeforeStarting() { + ConnectorTaskId taskId = new ConnectorTaskId("foo", 0); + + WorkerTask workerTask = partialMockBuilder(WorkerTask.class) + .withConstructor( + ConnectorTaskId.class, + TaskStatus.Listener.class, + TargetState.class, + ClassLoader.class, + ConnectMetrics.class, + RetryWithToleranceOperator.class, + Time.class, + StatusBackingStore.class + ) + .withArgs(taskId, statusListener, TargetState.STARTED, loader, metrics, + retryWithToleranceOperator, Time.SYSTEM, statusBackingStore) + .addMockedMethod("initialize") + .addMockedMethod("execute") + .addMockedMethod("close") + .createStrictMock(); + + workerTask.initialize(TASK_CONFIG); + EasyMock.expectLastCall(); + + workerTask.close(); + EasyMock.expectLastCall(); + + replay(workerTask); + + workerTask.initialize(TASK_CONFIG); + workerTask.stop(); + workerTask.awaitStop(1000L); + + // now run should not do anything + workerTask.run(); + + verify(workerTask); + } + + @Test + public void cancelBeforeStopping() throws Exception { + ConnectorTaskId taskId = new ConnectorTaskId("foo", 0); + + WorkerTask workerTask = partialMockBuilder(WorkerTask.class) + .withConstructor( + ConnectorTaskId.class, + TaskStatus.Listener.class, + TargetState.class, + ClassLoader.class, + ConnectMetrics.class, + RetryWithToleranceOperator.class, + Time.class, + StatusBackingStore.class + ) + .withArgs(taskId, statusListener, TargetState.STARTED, loader, metrics, + retryWithToleranceOperator, Time.SYSTEM, statusBackingStore) + .addMockedMethod("initialize") + .addMockedMethod("initializeAndStart") + .addMockedMethod("execute") + .addMockedMethod("close") + .createStrictMock(); + + final CountDownLatch stopped = new CountDownLatch(1); + final Thread thread = new Thread(() -> { + try { + stopped.await(); + } catch (Exception e) { + } + }); + + workerTask.initialize(TASK_CONFIG); + EasyMock.expectLastCall(); + + workerTask.initializeAndStart(); + EasyMock.expectLastCall(); + + workerTask.execute(); + expectLastCall().andAnswer(() -> { + thread.start(); + return null; + }); + + statusListener.onStartup(taskId); + expectLastCall(); + + workerTask.close(); + expectLastCall(); + + // there should be no call to onShutdown() + + replay(workerTask); + + workerTask.initialize(TASK_CONFIG); + workerTask.run(); + + workerTask.stop(); + workerTask.cancel(); + stopped.countDown(); + thread.join(); + + verify(workerTask); + } + + @Test + public void updateMetricsOnListenerEventsForStartupPauseResumeAndShutdown() { + ConnectorTaskId taskId = new ConnectorTaskId("foo", 0); + ConnectMetrics metrics = new MockConnectMetrics(); + TaskMetricsGroup group = new TaskMetricsGroup(taskId, metrics, statusListener); + + statusListener.onStartup(taskId); + expectLastCall(); + + statusListener.onPause(taskId); + expectLastCall(); + + statusListener.onResume(taskId); + expectLastCall(); + + statusListener.onShutdown(taskId); + expectLastCall(); + + replay(statusListener); + + group.onStartup(taskId); + assertRunningMetric(group); + group.onPause(taskId); + assertPausedMetric(group); + group.onResume(taskId); + assertRunningMetric(group); + group.onShutdown(taskId); + assertStoppedMetric(group); + + verify(statusListener); + } + + @Test + public void updateMetricsOnListenerEventsForStartupPauseResumeAndFailure() { + ConnectorTaskId taskId = new ConnectorTaskId("foo", 0); + MockConnectMetrics metrics = new MockConnectMetrics(); + MockTime time = metrics.time(); + ConnectException error = new ConnectException("error"); + TaskMetricsGroup group = new TaskMetricsGroup(taskId, metrics, statusListener); + + statusListener.onStartup(taskId); + expectLastCall(); + + statusListener.onPause(taskId); + expectLastCall(); + + statusListener.onResume(taskId); + expectLastCall(); + + statusListener.onPause(taskId); + expectLastCall(); + + statusListener.onResume(taskId); + expectLastCall(); + + statusListener.onFailure(taskId, error); + expectLastCall(); + + statusListener.onShutdown(taskId); + expectLastCall(); + + replay(statusListener); + + time.sleep(1000L); + group.onStartup(taskId); + assertRunningMetric(group); + + time.sleep(2000L); + group.onPause(taskId); + assertPausedMetric(group); + + time.sleep(3000L); + group.onResume(taskId); + assertRunningMetric(group); + + time.sleep(4000L); + group.onPause(taskId); + assertPausedMetric(group); + + time.sleep(5000L); + group.onResume(taskId); + assertRunningMetric(group); + + time.sleep(6000L); + group.onFailure(taskId, error); + assertFailedMetric(group); + + time.sleep(7000L); + group.onShutdown(taskId); + assertStoppedMetric(group); + + verify(statusListener); + + long totalTime = 27000L; + double pauseTimeRatio = (double) (3000L + 5000L) / totalTime; + double runningTimeRatio = (double) (2000L + 4000L + 6000L) / totalTime; + assertEquals(pauseTimeRatio, metrics.currentMetricValueAsDouble(group.metricGroup(), "pause-ratio"), 0.000001d); + assertEquals(runningTimeRatio, metrics.currentMetricValueAsDouble(group.metricGroup(), "running-ratio"), 0.000001d); + } + + private static abstract class TestSinkTask extends SinkTask { + } + + protected void assertFailedMetric(TaskMetricsGroup metricsGroup) { + assertEquals(AbstractStatus.State.FAILED, metricsGroup.state()); + } + + protected void assertPausedMetric(TaskMetricsGroup metricsGroup) { + assertEquals(AbstractStatus.State.PAUSED, metricsGroup.state()); + } + + protected void assertRunningMetric(TaskMetricsGroup metricsGroup) { + assertEquals(AbstractStatus.State.RUNNING, metricsGroup.state()); + } + + protected void assertStoppedMetric(TaskMetricsGroup metricsGroup) { + assertEquals(AbstractStatus.State.UNASSIGNED, metricsGroup.state()); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerTest.java new file mode 100644 index 0000000..046a9d9 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerTest.java @@ -0,0 +1,1609 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import java.util.Collection; +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.Configurable; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.config.provider.MockFileConfigProvider; +import org.apache.kafka.common.metrics.MetricsReporter; +import org.apache.kafka.common.metrics.stats.Avg; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.connect.connector.ConnectorContext; +import org.apache.kafka.connect.connector.Task; +import org.apache.kafka.connect.connector.policy.AllConnectorClientConfigOverridePolicy; +import org.apache.kafka.connect.connector.policy.ConnectorClientConfigOverridePolicy; +import org.apache.kafka.connect.connector.policy.NoneConnectorClientConfigOverridePolicy; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.json.JsonConverter; +import org.apache.kafka.connect.json.JsonConverterConfig; +import org.apache.kafka.connect.runtime.ConnectMetrics.MetricGroup; +import org.apache.kafka.connect.runtime.MockConnectMetrics.MockMetricsReporter; +import org.apache.kafka.connect.runtime.distributed.ClusterConfigState; +import org.apache.kafka.connect.runtime.errors.RetryWithToleranceOperator; +import org.apache.kafka.connect.runtime.isolation.DelegatingClassLoader; +import org.apache.kafka.connect.runtime.isolation.PluginClassLoader; +import org.apache.kafka.connect.runtime.isolation.Plugins; +import org.apache.kafka.connect.runtime.isolation.Plugins.ClassLoaderUsage; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorStateInfo; +import org.apache.kafka.connect.runtime.standalone.StandaloneConfig; +import org.apache.kafka.connect.sink.SinkConnector; +import org.apache.kafka.connect.sink.SinkTask; +import org.apache.kafka.connect.source.SourceConnector; +import org.apache.kafka.connect.source.SourceRecord; +import org.apache.kafka.connect.source.SourceTask; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.storage.HeaderConverter; +import org.apache.kafka.connect.storage.OffsetBackingStore; +import org.apache.kafka.connect.storage.OffsetStorageReader; +import org.apache.kafka.connect.storage.OffsetStorageWriter; +import org.apache.kafka.connect.storage.StatusBackingStore; +import org.apache.kafka.connect.util.ConnectUtils; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.apache.kafka.connect.util.FutureCallback; +import org.apache.kafka.connect.util.ParameterizedTest; +import org.apache.kafka.connect.util.ThreadedTest; +import org.apache.kafka.connect.util.TopicAdmin; +import org.apache.kafka.connect.util.TopicCreationGroup; +import org.easymock.EasyMock; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.easymock.PowerMock; +import org.powermock.api.easymock.annotation.Mock; +import org.powermock.api.easymock.annotation.MockNice; +import org.powermock.api.easymock.annotation.MockStrict; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; + +import javax.management.MBeanServer; +import javax.management.ObjectInstance; +import javax.management.ObjectName; +import java.lang.management.ManagementFactory; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; +import org.powermock.modules.junit4.PowerMockRunnerDelegate; + +import static org.apache.kafka.connect.runtime.TopicCreationConfig.DEFAULT_TOPIC_CREATION_PREFIX; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.PARTITIONS_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.REPLICATION_FACTOR_CONFIG; +import static org.apache.kafka.connect.runtime.WorkerConfig.TOPIC_CREATION_ENABLE_CONFIG; +import static org.apache.kafka.connect.runtime.errors.RetryWithToleranceOperatorTest.NOOP_OPERATOR; +import static org.easymock.EasyMock.anyObject; +import static org.easymock.EasyMock.eq; +import static org.easymock.EasyMock.expectLastCall; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.fail; + +@RunWith(PowerMockRunner.class) +@PowerMockRunnerDelegate(ParameterizedTest.class) +@PrepareForTest({Worker.class, Plugins.class, ConnectUtils.class}) +@PowerMockIgnore("javax.management.*") +public class WorkerTest extends ThreadedTest { + + private static final String CONNECTOR_ID = "test-connector"; + private static final ConnectorTaskId TASK_ID = new ConnectorTaskId("job", 0); + private static final String WORKER_ID = "localhost:8083"; + private static final String CLUSTER_ID = "test-cluster"; + private final ConnectorClientConfigOverridePolicy noneConnectorClientConfigOverridePolicy = new NoneConnectorClientConfigOverridePolicy(); + private final ConnectorClientConfigOverridePolicy allConnectorClientConfigOverridePolicy = new AllConnectorClientConfigOverridePolicy(); + + private Map workerProps = new HashMap<>(); + private WorkerConfig config; + private Worker worker; + + private Map defaultProducerConfigs = new HashMap<>(); + private Map defaultConsumerConfigs = new HashMap<>(); + + @Mock + private Plugins plugins; + @Mock + private PluginClassLoader pluginLoader; + @Mock + private DelegatingClassLoader delegatingLoader; + @Mock + private OffsetBackingStore offsetBackingStore; + @MockStrict + private TaskStatus.Listener taskStatusListener; + @MockStrict + private ConnectorStatus.Listener connectorStatusListener; + + @Mock private Herder herder; + @Mock private StatusBackingStore statusBackingStore; + @Mock private SourceConnector sourceConnector; + @Mock private SinkConnector sinkConnector; + @Mock private CloseableConnectorContext ctx; + @Mock private TestSourceTask task; + @Mock private WorkerSourceTask workerTask; + @Mock private Converter keyConverter; + @Mock private Converter valueConverter; + @Mock private Converter taskKeyConverter; + @Mock private Converter taskValueConverter; + @Mock private HeaderConverter taskHeaderConverter; + @Mock private ExecutorService executorService; + @MockNice private ConnectorConfig connectorConfig; + private String mockFileProviderTestId; + private Map connectorProps; + + private boolean enableTopicCreation; + + @ParameterizedTest.Parameters + public static Collection parameters() { + return Arrays.asList(false, true); + } + + public WorkerTest(boolean enableTopicCreation) { + this.enableTopicCreation = enableTopicCreation; + } + + @Before + public void setup() { + super.setup(); + workerProps.put("key.converter", "org.apache.kafka.connect.json.JsonConverter"); + workerProps.put("value.converter", "org.apache.kafka.connect.json.JsonConverter"); + workerProps.put("offset.storage.file.filename", "/tmp/connect.offsets"); + workerProps.put(CommonClientConfigs.METRIC_REPORTER_CLASSES_CONFIG, MockMetricsReporter.class.getName()); + workerProps.put("config.providers", "file"); + workerProps.put("config.providers.file.class", MockFileConfigProvider.class.getName()); + mockFileProviderTestId = UUID.randomUUID().toString(); + workerProps.put("config.providers.file.param.testId", mockFileProviderTestId); + workerProps.put(TOPIC_CREATION_ENABLE_CONFIG, String.valueOf(enableTopicCreation)); + config = new StandaloneConfig(workerProps); + + defaultProducerConfigs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + defaultProducerConfigs.put( + ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArraySerializer"); + defaultProducerConfigs.put( + ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArraySerializer"); + defaultProducerConfigs.put(ProducerConfig.MAX_BLOCK_MS_CONFIG, Long.toString(Long.MAX_VALUE)); + defaultProducerConfigs.put(ProducerConfig.ACKS_CONFIG, "all"); + defaultProducerConfigs.put(ProducerConfig.MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION, "1"); + defaultProducerConfigs.put(ProducerConfig.DELIVERY_TIMEOUT_MS_CONFIG, Integer.toString(Integer.MAX_VALUE)); + + defaultConsumerConfigs.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + defaultConsumerConfigs.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false"); + defaultConsumerConfigs.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + defaultConsumerConfigs + .put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArrayDeserializer"); + defaultConsumerConfigs + .put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArrayDeserializer"); + + // Some common defaults. They might change on individual tests + connectorProps = anyConnectorConfigMap(); + PowerMock.mockStatic(Plugins.class); + } + + @Test + public void testStartAndStopConnector() throws Throwable { + expectConverters(); + expectStartStorage(); + + final String connectorClass = WorkerTestConnector.class.getName(); + + // Create + EasyMock.expect(plugins.currentThreadLoader()).andReturn(delegatingLoader).times(2); + EasyMock.expect(plugins.delegatingLoader()).andReturn(delegatingLoader); + EasyMock.expect(delegatingLoader.connectorLoader(connectorClass)).andReturn(pluginLoader); + EasyMock.expect(plugins.newConnector(connectorClass)) + .andReturn(sourceConnector); + EasyMock.expect(sourceConnector.version()).andReturn("1.0"); + + connectorProps.put(ConnectorConfig.CONNECTOR_CLASS_CONFIG, connectorClass); + + EasyMock.expect(sourceConnector.version()).andReturn("1.0"); + + expectFileConfigProvider(); + EasyMock.expect(Plugins.compareAndSwapLoaders(pluginLoader)) + .andReturn(delegatingLoader) + .times(3); + sourceConnector.initialize(anyObject(ConnectorContext.class)); + EasyMock.expectLastCall(); + sourceConnector.start(connectorProps); + EasyMock.expectLastCall(); + + EasyMock.expect(Plugins.compareAndSwapLoaders(delegatingLoader)) + .andReturn(pluginLoader).times(3); + + connectorStatusListener.onStartup(CONNECTOR_ID); + EasyMock.expectLastCall(); + + // Remove + sourceConnector.stop(); + EasyMock.expectLastCall(); + + connectorStatusListener.onShutdown(CONNECTOR_ID); + EasyMock.expectLastCall(); + + ctx.close(); + expectLastCall(); + + expectStopStorage(); + expectClusterId(); + + PowerMock.replayAll(); + + worker = new Worker(WORKER_ID, new MockTime(), plugins, config, offsetBackingStore, noneConnectorClientConfigOverridePolicy); + worker.herder = herder; + worker.start(); + + assertEquals(Collections.emptySet(), worker.connectorNames()); + + FutureCallback onFirstStart = new FutureCallback<>(); + worker.startConnector(CONNECTOR_ID, connectorProps, ctx, connectorStatusListener, TargetState.STARTED, onFirstStart); + // Wait for the connector to actually start + assertEquals(TargetState.STARTED, onFirstStart.get(1000, TimeUnit.MILLISECONDS)); + assertEquals(new HashSet<>(Arrays.asList(CONNECTOR_ID)), worker.connectorNames()); + + FutureCallback onSecondStart = new FutureCallback<>(); + worker.startConnector(CONNECTOR_ID, connectorProps, ctx, connectorStatusListener, TargetState.STARTED, onSecondStart); + try { + onSecondStart.get(0, TimeUnit.MILLISECONDS); + fail("Should have failed while trying to start second connector with same name"); + } catch (ExecutionException e) { + assertThat(e.getCause(), instanceOf(ConnectException.class)); + } + + assertStatistics(worker, 1, 0); + assertStartupStatistics(worker, 1, 0, 0, 0); + worker.stopAndAwaitConnector(CONNECTOR_ID); + assertStatistics(worker, 0, 0); + assertStartupStatistics(worker, 1, 0, 0, 0); + assertEquals(Collections.emptySet(), worker.connectorNames()); + // Nothing should be left, so this should effectively be a nop + worker.stop(); + assertStatistics(worker, 0, 0); + + PowerMock.verifyAll(); + MockFileConfigProvider.assertClosed(mockFileProviderTestId); + } + + private void expectFileConfigProvider() { + EasyMock.expect(plugins.newConfigProvider(EasyMock.anyObject(), + EasyMock.eq("config.providers.file"), EasyMock.anyObject())) + .andAnswer(() -> { + MockFileConfigProvider mockFileConfigProvider = new MockFileConfigProvider(); + mockFileConfigProvider.configure(Collections.singletonMap("testId", mockFileProviderTestId)); + return mockFileConfigProvider; + }); + } + + @Test + public void testStartConnectorFailure() throws Exception { + expectConverters(); + expectStartStorage(); + expectFileConfigProvider(); + + final String nonConnectorClass = "java.util.HashMap"; + connectorProps.put(ConnectorConfig.CONNECTOR_CLASS_CONFIG, nonConnectorClass); // Bad connector class name + + Exception exception = new ConnectException("Failed to find Connector"); + EasyMock.expect(plugins.currentThreadLoader()).andReturn(delegatingLoader); + EasyMock.expect(plugins.delegatingLoader()).andReturn(delegatingLoader); + EasyMock.expect(delegatingLoader.connectorLoader(nonConnectorClass)).andReturn(delegatingLoader); + EasyMock.expect(Plugins.compareAndSwapLoaders(delegatingLoader)) + .andReturn(delegatingLoader).times(2); + EasyMock.expect(plugins.newConnector(EasyMock.anyString())) + .andThrow(exception); + + connectorStatusListener.onFailure( + EasyMock.eq(CONNECTOR_ID), + EasyMock.anyObject() + ); + EasyMock.expectLastCall(); + + expectClusterId(); + + PowerMock.replayAll(); + + worker = new Worker(WORKER_ID, new MockTime(), plugins, config, offsetBackingStore, noneConnectorClientConfigOverridePolicy); + worker.herder = herder; + worker.start(); + + assertStatistics(worker, 0, 0); + FutureCallback onStart = new FutureCallback<>(); + worker.startConnector(CONNECTOR_ID, connectorProps, ctx, connectorStatusListener, TargetState.STARTED, onStart); + try { + onStart.get(0, TimeUnit.MILLISECONDS); + fail("Should have failed to start connector"); + } catch (ExecutionException e) { + assertEquals(exception, e.getCause()); + } + + assertStartupStatistics(worker, 1, 1, 0, 0); + assertEquals(Collections.emptySet(), worker.connectorNames()); + + assertStatistics(worker, 0, 0); + assertStartupStatistics(worker, 1, 1, 0, 0); + worker.stopAndAwaitConnector(CONNECTOR_ID); + assertStatistics(worker, 0, 0); + assertStartupStatistics(worker, 1, 1, 0, 0); + + PowerMock.verifyAll(); + } + + @Test + public void testAddConnectorByAlias() throws Throwable { + expectConverters(); + expectStartStorage(); + expectFileConfigProvider(); + + final String connectorAlias = "WorkerTestConnector"; + + EasyMock.expect(plugins.currentThreadLoader()).andReturn(delegatingLoader).times(2); + EasyMock.expect(plugins.delegatingLoader()).andReturn(delegatingLoader); + EasyMock.expect(delegatingLoader.connectorLoader(connectorAlias)).andReturn(pluginLoader); + EasyMock.expect(plugins.newConnector(connectorAlias)).andReturn(sinkConnector); + EasyMock.expect(sinkConnector.version()).andReturn("1.0"); + + connectorProps.put(ConnectorConfig.CONNECTOR_CLASS_CONFIG, connectorAlias); + connectorProps.put(SinkConnectorConfig.TOPICS_CONFIG, "gfieyls, wfru"); + + EasyMock.expect(sinkConnector.version()).andReturn("1.0"); + EasyMock.expect(Plugins.compareAndSwapLoaders(pluginLoader)) + .andReturn(delegatingLoader) + .times(3); + sinkConnector.initialize(anyObject(ConnectorContext.class)); + EasyMock.expectLastCall(); + sinkConnector.start(connectorProps); + EasyMock.expectLastCall(); + + EasyMock.expect(Plugins.compareAndSwapLoaders(delegatingLoader)) + .andReturn(pluginLoader) + .times(3); + + connectorStatusListener.onStartup(CONNECTOR_ID); + EasyMock.expectLastCall(); + + // Remove + sinkConnector.stop(); + EasyMock.expectLastCall(); + + connectorStatusListener.onShutdown(CONNECTOR_ID); + EasyMock.expectLastCall(); + + ctx.close(); + expectLastCall(); + + expectStopStorage(); + expectClusterId(); + + PowerMock.replayAll(); + + worker = new Worker(WORKER_ID, new MockTime(), plugins, config, offsetBackingStore, noneConnectorClientConfigOverridePolicy); + worker.herder = herder; + worker.start(); + + assertStatistics(worker, 0, 0); + assertEquals(Collections.emptySet(), worker.connectorNames()); + FutureCallback onStart = new FutureCallback<>(); + worker.startConnector(CONNECTOR_ID, connectorProps, ctx, connectorStatusListener, TargetState.STARTED, onStart); + // Wait for the connector to actually start + assertEquals(TargetState.STARTED, onStart.get(1000, TimeUnit.MILLISECONDS)); + assertEquals(new HashSet<>(Arrays.asList(CONNECTOR_ID)), worker.connectorNames()); + assertStatistics(worker, 1, 0); + assertStartupStatistics(worker, 1, 0, 0, 0); + + worker.stopAndAwaitConnector(CONNECTOR_ID); + assertStatistics(worker, 0, 0); + assertStartupStatistics(worker, 1, 0, 0, 0); + assertEquals(Collections.emptySet(), worker.connectorNames()); + // Nothing should be left, so this should effectively be a nop + worker.stop(); + assertStatistics(worker, 0, 0); + assertStartupStatistics(worker, 1, 0, 0, 0); + + PowerMock.verifyAll(); + } + + @Test + public void testAddConnectorByShortAlias() throws Throwable { + expectConverters(); + expectStartStorage(); + expectFileConfigProvider(); + + final String shortConnectorAlias = "WorkerTest"; + + EasyMock.expect(plugins.currentThreadLoader()).andReturn(delegatingLoader).times(2); + EasyMock.expect(plugins.delegatingLoader()).andReturn(delegatingLoader); + EasyMock.expect(delegatingLoader.connectorLoader(shortConnectorAlias)).andReturn(pluginLoader); + EasyMock.expect(plugins.newConnector(shortConnectorAlias)).andReturn(sinkConnector); + EasyMock.expect(sinkConnector.version()).andReturn("1.0"); + + connectorProps.put(ConnectorConfig.CONNECTOR_CLASS_CONFIG, shortConnectorAlias); + connectorProps.put(SinkConnectorConfig.TOPICS_CONFIG, "gfieyls, wfru"); + + EasyMock.expect(sinkConnector.version()).andReturn("1.0"); + EasyMock.expect(Plugins.compareAndSwapLoaders(pluginLoader)) + .andReturn(delegatingLoader) + .times(3); + sinkConnector.initialize(anyObject(ConnectorContext.class)); + EasyMock.expectLastCall(); + sinkConnector.start(connectorProps); + EasyMock.expectLastCall(); + + EasyMock.expect(Plugins.compareAndSwapLoaders(delegatingLoader)) + .andReturn(pluginLoader) + .times(3); + + connectorStatusListener.onStartup(CONNECTOR_ID); + EasyMock.expectLastCall(); + + // Remove + sinkConnector.stop(); + EasyMock.expectLastCall(); + + connectorStatusListener.onShutdown(CONNECTOR_ID); + EasyMock.expectLastCall(); + + ctx.close(); + expectLastCall(); + + expectStopStorage(); + expectClusterId(); + + PowerMock.replayAll(); + + worker = new Worker(WORKER_ID, new MockTime(), plugins, config, offsetBackingStore, noneConnectorClientConfigOverridePolicy); + worker.herder = herder; + worker.start(); + + assertStatistics(worker, 0, 0); + assertEquals(Collections.emptySet(), worker.connectorNames()); + FutureCallback onStart = new FutureCallback<>(); + worker.startConnector(CONNECTOR_ID, connectorProps, ctx, connectorStatusListener, TargetState.STARTED, onStart); + // Wait for the connector to actually start + assertEquals(TargetState.STARTED, onStart.get(1000, TimeUnit.MILLISECONDS)); + assertEquals(new HashSet<>(Arrays.asList(CONNECTOR_ID)), worker.connectorNames()); + assertStatistics(worker, 1, 0); + + worker.stopAndAwaitConnector(CONNECTOR_ID); + assertStatistics(worker, 0, 0); + assertEquals(Collections.emptySet(), worker.connectorNames()); + // Nothing should be left, so this should effectively be a nop + worker.stop(); + assertStatistics(worker, 0, 0); + + PowerMock.verifyAll(); + } + + @Test + public void testStopInvalidConnector() { + expectConverters(); + expectStartStorage(); + expectFileConfigProvider(); + expectClusterId(); + + PowerMock.replayAll(); + + worker = new Worker(WORKER_ID, new MockTime(), plugins, config, offsetBackingStore, noneConnectorClientConfigOverridePolicy); + worker.herder = herder; + worker.start(); + + worker.stopAndAwaitConnector(CONNECTOR_ID); + + PowerMock.verifyAll(); + } + + @Test + public void testReconfigureConnectorTasks() throws Throwable { + expectConverters(); + expectStartStorage(); + expectFileConfigProvider(); + + final String connectorClass = WorkerTestConnector.class.getName(); + + EasyMock.expect(plugins.currentThreadLoader()).andReturn(delegatingLoader).times(3); + EasyMock.expect(plugins.delegatingLoader()).andReturn(delegatingLoader).times(1); + EasyMock.expect(delegatingLoader.connectorLoader(connectorClass)).andReturn(pluginLoader); + EasyMock.expect(plugins.newConnector(connectorClass)) + .andReturn(sinkConnector); + EasyMock.expect(sinkConnector.version()).andReturn("1.0"); + + connectorProps.put(SinkConnectorConfig.TOPICS_CONFIG, "foo,bar"); + connectorProps.put(ConnectorConfig.CONNECTOR_CLASS_CONFIG, connectorClass); + + EasyMock.expect(sinkConnector.version()).andReturn("1.0"); + EasyMock.expect(Plugins.compareAndSwapLoaders(pluginLoader)) + .andReturn(delegatingLoader) + .times(4); + sinkConnector.initialize(anyObject(ConnectorContext.class)); + EasyMock.expectLastCall(); + sinkConnector.start(connectorProps); + EasyMock.expectLastCall(); + + EasyMock.expect(Plugins.compareAndSwapLoaders(delegatingLoader)) + .andReturn(pluginLoader) + .times(4); + + connectorStatusListener.onStartup(CONNECTOR_ID); + EasyMock.expectLastCall(); + + // Reconfigure + EasyMock.>expect(sinkConnector.taskClass()).andReturn(TestSourceTask.class); + Map taskProps = new HashMap<>(); + taskProps.put("foo", "bar"); + EasyMock.expect(sinkConnector.taskConfigs(2)).andReturn(Arrays.asList(taskProps, taskProps)); + + // Remove + sinkConnector.stop(); + EasyMock.expectLastCall(); + + connectorStatusListener.onShutdown(CONNECTOR_ID); + EasyMock.expectLastCall(); + + ctx.close(); + expectLastCall(); + + expectStopStorage(); + expectClusterId(); + + PowerMock.replayAll(); + + worker = new Worker(WORKER_ID, new MockTime(), plugins, config, offsetBackingStore, noneConnectorClientConfigOverridePolicy); + worker.herder = herder; + worker.start(); + + assertStatistics(worker, 0, 0); + assertEquals(Collections.emptySet(), worker.connectorNames()); + FutureCallback onFirstStart = new FutureCallback<>(); + worker.startConnector(CONNECTOR_ID, connectorProps, ctx, connectorStatusListener, TargetState.STARTED, onFirstStart); + // Wait for the connector to actually start + assertEquals(TargetState.STARTED, onFirstStart.get(1000, TimeUnit.MILLISECONDS)); + assertStatistics(worker, 1, 0); + assertEquals(new HashSet<>(Arrays.asList(CONNECTOR_ID)), worker.connectorNames()); + + FutureCallback onSecondStart = new FutureCallback<>(); + worker.startConnector(CONNECTOR_ID, connectorProps, ctx, connectorStatusListener, TargetState.STARTED, onSecondStart); + try { + onSecondStart.get(0, TimeUnit.MILLISECONDS); + fail("Should have failed while trying to start second connector with same name"); + } catch (ExecutionException e) { + assertThat(e.getCause(), instanceOf(ConnectException.class)); + } + + Map connProps = new HashMap<>(connectorProps); + connProps.put(ConnectorConfig.TASKS_MAX_CONFIG, "2"); + ConnectorConfig connConfig = new SinkConnectorConfig(plugins, connProps); + List> taskConfigs = worker.connectorTaskConfigs(CONNECTOR_ID, connConfig); + Map expectedTaskProps = new HashMap<>(); + expectedTaskProps.put("foo", "bar"); + expectedTaskProps.put(TaskConfig.TASK_CLASS_CONFIG, TestSourceTask.class.getName()); + expectedTaskProps.put(SinkTask.TOPICS_CONFIG, "foo,bar"); + assertEquals(2, taskConfigs.size()); + assertEquals(expectedTaskProps, taskConfigs.get(0)); + assertEquals(expectedTaskProps, taskConfigs.get(1)); + assertStatistics(worker, 1, 0); + assertStartupStatistics(worker, 1, 0, 0, 0); + worker.stopAndAwaitConnector(CONNECTOR_ID); + assertStatistics(worker, 0, 0); + assertStartupStatistics(worker, 1, 0, 0, 0); + assertEquals(Collections.emptySet(), worker.connectorNames()); + // Nothing should be left, so this should effectively be a nop + worker.stop(); + assertStatistics(worker, 0, 0); + + PowerMock.verifyAll(); + } + + @Test + public void testAddRemoveTask() throws Exception { + expectConverters(); + expectStartStorage(); + expectFileConfigProvider(); + + EasyMock.expect(workerTask.id()).andStubReturn(TASK_ID); + + EasyMock.expect(plugins.currentThreadLoader()).andReturn(delegatingLoader).times(2); + expectNewWorkerTask(); + Map origProps = new HashMap<>(); + origProps.put(TaskConfig.TASK_CLASS_CONFIG, TestSourceTask.class.getName()); + + TaskConfig taskConfig = new TaskConfig(origProps); + // We should expect this call, but the pluginLoader being swapped in is only mocked. + // EasyMock.expect(pluginLoader.loadClass(TestSourceTask.class.getName())) + // .andReturn((Class) TestSourceTask.class); + EasyMock.expect(plugins.newTask(TestSourceTask.class)).andReturn(task); + EasyMock.expect(task.version()).andReturn("1.0"); + + workerTask.initialize(taskConfig); + EasyMock.expectLastCall(); + + // Expect that the worker will create converters and will find them using the current classloader ... + assertNotNull(taskKeyConverter); + assertNotNull(taskValueConverter); + assertNotNull(taskHeaderConverter); + expectTaskKeyConverters(ClassLoaderUsage.CURRENT_CLASSLOADER, taskKeyConverter); + expectTaskValueConverters(ClassLoaderUsage.CURRENT_CLASSLOADER, taskValueConverter); + expectTaskHeaderConverter(ClassLoaderUsage.CURRENT_CLASSLOADER, taskHeaderConverter); + + EasyMock.expect(executorService.submit(workerTask)).andReturn(null); + + EasyMock.expect(plugins.delegatingLoader()).andReturn(delegatingLoader); + EasyMock.expect(delegatingLoader.connectorLoader(WorkerTestConnector.class.getName())) + .andReturn(pluginLoader); + EasyMock.expect(Plugins.compareAndSwapLoaders(pluginLoader)).andReturn(delegatingLoader) + .times(2); + + EasyMock.expect(workerTask.loader()).andReturn(pluginLoader); + + EasyMock.expect(Plugins.compareAndSwapLoaders(delegatingLoader)).andReturn(pluginLoader) + .times(2); + plugins.connectorClass(WorkerTestConnector.class.getName()); + EasyMock.expectLastCall().andReturn(WorkerTestConnector.class); + // Remove + workerTask.stop(); + EasyMock.expectLastCall(); + EasyMock.expect(workerTask.awaitStop(EasyMock.anyLong())).andStubReturn(true); + EasyMock.expectLastCall(); + + workerTask.removeMetrics(); + EasyMock.expectLastCall(); + + expectStopStorage(); + expectClusterId(); + + PowerMock.replayAll(); + + worker = new Worker(WORKER_ID, new MockTime(), plugins, config, offsetBackingStore, executorService, + noneConnectorClientConfigOverridePolicy); + worker.herder = herder; + worker.start(); + assertStatistics(worker, 0, 0); + assertEquals(Collections.emptySet(), worker.taskIds()); + worker.startTask(TASK_ID, ClusterConfigState.EMPTY, anyConnectorConfigMap(), origProps, taskStatusListener, TargetState.STARTED); + assertStatistics(worker, 0, 1); + assertEquals(new HashSet<>(Arrays.asList(TASK_ID)), worker.taskIds()); + worker.stopAndAwaitTask(TASK_ID); + assertStatistics(worker, 0, 0); + assertEquals(Collections.emptySet(), worker.taskIds()); + // Nothing should be left, so this should effectively be a nop + worker.stop(); + assertStatistics(worker, 0, 0); + + PowerMock.verifyAll(); + } + + @Test + public void testTaskStatusMetricsStatuses() throws Exception { + expectConverters(); + expectStartStorage(); + expectFileConfigProvider(); + + EasyMock.expect(workerTask.id()).andStubReturn(TASK_ID); + + EasyMock.expect(plugins.currentThreadLoader()).andReturn(delegatingLoader).times(2); + expectNewWorkerTask(); + Map origProps = new HashMap<>(); + origProps.put(TaskConfig.TASK_CLASS_CONFIG, TestSourceTask.class.getName()); + + TaskConfig taskConfig = new TaskConfig(origProps); + // We should expect this call, but the pluginLoader being swapped in is only mocked. + // EasyMock.expect(pluginLoader.loadClass(TestSourceTask.class.getName())) + // .andReturn((Class) TestSourceTask.class); + EasyMock.expect(plugins.newTask(TestSourceTask.class)).andReturn(task); + EasyMock.expect(task.version()).andReturn("1.0"); + + workerTask.initialize(taskConfig); + EasyMock.expectLastCall(); + + // Expect that the worker will create converters and will find them using the current classloader ... + assertNotNull(taskKeyConverter); + assertNotNull(taskValueConverter); + assertNotNull(taskHeaderConverter); + expectTaskKeyConverters(ClassLoaderUsage.CURRENT_CLASSLOADER, taskKeyConverter); + expectTaskValueConverters(ClassLoaderUsage.CURRENT_CLASSLOADER, taskValueConverter); + expectTaskHeaderConverter(ClassLoaderUsage.CURRENT_CLASSLOADER, taskHeaderConverter); + + EasyMock.expect(executorService.submit(workerTask)).andReturn(null); + + EasyMock.expect(plugins.delegatingLoader()).andReturn(delegatingLoader); + EasyMock.expect(delegatingLoader.connectorLoader(WorkerTestConnector.class.getName())) + .andReturn(pluginLoader); + EasyMock.expect(Plugins.compareAndSwapLoaders(pluginLoader)).andReturn(delegatingLoader) + .times(2); + + EasyMock.expect(Plugins.compareAndSwapLoaders(delegatingLoader)).andReturn(pluginLoader) + .times(2); + plugins.connectorClass(WorkerTestConnector.class.getName()); + EasyMock.expectLastCall().andReturn(WorkerTestConnector.class); + + EasyMock.expect(workerTask.awaitStop(EasyMock.anyLong())).andStubReturn(true); + EasyMock.expectLastCall(); + + workerTask.removeMetrics(); + EasyMock.expectLastCall(); + + // Each time we check the task metrics, the worker will call the herder + herder.taskStatus(TASK_ID); + EasyMock.expectLastCall() + .andReturn(new ConnectorStateInfo.TaskState(0, "RUNNING", "worker", "msg")); + + herder.taskStatus(TASK_ID); + EasyMock.expectLastCall() + .andReturn(new ConnectorStateInfo.TaskState(0, "PAUSED", "worker", "msg")); + + herder.taskStatus(TASK_ID); + EasyMock.expectLastCall() + .andReturn(new ConnectorStateInfo.TaskState(0, "FAILED", "worker", "msg")); + + herder.taskStatus(TASK_ID); + EasyMock.expectLastCall() + .andReturn(new ConnectorStateInfo.TaskState(0, "DESTROYED", "worker", "msg")); + + herder.taskStatus(TASK_ID); + EasyMock.expectLastCall() + .andReturn(new ConnectorStateInfo.TaskState(0, "UNASSIGNED", "worker", "msg")); + + // Called when we stop the worker + EasyMock.expect(workerTask.loader()).andReturn(pluginLoader); + workerTask.stop(); + EasyMock.expectLastCall(); + + expectClusterId(); + + PowerMock.replayAll(); + + worker = new Worker(WORKER_ID, + new MockTime(), + plugins, + config, + offsetBackingStore, + executorService, + noneConnectorClientConfigOverridePolicy); + + worker.herder = herder; + + worker.start(); + assertStatistics(worker, 0, 0); + assertStartupStatistics(worker, 0, 0, 0, 0); + assertEquals(Collections.emptySet(), worker.taskIds()); + worker.startTask( + TASK_ID, + ClusterConfigState.EMPTY, + anyConnectorConfigMap(), + origProps, + taskStatusListener, + TargetState.STARTED); + + assertStatusMetrics(1L, "connector-running-task-count"); + assertStatusMetrics(1L, "connector-paused-task-count"); + assertStatusMetrics(1L, "connector-failed-task-count"); + assertStatusMetrics(1L, "connector-destroyed-task-count"); + assertStatusMetrics(1L, "connector-unassigned-task-count"); + + worker.stopAndAwaitTask(TASK_ID); + assertStatusMetrics(0L, "connector-running-task-count"); + assertStatusMetrics(0L, "connector-paused-task-count"); + assertStatusMetrics(0L, "connector-failed-task-count"); + assertStatusMetrics(0L, "connector-destroyed-task-count"); + assertStatusMetrics(0L, "connector-unassigned-task-count"); + + PowerMock.verifyAll(); + } + + @Test + public void testConnectorStatusMetricsGroup_taskStatusCounter() { + ConcurrentMap tasks = new ConcurrentHashMap<>(); + tasks.put(new ConnectorTaskId("c1", 0), workerTask); + tasks.put(new ConnectorTaskId("c1", 1), workerTask); + tasks.put(new ConnectorTaskId("c2", 0), workerTask); + + expectConverters(); + expectStartStorage(); + expectFileConfigProvider(); + + EasyMock.expect(Plugins.compareAndSwapLoaders(pluginLoader)).andReturn(delegatingLoader); + EasyMock.expect(Plugins.compareAndSwapLoaders(pluginLoader)).andReturn(delegatingLoader); + + EasyMock.expect(Plugins.compareAndSwapLoaders(delegatingLoader)).andReturn(pluginLoader); + + taskStatusListener.onFailure(EasyMock.eq(TASK_ID), EasyMock.anyObject()); + EasyMock.expectLastCall(); + + expectClusterId(); + + PowerMock.replayAll(); + + worker = new Worker(WORKER_ID, + new MockTime(), + plugins, + config, + offsetBackingStore, + noneConnectorClientConfigOverridePolicy); + worker.herder = herder; + + Worker.ConnectorStatusMetricsGroup metricGroup = new Worker.ConnectorStatusMetricsGroup( + worker.metrics(), tasks, herder + ); + assertEquals(2L, (long) metricGroup.taskCounter("c1").metricValue(0L)); + assertEquals(1L, (long) metricGroup.taskCounter("c2").metricValue(0L)); + assertEquals(0L, (long) metricGroup.taskCounter("fakeConnector").metricValue(0L)); + } + + @Test + public void testStartTaskFailure() { + expectConverters(); + expectStartStorage(); + expectFileConfigProvider(); + + Map origProps = new HashMap<>(); + origProps.put(TaskConfig.TASK_CLASS_CONFIG, "missing.From.This.Workers.Classpath"); + + EasyMock.expect(plugins.currentThreadLoader()).andReturn(delegatingLoader); + EasyMock.expect(plugins.delegatingLoader()).andReturn(delegatingLoader); + EasyMock.expect(delegatingLoader.connectorLoader(WorkerTestConnector.class.getName())) + .andReturn(pluginLoader); + + // We would normally expect this since the plugin loader would have been swapped in. However, since we mock out + // all classloader changes, the call actually goes to the normal default classloader. However, this works out + // fine since we just wanted a ClassNotFoundException anyway. + // EasyMock.expect(pluginLoader.loadClass(origProps.get(TaskConfig.TASK_CLASS_CONFIG))) + // .andThrow(new ClassNotFoundException()); + + EasyMock.expect(Plugins.compareAndSwapLoaders(pluginLoader)) + .andReturn(delegatingLoader); + + EasyMock.expect(Plugins.compareAndSwapLoaders(delegatingLoader)) + .andReturn(pluginLoader); + + taskStatusListener.onFailure(EasyMock.eq(TASK_ID), EasyMock.anyObject()); + EasyMock.expectLastCall(); + + expectClusterId(); + + PowerMock.replayAll(); + + worker = new Worker(WORKER_ID, new MockTime(), plugins, config, offsetBackingStore, noneConnectorClientConfigOverridePolicy); + worker.herder = herder; + worker.start(); + assertStatistics(worker, 0, 0); + assertStartupStatistics(worker, 0, 0, 0, 0); + + assertFalse(worker.startTask(TASK_ID, ClusterConfigState.EMPTY, anyConnectorConfigMap(), origProps, taskStatusListener, TargetState.STARTED)); + assertStartupStatistics(worker, 0, 0, 1, 1); + + assertStatistics(worker, 0, 0); + assertStartupStatistics(worker, 0, 0, 1, 1); + assertEquals(Collections.emptySet(), worker.taskIds()); + + PowerMock.verifyAll(); + } + + @Test + public void testCleanupTasksOnStop() throws Exception { + expectConverters(); + expectStartStorage(); + expectFileConfigProvider(); + + EasyMock.expect(workerTask.id()).andStubReturn(TASK_ID); + + EasyMock.expect(plugins.currentThreadLoader()).andReturn(delegatingLoader).times(2); + expectNewWorkerTask(); + Map origProps = new HashMap<>(); + origProps.put(TaskConfig.TASK_CLASS_CONFIG, TestSourceTask.class.getName()); + + TaskConfig taskConfig = new TaskConfig(origProps); + // We should expect this call, but the pluginLoader being swapped in is only mocked. + // EasyMock.expect(pluginLoader.loadClass(TestSourceTask.class.getName())) + // .andReturn((Class) TestSourceTask.class); + EasyMock.expect(plugins.newTask(TestSourceTask.class)).andReturn(task); + EasyMock.expect(task.version()).andReturn("1.0"); + + workerTask.initialize(taskConfig); + EasyMock.expectLastCall(); + + // Expect that the worker will create converters and will not initially find them using the current classloader ... + assertNotNull(taskKeyConverter); + assertNotNull(taskValueConverter); + assertNotNull(taskHeaderConverter); + expectTaskKeyConverters(ClassLoaderUsage.CURRENT_CLASSLOADER, null); + expectTaskKeyConverters(ClassLoaderUsage.PLUGINS, taskKeyConverter); + expectTaskValueConverters(ClassLoaderUsage.CURRENT_CLASSLOADER, null); + expectTaskValueConverters(ClassLoaderUsage.PLUGINS, taskValueConverter); + expectTaskHeaderConverter(ClassLoaderUsage.CURRENT_CLASSLOADER, null); + expectTaskHeaderConverter(ClassLoaderUsage.PLUGINS, taskHeaderConverter); + + EasyMock.expect(executorService.submit(workerTask)).andReturn(null); + + EasyMock.expect(plugins.delegatingLoader()).andReturn(delegatingLoader); + EasyMock.expect(delegatingLoader.connectorLoader(WorkerTestConnector.class.getName())) + .andReturn(pluginLoader); + + EasyMock.expect(Plugins.compareAndSwapLoaders(pluginLoader)).andReturn(delegatingLoader) + .times(2); + + EasyMock.expect(workerTask.loader()).andReturn(pluginLoader); + + EasyMock.expect(Plugins.compareAndSwapLoaders(delegatingLoader)).andReturn(pluginLoader) + .times(2); + plugins.connectorClass(WorkerTestConnector.class.getName()); + EasyMock.expectLastCall().andReturn(WorkerTestConnector.class); + // Remove on Worker.stop() + workerTask.stop(); + EasyMock.expectLastCall(); + + EasyMock.expect(workerTask.awaitStop(EasyMock.anyLong())).andReturn(true); + // Note that in this case we *do not* commit offsets since it's an unclean shutdown + EasyMock.expectLastCall(); + + workerTask.removeMetrics(); + EasyMock.expectLastCall(); + + expectStopStorage(); + expectClusterId(); + + PowerMock.replayAll(); + + worker = new Worker(WORKER_ID, new MockTime(), plugins, config, offsetBackingStore, executorService, + noneConnectorClientConfigOverridePolicy); + worker.herder = herder; + worker.start(); + assertStatistics(worker, 0, 0); + worker.startTask(TASK_ID, ClusterConfigState.EMPTY, anyConnectorConfigMap(), origProps, taskStatusListener, TargetState.STARTED); + assertStatistics(worker, 0, 1); + worker.stop(); + assertStatistics(worker, 0, 0); + + PowerMock.verifyAll(); + } + + @Test + public void testConverterOverrides() throws Exception { + expectConverters(); + expectStartStorage(); + expectFileConfigProvider(); + + EasyMock.expect(workerTask.id()).andStubReturn(TASK_ID); + + EasyMock.expect(plugins.currentThreadLoader()).andReturn(delegatingLoader).times(2); + expectNewWorkerTask(); + Map origProps = new HashMap<>(); + origProps.put(TaskConfig.TASK_CLASS_CONFIG, TestSourceTask.class.getName()); + + TaskConfig taskConfig = new TaskConfig(origProps); + // We should expect this call, but the pluginLoader being swapped in is only mocked. + // EasyMock.expect(pluginLoader.loadClass(TestSourceTask.class.getName())) + // .andReturn((Class) TestSourceTask.class); + EasyMock.expect(plugins.newTask(TestSourceTask.class)).andReturn(task); + EasyMock.expect(task.version()).andReturn("1.0"); + + workerTask.initialize(taskConfig); + EasyMock.expectLastCall(); + + // Expect that the worker will create converters and will not initially find them using the current classloader ... + assertNotNull(taskKeyConverter); + assertNotNull(taskValueConverter); + assertNotNull(taskHeaderConverter); + expectTaskKeyConverters(ClassLoaderUsage.CURRENT_CLASSLOADER, null); + expectTaskKeyConverters(ClassLoaderUsage.PLUGINS, taskKeyConverter); + expectTaskValueConverters(ClassLoaderUsage.CURRENT_CLASSLOADER, null); + expectTaskValueConverters(ClassLoaderUsage.PLUGINS, taskValueConverter); + expectTaskHeaderConverter(ClassLoaderUsage.CURRENT_CLASSLOADER, null); + expectTaskHeaderConverter(ClassLoaderUsage.PLUGINS, taskHeaderConverter); + + EasyMock.expect(executorService.submit(workerTask)).andReturn(null); + + EasyMock.expect(plugins.delegatingLoader()).andReturn(delegatingLoader); + EasyMock.expect(delegatingLoader.connectorLoader(WorkerTestConnector.class.getName())) + .andReturn(pluginLoader); + + EasyMock.expect(Plugins.compareAndSwapLoaders(pluginLoader)).andReturn(delegatingLoader) + .times(2); + + EasyMock.expect(workerTask.loader()).andReturn(pluginLoader); + + EasyMock.expect(Plugins.compareAndSwapLoaders(delegatingLoader)).andReturn(pluginLoader) + .times(2); + plugins.connectorClass(WorkerTestConnector.class.getName()); + EasyMock.expectLastCall().andReturn(WorkerTestConnector.class); + + // Remove + workerTask.stop(); + EasyMock.expectLastCall(); + EasyMock.expect(workerTask.awaitStop(EasyMock.anyLong())).andStubReturn(true); + EasyMock.expectLastCall(); + + workerTask.removeMetrics(); + EasyMock.expectLastCall(); + + expectStopStorage(); + expectClusterId(); + + PowerMock.replayAll(); + + worker = new Worker(WORKER_ID, new MockTime(), plugins, config, offsetBackingStore, executorService, + noneConnectorClientConfigOverridePolicy); + worker.herder = herder; + worker.start(); + assertStatistics(worker, 0, 0); + assertEquals(Collections.emptySet(), worker.taskIds()); + Map connProps = anyConnectorConfigMap(); + connProps.put(ConnectorConfig.KEY_CONVERTER_CLASS_CONFIG, TestConverter.class.getName()); + connProps.put("key.converter.extra.config", "foo"); + connProps.put(ConnectorConfig.VALUE_CONVERTER_CLASS_CONFIG, TestConfigurableConverter.class.getName()); + connProps.put("value.converter.extra.config", "bar"); + worker.startTask(TASK_ID, ClusterConfigState.EMPTY, connProps, origProps, taskStatusListener, TargetState.STARTED); + assertStatistics(worker, 0, 1); + assertEquals(new HashSet<>(Arrays.asList(TASK_ID)), worker.taskIds()); + worker.stopAndAwaitTask(TASK_ID); + assertStatistics(worker, 0, 0); + assertEquals(Collections.emptySet(), worker.taskIds()); + // Nothing should be left, so this should effectively be a nop + worker.stop(); + assertStatistics(worker, 0, 0); + + // We've mocked the Plugin.newConverter method, so we don't currently configure the converters + + PowerMock.verifyAll(); + } + + @Test + public void testProducerConfigsWithoutOverrides() { + EasyMock.expect(connectorConfig.originalsWithPrefix(ConnectorConfig.CONNECTOR_CLIENT_PRODUCER_OVERRIDES_PREFIX)).andReturn( + new HashMap<>()); + PowerMock.replayAll(); + Map expectedConfigs = new HashMap<>(defaultProducerConfigs); + expectedConfigs.put("client.id", "connector-producer-job-0"); + expectedConfigs.put("metrics.context.connect.kafka.cluster.id", CLUSTER_ID); + assertEquals(expectedConfigs, + Worker.producerConfigs(TASK_ID, "connector-producer-" + TASK_ID, config, connectorConfig, null, noneConnectorClientConfigOverridePolicy, CLUSTER_ID)); + } + + @Test + public void testProducerConfigsWithOverrides() { + Map props = new HashMap<>(workerProps); + props.put("producer.acks", "-1"); + props.put("producer.linger.ms", "1000"); + props.put("producer.client.id", "producer-test-id"); + WorkerConfig configWithOverrides = new StandaloneConfig(props); + + Map expectedConfigs = new HashMap<>(defaultProducerConfigs); + expectedConfigs.put("acks", "-1"); + expectedConfigs.put("linger.ms", "1000"); + expectedConfigs.put("client.id", "producer-test-id"); + expectedConfigs.put("metrics.context.connect.kafka.cluster.id", CLUSTER_ID); + + EasyMock.expect(connectorConfig.originalsWithPrefix(ConnectorConfig.CONNECTOR_CLIENT_PRODUCER_OVERRIDES_PREFIX)).andReturn( + new HashMap<>()); + PowerMock.replayAll(); + assertEquals(expectedConfigs, + Worker.producerConfigs(TASK_ID, "connector-producer-" + TASK_ID, configWithOverrides, connectorConfig, null, allConnectorClientConfigOverridePolicy, CLUSTER_ID)); + } + + @Test + public void testProducerConfigsWithClientOverrides() { + Map props = new HashMap<>(workerProps); + props.put("producer.acks", "-1"); + props.put("producer.linger.ms", "1000"); + props.put("producer.client.id", "producer-test-id"); + WorkerConfig configWithOverrides = new StandaloneConfig(props); + + Map expectedConfigs = new HashMap<>(defaultProducerConfigs); + expectedConfigs.put("acks", "-1"); + expectedConfigs.put("linger.ms", "5000"); + expectedConfigs.put("batch.size", "1000"); + expectedConfigs.put("client.id", "producer-test-id"); + expectedConfigs.put("metrics.context.connect.kafka.cluster.id", CLUSTER_ID); + + Map connConfig = new HashMap<>(); + connConfig.put("linger.ms", "5000"); + connConfig.put("batch.size", "1000"); + EasyMock.expect(connectorConfig.originalsWithPrefix(ConnectorConfig.CONNECTOR_CLIENT_PRODUCER_OVERRIDES_PREFIX)) + .andReturn(connConfig); + PowerMock.replayAll(); + assertEquals(expectedConfigs, + Worker.producerConfigs(TASK_ID, "connector-producer-" + TASK_ID, configWithOverrides, connectorConfig, null, allConnectorClientConfigOverridePolicy, CLUSTER_ID)); + } + + @Test + public void testConsumerConfigsWithoutOverrides() { + Map expectedConfigs = new HashMap<>(defaultConsumerConfigs); + expectedConfigs.put("group.id", "connect-test"); + expectedConfigs.put("client.id", "connector-consumer-test-1"); + expectedConfigs.put("metrics.context.connect.kafka.cluster.id", CLUSTER_ID); + + EasyMock.expect(connectorConfig.originalsWithPrefix(ConnectorConfig.CONNECTOR_CLIENT_CONSUMER_OVERRIDES_PREFIX)).andReturn(new HashMap<>()); + PowerMock.replayAll(); + assertEquals(expectedConfigs, Worker.consumerConfigs(new ConnectorTaskId("test", 1), config, connectorConfig, + null, noneConnectorClientConfigOverridePolicy, CLUSTER_ID)); + } + + @Test + public void testConsumerConfigsWithOverrides() { + Map props = new HashMap<>(workerProps); + props.put("consumer.auto.offset.reset", "latest"); + props.put("consumer.max.poll.records", "1000"); + props.put("consumer.client.id", "consumer-test-id"); + WorkerConfig configWithOverrides = new StandaloneConfig(props); + + Map expectedConfigs = new HashMap<>(defaultConsumerConfigs); + expectedConfigs.put("group.id", "connect-test"); + expectedConfigs.put("auto.offset.reset", "latest"); + expectedConfigs.put("max.poll.records", "1000"); + expectedConfigs.put("client.id", "consumer-test-id"); + expectedConfigs.put("metrics.context.connect.kafka.cluster.id", CLUSTER_ID); + + EasyMock.expect(connectorConfig.originalsWithPrefix(ConnectorConfig.CONNECTOR_CLIENT_CONSUMER_OVERRIDES_PREFIX)).andReturn(new HashMap<>()); + PowerMock.replayAll(); + assertEquals(expectedConfigs, Worker.consumerConfigs(new ConnectorTaskId("test", 1), configWithOverrides, connectorConfig, + null, noneConnectorClientConfigOverridePolicy, CLUSTER_ID)); + + } + + @Test + public void testConsumerConfigsWithClientOverrides() { + Map props = new HashMap<>(workerProps); + props.put("consumer.auto.offset.reset", "latest"); + props.put("consumer.max.poll.records", "5000"); + WorkerConfig configWithOverrides = new StandaloneConfig(props); + + Map expectedConfigs = new HashMap<>(defaultConsumerConfigs); + expectedConfigs.put("group.id", "connect-test"); + expectedConfigs.put("auto.offset.reset", "latest"); + expectedConfigs.put("max.poll.records", "5000"); + expectedConfigs.put("max.poll.interval.ms", "1000"); + expectedConfigs.put("client.id", "connector-consumer-test-1"); + expectedConfigs.put("metrics.context.connect.kafka.cluster.id", CLUSTER_ID); + + Map connConfig = new HashMap<>(); + connConfig.put("max.poll.records", "5000"); + connConfig.put("max.poll.interval.ms", "1000"); + EasyMock.expect(connectorConfig.originalsWithPrefix(ConnectorConfig.CONNECTOR_CLIENT_CONSUMER_OVERRIDES_PREFIX)) + .andReturn(connConfig); + PowerMock.replayAll(); + assertEquals(expectedConfigs, Worker.consumerConfigs(new ConnectorTaskId("test", 1), configWithOverrides, connectorConfig, + null, allConnectorClientConfigOverridePolicy, CLUSTER_ID)); + } + + @Test + public void testConsumerConfigsClientOverridesWithNonePolicy() { + Map props = new HashMap<>(workerProps); + props.put("consumer.auto.offset.reset", "latest"); + props.put("consumer.max.poll.records", "5000"); + WorkerConfig configWithOverrides = new StandaloneConfig(props); + + Map connConfig = new HashMap<>(); + connConfig.put("max.poll.records", "5000"); + connConfig.put("max.poll.interval.ms", "1000"); + EasyMock.expect(connectorConfig.originalsWithPrefix(ConnectorConfig.CONNECTOR_CLIENT_CONSUMER_OVERRIDES_PREFIX)) + .andReturn(connConfig); + PowerMock.replayAll(); + assertThrows(ConnectException.class, () -> Worker.consumerConfigs(new ConnectorTaskId("test", 1), + configWithOverrides, connectorConfig, null, noneConnectorClientConfigOverridePolicy, CLUSTER_ID)); + } + + @Test + public void testAdminConfigsClientOverridesWithAllPolicy() { + Map props = new HashMap<>(workerProps); + props.put("admin.client.id", "testid"); + props.put("admin.metadata.max.age.ms", "5000"); + props.put("producer.bootstrap.servers", "cbeauho.com"); + props.put("consumer.bootstrap.servers", "localhost:4761"); + WorkerConfig configWithOverrides = new StandaloneConfig(props); + + Map connConfig = new HashMap<>(); + connConfig.put("metadata.max.age.ms", "10000"); + + Map expectedConfigs = new HashMap<>(workerProps); + + expectedConfigs.put("bootstrap.servers", "localhost:9092"); + expectedConfigs.put("client.id", "testid"); + expectedConfigs.put("metadata.max.age.ms", "10000"); + //we added a config on the fly + expectedConfigs.put("metrics.context.connect.kafka.cluster.id", CLUSTER_ID); + + EasyMock.expect(connectorConfig.originalsWithPrefix(ConnectorConfig.CONNECTOR_CLIENT_ADMIN_OVERRIDES_PREFIX)) + .andReturn(connConfig); + PowerMock.replayAll(); + assertEquals(expectedConfigs, Worker.adminConfigs(new ConnectorTaskId("test", 1), "", configWithOverrides, connectorConfig, + null, allConnectorClientConfigOverridePolicy, CLUSTER_ID)); + } + + @Test + public void testAdminConfigsClientOverridesWithNonePolicy() { + Map props = new HashMap<>(workerProps); + props.put("admin.client.id", "testid"); + props.put("admin.metadata.max.age.ms", "5000"); + WorkerConfig configWithOverrides = new StandaloneConfig(props); + + Map connConfig = new HashMap<>(); + connConfig.put("metadata.max.age.ms", "10000"); + + EasyMock.expect(connectorConfig.originalsWithPrefix(ConnectorConfig.CONNECTOR_CLIENT_ADMIN_OVERRIDES_PREFIX)) + .andReturn(connConfig); + PowerMock.replayAll(); + assertThrows(ConnectException.class, () -> Worker.adminConfigs(new ConnectorTaskId("test", 1), + "", configWithOverrides, connectorConfig, null, noneConnectorClientConfigOverridePolicy, CLUSTER_ID)); + + } + + @Test + public void testWorkerMetrics() throws Exception { + expectConverters(); + expectStartStorage(); + expectFileConfigProvider(); + + // Create + EasyMock.expect(plugins.currentThreadLoader()).andReturn(delegatingLoader).times(2); + EasyMock.expect(plugins.newConnector(WorkerTestConnector.class.getName())) + .andReturn(sourceConnector); + EasyMock.expect(sourceConnector.version()).andReturn("1.0"); + + Map props = new HashMap<>(); + props.put(SinkConnectorConfig.TOPICS_CONFIG, "foo,bar"); + props.put(ConnectorConfig.TASKS_MAX_CONFIG, "1"); + props.put(ConnectorConfig.NAME_CONFIG, CONNECTOR_ID); + props.put(ConnectorConfig.CONNECTOR_CLASS_CONFIG, WorkerTestConnector.class.getName()); + + EasyMock.expect(sourceConnector.version()).andReturn("1.0"); + + EasyMock.expect(plugins.compareAndSwapLoaders(sourceConnector)) + .andReturn(delegatingLoader) + .times(2); + sourceConnector.initialize(anyObject(ConnectorContext.class)); + EasyMock.expectLastCall(); + sourceConnector.start(props); + EasyMock.expectLastCall(); + + EasyMock.expect(Plugins.compareAndSwapLoaders(delegatingLoader)) + .andReturn(pluginLoader).times(2); + + connectorStatusListener.onStartup(CONNECTOR_ID); + EasyMock.expectLastCall(); + + // Remove + sourceConnector.stop(); + EasyMock.expectLastCall(); + + connectorStatusListener.onShutdown(CONNECTOR_ID); + EasyMock.expectLastCall(); + + expectStopStorage(); + expectClusterId(); + + PowerMock.replayAll(); + + Worker worker = new Worker("worker-1", + Time.SYSTEM, + plugins, + config, + offsetBackingStore, + noneConnectorClientConfigOverridePolicy + ); + MetricName name = worker.metrics().metrics().metricName("test.avg", "grp1"); + worker.metrics().metrics().addMetric(name, new Avg()); + MBeanServer server = ManagementFactory.getPlatformMBeanServer(); + Set ret = server.queryMBeans(null, null); + + List list = worker.metrics().metrics().reporters(); + for (MetricsReporter reporter : list) { + if (reporter instanceof MockMetricsReporter) { + MockMetricsReporter mockMetricsReporter = (MockMetricsReporter) reporter; + //verify connect cluster is set in MetricsContext + assertEquals(CLUSTER_ID, mockMetricsReporter.getMetricsContext().contextLabels().get(WorkerConfig.CONNECT_KAFKA_CLUSTER_ID)); + } + } + //verify metric is created with correct jmx prefix + assertNotNull(server.getObjectInstance(new ObjectName("kafka.connect:type=grp1"))); + } + + private void assertStatusMetrics(long expected, String metricName) { + MetricGroup statusMetrics = worker.connectorStatusMetricsGroup().metricGroup(TASK_ID.connector()); + if (expected == 0L) { + assertNull(statusMetrics); + return; + } + assertEquals(expected, MockConnectMetrics.currentMetricValue(worker.metrics(), statusMetrics, metricName)); + } + + private void assertStatistics(Worker worker, int connectors, int tasks) { + assertStatusMetrics(tasks, "connector-total-task-count"); + MetricGroup workerMetrics = worker.workerMetricsGroup().metricGroup(); + assertEquals(connectors, MockConnectMetrics.currentMetricValueAsDouble(worker.metrics(), workerMetrics, "connector-count"), 0.0001d); + assertEquals(tasks, MockConnectMetrics.currentMetricValueAsDouble(worker.metrics(), workerMetrics, "task-count"), 0.0001d); + assertEquals(tasks, MockConnectMetrics.currentMetricValueAsDouble(worker.metrics(), workerMetrics, "task-count"), 0.0001d); + } + + private void assertStartupStatistics(Worker worker, int connectorStartupAttempts, int connectorStartupFailures, int taskStartupAttempts, int taskStartupFailures) { + double connectStartupSuccesses = connectorStartupAttempts - connectorStartupFailures; + double taskStartupSuccesses = taskStartupAttempts - taskStartupFailures; + double connectStartupSuccessPct = 0.0d; + double connectStartupFailurePct = 0.0d; + double taskStartupSuccessPct = 0.0d; + double taskStartupFailurePct = 0.0d; + if (connectorStartupAttempts != 0) { + connectStartupSuccessPct = connectStartupSuccesses / connectorStartupAttempts; + connectStartupFailurePct = (double) connectorStartupFailures / connectorStartupAttempts; + } + if (taskStartupAttempts != 0) { + taskStartupSuccessPct = taskStartupSuccesses / taskStartupAttempts; + taskStartupFailurePct = (double) taskStartupFailures / taskStartupAttempts; + } + MetricGroup workerMetrics = worker.workerMetricsGroup().metricGroup(); + assertEquals(connectorStartupAttempts, MockConnectMetrics.currentMetricValueAsDouble(worker.metrics(), workerMetrics, "connector-startup-attempts-total"), 0.0001d); + assertEquals(connectStartupSuccesses, MockConnectMetrics.currentMetricValueAsDouble(worker.metrics(), workerMetrics, "connector-startup-success-total"), 0.0001d); + assertEquals(connectorStartupFailures, MockConnectMetrics.currentMetricValueAsDouble(worker.metrics(), workerMetrics, "connector-startup-failure-total"), 0.0001d); + assertEquals(connectStartupSuccessPct, MockConnectMetrics.currentMetricValueAsDouble(worker.metrics(), workerMetrics, "connector-startup-success-percentage"), 0.0001d); + assertEquals(connectStartupFailurePct, MockConnectMetrics.currentMetricValueAsDouble(worker.metrics(), workerMetrics, "connector-startup-failure-percentage"), 0.0001d); + assertEquals(taskStartupAttempts, MockConnectMetrics.currentMetricValueAsDouble(worker.metrics(), workerMetrics, "task-startup-attempts-total"), 0.0001d); + assertEquals(taskStartupSuccesses, MockConnectMetrics.currentMetricValueAsDouble(worker.metrics(), workerMetrics, "task-startup-success-total"), 0.0001d); + assertEquals(taskStartupFailures, MockConnectMetrics.currentMetricValueAsDouble(worker.metrics(), workerMetrics, "task-startup-failure-total"), 0.0001d); + assertEquals(taskStartupSuccessPct, MockConnectMetrics.currentMetricValueAsDouble(worker.metrics(), workerMetrics, "task-startup-success-percentage"), 0.0001d); + assertEquals(taskStartupFailurePct, MockConnectMetrics.currentMetricValueAsDouble(worker.metrics(), workerMetrics, "task-startup-failure-percentage"), 0.0001d); + } + + private void expectStartStorage() { + offsetBackingStore.configure(anyObject(WorkerConfig.class)); + EasyMock.expectLastCall(); + offsetBackingStore.start(); + EasyMock.expectLastCall(); + EasyMock.expect(herder.statusBackingStore()) + .andReturn(statusBackingStore).anyTimes(); + } + + private void expectStopStorage() { + offsetBackingStore.stop(); + EasyMock.expectLastCall(); + } + + private void expectConverters() { + expectConverters(JsonConverter.class, false); + } + + private void expectConverters(Boolean expectDefaultConverters) { + expectConverters(JsonConverter.class, expectDefaultConverters); + } + + @SuppressWarnings("deprecation") + private void expectConverters(Class converterClass, Boolean expectDefaultConverters) { + // As default converters are instantiated when a task starts, they are expected only if the `startTask` method is called + if (expectDefaultConverters) { + + // Instantiate and configure default + EasyMock.expect(plugins.newConverter(config, WorkerConfig.KEY_CONVERTER_CLASS_CONFIG, ClassLoaderUsage.PLUGINS)) + .andReturn(keyConverter); + EasyMock.expect(plugins.newConverter(config, WorkerConfig.VALUE_CONVERTER_CLASS_CONFIG, ClassLoaderUsage.PLUGINS)) + .andReturn(valueConverter); + EasyMock.expectLastCall(); + } + + //internal + Converter internalKeyConverter = PowerMock.createMock(converterClass); + Converter internalValueConverter = PowerMock.createMock(converterClass); + + // Instantiate and configure internal + EasyMock.expect( + plugins.newInternalConverter( + EasyMock.eq(true), + EasyMock.anyString(), + EasyMock.anyObject() + ) + ).andReturn(internalKeyConverter); + EasyMock.expect( + plugins.newInternalConverter( + EasyMock.eq(false), + EasyMock.anyString(), + EasyMock.anyObject() + ) + ).andReturn(internalValueConverter); + EasyMock.expectLastCall(); + } + + private void expectTaskKeyConverters(ClassLoaderUsage classLoaderUsage, Converter returning) { + EasyMock.expect( + plugins.newConverter( + anyObject(AbstractConfig.class), + eq(WorkerConfig.KEY_CONVERTER_CLASS_CONFIG), + eq(classLoaderUsage))) + .andReturn(returning); + } + + private void expectTaskValueConverters(ClassLoaderUsage classLoaderUsage, Converter returning) { + EasyMock.expect( + plugins.newConverter( + anyObject(AbstractConfig.class), + eq(WorkerConfig.VALUE_CONVERTER_CLASS_CONFIG), + eq(classLoaderUsage))) + .andReturn(returning); + } + + private void expectTaskHeaderConverter(ClassLoaderUsage classLoaderUsage, HeaderConverter returning) { + EasyMock.expect( + plugins.newHeaderConverter( + anyObject(AbstractConfig.class), + eq(WorkerConfig.HEADER_CONVERTER_CLASS_CONFIG), + eq(classLoaderUsage))) + .andReturn(returning); + } + + private Map anyConnectorConfigMap() { + Map props = new HashMap<>(); + props.put(ConnectorConfig.NAME_CONFIG, CONNECTOR_ID); + props.put(ConnectorConfig.CONNECTOR_CLASS_CONFIG, WorkerTestConnector.class.getName()); + props.put(ConnectorConfig.TASKS_MAX_CONFIG, "1"); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + REPLICATION_FACTOR_CONFIG, String.valueOf(1)); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + PARTITIONS_CONFIG, String.valueOf(1)); + return props; + } + + private void expectClusterId() { + PowerMock.mockStaticPartial(ConnectUtils.class, "lookupKafkaClusterId"); + EasyMock.expect(ConnectUtils.lookupKafkaClusterId(EasyMock.anyObject())).andReturn("test-cluster").anyTimes(); + } + + private void expectNewWorkerTask() throws Exception { + PowerMock.expectNew( + WorkerSourceTask.class, EasyMock.eq(TASK_ID), + EasyMock.eq(task), + anyObject(TaskStatus.Listener.class), + EasyMock.eq(TargetState.STARTED), + anyObject(JsonConverter.class), + anyObject(JsonConverter.class), + anyObject(JsonConverter.class), + EasyMock.eq(new TransformationChain<>(Collections.emptyList(), NOOP_OPERATOR)), + anyObject(KafkaProducer.class), + anyObject(TopicAdmin.class), + EasyMock.>anyObject(), + anyObject(OffsetStorageReader.class), + anyObject(OffsetStorageWriter.class), + EasyMock.eq(config), + anyObject(ClusterConfigState.class), + anyObject(ConnectMetrics.class), + EasyMock.eq(pluginLoader), + anyObject(Time.class), + anyObject(RetryWithToleranceOperator.class), + anyObject(StatusBackingStore.class), + anyObject(Executor.class)) + .andReturn(workerTask); + } + /* Name here needs to be unique as we are testing the aliasing mechanism */ + public static class WorkerTestConnector extends SourceConnector { + + private static final ConfigDef CONFIG_DEF = new ConfigDef() + .define("configName", ConfigDef.Type.STRING, ConfigDef.Importance.HIGH, "Test configName."); + + @Override + public String version() { + return "1.0"; + } + + @Override + public void start(Map props) { + + } + + @Override + public Class taskClass() { + return null; + } + + @Override + public List> taskConfigs(int maxTasks) { + return null; + } + + @Override + public void stop() { + + } + + @Override + public ConfigDef config() { + return CONFIG_DEF; + } + } + + private static class TestSourceTask extends SourceTask { + public TestSourceTask() { + } + + @Override + public String version() { + return "1.0"; + } + + @Override + public void start(Map props) { + } + + @Override + public List poll() throws InterruptedException { + return null; + } + + @Override + public void stop() { + } + } + + public static class TestConverter implements Converter { + public Map configs; + + @Override + public void configure(Map configs, boolean isKey) { + this.configs = configs; + } + + @Override + public byte[] fromConnectData(String topic, Schema schema, Object value) { + return new byte[0]; + } + + @Override + public SchemaAndValue toConnectData(String topic, byte[] value) { + return null; + } + } + + public static class TestConfigurableConverter implements Converter, Configurable { + public Map configs; + + public ConfigDef config() { + return JsonConverterConfig.configDef(); + } + + @Override + public void configure(Map configs) { + this.configs = configs; + new JsonConverterConfig(configs); // requires the `converter.type` config be set + } + + @Override + public void configure(Map configs, boolean isKey) { + this.configs = configs; + } + + @Override + public byte[] fromConnectData(String topic, Schema schema, Object value) { + return new byte[0]; + } + + @Override + public SchemaAndValue toConnectData(String topic, byte[] value) { + return null; + } + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerTestUtils.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerTestUtils.java new file mode 100644 index 0000000..ed77018 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/WorkerTestUtils.java @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime; + +import org.apache.kafka.connect.runtime.distributed.ClusterConfigState; +import org.apache.kafka.connect.runtime.distributed.ExtendedAssignment; +import org.apache.kafka.connect.runtime.distributed.ExtendedWorkerState; +import org.apache.kafka.connect.util.ConnectorTaskId; + +import java.util.AbstractMap.SimpleEntry; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.apache.kafka.connect.runtime.distributed.WorkerCoordinator.WorkerLoad; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +public class WorkerTestUtils { + + public static WorkerLoad emptyWorkerLoad(String worker) { + return new WorkerLoad.Builder(worker).build(); + } + + public WorkerLoad workerLoad(String worker, int connectorStart, int connectorNum, + int taskStart, int taskNum) { + return new WorkerLoad.Builder(worker).with( + newConnectors(connectorStart, connectorStart + connectorNum), + newTasks(taskStart, taskStart + taskNum)).build(); + } + + public static List newConnectors(int start, int end) { + return IntStream.range(start, end) + .mapToObj(i -> "connector" + i) + .collect(Collectors.toList()); + } + + public static List newTasks(int start, int end) { + return IntStream.range(start, end) + .mapToObj(i -> new ConnectorTaskId("task", i)) + .collect(Collectors.toList()); + } + + public static ClusterConfigState clusterConfigState(long offset, + int connectorNum, + int taskNum) { + return new ClusterConfigState( + offset, + null, + connectorTaskCounts(1, connectorNum, taskNum), + connectorConfigs(1, connectorNum), + connectorTargetStates(1, connectorNum, TargetState.STARTED), + taskConfigs(0, connectorNum, connectorNum * taskNum), + Collections.emptySet()); + } + + public static Map memberConfigs(String givenLeader, + long givenOffset, + Map givenAssignments) { + return givenAssignments.entrySet().stream() + .collect(Collectors.toMap( + Map.Entry::getKey, + e -> new ExtendedWorkerState(expectedLeaderUrl(givenLeader), givenOffset, e.getValue()))); + } + + public static Map memberConfigs(String givenLeader, + long givenOffset, + int start, + int connectorNum) { + return IntStream.range(start, connectorNum + 1) + .mapToObj(i -> new SimpleEntry<>("worker" + i, new ExtendedWorkerState(expectedLeaderUrl(givenLeader), givenOffset, null))) + .collect(Collectors.toMap(SimpleEntry::getKey, SimpleEntry::getValue)); + } + + public static Map connectorTaskCounts(int start, + int connectorNum, + int taskCounts) { + return IntStream.range(start, connectorNum + 1) + .mapToObj(i -> new SimpleEntry<>("connector" + i, taskCounts)) + .collect(Collectors.toMap(SimpleEntry::getKey, SimpleEntry::getValue)); + } + + public static Map> connectorConfigs(int start, int connectorNum) { + return IntStream.range(start, connectorNum + 1) + .mapToObj(i -> new SimpleEntry<>("connector" + i, new HashMap())) + .collect(Collectors.toMap(SimpleEntry::getKey, SimpleEntry::getValue)); + } + + public static Map connectorTargetStates(int start, + int connectorNum, + TargetState state) { + return IntStream.range(start, connectorNum + 1) + .mapToObj(i -> new SimpleEntry<>("connector" + i, state)) + .collect(Collectors.toMap(SimpleEntry::getKey, SimpleEntry::getValue)); + } + + public static Map> taskConfigs(int start, + int connectorNum, + int taskNum) { + return IntStream.range(start, taskNum + 1) + .mapToObj(i -> new SimpleEntry<>( + new ConnectorTaskId("connector" + i / connectorNum + 1, i), + new HashMap()) + ).collect(Collectors.toMap(SimpleEntry::getKey, SimpleEntry::getValue)); + } + + public static String expectedLeaderUrl(String givenLeader) { + return "http://" + givenLeader + ":8083"; + } + + public static void assertAssignment(String expectedLeader, + long expectedOffset, + List expectedAssignedConnectors, + int expectedAssignedTaskNum, + List expectedRevokedConnectors, + int expectedRevokedTaskNum, + ExtendedAssignment assignment) { + assertAssignment(false, expectedLeader, expectedOffset, + expectedAssignedConnectors, expectedAssignedTaskNum, + expectedRevokedConnectors, expectedRevokedTaskNum, + 0, + assignment); + } + + public static void assertAssignment(String expectedLeader, + long expectedOffset, + List expectedAssignedConnectors, + int expectedAssignedTaskNum, + List expectedRevokedConnectors, + int expectedRevokedTaskNum, + int expectedDelay, + ExtendedAssignment assignment) { + assertAssignment(false, expectedLeader, expectedOffset, + expectedAssignedConnectors, expectedAssignedTaskNum, + expectedRevokedConnectors, expectedRevokedTaskNum, + expectedDelay, + assignment); + } + + public static void assertAssignment(boolean expectFailed, + String expectedLeader, + long expectedOffset, + List expectedAssignedConnectors, + int expectedAssignedTaskNum, + List expectedRevokedConnectors, + int expectedRevokedTaskNum, + int expectedDelay, + ExtendedAssignment assignment) { + assertNotNull("Assignment can't be null", assignment); + + assertEquals("Wrong status in " + assignment, expectFailed, assignment.failed()); + + assertEquals("Wrong leader in " + assignment, expectedLeader, assignment.leader()); + + assertEquals("Wrong leaderUrl in " + assignment, expectedLeaderUrl(expectedLeader), + assignment.leaderUrl()); + + assertEquals("Wrong offset in " + assignment, expectedOffset, assignment.offset()); + + assertThat("Wrong set of assigned connectors in " + assignment, + assignment.connectors(), is(expectedAssignedConnectors)); + + assertEquals("Wrong number of assigned tasks in " + assignment, + expectedAssignedTaskNum, assignment.tasks().size()); + + assertThat("Wrong set of revoked connectors in " + assignment, + assignment.revokedConnectors(), is(expectedRevokedConnectors)); + + assertEquals("Wrong number of revoked tasks in " + assignment, + expectedRevokedTaskNum, assignment.revokedTasks().size()); + + assertEquals("Wrong rebalance delay in " + assignment, expectedDelay, assignment.delay()); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/ConnectProtocolCompatibilityTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/ConnectProtocolCompatibilityTest.java new file mode 100644 index 0000000..a3144c0 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/ConnectProtocolCompatibilityTest.java @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.distributed; + +import org.apache.kafka.connect.runtime.TargetState; +import org.apache.kafka.connect.storage.KafkaConfigBackingStore; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; + +import static org.apache.kafka.connect.runtime.distributed.IncrementalCooperativeConnectProtocol.CONNECT_PROTOCOL_V1; +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class ConnectProtocolCompatibilityTest { + private static final String LEADER_URL = "leaderUrl:8083"; + + private String connectorId1 = "connector1"; + private String connectorId2 = "connector2"; + private String connectorId3 = "connector3"; + private ConnectorTaskId taskId1x0 = new ConnectorTaskId(connectorId1, 0); + private ConnectorTaskId taskId1x1 = new ConnectorTaskId(connectorId1, 1); + private ConnectorTaskId taskId2x0 = new ConnectorTaskId(connectorId2, 0); + private ConnectorTaskId taskId3x0 = new ConnectorTaskId(connectorId3, 0); + + @Rule + public MockitoRule rule = MockitoJUnit.rule(); + + @Mock + private KafkaConfigBackingStore configStorage; + private ClusterConfigState configState; + + @Before + public void setup() { + configStorage = mock(KafkaConfigBackingStore.class); + configState = new ClusterConfigState( + 1L, + null, + Collections.singletonMap(connectorId1, 1), + Collections.singletonMap(connectorId1, new HashMap<>()), + Collections.singletonMap(connectorId1, TargetState.STARTED), + Collections.singletonMap(taskId1x0, new HashMap<>()), + Collections.emptySet()); + } + + @After + public void teardown() { + verifyNoMoreInteractions(configStorage); + } + + @Test + public void testEagerToEagerMetadata() { + when(configStorage.snapshot()).thenReturn(configState); + ExtendedWorkerState workerState = new ExtendedWorkerState(LEADER_URL, configStorage.snapshot().offset(), null); + ByteBuffer metadata = ConnectProtocol.serializeMetadata(workerState); + ConnectProtocol.WorkerState state = ConnectProtocol.deserializeMetadata(metadata); + assertEquals(LEADER_URL, state.url()); + assertEquals(1, state.offset()); + verify(configStorage).snapshot(); + } + + @Test + public void testCoopToCoopMetadata() { + when(configStorage.snapshot()).thenReturn(configState); + ExtendedWorkerState workerState = new ExtendedWorkerState(LEADER_URL, configStorage.snapshot().offset(), null); + ByteBuffer metadata = IncrementalCooperativeConnectProtocol.serializeMetadata(workerState, false); + ExtendedWorkerState state = IncrementalCooperativeConnectProtocol.deserializeMetadata(metadata); + assertEquals(LEADER_URL, state.url()); + assertEquals(1, state.offset()); + verify(configStorage).snapshot(); + } + + @Test + public void testSessionedToCoopMetadata() { + when(configStorage.snapshot()).thenReturn(configState); + ExtendedWorkerState workerState = new ExtendedWorkerState(LEADER_URL, configStorage.snapshot().offset(), null); + ByteBuffer metadata = IncrementalCooperativeConnectProtocol.serializeMetadata(workerState, true); + ExtendedWorkerState state = IncrementalCooperativeConnectProtocol.deserializeMetadata(metadata); + assertEquals(LEADER_URL, state.url()); + assertEquals(1, state.offset()); + verify(configStorage).snapshot(); + } + + @Test + public void testSessionedToEagerMetadata() { + when(configStorage.snapshot()).thenReturn(configState); + ExtendedWorkerState workerState = new ExtendedWorkerState(LEADER_URL, configStorage.snapshot().offset(), null); + ByteBuffer metadata = IncrementalCooperativeConnectProtocol.serializeMetadata(workerState, true); + ConnectProtocol.WorkerState state = ConnectProtocol.deserializeMetadata(metadata); + assertEquals(LEADER_URL, state.url()); + assertEquals(1, state.offset()); + verify(configStorage).snapshot(); + } + + @Test + public void testCoopToEagerMetadata() { + when(configStorage.snapshot()).thenReturn(configState); + ExtendedWorkerState workerState = new ExtendedWorkerState(LEADER_URL, configStorage.snapshot().offset(), null); + ByteBuffer metadata = IncrementalCooperativeConnectProtocol.serializeMetadata(workerState, false); + ConnectProtocol.WorkerState state = ConnectProtocol.deserializeMetadata(metadata); + assertEquals(LEADER_URL, state.url()); + assertEquals(1, state.offset()); + verify(configStorage).snapshot(); + } + + @Test + public void testEagerToCoopMetadata() { + when(configStorage.snapshot()).thenReturn(configState); + ConnectProtocol.WorkerState workerState = new ConnectProtocol.WorkerState(LEADER_URL, configStorage.snapshot().offset()); + ByteBuffer metadata = ConnectProtocol.serializeMetadata(workerState); + ConnectProtocol.WorkerState state = IncrementalCooperativeConnectProtocol.deserializeMetadata(metadata); + assertEquals(LEADER_URL, state.url()); + assertEquals(1, state.offset()); + verify(configStorage).snapshot(); + } + + @Test + public void testEagerToEagerAssignment() { + ConnectProtocol.Assignment assignment = new ConnectProtocol.Assignment( + ConnectProtocol.Assignment.NO_ERROR, "leader", LEADER_URL, 1L, + Arrays.asList(connectorId1, connectorId3), Arrays.asList(taskId2x0)); + + ByteBuffer leaderBuf = ConnectProtocol.serializeAssignment(assignment); + ConnectProtocol.Assignment leaderAssignment = ConnectProtocol.deserializeAssignment(leaderBuf); + assertEquals(false, leaderAssignment.failed()); + assertEquals("leader", leaderAssignment.leader()); + assertEquals(1, leaderAssignment.offset()); + assertEquals(Arrays.asList(connectorId1, connectorId3), leaderAssignment.connectors()); + assertEquals(Collections.singletonList(taskId2x0), leaderAssignment.tasks()); + + ConnectProtocol.Assignment assignment2 = new ConnectProtocol.Assignment( + ConnectProtocol.Assignment.NO_ERROR, "member", LEADER_URL, 1L, + Arrays.asList(connectorId2), Arrays.asList(taskId1x0, taskId3x0)); + + ByteBuffer memberBuf = ConnectProtocol.serializeAssignment(assignment2); + ConnectProtocol.Assignment memberAssignment = ConnectProtocol.deserializeAssignment(memberBuf); + assertEquals(false, memberAssignment.failed()); + assertEquals("member", memberAssignment.leader()); + assertEquals(1, memberAssignment.offset()); + assertEquals(Collections.singletonList(connectorId2), memberAssignment.connectors()); + assertEquals(Arrays.asList(taskId1x0, taskId3x0), memberAssignment.tasks()); + } + + @Test + public void testCoopToCoopAssignment() { + ExtendedAssignment assignment = new ExtendedAssignment( + CONNECT_PROTOCOL_V1, ConnectProtocol.Assignment.NO_ERROR, "leader", LEADER_URL, 1L, + Arrays.asList(connectorId1, connectorId3), Arrays.asList(taskId2x0), + Collections.emptyList(), Collections.emptyList(), 0); + + ByteBuffer leaderBuf = IncrementalCooperativeConnectProtocol.serializeAssignment(assignment); + ConnectProtocol.Assignment leaderAssignment = ConnectProtocol.deserializeAssignment(leaderBuf); + assertEquals(false, leaderAssignment.failed()); + assertEquals("leader", leaderAssignment.leader()); + assertEquals(1, leaderAssignment.offset()); + assertEquals(Arrays.asList(connectorId1, connectorId3), leaderAssignment.connectors()); + assertEquals(Collections.singletonList(taskId2x0), leaderAssignment.tasks()); + + ExtendedAssignment assignment2 = new ExtendedAssignment( + CONNECT_PROTOCOL_V1, ConnectProtocol.Assignment.NO_ERROR, "member", LEADER_URL, 1L, + Arrays.asList(connectorId2), Arrays.asList(taskId1x0, taskId3x0), + Collections.emptyList(), Collections.emptyList(), 0); + + ByteBuffer memberBuf = ConnectProtocol.serializeAssignment(assignment2); + ConnectProtocol.Assignment memberAssignment = + IncrementalCooperativeConnectProtocol.deserializeAssignment(memberBuf); + assertEquals(false, memberAssignment.failed()); + assertEquals("member", memberAssignment.leader()); + assertEquals(1, memberAssignment.offset()); + assertEquals(Collections.singletonList(connectorId2), memberAssignment.connectors()); + assertEquals(Arrays.asList(taskId1x0, taskId3x0), memberAssignment.tasks()); + } + + @Test + public void testEagerToCoopAssignment() { + ConnectProtocol.Assignment assignment = new ConnectProtocol.Assignment( + ConnectProtocol.Assignment.NO_ERROR, "leader", LEADER_URL, 1L, + Arrays.asList(connectorId1, connectorId3), Arrays.asList(taskId2x0)); + + ByteBuffer leaderBuf = ConnectProtocol.serializeAssignment(assignment); + ConnectProtocol.Assignment leaderAssignment = + IncrementalCooperativeConnectProtocol.deserializeAssignment(leaderBuf); + assertEquals(false, leaderAssignment.failed()); + assertEquals("leader", leaderAssignment.leader()); + assertEquals(1, leaderAssignment.offset()); + assertEquals(Arrays.asList(connectorId1, connectorId3), leaderAssignment.connectors()); + assertEquals(Collections.singletonList(taskId2x0), leaderAssignment.tasks()); + + ConnectProtocol.Assignment assignment2 = new ConnectProtocol.Assignment( + ConnectProtocol.Assignment.NO_ERROR, "member", LEADER_URL, 1L, + Arrays.asList(connectorId2), Arrays.asList(taskId1x0, taskId3x0)); + + ByteBuffer memberBuf = ConnectProtocol.serializeAssignment(assignment2); + ConnectProtocol.Assignment memberAssignment = + IncrementalCooperativeConnectProtocol.deserializeAssignment(memberBuf); + assertEquals(false, memberAssignment.failed()); + assertEquals("member", memberAssignment.leader()); + assertEquals(1, memberAssignment.offset()); + assertEquals(Collections.singletonList(connectorId2), memberAssignment.connectors()); + assertEquals(Arrays.asList(taskId1x0, taskId3x0), memberAssignment.tasks()); + } + + @Test + public void testCoopToEagerAssignment() { + ExtendedAssignment assignment = new ExtendedAssignment( + CONNECT_PROTOCOL_V1, ConnectProtocol.Assignment.NO_ERROR, "leader", LEADER_URL, 1L, + Arrays.asList(connectorId1, connectorId3), Arrays.asList(taskId2x0), + Collections.emptyList(), Collections.emptyList(), 0); + + ByteBuffer leaderBuf = IncrementalCooperativeConnectProtocol.serializeAssignment(assignment); + ConnectProtocol.Assignment leaderAssignment = ConnectProtocol.deserializeAssignment(leaderBuf); + assertEquals(false, leaderAssignment.failed()); + assertEquals("leader", leaderAssignment.leader()); + assertEquals(1, leaderAssignment.offset()); + assertEquals(Arrays.asList(connectorId1, connectorId3), leaderAssignment.connectors()); + assertEquals(Collections.singletonList(taskId2x0), leaderAssignment.tasks()); + + ExtendedAssignment assignment2 = new ExtendedAssignment( + CONNECT_PROTOCOL_V1, ConnectProtocol.Assignment.NO_ERROR, "member", LEADER_URL, 1L, + Arrays.asList(connectorId2), Arrays.asList(taskId1x0, taskId3x0), + Collections.emptyList(), Collections.emptyList(), 0); + + ByteBuffer memberBuf = IncrementalCooperativeConnectProtocol.serializeAssignment(assignment2); + ConnectProtocol.Assignment memberAssignment = ConnectProtocol.deserializeAssignment(memberBuf); + assertEquals(false, memberAssignment.failed()); + assertEquals("member", memberAssignment.leader()); + assertEquals(1, memberAssignment.offset()); + assertEquals(Collections.singletonList(connectorId2), memberAssignment.connectors()); + assertEquals(Arrays.asList(taskId1x0, taskId3x0), memberAssignment.tasks()); + } + +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/DistributedConfigTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/DistributedConfigTest.java new file mode 100644 index 0000000..e952327 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/DistributedConfigTest.java @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.runtime.distributed; + +import org.apache.kafka.common.config.ConfigException; +import org.junit.Test; + +import javax.crypto.KeyGenerator; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; + +public class DistributedConfigTest { + + public Map configs() { + Map result = new HashMap<>(); + result.put(DistributedConfig.GROUP_ID_CONFIG, "connect-cluster"); + result.put(DistributedConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + result.put(DistributedConfig.CONFIG_TOPIC_CONFIG, "connect-configs"); + result.put(DistributedConfig.OFFSET_STORAGE_TOPIC_CONFIG, "connect-offsets"); + result.put(DistributedConfig.STATUS_STORAGE_TOPIC_CONFIG, "connect-status"); + result.put(DistributedConfig.KEY_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + result.put(DistributedConfig.VALUE_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + return result; + } + + @Test + public void shouldCreateKeyGeneratorWithDefaultSettings() { + DistributedConfig config = new DistributedConfig(configs()); + assertNotNull(config.getInternalRequestKeyGenerator()); + } + + @Test + public void shouldCreateKeyGeneratorWithSpecificSettings() { + final String algorithm = "HmacSHA1"; + Map configs = configs(); + configs.put(DistributedConfig.INTER_WORKER_KEY_GENERATION_ALGORITHM_CONFIG, algorithm); + configs.put(DistributedConfig.INTER_WORKER_KEY_SIZE_CONFIG, "512"); + configs.put(DistributedConfig.INTER_WORKER_VERIFICATION_ALGORITHMS_CONFIG, algorithm); + DistributedConfig config = new DistributedConfig(configs); + KeyGenerator keyGenerator = config.getInternalRequestKeyGenerator(); + assertNotNull(keyGenerator); + assertEquals(algorithm, keyGenerator.getAlgorithm()); + assertEquals(512 / 8, keyGenerator.generateKey().getEncoded().length); + } + + @Test + public void shouldFailWithEmptyListOfVerificationAlgorithms() { + Map configs = configs(); + configs.put(DistributedConfig.INTER_WORKER_VERIFICATION_ALGORITHMS_CONFIG, ""); + assertThrows(ConfigException.class, () -> new DistributedConfig(configs)); + } + + @Test + public void shouldFailIfKeyAlgorithmNotInVerificationAlgorithmsList() { + Map configs = configs(); + configs.put(DistributedConfig.INTER_WORKER_KEY_GENERATION_ALGORITHM_CONFIG, "HmacSHA1"); + configs.put(DistributedConfig.INTER_WORKER_VERIFICATION_ALGORITHMS_CONFIG, "HmacSHA256"); + assertThrows(ConfigException.class, () -> new DistributedConfig(configs)); + } + + @Test + public void shouldFailWithInvalidKeyAlgorithm() { + Map configs = configs(); + configs.put(DistributedConfig.INTER_WORKER_KEY_GENERATION_ALGORITHM_CONFIG, "not-actually-a-key-algorithm"); + assertThrows(ConfigException.class, () -> new DistributedConfig(configs)); + } + + @Test + public void shouldFailWithInvalidKeySize() { + Map configs = configs(); + configs.put(DistributedConfig.INTER_WORKER_KEY_SIZE_CONFIG, "0"); + assertThrows(ConfigException.class, () -> new DistributedConfig(configs)); + } + + @Test + public void shouldValidateAllVerificationAlgorithms() { + List algorithms = + new ArrayList<>(Arrays.asList("HmacSHA1", "HmacSHA256", "HmacMD5", "bad-algorithm")); + Map configs = configs(); + for (int i = 0; i < algorithms.size(); i++) { + configs.put(DistributedConfig.INTER_WORKER_VERIFICATION_ALGORITHMS_CONFIG, String.join(",", algorithms)); + assertThrows(ConfigException.class, () -> new DistributedConfig(configs)); + algorithms.add(algorithms.remove(0)); + } + } + + @Test + public void shouldAllowNegativeOneAndPositiveForPartitions() { + Map settings = configs(); + settings.put(DistributedConfig.OFFSET_STORAGE_PARTITIONS_CONFIG, "-1"); + settings.put(DistributedConfig.STATUS_STORAGE_PARTITIONS_CONFIG, "-1"); + new DistributedConfig(configs()); + settings.remove(DistributedConfig.OFFSET_STORAGE_PARTITIONS_CONFIG); + settings.remove(DistributedConfig.STATUS_STORAGE_PARTITIONS_CONFIG); + + for (int i = 1; i != 100; ++i) { + settings.put(DistributedConfig.OFFSET_STORAGE_PARTITIONS_CONFIG, Integer.toString(i)); + new DistributedConfig(settings); + settings.remove(DistributedConfig.OFFSET_STORAGE_PARTITIONS_CONFIG); + + settings.put(DistributedConfig.STATUS_STORAGE_PARTITIONS_CONFIG, Integer.toString(i)); + new DistributedConfig(settings); + } + } + + @Test + public void shouldNotAllowZeroPartitions() { + Map settings = configs(); + settings.put(DistributedConfig.OFFSET_STORAGE_PARTITIONS_CONFIG, "0"); + assertThrows(ConfigException.class, () -> new DistributedConfig(settings)); + settings.remove(DistributedConfig.OFFSET_STORAGE_PARTITIONS_CONFIG); + + settings.put(DistributedConfig.STATUS_STORAGE_PARTITIONS_CONFIG, "0"); + assertThrows(ConfigException.class, () -> new DistributedConfig(settings)); + } + + @Test + public void shouldNotAllowNegativePartitionsLessThanNegativeOne() { + Map settings = configs(); + for (int i = -2; i > -100; --i) { + settings.put(DistributedConfig.OFFSET_STORAGE_PARTITIONS_CONFIG, Integer.toString(i)); + assertThrows(ConfigException.class, () -> new DistributedConfig(settings)); + settings.remove(DistributedConfig.OFFSET_STORAGE_PARTITIONS_CONFIG); + + settings.put(DistributedConfig.STATUS_STORAGE_PARTITIONS_CONFIG, Integer.toString(i)); + assertThrows(ConfigException.class, () -> new DistributedConfig(settings)); + } + } + + @Test + public void shouldAllowNegativeOneAndPositiveForReplicationFactor() { + Map settings = configs(); + settings.put(DistributedConfig.CONFIG_STORAGE_REPLICATION_FACTOR_CONFIG, "-1"); + settings.put(DistributedConfig.OFFSET_STORAGE_REPLICATION_FACTOR_CONFIG, "-1"); + settings.put(DistributedConfig.STATUS_STORAGE_REPLICATION_FACTOR_CONFIG, "-1"); + new DistributedConfig(configs()); + settings.remove(DistributedConfig.CONFIG_STORAGE_REPLICATION_FACTOR_CONFIG); + settings.remove(DistributedConfig.OFFSET_STORAGE_PARTITIONS_CONFIG); + settings.remove(DistributedConfig.STATUS_STORAGE_PARTITIONS_CONFIG); + + for (int i = 1; i != 100; ++i) { + settings.put(DistributedConfig.CONFIG_STORAGE_REPLICATION_FACTOR_CONFIG, Integer.toString(i)); + new DistributedConfig(settings); + settings.remove(DistributedConfig.CONFIG_STORAGE_REPLICATION_FACTOR_CONFIG); + + settings.put(DistributedConfig.OFFSET_STORAGE_PARTITIONS_CONFIG, Integer.toString(i)); + new DistributedConfig(settings); + settings.remove(DistributedConfig.OFFSET_STORAGE_PARTITIONS_CONFIG); + + settings.put(DistributedConfig.STATUS_STORAGE_PARTITIONS_CONFIG, Integer.toString(i)); + new DistributedConfig(settings); + } + } + + @Test + public void shouldNotAllowZeroReplicationFactor() { + Map settings = configs(); + settings.put(DistributedConfig.CONFIG_STORAGE_REPLICATION_FACTOR_CONFIG, "0"); + assertThrows(ConfigException.class, () -> new DistributedConfig(settings)); + settings.remove(DistributedConfig.CONFIG_STORAGE_REPLICATION_FACTOR_CONFIG); + + settings.put(DistributedConfig.OFFSET_STORAGE_REPLICATION_FACTOR_CONFIG, "0"); + assertThrows(ConfigException.class, () -> new DistributedConfig(settings)); + settings.remove(DistributedConfig.OFFSET_STORAGE_REPLICATION_FACTOR_CONFIG); + + settings.put(DistributedConfig.STATUS_STORAGE_REPLICATION_FACTOR_CONFIG, "0"); + assertThrows(ConfigException.class, () -> new DistributedConfig(settings)); + } + + @Test + public void shouldNotAllowNegativeReplicationFactorLessThanNegativeOne() { + Map settings = configs(); + for (int i = -2; i > -100; --i) { + settings.put(DistributedConfig.CONFIG_STORAGE_REPLICATION_FACTOR_CONFIG, Integer.toString(i)); + assertThrows(ConfigException.class, () -> new DistributedConfig(settings)); + settings.remove(DistributedConfig.CONFIG_STORAGE_REPLICATION_FACTOR_CONFIG); + + settings.put(DistributedConfig.OFFSET_STORAGE_REPLICATION_FACTOR_CONFIG, Integer.toString(i)); + assertThrows(ConfigException.class, () -> new DistributedConfig(settings)); + settings.remove(DistributedConfig.OFFSET_STORAGE_REPLICATION_FACTOR_CONFIG); + + settings.put(DistributedConfig.STATUS_STORAGE_REPLICATION_FACTOR_CONFIG, Integer.toString(i)); + assertThrows(ConfigException.class, () -> new DistributedConfig(settings)); + } + } + + @Test + public void shouldAllowSettingConfigTopicSettings() { + Map topicSettings = new HashMap<>(); + topicSettings.put("foo", "foo value"); + topicSettings.put("bar", "bar value"); + topicSettings.put("baz.bim", "100"); + Map settings = configs(); + topicSettings.forEach((k, v) -> settings.put(DistributedConfig.CONFIG_STORAGE_PREFIX + k, v)); + DistributedConfig config = new DistributedConfig(settings); + assertEquals(topicSettings, config.configStorageTopicSettings()); + } + + @Test + public void shouldAllowSettingOffsetTopicSettings() { + Map topicSettings = new HashMap<>(); + topicSettings.put("foo", "foo value"); + topicSettings.put("bar", "bar value"); + topicSettings.put("baz.bim", "100"); + Map settings = configs(); + topicSettings.forEach((k, v) -> settings.put(DistributedConfig.OFFSET_STORAGE_PREFIX + k, v)); + DistributedConfig config = new DistributedConfig(settings); + assertEquals(topicSettings, config.offsetStorageTopicSettings()); + } + + @Test + public void shouldAllowSettingStatusTopicSettings() { + Map topicSettings = new HashMap<>(); + topicSettings.put("foo", "foo value"); + topicSettings.put("bar", "bar value"); + topicSettings.put("baz.bim", "100"); + Map settings = configs(); + topicSettings.forEach((k, v) -> settings.put(DistributedConfig.STATUS_STORAGE_PREFIX + k, v)); + DistributedConfig config = new DistributedConfig(settings); + assertEquals(topicSettings, config.statusStorageTopicSettings()); + } + + @Test + public void shouldRemoveCompactionFromConfigTopicSettings() { + Map expectedTopicSettings = new HashMap<>(); + expectedTopicSettings.put("foo", "foo value"); + expectedTopicSettings.put("bar", "bar value"); + expectedTopicSettings.put("baz.bim", "100"); + Map topicSettings = new HashMap<>(expectedTopicSettings); + topicSettings.put("cleanup.policy", "something-else"); + topicSettings.put("partitions", "3"); + + Map settings = configs(); + topicSettings.forEach((k, v) -> settings.put(DistributedConfig.CONFIG_STORAGE_PREFIX + k, v)); + DistributedConfig config = new DistributedConfig(settings); + Map actual = config.configStorageTopicSettings(); + assertEquals(expectedTopicSettings, actual); + assertNotEquals(topicSettings, actual); + } + + @Test + public void shouldRemoveCompactionFromOffsetTopicSettings() { + Map expectedTopicSettings = new HashMap<>(); + expectedTopicSettings.put("foo", "foo value"); + expectedTopicSettings.put("bar", "bar value"); + expectedTopicSettings.put("baz.bim", "100"); + Map topicSettings = new HashMap<>(expectedTopicSettings); + topicSettings.put("cleanup.policy", "something-else"); + + Map settings = configs(); + topicSettings.forEach((k, v) -> settings.put(DistributedConfig.OFFSET_STORAGE_PREFIX + k, v)); + DistributedConfig config = new DistributedConfig(settings); + Map actual = config.offsetStorageTopicSettings(); + assertEquals(expectedTopicSettings, actual); + assertNotEquals(topicSettings, actual); + } + + @Test + public void shouldRemoveCompactionFromStatusTopicSettings() { + Map expectedTopicSettings = new HashMap<>(); + expectedTopicSettings.put("foo", "foo value"); + expectedTopicSettings.put("bar", "bar value"); + expectedTopicSettings.put("baz.bim", "100"); + Map topicSettings = new HashMap<>(expectedTopicSettings); + topicSettings.put("cleanup.policy", "something-else"); + + Map settings = configs(); + topicSettings.forEach((k, v) -> settings.put(DistributedConfig.STATUS_STORAGE_PREFIX + k, v)); + DistributedConfig config = new DistributedConfig(settings); + Map actual = config.statusStorageTopicSettings(); + assertEquals(expectedTopicSettings, actual); + assertNotEquals(topicSettings, actual); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/DistributedHerderTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/DistributedHerderTest.java new file mode 100644 index 0000000..245bb75 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/DistributedHerderTest.java @@ -0,0 +1,2854 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.distributed; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.common.config.ConfigValue; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.connect.connector.Connector; +import org.apache.kafka.connect.connector.policy.ConnectorClientConfigOverridePolicy; +import org.apache.kafka.connect.connector.policy.NoneConnectorClientConfigOverridePolicy; +import org.apache.kafka.connect.errors.AlreadyExistsException; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.errors.NotFoundException; +import org.apache.kafka.connect.runtime.ConnectMetrics.MetricGroup; +import org.apache.kafka.connect.runtime.ConnectorConfig; +import org.apache.kafka.connect.runtime.Herder; +import org.apache.kafka.connect.runtime.MockConnectMetrics; +import org.apache.kafka.connect.runtime.RestartPlan; +import org.apache.kafka.connect.runtime.RestartRequest; +import org.apache.kafka.connect.runtime.SessionKey; +import org.apache.kafka.connect.runtime.SinkConnectorConfig; +import org.apache.kafka.connect.runtime.TargetState; +import org.apache.kafka.connect.runtime.TaskConfig; +import org.apache.kafka.connect.runtime.TopicStatus; +import org.apache.kafka.connect.runtime.Worker; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.apache.kafka.connect.runtime.WorkerConfigTransformer; +import org.apache.kafka.connect.runtime.distributed.DistributedHerder.HerderMetrics; +import org.apache.kafka.connect.runtime.isolation.DelegatingClassLoader; +import org.apache.kafka.connect.runtime.isolation.PluginClassLoader; +import org.apache.kafka.connect.runtime.isolation.Plugins; +import org.apache.kafka.connect.runtime.rest.InternalRequestSignature; +import org.apache.kafka.connect.runtime.rest.entities.ConfigInfos; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorInfo; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorStateInfo; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorType; +import org.apache.kafka.connect.runtime.rest.entities.TaskInfo; +import org.apache.kafka.connect.runtime.rest.errors.BadRequestException; +import org.apache.kafka.connect.runtime.rest.errors.ConnectRestException; +import org.apache.kafka.connect.sink.SinkConnector; +import org.apache.kafka.connect.source.SourceConnector; +import org.apache.kafka.connect.source.SourceTask; +import org.apache.kafka.connect.storage.ConfigBackingStore; +import org.apache.kafka.connect.storage.StatusBackingStore; +import org.apache.kafka.connect.util.Callback; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.apache.kafka.connect.util.FutureCallback; +import org.easymock.Capture; +import org.easymock.EasyMock; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.easymock.PowerMock; +import org.powermock.api.easymock.annotation.Mock; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; +import org.powermock.reflect.Whitebox; + +import javax.crypto.SecretKey; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import static java.util.Collections.singletonList; +import static javax.ws.rs.core.Response.Status.FORBIDDEN; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocol.CONNECT_PROTOCOL_V0; +import static org.apache.kafka.connect.runtime.distributed.DistributedConfig.INTER_WORKER_KEY_GENERATION_ALGORITHM_DEFAULT; +import static org.apache.kafka.connect.runtime.distributed.IncrementalCooperativeConnectProtocol.CONNECT_PROTOCOL_V1; +import static org.apache.kafka.connect.runtime.distributed.IncrementalCooperativeConnectProtocol.CONNECT_PROTOCOL_V2; +import static org.easymock.EasyMock.anyObject; +import static org.easymock.EasyMock.capture; +import static org.easymock.EasyMock.newCapture; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +@SuppressWarnings("deprecation") +@RunWith(PowerMockRunner.class) +@PrepareForTest({DistributedHerder.class, Plugins.class}) +@PowerMockIgnore({"javax.management.*", "javax.crypto.*"}) +public class DistributedHerderTest { + private static final Map HERDER_CONFIG = new HashMap<>(); + static { + HERDER_CONFIG.put(DistributedConfig.STATUS_STORAGE_TOPIC_CONFIG, "status-topic"); + HERDER_CONFIG.put(DistributedConfig.CONFIG_TOPIC_CONFIG, "config-topic"); + HERDER_CONFIG.put(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + HERDER_CONFIG.put(DistributedConfig.GROUP_ID_CONFIG, "connect-test-group"); + // The WorkerConfig base class has some required settings without defaults + HERDER_CONFIG.put(WorkerConfig.KEY_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + HERDER_CONFIG.put(WorkerConfig.VALUE_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + HERDER_CONFIG.put(DistributedConfig.OFFSET_STORAGE_TOPIC_CONFIG, "connect-offsets"); + } + private static final String MEMBER_URL = "memberUrl"; + + private static final String CONN1 = "sourceA"; + private static final String CONN2 = "sourceB"; + private static final ConnectorTaskId TASK0 = new ConnectorTaskId(CONN1, 0); + private static final ConnectorTaskId TASK1 = new ConnectorTaskId(CONN1, 1); + private static final ConnectorTaskId TASK2 = new ConnectorTaskId(CONN1, 2); + private static final Integer MAX_TASKS = 3; + private static final Map CONN1_CONFIG = new HashMap<>(); + private static final String FOO_TOPIC = "foo"; + private static final String BAR_TOPIC = "bar"; + private static final String BAZ_TOPIC = "baz"; + static { + CONN1_CONFIG.put(ConnectorConfig.NAME_CONFIG, CONN1); + CONN1_CONFIG.put(ConnectorConfig.TASKS_MAX_CONFIG, MAX_TASKS.toString()); + CONN1_CONFIG.put(SinkConnectorConfig.TOPICS_CONFIG, String.join(",", FOO_TOPIC, BAR_TOPIC)); + CONN1_CONFIG.put(ConnectorConfig.CONNECTOR_CLASS_CONFIG, BogusSourceConnector.class.getName()); + } + private static final Map CONN1_CONFIG_UPDATED = new HashMap<>(CONN1_CONFIG); + static { + CONN1_CONFIG_UPDATED.put(SinkConnectorConfig.TOPICS_CONFIG, String.join(",", FOO_TOPIC, BAR_TOPIC, BAZ_TOPIC)); + } + private static final ConfigInfos CONN1_CONFIG_INFOS = + new ConfigInfos(CONN1, 0, Collections.emptyList(), Collections.emptyList()); + private static final Map CONN2_CONFIG = new HashMap<>(); + static { + CONN2_CONFIG.put(ConnectorConfig.NAME_CONFIG, CONN2); + CONN2_CONFIG.put(ConnectorConfig.TASKS_MAX_CONFIG, MAX_TASKS.toString()); + CONN2_CONFIG.put(SinkConnectorConfig.TOPICS_CONFIG, String.join(",", FOO_TOPIC, BAR_TOPIC)); + CONN2_CONFIG.put(ConnectorConfig.CONNECTOR_CLASS_CONFIG, BogusSourceConnector.class.getName()); + } + private static final ConfigInfos CONN2_CONFIG_INFOS = + new ConfigInfos(CONN2, 0, Collections.emptyList(), Collections.emptyList()); + private static final ConfigInfos CONN2_INVALID_CONFIG_INFOS = + new ConfigInfos(CONN2, 1, Collections.emptyList(), Collections.emptyList()); + private static final Map TASK_CONFIG = new HashMap<>(); + static { + TASK_CONFIG.put(TaskConfig.TASK_CLASS_CONFIG, BogusSourceTask.class.getName()); + } + private static final List> TASK_CONFIGS = new ArrayList<>(); + static { + TASK_CONFIGS.add(TASK_CONFIG); + TASK_CONFIGS.add(TASK_CONFIG); + TASK_CONFIGS.add(TASK_CONFIG); + } + private static final HashMap> TASK_CONFIGS_MAP = new HashMap<>(); + static { + TASK_CONFIGS_MAP.put(TASK0, TASK_CONFIG); + TASK_CONFIGS_MAP.put(TASK1, TASK_CONFIG); + TASK_CONFIGS_MAP.put(TASK2, TASK_CONFIG); + } + private static final ClusterConfigState SNAPSHOT = new ClusterConfigState(1, null, Collections.singletonMap(CONN1, 3), + Collections.singletonMap(CONN1, CONN1_CONFIG), Collections.singletonMap(CONN1, TargetState.STARTED), + TASK_CONFIGS_MAP, Collections.emptySet()); + private static final ClusterConfigState SNAPSHOT_PAUSED_CONN1 = new ClusterConfigState(1, null, Collections.singletonMap(CONN1, 3), + Collections.singletonMap(CONN1, CONN1_CONFIG), Collections.singletonMap(CONN1, TargetState.PAUSED), + TASK_CONFIGS_MAP, Collections.emptySet()); + private static final ClusterConfigState SNAPSHOT_UPDATED_CONN1_CONFIG = new ClusterConfigState(1, null, Collections.singletonMap(CONN1, 3), + Collections.singletonMap(CONN1, CONN1_CONFIG_UPDATED), Collections.singletonMap(CONN1, TargetState.STARTED), + TASK_CONFIGS_MAP, Collections.emptySet()); + + private static final String WORKER_ID = "localhost:8083"; + private static final String KAFKA_CLUSTER_ID = "I4ZmrWqfT2e-upky_4fdPA"; + private static final Runnable EMPTY_RUNNABLE = () -> { + }; + + @Mock private ConfigBackingStore configBackingStore; + @Mock private StatusBackingStore statusBackingStore; + @Mock private WorkerGroupMember member; + private MockTime time; + private DistributedHerder herder; + private MockConnectMetrics metrics; + @Mock private Worker worker; + @Mock private WorkerConfigTransformer transformer; + @Mock private Callback> putConnectorCallback; + @Mock private Plugins plugins; + @Mock private PluginClassLoader pluginLoader; + @Mock private DelegatingClassLoader delegatingLoader; + private CountDownLatch shutdownCalled = new CountDownLatch(1); + + private ConfigBackingStore.UpdateListener configUpdateListener; + private WorkerRebalanceListener rebalanceListener; + + private SinkConnectorConfig conn1SinkConfig; + private SinkConnectorConfig conn1SinkConfigUpdated; + private short connectProtocolVersion; + private final ConnectorClientConfigOverridePolicy + noneConnectorClientConfigOverridePolicy = new NoneConnectorClientConfigOverridePolicy(); + + + @Before + public void setUp() throws Exception { + time = new MockTime(); + metrics = new MockConnectMetrics(time); + worker = PowerMock.createMock(Worker.class); + EasyMock.expect(worker.isSinkConnector(CONN1)).andStubReturn(Boolean.TRUE); + AutoCloseable uponShutdown = () -> shutdownCalled.countDown(); + + // Default to the old protocol unless specified otherwise + connectProtocolVersion = CONNECT_PROTOCOL_V0; + + herder = PowerMock.createPartialMock(DistributedHerder.class, + new String[]{"connectorTypeForClass", "updateDeletedConnectorStatus", "updateDeletedTaskStatus", "validateConnectorConfig", "buildRestartPlan", "recordRestarting"}, + new DistributedConfig(HERDER_CONFIG), worker, WORKER_ID, KAFKA_CLUSTER_ID, + statusBackingStore, configBackingStore, member, MEMBER_URL, metrics, time, noneConnectorClientConfigOverridePolicy, + new AutoCloseable[]{uponShutdown}); + + configUpdateListener = herder.new ConfigUpdateListener(); + rebalanceListener = herder.new RebalanceListener(time); + plugins = PowerMock.createMock(Plugins.class); + conn1SinkConfig = new SinkConnectorConfig(plugins, CONN1_CONFIG); + conn1SinkConfigUpdated = new SinkConnectorConfig(plugins, CONN1_CONFIG_UPDATED); + EasyMock.expect(herder.connectorTypeForClass(BogusSourceConnector.class.getName())).andReturn(ConnectorType.SOURCE).anyTimes(); + pluginLoader = PowerMock.createMock(PluginClassLoader.class); + delegatingLoader = PowerMock.createMock(DelegatingClassLoader.class); + PowerMock.mockStatic(Plugins.class); + PowerMock.expectPrivate(herder, "updateDeletedConnectorStatus").andVoid().anyTimes(); + PowerMock.expectPrivate(herder, "updateDeletedTaskStatus").andVoid().anyTimes(); + } + + @After + public void tearDown() { + if (metrics != null) metrics.stop(); + } + + @Test + public void testJoinAssignment() throws Exception { + // Join group and get assignment + EasyMock.expect(member.memberId()).andStubReturn("member"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + EasyMock.expect(worker.getPlugins()).andReturn(plugins); + expectRebalance(1, Arrays.asList(CONN1), Arrays.asList(TASK1)); + expectPostRebalanceCatchup(SNAPSHOT); + Capture> onStart = newCapture(); + worker.startConnector(EasyMock.eq(CONN1), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED), capture(onStart)); + PowerMock.expectLastCall().andAnswer(() -> { + onStart.getValue().onCompletion(null, TargetState.STARTED); + return true; + }); + member.wakeup(); + PowerMock.expectLastCall(); + EasyMock.expect(worker.isRunning(CONN1)).andReturn(true); + PowerMock.expectLastCall(); + EasyMock.expect(worker.connectorTaskConfigs(CONN1, conn1SinkConfig)).andReturn(TASK_CONFIGS); + worker.startTask(EasyMock.eq(TASK1), EasyMock.anyObject(), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED)); + PowerMock.expectLastCall().andReturn(true); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + herder.tick(); + time.sleep(1000L); + assertStatistics(3, 1, 100, 1000L); + + PowerMock.verifyAll(); + } + + @Test + public void testRebalance() throws Exception { + // Join group and get assignment + EasyMock.expect(member.memberId()).andStubReturn("member"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + EasyMock.expect(worker.getPlugins()).andReturn(plugins); + expectRebalance(1, Arrays.asList(CONN1), Arrays.asList(TASK1)); + expectPostRebalanceCatchup(SNAPSHOT); + Capture> onFirstStart = newCapture(); + worker.startConnector(EasyMock.eq(CONN1), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED), capture(onFirstStart)); + PowerMock.expectLastCall().andAnswer(() -> { + onFirstStart.getValue().onCompletion(null, TargetState.STARTED); + return true; + }); + member.wakeup(); + PowerMock.expectLastCall(); + EasyMock.expect(worker.isRunning(CONN1)).andReturn(true); + EasyMock.expect(worker.connectorTaskConfigs(CONN1, conn1SinkConfig)).andReturn(TASK_CONFIGS); + worker.startTask(EasyMock.eq(TASK1), EasyMock.anyObject(), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED)); + PowerMock.expectLastCall().andReturn(true); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + EasyMock.expect(worker.getPlugins()).andReturn(plugins); + expectRebalance(Arrays.asList(CONN1), Arrays.asList(TASK1), ConnectProtocol.Assignment.NO_ERROR, + 1, Arrays.asList(CONN1), Arrays.asList()); + + // and the new assignment started + Capture> onSecondStart = newCapture(); + worker.startConnector(EasyMock.eq(CONN1), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED), capture(onSecondStart)); + PowerMock.expectLastCall().andAnswer(() -> { + onSecondStart.getValue().onCompletion(null, TargetState.STARTED); + return true; + }); + member.wakeup(); + PowerMock.expectLastCall(); + EasyMock.expect(worker.isRunning(CONN1)).andReturn(true); + EasyMock.expect(worker.connectorTaskConfigs(CONN1, conn1SinkConfig)).andReturn(TASK_CONFIGS); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + time.sleep(1000L); + assertStatistics(0, 0, 0, Double.POSITIVE_INFINITY); + herder.tick(); + + time.sleep(2000L); + assertStatistics(3, 1, 100, 2000); + herder.tick(); + + time.sleep(3000L); + assertStatistics(3, 2, 100, 3000); + + PowerMock.verifyAll(); + } + + @Test + public void testIncrementalCooperativeRebalanceForNewMember() throws Exception { + connectProtocolVersion = CONNECT_PROTOCOL_V1; + // Join group. First rebalance contains revocations from other members. For the new + // member the assignment should be empty + EasyMock.expect(member.memberId()).andStubReturn("member"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V1); + expectRebalance(1, Collections.emptyList(), Collections.emptyList()); + expectPostRebalanceCatchup(SNAPSHOT); + + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // The new member got its assignment + expectRebalance(Collections.emptyList(), Collections.emptyList(), + ConnectProtocol.Assignment.NO_ERROR, + 1, Arrays.asList(CONN1), Arrays.asList(TASK1), 0); + + EasyMock.expect(worker.getPlugins()).andReturn(plugins); + // and the new assignment started + Capture> onStart = newCapture(); + worker.startConnector(EasyMock.eq(CONN1), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED), capture(onStart)); + PowerMock.expectLastCall().andAnswer(() -> { + onStart.getValue().onCompletion(null, TargetState.STARTED); + return true; + }); + member.wakeup(); + PowerMock.expectLastCall(); + EasyMock.expect(worker.isRunning(CONN1)).andReturn(true); + EasyMock.expect(worker.connectorTaskConfigs(CONN1, conn1SinkConfig)).andReturn(TASK_CONFIGS); + + worker.startTask(EasyMock.eq(TASK1), EasyMock.anyObject(), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED)); + PowerMock.expectLastCall().andReturn(true); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + time.sleep(1000L); + assertStatistics(0, 0, 0, Double.POSITIVE_INFINITY); + herder.tick(); + + time.sleep(2000L); + assertStatistics(3, 1, 100, 2000); + herder.tick(); + + time.sleep(3000L); + assertStatistics(3, 2, 100, 3000); + + PowerMock.verifyAll(); + } + + @Test + public void testIncrementalCooperativeRebalanceForExistingMember() throws Exception { + connectProtocolVersion = CONNECT_PROTOCOL_V1; + // Join group. First rebalance contains revocations because a new member joined. + EasyMock.expect(member.memberId()).andStubReturn("member"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V1); + expectRebalance(Arrays.asList(CONN1), Arrays.asList(TASK1), + ConnectProtocol.Assignment.NO_ERROR, 1, + Collections.emptyList(), Collections.emptyList(), 0); + member.requestRejoin(); + PowerMock.expectLastCall(); + + // In the second rebalance the new member gets its assignment and this member has no + // assignments or revocations + expectRebalance(1, Collections.emptyList(), Collections.emptyList()); + + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + herder.configState = SNAPSHOT; + time.sleep(1000L); + assertStatistics(0, 0, 0, Double.POSITIVE_INFINITY); + herder.tick(); + + time.sleep(2000L); + assertStatistics(3, 1, 100, 2000); + herder.tick(); + + time.sleep(3000L); + assertStatistics(3, 2, 100, 3000); + + PowerMock.verifyAll(); + } + + @Test + public void testIncrementalCooperativeRebalanceWithDelay() throws Exception { + connectProtocolVersion = CONNECT_PROTOCOL_V1; + // Join group. First rebalance contains some assignments but also a delay, because a + // member was detected missing + int rebalanceDelay = 10_000; + + EasyMock.expect(member.memberId()).andStubReturn("member"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V1); + expectRebalance(Collections.emptyList(), Collections.emptyList(), + ConnectProtocol.Assignment.NO_ERROR, 1, + Collections.emptyList(), Arrays.asList(TASK2), + rebalanceDelay); + expectPostRebalanceCatchup(SNAPSHOT); + + worker.startTask(EasyMock.eq(TASK2), EasyMock.anyObject(), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED)); + PowerMock.expectLastCall().andReturn(true); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall().andAnswer(() -> { + time.sleep(9900L); + return null; + }); + + // Request to re-join because the scheduled rebalance delay has been reached + member.requestRejoin(); + PowerMock.expectLastCall(); + + // The member got its assignment and revocation + expectRebalance(Collections.emptyList(), Collections.emptyList(), + ConnectProtocol.Assignment.NO_ERROR, + 1, Arrays.asList(CONN1), Arrays.asList(TASK1), 0); + + EasyMock.expect(worker.getPlugins()).andReturn(plugins); + // and the new assignment started + Capture> onStart = newCapture(); + worker.startConnector(EasyMock.eq(CONN1), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED), capture(onStart)); + PowerMock.expectLastCall().andAnswer(() -> { + onStart.getValue().onCompletion(null, TargetState.STARTED); + return true; + }); + member.wakeup(); + PowerMock.expectLastCall(); + EasyMock.expect(worker.isRunning(CONN1)).andReturn(true); + EasyMock.expect(worker.connectorTaskConfigs(CONN1, conn1SinkConfig)).andReturn(TASK_CONFIGS); + + worker.startTask(EasyMock.eq(TASK1), EasyMock.anyObject(), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED)); + PowerMock.expectLastCall().andReturn(true); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + time.sleep(1000L); + assertStatistics(0, 0, 0, Double.POSITIVE_INFINITY); + herder.tick(); + + herder.tick(); + + time.sleep(2000L); + assertStatistics(3, 2, 100, 2000); + + PowerMock.verifyAll(); + } + + @Test + public void testRebalanceFailedConnector() throws Exception { + // Join group and get assignment + EasyMock.expect(member.memberId()).andStubReturn("member"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + EasyMock.expect(worker.getPlugins()).andReturn(plugins); + expectRebalance(1, Arrays.asList(CONN1), Arrays.asList(TASK1)); + expectPostRebalanceCatchup(SNAPSHOT); + Capture> onFirstStart = newCapture(); + worker.startConnector(EasyMock.eq(CONN1), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED), capture(onFirstStart)); + PowerMock.expectLastCall().andAnswer(() -> { + onFirstStart.getValue().onCompletion(null, TargetState.STARTED); + return true; + }); + member.wakeup(); + PowerMock.expectLastCall(); + EasyMock.expect(worker.isRunning(CONN1)).andReturn(true); + EasyMock.expect(worker.connectorTaskConfigs(CONN1, conn1SinkConfig)).andReturn(TASK_CONFIGS); + worker.startTask(EasyMock.eq(TASK1), EasyMock.anyObject(), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED)); + PowerMock.expectLastCall().andReturn(true); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + expectRebalance(Arrays.asList(CONN1), Arrays.asList(TASK1), ConnectProtocol.Assignment.NO_ERROR, + 1, Arrays.asList(CONN1), Arrays.asList()); + + // and the new assignment started + Capture> onSecondStart = newCapture(); + worker.startConnector(EasyMock.eq(CONN1), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED), capture(onSecondStart)); + PowerMock.expectLastCall().andAnswer(() -> { + onSecondStart.getValue().onCompletion(null, TargetState.STARTED); + return true; + }); + member.wakeup(); + PowerMock.expectLastCall(); + EasyMock.expect(worker.isRunning(CONN1)).andReturn(false); + + // worker is not running, so we should see no call to connectorTaskConfigs() + + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + herder.tick(); + time.sleep(1000L); + assertStatistics(3, 1, 100, 1000L); + + herder.tick(); + time.sleep(2000L); + assertStatistics(3, 2, 100, 2000L); + + PowerMock.verifyAll(); + } + + @Test + public void testRevoke() throws TimeoutException { + revokeAndReassign(false); + } + + @Test + public void testIncompleteRebalanceBeforeRevoke() throws TimeoutException { + revokeAndReassign(true); + } + + public void revokeAndReassign(boolean incompleteRebalance) throws TimeoutException { + connectProtocolVersion = CONNECT_PROTOCOL_V1; + int configOffset = 1; + + // Join group and get initial assignment + EasyMock.expect(member.memberId()).andStubReturn("member"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(connectProtocolVersion); + EasyMock.expect(worker.getPlugins()).andReturn(plugins); + // The lists need to be mutable because assignments might be removed + expectRebalance(configOffset, new ArrayList<>(singletonList(CONN1)), new ArrayList<>(singletonList(TASK1))); + expectPostRebalanceCatchup(SNAPSHOT); + Capture> onFirstStart = newCapture(); + worker.startConnector(EasyMock.eq(CONN1), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED), capture(onFirstStart)); + PowerMock.expectLastCall().andAnswer(() -> { + onFirstStart.getValue().onCompletion(null, TargetState.STARTED); + return true; + }); + member.wakeup(); + PowerMock.expectLastCall(); + EasyMock.expect(worker.isRunning(CONN1)).andReturn(true); + EasyMock.expect(worker.connectorTaskConfigs(CONN1, conn1SinkConfig)).andReturn(TASK_CONFIGS); + worker.startTask(EasyMock.eq(TASK1), EasyMock.anyObject(), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED)); + PowerMock.expectLastCall().andReturn(true); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // worker is stable with an existing set of tasks + + if (incompleteRebalance) { + // Perform a partial re-balance just prior to the revocation + // bump the configOffset to trigger reading the config topic to the end + configOffset++; + expectRebalance(configOffset, Arrays.asList(), Arrays.asList()); + // give it the wrong snapshot, as if we're out of sync/can't reach the broker + expectPostRebalanceCatchup(SNAPSHOT); + member.requestRejoin(); + PowerMock.expectLastCall(); + // tick exits early because we failed, and doesn't do the poll at the end of the method + // the worker did not startWork or reset the rebalanceResolved flag + } + + // Revoke the connector in the next rebalance + expectRebalance(Arrays.asList(CONN1), Arrays.asList(), + ConnectProtocol.Assignment.NO_ERROR, configOffset, Arrays.asList(), + Arrays.asList()); + + if (incompleteRebalance) { + // Same as SNAPSHOT, except with an updated offset + // Allow the task to read to the end of the topic and complete the rebalance + ClusterConfigState secondSnapshot = new ClusterConfigState( + configOffset, null, Collections.singletonMap(CONN1, 3), + Collections.singletonMap(CONN1, CONN1_CONFIG), Collections.singletonMap(CONN1, TargetState.STARTED), + TASK_CONFIGS_MAP, Collections.emptySet()); + expectPostRebalanceCatchup(secondSnapshot); + } + member.requestRejoin(); + PowerMock.expectLastCall(); + + // re-assign the connector back to the same worker to ensure state was cleaned up + expectRebalance(configOffset, Arrays.asList(CONN1), Arrays.asList()); + EasyMock.expect(worker.getPlugins()).andReturn(plugins); + Capture> onSecondStart = newCapture(); + worker.startConnector(EasyMock.eq(CONN1), EasyMock.anyObject(), + EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED), capture(onSecondStart)); + PowerMock.expectLastCall().andAnswer(() -> { + onSecondStart.getValue().onCompletion(null, TargetState.STARTED); + return true; + }); + member.wakeup(); + PowerMock.expectLastCall(); + EasyMock.expect(worker.isRunning(CONN1)).andReturn(true); + EasyMock.expect(worker.connectorTaskConfigs(CONN1, conn1SinkConfig)) + .andReturn(TASK_CONFIGS); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + herder.tick(); + if (incompleteRebalance) { + herder.tick(); + } + herder.tick(); + herder.tick(); + + PowerMock.verifyAll(); + } + + @Test + public void testHaltCleansUpWorker() { + EasyMock.expect(worker.connectorNames()).andReturn(Collections.singleton(CONN1)); + worker.stopAndAwaitConnector(CONN1); + PowerMock.expectLastCall(); + EasyMock.expect(worker.taskIds()).andReturn(Collections.singleton(TASK1)); + worker.stopAndAwaitTask(TASK1); + PowerMock.expectLastCall(); + member.stop(); + PowerMock.expectLastCall(); + configBackingStore.stop(); + PowerMock.expectLastCall(); + statusBackingStore.stop(); + PowerMock.expectLastCall(); + worker.stop(); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + herder.halt(); + + PowerMock.verifyAll(); + } + + @Test + public void testCreateConnector() throws Exception { + EasyMock.expect(member.memberId()).andStubReturn("leader"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + expectRebalance(1, Collections.emptyList(), Collections.emptyList()); + expectPostRebalanceCatchup(SNAPSHOT); + + member.wakeup(); + PowerMock.expectLastCall(); + + // mock the actual validation since its asynchronous nature is difficult to test and should + // be covered sufficiently by the unit tests for the AbstractHerder class + Capture> validateCallback = newCapture(); + herder.validateConnectorConfig(EasyMock.eq(CONN2_CONFIG), capture(validateCallback)); + PowerMock.expectLastCall().andAnswer(() -> { + validateCallback.getValue().onCompletion(null, CONN2_CONFIG_INFOS); + return null; + }); + + // CONN2 is new, should succeed + configBackingStore.putConnectorConfig(CONN2, CONN2_CONFIG); + PowerMock.expectLastCall(); + ConnectorInfo info = new ConnectorInfo(CONN2, CONN2_CONFIG, Collections.emptyList(), + ConnectorType.SOURCE); + putConnectorCallback.onCompletion(null, new Herder.Created<>(true, info)); + PowerMock.expectLastCall(); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + // These will occur just before/during the second tick + member.wakeup(); + PowerMock.expectLastCall(); + member.ensureActive(); + PowerMock.expectLastCall(); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // No immediate action besides this -- change will be picked up via the config log + + PowerMock.replayAll(); + + herder.putConnectorConfig(CONN2, CONN2_CONFIG, false, putConnectorCallback); + // First tick runs the initial herder request, which issues an asynchronous request for + // connector validation + herder.tick(); + + // Once that validation is complete, another request is added to the herder request queue + // for actually performing the config write; this tick is for that request + herder.tick(); + + time.sleep(1000L); + assertStatistics(3, 1, 100, 1000L); + + PowerMock.verifyAll(); + } + + @Test + public void testCreateConnectorFailedValidation() throws Exception { + EasyMock.expect(member.memberId()).andStubReturn("leader"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + expectRebalance(1, Collections.emptyList(), Collections.emptyList()); + expectPostRebalanceCatchup(SNAPSHOT); + + HashMap config = new HashMap<>(CONN2_CONFIG); + config.remove(ConnectorConfig.NAME_CONFIG); + + member.wakeup(); + PowerMock.expectLastCall(); + + // mock the actual validation since its asynchronous nature is difficult to test and should + // be covered sufficiently by the unit tests for the AbstractHerder class + Capture> validateCallback = newCapture(); + herder.validateConnectorConfig(EasyMock.eq(config), capture(validateCallback)); + PowerMock.expectLastCall().andAnswer(() -> { + // CONN2 creation should fail + validateCallback.getValue().onCompletion(null, CONN2_INVALID_CONFIG_INFOS); + return null; + }); + + Capture error = newCapture(); + putConnectorCallback.onCompletion(capture(error), EasyMock.isNull()); + PowerMock.expectLastCall(); + + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // These will occur just before/during the second tick + member.wakeup(); + PowerMock.expectLastCall(); + + member.ensureActive(); + PowerMock.expectLastCall(); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // No immediate action besides this -- change will be picked up via the config log + + PowerMock.replayAll(); + + herder.putConnectorConfig(CONN2, config, false, putConnectorCallback); + herder.tick(); + herder.tick(); + + assertTrue(error.hasCaptured()); + assertTrue(error.getValue() instanceof BadRequestException); + + time.sleep(1000L); + assertStatistics(3, 1, 100, 1000L); + + PowerMock.verifyAll(); + } + + @SuppressWarnings("unchecked") + @Test + public void testConnectorNameConflictsWithWorkerGroupId() throws Exception { + Map config = new HashMap<>(CONN2_CONFIG); + config.put(ConnectorConfig.NAME_CONFIG, "test-group"); + + Connector connectorMock = PowerMock.createMock(SinkConnector.class); + + // CONN2 creation should fail because the worker group id (connect-test-group) conflicts with + // the consumer group id we would use for this sink + Map validatedConfigs = + herder.validateBasicConnectorConfig(connectorMock, ConnectorConfig.configDef(), config); + + ConfigValue nameConfig = validatedConfigs.get(ConnectorConfig.NAME_CONFIG); + assertNotNull(nameConfig.errorMessages()); + assertFalse(nameConfig.errorMessages().isEmpty()); + } + + @Test + public void testCreateConnectorAlreadyExists() throws Exception { + EasyMock.expect(member.memberId()).andStubReturn("leader"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + + // mock the actual validation since its asynchronous nature is difficult to test and should + // be covered sufficiently by the unit tests for the AbstractHerder class + Capture> validateCallback = newCapture(); + herder.validateConnectorConfig(EasyMock.eq(CONN1_CONFIG), capture(validateCallback)); + PowerMock.expectLastCall().andAnswer(() -> { + validateCallback.getValue().onCompletion(null, CONN1_CONFIG_INFOS); + return null; + }); + + expectRebalance(1, Collections.emptyList(), Collections.emptyList()); + expectPostRebalanceCatchup(SNAPSHOT); + + member.wakeup(); + PowerMock.expectLastCall(); + // CONN1 already exists + putConnectorCallback.onCompletion(EasyMock.anyObject(), EasyMock.isNull()); + PowerMock.expectLastCall(); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // These will occur just before/during the second tick + member.wakeup(); + PowerMock.expectLastCall(); + member.ensureActive(); + PowerMock.expectLastCall(); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // No immediate action besides this -- change will be picked up via the config log + + PowerMock.replayAll(); + + herder.putConnectorConfig(CONN1, CONN1_CONFIG, false, putConnectorCallback); + herder.tick(); + herder.tick(); + + time.sleep(1000L); + assertStatistics(3, 1, 100, 1000L); + + PowerMock.verifyAll(); + } + + @Test + public void testDestroyConnector() throws Exception { + EasyMock.expect(member.memberId()).andStubReturn("leader"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + // Start with one connector + EasyMock.expect(worker.getPlugins()).andReturn(plugins); + expectRebalance(1, Arrays.asList(CONN1), Collections.emptyList()); + expectPostRebalanceCatchup(SNAPSHOT); + Capture> onStart = newCapture(); + worker.startConnector(EasyMock.eq(CONN1), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED), capture(onStart)); + PowerMock.expectLastCall().andAnswer(() -> { + onStart.getValue().onCompletion(null, TargetState.STARTED); + return true; + }); + EasyMock.expect(worker.isRunning(CONN1)).andReturn(true); + EasyMock.expect(worker.connectorTaskConfigs(CONN1, conn1SinkConfig)).andReturn(TASK_CONFIGS); + + // And delete the connector + member.wakeup(); + PowerMock.expectLastCall(); + configBackingStore.removeConnectorConfig(CONN1); + PowerMock.expectLastCall(); + putConnectorCallback.onCompletion(null, new Herder.Created<>(false, null)); + PowerMock.expectLastCall(); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // The change eventually is reflected to the config topic and the deleted connector and + // tasks are revoked + member.wakeup(); + PowerMock.expectLastCall(); + TopicStatus fooStatus = new TopicStatus(FOO_TOPIC, CONN1, 0, time.milliseconds()); + TopicStatus barStatus = new TopicStatus(BAR_TOPIC, CONN1, 0, time.milliseconds()); + EasyMock.expect(statusBackingStore.getAllTopics(EasyMock.eq(CONN1))).andReturn(new HashSet<>(Arrays.asList(fooStatus, barStatus))).times(2); + statusBackingStore.deleteTopic(EasyMock.eq(CONN1), EasyMock.eq(FOO_TOPIC)); + PowerMock.expectLastCall().times(2); + statusBackingStore.deleteTopic(EasyMock.eq(CONN1), EasyMock.eq(BAR_TOPIC)); + PowerMock.expectLastCall().times(2); + expectRebalance(Arrays.asList(CONN1), Arrays.asList(TASK1), + ConnectProtocol.Assignment.NO_ERROR, 2, + Collections.emptyList(), Collections.emptyList(), 0); + expectPostRebalanceCatchup(ClusterConfigState.EMPTY); + member.requestRejoin(); + PowerMock.expectLastCall(); + PowerMock.replayAll(); + + herder.deleteConnectorConfig(CONN1, putConnectorCallback); + herder.tick(); + + time.sleep(1000L); + assertStatistics("leaderUrl", false, 3, 1, 100, 1000L); + + configUpdateListener.onConnectorConfigRemove(CONN1); // read updated config that removes the connector + herder.configState = ClusterConfigState.EMPTY; + herder.tick(); + time.sleep(1000L); + assertStatistics("leaderUrl", true, 3, 1, 100, 2100L); + + PowerMock.verifyAll(); + } + + @Test + public void testRestartConnector() throws Exception { + EasyMock.expect(worker.connectorTaskConfigs(CONN1, conn1SinkConfig)).andStubReturn(TASK_CONFIGS); + + // get the initial assignment + EasyMock.expect(member.memberId()).andStubReturn("leader"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + EasyMock.expect(worker.getPlugins()).andReturn(plugins); + expectRebalance(1, singletonList(CONN1), Collections.emptyList()); + expectPostRebalanceCatchup(SNAPSHOT); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + Capture> onFirstStart = newCapture(); + worker.startConnector(EasyMock.eq(CONN1), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED), capture(onFirstStart)); + PowerMock.expectLastCall().andAnswer(() -> { + onFirstStart.getValue().onCompletion(null, TargetState.STARTED); + return true; + }); + member.wakeup(); + PowerMock.expectLastCall(); + EasyMock.expect(worker.isRunning(CONN1)).andReturn(true); + + // now handle the connector restart + member.wakeup(); + PowerMock.expectLastCall(); + member.ensureActive(); + PowerMock.expectLastCall(); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + worker.stopAndAwaitConnector(CONN1); + PowerMock.expectLastCall(); + EasyMock.expect(worker.getPlugins()).andReturn(plugins); + Capture> onSecondStart = newCapture(); + worker.startConnector(EasyMock.eq(CONN1), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED), capture(onSecondStart)); + PowerMock.expectLastCall().andAnswer(() -> { + onSecondStart.getValue().onCompletion(null, TargetState.STARTED); + return true; + }); + member.wakeup(); + PowerMock.expectLastCall(); + EasyMock.expect(worker.isRunning(CONN1)).andReturn(true); + + PowerMock.replayAll(); + + herder.tick(); + FutureCallback callback = new FutureCallback<>(); + herder.restartConnector(CONN1, callback); + herder.tick(); + callback.get(1000L, TimeUnit.MILLISECONDS); + + PowerMock.verifyAll(); + } + + @Test + public void testRestartUnknownConnector() throws Exception { + // get the initial assignment + EasyMock.expect(member.memberId()).andStubReturn("leader"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + expectRebalance(1, Collections.emptyList(), Collections.emptyList()); + expectPostRebalanceCatchup(SNAPSHOT); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // now handle the connector restart + member.wakeup(); + PowerMock.expectLastCall(); + member.ensureActive(); + PowerMock.expectLastCall(); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + herder.tick(); + FutureCallback callback = new FutureCallback<>(); + herder.restartConnector(CONN2, callback); + herder.tick(); + try { + callback.get(1000L, TimeUnit.MILLISECONDS); + fail("Expected NotFoundException to be raised"); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof NotFoundException); + } + + PowerMock.verifyAll(); + } + + @Test + public void testRestartConnectorRedirectToLeader() throws Exception { + // get the initial assignment + EasyMock.expect(member.memberId()).andStubReturn("member"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + expectRebalance(1, Collections.emptyList(), Collections.emptyList()); + expectPostRebalanceCatchup(SNAPSHOT); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // now handle the connector restart + member.wakeup(); + PowerMock.expectLastCall(); + member.ensureActive(); + PowerMock.expectLastCall(); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + herder.tick(); + FutureCallback callback = new FutureCallback<>(); + herder.restartConnector(CONN1, callback); + herder.tick(); + + try { + callback.get(1000L, TimeUnit.MILLISECONDS); + fail("Expected NotLeaderException to be raised"); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof NotLeaderException); + } + + PowerMock.verifyAll(); + } + + @Test + public void testRestartConnectorRedirectToOwner() throws Exception { + // get the initial assignment + EasyMock.expect(member.memberId()).andStubReturn("leader"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + expectRebalance(1, Collections.emptyList(), Collections.emptyList()); + expectPostRebalanceCatchup(SNAPSHOT); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // now handle the connector restart + member.wakeup(); + PowerMock.expectLastCall(); + member.ensureActive(); + PowerMock.expectLastCall(); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + String ownerUrl = "ownerUrl"; + EasyMock.expect(member.ownerUrl(CONN1)).andReturn(ownerUrl); + + PowerMock.replayAll(); + + herder.tick(); + time.sleep(1000L); + assertStatistics(3, 1, 100, 1000L); + + FutureCallback callback = new FutureCallback<>(); + herder.restartConnector(CONN1, callback); + herder.tick(); + + time.sleep(2000L); + assertStatistics(3, 1, 100, 3000L); + + try { + callback.get(1000L, TimeUnit.MILLISECONDS); + fail("Expected NotLeaderException to be raised"); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof NotAssignedException); + NotAssignedException notAssignedException = (NotAssignedException) e.getCause(); + assertEquals(ownerUrl, notAssignedException.forwardUrl()); + } + + PowerMock.verifyAll(); + } + + @Test + public void testRestartConnectorAndTasksUnknownConnector() throws Exception { + String connectorName = "UnknownConnector"; + RestartRequest restartRequest = new RestartRequest(connectorName, false, true); + + // get the initial assignment + EasyMock.expect(member.memberId()).andStubReturn("leader"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + expectRebalance(1, Collections.emptyList(), Collections.emptyList()); + expectPostRebalanceCatchup(SNAPSHOT); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // now handle the connector restart + member.wakeup(); + PowerMock.expectLastCall(); + member.ensureActive(); + PowerMock.expectLastCall(); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + herder.tick(); + FutureCallback callback = new FutureCallback<>(); + herder.restartConnectorAndTasks(restartRequest, callback); + herder.tick(); + ExecutionException ee = assertThrows(ExecutionException.class, () -> callback.get(1000L, TimeUnit.MILLISECONDS)); + assertTrue(ee.getCause() instanceof NotFoundException); + assertTrue(ee.getMessage().contains("Unknown connector:")); + + PowerMock.verifyAll(); + } + + @Test + public void testRestartConnectorAndTasksNotLeader() throws Exception { + RestartRequest restartRequest = new RestartRequest(CONN1, false, true); + + // get the initial assignment + EasyMock.expect(member.memberId()).andStubReturn("member"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + expectRebalance(1, Collections.emptyList(), Collections.emptyList()); + expectPostRebalanceCatchup(SNAPSHOT); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // now handle the connector restart + member.wakeup(); + PowerMock.expectLastCall(); + member.ensureActive(); + PowerMock.expectLastCall(); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + herder.tick(); + FutureCallback callback = new FutureCallback<>(); + herder.restartConnectorAndTasks(restartRequest, callback); + herder.tick(); + ExecutionException ee = assertThrows(ExecutionException.class, () -> callback.get(1000L, TimeUnit.MILLISECONDS)); + assertTrue(ee.getCause() instanceof NotLeaderException); + + PowerMock.verifyAll(); + } + + @Test + public void testRestartConnectorAndTasksUnknownStatus() throws Exception { + RestartRequest restartRequest = new RestartRequest(CONN1, false, true); + EasyMock.expect(herder.buildRestartPlan(restartRequest)).andReturn(Optional.empty()).anyTimes(); + + configBackingStore.putRestartRequest(restartRequest); + PowerMock.expectLastCall(); + + // get the initial assignment + EasyMock.expect(member.memberId()).andStubReturn("leader"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + expectRebalance(1, Collections.emptyList(), Collections.emptyList()); + expectPostRebalanceCatchup(SNAPSHOT); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // now handle the connector restart + member.wakeup(); + PowerMock.expectLastCall(); + member.ensureActive(); + PowerMock.expectLastCall(); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + herder.tick(); + FutureCallback callback = new FutureCallback<>(); + herder.restartConnectorAndTasks(restartRequest, callback); + herder.tick(); + ExecutionException ee = assertThrows(ExecutionException.class, () -> callback.get(1000L, TimeUnit.MILLISECONDS)); + assertTrue(ee.getCause() instanceof NotFoundException); + assertTrue(ee.getMessage().contains("Status for connector")); + PowerMock.verifyAll(); + } + + @Test + public void testRestartConnectorAndTasksSuccess() throws Exception { + RestartPlan restartPlan = PowerMock.createMock(RestartPlan.class); + ConnectorStateInfo connectorStateInfo = PowerMock.createMock(ConnectorStateInfo.class); + EasyMock.expect(restartPlan.restartConnectorStateInfo()).andReturn(connectorStateInfo).anyTimes(); + + RestartRequest restartRequest = new RestartRequest(CONN1, false, true); + EasyMock.expect(herder.buildRestartPlan(restartRequest)).andReturn(Optional.of(restartPlan)).anyTimes(); + + configBackingStore.putRestartRequest(restartRequest); + PowerMock.expectLastCall(); + + // get the initial assignment + EasyMock.expect(member.memberId()).andStubReturn("leader"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + expectRebalance(1, Collections.emptyList(), Collections.emptyList()); + expectPostRebalanceCatchup(SNAPSHOT); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // now handle the connector restart + member.wakeup(); + PowerMock.expectLastCall(); + member.ensureActive(); + PowerMock.expectLastCall(); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + herder.tick(); + FutureCallback callback = new FutureCallback<>(); + herder.restartConnectorAndTasks(restartRequest, callback); + herder.tick(); + assertEquals(connectorStateInfo, callback.get(1000L, TimeUnit.MILLISECONDS)); + PowerMock.verifyAll(); + } + + @Test + public void testDoRestartConnectorAndTasksEmptyPlan() throws Exception { + RestartRequest restartRequest = new RestartRequest(CONN1, false, true); + EasyMock.expect(herder.buildRestartPlan(restartRequest)).andReturn(Optional.empty()).anyTimes(); + + PowerMock.replayAll(); + + herder.doRestartConnectorAndTasks(restartRequest); + PowerMock.verifyAll(); + } + + @Test + public void testDoRestartConnectorAndTasksNoAssignments() throws Exception { + ConnectorTaskId taskId = new ConnectorTaskId(CONN1, 0); + RestartRequest restartRequest = new RestartRequest(CONN1, false, true); + RestartPlan restartPlan = PowerMock.createMock(RestartPlan.class); + EasyMock.expect(restartPlan.shouldRestartConnector()).andReturn(true).anyTimes(); + EasyMock.expect(restartPlan.shouldRestartTasks()).andReturn(true).anyTimes(); + EasyMock.expect(restartPlan.taskIdsToRestart()).andReturn(Collections.singletonList(taskId)).anyTimes(); + + EasyMock.expect(herder.buildRestartPlan(restartRequest)).andReturn(Optional.of(restartPlan)).anyTimes(); + + PowerMock.replayAll(); + herder.assignment = ExtendedAssignment.empty(); + herder.doRestartConnectorAndTasks(restartRequest); + PowerMock.verifyAll(); + } + + @Test + public void testDoRestartConnectorAndTasksOnlyConnector() throws Exception { + ConnectorTaskId taskId = new ConnectorTaskId(CONN1, 0); + RestartRequest restartRequest = new RestartRequest(CONN1, false, true); + RestartPlan restartPlan = PowerMock.createMock(RestartPlan.class); + EasyMock.expect(restartPlan.shouldRestartConnector()).andReturn(true).anyTimes(); + EasyMock.expect(restartPlan.shouldRestartTasks()).andReturn(true).anyTimes(); + EasyMock.expect(restartPlan.taskIdsToRestart()).andReturn(Collections.singletonList(taskId)).anyTimes(); + + EasyMock.expect(herder.buildRestartPlan(restartRequest)).andReturn(Optional.of(restartPlan)).anyTimes(); + + herder.assignment = PowerMock.createMock(ExtendedAssignment.class); + EasyMock.expect(herder.assignment.connectors()).andReturn(Collections.singletonList(CONN1)).anyTimes(); + EasyMock.expect(herder.assignment.tasks()).andReturn(Collections.emptyList()).anyTimes(); + + worker.stopAndAwaitConnector(CONN1); + PowerMock.expectLastCall(); + + Capture> stateCallback = newCapture(); + worker.startConnector(EasyMock.eq(CONN1), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.anyObject(TargetState.class), capture(stateCallback)); + + + herder.onRestart(CONN1); + EasyMock.expectLastCall(); + + PowerMock.replayAll(); + herder.doRestartConnectorAndTasks(restartRequest); + PowerMock.verifyAll(); + } + + @Test + public void testDoRestartConnectorAndTasksOnlyTasks() throws Exception { + ConnectorTaskId taskId = new ConnectorTaskId(CONN1, 0); + RestartRequest restartRequest = new RestartRequest(CONN1, false, true); + RestartPlan restartPlan = PowerMock.createMock(RestartPlan.class); + EasyMock.expect(restartPlan.shouldRestartConnector()).andReturn(true).anyTimes(); + EasyMock.expect(restartPlan.shouldRestartTasks()).andReturn(true).anyTimes(); + EasyMock.expect(restartPlan.taskIdsToRestart()).andReturn(Collections.singletonList(taskId)).anyTimes(); + EasyMock.expect(restartPlan.restartTaskCount()).andReturn(1).anyTimes(); + EasyMock.expect(restartPlan.totalTaskCount()).andReturn(1).anyTimes(); + EasyMock.expect(herder.buildRestartPlan(restartRequest)).andReturn(Optional.of(restartPlan)).anyTimes(); + + herder.assignment = PowerMock.createMock(ExtendedAssignment.class); + EasyMock.expect(herder.assignment.connectors()).andReturn(Collections.emptyList()).anyTimes(); + EasyMock.expect(herder.assignment.tasks()).andReturn(Collections.singletonList(taskId)).anyTimes(); + + worker.stopAndAwaitTasks(Collections.singletonList(taskId)); + PowerMock.expectLastCall(); + + herder.onRestart(taskId); + EasyMock.expectLastCall(); + + worker.startTask(EasyMock.eq(TASK0), EasyMock.anyObject(), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.anyObject(TargetState.class)); + PowerMock.expectLastCall().andReturn(true); + + PowerMock.replayAll(); + herder.doRestartConnectorAndTasks(restartRequest); + PowerMock.verifyAll(); + } + + @Test + public void testDoRestartConnectorAndTasksBoth() throws Exception { + ConnectorTaskId taskId = new ConnectorTaskId(CONN1, 0); + RestartRequest restartRequest = new RestartRequest(CONN1, false, true); + RestartPlan restartPlan = PowerMock.createMock(RestartPlan.class); + EasyMock.expect(restartPlan.shouldRestartConnector()).andReturn(true).anyTimes(); + EasyMock.expect(restartPlan.shouldRestartTasks()).andReturn(true).anyTimes(); + EasyMock.expect(restartPlan.taskIdsToRestart()).andReturn(Collections.singletonList(taskId)).anyTimes(); + EasyMock.expect(restartPlan.restartTaskCount()).andReturn(1).anyTimes(); + EasyMock.expect(restartPlan.totalTaskCount()).andReturn(1).anyTimes(); + EasyMock.expect(herder.buildRestartPlan(restartRequest)).andReturn(Optional.of(restartPlan)).anyTimes(); + + herder.assignment = PowerMock.createMock(ExtendedAssignment.class); + EasyMock.expect(herder.assignment.connectors()).andReturn(Collections.singletonList(CONN1)).anyTimes(); + EasyMock.expect(herder.assignment.tasks()).andReturn(Collections.singletonList(taskId)).anyTimes(); + + worker.stopAndAwaitConnector(CONN1); + PowerMock.expectLastCall(); + + Capture> stateCallback = newCapture(); + worker.startConnector(EasyMock.eq(CONN1), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.anyObject(TargetState.class), capture(stateCallback)); + + + herder.onRestart(CONN1); + EasyMock.expectLastCall(); + + worker.stopAndAwaitTasks(Collections.singletonList(taskId)); + PowerMock.expectLastCall(); + + herder.onRestart(taskId); + EasyMock.expectLastCall(); + + worker.startTask(EasyMock.eq(TASK0), EasyMock.anyObject(), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.anyObject(TargetState.class)); + PowerMock.expectLastCall().andReturn(true); + + PowerMock.replayAll(); + herder.doRestartConnectorAndTasks(restartRequest); + PowerMock.verifyAll(); + } + + @Test + public void testRestartTask() throws Exception { + EasyMock.expect(worker.connectorTaskConfigs(CONN1, conn1SinkConfig)).andStubReturn(TASK_CONFIGS); + + // get the initial assignment + EasyMock.expect(member.memberId()).andStubReturn("leader"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + expectRebalance(1, Collections.emptyList(), singletonList(TASK0)); + expectPostRebalanceCatchup(SNAPSHOT); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + worker.startTask(EasyMock.eq(TASK0), EasyMock.anyObject(), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED)); + PowerMock.expectLastCall().andReturn(true); + + // now handle the task restart + member.wakeup(); + PowerMock.expectLastCall(); + member.ensureActive(); + PowerMock.expectLastCall(); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + worker.stopAndAwaitTask(TASK0); + PowerMock.expectLastCall(); + worker.startTask(EasyMock.eq(TASK0), EasyMock.anyObject(), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED)); + PowerMock.expectLastCall().andReturn(true); + + PowerMock.replayAll(); + + herder.tick(); + FutureCallback callback = new FutureCallback<>(); + herder.restartTask(TASK0, callback); + herder.tick(); + callback.get(1000L, TimeUnit.MILLISECONDS); + + PowerMock.verifyAll(); + } + + @Test + public void testRestartUnknownTask() throws Exception { + // get the initial assignment + EasyMock.expect(member.memberId()).andStubReturn("member"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + expectRebalance(1, Collections.emptyList(), Collections.emptyList()); + expectPostRebalanceCatchup(SNAPSHOT); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + member.wakeup(); + PowerMock.expectLastCall(); + member.ensureActive(); + PowerMock.expectLastCall(); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + FutureCallback callback = new FutureCallback<>(); + herder.tick(); + herder.restartTask(new ConnectorTaskId("blah", 0), callback); + herder.tick(); + + try { + callback.get(1000L, TimeUnit.MILLISECONDS); + fail("Expected NotLeaderException to be raised"); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof NotFoundException); + } + + PowerMock.verifyAll(); + } + + @Test + public void testRequestProcessingOrder() { + final DistributedHerder.DistributedHerderRequest req1 = herder.addRequest(100, null, null); + final DistributedHerder.DistributedHerderRequest req2 = herder.addRequest(10, null, null); + final DistributedHerder.DistributedHerderRequest req3 = herder.addRequest(200, null, null); + final DistributedHerder.DistributedHerderRequest req4 = herder.addRequest(200, null, null); + + assertEquals(req2, herder.requests.pollFirst()); // lowest delay + assertEquals(req1, herder.requests.pollFirst()); // next lowest delay + assertEquals(req3, herder.requests.pollFirst()); // same delay as req4, but added first + assertEquals(req4, herder.requests.pollFirst()); + } + + @Test + public void testRestartTaskRedirectToLeader() throws Exception { + // get the initial assignment + EasyMock.expect(member.memberId()).andStubReturn("member"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + expectRebalance(1, Collections.emptyList(), Collections.emptyList()); + expectPostRebalanceCatchup(SNAPSHOT); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // now handle the task restart + member.wakeup(); + PowerMock.expectLastCall(); + member.ensureActive(); + PowerMock.expectLastCall(); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + herder.tick(); + FutureCallback callback = new FutureCallback<>(); + herder.restartTask(TASK0, callback); + herder.tick(); + + try { + callback.get(1000L, TimeUnit.MILLISECONDS); + fail("Expected NotLeaderException to be raised"); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof NotLeaderException); + } + + PowerMock.verifyAll(); + } + + @Test + public void testRestartTaskRedirectToOwner() throws Exception { + // get the initial assignment + EasyMock.expect(member.memberId()).andStubReturn("leader"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + expectRebalance(1, Collections.emptyList(), Collections.emptyList()); + expectPostRebalanceCatchup(SNAPSHOT); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // now handle the task restart + String ownerUrl = "ownerUrl"; + EasyMock.expect(member.ownerUrl(TASK0)).andReturn(ownerUrl); + member.wakeup(); + PowerMock.expectLastCall(); + member.ensureActive(); + PowerMock.expectLastCall(); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + herder.tick(); + FutureCallback callback = new FutureCallback<>(); + herder.restartTask(TASK0, callback); + herder.tick(); + + try { + callback.get(1000L, TimeUnit.MILLISECONDS); + fail("Expected NotLeaderException to be raised"); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof NotAssignedException); + NotAssignedException notAssignedException = (NotAssignedException) e.getCause(); + assertEquals(ownerUrl, notAssignedException.forwardUrl()); + } + + PowerMock.verifyAll(); + } + + @Test + public void testConnectorConfigAdded() { + // If a connector was added, we need to rebalance + EasyMock.expect(member.memberId()).andStubReturn("member"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + + // join, no configs so no need to catch up on config topic + expectRebalance(-1, Collections.emptyList(), Collections.emptyList()); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // apply config + member.wakeup(); + PowerMock.expectLastCall(); + member.ensureActive(); + PowerMock.expectLastCall(); + // Checks for config updates and starts rebalance + EasyMock.expect(configBackingStore.snapshot()).andReturn(SNAPSHOT); + member.requestRejoin(); + PowerMock.expectLastCall(); + // Performs rebalance and gets new assignment + expectRebalance(Collections.emptyList(), Collections.emptyList(), + ConnectProtocol.Assignment.NO_ERROR, 1, Arrays.asList(CONN1), Collections.emptyList()); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + Capture> onStart = newCapture(); + worker.startConnector(EasyMock.eq(CONN1), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED), capture(onStart)); + PowerMock.expectLastCall().andAnswer(() -> { + onStart.getValue().onCompletion(null, TargetState.STARTED); + return true; + }); + member.wakeup(); + PowerMock.expectLastCall(); + EasyMock.expect(worker.getPlugins()).andReturn(plugins); + EasyMock.expect(worker.isRunning(CONN1)).andReturn(true); + EasyMock.expect(worker.connectorTaskConfigs(CONN1, conn1SinkConfig)).andReturn(TASK_CONFIGS); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + herder.tick(); // join + configUpdateListener.onConnectorConfigUpdate(CONN1); // read updated config + herder.tick(); // apply config + herder.tick(); // do rebalance + + PowerMock.verifyAll(); + } + + @Test + public void testConnectorConfigUpdate() throws Exception { + // Connector config can be applied without any rebalance + + EasyMock.expect(member.memberId()).andStubReturn("member"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + EasyMock.expect(worker.connectorNames()).andStubReturn(Collections.singleton(CONN1)); + + // join + expectRebalance(1, Arrays.asList(CONN1), Collections.emptyList()); + expectPostRebalanceCatchup(SNAPSHOT); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + Capture> onFirstStart = newCapture(); + worker.startConnector(EasyMock.eq(CONN1), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED), capture(onFirstStart)); + PowerMock.expectLastCall().andAnswer(() -> { + onFirstStart.getValue().onCompletion(null, TargetState.STARTED); + return true; + }); + member.wakeup(); + PowerMock.expectLastCall(); + EasyMock.expect(worker.getPlugins()).andReturn(plugins); + EasyMock.expect(worker.isRunning(CONN1)).andReturn(true); + EasyMock.expect(worker.connectorTaskConfigs(CONN1, conn1SinkConfig)).andReturn(TASK_CONFIGS); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // apply config + member.wakeup(); + PowerMock.expectLastCall(); + member.ensureActive(); + PowerMock.expectLastCall(); + EasyMock.expect(configBackingStore.snapshot()).andReturn(SNAPSHOT); // for this test, it doesn't matter if we use the same config snapshot + worker.stopAndAwaitConnector(CONN1); + PowerMock.expectLastCall(); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + Capture> onSecondStart = newCapture(); + worker.startConnector(EasyMock.eq(CONN1), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED), capture(onSecondStart)); + PowerMock.expectLastCall().andAnswer(() -> { + onSecondStart.getValue().onCompletion(null, TargetState.STARTED); + return true; + }); + member.wakeup(); + PowerMock.expectLastCall(); + EasyMock.expect(worker.getPlugins()).andReturn(plugins); + EasyMock.expect(worker.isRunning(CONN1)).andReturn(true); + EasyMock.expect(worker.connectorTaskConfigs(CONN1, conn1SinkConfig)).andReturn(TASK_CONFIGS); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // These will occur just before/during the third tick + member.ensureActive(); + PowerMock.expectLastCall(); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + herder.tick(); // join + configUpdateListener.onConnectorConfigUpdate(CONN1); // read updated config + herder.tick(); // apply config + herder.tick(); + + PowerMock.verifyAll(); + } + + @Test + public void testConnectorPaused() throws Exception { + // ensure that target state changes are propagated to the worker + + EasyMock.expect(member.memberId()).andStubReturn("member"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + EasyMock.expect(worker.connectorNames()).andStubReturn(Collections.singleton(CONN1)); + + // join + expectRebalance(1, Arrays.asList(CONN1), Collections.emptyList()); + expectPostRebalanceCatchup(SNAPSHOT); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + Capture> onStart = newCapture(); + worker.startConnector(EasyMock.eq(CONN1), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED), capture(onStart)); + PowerMock.expectLastCall().andAnswer(() -> { + onStart.getValue().onCompletion(null, TargetState.STARTED); + return true; + }); + member.wakeup(); + PowerMock.expectLastCall(); + EasyMock.expect(worker.getPlugins()).andReturn(plugins); + EasyMock.expect(worker.isRunning(CONN1)).andReturn(true); + EasyMock.expect(worker.connectorTaskConfigs(CONN1, conn1SinkConfig)).andReturn(TASK_CONFIGS); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // handle the state change + member.wakeup(); + PowerMock.expectLastCall(); + member.ensureActive(); + PowerMock.expectLastCall(); + + EasyMock.expect(configBackingStore.snapshot()).andReturn(SNAPSHOT_PAUSED_CONN1); + PowerMock.expectLastCall(); + + Capture> onPause = newCapture(); + worker.setTargetState(EasyMock.eq(CONN1), EasyMock.eq(TargetState.PAUSED), capture(onPause)); + PowerMock.expectLastCall().andAnswer(() -> { + onStart.getValue().onCompletion(null, TargetState.PAUSED); + return null; + }); + + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // These will occur just before/during the third tick + member.ensureActive(); + PowerMock.expectLastCall(); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + herder.tick(); // join + configUpdateListener.onConnectorTargetStateChange(CONN1); // state changes to paused + herder.tick(); // worker should apply the state change + herder.tick(); + + PowerMock.verifyAll(); + } + + @Test + public void testConnectorResumed() throws Exception { + EasyMock.expect(member.memberId()).andStubReturn("member"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + EasyMock.expect(worker.connectorNames()).andStubReturn(Collections.singleton(CONN1)); + + // start with the connector paused + expectRebalance(1, Arrays.asList(CONN1), Collections.emptyList()); + expectPostRebalanceCatchup(SNAPSHOT_PAUSED_CONN1); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + Capture> onStart = newCapture(); + worker.startConnector(EasyMock.eq(CONN1), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.PAUSED), capture(onStart)); + PowerMock.expectLastCall().andAnswer(() -> { + onStart.getValue().onCompletion(null, TargetState.PAUSED); + return true; + }); + EasyMock.expect(worker.getPlugins()).andReturn(plugins); + + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // handle the state change + member.wakeup(); + PowerMock.expectLastCall(); + member.ensureActive(); + PowerMock.expectLastCall(); + + EasyMock.expect(configBackingStore.snapshot()).andReturn(SNAPSHOT); + PowerMock.expectLastCall(); + + // we expect reconfiguration after resuming + EasyMock.expect(worker.isRunning(CONN1)).andReturn(true); + EasyMock.expect(worker.connectorTaskConfigs(CONN1, conn1SinkConfig)).andReturn(TASK_CONFIGS); + + Capture> onResume = newCapture(); + worker.setTargetState(EasyMock.eq(CONN1), EasyMock.eq(TargetState.STARTED), capture(onResume)); + PowerMock.expectLastCall().andAnswer(() -> { + onResume.getValue().onCompletion(null, TargetState.STARTED); + return null; + }); + member.wakeup(); + PowerMock.expectLastCall(); + + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // These will occur just before/during the third tick + member.ensureActive(); + PowerMock.expectLastCall(); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + herder.tick(); // join + configUpdateListener.onConnectorTargetStateChange(CONN1); // state changes to started + herder.tick(); // apply state change + herder.tick(); + + PowerMock.verifyAll(); + } + + @Test + public void testUnknownConnectorPaused() throws Exception { + EasyMock.expect(member.memberId()).andStubReturn("member"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + EasyMock.expect(worker.connectorNames()).andStubReturn(Collections.singleton(CONN1)); + + // join + expectRebalance(1, Collections.emptyList(), singletonList(TASK0)); + expectPostRebalanceCatchup(SNAPSHOT); + worker.startTask(EasyMock.eq(TASK0), EasyMock.anyObject(), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED)); + PowerMock.expectLastCall().andReturn(true); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // state change is ignored since we have no target state + member.wakeup(); + PowerMock.expectLastCall(); + member.ensureActive(); + PowerMock.expectLastCall(); + + EasyMock.expect(configBackingStore.snapshot()).andReturn(SNAPSHOT); + PowerMock.expectLastCall(); + + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + herder.tick(); // join + configUpdateListener.onConnectorTargetStateChange("unknown-connector"); + herder.tick(); // continue + + PowerMock.verifyAll(); + } + + @Test + public void testConnectorPausedRunningTaskOnly() throws Exception { + // even if we don't own the connector, we should still propagate target state + // changes to the worker so that tasks will transition correctly + + EasyMock.expect(member.memberId()).andStubReturn("member"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + EasyMock.expect(worker.connectorNames()).andStubReturn(Collections.emptySet()); + + // join + expectRebalance(1, Collections.emptyList(), singletonList(TASK0)); + expectPostRebalanceCatchup(SNAPSHOT); + worker.startTask(EasyMock.eq(TASK0), EasyMock.anyObject(), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED)); + PowerMock.expectLastCall().andReturn(true); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // handle the state change + member.wakeup(); + PowerMock.expectLastCall(); + member.ensureActive(); + PowerMock.expectLastCall(); + + EasyMock.expect(configBackingStore.snapshot()).andReturn(SNAPSHOT_PAUSED_CONN1); + PowerMock.expectLastCall(); + + Capture> onPause = newCapture(); + worker.setTargetState(EasyMock.eq(CONN1), EasyMock.eq(TargetState.PAUSED), capture(onPause)); + PowerMock.expectLastCall().andAnswer(() -> { + onPause.getValue().onCompletion(null, TargetState.STARTED); + return null; + }); + member.wakeup(); + PowerMock.expectLastCall(); + + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + herder.tick(); // join + configUpdateListener.onConnectorTargetStateChange(CONN1); // state changes to paused + herder.tick(); // apply state change + + PowerMock.verifyAll(); + } + + @Test + public void testConnectorResumedRunningTaskOnly() throws Exception { + // even if we don't own the connector, we should still propagate target state + // changes to the worker so that tasks will transition correctly + + EasyMock.expect(member.memberId()).andStubReturn("member"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + EasyMock.expect(worker.connectorNames()).andStubReturn(Collections.emptySet()); + + // join + expectRebalance(1, Collections.emptyList(), singletonList(TASK0)); + expectPostRebalanceCatchup(SNAPSHOT_PAUSED_CONN1); + worker.startTask(EasyMock.eq(TASK0), EasyMock.anyObject(), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.PAUSED)); + PowerMock.expectLastCall().andReturn(true); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // handle the state change + member.wakeup(); + PowerMock.expectLastCall(); + member.ensureActive(); + PowerMock.expectLastCall(); + + EasyMock.expect(configBackingStore.snapshot()).andReturn(SNAPSHOT); + PowerMock.expectLastCall(); + + Capture> onStart = newCapture(); + worker.setTargetState(EasyMock.eq(CONN1), EasyMock.eq(TargetState.STARTED), capture(onStart)); + PowerMock.expectLastCall().andAnswer(() -> { + onStart.getValue().onCompletion(null, TargetState.STARTED); + return null; + }); + member.wakeup(); + PowerMock.expectLastCall(); + + EasyMock.expect(worker.isRunning(CONN1)).andReturn(false); + + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // These will occur just before/during the third tick + member.ensureActive(); + PowerMock.expectLastCall(); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + herder.tick(); // join + configUpdateListener.onConnectorTargetStateChange(CONN1); // state changes to paused + herder.tick(); // apply state change + herder.tick(); + + PowerMock.verifyAll(); + } + + @Test + public void testTaskConfigAdded() { + // Task config always requires rebalance + EasyMock.expect(member.memberId()).andStubReturn("member"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + + // join + expectRebalance(-1, Collections.emptyList(), Collections.emptyList()); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // apply config + member.wakeup(); + PowerMock.expectLastCall(); + member.ensureActive(); + PowerMock.expectLastCall(); + // Checks for config updates and starts rebalance + EasyMock.expect(configBackingStore.snapshot()).andReturn(SNAPSHOT); + member.requestRejoin(); + PowerMock.expectLastCall(); + // Performs rebalance and gets new assignment + expectRebalance(Collections.emptyList(), Collections.emptyList(), + ConnectProtocol.Assignment.NO_ERROR, 1, Collections.emptyList(), + Arrays.asList(TASK0)); + worker.startTask(EasyMock.eq(TASK0), EasyMock.anyObject(), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED)); + PowerMock.expectLastCall().andReturn(true); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + herder.tick(); // join + configUpdateListener.onTaskConfigUpdate(Arrays.asList(TASK0, TASK1, TASK2)); // read updated config + herder.tick(); // apply config + herder.tick(); // do rebalance + + PowerMock.verifyAll(); + } + + @Test + public void testJoinLeaderCatchUpFails() throws Exception { + // Join group and as leader fail to do assignment + EasyMock.expect(member.memberId()).andStubReturn("leader"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + expectRebalance(Collections.emptyList(), Collections.emptyList(), + ConnectProtocol.Assignment.CONFIG_MISMATCH, 1, Collections.emptyList(), + Collections.emptyList()); + // Reading to end of log times out + configBackingStore.refresh(EasyMock.anyLong(), EasyMock.anyObject(TimeUnit.class)); + EasyMock.expectLastCall().andThrow(new TimeoutException()); + member.maybeLeaveGroup(EasyMock.eq("taking too long to read the log")); + EasyMock.expectLastCall(); + member.requestRejoin(); + + // After backoff, restart the process and this time succeed + expectRebalance(1, Arrays.asList(CONN1), Arrays.asList(TASK1)); + expectPostRebalanceCatchup(SNAPSHOT); + + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + Capture> onStart = newCapture(); + worker.startConnector(EasyMock.eq(CONN1), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED), capture(onStart)); + PowerMock.expectLastCall().andAnswer(() -> { + onStart.getValue().onCompletion(null, TargetState.STARTED); + return true; + }); + member.wakeup(); + PowerMock.expectLastCall(); + EasyMock.expect(worker.getPlugins()).andReturn(plugins); + EasyMock.expect(worker.connectorTaskConfigs(CONN1, conn1SinkConfig)).andReturn(TASK_CONFIGS); + worker.startTask(EasyMock.eq(TASK1), EasyMock.anyObject(), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED)); + PowerMock.expectLastCall().andReturn(true); + EasyMock.expect(worker.isRunning(CONN1)).andReturn(true); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // one more tick, to make sure we don't keep trying to read to the config topic unnecessarily + expectRebalance(1, Collections.emptyList(), Collections.emptyList()); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + long before = time.milliseconds(); + int workerUnsyncBackoffMs = DistributedConfig.WORKER_UNSYNC_BACKOFF_MS_DEFAULT; + int coordinatorDiscoveryTimeoutMs = 100; + herder.tick(); + assertEquals(before + coordinatorDiscoveryTimeoutMs + workerUnsyncBackoffMs, time.milliseconds()); + + time.sleep(1000L); + assertStatistics("leaderUrl", true, 3, 0, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY); + + before = time.milliseconds(); + herder.tick(); + assertEquals(before + coordinatorDiscoveryTimeoutMs, time.milliseconds()); + time.sleep(2000L); + assertStatistics("leaderUrl", false, 3, 1, 100, 2000L); + + // tick once more to ensure that the successful read to the end of the config topic was + // tracked and no further unnecessary attempts were made + herder.tick(); + + PowerMock.verifyAll(); + } + + @Test + public void testJoinLeaderCatchUpRetriesForIncrementalCooperative() throws Exception { + connectProtocolVersion = CONNECT_PROTOCOL_V1; + + // Join group and as leader fail to do assignment + EasyMock.expect(member.memberId()).andStubReturn("leader"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V1); + expectRebalance(1, Arrays.asList(CONN1), Arrays.asList(TASK1)); + expectPostRebalanceCatchup(SNAPSHOT); + + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // The leader got its assignment + expectRebalance(Collections.emptyList(), Collections.emptyList(), + ConnectProtocol.Assignment.NO_ERROR, + 1, Arrays.asList(CONN1), Arrays.asList(TASK1), 0); + + EasyMock.expect(worker.getPlugins()).andReturn(plugins); + Capture> onStart = newCapture(); + worker.startConnector(EasyMock.eq(CONN1), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED), capture(onStart)); + PowerMock.expectLastCall().andAnswer(() -> { + onStart.getValue().onCompletion(null, TargetState.STARTED); + return true; + }); + member.wakeup(); + PowerMock.expectLastCall(); + EasyMock.expect(worker.isRunning(CONN1)).andReturn(true); + EasyMock.expect(worker.connectorTaskConfigs(CONN1, conn1SinkConfig)).andReturn(TASK_CONFIGS); + + worker.startTask(EasyMock.eq(TASK1), EasyMock.anyObject(), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED)); + PowerMock.expectLastCall().andReturn(true); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // Another rebalance is triggered but this time it fails to read to the max offset and + // triggers a re-sync + expectRebalance(Collections.emptyList(), Collections.emptyList(), + ConnectProtocol.Assignment.CONFIG_MISMATCH, 1, Collections.emptyList(), + Collections.emptyList()); + + // The leader will retry a few times to read to the end of the config log + int retries = 2; + member.requestRejoin(); + for (int i = retries; i >= 0; --i) { + // Reading to end of log times out + configBackingStore.refresh(EasyMock.anyLong(), EasyMock.anyObject(TimeUnit.class)); + EasyMock.expectLastCall().andThrow(new TimeoutException()); + member.maybeLeaveGroup(EasyMock.eq("taking too long to read the log")); + EasyMock.expectLastCall(); + } + + // After a few retries succeed to read the log to the end + expectRebalance(Collections.emptyList(), Collections.emptyList(), + ConnectProtocol.Assignment.NO_ERROR, + 1, Arrays.asList(CONN1), Arrays.asList(TASK1), 0); + expectPostRebalanceCatchup(SNAPSHOT); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + assertStatistics(0, 0, 0, Double.POSITIVE_INFINITY); + herder.tick(); + + time.sleep(2000L); + assertStatistics(3, 1, 100, 2000); + herder.tick(); + + long before; + int coordinatorDiscoveryTimeoutMs = 100; + int maxRetries = 5; + for (int i = maxRetries; i >= maxRetries - retries; --i) { + before = time.milliseconds(); + int workerUnsyncBackoffMs = + DistributedConfig.SCHEDULED_REBALANCE_MAX_DELAY_MS_DEFAULT / 10 / i; + herder.tick(); + assertEquals(before + coordinatorDiscoveryTimeoutMs + workerUnsyncBackoffMs, time.milliseconds()); + coordinatorDiscoveryTimeoutMs = 0; + } + + before = time.milliseconds(); + coordinatorDiscoveryTimeoutMs = 100; + herder.tick(); + assertEquals(before + coordinatorDiscoveryTimeoutMs, time.milliseconds()); + + PowerMock.verifyAll(); + } + + @Test + public void testJoinLeaderCatchUpFailsForIncrementalCooperative() throws Exception { + connectProtocolVersion = CONNECT_PROTOCOL_V1; + + // Join group and as leader fail to do assignment + EasyMock.expect(member.memberId()).andStubReturn("leader"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V1); + expectRebalance(1, Arrays.asList(CONN1), Arrays.asList(TASK1)); + expectPostRebalanceCatchup(SNAPSHOT); + + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // The leader got its assignment + expectRebalance(Collections.emptyList(), Collections.emptyList(), + ConnectProtocol.Assignment.NO_ERROR, + 1, Arrays.asList(CONN1), Arrays.asList(TASK1), 0); + + EasyMock.expect(worker.getPlugins()).andReturn(plugins); + // and the new assignment started + Capture> onStart = newCapture(); + worker.startConnector(EasyMock.eq(CONN1), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED), capture(onStart)); + PowerMock.expectLastCall().andAnswer(() -> { + onStart.getValue().onCompletion(null, TargetState.STARTED); + return true; + }); + member.wakeup(); + PowerMock.expectLastCall(); + EasyMock.expect(worker.isRunning(CONN1)).andReturn(true); + EasyMock.expect(worker.connectorTaskConfigs(CONN1, conn1SinkConfig)).andReturn(TASK_CONFIGS); + + worker.startTask(EasyMock.eq(TASK1), EasyMock.anyObject(), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED)); + PowerMock.expectLastCall().andReturn(true); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // Another rebalance is triggered but this time it fails to read to the max offset and + // triggers a re-sync + expectRebalance(Collections.emptyList(), Collections.emptyList(), + ConnectProtocol.Assignment.CONFIG_MISMATCH, 1, Collections.emptyList(), + Collections.emptyList()); + + // The leader will exhaust the retries while trying to read to the end of the config log + int maxRetries = 5; + member.requestRejoin(); + for (int i = maxRetries; i >= 0; --i) { + // Reading to end of log times out + configBackingStore.refresh(EasyMock.anyLong(), EasyMock.anyObject(TimeUnit.class)); + EasyMock.expectLastCall().andThrow(new TimeoutException()); + member.maybeLeaveGroup(EasyMock.eq("taking too long to read the log")); + EasyMock.expectLastCall(); + } + + Capture assignmentCapture = newCapture(); + member.revokeAssignment(capture(assignmentCapture)); + PowerMock.expectLastCall(); + + // After a complete backoff and a revocation of running tasks rejoin and this time succeed + // The worker gets back the assignment that had given up + expectRebalance(Collections.emptyList(), Collections.emptyList(), + ConnectProtocol.Assignment.NO_ERROR, + 1, Arrays.asList(CONN1), Arrays.asList(TASK1), 0); + expectPostRebalanceCatchup(SNAPSHOT); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + assertStatistics(0, 0, 0, Double.POSITIVE_INFINITY); + herder.tick(); + + time.sleep(2000L); + assertStatistics(3, 1, 100, 2000); + herder.tick(); + + long before; + int coordinatorDiscoveryTimeoutMs = 100; + for (int i = maxRetries; i > 0; --i) { + before = time.milliseconds(); + int workerUnsyncBackoffMs = + DistributedConfig.SCHEDULED_REBALANCE_MAX_DELAY_MS_DEFAULT / 10 / i; + herder.tick(); + assertEquals(before + coordinatorDiscoveryTimeoutMs + workerUnsyncBackoffMs, time.milliseconds()); + coordinatorDiscoveryTimeoutMs = 0; + } + + before = time.milliseconds(); + herder.tick(); + assertEquals(before, time.milliseconds()); + assertEquals(Collections.singleton(CONN1), assignmentCapture.getValue().connectors()); + assertEquals(Collections.singleton(TASK1), assignmentCapture.getValue().tasks()); + herder.tick(); + + PowerMock.verifyAll(); + } + + @Test + public void testAccessors() throws Exception { + EasyMock.expect(member.memberId()).andStubReturn("leader"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + EasyMock.expect(worker.getPlugins()).andReturn(plugins).anyTimes(); + expectRebalance(1, Collections.emptyList(), Collections.emptyList()); + EasyMock.expect(configBackingStore.snapshot()).andReturn(SNAPSHOT).times(2); + + WorkerConfigTransformer configTransformer = EasyMock.mock(WorkerConfigTransformer.class); + EasyMock.expect(configTransformer.transform(EasyMock.eq(CONN1), EasyMock.anyObject())) + .andThrow(new AssertionError("Config transformation should not occur when requesting connector or task info")); + EasyMock.replay(configTransformer); + ClusterConfigState snapshotWithTransform = new ClusterConfigState(1, null, Collections.singletonMap(CONN1, 3), + Collections.singletonMap(CONN1, CONN1_CONFIG), Collections.singletonMap(CONN1, TargetState.STARTED), + TASK_CONFIGS_MAP, Collections.emptySet(), configTransformer); + + expectPostRebalanceCatchup(snapshotWithTransform); + + + member.wakeup(); + PowerMock.expectLastCall().anyTimes(); + // list connectors, get connector info, get connector config, get task configs + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + PowerMock.replayAll(); + + FutureCallback> listConnectorsCb = new FutureCallback<>(); + herder.connectors(listConnectorsCb); + FutureCallback connectorInfoCb = new FutureCallback<>(); + herder.connectorInfo(CONN1, connectorInfoCb); + FutureCallback> connectorConfigCb = new FutureCallback<>(); + herder.connectorConfig(CONN1, connectorConfigCb); + FutureCallback> taskConfigsCb = new FutureCallback<>(); + herder.taskConfigs(CONN1, taskConfigsCb); + + herder.tick(); + assertTrue(listConnectorsCb.isDone()); + assertEquals(Collections.singleton(CONN1), listConnectorsCb.get()); + assertTrue(connectorInfoCb.isDone()); + ConnectorInfo info = new ConnectorInfo(CONN1, CONN1_CONFIG, Arrays.asList(TASK0, TASK1, TASK2), + ConnectorType.SOURCE); + assertEquals(info, connectorInfoCb.get()); + assertTrue(connectorConfigCb.isDone()); + assertEquals(CONN1_CONFIG, connectorConfigCb.get()); + assertTrue(taskConfigsCb.isDone()); + assertEquals(Arrays.asList( + new TaskInfo(TASK0, TASK_CONFIG), + new TaskInfo(TASK1, TASK_CONFIG), + new TaskInfo(TASK2, TASK_CONFIG)), + taskConfigsCb.get()); + + PowerMock.verifyAll(); + } + + @Test + public void testPutConnectorConfig() throws Exception { + EasyMock.expect(member.memberId()).andStubReturn("leader"); + expectRebalance(1, Arrays.asList(CONN1), Collections.emptyList()); + expectPostRebalanceCatchup(SNAPSHOT); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + Capture> onFirstStart = newCapture(); + worker.startConnector(EasyMock.eq(CONN1), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED), capture(onFirstStart)); + PowerMock.expectLastCall().andAnswer(() -> { + onFirstStart.getValue().onCompletion(null, TargetState.STARTED); + return true; + }); + EasyMock.expect(worker.isRunning(CONN1)).andReturn(true); + EasyMock.expect(worker.connectorTaskConfigs(CONN1, conn1SinkConfig)).andReturn(TASK_CONFIGS); + + // list connectors, get connector info, get connector config, get task configs + member.wakeup(); + PowerMock.expectLastCall().anyTimes(); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // Poll loop for second round of calls + member.ensureActive(); + PowerMock.expectLastCall(); + + EasyMock.expect(worker.getPlugins()).andReturn(plugins).anyTimes(); + EasyMock.expect(configBackingStore.snapshot()).andReturn(SNAPSHOT); + + Capture> validateCallback = newCapture(); + herder.validateConnectorConfig(EasyMock.eq(CONN1_CONFIG_UPDATED), capture(validateCallback)); + PowerMock.expectLastCall().andAnswer(() -> { + validateCallback.getValue().onCompletion(null, CONN1_CONFIG_INFOS); + return null; + }); + configBackingStore.putConnectorConfig(CONN1, CONN1_CONFIG_UPDATED); + PowerMock.expectLastCall().andAnswer(() -> { + // Simulate response to writing config + waiting until end of log to be read + configUpdateListener.onConnectorConfigUpdate(CONN1); + return null; + }); + // As a result of reconfig, should need to update snapshot. With only connector updates, we'll just restart + // connector without rebalance + EasyMock.expect(configBackingStore.snapshot()).andReturn(SNAPSHOT_UPDATED_CONN1_CONFIG).times(2); + worker.stopAndAwaitConnector(CONN1); + PowerMock.expectLastCall(); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V0); + Capture> onSecondStart = newCapture(); + worker.startConnector(EasyMock.eq(CONN1), EasyMock.anyObject(), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED), capture(onSecondStart)); + PowerMock.expectLastCall().andAnswer(() -> { + onSecondStart.getValue().onCompletion(null, TargetState.STARTED); + return true; + }); + EasyMock.expect(worker.isRunning(CONN1)).andReturn(true); + EasyMock.expect(worker.connectorTaskConfigs(CONN1, conn1SinkConfigUpdated)).andReturn(TASK_CONFIGS); + + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + // Third tick just to read the config + member.ensureActive(); + PowerMock.expectLastCall(); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + // Should pick up original config + FutureCallback> connectorConfigCb = new FutureCallback<>(); + herder.connectorConfig(CONN1, connectorConfigCb); + herder.tick(); + assertTrue(connectorConfigCb.isDone()); + assertEquals(CONN1_CONFIG, connectorConfigCb.get()); + + // Apply new config. + FutureCallback> putConfigCb = new FutureCallback<>(); + herder.putConnectorConfig(CONN1, CONN1_CONFIG_UPDATED, true, putConfigCb); + herder.tick(); + assertTrue(putConfigCb.isDone()); + ConnectorInfo updatedInfo = new ConnectorInfo(CONN1, CONN1_CONFIG_UPDATED, Arrays.asList(TASK0, TASK1, TASK2), + ConnectorType.SOURCE); + assertEquals(new Herder.Created<>(false, updatedInfo), putConfigCb.get()); + + // Check config again to validate change + connectorConfigCb = new FutureCallback<>(); + herder.connectorConfig(CONN1, connectorConfigCb); + herder.tick(); + assertTrue(connectorConfigCb.isDone()); + assertEquals(CONN1_CONFIG_UPDATED, connectorConfigCb.get()); + + PowerMock.verifyAll(); + } + + @Test + public void testKeyRotationWhenWorkerBecomesLeader() throws Exception { + EasyMock.expect(member.memberId()).andStubReturn("member"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V2); + + expectRebalance(1, Collections.emptyList(), Collections.emptyList()); + expectPostRebalanceCatchup(SNAPSHOT); + // First rebalance: poll indefinitely as no key has been read yet, so expiration doesn't come into play + member.poll(Long.MAX_VALUE); + EasyMock.expectLastCall(); + + expectRebalance(2, Collections.emptyList(), Collections.emptyList()); + SessionKey initialKey = new SessionKey(EasyMock.mock(SecretKey.class), 0); + ClusterConfigState snapshotWithKey = new ClusterConfigState(2, initialKey, Collections.singletonMap(CONN1, 3), + Collections.singletonMap(CONN1, CONN1_CONFIG), Collections.singletonMap(CONN1, TargetState.STARTED), + TASK_CONFIGS_MAP, Collections.emptySet()); + expectPostRebalanceCatchup(snapshotWithKey); + // Second rebalance: poll indefinitely as worker is follower, so expiration still doesn't come into play + member.poll(Long.MAX_VALUE); + EasyMock.expectLastCall(); + + expectRebalance(2, Collections.emptyList(), Collections.emptyList(), "member", MEMBER_URL); + Capture updatedKey = EasyMock.newCapture(); + configBackingStore.putSessionKey(EasyMock.capture(updatedKey)); + EasyMock.expectLastCall().andAnswer(() -> { + configUpdateListener.onSessionKeyUpdate(updatedKey.getValue()); + return null; + }); + // Third rebalance: poll for a limited time as worker has become leader and must wake up for key expiration + Capture pollTimeout = EasyMock.newCapture(); + member.poll(EasyMock.captureLong(pollTimeout)); + EasyMock.expectLastCall(); + + PowerMock.replayAll(); + + herder.tick(); + configUpdateListener.onSessionKeyUpdate(initialKey); + herder.tick(); + herder.tick(); + + assertTrue(pollTimeout.getValue() <= DistributedConfig.INTER_WORKER_KEY_TTL_MS_MS_DEFAULT); + + PowerMock.verifyAll(); + } + + @Test + public void testKeyRotationDisabledWhenWorkerBecomesFollower() throws Exception { + EasyMock.expect(member.memberId()).andStubReturn("member"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V2); + + expectRebalance(1, Collections.emptyList(), Collections.emptyList(), "member", MEMBER_URL); + SecretKey initialSecretKey = EasyMock.mock(SecretKey.class); + EasyMock.expect(initialSecretKey.getAlgorithm()).andReturn(DistributedConfig.INTER_WORKER_KEY_GENERATION_ALGORITHM_DEFAULT).anyTimes(); + EasyMock.expect(initialSecretKey.getEncoded()).andReturn(new byte[32]).anyTimes(); + SessionKey initialKey = new SessionKey(initialSecretKey, time.milliseconds()); + ClusterConfigState snapshotWithKey = new ClusterConfigState(1, initialKey, Collections.singletonMap(CONN1, 3), + Collections.singletonMap(CONN1, CONN1_CONFIG), Collections.singletonMap(CONN1, TargetState.STARTED), + TASK_CONFIGS_MAP, Collections.emptySet()); + expectPostRebalanceCatchup(snapshotWithKey); + // First rebalance: poll for a limited time as worker is leader and must wake up for key expiration + Capture firstPollTimeout = EasyMock.newCapture(); + member.poll(EasyMock.captureLong(firstPollTimeout)); + EasyMock.expectLastCall(); + + expectRebalance(1, Collections.emptyList(), Collections.emptyList()); + // Second rebalance: poll indefinitely as worker is no longer leader, so key expiration doesn't come into play + member.poll(Long.MAX_VALUE); + EasyMock.expectLastCall(); + + PowerMock.replayAll(initialSecretKey); + + configUpdateListener.onSessionKeyUpdate(initialKey); + herder.tick(); + assertTrue(firstPollTimeout.getValue() <= DistributedConfig.INTER_WORKER_KEY_TTL_MS_MS_DEFAULT); + herder.tick(); + + PowerMock.verifyAll(); + } + + @Test + public void testPutTaskConfigsSignatureNotRequiredV0() { + Callback taskConfigCb = EasyMock.mock(Callback.class); + + member.wakeup(); + EasyMock.expectLastCall().once(); + EasyMock.expect(member.currentProtocolVersion()).andReturn(CONNECT_PROTOCOL_V0).anyTimes(); + PowerMock.replayAll(taskConfigCb); + + herder.putTaskConfigs(CONN1, TASK_CONFIGS, taskConfigCb, null); + + PowerMock.verifyAll(); + } + @Test + public void testPutTaskConfigsSignatureNotRequiredV1() { + Callback taskConfigCb = EasyMock.mock(Callback.class); + + member.wakeup(); + EasyMock.expectLastCall().once(); + EasyMock.expect(member.currentProtocolVersion()).andReturn(CONNECT_PROTOCOL_V1).anyTimes(); + PowerMock.replayAll(taskConfigCb); + + herder.putTaskConfigs(CONN1, TASK_CONFIGS, taskConfigCb, null); + + PowerMock.verifyAll(); + } + + @Test + public void testPutTaskConfigsMissingRequiredSignature() { + Callback taskConfigCb = EasyMock.mock(Callback.class); + Capture errorCapture = Capture.newInstance(); + taskConfigCb.onCompletion(capture(errorCapture), EasyMock.eq(null)); + EasyMock.expectLastCall().once(); + + EasyMock.expect(member.currentProtocolVersion()).andReturn(CONNECT_PROTOCOL_V2).anyTimes(); + PowerMock.replayAll(taskConfigCb); + + herder.putTaskConfigs(CONN1, TASK_CONFIGS, taskConfigCb, null); + + PowerMock.verifyAll(); + assertTrue(errorCapture.getValue() instanceof BadRequestException); + } + + @Test + public void testPutTaskConfigsDisallowedSignatureAlgorithm() { + Callback taskConfigCb = EasyMock.mock(Callback.class); + Capture errorCapture = Capture.newInstance(); + taskConfigCb.onCompletion(capture(errorCapture), EasyMock.eq(null)); + EasyMock.expectLastCall().once(); + + EasyMock.expect(member.currentProtocolVersion()).andReturn(CONNECT_PROTOCOL_V2).anyTimes(); + + InternalRequestSignature signature = EasyMock.mock(InternalRequestSignature.class); + EasyMock.expect(signature.keyAlgorithm()).andReturn("HmacSHA489").anyTimes(); + + PowerMock.replayAll(taskConfigCb, signature); + + herder.putTaskConfigs(CONN1, TASK_CONFIGS, taskConfigCb, signature); + + PowerMock.verifyAll(); + assertTrue(errorCapture.getValue() instanceof BadRequestException); + } + + @Test + public void testPutTaskConfigsInvalidSignature() { + Callback taskConfigCb = EasyMock.mock(Callback.class); + Capture errorCapture = Capture.newInstance(); + taskConfigCb.onCompletion(capture(errorCapture), EasyMock.eq(null)); + EasyMock.expectLastCall().once(); + + EasyMock.expect(member.currentProtocolVersion()).andReturn(CONNECT_PROTOCOL_V2).anyTimes(); + + InternalRequestSignature signature = EasyMock.mock(InternalRequestSignature.class); + EasyMock.expect(signature.keyAlgorithm()).andReturn("HmacSHA256").anyTimes(); + EasyMock.expect(signature.isValid(EasyMock.anyObject())).andReturn(false).anyTimes(); + + PowerMock.replayAll(taskConfigCb, signature); + + herder.putTaskConfigs(CONN1, TASK_CONFIGS, taskConfigCb, signature); + + PowerMock.verifyAll(); + assertTrue(errorCapture.getValue() instanceof ConnectRestException); + assertEquals(FORBIDDEN.getStatusCode(), ((ConnectRestException) errorCapture.getValue()).statusCode()); + } + + @Test + public void testPutTaskConfigsValidRequiredSignature() { + Callback taskConfigCb = EasyMock.mock(Callback.class); + + member.wakeup(); + EasyMock.expectLastCall().once(); + EasyMock.expect(member.currentProtocolVersion()).andReturn(CONNECT_PROTOCOL_V2).anyTimes(); + + InternalRequestSignature signature = EasyMock.mock(InternalRequestSignature.class); + EasyMock.expect(signature.keyAlgorithm()).andReturn("HmacSHA256").anyTimes(); + EasyMock.expect(signature.isValid(EasyMock.anyObject())).andReturn(true).anyTimes(); + + PowerMock.replayAll(taskConfigCb, signature); + + herder.putTaskConfigs(CONN1, TASK_CONFIGS, taskConfigCb, signature); + + PowerMock.verifyAll(); + } + + @Test + public void testFailedToWriteSessionKey() throws Exception { + // First tick -- after joining the group, we try to write a new + // session key to the config topic, and fail + EasyMock.expect(member.memberId()).andStubReturn("leader"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V2); + expectRebalance(1, Collections.emptyList(), Collections.emptyList()); + expectPostRebalanceCatchup(SNAPSHOT); + configBackingStore.putSessionKey(anyObject(SessionKey.class)); + EasyMock.expectLastCall().andThrow(new ConnectException("Oh no!")); + + // Second tick -- we read to the end of the config topic first, + // then ensure we're still active in the group + // then try a second time to write a new session key, + // then finally begin polling for group activity + expectPostRebalanceCatchup(SNAPSHOT); + member.ensureActive(); + PowerMock.expectLastCall(); + configBackingStore.putSessionKey(anyObject(SessionKey.class)); + EasyMock.expectLastCall(); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(); + + herder.tick(); + herder.tick(); + + PowerMock.verifyAll(); + } + + @Test + public void testFailedToReadBackNewlyWrittenSessionKey() throws Exception { + SecretKey secretKey = EasyMock.niceMock(SecretKey.class); + EasyMock.expect(secretKey.getAlgorithm()).andReturn(INTER_WORKER_KEY_GENERATION_ALGORITHM_DEFAULT); + EasyMock.expect(secretKey.getEncoded()).andReturn(new byte[32]); + SessionKey sessionKey = new SessionKey(secretKey, time.milliseconds()); + ClusterConfigState snapshotWithSessionKey = new ClusterConfigState(1, sessionKey, Collections.singletonMap(CONN1, 3), + Collections.singletonMap(CONN1, CONN1_CONFIG), Collections.singletonMap(CONN1, TargetState.STARTED), + TASK_CONFIGS_MAP, Collections.emptySet()); + + // First tick -- after joining the group, we try to write a new session key to + // the config topic, and fail (in this case, we're trying to simulate that we've + // actually written the key successfully, but haven't been able to read it back + // from the config topic, so to the herder it looks the same as if it'd just failed + // to write the key) + EasyMock.expect(member.memberId()).andStubReturn("leader"); + EasyMock.expect(member.currentProtocolVersion()).andStubReturn(CONNECT_PROTOCOL_V2); + expectRebalance(1, Collections.emptyList(), Collections.emptyList()); + expectPostRebalanceCatchup(SNAPSHOT); + configBackingStore.putSessionKey(anyObject(SessionKey.class)); + EasyMock.expectLastCall().andThrow(new ConnectException("Oh no!")); + + // Second tick -- we read to the end of the config topic first, and pick up + // the session key that we were able to write the last time, + // then ensure we're still active in the group + // then finally begin polling for group activity + // Importantly, we do not try to write a new session key this time around + configBackingStore.refresh(EasyMock.anyLong(), EasyMock.anyObject(TimeUnit.class)); + EasyMock.expectLastCall().andAnswer(() -> { + configUpdateListener.onSessionKeyUpdate(sessionKey); + return null; + }); + EasyMock.expect(configBackingStore.snapshot()).andReturn(snapshotWithSessionKey); + member.ensureActive(); + PowerMock.expectLastCall(); + member.poll(EasyMock.anyInt()); + PowerMock.expectLastCall(); + + PowerMock.replayAll(secretKey); + + herder.tick(); + herder.tick(); + + PowerMock.verifyAll(); + } + + @Test + public void testKeyExceptionDetection() { + assertFalse(herder.isPossibleExpiredKeyException( + time.milliseconds(), + new RuntimeException() + )); + assertFalse(herder.isPossibleExpiredKeyException( + time.milliseconds(), + new BadRequestException("") + )); + assertFalse(herder.isPossibleExpiredKeyException( + time.milliseconds() - TimeUnit.MINUTES.toMillis(2), + new ConnectRestException(FORBIDDEN.getStatusCode(), "") + )); + assertTrue(herder.isPossibleExpiredKeyException( + time.milliseconds(), + new ConnectRestException(FORBIDDEN.getStatusCode(), "") + )); + } + + @Test + public void testInconsistentConfigs() { + // FIXME: if we have inconsistent configs, we need to request forced reconfig + write of the connector's task configs + // This requires inter-worker communication, so needs the REST API + } + + + @Test + public void testThreadNames() { + assertTrue(Whitebox.getInternalState(herder, "herderExecutor"). + getThreadFactory().newThread(EMPTY_RUNNABLE).getName().startsWith(DistributedHerder.class.getSimpleName())); + + assertTrue(Whitebox.getInternalState(herder, "forwardRequestExecutor"). + getThreadFactory().newThread(EMPTY_RUNNABLE).getName().startsWith("ForwardRequestExecutor")); + + assertTrue(Whitebox.getInternalState(herder, "startAndStopExecutor"). + getThreadFactory().newThread(EMPTY_RUNNABLE).getName().startsWith("StartAndStopExecutor")); + } + + @Test + public void testHerderStopServicesClosesUponShutdown() { + assertEquals(1, shutdownCalled.getCount()); + herder.stopServices(); + assertEquals(0, shutdownCalled.getCount()); + } + + private void expectRebalance(final long offset, + final List assignedConnectors, + final List assignedTasks) { + expectRebalance(Collections.emptyList(), Collections.emptyList(), + ConnectProtocol.Assignment.NO_ERROR, offset, assignedConnectors, assignedTasks, 0); + } + + private void expectRebalance(final long offset, + final List assignedConnectors, + final List assignedTasks, + String leader, String leaderUrl) { + expectRebalance(Collections.emptyList(), Collections.emptyList(), + ConnectProtocol.Assignment.NO_ERROR, offset, leader, leaderUrl, assignedConnectors, assignedTasks, 0); + } + + // Handles common initial part of rebalance callback. Does not handle instantiation of connectors and tasks. + private void expectRebalance(final Collection revokedConnectors, + final List revokedTasks, + final short error, + final long offset, + final List assignedConnectors, + final List assignedTasks) { + expectRebalance(revokedConnectors, revokedTasks, error, offset, assignedConnectors, assignedTasks, 0); + } + + // Handles common initial part of rebalance callback. Does not handle instantiation of connectors and tasks. + private void expectRebalance(final Collection revokedConnectors, + final List revokedTasks, + final short error, + final long offset, + final List assignedConnectors, + final List assignedTasks, + int delay) { + expectRebalance(revokedConnectors, revokedTasks, error, offset, "leader", "leaderUrl", assignedConnectors, assignedTasks, delay); + } + + // Handles common initial part of rebalance callback. Does not handle instantiation of connectors and tasks. + private void expectRebalance(final Collection revokedConnectors, + final List revokedTasks, + final short error, + final long offset, + String leader, + String leaderUrl, + final List assignedConnectors, + final List assignedTasks, + int delay) { + member.ensureActive(); + PowerMock.expectLastCall().andAnswer(() -> { + ExtendedAssignment assignment; + if (!revokedConnectors.isEmpty() || !revokedTasks.isEmpty()) { + rebalanceListener.onRevoked(leader, revokedConnectors, revokedTasks); + } + + if (connectProtocolVersion == CONNECT_PROTOCOL_V0) { + assignment = new ExtendedAssignment( + connectProtocolVersion, error, leader, leaderUrl, offset, + assignedConnectors, assignedTasks, + Collections.emptyList(), Collections.emptyList(), 0); + } else { + assignment = new ExtendedAssignment( + connectProtocolVersion, error, leader, leaderUrl, offset, + assignedConnectors, assignedTasks, + new ArrayList<>(revokedConnectors), new ArrayList<>(revokedTasks), delay); + } + rebalanceListener.onAssigned(assignment, 3); + time.sleep(100L); + return null; + }); + + if (!revokedConnectors.isEmpty()) { + for (String connector : revokedConnectors) { + worker.stopAndAwaitConnector(connector); + PowerMock.expectLastCall(); + } + } + + if (!revokedTasks.isEmpty()) { + worker.stopAndAwaitTask(EasyMock.anyObject(ConnectorTaskId.class)); + PowerMock.expectLastCall(); + } + + if (!revokedConnectors.isEmpty()) { + statusBackingStore.flush(); + PowerMock.expectLastCall(); + } + + member.wakeup(); + PowerMock.expectLastCall(); + } + + private void expectPostRebalanceCatchup(final ClusterConfigState readToEndSnapshot) throws TimeoutException { + configBackingStore.refresh(EasyMock.anyLong(), EasyMock.anyObject(TimeUnit.class)); + EasyMock.expectLastCall(); + EasyMock.expect(configBackingStore.snapshot()).andReturn(readToEndSnapshot); + } + + private void assertStatistics(int expectedEpoch, int completedRebalances, double rebalanceTime, double millisSinceLastRebalance) { + String expectedLeader = completedRebalances <= 0 ? null : "leaderUrl"; + assertStatistics(expectedLeader, false, expectedEpoch, completedRebalances, rebalanceTime, millisSinceLastRebalance); + } + + private void assertStatistics(String expectedLeader, boolean isRebalancing, int expectedEpoch, int completedRebalances, double rebalanceTime, double millisSinceLastRebalance) { + HerderMetrics herderMetrics = herder.herderMetrics(); + MetricGroup group = herderMetrics.metricGroup(); + double epoch = MockConnectMetrics.currentMetricValueAsDouble(metrics, group, "epoch"); + String leader = MockConnectMetrics.currentMetricValueAsString(metrics, group, "leader-name"); + double rebalanceCompletedTotal = MockConnectMetrics.currentMetricValueAsDouble(metrics, group, "completed-rebalances-total"); + double rebalancing = MockConnectMetrics.currentMetricValueAsDouble(metrics, group, "rebalancing"); + double rebalanceTimeMax = MockConnectMetrics.currentMetricValueAsDouble(metrics, group, "rebalance-max-time-ms"); + double rebalanceTimeAvg = MockConnectMetrics.currentMetricValueAsDouble(metrics, group, "rebalance-avg-time-ms"); + double rebalanceTimeSinceLast = MockConnectMetrics.currentMetricValueAsDouble(metrics, group, "time-since-last-rebalance-ms"); + + assertEquals(expectedEpoch, epoch, 0.0001d); + assertEquals(expectedLeader, leader); + assertEquals(completedRebalances, rebalanceCompletedTotal, 0.0001d); + assertEquals(isRebalancing ? 1.0d : 0.0d, rebalancing, 0.0001d); + assertEquals(millisSinceLastRebalance, rebalanceTimeSinceLast, 0.0001d); + if (rebalanceTime <= 0L) { + assertEquals(Double.NaN, rebalanceTimeMax, 0.0001d); + assertEquals(Double.NaN, rebalanceTimeAvg, 0.0001d); + } else { + assertEquals(rebalanceTime, rebalanceTimeMax, 0.0001d); + assertEquals(rebalanceTime, rebalanceTimeAvg, 0.0001d); + } + } + + @Test + public void processRestartRequestsFailureSuppression() { + member.wakeup(); + PowerMock.expectLastCall().anyTimes(); + + final String connectorName = "foo"; + RestartRequest restartRequest = new RestartRequest(connectorName, false, false); + EasyMock.expect(herder.buildRestartPlan(restartRequest)).andThrow(new RuntimeException()).anyTimes(); + + PowerMock.replayAll(); + + configUpdateListener.onRestartRequest(restartRequest); + assertEquals(1, herder.pendingRestartRequests.size()); + herder.processRestartRequests(); + assertTrue(herder.pendingRestartRequests.isEmpty()); + } + + @Test + public void processRestartRequestsDequeue() { + member.wakeup(); + PowerMock.expectLastCall().anyTimes(); + + EasyMock.expect(herder.buildRestartPlan(EasyMock.anyObject(RestartRequest.class))).andReturn(Optional.empty()).anyTimes(); + + PowerMock.replayAll(); + + RestartRequest restartRequest = new RestartRequest("foo", false, false); + configUpdateListener.onRestartRequest(restartRequest); + restartRequest = new RestartRequest("bar", false, false); + configUpdateListener.onRestartRequest(restartRequest); + assertEquals(2, herder.pendingRestartRequests.size()); + herder.processRestartRequests(); + assertTrue(herder.pendingRestartRequests.isEmpty()); + } + + @Test + public void preserveHighestImpactRestartRequest() { + member.wakeup(); + PowerMock.expectLastCall().anyTimes(); + PowerMock.replayAll(); + + final String connectorName = "foo"; + RestartRequest restartRequest = new RestartRequest(connectorName, false, false); + configUpdateListener.onRestartRequest(restartRequest); + + //will overwrite as this is higher impact + restartRequest = new RestartRequest(connectorName, false, true); + configUpdateListener.onRestartRequest(restartRequest); + assertEquals(1, herder.pendingRestartRequests.size()); + assertFalse(herder.pendingRestartRequests.get(connectorName).onlyFailed()); + assertTrue(herder.pendingRestartRequests.get(connectorName).includeTasks()); + + //will be ignored as the existing request has higher impact + restartRequest = new RestartRequest(connectorName, true, false); + configUpdateListener.onRestartRequest(restartRequest); + assertEquals(1, herder.pendingRestartRequests.size()); + //compare against existing request + assertFalse(herder.pendingRestartRequests.get(connectorName).onlyFailed()); + assertTrue(herder.pendingRestartRequests.get(connectorName).includeTasks()); + } + + // We need to use a real class here due to some issue with mocking java.lang.Class + private abstract class BogusSourceConnector extends SourceConnector { + } + + private abstract class BogusSourceTask extends SourceTask { + } + +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/IncrementalCooperativeAssignorTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/IncrementalCooperativeAssignorTest.java new file mode 100644 index 0000000..0fe1531 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/IncrementalCooperativeAssignorTest.java @@ -0,0 +1,1469 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.distributed; + +import org.apache.kafka.clients.consumer.internals.RequestFuture; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.connect.runtime.TargetState; +import org.apache.kafka.connect.runtime.distributed.WorkerCoordinator.ConnectorsAndTasks; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +import java.util.AbstractMap.SimpleEntry; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.apache.kafka.connect.runtime.distributed.IncrementalCooperativeConnectProtocol.CONNECT_PROTOCOL_V1; +import static org.apache.kafka.connect.runtime.distributed.IncrementalCooperativeConnectProtocol.CONNECT_PROTOCOL_V2; +import static org.apache.kafka.connect.runtime.distributed.WorkerCoordinator.WorkerLoad; +import static org.hamcrest.CoreMatchers.hasItems; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.runners.Parameterized.Parameter; +import static org.junit.runners.Parameterized.Parameters; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class IncrementalCooperativeAssignorTest { + @Rule + public MockitoRule rule = MockitoJUnit.rule(); + + @Mock + private WorkerCoordinator coordinator; + + @Captor + ArgumentCaptor> assignmentsCapture; + + @Parameters + public static Iterable mode() { + return Arrays.asList(new Object[][] {{CONNECT_PROTOCOL_V1, CONNECT_PROTOCOL_V2}}); + } + + @Parameter + public short protocolVersion; + + private ClusterConfigState configState; + private Map memberConfigs; + private Map expectedMemberConfigs; + private long offset; + private String leader; + private String leaderUrl; + private Time time; + private int rebalanceDelay; + private IncrementalCooperativeAssignor assignor; + private int rebalanceNum; + Map assignments; + Map returnedAssignments; + + @Before + public void setup() { + leader = "worker1"; + leaderUrl = expectedLeaderUrl(leader); + offset = 10; + configState = clusterConfigState(offset, 2, 4); + memberConfigs = memberConfigs(leader, offset, 1, 1); + time = Time.SYSTEM; + rebalanceDelay = DistributedConfig.SCHEDULED_REBALANCE_MAX_DELAY_MS_DEFAULT; + assignments = new HashMap<>(); + initAssignor(); + } + + @After + public void teardown() { + verifyNoMoreInteractions(coordinator); + } + + public void initAssignor() { + assignor = Mockito.spy(new IncrementalCooperativeAssignor( + new LogContext(), + time, + rebalanceDelay)); + assignor.previousGenerationId = 1000; + } + + @Test + public void testTaskAssignmentWhenWorkerJoins() { + when(coordinator.configSnapshot()).thenReturn(configState); + doReturn(Collections.EMPTY_MAP).when(assignor).serializeAssignments(assignmentsCapture.capture()); + + // First assignment with 1 worker and 2 connectors configured but not yet assigned + expectGeneration(); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(0, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(2, 8, 0, 0, "worker1"); + + // Second assignment with a second worker joining and all connectors running on previous worker + applyAssignments(returnedAssignments); + memberConfigs = memberConfigs(leader, offset, assignments); + memberConfigs.put("worker2", new ExtendedWorkerState(leaderUrl, offset, null)); + expectGeneration(); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(0, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(0, 0, 1, 4, "worker1", "worker2"); + + // Third assignment after revocations + applyAssignments(returnedAssignments); + memberConfigs = memberConfigs(leader, offset, assignments); + expectGeneration(); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(0, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(1, 4, 0, 0, "worker1", "worker2"); + + // A fourth rebalance should not change assignments + applyAssignments(returnedAssignments); + memberConfigs = memberConfigs(leader, offset, assignments); + expectGeneration(); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(0, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(0, 0, 0, 0, "worker1", "worker2"); + + verify(coordinator, times(rebalanceNum)).configSnapshot(); + verify(coordinator, times(rebalanceNum)).leaderState(any()); + verify(coordinator, times(2 * rebalanceNum)).generationId(); + verify(coordinator, times(rebalanceNum)).memberId(); + verify(coordinator, times(rebalanceNum)).lastCompletedGenerationId(); + } + + @Test + public void testTaskAssignmentWhenWorkerLeavesPermanently() { + // Customize assignor for this test case + time = new MockTime(); + initAssignor(); + + when(coordinator.configSnapshot()).thenReturn(configState); + doReturn(Collections.EMPTY_MAP).when(assignor).serializeAssignments(assignmentsCapture.capture()); + + // First assignment with 2 workers and 2 connectors configured but not yet assigned + memberConfigs.put("worker2", new ExtendedWorkerState(leaderUrl, offset, null)); + expectGeneration(); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(0, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(2, 8, 0, 0, "worker1", "worker2"); + + // Second assignment with only one worker remaining in the group. The worker that left the + // group was a follower. No re-assignments take place immediately and the count + // down for the rebalance delay starts + applyAssignments(returnedAssignments); + assignments.remove("worker2"); + memberConfigs = memberConfigs(leader, offset, assignments); + expectGeneration(); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(rebalanceDelay, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(0, 0, 0, 0, "worker1"); + + time.sleep(rebalanceDelay / 2); + + // Third (incidental) assignment with still only one worker in the group. Max delay has not + // been reached yet + applyAssignments(returnedAssignments); + memberConfigs = memberConfigs(leader, offset, assignments); + expectGeneration(); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(rebalanceDelay / 2, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(0, 0, 0, 0, "worker1"); + + time.sleep(rebalanceDelay / 2 + 1); + + // Fourth assignment after delay expired + applyAssignments(returnedAssignments); + memberConfigs = memberConfigs(leader, offset, assignments); + expectGeneration(); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(0, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(1, 4, 0, 0, "worker1"); + + verify(coordinator, times(rebalanceNum)).configSnapshot(); + verify(coordinator, times(rebalanceNum)).leaderState(any()); + verify(coordinator, times(2 * rebalanceNum)).generationId(); + verify(coordinator, times(rebalanceNum)).memberId(); + verify(coordinator, times(rebalanceNum)).lastCompletedGenerationId(); + } + + @Test + public void testTaskAssignmentWhenWorkerBounces() { + // Customize assignor for this test case + time = new MockTime(); + initAssignor(); + + when(coordinator.configSnapshot()).thenReturn(configState); + doReturn(Collections.EMPTY_MAP).when(assignor).serializeAssignments(assignmentsCapture.capture()); + + // First assignment with 2 workers and 2 connectors configured but not yet assigned + memberConfigs.put("worker2", new ExtendedWorkerState(leaderUrl, offset, null)); + expectGeneration(); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(0, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(2, 8, 0, 0, "worker1", "worker2"); + + // Second assignment with only one worker remaining in the group. The worker that left the + // group was a follower. No re-assignments take place immediately and the count + // down for the rebalance delay starts + applyAssignments(returnedAssignments); + assignments.remove("worker2"); + memberConfigs = memberConfigs(leader, offset, assignments); + expectGeneration(); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(rebalanceDelay, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(0, 0, 0, 0, "worker1"); + + time.sleep(rebalanceDelay / 2); + + // Third (incidental) assignment with still only one worker in the group. Max delay has not + // been reached yet + applyAssignments(returnedAssignments); + memberConfigs = memberConfigs(leader, offset, assignments); + expectGeneration(); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(rebalanceDelay / 2, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(0, 0, 0, 0, "worker1"); + + time.sleep(rebalanceDelay / 4); + + // Fourth assignment with the second worker returning before the delay expires + // Since the delay is still active, lost assignments are not reassigned yet + applyAssignments(returnedAssignments); + memberConfigs = memberConfigs(leader, offset, assignments); + memberConfigs.put("worker2", new ExtendedWorkerState(leaderUrl, offset, null)); + expectGeneration(); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(rebalanceDelay / 4, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(0, 0, 0, 0, "worker1", "worker2"); + + time.sleep(rebalanceDelay / 4); + + // Fifth assignment with the same two workers. The delay has expired, so the lost + // assignments ought to be assigned to the worker that has appeared as returned. + applyAssignments(returnedAssignments); + memberConfigs = memberConfigs(leader, offset, assignments); + expectGeneration(); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(0, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(1, 4, 0, 0, "worker1", "worker2"); + + verify(coordinator, times(rebalanceNum)).configSnapshot(); + verify(coordinator, times(rebalanceNum)).leaderState(any()); + verify(coordinator, times(2 * rebalanceNum)).generationId(); + verify(coordinator, times(rebalanceNum)).memberId(); + verify(coordinator, times(rebalanceNum)).lastCompletedGenerationId(); + } + + @Test + public void testTaskAssignmentWhenLeaderLeavesPermanently() { + // Customize assignor for this test case + time = new MockTime(); + initAssignor(); + + when(coordinator.configSnapshot()).thenReturn(configState); + doReturn(Collections.EMPTY_MAP).when(assignor).serializeAssignments(assignmentsCapture.capture()); + + // First assignment with 3 workers and 2 connectors configured but not yet assigned + memberConfigs.put("worker2", new ExtendedWorkerState(leaderUrl, offset, null)); + memberConfigs.put("worker3", new ExtendedWorkerState(leaderUrl, offset, null)); + expectGeneration(); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(0, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(2, 8, 0, 0, "worker1", "worker2", "worker3"); + + // Second assignment with two workers remaining in the group. The worker that left the + // group was the leader. The new leader has no previous assignments and is not tracking a + // delay upon a leader's exit + applyAssignments(returnedAssignments); + assignments.remove("worker1"); + leader = "worker2"; + leaderUrl = expectedLeaderUrl(leader); + memberConfigs = memberConfigs(leader, offset, assignments); + // The fact that the leader bounces means that the assignor starts from a clean slate + initAssignor(); + + // Capture needs to be reset to point to the new assignor + doReturn(Collections.EMPTY_MAP).when(assignor).serializeAssignments(assignmentsCapture.capture()); + expectGeneration(); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(0, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(1, 3, 0, 0, "worker2", "worker3"); + + // Third (incidental) assignment with still only one worker in the group. + applyAssignments(returnedAssignments); + memberConfigs = memberConfigs(leader, offset, assignments); + expectGeneration(); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(0, 0, 0, 0, "worker2", "worker3"); + + verify(coordinator, times(rebalanceNum)).configSnapshot(); + verify(coordinator, times(rebalanceNum)).leaderState(any()); + verify(coordinator, times(2 * rebalanceNum)).generationId(); + verify(coordinator, times(rebalanceNum)).memberId(); + verify(coordinator, times(rebalanceNum)).lastCompletedGenerationId(); + } + + @Test + public void testTaskAssignmentWhenLeaderBounces() { + // Customize assignor for this test case + time = new MockTime(); + initAssignor(); + + when(coordinator.configSnapshot()).thenReturn(configState); + doReturn(Collections.EMPTY_MAP).when(assignor).serializeAssignments(assignmentsCapture.capture()); + + // First assignment with 3 workers and 2 connectors configured but not yet assigned + memberConfigs.put("worker2", new ExtendedWorkerState(leaderUrl, offset, null)); + memberConfigs.put("worker3", new ExtendedWorkerState(leaderUrl, offset, null)); + expectGeneration(); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(0, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(2, 8, 0, 0, "worker1", "worker2", "worker3"); + + // Second assignment with two workers remaining in the group. The worker that left the + // group was the leader. The new leader has no previous assignments and is not tracking a + // delay upon a leader's exit + applyAssignments(returnedAssignments); + assignments.remove("worker1"); + leader = "worker2"; + leaderUrl = expectedLeaderUrl(leader); + memberConfigs = memberConfigs(leader, offset, assignments); + // The fact that the leader bounces means that the assignor starts from a clean slate + initAssignor(); + + // Capture needs to be reset to point to the new assignor + doReturn(Collections.EMPTY_MAP).when(assignor).serializeAssignments(assignmentsCapture.capture()); + expectGeneration(); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(0, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(1, 3, 0, 0, "worker2", "worker3"); + + // Third assignment with the previous leader returning as a follower. In this case, the + // arrival of the previous leader is treated as an arrival of a new worker. Reassignment + // happens immediately, first with a revocation + applyAssignments(returnedAssignments); + memberConfigs = memberConfigs(leader, offset, assignments); + memberConfigs.put("worker1", new ExtendedWorkerState(leaderUrl, offset, null)); + expectGeneration(); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(0, 0, 0, 2, "worker1", "worker2", "worker3"); + + // Fourth assignment after revocations + applyAssignments(returnedAssignments); + memberConfigs = memberConfigs(leader, offset, assignments); + expectGeneration(); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(0, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(0, 2, 0, 0, "worker1", "worker2", "worker3"); + + verify(coordinator, times(rebalanceNum)).configSnapshot(); + verify(coordinator, times(rebalanceNum)).leaderState(any()); + verify(coordinator, times(2 * rebalanceNum)).generationId(); + verify(coordinator, times(rebalanceNum)).memberId(); + verify(coordinator, times(rebalanceNum)).lastCompletedGenerationId(); + } + + @Test + public void testTaskAssignmentWhenFirstAssignmentAttemptFails() { + // Customize assignor for this test case + time = new MockTime(); + initAssignor(); + + when(coordinator.configSnapshot()).thenReturn(configState); + doThrow(new RuntimeException("Unable to send computed assignment with SyncGroupRequest")) + .when(assignor).serializeAssignments(assignmentsCapture.capture()); + + // First assignment with 2 workers and 2 connectors configured but not yet assigned + memberConfigs.put("worker2", new ExtendedWorkerState(leaderUrl, offset, null)); + try { + expectGeneration(); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + } catch (RuntimeException e) { + RequestFuture.failure(e); + } + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + // This was the assignment that should have been sent, but didn't make it all the way + assertDelay(0, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(2, 8, 0, 0, "worker1", "worker2"); + + // Second assignment happens with members returning the same assignments (memberConfigs) + // as the first time. The assignor detects that the number of members did not change and + // avoids the rebalance delay, treating the lost assignments as new assignments. + doReturn(Collections.EMPTY_MAP).when(assignor).serializeAssignments(assignmentsCapture.capture()); + expectGeneration(); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(0, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(2, 8, 0, 0, "worker1", "worker2"); + + verify(coordinator, times(rebalanceNum)).configSnapshot(); + verify(coordinator, times(rebalanceNum)).leaderState(any()); + verify(coordinator, times(2 * rebalanceNum)).generationId(); + verify(coordinator, times(rebalanceNum)).memberId(); + verify(coordinator, times(rebalanceNum)).lastCompletedGenerationId(); + } + + @Test + public void testTaskAssignmentWhenSubsequentAssignmentAttemptFails() { + // Customize assignor for this test case + time = new MockTime(); + initAssignor(); + + when(coordinator.configSnapshot()).thenReturn(configState); + doReturn(Collections.EMPTY_MAP).when(assignor).serializeAssignments(assignmentsCapture.capture()); + + // First assignment with 2 workers and 2 connectors configured but not yet assigned + memberConfigs.put("worker2", new ExtendedWorkerState(leaderUrl, offset, null)); + expectGeneration(); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(0, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(2, 8, 0, 0, "worker1", "worker2"); + + when(coordinator.configSnapshot()).thenReturn(configState); + doThrow(new RuntimeException("Unable to send computed assignment with SyncGroupRequest")) + .when(assignor).serializeAssignments(assignmentsCapture.capture()); + + // Second assignment triggered by a third worker joining. The computed assignment should + // revoke tasks from the existing group. But the assignment won't be correctly delivered. + applyAssignments(returnedAssignments); + memberConfigs = memberConfigs(leader, offset, assignments); + memberConfigs.put("worker3", new ExtendedWorkerState(leaderUrl, offset, null)); + try { + expectGeneration(); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + } catch (RuntimeException e) { + RequestFuture.failure(e); + } + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + // This was the assignment that should have been sent, but didn't make it all the way + assertDelay(0, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(0, 0, 0, 2, "worker1", "worker2", "worker3"); + + // Third assignment happens with members returning the same assignments (memberConfigs) + // as the first time. + doReturn(Collections.EMPTY_MAP).when(assignor).serializeAssignments(assignmentsCapture.capture()); + expectGeneration(); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertDelay(0, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(0, 0, 0, 2, "worker1", "worker2", "worker3"); + + verify(coordinator, times(rebalanceNum)).configSnapshot(); + verify(coordinator, times(rebalanceNum)).leaderState(any()); + verify(coordinator, times(2 * rebalanceNum)).generationId(); + verify(coordinator, times(rebalanceNum)).memberId(); + verify(coordinator, times(rebalanceNum)).lastCompletedGenerationId(); + } + + @Test + public void testTaskAssignmentWhenSubsequentAssignmentAttemptFailsOutsideTheAssignor() { + // Customize assignor for this test case + time = new MockTime(); + initAssignor(); + + expectGeneration(); + when(coordinator.configSnapshot()).thenReturn(configState); + doReturn(Collections.EMPTY_MAP).when(assignor).serializeAssignments(assignmentsCapture.capture()); + + // First assignment with 2 workers and 2 connectors configured but not yet assigned + memberConfigs.put("worker2", new ExtendedWorkerState(leaderUrl, offset, null)); + expectGeneration(); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(0, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(2, 8, 0, 0, "worker1", "worker2"); + + // Second assignment triggered by a third worker joining. The computed assignment should + // revoke tasks from the existing group. But the assignment won't be correctly delivered + // and sync group with fail on the leader worker. + applyAssignments(returnedAssignments); + memberConfigs = memberConfigs(leader, offset, assignments); + memberConfigs.put("worker3", new ExtendedWorkerState(leaderUrl, offset, null)); + when(coordinator.generationId()) + .thenReturn(assignor.previousGenerationId + 1) + .thenReturn(assignor.previousGenerationId + 1); + when(coordinator.lastCompletedGenerationId()).thenReturn(assignor.previousGenerationId - 1); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + // This was the assignment that should have been sent, but didn't make it all the way + assertDelay(0, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(0, 0, 0, 2, "worker1", "worker2", "worker3"); + + // Third assignment happens with members returning the same assignments (memberConfigs) + // as the first time. + when(coordinator.lastCompletedGenerationId()).thenReturn(assignor.previousGenerationId - 1); + doReturn(Collections.EMPTY_MAP).when(assignor).serializeAssignments(assignmentsCapture.capture()); + expectGeneration(); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertDelay(0, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(0, 0, 0, 2, "worker1", "worker2", "worker3"); + + verify(coordinator, times(rebalanceNum)).configSnapshot(); + verify(coordinator, times(rebalanceNum)).leaderState(any()); + verify(coordinator, times(2 * rebalanceNum)).generationId(); + verify(coordinator, times(rebalanceNum)).memberId(); + verify(coordinator, times(rebalanceNum)).lastCompletedGenerationId(); + } + + @Test + public void testTaskAssignmentWhenConnectorsAreDeleted() { + configState = clusterConfigState(offset, 3, 4); + when(coordinator.configSnapshot()).thenReturn(configState); + doReturn(Collections.EMPTY_MAP).when(assignor).serializeAssignments(assignmentsCapture.capture()); + + // First assignment with 1 worker and 2 connectors configured but not yet assigned + memberConfigs.put("worker2", new ExtendedWorkerState(leaderUrl, offset, null)); + expectGeneration(); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(0, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(3, 12, 0, 0, "worker1", "worker2"); + + // Second assignment with an updated config state that reflects removal of a connector + configState = clusterConfigState(offset + 1, 2, 4); + when(coordinator.configSnapshot()).thenReturn(configState); + applyAssignments(returnedAssignments); + memberConfigs = memberConfigs(leader, offset, assignments); + expectGeneration(); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(0, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(0, 0, 1, 4, "worker1", "worker2"); + + verify(coordinator, times(rebalanceNum)).configSnapshot(); + verify(coordinator, times(rebalanceNum)).leaderState(any()); + verify(coordinator, times(2 * rebalanceNum)).generationId(); + verify(coordinator, times(rebalanceNum)).memberId(); + verify(coordinator, times(rebalanceNum)).lastCompletedGenerationId(); + } + + @Test + public void testAssignConnectorsWhenBalanced() { + int num = 2; + List existingAssignment = IntStream.range(0, 3) + .mapToObj(i -> workerLoad("worker" + i, i * num, num, i * num, num)) + .collect(Collectors.toList()); + + List expectedAssignment = existingAssignment.stream() + .map(wl -> new WorkerLoad.Builder(wl.worker()).withCopies(wl.connectors(), wl.tasks()).build()) + .collect(Collectors.toList()); + expectedAssignment.get(0).connectors().addAll(Arrays.asList("connector6", "connector9")); + expectedAssignment.get(1).connectors().addAll(Arrays.asList("connector7", "connector10")); + expectedAssignment.get(2).connectors().addAll(Arrays.asList("connector8")); + + List newConnectors = newConnectors(6, 11); + assignor.assignConnectors(existingAssignment, newConnectors); + assertEquals(expectedAssignment, existingAssignment); + } + + @Test + public void testAssignTasksWhenBalanced() { + int num = 2; + List existingAssignment = IntStream.range(0, 3) + .mapToObj(i -> workerLoad("worker" + i, i * num, num, i * num, num)) + .collect(Collectors.toList()); + + List expectedAssignment = existingAssignment.stream() + .map(wl -> new WorkerLoad.Builder(wl.worker()).withCopies(wl.connectors(), wl.tasks()).build()) + .collect(Collectors.toList()); + + expectedAssignment.get(0).connectors().addAll(Arrays.asList("connector6", "connector9")); + expectedAssignment.get(1).connectors().addAll(Arrays.asList("connector7", "connector10")); + expectedAssignment.get(2).connectors().addAll(Arrays.asList("connector8")); + + expectedAssignment.get(0).tasks().addAll(Arrays.asList(new ConnectorTaskId("task", 6), new ConnectorTaskId("task", 9))); + expectedAssignment.get(1).tasks().addAll(Arrays.asList(new ConnectorTaskId("task", 7), new ConnectorTaskId("task", 10))); + expectedAssignment.get(2).tasks().addAll(Arrays.asList(new ConnectorTaskId("task", 8))); + + List newConnectors = newConnectors(6, 11); + assignor.assignConnectors(existingAssignment, newConnectors); + List newTasks = newTasks(6, 11); + assignor.assignTasks(existingAssignment, newTasks); + assertEquals(expectedAssignment, existingAssignment); + } + + @Test + public void testAssignConnectorsWhenImbalanced() { + List existingAssignment = new ArrayList<>(); + existingAssignment.add(workerLoad("worker0", 0, 2, 0, 2)); + existingAssignment.add(workerLoad("worker1", 2, 3, 2, 3)); + existingAssignment.add(workerLoad("worker2", 5, 4, 5, 4)); + existingAssignment.add(emptyWorkerLoad("worker3")); + + List newConnectors = newConnectors(9, 24); + List newTasks = newTasks(9, 24); + assignor.assignConnectors(existingAssignment, newConnectors); + assignor.assignTasks(existingAssignment, newTasks); + for (WorkerLoad worker : existingAssignment) { + assertEquals(6, worker.connectorsSize()); + assertEquals(6, worker.tasksSize()); + } + } + + @Test + public void testLostAssignmentHandlingWhenWorkerBounces() { + // Customize assignor for this test case + time = new MockTime(); + initAssignor(); + + assertTrue(assignor.candidateWorkersForReassignment.isEmpty()); + assertEquals(0, assignor.scheduledRebalance); + assertEquals(0, assignor.delay); + + Map configuredAssignment = new HashMap<>(); + configuredAssignment.put("worker0", workerLoad("worker0", 0, 2, 0, 4)); + configuredAssignment.put("worker1", workerLoad("worker1", 2, 2, 4, 4)); + configuredAssignment.put("worker2", workerLoad("worker2", 4, 2, 8, 4)); + memberConfigs = memberConfigs(leader, offset, 0, 2); + + ConnectorsAndTasks newSubmissions = new ConnectorsAndTasks.Builder().build(); + + // No lost assignments + assignor.handleLostAssignments(new ConnectorsAndTasks.Builder().build(), + newSubmissions, + new ArrayList<>(configuredAssignment.values()), + memberConfigs); + + assertThat("Wrong set of workers for reassignments", + Collections.emptySet(), + is(assignor.candidateWorkersForReassignment)); + assertEquals(0, assignor.scheduledRebalance); + assertEquals(0, assignor.delay); + + assignor.previousMembers = new HashSet<>(memberConfigs.keySet()); + String flakyWorker = "worker1"; + WorkerLoad lostLoad = workerLoad(flakyWorker, 2, 2, 4, 4); + memberConfigs.remove(flakyWorker); + + ConnectorsAndTasks lostAssignments = new ConnectorsAndTasks.Builder() + .withCopies(lostLoad.connectors(), lostLoad.tasks()).build(); + + // Lost assignments detected - No candidate worker has appeared yet (worker with no assignments) + assignor.handleLostAssignments(lostAssignments, newSubmissions, + new ArrayList<>(configuredAssignment.values()), memberConfigs); + + assertThat("Wrong set of workers for reassignments", + Collections.emptySet(), + is(assignor.candidateWorkersForReassignment)); + assertEquals(time.milliseconds() + rebalanceDelay, assignor.scheduledRebalance); + assertEquals(rebalanceDelay, assignor.delay); + + assignor.previousMembers = new HashSet<>(memberConfigs.keySet()); + time.sleep(rebalanceDelay / 2); + rebalanceDelay /= 2; + + // A new worker (probably returning worker) has joined + configuredAssignment.put(flakyWorker, new WorkerLoad.Builder(flakyWorker).build()); + memberConfigs.put(flakyWorker, new ExtendedWorkerState(leaderUrl, offset, null)); + assignor.handleLostAssignments(lostAssignments, newSubmissions, + new ArrayList<>(configuredAssignment.values()), memberConfigs); + + assertThat("Wrong set of workers for reassignments", + Collections.singleton(flakyWorker), + is(assignor.candidateWorkersForReassignment)); + assertEquals(time.milliseconds() + rebalanceDelay, assignor.scheduledRebalance); + assertEquals(rebalanceDelay, assignor.delay); + + assignor.previousMembers = new HashSet<>(memberConfigs.keySet()); + time.sleep(rebalanceDelay); + + // The new worker has still no assignments + assignor.handleLostAssignments(lostAssignments, newSubmissions, + new ArrayList<>(configuredAssignment.values()), memberConfigs); + + assertTrue("Wrong assignment of lost connectors", + configuredAssignment.getOrDefault(flakyWorker, new WorkerLoad.Builder(flakyWorker).build()) + .connectors() + .containsAll(lostAssignments.connectors())); + assertTrue("Wrong assignment of lost tasks", + configuredAssignment.getOrDefault(flakyWorker, new WorkerLoad.Builder(flakyWorker).build()) + .tasks() + .containsAll(lostAssignments.tasks())); + assertThat("Wrong set of workers for reassignments", + Collections.emptySet(), + is(assignor.candidateWorkersForReassignment)); + assertEquals(0, assignor.scheduledRebalance); + assertEquals(0, assignor.delay); + } + + @Test + public void testLostAssignmentHandlingWhenWorkerLeavesPermanently() { + // Customize assignor for this test case + time = new MockTime(); + initAssignor(); + + assertTrue(assignor.candidateWorkersForReassignment.isEmpty()); + assertEquals(0, assignor.scheduledRebalance); + assertEquals(0, assignor.delay); + + Map configuredAssignment = new HashMap<>(); + configuredAssignment.put("worker0", workerLoad("worker0", 0, 2, 0, 4)); + configuredAssignment.put("worker1", workerLoad("worker1", 2, 2, 4, 4)); + configuredAssignment.put("worker2", workerLoad("worker2", 4, 2, 8, 4)); + memberConfigs = memberConfigs(leader, offset, 0, 2); + + ConnectorsAndTasks newSubmissions = new ConnectorsAndTasks.Builder().build(); + + // No lost assignments + assignor.handleLostAssignments(new ConnectorsAndTasks.Builder().build(), + newSubmissions, + new ArrayList<>(configuredAssignment.values()), + memberConfigs); + + assertThat("Wrong set of workers for reassignments", + Collections.emptySet(), + is(assignor.candidateWorkersForReassignment)); + assertEquals(0, assignor.scheduledRebalance); + assertEquals(0, assignor.delay); + + assignor.previousMembers = new HashSet<>(memberConfigs.keySet()); + String removedWorker = "worker1"; + WorkerLoad lostLoad = workerLoad(removedWorker, 2, 2, 4, 4); + memberConfigs.remove(removedWorker); + + ConnectorsAndTasks lostAssignments = new ConnectorsAndTasks.Builder() + .withCopies(lostLoad.connectors(), lostLoad.tasks()).build(); + + // Lost assignments detected - No candidate worker has appeared yet (worker with no assignments) + assignor.handleLostAssignments(lostAssignments, newSubmissions, + new ArrayList<>(configuredAssignment.values()), memberConfigs); + + assertThat("Wrong set of workers for reassignments", + Collections.emptySet(), + is(assignor.candidateWorkersForReassignment)); + assertEquals(time.milliseconds() + rebalanceDelay, assignor.scheduledRebalance); + assertEquals(rebalanceDelay, assignor.delay); + + assignor.previousMembers = new HashSet<>(memberConfigs.keySet()); + time.sleep(rebalanceDelay / 2); + rebalanceDelay /= 2; + + // No new worker has joined + assignor.handleLostAssignments(lostAssignments, newSubmissions, + new ArrayList<>(configuredAssignment.values()), memberConfigs); + + assertThat("Wrong set of workers for reassignments", + Collections.emptySet(), + is(assignor.candidateWorkersForReassignment)); + assertEquals(time.milliseconds() + rebalanceDelay, assignor.scheduledRebalance); + assertEquals(rebalanceDelay, assignor.delay); + + time.sleep(rebalanceDelay); + + assignor.handleLostAssignments(lostAssignments, newSubmissions, + new ArrayList<>(configuredAssignment.values()), memberConfigs); + + assertTrue("Wrong assignment of lost connectors", + newSubmissions.connectors().containsAll(lostAssignments.connectors())); + assertTrue("Wrong assignment of lost tasks", + newSubmissions.tasks().containsAll(lostAssignments.tasks())); + assertThat("Wrong set of workers for reassignments", + Collections.emptySet(), + is(assignor.candidateWorkersForReassignment)); + assertEquals(0, assignor.scheduledRebalance); + assertEquals(0, assignor.delay); + } + + @Test + public void testLostAssignmentHandlingWithMoreThanOneCandidates() { + // Customize assignor for this test case + time = new MockTime(); + initAssignor(); + + assertTrue(assignor.candidateWorkersForReassignment.isEmpty()); + assertEquals(0, assignor.scheduledRebalance); + assertEquals(0, assignor.delay); + + Map configuredAssignment = new HashMap<>(); + configuredAssignment.put("worker0", workerLoad("worker0", 0, 2, 0, 4)); + configuredAssignment.put("worker1", workerLoad("worker1", 2, 2, 4, 4)); + configuredAssignment.put("worker2", workerLoad("worker2", 4, 2, 8, 4)); + memberConfigs = memberConfigs(leader, offset, 0, 2); + + ConnectorsAndTasks newSubmissions = new ConnectorsAndTasks.Builder().build(); + + // No lost assignments + assignor.handleLostAssignments(new ConnectorsAndTasks.Builder().build(), + newSubmissions, + new ArrayList<>(configuredAssignment.values()), + memberConfigs); + + assertThat("Wrong set of workers for reassignments", + Collections.emptySet(), + is(assignor.candidateWorkersForReassignment)); + assertEquals(0, assignor.scheduledRebalance); + assertEquals(0, assignor.delay); + + assignor.previousMembers = new HashSet<>(memberConfigs.keySet()); + String flakyWorker = "worker1"; + WorkerLoad lostLoad = workerLoad(flakyWorker, 2, 2, 4, 4); + memberConfigs.remove(flakyWorker); + String newWorker = "worker3"; + + ConnectorsAndTasks lostAssignments = new ConnectorsAndTasks.Builder() + .withCopies(lostLoad.connectors(), lostLoad.tasks()).build(); + + // Lost assignments detected - A new worker also has joined that is not the returning worker + configuredAssignment.put(newWorker, new WorkerLoad.Builder(newWorker).build()); + memberConfigs.put(newWorker, new ExtendedWorkerState(leaderUrl, offset, null)); + assignor.handleLostAssignments(lostAssignments, newSubmissions, + new ArrayList<>(configuredAssignment.values()), memberConfigs); + + assertThat("Wrong set of workers for reassignments", + Collections.singleton(newWorker), + is(assignor.candidateWorkersForReassignment)); + assertEquals(time.milliseconds() + rebalanceDelay, assignor.scheduledRebalance); + assertEquals(rebalanceDelay, assignor.delay); + + assignor.previousMembers = new HashSet<>(memberConfigs.keySet()); + time.sleep(rebalanceDelay / 2); + rebalanceDelay /= 2; + + // Now two new workers have joined + configuredAssignment.put(flakyWorker, new WorkerLoad.Builder(flakyWorker).build()); + memberConfigs.put(flakyWorker, new ExtendedWorkerState(leaderUrl, offset, null)); + assignor.handleLostAssignments(lostAssignments, newSubmissions, + new ArrayList<>(configuredAssignment.values()), memberConfigs); + + Set expectedWorkers = new HashSet<>(); + expectedWorkers.addAll(Arrays.asList(newWorker, flakyWorker)); + assertThat("Wrong set of workers for reassignments", + expectedWorkers, + is(assignor.candidateWorkersForReassignment)); + assertEquals(time.milliseconds() + rebalanceDelay, assignor.scheduledRebalance); + assertEquals(rebalanceDelay, assignor.delay); + + assignor.previousMembers = new HashSet<>(memberConfigs.keySet()); + time.sleep(rebalanceDelay); + + // The new workers have new assignments, other than the lost ones + configuredAssignment.put(flakyWorker, workerLoad(flakyWorker, 6, 2, 8, 4)); + configuredAssignment.put(newWorker, workerLoad(newWorker, 8, 2, 12, 4)); + // we don't reflect these new assignments in memberConfigs currently because they are not + // used in handleLostAssignments method + assignor.handleLostAssignments(lostAssignments, newSubmissions, + new ArrayList<>(configuredAssignment.values()), memberConfigs); + + // both the newWorkers would need to be considered for re assignment of connectors and tasks + List listOfConnectorsInLast2Workers = new ArrayList<>(); + listOfConnectorsInLast2Workers.addAll(configuredAssignment.getOrDefault(newWorker, new WorkerLoad.Builder(flakyWorker).build()) + .connectors()); + listOfConnectorsInLast2Workers.addAll(configuredAssignment.getOrDefault(flakyWorker, new WorkerLoad.Builder(flakyWorker).build()) + .connectors()); + List listOfTasksInLast2Workers = new ArrayList<>(); + listOfTasksInLast2Workers.addAll(configuredAssignment.getOrDefault(newWorker, new WorkerLoad.Builder(flakyWorker).build()) + .tasks()); + listOfTasksInLast2Workers.addAll(configuredAssignment.getOrDefault(flakyWorker, new WorkerLoad.Builder(flakyWorker).build()) + .tasks()); + assertTrue("Wrong assignment of lost connectors", + listOfConnectorsInLast2Workers.containsAll(lostAssignments.connectors())); + assertTrue("Wrong assignment of lost tasks", + listOfTasksInLast2Workers.containsAll(lostAssignments.tasks())); + assertThat("Wrong set of workers for reassignments", + Collections.emptySet(), + is(assignor.candidateWorkersForReassignment)); + assertEquals(0, assignor.scheduledRebalance); + assertEquals(0, assignor.delay); + } + + @Test + public void testLostAssignmentHandlingWhenWorkerBouncesBackButFinallyLeaves() { + // Customize assignor for this test case + time = new MockTime(); + initAssignor(); + + assertTrue(assignor.candidateWorkersForReassignment.isEmpty()); + assertEquals(0, assignor.scheduledRebalance); + assertEquals(0, assignor.delay); + + Map configuredAssignment = new HashMap<>(); + configuredAssignment.put("worker0", workerLoad("worker0", 0, 2, 0, 4)); + configuredAssignment.put("worker1", workerLoad("worker1", 2, 2, 4, 4)); + configuredAssignment.put("worker2", workerLoad("worker2", 4, 2, 8, 4)); + memberConfigs = memberConfigs(leader, offset, 0, 2); + + ConnectorsAndTasks newSubmissions = new ConnectorsAndTasks.Builder().build(); + + // No lost assignments + assignor.handleLostAssignments(new ConnectorsAndTasks.Builder().build(), + newSubmissions, + new ArrayList<>(configuredAssignment.values()), + memberConfigs); + + assertThat("Wrong set of workers for reassignments", + Collections.emptySet(), + is(assignor.candidateWorkersForReassignment)); + assertEquals(0, assignor.scheduledRebalance); + assertEquals(0, assignor.delay); + + assignor.previousMembers = new HashSet<>(memberConfigs.keySet()); + String veryFlakyWorker = "worker1"; + WorkerLoad lostLoad = workerLoad(veryFlakyWorker, 2, 2, 4, 4); + memberConfigs.remove(veryFlakyWorker); + + ConnectorsAndTasks lostAssignments = new ConnectorsAndTasks.Builder() + .withCopies(lostLoad.connectors(), lostLoad.tasks()).build(); + + // Lost assignments detected - No candidate worker has appeared yet (worker with no assignments) + assignor.handleLostAssignments(lostAssignments, newSubmissions, + new ArrayList<>(configuredAssignment.values()), memberConfigs); + + assertThat("Wrong set of workers for reassignments", + Collections.emptySet(), + is(assignor.candidateWorkersForReassignment)); + assertEquals(time.milliseconds() + rebalanceDelay, assignor.scheduledRebalance); + assertEquals(rebalanceDelay, assignor.delay); + + assignor.previousMembers = new HashSet<>(memberConfigs.keySet()); + time.sleep(rebalanceDelay / 2); + rebalanceDelay /= 2; + + // A new worker (probably returning worker) has joined + configuredAssignment.put(veryFlakyWorker, new WorkerLoad.Builder(veryFlakyWorker).build()); + memberConfigs.put(veryFlakyWorker, new ExtendedWorkerState(leaderUrl, offset, null)); + assignor.handleLostAssignments(lostAssignments, newSubmissions, + new ArrayList<>(configuredAssignment.values()), memberConfigs); + + assertThat("Wrong set of workers for reassignments", + Collections.singleton(veryFlakyWorker), + is(assignor.candidateWorkersForReassignment)); + assertEquals(time.milliseconds() + rebalanceDelay, assignor.scheduledRebalance); + assertEquals(rebalanceDelay, assignor.delay); + + assignor.previousMembers = new HashSet<>(memberConfigs.keySet()); + time.sleep(rebalanceDelay); + + // The returning worker leaves permanently after joining briefly during the delay + configuredAssignment.remove(veryFlakyWorker); + memberConfigs.remove(veryFlakyWorker); + assignor.handleLostAssignments(lostAssignments, newSubmissions, + new ArrayList<>(configuredAssignment.values()), memberConfigs); + + assertTrue("Wrong assignment of lost connectors", + newSubmissions.connectors().containsAll(lostAssignments.connectors())); + assertTrue("Wrong assignment of lost tasks", + newSubmissions.tasks().containsAll(lostAssignments.tasks())); + assertThat("Wrong set of workers for reassignments", + Collections.emptySet(), + is(assignor.candidateWorkersForReassignment)); + assertEquals(0, assignor.scheduledRebalance); + assertEquals(0, assignor.delay); + } + + @Test + public void testTaskAssignmentWhenTasksDuplicatedInWorkerAssignment() { + when(coordinator.configSnapshot()).thenReturn(configState); + doReturn(Collections.EMPTY_MAP).when(assignor).serializeAssignments(assignmentsCapture.capture()); + + // First assignment with 1 worker and 2 connectors configured but not yet assigned + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(0, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(2, 8, 0, 0, "worker1"); + + // Second assignment with a second worker with duplicate assignment joining and all connectors running on previous worker + applyAssignments(returnedAssignments); + memberConfigs = memberConfigs(leader, offset, assignments); + ExtendedAssignment duplicatedWorkerAssignment = newExpandableAssignment(); + duplicatedWorkerAssignment.connectors().addAll(newConnectors(1, 2)); + duplicatedWorkerAssignment.tasks().addAll(newTasks("connector1", 0, 4)); + memberConfigs.put("worker2", new ExtendedWorkerState(leaderUrl, offset, duplicatedWorkerAssignment)); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(0, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(0, 0, 2, 8, "worker1", "worker2"); + + // Third assignment after revocations + applyAssignments(returnedAssignments); + memberConfigs = memberConfigs(leader, offset, assignments); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(0, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(1, 4, 0, 2, "worker1", "worker2"); + + // fourth rebalance after revocations + applyAssignments(returnedAssignments); + memberConfigs = memberConfigs(leader, offset, assignments); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(0, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(0, 2, 0, 0, "worker1", "worker2"); + + // Fifth rebalance should not change assignments + applyAssignments(returnedAssignments); + memberConfigs = memberConfigs(leader, offset, assignments); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(0, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(0, 0, 0, 0, "worker1", "worker2"); + + verify(coordinator, times(rebalanceNum)).configSnapshot(); + verify(coordinator, times(rebalanceNum)).leaderState(any()); + verify(coordinator, times(2 * rebalanceNum)).generationId(); + verify(coordinator, times(rebalanceNum)).memberId(); + verify(coordinator, times(rebalanceNum)).lastCompletedGenerationId(); + } + + @Test + public void testDuplicatedAssignmentHandleWhenTheDuplicatedAssignmentsDeleted() { + when(coordinator.configSnapshot()).thenReturn(configState); + doReturn(Collections.EMPTY_MAP).when(assignor).serializeAssignments(assignmentsCapture.capture()); + + // First assignment with 1 worker and 2 connectors configured but not yet assigned + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(0, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(2, 8, 0, 0, "worker1"); + + //delete connector1 + configState = clusterConfigState(offset, 2, 1, 4); + when(coordinator.configSnapshot()).thenReturn(configState); + + // Second assignment with a second worker with duplicate assignment joining and the duplicated assignment is deleted at the same time + applyAssignments(returnedAssignments); + memberConfigs = memberConfigs(leader, offset, assignments); + ExtendedAssignment duplicatedWorkerAssignment = newExpandableAssignment(); + duplicatedWorkerAssignment.connectors().addAll(newConnectors(1, 2)); + duplicatedWorkerAssignment.tasks().addAll(newTasks("connector1", 0, 4)); + memberConfigs.put("worker2", new ExtendedWorkerState(leaderUrl, offset, duplicatedWorkerAssignment)); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(0, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(0, 0, 2, 8, "worker1", "worker2"); + + // Third assignment after revocations + applyAssignments(returnedAssignments); + memberConfigs = memberConfigs(leader, offset, assignments); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(0, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(0, 0, 0, 2, "worker1", "worker2"); + + // fourth rebalance after revocations + applyAssignments(returnedAssignments); + memberConfigs = memberConfigs(leader, offset, assignments); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(0, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(0, 2, 0, 0, "worker1", "worker2"); + + // Fifth rebalance should not change assignments + applyAssignments(returnedAssignments); + memberConfigs = memberConfigs(leader, offset, assignments); + assignor.performTaskAssignment(leader, offset, memberConfigs, coordinator, protocolVersion); + ++rebalanceNum; + returnedAssignments = assignmentsCapture.getValue(); + assertDelay(0, returnedAssignments); + expectedMemberConfigs = memberConfigs(leader, offset, returnedAssignments); + assertNoReassignments(memberConfigs, expectedMemberConfigs); + assertAssignment(0, 0, 0, 0, "worker1", "worker2"); + + verify(coordinator, times(rebalanceNum)).configSnapshot(); + verify(coordinator, times(rebalanceNum)).leaderState(any()); + verify(coordinator, times(2 * rebalanceNum)).generationId(); + verify(coordinator, times(rebalanceNum)).memberId(); + verify(coordinator, times(rebalanceNum)).lastCompletedGenerationId(); + } + + private WorkerLoad emptyWorkerLoad(String worker) { + return new WorkerLoad.Builder(worker).build(); + } + + private WorkerLoad workerLoad(String worker, int connectorStart, int connectorNum, + int taskStart, int taskNum) { + return new WorkerLoad.Builder(worker).with( + newConnectors(connectorStart, connectorStart + connectorNum), + newTasks(taskStart, taskStart + taskNum)).build(); + } + + private static List newConnectors(int start, int end) { + return IntStream.range(start, end) + .mapToObj(i -> "connector" + i) + .collect(Collectors.toList()); + } + + private static List newTasks(int start, int end) { + return newTasks("task", start, end); + } + + private static List newTasks(String connectorName, int start, int end) { + return IntStream.range(start, end) + .mapToObj(i -> new ConnectorTaskId(connectorName, i)) + .collect(Collectors.toList()); + } + + private static ClusterConfigState clusterConfigState(long offset, + int connectorNum, + int taskNum) { + return clusterConfigState(offset, 1, connectorNum, taskNum); + } + + private static ClusterConfigState clusterConfigState(long offset, + int connectorStart, + int connectorNum, + int taskNum) { + int connectorNumEnd = connectorStart + connectorNum - 1; + return new ClusterConfigState( + offset, + null, + connectorTaskCounts(connectorStart, connectorNumEnd, taskNum), + connectorConfigs(connectorStart, connectorNumEnd), + connectorTargetStates(connectorStart, connectorNumEnd, TargetState.STARTED), + taskConfigs(0, connectorNum, connectorNum * taskNum), + Collections.emptySet()); + } + + private static Map memberConfigs(String givenLeader, + long givenOffset, + Map givenAssignments) { + return givenAssignments.entrySet().stream() + .collect(Collectors.toMap( + Map.Entry::getKey, + e -> new ExtendedWorkerState(expectedLeaderUrl(givenLeader), givenOffset, e.getValue()))); + } + + private static Map memberConfigs(String givenLeader, + long givenOffset, + int start, + int connectorNum) { + return IntStream.range(start, connectorNum + 1) + .mapToObj(i -> new SimpleEntry<>("worker" + i, new ExtendedWorkerState(expectedLeaderUrl(givenLeader), givenOffset, null))) + .collect(Collectors.toMap(SimpleEntry::getKey, SimpleEntry::getValue)); + } + + private static Map connectorTaskCounts(int start, + int connectorNum, + int taskCounts) { + return IntStream.range(start, connectorNum + 1) + .mapToObj(i -> new SimpleEntry<>("connector" + i, taskCounts)) + .collect(Collectors.toMap(SimpleEntry::getKey, SimpleEntry::getValue)); + } + + private static Map> connectorConfigs(int start, int connectorNum) { + return IntStream.range(start, connectorNum + 1) + .mapToObj(i -> new SimpleEntry<>("connector" + i, new HashMap())) + .collect(Collectors.toMap(SimpleEntry::getKey, SimpleEntry::getValue)); + } + + private static Map connectorTargetStates(int start, + int connectorNum, + TargetState state) { + return IntStream.range(start, connectorNum + 1) + .mapToObj(i -> new SimpleEntry<>("connector" + i, state)) + .collect(Collectors.toMap(SimpleEntry::getKey, SimpleEntry::getValue)); + } + + private static Map> taskConfigs(int start, + int connectorNum, + int taskNum) { + return IntStream.range(start, taskNum + 1) + .mapToObj(i -> new SimpleEntry<>( + new ConnectorTaskId("connector" + i / connectorNum + 1, i), + new HashMap()) + ).collect(Collectors.toMap(SimpleEntry::getKey, SimpleEntry::getValue)); + } + + private void applyAssignments(Map newAssignments) { + newAssignments.forEach((k, v) -> { + assignments.computeIfAbsent(k, noop -> newExpandableAssignment()) + .connectors() + .removeAll(v.revokedConnectors()); + assignments.computeIfAbsent(k, noop -> newExpandableAssignment()) + .connectors() + .addAll(v.connectors()); + assignments.computeIfAbsent(k, noop -> newExpandableAssignment()) + .tasks() + .removeAll(v.revokedTasks()); + assignments.computeIfAbsent(k, noop -> newExpandableAssignment()) + .tasks() + .addAll(v.tasks()); + }); + } + + private ExtendedAssignment newExpandableAssignment() { + return new ExtendedAssignment( + protocolVersion, + ConnectProtocol.Assignment.NO_ERROR, + leader, + leaderUrl, + offset, + new ArrayList<>(), + new ArrayList<>(), + new ArrayList<>(), + new ArrayList<>(), + 0); + } + + private static String expectedLeaderUrl(String givenLeader) { + return "http://" + givenLeader + ":8083"; + } + + private void assertAssignment(int connectorNum, int taskNum, + int revokedConnectorNum, int revokedTaskNum, + String... workers) { + assertAssignment(leader, connectorNum, taskNum, revokedConnectorNum, revokedTaskNum, workers); + } + + private void assertAssignment(String expectedLeader, int connectorNum, int taskNum, + int revokedConnectorNum, int revokedTaskNum, + String... workers) { + assertThat("Wrong number of workers", + expectedMemberConfigs.keySet().size(), + is(workers.length)); + assertThat("Wrong set of workers", + new ArrayList<>(expectedMemberConfigs.keySet()), hasItems(workers)); + assertThat("Wrong number of assigned connectors", + expectedMemberConfigs.values().stream().map(v -> v.assignment().connectors().size()).reduce(0, Integer::sum), + is(connectorNum)); + assertThat("Wrong number of assigned tasks", + expectedMemberConfigs.values().stream().map(v -> v.assignment().tasks().size()).reduce(0, Integer::sum), + is(taskNum)); + assertThat("Wrong number of revoked connectors", + expectedMemberConfigs.values().stream().map(v -> v.assignment().revokedConnectors().size()).reduce(0, Integer::sum), + is(revokedConnectorNum)); + assertThat("Wrong number of revoked tasks", + expectedMemberConfigs.values().stream().map(v -> v.assignment().revokedTasks().size()).reduce(0, Integer::sum), + is(revokedTaskNum)); + assertThat("Wrong leader in assignments", + expectedMemberConfigs.values().stream().map(v -> v.assignment().leader()).distinct().collect(Collectors.joining(", ")), + is(expectedLeader)); + assertThat("Wrong leaderUrl in assignments", + expectedMemberConfigs.values().stream().map(v -> v.assignment().leaderUrl()).distinct().collect(Collectors.joining(", ")), + is(expectedLeaderUrl(expectedLeader))); + } + + private void assertDelay(int expectedDelay, Map newAssignments) { + newAssignments.values().stream() + .forEach(a -> assertEquals( + "Wrong rebalance delay in " + a, expectedDelay, a.delay())); + } + + private void assertNoReassignments(Map existingAssignments, + Map newAssignments) { + assertNoDuplicateInAssignment(existingAssignments); + assertNoDuplicateInAssignment(newAssignments); + + List existingConnectors = existingAssignments.values().stream() + .flatMap(a -> a.assignment().connectors().stream()) + .collect(Collectors.toList()); + List newConnectors = newAssignments.values().stream() + .flatMap(a -> a.assignment().connectors().stream()) + .collect(Collectors.toList()); + + List existingTasks = existingAssignments.values().stream() + .flatMap(a -> a.assignment().tasks().stream()) + .collect(Collectors.toList()); + + List newTasks = newAssignments.values().stream() + .flatMap(a -> a.assignment().tasks().stream()) + .collect(Collectors.toList()); + + existingConnectors.retainAll(newConnectors); + assertThat("Found connectors in new assignment that already exist in current assignment", + Collections.emptyList(), + is(existingConnectors)); + existingTasks.retainAll(newTasks); + assertThat("Found tasks in new assignment that already exist in current assignment", + Collections.emptyList(), + is(existingConnectors)); + } + + private void assertNoDuplicateInAssignment(Map existingAssignment) { + List existingConnectors = existingAssignment.values().stream() + .flatMap(a -> a.assignment().connectors().stream()) + .collect(Collectors.toList()); + Set existingUniqueConnectors = new HashSet<>(existingConnectors); + existingConnectors.removeAll(existingUniqueConnectors); + assertThat("Connectors should be unique in assignments but duplicates where found", + Collections.emptyList(), + is(existingConnectors)); + + List existingTasks = existingAssignment.values().stream() + .flatMap(a -> a.assignment().tasks().stream()) + .collect(Collectors.toList()); + Set existingUniqueTasks = new HashSet<>(existingTasks); + existingTasks.removeAll(existingUniqueTasks); + assertThat("Tasks should be unique in assignments but duplicates where found", + Collections.emptyList(), + is(existingTasks)); + } + + private void expectGeneration() { + when(coordinator.generationId()) + .thenReturn(assignor.previousGenerationId + 1) + .thenReturn(assignor.previousGenerationId + 1); + when(coordinator.lastCompletedGenerationId()).thenReturn(assignor.previousGenerationId); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinatorIncrementalTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinatorIncrementalTest.java new file mode 100644 index 0000000..8b0f57e --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinatorIncrementalTest.java @@ -0,0 +1,584 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.distributed; + +import org.apache.kafka.clients.GroupRebalanceConfig; +import org.apache.kafka.clients.Metadata; +import org.apache.kafka.clients.MockClient; +import org.apache.kafka.clients.consumer.internals.ConsumerNetworkClient; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.internals.ClusterResourceListeners; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.requests.RequestTestUtils; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.connect.storage.KafkaConfigBackingStore; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.apache.kafka.common.message.JoinGroupRequestData.JoinGroupRequestProtocol; +import static org.apache.kafka.common.message.JoinGroupRequestData.JoinGroupRequestProtocolCollection; +import static org.apache.kafka.common.message.JoinGroupResponseData.JoinGroupResponseMember; +import static org.apache.kafka.connect.runtime.WorkerTestUtils.assertAssignment; +import static org.apache.kafka.connect.runtime.WorkerTestUtils.clusterConfigState; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocol.WorkerState; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocolCompatibility.COMPATIBLE; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocolCompatibility.EAGER; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocolCompatibility.SESSIONED; +import static org.apache.kafka.connect.runtime.distributed.IncrementalCooperativeConnectProtocol.CONNECT_PROTOCOL_V1; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.runners.Parameterized.Parameter; +import static org.junit.runners.Parameterized.Parameters; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +@RunWith(value = Parameterized.class) +public class WorkerCoordinatorIncrementalTest { + @Rule + public MockitoRule rule = MockitoJUnit.rule(); + + private String connectorId1 = "connector1"; + private String connectorId2 = "connector2"; + private String connectorId3 = "connector3"; + private ConnectorTaskId taskId1x0 = new ConnectorTaskId(connectorId1, 0); + private ConnectorTaskId taskId1x1 = new ConnectorTaskId(connectorId1, 1); + private ConnectorTaskId taskId2x0 = new ConnectorTaskId(connectorId2, 0); + private ConnectorTaskId taskId3x0 = new ConnectorTaskId(connectorId3, 0); + + private String groupId = "test-group"; + private int sessionTimeoutMs = 10; + private int rebalanceTimeoutMs = 60; + private int heartbeatIntervalMs = 2; + private long retryBackoffMs = 100; + private int requestTimeoutMs = 1000; + private MockTime time; + private MockClient client; + private Node node; + private Metadata metadata; + private Metrics metrics; + private ConsumerNetworkClient consumerClient; + private MockRebalanceListener rebalanceListener; + @Mock + private KafkaConfigBackingStore configStorage; + private GroupRebalanceConfig rebalanceConfig; + private WorkerCoordinator coordinator; + private int rebalanceDelay = DistributedConfig.SCHEDULED_REBALANCE_MAX_DELAY_MS_DEFAULT; + + private String leaderId; + private String memberId; + private String anotherMemberId; + private String leaderUrl; + private String memberUrl; + private String anotherMemberUrl; + private int generationId; + private long offset; + + private int configStorageCalls; + + private ClusterConfigState configState1; + private ClusterConfigState configState2; + private ClusterConfigState configStateSingleTaskConnectors; + + // Arguments are: + // - Protocol type + // - Expected metadata size + @Parameters + public static Iterable mode() { + return Arrays.asList(new Object[][]{{COMPATIBLE, 2}, {SESSIONED, 3}}); + } + + @Parameter + public ConnectProtocolCompatibility compatibility; + + @Parameter(1) + public int expectedMetadataSize; + + @Before + public void setup() { + LogContext loggerFactory = new LogContext(); + + this.time = new MockTime(); + this.metadata = new Metadata(0, Long.MAX_VALUE, loggerFactory, new ClusterResourceListeners()); + this.client = new MockClient(time, metadata); + this.client.updateMetadata(RequestTestUtils.metadataUpdateWith(1, Collections.singletonMap("topic", 1))); + this.node = metadata.fetch().nodes().get(0); + this.consumerClient = new ConsumerNetworkClient(loggerFactory, client, metadata, time, + retryBackoffMs, requestTimeoutMs, heartbeatIntervalMs); + this.metrics = new Metrics(time); + this.rebalanceListener = new MockRebalanceListener(); + + this.leaderId = "worker1"; + this.memberId = "worker2"; + this.anotherMemberId = "worker3"; + this.leaderUrl = expectedUrl(leaderId); + this.memberUrl = expectedUrl(memberId); + this.anotherMemberUrl = expectedUrl(anotherMemberId); + this.generationId = 3; + this.offset = 10L; + + this.configStorageCalls = 0; + + this.rebalanceConfig = new GroupRebalanceConfig(sessionTimeoutMs, + rebalanceTimeoutMs, + heartbeatIntervalMs, + groupId, + Optional.empty(), + retryBackoffMs, + true); + this.coordinator = new WorkerCoordinator(rebalanceConfig, + loggerFactory, + consumerClient, + metrics, + "worker" + groupId, + time, + expectedUrl(leaderId), + configStorage, + rebalanceListener, + compatibility, + rebalanceDelay); + + configState1 = clusterConfigState(offset, 2, 4); + } + + @After + public void teardown() { + this.metrics.close(); + verifyNoMoreInteractions(configStorage); + } + + private static String expectedUrl(String member) { + return "http://" + member + ":8083"; + } + + // We only test functionality unique to WorkerCoordinator. Other functionality is already + // well tested via the tests that cover AbstractCoordinator & ConsumerCoordinator. + + @Test + public void testMetadata() { + when(configStorage.snapshot()).thenReturn(configState1); + + JoinGroupRequestProtocolCollection serialized = coordinator.metadata(); + assertEquals(expectedMetadataSize, serialized.size()); + + Iterator protocolIterator = serialized.iterator(); + assertTrue(protocolIterator.hasNext()); + JoinGroupRequestProtocol defaultMetadata = protocolIterator.next(); + assertEquals(compatibility.protocol(), defaultMetadata.name()); + WorkerState state = IncrementalCooperativeConnectProtocol + .deserializeMetadata(ByteBuffer.wrap(defaultMetadata.metadata())); + assertEquals(offset, state.offset()); + + verify(configStorage, times(1)).snapshot(); + } + + @Test + public void testMetadataWithExistingAssignment() { + when(configStorage.snapshot()).thenReturn(configState1); + + ExtendedAssignment assignment = new ExtendedAssignment( + CONNECT_PROTOCOL_V1, ExtendedAssignment.NO_ERROR, leaderId, leaderUrl, configState1.offset(), + Collections.singletonList(connectorId1), Arrays.asList(taskId1x0, taskId2x0), + Collections.emptyList(), Collections.emptyList(), 0); + ByteBuffer buf = IncrementalCooperativeConnectProtocol.serializeAssignment(assignment); + // Using onJoinComplete to register the protocol selection decided by the broker + // coordinator as well as an existing previous assignment that the call to metadata will + // include with v1 but not with v0 + coordinator.onJoinComplete(generationId, memberId, compatibility.protocol(), buf); + + JoinGroupRequestProtocolCollection serialized = coordinator.metadata(); + assertEquals(expectedMetadataSize, serialized.size()); + + Iterator protocolIterator = serialized.iterator(); + assertTrue(protocolIterator.hasNext()); + JoinGroupRequestProtocol selectedMetadata = protocolIterator.next(); + assertEquals(compatibility.protocol(), selectedMetadata.name()); + ExtendedWorkerState state = IncrementalCooperativeConnectProtocol + .deserializeMetadata(ByteBuffer.wrap(selectedMetadata.metadata())); + assertEquals(offset, state.offset()); + assertNotEquals(ExtendedAssignment.empty(), state.assignment()); + assertEquals(Collections.singletonList(connectorId1), state.assignment().connectors()); + assertEquals(Arrays.asList(taskId1x0, taskId2x0), state.assignment().tasks()); + + verify(configStorage, times(1)).snapshot(); + } + + @Test + public void testMetadataWithExistingAssignmentButOlderProtocolSelection() { + when(configStorage.snapshot()).thenReturn(configState1); + + ExtendedAssignment assignment = new ExtendedAssignment( + CONNECT_PROTOCOL_V1, ExtendedAssignment.NO_ERROR, leaderId, leaderUrl, configState1.offset(), + Collections.singletonList(connectorId1), Arrays.asList(taskId1x0, taskId2x0), + Collections.emptyList(), Collections.emptyList(), 0); + ByteBuffer buf = IncrementalCooperativeConnectProtocol.serializeAssignment(assignment); + // Using onJoinComplete to register the protocol selection decided by the broker + // coordinator as well as an existing previous assignment that the call to metadata will + // include with v1 but not with v0 + coordinator.onJoinComplete(generationId, memberId, EAGER.protocol(), buf); + + JoinGroupRequestProtocolCollection serialized = coordinator.metadata(); + assertEquals(expectedMetadataSize, serialized.size()); + + Iterator protocolIterator = serialized.iterator(); + assertTrue(protocolIterator.hasNext()); + JoinGroupRequestProtocol selectedMetadata = protocolIterator.next(); + assertEquals(compatibility.protocol(), selectedMetadata.name()); + ExtendedWorkerState state = IncrementalCooperativeConnectProtocol + .deserializeMetadata(ByteBuffer.wrap(selectedMetadata.metadata())); + assertEquals(offset, state.offset()); + assertNotEquals(ExtendedAssignment.empty(), state.assignment()); + + verify(configStorage, times(1)).snapshot(); + } + + @Test + public void testTaskAssignmentWhenWorkerJoins() { + when(configStorage.snapshot()).thenReturn(configState1); + + coordinator.metadata(); + ++configStorageCalls; + + List responseMembers = new ArrayList<>(); + addJoinGroupResponseMember(responseMembers, leaderId, offset, null); + addJoinGroupResponseMember(responseMembers, memberId, offset, null); + + Map result = coordinator.performAssignment(leaderId, compatibility.protocol(), responseMembers); + + ExtendedAssignment leaderAssignment = deserializeAssignment(result, leaderId); + assertAssignment(leaderId, offset, + Collections.singletonList(connectorId1), 4, + Collections.emptyList(), 0, + leaderAssignment); + + ExtendedAssignment memberAssignment = deserializeAssignment(result, memberId); + assertAssignment(leaderId, offset, + Collections.singletonList(connectorId2), 4, + Collections.emptyList(), 0, + memberAssignment); + + coordinator.metadata(); + ++configStorageCalls; + + responseMembers = new ArrayList<>(); + addJoinGroupResponseMember(responseMembers, leaderId, offset, leaderAssignment); + addJoinGroupResponseMember(responseMembers, memberId, offset, memberAssignment); + addJoinGroupResponseMember(responseMembers, anotherMemberId, offset, null); + + result = coordinator.performAssignment(leaderId, compatibility.protocol(), responseMembers); + + //Equally distributing tasks across member + leaderAssignment = deserializeAssignment(result, leaderId); + assertAssignment(leaderId, offset, + Collections.emptyList(), 0, + Collections.emptyList(), 1, + leaderAssignment); + + memberAssignment = deserializeAssignment(result, memberId); + assertAssignment(leaderId, offset, + Collections.emptyList(), 0, + Collections.emptyList(), 1, + memberAssignment); + + ExtendedAssignment anotherMemberAssignment = deserializeAssignment(result, anotherMemberId); + assertAssignment(leaderId, offset, + Collections.emptyList(), 0, + Collections.emptyList(), 0, + anotherMemberAssignment); + + verify(configStorage, times(configStorageCalls)).snapshot(); + } + + @Test + public void testTaskAssignmentWhenWorkerLeavesPermanently() { + when(configStorage.snapshot()).thenReturn(configState1); + + // First assignment distributes configured connectors and tasks + coordinator.metadata(); + ++configStorageCalls; + + List responseMembers = new ArrayList<>(); + addJoinGroupResponseMember(responseMembers, leaderId, offset, null); + addJoinGroupResponseMember(responseMembers, memberId, offset, null); + addJoinGroupResponseMember(responseMembers, anotherMemberId, offset, null); + + Map result = coordinator.performAssignment(leaderId, compatibility.protocol(), responseMembers); + + ExtendedAssignment leaderAssignment = deserializeAssignment(result, leaderId); + assertAssignment(leaderId, offset, + Collections.singletonList(connectorId1), 3, + Collections.emptyList(), 0, + leaderAssignment); + + ExtendedAssignment memberAssignment = deserializeAssignment(result, memberId); + assertAssignment(leaderId, offset, + Collections.singletonList(connectorId2), 3, + Collections.emptyList(), 0, + memberAssignment); + + ExtendedAssignment anotherMemberAssignment = deserializeAssignment(result, anotherMemberId); + assertAssignment(leaderId, offset, + Collections.emptyList(), 2, + Collections.emptyList(), 0, + anotherMemberAssignment); + + // Second rebalance detects a worker is missing + coordinator.metadata(); + ++configStorageCalls; + + // Mark everyone as in sync with configState1 + responseMembers = new ArrayList<>(); + addJoinGroupResponseMember(responseMembers, leaderId, offset, leaderAssignment); + addJoinGroupResponseMember(responseMembers, memberId, offset, memberAssignment); + + result = coordinator.performAssignment(leaderId, compatibility.protocol(), responseMembers); + + leaderAssignment = deserializeAssignment(result, leaderId); + assertAssignment(leaderId, offset, + Collections.emptyList(), 0, + Collections.emptyList(), 0, + rebalanceDelay, + leaderAssignment); + + memberAssignment = deserializeAssignment(result, memberId); + assertAssignment(leaderId, offset, + Collections.emptyList(), 0, + Collections.emptyList(), 0, + rebalanceDelay, + memberAssignment); + + rebalanceDelay /= 2; + time.sleep(rebalanceDelay); + + // A third rebalance before the delay expires won't change the assignments + result = coordinator.performAssignment(leaderId, compatibility.protocol(), responseMembers); + + leaderAssignment = deserializeAssignment(result, leaderId); + assertAssignment(leaderId, offset, + Collections.emptyList(), 0, + Collections.emptyList(), 0, + rebalanceDelay, + leaderAssignment); + + memberAssignment = deserializeAssignment(result, memberId); + assertAssignment(leaderId, offset, + Collections.emptyList(), 0, + Collections.emptyList(), 0, + rebalanceDelay, + memberAssignment); + + time.sleep(rebalanceDelay + 1); + + // A rebalance after the delay expires re-assigns the lost tasks + result = coordinator.performAssignment(leaderId, compatibility.protocol(), responseMembers); + + leaderAssignment = deserializeAssignment(result, leaderId); + assertAssignment(leaderId, offset, + Collections.emptyList(), 1, + Collections.emptyList(), 0, + leaderAssignment); + + memberAssignment = deserializeAssignment(result, memberId); + assertAssignment(leaderId, offset, + Collections.emptyList(), 1, + Collections.emptyList(), 0, + memberAssignment); + + verify(configStorage, times(configStorageCalls)).snapshot(); + } + + @Test + public void testTaskAssignmentWhenWorkerBounces() { + when(configStorage.snapshot()).thenReturn(configState1); + + // First assignment distributes configured connectors and tasks + coordinator.metadata(); + ++configStorageCalls; + + List responseMembers = new ArrayList<>(); + addJoinGroupResponseMember(responseMembers, leaderId, offset, null); + addJoinGroupResponseMember(responseMembers, memberId, offset, null); + addJoinGroupResponseMember(responseMembers, anotherMemberId, offset, null); + + Map result = coordinator.performAssignment(leaderId, compatibility.protocol(), responseMembers); + + ExtendedAssignment leaderAssignment = deserializeAssignment(result, leaderId); + assertAssignment(leaderId, offset, + Collections.singletonList(connectorId1), 3, + Collections.emptyList(), 0, + leaderAssignment); + + ExtendedAssignment memberAssignment = deserializeAssignment(result, memberId); + assertAssignment(leaderId, offset, + Collections.singletonList(connectorId2), 3, + Collections.emptyList(), 0, + memberAssignment); + + ExtendedAssignment anotherMemberAssignment = deserializeAssignment(result, anotherMemberId); + assertAssignment(leaderId, offset, + Collections.emptyList(), 2, + Collections.emptyList(), 0, + anotherMemberAssignment); + + // Second rebalance detects a worker is missing + coordinator.metadata(); + ++configStorageCalls; + + responseMembers = new ArrayList<>(); + addJoinGroupResponseMember(responseMembers, leaderId, offset, leaderAssignment); + addJoinGroupResponseMember(responseMembers, memberId, offset, memberAssignment); + + result = coordinator.performAssignment(leaderId, compatibility.protocol(), responseMembers); + + leaderAssignment = deserializeAssignment(result, leaderId); + assertAssignment(leaderId, offset, + Collections.emptyList(), 0, + Collections.emptyList(), 0, + rebalanceDelay, + leaderAssignment); + + memberAssignment = deserializeAssignment(result, memberId); + assertAssignment(leaderId, offset, + Collections.emptyList(), 0, + Collections.emptyList(), 0, + rebalanceDelay, + memberAssignment); + + rebalanceDelay /= 2; + time.sleep(rebalanceDelay); + + // A third rebalance before the delay expires won't change the assignments even if the + // member returns in the meantime + addJoinGroupResponseMember(responseMembers, anotherMemberId, offset, null); + result = coordinator.performAssignment(leaderId, compatibility.protocol(), responseMembers); + + leaderAssignment = deserializeAssignment(result, leaderId); + assertAssignment(leaderId, offset, + Collections.emptyList(), 0, + Collections.emptyList(), 0, + rebalanceDelay, + leaderAssignment); + + memberAssignment = deserializeAssignment(result, memberId); + assertAssignment(leaderId, offset, + Collections.emptyList(), 0, + Collections.emptyList(), 0, + rebalanceDelay, + memberAssignment); + + anotherMemberAssignment = deserializeAssignment(result, anotherMemberId); + assertAssignment(leaderId, offset, + Collections.emptyList(), 0, + Collections.emptyList(), 0, + rebalanceDelay, + anotherMemberAssignment); + + time.sleep(rebalanceDelay + 1); + + result = coordinator.performAssignment(leaderId, compatibility.protocol(), responseMembers); + + // A rebalance after the delay expires re-assigns the lost tasks to the returning member + leaderAssignment = deserializeAssignment(result, leaderId); + assertAssignment(leaderId, offset, + Collections.emptyList(), 0, + Collections.emptyList(), 0, + leaderAssignment); + + memberAssignment = deserializeAssignment(result, memberId); + assertAssignment(leaderId, offset, + Collections.emptyList(), 0, + Collections.emptyList(), 0, + memberAssignment); + + anotherMemberAssignment = deserializeAssignment(result, anotherMemberId); + assertAssignment(leaderId, offset, + Collections.emptyList(), 2, + Collections.emptyList(), 0, + anotherMemberAssignment); + + verify(configStorage, times(configStorageCalls)).snapshot(); + } + + private static class MockRebalanceListener implements WorkerRebalanceListener { + public ExtendedAssignment assignment = null; + + public String revokedLeader; + public Collection revokedConnectors = Collections.emptyList(); + public Collection revokedTasks = Collections.emptyList(); + + public int revokedCount = 0; + public int assignedCount = 0; + + @Override + public void onAssigned(ExtendedAssignment assignment, int generation) { + this.assignment = assignment; + assignedCount++; + } + + @Override + public void onRevoked(String leader, Collection connectors, Collection tasks) { + if (connectors.isEmpty() && tasks.isEmpty()) { + return; + } + this.revokedLeader = leader; + this.revokedConnectors = connectors; + this.revokedTasks = tasks; + revokedCount++; + } + } + + private static ExtendedAssignment deserializeAssignment(Map assignment, + String member) { + return IncrementalCooperativeConnectProtocol.deserializeAssignment(assignment.get(member)); + } + + private void addJoinGroupResponseMember(List responseMembers, + String member, + long offset, + ExtendedAssignment assignment) { + responseMembers.add(new JoinGroupResponseMember() + .setMemberId(member) + .setMetadata( + IncrementalCooperativeConnectProtocol.serializeMetadata( + new ExtendedWorkerState(expectedUrl(member), offset, assignment), + compatibility != COMPATIBLE + ).array() + ) + ); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinatorTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinatorTest.java new file mode 100644 index 0000000..51232b6 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/WorkerCoordinatorTest.java @@ -0,0 +1,586 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.distributed; + +import org.apache.kafka.clients.GroupRebalanceConfig; +import org.apache.kafka.clients.Metadata; +import org.apache.kafka.clients.MockClient; +import org.apache.kafka.clients.consumer.internals.ConsumerNetworkClient; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.internals.ClusterResourceListeners; +import org.apache.kafka.common.message.JoinGroupRequestData; +import org.apache.kafka.common.message.JoinGroupResponseData; +import org.apache.kafka.common.message.SyncGroupResponseData; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.FindCoordinatorResponse; +import org.apache.kafka.common.requests.JoinGroupResponse; +import org.apache.kafka.common.requests.RequestTestUtils; +import org.apache.kafka.common.requests.SyncGroupRequest; +import org.apache.kafka.common.requests.SyncGroupResponse; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.connect.runtime.TargetState; +import org.apache.kafka.connect.storage.KafkaConfigBackingStore; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.easymock.EasyMock; +import org.easymock.Mock; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.powermock.api.easymock.PowerMock; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocolCompatibility.COMPATIBLE; +import static org.apache.kafka.connect.runtime.distributed.ConnectProtocolCompatibility.EAGER; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.runners.Parameterized.Parameter; +import static org.junit.runners.Parameterized.Parameters; + +@RunWith(value = Parameterized.class) +public class WorkerCoordinatorTest { + + private static final String LEADER_URL = "leaderUrl:8083"; + private static final String MEMBER_URL = "memberUrl:8083"; + + private String connectorId1 = "connector1"; + private String connectorId2 = "connector2"; + private String connectorId3 = "connector3"; + private ConnectorTaskId taskId1x0 = new ConnectorTaskId(connectorId1, 0); + private ConnectorTaskId taskId1x1 = new ConnectorTaskId(connectorId1, 1); + private ConnectorTaskId taskId2x0 = new ConnectorTaskId(connectorId2, 0); + private ConnectorTaskId taskId3x0 = new ConnectorTaskId(connectorId3, 0); + + private String groupId = "test-group"; + private int sessionTimeoutMs = 10; + private int rebalanceTimeoutMs = 60; + private int heartbeatIntervalMs = 2; + private long retryBackoffMs = 100; + private MockTime time; + private MockClient client; + private Node node; + private Metadata metadata; + private Metrics metrics; + private ConsumerNetworkClient consumerClient; + private MockRebalanceListener rebalanceListener; + @Mock private KafkaConfigBackingStore configStorage; + private GroupRebalanceConfig rebalanceConfig; + private WorkerCoordinator coordinator; + + private ClusterConfigState configState1; + private ClusterConfigState configState2; + private ClusterConfigState configStateSingleTaskConnectors; + + // Arguments are: + // - Protocol type + // - Expected metadata size + @Parameters + public static Iterable mode() { + return Arrays.asList(new Object[][]{ + {EAGER, 1}, + {COMPATIBLE, 2}}); + } + + @Parameter + public ConnectProtocolCompatibility compatibility; + + @Parameter(1) + public int expectedMetadataSize; + + @Before + public void setup() { + LogContext logContext = new LogContext(); + + this.time = new MockTime(); + this.metadata = new Metadata(0, Long.MAX_VALUE, logContext, new ClusterResourceListeners()); + this.client = new MockClient(time, metadata); + this.client.updateMetadata(RequestTestUtils.metadataUpdateWith(1, Collections.singletonMap("topic", 1))); + this.node = metadata.fetch().nodes().get(0); + this.consumerClient = new ConsumerNetworkClient(logContext, client, metadata, time, 100, 1000, heartbeatIntervalMs); + this.metrics = new Metrics(time); + this.rebalanceListener = new MockRebalanceListener(); + this.configStorage = PowerMock.createMock(KafkaConfigBackingStore.class); + this.rebalanceConfig = new GroupRebalanceConfig(sessionTimeoutMs, + rebalanceTimeoutMs, + heartbeatIntervalMs, + groupId, + Optional.empty(), + retryBackoffMs, + true); + this.coordinator = new WorkerCoordinator(rebalanceConfig, + logContext, + consumerClient, + metrics, + "consumer" + groupId, + time, + LEADER_URL, + configStorage, + rebalanceListener, + compatibility, + 0); + + configState1 = new ClusterConfigState( + 1L, + null, + Collections.singletonMap(connectorId1, 1), + Collections.singletonMap(connectorId1, new HashMap()), + Collections.singletonMap(connectorId1, TargetState.STARTED), + Collections.singletonMap(taskId1x0, new HashMap()), + Collections.emptySet() + ); + + Map configState2ConnectorTaskCounts = new HashMap<>(); + configState2ConnectorTaskCounts.put(connectorId1, 2); + configState2ConnectorTaskCounts.put(connectorId2, 1); + Map> configState2ConnectorConfigs = new HashMap<>(); + configState2ConnectorConfigs.put(connectorId1, new HashMap<>()); + configState2ConnectorConfigs.put(connectorId2, new HashMap<>()); + Map configState2TargetStates = new HashMap<>(); + configState2TargetStates.put(connectorId1, TargetState.STARTED); + configState2TargetStates.put(connectorId2, TargetState.STARTED); + Map> configState2TaskConfigs = new HashMap<>(); + configState2TaskConfigs.put(taskId1x0, new HashMap<>()); + configState2TaskConfigs.put(taskId1x1, new HashMap<>()); + configState2TaskConfigs.put(taskId2x0, new HashMap<>()); + configState2 = new ClusterConfigState( + 2L, + null, + configState2ConnectorTaskCounts, + configState2ConnectorConfigs, + configState2TargetStates, + configState2TaskConfigs, + Collections.emptySet() + ); + + Map configStateSingleTaskConnectorsConnectorTaskCounts = new HashMap<>(); + configStateSingleTaskConnectorsConnectorTaskCounts.put(connectorId1, 1); + configStateSingleTaskConnectorsConnectorTaskCounts.put(connectorId2, 1); + configStateSingleTaskConnectorsConnectorTaskCounts.put(connectorId3, 1); + Map> configStateSingleTaskConnectorsConnectorConfigs = new HashMap<>(); + configStateSingleTaskConnectorsConnectorConfigs.put(connectorId1, new HashMap<>()); + configStateSingleTaskConnectorsConnectorConfigs.put(connectorId2, new HashMap<>()); + configStateSingleTaskConnectorsConnectorConfigs.put(connectorId3, new HashMap<>()); + Map configStateSingleTaskConnectorsTargetStates = new HashMap<>(); + configStateSingleTaskConnectorsTargetStates.put(connectorId1, TargetState.STARTED); + configStateSingleTaskConnectorsTargetStates.put(connectorId2, TargetState.STARTED); + configStateSingleTaskConnectorsTargetStates.put(connectorId3, TargetState.STARTED); + Map> configStateSingleTaskConnectorsTaskConfigs = new HashMap<>(); + configStateSingleTaskConnectorsTaskConfigs.put(taskId1x0, new HashMap<>()); + configStateSingleTaskConnectorsTaskConfigs.put(taskId2x0, new HashMap<>()); + configStateSingleTaskConnectorsTaskConfigs.put(taskId3x0, new HashMap<>()); + configStateSingleTaskConnectors = new ClusterConfigState( + 2L, + null, + configStateSingleTaskConnectorsConnectorTaskCounts, + configStateSingleTaskConnectorsConnectorConfigs, + configStateSingleTaskConnectorsTargetStates, + configStateSingleTaskConnectorsTaskConfigs, + Collections.emptySet() + ); + } + + @After + public void teardown() { + this.metrics.close(); + } + + // We only test functionality unique to WorkerCoordinator. Most functionality is already well tested via the tests + // that cover AbstractCoordinator & ConsumerCoordinator. + + @Test + public void testMetadata() { + EasyMock.expect(configStorage.snapshot()).andReturn(configState1); + + PowerMock.replayAll(); + + JoinGroupRequestData.JoinGroupRequestProtocolCollection serialized = coordinator.metadata(); + assertEquals(expectedMetadataSize, serialized.size()); + + Iterator protocolIterator = serialized.iterator(); + assertTrue(protocolIterator.hasNext()); + JoinGroupRequestData.JoinGroupRequestProtocol defaultMetadata = protocolIterator.next(); + assertEquals(compatibility.protocol(), defaultMetadata.name()); + ConnectProtocol.WorkerState state = ConnectProtocol.deserializeMetadata( + ByteBuffer.wrap(defaultMetadata.metadata())); + assertEquals(1, state.offset()); + + PowerMock.verifyAll(); + } + + @Test + public void testNormalJoinGroupLeader() { + EasyMock.expect(configStorage.snapshot()).andReturn(configState1); + + PowerMock.replayAll(); + + final String consumerId = "leader"; + + client.prepareResponse(FindCoordinatorResponse.prepareResponse(Errors.NONE, groupId, node)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // normal join group + Map memberConfigOffsets = new HashMap<>(); + memberConfigOffsets.put("leader", 1L); + memberConfigOffsets.put("member", 1L); + client.prepareResponse(joinGroupLeaderResponse(1, consumerId, memberConfigOffsets, Errors.NONE)); + client.prepareResponse(body -> { + SyncGroupRequest sync = (SyncGroupRequest) body; + return sync.data().memberId().equals(consumerId) && + sync.data().generationId() == 1 && + sync.groupAssignments().containsKey(consumerId); + }, syncGroupResponse(ConnectProtocol.Assignment.NO_ERROR, "leader", 1L, Collections.singletonList(connectorId1), + Collections.emptyList(), Errors.NONE)); + coordinator.ensureActiveGroup(); + + assertFalse(coordinator.rejoinNeededOrPending()); + assertEquals(0, rebalanceListener.revokedCount); + assertEquals(1, rebalanceListener.assignedCount); + assertFalse(rebalanceListener.assignment.failed()); + assertEquals(1L, rebalanceListener.assignment.offset()); + assertEquals("leader", rebalanceListener.assignment.leader()); + assertEquals(Collections.singletonList(connectorId1), rebalanceListener.assignment.connectors()); + assertEquals(Collections.emptyList(), rebalanceListener.assignment.tasks()); + + PowerMock.verifyAll(); + } + + @Test + public void testNormalJoinGroupFollower() { + EasyMock.expect(configStorage.snapshot()).andReturn(configState1); + + PowerMock.replayAll(); + + final String memberId = "member"; + + client.prepareResponse(FindCoordinatorResponse.prepareResponse(Errors.NONE, groupId, node)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // normal join group + client.prepareResponse(joinGroupFollowerResponse(1, memberId, "leader", Errors.NONE)); + client.prepareResponse(body -> { + SyncGroupRequest sync = (SyncGroupRequest) body; + return sync.data().memberId().equals(memberId) && + sync.data().generationId() == 1 && + sync.data().assignments().isEmpty(); + }, syncGroupResponse(ConnectProtocol.Assignment.NO_ERROR, "leader", 1L, Collections.emptyList(), + Collections.singletonList(taskId1x0), Errors.NONE)); + coordinator.ensureActiveGroup(); + + assertFalse(coordinator.rejoinNeededOrPending()); + assertEquals(0, rebalanceListener.revokedCount); + assertEquals(1, rebalanceListener.assignedCount); + assertFalse(rebalanceListener.assignment.failed()); + assertEquals(1L, rebalanceListener.assignment.offset()); + assertEquals(Collections.emptyList(), rebalanceListener.assignment.connectors()); + assertEquals(Collections.singletonList(taskId1x0), rebalanceListener.assignment.tasks()); + + PowerMock.verifyAll(); + } + + @Test + public void testJoinLeaderCannotAssign() { + // If the selected leader can't get up to the maximum offset, it will fail to assign and we should immediately + // need to retry the join. + + // When the first round fails, we'll take an updated config snapshot + EasyMock.expect(configStorage.snapshot()).andReturn(configState1); + EasyMock.expect(configStorage.snapshot()).andReturn(configState2); + + PowerMock.replayAll(); + + final String memberId = "member"; + + client.prepareResponse(FindCoordinatorResponse.prepareResponse(Errors.NONE, groupId, node)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // config mismatch results in assignment error + client.prepareResponse(joinGroupFollowerResponse(1, memberId, "leader", Errors.NONE)); + MockClient.RequestMatcher matcher = body -> { + SyncGroupRequest sync = (SyncGroupRequest) body; + return sync.data().memberId().equals(memberId) && + sync.data().generationId() == 1 && + sync.data().assignments().isEmpty(); + }; + client.prepareResponse(matcher, syncGroupResponse(ConnectProtocol.Assignment.CONFIG_MISMATCH, "leader", 10L, + Collections.emptyList(), Collections.emptyList(), Errors.NONE)); + client.prepareResponse(joinGroupFollowerResponse(1, memberId, "leader", Errors.NONE)); + client.prepareResponse(matcher, syncGroupResponse(ConnectProtocol.Assignment.NO_ERROR, "leader", 1L, + Collections.emptyList(), Collections.singletonList(taskId1x0), Errors.NONE)); + coordinator.ensureActiveGroup(); + + PowerMock.verifyAll(); + } + + @Test + public void testRejoinGroup() { + EasyMock.expect(configStorage.snapshot()).andReturn(configState1); + EasyMock.expect(configStorage.snapshot()).andReturn(configState1); + + PowerMock.replayAll(); + + client.prepareResponse(FindCoordinatorResponse.prepareResponse(Errors.NONE, groupId, node)); + coordinator.ensureCoordinatorReady(time.timer(Long.MAX_VALUE)); + + // join the group once + client.prepareResponse(joinGroupFollowerResponse(1, "consumer", "leader", Errors.NONE)); + client.prepareResponse(syncGroupResponse(ConnectProtocol.Assignment.NO_ERROR, "leader", 1L, Collections.emptyList(), + Collections.singletonList(taskId1x0), Errors.NONE)); + coordinator.ensureActiveGroup(); + + assertEquals(0, rebalanceListener.revokedCount); + assertEquals(1, rebalanceListener.assignedCount); + assertFalse(rebalanceListener.assignment.failed()); + assertEquals(1L, rebalanceListener.assignment.offset()); + assertEquals(Collections.emptyList(), rebalanceListener.assignment.connectors()); + assertEquals(Collections.singletonList(taskId1x0), rebalanceListener.assignment.tasks()); + + // and join the group again + coordinator.requestRejoin("test"); + client.prepareResponse(joinGroupFollowerResponse(1, "consumer", "leader", Errors.NONE)); + client.prepareResponse(syncGroupResponse(ConnectProtocol.Assignment.NO_ERROR, "leader", 1L, Collections.singletonList(connectorId1), + Collections.emptyList(), Errors.NONE)); + coordinator.ensureActiveGroup(); + + assertEquals(1, rebalanceListener.revokedCount); + assertEquals(Collections.emptyList(), rebalanceListener.revokedConnectors); + assertEquals(Collections.singletonList(taskId1x0), rebalanceListener.revokedTasks); + assertEquals(2, rebalanceListener.assignedCount); + assertFalse(rebalanceListener.assignment.failed()); + assertEquals(1L, rebalanceListener.assignment.offset()); + assertEquals(Collections.singletonList(connectorId1), rebalanceListener.assignment.connectors()); + assertEquals(Collections.emptyList(), rebalanceListener.assignment.tasks()); + + PowerMock.verifyAll(); + } + + @Test + public void testLeaderPerformAssignment1() throws Exception { + // Since all the protocol responses are mocked, the other tests validate doSync runs, but don't validate its + // output. So we test it directly here. + + EasyMock.expect(configStorage.snapshot()).andReturn(configState1); + + PowerMock.replayAll(); + + // Prime the current configuration state + coordinator.metadata(); + + // Mark everyone as in sync with configState1 + List responseMembers = new ArrayList<>(); + responseMembers.add(new JoinGroupResponseData.JoinGroupResponseMember() + .setMemberId("leader") + .setMetadata(ConnectProtocol.serializeMetadata(new ConnectProtocol.WorkerState(LEADER_URL, 1L)).array()) + ); + responseMembers.add(new JoinGroupResponseData.JoinGroupResponseMember() + .setMemberId("member") + .setMetadata(ConnectProtocol.serializeMetadata(new ConnectProtocol.WorkerState(MEMBER_URL, 1L)).array()) + ); + Map result = coordinator.performAssignment("leader", EAGER.protocol(), responseMembers); + + // configState1 has 1 connector, 1 task + ConnectProtocol.Assignment leaderAssignment = ConnectProtocol.deserializeAssignment(result.get("leader")); + assertEquals(false, leaderAssignment.failed()); + assertEquals("leader", leaderAssignment.leader()); + assertEquals(1, leaderAssignment.offset()); + assertEquals(Collections.singletonList(connectorId1), leaderAssignment.connectors()); + assertEquals(Collections.emptyList(), leaderAssignment.tasks()); + + ConnectProtocol.Assignment memberAssignment = ConnectProtocol.deserializeAssignment(result.get("member")); + assertEquals(false, memberAssignment.failed()); + assertEquals("leader", memberAssignment.leader()); + assertEquals(1, memberAssignment.offset()); + assertEquals(Collections.emptyList(), memberAssignment.connectors()); + assertEquals(Collections.singletonList(taskId1x0), memberAssignment.tasks()); + + PowerMock.verifyAll(); + } + + @Test + public void testLeaderPerformAssignment2() throws Exception { + // Since all the protocol responses are mocked, the other tests validate doSync runs, but don't validate its + // output. So we test it directly here. + + EasyMock.expect(configStorage.snapshot()).andReturn(configState2); + + PowerMock.replayAll(); + + // Prime the current configuration state + coordinator.metadata(); + + // Mark everyone as in sync with configState1 + List responseMembers = new ArrayList<>(); + responseMembers.add(new JoinGroupResponseData.JoinGroupResponseMember() + .setMemberId("leader") + .setMetadata(ConnectProtocol.serializeMetadata(new ConnectProtocol.WorkerState(LEADER_URL, 1L)).array()) + ); + responseMembers.add(new JoinGroupResponseData.JoinGroupResponseMember() + .setMemberId("member") + .setMetadata(ConnectProtocol.serializeMetadata(new ConnectProtocol.WorkerState(MEMBER_URL, 1L)).array()) + ); + + Map result = coordinator.performAssignment("leader", EAGER.protocol(), responseMembers); + + // configState2 has 2 connector, 3 tasks and should trigger round robin assignment + ConnectProtocol.Assignment leaderAssignment = ConnectProtocol.deserializeAssignment(result.get("leader")); + assertEquals(false, leaderAssignment.failed()); + assertEquals("leader", leaderAssignment.leader()); + assertEquals(1, leaderAssignment.offset()); + assertEquals(Collections.singletonList(connectorId1), leaderAssignment.connectors()); + assertEquals(Arrays.asList(taskId1x0, taskId2x0), leaderAssignment.tasks()); + + ConnectProtocol.Assignment memberAssignment = ConnectProtocol.deserializeAssignment(result.get("member")); + assertEquals(false, memberAssignment.failed()); + assertEquals("leader", memberAssignment.leader()); + assertEquals(1, memberAssignment.offset()); + assertEquals(Collections.singletonList(connectorId2), memberAssignment.connectors()); + assertEquals(Collections.singletonList(taskId1x1), memberAssignment.tasks()); + + PowerMock.verifyAll(); + } + + @Test + public void testLeaderPerformAssignmentSingleTaskConnectors() throws Exception { + // Since all the protocol responses are mocked, the other tests validate doSync runs, but don't validate its + // output. So we test it directly here. + + EasyMock.expect(configStorage.snapshot()).andReturn(configStateSingleTaskConnectors); + + PowerMock.replayAll(); + + // Prime the current configuration state + coordinator.metadata(); + + // Mark everyone as in sync with configState1 + List responseMembers = new ArrayList<>(); + responseMembers.add(new JoinGroupResponseData.JoinGroupResponseMember() + .setMemberId("leader") + .setMetadata(ConnectProtocol.serializeMetadata(new ConnectProtocol.WorkerState(LEADER_URL, 1L)).array()) + ); + responseMembers.add(new JoinGroupResponseData.JoinGroupResponseMember() + .setMemberId("member") + .setMetadata(ConnectProtocol.serializeMetadata(new ConnectProtocol.WorkerState(MEMBER_URL, 1L)).array()) + ); + + Map result = coordinator.performAssignment("leader", EAGER.protocol(), responseMembers); + + // Round robin assignment when there are the same number of connectors and tasks should result in each being + // evenly distributed across the workers, i.e. round robin assignment of connectors first, then followed by tasks + ConnectProtocol.Assignment leaderAssignment = ConnectProtocol.deserializeAssignment(result.get("leader")); + assertEquals(false, leaderAssignment.failed()); + assertEquals("leader", leaderAssignment.leader()); + assertEquals(1, leaderAssignment.offset()); + assertEquals(Arrays.asList(connectorId1, connectorId3), leaderAssignment.connectors()); + assertEquals(Arrays.asList(taskId2x0), leaderAssignment.tasks()); + + ConnectProtocol.Assignment memberAssignment = ConnectProtocol.deserializeAssignment(result.get("member")); + assertEquals(false, memberAssignment.failed()); + assertEquals("leader", memberAssignment.leader()); + assertEquals(1, memberAssignment.offset()); + assertEquals(Collections.singletonList(connectorId2), memberAssignment.connectors()); + assertEquals(Arrays.asList(taskId1x0, taskId3x0), memberAssignment.tasks()); + + PowerMock.verifyAll(); + } + + private JoinGroupResponse joinGroupLeaderResponse(int generationId, String memberId, + Map configOffsets, Errors error) { + List metadata = new ArrayList<>(); + for (Map.Entry configStateEntry : configOffsets.entrySet()) { + // We need a member URL, but it doesn't matter for the purposes of this test. Just set it to the member ID + String memberUrl = configStateEntry.getKey(); + long configOffset = configStateEntry.getValue(); + ByteBuffer buf = ConnectProtocol.serializeMetadata(new ConnectProtocol.WorkerState(memberUrl, configOffset)); + metadata.add(new JoinGroupResponseData.JoinGroupResponseMember() + .setMemberId(configStateEntry.getKey()) + .setMetadata(buf.array()) + ); + } + return new JoinGroupResponse( + new JoinGroupResponseData().setErrorCode(error.code()) + .setGenerationId(generationId) + .setProtocolName(EAGER.protocol()) + .setLeader(memberId) + .setMemberId(memberId) + .setMembers(metadata) + ); + } + + private JoinGroupResponse joinGroupFollowerResponse(int generationId, String memberId, String leaderId, Errors error) { + return new JoinGroupResponse( + new JoinGroupResponseData().setErrorCode(error.code()) + .setGenerationId(generationId) + .setProtocolName(EAGER.protocol()) + .setLeader(leaderId) + .setMemberId(memberId) + .setMembers(Collections.emptyList()) + ); + } + + private SyncGroupResponse syncGroupResponse(short assignmentError, String leader, long configOffset, List connectorIds, + List taskIds, Errors error) { + ConnectProtocol.Assignment assignment = new ConnectProtocol.Assignment(assignmentError, leader, LEADER_URL, configOffset, connectorIds, taskIds); + ByteBuffer buf = ConnectProtocol.serializeAssignment(assignment); + return new SyncGroupResponse( + new SyncGroupResponseData() + .setErrorCode(error.code()) + .setAssignment(Utils.toArray(buf)) + ); + } + + private static class MockRebalanceListener implements WorkerRebalanceListener { + public ExtendedAssignment assignment = null; + + public String revokedLeader; + public Collection revokedConnectors; + public Collection revokedTasks; + + public int revokedCount = 0; + public int assignedCount = 0; + + @Override + public void onAssigned(ExtendedAssignment assignment, int generation) { + this.assignment = assignment; + assignedCount++; + } + + @Override + public void onRevoked(String leader, Collection connectors, Collection tasks) { + if (connectors.isEmpty() && tasks.isEmpty()) { + return; + } + this.revokedLeader = leader; + this.revokedConnectors = connectors; + this.revokedTasks = tasks; + revokedCount++; + } + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/WorkerGroupMemberTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/WorkerGroupMemberTest.java new file mode 100644 index 0000000..05cd017 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/distributed/WorkerGroupMemberTest.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.distributed; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.MetricsReporter; +import org.apache.kafka.common.metrics.stats.Avg; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.connect.runtime.MockConnectMetrics; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.apache.kafka.connect.storage.ConfigBackingStore; +import org.apache.kafka.connect.storage.StatusBackingStore; +import org.apache.kafka.connect.util.ConnectUtils; +import org.easymock.EasyMock; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.easymock.PowerMock; +import org.powermock.api.easymock.annotation.Mock; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; + +import javax.management.MBeanServer; +import javax.management.ObjectName; +import java.lang.management.ManagementFactory; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +@RunWith(PowerMockRunner.class) +@PrepareForTest({ConnectUtils.class}) +@PowerMockIgnore({"javax.management.*", "javax.crypto.*"}) +public class WorkerGroupMemberTest { + @Mock + private ConfigBackingStore configBackingStore; + @Mock + private StatusBackingStore statusBackingStore; + + @Test + public void testMetrics() throws Exception { + WorkerGroupMember member; + Map workerProps = new HashMap<>(); + workerProps.put("key.converter", "org.apache.kafka.connect.json.JsonConverter"); + workerProps.put("value.converter", "org.apache.kafka.connect.json.JsonConverter"); + workerProps.put("offset.storage.file.filename", "/tmp/connect.offsets"); + workerProps.put("group.id", "group-1"); + workerProps.put("offset.storage.topic", "topic-1"); + workerProps.put("config.storage.topic", "topic-1"); + workerProps.put("status.storage.topic", "topic-1"); + workerProps.put(CommonClientConfigs.METRIC_REPORTER_CLASSES_CONFIG, MockConnectMetrics.MockMetricsReporter.class.getName()); + DistributedConfig config = new DistributedConfig(workerProps); + + + LogContext logContext = new LogContext("[Worker clientId=client-1 + groupId= group-1]"); + + expectClusterId(); + + member = new WorkerGroupMember(config, "", configBackingStore, + null, Time.SYSTEM, "client-1", logContext); + + boolean entered = false; + for (MetricsReporter reporter : member.metrics().reporters()) { + if (reporter instanceof MockConnectMetrics.MockMetricsReporter) { + entered = true; + MockConnectMetrics.MockMetricsReporter mockMetricsReporter = (MockConnectMetrics.MockMetricsReporter) reporter; + assertEquals("cluster-1", mockMetricsReporter.getMetricsContext().contextLabels().get(WorkerConfig.CONNECT_KAFKA_CLUSTER_ID)); + assertEquals("group-1", mockMetricsReporter.getMetricsContext().contextLabels().get(WorkerConfig.CONNECT_GROUP_ID)); + } + } + assertTrue("Failed to verify MetricsReporter", entered); + + MetricName name = member.metrics().metricName("test.avg", "grp1"); + member.metrics().addMetric(name, new Avg()); + MBeanServer server = ManagementFactory.getPlatformMBeanServer(); + //verify metric exists with correct prefix + assertNotNull(server.getObjectInstance(new ObjectName("kafka.connect:type=grp1,client-id=client-1"))); + } + private void expectClusterId() { + PowerMock.mockStaticPartial(ConnectUtils.class, "lookupKafkaClusterId"); + EasyMock.expect(ConnectUtils.lookupKafkaClusterId(EasyMock.anyObject())).andReturn("cluster-1").anyTimes(); + PowerMock.replay(ConnectUtils.class); + } + +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/errors/ErrorReporterTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/errors/ErrorReporterTest.java new file mode 100644 index 0000000..11e72c2 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/errors/ErrorReporterTest.java @@ -0,0 +1,367 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.errors; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.json.JsonConverter; +import org.apache.kafka.connect.runtime.ConnectMetrics; +import org.apache.kafka.connect.runtime.ConnectorConfig; +import org.apache.kafka.connect.runtime.MockConnectMetrics; +import org.apache.kafka.connect.runtime.SinkConnectorConfig; +import org.apache.kafka.connect.runtime.isolation.Plugins; +import org.apache.kafka.connect.sink.SinkTask; +import org.apache.kafka.connect.transforms.Transformation; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.easymock.EasyMock; +import org.easymock.Mock; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.easymock.PowerMock; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Future; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.singletonMap; +import static org.apache.kafka.connect.runtime.errors.DeadLetterQueueReporter.ERROR_HEADER_CONNECTOR_NAME; +import static org.apache.kafka.connect.runtime.errors.DeadLetterQueueReporter.ERROR_HEADER_EXCEPTION; +import static org.apache.kafka.connect.runtime.errors.DeadLetterQueueReporter.ERROR_HEADER_EXCEPTION_MESSAGE; +import static org.apache.kafka.connect.runtime.errors.DeadLetterQueueReporter.ERROR_HEADER_EXCEPTION_STACK_TRACE; +import static org.apache.kafka.connect.runtime.errors.DeadLetterQueueReporter.ERROR_HEADER_EXECUTING_CLASS; +import static org.apache.kafka.connect.runtime.errors.DeadLetterQueueReporter.ERROR_HEADER_ORIG_OFFSET; +import static org.apache.kafka.connect.runtime.errors.DeadLetterQueueReporter.ERROR_HEADER_ORIG_PARTITION; +import static org.apache.kafka.connect.runtime.errors.DeadLetterQueueReporter.ERROR_HEADER_ORIG_TOPIC; +import static org.apache.kafka.connect.runtime.errors.DeadLetterQueueReporter.ERROR_HEADER_STAGE; +import static org.apache.kafka.connect.runtime.errors.DeadLetterQueueReporter.ERROR_HEADER_TASK_ID; +import static org.easymock.EasyMock.replay; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +@RunWith(PowerMockRunner.class) +@PowerMockIgnore("javax.management.*") +public class ErrorReporterTest { + + private static final String TOPIC = "test-topic"; + private static final String DLQ_TOPIC = "test-topic-errors"; + private static final ConnectorTaskId TASK_ID = new ConnectorTaskId("job", 0); + + @Mock + KafkaProducer producer; + + @Mock + Future metadata; + + @Mock + Plugins plugins; + + private ErrorHandlingMetrics errorHandlingMetrics; + private MockConnectMetrics metrics; + + @Before + public void setup() { + metrics = new MockConnectMetrics(); + errorHandlingMetrics = new ErrorHandlingMetrics(new ConnectorTaskId("connector-", 1), metrics); + } + + @After + public void tearDown() { + if (metrics != null) { + metrics.stop(); + } + } + + @Test + public void initializeDLQWithNullMetrics() { + assertThrows(NullPointerException.class, () -> new DeadLetterQueueReporter(producer, config(emptyMap()), TASK_ID, null)); + } + + @Test + public void testDLQConfigWithEmptyTopicName() { + DeadLetterQueueReporter deadLetterQueueReporter = new DeadLetterQueueReporter( + producer, config(emptyMap()), TASK_ID, errorHandlingMetrics); + + ProcessingContext context = processingContext(); + + EasyMock.expect(producer.send(EasyMock.anyObject(), EasyMock.anyObject())).andThrow(new RuntimeException()); + replay(producer); + + // since topic name is empty, this method should be a NOOP. + // if it attempts to log to the DLQ via the producer, the send mock will throw a RuntimeException. + deadLetterQueueReporter.report(context); + } + + @Test + public void testDLQConfigWithValidTopicName() { + DeadLetterQueueReporter deadLetterQueueReporter = new DeadLetterQueueReporter( + producer, config(singletonMap(SinkConnectorConfig.DLQ_TOPIC_NAME_CONFIG, DLQ_TOPIC)), TASK_ID, errorHandlingMetrics); + + ProcessingContext context = processingContext(); + + EasyMock.expect(producer.send(EasyMock.anyObject(), EasyMock.anyObject())).andReturn(metadata); + replay(producer); + + deadLetterQueueReporter.report(context); + + PowerMock.verifyAll(); + } + + @Test + public void testReportDLQTwice() { + DeadLetterQueueReporter deadLetterQueueReporter = new DeadLetterQueueReporter( + producer, config(singletonMap(SinkConnectorConfig.DLQ_TOPIC_NAME_CONFIG, DLQ_TOPIC)), TASK_ID, errorHandlingMetrics); + + ProcessingContext context = processingContext(); + + EasyMock.expect(producer.send(EasyMock.anyObject(), EasyMock.anyObject())).andReturn(metadata).times(2); + replay(producer); + + deadLetterQueueReporter.report(context); + deadLetterQueueReporter.report(context); + + PowerMock.verifyAll(); + } + + @Test + public void testDLQReportAndReturnFuture() { + DeadLetterQueueReporter deadLetterQueueReporter = new DeadLetterQueueReporter( + producer, config(singletonMap(SinkConnectorConfig.DLQ_TOPIC_NAME_CONFIG, DLQ_TOPIC)), TASK_ID, errorHandlingMetrics); + + ProcessingContext context = processingContext(); + + EasyMock.expect(producer.send(EasyMock.anyObject(), EasyMock.anyObject())).andReturn(metadata); + replay(producer); + + deadLetterQueueReporter.report(context); + + PowerMock.verifyAll(); + } + + @Test + public void testCloseDLQ() { + DeadLetterQueueReporter deadLetterQueueReporter = new DeadLetterQueueReporter( + producer, config(singletonMap(SinkConnectorConfig.DLQ_TOPIC_NAME_CONFIG, DLQ_TOPIC)), TASK_ID, errorHandlingMetrics); + + producer.close(); + EasyMock.expectLastCall(); + replay(producer); + + deadLetterQueueReporter.close(); + + PowerMock.verifyAll(); + } + + @Test + public void testLogOnDisabledLogReporter() { + LogReporter logReporter = new LogReporter(TASK_ID, config(emptyMap()), errorHandlingMetrics); + + ProcessingContext context = processingContext(); + context.error(new RuntimeException()); + + // reporting a context without an error should not cause any errors. + logReporter.report(context); + assertErrorHandlingMetricValue("total-errors-logged", 0.0); + } + + @Test + public void testLogOnEnabledLogReporter() { + LogReporter logReporter = new LogReporter(TASK_ID, config(singletonMap(ConnectorConfig.ERRORS_LOG_ENABLE_CONFIG, "true")), errorHandlingMetrics); + + ProcessingContext context = processingContext(); + context.error(new RuntimeException()); + + // reporting a context without an error should not cause any errors. + logReporter.report(context); + assertErrorHandlingMetricValue("total-errors-logged", 1.0); + } + + @Test + public void testLogMessageWithNoRecords() { + LogReporter logReporter = new LogReporter(TASK_ID, config(singletonMap(ConnectorConfig.ERRORS_LOG_ENABLE_CONFIG, "true")), errorHandlingMetrics); + + ProcessingContext context = processingContext(); + + String msg = logReporter.message(context); + assertEquals("Error encountered in task job-0. Executing stage 'KEY_CONVERTER' with class " + + "'org.apache.kafka.connect.json.JsonConverter'.", msg); + } + + @Test + public void testLogMessageWithSinkRecords() { + Map props = new HashMap<>(); + props.put(ConnectorConfig.ERRORS_LOG_ENABLE_CONFIG, "true"); + props.put(ConnectorConfig.ERRORS_LOG_INCLUDE_MESSAGES_CONFIG, "true"); + + LogReporter logReporter = new LogReporter(TASK_ID, config(props), errorHandlingMetrics); + + ProcessingContext context = processingContext(); + + String msg = logReporter.message(context); + assertEquals("Error encountered in task job-0. Executing stage 'KEY_CONVERTER' with class " + + "'org.apache.kafka.connect.json.JsonConverter', where consumed record is {topic='test-topic', " + + "partition=5, offset=100}.", msg); + } + + @Test + public void testLogReportAndReturnFuture() { + Map props = new HashMap<>(); + props.put(ConnectorConfig.ERRORS_LOG_ENABLE_CONFIG, "true"); + props.put(ConnectorConfig.ERRORS_LOG_INCLUDE_MESSAGES_CONFIG, "true"); + + LogReporter logReporter = new LogReporter(TASK_ID, config(props), errorHandlingMetrics); + + ProcessingContext context = processingContext(); + + String msg = logReporter.message(context); + assertEquals("Error encountered in task job-0. Executing stage 'KEY_CONVERTER' with class " + + "'org.apache.kafka.connect.json.JsonConverter', where consumed record is {topic='test-topic', " + + "partition=5, offset=100}.", msg); + + Future future = logReporter.report(context); + assertTrue(future instanceof CompletableFuture); + } + + @Test + public void testSetDLQConfigs() { + SinkConnectorConfig configuration = config(singletonMap(SinkConnectorConfig.DLQ_TOPIC_NAME_CONFIG, DLQ_TOPIC)); + assertEquals(configuration.dlqTopicName(), DLQ_TOPIC); + + configuration = config(singletonMap(SinkConnectorConfig.DLQ_TOPIC_REPLICATION_FACTOR_CONFIG, "7")); + assertEquals(configuration.dlqTopicReplicationFactor(), 7); + } + + @Test + public void testDlqHeaderConsumerRecord() { + Map props = new HashMap<>(); + props.put(SinkConnectorConfig.DLQ_TOPIC_NAME_CONFIG, DLQ_TOPIC); + props.put(SinkConnectorConfig.DLQ_CONTEXT_HEADERS_ENABLE_CONFIG, "true"); + DeadLetterQueueReporter deadLetterQueueReporter = new DeadLetterQueueReporter(producer, config(props), TASK_ID, errorHandlingMetrics); + + ProcessingContext context = new ProcessingContext(); + context.consumerRecord(new ConsumerRecord<>("source-topic", 7, 10, "source-key".getBytes(), "source-value".getBytes())); + context.currentContext(Stage.TRANSFORMATION, Transformation.class); + context.error(new ConnectException("Test Exception")); + + ProducerRecord producerRecord = new ProducerRecord<>(DLQ_TOPIC, "source-key".getBytes(), "source-value".getBytes()); + + deadLetterQueueReporter.populateContextHeaders(producerRecord, context); + assertEquals("source-topic", headerValue(producerRecord, ERROR_HEADER_ORIG_TOPIC)); + assertEquals("7", headerValue(producerRecord, ERROR_HEADER_ORIG_PARTITION)); + assertEquals("10", headerValue(producerRecord, ERROR_HEADER_ORIG_OFFSET)); + assertEquals(TASK_ID.connector(), headerValue(producerRecord, ERROR_HEADER_CONNECTOR_NAME)); + assertEquals(String.valueOf(TASK_ID.task()), headerValue(producerRecord, ERROR_HEADER_TASK_ID)); + assertEquals(Stage.TRANSFORMATION.name(), headerValue(producerRecord, ERROR_HEADER_STAGE)); + assertEquals(Transformation.class.getName(), headerValue(producerRecord, ERROR_HEADER_EXECUTING_CLASS)); + assertEquals(ConnectException.class.getName(), headerValue(producerRecord, ERROR_HEADER_EXCEPTION)); + assertEquals("Test Exception", headerValue(producerRecord, ERROR_HEADER_EXCEPTION_MESSAGE)); + assertTrue(headerValue(producerRecord, ERROR_HEADER_EXCEPTION_STACK_TRACE).length() > 0); + assertTrue(headerValue(producerRecord, ERROR_HEADER_EXCEPTION_STACK_TRACE).startsWith("org.apache.kafka.connect.errors.ConnectException: Test Exception")); + } + + @Test + public void testDlqHeaderOnNullExceptionMessage() { + Map props = new HashMap<>(); + props.put(SinkConnectorConfig.DLQ_TOPIC_NAME_CONFIG, DLQ_TOPIC); + props.put(SinkConnectorConfig.DLQ_CONTEXT_HEADERS_ENABLE_CONFIG, "true"); + DeadLetterQueueReporter deadLetterQueueReporter = new DeadLetterQueueReporter(producer, config(props), TASK_ID, errorHandlingMetrics); + + ProcessingContext context = new ProcessingContext(); + context.consumerRecord(new ConsumerRecord<>("source-topic", 7, 10, "source-key".getBytes(), "source-value".getBytes())); + context.currentContext(Stage.TRANSFORMATION, Transformation.class); + context.error(new NullPointerException()); + + ProducerRecord producerRecord = new ProducerRecord<>(DLQ_TOPIC, "source-key".getBytes(), "source-value".getBytes()); + + deadLetterQueueReporter.populateContextHeaders(producerRecord, context); + assertEquals("source-topic", headerValue(producerRecord, ERROR_HEADER_ORIG_TOPIC)); + assertEquals("7", headerValue(producerRecord, ERROR_HEADER_ORIG_PARTITION)); + assertEquals("10", headerValue(producerRecord, ERROR_HEADER_ORIG_OFFSET)); + assertEquals(TASK_ID.connector(), headerValue(producerRecord, ERROR_HEADER_CONNECTOR_NAME)); + assertEquals(String.valueOf(TASK_ID.task()), headerValue(producerRecord, ERROR_HEADER_TASK_ID)); + assertEquals(Stage.TRANSFORMATION.name(), headerValue(producerRecord, ERROR_HEADER_STAGE)); + assertEquals(Transformation.class.getName(), headerValue(producerRecord, ERROR_HEADER_EXECUTING_CLASS)); + assertEquals(NullPointerException.class.getName(), headerValue(producerRecord, ERROR_HEADER_EXCEPTION)); + assertNull(producerRecord.headers().lastHeader(ERROR_HEADER_EXCEPTION_MESSAGE).value()); + assertTrue(headerValue(producerRecord, ERROR_HEADER_EXCEPTION_STACK_TRACE).length() > 0); + assertTrue(headerValue(producerRecord, ERROR_HEADER_EXCEPTION_STACK_TRACE).startsWith("java.lang.NullPointerException")); + } + + @Test + public void testDlqHeaderIsAppended() { + Map props = new HashMap<>(); + props.put(SinkConnectorConfig.DLQ_TOPIC_NAME_CONFIG, DLQ_TOPIC); + props.put(SinkConnectorConfig.DLQ_CONTEXT_HEADERS_ENABLE_CONFIG, "true"); + DeadLetterQueueReporter deadLetterQueueReporter = new DeadLetterQueueReporter(producer, config(props), TASK_ID, errorHandlingMetrics); + + ProcessingContext context = new ProcessingContext(); + context.consumerRecord(new ConsumerRecord<>("source-topic", 7, 10, "source-key".getBytes(), "source-value".getBytes())); + context.currentContext(Stage.TRANSFORMATION, Transformation.class); + context.error(new ConnectException("Test Exception")); + + ProducerRecord producerRecord = new ProducerRecord<>(DLQ_TOPIC, "source-key".getBytes(), "source-value".getBytes()); + producerRecord.headers().add(ERROR_HEADER_ORIG_TOPIC, "dummy".getBytes()); + + deadLetterQueueReporter.populateContextHeaders(producerRecord, context); + int appearances = 0; + for (Header header: producerRecord.headers()) { + if (ERROR_HEADER_ORIG_TOPIC.equalsIgnoreCase(header.key())) { + appearances++; + } + } + + assertEquals("source-topic", headerValue(producerRecord, ERROR_HEADER_ORIG_TOPIC)); + assertEquals(2, appearances); + } + + private String headerValue(ProducerRecord producerRecord, String headerSuffix) { + return new String(producerRecord.headers().lastHeader(headerSuffix).value()); + } + + private ProcessingContext processingContext() { + ProcessingContext context = new ProcessingContext(); + context.consumerRecord(new ConsumerRecord<>(TOPIC, 5, 100, new byte[]{'a', 'b'}, new byte[]{'x'})); + context.currentContext(Stage.KEY_CONVERTER, JsonConverter.class); + return context; + } + + private SinkConnectorConfig config(Map configProps) { + Map props = new HashMap<>(); + props.put(ConnectorConfig.NAME_CONFIG, "test"); + props.put(ConnectorConfig.CONNECTOR_CLASS_CONFIG, SinkTask.class.getName()); + props.putAll(configProps); + return new SinkConnectorConfig(plugins, props); + } + + private void assertErrorHandlingMetricValue(String name, double expected) { + ConnectMetrics.MetricGroup sinkTaskGroup = errorHandlingMetrics.metricGroup(); + double measured = metrics.currentMetricValueAsDouble(sinkTaskGroup, name); + assertEquals(expected, measured, 0.001d); + } + +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/errors/ProcessingContextTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/errors/ProcessingContextTest.java new file mode 100644 index 0000000..89f1013 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/errors/ProcessingContextTest.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.errors; + +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.TopicPartition; +import org.junit.Test; + +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Future; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class ProcessingContextTest { + + @Test + public void testReportWithSingleReporter() { + testReport(1); + } + + @Test + public void testReportWithMultipleReporters() { + testReport(2); + } + + private void testReport(int numberOfReports) { + ProcessingContext context = new ProcessingContext(); + List> fs = IntStream.range(0, numberOfReports).mapToObj(i -> new CompletableFuture()).collect(Collectors.toList()); + context.reporters(IntStream.range(0, numberOfReports).mapToObj(i -> (ErrorReporter) c -> fs.get(i)).collect(Collectors.toList())); + Future result = context.report(); + fs.forEach(f -> { + assertFalse(result.isDone()); + f.complete(new RecordMetadata(new TopicPartition("t", 0), 0, 0, 0, 0, 0)); + }); + assertTrue(result.isDone()); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/errors/RetryWithToleranceOperatorTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/errors/RetryWithToleranceOperatorTest.java new file mode 100644 index 0000000..68f8afc --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/errors/RetryWithToleranceOperatorTest.java @@ -0,0 +1,457 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.errors; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.SystemTime; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.errors.RetriableException; +import org.apache.kafka.connect.runtime.ConnectMetrics; +import org.apache.kafka.connect.runtime.ConnectorConfig; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.apache.kafka.connect.runtime.isolation.Plugins; +import org.apache.kafka.connect.runtime.isolation.PluginsTest.TestConverter; +import org.apache.kafka.connect.runtime.isolation.PluginsTest.TestableWorkerConfig; +import org.apache.kafka.connect.sink.SinkTask; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.easymock.EasyMock; +import org.easymock.Mock; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.easymock.PowerMock; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.singletonMap; +import static org.apache.kafka.common.utils.Time.SYSTEM; +import static org.apache.kafka.connect.runtime.ConnectorConfig.ERRORS_RETRY_MAX_DELAY_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.ERRORS_RETRY_MAX_DELAY_DEFAULT; +import static org.apache.kafka.connect.runtime.ConnectorConfig.ERRORS_RETRY_TIMEOUT_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.ERRORS_RETRY_TIMEOUT_DEFAULT; +import static org.apache.kafka.connect.runtime.ConnectorConfig.ERRORS_TOLERANCE_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.ERRORS_TOLERANCE_DEFAULT; +import static org.apache.kafka.connect.runtime.errors.ToleranceType.ALL; +import static org.apache.kafka.connect.runtime.errors.ToleranceType.NONE; +import static org.easymock.EasyMock.replay; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +@RunWith(PowerMockRunner.class) +@PrepareForTest({ProcessingContext.class}) +@PowerMockIgnore("javax.management.*") +public class RetryWithToleranceOperatorTest { + + public static final RetryWithToleranceOperator NOOP_OPERATOR = new RetryWithToleranceOperator( + ERRORS_RETRY_TIMEOUT_DEFAULT, ERRORS_RETRY_MAX_DELAY_DEFAULT, NONE, SYSTEM); + static { + Map properties = new HashMap<>(); + properties.put(CommonClientConfigs.METRICS_NUM_SAMPLES_CONFIG, Objects.toString(2)); + properties.put(CommonClientConfigs.METRICS_SAMPLE_WINDOW_MS_CONFIG, Objects.toString(3000)); + properties.put(CommonClientConfigs.METRICS_RECORDING_LEVEL_CONFIG, Sensor.RecordingLevel.INFO.toString()); + + // define required properties + properties.put(WorkerConfig.KEY_CONVERTER_CLASS_CONFIG, TestConverter.class.getName()); + properties.put(WorkerConfig.VALUE_CONVERTER_CLASS_CONFIG, TestConverter.class.getName()); + + NOOP_OPERATOR.metrics(new ErrorHandlingMetrics( + new ConnectorTaskId("noop-connector", -1), + new ConnectMetrics("noop-worker", new TestableWorkerConfig(properties), new SystemTime(), "test-cluster")) + ); + } + + @SuppressWarnings("unused") + @Mock + private Operation mockOperation; + + @Mock + private ConsumerRecord consumerRecord; + + @Mock + ErrorHandlingMetrics errorHandlingMetrics; + + @Mock + Plugins plugins; + + @Test + public void testExecuteFailed() { + RetryWithToleranceOperator retryWithToleranceOperator = new RetryWithToleranceOperator(0, + ERRORS_RETRY_MAX_DELAY_DEFAULT, ALL, SYSTEM); + retryWithToleranceOperator.metrics(errorHandlingMetrics); + + retryWithToleranceOperator.executeFailed(Stage.TASK_PUT, + SinkTask.class, consumerRecord, new Throwable()); + } + + @Test + public void testExecuteFailedNoTolerance() { + RetryWithToleranceOperator retryWithToleranceOperator = new RetryWithToleranceOperator(0, + ERRORS_RETRY_MAX_DELAY_DEFAULT, NONE, SYSTEM); + retryWithToleranceOperator.metrics(errorHandlingMetrics); + + assertThrows(ConnectException.class, () -> retryWithToleranceOperator.executeFailed(Stage.TASK_PUT, + SinkTask.class, consumerRecord, new Throwable())); + } + + @Test + public void testHandleExceptionInTransformations() { + testHandleExceptionInStage(Stage.TRANSFORMATION, new Exception()); + } + + @Test + public void testHandleExceptionInHeaderConverter() { + testHandleExceptionInStage(Stage.HEADER_CONVERTER, new Exception()); + } + + @Test + public void testHandleExceptionInValueConverter() { + testHandleExceptionInStage(Stage.VALUE_CONVERTER, new Exception()); + } + + @Test + public void testHandleExceptionInKeyConverter() { + testHandleExceptionInStage(Stage.KEY_CONVERTER, new Exception()); + } + + @Test + public void testHandleExceptionInTaskPut() { + testHandleExceptionInStage(Stage.TASK_PUT, new org.apache.kafka.connect.errors.RetriableException("Test")); + } + + @Test + public void testHandleExceptionInTaskPoll() { + testHandleExceptionInStage(Stage.TASK_POLL, new org.apache.kafka.connect.errors.RetriableException("Test")); + } + + @Test + public void testThrowExceptionInTaskPut() { + assertThrows(ConnectException.class, () -> testHandleExceptionInStage(Stage.TASK_PUT, new Exception())); + } + + @Test + public void testThrowExceptionInTaskPoll() { + assertThrows(ConnectException.class, () -> testHandleExceptionInStage(Stage.TASK_POLL, new Exception())); + } + + @Test + public void testThrowExceptionInKafkaConsume() { + assertThrows(ConnectException.class, () -> testHandleExceptionInStage(Stage.KAFKA_CONSUME, new Exception())); + } + + @Test + public void testThrowExceptionInKafkaProduce() { + assertThrows(ConnectException.class, () -> testHandleExceptionInStage(Stage.KAFKA_PRODUCE, new Exception())); + } + + private void testHandleExceptionInStage(Stage type, Exception ex) { + RetryWithToleranceOperator retryWithToleranceOperator = setupExecutor(); + retryWithToleranceOperator.execute(new ExceptionThrower(ex), type, ExceptionThrower.class); + assertTrue(retryWithToleranceOperator.failed()); + PowerMock.verifyAll(); + } + + private RetryWithToleranceOperator setupExecutor() { + RetryWithToleranceOperator retryWithToleranceOperator = new RetryWithToleranceOperator(0, ERRORS_RETRY_MAX_DELAY_DEFAULT, ALL, SYSTEM); + retryWithToleranceOperator.metrics(errorHandlingMetrics); + return retryWithToleranceOperator; + } + + @Test + public void testExecAndHandleRetriableErrorOnce() throws Exception { + execAndHandleRetriableError(1, 300, new RetriableException("Test")); + } + + @Test + public void testExecAndHandleRetriableErrorThrice() throws Exception { + execAndHandleRetriableError(3, 2100, new RetriableException("Test")); + } + + @Test + public void testExecAndHandleNonRetriableErrorOnce() throws Exception { + execAndHandleNonRetriableError(1, 0, new Exception("Non Retriable Test")); + } + + @Test + public void testExecAndHandleNonRetriableErrorThrice() throws Exception { + execAndHandleNonRetriableError(3, 0, new Exception("Non Retriable Test")); + } + + public void execAndHandleRetriableError(int numRetriableExceptionsThrown, long expectedWait, Exception e) throws Exception { + MockTime time = new MockTime(0, 0, 0); + RetryWithToleranceOperator retryWithToleranceOperator = new RetryWithToleranceOperator(6000, ERRORS_RETRY_MAX_DELAY_DEFAULT, ALL, time); + retryWithToleranceOperator.metrics(errorHandlingMetrics); + + EasyMock.expect(mockOperation.call()).andThrow(e).times(numRetriableExceptionsThrown); + EasyMock.expect(mockOperation.call()).andReturn("Success"); + + replay(mockOperation); + + String result = retryWithToleranceOperator.execAndHandleError(mockOperation, Exception.class); + assertFalse(retryWithToleranceOperator.failed()); + assertEquals("Success", result); + assertEquals(expectedWait, time.hiResClockMs()); + + PowerMock.verifyAll(); + } + + public void execAndHandleNonRetriableError(int numRetriableExceptionsThrown, long expectedWait, Exception e) throws Exception { + MockTime time = new MockTime(0, 0, 0); + RetryWithToleranceOperator retryWithToleranceOperator = new RetryWithToleranceOperator(6000, ERRORS_RETRY_MAX_DELAY_DEFAULT, ALL, time); + retryWithToleranceOperator.metrics(errorHandlingMetrics); + + EasyMock.expect(mockOperation.call()).andThrow(e).times(numRetriableExceptionsThrown); + EasyMock.expect(mockOperation.call()).andReturn("Success"); + + replay(mockOperation); + + String result = retryWithToleranceOperator.execAndHandleError(mockOperation, Exception.class); + assertTrue(retryWithToleranceOperator.failed()); + assertNull(result); + assertEquals(expectedWait, time.hiResClockMs()); + + PowerMock.verifyAll(); + } + + @Test + public void testCheckRetryLimit() { + MockTime time = new MockTime(0, 0, 0); + RetryWithToleranceOperator retryWithToleranceOperator = new RetryWithToleranceOperator(500, 100, NONE, time); + + time.setCurrentTimeMs(100); + assertTrue(retryWithToleranceOperator.checkRetry(0)); + + time.setCurrentTimeMs(200); + assertTrue(retryWithToleranceOperator.checkRetry(0)); + + time.setCurrentTimeMs(400); + assertTrue(retryWithToleranceOperator.checkRetry(0)); + + time.setCurrentTimeMs(499); + assertTrue(retryWithToleranceOperator.checkRetry(0)); + + time.setCurrentTimeMs(501); + assertFalse(retryWithToleranceOperator.checkRetry(0)); + + time.setCurrentTimeMs(600); + assertFalse(retryWithToleranceOperator.checkRetry(0)); + } + + @Test + public void testBackoffLimit() { + MockTime time = new MockTime(0, 0, 0); + RetryWithToleranceOperator retryWithToleranceOperator = new RetryWithToleranceOperator(5, 5000, NONE, time); + + long prevTs = time.hiResClockMs(); + retryWithToleranceOperator.backoff(1, 5000); + assertEquals(300, time.hiResClockMs() - prevTs); + + prevTs = time.hiResClockMs(); + retryWithToleranceOperator.backoff(2, 5000); + assertEquals(600, time.hiResClockMs() - prevTs); + + prevTs = time.hiResClockMs(); + retryWithToleranceOperator.backoff(3, 5000); + assertEquals(1200, time.hiResClockMs() - prevTs); + + prevTs = time.hiResClockMs(); + retryWithToleranceOperator.backoff(4, 5000); + assertEquals(2400, time.hiResClockMs() - prevTs); + + prevTs = time.hiResClockMs(); + retryWithToleranceOperator.backoff(5, 5000); + assertEquals(500, time.hiResClockMs() - prevTs); + + prevTs = time.hiResClockMs(); + retryWithToleranceOperator.backoff(6, 5000); + assertEquals(0, time.hiResClockMs() - prevTs); + + PowerMock.verifyAll(); + } + + @Test + public void testToleranceLimit() { + RetryWithToleranceOperator retryWithToleranceOperator = new RetryWithToleranceOperator(ERRORS_RETRY_TIMEOUT_DEFAULT, ERRORS_RETRY_MAX_DELAY_DEFAULT, NONE, SYSTEM); + retryWithToleranceOperator.metrics(errorHandlingMetrics); + retryWithToleranceOperator.markAsFailed(); + assertFalse("should not tolerate any errors", retryWithToleranceOperator.withinToleranceLimits()); + + retryWithToleranceOperator = new RetryWithToleranceOperator(ERRORS_RETRY_TIMEOUT_DEFAULT, ERRORS_RETRY_MAX_DELAY_DEFAULT, ALL, SYSTEM); + retryWithToleranceOperator.metrics(errorHandlingMetrics); + retryWithToleranceOperator.markAsFailed(); + retryWithToleranceOperator.markAsFailed(); + assertTrue("should tolerate all errors", retryWithToleranceOperator.withinToleranceLimits()); + + retryWithToleranceOperator = new RetryWithToleranceOperator(ERRORS_RETRY_TIMEOUT_DEFAULT, ERRORS_RETRY_MAX_DELAY_DEFAULT, NONE, SYSTEM); + assertTrue("no tolerance is within limits if no failures", retryWithToleranceOperator.withinToleranceLimits()); + } + + @Test + public void testDefaultConfigs() { + ConnectorConfig configuration = config(emptyMap()); + assertEquals(configuration.errorRetryTimeout(), ERRORS_RETRY_TIMEOUT_DEFAULT); + assertEquals(configuration.errorMaxDelayInMillis(), ERRORS_RETRY_MAX_DELAY_DEFAULT); + assertEquals(configuration.errorToleranceType(), ERRORS_TOLERANCE_DEFAULT); + + PowerMock.verifyAll(); + } + + ConnectorConfig config(Map connProps) { + Map props = new HashMap<>(); + props.put(ConnectorConfig.NAME_CONFIG, "test"); + props.put(ConnectorConfig.CONNECTOR_CLASS_CONFIG, SinkTask.class.getName()); + props.putAll(connProps); + return new ConnectorConfig(plugins, props); + } + + @Test + public void testSetConfigs() { + ConnectorConfig configuration; + configuration = config(singletonMap(ERRORS_RETRY_TIMEOUT_CONFIG, "100")); + assertEquals(configuration.errorRetryTimeout(), 100); + + configuration = config(singletonMap(ERRORS_RETRY_MAX_DELAY_CONFIG, "100")); + assertEquals(configuration.errorMaxDelayInMillis(), 100); + + configuration = config(singletonMap(ERRORS_TOLERANCE_CONFIG, "none")); + assertEquals(configuration.errorToleranceType(), ToleranceType.NONE); + + PowerMock.verifyAll(); + } + + @Test + public void testThreadSafety() throws Throwable { + long runtimeMs = 5_000; + int numThreads = 10; + // Check that multiple threads using RetryWithToleranceOperator concurrently + // can't corrupt the state of the ProcessingContext + AtomicReference failed = new AtomicReference<>(null); + RetryWithToleranceOperator retryWithToleranceOperator = new RetryWithToleranceOperator(0, + ERRORS_RETRY_MAX_DELAY_DEFAULT, ALL, SYSTEM, new ProcessingContext() { + private AtomicInteger count = new AtomicInteger(); + private AtomicInteger attempt = new AtomicInteger(); + + @Override + public void error(Throwable error) { + if (count.getAndIncrement() > 0) { + failed.compareAndSet(null, new AssertionError("Concurrent call to error()")); + } + super.error(error); + } + + @Override + public Future report() { + if (count.getAndSet(0) > 1) { + failed.compareAndSet(null, new AssertionError("Concurrent call to error() in report()")); + } + + return super.report(); + } + + @Override + public void currentContext(Stage stage, Class klass) { + this.attempt.set(0); + super.currentContext(stage, klass); + } + + @Override + public void attempt(int attempt) { + if (!this.attempt.compareAndSet(attempt - 1, attempt)) { + failed.compareAndSet(null, new AssertionError( + "Concurrent call to attempt(): Attempts should increase monotonically " + + "within the scope of a given currentContext()")); + } + super.attempt(attempt); + } + }); + retryWithToleranceOperator.metrics(errorHandlingMetrics); + + ExecutorService pool = Executors.newFixedThreadPool(numThreads); + List> futures = IntStream.range(0, numThreads).boxed() + .map(id -> + pool.submit(() -> { + long t0 = System.currentTimeMillis(); + long i = 0; + while (true) { + if (++i % 10000 == 0 && System.currentTimeMillis() > t0 + runtimeMs) { + break; + } + if (failed.get() != null) { + break; + } + try { + if (id < numThreads / 2) { + retryWithToleranceOperator.executeFailed(Stage.TASK_PUT, + SinkTask.class, consumerRecord, new Throwable()).get(); + } else { + retryWithToleranceOperator.execute(() -> null, Stage.TRANSFORMATION, + SinkTask.class); + } + } catch (Exception e) { + failed.compareAndSet(null, e); + } + } + })) + .collect(Collectors.toList()); + pool.shutdown(); + pool.awaitTermination((long) (1.5 * runtimeMs), TimeUnit.MILLISECONDS); + futures.forEach(future -> { + try { + future.get(); + } catch (Exception e) { + failed.compareAndSet(null, e); + } + }); + Throwable exception = failed.get(); + if (exception != null) { + throw exception; + } + } + + + private static class ExceptionThrower implements Operation { + private Exception e; + + public ExceptionThrower(Exception e) { + this.e = e; + } + + @Override + public Object call() throws Exception { + throw e; + } + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/errors/WorkerErrantRecordReporterTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/errors/WorkerErrantRecordReporterTest.java new file mode 100644 index 0000000..2d78297 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/errors/WorkerErrantRecordReporterTest.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.runtime.errors; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.connect.sink.SinkRecord; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.storage.HeaderConverter; +import org.easymock.Mock; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.concurrent.CompletableFuture; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + + +@RunWith(PowerMockRunner.class) +@PowerMockIgnore("javax.management.*") +public class WorkerErrantRecordReporterTest { + + private WorkerErrantRecordReporter reporter; + + @Mock + private RetryWithToleranceOperator retryWithToleranceOperator; + + @Mock + private Converter converter; + + @Mock + private HeaderConverter headerConverter; + + @Mock + private SinkRecord record; + + @Before + public void setup() { + reporter = new WorkerErrantRecordReporter( + retryWithToleranceOperator, + converter, + converter, + headerConverter + ); + } + + @Test + public void testGetFutures() { + Collection topicPartitions = new ArrayList<>(); + assertTrue(reporter.futures.isEmpty()); + for (int i = 0; i < 4; i++) { + TopicPartition topicPartition = new TopicPartition("topic", i); + topicPartitions.add(topicPartition); + reporter.futures.put(topicPartition, Collections.singletonList(CompletableFuture.completedFuture(null))); + } + assertFalse(reporter.futures.isEmpty()); + reporter.awaitFutures(topicPartitions); + assertTrue(reporter.futures.isEmpty()); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/health/ConnectClusterStateImplTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/health/ConnectClusterStateImplTest.java new file mode 100644 index 0000000..58eb5a9 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/health/ConnectClusterStateImplTest.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.health; + +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.runtime.Herder; +import org.apache.kafka.connect.util.Callback; +import org.easymock.Capture; +import org.easymock.EasyMock; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.easymock.annotation.Mock; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; +import static org.junit.Assert.assertThrows; + +@RunWith(PowerMockRunner.class) +public class ConnectClusterStateImplTest { + protected static final String KAFKA_CLUSTER_ID = "franzwashere"; + + @Mock + protected Herder herder; + protected ConnectClusterStateImpl connectClusterState; + protected long herderRequestTimeoutMs = TimeUnit.SECONDS.toMillis(10); + protected Collection expectedConnectors; + + @Before + public void setUp() { + expectedConnectors = Arrays.asList("sink1", "source1", "source2"); + connectClusterState = new ConnectClusterStateImpl( + herderRequestTimeoutMs, + new ConnectClusterDetailsImpl(KAFKA_CLUSTER_ID), + herder + ); + } + + @Test + public void connectors() { + Capture>> callback = EasyMock.newCapture(); + herder.connectors(EasyMock.capture(callback)); + EasyMock.expectLastCall().andAnswer(() -> { + callback.getValue().onCompletion(null, expectedConnectors); + return null; + }); + EasyMock.replay(herder); + assertEquals(expectedConnectors, connectClusterState.connectors()); + } + + @Test + public void connectorConfig() { + final String connName = "sink6"; + final Map expectedConfig = Collections.singletonMap("key", "value"); + Capture>> callback = EasyMock.newCapture(); + herder.connectorConfig(EasyMock.eq(connName), EasyMock.capture(callback)); + EasyMock.expectLastCall().andAnswer(() -> { + callback.getValue().onCompletion(null, expectedConfig); + return null; + }); + EasyMock.replay(herder); + Map actualConfig = connectClusterState.connectorConfig(connName); + assertEquals(expectedConfig, actualConfig); + assertNotSame( + "Config should be copied in order to avoid mutation by REST extensions", + expectedConfig, + actualConfig + ); + } + + @Test + public void kafkaClusterId() { + assertEquals(KAFKA_CLUSTER_ID, connectClusterState.clusterDetails().kafkaClusterId()); + } + + @Test + public void connectorsFailure() { + Capture>> callback = EasyMock.newCapture(); + herder.connectors(EasyMock.capture(callback)); + EasyMock.expectLastCall().andAnswer(() -> { + Throwable timeout = new TimeoutException(); + callback.getValue().onCompletion(timeout, null); + return null; + }); + EasyMock.replay(herder); + assertThrows(ConnectException.class, connectClusterState::connectors); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/isolation/DelegatingClassLoaderTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/isolation/DelegatingClassLoaderTest.java new file mode 100644 index 0000000..447eab6 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/isolation/DelegatingClassLoaderTest.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.runtime.isolation; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.File; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Collections; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class DelegatingClassLoaderTest { + + @Rule + public TemporaryFolder pluginDir = new TemporaryFolder(); + + @Test + public void testPermittedManifestResources() { + assertTrue( + DelegatingClassLoader.serviceLoaderManifestForPlugin("META-INF/services/org.apache.kafka.connect.rest.ConnectRestExtension")); + assertTrue( + DelegatingClassLoader.serviceLoaderManifestForPlugin("META-INF/services/org.apache.kafka.common.config.provider.ConfigProvider")); + } + + @Test + public void testOtherResources() { + assertFalse( + DelegatingClassLoader.serviceLoaderManifestForPlugin("META-INF/services/org.apache.kafka.connect.transforms.Transformation")); + assertFalse(DelegatingClassLoader.serviceLoaderManifestForPlugin("resource/version.properties")); + } + + @Test + public void testLoadingUnloadedPluginClass() throws ClassNotFoundException { + TestPlugins.assertAvailable(); + DelegatingClassLoader classLoader = new DelegatingClassLoader(Collections.emptyList()); + classLoader.initLoaders(); + for (String pluginClassName : TestPlugins.pluginClasses()) { + assertThrows(ClassNotFoundException.class, () -> classLoader.loadClass(pluginClassName)); + } + } + + @Test + public void testLoadingPluginClass() throws ClassNotFoundException { + TestPlugins.assertAvailable(); + DelegatingClassLoader classLoader = new DelegatingClassLoader(TestPlugins.pluginPath()); + classLoader.initLoaders(); + for (String pluginClassName : TestPlugins.pluginClasses()) { + assertNotNull(classLoader.loadClass(pluginClassName)); + assertNotNull(classLoader.pluginClassLoader(pluginClassName)); + } + } + + @Test + public void testLoadingInvalidUberJar() throws Exception { + pluginDir.newFile("invalid.jar"); + + DelegatingClassLoader classLoader = new DelegatingClassLoader( + Collections.singletonList(pluginDir.getRoot().getAbsolutePath())); + classLoader.initLoaders(); + } + + @Test + public void testLoadingPluginDirContainsInvalidJarsOnly() throws Exception { + pluginDir.newFolder("my-plugin"); + pluginDir.newFile("my-plugin/invalid.jar"); + + DelegatingClassLoader classLoader = new DelegatingClassLoader( + Collections.singletonList(pluginDir.getRoot().getAbsolutePath())); + classLoader.initLoaders(); + } + + @Test + public void testLoadingNoPlugins() throws Exception { + DelegatingClassLoader classLoader = new DelegatingClassLoader( + Collections.singletonList(pluginDir.getRoot().getAbsolutePath())); + classLoader.initLoaders(); + } + + @Test + public void testLoadingPluginDirEmpty() throws Exception { + pluginDir.newFolder("my-plugin"); + + DelegatingClassLoader classLoader = new DelegatingClassLoader( + Collections.singletonList(pluginDir.getRoot().getAbsolutePath())); + classLoader.initLoaders(); + } + + @Test + public void testLoadingMixOfValidAndInvalidPlugins() throws Exception { + TestPlugins.assertAvailable(); + + pluginDir.newFile("invalid.jar"); + pluginDir.newFolder("my-plugin"); + pluginDir.newFile("my-plugin/invalid.jar"); + Path pluginPath = this.pluginDir.getRoot().toPath(); + + for (String sourceJar : TestPlugins.pluginPath()) { + Path source = new File(sourceJar).toPath(); + Files.copy(source, pluginPath.resolve(source.getFileName())); + } + + DelegatingClassLoader classLoader = new DelegatingClassLoader( + Collections.singletonList(pluginDir.getRoot().getAbsolutePath())); + classLoader.initLoaders(); + for (String pluginClassName : TestPlugins.pluginClasses()) { + assertNotNull(classLoader.loadClass(pluginClassName)); + assertNotNull(classLoader.pluginClassLoader(pluginClassName)); + } + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/isolation/PluginDescTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/isolation/PluginDescTest.java new file mode 100644 index 0000000..72a2493 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/isolation/PluginDescTest.java @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.runtime.isolation; + +import org.apache.kafka.connect.connector.Connector; +import org.apache.kafka.connect.sink.SinkConnector; +import org.apache.kafka.connect.source.SourceConnector; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.transforms.Transformation; +import org.junit.Before; +import org.junit.Test; + +import java.net.URL; +import java.nio.file.Paths; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; + +public class PluginDescTest { + private final ClassLoader systemLoader = ClassLoader.getSystemClassLoader(); + private final String regularVersion = "1.0.0"; + private final String newerVersion = "1.0.1"; + private final String snaphotVersion = "1.0.0-SNAPSHOT"; + private final String noVersion = "undefined"; + private PluginClassLoader pluginLoader; + + @Before + public void setUp() throws Exception { + // Fairly simple use case, thus no need to create a random directory here yet. + URL location = Paths.get("/tmp").toUri().toURL(); + // Normally parent will be a DelegatingClassLoader. + pluginLoader = new PluginClassLoader(location, new URL[0], systemLoader); + } + + @SuppressWarnings("rawtypes") + @Test + public void testRegularPluginDesc() { + PluginDesc connectorDesc = new PluginDesc<>( + Connector.class, + regularVersion, + pluginLoader + ); + + assertPluginDesc(connectorDesc, Connector.class, regularVersion, pluginLoader.location()); + + PluginDesc converterDesc = new PluginDesc<>( + Converter.class, + snaphotVersion, + pluginLoader + ); + + assertPluginDesc(converterDesc, Converter.class, snaphotVersion, pluginLoader.location()); + + PluginDesc transformDesc = new PluginDesc<>( + Transformation.class, + noVersion, + pluginLoader + ); + + assertPluginDesc(transformDesc, Transformation.class, noVersion, pluginLoader.location()); + } + + @SuppressWarnings("rawtypes") + @Test + public void testPluginDescWithSystemClassLoader() { + String location = "classpath"; + PluginDesc connectorDesc = new PluginDesc<>( + SinkConnector.class, + regularVersion, + systemLoader + ); + + assertPluginDesc(connectorDesc, SinkConnector.class, regularVersion, location); + + PluginDesc converterDesc = new PluginDesc<>( + Converter.class, + snaphotVersion, + systemLoader + ); + + assertPluginDesc(converterDesc, Converter.class, snaphotVersion, location); + + PluginDesc transformDesc = new PluginDesc<>( + Transformation.class, + noVersion, + systemLoader + ); + + assertPluginDesc(transformDesc, Transformation.class, noVersion, location); + } + + @Test + public void testPluginDescWithNullVersion() { + String nullVersion = "null"; + PluginDesc connectorDesc = new PluginDesc<>( + SourceConnector.class, + null, + pluginLoader + ); + + assertPluginDesc( + connectorDesc, + SourceConnector.class, + nullVersion, + pluginLoader.location() + ); + + String location = "classpath"; + PluginDesc converterDesc = new PluginDesc<>( + Converter.class, + null, + systemLoader + ); + + assertPluginDesc(converterDesc, Converter.class, nullVersion, location); + } + + @SuppressWarnings("rawtypes") + @Test + public void testPluginDescEquality() { + PluginDesc connectorDescPluginPath = new PluginDesc<>( + Connector.class, + snaphotVersion, + pluginLoader + ); + + PluginDesc connectorDescClasspath = new PluginDesc<>( + Connector.class, + snaphotVersion, + systemLoader + ); + + assertEquals(connectorDescPluginPath, connectorDescClasspath); + assertEquals(connectorDescPluginPath.hashCode(), connectorDescClasspath.hashCode()); + + PluginDesc converterDescPluginPath = new PluginDesc<>( + Converter.class, + noVersion, + pluginLoader + ); + + PluginDesc converterDescClasspath = new PluginDesc<>( + Converter.class, + noVersion, + systemLoader + ); + + assertEquals(converterDescPluginPath, converterDescClasspath); + assertEquals(converterDescPluginPath.hashCode(), converterDescClasspath.hashCode()); + + PluginDesc transformDescPluginPath = new PluginDesc<>( + Transformation.class, + null, + pluginLoader + ); + + PluginDesc transformDescClasspath = new PluginDesc<>( + Transformation.class, + noVersion, + pluginLoader + ); + + assertNotEquals(transformDescPluginPath, transformDescClasspath); + } + + @SuppressWarnings("rawtypes") + @Test + public void testPluginDescComparison() { + PluginDesc connectorDescPluginPath = new PluginDesc<>( + Connector.class, + regularVersion, + pluginLoader + ); + + PluginDesc connectorDescClasspath = new PluginDesc<>( + Connector.class, + newerVersion, + systemLoader + ); + + assertNewer(connectorDescPluginPath, connectorDescClasspath); + + PluginDesc converterDescPluginPath = new PluginDesc<>( + Converter.class, + noVersion, + pluginLoader + ); + + PluginDesc converterDescClasspath = new PluginDesc<>( + Converter.class, + snaphotVersion, + systemLoader + ); + + assertNewer(converterDescPluginPath, converterDescClasspath); + + PluginDesc transformDescPluginPath = new PluginDesc<>( + Transformation.class, + null, + pluginLoader + ); + + PluginDesc transformDescClasspath = new PluginDesc<>( + Transformation.class, + regularVersion, + systemLoader + ); + + assertNewer(transformDescPluginPath, transformDescClasspath); + } + + private static void assertPluginDesc( + PluginDesc desc, + Class klass, + String version, + String location + ) { + assertEquals(desc.pluginClass(), klass); + assertEquals(desc.className(), klass.getName()); + assertEquals(desc.version(), version); + assertEquals(desc.type(), PluginType.from(klass)); + assertEquals(desc.typeName(), PluginType.from(klass).toString()); + assertEquals(desc.location(), location); + } + + private static void assertNewer(PluginDesc older, PluginDesc newer) { + assertTrue(newer + " should be newer than " + older, older.compareTo(newer) < 0); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/isolation/PluginUtilsTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/isolation/PluginUtilsTest.java new file mode 100644 index 0000000..1976698 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/isolation/PluginUtilsTest.java @@ -0,0 +1,517 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.isolation; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class PluginUtilsTest { + @Rule + public TemporaryFolder rootDir = new TemporaryFolder(); + private Path pluginPath; + + @Before + public void setUp() throws Exception { + pluginPath = rootDir.newFolder("plugins").toPath().toRealPath(); + } + + @Test + public void testJavaLibraryClasses() { + assertFalse(PluginUtils.shouldLoadInIsolation("java.")); + assertFalse(PluginUtils.shouldLoadInIsolation("java.lang.Object")); + assertFalse(PluginUtils.shouldLoadInIsolation("java.lang.String")); + assertFalse(PluginUtils.shouldLoadInIsolation("java.util.HashMap$Entry")); + assertFalse(PluginUtils.shouldLoadInIsolation("java.io.Serializable")); + assertFalse(PluginUtils.shouldLoadInIsolation("javax.rmi.")); + assertFalse(PluginUtils.shouldLoadInIsolation( + "javax.management.loading.ClassLoaderRepository") + ); + assertFalse(PluginUtils.shouldLoadInIsolation("org.omg.CORBA.")); + assertFalse(PluginUtils.shouldLoadInIsolation("org.omg.CORBA.Object")); + assertFalse(PluginUtils.shouldLoadInIsolation("org.w3c.dom.")); + assertFalse(PluginUtils.shouldLoadInIsolation("org.w3c.dom.traversal.TreeWalker")); + assertFalse(PluginUtils.shouldLoadInIsolation("org.xml.sax.")); + assertFalse(PluginUtils.shouldLoadInIsolation("org.xml.sax.EntityResolver")); + } + + @Test + public void testThirdPartyClasses() { + assertFalse(PluginUtils.shouldLoadInIsolation("org.slf4j.")); + assertFalse(PluginUtils.shouldLoadInIsolation("org.slf4j.LoggerFactory")); + } + + @Test + public void testKafkaDependencyClasses() { + assertFalse(PluginUtils.shouldLoadInIsolation("org.apache.kafka.common.")); + assertFalse(PluginUtils.shouldLoadInIsolation( + "org.apache.kafka.common.config.AbstractConfig") + ); + assertFalse(PluginUtils.shouldLoadInIsolation( + "org.apache.kafka.common.config.ConfigDef$Type") + ); + assertFalse(PluginUtils.shouldLoadInIsolation( + "org.apache.kafka.common.serialization.Deserializer") + ); + assertFalse(PluginUtils.shouldLoadInIsolation( + "org.apache.kafka.clients.producer.ProducerConfig") + ); + assertFalse(PluginUtils.shouldLoadInIsolation( + "org.apache.kafka.clients.consumer.ConsumerConfig") + ); + assertFalse(PluginUtils.shouldLoadInIsolation( + "org.apache.kafka.clients.admin.KafkaAdminClient") + ); + } + + @Test + public void testConnectApiClasses() { + List apiClasses = Arrays.asList( + // Enumerate all packages and classes + "org.apache.kafka.connect.", + "org.apache.kafka.connect.components.", + "org.apache.kafka.connect.components.Versioned", + //"org.apache.kafka.connect.connector.policy.", isolated by default + "org.apache.kafka.connect.connector.policy.ConnectorClientConfigOverridePolicy", + "org.apache.kafka.connect.connector.policy.ConnectorClientConfigRequest", + "org.apache.kafka.connect.connector.policy.ConnectorClientConfigRequest$ClientType", + "org.apache.kafka.connect.connector.", + "org.apache.kafka.connect.connector.Connector", + "org.apache.kafka.connect.connector.ConnectorContext", + "org.apache.kafka.connect.connector.ConnectRecord", + "org.apache.kafka.connect.connector.Task", + "org.apache.kafka.connect.data.", + "org.apache.kafka.connect.data.ConnectSchema", + "org.apache.kafka.connect.data.Date", + "org.apache.kafka.connect.data.Decimal", + "org.apache.kafka.connect.data.Field", + "org.apache.kafka.connect.data.Schema", + "org.apache.kafka.connect.data.SchemaAndValue", + "org.apache.kafka.connect.data.SchemaBuilder", + "org.apache.kafka.connect.data.SchemaProjector", + "org.apache.kafka.connect.data.Struct", + "org.apache.kafka.connect.data.Time", + "org.apache.kafka.connect.data.Timestamp", + "org.apache.kafka.connect.data.Values", + "org.apache.kafka.connect.errors.", + "org.apache.kafka.connect.errors.AlreadyExistsException", + "org.apache.kafka.connect.errors.ConnectException", + "org.apache.kafka.connect.errors.DataException", + "org.apache.kafka.connect.errors.IllegalWorkerStateException", + "org.apache.kafka.connect.errors.NotFoundException", + "org.apache.kafka.connect.errors.RetriableException", + "org.apache.kafka.connect.errors.SchemaBuilderException", + "org.apache.kafka.connect.errors.SchemaProjectorException", + "org.apache.kafka.connect.header.", + "org.apache.kafka.connect.header.ConnectHeader", + "org.apache.kafka.connect.header.ConnectHeaders", + "org.apache.kafka.connect.header.Header", + "org.apache.kafka.connect.header.Headers", + "org.apache.kafka.connect.health.", + "org.apache.kafka.connect.health.AbstractState", + "org.apache.kafka.connect.health.ConnectClusterDetails", + "org.apache.kafka.connect.health.ConnectClusterState", + "org.apache.kafka.connect.health.ConnectorHealth", + "org.apache.kafka.connect.health.ConnectorState", + "org.apache.kafka.connect.health.ConnectorType", + "org.apache.kafka.connect.health.TaskState", + "org.apache.kafka.connect.rest.", + "org.apache.kafka.connect.rest.ConnectRestExtension", + "org.apache.kafka.connect.rest.ConnectRestExtensionContext", + "org.apache.kafka.connect.sink.", + "org.apache.kafka.connect.sink.SinkConnector", + "org.apache.kafka.connect.sink.SinkRecord", + "org.apache.kafka.connect.sink.SinkTask", + "org.apache.kafka.connect.sink.SinkTaskContext", + "org.apache.kafka.connect.sink.ErrantRecordReporter", + "org.apache.kafka.connect.source.", + "org.apache.kafka.connect.source.SourceConnector", + "org.apache.kafka.connect.source.SourceRecord", + "org.apache.kafka.connect.source.SourceTask", + "org.apache.kafka.connect.source.SourceTaskContext", + "org.apache.kafka.connect.storage.", + "org.apache.kafka.connect.storage.Converter", + "org.apache.kafka.connect.storage.ConverterConfig", + "org.apache.kafka.connect.storage.ConverterType", + "org.apache.kafka.connect.storage.HeaderConverter", + "org.apache.kafka.connect.storage.OffsetStorageReader", + //"org.apache.kafka.connect.storage.SimpleHeaderConverter", explicitly isolated + //"org.apache.kafka.connect.storage.StringConverter", explicitly isolated + "org.apache.kafka.connect.storage.StringConverterConfig", + //"org.apache.kafka.connect.transforms.", isolated by default + "org.apache.kafka.connect.transforms.Transformation", + "org.apache.kafka.connect.transforms.predicates.Predicate", + "org.apache.kafka.connect.util.", + "org.apache.kafka.connect.util.ConnectorUtils" + ); + // Classes in the API should never be loaded in isolation. + for (String clazz : apiClasses) { + assertFalse( + clazz + " from 'api' is loaded in isolation but should not be", + PluginUtils.shouldLoadInIsolation(clazz) + ); + } + } + + @Test + public void testConnectRuntimeClasses() { + // Only list packages, because there are too many classes. + List runtimeClasses = Arrays.asList( + "org.apache.kafka.connect.cli.", + //"org.apache.kafka.connect.connector.policy.", isolated by default + //"org.apache.kafka.connect.converters.", isolated by default + "org.apache.kafka.connect.runtime.", + "org.apache.kafka.connect.runtime.distributed", + "org.apache.kafka.connect.runtime.errors", + "org.apache.kafka.connect.runtime.health", + "org.apache.kafka.connect.runtime.isolation", + "org.apache.kafka.connect.runtime.rest.", + "org.apache.kafka.connect.runtime.rest.entities.", + "org.apache.kafka.connect.runtime.rest.errors.", + "org.apache.kafka.connect.runtime.rest.resources.", + "org.apache.kafka.connect.runtime.rest.util.", + "org.apache.kafka.connect.runtime.standalone.", + "org.apache.kafka.connect.runtime.rest.", + "org.apache.kafka.connect.storage.", + "org.apache.kafka.connect.tools.", + "org.apache.kafka.connect.util." + ); + for (String clazz : runtimeClasses) { + assertFalse( + clazz + " from 'runtime' is loaded in isolation but should not be", + PluginUtils.shouldLoadInIsolation(clazz) + ); + } + } + + @Test + public void testAllowedRuntimeClasses() { + List jsonConverterClasses = Arrays.asList( + "org.apache.kafka.connect.connector.policy.", + "org.apache.kafka.connect.connector.policy.AbstractConnectorClientConfigOverridePolicy", + "org.apache.kafka.connect.connector.policy.AllConnectorClientConfigOverridePolicy", + "org.apache.kafka.connect.connector.policy.NoneConnectorClientConfigOverridePolicy", + "org.apache.kafka.connect.connector.policy.PrincipalConnectorClientConfigOverridePolicy", + "org.apache.kafka.connect.converters.", + "org.apache.kafka.connect.converters.ByteArrayConverter", + "org.apache.kafka.connect.converters.DoubleConverter", + "org.apache.kafka.connect.converters.FloatConverter", + "org.apache.kafka.connect.converters.IntegerConverter", + "org.apache.kafka.connect.converters.LongConverter", + "org.apache.kafka.connect.converters.NumberConverter", + "org.apache.kafka.connect.converters.NumberConverterConfig", + "org.apache.kafka.connect.converters.ShortConverter", + //"org.apache.kafka.connect.storage.", not isolated by default + "org.apache.kafka.connect.storage.StringConverter", + "org.apache.kafka.connect.storage.SimpleHeaderConverter" + ); + for (String clazz : jsonConverterClasses) { + assertTrue( + clazz + " from 'runtime' is not loaded in isolation but should be", + PluginUtils.shouldLoadInIsolation(clazz) + ); + } + } + + @Test + public void testTransformsClasses() { + List transformsClasses = Arrays.asList( + "org.apache.kafka.connect.transforms.", + "org.apache.kafka.connect.transforms.util.", + "org.apache.kafka.connect.transforms.util.NonEmptyListValidator", + "org.apache.kafka.connect.transforms.util.RegexValidator", + "org.apache.kafka.connect.transforms.util.Requirements", + "org.apache.kafka.connect.transforms.util.SchemaUtil", + "org.apache.kafka.connect.transforms.util.SimpleConfig", + "org.apache.kafka.connect.transforms.Cast", + "org.apache.kafka.connect.transforms.Cast$Key", + "org.apache.kafka.connect.transforms.Cast$Value", + "org.apache.kafka.connect.transforms.ExtractField", + "org.apache.kafka.connect.transforms.ExtractField$Key", + "org.apache.kafka.connect.transforms.ExtractField$Value", + "org.apache.kafka.connect.transforms.Flatten", + "org.apache.kafka.connect.transforms.Flatten$Key", + "org.apache.kafka.connect.transforms.Flatten$Value", + "org.apache.kafka.connect.transforms.HoistField", + "org.apache.kafka.connect.transforms.HoistField$Key", + "org.apache.kafka.connect.transforms.HoistField$Key", + "org.apache.kafka.connect.transforms.InsertField", + "org.apache.kafka.connect.transforms.InsertField$Key", + "org.apache.kafka.connect.transforms.InsertField$Value", + "org.apache.kafka.connect.transforms.MaskField", + "org.apache.kafka.connect.transforms.MaskField$Key", + "org.apache.kafka.connect.transforms.MaskField$Value", + "org.apache.kafka.connect.transforms.RegexRouter", + "org.apache.kafka.connect.transforms.ReplaceField", + "org.apache.kafka.connect.transforms.ReplaceField$Key", + "org.apache.kafka.connect.transforms.ReplaceField$Value", + "org.apache.kafka.connect.transforms.SetSchemaMetadata", + "org.apache.kafka.connect.transforms.SetSchemaMetadata$Key", + "org.apache.kafka.connect.transforms.SetSchemaMetadata$Value", + "org.apache.kafka.connect.transforms.TimestampConverter", + "org.apache.kafka.connect.transforms.TimestampConverter$Key", + "org.apache.kafka.connect.transforms.TimestampConverter$Value", + "org.apache.kafka.connect.transforms.TimestampRouter", + "org.apache.kafka.connect.transforms.TimestampRouter$Key", + "org.apache.kafka.connect.transforms.TimestampRouter$Value", + "org.apache.kafka.connect.transforms.ValueToKey", + "org.apache.kafka.connect.transforms.predicates.", + "org.apache.kafka.connect.transforms.predicates.HasHeaderKey", + "org.apache.kafka.connect.transforms.predicates.RecordIsTombstone", + "org.apache.kafka.connect.transforms.predicates.TopicNameMatches" + ); + for (String clazz : transformsClasses) { + assertTrue( + clazz + " from 'transforms' is not loaded in isolation but should be", + PluginUtils.shouldLoadInIsolation(clazz) + ); + } + } + + @Test + public void testAllowedJsonConverterClasses() { + List jsonConverterClasses = Arrays.asList( + "org.apache.kafka.connect.json.", + "org.apache.kafka.connect.json.DecimalFormat", + "org.apache.kafka.connect.json.JsonConverter", + "org.apache.kafka.connect.json.JsonConverterConfig", + "org.apache.kafka.connect.json.JsonDeserializer", + "org.apache.kafka.connect.json.JsonSchema", + "org.apache.kafka.connect.json.JsonSerializer" + ); + for (String clazz : jsonConverterClasses) { + assertTrue( + clazz + " from 'json' is not loaded in isolation but should be", + PluginUtils.shouldLoadInIsolation(clazz) + ); + } + } + + @Test + public void testAllowedFileConnectors() { + List jsonConverterClasses = Arrays.asList( + "org.apache.kafka.connect.file.", + "org.apache.kafka.connect.file.FileStreamSinkConnector", + "org.apache.kafka.connect.file.FileStreamSinkTask", + "org.apache.kafka.connect.file.FileStreamSourceConnector", + "org.apache.kafka.connect.file.FileStreamSourceTask" + ); + for (String clazz : jsonConverterClasses) { + assertTrue( + clazz + " from 'file' is not loaded in isolation but should be", + PluginUtils.shouldLoadInIsolation(clazz) + ); + } + } + + @Test + public void testAllowedBasicAuthExtensionClasses() { + List basicAuthExtensionClasses = Arrays.asList( + "org.apache.kafka.connect.rest.basic.auth.extension.BasicAuthSecurityRestExtension" + //"org.apache.kafka.connect.rest.basic.auth.extension.JaasBasicAuthFilter", TODO fix? + //"org.apache.kafka.connect.rest.basic.auth.extension.PropertyFileLoginModule" TODO fix? + ); + for (String clazz : basicAuthExtensionClasses) { + assertTrue( + clazz + " from 'basic-auth-extension' is not loaded in isolation but should be", + PluginUtils.shouldLoadInIsolation(clazz) + ); + } + } + + @Test + public void testMirrorClasses() { + assertTrue(PluginUtils.shouldLoadInIsolation( + "org.apache.kafka.connect.mirror.MirrorSourceTask") + ); + assertTrue(PluginUtils.shouldLoadInIsolation( + "org.apache.kafka.connect.mirror.MirrorSourceConnector") + ); + } + + @Test + public void testClientConfigProvider() { + assertFalse(PluginUtils.shouldLoadInIsolation( + "org.apache.kafka.common.config.provider.ConfigProvider") + ); + assertTrue(PluginUtils.shouldLoadInIsolation( + "org.apache.kafka.common.config.provider.FileConfigProvider") + ); + assertTrue(PluginUtils.shouldLoadInIsolation( + "org.apache.kafka.common.config.provider.FutureConfigProvider") + ); + } + + @Test + public void testEmptyPluginUrls() throws Exception { + assertEquals(Collections.emptyList(), PluginUtils.pluginUrls(pluginPath)); + } + + @Test + public void testEmptyStructurePluginUrls() throws Exception { + createBasicDirectoryLayout(); + assertEquals(Collections.emptyList(), PluginUtils.pluginUrls(pluginPath)); + } + + @Test + public void testPluginUrlsWithJars() throws Exception { + createBasicDirectoryLayout(); + + List expectedUrls = createBasicExpectedUrls(); + + assertUrls(expectedUrls, PluginUtils.pluginUrls(pluginPath)); + } + + @Test + public void testOrderOfPluginUrlsWithJars() throws Exception { + createBasicDirectoryLayout(); + // Here this method is just used to create the files. The result is not used. + createBasicExpectedUrls(); + + List actual = PluginUtils.pluginUrls(pluginPath); + // 'simple-transform.jar' is created first. In many cases, without sorting within the + // PluginUtils, this jar will be placed before 'another-transform.jar'. However this is + // not guaranteed because a DirectoryStream does not maintain a certain order in its + // results. Besides this test case, sorted order in every call to assertUrls below. + int i = Arrays.toString(actual.toArray()).indexOf("another-transform.jar"); + int j = Arrays.toString(actual.toArray()).indexOf("simple-transform.jar"); + assertTrue(i < j); + } + + @Test + public void testPluginUrlsWithZips() throws Exception { + createBasicDirectoryLayout(); + + List expectedUrls = new ArrayList<>(); + expectedUrls.add(Files.createFile(pluginPath.resolve("connectorA/my-sink.zip"))); + expectedUrls.add(Files.createFile(pluginPath.resolve("connectorB/a-source.zip"))); + expectedUrls.add(Files.createFile(pluginPath.resolve("transformC/simple-transform.zip"))); + expectedUrls.add(Files.createFile( + pluginPath.resolve("transformC/deps/another-transform.zip")) + ); + + assertUrls(expectedUrls, PluginUtils.pluginUrls(pluginPath)); + } + + @Test + public void testPluginUrlsWithClasses() throws Exception { + Files.createDirectories(pluginPath.resolve("org/apache/kafka/converters")); + Files.createDirectories(pluginPath.resolve("com/mycompany/transforms")); + Files.createDirectories(pluginPath.resolve("edu/research/connectors")); + Files.createFile(pluginPath.resolve("org/apache/kafka/converters/README.txt")); + Files.createFile(pluginPath.resolve("org/apache/kafka/converters/AlienFormat.class")); + Files.createDirectories(pluginPath.resolve("com/mycompany/transforms/Blackhole.class")); + Files.createDirectories(pluginPath.resolve("edu/research/connectors/HalSink.class")); + + List expectedUrls = new ArrayList<>(); + expectedUrls.add(pluginPath); + + assertUrls(expectedUrls, PluginUtils.pluginUrls(pluginPath)); + } + + @Test + public void testPluginUrlsWithAbsoluteSymlink() throws Exception { + createBasicDirectoryLayout(); + + Path anotherPath = rootDir.newFolder("moreplugins").toPath().toRealPath(); + Files.createDirectories(anotherPath.resolve("connectorB-deps")); + Files.createSymbolicLink( + pluginPath.resolve("connectorB/deps/symlink"), + anotherPath.resolve("connectorB-deps") + ); + + List expectedUrls = createBasicExpectedUrls(); + expectedUrls.add(Files.createFile(anotherPath.resolve("connectorB-deps/converter.jar"))); + + assertUrls(expectedUrls, PluginUtils.pluginUrls(pluginPath)); + } + + @Test + public void testPluginUrlsWithRelativeSymlinkBackwards() throws Exception { + createBasicDirectoryLayout(); + + Path anotherPath = rootDir.newFolder("moreplugins").toPath().toRealPath(); + Files.createDirectories(anotherPath.resolve("connectorB-deps")); + Files.createSymbolicLink( + pluginPath.resolve("connectorB/deps/symlink"), + Paths.get("../../../moreplugins/connectorB-deps") + ); + + List expectedUrls = createBasicExpectedUrls(); + expectedUrls.add(Files.createFile(anotherPath.resolve("connectorB-deps/converter.jar"))); + + assertUrls(expectedUrls, PluginUtils.pluginUrls(pluginPath)); + } + + @Test + public void testPluginUrlsWithRelativeSymlinkForwards() throws Exception { + // Since this test case defines a relative symlink within an already included path, the main + // assertion of this test is absence of exceptions and correct resolution of paths. + createBasicDirectoryLayout(); + Files.createDirectories(pluginPath.resolve("connectorB/deps/more")); + Files.createSymbolicLink( + pluginPath.resolve("connectorB/deps/symlink"), + Paths.get("more") + ); + + List expectedUrls = createBasicExpectedUrls(); + expectedUrls.add( + Files.createFile(pluginPath.resolve("connectorB/deps/more/converter.jar")) + ); + + assertUrls(expectedUrls, PluginUtils.pluginUrls(pluginPath)); + } + + private void createBasicDirectoryLayout() throws IOException { + Files.createDirectories(pluginPath.resolve("connectorA")); + Files.createDirectories(pluginPath.resolve("connectorB/deps")); + Files.createDirectories(pluginPath.resolve("transformC/deps")); + Files.createDirectories(pluginPath.resolve("transformC/more-deps")); + Files.createFile(pluginPath.resolve("transformC/more-deps/README.txt")); + } + + private List createBasicExpectedUrls() throws IOException { + List expectedUrls = new ArrayList<>(); + expectedUrls.add(Files.createFile(pluginPath.resolve("connectorA/my-sink.jar"))); + expectedUrls.add(Files.createFile(pluginPath.resolve("connectorB/a-source.jar"))); + expectedUrls.add(Files.createFile(pluginPath.resolve("transformC/simple-transform.jar"))); + expectedUrls.add(Files.createFile( + pluginPath.resolve("transformC/deps/another-transform.jar")) + ); + return expectedUrls; + } + + private void assertUrls(List expected, List actual) { + Collections.sort(expected); + // not sorting 'actual' because it should be returned sorted from withing the PluginUtils. + assertEquals(expected, actual); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/isolation/PluginsTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/isolation/PluginsTest.java new file mode 100644 index 0000000..5083a2d --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/isolation/PluginsTest.java @@ -0,0 +1,472 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.runtime.isolation; + +import java.util.Collections; +import java.util.Map.Entry; +import org.apache.kafka.common.Configurable; +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.config.provider.ConfigProvider; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.json.JsonConverter; +import org.apache.kafka.connect.json.JsonConverterConfig; +import org.apache.kafka.connect.rest.ConnectRestExtension; +import org.apache.kafka.connect.rest.ConnectRestExtensionContext; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.apache.kafka.connect.runtime.isolation.Plugins.ClassLoaderUsage; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.storage.ConverterConfig; +import org.apache.kafka.connect.storage.ConverterType; +import org.apache.kafka.connect.storage.HeaderConverter; +import org.apache.kafka.connect.storage.SimpleHeaderConverter; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class PluginsTest { + + private Plugins plugins; + private Map props; + private AbstractConfig config; + private TestConverter converter; + private TestHeaderConverter headerConverter; + private TestInternalConverter internalConverter; + + @SuppressWarnings("deprecation") + @Before + public void setup() { + Map pluginProps = new HashMap<>(); + + // Set up the plugins with some test plugins to test isolation + pluginProps.put(WorkerConfig.PLUGIN_PATH_CONFIG, String.join(",", TestPlugins.pluginPath())); + plugins = new Plugins(pluginProps); + props = new HashMap<>(pluginProps); + props.put(WorkerConfig.KEY_CONVERTER_CLASS_CONFIG, TestConverter.class.getName()); + props.put(WorkerConfig.VALUE_CONVERTER_CLASS_CONFIG, TestConverter.class.getName()); + props.put("key.converter." + JsonConverterConfig.SCHEMAS_ENABLE_CONFIG, "true"); + props.put("value.converter." + JsonConverterConfig.SCHEMAS_ENABLE_CONFIG, "true"); + props.put("key.converter.extra.config", "foo1"); + props.put("value.converter.extra.config", "foo2"); + props.put(WorkerConfig.HEADER_CONVERTER_CLASS_CONFIG, TestHeaderConverter.class.getName()); + props.put("header.converter.extra.config", "baz"); + + createConfig(); + } + + protected void createConfig() { + this.config = new TestableWorkerConfig(props); + } + + @Test + public void shouldInstantiateAndConfigureConverters() { + instantiateAndConfigureConverter(WorkerConfig.KEY_CONVERTER_CLASS_CONFIG, ClassLoaderUsage.CURRENT_CLASSLOADER); + // Validate extra configs got passed through to overridden converters + assertEquals("true", converter.configs.get(JsonConverterConfig.SCHEMAS_ENABLE_CONFIG)); + assertEquals("foo1", converter.configs.get("extra.config")); + + instantiateAndConfigureConverter(WorkerConfig.VALUE_CONVERTER_CLASS_CONFIG, ClassLoaderUsage.PLUGINS); + // Validate extra configs got passed through to overridden converters + assertEquals("true", converter.configs.get(JsonConverterConfig.SCHEMAS_ENABLE_CONFIG)); + assertEquals("foo2", converter.configs.get("extra.config")); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldInstantiateAndConfigureInternalConverters() { + instantiateAndConfigureInternalConverter(true, Collections.singletonMap(JsonConverterConfig.SCHEMAS_ENABLE_CONFIG, "false")); + // Validate schemas.enable is set to false + assertEquals("false", internalConverter.configs.get(JsonConverterConfig.SCHEMAS_ENABLE_CONFIG)); + } + + @Test + public void shouldInstantiateAndConfigureExplicitlySetHeaderConverterWithCurrentClassLoader() { + assertNotNull(props.get(WorkerConfig.HEADER_CONVERTER_CLASS_CONFIG)); + HeaderConverter headerConverter = plugins.newHeaderConverter(config, + WorkerConfig.HEADER_CONVERTER_CLASS_CONFIG, + ClassLoaderUsage.CURRENT_CLASSLOADER); + assertNotNull(headerConverter); + assertTrue(headerConverter instanceof TestHeaderConverter); + this.headerConverter = (TestHeaderConverter) headerConverter; + + // Validate extra configs got passed through to overridden converters + assertConverterType(ConverterType.HEADER, this.headerConverter.configs); + assertEquals("baz", this.headerConverter.configs.get("extra.config")); + + headerConverter = plugins.newHeaderConverter(config, + WorkerConfig.HEADER_CONVERTER_CLASS_CONFIG, + ClassLoaderUsage.PLUGINS); + assertNotNull(headerConverter); + assertTrue(headerConverter instanceof TestHeaderConverter); + this.headerConverter = (TestHeaderConverter) headerConverter; + + // Validate extra configs got passed through to overridden converters + assertConverterType(ConverterType.HEADER, this.headerConverter.configs); + assertEquals("baz", this.headerConverter.configs.get("extra.config")); + } + + @Test + public void shouldInstantiateAndConfigureConnectRestExtension() { + props.put(WorkerConfig.REST_EXTENSION_CLASSES_CONFIG, + TestConnectRestExtension.class.getName()); + createConfig(); + + List connectRestExtensions = + plugins.newPlugins(config.getList(WorkerConfig.REST_EXTENSION_CLASSES_CONFIG), + config, + ConnectRestExtension.class); + assertNotNull(connectRestExtensions); + assertEquals("One Rest Extension expected", 1, connectRestExtensions.size()); + assertNotNull(connectRestExtensions.get(0)); + assertTrue("Should be instance of TestConnectRestExtension", + connectRestExtensions.get(0) instanceof TestConnectRestExtension); + assertNotNull(((TestConnectRestExtension) connectRestExtensions.get(0)).configs); + assertEquals(config.originals(), + ((TestConnectRestExtension) connectRestExtensions.get(0)).configs); + } + + @Test + public void shouldInstantiateAndConfigureDefaultHeaderConverter() { + props.remove(WorkerConfig.HEADER_CONVERTER_CLASS_CONFIG); + createConfig(); + + // Because it's not explicitly set on the supplied configuration, the logic to use the current classloader for the connector + // will exit immediately, and so this method always returns null + HeaderConverter headerConverter = plugins.newHeaderConverter(config, + WorkerConfig.HEADER_CONVERTER_CLASS_CONFIG, + ClassLoaderUsage.CURRENT_CLASSLOADER); + assertNull(headerConverter); + // But we should always find it (or the worker's default) when using the plugins classloader ... + headerConverter = plugins.newHeaderConverter(config, + WorkerConfig.HEADER_CONVERTER_CLASS_CONFIG, + ClassLoaderUsage.PLUGINS); + assertNotNull(headerConverter); + assertTrue(headerConverter instanceof SimpleHeaderConverter); + } + + @Test + public void shouldThrowIfPluginThrows() { + TestPlugins.assertAvailable(); + + assertThrows(ConnectException.class, () -> plugins.newPlugin( + TestPlugins.ALWAYS_THROW_EXCEPTION, + new AbstractConfig(new ConfigDef(), Collections.emptyMap()), + Converter.class + )); + } + + @Test + public void shouldShareStaticValuesBetweenSamePlugin() { + // Plugins are not isolated from other instances of their own class. + TestPlugins.assertAvailable(); + Converter firstPlugin = plugins.newPlugin( + TestPlugins.ALIASED_STATIC_FIELD, + new AbstractConfig(new ConfigDef(), Collections.emptyMap()), + Converter.class + ); + + assertInstanceOf(SamplingTestPlugin.class, firstPlugin, "Cannot collect samples"); + + Converter secondPlugin = plugins.newPlugin( + TestPlugins.ALIASED_STATIC_FIELD, + new AbstractConfig(new ConfigDef(), Collections.emptyMap()), + Converter.class + ); + + assertInstanceOf(SamplingTestPlugin.class, secondPlugin, "Cannot collect samples"); + assertSame( + ((SamplingTestPlugin) firstPlugin).otherSamples(), + ((SamplingTestPlugin) secondPlugin).otherSamples() + ); + } + + @Test + public void newPluginShouldServiceLoadWithPluginClassLoader() { + TestPlugins.assertAvailable(); + Converter plugin = plugins.newPlugin( + TestPlugins.SERVICE_LOADER, + new AbstractConfig(new ConfigDef(), Collections.emptyMap()), + Converter.class + ); + + assertInstanceOf(SamplingTestPlugin.class, plugin, "Cannot collect samples"); + Map samples = ((SamplingTestPlugin) plugin).flatten(); + // Assert that the service loaded subclass is found in both environments + assertTrue(samples.containsKey("ServiceLoadedSubclass.static")); + assertTrue(samples.containsKey("ServiceLoadedSubclass.dynamic")); + assertPluginClassLoaderAlwaysActive(samples); + } + + @Test + public void newPluginShouldInstantiateWithPluginClassLoader() { + TestPlugins.assertAvailable(); + Converter plugin = plugins.newPlugin( + TestPlugins.ALIASED_STATIC_FIELD, + new AbstractConfig(new ConfigDef(), Collections.emptyMap()), + Converter.class + ); + + assertInstanceOf(SamplingTestPlugin.class, plugin, "Cannot collect samples"); + Map samples = ((SamplingTestPlugin) plugin).flatten(); + assertPluginClassLoaderAlwaysActive(samples); + } + + @Test + public void shouldFailToFindConverterInCurrentClassloader() { + TestPlugins.assertAvailable(); + props.put(WorkerConfig.KEY_CONVERTER_CLASS_CONFIG, TestPlugins.SAMPLING_CONVERTER); + assertThrows(ConfigException.class, this::createConfig); + } + + @Test + public void newConverterShouldConfigureWithPluginClassLoader() { + TestPlugins.assertAvailable(); + props.put(WorkerConfig.KEY_CONVERTER_CLASS_CONFIG, TestPlugins.SAMPLING_CONVERTER); + ClassLoader classLoader = plugins.delegatingLoader().pluginClassLoader(TestPlugins.SAMPLING_CONVERTER); + ClassLoader savedLoader = Plugins.compareAndSwapLoaders(classLoader); + createConfig(); + Plugins.compareAndSwapLoaders(savedLoader); + + Converter plugin = plugins.newConverter( + config, + WorkerConfig.KEY_CONVERTER_CLASS_CONFIG, + ClassLoaderUsage.PLUGINS + ); + + assertInstanceOf(SamplingTestPlugin.class, plugin, "Cannot collect samples"); + Map samples = ((SamplingTestPlugin) plugin).flatten(); + assertTrue(samples.containsKey("configure")); + assertPluginClassLoaderAlwaysActive(samples); + } + + @Test + public void newConfigProviderShouldConfigureWithPluginClassLoader() { + TestPlugins.assertAvailable(); + String providerPrefix = "some.provider"; + props.put(providerPrefix + ".class", TestPlugins.SAMPLING_CONFIG_PROVIDER); + + PluginClassLoader classLoader = plugins.delegatingLoader().pluginClassLoader(TestPlugins.SAMPLING_CONFIG_PROVIDER); + assertNotNull(classLoader); + ClassLoader savedLoader = Plugins.compareAndSwapLoaders(classLoader); + createConfig(); + Plugins.compareAndSwapLoaders(savedLoader); + + ConfigProvider plugin = plugins.newConfigProvider( + config, + providerPrefix, + ClassLoaderUsage.PLUGINS + ); + + assertInstanceOf(SamplingTestPlugin.class, plugin, "Cannot collect samples"); + Map samples = ((SamplingTestPlugin) plugin).flatten(); + assertTrue(samples.containsKey("configure")); + assertPluginClassLoaderAlwaysActive(samples); + } + + @Test + public void newHeaderConverterShouldConfigureWithPluginClassLoader() { + TestPlugins.assertAvailable(); + props.put(WorkerConfig.HEADER_CONVERTER_CLASS_CONFIG, TestPlugins.SAMPLING_HEADER_CONVERTER); + ClassLoader classLoader = plugins.delegatingLoader().pluginClassLoader(TestPlugins.SAMPLING_HEADER_CONVERTER); + ClassLoader savedLoader = Plugins.compareAndSwapLoaders(classLoader); + createConfig(); + Plugins.compareAndSwapLoaders(savedLoader); + + HeaderConverter plugin = plugins.newHeaderConverter( + config, + WorkerConfig.HEADER_CONVERTER_CLASS_CONFIG, + ClassLoaderUsage.PLUGINS + ); + + assertInstanceOf(SamplingTestPlugin.class, plugin, "Cannot collect samples"); + Map samples = ((SamplingTestPlugin) plugin).flatten(); + assertTrue(samples.containsKey("configure")); // HeaderConverter::configure was called + assertPluginClassLoaderAlwaysActive(samples); + } + + @Test + public void newPluginsShouldConfigureWithPluginClassLoader() { + TestPlugins.assertAvailable(); + List configurables = plugins.newPlugins( + Collections.singletonList(TestPlugins.SAMPLING_CONFIGURABLE), + config, + Configurable.class + ); + assertEquals(1, configurables.size()); + Configurable plugin = configurables.get(0); + + assertInstanceOf(SamplingTestPlugin.class, plugin, "Cannot collect samples"); + Map samples = ((SamplingTestPlugin) plugin).flatten(); + assertTrue(samples.containsKey("configure")); // Configurable::configure was called + assertPluginClassLoaderAlwaysActive(samples); + } + + public static void assertPluginClassLoaderAlwaysActive(Map samples) { + for (Entry e : samples.entrySet()) { + String sampleName = "\"" + e.getKey() + "\" (" + e.getValue() + ")"; + assertInstanceOf( + PluginClassLoader.class, + e.getValue().staticClassloader(), + sampleName + " has incorrect static classloader" + ); + assertInstanceOf( + PluginClassLoader.class, + e.getValue().classloader(), + sampleName + " has incorrect dynamic classloader" + ); + } + } + + public static void assertInstanceOf(Class expected, Object actual, String message) { + assertTrue( + "Expected an instance of " + expected.getSimpleName() + ", found " + actual + " instead: " + message, + expected.isInstance(actual) + ); + } + + protected void instantiateAndConfigureConverter(String configPropName, ClassLoaderUsage classLoaderUsage) { + converter = (TestConverter) plugins.newConverter(config, configPropName, classLoaderUsage); + assertNotNull(converter); + } + + protected void instantiateAndConfigureHeaderConverter(String configPropName) { + headerConverter = (TestHeaderConverter) plugins.newHeaderConverter(config, configPropName, ClassLoaderUsage.CURRENT_CLASSLOADER); + assertNotNull(headerConverter); + } + + protected void instantiateAndConfigureInternalConverter(boolean isKey, Map config) { + internalConverter = (TestInternalConverter) plugins.newInternalConverter(isKey, TestInternalConverter.class.getName(), config); + assertNotNull(internalConverter); + } + + protected void assertConverterType(ConverterType type, Map props) { + assertEquals(type.getName(), props.get(ConverterConfig.TYPE_CONFIG)); + } + + public static class TestableWorkerConfig extends WorkerConfig { + public TestableWorkerConfig(Map props) { + super(WorkerConfig.baseConfigDef(), props); + } + } + + public static class TestConverter implements Converter, Configurable { + public Map configs; + + public ConfigDef config() { + return JsonConverterConfig.configDef(); + } + + @Override + public void configure(Map configs) { + this.configs = configs; + new JsonConverterConfig(configs); // requires the `converter.type` config be set + } + + @Override + public void configure(Map configs, boolean isKey) { + this.configs = configs; + } + + @Override + public byte[] fromConnectData(String topic, Schema schema, Object value) { + return new byte[0]; + } + + @Override + public SchemaAndValue toConnectData(String topic, byte[] value) { + return null; + } + } + + public static class TestHeaderConverter implements HeaderConverter { + public Map configs; + + @Override + public ConfigDef config() { + return JsonConverterConfig.configDef(); + } + + @Override + public void configure(Map configs) { + this.configs = configs; + new JsonConverterConfig(configs); // requires the `converter.type` config be set + } + + @Override + public byte[] fromConnectHeader(String topic, String headerKey, Schema schema, Object value) { + return new byte[0]; + } + + @Override + public SchemaAndValue toConnectHeader(String topic, String headerKey, byte[] value) { + return null; + } + + @Override + public void close() throws IOException { + } + } + + + public static class TestConnectRestExtension implements ConnectRestExtension { + + public Map configs; + + @Override + public void register(ConnectRestExtensionContext restPluginContext) { + } + + @Override + public void close() throws IOException { + } + + @Override + public void configure(Map configs) { + this.configs = configs; + } + + @Override + public String version() { + return "test"; + } + } + + public static class TestInternalConverter extends JsonConverter { + public Map configs; + + @Override + public void configure(Map configs) { + this.configs = configs; + super.configure(configs); + } + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/isolation/SamplingTestPlugin.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/isolation/SamplingTestPlugin.java new file mode 100644 index 0000000..bcf8881 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/isolation/SamplingTestPlugin.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.runtime.isolation; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Map.Entry; + +/** + * Base class for plugins so we can sample information about their initialization + */ +public abstract class SamplingTestPlugin { + + /** + * @return the ClassLoader used to statically initialize this plugin class + */ + public abstract ClassLoader staticClassloader(); + + /** + * @return the ClassLoader used to initialize this plugin instance + */ + public abstract ClassLoader classloader(); + + /** + * @return a group of other SamplingTestPlugin instances known by this plugin + * This should only return direct children, and not reference this instance directly + */ + public Map otherSamples() { + return Collections.emptyMap(); + } + + /** + * @return a flattened list of child samples including this entry keyed as "this" + */ + public Map flatten() { + Map out = new HashMap<>(); + Map otherSamples = otherSamples(); + if (otherSamples != null) { + for (Entry child : otherSamples.entrySet()) { + for (Entry flattened : child.getValue().flatten().entrySet()) { + String key = child.getKey(); + if (flattened.getKey().length() > 0) { + key += "." + flattened.getKey(); + } + out.put(key, flattened.getValue()); + } + } + } + out.put("", this); + return out; + } + + /** + * Log the parent method call as a child sample. + * Stores only the last invocation of each method if there are multiple invocations. + * @param samples The collection of samples to which this method call should be added + */ + public void logMethodCall(Map samples) { + StackTraceElement[] stackTraces = Thread.currentThread().getStackTrace(); + if (stackTraces.length < 2) { + return; + } + // 0 is inside getStackTrace + // 1 is this method + // 2 is our caller method + StackTraceElement caller = stackTraces[2]; + + samples.put(caller.getMethodName(), new MethodCallSample( + caller, + Thread.currentThread().getContextClassLoader(), + getClass().getClassLoader() + )); + } + + public static class MethodCallSample extends SamplingTestPlugin { + + private final StackTraceElement caller; + private final ClassLoader staticClassLoader; + private final ClassLoader dynamicClassLoader; + + public MethodCallSample( + StackTraceElement caller, + ClassLoader staticClassLoader, + ClassLoader dynamicClassLoader + ) { + this.caller = caller; + this.staticClassLoader = staticClassLoader; + this.dynamicClassLoader = dynamicClassLoader; + } + + @Override + public ClassLoader staticClassloader() { + return staticClassLoader; + } + + @Override + public ClassLoader classloader() { + return dynamicClassLoader; + } + + @Override + public String toString() { + return caller.toString(); + } + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/isolation/SynchronizationTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/isolation/SynchronizationTest.java new file mode 100644 index 0000000..d23ada5 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/isolation/SynchronizationTest.java @@ -0,0 +1,471 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.runtime.isolation; + +import static org.junit.Assert.fail; + +import java.lang.management.LockInfo; +import java.lang.management.ManagementFactory; +import java.lang.management.MonitorInfo; +import java.lang.management.ThreadInfo; +import java.net.URL; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.BrokenBarrierException; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigDef.Importance; +import org.apache.kafka.common.config.ConfigDef.Type; +import org.apache.kafka.connect.json.JsonConverter; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestName; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class SynchronizationTest { + + public static final Logger log = LoggerFactory.getLogger(SynchronizationTest.class); + + @Rule + public final TestName testName = new TestName(); + + private String threadPrefix; + private Plugins plugins; + private ThreadPoolExecutor exec; + private Breakpoint dclBreakpoint; + private Breakpoint pclBreakpoint; + + @Before + public void setup() { + TestPlugins.assertAvailable(); + Map pluginProps = Collections.singletonMap( + WorkerConfig.PLUGIN_PATH_CONFIG, + String.join(",", TestPlugins.pluginPath()) + ); + threadPrefix = SynchronizationTest.class.getSimpleName() + + "." + testName.getMethodName() + "-"; + dclBreakpoint = new Breakpoint<>(); + pclBreakpoint = new Breakpoint<>(); + plugins = new Plugins(pluginProps) { + @Override + protected DelegatingClassLoader newDelegatingClassLoader(List paths) { + return AccessController.doPrivileged( + (PrivilegedAction) () -> + new SynchronizedDelegatingClassLoader(paths) + ); + } + }; + exec = new ThreadPoolExecutor( + 2, + 2, + 1000L, + TimeUnit.MILLISECONDS, + new LinkedBlockingDeque<>(), + threadFactoryWithNamedThreads(threadPrefix) + ); + + } + + @After + public void tearDown() throws InterruptedException { + dclBreakpoint.clear(); + pclBreakpoint.clear(); + exec.shutdown(); + exec.awaitTermination(1L, TimeUnit.SECONDS); + } + + private static class Breakpoint { + + private Predicate predicate; + private CyclicBarrier barrier; + + public synchronized void clear() { + if (barrier != null) { + barrier.reset(); + } + predicate = null; + barrier = null; + } + + public synchronized void set(Predicate predicate) { + clear(); + this.predicate = predicate; + // As soon as the barrier is tripped, the barrier will be reset for the next round. + barrier = new CyclicBarrier(2); + } + + /** + * From a thread under test, await for the test orchestrator to continue execution + * @param obj Object to test with the breakpoint's current predicate + */ + public void await(T obj) { + Predicate predicate; + CyclicBarrier barrier; + synchronized (this) { + predicate = this.predicate; + barrier = this.barrier; + } + if (predicate != null && !predicate.test(obj)) { + return; + } + if (barrier != null) { + try { + barrier.await(); + } catch (InterruptedException | BrokenBarrierException e) { + throw new RuntimeException("Interrupted while waiting for load gate", e); + } + } + } + + /** + * From the test orchestrating thread, await for the test thread to continue execution + * @throws InterruptedException If the current thread is interrupted while waiting + * @throws BrokenBarrierException If the test thread is interrupted while waiting + * @throws TimeoutException If the barrier is not reached before 1s passes. + */ + public void testAwait() + throws InterruptedException, BrokenBarrierException, TimeoutException { + CyclicBarrier barrier; + synchronized (this) { + barrier = this.barrier; + } + Objects.requireNonNull(barrier, "Barrier must be set up before awaiting"); + barrier.await(1L, TimeUnit.SECONDS); + } + } + + private class SynchronizedDelegatingClassLoader extends DelegatingClassLoader { + { + ClassLoader.registerAsParallelCapable(); + } + + public SynchronizedDelegatingClassLoader(List pluginPaths) { + super(pluginPaths); + } + + @Override + protected PluginClassLoader newPluginClassLoader( + URL pluginLocation, + URL[] urls, + ClassLoader parent + ) { + return AccessController.doPrivileged( + (PrivilegedAction) () -> + new SynchronizedPluginClassLoader(pluginLocation, urls, parent) + ); + } + + @Override + public PluginClassLoader pluginClassLoader(String name) { + dclBreakpoint.await(name); + dclBreakpoint.await(name); + return super.pluginClassLoader(name); + } + } + + private class SynchronizedPluginClassLoader extends PluginClassLoader { + { + ClassLoader.registerAsParallelCapable(); + } + + + public SynchronizedPluginClassLoader(URL pluginLocation, URL[] urls, ClassLoader parent) { + super(pluginLocation, urls, parent); + } + + @Override + protected Object getClassLoadingLock(String className) { + pclBreakpoint.await(className); + return super.getClassLoadingLock(className); + } + } + + // If the test times out, then there's a deadlock in the test but not necessarily the code + @Test(timeout = 15000L) + public void testSimultaneousUpwardAndDownwardDelegating() throws Exception { + String t1Class = TestPlugins.SAMPLING_CONVERTER; + // Grab a reference to the target PluginClassLoader before activating breakpoints + ClassLoader connectorLoader = plugins.delegatingLoader().connectorLoader(t1Class); + + // THREAD 1: loads a class by delegating downward starting from the DelegatingClassLoader + // DelegatingClassLoader breakpoint will only trigger on this thread + dclBreakpoint.set(t1Class::equals); + Runnable thread1 = () -> { + // Use the DelegatingClassLoader as the current context loader + ClassLoader savedLoader = Plugins.compareAndSwapLoaders(plugins.delegatingLoader()); + + // Load an isolated plugin from the delegating classloader, which will + // 1. Enter the DelegatingClassLoader + // 2. Wait on dclBreakpoint for test to continue + // 3. Enter the PluginClassLoader + // 4. Load the isolated plugin class and return + new AbstractConfig( + new ConfigDef().define("a.class", Type.CLASS, Importance.HIGH, ""), + Collections.singletonMap("a.class", t1Class)); + Plugins.compareAndSwapLoaders(savedLoader); + }; + + // THREAD 2: loads a class by delegating upward starting from the PluginClassLoader + String t2Class = JsonConverter.class.getName(); + // PluginClassLoader breakpoint will only trigger on this thread + pclBreakpoint.set(t2Class::equals); + Runnable thread2 = () -> { + // Use the PluginClassLoader as the current context loader + ClassLoader savedLoader = Plugins.compareAndSwapLoaders(connectorLoader); + // Load a non-isolated class from the plugin classloader, which will + // 1. Enter the PluginClassLoader + // 2. Wait for the test to continue + // 3. Enter the DelegatingClassLoader + // 4. Load the non-isolated class and return + new AbstractConfig(new ConfigDef().define("a.class", Type.CLASS, Importance.HIGH, ""), + Collections.singletonMap("a.class", t2Class)); + Plugins.compareAndSwapLoaders(savedLoader); + }; + + // STEP 1: Have T1 enter the DelegatingClassLoader and pause + exec.submit(thread1); + // T1 enters ConfigDef::parseType + // T1 enters DelegatingClassLoader::loadClass + dclBreakpoint.testAwait(); + dclBreakpoint.testAwait(); + // T1 exits DelegatingClassLoader::loadClass + // T1 enters Class::forName + // T1 enters DelegatingClassLoader::loadClass + dclBreakpoint.testAwait(); + // T1 waits in the delegating classloader while we set up the other thread + dumpThreads("step 1, T1 waiting in DelegatingClassLoader"); + + // STEP 2: Have T2 enter PluginClassLoader, delegate upward to the Delegating classloader + exec.submit(thread2); + // T2 enters PluginClassLoader::loadClass + pclBreakpoint.testAwait(); + // T2 falls through to ClassLoader::loadClass + pclBreakpoint.testAwait(); + // T2 delegates upwards to DelegatingClassLoader::loadClass + // T2 enters ClassLoader::loadClass and loads the class from the parent (CLASSPATH) + dumpThreads("step 2, T2 entered DelegatingClassLoader and is loading class from parent"); + + // STEP 3: Resume T1 and have it enter the PluginClassLoader + dclBreakpoint.testAwait(); + // T1 enters PluginClassLoader::loadClass + dumpThreads("step 3, T1 entered PluginClassLoader and is/was loading class from isolated jar"); + + // If the DelegatingClassLoader and PluginClassLoader are both not parallel capable, then this test will deadlock + // Otherwise, T1 should be able to complete it's load from the PluginClassLoader concurrently with T2, + // before releasing the DelegatingClassLoader and allowing T2 to complete. + // As the DelegatingClassLoader is not parallel capable, it must be the case that PluginClassLoader is. + assertNoDeadlocks(); + } + + // If the test times out, then there's a deadlock in the test but not necessarily the code + @Test(timeout = 15000L) + // Ensure the PluginClassLoader is parallel capable and not synchronized on its monitor lock + public void testPluginClassLoaderDoesntHoldMonitorLock() + throws InterruptedException, TimeoutException, BrokenBarrierException { + String t1Class = TestPlugins.SAMPLING_CONVERTER; + ClassLoader connectorLoader = plugins.delegatingLoader().connectorLoader(t1Class); + + Object externalTestLock = new Object(); + Breakpoint testBreakpoint = new Breakpoint<>(); + Breakpoint progress = new Breakpoint<>(); + + // THREAD 1: hold the PluginClassLoader's monitor lock, and attempt to grab the external lock + testBreakpoint.set(null); + Runnable thread1 = () -> { + synchronized (connectorLoader) { + testBreakpoint.await(null); + testBreakpoint.await(null); + synchronized (externalTestLock) { + } + } + }; + + // THREAD 2: load a class via forName while holding some external lock + progress.set(null); + Runnable thread2 = () -> { + synchronized (externalTestLock) { + try { + progress.await(null); + Class.forName(TestPlugins.SAMPLING_CONVERTER, true, connectorLoader); + } catch (ClassNotFoundException e) { + throw new RuntimeException("Failed to load test plugin", e); + } + } + }; + + // STEP 1: Have T1 hold the PluginClassLoader's monitor lock + exec.submit(thread1); + // LOCK the classloader monitor lock + testBreakpoint.testAwait(); + dumpThreads("step 1, T1 holding classloader monitor lock"); + + // STEP 2: Have T2 hold the external lock, and proceed to perform classloading + exec.submit(thread2); + // LOCK the external lock + progress.testAwait(); + // perform class loading + dumpThreads("step 2, T2 holding external lock"); + + // STEP 3: Have T1 grab the external lock, and then release both locks + testBreakpoint.testAwait(); + // LOCK the external lock + dumpThreads("step 3, T1 grabbed external lock"); + + // If the PluginClassLoader was not parallel capable, then these threads should deadlock + // Otherwise, classloading should proceed without grabbing the monitor lock, and complete before T1 grabs the external lock from T2. + assertNoDeadlocks(); + } + + private boolean threadFromCurrentTest(ThreadInfo threadInfo) { + return threadInfo.getThreadName().startsWith(threadPrefix); + } + + private void assertNoDeadlocks() { + long[] deadlockedThreads = ManagementFactory.getThreadMXBean().findDeadlockedThreads(); + if (deadlockedThreads != null && deadlockedThreads.length > 0) { + final String threads = Arrays + .stream(ManagementFactory.getThreadMXBean().getThreadInfo(deadlockedThreads)) + .filter(this::threadFromCurrentTest) + .map(SynchronizationTest::threadInfoToString) + .collect(Collectors.joining("")); + if (!threads.isEmpty()) { + fail("Found deadlocked threads while classloading\n" + threads); + } + } + } + + private void dumpThreads(String msg) throws InterruptedException { + if (log.isDebugEnabled()) { + log.debug("{}:\n{}", + msg, + Arrays.stream(ManagementFactory.getThreadMXBean().dumpAllThreads(true, true)) + .filter(this::threadFromCurrentTest) + .map(SynchronizationTest::threadInfoToString) + .collect(Collectors.joining("\n")) + ); + } + } + + private static String threadInfoToString(ThreadInfo info) { + StringBuilder sb = new StringBuilder("\"" + info.getThreadName() + "\"" + + " Id=" + info.getThreadId() + " " + + info.getThreadState()); + if (info.getLockName() != null) { + sb.append(" on " + info.getLockName()); + } + if (info.getLockOwnerName() != null) { + sb.append(" owned by \"" + info.getLockOwnerName() + + "\" Id=" + info.getLockOwnerId()); + } + if (info.isSuspended()) { + sb.append(" (suspended)"); + } + if (info.isInNative()) { + sb.append(" (in native)"); + } + sb.append('\n'); + // this has been refactored for checkstyle + printStacktrace(info, sb); + LockInfo[] locks = info.getLockedSynchronizers(); + if (locks.length > 0) { + sb.append("\n\tNumber of locked synchronizers = " + locks.length); + sb.append('\n'); + for (LockInfo li : locks) { + sb.append("\t- " + li); + sb.append('\n'); + } + } + sb.append('\n'); + return sb.toString(); + } + + private static void printStacktrace(ThreadInfo info, StringBuilder sb) { + StackTraceElement[] stackTrace = info.getStackTrace(); + int i = 0; + // This is a copy of ThreadInfo::toString but with an unlimited number of frames shown. + for (; i < stackTrace.length; i++) { + StackTraceElement ste = stackTrace[i]; + sb.append("\tat " + ste.toString()); + sb.append('\n'); + if (i == 0 && info.getLockInfo() != null) { + Thread.State ts = info.getThreadState(); + switch (ts) { + case BLOCKED: + sb.append("\t- blocked on " + info.getLockInfo()); + sb.append('\n'); + break; + case WAITING: + sb.append("\t- waiting on " + info.getLockInfo()); + sb.append('\n'); + break; + case TIMED_WAITING: + sb.append("\t- waiting on " + info.getLockInfo()); + sb.append('\n'); + break; + default: + } + } + + for (MonitorInfo mi : info.getLockedMonitors()) { + if (mi.getLockedStackDepth() == i) { + sb.append("\t- locked " + mi); + sb.append('\n'); + } + } + } + } + + private static ThreadFactory threadFactoryWithNamedThreads(String threadPrefix) { + AtomicInteger threadNumber = new AtomicInteger(1); + return r -> { + // This is essentially Executors.defaultThreadFactory except with + // custom thread names so in order to filter by thread names when debugging + SecurityManager s = System.getSecurityManager(); + Thread t = new Thread((s != null) ? s.getThreadGroup() : + Thread.currentThread().getThreadGroup(), r, + threadPrefix + threadNumber.getAndIncrement(), + 0); + if (t.isDaemon()) { + t.setDaemon(false); + } + if (t.getPriority() != Thread.NORM_PRIORITY) { + t.setPriority(Thread.NORM_PRIORITY); + } + return t; + }; + } + +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/isolation/TestPlugins.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/isolation/TestPlugins.java new file mode 100644 index 0000000..9561ffb --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/isolation/TestPlugins.java @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.runtime.isolation; + +import java.io.BufferedInputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.StringWriter; +import java.net.URL; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.jar.Attributes; +import java.util.jar.JarEntry; +import java.util.jar.JarOutputStream; +import java.util.jar.Manifest; +import java.util.stream.Collectors; +import javax.tools.JavaCompiler; +import javax.tools.StandardJavaFileManager; +import javax.tools.ToolProvider; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Utility class for constructing test plugins for Connect. + * + *

            Plugins are built from their source under resources/test-plugins/ and placed into temporary + * jar files that are deleted when the process exits. + * + *

            To add a plugin, create the source files in the resource tree, and edit this class to build + * that plugin during initialization. For example, the plugin class {@literal package.Class} should + * be placed in {@literal resources/test-plugins/something/package/Class.java} and loaded using + * {@code createPluginJar("something")}. The class name, contents, and plugin directory can take + * any value you need for testing. + * + *

            To use this class in your tests, make sure to first call + * {@link TestPlugins#assertAvailable()} to verify that the plugins initialized correctly. + * Otherwise, exceptions during the plugin build are not propagated, and may invalidate your test. + * You can access the list of plugin jars for assembling a {@literal plugin.path}, and reference + * the names of the different plugins directly via the exposed constants. + */ +public class TestPlugins { + + /** + * Class name of a plugin which will always throw an exception during loading + */ + public static final String ALWAYS_THROW_EXCEPTION = "test.plugins.AlwaysThrowException"; + /** + * Class name of a plugin which samples information about its initialization. + */ + public static final String ALIASED_STATIC_FIELD = "test.plugins.AliasedStaticField"; + /** + * Class name of a {@link org.apache.kafka.connect.storage.Converter} + * which samples information about its method calls. + */ + public static final String SAMPLING_CONVERTER = "test.plugins.SamplingConverter"; + /** + * Class name of a {@link org.apache.kafka.common.Configurable} + * which samples information about its method calls. + */ + public static final String SAMPLING_CONFIGURABLE = "test.plugins.SamplingConfigurable"; + /** + * Class name of a {@link org.apache.kafka.connect.storage.HeaderConverter} + * which samples information about its method calls. + */ + public static final String SAMPLING_HEADER_CONVERTER = "test.plugins.SamplingHeaderConverter"; + /** + * Class name of a {@link org.apache.kafka.common.config.provider.ConfigProvider} + * which samples information about its method calls. + */ + public static final String SAMPLING_CONFIG_PROVIDER = "test.plugins.SamplingConfigProvider"; + /** + * Class name of a plugin which uses a {@link java.util.ServiceLoader} + * to load internal classes, and samples information about their initialization. + */ + public static final String SERVICE_LOADER = "test.plugins.ServiceLoaderPlugin"; + + private static final Logger log = LoggerFactory.getLogger(TestPlugins.class); + private static final Map PLUGIN_JARS; + private static final Throwable INITIALIZATION_EXCEPTION; + + static { + Throwable err = null; + HashMap pluginJars = new HashMap<>(); + try { + pluginJars.put(ALWAYS_THROW_EXCEPTION, createPluginJar("always-throw-exception")); + pluginJars.put(ALIASED_STATIC_FIELD, createPluginJar("aliased-static-field")); + pluginJars.put(SAMPLING_CONVERTER, createPluginJar("sampling-converter")); + pluginJars.put(SAMPLING_CONFIGURABLE, createPluginJar("sampling-configurable")); + pluginJars.put(SAMPLING_HEADER_CONVERTER, createPluginJar("sampling-header-converter")); + pluginJars.put(SAMPLING_CONFIG_PROVIDER, createPluginJar("sampling-config-provider")); + pluginJars.put(SERVICE_LOADER, createPluginJar("service-loader")); + } catch (Throwable e) { + log.error("Could not set up plugin test jars", e); + err = e; + } + PLUGIN_JARS = Collections.unmodifiableMap(pluginJars); + INITIALIZATION_EXCEPTION = err; + } + + /** + * Ensure that the test plugin JARs were assembled without error before continuing. + * @throws AssertionError if any plugin failed to load, or no plugins were loaded. + */ + public static void assertAvailable() throws AssertionError { + if (INITIALIZATION_EXCEPTION != null) { + throw new AssertionError("TestPlugins did not initialize completely", + INITIALIZATION_EXCEPTION); + } + if (PLUGIN_JARS.isEmpty()) { + throw new AssertionError("No test plugins loaded"); + } + } + + /** + * A list of jar files containing test plugins + * @return A list of plugin jar filenames + */ + public static List pluginPath() { + return PLUGIN_JARS.values() + .stream() + .map(File::getPath) + .collect(Collectors.toList()); + } + + /** + * Get all of the classes that were successfully built by this class + * @return A list of plugin class names + */ + public static List pluginClasses() { + return new ArrayList<>(PLUGIN_JARS.keySet()); + } + + private static File createPluginJar(String resourceDir) throws IOException { + Path inputDir = resourceDirectoryPath("test-plugins/" + resourceDir); + Path binDir = Files.createTempDirectory(resourceDir + ".bin."); + compileJavaSources(inputDir, binDir); + File jarFile = Files.createTempFile(resourceDir + ".", ".jar").toFile(); + try (JarOutputStream jar = openJarFile(jarFile)) { + writeJar(jar, inputDir); + writeJar(jar, binDir); + } + removeDirectory(binDir); + jarFile.deleteOnExit(); + return jarFile; + } + + private static Path resourceDirectoryPath(String resourceDir) throws IOException { + URL resource = Thread.currentThread() + .getContextClassLoader() + .getResource(resourceDir); + if (resource == null) { + throw new IOException("Could not find test plugin resource: " + resourceDir); + } + File file = new File(resource.getFile()); + if (!file.isDirectory()) { + throw new IOException("Resource is not a directory: " + resourceDir); + } + if (!file.canRead()) { + throw new IOException("Resource directory is not readable: " + resourceDir); + } + return file.toPath(); + } + + private static JarOutputStream openJarFile(File jarFile) throws IOException { + Manifest manifest = new Manifest(); + manifest.getMainAttributes().put(Attributes.Name.MANIFEST_VERSION, "1.0"); + return new JarOutputStream(new FileOutputStream(jarFile), manifest); + } + + private static void removeDirectory(Path binDir) throws IOException { + List classFiles = Files.walk(binDir) + .sorted(Comparator.reverseOrder()) + .map(Path::toFile) + .collect(Collectors.toList()); + for (File classFile : classFiles) { + if (!classFile.delete()) { + throw new IOException("Could not delete: " + classFile); + } + } + } + + /** + * Compile a directory of .java source files into .class files + * .class files are placed into the same directory as their sources. + * + *

            Dependencies between source files in this directory are resolved against one another + * and the classes present in the test environment. + * See https://stackoverflow.com/questions/1563909/ for more information. + * Additional dependencies in your plugins should be added as test scope to :connect:runtime. + * @param sourceDir Directory containing java source files + * @throws IOException if the files cannot be compiled + */ + private static void compileJavaSources(Path sourceDir, Path binDir) throws IOException { + JavaCompiler compiler = ToolProvider.getSystemJavaCompiler(); + List sourceFiles = Files.walk(sourceDir) + .filter(Files::isRegularFile) + .filter(path -> path.toFile().getName().endsWith(".java")) + .map(Path::toFile) + .collect(Collectors.toList()); + StringWriter writer = new StringWriter(); + List options = Arrays.asList( + "-d", binDir.toString() // Write class output to a different directory. + ); + + try (StandardJavaFileManager fileManager = compiler.getStandardFileManager(null, null, null)) { + boolean success = compiler.getTask( + writer, + fileManager, + null, + options, + null, + fileManager.getJavaFileObjectsFromFiles(sourceFiles) + ).call(); + if (!success) { + throw new RuntimeException("Failed to compile test plugin:\n" + writer); + } + } + } + + private static void writeJar(JarOutputStream jar, Path inputDir) throws IOException { + List paths = Files.walk(inputDir) + .filter(Files::isRegularFile) + .filter(path -> !path.toFile().getName().endsWith(".java")) + .collect(Collectors.toList()); + for (Path path : paths) { + try (InputStream in = new BufferedInputStream(new FileInputStream(path.toFile()))) { + jar.putNextEntry(new JarEntry( + inputDir.relativize(path) + .toFile() + .getPath() + .replace(File.separator, "/") + )); + byte[] buffer = new byte[1024]; + for (int count; (count = in.read(buffer)) != -1; ) { + jar.write(buffer, 0, count); + } + jar.closeEntry(); + } + } + } + +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/rest/InternalRequestSignatureTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/rest/InternalRequestSignatureTest.java new file mode 100644 index 0000000..f60ad35 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/rest/InternalRequestSignatureTest.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.runtime.rest; + +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.runtime.rest.errors.BadRequestException; +import org.eclipse.jetty.client.api.Request; +import org.junit.Test; +import org.mockito.ArgumentCaptor; + +import javax.crypto.Mac; +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import javax.ws.rs.core.HttpHeaders; + +import java.util.Base64; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class InternalRequestSignatureTest { + + private static final byte[] REQUEST_BODY = + "[{\"config\":\"value\"},{\"config\":\"other_value\"}]".getBytes(); + private static final String SIGNATURE_ALGORITHM = "HmacSHA256"; + private static final SecretKey KEY = new SecretKeySpec( + new byte[] { + 109, 116, -111, 49, -94, 25, -103, 44, -99, -118, 53, -69, 87, -124, 5, 48, + 89, -105, -2, 58, -92, 87, 67, 49, -125, -79, -39, -126, -51, -53, -85, 57 + }, "HmacSHA256" + ); + private static final byte[] SIGNATURE = new byte[] { + 42, -3, 127, 57, 43, 49, -51, -43, 72, -62, -10, 120, 123, 125, 26, -65, + 36, 72, 86, -71, -32, 13, -8, 115, 85, 73, -65, -112, 6, 68, 41, -50 + }; + private static final String ENCODED_SIGNATURE = Base64.getEncoder().encodeToString(SIGNATURE); + + @Test + public void fromHeadersShouldReturnNullOnNullHeaders() { + assertNull(InternalRequestSignature.fromHeaders(REQUEST_BODY, null)); + } + + @Test + public void fromHeadersShouldReturnNullIfSignatureHeaderMissing() { + assertNull(InternalRequestSignature.fromHeaders(REQUEST_BODY, internalRequestHeaders(null, SIGNATURE_ALGORITHM))); + } + + @Test + public void fromHeadersShouldReturnNullIfSignatureAlgorithmHeaderMissing() { + assertNull(InternalRequestSignature.fromHeaders(REQUEST_BODY, internalRequestHeaders(ENCODED_SIGNATURE, null))); + } + + @Test + public void fromHeadersShouldThrowExceptionOnInvalidSignatureAlgorithm() { + assertThrows(BadRequestException.class, () -> InternalRequestSignature.fromHeaders(REQUEST_BODY, + internalRequestHeaders(ENCODED_SIGNATURE, "doesn'texist"))); + } + + @Test + public void fromHeadersShouldThrowExceptionOnInvalidBase64Signature() { + assertThrows(BadRequestException.class, () -> InternalRequestSignature.fromHeaders(REQUEST_BODY, + internalRequestHeaders("not valid base 64", SIGNATURE_ALGORITHM))); + } + + @Test + public void fromHeadersShouldReturnNonNullResultOnValidSignatureAndSignatureAlgorithm() { + InternalRequestSignature signature = + InternalRequestSignature.fromHeaders(REQUEST_BODY, internalRequestHeaders(ENCODED_SIGNATURE, SIGNATURE_ALGORITHM)); + assertNotNull(signature); + assertNotNull(signature.keyAlgorithm()); + } + + @Test + public void addToRequestShouldThrowExceptionOnInvalidSignatureAlgorithm() { + Request request = mock(Request.class); + assertThrows(ConnectException.class, () -> InternalRequestSignature.addToRequest(KEY, REQUEST_BODY, "doesn'texist", request)); + } + + @Test + public void addToRequestShouldAddHeadersOnValidSignatureAlgorithm() { + Request request = mock(Request.class); + ArgumentCaptor signatureCapture = ArgumentCaptor.forClass(String.class); + ArgumentCaptor signatureAlgorithmCapture = ArgumentCaptor.forClass(String.class); + when(request.header( + eq(InternalRequestSignature.SIGNATURE_HEADER), + signatureCapture.capture() + )).thenReturn(request); + when(request.header( + eq(InternalRequestSignature.SIGNATURE_ALGORITHM_HEADER), + signatureAlgorithmCapture.capture() + )).thenReturn(request); + + InternalRequestSignature.addToRequest(KEY, REQUEST_BODY, SIGNATURE_ALGORITHM, request); + + assertEquals( + "Request should have valid base 64-encoded signature added as header", + ENCODED_SIGNATURE, + signatureCapture.getValue() + ); + assertEquals( + "Request should have provided signature algorithm added as header", + SIGNATURE_ALGORITHM, + signatureAlgorithmCapture.getValue() + ); + } + + @Test + public void testSignatureValidation() throws Exception { + Mac mac = Mac.getInstance(SIGNATURE_ALGORITHM); + + InternalRequestSignature signature = new InternalRequestSignature(REQUEST_BODY, mac, SIGNATURE); + assertTrue(signature.isValid(KEY)); + + signature = InternalRequestSignature.fromHeaders(REQUEST_BODY, internalRequestHeaders(ENCODED_SIGNATURE, SIGNATURE_ALGORITHM)); + assertTrue(signature.isValid(KEY)); + + signature = new InternalRequestSignature("[{\"different_config\":\"different_value\"}]".getBytes(), mac, SIGNATURE); + assertFalse(signature.isValid(KEY)); + + signature = new InternalRequestSignature(REQUEST_BODY, mac, "bad signature".getBytes()); + assertFalse(signature.isValid(KEY)); + } + + private static HttpHeaders internalRequestHeaders(String signature, String signatureAlgorithm) { + HttpHeaders result = mock(HttpHeaders.class); + when(result.getHeaderString(eq(InternalRequestSignature.SIGNATURE_HEADER))) + .thenReturn(signature); + when(result.getHeaderString(eq(InternalRequestSignature.SIGNATURE_ALGORITHM_HEADER))) + .thenReturn(signatureAlgorithm); + return result; + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/rest/RestServerTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/rest/RestServerTest.java new file mode 100644 index 0000000..a8e6fa8 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/rest/RestServerTest.java @@ -0,0 +1,421 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.http.HttpHost; +import org.apache.http.HttpRequest; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpOptions; +import org.apache.http.client.methods.HttpPut; +import org.apache.http.entity.StringEntity; +import org.apache.http.impl.client.BasicResponseHandler; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.HttpClients; +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.connect.rest.ConnectRestExtension; +import org.apache.kafka.connect.runtime.Herder; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.apache.kafka.connect.runtime.distributed.DistributedConfig; +import org.apache.kafka.connect.runtime.isolation.Plugins; +import org.apache.kafka.connect.runtime.standalone.StandaloneConfig; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.slf4j.LoggerFactory; + +import javax.ws.rs.core.MediaType; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.apache.kafka.connect.runtime.WorkerConfig.ADMIN_LISTENERS_CONFIG; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; + +public class RestServerTest { + + private Herder herder; + private Plugins plugins; + private RestServer server; + + protected static final String KAFKA_CLUSTER_ID = "Xbafgnagvar"; + + @Before + public void setUp() { + herder = mock(Herder.class); + plugins = mock(Plugins.class); + } + + @After + public void tearDown() { + if (server != null) { + server.stop(); + } + } + + @SuppressWarnings("deprecation") + private Map baseWorkerProps() { + Map workerProps = new HashMap<>(); + workerProps.put(DistributedConfig.STATUS_STORAGE_TOPIC_CONFIG, "status-topic"); + workerProps.put(DistributedConfig.CONFIG_TOPIC_CONFIG, "config-topic"); + workerProps.put(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + workerProps.put(DistributedConfig.GROUP_ID_CONFIG, "connect-test-group"); + workerProps.put(WorkerConfig.KEY_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + workerProps.put(WorkerConfig.VALUE_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + workerProps.put(DistributedConfig.OFFSET_STORAGE_TOPIC_CONFIG, "connect-offsets"); + workerProps.put(WorkerConfig.LISTENERS_CONFIG, "HTTP://localhost:0"); + + return workerProps; + } + + @Test + public void testCORSEnabled() throws IOException { + checkCORSRequest("*", "http://bar.com", "http://bar.com", "PUT"); + } + + @Test + public void testCORSDisabled() throws IOException { + checkCORSRequest("", "http://bar.com", null, null); + } + + @Test + public void testAdvertisedUri() { + // Advertised URI from listeners without protocol + Map configMap = new HashMap<>(baseWorkerProps()); + configMap.put(WorkerConfig.LISTENERS_CONFIG, "http://localhost:8080,https://localhost:8443"); + DistributedConfig config = new DistributedConfig(configMap); + + server = new RestServer(config); + Assert.assertEquals("http://localhost:8080/", server.advertisedUrl().toString()); + + // Advertised URI from listeners with protocol + configMap = new HashMap<>(baseWorkerProps()); + configMap.put(WorkerConfig.LISTENERS_CONFIG, "http://localhost:8080,https://localhost:8443"); + configMap.put(WorkerConfig.REST_ADVERTISED_LISTENER_CONFIG, "https"); + config = new DistributedConfig(configMap); + + server = new RestServer(config); + Assert.assertEquals("https://localhost:8443/", server.advertisedUrl().toString()); + + // Advertised URI from listeners with only SSL available + configMap = new HashMap<>(baseWorkerProps()); + configMap.put(WorkerConfig.LISTENERS_CONFIG, "https://localhost:8443"); + config = new DistributedConfig(configMap); + + server = new RestServer(config); + Assert.assertEquals("https://localhost:8443/", server.advertisedUrl().toString()); + + // Listener is overriden by advertised values + configMap = new HashMap<>(baseWorkerProps()); + configMap.put(WorkerConfig.LISTENERS_CONFIG, "https://localhost:8443"); + configMap.put(WorkerConfig.REST_ADVERTISED_LISTENER_CONFIG, "http"); + configMap.put(WorkerConfig.REST_ADVERTISED_HOST_NAME_CONFIG, "somehost"); + configMap.put(WorkerConfig.REST_ADVERTISED_PORT_CONFIG, "10000"); + config = new DistributedConfig(configMap); + + server = new RestServer(config); + Assert.assertEquals("http://somehost:10000/", server.advertisedUrl().toString()); + + // correct listener is chosen when https listener is configured before http listener and advertised listener is http + configMap = new HashMap<>(baseWorkerProps()); + configMap.put(WorkerConfig.LISTENERS_CONFIG, "https://encrypted-localhost:42069,http://plaintext-localhost:4761"); + configMap.put(WorkerConfig.REST_ADVERTISED_LISTENER_CONFIG, "http"); + config = new DistributedConfig(configMap); + server = new RestServer(config); + Assert.assertEquals("http://plaintext-localhost:4761/", server.advertisedUrl().toString()); + } + + @Test + public void testOptionsDoesNotIncludeWadlOutput() throws IOException { + Map configMap = new HashMap<>(baseWorkerProps()); + DistributedConfig workerConfig = new DistributedConfig(configMap); + + doReturn(KAFKA_CLUSTER_ID).when(herder).kafkaClusterId(); + doReturn(plugins).when(herder).plugins(); + doReturn(Collections.emptyList()).when(plugins).newPlugins(Collections.emptyList(), workerConfig, ConnectRestExtension.class); + + server = new RestServer(workerConfig); + server.initializeServer(); + server.initializeResources(herder); + + HttpOptions request = new HttpOptions("/connectors"); + request.addHeader("Content-Type", MediaType.WILDCARD); + CloseableHttpClient httpClient = HttpClients.createMinimal(); + HttpHost httpHost = new HttpHost( + server.advertisedUrl().getHost(), + server.advertisedUrl().getPort() + ); + CloseableHttpResponse response = httpClient.execute(httpHost, request); + Assert.assertEquals(MediaType.TEXT_PLAIN, response.getEntity().getContentType().getValue()); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + response.getEntity().writeTo(baos); + Assert.assertArrayEquals( + request.getAllowedMethods(response).toArray(), + new String(baos.toByteArray(), StandardCharsets.UTF_8).split(", ") + ); + } + + public void checkCORSRequest(String corsDomain, String origin, String expectedHeader, String method) + throws IOException { + Map workerProps = baseWorkerProps(); + workerProps.put(WorkerConfig.ACCESS_CONTROL_ALLOW_ORIGIN_CONFIG, corsDomain); + workerProps.put(WorkerConfig.ACCESS_CONTROL_ALLOW_METHODS_CONFIG, method); + WorkerConfig workerConfig = new DistributedConfig(workerProps); + + doReturn(KAFKA_CLUSTER_ID).when(herder).kafkaClusterId(); + doReturn(plugins).when(herder).plugins(); + doReturn(Collections.emptyList()).when(plugins).newPlugins(Collections.emptyList(), workerConfig, ConnectRestExtension.class); + doReturn(Arrays.asList("a", "b")).when(herder).connectors(); + + server = new RestServer(workerConfig); + server.initializeServer(); + server.initializeResources(herder); + HttpRequest request = new HttpGet("/connectors"); + request.addHeader("Referer", origin + "/page"); + request.addHeader("Origin", origin); + CloseableHttpClient httpClient = HttpClients.createMinimal(); + HttpHost httpHost = new HttpHost( + server.advertisedUrl().getHost(), + server.advertisedUrl().getPort() + ); + CloseableHttpResponse response = httpClient.execute(httpHost, request); + + Assert.assertEquals(200, response.getStatusLine().getStatusCode()); + + if (expectedHeader != null) { + Assert.assertEquals(expectedHeader, + response.getFirstHeader("Access-Control-Allow-Origin").getValue()); + } + + request = new HttpOptions("/connector-plugins/FileStreamSource/validate"); + request.addHeader("Referer", origin + "/page"); + request.addHeader("Origin", origin); + request.addHeader("Access-Control-Request-Method", method); + response = httpClient.execute(httpHost, request); + Assert.assertEquals(404, response.getStatusLine().getStatusCode()); + if (expectedHeader != null) { + Assert.assertEquals(expectedHeader, + response.getFirstHeader("Access-Control-Allow-Origin").getValue()); + } + if (method != null) { + Assert.assertEquals(method, + response.getFirstHeader("Access-Control-Allow-Methods").getValue()); + } + } + + @Test + public void testStandaloneConfig() throws IOException { + Map workerProps = baseWorkerProps(); + workerProps.put("offset.storage.file.filename", "/tmp"); + WorkerConfig workerConfig = new StandaloneConfig(workerProps); + + doReturn(KAFKA_CLUSTER_ID).when(herder).kafkaClusterId(); + doReturn(plugins).when(herder).plugins(); + doReturn(Collections.emptyList()).when(plugins).newPlugins(Collections.emptyList(), workerConfig, ConnectRestExtension.class); + doReturn(Arrays.asList("a", "b")).when(herder).connectors(); + + server = new RestServer(workerConfig); + server.initializeServer(); + server.initializeResources(herder); + HttpRequest request = new HttpGet("/connectors"); + CloseableHttpClient httpClient = HttpClients.createMinimal(); + HttpHost httpHost = new HttpHost(server.advertisedUrl().getHost(), server.advertisedUrl().getPort()); + CloseableHttpResponse response = httpClient.execute(httpHost, request); + + Assert.assertEquals(200, response.getStatusLine().getStatusCode()); + } + + @Test + public void testLoggersEndpointWithDefaults() throws IOException { + Map configMap = new HashMap<>(baseWorkerProps()); + DistributedConfig workerConfig = new DistributedConfig(configMap); + + doReturn(KAFKA_CLUSTER_ID).when(herder).kafkaClusterId(); + doReturn(plugins).when(herder).plugins(); + doReturn(Collections.emptyList()).when(plugins).newPlugins(Collections.emptyList(), workerConfig, ConnectRestExtension.class); + + // create some loggers in the process + LoggerFactory.getLogger("a.b.c.s.W"); + + server = new RestServer(workerConfig); + server.initializeServer(); + server.initializeResources(herder); + + ObjectMapper mapper = new ObjectMapper(); + + String host = server.advertisedUrl().getHost(); + int port = server.advertisedUrl().getPort(); + + executePut(host, port, "/admin/loggers/a.b.c.s.W", "{\"level\": \"INFO\"}"); + + String responseStr = executeGet(host, port, "/admin/loggers"); + Map> loggers = mapper.readValue(responseStr, new TypeReference>>() { + }); + assertNotNull("expected non null response for /admin/loggers" + prettyPrint(loggers), loggers); + assertTrue("expect at least 1 logger. instead found " + prettyPrint(loggers), loggers.size() >= 1); + assertEquals("expected to find logger a.b.c.s.W set to INFO level", loggers.get("a.b.c.s.W").get("level"), "INFO"); + } + + @Test + public void testIndependentAdminEndpoint() throws IOException { + Map configMap = new HashMap<>(baseWorkerProps()); + configMap.put(ADMIN_LISTENERS_CONFIG, "http://localhost:0"); + + DistributedConfig workerConfig = new DistributedConfig(configMap); + + doReturn(KAFKA_CLUSTER_ID).when(herder).kafkaClusterId(); + doReturn(plugins).when(herder).plugins(); + doReturn(Collections.emptyList()).when(plugins).newPlugins(Collections.emptyList(), workerConfig, ConnectRestExtension.class); + + // create some loggers in the process + LoggerFactory.getLogger("a.b.c.s.W"); + LoggerFactory.getLogger("a.b.c.p.X"); + LoggerFactory.getLogger("a.b.c.p.Y"); + LoggerFactory.getLogger("a.b.c.p.Z"); + + server = new RestServer(workerConfig); + server.initializeServer(); + server.initializeResources(herder); + + assertNotEquals(server.advertisedUrl(), server.adminUrl()); + + executeGet(server.adminUrl().getHost(), server.adminUrl().getPort(), "/admin/loggers"); + + HttpRequest request = new HttpGet("/admin/loggers"); + CloseableHttpClient httpClient = HttpClients.createMinimal(); + HttpHost httpHost = new HttpHost(server.advertisedUrl().getHost(), server.advertisedUrl().getPort()); + CloseableHttpResponse response = httpClient.execute(httpHost, request); + Assert.assertEquals(404, response.getStatusLine().getStatusCode()); + } + + @Test + public void testDisableAdminEndpoint() throws IOException { + Map configMap = new HashMap<>(baseWorkerProps()); + configMap.put(ADMIN_LISTENERS_CONFIG, ""); + + DistributedConfig workerConfig = new DistributedConfig(configMap); + + doReturn(KAFKA_CLUSTER_ID).when(herder).kafkaClusterId(); + doReturn(plugins).when(herder).plugins(); + doReturn(Collections.emptyList()).when(plugins).newPlugins(Collections.emptyList(), workerConfig, ConnectRestExtension.class); + + server = new RestServer(workerConfig); + server.initializeServer(); + server.initializeResources(herder); + + assertNull(server.adminUrl()); + + HttpRequest request = new HttpGet("/admin/loggers"); + CloseableHttpClient httpClient = HttpClients.createMinimal(); + HttpHost httpHost = new HttpHost(server.advertisedUrl().getHost(), server.advertisedUrl().getPort()); + CloseableHttpResponse response = httpClient.execute(httpHost, request); + Assert.assertEquals(404, response.getStatusLine().getStatusCode()); + } + + @Test + public void testValidCustomizedHttpResponseHeaders() throws IOException { + String headerConfig = + "add X-XSS-Protection: 1; mode=block, \"add Cache-Control: no-cache, no-store, must-revalidate\""; + Map expectedHeaders = new HashMap<>(); + expectedHeaders.put("X-XSS-Protection", "1; mode=block"); + expectedHeaders.put("Cache-Control", "no-cache, no-store, must-revalidate"); + checkCustomizedHttpResponseHeaders(headerConfig, expectedHeaders); + } + + @Test + public void testDefaultCustomizedHttpResponseHeaders() throws IOException { + String headerConfig = ""; + Map expectedHeaders = new HashMap<>(); + checkCustomizedHttpResponseHeaders(headerConfig, expectedHeaders); + } + + private void checkCustomizedHttpResponseHeaders(String headerConfig, Map expectedHeaders) + throws IOException { + Map workerProps = baseWorkerProps(); + workerProps.put("offset.storage.file.filename", "/tmp"); + workerProps.put(WorkerConfig.RESPONSE_HTTP_HEADERS_CONFIG, headerConfig); + WorkerConfig workerConfig = new DistributedConfig(workerProps); + + doReturn(KAFKA_CLUSTER_ID).when(herder).kafkaClusterId(); + doReturn(plugins).when(herder).plugins(); + doReturn(Collections.emptyList()).when(plugins).newPlugins(Collections.emptyList(), workerConfig, ConnectRestExtension.class); + doReturn(Arrays.asList("a", "b")).when(herder).connectors(); + + server = new RestServer(workerConfig); + try { + server.initializeServer(); + server.initializeResources(herder); + HttpRequest request = new HttpGet("/connectors"); + try (CloseableHttpClient httpClient = HttpClients.createMinimal()) { + HttpHost httpHost = new HttpHost(server.advertisedUrl().getHost(), server.advertisedUrl().getPort()); + try (CloseableHttpResponse response = httpClient.execute(httpHost, request)) { + Assert.assertEquals(200, response.getStatusLine().getStatusCode()); + if (!headerConfig.isEmpty()) { + expectedHeaders.forEach((k, v) -> + Assert.assertEquals(response.getFirstHeader(k).getValue(), v)); + } else { + Assert.assertNull(response.getFirstHeader("X-Frame-Options")); + } + } + } + } finally { + server.stop(); + server = null; + } + } + + private String executeGet(String host, int port, String endpoint) throws IOException { + HttpRequest request = new HttpGet(endpoint); + CloseableHttpClient httpClient = HttpClients.createMinimal(); + HttpHost httpHost = new HttpHost(host, port); + CloseableHttpResponse response = httpClient.execute(httpHost, request); + + Assert.assertEquals(200, response.getStatusLine().getStatusCode()); + return new BasicResponseHandler().handleResponse(response); + } + + private String executePut(String host, int port, String endpoint, String jsonBody) throws IOException { + HttpPut request = new HttpPut(endpoint); + StringEntity entity = new StringEntity(jsonBody, StandardCharsets.UTF_8.name()); + entity.setContentType("application/json"); + request.setEntity(entity); + CloseableHttpClient httpClient = HttpClients.createMinimal(); + HttpHost httpHost = new HttpHost(host, port); + CloseableHttpResponse response = httpClient.execute(httpHost, request); + + Assert.assertEquals(200, response.getStatusLine().getStatusCode()); + return new BasicResponseHandler().handleResponse(response); + } + + private static String prettyPrint(Map map) throws IOException { + ObjectMapper mapper = new ObjectMapper(); + return mapper.writerWithDefaultPrettyPrinter().writeValueAsString(map); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/rest/entities/ConnectorTypeTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/rest/entities/ConnectorTypeTest.java new file mode 100644 index 0000000..cd07bd8 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/rest/entities/ConnectorTypeTest.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest.entities; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class ConnectorTypeTest { + + @Test + public void testToStringIsLowerCase() { + for (ConnectorType ct : ConnectorType.values()) { + String shouldBeLower = ct.toString(); + assertFalse(shouldBeLower.isEmpty()); + for (Character c : shouldBeLower.toCharArray()) { + assertTrue(Character.isLowerCase(c)); + } + } + } + + @Test + public void testForValue() { + for (ConnectorType ct : ConnectorType.values()) { + assertEquals(ct, ConnectorType.forValue(ct.toString())); + } + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/rest/resources/ConnectorPluginsResourceTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/rest/resources/ConnectorPluginsResourceTest.java new file mode 100644 index 0000000..c1d06f9 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/rest/resources/ConnectorPluginsResourceTest.java @@ -0,0 +1,513 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest.resources; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import javax.ws.rs.core.HttpHeaders; +import org.apache.kafka.common.config.Config; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigDef.Importance; +import org.apache.kafka.common.config.ConfigDef.Recommender; +import org.apache.kafka.common.config.ConfigDef.Type; +import org.apache.kafka.common.config.ConfigDef.Width; +import org.apache.kafka.common.config.ConfigValue; +import org.apache.kafka.connect.connector.Connector; +import org.apache.kafka.connect.connector.Task; +import org.apache.kafka.connect.runtime.AbstractHerder; +import org.apache.kafka.connect.runtime.ConnectorConfig; +import org.apache.kafka.connect.runtime.Herder; +import org.apache.kafka.connect.runtime.TestSinkConnector; +import org.apache.kafka.connect.runtime.TestSourceConnector; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.apache.kafka.connect.runtime.isolation.PluginClassLoader; +import org.apache.kafka.connect.runtime.isolation.PluginDesc; +import org.apache.kafka.connect.runtime.isolation.Plugins; +import org.apache.kafka.connect.runtime.rest.RestClient; +import org.apache.kafka.connect.runtime.rest.entities.ConfigInfo; +import org.apache.kafka.connect.runtime.rest.entities.ConfigInfos; +import org.apache.kafka.connect.runtime.rest.entities.ConfigKeyInfo; +import org.apache.kafka.connect.runtime.rest.entities.ConfigValueInfo; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorPluginInfo; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorType; +import org.apache.kafka.connect.sink.SinkConnector; +import org.apache.kafka.connect.source.SourceConnector; +import org.apache.kafka.connect.tools.MockConnector; +import org.apache.kafka.connect.tools.MockSinkConnector; +import org.apache.kafka.connect.tools.MockSourceConnector; +import org.apache.kafka.connect.tools.SchemaSourceConnector; +import org.apache.kafka.connect.tools.VerifiableSinkConnector; +import org.apache.kafka.connect.tools.VerifiableSourceConnector; +import org.apache.kafka.connect.util.Callback; +import org.easymock.Capture; +import org.easymock.EasyMock; +import org.easymock.IAnswer; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.easymock.PowerMock; +import org.powermock.api.easymock.annotation.Mock; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; + +import javax.ws.rs.BadRequestException; +import java.net.URL; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeSet; + +import static java.util.Arrays.asList; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +@RunWith(PowerMockRunner.class) +@PrepareForTest(RestClient.class) +@PowerMockIgnore("javax.management.*") +public class ConnectorPluginsResourceTest { + + private static Map props; + private static Map partialProps = new HashMap<>(); + static { + partialProps.put("name", "test"); + partialProps.put("test.string.config", "testString"); + partialProps.put("test.int.config", "1"); + partialProps.put("test.list.config", "a,b"); + + props = new HashMap<>(partialProps); + props.put("connector.class", ConnectorPluginsResourceTestConnector.class.getSimpleName()); + props.put("plugin.path", null); + } + + private static final ConfigInfos CONFIG_INFOS; + private static final ConfigInfos PARTIAL_CONFIG_INFOS; + private static final int ERROR_COUNT = 0; + private static final int PARTIAL_CONFIG_ERROR_COUNT = 1; + private static final Set> CONNECTOR_PLUGINS = new TreeSet<>(); + + static { + List configs = new LinkedList<>(); + List partialConfigs = new LinkedList<>(); + + ConfigDef connectorConfigDef = ConnectorConfig.configDef(); + List connectorConfigValues = connectorConfigDef.validate(props); + List partialConnectorConfigValues = connectorConfigDef.validate(partialProps); + ConfigInfos result = AbstractHerder.generateResult(ConnectorPluginsResourceTestConnector.class.getName(), connectorConfigDef.configKeys(), connectorConfigValues, Collections.emptyList()); + ConfigInfos partialResult = AbstractHerder.generateResult(ConnectorPluginsResourceTestConnector.class.getName(), connectorConfigDef.configKeys(), partialConnectorConfigValues, Collections.emptyList()); + configs.addAll(result.values()); + partialConfigs.addAll(partialResult.values()); + + ConfigKeyInfo configKeyInfo = new ConfigKeyInfo("test.string.config", "STRING", true, null, "HIGH", "Test configuration for string type.", null, -1, "NONE", "test.string.config", Collections.emptyList()); + ConfigValueInfo configValueInfo = new ConfigValueInfo("test.string.config", "testString", Collections.emptyList(), Collections.emptyList(), true); + ConfigInfo configInfo = new ConfigInfo(configKeyInfo, configValueInfo); + configs.add(configInfo); + partialConfigs.add(configInfo); + + configKeyInfo = new ConfigKeyInfo("test.int.config", "INT", true, null, "MEDIUM", "Test configuration for integer type.", "Test", 1, "MEDIUM", "test.int.config", Collections.emptyList()); + configValueInfo = new ConfigValueInfo("test.int.config", "1", asList("1", "2", "3"), Collections.emptyList(), true); + configInfo = new ConfigInfo(configKeyInfo, configValueInfo); + configs.add(configInfo); + partialConfigs.add(configInfo); + + configKeyInfo = new ConfigKeyInfo("test.string.config.default", "STRING", false, "", "LOW", "Test configuration with default value.", null, -1, "NONE", "test.string.config.default", Collections.emptyList()); + configValueInfo = new ConfigValueInfo("test.string.config.default", "", Collections.emptyList(), Collections.emptyList(), true); + configInfo = new ConfigInfo(configKeyInfo, configValueInfo); + configs.add(configInfo); + partialConfigs.add(configInfo); + + configKeyInfo = new ConfigKeyInfo("test.list.config", "LIST", true, null, "HIGH", "Test configuration for list type.", "Test", 2, "LONG", "test.list.config", Collections.emptyList()); + configValueInfo = new ConfigValueInfo("test.list.config", "a,b", asList("a", "b", "c"), Collections.emptyList(), true); + configInfo = new ConfigInfo(configKeyInfo, configValueInfo); + configs.add(configInfo); + partialConfigs.add(configInfo); + + CONFIG_INFOS = new ConfigInfos(ConnectorPluginsResourceTestConnector.class.getName(), ERROR_COUNT, Collections.singletonList("Test"), configs); + PARTIAL_CONFIG_INFOS = new ConfigInfos(ConnectorPluginsResourceTestConnector.class.getName(), PARTIAL_CONFIG_ERROR_COUNT, Collections.singletonList("Test"), partialConfigs); + + List> abstractConnectorClasses = asList( + Connector.class, + SourceConnector.class, + SinkConnector.class + ); + + List> connectorClasses = asList( + VerifiableSourceConnector.class, + VerifiableSinkConnector.class, + MockSourceConnector.class, + MockSinkConnector.class, + MockConnector.class, + SchemaSourceConnector.class, + ConnectorPluginsResourceTestConnector.class + ); + + try { + for (Class klass : abstractConnectorClasses) { + @SuppressWarnings("unchecked") + MockConnectorPluginDesc pluginDesc = new MockConnectorPluginDesc(klass, "0.0.0"); + CONNECTOR_PLUGINS.add(pluginDesc); + } + for (Class klass : connectorClasses) { + @SuppressWarnings("unchecked") + MockConnectorPluginDesc pluginDesc = new MockConnectorPluginDesc(klass); + CONNECTOR_PLUGINS.add(pluginDesc); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Mock + private Herder herder; + @Mock + private Plugins plugins; + private ConnectorPluginsResource connectorPluginsResource; + + @Before + public void setUp() throws Exception { + PowerMock.mockStatic(RestClient.class, + RestClient.class.getMethod("httpRequest", String.class, String.class, HttpHeaders.class, Object.class, TypeReference.class, WorkerConfig.class)); + + plugins = PowerMock.createMock(Plugins.class); + herder = PowerMock.createMock(AbstractHerder.class); + connectorPluginsResource = new ConnectorPluginsResource(herder); + } + + private void expectPlugins() { + EasyMock.expect(herder.plugins()).andReturn(plugins); + EasyMock.expect(plugins.connectors()).andReturn(CONNECTOR_PLUGINS); + PowerMock.replayAll(); + } + + @Test + public void testValidateConfigWithSingleErrorDueToMissingConnectorClassname() throws Throwable { + Capture> configInfosCallback = EasyMock.newCapture(); + herder.validateConnectorConfig(EasyMock.eq(partialProps), EasyMock.capture(configInfosCallback), EasyMock.anyBoolean()); + + PowerMock.expectLastCall().andAnswer((IAnswer) () -> { + ConfigDef connectorConfigDef = ConnectorConfig.configDef(); + List connectorConfigValues = connectorConfigDef.validate(partialProps); + + Connector connector = new ConnectorPluginsResourceTestConnector(); + Config config = connector.validate(partialProps); + ConfigDef configDef = connector.config(); + Map configKeys = configDef.configKeys(); + List configValues = config.configValues(); + + Map resultConfigKeys = new HashMap<>(configKeys); + resultConfigKeys.putAll(connectorConfigDef.configKeys()); + configValues.addAll(connectorConfigValues); + + ConfigInfos configInfos = AbstractHerder.generateResult( + ConnectorPluginsResourceTestConnector.class.getName(), + resultConfigKeys, + configValues, + Collections.singletonList("Test") + ); + configInfosCallback.getValue().onCompletion(null, configInfos); + return null; + }); + + PowerMock.replayAll(); + + // This call to validateConfigs does not throw a BadRequestException because we've mocked + // validateConnectorConfig. + ConfigInfos configInfos = connectorPluginsResource.validateConfigs( + ConnectorPluginsResourceTestConnector.class.getSimpleName(), + partialProps + ); + assertEquals(PARTIAL_CONFIG_INFOS.name(), configInfos.name()); + assertEquals(PARTIAL_CONFIG_INFOS.errorCount(), configInfos.errorCount()); + assertEquals(PARTIAL_CONFIG_INFOS.groups(), configInfos.groups()); + assertEquals( + new HashSet<>(PARTIAL_CONFIG_INFOS.values()), + new HashSet<>(configInfos.values()) + ); + + PowerMock.verifyAll(); + } + + @Test + public void testValidateConfigWithSimpleName() throws Throwable { + Capture> configInfosCallback = EasyMock.newCapture(); + herder.validateConnectorConfig(EasyMock.eq(props), EasyMock.capture(configInfosCallback), EasyMock.anyBoolean()); + + PowerMock.expectLastCall().andAnswer((IAnswer) () -> { + ConfigDef connectorConfigDef = ConnectorConfig.configDef(); + List connectorConfigValues = connectorConfigDef.validate(props); + + Connector connector = new ConnectorPluginsResourceTestConnector(); + Config config = connector.validate(props); + ConfigDef configDef = connector.config(); + Map configKeys = configDef.configKeys(); + List configValues = config.configValues(); + + Map resultConfigKeys = new HashMap<>(configKeys); + resultConfigKeys.putAll(connectorConfigDef.configKeys()); + configValues.addAll(connectorConfigValues); + + ConfigInfos configInfos = AbstractHerder.generateResult( + ConnectorPluginsResourceTestConnector.class.getName(), + resultConfigKeys, + configValues, + Collections.singletonList("Test") + ); + configInfosCallback.getValue().onCompletion(null, configInfos); + return null; + }); + + PowerMock.replayAll(); + + // make a request to connector-plugins resource using just the simple class name. + ConfigInfos configInfos = connectorPluginsResource.validateConfigs( + ConnectorPluginsResourceTestConnector.class.getSimpleName(), + props + ); + assertEquals(CONFIG_INFOS.name(), configInfos.name()); + assertEquals(0, configInfos.errorCount()); + assertEquals(CONFIG_INFOS.groups(), configInfos.groups()); + assertEquals(new HashSet<>(CONFIG_INFOS.values()), new HashSet<>(configInfos.values())); + + PowerMock.verifyAll(); + } + + @Test + public void testValidateConfigWithAlias() throws Throwable { + Capture> configInfosCallback = EasyMock.newCapture(); + herder.validateConnectorConfig(EasyMock.eq(props), EasyMock.capture(configInfosCallback), EasyMock.anyBoolean()); + + PowerMock.expectLastCall().andAnswer((IAnswer) () -> { + ConfigDef connectorConfigDef = ConnectorConfig.configDef(); + List connectorConfigValues = connectorConfigDef.validate(props); + + Connector connector = new ConnectorPluginsResourceTestConnector(); + Config config = connector.validate(props); + ConfigDef configDef = connector.config(); + Map configKeys = configDef.configKeys(); + List configValues = config.configValues(); + + Map resultConfigKeys = new HashMap<>(configKeys); + resultConfigKeys.putAll(connectorConfigDef.configKeys()); + configValues.addAll(connectorConfigValues); + + ConfigInfos configInfos = AbstractHerder.generateResult( + ConnectorPluginsResourceTestConnector.class.getName(), + resultConfigKeys, + configValues, + Collections.singletonList("Test") + ); + configInfosCallback.getValue().onCompletion(null, configInfos); + return null; + }); + + PowerMock.replayAll(); + + // make a request to connector-plugins resource using a valid alias. + ConfigInfos configInfos = connectorPluginsResource.validateConfigs( + "ConnectorPluginsResourceTest", + props + ); + assertEquals(CONFIG_INFOS.name(), configInfos.name()); + assertEquals(0, configInfos.errorCount()); + assertEquals(CONFIG_INFOS.groups(), configInfos.groups()); + assertEquals(new HashSet<>(CONFIG_INFOS.values()), new HashSet<>(configInfos.values())); + + PowerMock.verifyAll(); + } + + @Test + public void testValidateConfigWithNonExistentName() { + // make a request to connector-plugins resource using a non-loaded connector with the same + // simple name but different package. + String customClassname = "com.custom.package." + + ConnectorPluginsResourceTestConnector.class.getSimpleName(); + assertThrows(BadRequestException.class, () -> connectorPluginsResource.validateConfigs(customClassname, props)); + } + + @Test + public void testValidateConfigWithNonExistentAlias() { + assertThrows(BadRequestException.class, () -> connectorPluginsResource.validateConfigs("ConnectorPluginsTest", props)); + } + + @Test + public void testListConnectorPlugins() throws Exception { + expectPlugins(); + Set connectorPlugins = new HashSet<>(connectorPluginsResource.listConnectorPlugins()); + assertFalse(connectorPlugins.contains(newInfo(Connector.class, "0.0"))); + assertFalse(connectorPlugins.contains(newInfo(SourceConnector.class, "0.0"))); + assertFalse(connectorPlugins.contains(newInfo(SinkConnector.class, "0.0"))); + assertFalse(connectorPlugins.contains(newInfo(VerifiableSourceConnector.class))); + assertFalse(connectorPlugins.contains(newInfo(VerifiableSinkConnector.class))); + assertFalse(connectorPlugins.contains(newInfo(MockSourceConnector.class))); + assertFalse(connectorPlugins.contains(newInfo(MockSinkConnector.class))); + assertFalse(connectorPlugins.contains(newInfo(MockConnector.class))); + assertFalse(connectorPlugins.contains(newInfo(SchemaSourceConnector.class))); + assertTrue(connectorPlugins.contains(newInfo(ConnectorPluginsResourceTestConnector.class))); + PowerMock.verifyAll(); + } + + @Test + public void testConnectorPluginsIncludesTypeAndVersionInformation() throws Exception { + expectPlugins(); + ConnectorPluginInfo sinkInfo = newInfo(TestSinkConnector.class); + ConnectorPluginInfo sourceInfo = + newInfo(TestSourceConnector.class); + ConnectorPluginInfo unknownInfo = + newInfo(ConnectorPluginsResourceTestConnector.class); + assertEquals(ConnectorType.SINK, sinkInfo.type()); + assertEquals(ConnectorType.SOURCE, sourceInfo.type()); + assertEquals(ConnectorType.UNKNOWN, unknownInfo.type()); + assertEquals(TestSinkConnector.VERSION, sinkInfo.version()); + assertEquals(TestSourceConnector.VERSION, sourceInfo.version()); + + final ObjectMapper objectMapper = new ObjectMapper(); + String serializedSink = objectMapper.writeValueAsString(ConnectorType.SINK); + String serializedSource = objectMapper.writeValueAsString(ConnectorType.SOURCE); + String serializedUnknown = objectMapper.writeValueAsString(ConnectorType.UNKNOWN); + assertTrue(serializedSink.contains("sink")); + assertTrue(serializedSource.contains("source")); + assertTrue(serializedUnknown.contains("unknown")); + assertEquals( + ConnectorType.SINK, + objectMapper.readValue(serializedSink, ConnectorType.class) + ); + assertEquals( + ConnectorType.SOURCE, + objectMapper.readValue(serializedSource, ConnectorType.class) + ); + assertEquals( + ConnectorType.UNKNOWN, + objectMapper.readValue(serializedUnknown, ConnectorType.class) + ); + } + + protected static ConnectorPluginInfo newInfo(Class klass, String version) + throws Exception { + return new ConnectorPluginInfo(new MockConnectorPluginDesc(klass, version)); + } + + protected static ConnectorPluginInfo newInfo(Class klass) + throws Exception { + return new ConnectorPluginInfo(new MockConnectorPluginDesc(klass)); + } + + public static class MockPluginClassLoader extends PluginClassLoader { + public MockPluginClassLoader(URL pluginLocation, URL[] urls, ClassLoader parent) { + super(pluginLocation, urls, parent); + } + + public MockPluginClassLoader(URL pluginLocation, URL[] urls) { + super(pluginLocation, urls); + } + + @Override + public String location() { + return "/tmp/mockpath"; + } + } + + public static class MockConnectorPluginDesc extends PluginDesc { + public MockConnectorPluginDesc(Class klass, String version) { + super(klass, version, new MockPluginClassLoader(null, new URL[0])); + } + + public MockConnectorPluginDesc(Class klass) throws Exception { + super( + klass, + klass.getConstructor().newInstance().version(), + new MockPluginClassLoader(null, new URL[0]) + ); + } + } + + /* Name here needs to be unique as we are testing the aliasing mechanism */ + public static class ConnectorPluginsResourceTestConnector extends Connector { + + private static final String TEST_STRING_CONFIG = "test.string.config"; + private static final String TEST_INT_CONFIG = "test.int.config"; + private static final String TEST_STRING_CONFIG_DEFAULT = "test.string.config.default"; + private static final String TEST_LIST_CONFIG = "test.list.config"; + private static final String GROUP = "Test"; + + private static final ConfigDef CONFIG_DEF = new ConfigDef() + .define(TEST_STRING_CONFIG, Type.STRING, Importance.HIGH, "Test configuration for string type.") + .define(TEST_INT_CONFIG, Type.INT, Importance.MEDIUM, "Test configuration for integer type.", GROUP, 1, Width.MEDIUM, TEST_INT_CONFIG, new IntegerRecommender()) + .define(TEST_STRING_CONFIG_DEFAULT, Type.STRING, "", Importance.LOW, "Test configuration with default value.") + .define(TEST_LIST_CONFIG, Type.LIST, Importance.HIGH, "Test configuration for list type.", GROUP, 2, Width.LONG, TEST_LIST_CONFIG, new ListRecommender()); + + @Override + public String version() { + return "1.0"; + } + + @Override + public void start(Map props) { + + } + + @Override + public Class taskClass() { + return null; + } + + @Override + public List> taskConfigs(int maxTasks) { + return null; + } + + @Override + public void stop() { + + } + + @Override + public ConfigDef config() { + return CONFIG_DEF; + } + } + + private static class IntegerRecommender implements Recommender { + + @Override + public List validValues(String name, Map parsedConfig) { + return asList(1, 2, 3); + } + + @Override + public boolean visible(String name, Map parsedConfig) { + return true; + } + } + + private static class ListRecommender implements Recommender { + @Override + public List validValues(String name, Map parsedConfig) { + return asList("a", "b", "c"); + } + + @Override + public boolean visible(String name, Map parsedConfig) { + return true; + } + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/rest/resources/ConnectorsResourceTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/rest/resources/ConnectorsResourceTest.java new file mode 100644 index 0000000..3a419b8 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/rest/resources/ConnectorsResourceTest.java @@ -0,0 +1,1078 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest.resources; + +import com.fasterxml.jackson.core.type.TypeReference; + +import javax.crypto.Mac; +import javax.ws.rs.core.HttpHeaders; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.kafka.connect.errors.AlreadyExistsException; +import org.apache.kafka.connect.errors.NotFoundException; +import org.apache.kafka.connect.runtime.AbstractStatus; +import org.apache.kafka.connect.runtime.ConnectorConfig; +import org.apache.kafka.connect.runtime.Herder; +import org.apache.kafka.connect.runtime.RestartRequest; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.apache.kafka.connect.runtime.distributed.NotAssignedException; +import org.apache.kafka.connect.runtime.distributed.NotLeaderException; +import org.apache.kafka.connect.runtime.distributed.RebalanceNeededException; +import org.apache.kafka.connect.runtime.rest.InternalRequestSignature; +import org.apache.kafka.connect.runtime.rest.RestClient; +import org.apache.kafka.connect.runtime.rest.entities.ActiveTopicsInfo; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorInfo; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorStateInfo; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorType; +import org.apache.kafka.connect.runtime.rest.entities.CreateConnectorRequest; +import org.apache.kafka.connect.runtime.rest.entities.TaskInfo; +import org.apache.kafka.connect.runtime.rest.errors.ConnectRestException; +import org.apache.kafka.connect.util.Callback; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.easymock.Capture; +import org.easymock.EasyMock; +import org.easymock.IAnswer; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.easymock.PowerMock; +import org.powermock.api.easymock.annotation.Mock; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; + +import javax.ws.rs.BadRequestException; +import javax.ws.rs.core.MultivaluedHashMap; +import javax.ws.rs.core.MultivaluedMap; +import javax.ws.rs.core.Response; +import javax.ws.rs.core.UriInfo; + +import java.io.IOException; +import java.net.URI; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Base64; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.apache.kafka.connect.runtime.WorkerConfig.TOPIC_TRACKING_ALLOW_RESET_CONFIG; +import static org.apache.kafka.connect.runtime.WorkerConfig.TOPIC_TRACKING_ENABLE_CONFIG; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +@RunWith(PowerMockRunner.class) +@PrepareForTest(RestClient.class) +@PowerMockIgnore({"javax.management.*", "javax.crypto.*"}) +@SuppressWarnings("unchecked") +public class ConnectorsResourceTest { + // Note trailing / and that we do *not* use LEADER_URL to construct our reference values. This checks that we handle + // URL construction properly, avoiding //, which will mess up routing in the REST server + private static final String LEADER_URL = "http://leader:8083/"; + private static final String CONNECTOR_NAME = "test"; + private static final String CONNECTOR_NAME_SPECIAL_CHARS = "ta/b&c=d//\\rx=1þ.1>< `'\" x%y+z!ሴ#$&'(æ)*+,:;=?ñ@[]ÿ"; + private static final String CONNECTOR_NAME_CONTROL_SEQUENCES1 = "ta/b&c=drx=1\n.1>< `'\" x%y+z!#$&'()*+,:;=?@[]"; + private static final String CONNECTOR2_NAME = "test2"; + private static final String CONNECTOR_NAME_ALL_WHITESPACES = " \t\n \b"; + private static final String CONNECTOR_NAME_PADDING_WHITESPACES = " " + CONNECTOR_NAME + " \n "; + private static final Boolean FORWARD = true; + private static final Map CONNECTOR_CONFIG_SPECIAL_CHARS = new HashMap<>(); + private static final HttpHeaders NULL_HEADERS = null; + static { + CONNECTOR_CONFIG_SPECIAL_CHARS.put("name", CONNECTOR_NAME_SPECIAL_CHARS); + CONNECTOR_CONFIG_SPECIAL_CHARS.put("sample_config", "test_config"); + } + + private static final Map CONNECTOR_CONFIG = new HashMap<>(); + static { + CONNECTOR_CONFIG.put("name", CONNECTOR_NAME); + CONNECTOR_CONFIG.put("sample_config", "test_config"); + } + + private static final Map CONNECTOR_CONFIG_CONTROL_SEQUENCES = new HashMap<>(); + static { + CONNECTOR_CONFIG_CONTROL_SEQUENCES.put("name", CONNECTOR_NAME_CONTROL_SEQUENCES1); + CONNECTOR_CONFIG_CONTROL_SEQUENCES.put("sample_config", "test_config"); + } + + private static final Map CONNECTOR_CONFIG_WITHOUT_NAME = new HashMap<>(); + static { + CONNECTOR_CONFIG_WITHOUT_NAME.put("sample_config", "test_config"); + } + + private static final Map CONNECTOR_CONFIG_WITH_EMPTY_NAME = new HashMap<>(); + + static { + CONNECTOR_CONFIG_WITH_EMPTY_NAME.put(ConnectorConfig.NAME_CONFIG, ""); + CONNECTOR_CONFIG_WITH_EMPTY_NAME.put("sample_config", "test_config"); + } + private static final List CONNECTOR_TASK_NAMES = Arrays.asList( + new ConnectorTaskId(CONNECTOR_NAME, 0), + new ConnectorTaskId(CONNECTOR_NAME, 1) + ); + private static final List> TASK_CONFIGS = new ArrayList<>(); + static { + TASK_CONFIGS.add(Collections.singletonMap("config", "value")); + TASK_CONFIGS.add(Collections.singletonMap("config", "other_value")); + } + private static final List TASK_INFOS = new ArrayList<>(); + static { + TASK_INFOS.add(new TaskInfo(new ConnectorTaskId(CONNECTOR_NAME, 0), TASK_CONFIGS.get(0))); + TASK_INFOS.add(new TaskInfo(new ConnectorTaskId(CONNECTOR_NAME, 1), TASK_CONFIGS.get(1))); + } + + private static final Set CONNECTOR_ACTIVE_TOPICS = new HashSet<>( + Arrays.asList("foo_topic", "bar_topic")); + private static final Set CONNECTOR2_ACTIVE_TOPICS = new HashSet<>( + Arrays.asList("foo_topic", "baz_topic")); + + @Mock + private Herder herder; + private ConnectorsResource connectorsResource; + private UriInfo forward; + @Mock + private WorkerConfig workerConfig; + + @Before + public void setUp() throws NoSuchMethodException { + PowerMock.mockStatic(RestClient.class, + RestClient.class.getMethod("httpRequest", String.class, String.class, HttpHeaders.class, Object.class, TypeReference.class, WorkerConfig.class)); + EasyMock.expect(workerConfig.getBoolean(TOPIC_TRACKING_ENABLE_CONFIG)).andReturn(true); + EasyMock.expect(workerConfig.getBoolean(TOPIC_TRACKING_ALLOW_RESET_CONFIG)).andReturn(true); + PowerMock.replay(workerConfig); + connectorsResource = new ConnectorsResource(herder, workerConfig); + forward = EasyMock.mock(UriInfo.class); + MultivaluedMap queryParams = new MultivaluedHashMap<>(); + queryParams.putSingle("forward", "true"); + EasyMock.expect(forward.getQueryParameters()).andReturn(queryParams).anyTimes(); + EasyMock.replay(forward); + } + + private static final Map getConnectorConfig(Map mapToClone) { + Map result = new HashMap<>(mapToClone); + return result; + } + + @Test + public void testListConnectors() throws Throwable { + final Capture>> cb = Capture.newInstance(); + EasyMock.expect(herder.connectors()).andReturn(Arrays.asList(CONNECTOR2_NAME, CONNECTOR_NAME)); + + PowerMock.replayAll(); + + Collection connectors = (Collection) connectorsResource.listConnectors(forward, NULL_HEADERS).getEntity(); + // Ordering isn't guaranteed, compare sets + assertEquals(new HashSet<>(Arrays.asList(CONNECTOR_NAME, CONNECTOR2_NAME)), new HashSet<>(connectors)); + + PowerMock.verifyAll(); + } + + @Test + public void testExpandConnectorsStatus() throws Throwable { + EasyMock.expect(herder.connectors()).andReturn(Arrays.asList(CONNECTOR2_NAME, CONNECTOR_NAME)); + ConnectorStateInfo connector = EasyMock.mock(ConnectorStateInfo.class); + ConnectorStateInfo connector2 = EasyMock.mock(ConnectorStateInfo.class); + EasyMock.expect(herder.connectorStatus(CONNECTOR2_NAME)).andReturn(connector2); + EasyMock.expect(herder.connectorStatus(CONNECTOR_NAME)).andReturn(connector); + + forward = EasyMock.mock(UriInfo.class); + MultivaluedMap queryParams = new MultivaluedHashMap<>(); + queryParams.putSingle("expand", "status"); + EasyMock.expect(forward.getQueryParameters()).andReturn(queryParams).anyTimes(); + EasyMock.replay(forward); + + PowerMock.replayAll(); + + Map> expanded = (Map>) connectorsResource.listConnectors(forward, NULL_HEADERS).getEntity(); + // Ordering isn't guaranteed, compare sets + assertEquals(new HashSet<>(Arrays.asList(CONNECTOR_NAME, CONNECTOR2_NAME)), expanded.keySet()); + assertEquals(connector2, expanded.get(CONNECTOR2_NAME).get("status")); + assertEquals(connector, expanded.get(CONNECTOR_NAME).get("status")); + PowerMock.verifyAll(); + } + + @Test + public void testExpandConnectorsInfo() throws Throwable { + EasyMock.expect(herder.connectors()).andReturn(Arrays.asList(CONNECTOR2_NAME, CONNECTOR_NAME)); + ConnectorInfo connector = EasyMock.mock(ConnectorInfo.class); + ConnectorInfo connector2 = EasyMock.mock(ConnectorInfo.class); + EasyMock.expect(herder.connectorInfo(CONNECTOR2_NAME)).andReturn(connector2); + EasyMock.expect(herder.connectorInfo(CONNECTOR_NAME)).andReturn(connector); + + forward = EasyMock.mock(UriInfo.class); + MultivaluedMap queryParams = new MultivaluedHashMap<>(); + queryParams.putSingle("expand", "info"); + EasyMock.expect(forward.getQueryParameters()).andReturn(queryParams).anyTimes(); + EasyMock.replay(forward); + + PowerMock.replayAll(); + + Map> expanded = (Map>) connectorsResource.listConnectors(forward, NULL_HEADERS).getEntity(); + // Ordering isn't guaranteed, compare sets + assertEquals(new HashSet<>(Arrays.asList(CONNECTOR_NAME, CONNECTOR2_NAME)), expanded.keySet()); + assertEquals(connector2, expanded.get(CONNECTOR2_NAME).get("info")); + assertEquals(connector, expanded.get(CONNECTOR_NAME).get("info")); + PowerMock.verifyAll(); + } + + @Test + public void testFullExpandConnectors() throws Throwable { + EasyMock.expect(herder.connectors()).andReturn(Arrays.asList(CONNECTOR2_NAME, CONNECTOR_NAME)); + ConnectorInfo connectorInfo = EasyMock.mock(ConnectorInfo.class); + ConnectorInfo connectorInfo2 = EasyMock.mock(ConnectorInfo.class); + EasyMock.expect(herder.connectorInfo(CONNECTOR2_NAME)).andReturn(connectorInfo2); + EasyMock.expect(herder.connectorInfo(CONNECTOR_NAME)).andReturn(connectorInfo); + ConnectorStateInfo connector = EasyMock.mock(ConnectorStateInfo.class); + ConnectorStateInfo connector2 = EasyMock.mock(ConnectorStateInfo.class); + EasyMock.expect(herder.connectorStatus(CONNECTOR2_NAME)).andReturn(connector2); + EasyMock.expect(herder.connectorStatus(CONNECTOR_NAME)).andReturn(connector); + + forward = EasyMock.mock(UriInfo.class); + MultivaluedMap queryParams = new MultivaluedHashMap<>(); + queryParams.put("expand", Arrays.asList("info", "status")); + EasyMock.expect(forward.getQueryParameters()).andReturn(queryParams).anyTimes(); + EasyMock.replay(forward); + + PowerMock.replayAll(); + + Map> expanded = (Map>) connectorsResource.listConnectors(forward, NULL_HEADERS).getEntity(); + // Ordering isn't guaranteed, compare sets + assertEquals(new HashSet<>(Arrays.asList(CONNECTOR_NAME, CONNECTOR2_NAME)), expanded.keySet()); + assertEquals(connectorInfo2, expanded.get(CONNECTOR2_NAME).get("info")); + assertEquals(connectorInfo, expanded.get(CONNECTOR_NAME).get("info")); + assertEquals(connector2, expanded.get(CONNECTOR2_NAME).get("status")); + assertEquals(connector, expanded.get(CONNECTOR_NAME).get("status")); + PowerMock.verifyAll(); + } + + @Test + public void testExpandConnectorsWithConnectorNotFound() throws Throwable { + EasyMock.expect(herder.connectors()).andReturn(Arrays.asList(CONNECTOR2_NAME, CONNECTOR_NAME)); + ConnectorStateInfo connector = EasyMock.mock(ConnectorStateInfo.class); + ConnectorStateInfo connector2 = EasyMock.mock(ConnectorStateInfo.class); + EasyMock.expect(herder.connectorStatus(CONNECTOR2_NAME)).andReturn(connector2); + EasyMock.expect(herder.connectorStatus(CONNECTOR_NAME)).andThrow(EasyMock.mock(NotFoundException.class)); + + forward = EasyMock.mock(UriInfo.class); + MultivaluedMap queryParams = new MultivaluedHashMap<>(); + queryParams.putSingle("expand", "status"); + EasyMock.expect(forward.getQueryParameters()).andReturn(queryParams).anyTimes(); + EasyMock.replay(forward); + + PowerMock.replayAll(); + + Map> expanded = (Map>) connectorsResource.listConnectors(forward, NULL_HEADERS).getEntity(); + // Ordering isn't guaranteed, compare sets + assertEquals(Collections.singleton(CONNECTOR2_NAME), expanded.keySet()); + assertEquals(connector2, expanded.get(CONNECTOR2_NAME).get("status")); + PowerMock.verifyAll(); + } + + + @Test + public void testCreateConnector() throws Throwable { + CreateConnectorRequest body = new CreateConnectorRequest(CONNECTOR_NAME, Collections.singletonMap(ConnectorConfig.NAME_CONFIG, CONNECTOR_NAME)); + + final Capture>> cb = Capture.newInstance(); + herder.putConnectorConfig(EasyMock.eq(CONNECTOR_NAME), EasyMock.eq(body.config()), EasyMock.eq(false), EasyMock.capture(cb)); + expectAndCallbackResult(cb, new Herder.Created<>(true, new ConnectorInfo(CONNECTOR_NAME, CONNECTOR_CONFIG, + CONNECTOR_TASK_NAMES, ConnectorType.SOURCE))); + + PowerMock.replayAll(); + + connectorsResource.createConnector(FORWARD, NULL_HEADERS, body); + + PowerMock.verifyAll(); + } + + @Test + public void testCreateConnectorNotLeader() throws Throwable { + CreateConnectorRequest body = new CreateConnectorRequest(CONNECTOR_NAME, Collections.singletonMap(ConnectorConfig.NAME_CONFIG, CONNECTOR_NAME)); + + final Capture>> cb = Capture.newInstance(); + herder.putConnectorConfig(EasyMock.eq(CONNECTOR_NAME), EasyMock.eq(body.config()), EasyMock.eq(false), EasyMock.capture(cb)); + expectAndCallbackNotLeaderException(cb); + // Should forward request + EasyMock.expect(RestClient.httpRequest(EasyMock.eq("http://leader:8083/connectors?forward=false"), EasyMock.eq("POST"), EasyMock.isNull(), EasyMock.eq(body), EasyMock.anyObject(), EasyMock.anyObject(WorkerConfig.class))) + .andReturn(new RestClient.HttpResponse<>(201, new HashMap<>(), new ConnectorInfo(CONNECTOR_NAME, CONNECTOR_CONFIG, CONNECTOR_TASK_NAMES, + ConnectorType.SOURCE))); + + PowerMock.replayAll(); + + connectorsResource.createConnector(FORWARD, NULL_HEADERS, body); + + PowerMock.verifyAll(); + + + } + + @Test + public void testCreateConnectorWithHeaderAuthorization() throws Throwable { + CreateConnectorRequest body = new CreateConnectorRequest(CONNECTOR_NAME, Collections.singletonMap(ConnectorConfig.NAME_CONFIG, CONNECTOR_NAME)); + final Capture>> cb = Capture.newInstance(); + HttpHeaders httpHeaders = EasyMock.mock(HttpHeaders.class); + EasyMock.expect(httpHeaders.getHeaderString("Authorization")).andReturn("Basic YWxhZGRpbjpvcGVuc2VzYW1l").times(1); + EasyMock.replay(httpHeaders); + herder.putConnectorConfig(EasyMock.eq(CONNECTOR_NAME), EasyMock.eq(body.config()), EasyMock.eq(false), EasyMock.capture(cb)); + expectAndCallbackResult(cb, new Herder.Created<>(true, new ConnectorInfo(CONNECTOR_NAME, CONNECTOR_CONFIG, + CONNECTOR_TASK_NAMES, ConnectorType.SOURCE))); + + PowerMock.replayAll(); + + connectorsResource.createConnector(FORWARD, httpHeaders, body); + + PowerMock.verifyAll(); + } + + + + @Test + public void testCreateConnectorWithoutHeaderAuthorization() throws Throwable { + CreateConnectorRequest body = new CreateConnectorRequest(CONNECTOR_NAME, Collections.singletonMap(ConnectorConfig.NAME_CONFIG, CONNECTOR_NAME)); + final Capture>> cb = Capture.newInstance(); + HttpHeaders httpHeaders = EasyMock.mock(HttpHeaders.class); + EasyMock.expect(httpHeaders.getHeaderString("Authorization")).andReturn(null).times(1); + EasyMock.replay(httpHeaders); + herder.putConnectorConfig(EasyMock.eq(CONNECTOR_NAME), EasyMock.eq(body.config()), EasyMock.eq(false), EasyMock.capture(cb)); + expectAndCallbackResult(cb, new Herder.Created<>(true, new ConnectorInfo(CONNECTOR_NAME, CONNECTOR_CONFIG, + CONNECTOR_TASK_NAMES, ConnectorType.SOURCE))); + + PowerMock.replayAll(); + + connectorsResource.createConnector(FORWARD, httpHeaders, body); + + PowerMock.verifyAll(); + } + + @Test + public void testCreateConnectorExists() { + CreateConnectorRequest body = new CreateConnectorRequest(CONNECTOR_NAME, Collections.singletonMap(ConnectorConfig.NAME_CONFIG, CONNECTOR_NAME)); + + final Capture>> cb = Capture.newInstance(); + herder.putConnectorConfig(EasyMock.eq(CONNECTOR_NAME), EasyMock.eq(body.config()), EasyMock.eq(false), EasyMock.capture(cb)); + expectAndCallbackException(cb, new AlreadyExistsException("already exists")); + + PowerMock.replayAll(); + + assertThrows(AlreadyExistsException.class, () -> connectorsResource.createConnector(FORWARD, NULL_HEADERS, body)); + + PowerMock.verifyAll(); + } + + @Test + public void testCreateConnectorNameTrimWhitespaces() throws Throwable { + // Clone CONNECTOR_CONFIG_WITHOUT_NAME Map, as createConnector changes it (puts the name in it) and this + // will affect later tests + Map inputConfig = getConnectorConfig(CONNECTOR_CONFIG_WITHOUT_NAME); + final CreateConnectorRequest bodyIn = new CreateConnectorRequest(CONNECTOR_NAME_PADDING_WHITESPACES, inputConfig); + final CreateConnectorRequest bodyOut = new CreateConnectorRequest(CONNECTOR_NAME, CONNECTOR_CONFIG); + + final Capture>> cb = Capture.newInstance(); + herder.putConnectorConfig(EasyMock.eq(bodyOut.name()), EasyMock.eq(bodyOut.config()), EasyMock.eq(false), EasyMock.capture(cb)); + expectAndCallbackResult(cb, new Herder.Created<>(true, new ConnectorInfo(bodyOut.name(), bodyOut.config(), CONNECTOR_TASK_NAMES, ConnectorType.SOURCE))); + + PowerMock.replayAll(); + + connectorsResource.createConnector(FORWARD, NULL_HEADERS, bodyIn); + + PowerMock.verifyAll(); + } + + @Test + public void testCreateConnectorNameAllWhitespaces() throws Throwable { + // Clone CONNECTOR_CONFIG_WITHOUT_NAME Map, as createConnector changes it (puts the name in it) and this + // will affect later tests + Map inputConfig = getConnectorConfig(CONNECTOR_CONFIG_WITHOUT_NAME); + final CreateConnectorRequest bodyIn = new CreateConnectorRequest(CONNECTOR_NAME_ALL_WHITESPACES, inputConfig); + final CreateConnectorRequest bodyOut = new CreateConnectorRequest("", CONNECTOR_CONFIG_WITH_EMPTY_NAME); + + final Capture>> cb = Capture.newInstance(); + herder.putConnectorConfig(EasyMock.eq(bodyOut.name()), EasyMock.eq(bodyOut.config()), EasyMock.eq(false), EasyMock.capture(cb)); + expectAndCallbackResult(cb, new Herder.Created<>(true, new ConnectorInfo(bodyOut.name(), bodyOut.config(), CONNECTOR_TASK_NAMES, ConnectorType.SOURCE))); + + PowerMock.replayAll(); + + connectorsResource.createConnector(FORWARD, NULL_HEADERS, bodyIn); + + PowerMock.verifyAll(); + } + + @Test + public void testCreateConnectorNoName() throws Throwable { + // Clone CONNECTOR_CONFIG_WITHOUT_NAME Map, as createConnector changes it (puts the name in it) and this + // will affect later tests + Map inputConfig = getConnectorConfig(CONNECTOR_CONFIG_WITHOUT_NAME); + final CreateConnectorRequest bodyIn = new CreateConnectorRequest(null, inputConfig); + final CreateConnectorRequest bodyOut = new CreateConnectorRequest("", CONNECTOR_CONFIG_WITH_EMPTY_NAME); + + final Capture>> cb = Capture.newInstance(); + herder.putConnectorConfig(EasyMock.eq(bodyOut.name()), EasyMock.eq(bodyOut.config()), EasyMock.eq(false), EasyMock.capture(cb)); + expectAndCallbackResult(cb, new Herder.Created<>(true, new ConnectorInfo(bodyOut.name(), bodyOut.config(), CONNECTOR_TASK_NAMES, ConnectorType.SOURCE))); + + PowerMock.replayAll(); + + connectorsResource.createConnector(FORWARD, NULL_HEADERS, bodyIn); + + PowerMock.verifyAll(); + } + + @Test + public void testDeleteConnector() throws Throwable { + final Capture>> cb = Capture.newInstance(); + herder.deleteConnectorConfig(EasyMock.eq(CONNECTOR_NAME), EasyMock.capture(cb)); + expectAndCallbackResult(cb, null); + + PowerMock.replayAll(); + + connectorsResource.destroyConnector(CONNECTOR_NAME, NULL_HEADERS, FORWARD); + + PowerMock.verifyAll(); + } + + @Test + public void testDeleteConnectorNotLeader() throws Throwable { + final Capture>> cb = Capture.newInstance(); + herder.deleteConnectorConfig(EasyMock.eq(CONNECTOR_NAME), EasyMock.capture(cb)); + expectAndCallbackNotLeaderException(cb); + // Should forward request + EasyMock.expect(RestClient.httpRequest("http://leader:8083/connectors/" + CONNECTOR_NAME + "?forward=false", "DELETE", NULL_HEADERS, null, null, workerConfig)) + .andReturn(new RestClient.HttpResponse<>(204, new HashMap<>(), null)); + + PowerMock.replayAll(); + + connectorsResource.destroyConnector(CONNECTOR_NAME, NULL_HEADERS, FORWARD); + + PowerMock.verifyAll(); + } + + // Not found exceptions should pass through to caller so they can be processed for 404s + @Test + public void testDeleteConnectorNotFound() throws Throwable { + final Capture>> cb = Capture.newInstance(); + herder.deleteConnectorConfig(EasyMock.eq(CONNECTOR_NAME), EasyMock.capture(cb)); + expectAndCallbackException(cb, new NotFoundException("not found")); + + PowerMock.replayAll(); + + assertThrows(NotFoundException.class, () -> connectorsResource.destroyConnector(CONNECTOR_NAME, NULL_HEADERS, FORWARD)); + + PowerMock.verifyAll(); + } + + @Test + public void testGetConnector() throws Throwable { + final Capture> cb = Capture.newInstance(); + herder.connectorInfo(EasyMock.eq(CONNECTOR_NAME), EasyMock.capture(cb)); + expectAndCallbackResult(cb, new ConnectorInfo(CONNECTOR_NAME, CONNECTOR_CONFIG, CONNECTOR_TASK_NAMES, + ConnectorType.SOURCE)); + + PowerMock.replayAll(); + + ConnectorInfo connInfo = connectorsResource.getConnector(CONNECTOR_NAME, NULL_HEADERS, FORWARD); + assertEquals(new ConnectorInfo(CONNECTOR_NAME, CONNECTOR_CONFIG, CONNECTOR_TASK_NAMES, ConnectorType.SOURCE), + connInfo); + + PowerMock.verifyAll(); + } + + @Test + public void testGetConnectorConfig() throws Throwable { + final Capture>> cb = Capture.newInstance(); + herder.connectorConfig(EasyMock.eq(CONNECTOR_NAME), EasyMock.capture(cb)); + expectAndCallbackResult(cb, CONNECTOR_CONFIG); + + PowerMock.replayAll(); + + Map connConfig = connectorsResource.getConnectorConfig(CONNECTOR_NAME, NULL_HEADERS, FORWARD); + assertEquals(CONNECTOR_CONFIG, connConfig); + + PowerMock.verifyAll(); + } + + @Test + public void testGetConnectorConfigConnectorNotFound() { + final Capture>> cb = Capture.newInstance(); + herder.connectorConfig(EasyMock.eq(CONNECTOR_NAME), EasyMock.capture(cb)); + expectAndCallbackException(cb, new NotFoundException("not found")); + + PowerMock.replayAll(); + + assertThrows(NotFoundException.class, () -> connectorsResource.getConnectorConfig(CONNECTOR_NAME, NULL_HEADERS, FORWARD)); + + PowerMock.verifyAll(); + } + + @Test + public void testGetTasksConfig() throws Throwable { + final ConnectorTaskId connectorTask0 = new ConnectorTaskId(CONNECTOR_NAME, 0); + final Map connectorTask0Configs = new HashMap<>(); + connectorTask0Configs.put("connector-task0-config0", "123"); + connectorTask0Configs.put("connector-task0-config1", "456"); + final ConnectorTaskId connectorTask1 = new ConnectorTaskId(CONNECTOR_NAME, 1); + final Map connectorTask1Configs = new HashMap<>(); + connectorTask0Configs.put("connector-task1-config0", "321"); + connectorTask0Configs.put("connector-task1-config1", "654"); + final ConnectorTaskId connector2Task0 = new ConnectorTaskId(CONNECTOR2_NAME, 0); + final Map connector2Task0Configs = Collections.singletonMap("connector2-task0-config0", "789"); + + final Map> expectedTasksConnector = new HashMap<>(); + expectedTasksConnector.put(connectorTask0, connectorTask0Configs); + expectedTasksConnector.put(connectorTask1, connectorTask1Configs); + final Map> expectedTasksConnector2 = new HashMap<>(); + expectedTasksConnector2.put(connector2Task0, connector2Task0Configs); + + final Capture>>> cb1 = Capture.newInstance(); + herder.tasksConfig(EasyMock.eq(CONNECTOR_NAME), EasyMock.capture(cb1)); + expectAndCallbackResult(cb1, expectedTasksConnector); + final Capture>>> cb2 = Capture.newInstance(); + herder.tasksConfig(EasyMock.eq(CONNECTOR2_NAME), EasyMock.capture(cb2)); + expectAndCallbackResult(cb2, expectedTasksConnector2); + + PowerMock.replayAll(); + + Map> tasksConfig = connectorsResource.getTasksConfig(CONNECTOR_NAME, NULL_HEADERS, FORWARD); + assertEquals(expectedTasksConnector, tasksConfig); + Map> tasksConfig2 = connectorsResource.getTasksConfig(CONNECTOR2_NAME, NULL_HEADERS, FORWARD); + assertEquals(expectedTasksConnector2, tasksConfig2); + + PowerMock.verifyAll(); + } + + @Test(expected = NotFoundException.class) + public void testGetTasksConfigConnectorNotFound() throws Throwable { + final Capture>>> cb = Capture.newInstance(); + herder.tasksConfig(EasyMock.eq(CONNECTOR_NAME), EasyMock.capture(cb)); + expectAndCallbackException(cb, new NotFoundException("not found")); + + PowerMock.replayAll(); + + connectorsResource.getTasksConfig(CONNECTOR_NAME, NULL_HEADERS, FORWARD); + + PowerMock.verifyAll(); + } + + @Test + public void testPutConnectorConfig() throws Throwable { + final Capture>> cb = Capture.newInstance(); + herder.putConnectorConfig(EasyMock.eq(CONNECTOR_NAME), EasyMock.eq(CONNECTOR_CONFIG), EasyMock.eq(true), EasyMock.capture(cb)); + expectAndCallbackResult(cb, new Herder.Created<>(false, new ConnectorInfo(CONNECTOR_NAME, CONNECTOR_CONFIG, CONNECTOR_TASK_NAMES, + ConnectorType.SINK))); + + PowerMock.replayAll(); + + connectorsResource.putConnectorConfig(CONNECTOR_NAME, NULL_HEADERS, FORWARD, CONNECTOR_CONFIG); + + PowerMock.verifyAll(); + } + + @Test + public void testCreateConnectorWithSpecialCharsInName() throws Throwable { + CreateConnectorRequest body = new CreateConnectorRequest(CONNECTOR_NAME_SPECIAL_CHARS, Collections.singletonMap(ConnectorConfig.NAME_CONFIG, CONNECTOR_NAME_SPECIAL_CHARS)); + + final Capture>> cb = Capture.newInstance(); + herder.putConnectorConfig(EasyMock.eq(CONNECTOR_NAME_SPECIAL_CHARS), EasyMock.eq(body.config()), EasyMock.eq(false), EasyMock.capture(cb)); + expectAndCallbackResult(cb, new Herder.Created<>(true, new ConnectorInfo(CONNECTOR_NAME_SPECIAL_CHARS, CONNECTOR_CONFIG, + CONNECTOR_TASK_NAMES, ConnectorType.SOURCE))); + + PowerMock.replayAll(); + + String rspLocation = connectorsResource.createConnector(FORWARD, NULL_HEADERS, body).getLocation().toString(); + String decoded = new URI(rspLocation).getPath(); + Assert.assertEquals("/connectors/" + CONNECTOR_NAME_SPECIAL_CHARS, decoded); + + PowerMock.verifyAll(); + } + + @Test + public void testCreateConnectorWithControlSequenceInName() throws Throwable { + CreateConnectorRequest body = new CreateConnectorRequest(CONNECTOR_NAME_CONTROL_SEQUENCES1, Collections.singletonMap(ConnectorConfig.NAME_CONFIG, CONNECTOR_NAME_CONTROL_SEQUENCES1)); + + final Capture>> cb = Capture.newInstance(); + herder.putConnectorConfig(EasyMock.eq(CONNECTOR_NAME_CONTROL_SEQUENCES1), EasyMock.eq(body.config()), EasyMock.eq(false), EasyMock.capture(cb)); + expectAndCallbackResult(cb, new Herder.Created<>(true, new ConnectorInfo(CONNECTOR_NAME_CONTROL_SEQUENCES1, CONNECTOR_CONFIG, + CONNECTOR_TASK_NAMES, ConnectorType.SOURCE))); + + PowerMock.replayAll(); + + String rspLocation = connectorsResource.createConnector(FORWARD, NULL_HEADERS, body).getLocation().toString(); + String decoded = new URI(rspLocation).getPath(); + Assert.assertEquals("/connectors/" + CONNECTOR_NAME_CONTROL_SEQUENCES1, decoded); + + PowerMock.verifyAll(); + } + + @Test + public void testPutConnectorConfigWithSpecialCharsInName() throws Throwable { + final Capture>> cb = Capture.newInstance(); + + herder.putConnectorConfig(EasyMock.eq(CONNECTOR_NAME_SPECIAL_CHARS), EasyMock.eq(CONNECTOR_CONFIG_SPECIAL_CHARS), EasyMock.eq(true), EasyMock.capture(cb)); + expectAndCallbackResult(cb, new Herder.Created<>(true, new ConnectorInfo(CONNECTOR_NAME_SPECIAL_CHARS, CONNECTOR_CONFIG_SPECIAL_CHARS, CONNECTOR_TASK_NAMES, + ConnectorType.SINK))); + + PowerMock.replayAll(); + + String rspLocation = connectorsResource.putConnectorConfig(CONNECTOR_NAME_SPECIAL_CHARS, NULL_HEADERS, FORWARD, CONNECTOR_CONFIG_SPECIAL_CHARS).getLocation().toString(); + String decoded = new URI(rspLocation).getPath(); + Assert.assertEquals("/connectors/" + CONNECTOR_NAME_SPECIAL_CHARS, decoded); + + PowerMock.verifyAll(); + } + + @Test + public void testPutConnectorConfigWithControlSequenceInName() throws Throwable { + final Capture>> cb = Capture.newInstance(); + + herder.putConnectorConfig(EasyMock.eq(CONNECTOR_NAME_CONTROL_SEQUENCES1), EasyMock.eq(CONNECTOR_CONFIG_CONTROL_SEQUENCES), EasyMock.eq(true), EasyMock.capture(cb)); + expectAndCallbackResult(cb, new Herder.Created<>(true, new ConnectorInfo(CONNECTOR_NAME_CONTROL_SEQUENCES1, CONNECTOR_CONFIG_CONTROL_SEQUENCES, CONNECTOR_TASK_NAMES, + ConnectorType.SINK))); + + PowerMock.replayAll(); + + String rspLocation = connectorsResource.putConnectorConfig(CONNECTOR_NAME_CONTROL_SEQUENCES1, NULL_HEADERS, FORWARD, CONNECTOR_CONFIG_CONTROL_SEQUENCES).getLocation().toString(); + String decoded = new URI(rspLocation).getPath(); + Assert.assertEquals("/connectors/" + CONNECTOR_NAME_CONTROL_SEQUENCES1, decoded); + + PowerMock.verifyAll(); + } + + @Test + public void testPutConnectorConfigNameMismatch() { + Map connConfig = new HashMap<>(CONNECTOR_CONFIG); + connConfig.put(ConnectorConfig.NAME_CONFIG, "mismatched-name"); + assertThrows(BadRequestException.class, () -> connectorsResource.putConnectorConfig(CONNECTOR_NAME, + NULL_HEADERS, FORWARD, connConfig)); + } + + @Test + public void testCreateConnectorConfigNameMismatch() { + Map connConfig = new HashMap<>(); + connConfig.put(ConnectorConfig.NAME_CONFIG, "mismatched-name"); + CreateConnectorRequest request = new CreateConnectorRequest(CONNECTOR_NAME, connConfig); + assertThrows(BadRequestException.class, () -> connectorsResource.createConnector(FORWARD, NULL_HEADERS, request)); + } + + @Test + public void testGetConnectorTaskConfigs() throws Throwable { + final Capture>> cb = Capture.newInstance(); + herder.taskConfigs(EasyMock.eq(CONNECTOR_NAME), EasyMock.capture(cb)); + expectAndCallbackResult(cb, TASK_INFOS); + + PowerMock.replayAll(); + + List taskInfos = connectorsResource.getTaskConfigs(CONNECTOR_NAME, NULL_HEADERS, FORWARD); + assertEquals(TASK_INFOS, taskInfos); + + PowerMock.verifyAll(); + } + + @Test + public void testGetConnectorTaskConfigsConnectorNotFound() throws Throwable { + final Capture>> cb = Capture.newInstance(); + herder.taskConfigs(EasyMock.eq(CONNECTOR_NAME), EasyMock.capture(cb)); + expectAndCallbackException(cb, new NotFoundException("connector not found")); + + PowerMock.replayAll(); + + assertThrows(NotFoundException.class, () -> connectorsResource.getTaskConfigs(CONNECTOR_NAME, NULL_HEADERS, FORWARD)); + + PowerMock.verifyAll(); + } + + @Test + public void testPutConnectorTaskConfigsNoInternalRequestSignature() throws Throwable { + final Capture> cb = Capture.newInstance(); + herder.putTaskConfigs( + EasyMock.eq(CONNECTOR_NAME), + EasyMock.eq(TASK_CONFIGS), + EasyMock.capture(cb), + EasyMock.anyObject(InternalRequestSignature.class) + ); + expectAndCallbackResult(cb, null); + + PowerMock.replayAll(); + + connectorsResource.putTaskConfigs(CONNECTOR_NAME, NULL_HEADERS, FORWARD, serializeAsBytes(TASK_CONFIGS)); + + PowerMock.verifyAll(); + } + + @Test + public void testPutConnectorTaskConfigsWithInternalRequestSignature() throws Throwable { + final String signatureAlgorithm = "HmacSHA256"; + final String encodedSignature = "Kv1/OSsxzdVIwvZ4e30avyRIVrngDfhzVUm/kAZEKc4="; + + final Capture> cb = Capture.newInstance(); + final Capture signatureCapture = Capture.newInstance(); + herder.putTaskConfigs( + EasyMock.eq(CONNECTOR_NAME), + EasyMock.eq(TASK_CONFIGS), + EasyMock.capture(cb), + EasyMock.capture(signatureCapture) + ); + expectAndCallbackResult(cb, null); + + HttpHeaders headers = EasyMock.mock(HttpHeaders.class); + EasyMock.expect(headers.getHeaderString(InternalRequestSignature.SIGNATURE_ALGORITHM_HEADER)) + .andReturn(signatureAlgorithm) + .once(); + EasyMock.expect(headers.getHeaderString(InternalRequestSignature.SIGNATURE_HEADER)) + .andReturn(encodedSignature) + .once(); + + PowerMock.replayAll(headers); + + connectorsResource.putTaskConfigs(CONNECTOR_NAME, headers, FORWARD, serializeAsBytes(TASK_CONFIGS)); + + PowerMock.verifyAll(); + InternalRequestSignature expectedSignature = new InternalRequestSignature( + serializeAsBytes(TASK_CONFIGS), + Mac.getInstance(signatureAlgorithm), + Base64.getDecoder().decode(encodedSignature) + ); + assertEquals( + expectedSignature, + signatureCapture.getValue() + ); + } + + @Test + public void testPutConnectorTaskConfigsConnectorNotFound() throws Throwable { + final Capture> cb = Capture.newInstance(); + herder.putTaskConfigs( + EasyMock.eq(CONNECTOR_NAME), + EasyMock.eq(TASK_CONFIGS), + EasyMock.capture(cb), + EasyMock.anyObject(InternalRequestSignature.class) + ); + expectAndCallbackException(cb, new NotFoundException("not found")); + + PowerMock.replayAll(); + + assertThrows(NotFoundException.class, () -> connectorsResource.putTaskConfigs(CONNECTOR_NAME, NULL_HEADERS, + FORWARD, serializeAsBytes(TASK_CONFIGS))); + + PowerMock.verifyAll(); + } + + @Test + public void testRestartConnectorAndTasksConnectorNotFound() { + RestartRequest restartRequest = new RestartRequest(CONNECTOR_NAME, true, false); + final Capture> cb = Capture.newInstance(); + herder.restartConnectorAndTasks(EasyMock.eq(restartRequest), EasyMock.capture(cb)); + expectAndCallbackException(cb, new NotFoundException("not found")); + + PowerMock.replayAll(); + + assertThrows(NotFoundException.class, () -> + connectorsResource.restartConnector(CONNECTOR_NAME, NULL_HEADERS, restartRequest.includeTasks(), restartRequest.onlyFailed(), FORWARD) + ); + + PowerMock.verifyAll(); + } + + @Test + public void testRestartConnectorAndTasksLeaderRedirect() throws Throwable { + RestartRequest restartRequest = new RestartRequest(CONNECTOR_NAME, true, false); + final Capture> cb = Capture.newInstance(); + herder.restartConnectorAndTasks(EasyMock.eq(restartRequest), EasyMock.capture(cb)); + expectAndCallbackNotLeaderException(cb); + + EasyMock.expect(RestClient.httpRequest(EasyMock.eq("http://leader:8083/connectors/" + CONNECTOR_NAME + "/restart?forward=true&includeTasks=" + restartRequest.includeTasks() + "&onlyFailed=" + restartRequest.onlyFailed()), + EasyMock.eq("POST"), EasyMock.isNull(), EasyMock.isNull(), EasyMock.anyObject(), EasyMock.anyObject(WorkerConfig.class))) + .andReturn(new RestClient.HttpResponse<>(202, new HashMap<>(), null)); + + PowerMock.replayAll(); + + Response response = connectorsResource.restartConnector(CONNECTOR_NAME, NULL_HEADERS, restartRequest.includeTasks(), restartRequest.onlyFailed(), null); + assertEquals(Response.Status.ACCEPTED.getStatusCode(), response.getStatus()); + + PowerMock.verifyAll(); + } + + @Test + public void testRestartConnectorAndTasksRebalanceNeeded() { + RestartRequest restartRequest = new RestartRequest(CONNECTOR_NAME, true, false); + final Capture> cb = Capture.newInstance(); + herder.restartConnectorAndTasks(EasyMock.eq(restartRequest), EasyMock.capture(cb)); + expectAndCallbackException(cb, new RebalanceNeededException("Request cannot be completed because a rebalance is expected")); + + PowerMock.replayAll(); + + ConnectRestException ex = assertThrows(ConnectRestException.class, () -> + connectorsResource.restartConnector(CONNECTOR_NAME, NULL_HEADERS, restartRequest.includeTasks(), restartRequest.onlyFailed(), FORWARD) + ); + assertEquals(Response.Status.CONFLICT.getStatusCode(), ex.statusCode()); + PowerMock.verifyAll(); + } + + @Test + public void testRestartConnectorAndTasksRequestAccepted() throws Throwable { + ConnectorStateInfo.ConnectorState state = new ConnectorStateInfo.ConnectorState( + AbstractStatus.State.RESTARTING.name(), + "foo", + null + ); + ConnectorStateInfo connectorStateInfo = new ConnectorStateInfo(CONNECTOR_NAME, state, Collections.emptyList(), ConnectorType.SOURCE); + + RestartRequest restartRequest = new RestartRequest(CONNECTOR_NAME, true, false); + final Capture> cb = Capture.newInstance(); + herder.restartConnectorAndTasks(EasyMock.eq(restartRequest), EasyMock.capture(cb)); + expectAndCallbackResult(cb, connectorStateInfo); + + PowerMock.replayAll(); + + Response response = connectorsResource.restartConnector(CONNECTOR_NAME, NULL_HEADERS, restartRequest.includeTasks(), restartRequest.onlyFailed(), FORWARD); + assertEquals(CONNECTOR_NAME, ((ConnectorStateInfo) response.getEntity()).name()); + assertEquals(state.state(), ((ConnectorStateInfo) response.getEntity()).connector().state()); + assertEquals(Response.Status.ACCEPTED.getStatusCode(), response.getStatus()); + PowerMock.verifyAll(); + } + + @Test + public void testRestartConnectorNotFound() { + final Capture> cb = Capture.newInstance(); + herder.restartConnector(EasyMock.eq(CONNECTOR_NAME), EasyMock.capture(cb)); + expectAndCallbackException(cb, new NotFoundException("not found")); + + PowerMock.replayAll(); + + assertThrows(NotFoundException.class, () -> + connectorsResource.restartConnector(CONNECTOR_NAME, NULL_HEADERS, false, false, FORWARD) + ); + + PowerMock.verifyAll(); + } + + @Test + public void testRestartConnectorLeaderRedirect() throws Throwable { + final Capture> cb = Capture.newInstance(); + herder.restartConnector(EasyMock.eq(CONNECTOR_NAME), EasyMock.capture(cb)); + expectAndCallbackNotLeaderException(cb); + + EasyMock.expect(RestClient.httpRequest(EasyMock.eq("http://leader:8083/connectors/" + CONNECTOR_NAME + "/restart?forward=true"), + EasyMock.eq("POST"), EasyMock.isNull(), EasyMock.isNull(), EasyMock.anyObject(), EasyMock.anyObject(WorkerConfig.class))) + .andReturn(new RestClient.HttpResponse<>(202, new HashMap<>(), null)); + + PowerMock.replayAll(); + + Response response = connectorsResource.restartConnector(CONNECTOR_NAME, NULL_HEADERS, false, false, null); + assertEquals(Response.Status.NO_CONTENT.getStatusCode(), response.getStatus()); + PowerMock.verifyAll(); + + PowerMock.verifyAll(); + } + + @Test + public void testRestartConnectorOwnerRedirect() throws Throwable { + final Capture> cb = Capture.newInstance(); + herder.restartConnector(EasyMock.eq(CONNECTOR_NAME), EasyMock.capture(cb)); + String ownerUrl = "http://owner:8083"; + expectAndCallbackException(cb, new NotAssignedException("not owner test", ownerUrl)); + + EasyMock.expect(RestClient.httpRequest(EasyMock.eq("http://owner:8083/connectors/" + CONNECTOR_NAME + "/restart?forward=false"), + EasyMock.eq("POST"), EasyMock.isNull(), EasyMock.isNull(), EasyMock.anyObject(), EasyMock.anyObject(WorkerConfig.class))) + .andReturn(new RestClient.HttpResponse<>(202, new HashMap<>(), null)); + + PowerMock.replayAll(); + + Response response = connectorsResource.restartConnector(CONNECTOR_NAME, NULL_HEADERS, false, false, true); + assertEquals(Response.Status.NO_CONTENT.getStatusCode(), response.getStatus()); + PowerMock.verifyAll(); + + PowerMock.verifyAll(); + } + + @Test + public void testRestartTaskNotFound() throws Throwable { + ConnectorTaskId taskId = new ConnectorTaskId(CONNECTOR_NAME, 0); + final Capture> cb = Capture.newInstance(); + herder.restartTask(EasyMock.eq(taskId), EasyMock.capture(cb)); + expectAndCallbackException(cb, new NotFoundException("not found")); + + PowerMock.replayAll(); + + assertThrows(NotFoundException.class, () -> connectorsResource.restartTask(CONNECTOR_NAME, 0, NULL_HEADERS, FORWARD)); + + PowerMock.verifyAll(); + } + + @Test + public void testRestartTaskLeaderRedirect() throws Throwable { + ConnectorTaskId taskId = new ConnectorTaskId(CONNECTOR_NAME, 0); + + final Capture> cb = Capture.newInstance(); + herder.restartTask(EasyMock.eq(taskId), EasyMock.capture(cb)); + expectAndCallbackNotLeaderException(cb); + + EasyMock.expect(RestClient.httpRequest(EasyMock.eq("http://leader:8083/connectors/" + CONNECTOR_NAME + "/tasks/0/restart?forward=true"), + EasyMock.eq("POST"), EasyMock.isNull(), EasyMock.isNull(), EasyMock.anyObject(), EasyMock.anyObject(WorkerConfig.class))) + .andReturn(new RestClient.HttpResponse<>(202, new HashMap<>(), null)); + + PowerMock.replayAll(); + + connectorsResource.restartTask(CONNECTOR_NAME, 0, NULL_HEADERS, null); + + PowerMock.verifyAll(); + } + + @Test + public void testRestartTaskOwnerRedirect() throws Throwable { + ConnectorTaskId taskId = new ConnectorTaskId(CONNECTOR_NAME, 0); + + final Capture> cb = Capture.newInstance(); + herder.restartTask(EasyMock.eq(taskId), EasyMock.capture(cb)); + String ownerUrl = "http://owner:8083"; + expectAndCallbackException(cb, new NotAssignedException("not owner test", ownerUrl)); + + EasyMock.expect(RestClient.httpRequest(EasyMock.eq("http://owner:8083/connectors/" + CONNECTOR_NAME + "/tasks/0/restart?forward=false"), + EasyMock.eq("POST"), EasyMock.isNull(), EasyMock.isNull(), EasyMock.anyObject(), EasyMock.anyObject(WorkerConfig.class))) + .andReturn(new RestClient.HttpResponse<>(202, new HashMap<>(), null)); + + PowerMock.replayAll(); + + connectorsResource.restartTask(CONNECTOR_NAME, 0, NULL_HEADERS, true); + + PowerMock.verifyAll(); + } + + @Test + public void testConnectorActiveTopicsWithTopicTrackingDisabled() { + PowerMock.reset(workerConfig); + EasyMock.expect(workerConfig.getBoolean(TOPIC_TRACKING_ENABLE_CONFIG)).andReturn(false); + EasyMock.expect(workerConfig.getBoolean(TOPIC_TRACKING_ALLOW_RESET_CONFIG)).andReturn(false); + PowerMock.replay(workerConfig); + connectorsResource = new ConnectorsResource(herder, workerConfig); + PowerMock.replayAll(); + + Exception e = assertThrows(ConnectRestException.class, + () -> connectorsResource.getConnectorActiveTopics(CONNECTOR_NAME)); + assertEquals("Topic tracking is disabled.", e.getMessage()); + PowerMock.verifyAll(); + } + + @Test + public void testResetConnectorActiveTopicsWithTopicTrackingDisabled() { + PowerMock.reset(workerConfig); + EasyMock.expect(workerConfig.getBoolean(TOPIC_TRACKING_ENABLE_CONFIG)).andReturn(false); + EasyMock.expect(workerConfig.getBoolean(TOPIC_TRACKING_ALLOW_RESET_CONFIG)).andReturn(true); + HttpHeaders headers = EasyMock.mock(HttpHeaders.class); + PowerMock.replay(workerConfig); + connectorsResource = new ConnectorsResource(herder, workerConfig); + PowerMock.replayAll(); + + Exception e = assertThrows(ConnectRestException.class, + () -> connectorsResource.resetConnectorActiveTopics(CONNECTOR_NAME, headers)); + assertEquals("Topic tracking is disabled.", e.getMessage()); + PowerMock.verifyAll(); + } + + @Test + public void testResetConnectorActiveTopicsWithTopicTrackingEnabled() { + PowerMock.reset(workerConfig); + EasyMock.expect(workerConfig.getBoolean(TOPIC_TRACKING_ENABLE_CONFIG)).andReturn(true); + EasyMock.expect(workerConfig.getBoolean(TOPIC_TRACKING_ALLOW_RESET_CONFIG)).andReturn(false); + HttpHeaders headers = EasyMock.mock(HttpHeaders.class); + PowerMock.replay(workerConfig); + connectorsResource = new ConnectorsResource(herder, workerConfig); + PowerMock.replayAll(); + + Exception e = assertThrows(ConnectRestException.class, + () -> connectorsResource.resetConnectorActiveTopics(CONNECTOR_NAME, headers)); + assertEquals("Topic tracking reset is disabled.", e.getMessage()); + PowerMock.verifyAll(); + } + + @Test + public void testConnectorActiveTopics() { + PowerMock.reset(workerConfig); + EasyMock.expect(workerConfig.getBoolean(TOPIC_TRACKING_ENABLE_CONFIG)).andReturn(true); + EasyMock.expect(workerConfig.getBoolean(TOPIC_TRACKING_ALLOW_RESET_CONFIG)).andReturn(true); + EasyMock.expect(herder.connectorActiveTopics(CONNECTOR_NAME)) + .andReturn(new ActiveTopicsInfo(CONNECTOR_NAME, CONNECTOR_ACTIVE_TOPICS)); + PowerMock.replay(workerConfig); + connectorsResource = new ConnectorsResource(herder, workerConfig); + PowerMock.replayAll(); + + Response response = connectorsResource.getConnectorActiveTopics(CONNECTOR_NAME); + assertEquals(Response.Status.OK.getStatusCode(), response.getStatus()); + Map> body = (Map>) response.getEntity(); + assertEquals(CONNECTOR_NAME, ((ActiveTopicsInfo) body.get(CONNECTOR_NAME)).connector()); + assertEquals(new HashSet<>(CONNECTOR_ACTIVE_TOPICS), + ((ActiveTopicsInfo) body.get(CONNECTOR_NAME)).topics()); + PowerMock.verifyAll(); + } + + @Test + public void testResetConnectorActiveTopics() { + PowerMock.reset(workerConfig); + EasyMock.expect(workerConfig.getBoolean(TOPIC_TRACKING_ENABLE_CONFIG)).andReturn(true); + EasyMock.expect(workerConfig.getBoolean(TOPIC_TRACKING_ALLOW_RESET_CONFIG)).andReturn(true); + HttpHeaders headers = EasyMock.mock(HttpHeaders.class); + herder.resetConnectorActiveTopics(CONNECTOR_NAME); + EasyMock.expectLastCall(); + PowerMock.replay(workerConfig); + connectorsResource = new ConnectorsResource(herder, workerConfig); + PowerMock.replayAll(); + + Response response = connectorsResource.resetConnectorActiveTopics(CONNECTOR_NAME, headers); + assertEquals(Response.Status.ACCEPTED.getStatusCode(), response.getStatus()); + PowerMock.verifyAll(); + } + + @Test + public void testCompleteOrForwardWithErrorAndNoForwardUrl() throws Throwable { + final Capture>> cb = Capture.newInstance(); + herder.deleteConnectorConfig(EasyMock.eq(CONNECTOR_NAME), EasyMock.capture(cb)); + String leaderUrl = null; + expectAndCallbackException(cb, new NotLeaderException("not leader", leaderUrl)); + + PowerMock.replayAll(); + + ConnectRestException e = assertThrows(ConnectRestException.class, () -> + connectorsResource.destroyConnector(CONNECTOR_NAME, NULL_HEADERS, FORWARD)); + assertTrue(e.getMessage().contains("no known leader URL")); + PowerMock.verifyAll(); + } + + private byte[] serializeAsBytes(final T value) throws IOException { + return new ObjectMapper().writeValueAsBytes(value); + } + + private void expectAndCallbackResult(final Capture> cb, final T value) { + PowerMock.expectLastCall().andAnswer(() -> { + cb.getValue().onCompletion(null, value); + return null; + }); + } + + private void expectAndCallbackException(final Capture> cb, final Throwable t) { + PowerMock.expectLastCall().andAnswer((IAnswer) () -> { + cb.getValue().onCompletion(t, null); + return null; + }); + } + + private void expectAndCallbackNotLeaderException(final Capture> cb) { + expectAndCallbackException(cb, new NotLeaderException("not leader test", LEADER_URL)); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/rest/resources/LoggingResourceTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/rest/resources/LoggingResourceTest.java new file mode 100644 index 0000000..63814cd --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/rest/resources/LoggingResourceTest.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest.resources; + +import org.apache.kafka.connect.errors.NotFoundException; +import org.apache.kafka.connect.runtime.rest.errors.BadRequestException; +import org.apache.log4j.Hierarchy; +import org.apache.log4j.Level; +import org.apache.log4j.Logger; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Enumeration; +import java.util.List; +import java.util.Map; +import java.util.Vector; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +@SuppressWarnings("unchecked") +public class LoggingResourceTest { + + @Test + public void getLoggersIgnoresNullLevelsTest() { + LoggingResource loggingResource = mock(LoggingResource.class); + Logger root = new Logger("root") { + }; + Logger a = new Logger("a") { + }; + a.setLevel(null); + Logger b = new Logger("b") { + }; + b.setLevel(Level.INFO); + when(loggingResource.currentLoggers()).thenReturn(loggers(a, b)); + when(loggingResource.rootLogger()).thenReturn(root); + when(loggingResource.listLoggers()).thenCallRealMethod(); + Map> loggers = (Map>) loggingResource.listLoggers().getEntity(); + assertEquals(1, loggers.size()); + assertEquals("INFO", loggers.get("b").get("level")); + } + + @Test + public void getLoggerFallsbackToEffectiveLogLevelTest() { + LoggingResource loggingResource = mock(LoggingResource.class); + Logger root = new Logger("root") { + }; + root.setLevel(Level.ERROR); + Hierarchy hierarchy = new Hierarchy(root); + Logger a = hierarchy.getLogger("a"); + a.setLevel(null); + Logger b = hierarchy.getLogger("b"); + b.setLevel(Level.INFO); + when(loggingResource.currentLoggers()).thenReturn(loggers(a, b)); + when(loggingResource.rootLogger()).thenReturn(root); + when(loggingResource.getLogger(any())).thenCallRealMethod(); + Map level = (Map) loggingResource.getLogger("a").getEntity(); + assertEquals(1, level.size()); + assertEquals("ERROR", level.get("level")); + } + + @Test + public void getUnknownLoggerTest() { + LoggingResource loggingResource = mock(LoggingResource.class); + Logger root = new Logger("root") { + }; + root.setLevel(Level.ERROR); + Hierarchy hierarchy = new Hierarchy(root); + Logger a = hierarchy.getLogger("a"); + a.setLevel(null); + Logger b = hierarchy.getLogger("b"); + b.setLevel(Level.INFO); + when(loggingResource.currentLoggers()).thenReturn(loggers(a, b)); + when(loggingResource.rootLogger()).thenReturn(root); + when(loggingResource.getLogger(any())).thenCallRealMethod(); + assertThrows(NotFoundException.class, () -> loggingResource.getLogger("c")); + } + + @Test + public void setLevelTest() { + LoggingResource loggingResource = mock(LoggingResource.class); + Logger root = new Logger("root") { + }; + root.setLevel(Level.ERROR); + Hierarchy hierarchy = new Hierarchy(root); + Logger p = hierarchy.getLogger("a.b.c.p"); + Logger x = hierarchy.getLogger("a.b.c.p.X"); + Logger y = hierarchy.getLogger("a.b.c.p.Y"); + Logger z = hierarchy.getLogger("a.b.c.p.Z"); + Logger w = hierarchy.getLogger("a.b.c.s.W"); + x.setLevel(Level.INFO); + y.setLevel(Level.INFO); + z.setLevel(Level.INFO); + w.setLevel(Level.INFO); + when(loggingResource.currentLoggers()).thenReturn(loggers(x, y, z, w)); + when(loggingResource.lookupLogger("a.b.c.p")).thenReturn(p); + when(loggingResource.rootLogger()).thenReturn(root); + when(loggingResource.setLevel(any(), any())).thenCallRealMethod(); + List modified = (List) loggingResource.setLevel("a.b.c.p", Collections.singletonMap("level", "DEBUG")).getEntity(); + assertEquals(4, modified.size()); + assertEquals(Arrays.asList("a.b.c.p", "a.b.c.p.X", "a.b.c.p.Y", "a.b.c.p.Z"), modified); + assertEquals(p.getLevel(), Level.DEBUG); + assertEquals(x.getLevel(), Level.DEBUG); + assertEquals(y.getLevel(), Level.DEBUG); + assertEquals(z.getLevel(), Level.DEBUG); + } + + @Test + public void setRootLevelTest() { + LoggingResource loggingResource = mock(LoggingResource.class); + Logger root = new Logger("root") { + }; + root.setLevel(Level.ERROR); + Hierarchy hierarchy = new Hierarchy(root); + Logger p = hierarchy.getLogger("a.b.c.p"); + Logger x = hierarchy.getLogger("a.b.c.p.X"); + Logger y = hierarchy.getLogger("a.b.c.p.Y"); + Logger z = hierarchy.getLogger("a.b.c.p.Z"); + Logger w = hierarchy.getLogger("a.b.c.s.W"); + x.setLevel(Level.INFO); + y.setLevel(Level.INFO); + z.setLevel(Level.INFO); + w.setLevel(Level.INFO); + when(loggingResource.currentLoggers()).thenReturn(loggers(x, y, z, w)); + when(loggingResource.lookupLogger("a.b.c.p")).thenReturn(p); + when(loggingResource.rootLogger()).thenReturn(root); + when(loggingResource.setLevel(any(), any())).thenCallRealMethod(); + List modified = (List) loggingResource.setLevel("root", Collections.singletonMap("level", "DEBUG")).getEntity(); + assertEquals(5, modified.size()); + assertEquals(Arrays.asList("a.b.c.p.X", "a.b.c.p.Y", "a.b.c.p.Z", "a.b.c.s.W", "root"), modified); + assertNull(p.getLevel()); + assertEquals(root.getLevel(), Level.DEBUG); + assertEquals(w.getLevel(), Level.DEBUG); + assertEquals(x.getLevel(), Level.DEBUG); + assertEquals(y.getLevel(), Level.DEBUG); + assertEquals(z.getLevel(), Level.DEBUG); + } + + @Test + public void setLevelWithEmptyArgTest() { + LoggingResource loggingResource = mock(LoggingResource.class); + Logger root = new Logger("root") { + }; + root.setLevel(Level.ERROR); + Hierarchy hierarchy = new Hierarchy(root); + Logger a = hierarchy.getLogger("a"); + a.setLevel(null); + Logger b = hierarchy.getLogger("b"); + b.setLevel(Level.INFO); + when(loggingResource.currentLoggers()).thenReturn(loggers(a, b)); + when(loggingResource.rootLogger()).thenReturn(root); + when(loggingResource.setLevel(any(), any())).thenCallRealMethod(); + assertThrows(BadRequestException.class, () -> loggingResource.setLevel("@root", Collections.emptyMap())); + } + + @Test + public void setLevelWithInvalidArgTest() { + LoggingResource loggingResource = mock(LoggingResource.class); + Logger root = new Logger("root") { + }; + root.setLevel(Level.ERROR); + Hierarchy hierarchy = new Hierarchy(root); + Logger a = hierarchy.getLogger("a"); + a.setLevel(null); + Logger b = hierarchy.getLogger("b"); + b.setLevel(Level.INFO); + when(loggingResource.currentLoggers()).thenReturn(loggers(a, b)); + when(loggingResource.rootLogger()).thenReturn(root); + when(loggingResource.setLevel(any(), any())).thenCallRealMethod(); + assertThrows(NotFoundException.class, () -> loggingResource.setLevel("@root", Collections.singletonMap("level", "HIGH"))); + } + + private Enumeration loggers(Logger... loggers) { + return new Vector<>(Arrays.asList(loggers)).elements(); + } + +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/rest/resources/RootResourceTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/rest/resources/RootResourceTest.java new file mode 100644 index 0000000..4e928a3 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/rest/resources/RootResourceTest.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest.resources; + +import org.apache.kafka.clients.admin.MockAdminClient; +import org.apache.kafka.common.utils.AppInfoParser; +import org.apache.kafka.connect.runtime.Herder; +import org.apache.kafka.connect.runtime.rest.entities.ServerInfo; +import org.easymock.EasyMock; +import org.easymock.EasyMockRunner; +import org.easymock.EasyMockSupport; +import org.easymock.Mock; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import static org.junit.Assert.assertEquals; + +@RunWith(EasyMockRunner.class) +public class RootResourceTest extends EasyMockSupport { + + @Mock + private Herder herder; + private RootResource rootResource; + + @Before + public void setUp() { + rootResource = new RootResource(herder); + } + + @Test + public void testRootGet() { + EasyMock.expect(herder.kafkaClusterId()).andReturn(MockAdminClient.DEFAULT_CLUSTER_ID); + + replayAll(); + + ServerInfo info = rootResource.serverInfo(); + assertEquals(AppInfoParser.getVersion(), info.version()); + assertEquals(AppInfoParser.getCommitId(), info.commit()); + assertEquals(MockAdminClient.DEFAULT_CLUSTER_ID, info.clusterId()); + + verifyAll(); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/rest/util/SSLUtilsTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/rest/util/SSLUtilsTest.java new file mode 100644 index 0000000..b8ffbcf --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/rest/util/SSLUtilsTest.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.rest.util; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.common.config.SslConfigs; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.apache.kafka.connect.runtime.distributed.DistributedConfig; +import org.apache.kafka.connect.runtime.standalone.StandaloneConfig; +import org.eclipse.jetty.util.ssl.SslContextFactory; +import org.junit.Assert; +import org.junit.Test; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +@SuppressWarnings("deprecation") +public class SSLUtilsTest { + private static final Map DEFAULT_CONFIG = new HashMap<>(); + static { + // The WorkerConfig base class has some required settings without defaults + DEFAULT_CONFIG.put(DistributedConfig.STATUS_STORAGE_TOPIC_CONFIG, "status-topic"); + DEFAULT_CONFIG.put(DistributedConfig.CONFIG_TOPIC_CONFIG, "config-topic"); + DEFAULT_CONFIG.put(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + DEFAULT_CONFIG.put(DistributedConfig.GROUP_ID_CONFIG, "connect-test-group"); + DEFAULT_CONFIG.put(WorkerConfig.KEY_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + DEFAULT_CONFIG.put(WorkerConfig.VALUE_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + DEFAULT_CONFIG.put(DistributedConfig.OFFSET_STORAGE_TOPIC_CONFIG, "connect-offsets"); + } + + @Test + public void testGetOrDefault() { + String existingKey = "exists"; + String missingKey = "missing"; + String value = "value"; + String defaultValue = "default"; + Map map = new HashMap<>(); + map.put("exists", "value"); + + Assert.assertEquals(SSLUtils.getOrDefault(map, existingKey, defaultValue), value); + Assert.assertEquals(SSLUtils.getOrDefault(map, missingKey, defaultValue), defaultValue); + } + + @Test + public void testCreateServerSideSslContextFactory() { + Map configMap = new HashMap<>(DEFAULT_CONFIG); + configMap.put("ssl.keystore.location", "/path/to/keystore"); + configMap.put("ssl.keystore.password", "123456"); + configMap.put("ssl.key.password", "123456"); + configMap.put("ssl.truststore.location", "/path/to/truststore"); + configMap.put("ssl.truststore.password", "123456"); + configMap.put("ssl.provider", "SunJSSE"); + configMap.put("ssl.cipher.suites", "SSL_RSA_WITH_RC4_128_SHA,SSL_RSA_WITH_RC4_128_MD5"); + configMap.put("ssl.secure.random.implementation", "SHA1PRNG"); + configMap.put("ssl.client.auth", "required"); + configMap.put("ssl.endpoint.identification.algorithm", "HTTPS"); + configMap.put("ssl.keystore.type", "JKS"); + configMap.put("ssl.protocol", "TLS"); + configMap.put("ssl.truststore.type", "JKS"); + configMap.put("ssl.enabled.protocols", "TLSv1.2,TLSv1.1,TLSv1"); + configMap.put("ssl.keymanager.algorithm", "SunX509"); + configMap.put("ssl.trustmanager.algorithm", "PKIX"); + + DistributedConfig config = new DistributedConfig(configMap); + SslContextFactory ssl = SSLUtils.createServerSideSslContextFactory(config); + + Assert.assertEquals("file:///path/to/keystore", ssl.getKeyStorePath()); + Assert.assertEquals("file:///path/to/truststore", ssl.getTrustStorePath()); + Assert.assertEquals("SunJSSE", ssl.getProvider()); + Assert.assertArrayEquals(new String[] {"SSL_RSA_WITH_RC4_128_SHA", "SSL_RSA_WITH_RC4_128_MD5"}, ssl.getIncludeCipherSuites()); + Assert.assertEquals("SHA1PRNG", ssl.getSecureRandomAlgorithm()); + Assert.assertTrue(ssl.getNeedClientAuth()); + Assert.assertFalse(ssl.getWantClientAuth()); + Assert.assertEquals("JKS", ssl.getKeyStoreType()); + Assert.assertEquals("JKS", ssl.getTrustStoreType()); + Assert.assertEquals("TLS", ssl.getProtocol()); + Assert.assertArrayEquals(new String[] {"TLSv1.2", "TLSv1.1", "TLSv1"}, ssl.getIncludeProtocols()); + Assert.assertEquals("SunX509", ssl.getKeyManagerFactoryAlgorithm()); + Assert.assertEquals("PKIX", ssl.getTrustManagerFactoryAlgorithm()); + } + + @Test + public void testCreateClientSideSslContextFactory() { + Map configMap = new HashMap<>(DEFAULT_CONFIG); + configMap.put("ssl.keystore.location", "/path/to/keystore"); + configMap.put("ssl.keystore.password", "123456"); + configMap.put("ssl.key.password", "123456"); + configMap.put("ssl.truststore.location", "/path/to/truststore"); + configMap.put("ssl.truststore.password", "123456"); + configMap.put("ssl.provider", "SunJSSE"); + configMap.put("ssl.cipher.suites", "SSL_RSA_WITH_RC4_128_SHA,SSL_RSA_WITH_RC4_128_MD5"); + configMap.put("ssl.secure.random.implementation", "SHA1PRNG"); + configMap.put("ssl.client.auth", "required"); + configMap.put("ssl.endpoint.identification.algorithm", "HTTPS"); + configMap.put("ssl.keystore.type", "JKS"); + configMap.put("ssl.protocol", "TLS"); + configMap.put("ssl.truststore.type", "JKS"); + configMap.put("ssl.enabled.protocols", "TLSv1.2,TLSv1.1,TLSv1"); + configMap.put("ssl.keymanager.algorithm", "SunX509"); + configMap.put("ssl.trustmanager.algorithm", "PKIX"); + + DistributedConfig config = new DistributedConfig(configMap); + SslContextFactory ssl = SSLUtils.createClientSideSslContextFactory(config); + + Assert.assertEquals("file:///path/to/keystore", ssl.getKeyStorePath()); + Assert.assertEquals("file:///path/to/truststore", ssl.getTrustStorePath()); + Assert.assertEquals("SunJSSE", ssl.getProvider()); + Assert.assertArrayEquals(new String[] {"SSL_RSA_WITH_RC4_128_SHA", "SSL_RSA_WITH_RC4_128_MD5"}, ssl.getIncludeCipherSuites()); + Assert.assertEquals("SHA1PRNG", ssl.getSecureRandomAlgorithm()); + Assert.assertFalse(ssl.getNeedClientAuth()); + Assert.assertFalse(ssl.getWantClientAuth()); + Assert.assertEquals("JKS", ssl.getKeyStoreType()); + Assert.assertEquals("JKS", ssl.getTrustStoreType()); + Assert.assertEquals("TLS", ssl.getProtocol()); + Assert.assertArrayEquals(new String[] {"TLSv1.2", "TLSv1.1", "TLSv1"}, ssl.getIncludeProtocols()); + Assert.assertEquals("SunX509", ssl.getKeyManagerFactoryAlgorithm()); + Assert.assertEquals("PKIX", ssl.getTrustManagerFactoryAlgorithm()); + } + + @Test + public void testCreateServerSideSslContextFactoryDefaultValues() { + Map configMap = new HashMap<>(DEFAULT_CONFIG); + configMap.put(StandaloneConfig.OFFSET_STORAGE_FILE_FILENAME_CONFIG, "/tmp/offset/file"); + configMap.put(WorkerConfig.KEY_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + configMap.put(WorkerConfig.VALUE_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + configMap.put("ssl.keystore.location", "/path/to/keystore"); + configMap.put("ssl.keystore.password", "123456"); + configMap.put("ssl.key.password", "123456"); + configMap.put("ssl.truststore.location", "/path/to/truststore"); + configMap.put("ssl.truststore.password", "123456"); + configMap.put("ssl.provider", "SunJSSE"); + configMap.put("ssl.cipher.suites", "SSL_RSA_WITH_RC4_128_SHA,SSL_RSA_WITH_RC4_128_MD5"); + configMap.put("ssl.secure.random.implementation", "SHA1PRNG"); + + DistributedConfig config = new DistributedConfig(configMap); + SslContextFactory ssl = SSLUtils.createServerSideSslContextFactory(config); + + Assert.assertEquals(SslConfigs.DEFAULT_SSL_KEYSTORE_TYPE, ssl.getKeyStoreType()); + Assert.assertEquals(SslConfigs.DEFAULT_SSL_TRUSTSTORE_TYPE, ssl.getTrustStoreType()); + Assert.assertEquals(SslConfigs.DEFAULT_SSL_PROTOCOL, ssl.getProtocol()); + Assert.assertArrayEquals(Arrays.asList(SslConfigs.DEFAULT_SSL_ENABLED_PROTOCOLS.split("\\s*,\\s*")).toArray(), ssl.getIncludeProtocols()); + Assert.assertEquals(SslConfigs.DEFAULT_SSL_KEYMANGER_ALGORITHM, ssl.getKeyManagerFactoryAlgorithm()); + Assert.assertEquals(SslConfigs.DEFAULT_SSL_TRUSTMANAGER_ALGORITHM, ssl.getTrustManagerFactoryAlgorithm()); + Assert.assertFalse(ssl.getNeedClientAuth()); + Assert.assertFalse(ssl.getWantClientAuth()); + } + + @Test + public void testCreateClientSideSslContextFactoryDefaultValues() { + Map configMap = new HashMap<>(DEFAULT_CONFIG); + configMap.put(StandaloneConfig.OFFSET_STORAGE_FILE_FILENAME_CONFIG, "/tmp/offset/file"); + configMap.put(WorkerConfig.KEY_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + configMap.put(WorkerConfig.VALUE_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + configMap.put("ssl.keystore.location", "/path/to/keystore"); + configMap.put("ssl.keystore.password", "123456"); + configMap.put("ssl.key.password", "123456"); + configMap.put("ssl.truststore.location", "/path/to/truststore"); + configMap.put("ssl.truststore.password", "123456"); + configMap.put("ssl.provider", "SunJSSE"); + configMap.put("ssl.cipher.suites", "SSL_RSA_WITH_RC4_128_SHA,SSL_RSA_WITH_RC4_128_MD5"); + configMap.put("ssl.secure.random.implementation", "SHA1PRNG"); + + DistributedConfig config = new DistributedConfig(configMap); + SslContextFactory ssl = SSLUtils.createClientSideSslContextFactory(config); + + Assert.assertEquals(SslConfigs.DEFAULT_SSL_KEYSTORE_TYPE, ssl.getKeyStoreType()); + Assert.assertEquals(SslConfigs.DEFAULT_SSL_TRUSTSTORE_TYPE, ssl.getTrustStoreType()); + Assert.assertEquals(SslConfigs.DEFAULT_SSL_PROTOCOL, ssl.getProtocol()); + Assert.assertArrayEquals(Arrays.asList(SslConfigs.DEFAULT_SSL_ENABLED_PROTOCOLS.split("\\s*,\\s*")).toArray(), ssl.getIncludeProtocols()); + Assert.assertEquals(SslConfigs.DEFAULT_SSL_KEYMANGER_ALGORITHM, ssl.getKeyManagerFactoryAlgorithm()); + Assert.assertEquals(SslConfigs.DEFAULT_SSL_TRUSTMANAGER_ALGORITHM, ssl.getTrustManagerFactoryAlgorithm()); + Assert.assertFalse(ssl.getNeedClientAuth()); + Assert.assertFalse(ssl.getWantClientAuth()); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/standalone/StandaloneConfigTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/standalone/StandaloneConfigTest.java new file mode 100644 index 0000000..e2e886f --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/standalone/StandaloneConfigTest.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.standalone; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.SslConfigs; +import org.apache.kafka.common.config.types.Password; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.junit.Test; + +import java.util.HashMap; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.junit.Assert.assertEquals; + +public class StandaloneConfigTest { + + private static final String HTTPS_LISTENER_PREFIX = "listeners.https."; + + private Map sslProps() { + return new HashMap() { + { + put(SslConfigs.SSL_KEY_PASSWORD_CONFIG, new Password("ssl_key_password")); + put(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG, "ssl_keystore"); + put(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG, new Password("ssl_keystore_password")); + put(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG, "ssl_truststore"); + put(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG, new Password("ssl_truststore_password")); + } + }; + } + + private Map baseWorkerProps() { + return new HashMap() { + { + put(WorkerConfig.KEY_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + put(WorkerConfig.VALUE_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + put(StandaloneConfig.OFFSET_STORAGE_FILE_FILENAME_CONFIG, "/tmp/foo"); + } + }; + } + + private static Map withStringValues(Map inputs, String prefix) { + return ConfigDef.convertToStringMapWithPasswordValues(inputs).entrySet().stream() + .collect(Collectors.toMap( + entry -> prefix + entry.getKey(), + Map.Entry::getValue + )); + } + + @Test + public void testRestServerPrefixedSslConfigs() { + Map workerProps = baseWorkerProps(); + Map expectedSslProps = sslProps(); + workerProps.putAll(withStringValues(expectedSslProps, HTTPS_LISTENER_PREFIX)); + + StandaloneConfig config = new StandaloneConfig(workerProps); + assertEquals(expectedSslProps, config.valuesWithPrefixAllOrNothing(HTTPS_LISTENER_PREFIX)); + } + + @Test + public void testRestServerNonPrefixedSslConfigs() { + Map props = baseWorkerProps(); + Map expectedSslProps = sslProps(); + props.putAll(withStringValues(expectedSslProps, "")); + + StandaloneConfig config = new StandaloneConfig(props); + Map actualProps = config.valuesWithPrefixAllOrNothing(HTTPS_LISTENER_PREFIX) + .entrySet().stream() + .filter(entry -> expectedSslProps.containsKey(entry.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + assertEquals(expectedSslProps, actualProps); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/standalone/StandaloneHerderTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/standalone/StandaloneHerderTest.java new file mode 100644 index 0000000..2504d98 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/runtime/standalone/StandaloneHerderTest.java @@ -0,0 +1,997 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.runtime.standalone; + +import org.apache.kafka.common.config.Config; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigValue; +import org.apache.kafka.connect.connector.Connector; +import org.apache.kafka.connect.connector.Task; +import org.apache.kafka.connect.connector.policy.ConnectorClientConfigOverridePolicy; +import org.apache.kafka.connect.connector.policy.NoneConnectorClientConfigOverridePolicy; +import org.apache.kafka.connect.errors.AlreadyExistsException; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.errors.NotFoundException; +import org.apache.kafka.connect.runtime.AbstractStatus; +import org.apache.kafka.connect.runtime.ConnectorConfig; +import org.apache.kafka.connect.runtime.ConnectorStatus; +import org.apache.kafka.connect.runtime.Herder; +import org.apache.kafka.connect.runtime.HerderConnectorContext; +import org.apache.kafka.connect.runtime.RestartPlan; +import org.apache.kafka.connect.runtime.RestartRequest; +import org.apache.kafka.connect.runtime.SinkConnectorConfig; +import org.apache.kafka.connect.runtime.SourceConnectorConfig; +import org.apache.kafka.connect.runtime.TargetState; +import org.apache.kafka.connect.runtime.TaskConfig; +import org.apache.kafka.connect.runtime.TaskStatus; +import org.apache.kafka.connect.runtime.Worker; +import org.apache.kafka.connect.runtime.WorkerConfigTransformer; +import org.apache.kafka.connect.runtime.WorkerConnector; +import org.apache.kafka.connect.runtime.distributed.ClusterConfigState; +import org.apache.kafka.connect.runtime.isolation.DelegatingClassLoader; +import org.apache.kafka.connect.runtime.isolation.PluginClassLoader; +import org.apache.kafka.connect.runtime.isolation.Plugins; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorInfo; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorStateInfo; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorType; +import org.apache.kafka.connect.runtime.rest.entities.TaskInfo; +import org.apache.kafka.connect.runtime.rest.errors.BadRequestException; +import org.apache.kafka.connect.sink.SinkConnector; +import org.apache.kafka.connect.sink.SinkTask; +import org.apache.kafka.connect.source.SourceConnector; +import org.apache.kafka.connect.source.SourceTask; +import org.apache.kafka.connect.storage.MemoryConfigBackingStore; +import org.apache.kafka.connect.storage.StatusBackingStore; +import org.apache.kafka.connect.util.Callback; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.apache.kafka.connect.util.FutureCallback; +import org.easymock.Capture; +import org.easymock.EasyMock; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.easymock.PowerMock; +import org.powermock.api.easymock.annotation.Mock; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; + +import static java.util.Collections.singleton; +import static java.util.Collections.singletonList; +import static java.util.Collections.singletonMap; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.DEFAULT_TOPIC_CREATION_PREFIX; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.PARTITIONS_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.REPLICATION_FACTOR_CONFIG; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +@RunWith(PowerMockRunner.class) +@SuppressWarnings("unchecked") +@PrepareForTest({StandaloneHerder.class, Plugins.class, WorkerConnector.class}) +public class StandaloneHerderTest { + private static final String CONNECTOR_NAME = "test"; + private static final String TOPICS_LIST_STR = "topic1,topic2"; + private static final String WORKER_ID = "localhost:8083"; + private static final String KAFKA_CLUSTER_ID = "I4ZmrWqfT2e-upky_4fdPA"; + + private enum SourceSink { + SOURCE, SINK + } + + private StandaloneHerder herder; + + private Connector connector; + @Mock protected Worker worker; + @Mock protected WorkerConfigTransformer transformer; + @Mock private Plugins plugins; + @Mock + private PluginClassLoader pluginLoader; + @Mock + private DelegatingClassLoader delegatingLoader; + protected FutureCallback> createCallback; + @Mock protected StatusBackingStore statusBackingStore; + + private final ConnectorClientConfigOverridePolicy + noneConnectorClientConfigOverridePolicy = new NoneConnectorClientConfigOverridePolicy(); + + + @Before + public void setup() { + worker = PowerMock.createMock(Worker.class); + String[] methodNames = new String[]{"connectorTypeForClass"/*, "validateConnectorConfig"*/, "buildRestartPlan", "recordRestarting"}; + herder = PowerMock.createPartialMock(StandaloneHerder.class, methodNames, + worker, WORKER_ID, KAFKA_CLUSTER_ID, statusBackingStore, new MemoryConfigBackingStore(transformer), noneConnectorClientConfigOverridePolicy); + createCallback = new FutureCallback<>(); + plugins = PowerMock.createMock(Plugins.class); + pluginLoader = PowerMock.createMock(PluginClassLoader.class); + delegatingLoader = PowerMock.createMock(DelegatingClassLoader.class); + PowerMock.mockStatic(Plugins.class); + PowerMock.mockStatic(WorkerConnector.class); + Capture> configCapture = Capture.newInstance(); + EasyMock.expect(transformer.transform(EasyMock.eq(CONNECTOR_NAME), EasyMock.capture(configCapture))).andAnswer(configCapture::getValue).anyTimes(); + } + + @Test + public void testCreateSourceConnector() throws Exception { + connector = PowerMock.createMock(BogusSourceConnector.class); + expectAdd(SourceSink.SOURCE); + + Map config = connectorConfig(SourceSink.SOURCE); + Connector connectorMock = PowerMock.createMock(SourceConnector.class); + expectConfigValidation(connectorMock, true, config); + + PowerMock.replayAll(); + + herder.putConnectorConfig(CONNECTOR_NAME, config, false, createCallback); + Herder.Created connectorInfo = createCallback.get(1000L, TimeUnit.SECONDS); + assertEquals(createdInfo(SourceSink.SOURCE), connectorInfo.result()); + + PowerMock.verifyAll(); + } + + @Test + public void testCreateConnectorFailedValidation() throws Throwable { + // Basic validation should be performed and return an error, but should still evaluate the connector's config + connector = PowerMock.createMock(BogusSourceConnector.class); + + Map config = connectorConfig(SourceSink.SOURCE); + config.remove(ConnectorConfig.NAME_CONFIG); + + Connector connectorMock = PowerMock.createMock(SourceConnector.class); + EasyMock.expect(worker.configTransformer()).andReturn(transformer).times(2); + final Capture> configCapture = EasyMock.newCapture(); + EasyMock.expect(transformer.transform(EasyMock.capture(configCapture))).andAnswer(configCapture::getValue); + EasyMock.expect(worker.getPlugins()).andReturn(plugins).times(3); + EasyMock.expect(plugins.compareAndSwapLoaders(connectorMock)).andReturn(delegatingLoader); + EasyMock.expect(plugins.newConnector(EasyMock.anyString())).andReturn(connectorMock); + + EasyMock.expect(connectorMock.config()).andStubReturn(new ConfigDef()); + + ConfigValue validatedValue = new ConfigValue("foo.bar"); + EasyMock.expect(connectorMock.validate(config)).andReturn(new Config(singletonList(validatedValue))); + EasyMock.expect(Plugins.compareAndSwapLoaders(delegatingLoader)).andReturn(pluginLoader); + + PowerMock.replayAll(); + + herder.putConnectorConfig(CONNECTOR_NAME, config, false, createCallback); + + ExecutionException exception = assertThrows(ExecutionException.class, () -> createCallback.get(1000L, TimeUnit.SECONDS)); + assertEquals(BadRequestException.class, exception.getCause().getClass()); + PowerMock.verifyAll(); + } + + @Test + public void testCreateConnectorAlreadyExists() throws Throwable { + connector = PowerMock.createMock(BogusSourceConnector.class); + // First addition should succeed + expectAdd(SourceSink.SOURCE); + + Map config = connectorConfig(SourceSink.SOURCE); + Connector connectorMock = PowerMock.createMock(SourceConnector.class); + expectConfigValidation(connectorMock, true, config, config); + + EasyMock.expect(worker.configTransformer()).andReturn(transformer).times(2); + final Capture> configCapture = EasyMock.newCapture(); + EasyMock.expect(transformer.transform(EasyMock.capture(configCapture))).andAnswer(configCapture::getValue); + EasyMock.expect(worker.getPlugins()).andReturn(plugins).times(2); + EasyMock.expect(plugins.compareAndSwapLoaders(connectorMock)).andReturn(delegatingLoader); + // No new connector is created + EasyMock.expect(Plugins.compareAndSwapLoaders(delegatingLoader)).andReturn(pluginLoader); + + PowerMock.replayAll(); + + herder.putConnectorConfig(CONNECTOR_NAME, config, false, createCallback); + Herder.Created connectorInfo = createCallback.get(1000L, TimeUnit.SECONDS); + assertEquals(createdInfo(SourceSink.SOURCE), connectorInfo.result()); + + // Second should fail + FutureCallback> failedCreateCallback = new FutureCallback<>(); + herder.putConnectorConfig(CONNECTOR_NAME, config, false, failedCreateCallback); + ExecutionException exception = assertThrows(ExecutionException.class, () -> failedCreateCallback.get(1000L, TimeUnit.SECONDS)); + assertEquals(AlreadyExistsException.class, exception.getCause().getClass()); + PowerMock.verifyAll(); + } + + @Test + public void testCreateSinkConnector() throws Exception { + connector = PowerMock.createMock(BogusSinkConnector.class); + expectAdd(SourceSink.SINK); + + Map config = connectorConfig(SourceSink.SINK); + Connector connectorMock = PowerMock.createMock(SinkConnector.class); + expectConfigValidation(connectorMock, true, config); + PowerMock.replayAll(); + + herder.putConnectorConfig(CONNECTOR_NAME, config, false, createCallback); + Herder.Created connectorInfo = createCallback.get(1000L, TimeUnit.SECONDS); + assertEquals(createdInfo(SourceSink.SINK), connectorInfo.result()); + + PowerMock.verifyAll(); + } + + @Test + public void testDestroyConnector() throws Exception { + connector = PowerMock.createMock(BogusSourceConnector.class); + expectAdd(SourceSink.SOURCE); + + Map config = connectorConfig(SourceSink.SOURCE); + Connector connectorMock = PowerMock.createMock(SourceConnector.class); + expectConfigValidation(connectorMock, true, config); + + EasyMock.expect(statusBackingStore.getAll(CONNECTOR_NAME)).andReturn(Collections.emptyList()); + statusBackingStore.put(new ConnectorStatus(CONNECTOR_NAME, AbstractStatus.State.DESTROYED, WORKER_ID, 0)); + statusBackingStore.put(new TaskStatus(new ConnectorTaskId(CONNECTOR_NAME, 0), TaskStatus.State.DESTROYED, WORKER_ID, 0)); + + expectDestroy(); + + PowerMock.replayAll(); + + herder.putConnectorConfig(CONNECTOR_NAME, config, false, createCallback); + Herder.Created connectorInfo = createCallback.get(1000L, TimeUnit.SECONDS); + assertEquals(createdInfo(SourceSink.SOURCE), connectorInfo.result()); + + FutureCallback> deleteCallback = new FutureCallback<>(); + herder.deleteConnectorConfig(CONNECTOR_NAME, deleteCallback); + deleteCallback.get(1000L, TimeUnit.MILLISECONDS); + + // Second deletion should fail since the connector is gone + FutureCallback> failedDeleteCallback = new FutureCallback<>(); + herder.deleteConnectorConfig(CONNECTOR_NAME, failedDeleteCallback); + try { + failedDeleteCallback.get(1000L, TimeUnit.MILLISECONDS); + fail("Should have thrown NotFoundException"); + } catch (ExecutionException e) { + assertTrue(e.getCause() instanceof NotFoundException); + } + + PowerMock.verifyAll(); + } + + @Test + public void testRestartConnector() throws Exception { + expectAdd(SourceSink.SOURCE); + + Map config = connectorConfig(SourceSink.SOURCE); + Connector connectorMock = PowerMock.createMock(SourceConnector.class); + expectConfigValidation(connectorMock, true, config); + + worker.stopAndAwaitConnector(CONNECTOR_NAME); + EasyMock.expectLastCall(); + + Capture> onStart = EasyMock.newCapture(); + worker.startConnector(EasyMock.eq(CONNECTOR_NAME), EasyMock.eq(config), EasyMock.anyObject(HerderConnectorContext.class), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED), EasyMock.capture(onStart)); + EasyMock.expectLastCall().andAnswer(() -> { + onStart.getValue().onCompletion(null, TargetState.STARTED); + return true; + }); + + PowerMock.replayAll(); + + herder.putConnectorConfig(CONNECTOR_NAME, config, false, createCallback); + Herder.Created connectorInfo = createCallback.get(1000L, TimeUnit.SECONDS); + assertEquals(createdInfo(SourceSink.SOURCE), connectorInfo.result()); + + FutureCallback restartCallback = new FutureCallback<>(); + herder.restartConnector(CONNECTOR_NAME, restartCallback); + restartCallback.get(1000L, TimeUnit.MILLISECONDS); + + PowerMock.verifyAll(); + } + + @Test + public void testRestartConnectorFailureOnStart() throws Exception { + expectAdd(SourceSink.SOURCE); + + Map config = connectorConfig(SourceSink.SOURCE); + Connector connectorMock = PowerMock.createMock(SourceConnector.class); + expectConfigValidation(connectorMock, true, config); + + worker.stopAndAwaitConnector(CONNECTOR_NAME); + EasyMock.expectLastCall(); + + Capture> onStart = EasyMock.newCapture(); + worker.startConnector(EasyMock.eq(CONNECTOR_NAME), EasyMock.eq(config), EasyMock.anyObject(HerderConnectorContext.class), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED), EasyMock.capture(onStart)); + Exception exception = new ConnectException("Failed to start connector"); + EasyMock.expectLastCall().andAnswer(() -> { + onStart.getValue().onCompletion(exception, null); + return true; + }); + + PowerMock.replayAll(); + + herder.putConnectorConfig(CONNECTOR_NAME, config, false, createCallback); + Herder.Created connectorInfo = createCallback.get(1000L, TimeUnit.SECONDS); + assertEquals(createdInfo(SourceSink.SOURCE), connectorInfo.result()); + + FutureCallback restartCallback = new FutureCallback<>(); + herder.restartConnector(CONNECTOR_NAME, restartCallback); + try { + restartCallback.get(1000L, TimeUnit.MILLISECONDS); + fail(); + } catch (ExecutionException e) { + assertEquals(exception, e.getCause()); + } + + PowerMock.verifyAll(); + } + + @Test + public void testRestartTask() throws Exception { + ConnectorTaskId taskId = new ConnectorTaskId(CONNECTOR_NAME, 0); + expectAdd(SourceSink.SOURCE); + + Map connectorConfig = connectorConfig(SourceSink.SOURCE); + Connector connectorMock = PowerMock.createMock(SourceConnector.class); + expectConfigValidation(connectorMock, true, connectorConfig); + + worker.stopAndAwaitTask(taskId); + EasyMock.expectLastCall(); + + ClusterConfigState configState = new ClusterConfigState( + -1, + null, + Collections.singletonMap(CONNECTOR_NAME, 1), + Collections.singletonMap(CONNECTOR_NAME, connectorConfig), + Collections.singletonMap(CONNECTOR_NAME, TargetState.STARTED), + Collections.singletonMap(taskId, taskConfig(SourceSink.SOURCE)), + new HashSet<>(), + transformer); + worker.startTask(taskId, configState, connectorConfig, taskConfig(SourceSink.SOURCE), herder, TargetState.STARTED); + EasyMock.expectLastCall().andReturn(true); + + PowerMock.replayAll(); + + herder.putConnectorConfig(CONNECTOR_NAME, connectorConfig, false, createCallback); + Herder.Created connectorInfo = createCallback.get(1000L, TimeUnit.SECONDS); + createCallback.get(1000L, TimeUnit.SECONDS); + assertEquals(createdInfo(SourceSink.SOURCE), connectorInfo.result()); + + FutureCallback restartTaskCallback = new FutureCallback<>(); + herder.restartTask(taskId, restartTaskCallback); + restartTaskCallback.get(1000L, TimeUnit.MILLISECONDS); + + PowerMock.verifyAll(); + } + + @Test + public void testRestartTaskFailureOnStart() throws Exception { + ConnectorTaskId taskId = new ConnectorTaskId(CONNECTOR_NAME, 0); + expectAdd(SourceSink.SOURCE); + + Map connectorConfig = connectorConfig(SourceSink.SOURCE); + Connector connectorMock = PowerMock.createMock(SourceConnector.class); + expectConfigValidation(connectorMock, true, connectorConfig); + + worker.stopAndAwaitTask(taskId); + EasyMock.expectLastCall(); + + ClusterConfigState configState = new ClusterConfigState( + -1, + null, + Collections.singletonMap(CONNECTOR_NAME, 1), + Collections.singletonMap(CONNECTOR_NAME, connectorConfig), + Collections.singletonMap(CONNECTOR_NAME, TargetState.STARTED), + Collections.singletonMap(new ConnectorTaskId(CONNECTOR_NAME, 0), taskConfig(SourceSink.SOURCE)), + new HashSet<>(), + transformer); + worker.startTask(taskId, configState, connectorConfig, taskConfig(SourceSink.SOURCE), herder, TargetState.STARTED); + EasyMock.expectLastCall().andReturn(false); + + PowerMock.replayAll(); + + herder.putConnectorConfig(CONNECTOR_NAME, connectorConfig, false, createCallback); + Herder.Created connectorInfo = createCallback.get(1000L, TimeUnit.MILLISECONDS); + assertEquals(createdInfo(SourceSink.SOURCE), connectorInfo.result()); + + FutureCallback cb = new FutureCallback<>(); + herder.restartTask(taskId, cb); + try { + cb.get(1000L, TimeUnit.MILLISECONDS); + fail("Expected restart callback to raise an exception"); + } catch (ExecutionException exception) { + assertEquals(ConnectException.class, exception.getCause().getClass()); + } + + PowerMock.verifyAll(); + } + + @Test + public void testRestartConnectorAndTasksUnknownConnector() throws Exception { + PowerMock.replayAll(); + + FutureCallback restartCallback = new FutureCallback<>(); + RestartRequest restartRequest = new RestartRequest("UnknownConnector", false, true); + herder.restartConnectorAndTasks(restartRequest, restartCallback); + ExecutionException ee = assertThrows(ExecutionException.class, () -> restartCallback.get(1000L, TimeUnit.MILLISECONDS)); + assertTrue(ee.getCause() instanceof NotFoundException); + + PowerMock.verifyAll(); + } + + @Test + public void testRestartConnectorAndTasksNoStatus() throws Exception { + RestartRequest restartRequest = new RestartRequest(CONNECTOR_NAME, false, true); + EasyMock.expect(herder.buildRestartPlan(restartRequest)).andReturn(Optional.empty()).anyTimes(); + + connector = PowerMock.createMock(BogusSinkConnector.class); + expectAdd(SourceSink.SINK); + + Map connectorConfig = connectorConfig(SourceSink.SINK); + Connector connectorMock = PowerMock.createMock(SinkConnector.class); + expectConfigValidation(connectorMock, true, connectorConfig); + PowerMock.replayAll(); + + herder.putConnectorConfig(CONNECTOR_NAME, connectorConfig, false, createCallback); + Herder.Created connectorInfo = createCallback.get(1000L, TimeUnit.SECONDS); + assertEquals(createdInfo(SourceSink.SINK), connectorInfo.result()); + + FutureCallback restartCallback = new FutureCallback<>(); + herder.restartConnectorAndTasks(restartRequest, restartCallback); + ExecutionException ee = assertThrows(ExecutionException.class, () -> restartCallback.get(1000L, TimeUnit.MILLISECONDS)); + assertTrue(ee.getCause() instanceof NotFoundException); + assertTrue(ee.getMessage().contains("Status for connector")); + PowerMock.verifyAll(); + } + + @Test + public void testRestartConnectorAndTasksNoRestarts() throws Exception { + RestartRequest restartRequest = new RestartRequest(CONNECTOR_NAME, false, true); + RestartPlan restartPlan = PowerMock.createMock(RestartPlan.class); + ConnectorStateInfo connectorStateInfo = PowerMock.createMock(ConnectorStateInfo.class); + EasyMock.expect(restartPlan.shouldRestartConnector()).andReturn(false).anyTimes(); + EasyMock.expect(restartPlan.shouldRestartTasks()).andReturn(false).anyTimes(); + EasyMock.expect(restartPlan.restartConnectorStateInfo()).andReturn(connectorStateInfo).anyTimes(); + EasyMock.expect(herder.buildRestartPlan(restartRequest)) + .andReturn(Optional.of(restartPlan)).anyTimes(); + + connector = PowerMock.createMock(BogusSinkConnector.class); + expectAdd(SourceSink.SINK); + + Map connectorConfig = connectorConfig(SourceSink.SINK); + Connector connectorMock = PowerMock.createMock(SinkConnector.class); + expectConfigValidation(connectorMock, true, connectorConfig); + + PowerMock.replayAll(); + + herder.putConnectorConfig(CONNECTOR_NAME, connectorConfig, false, createCallback); + Herder.Created connectorInfo = createCallback.get(1000L, TimeUnit.SECONDS); + assertEquals(createdInfo(SourceSink.SINK), connectorInfo.result()); + + FutureCallback restartCallback = new FutureCallback<>(); + herder.restartConnectorAndTasks(restartRequest, restartCallback); + assertEquals(connectorStateInfo, restartCallback.get(1000L, TimeUnit.MILLISECONDS)); + PowerMock.verifyAll(); + } + + @Test + public void testRestartConnectorAndTasksOnlyConnector() throws Exception { + RestartRequest restartRequest = new RestartRequest(CONNECTOR_NAME, false, true); + RestartPlan restartPlan = PowerMock.createMock(RestartPlan.class); + ConnectorStateInfo connectorStateInfo = PowerMock.createMock(ConnectorStateInfo.class); + EasyMock.expect(restartPlan.shouldRestartConnector()).andReturn(true).anyTimes(); + EasyMock.expect(restartPlan.shouldRestartTasks()).andReturn(false).anyTimes(); + EasyMock.expect(restartPlan.restartConnectorStateInfo()).andReturn(connectorStateInfo).anyTimes(); + EasyMock.expect(herder.buildRestartPlan(restartRequest)) + .andReturn(Optional.of(restartPlan)).anyTimes(); + + herder.onRestart(CONNECTOR_NAME); + EasyMock.expectLastCall(); + + connector = PowerMock.createMock(BogusSinkConnector.class); + expectAdd(SourceSink.SINK); + + Map connectorConfig = connectorConfig(SourceSink.SINK); + Connector connectorMock = PowerMock.createMock(SinkConnector.class); + expectConfigValidation(connectorMock, true, connectorConfig); + + worker.stopAndAwaitConnector(CONNECTOR_NAME); + EasyMock.expectLastCall(); + + Capture> onStart = EasyMock.newCapture(); + worker.startConnector(EasyMock.eq(CONNECTOR_NAME), EasyMock.eq(connectorConfig), EasyMock.anyObject(HerderConnectorContext.class), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED), EasyMock.capture(onStart)); + EasyMock.expectLastCall().andAnswer(() -> { + onStart.getValue().onCompletion(null, TargetState.STARTED); + return true; + }); + + PowerMock.replayAll(); + + herder.putConnectorConfig(CONNECTOR_NAME, connectorConfig, false, createCallback); + Herder.Created connectorInfo = createCallback.get(1000L, TimeUnit.SECONDS); + assertEquals(createdInfo(SourceSink.SINK), connectorInfo.result()); + + FutureCallback restartCallback = new FutureCallback<>(); + herder.restartConnectorAndTasks(restartRequest, restartCallback); + assertEquals(connectorStateInfo, restartCallback.get(1000L, TimeUnit.MILLISECONDS)); + PowerMock.verifyAll(); + } + + @Test + public void testRestartConnectorAndTasksOnlyTasks() throws Exception { + ConnectorTaskId taskId = new ConnectorTaskId(CONNECTOR_NAME, 0); + RestartRequest restartRequest = new RestartRequest(CONNECTOR_NAME, false, true); + RestartPlan restartPlan = PowerMock.createMock(RestartPlan.class); + ConnectorStateInfo connectorStateInfo = PowerMock.createMock(ConnectorStateInfo.class); + EasyMock.expect(restartPlan.shouldRestartConnector()).andReturn(false).anyTimes(); + EasyMock.expect(restartPlan.shouldRestartTasks()).andReturn(true).anyTimes(); + EasyMock.expect(restartPlan.restartTaskCount()).andReturn(1).anyTimes(); + EasyMock.expect(restartPlan.totalTaskCount()).andReturn(1).anyTimes(); + EasyMock.expect(restartPlan.taskIdsToRestart()).andReturn(Collections.singletonList(taskId)).anyTimes(); + EasyMock.expect(restartPlan.restartConnectorStateInfo()).andReturn(connectorStateInfo).anyTimes(); + EasyMock.expect(herder.buildRestartPlan(restartRequest)) + .andReturn(Optional.of(restartPlan)).anyTimes(); + + herder.onRestart(taskId); + EasyMock.expectLastCall(); + + connector = PowerMock.createMock(BogusSinkConnector.class); + expectAdd(SourceSink.SINK); + + Map connectorConfig = connectorConfig(SourceSink.SINK); + Connector connectorMock = PowerMock.createMock(SinkConnector.class); + expectConfigValidation(connectorMock, true, connectorConfig); + + worker.stopAndAwaitTasks(Collections.singletonList(taskId)); + EasyMock.expectLastCall(); + + ClusterConfigState configState = new ClusterConfigState( + -1, + null, + Collections.singletonMap(CONNECTOR_NAME, 1), + Collections.singletonMap(CONNECTOR_NAME, connectorConfig), + Collections.singletonMap(CONNECTOR_NAME, TargetState.STARTED), + Collections.singletonMap(taskId, taskConfig(SourceSink.SINK)), + new HashSet<>(), + transformer); + worker.startTask(taskId, configState, connectorConfig, taskConfig(SourceSink.SINK), herder, TargetState.STARTED); + EasyMock.expectLastCall().andReturn(true); + PowerMock.replayAll(); + + herder.putConnectorConfig(CONNECTOR_NAME, connectorConfig, false, createCallback); + Herder.Created connectorInfo = createCallback.get(1000L, TimeUnit.SECONDS); + assertEquals(createdInfo(SourceSink.SINK), connectorInfo.result()); + + FutureCallback restartCallback = new FutureCallback<>(); + herder.restartConnectorAndTasks(restartRequest, restartCallback); + assertEquals(connectorStateInfo, restartCallback.get(1000L, TimeUnit.MILLISECONDS)); + PowerMock.verifyAll(); + } + + @Test + public void testRestartConnectorAndTasksBoth() throws Exception { + ConnectorTaskId taskId = new ConnectorTaskId(CONNECTOR_NAME, 0); + RestartRequest restartRequest = new RestartRequest(CONNECTOR_NAME, false, true); + RestartPlan restartPlan = PowerMock.createMock(RestartPlan.class); + ConnectorStateInfo connectorStateInfo = PowerMock.createMock(ConnectorStateInfo.class); + EasyMock.expect(restartPlan.shouldRestartConnector()).andReturn(true).anyTimes(); + EasyMock.expect(restartPlan.shouldRestartTasks()).andReturn(true).anyTimes(); + EasyMock.expect(restartPlan.restartTaskCount()).andReturn(1).anyTimes(); + EasyMock.expect(restartPlan.totalTaskCount()).andReturn(1).anyTimes(); + EasyMock.expect(restartPlan.taskIdsToRestart()).andReturn(Collections.singletonList(taskId)).anyTimes(); + EasyMock.expect(restartPlan.restartConnectorStateInfo()).andReturn(connectorStateInfo).anyTimes(); + EasyMock.expect(herder.buildRestartPlan(restartRequest)) + .andReturn(Optional.of(restartPlan)).anyTimes(); + + herder.onRestart(CONNECTOR_NAME); + EasyMock.expectLastCall(); + herder.onRestart(taskId); + EasyMock.expectLastCall(); + + connector = PowerMock.createMock(BogusSinkConnector.class); + expectAdd(SourceSink.SINK); + + Map connectorConfig = connectorConfig(SourceSink.SINK); + Connector connectorMock = PowerMock.createMock(SinkConnector.class); + expectConfigValidation(connectorMock, true, connectorConfig); + + worker.stopAndAwaitConnector(CONNECTOR_NAME); + EasyMock.expectLastCall(); + worker.stopAndAwaitTasks(Collections.singletonList(taskId)); + EasyMock.expectLastCall(); + + Capture> onStart = EasyMock.newCapture(); + worker.startConnector(EasyMock.eq(CONNECTOR_NAME), EasyMock.eq(connectorConfig), EasyMock.anyObject(HerderConnectorContext.class), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED), EasyMock.capture(onStart)); + EasyMock.expectLastCall().andAnswer(() -> { + onStart.getValue().onCompletion(null, TargetState.STARTED); + return true; + }); + + ClusterConfigState configState = new ClusterConfigState( + -1, + null, + Collections.singletonMap(CONNECTOR_NAME, 1), + Collections.singletonMap(CONNECTOR_NAME, connectorConfig), + Collections.singletonMap(CONNECTOR_NAME, TargetState.STARTED), + Collections.singletonMap(taskId, taskConfig(SourceSink.SINK)), + new HashSet<>(), + transformer); + worker.startTask(taskId, configState, connectorConfig, taskConfig(SourceSink.SINK), herder, TargetState.STARTED); + EasyMock.expectLastCall().andReturn(true); + PowerMock.replayAll(); + + herder.putConnectorConfig(CONNECTOR_NAME, connectorConfig, false, createCallback); + Herder.Created connectorInfo = createCallback.get(1000L, TimeUnit.SECONDS); + assertEquals(createdInfo(SourceSink.SINK), connectorInfo.result()); + + FutureCallback restartCallback = new FutureCallback<>(); + herder.restartConnectorAndTasks(restartRequest, restartCallback); + assertEquals(connectorStateInfo, restartCallback.get(1000L, TimeUnit.MILLISECONDS)); + PowerMock.verifyAll(); + } + + @Test + public void testCreateAndStop() throws Exception { + connector = PowerMock.createMock(BogusSourceConnector.class); + expectAdd(SourceSink.SOURCE); + + Map connectorConfig = connectorConfig(SourceSink.SOURCE); + Connector connectorMock = PowerMock.createMock(SourceConnector.class); + expectConfigValidation(connectorMock, true, connectorConfig); + + // herder.stop() should stop any running connectors and tasks even if destroyConnector was not invoked + expectStop(); + + statusBackingStore.put(new TaskStatus(new ConnectorTaskId(CONNECTOR_NAME, 0), AbstractStatus.State.DESTROYED, WORKER_ID, 0)); + + statusBackingStore.stop(); + EasyMock.expectLastCall(); + worker.stop(); + EasyMock.expectLastCall(); + + PowerMock.replayAll(); + + herder.putConnectorConfig(CONNECTOR_NAME, connectorConfig, false, createCallback); + Herder.Created connectorInfo = createCallback.get(1000L, TimeUnit.SECONDS); + assertEquals(createdInfo(SourceSink.SOURCE), connectorInfo.result()); + + herder.stop(); + + PowerMock.verifyAll(); + } + + @Test + public void testAccessors() throws Exception { + Map connConfig = connectorConfig(SourceSink.SOURCE); + System.out.println(connConfig); + + Callback> listConnectorsCb = PowerMock.createMock(Callback.class); + Callback connectorInfoCb = PowerMock.createMock(Callback.class); + Callback> connectorConfigCb = PowerMock.createMock(Callback.class); + Callback> taskConfigsCb = PowerMock.createMock(Callback.class); + + // Check accessors with empty worker + listConnectorsCb.onCompletion(null, Collections.EMPTY_SET); + EasyMock.expectLastCall(); + connectorInfoCb.onCompletion(EasyMock.anyObject(), EasyMock.isNull()); + EasyMock.expectLastCall(); + connectorConfigCb.onCompletion(EasyMock.anyObject(), EasyMock.isNull()); + EasyMock.expectLastCall(); + taskConfigsCb.onCompletion(EasyMock.anyObject(), EasyMock.isNull()); + EasyMock.expectLastCall(); + + // Create connector + connector = PowerMock.createMock(BogusSourceConnector.class); + expectAdd(SourceSink.SOURCE); + expectConfigValidation(connector, true, connConfig); + + // Validate accessors with 1 connector + listConnectorsCb.onCompletion(null, singleton(CONNECTOR_NAME)); + EasyMock.expectLastCall(); + ConnectorInfo connInfo = new ConnectorInfo(CONNECTOR_NAME, connConfig, Arrays.asList(new ConnectorTaskId(CONNECTOR_NAME, 0)), + ConnectorType.SOURCE); + connectorInfoCb.onCompletion(null, connInfo); + EasyMock.expectLastCall(); + connectorConfigCb.onCompletion(null, connConfig); + EasyMock.expectLastCall(); + + TaskInfo taskInfo = new TaskInfo(new ConnectorTaskId(CONNECTOR_NAME, 0), taskConfig(SourceSink.SOURCE)); + taskConfigsCb.onCompletion(null, Arrays.asList(taskInfo)); + EasyMock.expectLastCall(); + + + PowerMock.replayAll(); + + // All operations are synchronous for StandaloneHerder, so we don't need to actually wait after making each call + herder.connectors(listConnectorsCb); + herder.connectorInfo(CONNECTOR_NAME, connectorInfoCb); + herder.connectorConfig(CONNECTOR_NAME, connectorConfigCb); + herder.taskConfigs(CONNECTOR_NAME, taskConfigsCb); + + herder.putConnectorConfig(CONNECTOR_NAME, connConfig, false, createCallback); + Herder.Created connectorInfo = createCallback.get(1000L, TimeUnit.SECONDS); + assertEquals(createdInfo(SourceSink.SOURCE), connectorInfo.result()); + + EasyMock.reset(transformer); + EasyMock.expect(transformer.transform(EasyMock.eq(CONNECTOR_NAME), EasyMock.anyObject())) + .andThrow(new AssertionError("Config transformation should not occur when requesting connector or task info")) + .anyTimes(); + EasyMock.replay(transformer); + + herder.connectors(listConnectorsCb); + herder.connectorInfo(CONNECTOR_NAME, connectorInfoCb); + herder.connectorConfig(CONNECTOR_NAME, connectorConfigCb); + herder.taskConfigs(CONNECTOR_NAME, taskConfigsCb); + + PowerMock.verifyAll(); + } + + @Test + public void testPutConnectorConfig() throws Exception { + Map connConfig = connectorConfig(SourceSink.SOURCE); + Map newConnConfig = new HashMap<>(connConfig); + newConnConfig.put("foo", "bar"); + + Callback> connectorConfigCb = PowerMock.createMock(Callback.class); + // Callback> putConnectorConfigCb = PowerMock.createMock(Callback.class); + + // Create + connector = PowerMock.createMock(BogusSourceConnector.class); + expectAdd(SourceSink.SOURCE); + Connector connectorMock = PowerMock.createMock(SourceConnector.class); + expectConfigValidation(connectorMock, true, connConfig); + + // Should get first config + connectorConfigCb.onCompletion(null, connConfig); + EasyMock.expectLastCall(); + // Update config, which requires stopping and restarting + worker.stopAndAwaitConnector(CONNECTOR_NAME); + EasyMock.expectLastCall(); + Capture> capturedConfig = EasyMock.newCapture(); + Capture> onStart = EasyMock.newCapture(); + worker.startConnector(EasyMock.eq(CONNECTOR_NAME), EasyMock.capture(capturedConfig), EasyMock.anyObject(), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED), EasyMock.capture(onStart)); + EasyMock.expectLastCall().andAnswer(() -> { + onStart.getValue().onCompletion(null, TargetState.STARTED); + return true; + }); + EasyMock.expect(worker.isRunning(CONNECTOR_NAME)).andReturn(true); + EasyMock.expect(worker.isTopicCreationEnabled()).andReturn(true); + // Generate same task config, which should result in no additional action to restart tasks + EasyMock.expect(worker.connectorTaskConfigs(CONNECTOR_NAME, new SourceConnectorConfig(plugins, newConnConfig, true))) + .andReturn(singletonList(taskConfig(SourceSink.SOURCE))); + worker.isSinkConnector(CONNECTOR_NAME); + EasyMock.expectLastCall().andReturn(false); + + expectConfigValidation(connectorMock, false, newConnConfig); + connectorConfigCb.onCompletion(null, newConnConfig); + EasyMock.expectLastCall(); + EasyMock.expect(worker.getPlugins()).andReturn(plugins).anyTimes(); + + PowerMock.replayAll(); + + herder.putConnectorConfig(CONNECTOR_NAME, connConfig, false, createCallback); + Herder.Created connectorInfo = createCallback.get(1000L, TimeUnit.SECONDS); + assertEquals(createdInfo(SourceSink.SOURCE), connectorInfo.result()); + + herder.connectorConfig(CONNECTOR_NAME, connectorConfigCb); + + FutureCallback> reconfigureCallback = new FutureCallback<>(); + herder.putConnectorConfig(CONNECTOR_NAME, newConnConfig, true, reconfigureCallback); + Herder.Created newConnectorInfo = reconfigureCallback.get(1000L, TimeUnit.SECONDS); + ConnectorInfo newConnInfo = new ConnectorInfo(CONNECTOR_NAME, newConnConfig, Arrays.asList(new ConnectorTaskId(CONNECTOR_NAME, 0)), + ConnectorType.SOURCE); + assertEquals(newConnInfo, newConnectorInfo.result()); + + assertEquals("bar", capturedConfig.getValue().get("foo")); + herder.connectorConfig(CONNECTOR_NAME, connectorConfigCb); + + PowerMock.verifyAll(); + } + + @Test + public void testPutTaskConfigs() { + Callback cb = PowerMock.createMock(Callback.class); + + PowerMock.replayAll(); + + assertThrows(UnsupportedOperationException.class, () -> herder.putTaskConfigs(CONNECTOR_NAME, + singletonList(singletonMap("config", "value")), cb, null)); + PowerMock.verifyAll(); + } + + @Test + public void testCorruptConfig() throws Throwable { + Map config = new HashMap<>(); + config.put(ConnectorConfig.NAME_CONFIG, CONNECTOR_NAME); + config.put(ConnectorConfig.CONNECTOR_CLASS_CONFIG, BogusSinkConnector.class.getName()); + config.put(SinkConnectorConfig.TOPICS_CONFIG, TOPICS_LIST_STR); + Connector connectorMock = PowerMock.createMock(SinkConnector.class); + String error = "This is an error in your config!"; + List errors = new ArrayList<>(singletonList(error)); + String key = "foo.invalid.key"; + EasyMock.expect(connectorMock.validate(config)).andReturn( + new Config( + Arrays.asList(new ConfigValue(key, null, Collections.emptyList(), errors)) + ) + ); + ConfigDef configDef = new ConfigDef(); + configDef.define(key, ConfigDef.Type.STRING, ConfigDef.Importance.HIGH, ""); + EasyMock.expect(worker.configTransformer()).andReturn(transformer).times(2); + final Capture> configCapture = EasyMock.newCapture(); + EasyMock.expect(transformer.transform(EasyMock.capture(configCapture))).andAnswer(configCapture::getValue); + EasyMock.expect(worker.getPlugins()).andReturn(plugins).times(3); + EasyMock.expect(plugins.compareAndSwapLoaders(connectorMock)).andReturn(delegatingLoader); + EasyMock.expect(worker.getPlugins()).andStubReturn(plugins); + EasyMock.expect(plugins.newConnector(EasyMock.anyString())).andReturn(connectorMock); + EasyMock.expect(connectorMock.config()).andStubReturn(configDef); + EasyMock.expect(Plugins.compareAndSwapLoaders(delegatingLoader)).andReturn(pluginLoader); + + PowerMock.replayAll(); + + herder.putConnectorConfig(CONNECTOR_NAME, config, true, createCallback); + try { + createCallback.get(1000L, TimeUnit.SECONDS); + fail("Should have failed to configure connector"); + } catch (ExecutionException e) { + assertNotNull(e.getCause()); + Throwable cause = e.getCause(); + assertTrue(cause instanceof BadRequestException); + assertEquals( + cause.getMessage(), + "Connector configuration is invalid and contains the following 1 error(s):\n" + + error + "\n" + + "You can also find the above list of errors at the endpoint `/connector-plugins/{connectorType}/config/validate`" + ); + } + + PowerMock.verifyAll(); + } + + private void expectAdd(SourceSink sourceSink) { + Map connectorProps = connectorConfig(sourceSink); + ConnectorConfig connConfig = sourceSink == SourceSink.SOURCE ? + new SourceConnectorConfig(plugins, connectorProps, true) : + new SinkConnectorConfig(plugins, connectorProps); + + Capture> onStart = EasyMock.newCapture(); + worker.startConnector(EasyMock.eq(CONNECTOR_NAME), EasyMock.eq(connectorProps), EasyMock.anyObject(HerderConnectorContext.class), + EasyMock.eq(herder), EasyMock.eq(TargetState.STARTED), EasyMock.capture(onStart)); + // EasyMock.expectLastCall().andReturn(true); + EasyMock.expectLastCall().andAnswer(() -> { + onStart.getValue().onCompletion(null, TargetState.STARTED); + return true; + }); + EasyMock.expect(worker.isRunning(CONNECTOR_NAME)).andReturn(true); + if (sourceSink == SourceSink.SOURCE) { + EasyMock.expect(worker.isTopicCreationEnabled()).andReturn(true); + } + + // And we should instantiate the tasks. For a sink task, we should see added properties for the input topic partitions + + Map generatedTaskProps = taskConfig(sourceSink); + + EasyMock.expect(worker.connectorTaskConfigs(CONNECTOR_NAME, connConfig)) + .andReturn(singletonList(generatedTaskProps)); + + ClusterConfigState configState = new ClusterConfigState( + -1, + null, + Collections.singletonMap(CONNECTOR_NAME, 1), + Collections.singletonMap(CONNECTOR_NAME, connectorConfig(sourceSink)), + Collections.singletonMap(CONNECTOR_NAME, TargetState.STARTED), + Collections.singletonMap(new ConnectorTaskId(CONNECTOR_NAME, 0), generatedTaskProps), + new HashSet<>(), + transformer); + worker.startTask(new ConnectorTaskId(CONNECTOR_NAME, 0), configState, connectorConfig(sourceSink), generatedTaskProps, herder, TargetState.STARTED); + EasyMock.expectLastCall().andReturn(true); + + EasyMock.expect(herder.connectorTypeForClass(BogusSourceConnector.class.getName())) + .andReturn(ConnectorType.SOURCE).anyTimes(); + EasyMock.expect(herder.connectorTypeForClass(BogusSinkConnector.class.getName())) + .andReturn(ConnectorType.SINK).anyTimes(); + worker.isSinkConnector(CONNECTOR_NAME); + PowerMock.expectLastCall().andReturn(sourceSink == SourceSink.SINK); + } + + private ConnectorInfo createdInfo(SourceSink sourceSink) { + return new ConnectorInfo(CONNECTOR_NAME, connectorConfig(sourceSink), + Arrays.asList(new ConnectorTaskId(CONNECTOR_NAME, 0)), + SourceSink.SOURCE == sourceSink ? ConnectorType.SOURCE : ConnectorType.SINK); + } + + private void expectStop() { + ConnectorTaskId task = new ConnectorTaskId(CONNECTOR_NAME, 0); + worker.stopAndAwaitTasks(singletonList(task)); + EasyMock.expectLastCall(); + worker.stopAndAwaitConnector(CONNECTOR_NAME); + EasyMock.expectLastCall(); + } + + private void expectDestroy() { + expectStop(); + } + + private static Map connectorConfig(SourceSink sourceSink) { + Map props = new HashMap<>(); + props.put(ConnectorConfig.NAME_CONFIG, CONNECTOR_NAME); + Class connectorClass = sourceSink == SourceSink.SINK ? BogusSinkConnector.class : BogusSourceConnector.class; + props.put(ConnectorConfig.CONNECTOR_CLASS_CONFIG, connectorClass.getName()); + props.put(ConnectorConfig.TASKS_MAX_CONFIG, "1"); + if (sourceSink == SourceSink.SINK) { + props.put(SinkTask.TOPICS_CONFIG, TOPICS_LIST_STR); + } else { + props.put(DEFAULT_TOPIC_CREATION_PREFIX + REPLICATION_FACTOR_CONFIG, String.valueOf(1)); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + PARTITIONS_CONFIG, String.valueOf(1)); + } + return props; + } + + private static Map taskConfig(SourceSink sourceSink) { + HashMap generatedTaskProps = new HashMap<>(); + // Connectors can add any settings, so these are arbitrary + generatedTaskProps.put("foo", "bar"); + Class taskClass = sourceSink == SourceSink.SINK ? BogusSinkTask.class : BogusSourceTask.class; + generatedTaskProps.put(TaskConfig.TASK_CLASS_CONFIG, taskClass.getName()); + if (sourceSink == SourceSink.SINK) + generatedTaskProps.put(SinkTask.TOPICS_CONFIG, TOPICS_LIST_STR); + return generatedTaskProps; + } + + private void expectConfigValidation( + Connector connectorMock, + boolean shouldCreateConnector, + Map... configs + ) { + // config validation + EasyMock.expect(worker.configTransformer()).andReturn(transformer).times(2); + final Capture> configCapture = EasyMock.newCapture(); + EasyMock.expect(transformer.transform(EasyMock.capture(configCapture))).andAnswer(configCapture::getValue); + EasyMock.expect(worker.getPlugins()).andReturn(plugins).times(3); + EasyMock.expect(plugins.compareAndSwapLoaders(connectorMock)).andReturn(delegatingLoader); + if (shouldCreateConnector) { + EasyMock.expect(worker.getPlugins()).andReturn(plugins); + EasyMock.expect(plugins.newConnector(EasyMock.anyString())).andReturn(connectorMock); + } + EasyMock.expect(connectorMock.config()).andStubReturn(new ConfigDef()); + + for (Map config : configs) + EasyMock.expect(connectorMock.validate(config)).andReturn(new Config(Collections.emptyList())); + EasyMock.expect(Plugins.compareAndSwapLoaders(delegatingLoader)).andReturn(pluginLoader); + } + + // We need to use a real class here due to some issue with mocking java.lang.Class + private abstract class BogusSourceConnector extends SourceConnector { + } + + private abstract class BogusSourceTask extends SourceTask { + } + + private abstract class BogusSinkConnector extends SinkConnector { + } + + private abstract class BogusSinkTask extends SourceTask { + } + +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/storage/FileOffsetBackingStoreTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/storage/FileOffsetBackingStoreTest.java new file mode 100644 index 0000000..9944a5d --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/storage/FileOffsetBackingStoreTest.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import org.apache.kafka.connect.runtime.standalone.StandaloneConfig; +import org.apache.kafka.connect.util.Callback; +import org.easymock.EasyMock; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.easymock.PowerMock; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ThreadPoolExecutor; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +@RunWith(PowerMockRunner.class) +public class FileOffsetBackingStoreTest { + + FileOffsetBackingStore store; + Map props; + StandaloneConfig config; + File tempFile; + + private static Map firstSet = new HashMap<>(); + private static final Runnable EMPTY_RUNNABLE = () -> { + }; + + static { + firstSet.put(buffer("key"), buffer("value")); + firstSet.put(null, null); + } + + @SuppressWarnings("deprecation") + @Before + public void setup() throws IOException { + store = new FileOffsetBackingStore(); + tempFile = File.createTempFile("fileoffsetbackingstore", null); + props = new HashMap<>(); + props.put(StandaloneConfig.OFFSET_STORAGE_FILE_FILENAME_CONFIG, tempFile.getAbsolutePath()); + props.put(StandaloneConfig.KEY_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + props.put(StandaloneConfig.VALUE_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + config = new StandaloneConfig(props); + store.configure(config); + store.start(); + } + + @After + public void teardown() { + tempFile.delete(); + } + + @Test + public void testGetSet() throws Exception { + Callback setCallback = expectSuccessfulSetCallback(); + PowerMock.replayAll(); + + store.set(firstSet, setCallback).get(); + + Map values = store.get(Arrays.asList(buffer("key"), buffer("bad"))).get(); + assertEquals(buffer("value"), values.get(buffer("key"))); + assertNull(values.get(buffer("bad"))); + + PowerMock.verifyAll(); + } + + @Test + public void testSaveRestore() throws Exception { + Callback setCallback = expectSuccessfulSetCallback(); + PowerMock.replayAll(); + + store.set(firstSet, setCallback).get(); + store.stop(); + + // Restore into a new store to ensure correct reload from scratch + FileOffsetBackingStore restore = new FileOffsetBackingStore(); + restore.configure(config); + restore.start(); + Map values = restore.get(Arrays.asList(buffer("key"))).get(); + assertEquals(buffer("value"), values.get(buffer("key"))); + + PowerMock.verifyAll(); + } + + @Test + public void testThreadName() { + assertTrue(((ThreadPoolExecutor) store.executor).getThreadFactory() + .newThread(EMPTY_RUNNABLE).getName().startsWith(FileOffsetBackingStore.class.getSimpleName())); + } + + private static ByteBuffer buffer(String v) { + return ByteBuffer.wrap(v.getBytes()); + } + + private Callback expectSuccessfulSetCallback() { + @SuppressWarnings("unchecked") + Callback setCallback = PowerMock.createMock(Callback.class); + setCallback.onCompletion(EasyMock.isNull(Throwable.class), EasyMock.isNull(Void.class)); + PowerMock.expectLastCall(); + return setCallback; + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/storage/KafkaConfigBackingStoreTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/storage/KafkaConfigBackingStoreTest.java new file mode 100644 index 0000000..18d92bf --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/storage/KafkaConfigBackingStoreTest.java @@ -0,0 +1,1195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.connect.data.Field; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.runtime.RestartRequest; +import org.apache.kafka.connect.runtime.TargetState; +import org.apache.kafka.connect.runtime.distributed.ClusterConfigState; +import org.apache.kafka.connect.runtime.distributed.DistributedConfig; +import org.apache.kafka.connect.util.Callback; +import org.apache.kafka.connect.util.ConnectUtils; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.apache.kafka.connect.util.KafkaBasedLog; +import org.apache.kafka.connect.util.TestFuture; +import org.apache.kafka.connect.util.TopicAdmin; +import org.easymock.Capture; +import org.easymock.EasyMock; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.easymock.PowerMock; +import org.powermock.api.easymock.annotation.Mock; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; +import org.powermock.reflect.Whitebox; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +import static org.apache.kafka.connect.storage.KafkaConfigBackingStore.INCLUDE_TASKS_FIELD_NAME; +import static org.apache.kafka.connect.storage.KafkaConfigBackingStore.ONLY_FAILED_FIELD_NAME; +import static org.apache.kafka.connect.storage.KafkaConfigBackingStore.RESTART_KEY; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +@RunWith(PowerMockRunner.class) +@PrepareForTest({KafkaConfigBackingStore.class, ConnectUtils.class}) +@PowerMockIgnore({"javax.management.*", "javax.crypto.*"}) +@SuppressWarnings({"unchecked", "deprecation"}) +public class KafkaConfigBackingStoreTest { + private static final String TOPIC = "connect-configs"; + private static final short TOPIC_REPLICATION_FACTOR = 5; + private static final Map DEFAULT_CONFIG_STORAGE_PROPS = new HashMap<>(); + private static final DistributedConfig DEFAULT_DISTRIBUTED_CONFIG; + + static { + DEFAULT_CONFIG_STORAGE_PROPS.put(DistributedConfig.CONFIG_TOPIC_CONFIG, TOPIC); + DEFAULT_CONFIG_STORAGE_PROPS.put(DistributedConfig.OFFSET_STORAGE_TOPIC_CONFIG, "connect-offsets"); + DEFAULT_CONFIG_STORAGE_PROPS.put(DistributedConfig.CONFIG_STORAGE_REPLICATION_FACTOR_CONFIG, Short.toString(TOPIC_REPLICATION_FACTOR)); + DEFAULT_CONFIG_STORAGE_PROPS.put(DistributedConfig.GROUP_ID_CONFIG, "connect"); + DEFAULT_CONFIG_STORAGE_PROPS.put(DistributedConfig.STATUS_STORAGE_TOPIC_CONFIG, "status-topic"); + DEFAULT_CONFIG_STORAGE_PROPS.put(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, "broker1:9092,broker2:9093"); + DEFAULT_CONFIG_STORAGE_PROPS.put(DistributedConfig.KEY_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + DEFAULT_CONFIG_STORAGE_PROPS.put(DistributedConfig.VALUE_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + DEFAULT_DISTRIBUTED_CONFIG = new DistributedConfig(DEFAULT_CONFIG_STORAGE_PROPS); + } + + private static final List CONNECTOR_IDS = Arrays.asList("connector1", "connector2"); + private static final List CONNECTOR_CONFIG_KEYS = Arrays.asList("connector-connector1", "connector-connector2"); + private static final List COMMIT_TASKS_CONFIG_KEYS = Arrays.asList("commit-connector1", "commit-connector2"); + private static final List TARGET_STATE_KEYS = Arrays.asList("target-state-connector1", "target-state-connector2"); + + private static final String CONNECTOR_1_NAME = "connector1"; + private static final String CONNECTOR_2_NAME = "connector2"; + private static final List RESTART_CONNECTOR_KEYS = Arrays.asList(RESTART_KEY(CONNECTOR_1_NAME), RESTART_KEY(CONNECTOR_2_NAME)); + + // Need a) connector with multiple tasks and b) multiple connectors + private static final List TASK_IDS = Arrays.asList( + new ConnectorTaskId("connector1", 0), + new ConnectorTaskId("connector1", 1), + new ConnectorTaskId("connector2", 0) + ); + private static final List TASK_CONFIG_KEYS = Arrays.asList("task-connector1-0", "task-connector1-1", "task-connector2-0"); + + // Need some placeholders -- the contents don't matter here, just that they are restored properly + private static final List> SAMPLE_CONFIGS = Arrays.asList( + Collections.singletonMap("config-key-one", "config-value-one"), + Collections.singletonMap("config-key-two", "config-value-two"), + Collections.singletonMap("config-key-three", "config-value-three") + ); + private static final List CONNECTOR_CONFIG_STRUCTS = Arrays.asList( + new Struct(KafkaConfigBackingStore.CONNECTOR_CONFIGURATION_V0).put("properties", SAMPLE_CONFIGS.get(0)), + new Struct(KafkaConfigBackingStore.CONNECTOR_CONFIGURATION_V0).put("properties", SAMPLE_CONFIGS.get(1)), + new Struct(KafkaConfigBackingStore.CONNECTOR_CONFIGURATION_V0).put("properties", SAMPLE_CONFIGS.get(2)) + ); + private static final List TASK_CONFIG_STRUCTS = Arrays.asList( + new Struct(KafkaConfigBackingStore.TASK_CONFIGURATION_V0).put("properties", SAMPLE_CONFIGS.get(0)), + new Struct(KafkaConfigBackingStore.TASK_CONFIGURATION_V0).put("properties", SAMPLE_CONFIGS.get(1)) + ); + private static final Struct TARGET_STATE_PAUSED = new Struct(KafkaConfigBackingStore.TARGET_STATE_V0).put("state", "PAUSED"); + + private static final Struct TASKS_COMMIT_STRUCT_TWO_TASK_CONNECTOR + = new Struct(KafkaConfigBackingStore.CONNECTOR_TASKS_COMMIT_V0).put("tasks", 2); + + private static final Struct TASKS_COMMIT_STRUCT_ZERO_TASK_CONNECTOR + = new Struct(KafkaConfigBackingStore.CONNECTOR_TASKS_COMMIT_V0).put("tasks", 0); + + private static final Struct ONLY_FAILED_MISSING_STRUCT = new Struct(KafkaConfigBackingStore.RESTART_REQUEST_V0).put(INCLUDE_TASKS_FIELD_NAME, false); + private static final Struct INLUDE_TASKS_MISSING_STRUCT = new Struct(KafkaConfigBackingStore.RESTART_REQUEST_V0).put(ONLY_FAILED_FIELD_NAME, true); + private static final List RESTART_REQUEST_STRUCTS = Arrays.asList( + new Struct(KafkaConfigBackingStore.RESTART_REQUEST_V0).put(ONLY_FAILED_FIELD_NAME, true).put(INCLUDE_TASKS_FIELD_NAME, false), + ONLY_FAILED_MISSING_STRUCT, + INLUDE_TASKS_MISSING_STRUCT); + + // The exact format doesn't matter here since both conversions are mocked + private static final List CONFIGS_SERIALIZED = Arrays.asList( + "config-bytes-1".getBytes(), "config-bytes-2".getBytes(), "config-bytes-3".getBytes(), + "config-bytes-4".getBytes(), "config-bytes-5".getBytes(), "config-bytes-6".getBytes(), + "config-bytes-7".getBytes(), "config-bytes-8".getBytes(), "config-bytes-9".getBytes() + ); + + @Mock + private Converter converter; + @Mock + private ConfigBackingStore.UpdateListener configUpdateListener; + @Mock + KafkaBasedLog storeLog; + private KafkaConfigBackingStore configStorage; + + private Capture capturedTopic = EasyMock.newCapture(); + private Capture> capturedProducerProps = EasyMock.newCapture(); + private Capture> capturedConsumerProps = EasyMock.newCapture(); + private Capture> capturedAdminSupplier = EasyMock.newCapture(); + private Capture capturedNewTopic = EasyMock.newCapture(); + private Capture>> capturedConsumedCallback = EasyMock.newCapture(); + + private long logOffset = 0; + + @Before + public void setUp() { + PowerMock.mockStaticPartial(ConnectUtils.class, "lookupKafkaClusterId"); + EasyMock.expect(ConnectUtils.lookupKafkaClusterId(EasyMock.anyObject())).andReturn("test-cluster").anyTimes(); + PowerMock.replay(ConnectUtils.class); + + configStorage = PowerMock.createPartialMock(KafkaConfigBackingStore.class, new String[]{"createKafkaBasedLog"}, converter, DEFAULT_DISTRIBUTED_CONFIG, null); + Whitebox.setInternalState(configStorage, "configLog", storeLog); + configStorage.setUpdateListener(configUpdateListener); + } + + @Test + public void testStartStop() throws Exception { + expectConfigure(); + expectStart(Collections.emptyList(), Collections.emptyMap()); + expectPartitionCount(1); + expectStop(); + PowerMock.replayAll(); + + Map settings = new HashMap<>(DEFAULT_CONFIG_STORAGE_PROPS); + settings.put("config.storage.min.insync.replicas", "3"); + settings.put("config.storage.max.message.bytes", "1001"); + configStorage.setupAndCreateKafkaBasedLog(TOPIC, new DistributedConfig(settings)); + + assertEquals(TOPIC, capturedTopic.getValue()); + assertEquals("org.apache.kafka.common.serialization.StringSerializer", capturedProducerProps.getValue().get(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG)); + assertEquals("org.apache.kafka.common.serialization.ByteArraySerializer", capturedProducerProps.getValue().get(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG)); + assertEquals("org.apache.kafka.common.serialization.StringDeserializer", capturedConsumerProps.getValue().get(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG)); + assertEquals("org.apache.kafka.common.serialization.ByteArrayDeserializer", capturedConsumerProps.getValue().get(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG)); + + assertEquals(TOPIC, capturedNewTopic.getValue().name()); + assertEquals(1, capturedNewTopic.getValue().numPartitions()); + assertEquals(TOPIC_REPLICATION_FACTOR, capturedNewTopic.getValue().replicationFactor()); + assertEquals("3", capturedNewTopic.getValue().configs().get("min.insync.replicas")); + assertEquals("1001", capturedNewTopic.getValue().configs().get("max.message.bytes")); + configStorage.start(); + configStorage.stop(); + + PowerMock.verifyAll(); + } + + @Test + public void testPutConnectorConfig() throws Exception { + expectConfigure(); + expectStart(Collections.emptyList(), Collections.emptyMap()); + + expectConvertWriteAndRead( + CONNECTOR_CONFIG_KEYS.get(0), KafkaConfigBackingStore.CONNECTOR_CONFIGURATION_V0, CONFIGS_SERIALIZED.get(0), + "properties", SAMPLE_CONFIGS.get(0)); + configUpdateListener.onConnectorConfigUpdate(CONNECTOR_IDS.get(0)); + EasyMock.expectLastCall(); + + expectConvertWriteAndRead( + CONNECTOR_CONFIG_KEYS.get(1), KafkaConfigBackingStore.CONNECTOR_CONFIGURATION_V0, CONFIGS_SERIALIZED.get(1), + "properties", SAMPLE_CONFIGS.get(1)); + configUpdateListener.onConnectorConfigUpdate(CONNECTOR_IDS.get(1)); + EasyMock.expectLastCall(); + + // Config deletion + expectConnectorRemoval(CONNECTOR_CONFIG_KEYS.get(1), TARGET_STATE_KEYS.get(1)); + configUpdateListener.onConnectorConfigRemove(CONNECTOR_IDS.get(1)); + EasyMock.expectLastCall(); + + expectPartitionCount(1); + expectStop(); + + PowerMock.replayAll(); + + configStorage.setupAndCreateKafkaBasedLog(TOPIC, DEFAULT_DISTRIBUTED_CONFIG); + configStorage.start(); + + // Null before writing + ClusterConfigState configState = configStorage.snapshot(); + assertEquals(-1, configState.offset()); + assertNull(configState.connectorConfig(CONNECTOR_IDS.get(0))); + assertNull(configState.connectorConfig(CONNECTOR_IDS.get(1))); + + // Writing should block until it is written and read back from Kafka + configStorage.putConnectorConfig(CONNECTOR_IDS.get(0), SAMPLE_CONFIGS.get(0)); + configState = configStorage.snapshot(); + assertEquals(1, configState.offset()); + assertEquals(SAMPLE_CONFIGS.get(0), configState.connectorConfig(CONNECTOR_IDS.get(0))); + assertNull(configState.connectorConfig(CONNECTOR_IDS.get(1))); + + // Second should also block and all configs should still be available + configStorage.putConnectorConfig(CONNECTOR_IDS.get(1), SAMPLE_CONFIGS.get(1)); + configState = configStorage.snapshot(); + assertEquals(2, configState.offset()); + assertEquals(SAMPLE_CONFIGS.get(0), configState.connectorConfig(CONNECTOR_IDS.get(0))); + assertEquals(SAMPLE_CONFIGS.get(1), configState.connectorConfig(CONNECTOR_IDS.get(1))); + + // Deletion should remove the second one we added + configStorage.removeConnectorConfig(CONNECTOR_IDS.get(1)); + configState = configStorage.snapshot(); + assertEquals(4, configState.offset()); + assertEquals(SAMPLE_CONFIGS.get(0), configState.connectorConfig(CONNECTOR_IDS.get(0))); + assertNull(configState.connectorConfig(CONNECTOR_IDS.get(1))); + assertNull(configState.targetState(CONNECTOR_IDS.get(1))); + + configStorage.stop(); + + PowerMock.verifyAll(); + } + + @Test + public void testPutTaskConfigs() throws Exception { + expectConfigure(); + expectStart(Collections.emptyList(), Collections.emptyMap()); + + // Task configs should read to end, write to the log, read to end, write root, then read to end again + expectReadToEnd(new LinkedHashMap<>()); + expectConvertWriteRead( + TASK_CONFIG_KEYS.get(0), KafkaConfigBackingStore.TASK_CONFIGURATION_V0, CONFIGS_SERIALIZED.get(0), + "properties", SAMPLE_CONFIGS.get(0)); + expectConvertWriteRead( + TASK_CONFIG_KEYS.get(1), KafkaConfigBackingStore.TASK_CONFIGURATION_V0, CONFIGS_SERIALIZED.get(1), + "properties", SAMPLE_CONFIGS.get(1)); + expectReadToEnd(new LinkedHashMap<>()); + expectConvertWriteRead( + COMMIT_TASKS_CONFIG_KEYS.get(0), KafkaConfigBackingStore.CONNECTOR_TASKS_COMMIT_V0, CONFIGS_SERIALIZED.get(2), + "tasks", 2); // Starts with 0 tasks, after update has 2 + // As soon as root is rewritten, we should see a callback notifying us that we reconfigured some tasks + configUpdateListener.onTaskConfigUpdate(Arrays.asList(TASK_IDS.get(0), TASK_IDS.get(1))); + EasyMock.expectLastCall(); + + // Records to be read by consumer as it reads to the end of the log + LinkedHashMap serializedConfigs = new LinkedHashMap<>(); + serializedConfigs.put(TASK_CONFIG_KEYS.get(0), CONFIGS_SERIALIZED.get(0)); + serializedConfigs.put(TASK_CONFIG_KEYS.get(1), CONFIGS_SERIALIZED.get(1)); + serializedConfigs.put(COMMIT_TASKS_CONFIG_KEYS.get(0), CONFIGS_SERIALIZED.get(2)); + expectReadToEnd(serializedConfigs); + + expectPartitionCount(1); + expectStop(); + + PowerMock.replayAll(); + + configStorage.setupAndCreateKafkaBasedLog(TOPIC, DEFAULT_DISTRIBUTED_CONFIG); + configStorage.start(); + + // Bootstrap as if we had already added the connector, but no tasks had been added yet + whiteboxAddConnector(CONNECTOR_IDS.get(0), SAMPLE_CONFIGS.get(0), Collections.emptyList()); + + // Null before writing + ClusterConfigState configState = configStorage.snapshot(); + assertEquals(-1, configState.offset()); + assertNull(configState.taskConfig(TASK_IDS.get(0))); + assertNull(configState.taskConfig(TASK_IDS.get(1))); + + // Writing task configs should block until all the writes have been performed and the root record update + // has completed + List> taskConfigs = Arrays.asList(SAMPLE_CONFIGS.get(0), SAMPLE_CONFIGS.get(1)); + configStorage.putTaskConfigs("connector1", taskConfigs); + + // Validate root config by listing all connectors and tasks + configState = configStorage.snapshot(); + assertEquals(3, configState.offset()); + String connectorName = CONNECTOR_IDS.get(0); + assertEquals(Arrays.asList(connectorName), new ArrayList<>(configState.connectors())); + assertEquals(Arrays.asList(TASK_IDS.get(0), TASK_IDS.get(1)), configState.tasks(connectorName)); + assertEquals(SAMPLE_CONFIGS.get(0), configState.taskConfig(TASK_IDS.get(0))); + assertEquals(SAMPLE_CONFIGS.get(1), configState.taskConfig(TASK_IDS.get(1))); + assertEquals(Collections.EMPTY_SET, configState.inconsistentConnectors()); + + configStorage.stop(); + + PowerMock.verifyAll(); + } + + @Test + public void testPutTaskConfigsStartsOnlyReconfiguredTasks() throws Exception { + expectConfigure(); + expectStart(Collections.emptyList(), Collections.emptyMap()); + + // Task configs should read to end, write to the log, read to end, write root, then read to end again + expectReadToEnd(new LinkedHashMap<>()); + expectConvertWriteRead( + TASK_CONFIG_KEYS.get(0), KafkaConfigBackingStore.TASK_CONFIGURATION_V0, CONFIGS_SERIALIZED.get(0), + "properties", SAMPLE_CONFIGS.get(0)); + expectConvertWriteRead( + TASK_CONFIG_KEYS.get(1), KafkaConfigBackingStore.TASK_CONFIGURATION_V0, CONFIGS_SERIALIZED.get(1), + "properties", SAMPLE_CONFIGS.get(1)); + expectReadToEnd(new LinkedHashMap<>()); + expectConvertWriteRead( + COMMIT_TASKS_CONFIG_KEYS.get(0), KafkaConfigBackingStore.CONNECTOR_TASKS_COMMIT_V0, CONFIGS_SERIALIZED.get(2), + "tasks", 2); // Starts with 0 tasks, after update has 2 + // As soon as root is rewritten, we should see a callback notifying us that we reconfigured some tasks + configUpdateListener.onTaskConfigUpdate(Arrays.asList(TASK_IDS.get(0), TASK_IDS.get(1))); + EasyMock.expectLastCall(); + + // Records to be read by consumer as it reads to the end of the log + LinkedHashMap serializedConfigs = new LinkedHashMap<>(); + serializedConfigs.put(TASK_CONFIG_KEYS.get(0), CONFIGS_SERIALIZED.get(0)); + serializedConfigs.put(TASK_CONFIG_KEYS.get(1), CONFIGS_SERIALIZED.get(1)); + serializedConfigs.put(COMMIT_TASKS_CONFIG_KEYS.get(0), CONFIGS_SERIALIZED.get(2)); + expectReadToEnd(serializedConfigs); + + // Task configs should read to end, write to the log, read to end, write root, then read to end again + expectReadToEnd(new LinkedHashMap<>()); + expectConvertWriteRead( + TASK_CONFIG_KEYS.get(2), KafkaConfigBackingStore.TASK_CONFIGURATION_V0, CONFIGS_SERIALIZED.get(3), + "properties", SAMPLE_CONFIGS.get(2)); + expectReadToEnd(new LinkedHashMap<>()); + expectConvertWriteRead( + COMMIT_TASKS_CONFIG_KEYS.get(1), KafkaConfigBackingStore.CONNECTOR_TASKS_COMMIT_V0, CONFIGS_SERIALIZED.get(4), + "tasks", 1); // Starts with 2 tasks, after update has 3 + + // As soon as root is rewritten, we should see a callback notifying us that we reconfigured some tasks + configUpdateListener.onTaskConfigUpdate(Arrays.asList(TASK_IDS.get(2))); + EasyMock.expectLastCall(); + + // Records to be read by consumer as it reads to the end of the log + serializedConfigs = new LinkedHashMap<>(); + serializedConfigs.put(TASK_CONFIG_KEYS.get(2), CONFIGS_SERIALIZED.get(3)); + serializedConfigs.put(COMMIT_TASKS_CONFIG_KEYS.get(1), CONFIGS_SERIALIZED.get(4)); + expectReadToEnd(serializedConfigs); + + expectPartitionCount(1); + expectStop(); + + PowerMock.replayAll(); + + configStorage.setupAndCreateKafkaBasedLog(TOPIC, DEFAULT_DISTRIBUTED_CONFIG); + configStorage.start(); + + // Bootstrap as if we had already added the connector, but no tasks had been added yet + whiteboxAddConnector(CONNECTOR_IDS.get(0), SAMPLE_CONFIGS.get(0), Collections.emptyList()); + whiteboxAddConnector(CONNECTOR_IDS.get(1), SAMPLE_CONFIGS.get(1), Collections.emptyList()); + + // Null before writing + ClusterConfigState configState = configStorage.snapshot(); + assertEquals(-1, configState.offset()); + assertNull(configState.taskConfig(TASK_IDS.get(0))); + assertNull(configState.taskConfig(TASK_IDS.get(1))); + + // Writing task configs should block until all the writes have been performed and the root record update + // has completed + List> taskConfigs = Arrays.asList(SAMPLE_CONFIGS.get(0), SAMPLE_CONFIGS.get(1)); + configStorage.putTaskConfigs("connector1", taskConfigs); + taskConfigs = Collections.singletonList(SAMPLE_CONFIGS.get(2)); + configStorage.putTaskConfigs("connector2", taskConfigs); + + // Validate root config by listing all connectors and tasks + configState = configStorage.snapshot(); + assertEquals(5, configState.offset()); + String connectorName1 = CONNECTOR_IDS.get(0); + String connectorName2 = CONNECTOR_IDS.get(1); + assertEquals(Arrays.asList(connectorName1, connectorName2), new ArrayList<>(configState.connectors())); + assertEquals(Arrays.asList(TASK_IDS.get(0), TASK_IDS.get(1)), configState.tasks(connectorName1)); + assertEquals(Collections.singletonList(TASK_IDS.get(2)), configState.tasks(connectorName2)); + assertEquals(SAMPLE_CONFIGS.get(0), configState.taskConfig(TASK_IDS.get(0))); + assertEquals(SAMPLE_CONFIGS.get(1), configState.taskConfig(TASK_IDS.get(1))); + assertEquals(SAMPLE_CONFIGS.get(2), configState.taskConfig(TASK_IDS.get(2))); + assertEquals(Collections.EMPTY_SET, configState.inconsistentConnectors()); + + configStorage.stop(); + + PowerMock.verifyAll(); + } + + @Test + public void testPutTaskConfigsZeroTasks() throws Exception { + expectConfigure(); + expectStart(Collections.emptyList(), Collections.emptyMap()); + + // Task configs should read to end, write to the log, read to end, write root. + expectReadToEnd(new LinkedHashMap<>()); + expectConvertWriteRead( + COMMIT_TASKS_CONFIG_KEYS.get(0), KafkaConfigBackingStore.CONNECTOR_TASKS_COMMIT_V0, CONFIGS_SERIALIZED.get(0), + "tasks", 0); // We have 0 tasks + // As soon as root is rewritten, we should see a callback notifying us that we reconfigured some tasks + configUpdateListener.onTaskConfigUpdate(Collections.emptyList()); + EasyMock.expectLastCall(); + + // Records to be read by consumer as it reads to the end of the log + LinkedHashMap serializedConfigs = new LinkedHashMap<>(); + serializedConfigs.put(COMMIT_TASKS_CONFIG_KEYS.get(0), CONFIGS_SERIALIZED.get(0)); + expectReadToEnd(serializedConfigs); + + expectPartitionCount(1); + expectStop(); + + PowerMock.replayAll(); + + configStorage.setupAndCreateKafkaBasedLog(TOPIC, DEFAULT_DISTRIBUTED_CONFIG); + configStorage.start(); + + // Bootstrap as if we had already added the connector, but no tasks had been added yet + whiteboxAddConnector(CONNECTOR_IDS.get(0), SAMPLE_CONFIGS.get(0), Collections.emptyList()); + + // Null before writing + ClusterConfigState configState = configStorage.snapshot(); + assertEquals(-1, configState.offset()); + + // Writing task configs should block until all the writes have been performed and the root record update + // has completed + List> taskConfigs = Collections.emptyList(); + configStorage.putTaskConfigs("connector1", taskConfigs); + + // Validate root config by listing all connectors and tasks + configState = configStorage.snapshot(); + assertEquals(1, configState.offset()); + String connectorName = CONNECTOR_IDS.get(0); + assertEquals(Arrays.asList(connectorName), new ArrayList<>(configState.connectors())); + assertEquals(Collections.emptyList(), configState.tasks(connectorName)); + assertEquals(Collections.EMPTY_SET, configState.inconsistentConnectors()); + + configStorage.stop(); + + PowerMock.verifyAll(); + } + + @Test + public void testRestoreTargetState() throws Exception { + expectConfigure(); + List> existingRecords = Arrays.asList( + new ConsumerRecord<>(TOPIC, 0, 0, 0L, TimestampType.CREATE_TIME, 0, 0, CONNECTOR_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(0), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 1, 0L, TimestampType.CREATE_TIME, 0, 0, TASK_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(1), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 2, 0L, TimestampType.CREATE_TIME, 0, 0, TASK_CONFIG_KEYS.get(1), + CONFIGS_SERIALIZED.get(2), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 3, 0L, TimestampType.CREATE_TIME, 0, 0, TARGET_STATE_KEYS.get(0), + CONFIGS_SERIALIZED.get(3), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 4, 0L, TimestampType.CREATE_TIME, 0, 0, COMMIT_TASKS_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(4), new RecordHeaders(), Optional.empty())); + LinkedHashMap deserialized = new LinkedHashMap<>(); + deserialized.put(CONFIGS_SERIALIZED.get(0), CONNECTOR_CONFIG_STRUCTS.get(0)); + deserialized.put(CONFIGS_SERIALIZED.get(1), TASK_CONFIG_STRUCTS.get(0)); + deserialized.put(CONFIGS_SERIALIZED.get(2), TASK_CONFIG_STRUCTS.get(0)); + deserialized.put(CONFIGS_SERIALIZED.get(3), TARGET_STATE_PAUSED); + deserialized.put(CONFIGS_SERIALIZED.get(4), TASKS_COMMIT_STRUCT_TWO_TASK_CONNECTOR); + logOffset = 5; + + expectStart(existingRecords, deserialized); + + // Shouldn't see any callbacks since this is during startup + + expectPartitionCount(1); + expectStop(); + + PowerMock.replayAll(); + + configStorage.setupAndCreateKafkaBasedLog(TOPIC, DEFAULT_DISTRIBUTED_CONFIG); + configStorage.start(); + + // Should see a single connector with initial state paused + ClusterConfigState configState = configStorage.snapshot(); + assertEquals(5, configState.offset()); // Should always be next to be read, even if uncommitted + assertEquals(Arrays.asList(CONNECTOR_IDS.get(0)), new ArrayList<>(configState.connectors())); + assertEquals(TargetState.PAUSED, configState.targetState(CONNECTOR_IDS.get(0))); + + configStorage.stop(); + + PowerMock.verifyAll(); + } + + @Test + public void testBackgroundUpdateTargetState() throws Exception { + // verify that we handle target state changes correctly when they come up through the log + + expectConfigure(); + List> existingRecords = Arrays.asList( + new ConsumerRecord<>(TOPIC, 0, 0, 0L, TimestampType.CREATE_TIME, 0, 0, CONNECTOR_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(0), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 1, 0L, TimestampType.CREATE_TIME, 0, 0, TASK_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(1), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 2, 0L, TimestampType.CREATE_TIME, 0, 0, TASK_CONFIG_KEYS.get(1), + CONFIGS_SERIALIZED.get(2), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 3, 0L, TimestampType.CREATE_TIME, 0, 0, COMMIT_TASKS_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(3), new RecordHeaders(), Optional.empty())); + LinkedHashMap deserialized = new LinkedHashMap<>(); + deserialized.put(CONFIGS_SERIALIZED.get(0), CONNECTOR_CONFIG_STRUCTS.get(0)); + deserialized.put(CONFIGS_SERIALIZED.get(1), TASK_CONFIG_STRUCTS.get(0)); + deserialized.put(CONFIGS_SERIALIZED.get(2), TASK_CONFIG_STRUCTS.get(0)); + deserialized.put(CONFIGS_SERIALIZED.get(3), TASKS_COMMIT_STRUCT_TWO_TASK_CONNECTOR); + logOffset = 5; + + expectStart(existingRecords, deserialized); + + expectRead(TARGET_STATE_KEYS.get(0), CONFIGS_SERIALIZED.get(0), TARGET_STATE_PAUSED); + + configUpdateListener.onConnectorTargetStateChange(CONNECTOR_IDS.get(0)); + EasyMock.expectLastCall(); + + expectPartitionCount(1); + expectStop(); + + PowerMock.replayAll(); + + configStorage.setupAndCreateKafkaBasedLog(TOPIC, DEFAULT_DISTRIBUTED_CONFIG); + configStorage.start(); + + // Should see a single connector with initial state paused + ClusterConfigState configState = configStorage.snapshot(); + assertEquals(TargetState.STARTED, configState.targetState(CONNECTOR_IDS.get(0))); + + configStorage.refresh(0, TimeUnit.SECONDS); + + configStorage.stop(); + + PowerMock.verifyAll(); + } + + @Test + public void testBackgroundConnectorDeletion() throws Exception { + // verify that we handle connector deletions correctly when they come up through the log + + expectConfigure(); + List> existingRecords = Arrays.asList( + new ConsumerRecord<>(TOPIC, 0, 0, 0L, TimestampType.CREATE_TIME, 0, 0, CONNECTOR_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(0), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 1, 0L, TimestampType.CREATE_TIME, 0, 0, TASK_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(1), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 2, 0L, TimestampType.CREATE_TIME, 0, 0, TASK_CONFIG_KEYS.get(1), + CONFIGS_SERIALIZED.get(2), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 3, 0L, TimestampType.CREATE_TIME, 0, 0, COMMIT_TASKS_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(3), new RecordHeaders(), Optional.empty())); + LinkedHashMap deserialized = new LinkedHashMap<>(); + deserialized.put(CONFIGS_SERIALIZED.get(0), CONNECTOR_CONFIG_STRUCTS.get(0)); + deserialized.put(CONFIGS_SERIALIZED.get(1), TASK_CONFIG_STRUCTS.get(0)); + deserialized.put(CONFIGS_SERIALIZED.get(2), TASK_CONFIG_STRUCTS.get(1)); + deserialized.put(CONFIGS_SERIALIZED.get(3), TASKS_COMMIT_STRUCT_TWO_TASK_CONNECTOR); + logOffset = 5; + + expectStart(existingRecords, deserialized); + + LinkedHashMap serializedData = new LinkedHashMap<>(); + serializedData.put(CONNECTOR_CONFIG_KEYS.get(0), CONFIGS_SERIALIZED.get(0)); + serializedData.put(TARGET_STATE_KEYS.get(0), CONFIGS_SERIALIZED.get(1)); + + Map deserializedData = new HashMap<>(); + deserializedData.put(CONNECTOR_CONFIG_KEYS.get(0), null); + deserializedData.put(TARGET_STATE_KEYS.get(0), null); + + expectRead(serializedData, deserializedData); + + configUpdateListener.onConnectorConfigRemove(CONNECTOR_IDS.get(0)); + EasyMock.expectLastCall(); + + expectPartitionCount(1); + expectStop(); + + PowerMock.replayAll(); + + configStorage.setupAndCreateKafkaBasedLog(TOPIC, DEFAULT_DISTRIBUTED_CONFIG); + configStorage.start(); + + // Should see a single connector with initial state paused + ClusterConfigState configState = configStorage.snapshot(); + assertEquals(TargetState.STARTED, configState.targetState(CONNECTOR_IDS.get(0))); + assertEquals(SAMPLE_CONFIGS.get(0), configState.connectorConfig(CONNECTOR_IDS.get(0))); + assertEquals(SAMPLE_CONFIGS.subList(0, 2), configState.allTaskConfigs(CONNECTOR_IDS.get(0))); + assertEquals(2, configState.taskCount(CONNECTOR_IDS.get(0))); + + configStorage.refresh(0, TimeUnit.SECONDS); + configState = configStorage.snapshot(); + // Connector should now be removed from the snapshot + assertFalse(configState.contains(CONNECTOR_IDS.get(0))); + // Task configs for the deleted connector should also be removed from the snapshot + assertEquals(Collections.emptyList(), configState.allTaskConfigs(CONNECTOR_IDS.get(0))); + assertEquals(0, configState.taskCount(CONNECTOR_IDS.get(0))); + + configStorage.stop(); + + PowerMock.verifyAll(); + } + + @Test + public void testRestoreTargetStateUnexpectedDeletion() throws Exception { + expectConfigure(); + List> existingRecords = Arrays.asList( + new ConsumerRecord<>(TOPIC, 0, 0, 0L, TimestampType.CREATE_TIME, 0, 0, CONNECTOR_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(0), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 1, 0L, TimestampType.CREATE_TIME, 0, 0, TASK_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(1), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 2, 0L, TimestampType.CREATE_TIME, 0, 0, TASK_CONFIG_KEYS.get(1), + CONFIGS_SERIALIZED.get(2), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 3, 0L, TimestampType.CREATE_TIME, 0, 0, TARGET_STATE_KEYS.get(0), + CONFIGS_SERIALIZED.get(3), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 4, 0L, TimestampType.CREATE_TIME, 0, 0, COMMIT_TASKS_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(4), new RecordHeaders(), Optional.empty())); + LinkedHashMap deserialized = new LinkedHashMap<>(); + deserialized.put(CONFIGS_SERIALIZED.get(0), CONNECTOR_CONFIG_STRUCTS.get(0)); + deserialized.put(CONFIGS_SERIALIZED.get(1), TASK_CONFIG_STRUCTS.get(0)); + deserialized.put(CONFIGS_SERIALIZED.get(2), TASK_CONFIG_STRUCTS.get(0)); + deserialized.put(CONFIGS_SERIALIZED.get(3), null); + deserialized.put(CONFIGS_SERIALIZED.get(4), TASKS_COMMIT_STRUCT_TWO_TASK_CONNECTOR); + logOffset = 5; + + expectStart(existingRecords, deserialized); + expectPartitionCount(1); + + // Shouldn't see any callbacks since this is during startup + + expectStop(); + + PowerMock.replayAll(); + + configStorage.setupAndCreateKafkaBasedLog(TOPIC, DEFAULT_DISTRIBUTED_CONFIG); + configStorage.start(); + + // The target state deletion should reset the state to STARTED + ClusterConfigState configState = configStorage.snapshot(); + assertEquals(5, configState.offset()); // Should always be next to be read, even if uncommitted + assertEquals(Arrays.asList(CONNECTOR_IDS.get(0)), new ArrayList<>(configState.connectors())); + assertEquals(TargetState.STARTED, configState.targetState(CONNECTOR_IDS.get(0))); + + configStorage.stop(); + + PowerMock.verifyAll(); + } + + @Test + public void testRestore() throws Exception { + // Restoring data should notify only of the latest values after loading is complete. This also validates + // that inconsistent state is ignored. + + expectConfigure(); + // Overwrite each type at least once to ensure we see the latest data after loading + List> existingRecords = Arrays.asList( + new ConsumerRecord<>(TOPIC, 0, 0, 0L, TimestampType.CREATE_TIME, 0, 0, CONNECTOR_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(0), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 1, 0L, TimestampType.CREATE_TIME, 0, 0, TASK_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(1), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 2, 0L, TimestampType.CREATE_TIME, 0, 0, TASK_CONFIG_KEYS.get(1), + CONFIGS_SERIALIZED.get(2), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 3, 0L, TimestampType.CREATE_TIME, 0, 0, CONNECTOR_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(3), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 4, 0L, TimestampType.CREATE_TIME, 0, 0, COMMIT_TASKS_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(4), new RecordHeaders(), Optional.empty()), + // Connector after root update should make it through, task update shouldn't + new ConsumerRecord<>(TOPIC, 0, 5, 0L, TimestampType.CREATE_TIME, 0, 0, CONNECTOR_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(5), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 6, 0L, TimestampType.CREATE_TIME, 0, 0, TASK_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(6), new RecordHeaders(), Optional.empty())); + LinkedHashMap deserialized = new LinkedHashMap<>(); + deserialized.put(CONFIGS_SERIALIZED.get(0), CONNECTOR_CONFIG_STRUCTS.get(0)); + deserialized.put(CONFIGS_SERIALIZED.get(1), TASK_CONFIG_STRUCTS.get(0)); + deserialized.put(CONFIGS_SERIALIZED.get(2), TASK_CONFIG_STRUCTS.get(0)); + deserialized.put(CONFIGS_SERIALIZED.get(3), CONNECTOR_CONFIG_STRUCTS.get(1)); + deserialized.put(CONFIGS_SERIALIZED.get(4), TASKS_COMMIT_STRUCT_TWO_TASK_CONNECTOR); + deserialized.put(CONFIGS_SERIALIZED.get(5), CONNECTOR_CONFIG_STRUCTS.get(2)); + deserialized.put(CONFIGS_SERIALIZED.get(6), TASK_CONFIG_STRUCTS.get(1)); + logOffset = 7; + expectStart(existingRecords, deserialized); + expectPartitionCount(1); + + // Shouldn't see any callbacks since this is during startup + + expectStop(); + + PowerMock.replayAll(); + + configStorage.setupAndCreateKafkaBasedLog(TOPIC, DEFAULT_DISTRIBUTED_CONFIG); + configStorage.start(); + + // Should see a single connector and its config should be the last one seen anywhere in the log + ClusterConfigState configState = configStorage.snapshot(); + assertEquals(7, configState.offset()); // Should always be next to be read, even if uncommitted + assertEquals(Arrays.asList(CONNECTOR_IDS.get(0)), new ArrayList<>(configState.connectors())); + assertEquals(TargetState.STARTED, configState.targetState(CONNECTOR_IDS.get(0))); + // CONNECTOR_CONFIG_STRUCTS[2] -> SAMPLE_CONFIGS[2] + assertEquals(SAMPLE_CONFIGS.get(2), configState.connectorConfig(CONNECTOR_IDS.get(0))); + // Should see 2 tasks for that connector. Only config updates before the root key update should be reflected + assertEquals(Arrays.asList(TASK_IDS.get(0), TASK_IDS.get(1)), configState.tasks(CONNECTOR_IDS.get(0))); + // Both TASK_CONFIG_STRUCTS[0] -> SAMPLE_CONFIGS[0] + assertEquals(SAMPLE_CONFIGS.get(0), configState.taskConfig(TASK_IDS.get(0))); + assertEquals(SAMPLE_CONFIGS.get(0), configState.taskConfig(TASK_IDS.get(1))); + assertEquals(Collections.EMPTY_SET, configState.inconsistentConnectors()); + + configStorage.stop(); + + PowerMock.verifyAll(); + } + + @Test + public void testRestoreConnectorDeletion() throws Exception { + // Restoring data should notify only of the latest values after loading is complete. This also validates + // that inconsistent state is ignored. + + expectConfigure(); + // Overwrite each type at least once to ensure we see the latest data after loading + List> existingRecords = Arrays.asList( + new ConsumerRecord<>(TOPIC, 0, 0, 0L, TimestampType.CREATE_TIME, 0, 0, CONNECTOR_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(0), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 1, 0L, TimestampType.CREATE_TIME, 0, 0, TASK_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(1), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 2, 0L, TimestampType.CREATE_TIME, 0, 0, TASK_CONFIG_KEYS.get(1), + CONFIGS_SERIALIZED.get(2), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 3, 0L, TimestampType.CREATE_TIME, 0, 0, CONNECTOR_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(3), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 4, 0L, TimestampType.CREATE_TIME, 0, 0, TARGET_STATE_KEYS.get(0), + CONFIGS_SERIALIZED.get(4), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 5, 0L, TimestampType.CREATE_TIME, 0, 0, COMMIT_TASKS_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(5), new RecordHeaders(), Optional.empty())); + + LinkedHashMap deserialized = new LinkedHashMap<>(); + deserialized.put(CONFIGS_SERIALIZED.get(0), CONNECTOR_CONFIG_STRUCTS.get(0)); + deserialized.put(CONFIGS_SERIALIZED.get(1), TASK_CONFIG_STRUCTS.get(0)); + deserialized.put(CONFIGS_SERIALIZED.get(2), TASK_CONFIG_STRUCTS.get(0)); + deserialized.put(CONFIGS_SERIALIZED.get(3), null); + deserialized.put(CONFIGS_SERIALIZED.get(4), null); + deserialized.put(CONFIGS_SERIALIZED.get(5), TASKS_COMMIT_STRUCT_TWO_TASK_CONNECTOR); + + logOffset = 6; + expectStart(existingRecords, deserialized); + expectPartitionCount(1); + + // Shouldn't see any callbacks since this is during startup + + expectStop(); + + PowerMock.replayAll(); + + configStorage.setupAndCreateKafkaBasedLog(TOPIC, DEFAULT_DISTRIBUTED_CONFIG); + configStorage.start(); + + // Should see a single connector and its config should be the last one seen anywhere in the log + ClusterConfigState configState = configStorage.snapshot(); + assertEquals(6, configState.offset()); // Should always be next to be read, even if uncommitted + assertTrue(configState.connectors().isEmpty()); + + configStorage.stop(); + + PowerMock.verifyAll(); + } + + @Test + public void testRestoreZeroTasks() throws Exception { + // Restoring data should notify only of the latest values after loading is complete. This also validates + // that inconsistent state is ignored. + expectConfigure(); + // Overwrite each type at least once to ensure we see the latest data after loading + List> existingRecords = Arrays.asList( + new ConsumerRecord<>(TOPIC, 0, 0, 0L, TimestampType.CREATE_TIME, 0, 0, CONNECTOR_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(0), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 1, 0L, TimestampType.CREATE_TIME, 0, 0, TASK_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(1), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 2, 0L, TimestampType.CREATE_TIME, 0, 0, TASK_CONFIG_KEYS.get(1), + CONFIGS_SERIALIZED.get(2), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 3, 0L, TimestampType.CREATE_TIME, 0, 0, CONNECTOR_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(3), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 4, 0L, TimestampType.CREATE_TIME, 0, 0, COMMIT_TASKS_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(4), new RecordHeaders(), Optional.empty()), + // Connector after root update should make it through, task update shouldn't + new ConsumerRecord<>(TOPIC, 0, 5, 0L, TimestampType.CREATE_TIME, 0, 0, CONNECTOR_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(5), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 6, 0L, TimestampType.CREATE_TIME, 0, 0, TASK_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(6), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 7, 0L, TimestampType.CREATE_TIME, 0, 0, COMMIT_TASKS_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(7), new RecordHeaders(), Optional.empty())); + LinkedHashMap deserialized = new LinkedHashMap<>(); + deserialized.put(CONFIGS_SERIALIZED.get(0), CONNECTOR_CONFIG_STRUCTS.get(0)); + deserialized.put(CONFIGS_SERIALIZED.get(1), TASK_CONFIG_STRUCTS.get(0)); + deserialized.put(CONFIGS_SERIALIZED.get(2), TASK_CONFIG_STRUCTS.get(0)); + deserialized.put(CONFIGS_SERIALIZED.get(3), CONNECTOR_CONFIG_STRUCTS.get(1)); + deserialized.put(CONFIGS_SERIALIZED.get(4), TASKS_COMMIT_STRUCT_TWO_TASK_CONNECTOR); + deserialized.put(CONFIGS_SERIALIZED.get(5), CONNECTOR_CONFIG_STRUCTS.get(2)); + deserialized.put(CONFIGS_SERIALIZED.get(6), TASK_CONFIG_STRUCTS.get(1)); + deserialized.put(CONFIGS_SERIALIZED.get(7), TASKS_COMMIT_STRUCT_ZERO_TASK_CONNECTOR); + logOffset = 8; + expectStart(existingRecords, deserialized); + expectPartitionCount(1); + + // Shouldn't see any callbacks since this is during startup + + expectStop(); + + PowerMock.replayAll(); + + configStorage.setupAndCreateKafkaBasedLog(TOPIC, DEFAULT_DISTRIBUTED_CONFIG); + configStorage.start(); + + // Should see a single connector and its config should be the last one seen anywhere in the log + ClusterConfigState configState = configStorage.snapshot(); + assertEquals(8, configState.offset()); // Should always be next to be read, even if uncommitted + assertEquals(Arrays.asList(CONNECTOR_IDS.get(0)), new ArrayList<>(configState.connectors())); + // CONNECTOR_CONFIG_STRUCTS[2] -> SAMPLE_CONFIGS[2] + assertEquals(SAMPLE_CONFIGS.get(2), configState.connectorConfig(CONNECTOR_IDS.get(0))); + // Should see 0 tasks for that connector. + assertEquals(Collections.emptyList(), configState.tasks(CONNECTOR_IDS.get(0))); + // Both TASK_CONFIG_STRUCTS[0] -> SAMPLE_CONFIGS[0] + assertEquals(Collections.EMPTY_SET, configState.inconsistentConnectors()); + + configStorage.stop(); + + PowerMock.verifyAll(); + } + + @Test + public void testPutTaskConfigsDoesNotResolveAllInconsistencies() throws Exception { + // Test a case where a failure and compaction has left us in an inconsistent state when reading the log. + // We start out by loading an initial configuration where we started to write a task update, and then + // compaction cleaned up the earlier record. + + expectConfigure(); + List> existingRecords = Arrays.asList( + new ConsumerRecord<>(TOPIC, 0, 0, 0L, TimestampType.CREATE_TIME, 0, 0, CONNECTOR_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(0), new RecordHeaders(), Optional.empty()), + // This is the record that has been compacted: + //new ConsumerRecord<>(TOPIC, 0, 1, TASK_CONFIG_KEYS.get(0), CONFIGS_SERIALIZED.get(1)), + new ConsumerRecord<>(TOPIC, 0, 2, 0L, TimestampType.CREATE_TIME, 0, 0, TASK_CONFIG_KEYS.get(1), + CONFIGS_SERIALIZED.get(2), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 4, 0L, TimestampType.CREATE_TIME, 0, 0, COMMIT_TASKS_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(4), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 5, 0L, TimestampType.CREATE_TIME, 0, 0, TASK_CONFIG_KEYS.get(0), + CONFIGS_SERIALIZED.get(5), new RecordHeaders(), Optional.empty())); + LinkedHashMap deserialized = new LinkedHashMap<>(); + deserialized.put(CONFIGS_SERIALIZED.get(0), CONNECTOR_CONFIG_STRUCTS.get(0)); + deserialized.put(CONFIGS_SERIALIZED.get(2), TASK_CONFIG_STRUCTS.get(0)); + deserialized.put(CONFIGS_SERIALIZED.get(4), TASKS_COMMIT_STRUCT_TWO_TASK_CONNECTOR); + deserialized.put(CONFIGS_SERIALIZED.get(5), TASK_CONFIG_STRUCTS.get(1)); + logOffset = 6; + expectStart(existingRecords, deserialized); + expectPartitionCount(1); + + // Successful attempt to write new task config + expectReadToEnd(new LinkedHashMap<>()); + expectConvertWriteRead( + TASK_CONFIG_KEYS.get(0), KafkaConfigBackingStore.TASK_CONFIGURATION_V0, CONFIGS_SERIALIZED.get(0), + "properties", SAMPLE_CONFIGS.get(0)); + expectReadToEnd(new LinkedHashMap<>()); + expectConvertWriteRead( + COMMIT_TASKS_CONFIG_KEYS.get(0), KafkaConfigBackingStore.CONNECTOR_TASKS_COMMIT_V0, CONFIGS_SERIALIZED.get(2), + "tasks", 1); // Updated to just 1 task + // As soon as root is rewritten, we should see a callback notifying us that we reconfigured some tasks + configUpdateListener.onTaskConfigUpdate(Arrays.asList(TASK_IDS.get(0))); + EasyMock.expectLastCall(); + // Records to be read by consumer as it reads to the end of the log + LinkedHashMap serializedConfigs = new LinkedHashMap<>(); + serializedConfigs.put(TASK_CONFIG_KEYS.get(0), CONFIGS_SERIALIZED.get(0)); + serializedConfigs.put(COMMIT_TASKS_CONFIG_KEYS.get(0), CONFIGS_SERIALIZED.get(2)); + expectReadToEnd(serializedConfigs); + + expectStop(); + + PowerMock.replayAll(); + + configStorage.setupAndCreateKafkaBasedLog(TOPIC, DEFAULT_DISTRIBUTED_CONFIG); + configStorage.start(); + // After reading the log, it should have been in an inconsistent state + ClusterConfigState configState = configStorage.snapshot(); + assertEquals(6, configState.offset()); // Should always be next to be read, not last committed + assertEquals(Arrays.asList(CONNECTOR_IDS.get(0)), new ArrayList<>(configState.connectors())); + // Inconsistent data should leave us with no tasks listed for the connector and an entry in the inconsistent list + assertEquals(Collections.emptyList(), configState.tasks(CONNECTOR_IDS.get(0))); + // Both TASK_CONFIG_STRUCTS[0] -> SAMPLE_CONFIGS[0] + assertNull(configState.taskConfig(TASK_IDS.get(0))); + assertNull(configState.taskConfig(TASK_IDS.get(1))); + assertEquals(Collections.singleton(CONNECTOR_IDS.get(0)), configState.inconsistentConnectors()); + + // Next, issue a write that has everything that is needed and it should be accepted. Note that in this case + // we are going to shrink the number of tasks to 1 + configStorage.putTaskConfigs("connector1", Collections.singletonList(SAMPLE_CONFIGS.get(0))); + // Validate updated config + configState = configStorage.snapshot(); + // This is only two more ahead of the last one because multiple calls fail, and so their configs are not written + // to the topic. Only the last call with 1 task config + 1 commit actually gets written. + assertEquals(8, configState.offset()); + assertEquals(Arrays.asList(CONNECTOR_IDS.get(0)), new ArrayList<>(configState.connectors())); + assertEquals(Arrays.asList(TASK_IDS.get(0)), configState.tasks(CONNECTOR_IDS.get(0))); + assertEquals(SAMPLE_CONFIGS.get(0), configState.taskConfig(TASK_IDS.get(0))); + assertEquals(Collections.EMPTY_SET, configState.inconsistentConnectors()); + + configStorage.stop(); + + PowerMock.verifyAll(); + } + + @Test + public void testPutRestartRequestOnlyFailed() throws Exception { + RestartRequest restartRequest = new RestartRequest(CONNECTOR_IDS.get(0), true, false); + testPutRestartRequest(restartRequest); + } + + @Test + public void testPutRestartRequestOnlyFailedIncludingTasks() throws Exception { + RestartRequest restartRequest = new RestartRequest(CONNECTOR_IDS.get(0), true, true); + testPutRestartRequest(restartRequest); + } + + private void testPutRestartRequest(RestartRequest restartRequest) throws Exception { + expectConfigure(); + expectStart(Collections.emptyList(), Collections.emptyMap()); + + expectConvertWriteAndRead( + RESTART_CONNECTOR_KEYS.get(0), KafkaConfigBackingStore.RESTART_REQUEST_V0, CONFIGS_SERIALIZED.get(0), + ONLY_FAILED_FIELD_NAME, restartRequest.onlyFailed()); + final Capture capturedRestartRequest = EasyMock.newCapture(); + configUpdateListener.onRestartRequest(EasyMock.capture(capturedRestartRequest)); + EasyMock.expectLastCall(); + + expectPartitionCount(1); + expectStop(); + + PowerMock.replayAll(); + + configStorage.setupAndCreateKafkaBasedLog(TOPIC, DEFAULT_DISTRIBUTED_CONFIG); + configStorage.start(); + + // Writing should block until it is written and read back from Kafka + configStorage.putRestartRequest(restartRequest); + + assertEquals(restartRequest.connectorName(), capturedRestartRequest.getValue().connectorName()); + assertEquals(restartRequest.onlyFailed(), capturedRestartRequest.getValue().onlyFailed()); + assertEquals(restartRequest.includeTasks(), capturedRestartRequest.getValue().includeTasks()); + + configStorage.stop(); + + PowerMock.verifyAll(); + } + + @Test + public void testRecordToRestartRequest() throws Exception { + ConsumerRecord record = new ConsumerRecord<>(TOPIC, 0, 0, 0L, TimestampType.CREATE_TIME, 0, 0, RESTART_CONNECTOR_KEYS.get(0), + CONFIGS_SERIALIZED.get(0), new RecordHeaders(), Optional.empty()); + Struct struct = RESTART_REQUEST_STRUCTS.get(0); + SchemaAndValue schemaAndValue = new SchemaAndValue(struct.schema(), structToMap(struct)); + RestartRequest restartRequest = configStorage.recordToRestartRequest(record, schemaAndValue); + assertEquals(CONNECTOR_1_NAME, restartRequest.connectorName()); + assertEquals(struct.getBoolean(INCLUDE_TASKS_FIELD_NAME), restartRequest.includeTasks()); + assertEquals(struct.getBoolean(ONLY_FAILED_FIELD_NAME), restartRequest.onlyFailed()); + } + + @Test + public void testRecordToRestartRequestOnlyFailedInconsistent() throws Exception { + ConsumerRecord record = new ConsumerRecord<>(TOPIC, 0, 0, 0L, TimestampType.CREATE_TIME, 0, 0, RESTART_CONNECTOR_KEYS.get(0), + CONFIGS_SERIALIZED.get(0), new RecordHeaders(), Optional.empty()); + Struct struct = ONLY_FAILED_MISSING_STRUCT; + SchemaAndValue schemaAndValue = new SchemaAndValue(struct.schema(), structToMap(struct)); + RestartRequest restartRequest = configStorage.recordToRestartRequest(record, schemaAndValue); + assertEquals(CONNECTOR_1_NAME, restartRequest.connectorName()); + assertEquals(struct.getBoolean(INCLUDE_TASKS_FIELD_NAME), restartRequest.includeTasks()); + assertEquals(false, restartRequest.onlyFailed()); + } + + @Test + public void testRecordToRestartRequestIncludeTasksInconsistent() throws Exception { + ConsumerRecord record = new ConsumerRecord<>(TOPIC, 0, 0, 0L, TimestampType.CREATE_TIME, 0, 0, RESTART_CONNECTOR_KEYS.get(0), + CONFIGS_SERIALIZED.get(0), new RecordHeaders(), Optional.empty()); + Struct struct = INLUDE_TASKS_MISSING_STRUCT; + SchemaAndValue schemaAndValue = new SchemaAndValue(struct.schema(), structToMap(struct)); + RestartRequest restartRequest = configStorage.recordToRestartRequest(record, schemaAndValue); + assertEquals(CONNECTOR_1_NAME, restartRequest.connectorName()); + assertEquals(false, restartRequest.includeTasks()); + assertEquals(struct.getBoolean(ONLY_FAILED_FIELD_NAME), restartRequest.onlyFailed()); + } + + @Test + public void testRestoreRestartRequestInconsistentState() throws Exception { + // Restoring data should notify only of the latest values after loading is complete. This also validates + // that inconsistent state doesnt prevent startup. + + expectConfigure(); + // Overwrite each type at least once to ensure we see the latest data after loading + List> existingRecords = Arrays.asList( + new ConsumerRecord<>(TOPIC, 0, 0, 0L, TimestampType.CREATE_TIME, 0, 0, RESTART_CONNECTOR_KEYS.get(0), + CONFIGS_SERIALIZED.get(0), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 1, 0L, TimestampType.CREATE_TIME, 0, 0, RESTART_CONNECTOR_KEYS.get(1), + CONFIGS_SERIALIZED.get(1), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 2, 0L, TimestampType.CREATE_TIME, 0, 0, RESTART_CONNECTOR_KEYS.get(1), + CONFIGS_SERIALIZED.get(2), new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 3, 0L, TimestampType.CREATE_TIME, 0, 0, RESTART_CONNECTOR_KEYS.get(1), + CONFIGS_SERIALIZED.get(3), new RecordHeaders(), Optional.empty())); + LinkedHashMap deserialized = new LinkedHashMap<>(); + deserialized.put(CONFIGS_SERIALIZED.get(0), RESTART_REQUEST_STRUCTS.get(0)); + deserialized.put(CONFIGS_SERIALIZED.get(1), RESTART_REQUEST_STRUCTS.get(1)); + deserialized.put(CONFIGS_SERIALIZED.get(2), RESTART_REQUEST_STRUCTS.get(2)); + deserialized.put(CONFIGS_SERIALIZED.get(3), null); + logOffset = 4; + expectStart(existingRecords, deserialized); + expectPartitionCount(1); + + // Shouldn't see any callbacks since this is during startup + + expectStop(); + + PowerMock.replayAll(); + + configStorage.setupAndCreateKafkaBasedLog(TOPIC, DEFAULT_DISTRIBUTED_CONFIG); + configStorage.start(); + + configStorage.stop(); + + PowerMock.verifyAll(); + } + + @Test + public void testExceptionOnStartWhenConfigTopicHasMultiplePartitions() throws Exception { + expectConfigure(); + expectStart(Collections.emptyList(), Collections.emptyMap()); + + expectPartitionCount(2); + + PowerMock.replayAll(); + + configStorage.setupAndCreateKafkaBasedLog(TOPIC, DEFAULT_DISTRIBUTED_CONFIG); + ConfigException e = assertThrows(ConfigException.class, () -> configStorage.start()); + assertTrue(e.getMessage().contains("required to have a single partition")); + + PowerMock.verifyAll(); + } + + private void expectConfigure() throws Exception { + PowerMock.expectPrivate(configStorage, "createKafkaBasedLog", + EasyMock.capture(capturedTopic), EasyMock.capture(capturedProducerProps), + EasyMock.capture(capturedConsumerProps), EasyMock.capture(capturedConsumedCallback), + EasyMock.capture(capturedNewTopic), EasyMock.capture(capturedAdminSupplier)) + .andReturn(storeLog); + } + + private void expectPartitionCount(int partitionCount) { + EasyMock.expect(storeLog.partitionCount()) + .andReturn(partitionCount); + } + + // If non-empty, deserializations should be a LinkedHashMap + private void expectStart(final List> preexistingRecords, + final Map deserializations) { + storeLog.start(); + PowerMock.expectLastCall().andAnswer(() -> { + for (ConsumerRecord rec : preexistingRecords) + capturedConsumedCallback.getValue().onCompletion(null, rec); + return null; + }); + for (Map.Entry deserializationEntry : deserializations.entrySet()) { + // Note null schema because default settings for internal serialization are schema-less + EasyMock.expect(converter.toConnectData(EasyMock.eq(TOPIC), EasyMock.aryEq(deserializationEntry.getKey()))) + .andReturn(new SchemaAndValue(null, structToMap(deserializationEntry.getValue()))); + } + } + + private void expectStop() { + storeLog.stop(); + PowerMock.expectLastCall(); + } + + private void expectRead(LinkedHashMap serializedValues, + Map deserializedValues) { + expectReadToEnd(serializedValues); + for (Map.Entry deserializedValueEntry : deserializedValues.entrySet()) { + byte[] serializedValue = serializedValues.get(deserializedValueEntry.getKey()); + EasyMock.expect(converter.toConnectData(EasyMock.eq(TOPIC), EasyMock.aryEq(serializedValue))) + .andReturn(new SchemaAndValue(null, structToMap(deserializedValueEntry.getValue()))); + } + } + + private void expectRead(final String key, final byte[] serializedValue, Struct deserializedValue) { + LinkedHashMap serializedData = new LinkedHashMap<>(); + serializedData.put(key, serializedValue); + expectRead(serializedData, Collections.singletonMap(key, deserializedValue)); + } + + // Expect a conversion & write to the underlying log, followed by a subsequent read when the data is consumed back + // from the log. Validate the data that is captured when the conversion is performed matches the specified data + // (by checking a single field's value) + private void expectConvertWriteRead(final String configKey, final Schema valueSchema, final byte[] serialized, + final String dataFieldName, final Object dataFieldValue) { + final Capture capturedRecord = EasyMock.newCapture(); + if (serialized != null) + EasyMock.expect(converter.fromConnectData(EasyMock.eq(TOPIC), EasyMock.eq(valueSchema), EasyMock.capture(capturedRecord))) + .andReturn(serialized); + storeLog.send(EasyMock.eq(configKey), EasyMock.aryEq(serialized)); + PowerMock.expectLastCall(); + EasyMock.expect(converter.toConnectData(EasyMock.eq(TOPIC), EasyMock.aryEq(serialized))) + .andAnswer(() -> { + if (dataFieldName != null) + assertEquals(dataFieldValue, capturedRecord.getValue().get(dataFieldName)); + // Note null schema because default settings for internal serialization are schema-less + return new SchemaAndValue(null, serialized == null ? null : structToMap(capturedRecord.getValue())); + }); + } + + // This map needs to maintain ordering + private void expectReadToEnd(final LinkedHashMap serializedConfigs) { + EasyMock.expect(storeLog.readToEnd()) + .andAnswer(() -> { + TestFuture future = new TestFuture<>(); + for (Map.Entry entry : serializedConfigs.entrySet()) { + capturedConsumedCallback.getValue().onCompletion(null, + new ConsumerRecord<>(TOPIC, 0, logOffset++, 0L, TimestampType.CREATE_TIME, 0, 0, + entry.getKey(), entry.getValue(), new RecordHeaders(), Optional.empty())); + } + future.resolveOnGet((Void) null); + return future; + }); + } + + private void expectConnectorRemoval(String configKey, String targetStateKey) { + expectConvertWriteRead(configKey, KafkaConfigBackingStore.CONNECTOR_CONFIGURATION_V0, null, null, null); + expectConvertWriteRead(targetStateKey, KafkaConfigBackingStore.TARGET_STATE_V0, null, null, null); + + LinkedHashMap recordsToRead = new LinkedHashMap<>(); + recordsToRead.put(configKey, null); + recordsToRead.put(targetStateKey, null); + expectReadToEnd(recordsToRead); + } + + private void expectConvertWriteAndRead(final String configKey, final Schema valueSchema, final byte[] serialized, + final String dataFieldName, final Object dataFieldValue) { + expectConvertWriteRead(configKey, valueSchema, serialized, dataFieldName, dataFieldValue); + LinkedHashMap recordsToRead = new LinkedHashMap<>(); + recordsToRead.put(configKey, serialized); + expectReadToEnd(recordsToRead); + } + + // Manually insert a connector into config storage, updating the task configs, connector config, and root config + private void whiteboxAddConnector(String connectorName, Map connectorConfig, List> taskConfigs) { + Map> storageTaskConfigs = Whitebox.getInternalState(configStorage, "taskConfigs"); + for (int i = 0; i < taskConfigs.size(); i++) + storageTaskConfigs.put(new ConnectorTaskId(connectorName, i), taskConfigs.get(i)); + + Map> connectorConfigs = Whitebox.getInternalState(configStorage, "connectorConfigs"); + connectorConfigs.put(connectorName, connectorConfig); + + Whitebox.>getInternalState(configStorage, "connectorTaskCounts").put(connectorName, taskConfigs.size()); + } + + // Generates a Map representation of Struct. Only does shallow traversal, so nested structs are not converted + private Map structToMap(Struct struct) { + if (struct == null) + return null; + + HashMap result = new HashMap<>(); + for (Field field : struct.schema().fields()) + result.put(field.name(), struct.get(field)); + return result; + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/storage/KafkaOffsetBackingStoreTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/storage/KafkaOffsetBackingStoreTest.java new file mode 100644 index 0000000..cdce2a1 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/storage/KafkaOffsetBackingStoreTest.java @@ -0,0 +1,418 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.connect.runtime.distributed.DistributedConfig; +import org.apache.kafka.connect.util.Callback; +import org.apache.kafka.connect.util.ConnectUtils; +import org.apache.kafka.connect.util.KafkaBasedLog; +import org.apache.kafka.connect.util.TopicAdmin; +import org.easymock.Capture; +import org.easymock.EasyMock; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.easymock.PowerMock; +import org.powermock.api.easymock.annotation.Mock; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; +import org.powermock.reflect.Whitebox; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +@RunWith(PowerMockRunner.class) +@PrepareForTest({KafkaOffsetBackingStore.class, ConnectUtils.class}) +@PowerMockIgnore({"javax.management.*", "javax.crypto.*"}) +@SuppressWarnings({"unchecked", "deprecation"}) +public class KafkaOffsetBackingStoreTest { + private static final String TOPIC = "connect-offsets"; + private static final short TOPIC_PARTITIONS = 2; + private static final short TOPIC_REPLICATION_FACTOR = 5; + private static final Map DEFAULT_PROPS = new HashMap<>(); + private static final DistributedConfig DEFAULT_DISTRIBUTED_CONFIG; + static { + DEFAULT_PROPS.put(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, "broker1:9092,broker2:9093"); + DEFAULT_PROPS.put(DistributedConfig.OFFSET_STORAGE_TOPIC_CONFIG, TOPIC); + DEFAULT_PROPS.put(DistributedConfig.OFFSET_STORAGE_REPLICATION_FACTOR_CONFIG, Short.toString(TOPIC_REPLICATION_FACTOR)); + DEFAULT_PROPS.put(DistributedConfig.OFFSET_STORAGE_PARTITIONS_CONFIG, Integer.toString(TOPIC_PARTITIONS)); + DEFAULT_PROPS.put(DistributedConfig.CONFIG_TOPIC_CONFIG, "connect-configs"); + DEFAULT_PROPS.put(DistributedConfig.CONFIG_STORAGE_REPLICATION_FACTOR_CONFIG, Short.toString(TOPIC_REPLICATION_FACTOR)); + DEFAULT_PROPS.put(DistributedConfig.GROUP_ID_CONFIG, "connect"); + DEFAULT_PROPS.put(DistributedConfig.STATUS_STORAGE_TOPIC_CONFIG, "status-topic"); + DEFAULT_PROPS.put(DistributedConfig.KEY_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + DEFAULT_PROPS.put(DistributedConfig.VALUE_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + DEFAULT_DISTRIBUTED_CONFIG = new DistributedConfig(DEFAULT_PROPS); + + } + private static final Map FIRST_SET = new HashMap<>(); + static { + FIRST_SET.put(buffer("key"), buffer("value")); + FIRST_SET.put(null, null); + } + + private static final ByteBuffer TP0_KEY = buffer("TP0KEY"); + private static final ByteBuffer TP1_KEY = buffer("TP1KEY"); + private static final ByteBuffer TP2_KEY = buffer("TP2KEY"); + private static final ByteBuffer TP0_VALUE = buffer("VAL0"); + private static final ByteBuffer TP1_VALUE = buffer("VAL1"); + private static final ByteBuffer TP2_VALUE = buffer("VAL2"); + private static final ByteBuffer TP0_VALUE_NEW = buffer("VAL0_NEW"); + private static final ByteBuffer TP1_VALUE_NEW = buffer("VAL1_NEW"); + + @Mock + KafkaBasedLog storeLog; + private KafkaOffsetBackingStore store; + + private Capture capturedTopic = EasyMock.newCapture(); + private Capture> capturedProducerProps = EasyMock.newCapture(); + private Capture> capturedConsumerProps = EasyMock.newCapture(); + private Capture> capturedAdminSupplier = EasyMock.newCapture(); + private Capture capturedNewTopic = EasyMock.newCapture(); + private Capture>> capturedConsumedCallback = EasyMock.newCapture(); + + @Before + public void setUp() throws Exception { + store = PowerMock.createPartialMockAndInvokeDefaultConstructor(KafkaOffsetBackingStore.class, "createKafkaBasedLog"); + } + + @Test + public void testStartStop() throws Exception { + expectConfigure(); + expectStart(Collections.emptyList()); + expectStop(); + expectClusterId(); + + PowerMock.replayAll(); + + Map settings = new HashMap<>(DEFAULT_PROPS); + settings.put("offset.storage.min.insync.replicas", "3"); + settings.put("offset.storage.max.message.bytes", "1001"); + store.configure(new DistributedConfig(settings)); + assertEquals(TOPIC, capturedTopic.getValue()); + assertEquals("org.apache.kafka.common.serialization.ByteArraySerializer", capturedProducerProps.getValue().get(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG)); + assertEquals("org.apache.kafka.common.serialization.ByteArraySerializer", capturedProducerProps.getValue().get(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG)); + assertEquals("org.apache.kafka.common.serialization.ByteArrayDeserializer", capturedConsumerProps.getValue().get(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG)); + assertEquals("org.apache.kafka.common.serialization.ByteArrayDeserializer", capturedConsumerProps.getValue().get(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG)); + + assertEquals(TOPIC, capturedNewTopic.getValue().name()); + assertEquals(TOPIC_PARTITIONS, capturedNewTopic.getValue().numPartitions()); + assertEquals(TOPIC_REPLICATION_FACTOR, capturedNewTopic.getValue().replicationFactor()); + + store.start(); + store.stop(); + + PowerMock.verifyAll(); + } + + @Test + public void testReloadOnStart() throws Exception { + expectConfigure(); + expectStart(Arrays.asList( + new ConsumerRecord<>(TOPIC, 0, 0, 0L, TimestampType.CREATE_TIME, 0, 0, TP0_KEY.array(), TP0_VALUE.array(), + new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 1, 0, 0L, TimestampType.CREATE_TIME, 0, 0, TP1_KEY.array(), TP1_VALUE.array(), + new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 0, 1, 0L, TimestampType.CREATE_TIME, 0, 0, TP0_KEY.array(), TP0_VALUE_NEW.array(), + new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>(TOPIC, 1, 1, 0L, TimestampType.CREATE_TIME, 0, 0, TP1_KEY.array(), TP1_VALUE_NEW.array(), + new RecordHeaders(), Optional.empty()) + )); + expectStop(); + expectClusterId(); + + PowerMock.replayAll(); + + store.configure(DEFAULT_DISTRIBUTED_CONFIG); + store.start(); + HashMap data = Whitebox.getInternalState(store, "data"); + assertEquals(TP0_VALUE_NEW, data.get(TP0_KEY)); + assertEquals(TP1_VALUE_NEW, data.get(TP1_KEY)); + + store.stop(); + + PowerMock.verifyAll(); + } + + @Test + public void testGetSet() throws Exception { + expectConfigure(); + expectStart(Collections.emptyList()); + expectStop(); + + // First get() against an empty store + final Capture> firstGetReadToEndCallback = EasyMock.newCapture(); + storeLog.readToEnd(EasyMock.capture(firstGetReadToEndCallback)); + PowerMock.expectLastCall().andAnswer(() -> { + firstGetReadToEndCallback.getValue().onCompletion(null, null); + return null; + }); + + // Set offsets + Capture callback0 = EasyMock.newCapture(); + storeLog.send(EasyMock.aryEq(TP0_KEY.array()), EasyMock.aryEq(TP0_VALUE.array()), EasyMock.capture(callback0)); + PowerMock.expectLastCall(); + Capture callback1 = EasyMock.newCapture(); + storeLog.send(EasyMock.aryEq(TP1_KEY.array()), EasyMock.aryEq(TP1_VALUE.array()), EasyMock.capture(callback1)); + PowerMock.expectLastCall(); + + // Second get() should get the produced data and return the new values + final Capture> secondGetReadToEndCallback = EasyMock.newCapture(); + storeLog.readToEnd(EasyMock.capture(secondGetReadToEndCallback)); + PowerMock.expectLastCall().andAnswer(() -> { + capturedConsumedCallback.getValue().onCompletion(null, + new ConsumerRecord<>(TOPIC, 0, 0, 0L, TimestampType.CREATE_TIME, 0, 0, TP0_KEY.array(), TP0_VALUE.array(), + new RecordHeaders(), Optional.empty())); + capturedConsumedCallback.getValue().onCompletion(null, + new ConsumerRecord<>(TOPIC, 1, 0, 0L, TimestampType.CREATE_TIME, 0, 0, TP1_KEY.array(), TP1_VALUE.array(), + new RecordHeaders(), Optional.empty())); + secondGetReadToEndCallback.getValue().onCompletion(null, null); + return null; + }); + + // Third get() should pick up data produced by someone else and return those values + final Capture> thirdGetReadToEndCallback = EasyMock.newCapture(); + storeLog.readToEnd(EasyMock.capture(thirdGetReadToEndCallback)); + PowerMock.expectLastCall().andAnswer(() -> { + capturedConsumedCallback.getValue().onCompletion(null, + new ConsumerRecord<>(TOPIC, 0, 1, 0L, TimestampType.CREATE_TIME, 0, 0, TP0_KEY.array(), TP0_VALUE_NEW.array(), + new RecordHeaders(), Optional.empty())); + capturedConsumedCallback.getValue().onCompletion(null, + new ConsumerRecord<>(TOPIC, 1, 1, 0L, TimestampType.CREATE_TIME, 0, 0, TP1_KEY.array(), TP1_VALUE_NEW.array(), + new RecordHeaders(), Optional.empty())); + thirdGetReadToEndCallback.getValue().onCompletion(null, null); + return null; + }); + + expectClusterId(); + PowerMock.replayAll(); + + store.configure(DEFAULT_DISTRIBUTED_CONFIG); + store.start(); + + // Getting from empty store should return nulls + Map offsets = store.get(Arrays.asList(TP0_KEY, TP1_KEY)).get(10000, TimeUnit.MILLISECONDS); + // Since we didn't read them yet, these will be null + assertNull(offsets.get(TP0_KEY)); + assertNull(offsets.get(TP1_KEY)); + + // Set some offsets + Map toSet = new HashMap<>(); + toSet.put(TP0_KEY, TP0_VALUE); + toSet.put(TP1_KEY, TP1_VALUE); + final AtomicBoolean invoked = new AtomicBoolean(false); + Future setFuture = store.set(toSet, (error, result) -> invoked.set(true)); + assertFalse(setFuture.isDone()); + // Out of order callbacks shouldn't matter, should still require all to be invoked before invoking the callback + // for the store's set callback + callback1.getValue().onCompletion(null, null); + assertFalse(invoked.get()); + callback0.getValue().onCompletion(null, null); + setFuture.get(10000, TimeUnit.MILLISECONDS); + assertTrue(invoked.get()); + + // Getting data should read to end of our published data and return it + offsets = store.get(Arrays.asList(TP0_KEY, TP1_KEY)).get(10000, TimeUnit.MILLISECONDS); + assertEquals(TP0_VALUE, offsets.get(TP0_KEY)); + assertEquals(TP1_VALUE, offsets.get(TP1_KEY)); + + // Getting data should read to end of our published data and return it + offsets = store.get(Arrays.asList(TP0_KEY, TP1_KEY)).get(10000, TimeUnit.MILLISECONDS); + assertEquals(TP0_VALUE_NEW, offsets.get(TP0_KEY)); + assertEquals(TP1_VALUE_NEW, offsets.get(TP1_KEY)); + + store.stop(); + + PowerMock.verifyAll(); + } + + @Test + public void testGetSetNull() throws Exception { + expectConfigure(); + expectStart(Collections.emptyList()); + + // Set offsets + Capture callback0 = EasyMock.newCapture(); + storeLog.send(EasyMock.isNull(byte[].class), EasyMock.aryEq(TP0_VALUE.array()), EasyMock.capture(callback0)); + PowerMock.expectLastCall(); + Capture callback1 = EasyMock.newCapture(); + storeLog.send(EasyMock.aryEq(TP1_KEY.array()), EasyMock.isNull(byte[].class), EasyMock.capture(callback1)); + PowerMock.expectLastCall(); + + // Second get() should get the produced data and return the new values + final Capture> secondGetReadToEndCallback = EasyMock.newCapture(); + storeLog.readToEnd(EasyMock.capture(secondGetReadToEndCallback)); + PowerMock.expectLastCall().andAnswer(() -> { + capturedConsumedCallback.getValue().onCompletion(null, + new ConsumerRecord<>(TOPIC, 0, 0, 0L, TimestampType.CREATE_TIME, 0, 0, null, TP0_VALUE.array(), + new RecordHeaders(), Optional.empty())); + capturedConsumedCallback.getValue().onCompletion(null, + new ConsumerRecord<>(TOPIC, 1, 0, 0L, TimestampType.CREATE_TIME, 0, 0, TP1_KEY.array(), null, + new RecordHeaders(), Optional.empty())); + secondGetReadToEndCallback.getValue().onCompletion(null, null); + return null; + }); + + expectStop(); + expectClusterId(); + + PowerMock.replayAll(); + + store.configure(DEFAULT_DISTRIBUTED_CONFIG); + store.start(); + + // Set offsets using null keys and values + Map toSet = new HashMap<>(); + toSet.put(null, TP0_VALUE); + toSet.put(TP1_KEY, null); + final AtomicBoolean invoked = new AtomicBoolean(false); + Future setFuture = store.set(toSet, (error, result) -> invoked.set(true)); + assertFalse(setFuture.isDone()); + // Out of order callbacks shouldn't matter, should still require all to be invoked before invoking the callback + // for the store's set callback + callback1.getValue().onCompletion(null, null); + assertFalse(invoked.get()); + callback0.getValue().onCompletion(null, null); + setFuture.get(10000, TimeUnit.MILLISECONDS); + assertTrue(invoked.get()); + + // Getting data should read to end of our published data and return it + Map offsets = store.get(Arrays.asList(null, TP1_KEY)).get(10000, TimeUnit.MILLISECONDS); + assertEquals(TP0_VALUE, offsets.get(null)); + assertNull(offsets.get(TP1_KEY)); + + store.stop(); + + PowerMock.verifyAll(); + } + + @Test + public void testSetFailure() throws Exception { + expectConfigure(); + expectStart(Collections.emptyList()); + expectStop(); + + // Set offsets + Capture callback0 = EasyMock.newCapture(); + storeLog.send(EasyMock.aryEq(TP0_KEY.array()), EasyMock.aryEq(TP0_VALUE.array()), EasyMock.capture(callback0)); + PowerMock.expectLastCall(); + Capture callback1 = EasyMock.newCapture(); + storeLog.send(EasyMock.aryEq(TP1_KEY.array()), EasyMock.aryEq(TP1_VALUE.array()), EasyMock.capture(callback1)); + PowerMock.expectLastCall(); + Capture callback2 = EasyMock.newCapture(); + storeLog.send(EasyMock.aryEq(TP2_KEY.array()), EasyMock.aryEq(TP2_VALUE.array()), EasyMock.capture(callback2)); + PowerMock.expectLastCall(); + + expectClusterId(); + + PowerMock.replayAll(); + + store.configure(DEFAULT_DISTRIBUTED_CONFIG); + store.start(); + + // Set some offsets + Map toSet = new HashMap<>(); + toSet.put(TP0_KEY, TP0_VALUE); + toSet.put(TP1_KEY, TP1_VALUE); + toSet.put(TP2_KEY, TP2_VALUE); + final AtomicBoolean invoked = new AtomicBoolean(false); + final AtomicBoolean invokedFailure = new AtomicBoolean(false); + Future setFuture = store.set(toSet, (error, result) -> { + invoked.set(true); + if (error != null) + invokedFailure.set(true); + }); + assertFalse(setFuture.isDone()); + // Out of order callbacks shouldn't matter, should still require all to be invoked before invoking the callback + // for the store's set callback + callback1.getValue().onCompletion(null, null); + assertFalse(invoked.get()); + callback2.getValue().onCompletion(null, new KafkaException("bogus error")); + assertTrue(invoked.get()); + assertTrue(invokedFailure.get()); + callback0.getValue().onCompletion(null, null); + try { + setFuture.get(10000, TimeUnit.MILLISECONDS); + fail("Should have seen KafkaException thrown when waiting on KafkaOffsetBackingStore.set() future"); + } catch (ExecutionException e) { + // expected + assertNotNull(e.getCause()); + assertTrue(e.getCause() instanceof KafkaException); + } + + store.stop(); + + PowerMock.verifyAll(); + } + + private void expectConfigure() throws Exception { + PowerMock.expectPrivate(store, "createKafkaBasedLog", EasyMock.capture(capturedTopic), EasyMock.capture(capturedProducerProps), + EasyMock.capture(capturedConsumerProps), EasyMock.capture(capturedConsumedCallback), + EasyMock.capture(capturedNewTopic), EasyMock.capture(capturedAdminSupplier)) + .andReturn(storeLog); + } + + private void expectStart(final List> preexistingRecords) throws Exception { + storeLog.start(); + PowerMock.expectLastCall().andAnswer(() -> { + for (ConsumerRecord rec : preexistingRecords) + capturedConsumedCallback.getValue().onCompletion(null, rec); + return null; + }); + } + + private void expectStop() { + storeLog.stop(); + PowerMock.expectLastCall(); + } + + private void expectClusterId() { + PowerMock.mockStaticPartial(ConnectUtils.class, "lookupKafkaClusterId"); + EasyMock.expect(ConnectUtils.lookupKafkaClusterId(EasyMock.anyObject())).andReturn("test-cluster").anyTimes(); + } + + private static ByteBuffer buffer(String v) { + return ByteBuffer.wrap(v.getBytes()); + } + +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/storage/KafkaStatusBackingStoreFormatTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/storage/KafkaStatusBackingStoreFormatTest.java new file mode 100644 index 0000000..ec8ea85 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/storage/KafkaStatusBackingStoreFormatTest.java @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.producer.Callback; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.UnknownServerException; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.connect.json.JsonConverter; +import org.apache.kafka.connect.runtime.TopicStatus; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.apache.kafka.connect.util.KafkaBasedLog; +import org.easymock.Capture; +import org.easymock.EasyMockSupport; +import org.easymock.Mock; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.concurrent.ConcurrentHashMap; + +import static org.apache.kafka.connect.json.JsonConverterConfig.SCHEMAS_ENABLE_CONFIG; +import static org.apache.kafka.connect.storage.KafkaStatusBackingStore.CONNECTOR_STATUS_PREFIX; +import static org.apache.kafka.connect.storage.KafkaStatusBackingStore.TASK_STATUS_PREFIX; +import static org.apache.kafka.connect.storage.KafkaStatusBackingStore.TOPIC_STATUS_PREFIX; +import static org.apache.kafka.connect.storage.KafkaStatusBackingStore.TOPIC_STATUS_SEPARATOR; +import static org.easymock.EasyMock.capture; +import static org.easymock.EasyMock.eq; +import static org.easymock.EasyMock.expectLastCall; +import static org.easymock.EasyMock.newCapture; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +@RunWith(PowerMockRunner.class) +public class KafkaStatusBackingStoreFormatTest extends EasyMockSupport { + + private static final String STATUS_TOPIC = "status-topic"; + private static final String FOO_TOPIC = "foo-topic"; + private static final String FOO_CONNECTOR = "foo-source"; + private static final String BAR_TOPIC = "bar-topic"; + + private Time time; + private KafkaStatusBackingStore store; + private JsonConverter converter; + @Mock + private KafkaBasedLog kafkaBasedLog; + + @Before + public void setup() { + time = new MockTime(); + converter = new JsonConverter(); + converter.configure(Collections.singletonMap(SCHEMAS_ENABLE_CONFIG, false), false); + store = new KafkaStatusBackingStore(new MockTime(), converter, STATUS_TOPIC, kafkaBasedLog); + } + + @Test + public void readInvalidStatus() { + String key = "status-unknown"; + byte[] value = new byte[0]; + ConsumerRecord statusRecord = new ConsumerRecord<>(STATUS_TOPIC, 0, 0, key, value); + assertTrue(store.connectors().isEmpty()); + assertTrue(store.tasks.isEmpty()); + assertTrue(store.topics.isEmpty()); + store.read(statusRecord); + assertTrue(store.connectors().isEmpty()); + assertTrue(store.tasks.isEmpty()); + assertTrue(store.topics.isEmpty()); + + key = CONNECTOR_STATUS_PREFIX; + statusRecord = new ConsumerRecord<>(STATUS_TOPIC, 0, 0, key, value); + assertTrue(store.connectors().isEmpty()); + store.read(statusRecord); + assertTrue(store.connectors().isEmpty()); + + key = TASK_STATUS_PREFIX; + statusRecord = new ConsumerRecord<>(STATUS_TOPIC, 0, 0, key, value); + assertTrue(store.tasks.isEmpty()); + store.read(statusRecord); + assertTrue(store.tasks.isEmpty()); + + key = TASK_STATUS_PREFIX + FOO_CONNECTOR + "-#"; + statusRecord = new ConsumerRecord<>(STATUS_TOPIC, 0, 0, key, value); + assertTrue(store.tasks.isEmpty()); + store.read(statusRecord); + assertTrue(store.tasks.isEmpty()); + + key = TOPIC_STATUS_PREFIX; + statusRecord = new ConsumerRecord<>(STATUS_TOPIC, 0, 0, key, value); + assertTrue(store.topics.isEmpty()); + store.read(statusRecord); + assertTrue(store.topics.isEmpty()); + + key = TOPIC_STATUS_PREFIX + TOPIC_STATUS_SEPARATOR; + statusRecord = new ConsumerRecord<>(STATUS_TOPIC, 0, 0, key, value); + assertTrue(store.topics.isEmpty()); + store.read(statusRecord); + assertTrue(store.topics.isEmpty()); + + key = TOPIC_STATUS_PREFIX + FOO_TOPIC + ":"; + statusRecord = new ConsumerRecord<>(STATUS_TOPIC, 0, 0, key, value); + assertTrue(store.topics.isEmpty()); + store.read(statusRecord); + assertTrue(store.topics.isEmpty()); + + key = TOPIC_STATUS_PREFIX + FOO_TOPIC + TOPIC_STATUS_SEPARATOR; + statusRecord = new ConsumerRecord<>(STATUS_TOPIC, 0, 0, key, value); + assertTrue(store.topics.isEmpty()); + store.read(statusRecord); + assertTrue(store.topics.isEmpty()); + } + + @Test + public void readInvalidStatusValue() { + String key = CONNECTOR_STATUS_PREFIX + FOO_CONNECTOR; + byte[] value = "invalid".getBytes(); + ConsumerRecord statusRecord = new ConsumerRecord<>(STATUS_TOPIC, 0, 0, key, value); + assertTrue(store.connectors().isEmpty()); + store.read(statusRecord); + assertTrue(store.connectors().isEmpty()); + + key = TASK_STATUS_PREFIX + FOO_CONNECTOR + "-0"; + statusRecord = new ConsumerRecord<>(STATUS_TOPIC, 0, 0, key, value); + assertTrue(store.tasks.isEmpty()); + store.read(statusRecord); + assertTrue(store.tasks.isEmpty()); + + key = TOPIC_STATUS_PREFIX + FOO_TOPIC + TOPIC_STATUS_SEPARATOR + FOO_CONNECTOR; + statusRecord = new ConsumerRecord<>(STATUS_TOPIC, 0, 0, key, value); + assertTrue(store.topics.isEmpty()); + store.read(statusRecord); + assertTrue(store.topics.isEmpty()); + } + + @Test + public void readTopicStatus() { + TopicStatus topicStatus = new TopicStatus(FOO_TOPIC, new ConnectorTaskId(FOO_CONNECTOR, 0), Time.SYSTEM.milliseconds()); + String key = TOPIC_STATUS_PREFIX + FOO_TOPIC + TOPIC_STATUS_SEPARATOR + FOO_CONNECTOR; + byte[] value = store.serializeTopicStatus(topicStatus); + ConsumerRecord statusRecord = new ConsumerRecord<>(STATUS_TOPIC, 0, 0, key, value); + store.read(statusRecord); + assertTrue(store.topics.containsKey(FOO_CONNECTOR)); + assertTrue(store.topics.get(FOO_CONNECTOR).containsKey(FOO_TOPIC)); + assertEquals(topicStatus, store.topics.get(FOO_CONNECTOR).get(FOO_TOPIC)); + } + + @Test + public void deleteTopicStatus() { + TopicStatus topicStatus = new TopicStatus("foo", new ConnectorTaskId("bar", 0), Time.SYSTEM.milliseconds()); + store.topics.computeIfAbsent("bar", k -> new ConcurrentHashMap<>()).put("foo", topicStatus); + assertTrue(store.topics.containsKey("bar")); + assertTrue(store.topics.get("bar").containsKey("foo")); + assertEquals(topicStatus, store.topics.get("bar").get("foo")); + // should return null + byte[] value = store.serializeTopicStatus(null); + ConsumerRecord statusRecord = new ConsumerRecord<>(STATUS_TOPIC, 0, 0, "status-topic-foo:connector-bar", value); + store.read(statusRecord); + assertTrue(store.topics.containsKey("bar")); + assertFalse(store.topics.get("bar").containsKey("foo")); + assertEquals(Collections.emptyMap(), store.topics.get("bar")); + } + + @Test + public void putTopicState() { + TopicStatus topicStatus = new TopicStatus(FOO_TOPIC, new ConnectorTaskId(FOO_CONNECTOR, 0), time.milliseconds()); + String key = TOPIC_STATUS_PREFIX + FOO_TOPIC + TOPIC_STATUS_SEPARATOR + FOO_CONNECTOR; + Capture valueCapture = newCapture(); + Capture callbackCapture = newCapture(); + kafkaBasedLog.send(eq(key), capture(valueCapture), capture(callbackCapture)); + expectLastCall() + .andAnswer(() -> { + callbackCapture.getValue().onCompletion(null, null); + return null; + }); + replayAll(); + + store.put(topicStatus); + // check capture state + assertEquals(topicStatus, store.parseTopicStatus(valueCapture.getValue())); + // state is not visible until read back from the log + assertNull(store.getTopic(FOO_CONNECTOR, FOO_TOPIC)); + + ConsumerRecord statusRecord = new ConsumerRecord<>(STATUS_TOPIC, 0, 0, key, valueCapture.getValue()); + store.read(statusRecord); + assertEquals(topicStatus, store.getTopic(FOO_CONNECTOR, FOO_TOPIC)); + assertEquals(new HashSet<>(Collections.singletonList(topicStatus)), new HashSet<>(store.getAllTopics(FOO_CONNECTOR))); + + verifyAll(); + } + + @Test + public void putTopicStateRetriableFailure() { + TopicStatus topicStatus = new TopicStatus(FOO_TOPIC, new ConnectorTaskId(FOO_CONNECTOR, 0), time.milliseconds()); + String key = TOPIC_STATUS_PREFIX + FOO_TOPIC + TOPIC_STATUS_SEPARATOR + FOO_CONNECTOR; + Capture valueCapture = newCapture(); + Capture callbackCapture = newCapture(); + kafkaBasedLog.send(eq(key), capture(valueCapture), capture(callbackCapture)); + expectLastCall() + .andAnswer(() -> { + callbackCapture.getValue().onCompletion(null, new TimeoutException()); + return null; + }) + .andAnswer(() -> { + callbackCapture.getValue().onCompletion(null, null); + return null; + }); + + replayAll(); + store.put(topicStatus); + + // check capture state + assertEquals(topicStatus, store.parseTopicStatus(valueCapture.getValue())); + // state is not visible until read back from the log + assertNull(store.getTopic(FOO_CONNECTOR, FOO_TOPIC)); + + verifyAll(); + } + + @Test + public void putTopicStateNonRetriableFailure() { + TopicStatus topicStatus = new TopicStatus(FOO_TOPIC, new ConnectorTaskId(FOO_CONNECTOR, 0), time.milliseconds()); + String key = TOPIC_STATUS_PREFIX + FOO_TOPIC + TOPIC_STATUS_SEPARATOR + FOO_CONNECTOR; + Capture valueCapture = newCapture(); + Capture callbackCapture = newCapture(); + kafkaBasedLog.send(eq(key), capture(valueCapture), capture(callbackCapture)); + expectLastCall() + .andAnswer(() -> { + callbackCapture.getValue().onCompletion(null, new UnknownServerException()); + return null; + }); + + replayAll(); + + // the error is logged and ignored + store.put(topicStatus); + + // check capture state + assertEquals(topicStatus, store.parseTopicStatus(valueCapture.getValue())); + // state is not visible until read back from the log + assertNull(store.getTopic(FOO_CONNECTOR, FOO_TOPIC)); + + verifyAll(); + } + + @Test + public void putTopicStateShouldOverridePreviousState() { + TopicStatus firstTopicStatus = new TopicStatus(FOO_TOPIC, new ConnectorTaskId(FOO_CONNECTOR, + 0), time.milliseconds()); + time.sleep(1000); + TopicStatus secondTopicStatus = new TopicStatus(BAR_TOPIC, new ConnectorTaskId(FOO_CONNECTOR, + 0), time.milliseconds()); + String firstKey = TOPIC_STATUS_PREFIX + FOO_TOPIC + TOPIC_STATUS_SEPARATOR + FOO_CONNECTOR; + String secondKey = TOPIC_STATUS_PREFIX + BAR_TOPIC + TOPIC_STATUS_SEPARATOR + FOO_CONNECTOR; + Capture valueCapture = newCapture(); + Capture callbackCapture = newCapture(); + kafkaBasedLog.send(eq(secondKey), capture(valueCapture), capture(callbackCapture)); + expectLastCall() + .andAnswer(() -> { + callbackCapture.getValue().onCompletion(null, null); + // The second status record is read soon after it's persisted in the status topic + ConsumerRecord statusRecord = new ConsumerRecord<>(STATUS_TOPIC, 0, 0, secondKey, valueCapture.getValue()); + store.read(statusRecord); + return null; + }); + replayAll(); + + byte[] value = store.serializeTopicStatus(firstTopicStatus); + ConsumerRecord statusRecord = new ConsumerRecord<>(STATUS_TOPIC, 0, 0, firstKey, value); + store.read(statusRecord); + store.put(secondTopicStatus); + + // check capture state + assertEquals(secondTopicStatus, store.parseTopicStatus(valueCapture.getValue())); + assertEquals(firstTopicStatus, store.getTopic(FOO_CONNECTOR, FOO_TOPIC)); + assertEquals(secondTopicStatus, store.getTopic(FOO_CONNECTOR, BAR_TOPIC)); + assertEquals(new HashSet<>(Arrays.asList(firstTopicStatus, secondTopicStatus)), + new HashSet<>(store.getAllTopics(FOO_CONNECTOR))); + + verifyAll(); + } + +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/storage/KafkaStatusBackingStoreTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/storage/KafkaStatusBackingStoreTest.java new file mode 100644 index 0000000..f19be75 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/storage/KafkaStatusBackingStoreTest.java @@ -0,0 +1,444 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.producer.Callback; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.UnknownServerException; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.runtime.ConnectorStatus; +import org.apache.kafka.connect.runtime.TaskStatus; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.apache.kafka.connect.runtime.distributed.DistributedConfig; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.apache.kafka.connect.util.KafkaBasedLog; +import org.easymock.Capture; +import org.easymock.EasyMock; +import org.easymock.EasyMockSupport; +import org.easymock.Mock; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Optional; + +import static org.easymock.EasyMock.anyObject; +import static org.easymock.EasyMock.capture; +import static org.easymock.EasyMock.eq; +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.expectLastCall; +import static org.easymock.EasyMock.newCapture; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +@SuppressWarnings("unchecked") +@RunWith(PowerMockRunner.class) +public class KafkaStatusBackingStoreTest extends EasyMockSupport { + + private static final String STATUS_TOPIC = "status-topic"; + private static final String WORKER_ID = "localhost:8083"; + private static final String CONNECTOR = "conn"; + private static final ConnectorTaskId TASK = new ConnectorTaskId(CONNECTOR, 0); + + private KafkaStatusBackingStore store; + @Mock + Converter converter; + @Mock + private KafkaBasedLog kafkaBasedLog; + @Mock + WorkerConfig workerConfig; + + @Before + public void setup() { + store = new KafkaStatusBackingStore(new MockTime(), converter, STATUS_TOPIC, kafkaBasedLog); + } + + @Test + public void misconfigurationOfStatusBackingStore() { + expect(workerConfig.getString(DistributedConfig.STATUS_STORAGE_TOPIC_CONFIG)).andReturn(null); + expect(workerConfig.getString(DistributedConfig.STATUS_STORAGE_TOPIC_CONFIG)).andReturn(" "); + replayAll(); + + Exception e = assertThrows(ConfigException.class, () -> store.configure(workerConfig)); + assertEquals("Must specify topic for connector status.", e.getMessage()); + e = assertThrows(ConfigException.class, () -> store.configure(workerConfig)); + assertEquals("Must specify topic for connector status.", e.getMessage()); + verifyAll(); + } + + @Test + public void putConnectorState() { + byte[] value = new byte[0]; + expect(converter.fromConnectData(eq(STATUS_TOPIC), anyObject(Schema.class), anyObject(Struct.class))) + .andStubReturn(value); + + final Capture callbackCapture = newCapture(); + kafkaBasedLog.send(eq("status-connector-conn"), eq(value), capture(callbackCapture)); + expectLastCall() + .andAnswer(() -> { + callbackCapture.getValue().onCompletion(null, null); + return null; + }); + replayAll(); + + ConnectorStatus status = new ConnectorStatus(CONNECTOR, ConnectorStatus.State.RUNNING, WORKER_ID, 0); + store.put(status); + + // state is not visible until read back from the log + assertNull(store.get(CONNECTOR)); + + verifyAll(); + } + + @Test + public void putConnectorStateRetriableFailure() { + byte[] value = new byte[0]; + expect(converter.fromConnectData(eq(STATUS_TOPIC), anyObject(Schema.class), anyObject(Struct.class))) + .andStubReturn(value); + + final Capture callbackCapture = newCapture(); + kafkaBasedLog.send(eq("status-connector-conn"), eq(value), capture(callbackCapture)); + expectLastCall() + .andAnswer(() -> { + callbackCapture.getValue().onCompletion(null, new TimeoutException()); + return null; + }) + .andAnswer(() -> { + callbackCapture.getValue().onCompletion(null, null); + return null; + }); + replayAll(); + + ConnectorStatus status = new ConnectorStatus(CONNECTOR, ConnectorStatus.State.RUNNING, WORKER_ID, 0); + store.put(status); + + // state is not visible until read back from the log + assertNull(store.get(CONNECTOR)); + + verifyAll(); + } + + @Test + public void putConnectorStateNonRetriableFailure() { + byte[] value = new byte[0]; + expect(converter.fromConnectData(eq(STATUS_TOPIC), anyObject(Schema.class), anyObject(Struct.class))) + .andStubReturn(value); + + final Capture callbackCapture = newCapture(); + kafkaBasedLog.send(eq("status-connector-conn"), eq(value), capture(callbackCapture)); + expectLastCall() + .andAnswer(() -> { + callbackCapture.getValue().onCompletion(null, new UnknownServerException()); + return null; + }); + replayAll(); + + // the error is logged and ignored + ConnectorStatus status = new ConnectorStatus(CONNECTOR, ConnectorStatus.State.RUNNING, WORKER_ID, 0); + store.put(status); + + // state is not visible until read back from the log + assertNull(store.get(CONNECTOR)); + + verifyAll(); + } + + @Test + public void putSafeConnectorIgnoresStaleStatus() { + byte[] value = new byte[0]; + String otherWorkerId = "anotherhost:8083"; + + // the persisted came from a different host and has a newer generation + Map statusMap = new HashMap<>(); + statusMap.put("worker_id", otherWorkerId); + statusMap.put("state", "RUNNING"); + statusMap.put("generation", 1L); + + expect(converter.toConnectData(STATUS_TOPIC, value)) + .andReturn(new SchemaAndValue(null, statusMap)); + + // we're verifying that there is no call to KafkaBasedLog.send + + replayAll(); + + store.read(consumerRecord(0, "status-connector-conn", value)); + store.putSafe(new ConnectorStatus(CONNECTOR, ConnectorStatus.State.UNASSIGNED, WORKER_ID, 0)); + + ConnectorStatus status = new ConnectorStatus(CONNECTOR, ConnectorStatus.State.RUNNING, otherWorkerId, 1); + assertEquals(status, store.get(CONNECTOR)); + + verifyAll(); + } + + @Test + public void putSafeWithNoPreviousValueIsPropagated() { + final byte[] value = new byte[0]; + + final Capture statusValueStruct = newCapture(); + converter.fromConnectData(eq(STATUS_TOPIC), anyObject(Schema.class), capture(statusValueStruct)); + EasyMock.expectLastCall().andReturn(value); + + kafkaBasedLog.send(eq("status-connector-" + CONNECTOR), eq(value), anyObject(Callback.class)); + expectLastCall(); + + replayAll(); + + final ConnectorStatus status = new ConnectorStatus(CONNECTOR, ConnectorStatus.State.FAILED, WORKER_ID, 0); + store.putSafe(status); + + verifyAll(); + + assertEquals(status.state().toString(), statusValueStruct.getValue().get(KafkaStatusBackingStore.STATE_KEY_NAME)); + assertEquals(status.workerId(), statusValueStruct.getValue().get(KafkaStatusBackingStore.WORKER_ID_KEY_NAME)); + assertEquals(status.generation(), statusValueStruct.getValue().get(KafkaStatusBackingStore.GENERATION_KEY_NAME)); + } + + @Test + public void putSafeOverridesValueSetBySameWorker() { + final byte[] value = new byte[0]; + + // the persisted came from the same host, but has a newer generation + Map firstStatusRead = new HashMap<>(); + firstStatusRead.put("worker_id", WORKER_ID); + firstStatusRead.put("state", "RUNNING"); + firstStatusRead.put("generation", 1L); + + Map secondStatusRead = new HashMap<>(); + secondStatusRead.put("worker_id", WORKER_ID); + secondStatusRead.put("state", "UNASSIGNED"); + secondStatusRead.put("generation", 0L); + + expect(converter.toConnectData(STATUS_TOPIC, value)) + .andReturn(new SchemaAndValue(null, firstStatusRead)) + .andReturn(new SchemaAndValue(null, secondStatusRead)); + + expect(converter.fromConnectData(eq(STATUS_TOPIC), anyObject(Schema.class), anyObject(Struct.class))) + .andStubReturn(value); + + final Capture callbackCapture = newCapture(); + kafkaBasedLog.send(eq("status-connector-conn"), eq(value), capture(callbackCapture)); + expectLastCall() + .andAnswer(() -> { + callbackCapture.getValue().onCompletion(null, null); + store.read(consumerRecord(1, "status-connector-conn", value)); + return null; + }); + + replayAll(); + + store.read(consumerRecord(0, "status-connector-conn", value)); + store.putSafe(new ConnectorStatus(CONNECTOR, ConnectorStatus.State.UNASSIGNED, WORKER_ID, 0)); + + ConnectorStatus status = new ConnectorStatus(CONNECTOR, ConnectorStatus.State.UNASSIGNED, WORKER_ID, 0); + assertEquals(status, store.get(CONNECTOR)); + + verifyAll(); + } + + @Test + public void putConnectorStateShouldOverride() { + final byte[] value = new byte[0]; + String otherWorkerId = "anotherhost:8083"; + + // the persisted came from a different host and has a newer generation + Map firstStatusRead = new HashMap<>(); + firstStatusRead.put("worker_id", otherWorkerId); + firstStatusRead.put("state", "RUNNING"); + firstStatusRead.put("generation", 1L); + + Map secondStatusRead = new HashMap<>(); + secondStatusRead.put("worker_id", WORKER_ID); + secondStatusRead.put("state", "UNASSIGNED"); + secondStatusRead.put("generation", 0L); + + expect(converter.toConnectData(STATUS_TOPIC, value)) + .andReturn(new SchemaAndValue(null, firstStatusRead)) + .andReturn(new SchemaAndValue(null, secondStatusRead)); + + expect(converter.fromConnectData(eq(STATUS_TOPIC), anyObject(Schema.class), anyObject(Struct.class))) + .andStubReturn(value); + + final Capture callbackCapture = newCapture(); + kafkaBasedLog.send(eq("status-connector-conn"), eq(value), capture(callbackCapture)); + expectLastCall() + .andAnswer(() -> { + callbackCapture.getValue().onCompletion(null, null); + store.read(consumerRecord(1, "status-connector-conn", value)); + return null; + }); + replayAll(); + + store.read(consumerRecord(0, "status-connector-conn", value)); + + ConnectorStatus status = new ConnectorStatus(CONNECTOR, ConnectorStatus.State.UNASSIGNED, WORKER_ID, 0); + store.put(status); + assertEquals(status, store.get(CONNECTOR)); + + verifyAll(); + } + + @Test + public void readConnectorState() { + byte[] value = new byte[0]; + + Map statusMap = new HashMap<>(); + statusMap.put("worker_id", WORKER_ID); + statusMap.put("state", "RUNNING"); + statusMap.put("generation", 0L); + + expect(converter.toConnectData(STATUS_TOPIC, value)) + .andReturn(new SchemaAndValue(null, statusMap)); + + replayAll(); + + store.read(consumerRecord(0, "status-connector-conn", value)); + + ConnectorStatus status = new ConnectorStatus(CONNECTOR, ConnectorStatus.State.RUNNING, WORKER_ID, 0); + assertEquals(status, store.get(CONNECTOR)); + + verifyAll(); + } + + @Test + public void putTaskState() { + byte[] value = new byte[0]; + expect(converter.fromConnectData(eq(STATUS_TOPIC), anyObject(Schema.class), anyObject(Struct.class))) + .andStubReturn(value); + + final Capture callbackCapture = newCapture(); + kafkaBasedLog.send(eq("status-task-conn-0"), eq(value), capture(callbackCapture)); + expectLastCall() + .andAnswer(() -> { + callbackCapture.getValue().onCompletion(null, null); + return null; + }); + replayAll(); + + TaskStatus status = new TaskStatus(TASK, TaskStatus.State.RUNNING, WORKER_ID, 0); + store.put(status); + + // state is not visible until read back from the log + assertNull(store.get(TASK)); + + verifyAll(); + } + + @Test + public void readTaskState() { + byte[] value = new byte[0]; + + Map statusMap = new HashMap<>(); + statusMap.put("worker_id", WORKER_ID); + statusMap.put("state", "RUNNING"); + statusMap.put("generation", 0L); + + expect(converter.toConnectData(STATUS_TOPIC, value)) + .andReturn(new SchemaAndValue(null, statusMap)); + + replayAll(); + + store.read(consumerRecord(0, "status-task-conn-0", value)); + + TaskStatus status = new TaskStatus(TASK, TaskStatus.State.RUNNING, WORKER_ID, 0); + assertEquals(status, store.get(TASK)); + + verifyAll(); + } + + @Test + public void deleteConnectorState() { + final byte[] value = new byte[0]; + Map statusMap = new HashMap<>(); + statusMap.put("worker_id", WORKER_ID); + statusMap.put("state", "RUNNING"); + statusMap.put("generation", 0L); + + converter.fromConnectData(eq(STATUS_TOPIC), anyObject(Schema.class), anyObject(Struct.class)); + EasyMock.expectLastCall().andReturn(value); + kafkaBasedLog.send(eq("status-connector-" + CONNECTOR), eq(value), anyObject(Callback.class)); + expectLastCall(); + + converter.fromConnectData(eq(STATUS_TOPIC), anyObject(Schema.class), anyObject(Struct.class)); + EasyMock.expectLastCall().andReturn(value); + kafkaBasedLog.send(eq("status-task-conn-0"), eq(value), anyObject(Callback.class)); + expectLastCall(); + + expect(converter.toConnectData(STATUS_TOPIC, value)).andReturn(new SchemaAndValue(null, statusMap)); + + replayAll(); + + ConnectorStatus connectorStatus = new ConnectorStatus(CONNECTOR, ConnectorStatus.State.RUNNING, WORKER_ID, 0); + store.put(connectorStatus); + TaskStatus taskStatus = new TaskStatus(TASK, TaskStatus.State.RUNNING, WORKER_ID, 0); + store.put(taskStatus); + store.read(consumerRecord(0, "status-task-conn-0", value)); + + assertEquals(new HashSet<>(Collections.singletonList(CONNECTOR)), store.connectors()); + assertEquals(new HashSet<>(Collections.singletonList(taskStatus)), new HashSet<>(store.getAll(CONNECTOR))); + store.read(consumerRecord(0, "status-connector-conn", null)); + assertTrue(store.connectors().isEmpty()); + assertTrue(store.getAll(CONNECTOR).isEmpty()); + verifyAll(); + } + + @Test + public void deleteTaskState() { + final byte[] value = new byte[0]; + Map statusMap = new HashMap<>(); + statusMap.put("worker_id", WORKER_ID); + statusMap.put("state", "RUNNING"); + statusMap.put("generation", 0L); + + converter.fromConnectData(eq(STATUS_TOPIC), anyObject(Schema.class), anyObject(Struct.class)); + EasyMock.expectLastCall().andReturn(value); + kafkaBasedLog.send(eq("status-task-conn-0"), eq(value), anyObject(Callback.class)); + expectLastCall(); + + expect(converter.toConnectData(STATUS_TOPIC, value)).andReturn(new SchemaAndValue(null, statusMap)); + + replayAll(); + + TaskStatus taskStatus = new TaskStatus(TASK, TaskStatus.State.RUNNING, WORKER_ID, 0); + store.put(taskStatus); + store.read(consumerRecord(0, "status-task-conn-0", value)); + + assertEquals(new HashSet<>(Collections.singletonList(taskStatus)), new HashSet<>(store.getAll(CONNECTOR))); + store.read(consumerRecord(0, "status-task-conn-0", null)); + assertTrue(store.getAll(CONNECTOR).isEmpty()); + verifyAll(); + } + + private static ConsumerRecord consumerRecord(long offset, String key, byte[] value) { + return new ConsumerRecord<>(STATUS_TOPIC, 0, offset, System.currentTimeMillis(), + TimestampType.CREATE_TIME, 0, 0, key, value, new RecordHeaders(), Optional.empty()); + } + +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/storage/MemoryStatusBackingStoreTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/storage/MemoryStatusBackingStoreTest.java new file mode 100644 index 0000000..a31915a --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/storage/MemoryStatusBackingStoreTest.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import org.apache.kafka.connect.runtime.ConnectorStatus; +import org.apache.kafka.connect.runtime.TaskStatus; +import org.apache.kafka.connect.util.ConnectorTaskId; +import org.junit.Test; + +import java.util.Collections; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +public class MemoryStatusBackingStoreTest { + + @Test + public void putAndGetConnectorStatus() { + MemoryStatusBackingStore store = new MemoryStatusBackingStore(); + ConnectorStatus status = new ConnectorStatus("connector", ConnectorStatus.State.RUNNING, "localhost:8083", 0); + store.put(status); + assertEquals(status, store.get("connector")); + } + + @Test + public void putAndGetTaskStatus() { + MemoryStatusBackingStore store = new MemoryStatusBackingStore(); + ConnectorTaskId taskId = new ConnectorTaskId("connector", 0); + TaskStatus status = new TaskStatus(taskId, ConnectorStatus.State.RUNNING, "localhost:8083", 0); + store.put(status); + assertEquals(status, store.get(taskId)); + assertEquals(Collections.singleton(status), store.getAll("connector")); + } + + @Test + public void deleteConnectorStatus() { + MemoryStatusBackingStore store = new MemoryStatusBackingStore(); + store.put(new ConnectorStatus("connector", ConnectorStatus.State.RUNNING, "localhost:8083", 0)); + store.put(new ConnectorStatus("connector", ConnectorStatus.State.DESTROYED, "localhost:8083", 0)); + assertNull(store.get("connector")); + } + + @Test + public void deleteTaskStatus() { + MemoryStatusBackingStore store = new MemoryStatusBackingStore(); + ConnectorTaskId taskId = new ConnectorTaskId("connector", 0); + store.put(new TaskStatus(taskId, ConnectorStatus.State.RUNNING, "localhost:8083", 0)); + store.put(new TaskStatus(taskId, ConnectorStatus.State.DESTROYED, "localhost:8083", 0)); + assertNull(store.get(taskId)); + } + +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/storage/OffsetStorageWriterTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/storage/OffsetStorageWriterTest.java new file mode 100644 index 0000000..b442bca --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/storage/OffsetStorageWriterTest.java @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.storage; + +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.util.Callback; +import org.easymock.Capture; +import org.easymock.EasyMock; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.easymock.PowerMock; +import org.powermock.api.easymock.annotation.Mock; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +@RunWith(PowerMockRunner.class) +public class OffsetStorageWriterTest { + private static final String NAMESPACE = "namespace"; + // Connect format - any types should be accepted here + private static final Map OFFSET_KEY = Collections.singletonMap("key", "key"); + private static final Map OFFSET_VALUE = Collections.singletonMap("key", 12); + + // Serialized + private static final byte[] OFFSET_KEY_SERIALIZED = "key-serialized".getBytes(); + private static final byte[] OFFSET_VALUE_SERIALIZED = "value-serialized".getBytes(); + + @Mock private OffsetBackingStore store; + @Mock private Converter keyConverter; + @Mock private Converter valueConverter; + private OffsetStorageWriter writer; + + private static Exception exception = new RuntimeException("error"); + + private ExecutorService service; + + @Before + public void setup() { + writer = new OffsetStorageWriter(store, NAMESPACE, keyConverter, valueConverter); + service = Executors.newFixedThreadPool(1); + } + + @After + public void teardown() { + service.shutdownNow(); + } + + @Test + public void testWriteFlush() throws Exception { + @SuppressWarnings("unchecked") + Callback callback = PowerMock.createMock(Callback.class); + expectStore(OFFSET_KEY, OFFSET_KEY_SERIALIZED, OFFSET_VALUE, OFFSET_VALUE_SERIALIZED, callback, false, null); + + PowerMock.replayAll(); + + writer.offset(OFFSET_KEY, OFFSET_VALUE); + + assertTrue(writer.beginFlush()); + writer.doFlush(callback).get(1000, TimeUnit.MILLISECONDS); + + PowerMock.verifyAll(); + } + + // It should be possible to set offset values to null + @Test + public void testWriteNullValueFlush() throws Exception { + @SuppressWarnings("unchecked") + Callback callback = PowerMock.createMock(Callback.class); + expectStore(OFFSET_KEY, OFFSET_KEY_SERIALIZED, null, null, callback, false, null); + + PowerMock.replayAll(); + + writer.offset(OFFSET_KEY, null); + + assertTrue(writer.beginFlush()); + writer.doFlush(callback).get(1000, TimeUnit.MILLISECONDS); + + PowerMock.verifyAll(); + } + + // It should be possible to use null keys. These aren't actually stored as null since the key is wrapped to include + // info about the namespace (connector) + @Test + public void testWriteNullKeyFlush() throws Exception { + @SuppressWarnings("unchecked") + Callback callback = PowerMock.createMock(Callback.class); + expectStore(null, null, OFFSET_VALUE, OFFSET_VALUE_SERIALIZED, callback, false, null); + + PowerMock.replayAll(); + + writer.offset(null, OFFSET_VALUE); + + assertTrue(writer.beginFlush()); + writer.doFlush(callback).get(1000, TimeUnit.MILLISECONDS); + + PowerMock.verifyAll(); + } + + @Test + public void testNoOffsetsToFlush() { + // If no offsets are flushed, we should finish immediately and not have made any calls to the + // underlying storage layer + + PowerMock.replayAll(); + + // Should not return a future + assertFalse(writer.beginFlush()); + + PowerMock.verifyAll(); + } + + @Test + public void testFlushFailureReplacesOffsets() throws Exception { + // When a flush fails, we shouldn't just lose the offsets. Instead, they should be restored + // such that a subsequent flush will write them. + + @SuppressWarnings("unchecked") + final Callback callback = PowerMock.createMock(Callback.class); + // First time the write fails + expectStore(OFFSET_KEY, OFFSET_KEY_SERIALIZED, OFFSET_VALUE, OFFSET_VALUE_SERIALIZED, callback, true, null); + // Second time it succeeds + expectStore(OFFSET_KEY, OFFSET_KEY_SERIALIZED, OFFSET_VALUE, OFFSET_VALUE_SERIALIZED, callback, false, null); + // Third time it has no data to flush so we won't get past beginFlush() + + PowerMock.replayAll(); + + writer.offset(OFFSET_KEY, OFFSET_VALUE); + assertTrue(writer.beginFlush()); + writer.doFlush(callback).get(1000, TimeUnit.MILLISECONDS); + assertTrue(writer.beginFlush()); + writer.doFlush(callback).get(1000, TimeUnit.MILLISECONDS); + assertFalse(writer.beginFlush()); + + PowerMock.verifyAll(); + } + + @Test + public void testAlreadyFlushing() { + @SuppressWarnings("unchecked") + final Callback callback = PowerMock.createMock(Callback.class); + // Trigger the send, but don't invoke the callback so we'll still be mid-flush + CountDownLatch allowStoreCompleteCountdown = new CountDownLatch(1); + expectStore(OFFSET_KEY, OFFSET_KEY_SERIALIZED, OFFSET_VALUE, OFFSET_VALUE_SERIALIZED, null, false, allowStoreCompleteCountdown); + + PowerMock.replayAll(); + + writer.offset(OFFSET_KEY, OFFSET_VALUE); + assertTrue(writer.beginFlush()); + writer.doFlush(callback); + assertThrows(ConnectException.class, writer::beginFlush); + + PowerMock.verifyAll(); + } + + @Test + public void testCancelBeforeAwaitFlush() { + PowerMock.replayAll(); + + writer.offset(OFFSET_KEY, OFFSET_VALUE); + assertTrue(writer.beginFlush()); + writer.cancelFlush(); + + PowerMock.verifyAll(); + } + + @Test + public void testCancelAfterAwaitFlush() throws Exception { + @SuppressWarnings("unchecked") + Callback callback = PowerMock.createMock(Callback.class); + CountDownLatch allowStoreCompleteCountdown = new CountDownLatch(1); + // In this test, the write should be cancelled so the callback will not be invoked and is not + // passed to the expectStore call + expectStore(OFFSET_KEY, OFFSET_KEY_SERIALIZED, OFFSET_VALUE, OFFSET_VALUE_SERIALIZED, null, false, allowStoreCompleteCountdown); + + PowerMock.replayAll(); + + writer.offset(OFFSET_KEY, OFFSET_VALUE); + assertTrue(writer.beginFlush()); + // Start the flush, then immediately cancel before allowing the mocked store request to finish + Future flushFuture = writer.doFlush(callback); + writer.cancelFlush(); + allowStoreCompleteCountdown.countDown(); + flushFuture.get(1000, TimeUnit.MILLISECONDS); + + PowerMock.verifyAll(); + } + + /** + * Expect a request to store data to the underlying OffsetBackingStore. + * + * @param key the key for the offset + * @param keySerialized serialized version of the key + * @param value the value for the offset + * @param valueSerialized serialized version of the value + * @param callback the callback to invoke when completed, or null if the callback isn't + * expected to be invoked + * @param fail if true, treat + * @param waitForCompletion if non-null, a CountDownLatch that should be awaited on before + * invoking the callback. A (generous) timeout is still imposed to + * ensure tests complete. + * @return the captured set of ByteBuffer key-value pairs passed to the storage layer + */ + private void expectStore(Map key, byte[] keySerialized, + Map value, byte[] valueSerialized, + final Callback callback, + final boolean fail, + final CountDownLatch waitForCompletion) { + List keyWrapped = Arrays.asList(NAMESPACE, key); + EasyMock.expect(keyConverter.fromConnectData(NAMESPACE, null, keyWrapped)).andReturn(keySerialized); + EasyMock.expect(valueConverter.fromConnectData(NAMESPACE, null, value)).andReturn(valueSerialized); + + final Capture> storeCallback = Capture.newInstance(); + final Map offsetsSerialized = Collections.singletonMap( + keySerialized == null ? null : ByteBuffer.wrap(keySerialized), + valueSerialized == null ? null : ByteBuffer.wrap(valueSerialized)); + EasyMock.expect(store.set(EasyMock.eq(offsetsSerialized), EasyMock.capture(storeCallback))) + .andAnswer(() -> + service.submit(() -> { + if (waitForCompletion != null) + assertTrue(waitForCompletion.await(10000, TimeUnit.MILLISECONDS)); + + if (fail) { + storeCallback.getValue().onCompletion(exception, null); + } else { + storeCallback.getValue().onCompletion(null, null); + } + return null; + }) + ); + if (callback != null) { + if (fail) { + callback.onCompletion(EasyMock.eq(exception), EasyMock.eq(null)); + } else { + callback.onCompletion(null, null); + } + } + PowerMock.expectLastCall(); + } + +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/util/ByteArrayProducerRecordEquals.java b/connect/runtime/src/test/java/org/apache/kafka/connect/util/ByteArrayProducerRecordEquals.java new file mode 100644 index 0000000..a6a155f --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/util/ByteArrayProducerRecordEquals.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +import org.apache.kafka.clients.producer.ProducerRecord; +import org.easymock.EasyMock; +import org.easymock.IArgumentMatcher; + +import java.util.Arrays; + +public class ByteArrayProducerRecordEquals implements IArgumentMatcher { + private ProducerRecord record; + + public static ProducerRecord eqProducerRecord(ProducerRecord in) { + EasyMock.reportMatcher(new ByteArrayProducerRecordEquals(in)); + return null; + } + + public ByteArrayProducerRecordEquals(ProducerRecord record) { + this.record = record; + } + + @Override + @SuppressWarnings("unchecked") + public boolean matches(Object argument) { + if (!(argument instanceof ProducerRecord)) + return false; + ProducerRecord other = (ProducerRecord) argument; + return record.topic().equals(other.topic()) && + record.partition() != null ? record.partition().equals(other.partition()) : other.partition() == null && + record.key() != null ? Arrays.equals(record.key(), other.key()) : other.key() == null && + record.value() != null ? Arrays.equals(record.value(), other.value()) : other.value() == null; + } + + @Override + public void appendTo(StringBuffer buffer) { + buffer.append(record.toString()); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/util/ConnectUtilsTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/util/ConnectUtilsTest.java new file mode 100644 index 0000000..62a01b7 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/util/ConnectUtilsTest.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.clients.admin.MockAdminClient; +import org.apache.kafka.common.Node; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.apache.kafka.connect.runtime.distributed.DistributedConfig; +import org.apache.kafka.connect.runtime.standalone.StandaloneConfig; +import org.junit.Test; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; + +public class ConnectUtilsTest { + + @Test + public void testLookupKafkaClusterId() { + final Node broker1 = new Node(0, "dummyHost-1", 1234); + final Node broker2 = new Node(1, "dummyHost-2", 1234); + List cluster = Arrays.asList(broker1, broker2); + MockAdminClient adminClient = new MockAdminClient.Builder(). + brokers(cluster).build(); + assertEquals(MockAdminClient.DEFAULT_CLUSTER_ID, ConnectUtils.lookupKafkaClusterId(adminClient)); + } + + @Test + public void testLookupNullKafkaClusterId() { + final Node broker1 = new Node(0, "dummyHost-1", 1234); + final Node broker2 = new Node(1, "dummyHost-2", 1234); + List cluster = Arrays.asList(broker1, broker2); + MockAdminClient adminClient = new MockAdminClient.Builder(). + brokers(cluster).clusterId(null).build(); + assertNull(ConnectUtils.lookupKafkaClusterId(adminClient)); + } + + @Test + public void testLookupKafkaClusterIdTimeout() { + final Node broker1 = new Node(0, "dummyHost-1", 1234); + final Node broker2 = new Node(1, "dummyHost-2", 1234); + List cluster = Arrays.asList(broker1, broker2); + MockAdminClient adminClient = new MockAdminClient.Builder(). + brokers(cluster).build(); + adminClient.timeoutNextRequest(1); + + assertThrows(ConnectException.class, () -> ConnectUtils.lookupKafkaClusterId(adminClient)); + } + + @Test + public void testAddMetricsContextPropertiesDistributed() { + Map props = new HashMap<>(); + props.put(DistributedConfig.GROUP_ID_CONFIG, "connect-cluster"); + props.put(DistributedConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + props.put(DistributedConfig.CONFIG_TOPIC_CONFIG, "connect-configs"); + props.put(DistributedConfig.OFFSET_STORAGE_TOPIC_CONFIG, "connect-offsets"); + props.put(DistributedConfig.STATUS_STORAGE_TOPIC_CONFIG, "connect-status"); + props.put(DistributedConfig.KEY_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + props.put(DistributedConfig.VALUE_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + DistributedConfig config = new DistributedConfig(props); + + Map prop = new HashMap<>(); + ConnectUtils.addMetricsContextProperties(prop, config, "cluster-1"); + assertEquals("connect-cluster", prop.get(CommonClientConfigs.METRICS_CONTEXT_PREFIX + WorkerConfig.CONNECT_GROUP_ID)); + assertEquals("cluster-1", prop.get(CommonClientConfigs.METRICS_CONTEXT_PREFIX + WorkerConfig.CONNECT_KAFKA_CLUSTER_ID)); + } + + @Test + public void testAddMetricsContextPropertiesStandalone() { + Map props = new HashMap<>(); + props.put(StandaloneConfig.OFFSET_STORAGE_FILE_FILENAME_CONFIG, "offsetStorageFile"); + props.put(StandaloneConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + props.put(StandaloneConfig.KEY_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + props.put(StandaloneConfig.VALUE_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.json.JsonConverter"); + StandaloneConfig config = new StandaloneConfig(props); + + Map prop = new HashMap<>(); + ConnectUtils.addMetricsContextProperties(prop, config, "cluster-1"); + assertNull(prop.get(CommonClientConfigs.METRICS_CONTEXT_PREFIX + WorkerConfig.CONNECT_GROUP_ID)); + assertEquals("cluster-1", prop.get(CommonClientConfigs.METRICS_CONTEXT_PREFIX + WorkerConfig.CONNECT_KAFKA_CLUSTER_ID)); + + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/util/ConvertingFutureCallbackTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/util/ConvertingFutureCallbackTest.java new file mode 100644 index 0000000..7977a29 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/util/ConvertingFutureCallbackTest.java @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +import org.junit.Before; +import org.junit.Test; + +import java.util.concurrent.CancellationException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class ConvertingFutureCallbackTest { + + private ExecutorService executor; + + @Before + public void setup() { + executor = Executors.newSingleThreadExecutor(); + } + + @Test + public void shouldConvertBeforeGetOnSuccessfulCompletion() throws Exception { + final Object expectedConversion = new Object(); + TestConvertingFutureCallback testCallback = new TestConvertingFutureCallback(); + testCallback.onCompletion(null, expectedConversion); + assertEquals(1, testCallback.numberOfConversions()); + assertEquals(expectedConversion, testCallback.get()); + } + + @Test + public void shouldConvertOnlyOnceBeforeGetOnSuccessfulCompletion() throws Exception { + final Object expectedConversion = new Object(); + TestConvertingFutureCallback testCallback = new TestConvertingFutureCallback(); + testCallback.onCompletion(null, expectedConversion); + testCallback.onCompletion(null, 69); + testCallback.cancel(true); + testCallback.onCompletion(new RuntimeException(), null); + assertEquals(1, testCallback.numberOfConversions()); + assertEquals(expectedConversion, testCallback.get()); + } + + @Test + public void shouldNotConvertBeforeGetOnFailedCompletion() throws Exception { + final Throwable expectedError = new Throwable(); + TestConvertingFutureCallback testCallback = new TestConvertingFutureCallback(); + testCallback.onCompletion(expectedError, null); + assertEquals(0, testCallback.numberOfConversions()); + try { + testCallback.get(); + fail("Expected ExecutionException"); + } catch (ExecutionException e) { + assertEquals(expectedError, e.getCause()); + } + } + + @Test + public void shouldRecordOnlyFirstErrorBeforeGetOnFailedCompletion() throws Exception { + final Throwable expectedError = new Throwable(); + TestConvertingFutureCallback testCallback = new TestConvertingFutureCallback(); + testCallback.onCompletion(expectedError, null); + testCallback.onCompletion(new RuntimeException(), null); + testCallback.cancel(true); + testCallback.onCompletion(null, "420"); + assertEquals(0, testCallback.numberOfConversions()); + try { + testCallback.get(); + fail("Expected ExecutionException"); + } catch (ExecutionException e) { + assertEquals(expectedError, e.getCause()); + } + } + + @Test + public void shouldCancelBeforeGetIfMayCancelWhileRunning() { + TestConvertingFutureCallback testCallback = new TestConvertingFutureCallback(); + assertTrue(testCallback.cancel(true)); + assertThrows(CancellationException.class, testCallback::get); + } + + @Test + public void shouldBlockUntilSuccessfulCompletion() throws Exception { + AtomicReference testThreadException = new AtomicReference<>(); + TestConvertingFutureCallback testCallback = new TestConvertingFutureCallback(); + final Object expectedConversion = new Object(); + executor.submit(() -> { + try { + testCallback.waitForGet(); + testCallback.onCompletion(null, expectedConversion); + } catch (Exception e) { + testThreadException.compareAndSet(null, e); + } + }); + assertFalse(testCallback.isDone()); + assertEquals(expectedConversion, testCallback.get()); + assertEquals(1, testCallback.numberOfConversions()); + assertTrue(testCallback.isDone()); + if (testThreadException.get() != null) { + throw testThreadException.get(); + } + } + + @Test + public void shouldBlockUntilFailedCompletion() throws Exception { + AtomicReference testThreadException = new AtomicReference<>(); + TestConvertingFutureCallback testCallback = new TestConvertingFutureCallback(); + final Throwable expectedError = new Throwable(); + executor.submit(() -> { + try { + testCallback.waitForGet(); + testCallback.onCompletion(expectedError, null); + } catch (Exception e) { + testThreadException.compareAndSet(null, e); + } + }); + assertFalse(testCallback.isDone()); + try { + testCallback.get(); + fail("Expected ExecutionException"); + } catch (ExecutionException e) { + assertEquals(expectedError, e.getCause()); + } + assertEquals(0, testCallback.numberOfConversions()); + assertTrue(testCallback.isDone()); + if (testThreadException.get() != null) { + throw testThreadException.get(); + } + } + + @Test + public void shouldBlockUntilCancellation() { + AtomicReference testThreadException = new AtomicReference<>(); + TestConvertingFutureCallback testCallback = new TestConvertingFutureCallback(); + executor.submit(() -> { + try { + testCallback.waitForGet(); + testCallback.cancel(true); + } catch (Exception e) { + testThreadException.compareAndSet(null, e); + } + }); + assertFalse(testCallback.isDone()); + assertThrows(CancellationException.class, testCallback::get); + } + + @Test + public void shouldNotCancelIfMayNotCancelWhileRunning() throws Exception { + AtomicReference testThreadException = new AtomicReference<>(); + TestConvertingFutureCallback testCallback = new TestConvertingFutureCallback(); + final Object expectedConversion = new Object(); + executor.submit(() -> { + try { + testCallback.waitForCancel(); + testCallback.onCompletion(null, expectedConversion); + } catch (Exception e) { + testThreadException.compareAndSet(null, e); + } + }); + assertFalse(testCallback.isCancelled()); + assertFalse(testCallback.isDone()); + testCallback.cancel(false); + assertFalse(testCallback.isCancelled()); + assertTrue(testCallback.isDone()); + assertEquals(expectedConversion, testCallback.get()); + assertEquals(1, testCallback.numberOfConversions()); + if (testThreadException.get() != null) assertThrows(CancellationException.class, testThreadException::get); + } + + protected static class TestConvertingFutureCallback extends ConvertingFutureCallback { + private AtomicInteger numberOfConversions = new AtomicInteger(); + private CountDownLatch getInvoked = new CountDownLatch(1); + private CountDownLatch cancelInvoked = new CountDownLatch(1); + + public int numberOfConversions() { + return numberOfConversions.get(); + } + + public void waitForGet() throws InterruptedException { + getInvoked.await(); + } + + public void waitForCancel() throws InterruptedException { + cancelInvoked.await(); + } + + @Override + public Object convert(Object result) { + numberOfConversions.incrementAndGet(); + return result; + } + + @Override + public Object get() throws InterruptedException, ExecutionException { + getInvoked.countDown(); + return super.get(); + } + + @Override + public Object get( + long duration, + TimeUnit unit + ) throws InterruptedException, ExecutionException, TimeoutException { + getInvoked.countDown(); + return super.get(duration, unit); + } + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + cancelInvoked.countDown(); + return super.cancel(mayInterruptIfRunning); + } + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/util/KafkaBasedLogTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/util/KafkaBasedLogTest.java new file mode 100644 index 0000000..8fae57e --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/util/KafkaBasedLogTest.java @@ -0,0 +1,580 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.MockConsumer; +import org.apache.kafka.clients.consumer.OffsetResetStrategy; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.LeaderNotAvailableException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.errors.WakeupException; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.MockTime; +import org.easymock.Capture; +import org.easymock.EasyMock; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.easymock.PowerMock; +import org.powermock.api.easymock.annotation.Mock; +import org.powermock.core.classloader.annotations.PowerMockIgnore; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; +import org.powermock.reflect.Whitebox; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +@RunWith(PowerMockRunner.class) +@PrepareForTest(KafkaBasedLog.class) +@PowerMockIgnore("javax.management.*") +public class KafkaBasedLogTest { + + private static final String TOPIC = "connect-log"; + private static final TopicPartition TP0 = new TopicPartition(TOPIC, 0); + private static final TopicPartition TP1 = new TopicPartition(TOPIC, 1); + private static final Map PRODUCER_PROPS = new HashMap<>(); + static { + PRODUCER_PROPS.put(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, "broker1:9092,broker2:9093"); + PRODUCER_PROPS.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.StringSerializer"); + PRODUCER_PROPS.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.StringSerializer"); + } + private static final Map CONSUMER_PROPS = new HashMap<>(); + static { + CONSUMER_PROPS.put(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, "broker1:9092,broker2:9093"); + CONSUMER_PROPS.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.StringDeserializer"); + CONSUMER_PROPS.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.StringDeserializer"); + } + + private static final Set CONSUMER_ASSIGNMENT = new HashSet<>(Arrays.asList(TP0, TP1)); + private static final Map FIRST_SET = new HashMap<>(); + static { + FIRST_SET.put("key", "value"); + FIRST_SET.put(null, null); + } + + private static final Node LEADER = new Node(1, "broker1", 9092); + private static final Node REPLICA = new Node(1, "broker2", 9093); + + private static final PartitionInfo TPINFO0 = new PartitionInfo(TOPIC, 0, LEADER, new Node[]{REPLICA}, new Node[]{REPLICA}); + private static final PartitionInfo TPINFO1 = new PartitionInfo(TOPIC, 1, LEADER, new Node[]{REPLICA}, new Node[]{REPLICA}); + + private static final String TP0_KEY = "TP0KEY"; + private static final String TP1_KEY = "TP1KEY"; + private static final String TP0_VALUE = "VAL0"; + private static final String TP1_VALUE = "VAL1"; + private static final String TP0_VALUE_NEW = "VAL0_NEW"; + private static final String TP1_VALUE_NEW = "VAL1_NEW"; + + private Time time = new MockTime(); + private KafkaBasedLog store; + + @Mock + private Runnable initializer; + @Mock + private KafkaProducer producer; + private MockConsumer consumer; + @Mock + private TopicAdmin admin; + + private Map>> consumedRecords = new HashMap<>(); + private Callback> consumedCallback = (error, record) -> { + TopicPartition partition = new TopicPartition(record.topic(), record.partition()); + List> records = consumedRecords.get(partition); + if (records == null) { + records = new ArrayList<>(); + consumedRecords.put(partition, records); + } + records.add(record); + }; + + @SuppressWarnings("unchecked") + @Before + public void setUp() { + store = PowerMock.createPartialMock(KafkaBasedLog.class, new String[]{"createConsumer", "createProducer"}, + TOPIC, PRODUCER_PROPS, CONSUMER_PROPS, consumedCallback, time, initializer); + consumer = new MockConsumer<>(OffsetResetStrategy.EARLIEST); + consumer.updatePartitions(TOPIC, Arrays.asList(TPINFO0, TPINFO1)); + Map beginningOffsets = new HashMap<>(); + beginningOffsets.put(TP0, 0L); + beginningOffsets.put(TP1, 0L); + consumer.updateBeginningOffsets(beginningOffsets); + } + + @Test + public void testStartStop() throws Exception { + expectStart(); + expectStop(); + + PowerMock.replayAll(); + + Map endOffsets = new HashMap<>(); + endOffsets.put(TP0, 0L); + endOffsets.put(TP1, 0L); + consumer.updateEndOffsets(endOffsets); + store.start(); + assertEquals(CONSUMER_ASSIGNMENT, consumer.assignment()); + + store.stop(); + + assertFalse(Whitebox.getInternalState(store, "thread").isAlive()); + assertTrue(consumer.closed()); + PowerMock.verifyAll(); + } + + @Test + public void testReloadOnStart() throws Exception { + expectStart(); + expectStop(); + + PowerMock.replayAll(); + + Map endOffsets = new HashMap<>(); + endOffsets.put(TP0, 1L); + endOffsets.put(TP1, 1L); + consumer.updateEndOffsets(endOffsets); + final CountDownLatch finishedLatch = new CountDownLatch(1); + consumer.schedulePollTask(() -> { + // Use first poll task to setup sequence of remaining responses to polls + // Should keep polling until it reaches current log end offset for all partitions. Should handle + // as many empty polls as needed + consumer.scheduleNopPollTask(); + consumer.scheduleNopPollTask(); + consumer.schedulePollTask(() -> + consumer.addRecord(new ConsumerRecord<>(TOPIC, 0, 0, 0L, TimestampType.CREATE_TIME, 0, 0, TP0_KEY, TP0_VALUE, + new RecordHeaders(), Optional.empty())) + ); + consumer.scheduleNopPollTask(); + consumer.scheduleNopPollTask(); + consumer.schedulePollTask(() -> + consumer.addRecord(new ConsumerRecord<>(TOPIC, 1, 0, 0L, TimestampType.CREATE_TIME, 0, 0, TP1_KEY, TP1_VALUE, + new RecordHeaders(), Optional.empty())) + ); + consumer.schedulePollTask(finishedLatch::countDown); + }); + store.start(); + assertTrue(finishedLatch.await(10000, TimeUnit.MILLISECONDS)); + + assertEquals(CONSUMER_ASSIGNMENT, consumer.assignment()); + assertEquals(2, consumedRecords.size()); + + assertEquals(TP0_VALUE, consumedRecords.get(TP0).get(0).value()); + assertEquals(TP1_VALUE, consumedRecords.get(TP1).get(0).value()); + + store.stop(); + + assertFalse(Whitebox.getInternalState(store, "thread").isAlive()); + assertTrue(consumer.closed()); + PowerMock.verifyAll(); + } + + @Test + public void testReloadOnStartWithNoNewRecordsPresent() throws Exception { + expectStart(); + expectStop(); + + PowerMock.replayAll(); + + Map endOffsets = new HashMap<>(); + endOffsets.put(TP0, 7L); + endOffsets.put(TP1, 7L); + consumer.updateEndOffsets(endOffsets); + // Better test with an advanced offset other than just 0L + consumer.updateBeginningOffsets(endOffsets); + + consumer.schedulePollTask(() -> { + // Throw an exception that will not be ignored or handled by Connect framework. In + // reality a misplaced call to poll blocks indefinitely and connect aborts due to + // time outs (for instance via ConnectRestException) + throw new WakeupException(); + }); + + store.start(); + + assertEquals(CONSUMER_ASSIGNMENT, consumer.assignment()); + assertEquals(7L, consumer.position(TP0)); + assertEquals(7L, consumer.position(TP1)); + + store.stop(); + + assertFalse(Whitebox.getInternalState(store, "thread").isAlive()); + assertTrue(consumer.closed()); + PowerMock.verifyAll(); + } + + @Test + public void testSendAndReadToEnd() throws Exception { + expectStart(); + TestFuture tp0Future = new TestFuture<>(); + ProducerRecord tp0Record = new ProducerRecord<>(TOPIC, TP0_KEY, TP0_VALUE); + Capture callback0 = EasyMock.newCapture(); + EasyMock.expect(producer.send(EasyMock.eq(tp0Record), EasyMock.capture(callback0))).andReturn(tp0Future); + TestFuture tp1Future = new TestFuture<>(); + ProducerRecord tp1Record = new ProducerRecord<>(TOPIC, TP1_KEY, TP1_VALUE); + Capture callback1 = EasyMock.newCapture(); + EasyMock.expect(producer.send(EasyMock.eq(tp1Record), EasyMock.capture(callback1))).andReturn(tp1Future); + + // Producer flushes when read to log end is called + producer.flush(); + PowerMock.expectLastCall(); + + expectStop(); + + PowerMock.replayAll(); + + Map endOffsets = new HashMap<>(); + endOffsets.put(TP0, 0L); + endOffsets.put(TP1, 0L); + consumer.updateEndOffsets(endOffsets); + store.start(); + assertEquals(CONSUMER_ASSIGNMENT, consumer.assignment()); + assertEquals(0L, consumer.position(TP0)); + assertEquals(0L, consumer.position(TP1)); + + // Set some keys + final AtomicInteger invoked = new AtomicInteger(0); + org.apache.kafka.clients.producer.Callback producerCallback = (metadata, exception) -> invoked.incrementAndGet(); + store.send(TP0_KEY, TP0_VALUE, producerCallback); + store.send(TP1_KEY, TP1_VALUE, producerCallback); + assertEquals(0, invoked.get()); + tp1Future.resolve((RecordMetadata) null); // Output not used, so safe to not return a real value for testing + callback1.getValue().onCompletion(null, null); + assertEquals(1, invoked.get()); + tp0Future.resolve((RecordMetadata) null); + callback0.getValue().onCompletion(null, null); + assertEquals(2, invoked.get()); + + // Now we should have to wait for the records to be read back when we call readToEnd() + final AtomicBoolean getInvoked = new AtomicBoolean(false); + final FutureCallback readEndFutureCallback = new FutureCallback<>((error, result) -> getInvoked.set(true)); + consumer.schedulePollTask(() -> { + // Once we're synchronized in a poll, start the read to end and schedule the exact set of poll events + // that should follow. This readToEnd call will immediately wakeup this consumer.poll() call without + // returning any data. + Map newEndOffsets = new HashMap<>(); + newEndOffsets.put(TP0, 2L); + newEndOffsets.put(TP1, 2L); + consumer.updateEndOffsets(newEndOffsets); + store.readToEnd(readEndFutureCallback); + + // Should keep polling until it reaches current log end offset for all partitions + consumer.scheduleNopPollTask(); + consumer.scheduleNopPollTask(); + consumer.scheduleNopPollTask(); + consumer.schedulePollTask(() -> { + consumer.addRecord(new ConsumerRecord<>(TOPIC, 0, 0, 0L, TimestampType.CREATE_TIME, 0, 0, TP0_KEY, TP0_VALUE, + new RecordHeaders(), Optional.empty())); + consumer.addRecord(new ConsumerRecord<>(TOPIC, 0, 1, 0L, TimestampType.CREATE_TIME, 0, 0, TP0_KEY, TP0_VALUE_NEW, + new RecordHeaders(), Optional.empty())); + consumer.addRecord(new ConsumerRecord<>(TOPIC, 1, 0, 0L, TimestampType.CREATE_TIME, 0, 0, TP1_KEY, TP1_VALUE, + new RecordHeaders(), Optional.empty())); + }); + + consumer.schedulePollTask(() -> + consumer.addRecord(new ConsumerRecord<>(TOPIC, 1, 1, 0L, TimestampType.CREATE_TIME, 0, 0, TP1_KEY, TP1_VALUE_NEW, + new RecordHeaders(), Optional.empty()))); + + // Already have FutureCallback that should be invoked/awaited, so no need for follow up finishedLatch + }); + readEndFutureCallback.get(10000, TimeUnit.MILLISECONDS); + assertTrue(getInvoked.get()); + assertEquals(2, consumedRecords.size()); + + assertEquals(2, consumedRecords.get(TP0).size()); + assertEquals(TP0_VALUE, consumedRecords.get(TP0).get(0).value()); + assertEquals(TP0_VALUE_NEW, consumedRecords.get(TP0).get(1).value()); + + assertEquals(2, consumedRecords.get(TP1).size()); + assertEquals(TP1_VALUE, consumedRecords.get(TP1).get(0).value()); + assertEquals(TP1_VALUE_NEW, consumedRecords.get(TP1).get(1).value()); + + // Cleanup + store.stop(); + + assertFalse(Whitebox.getInternalState(store, "thread").isAlive()); + assertTrue(consumer.closed()); + PowerMock.verifyAll(); + } + + @Test + public void testPollConsumerError() throws Exception { + expectStart(); + expectStop(); + + PowerMock.replayAll(); + + final CountDownLatch finishedLatch = new CountDownLatch(1); + Map endOffsets = new HashMap<>(); + endOffsets.put(TP0, 1L); + endOffsets.put(TP1, 1L); + consumer.updateEndOffsets(endOffsets); + consumer.schedulePollTask(() -> { + // Trigger exception + consumer.schedulePollTask(() -> + consumer.setPollException(Errors.COORDINATOR_NOT_AVAILABLE.exception())); + + // Should keep polling until it reaches current log end offset for all partitions + consumer.scheduleNopPollTask(); + consumer.scheduleNopPollTask(); + consumer.schedulePollTask(() -> { + consumer.addRecord(new ConsumerRecord<>(TOPIC, 0, 0, 0L, TimestampType.CREATE_TIME, 0, 0, TP0_KEY, TP0_VALUE_NEW, + new RecordHeaders(), Optional.empty())); + consumer.addRecord(new ConsumerRecord<>(TOPIC, 1, 0, 0L, TimestampType.CREATE_TIME, 0, 0, TP0_KEY, TP0_VALUE_NEW, + new RecordHeaders(), Optional.empty())); + }); + + consumer.schedulePollTask(finishedLatch::countDown); + }); + store.start(); + assertTrue(finishedLatch.await(10000, TimeUnit.MILLISECONDS)); + assertEquals(CONSUMER_ASSIGNMENT, consumer.assignment()); + assertEquals(1L, consumer.position(TP0)); + + store.stop(); + + assertFalse(Whitebox.getInternalState(store, "thread").isAlive()); + assertTrue(consumer.closed()); + PowerMock.verifyAll(); + } + + @Test + public void testGetOffsetsConsumerErrorOnReadToEnd() throws Exception { + expectStart(); + + // Producer flushes when read to log end is called + producer.flush(); + PowerMock.expectLastCall(); + + expectStop(); + + PowerMock.replayAll(); + final CountDownLatch finishedLatch = new CountDownLatch(1); + Map endOffsets = new HashMap<>(); + endOffsets.put(TP0, 0L); + endOffsets.put(TP1, 0L); + consumer.updateEndOffsets(endOffsets); + store.start(); + final AtomicBoolean getInvoked = new AtomicBoolean(false); + final FutureCallback readEndFutureCallback = new FutureCallback<>((error, result) -> getInvoked.set(true)); + consumer.schedulePollTask(() -> { + // Once we're synchronized in a poll, start the read to end and schedule the exact set of poll events + // that should follow. This readToEnd call will immediately wakeup this consumer.poll() call without + // returning any data. + Map newEndOffsets = new HashMap<>(); + newEndOffsets.put(TP0, 1L); + newEndOffsets.put(TP1, 1L); + consumer.updateEndOffsets(newEndOffsets); + // Set exception to occur when getting offsets to read log to end. It'll be caught in the work thread, + // which will retry and eventually get the correct offsets and read log to end. + consumer.setOffsetsException(new TimeoutException("Failed to get offsets by times")); + store.readToEnd(readEndFutureCallback); + + // Should keep polling until it reaches current log end offset for all partitions + consumer.scheduleNopPollTask(); + consumer.scheduleNopPollTask(); + consumer.schedulePollTask(() -> { + consumer.addRecord(new ConsumerRecord<>(TOPIC, 0, 0, 0L, TimestampType.CREATE_TIME, 0, 0, TP0_KEY, TP0_VALUE, + new RecordHeaders(), Optional.empty())); + consumer.addRecord(new ConsumerRecord<>(TOPIC, 1, 0, 0L, TimestampType.CREATE_TIME, 0, 0, TP0_KEY, TP0_VALUE_NEW, + new RecordHeaders(), Optional.empty())); + }); + + consumer.schedulePollTask(finishedLatch::countDown); + }); + readEndFutureCallback.get(10000, TimeUnit.MILLISECONDS); + assertTrue(getInvoked.get()); + assertTrue(finishedLatch.await(10000, TimeUnit.MILLISECONDS)); + assertEquals(CONSUMER_ASSIGNMENT, consumer.assignment()); + assertEquals(1L, consumer.position(TP0)); + + store.stop(); + + assertFalse(Whitebox.getInternalState(store, "thread").isAlive()); + assertTrue(consumer.closed()); + PowerMock.verifyAll(); + } + + @Test + public void testProducerError() throws Exception { + expectStart(); + TestFuture tp0Future = new TestFuture<>(); + ProducerRecord tp0Record = new ProducerRecord<>(TOPIC, TP0_KEY, TP0_VALUE); + Capture callback0 = EasyMock.newCapture(); + EasyMock.expect(producer.send(EasyMock.eq(tp0Record), EasyMock.capture(callback0))).andReturn(tp0Future); + + expectStop(); + + PowerMock.replayAll(); + + Map endOffsets = new HashMap<>(); + endOffsets.put(TP0, 0L); + endOffsets.put(TP1, 0L); + consumer.updateEndOffsets(endOffsets); + store.start(); + assertEquals(CONSUMER_ASSIGNMENT, consumer.assignment()); + assertEquals(0L, consumer.position(TP0)); + assertEquals(0L, consumer.position(TP1)); + + final AtomicReference setException = new AtomicReference<>(); + store.send(TP0_KEY, TP0_VALUE, (metadata, exception) -> { + assertNull(setException.get()); // Should only be invoked once + setException.set(exception); + }); + KafkaException exc = new LeaderNotAvailableException("Error"); + tp0Future.resolve(exc); + callback0.getValue().onCompletion(null, exc); + assertNotNull(setException.get()); + + store.stop(); + + assertFalse(Whitebox.getInternalState(store, "thread").isAlive()); + assertTrue(consumer.closed()); + PowerMock.verifyAll(); + } + + @Test + public void testReadEndOffsetsUsingAdmin() throws Exception { + // Create a log that uses the admin supplier + setupWithAdmin(); + expectProducerAndConsumerCreate(); + + Set tps = new HashSet<>(Arrays.asList(TP0, TP1)); + Map endOffsets = new HashMap<>(); + endOffsets.put(TP0, 0L); + endOffsets.put(TP1, 0L); + admin.endOffsets(EasyMock.eq(tps)); + PowerMock.expectLastCall().andReturn(endOffsets).times(2); + + PowerMock.replayAll(); + + store.start(); + assertEquals(endOffsets, store.readEndOffsets(tps)); + } + + @Test + public void testReadEndOffsetsUsingAdminThatFailsWithUnsupported() throws Exception { + // Create a log that uses the admin supplier + setupWithAdmin(); + expectProducerAndConsumerCreate(); + + Set tps = new HashSet<>(Arrays.asList(TP0, TP1)); + // Getting end offsets using the admin client should fail with unsupported version + admin.endOffsets(EasyMock.eq(tps)); + PowerMock.expectLastCall().andThrow(new UnsupportedVersionException("too old")); + + // Falls back to the consumer + Map endOffsets = new HashMap<>(); + endOffsets.put(TP0, 0L); + endOffsets.put(TP1, 0L); + consumer.updateEndOffsets(endOffsets); + + PowerMock.replayAll(); + + store.start(); + assertEquals(endOffsets, store.readEndOffsets(tps)); + } + + @Test + public void testReadEndOffsetsUsingAdminThatFailsWithRetriable() throws Exception { + // Create a log that uses the admin supplier + setupWithAdmin(); + expectProducerAndConsumerCreate(); + + Set tps = new HashSet<>(Arrays.asList(TP0, TP1)); + Map endOffsets = new HashMap<>(); + endOffsets.put(TP0, 0L); + endOffsets.put(TP1, 0L); + // Getting end offsets upon startup should work fine + admin.endOffsets(EasyMock.eq(tps)); + PowerMock.expectLastCall().andReturn(endOffsets).times(1); + // Getting end offsets using the admin client should fail with leader not available + admin.endOffsets(EasyMock.eq(tps)); + PowerMock.expectLastCall().andThrow(new LeaderNotAvailableException("retry")); + + PowerMock.replayAll(); + + store.start(); + assertThrows(LeaderNotAvailableException.class, () -> store.readEndOffsets(tps)); + } + + @SuppressWarnings("unchecked") + private void setupWithAdmin() { + Supplier adminSupplier = () -> admin; + java.util.function.Consumer initializer = admin -> { }; + store = PowerMock.createPartialMock(KafkaBasedLog.class, new String[]{"createConsumer", "createProducer"}, + TOPIC, PRODUCER_PROPS, CONSUMER_PROPS, adminSupplier, consumedCallback, time, initializer); + } + + private void expectProducerAndConsumerCreate() throws Exception { + PowerMock.expectPrivate(store, "createProducer") + .andReturn(producer); + PowerMock.expectPrivate(store, "createConsumer") + .andReturn(consumer); + } + + private void expectStart() throws Exception { + initializer.run(); + EasyMock.expectLastCall().times(1); + + expectProducerAndConsumerCreate(); + } + + private void expectStop() { + producer.close(); + PowerMock.expectLastCall(); + // MockConsumer close is checked after test. + } + + private static ByteBuffer buffer(String v) { + return ByteBuffer.wrap(v.getBytes()); + } + +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/util/LoggingContextTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/util/LoggingContextTest.java new file mode 100644 index 0000000..a4e6088 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/util/LoggingContextTest.java @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +import org.apache.kafka.connect.util.LoggingContext.Scope; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.slf4j.MDC; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class LoggingContextTest { + + private static final Logger log = LoggerFactory.getLogger(LoggingContextTest.class); + + private static final String CONNECTOR_NAME = "MyConnector"; + private static final ConnectorTaskId TASK_ID1 = new ConnectorTaskId(CONNECTOR_NAME, 1); + private static final String EXTRA_KEY1 = "extra.key.1"; + private static final String EXTRA_VALUE1 = "value1"; + private static final String EXTRA_KEY2 = "extra.key.2"; + private static final String EXTRA_VALUE2 = "value2"; + private static final String EXTRA_KEY3 = "extra.key.3"; + private static final String EXTRA_VALUE3 = "value3"; + + private Map mdc; + + @Before + public void setup() { + mdc = new HashMap<>(); + Map existing = MDC.getCopyOfContextMap(); + if (existing != null) { + mdc.putAll(existing); + } + MDC.put(EXTRA_KEY1, EXTRA_VALUE1); + MDC.put(EXTRA_KEY2, EXTRA_VALUE2); + } + + @After + public void tearDown() { + MDC.clear(); + MDC.setContextMap(mdc); + } + + @Test + public void shouldNotAllowNullConnectorNameForConnectorContext() { + assertThrows(NullPointerException.class, () -> LoggingContext.forConnector(null)); + } + + @Test + public void shouldNotAllowNullTaskIdForTaskContext() { + assertThrows(NullPointerException.class, () -> LoggingContext.forTask(null)); + } + + @Test + public void shouldNotAllowNullTaskIdForOffsetContext() { + assertThrows(NullPointerException.class, () -> LoggingContext.forOffsets(null)); + } + + @Test + public void shouldCreateAndCloseLoggingContextEvenWithNullContextMap() { + MDC.clear(); + assertMdc(null, null, null); + try (LoggingContext loggingContext = LoggingContext.forConnector(CONNECTOR_NAME)) { + assertMdc(CONNECTOR_NAME, null, Scope.WORKER); + log.info("Starting Connector"); + } + assertMdc(null, null, null); + } + + @Test + public void shouldCreateConnectorLoggingContext() { + assertMdcExtrasUntouched(); + assertMdc(null, null, null); + + try (LoggingContext loggingContext = LoggingContext.forConnector(CONNECTOR_NAME)) { + assertMdc(CONNECTOR_NAME, null, Scope.WORKER); + log.info("Starting Connector"); + } + + assertMdcExtrasUntouched(); + assertMdc(null, null, null); + } + + @Test + public void shouldCreateTaskLoggingContext() { + assertMdcExtrasUntouched(); + try (LoggingContext loggingContext = LoggingContext.forTask(TASK_ID1)) { + assertMdc(TASK_ID1.connector(), TASK_ID1.task(), Scope.TASK); + log.info("Running task"); + } + + assertMdcExtrasUntouched(); + assertMdc(null, null, null); + } + + @Test + public void shouldCreateOffsetsLoggingContext() { + assertMdcExtrasUntouched(); + try (LoggingContext loggingContext = LoggingContext.forOffsets(TASK_ID1)) { + assertMdc(TASK_ID1.connector(), TASK_ID1.task(), Scope.OFFSETS); + log.info("Running task"); + } + + assertMdcExtrasUntouched(); + assertMdc(null, null, null); + } + + @Test + public void shouldAllowNestedLoggingContexts() { + assertMdcExtrasUntouched(); + assertMdc(null, null, null); + try (LoggingContext loggingContext1 = LoggingContext.forConnector(CONNECTOR_NAME)) { + assertMdc(CONNECTOR_NAME, null, Scope.WORKER); + log.info("Starting Connector"); + // Set the extra MDC parameter, as if the connector were + MDC.put(EXTRA_KEY3, EXTRA_VALUE3); + assertConnectorMdcSet(); + + try (LoggingContext loggingContext2 = LoggingContext.forTask(TASK_ID1)) { + assertMdc(TASK_ID1.connector(), TASK_ID1.task(), Scope.TASK); + log.info("Starting task"); + // The extra connector-specific MDC parameter should still be set + assertConnectorMdcSet(); + + try (LoggingContext loggingContext3 = LoggingContext.forOffsets(TASK_ID1)) { + assertMdc(TASK_ID1.connector(), TASK_ID1.task(), Scope.OFFSETS); + assertConnectorMdcSet(); + log.info("Offsets for task"); + } + + assertMdc(TASK_ID1.connector(), TASK_ID1.task(), Scope.TASK); + log.info("Stopping task"); + // The extra connector-specific MDC parameter should still be set + assertConnectorMdcSet(); + } + + assertMdc(CONNECTOR_NAME, null, Scope.WORKER); + log.info("Stopping Connector"); + // The extra connector-specific MDC parameter should still be set + assertConnectorMdcSet(); + } + assertMdcExtrasUntouched(); + assertMdc(null, null, null); + + // The extra connector-specific MDC parameter should still be set + assertConnectorMdcSet(); + + LoggingContext.clear(); + assertConnectorMdcUnset(); + } + + protected void assertMdc(String connectorName, Integer taskId, Scope scope) { + String context = MDC.get(LoggingContext.CONNECTOR_CONTEXT); + if (context != null) { + assertEquals( + "Context should begin with connector name when the connector name is non-null", + connectorName != null, + context.startsWith("[" + connectorName) + ); + if (scope != null) { + assertTrue("Context should contain the scope", context.contains(scope.toString())); + } + if (taskId != null) { + assertTrue("Context should contain the taskId", context.contains(taskId.toString())); + } + } else { + assertNull("No logging context found, expected null connector name", connectorName); + assertNull("No logging context found, expected null task ID", taskId); + assertNull("No logging context found, expected null scope", scope); + } + } + + protected void assertMdcExtrasUntouched() { + assertEquals(EXTRA_VALUE1, MDC.get(EXTRA_KEY1)); + assertEquals(EXTRA_VALUE2, MDC.get(EXTRA_KEY2)); + } + + protected void assertConnectorMdcSet() { + assertEquals(EXTRA_VALUE3, MDC.get(EXTRA_KEY3)); + } + + protected void assertConnectorMdcUnset() { + assertNull(MDC.get(EXTRA_KEY3)); + } +} \ No newline at end of file diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/util/ParameterizedTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/util/ParameterizedTest.java new file mode 100644 index 0000000..5f16891 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/util/ParameterizedTest.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +import java.lang.annotation.Annotation; +import org.junit.runner.Description; +import org.junit.runner.manipulation.Filter; +import org.junit.runner.manipulation.NoTestsRemainException; +import org.junit.runners.Parameterized; + +/** + * Running a single parameterized test causes issue as explained in + * http://youtrack.jetbrains.com/issue/IDEA-65966 and + * https://stackoverflow.com/questions/12798079/initializationerror-with-eclipse-and-junit4-when-executing-a-single-test/18438718#18438718 + * + * As a workaround, the original filter needs to be wrapped and then pass it a deparameterized + * description which removes the parameter part (See deparametrizeName) + */ +public class ParameterizedTest extends Parameterized { + + public ParameterizedTest(Class klass) throws Throwable { + super(klass); + } + + @Override + public void filter(Filter filter) throws NoTestsRemainException { + super.filter(new FilterDecorator(filter)); + } + + private static String deparametrizeName(String name) { + //Each parameter is named as [0], [1] etc + if (name.startsWith("[")) { + return name; + } + + //Convert methodName[index](className) to methodName(className) + int indexOfOpenBracket = name.indexOf('['); + int indexOfCloseBracket = name.indexOf(']') + 1; + return name.substring(0, indexOfOpenBracket).concat(name.substring(indexOfCloseBracket)); + } + + private static Description wrap(Description description) { + String fixedName = deparametrizeName(description.getDisplayName()); + Description clonedDescription = Description.createSuiteDescription( + fixedName, + description.getAnnotations().toArray(new Annotation[0]) + ); + description.getChildren().forEach(child -> clonedDescription.addChild(wrap(child))); + return clonedDescription; + } + + private static class FilterDecorator extends Filter { + private final Filter delegate; + + private FilterDecorator(Filter delegate) { + this.delegate = delegate; + } + + @Override + public boolean shouldRun(Description description) { + return delegate.shouldRun(wrap(description)); + } + + @Override + public String describe() { + return delegate.describe(); + } + } +} \ No newline at end of file diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/util/SharedTopicAdminTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/util/SharedTopicAdminTest.java new file mode 100644 index 0000000..f5ac6a7 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/util/SharedTopicAdminTest.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +import java.time.Duration; +import java.util.Collections; +import java.util.Map; +import java.util.function.Function; + +import org.apache.kafka.connect.errors.ConnectException; +import org.junit.Rule; +import org.mockito.Mock; +import org.junit.Before; +import org.junit.Test; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +import static org.apache.kafka.connect.util.SharedTopicAdmin.DEFAULT_CLOSE_DURATION; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class SharedTopicAdminTest { + + private static final Map EMPTY_CONFIG = Collections.emptyMap(); + + @Rule + public MockitoRule rule = MockitoJUnit.rule(); + + @Mock private TopicAdmin mockTopicAdmin; + @Mock private Function, TopicAdmin> factory; + private SharedTopicAdmin sharedAdmin; + + @Before + public void beforeEach() { + when(factory.apply(anyMap())).thenReturn(mockTopicAdmin); + sharedAdmin = new SharedTopicAdmin(EMPTY_CONFIG, factory::apply); + } + + @Test + public void shouldCloseWithoutBeingUsed() { + // When closed before being used + sharedAdmin.close(); + // Then should not create or close admin + verifyTopicAdminCreatesAndCloses(0); + } + + @Test + public void shouldCloseAfterTopicAdminUsed() { + // When used and then closed + assertSame(mockTopicAdmin, sharedAdmin.topicAdmin()); + sharedAdmin.close(); + // Then should have created and closed just one admin + verifyTopicAdminCreatesAndCloses(1); + } + + @Test + public void shouldCloseAfterTopicAdminUsedMultipleTimes() { + // When used many times and then closed + for (int i = 0; i != 10; ++i) { + assertSame(mockTopicAdmin, sharedAdmin.topicAdmin()); + } + sharedAdmin.close(); + // Then should have created and closed just one admin + verifyTopicAdminCreatesAndCloses(1); + } + + @Test + public void shouldCloseWithDurationAfterTopicAdminUsed() { + // When used and then closed with a custom timeout + Duration timeout = Duration.ofSeconds(1); + assertSame(mockTopicAdmin, sharedAdmin.topicAdmin()); + sharedAdmin.close(timeout); + // Then should have created and closed just one admin using the supplied timeout + verifyTopicAdminCreatesAndCloses(1, timeout); + } + + @Test + public void shouldFailToGetTopicAdminAfterClose() { + // When closed + sharedAdmin.close(); + // Then using the admin should fail + assertThrows(ConnectException.class, () -> sharedAdmin.topicAdmin()); + } + + private void verifyTopicAdminCreatesAndCloses(int count) { + verifyTopicAdminCreatesAndCloses(count, DEFAULT_CLOSE_DURATION); + } + + private void verifyTopicAdminCreatesAndCloses(int count, Duration expectedDuration) { + verify(factory, times(count)).apply(anyMap()); + verify(mockTopicAdmin, times(count)).close(eq(expectedDuration)); + } +} \ No newline at end of file diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/util/ShutdownableThreadTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/util/ShutdownableThreadTest.java new file mode 100644 index 0000000..a72937d --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/util/ShutdownableThreadTest.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +import org.junit.Test; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class ShutdownableThreadTest { + + @Test + public void testGracefulShutdown() throws InterruptedException { + ShutdownableThread thread = new ShutdownableThread("graceful") { + @Override + public void execute() { + while (getRunning()) { + try { + Thread.sleep(1); + } catch (InterruptedException e) { + // Ignore + } + } + } + }; + thread.start(); + Thread.sleep(10); + assertTrue(thread.gracefulShutdown(1000, TimeUnit.MILLISECONDS)); + } + + @Test + public void testForcibleShutdown() throws InterruptedException { + final CountDownLatch startedLatch = new CountDownLatch(1); + ShutdownableThread thread = new ShutdownableThread("forcible") { + @Override + public void execute() { + try { + startedLatch.countDown(); + Thread.sleep(100000); + } catch (InterruptedException e) { + // Ignore + } + } + }; + thread.start(); + startedLatch.await(); + thread.forceShutdown(); + // Not all threads can be forcibly stopped since interrupt() doesn't work on threads in + // certain conditions, but in this case we know the thread is interruptible so we should be + // able join() it + thread.join(1000); + assertFalse(thread.isAlive()); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/util/TableTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/util/TableTest.java new file mode 100644 index 0000000..6d41a9d --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/util/TableTest.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +import org.junit.Test; + +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +public class TableTest { + + @Test + public void basicOperations() { + Table table = new Table<>(); + table.put("foo", 5, "bar"); + table.put("foo", 6, "baz"); + assertEquals("bar", table.get("foo", 5)); + assertEquals("baz", table.get("foo", 6)); + + Map row = table.row("foo"); + assertEquals("bar", row.get(5)); + assertEquals("baz", row.get(6)); + + assertEquals("bar", table.remove("foo", 5)); + assertNull(table.get("foo", 5)); + assertEquals("baz", table.remove("foo", 6)); + assertNull(table.get("foo", 6)); + assertTrue(table.row("foo").isEmpty()); + } + +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/util/TestBackgroundThreadExceptionHandler.java b/connect/runtime/src/test/java/org/apache/kafka/connect/util/TestBackgroundThreadExceptionHandler.java new file mode 100644 index 0000000..8726d5c --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/util/TestBackgroundThreadExceptionHandler.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +/** + * An UncaughtExceptionHandler that can be registered with one or more threads which tracks the + * first exception so the main thread can check for uncaught exceptions. + */ +public class TestBackgroundThreadExceptionHandler implements Thread.UncaughtExceptionHandler { + private Throwable firstException = null; + + @Override + public void uncaughtException(Thread t, Throwable e) { + if (this.firstException == null) + this.firstException = e; + } + + public void verifyNoExceptions() { + if (this.firstException != null) + throw new AssertionError(this.firstException); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/util/TestFuture.java b/connect/runtime/src/test/java/org/apache/kafka/connect/util/TestFuture.java new file mode 100644 index 0000000..0883040 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/util/TestFuture.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +public class TestFuture implements Future { + private volatile boolean resolved; + private T result; + private Throwable exception; + private CountDownLatch getCalledLatch; + + private volatile boolean resolveOnGet; + private T resolveOnGetResult; + private Throwable resolveOnGetException; + + public TestFuture() { + resolved = false; + getCalledLatch = new CountDownLatch(1); + + resolveOnGet = false; + resolveOnGetResult = null; + resolveOnGetException = null; + } + + public void resolve(T val) { + this.result = val; + resolved = true; + synchronized (this) { + this.notifyAll(); + } + } + + public void resolve(Throwable t) { + exception = t; + resolved = true; + synchronized (this) { + this.notifyAll(); + } + } + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return false; + } + + @Override + public boolean isCancelled() { + return false; + } + + @Override + public boolean isDone() { + return resolved; + } + + @Override + public T get() throws InterruptedException, ExecutionException { + getCalledLatch.countDown(); + while (true) { + try { + return get(Integer.MAX_VALUE, TimeUnit.DAYS); + } catch (TimeoutException e) { + // ignore + } + } + } + + @Override + public T get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { + getCalledLatch.countDown(); + + if (resolveOnGet) { + if (resolveOnGetException != null) + resolve(resolveOnGetException); + else + resolve(resolveOnGetResult); + } + + synchronized (this) { + while (!resolved) { + this.wait(TimeUnit.MILLISECONDS.convert(timeout, unit)); + } + } + + if (exception != null) { + if (exception instanceof TimeoutException) + throw (TimeoutException) exception; + else if (exception instanceof InterruptedException) + throw (InterruptedException) exception; + else + throw new ExecutionException(exception); + } + return result; + } + + /** + * Set a flag to resolve the future as soon as one of the get() methods has been called. Returns immediately. + * @param val the value to return from the future + */ + public void resolveOnGet(T val) { + resolveOnGet = true; + resolveOnGetResult = val; + } + + /** + * Set a flag to resolve the future as soon as one of the get() methods has been called. Returns immediately. + * @param t the exception to return from the future + */ + public void resolveOnGet(Throwable t) { + resolveOnGet = true; + resolveOnGetException = t; + } + + /** + * Block, waiting for another thread to call one of the get() methods, and then immediately resolve the future with + * the specified value. + * @param val the value to return from the future + */ + public void waitForGetAndResolve(T val) { + waitForGet(); + resolve(val); + } + + /** + * Block, waiting for another thread to call one of the get() methods, and then immediately resolve the future with + * the specified value. + * @param t the exception to use to resolve the future + */ + public void waitForGetAndResolve(Throwable t) { + waitForGet(); + resolve(t); + } + + private void waitForGet() { + try { + getCalledLatch.await(); + } catch (InterruptedException e) { + throw new RuntimeException("Unexpected interruption: ", e); + } + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/util/ThreadedTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/util/ThreadedTest.java new file mode 100644 index 0000000..dd367dd --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/util/ThreadedTest.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +import org.junit.After; +import org.junit.Before; + +/** + * Base class for tests that use threads. It sets up uncaught exception handlers for all known + * thread classes and checks for errors at the end of the test so that failures in background + * threads will cause the test to fail. + */ +public class ThreadedTest { + + protected TestBackgroundThreadExceptionHandler backgroundThreadExceptionHandler; + + @Before + public void setup() { + backgroundThreadExceptionHandler = new TestBackgroundThreadExceptionHandler(); + ShutdownableThread.funcaughtExceptionHandler = backgroundThreadExceptionHandler; + } + + @After + public void teardown() { + backgroundThreadExceptionHandler.verifyNoExceptions(); + ShutdownableThread.funcaughtExceptionHandler = null; + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/util/TopicAdminTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/util/TopicAdminTest.java new file mode 100644 index 0000000..dc25129 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/util/TopicAdminTest.java @@ -0,0 +1,856 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util; + +import org.apache.kafka.clients.NodeApiVersions; +import org.apache.kafka.clients.admin.AdminClientUnitTestEnv; +import org.apache.kafka.clients.admin.Config; +import org.apache.kafka.clients.admin.DescribeTopicsResult; +import org.apache.kafka.clients.admin.MockAdminClient; +import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.clients.admin.TopicDescription; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.TopicPartitionInfo; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.config.TopicConfig; +import org.apache.kafka.common.errors.ClusterAuthorizationException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.TopicAuthorizationException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.CreateTopicsResponseData; +import org.apache.kafka.common.message.CreateTopicsResponseData.CreatableTopicResult; +import org.apache.kafka.common.message.DescribeConfigsResponseData; +import org.apache.kafka.common.message.ListOffsetsResponseData; +import org.apache.kafka.common.message.ListOffsetsResponseData.ListOffsetsTopicResponse; +import org.apache.kafka.common.message.MetadataResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.ApiError; +import org.apache.kafka.common.requests.CreateTopicsResponse; +import org.apache.kafka.common.requests.DescribeConfigsResponse; +import org.apache.kafka.common.requests.ListOffsetsResponse; +import org.apache.kafka.common.requests.MetadataResponse; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.connect.errors.ConnectException; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.apache.kafka.common.message.MetadataResponseData.MetadataResponseTopic; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class TopicAdminTest { + + /** + * 0.11.0.0 clients can talk with older brokers, but the CREATE_TOPIC API was added in 0.10.1.0. That means, + * if our TopicAdmin talks to a pre 0.10.1 broker, it should receive an UnsupportedVersionException, should + * create no topics, and return false. + */ + @Test + public void returnEmptyWithApiVersionMismatchOnCreate() { + final NewTopic newTopic = TopicAdmin.defineTopic("myTopic").partitions(1).compacted().build(); + Cluster cluster = createCluster(1); + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(new MockTime(), cluster)) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + env.kafkaClient().prepareResponse(createTopicResponseWithUnsupportedVersion(newTopic)); + TopicAdmin admin = new TopicAdmin(null, env.adminClient()); + assertTrue(admin.createOrFindTopics(newTopic).isEmpty()); + } + } + + /** + * 0.11.0.0 clients can talk with older brokers, but the DESCRIBE_TOPIC API was added in 0.10.0.0. That means, + * if our TopicAdmin talks to a pre 0.10.0 broker, it should receive an UnsupportedVersionException, should + * create no topics, and return false. + */ + @Test + public void throwsWithApiVersionMismatchOnDescribe() { + final NewTopic newTopic = TopicAdmin.defineTopic("myTopic").partitions(1).compacted().build(); + Cluster cluster = createCluster(1); + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(new MockTime(), cluster)) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + env.kafkaClient().prepareResponse(describeTopicResponseWithUnsupportedVersion(newTopic)); + TopicAdmin admin = new TopicAdmin(null, env.adminClient()); + Exception e = assertThrows(ConnectException.class, () -> admin.describeTopics(newTopic.name())); + assertTrue(e.getCause() instanceof UnsupportedVersionException); + } + } + + @Test + public void returnEmptyWithClusterAuthorizationFailureOnCreate() { + final NewTopic newTopic = TopicAdmin.defineTopic("myTopic").partitions(1).compacted().build(); + Cluster cluster = createCluster(1); + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(new MockTime(), cluster)) { + env.kafkaClient().prepareResponse(createTopicResponseWithClusterAuthorizationException(newTopic)); + TopicAdmin admin = new TopicAdmin(null, env.adminClient()); + assertFalse(admin.createTopic(newTopic)); + + env.kafkaClient().prepareResponse(createTopicResponseWithClusterAuthorizationException(newTopic)); + assertTrue(admin.createOrFindTopics(newTopic).isEmpty()); + } + } + + @Test + public void throwsWithClusterAuthorizationFailureOnDescribe() { + final NewTopic newTopic = TopicAdmin.defineTopic("myTopic").partitions(1).compacted().build(); + Cluster cluster = createCluster(1); + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(new MockTime(), cluster)) { + env.kafkaClient().prepareResponse(describeTopicResponseWithClusterAuthorizationException(newTopic)); + TopicAdmin admin = new TopicAdmin(null, env.adminClient()); + Exception e = assertThrows(ConnectException.class, () -> admin.describeTopics(newTopic.name())); + assertTrue(e.getCause() instanceof ClusterAuthorizationException); + } + } + + @Test + public void returnEmptyWithTopicAuthorizationFailureOnCreate() { + final NewTopic newTopic = TopicAdmin.defineTopic("myTopic").partitions(1).compacted().build(); + Cluster cluster = createCluster(1); + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(new MockTime(), cluster)) { + env.kafkaClient().prepareResponse(createTopicResponseWithTopicAuthorizationException(newTopic)); + TopicAdmin admin = new TopicAdmin(null, env.adminClient()); + assertFalse(admin.createTopic(newTopic)); + + env.kafkaClient().prepareResponse(createTopicResponseWithTopicAuthorizationException(newTopic)); + assertTrue(admin.createOrFindTopics(newTopic).isEmpty()); + } + } + + @Test + public void throwsWithTopicAuthorizationFailureOnDescribe() { + final NewTopic newTopic = TopicAdmin.defineTopic("myTopic").partitions(1).compacted().build(); + Cluster cluster = createCluster(1); + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(new MockTime(), cluster)) { + env.kafkaClient().prepareResponse(describeTopicResponseWithTopicAuthorizationException(newTopic)); + TopicAdmin admin = new TopicAdmin(null, env.adminClient()); + Exception e = assertThrows(ConnectException.class, () -> admin.describeTopics(newTopic.name())); + assertTrue(e.getCause() instanceof TopicAuthorizationException); + } + } + + @Test + public void shouldNotCreateTopicWhenItAlreadyExists() { + NewTopic newTopic = TopicAdmin.defineTopic("myTopic").partitions(1).compacted().build(); + Cluster cluster = createCluster(1); + try (MockAdminClient mockAdminClient = new MockAdminClient(cluster.nodes(), cluster.nodeById(0))) { + TopicPartitionInfo topicPartitionInfo = new TopicPartitionInfo(0, cluster.nodeById(0), cluster.nodes(), Collections.emptyList()); + mockAdminClient.addTopic(false, "myTopic", Collections.singletonList(topicPartitionInfo), null); + TopicAdmin admin = new TopicAdmin(null, mockAdminClient); + assertFalse(admin.createTopic(newTopic)); + assertTrue(admin.createTopics(newTopic).isEmpty()); + assertTrue(admin.createOrFindTopic(newTopic)); + TopicAdmin.TopicCreationResponse response = admin.createOrFindTopics(newTopic); + assertTrue(response.isCreatedOrExisting(newTopic.name())); + assertTrue(response.isExisting(newTopic.name())); + assertFalse(response.isCreated(newTopic.name())); + } + } + + @Test + public void shouldCreateTopicWithPartitionsWhenItDoesNotExist() { + for (int numBrokers = 1; numBrokers < 10; ++numBrokers) { + int expectedReplicas = Math.min(3, numBrokers); + int maxDefaultRf = Math.min(numBrokers, 5); + for (int numPartitions = 1; numPartitions < 30; ++numPartitions) { + NewTopic newTopic = TopicAdmin.defineTopic("myTopic").partitions(numPartitions).compacted().build(); + + // Try clusters with no default replication factor or default partitions + assertTopicCreation(numBrokers, newTopic, null, null, expectedReplicas, numPartitions); + + // Try clusters with different default partitions + for (int defaultPartitions = 1; defaultPartitions < 20; ++defaultPartitions) { + assertTopicCreation(numBrokers, newTopic, defaultPartitions, null, expectedReplicas, numPartitions); + } + + // Try clusters with different default replication factors + for (int defaultRF = 1; defaultRF < maxDefaultRf; ++defaultRF) { + assertTopicCreation(numBrokers, newTopic, null, defaultRF, defaultRF, numPartitions); + } + } + } + } + + @Test + public void shouldCreateTopicWithReplicationFactorWhenItDoesNotExist() { + for (int numBrokers = 1; numBrokers < 10; ++numBrokers) { + int maxRf = Math.min(numBrokers, 5); + int maxDefaultRf = Math.min(numBrokers, 5); + for (short rf = 1; rf < maxRf; ++rf) { + NewTopic newTopic = TopicAdmin.defineTopic("myTopic").replicationFactor(rf).compacted().build(); + + // Try clusters with no default replication factor or default partitions + assertTopicCreation(numBrokers, newTopic, null, null, rf, 1); + + // Try clusters with different default partitions + for (int numPartitions = 1; numPartitions < 30; ++numPartitions) { + assertTopicCreation(numBrokers, newTopic, numPartitions, null, rf, numPartitions); + } + + // Try clusters with different default replication factors + for (int defaultRF = 1; defaultRF < maxDefaultRf; ++defaultRF) { + assertTopicCreation(numBrokers, newTopic, null, defaultRF, rf, 1); + } + } + } + } + + @Test + public void shouldCreateTopicWithDefaultPartitionsAndReplicationFactorWhenItDoesNotExist() { + NewTopic newTopic = TopicAdmin.defineTopic("my-topic") + .defaultPartitions() + .defaultReplicationFactor() + .compacted() + .build(); + + for (int numBrokers = 1; numBrokers < 10; ++numBrokers) { + int expectedReplicas = Math.min(3, numBrokers); + assertTopicCreation(numBrokers, newTopic, null, null, expectedReplicas, 1); + assertTopicCreation(numBrokers, newTopic, 30, null, expectedReplicas, 30); + } + } + + @Test + public void shouldCreateOneTopicWhenProvidedMultipleDefinitionsWithSameTopicName() { + NewTopic newTopic1 = TopicAdmin.defineTopic("myTopic").partitions(1).compacted().build(); + NewTopic newTopic2 = TopicAdmin.defineTopic("myTopic").partitions(1).compacted().build(); + Cluster cluster = createCluster(1); + try (TopicAdmin admin = new TopicAdmin(null, new MockAdminClient(cluster.nodes(), cluster.nodeById(0)))) { + Set newTopicNames = admin.createTopics(newTopic1, newTopic2); + assertEquals(1, newTopicNames.size()); + assertEquals(newTopic2.name(), newTopicNames.iterator().next()); + } + } + + @Test + public void createShouldReturnFalseWhenSuppliedNullTopicDescription() { + Cluster cluster = createCluster(1); + try (TopicAdmin admin = new TopicAdmin(null, new MockAdminClient(cluster.nodes(), cluster.nodeById(0)))) { + boolean created = admin.createTopic(null); + assertFalse(created); + } + } + + @Test + public void describeShouldReturnEmptyWhenTopicDoesNotExist() { + NewTopic newTopic = TopicAdmin.defineTopic("myTopic").partitions(1).compacted().build(); + Cluster cluster = createCluster(1); + try (TopicAdmin admin = new TopicAdmin(null, new MockAdminClient(cluster.nodes(), cluster.nodeById(0)))) { + assertTrue(admin.describeTopics(newTopic.name()).isEmpty()); + } + } + + @Test + public void describeShouldReturnTopicDescriptionWhenTopicExists() { + String topicName = "myTopic"; + NewTopic newTopic = TopicAdmin.defineTopic(topicName).partitions(1).compacted().build(); + Cluster cluster = createCluster(1); + try (MockAdminClient mockAdminClient = new MockAdminClient(cluster.nodes(), cluster.nodeById(0))) { + TopicPartitionInfo topicPartitionInfo = new TopicPartitionInfo(0, cluster.nodeById(0), cluster.nodes(), Collections.emptyList()); + mockAdminClient.addTopic(false, topicName, Collections.singletonList(topicPartitionInfo), null); + TopicAdmin admin = new TopicAdmin(null, mockAdminClient); + Map desc = admin.describeTopics(newTopic.name()); + assertFalse(desc.isEmpty()); + TopicDescription topicDesc = new TopicDescription(topicName, false, Collections.singletonList(topicPartitionInfo)); + assertEquals(desc.get("myTopic"), topicDesc); + } + } + + @Test + public void describeTopicConfigShouldReturnEmptyMapWhenNoTopicsAreSpecified() { + final NewTopic newTopic = TopicAdmin.defineTopic("myTopic").partitions(1).compacted().build(); + Cluster cluster = createCluster(1); + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(new MockTime(), cluster)) { + env.kafkaClient().prepareResponse(describeConfigsResponseWithUnsupportedVersion(newTopic)); + TopicAdmin admin = new TopicAdmin(null, env.adminClient()); + Map results = admin.describeTopicConfigs(); + assertTrue(results.isEmpty()); + } + } + + @Test + public void describeTopicConfigShouldReturnEmptyMapWhenUnsupportedVersionFailure() { + final NewTopic newTopic = TopicAdmin.defineTopic("myTopic").partitions(1).compacted().build(); + Cluster cluster = createCluster(1); + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(new MockTime(), cluster)) { + env.kafkaClient().prepareResponse(describeConfigsResponseWithUnsupportedVersion(newTopic)); + TopicAdmin admin = new TopicAdmin(null, env.adminClient()); + Map results = admin.describeTopicConfigs(newTopic.name()); + assertTrue(results.isEmpty()); + } + } + + @Test + public void describeTopicConfigShouldReturnEmptyMapWhenClusterAuthorizationFailure() { + final NewTopic newTopic = TopicAdmin.defineTopic("myTopic").partitions(1).compacted().build(); + Cluster cluster = createCluster(1); + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(new MockTime(), cluster)) { + env.kafkaClient().prepareResponse(describeConfigsResponseWithClusterAuthorizationException(newTopic)); + TopicAdmin admin = new TopicAdmin(null, env.adminClient()); + Map results = admin.describeTopicConfigs(newTopic.name()); + assertTrue(results.isEmpty()); + } + } + + @Test + public void describeTopicConfigShouldReturnEmptyMapWhenTopicAuthorizationFailure() { + final NewTopic newTopic = TopicAdmin.defineTopic("myTopic").partitions(1).compacted().build(); + Cluster cluster = createCluster(1); + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(new MockTime(), cluster)) { + env.kafkaClient().prepareResponse(describeConfigsResponseWithTopicAuthorizationException(newTopic)); + TopicAdmin admin = new TopicAdmin(null, env.adminClient()); + Map results = admin.describeTopicConfigs(newTopic.name()); + assertTrue(results.isEmpty()); + } + } + + @Test + public void describeTopicConfigShouldReturnMapWithNullValueWhenTopicDoesNotExist() { + NewTopic newTopic = TopicAdmin.defineTopic("myTopic").partitions(1).compacted().build(); + Cluster cluster = createCluster(1); + try (TopicAdmin admin = new TopicAdmin(null, new MockAdminClient(cluster.nodes(), cluster.nodeById(0)))) { + Map results = admin.describeTopicConfigs(newTopic.name()); + assertFalse(results.isEmpty()); + assertEquals(1, results.size()); + assertNull(results.get("myTopic")); + } + } + + @Test + public void describeTopicConfigShouldReturnTopicConfigWhenTopicExists() { + String topicName = "myTopic"; + NewTopic newTopic = TopicAdmin.defineTopic(topicName) + .config(Collections.singletonMap("foo", "bar")) + .partitions(1) + .compacted() + .build(); + Cluster cluster = createCluster(1); + try (MockAdminClient mockAdminClient = new MockAdminClient(cluster.nodes(), cluster.nodeById(0))) { + TopicPartitionInfo topicPartitionInfo = new TopicPartitionInfo(0, cluster.nodeById(0), cluster.nodes(), Collections.emptyList()); + mockAdminClient.addTopic(false, topicName, Collections.singletonList(topicPartitionInfo), null); + TopicAdmin admin = new TopicAdmin(null, mockAdminClient); + Map result = admin.describeTopicConfigs(newTopic.name()); + assertFalse(result.isEmpty()); + assertEquals(1, result.size()); + Config config = result.get("myTopic"); + assertNotNull(config); + config.entries().forEach(entry -> assertEquals(newTopic.configs().get(entry.name()), entry.value())); + } + } + + @Test + public void verifyingTopicCleanupPolicyShouldReturnFalseWhenBrokerVersionIsUnsupported() { + final NewTopic newTopic = TopicAdmin.defineTopic("myTopic").partitions(1).compacted().build(); + Cluster cluster = createCluster(1); + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(new MockTime(), cluster)) { + env.kafkaClient().prepareResponse(describeConfigsResponseWithUnsupportedVersion(newTopic)); + TopicAdmin admin = new TopicAdmin(null, env.adminClient()); + boolean result = admin.verifyTopicCleanupPolicyOnlyCompact("myTopic", "worker.topic", "purpose"); + assertFalse(result); + } + } + + @Test + public void verifyingTopicCleanupPolicyShouldReturnFalseWhenClusterAuthorizationError() { + final NewTopic newTopic = TopicAdmin.defineTopic("myTopic").partitions(1).compacted().build(); + Cluster cluster = createCluster(1); + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(new MockTime(), cluster)) { + env.kafkaClient().prepareResponse(describeConfigsResponseWithClusterAuthorizationException(newTopic)); + TopicAdmin admin = new TopicAdmin(null, env.adminClient()); + boolean result = admin.verifyTopicCleanupPolicyOnlyCompact("myTopic", "worker.topic", "purpose"); + assertFalse(result); + } + } + + @Test + public void verifyingTopicCleanupPolicyShouldReturnFalseWhenTopicAuthorizationError() { + final NewTopic newTopic = TopicAdmin.defineTopic("myTopic").partitions(1).compacted().build(); + Cluster cluster = createCluster(1); + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(new MockTime(), cluster)) { + env.kafkaClient().prepareResponse(describeConfigsResponseWithTopicAuthorizationException(newTopic)); + TopicAdmin admin = new TopicAdmin(null, env.adminClient()); + boolean result = admin.verifyTopicCleanupPolicyOnlyCompact("myTopic", "worker.topic", "purpose"); + assertFalse(result); + } + } + + @Test + public void verifyingTopicCleanupPolicyShouldReturnTrueWhenTopicHasCorrectPolicy() { + String topicName = "myTopic"; + Map topicConfigs = Collections.singletonMap("cleanup.policy", "compact"); + Cluster cluster = createCluster(1); + try (MockAdminClient mockAdminClient = new MockAdminClient(cluster.nodes(), cluster.nodeById(0))) { + TopicPartitionInfo topicPartitionInfo = new TopicPartitionInfo(0, cluster.nodeById(0), cluster.nodes(), Collections.emptyList()); + mockAdminClient.addTopic(false, topicName, Collections.singletonList(topicPartitionInfo), topicConfigs); + TopicAdmin admin = new TopicAdmin(null, mockAdminClient); + boolean result = admin.verifyTopicCleanupPolicyOnlyCompact("myTopic", "worker.topic", "purpose"); + assertTrue(result); + } + } + + @Test + public void verifyingTopicCleanupPolicyShouldFailWhenTopicHasDeletePolicy() { + String topicName = "myTopic"; + Map topicConfigs = Collections.singletonMap("cleanup.policy", "delete"); + Cluster cluster = createCluster(1); + try (MockAdminClient mockAdminClient = new MockAdminClient(cluster.nodes(), cluster.nodeById(0))) { + TopicPartitionInfo topicPartitionInfo = new TopicPartitionInfo(0, cluster.nodeById(0), cluster.nodes(), Collections.emptyList()); + mockAdminClient.addTopic(false, topicName, Collections.singletonList(topicPartitionInfo), topicConfigs); + TopicAdmin admin = new TopicAdmin(null, mockAdminClient); + ConfigException e = assertThrows(ConfigException.class, () -> admin.verifyTopicCleanupPolicyOnlyCompact("myTopic", "worker.topic", "purpose")); + assertTrue(e.getMessage().contains("to guarantee consistency and durability")); + } + } + + @Test + public void verifyingTopicCleanupPolicyShouldFailWhenTopicHasDeleteAndCompactPolicy() { + String topicName = "myTopic"; + Map topicConfigs = Collections.singletonMap("cleanup.policy", "delete,compact"); + Cluster cluster = createCluster(1); + try (MockAdminClient mockAdminClient = new MockAdminClient(cluster.nodes(), cluster.nodeById(0))) { + TopicPartitionInfo topicPartitionInfo = new TopicPartitionInfo(0, cluster.nodeById(0), cluster.nodes(), Collections.emptyList()); + mockAdminClient.addTopic(false, topicName, Collections.singletonList(topicPartitionInfo), topicConfigs); + TopicAdmin admin = new TopicAdmin(null, mockAdminClient); + ConfigException e = assertThrows(ConfigException.class, () -> admin.verifyTopicCleanupPolicyOnlyCompact("myTopic", "worker.topic", "purpose")); + assertTrue(e.getMessage().contains("to guarantee consistency and durability")); + } + } + + @Test + public void verifyingGettingTopicCleanupPolicies() { + String topicName = "myTopic"; + Map topicConfigs = Collections.singletonMap("cleanup.policy", "compact"); + Cluster cluster = createCluster(1); + try (MockAdminClient mockAdminClient = new MockAdminClient(cluster.nodes(), cluster.nodeById(0))) { + TopicPartitionInfo topicPartitionInfo = new TopicPartitionInfo(0, cluster.nodeById(0), cluster.nodes(), Collections.emptyList()); + mockAdminClient.addTopic(false, topicName, Collections.singletonList(topicPartitionInfo), topicConfigs); + TopicAdmin admin = new TopicAdmin(null, mockAdminClient); + Set policies = admin.topicCleanupPolicy("myTopic"); + assertEquals(1, policies.size()); + assertEquals(TopicConfig.CLEANUP_POLICY_COMPACT, policies.iterator().next()); + } + } + + @Test + public void endOffsetsShouldFailWithNonRetriableWhenAuthorizationFailureOccurs() { + String topicName = "myTopic"; + TopicPartition tp1 = new TopicPartition(topicName, 0); + Set tps = Collections.singleton(tp1); + Long offset = null; // response should use error + Cluster cluster = createCluster(1, topicName, 1); + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(new MockTime(), cluster)) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + env.kafkaClient().prepareResponse(prepareMetadataResponse(cluster, Errors.NONE)); + env.kafkaClient().prepareResponse(listOffsetsResultWithClusterAuthorizationException(tp1, offset)); + TopicAdmin admin = new TopicAdmin(null, env.adminClient()); + ConnectException e = assertThrows(ConnectException.class, () -> { + admin.endOffsets(tps); + }); + assertTrue(e.getMessage().contains("Not authorized to get the end offsets")); + } + } + + @Test + public void endOffsetsShouldFailWithUnsupportedVersionWhenVersionUnsupportedErrorOccurs() { + String topicName = "myTopic"; + TopicPartition tp1 = new TopicPartition(topicName, 0); + Set tps = Collections.singleton(tp1); + Long offset = null; // response should use error + Cluster cluster = createCluster(1, topicName, 1); + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(new MockTime(), cluster)) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + env.kafkaClient().prepareResponse(prepareMetadataResponse(cluster, Errors.NONE)); + env.kafkaClient().prepareResponse(listOffsetsResultWithUnsupportedVersion(tp1, offset)); + TopicAdmin admin = new TopicAdmin(null, env.adminClient()); + UnsupportedVersionException e = assertThrows(UnsupportedVersionException.class, () -> { + admin.endOffsets(tps); + }); + } + } + + @Test + public void endOffsetsShouldFailWithTimeoutExceptionWhenTimeoutErrorOccurs() { + String topicName = "myTopic"; + TopicPartition tp1 = new TopicPartition(topicName, 0); + Set tps = Collections.singleton(tp1); + Long offset = null; // response should use error + Cluster cluster = createCluster(1, topicName, 1); + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(new MockTime(), cluster)) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + env.kafkaClient().prepareResponse(prepareMetadataResponse(cluster, Errors.NONE)); + env.kafkaClient().prepareResponse(listOffsetsResultWithTimeout(tp1, offset)); + TopicAdmin admin = new TopicAdmin(null, env.adminClient()); + TimeoutException e = assertThrows(TimeoutException.class, () -> { + admin.endOffsets(tps); + }); + } + } + + @Test + public void endOffsetsShouldFailWithNonRetriableWhenUnknownErrorOccurs() { + String topicName = "myTopic"; + TopicPartition tp1 = new TopicPartition(topicName, 0); + Set tps = Collections.singleton(tp1); + Long offset = null; // response should use error + Cluster cluster = createCluster(1, topicName, 1); + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(new MockTime(), cluster)) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + env.kafkaClient().prepareResponse(prepareMetadataResponse(cluster, Errors.NONE)); + env.kafkaClient().prepareResponse(listOffsetsResultWithUnknownError(tp1, offset)); + TopicAdmin admin = new TopicAdmin(null, env.adminClient()); + ConnectException e = assertThrows(ConnectException.class, () -> { + admin.endOffsets(tps); + }); + assertTrue(e.getMessage().contains("Error while getting end offsets for topic")); + } + } + + @Test + public void endOffsetsShouldReturnEmptyMapWhenPartitionsSetIsNull() { + String topicName = "myTopic"; + Cluster cluster = createCluster(1, topicName, 1); + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(new MockTime(), cluster)) { + TopicAdmin admin = new TopicAdmin(null, env.adminClient()); + Map offsets = admin.endOffsets(Collections.emptySet()); + assertTrue(offsets.isEmpty()); + } + } + + @Test + public void endOffsetsShouldReturnOffsetsForOnePartition() { + String topicName = "myTopic"; + TopicPartition tp1 = new TopicPartition(topicName, 0); + Set tps = Collections.singleton(tp1); + long offset = 1000L; + Cluster cluster = createCluster(1, topicName, 1); + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(new MockTime(), cluster)) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + env.kafkaClient().prepareResponse(prepareMetadataResponse(cluster, Errors.NONE)); + env.kafkaClient().prepareResponse(listOffsetsResult(tp1, offset)); + TopicAdmin admin = new TopicAdmin(null, env.adminClient()); + Map offsets = admin.endOffsets(tps); + assertEquals(1, offsets.size()); + assertEquals(Long.valueOf(offset), offsets.get(tp1)); + } + } + + @Test + public void endOffsetsShouldReturnOffsetsForMultiplePartitions() { + String topicName = "myTopic"; + TopicPartition tp1 = new TopicPartition(topicName, 0); + TopicPartition tp2 = new TopicPartition(topicName, 1); + Set tps = new HashSet<>(Arrays.asList(tp1, tp2)); + long offset1 = 1001; + long offset2 = 1002; + Cluster cluster = createCluster(1, topicName, 2); + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(new MockTime(), cluster)) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + env.kafkaClient().prepareResponse(prepareMetadataResponse(cluster, Errors.NONE)); + env.kafkaClient().prepareResponse(listOffsetsResult(tp1, offset1, tp2, offset2)); + TopicAdmin admin = new TopicAdmin(null, env.adminClient()); + Map offsets = admin.endOffsets(tps); + assertEquals(2, offsets.size()); + assertEquals(Long.valueOf(offset1), offsets.get(tp1)); + assertEquals(Long.valueOf(offset2), offsets.get(tp2)); + } + } + + @Test + public void endOffsetsShouldFailWhenAnyTopicPartitionHasError() { + String topicName = "myTopic"; + TopicPartition tp1 = new TopicPartition(topicName, 0); + Set tps = Collections.singleton(tp1); + long offset = 1000; + Cluster cluster = createCluster(1, topicName, 1); + try (AdminClientUnitTestEnv env = new AdminClientUnitTestEnv(new MockTime(), cluster)) { + env.kafkaClient().setNodeApiVersions(NodeApiVersions.create()); + env.kafkaClient().prepareResponse(prepareMetadataResponse(cluster, Errors.NONE)); + env.kafkaClient().prepareResponse(listOffsetsResultWithClusterAuthorizationException(tp1, null)); + TopicAdmin admin = new TopicAdmin(null, env.adminClient()); + ConnectException e = assertThrows(ConnectException.class, () -> { + admin.endOffsets(tps); + }); + assertTrue(e.getMessage().contains("Not authorized to get the end offsets")); + } + } + + private Cluster createCluster(int numNodes) { + return createCluster(numNodes, "unused", 0); + } + + private Cluster createCluster(int numNodes, String topicName, int partitions) { + Node[] nodeArray = new Node[numNodes]; + HashMap nodes = new HashMap<>(); + for (int i = 0; i < numNodes; ++i) { + nodeArray[i] = new Node(i, "localhost", 8121 + i); + nodes.put(i, nodeArray[i]); + } + Node leader = nodeArray[0]; + List pInfos = new ArrayList<>(); + for (int i = 0; i < partitions; ++i) { + pInfos.add(new PartitionInfo(topicName, i, leader, nodeArray, nodeArray)); + } + Cluster cluster = new Cluster( + "mockClusterId", + nodes.values(), + pInfos, + Collections.emptySet(), + Collections.emptySet(), + leader); + return cluster; + } + + private MetadataResponse prepareMetadataResponse(Cluster cluster, Errors error) { + List metadata = new ArrayList<>(); + for (String topic : cluster.topics()) { + List pms = new ArrayList<>(); + for (PartitionInfo pInfo : cluster.availablePartitionsForTopic(topic)) { + MetadataResponseData.MetadataResponsePartition pm = new MetadataResponseData.MetadataResponsePartition() + .setErrorCode(error.code()) + .setPartitionIndex(pInfo.partition()) + .setLeaderId(pInfo.leader().id()) + .setLeaderEpoch(234) + .setReplicaNodes(Arrays.stream(pInfo.replicas()).map(Node::id).collect(Collectors.toList())) + .setIsrNodes(Arrays.stream(pInfo.inSyncReplicas()).map(Node::id).collect(Collectors.toList())) + .setOfflineReplicas(Arrays.stream(pInfo.offlineReplicas()).map(Node::id).collect(Collectors.toList())); + pms.add(pm); + } + MetadataResponseTopic tm = new MetadataResponseTopic() + .setErrorCode(error.code()) + .setName(topic) + .setIsInternal(false) + .setPartitions(pms); + metadata.add(tm); + } + return MetadataResponse.prepareResponse(true, + 0, + cluster.nodes(), + cluster.clusterResource().clusterId(), + cluster.controller().id(), + metadata, + MetadataResponse.AUTHORIZED_OPERATIONS_OMITTED); + } + + private ListOffsetsResponse listOffsetsResultWithUnknownError(TopicPartition tp1, Long offset1) { + return listOffsetsResult( + new ApiError(Errors.UNKNOWN_SERVER_ERROR, "Unknown error"), + Collections.singletonMap(tp1, offset1) + ); + } + + private ListOffsetsResponse listOffsetsResultWithTimeout(TopicPartition tp1, Long offset1) { + return listOffsetsResult( + new ApiError(Errors.REQUEST_TIMED_OUT, "Request timed out"), + Collections.singletonMap(tp1, offset1) + ); + } + + private ListOffsetsResponse listOffsetsResultWithUnsupportedVersion(TopicPartition tp1, Long offset1) { + return listOffsetsResult( + new ApiError(Errors.UNSUPPORTED_VERSION, "This version of the API is not supported"), + Collections.singletonMap(tp1, offset1) + ); + } + + private ListOffsetsResponse listOffsetsResultWithClusterAuthorizationException(TopicPartition tp1, Long offset1) { + return listOffsetsResult( + new ApiError(Errors.CLUSTER_AUTHORIZATION_FAILED, "Not authorized to create topic(s)"), + Collections.singletonMap(tp1, offset1) + ); + } + + private ListOffsetsResponse listOffsetsResult(TopicPartition tp1, Long offset1) { + return listOffsetsResult(null, Collections.singletonMap(tp1, offset1)); + } + + private ListOffsetsResponse listOffsetsResult(TopicPartition tp1, Long offset1, TopicPartition tp2, Long offset2) { + Map offsetsByPartitions = new HashMap<>(); + offsetsByPartitions.put(tp1, offset1); + offsetsByPartitions.put(tp2, offset2); + return listOffsetsResult(null, offsetsByPartitions); + } + + /** + * Create a ListOffsetResponse that exposes the supplied error and includes offsets for the supplied partitions. + * @param error the error; may be null if an unknown error should be used + * @param offsetsByPartitions offset for each partition, where offset is null signals the error should be used + * @return the response + */ + private ListOffsetsResponse listOffsetsResult(ApiError error, Map offsetsByPartitions) { + if (error == null) error = new ApiError(Errors.UNKNOWN_TOPIC_OR_PARTITION, "unknown topic"); + List tpResponses = new ArrayList<>(); + for (TopicPartition partition : offsetsByPartitions.keySet()) { + Long offset = offsetsByPartitions.get(partition); + ListOffsetsTopicResponse topicResponse; + if (offset == null) { + topicResponse = ListOffsetsResponse.singletonListOffsetsTopicResponse(partition, error.error(), -1L, 0, 321); + } else { + topicResponse = ListOffsetsResponse.singletonListOffsetsTopicResponse(partition, Errors.NONE, -1L, offset, 321); + } + tpResponses.add(topicResponse); + } + ListOffsetsResponseData responseData = new ListOffsetsResponseData() + .setThrottleTimeMs(0) + .setTopics(tpResponses); + + return new ListOffsetsResponse(responseData); + } + + private CreateTopicsResponse createTopicResponseWithUnsupportedVersion(NewTopic... topics) { + return createTopicResponse(new ApiError(Errors.UNSUPPORTED_VERSION, "This version of the API is not supported"), topics); + } + + private CreateTopicsResponse createTopicResponseWithClusterAuthorizationException(NewTopic... topics) { + return createTopicResponse(new ApiError(Errors.CLUSTER_AUTHORIZATION_FAILED, "Not authorized to create topic(s)"), topics); + } + + private CreateTopicsResponse createTopicResponseWithTopicAuthorizationException(NewTopic... topics) { + return createTopicResponse(new ApiError(Errors.TOPIC_AUTHORIZATION_FAILED, "Not authorized to create topic(s)"), topics); + } + + private CreateTopicsResponse createTopicResponse(ApiError error, NewTopic... topics) { + if (error == null) error = new ApiError(Errors.NONE, ""); + CreateTopicsResponseData response = new CreateTopicsResponseData(); + for (NewTopic topic : topics) { + response.topics().add(new CreatableTopicResult(). + setName(topic.name()). + setErrorCode(error.error().code()). + setErrorMessage(error.message())); + } + return new CreateTopicsResponse(response); + } + + protected void assertTopicCreation( + int brokers, + NewTopic newTopic, + Integer defaultPartitions, + Integer defaultReplicationFactor, + int expectedReplicas, + int expectedPartitions + ) { + Cluster cluster = createCluster(brokers); + MockAdminClient.Builder clientBuilder = MockAdminClient.create(); + if (defaultPartitions != null) { + clientBuilder.defaultPartitions(defaultPartitions.shortValue()); + } + if (defaultReplicationFactor != null) { + clientBuilder.defaultReplicationFactor(defaultReplicationFactor); + } + clientBuilder.brokers(cluster.nodes()); + clientBuilder.controller(0); + try (MockAdminClient admin = clientBuilder.build()) { + TopicAdmin topicClient = new TopicAdmin(null, admin, false); + TopicAdmin.TopicCreationResponse response = topicClient.createOrFindTopics(newTopic); + assertTrue(response.isCreated(newTopic.name())); + assertFalse(response.isExisting(newTopic.name())); + assertTopic(admin, newTopic.name(), expectedPartitions, expectedReplicas); + } + } + + protected void assertTopic(MockAdminClient admin, String topicName, int expectedPartitions, int expectedReplicas) { + TopicDescription desc = null; + try { + desc = topicDescription(admin, topicName); + } catch (Throwable t) { + fail("Failed to find topic description for topic '" + topicName + "'"); + } + assertEquals(expectedPartitions, desc.partitions().size()); + for (TopicPartitionInfo tp : desc.partitions()) { + assertEquals(expectedReplicas, tp.replicas().size()); + } + } + + protected TopicDescription topicDescription(MockAdminClient admin, String topicName) + throws ExecutionException, InterruptedException { + DescribeTopicsResult result = admin.describeTopics(Collections.singleton(topicName)); + Map> byName = result.topicNameValues(); + return byName.get(topicName).get(); + } + + private MetadataResponse describeTopicResponseWithUnsupportedVersion(NewTopic... topics) { + return describeTopicResponse(new ApiError(Errors.UNSUPPORTED_VERSION, "This version of the API is not supported"), topics); + } + + private MetadataResponse describeTopicResponseWithClusterAuthorizationException(NewTopic... topics) { + return describeTopicResponse(new ApiError(Errors.CLUSTER_AUTHORIZATION_FAILED, "Not authorized to create topic(s)"), topics); + } + + private MetadataResponse describeTopicResponseWithTopicAuthorizationException(NewTopic... topics) { + return describeTopicResponse(new ApiError(Errors.TOPIC_AUTHORIZATION_FAILED, "Not authorized to create topic(s)"), topics); + } + + private MetadataResponse describeTopicResponse(ApiError error, NewTopic... topics) { + if (error == null) error = new ApiError(Errors.NONE, ""); + MetadataResponseData response = new MetadataResponseData(); + for (NewTopic topic : topics) { + response.topics().add(new MetadataResponseTopic() + .setName(topic.name()) + .setErrorCode(error.error().code())); + } + return new MetadataResponse(response, ApiKeys.METADATA.latestVersion()); + } + + private DescribeConfigsResponse describeConfigsResponseWithUnsupportedVersion(NewTopic... topics) { + return describeConfigsResponse(new ApiError(Errors.UNSUPPORTED_VERSION, "This version of the API is not supported"), topics); + } + + private DescribeConfigsResponse describeConfigsResponseWithClusterAuthorizationException(NewTopic... topics) { + return describeConfigsResponse(new ApiError(Errors.CLUSTER_AUTHORIZATION_FAILED, "Not authorized to create topic(s)"), topics); + } + + private DescribeConfigsResponse describeConfigsResponseWithTopicAuthorizationException(NewTopic... topics) { + return describeConfigsResponse(new ApiError(Errors.TOPIC_AUTHORIZATION_FAILED, "Not authorized to create topic(s)"), topics); + } + + private DescribeConfigsResponse describeConfigsResponse(ApiError error, NewTopic... topics) { + List results = Stream.of(topics) + .map(topic -> new DescribeConfigsResponseData.DescribeConfigsResult() + .setErrorCode(error.error().code()) + .setErrorMessage(error.message()) + .setResourceType(ConfigResource.Type.TOPIC.id()) + .setResourceName(topic.name()) + .setConfigs(topic.configs().entrySet() + .stream() + .map(e -> new DescribeConfigsResponseData.DescribeConfigsResourceResult() + .setName(e.getKey()) + .setValue(e.getValue())) + .collect(Collectors.toList()))) + .collect(Collectors.toList()); + return new DescribeConfigsResponse(new DescribeConfigsResponseData().setThrottleTimeMs(1000).setResults(results)); + } + +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/util/TopicCreationTest.java b/connect/runtime/src/test/java/org/apache/kafka/connect/util/TopicCreationTest.java new file mode 100644 index 0000000..feb0e5f --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/util/TopicCreationTest.java @@ -0,0 +1,638 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.util; + +import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.runtime.SourceConnectorConfig; +import org.apache.kafka.connect.runtime.WorkerConfig; +import org.apache.kafka.connect.runtime.distributed.DistributedConfig; +import org.apache.kafka.connect.source.SourceRecord; +import org.apache.kafka.connect.storage.StringConverter; +import org.apache.kafka.connect.transforms.Cast; +import org.apache.kafka.connect.transforms.RegexRouter; +import org.apache.kafka.connect.transforms.Transformation; +import org.junit.Before; +import org.junit.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.apache.kafka.common.config.TopicConfig.CLEANUP_POLICY_COMPACT; +import static org.apache.kafka.common.config.TopicConfig.CLEANUP_POLICY_CONFIG; +import static org.apache.kafka.common.config.TopicConfig.COMPRESSION_TYPE_CONFIG; +import static org.apache.kafka.common.config.TopicConfig.RETENTION_MS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.CONNECTOR_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.NAME_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfigTest.MOCK_PLUGINS; +import static org.apache.kafka.connect.runtime.SourceConnectorConfig.TOPIC_CREATION_GROUPS_CONFIG; +import static org.apache.kafka.connect.runtime.SourceConnectorConfig.TOPIC_CREATION_PREFIX; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.DEFAULT_TOPIC_CREATION_GROUP; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.DEFAULT_TOPIC_CREATION_PREFIX; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.EXCLUDE_REGEX_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.INCLUDE_REGEX_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.PARTITIONS_CONFIG; +import static org.apache.kafka.connect.runtime.TopicCreationConfig.REPLICATION_FACTOR_CONFIG; +import static org.apache.kafka.connect.runtime.WorkerConfig.BOOTSTRAP_SERVERS_CONFIG; +import static org.apache.kafka.connect.runtime.WorkerConfig.KEY_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.WorkerConfig.TOPIC_CREATION_ENABLE_CONFIG; +import static org.apache.kafka.connect.runtime.WorkerConfig.VALUE_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.distributed.DistributedConfig.CONFIG_TOPIC_CONFIG; +import static org.apache.kafka.connect.runtime.distributed.DistributedConfig.GROUP_ID_CONFIG; +import static org.apache.kafka.connect.runtime.distributed.DistributedConfig.OFFSET_STORAGE_TOPIC_CONFIG; +import static org.apache.kafka.connect.runtime.distributed.DistributedConfig.STATUS_STORAGE_TOPIC_CONFIG; +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.CoreMatchers.hasItems; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +public class TopicCreationTest { + + private static final String FOO_CONNECTOR = "foo-source"; + private static final String FOO_GROUP = "foo"; + private static final String FOO_TOPIC = "foo-topic"; + private static final String FOO_REGEX = ".*foo.*"; + + private static final String BAR_GROUP = "bar"; + private static final String BAR_TOPIC = "bar-topic"; + private static final String BAR_REGEX = ".*bar.*"; + + private static final short DEFAULT_REPLICATION_FACTOR = -1; + private static final int DEFAULT_PARTITIONS = -1; + + Map workerProps; + WorkerConfig workerConfig; + Map sourceProps; + SourceConnectorConfig sourceConfig; + + @Before + public void setup() { + workerProps = defaultWorkerProps(); + workerConfig = new DistributedConfig(workerProps); + } + + public Map defaultWorkerProps() { + Map props = new HashMap<>(); + props.put(GROUP_ID_CONFIG, "connect-cluster"); + props.put(BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + props.put(CONFIG_TOPIC_CONFIG, "connect-configs"); + props.put(OFFSET_STORAGE_TOPIC_CONFIG, "connect-offsets"); + props.put(STATUS_STORAGE_TOPIC_CONFIG, "connect-status"); + props.put(KEY_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(VALUE_CONVERTER_CLASS_CONFIG, StringConverter.class.getName()); + props.put(TOPIC_CREATION_ENABLE_CONFIG, String.valueOf(true)); + return props; + } + + public Map defaultConnectorProps() { + Map props = new HashMap<>(); + props.put(NAME_CONFIG, FOO_CONNECTOR); + props.put(CONNECTOR_CLASS_CONFIG, "TestConnector"); + return props; + } + + public Map defaultConnectorPropsWithTopicCreation() { + Map props = defaultConnectorProps(); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + REPLICATION_FACTOR_CONFIG, String.valueOf(DEFAULT_REPLICATION_FACTOR)); + props.put(DEFAULT_TOPIC_CREATION_PREFIX + PARTITIONS_CONFIG, String.valueOf(DEFAULT_PARTITIONS)); + return props; + } + + @Test + public void testTopicCreationWhenTopicCreationIsEnabled() { + sourceProps = defaultConnectorPropsWithTopicCreation(); + sourceProps.put(TOPIC_CREATION_GROUPS_CONFIG, String.join(",", FOO_GROUP, BAR_GROUP)); + sourceConfig = new SourceConnectorConfig(MOCK_PLUGINS, sourceProps, true); + + Map groups = TopicCreationGroup.configuredGroups(sourceConfig); + TopicCreation topicCreation = TopicCreation.newTopicCreation(workerConfig, groups); + + assertTrue(topicCreation.isTopicCreationEnabled()); + assertTrue(topicCreation.isTopicCreationRequired(FOO_TOPIC)); + assertThat(topicCreation.defaultTopicGroup(), is(groups.get(DEFAULT_TOPIC_CREATION_GROUP))); + assertEquals(2, topicCreation.topicGroups().size()); + assertThat(topicCreation.topicGroups().keySet(), hasItems(FOO_GROUP, BAR_GROUP)); + assertEquals(topicCreation.defaultTopicGroup(), topicCreation.findFirstGroup(FOO_TOPIC)); + topicCreation.addTopic(FOO_TOPIC); + assertFalse(topicCreation.isTopicCreationRequired(FOO_TOPIC)); + } + + @Test + public void testTopicCreationWhenTopicCreationIsDisabled() { + workerProps.put(TOPIC_CREATION_ENABLE_CONFIG, String.valueOf(false)); + workerConfig = new DistributedConfig(workerProps); + sourceProps = defaultConnectorPropsWithTopicCreation(); + sourceConfig = new SourceConnectorConfig(MOCK_PLUGINS, sourceProps, true); + + TopicCreation topicCreation = TopicCreation.newTopicCreation(workerConfig, + TopicCreationGroup.configuredGroups(sourceConfig)); + + assertFalse(topicCreation.isTopicCreationEnabled()); + assertFalse(topicCreation.isTopicCreationRequired(FOO_TOPIC)); + assertNull(topicCreation.defaultTopicGroup()); + assertThat(topicCreation.topicGroups(), is(Collections.emptyMap())); + assertNull(topicCreation.findFirstGroup(FOO_TOPIC)); + topicCreation.addTopic(FOO_TOPIC); + assertFalse(topicCreation.isTopicCreationRequired(FOO_TOPIC)); + } + + @Test + public void testEmptyTopicCreation() { + TopicCreation topicCreation = TopicCreation.newTopicCreation(workerConfig, null); + + assertEquals(TopicCreation.empty(), topicCreation); + assertFalse(topicCreation.isTopicCreationEnabled()); + assertFalse(topicCreation.isTopicCreationRequired(FOO_TOPIC)); + assertNull(topicCreation.defaultTopicGroup()); + assertEquals(0, topicCreation.topicGroups().size()); + assertThat(topicCreation.topicGroups(), is(Collections.emptyMap())); + assertNull(topicCreation.findFirstGroup(FOO_TOPIC)); + topicCreation.addTopic(FOO_TOPIC); + assertFalse(topicCreation.isTopicCreationRequired(FOO_TOPIC)); + } + + @Test + public void withDefaultTopicCreation() { + sourceProps = defaultConnectorPropsWithTopicCreation(); + // Setting here but they should be ignored for the default group + sourceProps.put(TOPIC_CREATION_PREFIX + DEFAULT_TOPIC_CREATION_GROUP + "." + INCLUDE_REGEX_CONFIG, FOO_REGEX); + sourceProps.put(TOPIC_CREATION_PREFIX + DEFAULT_TOPIC_CREATION_GROUP + "." + EXCLUDE_REGEX_CONFIG, BAR_REGEX); + + // verify config creation + sourceConfig = new SourceConnectorConfig(MOCK_PLUGINS, sourceProps, true); + assertTrue(sourceConfig.usesTopicCreation()); + assertEquals(DEFAULT_REPLICATION_FACTOR, (short) sourceConfig.topicCreationReplicationFactor(DEFAULT_TOPIC_CREATION_GROUP)); + assertEquals(DEFAULT_PARTITIONS, (int) sourceConfig.topicCreationPartitions(DEFAULT_TOPIC_CREATION_GROUP)); + assertThat(sourceConfig.topicCreationInclude(DEFAULT_TOPIC_CREATION_GROUP), is(Collections.singletonList(".*"))); + assertThat(sourceConfig.topicCreationExclude(DEFAULT_TOPIC_CREATION_GROUP), is(Collections.emptyList())); + assertThat(sourceConfig.topicCreationOtherConfigs(DEFAULT_TOPIC_CREATION_GROUP), is(Collections.emptyMap())); + + // verify topic creation group is instantiated correctly + Map groups = TopicCreationGroup.configuredGroups(sourceConfig); + assertEquals(1, groups.size()); + assertThat(groups.keySet(), hasItem(DEFAULT_TOPIC_CREATION_GROUP)); + + // verify topic creation + TopicCreation topicCreation = TopicCreation.newTopicCreation(workerConfig, groups); + TopicCreationGroup group = topicCreation.defaultTopicGroup(); + // Default group will match all topics besides empty string + assertTrue(group.matches(" ")); + assertTrue(group.matches(FOO_TOPIC)); + assertEquals(DEFAULT_TOPIC_CREATION_GROUP, group.name()); + assertTrue(topicCreation.isTopicCreationEnabled()); + assertTrue(topicCreation.isTopicCreationRequired(FOO_TOPIC)); + assertThat(topicCreation.topicGroups(), is(Collections.emptyMap())); + assertEquals(topicCreation.defaultTopicGroup(), topicCreation.findFirstGroup(FOO_TOPIC)); + topicCreation.addTopic(FOO_TOPIC); + assertFalse(topicCreation.isTopicCreationRequired(FOO_TOPIC)); + + // verify new topic properties + NewTopic topicSpec = topicCreation.findFirstGroup(FOO_TOPIC).newTopic(FOO_TOPIC); + assertEquals(FOO_TOPIC, topicSpec.name()); + assertEquals(DEFAULT_REPLICATION_FACTOR, topicSpec.replicationFactor()); + assertEquals(DEFAULT_PARTITIONS, topicSpec.numPartitions()); + assertThat(topicSpec.configs(), is(Collections.emptyMap())); + } + + @Test + public void topicCreationWithDefaultGroupAndCustomProps() { + short replicas = 3; + int partitions = 5; + long retentionMs = TimeUnit.DAYS.toMillis(30); + String compressionType = "lz4"; + Map topicProps = new HashMap<>(); + topicProps.put(COMPRESSION_TYPE_CONFIG, compressionType); + topicProps.put(RETENTION_MS_CONFIG, String.valueOf(retentionMs)); + + sourceProps = defaultConnectorPropsWithTopicCreation(); + sourceProps.put(DEFAULT_TOPIC_CREATION_PREFIX + REPLICATION_FACTOR_CONFIG, String.valueOf(replicas)); + sourceProps.put(DEFAULT_TOPIC_CREATION_PREFIX + PARTITIONS_CONFIG, String.valueOf(partitions)); + topicProps.forEach((k, v) -> sourceProps.put(DEFAULT_TOPIC_CREATION_PREFIX + k, v)); + // Setting here but they should be ignored for the default group + sourceProps.put(TOPIC_CREATION_PREFIX + DEFAULT_TOPIC_CREATION_GROUP + "." + INCLUDE_REGEX_CONFIG, FOO_REGEX); + sourceProps.put(TOPIC_CREATION_PREFIX + DEFAULT_TOPIC_CREATION_GROUP + "." + EXCLUDE_REGEX_CONFIG, BAR_REGEX); + + // verify config creation + sourceConfig = new SourceConnectorConfig(MOCK_PLUGINS, sourceProps, true); + assertTrue(sourceConfig.usesTopicCreation()); + assertEquals(replicas, (short) sourceConfig.topicCreationReplicationFactor(DEFAULT_TOPIC_CREATION_GROUP)); + assertEquals(partitions, (int) sourceConfig.topicCreationPartitions(DEFAULT_TOPIC_CREATION_GROUP)); + assertThat(sourceConfig.topicCreationInclude(DEFAULT_TOPIC_CREATION_GROUP), is(Collections.singletonList(".*"))); + assertThat(sourceConfig.topicCreationExclude(DEFAULT_TOPIC_CREATION_GROUP), is(Collections.emptyList())); + assertThat(sourceConfig.topicCreationOtherConfigs(DEFAULT_TOPIC_CREATION_GROUP), is(topicProps)); + + // verify topic creation group is instantiated correctly + Map groups = TopicCreationGroup.configuredGroups(sourceConfig); + assertEquals(1, groups.size()); + assertThat(groups.keySet(), hasItem(DEFAULT_TOPIC_CREATION_GROUP)); + + // verify topic creation + TopicCreation topicCreation = TopicCreation.newTopicCreation(workerConfig, groups); + TopicCreationGroup group = topicCreation.defaultTopicGroup(); + // Default group will match all topics besides empty string + assertTrue(group.matches(" ")); + assertTrue(group.matches(FOO_TOPIC)); + assertEquals(DEFAULT_TOPIC_CREATION_GROUP, group.name()); + assertTrue(topicCreation.isTopicCreationEnabled()); + assertTrue(topicCreation.isTopicCreationRequired(FOO_TOPIC)); + assertThat(topicCreation.topicGroups(), is(Collections.emptyMap())); + assertEquals(topicCreation.defaultTopicGroup(), topicCreation.findFirstGroup(FOO_TOPIC)); + topicCreation.addTopic(FOO_TOPIC); + assertFalse(topicCreation.isTopicCreationRequired(FOO_TOPIC)); + + // verify new topic properties + NewTopic topicSpec = topicCreation.findFirstGroup(FOO_TOPIC).newTopic(FOO_TOPIC); + assertEquals(FOO_TOPIC, topicSpec.name()); + assertEquals(replicas, topicSpec.replicationFactor()); + assertEquals(partitions, topicSpec.numPartitions()); + assertThat(topicSpec.configs(), is(topicProps)); + } + + @Test + public void topicCreationWithOneGroup() { + short fooReplicas = 3; + int partitions = 5; + sourceProps = defaultConnectorPropsWithTopicCreation(); + sourceProps.put(TOPIC_CREATION_GROUPS_CONFIG, String.join(",", FOO_GROUP)); + sourceProps.put(DEFAULT_TOPIC_CREATION_PREFIX + PARTITIONS_CONFIG, String.valueOf(partitions)); + sourceProps.put(TOPIC_CREATION_PREFIX + FOO_GROUP + "." + INCLUDE_REGEX_CONFIG, FOO_REGEX); + sourceProps.put(TOPIC_CREATION_PREFIX + FOO_GROUP + "." + EXCLUDE_REGEX_CONFIG, BAR_REGEX); + sourceProps.put(TOPIC_CREATION_PREFIX + FOO_GROUP + "." + REPLICATION_FACTOR_CONFIG, String.valueOf(fooReplicas)); + + Map topicProps = new HashMap<>(); + topicProps.put(CLEANUP_POLICY_CONFIG, CLEANUP_POLICY_COMPACT); + topicProps.forEach((k, v) -> sourceProps.put(TOPIC_CREATION_PREFIX + FOO_GROUP + "." + k, v)); + + // verify config creation + sourceConfig = new SourceConnectorConfig(MOCK_PLUGINS, sourceProps, true); + assertTrue(sourceConfig.usesTopicCreation()); + assertEquals(DEFAULT_REPLICATION_FACTOR, (short) sourceConfig.topicCreationReplicationFactor(DEFAULT_TOPIC_CREATION_GROUP)); + assertEquals(partitions, (int) sourceConfig.topicCreationPartitions(DEFAULT_TOPIC_CREATION_GROUP)); + assertThat(sourceConfig.topicCreationInclude(DEFAULT_TOPIC_CREATION_GROUP), is(Collections.singletonList(".*"))); + assertThat(sourceConfig.topicCreationExclude(DEFAULT_TOPIC_CREATION_GROUP), is(Collections.emptyList())); + assertThat(sourceConfig.topicCreationOtherConfigs(DEFAULT_TOPIC_CREATION_GROUP), is(Collections.emptyMap())); + + // verify topic creation group is instantiated correctly + Map groups = TopicCreationGroup.configuredGroups(sourceConfig); + assertEquals(2, groups.size()); + assertThat(groups.keySet(), hasItems(DEFAULT_TOPIC_CREATION_GROUP, FOO_GROUP)); + + // verify topic creation + TopicCreation topicCreation = TopicCreation.newTopicCreation(workerConfig, groups); + TopicCreationGroup defaultGroup = topicCreation.defaultTopicGroup(); + // Default group will match all topics besides empty string + assertTrue(defaultGroup.matches(" ")); + assertTrue(defaultGroup.matches(FOO_TOPIC)); + assertTrue(defaultGroup.matches(BAR_TOPIC)); + assertEquals(DEFAULT_TOPIC_CREATION_GROUP, defaultGroup.name()); + TopicCreationGroup fooGroup = groups.get(FOO_GROUP); + assertFalse(fooGroup.matches(" ")); + assertTrue(fooGroup.matches(FOO_TOPIC)); + assertFalse(fooGroup.matches(BAR_TOPIC)); + assertEquals(FOO_GROUP, fooGroup.name()); + + assertTrue(topicCreation.isTopicCreationEnabled()); + assertTrue(topicCreation.isTopicCreationRequired(FOO_TOPIC)); + assertEquals(1, topicCreation.topicGroups().size()); + assertThat(topicCreation.topicGroups().keySet(), hasItems(FOO_GROUP)); + assertEquals(fooGroup, topicCreation.findFirstGroup(FOO_TOPIC)); + topicCreation.addTopic(FOO_TOPIC); + assertFalse(topicCreation.isTopicCreationRequired(FOO_TOPIC)); + + // verify new topic properties + NewTopic defaultTopicSpec = topicCreation.findFirstGroup(BAR_TOPIC).newTopic(BAR_TOPIC); + assertEquals(BAR_TOPIC, defaultTopicSpec.name()); + assertEquals(DEFAULT_REPLICATION_FACTOR, defaultTopicSpec.replicationFactor()); + assertEquals(partitions, defaultTopicSpec.numPartitions()); + assertThat(defaultTopicSpec.configs(), is(Collections.emptyMap())); + + NewTopic fooTopicSpec = topicCreation.findFirstGroup(FOO_TOPIC).newTopic(FOO_TOPIC); + assertEquals(FOO_TOPIC, fooTopicSpec.name()); + assertEquals(fooReplicas, fooTopicSpec.replicationFactor()); + assertEquals(partitions, fooTopicSpec.numPartitions()); + assertThat(fooTopicSpec.configs(), is(topicProps)); + } + + @Test + public void topicCreationWithOneGroupAndCombinedRegex() { + short fooReplicas = 3; + int partitions = 5; + sourceProps = defaultConnectorPropsWithTopicCreation(); + sourceProps.put(TOPIC_CREATION_GROUPS_CONFIG, String.join(",", FOO_GROUP)); + sourceProps.put(DEFAULT_TOPIC_CREATION_PREFIX + PARTITIONS_CONFIG, String.valueOf(partitions)); + // Setting here but they should be ignored for the default group + sourceProps.put(TOPIC_CREATION_PREFIX + FOO_GROUP + "." + INCLUDE_REGEX_CONFIG, String.join("|", FOO_REGEX, BAR_REGEX)); + sourceProps.put(TOPIC_CREATION_PREFIX + FOO_GROUP + "." + REPLICATION_FACTOR_CONFIG, String.valueOf(fooReplicas)); + + Map topicProps = new HashMap<>(); + topicProps.put(CLEANUP_POLICY_CONFIG, CLEANUP_POLICY_COMPACT); + topicProps.forEach((k, v) -> sourceProps.put(TOPIC_CREATION_PREFIX + FOO_GROUP + "." + k, v)); + + // verify config creation + sourceConfig = new SourceConnectorConfig(MOCK_PLUGINS, sourceProps, true); + assertTrue(sourceConfig.usesTopicCreation()); + assertEquals(DEFAULT_REPLICATION_FACTOR, (short) sourceConfig.topicCreationReplicationFactor(DEFAULT_TOPIC_CREATION_GROUP)); + assertEquals(partitions, (int) sourceConfig.topicCreationPartitions(DEFAULT_TOPIC_CREATION_GROUP)); + assertThat(sourceConfig.topicCreationInclude(DEFAULT_TOPIC_CREATION_GROUP), is(Collections.singletonList(".*"))); + assertThat(sourceConfig.topicCreationExclude(DEFAULT_TOPIC_CREATION_GROUP), is(Collections.emptyList())); + assertThat(sourceConfig.topicCreationOtherConfigs(DEFAULT_TOPIC_CREATION_GROUP), is(Collections.emptyMap())); + + // verify topic creation group is instantiated correctly + Map groups = TopicCreationGroup.configuredGroups(sourceConfig); + assertEquals(2, groups.size()); + assertThat(groups.keySet(), hasItems(DEFAULT_TOPIC_CREATION_GROUP, FOO_GROUP)); + + // verify topic creation + TopicCreation topicCreation = TopicCreation.newTopicCreation(workerConfig, groups); + TopicCreationGroup defaultGroup = topicCreation.defaultTopicGroup(); + // Default group will match all topics besides empty string + assertTrue(defaultGroup.matches(" ")); + assertTrue(defaultGroup.matches(FOO_TOPIC)); + assertTrue(defaultGroup.matches(BAR_TOPIC)); + assertEquals(DEFAULT_TOPIC_CREATION_GROUP, defaultGroup.name()); + TopicCreationGroup fooGroup = groups.get(FOO_GROUP); + assertFalse(fooGroup.matches(" ")); + assertTrue(fooGroup.matches(FOO_TOPIC)); + assertTrue(fooGroup.matches(BAR_TOPIC)); + assertEquals(FOO_GROUP, fooGroup.name()); + + assertTrue(topicCreation.isTopicCreationEnabled()); + assertTrue(topicCreation.isTopicCreationRequired(FOO_TOPIC)); + assertTrue(topicCreation.isTopicCreationRequired(BAR_TOPIC)); + assertEquals(1, topicCreation.topicGroups().size()); + assertThat(topicCreation.topicGroups().keySet(), hasItems(FOO_GROUP)); + assertEquals(fooGroup, topicCreation.findFirstGroup(FOO_TOPIC)); + assertEquals(fooGroup, topicCreation.findFirstGroup(BAR_TOPIC)); + topicCreation.addTopic(FOO_TOPIC); + topicCreation.addTopic(BAR_TOPIC); + assertFalse(topicCreation.isTopicCreationRequired(FOO_TOPIC)); + assertFalse(topicCreation.isTopicCreationRequired(BAR_TOPIC)); + + // verify new topic properties + NewTopic fooTopicSpec = topicCreation.findFirstGroup(FOO_TOPIC).newTopic(FOO_TOPIC); + assertEquals(FOO_TOPIC, fooTopicSpec.name()); + assertEquals(fooReplicas, fooTopicSpec.replicationFactor()); + assertEquals(partitions, fooTopicSpec.numPartitions()); + assertThat(fooTopicSpec.configs(), is(topicProps)); + + NewTopic barTopicSpec = topicCreation.findFirstGroup(BAR_TOPIC).newTopic(BAR_TOPIC); + assertEquals(BAR_TOPIC, barTopicSpec.name()); + assertEquals(fooReplicas, barTopicSpec.replicationFactor()); + assertEquals(partitions, barTopicSpec.numPartitions()); + assertThat(barTopicSpec.configs(), is(topicProps)); + } + + @Test + public void topicCreationWithTwoGroups() { + short fooReplicas = 3; + int partitions = 5; + int barPartitions = 1; + + sourceProps = defaultConnectorPropsWithTopicCreation(); + sourceProps.put(TOPIC_CREATION_GROUPS_CONFIG, String.join(",", FOO_GROUP, BAR_GROUP)); + sourceProps.put(DEFAULT_TOPIC_CREATION_PREFIX + PARTITIONS_CONFIG, String.valueOf(partitions)); + // Setting here but they should be ignored for the default group + sourceProps.put(TOPIC_CREATION_PREFIX + FOO_GROUP + "." + INCLUDE_REGEX_CONFIG, FOO_TOPIC); + sourceProps.put(TOPIC_CREATION_PREFIX + FOO_GROUP + "." + REPLICATION_FACTOR_CONFIG, String.valueOf(fooReplicas)); + sourceProps.put(TOPIC_CREATION_PREFIX + BAR_GROUP + "." + INCLUDE_REGEX_CONFIG, BAR_REGEX); + sourceProps.put(TOPIC_CREATION_PREFIX + BAR_GROUP + "." + PARTITIONS_CONFIG, String.valueOf(barPartitions)); + + Map fooTopicProps = new HashMap<>(); + fooTopicProps.put(RETENTION_MS_CONFIG, String.valueOf(TimeUnit.DAYS.toMillis(30))); + fooTopicProps.forEach((k, v) -> sourceProps.put(TOPIC_CREATION_PREFIX + FOO_GROUP + "." + k, v)); + + Map barTopicProps = new HashMap<>(); + barTopicProps.put(CLEANUP_POLICY_CONFIG, CLEANUP_POLICY_COMPACT); + barTopicProps.forEach((k, v) -> sourceProps.put(TOPIC_CREATION_PREFIX + BAR_GROUP + "." + k, v)); + + // verify config creation + sourceConfig = new SourceConnectorConfig(MOCK_PLUGINS, sourceProps, true); + assertTrue(sourceConfig.usesTopicCreation()); + assertEquals(DEFAULT_REPLICATION_FACTOR, (short) sourceConfig.topicCreationReplicationFactor(DEFAULT_TOPIC_CREATION_GROUP)); + assertEquals(partitions, (int) sourceConfig.topicCreationPartitions(DEFAULT_TOPIC_CREATION_GROUP)); + assertThat(sourceConfig.topicCreationInclude(DEFAULT_TOPIC_CREATION_GROUP), is(Collections.singletonList(".*"))); + assertThat(sourceConfig.topicCreationExclude(DEFAULT_TOPIC_CREATION_GROUP), is(Collections.emptyList())); + assertThat(sourceConfig.topicCreationOtherConfigs(DEFAULT_TOPIC_CREATION_GROUP), is(Collections.emptyMap())); + + // verify topic creation group is instantiated correctly + Map groups = TopicCreationGroup.configuredGroups(sourceConfig); + assertEquals(3, groups.size()); + assertThat(groups.keySet(), hasItems(DEFAULT_TOPIC_CREATION_GROUP, FOO_GROUP, BAR_GROUP)); + + // verify topic creation + TopicCreation topicCreation = TopicCreation.newTopicCreation(workerConfig, groups); + TopicCreationGroup defaultGroup = topicCreation.defaultTopicGroup(); + // Default group will match all topics besides empty string + assertTrue(defaultGroup.matches(" ")); + assertTrue(defaultGroup.matches(FOO_TOPIC)); + assertTrue(defaultGroup.matches(BAR_TOPIC)); + assertEquals(DEFAULT_TOPIC_CREATION_GROUP, defaultGroup.name()); + TopicCreationGroup fooGroup = groups.get(FOO_GROUP); + assertFalse(fooGroup.matches(" ")); + assertTrue(fooGroup.matches(FOO_TOPIC)); + assertFalse(fooGroup.matches(BAR_TOPIC)); + assertEquals(FOO_GROUP, fooGroup.name()); + TopicCreationGroup barGroup = groups.get(BAR_GROUP); + assertTrue(barGroup.matches(BAR_TOPIC)); + assertFalse(barGroup.matches(FOO_TOPIC)); + assertEquals(BAR_GROUP, barGroup.name()); + + assertTrue(topicCreation.isTopicCreationEnabled()); + assertTrue(topicCreation.isTopicCreationRequired(FOO_TOPIC)); + assertTrue(topicCreation.isTopicCreationRequired(BAR_TOPIC)); + assertEquals(2, topicCreation.topicGroups().size()); + assertThat(topicCreation.topicGroups().keySet(), hasItems(FOO_GROUP, BAR_GROUP)); + assertEquals(fooGroup, topicCreation.findFirstGroup(FOO_TOPIC)); + assertEquals(barGroup, topicCreation.findFirstGroup(BAR_TOPIC)); + topicCreation.addTopic(FOO_TOPIC); + topicCreation.addTopic(BAR_TOPIC); + assertFalse(topicCreation.isTopicCreationRequired(FOO_TOPIC)); + assertFalse(topicCreation.isTopicCreationRequired(BAR_TOPIC)); + + // verify new topic properties + String otherTopic = "any-other-topic"; + NewTopic defaultTopicSpec = topicCreation.findFirstGroup(otherTopic).newTopic(otherTopic); + assertEquals(otherTopic, defaultTopicSpec.name()); + assertEquals(DEFAULT_REPLICATION_FACTOR, defaultTopicSpec.replicationFactor()); + assertEquals(partitions, defaultTopicSpec.numPartitions()); + assertThat(defaultTopicSpec.configs(), is(Collections.emptyMap())); + + NewTopic fooTopicSpec = topicCreation.findFirstGroup(FOO_TOPIC).newTopic(FOO_TOPIC); + assertEquals(FOO_TOPIC, fooTopicSpec.name()); + assertEquals(fooReplicas, fooTopicSpec.replicationFactor()); + assertEquals(partitions, fooTopicSpec.numPartitions()); + assertThat(fooTopicSpec.configs(), is(fooTopicProps)); + + NewTopic barTopicSpec = topicCreation.findFirstGroup(BAR_TOPIC).newTopic(BAR_TOPIC); + assertEquals(BAR_TOPIC, barTopicSpec.name()); + assertEquals(DEFAULT_REPLICATION_FACTOR, barTopicSpec.replicationFactor()); + assertEquals(barPartitions, barTopicSpec.numPartitions()); + assertThat(barTopicSpec.configs(), is(barTopicProps)); + } + + @Test + public void testTopicCreationWithSingleTransformation() { + sourceProps = defaultConnectorPropsWithTopicCreation(); + sourceProps.put(TOPIC_CREATION_GROUPS_CONFIG, String.join(",", FOO_GROUP, BAR_GROUP)); + String xformName = "example"; + String castType = "int8"; + sourceProps.put("transforms", xformName); + sourceProps.put("transforms." + xformName + ".type", Cast.Value.class.getName()); + sourceProps.put("transforms." + xformName + ".spec", castType); + + sourceConfig = new SourceConnectorConfig(MOCK_PLUGINS, sourceProps, true); + + Map groups = TopicCreationGroup.configuredGroups(sourceConfig); + TopicCreation topicCreation = TopicCreation.newTopicCreation(workerConfig, groups); + + assertTrue(topicCreation.isTopicCreationEnabled()); + assertTrue(topicCreation.isTopicCreationRequired(FOO_TOPIC)); + assertThat(topicCreation.defaultTopicGroup(), is(groups.get(DEFAULT_TOPIC_CREATION_GROUP))); + assertEquals(2, topicCreation.topicGroups().size()); + assertThat(topicCreation.topicGroups().keySet(), hasItems(FOO_GROUP, BAR_GROUP)); + assertEquals(topicCreation.defaultTopicGroup(), topicCreation.findFirstGroup(FOO_TOPIC)); + topicCreation.addTopic(FOO_TOPIC); + assertFalse(topicCreation.isTopicCreationRequired(FOO_TOPIC)); + + List> transformations = sourceConfig.transformations(); + assertEquals(1, transformations.size()); + Cast xform = (Cast) transformations.get(0); + SourceRecord transformed = xform.apply(new SourceRecord(null, null, "topic", 0, null, null, Schema.INT8_SCHEMA, 42)); + assertEquals(Schema.Type.INT8, transformed.valueSchema().type()); + assertEquals((byte) 42, transformed.value()); + } + + @Test + public void topicCreationWithTwoGroupsAndTwoTransformations() { + short fooReplicas = 3; + int partitions = 5; + int barPartitions = 1; + + sourceProps = defaultConnectorPropsWithTopicCreation(); + sourceProps.put(TOPIC_CREATION_GROUPS_CONFIG, String.join(",", FOO_GROUP, BAR_GROUP)); + sourceProps.put(DEFAULT_TOPIC_CREATION_PREFIX + PARTITIONS_CONFIG, String.valueOf(partitions)); + // Setting here but they should be ignored for the default group + sourceProps.put(TOPIC_CREATION_PREFIX + FOO_GROUP + "." + INCLUDE_REGEX_CONFIG, FOO_TOPIC); + sourceProps.put(TOPIC_CREATION_PREFIX + FOO_GROUP + "." + REPLICATION_FACTOR_CONFIG, String.valueOf(fooReplicas)); + sourceProps.put(TOPIC_CREATION_PREFIX + BAR_GROUP + "." + INCLUDE_REGEX_CONFIG, BAR_REGEX); + sourceProps.put(TOPIC_CREATION_PREFIX + BAR_GROUP + "." + PARTITIONS_CONFIG, String.valueOf(barPartitions)); + + String castName = "cast"; + String castType = "int8"; + sourceProps.put("transforms." + castName + ".type", Cast.Value.class.getName()); + sourceProps.put("transforms." + castName + ".spec", castType); + + String regexRouterName = "regex"; + sourceProps.put("transforms." + regexRouterName + ".type", RegexRouter.class.getName()); + sourceProps.put("transforms." + regexRouterName + ".regex", "(.*)"); + sourceProps.put("transforms." + regexRouterName + ".replacement", "prefix-$1"); + + sourceProps.put("transforms", String.join(",", castName, regexRouterName)); + + Map fooTopicProps = new HashMap<>(); + fooTopicProps.put(RETENTION_MS_CONFIG, String.valueOf(TimeUnit.DAYS.toMillis(30))); + fooTopicProps.forEach((k, v) -> sourceProps.put(TOPIC_CREATION_PREFIX + FOO_GROUP + "." + k, v)); + + Map barTopicProps = new HashMap<>(); + barTopicProps.put(CLEANUP_POLICY_CONFIG, CLEANUP_POLICY_COMPACT); + barTopicProps.forEach((k, v) -> sourceProps.put(TOPIC_CREATION_PREFIX + BAR_GROUP + "." + k, v)); + + // verify config creation + sourceConfig = new SourceConnectorConfig(MOCK_PLUGINS, sourceProps, true); + assertTrue(sourceConfig.usesTopicCreation()); + assertEquals(DEFAULT_REPLICATION_FACTOR, (short) sourceConfig.topicCreationReplicationFactor(DEFAULT_TOPIC_CREATION_GROUP)); + assertEquals(partitions, (int) sourceConfig.topicCreationPartitions(DEFAULT_TOPIC_CREATION_GROUP)); + assertThat(sourceConfig.topicCreationInclude(DEFAULT_TOPIC_CREATION_GROUP), is(Collections.singletonList(".*"))); + assertThat(sourceConfig.topicCreationExclude(DEFAULT_TOPIC_CREATION_GROUP), is(Collections.emptyList())); + assertThat(sourceConfig.topicCreationOtherConfigs(DEFAULT_TOPIC_CREATION_GROUP), is(Collections.emptyMap())); + + // verify topic creation group is instantiated correctly + Map groups = TopicCreationGroup.configuredGroups(sourceConfig); + assertEquals(3, groups.size()); + assertThat(groups.keySet(), hasItems(DEFAULT_TOPIC_CREATION_GROUP, FOO_GROUP, BAR_GROUP)); + + // verify topic creation + TopicCreation topicCreation = TopicCreation.newTopicCreation(workerConfig, groups); + TopicCreationGroup defaultGroup = topicCreation.defaultTopicGroup(); + // Default group will match all topics besides empty string + assertTrue(defaultGroup.matches(" ")); + assertTrue(defaultGroup.matches(FOO_TOPIC)); + assertTrue(defaultGroup.matches(BAR_TOPIC)); + assertEquals(DEFAULT_TOPIC_CREATION_GROUP, defaultGroup.name()); + TopicCreationGroup fooGroup = groups.get(FOO_GROUP); + assertFalse(fooGroup.matches(" ")); + assertTrue(fooGroup.matches(FOO_TOPIC)); + assertFalse(fooGroup.matches(BAR_TOPIC)); + assertEquals(FOO_GROUP, fooGroup.name()); + TopicCreationGroup barGroup = groups.get(BAR_GROUP); + assertTrue(barGroup.matches(BAR_TOPIC)); + assertFalse(barGroup.matches(FOO_TOPIC)); + assertEquals(BAR_GROUP, barGroup.name()); + + assertTrue(topicCreation.isTopicCreationEnabled()); + assertTrue(topicCreation.isTopicCreationRequired(FOO_TOPIC)); + assertTrue(topicCreation.isTopicCreationRequired(BAR_TOPIC)); + assertEquals(2, topicCreation.topicGroups().size()); + assertThat(topicCreation.topicGroups().keySet(), hasItems(FOO_GROUP, BAR_GROUP)); + assertEquals(fooGroup, topicCreation.findFirstGroup(FOO_TOPIC)); + assertEquals(barGroup, topicCreation.findFirstGroup(BAR_TOPIC)); + topicCreation.addTopic(FOO_TOPIC); + topicCreation.addTopic(BAR_TOPIC); + assertFalse(topicCreation.isTopicCreationRequired(FOO_TOPIC)); + assertFalse(topicCreation.isTopicCreationRequired(BAR_TOPIC)); + + // verify new topic properties + String otherTopic = "any-other-topic"; + NewTopic defaultTopicSpec = topicCreation.findFirstGroup(otherTopic).newTopic(otherTopic); + assertEquals(otherTopic, defaultTopicSpec.name()); + assertEquals(DEFAULT_REPLICATION_FACTOR, defaultTopicSpec.replicationFactor()); + assertEquals(partitions, defaultTopicSpec.numPartitions()); + assertThat(defaultTopicSpec.configs(), is(Collections.emptyMap())); + + NewTopic fooTopicSpec = topicCreation.findFirstGroup(FOO_TOPIC).newTopic(FOO_TOPIC); + assertEquals(FOO_TOPIC, fooTopicSpec.name()); + assertEquals(fooReplicas, fooTopicSpec.replicationFactor()); + assertEquals(partitions, fooTopicSpec.numPartitions()); + assertThat(fooTopicSpec.configs(), is(fooTopicProps)); + + NewTopic barTopicSpec = topicCreation.findFirstGroup(BAR_TOPIC).newTopic(BAR_TOPIC); + assertEquals(BAR_TOPIC, barTopicSpec.name()); + assertEquals(DEFAULT_REPLICATION_FACTOR, barTopicSpec.replicationFactor()); + assertEquals(barPartitions, barTopicSpec.numPartitions()); + assertThat(barTopicSpec.configs(), is(barTopicProps)); + + List> transformations = sourceConfig.transformations(); + assertEquals(2, transformations.size()); + + Cast castXForm = (Cast) transformations.get(0); + SourceRecord transformed = castXForm.apply(new SourceRecord(null, null, "topic", 0, null, null, Schema.INT8_SCHEMA, 42)); + assertEquals(Schema.Type.INT8, transformed.valueSchema().type()); + assertEquals((byte) 42, transformed.value()); + + RegexRouter regexRouterXForm = (RegexRouter) transformations.get(1); + transformed = regexRouterXForm.apply(new SourceRecord(null, null, "topic", 0, null, null, Schema.INT8_SCHEMA, 42)); + assertEquals("prefix-topic", transformed.topic()); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/util/clusters/EmbeddedConnectCluster.java b/connect/runtime/src/test/java/org/apache/kafka/connect/util/clusters/EmbeddedConnectCluster.java new file mode 100644 index 0000000..adcde37 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/util/clusters/EmbeddedConnectCluster.java @@ -0,0 +1,858 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util.clusters; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.runtime.isolation.Plugins; +import org.apache.kafka.connect.runtime.rest.entities.ActiveTopicsInfo; +import org.apache.kafka.connect.runtime.rest.entities.ConfigInfos; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorStateInfo; +import org.apache.kafka.connect.runtime.rest.entities.ServerInfo; +import org.apache.kafka.connect.runtime.rest.errors.ConnectRestException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.ws.rs.core.Response; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStreamWriter; +import java.net.HttpURLConnection; +import java.net.URL; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Properties; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +import static org.apache.kafka.clients.consumer.ConsumerConfig.GROUP_ID_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.KEY_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.ConnectorConfig.VALUE_CONVERTER_CLASS_CONFIG; +import static org.apache.kafka.connect.runtime.WorkerConfig.BOOTSTRAP_SERVERS_CONFIG; +import static org.apache.kafka.connect.runtime.WorkerConfig.LISTENERS_CONFIG; +import static org.apache.kafka.connect.runtime.distributed.DistributedConfig.CONFIG_STORAGE_REPLICATION_FACTOR_CONFIG; +import static org.apache.kafka.connect.runtime.distributed.DistributedConfig.CONFIG_TOPIC_CONFIG; +import static org.apache.kafka.connect.runtime.distributed.DistributedConfig.OFFSET_STORAGE_REPLICATION_FACTOR_CONFIG; +import static org.apache.kafka.connect.runtime.distributed.DistributedConfig.OFFSET_STORAGE_TOPIC_CONFIG; +import static org.apache.kafka.connect.runtime.distributed.DistributedConfig.STATUS_STORAGE_REPLICATION_FACTOR_CONFIG; +import static org.apache.kafka.connect.runtime.distributed.DistributedConfig.STATUS_STORAGE_TOPIC_CONFIG; + +/** + * Start an embedded connect worker. Internally, this class will spin up a Kafka and Zk cluster, setup any tmp + * directories and clean up them on them. Methods on the same {@code EmbeddedConnectCluster} are + * not guaranteed to be thread-safe. + */ +public class EmbeddedConnectCluster { + + private static final Logger log = LoggerFactory.getLogger(EmbeddedConnectCluster.class); + + public static final int DEFAULT_NUM_BROKERS = 1; + public static final int DEFAULT_NUM_WORKERS = 1; + private static final Properties DEFAULT_BROKER_CONFIG = new Properties(); + private static final String REST_HOST_NAME = "localhost"; + + private static final String DEFAULT_WORKER_NAME_PREFIX = "connect-worker-"; + + private final Set connectCluster; + private final EmbeddedKafkaCluster kafkaCluster; + private final Map workerProps; + private final String connectClusterName; + private final int numBrokers; + private final int numInitialWorkers; + private final boolean maskExitProcedures; + private final String workerNamePrefix; + private final AtomicInteger nextWorkerId = new AtomicInteger(0); + private final EmbeddedConnectClusterAssertions assertions; + // we should keep the original class loader and set it back after connector stopped since the connector will change the class loader, + // and then, the Mockito will use the unexpected class loader to generate the wrong proxy instance, which makes mock failed + private final ClassLoader originalClassLoader = Thread.currentThread().getContextClassLoader(); + + private EmbeddedConnectCluster(String name, Map workerProps, int numWorkers, + int numBrokers, Properties brokerProps, + boolean maskExitProcedures) { + this.workerProps = workerProps; + this.connectClusterName = name; + this.numBrokers = numBrokers; + this.kafkaCluster = new EmbeddedKafkaCluster(numBrokers, brokerProps); + this.connectCluster = new LinkedHashSet<>(); + this.numInitialWorkers = numWorkers; + this.maskExitProcedures = maskExitProcedures; + // leaving non-configurable for now + this.workerNamePrefix = DEFAULT_WORKER_NAME_PREFIX; + this.assertions = new EmbeddedConnectClusterAssertions(this); + } + + /** + * A more graceful way to handle abnormal exit of services in integration tests. + */ + public Exit.Procedure exitProcedure = (code, message) -> { + if (code != 0) { + String exitMessage = "Abrupt service exit with code " + code + " and message " + message; + log.warn(exitMessage); + throw new UngracefulShutdownException(exitMessage); + } + }; + + /** + * A more graceful way to handle abnormal halt of services in integration tests. + */ + public Exit.Procedure haltProcedure = (code, message) -> { + if (code != 0) { + String haltMessage = "Abrupt service halt with code " + code + " and message " + message; + log.warn(haltMessage); + throw new UngracefulShutdownException(haltMessage); + } + }; + + /** + * Start the connect cluster and the embedded Kafka and Zookeeper cluster. + */ + public void start() { + if (maskExitProcedures) { + Exit.setExitProcedure(exitProcedure); + Exit.setHaltProcedure(haltProcedure); + } + kafkaCluster.start(); + startConnect(); + } + + /** + * Stop the connect cluster and the embedded Kafka and Zookeeper cluster. + * Clean up any temp directories created locally. + * + * @throws RuntimeException if Kafka brokers fail to stop + */ + public void stop() { + connectCluster.forEach(this::stopWorker); + try { + kafkaCluster.stop(); + } catch (UngracefulShutdownException e) { + log.warn("Kafka did not shutdown gracefully"); + } catch (Exception e) { + log.error("Could not stop kafka", e); + throw new RuntimeException("Could not stop brokers", e); + } finally { + if (maskExitProcedures) { + Exit.resetExitProcedure(); + Exit.resetHaltProcedure(); + } + Plugins.compareAndSwapLoaders(originalClassLoader); + } + } + + /** + * Provision and start an additional worker to the Connect cluster. + * + * @return the worker handle of the worker that was provisioned + */ + public WorkerHandle addWorker() { + WorkerHandle worker = WorkerHandle.start(workerNamePrefix + nextWorkerId.getAndIncrement(), workerProps); + connectCluster.add(worker); + log.info("Started worker {}", worker); + return worker; + } + + /** + * Decommission one of the workers from this Connect cluster. Which worker is removed is + * implementation dependent and selection is not guaranteed to be consistent. Use this method + * when you don't care which worker stops. + * + * @see #removeWorker(WorkerHandle) + */ + public void removeWorker() { + WorkerHandle toRemove = null; + for (Iterator it = connectCluster.iterator(); it.hasNext(); toRemove = it.next()) { + } + if (toRemove != null) { + removeWorker(toRemove); + } + } + + /** + * Decommission a specific worker from this Connect cluster. + * + * @param worker the handle of the worker to remove from the cluster + * @throws IllegalStateException if the Connect cluster has no workers + */ + public void removeWorker(WorkerHandle worker) { + if (connectCluster.isEmpty()) { + throw new IllegalStateException("Cannot remove worker. Cluster is empty"); + } + stopWorker(worker); + connectCluster.remove(worker); + } + + private void stopWorker(WorkerHandle worker) { + try { + log.info("Stopping worker {}", worker); + worker.stop(); + } catch (UngracefulShutdownException e) { + log.warn("Worker {} did not shutdown gracefully", worker); + } catch (Exception e) { + log.error("Could not stop connect", e); + throw new RuntimeException("Could not stop worker", e); + } + } + + /** + * Determine whether the Connect cluster has any workers running. + * + * @return true if any worker is running, or false otherwise + */ + public boolean anyWorkersRunning() { + return workers().stream().anyMatch(WorkerHandle::isRunning); + } + + /** + * Determine whether the Connect cluster has all workers running. + * + * @return true if all workers are running, or false otherwise + */ + public boolean allWorkersRunning() { + return workers().stream().allMatch(WorkerHandle::isRunning); + } + + public void startConnect() { + log.info("Starting Connect cluster '{}' with {} workers", connectClusterName, numInitialWorkers); + + workerProps.put(BOOTSTRAP_SERVERS_CONFIG, kafka().bootstrapServers()); + // use a random available port + workerProps.put(LISTENERS_CONFIG, "HTTP://" + REST_HOST_NAME + ":0"); + + String internalTopicsReplFactor = String.valueOf(numBrokers); + putIfAbsent(workerProps, GROUP_ID_CONFIG, "connect-integration-test-" + connectClusterName); + putIfAbsent(workerProps, OFFSET_STORAGE_TOPIC_CONFIG, "connect-offset-topic-" + connectClusterName); + putIfAbsent(workerProps, OFFSET_STORAGE_REPLICATION_FACTOR_CONFIG, internalTopicsReplFactor); + putIfAbsent(workerProps, CONFIG_TOPIC_CONFIG, "connect-config-topic-" + connectClusterName); + putIfAbsent(workerProps, CONFIG_STORAGE_REPLICATION_FACTOR_CONFIG, internalTopicsReplFactor); + putIfAbsent(workerProps, STATUS_STORAGE_TOPIC_CONFIG, "connect-storage-topic-" + connectClusterName); + putIfAbsent(workerProps, STATUS_STORAGE_REPLICATION_FACTOR_CONFIG, internalTopicsReplFactor); + putIfAbsent(workerProps, KEY_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.storage.StringConverter"); + putIfAbsent(workerProps, VALUE_CONVERTER_CLASS_CONFIG, "org.apache.kafka.connect.storage.StringConverter"); + + for (int i = 0; i < numInitialWorkers; i++) { + addWorker(); + } + } + + @Override + public String toString() { + return String.format("EmbeddedConnectCluster(name= %s, numBrokers= %d, numInitialWorkers= %d, workerProps= %s)", + connectClusterName, + numBrokers, + numInitialWorkers, + workerProps); + } + + public String getName() { + return connectClusterName; + } + + /** + * Get the workers that are up and running. + * + * @return the list of handles of the online workers + */ + public Set activeWorkers() { + ObjectMapper mapper = new ObjectMapper(); + return connectCluster.stream() + .filter(w -> { + try { + mapper.readerFor(ServerInfo.class) + .readValue(responseToString(requestGet(w.url().toString()))); + return true; + } catch (ConnectException | IOException e) { + // Worker failed to respond. Consider it's offline + return false; + } + }) + .collect(Collectors.toSet()); + } + + /** + * Get the provisioned workers. + * + * @return the list of handles of the provisioned workers + */ + public Set workers() { + return new LinkedHashSet<>(connectCluster); + } + + /** + * Configure a connector. If the connector does not already exist, a new one will be created and + * the given configuration will be applied to it. + * + * @param connName the name of the connector + * @param connConfig the intended configuration + * @throws ConnectRestException if the REST api returns error status + * @throws ConnectException if the configuration fails to be serialized or if the request could not be sent + */ + public String configureConnector(String connName, Map connConfig) { + String url = endpointForResource(String.format("connectors/%s/config", connName)); + return putConnectorConfig(url, connConfig); + } + + /** + * Validate a given connector configuration. If the configuration validates or + * has a configuration error, an instance of {@link ConfigInfos} is returned. If the validation fails + * an exception is thrown. + * + * @param connClassName the name of the connector class + * @param connConfig the intended configuration + * @throws ConnectRestException if the REST api returns error status + * @throws ConnectException if the configuration fails to serialize/deserialize or if the request failed to send + */ + public ConfigInfos validateConnectorConfig(String connClassName, Map connConfig) { + String url = endpointForResource(String.format("connector-plugins/%s/config/validate", connClassName)); + String response = putConnectorConfig(url, connConfig); + ConfigInfos configInfos; + try { + configInfos = new ObjectMapper().readValue(response, ConfigInfos.class); + } catch (IOException e) { + throw new ConnectException("Unable deserialize response into a ConfigInfos object"); + } + return configInfos; + } + + /** + * Execute a PUT request with the given connector configuration on the given URL endpoint. + * + * @param url the full URL of the endpoint that corresponds to the given REST resource + * @param connConfig the intended configuration + * @throws ConnectRestException if the REST api returns error status + * @throws ConnectException if the configuration fails to be serialized or if the request could not be sent + */ + protected String putConnectorConfig(String url, Map connConfig) { + ObjectMapper mapper = new ObjectMapper(); + String content; + try { + content = mapper.writeValueAsString(connConfig); + } catch (IOException e) { + throw new ConnectException("Could not serialize connector configuration and execute PUT request"); + } + Response response = requestPut(url, content); + if (response.getStatus() < Response.Status.BAD_REQUEST.getStatusCode()) { + return responseToString(response); + } + throw new ConnectRestException(response.getStatus(), + "Could not execute PUT request. Error response: " + responseToString(response)); + } + + /** + * Delete an existing connector. + * + * @param connName name of the connector to be deleted + * @throws ConnectRestException if the REST API returns error status + * @throws ConnectException for any other error. + */ + public void deleteConnector(String connName) { + String url = endpointForResource(String.format("connectors/%s", connName)); + Response response = requestDelete(url); + if (response.getStatus() >= Response.Status.BAD_REQUEST.getStatusCode()) { + throw new ConnectRestException(response.getStatus(), + "Could not execute DELETE request. Error response: " + responseToString(response)); + } + } + + /** + * Pause an existing connector. + * + * @param connName name of the connector to be paused + * @throws ConnectRestException if the REST API returns error status + * @throws ConnectException for any other error. + */ + public void pauseConnector(String connName) { + String url = endpointForResource(String.format("connectors/%s/pause", connName)); + Response response = requestPut(url, ""); + if (response.getStatus() >= Response.Status.BAD_REQUEST.getStatusCode()) { + throw new ConnectRestException(response.getStatus(), + "Could not execute PUT request. Error response: " + responseToString(response)); + } + } + + /** + * Resume an existing connector. + * + * @param connName name of the connector to be resumed + * @throws ConnectRestException if the REST API returns error status + * @throws ConnectException for any other error. + */ + public void resumeConnector(String connName) { + String url = endpointForResource(String.format("connectors/%s/resume", connName)); + Response response = requestPut(url, ""); + if (response.getStatus() >= Response.Status.BAD_REQUEST.getStatusCode()) { + throw new ConnectRestException(response.getStatus(), + "Could not execute PUT request. Error response: " + responseToString(response)); + } + } + + /** + * Restart an existing connector. + * + * @param connName name of the connector to be restarted + * @throws ConnectRestException if the REST API returns error status + * @throws ConnectException for any other error. + */ + public void restartConnector(String connName) { + String url = endpointForResource(String.format("connectors/%s/restart", connName)); + Response response = requestPost(url, "", Collections.emptyMap()); + if (response.getStatus() >= Response.Status.BAD_REQUEST.getStatusCode()) { + throw new ConnectRestException(response.getStatus(), + "Could not execute POST request. Error response: " + responseToString(response)); + } + } + + /** + * Restart an existing connector and its tasks. + * + * @param connName name of the connector to be restarted + * @param onlyFailed true if only failed instances should be restarted + * @param includeTasks true if tasks should be restarted, or false if only the connector should be restarted + * @param onlyCallOnEmptyWorker true if the REST API call should be called on a worker not running this connector or its tasks + * @throws ConnectRestException if the REST API returns error status + * @throws ConnectException for any other error. + */ + public ConnectorStateInfo restartConnectorAndTasks(String connName, boolean onlyFailed, boolean includeTasks, boolean onlyCallOnEmptyWorker) { + ObjectMapper mapper = new ObjectMapper(); + String restartPath = String.format("connectors/%s/restart?onlyFailed=" + onlyFailed + "&includeTasks=" + includeTasks, connName); + String restartEndpoint; + if (onlyCallOnEmptyWorker) { + restartEndpoint = endpointForResourceNotRunningConnector(restartPath, connName); + } else { + restartEndpoint = endpointForResource(restartPath); + } + Response response = requestPost(restartEndpoint, "", Collections.emptyMap()); + try { + if (response.getStatus() < Response.Status.BAD_REQUEST.getStatusCode()) { + //only the 202 stauts returns a body + if (response.getStatus() == Response.Status.ACCEPTED.getStatusCode()) { + return mapper.readerFor(ConnectorStateInfo.class) + .readValue(responseToString(response)); + } + } + return null; + } catch (IOException e) { + log.error("Could not read connector state from response: {}", + responseToString(response), e); + throw new ConnectException("Could not not parse connector state", e); + } + } + /** + * Get the connector names of the connectors currently running on this cluster. + * + * @return the list of connector names + * @throws ConnectRestException if the HTTP request to the REST API failed with a valid status code. + * @throws ConnectException for any other error. + */ + public Collection connectors() { + ObjectMapper mapper = new ObjectMapper(); + String url = endpointForResource("connectors"); + Response response = requestGet(url); + if (response.getStatus() < Response.Status.BAD_REQUEST.getStatusCode()) { + try { + return mapper.readerFor(Collection.class).readValue(responseToString(response)); + } catch (IOException e) { + log.error("Could not parse connector list from response: {}", + responseToString(response), e + ); + throw new ConnectException("Could not not parse connector list", e); + } + } + throw new ConnectRestException(response.getStatus(), + "Could not read connector list. Error response: " + responseToString(response)); + } + + /** + * Get the status for a connector running in this cluster. + * + * @param connectorName name of the connector + * @return an instance of {@link ConnectorStateInfo} populated with state information of the connector and its tasks. + * @throws ConnectRestException if the HTTP request to the REST API failed with a valid status code. + * @throws ConnectException for any other error. + */ + public ConnectorStateInfo connectorStatus(String connectorName) { + ObjectMapper mapper = new ObjectMapper(); + String url = endpointForResource(String.format("connectors/%s/status", connectorName)); + Response response = requestGet(url); + try { + if (response.getStatus() < Response.Status.BAD_REQUEST.getStatusCode()) { + return mapper.readerFor(ConnectorStateInfo.class) + .readValue(responseToString(response)); + } + } catch (IOException e) { + log.error("Could not read connector state from response: {}", + responseToString(response), e); + throw new ConnectException("Could not not parse connector state", e); + } + throw new ConnectRestException(response.getStatus(), + "Could not read connector state. Error response: " + responseToString(response)); + } + + /** + * Get the active topics of a connector running in this cluster. + * + * @param connectorName name of the connector + * @return an instance of {@link ConnectorStateInfo} populated with state information of the connector and its tasks. + * @throws ConnectRestException if the HTTP request to the REST API failed with a valid status code. + * @throws ConnectException for any other error. + */ + public ActiveTopicsInfo connectorTopics(String connectorName) { + ObjectMapper mapper = new ObjectMapper(); + String url = endpointForResource(String.format("connectors/%s/topics", connectorName)); + Response response = requestGet(url); + try { + if (response.getStatus() < Response.Status.BAD_REQUEST.getStatusCode()) { + Map>> activeTopics = mapper + .readerFor(new TypeReference>>>() { }) + .readValue(responseToString(response)); + return new ActiveTopicsInfo(connectorName, + activeTopics.get(connectorName).getOrDefault("topics", Collections.emptyList())); + } + } catch (IOException e) { + log.error("Could not read connector state from response: {}", + responseToString(response), e); + throw new ConnectException("Could not not parse connector state", e); + } + throw new ConnectRestException(response.getStatus(), + "Could not read connector state. Error response: " + responseToString(response)); + } + + /** + * Reset the set of active topics of a connector running in this cluster. + * + * @param connectorName name of the connector + * @throws ConnectRestException if the HTTP request to the REST API failed with a valid status code. + * @throws ConnectException for any other error. + */ + public void resetConnectorTopics(String connectorName) { + String url = endpointForResource(String.format("connectors/%s/topics/reset", connectorName)); + Response response = requestPut(url, null); + if (response.getStatus() >= Response.Status.BAD_REQUEST.getStatusCode()) { + throw new ConnectRestException(response.getStatus(), + "Resetting active topics for connector " + connectorName + " failed. " + + "Error response: " + responseToString(response)); + } + } + + /** + * Get the full URL of the admin endpoint that corresponds to the given REST resource + * + * @param resource the resource under the worker's admin endpoint + * @return the admin endpoint URL + * @throws ConnectException if no admin REST endpoint is available + */ + public String adminEndpoint(String resource) { + String url = connectCluster.stream() + .map(WorkerHandle::adminUrl) + .filter(Objects::nonNull) + .findFirst() + .orElseThrow(() -> new ConnectException("Admin endpoint is disabled.")) + .toString(); + return url + resource; + } + + /** + * Get the full URL of the endpoint that corresponds to the given REST resource + * + * @param resource the resource under the worker's admin endpoint + * @return the admin endpoint URL + * @throws ConnectException if no REST endpoint is available + */ + public String endpointForResource(String resource) { + String url = connectCluster.stream() + .map(WorkerHandle::url) + .filter(Objects::nonNull) + .findFirst() + .orElseThrow(() -> new ConnectException("Connect workers have not been provisioned")) + .toString(); + return url + resource; + } + + /** + * Get the full URL of the endpoint that corresponds to the given REST resource using a worker + * that is not running any tasks or connector instance for the connectorName provided in the arguments + * + * @param resource the resource under the worker's admin endpoint + * @param connectorName the name of the connector + * @return the admin endpoint URL + * @throws ConnectException if no REST endpoint is available + */ + public String endpointForResourceNotRunningConnector(String resource, String connectorName) { + ConnectorStateInfo info = connectorStatus(connectorName); + Set activeWorkerUrls = new HashSet<>(); + activeWorkerUrls.add(String.format("http://%s/", info.connector().workerId())); + info.tasks().forEach(t -> activeWorkerUrls.add(String.format("http://%s/", t.workerId()))); + String url = connectCluster.stream() + .map(WorkerHandle::url) + .filter(Objects::nonNull) + .filter(workerUrl -> !activeWorkerUrls.contains(workerUrl.toString())) + .findFirst() + .orElseThrow(() -> new ConnectException( + String.format("Connect workers have not been provisioned or no free worker found that is not running this connector(%s) or its tasks", connectorName))) + .toString(); + return url + resource; + } + + private static void putIfAbsent(Map props, String propertyKey, String propertyValue) { + if (!props.containsKey(propertyKey)) { + props.put(propertyKey, propertyValue); + } + } + + /** + * Return the handle to the Kafka cluster this Connect cluster connects to. + * + * @return the Kafka cluster handle + */ + public EmbeddedKafkaCluster kafka() { + return kafkaCluster; + } + + /** + * Execute a GET request on the given URL. + * + * @param url the HTTP endpoint + * @return the response to the GET request + * @throws ConnectException if execution of the GET request fails + * @deprecated Use {@link #requestGet(String)} instead. + */ + @Deprecated + public String executeGet(String url) { + return responseToString(requestGet(url)); + } + + /** + * Execute a GET request on the given URL. + * + * @param url the HTTP endpoint + * @return the response to the GET request + * @throws ConnectException if execution of the GET request fails + */ + public Response requestGet(String url) { + return requestHttpMethod(url, null, Collections.emptyMap(), "GET"); + } + + /** + * Execute a PUT request on the given URL. + * + * @param url the HTTP endpoint + * @param body the payload of the PUT request + * @return the response to the PUT request + * @throws ConnectException if execution of the PUT request fails + * @deprecated Use {@link #requestPut(String, String)} instead. + */ + @Deprecated + public int executePut(String url, String body) { + return requestPut(url, body).getStatus(); + } + + /** + * Execute a PUT request on the given URL. + * + * @param url the HTTP endpoint + * @param body the payload of the PUT request + * @return the response to the PUT request + * @throws ConnectException if execution of the PUT request fails + */ + public Response requestPut(String url, String body) { + return requestHttpMethod(url, body, Collections.emptyMap(), "PUT"); + } + + /** + * Execute a POST request on the given URL. + * + * @param url the HTTP endpoint + * @param body the payload of the POST request + * @param headers a map that stores the POST request headers + * @return the response to the POST request + * @throws ConnectException if execution of the POST request fails + * @deprecated Use {@link #requestPost(String, String, java.util.Map)} instead. + */ + @Deprecated + public int executePost(String url, String body, Map headers) { + return requestPost(url, body, headers).getStatus(); + } + + /** + * Execute a POST request on the given URL. + * + * @param url the HTTP endpoint + * @param body the payload of the POST request + * @param headers a map that stores the POST request headers + * @return the response to the POST request + * @throws ConnectException if execution of the POST request fails + */ + public Response requestPost(String url, String body, Map headers) { + return requestHttpMethod(url, body, headers, "POST"); + } + + /** + * Execute a DELETE request on the given URL. + * + * @param url the HTTP endpoint + * @return the response to the DELETE request + * @throws ConnectException if execution of the DELETE request fails + * @deprecated Use {@link #requestDelete(String)} instead. + */ + @Deprecated + public int executeDelete(String url) { + return requestDelete(url).getStatus(); + } + + /** + * Execute a DELETE request on the given URL. + * + * @param url the HTTP endpoint + * @return the response to the DELETE request + * @throws ConnectException if execution of the DELETE request fails + */ + public Response requestDelete(String url) { + return requestHttpMethod(url, null, Collections.emptyMap(), "DELETE"); + } + + /** + * A general method that executes an HTTP request on a given URL. + * + * @param url the HTTP endpoint + * @param body the payload of the request; null if there isn't one + * @param headers a map that stores the request headers; empty if there are no headers + * @param httpMethod the name of the HTTP method to execute + * @return the response to the HTTP request + * @throws ConnectException if execution of the HTTP method fails + */ + protected Response requestHttpMethod(String url, String body, Map headers, + String httpMethod) { + log.debug("Executing {} request to URL={}." + (body != null ? " Payload={}" : ""), + httpMethod, url, body); + try { + HttpURLConnection httpCon = (HttpURLConnection) new URL(url).openConnection(); + httpCon.setDoOutput(true); + httpCon.setRequestMethod(httpMethod); + if (body != null) { + httpCon.setRequestProperty("Content-Type", "application/json"); + headers.forEach(httpCon::setRequestProperty); + try (OutputStreamWriter out = new OutputStreamWriter(httpCon.getOutputStream())) { + out.write(body); + } + } + try (InputStream is = httpCon.getResponseCode() < HttpURLConnection.HTTP_BAD_REQUEST + ? httpCon.getInputStream() + : httpCon.getErrorStream() + ) { + String responseEntity = responseToString(is); + log.info("{} response for URL={} is {}", + httpMethod, url, responseEntity.isEmpty() ? "empty" : responseEntity); + return Response.status(Response.Status.fromStatusCode(httpCon.getResponseCode())) + .entity(responseEntity) + .build(); + } + } catch (IOException e) { + log.error("Could not execute " + httpMethod + " request to " + url, e); + throw new ConnectException(e); + } + } + + private String responseToString(Response response) { + return response == null ? "empty" : (String) response.getEntity(); + } + + private String responseToString(InputStream stream) throws IOException { + int c; + StringBuilder response = new StringBuilder(); + while ((c = stream.read()) != -1) { + response.append((char) c); + } + return response.toString(); + } + + public static class Builder { + private String name = UUID.randomUUID().toString(); + private Map workerProps = new HashMap<>(); + private int numWorkers = DEFAULT_NUM_WORKERS; + private int numBrokers = DEFAULT_NUM_BROKERS; + private Properties brokerProps = DEFAULT_BROKER_CONFIG; + private boolean maskExitProcedures = true; + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder workerProps(Map workerProps) { + this.workerProps = workerProps; + return this; + } + + public Builder numWorkers(int numWorkers) { + this.numWorkers = numWorkers; + return this; + } + + public Builder numBrokers(int numBrokers) { + this.numBrokers = numBrokers; + return this; + } + + public Builder brokerProps(Properties brokerProps) { + this.brokerProps = brokerProps; + return this; + } + + /** + * In the event of ungraceful shutdown, embedded clusters call exit or halt with non-zero + * exit statuses. Exiting with a non-zero status forces a test to fail and is hard to + * handle. Because graceful exit is usually not required during a test and because + * depending on such an exit increases flakiness, this setting allows masking + * exit and halt procedures by using a runtime exception instead. Customization of the + * exit and halt procedures is possible through {@code exitProcedure} and {@code + * haltProcedure} respectively. + * + * @param mask if false, exit and halt procedures remain unchanged; true is the default. + * @return the builder for this cluster + */ + public Builder maskExitProcedures(boolean mask) { + this.maskExitProcedures = mask; + return this; + } + + public EmbeddedConnectCluster build() { + return new EmbeddedConnectCluster(name, workerProps, numWorkers, numBrokers, + brokerProps, maskExitProcedures); + } + } + + /** + * Return the available assertions for this Connect cluster + * + * @return the assertions object + */ + public EmbeddedConnectClusterAssertions assertions() { + return assertions; + } + +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/util/clusters/EmbeddedConnectClusterAssertions.java b/connect/runtime/src/test/java/org/apache/kafka/connect/util/clusters/EmbeddedConnectClusterAssertions.java new file mode 100644 index 0000000..edd99c8 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/util/clusters/EmbeddedConnectClusterAssertions.java @@ -0,0 +1,549 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util.clusters; + +import org.apache.kafka.clients.admin.TopicDescription; +import org.apache.kafka.connect.runtime.AbstractStatus; +import org.apache.kafka.connect.runtime.rest.entities.ActiveTopicsInfo; +import org.apache.kafka.connect.runtime.rest.entities.ConnectorStateInfo; +import org.apache.kafka.connect.runtime.rest.errors.ConnectRestException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.ws.rs.core.Response; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiFunction; +import java.util.stream.Collectors; + +import static org.apache.kafka.test.TestUtils.waitForCondition; + +/** + * A set of common assertions that can be applied to a Connect cluster during integration testing + */ +public class EmbeddedConnectClusterAssertions { + + private static final Logger log = LoggerFactory.getLogger(EmbeddedConnectClusterAssertions.class); + public static final long WORKER_SETUP_DURATION_MS = TimeUnit.SECONDS.toMillis(60); + public static final long VALIDATION_DURATION_MS = TimeUnit.SECONDS.toMillis(30); + public static final long CONNECTOR_SETUP_DURATION_MS = TimeUnit.SECONDS.toMillis(30); + private static final long CONNECT_INTERNAL_TOPIC_UPDATES_DURATION_MS = TimeUnit.SECONDS.toMillis(60); + + private final EmbeddedConnectCluster connect; + + EmbeddedConnectClusterAssertions(EmbeddedConnectCluster connect) { + this.connect = connect; + } + + /** + * Assert that at least the requested number of workers are up and running. + * + * @param numWorkers the number of online workers + */ + public void assertAtLeastNumWorkersAreUp(int numWorkers, String detailMessage) throws InterruptedException { + try { + waitForCondition( + () -> checkWorkersUp(numWorkers, (actual, expected) -> actual >= expected).orElse(false), + WORKER_SETUP_DURATION_MS, + "Didn't meet the minimum requested number of online workers: " + numWorkers); + } catch (AssertionError e) { + throw new AssertionError(detailMessage, e); + } + } + + /** + * Assert that at least the requested number of workers are up and running. + * + * @param numWorkers the number of online workers + */ + public void assertExactlyNumWorkersAreUp(int numWorkers, String detailMessage) throws InterruptedException { + try { + waitForCondition( + () -> checkWorkersUp(numWorkers, (actual, expected) -> actual == expected).orElse(false), + WORKER_SETUP_DURATION_MS, + "Didn't meet the exact requested number of online workers: " + numWorkers); + } catch (AssertionError e) { + throw new AssertionError(detailMessage, e); + } + } + + /** + * Confirm that the requested number of workers are up and running. + * + * @param numWorkers the number of online workers + * @return true if at least {@code numWorkers} are up; false otherwise + */ + protected Optional checkWorkersUp(int numWorkers, BiFunction comp) { + try { + int numUp = connect.activeWorkers().size(); + return Optional.of(comp.apply(numUp, numWorkers)); + } catch (Exception e) { + log.error("Could not check active workers.", e); + return Optional.empty(); + } + } + + /** + * Assert that at least the requested number of workers are up and running. + * + * @param numBrokers the number of online brokers + */ + public void assertExactlyNumBrokersAreUp(int numBrokers, String detailMessage) throws InterruptedException { + try { + waitForCondition( + () -> checkBrokersUp(numBrokers, (actual, expected) -> actual == expected).orElse(false), + WORKER_SETUP_DURATION_MS, + "Didn't meet the exact requested number of online brokers: " + numBrokers); + } catch (AssertionError e) { + throw new AssertionError(detailMessage, e); + } + } + + /** + * Confirm that the requested number of brokers are up and running. + * + * @param numBrokers the number of online brokers + * @return true if at least {@code numBrokers} are up; false otherwise + */ + protected Optional checkBrokersUp(int numBrokers, BiFunction comp) { + try { + int numRunning = connect.kafka().runningBrokers().size(); + return Optional.of(comp.apply(numRunning, numBrokers)); + } catch (Exception e) { + log.error("Could not check running brokers.", e); + return Optional.empty(); + } + } + + /** + * Assert that the topics with the specified names do not exist. + * + * @param topicNames the names of the topics that are expected to not exist + */ + public void assertTopicsDoNotExist(String... topicNames) throws InterruptedException { + Set topicNameSet = new HashSet<>(Arrays.asList(topicNames)); + AtomicReference> existingTopics = new AtomicReference<>(topicNameSet); + waitForCondition( + () -> checkTopicsExist(topicNameSet, (actual, expected) -> { + existingTopics.set(actual); + return actual.isEmpty(); + }).orElse(false), + CONNECTOR_SETUP_DURATION_MS, + "Unexpectedly found topics " + existingTopics.get()); + } + + /** + * Assert that the topics with the specified names do exist. + * + * @param topicNames the names of the topics that are expected to exist + */ + public void assertTopicsExist(String... topicNames) throws InterruptedException { + Set topicNameSet = new HashSet<>(Arrays.asList(topicNames)); + AtomicReference> missingTopics = new AtomicReference<>(topicNameSet); + waitForCondition( + () -> checkTopicsExist(topicNameSet, (actual, expected) -> { + Set missing = new HashSet<>(expected); + missing.removeAll(actual); + missingTopics.set(missing); + return missing.isEmpty(); + }).orElse(false), + CONNECTOR_SETUP_DURATION_MS, + "Didn't find the topics " + missingTopics.get()); + } + + protected Optional checkTopicsExist(Set topicNames, BiFunction, Set, Boolean> comp) { + try { + Map> topics = connect.kafka().describeTopics(topicNames); + Set actualExistingTopics = topics.entrySet() + .stream() + .filter(e -> e.getValue().isPresent()) + .map(Map.Entry::getKey) + .collect(Collectors.toSet()); + return Optional.of(comp.apply(actualExistingTopics, topicNames)); + } catch (Exception e) { + log.error("Failed to describe the topic(s): {}.", topicNames, e); + return Optional.empty(); + } + } + + /** + * Assert that the named topic is configured to have the specified replication factor and + * number of partitions. + * + * @param topicName the name of the topic that is expected to exist + * @param replicas the replication factor + * @param partitions the number of partitions + * @param detailMessage the assertion message + */ + public void assertTopicSettings(String topicName, int replicas, int partitions, String detailMessage) + throws InterruptedException { + try { + waitForCondition( + () -> checkTopicSettings( + topicName, + replicas, + partitions + ).orElse(false), + VALIDATION_DURATION_MS, + "Topic " + topicName + " does not exist or does not have exactly " + + partitions + " partitions or at least " + + replicas + " per partition"); + } catch (AssertionError e) { + throw new AssertionError(detailMessage, e); + } + } + + protected Optional checkTopicSettings(String topicName, int replicas, int partitions) { + try { + Map> topics = connect.kafka().describeTopics(topicName); + TopicDescription topicDesc = topics.get(topicName).orElse(null); + boolean result = topicDesc != null + && topicDesc.name().equals(topicName) + && topicDesc.partitions().size() == partitions + && topicDesc.partitions().stream().allMatch(p -> p.replicas().size() >= replicas); + return Optional.of(result); + } catch (Exception e) { + log.error("Failed to describe the topic: {}.", topicName, e); + return Optional.empty(); + } + } + + /** + * Assert that the required number of errors are produced by a connector config validation. + * + * @param connectorClass the class of the connector to validate + * @param connConfig the intended configuration + * @param numErrors the number of errors expected + */ + public void assertExactlyNumErrorsOnConnectorConfigValidation(String connectorClass, Map connConfig, + int numErrors, String detailMessage) throws InterruptedException { + try { + waitForCondition( + () -> checkValidationErrors( + connectorClass, + connConfig, + numErrors, + (actual, expected) -> actual == expected + ).orElse(false), + VALIDATION_DURATION_MS, + "Didn't meet the exact requested number of validation errors: " + numErrors); + } catch (AssertionError e) { + throw new AssertionError(detailMessage, e); + } + } + + /** + * Confirm that the requested number of errors are produced by {@link EmbeddedConnectCluster#validateConnectorConfig}. + * + * @param connectorClass the class of the connector to validate + * @param connConfig the intended configuration + * @param numErrors the number of errors expected + * @return true if exactly {@code numErrors} are produced by the validation; false otherwise + */ + protected Optional checkValidationErrors(String connectorClass, Map connConfig, + int numErrors, BiFunction comp) { + try { + int numErrorsProduced = connect.validateConnectorConfig(connectorClass, connConfig).errorCount(); + return Optional.of(comp.apply(numErrorsProduced, numErrors)); + } catch (Exception e) { + log.error("Could not check config validation error count.", e); + return Optional.empty(); + } + } + + /** + * Assert that a connector is running with at least the given number of tasks all in running state + * + * @param connectorName the connector name + * @param numTasks the number of tasks + * @param detailMessage + * @throws InterruptedException + */ + public void assertConnectorAndAtLeastNumTasksAreRunning(String connectorName, int numTasks, String detailMessage) + throws InterruptedException { + try { + waitForCondition( + () -> checkConnectorState( + connectorName, + AbstractStatus.State.RUNNING, + numTasks, + AbstractStatus.State.RUNNING, + (actual, expected) -> actual >= expected + ).orElse(false), + CONNECTOR_SETUP_DURATION_MS, + "The connector or at least " + numTasks + " of tasks are not running."); + } catch (AssertionError e) { + throw new AssertionError(detailMessage, e); + } + } + + /** + * Assert that a connector is running with at least the given number of tasks all in running state + * + * @param connectorName the connector name + * @param numTasks the number of tasks + * @param detailMessage the assertion message + * @throws InterruptedException + */ + public void assertConnectorAndExactlyNumTasksAreRunning(String connectorName, int numTasks, String detailMessage) + throws InterruptedException { + try { + waitForCondition( + () -> checkConnectorState( + connectorName, + AbstractStatus.State.RUNNING, + numTasks, + AbstractStatus.State.RUNNING, + (actual, expected) -> actual == expected + ).orElse(false), + CONNECTOR_SETUP_DURATION_MS, + "The connector or exactly " + numTasks + " tasks are not running."); + } catch (AssertionError e) { + throw new AssertionError(detailMessage, e); + } + } + + /** + * Assert that a connector is running, that it has a specific number of tasks, and that all of + * its tasks are in the FAILED state. + * + * @param connectorName the connector name + * @param numTasks the number of tasks + * @param detailMessage the assertion message + * @throws InterruptedException + */ + public void assertConnectorIsRunningAndTasksHaveFailed(String connectorName, int numTasks, String detailMessage) + throws InterruptedException { + try { + waitForCondition( + () -> checkConnectorState( + connectorName, + AbstractStatus.State.RUNNING, + numTasks, + AbstractStatus.State.FAILED, + (actual, expected) -> actual >= expected + ).orElse(false), + CONNECTOR_SETUP_DURATION_MS, + "Either the connector is not running or not all the " + numTasks + " tasks have failed."); + } catch (AssertionError e) { + throw new AssertionError(detailMessage, e); + } + } + + /** + * Assert that a connector is running, that it has a specific number of tasks out of that numFailedTasks are in the FAILED state. + * + * @param connectorName the connector name + * @param numTasks the number of tasks + * @param numFailedTasks the number of failed tasks + * @param detailMessage the assertion message + * @throws InterruptedException + */ + public void assertConnectorIsRunningAndNumTasksHaveFailed(String connectorName, int numTasks, int numFailedTasks, String detailMessage) + throws InterruptedException { + try { + waitForCondition( + () -> checkConnectorState( + connectorName, + AbstractStatus.State.RUNNING, + numTasks, + numFailedTasks, + AbstractStatus.State.FAILED, + (actual, expected) -> actual >= expected + ).orElse(false), + CONNECTOR_SETUP_DURATION_MS, + "Either the connector is not running or not all the " + numTasks + " tasks have failed."); + } catch (AssertionError e) { + throw new AssertionError(detailMessage, e); + } + } + + /** + * Assert that a connector is in FAILED state, that it has a specific number of tasks, and that all of + * its tasks are in the FAILED state. + * + * @param connectorName the connector name + * @param numTasks the number of tasks + * @param detailMessage the assertion message + * @throws InterruptedException + */ + public void assertConnectorIsFailedAndTasksHaveFailed(String connectorName, int numTasks, String detailMessage) + throws InterruptedException { + try { + waitForCondition( + () -> checkConnectorState( + connectorName, + AbstractStatus.State.FAILED, + numTasks, + AbstractStatus.State.FAILED, + (actual, expected) -> actual >= expected + ).orElse(false), + CONNECTOR_SETUP_DURATION_MS, + "Either the connector is running or not all the " + numTasks + " tasks have failed."); + } catch (AssertionError e) { + throw new AssertionError(detailMessage, e); + } + } + + /** + * Assert that a connector and its tasks are not running. + * + * @param connectorName the connector name + * @param detailMessage the assertion message + * @throws InterruptedException + */ + public void assertConnectorAndTasksAreStopped(String connectorName, String detailMessage) + throws InterruptedException { + try { + waitForCondition( + () -> checkConnectorAndTasksAreStopped(connectorName), + CONNECTOR_SETUP_DURATION_MS, + "At least the connector or one of its tasks is still running"); + } catch (AssertionError e) { + throw new AssertionError(detailMessage, e); + } + } + + /** + * Check whether the connector or any of its tasks are still in RUNNING state + * + * @param connectorName the connector + * @return true if the connector and all the tasks are not in RUNNING state; false otherwise + */ + protected boolean checkConnectorAndTasksAreStopped(String connectorName) { + ConnectorStateInfo info; + try { + info = connect.connectorStatus(connectorName); + } catch (ConnectRestException e) { + return e.statusCode() == Response.Status.NOT_FOUND.getStatusCode(); + } catch (Exception e) { + log.error("Could not check connector state info.", e); + return false; + } + if (info == null) { + return true; + } + return !info.connector().state().equals(AbstractStatus.State.RUNNING.toString()) + && info.tasks().stream().noneMatch(s -> s.state().equals(AbstractStatus.State.RUNNING.toString())); + } + + /** + * Check whether the given connector state matches the current state of the connector and + * whether it has at least the given number of tasks, with all the tasks matching the given + * task state. + * @param connectorName the connector + * @param connectorState + * @param numTasks the expected number of tasks + * @param tasksState + * @return true if the connector and tasks are in RUNNING state; false otherwise + */ + protected Optional checkConnectorState( + String connectorName, + AbstractStatus.State connectorState, + int numTasks, + AbstractStatus.State tasksState, + BiFunction comp + ) { + try { + ConnectorStateInfo info = connect.connectorStatus(connectorName); + boolean result = info != null + && comp.apply(info.tasks().size(), numTasks) + && info.connector().state().equals(connectorState.toString()) + && info.tasks().stream().allMatch(s -> s.state().equals(tasksState.toString())); + return Optional.of(result); + } catch (Exception e) { + log.error("Could not check connector state info.", e); + return Optional.empty(); + } + } + + /** + * Check whether the given connector state matches the current state of the connector and + * whether it has at least the given number of tasks, with numTasksInTasksState matching the given + * task state. + * @param connectorName the connector + * @param connectorState + * @param numTasks the expected number of tasks + * @param tasksState + * @return true if the connector and tasks are in RUNNING state; false otherwise + */ + protected Optional checkConnectorState( + String connectorName, + AbstractStatus.State connectorState, + int numTasks, + int numTasksInTasksState, + AbstractStatus.State tasksState, + BiFunction comp + ) { + try { + ConnectorStateInfo info = connect.connectorStatus(connectorName); + boolean result = info != null + && comp.apply(info.tasks().size(), numTasks) + && info.connector().state().equals(connectorState.toString()) + && info.tasks().stream().filter(s -> s.state().equals(tasksState.toString())).count() == numTasksInTasksState; + return Optional.of(result); + } catch (Exception e) { + log.error("Could not check connector state info.", e); + return Optional.empty(); + } + } + /** + * Assert that a connector's set of active topics matches the given collection of topic names. + * + * @param connectorName the connector name + * @param topics a collection of topics to compare against + * @param detailMessage the assertion message + * @throws InterruptedException + */ + public void assertConnectorActiveTopics(String connectorName, Collection topics, String detailMessage) throws InterruptedException { + try { + waitForCondition( + () -> checkConnectorActiveTopics(connectorName, topics).orElse(false), + CONNECT_INTERNAL_TOPIC_UPDATES_DURATION_MS, + "Connector active topics don't match the expected collection"); + } catch (AssertionError e) { + throw new AssertionError(detailMessage, e); + } + } + + /** + * Check whether a connector's set of active topics matches the given collection of topic names. + * + * @param connectorName the connector name + * @param topics a collection of topics to compare against + * @return true if the connector's active topics matches the given collection; false otherwise + */ + protected Optional checkConnectorActiveTopics(String connectorName, Collection topics) { + try { + ActiveTopicsInfo info = connect.connectorTopics(connectorName); + boolean result = info != null + && topics.size() == info.topics().size() + && topics.containsAll(info.topics()); + log.debug("Found connector {} using topics: {}", connectorName, info.topics()); + return Optional.of(result); + } catch (Exception e) { + log.error("Could not check connector {} state info.", connectorName, e); + return Optional.empty(); + } + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/util/clusters/EmbeddedKafkaCluster.java b/connect/runtime/src/test/java/org/apache/kafka/connect/util/clusters/EmbeddedKafkaCluster.java new file mode 100644 index 0000000..cf7fde5 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/util/clusters/EmbeddedKafkaCluster.java @@ -0,0 +1,501 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util.clusters; + +import kafka.cluster.EndPoint; +import kafka.server.KafkaConfig; +import kafka.server.KafkaServer; +import kafka.utils.CoreUtils; +import kafka.utils.TestUtils; +import kafka.zk.EmbeddedZookeeper; +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.AdminClientConfig; +import org.apache.kafka.clients.admin.DescribeTopicsResult; +import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.clients.admin.TopicDescription; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.config.SslConfigs; +import org.apache.kafka.common.config.types.Password; +import org.apache.kafka.common.errors.InvalidReplicationFactorException; +import org.apache.kafka.common.errors.UnknownTopicOrPartitionException; +import org.apache.kafka.common.network.ListenerName; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.metadata.BrokerState; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.file.Files; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Properties; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.function.Predicate; +import java.util.stream.Collectors; + +import static org.apache.kafka.clients.consumer.ConsumerConfig.AUTO_OFFSET_RESET_CONFIG; +import static org.apache.kafka.clients.consumer.ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG; +import static org.apache.kafka.clients.consumer.ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG; +import static org.apache.kafka.clients.consumer.ConsumerConfig.GROUP_ID_CONFIG; +import static org.apache.kafka.clients.consumer.ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG; +import static org.apache.kafka.clients.consumer.ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG; + +/** + * Setup an embedded Kafka cluster with specified number of brokers and specified broker properties. To be used for + * integration tests. + */ +public class EmbeddedKafkaCluster { + + private static final Logger log = LoggerFactory.getLogger(EmbeddedKafkaCluster.class); + + private static final long DEFAULT_PRODUCE_SEND_DURATION_MS = TimeUnit.SECONDS.toMillis(120); + + // Kafka Config + private final KafkaServer[] brokers; + private final Properties brokerConfig; + private final Time time = new MockTime(); + private final int[] currentBrokerPorts; + private final String[] currentBrokerLogDirs; + private final boolean hasListenerConfig; + + private EmbeddedZookeeper zookeeper = null; + private ListenerName listenerName = new ListenerName("PLAINTEXT"); + private KafkaProducer producer; + + public EmbeddedKafkaCluster(final int numBrokers, + final Properties brokerConfig) { + brokers = new KafkaServer[numBrokers]; + currentBrokerPorts = new int[numBrokers]; + currentBrokerLogDirs = new String[numBrokers]; + this.brokerConfig = brokerConfig; + // Since we support `stop` followed by `startOnlyKafkaOnSamePorts`, we track whether + // a listener config is defined during initialization in order to know if it's + // safe to override it + hasListenerConfig = brokerConfig.get(KafkaConfig.ListenersProp()) != null; + } + + /** + * Starts the Kafka cluster alone using the ports that were assigned during initialization of + * the harness. + * + * @throws ConnectException if a directory to store the data cannot be created + */ + public void startOnlyKafkaOnSamePorts() { + doStart(); + } + + public void start() { + // pick a random port + zookeeper = new EmbeddedZookeeper(); + Arrays.fill(currentBrokerPorts, 0); + Arrays.fill(currentBrokerLogDirs, null); + doStart(); + } + + private void doStart() { + brokerConfig.put(KafkaConfig.ZkConnectProp(), zKConnectString()); + + putIfAbsent(brokerConfig, KafkaConfig.DeleteTopicEnableProp(), true); + putIfAbsent(brokerConfig, KafkaConfig.GroupInitialRebalanceDelayMsProp(), 0); + putIfAbsent(brokerConfig, KafkaConfig.OffsetsTopicReplicationFactorProp(), (short) brokers.length); + putIfAbsent(brokerConfig, KafkaConfig.AutoCreateTopicsEnableProp(), false); + + Object listenerConfig = brokerConfig.get(KafkaConfig.InterBrokerListenerNameProp()); + if (listenerConfig == null) + listenerConfig = brokerConfig.get(KafkaConfig.InterBrokerSecurityProtocolProp()); + if (listenerConfig == null) + listenerConfig = "PLAINTEXT"; + listenerName = new ListenerName(listenerConfig.toString()); + + for (int i = 0; i < brokers.length; i++) { + brokerConfig.put(KafkaConfig.BrokerIdProp(), i); + currentBrokerLogDirs[i] = currentBrokerLogDirs[i] == null ? createLogDir() : currentBrokerLogDirs[i]; + brokerConfig.put(KafkaConfig.LogDirProp(), currentBrokerLogDirs[i]); + if (!hasListenerConfig) + brokerConfig.put(KafkaConfig.ListenersProp(), listenerName.value() + "://localhost:" + currentBrokerPorts[i]); + brokers[i] = TestUtils.createServer(new KafkaConfig(brokerConfig, true), time); + currentBrokerPorts[i] = brokers[i].boundPort(listenerName); + } + + Map producerProps = new HashMap<>(); + producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers()); + if (sslEnabled()) { + producerProps.put(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG, brokerConfig.get(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG)); + producerProps.put(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG, brokerConfig.get(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG)); + producerProps.put(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, "SSL"); + } + producer = new KafkaProducer<>(producerProps, new ByteArraySerializer(), new ByteArraySerializer()); + } + + public void stopOnlyKafka() { + stop(false, false); + } + + public void stop() { + stop(true, true); + } + + private void stop(boolean deleteLogDirs, boolean stopZK) { + try { + if (producer != null) { + producer.close(); + } + } catch (Exception e) { + log.error("Could not shutdown producer ", e); + throw new RuntimeException("Could not shutdown producer", e); + } + + for (KafkaServer broker : brokers) { + try { + broker.shutdown(); + } catch (Throwable t) { + String msg = String.format("Could not shutdown broker at %s", address(broker)); + log.error(msg, t); + throw new RuntimeException(msg, t); + } + } + + if (deleteLogDirs) { + for (KafkaServer broker : brokers) { + try { + log.info("Cleaning up kafka log dirs at {}", broker.config().logDirs()); + CoreUtils.delete(broker.config().logDirs()); + } catch (Throwable t) { + String msg = String.format("Could not clean up log dirs for broker at %s", + address(broker)); + log.error(msg, t); + throw new RuntimeException(msg, t); + } + } + } + + try { + if (stopZK) { + zookeeper.shutdown(); + } + } catch (Throwable t) { + String msg = String.format("Could not shutdown zookeeper at %s", zKConnectString()); + log.error(msg, t); + throw new RuntimeException(msg, t); + } + } + + private static void putIfAbsent(final Properties props, final String propertyKey, final Object propertyValue) { + if (!props.containsKey(propertyKey)) { + props.put(propertyKey, propertyValue); + } + } + + private String createLogDir() { + try { + return Files.createTempDirectory(getClass().getSimpleName()).toString(); + } catch (IOException e) { + log.error("Unable to create temporary log directory", e); + throw new ConnectException("Unable to create temporary log directory", e); + } + } + + public String bootstrapServers() { + return Arrays.stream(brokers) + .map(this::address) + .collect(Collectors.joining(",")); + } + + public String address(KafkaServer server) { + final EndPoint endPoint = server.advertisedListeners().head(); + return endPoint.host() + ":" + endPoint.port(); + } + + public String zKConnectString() { + return "127.0.0.1:" + zookeeper.port(); + } + + /** + * Get the brokers that have a {@link BrokerState#RUNNING} state. + * + * @return the list of {@link KafkaServer} instances that are running; + * never null but possibly empty + */ + public Set runningBrokers() { + return brokersInState(state -> state == BrokerState.RUNNING); + } + + /** + * Get the brokers whose state match the given predicate. + * + * @return the list of {@link KafkaServer} instances with states that match the predicate; + * never null but possibly empty + */ + public Set brokersInState(Predicate desiredState) { + return Arrays.stream(brokers) + .filter(b -> hasState(b, desiredState)) + .collect(Collectors.toSet()); + } + + protected boolean hasState(KafkaServer server, Predicate desiredState) { + try { + return desiredState.test(server.brokerState()); + } catch (Throwable e) { + // Broker failed to respond. + return false; + } + } + + public boolean sslEnabled() { + final String listeners = brokerConfig.getProperty(KafkaConfig.ListenersProp()); + return listeners != null && listeners.contains("SSL"); + } + + /** + * Get the topic descriptions of the named topics. The value of the map entry will be empty + * if the topic does not exist. + * + * @param topicNames the names of the topics to describe + * @return the map of optional {@link TopicDescription} keyed by the topic name + */ + public Map> describeTopics(String... topicNames) { + return describeTopics(new HashSet<>(Arrays.asList(topicNames))); + } + + /** + * Get the topic descriptions of the named topics. The value of the map entry will be empty + * if the topic does not exist. + * + * @param topicNames the names of the topics to describe + * @return the map of optional {@link TopicDescription} keyed by the topic name + */ + public Map> describeTopics(Set topicNames) { + Map> results = new HashMap<>(); + log.info("Describing topics {}", topicNames); + try (Admin admin = createAdminClient()) { + DescribeTopicsResult result = admin.describeTopics(topicNames); + Map> byName = result.topicNameValues(); + for (Map.Entry> entry : byName.entrySet()) { + String topicName = entry.getKey(); + try { + TopicDescription desc = entry.getValue().get(); + results.put(topicName, Optional.of(desc)); + log.info("Found topic {} : {}", topicName, desc); + } catch (ExecutionException e) { + Throwable cause = e.getCause(); + if (cause instanceof UnknownTopicOrPartitionException) { + results.put(topicName, Optional.empty()); + log.info("Found non-existant topic {}", topicName); + continue; + } + throw new AssertionError("Could not describe topic(s)" + topicNames, e); + } + } + } catch (Exception e) { + throw new AssertionError("Could not describe topic(s) " + topicNames, e); + } + log.info("Found topics {}", results); + return results; + } + + /** + * Create a Kafka topic with 1 partition and a replication factor of 1. + * + * @param topic The name of the topic. + */ + public void createTopic(String topic) { + createTopic(topic, 1); + } + + /** + * Create a Kafka topic with given partition and a replication factor of 1. + * + * @param topic The name of the topic. + */ + public void createTopic(String topic, int partitions) { + createTopic(topic, partitions, 1, Collections.emptyMap()); + } + + /** + * Create a Kafka topic with given partition, replication factor, and topic config. + * + * @param topic The name of the topic. + */ + public void createTopic(String topic, int partitions, int replication, Map topicConfig) { + createTopic(topic, partitions, replication, topicConfig, new Properties()); + } + + /** + * Create a Kafka topic with the given parameters. + * + * @param topic The name of the topic. + * @param partitions The number of partitions for this topic. + * @param replication The replication factor for (partitions of) this topic. + * @param topicConfig Additional topic-level configuration settings. + * @param adminClientConfig Additional admin client configuration settings. + */ + public void createTopic(String topic, int partitions, int replication, Map topicConfig, Properties adminClientConfig) { + if (replication > brokers.length) { + throw new InvalidReplicationFactorException("Insufficient brokers (" + + brokers.length + ") for desired replication (" + replication + ")"); + } + + log.info("Creating topic { name: {}, partitions: {}, replication: {}, config: {} }", + topic, partitions, replication, topicConfig); + final NewTopic newTopic = new NewTopic(topic, partitions, (short) replication); + newTopic.configs(topicConfig); + + try (final Admin adminClient = createAdminClient(adminClientConfig)) { + adminClient.createTopics(Collections.singletonList(newTopic)).all().get(); + } catch (final InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + } + + /** + * Delete a Kafka topic. + * + * @param topic the topic to delete; may not be null + */ + public void deleteTopic(String topic) { + try (final Admin adminClient = createAdminClient()) { + adminClient.deleteTopics(Collections.singleton(topic)).all().get(); + } catch (final InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + } + + public void produce(String topic, String value) { + produce(topic, null, null, value); + } + + public void produce(String topic, String key, String value) { + produce(topic, null, key, value); + } + + public void produce(String topic, Integer partition, String key, String value) { + ProducerRecord msg = new ProducerRecord<>(topic, partition, key == null ? null : key.getBytes(), value == null ? null : value.getBytes()); + try { + producer.send(msg).get(DEFAULT_PRODUCE_SEND_DURATION_MS, TimeUnit.MILLISECONDS); + } catch (Exception e) { + throw new KafkaException("Could not produce message: " + msg, e); + } + } + + public Admin createAdminClient(Properties adminClientConfig) { + adminClientConfig.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers()); + final Object listeners = brokerConfig.get(KafkaConfig.ListenersProp()); + if (listeners != null && listeners.toString().contains("SSL")) { + adminClientConfig.put(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG, brokerConfig.get(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG)); + adminClientConfig.put(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG, ((Password) brokerConfig.get(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG)).value()); + adminClientConfig.put(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, "SSL"); + } + return Admin.create(adminClientConfig); + } + + public Admin createAdminClient() { + return createAdminClient(new Properties()); + } + + /** + * Consume at least n records in a given duration or throw an exception. + * + * @param n the number of expected records in this topic. + * @param maxDuration the max duration to wait for these records (in milliseconds). + * @param topics the topics to subscribe and consume records from. + * @return a {@link ConsumerRecords} collection containing at least n records. + */ + public ConsumerRecords consume(int n, long maxDuration, String... topics) { + Map>> records = new HashMap<>(); + int consumedRecords = 0; + try (KafkaConsumer consumer = createConsumerAndSubscribeTo(Collections.emptyMap(), topics)) { + final long startMillis = System.currentTimeMillis(); + long allowedDuration = maxDuration; + while (allowedDuration > 0) { + log.debug("Consuming from {} for {} millis.", Arrays.toString(topics), allowedDuration); + ConsumerRecords rec = consumer.poll(Duration.ofMillis(allowedDuration)); + if (rec.isEmpty()) { + allowedDuration = maxDuration - (System.currentTimeMillis() - startMillis); + continue; + } + for (TopicPartition partition: rec.partitions()) { + final List> r = rec.records(partition); + records.computeIfAbsent(partition, t -> new ArrayList<>()).addAll(r); + consumedRecords += r.size(); + } + if (consumedRecords >= n) { + return new ConsumerRecords<>(records); + } + allowedDuration = maxDuration - (System.currentTimeMillis() - startMillis); + } + } + + throw new RuntimeException("Could not find enough records. found " + consumedRecords + ", expected " + n); + } + + public KafkaConsumer createConsumer(Map consumerProps) { + Map props = new HashMap<>(consumerProps); + + putIfAbsent(props, GROUP_ID_CONFIG, UUID.randomUUID().toString()); + putIfAbsent(props, BOOTSTRAP_SERVERS_CONFIG, bootstrapServers()); + putIfAbsent(props, ENABLE_AUTO_COMMIT_CONFIG, "false"); + putIfAbsent(props, AUTO_OFFSET_RESET_CONFIG, "earliest"); + putIfAbsent(props, KEY_DESERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArrayDeserializer"); + putIfAbsent(props, VALUE_DESERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArrayDeserializer"); + if (sslEnabled()) { + putIfAbsent(props, SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG, brokerConfig.get(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG)); + putIfAbsent(props, SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG, brokerConfig.get(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG)); + putIfAbsent(props, CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, "SSL"); + } + KafkaConsumer consumer; + try { + consumer = new KafkaConsumer<>(props); + } catch (Throwable t) { + throw new ConnectException("Failed to create consumer", t); + } + return consumer; + } + + public KafkaConsumer createConsumerAndSubscribeTo(Map consumerProps, String... topics) { + KafkaConsumer consumer = createConsumer(consumerProps); + consumer.subscribe(Arrays.asList(topics)); + return consumer; + } + + private static void putIfAbsent(final Map props, final String propertyKey, final Object propertyValue) { + if (!props.containsKey(propertyKey)) { + props.put(propertyKey, propertyValue); + } + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/util/clusters/UngracefulShutdownException.java b/connect/runtime/src/test/java/org/apache/kafka/connect/util/clusters/UngracefulShutdownException.java new file mode 100644 index 0000000..2f2b030 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/util/clusters/UngracefulShutdownException.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util.clusters; + +import org.apache.kafka.common.KafkaException; + +/** + * An exception that can be used from within an {@code Exit.Procedure} to mask exit or halt calls + * and signify that the service terminated abruptly. It's intended to be used only from within + * integration tests. + */ +public class UngracefulShutdownException extends KafkaException { + public UngracefulShutdownException(String s) { + super(s); + } + + public UngracefulShutdownException(String s, Throwable throwable) { + super(s, throwable); + } + + public UngracefulShutdownException(Throwable throwable) { + super(throwable); + } +} diff --git a/connect/runtime/src/test/java/org/apache/kafka/connect/util/clusters/WorkerHandle.java b/connect/runtime/src/test/java/org/apache/kafka/connect/util/clusters/WorkerHandle.java new file mode 100644 index 0000000..4d94794 --- /dev/null +++ b/connect/runtime/src/test/java/org/apache/kafka/connect/util/clusters/WorkerHandle.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.util.clusters; + +import org.apache.kafka.connect.cli.ConnectDistributed; +import org.apache.kafka.connect.runtime.Connect; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.URI; +import java.util.Map; +import java.util.Objects; + +/** + * A handle to a worker executing in a Connect cluster. + */ +public class WorkerHandle { + private static final Logger log = LoggerFactory.getLogger(WorkerHandle.class); + + private final String workerName; + private final Connect worker; + + protected WorkerHandle(String workerName, Connect worker) { + this.workerName = workerName; + this.worker = worker; + } + + /** + * Create and start a new worker with the given properties. + * + * @param name a name for this worker + * @param workerProperties the worker properties + * @return the worker's handle + */ + public static WorkerHandle start(String name, Map workerProperties) { + return new WorkerHandle(name, new ConnectDistributed().startConnect(workerProperties)); + } + + /** + * Stop this worker. + */ + public void stop() { + worker.stop(); + } + + /** + * Determine if this worker is running. + * + * @return true if the worker is running, or false otherwise + */ + public boolean isRunning() { + return worker.isRunning(); + } + + /** + * Get the workers's name corresponding to this handle. + * + * @return the worker's name + */ + public String name() { + return workerName; + } + + /** + * Get the workers's url that accepts requests to its REST endpoint. + * + * @return the worker's url + */ + public URI url() { + return worker.restUrl(); + } + + /** + * Get the workers's url that accepts requests to its Admin REST endpoint. + * + * @return the worker's admin url + */ + public URI adminUrl() { + return worker.adminUrl(); + } + + @Override + public String toString() { + return "WorkerHandle{" + + "workerName='" + workerName + '\'' + + "workerURL='" + worker.restUrl() + '\'' + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof WorkerHandle)) { + return false; + } + WorkerHandle that = (WorkerHandle) o; + return Objects.equals(workerName, that.workerName) && + Objects.equals(worker, that.worker); + } + + @Override + public int hashCode() { + return Objects.hash(workerName, worker); + } +} diff --git a/connect/runtime/src/test/resources/META-INF/services/org.apache.kafka.connect.rest.ConnectRestExtension b/connect/runtime/src/test/resources/META-INF/services/org.apache.kafka.connect.rest.ConnectRestExtension new file mode 100644 index 0000000..0a1ef88 --- /dev/null +++ b/connect/runtime/src/test/resources/META-INF/services/org.apache.kafka.connect.rest.ConnectRestExtension @@ -0,0 +1,16 @@ + # Licensed to the Apache Software Foundation (ASF) under one or more + # contributor license agreements. See the NOTICE file distributed with + # this work for additional information regarding copyright ownership. + # The ASF licenses this file to You under the Apache License, Version 2.0 + # (the "License"); you may not use this file except in compliance with + # the License. You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + +org.apache.kafka.connect.runtime.isolation.PluginsTest$TestConnectRestExtension \ No newline at end of file diff --git a/connect/runtime/src/test/resources/log4j.properties b/connect/runtime/src/test/resources/log4j.properties new file mode 100644 index 0000000..176692d --- /dev/null +++ b/connect/runtime/src/test/resources/log4j.properties @@ -0,0 +1,35 @@ +## +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +## +log4j.rootLogger=INFO, stdout + +log4j.appender.stdout=org.apache.log4j.ConsoleAppender +log4j.appender.stdout.layout=org.apache.log4j.PatternLayout +# +# The `%X{connector.context}` parameter in the layout includes connector-specific and task-specific information +# in the log message, where appropriate. This makes it easier to identify those log messages that apply to a +# specific connector. Simply add this parameter to the log layout configuration below to include the contextual information. +# +log4j.appender.stdout.layout.ConversionPattern=[%d] %p %X{connector.context}%m (%c:%L)%n +# +# The following line includes no MDC context parameters: +#log4j.appender.stdout.layout.ConversionPattern=[%d] %p %m (%c:%L)%n (%t) + +log4j.logger.org.reflections=ERROR +log4j.logger.kafka=WARN +log4j.logger.org.apache.kafka.connect=DEBUG +log4j.logger.org.apache.kafka.connect.runtime.distributed=DEBUG +log4j.logger.org.apache.kafka.connect.integration=DEBUG diff --git a/connect/runtime/src/test/resources/test-plugins/aliased-static-field/test/plugins/AliasedStaticField.java b/connect/runtime/src/test/resources/test-plugins/aliased-static-field/test/plugins/AliasedStaticField.java new file mode 100644 index 0000000..d865f4e --- /dev/null +++ b/connect/runtime/src/test/resources/test-plugins/aliased-static-field/test/plugins/AliasedStaticField.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.plugins; + +import java.util.Map; +import java.util.HashMap; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.runtime.isolation.SamplingTestPlugin; + +/** + * Samples data about its initialization environment for later analysis + * Samples are shared between instances of the same class in a static variable + */ +public class AliasedStaticField extends SamplingTestPlugin implements Converter { + + private static final Map SAMPLES; + private static final ClassLoader STATIC_CLASS_LOADER; + private final ClassLoader classloader; + + static { + SAMPLES = new HashMap<>(); + STATIC_CLASS_LOADER = Thread.currentThread().getContextClassLoader(); + } + + { + classloader = Thread.currentThread().getContextClassLoader(); + } + + @Override + public void configure(final Map configs, final boolean isKey) { + + } + + @Override + public byte[] fromConnectData(final String topic, final Schema schema, final Object value) { + return new byte[0]; + } + + @Override + public SchemaAndValue toConnectData(final String topic, final byte[] value) { + return null; + } + + @Override + public ClassLoader staticClassloader() { + return STATIC_CLASS_LOADER; + } + + @Override + public ClassLoader classloader() { + return classloader; + } + + @Override + public Map otherSamples() { + return SAMPLES; + } +} diff --git a/connect/runtime/src/test/resources/test-plugins/always-throw-exception/test/plugins/AlwaysThrowException.java b/connect/runtime/src/test/resources/test-plugins/always-throw-exception/test/plugins/AlwaysThrowException.java new file mode 100644 index 0000000..858f3ed --- /dev/null +++ b/connect/runtime/src/test/resources/test-plugins/always-throw-exception/test/plugins/AlwaysThrowException.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.plugins; + +import java.util.Map; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.runtime.isolation.SamplingTestPlugin; +import org.apache.kafka.connect.storage.Converter; + +/** + * Unconditionally throw an exception during static initialization. + */ +public class AlwaysThrowException implements Converter { + + static { + setup(); + } + + public static void setup() { + throw new RuntimeException("I always throw an exception"); + } + + @Override + public void configure(final Map configs, final boolean isKey) { + + } + + @Override + public byte[] fromConnectData(final String topic, final Schema schema, final Object value) { + return new byte[0]; + } + + @Override + public SchemaAndValue toConnectData(final String topic, final byte[] value) { + return null; + } +} diff --git a/connect/runtime/src/test/resources/test-plugins/sampling-config-provider/META-INF/services/org.apache.kafka.common.config.provider.ConfigProvider b/connect/runtime/src/test/resources/test-plugins/sampling-config-provider/META-INF/services/org.apache.kafka.common.config.provider.ConfigProvider new file mode 100644 index 0000000..62d8df2 --- /dev/null +++ b/connect/runtime/src/test/resources/test-plugins/sampling-config-provider/META-INF/services/org.apache.kafka.common.config.provider.ConfigProvider @@ -0,0 +1,16 @@ + # Licensed to the Apache Software Foundation (ASF) under one or more + # contributor license agreements. See the NOTICE file distributed with + # this work for additional information regarding copyright ownership. + # The ASF licenses this file to You under the Apache License, Version 2.0 + # (the "License"); you may not use this file except in compliance with + # the License. You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + +test.plugins.SamplingConfigProvider diff --git a/connect/runtime/src/test/resources/test-plugins/sampling-config-provider/test/plugins/SamplingConfigProvider.java b/connect/runtime/src/test/resources/test-plugins/sampling-config-provider/test/plugins/SamplingConfigProvider.java new file mode 100644 index 0000000..df8285e --- /dev/null +++ b/connect/runtime/src/test/resources/test-plugins/sampling-config-provider/test/plugins/SamplingConfigProvider.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.plugins; + +import java.util.Set; +import java.util.Map; +import java.util.HashMap; +import org.apache.kafka.common.config.provider.ConfigProvider; +import org.apache.kafka.common.config.ConfigData; +import org.apache.kafka.common.config.ConfigChangeCallback; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.runtime.isolation.SamplingTestPlugin; +import org.apache.kafka.connect.storage.HeaderConverter; + +/** + * Samples data about its initialization environment for later analysis + */ +public class SamplingConfigProvider extends SamplingTestPlugin implements ConfigProvider { + + private static final ClassLoader STATIC_CLASS_LOADER; + private final ClassLoader classloader; + private Map samples; + + static { + STATIC_CLASS_LOADER = Thread.currentThread().getContextClassLoader(); + } + + { + samples = new HashMap<>(); + classloader = Thread.currentThread().getContextClassLoader(); + } + + @Override + public ConfigData get(String path) { + logMethodCall(samples); + return null; + } + + @Override + public ConfigData get(String path, Set keys) { + logMethodCall(samples); + return null; + } + + @Override + public void subscribe(String path, Set keys, ConfigChangeCallback callback) { + logMethodCall(samples); + } + + @Override + public void unsubscribe(String path, Set keys, ConfigChangeCallback callback) { + logMethodCall(samples); + } + + @Override + public void unsubscribeAll() { + logMethodCall(samples); + } + + @Override + public void configure(final Map configs) { + logMethodCall(samples); + } + + @Override + public void close() { + logMethodCall(samples); + } + + @Override + public ClassLoader staticClassloader() { + return STATIC_CLASS_LOADER; + } + + @Override + public ClassLoader classloader() { + return classloader; + } + + @Override + public Map otherSamples() { + return samples; + } +} diff --git a/connect/runtime/src/test/resources/test-plugins/sampling-configurable/test/plugins/SamplingConfigurable.java b/connect/runtime/src/test/resources/test-plugins/sampling-configurable/test/plugins/SamplingConfigurable.java new file mode 100644 index 0000000..a917f2f --- /dev/null +++ b/connect/runtime/src/test/resources/test-plugins/sampling-configurable/test/plugins/SamplingConfigurable.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.plugins; + +import java.util.Map; +import java.util.HashMap; +import org.apache.kafka.common.Configurable; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.runtime.isolation.SamplingTestPlugin; + +/** + * Samples data about its initialization environment for later analysis + */ +public class SamplingConfigurable extends SamplingTestPlugin implements Converter, Configurable { + + private static final ClassLoader STATIC_CLASS_LOADER; + private final ClassLoader classloader; + private Map samples; + + static { + STATIC_CLASS_LOADER = Thread.currentThread().getContextClassLoader(); + } + + { + samples = new HashMap<>(); + classloader = Thread.currentThread().getContextClassLoader(); + } + + @Override + public void configure(final Map configs) { + logMethodCall(samples); + } + + @Override + public void configure(final Map configs, final boolean isKey) { + } + + @Override + public byte[] fromConnectData(final String topic, final Schema schema, final Object value) { + return new byte[0]; + } + + @Override + public SchemaAndValue toConnectData(final String topic, final byte[] value) { + return null; + } + + @Override + public ClassLoader staticClassloader() { + return STATIC_CLASS_LOADER; + } + + @Override + public ClassLoader classloader() { + return classloader; + } + + @Override + public Map otherSamples() { + return samples; + } +} diff --git a/connect/runtime/src/test/resources/test-plugins/sampling-converter/test/plugins/SamplingConverter.java b/connect/runtime/src/test/resources/test-plugins/sampling-converter/test/plugins/SamplingConverter.java new file mode 100644 index 0000000..39109a1 --- /dev/null +++ b/connect/runtime/src/test/resources/test-plugins/sampling-converter/test/plugins/SamplingConverter.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.plugins; + +import java.util.Map; +import java.util.HashMap; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.runtime.isolation.SamplingTestPlugin; + +/** + * Samples data about its initialization environment for later analysis + */ +public class SamplingConverter extends SamplingTestPlugin implements Converter { + + private static final ClassLoader STATIC_CLASS_LOADER; + private final ClassLoader classloader; + private Map samples; + + static { + STATIC_CLASS_LOADER = Thread.currentThread().getContextClassLoader(); + } + + { + samples = new HashMap<>(); + classloader = Thread.currentThread().getContextClassLoader(); + } + + @Override + public void configure(final Map configs, final boolean isKey) { + logMethodCall(samples); + } + + @Override + public byte[] fromConnectData(final String topic, final Schema schema, final Object value) { + logMethodCall(samples); + return new byte[0]; + } + + @Override + public SchemaAndValue toConnectData(final String topic, final byte[] value) { + logMethodCall(samples); + return null; + } + + @Override + public ClassLoader staticClassloader() { + return STATIC_CLASS_LOADER; + } + + @Override + public ClassLoader classloader() { + return classloader; + } + + @Override + public Map otherSamples() { + return samples; + } +} diff --git a/connect/runtime/src/test/resources/test-plugins/sampling-header-converter/test/plugins/SamplingHeaderConverter.java b/connect/runtime/src/test/resources/test-plugins/sampling-header-converter/test/plugins/SamplingHeaderConverter.java new file mode 100644 index 0000000..11a1e28 --- /dev/null +++ b/connect/runtime/src/test/resources/test-plugins/sampling-header-converter/test/plugins/SamplingHeaderConverter.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.plugins; + +import java.util.Map; +import java.util.HashMap; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.runtime.isolation.SamplingTestPlugin; +import org.apache.kafka.connect.storage.HeaderConverter; + +/** + * Samples data about its initialization environment for later analysis + */ +public class SamplingHeaderConverter extends SamplingTestPlugin implements HeaderConverter { + + private static final ClassLoader STATIC_CLASS_LOADER; + private final ClassLoader classloader; + private Map samples; + + static { + STATIC_CLASS_LOADER = Thread.currentThread().getContextClassLoader(); + } + + { + samples = new HashMap<>(); + classloader = Thread.currentThread().getContextClassLoader(); + } + + @Override + public SchemaAndValue toConnectHeader(String topic, String headerKey, byte[] value) { + logMethodCall(samples); + return null; + } + + @Override + public byte[] fromConnectHeader(String topic, String headerKey, Schema schema, Object value) { + logMethodCall(samples); + return new byte[0]; + } + + @Override + public ConfigDef config() { + logMethodCall(samples); + return null; + } + + @Override + public void configure(final Map configs) { + logMethodCall(samples); + } + + @Override + public void close() { + logMethodCall(samples); + } + + @Override + public ClassLoader staticClassloader() { + return STATIC_CLASS_LOADER; + } + + @Override + public ClassLoader classloader() { + return classloader; + } + + @Override + public Map otherSamples() { + return samples; + } +} diff --git a/connect/runtime/src/test/resources/test-plugins/service-loader/META-INF/services/test.plugins.ServiceLoadedClass b/connect/runtime/src/test/resources/test-plugins/service-loader/META-INF/services/test.plugins.ServiceLoadedClass new file mode 100644 index 0000000..b8db865 --- /dev/null +++ b/connect/runtime/src/test/resources/test-plugins/service-loader/META-INF/services/test.plugins.ServiceLoadedClass @@ -0,0 +1,16 @@ + # Licensed to the Apache Software Foundation (ASF) under one or more + # contributor license agreements. See the NOTICE file distributed with + # this work for additional information regarding copyright ownership. + # The ASF licenses this file to You under the Apache License, Version 2.0 + # (the "License"); you may not use this file except in compliance with + # the License. You may obtain a copy of the License at + # + # http://www.apache.org/licenses/LICENSE-2.0 + # + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License. + +test.plugins.ServiceLoadedSubclass \ No newline at end of file diff --git a/connect/runtime/src/test/resources/test-plugins/service-loader/test/plugins/ServiceLoadedClass.java b/connect/runtime/src/test/resources/test-plugins/service-loader/test/plugins/ServiceLoadedClass.java new file mode 100644 index 0000000..98677ed --- /dev/null +++ b/connect/runtime/src/test/resources/test-plugins/service-loader/test/plugins/ServiceLoadedClass.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.plugins; + +import org.apache.kafka.connect.runtime.isolation.SamplingTestPlugin; + +/** + * Superclass for service loaded classes + */ +public class ServiceLoadedClass extends SamplingTestPlugin { + + private static final ClassLoader STATIC_CLASS_LOADER; + private final ClassLoader classloader; + + static { + STATIC_CLASS_LOADER = Thread.currentThread().getContextClassLoader(); + } + + { + classloader = Thread.currentThread().getContextClassLoader(); + } + + @Override + public ClassLoader staticClassloader() { + return STATIC_CLASS_LOADER; + } + + @Override + public ClassLoader classloader() { + return classloader; + } + +} diff --git a/connect/runtime/src/test/resources/test-plugins/service-loader/test/plugins/ServiceLoadedSubclass.java b/connect/runtime/src/test/resources/test-plugins/service-loader/test/plugins/ServiceLoadedSubclass.java new file mode 100644 index 0000000..cfc6b6f --- /dev/null +++ b/connect/runtime/src/test/resources/test-plugins/service-loader/test/plugins/ServiceLoadedSubclass.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.plugins; + +/** + * Instance of a service loaded class + */ +public class ServiceLoadedSubclass extends ServiceLoadedClass { + + private static final ClassLoader STATIC_CLASS_LOADER; + private final ClassLoader classloader; + + static { + STATIC_CLASS_LOADER = Thread.currentThread().getContextClassLoader(); + } + + { + classloader = Thread.currentThread().getContextClassLoader(); + } + + @Override + public ClassLoader staticClassloader() { + return STATIC_CLASS_LOADER; + } + + @Override + public ClassLoader classloader() { + return classloader; + } + +} diff --git a/connect/runtime/src/test/resources/test-plugins/service-loader/test/plugins/ServiceLoaderPlugin.java b/connect/runtime/src/test/resources/test-plugins/service-loader/test/plugins/ServiceLoaderPlugin.java new file mode 100644 index 0000000..e6371ba --- /dev/null +++ b/connect/runtime/src/test/resources/test-plugins/service-loader/test/plugins/ServiceLoaderPlugin.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.plugins; + +import java.util.Map; +import java.util.HashMap; +import java.util.ServiceLoader; +import java.util.Iterator; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.storage.Converter; +import org.apache.kafka.connect.runtime.isolation.SamplingTestPlugin; + +/** + * Samples data about its initialization environment for later analysis + */ +public class ServiceLoaderPlugin extends SamplingTestPlugin implements Converter { + + private static final ClassLoader STATIC_CLASS_LOADER; + private static final Map SAMPLES; + private final ClassLoader classloader; + + static { + STATIC_CLASS_LOADER = Thread.currentThread().getContextClassLoader(); + SAMPLES = new HashMap<>(); + Iterator it = ServiceLoader.load(ServiceLoadedClass.class).iterator(); + while (it.hasNext()) { + ServiceLoadedClass loaded = it.next(); + SAMPLES.put(loaded.getClass().getSimpleName() + ".static", loaded); + } + } + + { + classloader = Thread.currentThread().getContextClassLoader(); + Iterator it = ServiceLoader.load(ServiceLoadedClass.class).iterator(); + while (it.hasNext()) { + ServiceLoadedClass loaded = it.next(); + SAMPLES.put(loaded.getClass().getSimpleName() + ".dynamic", loaded); + } + } + + @Override + public void configure(final Map configs, final boolean isKey) { + } + + @Override + public byte[] fromConnectData(final String topic, final Schema schema, final Object value) { + return new byte[0]; + } + + @Override + public SchemaAndValue toConnectData(final String topic, final byte[] value) { + return null; + } + + @Override + public ClassLoader staticClassloader() { + return STATIC_CLASS_LOADER; + } + + @Override + public ClassLoader classloader() { + return classloader; + } + + @Override + public Map otherSamples() { + return SAMPLES; + } +} diff --git a/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/Cast.java b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/Cast.java new file mode 100644 index 0000000..cf31a00 --- /dev/null +++ b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/Cast.java @@ -0,0 +1,477 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.common.cache.Cache; +import org.apache.kafka.common.cache.LRUCache; +import org.apache.kafka.common.cache.SynchronizedCache; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.connect.connector.ConnectRecord; +import org.apache.kafka.connect.data.ConnectSchema; +import org.apache.kafka.connect.data.Date; +import org.apache.kafka.connect.data.Field; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.Schema.Type; +import org.apache.kafka.connect.data.SchemaBuilder; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.data.Time; +import org.apache.kafka.connect.data.Timestamp; +import org.apache.kafka.connect.data.Values; +import org.apache.kafka.connect.errors.DataException; +import org.apache.kafka.connect.transforms.util.SchemaUtil; +import org.apache.kafka.connect.transforms.util.SimpleConfig; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; +import java.util.Base64; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; + +import static org.apache.kafka.connect.transforms.util.Requirements.requireMap; +import static org.apache.kafka.connect.transforms.util.Requirements.requireStruct; + +public abstract class Cast> implements Transformation { + private static final Logger log = LoggerFactory.getLogger(Cast.class); + + // TODO: Currently we only support top-level field casting. Ideally we could use a dotted notation in the spec to + // allow casting nested fields. + public static final String OVERVIEW_DOC = + "Cast fields or the entire key or value to a specific type, e.g. to force an integer field to a smaller " + + "width. Cast from integers, floats, boolean and string to any other type, " + + "and cast binary to string (base64 encoded)." + + "

            Use the concrete transformation type designed for the record key (" + Key.class.getName() + ") " + + "or value (" + Value.class.getName() + ")."; + + public static final String SPEC_CONFIG = "spec"; + + public static final ConfigDef CONFIG_DEF = new ConfigDef() + .define(SPEC_CONFIG, ConfigDef.Type.LIST, ConfigDef.NO_DEFAULT_VALUE, new ConfigDef.Validator() { + @SuppressWarnings("unchecked") + @Override + public void ensureValid(String name, Object valueObject) { + List value = (List) valueObject; + if (value == null || value.isEmpty()) { + throw new ConfigException("Must specify at least one field to cast."); + } + parseFieldTypes(value); + } + + @Override + public String toString() { + return "list of colon-delimited pairs, e.g. foo:bar,abc:xyz"; + } + }, + ConfigDef.Importance.HIGH, + "List of fields and the type to cast them to of the form field1:type,field2:type to cast fields of " + + "Maps or Structs. A single type to cast the entire value. Valid types are int8, int16, int32, " + + "int64, float32, float64, boolean, and string. Note that binary fields can only be cast to string."); + + private static final String PURPOSE = "cast types"; + + private static final Set SUPPORTED_CAST_INPUT_TYPES = EnumSet.of( + Schema.Type.INT8, Schema.Type.INT16, Schema.Type.INT32, Schema.Type.INT64, + Schema.Type.FLOAT32, Schema.Type.FLOAT64, Schema.Type.BOOLEAN, + Schema.Type.STRING, Schema.Type.BYTES + ); + + private static final Set SUPPORTED_CAST_OUTPUT_TYPES = EnumSet.of( + Schema.Type.INT8, Schema.Type.INT16, Schema.Type.INT32, Schema.Type.INT64, + Schema.Type.FLOAT32, Schema.Type.FLOAT64, Schema.Type.BOOLEAN, + Schema.Type.STRING + ); + + // As a special case for casting the entire value (e.g. the incoming key is a int64 but you know it could be an + // int32 and want the smaller width), we use an otherwise invalid field name in the cast spec to track this. + private static final String WHOLE_VALUE_CAST = null; + + private Map casts; + private Schema.Type wholeValueCastType; + private Cache schemaUpdateCache; + + @Override + public void configure(Map props) { + final SimpleConfig config = new SimpleConfig(CONFIG_DEF, props); + casts = parseFieldTypes(config.getList(SPEC_CONFIG)); + wholeValueCastType = casts.get(WHOLE_VALUE_CAST); + schemaUpdateCache = new SynchronizedCache<>(new LRUCache<>(16)); + } + + @Override + public R apply(R record) { + if (operatingValue(record) == null) { + return record; + } + + if (operatingSchema(record) == null) { + return applySchemaless(record); + } else { + return applyWithSchema(record); + } + } + + @Override + public ConfigDef config() { + return CONFIG_DEF; + } + + @Override + public void close() { + } + + + private R applySchemaless(R record) { + if (wholeValueCastType != null) { + return newRecord(record, null, castValueToType(null, operatingValue(record), wholeValueCastType)); + } + + final Map value = requireMap(operatingValue(record), PURPOSE); + final HashMap updatedValue = new HashMap<>(value); + for (Map.Entry fieldSpec : casts.entrySet()) { + String field = fieldSpec.getKey(); + updatedValue.put(field, castValueToType(null, value.get(field), fieldSpec.getValue())); + } + return newRecord(record, null, updatedValue); + } + + private R applyWithSchema(R record) { + Schema valueSchema = operatingSchema(record); + Schema updatedSchema = getOrBuildSchema(valueSchema); + + // Whole-record casting + if (wholeValueCastType != null) + return newRecord(record, updatedSchema, castValueToType(valueSchema, operatingValue(record), wholeValueCastType)); + + // Casting within a struct + final Struct value = requireStruct(operatingValue(record), PURPOSE); + + final Struct updatedValue = new Struct(updatedSchema); + for (Field field : value.schema().fields()) { + final Object origFieldValue = value.get(field); + final Schema.Type targetType = casts.get(field.name()); + final Object newFieldValue = targetType != null ? castValueToType(field.schema(), origFieldValue, targetType) : origFieldValue; + log.trace("Cast field '{}' from '{}' to '{}'", field.name(), origFieldValue, newFieldValue); + updatedValue.put(updatedSchema.field(field.name()), newFieldValue); + } + return newRecord(record, updatedSchema, updatedValue); + } + + private Schema getOrBuildSchema(Schema valueSchema) { + Schema updatedSchema = schemaUpdateCache.get(valueSchema); + if (updatedSchema != null) + return updatedSchema; + + final SchemaBuilder builder; + if (wholeValueCastType != null) { + builder = SchemaUtil.copySchemaBasics(valueSchema, convertFieldType(wholeValueCastType)); + } else { + builder = SchemaUtil.copySchemaBasics(valueSchema, SchemaBuilder.struct()); + for (Field field : valueSchema.fields()) { + if (casts.containsKey(field.name())) { + SchemaBuilder fieldBuilder = convertFieldType(casts.get(field.name())); + if (field.schema().isOptional()) + fieldBuilder.optional(); + if (field.schema().defaultValue() != null) { + Schema fieldSchema = field.schema(); + fieldBuilder.defaultValue(castValueToType(fieldSchema, fieldSchema.defaultValue(), fieldBuilder.type())); + } + builder.field(field.name(), fieldBuilder.build()); + } else { + builder.field(field.name(), field.schema()); + } + } + } + + if (valueSchema.isOptional()) + builder.optional(); + if (valueSchema.defaultValue() != null) + builder.defaultValue(castValueToType(valueSchema, valueSchema.defaultValue(), builder.type())); + + updatedSchema = builder.build(); + schemaUpdateCache.put(valueSchema, updatedSchema); + return updatedSchema; + } + + private SchemaBuilder convertFieldType(Schema.Type type) { + switch (type) { + case INT8: + return SchemaBuilder.int8(); + case INT16: + return SchemaBuilder.int16(); + case INT32: + return SchemaBuilder.int32(); + case INT64: + return SchemaBuilder.int64(); + case FLOAT32: + return SchemaBuilder.float32(); + case FLOAT64: + return SchemaBuilder.float64(); + case BOOLEAN: + return SchemaBuilder.bool(); + case STRING: + return SchemaBuilder.string(); + default: + throw new DataException("Unexpected type in Cast transformation: " + type); + } + } + + private static Object encodeLogicalType(Schema schema, Object value) { + switch (schema.name()) { + case Date.LOGICAL_NAME: + return Date.fromLogical(schema, (java.util.Date) value); + case Time.LOGICAL_NAME: + return Time.fromLogical(schema, (java.util.Date) value); + case Timestamp.LOGICAL_NAME: + return Timestamp.fromLogical(schema, (java.util.Date) value); + } + return value; + } + + private static Object castValueToType(Schema schema, Object value, Schema.Type targetType) { + try { + if (value == null) return null; + + Schema.Type inferredType = schema == null ? ConnectSchema.schemaType(value.getClass()) : + schema.type(); + if (inferredType == null) { + throw new DataException("Cast transformation was passed a value of type " + value.getClass() + + " which is not supported by Connect's data API"); + } + // Ensure the type we are trying to cast from is supported + validCastType(inferredType, FieldType.INPUT); + + // Perform logical type encoding to their internal representation. + if (schema != null && schema.name() != null && targetType != Type.STRING) { + value = encodeLogicalType(schema, value); + } + + switch (targetType) { + case INT8: + return castToInt8(value); + case INT16: + return castToInt16(value); + case INT32: + return castToInt32(value); + case INT64: + return castToInt64(value); + case FLOAT32: + return castToFloat32(value); + case FLOAT64: + return castToFloat64(value); + case BOOLEAN: + return castToBoolean(value); + case STRING: + return castToString(value); + default: + throw new DataException(targetType.toString() + " is not supported in the Cast transformation."); + } + } catch (NumberFormatException e) { + throw new DataException("Value (" + value.toString() + ") was out of range for requested data type", e); + } + } + + private static byte castToInt8(Object value) { + if (value instanceof Number) + return ((Number) value).byteValue(); + else if (value instanceof Boolean) + return ((boolean) value) ? (byte) 1 : (byte) 0; + else if (value instanceof String) + return Byte.parseByte((String) value); + else + throw new DataException("Unexpected type in Cast transformation: " + value.getClass()); + } + + private static short castToInt16(Object value) { + if (value instanceof Number) + return ((Number) value).shortValue(); + else if (value instanceof Boolean) + return ((boolean) value) ? (short) 1 : (short) 0; + else if (value instanceof String) + return Short.parseShort((String) value); + else + throw new DataException("Unexpected type in Cast transformation: " + value.getClass()); + } + + private static int castToInt32(Object value) { + if (value instanceof Number) + return ((Number) value).intValue(); + else if (value instanceof Boolean) + return ((boolean) value) ? 1 : 0; + else if (value instanceof String) + return Integer.parseInt((String) value); + else + throw new DataException("Unexpected type in Cast transformation: " + value.getClass()); + } + + private static long castToInt64(Object value) { + if (value instanceof Number) + return ((Number) value).longValue(); + else if (value instanceof Boolean) + return ((boolean) value) ? (long) 1 : (long) 0; + else if (value instanceof String) + return Long.parseLong((String) value); + else + throw new DataException("Unexpected type in Cast transformation: " + value.getClass()); + } + + private static float castToFloat32(Object value) { + if (value instanceof Number) + return ((Number) value).floatValue(); + else if (value instanceof Boolean) + return ((boolean) value) ? 1.f : 0.f; + else if (value instanceof String) + return Float.parseFloat((String) value); + else + throw new DataException("Unexpected type in Cast transformation: " + value.getClass()); + } + + private static double castToFloat64(Object value) { + if (value instanceof Number) + return ((Number) value).doubleValue(); + else if (value instanceof Boolean) + return ((boolean) value) ? 1. : 0.; + else if (value instanceof String) + return Double.parseDouble((String) value); + else + throw new DataException("Unexpected type in Cast transformation: " + value.getClass()); + } + + private static boolean castToBoolean(Object value) { + if (value instanceof Number) + return ((Number) value).longValue() != 0L; + else if (value instanceof Boolean) + return (Boolean) value; + else if (value instanceof String) + return Boolean.parseBoolean((String) value); + else + throw new DataException("Unexpected type in Cast transformation: " + value.getClass()); + } + + private static String castToString(Object value) { + if (value instanceof java.util.Date) { + java.util.Date dateValue = (java.util.Date) value; + return Values.dateFormatFor(dateValue).format(dateValue); + } else if (value instanceof ByteBuffer) { + ByteBuffer byteBuffer = (ByteBuffer) value; + return Base64.getEncoder().encodeToString(Utils.readBytes(byteBuffer)); + } else if (value instanceof byte[]) { + byte[] rawBytes = (byte[]) value; + return Base64.getEncoder().encodeToString(rawBytes); + } else { + return value.toString(); + } + } + + protected abstract Schema operatingSchema(R record); + + protected abstract Object operatingValue(R record); + + protected abstract R newRecord(R record, Schema updatedSchema, Object updatedValue); + + private static Map parseFieldTypes(List mappings) { + final Map m = new HashMap<>(); + boolean isWholeValueCast = false; + for (String mapping : mappings) { + final String[] parts = mapping.split(":"); + if (parts.length > 2) { + throw new ConfigException(ReplaceField.ConfigName.RENAME, mappings, "Invalid rename mapping: " + mapping); + } + if (parts.length == 1) { + Schema.Type targetType = Schema.Type.valueOf(parts[0].trim().toUpperCase(Locale.ROOT)); + m.put(WHOLE_VALUE_CAST, validCastType(targetType, FieldType.OUTPUT)); + isWholeValueCast = true; + } else { + Schema.Type type; + try { + type = Schema.Type.valueOf(parts[1].trim().toUpperCase(Locale.ROOT)); + } catch (IllegalArgumentException e) { + throw new ConfigException("Invalid type found in casting spec: " + parts[1].trim(), e); + } + m.put(parts[0].trim(), validCastType(type, FieldType.OUTPUT)); + } + } + if (isWholeValueCast && mappings.size() > 1) { + throw new ConfigException("Cast transformations that specify a type to cast the entire value to " + + "may ony specify a single cast in their spec"); + } + return m; + } + + private enum FieldType { + INPUT, OUTPUT + } + + private static Schema.Type validCastType(Schema.Type type, FieldType fieldType) { + switch (fieldType) { + case INPUT: + if (!SUPPORTED_CAST_INPUT_TYPES.contains(type)) { + throw new DataException("Cast transformation does not support casting from " + + type + "; supported types are " + SUPPORTED_CAST_INPUT_TYPES); + } + break; + case OUTPUT: + if (!SUPPORTED_CAST_OUTPUT_TYPES.contains(type)) { + throw new ConfigException("Cast transformation does not support casting to " + + type + "; supported types are " + SUPPORTED_CAST_OUTPUT_TYPES); + } + break; + } + return type; + } + + public static final class Key> extends Cast { + @Override + protected Schema operatingSchema(R record) { + return record.keySchema(); + } + + @Override + protected Object operatingValue(R record) { + return record.key(); + } + + @Override + protected R newRecord(R record, Schema updatedSchema, Object updatedValue) { + return record.newRecord(record.topic(), record.kafkaPartition(), updatedSchema, updatedValue, record.valueSchema(), record.value(), record.timestamp()); + } + } + + public static final class Value> extends Cast { + @Override + protected Schema operatingSchema(R record) { + return record.valueSchema(); + } + + @Override + protected Object operatingValue(R record) { + return record.value(); + } + + @Override + protected R newRecord(R record, Schema updatedSchema, Object updatedValue) { + return record.newRecord(record.topic(), record.kafkaPartition(), record.keySchema(), record.key(), updatedSchema, updatedValue, record.timestamp()); + } + } + +} diff --git a/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/DropHeaders.java b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/DropHeaders.java new file mode 100644 index 0000000..6d1e1a4 --- /dev/null +++ b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/DropHeaders.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.connector.ConnectRecord; +import org.apache.kafka.connect.header.ConnectHeaders; +import org.apache.kafka.connect.header.Header; +import org.apache.kafka.connect.header.Headers; +import org.apache.kafka.connect.transforms.util.NonEmptyListValidator; +import org.apache.kafka.connect.transforms.util.SimpleConfig; + +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import static org.apache.kafka.common.config.ConfigDef.NO_DEFAULT_VALUE; + +public class DropHeaders> implements Transformation { + + public static final String OVERVIEW_DOC = + "Removes one or more headers from each record."; + + public static final String HEADERS_FIELD = "headers"; + + public static final ConfigDef CONFIG_DEF = new ConfigDef() + .define(HEADERS_FIELD, ConfigDef.Type.LIST, + NO_DEFAULT_VALUE, new NonEmptyListValidator(), + ConfigDef.Importance.HIGH, + "The name of the headers to be removed."); + + private Set headers; + + @Override + public R apply(R record) { + Headers updatedHeaders = new ConnectHeaders(); + for (Header header : record.headers()) { + if (!headers.contains(header.key())) { + updatedHeaders.add(header); + } + } + return record.newRecord(record.topic(), record.kafkaPartition(), record.keySchema(), record.key(), + record.valueSchema(), record.value(), record.timestamp(), updatedHeaders); + } + + @Override + public ConfigDef config() { + return CONFIG_DEF; + } + + @Override + public void close() { + } + + @Override + public void configure(Map props) { + final SimpleConfig config = new SimpleConfig(CONFIG_DEF, props); + headers = new HashSet<>(config.getList(HEADERS_FIELD)); + } +} diff --git a/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/ExtractField.java b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/ExtractField.java new file mode 100644 index 0000000..bd3cbd9 --- /dev/null +++ b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/ExtractField.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.connector.ConnectRecord; +import org.apache.kafka.connect.data.Field; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.transforms.util.SimpleConfig; + +import java.util.Map; + +import static org.apache.kafka.connect.transforms.util.Requirements.requireMapOrNull; +import static org.apache.kafka.connect.transforms.util.Requirements.requireStructOrNull; + +public abstract class ExtractField> implements Transformation { + + public static final String OVERVIEW_DOC = + "Extract the specified field from a Struct when schema present, or a Map in the case of schemaless data. " + + "Any null values are passed through unmodified." + + "

            Use the concrete transformation type designed for the record key (" + Key.class.getName() + ") " + + "or value (" + Value.class.getName() + ")."; + + private static final String FIELD_CONFIG = "field"; + + public static final ConfigDef CONFIG_DEF = new ConfigDef() + .define(FIELD_CONFIG, ConfigDef.Type.STRING, ConfigDef.NO_DEFAULT_VALUE, ConfigDef.Importance.MEDIUM, "Field name to extract."); + + private static final String PURPOSE = "field extraction"; + + private String fieldName; + + @Override + public void configure(Map props) { + final SimpleConfig config = new SimpleConfig(CONFIG_DEF, props); + fieldName = config.getString(FIELD_CONFIG); + } + + @Override + public R apply(R record) { + final Schema schema = operatingSchema(record); + if (schema == null) { + final Map value = requireMapOrNull(operatingValue(record), PURPOSE); + return newRecord(record, null, value == null ? null : value.get(fieldName)); + } else { + final Struct value = requireStructOrNull(operatingValue(record), PURPOSE); + Field field = schema.field(fieldName); + + if (field == null) { + throw new IllegalArgumentException("Unknown field: " + fieldName); + } + + return newRecord(record, field.schema(), value == null ? null : value.get(fieldName)); + } + } + + @Override + public void close() { + } + + @Override + public ConfigDef config() { + return CONFIG_DEF; + } + + protected abstract Schema operatingSchema(R record); + + protected abstract Object operatingValue(R record); + + protected abstract R newRecord(R record, Schema updatedSchema, Object updatedValue); + + public static class Key> extends ExtractField { + @Override + protected Schema operatingSchema(R record) { + return record.keySchema(); + } + + @Override + protected Object operatingValue(R record) { + return record.key(); + } + + @Override + protected R newRecord(R record, Schema updatedSchema, Object updatedValue) { + return record.newRecord(record.topic(), record.kafkaPartition(), updatedSchema, updatedValue, record.valueSchema(), record.value(), record.timestamp()); + } + } + + public static class Value> extends ExtractField { + @Override + protected Schema operatingSchema(R record) { + return record.valueSchema(); + } + + @Override + protected Object operatingValue(R record) { + return record.value(); + } + + @Override + protected R newRecord(R record, Schema updatedSchema, Object updatedValue) { + return record.newRecord(record.topic(), record.kafkaPartition(), record.keySchema(), record.key(), updatedSchema, updatedValue, record.timestamp()); + } + } + +} diff --git a/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/Filter.java b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/Filter.java new file mode 100644 index 0000000..d7fb54e --- /dev/null +++ b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/Filter.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms; + +import java.util.Map; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.connector.ConnectRecord; + +/** + * Drops all records, filtering them from subsequent transformations in the chain. + * This is intended to be used conditionally to filter out records matching (or not matching) + * a particular {@link org.apache.kafka.connect.transforms.predicates.Predicate}. + * @param The type of record. + */ +public class Filter> implements Transformation { + + public static final String OVERVIEW_DOC = "Drops all records, filtering them from subsequent transformations in the chain. " + + "This is intended to be used conditionally to filter out records matching (or not matching) " + + "a particular Predicate."; + public static final ConfigDef CONFIG_DEF = new ConfigDef(); + + @Override + public R apply(R record) { + return null; + } + + @Override + public ConfigDef config() { + return CONFIG_DEF; + } + + @Override + public void close() { + + } + + @Override + public void configure(Map configs) { + + } +} diff --git a/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/Flatten.java b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/Flatten.java new file mode 100644 index 0000000..35a57dd --- /dev/null +++ b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/Flatten.java @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.common.cache.Cache; +import org.apache.kafka.common.cache.LRUCache; +import org.apache.kafka.common.cache.SynchronizedCache; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.connector.ConnectRecord; +import org.apache.kafka.connect.data.ConnectSchema; +import org.apache.kafka.connect.data.Field; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaBuilder; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.errors.DataException; +import org.apache.kafka.connect.transforms.util.SchemaUtil; +import org.apache.kafka.connect.transforms.util.SimpleConfig; + +import java.util.LinkedHashMap; +import java.util.Map; + +import static org.apache.kafka.connect.transforms.util.Requirements.requireMap; +import static org.apache.kafka.connect.transforms.util.Requirements.requireStructOrNull; + +public abstract class Flatten> implements Transformation { + + public static final String OVERVIEW_DOC = + "Flatten a nested data structure, generating names for each field by concatenating the field names at each " + + "level with a configurable delimiter character. Applies to Struct when schema present, or a Map " + + "in the case of schemaless data. Array fields and their contents are not modified. The default delimiter is '.'." + + "

            Use the concrete transformation type designed for the record key (" + Key.class.getName() + ") " + + "or value (" + Value.class.getName() + ")."; + + private static final String DELIMITER_CONFIG = "delimiter"; + private static final String DELIMITER_DEFAULT = "."; + + public static final ConfigDef CONFIG_DEF = new ConfigDef() + .define(DELIMITER_CONFIG, ConfigDef.Type.STRING, DELIMITER_DEFAULT, ConfigDef.Importance.MEDIUM, + "Delimiter to insert between field names from the input record when generating field names for the " + + "output record"); + + private static final String PURPOSE = "flattening"; + + private String delimiter; + + private Cache schemaUpdateCache; + + @Override + public void configure(Map props) { + final SimpleConfig config = new SimpleConfig(CONFIG_DEF, props); + delimiter = config.getString(DELIMITER_CONFIG); + schemaUpdateCache = new SynchronizedCache<>(new LRUCache<>(16)); + } + + @Override + public R apply(R record) { + if (operatingValue(record) == null) { + return record; + } else if (operatingSchema(record) == null) { + return applySchemaless(record); + } else { + return applyWithSchema(record); + } + } + + @Override + public void close() { + } + + @Override + public ConfigDef config() { + return CONFIG_DEF; + } + + protected abstract Schema operatingSchema(R record); + + protected abstract Object operatingValue(R record); + + protected abstract R newRecord(R record, Schema updatedSchema, Object updatedValue); + + private R applySchemaless(R record) { + final Map value = requireMap(operatingValue(record), PURPOSE); + final Map newValue = new LinkedHashMap<>(); + applySchemaless(value, "", newValue); + return newRecord(record, null, newValue); + } + + private void applySchemaless(Map originalRecord, String fieldNamePrefix, Map newRecord) { + for (Map.Entry entry : originalRecord.entrySet()) { + final String fieldName = fieldName(fieldNamePrefix, entry.getKey()); + Object value = entry.getValue(); + if (value == null) { + newRecord.put(fieldName(fieldNamePrefix, entry.getKey()), null); + continue; + } + + Schema.Type inferredType = ConnectSchema.schemaType(value.getClass()); + if (inferredType == null) { + throw new DataException("Flatten transformation was passed a value of type " + value.getClass() + + " which is not supported by Connect's data API"); + } + switch (inferredType) { + case INT8: + case INT16: + case INT32: + case INT64: + case FLOAT32: + case FLOAT64: + case BOOLEAN: + case STRING: + case BYTES: + case ARRAY: + newRecord.put(fieldName(fieldNamePrefix, entry.getKey()), entry.getValue()); + break; + case MAP: + final Map fieldValue = requireMap(entry.getValue(), PURPOSE); + applySchemaless(fieldValue, fieldName, newRecord); + break; + default: + throw new DataException("Flatten transformation does not support " + entry.getValue().getClass() + + " for record without schemas (for field " + fieldName + ")."); + } + } + } + + private R applyWithSchema(R record) { + final Struct value = requireStructOrNull(operatingValue(record), PURPOSE); + + Schema schema = operatingSchema(record); + Schema updatedSchema = schemaUpdateCache.get(schema); + if (updatedSchema == null) { + final SchemaBuilder builder = SchemaUtil.copySchemaBasics(schema, SchemaBuilder.struct()); + Struct defaultValue = (Struct) schema.defaultValue(); + buildUpdatedSchema(schema, "", builder, schema.isOptional(), defaultValue); + updatedSchema = builder.build(); + schemaUpdateCache.put(schema, updatedSchema); + } + if (value == null) { + return newRecord(record, updatedSchema, null); + } else { + final Struct updatedValue = new Struct(updatedSchema); + buildWithSchema(value, "", updatedValue); + return newRecord(record, updatedSchema, updatedValue); + } + } + + /** + * Build an updated Struct Schema which flattens all nested fields into a single struct, handling cases where + * optionality and default values of the flattened fields are affected by the optionality and default values of + * parent/ancestor schemas (e.g. flattened field is optional because the parent schema was optional, even if the + * schema itself is marked as required). + * @param schema the schema to translate + * @param fieldNamePrefix the prefix to use on field names, i.e. the delimiter-joined set of ancestor field names + * @param newSchema the flattened schema being built + * @param optional true if any ancestor schema is optional + * @param defaultFromParent the default value, if any, included via the parent/ancestor schemas + */ + private void buildUpdatedSchema(Schema schema, String fieldNamePrefix, SchemaBuilder newSchema, boolean optional, Struct defaultFromParent) { + for (Field field : schema.fields()) { + final String fieldName = fieldName(fieldNamePrefix, field.name()); + final boolean fieldIsOptional = optional || field.schema().isOptional(); + Object fieldDefaultValue = null; + if (field.schema().defaultValue() != null) { + fieldDefaultValue = field.schema().defaultValue(); + } else if (defaultFromParent != null) { + fieldDefaultValue = defaultFromParent.get(field); + } + switch (field.schema().type()) { + case INT8: + case INT16: + case INT32: + case INT64: + case FLOAT32: + case FLOAT64: + case BOOLEAN: + case STRING: + case BYTES: + case ARRAY: + newSchema.field(fieldName, convertFieldSchema(field.schema(), fieldIsOptional, fieldDefaultValue)); + break; + case STRUCT: + buildUpdatedSchema(field.schema(), fieldName, newSchema, fieldIsOptional, (Struct) fieldDefaultValue); + break; + default: + throw new DataException("Flatten transformation does not support " + field.schema().type() + + " for record with schemas (for field " + fieldName + ")."); + } + } + } + + /** + * Convert the schema for a field of a Struct with a primitive schema to the schema to be used for the flattened + * version, taking into account that we may need to override optionality and default values in the flattened version + * to take into account the optionality and default values of parent/ancestor schemas + * @param orig the original schema for the field + * @param optional whether the new flattened field should be optional + * @param defaultFromParent the default value either taken from the existing field or provided by the parent + */ + private Schema convertFieldSchema(Schema orig, boolean optional, Object defaultFromParent) { + // Note that we don't use the schema translation cache here. It might save us a bit of effort, but we really + // only care about caching top-level schema translations. + + final SchemaBuilder builder = SchemaUtil.copySchemaBasics(orig); + if (optional) + builder.optional(); + if (defaultFromParent != null) + builder.defaultValue(defaultFromParent); + return builder.build(); + } + + private void buildWithSchema(Struct record, String fieldNamePrefix, Struct newRecord) { + if (record == null) { + return; + } + for (Field field : record.schema().fields()) { + final String fieldName = fieldName(fieldNamePrefix, field.name()); + switch (field.schema().type()) { + case INT8: + case INT16: + case INT32: + case INT64: + case FLOAT32: + case FLOAT64: + case BOOLEAN: + case STRING: + case BYTES: + case ARRAY: + newRecord.put(fieldName, record.get(field)); + break; + case STRUCT: + buildWithSchema(record.getStruct(field.name()), fieldName, newRecord); + break; + default: + throw new DataException("Flatten transformation does not support " + field.schema().type() + + " for record with schemas (for field " + fieldName + ")."); + } + } + } + + private String fieldName(String prefix, String fieldName) { + return prefix.isEmpty() ? fieldName : (prefix + delimiter + fieldName); + } + + public static class Key> extends Flatten { + @Override + protected Schema operatingSchema(R record) { + return record.keySchema(); + } + + @Override + protected Object operatingValue(R record) { + return record.key(); + } + + @Override + protected R newRecord(R record, Schema updatedSchema, Object updatedValue) { + return record.newRecord(record.topic(), record.kafkaPartition(), updatedSchema, updatedValue, record.valueSchema(), record.value(), record.timestamp()); + } + } + + public static class Value> extends Flatten { + @Override + protected Schema operatingSchema(R record) { + return record.valueSchema(); + } + + @Override + protected Object operatingValue(R record) { + return record.value(); + } + + @Override + protected R newRecord(R record, Schema updatedSchema, Object updatedValue) { + return record.newRecord(record.topic(), record.kafkaPartition(), record.keySchema(), record.key(), updatedSchema, updatedValue, record.timestamp()); + } + } + +} diff --git a/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/HeaderFrom.java b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/HeaderFrom.java new file mode 100644 index 0000000..b32ad56 --- /dev/null +++ b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/HeaderFrom.java @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.common.cache.Cache; +import org.apache.kafka.common.cache.LRUCache; +import org.apache.kafka.common.cache.SynchronizedCache; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.connect.connector.ConnectRecord; +import org.apache.kafka.connect.data.Field; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaBuilder; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.header.Header; +import org.apache.kafka.connect.header.Headers; +import org.apache.kafka.connect.transforms.util.NonEmptyListValidator; +import org.apache.kafka.connect.transforms.util.Requirements; +import org.apache.kafka.connect.transforms.util.SchemaUtil; +import org.apache.kafka.connect.transforms.util.SimpleConfig; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static java.lang.String.format; +import static org.apache.kafka.common.config.ConfigDef.NO_DEFAULT_VALUE; + +public abstract class HeaderFrom> implements Transformation { + + public static final String FIELDS_FIELD = "fields"; + public static final String HEADERS_FIELD = "headers"; + public static final String OPERATION_FIELD = "operation"; + private static final String MOVE_OPERATION = "move"; + private static final String COPY_OPERATION = "copy"; + + public static final String OVERVIEW_DOC = + "Moves or copies fields in the key/value of a record into that record's headers. " + + "Corresponding elements of " + FIELDS_FIELD + " and " + + "" + HEADERS_FIELD + " together identify a field and the header it should be " + + "moved or copied to. " + + "Use the concrete transformation type designed for the record " + + "key (" + Key.class.getName() + ") or value (" + Value.class.getName() + ")."; + + public static final ConfigDef CONFIG_DEF = new ConfigDef() + .define(FIELDS_FIELD, ConfigDef.Type.LIST, + NO_DEFAULT_VALUE, new NonEmptyListValidator(), + ConfigDef.Importance.HIGH, + "Field names in the record whose values are to be copied or moved to headers.") + .define(HEADERS_FIELD, ConfigDef.Type.LIST, + NO_DEFAULT_VALUE, new NonEmptyListValidator(), + ConfigDef.Importance.HIGH, + "Header names, in the same order as the field names listed in the fields configuration property.") + .define(OPERATION_FIELD, ConfigDef.Type.STRING, NO_DEFAULT_VALUE, + ConfigDef.ValidString.in(MOVE_OPERATION, COPY_OPERATION), ConfigDef.Importance.HIGH, + "Either move if the fields are to be moved to the headers (removed from the key/value), " + + "or copy if the fields are to be copied to the headers (retained in the key/value)."); + + enum Operation { + MOVE(MOVE_OPERATION), + COPY(COPY_OPERATION); + + private final String name; + + Operation(String name) { + this.name = name; + } + + static Operation fromName(String name) { + switch (name) { + case MOVE_OPERATION: + return MOVE; + case COPY_OPERATION: + return COPY; + default: + throw new IllegalArgumentException(); + } + } + + public String toString() { + return name; + } + } + + private List fields; + + private List headers; + + private Operation operation; + + private Cache moveSchemaCache = new SynchronizedCache<>(new LRUCache<>(16)); + + @Override + public R apply(R record) { + Object operatingValue = operatingValue(record); + Schema operatingSchema = operatingSchema(record); + + if (operatingSchema == null) { + return applySchemaless(record, operatingValue); + } else { + return applyWithSchema(record, operatingValue, operatingSchema); + } + } + + private R applyWithSchema(R record, Object operatingValue, Schema operatingSchema) { + Headers updatedHeaders = record.headers().duplicate(); + Struct value = Requirements.requireStruct(operatingValue, "header " + operation); + final Schema updatedSchema; + final Struct updatedValue; + if (operation == Operation.MOVE) { + updatedSchema = moveSchema(operatingSchema); + updatedValue = new Struct(updatedSchema); + for (Field field : updatedSchema.fields()) { + updatedValue.put(field, value.get(field.name())); + } + } else { + updatedSchema = operatingSchema; + updatedValue = value; + } + for (int i = 0; i < fields.size(); i++) { + String fieldName = fields.get(i); + String headerName = headers.get(i); + Object fieldValue = value.schema().field(fieldName) != null ? value.get(fieldName) : null; + Schema fieldSchema = operatingSchema.field(fieldName).schema(); + updatedHeaders.add(headerName, fieldValue, fieldSchema); + } + return newRecord(record, updatedSchema, updatedValue, updatedHeaders); + } + + private Schema moveSchema(Schema operatingSchema) { + Schema moveSchema = this.moveSchemaCache.get(operatingSchema); + if (moveSchema == null) { + final SchemaBuilder builder = SchemaUtil.copySchemaBasics(operatingSchema, SchemaBuilder.struct()); + for (Field field : operatingSchema.fields()) { + if (!fields.contains(field.name())) { + builder.field(field.name(), field.schema()); + } + } + moveSchema = builder.build(); + moveSchemaCache.put(operatingSchema, moveSchema); + } + return moveSchema; + } + + private R applySchemaless(R record, Object operatingValue) { + Headers updatedHeaders = record.headers().duplicate(); + Map value = Requirements.requireMap(operatingValue, "header " + operation); + Map updatedValue = new HashMap<>(value); + for (int i = 0; i < fields.size(); i++) { + String fieldName = fields.get(i); + Object fieldValue = value.get(fieldName); + String headerName = headers.get(i); + if (operation == Operation.MOVE) { + updatedValue.remove(fieldName); + } + updatedHeaders.add(headerName, fieldValue, null); + } + return newRecord(record, null, updatedValue, updatedHeaders); + } + + protected abstract Object operatingValue(R record); + protected abstract Schema operatingSchema(R record); + protected abstract R newRecord(R record, Schema updatedSchema, Object updatedValue, Iterable

            updatedHeaders); + + public static class Key> extends HeaderFrom { + + @Override + public Object operatingValue(R record) { + return record.key(); + } + + @Override + protected Schema operatingSchema(R record) { + return record.keySchema(); + } + + @Override + protected R newRecord(R record, Schema updatedSchema, Object updatedValue, Iterable
            updatedHeaders) { + return record.newRecord(record.topic(), record.kafkaPartition(), updatedSchema, updatedValue, + record.valueSchema(), record.value(), record.timestamp(), updatedHeaders); + } + } + + public static class Value> extends HeaderFrom { + + @Override + public Object operatingValue(R record) { + return record.value(); + } + + @Override + protected Schema operatingSchema(R record) { + return record.valueSchema(); + } + + @Override + protected R newRecord(R record, Schema updatedSchema, Object updatedValue, Iterable
            updatedHeaders) { + return record.newRecord(record.topic(), record.kafkaPartition(), record.keySchema(), record.key(), + updatedSchema, updatedValue, record.timestamp(), updatedHeaders); + } + } + + @Override + public ConfigDef config() { + return CONFIG_DEF; + } + + @Override + public void close() { + + } + + @Override + public void configure(Map props) { + final SimpleConfig config = new SimpleConfig(CONFIG_DEF, props); + fields = config.getList(FIELDS_FIELD); + headers = config.getList(HEADERS_FIELD); + if (headers.size() != fields.size()) { + throw new ConfigException(format("'%s' config must have the same number of elements as '%s' config.", + FIELDS_FIELD, HEADERS_FIELD)); + } + operation = Operation.fromName(config.getString(OPERATION_FIELD)); + } +} diff --git a/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/HoistField.java b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/HoistField.java new file mode 100644 index 0000000..3614104 --- /dev/null +++ b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/HoistField.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.common.cache.Cache; +import org.apache.kafka.common.cache.LRUCache; +import org.apache.kafka.common.cache.SynchronizedCache; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.connector.ConnectRecord; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaBuilder; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.transforms.util.SimpleConfig; + +import java.util.Collections; +import java.util.Map; + +public abstract class HoistField> implements Transformation { + + public static final String OVERVIEW_DOC = + "Wrap data using the specified field name in a Struct when schema present, or a Map in the case of schemaless data." + + "

            Use the concrete transformation type designed for the record key (" + Key.class.getName() + ") " + + "or value (" + Value.class.getName() + ")."; + + private static final String FIELD_CONFIG = "field"; + + public static final ConfigDef CONFIG_DEF = new ConfigDef() + .define(FIELD_CONFIG, ConfigDef.Type.STRING, ConfigDef.NO_DEFAULT_VALUE, ConfigDef.Importance.MEDIUM, + "Field name for the single field that will be created in the resulting Struct or Map."); + + private Cache schemaUpdateCache; + + private String fieldName; + + @Override + public void configure(Map props) { + final SimpleConfig config = new SimpleConfig(CONFIG_DEF, props); + fieldName = config.getString("field"); + schemaUpdateCache = new SynchronizedCache<>(new LRUCache<>(16)); + } + + @Override + public R apply(R record) { + final Schema schema = operatingSchema(record); + final Object value = operatingValue(record); + + if (schema == null) { + return newRecord(record, null, Collections.singletonMap(fieldName, value)); + } else { + Schema updatedSchema = schemaUpdateCache.get(schema); + if (updatedSchema == null) { + updatedSchema = SchemaBuilder.struct().field(fieldName, schema).build(); + schemaUpdateCache.put(schema, updatedSchema); + } + + final Struct updatedValue = new Struct(updatedSchema).put(fieldName, value); + + return newRecord(record, updatedSchema, updatedValue); + } + } + + @Override + public void close() { + schemaUpdateCache = null; + } + + @Override + public ConfigDef config() { + return CONFIG_DEF; + } + + protected abstract Schema operatingSchema(R record); + + protected abstract Object operatingValue(R record); + + protected abstract R newRecord(R record, Schema updatedSchema, Object updatedValue); + + public static class Key> extends HoistField { + @Override + protected Schema operatingSchema(R record) { + return record.keySchema(); + } + + @Override + protected Object operatingValue(R record) { + return record.key(); + } + + @Override + protected R newRecord(R record, Schema updatedSchema, Object updatedValue) { + return record.newRecord(record.topic(), record.kafkaPartition(), updatedSchema, updatedValue, record.valueSchema(), record.value(), record.timestamp()); + } + } + + public static class Value> extends HoistField { + @Override + protected Schema operatingSchema(R record) { + return record.valueSchema(); + } + + @Override + protected Object operatingValue(R record) { + return record.value(); + } + + @Override + protected R newRecord(R record, Schema updatedSchema, Object updatedValue) { + return record.newRecord(record.topic(), record.kafkaPartition(), record.keySchema(), record.key(), updatedSchema, updatedValue, record.timestamp()); + } + } + +} diff --git a/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/InsertField.java b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/InsertField.java new file mode 100644 index 0000000..cbc820b --- /dev/null +++ b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/InsertField.java @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.common.cache.Cache; +import org.apache.kafka.common.cache.LRUCache; +import org.apache.kafka.common.cache.SynchronizedCache; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.connect.connector.ConnectRecord; +import org.apache.kafka.connect.data.Field; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaBuilder; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.data.Timestamp; +import org.apache.kafka.connect.transforms.util.SimpleConfig; +import org.apache.kafka.connect.transforms.util.SchemaUtil; + +import java.util.Date; +import java.util.HashMap; +import java.util.Map; + +import static org.apache.kafka.connect.transforms.util.Requirements.requireMap; +import static org.apache.kafka.connect.transforms.util.Requirements.requireSinkRecord; +import static org.apache.kafka.connect.transforms.util.Requirements.requireStruct; + +public abstract class InsertField> implements Transformation { + + public static final String OVERVIEW_DOC = + "Insert field(s) using attributes from the record metadata or a configured static value." + + "

            Use the concrete transformation type designed for the record key (" + Key.class.getName() + ") " + + "or value (" + Value.class.getName() + ")."; + + private interface ConfigName { + String TOPIC_FIELD = "topic.field"; + String PARTITION_FIELD = "partition.field"; + String OFFSET_FIELD = "offset.field"; + String TIMESTAMP_FIELD = "timestamp.field"; + String STATIC_FIELD = "static.field"; + String STATIC_VALUE = "static.value"; + } + + private static final String OPTIONALITY_DOC = "Suffix with ! to make this a required field, or ? to keep it optional (the default)."; + + public static final ConfigDef CONFIG_DEF = new ConfigDef() + .define(ConfigName.TOPIC_FIELD, ConfigDef.Type.STRING, null, ConfigDef.Importance.MEDIUM, + "Field name for Kafka topic. " + OPTIONALITY_DOC) + .define(ConfigName.PARTITION_FIELD, ConfigDef.Type.STRING, null, ConfigDef.Importance.MEDIUM, + "Field name for Kafka partition. " + OPTIONALITY_DOC) + .define(ConfigName.OFFSET_FIELD, ConfigDef.Type.STRING, null, ConfigDef.Importance.MEDIUM, + "Field name for Kafka offset - only applicable to sink connectors.
            " + OPTIONALITY_DOC) + .define(ConfigName.TIMESTAMP_FIELD, ConfigDef.Type.STRING, null, ConfigDef.Importance.MEDIUM, + "Field name for record timestamp. " + OPTIONALITY_DOC) + .define(ConfigName.STATIC_FIELD, ConfigDef.Type.STRING, null, ConfigDef.Importance.MEDIUM, + "Field name for static data field. " + OPTIONALITY_DOC) + .define(ConfigName.STATIC_VALUE, ConfigDef.Type.STRING, null, ConfigDef.Importance.MEDIUM, + "Static field value, if field name configured."); + + private static final String PURPOSE = "field insertion"; + + private static final Schema OPTIONAL_TIMESTAMP_SCHEMA = Timestamp.builder().optional().build(); + + private static final class InsertionSpec { + final String name; + final boolean optional; + + private InsertionSpec(String name, boolean optional) { + this.name = name; + this.optional = optional; + } + + public static InsertionSpec parse(String spec) { + if (spec == null) return null; + if (spec.endsWith("?")) { + return new InsertionSpec(spec.substring(0, spec.length() - 1), true); + } + if (spec.endsWith("!")) { + return new InsertionSpec(spec.substring(0, spec.length() - 1), false); + } + return new InsertionSpec(spec, true); + } + } + + private InsertionSpec topicField; + private InsertionSpec partitionField; + private InsertionSpec offsetField; + private InsertionSpec timestampField; + private InsertionSpec staticField; + private String staticValue; + + private Cache schemaUpdateCache; + + @Override + public void configure(Map props) { + final SimpleConfig config = new SimpleConfig(CONFIG_DEF, props); + topicField = InsertionSpec.parse(config.getString(ConfigName.TOPIC_FIELD)); + partitionField = InsertionSpec.parse(config.getString(ConfigName.PARTITION_FIELD)); + offsetField = InsertionSpec.parse(config.getString(ConfigName.OFFSET_FIELD)); + timestampField = InsertionSpec.parse(config.getString(ConfigName.TIMESTAMP_FIELD)); + staticField = InsertionSpec.parse(config.getString(ConfigName.STATIC_FIELD)); + staticValue = config.getString(ConfigName.STATIC_VALUE); + + if (topicField == null && partitionField == null && offsetField == null && timestampField == null && staticField == null) { + throw new ConfigException("No field insertion configured"); + } + + if (staticField != null && staticValue == null) { + throw new ConfigException(ConfigName.STATIC_VALUE, null, "No value specified for static field: " + staticField); + } + + schemaUpdateCache = new SynchronizedCache<>(new LRUCache<>(16)); + } + + @Override + public R apply(R record) { + if (operatingValue(record) == null) { + return record; + } else if (operatingSchema(record) == null) { + return applySchemaless(record); + } else { + return applyWithSchema(record); + } + } + + private R applySchemaless(R record) { + final Map value = requireMap(operatingValue(record), PURPOSE); + + final Map updatedValue = new HashMap<>(value); + + if (topicField != null) { + updatedValue.put(topicField.name, record.topic()); + } + if (partitionField != null && record.kafkaPartition() != null) { + updatedValue.put(partitionField.name, record.kafkaPartition()); + } + if (offsetField != null) { + updatedValue.put(offsetField.name, requireSinkRecord(record, PURPOSE).kafkaOffset()); + } + if (timestampField != null && record.timestamp() != null) { + updatedValue.put(timestampField.name, record.timestamp()); + } + if (staticField != null && staticValue != null) { + updatedValue.put(staticField.name, staticValue); + } + + return newRecord(record, null, updatedValue); + } + + private R applyWithSchema(R record) { + final Struct value = requireStruct(operatingValue(record), PURPOSE); + + Schema updatedSchema = schemaUpdateCache.get(value.schema()); + if (updatedSchema == null) { + updatedSchema = makeUpdatedSchema(value.schema()); + schemaUpdateCache.put(value.schema(), updatedSchema); + } + + final Struct updatedValue = new Struct(updatedSchema); + + for (Field field : value.schema().fields()) { + updatedValue.put(field.name(), value.get(field)); + } + + if (topicField != null) { + updatedValue.put(topicField.name, record.topic()); + } + if (partitionField != null && record.kafkaPartition() != null) { + updatedValue.put(partitionField.name, record.kafkaPartition()); + } + if (offsetField != null) { + updatedValue.put(offsetField.name, requireSinkRecord(record, PURPOSE).kafkaOffset()); + } + if (timestampField != null && record.timestamp() != null) { + updatedValue.put(timestampField.name, new Date(record.timestamp())); + } + if (staticField != null && staticValue != null) { + updatedValue.put(staticField.name, staticValue); + } + + return newRecord(record, updatedSchema, updatedValue); + } + + private Schema makeUpdatedSchema(Schema schema) { + final SchemaBuilder builder = SchemaUtil.copySchemaBasics(schema, SchemaBuilder.struct()); + + for (Field field : schema.fields()) { + builder.field(field.name(), field.schema()); + } + + if (topicField != null) { + builder.field(topicField.name, topicField.optional ? Schema.OPTIONAL_STRING_SCHEMA : Schema.STRING_SCHEMA); + } + if (partitionField != null) { + builder.field(partitionField.name, partitionField.optional ? Schema.OPTIONAL_INT32_SCHEMA : Schema.INT32_SCHEMA); + } + if (offsetField != null) { + builder.field(offsetField.name, offsetField.optional ? Schema.OPTIONAL_INT64_SCHEMA : Schema.INT64_SCHEMA); + } + if (timestampField != null) { + builder.field(timestampField.name, timestampField.optional ? OPTIONAL_TIMESTAMP_SCHEMA : Timestamp.SCHEMA); + } + if (staticField != null) { + builder.field(staticField.name, staticField.optional ? Schema.OPTIONAL_STRING_SCHEMA : Schema.STRING_SCHEMA); + } + + return builder.build(); + } + + @Override + public void close() { + schemaUpdateCache = null; + } + + @Override + public ConfigDef config() { + return CONFIG_DEF; + } + + protected abstract Schema operatingSchema(R record); + + protected abstract Object operatingValue(R record); + + protected abstract R newRecord(R record, Schema updatedSchema, Object updatedValue); + + public static class Key> extends InsertField { + + @Override + protected Schema operatingSchema(R record) { + return record.keySchema(); + } + + @Override + protected Object operatingValue(R record) { + return record.key(); + } + + @Override + protected R newRecord(R record, Schema updatedSchema, Object updatedValue) { + return record.newRecord(record.topic(), record.kafkaPartition(), updatedSchema, updatedValue, record.valueSchema(), record.value(), record.timestamp()); + } + + } + + public static class Value> extends InsertField { + + @Override + protected Schema operatingSchema(R record) { + return record.valueSchema(); + } + + @Override + protected Object operatingValue(R record) { + return record.value(); + } + + @Override + protected R newRecord(R record, Schema updatedSchema, Object updatedValue) { + return record.newRecord(record.topic(), record.kafkaPartition(), record.keySchema(), record.key(), updatedSchema, updatedValue, record.timestamp()); + } + + } + +} diff --git a/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/InsertHeader.java b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/InsertHeader.java new file mode 100644 index 0000000..88b2002 --- /dev/null +++ b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/InsertHeader.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.connector.ConnectRecord; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.data.Values; +import org.apache.kafka.connect.header.Headers; +import org.apache.kafka.connect.transforms.util.SimpleConfig; + +import java.util.Map; + +import static org.apache.kafka.common.config.ConfigDef.NO_DEFAULT_VALUE; + +public class InsertHeader> implements Transformation { + + public static final String OVERVIEW_DOC = + "Add a header to each record."; + + public static final String HEADER_FIELD = "header"; + public static final String VALUE_LITERAL_FIELD = "value.literal"; + + public static final ConfigDef CONFIG_DEF = new ConfigDef() + .define(HEADER_FIELD, ConfigDef.Type.STRING, + NO_DEFAULT_VALUE, new ConfigDef.NonNullValidator(), + ConfigDef.Importance.HIGH, + "The name of the header.") + .define(VALUE_LITERAL_FIELD, ConfigDef.Type.STRING, + NO_DEFAULT_VALUE, new ConfigDef.NonNullValidator(), + ConfigDef.Importance.HIGH, + "The literal value that is to be set as the header value on all records."); + + private String header; + + private SchemaAndValue literalValue; + + @Override + public R apply(R record) { + Headers updatedHeaders = record.headers().duplicate(); + updatedHeaders.add(header, literalValue); + return record.newRecord(record.topic(), record.kafkaPartition(), record.keySchema(), record.key(), + record.valueSchema(), record.value(), record.timestamp(), updatedHeaders); + } + + + @Override + public ConfigDef config() { + return CONFIG_DEF; + } + + @Override + public void close() { + + } + + @Override + public void configure(Map props) { + final SimpleConfig config = new SimpleConfig(CONFIG_DEF, props); + header = config.getString(HEADER_FIELD); + literalValue = Values.parseString(config.getString(VALUE_LITERAL_FIELD)); + } +} diff --git a/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/MaskField.java b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/MaskField.java new file mode 100644 index 0000000..0d61ac1 --- /dev/null +++ b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/MaskField.java @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.connector.ConnectRecord; +import org.apache.kafka.connect.data.Field; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.data.Values; +import org.apache.kafka.connect.errors.DataException; +import org.apache.kafka.connect.transforms.util.NonEmptyListValidator; +import org.apache.kafka.connect.transforms.util.SimpleConfig; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.util.Collections; +import java.util.Date; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; + +import static org.apache.kafka.connect.transforms.util.Requirements.requireMap; +import static org.apache.kafka.connect.transforms.util.Requirements.requireStruct; + +public abstract class MaskField> implements Transformation { + + public static final String OVERVIEW_DOC = + "Mask specified fields with a valid null value for the field type (i.e. 0, false, empty string, and so on)." + + "

            For numeric and string fields, an optional replacement value can be specified that is converted to the correct type." + + "

            Use the concrete transformation type designed for the record key (" + Key.class.getName() + + ") or value (" + Value.class.getName() + ")."; + + private static final String FIELDS_CONFIG = "fields"; + private static final String REPLACEMENT_CONFIG = "replacement"; + + public static final ConfigDef CONFIG_DEF = new ConfigDef() + .define(FIELDS_CONFIG, ConfigDef.Type.LIST, ConfigDef.NO_DEFAULT_VALUE, new NonEmptyListValidator(), + ConfigDef.Importance.HIGH, "Names of fields to mask.") + .define(REPLACEMENT_CONFIG, ConfigDef.Type.STRING, null, new ConfigDef.NonEmptyString(), + ConfigDef.Importance.LOW, "Custom value replacement, that will be applied to all" + + " 'fields' values (numeric or non-empty string values only)."); + + private static final String PURPOSE = "mask fields"; + + private static final Map, Function> REPLACEMENT_MAPPING_FUNC = new HashMap<>(); + private static final Map, Object> PRIMITIVE_VALUE_MAPPING = new HashMap<>(); + + static { + PRIMITIVE_VALUE_MAPPING.put(Boolean.class, Boolean.FALSE); + PRIMITIVE_VALUE_MAPPING.put(Byte.class, (byte) 0); + PRIMITIVE_VALUE_MAPPING.put(Short.class, (short) 0); + PRIMITIVE_VALUE_MAPPING.put(Integer.class, 0); + PRIMITIVE_VALUE_MAPPING.put(Long.class, 0L); + PRIMITIVE_VALUE_MAPPING.put(Float.class, 0f); + PRIMITIVE_VALUE_MAPPING.put(Double.class, 0d); + PRIMITIVE_VALUE_MAPPING.put(BigInteger.class, BigInteger.ZERO); + PRIMITIVE_VALUE_MAPPING.put(BigDecimal.class, BigDecimal.ZERO); + PRIMITIVE_VALUE_MAPPING.put(Date.class, new Date(0)); + PRIMITIVE_VALUE_MAPPING.put(String.class, ""); + + REPLACEMENT_MAPPING_FUNC.put(Byte.class, v -> Values.convertToByte(null, v)); + REPLACEMENT_MAPPING_FUNC.put(Short.class, v -> Values.convertToShort(null, v)); + REPLACEMENT_MAPPING_FUNC.put(Integer.class, v -> Values.convertToInteger(null, v)); + REPLACEMENT_MAPPING_FUNC.put(Long.class, v -> Values.convertToLong(null, v)); + REPLACEMENT_MAPPING_FUNC.put(Float.class, v -> Values.convertToFloat(null, v)); + REPLACEMENT_MAPPING_FUNC.put(Double.class, v -> Values.convertToDouble(null, v)); + REPLACEMENT_MAPPING_FUNC.put(String.class, Function.identity()); + REPLACEMENT_MAPPING_FUNC.put(BigDecimal.class, BigDecimal::new); + REPLACEMENT_MAPPING_FUNC.put(BigInteger.class, BigInteger::new); + } + + private Set maskedFields; + private String replacement; + + @Override + public void configure(Map props) { + final SimpleConfig config = new SimpleConfig(CONFIG_DEF, props); + maskedFields = new HashSet<>(config.getList(FIELDS_CONFIG)); + replacement = config.getString(REPLACEMENT_CONFIG); + } + + @Override + public R apply(R record) { + if (operatingSchema(record) == null) { + return applySchemaless(record); + } else { + return applyWithSchema(record); + } + } + + private R applySchemaless(R record) { + final Map value = requireMap(operatingValue(record), PURPOSE); + final HashMap updatedValue = new HashMap<>(value); + for (String field : maskedFields) { + updatedValue.put(field, masked(value.get(field))); + } + return newRecord(record, updatedValue); + } + + private R applyWithSchema(R record) { + final Struct value = requireStruct(operatingValue(record), PURPOSE); + final Struct updatedValue = new Struct(value.schema()); + for (Field field : value.schema().fields()) { + final Object origFieldValue = value.get(field); + updatedValue.put(field, maskedFields.contains(field.name()) ? masked(origFieldValue) : origFieldValue); + } + return newRecord(record, updatedValue); + } + + private Object masked(Object value) { + if (value == null) { + return null; + } + return replacement == null ? maskWithNullValue(value) : maskWithCustomReplacement(value, replacement); + } + + private static Object maskWithCustomReplacement(Object value, String replacement) { + Function replacementMapper = REPLACEMENT_MAPPING_FUNC.get(value.getClass()); + if (replacementMapper == null) { + throw new DataException("Cannot mask value of type " + value.getClass() + " with custom replacement."); + } + try { + return replacementMapper.apply(replacement); + } catch (NumberFormatException ex) { + throw new DataException("Unable to convert " + replacement + " (" + replacement.getClass() + ") to number", ex); + } + } + + private static Object maskWithNullValue(Object value) { + Object maskedValue = PRIMITIVE_VALUE_MAPPING.get(value.getClass()); + if (maskedValue == null) { + if (value instanceof List) + maskedValue = Collections.emptyList(); + else if (value instanceof Map) + maskedValue = Collections.emptyMap(); + else + throw new DataException("Cannot mask value of type: " + value.getClass()); + } + return maskedValue; + } + + @Override + public ConfigDef config() { + return CONFIG_DEF; + } + + @Override + public void close() { + } + + protected abstract Schema operatingSchema(R record); + + protected abstract Object operatingValue(R record); + + protected abstract R newRecord(R base, Object value); + + public static final class Key> extends MaskField { + @Override + protected Schema operatingSchema(R record) { + return record.keySchema(); + } + + @Override + protected Object operatingValue(R record) { + return record.key(); + } + + @Override + protected R newRecord(R record, Object updatedValue) { + return record.newRecord(record.topic(), record.kafkaPartition(), record.keySchema(), updatedValue, record.valueSchema(), record.value(), record.timestamp()); + } + } + + public static final class Value> extends MaskField { + @Override + protected Schema operatingSchema(R record) { + return record.valueSchema(); + } + + @Override + protected Object operatingValue(R record) { + return record.value(); + } + + @Override + protected R newRecord(R record, Object updatedValue) { + return record.newRecord(record.topic(), record.kafkaPartition(), record.keySchema(), record.key(), record.valueSchema(), updatedValue, record.timestamp()); + } + } + +} diff --git a/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/RegexRouter.java b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/RegexRouter.java new file mode 100644 index 0000000..74a19cd --- /dev/null +++ b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/RegexRouter.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.connector.ConnectRecord; +import org.apache.kafka.connect.transforms.util.RegexValidator; +import org.apache.kafka.connect.transforms.util.SimpleConfig; + +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public class RegexRouter> implements Transformation { + + public static final String OVERVIEW_DOC = "Update the record topic using the configured regular expression and replacement string." + + "

            Under the hood, the regex is compiled to a java.util.regex.Pattern. " + + "If the pattern matches the input topic, java.util.regex.Matcher#replaceFirst() is used with the replacement string to obtain the new topic."; + + public static final ConfigDef CONFIG_DEF = new ConfigDef() + .define(ConfigName.REGEX, ConfigDef.Type.STRING, ConfigDef.NO_DEFAULT_VALUE, new RegexValidator(), ConfigDef.Importance.HIGH, + "Regular expression to use for matching.") + .define(ConfigName.REPLACEMENT, ConfigDef.Type.STRING, ConfigDef.NO_DEFAULT_VALUE, ConfigDef.Importance.HIGH, + "Replacement string."); + + private interface ConfigName { + String REGEX = "regex"; + String REPLACEMENT = "replacement"; + } + + private Pattern regex; + private String replacement; + + @Override + public void configure(Map props) { + final SimpleConfig config = new SimpleConfig(CONFIG_DEF, props); + regex = Pattern.compile(config.getString(ConfigName.REGEX)); + replacement = config.getString(ConfigName.REPLACEMENT); + } + + @Override + public R apply(R record) { + final Matcher matcher = regex.matcher(record.topic()); + if (matcher.matches()) { + final String topic = matcher.replaceFirst(replacement); + return record.newRecord(topic, record.kafkaPartition(), record.keySchema(), record.key(), record.valueSchema(), record.value(), record.timestamp()); + } + return record; + } + + @Override + public void close() { + } + + @Override + public ConfigDef config() { + return CONFIG_DEF; + } + +} diff --git a/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/ReplaceField.java b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/ReplaceField.java new file mode 100644 index 0000000..fb02577 --- /dev/null +++ b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/ReplaceField.java @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.common.cache.Cache; +import org.apache.kafka.common.cache.LRUCache; +import org.apache.kafka.common.cache.SynchronizedCache; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigDef.Importance; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.utils.ConfigUtils; +import org.apache.kafka.connect.connector.ConnectRecord; +import org.apache.kafka.connect.data.Field; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaBuilder; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.transforms.util.SchemaUtil; +import org.apache.kafka.connect.transforms.util.SimpleConfig; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.kafka.connect.transforms.util.Requirements.requireMap; +import static org.apache.kafka.connect.transforms.util.Requirements.requireStruct; + +public abstract class ReplaceField> implements Transformation { + + public static final String OVERVIEW_DOC = "Filter or rename fields." + + "

            Use the concrete transformation type designed for the record key (" + Key.class.getName() + ") " + + "or value (" + Value.class.getName() + ")."; + + interface ConfigName { + String EXCLUDE = "exclude"; + String INCLUDE = "include"; + + // for backwards compatibility + String INCLUDE_ALIAS = "whitelist"; + String EXCLUDE_ALIAS = "blacklist"; + + String RENAME = "renames"; + } + + public static final ConfigDef CONFIG_DEF = new ConfigDef() + .define(ConfigName.EXCLUDE, ConfigDef.Type.LIST, Collections.emptyList(), ConfigDef.Importance.MEDIUM, + "Fields to exclude. This takes precedence over the fields to include.") + .define("blacklist", ConfigDef.Type.LIST, null, Importance.LOW, + "Deprecated. Use " + ConfigName.EXCLUDE + " instead.") + .define(ConfigName.INCLUDE, ConfigDef.Type.LIST, Collections.emptyList(), ConfigDef.Importance.MEDIUM, + "Fields to include. If specified, only these fields will be used.") + .define("whitelist", ConfigDef.Type.LIST, null, Importance.LOW, + "Deprecated. Use " + ConfigName.INCLUDE + " instead.") + .define(ConfigName.RENAME, ConfigDef.Type.LIST, Collections.emptyList(), new ConfigDef.Validator() { + @SuppressWarnings("unchecked") + @Override + public void ensureValid(String name, Object value) { + parseRenameMappings((List) value); + } + + @Override + public String toString() { + return "list of colon-delimited pairs, e.g. foo:bar,abc:xyz"; + } + }, ConfigDef.Importance.MEDIUM, "Field rename mappings."); + + private static final String PURPOSE = "field replacement"; + + private List exclude; + private List include; + private Map renames; + private Map reverseRenames; + + private Cache schemaUpdateCache; + + @Override + public void configure(Map configs) { + final SimpleConfig config = new SimpleConfig(CONFIG_DEF, ConfigUtils.translateDeprecatedConfigs(configs, new String[][]{ + {ConfigName.INCLUDE, "whitelist"}, + {ConfigName.EXCLUDE, "blacklist"}, + })); + + exclude = config.getList(ConfigName.EXCLUDE); + include = config.getList(ConfigName.INCLUDE); + renames = parseRenameMappings(config.getList(ConfigName.RENAME)); + reverseRenames = invert(renames); + + schemaUpdateCache = new SynchronizedCache<>(new LRUCache<>(16)); + } + + static Map parseRenameMappings(List mappings) { + final Map m = new HashMap<>(); + for (String mapping : mappings) { + final String[] parts = mapping.split(":"); + if (parts.length != 2) { + throw new ConfigException(ConfigName.RENAME, mappings, "Invalid rename mapping: " + mapping); + } + m.put(parts[0], parts[1]); + } + return m; + } + + static Map invert(Map source) { + final Map m = new HashMap<>(); + for (Map.Entry e : source.entrySet()) { + m.put(e.getValue(), e.getKey()); + } + return m; + } + + boolean filter(String fieldName) { + return !exclude.contains(fieldName) && (include.isEmpty() || include.contains(fieldName)); + } + + String renamed(String fieldName) { + final String mapping = renames.get(fieldName); + return mapping == null ? fieldName : mapping; + } + + String reverseRenamed(String fieldName) { + final String mapping = reverseRenames.get(fieldName); + return mapping == null ? fieldName : mapping; + } + + @Override + public R apply(R record) { + if (operatingValue(record) == null) { + return record; + } else if (operatingSchema(record) == null) { + return applySchemaless(record); + } else { + return applyWithSchema(record); + } + } + + private R applySchemaless(R record) { + final Map value = requireMap(operatingValue(record), PURPOSE); + + final Map updatedValue = new HashMap<>(value.size()); + + for (Map.Entry e : value.entrySet()) { + final String fieldName = e.getKey(); + if (filter(fieldName)) { + final Object fieldValue = e.getValue(); + updatedValue.put(renamed(fieldName), fieldValue); + } + } + + return newRecord(record, null, updatedValue); + } + + private R applyWithSchema(R record) { + final Struct value = requireStruct(operatingValue(record), PURPOSE); + + Schema updatedSchema = schemaUpdateCache.get(value.schema()); + if (updatedSchema == null) { + updatedSchema = makeUpdatedSchema(value.schema()); + schemaUpdateCache.put(value.schema(), updatedSchema); + } + + final Struct updatedValue = new Struct(updatedSchema); + + for (Field field : updatedSchema.fields()) { + final Object fieldValue = value.get(reverseRenamed(field.name())); + updatedValue.put(field.name(), fieldValue); + } + + return newRecord(record, updatedSchema, updatedValue); + } + + private Schema makeUpdatedSchema(Schema schema) { + final SchemaBuilder builder = SchemaUtil.copySchemaBasics(schema, SchemaBuilder.struct()); + for (Field field : schema.fields()) { + if (filter(field.name())) { + builder.field(renamed(field.name()), field.schema()); + } + } + return builder.build(); + } + + @Override + public ConfigDef config() { + return CONFIG_DEF; + } + + @Override + public void close() { + schemaUpdateCache = null; + } + + protected abstract Schema operatingSchema(R record); + + protected abstract Object operatingValue(R record); + + protected abstract R newRecord(R record, Schema updatedSchema, Object updatedValue); + + public static class Key> extends ReplaceField { + + @Override + protected Schema operatingSchema(R record) { + return record.keySchema(); + } + + @Override + protected Object operatingValue(R record) { + return record.key(); + } + + @Override + protected R newRecord(R record, Schema updatedSchema, Object updatedValue) { + return record.newRecord(record.topic(), record.kafkaPartition(), updatedSchema, updatedValue, record.valueSchema(), record.value(), record.timestamp()); + } + + } + + public static class Value> extends ReplaceField { + + @Override + protected Schema operatingSchema(R record) { + return record.valueSchema(); + } + + @Override + protected Object operatingValue(R record) { + return record.value(); + } + + @Override + protected R newRecord(R record, Schema updatedSchema, Object updatedValue) { + return record.newRecord(record.topic(), record.kafkaPartition(), record.keySchema(), record.key(), updatedSchema, updatedValue, record.timestamp()); + } + + } + +} diff --git a/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/SetSchemaMetadata.java b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/SetSchemaMetadata.java new file mode 100644 index 0000000..fd3cbf3 --- /dev/null +++ b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/SetSchemaMetadata.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.connect.connector.ConnectRecord; +import org.apache.kafka.connect.data.ConnectSchema; +import org.apache.kafka.connect.data.Field; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.transforms.util.SimpleConfig; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Map; + +import static org.apache.kafka.connect.transforms.util.Requirements.requireSchema; + +public abstract class SetSchemaMetadata> implements Transformation { + private static final Logger log = LoggerFactory.getLogger(SetSchemaMetadata.class); + + public static final String OVERVIEW_DOC = + "Set the schema name, version or both on the record's key (" + Key.class.getName() + ")" + + " or value (" + Value.class.getName() + ") schema."; + + private interface ConfigName { + String SCHEMA_NAME = "schema.name"; + String SCHEMA_VERSION = "schema.version"; + } + + public static final ConfigDef CONFIG_DEF = new ConfigDef() + .define(ConfigName.SCHEMA_NAME, ConfigDef.Type.STRING, null, ConfigDef.Importance.HIGH, "Schema name to set.") + .define(ConfigName.SCHEMA_VERSION, ConfigDef.Type.INT, null, ConfigDef.Importance.HIGH, "Schema version to set."); + + private String schemaName; + private Integer schemaVersion; + + @Override + public void configure(Map configs) { + final SimpleConfig config = new SimpleConfig(CONFIG_DEF, configs); + schemaName = config.getString(ConfigName.SCHEMA_NAME); + schemaVersion = config.getInt(ConfigName.SCHEMA_VERSION); + + if (schemaName == null && schemaVersion == null) { + throw new ConfigException("Neither schema name nor version configured"); + } + } + + @Override + public R apply(R record) { + final Schema schema = operatingSchema(record); + requireSchema(schema, "updating schema metadata"); + final boolean isArray = schema.type() == Schema.Type.ARRAY; + final boolean isMap = schema.type() == Schema.Type.MAP; + final Schema updatedSchema = new ConnectSchema( + schema.type(), + schema.isOptional(), + schema.defaultValue(), + schemaName != null ? schemaName : schema.name(), + schemaVersion != null ? schemaVersion : schema.version(), + schema.doc(), + schema.parameters(), + schema.fields(), + isMap ? schema.keySchema() : null, + isMap || isArray ? schema.valueSchema() : null + ); + log.trace("Applying SetSchemaMetadata SMT. Original schema: {}, updated schema: {}", + schema, updatedSchema); + return newRecord(record, updatedSchema); + } + + @Override + public ConfigDef config() { + return CONFIG_DEF; + } + + @Override + public void close() { + } + + protected abstract Schema operatingSchema(R record); + + protected abstract R newRecord(R record, Schema updatedSchema); + + /** + * Set the schema name, version or both on the record's key schema. + */ + public static class Key> extends SetSchemaMetadata { + @Override + protected Schema operatingSchema(R record) { + return record.keySchema(); + } + + @Override + protected R newRecord(R record, Schema updatedSchema) { + Object updatedKey = updateSchemaIn(record.key(), updatedSchema); + return record.newRecord(record.topic(), record.kafkaPartition(), updatedSchema, updatedKey, record.valueSchema(), record.value(), record.timestamp()); + } + } + + /** + * Set the schema name, version or both on the record's value schema. + */ + public static class Value> extends SetSchemaMetadata { + @Override + protected Schema operatingSchema(R record) { + return record.valueSchema(); + } + + @Override + protected R newRecord(R record, Schema updatedSchema) { + Object updatedValue = updateSchemaIn(record.value(), updatedSchema); + return record.newRecord(record.topic(), record.kafkaPartition(), record.keySchema(), record.key(), updatedSchema, updatedValue, record.timestamp()); + } + } + + /** + * Utility to check the supplied key or value for references to the old Schema, + * and if so to return an updated key or value object that references the new Schema. + * Note that this method assumes that the new Schema may have a different name and/or version, + * but has fields that exactly match those of the old Schema. + *

            + * Currently only {@link Struct} objects have references to the {@link Schema}. + * + * @param keyOrValue the key or value object; may be null + * @param updatedSchema the updated schema that has been potentially renamed + * @return the original key or value object if it does not reference the old schema, or + * a copy of the key or value object with updated references to the new schema. + */ + protected static Object updateSchemaIn(Object keyOrValue, Schema updatedSchema) { + if (keyOrValue instanceof Struct) { + Struct origStruct = (Struct) keyOrValue; + Struct newStruct = new Struct(updatedSchema); + for (Field field : updatedSchema.fields()) { + // assume both schemas have exact same fields with same names and schemas ... + newStruct.put(field, origStruct.get(field)); + } + return newStruct; + } + return keyOrValue; + } +} diff --git a/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/TimestampConverter.java b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/TimestampConverter.java new file mode 100644 index 0000000..a8d5cec --- /dev/null +++ b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/TimestampConverter.java @@ -0,0 +1,462 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.common.cache.Cache; +import org.apache.kafka.common.cache.LRUCache; +import org.apache.kafka.common.cache.SynchronizedCache; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.connect.connector.ConnectRecord; +import org.apache.kafka.connect.data.Field; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaBuilder; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.data.Time; +import org.apache.kafka.connect.data.Timestamp; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.errors.DataException; +import org.apache.kafka.connect.transforms.util.SchemaUtil; +import org.apache.kafka.connect.transforms.util.SimpleConfig; + +import java.text.ParseException; +import java.text.SimpleDateFormat; +import java.util.Arrays; +import java.util.Calendar; +import java.util.Date; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.TimeZone; + +import static org.apache.kafka.connect.transforms.util.Requirements.requireMap; +import static org.apache.kafka.connect.transforms.util.Requirements.requireStructOrNull; + +public abstract class TimestampConverter> implements Transformation { + + public static final String OVERVIEW_DOC = + "Convert timestamps between different formats such as Unix epoch, strings, and Connect Date/Timestamp types." + + "Applies to individual fields or to the entire value." + + "

            Use the concrete transformation type designed for the record key (" + TimestampConverter.Key.class.getName() + ") " + + "or value (" + TimestampConverter.Value.class.getName() + ")."; + + public static final String FIELD_CONFIG = "field"; + private static final String FIELD_DEFAULT = ""; + + public static final String TARGET_TYPE_CONFIG = "target.type"; + + public static final String FORMAT_CONFIG = "format"; + private static final String FORMAT_DEFAULT = ""; + + public static final ConfigDef CONFIG_DEF = new ConfigDef() + .define(FIELD_CONFIG, ConfigDef.Type.STRING, FIELD_DEFAULT, ConfigDef.Importance.HIGH, + "The field containing the timestamp, or empty if the entire value is a timestamp") + .define(TARGET_TYPE_CONFIG, ConfigDef.Type.STRING, ConfigDef.Importance.HIGH, + "The desired timestamp representation: string, unix, Date, Time, or Timestamp") + .define(FORMAT_CONFIG, ConfigDef.Type.STRING, FORMAT_DEFAULT, ConfigDef.Importance.MEDIUM, + "A SimpleDateFormat-compatible format for the timestamp. Used to generate the output when type=string " + + "or used to parse the input if the input is a string."); + + private static final String PURPOSE = "converting timestamp formats"; + + private static final String TYPE_STRING = "string"; + private static final String TYPE_UNIX = "unix"; + private static final String TYPE_DATE = "Date"; + private static final String TYPE_TIME = "Time"; + private static final String TYPE_TIMESTAMP = "Timestamp"; + private static final Set VALID_TYPES = new HashSet<>(Arrays.asList(TYPE_STRING, TYPE_UNIX, TYPE_DATE, TYPE_TIME, TYPE_TIMESTAMP)); + + private static final TimeZone UTC = TimeZone.getTimeZone("UTC"); + + public static final Schema OPTIONAL_DATE_SCHEMA = org.apache.kafka.connect.data.Date.builder().optional().schema(); + public static final Schema OPTIONAL_TIMESTAMP_SCHEMA = Timestamp.builder().optional().schema(); + public static final Schema OPTIONAL_TIME_SCHEMA = Time.builder().optional().schema(); + + private interface TimestampTranslator { + /** + * Convert from the type-specific format to the universal java.util.Date format + */ + Date toRaw(Config config, Object orig); + + /** + * Get the schema for this format. + */ + Schema typeSchema(boolean isOptional); + + /** + * Convert from the universal java.util.Date format to the type-specific format + */ + Object toType(Config config, Date orig); + } + + private static final Map TRANSLATORS = new HashMap<>(); + static { + TRANSLATORS.put(TYPE_STRING, new TimestampTranslator() { + @Override + public Date toRaw(Config config, Object orig) { + if (!(orig instanceof String)) + throw new DataException("Expected string timestamp to be a String, but found " + orig.getClass()); + try { + return config.format.parse((String) orig); + } catch (ParseException e) { + throw new DataException("Could not parse timestamp: value (" + orig + ") does not match pattern (" + + config.format.toPattern() + ")", e); + } + } + + @Override + public Schema typeSchema(boolean isOptional) { + return isOptional ? Schema.OPTIONAL_STRING_SCHEMA : Schema.STRING_SCHEMA; + } + + @Override + public String toType(Config config, Date orig) { + synchronized (config.format) { + return config.format.format(orig); + } + } + }); + + TRANSLATORS.put(TYPE_UNIX, new TimestampTranslator() { + @Override + public Date toRaw(Config config, Object orig) { + if (!(orig instanceof Long)) + throw new DataException("Expected Unix timestamp to be a Long, but found " + orig.getClass()); + return Timestamp.toLogical(Timestamp.SCHEMA, (Long) orig); + } + + @Override + public Schema typeSchema(boolean isOptional) { + return isOptional ? Schema.OPTIONAL_INT64_SCHEMA : Schema.INT64_SCHEMA; + } + + @Override + public Long toType(Config config, Date orig) { + return Timestamp.fromLogical(Timestamp.SCHEMA, orig); + } + }); + + TRANSLATORS.put(TYPE_DATE, new TimestampTranslator() { + @Override + public Date toRaw(Config config, Object orig) { + if (!(orig instanceof Date)) + throw new DataException("Expected Date to be a java.util.Date, but found " + orig.getClass()); + // Already represented as a java.util.Date and Connect Dates are a subset of valid java.util.Date values + return (Date) orig; + } + + @Override + public Schema typeSchema(boolean isOptional) { + return isOptional ? OPTIONAL_DATE_SCHEMA : org.apache.kafka.connect.data.Date.SCHEMA; + } + + @Override + public Date toType(Config config, Date orig) { + Calendar result = Calendar.getInstance(UTC); + result.setTime(orig); + result.set(Calendar.HOUR_OF_DAY, 0); + result.set(Calendar.MINUTE, 0); + result.set(Calendar.SECOND, 0); + result.set(Calendar.MILLISECOND, 0); + return result.getTime(); + } + }); + + TRANSLATORS.put(TYPE_TIME, new TimestampTranslator() { + @Override + public Date toRaw(Config config, Object orig) { + if (!(orig instanceof Date)) + throw new DataException("Expected Time to be a java.util.Date, but found " + orig.getClass()); + // Already represented as a java.util.Date and Connect Times are a subset of valid java.util.Date values + return (Date) orig; + } + + @Override + public Schema typeSchema(boolean isOptional) { + return isOptional ? OPTIONAL_TIME_SCHEMA : Time.SCHEMA; + } + + @Override + public Date toType(Config config, Date orig) { + Calendar origCalendar = Calendar.getInstance(UTC); + origCalendar.setTime(orig); + Calendar result = Calendar.getInstance(UTC); + result.setTimeInMillis(0L); + result.set(Calendar.HOUR_OF_DAY, origCalendar.get(Calendar.HOUR_OF_DAY)); + result.set(Calendar.MINUTE, origCalendar.get(Calendar.MINUTE)); + result.set(Calendar.SECOND, origCalendar.get(Calendar.SECOND)); + result.set(Calendar.MILLISECOND, origCalendar.get(Calendar.MILLISECOND)); + return result.getTime(); + } + }); + + TRANSLATORS.put(TYPE_TIMESTAMP, new TimestampTranslator() { + @Override + public Date toRaw(Config config, Object orig) { + if (!(orig instanceof Date)) + throw new DataException("Expected Timestamp to be a java.util.Date, but found " + orig.getClass()); + return (Date) orig; + } + + @Override + public Schema typeSchema(boolean isOptional) { + return isOptional ? OPTIONAL_TIMESTAMP_SCHEMA : Timestamp.SCHEMA; + } + + @Override + public Date toType(Config config, Date orig) { + return orig; + } + }); + } + + // This is a bit unusual, but allows the transformation config to be passed to static anonymous classes to customize + // their behavior + private static class Config { + Config(String field, String type, SimpleDateFormat format) { + this.field = field; + this.type = type; + this.format = format; + } + String field; + String type; + SimpleDateFormat format; + } + private Config config; + private Cache schemaUpdateCache; + + + @Override + public void configure(Map configs) { + final SimpleConfig simpleConfig = new SimpleConfig(CONFIG_DEF, configs); + final String field = simpleConfig.getString(FIELD_CONFIG); + final String type = simpleConfig.getString(TARGET_TYPE_CONFIG); + String formatPattern = simpleConfig.getString(FORMAT_CONFIG); + schemaUpdateCache = new SynchronizedCache<>(new LRUCache<>(16)); + + if (!VALID_TYPES.contains(type)) { + throw new ConfigException("Unknown timestamp type in TimestampConverter: " + type + ". Valid values are " + + Utils.join(VALID_TYPES, ", ") + "."); + } + if (type.equals(TYPE_STRING) && Utils.isBlank(formatPattern)) { + throw new ConfigException("TimestampConverter requires format option to be specified when using string timestamps"); + } + SimpleDateFormat format = null; + if (!Utils.isBlank(formatPattern)) { + try { + format = new SimpleDateFormat(formatPattern); + format.setTimeZone(UTC); + } catch (IllegalArgumentException e) { + throw new ConfigException("TimestampConverter requires a SimpleDateFormat-compatible pattern for string timestamps: " + + formatPattern, e); + } + } + config = new Config(field, type, format); + } + + @Override + public R apply(R record) { + if (operatingSchema(record) == null) { + return applySchemaless(record); + } else { + return applyWithSchema(record); + } + } + + @Override + public ConfigDef config() { + return CONFIG_DEF; + } + + @Override + public void close() { + } + + public static class Key> extends TimestampConverter { + @Override + protected Schema operatingSchema(R record) { + return record.keySchema(); + } + + @Override + protected Object operatingValue(R record) { + return record.key(); + } + + @Override + protected R newRecord(R record, Schema updatedSchema, Object updatedValue) { + return record.newRecord(record.topic(), record.kafkaPartition(), updatedSchema, updatedValue, record.valueSchema(), record.value(), record.timestamp()); + } + } + + public static class Value> extends TimestampConverter { + @Override + protected Schema operatingSchema(R record) { + return record.valueSchema(); + } + + @Override + protected Object operatingValue(R record) { + return record.value(); + } + + @Override + protected R newRecord(R record, Schema updatedSchema, Object updatedValue) { + return record.newRecord(record.topic(), record.kafkaPartition(), record.keySchema(), record.key(), updatedSchema, updatedValue, record.timestamp()); + } + } + + protected abstract Schema operatingSchema(R record); + + protected abstract Object operatingValue(R record); + + protected abstract R newRecord(R record, Schema updatedSchema, Object updatedValue); + + private R applyWithSchema(R record) { + final Schema schema = operatingSchema(record); + if (config.field.isEmpty()) { + Object value = operatingValue(record); + // New schema is determined by the requested target timestamp type + Schema updatedSchema = TRANSLATORS.get(config.type).typeSchema(schema.isOptional()); + return newRecord(record, updatedSchema, convertTimestamp(value, timestampTypeFromSchema(schema))); + } else { + final Struct value = requireStructOrNull(operatingValue(record), PURPOSE); + Schema updatedSchema = schemaUpdateCache.get(schema); + if (updatedSchema == null) { + SchemaBuilder builder = SchemaUtil.copySchemaBasics(schema, SchemaBuilder.struct()); + for (Field field : schema.fields()) { + if (field.name().equals(config.field)) { + builder.field(field.name(), TRANSLATORS.get(config.type).typeSchema(field.schema().isOptional())); + } else { + builder.field(field.name(), field.schema()); + } + } + if (schema.isOptional()) + builder.optional(); + if (schema.defaultValue() != null) { + Struct updatedDefaultValue = applyValueWithSchema((Struct) schema.defaultValue(), builder); + builder.defaultValue(updatedDefaultValue); + } + + updatedSchema = builder.build(); + schemaUpdateCache.put(schema, updatedSchema); + } + + Struct updatedValue = applyValueWithSchema(value, updatedSchema); + return newRecord(record, updatedSchema, updatedValue); + } + } + + private Struct applyValueWithSchema(Struct value, Schema updatedSchema) { + if (value == null) { + return null; + } + Struct updatedValue = new Struct(updatedSchema); + for (Field field : value.schema().fields()) { + final Object updatedFieldValue; + if (field.name().equals(config.field)) { + updatedFieldValue = convertTimestamp(value.get(field), timestampTypeFromSchema(field.schema())); + } else { + updatedFieldValue = value.get(field); + } + updatedValue.put(field.name(), updatedFieldValue); + } + return updatedValue; + } + + private R applySchemaless(R record) { + Object rawValue = operatingValue(record); + if (rawValue == null || config.field.isEmpty()) { + return newRecord(record, null, convertTimestamp(rawValue)); + } else { + final Map value = requireMap(rawValue, PURPOSE); + final HashMap updatedValue = new HashMap<>(value); + updatedValue.put(config.field, convertTimestamp(value.get(config.field))); + return newRecord(record, null, updatedValue); + } + } + + /** + * Determine the type/format of the timestamp based on the schema + */ + private String timestampTypeFromSchema(Schema schema) { + if (Timestamp.LOGICAL_NAME.equals(schema.name())) { + return TYPE_TIMESTAMP; + } else if (org.apache.kafka.connect.data.Date.LOGICAL_NAME.equals(schema.name())) { + return TYPE_DATE; + } else if (Time.LOGICAL_NAME.equals(schema.name())) { + return TYPE_TIME; + } else if (schema.type().equals(Schema.Type.STRING)) { + // If not otherwise specified, string == user-specified string format for timestamps + return TYPE_STRING; + } else if (schema.type().equals(Schema.Type.INT64)) { + // If not otherwise specified, long == unix time + return TYPE_UNIX; + } + throw new ConnectException("Schema " + schema + " does not correspond to a known timestamp type format"); + } + + /** + * Infer the type/format of the timestamp based on the raw Java type + */ + private String inferTimestampType(Object timestamp) { + // Note that we can't infer all types, e.g. Date/Time/Timestamp all have the same runtime representation as a + // java.util.Date + if (timestamp instanceof Date) { + return TYPE_TIMESTAMP; + } else if (timestamp instanceof Long) { + return TYPE_UNIX; + } else if (timestamp instanceof String) { + return TYPE_STRING; + } + throw new DataException("TimestampConverter does not support " + timestamp.getClass() + " objects as timestamps"); + } + + /** + * Convert the given timestamp to the target timestamp format. + * @param timestamp the input timestamp, may be null + * @param timestampFormat the format of the timestamp, or null if the format should be inferred + * @return the converted timestamp + */ + private Object convertTimestamp(Object timestamp, String timestampFormat) { + if (timestamp == null) { + return null; + } + if (timestampFormat == null) { + timestampFormat = inferTimestampType(timestamp); + } + + TimestampTranslator sourceTranslator = TRANSLATORS.get(timestampFormat); + if (sourceTranslator == null) { + throw new ConnectException("Unsupported timestamp type: " + timestampFormat); + } + Date rawTimestamp = sourceTranslator.toRaw(config, timestamp); + + TimestampTranslator targetTranslator = TRANSLATORS.get(config.type); + if (targetTranslator == null) { + throw new ConnectException("Unsupported timestamp type: " + config.type); + } + return targetTranslator.toType(config, rawTimestamp); + } + + private Object convertTimestamp(Object timestamp) { + return convertTimestamp(timestamp, null); + } +} diff --git a/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/TimestampRouter.java b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/TimestampRouter.java new file mode 100644 index 0000000..f7b1e58 --- /dev/null +++ b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/TimestampRouter.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.connector.ConnectRecord; +import org.apache.kafka.connect.errors.DataException; +import org.apache.kafka.connect.transforms.util.SimpleConfig; + +import java.text.SimpleDateFormat; +import java.util.Date; +import java.util.Map; +import java.util.TimeZone; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public class TimestampRouter> implements Transformation, AutoCloseable { + + private static final Pattern TOPIC = Pattern.compile("${topic}", Pattern.LITERAL); + + private static final Pattern TIMESTAMP = Pattern.compile("${timestamp}", Pattern.LITERAL); + + public static final String OVERVIEW_DOC = + "Update the record's topic field as a function of the original topic value and the record timestamp." + + "

            " + + "This is mainly useful for sink connectors, since the topic field is often used to determine the equivalent entity name in the destination system" + + "(e.g. database table or search index name)."; + + public static final ConfigDef CONFIG_DEF = new ConfigDef() + .define(ConfigName.TOPIC_FORMAT, ConfigDef.Type.STRING, "${topic}-${timestamp}", ConfigDef.Importance.HIGH, + "Format string which can contain ${topic} and ${timestamp} as placeholders for the topic and timestamp, respectively.") + .define(ConfigName.TIMESTAMP_FORMAT, ConfigDef.Type.STRING, "yyyyMMdd", ConfigDef.Importance.HIGH, + "Format string for the timestamp that is compatible with java.text.SimpleDateFormat."); + + private interface ConfigName { + String TOPIC_FORMAT = "topic.format"; + String TIMESTAMP_FORMAT = "timestamp.format"; + } + + private String topicFormat; + private ThreadLocal timestampFormat; + + @Override + public void configure(Map props) { + final SimpleConfig config = new SimpleConfig(CONFIG_DEF, props); + + topicFormat = config.getString(ConfigName.TOPIC_FORMAT); + + final String timestampFormatStr = config.getString(ConfigName.TIMESTAMP_FORMAT); + timestampFormat = ThreadLocal.withInitial(() -> { + final SimpleDateFormat fmt = new SimpleDateFormat(timestampFormatStr); + fmt.setTimeZone(TimeZone.getTimeZone("UTC")); + return fmt; + }); + } + + @Override + public R apply(R record) { + final Long timestamp = record.timestamp(); + if (timestamp == null) { + throw new DataException("Timestamp missing on record: " + record); + } + final String formattedTimestamp = timestampFormat.get().format(new Date(timestamp)); + + final String replace1 = TOPIC.matcher(topicFormat).replaceAll(Matcher.quoteReplacement(record.topic())); + final String updatedTopic = TIMESTAMP.matcher(replace1).replaceAll(Matcher.quoteReplacement(formattedTimestamp)); + return record.newRecord( + updatedTopic, record.kafkaPartition(), + record.keySchema(), record.key(), + record.valueSchema(), record.value(), + record.timestamp() + ); + } + + @Override + public void close() { + timestampFormat.remove(); + } + + @Override + public ConfigDef config() { + return CONFIG_DEF; + } + +} diff --git a/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/ValueToKey.java b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/ValueToKey.java new file mode 100644 index 0000000..8f843f4 --- /dev/null +++ b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/ValueToKey.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.common.cache.Cache; +import org.apache.kafka.common.cache.LRUCache; +import org.apache.kafka.common.cache.SynchronizedCache; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.connector.ConnectRecord; +import org.apache.kafka.connect.data.Field; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaBuilder; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.errors.DataException; +import org.apache.kafka.connect.transforms.util.NonEmptyListValidator; +import org.apache.kafka.connect.transforms.util.SimpleConfig; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.kafka.connect.transforms.util.Requirements.requireMap; +import static org.apache.kafka.connect.transforms.util.Requirements.requireStruct; + +public class ValueToKey> implements Transformation { + + public static final String OVERVIEW_DOC = "Replace the record key with a new key formed from a subset of fields in the record value."; + + public static final String FIELDS_CONFIG = "fields"; + + public static final ConfigDef CONFIG_DEF = new ConfigDef() + .define(FIELDS_CONFIG, ConfigDef.Type.LIST, ConfigDef.NO_DEFAULT_VALUE, new NonEmptyListValidator(), ConfigDef.Importance.HIGH, + "Field names on the record value to extract as the record key."); + + private static final String PURPOSE = "copying fields from value to key"; + + private List fields; + + private Cache valueToKeySchemaCache; + + @Override + public void configure(Map configs) { + final SimpleConfig config = new SimpleConfig(CONFIG_DEF, configs); + fields = config.getList(FIELDS_CONFIG); + valueToKeySchemaCache = new SynchronizedCache<>(new LRUCache<>(16)); + } + + @Override + public R apply(R record) { + if (record.valueSchema() == null) { + return applySchemaless(record); + } else { + return applyWithSchema(record); + } + } + + private R applySchemaless(R record) { + final Map value = requireMap(record.value(), PURPOSE); + final Map key = new HashMap<>(fields.size()); + for (String field : fields) { + key.put(field, value.get(field)); + } + return record.newRecord(record.topic(), record.kafkaPartition(), null, key, record.valueSchema(), record.value(), record.timestamp()); + } + + private R applyWithSchema(R record) { + final Struct value = requireStruct(record.value(), PURPOSE); + + Schema keySchema = valueToKeySchemaCache.get(value.schema()); + if (keySchema == null) { + final SchemaBuilder keySchemaBuilder = SchemaBuilder.struct(); + for (String field : fields) { + final Field fieldFromValue = value.schema().field(field); + if (fieldFromValue == null) { + throw new DataException("Field does not exist: " + field); + } + keySchemaBuilder.field(field, fieldFromValue.schema()); + } + keySchema = keySchemaBuilder.build(); + valueToKeySchemaCache.put(value.schema(), keySchema); + } + + final Struct key = new Struct(keySchema); + for (String field : fields) { + key.put(field, value.get(field)); + } + + return record.newRecord(record.topic(), record.kafkaPartition(), keySchema, key, value.schema(), value, record.timestamp()); + } + + @Override + public ConfigDef config() { + return CONFIG_DEF; + } + + @Override + public void close() { + valueToKeySchemaCache = null; + } + +} diff --git a/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/predicates/HasHeaderKey.java b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/predicates/HasHeaderKey.java new file mode 100644 index 0000000..f15d426 --- /dev/null +++ b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/predicates/HasHeaderKey.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms.predicates; + +import java.util.Iterator; +import java.util.Map; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.connector.ConnectRecord; +import org.apache.kafka.connect.header.Header; +import org.apache.kafka.connect.transforms.util.SimpleConfig; + +/** + * A predicate which is true for records with at least one header with the configured name. + * @param The type of connect record. + */ +public class HasHeaderKey> implements Predicate { + + private static final String NAME_CONFIG = "name"; + public static final String OVERVIEW_DOC = "A predicate which is true for records with at least one header with the configured name."; + public static final ConfigDef CONFIG_DEF = new ConfigDef() + .define(NAME_CONFIG, ConfigDef.Type.STRING, ConfigDef.NO_DEFAULT_VALUE, + new ConfigDef.NonEmptyString(), ConfigDef.Importance.MEDIUM, + "The header name."); + private String name; + + @Override + public ConfigDef config() { + return CONFIG_DEF; + } + + @Override + public boolean test(R record) { + Iterator

            headerIterator = record.headers().allWithName(name); + return headerIterator != null && headerIterator.hasNext(); + } + + @Override + public void close() { + + } + + @Override + public void configure(Map configs) { + this.name = new SimpleConfig(config(), configs).getString(NAME_CONFIG); + } + + @Override + public String toString() { + return "HasHeaderKey{" + + "name='" + name + '\'' + + '}'; + } +} diff --git a/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/predicates/RecordIsTombstone.java b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/predicates/RecordIsTombstone.java new file mode 100644 index 0000000..4a21eac --- /dev/null +++ b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/predicates/RecordIsTombstone.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms.predicates; + +import java.util.Map; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.connect.connector.ConnectRecord; + +/** + * A predicate which is true for records which are tombstones (i.e. have null value). + * @param The type of connect record. + */ +public class RecordIsTombstone> implements Predicate { + + public static final String OVERVIEW_DOC = "A predicate which is true for records which are tombstones (i.e. have null value)."; + public static final ConfigDef CONFIG_DEF = new ConfigDef(); + + @Override + public ConfigDef config() { + return CONFIG_DEF; + } + + @Override + public boolean test(R record) { + return record.value() == null; + } + + @Override + public void close() { + + } + + @Override + public void configure(Map configs) { + + } + + @Override + public String toString() { + return "RecordIsTombstone{}"; + } +} diff --git a/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/predicates/TopicNameMatches.java b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/predicates/TopicNameMatches.java new file mode 100644 index 0000000..3ea8f1a --- /dev/null +++ b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/predicates/TopicNameMatches.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms.predicates; + +import java.util.Map; +import java.util.regex.Pattern; +import java.util.regex.PatternSyntaxException; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.connect.connector.ConnectRecord; +import org.apache.kafka.connect.transforms.util.RegexValidator; +import org.apache.kafka.connect.transforms.util.SimpleConfig; + +/** + * A predicate which is true for records with a topic name that matches the configured regular expression. + * @param The type of connect record. + */ +public class TopicNameMatches> implements Predicate { + + private static final String PATTERN_CONFIG = "pattern"; + + public static final String OVERVIEW_DOC = "A predicate which is true for records with a topic name that matches the configured regular expression."; + + public static final ConfigDef CONFIG_DEF = new ConfigDef() + .define(PATTERN_CONFIG, ConfigDef.Type.STRING, ConfigDef.NO_DEFAULT_VALUE, + ConfigDef.CompositeValidator.of(new ConfigDef.NonEmptyString(), new RegexValidator()), + ConfigDef.Importance.MEDIUM, + "A Java regular expression for matching against the name of a record's topic."); + private Pattern pattern; + + @Override + public ConfigDef config() { + return CONFIG_DEF; + } + + @Override + public boolean test(R record) { + return record.topic() != null && pattern.matcher(record.topic()).matches(); + } + + @Override + public void close() { + + } + + @Override + public void configure(Map configs) { + SimpleConfig simpleConfig = new SimpleConfig(config(), configs); + Pattern result; + String value = simpleConfig.getString(PATTERN_CONFIG); + try { + result = Pattern.compile(value); + } catch (PatternSyntaxException e) { + throw new ConfigException(PATTERN_CONFIG, value, "entry must be a Java-compatible regular expression: " + e.getMessage()); + } + this.pattern = result; + } + + @Override + public String toString() { + return "TopicNameMatches{" + + "pattern=" + pattern + + '}'; + } +} diff --git a/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/util/NonEmptyListValidator.java b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/util/NonEmptyListValidator.java new file mode 100644 index 0000000..dacb07b --- /dev/null +++ b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/util/NonEmptyListValidator.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms.util; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigException; + +import java.util.List; + +public class NonEmptyListValidator implements ConfigDef.Validator { + + @Override + public void ensureValid(String name, Object value) { + if (value == null || ((List) value).isEmpty()) { + throw new ConfigException(name, value, "Empty list"); + } + } + + @Override + public String toString() { + return "non-empty list"; + } + +} diff --git a/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/util/RegexValidator.java b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/util/RegexValidator.java new file mode 100644 index 0000000..d451f00 --- /dev/null +++ b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/util/RegexValidator.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms.util; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigException; + +import java.util.regex.Pattern; + +public class RegexValidator implements ConfigDef.Validator { + + @Override + public void ensureValid(String name, Object value) { + try { + Pattern.compile((String) value); + } catch (Exception e) { + throw new ConfigException(name, value, "Invalid regex: " + e.getMessage()); + } + } + + @Override + public String toString() { + return "valid regex"; + } + +} diff --git a/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/util/Requirements.java b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/util/Requirements.java new file mode 100644 index 0000000..6d1cd78 --- /dev/null +++ b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/util/Requirements.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms.util; + +import org.apache.kafka.connect.connector.ConnectRecord; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.errors.DataException; +import org.apache.kafka.connect.sink.SinkRecord; + +import java.util.Map; + +public class Requirements { + + public static void requireSchema(Schema schema, String purpose) { + if (schema == null) { + throw new DataException("Schema required for [" + purpose + "]"); + } + } + + @SuppressWarnings("unchecked") + public static Map requireMap(Object value, String purpose) { + if (!(value instanceof Map)) { + throw new DataException("Only Map objects supported in absence of schema for [" + purpose + "], found: " + nullSafeClassName(value)); + } + return (Map) value; + } + + public static Map requireMapOrNull(Object value, String purpose) { + if (value == null) { + return null; + } + return requireMap(value, purpose); + } + + public static Struct requireStruct(Object value, String purpose) { + if (!(value instanceof Struct)) { + throw new DataException("Only Struct objects supported for [" + purpose + "], found: " + nullSafeClassName(value)); + } + return (Struct) value; + } + + public static Struct requireStructOrNull(Object value, String purpose) { + if (value == null) { + return null; + } + return requireStruct(value, purpose); + } + + public static SinkRecord requireSinkRecord(ConnectRecord record, String purpose) { + if (!(record instanceof SinkRecord)) { + throw new DataException("Only SinkRecord supported for [" + purpose + "], found: " + nullSafeClassName(record)); + } + return (SinkRecord) record; + } + + private static String nullSafeClassName(Object x) { + return x == null ? "null" : x.getClass().getName(); + } + +} diff --git a/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/util/SchemaUtil.java b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/util/SchemaUtil.java new file mode 100644 index 0000000..c7c3f1e --- /dev/null +++ b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/util/SchemaUtil.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms.util; + +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaBuilder; + +import java.util.Map; + +public class SchemaUtil { + + public static SchemaBuilder copySchemaBasics(Schema source) { + SchemaBuilder builder; + if (source.type() == Schema.Type.ARRAY) { + builder = SchemaBuilder.array(source.valueSchema()); + } else { + builder = new SchemaBuilder(source.type()); + } + return copySchemaBasics(source, builder); + } + + public static SchemaBuilder copySchemaBasics(Schema source, SchemaBuilder builder) { + builder.name(source.name()); + builder.version(source.version()); + builder.doc(source.doc()); + + final Map params = source.parameters(); + if (params != null) { + builder.parameters(params); + } + + return builder; + } + +} diff --git a/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/util/SimpleConfig.java b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/util/SimpleConfig.java new file mode 100644 index 0000000..7629922 --- /dev/null +++ b/connect/transforms/src/main/java/org/apache/kafka/connect/transforms/util/SimpleConfig.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms.util; + +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.ConfigDef; + +import java.util.Map; + +/** + * A barebones concrete implementation of {@link AbstractConfig}. + */ +public class SimpleConfig extends AbstractConfig { + + public SimpleConfig(ConfigDef configDef, Map originals) { + super(configDef, originals, false); + } + +} diff --git a/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/CastTest.java b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/CastTest.java new file mode 100644 index 0000000..60744b2 --- /dev/null +++ b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/CastTest.java @@ -0,0 +1,576 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.transforms; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.connect.data.Decimal; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.Schema.Type; +import org.apache.kafka.connect.data.SchemaBuilder; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.data.Time; +import org.apache.kafka.connect.data.Timestamp; +import org.apache.kafka.connect.data.Values; +import org.apache.kafka.connect.errors.DataException; +import org.apache.kafka.connect.source.SourceRecord; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import java.math.BigDecimal; +import java.util.Collections; +import java.util.Date; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class CastTest { + private final Cast xformKey = new Cast.Key<>(); + private final Cast xformValue = new Cast.Value<>(); + private static final long MILLIS_PER_HOUR = TimeUnit.HOURS.toMillis(1); + private static final long MILLIS_PER_DAY = TimeUnit.DAYS.toMillis(1); + + @AfterEach + public void teardown() { + xformKey.close(); + xformValue.close(); + } + + @Test + public void testConfigEmpty() { + assertThrows(ConfigException.class, () -> xformKey.configure(Collections.singletonMap(Cast.SPEC_CONFIG, ""))); + } + + @Test + public void testConfigInvalidSchemaType() { + assertThrows(ConfigException.class, () -> xformKey.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "foo:faketype"))); + } + + @Test + public void testConfigInvalidTargetType() { + assertThrows(ConfigException.class, () -> xformKey.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "foo:array"))); + } + + @Test + public void testUnsupportedTargetType() { + assertThrows(ConfigException.class, () -> xformKey.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "foo:bytes"))); + } + + @Test + public void testConfigInvalidMap() { + assertThrows(ConfigException.class, () -> xformKey.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "foo:int8:extra"))); + } + + @Test + public void testConfigMixWholeAndFieldTransformation() { + assertThrows(ConfigException.class, () -> xformKey.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "foo:int8,int32"))); + } + + @Test + public void castNullValueRecordWithSchema() { + xformValue.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "foo:int64")); + SourceRecord original = new SourceRecord(null, null, "topic", 0, + Schema.STRING_SCHEMA, "key", Schema.STRING_SCHEMA, null); + SourceRecord transformed = xformValue.apply(original); + assertEquals(original, transformed); + } + + @Test + public void castNullValueRecordSchemaless() { + xformValue.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "foo:int64")); + SourceRecord original = new SourceRecord(null, null, "topic", 0, + Schema.STRING_SCHEMA, "key", null, null); + SourceRecord transformed = xformValue.apply(original); + assertEquals(original, transformed); + } + + @Test + public void castNullKeyRecordWithSchema() { + xformKey.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "foo:int64")); + SourceRecord original = new SourceRecord(null, null, "topic", 0, + Schema.STRING_SCHEMA, null, Schema.STRING_SCHEMA, "value"); + SourceRecord transformed = xformKey.apply(original); + assertEquals(original, transformed); + } + + @Test + public void castNullKeyRecordSchemaless() { + xformKey.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "foo:int64")); + SourceRecord original = new SourceRecord(null, null, "topic", 0, + null, null, Schema.STRING_SCHEMA, "value"); + SourceRecord transformed = xformKey.apply(original); + assertEquals(original, transformed); + } + + @Test + public void castWholeRecordKeyWithSchema() { + xformKey.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "int8")); + SourceRecord transformed = xformKey.apply(new SourceRecord(null, null, "topic", 0, + Schema.INT32_SCHEMA, 42, Schema.STRING_SCHEMA, "bogus")); + + assertEquals(Schema.Type.INT8, transformed.keySchema().type()); + assertEquals((byte) 42, transformed.key()); + } + + @Test + public void castWholeRecordValueWithSchemaInt8() { + xformValue.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "int8")); + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, "topic", 0, + Schema.INT32_SCHEMA, 42)); + + assertEquals(Schema.Type.INT8, transformed.valueSchema().type()); + assertEquals((byte) 42, transformed.value()); + } + + @Test + public void castWholeRecordValueWithSchemaInt16() { + xformValue.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "int16")); + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, "topic", 0, + Schema.INT32_SCHEMA, 42)); + + assertEquals(Schema.Type.INT16, transformed.valueSchema().type()); + assertEquals((short) 42, transformed.value()); + } + + @Test + public void castWholeRecordValueWithSchemaInt32() { + xformValue.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "int32")); + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, "topic", 0, + Schema.INT32_SCHEMA, 42)); + + assertEquals(Schema.Type.INT32, transformed.valueSchema().type()); + assertEquals(42, transformed.value()); + } + + @Test + public void castWholeRecordValueWithSchemaInt64() { + xformValue.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "int64")); + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, "topic", 0, + Schema.INT32_SCHEMA, 42)); + + assertEquals(Schema.Type.INT64, transformed.valueSchema().type()); + assertEquals((long) 42, transformed.value()); + } + + @Test + public void castWholeRecordValueWithSchemaFloat32() { + xformValue.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "float32")); + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, "topic", 0, + Schema.INT32_SCHEMA, 42)); + + assertEquals(Schema.Type.FLOAT32, transformed.valueSchema().type()); + assertEquals(42.f, transformed.value()); + } + + @Test + public void castWholeRecordValueWithSchemaFloat64() { + xformValue.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "float64")); + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, "topic", 0, + Schema.INT32_SCHEMA, 42)); + + assertEquals(Schema.Type.FLOAT64, transformed.valueSchema().type()); + assertEquals(42., transformed.value()); + } + + @Test + public void castWholeRecordValueWithSchemaBooleanTrue() { + xformValue.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "boolean")); + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, "topic", 0, + Schema.INT32_SCHEMA, 42)); + + assertEquals(Schema.Type.BOOLEAN, transformed.valueSchema().type()); + assertEquals(true, transformed.value()); + } + + @Test + public void castWholeRecordValueWithSchemaBooleanFalse() { + xformValue.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "boolean")); + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, "topic", 0, + Schema.INT32_SCHEMA, 0)); + + assertEquals(Schema.Type.BOOLEAN, transformed.valueSchema().type()); + assertEquals(false, transformed.value()); + } + + @Test + public void castWholeRecordValueWithSchemaString() { + xformValue.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "string")); + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, "topic", 0, + Schema.INT32_SCHEMA, 42)); + + assertEquals(Schema.Type.STRING, transformed.valueSchema().type()); + assertEquals("42", transformed.value()); + } + + @Test + public void castWholeBigDecimalRecordValueWithSchemaString() { + BigDecimal bigDecimal = new BigDecimal(42); + xformValue.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "string")); + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, "topic", 0, + Decimal.schema(bigDecimal.scale()), bigDecimal)); + + assertEquals(Schema.Type.STRING, transformed.valueSchema().type()); + assertEquals("42", transformed.value()); + } + + @Test + public void castWholeDateRecordValueWithSchemaString() { + Date timestamp = new Date(MILLIS_PER_DAY + 1); // day + 1msec to get a timestamp formatting. + xformValue.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "string")); + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, "topic", 0, + Timestamp.SCHEMA, timestamp)); + + assertEquals(Schema.Type.STRING, transformed.valueSchema().type()); + assertEquals(Values.dateFormatFor(timestamp).format(timestamp), transformed.value()); + } + + @Test + public void castWholeRecordDefaultValue() { + // Validate default value in schema is correctly converted + xformValue.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "int32")); + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, "topic", 0, + SchemaBuilder.float32().defaultValue(-42.125f).build(), 42.125f)); + + assertEquals(Schema.Type.INT32, transformed.valueSchema().type()); + assertEquals(42, transformed.value()); + assertEquals(-42, transformed.valueSchema().defaultValue()); + } + + @Test + public void castWholeRecordKeySchemaless() { + xformKey.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "int8")); + SourceRecord transformed = xformKey.apply(new SourceRecord(null, null, "topic", 0, + null, 42, Schema.STRING_SCHEMA, "bogus")); + + assertNull(transformed.keySchema()); + assertEquals((byte) 42, transformed.key()); + } + + @Test + public void castWholeRecordValueSchemalessInt8() { + xformValue.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "int8")); + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, "topic", 0, + null, 42)); + + assertNull(transformed.valueSchema()); + assertEquals((byte) 42, transformed.value()); + } + + @Test + public void castWholeRecordValueSchemalessInt16() { + xformValue.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "int16")); + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, "topic", 0, + null, 42)); + + assertNull(transformed.valueSchema()); + assertEquals((short) 42, transformed.value()); + } + + @Test + public void castWholeRecordValueSchemalessInt32() { + xformValue.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "int32")); + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, "topic", 0, + null, 42)); + + assertNull(transformed.valueSchema()); + assertEquals(42, transformed.value()); + } + + @Test + public void castWholeRecordValueSchemalessInt64() { + xformValue.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "int64")); + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, "topic", 0, + null, 42)); + + assertNull(transformed.valueSchema()); + assertEquals((long) 42, transformed.value()); + } + + @Test + public void castWholeRecordValueSchemalessFloat32() { + xformValue.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "float32")); + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, "topic", 0, + null, 42)); + + assertNull(transformed.valueSchema()); + assertEquals(42.f, transformed.value()); + } + + @Test + public void castWholeRecordValueSchemalessFloat64() { + xformValue.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "float64")); + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, "topic", 0, + null, 42)); + + assertNull(transformed.valueSchema()); + assertEquals(42., transformed.value()); + } + + @Test + public void castWholeRecordValueSchemalessBooleanTrue() { + xformValue.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "boolean")); + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, "topic", 0, + null, 42)); + + assertNull(transformed.valueSchema()); + assertEquals(true, transformed.value()); + } + + @Test + public void castWholeRecordValueSchemalessBooleanFalse() { + xformValue.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "boolean")); + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, "topic", 0, + null, 0)); + + assertNull(transformed.valueSchema()); + assertEquals(false, transformed.value()); + } + + @Test + public void castWholeRecordValueSchemalessString() { + xformValue.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "string")); + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, "topic", 0, + null, 42)); + + assertNull(transformed.valueSchema()); + assertEquals("42", transformed.value()); + } + + @Test + public void castWholeRecordValueSchemalessUnsupportedType() { + xformValue.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "int8")); + assertThrows(DataException.class, + () -> xformValue.apply(new SourceRecord(null, null, "topic", 0, + null, Collections.singletonList("foo")))); + } + + @Test + public void castLogicalToPrimitive() { + List specParts = Arrays.asList( + "date_to_int32:int32", // Cast to underlying representation + "timestamp_to_int64:int64", // Cast to underlying representation + "time_to_int64:int64", // Cast to wider datatype than underlying representation + "decimal_to_int32:int32", // Cast to narrower datatype with data loss + "timestamp_to_float64:float64", // loss of precision casting to double + "null_timestamp_to_int32:int32" + ); + + Date day = new Date(MILLIS_PER_DAY); + xformValue.configure(Collections.singletonMap(Cast.SPEC_CONFIG, + String.join(",", specParts))); + + SchemaBuilder builder = SchemaBuilder.struct(); + builder.field("date_to_int32", org.apache.kafka.connect.data.Date.SCHEMA); + builder.field("timestamp_to_int64", Timestamp.SCHEMA); + builder.field("time_to_int64", Time.SCHEMA); + builder.field("decimal_to_int32", Decimal.schema(new BigDecimal((long) Integer.MAX_VALUE + 1).scale())); + builder.field("timestamp_to_float64", Timestamp.SCHEMA); + builder.field("null_timestamp_to_int32", Timestamp.builder().optional().build()); + + Schema supportedTypesSchema = builder.build(); + + Struct recordValue = new Struct(supportedTypesSchema); + recordValue.put("date_to_int32", day); + recordValue.put("timestamp_to_int64", new Date(0)); + recordValue.put("time_to_int64", new Date(1)); + recordValue.put("decimal_to_int32", new BigDecimal((long) Integer.MAX_VALUE + 1)); + recordValue.put("timestamp_to_float64", new Date(Long.MAX_VALUE)); + recordValue.put("null_timestamp_to_int32", null); + + SourceRecord transformed = xformValue.apply( + new SourceRecord(null, null, "topic", 0, + supportedTypesSchema, recordValue)); + + assertEquals(1, ((Struct) transformed.value()).get("date_to_int32")); + assertEquals(0L, ((Struct) transformed.value()).get("timestamp_to_int64")); + assertEquals(1L, ((Struct) transformed.value()).get("time_to_int64")); + assertEquals(Integer.MIN_VALUE, ((Struct) transformed.value()).get("decimal_to_int32")); + assertEquals(9.223372036854776E18, ((Struct) transformed.value()).get("timestamp_to_float64")); + assertNull(((Struct) transformed.value()).get("null_timestamp_to_int32")); + + Schema transformedSchema = ((Struct) transformed.value()).schema(); + assertEquals(Type.INT32, transformedSchema.field("date_to_int32").schema().type()); + assertEquals(Type.INT64, transformedSchema.field("timestamp_to_int64").schema().type()); + assertEquals(Type.INT64, transformedSchema.field("time_to_int64").schema().type()); + assertEquals(Type.INT32, transformedSchema.field("decimal_to_int32").schema().type()); + assertEquals(Type.FLOAT64, transformedSchema.field("timestamp_to_float64").schema().type()); + assertEquals(Type.INT32, transformedSchema.field("null_timestamp_to_int32").schema().type()); + } + + @Test + public void castLogicalToString() { + Date date = new Date(MILLIS_PER_DAY); + Date time = new Date(MILLIS_PER_HOUR); + Date timestamp = new Date(); + + xformValue.configure(Collections.singletonMap(Cast.SPEC_CONFIG, + "date:string,decimal:string,time:string,timestamp:string")); + + SchemaBuilder builder = SchemaBuilder.struct(); + builder.field("date", org.apache.kafka.connect.data.Date.SCHEMA); + builder.field("decimal", Decimal.schema(new BigDecimal(1982).scale())); + builder.field("time", Time.SCHEMA); + builder.field("timestamp", Timestamp.SCHEMA); + + Schema supportedTypesSchema = builder.build(); + + Struct recordValue = new Struct(supportedTypesSchema); + recordValue.put("date", date); + recordValue.put("decimal", new BigDecimal(1982)); + recordValue.put("time", time); + recordValue.put("timestamp", timestamp); + + SourceRecord transformed = xformValue.apply( + new SourceRecord(null, null, "topic", 0, + supportedTypesSchema, recordValue)); + + assertEquals(Values.dateFormatFor(date).format(date), ((Struct) transformed.value()).get("date")); + assertEquals("1982", ((Struct) transformed.value()).get("decimal")); + assertEquals(Values.dateFormatFor(time).format(time), ((Struct) transformed.value()).get("time")); + assertEquals(Values.dateFormatFor(timestamp).format(timestamp), ((Struct) transformed.value()).get("timestamp")); + + Schema transformedSchema = ((Struct) transformed.value()).schema(); + assertEquals(Type.STRING, transformedSchema.field("date").schema().type()); + assertEquals(Type.STRING, transformedSchema.field("decimal").schema().type()); + assertEquals(Type.STRING, transformedSchema.field("time").schema().type()); + assertEquals(Type.STRING, transformedSchema.field("timestamp").schema().type()); + } + + @Test + public void castFieldsWithSchema() { + Date day = new Date(MILLIS_PER_DAY); + byte[] byteArray = new byte[] {(byte) 0xFE, (byte) 0xDC, (byte) 0xBA, (byte) 0x98, 0x76, 0x54, 0x32, 0x10}; + ByteBuffer byteBuffer = ByteBuffer.wrap(Arrays.copyOf(byteArray, byteArray.length)); + + xformValue.configure(Collections.singletonMap(Cast.SPEC_CONFIG, + "int8:int16,int16:int32,int32:int64,int64:boolean,float32:float64,float64:boolean,boolean:int8,string:int32,bigdecimal:string,date:string,optional:int32,bytes:string,byteArray:string")); + + // Include an optional fields and fields with defaults to validate their values are passed through properly + SchemaBuilder builder = SchemaBuilder.struct(); + builder.field("int8", Schema.INT8_SCHEMA); + builder.field("int16", Schema.OPTIONAL_INT16_SCHEMA); + builder.field("int32", SchemaBuilder.int32().defaultValue(2).build()); + builder.field("int64", Schema.INT64_SCHEMA); + builder.field("float32", Schema.FLOAT32_SCHEMA); + // Default value here ensures we correctly convert default values + builder.field("float64", SchemaBuilder.float64().defaultValue(-1.125).build()); + builder.field("boolean", Schema.BOOLEAN_SCHEMA); + builder.field("string", Schema.STRING_SCHEMA); + builder.field("bigdecimal", Decimal.schema(new BigDecimal(42).scale())); + builder.field("date", org.apache.kafka.connect.data.Date.SCHEMA); + builder.field("optional", Schema.OPTIONAL_FLOAT32_SCHEMA); + builder.field("timestamp", Timestamp.SCHEMA); + builder.field("bytes", Schema.BYTES_SCHEMA); + builder.field("byteArray", Schema.BYTES_SCHEMA); + + Schema supportedTypesSchema = builder.build(); + + Struct recordValue = new Struct(supportedTypesSchema); + recordValue.put("int8", (byte) 8); + recordValue.put("int16", (short) 16); + recordValue.put("int32", 32); + recordValue.put("int64", (long) 64); + recordValue.put("float32", 32.f); + recordValue.put("float64", -64.); + recordValue.put("boolean", true); + recordValue.put("bigdecimal", new BigDecimal(42)); + recordValue.put("date", day); + recordValue.put("string", "42"); + recordValue.put("timestamp", new Date(0)); + recordValue.put("bytes", byteBuffer); + recordValue.put("byteArray", byteArray); + + // optional field intentionally omitted + + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, "topic", 0, + supportedTypesSchema, recordValue)); + + assertEquals((short) 8, ((Struct) transformed.value()).get("int8")); + assertTrue(((Struct) transformed.value()).schema().field("int16").schema().isOptional()); + assertEquals(16, ((Struct) transformed.value()).get("int16")); + assertEquals((long) 32, ((Struct) transformed.value()).get("int32")); + assertEquals(2L, ((Struct) transformed.value()).schema().field("int32").schema().defaultValue()); + assertEquals(true, ((Struct) transformed.value()).get("int64")); + assertEquals(32., ((Struct) transformed.value()).get("float32")); + assertEquals(true, ((Struct) transformed.value()).get("float64")); + assertEquals(true, ((Struct) transformed.value()).schema().field("float64").schema().defaultValue()); + assertEquals((byte) 1, ((Struct) transformed.value()).get("boolean")); + assertEquals(42, ((Struct) transformed.value()).get("string")); + assertEquals("42", ((Struct) transformed.value()).get("bigdecimal")); + assertEquals(Values.dateFormatFor(day).format(day), ((Struct) transformed.value()).get("date")); + assertEquals(new Date(0), ((Struct) transformed.value()).get("timestamp")); + assertEquals("/ty6mHZUMhA=", ((Struct) transformed.value()).get("bytes")); + assertEquals("/ty6mHZUMhA=", ((Struct) transformed.value()).get("byteArray")); + + assertNull(((Struct) transformed.value()).get("optional")); + + Schema transformedSchema = ((Struct) transformed.value()).schema(); + assertEquals(Schema.INT16_SCHEMA.type(), transformedSchema.field("int8").schema().type()); + assertEquals(Schema.OPTIONAL_INT32_SCHEMA.type(), transformedSchema.field("int16").schema().type()); + assertEquals(Schema.INT64_SCHEMA.type(), transformedSchema.field("int32").schema().type()); + assertEquals(Schema.BOOLEAN_SCHEMA.type(), transformedSchema.field("int64").schema().type()); + assertEquals(Schema.FLOAT64_SCHEMA.type(), transformedSchema.field("float32").schema().type()); + assertEquals(Schema.BOOLEAN_SCHEMA.type(), transformedSchema.field("float64").schema().type()); + assertEquals(Schema.INT8_SCHEMA.type(), transformedSchema.field("boolean").schema().type()); + assertEquals(Schema.INT32_SCHEMA.type(), transformedSchema.field("string").schema().type()); + assertEquals(Schema.STRING_SCHEMA.type(), transformedSchema.field("bigdecimal").schema().type()); + assertEquals(Schema.STRING_SCHEMA.type(), transformedSchema.field("date").schema().type()); + assertEquals(Schema.OPTIONAL_INT32_SCHEMA.type(), transformedSchema.field("optional").schema().type()); + assertEquals(Schema.STRING_SCHEMA.type(), transformedSchema.field("bytes").schema().type()); + assertEquals(Schema.STRING_SCHEMA.type(), transformedSchema.field("byteArray").schema().type()); + + // The following fields are not changed + assertEquals(Timestamp.SCHEMA.type(), transformedSchema.field("timestamp").schema().type()); + } + + @SuppressWarnings("unchecked") + @Test + public void castFieldsSchemaless() { + xformValue.configure(Collections.singletonMap(Cast.SPEC_CONFIG, "int8:int16,int16:int32,int32:int64,int64:boolean,float32:float64,float64:boolean,boolean:int8,string:int32")); + Map recordValue = new HashMap<>(); + recordValue.put("int8", (byte) 8); + recordValue.put("int16", (short) 16); + recordValue.put("int32", 32); + recordValue.put("int64", (long) 64); + recordValue.put("float32", 32.f); + recordValue.put("float64", -64.); + recordValue.put("boolean", true); + recordValue.put("string", "42"); + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, "topic", 0, + null, recordValue)); + + assertNull(transformed.valueSchema()); + assertEquals((short) 8, ((Map) transformed.value()).get("int8")); + assertEquals(16, ((Map) transformed.value()).get("int16")); + assertEquals((long) 32, ((Map) transformed.value()).get("int32")); + assertEquals(true, ((Map) transformed.value()).get("int64")); + assertEquals(32., ((Map) transformed.value()).get("float32")); + assertEquals(true, ((Map) transformed.value()).get("float64")); + assertEquals((byte) 1, ((Map) transformed.value()).get("boolean")); + assertEquals(42, ((Map) transformed.value()).get("string")); + } + +} diff --git a/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/DropHeadersTest.java b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/DropHeadersTest.java new file mode 100644 index 0000000..7d20c38 --- /dev/null +++ b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/DropHeadersTest.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.header.ConnectHeaders; +import org.apache.kafka.connect.source.SourceRecord; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; + +import static java.util.Arrays.asList; +import static java.util.Collections.singletonMap; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class DropHeadersTest { + + private DropHeaders xform = new DropHeaders<>(); + + private Map config(String... headers) { + Map result = new HashMap<>(); + result.put(DropHeaders.HEADERS_FIELD, asList(headers)); + return result; + } + + @Test + public void dropExistingHeader() { + xform.configure(config("to-drop")); + ConnectHeaders expected = new ConnectHeaders(); + expected.addString("existing", "existing-value"); + ConnectHeaders headers = expected.duplicate(); + headers.addString("to-drop", "existing-value"); + SourceRecord original = sourceRecord(headers); + SourceRecord xformed = xform.apply(original); + assertNonHeaders(original, xformed); + assertEquals(expected, xformed.headers()); + } + + @Test + public void dropExistingHeaderWithMultipleValues() { + xform.configure(config("to-drop")); + ConnectHeaders expected = new ConnectHeaders(); + expected.addString("existing", "existing-value"); + ConnectHeaders headers = expected.duplicate(); + headers.addString("to-drop", "existing-value"); + headers.addString("to-drop", "existing-other-value"); + + SourceRecord original = sourceRecord(headers); + SourceRecord xformed = xform.apply(original); + assertNonHeaders(original, xformed); + assertEquals(expected, xformed.headers()); + } + + @Test + public void dropNonExistingHeader() { + xform.configure(config("to-drop")); + ConnectHeaders expected = new ConnectHeaders(); + expected.addString("existing", "existing-value"); + ConnectHeaders headers = expected.duplicate(); + + SourceRecord original = sourceRecord(headers); + SourceRecord xformed = xform.apply(original); + assertNonHeaders(original, xformed); + assertEquals(expected, xformed.headers()); + } + + @Test + public void configRejectsEmptyList() { + assertThrows(ConfigException.class, () -> xform.configure(config())); + } + + private void assertNonHeaders(SourceRecord original, SourceRecord xformed) { + assertEquals(original.sourcePartition(), xformed.sourcePartition()); + assertEquals(original.sourceOffset(), xformed.sourceOffset()); + assertEquals(original.topic(), xformed.topic()); + assertEquals(original.kafkaPartition(), xformed.kafkaPartition()); + assertEquals(original.keySchema(), xformed.keySchema()); + assertEquals(original.key(), xformed.key()); + assertEquals(original.valueSchema(), xformed.valueSchema()); + assertEquals(original.value(), xformed.value()); + assertEquals(original.timestamp(), xformed.timestamp()); + } + + private SourceRecord sourceRecord(ConnectHeaders headers) { + Map sourcePartition = singletonMap("foo", "bar"); + Map sourceOffset = singletonMap("baz", "quxx"); + String topic = "topic"; + Integer partition = 0; + Schema keySchema = null; + Object key = "key"; + Schema valueSchema = null; + Object value = "value"; + Long timestamp = 0L; + + SourceRecord record = new SourceRecord(sourcePartition, sourceOffset, topic, partition, + keySchema, key, valueSchema, value, timestamp, headers); + return record; + } +} + diff --git a/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/ExtractFieldTest.java b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/ExtractFieldTest.java new file mode 100644 index 0000000..ce776f9 --- /dev/null +++ b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/ExtractFieldTest.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaBuilder; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.sink.SinkRecord; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.fail; + +public class ExtractFieldTest { + private final ExtractField xform = new ExtractField.Key<>(); + + @AfterEach + public void teardown() { + xform.close(); + } + + @Test + public void schemaless() { + xform.configure(Collections.singletonMap("field", "magic")); + + final SinkRecord record = new SinkRecord("test", 0, null, Collections.singletonMap("magic", 42), null, null, 0); + final SinkRecord transformedRecord = xform.apply(record); + + assertNull(transformedRecord.keySchema()); + assertEquals(42, transformedRecord.key()); + } + + @Test + public void testNullSchemaless() { + xform.configure(Collections.singletonMap("field", "magic")); + + final Map key = null; + final SinkRecord record = new SinkRecord("test", 0, null, key, null, null, 0); + final SinkRecord transformedRecord = xform.apply(record); + + assertNull(transformedRecord.keySchema()); + assertNull(transformedRecord.key()); + } + + @Test + public void withSchema() { + xform.configure(Collections.singletonMap("field", "magic")); + + final Schema keySchema = SchemaBuilder.struct().field("magic", Schema.INT32_SCHEMA).build(); + final Struct key = new Struct(keySchema).put("magic", 42); + final SinkRecord record = new SinkRecord("test", 0, keySchema, key, null, null, 0); + final SinkRecord transformedRecord = xform.apply(record); + + assertEquals(Schema.INT32_SCHEMA, transformedRecord.keySchema()); + assertEquals(42, transformedRecord.key()); + } + + @Test + public void testNullWithSchema() { + xform.configure(Collections.singletonMap("field", "magic")); + + final Schema keySchema = SchemaBuilder.struct().field("magic", Schema.INT32_SCHEMA).optional().build(); + final Struct key = null; + final SinkRecord record = new SinkRecord("test", 0, keySchema, key, null, null, 0); + final SinkRecord transformedRecord = xform.apply(record); + + assertEquals(Schema.INT32_SCHEMA, transformedRecord.keySchema()); + assertNull(transformedRecord.key()); + } + + @Test + public void nonExistentFieldSchemalessShouldReturnNull() { + xform.configure(Collections.singletonMap("field", "nonexistent")); + + final SinkRecord record = new SinkRecord("test", 0, null, Collections.singletonMap("magic", 42), null, null, 0); + final SinkRecord transformedRecord = xform.apply(record); + + assertNull(transformedRecord.keySchema()); + assertNull(transformedRecord.key()); + } + + @Test + public void nonExistentFieldWithSchemaShouldFail() { + xform.configure(Collections.singletonMap("field", "nonexistent")); + + final Schema keySchema = SchemaBuilder.struct().field("magic", Schema.INT32_SCHEMA).build(); + final Struct key = new Struct(keySchema).put("magic", 42); + final SinkRecord record = new SinkRecord("test", 0, keySchema, key, null, null, 0); + + try { + xform.apply(record); + fail("Expected exception wasn't raised"); + } catch (IllegalArgumentException iae) { + assertEquals("Unknown field: nonexistent", iae.getMessage()); + } + } +} diff --git a/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/FlattenTest.java b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/FlattenTest.java new file mode 100644 index 0000000..90d1724 --- /dev/null +++ b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/FlattenTest.java @@ -0,0 +1,391 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaBuilder; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.errors.DataException; +import org.apache.kafka.connect.source.SourceRecord; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class FlattenTest { + private final Flatten xformKey = new Flatten.Key<>(); + private final Flatten xformValue = new Flatten.Value<>(); + + @AfterEach + public void teardown() { + xformKey.close(); + xformValue.close(); + } + + @Test + public void topLevelStructRequired() { + xformValue.configure(Collections.emptyMap()); + assertThrows(DataException.class, () -> xformValue.apply(new SourceRecord(null, null, + "topic", 0, Schema.INT32_SCHEMA, 42))); + } + + @Test + public void topLevelMapRequired() { + xformValue.configure(Collections.emptyMap()); + assertThrows(DataException.class, () -> xformValue.apply(new SourceRecord(null, null, + "topic", 0, null, 42))); + } + + @Test + public void testNestedStruct() { + xformValue.configure(Collections.emptyMap()); + + SchemaBuilder builder = SchemaBuilder.struct(); + builder.field("int8", Schema.INT8_SCHEMA); + builder.field("int16", Schema.INT16_SCHEMA); + builder.field("int32", Schema.INT32_SCHEMA); + builder.field("int64", Schema.INT64_SCHEMA); + builder.field("float32", Schema.FLOAT32_SCHEMA); + builder.field("float64", Schema.FLOAT64_SCHEMA); + builder.field("boolean", Schema.BOOLEAN_SCHEMA); + builder.field("string", Schema.STRING_SCHEMA); + builder.field("bytes", Schema.BYTES_SCHEMA); + Schema supportedTypesSchema = builder.build(); + + builder = SchemaBuilder.struct(); + builder.field("B", supportedTypesSchema); + Schema oneLevelNestedSchema = builder.build(); + + builder = SchemaBuilder.struct(); + builder.field("A", oneLevelNestedSchema); + Schema twoLevelNestedSchema = builder.build(); + + Struct supportedTypes = new Struct(supportedTypesSchema); + supportedTypes.put("int8", (byte) 8); + supportedTypes.put("int16", (short) 16); + supportedTypes.put("int32", 32); + supportedTypes.put("int64", (long) 64); + supportedTypes.put("float32", 32.f); + supportedTypes.put("float64", 64.); + supportedTypes.put("boolean", true); + supportedTypes.put("string", "stringy"); + supportedTypes.put("bytes", "bytes".getBytes()); + + Struct oneLevelNestedStruct = new Struct(oneLevelNestedSchema); + oneLevelNestedStruct.put("B", supportedTypes); + + Struct twoLevelNestedStruct = new Struct(twoLevelNestedSchema); + twoLevelNestedStruct.put("A", oneLevelNestedStruct); + + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, + "topic", 0, + twoLevelNestedSchema, twoLevelNestedStruct)); + + assertEquals(Schema.Type.STRUCT, transformed.valueSchema().type()); + Struct transformedStruct = (Struct) transformed.value(); + assertEquals(9, transformedStruct.schema().fields().size()); + assertEquals(8, (byte) transformedStruct.getInt8("A.B.int8")); + assertEquals(16, (short) transformedStruct.getInt16("A.B.int16")); + assertEquals(32, (int) transformedStruct.getInt32("A.B.int32")); + assertEquals(64L, (long) transformedStruct.getInt64("A.B.int64")); + assertEquals(32.f, transformedStruct.getFloat32("A.B.float32"), 0.f); + assertEquals(64., transformedStruct.getFloat64("A.B.float64"), 0.); + assertEquals(true, transformedStruct.getBoolean("A.B.boolean")); + assertEquals("stringy", transformedStruct.getString("A.B.string")); + assertArrayEquals("bytes".getBytes(), transformedStruct.getBytes("A.B.bytes")); + } + + @Test + public void testNestedMapWithDelimiter() { + xformValue.configure(Collections.singletonMap("delimiter", "#")); + + Map supportedTypes = new HashMap<>(); + supportedTypes.put("int8", (byte) 8); + supportedTypes.put("int16", (short) 16); + supportedTypes.put("int32", 32); + supportedTypes.put("int64", (long) 64); + supportedTypes.put("float32", 32.f); + supportedTypes.put("float64", 64.); + supportedTypes.put("boolean", true); + supportedTypes.put("string", "stringy"); + supportedTypes.put("bytes", "bytes".getBytes()); + + Map oneLevelNestedMap = Collections.singletonMap("B", supportedTypes); + Map twoLevelNestedMap = Collections.singletonMap("A", oneLevelNestedMap); + + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, + "topic", 0, + null, twoLevelNestedMap)); + + assertNull(transformed.valueSchema()); + assertTrue(transformed.value() instanceof Map); + @SuppressWarnings("unchecked") + Map transformedMap = (Map) transformed.value(); + assertEquals(9, transformedMap.size()); + assertEquals((byte) 8, transformedMap.get("A#B#int8")); + assertEquals((short) 16, transformedMap.get("A#B#int16")); + assertEquals(32, transformedMap.get("A#B#int32")); + assertEquals((long) 64, transformedMap.get("A#B#int64")); + assertEquals(32.f, (float) transformedMap.get("A#B#float32"), 0.f); + assertEquals(64., (double) transformedMap.get("A#B#float64"), 0.); + assertEquals(true, transformedMap.get("A#B#boolean")); + assertEquals("stringy", transformedMap.get("A#B#string")); + assertArrayEquals("bytes".getBytes(), (byte[]) transformedMap.get("A#B#bytes")); + } + + @Test + public void testOptionalFieldStruct() { + xformValue.configure(Collections.emptyMap()); + + SchemaBuilder builder = SchemaBuilder.struct(); + builder.field("opt_int32", Schema.OPTIONAL_INT32_SCHEMA); + Schema supportedTypesSchema = builder.build(); + + builder = SchemaBuilder.struct(); + builder.field("B", supportedTypesSchema); + Schema oneLevelNestedSchema = builder.build(); + + Struct supportedTypes = new Struct(supportedTypesSchema); + supportedTypes.put("opt_int32", null); + + Struct oneLevelNestedStruct = new Struct(oneLevelNestedSchema); + oneLevelNestedStruct.put("B", supportedTypes); + + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, + "topic", 0, + oneLevelNestedSchema, oneLevelNestedStruct)); + + assertEquals(Schema.Type.STRUCT, transformed.valueSchema().type()); + Struct transformedStruct = (Struct) transformed.value(); + assertNull(transformedStruct.get("B.opt_int32")); + } + + @Test + public void testOptionalStruct() { + xformValue.configure(Collections.emptyMap()); + + SchemaBuilder builder = SchemaBuilder.struct().optional(); + builder.field("opt_int32", Schema.OPTIONAL_INT32_SCHEMA); + Schema schema = builder.build(); + + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, + "topic", 0, + schema, null)); + + assertEquals(Schema.Type.STRUCT, transformed.valueSchema().type()); + assertNull(transformed.value()); + } + + @Test + public void testOptionalNestedStruct() { + xformValue.configure(Collections.emptyMap()); + + SchemaBuilder builder = SchemaBuilder.struct().optional(); + builder.field("opt_int32", Schema.OPTIONAL_INT32_SCHEMA); + Schema supportedTypesSchema = builder.build(); + + builder = SchemaBuilder.struct(); + builder.field("B", supportedTypesSchema); + Schema oneLevelNestedSchema = builder.build(); + + Struct oneLevelNestedStruct = new Struct(oneLevelNestedSchema); + oneLevelNestedStruct.put("B", null); + + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, + "topic", 0, + oneLevelNestedSchema, oneLevelNestedStruct)); + + assertEquals(Schema.Type.STRUCT, transformed.valueSchema().type()); + Struct transformedStruct = (Struct) transformed.value(); + assertNull(transformedStruct.get("B.opt_int32")); + } + + @Test + public void testOptionalFieldMap() { + xformValue.configure(Collections.emptyMap()); + + Map supportedTypes = new HashMap<>(); + supportedTypes.put("opt_int32", null); + + Map oneLevelNestedMap = Collections.singletonMap("B", supportedTypes); + + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, + "topic", 0, + null, oneLevelNestedMap)); + + assertNull(transformed.valueSchema()); + assertTrue(transformed.value() instanceof Map); + @SuppressWarnings("unchecked") + Map transformedMap = (Map) transformed.value(); + + assertNull(transformedMap.get("B.opt_int32")); + } + + @Test + public void testKey() { + xformKey.configure(Collections.emptyMap()); + + Map> key = Collections.singletonMap("A", Collections.singletonMap("B", 12)); + SourceRecord src = new SourceRecord(null, null, "topic", null, key, null, null); + SourceRecord transformed = xformKey.apply(src); + + assertNull(transformed.keySchema()); + assertTrue(transformed.key() instanceof Map); + @SuppressWarnings("unchecked") + Map transformedMap = (Map) transformed.key(); + assertEquals(12, transformedMap.get("A.B")); + } + + @Test + public void testSchemalessArray() { + xformValue.configure(Collections.emptyMap()); + Object value = Collections.singletonMap("foo", Arrays.asList("bar", Collections.singletonMap("baz", Collections.singletonMap("lfg", "lfg")))); + assertEquals(value, xformValue.apply(new SourceRecord(null, null, "topic", null, null, null, value)).value()); + } + + @Test + public void testArrayWithSchema() { + xformValue.configure(Collections.emptyMap()); + Schema nestedStructSchema = SchemaBuilder.struct().field("lfg", Schema.STRING_SCHEMA).build(); + Schema innerStructSchema = SchemaBuilder.struct().field("baz", nestedStructSchema).build(); + Schema structSchema = SchemaBuilder.struct() + .field("foo", SchemaBuilder.array(innerStructSchema).doc("durk").build()) + .build(); + Struct nestedValue = new Struct(nestedStructSchema); + nestedValue.put("lfg", "lfg"); + Struct innerValue = new Struct(innerStructSchema); + innerValue.put("baz", nestedValue); + Struct value = new Struct(structSchema); + value.put("foo", Collections.singletonList(innerValue)); + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, "topic", null, null, structSchema, value)); + assertEquals(value, transformed.value()); + assertEquals(structSchema, transformed.valueSchema()); + } + + @Test + public void testOptionalAndDefaultValuesNested() { + // If we have a nested structure where an entire sub-Struct is optional, all flattened fields generated from its + // children should also be optional. Similarly, if the parent Struct has a default value, the default value for + // the flattened field + + xformValue.configure(Collections.emptyMap()); + + SchemaBuilder builder = SchemaBuilder.struct().optional(); + builder.field("req_field", Schema.STRING_SCHEMA); + builder.field("opt_field", SchemaBuilder.string().optional().defaultValue("child_default").build()); + Struct childDefaultValue = new Struct(builder); + childDefaultValue.put("req_field", "req_default"); + builder.defaultValue(childDefaultValue); + Schema schema = builder.build(); + // Intentionally leave this entire value empty since it is optional + Struct value = new Struct(schema); + + SourceRecord transformed = xformValue.apply(new SourceRecord(null, null, "topic", 0, schema, value)); + + assertNotNull(transformed); + Schema transformedSchema = transformed.valueSchema(); + assertEquals(Schema.Type.STRUCT, transformedSchema.type()); + assertEquals(2, transformedSchema.fields().size()); + // Required field should pick up both being optional and the default value from the parent + Schema transformedReqFieldSchema = SchemaBuilder.string().optional().defaultValue("req_default").build(); + assertEquals(transformedReqFieldSchema, transformedSchema.field("req_field").schema()); + // The optional field should still be optional but should have picked up the default value. However, since + // the parent didn't specify the default explicitly, we should still be using the field's normal default + Schema transformedOptFieldSchema = SchemaBuilder.string().optional().defaultValue("child_default").build(); + assertEquals(transformedOptFieldSchema, transformedSchema.field("opt_field").schema()); + } + + @Test + public void tombstoneEventWithoutSchemaShouldPassThrough() { + xformValue.configure(Collections.emptyMap()); + + final SourceRecord record = new SourceRecord(null, null, "test", 0, + null, null); + final SourceRecord transformedRecord = xformValue.apply(record); + + assertNull(transformedRecord.value()); + assertNull(transformedRecord.valueSchema()); + } + + @Test + public void tombstoneEventWithSchemaShouldPassThrough() { + xformValue.configure(Collections.emptyMap()); + + final Schema simpleStructSchema = SchemaBuilder.struct().name("name").version(1).doc("doc").field("magic", Schema.OPTIONAL_INT64_SCHEMA).build(); + final SourceRecord record = new SourceRecord(null, null, "test", 0, + simpleStructSchema, null); + final SourceRecord transformedRecord = xformValue.apply(record); + + assertNull(transformedRecord.value()); + assertEquals(simpleStructSchema, transformedRecord.valueSchema()); + } + + @Test + public void testMapWithNullFields() { + xformValue.configure(Collections.emptyMap()); + + // Use a LinkedHashMap to ensure the SMT sees entries in a specific order + Map value = new LinkedHashMap<>(); + value.put("firstNull", null); + value.put("firstNonNull", "nonNull"); + value.put("secondNull", null); + value.put("secondNonNull", "alsoNonNull"); + value.put("thirdNonNull", null); + + final SourceRecord record = new SourceRecord(null, null, "test", 0, null, value); + final SourceRecord transformedRecord = xformValue.apply(record); + + assertEquals(value, transformedRecord.value()); + } + + @Test + public void testStructWithNullFields() { + xformValue.configure(Collections.emptyMap()); + + final Schema structSchema = SchemaBuilder.struct() + .field("firstNull", Schema.OPTIONAL_STRING_SCHEMA) + .field("firstNonNull", Schema.OPTIONAL_STRING_SCHEMA) + .field("secondNull", Schema.OPTIONAL_STRING_SCHEMA) + .field("secondNonNull", Schema.OPTIONAL_STRING_SCHEMA) + .field("thirdNonNull", Schema.OPTIONAL_STRING_SCHEMA) + .build(); + + final Struct value = new Struct(structSchema); + value.put("firstNull", null); + value.put("firstNonNull", "nonNull"); + value.put("secondNull", null); + value.put("secondNonNull", "alsoNonNull"); + value.put("thirdNonNull", null); + + final SourceRecord record = new SourceRecord(null, null, "test", 0, structSchema, value); + final SourceRecord transformedRecord = xformValue.apply(record); + + assertEquals(value, transformedRecord.value()); + } +} diff --git a/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/HeaderFromTest.java b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/HeaderFromTest.java new file mode 100644 index 0000000..61e0575 --- /dev/null +++ b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/HeaderFromTest.java @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaAndValue; +import org.apache.kafka.connect.data.SchemaBuilder; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.header.ConnectHeaders; +import org.apache.kafka.connect.header.Header; +import org.apache.kafka.connect.header.Headers; +import org.apache.kafka.connect.source.SourceRecord; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; +import static java.util.Collections.singletonList; +import static java.util.Collections.singletonMap; +import static org.apache.kafka.connect.data.Schema.STRING_SCHEMA; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class HeaderFromTest { + + static class RecordBuilder { + private final List fields = new ArrayList<>(2); + private final List fieldSchemas = new ArrayList<>(2); + private final List fieldValues = new ArrayList<>(2); + private final ConnectHeaders headers = new ConnectHeaders(); + + public RecordBuilder() { + } + + public RecordBuilder withField(String name, Schema schema, Object value) { + fields.add(name); + fieldSchemas.add(schema); + fieldValues.add(value); + return this; + } + + public RecordBuilder addHeader(String name, Schema schema, Object value) { + headers.add(name, new SchemaAndValue(schema, value)); + return this; + } + + public SourceRecord schemaless(boolean keyTransform) { + Map map = new HashMap<>(); + for (int i = 0; i < this.fields.size(); i++) { + String fieldName = this.fields.get(i); + map.put(fieldName, this.fieldValues.get(i)); + + } + return sourceRecord(keyTransform, null, map); + } + + private Schema schema() { + SchemaBuilder schemaBuilder = new SchemaBuilder(Schema.Type.STRUCT); + for (int i = 0; i < this.fields.size(); i++) { + String fieldName = this.fields.get(i); + schemaBuilder.field(fieldName, this.fieldSchemas.get(i)); + + } + return schemaBuilder.build(); + } + + private Struct struct(Schema schema) { + Struct struct = new Struct(schema); + for (int i = 0; i < this.fields.size(); i++) { + String fieldName = this.fields.get(i); + struct.put(fieldName, this.fieldValues.get(i)); + } + return struct; + } + + public SourceRecord withSchema(boolean keyTransform) { + Schema schema = schema(); + Struct struct = struct(schema); + return sourceRecord(keyTransform, schema, struct); + } + + private SourceRecord sourceRecord(boolean keyTransform, Schema keyOrValueSchema, Object keyOrValue) { + Map sourcePartition = singletonMap("foo", "bar"); + Map sourceOffset = singletonMap("baz", "quxx"); + String topic = "topic"; + Integer partition = 0; + Long timestamp = 0L; + + ConnectHeaders headers = this.headers; + if (keyOrValueSchema == null) { + // When doing a schemaless transformation we don't expect the header to have a schema + headers = new ConnectHeaders(); + for (Header header : this.headers) { + headers.add(header.key(), new SchemaAndValue(null, header.value())); + } + } + return new SourceRecord(sourcePartition, sourceOffset, topic, partition, + keyTransform ? keyOrValueSchema : null, + keyTransform ? keyOrValue : "key", + !keyTransform ? keyOrValueSchema : null, + !keyTransform ? keyOrValue : "value", + timestamp, headers); + } + + @Override + public String toString() { + return "RecordBuilder(" + + "fields=" + fields + + ", fieldSchemas=" + fieldSchemas + + ", fieldValues=" + fieldValues + + ", headers=" + headers + + ')'; + } + } + + public static List data() { + + List result = new ArrayList<>(); + + for (Boolean testKeyTransform : asList(true, false)) { + result.add( + Arguments.of( + "basic copy", + testKeyTransform, + new RecordBuilder() + .withField("field1", STRING_SCHEMA, "field1-value") + .withField("field2", STRING_SCHEMA, "field2-value") + .addHeader("header1", STRING_SCHEMA, "existing-value"), + singletonList("field1"), singletonList("inserted1"), HeaderFrom.Operation.COPY, + new RecordBuilder() + .withField("field1", STRING_SCHEMA, "field1-value") + .withField("field2", STRING_SCHEMA, "field2-value") + .addHeader("header1", STRING_SCHEMA, "existing-value") + .addHeader("inserted1", STRING_SCHEMA, "field1-value") + )); + result.add( + Arguments.of( + "basic move", + testKeyTransform, + new RecordBuilder() + .withField("field1", STRING_SCHEMA, "field1-value") + .withField("field2", STRING_SCHEMA, "field2-value") + .addHeader("header1", STRING_SCHEMA, "existing-value"), + singletonList("field1"), singletonList("inserted1"), HeaderFrom.Operation.MOVE, + new RecordBuilder() + // field1 got moved + .withField("field2", STRING_SCHEMA, "field2-value") + .addHeader("header1", STRING_SCHEMA, "existing-value") + .addHeader("inserted1", STRING_SCHEMA, "field1-value") + )); + result.add( + Arguments.of( + "copy with preexisting header", + testKeyTransform, + new RecordBuilder() + .withField("field1", STRING_SCHEMA, "field1-value") + .withField("field2", STRING_SCHEMA, "field2-value") + .addHeader("inserted1", STRING_SCHEMA, "existing-value"), + singletonList("field1"), singletonList("inserted1"), HeaderFrom.Operation.COPY, + new RecordBuilder() + .withField("field1", STRING_SCHEMA, "field1-value") + .withField("field2", STRING_SCHEMA, "field2-value") + .addHeader("inserted1", STRING_SCHEMA, "existing-value") + .addHeader("inserted1", STRING_SCHEMA, "field1-value") + )); + result.add( + Arguments.of( + "move with preexisting header", + testKeyTransform, + new RecordBuilder() + .withField("field1", STRING_SCHEMA, "field1-value") + .withField("field2", STRING_SCHEMA, "field2-value") + .addHeader("inserted1", STRING_SCHEMA, "existing-value"), + singletonList("field1"), singletonList("inserted1"), HeaderFrom.Operation.MOVE, + new RecordBuilder() + // field1 got moved + .withField("field2", STRING_SCHEMA, "field2-value") + .addHeader("inserted1", STRING_SCHEMA, "existing-value") + .addHeader("inserted1", STRING_SCHEMA, "field1-value") + )); + Schema schema = new SchemaBuilder(Schema.Type.STRUCT).field("foo", STRING_SCHEMA).build(); + Struct struct = new Struct(schema).put("foo", "foo-value"); + result.add( + Arguments.of( + "copy with struct value", + testKeyTransform, + new RecordBuilder() + .withField("field1", schema, struct) + .withField("field2", STRING_SCHEMA, "field2-value") + .addHeader("header1", STRING_SCHEMA, "existing-value"), + singletonList("field1"), singletonList("inserted1"), HeaderFrom.Operation.COPY, + new RecordBuilder() + .withField("field1", schema, struct) + .withField("field2", STRING_SCHEMA, "field2-value") + .addHeader("header1", STRING_SCHEMA, "existing-value") + .addHeader("inserted1", schema, struct) + )); + result.add( + Arguments.of( + "move with struct value", + testKeyTransform, + new RecordBuilder() + .withField("field1", schema, struct) + .withField("field2", STRING_SCHEMA, "field2-value") + .addHeader("header1", STRING_SCHEMA, "existing-value"), + singletonList("field1"), singletonList("inserted1"), HeaderFrom.Operation.MOVE, + new RecordBuilder() + // field1 got moved + .withField("field2", STRING_SCHEMA, "field2-value") + .addHeader("header1", STRING_SCHEMA, "existing-value") + .addHeader("inserted1", schema, struct) + )); + result.add( + Arguments.of( + "two headers from same field", + testKeyTransform, + new RecordBuilder() + .withField("field1", STRING_SCHEMA, "field1-value") + .withField("field2", STRING_SCHEMA, "field2-value") + .addHeader("header1", STRING_SCHEMA, "existing-value"), + // two headers from the same field + asList("field1", "field1"), asList("inserted1", "inserted2"), HeaderFrom.Operation.MOVE, + new RecordBuilder() + // field1 got moved + .withField("field2", STRING_SCHEMA, "field2-value") + .addHeader("header1", STRING_SCHEMA, "existing-value") + .addHeader("inserted1", STRING_SCHEMA, "field1-value") + .addHeader("inserted2", STRING_SCHEMA, "field1-value") + )); + result.add( + Arguments.of( + "two fields to same header", + testKeyTransform, + new RecordBuilder() + .withField("field1", STRING_SCHEMA, "field1-value") + .withField("field2", STRING_SCHEMA, "field2-value") + .addHeader("header1", STRING_SCHEMA, "existing-value"), + // two headers from the same field + asList("field1", "field2"), asList("inserted1", "inserted1"), HeaderFrom.Operation.MOVE, + new RecordBuilder() + // field1 and field2 got moved + .addHeader("header1", STRING_SCHEMA, "existing-value") + .addHeader("inserted1", STRING_SCHEMA, "field1-value") + .addHeader("inserted1", STRING_SCHEMA, "field2-value") + )); + } + return result; + } + + private Map config(List headers, List transformFields, HeaderFrom.Operation operation) { + Map result = new HashMap<>(); + result.put(HeaderFrom.HEADERS_FIELD, headers); + result.put(HeaderFrom.FIELDS_FIELD, transformFields); + result.put(HeaderFrom.OPERATION_FIELD, operation.toString()); + return result; + } + + @ParameterizedTest + @MethodSource("data") + public void schemaless(String description, + boolean keyTransform, + RecordBuilder originalBuilder, + List transformFields, List headers1, HeaderFrom.Operation operation, + RecordBuilder expectedBuilder) { + HeaderFrom xform = keyTransform ? new HeaderFrom.Key<>() : new HeaderFrom.Value<>(); + + xform.configure(config(headers1, transformFields, operation)); + ConnectHeaders headers = new ConnectHeaders(); + headers.addString("existing", "existing-value"); + + SourceRecord originalRecord = originalBuilder.schemaless(keyTransform); + SourceRecord expectedRecord = expectedBuilder.schemaless(keyTransform); + SourceRecord xformed = xform.apply(originalRecord); + assertSameRecord(expectedRecord, xformed); + } + + @ParameterizedTest + @MethodSource("data") + public void withSchema(String description, + boolean keyTransform, + RecordBuilder originalBuilder, + List transformFields, List headers1, HeaderFrom.Operation operation, + RecordBuilder expectedBuilder) { + HeaderFrom xform = keyTransform ? new HeaderFrom.Key<>() : new HeaderFrom.Value<>(); + xform.configure(config(headers1, transformFields, operation)); + ConnectHeaders headers = new ConnectHeaders(); + headers.addString("existing", "existing-value"); + Headers expect = headers.duplicate(); + for (int i = 0; i < headers1.size(); i++) { + expect.add(headers1.get(i), originalBuilder.fieldValues.get(i), originalBuilder.fieldSchemas.get(i)); + } + + SourceRecord originalRecord = originalBuilder.withSchema(keyTransform); + SourceRecord expectedRecord = expectedBuilder.withSchema(keyTransform); + SourceRecord xformed = xform.apply(originalRecord); + assertSameRecord(expectedRecord, xformed); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void invalidConfigExtraHeaderConfig(boolean keyTransform) { + Map config = config(singletonList("foo"), asList("foo", "bar"), HeaderFrom.Operation.COPY); + HeaderFrom xform = keyTransform ? new HeaderFrom.Key<>() : new HeaderFrom.Value<>(); + assertThrows(ConfigException.class, () -> xform.configure(config)); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void invalidConfigExtraFieldConfig(boolean keyTransform) { + Map config = config(asList("foo", "bar"), singletonList("foo"), HeaderFrom.Operation.COPY); + HeaderFrom xform = keyTransform ? new HeaderFrom.Key<>() : new HeaderFrom.Value<>(); + assertThrows(ConfigException.class, () -> xform.configure(config)); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void invalidConfigEmptyHeadersAndFieldsConfig(boolean keyTransform) { + Map config = config(emptyList(), emptyList(), HeaderFrom.Operation.COPY); + HeaderFrom xform = keyTransform ? new HeaderFrom.Key<>() : new HeaderFrom.Value<>(); + assertThrows(ConfigException.class, () -> xform.configure(config)); + } + + private static void assertSameRecord(SourceRecord expected, SourceRecord xformed) { + assertEquals(expected.sourcePartition(), xformed.sourcePartition()); + assertEquals(expected.sourceOffset(), xformed.sourceOffset()); + assertEquals(expected.topic(), xformed.topic()); + assertEquals(expected.kafkaPartition(), xformed.kafkaPartition()); + assertEquals(expected.keySchema(), xformed.keySchema()); + assertEquals(expected.key(), xformed.key()); + assertEquals(expected.valueSchema(), xformed.valueSchema()); + assertEquals(expected.value(), xformed.value()); + assertEquals(expected.timestamp(), xformed.timestamp()); + assertEquals(expected.headers(), xformed.headers()); + } + +} + diff --git a/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/HoistFieldTest.java b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/HoistFieldTest.java new file mode 100644 index 0000000..ab601b8 --- /dev/null +++ b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/HoistFieldTest.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.sink.SinkRecord; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +public class HoistFieldTest { + private final HoistField xform = new HoistField.Key<>(); + + @AfterEach + public void teardown() { + xform.close(); + } + + @Test + public void schemaless() { + xform.configure(Collections.singletonMap("field", "magic")); + + final SinkRecord record = new SinkRecord("test", 0, null, 42, null, null, 0); + final SinkRecord transformedRecord = xform.apply(record); + + assertNull(transformedRecord.keySchema()); + assertEquals(Collections.singletonMap("magic", 42), transformedRecord.key()); + } + + @Test + public void withSchema() { + xform.configure(Collections.singletonMap("field", "magic")); + + final SinkRecord record = new SinkRecord("test", 0, Schema.INT32_SCHEMA, 42, null, null, 0); + final SinkRecord transformedRecord = xform.apply(record); + + assertEquals(Schema.Type.STRUCT, transformedRecord.keySchema().type()); + assertEquals(record.keySchema(), transformedRecord.keySchema().field("magic").schema()); + assertEquals(42, ((Struct) transformedRecord.key()).get("magic")); + } + +} diff --git a/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/InsertFieldTest.java b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/InsertFieldTest.java new file mode 100644 index 0000000..d5cc6bf --- /dev/null +++ b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/InsertFieldTest.java @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaBuilder; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.data.Timestamp; +import org.apache.kafka.connect.errors.DataException; +import org.apache.kafka.connect.source.SourceRecord; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.Date; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class InsertFieldTest { + private InsertField xformKey = new InsertField.Key<>(); + private InsertField xformValue = new InsertField.Value<>(); + + @AfterEach + public void teardown() { + xformValue.close(); + } + + @Test + public void topLevelStructRequired() { + xformValue.configure(Collections.singletonMap("topic.field", "topic_field")); + assertThrows(DataException.class, + () -> xformValue.apply(new SourceRecord(null, null, "", 0, Schema.INT32_SCHEMA, 42))); + } + + @Test + public void copySchemaAndInsertConfiguredFields() { + final Map props = new HashMap<>(); + props.put("topic.field", "topic_field!"); + props.put("partition.field", "partition_field"); + props.put("timestamp.field", "timestamp_field?"); + props.put("static.field", "instance_id"); + props.put("static.value", "my-instance-id"); + + xformValue.configure(props); + + final Schema simpleStructSchema = SchemaBuilder.struct().name("name").version(1).doc("doc").field("magic", Schema.OPTIONAL_INT64_SCHEMA).build(); + final Struct simpleStruct = new Struct(simpleStructSchema).put("magic", 42L); + + final SourceRecord record = new SourceRecord(null, null, "test", 0, null, null, simpleStructSchema, simpleStruct, 789L); + final SourceRecord transformedRecord = xformValue.apply(record); + + assertEquals(simpleStructSchema.name(), transformedRecord.valueSchema().name()); + assertEquals(simpleStructSchema.version(), transformedRecord.valueSchema().version()); + assertEquals(simpleStructSchema.doc(), transformedRecord.valueSchema().doc()); + + assertEquals(Schema.OPTIONAL_INT64_SCHEMA, transformedRecord.valueSchema().field("magic").schema()); + assertEquals(42L, ((Struct) transformedRecord.value()).getInt64("magic").longValue()); + + assertEquals(Schema.STRING_SCHEMA, transformedRecord.valueSchema().field("topic_field").schema()); + assertEquals("test", ((Struct) transformedRecord.value()).getString("topic_field")); + + assertEquals(Schema.OPTIONAL_INT32_SCHEMA, transformedRecord.valueSchema().field("partition_field").schema()); + assertEquals(0, ((Struct) transformedRecord.value()).getInt32("partition_field").intValue()); + + assertEquals(Timestamp.builder().optional().build(), transformedRecord.valueSchema().field("timestamp_field").schema()); + assertEquals(789L, ((Date) ((Struct) transformedRecord.value()).get("timestamp_field")).getTime()); + + assertEquals(Schema.OPTIONAL_STRING_SCHEMA, transformedRecord.valueSchema().field("instance_id").schema()); + assertEquals("my-instance-id", ((Struct) transformedRecord.value()).getString("instance_id")); + + // Exercise caching + final SourceRecord transformedRecord2 = xformValue.apply( + new SourceRecord(null, null, "test", 1, simpleStructSchema, new Struct(simpleStructSchema))); + assertSame(transformedRecord.valueSchema(), transformedRecord2.valueSchema()); + } + + @Test + public void schemalessInsertConfiguredFields() { + final Map props = new HashMap<>(); + props.put("topic.field", "topic_field!"); + props.put("partition.field", "partition_field"); + props.put("timestamp.field", "timestamp_field?"); + props.put("static.field", "instance_id"); + props.put("static.value", "my-instance-id"); + + xformValue.configure(props); + + final SourceRecord record = new SourceRecord(null, null, "test", 0, + null, null, null, Collections.singletonMap("magic", 42L), 123L); + + final SourceRecord transformedRecord = xformValue.apply(record); + + assertEquals(42L, ((Map) transformedRecord.value()).get("magic")); + assertEquals("test", ((Map) transformedRecord.value()).get("topic_field")); + assertEquals(0, ((Map) transformedRecord.value()).get("partition_field")); + assertEquals(123L, ((Map) transformedRecord.value()).get("timestamp_field")); + assertEquals("my-instance-id", ((Map) transformedRecord.value()).get("instance_id")); + } + + @Test + public void insertConfiguredFieldsIntoTombstoneEventWithoutSchemaLeavesValueUnchanged() { + final Map props = new HashMap<>(); + props.put("topic.field", "topic_field!"); + props.put("partition.field", "partition_field"); + props.put("timestamp.field", "timestamp_field?"); + props.put("static.field", "instance_id"); + props.put("static.value", "my-instance-id"); + + xformValue.configure(props); + + final SourceRecord record = new SourceRecord(null, null, "test", 0, + null, null); + + final SourceRecord transformedRecord = xformValue.apply(record); + + assertNull(transformedRecord.value()); + assertNull(transformedRecord.valueSchema()); + } + + @Test + public void insertConfiguredFieldsIntoTombstoneEventWithSchemaLeavesValueUnchanged() { + final Map props = new HashMap<>(); + props.put("topic.field", "topic_field!"); + props.put("partition.field", "partition_field"); + props.put("timestamp.field", "timestamp_field?"); + props.put("static.field", "instance_id"); + props.put("static.value", "my-instance-id"); + + xformValue.configure(props); + + final Schema simpleStructSchema = SchemaBuilder.struct().name("name").version(1).doc("doc").field("magic", Schema.OPTIONAL_INT64_SCHEMA).build(); + + final SourceRecord record = new SourceRecord(null, null, "test", 0, + simpleStructSchema, null); + + final SourceRecord transformedRecord = xformValue.apply(record); + + assertNull(transformedRecord.value()); + assertEquals(simpleStructSchema, transformedRecord.valueSchema()); + } + + @Test + public void insertKeyFieldsIntoTombstoneEvent() { + final Map props = new HashMap<>(); + props.put("topic.field", "topic_field!"); + props.put("partition.field", "partition_field"); + props.put("timestamp.field", "timestamp_field?"); + props.put("static.field", "instance_id"); + props.put("static.value", "my-instance-id"); + + xformKey.configure(props); + + final SourceRecord record = new SourceRecord(null, null, "test", 0, + null, Collections.singletonMap("magic", 42L), null, null); + + final SourceRecord transformedRecord = xformKey.apply(record); + + assertEquals(42L, ((Map) transformedRecord.key()).get("magic")); + assertEquals("test", ((Map) transformedRecord.key()).get("topic_field")); + assertEquals(0, ((Map) transformedRecord.key()).get("partition_field")); + assertNull(((Map) transformedRecord.key()).get("timestamp_field")); + assertEquals("my-instance-id", ((Map) transformedRecord.key()).get("instance_id")); + assertNull(transformedRecord.value()); + } + + @Test + public void insertIntoNullKeyLeavesRecordUnchanged() { + final Map props = new HashMap<>(); + props.put("topic.field", "topic_field!"); + props.put("partition.field", "partition_field"); + props.put("timestamp.field", "timestamp_field?"); + props.put("static.field", "instance_id"); + props.put("static.value", "my-instance-id"); + + xformKey.configure(props); + + final SourceRecord record = new SourceRecord(null, null, "test", 0, + null, null, null, Collections.singletonMap("magic", 42L)); + + final SourceRecord transformedRecord = xformKey.apply(record); + + assertSame(record, transformedRecord); + } +} diff --git a/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/InsertHeaderTest.java b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/InsertHeaderTest.java new file mode 100644 index 0000000..97cbe5d --- /dev/null +++ b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/InsertHeaderTest.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.header.ConnectHeaders; +import org.apache.kafka.connect.header.Headers; +import org.apache.kafka.connect.source.SourceRecord; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; + +import static java.util.Collections.singletonMap; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class InsertHeaderTest { + + private InsertHeader xform = new InsertHeader<>(); + + private Map config(String header, String valueLiteral) { + Map result = new HashMap<>(); + result.put(InsertHeader.HEADER_FIELD, header); + result.put(InsertHeader.VALUE_LITERAL_FIELD, valueLiteral); + return result; + } + + @Test + public void insertionWithExistingOtherHeader() { + xform.configure(config("inserted", "inserted-value")); + ConnectHeaders headers = new ConnectHeaders(); + headers.addString("existing", "existing-value"); + Headers expect = headers.duplicate().addString("inserted", "inserted-value"); + + SourceRecord original = sourceRecord(headers); + SourceRecord xformed = xform.apply(original); + assertNonHeaders(original, xformed); + assertEquals(expect, xformed.headers()); + } + + @Test + public void insertionWithExistingSameHeader() { + xform.configure(config("existing", "inserted-value")); + ConnectHeaders headers = new ConnectHeaders(); + headers.addString("existing", "preexisting-value"); + Headers expect = headers.duplicate().addString("existing", "inserted-value"); + + SourceRecord original = sourceRecord(headers); + SourceRecord xformed = xform.apply(original); + assertNonHeaders(original, xformed); + assertEquals(expect, xformed.headers()); + } + + @Test + public void insertionWithByteHeader() { + xform.configure(config("inserted", "1")); + ConnectHeaders headers = new ConnectHeaders(); + headers.addString("existing", "existing-value"); + Headers expect = headers.duplicate().addByte("inserted", (byte) 1); + + SourceRecord original = sourceRecord(headers); + SourceRecord xformed = xform.apply(original); + assertNonHeaders(original, xformed); + assertEquals(expect, xformed.headers()); + } + + @Test + public void configRejectsNullHeaderKey() { + assertThrows(ConfigException.class, () -> xform.configure(config(null, "1"))); + } + + @Test + public void configRejectsNullHeaderValue() { + assertThrows(ConfigException.class, () -> xform.configure(config("inserted", null))); + } + + private void assertNonHeaders(SourceRecord original, SourceRecord xformed) { + assertEquals(original.sourcePartition(), xformed.sourcePartition()); + assertEquals(original.sourceOffset(), xformed.sourceOffset()); + assertEquals(original.topic(), xformed.topic()); + assertEquals(original.kafkaPartition(), xformed.kafkaPartition()); + assertEquals(original.keySchema(), xformed.keySchema()); + assertEquals(original.key(), xformed.key()); + assertEquals(original.valueSchema(), xformed.valueSchema()); + assertEquals(original.value(), xformed.value()); + assertEquals(original.timestamp(), xformed.timestamp()); + } + + private SourceRecord sourceRecord(ConnectHeaders headers) { + Map sourcePartition = singletonMap("foo", "bar"); + Map sourceOffset = singletonMap("baz", "quxx"); + String topic = "topic"; + Integer partition = 0; + Schema keySchema = null; + Object key = "key"; + Schema valueSchema = null; + Object value = "value"; + Long timestamp = 0L; + + SourceRecord record = new SourceRecord(sourcePartition, sourceOffset, topic, partition, + keySchema, key, valueSchema, value, timestamp, headers); + return record; + } +} + diff --git a/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/MaskFieldTest.java b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/MaskFieldTest.java new file mode 100644 index 0000000..5eaa31b --- /dev/null +++ b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/MaskFieldTest.java @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.connect.data.Decimal; +import org.apache.kafka.connect.data.Field; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaBuilder; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.data.Time; +import org.apache.kafka.connect.data.Timestamp; +import org.apache.kafka.connect.errors.DataException; +import org.apache.kafka.connect.sink.SinkRecord; +import org.junit.jupiter.api.Test; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Date; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static java.util.Collections.singletonList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class MaskFieldTest { + + private static final Schema SCHEMA = SchemaBuilder.struct() + .field("magic", Schema.INT32_SCHEMA) + .field("bool", Schema.BOOLEAN_SCHEMA) + .field("byte", Schema.INT8_SCHEMA) + .field("short", Schema.INT16_SCHEMA) + .field("int", Schema.INT32_SCHEMA) + .field("long", Schema.INT64_SCHEMA) + .field("float", Schema.FLOAT32_SCHEMA) + .field("double", Schema.FLOAT64_SCHEMA) + .field("string", Schema.STRING_SCHEMA) + .field("date", org.apache.kafka.connect.data.Date.SCHEMA) + .field("time", Time.SCHEMA) + .field("timestamp", Timestamp.SCHEMA) + .field("decimal", Decimal.schema(0)) + .field("array", SchemaBuilder.array(Schema.INT32_SCHEMA)) + .field("map", SchemaBuilder.map(Schema.STRING_SCHEMA, Schema.STRING_SCHEMA)) + .build(); + private static final Map VALUES = new HashMap<>(); + private static final Struct VALUES_WITH_SCHEMA = new Struct(SCHEMA); + + static { + VALUES.put("magic", 42); + VALUES.put("bool", true); + VALUES.put("byte", (byte) 42); + VALUES.put("short", (short) 42); + VALUES.put("int", 42); + VALUES.put("long", 42L); + VALUES.put("float", 42f); + VALUES.put("double", 42d); + VALUES.put("string", "55.121.20.20"); + VALUES.put("date", new Date()); + VALUES.put("bigint", new BigInteger("42")); + VALUES.put("bigdec", new BigDecimal("42.0")); + VALUES.put("list", singletonList(42)); + VALUES.put("map", Collections.singletonMap("key", "value")); + + VALUES_WITH_SCHEMA.put("magic", 42); + VALUES_WITH_SCHEMA.put("bool", true); + VALUES_WITH_SCHEMA.put("byte", (byte) 42); + VALUES_WITH_SCHEMA.put("short", (short) 42); + VALUES_WITH_SCHEMA.put("int", 42); + VALUES_WITH_SCHEMA.put("long", 42L); + VALUES_WITH_SCHEMA.put("float", 42f); + VALUES_WITH_SCHEMA.put("double", 42d); + VALUES_WITH_SCHEMA.put("string", "hmm"); + VALUES_WITH_SCHEMA.put("date", new Date()); + VALUES_WITH_SCHEMA.put("time", new Date()); + VALUES_WITH_SCHEMA.put("timestamp", new Date()); + VALUES_WITH_SCHEMA.put("decimal", new BigDecimal(42)); + VALUES_WITH_SCHEMA.put("array", Arrays.asList(1, 2, 3)); + VALUES_WITH_SCHEMA.put("map", Collections.singletonMap("what", "what")); + } + + private static MaskField transform(List fields, String replacement) { + final MaskField xform = new MaskField.Value<>(); + Map props = new HashMap<>(); + props.put("fields", fields); + props.put("replacement", replacement); + xform.configure(props); + return xform; + } + + private static SinkRecord record(Schema schema, Object value) { + return new SinkRecord("", 0, null, null, schema, value, 0); + } + + private static void checkReplacementWithSchema(String maskField, Object replacement) { + SinkRecord record = record(SCHEMA, VALUES_WITH_SCHEMA); + final Struct updatedValue = (Struct) transform(singletonList(maskField), String.valueOf(replacement)).apply(record).value(); + assertEquals(replacement, updatedValue.get(maskField), "Invalid replacement for " + maskField + " value"); + } + + private static void checkReplacementSchemaless(String maskField, Object replacement) { + checkReplacementSchemaless(singletonList(maskField), replacement); + } + + @SuppressWarnings("unchecked") + private static void checkReplacementSchemaless(List maskFields, Object replacement) { + SinkRecord record = record(null, VALUES); + final Map updatedValue = (Map) transform(maskFields, String.valueOf(replacement)) + .apply(record) + .value(); + for (String maskField : maskFields) { + assertEquals(replacement, updatedValue.get(maskField), "Invalid replacement for " + maskField + " value"); + } + } + + @Test + public void testSchemaless() { + final List maskFields = new ArrayList<>(VALUES.keySet()); + maskFields.remove("magic"); + @SuppressWarnings("unchecked") final Map updatedValue = (Map) transform(maskFields, null).apply(record(null, VALUES)).value(); + + assertEquals(42, updatedValue.get("magic")); + assertEquals(false, updatedValue.get("bool")); + assertEquals((byte) 0, updatedValue.get("byte")); + assertEquals((short) 0, updatedValue.get("short")); + assertEquals(0, updatedValue.get("int")); + assertEquals(0L, updatedValue.get("long")); + assertEquals(0f, updatedValue.get("float")); + assertEquals(0d, updatedValue.get("double")); + assertEquals("", updatedValue.get("string")); + assertEquals(new Date(0), updatedValue.get("date")); + assertEquals(BigInteger.ZERO, updatedValue.get("bigint")); + assertEquals(BigDecimal.ZERO, updatedValue.get("bigdec")); + assertEquals(Collections.emptyList(), updatedValue.get("list")); + assertEquals(Collections.emptyMap(), updatedValue.get("map")); + } + + @Test + public void testWithSchema() { + final List maskFields = new ArrayList<>(SCHEMA.fields().size()); + for (Field field : SCHEMA.fields()) { + if (!field.name().equals("magic")) { + maskFields.add(field.name()); + } + } + + final Struct updatedValue = (Struct) transform(maskFields, null).apply(record(SCHEMA, VALUES_WITH_SCHEMA)).value(); + + assertEquals(42, updatedValue.get("magic")); + assertEquals(false, updatedValue.get("bool")); + assertEquals((byte) 0, updatedValue.get("byte")); + assertEquals((short) 0, updatedValue.get("short")); + assertEquals(0, updatedValue.get("int")); + assertEquals(0L, updatedValue.get("long")); + assertEquals(0f, updatedValue.get("float")); + assertEquals(0d, updatedValue.get("double")); + assertEquals("", updatedValue.get("string")); + assertEquals(new Date(0), updatedValue.get("date")); + assertEquals(new Date(0), updatedValue.get("time")); + assertEquals(new Date(0), updatedValue.get("timestamp")); + assertEquals(BigDecimal.ZERO, updatedValue.get("decimal")); + assertEquals(Collections.emptyList(), updatedValue.get("array")); + assertEquals(Collections.emptyMap(), updatedValue.get("map")); + } + + @Test + public void testSchemalessWithReplacement() { + checkReplacementSchemaless("short", (short) 123); + checkReplacementSchemaless("byte", (byte) 123); + checkReplacementSchemaless("int", 123); + checkReplacementSchemaless("long", 123L); + checkReplacementSchemaless("float", 123.0f); + checkReplacementSchemaless("double", 123.0); + checkReplacementSchemaless("string", "123"); + checkReplacementSchemaless("bigint", BigInteger.valueOf(123L)); + checkReplacementSchemaless("bigdec", BigDecimal.valueOf(123.0)); + } + + @Test + public void testSchemalessUnsupportedReplacementType() { + String exMessage = "Cannot mask value of type"; + Class exClass = DataException.class; + + assertThrows(exClass, () -> checkReplacementSchemaless("date", new Date()), exMessage); + assertThrows(exClass, () -> checkReplacementSchemaless(Arrays.asList("int", "date"), new Date()), exMessage); + assertThrows(exClass, () -> checkReplacementSchemaless("bool", false), exMessage); + assertThrows(exClass, () -> checkReplacementSchemaless("list", singletonList("123")), exMessage); + assertThrows(exClass, () -> checkReplacementSchemaless("map", Collections.singletonMap("123", "321")), exMessage); + } + + @Test + public void testWithSchemaAndReplacement() { + checkReplacementWithSchema("short", (short) 123); + checkReplacementWithSchema("byte", (byte) 123); + checkReplacementWithSchema("int", 123); + checkReplacementWithSchema("long", 123L); + checkReplacementWithSchema("float", 123.0f); + checkReplacementWithSchema("double", 123.0); + checkReplacementWithSchema("string", "123"); + checkReplacementWithSchema("decimal", BigDecimal.valueOf(123.0)); + } + + @Test + public void testWithSchemaUnsupportedReplacementType() { + String exMessage = "Cannot mask value of type"; + Class exClass = DataException.class; + + assertThrows(exClass, () -> checkReplacementWithSchema("time", new Date()), exMessage); + assertThrows(exClass, () -> checkReplacementWithSchema("timestamp", new Date()), exMessage); + assertThrows(exClass, () -> checkReplacementWithSchema("array", singletonList(123)), exMessage); + } + + @Test + public void testReplacementTypeMismatch() { + String exMessage = "Invalid value for configuration replacement"; + Class exClass = DataException.class; + + assertThrows(exClass, () -> checkReplacementSchemaless("byte", "foo"), exMessage); + assertThrows(exClass, () -> checkReplacementSchemaless("short", "foo"), exMessage); + assertThrows(exClass, () -> checkReplacementSchemaless("int", "foo"), exMessage); + assertThrows(exClass, () -> checkReplacementSchemaless("long", "foo"), exMessage); + assertThrows(exClass, () -> checkReplacementSchemaless("float", "foo"), exMessage); + assertThrows(exClass, () -> checkReplacementSchemaless("double", "foo"), exMessage); + assertThrows(exClass, () -> checkReplacementSchemaless("bigint", "foo"), exMessage); + assertThrows(exClass, () -> checkReplacementSchemaless("bigdec", "foo"), exMessage); + assertThrows(exClass, () -> checkReplacementSchemaless("int", new Date()), exMessage); + assertThrows(exClass, () -> checkReplacementSchemaless("int", new Object()), exMessage); + assertThrows(exClass, () -> checkReplacementSchemaless(Arrays.asList("string", "int"), "foo"), exMessage); + } + + @Test + public void testEmptyStringReplacementValue() { + assertThrows(ConfigException.class, () -> checkReplacementSchemaless("short", ""), "String must be non-empty"); + } +} diff --git a/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/RegexRouterTest.java b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/RegexRouterTest.java new file mode 100644 index 0000000..cef82d2 --- /dev/null +++ b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/RegexRouterTest.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.connect.sink.SinkRecord; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class RegexRouterTest { + + private static String apply(String regex, String replacement, String topic) { + final Map props = new HashMap<>(); + props.put("regex", regex); + props.put("replacement", replacement); + final RegexRouter router = new RegexRouter<>(); + router.configure(props); + String sinkTopic = router.apply(new SinkRecord(topic, 0, null, null, null, null, 0)).topic(); + router.close(); + return sinkTopic; + } + + @Test + public void staticReplacement() { + assertEquals("bar", apply("foo", "bar", "foo")); + } + + @Test + public void doesntMatch() { + assertEquals("orig", apply("foo", "bar", "orig")); + } + + @Test + public void identity() { + assertEquals("orig", apply("(.*)", "$1", "orig")); + } + + @Test + public void addPrefix() { + assertEquals("prefix-orig", apply("(.*)", "prefix-$1", "orig")); + } + + @Test + public void addSuffix() { + assertEquals("orig-suffix", apply("(.*)", "$1-suffix", "orig")); + } + + @Test + public void slice() { + assertEquals("index", apply("(.*)-(\\d\\d\\d\\d\\d\\d\\d\\d)", "$1", "index-20160117")); + } + +} diff --git a/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/ReplaceFieldTest.java b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/ReplaceFieldTest.java new file mode 100644 index 0000000..f8641a7 --- /dev/null +++ b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/ReplaceFieldTest.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaBuilder; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.sink.SinkRecord; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +public class ReplaceFieldTest { + private ReplaceField xform = new ReplaceField.Value<>(); + + @AfterEach + public void teardown() { + xform.close(); + } + + @Test + public void tombstoneSchemaless() { + final Map props = new HashMap<>(); + props.put("include", "abc,foo"); + props.put("renames", "abc:xyz,foo:bar"); + + xform.configure(props); + + final SinkRecord record = new SinkRecord("test", 0, null, null, null, null, 0); + final SinkRecord transformedRecord = xform.apply(record); + + assertNull(transformedRecord.value()); + assertNull(transformedRecord.valueSchema()); + } + + @Test + public void tombstoneWithSchema() { + final Map props = new HashMap<>(); + props.put("include", "abc,foo"); + props.put("renames", "abc:xyz,foo:bar"); + + xform.configure(props); + + final Schema schema = SchemaBuilder.struct() + .field("dont", Schema.STRING_SCHEMA) + .field("abc", Schema.INT32_SCHEMA) + .field("foo", Schema.BOOLEAN_SCHEMA) + .field("etc", Schema.STRING_SCHEMA) + .build(); + + final SinkRecord record = new SinkRecord("test", 0, null, null, schema, null, 0); + final SinkRecord transformedRecord = xform.apply(record); + + assertNull(transformedRecord.value()); + assertEquals(schema, transformedRecord.valueSchema()); + } + + @SuppressWarnings("unchecked") + @Test + public void schemaless() { + final Map props = new HashMap<>(); + props.put("exclude", "dont"); + props.put("renames", "abc:xyz,foo:bar"); + + xform.configure(props); + + final Map value = new HashMap<>(); + value.put("dont", "whatever"); + value.put("abc", 42); + value.put("foo", true); + value.put("etc", "etc"); + + final SinkRecord record = new SinkRecord("test", 0, null, null, null, value, 0); + final SinkRecord transformedRecord = xform.apply(record); + + final Map updatedValue = (Map) transformedRecord.value(); + assertEquals(3, updatedValue.size()); + assertEquals(42, updatedValue.get("xyz")); + assertEquals(true, updatedValue.get("bar")); + assertEquals("etc", updatedValue.get("etc")); + } + + @Test + public void withSchema() { + final Map props = new HashMap<>(); + props.put("include", "abc,foo"); + props.put("renames", "abc:xyz,foo:bar"); + + xform.configure(props); + + final Schema schema = SchemaBuilder.struct() + .field("dont", Schema.STRING_SCHEMA) + .field("abc", Schema.INT32_SCHEMA) + .field("foo", Schema.BOOLEAN_SCHEMA) + .field("etc", Schema.STRING_SCHEMA) + .build(); + + final Struct value = new Struct(schema); + value.put("dont", "whatever"); + value.put("abc", 42); + value.put("foo", true); + value.put("etc", "etc"); + + final SinkRecord record = new SinkRecord("test", 0, null, null, schema, value, 0); + final SinkRecord transformedRecord = xform.apply(record); + + final Struct updatedValue = (Struct) transformedRecord.value(); + + assertEquals(2, updatedValue.schema().fields().size()); + assertEquals(Integer.valueOf(42), updatedValue.getInt32("xyz")); + assertEquals(true, updatedValue.getBoolean("bar")); + } + + @Test + public void testIncludeBackwardsCompatibility() { + final Map props = new HashMap<>(); + props.put("whitelist", "abc,foo"); + props.put("renames", "abc:xyz,foo:bar"); + + xform.configure(props); + + final SinkRecord record = new SinkRecord("test", 0, null, null, null, null, 0); + final SinkRecord transformedRecord = xform.apply(record); + + assertNull(transformedRecord.value()); + assertNull(transformedRecord.valueSchema()); + } + + @SuppressWarnings("unchecked") + @Test + public void testExcludeBackwardsCompatibility() { + final Map props = new HashMap<>(); + props.put("blacklist", "dont"); + props.put("renames", "abc:xyz,foo:bar"); + + xform.configure(props); + + final Map value = new HashMap<>(); + value.put("dont", "whatever"); + value.put("abc", 42); + value.put("foo", true); + value.put("etc", "etc"); + + final SinkRecord record = new SinkRecord("test", 0, null, null, null, value, 0); + final SinkRecord transformedRecord = xform.apply(record); + + final Map updatedValue = (Map) transformedRecord.value(); + assertEquals(3, updatedValue.size()); + assertEquals(42, updatedValue.get("xyz")); + assertEquals(true, updatedValue.get("bar")); + assertEquals("etc", updatedValue.get("etc")); + } +} diff --git a/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/SetSchemaMetadataTest.java b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/SetSchemaMetadataTest.java new file mode 100644 index 0000000..04a35ca --- /dev/null +++ b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/SetSchemaMetadataTest.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.connect.data.Field; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaBuilder; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.sink.SinkRecord; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; + +public class SetSchemaMetadataTest { + private final SetSchemaMetadata xform = new SetSchemaMetadata.Value<>(); + + @AfterEach + public void teardown() { + xform.close(); + } + + @Test + public void schemaNameUpdate() { + xform.configure(Collections.singletonMap("schema.name", "foo")); + final SinkRecord record = new SinkRecord("", 0, null, null, SchemaBuilder.struct().build(), null, 0); + final SinkRecord updatedRecord = xform.apply(record); + assertEquals("foo", updatedRecord.valueSchema().name()); + } + + @Test + public void schemaVersionUpdate() { + xform.configure(Collections.singletonMap("schema.version", 42)); + final SinkRecord record = new SinkRecord("", 0, null, null, SchemaBuilder.struct().build(), null, 0); + final SinkRecord updatedRecord = xform.apply(record); + assertEquals(Integer.valueOf(42), updatedRecord.valueSchema().version()); + } + + @Test + public void schemaNameAndVersionUpdate() { + final Map props = new HashMap<>(); + props.put("schema.name", "foo"); + props.put("schema.version", "42"); + + xform.configure(props); + + final SinkRecord record = new SinkRecord("", 0, null, null, SchemaBuilder.struct().build(), null, 0); + + final SinkRecord updatedRecord = xform.apply(record); + + assertEquals("foo", updatedRecord.valueSchema().name()); + assertEquals(Integer.valueOf(42), updatedRecord.valueSchema().version()); + } + + @Test + public void schemaNameAndVersionUpdateWithStruct() { + final String fieldName1 = "f1"; + final String fieldName2 = "f2"; + final String fieldValue1 = "value1"; + final int fieldValue2 = 1; + final Schema schema = SchemaBuilder.struct() + .name("my.orig.SchemaDefn") + .field(fieldName1, Schema.STRING_SCHEMA) + .field(fieldName2, Schema.INT32_SCHEMA) + .build(); + final Struct value = new Struct(schema).put(fieldName1, fieldValue1).put(fieldName2, fieldValue2); + + final Map props = new HashMap<>(); + props.put("schema.name", "foo"); + props.put("schema.version", "42"); + xform.configure(props); + + final SinkRecord record = new SinkRecord("", 0, null, null, schema, value, 0); + + final SinkRecord updatedRecord = xform.apply(record); + + assertEquals("foo", updatedRecord.valueSchema().name()); + assertEquals(Integer.valueOf(42), updatedRecord.valueSchema().version()); + + // Make sure the struct's schema and fields all point to the new schema + assertMatchingSchema((Struct) updatedRecord.value(), updatedRecord.valueSchema()); + } + + @Test + public void updateSchemaOfStruct() { + final String fieldName1 = "f1"; + final String fieldName2 = "f2"; + final String fieldValue1 = "value1"; + final int fieldValue2 = 1; + final Schema schema = SchemaBuilder.struct() + .name("my.orig.SchemaDefn") + .field(fieldName1, Schema.STRING_SCHEMA) + .field(fieldName2, Schema.INT32_SCHEMA) + .build(); + final Struct value = new Struct(schema).put(fieldName1, fieldValue1).put(fieldName2, fieldValue2); + + final Schema newSchema = SchemaBuilder.struct() + .name("my.updated.SchemaDefn") + .field(fieldName1, Schema.STRING_SCHEMA) + .field(fieldName2, Schema.INT32_SCHEMA) + .build(); + + Struct newValue = (Struct) SetSchemaMetadata.updateSchemaIn(value, newSchema); + assertMatchingSchema(newValue, newSchema); + } + + @Test + public void updateSchemaOfNonStruct() { + Object value = 1; + Object updatedValue = SetSchemaMetadata.updateSchemaIn(value, Schema.INT32_SCHEMA); + assertSame(value, updatedValue); + } + + @Test + public void updateSchemaOfNull() { + Object updatedValue = SetSchemaMetadata.updateSchemaIn(null, Schema.INT32_SCHEMA); + assertNull(updatedValue); + } + + protected void assertMatchingSchema(Struct value, Schema schema) { + assertSame(schema, value.schema()); + assertEquals(schema.name(), value.schema().name()); + for (Field field : schema.fields()) { + String fieldName = field.name(); + assertEquals(schema.field(fieldName).name(), value.schema().field(fieldName).name()); + assertEquals(schema.field(fieldName).index(), value.schema().field(fieldName).index()); + assertSame(schema.field(fieldName).schema(), value.schema().field(fieldName).schema()); + } + } +} diff --git a/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/TimestampConverterTest.java b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/TimestampConverterTest.java new file mode 100644 index 0000000..212b9ee --- /dev/null +++ b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/TimestampConverterTest.java @@ -0,0 +1,548 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.connect.data.Date; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaBuilder; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.data.Time; +import org.apache.kafka.connect.data.Timestamp; +import org.apache.kafka.connect.source.SourceRecord; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import java.util.Calendar; +import java.util.Collections; +import java.util.GregorianCalendar; +import java.util.HashMap; +import java.util.Map; +import java.util.TimeZone; + +import static org.apache.kafka.connect.transforms.util.Requirements.requireStruct; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class TimestampConverterTest { + private static final TimeZone UTC = TimeZone.getTimeZone("UTC"); + private static final Calendar EPOCH; + private static final Calendar TIME; + private static final Calendar DATE; + private static final Calendar DATE_PLUS_TIME; + private static final long DATE_PLUS_TIME_UNIX; + private static final String STRING_DATE_FMT = "yyyy MM dd HH mm ss SSS z"; + private static final String DATE_PLUS_TIME_STRING; + + private final TimestampConverter xformKey = new TimestampConverter.Key<>(); + private final TimestampConverter xformValue = new TimestampConverter.Value<>(); + + static { + EPOCH = GregorianCalendar.getInstance(UTC); + EPOCH.setTimeInMillis(0L); + + TIME = GregorianCalendar.getInstance(UTC); + TIME.setTimeInMillis(0L); + TIME.add(Calendar.MILLISECOND, 1234); + + DATE = GregorianCalendar.getInstance(UTC); + DATE.setTimeInMillis(0L); + DATE.set(1970, Calendar.JANUARY, 1, 0, 0, 0); + DATE.add(Calendar.DATE, 1); + + DATE_PLUS_TIME = GregorianCalendar.getInstance(UTC); + DATE_PLUS_TIME.setTimeInMillis(0L); + DATE_PLUS_TIME.add(Calendar.DATE, 1); + DATE_PLUS_TIME.add(Calendar.MILLISECOND, 1234); + + DATE_PLUS_TIME_UNIX = DATE_PLUS_TIME.getTime().getTime(); + DATE_PLUS_TIME_STRING = "1970 01 02 00 00 01 234 UTC"; + } + + + // Configuration + + @AfterEach + public void teardown() { + xformKey.close(); + xformValue.close(); + } + + @Test + public void testConfigNoTargetType() { + assertThrows(ConfigException.class, () -> xformValue.configure(Collections.emptyMap())); + } + + @Test + public void testConfigInvalidTargetType() { + assertThrows(ConfigException.class, + () -> xformValue.configure(Collections.singletonMap(TimestampConverter.TARGET_TYPE_CONFIG, "invalid"))); + } + + @Test + public void testConfigMissingFormat() { + assertThrows(ConfigException.class, + () -> xformValue.configure(Collections.singletonMap(TimestampConverter.TARGET_TYPE_CONFIG, "string"))); + } + + @Test + public void testConfigInvalidFormat() { + Map config = new HashMap<>(); + config.put(TimestampConverter.TARGET_TYPE_CONFIG, "string"); + config.put(TimestampConverter.FORMAT_CONFIG, "bad-format"); + assertThrows(ConfigException.class, () -> xformValue.configure(config)); + } + + // Conversions without schemas (most flexible Timestamp -> other types) + + @Test + public void testSchemalessIdentity() { + xformValue.configure(Collections.singletonMap(TimestampConverter.TARGET_TYPE_CONFIG, "Timestamp")); + SourceRecord transformed = xformValue.apply(createRecordSchemaless(DATE_PLUS_TIME.getTime())); + + assertNull(transformed.valueSchema()); + assertEquals(DATE_PLUS_TIME.getTime(), transformed.value()); + } + + @Test + public void testSchemalessTimestampToDate() { + xformValue.configure(Collections.singletonMap(TimestampConverter.TARGET_TYPE_CONFIG, "Date")); + SourceRecord transformed = xformValue.apply(createRecordSchemaless(DATE_PLUS_TIME.getTime())); + + assertNull(transformed.valueSchema()); + assertEquals(DATE.getTime(), transformed.value()); + } + + @Test + public void testSchemalessTimestampToTime() { + xformValue.configure(Collections.singletonMap(TimestampConverter.TARGET_TYPE_CONFIG, "Time")); + SourceRecord transformed = xformValue.apply(createRecordSchemaless(DATE_PLUS_TIME.getTime())); + + assertNull(transformed.valueSchema()); + assertEquals(TIME.getTime(), transformed.value()); + } + + @Test + public void testSchemalessTimestampToUnix() { + xformValue.configure(Collections.singletonMap(TimestampConverter.TARGET_TYPE_CONFIG, "unix")); + SourceRecord transformed = xformValue.apply(createRecordSchemaless(DATE_PLUS_TIME.getTime())); + + assertNull(transformed.valueSchema()); + assertEquals(DATE_PLUS_TIME_UNIX, transformed.value()); + } + + @Test + public void testSchemalessTimestampToString() { + Map config = new HashMap<>(); + config.put(TimestampConverter.TARGET_TYPE_CONFIG, "string"); + config.put(TimestampConverter.FORMAT_CONFIG, STRING_DATE_FMT); + xformValue.configure(config); + SourceRecord transformed = xformValue.apply(createRecordSchemaless(DATE_PLUS_TIME.getTime())); + + assertNull(transformed.valueSchema()); + assertEquals(DATE_PLUS_TIME_STRING, transformed.value()); + } + + + // Conversions without schemas (core types -> most flexible Timestamp format) + + @Test + public void testSchemalessDateToTimestamp() { + xformValue.configure(Collections.singletonMap(TimestampConverter.TARGET_TYPE_CONFIG, "Timestamp")); + SourceRecord transformed = xformValue.apply(createRecordSchemaless(DATE.getTime())); + + assertNull(transformed.valueSchema()); + // No change expected since the source type is coarser-grained + assertEquals(DATE.getTime(), transformed.value()); + } + + @Test + public void testSchemalessTimeToTimestamp() { + xformValue.configure(Collections.singletonMap(TimestampConverter.TARGET_TYPE_CONFIG, "Timestamp")); + SourceRecord transformed = xformValue.apply(createRecordSchemaless(TIME.getTime())); + + assertNull(transformed.valueSchema()); + // No change expected since the source type is coarser-grained + assertEquals(TIME.getTime(), transformed.value()); + } + + @Test + public void testSchemalessUnixToTimestamp() { + xformValue.configure(Collections.singletonMap(TimestampConverter.TARGET_TYPE_CONFIG, "Timestamp")); + SourceRecord transformed = xformValue.apply(createRecordSchemaless(DATE_PLUS_TIME_UNIX)); + + assertNull(transformed.valueSchema()); + assertEquals(DATE_PLUS_TIME.getTime(), transformed.value()); + } + + @Test + public void testSchemalessStringToTimestamp() { + Map config = new HashMap<>(); + config.put(TimestampConverter.TARGET_TYPE_CONFIG, "Timestamp"); + config.put(TimestampConverter.FORMAT_CONFIG, STRING_DATE_FMT); + xformValue.configure(config); + SourceRecord transformed = xformValue.apply(createRecordSchemaless(DATE_PLUS_TIME_STRING)); + + assertNull(transformed.valueSchema()); + assertEquals(DATE_PLUS_TIME.getTime(), transformed.value()); + } + + + // Conversions with schemas (most flexible Timestamp -> other types) + + @Test + public void testWithSchemaIdentity() { + xformValue.configure(Collections.singletonMap(TimestampConverter.TARGET_TYPE_CONFIG, "Timestamp")); + SourceRecord transformed = xformValue.apply(createRecordWithSchema(Timestamp.SCHEMA, DATE_PLUS_TIME.getTime())); + + assertEquals(Timestamp.SCHEMA, transformed.valueSchema()); + assertEquals(DATE_PLUS_TIME.getTime(), transformed.value()); + } + + @Test + public void testWithSchemaTimestampToDate() { + xformValue.configure(Collections.singletonMap(TimestampConverter.TARGET_TYPE_CONFIG, "Date")); + SourceRecord transformed = xformValue.apply(createRecordWithSchema(Timestamp.SCHEMA, DATE_PLUS_TIME.getTime())); + + assertEquals(Date.SCHEMA, transformed.valueSchema()); + assertEquals(DATE.getTime(), transformed.value()); + } + + @Test + public void testWithSchemaTimestampToTime() { + xformValue.configure(Collections.singletonMap(TimestampConverter.TARGET_TYPE_CONFIG, "Time")); + SourceRecord transformed = xformValue.apply(createRecordWithSchema(Timestamp.SCHEMA, DATE_PLUS_TIME.getTime())); + + assertEquals(Time.SCHEMA, transformed.valueSchema()); + assertEquals(TIME.getTime(), transformed.value()); + } + + @Test + public void testWithSchemaTimestampToUnix() { + xformValue.configure(Collections.singletonMap(TimestampConverter.TARGET_TYPE_CONFIG, "unix")); + SourceRecord transformed = xformValue.apply(createRecordWithSchema(Timestamp.SCHEMA, DATE_PLUS_TIME.getTime())); + + assertEquals(Schema.INT64_SCHEMA, transformed.valueSchema()); + assertEquals(DATE_PLUS_TIME_UNIX, transformed.value()); + } + + @Test + public void testWithSchemaTimestampToString() { + Map config = new HashMap<>(); + config.put(TimestampConverter.TARGET_TYPE_CONFIG, "string"); + config.put(TimestampConverter.FORMAT_CONFIG, STRING_DATE_FMT); + xformValue.configure(config); + SourceRecord transformed = xformValue.apply(createRecordWithSchema(Timestamp.SCHEMA, DATE_PLUS_TIME.getTime())); + + assertEquals(Schema.STRING_SCHEMA, transformed.valueSchema()); + assertEquals(DATE_PLUS_TIME_STRING, transformed.value()); + } + + // Null-value conversions schemaless + + @Test + public void testSchemalessNullValueToString() { + testSchemalessNullValueConversion("string"); + testSchemalessNullFieldConversion("string"); + } + @Test + public void testSchemalessNullValueToDate() { + testSchemalessNullValueConversion("Date"); + testSchemalessNullFieldConversion("Date"); + } + @Test + public void testSchemalessNullValueToTimestamp() { + testSchemalessNullValueConversion("Timestamp"); + testSchemalessNullFieldConversion("Timestamp"); + } + @Test + public void testSchemalessNullValueToUnix() { + testSchemalessNullValueConversion("unix"); + testSchemalessNullFieldConversion("unix"); + } + + @Test + public void testSchemalessNullValueToTime() { + testSchemalessNullValueConversion("Time"); + testSchemalessNullFieldConversion("Time"); + } + + private void testSchemalessNullValueConversion(String targetType) { + Map config = new HashMap<>(); + config.put(TimestampConverter.TARGET_TYPE_CONFIG, targetType); + config.put(TimestampConverter.FORMAT_CONFIG, STRING_DATE_FMT); + xformValue.configure(config); + SourceRecord transformed = xformValue.apply(createRecordSchemaless(null)); + + assertNull(transformed.valueSchema()); + assertNull(transformed.value()); + } + + private void testSchemalessNullFieldConversion(String targetType) { + Map config = new HashMap<>(); + config.put(TimestampConverter.TARGET_TYPE_CONFIG, targetType); + config.put(TimestampConverter.FORMAT_CONFIG, STRING_DATE_FMT); + config.put(TimestampConverter.FIELD_CONFIG, "ts"); + xformValue.configure(config); + SourceRecord transformed = xformValue.apply(createRecordSchemaless(null)); + + assertNull(transformed.valueSchema()); + assertNull(transformed.value()); + } + + // Conversions with schemas (core types -> most flexible Timestamp format) + + @Test + public void testWithSchemaDateToTimestamp() { + xformValue.configure(Collections.singletonMap(TimestampConverter.TARGET_TYPE_CONFIG, "Timestamp")); + SourceRecord transformed = xformValue.apply(createRecordWithSchema(Date.SCHEMA, DATE.getTime())); + + assertEquals(Timestamp.SCHEMA, transformed.valueSchema()); + // No change expected since the source type is coarser-grained + assertEquals(DATE.getTime(), transformed.value()); + } + + @Test + public void testWithSchemaTimeToTimestamp() { + xformValue.configure(Collections.singletonMap(TimestampConverter.TARGET_TYPE_CONFIG, "Timestamp")); + SourceRecord transformed = xformValue.apply(createRecordWithSchema(Time.SCHEMA, TIME.getTime())); + + assertEquals(Timestamp.SCHEMA, transformed.valueSchema()); + // No change expected since the source type is coarser-grained + assertEquals(TIME.getTime(), transformed.value()); + } + + @Test + public void testWithSchemaUnixToTimestamp() { + xformValue.configure(Collections.singletonMap(TimestampConverter.TARGET_TYPE_CONFIG, "Timestamp")); + SourceRecord transformed = xformValue.apply(createRecordWithSchema(Schema.INT64_SCHEMA, DATE_PLUS_TIME_UNIX)); + + assertEquals(Timestamp.SCHEMA, transformed.valueSchema()); + assertEquals(DATE_PLUS_TIME.getTime(), transformed.value()); + } + + @Test + public void testWithSchemaStringToTimestamp() { + Map config = new HashMap<>(); + config.put(TimestampConverter.TARGET_TYPE_CONFIG, "Timestamp"); + config.put(TimestampConverter.FORMAT_CONFIG, STRING_DATE_FMT); + xformValue.configure(config); + SourceRecord transformed = xformValue.apply(createRecordWithSchema(Schema.STRING_SCHEMA, DATE_PLUS_TIME_STRING)); + + assertEquals(Timestamp.SCHEMA, transformed.valueSchema()); + assertEquals(DATE_PLUS_TIME.getTime(), transformed.value()); + } + + // Null-value conversions with schema + + @Test + public void testWithSchemaNullValueToTimestamp() { + testWithSchemaNullValueConversion("Timestamp", Schema.OPTIONAL_INT64_SCHEMA, TimestampConverter.OPTIONAL_TIMESTAMP_SCHEMA); + testWithSchemaNullValueConversion("Timestamp", TimestampConverter.OPTIONAL_TIME_SCHEMA, TimestampConverter.OPTIONAL_TIMESTAMP_SCHEMA); + testWithSchemaNullValueConversion("Timestamp", TimestampConverter.OPTIONAL_DATE_SCHEMA, TimestampConverter.OPTIONAL_TIMESTAMP_SCHEMA); + testWithSchemaNullValueConversion("Timestamp", Schema.OPTIONAL_STRING_SCHEMA, TimestampConverter.OPTIONAL_TIMESTAMP_SCHEMA); + testWithSchemaNullValueConversion("Timestamp", TimestampConverter.OPTIONAL_TIMESTAMP_SCHEMA, TimestampConverter.OPTIONAL_TIMESTAMP_SCHEMA); + } + + @Test + public void testWithSchemaNullFieldToTimestamp() { + testWithSchemaNullFieldConversion("Timestamp", Schema.OPTIONAL_INT64_SCHEMA, TimestampConverter.OPTIONAL_TIMESTAMP_SCHEMA); + testWithSchemaNullFieldConversion("Timestamp", TimestampConverter.OPTIONAL_TIME_SCHEMA, TimestampConverter.OPTIONAL_TIMESTAMP_SCHEMA); + testWithSchemaNullFieldConversion("Timestamp", TimestampConverter.OPTIONAL_DATE_SCHEMA, TimestampConverter.OPTIONAL_TIMESTAMP_SCHEMA); + testWithSchemaNullFieldConversion("Timestamp", Schema.OPTIONAL_STRING_SCHEMA, TimestampConverter.OPTIONAL_TIMESTAMP_SCHEMA); + testWithSchemaNullFieldConversion("Timestamp", TimestampConverter.OPTIONAL_TIMESTAMP_SCHEMA, TimestampConverter.OPTIONAL_TIMESTAMP_SCHEMA); + } + + @Test + public void testWithSchemaNullValueToUnix() { + testWithSchemaNullValueConversion("unix", Schema.OPTIONAL_INT64_SCHEMA, Schema.OPTIONAL_INT64_SCHEMA); + testWithSchemaNullValueConversion("unix", TimestampConverter.OPTIONAL_TIME_SCHEMA, Schema.OPTIONAL_INT64_SCHEMA); + testWithSchemaNullValueConversion("unix", TimestampConverter.OPTIONAL_DATE_SCHEMA, Schema.OPTIONAL_INT64_SCHEMA); + testWithSchemaNullValueConversion("unix", Schema.OPTIONAL_STRING_SCHEMA, Schema.OPTIONAL_INT64_SCHEMA); + testWithSchemaNullValueConversion("unix", TimestampConverter.OPTIONAL_TIMESTAMP_SCHEMA, Schema.OPTIONAL_INT64_SCHEMA); + } + + @Test + public void testWithSchemaNullFieldToUnix() { + testWithSchemaNullFieldConversion("unix", Schema.OPTIONAL_INT64_SCHEMA, Schema.OPTIONAL_INT64_SCHEMA); + testWithSchemaNullFieldConversion("unix", TimestampConverter.OPTIONAL_TIME_SCHEMA, Schema.OPTIONAL_INT64_SCHEMA); + testWithSchemaNullFieldConversion("unix", TimestampConverter.OPTIONAL_DATE_SCHEMA, Schema.OPTIONAL_INT64_SCHEMA); + testWithSchemaNullFieldConversion("unix", Schema.OPTIONAL_STRING_SCHEMA, Schema.OPTIONAL_INT64_SCHEMA); + testWithSchemaNullFieldConversion("unix", TimestampConverter.OPTIONAL_TIMESTAMP_SCHEMA, Schema.OPTIONAL_INT64_SCHEMA); + } + + @Test + public void testWithSchemaNullValueToTime() { + testWithSchemaNullValueConversion("Time", Schema.OPTIONAL_INT64_SCHEMA, TimestampConverter.OPTIONAL_TIME_SCHEMA); + testWithSchemaNullValueConversion("Time", TimestampConverter.OPTIONAL_TIME_SCHEMA, TimestampConverter.OPTIONAL_TIME_SCHEMA); + testWithSchemaNullValueConversion("Time", TimestampConverter.OPTIONAL_DATE_SCHEMA, TimestampConverter.OPTIONAL_TIME_SCHEMA); + testWithSchemaNullValueConversion("Time", Schema.OPTIONAL_STRING_SCHEMA, TimestampConverter.OPTIONAL_TIME_SCHEMA); + testWithSchemaNullValueConversion("Time", TimestampConverter.OPTIONAL_TIMESTAMP_SCHEMA, TimestampConverter.OPTIONAL_TIME_SCHEMA); + } + + @Test + public void testWithSchemaNullFieldToTime() { + testWithSchemaNullFieldConversion("Time", Schema.OPTIONAL_INT64_SCHEMA, TimestampConverter.OPTIONAL_TIME_SCHEMA); + testWithSchemaNullFieldConversion("Time", TimestampConverter.OPTIONAL_TIME_SCHEMA, TimestampConverter.OPTIONAL_TIME_SCHEMA); + testWithSchemaNullFieldConversion("Time", TimestampConverter.OPTIONAL_DATE_SCHEMA, TimestampConverter.OPTIONAL_TIME_SCHEMA); + testWithSchemaNullFieldConversion("Time", Schema.OPTIONAL_STRING_SCHEMA, TimestampConverter.OPTIONAL_TIME_SCHEMA); + testWithSchemaNullFieldConversion("Time", TimestampConverter.OPTIONAL_TIMESTAMP_SCHEMA, TimestampConverter.OPTIONAL_TIME_SCHEMA); + } + + @Test + public void testWithSchemaNullValueToDate() { + testWithSchemaNullValueConversion("Date", Schema.OPTIONAL_INT64_SCHEMA, TimestampConverter.OPTIONAL_DATE_SCHEMA); + testWithSchemaNullValueConversion("Date", TimestampConverter.OPTIONAL_TIME_SCHEMA, TimestampConverter.OPTIONAL_DATE_SCHEMA); + testWithSchemaNullValueConversion("Date", TimestampConverter.OPTIONAL_DATE_SCHEMA, TimestampConverter.OPTIONAL_DATE_SCHEMA); + testWithSchemaNullValueConversion("Date", Schema.OPTIONAL_STRING_SCHEMA, TimestampConverter.OPTIONAL_DATE_SCHEMA); + testWithSchemaNullValueConversion("Date", TimestampConverter.OPTIONAL_TIMESTAMP_SCHEMA, TimestampConverter.OPTIONAL_DATE_SCHEMA); + } + + @Test + public void testWithSchemaNullFieldToDate() { + testWithSchemaNullFieldConversion("Date", Schema.OPTIONAL_INT64_SCHEMA, TimestampConverter.OPTIONAL_DATE_SCHEMA); + testWithSchemaNullFieldConversion("Date", TimestampConverter.OPTIONAL_TIME_SCHEMA, TimestampConverter.OPTIONAL_DATE_SCHEMA); + testWithSchemaNullFieldConversion("Date", TimestampConverter.OPTIONAL_DATE_SCHEMA, TimestampConverter.OPTIONAL_DATE_SCHEMA); + testWithSchemaNullFieldConversion("Date", Schema.OPTIONAL_STRING_SCHEMA, TimestampConverter.OPTIONAL_DATE_SCHEMA); + testWithSchemaNullFieldConversion("Date", TimestampConverter.OPTIONAL_TIMESTAMP_SCHEMA, TimestampConverter.OPTIONAL_DATE_SCHEMA); + } + + @Test + public void testWithSchemaNullValueToString() { + testWithSchemaNullValueConversion("string", Schema.OPTIONAL_INT64_SCHEMA, Schema.OPTIONAL_STRING_SCHEMA); + testWithSchemaNullValueConversion("string", TimestampConverter.OPTIONAL_TIME_SCHEMA, Schema.OPTIONAL_STRING_SCHEMA); + testWithSchemaNullValueConversion("string", TimestampConverter.OPTIONAL_DATE_SCHEMA, Schema.OPTIONAL_STRING_SCHEMA); + testWithSchemaNullValueConversion("string", Schema.OPTIONAL_STRING_SCHEMA, Schema.OPTIONAL_STRING_SCHEMA); + testWithSchemaNullValueConversion("string", TimestampConverter.OPTIONAL_TIMESTAMP_SCHEMA, Schema.OPTIONAL_STRING_SCHEMA); + } + + @Test + public void testWithSchemaNullFieldToString() { + testWithSchemaNullFieldConversion("string", Schema.OPTIONAL_INT64_SCHEMA, Schema.OPTIONAL_STRING_SCHEMA); + testWithSchemaNullFieldConversion("string", TimestampConverter.OPTIONAL_TIME_SCHEMA, Schema.OPTIONAL_STRING_SCHEMA); + testWithSchemaNullFieldConversion("string", TimestampConverter.OPTIONAL_DATE_SCHEMA, Schema.OPTIONAL_STRING_SCHEMA); + testWithSchemaNullFieldConversion("string", Schema.OPTIONAL_STRING_SCHEMA, Schema.OPTIONAL_STRING_SCHEMA); + testWithSchemaNullFieldConversion("string", TimestampConverter.OPTIONAL_TIMESTAMP_SCHEMA, Schema.OPTIONAL_STRING_SCHEMA); + } + + private void testWithSchemaNullValueConversion(String targetType, Schema originalSchema, Schema expectedSchema) { + Map config = new HashMap<>(); + config.put(TimestampConverter.TARGET_TYPE_CONFIG, targetType); + config.put(TimestampConverter.FORMAT_CONFIG, STRING_DATE_FMT); + xformValue.configure(config); + SourceRecord transformed = xformValue.apply(createRecordWithSchema(originalSchema, null)); + + assertEquals(expectedSchema, transformed.valueSchema()); + assertNull(transformed.value()); + } + + private void testWithSchemaNullFieldConversion(String targetType, Schema originalSchema, Schema expectedSchema) { + Map config = new HashMap<>(); + config.put(TimestampConverter.TARGET_TYPE_CONFIG, targetType); + config.put(TimestampConverter.FORMAT_CONFIG, STRING_DATE_FMT); + config.put(TimestampConverter.FIELD_CONFIG, "ts"); + xformValue.configure(config); + SchemaBuilder structSchema = SchemaBuilder.struct() + .field("ts", originalSchema) + .field("other", Schema.STRING_SCHEMA); + + SchemaBuilder expectedStructSchema = SchemaBuilder.struct() + .field("ts", expectedSchema) + .field("other", Schema.STRING_SCHEMA); + + Struct original = new Struct(structSchema); + original.put("ts", null); + original.put("other", "test"); + + // Struct field is null + SourceRecord transformed = xformValue.apply(createRecordWithSchema(structSchema.build(), original)); + + assertEquals(expectedStructSchema.build(), transformed.valueSchema()); + assertNull(requireStruct(transformed.value(), "").get("ts")); + + // entire Struct is null + transformed = xformValue.apply(createRecordWithSchema(structSchema.optional().build(), null)); + + assertEquals(expectedStructSchema.optional().build(), transformed.valueSchema()); + assertNull(transformed.value()); + } + + // Convert field instead of entire key/value + + @Test + public void testSchemalessFieldConversion() { + Map config = new HashMap<>(); + config.put(TimestampConverter.TARGET_TYPE_CONFIG, "Date"); + config.put(TimestampConverter.FIELD_CONFIG, "ts"); + xformValue.configure(config); + + Object value = Collections.singletonMap("ts", DATE_PLUS_TIME.getTime()); + SourceRecord transformed = xformValue.apply(createRecordSchemaless(value)); + + assertNull(transformed.valueSchema()); + assertEquals(Collections.singletonMap("ts", DATE.getTime()), transformed.value()); + } + + @Test + public void testWithSchemaFieldConversion() { + Map config = new HashMap<>(); + config.put(TimestampConverter.TARGET_TYPE_CONFIG, "Timestamp"); + config.put(TimestampConverter.FIELD_CONFIG, "ts"); + xformValue.configure(config); + + // ts field is a unix timestamp + Schema structWithTimestampFieldSchema = SchemaBuilder.struct() + .field("ts", Schema.INT64_SCHEMA) + .field("other", Schema.STRING_SCHEMA) + .build(); + Struct original = new Struct(structWithTimestampFieldSchema); + original.put("ts", DATE_PLUS_TIME_UNIX); + original.put("other", "test"); + + SourceRecord transformed = xformValue.apply(createRecordWithSchema(structWithTimestampFieldSchema, original)); + + Schema expectedSchema = SchemaBuilder.struct() + .field("ts", Timestamp.SCHEMA) + .field("other", Schema.STRING_SCHEMA) + .build(); + assertEquals(expectedSchema, transformed.valueSchema()); + assertEquals(DATE_PLUS_TIME.getTime(), ((Struct) transformed.value()).get("ts")); + assertEquals("test", ((Struct) transformed.value()).get("other")); + } + + + // Validate Key implementation in addition to Value + + @Test + public void testKey() { + xformKey.configure(Collections.singletonMap(TimestampConverter.TARGET_TYPE_CONFIG, "Timestamp")); + SourceRecord transformed = xformKey.apply(new SourceRecord(null, null, "topic", 0, null, DATE_PLUS_TIME.getTime(), null, null)); + + assertNull(transformed.keySchema()); + assertEquals(DATE_PLUS_TIME.getTime(), transformed.key()); + } + + private SourceRecord createRecordWithSchema(Schema schema, Object value) { + return new SourceRecord(null, null, "topic", 0, schema, value); + } + + private SourceRecord createRecordSchemaless(Object value) { + return createRecordWithSchema(null, value); + } +} diff --git a/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/TimestampRouterTest.java b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/TimestampRouterTest.java new file mode 100644 index 0000000..5fa87ba --- /dev/null +++ b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/TimestampRouterTest.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.connect.source.SourceRecord; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class TimestampRouterTest { + private final TimestampRouter xform = new TimestampRouter<>(); + + @AfterEach + public void teardown() { + xform.close(); + } + + @Test + public void defaultConfiguration() { + xform.configure(Collections.emptyMap()); // defaults + final SourceRecord record = new SourceRecord( + null, null, + "test", 0, + null, null, + null, null, + 1483425001864L + ); + assertEquals("test-20170103", xform.apply(record).topic()); + } + +} diff --git a/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/ValueToKeyTest.java b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/ValueToKeyTest.java new file mode 100644 index 0000000..94fa85c --- /dev/null +++ b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/ValueToKeyTest.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms; + +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.data.SchemaBuilder; +import org.apache.kafka.connect.data.Struct; +import org.apache.kafka.connect.errors.DataException; +import org.apache.kafka.connect.sink.SinkRecord; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.HashMap; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class ValueToKeyTest { + private final ValueToKey xform = new ValueToKey<>(); + + @AfterEach + public void teardown() { + xform.close(); + } + + @Test + public void schemaless() { + xform.configure(Collections.singletonMap("fields", "a,b")); + + final HashMap value = new HashMap<>(); + value.put("a", 1); + value.put("b", 2); + value.put("c", 3); + + final SinkRecord record = new SinkRecord("", 0, null, null, null, value, 0); + final SinkRecord transformedRecord = xform.apply(record); + + final HashMap expectedKey = new HashMap<>(); + expectedKey.put("a", 1); + expectedKey.put("b", 2); + + assertNull(transformedRecord.keySchema()); + assertEquals(expectedKey, transformedRecord.key()); + } + + @Test + public void withSchema() { + xform.configure(Collections.singletonMap("fields", "a,b")); + + final Schema valueSchema = SchemaBuilder.struct() + .field("a", Schema.INT32_SCHEMA) + .field("b", Schema.INT32_SCHEMA) + .field("c", Schema.INT32_SCHEMA) + .build(); + + final Struct value = new Struct(valueSchema); + value.put("a", 1); + value.put("b", 2); + value.put("c", 3); + + final SinkRecord record = new SinkRecord("", 0, null, null, valueSchema, value, 0); + final SinkRecord transformedRecord = xform.apply(record); + + final Schema expectedKeySchema = SchemaBuilder.struct() + .field("a", Schema.INT32_SCHEMA) + .field("b", Schema.INT32_SCHEMA) + .build(); + + final Struct expectedKey = new Struct(expectedKeySchema) + .put("a", 1) + .put("b", 2); + + assertEquals(expectedKeySchema, transformedRecord.keySchema()); + assertEquals(expectedKey, transformedRecord.key()); + } + + @Test + public void nonExistingField() { + xform.configure(Collections.singletonMap("fields", "not_exist")); + + final Schema valueSchema = SchemaBuilder.struct() + .field("a", Schema.INT32_SCHEMA) + .build(); + + final Struct value = new Struct(valueSchema); + value.put("a", 1); + + final SinkRecord record = new SinkRecord("", 0, null, null, valueSchema, value, 0); + + DataException actual = assertThrows(DataException.class, () -> xform.apply(record)); + assertEquals("Field does not exist: not_exist", actual.getMessage()); + } +} diff --git a/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/predicates/HasHeaderKeyTest.java b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/predicates/HasHeaderKeyTest.java new file mode 100644 index 0000000..d21c98f --- /dev/null +++ b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/predicates/HasHeaderKeyTest.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms.predicates; + +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.config.ConfigValue; +import org.apache.kafka.connect.data.Schema; +import org.apache.kafka.connect.header.Header; +import org.apache.kafka.connect.source.SourceRecord; +import org.apache.kafka.connect.transforms.util.SimpleConfig; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static java.util.Collections.singletonList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class HasHeaderKeyTest { + + @Test + public void testNameRequiredInConfig() { + Map props = new HashMap<>(); + ConfigException e = assertThrows(ConfigException.class, () -> config(props)); + assertTrue(e.getMessage().contains("Missing required configuration \"name\"")); + } + + @Test + public void testNameMayNotBeEmptyInConfig() { + Map props = new HashMap<>(); + props.put("name", ""); + ConfigException e = assertThrows(ConfigException.class, () -> config(props)); + assertTrue(e.getMessage().contains("String must be non-empty")); + } + + @Test + public void testConfig() { + HasHeaderKey predicate = new HasHeaderKey<>(); + predicate.config().validate(Collections.singletonMap("name", "foo")); + + List configs = predicate.config().validate(Collections.singletonMap("name", "")); + assertEquals(singletonList("Invalid value for configuration name: String must be non-empty"), configs.get(0).errorMessages()); + } + + @Test + public void testTest() { + HasHeaderKey predicate = new HasHeaderKey<>(); + predicate.configure(Collections.singletonMap("name", "foo")); + + assertTrue(predicate.test(recordWithHeaders("foo"))); + assertTrue(predicate.test(recordWithHeaders("foo", "bar"))); + assertTrue(predicate.test(recordWithHeaders("bar", "foo", "bar", "foo"))); + assertFalse(predicate.test(recordWithHeaders("bar"))); + assertFalse(predicate.test(recordWithHeaders("bar", "bar"))); + assertFalse(predicate.test(recordWithHeaders())); + assertFalse(predicate.test(new SourceRecord(null, null, null, null, null))); + + } + + private SimpleConfig config(Map props) { + return new SimpleConfig(new HasHeaderKey<>().config(), props); + } + + private SourceRecord recordWithHeaders(String... headers) { + return new SourceRecord(null, null, null, null, null, null, null, null, null, + Arrays.stream(headers).map(TestHeader::new).collect(Collectors.toList())); + } + + private static class TestHeader implements Header { + + private final String key; + + public TestHeader(String key) { + this.key = key; + } + + @Override + public String key() { + return key; + } + + @Override + public Schema schema() { + return null; + } + + @Override + public Object value() { + return null; + } + + @Override + public Header with(Schema schema, Object value) { + return null; + } + + @Override + public Header rename(String key) { + return null; + } + } +} diff --git a/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/predicates/TopicNameMatchesTest.java b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/predicates/TopicNameMatchesTest.java new file mode 100644 index 0000000..0640803 --- /dev/null +++ b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/predicates/TopicNameMatchesTest.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms.predicates; + +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.config.ConfigValue; +import org.apache.kafka.connect.source.SourceRecord; +import org.apache.kafka.connect.transforms.util.SimpleConfig; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class TopicNameMatchesTest { + + @Test + public void testPatternRequiredInConfig() { + Map props = new HashMap<>(); + ConfigException e = assertThrows(ConfigException.class, () -> config(props)); + assertTrue(e.getMessage().contains("Missing required configuration \"pattern\"")); + } + + @Test + public void testPatternMayNotBeEmptyInConfig() { + Map props = new HashMap<>(); + props.put("pattern", ""); + ConfigException e = assertThrows(ConfigException.class, () -> config(props)); + assertTrue(e.getMessage().contains("String must be non-empty")); + } + + @Test + public void testPatternIsValidRegexInConfig() { + Map props = new HashMap<>(); + props.put("pattern", "["); + ConfigException e = assertThrows(ConfigException.class, () -> config(props)); + assertTrue(e.getMessage().contains("Invalid regex")); + } + + @Test + public void testConfig() { + TopicNameMatches predicate = new TopicNameMatches<>(); + predicate.config().validate(Collections.singletonMap("pattern", "my-prefix-.*")); + + List configs = predicate.config().validate(Collections.singletonMap("pattern", "*")); + List errorMsgs = configs.get(0).errorMessages(); + assertEquals(1, errorMsgs.size()); + assertTrue(errorMsgs.get(0).contains("Invalid regex")); + } + + @Test + public void testTest() { + TopicNameMatches predicate = new TopicNameMatches<>(); + predicate.configure(Collections.singletonMap("pattern", "my-prefix-.*")); + + assertTrue(predicate.test(recordWithTopicName("my-prefix-"))); + assertTrue(predicate.test(recordWithTopicName("my-prefix-foo"))); + assertFalse(predicate.test(recordWithTopicName("x-my-prefix-"))); + assertFalse(predicate.test(recordWithTopicName("x-my-prefix-foo"))); + assertFalse(predicate.test(recordWithTopicName("your-prefix-"))); + assertFalse(predicate.test(recordWithTopicName("your-prefix-foo"))); + assertFalse(predicate.test(new SourceRecord(null, null, null, null, null))); + + } + + private SimpleConfig config(Map props) { + return new SimpleConfig(TopicNameMatches.CONFIG_DEF, props); + } + + private SourceRecord recordWithTopicName(String topicName) { + return new SourceRecord(null, null, topicName, null, null); + } +} diff --git a/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/util/NonEmptyListValidatorTest.java b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/util/NonEmptyListValidatorTest.java new file mode 100644 index 0000000..ff3c4f7 --- /dev/null +++ b/connect/transforms/src/test/java/org/apache/kafka/connect/transforms/util/NonEmptyListValidatorTest.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.connect.transforms.util; + +import org.apache.kafka.common.config.ConfigException; +import org.junit.jupiter.api.Test; + +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class NonEmptyListValidatorTest { + + @Test + public void testNullList() { + assertThrows(ConfigException.class, () -> new NonEmptyListValidator().ensureValid("foo", null)); + } + + @Test + public void testEmptyList() { + assertThrows(ConfigException.class, + () -> new NonEmptyListValidator().ensureValid("foo", Collections.emptyList())); + } + + @Test + public void testValidList() { + new NonEmptyListValidator().ensureValid("foo", Collections.singletonList("foo")); + } +} diff --git a/core/.gitignore b/core/.gitignore new file mode 100644 index 0000000..0d7e8b0 --- /dev/null +++ b/core/.gitignore @@ -0,0 +1,3 @@ +.cache-main +.cache-tests +/bin/ diff --git a/core/src/main/java/kafka/metrics/FilteringJmxReporter.java b/core/src/main/java/kafka/metrics/FilteringJmxReporter.java new file mode 100644 index 0000000..3794448 --- /dev/null +++ b/core/src/main/java/kafka/metrics/FilteringJmxReporter.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.metrics; + +import com.yammer.metrics.core.Metric; +import com.yammer.metrics.core.MetricName; +import com.yammer.metrics.core.MetricsRegistry; +import com.yammer.metrics.reporting.JmxReporter; + +import java.util.function.Predicate; + +public class FilteringJmxReporter extends JmxReporter { + + private volatile Predicate metricPredicate; + + public FilteringJmxReporter(MetricsRegistry registry, Predicate metricPredicate) { + super(registry); + this.metricPredicate = metricPredicate; + } + + @Override + public void onMetricAdded(MetricName name, Metric metric) { + if (metricPredicate.test(name)) { + super.onMetricAdded(name, metric); + } + } + + public void updatePredicate(Predicate predicate) { + this.metricPredicate = predicate; + // re-register metrics on update + getMetricsRegistry() + .allMetrics() + .forEach((name, metric) -> { + if (metricPredicate.test(name)) { + super.onMetricAdded(name, metric); + } else { + super.onMetricRemoved(name); + } + } + ); + } +} diff --git a/core/src/main/java/kafka/metrics/KafkaYammerMetrics.java b/core/src/main/java/kafka/metrics/KafkaYammerMetrics.java new file mode 100644 index 0000000..dd650fd --- /dev/null +++ b/core/src/main/java/kafka/metrics/KafkaYammerMetrics.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.metrics; + +import com.yammer.metrics.core.MetricsRegistry; + +import org.apache.kafka.common.Reconfigurable; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.metrics.JmxReporter; + +import java.util.Map; +import java.util.Set; +import java.util.function.Predicate; + +/** + * This class encapsulates the default yammer metrics registry for Kafka server, + * and configures the set of exported JMX metrics for Yammer metrics. + * + * KafkaYammerMetrics.defaultRegistry() should always be used instead of Metrics.defaultRegistry() + */ +public class KafkaYammerMetrics implements Reconfigurable { + + public static final KafkaYammerMetrics INSTANCE = new KafkaYammerMetrics(); + + /** + * convenience method to replace {@link com.yammer.metrics.Metrics#defaultRegistry()} + */ + public static MetricsRegistry defaultRegistry() { + return INSTANCE.metricsRegistry; + } + + private final MetricsRegistry metricsRegistry = new MetricsRegistry(); + private final FilteringJmxReporter jmxReporter = new FilteringJmxReporter(metricsRegistry, + metricName -> true); + + private KafkaYammerMetrics() { + jmxReporter.start(); + Runtime.getRuntime().addShutdownHook(new Thread(jmxReporter::shutdown)); + } + + @Override + public void configure(Map configs) { + reconfigure(configs); + } + + @Override + public Set reconfigurableConfigs() { + return JmxReporter.RECONFIGURABLE_CONFIGS; + } + + @Override + public void validateReconfiguration(Map configs) throws ConfigException { + JmxReporter.compilePredicate(configs); + } + + @Override + public void reconfigure(Map configs) { + Predicate mBeanPredicate = JmxReporter.compilePredicate(configs); + jmxReporter.updatePredicate(metricName -> mBeanPredicate.test(metricName.getMBeanName())); + } +} diff --git a/core/src/main/java/kafka/server/builders/KafkaApisBuilder.java b/core/src/main/java/kafka/server/builders/KafkaApisBuilder.java new file mode 100644 index 0000000..971d733 --- /dev/null +++ b/core/src/main/java/kafka/server/builders/KafkaApisBuilder.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server.builders; + +import kafka.coordinator.group.GroupCoordinator; +import kafka.coordinator.transaction.TransactionCoordinator; +import kafka.network.RequestChannel; +import kafka.server.ApiVersionManager; +import kafka.server.AutoTopicCreationManager; +import kafka.server.BrokerTopicStats; +import kafka.server.DelegationTokenManager; +import kafka.server.FetchManager; +import kafka.server.KafkaApis; +import kafka.server.KafkaConfig; +import kafka.server.MetadataCache; +import kafka.server.MetadataSupport; +import kafka.server.QuotaFactory.QuotaManagers; +import kafka.server.ReplicaManager; +import kafka.server.metadata.ConfigRepository; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.server.authorizer.Authorizer; + +import java.util.Collections; +import java.util.Optional; +import scala.compat.java8.OptionConverters; + + +public class KafkaApisBuilder { + private RequestChannel requestChannel = null; + private MetadataSupport metadataSupport = null; + private ReplicaManager replicaManager = null; + private GroupCoordinator groupCoordinator = null; + private TransactionCoordinator txnCoordinator = null; + private AutoTopicCreationManager autoTopicCreationManager = null; + private int brokerId = 0; + private KafkaConfig config = null; + private ConfigRepository configRepository = null; + private MetadataCache metadataCache = null; + private Metrics metrics = null; + private Optional authorizer = Optional.empty(); + private QuotaManagers quotas = null; + private FetchManager fetchManager = null; + private BrokerTopicStats brokerTopicStats = null; + private String clusterId = "clusterId"; + private Time time = Time.SYSTEM; + private DelegationTokenManager tokenManager = null; + private ApiVersionManager apiVersionManager = null; + + public KafkaApisBuilder setRequestChannel(RequestChannel requestChannel) { + this.requestChannel = requestChannel; + return this; + } + + public KafkaApisBuilder setMetadataSupport(MetadataSupport metadataSupport) { + this.metadataSupport = metadataSupport; + return this; + } + + public KafkaApisBuilder setReplicaManager(ReplicaManager replicaManager) { + this.replicaManager = replicaManager; + return this; + } + + public KafkaApisBuilder setGroupCoordinator(GroupCoordinator groupCoordinator) { + this.groupCoordinator = groupCoordinator; + return this; + } + + public KafkaApisBuilder setTxnCoordinator(TransactionCoordinator txnCoordinator) { + this.txnCoordinator = txnCoordinator; + return this; + } + + public KafkaApisBuilder setAutoTopicCreationManager(AutoTopicCreationManager autoTopicCreationManager) { + this.autoTopicCreationManager = autoTopicCreationManager; + return this; + } + + public KafkaApisBuilder setBrokerId(int brokerId) { + this.brokerId = brokerId; + return this; + } + + public KafkaApisBuilder setConfig(KafkaConfig config) { + this.config = config; + return this; + } + + public KafkaApisBuilder setConfigRepository(ConfigRepository configRepository) { + this.configRepository = configRepository; + return this; + } + + public KafkaApisBuilder setMetadataCache(MetadataCache metadataCache) { + this.metadataCache = metadataCache; + return this; + } + + public KafkaApisBuilder setMetrics(Metrics metrics) { + this.metrics = metrics; + return this; + } + + public KafkaApisBuilder setAuthorizer(Optional authorizer) { + this.authorizer = authorizer; + return this; + } + + public KafkaApisBuilder setQuotas(QuotaManagers quotas) { + this.quotas = quotas; + return this; + } + + public KafkaApisBuilder setFetchManager(FetchManager fetchManager) { + this.fetchManager = fetchManager; + return this; + } + + public KafkaApisBuilder setBrokerTopicStats(BrokerTopicStats brokerTopicStats) { + this.brokerTopicStats = brokerTopicStats; + return this; + } + + public KafkaApisBuilder setClusterId(String clusterId) { + this.clusterId = clusterId; + return this; + } + + public KafkaApisBuilder setTime(Time time) { + this.time = time; + return this; + } + + public KafkaApisBuilder setTokenManager(DelegationTokenManager tokenManager) { + this.tokenManager = tokenManager; + return this; + } + + public KafkaApisBuilder setApiVersionManager(ApiVersionManager apiVersionManager) { + this.apiVersionManager = apiVersionManager; + return this; + } + + public KafkaApis build() { + if (requestChannel == null) throw new RuntimeException("you must set requestChannel"); + if (metadataSupport == null) throw new RuntimeException("you must set metadataSupport"); + if (replicaManager == null) throw new RuntimeException("You must set replicaManager"); + if (groupCoordinator == null) throw new RuntimeException("You must set groupCoordinator"); + if (txnCoordinator == null) throw new RuntimeException("You must set txnCoordinator"); + if (autoTopicCreationManager == null) + throw new RuntimeException("You must set autoTopicCreationManager"); + if (config == null) config = new KafkaConfig(Collections.emptyMap()); + if (configRepository == null) throw new RuntimeException("You must set configRepository"); + if (metadataCache == null) throw new RuntimeException("You must set metadataCache"); + if (metrics == null) throw new RuntimeException("You must set metrics"); + if (quotas == null) throw new RuntimeException("You must set quotas"); + if (fetchManager == null) throw new RuntimeException("You must set fetchManager"); + if (brokerTopicStats == null) brokerTopicStats = new BrokerTopicStats(); + if (apiVersionManager == null) throw new RuntimeException("You must set apiVersionManager"); + + return new KafkaApis(requestChannel, + metadataSupport, + replicaManager, + groupCoordinator, + txnCoordinator, + autoTopicCreationManager, + brokerId, + config, + configRepository, + metadataCache, + metrics, + OptionConverters.toScala(authorizer), + quotas, + fetchManager, + brokerTopicStats, + clusterId, + time, + tokenManager, + apiVersionManager); + } +} diff --git a/core/src/main/java/kafka/server/builders/LogManagerBuilder.java b/core/src/main/java/kafka/server/builders/LogManagerBuilder.java new file mode 100644 index 0000000..0082040 --- /dev/null +++ b/core/src/main/java/kafka/server/builders/LogManagerBuilder.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server.builders; + +import kafka.api.ApiVersion; +import kafka.log.CleanerConfig; +import kafka.log.LogConfig; +import kafka.log.LogManager; +import kafka.server.BrokerTopicStats; +import kafka.server.LogDirFailureChannel; +import kafka.server.metadata.ConfigRepository; +import kafka.utils.Scheduler; +import org.apache.kafka.common.utils.Time; + +import java.io.File; +import java.util.Collections; +import java.util.List; +import scala.collection.JavaConverters; + + +public class LogManagerBuilder { + private List logDirs = null; + private List initialOfflineDirs = Collections.emptyList(); + private ConfigRepository configRepository = null; + private LogConfig initialDefaultConfig = null; + private CleanerConfig cleanerConfig = null; + private int recoveryThreadsPerDataDir = 1; + private long flushCheckMs = 1000L; + private long flushRecoveryOffsetCheckpointMs = 10000L; + private long flushStartOffsetCheckpointMs = 10000L; + private long retentionCheckMs = 1000L; + private int maxPidExpirationMs = 60000; + private ApiVersion interBrokerProtocolVersion = ApiVersion.latestVersion(); + private Scheduler scheduler = null; + private BrokerTopicStats brokerTopicStats = null; + private LogDirFailureChannel logDirFailureChannel = null; + private Time time = Time.SYSTEM; + private boolean keepPartitionMetadataFile = true; + + public LogManagerBuilder setLogDirs(List logDirs) { + this.logDirs = logDirs; + return this; + } + + public LogManagerBuilder setInitialOfflineDirs(List initialOfflineDirs) { + this.initialOfflineDirs = initialOfflineDirs; + return this; + } + + public LogManagerBuilder setConfigRepository(ConfigRepository configRepository) { + this.configRepository = configRepository; + return this; + } + + public LogManagerBuilder setInitialDefaultConfig(LogConfig initialDefaultConfig) { + this.initialDefaultConfig = initialDefaultConfig; + return this; + } + + public LogManagerBuilder setCleanerConfig(CleanerConfig cleanerConfig) { + this.cleanerConfig = cleanerConfig; + return this; + } + + public LogManagerBuilder setRecoveryThreadsPerDataDir(int recoveryThreadsPerDataDir) { + this.recoveryThreadsPerDataDir = recoveryThreadsPerDataDir; + return this; + } + + public LogManagerBuilder setFlushCheckMs(long flushCheckMs) { + this.flushCheckMs = flushCheckMs; + return this; + } + + public LogManagerBuilder setFlushRecoveryOffsetCheckpointMs(long flushRecoveryOffsetCheckpointMs) { + this.flushRecoveryOffsetCheckpointMs = flushRecoveryOffsetCheckpointMs; + return this; + } + + public LogManagerBuilder setFlushStartOffsetCheckpointMs(long flushStartOffsetCheckpointMs) { + this.flushStartOffsetCheckpointMs = flushStartOffsetCheckpointMs; + return this; + } + + public LogManagerBuilder setRetentionCheckMs(long retentionCheckMs) { + this.retentionCheckMs = retentionCheckMs; + return this; + } + + public LogManagerBuilder setMaxPidExpirationMs(int maxPidExpirationMs) { + this.maxPidExpirationMs = maxPidExpirationMs; + return this; + } + + public LogManagerBuilder setInterBrokerProtocolVersion(ApiVersion interBrokerProtocolVersion) { + this.interBrokerProtocolVersion = interBrokerProtocolVersion; + return this; + } + + public LogManagerBuilder setScheduler(Scheduler scheduler) { + this.scheduler = scheduler; + return this; + } + + public LogManagerBuilder setBrokerTopicStats(BrokerTopicStats brokerTopicStats) { + this.brokerTopicStats = brokerTopicStats; + return this; + } + + public LogManagerBuilder setLogDirFailureChannel(LogDirFailureChannel logDirFailureChannel) { + this.logDirFailureChannel = logDirFailureChannel; + return this; + } + + public LogManagerBuilder setTime(Time time) { + this.time = time; + return this; + } + + public LogManagerBuilder setKeepPartitionMetadataFile(boolean keepPartitionMetadataFile) { + this.keepPartitionMetadataFile = keepPartitionMetadataFile; + return this; + } + + public LogManager build() { + if (logDirs == null) throw new RuntimeException("you must set logDirs"); + if (configRepository == null) throw new RuntimeException("you must set configRepository"); + if (initialDefaultConfig == null) throw new RuntimeException("you must set initialDefaultConfig"); + if (cleanerConfig == null) throw new RuntimeException("you must set cleanerConfig"); + if (scheduler == null) throw new RuntimeException("you must set scheduler"); + if (brokerTopicStats == null) throw new RuntimeException("you must set brokerTopicStats"); + if (logDirFailureChannel == null) throw new RuntimeException("you must set logDirFailureChannel"); + + return new LogManager(JavaConverters.asScalaIteratorConverter(logDirs.iterator()).asScala().toSeq(), + JavaConverters.asScalaIteratorConverter(initialOfflineDirs.iterator()).asScala().toSeq(), + configRepository, + initialDefaultConfig, + cleanerConfig, + recoveryThreadsPerDataDir, + flushCheckMs, + flushRecoveryOffsetCheckpointMs, + flushStartOffsetCheckpointMs, + retentionCheckMs, + maxPidExpirationMs, + interBrokerProtocolVersion, + scheduler, + brokerTopicStats, + logDirFailureChannel, + time, + keepPartitionMetadataFile); + } +} diff --git a/core/src/main/java/kafka/server/builders/ReplicaManagerBuilder.java b/core/src/main/java/kafka/server/builders/ReplicaManagerBuilder.java new file mode 100644 index 0000000..a005178 --- /dev/null +++ b/core/src/main/java/kafka/server/builders/ReplicaManagerBuilder.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server.builders; + +import kafka.log.LogManager; +import kafka.server.AlterIsrManager; +import kafka.server.BrokerTopicStats; +import kafka.server.DelayedDeleteRecords; +import kafka.server.DelayedElectLeader; +import kafka.server.DelayedFetch; +import kafka.server.DelayedOperationPurgatory; +import kafka.server.DelayedProduce; +import kafka.server.KafkaConfig; +import kafka.server.LogDirFailureChannel; +import kafka.server.MetadataCache; +import kafka.server.QuotaFactory.QuotaManagers; +import kafka.server.ReplicaManager; +import kafka.utils.Scheduler; +import kafka.zk.KafkaZkClient; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.utils.Time; +import scala.compat.java8.OptionConverters; + +import java.util.Collections; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; + + +public class ReplicaManagerBuilder { + private KafkaConfig config = null; + private Metrics metrics = null; + private Time time = Time.SYSTEM; + private Scheduler scheduler = null; + private LogManager logManager = null; + private QuotaManagers quotaManagers = null; + private MetadataCache metadataCache = null; + private LogDirFailureChannel logDirFailureChannel = null; + private AlterIsrManager alterIsrManager = null; + private BrokerTopicStats brokerTopicStats = new BrokerTopicStats(); + private AtomicBoolean isShuttingDown = new AtomicBoolean(false); + private Optional zkClient = Optional.empty(); + private Optional> delayedProducePurgatory = Optional.empty(); + private Optional> delayedFetchPurgatory = Optional.empty(); + private Optional> delayedDeleteRecordsPurgatory = Optional.empty(); + private Optional> delayedElectLeaderPurgatory = Optional.empty(); + private Optional threadNamePrefix = Optional.empty(); + + public ReplicaManagerBuilder setConfig(KafkaConfig config) { + this.config = config; + return this; + } + + public ReplicaManagerBuilder setMetrics(Metrics metrics) { + this.metrics = metrics; + return this; + } + + public ReplicaManagerBuilder setTime(Time time) { + this.time = time; + return this; + } + + public ReplicaManagerBuilder setScheduler(Scheduler scheduler) { + this.scheduler = scheduler; + return this; + } + + public ReplicaManagerBuilder setLogManager(LogManager logManager) { + this.logManager = logManager; + return this; + } + + public ReplicaManagerBuilder setQuotaManagers(QuotaManagers quotaManagers) { + this.quotaManagers = quotaManagers; + return this; + } + + public ReplicaManagerBuilder setMetadataCache(MetadataCache metadataCache) { + this.metadataCache = metadataCache; + return this; + } + + public ReplicaManagerBuilder setLogDirFailureChannel(LogDirFailureChannel logDirFailureChannel) { + this.logDirFailureChannel = logDirFailureChannel; + return this; + } + + public ReplicaManagerBuilder setAlterIsrManager(AlterIsrManager alterIsrManager) { + this.alterIsrManager = alterIsrManager; + return this; + } + + public ReplicaManagerBuilder setBrokerTopicStats(BrokerTopicStats brokerTopicStats) { + this.brokerTopicStats = brokerTopicStats; + return this; + } + + public ReplicaManagerBuilder setIsShuttingDown(AtomicBoolean isShuttingDown) { + this.isShuttingDown = isShuttingDown; + return this; + } + + public ReplicaManagerBuilder setZkClient(KafkaZkClient zkClient) { + this.zkClient = Optional.of(zkClient); + return this; + } + + public ReplicaManagerBuilder setDelayedProducePurgatory(DelayedOperationPurgatory delayedProducePurgatory) { + this.delayedProducePurgatory = Optional.of(delayedProducePurgatory); + return this; + } + + public ReplicaManagerBuilder setDelayedFetchPurgatory(DelayedOperationPurgatory delayedFetchPurgatory) { + this.delayedFetchPurgatory = Optional.of(delayedFetchPurgatory); + return this; + } + + public ReplicaManagerBuilder setDelayedDeleteRecordsPurgatory(DelayedOperationPurgatory delayedDeleteRecordsPurgatory) { + this.delayedDeleteRecordsPurgatory = Optional.of(delayedDeleteRecordsPurgatory); + return this; + } + + public ReplicaManagerBuilder setDelayedElectLeaderPurgatoryParam(DelayedOperationPurgatory delayedElectLeaderPurgatory) { + this.delayedElectLeaderPurgatory = Optional.of(delayedElectLeaderPurgatory); + return this; + } + + public ReplicaManagerBuilder setThreadNamePrefix(String threadNamePrefix) { + this.threadNamePrefix = Optional.of(threadNamePrefix); + return this; + } + + public ReplicaManager build() { + if (config == null) config = new KafkaConfig(Collections.emptyMap()); + if (metrics == null) metrics = new Metrics(); + if (logManager == null) throw new RuntimeException("You must set logManager"); + if (metadataCache == null) throw new RuntimeException("You must set metadataCache"); + if (logDirFailureChannel == null) throw new RuntimeException("You must set logDirFailureChannel"); + if (alterIsrManager == null) throw new RuntimeException("You must set alterIsrManager"); + return new ReplicaManager(config, + metrics, + time, + scheduler, + logManager, + quotaManagers, + metadataCache, + logDirFailureChannel, + alterIsrManager, + brokerTopicStats, + isShuttingDown, + OptionConverters.toScala(zkClient), + OptionConverters.toScala(delayedProducePurgatory), + OptionConverters.toScala(delayedFetchPurgatory), + OptionConverters.toScala(delayedDeleteRecordsPurgatory), + OptionConverters.toScala(delayedElectLeaderPurgatory), + OptionConverters.toScala(threadNamePrefix)); + } +} diff --git a/core/src/main/resources/common/message/GroupMetadataKey.json b/core/src/main/resources/common/message/GroupMetadataKey.json new file mode 100644 index 0000000..fa0c9ff --- /dev/null +++ b/core/src/main/resources/common/message/GroupMetadataKey.json @@ -0,0 +1,24 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "type": "data", + "name": "GroupMetadataKey", + "validVersions": "2", + "flexibleVersions": "none", + "fields": [ + { "name": "group", "type": "string", "versions": "2" } + ] +} diff --git a/core/src/main/resources/common/message/GroupMetadataValue.json b/core/src/main/resources/common/message/GroupMetadataValue.json new file mode 100644 index 0000000..826a7c8 --- /dev/null +++ b/core/src/main/resources/common/message/GroupMetadataValue.json @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "type": "data", + "name": "GroupMetadataValue", + "validVersions": "0-3", + "flexibleVersions": "none", + "fields": [ + { "name": "protocolType", "versions": "0+", "type": "string"}, + { "name": "generation", "versions": "0+", "type": "int32" }, + { "name": "protocol", "versions": "0+", "type": "string", "nullableVersions": "0+" }, + { "name": "leader", "versions": "0+", "type": "string", "nullableVersions": "0+" }, + { "name": "currentStateTimestamp", "versions": "2+", "type": "int64", "default": -1, "ignorable": true}, + { "name": "members", "versions": "0+", "type": "[]MemberMetadata" } + ], + "commonStructs": [ + { + "name": "MemberMetadata", + "versions": "0-3", + "fields": [ + { "name": "memberId", "versions": "0+", "type": "string" }, + { "name": "groupInstanceId", "versions": "3+", "type": "string", "default": "null", "nullableVersions": "3+", "ignorable": true}, + { "name": "clientId", "versions": "0+", "type": "string" }, + { "name": "clientHost", "versions": "0+", "type": "string" }, + { "name": "rebalanceTimeout", "versions": "1+", "type": "int32", "ignorable": true}, + { "name": "sessionTimeout", "versions": "0+", "type": "int32" }, + { "name": "subscription", "versions": "0+", "type": "bytes" }, + { "name": "assignment", "versions": "0+", "type": "bytes" } + ] + } + ] +} diff --git a/core/src/main/resources/common/message/OffsetCommitKey.json b/core/src/main/resources/common/message/OffsetCommitKey.json new file mode 100644 index 0000000..a9d1bc3 --- /dev/null +++ b/core/src/main/resources/common/message/OffsetCommitKey.json @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "type": "data", + "name": "OffsetCommitKey", + "validVersions": "0-1", + "flexibleVersions": "none", + "fields": [ + { "name": "group", "type": "string", "versions": "0-1" }, + { "name": "topic", "type": "string", "versions": "0-1" }, + { "name": "partition", "type": "int32", "versions": "0-1" } + ] +} diff --git a/core/src/main/resources/common/message/OffsetCommitValue.json b/core/src/main/resources/common/message/OffsetCommitValue.json new file mode 100644 index 0000000..db8a628 --- /dev/null +++ b/core/src/main/resources/common/message/OffsetCommitValue.json @@ -0,0 +1,28 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "type": "data", + "name": "OffsetCommitValue", + "validVersions": "0-3", + "flexibleVersions": "none", + "fields": [ + { "name": "offset", "type": "int64", "versions": "0+" }, + { "name": "leaderEpoch", "type": "int32", "versions": "3+", "default": -1, "ignorable": true}, + { "name": "metadata", "type": "string", "versions": "0+" }, + { "name": "commitTimestamp", "type": "int64", "versions": "0+" }, + { "name": "expireTimestamp", "type": "int64", "versions": "1", "default": -1, "ignorable": true} + ] +} diff --git a/core/src/main/resources/common/message/TransactionLogKey.json b/core/src/main/resources/common/message/TransactionLogKey.json new file mode 100644 index 0000000..7a5d3e5 --- /dev/null +++ b/core/src/main/resources/common/message/TransactionLogKey.json @@ -0,0 +1,24 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "type": "data", + "name": "TransactionLogKey", + "validVersions": "0", + "flexibleVersions": "none", + "fields": [ + { "name": "TransactionalId", "type": "string", "versions": "0"} + ] +} diff --git a/core/src/main/resources/common/message/TransactionLogValue.json b/core/src/main/resources/common/message/TransactionLogValue.json new file mode 100644 index 0000000..7915c3d --- /dev/null +++ b/core/src/main/resources/common/message/TransactionLogValue.json @@ -0,0 +1,39 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "type": "data", + "name": "TransactionLogValue", + "validVersions": "0", + "flexibleVersions": "none", + "fields": [ + { "name": "ProducerId", "type": "int64", "versions": "0", + "about": "Producer id in use by the transactional id"}, + { "name": "ProducerEpoch", "type": "int16", "versions": "0", + "about": "Epoch associated with the producer id"}, + { "name": "TransactionTimeoutMs", "type": "int32", "versions": "0", + "about": "Transaction timeout in milliseconds"}, + { "name": "TransactionStatus", "type": "int8", "versions": "0", + "about": "TransactionState the transaction is in"}, + { "name": "TransactionPartitions", "type": "[]PartitionsSchema", "versions": "0", "nullableVersions": "0", + "about": "Set of partitions involved in the transaction", "fields": [ + { "name": "Topic", "type": "string", "versions": "0"}, + { "name": "PartitionIds", "type": "[]int32", "versions": "0"}]}, + { "name": "TransactionLastUpdateTimestampMs", "type": "int64", "versions": "0", + "about": "Time the transaction was last updated"}, + { "name": "TransactionStartTimestampMs", "type": "int64", "versions": "0", + "about": "Time the transaction was started"} + ] +} diff --git a/core/src/main/scala/kafka/Kafka.scala b/core/src/main/scala/kafka/Kafka.scala new file mode 100755 index 0000000..4e278c9 --- /dev/null +++ b/core/src/main/scala/kafka/Kafka.scala @@ -0,0 +1,126 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka + +import java.util.Properties + +import joptsimple.OptionParser +import kafka.server.{KafkaConfig, KafkaRaftServer, KafkaServer, Server} +import kafka.utils.Implicits._ +import kafka.utils.{CommandLineUtils, Exit, Logging} +import org.apache.kafka.common.utils.{Java, LoggingSignalHandler, OperatingSystem, Time, Utils} + +import scala.jdk.CollectionConverters._ + +object Kafka extends Logging { + + def getPropsFromArgs(args: Array[String]): Properties = { + val optionParser = new OptionParser(false) + val overrideOpt = optionParser.accepts("override", "Optional property that should override values set in server.properties file") + .withRequiredArg() + .ofType(classOf[String]) + // This is just to make the parameter show up in the help output, we are not actually using this due the + // fact that this class ignores the first parameter which is interpreted as positional and mandatory + // but would not be mandatory if --version is specified + // This is a bit of an ugly crutch till we get a chance to rework the entire command line parsing + optionParser.accepts("version", "Print version information and exit.") + + if (args.length == 0 || args.contains("--help")) { + CommandLineUtils.printUsageAndDie(optionParser, + "USAGE: java [options] %s server.properties [--override property=value]*".format(this.getClass.getCanonicalName.split('$').head)) + } + + if (args.contains("--version")) { + CommandLineUtils.printVersionAndDie() + } + + val props = Utils.loadProps(args(0)) + + if (args.length > 1) { + val options = optionParser.parse(args.slice(1, args.length): _*) + + if (options.nonOptionArguments().size() > 0) { + CommandLineUtils.printUsageAndDie(optionParser, "Found non argument parameters: " + options.nonOptionArguments().toArray.mkString(",")) + } + + props ++= CommandLineUtils.parseKeyValueArgs(options.valuesOf(overrideOpt).asScala) + } + props + } + + private def buildServer(props: Properties): Server = { + val config = KafkaConfig.fromProps(props, false) + if (config.requiresZookeeper) { + new KafkaServer( + config, + Time.SYSTEM, + threadNamePrefix = None, + enableForwarding = false + ) + } else { + new KafkaRaftServer( + config, + Time.SYSTEM, + threadNamePrefix = None + ) + } + } + + def main(args: Array[String]): Unit = { + try { + val serverProps = getPropsFromArgs(args) + val server = buildServer(serverProps) + + try { + if (!OperatingSystem.IS_WINDOWS && !Java.isIbmJdk) + new LoggingSignalHandler().register() + } catch { + case e: ReflectiveOperationException => + warn("Failed to register optional signal handler that logs a message when the process is terminated " + + s"by a signal. Reason for registration failure is: $e", e) + } + + // attach shutdown handler to catch terminating signals as well as normal termination + Exit.addShutdownHook("kafka-shutdown-hook", { + try server.shutdown() + catch { + case _: Throwable => + fatal("Halting Kafka.") + // Calling exit() can lead to deadlock as exit() can be called multiple times. Force exit. + Exit.halt(1) + } + }) + + try server.startup() + catch { + case _: Throwable => + // KafkaServer.startup() calls shutdown() in case of exceptions, so we invoke `exit` to set the status code + fatal("Exiting Kafka.") + Exit.exit(1) + } + + server.awaitShutdown() + } + catch { + case e: Throwable => + fatal("Exiting Kafka due to fatal exception", e) + Exit.exit(1) + } + Exit.exit(0) + } +} diff --git a/core/src/main/scala/kafka/admin/AclCommand.scala b/core/src/main/scala/kafka/admin/AclCommand.scala new file mode 100644 index 0000000..116ca24 --- /dev/null +++ b/core/src/main/scala/kafka/admin/AclCommand.scala @@ -0,0 +1,679 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.admin + +import java.util.Properties + +import joptsimple._ +import joptsimple.util.EnumConverter +import kafka.security.authorizer.{AclAuthorizer, AclEntry, AuthorizerUtils} +import kafka.server.KafkaConfig +import kafka.utils._ +import org.apache.kafka.clients.admin.{Admin, AdminClientConfig} +import org.apache.kafka.common.acl._ +import org.apache.kafka.common.acl.AclOperation._ +import org.apache.kafka.common.acl.AclPermissionType.{ALLOW, DENY} +import org.apache.kafka.common.resource.{PatternType, ResourcePattern, ResourcePatternFilter, Resource => JResource, ResourceType => JResourceType} +import org.apache.kafka.common.security.JaasUtils +import org.apache.kafka.common.security.auth.KafkaPrincipal +import org.apache.kafka.common.utils.{Utils, SecurityUtils => JSecurityUtils} +import org.apache.kafka.server.authorizer.Authorizer + +import scala.jdk.CollectionConverters._ +import scala.collection.mutable +import scala.io.StdIn + +object AclCommand extends Logging { + + val AuthorizerDeprecationMessage: String = "Warning: support for ACL configuration directly " + + "through the authorizer is deprecated and will be removed in a future release. Please use " + + "--bootstrap-server instead to set ACLs through the admin client." + val ClusterResourceFilter = new ResourcePatternFilter(JResourceType.CLUSTER, JResource.CLUSTER_NAME, PatternType.LITERAL) + + private val Newline = scala.util.Properties.lineSeparator + + def main(args: Array[String]): Unit = { + + val opts = new AclCommandOptions(args) + + CommandLineUtils.printHelpAndExitIfNeeded(opts, "This tool helps to manage acls on kafka.") + + opts.checkArgs() + + val aclCommandService = { + if (opts.options.has(opts.bootstrapServerOpt)) { + new AdminClientService(opts) + } else { + val authorizerClassName = if (opts.options.has(opts.authorizerOpt)) + opts.options.valueOf(opts.authorizerOpt) + else + classOf[AclAuthorizer].getName + + new AuthorizerService(authorizerClassName, opts) + } + } + + try { + if (opts.options.has(opts.addOpt)) + aclCommandService.addAcls() + else if (opts.options.has(opts.removeOpt)) + aclCommandService.removeAcls() + else if (opts.options.has(opts.listOpt)) + aclCommandService.listAcls() + } catch { + case e: Throwable => + println(s"Error while executing ACL command: ${e.getMessage}") + println(Utils.stackTrace(e)) + Exit.exit(1) + } + } + + sealed trait AclCommandService { + def addAcls(): Unit + def removeAcls(): Unit + def listAcls(): Unit + } + + class AdminClientService(val opts: AclCommandOptions) extends AclCommandService with Logging { + + private def withAdminClient(opts: AclCommandOptions)(f: Admin => Unit): Unit = { + val props = if (opts.options.has(opts.commandConfigOpt)) + Utils.loadProps(opts.options.valueOf(opts.commandConfigOpt)) + else + new Properties() + props.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, opts.options.valueOf(opts.bootstrapServerOpt)) + val adminClient = Admin.create(props) + + try { + f(adminClient) + } finally { + adminClient.close() + } + } + + def addAcls(): Unit = { + val resourceToAcl = getResourceToAcls(opts) + withAdminClient(opts) { adminClient => + for ((resource, acls) <- resourceToAcl) { + println(s"Adding ACLs for resource `$resource`: $Newline ${acls.map("\t" + _).mkString(Newline)} $Newline") + val aclBindings = acls.map(acl => new AclBinding(resource, acl)).asJavaCollection + adminClient.createAcls(aclBindings).all().get() + } + + listAcls(adminClient) + } + } + + def removeAcls(): Unit = { + withAdminClient(opts) { adminClient => + val filterToAcl = getResourceFilterToAcls(opts) + + for ((filter, acls) <- filterToAcl) { + if (acls.isEmpty) { + if (confirmAction(opts, s"Are you sure you want to delete all ACLs for resource filter `$filter`? (y/n)")) + removeAcls(adminClient, acls, filter) + } else { + if (confirmAction(opts, s"Are you sure you want to remove ACLs: $Newline ${acls.map("\t" + _).mkString(Newline)} $Newline from resource filter `$filter`? (y/n)")) + removeAcls(adminClient, acls, filter) + } + } + + listAcls(adminClient) + } + } + + def listAcls(): Unit = { + withAdminClient(opts) { adminClient => + listAcls(adminClient) + } + } + + private def listAcls(adminClient: Admin): Unit = { + val filters = getResourceFilter(opts, dieIfNoResourceFound = false) + val listPrincipals = getPrincipals(opts, opts.listPrincipalsOpt) + val resourceToAcls = getAcls(adminClient, filters) + + if (listPrincipals.isEmpty) { + printResourceAcls(resourceToAcls) + } else { + listPrincipals.foreach{principal => + println(s"ACLs for principal `$principal`") + val filteredResourceToAcls = resourceToAcls.map { case (resource, acls) => + resource -> acls.filter(acl => principal.toString.equals(acl.principal)) + }.filter { case (_, acls) => acls.nonEmpty } + printResourceAcls(filteredResourceToAcls) + } + } + } + + private def printResourceAcls(resourceToAcls: Map[ResourcePattern, Set[AccessControlEntry]]): Unit = { + for ((resource, acls) <- resourceToAcls) + println(s"Current ACLs for resource `$resource`: $Newline ${acls.map("\t" + _).mkString(Newline)} $Newline") + } + + private def removeAcls(adminClient: Admin, acls: Set[AccessControlEntry], filter: ResourcePatternFilter): Unit = { + if (acls.isEmpty) + adminClient.deleteAcls(List(new AclBindingFilter(filter, AccessControlEntryFilter.ANY)).asJava).all().get() + else { + val aclBindingFilters = acls.map(acl => new AclBindingFilter(filter, acl.toFilter)).toList.asJava + adminClient.deleteAcls(aclBindingFilters).all().get() + } + } + + private def getAcls(adminClient: Admin, filters: Set[ResourcePatternFilter]): Map[ResourcePattern, Set[AccessControlEntry]] = { + val aclBindings = + if (filters.isEmpty) adminClient.describeAcls(AclBindingFilter.ANY).values().get().asScala.toList + else { + val results = for (filter <- filters) yield { + adminClient.describeAcls(new AclBindingFilter(filter, AccessControlEntryFilter.ANY)).values().get().asScala.toList + } + results.reduceLeft(_ ++ _) + } + + val resourceToAcls = mutable.Map[ResourcePattern, Set[AccessControlEntry]]().withDefaultValue(Set()) + + aclBindings.foreach(aclBinding => resourceToAcls(aclBinding.pattern()) = resourceToAcls(aclBinding.pattern()) + aclBinding.entry()) + resourceToAcls.toMap + } + } + + class AuthorizerService(val authorizerClassName: String, val opts: AclCommandOptions) extends AclCommandService with Logging { + + private def withAuthorizer()(f: Authorizer => Unit): Unit = { + // It is possible that zookeeper.set.acl could be true without SASL if mutual certificate authentication is configured. + // We will default the value of zookeeper.set.acl to true or false based on whether SASL is configured, + // but if SASL is not configured and zookeeper.set.acl is supposed to be true due to mutual certificate authentication + // then it will be up to the user to explicitly specify zookeeper.set.acl=true in the authorizer-properties. + val defaultProps = Map(KafkaConfig.ZkEnableSecureAclsProp -> JaasUtils.isZkSaslEnabled) + val authorizerPropertiesWithoutTls = + if (opts.options.has(opts.authorizerPropertiesOpt)) { + val authorizerProperties = opts.options.valuesOf(opts.authorizerPropertiesOpt).asScala + defaultProps ++ CommandLineUtils.parseKeyValueArgs(authorizerProperties, acceptMissingValue = false).asScala + } else { + defaultProps + } + val authorizerProperties = + if (opts.options.has(opts.zkTlsConfigFile)) { + // load in TLS configs both with and without the "authorizer." prefix + val validKeys = (KafkaConfig.ZkSslConfigToSystemPropertyMap.keys.toList ++ KafkaConfig.ZkSslConfigToSystemPropertyMap.keys.map("authorizer." + _).toList).asJava + authorizerPropertiesWithoutTls ++ Utils.loadProps(opts.options.valueOf(opts.zkTlsConfigFile), validKeys).asInstanceOf[java.util.Map[String, Any]].asScala + } + else + authorizerPropertiesWithoutTls + + val authZ = AuthorizerUtils.createAuthorizer(authorizerClassName) + try { + authZ.configure(authorizerProperties.asJava) + f(authZ) + } + finally CoreUtils.swallow(authZ.close(), this) + } + + def addAcls(): Unit = { + val resourceToAcl = getResourceToAcls(opts) + withAuthorizer() { authorizer => + for ((resource, acls) <- resourceToAcl) { + println(s"Adding ACLs for resource `$resource`: $Newline ${acls.map("\t" + _).mkString(Newline)} $Newline") + val aclBindings = acls.map(acl => new AclBinding(resource, acl)) + authorizer.createAcls(null,aclBindings.toList.asJava).asScala.map(_.toCompletableFuture.get).foreach { result => + result.exception.ifPresent { exception => + println(s"Error while adding ACLs: ${exception.getMessage}") + println(Utils.stackTrace(exception)) + } + } + } + + listAcls() + } + } + + def removeAcls(): Unit = { + withAuthorizer() { authorizer => + val filterToAcl = getResourceFilterToAcls(opts) + + for ((filter, acls) <- filterToAcl) { + if (acls.isEmpty) { + if (confirmAction(opts, s"Are you sure you want to delete all ACLs for resource filter `$filter`? (y/n)")) + removeAcls(authorizer, acls, filter) + } else { + if (confirmAction(opts, s"Are you sure you want to remove ACLs: $Newline ${acls.map("\t" + _).mkString(Newline)} $Newline from resource filter `$filter`? (y/n)")) + removeAcls(authorizer, acls, filter) + } + } + + listAcls() + } + } + + def listAcls(): Unit = { + withAuthorizer() { authorizer => + val filters = getResourceFilter(opts, dieIfNoResourceFound = false) + val listPrincipals = getPrincipals(opts, opts.listPrincipalsOpt) + val resourceToAcls = getAcls(authorizer, filters) + + if (listPrincipals.isEmpty) { + for ((resource, acls) <- resourceToAcls) + println(s"Current ACLs for resource `$resource`: $Newline ${acls.map("\t" + _).mkString(Newline)} $Newline") + } else { + listPrincipals.foreach(principal => { + println(s"ACLs for principal `$principal`") + val filteredResourceToAcls = resourceToAcls.map { case (resource, acls) => + resource -> acls.filter(acl => principal.toString.equals(acl.principal)) + }.filter { case (_, acls) => acls.nonEmpty } + + for ((resource, acls) <- filteredResourceToAcls) + println(s"Current ACLs for resource `$resource`: $Newline ${acls.map("\t" + _).mkString(Newline)} $Newline") + }) + } + } + } + + private def removeAcls(authorizer: Authorizer, acls: Set[AccessControlEntry], filter: ResourcePatternFilter): Unit = { + val result = if (acls.isEmpty) + authorizer.deleteAcls(null, List(new AclBindingFilter(filter, AccessControlEntryFilter.ANY)).asJava) + else { + val aclBindingFilters = acls.map(acl => new AclBindingFilter(filter, acl.toFilter)).toList.asJava + authorizer.deleteAcls(null, aclBindingFilters) + } + result.asScala.map(_.toCompletableFuture.get).foreach { result => + result.exception.ifPresent { exception => + println(s"Error while removing ACLs: ${exception.getMessage}") + println(Utils.stackTrace(exception)) + } + result.aclBindingDeleteResults.forEach { deleteResult => + deleteResult.exception.ifPresent { exception => + println(s"Error while removing ACLs: ${exception.getMessage}") + println(Utils.stackTrace(exception)) + } + } + } + } + + private def getAcls(authorizer: Authorizer, filters: Set[ResourcePatternFilter]): Map[ResourcePattern, Set[AccessControlEntry]] = { + val aclBindings = + if (filters.isEmpty) authorizer.acls(AclBindingFilter.ANY).asScala + else { + val results = for (filter <- filters) yield { + authorizer.acls(new AclBindingFilter(filter, AccessControlEntryFilter.ANY)).asScala + } + results.reduceLeft(_ ++ _) + } + + val resourceToAcls = mutable.Map[ResourcePattern, Set[AccessControlEntry]]().withDefaultValue(Set()) + + aclBindings.foreach(aclBinding => resourceToAcls(aclBinding.pattern()) = resourceToAcls(aclBinding.pattern()) + aclBinding.entry()) + resourceToAcls.toMap + } + } + + private def getResourceToAcls(opts: AclCommandOptions): Map[ResourcePattern, Set[AccessControlEntry]] = { + val patternType = opts.options.valueOf(opts.resourcePatternType) + if (!patternType.isSpecific) + CommandLineUtils.printUsageAndDie(opts.parser, s"A '--resource-pattern-type' value of '$patternType' is not valid when adding acls.") + + val resourceToAcl = getResourceFilterToAcls(opts).map { + case (filter, acls) => + new ResourcePattern(filter.resourceType(), filter.name(), filter.patternType()) -> acls + } + + if (resourceToAcl.values.exists(_.isEmpty)) + CommandLineUtils.printUsageAndDie(opts.parser, "You must specify one of: --allow-principal, --deny-principal when trying to add ACLs.") + + resourceToAcl + } + + private def getResourceFilterToAcls(opts: AclCommandOptions): Map[ResourcePatternFilter, Set[AccessControlEntry]] = { + var resourceToAcls = Map.empty[ResourcePatternFilter, Set[AccessControlEntry]] + + //if none of the --producer or --consumer options are specified , just construct ACLs from CLI options. + if (!opts.options.has(opts.producerOpt) && !opts.options.has(opts.consumerOpt)) { + resourceToAcls ++= getCliResourceFilterToAcls(opts) + } + + //users are allowed to specify both --producer and --consumer options in a single command. + if (opts.options.has(opts.producerOpt)) + resourceToAcls ++= getProducerResourceFilterToAcls(opts) + + if (opts.options.has(opts.consumerOpt)) + resourceToAcls ++= getConsumerResourceFilterToAcls(opts).map { case (k, v) => k -> (v ++ resourceToAcls.getOrElse(k, Set.empty[AccessControlEntry])) } + + validateOperation(opts, resourceToAcls) + + resourceToAcls + } + + private def getProducerResourceFilterToAcls(opts: AclCommandOptions): Map[ResourcePatternFilter, Set[AccessControlEntry]] = { + val filters = getResourceFilter(opts) + + val topics = filters.filter(_.resourceType == JResourceType.TOPIC) + val transactionalIds = filters.filter(_.resourceType == JResourceType.TRANSACTIONAL_ID) + val enableIdempotence = opts.options.has(opts.idempotentOpt) + + val topicAcls = getAcl(opts, Set(WRITE, DESCRIBE, CREATE)) + val transactionalIdAcls = getAcl(opts, Set(WRITE, DESCRIBE)) + + //Write, Describe, Create permission on topics, Write, Describe on transactionalIds + topics.map(_ -> topicAcls).toMap ++ + transactionalIds.map(_ -> transactionalIdAcls).toMap ++ + (if (enableIdempotence) + Map(ClusterResourceFilter -> getAcl(opts, Set(IDEMPOTENT_WRITE))) + else + Map.empty) + } + + private def getConsumerResourceFilterToAcls(opts: AclCommandOptions): Map[ResourcePatternFilter, Set[AccessControlEntry]] = { + val filters = getResourceFilter(opts) + + val topics = filters.filter(_.resourceType == JResourceType.TOPIC) + val groups = filters.filter(_.resourceType == JResourceType.GROUP) + + //Read, Describe on topic, Read on consumerGroup + + val acls = getAcl(opts, Set(READ, DESCRIBE)) + + topics.map(_ -> acls).toMap[ResourcePatternFilter, Set[AccessControlEntry]] ++ + groups.map(_ -> getAcl(opts, Set(READ))).toMap[ResourcePatternFilter, Set[AccessControlEntry]] + } + + private def getCliResourceFilterToAcls(opts: AclCommandOptions): Map[ResourcePatternFilter, Set[AccessControlEntry]] = { + val acls = getAcl(opts) + val filters = getResourceFilter(opts) + filters.map(_ -> acls).toMap + } + + private def getAcl(opts: AclCommandOptions, operations: Set[AclOperation]): Set[AccessControlEntry] = { + val allowedPrincipals = getPrincipals(opts, opts.allowPrincipalsOpt) + + val deniedPrincipals = getPrincipals(opts, opts.denyPrincipalsOpt) + + val allowedHosts = getHosts(opts, opts.allowHostsOpt, opts.allowPrincipalsOpt) + + val deniedHosts = getHosts(opts, opts.denyHostsOpt, opts.denyPrincipalsOpt) + + val acls = new collection.mutable.HashSet[AccessControlEntry] + if (allowedHosts.nonEmpty && allowedPrincipals.nonEmpty) + acls ++= getAcls(allowedPrincipals, ALLOW, operations, allowedHosts) + + if (deniedHosts.nonEmpty && deniedPrincipals.nonEmpty) + acls ++= getAcls(deniedPrincipals, DENY, operations, deniedHosts) + + acls.toSet + } + + private def getAcl(opts: AclCommandOptions): Set[AccessControlEntry] = { + val operations = opts.options.valuesOf(opts.operationsOpt).asScala + .map(operation => JSecurityUtils.operation(operation.trim)).toSet + getAcl(opts, operations) + } + + def getAcls(principals: Set[KafkaPrincipal], permissionType: AclPermissionType, operations: Set[AclOperation], + hosts: Set[String]): Set[AccessControlEntry] = { + for { + principal <- principals + operation <- operations + host <- hosts + } yield new AccessControlEntry(principal.toString, host, operation, permissionType) + } + + private def getHosts(opts: AclCommandOptions, hostOptionSpec: ArgumentAcceptingOptionSpec[String], + principalOptionSpec: ArgumentAcceptingOptionSpec[String]): Set[String] = { + if (opts.options.has(hostOptionSpec)) + opts.options.valuesOf(hostOptionSpec).asScala.map(_.trim).toSet + else if (opts.options.has(principalOptionSpec)) + Set[String](AclEntry.WildcardHost) + else + Set.empty[String] + } + + private def getPrincipals(opts: AclCommandOptions, principalOptionSpec: ArgumentAcceptingOptionSpec[String]): Set[KafkaPrincipal] = { + if (opts.options.has(principalOptionSpec)) + opts.options.valuesOf(principalOptionSpec).asScala.map(s => JSecurityUtils.parseKafkaPrincipal(s.trim)).toSet + else + Set.empty[KafkaPrincipal] + } + + private def getResourceFilter(opts: AclCommandOptions, dieIfNoResourceFound: Boolean = true): Set[ResourcePatternFilter] = { + val patternType = opts.options.valueOf(opts.resourcePatternType) + + var resourceFilters = Set.empty[ResourcePatternFilter] + if (opts.options.has(opts.topicOpt)) + opts.options.valuesOf(opts.topicOpt).forEach(topic => resourceFilters += new ResourcePatternFilter(JResourceType.TOPIC, topic.trim, patternType)) + + if (patternType == PatternType.LITERAL && (opts.options.has(opts.clusterOpt) || opts.options.has(opts.idempotentOpt))) + resourceFilters += ClusterResourceFilter + + if (opts.options.has(opts.groupOpt)) + opts.options.valuesOf(opts.groupOpt).forEach(group => resourceFilters += new ResourcePatternFilter(JResourceType.GROUP, group.trim, patternType)) + + if (opts.options.has(opts.transactionalIdOpt)) + opts.options.valuesOf(opts.transactionalIdOpt).forEach(transactionalId => + resourceFilters += new ResourcePatternFilter(JResourceType.TRANSACTIONAL_ID, transactionalId, patternType)) + + if (opts.options.has(opts.delegationTokenOpt)) + opts.options.valuesOf(opts.delegationTokenOpt).forEach(token => resourceFilters += new ResourcePatternFilter(JResourceType.DELEGATION_TOKEN, token.trim, patternType)) + + if (resourceFilters.isEmpty && dieIfNoResourceFound) + CommandLineUtils.printUsageAndDie(opts.parser, "You must provide at least one resource: --topic or --cluster or --group or --delegation-token ") + + resourceFilters + } + + private def confirmAction(opts: AclCommandOptions, msg: String): Boolean = { + if (opts.options.has(opts.forceOpt)) + return true + println(msg) + StdIn.readLine().equalsIgnoreCase("y") + } + + private def validateOperation(opts: AclCommandOptions, resourceToAcls: Map[ResourcePatternFilter, Set[AccessControlEntry]]): Unit = { + for ((resource, acls) <- resourceToAcls) { + val validOps = AclEntry.supportedOperations(resource.resourceType) + AclOperation.ALL + if ((acls.map(_.operation) -- validOps).nonEmpty) + CommandLineUtils.printUsageAndDie(opts.parser, s"ResourceType ${resource.resourceType} only supports operations ${validOps.mkString(",")}") + } + } + + class AclCommandOptions(args: Array[String]) extends CommandDefaultOptions(args) { + val CommandConfigDoc = "A property file containing configs to be passed to Admin Client." + + val bootstrapServerOpt = parser.accepts("bootstrap-server", "A list of host/port pairs to use for establishing the connection to the Kafka cluster." + + " This list should be in the form host1:port1,host2:port2,... This config is required for acl management using admin client API.") + .withRequiredArg + .describedAs("server to connect to") + .ofType(classOf[String]) + + val commandConfigOpt = parser.accepts("command-config", CommandConfigDoc) + .withOptionalArg() + .describedAs("command-config") + .ofType(classOf[String]) + + val authorizerOpt = parser.accepts("authorizer", "DEPRECATED: Fully qualified class name of " + + "the authorizer, which defaults to kafka.security.authorizer.AclAuthorizer if --bootstrap-server is not provided. " + + AclCommand.AuthorizerDeprecationMessage) + .withRequiredArg + .describedAs("authorizer") + .ofType(classOf[String]) + + val authorizerPropertiesOpt = parser.accepts("authorizer-properties", "DEPRECATED: The " + + "properties required to configure an instance of the Authorizer specified by --authorizer. " + + "These are key=val pairs. For the default authorizer, example values are: zookeeper.connect=localhost:2181. " + + AclCommand.AuthorizerDeprecationMessage) + .withRequiredArg + .describedAs("authorizer-properties") + .ofType(classOf[String]) + + val topicOpt = parser.accepts("topic", "topic to which ACLs should be added or removed. " + + "A value of * indicates ACL should apply to all topics.") + .withRequiredArg + .describedAs("topic") + .ofType(classOf[String]) + + val clusterOpt = parser.accepts("cluster", "Add/Remove cluster ACLs.") + val groupOpt = parser.accepts("group", "Consumer Group to which the ACLs should be added or removed. " + + "A value of * indicates the ACLs should apply to all groups.") + .withRequiredArg + .describedAs("group") + .ofType(classOf[String]) + + val transactionalIdOpt = parser.accepts("transactional-id", "The transactionalId to which ACLs should " + + "be added or removed. A value of * indicates the ACLs should apply to all transactionalIds.") + .withRequiredArg + .describedAs("transactional-id") + .ofType(classOf[String]) + + val idempotentOpt = parser.accepts("idempotent", "Enable idempotence for the producer. This should be " + + "used in combination with the --producer option. Note that idempotence is enabled automatically if " + + "the producer is authorized to a particular transactional-id.") + + val delegationTokenOpt = parser.accepts("delegation-token", "Delegation token to which ACLs should be added or removed. " + + "A value of * indicates ACL should apply to all tokens.") + .withRequiredArg + .describedAs("delegation-token") + .ofType(classOf[String]) + + val resourcePatternType = parser.accepts("resource-pattern-type", "The type of the resource pattern or pattern filter. " + + "When adding acls, this should be a specific pattern type, e.g. 'literal' or 'prefixed'. " + + "When listing or removing acls, a specific pattern type can be used to list or remove acls from specific resource patterns, " + + "or use the filter values of 'any' or 'match', where 'any' will match any pattern type, but will match the resource name exactly, " + + "where as 'match' will perform pattern matching to list or remove all acls that affect the supplied resource(s). " + + "WARNING: 'match', when used in combination with the '--remove' switch, should be used with care.") + .withRequiredArg() + .ofType(classOf[String]) + .withValuesConvertedBy(new PatternTypeConverter()) + .defaultsTo(PatternType.LITERAL) + + val addOpt = parser.accepts("add", "Indicates you are trying to add ACLs.") + val removeOpt = parser.accepts("remove", "Indicates you are trying to remove ACLs.") + val listOpt = parser.accepts("list", "List ACLs for the specified resource, use --topic or --group or --cluster to specify a resource.") + + val operationsOpt = parser.accepts("operation", "Operation that is being allowed or denied. Valid operation names are: " + Newline + + AclEntry.AclOperations.map("\t" + JSecurityUtils.operationName(_)).mkString(Newline) + Newline) + .withRequiredArg + .ofType(classOf[String]) + .defaultsTo(JSecurityUtils.operationName(AclOperation.ALL)) + + val allowPrincipalsOpt = parser.accepts("allow-principal", "principal is in principalType:name format." + + " Note that principalType must be supported by the Authorizer being used." + + " For example, User:* is the wild card indicating all users.") + .withRequiredArg + .describedAs("allow-principal") + .ofType(classOf[String]) + + val denyPrincipalsOpt = parser.accepts("deny-principal", "principal is in principalType:name format. " + + "By default anyone not added through --allow-principal is denied access. " + + "You only need to use this option as negation to already allowed set. " + + "Note that principalType must be supported by the Authorizer being used. " + + "For example if you wanted to allow access to all users in the system but not test-user you can define an ACL that " + + "allows access to User:* and specify --deny-principal=User:test@EXAMPLE.COM. " + + "AND PLEASE REMEMBER DENY RULES TAKES PRECEDENCE OVER ALLOW RULES.") + .withRequiredArg + .describedAs("deny-principal") + .ofType(classOf[String]) + + val listPrincipalsOpt = parser.accepts("principal", "List ACLs for the specified principal. principal is in principalType:name format." + + " Note that principalType must be supported by the Authorizer being used. Multiple --principal option can be passed.") + .withOptionalArg() + .describedAs("principal") + .ofType(classOf[String]) + + val allowHostsOpt = parser.accepts("allow-host", "Host from which principals listed in --allow-principal will have access. " + + "If you have specified --allow-principal then the default for this option will be set to * which allows access from all hosts.") + .withRequiredArg + .describedAs("allow-host") + .ofType(classOf[String]) + + val denyHostsOpt = parser.accepts("deny-host", "Host from which principals listed in --deny-principal will be denied access. " + + "If you have specified --deny-principal then the default for this option will be set to * which denies access from all hosts.") + .withRequiredArg + .describedAs("deny-host") + .ofType(classOf[String]) + + val producerOpt = parser.accepts("producer", "Convenience option to add/remove ACLs for producer role. " + + "This will generate ACLs that allows WRITE,DESCRIBE and CREATE on topic.") + + val consumerOpt = parser.accepts("consumer", "Convenience option to add/remove ACLs for consumer role. " + + "This will generate ACLs that allows READ,DESCRIBE on topic and READ on group.") + + val forceOpt = parser.accepts("force", "Assume Yes to all queries and do not prompt.") + + val zkTlsConfigFile = parser.accepts("zk-tls-config-file", + "DEPRECATED: Identifies the file where ZooKeeper client TLS connectivity properties are defined for" + + " the default authorizer kafka.security.authorizer.AclAuthorizer." + + " Any properties other than the following (with or without an \"authorizer.\" prefix) are ignored: " + + KafkaConfig.ZkSslConfigToSystemPropertyMap.keys.toList.sorted.mkString(", ") + + ". Note that if SASL is not configured and zookeeper.set.acl is supposed to be true due to mutual certificate authentication being used" + + " then it is necessary to explicitly specify --authorizer-properties zookeeper.set.acl=true. " + + AclCommand.AuthorizerDeprecationMessage) + .withRequiredArg().describedAs("Authorizer ZooKeeper TLS configuration").ofType(classOf[String]) + + options = parser.parse(args: _*) + + def checkArgs(): Unit = { + if (options.has(bootstrapServerOpt) && options.has(authorizerOpt)) + CommandLineUtils.printUsageAndDie(parser, "Only one of --bootstrap-server or --authorizer must be specified") + + if (!options.has(bootstrapServerOpt)) { + CommandLineUtils.checkRequiredArgs(parser, options, authorizerPropertiesOpt) + System.err.println(AclCommand.AuthorizerDeprecationMessage) + } + + if (options.has(commandConfigOpt) && !options.has(bootstrapServerOpt)) + CommandLineUtils.printUsageAndDie(parser, "The --command-config option can only be used with --bootstrap-server option") + + if (options.has(authorizerPropertiesOpt) && options.has(bootstrapServerOpt)) + CommandLineUtils.printUsageAndDie(parser, "The --authorizer-properties option can only be used with --authorizer option") + + val actions = Seq(addOpt, removeOpt, listOpt).count(options.has) + if (actions != 1) + CommandLineUtils.printUsageAndDie(parser, "Command must include exactly one action: --list, --add, --remove. ") + + CommandLineUtils.checkInvalidArgs(parser, options, listOpt, Set(producerOpt, consumerOpt, allowHostsOpt, allowPrincipalsOpt, denyHostsOpt, denyPrincipalsOpt)) + + //when --producer or --consumer is specified , user should not specify operations as they are inferred and we also disallow --deny-principals and --deny-hosts. + CommandLineUtils.checkInvalidArgs(parser, options, producerOpt, Set(operationsOpt, denyPrincipalsOpt, denyHostsOpt)) + CommandLineUtils.checkInvalidArgs(parser, options, consumerOpt, Set(operationsOpt, denyPrincipalsOpt, denyHostsOpt)) + + if (options.has(listPrincipalsOpt) && !options.has(listOpt)) + CommandLineUtils.printUsageAndDie(parser, "The --principal option is only available if --list is set") + + if (options.has(producerOpt) && !options.has(topicOpt)) + CommandLineUtils.printUsageAndDie(parser, "With --producer you must specify a --topic") + + if (options.has(idempotentOpt) && !options.has(producerOpt)) + CommandLineUtils.printUsageAndDie(parser, "The --idempotent option is only available if --producer is set") + + if (options.has(consumerOpt) && (!options.has(topicOpt) || !options.has(groupOpt) || (!options.has(producerOpt) && (options.has(clusterOpt) || options.has(transactionalIdOpt))))) + CommandLineUtils.printUsageAndDie(parser, "With --consumer you must specify a --topic and a --group and no --cluster or --transactional-id option should be specified.") + } + } +} + +class PatternTypeConverter extends EnumConverter[PatternType](classOf[PatternType]) { + + override def convert(value: String): PatternType = { + val patternType = super.convert(value) + if (patternType.isUnknown) + throw new ValueConversionException("Unknown resource-pattern-type: " + value) + + patternType + } + + override def valuePattern: String = PatternType.values + .filter(_ != PatternType.UNKNOWN) + .mkString("|") +} diff --git a/core/src/main/scala/kafka/admin/AdminOperationException.scala b/core/src/main/scala/kafka/admin/AdminOperationException.scala new file mode 100644 index 0000000..a45b3f7 --- /dev/null +++ b/core/src/main/scala/kafka/admin/AdminOperationException.scala @@ -0,0 +1,23 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.admin + +class AdminOperationException(val error: String, cause: Throwable) extends RuntimeException(error, cause) { + def this(error: Throwable) = this(error.getMessage, error) + def this(msg: String) = this(msg, null) +} \ No newline at end of file diff --git a/core/src/main/scala/kafka/admin/AdminUtils.scala b/core/src/main/scala/kafka/admin/AdminUtils.scala new file mode 100644 index 0000000..a37e3d6 --- /dev/null +++ b/core/src/main/scala/kafka/admin/AdminUtils.scala @@ -0,0 +1,239 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.admin + +import java.util.Random + +import kafka.utils.Logging +import org.apache.kafka.common.errors.{InvalidPartitionsException, InvalidReplicationFactorException} + +import collection.{Map, mutable, _} + +object AdminUtils extends Logging { + val rand = new Random + val AdminClientId = "__admin_client" + + /** + * There are 3 goals of replica assignment: + * + *
              + *
            1. Spread the replicas evenly among brokers.
            2. + *
            3. For partitions assigned to a particular broker, their other replicas are spread over the other brokers.
            4. + *
            5. If all brokers have rack information, assign the replicas for each partition to different racks if possible
            6. + *
            + * + * To achieve this goal for replica assignment without considering racks, we: + *
              + *
            1. Assign the first replica of each partition by round-robin, starting from a random position in the broker list.
            2. + *
            3. Assign the remaining replicas of each partition with an increasing shift.
            4. + *
            + * + * Here is an example of assigning + * + * + * + * + * + * + * + * + *
            broker-0broker-1broker-2broker-3broker-4 
            p0 p1 p2 p3 p4 (1st replica)
            p5 p6 p7 p8 p9 (1st replica)
            p4 p0 p1 p2 p3 (2nd replica)
            p8 p9 p5 p6 p7 (2nd replica)
            p3 p4 p0 p1 p2 (3nd replica)
            p7 p8 p9 p5 p6 (3nd replica)
            + * + *

            + * To create rack aware assignment, this API will first create a rack alternated broker list. For example, + * from this brokerID -> rack mapping:

            + * 0 -> "rack1", 1 -> "rack3", 2 -> "rack3", 3 -> "rack2", 4 -> "rack2", 5 -> "rack1" + *

            + *

            + * The rack alternated list will be: + *

            + * 0, 3, 1, 5, 4, 2 + *

            + *

            + * Then an easy round-robin assignment can be applied. Assume 6 partitions with replication factor of 3, the assignment + * will be: + *

            + * 0 -> 0,3,1
            + * 1 -> 3,1,5
            + * 2 -> 1,5,4
            + * 3 -> 5,4,2
            + * 4 -> 4,2,0
            + * 5 -> 2,0,3
            + *
            + *

            + * Once it has completed the first round-robin, if there are more partitions to assign, the algorithm will start + * shifting the followers. This is to ensure we will not always get the same set of sequences. + * In this case, if there is another partition to assign (partition #6), the assignment will be: + *

            + * 6 -> 0,4,2 (instead of repeating 0,3,1 as partition 0) + *

            + *

            + * The rack aware assignment always chooses the 1st replica of the partition using round robin on the rack alternated + * broker list. For rest of the replicas, it will be biased towards brokers on racks that do not have + * any replica assignment, until every rack has a replica. Then the assignment will go back to round-robin on + * the broker list. + *

            + *
            + *

            + * As the result, if the number of replicas is equal to or greater than the number of racks, it will ensure that + * each rack will get at least one replica. Otherwise, each rack will get at most one replica. In a perfect + * situation where the number of replicas is the same as the number of racks and each rack has the same number of + * brokers, it guarantees that the replica distribution is even across brokers and racks. + *

            + * @return a Map from partition id to replica ids + * @throws AdminOperationException If rack information is supplied but it is incomplete, or if it is not possible to + * assign each replica to a unique rack. + * + */ + def assignReplicasToBrokers(brokerMetadatas: Iterable[BrokerMetadata], + nPartitions: Int, + replicationFactor: Int, + fixedStartIndex: Int = -1, + startPartitionId: Int = -1): Map[Int, Seq[Int]] = { + if (nPartitions <= 0) + throw new InvalidPartitionsException("Number of partitions must be larger than 0.") + if (replicationFactor <= 0) + throw new InvalidReplicationFactorException("Replication factor must be larger than 0.") + if (replicationFactor > brokerMetadatas.size) + throw new InvalidReplicationFactorException(s"Replication factor: $replicationFactor larger than available brokers: ${brokerMetadatas.size}.") + if (brokerMetadatas.forall(_.rack.isEmpty)) + assignReplicasToBrokersRackUnaware(nPartitions, replicationFactor, brokerMetadatas.map(_.id), fixedStartIndex, + startPartitionId) + else { + if (brokerMetadatas.exists(_.rack.isEmpty)) + throw new AdminOperationException("Not all brokers have rack information for replica rack aware assignment.") + assignReplicasToBrokersRackAware(nPartitions, replicationFactor, brokerMetadatas, fixedStartIndex, + startPartitionId) + } + } + + private def assignReplicasToBrokersRackUnaware(nPartitions: Int, + replicationFactor: Int, + brokerList: Iterable[Int], + fixedStartIndex: Int, + startPartitionId: Int): Map[Int, Seq[Int]] = { + val ret = mutable.Map[Int, Seq[Int]]() + val brokerArray = brokerList.toArray + val startIndex = if (fixedStartIndex >= 0) fixedStartIndex else rand.nextInt(brokerArray.length) + var currentPartitionId = math.max(0, startPartitionId) + var nextReplicaShift = if (fixedStartIndex >= 0) fixedStartIndex else rand.nextInt(brokerArray.length) + for (_ <- 0 until nPartitions) { + if (currentPartitionId > 0 && (currentPartitionId % brokerArray.length == 0)) + nextReplicaShift += 1 + val firstReplicaIndex = (currentPartitionId + startIndex) % brokerArray.length + val replicaBuffer = mutable.ArrayBuffer(brokerArray(firstReplicaIndex)) + for (j <- 0 until replicationFactor - 1) + replicaBuffer += brokerArray(replicaIndex(firstReplicaIndex, nextReplicaShift, j, brokerArray.length)) + ret.put(currentPartitionId, replicaBuffer) + currentPartitionId += 1 + } + ret + } + + private def assignReplicasToBrokersRackAware(nPartitions: Int, + replicationFactor: Int, + brokerMetadatas: Iterable[BrokerMetadata], + fixedStartIndex: Int, + startPartitionId: Int): Map[Int, Seq[Int]] = { + val brokerRackMap = brokerMetadatas.collect { case BrokerMetadata(id, Some(rack)) => + id -> rack + }.toMap + val numRacks = brokerRackMap.values.toSet.size + val arrangedBrokerList = getRackAlternatedBrokerList(brokerRackMap) + val numBrokers = arrangedBrokerList.size + val ret = mutable.Map[Int, Seq[Int]]() + val startIndex = if (fixedStartIndex >= 0) fixedStartIndex else rand.nextInt(arrangedBrokerList.size) + var currentPartitionId = math.max(0, startPartitionId) + var nextReplicaShift = if (fixedStartIndex >= 0) fixedStartIndex else rand.nextInt(arrangedBrokerList.size) + for (_ <- 0 until nPartitions) { + if (currentPartitionId > 0 && (currentPartitionId % arrangedBrokerList.size == 0)) + nextReplicaShift += 1 + val firstReplicaIndex = (currentPartitionId + startIndex) % arrangedBrokerList.size + val leader = arrangedBrokerList(firstReplicaIndex) + val replicaBuffer = mutable.ArrayBuffer(leader) + val racksWithReplicas = mutable.Set(brokerRackMap(leader)) + val brokersWithReplicas = mutable.Set(leader) + var k = 0 + for (_ <- 0 until replicationFactor - 1) { + var done = false + while (!done) { + val broker = arrangedBrokerList(replicaIndex(firstReplicaIndex, nextReplicaShift * numRacks, k, arrangedBrokerList.size)) + val rack = brokerRackMap(broker) + // Skip this broker if + // 1. there is already a broker in the same rack that has assigned a replica AND there is one or more racks + // that do not have any replica, or + // 2. the broker has already assigned a replica AND there is one or more brokers that do not have replica assigned + if ((!racksWithReplicas.contains(rack) || racksWithReplicas.size == numRacks) + && (!brokersWithReplicas.contains(broker) || brokersWithReplicas.size == numBrokers)) { + replicaBuffer += broker + racksWithReplicas += rack + brokersWithReplicas += broker + done = true + } + k += 1 + } + } + ret.put(currentPartitionId, replicaBuffer) + currentPartitionId += 1 + } + ret + } + + /** + * Given broker and rack information, returns a list of brokers alternated by the rack. Assume + * this is the rack and its brokers: + * + * rack1: 0, 1, 2 + * rack2: 3, 4, 5 + * rack3: 6, 7, 8 + * + * This API would return the list of 0, 3, 6, 1, 4, 7, 2, 5, 8 + * + * This is essential to make sure that the assignReplicasToBrokers API can use such list and + * assign replicas to brokers in a simple round-robin fashion, while ensuring an even + * distribution of leader and replica counts on each broker and that replicas are + * distributed to all racks. + */ + private[admin] def getRackAlternatedBrokerList(brokerRackMap: Map[Int, String]): IndexedSeq[Int] = { + val brokersIteratorByRack = getInverseMap(brokerRackMap).map { case (rack, brokers) => + (rack, brokers.iterator) + } + val racks = brokersIteratorByRack.keys.toArray.sorted + val result = new mutable.ArrayBuffer[Int] + var rackIndex = 0 + while (result.size < brokerRackMap.size) { + val rackIterator = brokersIteratorByRack(racks(rackIndex)) + if (rackIterator.hasNext) + result += rackIterator.next() + rackIndex = (rackIndex + 1) % racks.length + } + result + } + + private[admin] def getInverseMap(brokerRackMap: Map[Int, String]): Map[String, Seq[Int]] = { + brokerRackMap.toSeq.map { case (id, rack) => (rack, id) } + .groupBy { case (rack, _) => rack } + .map { case (rack, rackAndIdList) => (rack, rackAndIdList.map { case (_, id) => id }.sorted) } + } + + private def replicaIndex(firstReplicaIndex: Int, secondReplicaShift: Int, replicaIndex: Int, nBrokers: Int): Int = { + val shift = 1 + (secondReplicaShift + replicaIndex) % (nBrokers - 1) + (firstReplicaIndex + shift) % nBrokers + } + +} diff --git a/core/src/main/scala/kafka/admin/BrokerApiVersionsCommand.scala b/core/src/main/scala/kafka/admin/BrokerApiVersionsCommand.scala new file mode 100644 index 0000000..f6f8706 --- /dev/null +++ b/core/src/main/scala/kafka/admin/BrokerApiVersionsCommand.scala @@ -0,0 +1,330 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.admin + +import java.io.PrintStream +import java.io.IOException +import java.util.Properties +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.{ConcurrentLinkedQueue, TimeUnit} + +import kafka.utils.{CommandDefaultOptions, CommandLineUtils} +import kafka.utils.Implicits._ +import kafka.utils.Logging +import org.apache.kafka.common.utils.Utils +import org.apache.kafka.clients.{ApiVersions, ClientDnsLookup, ClientResponse, ClientUtils, CommonClientConfigs, Metadata, NetworkClient, NodeApiVersions} +import org.apache.kafka.clients.consumer.internals.{ConsumerNetworkClient, RequestFuture} +import org.apache.kafka.common.config.ConfigDef.ValidString._ +import org.apache.kafka.common.config.ConfigDef.{Importance, Type} +import org.apache.kafka.common.config.{AbstractConfig, ConfigDef} +import org.apache.kafka.common.errors.AuthenticationException +import org.apache.kafka.common.internals.ClusterResourceListeners +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.network.Selector +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.utils.LogContext +import org.apache.kafka.common.utils.{KafkaThread, Time} +import org.apache.kafka.common.Node +import org.apache.kafka.common.message.ApiVersionsResponseData.ApiVersionCollection +import org.apache.kafka.common.requests.{AbstractRequest, AbstractResponse, ApiVersionsRequest, ApiVersionsResponse, MetadataRequest, MetadataResponse} + +import scala.jdk.CollectionConverters._ +import scala.util.{Failure, Success, Try} + +/** + * A command for retrieving broker version information. + */ +object BrokerApiVersionsCommand { + + def main(args: Array[String]): Unit = { + execute(args, System.out) + } + + def execute(args: Array[String], out: PrintStream): Unit = { + val opts = new BrokerVersionCommandOptions(args) + val adminClient = createAdminClient(opts) + adminClient.awaitBrokers() + val brokerMap = adminClient.listAllBrokerVersionInfo() + brokerMap.forKeyValue { (broker, versionInfoOrError) => + versionInfoOrError match { + case Success(v) => out.print(s"${broker} -> ${v.toString(true)}\n") + case Failure(v) => out.print(s"${broker} -> ERROR: ${v}\n") + } + } + adminClient.close() + } + + private def createAdminClient(opts: BrokerVersionCommandOptions): AdminClient = { + val props = if (opts.options.has(opts.commandConfigOpt)) + Utils.loadProps(opts.options.valueOf(opts.commandConfigOpt)) + else + new Properties() + props.put(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, opts.options.valueOf(opts.bootstrapServerOpt)) + AdminClient.create(props) + } + + class BrokerVersionCommandOptions(args: Array[String]) extends CommandDefaultOptions(args) { + val BootstrapServerDoc = "REQUIRED: The server to connect to." + val CommandConfigDoc = "A property file containing configs to be passed to Admin Client." + + val commandConfigOpt = parser.accepts("command-config", CommandConfigDoc) + .withRequiredArg + .describedAs("command config property file") + .ofType(classOf[String]) + val bootstrapServerOpt = parser.accepts("bootstrap-server", BootstrapServerDoc) + .withRequiredArg + .describedAs("server(s) to use for bootstrapping") + .ofType(classOf[String]) + options = parser.parse(args : _*) + checkArgs() + + def checkArgs(): Unit = { + CommandLineUtils.printHelpAndExitIfNeeded(this, "This tool helps to retrieve broker version information.") + // check required args + CommandLineUtils.checkRequiredArgs(parser, options, bootstrapServerOpt) + } + } + + // org.apache.kafka.clients.admin.AdminClient doesn't currently expose a way to retrieve the supported api versions. + // We inline the bits we need from kafka.admin.AdminClient so that we can delete it. + private class AdminClient(val time: Time, + val client: ConsumerNetworkClient, + val bootstrapBrokers: List[Node]) extends Logging { + + @volatile var running = true + val pendingFutures = new ConcurrentLinkedQueue[RequestFuture[ClientResponse]]() + + val networkThread = new KafkaThread("admin-client-network-thread", () => { + try { + while (running) + client.poll(time.timer(Long.MaxValue)) + } catch { + case t: Throwable => + error("admin-client-network-thread exited", t) + } finally { + pendingFutures.forEach { future => + try { + future.raise(Errors.UNKNOWN_SERVER_ERROR) + } catch { + case _: IllegalStateException => // It is OK if the future has been completed + } + } + pendingFutures.clear() + } + }, true) + + networkThread.start() + + private def send(target: Node, + request: AbstractRequest.Builder[_ <: AbstractRequest]): AbstractResponse = { + val future = client.send(target, request) + pendingFutures.add(future) + future.awaitDone(Long.MaxValue, TimeUnit.MILLISECONDS) + pendingFutures.remove(future) + if (future.succeeded()) + future.value().responseBody() + else + throw future.exception() + } + + private def sendAnyNode(request: AbstractRequest.Builder[_ <: AbstractRequest]): AbstractResponse = { + bootstrapBrokers.foreach { broker => + try { + return send(broker, request) + } catch { + case e: AuthenticationException => + throw e + case e: Exception => + debug(s"Request ${request.apiKey()} failed against node $broker", e) + } + } + throw new RuntimeException(s"Request ${request.apiKey()} failed on brokers $bootstrapBrokers") + } + + private def getApiVersions(node: Node): ApiVersionCollection = { + val response = send(node, new ApiVersionsRequest.Builder()).asInstanceOf[ApiVersionsResponse] + Errors.forCode(response.data.errorCode).maybeThrow() + response.data.apiKeys + } + + /** + * Wait until there is a non-empty list of brokers in the cluster. + */ + def awaitBrokers(): Unit = { + var nodes = List[Node]() + do { + nodes = findAllBrokers() + if (nodes.isEmpty) + Thread.sleep(50) + } while (nodes.isEmpty) + } + + private def findAllBrokers(): List[Node] = { + val request = MetadataRequest.Builder.allTopics() + val response = sendAnyNode(request).asInstanceOf[MetadataResponse] + val errors = response.errors + if (!errors.isEmpty) + debug(s"Metadata request contained errors: $errors") + response.buildCluster.nodes.asScala.toList + } + + def listAllBrokerVersionInfo(): Map[Node, Try[NodeApiVersions]] = + findAllBrokers().map { broker => + broker -> Try[NodeApiVersions](new NodeApiVersions(getApiVersions(broker))) + }.toMap + + def close(): Unit = { + running = false + try { + client.close() + } catch { + case e: IOException => + error("Exception closing nioSelector:", e) + } + } + + } + + private object AdminClient { + val DefaultConnectionMaxIdleMs = 9 * 60 * 1000 + val DefaultRequestTimeoutMs = 5000 + val DefaultSocketConnectionSetupMs = CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MS_CONFIG + val DefaultSocketConnectionSetupMaxMs = CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS_CONFIG + val DefaultMaxInFlightRequestsPerConnection = 100 + val DefaultReconnectBackoffMs = 50 + val DefaultReconnectBackoffMax = 50 + val DefaultSendBufferBytes = 128 * 1024 + val DefaultReceiveBufferBytes = 32 * 1024 + val DefaultRetryBackoffMs = 100 + + val AdminClientIdSequence = new AtomicInteger(1) + val AdminConfigDef = { + val config = new ConfigDef() + .define( + CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, + Type.LIST, + Importance.HIGH, + CommonClientConfigs.BOOTSTRAP_SERVERS_DOC) + .define(CommonClientConfigs.CLIENT_DNS_LOOKUP_CONFIG, + Type.STRING, + ClientDnsLookup.USE_ALL_DNS_IPS.toString, + in(ClientDnsLookup.USE_ALL_DNS_IPS.toString, + ClientDnsLookup.RESOLVE_CANONICAL_BOOTSTRAP_SERVERS_ONLY.toString), + Importance.MEDIUM, + CommonClientConfigs.CLIENT_DNS_LOOKUP_DOC) + .define( + CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, + ConfigDef.Type.STRING, + CommonClientConfigs.DEFAULT_SECURITY_PROTOCOL, + ConfigDef.Importance.MEDIUM, + CommonClientConfigs.SECURITY_PROTOCOL_DOC) + .define( + CommonClientConfigs.REQUEST_TIMEOUT_MS_CONFIG, + ConfigDef.Type.INT, + DefaultRequestTimeoutMs, + ConfigDef.Importance.MEDIUM, + CommonClientConfigs.REQUEST_TIMEOUT_MS_DOC) + .define( + CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MS_CONFIG, + ConfigDef.Type.LONG, + CommonClientConfigs.DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT_MS, + ConfigDef.Importance.MEDIUM, + CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MS_DOC) + .define( + CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS_CONFIG, + ConfigDef.Type.LONG, + CommonClientConfigs.DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS, + ConfigDef.Importance.MEDIUM, + CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS_DOC) + .define( + CommonClientConfigs.RETRY_BACKOFF_MS_CONFIG, + ConfigDef.Type.LONG, + DefaultRetryBackoffMs, + ConfigDef.Importance.MEDIUM, + CommonClientConfigs.RETRY_BACKOFF_MS_DOC) + .withClientSslSupport() + .withClientSaslSupport() + config + } + + class AdminConfig(originals: Map[_,_]) extends AbstractConfig(AdminConfigDef, originals.asJava, false) + + def create(props: Properties): AdminClient = create(props.asScala.toMap) + + def create(props: Map[String, _]): AdminClient = create(new AdminConfig(props)) + + def create(config: AdminConfig): AdminClient = { + val clientId = "admin-" + AdminClientIdSequence.getAndIncrement() + val logContext = new LogContext(s"[LegacyAdminClient clientId=$clientId] ") + val time = Time.SYSTEM + val metrics = new Metrics(time) + val metadata = new Metadata(100L, 60 * 60 * 1000L, logContext, + new ClusterResourceListeners) + val channelBuilder = ClientUtils.createChannelBuilder(config, time, logContext) + val requestTimeoutMs = config.getInt(CommonClientConfigs.REQUEST_TIMEOUT_MS_CONFIG) + val connectionSetupTimeoutMs = config.getLong(CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MS_CONFIG) + val connectionSetupTimeoutMaxMs = config.getLong(CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS_CONFIG) + val retryBackoffMs = config.getLong(CommonClientConfigs.RETRY_BACKOFF_MS_CONFIG) + + val brokerUrls = config.getList(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG) + val clientDnsLookup = config.getString(CommonClientConfigs.CLIENT_DNS_LOOKUP_CONFIG) + val brokerAddresses = ClientUtils.parseAndValidateAddresses(brokerUrls, clientDnsLookup) + metadata.bootstrap(brokerAddresses) + + val selector = new Selector( + DefaultConnectionMaxIdleMs, + metrics, + time, + "admin", + channelBuilder, + logContext) + + val networkClient = new NetworkClient( + selector, + metadata, + clientId, + DefaultMaxInFlightRequestsPerConnection, + DefaultReconnectBackoffMs, + DefaultReconnectBackoffMax, + DefaultSendBufferBytes, + DefaultReceiveBufferBytes, + requestTimeoutMs, + connectionSetupTimeoutMs, + connectionSetupTimeoutMaxMs, + time, + true, + new ApiVersions, + logContext) + + val highLevelClient = new ConsumerNetworkClient( + logContext, + networkClient, + metadata, + time, + retryBackoffMs, + requestTimeoutMs, + Integer.MAX_VALUE) + + new AdminClient( + time, + highLevelClient, + metadata.fetch.nodes.asScala.toList) + } + } + +} diff --git a/core/src/main/scala/kafka/admin/BrokerMetadata.scala b/core/src/main/scala/kafka/admin/BrokerMetadata.scala new file mode 100644 index 0000000..86831e3 --- /dev/null +++ b/core/src/main/scala/kafka/admin/BrokerMetadata.scala @@ -0,0 +1,23 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package kafka.admin + +/** + * Broker metadata used by admin tools. + * + * @param id an integer that uniquely identifies this broker + * @param rack the rack of the broker, which is used to in rack aware partition assignment for fault tolerance. + * Examples: "RACK1", "us-east-1d" + */ +case class BrokerMetadata(id: Int, rack: Option[String]) diff --git a/core/src/main/scala/kafka/admin/ConfigCommand.scala b/core/src/main/scala/kafka/admin/ConfigCommand.scala new file mode 100644 index 0000000..5e5ccef --- /dev/null +++ b/core/src/main/scala/kafka/admin/ConfigCommand.scala @@ -0,0 +1,952 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.admin + +import java.nio.charset.StandardCharsets +import java.util.concurrent.TimeUnit +import java.util.{Collections, Properties} + +import joptsimple._ +import kafka.common.Config +import kafka.log.LogConfig +import kafka.server.DynamicConfig.QuotaConfigs +import kafka.server.{ConfigEntityName, ConfigType, Defaults, DynamicBrokerConfig, DynamicConfig, KafkaConfig} +import kafka.utils.{CommandDefaultOptions, CommandLineUtils, Exit, PasswordEncoder} +import kafka.utils.Implicits._ +import kafka.zk.{AdminZkClient, KafkaZkClient} +import org.apache.kafka.clients.admin.{Admin, AlterClientQuotasOptions, AlterConfigOp, AlterConfigsOptions, ConfigEntry, DescribeClusterOptions, DescribeConfigsOptions, ListTopicsOptions, ScramCredentialInfo, UserScramCredentialDeletion, UserScramCredentialUpsertion, Config => JConfig, ScramMechanism => PublicScramMechanism} +import org.apache.kafka.clients.CommonClientConfigs +import org.apache.kafka.common.config.ConfigResource +import org.apache.kafka.common.config.types.Password +import org.apache.kafka.common.errors.InvalidConfigurationException +import org.apache.kafka.common.internals.Topic +import org.apache.kafka.common.quota.{ClientQuotaAlteration, ClientQuotaEntity, ClientQuotaFilter, ClientQuotaFilterComponent} +import org.apache.kafka.common.security.JaasUtils +import org.apache.kafka.common.security.scram.internals.{ScramCredentialUtils, ScramFormatter, ScramMechanism} +import org.apache.kafka.common.utils.{Sanitizer, Time, Utils} +import org.apache.zookeeper.client.ZKClientConfig + +import scala.annotation.nowarn +import scala.jdk.CollectionConverters._ +import scala.collection._ + +/** + * This script can be used to change configs for topics/clients/users/brokers/ips dynamically + * An entity described or altered by the command may be one of: + *
              + *
            • topic: --topic OR --entity-type topics --entity-name + *
            • client: --client OR --entity-type clients --entity-name + *
            • user: --user OR --entity-type users --entity-name + *
            • : --user --client OR + * --entity-type users --entity-name --entity-type clients --entity-name + *
            • broker: --broker OR --entity-type brokers --entity-name + *
            • broker-logger: --broker-logger OR --entity-type broker-loggers --entity-name + *
            • ip: --ip OR --entity-type ips --entity-name + *
            + * --entity-type --entity-default may be specified in place of --entity-type --entity-name + * when describing or altering default configuration for users, clients, brokers, or ips, respectively. + * Alternatively, --user-defaults, --client-defaults, --broker-defaults, or --ip-defaults may be specified in place of + * --entity-type --entity-default, respectively. + * + * For most use cases, this script communicates with a kafka cluster (specified via the + * `--bootstrap-server` option). There are three exceptions where direct communication with a + * ZooKeeper ensemble (specified via the `--zookeeper` option) is allowed: + * + * 1. Describe/alter user configs where the config is a SCRAM mechanism name (i.e. a SCRAM credential for a user) + * 2. Describe/alter broker configs for a particular broker when that broker is down + * 3. Describe/alter broker default configs when all brokers are down + * + * For example, this allows password configs to be stored encrypted in ZK before brokers are started, + * avoiding cleartext passwords in `server.properties`. + */ +object ConfigCommand extends Config { + + val BrokerDefaultEntityName = "" + val BrokerLoggerConfigType = "broker-loggers" + val BrokerSupportedConfigTypes = ConfigType.all :+ BrokerLoggerConfigType + val ZkSupportedConfigTypes = Seq(ConfigType.User, ConfigType.Broker) + val DefaultScramIterations = 4096 + + def main(args: Array[String]): Unit = { + try { + val opts = new ConfigCommandOptions(args) + + CommandLineUtils.printHelpAndExitIfNeeded(opts, "This tool helps to manipulate and describe entity config for a topic, client, user, broker or ip") + + opts.checkArgs() + + if (opts.options.has(opts.zkConnectOpt)) { + println(s"Warning: --zookeeper is deprecated and will be removed in a future version of Kafka.") + println(s"Use --bootstrap-server instead to specify a broker to connect to.") + processCommandWithZk(opts.options.valueOf(opts.zkConnectOpt), opts) + } else { + processCommand(opts) + } + } catch { + case e @ (_: IllegalArgumentException | _: InvalidConfigurationException | _: OptionException) => + logger.debug(s"Failed config command with args '${args.mkString(" ")}'", e) + System.err.println(e.getMessage) + Exit.exit(1) + + case t: Throwable => + logger.debug(s"Error while executing config command with args '${args.mkString(" ")}'", t) + System.err.println(s"Error while executing config command with args '${args.mkString(" ")}'") + t.printStackTrace(System.err) + Exit.exit(1) + } + } + + private def processCommandWithZk(zkConnectString: String, opts: ConfigCommandOptions): Unit = { + val zkClientConfig = ZkSecurityMigrator.createZkClientConfigFromOption(opts.options, opts.zkTlsConfigFile) + .getOrElse(new ZKClientConfig()) + val zkClient = KafkaZkClient(zkConnectString, JaasUtils.isZkSaslEnabled || KafkaConfig.zkTlsClientAuthEnabled(zkClientConfig), 30000, 30000, + Int.MaxValue, Time.SYSTEM, zkClientConfig = zkClientConfig, name = "ConfigCommand") + val adminZkClient = new AdminZkClient(zkClient) + try { + if (opts.options.has(opts.alterOpt)) + alterConfigWithZk(zkClient, opts, adminZkClient) + else if (opts.options.has(opts.describeOpt)) + describeConfigWithZk(zkClient, opts, adminZkClient) + } finally { + zkClient.close() + } + } + + private[admin] def alterConfigWithZk(zkClient: KafkaZkClient, opts: ConfigCommandOptions, adminZkClient: AdminZkClient): Unit = { + val configsToBeAdded = parseConfigsToBeAdded(opts) + val configsToBeDeleted = parseConfigsToBeDeleted(opts) + val entity = parseEntity(opts) + val entityType = entity.root.entityType + val entityName = entity.fullSanitizedName + val errorMessage = s"--bootstrap-server option must be specified to update $entityType configs: {add: $configsToBeAdded, delete: $configsToBeDeleted}" + + if (entityType == ConfigType.User) { + if (!configsToBeAdded.isEmpty || !configsToBeDeleted.isEmpty) { + val info = "User configuration updates using ZooKeeper are only supported for SCRAM credential updates." + val scramMechanismNames = ScramMechanism.values.map(_.mechanismName) + // make sure every added/deleted configs are SCRAM related, other configs are not supported using zookeeper + require(configsToBeAdded.stringPropertyNames.asScala.forall(scramMechanismNames.contains), + s"$errorMessage. $info") + require(configsToBeDeleted.forall(scramMechanismNames.contains), s"$errorMessage. $info") + } + preProcessScramCredentials(configsToBeAdded) + } else if (entityType == ConfigType.Broker) { + // Dynamic broker configs can be updated using ZooKeeper only if the corresponding broker is not running. + if (!configsToBeAdded.isEmpty || !configsToBeDeleted.isEmpty) { + validateBrokersNotRunning(entityName, adminZkClient, zkClient, errorMessage) + + val perBrokerConfig = entityName != ConfigEntityName.Default + preProcessBrokerConfigs(configsToBeAdded, perBrokerConfig) + } + } + + // compile the final set of configs + val configs = adminZkClient.fetchEntityConfig(entityType, entityName) + + // fail the command if any of the configs to be deleted does not exist + val invalidConfigs = configsToBeDeleted.filterNot(configs.containsKey(_)) + if (invalidConfigs.nonEmpty) + throw new InvalidConfigurationException(s"Invalid config(s): ${invalidConfigs.mkString(",")}") + + configs ++= configsToBeAdded + configsToBeDeleted.foreach(configs.remove(_)) + + adminZkClient.changeConfigs(entityType, entityName, configs) + + println(s"Completed updating config for entity: $entity.") + } + + private def validateBrokersNotRunning(entityName: String, + adminZkClient: AdminZkClient, + zkClient: KafkaZkClient, + errorMessage: String): Unit = { + val perBrokerConfig = entityName != ConfigEntityName.Default + val info = "Broker configuration operations using ZooKeeper are only supported if the affected broker(s) are not running." + if (perBrokerConfig) { + adminZkClient.parseBroker(entityName).foreach { brokerId => + require(zkClient.getBroker(brokerId).isEmpty, s"$errorMessage - broker $brokerId is running. $info") + } + } else { + val runningBrokersCount = zkClient.getAllBrokersInCluster.size + require(runningBrokersCount == 0, s"$errorMessage - $runningBrokersCount brokers are running. $info") + } + } + + private def preProcessScramCredentials(configsToBeAdded: Properties): Unit = { + def scramCredential(mechanism: ScramMechanism, credentialStr: String): String = { + val pattern = "(?:iterations=([0-9]*),)?password=(.*)".r + val (iterations, password) = credentialStr match { + case pattern(iterations, password) => (if (iterations != null) iterations.toInt else DefaultScramIterations, password) + case _ => throw new IllegalArgumentException(s"Invalid credential property $mechanism=$credentialStr") + } + if (iterations < mechanism.minIterations()) + throw new IllegalArgumentException(s"Iterations $iterations is less than the minimum ${mechanism.minIterations()} required for $mechanism") + val credential = new ScramFormatter(mechanism).generateCredential(password, iterations) + ScramCredentialUtils.credentialToString(credential) + } + for (mechanism <- ScramMechanism.values) { + configsToBeAdded.getProperty(mechanism.mechanismName) match { + case null => + case value => + configsToBeAdded.setProperty(mechanism.mechanismName, scramCredential(mechanism, value)) + } + } + } + + private[admin] def createPasswordEncoder(encoderConfigs: Map[String, String]): PasswordEncoder = { + encoderConfigs.get(KafkaConfig.PasswordEncoderSecretProp) + val encoderSecret = encoderConfigs.getOrElse(KafkaConfig.PasswordEncoderSecretProp, + throw new IllegalArgumentException("Password encoder secret not specified")) + new PasswordEncoder(new Password(encoderSecret), + None, + encoderConfigs.get(KafkaConfig.PasswordEncoderCipherAlgorithmProp).getOrElse(Defaults.PasswordEncoderCipherAlgorithm), + encoderConfigs.get(KafkaConfig.PasswordEncoderKeyLengthProp).map(_.toInt).getOrElse(Defaults.PasswordEncoderKeyLength), + encoderConfigs.get(KafkaConfig.PasswordEncoderIterationsProp).map(_.toInt).getOrElse(Defaults.PasswordEncoderIterations)) + } + + /** + * Pre-process broker configs provided to convert them to persistent format. + * Password configs are encrypted using the secret `KafkaConfig.PasswordEncoderSecretProp`. + * The secret is removed from `configsToBeAdded` and will not be persisted in ZooKeeper. + */ + private def preProcessBrokerConfigs(configsToBeAdded: Properties, perBrokerConfig: Boolean): Unit = { + val passwordEncoderConfigs = new Properties + passwordEncoderConfigs ++= configsToBeAdded.asScala.filter { case (key, _) => key.startsWith("password.encoder.") } + if (!passwordEncoderConfigs.isEmpty) { + info(s"Password encoder configs ${passwordEncoderConfigs.keySet} will be used for encrypting" + + " passwords, but will not be stored in ZooKeeper.") + passwordEncoderConfigs.asScala.keySet.foreach(configsToBeAdded.remove) + } + + DynamicBrokerConfig.validateConfigs(configsToBeAdded, perBrokerConfig) + val passwordConfigs = configsToBeAdded.asScala.keySet.filter(DynamicBrokerConfig.isPasswordConfig) + if (passwordConfigs.nonEmpty) { + require(passwordEncoderConfigs.containsKey(KafkaConfig.PasswordEncoderSecretProp), + s"${KafkaConfig.PasswordEncoderSecretProp} must be specified to update $passwordConfigs." + + " Other password encoder configs like cipher algorithm and iterations may also be specified" + + " to override the default encoding parameters. Password encoder configs will not be persisted" + + " in ZooKeeper." + ) + + val passwordEncoder = createPasswordEncoder(passwordEncoderConfigs.asScala) + passwordConfigs.foreach { configName => + val encodedValue = passwordEncoder.encode(new Password(configsToBeAdded.getProperty(configName))) + configsToBeAdded.setProperty(configName, encodedValue) + } + } + } + + private[admin] def describeConfigWithZk(zkClient: KafkaZkClient, opts: ConfigCommandOptions, adminZkClient: AdminZkClient): Unit = { + val configEntity = parseEntity(opts) + val entityType = configEntity.root.entityType + val describeAllUsers = entityType == ConfigType.User && !configEntity.root.sanitizedName.isDefined && !configEntity.child.isDefined + val entityName = configEntity.fullSanitizedName + val errorMessage = s"--bootstrap-server option must be specified to describe $entityType" + if (entityType == ConfigType.Broker) { + // Dynamic broker configs can be described using ZooKeeper only if the corresponding broker is not running. + validateBrokersNotRunning(entityName, adminZkClient, zkClient, errorMessage) + } + + val entities = configEntity.getAllEntities(zkClient) + for (entity <- entities) { + val configs = adminZkClient.fetchEntityConfig(entity.root.entityType, entity.fullSanitizedName) + // When describing all users, don't include empty user nodes with only quota overrides. + if (!configs.isEmpty || !describeAllUsers) { + println("Configs for %s are %s" + .format(entity, configs.asScala.map(kv => kv._1 + "=" + kv._2).mkString(","))) + } + } + } + + @nowarn("cat=deprecation") + private[admin] def parseConfigsToBeAdded(opts: ConfigCommandOptions): Properties = { + val props = new Properties + if (opts.options.has(opts.addConfigFile)) { + val file = opts.options.valueOf(opts.addConfigFile) + props ++= Utils.loadProps(file) + } + if (opts.options.has(opts.addConfig)) { + // Split list by commas, but avoid those in [], then into KV pairs + // Each KV pair is of format key=value, split them into key and value, using -1 as the limit for split() to + // include trailing empty strings. This is to support empty value (e.g. 'ssl.endpoint.identification.algorithm=') + val pattern = "(?=[^\\]]*(?:\\[|$))" + val configsToBeAdded = opts.options.valueOf(opts.addConfig) + .split("," + pattern) + .map(_.split("""\s*=\s*""" + pattern, -1)) + require(configsToBeAdded.forall(config => config.length == 2), "Invalid entity config: all configs to be added must be in the format \"key=val\".") + //Create properties, parsing square brackets from values if necessary + configsToBeAdded.foreach(pair => props.setProperty(pair(0).trim, pair(1).replaceAll("\\[?\\]?", "").trim)) + } + if (props.containsKey(LogConfig.MessageFormatVersionProp)) { + println(s"WARNING: The configuration ${LogConfig.MessageFormatVersionProp}=${props.getProperty(LogConfig.MessageFormatVersionProp)} is specified. " + + "This configuration will be ignored if the version is newer than the inter.broker.protocol.version specified in the broker or " + + "if the inter.broker.protocol.version is 3.0 or newer. This configuration is deprecated and it will be removed in Apache Kafka 4.0.") + } + props + } + + private[admin] def parseConfigsToBeDeleted(opts: ConfigCommandOptions): Seq[String] = { + if (opts.options.has(opts.deleteConfig)) { + val configsToBeDeleted = opts.options.valuesOf(opts.deleteConfig).asScala.map(_.trim()) + val propsToBeDeleted = new Properties + configsToBeDeleted.foreach(propsToBeDeleted.setProperty(_, "")) + configsToBeDeleted + } + else + Seq.empty + } + + private def processCommand(opts: ConfigCommandOptions): Unit = { + val props = if (opts.options.has(opts.commandConfigOpt)) + Utils.loadProps(opts.options.valueOf(opts.commandConfigOpt)) + else + new Properties() + props.put(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, opts.options.valueOf(opts.bootstrapServerOpt)) + val adminClient = Admin.create(props) + + if (opts.options.has(opts.alterOpt) && opts.entityTypes.size != opts.entityNames.size) + throw new IllegalArgumentException(s"An entity name must be specified for every entity type") + + try { + if (opts.options.has(opts.alterOpt)) + alterConfig(adminClient, opts) + else if (opts.options.has(opts.describeOpt)) + describeConfig(adminClient, opts) + } finally { + adminClient.close() + } + } + + @nowarn("cat=deprecation") + private[admin] def alterConfig(adminClient: Admin, opts: ConfigCommandOptions): Unit = { + val entityTypes = opts.entityTypes + val entityNames = opts.entityNames + val entityTypeHead = entityTypes.head + val entityNameHead = entityNames.head + val configsToBeAddedMap = parseConfigsToBeAdded(opts).asScala.toMap // no need for mutability + val configsToBeAdded = configsToBeAddedMap.map { case (k, v) => (k, new ConfigEntry(k, v)) } + val configsToBeDeleted = parseConfigsToBeDeleted(opts) + + entityTypeHead match { + case ConfigType.Topic => + val oldConfig = getResourceConfig(adminClient, entityTypeHead, entityNameHead, includeSynonyms = false, describeAll = false) + .map { entry => (entry.name, entry) }.toMap + + // fail the command if any of the configs to be deleted does not exist + val invalidConfigs = configsToBeDeleted.filterNot(oldConfig.contains) + if (invalidConfigs.nonEmpty) + throw new InvalidConfigurationException(s"Invalid config(s): ${invalidConfigs.mkString(",")}") + + val configResource = new ConfigResource(ConfigResource.Type.TOPIC, entityNameHead) + val alterOptions = new AlterConfigsOptions().timeoutMs(30000).validateOnly(false) + val alterEntries = (configsToBeAdded.values.map(new AlterConfigOp(_, AlterConfigOp.OpType.SET)) + ++ configsToBeDeleted.map { k => new AlterConfigOp(new ConfigEntry(k, ""), AlterConfigOp.OpType.DELETE) } + ).asJavaCollection + adminClient.incrementalAlterConfigs(Map(configResource -> alterEntries).asJava, alterOptions).all().get(60, TimeUnit.SECONDS) + + case ConfigType.Broker => + val oldConfig = getResourceConfig(adminClient, entityTypeHead, entityNameHead, includeSynonyms = false, describeAll = false) + .map { entry => (entry.name, entry) }.toMap + + // fail the command if any of the configs to be deleted does not exist + val invalidConfigs = configsToBeDeleted.filterNot(oldConfig.contains) + if (invalidConfigs.nonEmpty) + throw new InvalidConfigurationException(s"Invalid config(s): ${invalidConfigs.mkString(",")}") + + val newEntries = oldConfig ++ configsToBeAdded -- configsToBeDeleted + val sensitiveEntries = newEntries.filter(_._2.value == null) + if (sensitiveEntries.nonEmpty) + throw new InvalidConfigurationException(s"All sensitive broker config entries must be specified for --alter, missing entries: ${sensitiveEntries.keySet}") + val newConfig = new JConfig(newEntries.asJava.values) + + val configResource = new ConfigResource(ConfigResource.Type.BROKER, entityNameHead) + val alterOptions = new AlterConfigsOptions().timeoutMs(30000).validateOnly(false) + adminClient.alterConfigs(Map(configResource -> newConfig).asJava, alterOptions).all().get(60, TimeUnit.SECONDS) + + case BrokerLoggerConfigType => + val validLoggers = getResourceConfig(adminClient, entityTypeHead, entityNameHead, includeSynonyms = true, describeAll = false).map(_.name) + // fail the command if any of the configured broker loggers do not exist + val invalidBrokerLoggers = configsToBeDeleted.filterNot(validLoggers.contains) ++ configsToBeAdded.keys.filterNot(validLoggers.contains) + if (invalidBrokerLoggers.nonEmpty) + throw new InvalidConfigurationException(s"Invalid broker logger(s): ${invalidBrokerLoggers.mkString(",")}") + + val configResource = new ConfigResource(ConfigResource.Type.BROKER_LOGGER, entityNameHead) + val alterOptions = new AlterConfigsOptions().timeoutMs(30000).validateOnly(false) + val alterLogLevelEntries = (configsToBeAdded.values.map(new AlterConfigOp(_, AlterConfigOp.OpType.SET)) + ++ configsToBeDeleted.map { k => new AlterConfigOp(new ConfigEntry(k, ""), AlterConfigOp.OpType.DELETE) } + ).asJavaCollection + adminClient.incrementalAlterConfigs(Map(configResource -> alterLogLevelEntries).asJava, alterOptions).all().get(60, TimeUnit.SECONDS) + + case ConfigType.User | ConfigType.Client => + val hasQuotaConfigsToAdd = configsToBeAdded.keys.exists(QuotaConfigs.isClientOrUserQuotaConfig) + val scramConfigsToAddMap = configsToBeAdded.filter(entry => ScramMechanism.isScram(entry._1)) + val unknownConfigsToAdd = configsToBeAdded.keys.filterNot(key => ScramMechanism.isScram(key) || QuotaConfigs.isClientOrUserQuotaConfig(key)) + val hasQuotaConfigsToDelete = configsToBeDeleted.exists(QuotaConfigs.isClientOrUserQuotaConfig) + val scramConfigsToDelete = configsToBeDeleted.filter(ScramMechanism.isScram) + val unknownConfigsToDelete = configsToBeDeleted.filterNot(key => ScramMechanism.isScram(key) || QuotaConfigs.isClientOrUserQuotaConfig(key)) + if (entityTypeHead == ConfigType.Client || entityTypes.size == 2) { // size==2 for case where users is specified first on the command line, before clients + // either just a client or both a user and a client + if (unknownConfigsToAdd.nonEmpty || scramConfigsToAddMap.nonEmpty) + throw new IllegalArgumentException(s"Only quota configs can be added for '${ConfigType.Client}' using --bootstrap-server. Unexpected config names: ${unknownConfigsToAdd ++ scramConfigsToAddMap.keys}") + if (unknownConfigsToDelete.nonEmpty || scramConfigsToDelete.nonEmpty) + throw new IllegalArgumentException(s"Only quota configs can be deleted for '${ConfigType.Client}' using --bootstrap-server. Unexpected config names: ${unknownConfigsToDelete ++ scramConfigsToDelete}") + } else { // ConfigType.User + if (unknownConfigsToAdd.nonEmpty) + throw new IllegalArgumentException(s"Only quota and SCRAM credential configs can be added for '${ConfigType.User}' using --bootstrap-server. Unexpected config names: $unknownConfigsToAdd") + if (unknownConfigsToDelete.nonEmpty) + throw new IllegalArgumentException(s"Only quota and SCRAM credential configs can be deleted for '${ConfigType.User}' using --bootstrap-server. Unexpected config names: $unknownConfigsToDelete") + if (scramConfigsToAddMap.nonEmpty || scramConfigsToDelete.nonEmpty) { + if (entityNames.exists(_.isEmpty)) // either --entity-type users --entity-default or --user-defaults + throw new IllegalArgumentException("The use of --entity-default or --user-defaults is not allowed with User SCRAM Credentials using --bootstrap-server.") + if (hasQuotaConfigsToAdd || hasQuotaConfigsToDelete) + throw new IllegalArgumentException(s"Cannot alter both quota and SCRAM credential configs simultaneously for '${ConfigType.User}' using --bootstrap-server.") + } + } + + if (hasQuotaConfigsToAdd || hasQuotaConfigsToDelete) { + alterQuotaConfigs(adminClient, entityTypes, entityNames, configsToBeAddedMap, configsToBeDeleted) + } else { + // handle altering user SCRAM credential configs + if (entityNames.size != 1) + // should never happen, if we get here then it is a bug + throw new IllegalStateException(s"Altering user SCRAM credentials should never occur for more zero or multiple users: $entityNames") + alterUserScramCredentialConfigs(adminClient, entityNames.head, scramConfigsToAddMap, scramConfigsToDelete) + } + case ConfigType.Ip => + val unknownConfigs = (configsToBeAdded.keys ++ configsToBeDeleted).filterNot(key => DynamicConfig.Ip.names.contains(key)) + if (unknownConfigs.nonEmpty) + throw new IllegalArgumentException(s"Only connection quota configs can be added for '${ConfigType.Ip}' using --bootstrap-server. Unexpected config names: ${unknownConfigs.mkString(",")}") + alterQuotaConfigs(adminClient, entityTypes, entityNames, configsToBeAddedMap, configsToBeDeleted) + case _ => throw new IllegalArgumentException(s"Unsupported entity type: $entityTypeHead") + } + + if (entityNameHead.nonEmpty) + println(s"Completed updating config for ${entityTypeHead.dropRight(1)} $entityNameHead.") + else + println(s"Completed updating default config for $entityTypeHead in the cluster.") + } + + private def alterUserScramCredentialConfigs(adminClient: Admin, user: String, scramConfigsToAddMap: Map[String, ConfigEntry], scramConfigsToDelete: Seq[String]) = { + val deletions = scramConfigsToDelete.map(mechanismName => + new UserScramCredentialDeletion(user, PublicScramMechanism.fromMechanismName(mechanismName))) + + def iterationsAndPasswordBytes(mechanism: ScramMechanism, credentialStr: String): (Integer, Array[Byte]) = { + val pattern = "(?:iterations=(\\-?[0-9]*),)?password=(.*)".r + val (iterations, password) = credentialStr match { + case pattern(iterations, password) => (if (iterations != null && iterations != "-1") iterations.toInt else DefaultScramIterations, password) + case _ => throw new IllegalArgumentException(s"Invalid credential property $mechanism=$credentialStr") + } + if (iterations < mechanism.minIterations) + throw new IllegalArgumentException(s"Iterations $iterations is less than the minimum ${mechanism.minIterations} required for ${mechanism.mechanismName}") + (iterations, password.getBytes(StandardCharsets.UTF_8)) + } + + val upsertions = scramConfigsToAddMap.map { case (mechanismName, configEntry) => + val (iterations, passwordBytes) = iterationsAndPasswordBytes(ScramMechanism.forMechanismName(mechanismName), configEntry.value) + new UserScramCredentialUpsertion(user, new ScramCredentialInfo(PublicScramMechanism.fromMechanismName(mechanismName), iterations), passwordBytes) + } + // we are altering only a single user by definition, so we don't have to worry about one user succeeding and another + // failing; therefore just check the success of all the futures (since there will only be 1) + adminClient.alterUserScramCredentials((deletions ++ upsertions).toList.asJava).all.get(60, TimeUnit.SECONDS) + } + + private def alterQuotaConfigs(adminClient: Admin, entityTypes: List[String], entityNames: List[String], configsToBeAddedMap: Map[String, String], configsToBeDeleted: Seq[String]) = { + // handle altering client/user quota configs + val oldConfig = getClientQuotasConfig(adminClient, entityTypes, entityNames) + + val invalidConfigs = configsToBeDeleted.filterNot(oldConfig.contains) + if (invalidConfigs.nonEmpty) + throw new InvalidConfigurationException(s"Invalid config(s): ${invalidConfigs.mkString(",")}") + + val alterEntityTypes = entityTypes.map { + case ConfigType.User => ClientQuotaEntity.USER + case ConfigType.Client => ClientQuotaEntity.CLIENT_ID + case ConfigType.Ip => ClientQuotaEntity.IP + case entType => throw new IllegalArgumentException(s"Unexpected entity type: $entType") + } + val alterEntityNames = entityNames.map(en => if (en.nonEmpty) en else null) + + // Explicitly populate a HashMap to ensure nulls are recorded properly. + val alterEntityMap = new java.util.HashMap[String, String] + alterEntityTypes.zip(alterEntityNames).foreach { case (k, v) => alterEntityMap.put(k, v) } + val entity = new ClientQuotaEntity(alterEntityMap) + + val alterOptions = new AlterClientQuotasOptions().validateOnly(false) + val alterOps = (configsToBeAddedMap.map { case (key, value) => + val doubleValue = try value.toDouble catch { + case _: NumberFormatException => + throw new IllegalArgumentException(s"Cannot parse quota configuration value for $key: $value") + } + new ClientQuotaAlteration.Op(key, doubleValue) + } ++ configsToBeDeleted.map(key => new ClientQuotaAlteration.Op(key, null))).asJavaCollection + + adminClient.alterClientQuotas(Collections.singleton(new ClientQuotaAlteration(entity, alterOps)), alterOptions) + .all().get(60, TimeUnit.SECONDS) + } + + private[admin] def describeConfig(adminClient: Admin, opts: ConfigCommandOptions): Unit = { + val entityTypes = opts.entityTypes + val entityNames = opts.entityNames + val describeAll = opts.options.has(opts.allOpt) + + entityTypes.head match { + case ConfigType.Topic | ConfigType.Broker | BrokerLoggerConfigType => + describeResourceConfig(adminClient, entityTypes.head, entityNames.headOption, describeAll) + case ConfigType.User | ConfigType.Client => + describeClientQuotaAndUserScramCredentialConfigs(adminClient, entityTypes, entityNames) + case ConfigType.Ip => + describeQuotaConfigs(adminClient, entityTypes, entityNames) + case entityType => throw new IllegalArgumentException(s"Invalid entity type: $entityType") + } + } + + private def describeResourceConfig(adminClient: Admin, entityType: String, entityName: Option[String], describeAll: Boolean): Unit = { + val entities = entityName + .map(name => List(name)) + .getOrElse(entityType match { + case ConfigType.Topic => + adminClient.listTopics(new ListTopicsOptions().listInternal(true)).names().get().asScala.toSeq + case ConfigType.Broker | BrokerLoggerConfigType => + adminClient.describeCluster(new DescribeClusterOptions()).nodes().get().asScala.map(_.idString).toSeq :+ BrokerDefaultEntityName + case entityType => throw new IllegalArgumentException(s"Invalid entity type: $entityType") + }) + + entities.foreach { entity => + entity match { + case BrokerDefaultEntityName => + println(s"Default configs for $entityType in the cluster are:") + case _ => + val configSourceStr = if (describeAll) "All" else "Dynamic" + println(s"$configSourceStr configs for ${entityType.dropRight(1)} $entity are:") + } + getResourceConfig(adminClient, entityType, entity, includeSynonyms = true, describeAll).foreach { entry => + val synonyms = entry.synonyms.asScala.map(synonym => s"${synonym.source}:${synonym.name}=${synonym.value}").mkString(", ") + println(s" ${entry.name}=${entry.value} sensitive=${entry.isSensitive} synonyms={$synonyms}") + } + } + } + + private def getResourceConfig(adminClient: Admin, entityType: String, entityName: String, includeSynonyms: Boolean, describeAll: Boolean) = { + def validateBrokerId(): Unit = try entityName.toInt catch { + case _: NumberFormatException => + throw new IllegalArgumentException(s"The entity name for $entityType must be a valid integer broker id, found: $entityName") + } + + val (configResourceType, dynamicConfigSource) = entityType match { + case ConfigType.Topic => + if (!entityName.isEmpty) + Topic.validate(entityName) + (ConfigResource.Type.TOPIC, Some(ConfigEntry.ConfigSource.DYNAMIC_TOPIC_CONFIG)) + case ConfigType.Broker => entityName match { + case BrokerDefaultEntityName => + (ConfigResource.Type.BROKER, Some(ConfigEntry.ConfigSource.DYNAMIC_DEFAULT_BROKER_CONFIG)) + case _ => + validateBrokerId() + (ConfigResource.Type.BROKER, Some(ConfigEntry.ConfigSource.DYNAMIC_BROKER_CONFIG)) + } + case BrokerLoggerConfigType => + if (!entityName.isEmpty) + validateBrokerId() + (ConfigResource.Type.BROKER_LOGGER, None) + case entityType => throw new IllegalArgumentException(s"Invalid entity type: $entityType") + } + + val configSourceFilter = if (describeAll) + None + else + dynamicConfigSource + + val configResource = new ConfigResource(configResourceType, entityName) + val describeOptions = new DescribeConfigsOptions().includeSynonyms(includeSynonyms) + val configs = adminClient.describeConfigs(Collections.singleton(configResource), describeOptions) + .all.get(30, TimeUnit.SECONDS) + configs.get(configResource).entries.asScala + .filter(entry => configSourceFilter match { + case Some(configSource) => entry.source == configSource + case None => true + }).toSeq + } + + private def describeQuotaConfigs(adminClient: Admin, entityTypes: List[String], entityNames: List[String]) = { + val quotaConfigs = getAllClientQuotasConfigs(adminClient, entityTypes, entityNames) + quotaConfigs.forKeyValue { (entity, entries) => + val entityEntries = entity.entries.asScala + + def entitySubstr(entityType: String): Option[String] = + entityEntries.get(entityType).map { name => + val typeStr = entityType match { + case ClientQuotaEntity.USER => "user-principal" + case ClientQuotaEntity.CLIENT_ID => "client-id" + case ClientQuotaEntity.IP => "ip" + } + if (name != null) s"$typeStr '$name'" + else s"the default $typeStr" + } + + val entityStr = (entitySubstr(ClientQuotaEntity.USER) ++ + entitySubstr(ClientQuotaEntity.CLIENT_ID) ++ + entitySubstr(ClientQuotaEntity.IP)).mkString(", ") + val entriesStr = entries.asScala.map(e => s"${e._1}=${e._2}").mkString(", ") + println(s"Quota configs for $entityStr are $entriesStr") + } + } + + private def describeClientQuotaAndUserScramCredentialConfigs(adminClient: Admin, entityTypes: List[String], entityNames: List[String]) = { + describeQuotaConfigs(adminClient, entityTypes, entityNames) + // we describe user SCRAM credentials only when we are not describing client information + // and we are not given either --entity-default or --user-defaults + if (!entityTypes.contains(ConfigType.Client) && !entityNames.contains("")) { + val result = adminClient.describeUserScramCredentials(entityNames.asJava) + result.users.get(30, TimeUnit.SECONDS).asScala.foreach(user => { + try { + val description = result.description(user).get(30, TimeUnit.SECONDS) + val descriptionText = description.credentialInfos.asScala.map(info => s"${info.mechanism.mechanismName}=iterations=${info.iterations}").mkString(", ") + println(s"SCRAM credential configs for user-principal '$user' are $descriptionText") + } catch { + case e: Exception => println(s"Error retrieving SCRAM credential configs for user-principal '$user': ${e.getClass.getSimpleName}: ${e.getMessage}") + } + }) + } + } + + private def getClientQuotasConfig(adminClient: Admin, entityTypes: List[String], entityNames: List[String]): Map[String, java.lang.Double] = { + if (entityTypes.size != entityNames.size) + throw new IllegalArgumentException("Exactly one entity name must be specified for every entity type") + getAllClientQuotasConfigs(adminClient, entityTypes, entityNames).headOption.map(_._2.asScala).getOrElse(Map.empty) + } + + private def getAllClientQuotasConfigs(adminClient: Admin, entityTypes: List[String], entityNames: List[String]) = { + val components = entityTypes.map(Some(_)).zipAll(entityNames.map(Some(_)), None, None).map { case (entityTypeOpt, entityNameOpt) => + val entityType = entityTypeOpt match { + case Some(ConfigType.User) => ClientQuotaEntity.USER + case Some(ConfigType.Client) => ClientQuotaEntity.CLIENT_ID + case Some(ConfigType.Ip) => ClientQuotaEntity.IP + case Some(_) => throw new IllegalArgumentException(s"Unexpected entity type ${entityTypeOpt.get}") + case None => throw new IllegalArgumentException("More entity names specified than entity types") + } + entityNameOpt match { + case Some("") => ClientQuotaFilterComponent.ofDefaultEntity(entityType) + case Some(name) => ClientQuotaFilterComponent.ofEntity(entityType, name) + case None => ClientQuotaFilterComponent.ofEntityType(entityType) + } + } + + adminClient.describeClientQuotas(ClientQuotaFilter.containsOnly(components.asJava)).entities.get(30, TimeUnit.SECONDS).asScala + } + + case class Entity(entityType: String, sanitizedName: Option[String]) { + val entityPath = sanitizedName match { + case Some(n) => entityType + "/" + n + case None => entityType + } + override def toString: String = { + val typeName = entityType match { + case ConfigType.User => "user-principal" + case ConfigType.Client => "client-id" + case ConfigType.Topic => "topic" + case t => t + } + sanitizedName match { + case Some(ConfigEntityName.Default) => "default " + typeName + case Some(n) => + val desanitized = if (entityType == ConfigType.User || entityType == ConfigType.Client) Sanitizer.desanitize(n) else n + s"$typeName '$desanitized'" + case None => entityType + } + } + } + + case class ConfigEntity(root: Entity, child: Option[Entity]) { + val fullSanitizedName = root.sanitizedName.getOrElse("") + child.map(s => "/" + s.entityPath).getOrElse("") + + def getAllEntities(zkClient: KafkaZkClient) : Seq[ConfigEntity] = { + // Describe option examples: + // Describe entity with specified name: + // --entity-type topics --entity-name topic1 (topic1) + // Describe all entities of a type (topics/brokers/users/clients): + // --entity-type topics (all topics) + // Describe quotas: + // --entity-type users --entity-name user1 --entity-type clients --entity-name client2 () + // --entity-type users --entity-name userA --entity-type clients (all clients of userA) + // --entity-type users --entity-type clients (all s)) + // Describe default quotas: + // --entity-type users --entity-default (Default user) + // --entity-type users --entity-default --entity-type clients --entity-default (Default ) + (root.sanitizedName, child) match { + case (None, _) => + val rootEntities = zkClient.getAllEntitiesWithConfig(root.entityType) + .map(name => ConfigEntity(Entity(root.entityType, Some(name)), child)) + child match { + case Some(s) => + rootEntities.flatMap(rootEntity => + ConfigEntity(rootEntity.root, Some(Entity(s.entityType, None))).getAllEntities(zkClient)) + case None => rootEntities + } + case (_, Some(childEntity)) => + childEntity.sanitizedName match { + case Some(_) => Seq(this) + case None => + zkClient.getAllEntitiesWithConfig(root.entityPath + "/" + childEntity.entityType) + .map(name => ConfigEntity(root, Some(Entity(childEntity.entityType, Some(name))))) + + } + case (_, None) => + Seq(this) + } + } + + override def toString: String = { + root.toString + child.map(s => ", " + s.toString).getOrElse("") + } + } + + private[admin] def parseEntity(opts: ConfigCommandOptions): ConfigEntity = { + val entityTypes = opts.entityTypes + val entityNames = opts.entityNames + if (entityTypes.head == ConfigType.User || entityTypes.head == ConfigType.Client) + parseClientQuotaEntity(opts, entityTypes, entityNames) + else { + // Exactly one entity type and at-most one entity name expected for other entities + val name = entityNames.headOption match { + case Some("") => Some(ConfigEntityName.Default) + case v => v + } + ConfigEntity(Entity(entityTypes.head, name), None) + } + } + + private def parseClientQuotaEntity(opts: ConfigCommandOptions, types: List[String], names: List[String]): ConfigEntity = { + if (opts.options.has(opts.alterOpt) && names.size != types.size) + throw new IllegalArgumentException("--entity-name or --entity-default must be specified with each --entity-type for --alter") + + val reverse = types.size == 2 && types.head == ConfigType.Client + val entityTypes = if (reverse) types.reverse else types + val sortedNames = (if (reverse && names.length == 2) names.reverse else names).iterator + + def sanitizeName(entityType: String, name: String) = { + if (name.isEmpty) + ConfigEntityName.Default + else { + entityType match { + case ConfigType.User | ConfigType.Client => Sanitizer.sanitize(name) + case _ => throw new IllegalArgumentException("Invalid entity type " + entityType) + } + } + } + + val entities = entityTypes.map(t => Entity(t, if (sortedNames.hasNext) Some(sanitizeName(t, sortedNames.next())) else None)) + ConfigEntity(entities.head, if (entities.size > 1) Some(entities(1)) else None) + } + + class ConfigCommandOptions(args: Array[String]) extends CommandDefaultOptions(args) { + + val zkConnectOpt = parser.accepts("zookeeper", "DEPRECATED. The connection string for the zookeeper connection in the form host:port. " + + "Multiple URLS can be given to allow fail-over. Required when configuring SCRAM credentials for users or " + + "dynamic broker configs when the relevant broker(s) are down. Not allowed otherwise.") + .withRequiredArg + .describedAs("urls") + .ofType(classOf[String]) + val bootstrapServerOpt = parser.accepts("bootstrap-server", "The Kafka server to connect to. " + + "This is required for describing and altering broker configs.") + .withRequiredArg + .describedAs("server to connect to") + .ofType(classOf[String]) + val commandConfigOpt = parser.accepts("command-config", "Property file containing configs to be passed to Admin Client. " + + "This is used only with --bootstrap-server option for describing and altering broker configs.") + .withRequiredArg + .describedAs("command config property file") + .ofType(classOf[String]) + val alterOpt = parser.accepts("alter", "Alter the configuration for the entity.") + val describeOpt = parser.accepts("describe", "List configs for the given entity.") + val allOpt = parser.accepts("all", "List all configs for the given topic, broker, or broker-logger entity (includes static configuration when the entity type is brokers)") + + val entityType = parser.accepts("entity-type", "Type of entity (topics/clients/users/brokers/broker-loggers/ips)") + .withRequiredArg + .ofType(classOf[String]) + val entityName = parser.accepts("entity-name", "Name of entity (topic name/client id/user principal name/broker id/ip)") + .withRequiredArg + .ofType(classOf[String]) + val entityDefault = parser.accepts("entity-default", "Default entity name for clients/users/brokers/ips (applies to corresponding entity type in command line)") + + val nl = System.getProperty("line.separator") + val addConfig = parser.accepts("add-config", "Key Value pairs of configs to add. Square brackets can be used to group values which contain commas: 'k1=v1,k2=[v1,v2,v2],k3=v3'. The following is a list of valid configurations: " + + "For entity-type '" + ConfigType.Topic + "': " + LogConfig.configNames.map("\t" + _).mkString(nl, nl, nl) + + "For entity-type '" + ConfigType.Broker + "': " + DynamicConfig.Broker.names.asScala.toSeq.sorted.map("\t" + _).mkString(nl, nl, nl) + + "For entity-type '" + ConfigType.User + "': " + DynamicConfig.User.names.asScala.toSeq.sorted.map("\t" + _).mkString(nl, nl, nl) + + "For entity-type '" + ConfigType.Client + "': " + DynamicConfig.Client.names.asScala.toSeq.sorted.map("\t" + _).mkString(nl, nl, nl) + + "For entity-type '" + ConfigType.Ip + "': " + DynamicConfig.Ip.names.asScala.toSeq.sorted.map("\t" + _).mkString(nl, nl, nl) + + s"Entity types '${ConfigType.User}' and '${ConfigType.Client}' may be specified together to update config for clients of a specific user.") + .withRequiredArg + .ofType(classOf[String]) + val addConfigFile = parser.accepts("add-config-file", "Path to a properties file with configs to add. See add-config for a list of valid configurations.") + .withRequiredArg + .ofType(classOf[String]) + val deleteConfig = parser.accepts("delete-config", "config keys to remove 'k1,k2'") + .withRequiredArg + .ofType(classOf[String]) + .withValuesSeparatedBy(',') + val forceOpt = parser.accepts("force", "Suppress console prompts") + val topic = parser.accepts("topic", "The topic's name.") + .withRequiredArg + .ofType(classOf[String]) + val client = parser.accepts("client", "The client's ID.") + .withRequiredArg + .ofType(classOf[String]) + val clientDefaults = parser.accepts("client-defaults", "The config defaults for all clients.") + val user = parser.accepts("user", "The user's principal name.") + .withRequiredArg + .ofType(classOf[String]) + val userDefaults = parser.accepts("user-defaults", "The config defaults for all users.") + val broker = parser.accepts("broker", "The broker's ID.") + .withRequiredArg + .ofType(classOf[String]) + val brokerDefaults = parser.accepts("broker-defaults", "The config defaults for all brokers.") + val brokerLogger = parser.accepts("broker-logger", "The broker's ID for its logger config.") + .withRequiredArg + .ofType(classOf[String]) + val ipDefaults = parser.accepts("ip-defaults", "The config defaults for all IPs.") + val ip = parser.accepts("ip", "The IP address.") + .withRequiredArg + .ofType(classOf[String]) + val zkTlsConfigFile = parser.accepts("zk-tls-config-file", + "Identifies the file where ZooKeeper client TLS connectivity properties are defined. Any properties other than " + + KafkaConfig.ZkSslConfigToSystemPropertyMap.keys.toList.sorted.mkString(", ") + " are ignored.") + .withRequiredArg().describedAs("ZooKeeper TLS configuration").ofType(classOf[String]) + options = parser.parse(args : _*) + + private val entityFlags = List((topic, ConfigType.Topic), + (client, ConfigType.Client), + (user, ConfigType.User), + (broker, ConfigType.Broker), + (brokerLogger, BrokerLoggerConfigType), + (ip, ConfigType.Ip)) + + private val entityDefaultsFlags = List((clientDefaults, ConfigType.Client), + (userDefaults, ConfigType.User), + (brokerDefaults, ConfigType.Broker), + (ipDefaults, ConfigType.Ip)) + + private[admin] def entityTypes: List[String] = { + options.valuesOf(entityType).asScala.toList ++ + (entityFlags ++ entityDefaultsFlags).filter(entity => options.has(entity._1)).map(_._2) + } + + private[admin] def entityNames: List[String] = { + val namesIterator = options.valuesOf(entityName).iterator + options.specs.asScala + .filter(spec => spec.options.contains("entity-name") || spec.options.contains("entity-default")) + .map(spec => if (spec.options.contains("entity-name")) namesIterator.next else "").toList ++ + entityFlags + .filter(entity => options.has(entity._1)) + .map(entity => options.valueOf(entity._1)) ++ + entityDefaultsFlags + .filter(entity => options.has(entity._1)) + .map(_ => "") + } + + def checkArgs(): Unit = { + // should have exactly one action + val actions = Seq(alterOpt, describeOpt).count(options.has _) + if (actions != 1) + CommandLineUtils.printUsageAndDie(parser, "Command must include exactly one action: --describe, --alter") + // check required args + CommandLineUtils.checkInvalidArgs(parser, options, alterOpt, Set(describeOpt)) + CommandLineUtils.checkInvalidArgs(parser, options, describeOpt, Set(alterOpt, addConfig, deleteConfig)) + + val entityTypeVals = entityTypes + if (entityTypeVals.size != entityTypeVals.distinct.size) + throw new IllegalArgumentException(s"Duplicate entity type(s) specified: ${entityTypeVals.diff(entityTypeVals.distinct).mkString(",")}") + + val (allowedEntityTypes, connectOptString) = if (options.has(bootstrapServerOpt)) + (BrokerSupportedConfigTypes, "--bootstrap-server") + else + (ZkSupportedConfigTypes, "--zookeeper") + + entityTypeVals.foreach(entityTypeVal => + if (!allowedEntityTypes.contains(entityTypeVal)) + throw new IllegalArgumentException(s"Invalid entity type $entityTypeVal, the entity type must be one of ${allowedEntityTypes.mkString(", ")} with the $connectOptString argument") + ) + if (entityTypeVals.isEmpty) + throw new IllegalArgumentException("At least one entity type must be specified") + else if (entityTypeVals.size > 1 && !entityTypeVals.toSet.equals(Set(ConfigType.User, ConfigType.Client))) + throw new IllegalArgumentException(s"Only '${ConfigType.User}' and '${ConfigType.Client}' entity types may be specified together") + + if ((options.has(entityName) || options.has(entityType) || options.has(entityDefault)) && + (entityFlags ++ entityDefaultsFlags).exists(entity => options.has(entity._1))) + throw new IllegalArgumentException("--entity-{type,name,default} should not be used in conjunction with specific entity flags") + + val hasEntityName = entityNames.exists(!_.isEmpty) + val hasEntityDefault = entityNames.exists(_.isEmpty) + + if (!options.has(bootstrapServerOpt) && !options.has(zkConnectOpt)) + throw new IllegalArgumentException("One of the required --bootstrap-server or --zookeeper arguments must be specified") + else if (options.has(bootstrapServerOpt) && options.has(zkConnectOpt)) + throw new IllegalArgumentException("Only one of --bootstrap-server or --zookeeper must be specified") + + if (options.has(allOpt) && options.has(zkConnectOpt)) { + throw new IllegalArgumentException(s"--bootstrap-server must be specified for --all") + } + + if (options.has(zkTlsConfigFile) && options.has(bootstrapServerOpt)) { + throw new IllegalArgumentException("--bootstrap-server doesn't support --zk-tls-config-file option. " + + "If you intend the command to communicate directly with ZooKeeper, please use the option --zookeeper instead of --bootstrap-server. " + + "Otherwise, remove the --zk-tls-config-file option.") + } + + if (hasEntityName && (entityTypeVals.contains(ConfigType.Broker) || entityTypeVals.contains(BrokerLoggerConfigType))) { + Seq(entityName, broker, brokerLogger).filter(options.has(_)).map(options.valueOf(_)).foreach { brokerId => + try brokerId.toInt catch { + case _: NumberFormatException => + throw new IllegalArgumentException(s"The entity name for ${entityTypeVals.head} must be a valid integer broker id, but it is: $brokerId") + } + } + } + + if (hasEntityName && entityTypeVals.contains(ConfigType.Ip)) { + Seq(entityName, ip).filter(options.has(_)).map(options.valueOf(_)).foreach { ipEntity => + if (!DynamicConfig.Ip.isValidIpEntity(ipEntity)) + throw new IllegalArgumentException(s"The entity name for ${entityTypeVals.head} must be a valid IP or resolvable host, but it is: $ipEntity") + } + } + + if (options.has(describeOpt) && entityTypeVals.contains(BrokerLoggerConfigType) && !hasEntityName) + throw new IllegalArgumentException(s"an entity name must be specified with --describe of ${entityTypeVals.mkString(",")}") + + if (options.has(alterOpt)) { + if (entityTypeVals.contains(ConfigType.User) || + entityTypeVals.contains(ConfigType.Client) || + entityTypeVals.contains(ConfigType.Broker) || + entityTypeVals.contains(ConfigType.Ip)) { + if (!hasEntityName && !hasEntityDefault) + throw new IllegalArgumentException("an entity-name or default entity must be specified with --alter of users, clients, brokers or ips") + } else if (!hasEntityName) + throw new IllegalArgumentException(s"an entity name must be specified with --alter of ${entityTypeVals.mkString(",")}") + + val isAddConfigPresent = options.has(addConfig) + val isAddConfigFilePresent = options.has(addConfigFile) + val isDeleteConfigPresent = options.has(deleteConfig) + + if(isAddConfigPresent && isAddConfigFilePresent) + throw new IllegalArgumentException("Only one of --add-config or --add-config-file must be specified") + + if(!isAddConfigPresent && !isAddConfigFilePresent && !isDeleteConfigPresent) + throw new IllegalArgumentException("At least one of --add-config, --add-config-file, or --delete-config must be specified with --alter") + } + } + } +} diff --git a/core/src/main/scala/kafka/admin/ConsumerGroupCommand.scala b/core/src/main/scala/kafka/admin/ConsumerGroupCommand.scala new file mode 100755 index 0000000..47c1d17 --- /dev/null +++ b/core/src/main/scala/kafka/admin/ConsumerGroupCommand.scala @@ -0,0 +1,1157 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.admin + +import java.time.{Duration, Instant} +import java.util.Properties +import com.fasterxml.jackson.dataformat.csv.CsvMapper +import com.fasterxml.jackson.module.scala.DefaultScalaModule +import kafka.utils._ +import kafka.utils.Implicits._ +import org.apache.kafka.clients.admin._ +import org.apache.kafka.clients.consumer.OffsetAndMetadata +import org.apache.kafka.clients.CommonClientConfigs +import org.apache.kafka.common.utils.Utils +import org.apache.kafka.common.{KafkaException, Node, TopicPartition} + +import scala.jdk.CollectionConverters._ +import scala.collection.mutable.ListBuffer +import scala.collection.{Map, Seq, immutable, mutable} +import scala.util.{Failure, Success, Try} +import joptsimple.OptionSpec +import org.apache.kafka.common.protocol.Errors + +import scala.collection.immutable.TreeMap +import scala.reflect.ClassTag +import org.apache.kafka.common.ConsumerGroupState +import joptsimple.OptionException +import org.apache.kafka.common.requests.ListOffsetsResponse + +object ConsumerGroupCommand extends Logging { + + def main(args: Array[String]): Unit = { + + val opts = new ConsumerGroupCommandOptions(args) + try { + opts.checkArgs() + CommandLineUtils.printHelpAndExitIfNeeded(opts, "This tool helps to list all consumer groups, describe a consumer group, delete consumer group info, or reset consumer group offsets.") + + // should have exactly one action + val actions = Seq(opts.listOpt, opts.describeOpt, opts.deleteOpt, opts.resetOffsetsOpt, opts.deleteOffsetsOpt).count(opts.options.has) + if (actions != 1) + CommandLineUtils.printUsageAndDie(opts.parser, "Command must include exactly one action: --list, --describe, --delete, --reset-offsets, --delete-offsets") + + run(opts) + } catch { + case e: OptionException => + CommandLineUtils.printUsageAndDie(opts.parser, e.getMessage) + } + } + + def run(opts: ConsumerGroupCommandOptions): Unit = { + val consumerGroupService = new ConsumerGroupService(opts) + try { + if (opts.options.has(opts.listOpt)) + consumerGroupService.listGroups() + else if (opts.options.has(opts.describeOpt)) + consumerGroupService.describeGroups() + else if (opts.options.has(opts.deleteOpt)) + consumerGroupService.deleteGroups() + else if (opts.options.has(opts.resetOffsetsOpt)) { + val offsetsToReset = consumerGroupService.resetOffsets() + if (opts.options.has(opts.exportOpt)) { + val exported = consumerGroupService.exportOffsetsToCsv(offsetsToReset) + println(exported) + } else + printOffsetsToReset(offsetsToReset) + } + else if (opts.options.has(opts.deleteOffsetsOpt)) { + consumerGroupService.deleteOffsets() + } + } catch { + case e: IllegalArgumentException => + CommandLineUtils.printUsageAndDie(opts.parser, e.getMessage) + case e: Throwable => + printError(s"Executing consumer group command failed due to ${e.getMessage}", Some(e)) + } finally { + consumerGroupService.close() + } + } + + def consumerGroupStatesFromString(input: String): Set[ConsumerGroupState] = { + val parsedStates = input.split(',').map(s => ConsumerGroupState.parse(s.trim)).toSet + if (parsedStates.contains(ConsumerGroupState.UNKNOWN)) { + val validStates = ConsumerGroupState.values().filter(_ != ConsumerGroupState.UNKNOWN) + throw new IllegalArgumentException(s"Invalid state list '$input'. Valid states are: ${validStates.mkString(", ")}") + } + parsedStates + } + + val MISSING_COLUMN_VALUE = "-" + + def printError(msg: String, e: Option[Throwable] = None): Unit = { + println(s"\nError: $msg") + e.foreach(_.printStackTrace()) + } + + def printOffsetsToReset(groupAssignmentsToReset: Map[String, Map[TopicPartition, OffsetAndMetadata]]): Unit = { + if (groupAssignmentsToReset.nonEmpty) + println("\n%-30s %-30s %-10s %-15s".format("GROUP", "TOPIC", "PARTITION", "NEW-OFFSET")) + for { + (groupId, assignment) <- groupAssignmentsToReset + (consumerAssignment, offsetAndMetadata) <- assignment + } { + println("%-30s %-30s %-10s %-15s".format( + groupId, + consumerAssignment.topic, + consumerAssignment.partition, + offsetAndMetadata.offset)) + } + } + + private[admin] case class PartitionAssignmentState(group: String, coordinator: Option[Node], topic: Option[String], + partition: Option[Int], offset: Option[Long], lag: Option[Long], + consumerId: Option[String], host: Option[String], + clientId: Option[String], logEndOffset: Option[Long]) + + private[admin] case class MemberAssignmentState(group: String, consumerId: String, host: String, clientId: String, groupInstanceId: String, + numPartitions: Int, assignment: List[TopicPartition]) + + private[admin] case class GroupState(group: String, coordinator: Node, assignmentStrategy: String, state: String, numMembers: Int) + + private[admin] sealed trait CsvRecord + private[admin] case class CsvRecordWithGroup(group: String, topic: String, partition: Int, offset: Long) extends CsvRecord + private[admin] case class CsvRecordNoGroup(topic: String, partition: Int, offset: Long) extends CsvRecord + private[admin] object CsvRecordWithGroup { + val fields = Array("group", "topic", "partition", "offset") + } + private[admin] object CsvRecordNoGroup { + val fields = Array("topic", "partition", "offset") + } + // Example: CsvUtils().readerFor[CsvRecordWithoutGroup] + private[admin] case class CsvUtils() { + val mapper = new CsvMapper + mapper.registerModule(DefaultScalaModule) + def readerFor[T <: CsvRecord : ClassTag] = { + val schema = getSchema[T] + val clazz = implicitly[ClassTag[T]].runtimeClass + mapper.readerFor(clazz).`with`(schema) + } + def writerFor[T <: CsvRecord : ClassTag] = { + val schema = getSchema[T] + val clazz = implicitly[ClassTag[T]].runtimeClass + mapper.writerFor(clazz).`with`(schema) + } + private def getSchema[T <: CsvRecord : ClassTag] = { + val clazz = implicitly[ClassTag[T]].runtimeClass + + val fields = + if (classOf[CsvRecordWithGroup] == clazz) CsvRecordWithGroup.fields + else if (classOf[CsvRecordNoGroup] == clazz) CsvRecordNoGroup.fields + else throw new IllegalStateException(s"Unhandled class $clazz") + + val schema = mapper.schemaFor(clazz).sortedBy(fields: _*) + schema + } + } + + class ConsumerGroupService(val opts: ConsumerGroupCommandOptions, + private[admin] val configOverrides: Map[String, String] = Map.empty) { + + private val adminClient = createAdminClient(configOverrides) + + // We have to make sure it is evaluated once and available + private lazy val resetPlanFromFile: Option[Map[String, Map[TopicPartition, OffsetAndMetadata]]] = { + if (opts.options.has(opts.resetFromFileOpt)) { + val resetPlanPath = opts.options.valueOf(opts.resetFromFileOpt) + val resetPlanCsv = Utils.readFileAsString(resetPlanPath) + val resetPlan = parseResetPlan(resetPlanCsv) + Some(resetPlan) + } else None + } + + def listGroups(): Unit = { + if (opts.options.has(opts.stateOpt)) { + val stateValue = opts.options.valueOf(opts.stateOpt) + val states = if (stateValue == null || stateValue.isEmpty) + Set[ConsumerGroupState]() + else + consumerGroupStatesFromString(stateValue) + val listings = listConsumerGroupsWithState(states) + printGroupStates(listings.map(e => (e.groupId, e.state.get.toString))) + } else + listConsumerGroups().foreach(println(_)) + } + + def listConsumerGroups(): List[String] = { + val result = adminClient.listConsumerGroups(withTimeoutMs(new ListConsumerGroupsOptions)) + val listings = result.all.get.asScala + listings.map(_.groupId).toList + } + + def listConsumerGroupsWithState(states: Set[ConsumerGroupState]): List[ConsumerGroupListing] = { + val listConsumerGroupsOptions = withTimeoutMs(new ListConsumerGroupsOptions()) + listConsumerGroupsOptions.inStates(states.asJava) + val result = adminClient.listConsumerGroups(listConsumerGroupsOptions) + result.all.get.asScala.toList + } + + private def printGroupStates(groupsAndStates: List[(String, String)]): Unit = { + // find proper columns width + var maxGroupLen = 15 + for ((groupId, state) <- groupsAndStates) { + maxGroupLen = Math.max(maxGroupLen, groupId.length) + } + println(s"%${-maxGroupLen}s %s".format("GROUP", "STATE")) + for ((groupId, state) <- groupsAndStates) { + println(s"%${-maxGroupLen}s %s".format(groupId, state)) + } + } + + private def shouldPrintMemberState(group: String, state: Option[String], numRows: Option[Int]): Boolean = { + // numRows contains the number of data rows, if any, compiled from the API call in the caller method. + // if it's undefined or 0, there is no relevant group information to display. + numRows match { + case None => + printError(s"The consumer group '$group' does not exist.") + false + case Some(num) => state match { + case Some("Dead") => + printError(s"Consumer group '$group' does not exist.") + case Some("Empty") => + Console.err.println(s"\nConsumer group '$group' has no active members.") + case Some("PreparingRebalance") | Some("CompletingRebalance") => + Console.err.println(s"\nWarning: Consumer group '$group' is rebalancing.") + case Some("Stable") => + case other => + // the control should never reach here + throw new KafkaException(s"Expected a valid consumer group state, but found '${other.getOrElse("NONE")}'.") + } + !state.contains("Dead") && num > 0 + } + } + + private def size(colOpt: Option[Seq[Object]]): Option[Int] = colOpt.map(_.size) + + private def printOffsets(offsets: Map[String, (Option[String], Option[Seq[PartitionAssignmentState]])]): Unit = { + for ((groupId, (state, assignments)) <- offsets) { + if (shouldPrintMemberState(groupId, state, size(assignments))) { + // find proper columns width + var (maxGroupLen, maxTopicLen, maxConsumerIdLen, maxHostLen) = (15, 15, 15, 15) + assignments match { + case None => // do nothing + case Some(consumerAssignments) => + consumerAssignments.foreach { consumerAssignment => + maxGroupLen = Math.max(maxGroupLen, consumerAssignment.group.length) + maxTopicLen = Math.max(maxTopicLen, consumerAssignment.topic.getOrElse(MISSING_COLUMN_VALUE).length) + maxConsumerIdLen = Math.max(maxConsumerIdLen, consumerAssignment.consumerId.getOrElse(MISSING_COLUMN_VALUE).length) + maxHostLen = Math.max(maxHostLen, consumerAssignment.host.getOrElse(MISSING_COLUMN_VALUE).length) + } + } + + println(s"\n%${-maxGroupLen}s %${-maxTopicLen}s %-10s %-15s %-15s %-15s %${-maxConsumerIdLen}s %${-maxHostLen}s %s" + .format("GROUP", "TOPIC", "PARTITION", "CURRENT-OFFSET", "LOG-END-OFFSET", "LAG", "CONSUMER-ID", "HOST", "CLIENT-ID")) + + assignments match { + case None => // do nothing + case Some(consumerAssignments) => + consumerAssignments.foreach { consumerAssignment => + println(s"%${-maxGroupLen}s %${-maxTopicLen}s %-10s %-15s %-15s %-15s %${-maxConsumerIdLen}s %${-maxHostLen}s %s".format( + consumerAssignment.group, + consumerAssignment.topic.getOrElse(MISSING_COLUMN_VALUE), consumerAssignment.partition.getOrElse(MISSING_COLUMN_VALUE), + consumerAssignment.offset.getOrElse(MISSING_COLUMN_VALUE), consumerAssignment.logEndOffset.getOrElse(MISSING_COLUMN_VALUE), + consumerAssignment.lag.getOrElse(MISSING_COLUMN_VALUE), consumerAssignment.consumerId.getOrElse(MISSING_COLUMN_VALUE), + consumerAssignment.host.getOrElse(MISSING_COLUMN_VALUE), consumerAssignment.clientId.getOrElse(MISSING_COLUMN_VALUE)) + ) + } + } + } + } + } + + private def printMembers(members: Map[String, (Option[String], Option[Seq[MemberAssignmentState]])], verbose: Boolean): Unit = { + for ((groupId, (state, assignments)) <- members) { + if (shouldPrintMemberState(groupId, state, size(assignments))) { + // find proper columns width + var (maxGroupLen, maxConsumerIdLen, maxGroupInstanceIdLen, maxHostLen, maxClientIdLen, includeGroupInstanceId) = (15, 15, 17, 15, 15, false) + assignments match { + case None => // do nothing + case Some(memberAssignments) => + memberAssignments.foreach { memberAssignment => + maxGroupLen = Math.max(maxGroupLen, memberAssignment.group.length) + maxConsumerIdLen = Math.max(maxConsumerIdLen, memberAssignment.consumerId.length) + maxGroupInstanceIdLen = Math.max(maxGroupInstanceIdLen, memberAssignment.groupInstanceId.length) + maxHostLen = Math.max(maxHostLen, memberAssignment.host.length) + maxClientIdLen = Math.max(maxClientIdLen, memberAssignment.clientId.length) + includeGroupInstanceId = includeGroupInstanceId || memberAssignment.groupInstanceId.length > 0 + } + } + + if (includeGroupInstanceId) { + print(s"\n%${-maxGroupLen}s %${-maxConsumerIdLen}s %${-maxGroupInstanceIdLen}s %${-maxHostLen}s %${-maxClientIdLen}s %-15s " + .format("GROUP", "CONSUMER-ID", "GROUP-INSTANCE-ID", "HOST", "CLIENT-ID", "#PARTITIONS")) + } else { + print(s"\n%${-maxGroupLen}s %${-maxConsumerIdLen}s %${-maxHostLen}s %${-maxClientIdLen}s %-15s " + .format("GROUP", "CONSUMER-ID", "HOST", "CLIENT-ID", "#PARTITIONS")) + } + if (verbose) + print(s"%s".format("ASSIGNMENT")) + println() + + assignments match { + case None => // do nothing + case Some(memberAssignments) => + memberAssignments.foreach { memberAssignment => + if (includeGroupInstanceId) { + print(s"%${-maxGroupLen}s %${-maxConsumerIdLen}s %${-maxGroupInstanceIdLen}s %${-maxHostLen}s %${-maxClientIdLen}s %-15s ".format( + memberAssignment.group, memberAssignment.consumerId, memberAssignment.groupInstanceId, memberAssignment.host, + memberAssignment.clientId, memberAssignment.numPartitions)) + } else { + print(s"%${-maxGroupLen}s %${-maxConsumerIdLen}s %${-maxHostLen}s %${-maxClientIdLen}s %-15s ".format( + memberAssignment.group, memberAssignment.consumerId, memberAssignment.host, memberAssignment.clientId, memberAssignment.numPartitions)) + } + if (verbose) { + val partitions = memberAssignment.assignment match { + case List() => MISSING_COLUMN_VALUE + case assignment => + assignment.groupBy(_.topic).map { + case (topic, partitionList) => topic + partitionList.map(_.partition).sorted.mkString("(", ",", ")") + }.toList.sorted.mkString(", ") + } + print(s"%s".format(partitions)) + } + println() + } + } + } + } + } + + private def printStates(states: Map[String, GroupState]): Unit = { + for ((groupId, state) <- states) { + if (shouldPrintMemberState(groupId, Some(state.state), Some(1))) { + val coordinator = s"${state.coordinator.host}:${state.coordinator.port} (${state.coordinator.idString})" + val coordinatorColLen = Math.max(25, coordinator.length) + print(s"\n%${-coordinatorColLen}s %-25s %-20s %-15s %s".format("GROUP", "COORDINATOR (ID)", "ASSIGNMENT-STRATEGY", "STATE", "#MEMBERS")) + print(s"\n%${-coordinatorColLen}s %-25s %-20s %-15s %s".format(state.group, coordinator, state.assignmentStrategy, state.state, state.numMembers)) + println() + } + } + } + + def describeGroups(): Unit = { + val groupIds = + if (opts.options.has(opts.allGroupsOpt)) listConsumerGroups() + else opts.options.valuesOf(opts.groupOpt).asScala + val membersOptPresent = opts.options.has(opts.membersOpt) + val stateOptPresent = opts.options.has(opts.stateOpt) + val offsetsOptPresent = opts.options.has(opts.offsetsOpt) + val subActions = Seq(membersOptPresent, offsetsOptPresent, stateOptPresent).count(_ == true) + + if (subActions == 0 || offsetsOptPresent) { + val offsets = collectGroupsOffsets(groupIds) + printOffsets(offsets) + } else if (membersOptPresent) { + val members = collectGroupsMembers(groupIds, opts.options.has(opts.verboseOpt)) + printMembers(members, opts.options.has(opts.verboseOpt)) + } else { + val states = collectGroupsState(groupIds) + printStates(states) + } + } + + private def collectConsumerAssignment(group: String, + coordinator: Option[Node], + topicPartitions: Seq[TopicPartition], + getPartitionOffset: TopicPartition => Option[Long], + consumerIdOpt: Option[String], + hostOpt: Option[String], + clientIdOpt: Option[String]): Array[PartitionAssignmentState] = { + if (topicPartitions.isEmpty) { + Array[PartitionAssignmentState]( + PartitionAssignmentState(group, coordinator, None, None, None, getLag(None, None), consumerIdOpt, hostOpt, clientIdOpt, None) + ) + } + else + describePartitions(group, coordinator, topicPartitions.sortBy(_.partition), getPartitionOffset, consumerIdOpt, hostOpt, clientIdOpt) + } + + private def getLag(offset: Option[Long], logEndOffset: Option[Long]): Option[Long] = + offset.filter(_ != -1).flatMap(offset => logEndOffset.map(_ - offset)) + + private def describePartitions(group: String, + coordinator: Option[Node], + topicPartitions: Seq[TopicPartition], + getPartitionOffset: TopicPartition => Option[Long], + consumerIdOpt: Option[String], + hostOpt: Option[String], + clientIdOpt: Option[String]): Array[PartitionAssignmentState] = { + + def getDescribePartitionResult(topicPartition: TopicPartition, logEndOffsetOpt: Option[Long]): PartitionAssignmentState = { + val offset = getPartitionOffset(topicPartition) + PartitionAssignmentState(group, coordinator, Option(topicPartition.topic), Option(topicPartition.partition), offset, + getLag(offset, logEndOffsetOpt), consumerIdOpt, hostOpt, clientIdOpt, logEndOffsetOpt) + } + + getLogEndOffsets(group, topicPartitions).map { + logEndOffsetResult => + logEndOffsetResult._2 match { + case LogOffsetResult.LogOffset(logEndOffset) => getDescribePartitionResult(logEndOffsetResult._1, Some(logEndOffset)) + case LogOffsetResult.Unknown => getDescribePartitionResult(logEndOffsetResult._1, None) + case LogOffsetResult.Ignore => null + } + }.toArray + } + + def resetOffsets(): Map[String, Map[TopicPartition, OffsetAndMetadata]] = { + val groupIds = + if (opts.options.has(opts.allGroupsOpt)) listConsumerGroups() + else opts.options.valuesOf(opts.groupOpt).asScala + + val consumerGroups = adminClient.describeConsumerGroups( + groupIds.asJava, + withTimeoutMs(new DescribeConsumerGroupsOptions) + ).describedGroups() + + val result = + consumerGroups.asScala.foldLeft(immutable.Map[String, Map[TopicPartition, OffsetAndMetadata]]()) { + case (acc, (groupId, groupDescription)) => + groupDescription.get.state().toString match { + case "Empty" | "Dead" => + val partitionsToReset = getPartitionsToReset(groupId) + val preparedOffsets = prepareOffsetsToReset(groupId, partitionsToReset) + + // Dry-run is the default behavior if --execute is not specified + val dryRun = opts.options.has(opts.dryRunOpt) || !opts.options.has(opts.executeOpt) + if (!dryRun) { + adminClient.alterConsumerGroupOffsets( + groupId, + preparedOffsets.asJava, + withTimeoutMs(new AlterConsumerGroupOffsetsOptions) + ).all.get + } + acc.updated(groupId, preparedOffsets) + case currentState => + printError(s"Assignments can only be reset if the group '$groupId' is inactive, but the current state is $currentState.") + acc.updated(groupId, Map.empty) + } + } + result + } + + def deleteOffsets(groupId: String, topics: List[String]): (Errors, Map[TopicPartition, Throwable]) = { + val partitionLevelResult = mutable.Map[TopicPartition, Throwable]() + + val (topicWithPartitions, topicWithoutPartitions) = topics.partition(_.contains(":")) + val knownPartitions = topicWithPartitions.flatMap(parseTopicsWithPartitions) + + // Get the partitions of topics that the user did not explicitly specify the partitions + val describeTopicsResult = adminClient.describeTopics( + topicWithoutPartitions.asJava, + withTimeoutMs(new DescribeTopicsOptions)) + + val unknownPartitions = describeTopicsResult.topicNameValues().asScala.flatMap { case (topic, future) => + Try(future.get()) match { + case Success(description) => description.partitions().asScala.map { partition => + new TopicPartition(topic, partition.partition()) + } + case Failure(e) => + partitionLevelResult += new TopicPartition(topic, -1) -> e + List.empty + } + } + + val partitions = knownPartitions ++ unknownPartitions + + val deleteResult = adminClient.deleteConsumerGroupOffsets( + groupId, + partitions.toSet.asJava, + withTimeoutMs(new DeleteConsumerGroupOffsetsOptions) + ) + + var topLevelException = Errors.NONE + Try(deleteResult.all.get) match { + case Success(_) => + case Failure(e) => topLevelException = Errors.forException(e.getCause) + } + + partitions.foreach { partition => + Try(deleteResult.partitionResult(partition).get()) match { + case Success(_) => partitionLevelResult += partition -> null + case Failure(e) => partitionLevelResult += partition -> e + } + } + + (topLevelException, partitionLevelResult) + } + + def deleteOffsets(): Unit = { + val groupId = opts.options.valueOf(opts.groupOpt) + val topics = opts.options.valuesOf(opts.topicOpt).asScala.toList + + val (topLevelResult, partitionLevelResult) = deleteOffsets(groupId, topics) + + topLevelResult match { + case Errors.NONE => + println(s"Request succeed for deleting offsets with topic ${topics.mkString(", ")} group $groupId") + case Errors.INVALID_GROUP_ID => + printError(s"'$groupId' is not valid.") + case Errors.GROUP_ID_NOT_FOUND => + printError(s"'$groupId' does not exist.") + case Errors.GROUP_AUTHORIZATION_FAILED => + printError(s"Access to '$groupId' is not authorized.") + case Errors.NON_EMPTY_GROUP => + printError(s"Deleting offsets of a consumer group '$groupId' is forbidden if the group is not empty.") + case Errors.GROUP_SUBSCRIBED_TO_TOPIC | + Errors.TOPIC_AUTHORIZATION_FAILED | + Errors.UNKNOWN_TOPIC_OR_PARTITION => + printError(s"Encounter some partition level error, see the follow-up details:") + case _ => + printError(s"Encounter some unknown error: $topLevelResult") + } + + println("\n%-30s %-15s %-15s".format("TOPIC", "PARTITION", "STATUS")) + partitionLevelResult.toList.sortBy(t => t._1.topic + t._1.partition.toString).foreach { case (tp, error) => + println("%-30s %-15s %-15s".format( + tp.topic, + if (tp.partition >= 0) tp.partition else "Not Provided", + if (error != null) s"Error: ${error.getMessage}" else "Successful" + )) + } + } + + private[admin] def describeConsumerGroups(groupIds: Seq[String]): mutable.Map[String, ConsumerGroupDescription] = { + adminClient.describeConsumerGroups( + groupIds.asJava, + withTimeoutMs(new DescribeConsumerGroupsOptions) + ).describedGroups().asScala.map { + case (groupId, groupDescriptionFuture) => (groupId, groupDescriptionFuture.get()) + } + } + + /** + * Returns the state of the specified consumer group and partition assignment states + */ + def collectGroupOffsets(groupId: String): (Option[String], Option[Seq[PartitionAssignmentState]]) = { + collectGroupsOffsets(List(groupId)).getOrElse(groupId, (None, None)) + } + + /** + * Returns states of the specified consumer groups and partition assignment states + */ + def collectGroupsOffsets(groupIds: Seq[String]): TreeMap[String, (Option[String], Option[Seq[PartitionAssignmentState]])] = { + val consumerGroups = describeConsumerGroups(groupIds) + + val groupOffsets = TreeMap[String, (Option[String], Option[Seq[PartitionAssignmentState]])]() ++ (for ((groupId, consumerGroup) <- consumerGroups) yield { + val state = consumerGroup.state + val committedOffsets = getCommittedOffsets(groupId) + // The admin client returns `null` as a value to indicate that there is not committed offset for a partition. + def getPartitionOffset(tp: TopicPartition): Option[Long] = committedOffsets.get(tp).filter(_ != null).map(_.offset) + var assignedTopicPartitions = ListBuffer[TopicPartition]() + val rowsWithConsumer = consumerGroup.members.asScala.filter(!_.assignment.topicPartitions.isEmpty).toSeq + .sortWith(_.assignment.topicPartitions.size > _.assignment.topicPartitions.size).flatMap { consumerSummary => + val topicPartitions = consumerSummary.assignment.topicPartitions.asScala + assignedTopicPartitions = assignedTopicPartitions ++ topicPartitions + collectConsumerAssignment(groupId, Option(consumerGroup.coordinator), topicPartitions.toList, + getPartitionOffset, Some(s"${consumerSummary.consumerId}"), Some(s"${consumerSummary.host}"), + Some(s"${consumerSummary.clientId}")) + } + val unassignedPartitions = committedOffsets.filterNot { case (tp, _) => assignedTopicPartitions.contains(tp) } + val rowsWithoutConsumer = if (unassignedPartitions.nonEmpty) { + collectConsumerAssignment( + groupId, + Option(consumerGroup.coordinator), + unassignedPartitions.keySet.toSeq, + getPartitionOffset, + Some(MISSING_COLUMN_VALUE), + Some(MISSING_COLUMN_VALUE), + Some(MISSING_COLUMN_VALUE)).toSeq + } else + Seq.empty + + groupId -> (Some(state.toString), Some(rowsWithConsumer ++ rowsWithoutConsumer)) + }).toMap + + groupOffsets + } + + private[admin] def collectGroupMembers(groupId: String, verbose: Boolean): (Option[String], Option[Seq[MemberAssignmentState]]) = { + collectGroupsMembers(Seq(groupId), verbose)(groupId) + } + + private[admin] def collectGroupsMembers(groupIds: Seq[String], verbose: Boolean): TreeMap[String, (Option[String], Option[Seq[MemberAssignmentState]])] = { + val consumerGroups = describeConsumerGroups(groupIds) + TreeMap[String, (Option[String], Option[Seq[MemberAssignmentState]])]() ++ (for ((groupId, consumerGroup) <- consumerGroups) yield { + val state = consumerGroup.state.toString + val memberAssignmentStates = consumerGroup.members().asScala.map(consumer => + MemberAssignmentState( + groupId, + consumer.consumerId, + consumer.host, + consumer.clientId, + consumer.groupInstanceId.orElse(""), + consumer.assignment.topicPartitions.size(), + if (verbose) consumer.assignment.topicPartitions.asScala.toList else List() + )).toList + groupId -> (Some(state), Option(memberAssignmentStates)) + }).toMap + } + + private[admin] def collectGroupState(groupId: String): GroupState = { + collectGroupsState(Seq(groupId))(groupId) + } + + private[admin] def collectGroupsState(groupIds: Seq[String]): TreeMap[String, GroupState] = { + val consumerGroups = describeConsumerGroups(groupIds) + TreeMap[String, GroupState]() ++ (for ((groupId, groupDescription) <- consumerGroups) yield { + groupId -> GroupState( + groupId, + groupDescription.coordinator, + groupDescription.partitionAssignor(), + groupDescription.state.toString, + groupDescription.members().size + ) + }).toMap + } + + private def getLogEndOffsets(groupId: String, topicPartitions: Seq[TopicPartition]): Map[TopicPartition, LogOffsetResult] = { + val endOffsets = topicPartitions.map { topicPartition => + topicPartition -> OffsetSpec.latest + }.toMap + val offsets = adminClient.listOffsets( + endOffsets.asJava, + withTimeoutMs(new ListOffsetsOptions) + ).all.get + topicPartitions.map { topicPartition => + Option(offsets.get(topicPartition)) match { + case Some(listOffsetsResultInfo) => topicPartition -> LogOffsetResult.LogOffset(listOffsetsResultInfo.offset) + case _ => topicPartition -> LogOffsetResult.Unknown + } + }.toMap + } + + private def getLogStartOffsets(groupId: String, topicPartitions: Seq[TopicPartition]): Map[TopicPartition, LogOffsetResult] = { + val startOffsets = topicPartitions.map { topicPartition => + topicPartition -> OffsetSpec.earliest + }.toMap + val offsets = adminClient.listOffsets( + startOffsets.asJava, + withTimeoutMs(new ListOffsetsOptions) + ).all.get + topicPartitions.map { topicPartition => + Option(offsets.get(topicPartition)) match { + case Some(listOffsetsResultInfo) => topicPartition -> LogOffsetResult.LogOffset(listOffsetsResultInfo.offset) + case _ => topicPartition -> LogOffsetResult.Unknown + } + }.toMap + } + + private def getLogTimestampOffsets(groupId: String, topicPartitions: Seq[TopicPartition], timestamp: java.lang.Long): Map[TopicPartition, LogOffsetResult] = { + val timestampOffsets = topicPartitions.map { topicPartition => + topicPartition -> OffsetSpec.forTimestamp(timestamp) + }.toMap + val offsets = adminClient.listOffsets( + timestampOffsets.asJava, + withTimeoutMs(new ListOffsetsOptions) + ).all.get + val (successfulOffsetsForTimes, unsuccessfulOffsetsForTimes) = + offsets.asScala.partition(_._2.offset != ListOffsetsResponse.UNKNOWN_OFFSET) + + val successfulLogTimestampOffsets = successfulOffsetsForTimes.map { + case (topicPartition, listOffsetsResultInfo) => topicPartition -> LogOffsetResult.LogOffset(listOffsetsResultInfo.offset) + }.toMap + + unsuccessfulOffsetsForTimes.foreach { entry => + println(s"\nWarn: Partition " + entry._1.partition() + " from topic " + entry._1.topic() + + " is empty. Falling back to latest known offset.") + } + + successfulLogTimestampOffsets ++ getLogEndOffsets(groupId, unsuccessfulOffsetsForTimes.keySet.toSeq) + } + + def close(): Unit = { + adminClient.close() + } + + // Visibility for testing + protected def createAdminClient(configOverrides: Map[String, String]): Admin = { + val props = if (opts.options.has(opts.commandConfigOpt)) Utils.loadProps(opts.options.valueOf(opts.commandConfigOpt)) else new Properties() + props.put(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, opts.options.valueOf(opts.bootstrapServerOpt)) + configOverrides.forKeyValue { (k, v) => props.put(k, v)} + Admin.create(props) + } + + private def withTimeoutMs [T <: AbstractOptions[T]] (options : T) = { + val t = opts.options.valueOf(opts.timeoutMsOpt).intValue() + options.timeoutMs(t) + } + + private def parseTopicsWithPartitions(topicArg: String): Seq[TopicPartition] = { + def partitionNum(partition: String): Int = { + try { + partition.toInt + } catch { + case _: NumberFormatException => + throw new IllegalArgumentException(s"Invalid partition '$partition' specified in topic arg '$topicArg''") + } + } + topicArg.split(":") match { + case Array(topic, partitions) => + partitions.split(",").map(partition => new TopicPartition(topic, partitionNum(partition))) + case _ => + throw new IllegalArgumentException(s"Invalid topic arg '$topicArg', expected topic name and partitions") + } + } + + private def parseTopicPartitionsToReset(topicArgs: Seq[String]): Seq[TopicPartition] = { + val (topicsWithPartitions, topics) = topicArgs.partition(_.contains(":")) + val specifiedPartitions = topicsWithPartitions.flatMap(parseTopicsWithPartitions) + + val unspecifiedPartitions = if (topics.nonEmpty) { + val descriptionMap = adminClient.describeTopics( + topics.asJava, + withTimeoutMs(new DescribeTopicsOptions) + ).allTopicNames().get.asScala + descriptionMap.flatMap { case (topic, description) => + description.partitions().asScala.map { tpInfo => + new TopicPartition(topic, tpInfo.partition) + } + } + } else + Seq.empty + specifiedPartitions ++ unspecifiedPartitions + } + + private def getPartitionsToReset(groupId: String): Seq[TopicPartition] = { + if (opts.options.has(opts.allTopicsOpt)) { + getCommittedOffsets(groupId).keys.toSeq + } else if (opts.options.has(opts.topicOpt)) { + val topics = opts.options.valuesOf(opts.topicOpt).asScala + parseTopicPartitionsToReset(topics) + } else { + if (opts.options.has(opts.resetFromFileOpt)) + Nil + else + CommandLineUtils.printUsageAndDie(opts.parser, "One of the reset scopes should be defined: --all-topics, --topic.") + } + } + + private def getCommittedOffsets(groupId: String): Map[TopicPartition, OffsetAndMetadata] = { + adminClient.listConsumerGroupOffsets( + groupId, + withTimeoutMs(new ListConsumerGroupOffsetsOptions) + ).partitionsToOffsetAndMetadata.get.asScala + } + + type GroupMetadata = immutable.Map[String, immutable.Map[TopicPartition, OffsetAndMetadata]] + private def parseResetPlan(resetPlanCsv: String): GroupMetadata = { + def updateGroupMetadata(group: String, topic: String, partition: Int, offset: Long, acc: GroupMetadata) = { + val topicPartition = new TopicPartition(topic, partition) + val offsetAndMetadata = new OffsetAndMetadata(offset) + val dataMap = acc.getOrElse(group, immutable.Map()).updated(topicPartition, offsetAndMetadata) + acc.updated(group, dataMap) + } + val csvReader = CsvUtils().readerFor[CsvRecordNoGroup] + val lines = resetPlanCsv.split("\n") + val isSingleGroupQuery = opts.options.valuesOf(opts.groupOpt).size() == 1 + val isOldCsvFormat = lines.headOption.flatMap(line => + Try(csvReader.readValue[CsvRecordNoGroup](line)).toOption).nonEmpty + // Single group CSV format: "topic,partition,offset" + val dataMap = if (isSingleGroupQuery && isOldCsvFormat) { + val group = opts.options.valueOf(opts.groupOpt) + lines.foldLeft(immutable.Map[String, immutable.Map[TopicPartition, OffsetAndMetadata]]()) { (acc, line) => + val CsvRecordNoGroup(topic, partition, offset) = csvReader.readValue[CsvRecordNoGroup](line) + updateGroupMetadata(group, topic, partition, offset, acc) + } + // Multiple group CSV format: "group,topic,partition,offset" + } else { + val csvReader = CsvUtils().readerFor[CsvRecordWithGroup] + lines.foldLeft(immutable.Map[String, immutable.Map[TopicPartition, OffsetAndMetadata]]()) { (acc, line) => + val CsvRecordWithGroup(group, topic, partition, offset) = csvReader.readValue[CsvRecordWithGroup](line) + updateGroupMetadata(group, topic, partition, offset, acc) + } + } + dataMap + } + + private def prepareOffsetsToReset(groupId: String, + partitionsToReset: Seq[TopicPartition]): Map[TopicPartition, OffsetAndMetadata] = { + if (opts.options.has(opts.resetToOffsetOpt)) { + val offset = opts.options.valueOf(opts.resetToOffsetOpt) + checkOffsetsRange(groupId, partitionsToReset.map((_, offset)).toMap).map { + case (topicPartition, newOffset) => (topicPartition, new OffsetAndMetadata(newOffset)) + } + } else if (opts.options.has(opts.resetToEarliestOpt)) { + val logStartOffsets = getLogStartOffsets(groupId, partitionsToReset) + partitionsToReset.map { topicPartition => + logStartOffsets.get(topicPartition) match { + case Some(LogOffsetResult.LogOffset(offset)) => (topicPartition, new OffsetAndMetadata(offset)) + case _ => CommandLineUtils.printUsageAndDie(opts.parser, s"Error getting starting offset of topic partition: $topicPartition") + } + }.toMap + } else if (opts.options.has(opts.resetToLatestOpt)) { + val logEndOffsets = getLogEndOffsets(groupId, partitionsToReset) + partitionsToReset.map { topicPartition => + logEndOffsets.get(topicPartition) match { + case Some(LogOffsetResult.LogOffset(offset)) => (topicPartition, new OffsetAndMetadata(offset)) + case _ => CommandLineUtils.printUsageAndDie(opts.parser, s"Error getting ending offset of topic partition: $topicPartition") + } + }.toMap + } else if (opts.options.has(opts.resetShiftByOpt)) { + val currentCommittedOffsets = getCommittedOffsets(groupId) + val requestedOffsets = partitionsToReset.map { topicPartition => + val shiftBy = opts.options.valueOf(opts.resetShiftByOpt) + val currentOffset = currentCommittedOffsets.getOrElse(topicPartition, + throw new IllegalArgumentException(s"Cannot shift offset for partition $topicPartition since there is no current committed offset")).offset + (topicPartition, currentOffset + shiftBy) + }.toMap + checkOffsetsRange(groupId, requestedOffsets).map { + case (topicPartition, newOffset) => (topicPartition, new OffsetAndMetadata(newOffset)) + } + } else if (opts.options.has(opts.resetToDatetimeOpt)) { + val timestamp = Utils.getDateTime(opts.options.valueOf(opts.resetToDatetimeOpt)) + val logTimestampOffsets = getLogTimestampOffsets(groupId, partitionsToReset, timestamp) + partitionsToReset.map { topicPartition => + val logTimestampOffset = logTimestampOffsets.get(topicPartition) + logTimestampOffset match { + case Some(LogOffsetResult.LogOffset(offset)) => (topicPartition, new OffsetAndMetadata(offset)) + case _ => CommandLineUtils.printUsageAndDie(opts.parser, s"Error getting offset by timestamp of topic partition: $topicPartition") + } + }.toMap + } else if (opts.options.has(opts.resetByDurationOpt)) { + val duration = opts.options.valueOf(opts.resetByDurationOpt) + val durationParsed = Duration.parse(duration) + val now = Instant.now() + durationParsed.negated().addTo(now) + val timestamp = now.minus(durationParsed).toEpochMilli + val logTimestampOffsets = getLogTimestampOffsets(groupId, partitionsToReset, timestamp) + partitionsToReset.map { topicPartition => + val logTimestampOffset = logTimestampOffsets.get(topicPartition) + logTimestampOffset match { + case Some(LogOffsetResult.LogOffset(offset)) => (topicPartition, new OffsetAndMetadata(offset)) + case _ => CommandLineUtils.printUsageAndDie(opts.parser, s"Error getting offset by timestamp of topic partition: $topicPartition") + } + }.toMap + } else if (resetPlanFromFile.isDefined) { + resetPlanFromFile.map(resetPlan => resetPlan.get(groupId).map { resetPlanForGroup => + val requestedOffsets = resetPlanForGroup.keySet.map { topicPartition => + topicPartition -> resetPlanForGroup(topicPartition).offset + }.toMap + checkOffsetsRange(groupId, requestedOffsets).map { + case (topicPartition, newOffset) => (topicPartition, new OffsetAndMetadata(newOffset)) + } + } match { + case Some(resetPlanForGroup) => resetPlanForGroup + case None => + printError(s"No reset plan for group $groupId found") + Map[TopicPartition, OffsetAndMetadata]() + }).getOrElse(Map.empty) + } else if (opts.options.has(opts.resetToCurrentOpt)) { + val currentCommittedOffsets = getCommittedOffsets(groupId) + val (partitionsToResetWithCommittedOffset, partitionsToResetWithoutCommittedOffset) = + partitionsToReset.partition(currentCommittedOffsets.keySet.contains(_)) + + val preparedOffsetsForPartitionsWithCommittedOffset = partitionsToResetWithCommittedOffset.map { topicPartition => + (topicPartition, new OffsetAndMetadata(currentCommittedOffsets.get(topicPartition) match { + case Some(offset) => offset.offset + case None => throw new IllegalStateException(s"Expected a valid current offset for topic partition: $topicPartition") + })) + }.toMap + + val preparedOffsetsForPartitionsWithoutCommittedOffset = getLogEndOffsets(groupId, partitionsToResetWithoutCommittedOffset).map { + case (topicPartition, LogOffsetResult.LogOffset(offset)) => (topicPartition, new OffsetAndMetadata(offset)) + case (topicPartition, _) => CommandLineUtils.printUsageAndDie(opts.parser, s"Error getting ending offset of topic partition: $topicPartition") + } + + preparedOffsetsForPartitionsWithCommittedOffset ++ preparedOffsetsForPartitionsWithoutCommittedOffset + } else { + CommandLineUtils.printUsageAndDie(opts.parser, "Option '%s' requires one of the following scenarios: %s".format(opts.resetOffsetsOpt, opts.allResetOffsetScenarioOpts) ) + } + } + + private def checkOffsetsRange(groupId: String, requestedOffsets: Map[TopicPartition, Long]) = { + val logStartOffsets = getLogStartOffsets(groupId, requestedOffsets.keySet.toSeq) + val logEndOffsets = getLogEndOffsets(groupId, requestedOffsets.keySet.toSeq) + requestedOffsets.map { case (topicPartition, offset) => (topicPartition, + logEndOffsets.get(topicPartition) match { + case Some(LogOffsetResult.LogOffset(endOffset)) if offset > endOffset => + warn(s"New offset ($offset) is higher than latest offset for topic partition $topicPartition. Value will be set to $endOffset") + endOffset + + case Some(_) => logStartOffsets.get(topicPartition) match { + case Some(LogOffsetResult.LogOffset(startOffset)) if offset < startOffset => + warn(s"New offset ($offset) is lower than earliest offset for topic partition $topicPartition. Value will be set to $startOffset") + startOffset + + case _ => offset + } + + case None => // the control should not reach here + throw new IllegalStateException(s"Unexpected non-existing offset value for topic partition $topicPartition") + }) + } + } + + def exportOffsetsToCsv(assignments: Map[String, Map[TopicPartition, OffsetAndMetadata]]): String = { + val isSingleGroupQuery = opts.options.valuesOf(opts.groupOpt).size() == 1 + val csvWriter = + if (isSingleGroupQuery) CsvUtils().writerFor[CsvRecordNoGroup] + else CsvUtils().writerFor[CsvRecordWithGroup] + val rows = assignments.flatMap { case (groupId, partitionInfo) => + partitionInfo.map { case (k: TopicPartition, v: OffsetAndMetadata) => + val csvRecord = + if (isSingleGroupQuery) CsvRecordNoGroup(k.topic, k.partition, v.offset) + else CsvRecordWithGroup(groupId, k.topic, k.partition, v.offset) + csvWriter.writeValueAsString(csvRecord) + } + } + rows.mkString("") + } + + def deleteGroups(): Map[String, Throwable] = { + val groupIds = + if (opts.options.has(opts.allGroupsOpt)) listConsumerGroups() + else opts.options.valuesOf(opts.groupOpt).asScala + + val groupsToDelete = adminClient.deleteConsumerGroups( + groupIds.asJava, + withTimeoutMs(new DeleteConsumerGroupsOptions) + ).deletedGroups().asScala + + val result = groupsToDelete.map { case (g, f) => + Try(f.get) match { + case Success(_) => g -> null + case Failure(e) => g -> e + } + } + + val (success, failed) = result.partition { + case (_, error) => error == null + } + + if (failed.isEmpty) { + println(s"Deletion of requested consumer groups (${success.keySet.mkString("'", "', '", "'")}) was successful.") + } + else { + printError("Deletion of some consumer groups failed:") + failed.foreach { + case (group, error) => println(s"* Group '$group' could not be deleted due to: ${error.toString}") + } + if (success.nonEmpty) + println(s"\nThese consumer groups were deleted successfully: ${success.keySet.mkString("'", "', '", "'")}") + } + + result.toMap + } + } + + sealed trait LogOffsetResult + + object LogOffsetResult { + case class LogOffset(value: Long) extends LogOffsetResult + case object Unknown extends LogOffsetResult + case object Ignore extends LogOffsetResult + } + + class ConsumerGroupCommandOptions(args: Array[String]) extends CommandDefaultOptions(args) { + val BootstrapServerDoc = "REQUIRED: The server(s) to connect to." + val GroupDoc = "The consumer group we wish to act on." + val TopicDoc = "The topic whose consumer group information should be deleted or topic whose should be included in the reset offset process. " + + "In `reset-offsets` case, partitions can be specified using this format: `topic1:0,1,2`, where 0,1,2 are the partition to be included in the process. " + + "Reset-offsets also supports multiple topic inputs." + val AllTopicsDoc = "Consider all topics assigned to a group in the `reset-offsets` process." + val ListDoc = "List all consumer groups." + val DescribeDoc = "Describe consumer group and list offset lag (number of messages not yet processed) related to given group." + val AllGroupsDoc = "Apply to all consumer groups." + val nl = System.getProperty("line.separator") + val DeleteDoc = "Pass in groups to delete topic partition offsets and ownership information " + + "over the entire consumer group. For instance --group g1 --group g2" + val TimeoutMsDoc = "The timeout that can be set for some use cases. For example, it can be used when describing the group " + + "to specify the maximum amount of time in milliseconds to wait before the group stabilizes (when the group is just created, " + + "or is going through some changes)." + val CommandConfigDoc = "Property file containing configs to be passed to Admin Client and Consumer." + val ResetOffsetsDoc = "Reset offsets of consumer group. Supports one consumer group at the time, and instances should be inactive" + nl + + "Has 2 execution options: --dry-run (the default) to plan which offsets to reset, and --execute to update the offsets. " + + "Additionally, the --export option is used to export the results to a CSV format." + nl + + "You must choose one of the following reset specifications: --to-datetime, --by-period, --to-earliest, " + + "--to-latest, --shift-by, --from-file, --to-current." + nl + + "To define the scope use --all-topics or --topic. One scope must be specified unless you use '--from-file'." + val DryRunDoc = "Only show results without executing changes on Consumer Groups. Supported operations: reset-offsets." + val ExecuteDoc = "Execute operation. Supported operations: reset-offsets." + val ExportDoc = "Export operation execution to a CSV file. Supported operations: reset-offsets." + val ResetToOffsetDoc = "Reset offsets to a specific offset." + val ResetFromFileDoc = "Reset offsets to values defined in CSV file." + val ResetToDatetimeDoc = "Reset offsets to offset from datetime. Format: 'YYYY-MM-DDTHH:mm:SS.sss'" + val ResetByDurationDoc = "Reset offsets to offset by duration from current timestamp. Format: 'PnDTnHnMnS'" + val ResetToEarliestDoc = "Reset offsets to earliest offset." + val ResetToLatestDoc = "Reset offsets to latest offset." + val ResetToCurrentDoc = "Reset offsets to current offset." + val ResetShiftByDoc = "Reset offsets shifting current offset by 'n', where 'n' can be positive or negative." + val MembersDoc = "Describe members of the group. This option may be used with '--describe' and '--bootstrap-server' options only." + nl + + "Example: --bootstrap-server localhost:9092 --describe --group group1 --members" + val VerboseDoc = "Provide additional information, if any, when describing the group. This option may be used " + + "with '--offsets'/'--members'/'--state' and '--bootstrap-server' options only." + nl + "Example: --bootstrap-server localhost:9092 --describe --group group1 --members --verbose" + val OffsetsDoc = "Describe the group and list all topic partitions in the group along with their offset lag. " + + "This is the default sub-action of and may be used with '--describe' and '--bootstrap-server' options only." + nl + + "Example: --bootstrap-server localhost:9092 --describe --group group1 --offsets" + val StateDoc = "When specified with '--describe', includes the state of the group." + nl + + "Example: --bootstrap-server localhost:9092 --describe --group group1 --state" + nl + + "When specified with '--list', it displays the state of all groups. It can also be used to list groups with specific states." + nl + + "Example: --bootstrap-server localhost:9092 --list --state stable,empty" + nl + + "This option may be used with '--describe', '--list' and '--bootstrap-server' options only." + val DeleteOffsetsDoc = "Delete offsets of consumer group. Supports one consumer group at the time, and multiple topics." + + val bootstrapServerOpt = parser.accepts("bootstrap-server", BootstrapServerDoc) + .withRequiredArg + .describedAs("server to connect to") + .ofType(classOf[String]) + val groupOpt = parser.accepts("group", GroupDoc) + .withRequiredArg + .describedAs("consumer group") + .ofType(classOf[String]) + val topicOpt = parser.accepts("topic", TopicDoc) + .withRequiredArg + .describedAs("topic") + .ofType(classOf[String]) + val allTopicsOpt = parser.accepts("all-topics", AllTopicsDoc) + val listOpt = parser.accepts("list", ListDoc) + val describeOpt = parser.accepts("describe", DescribeDoc) + val allGroupsOpt = parser.accepts("all-groups", AllGroupsDoc) + val deleteOpt = parser.accepts("delete", DeleteDoc) + val timeoutMsOpt = parser.accepts("timeout", TimeoutMsDoc) + .withRequiredArg + .describedAs("timeout (ms)") + .ofType(classOf[Long]) + .defaultsTo(5000) + val commandConfigOpt = parser.accepts("command-config", CommandConfigDoc) + .withRequiredArg + .describedAs("command config property file") + .ofType(classOf[String]) + val resetOffsetsOpt = parser.accepts("reset-offsets", ResetOffsetsDoc) + val deleteOffsetsOpt = parser.accepts("delete-offsets", DeleteOffsetsDoc) + val dryRunOpt = parser.accepts("dry-run", DryRunDoc) + val executeOpt = parser.accepts("execute", ExecuteDoc) + val exportOpt = parser.accepts("export", ExportDoc) + val resetToOffsetOpt = parser.accepts("to-offset", ResetToOffsetDoc) + .withRequiredArg() + .describedAs("offset") + .ofType(classOf[Long]) + val resetFromFileOpt = parser.accepts("from-file", ResetFromFileDoc) + .withRequiredArg() + .describedAs("path to CSV file") + .ofType(classOf[String]) + val resetToDatetimeOpt = parser.accepts("to-datetime", ResetToDatetimeDoc) + .withRequiredArg() + .describedAs("datetime") + .ofType(classOf[String]) + val resetByDurationOpt = parser.accepts("by-duration", ResetByDurationDoc) + .withRequiredArg() + .describedAs("duration") + .ofType(classOf[String]) + val resetToEarliestOpt = parser.accepts("to-earliest", ResetToEarliestDoc) + val resetToLatestOpt = parser.accepts("to-latest", ResetToLatestDoc) + val resetToCurrentOpt = parser.accepts("to-current", ResetToCurrentDoc) + val resetShiftByOpt = parser.accepts("shift-by", ResetShiftByDoc) + .withRequiredArg() + .describedAs("number-of-offsets") + .ofType(classOf[Long]) + val membersOpt = parser.accepts("members", MembersDoc) + .availableIf(describeOpt) + val verboseOpt = parser.accepts("verbose", VerboseDoc) + .availableIf(describeOpt) + val offsetsOpt = parser.accepts("offsets", OffsetsDoc) + .availableIf(describeOpt) + val stateOpt = parser.accepts("state", StateDoc) + .availableIf(describeOpt, listOpt) + .withOptionalArg() + .ofType(classOf[String]) + + options = parser.parse(args : _*) + + val allGroupSelectionScopeOpts = immutable.Set[OptionSpec[_]](groupOpt, allGroupsOpt) + val allConsumerGroupLevelOpts = immutable.Set[OptionSpec[_]](listOpt, describeOpt, deleteOpt, resetOffsetsOpt) + val allResetOffsetScenarioOpts = immutable.Set[OptionSpec[_]](resetToOffsetOpt, resetShiftByOpt, + resetToDatetimeOpt, resetByDurationOpt, resetToEarliestOpt, resetToLatestOpt, resetToCurrentOpt, resetFromFileOpt) + val allDeleteOffsetsOpts = immutable.Set[OptionSpec[_]](groupOpt, topicOpt) + + def checkArgs(): Unit = { + + CommandLineUtils.checkRequiredArgs(parser, options, bootstrapServerOpt) + + if (options.has(describeOpt)) { + if (!options.has(groupOpt) && !options.has(allGroupsOpt)) + CommandLineUtils.printUsageAndDie(parser, + s"Option $describeOpt takes one of these options: ${allGroupSelectionScopeOpts.mkString(", ")}") + val mutuallyExclusiveOpts: Set[OptionSpec[_]] = Set(membersOpt, offsetsOpt, stateOpt) + if (mutuallyExclusiveOpts.toList.map(o => if (options.has(o)) 1 else 0).sum > 1) { + CommandLineUtils.printUsageAndDie(parser, + s"Option $describeOpt takes at most one of these options: ${mutuallyExclusiveOpts.mkString(", ")}") + } + if (options.has(stateOpt) && options.valueOf(stateOpt) != null) + CommandLineUtils.printUsageAndDie(parser, + s"Option $describeOpt does not take a value for $stateOpt") + } else { + if (options.has(timeoutMsOpt)) + debug(s"Option $timeoutMsOpt is applicable only when $describeOpt is used.") + } + + if (options.has(deleteOpt)) { + if (!options.has(groupOpt) && !options.has(allGroupsOpt)) + CommandLineUtils.printUsageAndDie(parser, + s"Option $deleteOpt takes one of these options: ${allGroupSelectionScopeOpts.mkString(", ")}") + if (options.has(topicOpt)) + CommandLineUtils.printUsageAndDie(parser, s"The consumer does not support topic-specific offset " + + "deletion from a consumer group.") + } + + if (options.has(deleteOffsetsOpt)) { + if (!options.has(groupOpt) || !options.has(topicOpt)) + CommandLineUtils.printUsageAndDie(parser, + s"Option $deleteOffsetsOpt takes the following options: ${allDeleteOffsetsOpts.mkString(", ")}") + } + + if (options.has(resetOffsetsOpt)) { + if (options.has(dryRunOpt) && options.has(executeOpt)) + CommandLineUtils.printUsageAndDie(parser, s"Option $resetOffsetsOpt only accepts one of $executeOpt and $dryRunOpt") + + if (!options.has(dryRunOpt) && !options.has(executeOpt)) { + Console.err.println("WARN: No action will be performed as the --execute option is missing." + + "In a future major release, the default behavior of this command will be to prompt the user before " + + "executing the reset rather than doing a dry run. You should add the --dry-run option explicitly " + + "if you are scripting this command and want to keep the current default behavior without prompting.") + } + + if (!options.has(groupOpt) && !options.has(allGroupsOpt)) + CommandLineUtils.printUsageAndDie(parser, + s"Option $resetOffsetsOpt takes one of these options: ${allGroupSelectionScopeOpts.mkString(", ")}") + CommandLineUtils.checkInvalidArgs(parser, options, resetToOffsetOpt, allResetOffsetScenarioOpts - resetToOffsetOpt) + CommandLineUtils.checkInvalidArgs(parser, options, resetToDatetimeOpt, allResetOffsetScenarioOpts - resetToDatetimeOpt) + CommandLineUtils.checkInvalidArgs(parser, options, resetByDurationOpt, allResetOffsetScenarioOpts - resetByDurationOpt) + CommandLineUtils.checkInvalidArgs(parser, options, resetToEarliestOpt, allResetOffsetScenarioOpts - resetToEarliestOpt) + CommandLineUtils.checkInvalidArgs(parser, options, resetToLatestOpt, allResetOffsetScenarioOpts - resetToLatestOpt) + CommandLineUtils.checkInvalidArgs(parser, options, resetToCurrentOpt, allResetOffsetScenarioOpts - resetToCurrentOpt) + CommandLineUtils.checkInvalidArgs(parser, options, resetShiftByOpt, allResetOffsetScenarioOpts - resetShiftByOpt) + CommandLineUtils.checkInvalidArgs(parser, options, resetFromFileOpt, allResetOffsetScenarioOpts - resetFromFileOpt) + } + + CommandLineUtils.checkInvalidArgs(parser, options, groupOpt, allGroupSelectionScopeOpts - groupOpt) + CommandLineUtils.checkInvalidArgs(parser, options, groupOpt, allConsumerGroupLevelOpts - describeOpt - deleteOpt - resetOffsetsOpt) + CommandLineUtils.checkInvalidArgs(parser, options, topicOpt, allConsumerGroupLevelOpts - deleteOpt - resetOffsetsOpt) + } + } +} diff --git a/core/src/main/scala/kafka/admin/DelegationTokenCommand.scala b/core/src/main/scala/kafka/admin/DelegationTokenCommand.scala new file mode 100644 index 0000000..6465b14 --- /dev/null +++ b/core/src/main/scala/kafka/admin/DelegationTokenCommand.scala @@ -0,0 +1,219 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.admin + +import java.text.SimpleDateFormat +import java.util +import java.util.Base64 + +import joptsimple.ArgumentAcceptingOptionSpec +import kafka.utils.{CommandDefaultOptions, CommandLineUtils, Exit, Logging} +import org.apache.kafka.clients.CommonClientConfigs +import org.apache.kafka.clients.admin.{Admin, CreateDelegationTokenOptions, DescribeDelegationTokenOptions, ExpireDelegationTokenOptions, RenewDelegationTokenOptions} +import org.apache.kafka.common.security.auth.KafkaPrincipal +import org.apache.kafka.common.security.token.delegation.DelegationToken +import org.apache.kafka.common.utils.{SecurityUtils, Utils} + +import scala.jdk.CollectionConverters._ +import scala.collection.Set + +/** + * A command to manage delegation token. + */ +object DelegationTokenCommand extends Logging { + + def main(args: Array[String]): Unit = { + val opts = new DelegationTokenCommandOptions(args) + + CommandLineUtils.printHelpAndExitIfNeeded(opts, "This tool helps to create, renew, expire, or describe delegation tokens.") + + // should have exactly one action + val actions = Seq(opts.createOpt, opts.renewOpt, opts.expiryOpt, opts.describeOpt).count(opts.options.has _) + if(actions != 1) + CommandLineUtils.printUsageAndDie(opts.parser, "Command must include exactly one action: --create, --renew, --expire or --describe") + + opts.checkArgs() + + val adminClient = createAdminClient(opts) + + var exitCode = 0 + try { + if(opts.options.has(opts.createOpt)) + createToken(adminClient, opts) + else if(opts.options.has(opts.renewOpt)) + renewToken(adminClient, opts) + else if(opts.options.has(opts.expiryOpt)) + expireToken(adminClient, opts) + else if(opts.options.has(opts.describeOpt)) + describeToken(adminClient, opts) + } catch { + case e: Throwable => + println("Error while executing delegation token command : " + e.getMessage) + error(Utils.stackTrace(e)) + exitCode = 1 + } finally { + adminClient.close() + Exit.exit(exitCode) + } + } + + def createToken(adminClient: Admin, opts: DelegationTokenCommandOptions): DelegationToken = { + val renewerPrincipals = getPrincipals(opts, opts.renewPrincipalsOpt).getOrElse(new util.LinkedList[KafkaPrincipal]()) + val maxLifeTimeMs = opts.options.valueOf(opts.maxLifeTimeOpt).longValue + + println("Calling create token operation with renewers :" + renewerPrincipals +" , max-life-time-period :"+ maxLifeTimeMs) + val createDelegationTokenOptions = new CreateDelegationTokenOptions().maxlifeTimeMs(maxLifeTimeMs).renewers(renewerPrincipals) + val createResult = adminClient.createDelegationToken(createDelegationTokenOptions) + val token = createResult.delegationToken().get() + println("Created delegation token with tokenId : %s".format(token.tokenInfo.tokenId)); printToken(List(token)) + token + } + + def printToken(tokens: List[DelegationToken]): Unit = { + val dateFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm") + print("\n%-15s %-30s %-15s %-25s %-15s %-15s %-15s".format("TOKENID", "HMAC", "OWNER", "RENEWERS", "ISSUEDATE", "EXPIRYDATE", "MAXDATE")) + for (token <- tokens) { + val tokenInfo = token.tokenInfo + print("\n%-15s %-30s %-15s %-25s %-15s %-15s %-15s".format( + tokenInfo.tokenId, + token.hmacAsBase64String, + tokenInfo.owner, + tokenInfo.renewersAsString, + dateFormat.format(tokenInfo.issueTimestamp), + dateFormat.format(tokenInfo.expiryTimestamp), + dateFormat.format(tokenInfo.maxTimestamp))) + println() + } + } + + private def getPrincipals(opts: DelegationTokenCommandOptions, principalOptionSpec: ArgumentAcceptingOptionSpec[String]): Option[util.List[KafkaPrincipal]] = { + if (opts.options.has(principalOptionSpec)) + Some(opts.options.valuesOf(principalOptionSpec).asScala.map(s => SecurityUtils.parseKafkaPrincipal(s.trim)).toList.asJava) + else + None + } + + def renewToken(adminClient: Admin, opts: DelegationTokenCommandOptions): Long = { + val hmac = opts.options.valueOf(opts.hmacOpt) + val renewTimePeriodMs = opts.options.valueOf(opts.renewTimePeriodOpt).longValue() + println("Calling renew token operation with hmac :" + hmac +" , renew-time-period :"+ renewTimePeriodMs) + val renewResult = adminClient.renewDelegationToken(Base64.getDecoder.decode(hmac), new RenewDelegationTokenOptions().renewTimePeriodMs(renewTimePeriodMs)) + val expiryTimeStamp = renewResult.expiryTimestamp().get() + val dateFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm") + println("Completed renew operation. New expiry date : %s".format(dateFormat.format(expiryTimeStamp))) + expiryTimeStamp + } + + def expireToken(adminClient: Admin, opts: DelegationTokenCommandOptions): Long = { + val hmac = opts.options.valueOf(opts.hmacOpt) + val expiryTimePeriodMs = opts.options.valueOf(opts.expiryTimePeriodOpt).longValue() + println("Calling expire token operation with hmac :" + hmac +" , expire-time-period : "+ expiryTimePeriodMs) + val expireResult = adminClient.expireDelegationToken(Base64.getDecoder.decode(hmac), new ExpireDelegationTokenOptions().expiryTimePeriodMs(expiryTimePeriodMs)) + val expiryTimeStamp = expireResult.expiryTimestamp().get() + val dateFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm") + println("Completed expire operation. New expiry date : %s".format(dateFormat.format(expiryTimeStamp))) + expiryTimeStamp + } + + def describeToken(adminClient: Admin, opts: DelegationTokenCommandOptions): List[DelegationToken] = { + val ownerPrincipals = getPrincipals(opts, opts.ownerPrincipalsOpt) + if (ownerPrincipals.isEmpty) + println("Calling describe token operation for current user.") + else + println("Calling describe token operation for owners :" + ownerPrincipals.get) + + val describeResult = adminClient.describeDelegationToken(new DescribeDelegationTokenOptions().owners(ownerPrincipals.orNull)) + val tokens = describeResult.delegationTokens().get().asScala.toList + println("Total number of tokens : %s".format(tokens.size)); printToken(tokens) + tokens + } + + private def createAdminClient(opts: DelegationTokenCommandOptions): Admin = { + val props = Utils.loadProps(opts.options.valueOf(opts.commandConfigOpt)) + props.put(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, opts.options.valueOf(opts.bootstrapServerOpt)) + Admin.create(props) + } + + class DelegationTokenCommandOptions(args: Array[String]) extends CommandDefaultOptions(args) { + val BootstrapServerDoc = "REQUIRED: server(s) to use for bootstrapping." + val CommandConfigDoc = "REQUIRED: A property file containing configs to be passed to Admin Client. Token management" + + " operations are allowed in secure mode only. This config file is used to pass security related configs." + + val bootstrapServerOpt = parser.accepts("bootstrap-server", BootstrapServerDoc) + .withRequiredArg + .ofType(classOf[String]) + val commandConfigOpt = parser.accepts("command-config", CommandConfigDoc) + .withRequiredArg + .ofType(classOf[String]) + + val createOpt = parser.accepts("create", "Create a new delegation token. Use --renewer-principal option to pass renewers principals.") + val renewOpt = parser.accepts("renew", "Renew delegation token. Use --renew-time-period option to set renew time period.") + val expiryOpt = parser.accepts("expire", "Expire delegation token. Use --expiry-time-period option to expire the token.") + val describeOpt = parser.accepts("describe", "Describe delegation tokens for the given principals. Use --owner-principal to pass owner/renewer principals." + + " If --owner-principal option is not supplied, all the user owned tokens and tokens where user have Describe permission will be returned.") + + val ownerPrincipalsOpt = parser.accepts("owner-principal", "owner is a kafka principal. It is should be in principalType:name format.") + .withOptionalArg() + .ofType(classOf[String]) + + val renewPrincipalsOpt = parser.accepts("renewer-principal", "renewer is a kafka principal. It is should be in principalType:name format.") + .withOptionalArg() + .ofType(classOf[String]) + + val maxLifeTimeOpt = parser.accepts("max-life-time-period", "Max life period for the token in milliseconds. If the value is -1," + + " then token max life time will default to a server side config value (delegation.token.max.lifetime.ms).") + .withOptionalArg() + .ofType(classOf[Long]) + + val renewTimePeriodOpt = parser.accepts("renew-time-period", "Renew time period in milliseconds. If the value is -1, then the" + + " renew time period will default to a server side config value (delegation.token.expiry.time.ms).") + .withOptionalArg() + .ofType(classOf[Long]) + + val expiryTimePeriodOpt = parser.accepts("expiry-time-period", "Expiry time period in milliseconds. If the value is -1, then the" + + " token will get invalidated immediately." ) + .withOptionalArg() + .ofType(classOf[Long]) + + val hmacOpt = parser.accepts("hmac", "HMAC of the delegation token") + .withOptionalArg + .ofType(classOf[String]) + + options = parser.parse(args : _*) + + def checkArgs(): Unit = { + // check required args + CommandLineUtils.checkRequiredArgs(parser, options, bootstrapServerOpt, commandConfigOpt) + + if (options.has(createOpt)) + CommandLineUtils.checkRequiredArgs(parser, options, maxLifeTimeOpt) + + if (options.has(renewOpt)) + CommandLineUtils.checkRequiredArgs(parser, options, hmacOpt, renewTimePeriodOpt) + + if (options.has(expiryOpt)) + CommandLineUtils.checkRequiredArgs(parser, options, hmacOpt, expiryTimePeriodOpt) + + // check invalid args + CommandLineUtils.checkInvalidArgs(parser, options, createOpt, Set(hmacOpt, renewTimePeriodOpt, expiryTimePeriodOpt, ownerPrincipalsOpt)) + CommandLineUtils.checkInvalidArgs(parser, options, renewOpt, Set(renewPrincipalsOpt, maxLifeTimeOpt, expiryTimePeriodOpt, ownerPrincipalsOpt)) + CommandLineUtils.checkInvalidArgs(parser, options, expiryOpt, Set(renewOpt, maxLifeTimeOpt, renewTimePeriodOpt, ownerPrincipalsOpt)) + CommandLineUtils.checkInvalidArgs(parser, options, describeOpt, Set(renewTimePeriodOpt, maxLifeTimeOpt, hmacOpt, renewTimePeriodOpt, expiryTimePeriodOpt)) + } + } +} diff --git a/core/src/main/scala/kafka/admin/DeleteRecordsCommand.scala b/core/src/main/scala/kafka/admin/DeleteRecordsCommand.scala new file mode 100644 index 0000000..71ef6fd --- /dev/null +++ b/core/src/main/scala/kafka/admin/DeleteRecordsCommand.scala @@ -0,0 +1,137 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.admin + +import java.io.PrintStream +import java.util.Properties + +import kafka.common.AdminCommandFailedException +import kafka.utils.json.JsonValue +import kafka.utils.{CommandDefaultOptions, CommandLineUtils, CoreUtils, Json} +import org.apache.kafka.clients.admin.{Admin, RecordsToDelete} +import org.apache.kafka.clients.CommonClientConfigs +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.utils.Utils + +import scala.jdk.CollectionConverters._ +import scala.collection.Seq + +/** + * A command for delete records of the given partitions down to the specified offset. + */ +object DeleteRecordsCommand { + + private[admin] val EarliestVersion = 1 + + def main(args: Array[String]): Unit = { + execute(args, System.out) + } + + def parseOffsetJsonStringWithoutDedup(jsonData: String): Seq[(TopicPartition, Long)] = { + Json.parseFull(jsonData) match { + case Some(js) => + val version = js.asJsonObject.get("version") match { + case Some(jsonValue) => jsonValue.to[Int] + case None => EarliestVersion + } + parseJsonData(version, js) + case None => throw new AdminOperationException("The input string is not a valid JSON") + } + } + + def parseJsonData(version: Int, js: JsonValue): Seq[(TopicPartition, Long)] = { + version match { + case 1 => + js.asJsonObject.get("partitions") match { + case Some(partitions) => + partitions.asJsonArray.iterator.map(_.asJsonObject).map { partitionJs => + val topic = partitionJs("topic").to[String] + val partition = partitionJs("partition").to[Int] + val offset = partitionJs("offset").to[Long] + new TopicPartition(topic, partition) -> offset + }.toBuffer + case _ => throw new AdminOperationException("Missing partitions field"); + } + case _ => throw new AdminOperationException(s"Not supported version field value $version") + } + } + + def execute(args: Array[String], out: PrintStream): Unit = { + val opts = new DeleteRecordsCommandOptions(args) + val adminClient = createAdminClient(opts) + val offsetJsonFile = opts.options.valueOf(opts.offsetJsonFileOpt) + val offsetJsonString = Utils.readFileAsString(offsetJsonFile) + val offsetSeq = parseOffsetJsonStringWithoutDedup(offsetJsonString) + + val duplicatePartitions = CoreUtils.duplicates(offsetSeq.map { case (tp, _) => tp }) + if (duplicatePartitions.nonEmpty) + throw new AdminCommandFailedException("Offset json file contains duplicate topic partitions: %s".format(duplicatePartitions.mkString(","))) + + val recordsToDelete = offsetSeq.map { case (topicPartition, offset) => + (topicPartition, RecordsToDelete.beforeOffset(offset)) + }.toMap.asJava + + out.println("Executing records delete operation") + val deleteRecordsResult = adminClient.deleteRecords(recordsToDelete) + out.println("Records delete operation completed:") + + deleteRecordsResult.lowWatermarks.forEach { (tp, partitionResult) => + try out.println(s"partition: $tp\tlow_watermark: ${partitionResult.get.lowWatermark}") + catch { + case e: Exception => out.println(s"partition: $tp\terror: ${e.getMessage}") + } + } + + adminClient.close() + } + + private def createAdminClient(opts: DeleteRecordsCommandOptions): Admin = { + val props = if (opts.options.has(opts.commandConfigOpt)) + Utils.loadProps(opts.options.valueOf(opts.commandConfigOpt)) + else + new Properties() + props.put(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, opts.options.valueOf(opts.bootstrapServerOpt)) + Admin.create(props) + } + + class DeleteRecordsCommandOptions(args: Array[String]) extends CommandDefaultOptions(args) { + val BootstrapServerDoc = "REQUIRED: The server to connect to." + val offsetJsonFileDoc = "REQUIRED: The JSON file with offset per partition. The format to use is:\n" + + "{\"partitions\":\n [{\"topic\": \"foo\", \"partition\": 1, \"offset\": 1}],\n \"version\":1\n}" + val CommandConfigDoc = "A property file containing configs to be passed to Admin Client." + + val bootstrapServerOpt = parser.accepts("bootstrap-server", BootstrapServerDoc) + .withRequiredArg + .describedAs("server(s) to use for bootstrapping") + .ofType(classOf[String]) + val offsetJsonFileOpt = parser.accepts("offset-json-file", offsetJsonFileDoc) + .withRequiredArg + .describedAs("Offset json file path") + .ofType(classOf[String]) + val commandConfigOpt = parser.accepts("command-config", CommandConfigDoc) + .withRequiredArg + .describedAs("command config property file path") + .ofType(classOf[String]) + + options = parser.parse(args : _*) + + CommandLineUtils.printHelpAndExitIfNeeded(this, "This tool helps to delete records of the given partitions down to the specified offset.") + + CommandLineUtils.checkRequiredArgs(parser, options, bootstrapServerOpt, offsetJsonFileOpt) + } +} diff --git a/core/src/main/scala/kafka/admin/FeatureCommand.scala b/core/src/main/scala/kafka/admin/FeatureCommand.scala new file mode 100644 index 0000000..4b29965 --- /dev/null +++ b/core/src/main/scala/kafka/admin/FeatureCommand.scala @@ -0,0 +1,390 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.admin + +import kafka.server.BrokerFeatures +import kafka.utils.{CommandDefaultOptions, CommandLineUtils, Exit} +import org.apache.kafka.clients.CommonClientConfigs +import org.apache.kafka.clients.admin.{Admin, FeatureUpdate, UpdateFeaturesOptions} +import org.apache.kafka.common.feature.{Features, SupportedVersionRange} +import org.apache.kafka.common.utils.Utils +import java.util.Properties + +import scala.collection.Seq +import scala.collection.immutable.ListMap +import scala.jdk.CollectionConverters._ +import joptsimple.OptionSpec + +import scala.concurrent.ExecutionException + +object FeatureCommand { + + def main(args: Array[String]): Unit = { + val opts = new FeatureCommandOptions(args) + val featureApis = new FeatureApis(opts) + var exitCode = 0 + try { + featureApis.execute() + } catch { + case e: IllegalArgumentException => + printException(e) + opts.parser.printHelpOn(System.err) + exitCode = 1 + case _: UpdateFeaturesException => + exitCode = 1 + case e: ExecutionException => + val cause = if (e.getCause == null) e else e.getCause + printException(cause) + exitCode = 1 + case e: Throwable => + printException(e) + exitCode = 1 + } finally { + featureApis.close() + Exit.exit(exitCode) + } + } + + private def printException(exception: Throwable): Unit = { + System.err.println("\nError encountered when executing command: " + Utils.stackTrace(exception)) + } +} + +class UpdateFeaturesException(message: String) extends RuntimeException(message) + +/** + * A class that provides necessary APIs to bridge feature APIs provided by the Admin client with + * the requirements of the CLI tool. + * + * @param opts the CLI options + */ +class FeatureApis(private var opts: FeatureCommandOptions) { + private var supportedFeatures = BrokerFeatures.createDefault().supportedFeatures + private var adminClient = FeatureApis.createAdminClient(opts) + + private def pad(op: String): String = { + f"$op%11s" + } + + private val addOp = pad("[Add]") + private val upgradeOp = pad("[Upgrade]") + private val deleteOp = pad("[Delete]") + private val downgradeOp = pad("[Downgrade]") + + // For testing only. + private[admin] def setSupportedFeatures(newFeatures: Features[SupportedVersionRange]): Unit = { + supportedFeatures = newFeatures + } + + // For testing only. + private[admin] def setOptions(newOpts: FeatureCommandOptions): Unit = { + adminClient.close() + adminClient = FeatureApis.createAdminClient(newOpts) + opts = newOpts + } + + /** + * Describes the supported and finalized features. The request is issued to any of the provided + * bootstrap servers. + */ + def describeFeatures(): Unit = { + val result = adminClient.describeFeatures.featureMetadata.get + val features = result.supportedFeatures.asScala.keys.toSet ++ result.finalizedFeatures.asScala.keys.toSet + + features.toList.sorted.foreach { + feature => + val output = new StringBuilder() + output.append(s"Feature: $feature") + + val (supportedMinVersion, supportedMaxVersion) = { + val supportedVersionRange = result.supportedFeatures.get(feature) + if (supportedVersionRange == null) { + ("-", "-") + } else { + (supportedVersionRange.minVersion, supportedVersionRange.maxVersion) + } + } + output.append(s"\tSupportedMinVersion: $supportedMinVersion") + output.append(s"\tSupportedMaxVersion: $supportedMaxVersion") + + val (finalizedMinVersionLevel, finalizedMaxVersionLevel) = { + val finalizedVersionRange = result.finalizedFeatures.get(feature) + if (finalizedVersionRange == null) { + ("-", "-") + } else { + (finalizedVersionRange.minVersionLevel, finalizedVersionRange.maxVersionLevel) + } + } + output.append(s"\tFinalizedMinVersionLevel: $finalizedMinVersionLevel") + output.append(s"\tFinalizedMaxVersionLevel: $finalizedMaxVersionLevel") + + val epoch = { + if (result.finalizedFeaturesEpoch.isPresent) { + result.finalizedFeaturesEpoch.get.toString + } else { + "-" + } + } + output.append(s"\tEpoch: $epoch") + + println(output) + } + } + + /** + * Upgrades all features known to this tool to their highest max version levels. The method may + * add new finalized features if they were not finalized previously, but it does not delete + * any existing finalized feature. The results of the feature updates are written to STDOUT. + * + * NOTE: if the --dry-run CLI option is provided, this method only prints the expected feature + * updates to STDOUT, without applying them. + * + * @throws UpdateFeaturesException if at least one of the feature updates failed + */ + def upgradeAllFeatures(): Unit = { + val metadata = adminClient.describeFeatures.featureMetadata.get + val existingFinalizedFeatures = metadata.finalizedFeatures + val updates = supportedFeatures.features.asScala.map { + case (feature, targetVersionRange) => + val existingVersionRange = existingFinalizedFeatures.get(feature) + if (existingVersionRange == null) { + val updateStr = + addOp + + s"\tFeature: $feature" + + s"\tExistingFinalizedMaxVersion: -" + + s"\tNewFinalizedMaxVersion: ${targetVersionRange.max}" + (feature, Some((updateStr, new FeatureUpdate(targetVersionRange.max, false)))) + } else { + if (targetVersionRange.max > existingVersionRange.maxVersionLevel) { + val updateStr = + upgradeOp + + s"\tFeature: $feature" + + s"\tExistingFinalizedMaxVersion: ${existingVersionRange.maxVersionLevel}" + + s"\tNewFinalizedMaxVersion: ${targetVersionRange.max}" + (feature, Some((updateStr, new FeatureUpdate(targetVersionRange.max, false)))) + } else { + (feature, Option.empty) + } + } + }.filter { + case(_, updateInfo) => updateInfo.isDefined + }.map { + case(feature, updateInfo) => (feature, updateInfo.get) + }.toMap + + if (updates.nonEmpty) { + maybeApplyFeatureUpdates(updates) + } + } + + /** + * Downgrades existing finalized features to the highest max version levels known to this tool. + * The method may delete existing finalized features if they are no longer seen to be supported, + * but it does not add a feature that was not finalized previously. The results of the feature + * updates are written to STDOUT. + * + * NOTE: if the --dry-run CLI option is provided, this method only prints the expected feature + * updates to STDOUT, without applying them. + * + * @throws UpdateFeaturesException if at least one of the feature updates failed + */ + def downgradeAllFeatures(): Unit = { + val metadata = adminClient.describeFeatures.featureMetadata.get + val existingFinalizedFeatures = metadata.finalizedFeatures + val supportedFeaturesMap = supportedFeatures.features + val updates = existingFinalizedFeatures.asScala.map { + case (feature, existingVersionRange) => + val targetVersionRange = supportedFeaturesMap.get(feature) + if (targetVersionRange == null) { + val updateStr = + deleteOp + + s"\tFeature: $feature" + + s"\tExistingFinalizedMaxVersion: ${existingVersionRange.maxVersionLevel}" + + s"\tNewFinalizedMaxVersion: -" + (feature, Some(updateStr, new FeatureUpdate(0, true))) + } else { + if (targetVersionRange.max < existingVersionRange.maxVersionLevel) { + val updateStr = + downgradeOp + + s"\tFeature: $feature" + + s"\tExistingFinalizedMaxVersion: ${existingVersionRange.maxVersionLevel}" + + s"\tNewFinalizedMaxVersion: ${targetVersionRange.max}" + (feature, Some(updateStr, new FeatureUpdate(targetVersionRange.max, true))) + } else { + (feature, Option.empty) + } + } + }.filter { + case(_, updateInfo) => updateInfo.isDefined + }.map { + case(feature, updateInfo) => (feature, updateInfo.get) + }.toMap + + if (updates.nonEmpty) { + maybeApplyFeatureUpdates(updates) + } + } + + /** + * Applies the provided feature updates. If the --dry-run CLI option is provided, the method + * only prints the expected feature updates to STDOUT without applying them. + * + * @param updates the feature updates to be applied via the admin client + * + * @throws UpdateFeaturesException if at least one of the feature updates failed + */ + private def maybeApplyFeatureUpdates(updates: Map[String, (String, FeatureUpdate)]): Unit = { + if (opts.hasDryRunOption) { + println("Expected feature updates:" + ListMap( + updates + .toSeq + .sortBy { case(feature, _) => feature} :_*) + .map { case(_, (updateStr, _)) => updateStr} + .mkString("\n")) + } else { + val result = adminClient.updateFeatures( + updates + .map { case(feature, (_, update)) => (feature, update)} + .asJava, + new UpdateFeaturesOptions()) + val resultSortedByFeature = ListMap( + result + .values + .asScala + .toSeq + .sortBy { case(feature, _) => feature} :_*) + val failures = resultSortedByFeature.map { + case (feature, updateFuture) => + val (updateStr, _) = updates(feature) + try { + updateFuture.get + println(updateStr + "\tResult: OK") + 0 + } catch { + case e: ExecutionException => + val cause = if (e.getCause == null) e else e.getCause + println(updateStr + "\tResult: FAILED due to " + cause) + 1 + case e: Throwable => + println(updateStr + "\tResult: FAILED due to " + e) + 1 + } + }.sum + if (failures > 0) { + throw new UpdateFeaturesException(s"$failures feature updates failed!") + } + } + } + + def execute(): Unit = { + if (opts.hasDescribeOption) { + describeFeatures() + } else if (opts.hasUpgradeAllOption) { + upgradeAllFeatures() + } else if (opts.hasDowngradeAllOption) { + downgradeAllFeatures() + } else { + throw new IllegalStateException("Unexpected state: no CLI command could be executed.") + } + } + + def close(): Unit = { + adminClient.close() + } +} + +class FeatureCommandOptions(args: Array[String]) extends CommandDefaultOptions(args) { + private val bootstrapServerOpt = parser.accepts( + "bootstrap-server", + "REQUIRED: A comma-separated list of host:port pairs to use for establishing the connection" + + " to the Kafka cluster.") + .withRequiredArg + .describedAs("server to connect to") + .ofType(classOf[String]) + private val commandConfigOpt = parser.accepts( + "command-config", + "Property file containing configs to be passed to Admin Client." + + " This is used with --bootstrap-server option when required.") + .withOptionalArg + .describedAs("command config property file") + .ofType(classOf[String]) + private val describeOpt = parser.accepts( + "describe", + "Describe supported and finalized features from a random broker.") + private val upgradeAllOpt = parser.accepts( + "upgrade-all", + "Upgrades all finalized features to the maximum version levels known to the tool." + + " This command finalizes new features known to the tool that were never finalized" + + " previously in the cluster, but it is guaranteed to not delete any existing feature.") + private val downgradeAllOpt = parser.accepts( + "downgrade-all", + "Downgrades all finalized features to the maximum version levels known to the tool." + + " This command deletes unknown features from the list of finalized features in the" + + " cluster, but it is guaranteed to not add a new feature.") + private val dryRunOpt = parser.accepts( + "dry-run", + "Performs a dry-run of upgrade/downgrade mutations to finalized feature without applying them.") + + options = parser.parse(args : _*) + + checkArgs() + + def has(builder: OptionSpec[_]): Boolean = options.has(builder) + + def hasDescribeOption: Boolean = has(describeOpt) + + def hasDryRunOption: Boolean = has(dryRunOpt) + + def hasUpgradeAllOption: Boolean = has(upgradeAllOpt) + + def hasDowngradeAllOption: Boolean = has(downgradeAllOpt) + + def commandConfig: Properties = { + if (has(commandConfigOpt)) + Utils.loadProps(options.valueOf(commandConfigOpt)) + else + new Properties() + } + + def bootstrapServers: String = options.valueOf(bootstrapServerOpt) + + def checkArgs(): Unit = { + CommandLineUtils.printHelpAndExitIfNeeded(this, "This tool describes and updates finalized features.") + val numActions = Seq(describeOpt, upgradeAllOpt, downgradeAllOpt).count(has) + if (numActions != 1) { + CommandLineUtils.printUsageAndDie( + parser, + "Command must include exactly one action: --describe, --upgrade-all, --downgrade-all.") + } + CommandLineUtils.checkRequiredArgs(parser, options, bootstrapServerOpt) + if (hasDryRunOption && !hasUpgradeAllOption && !hasDowngradeAllOption) { + CommandLineUtils.printUsageAndDie( + parser, + "Command can contain --dry-run option only when either --upgrade-all or --downgrade-all actions are provided.") + } + } +} + +object FeatureApis { + private def createAdminClient(opts: FeatureCommandOptions): Admin = { + val props = new Properties() + props.putAll(opts.commandConfig) + props.put(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, opts.bootstrapServers) + Admin.create(props) + } +} diff --git a/core/src/main/scala/kafka/admin/LeaderElectionCommand.scala b/core/src/main/scala/kafka/admin/LeaderElectionCommand.scala new file mode 100644 index 0000000..92edcad --- /dev/null +++ b/core/src/main/scala/kafka/admin/LeaderElectionCommand.scala @@ -0,0 +1,289 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.admin + +import java.util.Properties +import java.util.concurrent.ExecutionException +import joptsimple.util.EnumConverter +import kafka.common.AdminCommandFailedException +import kafka.utils.CommandDefaultOptions +import kafka.utils.CommandLineUtils +import kafka.utils.CoreUtils +import kafka.utils.Implicits._ +import kafka.utils.Json +import kafka.utils.Logging +import org.apache.kafka.clients.admin.{Admin, AdminClientConfig} +import org.apache.kafka.common.ElectionType +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.errors.ClusterAuthorizationException +import org.apache.kafka.common.errors.ElectionNotNeededException +import org.apache.kafka.common.errors.TimeoutException +import org.apache.kafka.common.utils.Utils +import scala.jdk.CollectionConverters._ +import scala.collection.mutable +import scala.concurrent.duration._ + +object LeaderElectionCommand extends Logging { + def main(args: Array[String]): Unit = { + run(args, 30.second) + } + + def run(args: Array[String], timeout: Duration): Unit = { + val commandOptions = new LeaderElectionCommandOptions(args) + CommandLineUtils.printHelpAndExitIfNeeded( + commandOptions, + "This tool attempts to elect a new leader for a set of topic partitions. The type of elections supported are preferred replicas and unclean replicas." + ) + + validate(commandOptions) + + val electionType = commandOptions.options.valueOf(commandOptions.electionType) + + val jsonFileTopicPartitions = Option(commandOptions.options.valueOf(commandOptions.pathToJsonFile)).map { path => + parseReplicaElectionData(Utils.readFileAsString(path)) + } + + val singleTopicPartition = ( + Option(commandOptions.options.valueOf(commandOptions.topic)), + Option(commandOptions.options.valueOf(commandOptions.partition)) + ) match { + case (Some(topic), Some(partition)) => Some(Set(new TopicPartition(topic, partition))) + case _ => None + } + + /* Note: No need to look at --all-topic-partitions as we want this to be None if it is use. + * The validate function should be checking that this option is required if the --topic and --path-to-json-file + * are not specified. + */ + val topicPartitions = jsonFileTopicPartitions.orElse(singleTopicPartition) + + val adminClient = { + val props = Option(commandOptions.options.valueOf(commandOptions.adminClientConfig)).map { config => + Utils.loadProps(config) + }.getOrElse(new Properties()) + + props.setProperty( + AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, + commandOptions.options.valueOf(commandOptions.bootstrapServer) + ) + props.setProperty(AdminClientConfig.DEFAULT_API_TIMEOUT_MS_CONFIG, timeout.toMillis.toString) + props.setProperty(AdminClientConfig.REQUEST_TIMEOUT_MS_CONFIG, (timeout.toMillis / 2).toString) + + Admin.create(props) + } + + try { + electLeaders(adminClient, electionType, topicPartitions) + } finally { + adminClient.close() + } + } + + private[this] def parseReplicaElectionData(jsonString: String): Set[TopicPartition] = { + Json.parseFull(jsonString) match { + case Some(js) => + js.asJsonObject.get("partitions") match { + case Some(partitionsList) => + val partitionsRaw = partitionsList.asJsonArray.iterator.map(_.asJsonObject) + val partitions = partitionsRaw.map { p => + val topic = p("topic").to[String] + val partition = p("partition").to[Int] + new TopicPartition(topic, partition) + }.toBuffer + val duplicatePartitions = CoreUtils.duplicates(partitions) + if (duplicatePartitions.nonEmpty) { + throw new AdminOperationException( + s"Replica election data contains duplicate partitions: ${duplicatePartitions.mkString(",")}" + ) + } + partitions.toSet + case None => throw new AdminOperationException("Replica election data is missing \"partitions\" field") + } + case None => throw new AdminOperationException("Replica election data is empty") + } + } + + private[this] def electLeaders( + client: Admin, + electionType: ElectionType, + topicPartitions: Option[Set[TopicPartition]] + ): Unit = { + val electionResults = try { + val partitions = topicPartitions.map(_.asJava).orNull + debug(s"Calling AdminClient.electLeaders($electionType, $partitions)") + client.electLeaders(electionType, partitions).partitions.get.asScala + } catch { + case e: ExecutionException => + e.getCause match { + case cause: TimeoutException => + val message = "Timeout waiting for election results" + println(message) + throw new AdminCommandFailedException(message, cause) + case cause: ClusterAuthorizationException => + val message = "Not authorized to perform leader election" + println(message) + throw new AdminCommandFailedException(message, cause) + case _ => + throw e + } + case e: Throwable => + println("Error while making request") + throw e + } + + val succeeded = mutable.Set.empty[TopicPartition] + val noop = mutable.Set.empty[TopicPartition] + val failed = mutable.Map.empty[TopicPartition, Throwable] + + electionResults.foreach[Unit] { case (topicPartition, error) => + if (error.isPresent) { + error.get match { + case _: ElectionNotNeededException => noop += topicPartition + case _ => failed += topicPartition -> error.get + } + } else { + succeeded += topicPartition + } + } + + if (succeeded.nonEmpty) { + val partitions = succeeded.mkString(", ") + println(s"Successfully completed leader election ($electionType) for partitions $partitions") + } + + if (noop.nonEmpty) { + val partitions = noop.mkString(", ") + println(s"Valid replica already elected for partitions $partitions") + } + + if (failed.nonEmpty) { + val rootException = new AdminCommandFailedException(s"${failed.size} replica(s) could not be elected") + failed.forKeyValue { (topicPartition, exception) => + println(s"Error completing leader election ($electionType) for partition: $topicPartition: $exception") + rootException.addSuppressed(exception) + } + throw rootException + } + } + + private[this] def validate(commandOptions: LeaderElectionCommandOptions): Unit = { + // required options: --bootstrap-server and --election-type + var missingOptions = List.empty[String] + if (!commandOptions.options.has(commandOptions.bootstrapServer)) { + missingOptions = commandOptions.bootstrapServer.options().get(0) :: missingOptions + } + + if (!commandOptions.options.has(commandOptions.electionType)) { + missingOptions = commandOptions.electionType.options().get(0) :: missingOptions + } + + if (missingOptions.nonEmpty) { + throw new AdminCommandFailedException(s"Missing required option(s): ${missingOptions.mkString(", ")}") + } + + // One and only one is required: --topic, --all-topic-partitions or --path-to-json-file + val mutuallyExclusiveOptions = Seq( + commandOptions.topic, + commandOptions.allTopicPartitions, + commandOptions.pathToJsonFile + ) + + mutuallyExclusiveOptions.count(commandOptions.options.has) match { + case 1 => // This is the only correct configuration, don't throw an exception + case _ => + throw new AdminCommandFailedException( + "One and only one of the following options is required: " + + s"${mutuallyExclusiveOptions.map(_.options.get(0)).mkString(", ")}" + ) + } + + // --partition if and only if --topic is used + ( + commandOptions.options.has(commandOptions.topic), + commandOptions.options.has(commandOptions.partition) + ) match { + case (true, false) => + throw new AdminCommandFailedException( + s"Missing required option(s): ${commandOptions.partition.options.get(0)}" + ) + case (false, true) => + throw new AdminCommandFailedException( + s"Option ${commandOptions.partition.options.get(0)} is only allowed if " + + s"${commandOptions.topic.options.get(0)} is used" + ) + case _ => // Ignore; we have a valid configuration + } + } +} + +private final class LeaderElectionCommandOptions(args: Array[String]) extends CommandDefaultOptions(args) { + val bootstrapServer = parser + .accepts( + "bootstrap-server", + "A hostname and port for the broker to connect to, in the form host:port. Multiple comma separated URLs can be given. REQUIRED.") + .withRequiredArg + .describedAs("host:port") + .ofType(classOf[String]) + val adminClientConfig = parser + .accepts( + "admin.config", + "Configuration properties files to pass to the admin client") + .withRequiredArg + .describedAs("config file") + .ofType(classOf[String]) + + val pathToJsonFile = parser + .accepts( + "path-to-json-file", + "The JSON file with the list of partition for which leader elections should be performed. This is an example format. \n{\"partitions\":\n\t[{\"topic\": \"foo\", \"partition\": 1},\n\t {\"topic\": \"foobar\", \"partition\": 2}]\n}\nNot allowed if --all-topic-partitions or --topic flags are specified.") + .withRequiredArg + .describedAs("Path to JSON file") + .ofType(classOf[String]) + + val topic = parser + .accepts( + "topic", + "Name of topic for which to perform an election. Not allowed if --path-to-json-file or --all-topic-partitions is specified.") + .withRequiredArg + .describedAs("topic name") + .ofType(classOf[String]) + + val partition = parser + .accepts( + "partition", + "Partition id for which to perform an election. REQUIRED if --topic is specified.") + .withRequiredArg + .describedAs("partition id") + .ofType(classOf[Integer]) + + val allTopicPartitions = parser + .accepts( + "all-topic-partitions", + "Perform election on all of the eligible topic partitions based on the type of election (see the --election-type flag). Not allowed if --topic or --path-to-json-file is specified.") + + val electionType = parser + .accepts( + "election-type", + "Type of election to attempt. Possible values are \"preferred\" for preferred leader election or \"unclean\" for unclean leader election. If preferred election is selection, the election is only performed if the current leader is not the preferred leader for the topic partition. If unclean election is selected, the election is only performed if there are no leader for the topic partition. REQUIRED.") + .withRequiredArg + .describedAs("election type") + .withValuesConvertedBy(ElectionTypeConverter) + + options = parser.parse(args: _*) +} + +final object ElectionTypeConverter extends EnumConverter[ElectionType](classOf[ElectionType]) { } diff --git a/core/src/main/scala/kafka/admin/LogDirsCommand.scala b/core/src/main/scala/kafka/admin/LogDirsCommand.scala new file mode 100644 index 0000000..d8c802e --- /dev/null +++ b/core/src/main/scala/kafka/admin/LogDirsCommand.scala @@ -0,0 +1,133 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.admin + +import java.io.PrintStream +import java.util.Properties + +import kafka.utils.{CommandDefaultOptions, CommandLineUtils, Json} +import org.apache.kafka.clients.admin.{Admin, AdminClientConfig, LogDirDescription} +import org.apache.kafka.common.utils.Utils + +import scala.jdk.CollectionConverters._ +import scala.collection.Map + +/** + * A command for querying log directory usage on the specified brokers + */ +object LogDirsCommand { + + def main(args: Array[String]): Unit = { + describe(args, System.out) + } + + def describe(args: Array[String], out: PrintStream): Unit = { + val opts = new LogDirsCommandOptions(args) + val adminClient = createAdminClient(opts) + try { + val topicList = opts.options.valueOf(opts.topicListOpt).split(",").filter(_.nonEmpty) + val clusterBrokers = adminClient.describeCluster().nodes().get().asScala.map(_.id()).toSet + val (existingBrokers, nonExistingBrokers) = Option(opts.options.valueOf(opts.brokerListOpt)) match { + case Some(brokerListStr) => + val inputBrokers = brokerListStr.split(',').filter(_.nonEmpty).map(_.toInt).toSet + (inputBrokers.intersect(clusterBrokers), inputBrokers.diff(clusterBrokers)) + case None => (clusterBrokers, Set.empty) + } + + if (nonExistingBrokers.nonEmpty) { + out.println(s"ERROR: The given brokers do not exist from --broker-list: ${nonExistingBrokers.mkString(",")}." + + s" Current existent brokers: ${clusterBrokers.mkString(",")}") + } else { + out.println("Querying brokers for log directories information") + val describeLogDirsResult = adminClient.describeLogDirs(existingBrokers.map(Integer.valueOf).toSeq.asJava) + val logDirInfosByBroker = describeLogDirsResult.allDescriptions.get().asScala.map { case (k, v) => k -> v.asScala } + + out.println(s"Received log directory information from brokers ${existingBrokers.mkString(",")}") + out.println(formatAsJson(logDirInfosByBroker, topicList.toSet)) + } + } finally { + adminClient.close() + } + } + + private def formatAsJson(logDirInfosByBroker: Map[Integer, Map[String, LogDirDescription]], topicSet: Set[String]): String = { + Json.encodeAsString(Map( + "version" -> 1, + "brokers" -> logDirInfosByBroker.map { case (broker, logDirInfos) => + Map( + "broker" -> broker, + "logDirs" -> logDirInfos.map { case (logDir, logDirInfo) => + Map( + "logDir" -> logDir, + "error" -> Option(logDirInfo.error).map(ex => ex.getClass.getName).orNull, + "partitions" -> logDirInfo.replicaInfos.asScala.filter { case (topicPartition, _) => + topicSet.isEmpty || topicSet.contains(topicPartition.topic) + }.map { case (topicPartition, replicaInfo) => + Map( + "partition" -> topicPartition.toString, + "size" -> replicaInfo.size, + "offsetLag" -> replicaInfo.offsetLag, + "isFuture" -> replicaInfo.isFuture + ).asJava + }.asJava + ).asJava + }.asJava + ).asJava + }.asJava + ).asJava) + } + + private def createAdminClient(opts: LogDirsCommandOptions): Admin = { + val props = if (opts.options.has(opts.commandConfigOpt)) + Utils.loadProps(opts.options.valueOf(opts.commandConfigOpt)) + else + new Properties() + props.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, opts.options.valueOf(opts.bootstrapServerOpt)) + props.putIfAbsent(AdminClientConfig.CLIENT_ID_CONFIG, "log-dirs-tool") + Admin.create(props) + } + + class LogDirsCommandOptions(args: Array[String]) extends CommandDefaultOptions(args){ + val bootstrapServerOpt = parser.accepts("bootstrap-server", "REQUIRED: the server(s) to use for bootstrapping") + .withRequiredArg + .describedAs("The server(s) to use for bootstrapping") + .ofType(classOf[String]) + val commandConfigOpt = parser.accepts("command-config", "Property file containing configs to be passed to Admin Client.") + .withRequiredArg + .describedAs("Admin client property file") + .ofType(classOf[String]) + val describeOpt = parser.accepts("describe", "Describe the specified log directories on the specified brokers.") + val topicListOpt = parser.accepts("topic-list", "The list of topics to be queried in the form \"topic1,topic2,topic3\". " + + "All topics will be queried if no topic list is specified") + .withRequiredArg + .describedAs("Topic list") + .defaultsTo("") + .ofType(classOf[String]) + val brokerListOpt = parser.accepts("broker-list", "The list of brokers to be queried in the form \"0,1,2\". " + + "All brokers in the cluster will be queried if no broker list is specified") + .withRequiredArg + .describedAs("Broker list") + .ofType(classOf[String]) + + options = parser.parse(args : _*) + + CommandLineUtils.printHelpAndExitIfNeeded(this, "This tool helps to query log directory usage on the specified brokers.") + + CommandLineUtils.checkRequiredArgs(parser, options, bootstrapServerOpt, describeOpt) + } +} diff --git a/core/src/main/scala/kafka/admin/RackAwareMode.scala b/core/src/main/scala/kafka/admin/RackAwareMode.scala new file mode 100644 index 0000000..45555b6 --- /dev/null +++ b/core/src/main/scala/kafka/admin/RackAwareMode.scala @@ -0,0 +1,42 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.admin + +/** + * Mode to control how rack aware replica assignment will be executed + */ +object RackAwareMode { + + /** + * Ignore all rack information in replica assignment. This is an optional mode used in command line. + */ + case object Disabled extends RackAwareMode + + /** + * Assume every broker has rack, or none of the brokers has rack. If only partial brokers have rack, fail fast + * in replica assignment. This is the default mode in command line tools (TopicCommand and ReassignPartitionsCommand). + */ + case object Enforced extends RackAwareMode + + /** + * Use rack information if every broker has a rack. Otherwise, fallback to Disabled mode. This is used in auto topic + * creation. + */ + case object Safe extends RackAwareMode +} + +sealed trait RackAwareMode diff --git a/core/src/main/scala/kafka/admin/ReassignPartitionsCommand.scala b/core/src/main/scala/kafka/admin/ReassignPartitionsCommand.scala new file mode 100755 index 0000000..ac6304b --- /dev/null +++ b/core/src/main/scala/kafka/admin/ReassignPartitionsCommand.scala @@ -0,0 +1,1500 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.admin + +import java.util +import java.util.Optional +import java.util.concurrent.ExecutionException + +import kafka.common.AdminCommandFailedException +import kafka.log.LogConfig +import kafka.server.DynamicConfig +import kafka.utils.{CommandDefaultOptions, CommandLineUtils, CoreUtils, Exit, Json, Logging} +import kafka.utils.Implicits._ +import kafka.utils.json.JsonValue +import org.apache.kafka.clients.admin.AlterConfigOp.OpType +import org.apache.kafka.clients.admin.{Admin, AdminClientConfig, AlterConfigOp, ConfigEntry, NewPartitionReassignment, PartitionReassignment, TopicDescription} +import org.apache.kafka.common.config.ConfigResource +import org.apache.kafka.common.errors.{ReplicaNotAvailableException, UnknownTopicOrPartitionException} +import org.apache.kafka.common.utils.{Time, Utils} +import org.apache.kafka.common.{KafkaException, KafkaFuture, TopicPartition, TopicPartitionReplica} + +import scala.jdk.CollectionConverters._ +import scala.collection.{Map, Seq, mutable} +import scala.math.Ordered.orderingToOrdered + + +object ReassignPartitionsCommand extends Logging { + private[admin] val AnyLogDir = "any" + + val helpText = "This tool helps to move topic partitions between replicas." + + /** + * The earliest version of the partition reassignment JSON. We will default to this + * version if no other version number is given. + */ + private[admin] val EarliestVersion = 1 + + /** + * The earliest version of the JSON for each partition reassignment topic. We will + * default to this version if no other version number is given. + */ + private[admin] val EarliestTopicsJsonVersion = 1 + + // Throttles that are set at the level of an individual broker. + private[admin] val brokerLevelLeaderThrottle = + DynamicConfig.Broker.LeaderReplicationThrottledRateProp + private[admin] val brokerLevelFollowerThrottle = + DynamicConfig.Broker.FollowerReplicationThrottledRateProp + private[admin] val brokerLevelLogDirThrottle = + DynamicConfig.Broker.ReplicaAlterLogDirsIoMaxBytesPerSecondProp + private[admin] val brokerLevelThrottles = Seq( + brokerLevelLeaderThrottle, + brokerLevelFollowerThrottle, + brokerLevelLogDirThrottle + ) + + // Throttles that are set at the level of an individual topic. + private[admin] val topicLevelLeaderThrottle = + LogConfig.LeaderReplicationThrottledReplicasProp + private[admin] val topicLevelFollowerThrottle = + LogConfig.FollowerReplicationThrottledReplicasProp + private[admin] val topicLevelThrottles = Seq( + topicLevelLeaderThrottle, + topicLevelFollowerThrottle + ) + + private[admin] val cannotExecuteBecauseOfExistingMessage = "Cannot execute because " + + "there is an existing partition assignment. Use --additional to override this and " + + "create a new partition assignment in addition to the existing one. The --additional " + + "flag can also be used to change the throttle by resubmitting the current reassignment." + + private[admin] val youMustRunVerifyPeriodicallyMessage = "Warning: You must run " + + "--verify periodically, until the reassignment completes, to ensure the throttle " + + "is removed." + + /** + * A map from topic names to partition movements. + */ + type MoveMap = mutable.Map[String, mutable.Map[Int, PartitionMove]] + + /** + * A partition movement. The source and destination brokers may overlap. + * + * @param sources The source brokers. + * @param destinations The destination brokers. + */ + sealed case class PartitionMove(sources: mutable.Set[Int], + destinations: mutable.Set[Int]) { } + + /** + * The state of a partition reassignment. The current replicas and target replicas + * may overlap. + * + * @param currentReplicas The current replicas. + * @param targetReplicas The target replicas. + * @param done True if the reassignment is done. + */ + sealed case class PartitionReassignmentState(currentReplicas: Seq[Int], + targetReplicas: Seq[Int], + done: Boolean) {} + + /** + * The state of a replica log directory movement. + */ + sealed trait LogDirMoveState { + /** + * True if the move is done without errors. + */ + def done: Boolean + } + + /** + * A replica log directory move state where the source log directory is missing. + * + * @param targetLogDir The log directory that we wanted the replica to move to. + */ + sealed case class MissingReplicaMoveState(targetLogDir: String) + extends LogDirMoveState { + override def done = false + } + + /** + * A replica log directory move state where the source replica is missing. + * + * @param targetLogDir The log directory that we wanted the replica to move to. + */ + sealed case class MissingLogDirMoveState(targetLogDir: String) + extends LogDirMoveState { + override def done = false + } + + /** + * A replica log directory move state where the move is in progress. + * + * @param currentLogDir The current log directory. + * @param futureLogDir The log directory that the replica is moving to. + * @param targetLogDir The log directory that we wanted the replica to move to. + */ + sealed case class ActiveMoveState(currentLogDir: String, + targetLogDir: String, + futureLogDir: String) + extends LogDirMoveState { + override def done = false + } + + /** + * A replica log directory move state where there is no move in progress, but we did not + * reach the target log directory. + * + * @param currentLogDir The current log directory. + * @param targetLogDir The log directory that we wanted the replica to move to. + */ + sealed case class CancelledMoveState(currentLogDir: String, + targetLogDir: String) + extends LogDirMoveState { + override def done = true + } + + /** + * The completed replica log directory move state. + * + * @param targetLogDir The log directory that we wanted the replica to move to. + */ + sealed case class CompletedMoveState(targetLogDir: String) + extends LogDirMoveState { + override def done = true + } + + /** + * An exception thrown to indicate that the command has failed, but we don't want to + * print a stack trace. + * + * @param message The message to print out before exiting. A stack trace will not + * be printed. + */ + class TerseReassignmentFailureException(message: String) extends KafkaException(message) { + } + + def main(args: Array[String]): Unit = { + val opts = validateAndParseArgs(args) + var failed = true + var adminClient: Admin = null + + try { + val props = if (opts.options.has(opts.commandConfigOpt)) + Utils.loadProps(opts.options.valueOf(opts.commandConfigOpt)) + else + new util.Properties() + props.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, opts.options.valueOf(opts.bootstrapServerOpt)) + props.putIfAbsent(AdminClientConfig.CLIENT_ID_CONFIG, "reassign-partitions-tool") + adminClient = Admin.create(props) + handleAction(adminClient, opts) + failed = false + } catch { + case e: TerseReassignmentFailureException => + println(e.getMessage) + case e: Throwable => + println("Error: " + e.getMessage) + println(Utils.stackTrace(e)) + } finally { + // It's good to do this after printing any error stack trace. + if (adminClient != null) { + adminClient.close() + } + } + // If the command failed, exit with a non-zero exit code. + if (failed) { + Exit.exit(1) + } + } + + private def handleAction(adminClient: Admin, + opts: ReassignPartitionsCommandOptions): Unit = { + if (opts.options.has(opts.verifyOpt)) { + verifyAssignment(adminClient, + Utils.readFileAsString(opts.options.valueOf(opts.reassignmentJsonFileOpt)), + opts.options.has(opts.preserveThrottlesOpt)) + } else if (opts.options.has(opts.generateOpt)) { + generateAssignment(adminClient, + Utils.readFileAsString(opts.options.valueOf(opts.topicsToMoveJsonFileOpt)), + opts.options.valueOf(opts.brokerListOpt), + !opts.options.has(opts.disableRackAware)) + } else if (opts.options.has(opts.executeOpt)) { + executeAssignment(adminClient, + opts.options.has(opts.additionalOpt), + Utils.readFileAsString(opts.options.valueOf(opts.reassignmentJsonFileOpt)), + opts.options.valueOf(opts.interBrokerThrottleOpt), + opts.options.valueOf(opts.replicaAlterLogDirsThrottleOpt), + opts.options.valueOf(opts.timeoutOpt)) + } else if (opts.options.has(opts.cancelOpt)) { + cancelAssignment(adminClient, + Utils.readFileAsString(opts.options.valueOf(opts.reassignmentJsonFileOpt)), + opts.options.has(opts.preserveThrottlesOpt), + opts.options.valueOf(opts.timeoutOpt)) + } else if (opts.options.has(opts.listOpt)) { + listReassignments(adminClient) + } else { + throw new RuntimeException("Unsupported action.") + } + } + + /** + * A result returned from verifyAssignment. + * + * @param partStates A map from partitions to reassignment states. + * @param partsOngoing True if there are any ongoing partition reassignments. + * @param moveStates A map from log directories to movement states. + * @param movesOngoing True if there are any ongoing moves that we know about. + */ + case class VerifyAssignmentResult(partStates: Map[TopicPartition, PartitionReassignmentState], + partsOngoing: Boolean = false, + moveStates: Map[TopicPartitionReplica, LogDirMoveState] = Map.empty, + movesOngoing: Boolean = false) + + /** + * The entry point for the --verify command. + * + * @param adminClient The AdminClient to use. + * @param jsonString The JSON string to use for the topics and partitions to verify. + * @param preserveThrottles True if we should avoid changing topic or broker throttles. + * + * @return A result that is useful for testing. + */ + def verifyAssignment(adminClient: Admin, jsonString: String, preserveThrottles: Boolean) + : VerifyAssignmentResult = { + val (targetParts, targetLogDirs) = parsePartitionReassignmentData(jsonString) + val (partStates, partsOngoing) = verifyPartitionAssignments(adminClient, targetParts) + val (moveStates, movesOngoing) = verifyReplicaMoves(adminClient, targetLogDirs) + if (!partsOngoing && !movesOngoing && !preserveThrottles) { + // If the partition assignments and replica assignments are done, clear any throttles + // that were set. We have to clear all throttles, because we don't have enough + // information to know all of the source brokers that might have been involved in the + // previous reassignments. + clearAllThrottles(adminClient, targetParts) + } + VerifyAssignmentResult(partStates, partsOngoing, moveStates, movesOngoing) + } + + /** + * Verify the partition reassignments specified by the user. + * + * @param adminClient The AdminClient to use. + * @param targets The partition reassignments specified by the user. + * + * @return A tuple of the partition reassignment states, and a + * boolean which is true if there are no ongoing + * reassignments (including reassignments not described + * in the JSON file.) + */ + def verifyPartitionAssignments(adminClient: Admin, + targets: Seq[(TopicPartition, Seq[Int])]) + : (Map[TopicPartition, PartitionReassignmentState], Boolean) = { + val (partStates, partsOngoing) = findPartitionReassignmentStates(adminClient, targets) + println(partitionReassignmentStatesToString(partStates)) + (partStates, partsOngoing) + } + + def compareTopicPartitions(a: TopicPartition, b: TopicPartition): Boolean = { + (a.topic(), a.partition()) < (b.topic(), b.partition()) + } + + def compareTopicPartitionReplicas(a: TopicPartitionReplica, b: TopicPartitionReplica): Boolean = { + (a.brokerId(), a.topic(), a.partition()) < (b.brokerId(), b.topic(), b.partition()) + } + + /** + * Convert partition reassignment states to a human-readable string. + * + * @param states A map from topic partitions to states. + * @return A string summarizing the partition reassignment states. + */ + def partitionReassignmentStatesToString(states: Map[TopicPartition, PartitionReassignmentState]) + : String = { + val bld = new mutable.ArrayBuffer[String]() + bld.append("Status of partition reassignment:") + states.keySet.toBuffer.sortWith(compareTopicPartitions).foreach { topicPartition => + val state = states(topicPartition) + if (state.done) { + if (state.currentReplicas.equals(state.targetReplicas)) { + bld.append("Reassignment of partition %s is complete.". + format(topicPartition.toString)) + } else { + bld.append(s"There is no active reassignment of partition ${topicPartition}, " + + s"but replica set is ${state.currentReplicas.mkString(",")} rather than " + + s"${state.targetReplicas.mkString(",")}.") + } + } else { + bld.append("Reassignment of partition %s is still in progress.".format(topicPartition)) + } + } + bld.mkString(System.lineSeparator()) + } + + /** + * Find the state of the specified partition reassignments. + * + * @param adminClient The Admin client to use. + * @param targetReassignments The reassignments we want to learn about. + * + * @return A tuple containing the reassignment states for each topic + * partition, plus whether there are any ongoing reassignments. + */ + def findPartitionReassignmentStates(adminClient: Admin, + targetReassignments: Seq[(TopicPartition, Seq[Int])]) + : (Map[TopicPartition, PartitionReassignmentState], Boolean) = { + val currentReassignments = adminClient. + listPartitionReassignments.reassignments.get().asScala + val (foundReassignments, notFoundReassignments) = targetReassignments.partition { + case (part, _) => currentReassignments.contains(part) + } + val foundResults = foundReassignments.map { + case (part, targetReplicas) => (part, + PartitionReassignmentState( + currentReassignments(part).replicas. + asScala.map(i => i.asInstanceOf[Int]), + targetReplicas, + false)) + } + val topicNamesToLookUp = new mutable.HashSet[String]() + notFoundReassignments.foreach { case (part, _) => + if (!currentReassignments.contains(part)) + topicNamesToLookUp.add(part.topic) + } + val topicDescriptions = adminClient. + describeTopics(topicNamesToLookUp.asJava).topicNameValues().asScala + val notFoundResults = notFoundReassignments.map { + case (part, targetReplicas) => + currentReassignments.get(part) match { + case Some(reassignment) => (part, + PartitionReassignmentState( + reassignment.replicas.asScala.map(_.asInstanceOf[Int]), + targetReplicas, + false)) + case None => + (part, topicDescriptionFutureToState(part.partition, + topicDescriptions(part.topic), targetReplicas)) + } + } + val allResults = foundResults ++ notFoundResults + (allResults.toMap, currentReassignments.nonEmpty) + } + + private def topicDescriptionFutureToState(partition: Int, + future: KafkaFuture[TopicDescription], + targetReplicas: Seq[Int]): PartitionReassignmentState = { + try { + val topicDescription = future.get() + if (topicDescription.partitions().size() < partition) { + throw new ExecutionException("Too few partitions found", new UnknownTopicOrPartitionException()) + } + PartitionReassignmentState( + topicDescription.partitions.get(partition).replicas.asScala.map(_.id), + targetReplicas, + true) + } catch { + case t: ExecutionException if t.getCause.isInstanceOf[UnknownTopicOrPartitionException] => + PartitionReassignmentState(Seq(), targetReplicas, true) + } + } + + /** + * Verify the replica reassignments specified by the user. + * + * @param adminClient The AdminClient to use. + * @param targetReassignments The replica reassignments specified by the user. + * + * @return A tuple of the replica states, and a boolean which is true + * if there are any ongoing replica moves. + * + * Note: Unlike in verifyPartitionAssignments, we will + * return false here even if there are unrelated ongoing + * reassignments. (We don't have an efficient API that + * returns all ongoing replica reassignments.) + */ + def verifyReplicaMoves(adminClient: Admin, + targetReassignments: Map[TopicPartitionReplica, String]) + : (Map[TopicPartitionReplica, LogDirMoveState], Boolean) = { + val moveStates = findLogDirMoveStates(adminClient, targetReassignments) + println(replicaMoveStatesToString(moveStates)) + (moveStates, !moveStates.values.forall(_.done)) + } + + /** + * Find the state of the specified partition reassignments. + * + * @param adminClient The AdminClient to use. + * @param targetMoves The movements we want to learn about. The map is keyed + * by TopicPartitionReplica, and its values are target log + * directories. + * + * @return The states for each replica movement. + */ + def findLogDirMoveStates(adminClient: Admin, + targetMoves: Map[TopicPartitionReplica, String]) + : Map[TopicPartitionReplica, LogDirMoveState] = { + val replicaLogDirInfos = adminClient.describeReplicaLogDirs( + targetMoves.keySet.asJava).all().get().asScala + targetMoves.map { case (replica, targetLogDir) => + val moveState = replicaLogDirInfos.get(replica) match { + case None => MissingReplicaMoveState(targetLogDir) + case Some(info) => if (info.getCurrentReplicaLogDir == null) { + MissingLogDirMoveState(targetLogDir) + } else if (info.getFutureReplicaLogDir == null) { + if (info.getCurrentReplicaLogDir.equals(targetLogDir)) { + CompletedMoveState(targetLogDir) + } else { + CancelledMoveState(info.getCurrentReplicaLogDir, targetLogDir) + } + } else { + ActiveMoveState(info.getCurrentReplicaLogDir(), + targetLogDir, + info.getFutureReplicaLogDir) + } + } + (replica, moveState) + } + } + + /** + * Convert replica move states to a human-readable string. + * + * @param states A map from topic partition replicas to states. + * @return A tuple of a summary string, and a boolean describing + * whether there are any active replica moves. + */ + def replicaMoveStatesToString(states: Map[TopicPartitionReplica, LogDirMoveState]) + : String = { + val bld = new mutable.ArrayBuffer[String] + states.keySet.toBuffer.sortWith(compareTopicPartitionReplicas).foreach { replica => + val state = states(replica) + state match { + case MissingLogDirMoveState(_) => + bld.append(s"Partition ${replica.topic}-${replica.partition} is not found " + + s"in any live log dir on broker ${replica.brokerId}. There is likely an " + + s"offline log directory on the broker.") + case MissingReplicaMoveState(_) => + bld.append(s"Partition ${replica.topic}-${replica.partition} cannot be found " + + s"in any live log directory on broker ${replica.brokerId}.") + case ActiveMoveState(_, targetLogDir, futureLogDir) => + if (targetLogDir.equals(futureLogDir)) { + bld.append(s"Reassignment of replica $replica is still in progress.") + } else { + bld.append(s"Partition ${replica.topic}-${replica.partition} on broker " + + s"${replica.brokerId} is being moved to log dir $futureLogDir " + + s"instead of $targetLogDir.") + } + case CancelledMoveState(currentLogDir, targetLogDir) => + bld.append(s"Partition ${replica.topic}-${replica.partition} on broker " + + s"${replica.brokerId} is not being moved from log dir $currentLogDir to " + + s"$targetLogDir.") + case CompletedMoveState(_) => + bld.append(s"Reassignment of replica $replica completed successfully.") + } + } + bld.mkString(System.lineSeparator()) + } + + /** + * Clear all topic-level and broker-level throttles. + * + * @param adminClient The AdminClient to use. + * @param targetParts The target partitions loaded from the JSON file. + */ + def clearAllThrottles(adminClient: Admin, + targetParts: Seq[(TopicPartition, Seq[Int])]): Unit = { + val activeBrokers = adminClient.describeCluster().nodes().get().asScala.map(_.id()).toSet + val brokers = activeBrokers ++ targetParts.flatMap(_._2).toSet + println("Clearing broker-level throttles on broker%s %s".format( + if (brokers.size == 1) "" else "s", brokers.mkString(","))) + clearBrokerLevelThrottles(adminClient, brokers) + + val topics = targetParts.map(_._1.topic()).toSet + println("Clearing topic-level throttles on topic%s %s".format( + if (topics.size == 1) "" else "s", topics.mkString(","))) + clearTopicLevelThrottles(adminClient, topics) + } + + /** + * Clear all throttles which have been set at the broker level. + * + * @param adminClient The AdminClient to use. + * @param brokers The brokers to clear the throttles for. + */ + def clearBrokerLevelThrottles(adminClient: Admin, brokers: Set[Int]): Unit = { + val configOps = new util.HashMap[ConfigResource, util.Collection[AlterConfigOp]]() + brokers.foreach { brokerId => + configOps.put( + new ConfigResource(ConfigResource.Type.BROKER, brokerId.toString), + brokerLevelThrottles.map(throttle => new AlterConfigOp( + new ConfigEntry(throttle, null), OpType.DELETE)).asJava) + } + adminClient.incrementalAlterConfigs(configOps).all().get() + } + + /** + * Clear the reassignment throttles for the specified topics. + * + * @param adminClient The AdminClient to use. + * @param topics The topics to clear the throttles for. + */ + def clearTopicLevelThrottles(adminClient: Admin, topics: Set[String]): Unit = { + val configOps = new util.HashMap[ConfigResource, util.Collection[AlterConfigOp]]() + topics.foreach { + topicName => configOps.put( + new ConfigResource(ConfigResource.Type.TOPIC, topicName), + topicLevelThrottles.map(throttle => new AlterConfigOp(new ConfigEntry(throttle, null), + OpType.DELETE)).asJava) + } + adminClient.incrementalAlterConfigs(configOps).all().get() + } + + /** + * The entry point for the --generate command. + * + * @param adminClient The AdminClient to use. + * @param reassignmentJson The JSON string to use for the topics to reassign. + * @param brokerListString The comma-separated string of broker IDs to use. + * @param enableRackAwareness True if rack-awareness should be enabled. + * + * @return A tuple containing the proposed assignment and the + * current assignment. + */ + def generateAssignment(adminClient: Admin, + reassignmentJson: String, + brokerListString: String, + enableRackAwareness: Boolean) + : (Map[TopicPartition, Seq[Int]], Map[TopicPartition, Seq[Int]]) = { + val (brokersToReassign, topicsToReassign) = + parseGenerateAssignmentArgs(reassignmentJson, brokerListString) + val currentAssignments = getReplicaAssignmentForTopics(adminClient, topicsToReassign) + val brokerMetadatas = getBrokerMetadata(adminClient, brokersToReassign, enableRackAwareness) + val proposedAssignments = calculateAssignment(currentAssignments, brokerMetadatas) + println("Current partition replica assignment\n%s\n". + format(formatAsReassignmentJson(currentAssignments, Map.empty))) + println("Proposed partition reassignment configuration\n%s". + format(formatAsReassignmentJson(proposedAssignments, Map.empty))) + (proposedAssignments, currentAssignments) + } + + /** + * Calculate the new partition assignments to suggest in --generate. + * + * @param currentAssignment The current partition assignments. + * @param brokerMetadatas The rack information for each broker. + * + * @return A map from partitions to the proposed assignments for each. + */ + def calculateAssignment(currentAssignment: Map[TopicPartition, Seq[Int]], + brokerMetadatas: Seq[BrokerMetadata]) + : Map[TopicPartition, Seq[Int]] = { + val groupedByTopic = currentAssignment.groupBy { case (tp, _) => tp.topic } + val proposedAssignments = mutable.Map[TopicPartition, Seq[Int]]() + groupedByTopic.forKeyValue { (topic, assignment) => + val (_, replicas) = assignment.head + val assignedReplicas = AdminUtils. + assignReplicasToBrokers(brokerMetadatas, assignment.size, replicas.size) + proposedAssignments ++= assignedReplicas.map { case (partition, replicas) => + new TopicPartition(topic, partition) -> replicas + } + } + proposedAssignments + } + + private def describeTopics(adminClient: Admin, + topics: Set[String]) + : Map[String, TopicDescription] = { + adminClient.describeTopics(topics.asJava).topicNameValues().asScala.map { case (topicName, topicDescriptionFuture) => + try topicName -> topicDescriptionFuture.get + catch { + case t: ExecutionException if t.getCause.isInstanceOf[UnknownTopicOrPartitionException] => + throw new ExecutionException( + new UnknownTopicOrPartitionException(s"Topic $topicName not found.")) + } + } + } + + /** + * Get the current replica assignments for some topics. + * + * @param adminClient The AdminClient to use. + * @param topics The topics to get information about. + * @return A map from partitions to broker assignments. + * If any topic can't be found, an exception will be thrown. + */ + def getReplicaAssignmentForTopics(adminClient: Admin, + topics: Seq[String]) + : Map[TopicPartition, Seq[Int]] = { + describeTopics(adminClient, topics.toSet).flatMap { + case (topicName, topicDescription) => topicDescription.partitions.asScala.map { info => + (new TopicPartition(topicName, info.partition), info.replicas.asScala.map(_.id)) + } + } + } + + /** + * Get the current replica assignments for some partitions. + * + * @param adminClient The AdminClient to use. + * @param partitions The partitions to get information about. + * @return A map from partitions to broker assignments. + * If any topic can't be found, an exception will be thrown. + */ + def getReplicaAssignmentForPartitions(adminClient: Admin, + partitions: Set[TopicPartition]) + : Map[TopicPartition, Seq[Int]] = { + describeTopics(adminClient, partitions.map(_.topic)).flatMap { + case (topicName, topicDescription) => topicDescription.partitions.asScala.flatMap { info => + val tp = new TopicPartition(topicName, info.partition) + if (partitions.contains(tp)) { + Some(tp, info.replicas.asScala.map(_.id)) + } else { + None + } + } + } + } + + /** + * Find the rack information for some brokers. + * + * @param adminClient The AdminClient object. + * @param brokers The brokers to gather metadata about. + * @param enableRackAwareness True if we should return rack information, and throw an + * exception if it is inconsistent. + * + * @return The metadata for each broker that was found. + * Brokers that were not found will be omitted. + */ + def getBrokerMetadata(adminClient: Admin, + brokers: Seq[Int], + enableRackAwareness: Boolean): Seq[BrokerMetadata] = { + val brokerSet = brokers.toSet + val results = adminClient.describeCluster().nodes.get().asScala. + filter(node => brokerSet.contains(node.id)). + map { + node => if (enableRackAwareness && node.rack != null) { + BrokerMetadata(node.id, Some(node.rack)) + } else { + BrokerMetadata(node.id, None) + } + }.toSeq + val numRackless = results.count(_.rack.isEmpty) + if (enableRackAwareness && numRackless != 0 && numRackless != results.size) { + throw new AdminOperationException("Not all brokers have rack information. Add " + + "--disable-rack-aware in command line to make replica assignment without rack " + + "information.") + } + results + } + + /** + * Parse and validate data gathered from the command-line for --generate + * In particular, we parse the JSON and validate that duplicate brokers and + * topics don't appear. + * + * @param reassignmentJson The JSON passed to --generate . + * @param brokerList A list of brokers passed to --generate. + * + * @return A tuple of brokers to reassign, topics to reassign + */ + def parseGenerateAssignmentArgs(reassignmentJson: String, + brokerList: String): (Seq[Int], Seq[String]) = { + val brokerListToReassign = brokerList.split(',').map(_.toInt) + val duplicateReassignments = CoreUtils.duplicates(brokerListToReassign) + if (duplicateReassignments.nonEmpty) + throw new AdminCommandFailedException("Broker list contains duplicate entries: %s". + format(duplicateReassignments.mkString(","))) + val topicsToReassign = parseTopicsData(reassignmentJson) + val duplicateTopicsToReassign = CoreUtils.duplicates(topicsToReassign) + if (duplicateTopicsToReassign.nonEmpty) + throw new AdminCommandFailedException("List of topics to reassign contains duplicate entries: %s". + format(duplicateTopicsToReassign.mkString(","))) + (brokerListToReassign, topicsToReassign) + } + + /** + * The entry point for the --execute and --execute-additional commands. + * + * @param adminClient The AdminClient to use. + * @param additional Whether --additional was passed. + * @param reassignmentJson The JSON string to use for the topics to reassign. + * @param interBrokerThrottle The inter-broker throttle to use, or a negative + * number to skip using a throttle. + * @param logDirThrottle The replica log directory throttle to use, or a + * negative number to skip using a throttle. + * @param timeoutMs The maximum time in ms to wait for log directory + * replica assignment to begin. + * @param time The Time object to use. + */ + def executeAssignment(adminClient: Admin, + additional: Boolean, + reassignmentJson: String, + interBrokerThrottle: Long = -1L, + logDirThrottle: Long = -1L, + timeoutMs: Long = 10000L, + time: Time = Time.SYSTEM): Unit = { + val (proposedParts, proposedReplicas) = parseExecuteAssignmentArgs(reassignmentJson) + val currentReassignments = adminClient. + listPartitionReassignments().reassignments().get().asScala + // If there is an existing assignment, check for --additional before proceeding. + // This helps avoid surprising users. + if (!additional && currentReassignments.nonEmpty) { + throw new TerseReassignmentFailureException(cannotExecuteBecauseOfExistingMessage) + } + verifyBrokerIds(adminClient, proposedParts.values.flatten.toSet) + val currentParts = getReplicaAssignmentForPartitions(adminClient, proposedParts.keySet.toSet) + println(currentPartitionReplicaAssignmentToString(proposedParts, currentParts)) + + if (interBrokerThrottle >= 0 || logDirThrottle >= 0) { + println(youMustRunVerifyPeriodicallyMessage) + + if (interBrokerThrottle >= 0) { + val moveMap = calculateProposedMoveMap(currentReassignments, proposedParts, currentParts) + modifyReassignmentThrottle(adminClient, moveMap, interBrokerThrottle) + } + + if (logDirThrottle >= 0) { + val movingBrokers = calculateMovingBrokers(proposedReplicas.keySet.toSet) + modifyLogDirThrottle(adminClient, movingBrokers, logDirThrottle) + } + } + + // Execute the partition reassignments. + val errors = alterPartitionReassignments(adminClient, proposedParts) + if (errors.nonEmpty) { + throw new TerseReassignmentFailureException( + "Error reassigning partition(s):%n%s".format( + errors.keySet.toBuffer.sortWith(compareTopicPartitions).map { part => + s"$part: ${errors(part).getMessage}" + }.mkString(System.lineSeparator()))) + } + println("Successfully started partition reassignment%s for %s".format( + if (proposedParts.size == 1) "" else "s", + proposedParts.keySet.toBuffer.sortWith(compareTopicPartitions).mkString(","))) + if (proposedReplicas.nonEmpty) { + executeMoves(adminClient, proposedReplicas, timeoutMs, time) + } + } + + /** + * Execute some partition log directory movements. + * + * @param adminClient The AdminClient to use. + * @param proposedReplicas A map from TopicPartitionReplicas to the + * directories to move them to. + * @param timeoutMs The maximum time in ms to wait for log directory + * replica assignment to begin. + * @param time The Time object to use. + */ + def executeMoves(adminClient: Admin, + proposedReplicas: Map[TopicPartitionReplica, String], + timeoutMs: Long, + time: Time): Unit = { + val startTimeMs = time.milliseconds() + val pendingReplicas = new mutable.HashMap[TopicPartitionReplica, String]() + pendingReplicas ++= proposedReplicas + var done = false + do { + val completed = alterReplicaLogDirs(adminClient, pendingReplicas) + if (completed.nonEmpty) { + println("Successfully started log directory move%s for: %s".format( + if (completed.size == 1) "" else "s", + completed.toBuffer.sortWith(compareTopicPartitionReplicas).mkString(","))) + } + pendingReplicas --= completed + if (pendingReplicas.isEmpty) { + done = true + } else if (time.milliseconds() >= startTimeMs + timeoutMs) { + throw new TerseReassignmentFailureException( + "Timed out before log directory move%s could be started for: %s".format( + if (pendingReplicas.size == 1) "" else "s", + pendingReplicas.keySet.toBuffer.sortWith(compareTopicPartitionReplicas). + mkString(","))) + } else { + // If a replica has been moved to a new host and we also specified a particular + // log directory, we will have to keep retrying the alterReplicaLogDirs + // call. It can't take effect until the replica is moved to that host. + time.sleep(100) + } + } while (!done) + } + + /** + * Entry point for the --list command. + * + * @param adminClient The AdminClient to use. + */ + def listReassignments(adminClient: Admin): Unit = { + println(curReassignmentsToString(adminClient)) + } + + /** + * Convert the current partition reassignments to text. + * + * @param adminClient The AdminClient to use. + * @return A string describing the current partition reassignments. + */ + def curReassignmentsToString(adminClient: Admin): String = { + val currentReassignments = adminClient. + listPartitionReassignments().reassignments().get().asScala + val text = currentReassignments.keySet.toBuffer.sortWith(compareTopicPartitions).map { part => + val reassignment = currentReassignments(part) + val replicas = reassignment.replicas.asScala + val addingReplicas = reassignment.addingReplicas.asScala + val removingReplicas = reassignment.removingReplicas.asScala + "%s: replicas: %s.%s%s".format(part, replicas.mkString(","), + if (addingReplicas.isEmpty) "" else + " adding: %s.".format(addingReplicas.mkString(",")), + if (removingReplicas.isEmpty) "" else + " removing: %s.".format(removingReplicas.mkString(","))) + }.mkString(System.lineSeparator()) + if (text.isEmpty) { + "No partition reassignments found." + } else { + "Current partition reassignments:%n%s".format(text) + } + } + + /** + * Verify that all the brokers in an assignment exist. + * + * @param adminClient The AdminClient to use. + * @param brokers The broker IDs to verify. + */ + def verifyBrokerIds(adminClient: Admin, brokers: Set[Int]): Unit = { + val allNodeIds = adminClient.describeCluster().nodes().get().asScala.map(_.id).toSet + brokers.find(!allNodeIds.contains(_)).map { + id => throw new AdminCommandFailedException(s"Unknown broker id ${id}") + } + } + + /** + * Return the string which we want to print to describe the current partition assignment. + * + * @param proposedParts The proposed partition assignment. + * @param currentParts The current partition assignment. + * + * @return The string to print. We will only print information about + * partitions that appear in the proposed partition assignment. + */ + def currentPartitionReplicaAssignmentToString(proposedParts: Map[TopicPartition, Seq[Int]], + currentParts: Map[TopicPartition, Seq[Int]]): String = { + "Current partition replica assignment%n%n%s%n%nSave this to use as the %s". + format(formatAsReassignmentJson(currentParts.filter { case (k, _) => proposedParts.contains(k) }.toMap, Map.empty), + "--reassignment-json-file option during rollback") + } + + /** + * Execute the given partition reassignments. + * + * @param adminClient The admin client object to use. + * @param reassignments A map from topic names to target replica assignments. + * @return A map from partition objects to error strings. + */ + def alterPartitionReassignments(adminClient: Admin, + reassignments: Map[TopicPartition, Seq[Int]]): Map[TopicPartition, Throwable] = { + val results = adminClient.alterPartitionReassignments(reassignments.map { case (part, replicas) => + (part, Optional.of(new NewPartitionReassignment(replicas.map(Integer.valueOf).asJava))) + }.asJava).values().asScala + results.flatMap { + case (part, future) => { + try { + future.get() + None + } catch { + case t: ExecutionException => Some(part, t.getCause()) + } + } + } + } + + /** + * Cancel the given partition reassignments. + * + * @param adminClient The admin client object to use. + * @param reassignments The partition reassignments to cancel. + * @return A map from partition objects to error strings. + */ + def cancelPartitionReassignments(adminClient: Admin, + reassignments: Set[TopicPartition]) + : Map[TopicPartition, Throwable] = { + val results = adminClient.alterPartitionReassignments(reassignments.map { + (_, Optional.empty[NewPartitionReassignment]()) + }.toMap.asJava).values().asScala + results.flatMap { case (part, future) => + try { + future.get() + None + } catch { + case t: ExecutionException => Some(part, t.getCause()) + } + } + } + + /** + * Compute the in progress partition move from the current reassignments. + * @param currentReassignments All replicas, adding replicas and removing replicas of target partitions + */ + private def calculateCurrentMoveMap(currentReassignments: Map[TopicPartition, PartitionReassignment]): MoveMap = { + val moveMap = new mutable.HashMap[String, mutable.Map[Int, PartitionMove]]() + // Add the current reassignments to the move map. + currentReassignments.forKeyValue { (part, reassignment) => + val allReplicas = reassignment.replicas().asScala.map(Int.unbox) + val addingReplicas = reassignment.addingReplicas.asScala.map(Int.unbox) + + // The addingReplicas is included in the replicas during reassignment + val sources = mutable.Set[Int]() ++ allReplicas.diff(addingReplicas) + val destinations = mutable.Set[Int]() ++ addingReplicas + + val partMoves = moveMap.getOrElseUpdate(part.topic, new mutable.HashMap[Int, PartitionMove]) + partMoves.put(part.partition, PartitionMove(sources, destinations)) + } + moveMap + } + + /** + * Calculate the global map of all partitions that are moving. + * + * @param currentReassignments The currently active reassignments. + * @param proposedParts The proposed location of the partitions (destinations replicas only). + * @param currentParts The current location of the partitions that we are + * proposing to move. + * @return A map from topic name to partition map. + * The partition map is keyed on partition index and contains + * the movements for that partition. + */ + def calculateProposedMoveMap(currentReassignments: Map[TopicPartition, PartitionReassignment], + proposedParts: Map[TopicPartition, Seq[Int]], + currentParts: Map[TopicPartition, Seq[Int]]): MoveMap = { + val moveMap = calculateCurrentMoveMap(currentReassignments) + + proposedParts.forKeyValue { (part, replicas) => + val partMoves = moveMap.getOrElseUpdate(part.topic, new mutable.HashMap[Int, PartitionMove]) + + // If there is a reassignment in progress, use the sources from moveMap, otherwise + // use the sources from currentParts + val sources = mutable.Set[Int]() ++ (partMoves.get(part.partition) match { + case Some(move) => move.sources.toSeq + case None => currentParts.getOrElse(part, + throw new RuntimeException(s"Trying to reassign a topic partition $part with 0 replicas")) + }) + val destinations = mutable.Set[Int]() ++ replicas.diff(sources.toSeq) + + partMoves.put(part.partition, + PartitionMove(sources, destinations)) + } + moveMap + } + + /** + * Calculate the leader throttle configurations to use. + * + * @param moveMap The movements. + * @return A map from topic names to leader throttle configurations. + */ + def calculateLeaderThrottles(moveMap: MoveMap): Map[String, String] = { + moveMap.map { + case (topicName, partMoveMap) => { + val components = new mutable.TreeSet[String] + partMoveMap.forKeyValue { (partId, move) => + move.sources.foreach(source => components.add("%d:%d".format(partId, source))) + } + (topicName, components.mkString(",")) + } + } + } + + /** + * Calculate the follower throttle configurations to use. + * + * @param moveMap The movements. + * @return A map from topic names to follower throttle configurations. + */ + def calculateFollowerThrottles(moveMap: MoveMap): Map[String, String] = { + moveMap.map { + case (topicName, partMoveMap) => { + val components = new mutable.TreeSet[String] + partMoveMap.forKeyValue { (partId, move) => + move.destinations.foreach(destination => + if (!move.sources.contains(destination)) { + components.add("%d:%d".format(partId, destination)) + }) + } + (topicName, components.mkString(",")) + } + } + } + + /** + * Calculate all the brokers which are involved in the given partition reassignments. + * + * @param moveMap The partition movements. + * @return A set of all the brokers involved. + */ + def calculateReassigningBrokers(moveMap: MoveMap): Set[Int] = { + val reassigningBrokers = new mutable.TreeSet[Int] + moveMap.values.foreach { + _.values.foreach { + partMove => + partMove.sources.foreach(reassigningBrokers.add) + partMove.destinations.foreach(reassigningBrokers.add) + } + } + reassigningBrokers.toSet + } + + /** + * Calculate all the brokers which are involved in the given directory movements. + * + * @param replicaMoves The replica movements. + * @return A set of all the brokers involved. + */ + def calculateMovingBrokers(replicaMoves: Set[TopicPartitionReplica]): Set[Int] = { + replicaMoves.map(_.brokerId()) + } + + /** + * Modify the topic configurations that control inter-broker throttling. + * + * @param adminClient The adminClient object to use. + * @param leaderThrottles A map from topic names to leader throttle configurations. + * @param followerThrottles A map from topic names to follower throttle configurations. + */ + def modifyTopicThrottles(adminClient: Admin, + leaderThrottles: Map[String, String], + followerThrottles: Map[String, String]): Unit = { + val configs = new util.HashMap[ConfigResource, util.Collection[AlterConfigOp]]() + val topicNames = leaderThrottles.keySet ++ followerThrottles.keySet + topicNames.foreach { topicName => + val ops = new util.ArrayList[AlterConfigOp] + leaderThrottles.get(topicName).foreach { value => + ops.add(new AlterConfigOp(new ConfigEntry(topicLevelLeaderThrottle, value), OpType.SET)) + } + followerThrottles.get(topicName).foreach { value => + ops.add(new AlterConfigOp(new ConfigEntry(topicLevelFollowerThrottle, value), OpType.SET)) + } + if (!ops.isEmpty) { + configs.put(new ConfigResource(ConfigResource.Type.TOPIC, topicName), ops) + } + } + adminClient.incrementalAlterConfigs(configs).all().get() + } + + private def modifyReassignmentThrottle(admin: Admin, moveMap: MoveMap, interBrokerThrottle: Long): Unit = { + val leaderThrottles = calculateLeaderThrottles(moveMap) + val followerThrottles = calculateFollowerThrottles(moveMap) + modifyTopicThrottles(admin, leaderThrottles, followerThrottles) + + val reassigningBrokers = calculateReassigningBrokers(moveMap) + modifyInterBrokerThrottle(admin, reassigningBrokers, interBrokerThrottle) + } + + /** + * Modify the leader/follower replication throttles for a set of brokers. + * + * @param adminClient The Admin instance to use + * @param reassigningBrokers The set of brokers involved in the reassignment + * @param interBrokerThrottle The new throttle (ignored if less than 0) + */ + def modifyInterBrokerThrottle(adminClient: Admin, + reassigningBrokers: Set[Int], + interBrokerThrottle: Long): Unit = { + if (interBrokerThrottle >= 0) { + val configs = new util.HashMap[ConfigResource, util.Collection[AlterConfigOp]]() + reassigningBrokers.foreach { brokerId => + val ops = new util.ArrayList[AlterConfigOp] + ops.add(new AlterConfigOp(new ConfigEntry(brokerLevelLeaderThrottle, + interBrokerThrottle.toString), OpType.SET)) + ops.add(new AlterConfigOp(new ConfigEntry(brokerLevelFollowerThrottle, + interBrokerThrottle.toString), OpType.SET)) + configs.put(new ConfigResource(ConfigResource.Type.BROKER, brokerId.toString), ops) + } + adminClient.incrementalAlterConfigs(configs).all().get() + println(s"The inter-broker throttle limit was set to $interBrokerThrottle B/s") + } + } + + /** + * Modify the log dir reassignment throttle for a set of brokers. + * + * @param admin The Admin instance to use + * @param movingBrokers The set of broker to alter the throttle of + * @param logDirThrottle The new throttle (ignored if less than 0) + */ + def modifyLogDirThrottle(admin: Admin, + movingBrokers: Set[Int], + logDirThrottle: Long): Unit = { + if (logDirThrottle >= 0) { + val configs = new util.HashMap[ConfigResource, util.Collection[AlterConfigOp]]() + movingBrokers.foreach { brokerId => + val ops = new util.ArrayList[AlterConfigOp] + ops.add(new AlterConfigOp(new ConfigEntry(brokerLevelLogDirThrottle, logDirThrottle.toString), OpType.SET)) + configs.put(new ConfigResource(ConfigResource.Type.BROKER, brokerId.toString), ops) + } + admin.incrementalAlterConfigs(configs).all().get() + println(s"The replica-alter-dir throttle limit was set to $logDirThrottle B/s") + } + } + + /** + * Parse the reassignment JSON string passed to the --execute command. + * + * @param reassignmentJson The JSON string. + * @return A tuple of the partitions to be reassigned and the replicas + * to be reassigned. + */ + def parseExecuteAssignmentArgs(reassignmentJson: String) + : (Map[TopicPartition, Seq[Int]], Map[TopicPartitionReplica, String]) = { + val (partitionsToBeReassigned, replicaAssignment) = parsePartitionReassignmentData(reassignmentJson) + if (partitionsToBeReassigned.isEmpty) + throw new AdminCommandFailedException("Partition reassignment list cannot be empty") + if (partitionsToBeReassigned.exists(_._2.isEmpty)) { + throw new AdminCommandFailedException("Partition replica list cannot be empty") + } + val duplicateReassignedPartitions = CoreUtils.duplicates(partitionsToBeReassigned.map { case (tp, _) => tp }) + if (duplicateReassignedPartitions.nonEmpty) + throw new AdminCommandFailedException("Partition reassignment contains duplicate topic partitions: %s".format(duplicateReassignedPartitions.mkString(","))) + val duplicateEntries = partitionsToBeReassigned + .map { case (tp, replicas) => (tp, CoreUtils.duplicates(replicas))} + .filter { case (_, duplicatedReplicas) => duplicatedReplicas.nonEmpty } + if (duplicateEntries.nonEmpty) { + val duplicatesMsg = duplicateEntries + .map { case (tp, duplicateReplicas) => "%s contains multiple entries for %s".format(tp, duplicateReplicas.mkString(",")) } + .mkString(". ") + throw new AdminCommandFailedException("Partition replica lists may not contain duplicate entries: %s".format(duplicatesMsg)) + } + (partitionsToBeReassigned.toMap, replicaAssignment) + } + + /** + * The entry point for the --cancel command. + * + * @param adminClient The AdminClient to use. + * @param jsonString The JSON string to use for the topics and partitions to cancel. + * @param preserveThrottles True if we should avoid changing topic or broker throttles. + * @param timeoutMs The maximum time in ms to wait for log directory + * replica assignment to begin. + * @param time The Time object to use. + * + * @return A tuple of the partition reassignments that were cancelled, + * and the replica movements that were cancelled. + */ + def cancelAssignment(adminClient: Admin, + jsonString: String, + preserveThrottles: Boolean, + timeoutMs: Long = 10000L, + time: Time = Time.SYSTEM) + : (Set[TopicPartition], Set[TopicPartitionReplica]) = { + val (targetParts, targetReplicas) = parsePartitionReassignmentData(jsonString) + val targetPartsSet = targetParts.map(_._1).toSet + val curReassigningParts = adminClient.listPartitionReassignments(targetPartsSet.asJava). + reassignments().get().asScala.flatMap { + case (part, reassignment) => if (!reassignment.addingReplicas().isEmpty || + !reassignment.removingReplicas().isEmpty) { + Some(part) + } else { + None + } + }.toSet + if (curReassigningParts.nonEmpty) { + val errors = cancelPartitionReassignments(adminClient, curReassigningParts) + if (errors.nonEmpty) { + throw new TerseReassignmentFailureException( + "Error cancelling partition reassignment%s for:%n%s".format( + if (errors.size == 1) "" else "s", + errors.keySet.toBuffer.sortWith(compareTopicPartitions).map { + part => s"${part}: ${errors(part).getMessage}" + }.mkString(System.lineSeparator()))) + } + println("Successfully cancelled partition reassignment%s for: %s".format( + if (curReassigningParts.size == 1) "" else "s", + s"${curReassigningParts.toBuffer.sortWith(compareTopicPartitions).mkString(",")}")) + } else { + println("None of the specified partition reassignments are active.") + } + val curMovingParts = findLogDirMoveStates(adminClient, targetReplicas).flatMap { + case (part, moveState) => moveState match { + case state: ActiveMoveState => Some(part, state.currentLogDir) + case _ => None + } + }.toMap + if (curMovingParts.isEmpty) { + println("None of the specified partition moves are active.") + } else { + executeMoves(adminClient, curMovingParts, timeoutMs, time) + } + if (!preserveThrottles) { + clearAllThrottles(adminClient, targetParts) + } + (curReassigningParts, curMovingParts.keySet) + } + + def formatAsReassignmentJson(partitionsToBeReassigned: Map[TopicPartition, Seq[Int]], + replicaLogDirAssignment: Map[TopicPartitionReplica, String]): String = { + Json.encodeAsString(Map( + "version" -> 1, + "partitions" -> partitionsToBeReassigned.keySet.toBuffer.sortWith(compareTopicPartitions).map { + tp => + val replicas = partitionsToBeReassigned(tp) + Map( + "topic" -> tp.topic, + "partition" -> tp.partition, + "replicas" -> replicas.asJava, + "log_dirs" -> replicas.map(r => replicaLogDirAssignment.getOrElse(new TopicPartitionReplica(tp.topic, tp.partition, r), AnyLogDir)).asJava + ).asJava + }.asJava + ).asJava) + } + + def parseTopicsData(jsonData: String): Seq[String] = { + Json.parseFull(jsonData) match { + case Some(js) => + val version = js.asJsonObject.get("version") match { + case Some(jsonValue) => jsonValue.to[Int] + case None => EarliestTopicsJsonVersion + } + parseTopicsData(version, js) + case None => throw new AdminOperationException("The input string is not a valid JSON") + } + } + + def parseTopicsData(version: Int, js: JsonValue): Seq[String] = { + version match { + case 1 => + for { + partitionsSeq <- js.asJsonObject.get("topics").toSeq + p <- partitionsSeq.asJsonArray.iterator + } yield p.asJsonObject("topic").to[String] + case _ => throw new AdminOperationException(s"Not supported version field value $version") + } + } + + def parsePartitionReassignmentData(jsonData: String): (Seq[(TopicPartition, Seq[Int])], Map[TopicPartitionReplica, String]) = { + Json.tryParseFull(jsonData) match { + case Right(js) => + val version = js.asJsonObject.get("version") match { + case Some(jsonValue) => jsonValue.to[Int] + case None => EarliestVersion + } + parsePartitionReassignmentData(version, js) + case Left(f) => + throw new AdminOperationException(f) + } + } + + // Parses without deduplicating keys so the data can be checked before allowing reassignment to proceed + def parsePartitionReassignmentData(version:Int, jsonData: JsonValue): (Seq[(TopicPartition, Seq[Int])], Map[TopicPartitionReplica, String]) = { + version match { + case 1 => + val partitionAssignment = mutable.ListBuffer.empty[(TopicPartition, Seq[Int])] + val replicaAssignment = mutable.Map.empty[TopicPartitionReplica, String] + for { + partitionsSeq <- jsonData.asJsonObject.get("partitions").toSeq + p <- partitionsSeq.asJsonArray.iterator + } { + val partitionFields = p.asJsonObject + val topic = partitionFields("topic").to[String] + val partition = partitionFields("partition").to[Int] + val newReplicas = partitionFields("replicas").to[Seq[Int]] + val newLogDirs = partitionFields.get("log_dirs") match { + case Some(jsonValue) => jsonValue.to[Seq[String]] + case None => newReplicas.map(_ => AnyLogDir) + } + if (newReplicas.size != newLogDirs.size) + throw new AdminCommandFailedException(s"Size of replicas list $newReplicas is different from " + + s"size of log dirs list $newLogDirs for partition ${new TopicPartition(topic, partition)}") + partitionAssignment += (new TopicPartition(topic, partition) -> newReplicas) + replicaAssignment ++= newReplicas.zip(newLogDirs).map { case (replica, logDir) => + new TopicPartitionReplica(topic, partition, replica) -> logDir + }.filter(_._2 != AnyLogDir) + } + (partitionAssignment, replicaAssignment) + case _ => throw new AdminOperationException(s"Not supported version field value $version") + } + } + + def validateAndParseArgs(args: Array[String]): ReassignPartitionsCommandOptions = { + val opts = new ReassignPartitionsCommandOptions(args) + + CommandLineUtils.printHelpAndExitIfNeeded(opts, helpText) + + // Determine which action we should perform. + val validActions = Seq(opts.generateOpt, opts.executeOpt, opts.verifyOpt, + opts.cancelOpt, opts.listOpt) + val allActions = validActions.filter(opts.options.has _) + if (allActions.size != 1) { + CommandLineUtils.printUsageAndDie(opts.parser, "Command must include exactly one action: %s".format( + validActions.map("--" + _.options().get(0)).mkString(", "))) + } + val action = allActions(0) + + if (!opts.options.has(opts.bootstrapServerOpt)) + CommandLineUtils.printUsageAndDie(opts.parser, "Please specify --bootstrap-server") + + // Make sure that we have all the required arguments for our action. + val requiredArgs = Map( + opts.verifyOpt -> collection.immutable.Seq( + opts.reassignmentJsonFileOpt + ), + opts.generateOpt -> collection.immutable.Seq( + opts.topicsToMoveJsonFileOpt, + opts.brokerListOpt + ), + opts.executeOpt -> collection.immutable.Seq( + opts.reassignmentJsonFileOpt + ), + opts.cancelOpt -> collection.immutable.Seq( + opts.reassignmentJsonFileOpt + ), + opts.listOpt -> collection.immutable.Seq.empty + ) + CommandLineUtils.checkRequiredArgs(opts.parser, opts.options, requiredArgs(action): _*) + + // Make sure that we didn't specify any arguments that are incompatible with our chosen action. + val permittedArgs = Map( + opts.verifyOpt -> Seq( + opts.bootstrapServerOpt, + opts.commandConfigOpt, + opts.preserveThrottlesOpt, + ), + opts.generateOpt -> Seq( + opts.bootstrapServerOpt, + opts.brokerListOpt, + opts.commandConfigOpt, + opts.disableRackAware, + ), + opts.executeOpt -> Seq( + opts.additionalOpt, + opts.bootstrapServerOpt, + opts.commandConfigOpt, + opts.interBrokerThrottleOpt, + opts.replicaAlterLogDirsThrottleOpt, + opts.timeoutOpt, + ), + opts.cancelOpt -> Seq( + opts.bootstrapServerOpt, + opts.commandConfigOpt, + opts.preserveThrottlesOpt, + opts.timeoutOpt + ), + opts.listOpt -> Seq( + opts.bootstrapServerOpt, + opts.commandConfigOpt + ) + ) + opts.options.specs.forEach(opt => { + if (!opt.equals(action) && + !requiredArgs(action).contains(opt) && + !permittedArgs(action).contains(opt)) { + CommandLineUtils.printUsageAndDie(opts.parser, + """Option "%s" can't be used with action "%s"""".format(opt, action)) + } + }) + + opts + } + + def alterReplicaLogDirs(adminClient: Admin, + assignment: Map[TopicPartitionReplica, String]) + : Set[TopicPartitionReplica] = { + adminClient.alterReplicaLogDirs(assignment.asJava).values().asScala.flatMap { + case (replica, future) => { + try { + future.get() + Some(replica) + } catch { + case t: ExecutionException => + t.getCause match { + // Ignore ReplicaNotAvailableException. It is OK if the replica is not + // available at this moment. + case _: ReplicaNotAvailableException => None + case e: Throwable => + throw new AdminCommandFailedException(s"Failed to alter dir for $replica", e) + } + } + } + }.toSet + } + + sealed class ReassignPartitionsCommandOptions(args: Array[String]) extends CommandDefaultOptions(args) { + // Actions + val verifyOpt = parser.accepts("verify", "Verify if the reassignment completed as specified by the " + + "--reassignment-json-file option. If there is a throttle engaged for the replicas specified, and the rebalance has completed, the throttle will be removed") + val generateOpt = parser.accepts("generate", "Generate a candidate partition reassignment configuration." + + " Note that this only generates a candidate assignment, it does not execute it.") + val executeOpt = parser.accepts("execute", "Kick off the reassignment as specified by the --reassignment-json-file option.") + val cancelOpt = parser.accepts("cancel", "Cancel an active reassignment.") + val listOpt = parser.accepts("list", "List all active partition reassignments.") + + // Arguments + val bootstrapServerOpt = parser.accepts("bootstrap-server", "REQUIRED: the server(s) to use for bootstrapping.") + .withRequiredArg + .describedAs("Server(s) to use for bootstrapping") + .ofType(classOf[String]) + + val commandConfigOpt = parser.accepts("command-config", "Property file containing configs to be passed to Admin Client.") + .withRequiredArg + .describedAs("Admin client property file") + .ofType(classOf[String]) + + val reassignmentJsonFileOpt = parser.accepts("reassignment-json-file", "The JSON file with the partition reassignment configuration" + + "The format to use is - \n" + + "{\"partitions\":\n\t[{\"topic\": \"foo\",\n\t \"partition\": 1,\n\t \"replicas\": [1,2,3],\n\t \"log_dirs\": [\"dir1\",\"dir2\",\"dir3\"] }],\n\"version\":1\n}\n" + + "Note that \"log_dirs\" is optional. When it is specified, its length must equal the length of the replicas list. The value in this list " + + "can be either \"any\" or the absolution path of the log directory on the broker. If absolute log directory path is specified, the replica will be moved to the specified log directory on the broker.") + .withRequiredArg + .describedAs("manual assignment json file path") + .ofType(classOf[String]) + val topicsToMoveJsonFileOpt = parser.accepts("topics-to-move-json-file", "Generate a reassignment configuration to move the partitions" + + " of the specified topics to the list of brokers specified by the --broker-list option. The format to use is - \n" + + "{\"topics\":\n\t[{\"topic\": \"foo\"},{\"topic\": \"foo1\"}],\n\"version\":1\n}") + .withRequiredArg + .describedAs("topics to reassign json file path") + .ofType(classOf[String]) + val brokerListOpt = parser.accepts("broker-list", "The list of brokers to which the partitions need to be reassigned" + + " in the form \"0,1,2\". This is required if --topics-to-move-json-file is used to generate reassignment configuration") + .withRequiredArg + .describedAs("brokerlist") + .ofType(classOf[String]) + val disableRackAware = parser.accepts("disable-rack-aware", "Disable rack aware replica assignment") + val interBrokerThrottleOpt = parser.accepts("throttle", "The movement of partitions between brokers will be throttled to this value (bytes/sec). " + + "This option can be included with --execute when a reassignment is started, and it can be altered by resubmitting the current reassignment " + + "along with the --additional flag. The throttle rate should be at least 1 KB/s.") + .withRequiredArg() + .describedAs("throttle") + .ofType(classOf[Long]) + .defaultsTo(-1) + val replicaAlterLogDirsThrottleOpt = parser.accepts("replica-alter-log-dirs-throttle", + "The movement of replicas between log directories on the same broker will be throttled to this value (bytes/sec). " + + "This option can be included with --execute when a reassignment is started, and it can be altered by resubmitting the current reassignment " + + "along with the --additional flag. The throttle rate should be at least 1 KB/s.") + .withRequiredArg() + .describedAs("replicaAlterLogDirsThrottle") + .ofType(classOf[Long]) + .defaultsTo(-1) + val timeoutOpt = parser.accepts("timeout", "The maximum time in ms to wait for log directory replica assignment to begin.") + .withRequiredArg() + .describedAs("timeout") + .ofType(classOf[Long]) + .defaultsTo(10000) + val additionalOpt = parser.accepts("additional", "Execute this reassignment in addition to any " + + "other ongoing ones. This option can also be used to change the throttle of an ongoing reassignment.") + val preserveThrottlesOpt = parser.accepts("preserve-throttles", "Do not modify broker or topic throttles.") + options = parser.parse(args : _*) + } +} diff --git a/core/src/main/scala/kafka/admin/TopicCommand.scala b/core/src/main/scala/kafka/admin/TopicCommand.scala new file mode 100755 index 0000000..5e7d98c --- /dev/null +++ b/core/src/main/scala/kafka/admin/TopicCommand.scala @@ -0,0 +1,658 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.admin + +import java.util +import java.util.{Collections, Properties} +import joptsimple._ +import kafka.common.AdminCommandFailedException +import kafka.log.LogConfig +import kafka.utils._ +import org.apache.kafka.clients.CommonClientConfigs +import org.apache.kafka.clients.admin.CreatePartitionsOptions +import org.apache.kafka.clients.admin.CreateTopicsOptions +import org.apache.kafka.clients.admin.DeleteTopicsOptions +import org.apache.kafka.clients.admin.{Admin, ListTopicsOptions, NewPartitions, NewTopic, PartitionReassignment, Config => JConfig} +import org.apache.kafka.common.{TopicCollection, TopicPartition, TopicPartitionInfo, Uuid} +import org.apache.kafka.common.config.ConfigResource.Type +import org.apache.kafka.common.config.{ConfigResource, TopicConfig} +import org.apache.kafka.common.errors.{ClusterAuthorizationException, TopicExistsException, UnsupportedVersionException} +import org.apache.kafka.common.internals.Topic +import org.apache.kafka.common.utils.Utils + +import scala.annotation.nowarn +import scala.jdk.CollectionConverters._ +import scala.collection._ +import scala.compat.java8.OptionConverters._ +import scala.concurrent.ExecutionException + +object TopicCommand extends Logging { + + def main(args: Array[String]): Unit = { + val opts = new TopicCommandOptions(args) + opts.checkArgs() + + val topicService = TopicService(opts.commandConfig, opts.bootstrapServer) + + var exitCode = 0 + try { + if (opts.hasCreateOption) + topicService.createTopic(opts) + else if (opts.hasAlterOption) + topicService.alterTopic(opts) + else if (opts.hasListOption) + topicService.listTopics(opts) + else if (opts.hasDescribeOption) + topicService.describeTopic(opts) + else if (opts.hasDeleteOption) + topicService.deleteTopic(opts) + } catch { + case e: ExecutionException => + if (e.getCause != null) + printException(e.getCause) + else + printException(e) + exitCode = 1 + case e: Throwable => + printException(e) + exitCode = 1 + } finally { + topicService.close() + Exit.exit(exitCode) + } + } + + private def printException(e: Throwable): Unit = { + println("Error while executing topic command : " + e.getMessage) + error(Utils.stackTrace(e)) + } + + class CommandTopicPartition(opts: TopicCommandOptions) { + val name = opts.topic.get + val partitions = opts.partitions + val replicationFactor = opts.replicationFactor + val replicaAssignment = opts.replicaAssignment + val configsToAdd = parseTopicConfigsToBeAdded(opts) + val configsToDelete = parseTopicConfigsToBeDeleted(opts) + val rackAwareMode = opts.rackAwareMode + + def hasReplicaAssignment: Boolean = replicaAssignment.isDefined + def hasPartitions: Boolean = partitions.isDefined + def ifTopicDoesntExist(): Boolean = opts.ifNotExists + } + + case class TopicDescription(topic: String, + topicId: Uuid, + numPartitions: Int, + replicationFactor: Int, + config: JConfig, + markedForDeletion: Boolean) { + + def printDescription(): Unit = { + val configsAsString = config.entries.asScala.filter(!_.isDefault).map { ce => s"${ce.name}=${ce.value}" }.mkString(",") + print(s"Topic: $topic") + if(topicId != Uuid.ZERO_UUID) print(s"\tTopicId: $topicId") + print(s"\tPartitionCount: $numPartitions") + print(s"\tReplicationFactor: $replicationFactor") + print(s"\tConfigs: $configsAsString") + print(if (markedForDeletion) "\tMarkedForDeletion: true" else "") + println() + } + } + + case class PartitionDescription(topic: String, + info: TopicPartitionInfo, + config: Option[JConfig], + markedForDeletion: Boolean, + reassignment: Option[PartitionReassignment]) { + + private def minIsrCount: Option[Int] = { + config.map(_.get(TopicConfig.MIN_IN_SYNC_REPLICAS_CONFIG).value.toInt) + } + + def isUnderReplicated: Boolean = { + getReplicationFactor(info, reassignment) - info.isr.size > 0 + } + + private def hasLeader: Boolean = { + info.leader != null + } + + def isUnderMinIsr: Boolean = { + !hasLeader || minIsrCount.exists(info.isr.size < _) + } + + def isAtMinIsrPartitions: Boolean = { + minIsrCount.contains(info.isr.size) + } + + def hasUnavailablePartitions(liveBrokers: Set[Int]): Boolean = { + !hasLeader || !liveBrokers.contains(info.leader.id) + } + + def printDescription(): Unit = { + print("\tTopic: " + topic) + print("\tPartition: " + info.partition) + print("\tLeader: " + (if (hasLeader) info.leader.id else "none")) + print("\tReplicas: " + info.replicas.asScala.map(_.id).mkString(",")) + print("\tIsr: " + info.isr.asScala.map(_.id).mkString(",")) + if (reassignment.nonEmpty) { + print("\tAdding Replicas: " + reassignment.get.addingReplicas().asScala.mkString(",")) + print("\tRemoving Replicas: " + reassignment.get.removingReplicas().asScala.mkString(",")) + } + print(if (markedForDeletion) "\tMarkedForDeletion: true" else "") + println() + } + + } + + class DescribeOptions(opts: TopicCommandOptions, liveBrokers: Set[Int]) { + val describeConfigs = + !opts.reportUnavailablePartitions && + !opts.reportUnderReplicatedPartitions && + !opts.reportUnderMinIsrPartitions && + !opts.reportAtMinIsrPartitions + val describePartitions = !opts.reportOverriddenConfigs + + private def shouldPrintUnderReplicatedPartitions(partitionDescription: PartitionDescription): Boolean = { + opts.reportUnderReplicatedPartitions && partitionDescription.isUnderReplicated + } + private def shouldPrintUnavailablePartitions(partitionDescription: PartitionDescription): Boolean = { + opts.reportUnavailablePartitions && partitionDescription.hasUnavailablePartitions(liveBrokers) + } + private def shouldPrintUnderMinIsrPartitions(partitionDescription: PartitionDescription): Boolean = { + opts.reportUnderMinIsrPartitions && partitionDescription.isUnderMinIsr + } + private def shouldPrintAtMinIsrPartitions(partitionDescription: PartitionDescription): Boolean = { + opts.reportAtMinIsrPartitions && partitionDescription.isAtMinIsrPartitions + } + + private def shouldPrintTopicPartition(partitionDesc: PartitionDescription): Boolean = { + describeConfigs || + shouldPrintUnderReplicatedPartitions(partitionDesc) || + shouldPrintUnavailablePartitions(partitionDesc) || + shouldPrintUnderMinIsrPartitions(partitionDesc) || + shouldPrintAtMinIsrPartitions(partitionDesc) + } + + def maybePrintPartitionDescription(desc: PartitionDescription): Unit = { + if (shouldPrintTopicPartition(desc)) + desc.printDescription() + } + } + + object TopicService { + def createAdminClient(commandConfig: Properties, bootstrapServer: Option[String]): Admin = { + bootstrapServer match { + case Some(serverList) => commandConfig.put(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, serverList) + case None => + } + Admin.create(commandConfig) + } + + def apply(commandConfig: Properties, bootstrapServer: Option[String]): TopicService = + new TopicService(createAdminClient(commandConfig, bootstrapServer)) + } + + case class TopicService private (adminClient: Admin) extends AutoCloseable { + + def createTopic(opts: TopicCommandOptions): Unit = { + val topic = new CommandTopicPartition(opts) + if (Topic.hasCollisionChars(topic.name)) + println("WARNING: Due to limitations in metric names, topics with a period ('.') or underscore ('_') could " + + "collide. To avoid issues it is best to use either, but not both.") + createTopic(topic) + } + + def createTopic(topic: CommandTopicPartition): Unit = { + if (topic.replicationFactor.exists(rf => rf > Short.MaxValue || rf < 1)) + throw new IllegalArgumentException(s"The replication factor must be between 1 and ${Short.MaxValue} inclusive") + if (topic.partitions.exists(partitions => partitions < 1)) + throw new IllegalArgumentException(s"The partitions must be greater than 0") + + try { + val newTopic = if (topic.hasReplicaAssignment) + new NewTopic(topic.name, asJavaReplicaReassignment(topic.replicaAssignment.get)) + else { + new NewTopic( + topic.name, + topic.partitions.asJava, + topic.replicationFactor.map(_.toShort).map(Short.box).asJava) + } + + val configsMap = topic.configsToAdd.stringPropertyNames() + .asScala + .map(name => name -> topic.configsToAdd.getProperty(name)) + .toMap.asJava + + newTopic.configs(configsMap) + val createResult = adminClient.createTopics(Collections.singleton(newTopic), + new CreateTopicsOptions().retryOnQuotaViolation(false)) + createResult.all().get() + println(s"Created topic ${topic.name}.") + } catch { + case e : ExecutionException => + if (e.getCause == null) + throw e + if (!(e.getCause.isInstanceOf[TopicExistsException] && topic.ifTopicDoesntExist())) + throw e.getCause + } + } + + def listTopics(opts: TopicCommandOptions): Unit = { + println(getTopics(opts.topic, opts.excludeInternalTopics).mkString("\n")) + } + + def alterTopic(opts: TopicCommandOptions): Unit = { + val topic = new CommandTopicPartition(opts) + val topics = getTopics(opts.topic, opts.excludeInternalTopics) + ensureTopicExists(topics, opts.topic, !opts.ifExists) + + if (topics.nonEmpty) { + val topicsInfo = adminClient.describeTopics(topics.asJavaCollection).topicNameValues() + val newPartitions = topics.map { topicName => + if (topic.hasReplicaAssignment) { + val startPartitionId = topicsInfo.get(topicName).get().partitions().size() + val newAssignment = { + val replicaMap = topic.replicaAssignment.get.drop(startPartitionId) + new util.ArrayList(replicaMap.map(p => p._2.asJava).asJavaCollection).asInstanceOf[util.List[util.List[Integer]]] + } + topicName -> NewPartitions.increaseTo(topic.partitions.get, newAssignment) + } else { + topicName -> NewPartitions.increaseTo(topic.partitions.get) + } + }.toMap + adminClient.createPartitions(newPartitions.asJava, + new CreatePartitionsOptions().retryOnQuotaViolation(false)).all().get() + } + } + + def listAllReassignments(topicPartitions: util.Set[TopicPartition]): Map[TopicPartition, PartitionReassignment] = { + try { + adminClient.listPartitionReassignments(topicPartitions).reassignments().get().asScala + } catch { + case e: ExecutionException => + e.getCause match { + case ex @ (_: UnsupportedVersionException | _: ClusterAuthorizationException) => + logger.debug(s"Couldn't query reassignments through the AdminClient API: ${ex.getMessage}", ex) + Map() + case t => throw t + } + } + } + + def describeTopic(opts: TopicCommandOptions): Unit = { + // If topicId is provided and not zero, will use topicId regardless of topic name + val inputTopicId = opts.topicId.map(Uuid.fromString).filter(uuid => uuid != Uuid.ZERO_UUID) + val useTopicId = inputTopicId.nonEmpty + + val (topicIds, topics) = if (useTopicId) + (getTopicIds(inputTopicId, opts.excludeInternalTopics), Seq()) + else + (Seq(), getTopics(opts.topic, opts.excludeInternalTopics)) + + // Only check topic name when topicId is not provided + if (useTopicId) + ensureTopicIdExists(topicIds, inputTopicId, !opts.ifExists) + else + ensureTopicExists(topics, opts.topic, !opts.ifExists) + + val topicDescriptions = if (topicIds.nonEmpty) { + adminClient.describeTopics(TopicCollection.ofTopicIds(topicIds.toSeq.asJavaCollection)).allTopicIds().get().values().asScala + } else if (topics.nonEmpty) { + adminClient.describeTopics(TopicCollection.ofTopicNames(topics.asJavaCollection)).allTopicNames().get().values().asScala + } else { + Seq() + } + + val topicNames = topicDescriptions.map(_.name()) + val allConfigs = adminClient.describeConfigs(topicNames.map(new ConfigResource(Type.TOPIC, _)).asJavaCollection).values() + val liveBrokers = adminClient.describeCluster().nodes().get().asScala.map(_.id()) + val describeOptions = new DescribeOptions(opts, liveBrokers.toSet) + val topicPartitions = topicDescriptions + .flatMap(td => td.partitions.iterator().asScala.map(p => new TopicPartition(td.name(), p.partition()))) + .toSet.asJava + val reassignments = listAllReassignments(topicPartitions) + + for (td <- topicDescriptions) { + val topicName = td.name + val topicId = td.topicId() + val config = allConfigs.get(new ConfigResource(Type.TOPIC, topicName)).get() + val sortedPartitions = td.partitions.asScala.sortBy(_.partition) + + if (describeOptions.describeConfigs) { + val hasNonDefault = config.entries().asScala.exists(!_.isDefault) + if (!opts.reportOverriddenConfigs || hasNonDefault) { + val numPartitions = td.partitions().size + val firstPartition = td.partitions.iterator.next() + val reassignment = reassignments.get(new TopicPartition(td.name, firstPartition.partition)) + val topicDesc = TopicDescription(topicName, topicId, numPartitions, getReplicationFactor(firstPartition, reassignment), config, markedForDeletion = false) + topicDesc.printDescription() + } + } + + if (describeOptions.describePartitions) { + for (partition <- sortedPartitions) { + val reassignment = reassignments.get(new TopicPartition(td.name, partition.partition)) + val partitionDesc = PartitionDescription(topicName, partition, Some(config), markedForDeletion = false, reassignment) + describeOptions.maybePrintPartitionDescription(partitionDesc) + } + } + } + } + + def deleteTopic(opts: TopicCommandOptions): Unit = { + val topics = getTopics(opts.topic, opts.excludeInternalTopics) + ensureTopicExists(topics, opts.topic, !opts.ifExists) + adminClient.deleteTopics(topics.asJavaCollection, new DeleteTopicsOptions().retryOnQuotaViolation(false)) + .all().get() + } + + def getTopics(topicIncludeList: Option[String], excludeInternalTopics: Boolean = false): Seq[String] = { + val allTopics = if (excludeInternalTopics) { + adminClient.listTopics() + } else { + adminClient.listTopics(new ListTopicsOptions().listInternal(true)) + } + doGetTopics(allTopics.names().get().asScala.toSeq.sorted, topicIncludeList, excludeInternalTopics) + } + + def getTopicIds(topicIdIncludeList: Option[Uuid], excludeInternalTopics: Boolean = false): Seq[Uuid] = { + val allTopics = if (excludeInternalTopics) { + adminClient.listTopics() + } else { + adminClient.listTopics(new ListTopicsOptions().listInternal(true)) + } + val allTopicIds = allTopics.listings().get().asScala.map(_.topicId()).toSeq.sorted + topicIdIncludeList.filter(allTopicIds.contains).toSeq + } + + def close(): Unit = adminClient.close() + } + + /** + * ensures topic existence and throws exception if topic doesn't exist + * + * @param foundTopics Topics that were found to match the requested topic name. + * @param requestedTopic Name of the topic that was requested. + * @param requireTopicExists Indicates if the topic needs to exist for the operation to be successful. + * If set to true, the command will throw an exception if the topic with the + * requested name does not exist. + */ + private def ensureTopicExists(foundTopics: Seq[String], requestedTopic: Option[String], requireTopicExists: Boolean): Unit = { + // If no topic name was mentioned, do not need to throw exception. + if (requestedTopic.isDefined && requireTopicExists && foundTopics.isEmpty) { + // If given topic doesn't exist then throw exception + throw new IllegalArgumentException(s"Topic '${requestedTopic.get}' does not exist as expected") + } + } + + /** + * ensures topic existence and throws exception if topic doesn't exist + * + * @param foundTopicIds Topics that were found to match the requested topic id. + * @param requestedTopicId Id of the topic that was requested. + * @param requireTopicIdExists Indicates if the topic needs to exist for the operation to be successful. + * If set to true, the command will throw an exception if the topic with the + * requested id does not exist. + */ + private def ensureTopicIdExists(foundTopicIds: Seq[Uuid], requestedTopicId: Option[Uuid], requireTopicIdExists: Boolean): Unit = { + // If no topic id was mentioned, do not need to throw exception. + if (requestedTopicId.isDefined && requireTopicIdExists && foundTopicIds.isEmpty) { + // If given topicId doesn't exist then throw exception + throw new IllegalArgumentException(s"TopicId '${requestedTopicId.get}' does not exist as expected") + } + } + + private def doGetTopics(allTopics: Seq[String], topicIncludeList: Option[String], excludeInternalTopics: Boolean): Seq[String] = { + if (topicIncludeList.isDefined) { + val topicsFilter = IncludeList(topicIncludeList.get) + allTopics.filter(topicsFilter.isTopicAllowed(_, excludeInternalTopics)) + } else + allTopics.filterNot(Topic.isInternal(_) && excludeInternalTopics) + } + + @nowarn("cat=deprecation") + def parseTopicConfigsToBeAdded(opts: TopicCommandOptions): Properties = { + val configsToBeAdded = opts.topicConfig.getOrElse(Collections.emptyList()).asScala.map(_.split("""\s*=\s*""")) + require(configsToBeAdded.forall(config => config.length == 2), + "Invalid topic config: all configs to be added must be in the format \"key=val\".") + val props = new Properties + configsToBeAdded.foreach(pair => props.setProperty(pair(0).trim, pair(1).trim)) + LogConfig.validate(props) + if (props.containsKey(LogConfig.MessageFormatVersionProp)) { + println(s"WARNING: The configuration ${LogConfig.MessageFormatVersionProp}=${props.getProperty(LogConfig.MessageFormatVersionProp)} is specified. " + + "This configuration will be ignored if the version is newer than the inter.broker.protocol.version specified in the broker or " + + "if the inter.broker.protocol.version is 3.0 or newer. This configuration is deprecated and it will be removed in Apache Kafka 4.0.") + } + props + } + + def parseTopicConfigsToBeDeleted(opts: TopicCommandOptions): Seq[String] = { + val configsToBeDeleted = opts.configsToDelete.getOrElse(Collections.emptyList()).asScala.map(_.trim()) + val propsToBeDeleted = new Properties + configsToBeDeleted.foreach(propsToBeDeleted.setProperty(_, "")) + LogConfig.validateNames(propsToBeDeleted) + configsToBeDeleted + } + + def parseReplicaAssignment(replicaAssignmentList: String): Map[Int, List[Int]] = { + val partitionList = replicaAssignmentList.split(",") + val ret = new mutable.LinkedHashMap[Int, List[Int]]() + for (i <- 0 until partitionList.size) { + val brokerList = partitionList(i).split(":").map(s => s.trim().toInt) + val duplicateBrokers = CoreUtils.duplicates(brokerList) + if (duplicateBrokers.nonEmpty) + throw new AdminCommandFailedException(s"Partition replica lists may not contain duplicate entries: ${duplicateBrokers.mkString(",")}") + ret.put(i, brokerList.toList) + if (ret(i).size != ret(0).size) + throw new AdminOperationException("Partition " + i + " has different replication factor: " + brokerList) + } + ret + } + + def asJavaReplicaReassignment(original: Map[Int, List[Int]]): util.Map[Integer, util.List[Integer]] = { + original.map(f => Integer.valueOf(f._1) -> f._2.map(e => Integer.valueOf(e)).asJava).asJava + } + + private def getReplicationFactor(tpi: TopicPartitionInfo, reassignment: Option[PartitionReassignment]): Int = { + // It is possible for a reassignment to complete between the time we have fetched its state and the time + // we fetch partition metadata. In ths case, we ignore the reassignment when determining replication factor. + def isReassignmentInProgress(ra: PartitionReassignment): Boolean = { + // Reassignment is still in progress as long as the removing and adding replicas are still present + val allReplicaIds = tpi.replicas.asScala.map(_.id).toSet + val changingReplicaIds = ra.removingReplicas.asScala.map(_.intValue).toSet ++ ra.addingReplicas.asScala.map(_.intValue).toSet + allReplicaIds.exists(changingReplicaIds.contains) + } + + reassignment match { + case Some(ra) if isReassignmentInProgress(ra) => ra.replicas.asScala.diff(ra.addingReplicas.asScala).size + case _=> tpi.replicas.size + } + } + + class TopicCommandOptions(args: Array[String]) extends CommandDefaultOptions(args) { + private val bootstrapServerOpt = parser.accepts("bootstrap-server", "REQUIRED: The Kafka server to connect to.") + .withRequiredArg + .describedAs("server to connect to") + .ofType(classOf[String]) + + private val commandConfigOpt = parser.accepts("command-config", "Property file containing configs to be passed to Admin Client. " + + "This is used only with --bootstrap-server option for describing and altering broker configs.") + .withRequiredArg + .describedAs("command config property file") + .ofType(classOf[String]) + + private val listOpt = parser.accepts("list", "List all available topics.") + private val createOpt = parser.accepts("create", "Create a new topic.") + private val deleteOpt = parser.accepts("delete", "Delete a topic") + private val alterOpt = parser.accepts("alter", "Alter the number of partitions, replica assignment, and/or configuration for the topic.") + private val describeOpt = parser.accepts("describe", "List details for the given topics.") + private val topicOpt = parser.accepts("topic", "The topic to create, alter, describe or delete. It also accepts a regular " + + "expression, except for --create option. Put topic name in double quotes and use the '\\' prefix " + + "to escape regular expression symbols; e.g. \"test\\.topic\".") + .withRequiredArg + .describedAs("topic") + .ofType(classOf[String]) + private val topicIdOpt = parser.accepts("topic-id", "The topic-id to describe." + + "This is used only with --bootstrap-server option for describing topics.") + .withRequiredArg + .describedAs("topic-id") + .ofType(classOf[String]) + private val nl = System.getProperty("line.separator") + private val kafkaConfigsCanAlterTopicConfigsViaBootstrapServer = + " (the kafka-configs CLI supports altering topic configs with a --bootstrap-server option)" + private val configOpt = parser.accepts("config", "A topic configuration override for the topic being created or altered." + + " The following is a list of valid configurations: " + nl + LogConfig.configNames.map("\t" + _).mkString(nl) + nl + + "See the Kafka documentation for full details on the topic configs." + + " It is supported only in combination with --create if --bootstrap-server option is used" + + kafkaConfigsCanAlterTopicConfigsViaBootstrapServer + ".") + .withRequiredArg + .describedAs("name=value") + .ofType(classOf[String]) + private val deleteConfigOpt = parser.accepts("delete-config", "A topic configuration override to be removed for an existing topic (see the list of configurations under the --config option). " + + "Not supported with the --bootstrap-server option.") + .withRequiredArg + .describedAs("name") + .ofType(classOf[String]) + private val partitionsOpt = parser.accepts("partitions", "The number of partitions for the topic being created or " + + "altered (WARNING: If partitions are increased for a topic that has a key, the partition logic or ordering of the messages will be affected). If not supplied for create, defaults to the cluster default.") + .withRequiredArg + .describedAs("# of partitions") + .ofType(classOf[java.lang.Integer]) + private val replicationFactorOpt = parser.accepts("replication-factor", "The replication factor for each partition in the topic being created. If not supplied, defaults to the cluster default.") + .withRequiredArg + .describedAs("replication factor") + .ofType(classOf[java.lang.Integer]) + private val replicaAssignmentOpt = parser.accepts("replica-assignment", "A list of manual partition-to-broker assignments for the topic being created or altered.") + .withRequiredArg + .describedAs("broker_id_for_part1_replica1 : broker_id_for_part1_replica2 , " + + "broker_id_for_part2_replica1 : broker_id_for_part2_replica2 , ...") + .ofType(classOf[String]) + private val reportUnderReplicatedPartitionsOpt = parser.accepts("under-replicated-partitions", + "if set when describing topics, only show under replicated partitions") + private val reportUnavailablePartitionsOpt = parser.accepts("unavailable-partitions", + "if set when describing topics, only show partitions whose leader is not available") + private val reportUnderMinIsrPartitionsOpt = parser.accepts("under-min-isr-partitions", + "if set when describing topics, only show partitions whose isr count is less than the configured minimum.") + private val reportAtMinIsrPartitionsOpt = parser.accepts("at-min-isr-partitions", + "if set when describing topics, only show partitions whose isr count is equal to the configured minimum.") + private val topicsWithOverridesOpt = parser.accepts("topics-with-overrides", + "if set when describing topics, only show topics that have overridden configs") + private val ifExistsOpt = parser.accepts("if-exists", + "if set when altering or deleting or describing topics, the action will only execute if the topic exists.") + private val ifNotExistsOpt = parser.accepts("if-not-exists", + "if set when creating topics, the action will only execute if the topic does not already exist.") + + private val disableRackAware = parser.accepts("disable-rack-aware", "Disable rack aware replica assignment") + + private val excludeInternalTopicOpt = parser.accepts("exclude-internal", + "exclude internal topics when running list or describe command. The internal topics will be listed by default") + + options = parser.parse(args : _*) + + private val allTopicLevelOpts = immutable.Set[OptionSpec[_]](alterOpt, createOpt, describeOpt, listOpt, deleteOpt) + + private val allReplicationReportOpts = Set(reportUnderReplicatedPartitionsOpt, reportUnderMinIsrPartitionsOpt, reportAtMinIsrPartitionsOpt, reportUnavailablePartitionsOpt) + + def has(builder: OptionSpec[_]): Boolean = options.has(builder) + def valueAsOption[A](option: OptionSpec[A], defaultValue: Option[A] = None): Option[A] = if (has(option)) Some(options.valueOf(option)) else defaultValue + def valuesAsOption[A](option: OptionSpec[A], defaultValue: Option[util.List[A]] = None): Option[util.List[A]] = if (has(option)) Some(options.valuesOf(option)) else defaultValue + + def hasCreateOption: Boolean = has(createOpt) + def hasAlterOption: Boolean = has(alterOpt) + def hasListOption: Boolean = has(listOpt) + def hasDescribeOption: Boolean = has(describeOpt) + def hasDeleteOption: Boolean = has(deleteOpt) + + def bootstrapServer: Option[String] = valueAsOption(bootstrapServerOpt) + def commandConfig: Properties = if (has(commandConfigOpt)) Utils.loadProps(options.valueOf(commandConfigOpt)) else new Properties() + def topic: Option[String] = valueAsOption(topicOpt) + def topicId: Option[String] = valueAsOption(topicIdOpt) + def partitions: Option[Integer] = valueAsOption(partitionsOpt) + def replicationFactor: Option[Integer] = valueAsOption(replicationFactorOpt) + def replicaAssignment: Option[Map[Int, List[Int]]] = + if (has(replicaAssignmentOpt) && !Option(options.valueOf(replicaAssignmentOpt)).getOrElse("").isEmpty) + Some(parseReplicaAssignment(options.valueOf(replicaAssignmentOpt))) + else + None + def rackAwareMode: RackAwareMode = if (has(disableRackAware)) RackAwareMode.Disabled else RackAwareMode.Enforced + def reportUnderReplicatedPartitions: Boolean = has(reportUnderReplicatedPartitionsOpt) + def reportUnavailablePartitions: Boolean = has(reportUnavailablePartitionsOpt) + def reportUnderMinIsrPartitions: Boolean = has(reportUnderMinIsrPartitionsOpt) + def reportAtMinIsrPartitions: Boolean = has(reportAtMinIsrPartitionsOpt) + def reportOverriddenConfigs: Boolean = has(topicsWithOverridesOpt) + def ifExists: Boolean = has(ifExistsOpt) + def ifNotExists: Boolean = has(ifNotExistsOpt) + def excludeInternalTopics: Boolean = has(excludeInternalTopicOpt) + def topicConfig: Option[util.List[String]] = valuesAsOption(configOpt) + def configsToDelete: Option[util.List[String]] = valuesAsOption(deleteConfigOpt) + + def checkArgs(): Unit = { + if (args.length == 0) + CommandLineUtils.printUsageAndDie(parser, "Create, delete, describe, or change a topic.") + + CommandLineUtils.printHelpAndExitIfNeeded(this, "This tool helps to create, delete, describe, or change a topic.") + + // should have exactly one action + val actions = Seq(createOpt, listOpt, alterOpt, describeOpt, deleteOpt).count(options.has) + if (actions != 1) + CommandLineUtils.printUsageAndDie(parser, "Command must include exactly one action: --list, --describe, --create, --alter or --delete") + + // check required args + if (!has(bootstrapServerOpt)) + throw new IllegalArgumentException("--bootstrap-server must be specified") + if (has(describeOpt) && has(ifExistsOpt)) { + if (!has(topicOpt) && !has(topicIdOpt)) + CommandLineUtils.printUsageAndDie(parser, "--topic or --topic-id is required to describe a topic") + if (has(topicOpt) && has(topicIdOpt)) + println("Only topic id will be used when both --topic and --topic-id are specified and topicId is not Uuid.ZERO_UUID") + } + if (!has(listOpt) && !has(describeOpt)) + CommandLineUtils.checkRequiredArgs(parser, options, topicOpt) + if (has(alterOpt)) { + CommandLineUtils.checkInvalidArgsSet(parser, options, Set(bootstrapServerOpt, configOpt), Set(alterOpt), + Some(kafkaConfigsCanAlterTopicConfigsViaBootstrapServer)) + CommandLineUtils.checkRequiredArgs(parser, options, partitionsOpt) + } + + // check invalid args + CommandLineUtils.checkInvalidArgs(parser, options, configOpt, allTopicLevelOpts -- Set(alterOpt, createOpt)) + CommandLineUtils.checkInvalidArgs(parser, options, deleteConfigOpt, allTopicLevelOpts -- Set(alterOpt) ++ Set(bootstrapServerOpt)) + CommandLineUtils.checkInvalidArgs(parser, options, partitionsOpt, allTopicLevelOpts -- Set(alterOpt, createOpt)) + CommandLineUtils.checkInvalidArgs(parser, options, replicationFactorOpt, allTopicLevelOpts -- Set(createOpt)) + CommandLineUtils.checkInvalidArgs(parser, options, replicaAssignmentOpt, allTopicLevelOpts -- Set(createOpt,alterOpt)) + if(options.has(createOpt)) + CommandLineUtils.checkInvalidArgs(parser, options, replicaAssignmentOpt, Set(partitionsOpt, replicationFactorOpt)) + CommandLineUtils.checkInvalidArgs(parser, options, reportUnderReplicatedPartitionsOpt, + allTopicLevelOpts -- Set(describeOpt) ++ allReplicationReportOpts - reportUnderReplicatedPartitionsOpt + topicsWithOverridesOpt) + CommandLineUtils.checkInvalidArgs(parser, options, reportUnderMinIsrPartitionsOpt, + allTopicLevelOpts -- Set(describeOpt) ++ allReplicationReportOpts - reportUnderMinIsrPartitionsOpt + topicsWithOverridesOpt) + CommandLineUtils.checkInvalidArgs(parser, options, reportAtMinIsrPartitionsOpt, + allTopicLevelOpts -- Set(describeOpt) ++ allReplicationReportOpts - reportAtMinIsrPartitionsOpt + topicsWithOverridesOpt) + CommandLineUtils.checkInvalidArgs(parser, options, reportUnavailablePartitionsOpt, + allTopicLevelOpts -- Set(describeOpt) ++ allReplicationReportOpts - reportUnavailablePartitionsOpt + topicsWithOverridesOpt) + CommandLineUtils.checkInvalidArgs(parser, options, topicsWithOverridesOpt, + allTopicLevelOpts -- Set(describeOpt) ++ allReplicationReportOpts) + CommandLineUtils.checkInvalidArgs(parser, options, ifExistsOpt, allTopicLevelOpts -- Set(alterOpt, deleteOpt, describeOpt)) + CommandLineUtils.checkInvalidArgs(parser, options, ifNotExistsOpt, allTopicLevelOpts -- Set(createOpt)) + CommandLineUtils.checkInvalidArgs(parser, options, excludeInternalTopicOpt, allTopicLevelOpts -- Set(listOpt, describeOpt)) + } + } +} + diff --git a/core/src/main/scala/kafka/admin/ZkSecurityMigrator.scala b/core/src/main/scala/kafka/admin/ZkSecurityMigrator.scala new file mode 100644 index 0000000..1263195 --- /dev/null +++ b/core/src/main/scala/kafka/admin/ZkSecurityMigrator.scala @@ -0,0 +1,307 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.admin + +import joptsimple.{ArgumentAcceptingOptionSpec, OptionSet} +import kafka.server.KafkaConfig +import kafka.utils.{CommandDefaultOptions, CommandLineUtils, Exit, Logging} +import kafka.utils.Implicits._ +import kafka.zk.{ControllerZNode, KafkaZkClient, ZkData, ZkSecurityMigratorUtils} +import org.apache.kafka.common.security.JaasUtils +import org.apache.kafka.common.utils.{Time, Utils} +import org.apache.zookeeper.AsyncCallback.{ChildrenCallback, StatCallback} +import org.apache.zookeeper.KeeperException +import org.apache.zookeeper.KeeperException.Code +import org.apache.zookeeper.client.ZKClientConfig +import org.apache.zookeeper.data.Stat + +import scala.annotation.tailrec +import scala.jdk.CollectionConverters._ +import scala.collection.mutable.Queue +import scala.concurrent._ +import scala.concurrent.duration._ + +/** + * This tool is to be used when making access to ZooKeeper authenticated or + * the other way around, when removing authenticated access. The exact steps + * to migrate a Kafka cluster from unsecure to secure with respect to ZooKeeper + * access are the following: + * + * 1- Perform a rolling upgrade of Kafka servers, setting zookeeper.set.acl to false + * and passing a valid JAAS login file via the system property + * java.security.auth.login.config + * 2- Perform a second rolling upgrade keeping the system property for the login file + * and now setting zookeeper.set.acl to true + * 3- Finally run this tool. There is a script under ./bin. Run + * ./bin/zookeeper-security-migration.sh --help + * to see the configuration parameters. An example of running it is the following: + * ./bin/zookeeper-security-migration.sh --zookeeper.acl=secure --zookeeper.connect=localhost:2181 + * + * To convert a cluster from secure to unsecure, we need to perform the following + * steps: + * 1- Perform a rolling upgrade setting zookeeper.set.acl to false for each server + * 2- Run this migration tool, setting zookeeper.acl to unsecure + * 3- Perform another rolling upgrade to remove the system property setting the + * login file (java.security.auth.login.config). + */ + +object ZkSecurityMigrator extends Logging { + val usageMessage = ("ZooKeeper Migration Tool Help. This tool updates the ACLs of " + + "znodes as part of the process of setting up ZooKeeper " + + "authentication.") + val tlsConfigFileOption = "zk-tls-config-file" + + def run(args: Array[String]): Unit = { + val jaasFile = System.getProperty(JaasUtils.JAVA_LOGIN_CONFIG_PARAM) + val opts = new ZkSecurityMigratorOptions(args) + + CommandLineUtils.printHelpAndExitIfNeeded(opts, usageMessage) + + // Must have either SASL or TLS mutual authentication enabled to use this tool. + // Instantiate the client config we will use so that we take into account config provided via the CLI option + // and system properties passed via -D parameters if no CLI option is given. + val zkClientConfig = createZkClientConfigFromOption(opts.options, opts.zkTlsConfigFile).getOrElse(new ZKClientConfig()) + val tlsClientAuthEnabled = KafkaConfig.zkTlsClientAuthEnabled(zkClientConfig) + if (jaasFile == null && !tlsClientAuthEnabled) { + val errorMsg = s"No JAAS configuration file has been specified and no TLS client certificate has been specified. Please make sure that you set " + + s"the system property ${JaasUtils.JAVA_LOGIN_CONFIG_PARAM} or provide a ZooKeeper client TLS configuration via --$tlsConfigFileOption " + + s"identifying at least ${KafkaConfig.ZkSslClientEnableProp}, ${KafkaConfig.ZkClientCnxnSocketProp}, and ${KafkaConfig.ZkSslKeyStoreLocationProp}" + System.err.println("ERROR: %s".format(errorMsg)) + throw new IllegalArgumentException("Incorrect configuration") + } + + if (!tlsClientAuthEnabled && !JaasUtils.isZkSaslEnabled()) { + val errorMsg = "Security isn't enabled, most likely the file isn't set properly: %s".format(jaasFile) + System.out.println("ERROR: %s".format(errorMsg)) + throw new IllegalArgumentException("Incorrect configuration") + } + + val zkAcl = opts.options.valueOf(opts.zkAclOpt) match { + case "secure" => + info("zookeeper.acl option is secure") + true + case "unsecure" => + info("zookeeper.acl option is unsecure") + false + case _ => + CommandLineUtils.printUsageAndDie(opts.parser, usageMessage) + } + val zkUrl = opts.options.valueOf(opts.zkUrlOpt) + val zkSessionTimeout = opts.options.valueOf(opts.zkSessionTimeoutOpt).intValue + val zkConnectionTimeout = opts.options.valueOf(opts.zkConnectionTimeoutOpt).intValue + val zkClient = KafkaZkClient(zkUrl, zkAcl, zkSessionTimeout, zkConnectionTimeout, + Int.MaxValue, Time.SYSTEM, zkClientConfig = zkClientConfig, name = "ZkSecurityMigrator") + val enablePathCheck = opts.options.has(opts.enablePathCheckOpt) + val migrator = new ZkSecurityMigrator(zkClient) + migrator.run(enablePathCheck) + } + + def main(args: Array[String]): Unit = { + try { + run(args) + } catch { + case e: Exception => { + e.printStackTrace() + // must exit with non-zero status so system tests will know we failed + Exit.exit(1) + } + } + } + + def createZkClientConfigFromFile(filename: String) : ZKClientConfig = { + val zkTlsConfigFileProps = Utils.loadProps(filename, KafkaConfig.ZkSslConfigToSystemPropertyMap.keys.toList.asJava) + val zkClientConfig = new ZKClientConfig() // Initializes based on any system properties that have been set + // Now override any set system properties with explicitly-provided values from the config file + // Emit INFO logs due to camel-case property names encouraging mistakes -- help people see mistakes they make + info(s"Found ${zkTlsConfigFileProps.size()} ZooKeeper client configuration properties in file $filename") + zkTlsConfigFileProps.asScala.forKeyValue { (key, value) => + info(s"Setting $key") + KafkaConfig.setZooKeeperClientProperty(zkClientConfig, key, value) + } + zkClientConfig + } + + private[admin] def createZkClientConfigFromOption(options: OptionSet, option: ArgumentAcceptingOptionSpec[String]) : Option[ZKClientConfig] = + if (!options.has(option)) + None + else + Some(createZkClientConfigFromFile(options.valueOf(option))) + + class ZkSecurityMigratorOptions(args: Array[String]) extends CommandDefaultOptions(args) { + val zkAclOpt = parser.accepts("zookeeper.acl", "Indicates whether to make the Kafka znodes in ZooKeeper secure or unsecure." + + " The options are 'secure' and 'unsecure'").withRequiredArg().ofType(classOf[String]) + val zkUrlOpt = parser.accepts("zookeeper.connect", "Sets the ZooKeeper connect string (ensemble). This parameter " + + "takes a comma-separated list of host:port pairs.").withRequiredArg().defaultsTo("localhost:2181"). + ofType(classOf[String]) + val zkSessionTimeoutOpt = parser.accepts("zookeeper.session.timeout", "Sets the ZooKeeper session timeout."). + withRequiredArg().ofType(classOf[java.lang.Integer]).defaultsTo(30000) + val zkConnectionTimeoutOpt = parser.accepts("zookeeper.connection.timeout", "Sets the ZooKeeper connection timeout."). + withRequiredArg().ofType(classOf[java.lang.Integer]).defaultsTo(30000) + val enablePathCheckOpt = parser.accepts("enable.path.check", "Checks if all the root paths exist in ZooKeeper " + + "before migration. If not, exit the command.") + val zkTlsConfigFile = parser.accepts(tlsConfigFileOption, + "Identifies the file where ZooKeeper client TLS connectivity properties are defined. Any properties other than " + + KafkaConfig.ZkSslConfigToSystemPropertyMap.keys.mkString(", ") + " are ignored.") + .withRequiredArg().describedAs("ZooKeeper TLS configuration").ofType(classOf[String]) + options = parser.parse(args : _*) + } +} + +class ZkSecurityMigrator(zkClient: KafkaZkClient) extends Logging { + private val zkSecurityMigratorUtils = new ZkSecurityMigratorUtils(zkClient) + private val futures = new Queue[Future[String]] + + private def setAcl(path: String, setPromise: Promise[String]): Unit = { + info("Setting ACL for path %s".format(path)) + zkSecurityMigratorUtils.currentZooKeeper.setACL(path, zkClient.defaultAcls(path).asJava, -1, SetACLCallback, setPromise) + } + + private def getChildren(path: String, childrenPromise: Promise[String]): Unit = { + info("Getting children to set ACLs for path %s".format(path)) + zkSecurityMigratorUtils.currentZooKeeper.getChildren(path, false, GetChildrenCallback, childrenPromise) + } + + private def setAclIndividually(path: String): Unit = { + val setPromise = Promise[String]() + futures.synchronized { + futures += setPromise.future + } + setAcl(path, setPromise) + } + + private def setAclsRecursively(path: String): Unit = { + val setPromise = Promise[String]() + val childrenPromise = Promise[String]() + futures.synchronized { + futures += setPromise.future + futures += childrenPromise.future + } + setAcl(path, setPromise) + getChildren(path, childrenPromise) + } + + private object GetChildrenCallback extends ChildrenCallback { + def processResult(rc: Int, + path: String, + ctx: Object, + children: java.util.List[String]): Unit = { + val zkHandle = zkSecurityMigratorUtils.currentZooKeeper + val promise = ctx.asInstanceOf[Promise[String]] + Code.get(rc) match { + case Code.OK => + // Set ACL for each child + children.asScala.map { child => + path match { + case "/" => s"/$child" + case path => s"$path/$child" + } + }.foreach(setAclsRecursively) + promise success "done" + case Code.CONNECTIONLOSS => + zkHandle.getChildren(path, false, GetChildrenCallback, ctx) + case Code.NONODE => + warn("Node is gone, it could be have been legitimately deleted: %s".format(path)) + promise success "done" + case Code.SESSIONEXPIRED => + // Starting a new session isn't really a problem, but it'd complicate + // the logic of the tool, so we quit and let the user re-run it. + System.out.println("ZooKeeper session expired while changing ACLs") + promise failure KeeperException.create(Code.get(rc)) + case _ => + System.out.println("Unexpected return code: %d".format(rc)) + promise failure KeeperException.create(Code.get(rc)) + } + } + } + + private object SetACLCallback extends StatCallback { + def processResult(rc: Int, + path: String, + ctx: Object, + stat: Stat): Unit = { + val zkHandle = zkSecurityMigratorUtils.currentZooKeeper + val promise = ctx.asInstanceOf[Promise[String]] + + Code.get(rc) match { + case Code.OK => + info("Successfully set ACLs for %s".format(path)) + promise success "done" + case Code.CONNECTIONLOSS => + zkHandle.setACL(path, zkClient.defaultAcls(path).asJava, -1, SetACLCallback, ctx) + case Code.NONODE => + warn("Znode is gone, it could be have been legitimately deleted: %s".format(path)) + promise success "done" + case Code.SESSIONEXPIRED => + // Starting a new session isn't really a problem, but it'd complicate + // the logic of the tool, so we quit and let the user re-run it. + System.out.println("ZooKeeper session expired while changing ACLs") + promise failure KeeperException.create(Code.get(rc)) + case _ => + System.out.println("Unexpected return code: %d".format(rc)) + promise failure KeeperException.create(Code.get(rc)) + } + } + } + + private def run(enablePathCheck: Boolean): Unit = { + try { + setAclIndividually("/") + checkPathExistenceAndMaybeExit(enablePathCheck) + for (path <- ZkData.SecureRootPaths) { + debug("Going to set ACL for %s".format(path)) + if (path == ControllerZNode.path && !zkClient.pathExists(path)) { + debug("Ignoring to set ACL for %s, because it doesn't exist".format(path)) + } else { + zkClient.makeSurePersistentPathExists(path) + setAclsRecursively(path) + } + } + + @tailrec + def recurse(): Unit = { + val future = futures.synchronized { + futures.headOption + } + future match { + case Some(a) => + Await.result(a, 6000 millis) + futures.synchronized { futures.dequeue() } + recurse() + case None => + } + } + recurse() + + } finally { + zkClient.close() + } + } + + private def checkPathExistenceAndMaybeExit(enablePathCheck: Boolean): Unit = { + val nonExistingSecureRootPaths = ZkData.SecureRootPaths.filterNot(zkClient.pathExists) + if (nonExistingSecureRootPaths.nonEmpty) { + println(s"Warning: The following secure root paths do not exist in ZooKeeper: ${nonExistingSecureRootPaths.mkString(",")}") + println("That might be due to an incorrect chroot is specified when executing the command.") + if (enablePathCheck) { + println("Exit the command.") + // must exit with non-zero status so system tests will know we failed + Exit.exit(1) + } + } + } +} diff --git a/core/src/main/scala/kafka/api/ApiUtils.scala b/core/src/main/scala/kafka/api/ApiUtils.scala new file mode 100644 index 0000000..9be1e4b --- /dev/null +++ b/core/src/main/scala/kafka/api/ApiUtils.scala @@ -0,0 +1,78 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.api + +import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets + +import org.apache.kafka.common.KafkaException + +/** + * Helper functions specific to parsing or serializing requests and responses + */ +object ApiUtils { + + /** + * Read size prefixed string where the size is stored as a 2 byte short. + * @param buffer The buffer to read from + */ + def readShortString(buffer: ByteBuffer): String = { + val size: Int = buffer.getShort() + if(size < 0) + return null + val bytes = new Array[Byte](size) + buffer.get(bytes) + new String(bytes, StandardCharsets.UTF_8) + } + + /** + * Write a size prefixed string where the size is stored as a 2 byte short + * @param buffer The buffer to write to + * @param string The string to write + */ + def writeShortString(buffer: ByteBuffer, string: String): Unit = { + if(string == null) { + buffer.putShort(-1) + } else { + val encodedString = string.getBytes(StandardCharsets.UTF_8) + if(encodedString.length > Short.MaxValue) { + throw new KafkaException("String exceeds the maximum size of " + Short.MaxValue + ".") + } else { + buffer.putShort(encodedString.length.asInstanceOf[Short]) + buffer.put(encodedString) + } + } + } + + /** + * Return size of a size prefixed string where the size is stored as a 2 byte short + * @param string The string to write + */ + def shortStringLength(string: String): Int = { + if(string == null) { + 2 + } else { + val encodedString = string.getBytes(StandardCharsets.UTF_8) + if(encodedString.length > Short.MaxValue) { + throw new KafkaException("String exceeds the maximum size of " + Short.MaxValue + ".") + } else { + 2 + encodedString.length + } + } + } + +} diff --git a/core/src/main/scala/kafka/api/ApiVersion.scala b/core/src/main/scala/kafka/api/ApiVersion.scala new file mode 100644 index 0000000..8165e6c --- /dev/null +++ b/core/src/main/scala/kafka/api/ApiVersion.scala @@ -0,0 +1,491 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.api + +import org.apache.kafka.clients.NodeApiVersions +import org.apache.kafka.common.config.ConfigDef.Validator +import org.apache.kafka.common.config.ConfigException +import org.apache.kafka.common.feature.{Features, FinalizedVersionRange, SupportedVersionRange} +import org.apache.kafka.common.message.ApiMessageType.ListenerType +import org.apache.kafka.common.record.RecordVersion +import org.apache.kafka.common.requests.ApiVersionsResponse + +/** + * This class contains the different Kafka versions. + * Right now, we use them for upgrades - users can configure the version of the API brokers will use to communicate between themselves. + * This is only for inter-broker communications - when communicating with clients, the client decides on the API version. + * + * Note that the ID we initialize for each version is important. + * We consider a version newer than another, if it has a higher ID (to avoid depending on lexicographic order) + * + * Since the api protocol may change more than once within the same release and to facilitate people deploying code from + * trunk, we have the concept of internal versions (first introduced during the 0.10.0 development cycle). For example, + * the first time we introduce a version change in a release, say 0.10.0, we will add a config value "0.10.0-IV0" and a + * corresponding case object KAFKA_0_10_0-IV0. We will also add a config value "0.10.0" that will be mapped to the + * latest internal version object, which is KAFKA_0_10_0-IV0. When we change the protocol a second time while developing + * 0.10.0, we will add a new config value "0.10.0-IV1" and a corresponding case object KAFKA_0_10_0-IV1. We will change + * the config value "0.10.0" to map to the latest internal version object KAFKA_0_10_0-IV1. The config value of + * "0.10.0-IV0" is still mapped to KAFKA_0_10_0-IV0. This way, if people are deploying from trunk, they can use + * "0.10.0-IV0" and "0.10.0-IV1" to upgrade one internal version at a time. For most people who just want to use + * released version, they can use "0.10.0" when upgrading to the 0.10.0 release. + */ +object ApiVersion { + // This implicit is necessary due to: https://issues.scala-lang.org/browse/SI-8541 + implicit def orderingByVersion[A <: ApiVersion]: Ordering[A] = Ordering.by(_.id) + + val allVersions: Seq[ApiVersion] = Seq( + KAFKA_0_8_0, + KAFKA_0_8_1, + KAFKA_0_8_2, + KAFKA_0_9_0, + // 0.10.0-IV0 is introduced for KIP-31/32 which changes the message format. + KAFKA_0_10_0_IV0, + // 0.10.0-IV1 is introduced for KIP-36(rack awareness) and KIP-43(SASL handshake). + KAFKA_0_10_0_IV1, + // introduced for JoinGroup protocol change in KIP-62 + KAFKA_0_10_1_IV0, + // 0.10.1-IV1 is introduced for KIP-74(fetch response size limit). + KAFKA_0_10_1_IV1, + // introduced ListOffsetRequest v1 in KIP-79 + KAFKA_0_10_1_IV2, + // introduced UpdateMetadataRequest v3 in KIP-103 + KAFKA_0_10_2_IV0, + // KIP-98 (idempotent and transactional producer support) + KAFKA_0_11_0_IV0, + // introduced DeleteRecordsRequest v0 and FetchRequest v4 in KIP-107 + KAFKA_0_11_0_IV1, + // Introduced leader epoch fetches to the replica fetcher via KIP-101 + KAFKA_0_11_0_IV2, + // Introduced LeaderAndIsrRequest V1, UpdateMetadataRequest V4 and FetchRequest V6 via KIP-112 + KAFKA_1_0_IV0, + // Introduced DeleteGroupsRequest V0 via KIP-229, plus KIP-227 incremental fetch requests, + // and KafkaStorageException for fetch requests. + KAFKA_1_1_IV0, + // Introduced OffsetsForLeaderEpochRequest V1 via KIP-279 (Fix log divergence between leader and follower after fast leader fail over) + KAFKA_2_0_IV0, + // Several request versions were bumped due to KIP-219 (Improve quota communication) + KAFKA_2_0_IV1, + // Introduced new schemas for group offset (v2) and group metadata (v2) (KIP-211) + KAFKA_2_1_IV0, + // New Fetch, OffsetsForLeaderEpoch, and ListOffsets schemas (KIP-320) + KAFKA_2_1_IV1, + // Support ZStandard Compression Codec (KIP-110) + KAFKA_2_1_IV2, + // Introduced broker generation (KIP-380), and + // LeaderAdnIsrRequest V2, UpdateMetadataRequest V5, StopReplicaRequest V1 + KAFKA_2_2_IV0, + // New error code for ListOffsets when a new leader is lagging behind former HW (KIP-207) + KAFKA_2_2_IV1, + // Introduced static membership. + KAFKA_2_3_IV0, + // Add rack_id to FetchRequest, preferred_read_replica to FetchResponse, and replica_id to OffsetsForLeaderRequest + KAFKA_2_3_IV1, + // Add adding_replicas and removing_replicas fields to LeaderAndIsrRequest + KAFKA_2_4_IV0, + // Flexible version support in inter-broker APIs + KAFKA_2_4_IV1, + // No new APIs, equivalent to 2.4-IV1 + KAFKA_2_5_IV0, + // Introduced StopReplicaRequest V3 containing the leader epoch for each partition (KIP-570) + KAFKA_2_6_IV0, + // Introduced feature versioning support (KIP-584) + KAFKA_2_7_IV0, + // Bup Fetch protocol for Raft protocol (KIP-595) + KAFKA_2_7_IV1, + // Introduced AlterIsr (KIP-497) + KAFKA_2_7_IV2, + // Flexible versioning on ListOffsets, WriteTxnMarkers and OffsetsForLeaderEpoch. Also adds topic IDs (KIP-516) + KAFKA_2_8_IV0, + // Introduced topic IDs to LeaderAndIsr and UpdateMetadata requests/responses (KIP-516) + KAFKA_2_8_IV1, + // Introduce AllocateProducerIds (KIP-730) + KAFKA_3_0_IV0, + // Introduce ListOffsets V7 which supports listing offsets by max timestamp (KIP-734) + // Assume message format version is 3.0 (KIP-724) + KAFKA_3_0_IV1, + // Adds topic IDs to Fetch requests/responses (KIP-516) + KAFKA_3_1_IV0 + ) + + // Map keys are the union of the short and full versions + private val versionMap: Map[String, ApiVersion] = + allVersions.map(v => v.version -> v).toMap ++ allVersions.groupBy(_.shortVersion).map { case (k, v) => k -> v.last } + + /** + * Return an `ApiVersion` instance for `versionString`, which can be in a variety of formats (e.g. "0.8.0", "0.8.0.x", + * "0.10.0", "0.10.0-IV1"). `IllegalArgumentException` is thrown if `versionString` cannot be mapped to an `ApiVersion`. + */ + def apply(versionString: String): ApiVersion = { + val versionSegments = versionString.split('.').toSeq + val numSegments = if (versionString.startsWith("0.")) 3 else 2 + val key = versionSegments.take(numSegments).mkString(".") + versionMap.getOrElse(key, throw new IllegalArgumentException(s"Version `$versionString` is not a valid version")) + } + + val latestVersion: ApiVersion = allVersions.last + + def isTruncationOnFetchSupported(version: ApiVersion): Boolean = version >= KAFKA_2_7_IV1 + + /** + * Return the minimum `ApiVersion` that supports `RecordVersion`. + */ + def minSupportedFor(recordVersion: RecordVersion): ApiVersion = { + recordVersion match { + case RecordVersion.V0 => KAFKA_0_8_0 + case RecordVersion.V1 => KAFKA_0_10_0_IV0 + case RecordVersion.V2 => KAFKA_0_11_0_IV0 + case _ => throw new IllegalArgumentException(s"Invalid message format version $recordVersion") + } + } + + def apiVersionsResponse( + throttleTimeMs: Int, + minRecordVersion: RecordVersion, + latestSupportedFeatures: Features[SupportedVersionRange], + controllerApiVersions: Option[NodeApiVersions], + listenerType: ListenerType + ): ApiVersionsResponse = { + apiVersionsResponse( + throttleTimeMs, + minRecordVersion, + latestSupportedFeatures, + Features.emptyFinalizedFeatures, + ApiVersionsResponse.UNKNOWN_FINALIZED_FEATURES_EPOCH, + controllerApiVersions, + listenerType + ) + } + + def apiVersionsResponse( + throttleTimeMs: Int, + minRecordVersion: RecordVersion, + latestSupportedFeatures: Features[SupportedVersionRange], + finalizedFeatures: Features[FinalizedVersionRange], + finalizedFeaturesEpoch: Long, + controllerApiVersions: Option[NodeApiVersions], + listenerType: ListenerType + ): ApiVersionsResponse = { + val apiKeys = controllerApiVersions match { + case None => ApiVersionsResponse.filterApis(minRecordVersion, listenerType) + case Some(controllerApiVersion) => ApiVersionsResponse.intersectForwardableApis( + listenerType, minRecordVersion, controllerApiVersion.allSupportedApiVersions()) + } + + ApiVersionsResponse.createApiVersionsResponse( + throttleTimeMs, + apiKeys, + latestSupportedFeatures, + finalizedFeatures, + finalizedFeaturesEpoch + ) + } +} + +sealed trait ApiVersion extends Ordered[ApiVersion] { + def version: String + def shortVersion: String + def recordVersion: RecordVersion + def id: Int + + def isAlterIsrSupported: Boolean = this >= KAFKA_2_7_IV2 + + def isAllocateProducerIdsSupported: Boolean = this >= KAFKA_3_0_IV0 + + override def compare(that: ApiVersion): Int = + ApiVersion.orderingByVersion.compare(this, that) + + override def toString: String = version +} + +/** + * For versions before 0.10.0, `version` and `shortVersion` were the same. + */ +sealed trait LegacyApiVersion extends ApiVersion { + def version = shortVersion +} + +/** + * From 0.10.0 onwards, each version has a sub-version. For example, IV0 is the sub-version of 0.10.0-IV0. + */ +sealed trait DefaultApiVersion extends ApiVersion { + lazy val version = shortVersion + "-" + subVersion + protected def subVersion: String +} + +// Keep the IDs in order of versions +case object KAFKA_0_8_0 extends LegacyApiVersion { + val shortVersion = "0.8.0" + val recordVersion = RecordVersion.V0 + val id: Int = 0 +} + +case object KAFKA_0_8_1 extends LegacyApiVersion { + val shortVersion = "0.8.1" + val recordVersion = RecordVersion.V0 + val id: Int = 1 +} + +case object KAFKA_0_8_2 extends LegacyApiVersion { + val shortVersion = "0.8.2" + val recordVersion = RecordVersion.V0 + val id: Int = 2 +} + +case object KAFKA_0_9_0 extends LegacyApiVersion { + val shortVersion = "0.9.0" + val subVersion = "" + val recordVersion = RecordVersion.V0 + val id: Int = 3 +} + +case object KAFKA_0_10_0_IV0 extends DefaultApiVersion { + val shortVersion = "0.10.0" + val subVersion = "IV0" + val recordVersion = RecordVersion.V1 + val id: Int = 4 +} + +case object KAFKA_0_10_0_IV1 extends DefaultApiVersion { + val shortVersion = "0.10.0" + val subVersion = "IV1" + val recordVersion = RecordVersion.V1 + val id: Int = 5 +} + +case object KAFKA_0_10_1_IV0 extends DefaultApiVersion { + val shortVersion = "0.10.1" + val subVersion = "IV0" + val recordVersion = RecordVersion.V1 + val id: Int = 6 +} + +case object KAFKA_0_10_1_IV1 extends DefaultApiVersion { + val shortVersion = "0.10.1" + val subVersion = "IV1" + val recordVersion = RecordVersion.V1 + val id: Int = 7 +} + +case object KAFKA_0_10_1_IV2 extends DefaultApiVersion { + val shortVersion = "0.10.1" + val subVersion = "IV2" + val recordVersion = RecordVersion.V1 + val id: Int = 8 +} + +case object KAFKA_0_10_2_IV0 extends DefaultApiVersion { + val shortVersion = "0.10.2" + val subVersion = "IV0" + val recordVersion = RecordVersion.V1 + val id: Int = 9 +} + +case object KAFKA_0_11_0_IV0 extends DefaultApiVersion { + val shortVersion = "0.11.0" + val subVersion = "IV0" + val recordVersion = RecordVersion.V2 + val id: Int = 10 +} + +case object KAFKA_0_11_0_IV1 extends DefaultApiVersion { + val shortVersion = "0.11.0" + val subVersion = "IV1" + val recordVersion = RecordVersion.V2 + val id: Int = 11 +} + +case object KAFKA_0_11_0_IV2 extends DefaultApiVersion { + val shortVersion = "0.11.0" + val subVersion = "IV2" + val recordVersion = RecordVersion.V2 + val id: Int = 12 +} + +case object KAFKA_1_0_IV0 extends DefaultApiVersion { + val shortVersion = "1.0" + val subVersion = "IV0" + val recordVersion = RecordVersion.V2 + val id: Int = 13 +} + +case object KAFKA_1_1_IV0 extends DefaultApiVersion { + val shortVersion = "1.1" + val subVersion = "IV0" + val recordVersion = RecordVersion.V2 + val id: Int = 14 +} + +case object KAFKA_2_0_IV0 extends DefaultApiVersion { + val shortVersion: String = "2.0" + val subVersion = "IV0" + val recordVersion = RecordVersion.V2 + val id: Int = 15 +} + +case object KAFKA_2_0_IV1 extends DefaultApiVersion { + val shortVersion: String = "2.0" + val subVersion = "IV1" + val recordVersion = RecordVersion.V2 + val id: Int = 16 +} + +case object KAFKA_2_1_IV0 extends DefaultApiVersion { + val shortVersion: String = "2.1" + val subVersion = "IV0" + val recordVersion = RecordVersion.V2 + val id: Int = 17 +} + +case object KAFKA_2_1_IV1 extends DefaultApiVersion { + val shortVersion: String = "2.1" + val subVersion = "IV1" + val recordVersion = RecordVersion.V2 + val id: Int = 18 +} + +case object KAFKA_2_1_IV2 extends DefaultApiVersion { + val shortVersion: String = "2.1" + val subVersion = "IV2" + val recordVersion = RecordVersion.V2 + val id: Int = 19 +} + +case object KAFKA_2_2_IV0 extends DefaultApiVersion { + val shortVersion: String = "2.2" + val subVersion = "IV0" + val recordVersion = RecordVersion.V2 + val id: Int = 20 +} + +case object KAFKA_2_2_IV1 extends DefaultApiVersion { + val shortVersion: String = "2.2" + val subVersion = "IV1" + val recordVersion = RecordVersion.V2 + val id: Int = 21 +} + +case object KAFKA_2_3_IV0 extends DefaultApiVersion { + val shortVersion: String = "2.3" + val subVersion = "IV0" + val recordVersion = RecordVersion.V2 + val id: Int = 22 +} + +case object KAFKA_2_3_IV1 extends DefaultApiVersion { + val shortVersion: String = "2.3" + val subVersion = "IV1" + val recordVersion = RecordVersion.V2 + val id: Int = 23 +} + +case object KAFKA_2_4_IV0 extends DefaultApiVersion { + val shortVersion: String = "2.4" + val subVersion = "IV0" + val recordVersion = RecordVersion.V2 + val id: Int = 24 +} + +case object KAFKA_2_4_IV1 extends DefaultApiVersion { + val shortVersion: String = "2.4" + val subVersion = "IV1" + val recordVersion = RecordVersion.V2 + val id: Int = 25 +} + +case object KAFKA_2_5_IV0 extends DefaultApiVersion { + val shortVersion: String = "2.5" + val subVersion = "IV0" + val recordVersion = RecordVersion.V2 + val id: Int = 26 +} + +case object KAFKA_2_6_IV0 extends DefaultApiVersion { + val shortVersion: String = "2.6" + val subVersion = "IV0" + val recordVersion = RecordVersion.V2 + val id: Int = 27 +} + +case object KAFKA_2_7_IV0 extends DefaultApiVersion { + val shortVersion: String = "2.7" + val subVersion = "IV0" + val recordVersion = RecordVersion.V2 + val id: Int = 28 +} + +case object KAFKA_2_7_IV1 extends DefaultApiVersion { + val shortVersion: String = "2.7" + val subVersion = "IV1" + val recordVersion = RecordVersion.V2 + val id: Int = 29 +} + +case object KAFKA_2_7_IV2 extends DefaultApiVersion { + val shortVersion: String = "2.7" + val subVersion = "IV2" + val recordVersion = RecordVersion.V2 + val id: Int = 30 +} + +case object KAFKA_2_8_IV0 extends DefaultApiVersion { + val shortVersion: String = "2.8" + val subVersion = "IV0" + val recordVersion = RecordVersion.V2 + val id: Int = 31 +} + +case object KAFKA_2_8_IV1 extends DefaultApiVersion { + val shortVersion: String = "2.8" + val subVersion = "IV1" + val recordVersion = RecordVersion.V2 + val id: Int = 32 +} + +case object KAFKA_3_0_IV0 extends DefaultApiVersion { + val shortVersion: String = "3.0" + val subVersion = "IV0" + val recordVersion = RecordVersion.V2 + val id: Int = 33 +} + +case object KAFKA_3_0_IV1 extends DefaultApiVersion { + val shortVersion: String = "3.0" + val subVersion = "IV1" + val recordVersion = RecordVersion.V2 + val id: Int = 34 +} + +case object KAFKA_3_1_IV0 extends DefaultApiVersion { + val shortVersion: String = "3.1" + val subVersion = "IV0" + val recordVersion = RecordVersion.V2 + val id: Int = 35 +} + +object ApiVersionValidator extends Validator { + + override def ensureValid(name: String, value: Any): Unit = { + try { + ApiVersion(value.toString) + } catch { + case e: IllegalArgumentException => throw new ConfigException(name, value.toString, e.getMessage) + } + } + + override def toString: String = "[" + ApiVersion.allVersions.map(_.version).distinct.mkString(", ") + "]" +} diff --git a/core/src/main/scala/kafka/api/LeaderAndIsr.scala b/core/src/main/scala/kafka/api/LeaderAndIsr.scala new file mode 100644 index 0000000..05952aa --- /dev/null +++ b/core/src/main/scala/kafka/api/LeaderAndIsr.scala @@ -0,0 +1,62 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.api + +object LeaderAndIsr { + val initialLeaderEpoch: Int = 0 + val initialZKVersion: Int = 0 + val NoLeader: Int = -1 + val NoEpoch: Int = -1 + val LeaderDuringDelete: Int = -2 + val EpochDuringDelete: Int = -2 + + def apply(leader: Int, isr: List[Int]): LeaderAndIsr = LeaderAndIsr(leader, initialLeaderEpoch, isr, initialZKVersion) + + def duringDelete(isr: List[Int]): LeaderAndIsr = LeaderAndIsr(LeaderDuringDelete, isr) +} + +case class LeaderAndIsr(leader: Int, + leaderEpoch: Int, + isr: List[Int], + zkVersion: Int) { + def withZkVersion(zkVersion: Int) = copy(zkVersion = zkVersion) + + def newLeader(leader: Int) = newLeaderAndIsr(leader, isr) + + def newLeaderAndIsr(leader: Int, isr: List[Int]) = LeaderAndIsr(leader, leaderEpoch + 1, isr, zkVersion) + + def newEpochAndZkVersion = newLeaderAndIsr(leader, isr) + + def leaderOpt: Option[Int] = { + if (leader == LeaderAndIsr.NoLeader) None else Some(leader) + } + + def equalsIgnoreZk(other: LeaderAndIsr): Boolean = { + if (this == other) { + true + } else if (other == null) { + false + } else { + leader == other.leader && leaderEpoch == other.leaderEpoch && isr.equals(other.isr) + } + } + + override def toString: String = { + s"LeaderAndIsr(leader=$leader, leaderEpoch=$leaderEpoch, isr=$isr, zkVersion=$zkVersion)" + } +} diff --git a/core/src/main/scala/kafka/api/Request.scala b/core/src/main/scala/kafka/api/Request.scala new file mode 100644 index 0000000..653b5f6 --- /dev/null +++ b/core/src/main/scala/kafka/api/Request.scala @@ -0,0 +1,37 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.api + +object Request { + val OrdinaryConsumerId: Int = -1 + val DebuggingConsumerId: Int = -2 + val FutureLocalReplicaId: Int = -3 + + // Broker ids are non-negative int. + def isValidBrokerId(brokerId: Int): Boolean = brokerId >= 0 + + def describeReplicaId(replicaId: Int): String = { + replicaId match { + case OrdinaryConsumerId => "consumer" + case DebuggingConsumerId => "debug consumer" + case FutureLocalReplicaId => "future local replica" + case id if isValidBrokerId(id) => s"replica [$id]" + case id => s"invalid replica [$id]" + } + } +} diff --git a/core/src/main/scala/kafka/api/package.scala b/core/src/main/scala/kafka/api/package.scala new file mode 100644 index 0000000..e0678f8 --- /dev/null +++ b/core/src/main/scala/kafka/api/package.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka + +import org.apache.kafka.common.ElectionType +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.requests.ElectLeadersRequest +import scala.jdk.CollectionConverters._ + +package object api { + implicit final class ElectLeadersRequestOps(val self: ElectLeadersRequest) extends AnyVal { + def topicPartitions: Set[TopicPartition] = { + if (self.data.topicPartitions == null) { + Set.empty + } else { + self.data.topicPartitions.asScala.iterator.flatMap { topicPartition => + topicPartition.partitions.asScala.map { partitionId => + new TopicPartition(topicPartition.topic, partitionId) + } + }.toSet + } + } + + def electionType: ElectionType = { + if (self.version == 0) { + ElectionType.PREFERRED + } else { + ElectionType.valueOf(self.data.electionType) + } + } + } +} diff --git a/core/src/main/scala/kafka/cluster/Broker.scala b/core/src/main/scala/kafka/cluster/Broker.scala new file mode 100755 index 0000000..657d89b --- /dev/null +++ b/core/src/main/scala/kafka/cluster/Broker.scala @@ -0,0 +1,98 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.cluster + +import java.util + +import kafka.common.BrokerEndPointNotAvailableException +import kafka.server.KafkaConfig +import org.apache.kafka.common.feature.{Features, SupportedVersionRange} +import org.apache.kafka.common.feature.Features._ +import org.apache.kafka.common.{ClusterResource, Endpoint, Node} +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.server.authorizer.AuthorizerServerInfo + +import scala.collection.Seq +import scala.jdk.CollectionConverters._ + +object Broker { + private[kafka] case class ServerInfo(clusterResource: ClusterResource, + brokerId: Int, + endpoints: util.List[Endpoint], + interBrokerEndpoint: Endpoint) extends AuthorizerServerInfo + + def apply(id: Int, endPoints: Seq[EndPoint], rack: Option[String]): Broker = { + new Broker(id, endPoints, rack, emptySupportedFeatures) + } +} + +/** + * A Kafka broker. + * + * @param id a broker id + * @param endPoints a collection of EndPoint. Each end-point is (host, port, listener name, security protocol). + * @param rack an optional rack + * @param features supported features + */ +case class Broker(id: Int, endPoints: Seq[EndPoint], rack: Option[String], features: Features[SupportedVersionRange]) { + + private val endPointsMap = endPoints.map { endPoint => + endPoint.listenerName -> endPoint + }.toMap + + if (endPointsMap.size != endPoints.size) + throw new IllegalArgumentException(s"There is more than one end point with the same listener name: ${endPoints.mkString(",")}") + + override def toString: String = + s"$id : ${endPointsMap.values.mkString("(",",",")")} : ${rack.orNull} : $features" + + def this(id: Int, host: String, port: Int, listenerName: ListenerName, protocol: SecurityProtocol) = { + this(id, Seq(EndPoint(host, port, listenerName, protocol)), None, emptySupportedFeatures) + } + + def this(bep: BrokerEndPoint, listenerName: ListenerName, protocol: SecurityProtocol) = { + this(bep.id, bep.host, bep.port, listenerName, protocol) + } + + def node(listenerName: ListenerName): Node = + getNode(listenerName).getOrElse { + throw new BrokerEndPointNotAvailableException(s"End point with listener name ${listenerName.value} not found " + + s"for broker $id") + } + + def getNode(listenerName: ListenerName): Option[Node] = + endPointsMap.get(listenerName).map(endpoint => new Node(id, endpoint.host, endpoint.port, rack.orNull)) + + def brokerEndPoint(listenerName: ListenerName): BrokerEndPoint = { + val endpoint = endPoint(listenerName) + new BrokerEndPoint(id, endpoint.host, endpoint.port) + } + + def endPoint(listenerName: ListenerName): EndPoint = { + endPointsMap.getOrElse(listenerName, + throw new BrokerEndPointNotAvailableException(s"End point with listener name ${listenerName.value} not found for broker $id")) + } + + def toServerInfo(clusterId: String, config: KafkaConfig): AuthorizerServerInfo = { + val clusterResource: ClusterResource = new ClusterResource(clusterId) + val interBrokerEndpoint: Endpoint = endPoint(config.interBrokerListenerName).toJava + val brokerEndpoints: util.List[Endpoint] = endPoints.toList.map(_.toJava).asJava + Broker.ServerInfo(clusterResource, id, brokerEndpoints, interBrokerEndpoint) + } +} diff --git a/core/src/main/scala/kafka/cluster/BrokerEndPoint.scala b/core/src/main/scala/kafka/cluster/BrokerEndPoint.scala new file mode 100644 index 0000000..b2b36af --- /dev/null +++ b/core/src/main/scala/kafka/cluster/BrokerEndPoint.scala @@ -0,0 +1,83 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.cluster + +import java.nio.ByteBuffer + +import kafka.api.ApiUtils._ +import org.apache.kafka.common.KafkaException +import org.apache.kafka.common.utils.Utils._ + +object BrokerEndPoint { + + private val uriParseExp = """\[?([0-9a-zA-Z\-%._:]*)\]?:([0-9]+)""".r + + /** + * BrokerEndPoint URI is host:port or [ipv6_host]:port + * Note that unlike EndPoint (or listener) this URI has no security information. + */ + def parseHostPort(connectionString: String): Option[(String, Int)] = { + connectionString match { + case uriParseExp(host, port) => try Some(host, port.toInt) catch { case _: NumberFormatException => None } + case _ => None + } + } + + /** + * BrokerEndPoint URI is host:port or [ipv6_host]:port + * Note that unlike EndPoint (or listener) this URI has no security information. + */ + def createBrokerEndPoint(brokerId: Int, connectionString: String): BrokerEndPoint = { + parseHostPort(connectionString).map { case (host, port) => new BrokerEndPoint(brokerId, host, port) }.getOrElse { + throw new KafkaException("Unable to parse " + connectionString + " to a broker endpoint") + } + } + + def readFrom(buffer: ByteBuffer): BrokerEndPoint = { + val brokerId = buffer.getInt() + val host = readShortString(buffer) + val port = buffer.getInt() + BrokerEndPoint(brokerId, host, port) + } +} + +/** + * BrokerEndpoint is used to connect to specific host:port pair. + * It is typically used by clients (or brokers when connecting to other brokers) + * and contains no information about the security protocol used on the connection. + * Clients should know which security protocol to use from configuration. + * This allows us to keep the wire protocol with the clients unchanged where the protocol is not needed. + */ +case class BrokerEndPoint(id: Int, host: String, port: Int) { + + def connectionString(): String = formatAddress(host, port) + + def writeTo(buffer: ByteBuffer): Unit = { + buffer.putInt(id) + writeShortString(buffer, host) + buffer.putInt(port) + } + + def sizeInBytes: Int = + 4 + /* broker Id */ + 4 + /* port */ + shortStringLength(host) + + override def toString: String = { + s"BrokerEndPoint(id=$id, host=$host:$port)" + } +} diff --git a/core/src/main/scala/kafka/cluster/EndPoint.scala b/core/src/main/scala/kafka/cluster/EndPoint.scala new file mode 100644 index 0000000..3e84f9e --- /dev/null +++ b/core/src/main/scala/kafka/cluster/EndPoint.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.cluster + +import org.apache.kafka.common.{KafkaException, Endpoint => JEndpoint} +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.utils.Utils + +import java.util.Locale +import scala.collection.Map + +object EndPoint { + + private val uriParseExp = """^(.*)://\[?([0-9a-zA-Z\-%._:]*)\]?:(-?[0-9]+)""".r + + private[kafka] val DefaultSecurityProtocolMap: Map[ListenerName, SecurityProtocol] = + SecurityProtocol.values.map(sp => ListenerName.forSecurityProtocol(sp) -> sp).toMap + + /** + * Create EndPoint object from `connectionString` and optional `securityProtocolMap`. If the latter is not provided, + * we fallback to the default behaviour where listener names are the same as security protocols. + * + * @param connectionString the format is listener_name://host:port or listener_name://[ipv6 host]:port + * for example: PLAINTEXT://myhost:9092, CLIENT://myhost:9092 or REPLICATION://[::1]:9092 + * Host can be empty (PLAINTEXT://:9092) in which case we'll bind to default interface + * Negative ports are also accepted, since they are used in some unit tests + */ + def createEndPoint(connectionString: String, securityProtocolMap: Option[Map[ListenerName, SecurityProtocol]]): EndPoint = { + val protocolMap = securityProtocolMap.getOrElse(DefaultSecurityProtocolMap) + + def securityProtocol(listenerName: ListenerName): SecurityProtocol = + protocolMap.getOrElse(listenerName, + throw new IllegalArgumentException(s"No security protocol defined for listener ${listenerName.value}")) + + connectionString match { + case uriParseExp(listenerNameString, "", port) => + val listenerName = ListenerName.normalised(listenerNameString) + new EndPoint(null, port.toInt, listenerName, securityProtocol(listenerName)) + case uriParseExp(listenerNameString, host, port) => + val listenerName = ListenerName.normalised(listenerNameString) + new EndPoint(host, port.toInt, listenerName, securityProtocol(listenerName)) + case _ => throw new KafkaException(s"Unable to parse $connectionString to a broker endpoint") + } + } + + def parseListenerName(connectionString: String): String = { + connectionString match { + case uriParseExp(listenerNameString, _, _) => listenerNameString.toUpperCase(Locale.ROOT) + case _ => throw new KafkaException(s"Unable to parse a listener name from $connectionString") + } + } +} + +/** + * Part of the broker definition - matching host/port pair to a protocol + */ +case class EndPoint(host: String, port: Int, listenerName: ListenerName, securityProtocol: SecurityProtocol) { + def connectionString: String = { + val hostport = + if (host == null) + ":"+port + else + Utils.formatAddress(host, port) + listenerName.value + "://" + hostport + } + + def toJava: JEndpoint = { + new JEndpoint(listenerName.value, securityProtocol, host, port) + } +} diff --git a/core/src/main/scala/kafka/cluster/Partition.scala b/core/src/main/scala/kafka/cluster/Partition.scala new file mode 100755 index 0000000..b8e0ce4 --- /dev/null +++ b/core/src/main/scala/kafka/cluster/Partition.scala @@ -0,0 +1,1484 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.cluster + +import java.util.concurrent.locks.ReentrantReadWriteLock +import java.util.Optional +import java.util.concurrent.CompletableFuture + +import kafka.api.{ApiVersion, LeaderAndIsr} +import kafka.common.UnexpectedAppendOffsetException +import kafka.controller.{KafkaController, StateChangeLogger} +import kafka.log._ +import kafka.metrics.KafkaMetricsGroup +import kafka.server._ +import kafka.server.checkpoints.OffsetCheckpoints +import kafka.utils.CoreUtils.{inReadLock, inWriteLock} +import kafka.utils._ +import kafka.zookeeper.ZooKeeperClientException +import org.apache.kafka.common.errors._ +import org.apache.kafka.common.message.{DescribeProducersResponseData, FetchResponseData} +import org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrPartitionState +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.record.FileRecords.TimestampAndOffset +import org.apache.kafka.common.record.{MemoryRecords, RecordBatch} +import org.apache.kafka.common.requests._ +import org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.{UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET} +import org.apache.kafka.common.utils.Time +import org.apache.kafka.common.{IsolationLevel, TopicPartition, Uuid} + +import scala.collection.{Map, Seq} +import scala.jdk.CollectionConverters._ + +trait IsrChangeListener { + def markExpand(): Unit + def markShrink(): Unit + def markFailed(): Unit +} + +class DelayedOperations(topicPartition: TopicPartition, + produce: DelayedOperationPurgatory[DelayedProduce], + fetch: DelayedOperationPurgatory[DelayedFetch], + deleteRecords: DelayedOperationPurgatory[DelayedDeleteRecords]) { + + def checkAndCompleteAll(): Unit = { + val requestKey = TopicPartitionOperationKey(topicPartition) + fetch.checkAndComplete(requestKey) + produce.checkAndComplete(requestKey) + deleteRecords.checkAndComplete(requestKey) + } + + def numDelayedDelete: Int = deleteRecords.numDelayed +} + +object Partition extends KafkaMetricsGroup { + def apply(topicPartition: TopicPartition, + time: Time, + replicaManager: ReplicaManager): Partition = { + + val isrChangeListener = new IsrChangeListener { + override def markExpand(): Unit = { + replicaManager.isrExpandRate.mark() + } + + override def markShrink(): Unit = { + replicaManager.isrShrinkRate.mark() + } + + override def markFailed(): Unit = replicaManager.failedIsrUpdatesRate.mark() + } + + val delayedOperations = new DelayedOperations( + topicPartition, + replicaManager.delayedProducePurgatory, + replicaManager.delayedFetchPurgatory, + replicaManager.delayedDeleteRecordsPurgatory) + + new Partition(topicPartition, + replicaLagTimeMaxMs = replicaManager.config.replicaLagTimeMaxMs, + interBrokerProtocolVersion = replicaManager.config.interBrokerProtocolVersion, + localBrokerId = replicaManager.config.brokerId, + time = time, + isrChangeListener = isrChangeListener, + delayedOperations = delayedOperations, + metadataCache = replicaManager.metadataCache, + logManager = replicaManager.logManager, + alterIsrManager = replicaManager.alterIsrManager) + } + + def removeMetrics(topicPartition: TopicPartition): Unit = { + val tags = Map("topic" -> topicPartition.topic, "partition" -> topicPartition.partition.toString) + removeMetric("UnderReplicated", tags) + removeMetric("UnderMinIsr", tags) + removeMetric("InSyncReplicasCount", tags) + removeMetric("ReplicasCount", tags) + removeMetric("LastStableOffsetLag", tags) + removeMetric("AtMinIsr", tags) + } +} + + +sealed trait AssignmentState { + def replicas: Seq[Int] + def replicationFactor: Int = replicas.size + def isAddingReplica(brokerId: Int): Boolean = false +} + +case class OngoingReassignmentState(addingReplicas: Seq[Int], + removingReplicas: Seq[Int], + replicas: Seq[Int]) extends AssignmentState { + + override def replicationFactor: Int = replicas.diff(addingReplicas).size // keep the size of the original replicas + override def isAddingReplica(replicaId: Int): Boolean = addingReplicas.contains(replicaId) +} + +case class SimpleAssignmentState(replicas: Seq[Int]) extends AssignmentState + + + +sealed trait IsrState { + /** + * Includes only the in-sync replicas which have been committed to ZK. + */ + def isr: Set[Int] + + /** + * This set may include un-committed ISR members following an expansion. This "effective" ISR is used for advancing + * the high watermark as well as determining which replicas are required for acks=all produce requests. + * + * Only applicable as of IBP 2.7-IV2, for older versions this will return the committed ISR + * + */ + def maximalIsr: Set[Int] + + /** + * Indicates if we have an AlterIsr request inflight. + */ + def isInflight: Boolean +} + +sealed trait PendingIsrChange extends IsrState { + def sentLeaderAndIsr: LeaderAndIsr +} + +case class PendingExpandIsr( + isr: Set[Int], + newInSyncReplicaId: Int, + sentLeaderAndIsr: LeaderAndIsr +) extends PendingIsrChange { + val maximalIsr = isr + newInSyncReplicaId + val isInflight = true + + override def toString: String = { + s"PendingExpandIsr(isr=$isr" + + s", newInSyncReplicaId=$newInSyncReplicaId" + + s", sentLeaderAndIsr=$sentLeaderAndIsr" + + ")" + } +} + +case class PendingShrinkIsr( + isr: Set[Int], + outOfSyncReplicaIds: Set[Int], + sentLeaderAndIsr: LeaderAndIsr +) extends PendingIsrChange { + val maximalIsr = isr + val isInflight = true + + override def toString: String = { + s"PendingShrinkIsr(isr=$isr" + + s", outOfSyncReplicaIds=$outOfSyncReplicaIds" + + s", sentLeaderAndIsr=$sentLeaderAndIsr" + + ")" + } +} + +case class CommittedIsr( + isr: Set[Int] +) extends IsrState { + val maximalIsr = isr + val isInflight = false + + override def toString: String = { + s"CommittedIsr(isr=$isr" + + ")" + } +} + + +/** + * Data structure that represents a topic partition. The leader maintains the AR, ISR, CUR, RAR + * + * Concurrency notes: + * 1) Partition is thread-safe. Operations on partitions may be invoked concurrently from different + * request handler threads + * 2) ISR updates are synchronized using a read-write lock. Read lock is used to check if an update + * is required to avoid acquiring write lock in the common case of replica fetch when no update + * is performed. ISR update condition is checked a second time under write lock before performing + * the update + * 3) Various other operations like leader changes are processed while holding the ISR write lock. + * This can introduce delays in produce and replica fetch requests, but these operations are typically + * infrequent. + * 4) HW updates are synchronized using ISR read lock. @Log lock is acquired during the update with + * locking order Partition lock -> Log lock. + * 5) lock is used to prevent the follower replica from being updated while ReplicaAlterDirThread is + * executing maybeReplaceCurrentWithFutureReplica() to replace follower replica with the future replica. + */ +class Partition(val topicPartition: TopicPartition, + val replicaLagTimeMaxMs: Long, + interBrokerProtocolVersion: ApiVersion, + localBrokerId: Int, + time: Time, + isrChangeListener: IsrChangeListener, + delayedOperations: DelayedOperations, + metadataCache: MetadataCache, + logManager: LogManager, + alterIsrManager: AlterIsrManager) extends Logging with KafkaMetricsGroup { + + def topic: String = topicPartition.topic + def partitionId: Int = topicPartition.partition + + private val stateChangeLogger = new StateChangeLogger(localBrokerId, inControllerContext = false, None) + private val remoteReplicasMap = new Pool[Int, Replica] + // The read lock is only required when multiple reads are executed and needs to be in a consistent manner + private val leaderIsrUpdateLock = new ReentrantReadWriteLock + + // lock to prevent the follower replica log update while checking if the log dir could be replaced with future log. + private val futureLogLock = new Object() + private var zkVersion: Int = LeaderAndIsr.initialZKVersion + @volatile private var leaderEpoch: Int = LeaderAndIsr.initialLeaderEpoch - 1 + // start offset for 'leaderEpoch' above (leader epoch of the current leader for this partition), + // defined when this broker is leader for partition + @volatile private var leaderEpochStartOffsetOpt: Option[Long] = None + @volatile var leaderReplicaIdOpt: Option[Int] = None + @volatile private[cluster] var isrState: IsrState = CommittedIsr(Set.empty) + @volatile var assignmentState: AssignmentState = SimpleAssignmentState(Seq.empty) + + // Logs belonging to this partition. Majority of time it will be only one log, but if log directory + // is getting changed (as a result of ReplicaAlterLogDirs command), we may have two logs until copy + // completes and a switch to new location is performed. + // log and futureLog variables defined below are used to capture this + @volatile var log: Option[UnifiedLog] = None + // If ReplicaAlterLogDir command is in progress, this is future location of the log + @volatile var futureLog: Option[UnifiedLog] = None + + /* Epoch of the controller that last changed the leader. This needs to be initialized correctly upon broker startup. + * One way of doing that is through the controller's start replica state change command. When a new broker starts up + * the controller sends it a start replica command containing the leader for each partition that the broker hosts. + * In addition to the leader, the controller can also send the epoch of the controller that elected the leader for + * each partition. */ + private var controllerEpoch: Int = KafkaController.InitialControllerEpoch + this.logIdent = s"[Partition $topicPartition broker=$localBrokerId] " + + private val tags = Map("topic" -> topic, "partition" -> partitionId.toString) + + newGauge("UnderReplicated", () => if (isUnderReplicated) 1 else 0, tags) + newGauge("InSyncReplicasCount", () => if (isLeader) isrState.isr.size else 0, tags) + newGauge("UnderMinIsr", () => if (isUnderMinIsr) 1 else 0, tags) + newGauge("AtMinIsr", () => if (isAtMinIsr) 1 else 0, tags) + newGauge("ReplicasCount", () => if (isLeader) assignmentState.replicationFactor else 0, tags) + newGauge("LastStableOffsetLag", () => log.map(_.lastStableOffsetLag).getOrElse(0), tags) + + def isUnderReplicated: Boolean = isLeader && (assignmentState.replicationFactor - isrState.isr.size) > 0 + + def isUnderMinIsr: Boolean = leaderLogIfLocal.exists { isrState.isr.size < _.config.minInSyncReplicas } + + def isAtMinIsr: Boolean = leaderLogIfLocal.exists { isrState.isr.size == _.config.minInSyncReplicas } + + def isReassigning: Boolean = assignmentState.isInstanceOf[OngoingReassignmentState] + + def isAddingLocalReplica: Boolean = assignmentState.isAddingReplica(localBrokerId) + + def isAddingReplica(replicaId: Int): Boolean = assignmentState.isAddingReplica(replicaId) + + def inSyncReplicaIds: Set[Int] = isrState.isr + + /** + * Create the future replica if 1) the current replica is not in the given log directory and 2) the future replica + * does not exist. This method assumes that the current replica has already been created. + * + * @param logDir log directory + * @param highWatermarkCheckpoints Checkpoint to load initial high watermark from + * @return true iff the future replica is created + */ + def maybeCreateFutureReplica(logDir: String, highWatermarkCheckpoints: OffsetCheckpoints): Boolean = { + // The writeLock is needed to make sure that while the caller checks the log directory of the + // current replica and the existence of the future replica, no other thread can update the log directory of the + // current replica or remove the future replica. + inWriteLock(leaderIsrUpdateLock) { + val currentLogDir = localLogOrException.parentDir + if (currentLogDir == logDir) { + info(s"Current log directory $currentLogDir is same as requested log dir $logDir. " + + s"Skipping future replica creation.") + false + } else { + futureLog match { + case Some(partitionFutureLog) => + val futureLogDir = partitionFutureLog.parentDir + if (futureLogDir != logDir) + throw new IllegalStateException(s"The future log dir $futureLogDir of $topicPartition is " + + s"different from the requested log dir $logDir") + false + case None => + createLogIfNotExists(isNew = false, isFutureReplica = true, highWatermarkCheckpoints, topicId) + true + } + } + } + } + + def createLogIfNotExists(isNew: Boolean, isFutureReplica: Boolean, offsetCheckpoints: OffsetCheckpoints, topicId: Option[Uuid]): Unit = { + def maybeCreate(logOpt: Option[UnifiedLog]): UnifiedLog = { + logOpt match { + case Some(log) => + trace(s"${if (isFutureReplica) "Future UnifiedLog" else "UnifiedLog"} already exists.") + if (log.topicId.isEmpty) + topicId.foreach(log.assignTopicId) + log + case None => + createLog(isNew, isFutureReplica, offsetCheckpoints, topicId) + } + } + + if (isFutureReplica) { + this.futureLog = Some(maybeCreate(this.futureLog)) + } else { + this.log = Some(maybeCreate(this.log)) + } + } + + // Visible for testing + private[cluster] def createLog(isNew: Boolean, isFutureReplica: Boolean, offsetCheckpoints: OffsetCheckpoints, topicId: Option[Uuid]): UnifiedLog = { + def updateHighWatermark(log: UnifiedLog) = { + val checkpointHighWatermark = offsetCheckpoints.fetch(log.parentDir, topicPartition).getOrElse { + info(s"No checkpointed highwatermark is found for partition $topicPartition") + 0L + } + val initialHighWatermark = log.updateHighWatermark(checkpointHighWatermark) + info(s"Log loaded for partition $topicPartition with initial high watermark $initialHighWatermark") + } + + logManager.initializingLog(topicPartition) + var maybeLog: Option[UnifiedLog] = None + try { + val log = logManager.getOrCreateLog(topicPartition, isNew, isFutureReplica, topicId) + maybeLog = Some(log) + updateHighWatermark(log) + log + } finally { + logManager.finishedInitializingLog(topicPartition, maybeLog) + } + } + + def getReplica(replicaId: Int): Option[Replica] = Option(remoteReplicasMap.get(replicaId)) + + private def checkCurrentLeaderEpoch(remoteLeaderEpochOpt: Optional[Integer]): Errors = { + if (!remoteLeaderEpochOpt.isPresent) { + Errors.NONE + } else { + val remoteLeaderEpoch = remoteLeaderEpochOpt.get + val localLeaderEpoch = leaderEpoch + if (localLeaderEpoch > remoteLeaderEpoch) + Errors.FENCED_LEADER_EPOCH + else if (localLeaderEpoch < remoteLeaderEpoch) + Errors.UNKNOWN_LEADER_EPOCH + else + Errors.NONE + } + } + + private def getLocalLog(currentLeaderEpoch: Optional[Integer], + requireLeader: Boolean): Either[UnifiedLog, Errors] = { + checkCurrentLeaderEpoch(currentLeaderEpoch) match { + case Errors.NONE => + if (requireLeader && !isLeader) { + Right(Errors.NOT_LEADER_OR_FOLLOWER) + } else { + log match { + case Some(partitionLog) => + Left(partitionLog) + case _ => + Right(Errors.NOT_LEADER_OR_FOLLOWER) + } + } + case error => + Right(error) + } + } + + def localLogOrException: UnifiedLog = log.getOrElse { + throw new NotLeaderOrFollowerException(s"Log for partition $topicPartition is not available " + + s"on broker $localBrokerId") + } + + def futureLocalLogOrException: UnifiedLog = futureLog.getOrElse { + throw new NotLeaderOrFollowerException(s"Future log for partition $topicPartition is not available " + + s"on broker $localBrokerId") + } + + def leaderLogIfLocal: Option[UnifiedLog] = { + log.filter(_ => isLeader) + } + + /** + * Returns true if this node is currently leader for the Partition. + */ + def isLeader: Boolean = leaderReplicaIdOpt.contains(localBrokerId) + + private def localLogWithEpochOrException(currentLeaderEpoch: Optional[Integer], + requireLeader: Boolean): UnifiedLog = { + getLocalLog(currentLeaderEpoch, requireLeader) match { + case Left(localLog) => localLog + case Right(error) => + throw error.exception(s"Failed to find ${if (requireLeader) "leader" else ""} log for " + + s"partition $topicPartition with leader epoch $currentLeaderEpoch. The current leader " + + s"is $leaderReplicaIdOpt and the current epoch $leaderEpoch") + } + } + + // Visible for testing -- Used by unit tests to set log for this partition + def setLog(log: UnifiedLog, isFutureLog: Boolean): Unit = { + if (isFutureLog) + futureLog = Some(log) + else + this.log = Some(log) + } + + /** + * @return the topic ID for the log or None if the log or the topic ID does not exist. + */ + def topicId: Option[Uuid] = { + val log = this.log.orElse(logManager.getLog(topicPartition)) + log.flatMap(_.topicId) + } + + // remoteReplicas will be called in the hot path, and must be inexpensive + def remoteReplicas: Iterable[Replica] = + remoteReplicasMap.values + + def futureReplicaDirChanged(newDestinationDir: String): Boolean = { + inReadLock(leaderIsrUpdateLock) { + futureLog.exists(_.parentDir != newDestinationDir) + } + } + + def removeFutureLocalReplica(deleteFromLogDir: Boolean = true): Unit = { + inWriteLock(leaderIsrUpdateLock) { + futureLog = None + if (deleteFromLogDir) + logManager.asyncDelete(topicPartition, isFuture = true) + } + } + + // Return true if the future replica exists and it has caught up with the current replica for this partition + // Only ReplicaAlterDirThread will call this method and ReplicaAlterDirThread should remove the partition + // from its partitionStates if this method returns true + def maybeReplaceCurrentWithFutureReplica(): Boolean = { + // lock to prevent the log append by followers while checking if the log dir could be replaced with future log. + futureLogLock.synchronized { + val localReplicaLEO = localLogOrException.logEndOffset + val futureReplicaLEO = futureLog.map(_.logEndOffset) + if (futureReplicaLEO.contains(localReplicaLEO)) { + // The write lock is needed to make sure that while ReplicaAlterDirThread checks the LEO of the + // current replica, no other thread can update LEO of the current replica via log truncation or log append operation. + inWriteLock(leaderIsrUpdateLock) { + futureLog match { + case Some(futurePartitionLog) => + if (log.exists(_.logEndOffset == futurePartitionLog.logEndOffset)) { + logManager.replaceCurrentWithFutureLog(topicPartition) + log = futureLog + removeFutureLocalReplica(false) + true + } else false + case None => + // Future replica is removed by a non-ReplicaAlterLogDirsThread before this method is called + // In this case the partition should have been removed from state of the ReplicaAlterLogDirsThread + // Return false so that ReplicaAlterLogDirsThread does not have to remove this partition from the + // state again to avoid race condition + false + } + } + } else false + } + } + + /** + * Delete the partition. Note that deleting the partition does not delete the underlying logs. + * The logs are deleted by the ReplicaManager after having deleted the partition. + */ + def delete(): Unit = { + // need to hold the lock to prevent appendMessagesToLeader() from hitting I/O exceptions due to log being deleted + inWriteLock(leaderIsrUpdateLock) { + remoteReplicasMap.clear() + assignmentState = SimpleAssignmentState(Seq.empty) + log = None + futureLog = None + isrState = CommittedIsr(Set.empty) + leaderReplicaIdOpt = None + leaderEpochStartOffsetOpt = None + Partition.removeMetrics(topicPartition) + } + } + + def getLeaderEpoch: Int = this.leaderEpoch + + def getZkVersion: Int = this.zkVersion + + /** + * Make the local replica the leader by resetting LogEndOffset for remote replicas (there could be old LogEndOffset + * from the time when this broker was the leader last time) and setting the new leader and ISR. + * If the leader replica id does not change, return false to indicate the replica manager. + */ + def makeLeader(partitionState: LeaderAndIsrPartitionState, + highWatermarkCheckpoints: OffsetCheckpoints, + topicId: Option[Uuid]): Boolean = { + val (leaderHWIncremented, isNewLeader) = inWriteLock(leaderIsrUpdateLock) { + // record the epoch of the controller that made the leadership decision. This is useful while updating the isr + // to maintain the decision maker controller's epoch in the zookeeper path + controllerEpoch = partitionState.controllerEpoch + + val isr = partitionState.isr.asScala.map(_.toInt).toSet + val addingReplicas = partitionState.addingReplicas.asScala.map(_.toInt) + val removingReplicas = partitionState.removingReplicas.asScala.map(_.toInt) + + updateAssignmentAndIsr( + assignment = partitionState.replicas.asScala.map(_.toInt), + isr = isr, + addingReplicas = addingReplicas, + removingReplicas = removingReplicas + ) + try { + createLogIfNotExists(partitionState.isNew, isFutureReplica = false, highWatermarkCheckpoints, topicId) + } catch { + case e: ZooKeeperClientException => + stateChangeLogger.error(s"A ZooKeeper client exception has occurred and makeLeader will be skipping the " + + s"state change for the partition $topicPartition with leader epoch: $leaderEpoch ", e) + + return false + } + + val leaderLog = localLogOrException + val leaderEpochStartOffset = leaderLog.logEndOffset + stateChangeLogger.info(s"Leader $topicPartition starts at leader epoch ${partitionState.leaderEpoch} from " + + s"offset $leaderEpochStartOffset with high watermark ${leaderLog.highWatermark} " + + s"ISR ${isr.mkString("[", ",", "]")} addingReplicas ${addingReplicas.mkString("[", ",", "]")} " + + s"removingReplicas ${removingReplicas.mkString("[", ",", "]")}. Previous leader epoch was $leaderEpoch.") + + //We cache the leader epoch here, persisting it only if it's local (hence having a log dir) + leaderEpoch = partitionState.leaderEpoch + leaderEpochStartOffsetOpt = Some(leaderEpochStartOffset) + zkVersion = partitionState.zkVersion + + // In the case of successive leader elections in a short time period, a follower may have + // entries in its log from a later epoch than any entry in the new leader's log. In order + // to ensure that these followers can truncate to the right offset, we must cache the new + // leader epoch and the start offset since it should be larger than any epoch that a follower + // would try to query. + leaderLog.maybeAssignEpochStartOffset(leaderEpoch, leaderEpochStartOffset) + + val isNewLeader = !isLeader + val curTimeMs = time.milliseconds + // initialize lastCaughtUpTime of replicas as well as their lastFetchTimeMs and lastFetchLeaderLogEndOffset. + remoteReplicas.foreach { replica => + val lastCaughtUpTimeMs = if (isrState.isr.contains(replica.brokerId)) curTimeMs else 0L + replica.resetLastCaughtUpTime(leaderEpochStartOffset, curTimeMs, lastCaughtUpTimeMs) + } + + if (isNewLeader) { + // mark local replica as the leader after converting hw + leaderReplicaIdOpt = Some(localBrokerId) + // reset log end offset for remote replicas + remoteReplicas.foreach { replica => + replica.updateFetchState( + followerFetchOffsetMetadata = LogOffsetMetadata.UnknownOffsetMetadata, + followerStartOffset = UnifiedLog.UnknownOffset, + followerFetchTimeMs = 0L, + leaderEndOffset = UnifiedLog.UnknownOffset) + } + } + // we may need to increment high watermark since ISR could be down to 1 + (maybeIncrementLeaderHW(leaderLog), isNewLeader) + } + // some delayed operations may be unblocked after HW changed + if (leaderHWIncremented) + tryCompleteDelayedRequests() + isNewLeader + } + + /** + * Make the local replica the follower by setting the new leader and ISR to empty + * If the leader replica id does not change and the new epoch is equal or one + * greater (that is, no updates have been missed), return false to indicate to the + * replica manager that state is already correct and the become-follower steps can be skipped + */ + def makeFollower(partitionState: LeaderAndIsrPartitionState, + highWatermarkCheckpoints: OffsetCheckpoints, + topicId: Option[Uuid]): Boolean = { + inWriteLock(leaderIsrUpdateLock) { + val newLeaderBrokerId = partitionState.leader + val oldLeaderEpoch = leaderEpoch + // record the epoch of the controller that made the leadership decision. This is useful while updating the isr + // to maintain the decision maker controller's epoch in the zookeeper path + controllerEpoch = partitionState.controllerEpoch + + updateAssignmentAndIsr( + assignment = partitionState.replicas.asScala.iterator.map(_.toInt).toSeq, + isr = Set.empty[Int], + addingReplicas = partitionState.addingReplicas.asScala.map(_.toInt), + removingReplicas = partitionState.removingReplicas.asScala.map(_.toInt) + ) + try { + createLogIfNotExists(partitionState.isNew, isFutureReplica = false, highWatermarkCheckpoints, topicId) + } catch { + case e: ZooKeeperClientException => + stateChangeLogger.error(s"A ZooKeeper client exception has occurred. makeFollower will be skipping the " + + s"state change for the partition $topicPartition with leader epoch: $leaderEpoch.", e) + + return false + } + + val followerLog = localLogOrException + val leaderEpochEndOffset = followerLog.logEndOffset + stateChangeLogger.info(s"Follower $topicPartition starts at leader epoch ${partitionState.leaderEpoch} from " + + s"offset $leaderEpochEndOffset with high watermark ${followerLog.highWatermark}. " + + s"Previous leader epoch was $leaderEpoch.") + + leaderEpoch = partitionState.leaderEpoch + leaderEpochStartOffsetOpt = None + zkVersion = partitionState.zkVersion + + if (leaderReplicaIdOpt.contains(newLeaderBrokerId) && leaderEpoch == oldLeaderEpoch) { + false + } else { + leaderReplicaIdOpt = Some(newLeaderBrokerId) + true + } + } + } + + /** + * Update the follower's state in the leader based on the last fetch request. See + * [[Replica.updateFetchState()]] for details. + * + * @return true if the follower's fetch state was updated, false if the followerId is not recognized + */ + def updateFollowerFetchState(followerId: Int, + followerFetchOffsetMetadata: LogOffsetMetadata, + followerStartOffset: Long, + followerFetchTimeMs: Long, + leaderEndOffset: Long): Boolean = { + getReplica(followerId) match { + case Some(followerReplica) => + // No need to calculate low watermark if there is no delayed DeleteRecordsRequest + val oldLeaderLW = if (delayedOperations.numDelayedDelete > 0) lowWatermarkIfLeader else -1L + val prevFollowerEndOffset = followerReplica.logEndOffset + followerReplica.updateFetchState( + followerFetchOffsetMetadata, + followerStartOffset, + followerFetchTimeMs, + leaderEndOffset) + + val newLeaderLW = if (delayedOperations.numDelayedDelete > 0) lowWatermarkIfLeader else -1L + // check if the LW of the partition has incremented + // since the replica's logStartOffset may have incremented + val leaderLWIncremented = newLeaderLW > oldLeaderLW + + // Check if this in-sync replica needs to be added to the ISR. + maybeExpandIsr(followerReplica) + + // check if the HW of the partition can now be incremented + // since the replica may already be in the ISR and its LEO has just incremented + val leaderHWIncremented = if (prevFollowerEndOffset != followerReplica.logEndOffset) { + // the leader log may be updated by ReplicaAlterLogDirsThread so the following method must be in lock of + // leaderIsrUpdateLock to prevent adding new hw to invalid log. + inReadLock(leaderIsrUpdateLock) { + leaderLogIfLocal.exists(leaderLog => maybeIncrementLeaderHW(leaderLog, followerFetchTimeMs)) + } + } else { + false + } + + // some delayed operations may be unblocked after HW or LW changed + if (leaderLWIncremented || leaderHWIncremented) + tryCompleteDelayedRequests() + + debug(s"Recorded replica $followerId log end offset (LEO) position " + + s"${followerFetchOffsetMetadata.messageOffset} and log start offset $followerStartOffset.") + true + + case None => + false + } + } + + /** + * Stores the topic partition assignment and ISR. + * It creates a new Replica object for any new remote broker. The isr parameter is + * expected to be a subset of the assignment parameter. + * + * Note: public visibility for tests. + * + * @param assignment An ordered sequence of all the broker ids that were assigned to this + * topic partition + * @param isr The set of broker ids that are known to be insync with the leader + * @param addingReplicas An ordered sequence of all broker ids that will be added to the + * assignment + * @param removingReplicas An ordered sequence of all broker ids that will be removed from + * the assignment + */ + def updateAssignmentAndIsr(assignment: Seq[Int], + isr: Set[Int], + addingReplicas: Seq[Int], + removingReplicas: Seq[Int]): Unit = { + val newRemoteReplicas = assignment.filter(_ != localBrokerId) + val removedReplicas = remoteReplicasMap.keys.filter(!newRemoteReplicas.contains(_)) + + // due to code paths accessing remoteReplicasMap without a lock, + // first add the new replicas and then remove the old ones + newRemoteReplicas.foreach(id => remoteReplicasMap.getAndMaybePut(id, new Replica(id, topicPartition))) + remoteReplicasMap.removeAll(removedReplicas) + + if (addingReplicas.nonEmpty || removingReplicas.nonEmpty) + assignmentState = OngoingReassignmentState(addingReplicas, removingReplicas, assignment) + else + assignmentState = SimpleAssignmentState(assignment) + isrState = CommittedIsr(isr) + } + + /** + * Check and maybe expand the ISR of the partition. + * A replica will be added to ISR if its LEO >= current hw of the partition and it is caught up to + * an offset within the current leader epoch. A replica must be caught up to the current leader + * epoch before it can join ISR, because otherwise, if there is committed data between current + * leader's HW and LEO, the replica may become the leader before it fetches the committed data + * and the data will be lost. + * + * Technically, a replica shouldn't be in ISR if it hasn't caught up for longer than replicaLagTimeMaxMs, + * even if its log end offset is >= HW. However, to be consistent with how the follower determines + * whether a replica is in-sync, we only check HW. + * + * This function can be triggered when a replica's LEO has incremented. + */ + private def maybeExpandIsr(followerReplica: Replica): Unit = { + val needsIsrUpdate = !isrState.isInflight && canAddReplicaToIsr(followerReplica.brokerId) && inReadLock(leaderIsrUpdateLock) { + needsExpandIsr(followerReplica) + } + if (needsIsrUpdate) { + val alterIsrUpdateOpt = inWriteLock(leaderIsrUpdateLock) { + // check if this replica needs to be added to the ISR + if (!isrState.isInflight && needsExpandIsr(followerReplica)) { + Some(prepareIsrExpand(followerReplica.brokerId)) + } else { + None + } + } + // Send the AlterIsr request outside of the LeaderAndIsr lock since the completion logic + // may increment the high watermark (and consequently complete delayed operations). + alterIsrUpdateOpt.foreach(submitAlterIsr) + } + } + + private def needsExpandIsr(followerReplica: Replica): Boolean = { + canAddReplicaToIsr(followerReplica.brokerId) && isFollowerAtHighwatermark(followerReplica) + } + + private def canAddReplicaToIsr(followerReplicaId: Int): Boolean = { + val current = isrState + !current.isInflight && !current.isr.contains(followerReplicaId) + } + + private def isFollowerAtHighwatermark(followerReplica: Replica): Boolean = { + leaderLogIfLocal.exists { leaderLog => + val followerEndOffset = followerReplica.logEndOffset + followerEndOffset >= leaderLog.highWatermark && leaderEpochStartOffsetOpt.exists(followerEndOffset >= _) + } + } + + /* + * Returns a tuple where the first element is a boolean indicating whether enough replicas reached `requiredOffset` + * and the second element is an error (which would be `Errors.NONE` for no error). + * + * Note that this method will only be called if requiredAcks = -1 and we are waiting for all replicas in ISR to be + * fully caught up to the (local) leader's offset corresponding to this produce request before we acknowledge the + * produce request. + */ + def checkEnoughReplicasReachOffset(requiredOffset: Long): (Boolean, Errors) = { + leaderLogIfLocal match { + case Some(leaderLog) => + // keep the current immutable replica list reference + val curMaximalIsr = isrState.maximalIsr + + if (isTraceEnabled) { + def logEndOffsetString: ((Int, Long)) => String = { + case (brokerId, logEndOffset) => s"broker $brokerId: $logEndOffset" + } + + val curInSyncReplicaObjects = (curMaximalIsr - localBrokerId).flatMap(getReplica) + val replicaInfo = curInSyncReplicaObjects.map(replica => (replica.brokerId, replica.logEndOffset)) + val localLogInfo = (localBrokerId, localLogOrException.logEndOffset) + val (ackedReplicas, awaitingReplicas) = (replicaInfo + localLogInfo).partition { _._2 >= requiredOffset} + + trace(s"Progress awaiting ISR acks for offset $requiredOffset: " + + s"acked: ${ackedReplicas.map(logEndOffsetString)}, " + + s"awaiting ${awaitingReplicas.map(logEndOffsetString)}") + } + + val minIsr = leaderLog.config.minInSyncReplicas + if (leaderLog.highWatermark >= requiredOffset) { + /* + * The topic may be configured not to accept messages if there are not enough replicas in ISR + * in this scenario the request was already appended locally and then added to the purgatory before the ISR was shrunk + */ + if (minIsr <= curMaximalIsr.size) + (true, Errors.NONE) + else + (true, Errors.NOT_ENOUGH_REPLICAS_AFTER_APPEND) + } else + (false, Errors.NONE) + case None => + (false, Errors.NOT_LEADER_OR_FOLLOWER) + } + } + + /** + * Check and maybe increment the high watermark of the partition; + * this function can be triggered when + * + * 1. Partition ISR changed + * 2. Any replica's LEO changed + * + * The HW is determined by the smallest log end offset among all replicas that are in sync or are considered caught-up. + * This way, if a replica is considered caught-up, but its log end offset is smaller than HW, we will wait for this + * replica to catch up to the HW before advancing the HW. This helps the situation when the ISR only includes the + * leader replica and a follower tries to catch up. If we don't wait for the follower when advancing the HW, the + * follower's log end offset may keep falling behind the HW (determined by the leader's log end offset) and therefore + * will never be added to ISR. + * + * With the addition of AlterIsr, we also consider newly added replicas as part of the ISR when advancing + * the HW. These replicas have not yet been committed to the ISR by the controller, so we could revert to the previously + * committed ISR. However, adding additional replicas to the ISR makes it more restrictive and therefor safe. We call + * this set the "maximal" ISR. See KIP-497 for more details + * + * Note There is no need to acquire the leaderIsrUpdate lock here since all callers of this private API acquire that lock + * + * @return true if the HW was incremented, and false otherwise. + */ + private def maybeIncrementLeaderHW(leaderLog: UnifiedLog, curTime: Long = time.milliseconds): Boolean = { + // maybeIncrementLeaderHW is in the hot path, the following code is written to + // avoid unnecessary collection generation + var newHighWatermark = leaderLog.logEndOffsetMetadata + remoteReplicasMap.values.foreach { replica => + // Note here we are using the "maximal", see explanation above + if (replica.logEndOffsetMetadata.messageOffset < newHighWatermark.messageOffset && + (curTime - replica.lastCaughtUpTimeMs <= replicaLagTimeMaxMs || isrState.maximalIsr.contains(replica.brokerId))) { + newHighWatermark = replica.logEndOffsetMetadata + } + } + + leaderLog.maybeIncrementHighWatermark(newHighWatermark) match { + case Some(oldHighWatermark) => + debug(s"High watermark updated from $oldHighWatermark to $newHighWatermark") + true + + case None => + def logEndOffsetString: ((Int, LogOffsetMetadata)) => String = { + case (brokerId, logEndOffsetMetadata) => s"replica $brokerId: $logEndOffsetMetadata" + } + + if (isTraceEnabled) { + val replicaInfo = remoteReplicas.map(replica => (replica.brokerId, replica.logEndOffsetMetadata)).toSet + val localLogInfo = (localBrokerId, localLogOrException.logEndOffsetMetadata) + trace(s"Skipping update high watermark since new hw $newHighWatermark is not larger than old value. " + + s"All current LEOs are ${(replicaInfo + localLogInfo).map(logEndOffsetString)}") + } + false + } + } + + /** + * The low watermark offset value, calculated only if the local replica is the partition leader + * It is only used by leader broker to decide when DeleteRecordsRequest is satisfied. Its value is minimum logStartOffset of all live replicas + * Low watermark will increase when the leader broker receives either FetchRequest or DeleteRecordsRequest. + */ + def lowWatermarkIfLeader: Long = { + if (!isLeader) + throw new NotLeaderOrFollowerException(s"Leader not local for partition $topicPartition on broker $localBrokerId") + + // lowWatermarkIfLeader may be called many times when a DeleteRecordsRequest is outstanding, + // care has been taken to avoid generating unnecessary collections in this code + var lowWaterMark = localLogOrException.logStartOffset + remoteReplicas.foreach { replica => + if (metadataCache.hasAliveBroker(replica.brokerId) && replica.logStartOffset < lowWaterMark) { + lowWaterMark = replica.logStartOffset + } + } + + futureLog match { + case Some(partitionFutureLog) => + Math.min(lowWaterMark, partitionFutureLog.logStartOffset) + case None => + lowWaterMark + } + } + + /** + * Try to complete any pending requests. This should be called without holding the leaderIsrUpdateLock. + */ + private def tryCompleteDelayedRequests(): Unit = delayedOperations.checkAndCompleteAll() + + def maybeShrinkIsr(): Unit = { + def needsIsrUpdate: Boolean = { + !isrState.isInflight && inReadLock(leaderIsrUpdateLock) { + needsShrinkIsr() + } + } + + if (needsIsrUpdate) { + val alterIsrUpdateOpt = inWriteLock(leaderIsrUpdateLock) { + leaderLogIfLocal.flatMap { leaderLog => + val outOfSyncReplicaIds = getOutOfSyncReplicas(replicaLagTimeMaxMs) + if (!isrState.isInflight && outOfSyncReplicaIds.nonEmpty) { + val outOfSyncReplicaLog = outOfSyncReplicaIds.map { replicaId => + val logEndOffsetMessage = getReplica(replicaId) + .map(_.logEndOffset.toString) + .getOrElse("unknown") + s"(brokerId: $replicaId, endOffset: $logEndOffsetMessage)" + }.mkString(" ") + val newIsrLog = (isrState.isr -- outOfSyncReplicaIds).mkString(",") + info(s"Shrinking ISR from ${isrState.isr.mkString(",")} to $newIsrLog. " + + s"Leader: (highWatermark: ${leaderLog.highWatermark}, " + + s"endOffset: ${leaderLog.logEndOffset}). " + + s"Out of sync replicas: $outOfSyncReplicaLog.") + Some(prepareIsrShrink(outOfSyncReplicaIds)) + } else { + None + } + } + } + // Send the AlterIsr request outside of the LeaderAndIsr lock since the completion logic + // may increment the high watermark (and consequently complete delayed operations). + alterIsrUpdateOpt.foreach(submitAlterIsr) + } + } + + private def needsShrinkIsr(): Boolean = { + leaderLogIfLocal.exists { _ => getOutOfSyncReplicas(replicaLagTimeMaxMs).nonEmpty } + } + + private def isFollowerOutOfSync(replicaId: Int, + leaderEndOffset: Long, + currentTimeMs: Long, + maxLagMs: Long): Boolean = { + getReplica(replicaId).fold(true) { followerReplica => + followerReplica.logEndOffset != leaderEndOffset && + (currentTimeMs - followerReplica.lastCaughtUpTimeMs) > maxLagMs + } + } + + /** + * If the follower already has the same leo as the leader, it will not be considered as out-of-sync, + * otherwise there are two cases that will be handled here - + * 1. Stuck followers: If the leo of the replica hasn't been updated for maxLagMs ms, + * the follower is stuck and should be removed from the ISR + * 2. Slow followers: If the replica has not read up to the leo within the last maxLagMs ms, + * then the follower is lagging and should be removed from the ISR + * Both these cases are handled by checking the lastCaughtUpTimeMs which represents + * the last time when the replica was fully caught up. If either of the above conditions + * is violated, that replica is considered to be out of sync + * + * If an ISR update is in-flight, we will return an empty set here + **/ + def getOutOfSyncReplicas(maxLagMs: Long): Set[Int] = { + val current = isrState + if (!current.isInflight) { + val candidateReplicaIds = current.isr - localBrokerId + val currentTimeMs = time.milliseconds() + val leaderEndOffset = localLogOrException.logEndOffset + candidateReplicaIds.filter(replicaId => isFollowerOutOfSync(replicaId, leaderEndOffset, currentTimeMs, maxLagMs)) + } else { + Set.empty + } + } + + private def doAppendRecordsToFollowerOrFutureReplica(records: MemoryRecords, isFuture: Boolean): Option[LogAppendInfo] = { + if (isFuture) { + // The read lock is needed to handle race condition if request handler thread tries to + // remove future replica after receiving AlterReplicaLogDirsRequest. + inReadLock(leaderIsrUpdateLock) { + // Note the replica may be undefined if it is removed by a non-ReplicaAlterLogDirsThread before + // this method is called + futureLog.map { _.appendAsFollower(records) } + } + } else { + // The lock is needed to prevent the follower replica from being updated while ReplicaAlterDirThread + // is executing maybeReplaceCurrentWithFutureReplica() to replace follower replica with the future replica. + futureLogLock.synchronized { + Some(localLogOrException.appendAsFollower(records)) + } + } + } + + def appendRecordsToFollowerOrFutureReplica(records: MemoryRecords, isFuture: Boolean): Option[LogAppendInfo] = { + try { + doAppendRecordsToFollowerOrFutureReplica(records, isFuture) + } catch { + case e: UnexpectedAppendOffsetException => + val log = if (isFuture) futureLocalLogOrException else localLogOrException + val logEndOffset = log.logEndOffset + if (logEndOffset == log.logStartOffset && + e.firstOffset < logEndOffset && e.lastOffset >= logEndOffset) { + // This may happen if the log start offset on the leader (or current replica) falls in + // the middle of the batch due to delete records request and the follower tries to + // fetch its first offset from the leader. + // We handle this case here instead of Log#append() because we will need to remove the + // segment that start with log start offset and create a new one with earlier offset + // (base offset of the batch), which will move recoveryPoint backwards, so we will need + // to checkpoint the new recovery point before we append + val replicaName = if (isFuture) "future replica" else "follower" + info(s"Unexpected offset in append to $topicPartition. First offset ${e.firstOffset} is less than log start offset ${log.logStartOffset}." + + s" Since this is the first record to be appended to the $replicaName's log, will start the log from offset ${e.firstOffset}.") + truncateFullyAndStartAt(e.firstOffset, isFuture) + doAppendRecordsToFollowerOrFutureReplica(records, isFuture) + } else + throw e + } + } + + def appendRecordsToLeader(records: MemoryRecords, origin: AppendOrigin, requiredAcks: Int, + requestLocal: RequestLocal): LogAppendInfo = { + val (info, leaderHWIncremented) = inReadLock(leaderIsrUpdateLock) { + leaderLogIfLocal match { + case Some(leaderLog) => + val minIsr = leaderLog.config.minInSyncReplicas + val inSyncSize = isrState.isr.size + + // Avoid writing to leader if there are not enough insync replicas to make it safe + if (inSyncSize < minIsr && requiredAcks == -1) { + throw new NotEnoughReplicasException(s"The size of the current ISR ${isrState.isr} " + + s"is insufficient to satisfy the min.isr requirement of $minIsr for partition $topicPartition") + } + + val info = leaderLog.appendAsLeader(records, leaderEpoch = this.leaderEpoch, origin, + interBrokerProtocolVersion, requestLocal) + + // we may need to increment high watermark since ISR could be down to 1 + (info, maybeIncrementLeaderHW(leaderLog)) + + case None => + throw new NotLeaderOrFollowerException("Leader not local for partition %s on broker %d" + .format(topicPartition, localBrokerId)) + } + } + + info.copy(leaderHwChange = if (leaderHWIncremented) LeaderHwChange.Increased else LeaderHwChange.Same) + } + + def readRecords(lastFetchedEpoch: Optional[Integer], + fetchOffset: Long, + currentLeaderEpoch: Optional[Integer], + maxBytes: Int, + fetchIsolation: FetchIsolation, + fetchOnlyFromLeader: Boolean, + minOneMessage: Boolean): LogReadInfo = inReadLock(leaderIsrUpdateLock) { + // decide whether to only fetch from leader + val localLog = localLogWithEpochOrException(currentLeaderEpoch, fetchOnlyFromLeader) + + // Note we use the log end offset prior to the read. This ensures that any appends following + // the fetch do not prevent a follower from coming into sync. + val initialHighWatermark = localLog.highWatermark + val initialLogStartOffset = localLog.logStartOffset + val initialLogEndOffset = localLog.logEndOffset + val initialLastStableOffset = localLog.lastStableOffset + + lastFetchedEpoch.ifPresent { fetchEpoch => + val epochEndOffset = lastOffsetForLeaderEpoch(currentLeaderEpoch, fetchEpoch, fetchOnlyFromLeader = false) + val error = Errors.forCode(epochEndOffset.errorCode) + if (error != Errors.NONE) { + throw error.exception() + } + + if (epochEndOffset.endOffset == UNDEFINED_EPOCH_OFFSET || epochEndOffset.leaderEpoch == UNDEFINED_EPOCH) { + throw new OffsetOutOfRangeException("Could not determine the end offset of the last fetched epoch " + + s"$lastFetchedEpoch from the request") + } + + // If fetch offset is less than log start, fail with OffsetOutOfRangeException, regardless of whether epochs are diverging + if (fetchOffset < initialLogStartOffset) { + throw new OffsetOutOfRangeException(s"Received request for offset $fetchOffset for partition $topicPartition, " + + s"but we only have log segments in the range $initialLogStartOffset to $initialLogEndOffset.") + } + + if (epochEndOffset.leaderEpoch < fetchEpoch || epochEndOffset.endOffset < fetchOffset) { + val emptyFetchData = FetchDataInfo( + fetchOffsetMetadata = LogOffsetMetadata(fetchOffset), + records = MemoryRecords.EMPTY, + firstEntryIncomplete = false, + abortedTransactions = None + ) + + val divergingEpoch = new FetchResponseData.EpochEndOffset() + .setEpoch(epochEndOffset.leaderEpoch) + .setEndOffset(epochEndOffset.endOffset) + + return LogReadInfo( + fetchedData = emptyFetchData, + divergingEpoch = Some(divergingEpoch), + highWatermark = initialHighWatermark, + logStartOffset = initialLogStartOffset, + logEndOffset = initialLogEndOffset, + lastStableOffset = initialLastStableOffset) + } + } + + val fetchedData = localLog.read(fetchOffset, maxBytes, fetchIsolation, minOneMessage) + LogReadInfo( + fetchedData = fetchedData, + divergingEpoch = None, + highWatermark = initialHighWatermark, + logStartOffset = initialLogStartOffset, + logEndOffset = initialLogEndOffset, + lastStableOffset = initialLastStableOffset) + } + + def fetchOffsetForTimestamp(timestamp: Long, + isolationLevel: Option[IsolationLevel], + currentLeaderEpoch: Optional[Integer], + fetchOnlyFromLeader: Boolean): Option[TimestampAndOffset] = inReadLock(leaderIsrUpdateLock) { + // decide whether to only fetch from leader + val localLog = localLogWithEpochOrException(currentLeaderEpoch, fetchOnlyFromLeader) + + val lastFetchableOffset = isolationLevel match { + case Some(IsolationLevel.READ_COMMITTED) => localLog.lastStableOffset + case Some(IsolationLevel.READ_UNCOMMITTED) => localLog.highWatermark + case None => localLog.logEndOffset + } + + val epochLogString = if (currentLeaderEpoch.isPresent) { + s"epoch ${currentLeaderEpoch.get}" + } else { + "unknown epoch" + } + + // Only consider throwing an error if we get a client request (isolationLevel is defined) and the start offset + // is lagging behind the high watermark + val maybeOffsetsError: Option[ApiException] = leaderEpochStartOffsetOpt + .filter(epochStart => isolationLevel.isDefined && epochStart > localLog.highWatermark) + .map(epochStart => Errors.OFFSET_NOT_AVAILABLE.exception(s"Failed to fetch offsets for " + + s"partition $topicPartition with leader $epochLogString as this partition's " + + s"high watermark (${localLog.highWatermark}) is lagging behind the " + + s"start offset from the beginning of this epoch ($epochStart).")) + + def getOffsetByTimestamp: Option[TimestampAndOffset] = { + logManager.getLog(topicPartition).flatMap(log => log.fetchOffsetByTimestamp(timestamp)) + } + + // If we're in the lagging HW state after a leader election, throw OffsetNotAvailable for "latest" offset + // or for a timestamp lookup that is beyond the last fetchable offset. + timestamp match { + case ListOffsetsRequest.LATEST_TIMESTAMP => + maybeOffsetsError.map(e => throw e) + .orElse(Some(new TimestampAndOffset(RecordBatch.NO_TIMESTAMP, lastFetchableOffset, Optional.of(leaderEpoch)))) + case ListOffsetsRequest.EARLIEST_TIMESTAMP => + getOffsetByTimestamp + case _ => + getOffsetByTimestamp.filter(timestampAndOffset => timestampAndOffset.offset < lastFetchableOffset) + .orElse(maybeOffsetsError.map(e => throw e)) + } + } + + def activeProducerState: DescribeProducersResponseData.PartitionResponse = { + val producerState = new DescribeProducersResponseData.PartitionResponse() + .setPartitionIndex(topicPartition.partition()) + + log.map(_.activeProducers) match { + case Some(producers) => + producerState + .setErrorCode(Errors.NONE.code) + .setActiveProducers(producers.asJava) + case None => + producerState + .setErrorCode(Errors.NOT_LEADER_OR_FOLLOWER.code) + } + + producerState + } + + def fetchOffsetSnapshot(currentLeaderEpoch: Optional[Integer], + fetchOnlyFromLeader: Boolean): LogOffsetSnapshot = inReadLock(leaderIsrUpdateLock) { + // decide whether to only fetch from leader + val localLog = localLogWithEpochOrException(currentLeaderEpoch, fetchOnlyFromLeader) + localLog.fetchOffsetSnapshot + } + + def legacyFetchOffsetsForTimestamp(timestamp: Long, + maxNumOffsets: Int, + isFromConsumer: Boolean, + fetchOnlyFromLeader: Boolean): Seq[Long] = inReadLock(leaderIsrUpdateLock) { + val localLog = localLogWithEpochOrException(Optional.empty(), fetchOnlyFromLeader) + val allOffsets = localLog.legacyFetchOffsetsBefore(timestamp, maxNumOffsets) + + if (!isFromConsumer) { + allOffsets + } else { + val hw = localLog.highWatermark + if (allOffsets.exists(_ > hw)) + hw +: allOffsets.dropWhile(_ > hw) + else + allOffsets + } + } + + def logStartOffset: Long = { + inReadLock(leaderIsrUpdateLock) { + leaderLogIfLocal.map(_.logStartOffset).getOrElse(-1) + } + } + + /** + * Update logStartOffset and low watermark if 1) offset <= highWatermark and 2) it is the leader replica. + * This function can trigger log segment deletion and log rolling. + * + * Return low watermark of the partition. + */ + def deleteRecordsOnLeader(offset: Long): LogDeleteRecordsResult = inReadLock(leaderIsrUpdateLock) { + leaderLogIfLocal match { + case Some(leaderLog) => + if (!leaderLog.config.delete) + throw new PolicyViolationException(s"Records of partition $topicPartition can not be deleted due to the configured policy") + + val convertedOffset = if (offset == DeleteRecordsRequest.HIGH_WATERMARK) + leaderLog.highWatermark + else + offset + + if (convertedOffset < 0) + throw new OffsetOutOfRangeException(s"The offset $convertedOffset for partition $topicPartition is not valid") + + leaderLog.maybeIncrementLogStartOffset(convertedOffset, ClientRecordDeletion) + LogDeleteRecordsResult( + requestedOffset = convertedOffset, + lowWatermark = lowWatermarkIfLeader) + case None => + throw new NotLeaderOrFollowerException(s"Leader not local for partition $topicPartition on broker $localBrokerId") + } + } + + /** + * Truncate the local log of this partition to the specified offset and checkpoint the recovery point to this offset + * + * @param offset offset to be used for truncation + * @param isFuture True iff the truncation should be performed on the future log of this partition + */ + def truncateTo(offset: Long, isFuture: Boolean): Unit = { + // The read lock is needed to prevent the follower replica from being truncated while ReplicaAlterDirThread + // is executing maybeReplaceCurrentWithFutureReplica() to replace follower replica with the future replica. + inReadLock(leaderIsrUpdateLock) { + logManager.truncateTo(Map(topicPartition -> offset), isFuture = isFuture) + } + } + + /** + * Delete all data in the local log of this partition and start the log at the new offset + * + * @param newOffset The new offset to start the log with + * @param isFuture True iff the truncation should be performed on the future log of this partition + */ + def truncateFullyAndStartAt(newOffset: Long, isFuture: Boolean): Unit = { + // The read lock is needed to prevent the follower replica from being truncated while ReplicaAlterDirThread + // is executing maybeReplaceCurrentWithFutureReplica() to replace follower replica with the future replica. + inReadLock(leaderIsrUpdateLock) { + logManager.truncateFullyAndStartAt(topicPartition, newOffset, isFuture = isFuture) + } + } + + /** + * Find the (exclusive) last offset of the largest epoch less than or equal to the requested epoch. + * + * @param currentLeaderEpoch The expected epoch of the current leader (if known) + * @param leaderEpoch Requested leader epoch + * @param fetchOnlyFromLeader Whether or not to require servicing only from the leader + * + * @return The requested leader epoch and the end offset of this leader epoch, or if the requested + * leader epoch is unknown, the leader epoch less than the requested leader epoch and the end offset + * of this leader epoch. The end offset of a leader epoch is defined as the start + * offset of the first leader epoch larger than the leader epoch, or else the log end + * offset if the leader epoch is the latest leader epoch. + */ + def lastOffsetForLeaderEpoch(currentLeaderEpoch: Optional[Integer], + leaderEpoch: Int, + fetchOnlyFromLeader: Boolean): EpochEndOffset = { + inReadLock(leaderIsrUpdateLock) { + val localLogOrError = getLocalLog(currentLeaderEpoch, fetchOnlyFromLeader) + localLogOrError match { + case Left(localLog) => + localLog.endOffsetForEpoch(leaderEpoch) match { + case Some(epochAndOffset) => new EpochEndOffset() + .setPartition(partitionId) + .setErrorCode(Errors.NONE.code) + .setLeaderEpoch(epochAndOffset.leaderEpoch) + .setEndOffset(epochAndOffset.offset) + case None => new EpochEndOffset() + .setPartition(partitionId) + .setErrorCode(Errors.NONE.code) + } + case Right(error) => new EpochEndOffset() + .setPartition(partitionId) + .setErrorCode(error.code) + } + } + } + + private def prepareIsrExpand(newInSyncReplicaId: Int): PendingExpandIsr = { + // When expanding the ISR, we assume that the new replica will make it into the ISR + // before we receive confirmation that it has. This ensures that the HW will already + // reflect the updated ISR even if there is a delay before we receive the confirmation. + // Alternatively, if the update fails, no harm is done since the expanded ISR puts + // a stricter requirement for advancement of the HW. + val isrToSend = isrState.isr + newInSyncReplicaId + val newLeaderAndIsr = new LeaderAndIsr(localBrokerId, leaderEpoch, isrToSend.toList, zkVersion) + val updatedState = PendingExpandIsr(isrState.isr, newInSyncReplicaId, newLeaderAndIsr) + isrState = updatedState + updatedState + } + + private[cluster] def prepareIsrShrink(outOfSyncReplicaIds: Set[Int]): PendingShrinkIsr = { + // When shrinking the ISR, we cannot assume that the update will succeed as this could + // erroneously advance the HW if the `AlterIsr` were to fail. Hence the "maximal ISR" + // for `PendingShrinkIsr` is the the current ISR. + val isrToSend = isrState.isr -- outOfSyncReplicaIds + val newLeaderAndIsr = new LeaderAndIsr(localBrokerId, leaderEpoch, isrToSend.toList, zkVersion) + val updatedState = PendingShrinkIsr(isrState.isr, outOfSyncReplicaIds, newLeaderAndIsr) + isrState = updatedState + updatedState + } + + private def submitAlterIsr(proposedIsrState: PendingIsrChange): CompletableFuture[LeaderAndIsr] = { + debug(s"Submitting ISR state change $proposedIsrState") + val future = alterIsrManager.submit(topicPartition, proposedIsrState.sentLeaderAndIsr, controllerEpoch) + future.whenComplete { (leaderAndIsr, e) => + var hwIncremented = false + var shouldRetry = false + + inWriteLock(leaderIsrUpdateLock) { + if (isrState != proposedIsrState) { + // This means isrState was updated through leader election or some other mechanism + // before we got the AlterIsr response. We don't know what happened on the controller + // exactly, but we do know this response is out of date so we ignore it. + debug(s"Ignoring failed ISR update to $proposedIsrState since we have already " + + s"updated state to $isrState") + } else if (leaderAndIsr != null) { + hwIncremented = handleAlterIsrUpdate(proposedIsrState, leaderAndIsr) + } else { + shouldRetry = handleAlterIsrError(proposedIsrState, Errors.forException(e)) + } + } + + if (hwIncremented) { + tryCompleteDelayedRequests() + } + + // Send the AlterIsr request outside of the LeaderAndIsr lock since the completion logic + // may increment the high watermark (and consequently complete delayed operations). + if (shouldRetry) { + submitAlterIsr(proposedIsrState) + } + } + } + + /** + * Handle a failed `AlterIsr` request. For errors which are non-retriable, we simply give up. + * This leaves [[Partition.isrState]] in a pending state. Since the error was non-retriable, + * we are okay staying in this state until we see new metadata from LeaderAndIsr (or an update + * to the KRaft metadata log). + * + * @param proposedIsrState The ISR state change that was requested + * @param error The error returned from [[AlterIsrManager]] + * @return true if the `AlterIsr` request should be retried, false otherwise + */ + private def handleAlterIsrError( + proposedIsrState: PendingIsrChange, + error: Errors + ): Boolean = { + isrChangeListener.markFailed() + error match { + case Errors.OPERATION_NOT_ATTEMPTED => + // Since the operation was not attempted, it is safe to reset back to the committed state. + isrState = CommittedIsr(proposedIsrState.isr) + debug(s"Failed to update ISR to $proposedIsrState since there is a pending ISR update still inflight. " + + s"ISR state has been reset to the latest committed state $isrState") + false + case Errors.UNKNOWN_TOPIC_OR_PARTITION => + debug(s"Failed to update ISR to $proposedIsrState since the controller doesn't know about " + + "this topic or partition. Giving up.") + false + case Errors.FENCED_LEADER_EPOCH => + debug(s"Failed to update ISR to $proposedIsrState since the leader epoch is old. Giving up.") + false + case Errors.INVALID_UPDATE_VERSION => + debug(s"Failed to update ISR to $proposedIsrState because the version is invalid. Giving up.") + false + case _ => + warn(s"Failed to update ISR to $proposedIsrState due to unexpected $error. Retrying.") + true + } + } + + /** + * Handle a successful `AlterIsr` response. + * + * @param proposedIsrState The ISR state change that was requested + * @param leaderAndIsr The updated LeaderAndIsr state + * @return true if the high watermark was successfully incremented following, false otherwise + */ + private def handleAlterIsrUpdate( + proposedIsrState: PendingIsrChange, + leaderAndIsr: LeaderAndIsr + ): Boolean = { + // Success from controller, still need to check a few things + if (leaderAndIsr.leaderEpoch != leaderEpoch) { + debug(s"Ignoring new ISR $leaderAndIsr since we have a stale leader epoch $leaderEpoch.") + isrChangeListener.markFailed() + false + } else if (leaderAndIsr.zkVersion < zkVersion) { + debug(s"Ignoring new ISR $leaderAndIsr since we have a newer version $zkVersion.") + isrChangeListener.markFailed() + false + } else { + // This is one of two states: + // 1) leaderAndIsr.zkVersion > zkVersion: Controller updated to new version with proposedIsrState. + // 2) leaderAndIsr.zkVersion == zkVersion: No update was performed since proposed and actual state are the same. + // In both cases, we want to move from Pending to Committed state to ensure new updates are processed. + + isrState = CommittedIsr(leaderAndIsr.isr.toSet) + zkVersion = leaderAndIsr.zkVersion + info(s"ISR updated to ${isrState.isr.mkString(",")} and version updated to $zkVersion") + + proposedIsrState match { + case PendingExpandIsr(_, _, _) => isrChangeListener.markExpand() + case PendingShrinkIsr(_, _, _) => isrChangeListener.markShrink() + } + + // we may need to increment high watermark since ISR could be down to 1 + leaderLogIfLocal.exists(log => maybeIncrementLeaderHW(log)) + } + } + + override def equals(that: Any): Boolean = that match { + case other: Partition => partitionId == other.partitionId && topic == other.topic + case _ => false + } + + override def hashCode: Int = + 31 + topic.hashCode + 17 * partitionId + + override def toString: String = { + val partitionString = new StringBuilder + partitionString.append("Topic: " + topic) + partitionString.append("; Partition: " + partitionId) + partitionString.append("; Leader: " + leaderReplicaIdOpt) + partitionString.append("; Replicas: " + assignmentState.replicas.mkString(",")) + partitionString.append("; ISR: " + isrState.isr.mkString(",")) + assignmentState match { + case OngoingReassignmentState(adding, removing, _) => + partitionString.append("; AddingReplicas: " + adding.mkString(",")) + partitionString.append("; RemovingReplicas: " + removing.mkString(",")) + case _ => + } + partitionString.toString + } +} diff --git a/core/src/main/scala/kafka/cluster/Replica.scala b/core/src/main/scala/kafka/cluster/Replica.scala new file mode 100644 index 0000000..921faef --- /dev/null +++ b/core/src/main/scala/kafka/cluster/Replica.scala @@ -0,0 +1,108 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +package kafka.cluster + +import kafka.log.UnifiedLog +import kafka.server.LogOffsetMetadata +import kafka.utils.Logging +import org.apache.kafka.common.TopicPartition + +class Replica(val brokerId: Int, val topicPartition: TopicPartition) extends Logging { + // the log end offset value, kept in all replicas; + // for local replica it is the log's end offset, for remote replicas its value is only updated by follower fetch + @volatile private[this] var _logEndOffsetMetadata = LogOffsetMetadata.UnknownOffsetMetadata + // the log start offset value, kept in all replicas; + // for local replica it is the log's start offset, for remote replicas its value is only updated by follower fetch + @volatile private[this] var _logStartOffset = UnifiedLog.UnknownOffset + + // The log end offset value at the time the leader received the last FetchRequest from this follower + // This is used to determine the lastCaughtUpTimeMs of the follower + @volatile private[this] var lastFetchLeaderLogEndOffset = 0L + + // The time when the leader received the last FetchRequest from this follower + // This is used to determine the lastCaughtUpTimeMs of the follower + @volatile private[this] var lastFetchTimeMs = 0L + + // lastCaughtUpTimeMs is the largest time t such that the offset of most recent FetchRequest from this follower >= + // the LEO of leader at time t. This is used to determine the lag of this follower and ISR of this partition. + @volatile private[this] var _lastCaughtUpTimeMs = 0L + + def logStartOffset: Long = _logStartOffset + + def logEndOffsetMetadata: LogOffsetMetadata = _logEndOffsetMetadata + + def logEndOffset: Long = logEndOffsetMetadata.messageOffset + + def lastCaughtUpTimeMs: Long = _lastCaughtUpTimeMs + + /* + * If the FetchRequest reads up to the log end offset of the leader when the current fetch request is received, + * set `lastCaughtUpTimeMs` to the time when the current fetch request was received. + * + * Else if the FetchRequest reads up to the log end offset of the leader when the previous fetch request was received, + * set `lastCaughtUpTimeMs` to the time when the previous fetch request was received. + * + * This is needed to enforce the semantics of ISR, i.e. a replica is in ISR if and only if it lags behind leader's LEO + * by at most `replicaLagTimeMaxMs`. These semantics allow a follower to be added to the ISR even if the offset of its + * fetch request is always smaller than the leader's LEO, which can happen if small produce requests are received at + * high frequency. + */ + def updateFetchState(followerFetchOffsetMetadata: LogOffsetMetadata, + followerStartOffset: Long, + followerFetchTimeMs: Long, + leaderEndOffset: Long): Unit = { + if (followerFetchOffsetMetadata.messageOffset >= leaderEndOffset) + _lastCaughtUpTimeMs = math.max(_lastCaughtUpTimeMs, followerFetchTimeMs) + else if (followerFetchOffsetMetadata.messageOffset >= lastFetchLeaderLogEndOffset) + _lastCaughtUpTimeMs = math.max(_lastCaughtUpTimeMs, lastFetchTimeMs) + + _logStartOffset = followerStartOffset + _logEndOffsetMetadata = followerFetchOffsetMetadata + lastFetchLeaderLogEndOffset = leaderEndOffset + lastFetchTimeMs = followerFetchTimeMs + } + + def resetLastCaughtUpTime(curLeaderLogEndOffset: Long, curTimeMs: Long, lastCaughtUpTimeMs: Long): Unit = { + lastFetchLeaderLogEndOffset = curLeaderLogEndOffset + lastFetchTimeMs = curTimeMs + _lastCaughtUpTimeMs = lastCaughtUpTimeMs + trace(s"Reset state of replica to $this") + } + + override def toString: String = { + val replicaString = new StringBuilder + replicaString.append("Replica(replicaId=" + brokerId) + replicaString.append(s", topic=${topicPartition.topic}") + replicaString.append(s", partition=${topicPartition.partition}") + replicaString.append(s", lastCaughtUpTimeMs=$lastCaughtUpTimeMs") + replicaString.append(s", logStartOffset=$logStartOffset") + replicaString.append(s", logEndOffset=$logEndOffset") + replicaString.append(s", logEndOffsetMetadata=$logEndOffsetMetadata") + replicaString.append(s", lastFetchLeaderLogEndOffset=$lastFetchLeaderLogEndOffset") + replicaString.append(s", lastFetchTimeMs=$lastFetchTimeMs") + replicaString.append(")") + replicaString.toString + } + + override def equals(that: Any): Boolean = that match { + case other: Replica => brokerId == other.brokerId && topicPartition == other.topicPartition + case _ => false + } + + override def hashCode: Int = 31 + topicPartition.hashCode + 17 * brokerId +} diff --git a/core/src/main/scala/kafka/common/AdminCommandFailedException.scala b/core/src/main/scala/kafka/common/AdminCommandFailedException.scala new file mode 100644 index 0000000..94e2864 --- /dev/null +++ b/core/src/main/scala/kafka/common/AdminCommandFailedException.scala @@ -0,0 +1,23 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.common + +class AdminCommandFailedException(message: String, cause: Throwable) extends RuntimeException(message, cause) { + def this(message: String) = this(message, null) + def this() = this(null, null) +} diff --git a/core/src/main/scala/kafka/common/BaseEnum.scala b/core/src/main/scala/kafka/common/BaseEnum.scala new file mode 100644 index 0000000..9c39466 --- /dev/null +++ b/core/src/main/scala/kafka/common/BaseEnum.scala @@ -0,0 +1,26 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.common + +/* + * We inherit from `Product` and `Serializable` because `case` objects and classes inherit from them and if we don't + * do it here, the compiler will infer types that unexpectedly include `Product` and `Serializable`, see + * http://underscore.io/blog/posts/2015/06/04/more-on-sealed.html for more information. + */ +trait BaseEnum extends Product with Serializable { + def name: String +} diff --git a/core/src/main/scala/kafka/common/BrokerEndPointNotAvailableException.scala b/core/src/main/scala/kafka/common/BrokerEndPointNotAvailableException.scala new file mode 100644 index 0000000..455d8c6 --- /dev/null +++ b/core/src/main/scala/kafka/common/BrokerEndPointNotAvailableException.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.common + +class BrokerEndPointNotAvailableException(message: String) extends RuntimeException(message) { + def this() = this(null) +} diff --git a/core/src/main/scala/kafka/common/ClientIdAndBroker.scala b/core/src/main/scala/kafka/common/ClientIdAndBroker.scala new file mode 100644 index 0000000..3b09041 --- /dev/null +++ b/core/src/main/scala/kafka/common/ClientIdAndBroker.scala @@ -0,0 +1,34 @@ +package kafka.common + +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Convenience case class since (clientId, brokerInfo) pairs are used to create + * SyncProducer Request Stats and SimpleConsumer Request and Response Stats. + */ + +trait ClientIdBroker { +} + +case class ClientIdAndBroker(clientId: String, brokerHost: String, brokerPort: Int) extends ClientIdBroker { + override def toString = "%s-%s-%d".format(clientId, brokerHost, brokerPort) +} + +case class ClientIdAllBrokers(clientId: String) extends ClientIdBroker { + override def toString = "%s-%s".format(clientId, "AllBrokers") +} diff --git a/core/src/main/scala/kafka/common/Config.scala b/core/src/main/scala/kafka/common/Config.scala new file mode 100644 index 0000000..f56cca8 --- /dev/null +++ b/core/src/main/scala/kafka/common/Config.scala @@ -0,0 +1,41 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.common + +import util.matching.Regex +import kafka.utils.Logging +import org.apache.kafka.common.errors.InvalidConfigurationException + +trait Config extends Logging { + + def validateChars(prop: String, value: String): Unit = { + val legalChars = "[a-zA-Z0-9\\._\\-]" + val rgx = new Regex(legalChars + "*") + + rgx.findFirstIn(value) match { + case Some(t) => + if (!t.equals(value)) + throw new InvalidConfigurationException(prop + " " + value + " is illegal, contains a character other than ASCII alphanumerics, '.', '_' and '-'") + case None => throw new InvalidConfigurationException(prop + " " + value + " is illegal, contains a character other than ASCII alphanumerics, '.', '_' and '-'") + } + } +} + + + + diff --git a/core/src/main/scala/kafka/common/GenerateBrokerIdException.scala b/core/src/main/scala/kafka/common/GenerateBrokerIdException.scala new file mode 100644 index 0000000..13784fe --- /dev/null +++ b/core/src/main/scala/kafka/common/GenerateBrokerIdException.scala @@ -0,0 +1,27 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.common + +/** + * Thrown when there is a failure to generate a zookeeper sequenceId to use as brokerId + */ +class GenerateBrokerIdException(message: String, cause: Throwable) extends RuntimeException(message, cause) { + def this(message: String) = this(message, null) + def this(cause: Throwable) = this(null, cause) + def this() = this(null, null) +} diff --git a/core/src/main/scala/kafka/common/InconsistentBrokerIdException.scala b/core/src/main/scala/kafka/common/InconsistentBrokerIdException.scala new file mode 100644 index 0000000..0c0d1cd --- /dev/null +++ b/core/src/main/scala/kafka/common/InconsistentBrokerIdException.scala @@ -0,0 +1,27 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.common + +/** + * Indicates the brokerId stored in logDirs is not consistent across logDirs. + */ +class InconsistentBrokerIdException(message: String, cause: Throwable) extends RuntimeException(message, cause) { + def this(message: String) = this(message, null) + def this(cause: Throwable) = this(null, cause) + def this() = this(null, null) +} diff --git a/core/src/main/scala/kafka/common/InconsistentBrokerMetadataException.scala b/core/src/main/scala/kafka/common/InconsistentBrokerMetadataException.scala new file mode 100644 index 0000000..2b11512 --- /dev/null +++ b/core/src/main/scala/kafka/common/InconsistentBrokerMetadataException.scala @@ -0,0 +1,27 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.common + +/** + * Indicates the BrokerMetadata stored in logDirs is not consistent across logDirs. + */ +class InconsistentBrokerMetadataException(message: String, cause: Throwable) extends RuntimeException(message, cause) { + def this(message: String) = this(message, null) + def this(cause: Throwable) = this(null, cause) + def this() = this(null, null) +} diff --git a/core/src/main/scala/kafka/common/InconsistentClusterIdException.scala b/core/src/main/scala/kafka/common/InconsistentClusterIdException.scala new file mode 100644 index 0000000..6868dd8 --- /dev/null +++ b/core/src/main/scala/kafka/common/InconsistentClusterIdException.scala @@ -0,0 +1,27 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.common + +/** + * Indicates the clusterId stored in logDirs is not consistent with the clusterIs stored in ZK. + */ +class InconsistentClusterIdException(message: String, cause: Throwable) extends RuntimeException(message, cause) { + def this(message: String) = this(message, null) + def this(cause: Throwable) = this(null, cause) + def this() = this(null, null) +} diff --git a/core/src/main/scala/kafka/common/InconsistentNodeIdException.scala b/core/src/main/scala/kafka/common/InconsistentNodeIdException.scala new file mode 100644 index 0000000..2fd8f15 --- /dev/null +++ b/core/src/main/scala/kafka/common/InconsistentNodeIdException.scala @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.common + +class InconsistentNodeIdException(message: String, cause: Throwable) extends RuntimeException(message, cause) { + def this(message: String) = this(message, null) +} diff --git a/core/src/main/scala/kafka/common/IndexOffsetOverflowException.scala b/core/src/main/scala/kafka/common/IndexOffsetOverflowException.scala new file mode 100644 index 0000000..5dd9b43 --- /dev/null +++ b/core/src/main/scala/kafka/common/IndexOffsetOverflowException.scala @@ -0,0 +1,25 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.common + +/** + * Indicates that an attempt was made to append a message whose offset could cause the index offset to overflow. + */ +class IndexOffsetOverflowException(message: String, cause: Throwable) extends org.apache.kafka.common.KafkaException(message, cause) { + def this(message: String) = this(message, null) +} diff --git a/core/src/main/scala/kafka/common/InterBrokerSendThread.scala b/core/src/main/scala/kafka/common/InterBrokerSendThread.scala new file mode 100644 index 0000000..c2724e2 --- /dev/null +++ b/core/src/main/scala/kafka/common/InterBrokerSendThread.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.common + +import java.util.Map.Entry +import java.util.{ArrayDeque, ArrayList, Collection, Collections, HashMap, Iterator} +import kafka.utils.ShutdownableThread +import org.apache.kafka.clients.{ClientRequest, ClientResponse, KafkaClient, RequestCompletionHandler} +import org.apache.kafka.common.Node +import org.apache.kafka.common.errors.{AuthenticationException, DisconnectException} +import org.apache.kafka.common.internals.FatalExitError +import org.apache.kafka.common.requests.AbstractRequest +import org.apache.kafka.common.utils.Time + +import scala.jdk.CollectionConverters._ + +/** + * Class for inter-broker send thread that utilize a non-blocking network client. + */ +abstract class InterBrokerSendThread( + name: String, + networkClient: KafkaClient, + requestTimeoutMs: Int, + time: Time, + isInterruptible: Boolean = true +) extends ShutdownableThread(name, isInterruptible) { + + private val unsentRequests = new UnsentRequests + + def generateRequests(): Iterable[RequestAndCompletionHandler] + + def hasUnsentRequests: Boolean = unsentRequests.iterator().hasNext + + override def shutdown(): Unit = { + initiateShutdown() + networkClient.initiateClose() + awaitShutdown() + networkClient.close() + } + + private def drainGeneratedRequests(): Unit = { + generateRequests().foreach { request => + unsentRequests.put(request.destination, + networkClient.newClientRequest( + request.destination.idString, + request.request, + request.creationTimeMs, + true, + requestTimeoutMs, + request.handler + )) + } + } + + protected def pollOnce(maxTimeoutMs: Long): Unit = { + try { + drainGeneratedRequests() + var now = time.milliseconds() + val timeout = sendRequests(now, maxTimeoutMs) + networkClient.poll(timeout, now) + now = time.milliseconds() + checkDisconnects(now) + failExpiredRequests(now) + unsentRequests.clean() + } catch { + case _: DisconnectException if !networkClient.active() => + // DisconnectException is expected when NetworkClient#initiateClose is called + case e: FatalExitError => throw e + case t: Throwable => + error(s"unhandled exception caught in InterBrokerSendThread", t) + // rethrow any unhandled exceptions as FatalExitError so the JVM will be terminated + // as we will be in an unknown state with potentially some requests dropped and not + // being able to make progress. Known and expected Errors should have been appropriately + // dealt with already. + throw new FatalExitError() + } + } + + override def doWork(): Unit = { + pollOnce(Long.MaxValue) + } + + private def sendRequests(now: Long, maxTimeoutMs: Long): Long = { + var pollTimeout = maxTimeoutMs + for (node <- unsentRequests.nodes.asScala) { + val requestIterator = unsentRequests.requestIterator(node) + while (requestIterator.hasNext) { + val request = requestIterator.next + if (networkClient.ready(node, now)) { + networkClient.send(request, now) + requestIterator.remove() + } else + pollTimeout = Math.min(pollTimeout, networkClient.connectionDelay(node, now)) + } + } + pollTimeout + } + + private def checkDisconnects(now: Long): Unit = { + // any disconnects affecting requests that have already been transmitted will be handled + // by NetworkClient, so we just need to check whether connections for any of the unsent + // requests have been disconnected; if they have, then we complete the corresponding future + // and set the disconnect flag in the ClientResponse + val iterator = unsentRequests.iterator() + while (iterator.hasNext) { + val entry = iterator.next + val (node, requests) = (entry.getKey, entry.getValue) + if (!requests.isEmpty && networkClient.connectionFailed(node)) { + iterator.remove() + for (request <- requests.asScala) { + val authenticationException = networkClient.authenticationException(node) + if (authenticationException != null) + error(s"Failed to send the following request due to authentication error: $request") + completeWithDisconnect(request, now, authenticationException) + } + } + } + } + + private def failExpiredRequests(now: Long): Unit = { + // clear all expired unsent requests + val timedOutRequests = unsentRequests.removeAllTimedOut(now) + for (request <- timedOutRequests.asScala) { + debug(s"Failed to send the following request after ${request.requestTimeoutMs} ms: $request") + completeWithDisconnect(request, now, null) + } + } + + def completeWithDisconnect(request: ClientRequest, + now: Long, + authenticationException: AuthenticationException): Unit = { + val handler = request.callback + handler.onComplete(new ClientResponse(request.makeHeader(request.requestBuilder().latestAllowedVersion()), + handler, request.destination, now /* createdTimeMs */ , now /* receivedTimeMs */ , true /* disconnected */ , + null /* versionMismatch */ , authenticationException, null)) + } + + def wakeup(): Unit = networkClient.wakeup() +} + +case class RequestAndCompletionHandler( + creationTimeMs: Long, + destination: Node, + request: AbstractRequest.Builder[_ <: AbstractRequest], + handler: RequestCompletionHandler +) + +private class UnsentRequests { + private val unsent = new HashMap[Node, ArrayDeque[ClientRequest]] + + def put(node: Node, request: ClientRequest): Unit = { + var requests = unsent.get(node) + if (requests == null) { + requests = new ArrayDeque[ClientRequest] + unsent.put(node, requests) + } + requests.add(request) + } + + def removeAllTimedOut(now: Long): Collection[ClientRequest] = { + val expiredRequests = new ArrayList[ClientRequest] + for (requests <- unsent.values.asScala) { + val requestIterator = requests.iterator + var foundExpiredRequest = false + while (requestIterator.hasNext && !foundExpiredRequest) { + val request = requestIterator.next + val elapsedMs = Math.max(0, now - request.createdTimeMs) + if (elapsedMs > request.requestTimeoutMs) { + expiredRequests.add(request) + requestIterator.remove() + foundExpiredRequest = true + } + } + } + expiredRequests + } + + def clean(): Unit = { + val iterator = unsent.values.iterator + while (iterator.hasNext) { + val requests = iterator.next + if (requests.isEmpty) + iterator.remove() + } + } + + def iterator(): Iterator[Entry[Node, ArrayDeque[ClientRequest]]] = { + unsent.entrySet().iterator() + } + + def requestIterator(node: Node): Iterator[ClientRequest] = { + val requests = unsent.get(node) + if (requests == null) + Collections.emptyIterator[ClientRequest] + else + requests.iterator + } + + def nodes: java.util.Set[Node] = unsent.keySet +} diff --git a/core/src/main/scala/kafka/common/KafkaException.scala b/core/src/main/scala/kafka/common/KafkaException.scala new file mode 100644 index 0000000..9c34dd9 --- /dev/null +++ b/core/src/main/scala/kafka/common/KafkaException.scala @@ -0,0 +1,27 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +package kafka.common + +/** + * Usage of this class is discouraged. Use org.apache.kafka.common.KafkaException instead. + * + * This class will be removed once kafka.security.auth classes are removed. + */ +class KafkaException(message: String, t: Throwable) extends RuntimeException(message, t) { + def this(message: String) = this(message, null) + def this(t: Throwable) = this("", t) +} diff --git a/core/src/main/scala/kafka/common/LogCleaningAbortedException.scala b/core/src/main/scala/kafka/common/LogCleaningAbortedException.scala new file mode 100644 index 0000000..5ea6632 --- /dev/null +++ b/core/src/main/scala/kafka/common/LogCleaningAbortedException.scala @@ -0,0 +1,24 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.common + +/** + * Thrown when a log cleaning task is requested to be aborted. + */ +class LogCleaningAbortedException() extends RuntimeException() { +} diff --git a/core/src/main/scala/kafka/common/LogSegmentOffsetOverflowException.scala b/core/src/main/scala/kafka/common/LogSegmentOffsetOverflowException.scala new file mode 100644 index 0000000..2de5906 --- /dev/null +++ b/core/src/main/scala/kafka/common/LogSegmentOffsetOverflowException.scala @@ -0,0 +1,30 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.common + +import kafka.log.LogSegment + +/** + * Indicates that the log segment contains one or more messages that overflow the offset (and / or time) index. This is + * not a typical scenario, and could only happen when brokers have log segments that were created before the patch for + * KAFKA-5413. With KAFKA-6264, we have the ability to split such log segments into multiple log segments such that we + * do not have any segments with offset overflow. + */ +class LogSegmentOffsetOverflowException(val segment: LogSegment, val offset: Long) + extends org.apache.kafka.common.KafkaException(s"Detected offset overflow at offset $offset in segment $segment") { +} diff --git a/core/src/main/scala/kafka/common/LongRef.scala b/core/src/main/scala/kafka/common/LongRef.scala new file mode 100644 index 0000000..f2b1e32 --- /dev/null +++ b/core/src/main/scala/kafka/common/LongRef.scala @@ -0,0 +1,61 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.common + +/** + * A mutable cell that holds a value of type `Long`. One should generally prefer using value-based programming (i.e. + * passing and returning `Long` values), but this class can be useful in some scenarios. + * + * Unlike `AtomicLong`, this class is not thread-safe and there are no atomicity guarantees. + */ +class LongRef(var value: Long) { + + def addAndGet(delta: Long): Long = { + value += delta + value + } + + def getAndAdd(delta: Long): Long = { + val result = value + value += delta + result + } + + def getAndIncrement(): Long = { + val v = value + value += 1 + v + } + + def incrementAndGet(): Long = { + value += 1 + value + } + + def getAndDecrement(): Long = { + val v = value + value -= 1 + v + } + + def decrementAndGet(): Long = { + value -= 1 + value + } + +} diff --git a/core/src/main/scala/kafka/common/MessageReader.scala b/core/src/main/scala/kafka/common/MessageReader.scala new file mode 100644 index 0000000..de456e1 --- /dev/null +++ b/core/src/main/scala/kafka/common/MessageReader.scala @@ -0,0 +1,39 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.common + +import java.io.InputStream +import java.util.Properties + +import org.apache.kafka.clients.producer.ProducerRecord + +/** + * Typical implementations of this interface convert data from an `InputStream` received via `init` into a + * `ProducerRecord` instance on each invocation of `readMessage`. + * + * This is used by the `ConsoleProducer`. + */ +trait MessageReader { + + def init(inputStream: InputStream, props: Properties): Unit = {} + + def readMessage(): ProducerRecord[Array[Byte], Array[Byte]] + + def close(): Unit = {} + +} diff --git a/core/src/main/scala/kafka/common/OffsetAndMetadata.scala b/core/src/main/scala/kafka/common/OffsetAndMetadata.scala new file mode 100644 index 0000000..632c863 --- /dev/null +++ b/core/src/main/scala/kafka/common/OffsetAndMetadata.scala @@ -0,0 +1,48 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.common + +import java.util.Optional + +case class OffsetAndMetadata(offset: Long, + leaderEpoch: Optional[Integer], + metadata: String, + commitTimestamp: Long, + expireTimestamp: Option[Long]) { + + + override def toString: String = { + s"OffsetAndMetadata(offset=$offset" + + s", leaderEpoch=$leaderEpoch" + + s", metadata=$metadata" + + s", commitTimestamp=$commitTimestamp" + + s", expireTimestamp=$expireTimestamp)" + } +} + +object OffsetAndMetadata { + val NoMetadata: String = "" + + def apply(offset: Long, metadata: String, commitTimestamp: Long): OffsetAndMetadata = { + OffsetAndMetadata(offset, Optional.empty(), metadata, commitTimestamp, None) + } + + def apply(offset: Long, metadata: String, commitTimestamp: Long, expireTimestamp: Long): OffsetAndMetadata = { + OffsetAndMetadata(offset, Optional.empty(), metadata, commitTimestamp, Some(expireTimestamp)) + } +} diff --git a/core/src/main/scala/kafka/common/OffsetsOutOfOrderException.scala b/core/src/main/scala/kafka/common/OffsetsOutOfOrderException.scala new file mode 100644 index 0000000..f8daaa4 --- /dev/null +++ b/core/src/main/scala/kafka/common/OffsetsOutOfOrderException.scala @@ -0,0 +1,25 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.common + +/** + * Indicates the follower received records with non-monotonically increasing offsets + */ +class OffsetsOutOfOrderException(message: String) extends RuntimeException(message) { +} + diff --git a/core/src/main/scala/kafka/common/RecordValidationException.scala b/core/src/main/scala/kafka/common/RecordValidationException.scala new file mode 100644 index 0000000..baa7d72 --- /dev/null +++ b/core/src/main/scala/kafka/common/RecordValidationException.scala @@ -0,0 +1,28 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.common + +import org.apache.kafka.common.errors.ApiException +import org.apache.kafka.common.requests.ProduceResponse.RecordError + +import scala.collection.Seq + +class RecordValidationException(val invalidException: ApiException, + val recordErrors: Seq[RecordError]) + extends RuntimeException(invalidException) { +} diff --git a/core/src/main/scala/kafka/common/StateChangeFailedException.scala b/core/src/main/scala/kafka/common/StateChangeFailedException.scala new file mode 100644 index 0000000..fd56796 --- /dev/null +++ b/core/src/main/scala/kafka/common/StateChangeFailedException.scala @@ -0,0 +1,23 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.common + +class StateChangeFailedException(message: String, cause: Throwable) extends RuntimeException(message, cause) { + def this(message: String) = this(message, null) + def this() = this(null, null) +} \ No newline at end of file diff --git a/core/src/main/scala/kafka/common/ThreadShutdownException.scala b/core/src/main/scala/kafka/common/ThreadShutdownException.scala new file mode 100644 index 0000000..6554a5e --- /dev/null +++ b/core/src/main/scala/kafka/common/ThreadShutdownException.scala @@ -0,0 +1,24 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.common + +/** + * An exception that indicates a thread is being shut down normally. + */ +class ThreadShutdownException() extends RuntimeException { +} diff --git a/core/src/main/scala/kafka/common/TopicAlreadyMarkedForDeletionException.scala b/core/src/main/scala/kafka/common/TopicAlreadyMarkedForDeletionException.scala new file mode 100644 index 0000000..c83cea9 --- /dev/null +++ b/core/src/main/scala/kafka/common/TopicAlreadyMarkedForDeletionException.scala @@ -0,0 +1,21 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.common + +class TopicAlreadyMarkedForDeletionException(message: String) extends RuntimeException(message) { +} \ No newline at end of file diff --git a/core/src/main/scala/kafka/common/UnexpectedAppendOffsetException.scala b/core/src/main/scala/kafka/common/UnexpectedAppendOffsetException.scala new file mode 100644 index 0000000..e719a93 --- /dev/null +++ b/core/src/main/scala/kafka/common/UnexpectedAppendOffsetException.scala @@ -0,0 +1,29 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.common + +/** + * Indicates the follower or the future replica received records from the leader (or current + * replica) with first offset less than expected next offset. + * @param firstOffset The first offset of the records to append + * @param lastOffset The last offset of the records to append + */ +class UnexpectedAppendOffsetException(val message: String, + val firstOffset: Long, + val lastOffset: Long) extends RuntimeException(message) { +} diff --git a/core/src/main/scala/kafka/common/UnknownCodecException.scala b/core/src/main/scala/kafka/common/UnknownCodecException.scala new file mode 100644 index 0000000..7e66901 --- /dev/null +++ b/core/src/main/scala/kafka/common/UnknownCodecException.scala @@ -0,0 +1,26 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.common + +/** + * Indicates the client has requested a range no longer available on the server + */ +class UnknownCodecException(message: String) extends RuntimeException(message) { + def this() = this(null) +} + diff --git a/core/src/main/scala/kafka/common/ZkNodeChangeNotificationListener.scala b/core/src/main/scala/kafka/common/ZkNodeChangeNotificationListener.scala new file mode 100644 index 0000000..c49ec08 --- /dev/null +++ b/core/src/main/scala/kafka/common/ZkNodeChangeNotificationListener.scala @@ -0,0 +1,159 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.common + +import java.nio.charset.StandardCharsets.UTF_8 +import java.util.concurrent.LinkedBlockingQueue +import java.util.concurrent.atomic.AtomicBoolean + +import kafka.utils.{Logging, ShutdownableThread} +import kafka.zk.{KafkaZkClient, StateChangeHandlers} +import kafka.zookeeper.{StateChangeHandler, ZNodeChildChangeHandler} +import org.apache.kafka.common.utils.Time + +import scala.collection.Seq +import scala.util.{Failure, Try} + +/** + * Handle the notificationMessage. + */ +trait NotificationHandler { + def processNotification(notificationMessage: Array[Byte]): Unit +} + +/** + * A listener that subscribes to seqNodeRoot for any child changes where all children are assumed to be sequence node + * with seqNodePrefix. When a child is added under seqNodeRoot this class gets notified, it looks at lastExecutedChange + * number to avoid duplicate processing and if it finds an unprocessed child, it reads its data and calls supplied + * notificationHandler's processNotification() method with the child's data as argument. As part of processing these changes it also + * purges any children with currentTime - createTime > changeExpirationMs. + * + * @param zkClient + * @param seqNodeRoot + * @param seqNodePrefix + * @param notificationHandler + * @param changeExpirationMs + * @param time + */ +class ZkNodeChangeNotificationListener(private val zkClient: KafkaZkClient, + private val seqNodeRoot: String, + private val seqNodePrefix: String, + private val notificationHandler: NotificationHandler, + private val changeExpirationMs: Long = 15 * 60 * 1000, + private val time: Time = Time.SYSTEM) extends Logging { + private var lastExecutedChange = -1L + private val queue = new LinkedBlockingQueue[ChangeNotification] + private val thread = new ChangeEventProcessThread(s"$seqNodeRoot-event-process-thread") + private val isClosed = new AtomicBoolean(false) + + def init(): Unit = { + zkClient.registerStateChangeHandler(ZkStateChangeHandler) + zkClient.registerZNodeChildChangeHandler(ChangeNotificationHandler) + addChangeNotification() + thread.start() + } + + def close() = { + isClosed.set(true) + zkClient.unregisterStateChangeHandler(ZkStateChangeHandler.name) + zkClient.unregisterZNodeChildChangeHandler(ChangeNotificationHandler.path) + queue.clear() + thread.shutdown() + } + + /** + * Process notifications + */ + private def processNotifications(): Unit = { + try { + val notifications = zkClient.getChildren(seqNodeRoot).sorted + if (notifications.nonEmpty) { + info(s"Processing notification(s) to $seqNodeRoot") + val now = time.milliseconds + for (notification <- notifications) { + val changeId = changeNumber(notification) + if (changeId > lastExecutedChange) { + processNotification(notification) + lastExecutedChange = changeId + } + } + purgeObsoleteNotifications(now, notifications) + } + } catch { + case e: InterruptedException => if (!isClosed.get) error(s"Error while processing notification change for path = $seqNodeRoot", e) + case e: Exception => error(s"Error while processing notification change for path = $seqNodeRoot", e) + } + } + + private def processNotification(notification: String): Unit = { + val changeZnode = seqNodeRoot + "/" + notification + val (data, _) = zkClient.getDataAndStat(changeZnode) + data match { + case Some(d) => Try(notificationHandler.processNotification(d)) match { + case Failure(e) => error(s"error processing change notification ${new String(d, UTF_8)} from $changeZnode", e) + case _ => + } + case None => warn(s"read null data from $changeZnode") + } + } + + private def addChangeNotification(): Unit = { + if (!isClosed.get && queue.peek() == null) + queue.put(new ChangeNotification) + } + + class ChangeNotification { + def process(): Unit = processNotifications() + } + + /** + * Purges expired notifications. + * + * @param now + * @param notifications + */ + private def purgeObsoleteNotifications(now: Long, notifications: Seq[String]): Unit = { + for (notification <- notifications.sorted) { + val notificationNode = seqNodeRoot + "/" + notification + val (data, stat) = zkClient.getDataAndStat(notificationNode) + if (data.isDefined) { + if (now - stat.getCtime > changeExpirationMs) { + debug(s"Purging change notification $notificationNode") + zkClient.deletePath(notificationNode) + } + } + } + } + + /* get the change number from a change notification znode */ + private def changeNumber(name: String): Long = name.substring(seqNodePrefix.length).toLong + + class ChangeEventProcessThread(name: String) extends ShutdownableThread(name = name) { + override def doWork(): Unit = queue.take().process() + } + + object ChangeNotificationHandler extends ZNodeChildChangeHandler { + override val path: String = seqNodeRoot + override def handleChildChange(): Unit = addChangeNotification() + } + + object ZkStateChangeHandler extends StateChangeHandler { + override val name: String = StateChangeHandlers.zkNodeChangeListenerHandler(seqNodeRoot) + override def afterInitializingSession(): Unit = addChangeNotification() + } +} + diff --git a/core/src/main/scala/kafka/consumer/BaseConsumerRecord.scala b/core/src/main/scala/kafka/consumer/BaseConsumerRecord.scala new file mode 100644 index 0000000..7628b6b --- /dev/null +++ b/core/src/main/scala/kafka/consumer/BaseConsumerRecord.scala @@ -0,0 +1,33 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.consumer + +import org.apache.kafka.common.header.Headers +import org.apache.kafka.common.header.internals.RecordHeaders +import org.apache.kafka.common.record.{RecordBatch, TimestampType} + +@deprecated("This class has been deprecated and will be removed in a future release. " + + "Please use org.apache.kafka.clients.consumer.ConsumerRecord instead.", "0.11.0.0") +case class BaseConsumerRecord(topic: String, + partition: Int, + offset: Long, + timestamp: Long = RecordBatch.NO_TIMESTAMP, + timestampType: TimestampType = TimestampType.NO_TIMESTAMP_TYPE, + key: Array[Byte], + value: Array[Byte], + headers: Headers = new RecordHeaders()) diff --git a/core/src/main/scala/kafka/controller/ControllerChannelManager.scala b/core/src/main/scala/kafka/controller/ControllerChannelManager.scala new file mode 100755 index 0000000..2f10710 --- /dev/null +++ b/core/src/main/scala/kafka/controller/ControllerChannelManager.scala @@ -0,0 +1,683 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +package kafka.controller + +import java.net.SocketTimeoutException +import java.util.concurrent.{BlockingQueue, LinkedBlockingQueue, TimeUnit} + +import com.yammer.metrics.core.{Gauge, Timer} +import kafka.api._ +import kafka.cluster.Broker +import kafka.metrics.KafkaMetricsGroup +import kafka.server.KafkaConfig +import kafka.utils._ +import kafka.utils.Implicits._ +import org.apache.kafka.clients._ +import org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrPartitionState +import org.apache.kafka.common.message.StopReplicaRequestData.{StopReplicaPartitionState, StopReplicaTopicState} +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.network._ +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.message.UpdateMetadataRequestData.{UpdateMetadataBroker, UpdateMetadataEndpoint, UpdateMetadataPartitionState} +import org.apache.kafka.common.requests._ +import org.apache.kafka.common.security.JaasContext +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.utils.{LogContext, Time} +import org.apache.kafka.common.{KafkaException, Node, Reconfigurable, TopicPartition, Uuid} + +import scala.jdk.CollectionConverters._ +import scala.collection.mutable.HashMap +import scala.collection.{Seq, Set, mutable} + +object ControllerChannelManager { + val QueueSizeMetricName = "QueueSize" + val RequestRateAndQueueTimeMetricName = "RequestRateAndQueueTimeMs" +} + +class ControllerChannelManager(controllerContext: ControllerContext, + config: KafkaConfig, + time: Time, + metrics: Metrics, + stateChangeLogger: StateChangeLogger, + threadNamePrefix: Option[String] = None) extends Logging with KafkaMetricsGroup { + import ControllerChannelManager._ + + protected val brokerStateInfo = new HashMap[Int, ControllerBrokerStateInfo] + private val brokerLock = new Object + this.logIdent = "[Channel manager on controller " + config.brokerId + "]: " + + newGauge("TotalQueueSize", + () => brokerLock synchronized { + brokerStateInfo.values.iterator.map(_.messageQueue.size).sum + } + ) + + def startup() = { + controllerContext.liveOrShuttingDownBrokers.foreach(addNewBroker) + + brokerLock synchronized { + brokerStateInfo.foreach(brokerState => startRequestSendThread(brokerState._1)) + } + } + + def shutdown() = { + brokerLock synchronized { + brokerStateInfo.values.toList.foreach(removeExistingBroker) + } + } + + def sendRequest(brokerId: Int, request: AbstractControlRequest.Builder[_ <: AbstractControlRequest], + callback: AbstractResponse => Unit = null): Unit = { + brokerLock synchronized { + val stateInfoOpt = brokerStateInfo.get(brokerId) + stateInfoOpt match { + case Some(stateInfo) => + stateInfo.messageQueue.put(QueueItem(request.apiKey, request, callback, time.milliseconds())) + case None => + warn(s"Not sending request $request to broker $brokerId, since it is offline.") + } + } + } + + def addBroker(broker: Broker): Unit = { + // be careful here. Maybe the startup() API has already started the request send thread + brokerLock synchronized { + if (!brokerStateInfo.contains(broker.id)) { + addNewBroker(broker) + startRequestSendThread(broker.id) + } + } + } + + def removeBroker(brokerId: Int): Unit = { + brokerLock synchronized { + removeExistingBroker(brokerStateInfo(brokerId)) + } + } + + private def addNewBroker(broker: Broker): Unit = { + val messageQueue = new LinkedBlockingQueue[QueueItem] + debug(s"Controller ${config.brokerId} trying to connect to broker ${broker.id}") + val controllerToBrokerListenerName = config.controlPlaneListenerName.getOrElse(config.interBrokerListenerName) + val controllerToBrokerSecurityProtocol = config.controlPlaneSecurityProtocol.getOrElse(config.interBrokerSecurityProtocol) + val brokerNode = broker.node(controllerToBrokerListenerName) + val logContext = new LogContext(s"[Controller id=${config.brokerId}, targetBrokerId=${brokerNode.idString}] ") + val (networkClient, reconfigurableChannelBuilder) = { + val channelBuilder = ChannelBuilders.clientChannelBuilder( + controllerToBrokerSecurityProtocol, + JaasContext.Type.SERVER, + config, + controllerToBrokerListenerName, + config.saslMechanismInterBrokerProtocol, + time, + config.saslInterBrokerHandshakeRequestEnable, + logContext + ) + val reconfigurableChannelBuilder = channelBuilder match { + case reconfigurable: Reconfigurable => + config.addReconfigurable(reconfigurable) + Some(reconfigurable) + case _ => None + } + val selector = new Selector( + NetworkReceive.UNLIMITED, + Selector.NO_IDLE_TIMEOUT_MS, + metrics, + time, + "controller-channel", + Map("broker-id" -> brokerNode.idString).asJava, + false, + channelBuilder, + logContext + ) + val networkClient = new NetworkClient( + selector, + new ManualMetadataUpdater(Seq(brokerNode).asJava), + config.brokerId.toString, + 1, + 0, + 0, + Selectable.USE_DEFAULT_BUFFER_SIZE, + Selectable.USE_DEFAULT_BUFFER_SIZE, + config.requestTimeoutMs, + config.connectionSetupTimeoutMs, + config.connectionSetupTimeoutMaxMs, + time, + false, + new ApiVersions, + logContext + ) + (networkClient, reconfigurableChannelBuilder) + } + val threadName = threadNamePrefix match { + case None => s"Controller-${config.brokerId}-to-broker-${broker.id}-send-thread" + case Some(name) => s"$name:Controller-${config.brokerId}-to-broker-${broker.id}-send-thread" + } + + val requestRateAndQueueTimeMetrics = newTimer( + RequestRateAndQueueTimeMetricName, TimeUnit.MILLISECONDS, TimeUnit.SECONDS, brokerMetricTags(broker.id) + ) + + val requestThread = new RequestSendThread(config.brokerId, controllerContext, messageQueue, networkClient, + brokerNode, config, time, requestRateAndQueueTimeMetrics, stateChangeLogger, threadName) + requestThread.setDaemon(false) + + val queueSizeGauge = newGauge(QueueSizeMetricName, () => messageQueue.size, brokerMetricTags(broker.id)) + + brokerStateInfo.put(broker.id, ControllerBrokerStateInfo(networkClient, brokerNode, messageQueue, + requestThread, queueSizeGauge, requestRateAndQueueTimeMetrics, reconfigurableChannelBuilder)) + } + + private def brokerMetricTags(brokerId: Int) = Map("broker-id" -> brokerId.toString) + + private def removeExistingBroker(brokerState: ControllerBrokerStateInfo): Unit = { + try { + // Shutdown the RequestSendThread before closing the NetworkClient to avoid the concurrent use of the + // non-threadsafe classes as described in KAFKA-4959. + // The call to shutdownLatch.await() in ShutdownableThread.shutdown() serves as a synchronization barrier that + // hands off the NetworkClient from the RequestSendThread to the ZkEventThread. + brokerState.reconfigurableChannelBuilder.foreach(config.removeReconfigurable) + brokerState.requestSendThread.shutdown() + brokerState.networkClient.close() + brokerState.messageQueue.clear() + removeMetric(QueueSizeMetricName, brokerMetricTags(brokerState.brokerNode.id)) + removeMetric(RequestRateAndQueueTimeMetricName, brokerMetricTags(brokerState.brokerNode.id)) + brokerStateInfo.remove(brokerState.brokerNode.id) + } catch { + case e: Throwable => error("Error while removing broker by the controller", e) + } + } + + protected def startRequestSendThread(brokerId: Int): Unit = { + val requestThread = brokerStateInfo(brokerId).requestSendThread + if (requestThread.getState == Thread.State.NEW) + requestThread.start() + } +} + +case class QueueItem(apiKey: ApiKeys, request: AbstractControlRequest.Builder[_ <: AbstractControlRequest], + callback: AbstractResponse => Unit, enqueueTimeMs: Long) + +class RequestSendThread(val controllerId: Int, + val controllerContext: ControllerContext, + val queue: BlockingQueue[QueueItem], + val networkClient: NetworkClient, + val brokerNode: Node, + val config: KafkaConfig, + val time: Time, + val requestRateAndQueueTimeMetrics: Timer, + val stateChangeLogger: StateChangeLogger, + name: String) + extends ShutdownableThread(name = name) { + + logIdent = s"[RequestSendThread controllerId=$controllerId] " + + private val socketTimeoutMs = config.controllerSocketTimeoutMs + + override def doWork(): Unit = { + + def backoff(): Unit = pause(100, TimeUnit.MILLISECONDS) + + val QueueItem(apiKey, requestBuilder, callback, enqueueTimeMs) = queue.take() + requestRateAndQueueTimeMetrics.update(time.milliseconds() - enqueueTimeMs, TimeUnit.MILLISECONDS) + + var clientResponse: ClientResponse = null + try { + var isSendSuccessful = false + while (isRunning && !isSendSuccessful) { + // if a broker goes down for a long time, then at some point the controller's zookeeper listener will trigger a + // removeBroker which will invoke shutdown() on this thread. At that point, we will stop retrying. + try { + if (!brokerReady()) { + isSendSuccessful = false + backoff() + } + else { + val clientRequest = networkClient.newClientRequest(brokerNode.idString, requestBuilder, + time.milliseconds(), true) + clientResponse = NetworkClientUtils.sendAndReceive(networkClient, clientRequest, time) + isSendSuccessful = true + } + } catch { + case e: Throwable => // if the send was not successful, reconnect to broker and resend the message + warn(s"Controller $controllerId epoch ${controllerContext.epoch} fails to send request $requestBuilder " + + s"to broker $brokerNode. Reconnecting to broker.", e) + networkClient.close(brokerNode.idString) + isSendSuccessful = false + backoff() + } + } + if (clientResponse != null) { + val requestHeader = clientResponse.requestHeader + val api = requestHeader.apiKey + if (api != ApiKeys.LEADER_AND_ISR && api != ApiKeys.STOP_REPLICA && api != ApiKeys.UPDATE_METADATA) + throw new KafkaException(s"Unexpected apiKey received: $apiKey") + + val response = clientResponse.responseBody + + stateChangeLogger.withControllerEpoch(controllerContext.epoch).trace(s"Received response " + + s"$response for request $api with correlation id " + + s"${requestHeader.correlationId} sent to broker $brokerNode") + + if (callback != null) { + callback(response) + } + } + } catch { + case e: Throwable => + error(s"Controller $controllerId fails to send a request to broker $brokerNode", e) + // If there is any socket error (eg, socket timeout), the connection is no longer usable and needs to be recreated. + networkClient.close(brokerNode.idString) + } + } + + private def brokerReady(): Boolean = { + try { + if (!NetworkClientUtils.isReady(networkClient, brokerNode, time.milliseconds())) { + if (!NetworkClientUtils.awaitReady(networkClient, brokerNode, time, socketTimeoutMs)) + throw new SocketTimeoutException(s"Failed to connect within $socketTimeoutMs ms") + + info(s"Controller $controllerId connected to $brokerNode for sending state change requests") + } + + true + } catch { + case e: Throwable => + warn(s"Controller $controllerId's connection to broker $brokerNode was unsuccessful", e) + networkClient.close(brokerNode.idString) + false + } + } + + override def initiateShutdown(): Boolean = { + if (super.initiateShutdown()) { + networkClient.initiateClose() + true + } else + false + } +} + +class ControllerBrokerRequestBatch(config: KafkaConfig, + controllerChannelManager: ControllerChannelManager, + controllerEventManager: ControllerEventManager, + controllerContext: ControllerContext, + stateChangeLogger: StateChangeLogger) + extends AbstractControllerBrokerRequestBatch(config, controllerContext, stateChangeLogger) { + + def sendEvent(event: ControllerEvent): Unit = { + controllerEventManager.put(event) + } + + def sendRequest(brokerId: Int, + request: AbstractControlRequest.Builder[_ <: AbstractControlRequest], + callback: AbstractResponse => Unit = null): Unit = { + controllerChannelManager.sendRequest(brokerId, request, callback) + } + +} + +abstract class AbstractControllerBrokerRequestBatch(config: KafkaConfig, + controllerContext: ControllerContext, + stateChangeLogger: StateChangeLogger) extends Logging { + val controllerId: Int = config.brokerId + val leaderAndIsrRequestMap = mutable.Map.empty[Int, mutable.Map[TopicPartition, LeaderAndIsrPartitionState]] + val stopReplicaRequestMap = mutable.Map.empty[Int, mutable.Map[TopicPartition, StopReplicaPartitionState]] + val updateMetadataRequestBrokerSet = mutable.Set.empty[Int] + val updateMetadataRequestPartitionInfoMap = mutable.Map.empty[TopicPartition, UpdateMetadataPartitionState] + + def sendEvent(event: ControllerEvent): Unit + + def sendRequest(brokerId: Int, + request: AbstractControlRequest.Builder[_ <: AbstractControlRequest], + callback: AbstractResponse => Unit = null): Unit + + def newBatch(): Unit = { + // raise error if the previous batch is not empty + if (leaderAndIsrRequestMap.nonEmpty) + throw new IllegalStateException("Controller to broker state change requests batch is not empty while creating " + + s"a new one. Some LeaderAndIsr state changes $leaderAndIsrRequestMap might be lost ") + if (stopReplicaRequestMap.nonEmpty) + throw new IllegalStateException("Controller to broker state change requests batch is not empty while creating a " + + s"new one. Some StopReplica state changes $stopReplicaRequestMap might be lost ") + if (updateMetadataRequestBrokerSet.nonEmpty) + throw new IllegalStateException("Controller to broker state change requests batch is not empty while creating a " + + s"new one. Some UpdateMetadata state changes to brokers $updateMetadataRequestBrokerSet with partition info " + + s"$updateMetadataRequestPartitionInfoMap might be lost ") + } + + def clear(): Unit = { + leaderAndIsrRequestMap.clear() + stopReplicaRequestMap.clear() + updateMetadataRequestBrokerSet.clear() + updateMetadataRequestPartitionInfoMap.clear() + } + + def addLeaderAndIsrRequestForBrokers(brokerIds: Seq[Int], + topicPartition: TopicPartition, + leaderIsrAndControllerEpoch: LeaderIsrAndControllerEpoch, + replicaAssignment: ReplicaAssignment, + isNew: Boolean): Unit = { + + brokerIds.filter(_ >= 0).foreach { brokerId => + val result = leaderAndIsrRequestMap.getOrElseUpdate(brokerId, mutable.Map.empty) + val alreadyNew = result.get(topicPartition).exists(_.isNew) + val leaderAndIsr = leaderIsrAndControllerEpoch.leaderAndIsr + result.put(topicPartition, new LeaderAndIsrPartitionState() + .setTopicName(topicPartition.topic) + .setPartitionIndex(topicPartition.partition) + .setControllerEpoch(leaderIsrAndControllerEpoch.controllerEpoch) + .setLeader(leaderAndIsr.leader) + .setLeaderEpoch(leaderAndIsr.leaderEpoch) + .setIsr(leaderAndIsr.isr.map(Integer.valueOf).asJava) + .setZkVersion(leaderAndIsr.zkVersion) + .setReplicas(replicaAssignment.replicas.map(Integer.valueOf).asJava) + .setAddingReplicas(replicaAssignment.addingReplicas.map(Integer.valueOf).asJava) + .setRemovingReplicas(replicaAssignment.removingReplicas.map(Integer.valueOf).asJava) + .setIsNew(isNew || alreadyNew)) + } + + addUpdateMetadataRequestForBrokers(controllerContext.liveOrShuttingDownBrokerIds.toSeq, Set(topicPartition)) + } + + def addStopReplicaRequestForBrokers(brokerIds: Seq[Int], + topicPartition: TopicPartition, + deletePartition: Boolean): Unit = { + // A sentinel (-2) is used as an epoch if the topic is queued for deletion. It overrides + // any existing epoch. + val leaderEpoch = if (controllerContext.isTopicQueuedUpForDeletion(topicPartition.topic)) { + LeaderAndIsr.EpochDuringDelete + } else { + controllerContext.partitionLeadershipInfo(topicPartition) + .map(_.leaderAndIsr.leaderEpoch) + .getOrElse(LeaderAndIsr.NoEpoch) + } + + brokerIds.filter(_ >= 0).foreach { brokerId => + val result = stopReplicaRequestMap.getOrElseUpdate(brokerId, mutable.Map.empty) + val alreadyDelete = result.get(topicPartition).exists(_.deletePartition) + result.put(topicPartition, new StopReplicaPartitionState() + .setPartitionIndex(topicPartition.partition()) + .setLeaderEpoch(leaderEpoch) + .setDeletePartition(alreadyDelete || deletePartition)) + } + } + + /** Send UpdateMetadataRequest to the given brokers for the given partitions and partitions that are being deleted */ + def addUpdateMetadataRequestForBrokers(brokerIds: Seq[Int], + partitions: collection.Set[TopicPartition]): Unit = { + + def updateMetadataRequestPartitionInfo(partition: TopicPartition, beingDeleted: Boolean): Unit = { + controllerContext.partitionLeadershipInfo(partition) match { + case Some(LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch)) => + val replicas = controllerContext.partitionReplicaAssignment(partition) + val offlineReplicas = replicas.filter(!controllerContext.isReplicaOnline(_, partition)) + val updatedLeaderAndIsr = + if (beingDeleted) LeaderAndIsr.duringDelete(leaderAndIsr.isr) + else leaderAndIsr + + val partitionStateInfo = new UpdateMetadataPartitionState() + .setTopicName(partition.topic) + .setPartitionIndex(partition.partition) + .setControllerEpoch(controllerEpoch) + .setLeader(updatedLeaderAndIsr.leader) + .setLeaderEpoch(updatedLeaderAndIsr.leaderEpoch) + .setIsr(updatedLeaderAndIsr.isr.map(Integer.valueOf).asJava) + .setZkVersion(updatedLeaderAndIsr.zkVersion) + .setReplicas(replicas.map(Integer.valueOf).asJava) + .setOfflineReplicas(offlineReplicas.map(Integer.valueOf).asJava) + updateMetadataRequestPartitionInfoMap.put(partition, partitionStateInfo) + + case None => + info(s"Leader not yet assigned for partition $partition. Skip sending UpdateMetadataRequest.") + } + } + + updateMetadataRequestBrokerSet ++= brokerIds.filter(_ >= 0) + partitions.foreach(partition => updateMetadataRequestPartitionInfo(partition, + beingDeleted = controllerContext.topicsToBeDeleted.contains(partition.topic))) + } + + private def sendLeaderAndIsrRequest(controllerEpoch: Int, stateChangeLog: StateChangeLogger): Unit = { + val leaderAndIsrRequestVersion: Short = + if (config.interBrokerProtocolVersion >= KAFKA_2_8_IV1) 5 + else if (config.interBrokerProtocolVersion >= KAFKA_2_4_IV1) 4 + else if (config.interBrokerProtocolVersion >= KAFKA_2_4_IV0) 3 + else if (config.interBrokerProtocolVersion >= KAFKA_2_2_IV0) 2 + else if (config.interBrokerProtocolVersion >= KAFKA_1_0_IV0) 1 + else 0 + + leaderAndIsrRequestMap.forKeyValue { (broker, leaderAndIsrPartitionStates) => + if (controllerContext.liveOrShuttingDownBrokerIds.contains(broker)) { + val leaderIds = mutable.Set.empty[Int] + var numBecomeLeaders = 0 + leaderAndIsrPartitionStates.forKeyValue { (topicPartition, state) => + leaderIds += state.leader + val typeOfRequest = if (broker == state.leader) { + numBecomeLeaders += 1 + "become-leader" + } else { + "become-follower" + } + if (stateChangeLog.isTraceEnabled) + stateChangeLog.trace(s"Sending $typeOfRequest LeaderAndIsr request $state to broker $broker for partition $topicPartition") + } + stateChangeLog.info(s"Sending LeaderAndIsr request to broker $broker with $numBecomeLeaders become-leader " + + s"and ${leaderAndIsrPartitionStates.size - numBecomeLeaders} become-follower partitions") + val leaders = controllerContext.liveOrShuttingDownBrokers.filter(b => leaderIds.contains(b.id)).map { + _.node(config.interBrokerListenerName) + } + val brokerEpoch = controllerContext.liveBrokerIdAndEpochs(broker) + val topicIds = leaderAndIsrPartitionStates.keys + .map(_.topic) + .toSet[String] + .map(topic => (topic, controllerContext.topicIds.getOrElse(topic, Uuid.ZERO_UUID))) + .toMap + val leaderAndIsrRequestBuilder = new LeaderAndIsrRequest.Builder(leaderAndIsrRequestVersion, controllerId, + controllerEpoch, brokerEpoch, leaderAndIsrPartitionStates.values.toBuffer.asJava, topicIds.asJava, leaders.asJava) + sendRequest(broker, leaderAndIsrRequestBuilder, (r: AbstractResponse) => { + val leaderAndIsrResponse = r.asInstanceOf[LeaderAndIsrResponse] + sendEvent(LeaderAndIsrResponseReceived(leaderAndIsrResponse, broker)) + }) + } + } + leaderAndIsrRequestMap.clear() + } + + private def sendUpdateMetadataRequests(controllerEpoch: Int, stateChangeLog: StateChangeLogger): Unit = { + stateChangeLog.info(s"Sending UpdateMetadata request to brokers $updateMetadataRequestBrokerSet " + + s"for ${updateMetadataRequestPartitionInfoMap.size} partitions") + + val partitionStates = updateMetadataRequestPartitionInfoMap.values.toBuffer + val updateMetadataRequestVersion: Short = + if (config.interBrokerProtocolVersion >= KAFKA_2_8_IV1) 7 + else if (config.interBrokerProtocolVersion >= KAFKA_2_4_IV1) 6 + else if (config.interBrokerProtocolVersion >= KAFKA_2_2_IV0) 5 + else if (config.interBrokerProtocolVersion >= KAFKA_1_0_IV0) 4 + else if (config.interBrokerProtocolVersion >= KAFKA_0_10_2_IV0) 3 + else if (config.interBrokerProtocolVersion >= KAFKA_0_10_0_IV1) 2 + else if (config.interBrokerProtocolVersion >= KAFKA_0_9_0) 1 + else 0 + + val liveBrokers = controllerContext.liveOrShuttingDownBrokers.iterator.map { broker => + val endpoints = if (updateMetadataRequestVersion == 0) { + // Version 0 of UpdateMetadataRequest only supports PLAINTEXT + val securityProtocol = SecurityProtocol.PLAINTEXT + val listenerName = ListenerName.forSecurityProtocol(securityProtocol) + val node = broker.node(listenerName) + Seq(new UpdateMetadataEndpoint() + .setHost(node.host) + .setPort(node.port) + .setSecurityProtocol(securityProtocol.id) + .setListener(listenerName.value)) + } else { + broker.endPoints.map { endpoint => + new UpdateMetadataEndpoint() + .setHost(endpoint.host) + .setPort(endpoint.port) + .setSecurityProtocol(endpoint.securityProtocol.id) + .setListener(endpoint.listenerName.value) + } + } + new UpdateMetadataBroker() + .setId(broker.id) + .setEndpoints(endpoints.asJava) + .setRack(broker.rack.orNull) + }.toBuffer + + updateMetadataRequestBrokerSet.intersect(controllerContext.liveOrShuttingDownBrokerIds).foreach { broker => + val brokerEpoch = controllerContext.liveBrokerIdAndEpochs(broker) + val topicIds = partitionStates.map(_.topicName()) + .distinct + .filter(controllerContext.topicIds.contains) + .map(topic => (topic, controllerContext.topicIds(topic))).toMap + val updateMetadataRequestBuilder = new UpdateMetadataRequest.Builder(updateMetadataRequestVersion, + controllerId, controllerEpoch, brokerEpoch, partitionStates.asJava, liveBrokers.asJava, topicIds.asJava) + sendRequest(broker, updateMetadataRequestBuilder, (r: AbstractResponse) => { + val updateMetadataResponse = r.asInstanceOf[UpdateMetadataResponse] + sendEvent(UpdateMetadataResponseReceived(updateMetadataResponse, broker)) + }) + + } + updateMetadataRequestBrokerSet.clear() + updateMetadataRequestPartitionInfoMap.clear() + } + + private def sendStopReplicaRequests(controllerEpoch: Int, stateChangeLog: StateChangeLogger): Unit = { + val traceEnabled = stateChangeLog.isTraceEnabled + val stopReplicaRequestVersion: Short = + if (config.interBrokerProtocolVersion >= KAFKA_2_6_IV0) 3 + else if (config.interBrokerProtocolVersion >= KAFKA_2_4_IV1) 2 + else if (config.interBrokerProtocolVersion >= KAFKA_2_2_IV0) 1 + else 0 + + def responseCallback(brokerId: Int, isPartitionDeleted: TopicPartition => Boolean) + (response: AbstractResponse): Unit = { + val stopReplicaResponse = response.asInstanceOf[StopReplicaResponse] + val partitionErrorsForDeletingTopics = mutable.Map.empty[TopicPartition, Errors] + stopReplicaResponse.partitionErrors.forEach { pe => + val tp = new TopicPartition(pe.topicName, pe.partitionIndex) + if (controllerContext.isTopicDeletionInProgress(pe.topicName) && + isPartitionDeleted(tp)) { + partitionErrorsForDeletingTopics += tp -> Errors.forCode(pe.errorCode) + } + } + if (partitionErrorsForDeletingTopics.nonEmpty) + sendEvent(TopicDeletionStopReplicaResponseReceived(brokerId, stopReplicaResponse.error, + partitionErrorsForDeletingTopics)) + } + + stopReplicaRequestMap.forKeyValue { (brokerId, partitionStates) => + if (controllerContext.liveOrShuttingDownBrokerIds.contains(brokerId)) { + if (traceEnabled) + partitionStates.forKeyValue { (topicPartition, partitionState) => + stateChangeLog.trace(s"Sending StopReplica request $partitionState to " + + s"broker $brokerId for partition $topicPartition") + } + + val brokerEpoch = controllerContext.liveBrokerIdAndEpochs(brokerId) + if (stopReplicaRequestVersion >= 3) { + val stopReplicaTopicState = mutable.Map.empty[String, StopReplicaTopicState] + partitionStates.forKeyValue { (topicPartition, partitionState) => + val topicState = stopReplicaTopicState.getOrElseUpdate(topicPartition.topic, + new StopReplicaTopicState().setTopicName(topicPartition.topic)) + topicState.partitionStates().add(partitionState) + } + + stateChangeLog.info(s"Sending StopReplica request for ${partitionStates.size} " + + s"replicas to broker $brokerId") + val stopReplicaRequestBuilder = new StopReplicaRequest.Builder( + stopReplicaRequestVersion, controllerId, controllerEpoch, brokerEpoch, + false, stopReplicaTopicState.values.toBuffer.asJava) + sendRequest(brokerId, stopReplicaRequestBuilder, + responseCallback(brokerId, tp => partitionStates.get(tp).exists(_.deletePartition))) + } else { + var numPartitionStateWithDelete = 0 + var numPartitionStateWithoutDelete = 0 + val topicStatesWithDelete = mutable.Map.empty[String, StopReplicaTopicState] + val topicStatesWithoutDelete = mutable.Map.empty[String, StopReplicaTopicState] + + partitionStates.forKeyValue { (topicPartition, partitionState) => + val topicStates = if (partitionState.deletePartition()) { + numPartitionStateWithDelete += 1 + topicStatesWithDelete + } else { + numPartitionStateWithoutDelete += 1 + topicStatesWithoutDelete + } + val topicState = topicStates.getOrElseUpdate(topicPartition.topic, + new StopReplicaTopicState().setTopicName(topicPartition.topic)) + topicState.partitionStates().add(partitionState) + } + + if (topicStatesWithDelete.nonEmpty) { + stateChangeLog.info(s"Sending StopReplica request (delete = true) for " + + s"$numPartitionStateWithDelete replicas to broker $brokerId") + val stopReplicaRequestBuilder = new StopReplicaRequest.Builder( + stopReplicaRequestVersion, controllerId, controllerEpoch, brokerEpoch, + true, topicStatesWithDelete.values.toBuffer.asJava) + sendRequest(brokerId, stopReplicaRequestBuilder, responseCallback(brokerId, _ => true)) + } + + if (topicStatesWithoutDelete.nonEmpty) { + stateChangeLog.info(s"Sending StopReplica request (delete = false) for " + + s"$numPartitionStateWithoutDelete replicas to broker $brokerId") + val stopReplicaRequestBuilder = new StopReplicaRequest.Builder( + stopReplicaRequestVersion, controllerId, controllerEpoch, brokerEpoch, + false, topicStatesWithoutDelete.values.toBuffer.asJava) + sendRequest(brokerId, stopReplicaRequestBuilder) + } + } + } + } + + stopReplicaRequestMap.clear() + } + + def sendRequestsToBrokers(controllerEpoch: Int): Unit = { + try { + val stateChangeLog = stateChangeLogger.withControllerEpoch(controllerEpoch) + sendLeaderAndIsrRequest(controllerEpoch, stateChangeLog) + sendUpdateMetadataRequests(controllerEpoch, stateChangeLog) + sendStopReplicaRequests(controllerEpoch, stateChangeLog) + } catch { + case e: Throwable => + if (leaderAndIsrRequestMap.nonEmpty) { + error("Haven't been able to send leader and isr requests, current state of " + + s"the map is $leaderAndIsrRequestMap. Exception message: $e") + } + if (updateMetadataRequestBrokerSet.nonEmpty) { + error(s"Haven't been able to send metadata update requests to brokers $updateMetadataRequestBrokerSet, " + + s"current state of the partition info is $updateMetadataRequestPartitionInfoMap. Exception message: $e") + } + if (stopReplicaRequestMap.nonEmpty) { + error("Haven't been able to send stop replica requests, current state of " + + s"the map is $stopReplicaRequestMap. Exception message: $e") + } + throw new IllegalStateException(e) + } + } +} + +case class ControllerBrokerStateInfo(networkClient: NetworkClient, + brokerNode: Node, + messageQueue: BlockingQueue[QueueItem], + requestSendThread: RequestSendThread, + queueSizeGauge: Gauge[Int], + requestRateAndTimeMetrics: Timer, + reconfigurableChannelBuilder: Option[Reconfigurable]) + diff --git a/core/src/main/scala/kafka/controller/ControllerContext.scala b/core/src/main/scala/kafka/controller/ControllerContext.scala new file mode 100644 index 0000000..379196a --- /dev/null +++ b/core/src/main/scala/kafka/controller/ControllerContext.scala @@ -0,0 +1,521 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.controller + +import kafka.cluster.Broker +import kafka.utils.Implicits._ +import org.apache.kafka.common.{TopicPartition, Uuid} + +import scala.collection.{Map, Seq, Set, mutable} + +object ReplicaAssignment { + def apply(replicas: Seq[Int]): ReplicaAssignment = { + apply(replicas, Seq.empty, Seq.empty) + } + + val empty: ReplicaAssignment = apply(Seq.empty) +} + + +/** + * @param replicas the sequence of brokers assigned to the partition. It includes the set of brokers + * that were added (`addingReplicas`) and removed (`removingReplicas`). + * @param addingReplicas the replicas that are being added if there is a pending reassignment + * @param removingReplicas the replicas that are being removed if there is a pending reassignment + */ +case class ReplicaAssignment private (replicas: Seq[Int], + addingReplicas: Seq[Int], + removingReplicas: Seq[Int]) { + + lazy val originReplicas: Seq[Int] = replicas.diff(addingReplicas) + lazy val targetReplicas: Seq[Int] = replicas.diff(removingReplicas) + + def isBeingReassigned: Boolean = { + addingReplicas.nonEmpty || removingReplicas.nonEmpty + } + + def reassignTo(target: Seq[Int]): ReplicaAssignment = { + val fullReplicaSet = (target ++ originReplicas).distinct + ReplicaAssignment( + fullReplicaSet, + fullReplicaSet.diff(originReplicas), + fullReplicaSet.diff(target) + ) + } + + def removeReplica(replica: Int): ReplicaAssignment = { + ReplicaAssignment( + replicas.filterNot(_ == replica), + addingReplicas.filterNot(_ == replica), + removingReplicas.filterNot(_ == replica) + ) + } + + override def toString: String = s"ReplicaAssignment(" + + s"replicas=${replicas.mkString(",")}, " + + s"addingReplicas=${addingReplicas.mkString(",")}, " + + s"removingReplicas=${removingReplicas.mkString(",")})" +} + +class ControllerContext { + val stats = new ControllerStats + var offlinePartitionCount = 0 + var preferredReplicaImbalanceCount = 0 + val shuttingDownBrokerIds = mutable.Set.empty[Int] + private val liveBrokers = mutable.Set.empty[Broker] + private val liveBrokerEpochs = mutable.Map.empty[Int, Long] + var epoch: Int = KafkaController.InitialControllerEpoch + var epochZkVersion: Int = KafkaController.InitialControllerEpochZkVersion + + val allTopics = mutable.Set.empty[String] + var topicIds = mutable.Map.empty[String, Uuid] + var topicNames = mutable.Map.empty[Uuid, String] + val partitionAssignments = mutable.Map.empty[String, mutable.Map[Int, ReplicaAssignment]] + private val partitionLeadershipInfo = mutable.Map.empty[TopicPartition, LeaderIsrAndControllerEpoch] + val partitionsBeingReassigned = mutable.Set.empty[TopicPartition] + val partitionStates = mutable.Map.empty[TopicPartition, PartitionState] + val replicaStates = mutable.Map.empty[PartitionAndReplica, ReplicaState] + val replicasOnOfflineDirs = mutable.Map.empty[Int, Set[TopicPartition]] + + val topicsToBeDeleted = mutable.Set.empty[String] + + /** The following topicsWithDeletionStarted variable is used to properly update the offlinePartitionCount metric. + * When a topic is going through deletion, we don't want to keep track of its partition state + * changes in the offlinePartitionCount metric. This goal means if some partitions of a topic are already + * in OfflinePartition state when deletion starts, we need to change the corresponding partition + * states to NonExistentPartition first before starting the deletion. + * + * However we can NOT change partition states to NonExistentPartition at the time of enqueuing topics + * for deletion. The reason is that when a topic is enqueued for deletion, it may be ineligible for + * deletion due to ongoing partition reassignments. Hence there might be a delay between enqueuing + * a topic for deletion and the actual start of deletion. In this delayed interval, partitions may still + * transition to or out of the OfflinePartition state. + * + * Hence we decide to change partition states to NonExistentPartition only when the actual deletion have started. + * For topics whose deletion have actually started, we keep track of them in the following topicsWithDeletionStarted + * variable. And once a topic is in the topicsWithDeletionStarted set, we are sure there will no longer + * be partition reassignments to any of its partitions, and only then it's safe to move its partitions to + * NonExistentPartition state. Once a topic is in the topicsWithDeletionStarted set, we will stop monitoring + * its partition state changes in the offlinePartitionCount metric + */ + val topicsWithDeletionStarted = mutable.Set.empty[String] + val topicsIneligibleForDeletion = mutable.Set.empty[String] + + private def clearTopicsState(): Unit = { + allTopics.clear() + topicIds.clear() + topicNames.clear() + partitionAssignments.clear() + partitionLeadershipInfo.clear() + partitionsBeingReassigned.clear() + replicasOnOfflineDirs.clear() + partitionStates.clear() + offlinePartitionCount = 0 + preferredReplicaImbalanceCount = 0 + replicaStates.clear() + } + + def addTopicId(topic: String, id: Uuid): Unit = { + if (!allTopics.contains(topic)) + throw new IllegalStateException(s"topic $topic is not contained in all topics.") + + topicIds.get(topic).foreach { existingId => + if (!existingId.equals(id)) + throw new IllegalStateException(s"topic ID map already contained ID for topic " + + s"$topic and new ID $id did not match existing ID $existingId") + } + topicNames.get(id).foreach { existingName => + if (!existingName.equals(topic)) + throw new IllegalStateException(s"topic name map already contained ID " + + s"$id and new name $topic did not match existing name $existingName") + } + topicIds.put(topic, id) + topicNames.put(id, topic) + } + + def partitionReplicaAssignment(topicPartition: TopicPartition): Seq[Int] = { + partitionAssignments.getOrElse(topicPartition.topic, mutable.Map.empty).get(topicPartition.partition) match { + case Some(partitionAssignment) => partitionAssignment.replicas + case None => Seq.empty + } + } + + def partitionFullReplicaAssignment(topicPartition: TopicPartition): ReplicaAssignment = { + partitionAssignments.getOrElse(topicPartition.topic, mutable.Map.empty) + .getOrElse(topicPartition.partition, ReplicaAssignment.empty) + } + + def updatePartitionFullReplicaAssignment(topicPartition: TopicPartition, newAssignment: ReplicaAssignment): Unit = { + val assignments = partitionAssignments.getOrElseUpdate(topicPartition.topic, mutable.Map.empty) + val previous = assignments.put(topicPartition.partition, newAssignment) + val leadershipInfo = partitionLeadershipInfo.get(topicPartition) + updatePreferredReplicaImbalanceMetric(topicPartition, previous, leadershipInfo, + Some(newAssignment), leadershipInfo) + } + + def partitionReplicaAssignmentForTopic(topic : String): Map[TopicPartition, Seq[Int]] = { + partitionAssignments.getOrElse(topic, Map.empty).map { + case (partition, assignment) => (new TopicPartition(topic, partition), assignment.replicas) + }.toMap + } + + def partitionFullReplicaAssignmentForTopic(topic : String): Map[TopicPartition, ReplicaAssignment] = { + partitionAssignments.getOrElse(topic, Map.empty).map { + case (partition, assignment) => (new TopicPartition(topic, partition), assignment) + }.toMap + } + + def allPartitions: Set[TopicPartition] = { + partitionAssignments.flatMap { + case (topic, topicReplicaAssignment) => topicReplicaAssignment.map { + case (partition, _) => new TopicPartition(topic, partition) + } + }.toSet + } + + def setLiveBrokers(brokerAndEpochs: Map[Broker, Long]): Unit = { + clearLiveBrokers() + addLiveBrokers(brokerAndEpochs) + } + + private def clearLiveBrokers(): Unit = { + liveBrokers.clear() + liveBrokerEpochs.clear() + } + + def addLiveBrokers(brokerAndEpochs: Map[Broker, Long]): Unit = { + liveBrokers ++= brokerAndEpochs.keySet + liveBrokerEpochs ++= brokerAndEpochs.map { case (broker, brokerEpoch) => (broker.id, brokerEpoch) } + } + + def removeLiveBrokers(brokerIds: Set[Int]): Unit = { + liveBrokers --= liveBrokers.filter(broker => brokerIds.contains(broker.id)) + liveBrokerEpochs --= brokerIds + } + + def updateBrokerMetadata(oldMetadata: Broker, newMetadata: Broker): Unit = { + liveBrokers -= oldMetadata + liveBrokers += newMetadata + } + + // getter + def liveBrokerIds: Set[Int] = liveBrokerEpochs.keySet.diff(shuttingDownBrokerIds) + def liveOrShuttingDownBrokerIds: Set[Int] = liveBrokerEpochs.keySet + def liveOrShuttingDownBrokers: Set[Broker] = liveBrokers + def liveBrokerIdAndEpochs: Map[Int, Long] = liveBrokerEpochs + def liveOrShuttingDownBroker(brokerId: Int): Option[Broker] = liveOrShuttingDownBrokers.find(_.id == brokerId) + + def partitionsOnBroker(brokerId: Int): Set[TopicPartition] = { + partitionAssignments.flatMap { + case (topic, topicReplicaAssignment) => topicReplicaAssignment.filter { + case (_, partitionAssignment) => partitionAssignment.replicas.contains(brokerId) + }.map { + case (partition, _) => new TopicPartition(topic, partition) + } + }.toSet + } + + def isReplicaOnline(brokerId: Int, topicPartition: TopicPartition, includeShuttingDownBrokers: Boolean = false): Boolean = { + val brokerOnline = { + if (includeShuttingDownBrokers) liveOrShuttingDownBrokerIds.contains(brokerId) + else liveBrokerIds.contains(brokerId) + } + brokerOnline && !replicasOnOfflineDirs.getOrElse(brokerId, Set.empty).contains(topicPartition) + } + + def replicasOnBrokers(brokerIds: Set[Int]): Set[PartitionAndReplica] = { + brokerIds.flatMap { brokerId => + partitionAssignments.flatMap { + case (topic, topicReplicaAssignment) => topicReplicaAssignment.collect { + case (partition, partitionAssignment) if partitionAssignment.replicas.contains(brokerId) => + PartitionAndReplica(new TopicPartition(topic, partition), brokerId) + } + } + } + } + + def replicasForTopic(topic: String): Set[PartitionAndReplica] = { + partitionAssignments.getOrElse(topic, mutable.Map.empty).flatMap { + case (partition, assignment) => assignment.replicas.map { r => + PartitionAndReplica(new TopicPartition(topic, partition), r) + } + }.toSet + } + + def partitionsForTopic(topic: String): collection.Set[TopicPartition] = { + partitionAssignments.getOrElse(topic, mutable.Map.empty).map { + case (partition, _) => new TopicPartition(topic, partition) + }.toSet + } + + /** + * Get all online and offline replicas. + * + * @return a tuple consisting of first the online replicas and followed by the offline replicas + */ + def onlineAndOfflineReplicas: (Set[PartitionAndReplica], Set[PartitionAndReplica]) = { + val onlineReplicas = mutable.Set.empty[PartitionAndReplica] + val offlineReplicas = mutable.Set.empty[PartitionAndReplica] + for ((topic, partitionAssignments) <- partitionAssignments; + (partitionId, assignment) <- partitionAssignments) { + val partition = new TopicPartition(topic, partitionId) + for (replica <- assignment.replicas) { + val partitionAndReplica = PartitionAndReplica(partition, replica) + if (isReplicaOnline(replica, partition)) + onlineReplicas.add(partitionAndReplica) + else + offlineReplicas.add(partitionAndReplica) + } + } + (onlineReplicas, offlineReplicas) + } + + def replicasForPartition(partitions: collection.Set[TopicPartition]): collection.Set[PartitionAndReplica] = { + partitions.flatMap { p => + val replicas = partitionReplicaAssignment(p) + replicas.map(PartitionAndReplica(p, _)) + } + } + + def resetContext(): Unit = { + topicsToBeDeleted.clear() + topicsWithDeletionStarted.clear() + topicsIneligibleForDeletion.clear() + shuttingDownBrokerIds.clear() + epoch = 0 + epochZkVersion = 0 + clearTopicsState() + clearLiveBrokers() + } + + def setAllTopics(topics: Set[String]): Unit = { + allTopics.clear() + allTopics ++= topics + } + + def removeTopic(topic: String): Unit = { + // Metric is cleaned when the topic is queued up for deletion so + // we don't clean it twice. We clean it only if it is deleted + // directly. + if (!topicsToBeDeleted.contains(topic)) + cleanPreferredReplicaImbalanceMetric(topic) + topicsToBeDeleted -= topic + topicsWithDeletionStarted -= topic + allTopics -= topic + topicIds.remove(topic).foreach { topicId => + topicNames.remove(topicId) + } + partitionAssignments.remove(topic).foreach { assignments => + assignments.keys.foreach { partition => + partitionLeadershipInfo.remove(new TopicPartition(topic, partition)) + } + } + } + + def queueTopicDeletion(topics: Set[String]): Unit = { + topicsToBeDeleted ++= topics + topics.foreach(cleanPreferredReplicaImbalanceMetric) + } + + def beginTopicDeletion(topics: Set[String]): Unit = { + topicsWithDeletionStarted ++= topics + } + + def isTopicDeletionInProgress(topic: String): Boolean = { + topicsWithDeletionStarted.contains(topic) + } + + def isTopicQueuedUpForDeletion(topic: String): Boolean = { + topicsToBeDeleted.contains(topic) + } + + def isTopicEligibleForDeletion(topic: String): Boolean = { + topicsToBeDeleted.contains(topic) && !topicsIneligibleForDeletion.contains(topic) + } + + def topicsQueuedForDeletion: Set[String] = { + topicsToBeDeleted + } + + def replicasInState(topic: String, state: ReplicaState): Set[PartitionAndReplica] = { + replicasForTopic(topic).filter(replica => replicaStates(replica) == state).toSet + } + + def areAllReplicasInState(topic: String, state: ReplicaState): Boolean = { + replicasForTopic(topic).forall(replica => replicaStates(replica) == state) + } + + def isAnyReplicaInState(topic: String, state: ReplicaState): Boolean = { + replicasForTopic(topic).exists(replica => replicaStates(replica) == state) + } + + def checkValidReplicaStateChange(replicas: Seq[PartitionAndReplica], targetState: ReplicaState): (Seq[PartitionAndReplica], Seq[PartitionAndReplica]) = { + replicas.partition(replica => isValidReplicaStateTransition(replica, targetState)) + } + + def checkValidPartitionStateChange(partitions: Seq[TopicPartition], targetState: PartitionState): (Seq[TopicPartition], Seq[TopicPartition]) = { + partitions.partition(p => isValidPartitionStateTransition(p, targetState)) + } + + def putReplicaState(replica: PartitionAndReplica, state: ReplicaState): Unit = { + replicaStates.put(replica, state) + } + + def removeReplicaState(replica: PartitionAndReplica): Unit = { + replicaStates.remove(replica) + } + + def putReplicaStateIfNotExists(replica: PartitionAndReplica, state: ReplicaState): Unit = { + replicaStates.getOrElseUpdate(replica, state) + } + + def putPartitionState(partition: TopicPartition, targetState: PartitionState): Unit = { + val currentState = partitionStates.put(partition, targetState).getOrElse(NonExistentPartition) + updatePartitionStateMetrics(partition, currentState, targetState) + } + + private def updatePartitionStateMetrics(partition: TopicPartition, + currentState: PartitionState, + targetState: PartitionState): Unit = { + if (!isTopicDeletionInProgress(partition.topic)) { + if (currentState != OfflinePartition && targetState == OfflinePartition) { + offlinePartitionCount = offlinePartitionCount + 1 + } else if (currentState == OfflinePartition && targetState != OfflinePartition) { + offlinePartitionCount = offlinePartitionCount - 1 + } + } + } + + def putPartitionStateIfNotExists(partition: TopicPartition, state: PartitionState): Unit = { + if (partitionStates.getOrElseUpdate(partition, state) == state) + updatePartitionStateMetrics(partition, NonExistentPartition, state) + } + + def replicaState(replica: PartitionAndReplica): ReplicaState = { + replicaStates(replica) + } + + def partitionState(partition: TopicPartition): PartitionState = { + partitionStates(partition) + } + + def partitionsInState(state: PartitionState): Set[TopicPartition] = { + partitionStates.filter { case (_, s) => s == state }.keySet.toSet + } + + def partitionsInStates(states: Set[PartitionState]): Set[TopicPartition] = { + partitionStates.filter { case (_, s) => states.contains(s) }.keySet.toSet + } + + def partitionsInState(topic: String, state: PartitionState): Set[TopicPartition] = { + partitionsForTopic(topic).filter { partition => state == partitionState(partition) }.toSet + } + + def partitionsInStates(topic: String, states: Set[PartitionState]): Set[TopicPartition] = { + partitionsForTopic(topic).filter { partition => states.contains(partitionState(partition)) }.toSet + } + + def putPartitionLeadershipInfo(partition: TopicPartition, + leaderIsrAndControllerEpoch: LeaderIsrAndControllerEpoch): Unit = { + val previous = partitionLeadershipInfo.put(partition, leaderIsrAndControllerEpoch) + val replicaAssignment = partitionFullReplicaAssignment(partition) + updatePreferredReplicaImbalanceMetric(partition, Some(replicaAssignment), previous, + Some(replicaAssignment), Some(leaderIsrAndControllerEpoch)) + } + + def partitionLeadershipInfo(partition: TopicPartition): Option[LeaderIsrAndControllerEpoch] = { + partitionLeadershipInfo.get(partition) + } + + def partitionsLeadershipInfo: Map[TopicPartition, LeaderIsrAndControllerEpoch] = + partitionLeadershipInfo + + def partitionsWithLeaders: Set[TopicPartition] = + partitionLeadershipInfo.keySet.filter(tp => !isTopicQueuedUpForDeletion(tp.topic)) + + def partitionsWithOfflineLeader: Set[TopicPartition] = { + partitionLeadershipInfo.filter { case (topicPartition, leaderIsrAndControllerEpoch) => + !isReplicaOnline(leaderIsrAndControllerEpoch.leaderAndIsr.leader, topicPartition) && + !isTopicQueuedUpForDeletion(topicPartition.topic) + }.keySet + } + + def partitionLeadersOnBroker(brokerId: Int): Set[TopicPartition] = { + partitionLeadershipInfo.filter { case (topicPartition, leaderIsrAndControllerEpoch) => + !isTopicQueuedUpForDeletion(topicPartition.topic) && + leaderIsrAndControllerEpoch.leaderAndIsr.leader == brokerId && + partitionReplicaAssignment(topicPartition).size > 1 + }.keySet + } + + def topicName(topicId: Uuid): Option[String] = { + topicNames.get(topicId) + } + + def clearPartitionLeadershipInfo(): Unit = partitionLeadershipInfo.clear() + + def partitionWithLeadersCount: Int = partitionLeadershipInfo.size + + private def updatePreferredReplicaImbalanceMetric(partition: TopicPartition, + oldReplicaAssignment: Option[ReplicaAssignment], + oldLeadershipInfo: Option[LeaderIsrAndControllerEpoch], + newReplicaAssignment: Option[ReplicaAssignment], + newLeadershipInfo: Option[LeaderIsrAndControllerEpoch]): Unit = { + if (!isTopicQueuedUpForDeletion(partition.topic)) { + oldReplicaAssignment.foreach { replicaAssignment => + oldLeadershipInfo.foreach { leadershipInfo => + if (!hasPreferredLeader(replicaAssignment, leadershipInfo)) + preferredReplicaImbalanceCount -= 1 + } + } + + newReplicaAssignment.foreach { replicaAssignment => + newLeadershipInfo.foreach { leadershipInfo => + if (!hasPreferredLeader(replicaAssignment, leadershipInfo)) + preferredReplicaImbalanceCount += 1 + } + } + } + } + + private def cleanPreferredReplicaImbalanceMetric(topic: String): Unit = { + partitionAssignments.getOrElse(topic, mutable.Map.empty).forKeyValue { (partition, replicaAssignment) => + partitionLeadershipInfo.get(new TopicPartition(topic, partition)).foreach { leadershipInfo => + if (!hasPreferredLeader(replicaAssignment, leadershipInfo)) + preferredReplicaImbalanceCount -= 1 + } + } + } + + private def hasPreferredLeader(replicaAssignment: ReplicaAssignment, + leadershipInfo: LeaderIsrAndControllerEpoch): Boolean = { + val preferredReplica = replicaAssignment.replicas.head + if (replicaAssignment.isBeingReassigned && replicaAssignment.addingReplicas.contains(preferredReplica)) + // reassigning partitions are not counted as imbalanced until the new replica joins the ISR (completes reassignment) + !leadershipInfo.leaderAndIsr.isr.contains(preferredReplica) + else + leadershipInfo.leaderAndIsr.leader == preferredReplica + } + + private def isValidReplicaStateTransition(replica: PartitionAndReplica, targetState: ReplicaState): Boolean = + targetState.validPreviousStates.contains(replicaStates(replica)) + + private def isValidPartitionStateTransition(partition: TopicPartition, targetState: PartitionState): Boolean = + targetState.validPreviousStates.contains(partitionStates(partition)) + +} diff --git a/core/src/main/scala/kafka/controller/ControllerEventManager.scala b/core/src/main/scala/kafka/controller/ControllerEventManager.scala new file mode 100644 index 0000000..b5ae3ff --- /dev/null +++ b/core/src/main/scala/kafka/controller/ControllerEventManager.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.controller + +import java.util.ArrayList +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.{CountDownLatch, LinkedBlockingQueue, TimeUnit} +import java.util.concurrent.locks.ReentrantLock + +import kafka.metrics.{KafkaMetricsGroup, KafkaTimer} +import kafka.utils.CoreUtils.inLock +import kafka.utils.ShutdownableThread +import org.apache.kafka.common.utils.Time + +import scala.collection._ + +object ControllerEventManager { + val ControllerEventThreadName = "controller-event-thread" + val EventQueueTimeMetricName = "EventQueueTimeMs" + val EventQueueSizeMetricName = "EventQueueSize" +} + +trait ControllerEventProcessor { + def process(event: ControllerEvent): Unit + def preempt(event: ControllerEvent): Unit +} + +class QueuedEvent(val event: ControllerEvent, + val enqueueTimeMs: Long) { + val processingStarted = new CountDownLatch(1) + val spent = new AtomicBoolean(false) + + def process(processor: ControllerEventProcessor): Unit = { + if (spent.getAndSet(true)) + return + processingStarted.countDown() + processor.process(event) + } + + def preempt(processor: ControllerEventProcessor): Unit = { + if (spent.getAndSet(true)) + return + processor.preempt(event) + } + + def awaitProcessing(): Unit = { + processingStarted.await() + } + + override def toString: String = { + s"QueuedEvent(event=$event, enqueueTimeMs=$enqueueTimeMs)" + } +} + +class ControllerEventManager(controllerId: Int, + processor: ControllerEventProcessor, + time: Time, + rateAndTimeMetrics: Map[ControllerState, KafkaTimer], + eventQueueTimeTimeoutMs: Long = 300000) extends KafkaMetricsGroup { + import ControllerEventManager._ + + @volatile private var _state: ControllerState = ControllerState.Idle + private val putLock = new ReentrantLock() + private val queue = new LinkedBlockingQueue[QueuedEvent] + // Visible for test + private[controller] var thread = new ControllerEventThread(ControllerEventThreadName) + + private val eventQueueTimeHist = newHistogram(EventQueueTimeMetricName) + + newGauge(EventQueueSizeMetricName, () => queue.size) + + def state: ControllerState = _state + + def start(): Unit = thread.start() + + def close(): Unit = { + try { + thread.initiateShutdown() + clearAndPut(ShutdownEventThread) + thread.awaitShutdown() + } finally { + removeMetric(EventQueueTimeMetricName) + removeMetric(EventQueueSizeMetricName) + } + } + + def put(event: ControllerEvent): QueuedEvent = inLock(putLock) { + val queuedEvent = new QueuedEvent(event, time.milliseconds()) + queue.put(queuedEvent) + queuedEvent + } + + def clearAndPut(event: ControllerEvent): QueuedEvent = inLock(putLock){ + val preemptedEvents = new ArrayList[QueuedEvent]() + queue.drainTo(preemptedEvents) + preemptedEvents.forEach(_.preempt(processor)) + put(event) + } + + def isEmpty: Boolean = queue.isEmpty + + class ControllerEventThread(name: String) extends ShutdownableThread(name = name, isInterruptible = false) { + logIdent = s"[ControllerEventThread controllerId=$controllerId] " + + override def doWork(): Unit = { + val dequeued = pollFromEventQueue() + dequeued.event match { + case ShutdownEventThread => // The shutting down of the thread has been initiated at this point. Ignore this event. + case controllerEvent => + _state = controllerEvent.state + + eventQueueTimeHist.update(time.milliseconds() - dequeued.enqueueTimeMs) + + try { + def process(): Unit = dequeued.process(processor) + + rateAndTimeMetrics.get(state) match { + case Some(timer) => timer.time { process() } + case None => process() + } + } catch { + case e: Throwable => error(s"Uncaught error processing event $controllerEvent", e) + } + + _state = ControllerState.Idle + } + } + } + + private def pollFromEventQueue(): QueuedEvent = { + val count = eventQueueTimeHist.count() + if (count != 0) { + val event = queue.poll(eventQueueTimeTimeoutMs, TimeUnit.MILLISECONDS) + if (event == null) { + eventQueueTimeHist.clear() + queue.take() + } else { + event + } + } else { + queue.take() + } + } + +} diff --git a/core/src/main/scala/kafka/controller/ControllerState.scala b/core/src/main/scala/kafka/controller/ControllerState.scala new file mode 100644 index 0000000..f842405 --- /dev/null +++ b/core/src/main/scala/kafka/controller/ControllerState.scala @@ -0,0 +1,122 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.controller + +import scala.collection.Seq + +sealed abstract class ControllerState { + + def value: Byte + + def rateAndTimeMetricName: Option[String] = + if (hasRateAndTimeMetric) Some(s"${toString}RateAndTimeMs") else None + + protected def hasRateAndTimeMetric: Boolean = true +} + +object ControllerState { + + // Note: `rateAndTimeMetricName` is based on the case object name by default. Changing a name is a breaking change + // unless `rateAndTimeMetricName` is overridden. + + case object Idle extends ControllerState { + def value = 0 + override protected def hasRateAndTimeMetric: Boolean = false + } + + case object ControllerChange extends ControllerState { + def value = 1 + } + + case object BrokerChange extends ControllerState { + def value = 2 + // The LeaderElectionRateAndTimeMs metric existed before `ControllerState` was introduced and we keep the name + // for backwards compatibility. The alternative would be to have the same metric under two different names. + override def rateAndTimeMetricName = Some("LeaderElectionRateAndTimeMs") + } + + case object TopicChange extends ControllerState { + def value = 3 + } + + case object TopicDeletion extends ControllerState { + def value = 4 + } + + case object AlterPartitionReassignment extends ControllerState { + def value = 5 + + override def rateAndTimeMetricName: Option[String] = Some("PartitionReassignmentRateAndTimeMs") + } + + case object AutoLeaderBalance extends ControllerState { + def value = 6 + } + + case object ManualLeaderBalance extends ControllerState { + def value = 7 + } + + case object ControlledShutdown extends ControllerState { + def value = 8 + } + + case object IsrChange extends ControllerState { + def value = 9 + } + + case object LeaderAndIsrResponseReceived extends ControllerState { + def value = 10 + } + + case object LogDirChange extends ControllerState { + def value = 11 + } + + case object ControllerShutdown extends ControllerState { + def value = 12 + } + + case object UncleanLeaderElectionEnable extends ControllerState { + def value = 13 + } + + case object TopicUncleanLeaderElectionEnable extends ControllerState { + def value = 14 + } + + case object ListPartitionReassignment extends ControllerState { + def value = 15 + } + + case object UpdateMetadataResponseReceived extends ControllerState { + def value = 16 + + override protected def hasRateAndTimeMetric: Boolean = false + } + + case object UpdateFeatures extends ControllerState { + def value = 17 + } + + val values: Seq[ControllerState] = Seq(Idle, ControllerChange, BrokerChange, TopicChange, TopicDeletion, + AlterPartitionReassignment, AutoLeaderBalance, ManualLeaderBalance, ControlledShutdown, IsrChange, + LeaderAndIsrResponseReceived, LogDirChange, ControllerShutdown, UncleanLeaderElectionEnable, + TopicUncleanLeaderElectionEnable, ListPartitionReassignment, UpdateMetadataResponseReceived, + UpdateFeatures) +} diff --git a/core/src/main/scala/kafka/controller/Election.scala b/core/src/main/scala/kafka/controller/Election.scala new file mode 100644 index 0000000..dffa888 --- /dev/null +++ b/core/src/main/scala/kafka/controller/Election.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.controller + +import kafka.api.LeaderAndIsr +import org.apache.kafka.common.TopicPartition + +import scala.collection.Seq + +case class ElectionResult(topicPartition: TopicPartition, leaderAndIsr: Option[LeaderAndIsr], liveReplicas: Seq[Int]) + +object Election { + + private def leaderForOffline(partition: TopicPartition, + leaderAndIsrOpt: Option[LeaderAndIsr], + uncleanLeaderElectionEnabled: Boolean, + controllerContext: ControllerContext): ElectionResult = { + + val assignment = controllerContext.partitionReplicaAssignment(partition) + val liveReplicas = assignment.filter(replica => controllerContext.isReplicaOnline(replica, partition)) + leaderAndIsrOpt match { + case Some(leaderAndIsr) => + val isr = leaderAndIsr.isr + val leaderOpt = PartitionLeaderElectionAlgorithms.offlinePartitionLeaderElection( + assignment, isr, liveReplicas.toSet, uncleanLeaderElectionEnabled, controllerContext) + val newLeaderAndIsrOpt = leaderOpt.map { leader => + val newIsr = if (isr.contains(leader)) isr.filter(replica => controllerContext.isReplicaOnline(replica, partition)) + else List(leader) + leaderAndIsr.newLeaderAndIsr(leader, newIsr) + } + ElectionResult(partition, newLeaderAndIsrOpt, liveReplicas) + + case None => + ElectionResult(partition, None, liveReplicas) + } + } + + /** + * Elect leaders for new or offline partitions. + * + * @param controllerContext Context with the current state of the cluster + * @param partitionsWithUncleanLeaderElectionState A sequence of tuples representing the partitions + * that need election, their leader/ISR state, and whether + * or not unclean leader election is enabled + * + * @return The election results + */ + def leaderForOffline( + controllerContext: ControllerContext, + partitionsWithUncleanLeaderElectionState: Seq[(TopicPartition, Option[LeaderAndIsr], Boolean)] + ): Seq[ElectionResult] = { + partitionsWithUncleanLeaderElectionState.map { + case (partition, leaderAndIsrOpt, uncleanLeaderElectionEnabled) => + leaderForOffline(partition, leaderAndIsrOpt, uncleanLeaderElectionEnabled, controllerContext) + } + } + + private def leaderForReassign(partition: TopicPartition, + leaderAndIsr: LeaderAndIsr, + controllerContext: ControllerContext): ElectionResult = { + val targetReplicas = controllerContext.partitionFullReplicaAssignment(partition).targetReplicas + val liveReplicas = targetReplicas.filter(replica => controllerContext.isReplicaOnline(replica, partition)) + val isr = leaderAndIsr.isr + val leaderOpt = PartitionLeaderElectionAlgorithms.reassignPartitionLeaderElection(targetReplicas, isr, liveReplicas.toSet) + val newLeaderAndIsrOpt = leaderOpt.map(leader => leaderAndIsr.newLeader(leader)) + ElectionResult(partition, newLeaderAndIsrOpt, targetReplicas) + } + + /** + * Elect leaders for partitions that are undergoing reassignment. + * + * @param controllerContext Context with the current state of the cluster + * @param leaderAndIsrs A sequence of tuples representing the partitions that need election + * and their respective leader/ISR states + * + * @return The election results + */ + def leaderForReassign(controllerContext: ControllerContext, + leaderAndIsrs: Seq[(TopicPartition, LeaderAndIsr)]): Seq[ElectionResult] = { + leaderAndIsrs.map { case (partition, leaderAndIsr) => + leaderForReassign(partition, leaderAndIsr, controllerContext) + } + } + + private def leaderForPreferredReplica(partition: TopicPartition, + leaderAndIsr: LeaderAndIsr, + controllerContext: ControllerContext): ElectionResult = { + val assignment = controllerContext.partitionReplicaAssignment(partition) + val liveReplicas = assignment.filter(replica => controllerContext.isReplicaOnline(replica, partition)) + val isr = leaderAndIsr.isr + val leaderOpt = PartitionLeaderElectionAlgorithms.preferredReplicaPartitionLeaderElection(assignment, isr, liveReplicas.toSet) + val newLeaderAndIsrOpt = leaderOpt.map(leader => leaderAndIsr.newLeader(leader)) + ElectionResult(partition, newLeaderAndIsrOpt, assignment) + } + + /** + * Elect preferred leaders. + * + * @param controllerContext Context with the current state of the cluster + * @param leaderAndIsrs A sequence of tuples representing the partitions that need election + * and their respective leader/ISR states + * + * @return The election results + */ + def leaderForPreferredReplica(controllerContext: ControllerContext, + leaderAndIsrs: Seq[(TopicPartition, LeaderAndIsr)]): Seq[ElectionResult] = { + leaderAndIsrs.map { case (partition, leaderAndIsr) => + leaderForPreferredReplica(partition, leaderAndIsr, controllerContext) + } + } + + private def leaderForControlledShutdown(partition: TopicPartition, + leaderAndIsr: LeaderAndIsr, + shuttingDownBrokerIds: Set[Int], + controllerContext: ControllerContext): ElectionResult = { + val assignment = controllerContext.partitionReplicaAssignment(partition) + val liveOrShuttingDownReplicas = assignment.filter(replica => + controllerContext.isReplicaOnline(replica, partition, includeShuttingDownBrokers = true)) + val isr = leaderAndIsr.isr + val leaderOpt = PartitionLeaderElectionAlgorithms.controlledShutdownPartitionLeaderElection(assignment, isr, + liveOrShuttingDownReplicas.toSet, shuttingDownBrokerIds) + val newIsr = isr.filter(replica => !shuttingDownBrokerIds.contains(replica)) + val newLeaderAndIsrOpt = leaderOpt.map(leader => leaderAndIsr.newLeaderAndIsr(leader, newIsr)) + ElectionResult(partition, newLeaderAndIsrOpt, liveOrShuttingDownReplicas) + } + + /** + * Elect leaders for partitions whose current leaders are shutting down. + * + * @param controllerContext Context with the current state of the cluster + * @param leaderAndIsrs A sequence of tuples representing the partitions that need election + * and their respective leader/ISR states + * + * @return The election results + */ + def leaderForControlledShutdown(controllerContext: ControllerContext, + leaderAndIsrs: Seq[(TopicPartition, LeaderAndIsr)]): Seq[ElectionResult] = { + val shuttingDownBrokerIds = controllerContext.shuttingDownBrokerIds.toSet + leaderAndIsrs.map { case (partition, leaderAndIsr) => + leaderForControlledShutdown(partition, leaderAndIsr, shuttingDownBrokerIds, controllerContext) + } + } +} diff --git a/core/src/main/scala/kafka/controller/KafkaController.scala b/core/src/main/scala/kafka/controller/KafkaController.scala new file mode 100644 index 0000000..9afefe3 --- /dev/null +++ b/core/src/main/scala/kafka/controller/KafkaController.scala @@ -0,0 +1,2831 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.controller + +import java.util +import java.util.concurrent.TimeUnit +import kafka.admin.AdminOperationException +import kafka.api._ +import kafka.common._ +import kafka.controller.KafkaController.AlterIsrCallback +import kafka.cluster.Broker +import kafka.controller.KafkaController.{AlterReassignmentsCallback, ElectLeadersCallback, ListReassignmentsCallback, UpdateFeaturesCallback} +import kafka.coordinator.transaction.ZkProducerIdManager +import kafka.metrics.{KafkaMetricsGroup, KafkaTimer} +import kafka.server._ +import kafka.utils._ +import kafka.utils.Implicits._ +import kafka.zk.KafkaZkClient.UpdateLeaderAndIsrResult +import kafka.zk.TopicZNode.TopicIdReplicaAssignment +import kafka.zk.{FeatureZNodeStatus, _} +import kafka.zookeeper.{StateChangeHandler, ZNodeChangeHandler, ZNodeChildChangeHandler} +import org.apache.kafka.common.ElectionType +import org.apache.kafka.common.KafkaException +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.errors.{BrokerNotAvailableException, ControllerMovedException, StaleBrokerEpochException} +import org.apache.kafka.common.message.{AllocateProducerIdsRequestData, AllocateProducerIdsResponseData, AlterIsrRequestData, AlterIsrResponseData, UpdateFeaturesRequestData} +import org.apache.kafka.common.feature.{Features, FinalizedVersionRange} +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.{AbstractControlRequest, ApiError, LeaderAndIsrResponse, UpdateFeaturesRequest, UpdateMetadataResponse} +import org.apache.kafka.common.utils.{Time, Utils} +import org.apache.kafka.server.common.ProducerIdsBlock +import org.apache.zookeeper.KeeperException +import org.apache.zookeeper.KeeperException.Code + +import scala.collection.{Map, Seq, Set, immutable, mutable} +import scala.collection.mutable.ArrayBuffer +import scala.jdk.CollectionConverters._ +import scala.util.{Failure, Success, Try} + +sealed trait ElectionTrigger +final case object AutoTriggered extends ElectionTrigger +final case object ZkTriggered extends ElectionTrigger +final case object AdminClientTriggered extends ElectionTrigger + +object KafkaController extends Logging { + val InitialControllerEpoch = 0 + val InitialControllerEpochZkVersion = 0 + + type ElectLeadersCallback = Map[TopicPartition, Either[ApiError, Int]] => Unit + type ListReassignmentsCallback = Either[Map[TopicPartition, ReplicaAssignment], ApiError] => Unit + type AlterReassignmentsCallback = Either[Map[TopicPartition, ApiError], ApiError] => Unit + type AlterIsrCallback = Either[Map[TopicPartition, Either[Errors, LeaderAndIsr]], Errors] => Unit + type UpdateFeaturesCallback = Either[ApiError, Map[String, ApiError]] => Unit +} + +class KafkaController(val config: KafkaConfig, + zkClient: KafkaZkClient, + time: Time, + metrics: Metrics, + initialBrokerInfo: BrokerInfo, + initialBrokerEpoch: Long, + tokenManager: DelegationTokenManager, + brokerFeatures: BrokerFeatures, + featureCache: FinalizedFeatureCache, + threadNamePrefix: Option[String] = None) + extends ControllerEventProcessor with Logging with KafkaMetricsGroup { + + this.logIdent = s"[Controller id=${config.brokerId}] " + + @volatile private var brokerInfo = initialBrokerInfo + @volatile private var _brokerEpoch = initialBrokerEpoch + + private val isAlterIsrEnabled = config.interBrokerProtocolVersion.isAlterIsrSupported + private val stateChangeLogger = new StateChangeLogger(config.brokerId, inControllerContext = true, None) + val controllerContext = new ControllerContext + var controllerChannelManager = new ControllerChannelManager(controllerContext, config, time, metrics, + stateChangeLogger, threadNamePrefix) + + // have a separate scheduler for the controller to be able to start and stop independently of the kafka server + // visible for testing + private[controller] val kafkaScheduler = new KafkaScheduler(1) + + // visible for testing + private[controller] val eventManager = new ControllerEventManager(config.brokerId, this, time, + controllerContext.stats.rateAndTimeMetrics) + + private val brokerRequestBatch = new ControllerBrokerRequestBatch(config, controllerChannelManager, + eventManager, controllerContext, stateChangeLogger) + val replicaStateMachine: ReplicaStateMachine = new ZkReplicaStateMachine(config, stateChangeLogger, controllerContext, zkClient, + new ControllerBrokerRequestBatch(config, controllerChannelManager, eventManager, controllerContext, stateChangeLogger)) + val partitionStateMachine: PartitionStateMachine = new ZkPartitionStateMachine(config, stateChangeLogger, controllerContext, zkClient, + new ControllerBrokerRequestBatch(config, controllerChannelManager, eventManager, controllerContext, stateChangeLogger)) + val topicDeletionManager = new TopicDeletionManager(config, controllerContext, replicaStateMachine, + partitionStateMachine, new ControllerDeletionClient(this, zkClient)) + + private val controllerChangeHandler = new ControllerChangeHandler(eventManager) + private val brokerChangeHandler = new BrokerChangeHandler(eventManager) + private val brokerModificationsHandlers: mutable.Map[Int, BrokerModificationsHandler] = mutable.Map.empty + private val topicChangeHandler = new TopicChangeHandler(eventManager) + private val topicDeletionHandler = new TopicDeletionHandler(eventManager) + private val partitionModificationsHandlers: mutable.Map[String, PartitionModificationsHandler] = mutable.Map.empty + private val partitionReassignmentHandler = new PartitionReassignmentHandler(eventManager) + private val preferredReplicaElectionHandler = new PreferredReplicaElectionHandler(eventManager) + private val isrChangeNotificationHandler = new IsrChangeNotificationHandler(eventManager) + private val logDirEventNotificationHandler = new LogDirEventNotificationHandler(eventManager) + + @volatile private var activeControllerId = -1 + @volatile private var offlinePartitionCount = 0 + @volatile private var preferredReplicaImbalanceCount = 0 + @volatile private var globalTopicCount = 0 + @volatile private var globalPartitionCount = 0 + @volatile private var topicsToDeleteCount = 0 + @volatile private var replicasToDeleteCount = 0 + @volatile private var ineligibleTopicsToDeleteCount = 0 + @volatile private var ineligibleReplicasToDeleteCount = 0 + @volatile private var activeBrokerCount = 0 + + /* single-thread scheduler to clean expired tokens */ + private val tokenCleanScheduler = new KafkaScheduler(threads = 1, threadNamePrefix = "delegation-token-cleaner") + + newGauge("ActiveControllerCount", () => if (isActive) 1 else 0) + newGauge("OfflinePartitionsCount", () => offlinePartitionCount) + newGauge("PreferredReplicaImbalanceCount", () => preferredReplicaImbalanceCount) + newGauge("ControllerState", () => state.value) + newGauge("GlobalTopicCount", () => globalTopicCount) + newGauge("GlobalPartitionCount", () => globalPartitionCount) + newGauge("TopicsToDeleteCount", () => topicsToDeleteCount) + newGauge("ReplicasToDeleteCount", () => replicasToDeleteCount) + newGauge("TopicsIneligibleToDeleteCount", () => ineligibleTopicsToDeleteCount) + newGauge("ReplicasIneligibleToDeleteCount", () => ineligibleReplicasToDeleteCount) + newGauge("ActiveBrokerCount", () => activeBrokerCount) + // FencedBrokerCount metric is always 0 in the ZK controller. + newGauge("FencedBrokerCount", () => 0) + + /** + * Returns true if this broker is the current controller. + */ + def isActive: Boolean = activeControllerId == config.brokerId + + def brokerEpoch: Long = _brokerEpoch + + def epoch: Int = controllerContext.epoch + + /** + * Invoked when the controller module of a Kafka server is started up. This does not assume that the current broker + * is the controller. It merely registers the session expiration listener and starts the controller leader + * elector + */ + def startup() = { + zkClient.registerStateChangeHandler(new StateChangeHandler { + override val name: String = StateChangeHandlers.ControllerHandler + override def afterInitializingSession(): Unit = { + eventManager.put(RegisterBrokerAndReelect) + } + override def beforeInitializingSession(): Unit = { + val queuedEvent = eventManager.clearAndPut(Expire) + + // Block initialization of the new session until the expiration event is being handled, + // which ensures that all pending events have been processed before creating the new session + queuedEvent.awaitProcessing() + } + }) + eventManager.put(Startup) + eventManager.start() + } + + /** + * Invoked when the controller module of a Kafka server is shutting down. If the broker was the current controller, + * it shuts down the partition and replica state machines. If not, those are a no-op. In addition to that, it also + * shuts down the controller channel manager, if one exists (i.e. if it was the current controller) + */ + def shutdown(): Unit = { + eventManager.close() + onControllerResignation() + } + + /** + * On controlled shutdown, the controller first determines the partitions that the + * shutting down broker leads, and moves leadership of those partitions to another broker + * that is in that partition's ISR. + * + * @param id Id of the broker to shutdown. + * @param brokerEpoch The broker epoch in the controlled shutdown request + * @return The number of partitions that the broker still leads. + */ + def controlledShutdown(id: Int, brokerEpoch: Long, controlledShutdownCallback: Try[Set[TopicPartition]] => Unit): Unit = { + val controlledShutdownEvent = ControlledShutdown(id, brokerEpoch, controlledShutdownCallback) + eventManager.put(controlledShutdownEvent) + } + + private[kafka] def updateBrokerInfo(newBrokerInfo: BrokerInfo): Unit = { + this.brokerInfo = newBrokerInfo + zkClient.updateBrokerInfo(newBrokerInfo) + } + + private[kafka] def enableDefaultUncleanLeaderElection(): Unit = { + eventManager.put(UncleanLeaderElectionEnable) + } + + private[kafka] def enableTopicUncleanLeaderElection(topic: String): Unit = { + if (isActive) { + eventManager.put(TopicUncleanLeaderElectionEnable(topic)) + } + } + + private def state: ControllerState = eventManager.state + + /** + * This callback is invoked by the zookeeper leader elector on electing the current broker as the new controller. + * It does the following things on the become-controller state change - + * 1. Initializes the controller's context object that holds cache objects for current topics, live brokers and + * leaders for all existing partitions. + * 2. Starts the controller's channel manager + * 3. Starts the replica state machine + * 4. Starts the partition state machine + * If it encounters any unexpected exception/error while becoming controller, it resigns as the current controller. + * This ensures another controller election will be triggered and there will always be an actively serving controller + */ + private def onControllerFailover(): Unit = { + maybeSetupFeatureVersioning() + + info("Registering handlers") + + // before reading source of truth from zookeeper, register the listeners to get broker/topic callbacks + val childChangeHandlers = Seq(brokerChangeHandler, topicChangeHandler, topicDeletionHandler, logDirEventNotificationHandler, + isrChangeNotificationHandler) + childChangeHandlers.foreach(zkClient.registerZNodeChildChangeHandler) + + val nodeChangeHandlers = Seq(preferredReplicaElectionHandler, partitionReassignmentHandler) + nodeChangeHandlers.foreach(zkClient.registerZNodeChangeHandlerAndCheckExistence) + + info("Deleting log dir event notifications") + zkClient.deleteLogDirEventNotifications(controllerContext.epochZkVersion) + info("Deleting isr change notifications") + zkClient.deleteIsrChangeNotifications(controllerContext.epochZkVersion) + info("Initializing controller context") + initializeControllerContext() + info("Fetching topic deletions in progress") + val (topicsToBeDeleted, topicsIneligibleForDeletion) = fetchTopicDeletionsInProgress() + info("Initializing topic deletion manager") + topicDeletionManager.init(topicsToBeDeleted, topicsIneligibleForDeletion) + + // We need to send UpdateMetadataRequest after the controller context is initialized and before the state machines + // are started. The is because brokers need to receive the list of live brokers from UpdateMetadataRequest before + // they can process the LeaderAndIsrRequests that are generated by replicaStateMachine.startup() and + // partitionStateMachine.startup(). + info("Sending update metadata request") + sendUpdateMetadataRequest(controllerContext.liveOrShuttingDownBrokerIds.toSeq, Set.empty) + + replicaStateMachine.startup() + partitionStateMachine.startup() + + info(s"Ready to serve as the new controller with epoch $epoch") + + initializePartitionReassignments() + topicDeletionManager.tryTopicDeletion() + val pendingPreferredReplicaElections = fetchPendingPreferredReplicaElections() + onReplicaElection(pendingPreferredReplicaElections, ElectionType.PREFERRED, ZkTriggered) + info("Starting the controller scheduler") + kafkaScheduler.startup() + if (config.autoLeaderRebalanceEnable) { + scheduleAutoLeaderRebalanceTask(delay = 5, unit = TimeUnit.SECONDS) + } + + if (config.tokenAuthEnabled) { + info("starting the token expiry check scheduler") + tokenCleanScheduler.startup() + tokenCleanScheduler.schedule(name = "delete-expired-tokens", + fun = () => tokenManager.expireTokens(), + period = config.delegationTokenExpiryCheckIntervalMs, + unit = TimeUnit.MILLISECONDS) + } + } + + private def createFeatureZNode(newNode: FeatureZNode): Int = { + info(s"Creating FeatureZNode at path: ${FeatureZNode.path} with contents: $newNode") + zkClient.createFeatureZNode(newNode) + val (_, newVersion) = zkClient.getDataAndVersion(FeatureZNode.path) + newVersion + } + + private def updateFeatureZNode(updatedNode: FeatureZNode): Int = { + info(s"Updating FeatureZNode at path: ${FeatureZNode.path} with contents: $updatedNode") + zkClient.updateFeatureZNode(updatedNode) + } + + /** + * This method enables the feature versioning system (KIP-584). + * + * Development in Kafka (from a high level) is organized into features. Each feature is tracked by + * a name and a range of version numbers. A feature can be of two types: + * + * 1. Supported feature: + * A supported feature is represented by a name (string) and a range of versions (defined by a + * SupportedVersionRange). It refers to a feature that a particular broker advertises support for. + * Each broker advertises the version ranges of its own supported features in its own + * BrokerIdZNode. The contents of the advertisement are specific to the particular broker and + * do not represent any guarantee of a cluster-wide availability of the feature for any particular + * range of versions. + * + * 2. Finalized feature: + * A finalized feature is represented by a name (string) and a range of version levels (defined + * by a FinalizedVersionRange). Whenever the feature versioning system (KIP-584) is + * enabled, the finalized features are stored in the cluster-wide common FeatureZNode. + * In comparison to a supported feature, the key difference is that a finalized feature exists + * in ZK only when it is guaranteed to be supported by any random broker in the cluster for a + * specified range of version levels. Also, the controller is the only entity modifying the + * information about finalized features. + * + * This method sets up the FeatureZNode with enabled status, which means that the finalized + * features stored in the FeatureZNode are active. The enabled status should be written by the + * controller to the FeatureZNode only when the broker IBP config is greater than or equal to + * KAFKA_2_7_IV0. + * + * There are multiple cases handled here: + * + * 1. New cluster bootstrap: + * A new Kafka cluster (i.e. it is deployed first time) is almost always started with IBP config + * setting greater than or equal to KAFKA_2_7_IV0. We would like to start the cluster with all + * the possible supported features finalized immediately. Assuming this is the case, the + * controller will start up and notice that the FeatureZNode is absent in the new cluster, + * it will then create a FeatureZNode (with enabled status) containing the entire list of + * supported features as its finalized features. + * + * 2. Broker binary upgraded, but IBP config set to lower than KAFKA_2_7_IV0: + * Imagine there was an existing Kafka cluster with IBP config less than KAFKA_2_7_IV0, and the + * broker binary has now been upgraded to a newer version that supports the feature versioning + * system (KIP-584). But the IBP config is still set to lower than KAFKA_2_7_IV0, and may be + * set to a higher value later. In this case, we want to start with no finalized features and + * allow the user to finalize them whenever they are ready i.e. in the future whenever the + * user sets IBP config to be greater than or equal to KAFKA_2_7_IV0, then the user could start + * finalizing the features. This process ensures we do not enable all the possible features + * immediately after an upgrade, which could be harmful to Kafka. + * This is how we handle such a case: + * - Before the IBP config upgrade (i.e. IBP config set to less than KAFKA_2_7_IV0), the + * controller will start up and check if the FeatureZNode is absent. + * - If the node is absent, it will react by creating a FeatureZNode with disabled status + * and empty finalized features. + * - Otherwise, if a node already exists in enabled status then the controller will just + * flip the status to disabled and clear the finalized features. + * - After the IBP config upgrade (i.e. IBP config set to greater than or equal to + * KAFKA_2_7_IV0), when the controller starts up it will check if the FeatureZNode exists + * and whether it is disabled. + * - If the node is in disabled status, the controller won’t upgrade all features immediately. + * Instead it will just switch the FeatureZNode status to enabled status. This lets the + * user finalize the features later. + * - Otherwise, if a node already exists in enabled status then the controller will leave + * the node umodified. + * + * 3. Broker binary upgraded, with existing cluster IBP config >= KAFKA_2_7_IV0: + * Imagine there was an existing Kafka cluster with IBP config >= KAFKA_2_7_IV0, and the broker + * binary has just been upgraded to a newer version (that supports IBP config KAFKA_2_7_IV0 and + * higher). The controller will start up and find that a FeatureZNode is already present with + * enabled status and existing finalized features. In such a case, the controller leaves the node + * unmodified. + * + * 4. Broker downgrade: + * Imagine that a Kafka cluster exists already and the IBP config is greater than or equal to + * KAFKA_2_7_IV0. Then, the user decided to downgrade the cluster by setting IBP config to a + * value less than KAFKA_2_7_IV0. This means the user is also disabling the feature versioning + * system (KIP-584). In this case, when the controller starts up with the lower IBP config, it + * will switch the FeatureZNode status to disabled with empty features. + */ + private def enableFeatureVersioning(): Unit = { + val (mayBeFeatureZNodeBytes, version) = zkClient.getDataAndVersion(FeatureZNode.path) + if (version == ZkVersion.UnknownVersion) { + val newVersion = createFeatureZNode(new FeatureZNode(FeatureZNodeStatus.Enabled, + brokerFeatures.defaultFinalizedFeatures)) + featureCache.waitUntilEpochOrThrow(newVersion, config.zkConnectionTimeoutMs) + } else { + val existingFeatureZNode = FeatureZNode.decode(mayBeFeatureZNodeBytes.get) + val newFeatures = existingFeatureZNode.status match { + case FeatureZNodeStatus.Enabled => existingFeatureZNode.features + case FeatureZNodeStatus.Disabled => + if (!existingFeatureZNode.features.empty()) { + warn(s"FeatureZNode at path: ${FeatureZNode.path} with disabled status" + + s" contains non-empty features: ${existingFeatureZNode.features}") + } + Features.emptyFinalizedFeatures + } + val newFeatureZNode = new FeatureZNode(FeatureZNodeStatus.Enabled, newFeatures) + if (!newFeatureZNode.equals(existingFeatureZNode)) { + val newVersion = updateFeatureZNode(newFeatureZNode) + featureCache.waitUntilEpochOrThrow(newVersion, config.zkConnectionTimeoutMs) + } + } + } + + /** + * Disables the feature versioning system (KIP-584). + * + * Sets up the FeatureZNode with disabled status. This status means the feature versioning system + * (KIP-584) is disabled, and, the finalized features stored in the FeatureZNode are not relevant. + * This status should be written by the controller to the FeatureZNode only when the broker + * IBP config is less than KAFKA_2_7_IV0. + * + * NOTE: + * 1. When this method returns, existing finalized features (if any) will be cleared from the + * FeatureZNode. + * 2. This method, unlike enableFeatureVersioning() need not wait for the FinalizedFeatureCache + * to be updated, because, such updates to the cache (via FinalizedFeatureChangeListener) + * are disabled when IBP config is < than KAFKA_2_7_IV0. + */ + private def disableFeatureVersioning(): Unit = { + val newNode = FeatureZNode(FeatureZNodeStatus.Disabled, Features.emptyFinalizedFeatures()) + val (mayBeFeatureZNodeBytes, version) = zkClient.getDataAndVersion(FeatureZNode.path) + if (version == ZkVersion.UnknownVersion) { + createFeatureZNode(newNode) + } else { + val existingFeatureZNode = FeatureZNode.decode(mayBeFeatureZNodeBytes.get) + if (existingFeatureZNode.status == FeatureZNodeStatus.Disabled && + !existingFeatureZNode.features.empty()) { + warn(s"FeatureZNode at path: ${FeatureZNode.path} with disabled status" + + s" contains non-empty features: ${existingFeatureZNode.features}") + } + if (!newNode.equals(existingFeatureZNode)) { + updateFeatureZNode(newNode) + } + } + } + + private def maybeSetupFeatureVersioning(): Unit = { + if (config.isFeatureVersioningSupported) { + enableFeatureVersioning() + } else { + disableFeatureVersioning() + } + } + + private def scheduleAutoLeaderRebalanceTask(delay: Long, unit: TimeUnit): Unit = { + kafkaScheduler.schedule("auto-leader-rebalance-task", () => eventManager.put(AutoPreferredReplicaLeaderElection), + delay = delay, unit = unit) + } + + /** + * This callback is invoked by the zookeeper leader elector when the current broker resigns as the controller. This is + * required to clean up internal controller data structures + */ + private def onControllerResignation(): Unit = { + debug("Resigning") + // de-register listeners + zkClient.unregisterZNodeChildChangeHandler(isrChangeNotificationHandler.path) + zkClient.unregisterZNodeChangeHandler(partitionReassignmentHandler.path) + zkClient.unregisterZNodeChangeHandler(preferredReplicaElectionHandler.path) + zkClient.unregisterZNodeChildChangeHandler(logDirEventNotificationHandler.path) + unregisterBrokerModificationsHandler(brokerModificationsHandlers.keySet) + + // shutdown leader rebalance scheduler + kafkaScheduler.shutdown() + offlinePartitionCount = 0 + preferredReplicaImbalanceCount = 0 + globalTopicCount = 0 + globalPartitionCount = 0 + topicsToDeleteCount = 0 + replicasToDeleteCount = 0 + ineligibleTopicsToDeleteCount = 0 + ineligibleReplicasToDeleteCount = 0 + + // stop token expiry check scheduler + if (tokenCleanScheduler.isStarted) + tokenCleanScheduler.shutdown() + + // de-register partition ISR listener for on-going partition reassignment task + unregisterPartitionReassignmentIsrChangeHandlers() + // shutdown partition state machine + partitionStateMachine.shutdown() + zkClient.unregisterZNodeChildChangeHandler(topicChangeHandler.path) + unregisterPartitionModificationsHandlers(partitionModificationsHandlers.keys.toSeq) + zkClient.unregisterZNodeChildChangeHandler(topicDeletionHandler.path) + // shutdown replica state machine + replicaStateMachine.shutdown() + zkClient.unregisterZNodeChildChangeHandler(brokerChangeHandler.path) + + controllerChannelManager.shutdown() + controllerContext.resetContext() + + info("Resigned") + } + + /* + * This callback is invoked by the controller's LogDirEventNotificationListener with the list of broker ids who + * have experienced new log directory failures. In response the controller should send LeaderAndIsrRequest + * to all these brokers to query the state of their replicas. Replicas with an offline log directory respond with + * KAFKA_STORAGE_ERROR, which will be handled by the LeaderAndIsrResponseReceived event. + */ + private def onBrokerLogDirFailure(brokerIds: Seq[Int]): Unit = { + // send LeaderAndIsrRequest for all replicas on those brokers to see if they are still online. + info(s"Handling log directory failure for brokers ${brokerIds.mkString(",")}") + val replicasOnBrokers = controllerContext.replicasOnBrokers(brokerIds.toSet) + replicaStateMachine.handleStateChanges(replicasOnBrokers.toSeq, OnlineReplica) + } + + /** + * This callback is invoked by the replica state machine's broker change listener, with the list of newly started + * brokers as input. It does the following - + * 1. Sends update metadata request to all live and shutting down brokers + * 2. Triggers the OnlinePartition state change for all new/offline partitions + * 3. It checks whether there are reassigned replicas assigned to any newly started brokers. If + * so, it performs the reassignment logic for each topic/partition. + * + * Note that we don't need to refresh the leader/isr cache for all topic/partitions at this point for two reasons: + * 1. The partition state machine, when triggering online state change, will refresh leader and ISR for only those + * partitions currently new or offline (rather than every partition this controller is aware of) + * 2. Even if we do refresh the cache, there is no guarantee that by the time the leader and ISR request reaches + * every broker that it is still valid. Brokers check the leader epoch to determine validity of the request. + */ + private def onBrokerStartup(newBrokers: Seq[Int]): Unit = { + info(s"New broker startup callback for ${newBrokers.mkString(",")}") + newBrokers.foreach(controllerContext.replicasOnOfflineDirs.remove) + val newBrokersSet = newBrokers.toSet + val existingBrokers = controllerContext.liveOrShuttingDownBrokerIds.diff(newBrokersSet) + // Send update metadata request to all the existing brokers in the cluster so that they know about the new brokers + // via this update. No need to include any partition states in the request since there are no partition state changes. + sendUpdateMetadataRequest(existingBrokers.toSeq, Set.empty) + // Send update metadata request to all the new brokers in the cluster with a full set of partition states for initialization. + // In cases of controlled shutdown leaders will not be elected when a new broker comes up. So at least in the + // common controlled shutdown case, the metadata will reach the new brokers faster. + sendUpdateMetadataRequest(newBrokers, controllerContext.partitionsWithLeaders) + // the very first thing to do when a new broker comes up is send it the entire list of partitions that it is + // supposed to host. Based on that the broker starts the high watermark threads for the input list of partitions + val allReplicasOnNewBrokers = controllerContext.replicasOnBrokers(newBrokersSet) + replicaStateMachine.handleStateChanges(allReplicasOnNewBrokers.toSeq, OnlineReplica) + // when a new broker comes up, the controller needs to trigger leader election for all new and offline partitions + // to see if these brokers can become leaders for some/all of those + partitionStateMachine.triggerOnlinePartitionStateChange() + // check if reassignment of some partitions need to be restarted + maybeResumeReassignments { (_, assignment) => + assignment.targetReplicas.exists(newBrokersSet.contains) + } + // check if topic deletion needs to be resumed. If at least one replica that belongs to the topic being deleted exists + // on the newly restarted brokers, there is a chance that topic deletion can resume + val replicasForTopicsToBeDeleted = allReplicasOnNewBrokers.filter(p => topicDeletionManager.isTopicQueuedUpForDeletion(p.topic)) + if (replicasForTopicsToBeDeleted.nonEmpty) { + info(s"Some replicas ${replicasForTopicsToBeDeleted.mkString(",")} for topics scheduled for deletion " + + s"${controllerContext.topicsToBeDeleted.mkString(",")} are on the newly restarted brokers " + + s"${newBrokers.mkString(",")}. Signaling restart of topic deletion for these topics") + topicDeletionManager.resumeDeletionForTopics(replicasForTopicsToBeDeleted.map(_.topic)) + } + registerBrokerModificationsHandler(newBrokers) + } + + private def maybeResumeReassignments(shouldResume: (TopicPartition, ReplicaAssignment) => Boolean): Unit = { + controllerContext.partitionsBeingReassigned.foreach { tp => + val currentAssignment = controllerContext.partitionFullReplicaAssignment(tp) + if (shouldResume(tp, currentAssignment)) + onPartitionReassignment(tp, currentAssignment) + } + } + + private def registerBrokerModificationsHandler(brokerIds: Iterable[Int]): Unit = { + debug(s"Register BrokerModifications handler for $brokerIds") + brokerIds.foreach { brokerId => + val brokerModificationsHandler = new BrokerModificationsHandler(eventManager, brokerId) + zkClient.registerZNodeChangeHandlerAndCheckExistence(brokerModificationsHandler) + brokerModificationsHandlers.put(brokerId, brokerModificationsHandler) + } + } + + private def unregisterBrokerModificationsHandler(brokerIds: Iterable[Int]): Unit = { + debug(s"Unregister BrokerModifications handler for $brokerIds") + brokerIds.foreach { brokerId => + brokerModificationsHandlers.remove(brokerId).foreach(handler => zkClient.unregisterZNodeChangeHandler(handler.path)) + } + } + + /* + * This callback is invoked by the replica state machine's broker change listener with the list of failed brokers + * as input. It will call onReplicaBecomeOffline(...) with the list of replicas on those failed brokers as input. + */ + private def onBrokerFailure(deadBrokers: Seq[Int]): Unit = { + info(s"Broker failure callback for ${deadBrokers.mkString(",")}") + deadBrokers.foreach(controllerContext.replicasOnOfflineDirs.remove) + val deadBrokersThatWereShuttingDown = + deadBrokers.filter(id => controllerContext.shuttingDownBrokerIds.remove(id)) + if (deadBrokersThatWereShuttingDown.nonEmpty) + info(s"Removed ${deadBrokersThatWereShuttingDown.mkString(",")} from list of shutting down brokers.") + val allReplicasOnDeadBrokers = controllerContext.replicasOnBrokers(deadBrokers.toSet) + onReplicasBecomeOffline(allReplicasOnDeadBrokers) + + unregisterBrokerModificationsHandler(deadBrokers) + } + + private def onBrokerUpdate(updatedBrokerId: Int): Unit = { + info(s"Broker info update callback for $updatedBrokerId") + sendUpdateMetadataRequest(controllerContext.liveOrShuttingDownBrokerIds.toSeq, Set.empty) + } + + /** + * This method marks the given replicas as offline. It does the following - + * 1. Marks the given partitions as offline + * 2. Triggers the OnlinePartition state change for all new/offline partitions + * 3. Invokes the OfflineReplica state change on the input list of newly offline replicas + * 4. If no partitions are affected then send UpdateMetadataRequest to live or shutting down brokers + * + * Note that we don't need to refresh the leader/isr cache for all topic/partitions at this point. This is because + * the partition state machine will refresh our cache for us when performing leader election for all new/offline + * partitions coming online. + */ + private def onReplicasBecomeOffline(newOfflineReplicas: Set[PartitionAndReplica]): Unit = { + val (newOfflineReplicasForDeletion, newOfflineReplicasNotForDeletion) = + newOfflineReplicas.partition(p => topicDeletionManager.isTopicQueuedUpForDeletion(p.topic)) + + val partitionsWithOfflineLeader = controllerContext.partitionsWithOfflineLeader + + // trigger OfflinePartition state for all partitions whose current leader is one amongst the newOfflineReplicas + partitionStateMachine.handleStateChanges(partitionsWithOfflineLeader.toSeq, OfflinePartition) + // trigger OnlinePartition state changes for offline or new partitions + val onlineStateChangeResults = partitionStateMachine.triggerOnlinePartitionStateChange() + // trigger OfflineReplica state change for those newly offline replicas + replicaStateMachine.handleStateChanges(newOfflineReplicasNotForDeletion.toSeq, OfflineReplica) + + // fail deletion of topics that are affected by the offline replicas + if (newOfflineReplicasForDeletion.nonEmpty) { + // it is required to mark the respective replicas in TopicDeletionFailed state since the replica cannot be + // deleted when its log directory is offline. This will prevent the replica from being in TopicDeletionStarted state indefinitely + // since topic deletion cannot be retried until at least one replica is in TopicDeletionStarted state + topicDeletionManager.failReplicaDeletion(newOfflineReplicasForDeletion) + } + + // If no partition has changed leader or ISR, no UpdateMetadataRequest is sent through PartitionStateMachine + // and ReplicaStateMachine. In that case, we want to send an UpdateMetadataRequest explicitly to + // propagate the information about the new offline brokers. + if (newOfflineReplicasNotForDeletion.isEmpty && onlineStateChangeResults.values.forall(_.isLeft)) { + sendUpdateMetadataRequest(controllerContext.liveOrShuttingDownBrokerIds.toSeq, Set.empty) + } + } + + /** + * This callback is invoked by the topic change callback with the list of failed brokers as input. + * It does the following - + * 1. Move the newly created partitions to the NewPartition state + * 2. Move the newly created partitions from NewPartition->OnlinePartition state + */ + private def onNewPartitionCreation(newPartitions: Set[TopicPartition]): Unit = { + info(s"New partition creation callback for ${newPartitions.mkString(",")}") + partitionStateMachine.handleStateChanges(newPartitions.toSeq, NewPartition) + replicaStateMachine.handleStateChanges(controllerContext.replicasForPartition(newPartitions).toSeq, NewReplica) + partitionStateMachine.handleStateChanges( + newPartitions.toSeq, + OnlinePartition, + Some(OfflinePartitionLeaderElectionStrategy(false)) + ) + replicaStateMachine.handleStateChanges(controllerContext.replicasForPartition(newPartitions).toSeq, OnlineReplica) + } + + /** + * This callback is invoked: + * 1. By the AlterPartitionReassignments API + * 2. By the reassigned partitions listener which is triggered when the /admin/reassign/partitions znode is created + * 3. When an ongoing reassignment finishes - this is detected by a change in the partition's ISR znode + * 4. Whenever a new broker comes up which is part of an ongoing reassignment + * 5. On controller startup/failover + * + * Reassigning replicas for a partition goes through a few steps listed in the code. + * RS = current assigned replica set + * ORS = Original replica set for partition + * TRS = Reassigned (target) replica set + * AR = The replicas we are adding as part of this reassignment + * RR = The replicas we are removing as part of this reassignment + * + * A reassignment may have up to three phases, each with its own steps: + + * Phase U (Assignment update): Regardless of the trigger, the first step is in the reassignment process + * is to update the existing assignment state. We always update the state in Zookeeper before + * we update memory so that it can be resumed upon controller fail-over. + * + * U1. Update ZK with RS = ORS + TRS, AR = TRS - ORS, RR = ORS - TRS. + * U2. Update memory with RS = ORS + TRS, AR = TRS - ORS and RR = ORS - TRS + * U3. If we are cancelling or replacing an existing reassignment, send StopReplica to all members + * of AR in the original reassignment if they are not in TRS from the new assignment + * + * To complete the reassignment, we need to bring the new replicas into sync, so depending on the state + * of the ISR, we will execute one of the following steps. + * + * Phase A (when TRS != ISR): The reassignment is not yet complete + * + * A1. Bump the leader epoch for the partition and send LeaderAndIsr updates to RS. + * A2. Start new replicas AR by moving replicas in AR to NewReplica state. + * + * Phase B (when TRS = ISR): The reassignment is complete + * + * B1. Move all replicas in AR to OnlineReplica state. + * B2. Set RS = TRS, AR = [], RR = [] in memory. + * B3. Send a LeaderAndIsr request with RS = TRS. This will prevent the leader from adding any replica in TRS - ORS back in the isr. + * If the current leader is not in TRS or isn't alive, we move the leader to a new replica in TRS. + * We may send the LeaderAndIsr to more than the TRS replicas due to the + * way the partition state machine works (it reads replicas from ZK) + * B4. Move all replicas in RR to OfflineReplica state. As part of OfflineReplica state change, we shrink the + * isr to remove RR in ZooKeeper and send a LeaderAndIsr ONLY to the Leader to notify it of the shrunk isr. + * After that, we send a StopReplica (delete = false) to the replicas in RR. + * B5. Move all replicas in RR to NonExistentReplica state. This will send a StopReplica (delete = true) to + * the replicas in RR to physically delete the replicas on disk. + * B6. Update ZK with RS=TRS, AR=[], RR=[]. + * B7. Remove the ISR reassign listener and maybe update the /admin/reassign_partitions path in ZK to remove this partition from it if present. + * B8. After electing leader, the replicas and isr information changes. So resend the update metadata request to every broker. + * + * In general, there are two goals we want to aim for: + * 1. Every replica present in the replica set of a LeaderAndIsrRequest gets the request sent to it + * 2. Replicas that are removed from a partition's assignment get StopReplica sent to them + * + * For example, if ORS = {1,2,3} and TRS = {4,5,6}, the values in the topic and leader/isr paths in ZK + * may go through the following transitions. + * RS AR RR leader isr + * {1,2,3} {} {} 1 {1,2,3} (initial state) + * {4,5,6,1,2,3} {4,5,6} {1,2,3} 1 {1,2,3} (step A2) + * {4,5,6,1,2,3} {4,5,6} {1,2,3} 1 {1,2,3,4,5,6} (phase B) + * {4,5,6,1,2,3} {4,5,6} {1,2,3} 4 {1,2,3,4,5,6} (step B3) + * {4,5,6,1,2,3} {4,5,6} {1,2,3} 4 {4,5,6} (step B4) + * {4,5,6} {} {} 4 {4,5,6} (step B6) + * + * Note that we have to update RS in ZK with TRS last since it's the only place where we store ORS persistently. + * This way, if the controller crashes before that step, we can still recover. + */ + private def onPartitionReassignment(topicPartition: TopicPartition, reassignment: ReplicaAssignment): Unit = { + // While a reassignment is in progress, deletion is not allowed + topicDeletionManager.markTopicIneligibleForDeletion(Set(topicPartition.topic), reason = "topic reassignment in progress") + + updateCurrentReassignment(topicPartition, reassignment) + + val addingReplicas = reassignment.addingReplicas + val removingReplicas = reassignment.removingReplicas + + if (!isReassignmentComplete(topicPartition, reassignment)) { + // A1. Send LeaderAndIsr request to every replica in ORS + TRS (with the new RS, AR and RR). + updateLeaderEpochAndSendRequest(topicPartition, reassignment) + // A2. replicas in AR -> NewReplica + startNewReplicasForReassignedPartition(topicPartition, addingReplicas) + } else { + // B1. replicas in AR -> OnlineReplica + replicaStateMachine.handleStateChanges(addingReplicas.map(PartitionAndReplica(topicPartition, _)), OnlineReplica) + // B2. Set RS = TRS, AR = [], RR = [] in memory. + val completedReassignment = ReplicaAssignment(reassignment.targetReplicas) + controllerContext.updatePartitionFullReplicaAssignment(topicPartition, completedReassignment) + // B3. Send LeaderAndIsr request with a potential new leader (if current leader not in TRS) and + // a new RS (using TRS) and same isr to every broker in ORS + TRS or TRS + moveReassignedPartitionLeaderIfRequired(topicPartition, completedReassignment) + // B4. replicas in RR -> Offline (force those replicas out of isr) + // B5. replicas in RR -> NonExistentReplica (force those replicas to be deleted) + stopRemovedReplicasOfReassignedPartition(topicPartition, removingReplicas) + // B6. Update ZK with RS = TRS, AR = [], RR = []. + updateReplicaAssignmentForPartition(topicPartition, completedReassignment) + // B7. Remove the ISR reassign listener and maybe update the /admin/reassign_partitions path in ZK to remove this partition from it. + removePartitionFromReassigningPartitions(topicPartition, completedReassignment) + // B8. After electing a leader in B3, the replicas and isr information changes, so resend the update metadata request to every broker + sendUpdateMetadataRequest(controllerContext.liveOrShuttingDownBrokerIds.toSeq, Set(topicPartition)) + // signal delete topic thread if reassignment for some partitions belonging to topics being deleted just completed + topicDeletionManager.resumeDeletionForTopics(Set(topicPartition.topic)) + } + } + + /** + * Update the current assignment state in Zookeeper and in memory. If a reassignment is already in + * progress, then the new reassignment will supplant it and some replicas will be shutdown. + * + * Note that due to the way we compute the original replica set, we cannot guarantee that a + * cancellation will restore the original replica order. Target replicas are always listed + * first in the replica set in the desired order, which means we have no way to get to the + * original order if the reassignment overlaps with the current assignment. For example, + * with an initial assignment of [1, 2, 3] and a reassignment of [3, 4, 2], then the replicas + * will be encoded as [3, 4, 2, 1] while the reassignment is in progress. If the reassignment + * is cancelled, there is no way to restore the original order. + * + * @param topicPartition The reassigning partition + * @param reassignment The new reassignment + */ + private def updateCurrentReassignment(topicPartition: TopicPartition, reassignment: ReplicaAssignment): Unit = { + val currentAssignment = controllerContext.partitionFullReplicaAssignment(topicPartition) + + if (currentAssignment != reassignment) { + debug(s"Updating assignment of partition $topicPartition from $currentAssignment to $reassignment") + + // U1. Update assignment state in zookeeper + updateReplicaAssignmentForPartition(topicPartition, reassignment) + // U2. Update assignment state in memory + controllerContext.updatePartitionFullReplicaAssignment(topicPartition, reassignment) + + // If there is a reassignment already in progress, then some of the currently adding replicas + // may be eligible for immediate removal, in which case we need to stop the replicas. + val unneededReplicas = currentAssignment.replicas.diff(reassignment.replicas) + if (unneededReplicas.nonEmpty) + stopRemovedReplicasOfReassignedPartition(topicPartition, unneededReplicas) + } + + if (!isAlterIsrEnabled) { + val reassignIsrChangeHandler = new PartitionReassignmentIsrChangeHandler(eventManager, topicPartition) + zkClient.registerZNodeChangeHandler(reassignIsrChangeHandler) + } + + controllerContext.partitionsBeingReassigned.add(topicPartition) + } + + /** + * Trigger a partition reassignment provided that the topic exists and is not being deleted. + * + * This is called when a reassignment is initially received either through Zookeeper or through the + * AlterPartitionReassignments API + * + * The `partitionsBeingReassigned` field in the controller context will be updated by this + * call after the reassignment completes validation and is successfully stored in the topic + * assignment zNode. + * + * @param reassignments The reassignments to begin processing + * @return A map of any errors in the reassignment. If the error is NONE for a given partition, + * then the reassignment was submitted successfully. + */ + private def maybeTriggerPartitionReassignment(reassignments: Map[TopicPartition, ReplicaAssignment]): Map[TopicPartition, ApiError] = { + reassignments.map { case (tp, reassignment) => + val topic = tp.topic + + val apiError = if (topicDeletionManager.isTopicQueuedUpForDeletion(topic)) { + info(s"Skipping reassignment of $tp since the topic is currently being deleted") + new ApiError(Errors.UNKNOWN_TOPIC_OR_PARTITION, "The partition does not exist.") + } else { + val assignedReplicas = controllerContext.partitionReplicaAssignment(tp) + if (assignedReplicas.nonEmpty) { + try { + onPartitionReassignment(tp, reassignment) + ApiError.NONE + } catch { + case e: ControllerMovedException => + info(s"Failed completing reassignment of partition $tp because controller has moved to another broker") + throw e + case e: Throwable => + error(s"Error completing reassignment of partition $tp", e) + new ApiError(Errors.UNKNOWN_SERVER_ERROR) + } + } else { + new ApiError(Errors.UNKNOWN_TOPIC_OR_PARTITION, "The partition does not exist.") + } + } + + tp -> apiError + } + } + + /** + * Attempt to elect a replica as leader for each of the given partitions. + * @param partitions The partitions to have a new leader elected + * @param electionType The type of election to perform + * @param electionTrigger The reason for tigger this election + * @return A map of failed and successful elections. The keys are the topic partitions and the corresponding values are + * either the exception that was thrown or new leader & ISR. + */ + private[this] def onReplicaElection( + partitions: Set[TopicPartition], + electionType: ElectionType, + electionTrigger: ElectionTrigger + ): Map[TopicPartition, Either[Throwable, LeaderAndIsr]] = { + info(s"Starting replica leader election ($electionType) for partitions ${partitions.mkString(",")} triggered by $electionTrigger") + try { + val strategy = electionType match { + case ElectionType.PREFERRED => PreferredReplicaPartitionLeaderElectionStrategy + case ElectionType.UNCLEAN => + /* Let's be conservative and only trigger unclean election if the election type is unclean and it was + * triggered by the admin client + */ + OfflinePartitionLeaderElectionStrategy(allowUnclean = electionTrigger == AdminClientTriggered) + } + + val results = partitionStateMachine.handleStateChanges( + partitions.toSeq, + OnlinePartition, + Some(strategy) + ) + if (electionTrigger != AdminClientTriggered) { + results.foreach { + case (tp, Left(throwable)) => + if (throwable.isInstanceOf[ControllerMovedException]) { + info(s"Error completing replica leader election ($electionType) for partition $tp because controller has moved to another broker.", throwable) + throw throwable + } else { + error(s"Error completing replica leader election ($electionType) for partition $tp", throwable) + } + case (_, Right(_)) => // Ignored; No need to log or throw exception for the success cases + } + } + + results + } finally { + if (electionTrigger != AdminClientTriggered) { + removePartitionsFromPreferredReplicaElection(partitions, electionTrigger == AutoTriggered) + } + } + } + + private def initializeControllerContext(): Unit = { + // update controller cache with delete topic information + val curBrokerAndEpochs = zkClient.getAllBrokerAndEpochsInCluster + val (compatibleBrokerAndEpochs, incompatibleBrokerAndEpochs) = partitionOnFeatureCompatibility(curBrokerAndEpochs) + if (!incompatibleBrokerAndEpochs.isEmpty) { + warn("Ignoring registration of new brokers due to incompatibilities with finalized features: " + + incompatibleBrokerAndEpochs.map { case (broker, _) => broker.id }.toSeq.sorted.mkString(",")) + } + controllerContext.setLiveBrokers(compatibleBrokerAndEpochs) + info(s"Initialized broker epochs cache: ${controllerContext.liveBrokerIdAndEpochs}") + controllerContext.setAllTopics(zkClient.getAllTopicsInCluster(true)) + registerPartitionModificationsHandlers(controllerContext.allTopics.toSeq) + val replicaAssignmentAndTopicIds = zkClient.getReplicaAssignmentAndTopicIdForTopics(controllerContext.allTopics.toSet) + processTopicIds(replicaAssignmentAndTopicIds) + + replicaAssignmentAndTopicIds.foreach { case TopicIdReplicaAssignment(_, _, assignments) => + assignments.foreach { case (topicPartition, replicaAssignment) => + controllerContext.updatePartitionFullReplicaAssignment(topicPartition, replicaAssignment) + if (replicaAssignment.isBeingReassigned) + controllerContext.partitionsBeingReassigned.add(topicPartition) + } + } + controllerContext.clearPartitionLeadershipInfo() + controllerContext.shuttingDownBrokerIds.clear() + // register broker modifications handlers + registerBrokerModificationsHandler(controllerContext.liveOrShuttingDownBrokerIds) + // update the leader and isr cache for all existing partitions from Zookeeper + updateLeaderAndIsrCache() + // start the channel manager + controllerChannelManager.startup() + info(s"Currently active brokers in the cluster: ${controllerContext.liveBrokerIds}") + info(s"Currently shutting brokers in the cluster: ${controllerContext.shuttingDownBrokerIds}") + info(s"Current list of topics in the cluster: ${controllerContext.allTopics}") + } + + private def fetchPendingPreferredReplicaElections(): Set[TopicPartition] = { + val partitionsUndergoingPreferredReplicaElection = zkClient.getPreferredReplicaElection + // check if they are already completed or topic was deleted + val partitionsThatCompletedPreferredReplicaElection = partitionsUndergoingPreferredReplicaElection.filter { partition => + val replicas = controllerContext.partitionReplicaAssignment(partition) + val topicDeleted = replicas.isEmpty + val successful = + if (!topicDeleted) controllerContext.partitionLeadershipInfo(partition).get.leaderAndIsr.leader == replicas.head else false + successful || topicDeleted + } + val pendingPreferredReplicaElectionsIgnoringTopicDeletion = partitionsUndergoingPreferredReplicaElection -- partitionsThatCompletedPreferredReplicaElection + val pendingPreferredReplicaElectionsSkippedFromTopicDeletion = pendingPreferredReplicaElectionsIgnoringTopicDeletion.filter(partition => topicDeletionManager.isTopicQueuedUpForDeletion(partition.topic)) + val pendingPreferredReplicaElections = pendingPreferredReplicaElectionsIgnoringTopicDeletion -- pendingPreferredReplicaElectionsSkippedFromTopicDeletion + info(s"Partitions undergoing preferred replica election: ${partitionsUndergoingPreferredReplicaElection.mkString(",")}") + info(s"Partitions that completed preferred replica election: ${partitionsThatCompletedPreferredReplicaElection.mkString(",")}") + info(s"Skipping preferred replica election for partitions due to topic deletion: ${pendingPreferredReplicaElectionsSkippedFromTopicDeletion.mkString(",")}") + info(s"Resuming preferred replica election for partitions: ${pendingPreferredReplicaElections.mkString(",")}") + pendingPreferredReplicaElections + } + + /** + * Initialize pending reassignments. This includes reassignments sent through /admin/reassign_partitions, + * which will supplant any API reassignments already in progress. + */ + private def initializePartitionReassignments(): Unit = { + // New reassignments may have been submitted through Zookeeper while the controller was failing over + val zkPartitionsResumed = processZkPartitionReassignment() + // We may also have some API-based reassignments that need to be restarted + maybeResumeReassignments { (tp, _) => + !zkPartitionsResumed.contains(tp) + } + } + + private def fetchTopicDeletionsInProgress(): (Set[String], Set[String]) = { + val topicsToBeDeleted = zkClient.getTopicDeletions.toSet + val topicsWithOfflineReplicas = controllerContext.allTopics.filter { topic => { + val replicasForTopic = controllerContext.replicasForTopic(topic) + replicasForTopic.exists(r => !controllerContext.isReplicaOnline(r.replica, r.topicPartition)) + }} + val topicsForWhichPartitionReassignmentIsInProgress = controllerContext.partitionsBeingReassigned.map(_.topic) + val topicsIneligibleForDeletion = topicsWithOfflineReplicas | topicsForWhichPartitionReassignmentIsInProgress + info(s"List of topics to be deleted: ${topicsToBeDeleted.mkString(",")}") + info(s"List of topics ineligible for deletion: ${topicsIneligibleForDeletion.mkString(",")}") + (topicsToBeDeleted, topicsIneligibleForDeletion) + } + + private def updateLeaderAndIsrCache(partitions: Seq[TopicPartition] = controllerContext.allPartitions.toSeq): Unit = { + val leaderIsrAndControllerEpochs = zkClient.getTopicPartitionStates(partitions) + leaderIsrAndControllerEpochs.forKeyValue { (partition, leaderIsrAndControllerEpoch) => + controllerContext.putPartitionLeadershipInfo(partition, leaderIsrAndControllerEpoch) + } + } + + private def isReassignmentComplete(partition: TopicPartition, assignment: ReplicaAssignment): Boolean = { + if (!assignment.isBeingReassigned) { + true + } else { + zkClient.getTopicPartitionStates(Seq(partition)).get(partition).exists { leaderIsrAndControllerEpoch => + val isr = leaderIsrAndControllerEpoch.leaderAndIsr.isr.toSet + val targetReplicas = assignment.targetReplicas.toSet + targetReplicas.subsetOf(isr) + } + } + } + + private def moveReassignedPartitionLeaderIfRequired(topicPartition: TopicPartition, + newAssignment: ReplicaAssignment): Unit = { + val reassignedReplicas = newAssignment.replicas + val currentLeader = controllerContext.partitionLeadershipInfo(topicPartition).get.leaderAndIsr.leader + + if (!reassignedReplicas.contains(currentLeader)) { + info(s"Leader $currentLeader for partition $topicPartition being reassigned, " + + s"is not in the new list of replicas ${reassignedReplicas.mkString(",")}. Re-electing leader") + // move the leader to one of the alive and caught up new replicas + partitionStateMachine.handleStateChanges(Seq(topicPartition), OnlinePartition, Some(ReassignPartitionLeaderElectionStrategy)) + } else if (controllerContext.isReplicaOnline(currentLeader, topicPartition)) { + info(s"Leader $currentLeader for partition $topicPartition being reassigned, " + + s"is already in the new list of replicas ${reassignedReplicas.mkString(",")} and is alive") + // shrink replication factor and update the leader epoch in zookeeper to use on the next LeaderAndIsrRequest + updateLeaderEpochAndSendRequest(topicPartition, newAssignment) + } else { + info(s"Leader $currentLeader for partition $topicPartition being reassigned, " + + s"is already in the new list of replicas ${reassignedReplicas.mkString(",")} but is dead") + partitionStateMachine.handleStateChanges(Seq(topicPartition), OnlinePartition, Some(ReassignPartitionLeaderElectionStrategy)) + } + } + + private def stopRemovedReplicasOfReassignedPartition(topicPartition: TopicPartition, + removedReplicas: Seq[Int]): Unit = { + // first move the replica to offline state (the controller removes it from the ISR) + val replicasToBeDeleted = removedReplicas.map(PartitionAndReplica(topicPartition, _)) + replicaStateMachine.handleStateChanges(replicasToBeDeleted, OfflineReplica) + // send stop replica command to the old replicas + replicaStateMachine.handleStateChanges(replicasToBeDeleted, ReplicaDeletionStarted) + // TODO: Eventually partition reassignment could use a callback that does retries if deletion failed + replicaStateMachine.handleStateChanges(replicasToBeDeleted, ReplicaDeletionSuccessful) + replicaStateMachine.handleStateChanges(replicasToBeDeleted, NonExistentReplica) + } + + private def updateReplicaAssignmentForPartition(topicPartition: TopicPartition, assignment: ReplicaAssignment): Unit = { + val topicAssignment = mutable.Map() ++= + controllerContext.partitionFullReplicaAssignmentForTopic(topicPartition.topic) += + (topicPartition -> assignment) + + val setDataResponse = zkClient.setTopicAssignmentRaw(topicPartition.topic, + controllerContext.topicIds.get(topicPartition.topic), + topicAssignment, controllerContext.epochZkVersion) + setDataResponse.resultCode match { + case Code.OK => + info(s"Successfully updated assignment of partition $topicPartition to $assignment") + case Code.NONODE => + throw new IllegalStateException(s"Failed to update assignment for $topicPartition since the topic " + + "has no current assignment") + case _ => throw new KafkaException(setDataResponse.resultException.get) + } + } + + private def startNewReplicasForReassignedPartition(topicPartition: TopicPartition, newReplicas: Seq[Int]): Unit = { + // send the start replica request to the brokers in the reassigned replicas list that are not in the assigned + // replicas list + newReplicas.foreach { replica => + replicaStateMachine.handleStateChanges(Seq(PartitionAndReplica(topicPartition, replica)), NewReplica) + } + } + + private def updateLeaderEpochAndSendRequest(topicPartition: TopicPartition, + assignment: ReplicaAssignment): Unit = { + val stateChangeLog = stateChangeLogger.withControllerEpoch(controllerContext.epoch) + updateLeaderEpoch(topicPartition) match { + case Some(updatedLeaderIsrAndControllerEpoch) => + try { + brokerRequestBatch.newBatch() + // the isNew flag, when set to true, makes sure that when a replica possibly resided + // in a logDir that is offline, we refrain from just creating a new replica in a good + // logDir. This is exactly the behavior we want for the original replicas, but not + // for the replicas we add in this reassignment. For new replicas, want to be able + // to assign to one of the good logDirs. + brokerRequestBatch.addLeaderAndIsrRequestForBrokers(assignment.originReplicas, topicPartition, + updatedLeaderIsrAndControllerEpoch, assignment, isNew = false) + brokerRequestBatch.addLeaderAndIsrRequestForBrokers(assignment.addingReplicas, topicPartition, + updatedLeaderIsrAndControllerEpoch, assignment, isNew = true) + brokerRequestBatch.sendRequestsToBrokers(controllerContext.epoch) + } catch { + case e: IllegalStateException => + handleIllegalState(e) + } + stateChangeLog.info(s"Sent LeaderAndIsr request $updatedLeaderIsrAndControllerEpoch with " + + s"new replica assignment $assignment to leader ${updatedLeaderIsrAndControllerEpoch.leaderAndIsr.leader} " + + s"for partition being reassigned $topicPartition") + + case None => // fail the reassignment + stateChangeLog.error(s"Failed to send LeaderAndIsr request with new replica assignment " + + s"$assignment to leader for partition being reassigned $topicPartition") + } + } + + private def registerPartitionModificationsHandlers(topics: Seq[String]) = { + topics.foreach { topic => + val partitionModificationsHandler = new PartitionModificationsHandler(eventManager, topic) + partitionModificationsHandlers.put(topic, partitionModificationsHandler) + } + partitionModificationsHandlers.values.foreach(zkClient.registerZNodeChangeHandler) + } + + private[controller] def unregisterPartitionModificationsHandlers(topics: Seq[String]) = { + topics.foreach { topic => + partitionModificationsHandlers.remove(topic).foreach(handler => zkClient.unregisterZNodeChangeHandler(handler.path)) + } + } + + private def unregisterPartitionReassignmentIsrChangeHandlers(): Unit = { + if (!isAlterIsrEnabled) { + controllerContext.partitionsBeingReassigned.foreach { tp => + val path = TopicPartitionStateZNode.path(tp) + zkClient.unregisterZNodeChangeHandler(path) + } + } + } + + private def removePartitionFromReassigningPartitions(topicPartition: TopicPartition, + assignment: ReplicaAssignment): Unit = { + if (controllerContext.partitionsBeingReassigned.contains(topicPartition)) { + if (!isAlterIsrEnabled) { + val path = TopicPartitionStateZNode.path(topicPartition) + zkClient.unregisterZNodeChangeHandler(path) + } + maybeRemoveFromZkReassignment((tp, replicas) => tp == topicPartition && replicas == assignment.replicas) + controllerContext.partitionsBeingReassigned.remove(topicPartition) + } else { + throw new IllegalStateException("Cannot remove a reassigning partition because it is not present in memory") + } + } + + /** + * Remove partitions from an active zk-based reassignment (if one exists). + * + * @param shouldRemoveReassignment Predicate indicating which partition reassignments should be removed + */ + private def maybeRemoveFromZkReassignment(shouldRemoveReassignment: (TopicPartition, Seq[Int]) => Boolean): Unit = { + if (!zkClient.reassignPartitionsInProgress) + return + + val reassigningPartitions = zkClient.getPartitionReassignment + val (removingPartitions, updatedPartitionsBeingReassigned) = reassigningPartitions.partition { case (tp, replicas) => + shouldRemoveReassignment(tp, replicas) + } + info(s"Removing partitions $removingPartitions from the list of reassigned partitions in zookeeper") + + // write the new list to zookeeper + if (updatedPartitionsBeingReassigned.isEmpty) { + info(s"No more partitions need to be reassigned. Deleting zk path ${ReassignPartitionsZNode.path}") + zkClient.deletePartitionReassignment(controllerContext.epochZkVersion) + // Ensure we detect future reassignments + eventManager.put(ZkPartitionReassignment) + } else { + try { + zkClient.setOrCreatePartitionReassignment(updatedPartitionsBeingReassigned, controllerContext.epochZkVersion) + } catch { + case e: KeeperException => throw new AdminOperationException(e) + } + } + } + + private def removePartitionsFromPreferredReplicaElection(partitionsToBeRemoved: Set[TopicPartition], + isTriggeredByAutoRebalance : Boolean): Unit = { + for (partition <- partitionsToBeRemoved) { + // check the status + val currentLeader = controllerContext.partitionLeadershipInfo(partition).get.leaderAndIsr.leader + val preferredReplica = controllerContext.partitionReplicaAssignment(partition).head + if (currentLeader == preferredReplica) { + info(s"Partition $partition completed preferred replica leader election. New leader is $preferredReplica") + } else { + warn(s"Partition $partition failed to complete preferred replica leader election to $preferredReplica. " + + s"Leader is still $currentLeader") + } + } + if (!isTriggeredByAutoRebalance) { + zkClient.deletePreferredReplicaElection(controllerContext.epochZkVersion) + // Ensure we detect future preferred replica leader elections + eventManager.put(ReplicaLeaderElection(None, ElectionType.PREFERRED, ZkTriggered)) + } + } + + /** + * Send the leader information for selected partitions to selected brokers so that they can correctly respond to + * metadata requests + * + * @param brokers The brokers that the update metadata request should be sent to + */ + private[controller] def sendUpdateMetadataRequest(brokers: Seq[Int], partitions: Set[TopicPartition]): Unit = { + try { + brokerRequestBatch.newBatch() + brokerRequestBatch.addUpdateMetadataRequestForBrokers(brokers, partitions) + brokerRequestBatch.sendRequestsToBrokers(epoch) + } catch { + case e: IllegalStateException => + handleIllegalState(e) + } + } + + /** + * Does not change leader or isr, but just increments the leader epoch + * + * @param partition partition + * @return the new leaderAndIsr with an incremented leader epoch, or None if leaderAndIsr is empty. + */ + private def updateLeaderEpoch(partition: TopicPartition): Option[LeaderIsrAndControllerEpoch] = { + debug(s"Updating leader epoch for partition $partition") + var finalLeaderIsrAndControllerEpoch: Option[LeaderIsrAndControllerEpoch] = None + var zkWriteCompleteOrUnnecessary = false + while (!zkWriteCompleteOrUnnecessary) { + // refresh leader and isr from zookeeper again + zkWriteCompleteOrUnnecessary = zkClient.getTopicPartitionStates(Seq(partition)).get(partition) match { + case Some(leaderIsrAndControllerEpoch) => + val leaderAndIsr = leaderIsrAndControllerEpoch.leaderAndIsr + val controllerEpoch = leaderIsrAndControllerEpoch.controllerEpoch + if (controllerEpoch > epoch) + throw new StateChangeFailedException("Leader and isr path written by another controller. This probably " + + s"means the current controller with epoch $epoch went through a soft failure and another " + + s"controller was elected with epoch $controllerEpoch. Aborting state change by this controller") + // increment the leader epoch even if there are no leader or isr changes to allow the leader to cache the expanded + // assigned replica list + val newLeaderAndIsr = leaderAndIsr.newEpochAndZkVersion + // update the new leadership decision in zookeeper or retry + val UpdateLeaderAndIsrResult(finishedUpdates, _) = + zkClient.updateLeaderAndIsr(immutable.Map(partition -> newLeaderAndIsr), epoch, controllerContext.epochZkVersion) + + finishedUpdates.get(partition) match { + case Some(Right(leaderAndIsr)) => + val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(leaderAndIsr, epoch) + controllerContext.putPartitionLeadershipInfo(partition, leaderIsrAndControllerEpoch) + finalLeaderIsrAndControllerEpoch = Some(leaderIsrAndControllerEpoch) + info(s"Updated leader epoch for partition $partition to ${leaderAndIsr.leaderEpoch}, zkVersion=${leaderAndIsr.zkVersion}") + true + case Some(Left(e)) => throw e + case None => false + } + case None => + throw new IllegalStateException(s"Cannot update leader epoch for partition $partition as " + + "leaderAndIsr path is empty. This could mean we somehow tried to reassign a partition that doesn't exist") + } + } + finalLeaderIsrAndControllerEpoch + } + + private def checkAndTriggerAutoLeaderRebalance(): Unit = { + trace("Checking need to trigger auto leader balancing") + val preferredReplicasForTopicsByBrokers: Map[Int, Map[TopicPartition, Seq[Int]]] = + controllerContext.allPartitions.filterNot { + tp => topicDeletionManager.isTopicQueuedUpForDeletion(tp.topic) + }.map { tp => + (tp, controllerContext.partitionReplicaAssignment(tp) ) + }.toMap.groupBy { case (_, assignedReplicas) => assignedReplicas.head } + + // for each broker, check if a preferred replica election needs to be triggered + preferredReplicasForTopicsByBrokers.forKeyValue { (leaderBroker, topicPartitionsForBroker) => + val topicsNotInPreferredReplica = topicPartitionsForBroker.filter { case (topicPartition, _) => + val leadershipInfo = controllerContext.partitionLeadershipInfo(topicPartition) + leadershipInfo.exists(_.leaderAndIsr.leader != leaderBroker) + } + debug(s"Topics not in preferred replica for broker $leaderBroker $topicsNotInPreferredReplica") + + val imbalanceRatio = topicsNotInPreferredReplica.size.toDouble / topicPartitionsForBroker.size + trace(s"Leader imbalance ratio for broker $leaderBroker is $imbalanceRatio") + + // check ratio and if greater than desired ratio, trigger a rebalance for the topic partitions + // that need to be on this broker + if (imbalanceRatio > (config.leaderImbalancePerBrokerPercentage.toDouble / 100)) { + // do this check only if the broker is live and there are no partitions being reassigned currently + // and preferred replica election is not in progress + val candidatePartitions = topicsNotInPreferredReplica.keys.filter(tp => + controllerContext.partitionsBeingReassigned.isEmpty && + !topicDeletionManager.isTopicQueuedUpForDeletion(tp.topic) && + controllerContext.allTopics.contains(tp.topic) && + canPreferredReplicaBeLeader(tp) + ) + onReplicaElection(candidatePartitions.toSet, ElectionType.PREFERRED, AutoTriggered) + } + } + } + + private def canPreferredReplicaBeLeader(tp: TopicPartition): Boolean = { + val assignment = controllerContext.partitionReplicaAssignment(tp) + val liveReplicas = assignment.filter(replica => controllerContext.isReplicaOnline(replica, tp)) + val isr = controllerContext.partitionLeadershipInfo(tp).get.leaderAndIsr.isr + PartitionLeaderElectionAlgorithms + .preferredReplicaPartitionLeaderElection(assignment, isr, liveReplicas.toSet) + .nonEmpty + } + + private def processAutoPreferredReplicaLeaderElection(): Unit = { + if (!isActive) return + try { + info("Processing automatic preferred replica leader election") + checkAndTriggerAutoLeaderRebalance() + } finally { + scheduleAutoLeaderRebalanceTask(delay = config.leaderImbalanceCheckIntervalSeconds, unit = TimeUnit.SECONDS) + } + } + + private def processUncleanLeaderElectionEnable(): Unit = { + if (!isActive) return + info("Unclean leader election has been enabled by default") + partitionStateMachine.triggerOnlinePartitionStateChange() + } + + private def processTopicUncleanLeaderElectionEnable(topic: String): Unit = { + if (!isActive) return + info(s"Unclean leader election has been enabled for topic $topic") + partitionStateMachine.triggerOnlinePartitionStateChange(topic) + } + + private def processControlledShutdown(id: Int, brokerEpoch: Long, controlledShutdownCallback: Try[Set[TopicPartition]] => Unit): Unit = { + val controlledShutdownResult = Try { doControlledShutdown(id, brokerEpoch) } + controlledShutdownCallback(controlledShutdownResult) + } + + private def doControlledShutdown(id: Int, brokerEpoch: Long): Set[TopicPartition] = { + if (!isActive) { + throw new ControllerMovedException("Controller moved to another broker. Aborting controlled shutdown") + } + + // broker epoch in the request is unknown if the controller hasn't been upgraded to use KIP-380 + // so we will keep the previous behavior and don't reject the request + if (brokerEpoch != AbstractControlRequest.UNKNOWN_BROKER_EPOCH) { + val cachedBrokerEpoch = controllerContext.liveBrokerIdAndEpochs(id) + if (brokerEpoch < cachedBrokerEpoch) { + val stateBrokerEpochErrorMessage = "Received controlled shutdown request from an old broker epoch " + + s"$brokerEpoch for broker $id. Current broker epoch is $cachedBrokerEpoch." + info(stateBrokerEpochErrorMessage) + throw new StaleBrokerEpochException(stateBrokerEpochErrorMessage) + } + } + + info(s"Shutting down broker $id") + + if (!controllerContext.liveOrShuttingDownBrokerIds.contains(id)) + throw new BrokerNotAvailableException(s"Broker id $id does not exist.") + + controllerContext.shuttingDownBrokerIds.add(id) + debug(s"All shutting down brokers: ${controllerContext.shuttingDownBrokerIds.mkString(",")}") + debug(s"Live brokers: ${controllerContext.liveBrokerIds.mkString(",")}") + + val partitionsToActOn = controllerContext.partitionsOnBroker(id).filter { partition => + controllerContext.partitionReplicaAssignment(partition).size > 1 && + controllerContext.partitionLeadershipInfo(partition).isDefined && + !topicDeletionManager.isTopicQueuedUpForDeletion(partition.topic) + } + val (partitionsLedByBroker, partitionsFollowedByBroker) = partitionsToActOn.partition { partition => + controllerContext.partitionLeadershipInfo(partition).get.leaderAndIsr.leader == id + } + partitionStateMachine.handleStateChanges(partitionsLedByBroker.toSeq, OnlinePartition, Some(ControlledShutdownPartitionLeaderElectionStrategy)) + try { + brokerRequestBatch.newBatch() + partitionsFollowedByBroker.foreach { partition => + brokerRequestBatch.addStopReplicaRequestForBrokers(Seq(id), partition, deletePartition = false) + } + brokerRequestBatch.sendRequestsToBrokers(epoch) + } catch { + case e: IllegalStateException => + handleIllegalState(e) + } + // If the broker is a follower, updates the isr in ZK and notifies the current leader + replicaStateMachine.handleStateChanges(partitionsFollowedByBroker.map(partition => + PartitionAndReplica(partition, id)).toSeq, OfflineReplica) + trace(s"All leaders = ${controllerContext.partitionsLeadershipInfo.mkString(",")}") + controllerContext.partitionLeadersOnBroker(id) + } + + private def processUpdateMetadataResponseReceived(updateMetadataResponse: UpdateMetadataResponse, brokerId: Int): Unit = { + if (!isActive) return + + if (updateMetadataResponse.error != Errors.NONE) { + stateChangeLogger.error(s"Received error ${updateMetadataResponse.error} in UpdateMetadata " + + s"response $updateMetadataResponse from broker $brokerId") + } + } + + private def processLeaderAndIsrResponseReceived(leaderAndIsrResponse: LeaderAndIsrResponse, brokerId: Int): Unit = { + if (!isActive) return + + if (leaderAndIsrResponse.error != Errors.NONE) { + stateChangeLogger.error(s"Received error ${leaderAndIsrResponse.error} in LeaderAndIsr " + + s"response $leaderAndIsrResponse from broker $brokerId") + return + } + + val offlineReplicas = new ArrayBuffer[TopicPartition]() + val onlineReplicas = new ArrayBuffer[TopicPartition]() + + leaderAndIsrResponse.partitionErrors(controllerContext.topicNames.asJava).forEach{ case (tp, error) => + if (error.code() == Errors.KAFKA_STORAGE_ERROR.code) + offlineReplicas += tp + else if (error.code() == Errors.NONE.code) + onlineReplicas += tp + } + + val previousOfflineReplicas = controllerContext.replicasOnOfflineDirs.getOrElse(brokerId, Set.empty[TopicPartition]) + val currentOfflineReplicas = mutable.Set() ++= previousOfflineReplicas --= onlineReplicas ++= offlineReplicas + controllerContext.replicasOnOfflineDirs.put(brokerId, currentOfflineReplicas) + val newOfflineReplicas = currentOfflineReplicas.diff(previousOfflineReplicas) + + if (newOfflineReplicas.nonEmpty) { + stateChangeLogger.info(s"Mark replicas ${newOfflineReplicas.mkString(",")} on broker $brokerId as offline") + onReplicasBecomeOffline(newOfflineReplicas.map(PartitionAndReplica(_, brokerId))) + } + } + + private def processTopicDeletionStopReplicaResponseReceived(replicaId: Int, + requestError: Errors, + partitionErrors: Map[TopicPartition, Errors]): Unit = { + if (!isActive) return + debug(s"Delete topic callback invoked on StopReplica response received from broker $replicaId: " + + s"request error = $requestError, partition errors = $partitionErrors") + + val partitionsInError = if (requestError != Errors.NONE) + partitionErrors.keySet + else + partitionErrors.filter { case (_, error) => error != Errors.NONE }.keySet + + val replicasInError = partitionsInError.map(PartitionAndReplica(_, replicaId)) + // move all the failed replicas to ReplicaDeletionIneligible + topicDeletionManager.failReplicaDeletion(replicasInError) + if (replicasInError.size != partitionErrors.size) { + // some replicas could have been successfully deleted + val deletedReplicas = partitionErrors.keySet.diff(partitionsInError) + topicDeletionManager.completeReplicaDeletion(deletedReplicas.map(PartitionAndReplica(_, replicaId))) + } + } + + private def processStartup(): Unit = { + zkClient.registerZNodeChangeHandlerAndCheckExistence(controllerChangeHandler) + elect() + } + + private def updateMetrics(): Unit = { + offlinePartitionCount = + if (!isActive) { + 0 + } else { + controllerContext.offlinePartitionCount + } + + preferredReplicaImbalanceCount = + if (!isActive) { + 0 + } else { + controllerContext.preferredReplicaImbalanceCount + } + + globalTopicCount = if (!isActive) 0 else controllerContext.allTopics.size + + globalPartitionCount = if (!isActive) 0 else controllerContext.partitionWithLeadersCount + + topicsToDeleteCount = if (!isActive) 0 else controllerContext.topicsToBeDeleted.size + + replicasToDeleteCount = if (!isActive) 0 else controllerContext.topicsToBeDeleted.map { topic => + // For each enqueued topic, count the number of replicas that are not yet deleted + controllerContext.replicasForTopic(topic).count { replica => + controllerContext.replicaState(replica) != ReplicaDeletionSuccessful + } + }.sum + + ineligibleTopicsToDeleteCount = if (!isActive) 0 else controllerContext.topicsIneligibleForDeletion.size + + ineligibleReplicasToDeleteCount = if (!isActive) 0 else controllerContext.topicsToBeDeleted.map { topic => + // For each enqueued topic, count the number of replicas that are ineligible + controllerContext.replicasForTopic(topic).count { replica => + controllerContext.replicaState(replica) == ReplicaDeletionIneligible + } + }.sum + + activeBrokerCount = if (isActive) controllerContext.liveOrShuttingDownBrokerIds.size else 0 + } + + // visible for testing + private[controller] def handleIllegalState(e: IllegalStateException): Nothing = { + // Resign if the controller is in an illegal state + error("Forcing the controller to resign") + brokerRequestBatch.clear() + triggerControllerMove() + throw e + } + + private def triggerControllerMove(): Unit = { + activeControllerId = zkClient.getControllerId.getOrElse(-1) + if (!isActive) { + warn("Controller has already moved when trying to trigger controller movement") + return + } + try { + val expectedControllerEpochZkVersion = controllerContext.epochZkVersion + activeControllerId = -1 + onControllerResignation() + zkClient.deleteController(expectedControllerEpochZkVersion) + } catch { + case _: ControllerMovedException => + warn("Controller has already moved when trying to trigger controller movement") + } + } + + private def maybeResign(): Unit = { + val wasActiveBeforeChange = isActive + zkClient.registerZNodeChangeHandlerAndCheckExistence(controllerChangeHandler) + activeControllerId = zkClient.getControllerId.getOrElse(-1) + if (wasActiveBeforeChange && !isActive) { + onControllerResignation() + } + } + + private def elect(): Unit = { + activeControllerId = zkClient.getControllerId.getOrElse(-1) + /* + * We can get here during the initial startup and the handleDeleted ZK callback. Because of the potential race condition, + * it's possible that the controller has already been elected when we get here. This check will prevent the following + * createEphemeralPath method from getting into an infinite loop if this broker is already the controller. + */ + if (activeControllerId != -1) { + debug(s"Broker $activeControllerId has been elected as the controller, so stopping the election process.") + return + } + + try { + val (epoch, epochZkVersion) = zkClient.registerControllerAndIncrementControllerEpoch(config.brokerId) + controllerContext.epoch = epoch + controllerContext.epochZkVersion = epochZkVersion + activeControllerId = config.brokerId + + info(s"${config.brokerId} successfully elected as the controller. Epoch incremented to ${controllerContext.epoch} " + + s"and epoch zk version is now ${controllerContext.epochZkVersion}") + + onControllerFailover() + } catch { + case e: ControllerMovedException => + maybeResign() + + if (activeControllerId != -1) + debug(s"Broker $activeControllerId was elected as controller instead of broker ${config.brokerId}", e) + else + warn("A controller has been elected but just resigned, this will result in another round of election", e) + case t: Throwable => + error(s"Error while electing or becoming controller on broker ${config.brokerId}. " + + s"Trigger controller movement immediately", t) + triggerControllerMove() + } + } + + /** + * Partitions the provided map of brokers and epochs into 2 new maps: + * - The first map contains only those brokers whose features were found to be compatible with + * the existing finalized features. + * - The second map contains only those brokers whose features were found to be incompatible with + * the existing finalized features. + * + * @param brokersAndEpochs the map to be partitioned + * @return two maps: first contains compatible brokers and second contains + * incompatible brokers as explained above + */ + private def partitionOnFeatureCompatibility(brokersAndEpochs: Map[Broker, Long]): (Map[Broker, Long], Map[Broker, Long]) = { + // There can not be any feature incompatibilities when the feature versioning system is disabled + // or when the finalized feature cache is empty. Otherwise, we check if the non-empty contents + // of the cache are compatible with the supported features of each broker. + brokersAndEpochs.partition { + case (broker, _) => + !config.isFeatureVersioningSupported || + !featureCache.get.exists( + latestFinalizedFeatures => + BrokerFeatures.hasIncompatibleFeatures(broker.features, latestFinalizedFeatures.features)) + } + } + + private def processBrokerChange(): Unit = { + if (!isActive) return + val curBrokerAndEpochs = zkClient.getAllBrokerAndEpochsInCluster + val curBrokerIdAndEpochs = curBrokerAndEpochs map { case (broker, epoch) => (broker.id, epoch) } + val curBrokerIds = curBrokerIdAndEpochs.keySet + val liveOrShuttingDownBrokerIds = controllerContext.liveOrShuttingDownBrokerIds + val newBrokerIds = curBrokerIds.diff(liveOrShuttingDownBrokerIds) + val deadBrokerIds = liveOrShuttingDownBrokerIds.diff(curBrokerIds) + val bouncedBrokerIds = (curBrokerIds & liveOrShuttingDownBrokerIds) + .filter(brokerId => curBrokerIdAndEpochs(brokerId) > controllerContext.liveBrokerIdAndEpochs(brokerId)) + val newBrokerAndEpochs = curBrokerAndEpochs.filter { case (broker, _) => newBrokerIds.contains(broker.id) } + val bouncedBrokerAndEpochs = curBrokerAndEpochs.filter { case (broker, _) => bouncedBrokerIds.contains(broker.id) } + val newBrokerIdsSorted = newBrokerIds.toSeq.sorted + val deadBrokerIdsSorted = deadBrokerIds.toSeq.sorted + val liveBrokerIdsSorted = curBrokerIds.toSeq.sorted + val bouncedBrokerIdsSorted = bouncedBrokerIds.toSeq.sorted + info(s"Newly added brokers: ${newBrokerIdsSorted.mkString(",")}, " + + s"deleted brokers: ${deadBrokerIdsSorted.mkString(",")}, " + + s"bounced brokers: ${bouncedBrokerIdsSorted.mkString(",")}, " + + s"all live brokers: ${liveBrokerIdsSorted.mkString(",")}") + + newBrokerAndEpochs.keySet.foreach(controllerChannelManager.addBroker) + bouncedBrokerIds.foreach(controllerChannelManager.removeBroker) + bouncedBrokerAndEpochs.keySet.foreach(controllerChannelManager.addBroker) + deadBrokerIds.foreach(controllerChannelManager.removeBroker) + + if (newBrokerIds.nonEmpty) { + val (newCompatibleBrokerAndEpochs, newIncompatibleBrokerAndEpochs) = + partitionOnFeatureCompatibility(newBrokerAndEpochs) + if (!newIncompatibleBrokerAndEpochs.isEmpty) { + warn("Ignoring registration of new brokers due to incompatibilities with finalized features: " + + newIncompatibleBrokerAndEpochs.map { case (broker, _) => broker.id }.toSeq.sorted.mkString(",")) + } + controllerContext.addLiveBrokers(newCompatibleBrokerAndEpochs) + onBrokerStartup(newBrokerIdsSorted) + } + if (bouncedBrokerIds.nonEmpty) { + controllerContext.removeLiveBrokers(bouncedBrokerIds) + onBrokerFailure(bouncedBrokerIdsSorted) + val (bouncedCompatibleBrokerAndEpochs, bouncedIncompatibleBrokerAndEpochs) = + partitionOnFeatureCompatibility(bouncedBrokerAndEpochs) + if (!bouncedIncompatibleBrokerAndEpochs.isEmpty) { + warn("Ignoring registration of bounced brokers due to incompatibilities with finalized features: " + + bouncedIncompatibleBrokerAndEpochs.map { case (broker, _) => broker.id }.toSeq.sorted.mkString(",")) + } + controllerContext.addLiveBrokers(bouncedCompatibleBrokerAndEpochs) + onBrokerStartup(bouncedBrokerIdsSorted) + } + if (deadBrokerIds.nonEmpty) { + controllerContext.removeLiveBrokers(deadBrokerIds) + onBrokerFailure(deadBrokerIdsSorted) + } + + if (newBrokerIds.nonEmpty || deadBrokerIds.nonEmpty || bouncedBrokerIds.nonEmpty) { + info(s"Updated broker epochs cache: ${controllerContext.liveBrokerIdAndEpochs}") + } + } + + private def processBrokerModification(brokerId: Int): Unit = { + if (!isActive) return + val newMetadataOpt = zkClient.getBroker(brokerId) + val oldMetadataOpt = controllerContext.liveOrShuttingDownBroker(brokerId) + if (newMetadataOpt.nonEmpty && oldMetadataOpt.nonEmpty) { + val oldMetadata = oldMetadataOpt.get + val newMetadata = newMetadataOpt.get + if (newMetadata.endPoints != oldMetadata.endPoints || !oldMetadata.features.equals(newMetadata.features)) { + info(s"Updated broker metadata: $oldMetadata -> $newMetadata") + controllerContext.updateBrokerMetadata(oldMetadata, newMetadata) + onBrokerUpdate(brokerId) + } + } + } + + private def processTopicChange(): Unit = { + if (!isActive) return + val topics = zkClient.getAllTopicsInCluster(true) + val newTopics = topics -- controllerContext.allTopics + val deletedTopics = controllerContext.allTopics.diff(topics) + controllerContext.setAllTopics(topics) + + registerPartitionModificationsHandlers(newTopics.toSeq) + val addedPartitionReplicaAssignment = zkClient.getReplicaAssignmentAndTopicIdForTopics(newTopics) + deletedTopics.foreach(controllerContext.removeTopic) + processTopicIds(addedPartitionReplicaAssignment) + + addedPartitionReplicaAssignment.foreach { case TopicIdReplicaAssignment(_, _, newAssignments) => + newAssignments.foreach { case (topicAndPartition, newReplicaAssignment) => + controllerContext.updatePartitionFullReplicaAssignment(topicAndPartition, newReplicaAssignment) + } + } + info(s"New topics: [$newTopics], deleted topics: [$deletedTopics], new partition replica assignment " + + s"[$addedPartitionReplicaAssignment]") + if (addedPartitionReplicaAssignment.nonEmpty) { + val partitionAssignments = addedPartitionReplicaAssignment + .map { case TopicIdReplicaAssignment(_, _, partitionsReplicas) => partitionsReplicas.keySet } + .reduce((s1, s2) => s1.union(s2)) + onNewPartitionCreation(partitionAssignments) + } + } + + private def processTopicIds(topicIdAssignments: Set[TopicIdReplicaAssignment]): Unit = { + // Create topic IDs for topics missing them if we are using topic IDs + // Otherwise, maintain what we have in the topicZNode + val updatedTopicIdAssignments = if (config.usesTopicId) { + val (withTopicIds, withoutTopicIds) = topicIdAssignments.partition(_.topicId.isDefined) + withTopicIds ++ zkClient.setTopicIds(withoutTopicIds, controllerContext.epochZkVersion) + } else { + topicIdAssignments + } + + // Add topic IDs to controller context + // If we don't have IBP 2.8, but are running 2.8 code, put any topic IDs from the ZNode in controller context + // This is to avoid losing topic IDs during operations like partition reassignments while the cluster is in a mixed state + updatedTopicIdAssignments.foreach { topicIdAssignment => + topicIdAssignment.topicId.foreach { topicId => + controllerContext.addTopicId(topicIdAssignment.topic, topicId) + } + } + } + + private def processLogDirEventNotification(): Unit = { + if (!isActive) return + val sequenceNumbers = zkClient.getAllLogDirEventNotifications + try { + val brokerIds = zkClient.getBrokerIdsFromLogDirEvents(sequenceNumbers) + onBrokerLogDirFailure(brokerIds) + } finally { + // delete processed children + zkClient.deleteLogDirEventNotifications(sequenceNumbers, controllerContext.epochZkVersion) + } + } + + private def processPartitionModifications(topic: String): Unit = { + def restorePartitionReplicaAssignment( + topic: String, + newPartitionReplicaAssignment: Map[TopicPartition, ReplicaAssignment] + ): Unit = { + info("Restoring the partition replica assignment for topic %s".format(topic)) + + val existingPartitions = zkClient.getChildren(TopicPartitionsZNode.path(topic)) + val existingPartitionReplicaAssignment = newPartitionReplicaAssignment + .filter(p => existingPartitions.contains(p._1.partition.toString)) + .map { case (tp, _) => + tp -> controllerContext.partitionFullReplicaAssignment(tp) + }.toMap + + zkClient.setTopicAssignment(topic, + controllerContext.topicIds.get(topic), + existingPartitionReplicaAssignment, + controllerContext.epochZkVersion) + } + + if (!isActive) return + val partitionReplicaAssignment = zkClient.getFullReplicaAssignmentForTopics(immutable.Set(topic)) + val partitionsToBeAdded = partitionReplicaAssignment.filter { case (topicPartition, _) => + controllerContext.partitionReplicaAssignment(topicPartition).isEmpty + } + + if (topicDeletionManager.isTopicQueuedUpForDeletion(topic)) { + if (partitionsToBeAdded.nonEmpty) { + warn("Skipping adding partitions %s for topic %s since it is currently being deleted" + .format(partitionsToBeAdded.map(_._1.partition).mkString(","), topic)) + + restorePartitionReplicaAssignment(topic, partitionReplicaAssignment) + } else { + // This can happen if existing partition replica assignment are restored to prevent increasing partition count during topic deletion + info("Ignoring partition change during topic deletion as no new partitions are added") + } + } else if (partitionsToBeAdded.nonEmpty) { + info(s"New partitions to be added $partitionsToBeAdded") + partitionsToBeAdded.forKeyValue { (topicPartition, assignedReplicas) => + controllerContext.updatePartitionFullReplicaAssignment(topicPartition, assignedReplicas) + } + onNewPartitionCreation(partitionsToBeAdded.keySet) + } + } + + private def processTopicDeletion(): Unit = { + if (!isActive) return + var topicsToBeDeleted = zkClient.getTopicDeletions.toSet + debug(s"Delete topics listener fired for topics ${topicsToBeDeleted.mkString(",")} to be deleted") + val nonExistentTopics = topicsToBeDeleted -- controllerContext.allTopics + if (nonExistentTopics.nonEmpty) { + warn(s"Ignoring request to delete non-existing topics ${nonExistentTopics.mkString(",")}") + zkClient.deleteTopicDeletions(nonExistentTopics.toSeq, controllerContext.epochZkVersion) + } + topicsToBeDeleted --= nonExistentTopics + if (config.deleteTopicEnable) { + if (topicsToBeDeleted.nonEmpty) { + info(s"Starting topic deletion for topics ${topicsToBeDeleted.mkString(",")}") + // mark topic ineligible for deletion if other state changes are in progress + topicsToBeDeleted.foreach { topic => + val partitionReassignmentInProgress = + controllerContext.partitionsBeingReassigned.map(_.topic).contains(topic) + if (partitionReassignmentInProgress) + topicDeletionManager.markTopicIneligibleForDeletion(Set(topic), + reason = "topic reassignment in progress") + } + // add topic to deletion list + topicDeletionManager.enqueueTopicsForDeletion(topicsToBeDeleted) + } + } else { + // If delete topic is disabled remove entries under zookeeper path : /admin/delete_topics + info(s"Removing $topicsToBeDeleted since delete topic is disabled") + zkClient.deleteTopicDeletions(topicsToBeDeleted.toSeq, controllerContext.epochZkVersion) + } + } + + private def processZkPartitionReassignment(): Set[TopicPartition] = { + // We need to register the watcher if the path doesn't exist in order to detect future + // reassignments and we get the `path exists` check for free + if (isActive && zkClient.registerZNodeChangeHandlerAndCheckExistence(partitionReassignmentHandler)) { + val reassignmentResults = mutable.Map.empty[TopicPartition, ApiError] + val partitionsToReassign = mutable.Map.empty[TopicPartition, ReplicaAssignment] + + zkClient.getPartitionReassignment.forKeyValue { (tp, targetReplicas) => + maybeBuildReassignment(tp, Some(targetReplicas)) match { + case Some(context) => partitionsToReassign.put(tp, context) + case None => reassignmentResults.put(tp, new ApiError(Errors.NO_REASSIGNMENT_IN_PROGRESS)) + } + } + + reassignmentResults ++= maybeTriggerPartitionReassignment(partitionsToReassign) + val (partitionsReassigned, partitionsFailed) = reassignmentResults.partition(_._2.error == Errors.NONE) + if (partitionsFailed.nonEmpty) { + warn(s"Failed reassignment through zk with the following errors: $partitionsFailed") + maybeRemoveFromZkReassignment((tp, _) => partitionsFailed.contains(tp)) + } + partitionsReassigned.keySet + } else { + Set.empty + } + } + + /** + * Process a partition reassignment from the AlterPartitionReassignment API. If there is an + * existing reassignment through zookeeper for any of the requested partitions, they will be + * cancelled prior to beginning the new reassignment. Any zk-based reassignment for partitions + * which are NOT included in this call will not be affected. + * + * @param reassignments Map of reassignments passed through the AlterReassignments API. A null value + * means that we should cancel an in-progress reassignment. + * @param callback Callback to send AlterReassignments response + */ + private def processApiPartitionReassignment(reassignments: Map[TopicPartition, Option[Seq[Int]]], + callback: AlterReassignmentsCallback): Unit = { + if (!isActive) { + callback(Right(new ApiError(Errors.NOT_CONTROLLER))) + } else { + val reassignmentResults = mutable.Map.empty[TopicPartition, ApiError] + val partitionsToReassign = mutable.Map.empty[TopicPartition, ReplicaAssignment] + + reassignments.forKeyValue { (tp, targetReplicas) => + val maybeApiError = targetReplicas.flatMap(validateReplicas(tp, _)) + maybeApiError match { + case None => + maybeBuildReassignment(tp, targetReplicas) match { + case Some(context) => partitionsToReassign.put(tp, context) + case None => reassignmentResults.put(tp, new ApiError(Errors.NO_REASSIGNMENT_IN_PROGRESS)) + } + case Some(err) => + reassignmentResults.put(tp, err) + } + } + + // The latest reassignment (whether by API or through zk) always takes precedence, + // so remove from active zk reassignment (if one exists) + maybeRemoveFromZkReassignment((tp, _) => partitionsToReassign.contains(tp)) + + reassignmentResults ++= maybeTriggerPartitionReassignment(partitionsToReassign) + callback(Left(reassignmentResults)) + } + } + + private def validateReplicas(topicPartition: TopicPartition, replicas: Seq[Int]): Option[ApiError] = { + val replicaSet = replicas.toSet + if (replicas.isEmpty) + Some(new ApiError(Errors.INVALID_REPLICA_ASSIGNMENT, + s"Empty replica list specified in partition reassignment.")) + else if (replicas.size != replicaSet.size) { + Some(new ApiError(Errors.INVALID_REPLICA_ASSIGNMENT, + s"Duplicate replica ids in partition reassignment replica list: $replicas")) + } else if (replicas.exists(_ < 0)) + Some(new ApiError(Errors.INVALID_REPLICA_ASSIGNMENT, + s"Invalid broker id in replica list: $replicas")) + else { + // Ensure that any new replicas are among the live brokers + val currentAssignment = controllerContext.partitionFullReplicaAssignment(topicPartition) + val newAssignment = currentAssignment.reassignTo(replicas) + val areNewReplicasAlive = newAssignment.addingReplicas.toSet.subsetOf(controllerContext.liveBrokerIds) + if (!areNewReplicasAlive) + Some(new ApiError(Errors.INVALID_REPLICA_ASSIGNMENT, + s"Replica assignment has brokers that are not alive. Replica list: " + + s"${newAssignment.addingReplicas}, live broker list: ${controllerContext.liveBrokerIds}")) + else None + } + } + + private def maybeBuildReassignment(topicPartition: TopicPartition, + targetReplicasOpt: Option[Seq[Int]]): Option[ReplicaAssignment] = { + val replicaAssignment = controllerContext.partitionFullReplicaAssignment(topicPartition) + if (replicaAssignment.isBeingReassigned) { + val targetReplicas = targetReplicasOpt.getOrElse(replicaAssignment.originReplicas) + Some(replicaAssignment.reassignTo(targetReplicas)) + } else { + targetReplicasOpt.map { targetReplicas => + replicaAssignment.reassignTo(targetReplicas) + } + } + } + + private def processPartitionReassignmentIsrChange(topicPartition: TopicPartition): Unit = { + if (!isActive) return + + if (controllerContext.partitionsBeingReassigned.contains(topicPartition)) { + maybeCompleteReassignment(topicPartition) + } + } + + private def maybeCompleteReassignment(topicPartition: TopicPartition): Unit = { + val reassignment = controllerContext.partitionFullReplicaAssignment(topicPartition) + if (isReassignmentComplete(topicPartition, reassignment)) { + // resume the partition reassignment process + info(s"Target replicas ${reassignment.targetReplicas} have all caught up with the leader for " + + s"reassigning partition $topicPartition") + onPartitionReassignment(topicPartition, reassignment) + } + } + + private def processListPartitionReassignments(partitionsOpt: Option[Set[TopicPartition]], callback: ListReassignmentsCallback): Unit = { + if (!isActive) { + callback(Right(new ApiError(Errors.NOT_CONTROLLER))) + } else { + val results: mutable.Map[TopicPartition, ReplicaAssignment] = mutable.Map.empty + val partitionsToList = partitionsOpt match { + case Some(partitions) => partitions + case None => controllerContext.partitionsBeingReassigned + } + + partitionsToList.foreach { tp => + val assignment = controllerContext.partitionFullReplicaAssignment(tp) + if (assignment.isBeingReassigned) { + results += tp -> assignment + } + } + + callback(Left(results)) + } + } + + /** + * Returns the new FinalizedVersionRange for the feature, if there are no feature + * incompatibilities seen with all known brokers for the provided feature update. + * Otherwise returns an ApiError object containing Errors.INVALID_REQUEST. + * + * @param update the feature update to be processed (this can not be meant to delete the feature) + * + * @return the new FinalizedVersionRange or error, as described above. + */ + private def newFinalizedVersionRangeOrIncompatibilityError(update: UpdateFeaturesRequestData.FeatureUpdateKey): Either[FinalizedVersionRange, ApiError] = { + if (UpdateFeaturesRequest.isDeleteRequest(update)) { + throw new IllegalArgumentException(s"Provided feature update can not be meant to delete the feature: $update") + } + + val supportedVersionRange = brokerFeatures.supportedFeatures.get(update.feature) + if (supportedVersionRange == null) { + Right(new ApiError(Errors.INVALID_REQUEST, + "Could not apply finalized feature update because the provided feature" + + " is not supported.")) + } else { + var newVersionRange: FinalizedVersionRange = null + try { + newVersionRange = new FinalizedVersionRange(supportedVersionRange.min, update.maxVersionLevel) + } catch { + case _: IllegalArgumentException => { + // This exception means the provided maxVersionLevel is invalid. It is handled below + // outside of this catch clause. + } + } + if (newVersionRange == null) { + Right(new ApiError(Errors.INVALID_REQUEST, + "Could not apply finalized feature update because the provided" + + s" maxVersionLevel:${update.maxVersionLevel} is lower than the" + + s" supported minVersion:${supportedVersionRange.min}.")) + } else { + val newFinalizedFeature = + Features.finalizedFeatures(Utils.mkMap(Utils.mkEntry(update.feature, newVersionRange))) + val numIncompatibleBrokers = controllerContext.liveOrShuttingDownBrokers.count(broker => { + BrokerFeatures.hasIncompatibleFeatures(broker.features, newFinalizedFeature) + }) + if (numIncompatibleBrokers == 0) { + Left(newVersionRange) + } else { + Right(new ApiError(Errors.INVALID_REQUEST, + "Could not apply finalized feature update because" + + " brokers were found to have incompatible versions for the feature.")) + } + } + } + } + + /** + * Validates a feature update on an existing FinalizedVersionRange. + * If the validation succeeds, then, the return value contains: + * 1. the new FinalizedVersionRange for the feature, if the feature update was not meant to delete the feature. + * 2. Option.empty, if the feature update was meant to delete the feature. + * + * If the validation fails, then returned value contains a suitable ApiError. + * + * @param update the feature update to be processed. + * @param existingVersionRange the existing FinalizedVersionRange which can be empty when no + * FinalizedVersionRange exists for the associated feature + * + * @return the new FinalizedVersionRange to be updated into ZK or error + * as described above. + */ + private def validateFeatureUpdate(update: UpdateFeaturesRequestData.FeatureUpdateKey, + existingVersionRange: Option[FinalizedVersionRange]): Either[Option[FinalizedVersionRange], ApiError] = { + def newVersionRangeOrError(update: UpdateFeaturesRequestData.FeatureUpdateKey): Either[Option[FinalizedVersionRange], ApiError] = { + newFinalizedVersionRangeOrIncompatibilityError(update) + .fold(versionRange => Left(Some(versionRange)), error => Right(error)) + } + + if (update.feature.isEmpty) { + // Check that the feature name is not empty. + Right(new ApiError(Errors.INVALID_REQUEST, "Feature name can not be empty.")) + } else { + // We handle deletion requests separately from non-deletion requests. + if (UpdateFeaturesRequest.isDeleteRequest(update)) { + if (existingVersionRange.isEmpty) { + // Disallow deletion of a non-existing finalized feature. + Right(new ApiError(Errors.INVALID_REQUEST, + "Can not delete non-existing finalized feature.")) + } else { + Left(Option.empty) + } + } else if (update.maxVersionLevel() < 1) { + // Disallow deletion of a finalized feature without allowDowngrade flag set. + Right(new ApiError(Errors.INVALID_REQUEST, + s"Can not provide maxVersionLevel: ${update.maxVersionLevel} less" + + s" than 1 without setting the allowDowngrade flag to true in the request.")) + } else { + existingVersionRange.map(existing => + if (update.maxVersionLevel == existing.max) { + // Disallow a case where target maxVersionLevel matches existing maxVersionLevel. + Right(new ApiError(Errors.INVALID_REQUEST, + s"Can not ${if (update.allowDowngrade) "downgrade" else "upgrade"}" + + s" a finalized feature from existing maxVersionLevel:${existing.max}" + + " to the same value.")) + } else if (update.maxVersionLevel < existing.max && !update.allowDowngrade) { + // Disallow downgrade of a finalized feature without the allowDowngrade flag set. + Right(new ApiError(Errors.INVALID_REQUEST, + s"Can not downgrade finalized feature from existing" + + s" maxVersionLevel:${existing.max} to provided" + + s" maxVersionLevel:${update.maxVersionLevel} without setting the" + + " allowDowngrade flag in the request.")) + } else if (update.allowDowngrade && update.maxVersionLevel > existing.max) { + // Disallow a request that sets allowDowngrade flag without specifying a + // maxVersionLevel that's lower than the existing maxVersionLevel. + Right(new ApiError(Errors.INVALID_REQUEST, + s"When the allowDowngrade flag set in the request, the provided" + + s" maxVersionLevel:${update.maxVersionLevel} can not be greater than" + + s" existing maxVersionLevel:${existing.max}.")) + } else if (update.maxVersionLevel < existing.min) { + // Disallow downgrade of a finalized feature below the existing finalized + // minVersionLevel. + Right(new ApiError(Errors.INVALID_REQUEST, + s"Can not downgrade finalized feature to maxVersionLevel:${update.maxVersionLevel}" + + s" because it's lower than the existing minVersionLevel:${existing.min}.")) + } else { + newVersionRangeOrError(update) + } + ).getOrElse(newVersionRangeOrError(update)) + } + } + } + + private def processFeatureUpdates(request: UpdateFeaturesRequest, + callback: UpdateFeaturesCallback): Unit = { + if (isActive) { + processFeatureUpdatesWithActiveController(request, callback) + } else { + callback(Left(new ApiError(Errors.NOT_CONTROLLER))) + } + } + + private def processFeatureUpdatesWithActiveController(request: UpdateFeaturesRequest, + callback: UpdateFeaturesCallback): Unit = { + val updates = request.data.featureUpdates + val existingFeatures = featureCache.get + .map(featuresAndEpoch => featuresAndEpoch.features.features().asScala) + .getOrElse(Map[String, FinalizedVersionRange]()) + // A map with key being feature name and value being FinalizedVersionRange. + // This contains the target features to be eventually written to FeatureZNode. + val targetFeatures = scala.collection.mutable.Map[String, FinalizedVersionRange]() ++ existingFeatures + // A map with key being feature name and value being error encountered when the FeatureUpdate + // was applied. + val errors = scala.collection.mutable.Map[String, ApiError]() + + // Below we process each FeatureUpdate using the following logic: + // - If a FeatureUpdate is found to be valid, then: + // - The corresponding entry in errors map would be updated to contain Errors.NONE. + // - If the FeatureUpdate is an add or update request, then the targetFeatures map is updated + // to contain the new FinalizedVersionRange for the feature. + // - Otherwise if the FeatureUpdate is a delete request, then the feature is removed from the + // targetFeatures map. + // - Otherwise if a FeatureUpdate is found to be invalid, then: + // - The corresponding entry in errors map would be updated with the appropriate ApiError. + // - The entry in targetFeatures map is left untouched. + updates.asScala.iterator.foreach { update => + validateFeatureUpdate(update, existingFeatures.get(update.feature())) match { + case Left(newVersionRangeOrNone) => + newVersionRangeOrNone match { + case Some(newVersionRange) => targetFeatures += (update.feature() -> newVersionRange) + case None => targetFeatures -= update.feature() + } + errors += (update.feature() -> new ApiError(Errors.NONE)) + case Right(featureUpdateFailureReason) => + errors += (update.feature() -> featureUpdateFailureReason) + } + } + + // If the existing and target features are the same, then, we skip the update to the + // FeatureZNode as no changes to the node are required. Otherwise, we replace the contents + // of the FeatureZNode with the new features. This may result in partial or full modification + // of the existing finalized features in ZK. + try { + if (!existingFeatures.equals(targetFeatures)) { + val newNode = new FeatureZNode(FeatureZNodeStatus.Enabled, Features.finalizedFeatures(targetFeatures.asJava)) + val newVersion = updateFeatureZNode(newNode) + featureCache.waitUntilEpochOrThrow(newVersion, request.data().timeoutMs()) + } + } catch { + // For all features that correspond to valid FeatureUpdate (i.e. error is Errors.NONE), + // we set the error as Errors.FEATURE_UPDATE_FAILED since the FeatureZNode update has failed + // for these. For the rest, the existing error is left untouched. + case e: Exception => + warn(s"Processing of feature updates: $request failed due to error: $e") + errors.foreach { case (feature, apiError) => + if (apiError.error() == Errors.NONE) { + errors(feature) = new ApiError(Errors.FEATURE_UPDATE_FAILED) + } + } + } finally { + callback(Right(errors)) + } + } + + private def processIsrChangeNotification(): Unit = { + def processUpdateNotifications(partitions: Seq[TopicPartition]): Unit = { + val liveBrokers: Seq[Int] = controllerContext.liveOrShuttingDownBrokerIds.toSeq + debug(s"Sending MetadataRequest to Brokers: $liveBrokers for TopicPartitions: $partitions") + sendUpdateMetadataRequest(liveBrokers, partitions.toSet) + } + + if (!isActive) return + val sequenceNumbers = zkClient.getAllIsrChangeNotifications + try { + val partitions = zkClient.getPartitionsFromIsrChangeNotifications(sequenceNumbers) + if (partitions.nonEmpty) { + updateLeaderAndIsrCache(partitions) + processUpdateNotifications(partitions) + + // During a partial upgrade, the controller may be on an IBP which assumes + // ISR changes through the `AlterIsr` API while some brokers are on an older + // IBP which assumes notification through Zookeeper. In this case, since the + // controller will not have registered watches for reassigning partitions, we + // can still rely on the batch ISR change notification path in order to + // complete the reassignment. + partitions.filter(controllerContext.partitionsBeingReassigned.contains).foreach { topicPartition => + maybeCompleteReassignment(topicPartition) + } + } + } finally { + // delete the notifications + zkClient.deleteIsrChangeNotifications(sequenceNumbers, controllerContext.epochZkVersion) + } + } + + def electLeaders( + partitions: Set[TopicPartition], + electionType: ElectionType, + callback: ElectLeadersCallback + ): Unit = { + eventManager.put(ReplicaLeaderElection(Some(partitions), electionType, AdminClientTriggered, callback)) + } + + def listPartitionReassignments(partitions: Option[Set[TopicPartition]], + callback: ListReassignmentsCallback): Unit = { + eventManager.put(ListPartitionReassignments(partitions, callback)) + } + + def updateFeatures(request: UpdateFeaturesRequest, + callback: UpdateFeaturesCallback): Unit = { + eventManager.put(UpdateFeatures(request, callback)) + } + + def alterPartitionReassignments(partitions: Map[TopicPartition, Option[Seq[Int]]], + callback: AlterReassignmentsCallback): Unit = { + eventManager.put(ApiPartitionReassignment(partitions, callback)) + } + + private def processReplicaLeaderElection( + partitionsFromAdminClientOpt: Option[Set[TopicPartition]], + electionType: ElectionType, + electionTrigger: ElectionTrigger, + callback: ElectLeadersCallback + ): Unit = { + if (!isActive) { + callback(partitionsFromAdminClientOpt.fold(Map.empty[TopicPartition, Either[ApiError, Int]]) { partitions => + partitions.iterator.map(partition => partition -> Left(new ApiError(Errors.NOT_CONTROLLER, null))).toMap + }) + } else { + // We need to register the watcher if the path doesn't exist in order to detect future preferred replica + // leader elections and we get the `path exists` check for free + if (electionTrigger == AdminClientTriggered || zkClient.registerZNodeChangeHandlerAndCheckExistence(preferredReplicaElectionHandler)) { + val partitions = partitionsFromAdminClientOpt match { + case Some(partitions) => partitions + case None => zkClient.getPreferredReplicaElection + } + + val allPartitions = controllerContext.allPartitions + val (knownPartitions, unknownPartitions) = partitions.partition(tp => allPartitions.contains(tp)) + unknownPartitions.foreach { p => + info(s"Skipping replica leader election ($electionType) for partition $p by $electionTrigger since it doesn't exist.") + } + + val (partitionsBeingDeleted, livePartitions) = knownPartitions.partition(partition => + topicDeletionManager.isTopicQueuedUpForDeletion(partition.topic)) + if (partitionsBeingDeleted.nonEmpty) { + warn(s"Skipping replica leader election ($electionType) for partitions $partitionsBeingDeleted " + + s"by $electionTrigger since the respective topics are being deleted") + } + + // partition those that have a valid leader + val (electablePartitions, alreadyValidLeader) = livePartitions.partition { partition => + electionType match { + case ElectionType.PREFERRED => + val assignedReplicas = controllerContext.partitionReplicaAssignment(partition) + val preferredReplica = assignedReplicas.head + val currentLeader = controllerContext.partitionLeadershipInfo(partition).get.leaderAndIsr.leader + currentLeader != preferredReplica + + case ElectionType.UNCLEAN => + val currentLeader = controllerContext.partitionLeadershipInfo(partition).get.leaderAndIsr.leader + currentLeader == LeaderAndIsr.NoLeader || !controllerContext.liveBrokerIds.contains(currentLeader) + } + } + + val results = onReplicaElection(electablePartitions, electionType, electionTrigger).map { + case (k, Left(ex)) => + if (ex.isInstanceOf[StateChangeFailedException]) { + val error = if (electionType == ElectionType.PREFERRED) { + Errors.PREFERRED_LEADER_NOT_AVAILABLE + } else { + Errors.ELIGIBLE_LEADERS_NOT_AVAILABLE + } + k -> Left(new ApiError(error, ex.getMessage)) + } else { + k -> Left(ApiError.fromThrowable(ex)) + } + case (k, Right(leaderAndIsr)) => k -> Right(leaderAndIsr.leader) + } ++ + alreadyValidLeader.map(_ -> Left(new ApiError(Errors.ELECTION_NOT_NEEDED))) ++ + partitionsBeingDeleted.map( + _ -> Left(new ApiError(Errors.INVALID_TOPIC_EXCEPTION, "The topic is being deleted")) + ) ++ + unknownPartitions.map( + _ -> Left(new ApiError(Errors.UNKNOWN_TOPIC_OR_PARTITION, "The partition does not exist.")) + ) + + debug(s"Waiting for any successful result for election type ($electionType) by $electionTrigger for partitions: $results") + callback(results) + } + } + } + + def alterIsrs(alterIsrRequest: AlterIsrRequestData, callback: AlterIsrResponseData => Unit): Unit = { + val isrsToAlter = mutable.Map[TopicPartition, LeaderAndIsr]() + + alterIsrRequest.topics.forEach { topicReq => + topicReq.partitions.forEach { partitionReq => + val tp = new TopicPartition(topicReq.name, partitionReq.partitionIndex) + val newIsr = partitionReq.newIsr().asScala.toList.map(_.toInt) + isrsToAlter.put(tp, new LeaderAndIsr(alterIsrRequest.brokerId, partitionReq.leaderEpoch, newIsr, partitionReq.currentIsrVersion)) + } + } + + def responseCallback(results: Either[Map[TopicPartition, Either[Errors, LeaderAndIsr]], Errors]): Unit = { + val resp = new AlterIsrResponseData() + results match { + case Right(error) => + resp.setErrorCode(error.code) + case Left(partitionResults) => + resp.setTopics(new util.ArrayList()) + partitionResults + .groupBy { case (tp, _) => tp.topic } // Group by topic + .foreach { case (topic, partitions) => + // Add each topic part to the response + val topicResp = new AlterIsrResponseData.TopicData() + .setName(topic) + .setPartitions(new util.ArrayList()) + resp.topics.add(topicResp) + partitions.foreach { case (tp, errorOrIsr) => + // Add each partition part to the response (new ISR or error) + errorOrIsr match { + case Left(error) => topicResp.partitions.add( + new AlterIsrResponseData.PartitionData() + .setPartitionIndex(tp.partition) + .setErrorCode(error.code)) + case Right(leaderAndIsr) => topicResp.partitions.add( + new AlterIsrResponseData.PartitionData() + .setPartitionIndex(tp.partition) + .setLeaderId(leaderAndIsr.leader) + .setLeaderEpoch(leaderAndIsr.leaderEpoch) + .setIsr(leaderAndIsr.isr.map(Integer.valueOf).asJava) + .setCurrentIsrVersion(leaderAndIsr.zkVersion)) + } + } + } + } + callback.apply(resp) + } + + eventManager.put(AlterIsrReceived(alterIsrRequest.brokerId, alterIsrRequest.brokerEpoch, isrsToAlter, responseCallback)) + } + + private def processAlterIsr(brokerId: Int, brokerEpoch: Long, + isrsToAlter: Map[TopicPartition, LeaderAndIsr], + callback: AlterIsrCallback): Unit = { + + // Handle a few short-circuits + if (!isActive) { + callback.apply(Right(Errors.NOT_CONTROLLER)) + return + } + + val brokerEpochOpt = controllerContext.liveBrokerIdAndEpochs.get(brokerId) + if (brokerEpochOpt.isEmpty) { + info(s"Ignoring AlterIsr due to unknown broker $brokerId") + callback.apply(Right(Errors.STALE_BROKER_EPOCH)) + return + } + + if (!brokerEpochOpt.contains(brokerEpoch)) { + info(s"Ignoring AlterIsr due to stale broker epoch $brokerEpoch and local broker epoch $brokerEpochOpt for broker $brokerId") + callback.apply(Right(Errors.STALE_BROKER_EPOCH)) + return + } + + val response = try { + val partitionResponses = mutable.HashMap[TopicPartition, Either[Errors, LeaderAndIsr]]() + + // Determine which partitions we will accept the new ISR for + val adjustedIsrs: Map[TopicPartition, LeaderAndIsr] = isrsToAlter.flatMap { + case (tp: TopicPartition, newLeaderAndIsr: LeaderAndIsr) => + controllerContext.partitionLeadershipInfo(tp) match { + case Some(leaderIsrAndControllerEpoch) => + val currentLeaderAndIsr = leaderIsrAndControllerEpoch.leaderAndIsr + if (newLeaderAndIsr.leaderEpoch < currentLeaderAndIsr.leaderEpoch) { + partitionResponses(tp) = Left(Errors.FENCED_LEADER_EPOCH) + None + } else if (newLeaderAndIsr.equalsIgnoreZk(currentLeaderAndIsr)) { + // If a partition is already in the desired state, just return it + partitionResponses(tp) = Right(currentLeaderAndIsr) + None + } else { + Some(tp -> newLeaderAndIsr) + } + case None => + partitionResponses(tp) = Left(Errors.UNKNOWN_TOPIC_OR_PARTITION) + None + } + } + + // Do the updates in ZK + debug(s"Updating ISRs for partitions: ${adjustedIsrs.keySet}.") + val UpdateLeaderAndIsrResult(finishedUpdates, badVersionUpdates) = zkClient.updateLeaderAndIsr( + adjustedIsrs, controllerContext.epoch, controllerContext.epochZkVersion) + + val successfulUpdates: Map[TopicPartition, LeaderAndIsr] = finishedUpdates.flatMap { + case (partition: TopicPartition, isrOrError: Either[Throwable, LeaderAndIsr]) => + isrOrError match { + case Right(updatedIsr) => + debug(s"ISR for partition $partition updated to [${updatedIsr.isr.mkString(",")}] and zkVersion updated to [${updatedIsr.zkVersion}]") + partitionResponses(partition) = Right(updatedIsr) + Some(partition -> updatedIsr) + case Left(e) => + error(s"Failed to update ISR for partition $partition", e) + partitionResponses(partition) = Left(Errors.forException(e)) + None + } + } + + badVersionUpdates.foreach { partition => + info(s"Failed to update ISR to ${adjustedIsrs(partition)} for partition $partition, bad ZK version.") + partitionResponses(partition) = Left(Errors.INVALID_UPDATE_VERSION) + } + + def processUpdateNotifications(partitions: Seq[TopicPartition]): Unit = { + val liveBrokers: Seq[Int] = controllerContext.liveOrShuttingDownBrokerIds.toSeq + sendUpdateMetadataRequest(liveBrokers, partitions.toSet) + } + + // Update our cache and send out metadata updates + updateLeaderAndIsrCache(successfulUpdates.keys.toSeq) + processUpdateNotifications(isrsToAlter.keys.toSeq) + + Left(partitionResponses) + } catch { + case e: Throwable => + error(s"Error when processing AlterIsr for partitions: ${isrsToAlter.keys.toSeq}", e) + Right(Errors.UNKNOWN_SERVER_ERROR) + } + + callback.apply(response) + + // After we have returned the result of the `AlterIsr` request, we should check whether + // there are any reassignments which can be completed by a successful ISR expansion. + response.left.foreach { alterIsrResponses => + alterIsrResponses.forKeyValue { (topicPartition, partitionResponse) => + if (controllerContext.partitionsBeingReassigned.contains(topicPartition)) { + val isSuccessfulUpdate = partitionResponse.isRight + if (isSuccessfulUpdate) { + maybeCompleteReassignment(topicPartition) + } + } + } + } + } + + def allocateProducerIds(allocateProducerIdsRequest: AllocateProducerIdsRequestData, + callback: AllocateProducerIdsResponseData => Unit): Unit = { + + def eventManagerCallback(results: Either[Errors, ProducerIdsBlock]): Unit = { + results match { + case Left(error) => callback.apply(new AllocateProducerIdsResponseData().setErrorCode(error.code)) + case Right(pidBlock) => callback.apply( + new AllocateProducerIdsResponseData() + .setProducerIdStart(pidBlock.producerIdStart()) + .setProducerIdLen(pidBlock.producerIdLen())) + } + } + eventManager.put(AllocateProducerIds(allocateProducerIdsRequest.brokerId, + allocateProducerIdsRequest.brokerEpoch, eventManagerCallback)) + } + + def processAllocateProducerIds(brokerId: Int, brokerEpoch: Long, callback: Either[Errors, ProducerIdsBlock] => Unit): Unit = { + // Handle a few short-circuits + if (!isActive) { + callback.apply(Left(Errors.NOT_CONTROLLER)) + return + } + + val brokerEpochOpt = controllerContext.liveBrokerIdAndEpochs.get(brokerId) + if (brokerEpochOpt.isEmpty) { + warn(s"Ignoring AllocateProducerIds due to unknown broker $brokerId") + callback.apply(Left(Errors.BROKER_ID_NOT_REGISTERED)) + return + } + + if (!brokerEpochOpt.contains(brokerEpoch)) { + warn(s"Ignoring AllocateProducerIds due to stale broker epoch $brokerEpoch for broker $brokerId") + callback.apply(Left(Errors.STALE_BROKER_EPOCH)) + return + } + + val maybeNewProducerIdsBlock = try { + Try(ZkProducerIdManager.getNewProducerIdBlock(brokerId, zkClient, this)) + } catch { + case ke: KafkaException => Failure(ke) + } + + maybeNewProducerIdsBlock match { + case Failure(exception) => callback.apply(Left(Errors.forException(exception))) + case Success(newProducerIdBlock) => callback.apply(Right(newProducerIdBlock)) + } + } + + private def processControllerChange(): Unit = { + maybeResign() + } + + private def processReelect(): Unit = { + maybeResign() + elect() + } + + private def processRegisterBrokerAndReelect(): Unit = { + _brokerEpoch = zkClient.registerBroker(brokerInfo) + processReelect() + } + + private def processExpire(): Unit = { + activeControllerId = -1 + onControllerResignation() + } + + + override def process(event: ControllerEvent): Unit = { + try { + event match { + case event: MockEvent => + // Used only in test cases + event.process() + case ShutdownEventThread => + error("Received a ShutdownEventThread event. This type of event is supposed to be handle by ControllerEventThread") + case AutoPreferredReplicaLeaderElection => + processAutoPreferredReplicaLeaderElection() + case ReplicaLeaderElection(partitions, electionType, electionTrigger, callback) => + processReplicaLeaderElection(partitions, electionType, electionTrigger, callback) + case UncleanLeaderElectionEnable => + processUncleanLeaderElectionEnable() + case TopicUncleanLeaderElectionEnable(topic) => + processTopicUncleanLeaderElectionEnable(topic) + case ControlledShutdown(id, brokerEpoch, callback) => + processControlledShutdown(id, brokerEpoch, callback) + case LeaderAndIsrResponseReceived(response, brokerId) => + processLeaderAndIsrResponseReceived(response, brokerId) + case UpdateMetadataResponseReceived(response, brokerId) => + processUpdateMetadataResponseReceived(response, brokerId) + case TopicDeletionStopReplicaResponseReceived(replicaId, requestError, partitionErrors) => + processTopicDeletionStopReplicaResponseReceived(replicaId, requestError, partitionErrors) + case BrokerChange => + processBrokerChange() + case BrokerModifications(brokerId) => + processBrokerModification(brokerId) + case ControllerChange => + processControllerChange() + case Reelect => + processReelect() + case RegisterBrokerAndReelect => + processRegisterBrokerAndReelect() + case Expire => + processExpire() + case TopicChange => + processTopicChange() + case LogDirEventNotification => + processLogDirEventNotification() + case PartitionModifications(topic) => + processPartitionModifications(topic) + case TopicDeletion => + processTopicDeletion() + case ApiPartitionReassignment(reassignments, callback) => + processApiPartitionReassignment(reassignments, callback) + case ZkPartitionReassignment => + processZkPartitionReassignment() + case ListPartitionReassignments(partitions, callback) => + processListPartitionReassignments(partitions, callback) + case UpdateFeatures(request, callback) => + processFeatureUpdates(request, callback) + case PartitionReassignmentIsrChange(partition) => + processPartitionReassignmentIsrChange(partition) + case IsrChangeNotification => + processIsrChangeNotification() + case AlterIsrReceived(brokerId, brokerEpoch, isrsToAlter, callback) => + processAlterIsr(brokerId, brokerEpoch, isrsToAlter, callback) + case AllocateProducerIds(brokerId, brokerEpoch, callback) => + processAllocateProducerIds(brokerId, brokerEpoch, callback) + case Startup => + processStartup() + } + } catch { + case e: ControllerMovedException => + info(s"Controller moved to another broker when processing $event.", e) + maybeResign() + case e: Throwable => + error(s"Error processing event $event", e) + } finally { + updateMetrics() + } + } + + override def preempt(event: ControllerEvent): Unit = { + event.preempt() + } +} + +class BrokerChangeHandler(eventManager: ControllerEventManager) extends ZNodeChildChangeHandler { + override val path: String = BrokerIdsZNode.path + + override def handleChildChange(): Unit = { + eventManager.put(BrokerChange) + } +} + +class BrokerModificationsHandler(eventManager: ControllerEventManager, brokerId: Int) extends ZNodeChangeHandler { + override val path: String = BrokerIdZNode.path(brokerId) + + override def handleDataChange(): Unit = { + eventManager.put(BrokerModifications(brokerId)) + } +} + +class TopicChangeHandler(eventManager: ControllerEventManager) extends ZNodeChildChangeHandler { + override val path: String = TopicsZNode.path + + override def handleChildChange(): Unit = eventManager.put(TopicChange) +} + +class LogDirEventNotificationHandler(eventManager: ControllerEventManager) extends ZNodeChildChangeHandler { + override val path: String = LogDirEventNotificationZNode.path + + override def handleChildChange(): Unit = eventManager.put(LogDirEventNotification) +} + +object LogDirEventNotificationHandler { + val Version: Long = 1L +} + +class PartitionModificationsHandler(eventManager: ControllerEventManager, topic: String) extends ZNodeChangeHandler { + override val path: String = TopicZNode.path(topic) + + override def handleDataChange(): Unit = eventManager.put(PartitionModifications(topic)) +} + +class TopicDeletionHandler(eventManager: ControllerEventManager) extends ZNodeChildChangeHandler { + override val path: String = DeleteTopicsZNode.path + + override def handleChildChange(): Unit = eventManager.put(TopicDeletion) +} + +class PartitionReassignmentHandler(eventManager: ControllerEventManager) extends ZNodeChangeHandler { + override val path: String = ReassignPartitionsZNode.path + + // Note that the event is also enqueued when the znode is deleted, but we do it explicitly instead of relying on + // handleDeletion(). This approach is more robust as it doesn't depend on the watcher being re-registered after + // it's consumed during data changes (we ensure re-registration when the znode is deleted). + override def handleCreation(): Unit = eventManager.put(ZkPartitionReassignment) +} + +class PartitionReassignmentIsrChangeHandler(eventManager: ControllerEventManager, partition: TopicPartition) extends ZNodeChangeHandler { + override val path: String = TopicPartitionStateZNode.path(partition) + + override def handleDataChange(): Unit = eventManager.put(PartitionReassignmentIsrChange(partition)) +} + +class IsrChangeNotificationHandler(eventManager: ControllerEventManager) extends ZNodeChildChangeHandler { + override val path: String = IsrChangeNotificationZNode.path + + override def handleChildChange(): Unit = eventManager.put(IsrChangeNotification) +} + +object IsrChangeNotificationHandler { + val Version: Long = 1L +} + +class PreferredReplicaElectionHandler(eventManager: ControllerEventManager) extends ZNodeChangeHandler { + override val path: String = PreferredReplicaElectionZNode.path + + override def handleCreation(): Unit = eventManager.put(ReplicaLeaderElection(None, ElectionType.PREFERRED, ZkTriggered)) +} + +class ControllerChangeHandler(eventManager: ControllerEventManager) extends ZNodeChangeHandler { + override val path: String = ControllerZNode.path + + override def handleCreation(): Unit = eventManager.put(ControllerChange) + override def handleDeletion(): Unit = eventManager.put(Reelect) + override def handleDataChange(): Unit = eventManager.put(ControllerChange) +} + +case class PartitionAndReplica(topicPartition: TopicPartition, replica: Int) { + def topic: String = topicPartition.topic + def partition: Int = topicPartition.partition + + override def toString: String = { + s"[Topic=$topic,Partition=$partition,Replica=$replica]" + } +} + +case class LeaderIsrAndControllerEpoch(leaderAndIsr: LeaderAndIsr, controllerEpoch: Int) { + override def toString: String = { + val leaderAndIsrInfo = new StringBuilder + leaderAndIsrInfo.append("(Leader:" + leaderAndIsr.leader) + leaderAndIsrInfo.append(",ISR:" + leaderAndIsr.isr.mkString(",")) + leaderAndIsrInfo.append(",LeaderEpoch:" + leaderAndIsr.leaderEpoch) + leaderAndIsrInfo.append(",ZkVersion:" + leaderAndIsr.zkVersion) + leaderAndIsrInfo.append(",ControllerEpoch:" + controllerEpoch + ")") + leaderAndIsrInfo.toString() + } +} + +private[controller] class ControllerStats extends KafkaMetricsGroup { + val uncleanLeaderElectionRate = newMeter("UncleanLeaderElectionsPerSec", "elections", TimeUnit.SECONDS) + + val rateAndTimeMetrics: Map[ControllerState, KafkaTimer] = ControllerState.values.flatMap { state => + state.rateAndTimeMetricName.map { metricName => + state -> new KafkaTimer(newTimer(metricName, TimeUnit.MILLISECONDS, TimeUnit.SECONDS)) + } + }.toMap + +} + +sealed trait ControllerEvent { + def state: ControllerState + // preempt() is not executed by `ControllerEventThread` but by the main thread. + def preempt(): Unit +} + +case object ControllerChange extends ControllerEvent { + override def state: ControllerState = ControllerState.ControllerChange + override def preempt(): Unit = {} +} + +case object Reelect extends ControllerEvent { + override def state: ControllerState = ControllerState.ControllerChange + override def preempt(): Unit = {} +} + +case object RegisterBrokerAndReelect extends ControllerEvent { + override def state: ControllerState = ControllerState.ControllerChange + override def preempt(): Unit = {} +} + +case object Expire extends ControllerEvent { + override def state: ControllerState = ControllerState.ControllerChange + override def preempt(): Unit = {} +} + +case object ShutdownEventThread extends ControllerEvent { + override def state: ControllerState = ControllerState.ControllerShutdown + override def preempt(): Unit = {} +} + +case object AutoPreferredReplicaLeaderElection extends ControllerEvent { + override def state: ControllerState = ControllerState.AutoLeaderBalance + override def preempt(): Unit = {} +} + +case object UncleanLeaderElectionEnable extends ControllerEvent { + override def state: ControllerState = ControllerState.UncleanLeaderElectionEnable + override def preempt(): Unit = {} +} + +case class TopicUncleanLeaderElectionEnable(topic: String) extends ControllerEvent { + override def state: ControllerState = ControllerState.TopicUncleanLeaderElectionEnable + override def preempt(): Unit = {} +} + +case class ControlledShutdown(id: Int, brokerEpoch: Long, controlledShutdownCallback: Try[Set[TopicPartition]] => Unit) extends ControllerEvent { + override def state: ControllerState = ControllerState.ControlledShutdown + override def preempt(): Unit = controlledShutdownCallback(Failure(new ControllerMovedException("Controller moved to another broker"))) +} + +case class LeaderAndIsrResponseReceived(leaderAndIsrResponse: LeaderAndIsrResponse, brokerId: Int) extends ControllerEvent { + override def state: ControllerState = ControllerState.LeaderAndIsrResponseReceived + override def preempt(): Unit = {} +} + +case class UpdateMetadataResponseReceived(updateMetadataResponse: UpdateMetadataResponse, brokerId: Int) extends ControllerEvent { + override def state: ControllerState = ControllerState.UpdateMetadataResponseReceived + override def preempt(): Unit = {} +} + +case class TopicDeletionStopReplicaResponseReceived(replicaId: Int, + requestError: Errors, + partitionErrors: Map[TopicPartition, Errors]) extends ControllerEvent { + override def state: ControllerState = ControllerState.TopicDeletion + override def preempt(): Unit = {} +} + +case object Startup extends ControllerEvent { + override def state: ControllerState = ControllerState.ControllerChange + override def preempt(): Unit = {} +} + +case object BrokerChange extends ControllerEvent { + override def state: ControllerState = ControllerState.BrokerChange + override def preempt(): Unit = {} +} + +case class BrokerModifications(brokerId: Int) extends ControllerEvent { + override def state: ControllerState = ControllerState.BrokerChange + override def preempt(): Unit = {} +} + +case object TopicChange extends ControllerEvent { + override def state: ControllerState = ControllerState.TopicChange + override def preempt(): Unit = {} +} + +case object LogDirEventNotification extends ControllerEvent { + override def state: ControllerState = ControllerState.LogDirChange + override def preempt(): Unit = {} +} + +case class PartitionModifications(topic: String) extends ControllerEvent { + override def state: ControllerState = ControllerState.TopicChange + override def preempt(): Unit = {} +} + +case object TopicDeletion extends ControllerEvent { + override def state: ControllerState = ControllerState.TopicDeletion + override def preempt(): Unit = {} +} + +case object ZkPartitionReassignment extends ControllerEvent { + override def state: ControllerState = ControllerState.AlterPartitionReassignment + override def preempt(): Unit = {} +} + +case class ApiPartitionReassignment(reassignments: Map[TopicPartition, Option[Seq[Int]]], + callback: AlterReassignmentsCallback) extends ControllerEvent { + override def state: ControllerState = ControllerState.AlterPartitionReassignment + override def preempt(): Unit = callback(Right(new ApiError(Errors.NOT_CONTROLLER))) +} + +case class PartitionReassignmentIsrChange(partition: TopicPartition) extends ControllerEvent { + override def state: ControllerState = ControllerState.AlterPartitionReassignment + override def preempt(): Unit = {} +} + +case object IsrChangeNotification extends ControllerEvent { + override def state: ControllerState = ControllerState.IsrChange + override def preempt(): Unit = {} +} + +case class AlterIsrReceived(brokerId: Int, brokerEpoch: Long, isrsToAlter: Map[TopicPartition, LeaderAndIsr], + callback: AlterIsrCallback) extends ControllerEvent { + override def state: ControllerState = ControllerState.IsrChange + override def preempt(): Unit = {} +} + +case class ReplicaLeaderElection( + partitionsFromAdminClientOpt: Option[Set[TopicPartition]], + electionType: ElectionType, + electionTrigger: ElectionTrigger, + callback: ElectLeadersCallback = _ => {} +) extends ControllerEvent { + override def state: ControllerState = ControllerState.ManualLeaderBalance + + override def preempt(): Unit = callback( + partitionsFromAdminClientOpt.fold(Map.empty[TopicPartition, Either[ApiError, Int]]) { partitions => + partitions.iterator.map(partition => partition -> Left(new ApiError(Errors.NOT_CONTROLLER, null))).toMap + } + ) +} + +/** + * @param partitionsOpt - an Optional set of partitions. If not present, all reassigning partitions are to be listed + */ +case class ListPartitionReassignments(partitionsOpt: Option[Set[TopicPartition]], + callback: ListReassignmentsCallback) extends ControllerEvent { + override def state: ControllerState = ControllerState.ListPartitionReassignment + override def preempt(): Unit = callback(Right(new ApiError(Errors.NOT_CONTROLLER, null))) +} + +case class UpdateFeatures(request: UpdateFeaturesRequest, + callback: UpdateFeaturesCallback) extends ControllerEvent { + override def state: ControllerState = ControllerState.UpdateFeatures + override def preempt(): Unit = {} +} + +case class AllocateProducerIds(brokerId: Int, brokerEpoch: Long, callback: Either[Errors, ProducerIdsBlock] => Unit) + extends ControllerEvent { + override def state: ControllerState = ControllerState.Idle + override def preempt(): Unit = {} +} + + +// Used only in test cases +abstract class MockEvent(val state: ControllerState) extends ControllerEvent { + def process(): Unit + def preempt(): Unit +} diff --git a/core/src/main/scala/kafka/controller/PartitionStateMachine.scala b/core/src/main/scala/kafka/controller/PartitionStateMachine.scala new file mode 100755 index 0000000..105e158 --- /dev/null +++ b/core/src/main/scala/kafka/controller/PartitionStateMachine.scala @@ -0,0 +1,578 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +package kafka.controller + +import kafka.api.LeaderAndIsr +import kafka.common.StateChangeFailedException +import kafka.controller.Election._ +import kafka.server.KafkaConfig +import kafka.utils.Implicits._ +import kafka.utils.Logging +import kafka.zk.KafkaZkClient +import kafka.zk.KafkaZkClient.UpdateLeaderAndIsrResult +import kafka.zk.TopicPartitionStateZNode +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.errors.ControllerMovedException +import org.apache.zookeeper.KeeperException +import org.apache.zookeeper.KeeperException.Code +import scala.collection.{Map, Seq, mutable} + +abstract class PartitionStateMachine(controllerContext: ControllerContext) extends Logging { + /** + * Invoked on successful controller election. + */ + def startup(): Unit = { + info("Initializing partition state") + initializePartitionState() + info("Triggering online partition state changes") + triggerOnlinePartitionStateChange() + debug(s"Started partition state machine with initial state -> ${controllerContext.partitionStates}") + } + + /** + * Invoked on controller shutdown. + */ + def shutdown(): Unit = { + info("Stopped partition state machine") + } + + /** + * This API invokes the OnlinePartition state change on all partitions in either the NewPartition or OfflinePartition + * state. This is called on a successful controller election and on broker changes + */ + def triggerOnlinePartitionStateChange(): Map[TopicPartition, Either[Throwable, LeaderAndIsr]] = { + val partitions = controllerContext.partitionsInStates(Set(OfflinePartition, NewPartition)) + triggerOnlineStateChangeForPartitions(partitions) + } + + def triggerOnlinePartitionStateChange(topic: String): Unit = { + val partitions = controllerContext.partitionsInStates(topic, Set(OfflinePartition, NewPartition)) + triggerOnlineStateChangeForPartitions(partitions) + } + + private def triggerOnlineStateChangeForPartitions(partitions: collection.Set[TopicPartition]): Map[TopicPartition, Either[Throwable, LeaderAndIsr]] = { + // try to move all partitions in NewPartition or OfflinePartition state to OnlinePartition state except partitions + // that belong to topics to be deleted + val partitionsToTrigger = partitions.filter { partition => + !controllerContext.isTopicQueuedUpForDeletion(partition.topic) + }.toSeq + + handleStateChanges(partitionsToTrigger, OnlinePartition, Some(OfflinePartitionLeaderElectionStrategy(false))) + // TODO: If handleStateChanges catches an exception, it is not enough to bail out and log an error. + // It is important to trigger leader election for those partitions. + } + + /** + * Invoked on startup of the partition's state machine to set the initial state for all existing partitions in + * zookeeper + */ + private def initializePartitionState(): Unit = { + for (topicPartition <- controllerContext.allPartitions) { + // check if leader and isr path exists for partition. If not, then it is in NEW state + controllerContext.partitionLeadershipInfo(topicPartition) match { + case Some(currentLeaderIsrAndEpoch) => + // else, check if the leader for partition is alive. If yes, it is in Online state, else it is in Offline state + if (controllerContext.isReplicaOnline(currentLeaderIsrAndEpoch.leaderAndIsr.leader, topicPartition)) + // leader is alive + controllerContext.putPartitionState(topicPartition, OnlinePartition) + else + controllerContext.putPartitionState(topicPartition, OfflinePartition) + case None => + controllerContext.putPartitionState(topicPartition, NewPartition) + } + } + } + + def handleStateChanges( + partitions: Seq[TopicPartition], + targetState: PartitionState + ): Map[TopicPartition, Either[Throwable, LeaderAndIsr]] = { + handleStateChanges(partitions, targetState, None) + } + + def handleStateChanges( + partitions: Seq[TopicPartition], + targetState: PartitionState, + leaderElectionStrategy: Option[PartitionLeaderElectionStrategy] + ): Map[TopicPartition, Either[Throwable, LeaderAndIsr]] + +} + +/** + * This class represents the state machine for partitions. It defines the states that a partition can be in, and + * transitions to move the partition to another legal state. The different states that a partition can be in are - + * 1. NonExistentPartition: This state indicates that the partition was either never created or was created and then + * deleted. Valid previous state, if one exists, is OfflinePartition + * 2. NewPartition : After creation, the partition is in the NewPartition state. In this state, the partition should have + * replicas assigned to it, but no leader/isr yet. Valid previous states are NonExistentPartition + * 3. OnlinePartition : Once a leader is elected for a partition, it is in the OnlinePartition state. + * Valid previous states are NewPartition/OfflinePartition + * 4. OfflinePartition : If, after successful leader election, the leader for partition dies, then the partition + * moves to the OfflinePartition state. Valid previous states are NewPartition/OnlinePartition + */ +class ZkPartitionStateMachine(config: KafkaConfig, + stateChangeLogger: StateChangeLogger, + controllerContext: ControllerContext, + zkClient: KafkaZkClient, + controllerBrokerRequestBatch: ControllerBrokerRequestBatch) + extends PartitionStateMachine(controllerContext) { + + private val controllerId = config.brokerId + this.logIdent = s"[PartitionStateMachine controllerId=$controllerId] " + + /** + * Try to change the state of the given partitions to the given targetState, using the given + * partitionLeaderElectionStrategyOpt if a leader election is required. + * @param partitions The partitions + * @param targetState The state + * @param partitionLeaderElectionStrategyOpt The leader election strategy if a leader election is required. + * @return A map of failed and successful elections when targetState is OnlinePartitions. The keys are the + * topic partitions and the corresponding values are either the exception that was thrown or new + * leader & ISR. + */ + override def handleStateChanges( + partitions: Seq[TopicPartition], + targetState: PartitionState, + partitionLeaderElectionStrategyOpt: Option[PartitionLeaderElectionStrategy] + ): Map[TopicPartition, Either[Throwable, LeaderAndIsr]] = { + if (partitions.nonEmpty) { + try { + controllerBrokerRequestBatch.newBatch() + val result = doHandleStateChanges( + partitions, + targetState, + partitionLeaderElectionStrategyOpt + ) + controllerBrokerRequestBatch.sendRequestsToBrokers(controllerContext.epoch) + result + } catch { + case e: ControllerMovedException => + error(s"Controller moved to another broker when moving some partitions to $targetState state", e) + throw e + case e: Throwable => + error(s"Error while moving some partitions to $targetState state", e) + partitions.iterator.map(_ -> Left(e)).toMap + } + } else { + Map.empty + } + } + + private def partitionState(partition: TopicPartition): PartitionState = { + controllerContext.partitionState(partition) + } + + /** + * This API exercises the partition's state machine. It ensures that every state transition happens from a legal + * previous state to the target state. Valid state transitions are: + * NonExistentPartition -> NewPartition: + * --load assigned replicas from ZK to controller cache + * + * NewPartition -> OnlinePartition + * --assign first live replica as the leader and all live replicas as the isr; write leader and isr to ZK for this partition + * --send LeaderAndIsr request to every live replica and UpdateMetadata request to every live broker + * + * OnlinePartition,OfflinePartition -> OnlinePartition + * --select new leader and isr for this partition and a set of replicas to receive the LeaderAndIsr request, and write leader and isr to ZK + * --for this partition, send LeaderAndIsr request to every receiving replica and UpdateMetadata request to every live broker + * + * NewPartition,OnlinePartition,OfflinePartition -> OfflinePartition + * --nothing other than marking partition state as Offline + * + * OfflinePartition -> NonExistentPartition + * --nothing other than marking the partition state as NonExistentPartition + * @param partitions The partitions for which the state transition is invoked + * @param targetState The end state that the partition should be moved to + * @return A map of failed and successful elections when targetState is OnlinePartitions. The keys are the + * topic partitions and the corresponding values are either the exception that was thrown or new + * leader & ISR. + */ + private def doHandleStateChanges( + partitions: Seq[TopicPartition], + targetState: PartitionState, + partitionLeaderElectionStrategyOpt: Option[PartitionLeaderElectionStrategy] + ): Map[TopicPartition, Either[Throwable, LeaderAndIsr]] = { + val stateChangeLog = stateChangeLogger.withControllerEpoch(controllerContext.epoch) + val traceEnabled = stateChangeLog.isTraceEnabled + partitions.foreach(partition => controllerContext.putPartitionStateIfNotExists(partition, NonExistentPartition)) + val (validPartitions, invalidPartitions) = controllerContext.checkValidPartitionStateChange(partitions, targetState) + invalidPartitions.foreach(partition => logInvalidTransition(partition, targetState)) + + targetState match { + case NewPartition => + validPartitions.foreach { partition => + stateChangeLog.info(s"Changed partition $partition state from ${partitionState(partition)} to $targetState with " + + s"assigned replicas ${controllerContext.partitionReplicaAssignment(partition).mkString(",")}") + controllerContext.putPartitionState(partition, NewPartition) + } + Map.empty + case OnlinePartition => + val uninitializedPartitions = validPartitions.filter(partition => partitionState(partition) == NewPartition) + val partitionsToElectLeader = validPartitions.filter(partition => partitionState(partition) == OfflinePartition || partitionState(partition) == OnlinePartition) + if (uninitializedPartitions.nonEmpty) { + val successfulInitializations = initializeLeaderAndIsrForPartitions(uninitializedPartitions) + successfulInitializations.foreach { partition => + stateChangeLog.info(s"Changed partition $partition from ${partitionState(partition)} to $targetState with state " + + s"${controllerContext.partitionLeadershipInfo(partition).get.leaderAndIsr}") + controllerContext.putPartitionState(partition, OnlinePartition) + } + } + if (partitionsToElectLeader.nonEmpty) { + val electionResults = electLeaderForPartitions( + partitionsToElectLeader, + partitionLeaderElectionStrategyOpt.getOrElse( + throw new IllegalArgumentException("Election strategy is a required field when the target state is OnlinePartition") + ) + ) + + electionResults.foreach { + case (partition, Right(leaderAndIsr)) => + stateChangeLog.info( + s"Changed partition $partition from ${partitionState(partition)} to $targetState with state $leaderAndIsr" + ) + controllerContext.putPartitionState(partition, OnlinePartition) + case (_, Left(_)) => // Ignore; no need to update partition state on election error + } + + electionResults + } else { + Map.empty + } + case OfflinePartition | NonExistentPartition => + validPartitions.foreach { partition => + if (traceEnabled) + stateChangeLog.trace(s"Changed partition $partition state from ${partitionState(partition)} to $targetState") + controllerContext.putPartitionState(partition, targetState) + } + Map.empty + } + } + + /** + * Initialize leader and isr partition state in zookeeper. + * @param partitions The partitions that we're trying to initialize. + * @return The partitions that have been successfully initialized. + */ + private def initializeLeaderAndIsrForPartitions(partitions: Seq[TopicPartition]): Seq[TopicPartition] = { + val successfulInitializations = mutable.Buffer.empty[TopicPartition] + val replicasPerPartition = partitions.map(partition => partition -> controllerContext.partitionReplicaAssignment(partition)) + val liveReplicasPerPartition = replicasPerPartition.map { case (partition, replicas) => + val liveReplicasForPartition = replicas.filter(replica => controllerContext.isReplicaOnline(replica, partition)) + partition -> liveReplicasForPartition + } + val (partitionsWithoutLiveReplicas, partitionsWithLiveReplicas) = liveReplicasPerPartition.partition { case (_, liveReplicas) => liveReplicas.isEmpty } + + partitionsWithoutLiveReplicas.foreach { case (partition, _) => + val failMsg = s"Controller $controllerId epoch ${controllerContext.epoch} encountered error during state change of " + + s"partition $partition from New to Online, assigned replicas are " + + s"[${controllerContext.partitionReplicaAssignment(partition).mkString(",")}], live brokers are [${controllerContext.liveBrokerIds}]. No assigned " + + "replica is alive." + logFailedStateChange(partition, NewPartition, OnlinePartition, new StateChangeFailedException(failMsg)) + } + val leaderIsrAndControllerEpochs = partitionsWithLiveReplicas.map { case (partition, liveReplicas) => + val leaderAndIsr = LeaderAndIsr(liveReplicas.head, liveReplicas.toList) + val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(leaderAndIsr, controllerContext.epoch) + partition -> leaderIsrAndControllerEpoch + }.toMap + val createResponses = try { + zkClient.createTopicPartitionStatesRaw(leaderIsrAndControllerEpochs, controllerContext.epochZkVersion) + } catch { + case e: ControllerMovedException => + error("Controller moved to another broker when trying to create the topic partition state znode", e) + throw e + case e: Exception => + partitionsWithLiveReplicas.foreach { case (partition, _) => logFailedStateChange(partition, partitionState(partition), NewPartition, e) } + Seq.empty + } + createResponses.foreach { createResponse => + val code = createResponse.resultCode + val partition = createResponse.ctx.get.asInstanceOf[TopicPartition] + val leaderIsrAndControllerEpoch = leaderIsrAndControllerEpochs(partition) + if (code == Code.OK) { + controllerContext.putPartitionLeadershipInfo(partition, leaderIsrAndControllerEpoch) + controllerBrokerRequestBatch.addLeaderAndIsrRequestForBrokers(leaderIsrAndControllerEpoch.leaderAndIsr.isr, + partition, leaderIsrAndControllerEpoch, controllerContext.partitionFullReplicaAssignment(partition), isNew = true) + successfulInitializations += partition + } else { + logFailedStateChange(partition, NewPartition, OnlinePartition, code) + } + } + successfulInitializations + } + + /** + * Repeatedly attempt to elect leaders for multiple partitions until there are no more remaining partitions to retry. + * @param partitions The partitions that we're trying to elect leaders for. + * @param partitionLeaderElectionStrategy The election strategy to use. + * @return A map of failed and successful elections. The keys are the topic partitions and the corresponding values are + * either the exception that was thrown or new leader & ISR. + */ + private def electLeaderForPartitions( + partitions: Seq[TopicPartition], + partitionLeaderElectionStrategy: PartitionLeaderElectionStrategy + ): Map[TopicPartition, Either[Throwable, LeaderAndIsr]] = { + var remaining = partitions + val finishedElections = mutable.Map.empty[TopicPartition, Either[Throwable, LeaderAndIsr]] + + while (remaining.nonEmpty) { + val (finished, updatesToRetry) = doElectLeaderForPartitions(remaining, partitionLeaderElectionStrategy) + remaining = updatesToRetry + + finished.foreach { + case (partition, Left(e)) => + logFailedStateChange(partition, partitionState(partition), OnlinePartition, e) + case (_, Right(_)) => // Ignore; success so no need to log failed state change + } + + finishedElections ++= finished + + if (remaining.nonEmpty) + logger.info(s"Retrying leader election with strategy $partitionLeaderElectionStrategy for partitions $remaining") + } + + finishedElections.toMap + } + + /** + * Try to elect leaders for multiple partitions. + * Electing a leader for a partition updates partition state in zookeeper. + * + * @param partitions The partitions that we're trying to elect leaders for. + * @param partitionLeaderElectionStrategy The election strategy to use. + * @return A tuple of two values: + * 1. The partitions and the expected leader and isr that successfully had a leader elected. And exceptions + * corresponding to failed elections that should not be retried. + * 2. The partitions that we should retry due to a zookeeper BADVERSION conflict. Version conflicts can occur if + * the partition leader updated partition state while the controller attempted to update partition state. + */ + private def doElectLeaderForPartitions( + partitions: Seq[TopicPartition], + partitionLeaderElectionStrategy: PartitionLeaderElectionStrategy + ): (Map[TopicPartition, Either[Exception, LeaderAndIsr]], Seq[TopicPartition]) = { + val getDataResponses = try { + zkClient.getTopicPartitionStatesRaw(partitions) + } catch { + case e: Exception => + return (partitions.iterator.map(_ -> Left(e)).toMap, Seq.empty) + } + val failedElections = mutable.Map.empty[TopicPartition, Either[Exception, LeaderAndIsr]] + val validLeaderAndIsrs = mutable.Buffer.empty[(TopicPartition, LeaderAndIsr)] + + getDataResponses.foreach { getDataResponse => + val partition = getDataResponse.ctx.get.asInstanceOf[TopicPartition] + val currState = partitionState(partition) + if (getDataResponse.resultCode == Code.OK) { + TopicPartitionStateZNode.decode(getDataResponse.data, getDataResponse.stat) match { + case Some(leaderIsrAndControllerEpoch) => + if (leaderIsrAndControllerEpoch.controllerEpoch > controllerContext.epoch) { + val failMsg = s"Aborted leader election for partition $partition since the LeaderAndIsr path was " + + s"already written by another controller. This probably means that the current controller $controllerId went through " + + s"a soft failure and another controller was elected with epoch ${leaderIsrAndControllerEpoch.controllerEpoch}." + failedElections.put(partition, Left(new StateChangeFailedException(failMsg))) + } else { + validLeaderAndIsrs += partition -> leaderIsrAndControllerEpoch.leaderAndIsr + } + + case None => + val exception = new StateChangeFailedException(s"LeaderAndIsr information doesn't exist for partition $partition in $currState state") + failedElections.put(partition, Left(exception)) + } + + } else if (getDataResponse.resultCode == Code.NONODE) { + val exception = new StateChangeFailedException(s"LeaderAndIsr information doesn't exist for partition $partition in $currState state") + failedElections.put(partition, Left(exception)) + } else { + failedElections.put(partition, Left(getDataResponse.resultException.get)) + } + } + + if (validLeaderAndIsrs.isEmpty) { + return (failedElections.toMap, Seq.empty) + } + + val (partitionsWithoutLeaders, partitionsWithLeaders) = partitionLeaderElectionStrategy match { + case OfflinePartitionLeaderElectionStrategy(allowUnclean) => + val partitionsWithUncleanLeaderElectionState = collectUncleanLeaderElectionState( + validLeaderAndIsrs, + allowUnclean + ) + leaderForOffline(controllerContext, partitionsWithUncleanLeaderElectionState).partition(_.leaderAndIsr.isEmpty) + case ReassignPartitionLeaderElectionStrategy => + leaderForReassign(controllerContext, validLeaderAndIsrs).partition(_.leaderAndIsr.isEmpty) + case PreferredReplicaPartitionLeaderElectionStrategy => + leaderForPreferredReplica(controllerContext, validLeaderAndIsrs).partition(_.leaderAndIsr.isEmpty) + case ControlledShutdownPartitionLeaderElectionStrategy => + leaderForControlledShutdown(controllerContext, validLeaderAndIsrs).partition(_.leaderAndIsr.isEmpty) + } + partitionsWithoutLeaders.foreach { electionResult => + val partition = electionResult.topicPartition + val failMsg = s"Failed to elect leader for partition $partition under strategy $partitionLeaderElectionStrategy" + failedElections.put(partition, Left(new StateChangeFailedException(failMsg))) + } + val recipientsPerPartition = partitionsWithLeaders.map(result => result.topicPartition -> result.liveReplicas).toMap + val adjustedLeaderAndIsrs = partitionsWithLeaders.map(result => result.topicPartition -> result.leaderAndIsr.get).toMap + val UpdateLeaderAndIsrResult(finishedUpdates, updatesToRetry) = zkClient.updateLeaderAndIsr( + adjustedLeaderAndIsrs, controllerContext.epoch, controllerContext.epochZkVersion) + finishedUpdates.forKeyValue { (partition, result) => + result.foreach { leaderAndIsr => + val replicaAssignment = controllerContext.partitionFullReplicaAssignment(partition) + val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(leaderAndIsr, controllerContext.epoch) + controllerContext.putPartitionLeadershipInfo(partition, leaderIsrAndControllerEpoch) + controllerBrokerRequestBatch.addLeaderAndIsrRequestForBrokers(recipientsPerPartition(partition), partition, + leaderIsrAndControllerEpoch, replicaAssignment, isNew = false) + } + } + + if (isDebugEnabled) { + updatesToRetry.foreach { partition => + debug(s"Controller failed to elect leader for partition $partition. " + + s"Attempted to write state ${adjustedLeaderAndIsrs(partition)}, but failed with bad ZK version. This will be retried.") + } + } + + (finishedUpdates ++ failedElections, updatesToRetry) + } + + /* For the provided set of topic partition and partition sync state it attempts to determine if unclean + * leader election should be performed. Unclean election should be performed if there are no live + * replica which are in sync and unclean leader election is allowed (allowUnclean parameter is true or + * the topic has been configured to allow unclean election). + * + * @param leaderIsrAndControllerEpochs set of partition to determine if unclean leader election should be + * allowed + * @param allowUnclean whether to allow unclean election without having to read the topic configuration + * @return a sequence of three element tuple: + * 1. topic partition + * 2. leader, isr and controller epoc. Some means election should be performed + * 3. allow unclean + */ + private def collectUncleanLeaderElectionState( + leaderAndIsrs: Seq[(TopicPartition, LeaderAndIsr)], + allowUnclean: Boolean + ): Seq[(TopicPartition, Option[LeaderAndIsr], Boolean)] = { + val (partitionsWithNoLiveInSyncReplicas, partitionsWithLiveInSyncReplicas) = leaderAndIsrs.partition { + case (partition, leaderAndIsr) => + val liveInSyncReplicas = leaderAndIsr.isr.filter(controllerContext.isReplicaOnline(_, partition)) + liveInSyncReplicas.isEmpty + } + + val electionForPartitionWithoutLiveReplicas = if (allowUnclean) { + partitionsWithNoLiveInSyncReplicas.map { case (partition, leaderAndIsr) => + (partition, Option(leaderAndIsr), true) + } + } else { + val (logConfigs, failed) = zkClient.getLogConfigs( + partitionsWithNoLiveInSyncReplicas.iterator.map { case (partition, _) => partition.topic }.toSet, + config.originals() + ) + + partitionsWithNoLiveInSyncReplicas.map { case (partition, leaderAndIsr) => + if (failed.contains(partition.topic)) { + logFailedStateChange(partition, partitionState(partition), OnlinePartition, failed(partition.topic)) + (partition, None, false) + } else { + ( + partition, + Option(leaderAndIsr), + logConfigs(partition.topic).uncleanLeaderElectionEnable.booleanValue() + ) + } + } + } + + electionForPartitionWithoutLiveReplicas ++ + partitionsWithLiveInSyncReplicas.map { case (partition, leaderAndIsr) => + (partition, Option(leaderAndIsr), false) + } + } + + private def logInvalidTransition(partition: TopicPartition, targetState: PartitionState): Unit = { + val currState = partitionState(partition) + val e = new IllegalStateException(s"Partition $partition should be in one of " + + s"${targetState.validPreviousStates.mkString(",")} states before moving to $targetState state. Instead it is in " + + s"$currState state") + logFailedStateChange(partition, currState, targetState, e) + } + + private def logFailedStateChange(partition: TopicPartition, currState: PartitionState, targetState: PartitionState, code: Code): Unit = { + logFailedStateChange(partition, currState, targetState, KeeperException.create(code)) + } + + private def logFailedStateChange(partition: TopicPartition, currState: PartitionState, targetState: PartitionState, t: Throwable): Unit = { + stateChangeLogger.withControllerEpoch(controllerContext.epoch) + .error(s"Controller $controllerId epoch ${controllerContext.epoch} failed to change state for partition $partition " + + s"from $currState to $targetState", t) + } +} + +object PartitionLeaderElectionAlgorithms { + def offlinePartitionLeaderElection(assignment: Seq[Int], isr: Seq[Int], liveReplicas: Set[Int], uncleanLeaderElectionEnabled: Boolean, controllerContext: ControllerContext): Option[Int] = { + assignment.find(id => liveReplicas.contains(id) && isr.contains(id)).orElse { + if (uncleanLeaderElectionEnabled) { + val leaderOpt = assignment.find(liveReplicas.contains) + if (leaderOpt.isDefined) + controllerContext.stats.uncleanLeaderElectionRate.mark() + leaderOpt + } else { + None + } + } + } + + def reassignPartitionLeaderElection(reassignment: Seq[Int], isr: Seq[Int], liveReplicas: Set[Int]): Option[Int] = { + reassignment.find(id => liveReplicas.contains(id) && isr.contains(id)) + } + + def preferredReplicaPartitionLeaderElection(assignment: Seq[Int], isr: Seq[Int], liveReplicas: Set[Int]): Option[Int] = { + assignment.headOption.filter(id => liveReplicas.contains(id) && isr.contains(id)) + } + + def controlledShutdownPartitionLeaderElection(assignment: Seq[Int], isr: Seq[Int], liveReplicas: Set[Int], shuttingDownBrokers: Set[Int]): Option[Int] = { + assignment.find(id => liveReplicas.contains(id) && isr.contains(id) && !shuttingDownBrokers.contains(id)) + } +} + +sealed trait PartitionLeaderElectionStrategy +final case class OfflinePartitionLeaderElectionStrategy(allowUnclean: Boolean) extends PartitionLeaderElectionStrategy +final case object ReassignPartitionLeaderElectionStrategy extends PartitionLeaderElectionStrategy +final case object PreferredReplicaPartitionLeaderElectionStrategy extends PartitionLeaderElectionStrategy +final case object ControlledShutdownPartitionLeaderElectionStrategy extends PartitionLeaderElectionStrategy + +sealed trait PartitionState { + def state: Byte + def validPreviousStates: Set[PartitionState] +} + +case object NewPartition extends PartitionState { + val state: Byte = 0 + val validPreviousStates: Set[PartitionState] = Set(NonExistentPartition) +} + +case object OnlinePartition extends PartitionState { + val state: Byte = 1 + val validPreviousStates: Set[PartitionState] = Set(NewPartition, OnlinePartition, OfflinePartition) +} + +case object OfflinePartition extends PartitionState { + val state: Byte = 2 + val validPreviousStates: Set[PartitionState] = Set(NewPartition, OnlinePartition, OfflinePartition) +} + +case object NonExistentPartition extends PartitionState { + val state: Byte = 3 + val validPreviousStates: Set[PartitionState] = Set(OfflinePartition) +} diff --git a/core/src/main/scala/kafka/controller/ReplicaStateMachine.scala b/core/src/main/scala/kafka/controller/ReplicaStateMachine.scala new file mode 100644 index 0000000..5b41e66 --- /dev/null +++ b/core/src/main/scala/kafka/controller/ReplicaStateMachine.scala @@ -0,0 +1,491 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +package kafka.controller + +import kafka.api.LeaderAndIsr +import kafka.common.StateChangeFailedException +import kafka.server.KafkaConfig +import kafka.utils.Implicits._ +import kafka.utils.Logging +import kafka.zk.KafkaZkClient +import kafka.zk.KafkaZkClient.UpdateLeaderAndIsrResult +import kafka.zk.TopicPartitionStateZNode +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.errors.ControllerMovedException +import org.apache.zookeeper.KeeperException.Code +import scala.collection.{Seq, mutable} + +abstract class ReplicaStateMachine(controllerContext: ControllerContext) extends Logging { + /** + * Invoked on successful controller election. + */ + def startup(): Unit = { + info("Initializing replica state") + initializeReplicaState() + info("Triggering online replica state changes") + val (onlineReplicas, offlineReplicas) = controllerContext.onlineAndOfflineReplicas + handleStateChanges(onlineReplicas.toSeq, OnlineReplica) + info("Triggering offline replica state changes") + handleStateChanges(offlineReplicas.toSeq, OfflineReplica) + debug(s"Started replica state machine with initial state -> ${controllerContext.replicaStates}") + } + + /** + * Invoked on controller shutdown. + */ + def shutdown(): Unit = { + info("Stopped replica state machine") + } + + /** + * Invoked on startup of the replica's state machine to set the initial state for replicas of all existing partitions + * in zookeeper + */ + private def initializeReplicaState(): Unit = { + controllerContext.allPartitions.foreach { partition => + val replicas = controllerContext.partitionReplicaAssignment(partition) + replicas.foreach { replicaId => + val partitionAndReplica = PartitionAndReplica(partition, replicaId) + if (controllerContext.isReplicaOnline(replicaId, partition)) { + controllerContext.putReplicaState(partitionAndReplica, OnlineReplica) + } else { + // mark replicas on dead brokers as failed for topic deletion, if they belong to a topic to be deleted. + // This is required during controller failover since during controller failover a broker can go down, + // so the replicas on that broker should be moved to ReplicaDeletionIneligible to be on the safer side. + controllerContext.putReplicaState(partitionAndReplica, ReplicaDeletionIneligible) + } + } + } + } + + def handleStateChanges(replicas: Seq[PartitionAndReplica], targetState: ReplicaState): Unit +} + +/** + * This class represents the state machine for replicas. It defines the states that a replica can be in, and + * transitions to move the replica to another legal state. The different states that a replica can be in are - + * 1. NewReplica : The controller can create new replicas during partition reassignment. In this state, a + * replica can only get become follower state change request. Valid previous + * state is NonExistentReplica + * 2. OnlineReplica : Once a replica is started and part of the assigned replicas for its partition, it is in this + * state. In this state, it can get either become leader or become follower state change requests. + * Valid previous state are NewReplica, OnlineReplica, OfflineReplica and ReplicaDeletionIneligible + * 3. OfflineReplica : If a replica dies, it moves to this state. This happens when the broker hosting the replica + * is down. Valid previous state are NewReplica, OnlineReplica, OfflineReplica and ReplicaDeletionIneligible + * 4. ReplicaDeletionStarted: If replica deletion starts, it is moved to this state. Valid previous state is OfflineReplica + * 5. ReplicaDeletionSuccessful: If replica responds with no error code in response to a delete replica request, it is + * moved to this state. Valid previous state is ReplicaDeletionStarted + * 6. ReplicaDeletionIneligible: If replica deletion fails, it is moved to this state. Valid previous states are + * ReplicaDeletionStarted and OfflineReplica + * 7. NonExistentReplica: If a replica is deleted successfully, it is moved to this state. Valid previous state is + * ReplicaDeletionSuccessful + */ +class ZkReplicaStateMachine(config: KafkaConfig, + stateChangeLogger: StateChangeLogger, + controllerContext: ControllerContext, + zkClient: KafkaZkClient, + controllerBrokerRequestBatch: ControllerBrokerRequestBatch) + extends ReplicaStateMachine(controllerContext) with Logging { + + private val controllerId = config.brokerId + this.logIdent = s"[ReplicaStateMachine controllerId=$controllerId] " + + override def handleStateChanges(replicas: Seq[PartitionAndReplica], targetState: ReplicaState): Unit = { + if (replicas.nonEmpty) { + try { + controllerBrokerRequestBatch.newBatch() + replicas.groupBy(_.replica).forKeyValue { (replicaId, replicas) => + doHandleStateChanges(replicaId, replicas, targetState) + } + controllerBrokerRequestBatch.sendRequestsToBrokers(controllerContext.epoch) + } catch { + case e: ControllerMovedException => + error(s"Controller moved to another broker when moving some replicas to $targetState state", e) + throw e + case e: Throwable => error(s"Error while moving some replicas to $targetState state", e) + } + } + } + + /** + * This API exercises the replica's state machine. It ensures that every state transition happens from a legal + * previous state to the target state. Valid state transitions are: + * NonExistentReplica --> NewReplica + * --send LeaderAndIsr request with current leader and isr to the new replica and UpdateMetadata request for the + * partition to every live broker + * + * NewReplica -> OnlineReplica + * --add the new replica to the assigned replica list if needed + * + * OnlineReplica,OfflineReplica -> OnlineReplica + * --send LeaderAndIsr request with current leader and isr to the new replica and UpdateMetadata request for the + * partition to every live broker + * + * NewReplica,OnlineReplica,OfflineReplica,ReplicaDeletionIneligible -> OfflineReplica + * --send StopReplicaRequest to the replica (w/o deletion) + * --remove this replica from the isr and send LeaderAndIsr request (with new isr) to the leader replica and + * UpdateMetadata request for the partition to every live broker. + * + * OfflineReplica -> ReplicaDeletionStarted + * --send StopReplicaRequest to the replica (with deletion) + * + * ReplicaDeletionStarted -> ReplicaDeletionSuccessful + * -- mark the state of the replica in the state machine + * + * ReplicaDeletionStarted -> ReplicaDeletionIneligible + * -- mark the state of the replica in the state machine + * + * ReplicaDeletionSuccessful -> NonExistentReplica + * -- remove the replica from the in memory partition replica assignment cache + * + * @param replicaId The replica for which the state transition is invoked + * @param replicas The partitions on this replica for which the state transition is invoked + * @param targetState The end state that the replica should be moved to + */ + private def doHandleStateChanges(replicaId: Int, replicas: Seq[PartitionAndReplica], targetState: ReplicaState): Unit = { + val stateLogger = stateChangeLogger.withControllerEpoch(controllerContext.epoch) + val traceEnabled = stateLogger.isTraceEnabled + replicas.foreach(replica => controllerContext.putReplicaStateIfNotExists(replica, NonExistentReplica)) + val (validReplicas, invalidReplicas) = controllerContext.checkValidReplicaStateChange(replicas, targetState) + invalidReplicas.foreach(replica => logInvalidTransition(replica, targetState)) + + targetState match { + case NewReplica => + validReplicas.foreach { replica => + val partition = replica.topicPartition + val currentState = controllerContext.replicaState(replica) + + controllerContext.partitionLeadershipInfo(partition) match { + case Some(leaderIsrAndControllerEpoch) => + if (leaderIsrAndControllerEpoch.leaderAndIsr.leader == replicaId) { + val exception = new StateChangeFailedException(s"Replica $replicaId for partition $partition cannot be moved to NewReplica state as it is being requested to become leader") + logFailedStateChange(replica, currentState, OfflineReplica, exception) + } else { + controllerBrokerRequestBatch.addLeaderAndIsrRequestForBrokers(Seq(replicaId), + replica.topicPartition, + leaderIsrAndControllerEpoch, + controllerContext.partitionFullReplicaAssignment(replica.topicPartition), + isNew = true) + if (traceEnabled) + logSuccessfulTransition(stateLogger, replicaId, partition, currentState, NewReplica) + controllerContext.putReplicaState(replica, NewReplica) + } + case None => + if (traceEnabled) + logSuccessfulTransition(stateLogger, replicaId, partition, currentState, NewReplica) + controllerContext.putReplicaState(replica, NewReplica) + } + } + case OnlineReplica => + validReplicas.foreach { replica => + val partition = replica.topicPartition + val currentState = controllerContext.replicaState(replica) + + currentState match { + case NewReplica => + val assignment = controllerContext.partitionFullReplicaAssignment(partition) + if (!assignment.replicas.contains(replicaId)) { + error(s"Adding replica ($replicaId) that is not part of the assignment $assignment") + val newAssignment = assignment.copy(replicas = assignment.replicas :+ replicaId) + controllerContext.updatePartitionFullReplicaAssignment(partition, newAssignment) + } + case _ => + controllerContext.partitionLeadershipInfo(partition) match { + case Some(leaderIsrAndControllerEpoch) => + controllerBrokerRequestBatch.addLeaderAndIsrRequestForBrokers(Seq(replicaId), + replica.topicPartition, + leaderIsrAndControllerEpoch, + controllerContext.partitionFullReplicaAssignment(partition), isNew = false) + case None => + } + } + if (traceEnabled) + logSuccessfulTransition(stateLogger, replicaId, partition, currentState, OnlineReplica) + controllerContext.putReplicaState(replica, OnlineReplica) + } + case OfflineReplica => + validReplicas.foreach { replica => + controllerBrokerRequestBatch.addStopReplicaRequestForBrokers(Seq(replicaId), replica.topicPartition, deletePartition = false) + } + val (replicasWithLeadershipInfo, replicasWithoutLeadershipInfo) = validReplicas.partition { replica => + controllerContext.partitionLeadershipInfo(replica.topicPartition).isDefined + } + val updatedLeaderIsrAndControllerEpochs = removeReplicasFromIsr(replicaId, replicasWithLeadershipInfo.map(_.topicPartition)) + updatedLeaderIsrAndControllerEpochs.forKeyValue { (partition, leaderIsrAndControllerEpoch) => + stateLogger.info(s"Partition $partition state changed to $leaderIsrAndControllerEpoch after removing replica $replicaId from the ISR as part of transition to $OfflineReplica") + if (!controllerContext.isTopicQueuedUpForDeletion(partition.topic)) { + val recipients = controllerContext.partitionReplicaAssignment(partition).filterNot(_ == replicaId) + controllerBrokerRequestBatch.addLeaderAndIsrRequestForBrokers(recipients, + partition, + leaderIsrAndControllerEpoch, + controllerContext.partitionFullReplicaAssignment(partition), isNew = false) + } + val replica = PartitionAndReplica(partition, replicaId) + val currentState = controllerContext.replicaState(replica) + if (traceEnabled) + logSuccessfulTransition(stateLogger, replicaId, partition, currentState, OfflineReplica) + controllerContext.putReplicaState(replica, OfflineReplica) + } + + replicasWithoutLeadershipInfo.foreach { replica => + val currentState = controllerContext.replicaState(replica) + if (traceEnabled) + logSuccessfulTransition(stateLogger, replicaId, replica.topicPartition, currentState, OfflineReplica) + controllerBrokerRequestBatch.addUpdateMetadataRequestForBrokers(controllerContext.liveOrShuttingDownBrokerIds.toSeq, Set(replica.topicPartition)) + controllerContext.putReplicaState(replica, OfflineReplica) + } + case ReplicaDeletionStarted => + validReplicas.foreach { replica => + val currentState = controllerContext.replicaState(replica) + if (traceEnabled) + logSuccessfulTransition(stateLogger, replicaId, replica.topicPartition, currentState, ReplicaDeletionStarted) + controllerContext.putReplicaState(replica, ReplicaDeletionStarted) + controllerBrokerRequestBatch.addStopReplicaRequestForBrokers(Seq(replicaId), replica.topicPartition, deletePartition = true) + } + case ReplicaDeletionIneligible => + validReplicas.foreach { replica => + val currentState = controllerContext.replicaState(replica) + if (traceEnabled) + logSuccessfulTransition(stateLogger, replicaId, replica.topicPartition, currentState, ReplicaDeletionIneligible) + controllerContext.putReplicaState(replica, ReplicaDeletionIneligible) + } + case ReplicaDeletionSuccessful => + validReplicas.foreach { replica => + val currentState = controllerContext.replicaState(replica) + if (traceEnabled) + logSuccessfulTransition(stateLogger, replicaId, replica.topicPartition, currentState, ReplicaDeletionSuccessful) + controllerContext.putReplicaState(replica, ReplicaDeletionSuccessful) + } + case NonExistentReplica => + validReplicas.foreach { replica => + val currentState = controllerContext.replicaState(replica) + val newAssignedReplicas = controllerContext + .partitionFullReplicaAssignment(replica.topicPartition) + .removeReplica(replica.replica) + + controllerContext.updatePartitionFullReplicaAssignment(replica.topicPartition, newAssignedReplicas) + if (traceEnabled) + logSuccessfulTransition(stateLogger, replicaId, replica.topicPartition, currentState, NonExistentReplica) + controllerContext.removeReplicaState(replica) + } + } + } + + /** + * Repeatedly attempt to remove a replica from the isr of multiple partitions until there are no more remaining partitions + * to retry. + * @param replicaId The replica being removed from isr of multiple partitions + * @param partitions The partitions from which we're trying to remove the replica from isr + * @return The updated LeaderIsrAndControllerEpochs of all partitions for which we successfully removed the replica from isr. + */ + private def removeReplicasFromIsr( + replicaId: Int, + partitions: Seq[TopicPartition] + ): Map[TopicPartition, LeaderIsrAndControllerEpoch] = { + var results = Map.empty[TopicPartition, LeaderIsrAndControllerEpoch] + var remaining = partitions + while (remaining.nonEmpty) { + val (finishedRemoval, removalsToRetry) = doRemoveReplicasFromIsr(replicaId, remaining) + remaining = removalsToRetry + + finishedRemoval.foreach { + case (partition, Left(e)) => + val replica = PartitionAndReplica(partition, replicaId) + val currentState = controllerContext.replicaState(replica) + logFailedStateChange(replica, currentState, OfflineReplica, e) + case (partition, Right(leaderIsrAndEpoch)) => + results += partition -> leaderIsrAndEpoch + } + } + results + } + + /** + * Try to remove a replica from the isr of multiple partitions. + * Removing a replica from isr updates partition state in zookeeper. + * + * @param replicaId The replica being removed from isr of multiple partitions + * @param partitions The partitions from which we're trying to remove the replica from isr + * @return A tuple of two elements: + * 1. The updated Right[LeaderIsrAndControllerEpochs] of all partitions for which we successfully + * removed the replica from isr. Or Left[Exception] corresponding to failed removals that should + * not be retried + * 2. The partitions that we should retry due to a zookeeper BADVERSION conflict. Version conflicts can occur if + * the partition leader updated partition state while the controller attempted to update partition state. + */ + private def doRemoveReplicasFromIsr( + replicaId: Int, + partitions: Seq[TopicPartition] + ): (Map[TopicPartition, Either[Exception, LeaderIsrAndControllerEpoch]], Seq[TopicPartition]) = { + val (leaderAndIsrs, partitionsWithNoLeaderAndIsrInZk) = getTopicPartitionStatesFromZk(partitions) + val (leaderAndIsrsWithReplica, leaderAndIsrsWithoutReplica) = leaderAndIsrs.partition { case (_, result) => + result.map { leaderAndIsr => + leaderAndIsr.isr.contains(replicaId) + }.getOrElse(false) + } + + val adjustedLeaderAndIsrs: Map[TopicPartition, LeaderAndIsr] = leaderAndIsrsWithReplica.flatMap { + case (partition, result) => + result.toOption.map { leaderAndIsr => + val newLeader = if (replicaId == leaderAndIsr.leader) LeaderAndIsr.NoLeader else leaderAndIsr.leader + val adjustedIsr = if (leaderAndIsr.isr.size == 1) leaderAndIsr.isr else leaderAndIsr.isr.filter(_ != replicaId) + partition -> leaderAndIsr.newLeaderAndIsr(newLeader, adjustedIsr) + } + } + + val UpdateLeaderAndIsrResult(finishedPartitions, updatesToRetry) = zkClient.updateLeaderAndIsr( + adjustedLeaderAndIsrs, controllerContext.epoch, controllerContext.epochZkVersion) + + val exceptionsForPartitionsWithNoLeaderAndIsrInZk: Map[TopicPartition, Either[Exception, LeaderIsrAndControllerEpoch]] = + partitionsWithNoLeaderAndIsrInZk.iterator.flatMap { partition => + if (!controllerContext.isTopicQueuedUpForDeletion(partition.topic)) { + val exception = new StateChangeFailedException( + s"Failed to change state of replica $replicaId for partition $partition since the leader and isr " + + "path in zookeeper is empty" + ) + Option(partition -> Left(exception)) + } else None + }.toMap + + val leaderIsrAndControllerEpochs: Map[TopicPartition, Either[Exception, LeaderIsrAndControllerEpoch]] = + (leaderAndIsrsWithoutReplica ++ finishedPartitions).map { case (partition, result) => + (partition, result.map { leaderAndIsr => + val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(leaderAndIsr, controllerContext.epoch) + controllerContext.putPartitionLeadershipInfo(partition, leaderIsrAndControllerEpoch) + leaderIsrAndControllerEpoch + }) + } + + if (isDebugEnabled) { + updatesToRetry.foreach { partition => + debug(s"Controller failed to remove replica $replicaId from ISR of partition $partition. " + + s"Attempted to write state ${adjustedLeaderAndIsrs(partition)}, but failed with bad ZK version. This will be retried.") + } + } + + (leaderIsrAndControllerEpochs ++ exceptionsForPartitionsWithNoLeaderAndIsrInZk, updatesToRetry) + } + + /** + * Gets the partition state from zookeeper + * @param partitions the partitions whose state we want from zookeeper + * @return A tuple of two values: + * 1. The Right(LeaderAndIsrs) of partitions whose state we successfully read from zookeeper. + * The Left(Exception) to failed zookeeper lookups or states whose controller epoch exceeds our current epoch + * 2. The partitions that had no leader and isr state in zookeeper. This happens if the controller + * didn't finish partition initialization. + */ + private def getTopicPartitionStatesFromZk( + partitions: Seq[TopicPartition] + ): (Map[TopicPartition, Either[Exception, LeaderAndIsr]], Seq[TopicPartition]) = { + val getDataResponses = try { + zkClient.getTopicPartitionStatesRaw(partitions) + } catch { + case e: Exception => + return (partitions.iterator.map(_ -> Left(e)).toMap, Seq.empty) + } + + val partitionsWithNoLeaderAndIsrInZk = mutable.Buffer.empty[TopicPartition] + val result = mutable.Map.empty[TopicPartition, Either[Exception, LeaderAndIsr]] + + getDataResponses.foreach[Unit] { getDataResponse => + val partition = getDataResponse.ctx.get.asInstanceOf[TopicPartition] + if (getDataResponse.resultCode == Code.OK) { + TopicPartitionStateZNode.decode(getDataResponse.data, getDataResponse.stat) match { + case None => + partitionsWithNoLeaderAndIsrInZk += partition + case Some(leaderIsrAndControllerEpoch) => + if (leaderIsrAndControllerEpoch.controllerEpoch > controllerContext.epoch) { + val exception = new StateChangeFailedException( + "Leader and isr path written by another controller. This probably " + + s"means the current controller with epoch ${controllerContext.epoch} went through a soft failure and " + + s"another controller was elected with epoch ${leaderIsrAndControllerEpoch.controllerEpoch}. Aborting " + + "state change by this controller" + ) + result += (partition -> Left(exception)) + } else { + result += (partition -> Right(leaderIsrAndControllerEpoch.leaderAndIsr)) + } + } + } else if (getDataResponse.resultCode == Code.NONODE) { + partitionsWithNoLeaderAndIsrInZk += partition + } else { + result += (partition -> Left(getDataResponse.resultException.get)) + } + } + + (result.toMap, partitionsWithNoLeaderAndIsrInZk) + } + + private def logSuccessfulTransition(logger: StateChangeLogger, replicaId: Int, partition: TopicPartition, + currState: ReplicaState, targetState: ReplicaState): Unit = { + logger.trace(s"Changed state of replica $replicaId for partition $partition from $currState to $targetState") + } + + private def logInvalidTransition(replica: PartitionAndReplica, targetState: ReplicaState): Unit = { + val currState = controllerContext.replicaState(replica) + val e = new IllegalStateException(s"Replica $replica should be in the ${targetState.validPreviousStates.mkString(",")} " + + s"states before moving to $targetState state. Instead it is in $currState state") + logFailedStateChange(replica, currState, targetState, e) + } + + private def logFailedStateChange(replica: PartitionAndReplica, currState: ReplicaState, targetState: ReplicaState, t: Throwable): Unit = { + stateChangeLogger.withControllerEpoch(controllerContext.epoch) + .error(s"Controller $controllerId epoch ${controllerContext.epoch} initiated state change of replica ${replica.replica} " + + s"for partition ${replica.topicPartition} from $currState to $targetState failed", t) + } +} + +sealed trait ReplicaState { + def state: Byte + def validPreviousStates: Set[ReplicaState] +} + +case object NewReplica extends ReplicaState { + val state: Byte = 1 + val validPreviousStates: Set[ReplicaState] = Set(NonExistentReplica) +} + +case object OnlineReplica extends ReplicaState { + val state: Byte = 2 + val validPreviousStates: Set[ReplicaState] = Set(NewReplica, OnlineReplica, OfflineReplica, ReplicaDeletionIneligible) +} + +case object OfflineReplica extends ReplicaState { + val state: Byte = 3 + val validPreviousStates: Set[ReplicaState] = Set(NewReplica, OnlineReplica, OfflineReplica, ReplicaDeletionIneligible) +} + +case object ReplicaDeletionStarted extends ReplicaState { + val state: Byte = 4 + val validPreviousStates: Set[ReplicaState] = Set(OfflineReplica) +} + +case object ReplicaDeletionSuccessful extends ReplicaState { + val state: Byte = 5 + val validPreviousStates: Set[ReplicaState] = Set(ReplicaDeletionStarted) +} + +case object ReplicaDeletionIneligible extends ReplicaState { + val state: Byte = 6 + val validPreviousStates: Set[ReplicaState] = Set(OfflineReplica, ReplicaDeletionStarted) +} + +case object NonExistentReplica extends ReplicaState { + val state: Byte = 7 + val validPreviousStates: Set[ReplicaState] = Set(ReplicaDeletionSuccessful) +} diff --git a/core/src/main/scala/kafka/controller/StateChangeLogger.scala b/core/src/main/scala/kafka/controller/StateChangeLogger.scala new file mode 100644 index 0000000..a1d1bb2 --- /dev/null +++ b/core/src/main/scala/kafka/controller/StateChangeLogger.scala @@ -0,0 +1,50 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.controller + +import com.typesafe.scalalogging.Logger +import kafka.utils.Logging + +object StateChangeLogger { + private val logger = Logger("state.change.logger") +} + +/** + * Simple class that sets `logIdent` appropriately depending on whether the state change logger is being used in the + * context of the KafkaController or not (e.g. ReplicaManager and MetadataCache log to the state change logger + * irrespective of whether the broker is the Controller). + */ +class StateChangeLogger(brokerId: Int, inControllerContext: Boolean, controllerEpoch: Option[Int]) extends Logging { + + if (controllerEpoch.isDefined && !inControllerContext) + throw new IllegalArgumentException("Controller epoch should only be defined if inControllerContext is true") + + override lazy val logger = StateChangeLogger.logger + + locally { + val prefix = if (inControllerContext) "Controller" else "Broker" + val epochEntry = controllerEpoch.fold("")(epoch => s" epoch=$epoch") + logIdent = s"[$prefix id=$brokerId$epochEntry] " + } + + def withControllerEpoch(controllerEpoch: Int): StateChangeLogger = + new StateChangeLogger(brokerId, inControllerContext, Some(controllerEpoch)) + + def messageWithPrefix(message: String): String = msgWithLogIdent(message) + +} diff --git a/core/src/main/scala/kafka/controller/TopicDeletionManager.scala b/core/src/main/scala/kafka/controller/TopicDeletionManager.scala new file mode 100755 index 0000000..0fd7274 --- /dev/null +++ b/core/src/main/scala/kafka/controller/TopicDeletionManager.scala @@ -0,0 +1,358 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.controller + +import kafka.server.KafkaConfig +import kafka.utils.Logging +import kafka.zk.KafkaZkClient +import org.apache.kafka.common.TopicPartition + +import scala.collection.Set +import scala.collection.mutable + +trait DeletionClient { + def deleteTopic(topic: String, epochZkVersion: Int): Unit + def deleteTopicDeletions(topics: Seq[String], epochZkVersion: Int): Unit + def mutePartitionModifications(topic: String): Unit + def sendMetadataUpdate(partitions: Set[TopicPartition]): Unit +} + +class ControllerDeletionClient(controller: KafkaController, zkClient: KafkaZkClient) extends DeletionClient { + override def deleteTopic(topic: String, epochZkVersion: Int): Unit = { + zkClient.deleteTopicZNode(topic, epochZkVersion) + zkClient.deleteTopicConfigs(Seq(topic), epochZkVersion) + zkClient.deleteTopicDeletions(Seq(topic), epochZkVersion) + } + + override def deleteTopicDeletions(topics: Seq[String], epochZkVersion: Int): Unit = { + zkClient.deleteTopicDeletions(topics, epochZkVersion) + } + + override def mutePartitionModifications(topic: String): Unit = { + controller.unregisterPartitionModificationsHandlers(Seq(topic)) + } + + override def sendMetadataUpdate(partitions: Set[TopicPartition]): Unit = { + controller.sendUpdateMetadataRequest(controller.controllerContext.liveOrShuttingDownBrokerIds.toSeq, partitions) + } +} + +/** + * This manages the state machine for topic deletion. + * 1. TopicCommand issues topic deletion by creating a new admin path /admin/delete_topics/ + * 2. The controller listens for child changes on /admin/delete_topic and starts topic deletion for the respective topics + * 3. The controller's ControllerEventThread handles topic deletion. A topic will be ineligible + * for deletion in the following scenarios - + * 3.1 broker hosting one of the replicas for that topic goes down + * 3.2 partition reassignment for partitions of that topic is in progress + * 4. Topic deletion is resumed when - + * 4.1 broker hosting one of the replicas for that topic is started + * 4.2 partition reassignment for partitions of that topic completes + * 5. Every replica for a topic being deleted is in either of the 3 states - + * 5.1 TopicDeletionStarted Replica enters TopicDeletionStarted phase when onPartitionDeletion is invoked. + * This happens when the child change watch for /admin/delete_topics fires on the controller. As part of this state + * change, the controller sends StopReplicaRequests to all replicas. It registers a callback for the + * StopReplicaResponse when deletePartition=true thereby invoking a callback when a response for delete replica + * is received from every replica) + * 5.2 TopicDeletionSuccessful moves replicas from + * TopicDeletionStarted->TopicDeletionSuccessful depending on the error codes in StopReplicaResponse + * 5.3 TopicDeletionFailed moves replicas from + * TopicDeletionStarted->TopicDeletionFailed depending on the error codes in StopReplicaResponse. + * In general, if a broker dies and if it hosted replicas for topics being deleted, the controller marks the + * respective replicas in TopicDeletionFailed state in the onBrokerFailure callback. The reason is that if a + * broker fails before the request is sent and after the replica is in TopicDeletionStarted state, + * it is possible that the replica will mistakenly remain in TopicDeletionStarted state and topic deletion + * will not be retried when the broker comes back up. + * 6. A topic is marked successfully deleted only if all replicas are in TopicDeletionSuccessful + * state. Topic deletion teardown mode deletes all topic state from the controllerContext + * as well as from zookeeper. This is the only time the /brokers/topics/ path gets deleted. On the other hand, + * if no replica is in TopicDeletionStarted state and at least one replica is in TopicDeletionFailed state, then + * it marks the topic for deletion retry. + * @param controller + */ +class TopicDeletionManager(config: KafkaConfig, + controllerContext: ControllerContext, + replicaStateMachine: ReplicaStateMachine, + partitionStateMachine: PartitionStateMachine, + client: DeletionClient) extends Logging { + this.logIdent = s"[Topic Deletion Manager ${config.brokerId}] " + val isDeleteTopicEnabled: Boolean = config.deleteTopicEnable + + def init(initialTopicsToBeDeleted: Set[String], initialTopicsIneligibleForDeletion: Set[String]): Unit = { + info(s"Initializing manager with initial deletions: $initialTopicsToBeDeleted, " + + s"initial ineligible deletions: $initialTopicsIneligibleForDeletion") + + if (isDeleteTopicEnabled) { + controllerContext.queueTopicDeletion(initialTopicsToBeDeleted) + controllerContext.topicsIneligibleForDeletion ++= initialTopicsIneligibleForDeletion & controllerContext.topicsToBeDeleted + } else { + // if delete topic is disabled clean the topic entries under /admin/delete_topics + info(s"Removing $initialTopicsToBeDeleted since delete topic is disabled") + client.deleteTopicDeletions(initialTopicsToBeDeleted.toSeq, controllerContext.epochZkVersion) + } + } + + def tryTopicDeletion(): Unit = { + if (isDeleteTopicEnabled) { + resumeDeletions() + } + } + + /** + * Invoked by the child change listener on /admin/delete_topics to queue up the topics for deletion. The topic gets added + * to the topicsToBeDeleted list and only gets removed from the list when the topic deletion has completed successfully + * i.e. all replicas of all partitions of that topic are deleted successfully. + * @param topics Topics that should be deleted + */ + def enqueueTopicsForDeletion(topics: Set[String]): Unit = { + if (isDeleteTopicEnabled) { + controllerContext.queueTopicDeletion(topics) + resumeDeletions() + } + } + + /** + * Invoked when any event that can possibly resume topic deletion occurs. These events include - + * 1. New broker starts up. Any replicas belonging to topics queued up for deletion can be deleted since the broker is up + * 2. Partition reassignment completes. Any partitions belonging to topics queued up for deletion finished reassignment + * @param topics Topics for which deletion can be resumed + */ + def resumeDeletionForTopics(topics: Set[String] = Set.empty): Unit = { + if (isDeleteTopicEnabled) { + val topicsToResumeDeletion = topics & controllerContext.topicsToBeDeleted + if (topicsToResumeDeletion.nonEmpty) { + controllerContext.topicsIneligibleForDeletion --= topicsToResumeDeletion + resumeDeletions() + } + } + } + + /** + * Invoked when a broker that hosts replicas for topics to be deleted goes down. Also invoked when the callback for + * StopReplicaResponse receives an error code for the replicas of a topic to be deleted. As part of this, the replicas + * are moved from ReplicaDeletionStarted to ReplicaDeletionIneligible state. Also, the topic is added to the list of topics + * ineligible for deletion until further notice. + * @param replicas Replicas for which deletion has failed + */ + def failReplicaDeletion(replicas: Set[PartitionAndReplica]): Unit = { + if (isDeleteTopicEnabled) { + val replicasThatFailedToDelete = replicas.filter(r => isTopicQueuedUpForDeletion(r.topic)) + if (replicasThatFailedToDelete.nonEmpty) { + val topics = replicasThatFailedToDelete.map(_.topic) + debug(s"Deletion failed for replicas ${replicasThatFailedToDelete.mkString(",")}. Halting deletion for topics $topics") + replicaStateMachine.handleStateChanges(replicasThatFailedToDelete.toSeq, ReplicaDeletionIneligible) + markTopicIneligibleForDeletion(topics, reason = "replica deletion failure") + resumeDeletions() + } + } + } + + /** + * Halt delete topic if - + * 1. replicas being down + * 2. partition reassignment in progress for some partitions of the topic + * @param topics Topics that should be marked ineligible for deletion. No op if the topic is was not previously queued up for deletion + */ + def markTopicIneligibleForDeletion(topics: Set[String], reason: => String): Unit = { + if (isDeleteTopicEnabled) { + val newTopicsToHaltDeletion = controllerContext.topicsToBeDeleted & topics + controllerContext.topicsIneligibleForDeletion ++= newTopicsToHaltDeletion + if (newTopicsToHaltDeletion.nonEmpty) + info(s"Halted deletion of topics ${newTopicsToHaltDeletion.mkString(",")} due to $reason") + } + } + + private def isTopicIneligibleForDeletion(topic: String): Boolean = { + if (isDeleteTopicEnabled) { + controllerContext.topicsIneligibleForDeletion.contains(topic) + } else + true + } + + private def isTopicDeletionInProgress(topic: String): Boolean = { + if (isDeleteTopicEnabled) { + controllerContext.isAnyReplicaInState(topic, ReplicaDeletionStarted) + } else + false + } + + def isTopicQueuedUpForDeletion(topic: String): Boolean = { + if (isDeleteTopicEnabled) { + controllerContext.isTopicQueuedUpForDeletion(topic) + } else + false + } + + /** + * Invoked by the StopReplicaResponse callback when it receives no error code for a replica of a topic to be deleted. + * As part of this, the replicas are moved from ReplicaDeletionStarted to ReplicaDeletionSuccessful state. Tears down + * the topic if all replicas of a topic have been successfully deleted + * @param replicas Replicas that were successfully deleted by the broker + */ + def completeReplicaDeletion(replicas: Set[PartitionAndReplica]): Unit = { + val successfullyDeletedReplicas = replicas.filter(r => isTopicQueuedUpForDeletion(r.topic)) + debug(s"Deletion successfully completed for replicas ${successfullyDeletedReplicas.mkString(",")}") + replicaStateMachine.handleStateChanges(successfullyDeletedReplicas.toSeq, ReplicaDeletionSuccessful) + resumeDeletions() + } + + /** + * Topic deletion can be retried if - + * 1. Topic deletion is not already complete + * 2. Topic deletion is currently not in progress for that topic + * 3. Topic is currently marked ineligible for deletion + * @param topic Topic + * @return Whether or not deletion can be retried for the topic + */ + private def isTopicEligibleForDeletion(topic: String): Boolean = { + controllerContext.isTopicQueuedUpForDeletion(topic) && + !isTopicDeletionInProgress(topic) && + !isTopicIneligibleForDeletion(topic) + } + + /** + * If the topic is queued for deletion but deletion is not currently under progress, then deletion is retried for that topic + * To ensure a successful retry, reset states for respective replicas from ReplicaDeletionIneligible to OfflineReplica state + * @param topics Topics for which deletion should be retried + */ + private def retryDeletionForIneligibleReplicas(topics: Set[String]): Unit = { + // reset replica states from ReplicaDeletionIneligible to OfflineReplica + val failedReplicas = topics.flatMap(controllerContext.replicasInState(_, ReplicaDeletionIneligible)) + debug(s"Retrying deletion of topics ${topics.mkString(",")} since replicas ${failedReplicas.mkString(",")} were not successfully deleted") + replicaStateMachine.handleStateChanges(failedReplicas.toSeq, OfflineReplica) + } + + private def completeDeleteTopic(topic: String): Unit = { + // deregister partition change listener on the deleted topic. This is to prevent the partition change listener + // firing before the new topic listener when a deleted topic gets auto created + client.mutePartitionModifications(topic) + val replicasForDeletedTopic = controllerContext.replicasInState(topic, ReplicaDeletionSuccessful) + // controller will remove this replica from the state machine as well as its partition assignment cache + replicaStateMachine.handleStateChanges(replicasForDeletedTopic.toSeq, NonExistentReplica) + client.deleteTopic(topic, controllerContext.epochZkVersion) + controllerContext.removeTopic(topic) + } + + /** + * Invoked with the list of topics to be deleted + * It invokes onPartitionDeletion for all partitions of a topic. + * The updateMetadataRequest is also going to set the leader for the topics being deleted to + * {@link LeaderAndIsr#LeaderDuringDelete}. This lets each broker know that this topic is being deleted and can be + * removed from their caches. + */ + private def onTopicDeletion(topics: Set[String]): Unit = { + val unseenTopicsForDeletion = topics.diff(controllerContext.topicsWithDeletionStarted) + if (unseenTopicsForDeletion.nonEmpty) { + val unseenPartitionsForDeletion = unseenTopicsForDeletion.flatMap(controllerContext.partitionsForTopic) + partitionStateMachine.handleStateChanges(unseenPartitionsForDeletion.toSeq, OfflinePartition) + partitionStateMachine.handleStateChanges(unseenPartitionsForDeletion.toSeq, NonExistentPartition) + // adding of unseenTopicsForDeletion to topics with deletion started must be done after the partition + // state changes to make sure the offlinePartitionCount metric is properly updated + controllerContext.beginTopicDeletion(unseenTopicsForDeletion) + } + + // send update metadata so that brokers stop serving data for topics to be deleted + client.sendMetadataUpdate(topics.flatMap(controllerContext.partitionsForTopic)) + + onPartitionDeletion(topics) + } + + /** + * Invoked by onTopicDeletion with the list of partitions for topics to be deleted + * It does the following - + * 1. Move all dead replicas directly to ReplicaDeletionIneligible state. Also mark the respective topics ineligible + * for deletion if some replicas are dead since it won't complete successfully anyway + * 2. Move all replicas for the partitions to OfflineReplica state. This will send StopReplicaRequest to the replicas + * and LeaderAndIsrRequest to the leader with the shrunk ISR. When the leader replica itself is moved to OfflineReplica state, + * it will skip sending the LeaderAndIsrRequest since the leader will be updated to -1 + * 3. Move all replicas to ReplicaDeletionStarted state. This will send StopReplicaRequest with deletePartition=true. And + * will delete all persistent data from all replicas of the respective partitions + */ + private def onPartitionDeletion(topicsToBeDeleted: Set[String]): Unit = { + val allDeadReplicas = mutable.ListBuffer.empty[PartitionAndReplica] + val allReplicasForDeletionRetry = mutable.ListBuffer.empty[PartitionAndReplica] + val allTopicsIneligibleForDeletion = mutable.Set.empty[String] + + topicsToBeDeleted.foreach { topic => + val (aliveReplicas, deadReplicas) = controllerContext.replicasForTopic(topic).partition { r => + controllerContext.isReplicaOnline(r.replica, r.topicPartition) + } + + val successfullyDeletedReplicas = controllerContext.replicasInState(topic, ReplicaDeletionSuccessful) + val replicasForDeletionRetry = aliveReplicas.diff(successfullyDeletedReplicas) + + allDeadReplicas ++= deadReplicas + allReplicasForDeletionRetry ++= replicasForDeletionRetry + + if (deadReplicas.nonEmpty) { + debug(s"Dead Replicas (${deadReplicas.mkString(",")}) found for topic $topic") + allTopicsIneligibleForDeletion += topic + } + } + + // move dead replicas directly to failed state + replicaStateMachine.handleStateChanges(allDeadReplicas, ReplicaDeletionIneligible) + // send stop replica to all followers that are not in the OfflineReplica state so they stop sending fetch requests to the leader + replicaStateMachine.handleStateChanges(allReplicasForDeletionRetry, OfflineReplica) + replicaStateMachine.handleStateChanges(allReplicasForDeletionRetry, ReplicaDeletionStarted) + + if (allTopicsIneligibleForDeletion.nonEmpty) { + markTopicIneligibleForDeletion(allTopicsIneligibleForDeletion, reason = "offline replicas") + } + } + + private def resumeDeletions(): Unit = { + val topicsQueuedForDeletion = Set.empty[String] ++ controllerContext.topicsToBeDeleted + val topicsEligibleForRetry = mutable.Set.empty[String] + val topicsEligibleForDeletion = mutable.Set.empty[String] + + if (topicsQueuedForDeletion.nonEmpty) + info(s"Handling deletion for topics ${topicsQueuedForDeletion.mkString(",")}") + + topicsQueuedForDeletion.foreach { topic => + // if all replicas are marked as deleted successfully, then topic deletion is done + if (controllerContext.areAllReplicasInState(topic, ReplicaDeletionSuccessful)) { + // clear up all state for this topic from controller cache and zookeeper + completeDeleteTopic(topic) + info(s"Deletion of topic $topic successfully completed") + } else if (!controllerContext.isAnyReplicaInState(topic, ReplicaDeletionStarted)) { + // if you come here, then no replica is in TopicDeletionStarted and all replicas are not in + // TopicDeletionSuccessful. That means, that either given topic haven't initiated deletion + // or there is at least one failed replica (which means topic deletion should be retried). + if (controllerContext.isAnyReplicaInState(topic, ReplicaDeletionIneligible)) { + topicsEligibleForRetry += topic + } + } + + // Add topic to the eligible set if it is eligible for deletion. + if (isTopicEligibleForDeletion(topic)) { + info(s"Deletion of topic $topic (re)started") + topicsEligibleForDeletion += topic + } + } + + // topic deletion retry will be kicked off + if (topicsEligibleForRetry.nonEmpty) { + retryDeletionForIneligibleReplicas(topicsEligibleForRetry) + } + + // topic deletion will be kicked off + if (topicsEligibleForDeletion.nonEmpty) { + onTopicDeletion(topicsEligibleForDeletion) + } + } +} diff --git a/core/src/main/scala/kafka/coordinator/group/DelayedHeartbeat.scala b/core/src/main/scala/kafka/coordinator/group/DelayedHeartbeat.scala new file mode 100644 index 0000000..55a0496 --- /dev/null +++ b/core/src/main/scala/kafka/coordinator/group/DelayedHeartbeat.scala @@ -0,0 +1,36 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.coordinator.group + +import kafka.server.DelayedOperation + +/** + * Delayed heartbeat operations that are added to the purgatory for session timeout checking. + * Heartbeats are paused during rebalance. + */ +private[group] class DelayedHeartbeat(coordinator: GroupCoordinator, + group: GroupMetadata, + memberId: String, + isPending: Boolean, + timeoutMs: Long) + extends DelayedOperation(timeoutMs, Some(group.lock)) { + + override def tryComplete(): Boolean = coordinator.tryCompleteHeartbeat(group, memberId, isPending, forceComplete _) + override def onExpiration(): Unit = coordinator.onExpireHeartbeat(group, memberId, isPending) + override def onComplete(): Unit = {} +} diff --git a/core/src/main/scala/kafka/coordinator/group/DelayedJoin.scala b/core/src/main/scala/kafka/coordinator/group/DelayedJoin.scala new file mode 100644 index 0000000..22dfa9d --- /dev/null +++ b/core/src/main/scala/kafka/coordinator/group/DelayedJoin.scala @@ -0,0 +1,94 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.coordinator.group + +import kafka.server.{DelayedOperationPurgatory, GroupJoinKey} + +import scala.math.{max, min} + +/** + * Delayed rebalance operations that are added to the purgatory when group is preparing for rebalance + * + * Whenever a join-group request is received, check if all known group members have requested + * to re-join the group; if yes, complete this operation to proceed rebalance. + * + * When the operation has expired, any known members that have not requested to re-join + * the group are marked as failed, and complete this operation to proceed rebalance with + * the rest of the group. + */ +private[group] class DelayedJoin( + coordinator: GroupCoordinator, + group: GroupMetadata, + rebalanceTimeout: Long +) extends DelayedRebalance( + rebalanceTimeout, + group.lock +) { + override def tryComplete(): Boolean = coordinator.tryCompleteJoin(group, forceComplete _) + + override def onExpiration(): Unit = { + // try to complete delayed actions introduced by coordinator.onCompleteJoin + tryToCompleteDelayedAction() + } + override def onComplete(): Unit = coordinator.onCompleteJoin(group) + + // TODO: remove this ugly chain after we move the action queue to handler thread + private def tryToCompleteDelayedAction(): Unit = coordinator.groupManager.replicaManager.tryCompleteActions() +} + +/** + * Delayed rebalance operation that is added to the purgatory when a group is transitioning from + * Empty to PreparingRebalance + * + * When onComplete is triggered we check if any new members have been added and if there is still time remaining + * before the rebalance timeout. If both are true we then schedule a further delay. Otherwise we complete the + * rebalance. + */ +private[group] class InitialDelayedJoin( + coordinator: GroupCoordinator, + purgatory: DelayedOperationPurgatory[DelayedRebalance], + group: GroupMetadata, + configuredRebalanceDelay: Int, + delayMs: Int, + remainingMs: Int +) extends DelayedJoin( + coordinator, + group, + delayMs +) { + override def tryComplete(): Boolean = false + + override def onComplete(): Unit = { + group.inLock { + if (group.newMemberAdded && remainingMs != 0) { + group.newMemberAdded = false + val delay = min(configuredRebalanceDelay, remainingMs) + val remaining = max(remainingMs - delayMs, 0) + purgatory.tryCompleteElseWatch(new InitialDelayedJoin(coordinator, + purgatory, + group, + configuredRebalanceDelay, + delay, + remaining + ), Seq(GroupJoinKey(group.groupId))) + } else + super.onComplete() + } + } + +} diff --git a/core/src/main/scala/kafka/coordinator/group/DelayedRebalance.scala b/core/src/main/scala/kafka/coordinator/group/DelayedRebalance.scala new file mode 100644 index 0000000..bad109a --- /dev/null +++ b/core/src/main/scala/kafka/coordinator/group/DelayedRebalance.scala @@ -0,0 +1,34 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.coordinator.group + +import kafka.server.DelayedOperation + +import java.util.concurrent.locks.Lock + +/** + * Delayed rebalance operation that is shared by DelayedJoin and DelayedSync + * operations. This allows us to use a common purgatory for both cases. + */ +private[group] abstract class DelayedRebalance( + rebalanceTimeoutMs: Long, + groupLock: Lock +) extends DelayedOperation( + rebalanceTimeoutMs, + Some(groupLock) +) diff --git a/core/src/main/scala/kafka/coordinator/group/DelayedSync.scala b/core/src/main/scala/kafka/coordinator/group/DelayedSync.scala new file mode 100644 index 0000000..a39adef --- /dev/null +++ b/core/src/main/scala/kafka/coordinator/group/DelayedSync.scala @@ -0,0 +1,48 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.coordinator.group + +/** + * Delayed rebalance operation that is added to the purgatory when the group is completing the + * rebalance. + * + * Whenever a SyncGroup is received, checks that we received all the SyncGroup request from + * each member of the group; if yes, complete this operation. + * + * When the operation has expired, any known members that have not sent a SyncGroup requests + * are removed from the group. If any members is removed, the group is rebalanced. + */ +private[group] class DelayedSync( + coordinator: GroupCoordinator, + group: GroupMetadata, + generationId: Int, + rebalanceTimeoutMs: Long +) extends DelayedRebalance( + rebalanceTimeoutMs, + group.lock +) { + override def tryComplete(): Boolean = { + coordinator.tryCompletePendingSync(group, generationId, forceComplete _) + } + + override def onExpiration(): Unit = { + coordinator.onExpirePendingSync(group, generationId) + } + + override def onComplete(): Unit = { } +} diff --git a/core/src/main/scala/kafka/coordinator/group/GroupCoordinator.scala b/core/src/main/scala/kafka/coordinator/group/GroupCoordinator.scala new file mode 100644 index 0000000..0fa1ebb --- /dev/null +++ b/core/src/main/scala/kafka/coordinator/group/GroupCoordinator.scala @@ -0,0 +1,1726 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.coordinator.group + +import java.util.Properties +import java.util.concurrent.atomic.AtomicBoolean +import kafka.common.OffsetAndMetadata +import kafka.log.LogConfig +import kafka.message.ProducerCompressionCodec +import kafka.server._ +import kafka.utils.Logging +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.internals.Topic +import org.apache.kafka.common.message.JoinGroupResponseData.JoinGroupResponseMember +import org.apache.kafka.common.message.LeaveGroupRequestData.MemberIdentity +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.metrics.stats.Meter +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.requests._ +import org.apache.kafka.common.utils.Time + +import scala.collection.{Map, Seq, Set, immutable, mutable} +import scala.math.max + +/** + * GroupCoordinator handles general group membership and offset management. + * + * Each Kafka server instantiates a coordinator which is responsible for a set of + * groups. Groups are assigned to coordinators based on their group names. + *

            + * Delayed operation locking notes: + * Delayed operations in GroupCoordinator use `group` as the delayed operation + * lock. ReplicaManager.appendRecords may be invoked while holding the group lock + * used by its callback. The delayed callback may acquire the group lock + * since the delayed operation is completed only if the group lock can be acquired. + */ +class GroupCoordinator(val brokerId: Int, + val groupConfig: GroupConfig, + val offsetConfig: OffsetConfig, + val groupManager: GroupMetadataManager, + val heartbeatPurgatory: DelayedOperationPurgatory[DelayedHeartbeat], + val rebalancePurgatory: DelayedOperationPurgatory[DelayedRebalance], + time: Time, + metrics: Metrics) extends Logging { + import GroupCoordinator._ + + type JoinCallback = JoinGroupResult => Unit + type SyncCallback = SyncGroupResult => Unit + + /* setup metrics */ + val offsetDeletionSensor = metrics.sensor("OffsetDeletions") + + offsetDeletionSensor.add(new Meter( + metrics.metricName("offset-deletion-rate", + "group-coordinator-metrics", + "The rate of administrative deleted offsets"), + metrics.metricName("offset-deletion-count", + "group-coordinator-metrics", + "The total number of administrative deleted offsets"))) + + val groupCompletedRebalanceSensor = metrics.sensor("CompletedRebalances") + + groupCompletedRebalanceSensor.add(new Meter( + metrics.metricName("group-completed-rebalance-rate", + "group-coordinator-metrics", + "The rate of completed rebalance"), + metrics.metricName("group-completed-rebalance-count", + "group-coordinator-metrics", + "The total number of completed rebalance"))) + + this.logIdent = "[GroupCoordinator " + brokerId + "]: " + + private val isActive = new AtomicBoolean(false) + + def offsetsTopicConfigs: Properties = { + val props = new Properties + props.put(LogConfig.CleanupPolicyProp, LogConfig.Compact) + props.put(LogConfig.SegmentBytesProp, offsetConfig.offsetsTopicSegmentBytes.toString) + props.put(LogConfig.CompressionTypeProp, ProducerCompressionCodec.name) + + props + } + + /** + * NOTE: If a group lock and metadataLock are simultaneously needed, + * be sure to acquire the group lock before metadataLock to prevent deadlock + */ + + /** + * Startup logic executed at the same time when the server starts up. + */ + def startup(retrieveGroupMetadataTopicPartitionCount: () => Int, enableMetadataExpiration: Boolean = true): Unit = { + info("Starting up.") + groupManager.startup(retrieveGroupMetadataTopicPartitionCount, enableMetadataExpiration) + isActive.set(true) + info("Startup complete.") + } + + /** + * Shutdown logic executed at the same time when server shuts down. + * Ordering of actions should be reversed from the startup process. + */ + def shutdown(): Unit = { + info("Shutting down.") + isActive.set(false) + groupManager.shutdown() + heartbeatPurgatory.shutdown() + rebalancePurgatory.shutdown() + info("Shutdown complete.") + } + + /** + * Verify if the group has space to accept the joining member. The various + * criteria are explained below. + */ + private def acceptJoiningMember(group: GroupMetadata, member: String): Boolean = { + group.currentState match { + // Always accept the request when the group is empty or dead + case Empty | Dead => + true + + // An existing member is accepted if it is already awaiting. New members are accepted + // up to the max group size. Note that the number of awaiting members is used here + // for two reasons: + // 1) the group size is not reliable as it could already be above the max group size + // if the max group size was reduced. + // 2) using the number of awaiting members allows to kick out the last rejoining + // members of the group. + case PreparingRebalance => + (group.has(member) && group.get(member).isAwaitingJoin) || + group.numAwaiting < groupConfig.groupMaxSize + + // An existing member is accepted. New members are accepted up to the max group size. + // Note that the group size is used here. When the group transitions to CompletingRebalance, + // members which haven't rejoined are removed. + case CompletingRebalance | Stable => + group.has(member) || group.size < groupConfig.groupMaxSize + } + } + + def handleJoinGroup(groupId: String, + memberId: String, + groupInstanceId: Option[String], + requireKnownMemberId: Boolean, + clientId: String, + clientHost: String, + rebalanceTimeoutMs: Int, + sessionTimeoutMs: Int, + protocolType: String, + protocols: List[(String, Array[Byte])], + responseCallback: JoinCallback, + requestLocal: RequestLocal = RequestLocal.NoCaching): Unit = { + validateGroupStatus(groupId, ApiKeys.JOIN_GROUP).foreach { error => + responseCallback(JoinGroupResult(memberId, error)) + return + } + + if (sessionTimeoutMs < groupConfig.groupMinSessionTimeoutMs || + sessionTimeoutMs > groupConfig.groupMaxSessionTimeoutMs) { + responseCallback(JoinGroupResult(memberId, Errors.INVALID_SESSION_TIMEOUT)) + } else { + val isUnknownMember = memberId == JoinGroupRequest.UNKNOWN_MEMBER_ID + // group is created if it does not exist and the member id is UNKNOWN. if member + // is specified but group does not exist, request is rejected with UNKNOWN_MEMBER_ID + groupManager.getOrMaybeCreateGroup(groupId, isUnknownMember) match { + case None => + responseCallback(JoinGroupResult(memberId, Errors.UNKNOWN_MEMBER_ID)) + case Some(group) => + group.inLock { + if (!acceptJoiningMember(group, memberId)) { + group.remove(memberId) + responseCallback(JoinGroupResult(JoinGroupRequest.UNKNOWN_MEMBER_ID, Errors.GROUP_MAX_SIZE_REACHED)) + } else if (isUnknownMember) { + doNewMemberJoinGroup( + group, + groupInstanceId, + requireKnownMemberId, + clientId, + clientHost, + rebalanceTimeoutMs, + sessionTimeoutMs, + protocolType, + protocols, + responseCallback, + requestLocal + ) + } else { + doCurrentMemberJoinGroup( + group, + memberId, + groupInstanceId, + clientId, + clientHost, + rebalanceTimeoutMs, + sessionTimeoutMs, + protocolType, + protocols, + responseCallback + ) + } + + // attempt to complete JoinGroup + if (group.is(PreparingRebalance)) { + rebalancePurgatory.checkAndComplete(GroupJoinKey(group.groupId)) + } + } + } + } + } + + private def doNewMemberJoinGroup( + group: GroupMetadata, + groupInstanceId: Option[String], + requireKnownMemberId: Boolean, + clientId: String, + clientHost: String, + rebalanceTimeoutMs: Int, + sessionTimeoutMs: Int, + protocolType: String, + protocols: List[(String, Array[Byte])], + responseCallback: JoinCallback, + requestLocal: RequestLocal + ): Unit = { + group.inLock { + if (group.is(Dead)) { + // if the group is marked as dead, it means some other thread has just removed the group + // from the coordinator metadata; it is likely that the group has migrated to some other + // coordinator OR the group is in a transient unstable phase. Let the member retry + // finding the correct coordinator and rejoin. + responseCallback(JoinGroupResult(JoinGroupRequest.UNKNOWN_MEMBER_ID, Errors.COORDINATOR_NOT_AVAILABLE)) + } else if (!group.supportsProtocols(protocolType, MemberMetadata.plainProtocolSet(protocols))) { + responseCallback(JoinGroupResult(JoinGroupRequest.UNKNOWN_MEMBER_ID, Errors.INCONSISTENT_GROUP_PROTOCOL)) + } else { + val newMemberId = group.generateMemberId(clientId, groupInstanceId) + groupInstanceId match { + case Some(instanceId) => + doStaticNewMemberJoinGroup( + group, + instanceId, + newMemberId, + clientId, + clientHost, + rebalanceTimeoutMs, + sessionTimeoutMs, + protocolType, + protocols, + responseCallback, + requestLocal + ) + case None => + doDynamicNewMemberJoinGroup( + group, + requireKnownMemberId, + newMemberId, + clientId, + clientHost, + rebalanceTimeoutMs, + sessionTimeoutMs, + protocolType, + protocols, + responseCallback + ) + } + } + } + } + + private def doStaticNewMemberJoinGroup( + group: GroupMetadata, + groupInstanceId: String, + newMemberId: String, + clientId: String, + clientHost: String, + rebalanceTimeoutMs: Int, + sessionTimeoutMs: Int, + protocolType: String, + protocols: List[(String, Array[Byte])], + responseCallback: JoinCallback, + requestLocal: RequestLocal + ): Unit = { + group.currentStaticMemberId(groupInstanceId) match { + case Some(oldMemberId) => + info(s"Static member with groupInstanceId=$groupInstanceId and unknown member id joins " + + s"group ${group.groupId} in ${group.currentState} state. Replacing previously mapped " + + s"member $oldMemberId with this groupInstanceId.") + updateStaticMemberAndRebalance(group, oldMemberId, newMemberId, groupInstanceId, protocols, responseCallback, requestLocal) + + case None => + info(s"Static member with groupInstanceId=$groupInstanceId and unknown member id joins " + + s"group ${group.groupId} in ${group.currentState} state. Created a new member id $newMemberId " + + s"for this member and add to the group.") + addMemberAndRebalance(rebalanceTimeoutMs, sessionTimeoutMs, newMemberId, Some(groupInstanceId), + clientId, clientHost, protocolType, protocols, group, responseCallback) + } + } + + private def doDynamicNewMemberJoinGroup( + group: GroupMetadata, + requireKnownMemberId: Boolean, + newMemberId: String, + clientId: String, + clientHost: String, + rebalanceTimeoutMs: Int, + sessionTimeoutMs: Int, + protocolType: String, + protocols: List[(String, Array[Byte])], + responseCallback: JoinCallback + ): Unit = { + if (requireKnownMemberId) { + // If member id required, register the member in the pending member list and send + // back a response to call for another join group request with allocated member id. + info(s"Dynamic member with unknown member id joins group ${group.groupId} in " + + s"${group.currentState} state. Created a new member id $newMemberId and request the " + + s"member to rejoin with this id.") + group.addPendingMember(newMemberId) + addPendingMemberExpiration(group, newMemberId, sessionTimeoutMs) + responseCallback(JoinGroupResult(newMemberId, Errors.MEMBER_ID_REQUIRED)) + } else { + info(s"Dynamic Member with unknown member id joins group ${group.groupId} in " + + s"${group.currentState} state. Created a new member id $newMemberId for this member " + + s"and add to the group.") + addMemberAndRebalance(rebalanceTimeoutMs, sessionTimeoutMs, newMemberId, None, + clientId, clientHost, protocolType, protocols, group, responseCallback) + } + } + + private def validateCurrentMember( + group: GroupMetadata, + memberId: String, + groupInstanceId: Option[String], + operation: String + ): Option[Errors] = { + // We are validating two things: + // 1. If `groupInstanceId` is present, then it exists and is mapped to `memberId` + // 2. The `memberId` exists in the group + groupInstanceId.flatMap { instanceId => + group.currentStaticMemberId(instanceId) match { + case Some(currentMemberId) if currentMemberId != memberId => + info(s"Request memberId=$memberId for static member with groupInstanceId=$instanceId " + + s"is fenced by current memberId=$currentMemberId during operation $operation") + Some(Errors.FENCED_INSTANCE_ID) + case Some(_) => + None + case None => + Some(Errors.UNKNOWN_MEMBER_ID) + } + }.orElse { + if (!group.has(memberId)) { + Some(Errors.UNKNOWN_MEMBER_ID) + } else { + None + } + } + } + + private def doCurrentMemberJoinGroup( + group: GroupMetadata, + memberId: String, + groupInstanceId: Option[String], + clientId: String, + clientHost: String, + rebalanceTimeoutMs: Int, + sessionTimeoutMs: Int, + protocolType: String, + protocols: List[(String, Array[Byte])], + responseCallback: JoinCallback + ): Unit = { + group.inLock { + if (group.is(Dead)) { + // if the group is marked as dead, it means some other thread has just removed the group + // from the coordinator metadata; it is likely that the group has migrated to some other + // coordinator OR the group is in a transient unstable phase. Let the member retry + // finding the correct coordinator and rejoin. + responseCallback(JoinGroupResult(memberId, Errors.COORDINATOR_NOT_AVAILABLE)) + } else if (!group.supportsProtocols(protocolType, MemberMetadata.plainProtocolSet(protocols))) { + responseCallback(JoinGroupResult(memberId, Errors.INCONSISTENT_GROUP_PROTOCOL)) + } else if (group.isPendingMember(memberId)) { + // A rejoining pending member will be accepted. Note that pending member cannot be a static member. + groupInstanceId.foreach { instanceId => + throw new IllegalStateException(s"Received unexpected JoinGroup with groupInstanceId=$instanceId " + + s"for pending member with memberId=$memberId") + } + + debug(s"Pending dynamic member with id $memberId joins group ${group.groupId} in " + + s"${group.currentState} state. Adding to the group now.") + addMemberAndRebalance(rebalanceTimeoutMs, sessionTimeoutMs, memberId, None, + clientId, clientHost, protocolType, protocols, group, responseCallback) + } else { + val memberErrorOpt = validateCurrentMember( + group, + memberId, + groupInstanceId, + operation = "join-group" + ) + + memberErrorOpt match { + case Some(error) => responseCallback(JoinGroupResult(memberId, error)) + + case None => group.currentState match { + case PreparingRebalance => + val member = group.get(memberId) + updateMemberAndRebalance(group, member, protocols, s"Member ${member.memberId} joining group during ${group.currentState}", responseCallback) + + case CompletingRebalance => + val member = group.get(memberId) + if (member.matches(protocols)) { + // member is joining with the same metadata (which could be because it failed to + // receive the initial JoinGroup response), so just return current group information + // for the current generation. + responseCallback(JoinGroupResult( + members = if (group.isLeader(memberId)) { + group.currentMemberMetadata + } else { + List.empty + }, + memberId = memberId, + generationId = group.generationId, + protocolType = group.protocolType, + protocolName = group.protocolName, + leaderId = group.leaderOrNull, + error = Errors.NONE)) + } else { + // member has changed metadata, so force a rebalance + updateMemberAndRebalance(group, member, protocols, s"Updating metadata for member ${member.memberId} during ${group.currentState}", responseCallback) + } + + case Stable => + val member = group.get(memberId) + if (group.isLeader(memberId)) { + // force a rebalance if the leader sends JoinGroup; + // This allows the leader to trigger rebalances for changes affecting assignment + // which do not affect the member metadata (such as topic metadata changes for the consumer) + updateMemberAndRebalance(group, member, protocols, s"Leader ${member.memberId} re-joining group during ${group.currentState}", responseCallback) + } else if (!member.matches(protocols)) { + updateMemberAndRebalance(group, member, protocols, s"Updating metadata for member ${member.memberId} during ${group.currentState}", responseCallback) + } else { + // for followers with no actual change to their metadata, just return group information + // for the current generation which will allow them to issue SyncGroup + responseCallback(JoinGroupResult( + members = List.empty, + memberId = memberId, + generationId = group.generationId, + protocolType = group.protocolType, + protocolName = group.protocolName, + leaderId = group.leaderOrNull, + error = Errors.NONE)) + } + + case Empty | Dead => + // Group reaches unexpected state. Let the joining member reset their generation and rejoin. + warn(s"Attempt to add rejoining member $memberId of group ${group.groupId} in " + + s"unexpected group state ${group.currentState}") + responseCallback(JoinGroupResult(memberId, Errors.UNKNOWN_MEMBER_ID)) + } + } + } + } + } + + def handleSyncGroup(groupId: String, + generation: Int, + memberId: String, + protocolType: Option[String], + protocolName: Option[String], + groupInstanceId: Option[String], + groupAssignment: Map[String, Array[Byte]], + responseCallback: SyncCallback, + requestLocal: RequestLocal = RequestLocal.NoCaching): Unit = { + validateGroupStatus(groupId, ApiKeys.SYNC_GROUP) match { + case Some(error) if error == Errors.COORDINATOR_LOAD_IN_PROGRESS => + // The coordinator is loading, which means we've lost the state of the active rebalance and the + // group will need to start over at JoinGroup. By returning rebalance in progress, the consumer + // will attempt to rejoin without needing to rediscover the coordinator. Note that we cannot + // return COORDINATOR_LOAD_IN_PROGRESS since older clients do not expect the error. + responseCallback(SyncGroupResult(Errors.REBALANCE_IN_PROGRESS)) + + case Some(error) => responseCallback(SyncGroupResult(error)) + + case None => + groupManager.getGroup(groupId) match { + case None => responseCallback(SyncGroupResult(Errors.UNKNOWN_MEMBER_ID)) + case Some(group) => doSyncGroup(group, generation, memberId, protocolType, protocolName, + groupInstanceId, groupAssignment, requestLocal, responseCallback) + } + } + } + + private def validateSyncGroup( + group: GroupMetadata, + generationId: Int, + memberId: String, + protocolType: Option[String], + protocolName: Option[String], + groupInstanceId: Option[String], + ): Option[Errors] = { + if (group.is(Dead)) { + // if the group is marked as dead, it means some other thread has just removed the group + // from the coordinator metadata; this is likely that the group has migrated to some other + // coordinator OR the group is in a transient unstable phase. Let the member retry + // finding the correct coordinator and rejoin. + Some(Errors.COORDINATOR_NOT_AVAILABLE) + } else { + validateCurrentMember( + group, + memberId, + groupInstanceId, + operation = "sync-group" + ).orElse { + if (generationId != group.generationId) { + Some(Errors.ILLEGAL_GENERATION) + } else if (protocolType.isDefined && !group.protocolType.contains(protocolType.get)) { + Some(Errors.INCONSISTENT_GROUP_PROTOCOL) + } else if (protocolName.isDefined && !group.protocolName.contains(protocolName.get)) { + Some(Errors.INCONSISTENT_GROUP_PROTOCOL) + } else { + None + } + } + } + } + + private def doSyncGroup(group: GroupMetadata, + generationId: Int, + memberId: String, + protocolType: Option[String], + protocolName: Option[String], + groupInstanceId: Option[String], + groupAssignment: Map[String, Array[Byte]], + requestLocal: RequestLocal, + responseCallback: SyncCallback): Unit = { + group.inLock { + val validationErrorOpt = validateSyncGroup( + group, + generationId, + memberId, + protocolType, + protocolName, + groupInstanceId + ) + + validationErrorOpt match { + case Some(error) => responseCallback(SyncGroupResult(error)) + + case None => group.currentState match { + case Empty => + responseCallback(SyncGroupResult(Errors.UNKNOWN_MEMBER_ID)) + + case PreparingRebalance => + responseCallback(SyncGroupResult(Errors.REBALANCE_IN_PROGRESS)) + + case CompletingRebalance => + group.get(memberId).awaitingSyncCallback = responseCallback + removePendingSyncMember(group, memberId) + + // if this is the leader, then we can attempt to persist state and transition to stable + if (group.isLeader(memberId)) { + info(s"Assignment received from leader $memberId for group ${group.groupId} for generation ${group.generationId}. " + + s"The group has ${group.size} members, ${group.allStaticMembers.size} of which are static.") + + // fill any missing members with an empty assignment + val missing = group.allMembers.diff(groupAssignment.keySet) + val assignment = groupAssignment ++ missing.map(_ -> Array.empty[Byte]).toMap + + if (missing.nonEmpty) { + warn(s"Setting empty assignments for members $missing of ${group.groupId} for generation ${group.generationId}") + } + + groupManager.storeGroup(group, assignment, (error: Errors) => { + group.inLock { + // another member may have joined the group while we were awaiting this callback, + // so we must ensure we are still in the CompletingRebalance state and the same generation + // when it gets invoked. if we have transitioned to another state, then do nothing + if (group.is(CompletingRebalance) && generationId == group.generationId) { + if (error != Errors.NONE) { + resetAndPropagateAssignmentError(group, error) + maybePrepareRebalance(group, s"Error when storing group assignment during SyncGroup (member: $memberId)") + } else { + setAndPropagateAssignment(group, assignment) + group.transitionTo(Stable) + } + } + } + }, requestLocal) + groupCompletedRebalanceSensor.record() + } + + case Stable => + removePendingSyncMember(group, memberId) + + // if the group is stable, we just return the current assignment + val memberMetadata = group.get(memberId) + responseCallback(SyncGroupResult(group.protocolType, group.protocolName, memberMetadata.assignment, Errors.NONE)) + completeAndScheduleNextHeartbeatExpiration(group, group.get(memberId)) + + case Dead => + throw new IllegalStateException(s"Reached unexpected condition for Dead group ${group.groupId}") + } + } + } + } + + def handleLeaveGroup(groupId: String, + leavingMembers: List[MemberIdentity], + responseCallback: LeaveGroupResult => Unit): Unit = { + + def removeCurrentMemberFromGroup(group: GroupMetadata, memberId: String): Unit = { + val member = group.get(memberId) + removeMemberAndUpdateGroup(group, member, s"Removing member $memberId on LeaveGroup") + removeHeartbeatForLeavingMember(group, member.memberId) + info(s"Member $member has left group $groupId through explicit `LeaveGroup` request") + } + + validateGroupStatus(groupId, ApiKeys.LEAVE_GROUP) match { + case Some(error) => + responseCallback(leaveError(error, List.empty)) + case None => + groupManager.getGroup(groupId) match { + case None => + responseCallback(leaveError(Errors.NONE, leavingMembers.map {leavingMember => + memberLeaveError(leavingMember, Errors.UNKNOWN_MEMBER_ID) + })) + case Some(group) => + group.inLock { + if (group.is(Dead)) { + responseCallback(leaveError(Errors.COORDINATOR_NOT_AVAILABLE, List.empty)) + } else { + val memberErrors = leavingMembers.map { leavingMember => + val memberId = leavingMember.memberId + val groupInstanceId = Option(leavingMember.groupInstanceId) + + // The LeaveGroup API allows administrative removal of members by GroupInstanceId + // in which case we expect the MemberId to be undefined. + if (memberId == JoinGroupRequest.UNKNOWN_MEMBER_ID) { + groupInstanceId.flatMap(group.currentStaticMemberId) match { + case Some(currentMemberId) => + removeCurrentMemberFromGroup(group, currentMemberId) + memberLeaveError(leavingMember, Errors.NONE) + case None => + memberLeaveError(leavingMember, Errors.UNKNOWN_MEMBER_ID) + } + } else if (group.isPendingMember(memberId)) { + removePendingMemberAndUpdateGroup(group, memberId) + heartbeatPurgatory.checkAndComplete(MemberKey(group.groupId, memberId)) + info(s"Pending member with memberId=$memberId has left group ${group.groupId} " + + s"through explicit `LeaveGroup` request") + memberLeaveError(leavingMember, Errors.NONE) + } else { + val memberError = validateCurrentMember( + group, + memberId, + groupInstanceId, + operation = "leave-group" + ).getOrElse { + removeCurrentMemberFromGroup(group, memberId) + Errors.NONE + } + memberLeaveError(leavingMember, memberError) + } + } + responseCallback(leaveError(Errors.NONE, memberErrors)) + } + } + } + } + } + + def handleDeleteGroups(groupIds: Set[String], + requestLocal: RequestLocal = RequestLocal.NoCaching): Map[String, Errors] = { + val groupErrors = mutable.Map.empty[String, Errors] + val groupsEligibleForDeletion = mutable.ArrayBuffer[GroupMetadata]() + + groupIds.foreach { groupId => + validateGroupStatus(groupId, ApiKeys.DELETE_GROUPS) match { + case Some(error) => + groupErrors += groupId -> error + + case None => + groupManager.getGroup(groupId) match { + case None => + groupErrors += groupId -> + (if (groupManager.groupNotExists(groupId)) Errors.GROUP_ID_NOT_FOUND else Errors.NOT_COORDINATOR) + case Some(group) => + group.inLock { + group.currentState match { + case Dead => + groupErrors += groupId -> + (if (groupManager.groupNotExists(groupId)) Errors.GROUP_ID_NOT_FOUND else Errors.NOT_COORDINATOR) + case Empty => + group.transitionTo(Dead) + groupsEligibleForDeletion += group + case Stable | PreparingRebalance | CompletingRebalance => + groupErrors(groupId) = Errors.NON_EMPTY_GROUP + } + } + } + } + } + + if (groupsEligibleForDeletion.nonEmpty) { + val offsetsRemoved = groupManager.cleanupGroupMetadata(groupsEligibleForDeletion, requestLocal, + _.removeAllOffsets()) + groupErrors ++= groupsEligibleForDeletion.map(_.groupId -> Errors.NONE).toMap + info(s"The following groups were deleted: ${groupsEligibleForDeletion.map(_.groupId).mkString(", ")}. " + + s"A total of $offsetsRemoved offsets were removed.") + } + + groupErrors + } + + def handleDeleteOffsets(groupId: String, partitions: Seq[TopicPartition], + requestLocal: RequestLocal): (Errors, Map[TopicPartition, Errors]) = { + var groupError: Errors = Errors.NONE + var partitionErrors: Map[TopicPartition, Errors] = Map() + var partitionsEligibleForDeletion: Seq[TopicPartition] = Seq() + + validateGroupStatus(groupId, ApiKeys.OFFSET_DELETE) match { + case Some(error) => + groupError = error + + case None => + groupManager.getGroup(groupId) match { + case None => + groupError = if (groupManager.groupNotExists(groupId)) + Errors.GROUP_ID_NOT_FOUND else Errors.NOT_COORDINATOR + + case Some(group) => + group.inLock { + group.currentState match { + case Dead => + groupError = if (groupManager.groupNotExists(groupId)) + Errors.GROUP_ID_NOT_FOUND else Errors.NOT_COORDINATOR + + case Empty => + partitionsEligibleForDeletion = partitions + + case PreparingRebalance | CompletingRebalance | Stable if group.isConsumerGroup => + val (consumed, notConsumed) = + partitions.partition(tp => group.isSubscribedToTopic(tp.topic())) + + partitionsEligibleForDeletion = notConsumed + partitionErrors = consumed.map(_ -> Errors.GROUP_SUBSCRIBED_TO_TOPIC).toMap + + case _ => + groupError = Errors.NON_EMPTY_GROUP + } + } + + if (partitionsEligibleForDeletion.nonEmpty) { + val offsetsRemoved = groupManager.cleanupGroupMetadata(Seq(group), requestLocal, + _.removeOffsets(partitionsEligibleForDeletion)) + + partitionErrors ++= partitionsEligibleForDeletion.map(_ -> Errors.NONE).toMap + + offsetDeletionSensor.record(offsetsRemoved) + + info(s"The following offsets of the group $groupId were deleted: ${partitionsEligibleForDeletion.mkString(", ")}. " + + s"A total of $offsetsRemoved offsets were removed.") + } + } + } + + // If there is a group error, the partition errors is empty + groupError -> partitionErrors + } + + private def validateHeartbeat( + group: GroupMetadata, + generationId: Int, + memberId: String, + groupInstanceId: Option[String] + ): Option[Errors] = { + if (group.is(Dead)) { + Some(Errors.COORDINATOR_NOT_AVAILABLE) + } else { + validateCurrentMember( + group, + memberId, + groupInstanceId, + operation = "heartbeat" + ).orElse { + if (generationId != group.generationId) { + Some(Errors.ILLEGAL_GENERATION) + } else { + None + } + } + } + } + + def handleHeartbeat(groupId: String, + memberId: String, + groupInstanceId: Option[String], + generationId: Int, + responseCallback: Errors => Unit): Unit = { + validateGroupStatus(groupId, ApiKeys.HEARTBEAT).foreach { error => + if (error == Errors.COORDINATOR_LOAD_IN_PROGRESS) + // the group is still loading, so respond just blindly + responseCallback(Errors.NONE) + else + responseCallback(error) + return + } + + val err = groupManager.getGroup(groupId) match { + case None => + Errors.UNKNOWN_MEMBER_ID + + case Some(group) => group.inLock { + val validationErrorOpt = validateHeartbeat( + group, + generationId, + memberId, + groupInstanceId + ) + + if (validationErrorOpt.isDefined) { + validationErrorOpt.get + } else { + group.currentState match { + case Empty => + Errors.UNKNOWN_MEMBER_ID + + case CompletingRebalance => + // consumers may start sending heartbeat after join-group response, in which case + // we should treat them as normal hb request and reset the timer + val member = group.get(memberId) + completeAndScheduleNextHeartbeatExpiration(group, member) + Errors.NONE + + case PreparingRebalance => + val member = group.get(memberId) + completeAndScheduleNextHeartbeatExpiration(group, member) + Errors.REBALANCE_IN_PROGRESS + + case Stable => + val member = group.get(memberId) + completeAndScheduleNextHeartbeatExpiration(group, member) + Errors.NONE + + case Dead => + throw new IllegalStateException(s"Reached unexpected condition for Dead group $groupId") + } + } + } + } + responseCallback(err) + } + + def handleTxnCommitOffsets(groupId: String, + producerId: Long, + producerEpoch: Short, + memberId: String, + groupInstanceId: Option[String], + generationId: Int, + offsetMetadata: immutable.Map[TopicPartition, OffsetAndMetadata], + responseCallback: immutable.Map[TopicPartition, Errors] => Unit, + requestLocal: RequestLocal = RequestLocal.NoCaching): Unit = { + validateGroupStatus(groupId, ApiKeys.TXN_OFFSET_COMMIT) match { + case Some(error) => responseCallback(offsetMetadata.map { case (k, _) => k -> error }) + case None => + val group = groupManager.getGroup(groupId).getOrElse { + groupManager.addGroup(new GroupMetadata(groupId, Empty, time)) + } + doTxnCommitOffsets(group, memberId, groupInstanceId, generationId, producerId, producerEpoch, + offsetMetadata, requestLocal, responseCallback) + } + } + + def handleCommitOffsets(groupId: String, + memberId: String, + groupInstanceId: Option[String], + generationId: Int, + offsetMetadata: immutable.Map[TopicPartition, OffsetAndMetadata], + responseCallback: immutable.Map[TopicPartition, Errors] => Unit, + requestLocal: RequestLocal = RequestLocal.NoCaching): Unit = { + validateGroupStatus(groupId, ApiKeys.OFFSET_COMMIT) match { + case Some(error) => responseCallback(offsetMetadata.map { case (k, _) => k -> error }) + case None => + groupManager.getGroup(groupId) match { + case None => + if (generationId < 0) { + // the group is not relying on Kafka for group management, so allow the commit + val group = groupManager.addGroup(new GroupMetadata(groupId, Empty, time)) + doCommitOffsets(group, memberId, groupInstanceId, generationId, offsetMetadata, + responseCallback, requestLocal) + } else { + // or this is a request coming from an older generation. either way, reject the commit + responseCallback(offsetMetadata.map { case (k, _) => k -> Errors.ILLEGAL_GENERATION }) + } + + case Some(group) => + doCommitOffsets(group, memberId, groupInstanceId, generationId, offsetMetadata, + responseCallback, requestLocal) + } + } + } + + def scheduleHandleTxnCompletion(producerId: Long, + offsetsPartitions: Iterable[TopicPartition], + transactionResult: TransactionResult): Unit = { + require(offsetsPartitions.forall(_.topic == Topic.GROUP_METADATA_TOPIC_NAME)) + val isCommit = transactionResult == TransactionResult.COMMIT + groupManager.scheduleHandleTxnCompletion(producerId, offsetsPartitions.map(_.partition).toSet, isCommit) + } + + private def doTxnCommitOffsets(group: GroupMetadata, + memberId: String, + groupInstanceId: Option[String], + generationId: Int, + producerId: Long, + producerEpoch: Short, + offsetMetadata: immutable.Map[TopicPartition, OffsetAndMetadata], + requestLocal: RequestLocal, + responseCallback: immutable.Map[TopicPartition, Errors] => Unit): Unit = { + group.inLock { + val validationErrorOpt = validateOffsetCommit( + group, + generationId, + memberId, + groupInstanceId, + isTransactional = true + ) + + if (validationErrorOpt.isDefined) { + responseCallback(offsetMetadata.map { case (k, _) => k -> validationErrorOpt.get }) + } else { + groupManager.storeOffsets(group, memberId, offsetMetadata, responseCallback, producerId, + producerEpoch, requestLocal) + } + } + } + + private def validateOffsetCommit( + group: GroupMetadata, + generationId: Int, + memberId: String, + groupInstanceId: Option[String], + isTransactional: Boolean + ): Option[Errors] = { + if (group.is(Dead)) { + Some(Errors.COORDINATOR_NOT_AVAILABLE) + } else if (generationId >= 0 || memberId != JoinGroupRequest.UNKNOWN_MEMBER_ID || groupInstanceId.isDefined) { + validateCurrentMember( + group, + memberId, + groupInstanceId, + operation = if (isTransactional) "txn-offset-commit" else "offset-commit" + ).orElse { + if (generationId != group.generationId) { + Some(Errors.ILLEGAL_GENERATION) + } else { + None + } + } + } else if (!isTransactional && !group.is(Empty)) { + // When the group is non-empty, only members can commit offsets. + // This does not apply to transactional offset commits, since the + // older versions of this protocol do not require memberId and + // generationId. + Some(Errors.UNKNOWN_MEMBER_ID) + } else { + None + } + } + + private def doCommitOffsets(group: GroupMetadata, + memberId: String, + groupInstanceId: Option[String], + generationId: Int, + offsetMetadata: immutable.Map[TopicPartition, OffsetAndMetadata], + responseCallback: immutable.Map[TopicPartition, Errors] => Unit, + requestLocal: RequestLocal): Unit = { + group.inLock { + val validationErrorOpt = validateOffsetCommit( + group, + generationId, + memberId, + groupInstanceId, + isTransactional = false + ) + + if (validationErrorOpt.isDefined) { + responseCallback(offsetMetadata.map { case (k, _) => k -> validationErrorOpt.get }) + } else { + group.currentState match { + case Empty => + groupManager.storeOffsets(group, memberId, offsetMetadata, responseCallback) + + case Stable | PreparingRebalance => + // During PreparingRebalance phase, we still allow a commit request since we rely + // on heartbeat response to eventually notify the rebalance in progress signal to the consumer + val member = group.get(memberId) + completeAndScheduleNextHeartbeatExpiration(group, member) + groupManager.storeOffsets(group, memberId, offsetMetadata, responseCallback, requestLocal = requestLocal) + + case CompletingRebalance => + // We should not receive a commit request if the group has not completed rebalance; + // but since the consumer's member.id and generation is valid, it means it has received + // the latest group generation information from the JoinResponse. + // So let's return a REBALANCE_IN_PROGRESS to let consumer handle it gracefully. + responseCallback(offsetMetadata.map { case (k, _) => k -> Errors.REBALANCE_IN_PROGRESS }) + + case _ => + throw new RuntimeException(s"Logic error: unexpected group state ${group.currentState}") + } + } + } + } + + def handleFetchOffsets(groupId: String, requireStable: Boolean, partitions: Option[Seq[TopicPartition]] = None): + (Errors, Map[TopicPartition, OffsetFetchResponse.PartitionData]) = { + + validateGroupStatus(groupId, ApiKeys.OFFSET_FETCH) match { + case Some(error) => error -> Map.empty + case None => + // return offsets blindly regardless the current group state since the group may be using + // Kafka commit storage without automatic group management + (Errors.NONE, groupManager.getOffsets(groupId, requireStable, partitions)) + } + } + + def handleListGroups(states: Set[String]): (Errors, List[GroupOverview]) = { + if (!isActive.get) { + (Errors.COORDINATOR_NOT_AVAILABLE, List[GroupOverview]()) + } else { + val errorCode = if (groupManager.isLoading) Errors.COORDINATOR_LOAD_IN_PROGRESS else Errors.NONE + // if states is empty, return all groups + val groups = if (states.isEmpty) + groupManager.currentGroups + else + groupManager.currentGroups.filter(g => states.contains(g.summary.state)) + (errorCode, groups.map(_.overview).toList) + } + } + + def handleDescribeGroup(groupId: String): (Errors, GroupSummary) = { + validateGroupStatus(groupId, ApiKeys.DESCRIBE_GROUPS) match { + case Some(error) => (error, GroupCoordinator.EmptyGroup) + case None => + groupManager.getGroup(groupId) match { + case None => (Errors.NONE, GroupCoordinator.DeadGroup) + case Some(group) => + group.inLock { + (Errors.NONE, group.summary) + } + } + } + } + + def handleDeletedPartitions(topicPartitions: Seq[TopicPartition], requestLocal: RequestLocal): Unit = { + val offsetsRemoved = groupManager.cleanupGroupMetadata(groupManager.currentGroups, requestLocal, + _.removeOffsets(topicPartitions)) + info(s"Removed $offsetsRemoved offsets associated with deleted partitions: ${topicPartitions.mkString(", ")}.") + } + + private def isValidGroupId(groupId: String, api: ApiKeys): Boolean = { + api match { + case ApiKeys.OFFSET_COMMIT | ApiKeys.OFFSET_FETCH | ApiKeys.DESCRIBE_GROUPS | ApiKeys.DELETE_GROUPS => + // For backwards compatibility, we support the offset commit APIs for the empty groupId, and also + // in DescribeGroups and DeleteGroups so that users can view and delete state of all groups. + groupId != null + case _ => + // The remaining APIs are groups using Kafka for group coordination and must have a non-empty groupId + groupId != null && !groupId.isEmpty + } + } + + /** + * Check that the groupId is valid, assigned to this coordinator and that the group has been loaded. + */ + private def validateGroupStatus(groupId: String, api: ApiKeys): Option[Errors] = { + if (!isValidGroupId(groupId, api)) + Some(Errors.INVALID_GROUP_ID) + else if (!isActive.get) + Some(Errors.COORDINATOR_NOT_AVAILABLE) + else if (isCoordinatorLoadInProgress(groupId)) + Some(Errors.COORDINATOR_LOAD_IN_PROGRESS) + else if (!isCoordinatorForGroup(groupId)) + Some(Errors.NOT_COORDINATOR) + else + None + } + + private def onGroupUnloaded(group: GroupMetadata): Unit = { + group.inLock { + info(s"Unloading group metadata for ${group.groupId} with generation ${group.generationId}") + val previousState = group.currentState + group.transitionTo(Dead) + + previousState match { + case Empty | Dead => + case PreparingRebalance => + for (member <- group.allMemberMetadata) { + group.maybeInvokeJoinCallback(member, JoinGroupResult(member.memberId, Errors.NOT_COORDINATOR)) + } + + rebalancePurgatory.checkAndComplete(GroupJoinKey(group.groupId)) + + case Stable | CompletingRebalance => + for (member <- group.allMemberMetadata) { + group.maybeInvokeSyncCallback(member, SyncGroupResult(Errors.NOT_COORDINATOR)) + heartbeatPurgatory.checkAndComplete(MemberKey(group.groupId, member.memberId)) + } + } + + removeSyncExpiration(group) + } + } + + private def onGroupLoaded(group: GroupMetadata): Unit = { + group.inLock { + info(s"Loading group metadata for ${group.groupId} with generation ${group.generationId}") + assert(group.is(Stable) || group.is(Empty)) + if (groupIsOverCapacity(group)) { + prepareRebalance(group, s"Freshly-loaded group is over capacity (${groupConfig.groupMaxSize}). " + + "Rebalancing in order to give a chance for consumers to commit offsets") + } + + group.allMemberMetadata.foreach(completeAndScheduleNextHeartbeatExpiration(group, _)) + } + } + + /** + * Load cached state from the given partition and begin handling requests for groups which map to it. + * + * @param offsetTopicPartitionId The partition we are now leading + */ + def onElection(offsetTopicPartitionId: Int, coordinatorEpoch: Int): Unit = { + info(s"Elected as the group coordinator for partition $offsetTopicPartitionId in epoch $coordinatorEpoch") + groupManager.scheduleLoadGroupAndOffsets(offsetTopicPartitionId, coordinatorEpoch, onGroupLoaded) + } + + /** + * Unload cached state for the given partition and stop handling requests for groups which map to it. + * + * @param offsetTopicPartitionId The partition we are no longer leading + */ + def onResignation(offsetTopicPartitionId: Int, coordinatorEpoch: Option[Int]): Unit = { + info(s"Resigned as the group coordinator for partition $offsetTopicPartitionId in epoch $coordinatorEpoch") + groupManager.removeGroupsForPartition(offsetTopicPartitionId, coordinatorEpoch, onGroupUnloaded) + } + + private def setAndPropagateAssignment(group: GroupMetadata, assignment: Map[String, Array[Byte]]): Unit = { + assert(group.is(CompletingRebalance)) + group.allMemberMetadata.foreach(member => member.assignment = assignment(member.memberId)) + propagateAssignment(group, Errors.NONE) + } + + private def resetAndPropagateAssignmentError(group: GroupMetadata, error: Errors): Unit = { + assert(group.is(CompletingRebalance)) + group.allMemberMetadata.foreach(_.assignment = Array.empty) + propagateAssignment(group, error) + } + + private def propagateAssignment(group: GroupMetadata, error: Errors): Unit = { + val (protocolType, protocolName) = if (error == Errors.NONE) + (group.protocolType, group.protocolName) + else + (None, None) + for (member <- group.allMemberMetadata) { + if (member.assignment.isEmpty && error == Errors.NONE) { + warn(s"Sending empty assignment to member ${member.memberId} of ${group.groupId} for generation ${group.generationId} with no errors") + } + + if (group.maybeInvokeSyncCallback(member, SyncGroupResult(protocolType, protocolName, member.assignment, error))) { + // reset the session timeout for members after propagating the member's assignment. + // This is because if any member's session expired while we were still awaiting either + // the leader sync group or the storage callback, its expiration will be ignored and no + // future heartbeat expectations will not be scheduled. + completeAndScheduleNextHeartbeatExpiration(group, member) + } + } + } + + /** + * Complete existing DelayedHeartbeats for the given member and schedule the next one + */ + private def completeAndScheduleNextHeartbeatExpiration(group: GroupMetadata, member: MemberMetadata): Unit = { + completeAndScheduleNextExpiration(group, member, member.sessionTimeoutMs) + } + + private def completeAndScheduleNextExpiration(group: GroupMetadata, member: MemberMetadata, timeoutMs: Long): Unit = { + val memberKey = MemberKey(group.groupId, member.memberId) + + // complete current heartbeat expectation + member.heartbeatSatisfied = true + heartbeatPurgatory.checkAndComplete(memberKey) + + // reschedule the next heartbeat expiration deadline + member.heartbeatSatisfied = false + val delayedHeartbeat = new DelayedHeartbeat(this, group, member.memberId, isPending = false, timeoutMs) + heartbeatPurgatory.tryCompleteElseWatch(delayedHeartbeat, Seq(memberKey)) + } + + /** + * Add pending member expiration to heartbeat purgatory + */ + private def addPendingMemberExpiration(group: GroupMetadata, pendingMemberId: String, timeoutMs: Long): Unit = { + val pendingMemberKey = MemberKey(group.groupId, pendingMemberId) + val delayedHeartbeat = new DelayedHeartbeat(this, group, pendingMemberId, isPending = true, timeoutMs) + heartbeatPurgatory.tryCompleteElseWatch(delayedHeartbeat, Seq(pendingMemberKey)) + } + + private def removeHeartbeatForLeavingMember(group: GroupMetadata, memberId: String): Unit = { + val memberKey = MemberKey(group.groupId, memberId) + heartbeatPurgatory.checkAndComplete(memberKey) + } + + private def addMemberAndRebalance(rebalanceTimeoutMs: Int, + sessionTimeoutMs: Int, + memberId: String, + groupInstanceId: Option[String], + clientId: String, + clientHost: String, + protocolType: String, + protocols: List[(String, Array[Byte])], + group: GroupMetadata, + callback: JoinCallback): Unit = { + val member = new MemberMetadata(memberId, groupInstanceId, clientId, clientHost, + rebalanceTimeoutMs, sessionTimeoutMs, protocolType, protocols) + + member.isNew = true + + // update the newMemberAdded flag to indicate that the join group can be further delayed + if (group.is(PreparingRebalance) && group.generationId == 0) + group.newMemberAdded = true + + group.add(member, callback) + + // The session timeout does not affect new members since they do not have their memberId and + // cannot send heartbeats. Furthermore, we cannot detect disconnects because sockets are muted + // while the JoinGroup is in purgatory. If the client does disconnect (e.g. because of a request + // timeout during a long rebalance), they may simply retry which will lead to a lot of defunct + // members in the rebalance. To prevent this going on indefinitely, we timeout JoinGroup requests + // for new members. If the new member is still there, we expect it to retry. + completeAndScheduleNextExpiration(group, member, NewMemberJoinTimeoutMs) + + maybePrepareRebalance(group, s"Adding new member $memberId with group instance id $groupInstanceId") + } + + private def updateStaticMemberAndRebalance(group: GroupMetadata, + oldMemberId: String, + newMemberId: String, + groupInstanceId: String, + protocols: List[(String, Array[Byte])], + responseCallback: JoinCallback, + requestLocal: RequestLocal): Unit = { + val currentLeader = group.leaderOrNull + val member = group.replaceStaticMember(groupInstanceId, oldMemberId, newMemberId) + // Heartbeat of old member id will expire without effect since the group no longer contains that member id. + // New heartbeat shall be scheduled with new member id. + completeAndScheduleNextHeartbeatExpiration(group, member) + + val knownStaticMember = group.get(newMemberId) + group.updateMember(knownStaticMember, protocols, responseCallback) + val oldProtocols = knownStaticMember.supportedProtocols + + group.currentState match { + case Stable => + // check if group's selectedProtocol of next generation will change, if not, simply store group to persist the + // updated static member, if yes, rebalance should be triggered to let the group's assignment and selectProtocol consistent + val selectedProtocolOfNextGeneration = group.selectProtocol + if (group.protocolName.contains(selectedProtocolOfNextGeneration)) { + info(s"Static member which joins during Stable stage and doesn't affect selectProtocol will not trigger rebalance.") + val groupAssignment: Map[String, Array[Byte]] = group.allMemberMetadata.map(member => member.memberId -> member.assignment).toMap + groupManager.storeGroup(group, groupAssignment, error => { + if (error != Errors.NONE) { + warn(s"Failed to persist metadata for group ${group.groupId}: ${error.message}") + + // Failed to persist member.id of the given static member, revert the update of the static member in the group. + group.updateMember(knownStaticMember, oldProtocols, null) + val oldMember = group.replaceStaticMember(groupInstanceId, newMemberId, oldMemberId) + completeAndScheduleNextHeartbeatExpiration(group, oldMember) + responseCallback(JoinGroupResult( + List.empty, + memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID, + generationId = group.generationId, + protocolType = group.protocolType, + protocolName = group.protocolName, + leaderId = currentLeader, + error = error + )) + } else { + group.maybeInvokeJoinCallback(member, JoinGroupResult( + members = List.empty, + memberId = newMemberId, + generationId = group.generationId, + protocolType = group.protocolType, + protocolName = group.protocolName, + // We want to avoid current leader performing trivial assignment while the group + // is in stable stage, because the new assignment in leader's next sync call + // won't be broadcast by a stable group. This could be guaranteed by + // always returning the old leader id so that the current leader won't assume itself + // as a leader based on the returned message, since the new member.id won't match + // returned leader id, therefore no assignment will be performed. + leaderId = currentLeader, + error = Errors.NONE)) + } + }, requestLocal) + } else { + maybePrepareRebalance(group, s"Group's selectedProtocol will change because static member ${member.memberId} with instance id $groupInstanceId joined with change of protocol") + } + case CompletingRebalance => + // if the group is in after-sync stage, upon getting a new join-group of a known static member + // we should still trigger a new rebalance, since the old member may already be sent to the leader + // for assignment, and hence when the assignment gets back there would be a mismatch of the old member id + // with the new replaced member id. As a result the new member id would not get any assignment. + prepareRebalance(group, s"Updating metadata for static member ${member.memberId} with instance id $groupInstanceId") + case Empty | Dead => + throw new IllegalStateException(s"Group ${group.groupId} was not supposed to be " + + s"in the state ${group.currentState} when the unknown static member $groupInstanceId rejoins.") + case PreparingRebalance => + } + } + + private def updateMemberAndRebalance(group: GroupMetadata, + member: MemberMetadata, + protocols: List[(String, Array[Byte])], + reason: String, + callback: JoinCallback): Unit = { + group.updateMember(member, protocols, callback) + maybePrepareRebalance(group, reason) + } + + private def maybePrepareRebalance(group: GroupMetadata, reason: String): Unit = { + group.inLock { + if (group.canRebalance) + prepareRebalance(group, reason) + } + } + + // package private for testing + private[group] def prepareRebalance(group: GroupMetadata, reason: String): Unit = { + // if any members are awaiting sync, cancel their request and have them rejoin + if (group.is(CompletingRebalance)) + resetAndPropagateAssignmentError(group, Errors.REBALANCE_IN_PROGRESS) + + // if a sync expiration is pending, cancel it. + removeSyncExpiration(group) + + val delayedRebalance = if (group.is(Empty)) + new InitialDelayedJoin(this, + rebalancePurgatory, + group, + groupConfig.groupInitialRebalanceDelayMs, + groupConfig.groupInitialRebalanceDelayMs, + max(group.rebalanceTimeoutMs - groupConfig.groupInitialRebalanceDelayMs, 0)) + else + new DelayedJoin(this, group, group.rebalanceTimeoutMs) + + group.transitionTo(PreparingRebalance) + + info(s"Preparing to rebalance group ${group.groupId} in state ${group.currentState} with old generation " + + s"${group.generationId} (${Topic.GROUP_METADATA_TOPIC_NAME}-${partitionFor(group.groupId)}) (reason: $reason)") + + val groupKey = GroupJoinKey(group.groupId) + rebalancePurgatory.tryCompleteElseWatch(delayedRebalance, Seq(groupKey)) + } + + private def removeMemberAndUpdateGroup(group: GroupMetadata, member: MemberMetadata, reason: String): Unit = { + // New members may timeout with a pending JoinGroup while the group is still rebalancing, so we have + // to invoke the callback before removing the member. We return UNKNOWN_MEMBER_ID so that the consumer + // will retry the JoinGroup request if is still active. + group.maybeInvokeJoinCallback(member, JoinGroupResult(JoinGroupRequest.UNKNOWN_MEMBER_ID, Errors.UNKNOWN_MEMBER_ID)) + group.remove(member.memberId) + + group.currentState match { + case Dead | Empty => + case Stable | CompletingRebalance => maybePrepareRebalance(group, reason) + case PreparingRebalance => rebalancePurgatory.checkAndComplete(GroupJoinKey(group.groupId)) + } + } + + private def removePendingMemberAndUpdateGroup(group: GroupMetadata, memberId: String): Unit = { + group.remove(memberId) + + if (group.is(PreparingRebalance)) { + rebalancePurgatory.checkAndComplete(GroupJoinKey(group.groupId)) + } + } + + def tryCompleteJoin(group: GroupMetadata, forceComplete: () => Boolean): Boolean = { + group.inLock { + if (group.hasAllMembersJoined) + forceComplete() + else false + } + } + + def onCompleteJoin(group: GroupMetadata): Unit = { + group.inLock { + val notYetRejoinedDynamicMembers = group.notYetRejoinedMembers.filterNot(_._2.isStaticMember) + if (notYetRejoinedDynamicMembers.nonEmpty) { + info(s"Group ${group.groupId} removed dynamic members " + + s"who haven't joined: ${notYetRejoinedDynamicMembers.keySet}") + + notYetRejoinedDynamicMembers.values.foreach { failedMember => + group.remove(failedMember.memberId) + removeHeartbeatForLeavingMember(group, failedMember.memberId) + } + } + + if (group.is(Dead)) { + info(s"Group ${group.groupId} is dead, skipping rebalance stage") + } else if (!group.maybeElectNewJoinedLeader() && group.allMembers.nonEmpty) { + // If all members are not rejoining, we will postpone the completion + // of rebalance preparing stage, and send out another delayed operation + // until session timeout removes all the non-responsive members. + error(s"Group ${group.groupId} could not complete rebalance because no members rejoined") + rebalancePurgatory.tryCompleteElseWatch( + new DelayedJoin(this, group, group.rebalanceTimeoutMs), + Seq(GroupJoinKey(group.groupId))) + } else { + group.initNextGeneration() + if (group.is(Empty)) { + info(s"Group ${group.groupId} with generation ${group.generationId} is now empty " + + s"(${Topic.GROUP_METADATA_TOPIC_NAME}-${partitionFor(group.groupId)})") + + groupManager.storeGroup(group, Map.empty, error => { + if (error != Errors.NONE) { + // we failed to write the empty group metadata. If the broker fails before another rebalance, + // the previous generation written to the log will become active again (and most likely timeout). + // This should be safe since there are no active members in an empty generation, so we just warn. + warn(s"Failed to write empty metadata for group ${group.groupId}: ${error.message}") + } + }, RequestLocal.NoCaching) + } else { + info(s"Stabilized group ${group.groupId} generation ${group.generationId} " + + s"(${Topic.GROUP_METADATA_TOPIC_NAME}-${partitionFor(group.groupId)}) with ${group.size} members") + + // trigger the awaiting join group response callback for all the members after rebalancing + for (member <- group.allMemberMetadata) { + val joinResult = JoinGroupResult( + members = if (group.isLeader(member.memberId)) { + group.currentMemberMetadata + } else { + List.empty + }, + memberId = member.memberId, + generationId = group.generationId, + protocolType = group.protocolType, + protocolName = group.protocolName, + leaderId = group.leaderOrNull, + error = Errors.NONE) + + group.maybeInvokeJoinCallback(member, joinResult) + completeAndScheduleNextHeartbeatExpiration(group, member) + member.isNew = false + + group.addPendingSyncMember(member.memberId) + } + + schedulePendingSync(group) + } + } + } + } + + private def removePendingSyncMember( + group: GroupMetadata, + memberId: String + ): Unit = { + group.removePendingSyncMember(memberId) + maybeCompleteSyncExpiration(group) + } + + private def removeSyncExpiration( + group: GroupMetadata + ): Unit = { + group.clearPendingSyncMembers() + maybeCompleteSyncExpiration(group) + } + + private def maybeCompleteSyncExpiration( + group: GroupMetadata + ): Unit = { + val groupKey = GroupSyncKey(group.groupId) + rebalancePurgatory.checkAndComplete(groupKey) + } + + private def schedulePendingSync( + group: GroupMetadata + ): Unit = { + val delayedSync = new DelayedSync(this, group, group.generationId, group.rebalanceTimeoutMs) + val groupKey = GroupSyncKey(group.groupId) + rebalancePurgatory.tryCompleteElseWatch(delayedSync, Seq(groupKey)) + } + + def tryCompletePendingSync( + group: GroupMetadata, + generationId: Int, + forceComplete: () => Boolean + ): Boolean = { + group.inLock { + if (generationId != group.generationId) { + forceComplete() + } else { + group.currentState match { + case Dead | Empty | PreparingRebalance => + forceComplete() + case CompletingRebalance | Stable => + if (group.hasReceivedSyncFromAllMembers) + forceComplete() + else false + } + } + } + } + + def onExpirePendingSync( + group: GroupMetadata, + generationId: Int + ): Unit = { + group.inLock { + if (generationId != group.generationId) { + error(s"Received unexpected notification of sync expiration for ${group.groupId} " + + s"with an old generation $generationId while the group has ${group.generationId}.") + } else { + group.currentState match { + case Dead | Empty | PreparingRebalance => + error(s"Received unexpected notification of sync expiration after group ${group.groupId} " + + s"already transitioned to the ${group.currentState} state.") + + case CompletingRebalance | Stable => + if (!group.hasReceivedSyncFromAllMembers) { + val pendingSyncMembers = group.allPendingSyncMembers + + pendingSyncMembers.foreach { memberId => + group.remove(memberId) + removeHeartbeatForLeavingMember(group, memberId) + } + + debug(s"Group ${group.groupId} removed members who haven't " + + s"sent their sync request: $pendingSyncMembers") + + prepareRebalance(group, s"Removing $pendingSyncMembers on pending sync request expiration") + } + } + } + } + } + + def tryCompleteHeartbeat(group: GroupMetadata, + memberId: String, + isPending: Boolean, + forceComplete: () => Boolean): Boolean = { + group.inLock { + // The group has been unloaded and invalid, we should complete the heartbeat. + if (group.is(Dead)) { + forceComplete() + } else if (isPending) { + // complete the heartbeat if the member has joined the group + if (group.has(memberId)) { + forceComplete() + } else false + } else if (shouldCompleteNonPendingHeartbeat(group, memberId)) { + forceComplete() + } else false + } + } + + def shouldCompleteNonPendingHeartbeat(group: GroupMetadata, memberId: String): Boolean = { + if (group.has(memberId)) { + val member = group.get(memberId) + member.hasSatisfiedHeartbeat + } else { + debug(s"Member id $memberId was not found in ${group.groupId} during heartbeat completion check") + true + } + } + + def onExpireHeartbeat(group: GroupMetadata, memberId: String, isPending: Boolean): Unit = { + group.inLock { + if (group.is(Dead)) { + info(s"Received notification of heartbeat expiration for member $memberId after group ${group.groupId} had already been unloaded or deleted.") + } else if (isPending) { + info(s"Pending member $memberId in group ${group.groupId} has been removed after session timeout expiration.") + removePendingMemberAndUpdateGroup(group, memberId) + } else if (!group.has(memberId)) { + debug(s"Member $memberId has already been removed from the group.") + } else { + val member = group.get(memberId) + if (!member.hasSatisfiedHeartbeat) { + info(s"Member ${member.memberId} in group ${group.groupId} has failed, removing it from the group") + removeMemberAndUpdateGroup(group, member, s"removing member ${member.memberId} on heartbeat expiration") + } + } + } + } + + def partitionFor(group: String): Int = groupManager.partitionFor(group) + + private def groupIsOverCapacity(group: GroupMetadata): Boolean = { + group.size > groupConfig.groupMaxSize + } + + private def isCoordinatorForGroup(groupId: String) = groupManager.isGroupLocal(groupId) + + private def isCoordinatorLoadInProgress(groupId: String) = groupManager.isGroupLoading(groupId) +} + +object GroupCoordinator { + + val NoState = "" + val NoProtocolType = "" + val NoProtocol = "" + val NoLeader = "" + val NoGeneration = -1 + val NoMembers = List[MemberSummary]() + val EmptyGroup = GroupSummary(NoState, NoProtocolType, NoProtocol, NoMembers) + val DeadGroup = GroupSummary(Dead.toString, NoProtocolType, NoProtocol, NoMembers) + val NewMemberJoinTimeoutMs: Int = 5 * 60 * 1000 + + def apply(config: KafkaConfig, + replicaManager: ReplicaManager, + time: Time, + metrics: Metrics): GroupCoordinator = { + val heartbeatPurgatory = DelayedOperationPurgatory[DelayedHeartbeat]("Heartbeat", config.brokerId) + val rebalancePurgatory = DelayedOperationPurgatory[DelayedRebalance]("Rebalance", config.brokerId) + GroupCoordinator(config, replicaManager, heartbeatPurgatory, rebalancePurgatory, time, metrics) + } + + private[group] def offsetConfig(config: KafkaConfig) = OffsetConfig( + maxMetadataSize = config.offsetMetadataMaxSize, + loadBufferSize = config.offsetsLoadBufferSize, + offsetsRetentionMs = config.offsetsRetentionMinutes * 60L * 1000L, + offsetsRetentionCheckIntervalMs = config.offsetsRetentionCheckIntervalMs, + offsetsTopicNumPartitions = config.offsetsTopicPartitions, + offsetsTopicSegmentBytes = config.offsetsTopicSegmentBytes, + offsetsTopicReplicationFactor = config.offsetsTopicReplicationFactor, + offsetsTopicCompressionCodec = config.offsetsTopicCompressionCodec, + offsetCommitTimeoutMs = config.offsetCommitTimeoutMs, + offsetCommitRequiredAcks = config.offsetCommitRequiredAcks + ) + + def apply(config: KafkaConfig, + replicaManager: ReplicaManager, + heartbeatPurgatory: DelayedOperationPurgatory[DelayedHeartbeat], + rebalancePurgatory: DelayedOperationPurgatory[DelayedRebalance], + time: Time, + metrics: Metrics): GroupCoordinator = { + val offsetConfig = this.offsetConfig(config) + val groupConfig = GroupConfig(groupMinSessionTimeoutMs = config.groupMinSessionTimeoutMs, + groupMaxSessionTimeoutMs = config.groupMaxSessionTimeoutMs, + groupMaxSize = config.groupMaxSize, + groupInitialRebalanceDelayMs = config.groupInitialRebalanceDelay) + + val groupMetadataManager = new GroupMetadataManager(config.brokerId, config.interBrokerProtocolVersion, + offsetConfig, replicaManager, time, metrics) + new GroupCoordinator(config.brokerId, groupConfig, offsetConfig, groupMetadataManager, heartbeatPurgatory, + rebalancePurgatory, time, metrics) + } + + private def memberLeaveError(memberIdentity: MemberIdentity, + error: Errors): LeaveMemberResponse = { + LeaveMemberResponse( + memberId = memberIdentity.memberId, + groupInstanceId = Option(memberIdentity.groupInstanceId), + error = error) + } + + private def leaveError(topLevelError: Errors, + memberResponses: List[LeaveMemberResponse]): LeaveGroupResult = { + LeaveGroupResult( + topLevelError = topLevelError, + memberResponses = memberResponses) + } +} + +case class GroupConfig(groupMinSessionTimeoutMs: Int, + groupMaxSessionTimeoutMs: Int, + groupMaxSize: Int, + groupInitialRebalanceDelayMs: Int) + +case class JoinGroupResult(members: List[JoinGroupResponseMember], + memberId: String, + generationId: Int, + protocolType: Option[String], + protocolName: Option[String], + leaderId: String, + error: Errors) + +object JoinGroupResult { + def apply(memberId: String, error: Errors): JoinGroupResult = { + JoinGroupResult( + members = List.empty, + memberId = memberId, + generationId = GroupCoordinator.NoGeneration, + protocolType = None, + protocolName = None, + leaderId = GroupCoordinator.NoLeader, + error = error) + } +} + +case class SyncGroupResult(protocolType: Option[String], + protocolName: Option[String], + memberAssignment: Array[Byte], + error: Errors) + +object SyncGroupResult { + def apply(error: Errors): SyncGroupResult = { + SyncGroupResult(None, None, Array.empty, error) + } +} + +case class LeaveMemberResponse(memberId: String, + groupInstanceId: Option[String], + error: Errors) + +case class LeaveGroupResult(topLevelError: Errors, + memberResponses : List[LeaveMemberResponse]) diff --git a/core/src/main/scala/kafka/coordinator/group/GroupMetadata.scala b/core/src/main/scala/kafka/coordinator/group/GroupMetadata.scala new file mode 100644 index 0000000..5cb8a73 --- /dev/null +++ b/core/src/main/scala/kafka/coordinator/group/GroupMetadata.scala @@ -0,0 +1,832 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.coordinator.group + +import java.nio.ByteBuffer +import java.util.UUID +import java.util.concurrent.locks.ReentrantLock + +import kafka.common.OffsetAndMetadata +import kafka.utils.{CoreUtils, Logging, nonthreadsafe} +import kafka.utils.Implicits._ +import org.apache.kafka.clients.consumer.internals.ConsumerProtocol +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.message.JoinGroupResponseData.JoinGroupResponseMember +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.protocol.types.SchemaException +import org.apache.kafka.common.utils.Time + +import scala.collection.{Seq, immutable, mutable} +import scala.jdk.CollectionConverters._ + +private[group] sealed trait GroupState { + val validPreviousStates: Set[GroupState] +} + +/** + * Group is preparing to rebalance + * + * action: respond to heartbeats with REBALANCE_IN_PROGRESS + * respond to sync group with REBALANCE_IN_PROGRESS + * remove member on leave group request + * park join group requests from new or existing members until all expected members have joined + * allow offset commits from previous generation + * allow offset fetch requests + * transition: some members have joined by the timeout => CompletingRebalance + * all members have left the group => Empty + * group is removed by partition emigration => Dead + */ +private[group] case object PreparingRebalance extends GroupState { + val validPreviousStates: Set[GroupState] = Set(Stable, CompletingRebalance, Empty) +} + +/** + * Group is awaiting state assignment from the leader + * + * action: respond to heartbeats with REBALANCE_IN_PROGRESS + * respond to offset commits with REBALANCE_IN_PROGRESS + * park sync group requests from followers until transition to Stable + * allow offset fetch requests + * transition: sync group with state assignment received from leader => Stable + * join group from new member or existing member with updated metadata => PreparingRebalance + * leave group from existing member => PreparingRebalance + * member failure detected => PreparingRebalance + * group is removed by partition emigration => Dead + */ +private[group] case object CompletingRebalance extends GroupState { + val validPreviousStates: Set[GroupState] = Set(PreparingRebalance) +} + +/** + * Group is stable + * + * action: respond to member heartbeats normally + * respond to sync group from any member with current assignment + * respond to join group from followers with matching metadata with current group metadata + * allow offset commits from member of current generation + * allow offset fetch requests + * transition: member failure detected via heartbeat => PreparingRebalance + * leave group from existing member => PreparingRebalance + * leader join-group received => PreparingRebalance + * follower join-group with new metadata => PreparingRebalance + * group is removed by partition emigration => Dead + */ +private[group] case object Stable extends GroupState { + val validPreviousStates: Set[GroupState] = Set(CompletingRebalance) +} + +/** + * Group has no more members and its metadata is being removed + * + * action: respond to join group with UNKNOWN_MEMBER_ID + * respond to sync group with UNKNOWN_MEMBER_ID + * respond to heartbeat with UNKNOWN_MEMBER_ID + * respond to leave group with UNKNOWN_MEMBER_ID + * respond to offset commit with UNKNOWN_MEMBER_ID + * allow offset fetch requests + * transition: Dead is a final state before group metadata is cleaned up, so there are no transitions + */ +private[group] case object Dead extends GroupState { + val validPreviousStates: Set[GroupState] = Set(Stable, PreparingRebalance, CompletingRebalance, Empty, Dead) +} + +/** + * Group has no more members, but lingers until all offsets have expired. This state + * also represents groups which use Kafka only for offset commits and have no members. + * + * action: respond normally to join group from new members + * respond to sync group with UNKNOWN_MEMBER_ID + * respond to heartbeat with UNKNOWN_MEMBER_ID + * respond to leave group with UNKNOWN_MEMBER_ID + * respond to offset commit with UNKNOWN_MEMBER_ID + * allow offset fetch requests + * transition: last offsets removed in periodic expiration task => Dead + * join group from a new member => PreparingRebalance + * group is removed by partition emigration => Dead + * group is removed by expiration => Dead + */ +private[group] case object Empty extends GroupState { + val validPreviousStates: Set[GroupState] = Set(PreparingRebalance) +} + + +private object GroupMetadata extends Logging { + + def loadGroup(groupId: String, + initialState: GroupState, + generationId: Int, + protocolType: String, + protocolName: String, + leaderId: String, + currentStateTimestamp: Option[Long], + members: Iterable[MemberMetadata], + time: Time): GroupMetadata = { + val group = new GroupMetadata(groupId, initialState, time) + group.generationId = generationId + group.protocolType = if (protocolType == null || protocolType.isEmpty) None else Some(protocolType) + group.protocolName = Option(protocolName) + group.leaderId = Option(leaderId) + group.currentStateTimestamp = currentStateTimestamp + members.foreach { member => + group.add(member, null) + info(s"Loaded member $member in group $groupId with generation ${group.generationId}.") + } + group.subscribedTopics = group.computeSubscribedTopics() + group + } + + private val MemberIdDelimiter = "-" +} + +/** + * Case class used to represent group metadata for the ListGroups API + */ +case class GroupOverview(groupId: String, + protocolType: String, + state: String) + +/** + * Case class used to represent group metadata for the DescribeGroup API + */ +case class GroupSummary(state: String, + protocolType: String, + protocol: String, + members: List[MemberSummary]) + +/** + * We cache offset commits along with their commit record offset. This enables us to ensure that the latest offset + * commit is always materialized when we have a mix of transactional and regular offset commits. Without preserving + * information of the commit record offset, compaction of the offsets topic itself may result in the wrong offset commit + * being materialized. + */ +case class CommitRecordMetadataAndOffset(appendedBatchOffset: Option[Long], offsetAndMetadata: OffsetAndMetadata) { + def olderThan(that: CommitRecordMetadataAndOffset): Boolean = appendedBatchOffset.get < that.appendedBatchOffset.get +} + +/** + * Group contains the following metadata: + * + * Membership metadata: + * 1. Members registered in this group + * 2. Current protocol assigned to the group (e.g. partition assignment strategy for consumers) + * 3. Protocol metadata associated with group members + * + * State metadata: + * 1. group state + * 2. generation id + * 3. leader id + */ +@nonthreadsafe +private[group] class GroupMetadata(val groupId: String, initialState: GroupState, time: Time) extends Logging { + type JoinCallback = JoinGroupResult => Unit + + private[group] val lock = new ReentrantLock + + private var state: GroupState = initialState + var currentStateTimestamp: Option[Long] = Some(time.milliseconds()) + var protocolType: Option[String] = None + var protocolName: Option[String] = None + var generationId = 0 + private var leaderId: Option[String] = None + + private val members = new mutable.HashMap[String, MemberMetadata] + // Static membership mapping [key: group.instance.id, value: member.id] + private val staticMembers = new mutable.HashMap[String, String] + private val pendingMembers = new mutable.HashSet[String] + private var numMembersAwaitingJoin = 0 + private val supportedProtocols = new mutable.HashMap[String, Integer]().withDefaultValue(0) + private val offsets = new mutable.HashMap[TopicPartition, CommitRecordMetadataAndOffset] + private val pendingOffsetCommits = new mutable.HashMap[TopicPartition, OffsetAndMetadata] + private val pendingTransactionalOffsetCommits = new mutable.HashMap[Long, mutable.Map[TopicPartition, CommitRecordMetadataAndOffset]]() + private var receivedTransactionalOffsetCommits = false + private var receivedConsumerOffsetCommits = false + private val pendingSyncMembers = new mutable.HashSet[String] + + // When protocolType == `consumer`, a set of subscribed topics is maintained. The set is + // computed when a new generation is created or when the group is restored from the log. + private var subscribedTopics: Option[Set[String]] = None + + var newMemberAdded: Boolean = false + + def inLock[T](fun: => T): T = CoreUtils.inLock(lock)(fun) + + def is(groupState: GroupState): Boolean = state == groupState + def has(memberId: String): Boolean = members.contains(memberId) + def get(memberId: String): MemberMetadata = members(memberId) + def size: Int = members.size + + def isLeader(memberId: String): Boolean = leaderId.contains(memberId) + def leaderOrNull: String = leaderId.orNull + def currentStateTimestampOrDefault: Long = currentStateTimestamp.getOrElse(-1) + + def isConsumerGroup: Boolean = protocolType.contains(ConsumerProtocol.PROTOCOL_TYPE) + + def add(member: MemberMetadata, callback: JoinCallback = null): Unit = { + member.groupInstanceId.foreach { instanceId => + if (staticMembers.contains(instanceId)) + throw new IllegalStateException(s"Static member with groupInstanceId=$instanceId " + + s"cannot be added to group $groupId since it is already a member") + staticMembers.put(instanceId, member.memberId) + } + + if (members.isEmpty) + this.protocolType = Some(member.protocolType) + + assert(this.protocolType.orNull == member.protocolType) + assert(supportsProtocols(member.protocolType, MemberMetadata.plainProtocolSet(member.supportedProtocols))) + + if (leaderId.isEmpty) + leaderId = Some(member.memberId) + + members.put(member.memberId, member) + incSupportedProtocols(member) + member.awaitingJoinCallback = callback + + if (member.isAwaitingJoin) + numMembersAwaitingJoin += 1 + + pendingMembers.remove(member.memberId) + } + + def remove(memberId: String): Unit = { + members.remove(memberId).foreach { member => + decSupportedProtocols(member) + if (member.isAwaitingJoin) + numMembersAwaitingJoin -= 1 + + member.groupInstanceId.foreach(staticMembers.remove) + } + + if (isLeader(memberId)) + leaderId = members.keys.headOption + + pendingMembers.remove(memberId) + pendingSyncMembers.remove(memberId) + } + + /** + * Check whether current leader is rejoined. If not, try to find another joined member to be + * new leader. Return false if + * 1. the group is currently empty (has no designated leader) + * 2. no member rejoined + */ + def maybeElectNewJoinedLeader(): Boolean = { + leaderId.exists { currentLeaderId => + val currentLeader = get(currentLeaderId) + if (!currentLeader.isAwaitingJoin) { + members.find(_._2.isAwaitingJoin) match { + case Some((anyJoinedMemberId, anyJoinedMember)) => + leaderId = Option(anyJoinedMemberId) + info(s"Group leader [member.id: ${currentLeader.memberId}, " + + s"group.instance.id: ${currentLeader.groupInstanceId}] failed to join " + + s"before rebalance timeout, while new leader $anyJoinedMember was elected.") + true + + case None => + info(s"Group leader [member.id: ${currentLeader.memberId}, " + + s"group.instance.id: ${currentLeader.groupInstanceId}] failed to join " + + s"before rebalance timeout, and the group couldn't proceed to next generation" + + s"because no member joined.") + false + } + } else { + true + } + } + } + + /** + * [For static members only]: Replace the old member id with the new one, + * keep everything else unchanged and return the updated member. + */ + def replaceStaticMember( + groupInstanceId: String, + oldMemberId: String, + newMemberId: String + ): MemberMetadata = { + val memberMetadata = members.remove(oldMemberId) + .getOrElse(throw new IllegalArgumentException(s"Cannot replace non-existing member id $oldMemberId")) + + // Fence potential duplicate member immediately if someone awaits join/sync callback. + maybeInvokeJoinCallback(memberMetadata, JoinGroupResult(oldMemberId, Errors.FENCED_INSTANCE_ID)) + maybeInvokeSyncCallback(memberMetadata, SyncGroupResult(Errors.FENCED_INSTANCE_ID)) + + memberMetadata.memberId = newMemberId + members.put(newMemberId, memberMetadata) + + if (isLeader(oldMemberId)) { + leaderId = Some(newMemberId) + } + + staticMembers.put(groupInstanceId, newMemberId) + memberMetadata + } + + def isPendingMember(memberId: String): Boolean = pendingMembers.contains(memberId) + + def addPendingMember(memberId: String): Boolean = { + if (has(memberId)) { + throw new IllegalStateException(s"Attempt to add pending member $memberId which is already " + + s"a stable member of the group") + } + pendingMembers.add(memberId) + } + + def addPendingSyncMember(memberId: String): Boolean = { + if (!has(memberId)) { + throw new IllegalStateException(s"Attempt to add a pending sync for member $memberId which " + + "is not a member of the group") + } + pendingSyncMembers.add(memberId) + } + + def removePendingSyncMember(memberId: String): Boolean = { + if (!has(memberId)) { + throw new IllegalStateException(s"Attempt to remove a pending sync for member $memberId which " + + "is not a member of the group") + } + pendingSyncMembers.remove(memberId) + } + + def hasReceivedSyncFromAllMembers: Boolean = { + pendingSyncMembers.isEmpty + } + + def allPendingSyncMembers: Set[String] = { + pendingSyncMembers.toSet + } + + def clearPendingSyncMembers(): Unit = { + pendingSyncMembers.clear() + } + + def hasStaticMember(groupInstanceId: String): Boolean = { + staticMembers.contains(groupInstanceId) + } + + def currentStaticMemberId(groupInstanceId: String): Option[String] = { + staticMembers.get(groupInstanceId) + } + + def currentState: GroupState = state + + def notYetRejoinedMembers: Map[String, MemberMetadata] = members.filter(!_._2.isAwaitingJoin).toMap + + def hasAllMembersJoined: Boolean = members.size == numMembersAwaitingJoin && pendingMembers.isEmpty + + def allMembers: collection.Set[String] = members.keySet + + def allStaticMembers: collection.Set[String] = staticMembers.keySet + + // For testing only. + private[group] def allDynamicMembers: Set[String] = { + val dynamicMemberSet = new mutable.HashSet[String] + allMembers.foreach(memberId => dynamicMemberSet.add(memberId)) + staticMembers.values.foreach(memberId => dynamicMemberSet.remove(memberId)) + dynamicMemberSet.toSet + } + + def numPending: Int = pendingMembers.size + + def numAwaiting: Int = numMembersAwaitingJoin + + def allMemberMetadata: List[MemberMetadata] = members.values.toList + + def rebalanceTimeoutMs: Int = members.values.foldLeft(0) { (timeout, member) => + timeout.max(member.rebalanceTimeoutMs) + } + + def generateMemberId(clientId: String, + groupInstanceId: Option[String]): String = { + groupInstanceId match { + case None => + clientId + GroupMetadata.MemberIdDelimiter + UUID.randomUUID().toString + case Some(instanceId) => + instanceId + GroupMetadata.MemberIdDelimiter + UUID.randomUUID().toString + } + } + + /** + * Verify the member.id is up to date for static members. Return true if both conditions met: + * 1. given member is a known static member to group + * 2. group stored member.id doesn't match with given member.id + */ + def isStaticMemberFenced( + groupInstanceId: String, + memberId: String + ): Boolean = { + currentStaticMemberId(groupInstanceId).exists(_ != memberId) + } + + def canRebalance: Boolean = PreparingRebalance.validPreviousStates.contains(state) + + def transitionTo(groupState: GroupState): Unit = { + assertValidTransition(groupState) + state = groupState + currentStateTimestamp = Some(time.milliseconds()) + } + + def selectProtocol: String = { + if (members.isEmpty) + throw new IllegalStateException("Cannot select protocol for empty group") + + // select the protocol for this group which is supported by all members + val candidates = candidateProtocols + + // let each member vote for one of the protocols and choose the one with the most votes + val (protocol, _) = allMemberMetadata + .map(_.vote(candidates)) + .groupBy(identity) + .maxBy { case (_, votes) => votes.size } + + protocol + } + + private def incSupportedProtocols(member: MemberMetadata): Unit = { + member.supportedProtocols.foreach { case (protocol, _) => supportedProtocols(protocol) += 1 } + } + + private def decSupportedProtocols(member: MemberMetadata): Unit = { + member.supportedProtocols.foreach { case (protocol, _) => supportedProtocols(protocol) -= 1 } + } + + private def candidateProtocols: Set[String] = { + // get the set of protocols that are commonly supported by all members + val numMembers = members.size + supportedProtocols.filter(_._2 == numMembers).keys.toSet + } + + def supportsProtocols(memberProtocolType: String, memberProtocols: Set[String]): Boolean = { + if (is(Empty)) + memberProtocolType.nonEmpty && memberProtocols.nonEmpty + else + protocolType.contains(memberProtocolType) && memberProtocols.exists(supportedProtocols(_) == members.size) + } + + def getSubscribedTopics: Option[Set[String]] = subscribedTopics + + /** + * Returns true if the consumer group is actively subscribed to the topic. When the consumer + * group does not know, because the information is not available yet or because the it has + * failed to parse the Consumer Protocol, it returns true to be safe. + */ + def isSubscribedToTopic(topic: String): Boolean = subscribedTopics match { + case Some(topics) => topics.contains(topic) + case None => true + } + + /** + * Collects the set of topics that the members are subscribed to when the Protocol Type is equal + * to 'consumer'. None is returned if + * - the protocol type is not equal to 'consumer'; + * - the protocol is not defined yet; or + * - the protocol metadata does not comply with the schema. + */ + private[group] def computeSubscribedTopics(): Option[Set[String]] = { + protocolType match { + case Some(ConsumerProtocol.PROTOCOL_TYPE) if members.nonEmpty && protocolName.isDefined => + try { + Some( + members.map { case (_, member) => + // The consumer protocol is parsed with V0 which is the based prefix of all versions. + // This way the consumer group manager does not depend on any specific existing or + // future versions of the consumer protocol. VO must prefix all new versions. + val buffer = ByteBuffer.wrap(member.metadata(protocolName.get)) + ConsumerProtocol.deserializeVersion(buffer) + ConsumerProtocol.deserializeSubscription(buffer, 0).topics.asScala.toSet + }.reduceLeft(_ ++ _) + ) + } catch { + case e: SchemaException => + warn(s"Failed to parse Consumer Protocol ${ConsumerProtocol.PROTOCOL_TYPE}:${protocolName.get} " + + s"of group $groupId. Consumer group coordinator is not aware of the subscribed topics.", e) + None + } + + case Some(ConsumerProtocol.PROTOCOL_TYPE) if members.isEmpty => + Option(Set.empty) + + case _ => None + } + } + + def updateMember(member: MemberMetadata, + protocols: List[(String, Array[Byte])], + callback: JoinCallback): Unit = { + decSupportedProtocols(member) + member.supportedProtocols = protocols + incSupportedProtocols(member) + + if (callback != null && !member.isAwaitingJoin) { + numMembersAwaitingJoin += 1 + } else if (callback == null && member.isAwaitingJoin) { + numMembersAwaitingJoin -= 1 + } + member.awaitingJoinCallback = callback + } + + def maybeInvokeJoinCallback(member: MemberMetadata, + joinGroupResult: JoinGroupResult): Unit = { + if (member.isAwaitingJoin) { + member.awaitingJoinCallback(joinGroupResult) + member.awaitingJoinCallback = null + numMembersAwaitingJoin -= 1 + } + } + + /** + * @return true if a sync callback actually performs. + */ + def maybeInvokeSyncCallback(member: MemberMetadata, + syncGroupResult: SyncGroupResult): Boolean = { + if (member.isAwaitingSync) { + member.awaitingSyncCallback(syncGroupResult) + member.awaitingSyncCallback = null + true + } else { + false + } + } + + def initNextGeneration(): Unit = { + if (members.nonEmpty) { + generationId += 1 + protocolName = Some(selectProtocol) + subscribedTopics = computeSubscribedTopics() + transitionTo(CompletingRebalance) + } else { + generationId += 1 + protocolName = None + subscribedTopics = computeSubscribedTopics() + transitionTo(Empty) + } + receivedConsumerOffsetCommits = false + receivedTransactionalOffsetCommits = false + clearPendingSyncMembers() + } + + def currentMemberMetadata: List[JoinGroupResponseMember] = { + if (is(Dead) || is(PreparingRebalance)) + throw new IllegalStateException("Cannot obtain member metadata for group in state %s".format(state)) + members.map{ case (memberId, memberMetadata) => new JoinGroupResponseMember() + .setMemberId(memberId) + .setGroupInstanceId(memberMetadata.groupInstanceId.orNull) + .setMetadata(memberMetadata.metadata(protocolName.get)) + }.toList + } + + def summary: GroupSummary = { + if (is(Stable)) { + val protocol = protocolName.orNull + if (protocol == null) + throw new IllegalStateException("Invalid null group protocol for stable group") + + val members = this.members.values.map { member => member.summary(protocol) } + GroupSummary(state.toString, protocolType.getOrElse(""), protocol, members.toList) + } else { + val members = this.members.values.map{ member => member.summaryNoMetadata() } + GroupSummary(state.toString, protocolType.getOrElse(""), GroupCoordinator.NoProtocol, members.toList) + } + } + + def overview: GroupOverview = { + GroupOverview(groupId, protocolType.getOrElse(""), state.toString) + } + + def initializeOffsets(offsets: collection.Map[TopicPartition, CommitRecordMetadataAndOffset], + pendingTxnOffsets: Map[Long, mutable.Map[TopicPartition, CommitRecordMetadataAndOffset]]): Unit = { + this.offsets ++= offsets + this.pendingTransactionalOffsetCommits ++= pendingTxnOffsets + } + + def onOffsetCommitAppend(topicPartition: TopicPartition, offsetWithCommitRecordMetadata: CommitRecordMetadataAndOffset): Unit = { + if (pendingOffsetCommits.contains(topicPartition)) { + if (offsetWithCommitRecordMetadata.appendedBatchOffset.isEmpty) + throw new IllegalStateException("Cannot complete offset commit write without providing the metadata of the record " + + "in the log.") + if (!offsets.contains(topicPartition) || offsets(topicPartition).olderThan(offsetWithCommitRecordMetadata)) + offsets.put(topicPartition, offsetWithCommitRecordMetadata) + } + + pendingOffsetCommits.get(topicPartition) match { + case Some(stagedOffset) if offsetWithCommitRecordMetadata.offsetAndMetadata == stagedOffset => + pendingOffsetCommits.remove(topicPartition) + case _ => + // The pendingOffsetCommits for this partition could be empty if the topic was deleted, in which case + // its entries would be removed from the cache by the `removeOffsets` method. + } + } + + def failPendingOffsetWrite(topicPartition: TopicPartition, offset: OffsetAndMetadata): Unit = { + pendingOffsetCommits.get(topicPartition) match { + case Some(pendingOffset) if offset == pendingOffset => pendingOffsetCommits.remove(topicPartition) + case _ => + } + } + + def prepareOffsetCommit(offsets: Map[TopicPartition, OffsetAndMetadata]): Unit = { + receivedConsumerOffsetCommits = true + pendingOffsetCommits ++= offsets + } + + def prepareTxnOffsetCommit(producerId: Long, offsets: Map[TopicPartition, OffsetAndMetadata]): Unit = { + trace(s"TxnOffsetCommit for producer $producerId and group $groupId with offsets $offsets is pending") + receivedTransactionalOffsetCommits = true + val producerOffsets = pendingTransactionalOffsetCommits.getOrElseUpdate(producerId, + mutable.Map.empty[TopicPartition, CommitRecordMetadataAndOffset]) + + offsets.forKeyValue { (topicPartition, offsetAndMetadata) => + producerOffsets.put(topicPartition, CommitRecordMetadataAndOffset(None, offsetAndMetadata)) + } + } + + def hasReceivedConsistentOffsetCommits : Boolean = { + !receivedConsumerOffsetCommits || !receivedTransactionalOffsetCommits + } + + /* Remove a pending transactional offset commit if the actual offset commit record was not written to the log. + * We will return an error and the client will retry the request, potentially to a different coordinator. + */ + def failPendingTxnOffsetCommit(producerId: Long, topicPartition: TopicPartition): Unit = { + pendingTransactionalOffsetCommits.get(producerId) match { + case Some(pendingOffsets) => + val pendingOffsetCommit = pendingOffsets.remove(topicPartition) + trace(s"TxnOffsetCommit for producer $producerId and group $groupId with offsets $pendingOffsetCommit failed " + + s"to be appended to the log") + if (pendingOffsets.isEmpty) + pendingTransactionalOffsetCommits.remove(producerId) + case _ => + // We may hit this case if the partition in question has emigrated already. + } + } + + def onTxnOffsetCommitAppend(producerId: Long, topicPartition: TopicPartition, + commitRecordMetadataAndOffset: CommitRecordMetadataAndOffset): Unit = { + pendingTransactionalOffsetCommits.get(producerId) match { + case Some(pendingOffset) => + if (pendingOffset.contains(topicPartition) + && pendingOffset(topicPartition).offsetAndMetadata == commitRecordMetadataAndOffset.offsetAndMetadata) + pendingOffset.update(topicPartition, commitRecordMetadataAndOffset) + case _ => + // We may hit this case if the partition in question has emigrated. + } + } + + /* Complete a pending transactional offset commit. This is called after a commit or abort marker is fully written + * to the log. + */ + def completePendingTxnOffsetCommit(producerId: Long, isCommit: Boolean): Unit = { + val pendingOffsetsOpt = pendingTransactionalOffsetCommits.remove(producerId) + if (isCommit) { + pendingOffsetsOpt.foreach { pendingOffsets => + pendingOffsets.forKeyValue { (topicPartition, commitRecordMetadataAndOffset) => + if (commitRecordMetadataAndOffset.appendedBatchOffset.isEmpty) + throw new IllegalStateException(s"Trying to complete a transactional offset commit for producerId $producerId " + + s"and groupId $groupId even though the offset commit record itself hasn't been appended to the log.") + + val currentOffsetOpt = offsets.get(topicPartition) + if (currentOffsetOpt.forall(_.olderThan(commitRecordMetadataAndOffset))) { + trace(s"TxnOffsetCommit for producer $producerId and group $groupId with offset $commitRecordMetadataAndOffset " + + "committed and loaded into the cache.") + offsets.put(topicPartition, commitRecordMetadataAndOffset) + } else { + trace(s"TxnOffsetCommit for producer $producerId and group $groupId with offset $commitRecordMetadataAndOffset " + + s"committed, but not loaded since its offset is older than current offset $currentOffsetOpt.") + } + } + } + } else { + trace(s"TxnOffsetCommit for producer $producerId and group $groupId with offsets $pendingOffsetsOpt aborted") + } + } + + def activeProducers: collection.Set[Long] = pendingTransactionalOffsetCommits.keySet + + def hasPendingOffsetCommitsFromProducer(producerId: Long): Boolean = + pendingTransactionalOffsetCommits.contains(producerId) + + def hasPendingOffsetCommitsForTopicPartition(topicPartition: TopicPartition): Boolean = { + pendingOffsetCommits.contains(topicPartition) || + pendingTransactionalOffsetCommits.exists( + _._2.contains(topicPartition) + ) + } + + def removeAllOffsets(): immutable.Map[TopicPartition, OffsetAndMetadata] = removeOffsets(offsets.keySet.toSeq) + + def removeOffsets(topicPartitions: Seq[TopicPartition]): immutable.Map[TopicPartition, OffsetAndMetadata] = { + topicPartitions.flatMap { topicPartition => + pendingOffsetCommits.remove(topicPartition) + pendingTransactionalOffsetCommits.forKeyValue { (_, pendingOffsets) => + pendingOffsets.remove(topicPartition) + } + val removedOffset = offsets.remove(topicPartition) + removedOffset.map(topicPartition -> _.offsetAndMetadata) + }.toMap + } + + def removeExpiredOffsets(currentTimestamp: Long, offsetRetentionMs: Long): Map[TopicPartition, OffsetAndMetadata] = { + + def getExpiredOffsets(baseTimestamp: CommitRecordMetadataAndOffset => Long, + subscribedTopics: Set[String] = Set.empty): Map[TopicPartition, OffsetAndMetadata] = { + offsets.filter { + case (topicPartition, commitRecordMetadataAndOffset) => + !subscribedTopics.contains(topicPartition.topic()) && + !pendingOffsetCommits.contains(topicPartition) && { + commitRecordMetadataAndOffset.offsetAndMetadata.expireTimestamp match { + case None => + // current version with no per partition retention + currentTimestamp - baseTimestamp(commitRecordMetadataAndOffset) >= offsetRetentionMs + case Some(expireTimestamp) => + // older versions with explicit expire_timestamp field => old expiration semantics is used + currentTimestamp >= expireTimestamp + } + } + }.map { + case (topicPartition, commitRecordOffsetAndMetadata) => + (topicPartition, commitRecordOffsetAndMetadata.offsetAndMetadata) + }.toMap + } + + val expiredOffsets: Map[TopicPartition, OffsetAndMetadata] = protocolType match { + case Some(_) if is(Empty) => + // no consumer exists in the group => + // - if current state timestamp exists and retention period has passed since group became Empty, + // expire all offsets with no pending offset commit; + // - if there is no current state timestamp (old group metadata schema) and retention period has passed + // since the last commit timestamp, expire the offset + getExpiredOffsets( + commitRecordMetadataAndOffset => currentStateTimestamp + .getOrElse(commitRecordMetadataAndOffset.offsetAndMetadata.commitTimestamp) + ) + + case Some(ConsumerProtocol.PROTOCOL_TYPE) if subscribedTopics.isDefined => + // consumers exist in the group => + // - if the group is aware of the subscribed topics and retention period had passed since the + // the last commit timestamp, expire the offset. offset with pending offset commit are not + // expired + getExpiredOffsets( + _.offsetAndMetadata.commitTimestamp, + subscribedTopics.get + ) + + case None => + // protocolType is None => standalone (simple) consumer, that uses Kafka for offset storage only + // expire offsets with no pending offset commit that retention period has passed since their last commit + getExpiredOffsets(_.offsetAndMetadata.commitTimestamp) + + case _ => + Map() + } + + if (expiredOffsets.nonEmpty) + debug(s"Expired offsets from group '$groupId': ${expiredOffsets.keySet}") + + offsets --= expiredOffsets.keySet + expiredOffsets + } + + def allOffsets: Map[TopicPartition, OffsetAndMetadata] = offsets.map { case (topicPartition, commitRecordMetadataAndOffset) => + (topicPartition, commitRecordMetadataAndOffset.offsetAndMetadata) + }.toMap + + def offset(topicPartition: TopicPartition): Option[OffsetAndMetadata] = offsets.get(topicPartition).map(_.offsetAndMetadata) + + // visible for testing + private[group] def offsetWithRecordMetadata(topicPartition: TopicPartition): Option[CommitRecordMetadataAndOffset] = offsets.get(topicPartition) + + def numOffsets: Int = offsets.size + + def hasOffsets: Boolean = offsets.nonEmpty || pendingOffsetCommits.nonEmpty || pendingTransactionalOffsetCommits.nonEmpty + + private def assertValidTransition(targetState: GroupState): Unit = { + if (!targetState.validPreviousStates.contains(state)) + throw new IllegalStateException("Group %s should be in the %s states before moving to %s state. Instead it is in %s state" + .format(groupId, targetState.validPreviousStates.mkString(","), targetState, state)) + } + + override def toString: String = { + "GroupMetadata(" + + s"groupId=$groupId, " + + s"generation=$generationId, " + + s"protocolType=$protocolType, " + + s"currentState=$currentState, " + + s"members=$members)" + } + +} + diff --git a/core/src/main/scala/kafka/coordinator/group/GroupMetadataManager.scala b/core/src/main/scala/kafka/coordinator/group/GroupMetadataManager.scala new file mode 100644 index 0000000..ac3fc39 --- /dev/null +++ b/core/src/main/scala/kafka/coordinator/group/GroupMetadataManager.scala @@ -0,0 +1,1366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.coordinator.group + +import java.io.PrintStream +import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets +import java.util.Optional +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.locks.ReentrantLock +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} + +import com.yammer.metrics.core.Gauge +import kafka.api.{ApiVersion, KAFKA_0_10_1_IV0, KAFKA_2_1_IV0, KAFKA_2_1_IV1, KAFKA_2_3_IV0} +import kafka.common.OffsetAndMetadata +import kafka.internals.generated.{GroupMetadataValue, OffsetCommitKey, OffsetCommitValue, GroupMetadataKey => GroupMetadataKeyData} +import kafka.log.AppendOrigin +import kafka.metrics.KafkaMetricsGroup +import kafka.server.{FetchLogEnd, ReplicaManager, RequestLocal} +import kafka.utils.CoreUtils.inLock +import kafka.utils.Implicits._ +import kafka.utils._ +import org.apache.kafka.clients.consumer.ConsumerRecord +import org.apache.kafka.clients.consumer.internals.ConsumerProtocol +import org.apache.kafka.common.internals.Topic +import org.apache.kafka.common.metrics.{Metrics, Sensor} +import org.apache.kafka.common.metrics.stats.{Avg, Max, Meter} +import org.apache.kafka.common.protocol.{ByteBufferAccessor, Errors, MessageUtil} +import org.apache.kafka.common.record._ +import org.apache.kafka.common.requests.OffsetFetchResponse.PartitionData +import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse +import org.apache.kafka.common.requests.{OffsetCommitRequest, OffsetFetchResponse} +import org.apache.kafka.common.utils.{Time, Utils} +import org.apache.kafka.common.{KafkaException, MessageFormatter, TopicPartition} + +import scala.collection._ +import scala.collection.mutable.ArrayBuffer +import scala.jdk.CollectionConverters._ + +class GroupMetadataManager(brokerId: Int, + interBrokerProtocolVersion: ApiVersion, + config: OffsetConfig, + val replicaManager: ReplicaManager, + time: Time, + metrics: Metrics) extends Logging with KafkaMetricsGroup { + + private val compressionType: CompressionType = CompressionType.forId(config.offsetsTopicCompressionCodec.codec) + + private val groupMetadataCache = new Pool[String, GroupMetadata] + + /* lock protecting access to loading and owned partition sets */ + private val partitionLock = new ReentrantLock() + + /* partitions of consumer groups that are being loaded, its lock should be always called BEFORE the group lock if needed */ + private val loadingPartitions: mutable.Set[Int] = mutable.Set() + + /* partitions of consumer groups that are assigned, using the same loading partition lock */ + private val ownedPartitions: mutable.Set[Int] = mutable.Set() + + /* shutting down flag */ + private val shuttingDown = new AtomicBoolean(false) + + /* number of partitions for the consumer metadata topic */ + @volatile private var groupMetadataTopicPartitionCount: Int = _ + + /* single-thread scheduler to handle offset/group metadata cache loading and unloading */ + private val scheduler = new KafkaScheduler(threads = 1, threadNamePrefix = "group-metadata-manager-") + + /* The groups with open transactional offsets commits per producer. We need this because when the commit or abort + * marker comes in for a transaction, it is for a particular partition on the offsets topic and a particular producerId. + * We use this structure to quickly find the groups which need to be updated by the commit/abort marker. */ + private val openGroupsForProducer = mutable.HashMap[Long, mutable.Set[String]]() + + /* Track the epoch in which we (un)loaded group state to detect racing LeaderAndIsr requests */ + private [group] val epochForPartitionId = new ConcurrentHashMap[Int, java.lang.Integer]() + + /* setup metrics*/ + private val partitionLoadSensor = metrics.sensor(GroupMetadataManager.LoadTimeSensor) + + partitionLoadSensor.add(metrics.metricName("partition-load-time-max", + GroupMetadataManager.MetricsGroup, + "The max time it took to load the partitions in the last 30sec"), new Max()) + partitionLoadSensor.add(metrics.metricName("partition-load-time-avg", + GroupMetadataManager.MetricsGroup, + "The avg time it took to load the partitions in the last 30sec"), new Avg()) + + val offsetCommitsSensor: Sensor = metrics.sensor(GroupMetadataManager.OffsetCommitsSensor) + + offsetCommitsSensor.add(new Meter( + metrics.metricName("offset-commit-rate", + "group-coordinator-metrics", + "The rate of committed offsets"), + metrics.metricName("offset-commit-count", + "group-coordinator-metrics", + "The total number of committed offsets"))) + + val offsetExpiredSensor: Sensor = metrics.sensor(GroupMetadataManager.OffsetExpiredSensor) + + offsetExpiredSensor.add(new Meter( + metrics.metricName("offset-expiration-rate", + "group-coordinator-metrics", + "The rate of expired offsets"), + metrics.metricName("offset-expiration-count", + "group-coordinator-metrics", + "The total number of expired offsets"))) + + this.logIdent = s"[GroupMetadataManager brokerId=$brokerId] " + + private def recreateGauge[T](name: String, gauge: Gauge[T]): Gauge[T] = { + removeMetric(name) + newGauge(name, gauge) + } + + recreateGauge("NumOffsets", + () => groupMetadataCache.values.map { group => + group.inLock { group.numOffsets } + }.sum + ) + + recreateGauge("NumGroups", + () => groupMetadataCache.size + ) + + recreateGauge("NumGroupsPreparingRebalance", + () => groupMetadataCache.values.count { group => + group synchronized { + group.is(PreparingRebalance) + } + }) + + recreateGauge("NumGroupsCompletingRebalance", + () => groupMetadataCache.values.count { group => + group synchronized { + group.is(CompletingRebalance) + } + }) + + recreateGauge("NumGroupsStable", + () => groupMetadataCache.values.count { group => + group synchronized { + group.is(Stable) + } + }) + + recreateGauge("NumGroupsDead", + () => groupMetadataCache.values.count { group => + group synchronized { + group.is(Dead) + } + }) + + recreateGauge("NumGroupsEmpty", + () => groupMetadataCache.values.count { group => + group synchronized { + group.is(Empty) + } + }) + + def startup(retrieveGroupMetadataTopicPartitionCount: () => Int, enableMetadataExpiration: Boolean): Unit = { + groupMetadataTopicPartitionCount = retrieveGroupMetadataTopicPartitionCount() + scheduler.startup() + if (enableMetadataExpiration) { + scheduler.schedule(name = "delete-expired-group-metadata", + fun = () => cleanupGroupMetadata(), + period = config.offsetsRetentionCheckIntervalMs, + unit = TimeUnit.MILLISECONDS) + } + } + + def currentGroups: Iterable[GroupMetadata] = groupMetadataCache.values + + def isPartitionOwned(partition: Int): Boolean = inLock(partitionLock) { ownedPartitions.contains(partition) } + + def isPartitionLoading(partition: Int): Boolean = inLock(partitionLock) { loadingPartitions.contains(partition) } + + def partitionFor(groupId: String): Int = Utils.abs(groupId.hashCode) % groupMetadataTopicPartitionCount + + def isGroupLocal(groupId: String): Boolean = isPartitionOwned(partitionFor(groupId)) + + def isGroupLoading(groupId: String): Boolean = isPartitionLoading(partitionFor(groupId)) + + def isLoading: Boolean = inLock(partitionLock) { loadingPartitions.nonEmpty } + + // return true iff group is owned and the group doesn't exist + def groupNotExists(groupId: String): Boolean = inLock(partitionLock) { + isGroupLocal(groupId) && getGroup(groupId).forall { group => + group.inLock(group.is(Dead)) + } + } + + // visible for testing + private[group] def isGroupOpenForProducer(producerId: Long, groupId: String) = openGroupsForProducer.get(producerId) match { + case Some(groups) => + groups.contains(groupId) + case None => + false + } + + /** + * Get the group associated with the given groupId or null if not found + */ + def getGroup(groupId: String): Option[GroupMetadata] = { + Option(groupMetadataCache.get(groupId)) + } + + /** + * Get the group associated with the given groupId - the group is created if createIfNotExist + * is true - or null if not found + */ + def getOrMaybeCreateGroup(groupId: String, createIfNotExist: Boolean): Option[GroupMetadata] = { + if (createIfNotExist) + Option(groupMetadataCache.getAndMaybePut(groupId, new GroupMetadata(groupId, Empty, time))) + else + Option(groupMetadataCache.get(groupId)) + } + + /** + * Add a group or get the group associated with the given groupId if it already exists + */ + def addGroup(group: GroupMetadata): GroupMetadata = { + val currentGroup = groupMetadataCache.putIfNotExists(group.groupId, group) + if (currentGroup != null) { + currentGroup + } else { + group + } + } + + def storeGroup(group: GroupMetadata, + groupAssignment: Map[String, Array[Byte]], + responseCallback: Errors => Unit, + requestLocal: RequestLocal = RequestLocal.NoCaching): Unit = { + getMagic(partitionFor(group.groupId)) match { + case Some(magicValue) => + // We always use CREATE_TIME, like the producer. The conversion to LOG_APPEND_TIME (if necessary) happens automatically. + val timestampType = TimestampType.CREATE_TIME + val timestamp = time.milliseconds() + val key = GroupMetadataManager.groupMetadataKey(group.groupId) + val value = GroupMetadataManager.groupMetadataValue(group, groupAssignment, interBrokerProtocolVersion) + + val records = { + val buffer = ByteBuffer.allocate(AbstractRecords.estimateSizeInBytes(magicValue, compressionType, + Seq(new SimpleRecord(timestamp, key, value)).asJava)) + val builder = MemoryRecords.builder(buffer, magicValue, compressionType, timestampType, 0L) + builder.append(timestamp, key, value) + builder.build() + } + + val groupMetadataPartition = new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, partitionFor(group.groupId)) + val groupMetadataRecords = Map(groupMetadataPartition -> records) + val generationId = group.generationId + + // set the callback function to insert the created group into cache after log append completed + def putCacheCallback(responseStatus: Map[TopicPartition, PartitionResponse]): Unit = { + // the append response should only contain the topics partition + if (responseStatus.size != 1 || !responseStatus.contains(groupMetadataPartition)) + throw new IllegalStateException("Append status %s should only have one partition %s" + .format(responseStatus, groupMetadataPartition)) + + // construct the error status in the propagated assignment response in the cache + val status = responseStatus(groupMetadataPartition) + + val responseError = if (status.error == Errors.NONE) { + Errors.NONE + } else { + debug(s"Metadata from group ${group.groupId} with generation $generationId failed when appending to log " + + s"due to ${status.error.exceptionName}") + + // transform the log append error code to the corresponding the commit status error code + status.error match { + case Errors.UNKNOWN_TOPIC_OR_PARTITION + | Errors.NOT_ENOUGH_REPLICAS + | Errors.NOT_ENOUGH_REPLICAS_AFTER_APPEND => + Errors.COORDINATOR_NOT_AVAILABLE + + case Errors.NOT_LEADER_OR_FOLLOWER + | Errors.KAFKA_STORAGE_ERROR => + Errors.NOT_COORDINATOR + + case Errors.REQUEST_TIMED_OUT => + Errors.REBALANCE_IN_PROGRESS + + case Errors.MESSAGE_TOO_LARGE + | Errors.RECORD_LIST_TOO_LARGE + | Errors.INVALID_FETCH_SIZE => + + error(s"Appending metadata message for group ${group.groupId} generation $generationId failed due to " + + s"${status.error.exceptionName}, returning UNKNOWN error code to the client") + + Errors.UNKNOWN_SERVER_ERROR + + case other => + error(s"Appending metadata message for group ${group.groupId} generation $generationId failed " + + s"due to unexpected error: ${status.error.exceptionName}") + + other + } + } + + responseCallback(responseError) + } + appendForGroup(group, groupMetadataRecords, requestLocal, putCacheCallback) + + case None => + responseCallback(Errors.NOT_COORDINATOR) + None + } + } + + private def appendForGroup(group: GroupMetadata, + records: Map[TopicPartition, MemoryRecords], + requestLocal: RequestLocal, + callback: Map[TopicPartition, PartitionResponse] => Unit): Unit = { + // call replica manager to append the group message + replicaManager.appendRecords( + timeout = config.offsetCommitTimeoutMs.toLong, + requiredAcks = config.offsetCommitRequiredAcks, + internalTopicsAllowed = true, + origin = AppendOrigin.Coordinator, + entriesPerPartition = records, + delayedProduceLock = Some(group.lock), + responseCallback = callback, + requestLocal = requestLocal) + } + + /** + * Store offsets by appending it to the replicated log and then inserting to cache + */ + def storeOffsets(group: GroupMetadata, + consumerId: String, + offsetMetadata: immutable.Map[TopicPartition, OffsetAndMetadata], + responseCallback: immutable.Map[TopicPartition, Errors] => Unit, + producerId: Long = RecordBatch.NO_PRODUCER_ID, + producerEpoch: Short = RecordBatch.NO_PRODUCER_EPOCH, + requestLocal: RequestLocal = RequestLocal.NoCaching): Unit = { + // first filter out partitions with offset metadata size exceeding limit + val filteredOffsetMetadata = offsetMetadata.filter { case (_, offsetAndMetadata) => + validateOffsetMetadataLength(offsetAndMetadata.metadata) + } + + group.inLock { + if (!group.hasReceivedConsistentOffsetCommits) + warn(s"group: ${group.groupId} with leader: ${group.leaderOrNull} has received offset commits from consumers as well " + + s"as transactional producers. Mixing both types of offset commits will generally result in surprises and " + + s"should be avoided.") + } + + val isTxnOffsetCommit = producerId != RecordBatch.NO_PRODUCER_ID + // construct the message set to append + if (filteredOffsetMetadata.isEmpty) { + // compute the final error codes for the commit response + val commitStatus = offsetMetadata.map { case (k, _) => k -> Errors.OFFSET_METADATA_TOO_LARGE } + responseCallback(commitStatus) + } else { + getMagic(partitionFor(group.groupId)) match { + case Some(magicValue) => + // We always use CREATE_TIME, like the producer. The conversion to LOG_APPEND_TIME (if necessary) happens automatically. + val timestampType = TimestampType.CREATE_TIME + val timestamp = time.milliseconds() + + val records = filteredOffsetMetadata.map { case (topicPartition, offsetAndMetadata) => + val key = GroupMetadataManager.offsetCommitKey(group.groupId, topicPartition) + val value = GroupMetadataManager.offsetCommitValue(offsetAndMetadata, interBrokerProtocolVersion) + new SimpleRecord(timestamp, key, value) + } + val offsetTopicPartition = new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, partitionFor(group.groupId)) + val buffer = ByteBuffer.allocate(AbstractRecords.estimateSizeInBytes(magicValue, compressionType, records.asJava)) + + if (isTxnOffsetCommit && magicValue < RecordBatch.MAGIC_VALUE_V2) + throw Errors.UNSUPPORTED_FOR_MESSAGE_FORMAT.exception("Attempting to make a transaction offset commit with an invalid magic: " + magicValue) + + val builder = MemoryRecords.builder(buffer, magicValue, compressionType, timestampType, 0L, time.milliseconds(), + producerId, producerEpoch, 0, isTxnOffsetCommit, RecordBatch.NO_PARTITION_LEADER_EPOCH) + + records.foreach(builder.append) + val entries = Map(offsetTopicPartition -> builder.build()) + + // set the callback function to insert offsets into cache after log append completed + def putCacheCallback(responseStatus: Map[TopicPartition, PartitionResponse]): Unit = { + // the append response should only contain the topics partition + if (responseStatus.size != 1 || !responseStatus.contains(offsetTopicPartition)) + throw new IllegalStateException("Append status %s should only have one partition %s" + .format(responseStatus, offsetTopicPartition)) + + // construct the commit response status and insert + // the offset and metadata to cache if the append status has no error + val status = responseStatus(offsetTopicPartition) + + val responseError = group.inLock { + if (status.error == Errors.NONE) { + if (!group.is(Dead)) { + filteredOffsetMetadata.forKeyValue { (topicPartition, offsetAndMetadata) => + if (isTxnOffsetCommit) + group.onTxnOffsetCommitAppend(producerId, topicPartition, CommitRecordMetadataAndOffset(Some(status.baseOffset), offsetAndMetadata)) + else + group.onOffsetCommitAppend(topicPartition, CommitRecordMetadataAndOffset(Some(status.baseOffset), offsetAndMetadata)) + } + } + + // Record the number of offsets committed to the log + offsetCommitsSensor.record(records.size) + + Errors.NONE + } else { + if (!group.is(Dead)) { + if (!group.hasPendingOffsetCommitsFromProducer(producerId)) + removeProducerGroup(producerId, group.groupId) + filteredOffsetMetadata.forKeyValue { (topicPartition, offsetAndMetadata) => + if (isTxnOffsetCommit) + group.failPendingTxnOffsetCommit(producerId, topicPartition) + else + group.failPendingOffsetWrite(topicPartition, offsetAndMetadata) + } + } + + debug(s"Offset commit $filteredOffsetMetadata from group ${group.groupId}, consumer $consumerId " + + s"with generation ${group.generationId} failed when appending to log due to ${status.error.exceptionName}") + + // transform the log append error code to the corresponding the commit status error code + status.error match { + case Errors.UNKNOWN_TOPIC_OR_PARTITION + | Errors.NOT_ENOUGH_REPLICAS + | Errors.NOT_ENOUGH_REPLICAS_AFTER_APPEND => + Errors.COORDINATOR_NOT_AVAILABLE + + case Errors.NOT_LEADER_OR_FOLLOWER + | Errors.KAFKA_STORAGE_ERROR => + Errors.NOT_COORDINATOR + + case Errors.MESSAGE_TOO_LARGE + | Errors.RECORD_LIST_TOO_LARGE + | Errors.INVALID_FETCH_SIZE => + Errors.INVALID_COMMIT_OFFSET_SIZE + + case other => other + } + } + } + + // compute the final error codes for the commit response + val commitStatus = offsetMetadata.map { case (topicPartition, offsetAndMetadata) => + if (validateOffsetMetadataLength(offsetAndMetadata.metadata)) + (topicPartition, responseError) + else + (topicPartition, Errors.OFFSET_METADATA_TOO_LARGE) + } + + // finally trigger the callback logic passed from the API layer + responseCallback(commitStatus) + } + + if (isTxnOffsetCommit) { + group.inLock { + addProducerGroup(producerId, group.groupId) + group.prepareTxnOffsetCommit(producerId, offsetMetadata) + } + } else { + group.inLock { + group.prepareOffsetCommit(offsetMetadata) + } + } + + appendForGroup(group, entries, requestLocal, putCacheCallback) + + case None => + val commitStatus = offsetMetadata.map { case (topicPartition, _) => + (topicPartition, Errors.NOT_COORDINATOR) + } + responseCallback(commitStatus) + } + } + } + + /** + * The most important guarantee that this API provides is that it should never return a stale offset. i.e., it either + * returns the current offset or it begins to sync the cache from the log (and returns an error code). + */ + def getOffsets(groupId: String, requireStable: Boolean, topicPartitionsOpt: Option[Seq[TopicPartition]]): Map[TopicPartition, PartitionData] = { + trace("Getting offsets of %s for group %s.".format(topicPartitionsOpt.getOrElse("all partitions"), groupId)) + val group = groupMetadataCache.get(groupId) + if (group == null) { + topicPartitionsOpt.getOrElse(Seq.empty[TopicPartition]).map { topicPartition => + val partitionData = new PartitionData(OffsetFetchResponse.INVALID_OFFSET, + Optional.empty(), "", Errors.NONE) + topicPartition -> partitionData + }.toMap + } else { + group.inLock { + if (group.is(Dead)) { + topicPartitionsOpt.getOrElse(Seq.empty[TopicPartition]).map { topicPartition => + val partitionData = new PartitionData(OffsetFetchResponse.INVALID_OFFSET, + Optional.empty(), "", Errors.NONE) + topicPartition -> partitionData + }.toMap + } else { + val topicPartitions = topicPartitionsOpt.getOrElse(group.allOffsets.keySet) + + topicPartitions.map { topicPartition => + if (requireStable && group.hasPendingOffsetCommitsForTopicPartition(topicPartition)) { + topicPartition -> new PartitionData(OffsetFetchResponse.INVALID_OFFSET, + Optional.empty(), "", Errors.UNSTABLE_OFFSET_COMMIT) + } else { + val partitionData = group.offset(topicPartition) match { + case None => + new PartitionData(OffsetFetchResponse.INVALID_OFFSET, + Optional.empty(), "", Errors.NONE) + case Some(offsetAndMetadata) => + new PartitionData(offsetAndMetadata.offset, + offsetAndMetadata.leaderEpoch, offsetAndMetadata.metadata, Errors.NONE) + } + topicPartition -> partitionData + } + }.toMap + } + } + } + } + + /** + * Asynchronously read the partition from the offsets topic and populate the cache + */ + def scheduleLoadGroupAndOffsets(offsetsPartition: Int, coordinatorEpoch: Int, onGroupLoaded: GroupMetadata => Unit): Unit = { + val topicPartition = new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, offsetsPartition) + info(s"Scheduling loading of offsets and group metadata from $topicPartition for epoch $coordinatorEpoch") + val startTimeMs = time.milliseconds() + scheduler.schedule(topicPartition.toString, () => loadGroupsAndOffsets(topicPartition, coordinatorEpoch, onGroupLoaded, startTimeMs)) + } + + private[group] def loadGroupsAndOffsets( + topicPartition: TopicPartition, + coordinatorEpoch: Int, + onGroupLoaded: GroupMetadata => Unit, + startTimeMs: java.lang.Long + ): Unit = { + if (!maybeUpdateCoordinatorEpoch(topicPartition.partition, Some(coordinatorEpoch))) { + info(s"Not loading offsets and group metadata for $topicPartition " + + s"in epoch $coordinatorEpoch since current epoch is ${epochForPartitionId.get(topicPartition.partition)}") + } else if (!addLoadingPartition(topicPartition.partition)) { + info(s"Already loading offsets and group metadata from $topicPartition") + } else { + try { + val schedulerTimeMs = time.milliseconds() - startTimeMs + debug(s"Started loading offsets and group metadata from $topicPartition for epoch $coordinatorEpoch") + doLoadGroupsAndOffsets(topicPartition, onGroupLoaded) + val endTimeMs = time.milliseconds() + val totalLoadingTimeMs = endTimeMs - startTimeMs + partitionLoadSensor.record(totalLoadingTimeMs.toDouble, endTimeMs, false) + info(s"Finished loading offsets and group metadata from $topicPartition " + + s"in $totalLoadingTimeMs milliseconds for epoch $coordinatorEpoch, of which " + + s"$schedulerTimeMs milliseconds was spent in the scheduler.") + } catch { + case t: Throwable => error(s"Error loading offsets from $topicPartition", t) + } finally { + inLock(partitionLock) { + ownedPartitions.add(topicPartition.partition) + loadingPartitions.remove(topicPartition.partition) + } + } + } + } + + private def doLoadGroupsAndOffsets(topicPartition: TopicPartition, onGroupLoaded: GroupMetadata => Unit): Unit = { + def logEndOffset: Long = replicaManager.getLogEndOffset(topicPartition).getOrElse(-1L) + + replicaManager.getLog(topicPartition) match { + case None => + warn(s"Attempted to load offsets and group metadata from $topicPartition, but found no log") + + case Some(log) => + val loadedOffsets = mutable.Map[GroupTopicPartition, CommitRecordMetadataAndOffset]() + val pendingOffsets = mutable.Map[Long, mutable.Map[GroupTopicPartition, CommitRecordMetadataAndOffset]]() + val loadedGroups = mutable.Map[String, GroupMetadata]() + val removedGroups = mutable.Set[String]() + + // buffer may not be needed if records are read from memory + var buffer = ByteBuffer.allocate(0) + + // loop breaks if leader changes at any time during the load, since logEndOffset is -1 + var currOffset = log.logStartOffset + + // loop breaks if no records have been read, since the end of the log has been reached + var readAtLeastOneRecord = true + + while (currOffset < logEndOffset && readAtLeastOneRecord && !shuttingDown.get()) { + val fetchDataInfo = log.read(currOffset, + maxLength = config.loadBufferSize, + isolation = FetchLogEnd, + minOneMessage = true) + + readAtLeastOneRecord = fetchDataInfo.records.sizeInBytes > 0 + + val memRecords = (fetchDataInfo.records: @unchecked) match { + case records: MemoryRecords => records + case fileRecords: FileRecords => + val sizeInBytes = fileRecords.sizeInBytes + val bytesNeeded = Math.max(config.loadBufferSize, sizeInBytes) + + // minOneMessage = true in the above log.read means that the buffer may need to be grown to ensure progress can be made + if (buffer.capacity < bytesNeeded) { + if (config.loadBufferSize < bytesNeeded) + warn(s"Loaded offsets and group metadata from $topicPartition with buffer larger ($bytesNeeded bytes) than " + + s"configured offsets.load.buffer.size (${config.loadBufferSize} bytes)") + + buffer = ByteBuffer.allocate(bytesNeeded) + } else { + buffer.clear() + } + + fileRecords.readInto(buffer, 0) + MemoryRecords.readableRecords(buffer) + } + + memRecords.batches.forEach { batch => + val isTxnOffsetCommit = batch.isTransactional + if (batch.isControlBatch) { + val recordIterator = batch.iterator + if (recordIterator.hasNext) { + val record = recordIterator.next() + val controlRecord = ControlRecordType.parse(record.key) + if (controlRecord == ControlRecordType.COMMIT) { + pendingOffsets.getOrElse(batch.producerId, mutable.Map[GroupTopicPartition, CommitRecordMetadataAndOffset]()) + .foreach { + case (groupTopicPartition, commitRecordMetadataAndOffset) => + if (!loadedOffsets.contains(groupTopicPartition) || loadedOffsets(groupTopicPartition).olderThan(commitRecordMetadataAndOffset)) + loadedOffsets.put(groupTopicPartition, commitRecordMetadataAndOffset) + } + } + pendingOffsets.remove(batch.producerId) + } + } else { + var batchBaseOffset: Option[Long] = None + for (record <- batch.asScala) { + require(record.hasKey, "Group metadata/offset entry key should not be null") + if (batchBaseOffset.isEmpty) + batchBaseOffset = Some(record.offset) + GroupMetadataManager.readMessageKey(record.key) match { + + case offsetKey: OffsetKey => + if (isTxnOffsetCommit && !pendingOffsets.contains(batch.producerId)) + pendingOffsets.put(batch.producerId, mutable.Map[GroupTopicPartition, CommitRecordMetadataAndOffset]()) + + // load offset + val groupTopicPartition = offsetKey.key + if (!record.hasValue) { + if (isTxnOffsetCommit) + pendingOffsets(batch.producerId).remove(groupTopicPartition) + else + loadedOffsets.remove(groupTopicPartition) + } else { + val offsetAndMetadata = GroupMetadataManager.readOffsetMessageValue(record.value) + if (isTxnOffsetCommit) + pendingOffsets(batch.producerId).put(groupTopicPartition, CommitRecordMetadataAndOffset(batchBaseOffset, offsetAndMetadata)) + else + loadedOffsets.put(groupTopicPartition, CommitRecordMetadataAndOffset(batchBaseOffset, offsetAndMetadata)) + } + + case groupMetadataKey: GroupMetadataKey => + // load group metadata + val groupId = groupMetadataKey.key + val groupMetadata = GroupMetadataManager.readGroupMessageValue(groupId, record.value, time) + if (groupMetadata != null) { + removedGroups.remove(groupId) + loadedGroups.put(groupId, groupMetadata) + } else { + loadedGroups.remove(groupId) + removedGroups.add(groupId) + } + + case unknownKey => + throw new IllegalStateException(s"Unexpected message key $unknownKey while loading offsets and group metadata") + } + } + } + currOffset = batch.nextOffset + } + } + + val (groupOffsets, emptyGroupOffsets) = loadedOffsets + .groupBy(_._1.group) + .map { case (k, v) => + k -> v.map { case (groupTopicPartition, offset) => (groupTopicPartition.topicPartition, offset) } + }.partition { case (group, _) => loadedGroups.contains(group) } + + val pendingOffsetsByGroup = mutable.Map[String, mutable.Map[Long, mutable.Map[TopicPartition, CommitRecordMetadataAndOffset]]]() + pendingOffsets.forKeyValue { (producerId, producerOffsets) => + producerOffsets.keySet.map(_.group).foreach(addProducerGroup(producerId, _)) + producerOffsets + .groupBy(_._1.group) + .forKeyValue { (group, offsets) => + val groupPendingOffsets = pendingOffsetsByGroup.getOrElseUpdate(group, mutable.Map.empty[Long, mutable.Map[TopicPartition, CommitRecordMetadataAndOffset]]) + val groupProducerOffsets = groupPendingOffsets.getOrElseUpdate(producerId, mutable.Map.empty[TopicPartition, CommitRecordMetadataAndOffset]) + groupProducerOffsets ++= offsets.map { case (groupTopicPartition, offset) => + (groupTopicPartition.topicPartition, offset) + } + } + } + + val (pendingGroupOffsets, pendingEmptyGroupOffsets) = pendingOffsetsByGroup + .partition { case (group, _) => loadedGroups.contains(group)} + + loadedGroups.values.foreach { group => + val offsets = groupOffsets.getOrElse(group.groupId, Map.empty[TopicPartition, CommitRecordMetadataAndOffset]) + val pendingOffsets = pendingGroupOffsets.getOrElse(group.groupId, Map.empty[Long, mutable.Map[TopicPartition, CommitRecordMetadataAndOffset]]) + debug(s"Loaded group metadata $group with offsets $offsets and pending offsets $pendingOffsets") + loadGroup(group, offsets, pendingOffsets) + onGroupLoaded(group) + } + + // load groups which store offsets in kafka, but which have no active members and thus no group + // metadata stored in the log + (emptyGroupOffsets.keySet ++ pendingEmptyGroupOffsets.keySet).foreach { groupId => + val group = new GroupMetadata(groupId, Empty, time) + val offsets = emptyGroupOffsets.getOrElse(groupId, Map.empty[TopicPartition, CommitRecordMetadataAndOffset]) + val pendingOffsets = pendingEmptyGroupOffsets.getOrElse(groupId, Map.empty[Long, mutable.Map[TopicPartition, CommitRecordMetadataAndOffset]]) + debug(s"Loaded group metadata $group with offsets $offsets and pending offsets $pendingOffsets") + loadGroup(group, offsets, pendingOffsets) + onGroupLoaded(group) + } + + removedGroups.foreach { groupId => + // if the cache already contains a group which should be removed, raise an error. Note that it + // is possible (however unlikely) for a consumer group to be removed, and then to be used only for + // offset storage (i.e. by "simple" consumers) + if (groupMetadataCache.contains(groupId) && !emptyGroupOffsets.contains(groupId)) + throw new IllegalStateException(s"Unexpected unload of active group $groupId while " + + s"loading partition $topicPartition") + } + } + } + + private def loadGroup(group: GroupMetadata, offsets: Map[TopicPartition, CommitRecordMetadataAndOffset], + pendingTransactionalOffsets: Map[Long, mutable.Map[TopicPartition, CommitRecordMetadataAndOffset]]): Unit = { + // offsets are initialized prior to loading the group into the cache to ensure that clients see a consistent + // view of the group's offsets + trace(s"Initialized offsets $offsets for group ${group.groupId}") + group.initializeOffsets(offsets, pendingTransactionalOffsets.toMap) + + val currentGroup = addGroup(group) + if (group != currentGroup) + debug(s"Attempt to load group ${group.groupId} from log with generation ${group.generationId} failed " + + s"because there is already a cached group with generation ${currentGroup.generationId}") + } + + /** + * When this broker becomes a follower for an offsets topic partition clear out the cache for groups that belong to + * that partition. + * + * @param offsetsPartition Groups belonging to this partition of the offsets topic will be deleted from the cache. + */ + def removeGroupsForPartition(offsetsPartition: Int, + coordinatorEpoch: Option[Int], + onGroupUnloaded: GroupMetadata => Unit): Unit = { + val topicPartition = new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, offsetsPartition) + info(s"Scheduling unloading of offsets and group metadata from $topicPartition") + scheduler.schedule(topicPartition.toString, () => removeGroupsAndOffsets(topicPartition, coordinatorEpoch, onGroupUnloaded)) + } + + private [group] def removeGroupsAndOffsets(topicPartition: TopicPartition, + coordinatorEpoch: Option[Int], + onGroupUnloaded: GroupMetadata => Unit): Unit = { + val offsetsPartition = topicPartition.partition + if (maybeUpdateCoordinatorEpoch(offsetsPartition, coordinatorEpoch)) { + var numOffsetsRemoved = 0 + var numGroupsRemoved = 0 + + debug(s"Started unloading offsets and group metadata for $topicPartition for " + + s"coordinator epoch $coordinatorEpoch") + inLock(partitionLock) { + // we need to guard the group removal in cache in the loading partition lock + // to prevent coordinator's check-and-get-group race condition + ownedPartitions.remove(offsetsPartition) + loadingPartitions.remove(offsetsPartition) + + for (group <- groupMetadataCache.values) { + if (partitionFor(group.groupId) == offsetsPartition) { + onGroupUnloaded(group) + groupMetadataCache.remove(group.groupId, group) + removeGroupFromAllProducers(group.groupId) + numGroupsRemoved += 1 + numOffsetsRemoved += group.numOffsets + } + } + } + info(s"Finished unloading $topicPartition for coordinator epoch $coordinatorEpoch. " + + s"Removed $numOffsetsRemoved cached offsets and $numGroupsRemoved cached groups.") + } else { + info(s"Not removing offsets and group metadata for $topicPartition " + + s"in epoch $coordinatorEpoch since current epoch is ${epochForPartitionId.get(topicPartition.partition)}") + } + } + + /** + * Update the cached coordinator epoch if the new value is larger than the old value. + * @return true if `epochOpt` is either empty or contains a value greater than or equal to the current epoch + */ + private def maybeUpdateCoordinatorEpoch( + partitionId: Int, + epochOpt: Option[Int] + ): Boolean = { + val updatedEpoch = epochForPartitionId.compute(partitionId, (_, currentEpoch) => { + if (currentEpoch == null) { + epochOpt.map(Int.box).orNull + } else { + epochOpt match { + case Some(epoch) if epoch > currentEpoch => epoch + case _ => currentEpoch + } + } + }) + epochOpt.forall(_ == updatedEpoch) + } + + // visible for testing + private[group] def cleanupGroupMetadata(): Unit = { + val currentTimestamp = time.milliseconds() + val numOffsetsRemoved = cleanupGroupMetadata(groupMetadataCache.values, RequestLocal.NoCaching, + _.removeExpiredOffsets(currentTimestamp, config.offsetsRetentionMs)) + offsetExpiredSensor.record(numOffsetsRemoved) + if (numOffsetsRemoved > 0) + info(s"Removed $numOffsetsRemoved expired offsets in ${time.milliseconds() - currentTimestamp} milliseconds.") + } + + /** + * This function is used to clean up group offsets given the groups and also a function that performs the offset deletion. + * @param groups Groups whose metadata are to be cleaned up + * @param selector A function that implements deletion of (all or part of) group offsets. This function is called while + * a group lock is held, therefore there is no need for the caller to also obtain a group lock. + * @return The cumulative number of offsets removed + */ + def cleanupGroupMetadata(groups: Iterable[GroupMetadata], requestLocal: RequestLocal, + selector: GroupMetadata => Map[TopicPartition, OffsetAndMetadata]): Int = { + var offsetsRemoved = 0 + + groups.foreach { group => + val groupId = group.groupId + val (removedOffsets, groupIsDead, generation) = group.inLock { + val removedOffsets = selector(group) + if (group.is(Empty) && !group.hasOffsets) { + info(s"Group $groupId transitioned to Dead in generation ${group.generationId}") + group.transitionTo(Dead) + } + (removedOffsets, group.is(Dead), group.generationId) + } + + val offsetsPartition = partitionFor(groupId) + val appendPartition = new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, offsetsPartition) + getMagic(offsetsPartition) match { + case Some(magicValue) => + // We always use CREATE_TIME, like the producer. The conversion to LOG_APPEND_TIME (if necessary) happens automatically. + val timestampType = TimestampType.CREATE_TIME + val timestamp = time.milliseconds() + + replicaManager.onlinePartition(appendPartition).foreach { partition => + val tombstones = ArrayBuffer.empty[SimpleRecord] + removedOffsets.forKeyValue { (topicPartition, offsetAndMetadata) => + trace(s"Removing expired/deleted offset and metadata for $groupId, $topicPartition: $offsetAndMetadata") + val commitKey = GroupMetadataManager.offsetCommitKey(groupId, topicPartition) + tombstones += new SimpleRecord(timestamp, commitKey, null) + } + trace(s"Marked ${removedOffsets.size} offsets in $appendPartition for deletion.") + + // We avoid writing the tombstone when the generationId is 0, since this group is only using + // Kafka for offset storage. + if (groupIsDead && groupMetadataCache.remove(groupId, group) && generation > 0) { + // Append the tombstone messages to the partition. It is okay if the replicas don't receive these (say, + // if we crash or leaders move) since the new leaders will still expire the consumers with heartbeat and + // retry removing this group. + val groupMetadataKey = GroupMetadataManager.groupMetadataKey(group.groupId) + tombstones += new SimpleRecord(timestamp, groupMetadataKey, null) + trace(s"Group $groupId removed from the metadata cache and marked for deletion in $appendPartition.") + } + + if (tombstones.nonEmpty) { + try { + // do not need to require acks since even if the tombstone is lost, + // it will be appended again in the next purge cycle + val records = MemoryRecords.withRecords(magicValue, 0L, compressionType, timestampType, tombstones.toArray: _*) + partition.appendRecordsToLeader(records, origin = AppendOrigin.Coordinator, requiredAcks = 0, + requestLocal = requestLocal) + + offsetsRemoved += removedOffsets.size + trace(s"Successfully appended ${tombstones.size} tombstones to $appendPartition for expired/deleted " + + s"offsets and/or metadata for group $groupId") + } catch { + case t: Throwable => + error(s"Failed to append ${tombstones.size} tombstones to $appendPartition for expired/deleted " + + s"offsets and/or metadata for group $groupId.", t) + // ignore and continue + } + } + } + + case None => + info(s"BrokerId $brokerId is no longer a coordinator for the group $groupId. Proceeding cleanup for other alive groups") + } + } + + offsetsRemoved + } + + /** + * Complete pending transactional offset commits of the groups of `producerId` from the provided + * `completedPartitions`. This method is invoked when a commit or abort marker is fully written + * to the log. It may be invoked when a group lock is held by the caller, for instance when delayed + * operations are completed while appending offsets for a group. Since we need to acquire one or + * more group metadata locks to handle transaction completion, this operation is scheduled on + * the scheduler thread to avoid deadlocks. + */ + def scheduleHandleTxnCompletion(producerId: Long, completedPartitions: Set[Int], isCommit: Boolean): Unit = { + scheduler.schedule(s"handleTxnCompletion-$producerId", () => + handleTxnCompletion(producerId, completedPartitions, isCommit)) + } + + private[group] def handleTxnCompletion(producerId: Long, completedPartitions: Set[Int], isCommit: Boolean): Unit = { + val pendingGroups = groupsBelongingToPartitions(producerId, completedPartitions) + pendingGroups.foreach { groupId => + getGroup(groupId) match { + case Some(group) => group.inLock { + if (!group.is(Dead)) { + group.completePendingTxnOffsetCommit(producerId, isCommit) + removeProducerGroup(producerId, groupId) + } + } + case _ => + info(s"Group $groupId has moved away from $brokerId after transaction marker was written but before the " + + s"cache was updated. The cache on the new group owner will be updated instead.") + } + } + } + + private def addProducerGroup(producerId: Long, groupId: String) = openGroupsForProducer synchronized { + openGroupsForProducer.getOrElseUpdate(producerId, mutable.Set.empty[String]).add(groupId) + } + + private def removeProducerGroup(producerId: Long, groupId: String) = openGroupsForProducer synchronized { + openGroupsForProducer.getOrElseUpdate(producerId, mutable.Set.empty[String]).remove(groupId) + if (openGroupsForProducer(producerId).isEmpty) + openGroupsForProducer.remove(producerId) + } + + private def groupsBelongingToPartitions(producerId: Long, partitions: Set[Int]) = openGroupsForProducer synchronized { + val (ownedGroups, _) = openGroupsForProducer.getOrElse(producerId, mutable.Set.empty[String]) + .partition(group => partitions.contains(partitionFor(group))) + ownedGroups + } + + private def removeGroupFromAllProducers(groupId: String): Unit = openGroupsForProducer synchronized { + openGroupsForProducer.forKeyValue { (_, groups) => + groups.remove(groupId) + } + } + + /* + * Check if the offset metadata length is valid + */ + private def validateOffsetMetadataLength(metadata: String) : Boolean = { + metadata == null || metadata.length() <= config.maxMetadataSize + } + + + def shutdown(): Unit = { + shuttingDown.set(true) + if (scheduler.isStarted) + scheduler.shutdown() + metrics.removeSensor(GroupMetadataManager.LoadTimeSensor) + metrics.removeSensor(GroupMetadataManager.OffsetCommitsSensor) + metrics.removeSensor(GroupMetadataManager.OffsetExpiredSensor) + + // TODO: clear the caches + } + + /** + * Check if the replica is local and return the message format version + * + * @param partition Partition of GroupMetadataTopic + * @return Some(MessageFormatVersion) if replica is local, None otherwise + */ + private def getMagic(partition: Int): Option[Byte] = + replicaManager.getMagic(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, partition)) + + /** + * Add the partition into the owned list + * + * NOTE: this is for test only + */ + private[group] def addPartitionOwnership(partition: Int): Unit = { + inLock(partitionLock) { + ownedPartitions.add(partition) + } + } + + /** + * Add a partition to the loading partitions set. Return true if the partition was not + * already loading. + * + * Visible for testing + */ + private[group] def addLoadingPartition(partition: Int): Boolean = { + inLock(partitionLock) { + if (ownedPartitions.contains(partition)) { + false + } else { + loadingPartitions.add(partition) + } + } + } + +} + +/** + * Messages stored for the group topic has versions for both the key and value fields. Key + * version is used to indicate the type of the message (also to differentiate different types + * of messages from being compacted together if they have the same field values); and value + * version is used to evolve the messages within their data types: + * + * key version 0: group consumption offset + * -> value version 0: [offset, metadata, timestamp] + * + * key version 1: group consumption offset + * -> value version 1: [offset, metadata, commit_timestamp, expire_timestamp] + * + * key version 2: group metadata + * -> value version 0: [protocol_type, generation, protocol, leader, members] + */ +object GroupMetadataManager { + // Metrics names + val MetricsGroup: String = "group-coordinator-metrics" + val LoadTimeSensor: String = "GroupPartitionLoadTime" + val OffsetCommitsSensor: String = "OffsetCommits" + val OffsetExpiredSensor: String = "OffsetExpired" + + /** + * Generates the key for offset commit message for given (group, topic, partition) + * + * @param groupId the ID of the group to generate the key + * @param topicPartition the TopicPartition to generate the key + * @return key for offset commit message + */ + def offsetCommitKey(groupId: String, topicPartition: TopicPartition): Array[Byte] = { + MessageUtil.toVersionPrefixedBytes(OffsetCommitKey.HIGHEST_SUPPORTED_VERSION, + new OffsetCommitKey() + .setGroup(groupId) + .setTopic(topicPartition.topic) + .setPartition(topicPartition.partition)) + } + + /** + * Generates the key for group metadata message for given group + * + * @param groupId the ID of the group to generate the key + * @return key bytes for group metadata message + */ + def groupMetadataKey(groupId: String): Array[Byte] = { + MessageUtil.toVersionPrefixedBytes(GroupMetadataKeyData.HIGHEST_SUPPORTED_VERSION, + new GroupMetadataKeyData() + .setGroup(groupId)) + } + + /** + * Generates the payload for offset commit message from given offset and metadata + * + * @param offsetAndMetadata consumer's current offset and metadata + * @param apiVersion the api version + * @return payload for offset commit message + */ + def offsetCommitValue(offsetAndMetadata: OffsetAndMetadata, + apiVersion: ApiVersion): Array[Byte] = { + val version = + if (apiVersion < KAFKA_2_1_IV0 || offsetAndMetadata.expireTimestamp.nonEmpty) 1.toShort + else if (apiVersion < KAFKA_2_1_IV1) 2.toShort + else 3.toShort + MessageUtil.toVersionPrefixedBytes(version, new OffsetCommitValue() + .setOffset(offsetAndMetadata.offset) + .setMetadata(offsetAndMetadata.metadata) + .setCommitTimestamp(offsetAndMetadata.commitTimestamp) + .setLeaderEpoch(offsetAndMetadata.leaderEpoch.orElse(RecordBatch.NO_PARTITION_LEADER_EPOCH)) + // version 1 has a non empty expireTimestamp field + .setExpireTimestamp(offsetAndMetadata.expireTimestamp.getOrElse(OffsetCommitRequest.DEFAULT_TIMESTAMP)) + ) + } + + /** + * Generates the payload for group metadata message from given offset and metadata + * assuming the generation id, selected protocol, leader and member assignment are all available + * + * @param groupMetadata current group metadata + * @param assignment the assignment for the rebalancing generation + * @param apiVersion the api version + * @return payload for offset commit message + */ + def groupMetadataValue(groupMetadata: GroupMetadata, + assignment: Map[String, Array[Byte]], + apiVersion: ApiVersion): Array[Byte] = { + + val version = + if (apiVersion < KAFKA_0_10_1_IV0) 0.toShort + else if (apiVersion < KAFKA_2_1_IV0) 1.toShort + else if (apiVersion < KAFKA_2_3_IV0) 2.toShort + else 3.toShort + + MessageUtil.toVersionPrefixedBytes(version, new GroupMetadataValue() + .setProtocolType(groupMetadata.protocolType.getOrElse("")) + .setGeneration(groupMetadata.generationId) + .setProtocol(groupMetadata.protocolName.orNull) + .setLeader(groupMetadata.leaderOrNull) + .setCurrentStateTimestamp(groupMetadata.currentStateTimestampOrDefault) + .setMembers(groupMetadata.allMemberMetadata.map { memberMetadata => + new GroupMetadataValue.MemberMetadata() + .setMemberId(memberMetadata.memberId) + .setClientId(memberMetadata.clientId) + .setClientHost(memberMetadata.clientHost) + .setSessionTimeout(memberMetadata.sessionTimeoutMs) + .setRebalanceTimeout(memberMetadata.rebalanceTimeoutMs) + .setGroupInstanceId(memberMetadata.groupInstanceId.orNull) + // The group is non-empty, so the current protocol must be defined + .setSubscription(groupMetadata.protocolName.map(memberMetadata.metadata) + .getOrElse(throw new IllegalStateException("Attempted to write non-empty group metadata with no defined protocol."))) + .setAssignment(assignment.getOrElse(memberMetadata.memberId, + throw new IllegalStateException(s"Attempted to write member ${memberMetadata.memberId} of group ${groupMetadata.groupId} with no assignment."))) + }.asJava)) + } + + /** + * Decodes the offset messages' key + * + * @param buffer input byte-buffer + * @return an OffsetKey or GroupMetadataKey object from the message + */ + def readMessageKey(buffer: ByteBuffer): BaseKey = { + val version = buffer.getShort + if (version >= OffsetCommitKey.LOWEST_SUPPORTED_VERSION && version <= OffsetCommitKey.HIGHEST_SUPPORTED_VERSION) { + // version 0 and 1 refer to offset + val key = new OffsetCommitKey(new ByteBufferAccessor(buffer), version) + OffsetKey(version, GroupTopicPartition(key.group, new TopicPartition(key.topic, key.partition))) + } else if (version >= GroupMetadataKeyData.LOWEST_SUPPORTED_VERSION && version <= GroupMetadataKeyData.HIGHEST_SUPPORTED_VERSION) { + // version 2 refers to group metadata + val key = new GroupMetadataKeyData(new ByteBufferAccessor(buffer), version) + GroupMetadataKey(version, key.group) + } else throw new IllegalStateException(s"Unknown group metadata message version: $version") + } + + /** + * Decodes the offset messages' payload and retrieves offset and metadata from it + * + * @param buffer input byte-buffer + * @return an offset-metadata object from the message + */ + def readOffsetMessageValue(buffer: ByteBuffer): OffsetAndMetadata = { + // tombstone + if (buffer == null) null + else { + val version = buffer.getShort + if (version >= OffsetCommitValue.LOWEST_SUPPORTED_VERSION && version <= OffsetCommitValue.HIGHEST_SUPPORTED_VERSION) { + val value = new OffsetCommitValue(new ByteBufferAccessor(buffer), version) + OffsetAndMetadata( + offset = value.offset, + leaderEpoch = if (value.leaderEpoch == RecordBatch.NO_PARTITION_LEADER_EPOCH) Optional.empty() else Optional.of(value.leaderEpoch), + metadata = value.metadata, + commitTimestamp = value.commitTimestamp, + expireTimestamp = if (value.expireTimestamp == OffsetCommitRequest.DEFAULT_TIMESTAMP) None else Some(value.expireTimestamp)) + } else throw new IllegalStateException(s"Unknown offset message version: $version") + } + } + + /** + * Decodes the group metadata messages' payload and retrieves its member metadata from it + * + * @param groupId The ID of the group to be read + * @param buffer input byte-buffer + * @param time the time instance to use + * @return a group metadata object from the message + */ + def readGroupMessageValue(groupId: String, buffer: ByteBuffer, time: Time): GroupMetadata = { + // tombstone + if (buffer == null) null + else { + val version = buffer.getShort + if (version >= GroupMetadataValue.LOWEST_SUPPORTED_VERSION && version <= GroupMetadataValue.HIGHEST_SUPPORTED_VERSION) { + val value = new GroupMetadataValue(new ByteBufferAccessor(buffer), version) + val members = value.members.asScala.map { memberMetadata => + new MemberMetadata( + memberId = memberMetadata.memberId, + groupInstanceId = Option(memberMetadata.groupInstanceId), + clientId = memberMetadata.clientId, + clientHost = memberMetadata.clientHost, + rebalanceTimeoutMs = if (version == 0) memberMetadata.sessionTimeout else memberMetadata.rebalanceTimeout, + sessionTimeoutMs = memberMetadata.sessionTimeout, + protocolType = value.protocolType, + supportedProtocols = List((value.protocol, memberMetadata.subscription)), + assignment = memberMetadata.assignment) + } + GroupMetadata.loadGroup( + groupId = groupId, + initialState = if (members.isEmpty) Empty else Stable, + generationId = value.generation, + protocolType = value.protocolType, + protocolName = value.protocol, + leaderId = value.leader, + currentStateTimestamp = if (value.currentStateTimestamp == -1) None else Some(value.currentStateTimestamp), + members = members, + time = time) + } else throw new IllegalStateException(s"Unknown group metadata message version: $version") + } + } + + // Formatter for use with tools such as console consumer: Consumer should also set exclude.internal.topics to false. + // (specify --formatter "kafka.coordinator.group.GroupMetadataManager\$OffsetsMessageFormatter" when consuming __consumer_offsets) + class OffsetsMessageFormatter extends MessageFormatter { + def writeTo(consumerRecord: ConsumerRecord[Array[Byte], Array[Byte]], output: PrintStream): Unit = { + Option(consumerRecord.key).map(key => GroupMetadataManager.readMessageKey(ByteBuffer.wrap(key))).foreach { + // Only print if the message is an offset record. + // We ignore the timestamp of the message because GroupMetadataMessage has its own timestamp. + case offsetKey: OffsetKey => + val groupTopicPartition = offsetKey.key + val value = consumerRecord.value + val formattedValue = + if (value == null) "NULL" + else GroupMetadataManager.readOffsetMessageValue(ByteBuffer.wrap(value)).toString + output.write(groupTopicPartition.toString.getBytes(StandardCharsets.UTF_8)) + output.write("::".getBytes(StandardCharsets.UTF_8)) + output.write(formattedValue.getBytes(StandardCharsets.UTF_8)) + output.write("\n".getBytes(StandardCharsets.UTF_8)) + case _ => // no-op + } + } + } + + // Formatter for use with tools to read group metadata history + class GroupMetadataMessageFormatter extends MessageFormatter { + def writeTo(consumerRecord: ConsumerRecord[Array[Byte], Array[Byte]], output: PrintStream): Unit = { + Option(consumerRecord.key).map(key => GroupMetadataManager.readMessageKey(ByteBuffer.wrap(key))).foreach { + // Only print if the message is a group metadata record. + // We ignore the timestamp of the message because GroupMetadataMessage has its own timestamp. + case groupMetadataKey: GroupMetadataKey => + val groupId = groupMetadataKey.key + val value = consumerRecord.value + val formattedValue = + if (value == null) "NULL" + else GroupMetadataManager.readGroupMessageValue(groupId, ByteBuffer.wrap(value), Time.SYSTEM).toString + output.write(groupId.getBytes(StandardCharsets.UTF_8)) + output.write("::".getBytes(StandardCharsets.UTF_8)) + output.write(formattedValue.getBytes(StandardCharsets.UTF_8)) + output.write("\n".getBytes(StandardCharsets.UTF_8)) + case _ => // no-op + } + } + } + + /** + * Exposed for printing records using [[kafka.tools.DumpLogSegments]] + */ + def formatRecordKeyAndValue(record: Record): (Option[String], Option[String]) = { + if (!record.hasKey) { + throw new KafkaException("Failed to decode message using offset topic decoder (message had a missing key)") + } else { + GroupMetadataManager.readMessageKey(record.key) match { + case offsetKey: OffsetKey => parseOffsets(offsetKey, record.value) + case groupMetadataKey: GroupMetadataKey => parseGroupMetadata(groupMetadataKey, record.value) + case _ => throw new KafkaException("Failed to decode message using offset topic decoder (message had an invalid key)") + } + } + } + + private def parseOffsets(offsetKey: OffsetKey, payload: ByteBuffer): (Option[String], Option[String]) = { + val groupId = offsetKey.key.group + val topicPartition = offsetKey.key.topicPartition + val keyString = s"offset_commit::group=$groupId,partition=$topicPartition" + + val offset = GroupMetadataManager.readOffsetMessageValue(payload) + val valueString = if (offset == null) { + "" + } else { + if (offset.metadata.isEmpty) + s"offset=${offset.offset}" + else + s"offset=${offset.offset},metadata=${offset.metadata}" + } + + (Some(keyString), Some(valueString)) + } + + private def parseGroupMetadata(groupMetadataKey: GroupMetadataKey, payload: ByteBuffer): (Option[String], Option[String]) = { + val groupId = groupMetadataKey.key + val keyString = s"group_metadata::group=$groupId" + + val group = GroupMetadataManager.readGroupMessageValue(groupId, payload, Time.SYSTEM) + val valueString = if (group == null) + "" + else { + val protocolType = group.protocolType.getOrElse("") + + val assignment = group.allMemberMetadata.map { member => + if (protocolType == ConsumerProtocol.PROTOCOL_TYPE) { + val partitionAssignment = ConsumerProtocol.deserializeAssignment(ByteBuffer.wrap(member.assignment)) + val userData = Option(partitionAssignment.userData) + .map(Utils.toArray) + .map(hex) + .getOrElse("") + + if (userData.isEmpty) + s"${member.memberId}=${partitionAssignment.partitions}" + else + s"${member.memberId}=${partitionAssignment.partitions}:$userData" + } else { + s"${member.memberId}=${hex(member.assignment)}" + } + }.mkString("{", ",", "}") + + Json.encodeAsString(Map( + "protocolType" -> protocolType, + "protocol" -> group.protocolName.orNull, + "generationId" -> group.generationId, + "assignment" -> assignment + ).asJava) + } + (Some(keyString), Some(valueString)) + } + + private def hex(bytes: Array[Byte]): String = { + if (bytes.isEmpty) + "" + else + "%X".format(BigInt(1, bytes)) + } + +} + +case class GroupTopicPartition(group: String, topicPartition: TopicPartition) { + + def this(group: String, topic: String, partition: Int) = + this(group, new TopicPartition(topic, partition)) + + override def toString: String = + "[%s,%s,%d]".format(group, topicPartition.topic, topicPartition.partition) +} + +trait BaseKey{ + def version: Short + def key: Any +} + +case class OffsetKey(version: Short, key: GroupTopicPartition) extends BaseKey { + + override def toString: String = key.toString +} + +case class GroupMetadataKey(version: Short, key: String) extends BaseKey { + + override def toString: String = key +} + diff --git a/core/src/main/scala/kafka/coordinator/group/MemberMetadata.scala b/core/src/main/scala/kafka/coordinator/group/MemberMetadata.scala new file mode 100644 index 0000000..514dbfb --- /dev/null +++ b/core/src/main/scala/kafka/coordinator/group/MemberMetadata.scala @@ -0,0 +1,153 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.coordinator.group + +import java.util + +import kafka.utils.nonthreadsafe + +case class MemberSummary(memberId: String, + groupInstanceId: Option[String], + clientId: String, + clientHost: String, + metadata: Array[Byte], + assignment: Array[Byte]) + +private object MemberMetadata { + def plainProtocolSet(supportedProtocols: List[(String, Array[Byte])]) = supportedProtocols.map(_._1).toSet +} + +/** + * Member metadata contains the following metadata: + * + * Heartbeat metadata: + * 1. negotiated heartbeat session timeout + * 2. timestamp of the latest heartbeat + * + * Protocol metadata: + * 1. the list of supported protocols (ordered by preference) + * 2. the metadata associated with each protocol + * + * In addition, it also contains the following state information: + * + * 1. Awaiting rebalance callback: when the group is in the prepare-rebalance state, + * its rebalance callback will be kept in the metadata if the + * member has sent the join group request + * 2. Awaiting sync callback: when the group is in the awaiting-sync state, its sync callback + * is kept in metadata until the leader provides the group assignment + * and the group transitions to stable + */ +@nonthreadsafe +private[group] class MemberMetadata(var memberId: String, + val groupInstanceId: Option[String], + val clientId: String, + val clientHost: String, + val rebalanceTimeoutMs: Int, + val sessionTimeoutMs: Int, + val protocolType: String, + var supportedProtocols: List[(String, Array[Byte])], + var assignment: Array[Byte] = Array.empty[Byte]) { + + var awaitingJoinCallback: JoinGroupResult => Unit = _ + var awaitingSyncCallback: SyncGroupResult => Unit = _ + var isNew: Boolean = false + + def isStaticMember: Boolean = groupInstanceId.isDefined + + // This variable is used to track heartbeat completion through the delayed + // heartbeat purgatory. When scheduling a new heartbeat expiration, we set + // this value to `false`. Upon receiving the heartbeat (or any other event + // indicating the liveness of the client), we set it to `true` so that the + // delayed heartbeat can be completed. + var heartbeatSatisfied: Boolean = false + + def isAwaitingJoin: Boolean = awaitingJoinCallback != null + def isAwaitingSync: Boolean = awaitingSyncCallback != null + + /** + * Get metadata corresponding to the provided protocol. + */ + def metadata(protocol: String): Array[Byte] = { + supportedProtocols.find(_._1 == protocol) match { + case Some((_, metadata)) => metadata + case None => + throw new IllegalArgumentException("Member does not support protocol") + } + } + + def hasSatisfiedHeartbeat: Boolean = { + if (isNew) { + // New members can be expired while awaiting join, so we have to check this first + heartbeatSatisfied + } else if (isAwaitingJoin || isAwaitingSync) { + // Members that are awaiting a rebalance automatically satisfy expected heartbeats + true + } else { + // Otherwise we require the next heartbeat + heartbeatSatisfied + } + } + + /** + * Check if the provided protocol metadata matches the currently stored metadata. + */ + def matches(protocols: List[(String, Array[Byte])]): Boolean = { + if (protocols.size != this.supportedProtocols.size) + return false + + for (i <- protocols.indices) { + val p1 = protocols(i) + val p2 = supportedProtocols(i) + if (p1._1 != p2._1 || !util.Arrays.equals(p1._2, p2._2)) + return false + } + true + } + + def summary(protocol: String): MemberSummary = { + MemberSummary(memberId, groupInstanceId, clientId, clientHost, metadata(protocol), assignment) + } + + def summaryNoMetadata(): MemberSummary = { + MemberSummary(memberId, groupInstanceId, clientId, clientHost, Array.empty[Byte], Array.empty[Byte]) + } + + /** + * Vote for one of the potential group protocols. This takes into account the protocol preference as + * indicated by the order of supported protocols and returns the first one also contained in the set + */ + def vote(candidates: Set[String]): String = { + supportedProtocols.find({ case (protocol, _) => candidates.contains(protocol)}) match { + case Some((protocol, _)) => protocol + case None => + throw new IllegalArgumentException("Member does not support any of the candidate protocols") + } + } + + override def toString: String = { + "MemberMetadata(" + + s"memberId=$memberId, " + + s"groupInstanceId=$groupInstanceId, " + + s"clientId=$clientId, " + + s"clientHost=$clientHost, " + + s"sessionTimeoutMs=$sessionTimeoutMs, " + + s"rebalanceTimeoutMs=$rebalanceTimeoutMs, " + + s"supportedProtocols=${supportedProtocols.map(_._1)}" + + ")" + } +} diff --git a/core/src/main/scala/kafka/coordinator/group/OffsetConfig.scala b/core/src/main/scala/kafka/coordinator/group/OffsetConfig.scala new file mode 100644 index 0000000..55ec590 --- /dev/null +++ b/core/src/main/scala/kafka/coordinator/group/OffsetConfig.scala @@ -0,0 +1,62 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.coordinator.group + +import kafka.message.{CompressionCodec, NoCompressionCodec} + +/** + * Configuration settings for in-built offset management + * @param maxMetadataSize The maximum allowed metadata for any offset commit. + * @param loadBufferSize Batch size for reading from the offsets segments when loading offsets into the cache. + * @param offsetsRetentionMs After a consumer group loses all its consumers (i.e. becomes empty) its offsets will be kept for this retention period before getting discarded. + * For standalone consumers (using manual assignment), offsets will be expired after the time of last commit plus this retention period. + * @param offsetsRetentionCheckIntervalMs Frequency at which to check for expired offsets. + * @param offsetsTopicNumPartitions The number of partitions for the offset commit topic (should not change after deployment). + * @param offsetsTopicSegmentBytes The offsets topic segment bytes should be kept relatively small to facilitate faster + * log compaction and faster offset loads + * @param offsetsTopicReplicationFactor The replication factor for the offset commit topic (set higher to ensure availability). + * @param offsetsTopicCompressionCodec Compression codec for the offsets topic - compression should be turned on in + * order to achieve "atomic" commits. + * @param offsetCommitTimeoutMs The offset commit will be delayed until all replicas for the offsets topic receive the + * commit or this timeout is reached. (Similar to the producer request timeout.) + * @param offsetCommitRequiredAcks The required acks before the commit can be accepted. In general, the default (-1) + * should not be overridden. + */ +case class OffsetConfig(maxMetadataSize: Int = OffsetConfig.DefaultMaxMetadataSize, + loadBufferSize: Int = OffsetConfig.DefaultLoadBufferSize, + offsetsRetentionMs: Long = OffsetConfig.DefaultOffsetRetentionMs, + offsetsRetentionCheckIntervalMs: Long = OffsetConfig.DefaultOffsetsRetentionCheckIntervalMs, + offsetsTopicNumPartitions: Int = OffsetConfig.DefaultOffsetsTopicNumPartitions, + offsetsTopicSegmentBytes: Int = OffsetConfig.DefaultOffsetsTopicSegmentBytes, + offsetsTopicReplicationFactor: Short = OffsetConfig.DefaultOffsetsTopicReplicationFactor, + offsetsTopicCompressionCodec: CompressionCodec = OffsetConfig.DefaultOffsetsTopicCompressionCodec, + offsetCommitTimeoutMs: Int = OffsetConfig.DefaultOffsetCommitTimeoutMs, + offsetCommitRequiredAcks: Short = OffsetConfig.DefaultOffsetCommitRequiredAcks) + +object OffsetConfig { + val DefaultMaxMetadataSize = 4096 + val DefaultLoadBufferSize = 5*1024*1024 + val DefaultOffsetRetentionMs = 24*60*60*1000L + val DefaultOffsetsRetentionCheckIntervalMs = 600000L + val DefaultOffsetsTopicNumPartitions = 50 + val DefaultOffsetsTopicSegmentBytes = 100*1024*1024 + val DefaultOffsetsTopicReplicationFactor = 3.toShort + val DefaultOffsetsTopicCompressionCodec = NoCompressionCodec + val DefaultOffsetCommitTimeoutMs = 5000 + val DefaultOffsetCommitRequiredAcks = (-1).toShort +} \ No newline at end of file diff --git a/core/src/main/scala/kafka/coordinator/transaction/ProducerIdManager.scala b/core/src/main/scala/kafka/coordinator/transaction/ProducerIdManager.scala new file mode 100644 index 0000000..b5d419d --- /dev/null +++ b/core/src/main/scala/kafka/coordinator/transaction/ProducerIdManager.scala @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.coordinator.transaction + +import kafka.server.{BrokerToControllerChannelManager, ControllerRequestCompletionHandler} +import kafka.utils.Logging +import kafka.zk.{KafkaZkClient, ProducerIdBlockZNode} +import org.apache.kafka.clients.ClientResponse +import org.apache.kafka.common.KafkaException +import org.apache.kafka.common.message.AllocateProducerIdsRequestData +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.{AllocateProducerIdsRequest, AllocateProducerIdsResponse} +import org.apache.kafka.server.common.ProducerIdsBlock + +import java.util.concurrent.{ArrayBlockingQueue, TimeUnit} +import java.util.concurrent.atomic.AtomicBoolean +import scala.util.{Failure, Success, Try} + +/** + * ProducerIdManager is the part of the transaction coordinator that provides ProducerIds in a unique way + * such that the same producerId will not be assigned twice across multiple transaction coordinators. + * + * ProducerIds are managed by the controller. When requesting a new range of IDs, we are guaranteed to receive + * a unique block. + */ + +object ProducerIdManager { + // Once we reach this percentage of PIDs consumed from the current block, trigger a fetch of the next block + val PidPrefetchThreshold = 0.90 + + // Creates a ProducerIdGenerate that directly interfaces with ZooKeeper, IBP < 3.0-IV0 + def zk(brokerId: Int, zkClient: KafkaZkClient): ZkProducerIdManager = { + new ZkProducerIdManager(brokerId, zkClient) + } + + // Creates a ProducerIdGenerate that uses AllocateProducerIds RPC, IBP >= 3.0-IV0 + def rpc(brokerId: Int, + brokerEpochSupplier: () => Long, + controllerChannel: BrokerToControllerChannelManager, + maxWaitMs: Int): RPCProducerIdManager = { + new RPCProducerIdManager(brokerId, brokerEpochSupplier, controllerChannel, maxWaitMs) + } +} + +trait ProducerIdManager { + def generateProducerId(): Long + def shutdown() : Unit = {} +} + +object ZkProducerIdManager { + def getNewProducerIdBlock(brokerId: Int, zkClient: KafkaZkClient, logger: Logging): ProducerIdsBlock = { + // Get or create the existing PID block from ZK and attempt to update it. We retry in a loop here since other + // brokers may be generating PID blocks during a rolling upgrade + var zkWriteComplete = false + while (!zkWriteComplete) { + // refresh current producerId block from zookeeper again + val (dataOpt, zkVersion) = zkClient.getDataAndVersion(ProducerIdBlockZNode.path) + + // generate the new producerId block + val newProducerIdBlock = dataOpt match { + case Some(data) => + val currProducerIdBlock = ProducerIdBlockZNode.parseProducerIdBlockData(data) + logger.debug(s"Read current producerId block $currProducerIdBlock, Zk path version $zkVersion") + + if (currProducerIdBlock.producerIdEnd > Long.MaxValue - ProducerIdsBlock.PRODUCER_ID_BLOCK_SIZE) { + // we have exhausted all producerIds (wow!), treat it as a fatal error + logger.fatal(s"Exhausted all producerIds as the next block's end producerId is will has exceeded long type limit (current block end producerId is ${currProducerIdBlock.producerIdEnd})") + throw new KafkaException("Have exhausted all producerIds.") + } + + new ProducerIdsBlock(brokerId, currProducerIdBlock.producerIdEnd + 1L, ProducerIdsBlock.PRODUCER_ID_BLOCK_SIZE) + case None => + logger.debug(s"There is no producerId block yet (Zk path version $zkVersion), creating the first block") + new ProducerIdsBlock(brokerId, 0L, ProducerIdsBlock.PRODUCER_ID_BLOCK_SIZE) + } + + val newProducerIdBlockData = ProducerIdBlockZNode.generateProducerIdBlockJson(newProducerIdBlock) + + // try to write the new producerId block into zookeeper + val (succeeded, version) = zkClient.conditionalUpdatePath(ProducerIdBlockZNode.path, newProducerIdBlockData, zkVersion, None) + zkWriteComplete = succeeded + + if (zkWriteComplete) { + logger.info(s"Acquired new producerId block $newProducerIdBlock by writing to Zk with path version $version") + return newProducerIdBlock + } + } + throw new IllegalStateException() + } +} + +class ZkProducerIdManager(brokerId: Int, + zkClient: KafkaZkClient) extends ProducerIdManager with Logging { + + this.logIdent = "[ZK ProducerId Manager " + brokerId + "]: " + + private var currentProducerIdBlock: ProducerIdsBlock = ProducerIdsBlock.EMPTY + private var nextProducerId: Long = _ + + // grab the first block of producerIds + this synchronized { + allocateNewProducerIdBlock() + nextProducerId = currentProducerIdBlock.producerIdStart + } + + private def allocateNewProducerIdBlock(): Unit = { + this synchronized { + currentProducerIdBlock = ZkProducerIdManager.getNewProducerIdBlock(brokerId, zkClient, this) + } + } + + def generateProducerId(): Long = { + this synchronized { + // grab a new block of producerIds if this block has been exhausted + if (nextProducerId > currentProducerIdBlock.producerIdEnd) { + allocateNewProducerIdBlock() + nextProducerId = currentProducerIdBlock.producerIdStart + } + nextProducerId += 1 + nextProducerId - 1 + } + } +} + +class RPCProducerIdManager(brokerId: Int, + brokerEpochSupplier: () => Long, + controllerChannel: BrokerToControllerChannelManager, + maxWaitMs: Int) extends ProducerIdManager with Logging { + + this.logIdent = "[RPC ProducerId Manager " + brokerId + "]: " + + private val nextProducerIdBlock = new ArrayBlockingQueue[Try[ProducerIdsBlock]](1) + private val requestInFlight = new AtomicBoolean(false) + private var currentProducerIdBlock: ProducerIdsBlock = ProducerIdsBlock.EMPTY + private var nextProducerId: Long = -1L + + override def generateProducerId(): Long = { + this synchronized { + if (nextProducerId == -1L) { + // Send an initial request to get the first block + maybeRequestNextBlock() + nextProducerId = 0L + } else { + nextProducerId += 1 + + // Check if we need to fetch the next block + if (nextProducerId >= (currentProducerIdBlock.producerIdStart + currentProducerIdBlock.producerIdLen * ProducerIdManager.PidPrefetchThreshold)) { + maybeRequestNextBlock() + } + } + + // If we've exhausted the current block, grab the next block (waiting if necessary) + if (nextProducerId > currentProducerIdBlock.producerIdEnd) { + val block = nextProducerIdBlock.poll(maxWaitMs, TimeUnit.MILLISECONDS) + if (block == null) { + throw Errors.REQUEST_TIMED_OUT.exception("Timed out waiting for next producer ID block") + } else { + block match { + case Success(nextBlock) => + currentProducerIdBlock = nextBlock + nextProducerId = currentProducerIdBlock.producerIdStart + case Failure(t) => throw t + } + } + } + nextProducerId + } + } + + + private def maybeRequestNextBlock(): Unit = { + if (nextProducerIdBlock.isEmpty && requestInFlight.compareAndSet(false, true)) { + sendRequest() + } + } + + private[transaction] def sendRequest(): Unit = { + val message = new AllocateProducerIdsRequestData() + .setBrokerEpoch(brokerEpochSupplier.apply()) + .setBrokerId(brokerId) + + val request = new AllocateProducerIdsRequest.Builder(message) + debug("Requesting next Producer ID block") + controllerChannel.sendRequest(request, new ControllerRequestCompletionHandler() { + override def onComplete(response: ClientResponse): Unit = { + val message = response.responseBody().asInstanceOf[AllocateProducerIdsResponse] + handleAllocateProducerIdsResponse(message) + } + + override def onTimeout(): Unit = handleTimeout() + }) + } + + private[transaction] def handleAllocateProducerIdsResponse(response: AllocateProducerIdsResponse): Unit = { + requestInFlight.set(false) + val data = response.data + Errors.forCode(data.errorCode()) match { + case Errors.NONE => + debug(s"Got next producer ID block from controller $data") + // Do some sanity checks on the response + if (data.producerIdStart() < currentProducerIdBlock.producerIdEnd) { + nextProducerIdBlock.put(Failure(new KafkaException( + s"Producer ID block is not monotonic with current block: current=$currentProducerIdBlock response=$data"))) + } else if (data.producerIdStart() < 0 || data.producerIdLen() < 0 || data.producerIdStart() > Long.MaxValue - data.producerIdLen()) { + nextProducerIdBlock.put(Failure(new KafkaException(s"Producer ID block includes invalid ID range: $data"))) + } else { + nextProducerIdBlock.put( + Success(new ProducerIdsBlock(brokerId, data.producerIdStart(), data.producerIdLen()))) + } + case Errors.STALE_BROKER_EPOCH => + warn("Our broker epoch was stale, trying again.") + maybeRequestNextBlock() + case Errors.BROKER_ID_NOT_REGISTERED => + warn("Our broker ID is not yet known by the controller, trying again.") + maybeRequestNextBlock() + case e: Errors => + warn("Had an unknown error from the controller, giving up.") + nextProducerIdBlock.put(Failure(e.exception())) + } + } + + private[transaction] def handleTimeout(): Unit = { + warn("Timed out when requesting AllocateProducerIds from the controller.") + requestInFlight.set(false) + nextProducerIdBlock.put(Failure(Errors.REQUEST_TIMED_OUT.exception)) + maybeRequestNextBlock() + } +} diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala new file mode 100644 index 0000000..78983c1 --- /dev/null +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala @@ -0,0 +1,694 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.coordinator.transaction + +import java.util.Properties +import java.util.concurrent.atomic.AtomicBoolean +import kafka.server.{KafkaConfig, MetadataCache, ReplicaManager, RequestLocal} +import kafka.utils.{Logging, Scheduler} +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.internals.Topic +import org.apache.kafka.common.message.{DescribeTransactionsResponseData, ListTransactionsResponseData} +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.record.RecordBatch +import org.apache.kafka.common.requests.TransactionResult +import org.apache.kafka.common.utils.{LogContext, ProducerIdAndEpoch, Time} + +object TransactionCoordinator { + + def apply(config: KafkaConfig, + replicaManager: ReplicaManager, + scheduler: Scheduler, + createProducerIdGenerator: () => ProducerIdManager, + metrics: Metrics, + metadataCache: MetadataCache, + time: Time): TransactionCoordinator = { + + val txnConfig = TransactionConfig(config.transactionalIdExpirationMs, + config.transactionMaxTimeoutMs, + config.transactionTopicPartitions, + config.transactionTopicReplicationFactor, + config.transactionTopicSegmentBytes, + config.transactionsLoadBufferSize, + config.transactionTopicMinISR, + config.transactionAbortTimedOutTransactionCleanupIntervalMs, + config.transactionRemoveExpiredTransactionalIdCleanupIntervalMs, + config.requestTimeoutMs) + + val txnStateManager = new TransactionStateManager(config.brokerId, scheduler, replicaManager, txnConfig, + time, metrics) + + val logContext = new LogContext(s"[TransactionCoordinator id=${config.brokerId}] ") + val txnMarkerChannelManager = TransactionMarkerChannelManager(config, metrics, metadataCache, txnStateManager, + time, logContext) + + new TransactionCoordinator(config.brokerId, txnConfig, scheduler, createProducerIdGenerator, txnStateManager, txnMarkerChannelManager, + time, logContext) + } + + private def initTransactionError(error: Errors): InitProducerIdResult = { + InitProducerIdResult(RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, error) + } + + private def initTransactionMetadata(txnMetadata: TxnTransitMetadata): InitProducerIdResult = { + InitProducerIdResult(txnMetadata.producerId, txnMetadata.producerEpoch, Errors.NONE) + } +} + +/** + * Transaction coordinator handles message transactions sent by producers and communicate with brokers + * to update ongoing transaction's status. + * + * Each Kafka server instantiates a transaction coordinator which is responsible for a set of + * producers. Producers with specific transactional ids are assigned to their corresponding coordinators; + * Producers with no specific transactional id may talk to a random broker as their coordinators. + */ +class TransactionCoordinator(brokerId: Int, + txnConfig: TransactionConfig, + scheduler: Scheduler, + createProducerIdManager: () => ProducerIdManager, + txnManager: TransactionStateManager, + txnMarkerChannelManager: TransactionMarkerChannelManager, + time: Time, + logContext: LogContext) extends Logging { + this.logIdent = logContext.logPrefix + + import TransactionCoordinator._ + + type InitProducerIdCallback = InitProducerIdResult => Unit + type AddPartitionsCallback = Errors => Unit + type EndTxnCallback = Errors => Unit + type ApiResult[T] = Either[Errors, T] + + /* Active flag of the coordinator */ + private val isActive = new AtomicBoolean(false) + + val producerIdManager = createProducerIdManager() + + def handleInitProducerId(transactionalId: String, + transactionTimeoutMs: Int, + expectedProducerIdAndEpoch: Option[ProducerIdAndEpoch], + responseCallback: InitProducerIdCallback, + requestLocal: RequestLocal = RequestLocal.NoCaching): Unit = { + + if (transactionalId == null) { + // if the transactional id is null, then always blindly accept the request + // and return a new producerId from the producerId manager + val producerId = producerIdManager.generateProducerId() + responseCallback(InitProducerIdResult(producerId, producerEpoch = 0, Errors.NONE)) + } else if (transactionalId.isEmpty) { + // if transactional id is empty then return error as invalid request. This is + // to make TransactionCoordinator's behavior consistent with producer client + responseCallback(initTransactionError(Errors.INVALID_REQUEST)) + } else if (!txnManager.validateTransactionTimeoutMs(transactionTimeoutMs)) { + // check transactionTimeoutMs is not larger than the broker configured maximum allowed value + responseCallback(initTransactionError(Errors.INVALID_TRANSACTION_TIMEOUT)) + } else { + val coordinatorEpochAndMetadata = txnManager.getTransactionState(transactionalId).flatMap { + case None => + val producerId = producerIdManager.generateProducerId() + val createdMetadata = new TransactionMetadata(transactionalId = transactionalId, + producerId = producerId, + lastProducerId = RecordBatch.NO_PRODUCER_ID, + producerEpoch = RecordBatch.NO_PRODUCER_EPOCH, + lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, + txnTimeoutMs = transactionTimeoutMs, + state = Empty, + topicPartitions = collection.mutable.Set.empty[TopicPartition], + txnLastUpdateTimestamp = time.milliseconds()) + txnManager.putTransactionStateIfNotExists(createdMetadata) + + case Some(epochAndTxnMetadata) => Right(epochAndTxnMetadata) + } + + val result: ApiResult[(Int, TxnTransitMetadata)] = coordinatorEpochAndMetadata.flatMap { + existingEpochAndMetadata => + val coordinatorEpoch = existingEpochAndMetadata.coordinatorEpoch + val txnMetadata = existingEpochAndMetadata.transactionMetadata + + txnMetadata.inLock { + prepareInitProducerIdTransit(transactionalId, transactionTimeoutMs, coordinatorEpoch, txnMetadata, + expectedProducerIdAndEpoch) + } + } + + result match { + case Left(error) => + responseCallback(initTransactionError(error)) + + case Right((coordinatorEpoch, newMetadata)) => + if (newMetadata.txnState == PrepareEpochFence) { + // abort the ongoing transaction and then return CONCURRENT_TRANSACTIONS to let client wait and retry + def sendRetriableErrorCallback(error: Errors): Unit = { + if (error != Errors.NONE) { + responseCallback(initTransactionError(error)) + } else { + responseCallback(initTransactionError(Errors.CONCURRENT_TRANSACTIONS)) + } + } + + endTransaction(transactionalId, + newMetadata.producerId, + newMetadata.producerEpoch, + TransactionResult.ABORT, + isFromClient = false, + sendRetriableErrorCallback, + requestLocal) + } else { + def sendPidResponseCallback(error: Errors): Unit = { + if (error == Errors.NONE) { + info(s"Initialized transactionalId $transactionalId with producerId ${newMetadata.producerId} and producer " + + s"epoch ${newMetadata.producerEpoch} on partition " + + s"${Topic.TRANSACTION_STATE_TOPIC_NAME}-${txnManager.partitionFor(transactionalId)}") + responseCallback(initTransactionMetadata(newMetadata)) + } else { + info(s"Returning $error error code to client for $transactionalId's InitProducerId request") + responseCallback(initTransactionError(error)) + } + } + + txnManager.appendTransactionToLog(transactionalId, coordinatorEpoch, newMetadata, + sendPidResponseCallback, requestLocal = requestLocal) + } + } + } + } + + private def prepareInitProducerIdTransit(transactionalId: String, + transactionTimeoutMs: Int, + coordinatorEpoch: Int, + txnMetadata: TransactionMetadata, + expectedProducerIdAndEpoch: Option[ProducerIdAndEpoch]): ApiResult[(Int, TxnTransitMetadata)] = { + + def isValidProducerId(producerIdAndEpoch: ProducerIdAndEpoch): Boolean = { + // If a producer ID and epoch are provided by the request, fence the producer unless one of the following is true: + // 1. The producer epoch is equal to -1, which implies that the metadata was just created. This is the case of a + // producer recovering from an UNKNOWN_PRODUCER_ID error, and it is safe to return the newly-generated + // producer ID. + // 2. The expected producer ID matches the ID in current metadata (the epoch will be checked when we try to + // increment it) + // 3. The expected producer ID matches the previous one and the expected epoch is exhausted, in which case this + // could be a retry after a valid epoch bump that the producer never received the response for + txnMetadata.producerEpoch == RecordBatch.NO_PRODUCER_EPOCH || + producerIdAndEpoch.producerId == txnMetadata.producerId || + (producerIdAndEpoch.producerId == txnMetadata.lastProducerId && TransactionMetadata.isEpochExhausted(producerIdAndEpoch.epoch)) + } + + if (txnMetadata.pendingTransitionInProgress) { + // return a retriable exception to let the client backoff and retry + Left(Errors.CONCURRENT_TRANSACTIONS) + } + else if (!expectedProducerIdAndEpoch.forall(isValidProducerId)) { + Left(Errors.PRODUCER_FENCED) + } else { + // caller should have synchronized on txnMetadata already + txnMetadata.state match { + case PrepareAbort | PrepareCommit => + // reply to client and let it backoff and retry + Left(Errors.CONCURRENT_TRANSACTIONS) + + case CompleteAbort | CompleteCommit | Empty => + val transitMetadataResult = + // If the epoch is exhausted and the expected epoch (if provided) matches it, generate a new producer ID + if (txnMetadata.isProducerEpochExhausted && + expectedProducerIdAndEpoch.forall(_.epoch == txnMetadata.producerEpoch)) { + val newProducerId = producerIdManager.generateProducerId() + Right(txnMetadata.prepareProducerIdRotation(newProducerId, transactionTimeoutMs, time.milliseconds(), + expectedProducerIdAndEpoch.isDefined)) + } else { + txnMetadata.prepareIncrementProducerEpoch(transactionTimeoutMs, expectedProducerIdAndEpoch.map(_.epoch), + time.milliseconds()) + } + + transitMetadataResult match { + case Right(transitMetadata) => Right((coordinatorEpoch, transitMetadata)) + case Left(err) => Left(err) + } + + case Ongoing => + // indicate to abort the current ongoing txn first. Note that this epoch is never returned to the + // user. We will abort the ongoing transaction and return CONCURRENT_TRANSACTIONS to the client. + // This forces the client to retry, which will ensure that the epoch is bumped a second time. In + // particular, if fencing the current producer exhausts the available epochs for the current producerId, + // then when the client retries, we will generate a new producerId. + Right(coordinatorEpoch, txnMetadata.prepareFenceProducerEpoch()) + + case Dead | PrepareEpochFence => + val errorMsg = s"Found transactionalId $transactionalId with state ${txnMetadata.state}. " + + s"This is illegal as we should never have transitioned to this state." + fatal(errorMsg) + throw new IllegalStateException(errorMsg) + } + } + } + + def handleListTransactions( + filteredProducerIds: Set[Long], + filteredStates: Set[String] + ): ListTransactionsResponseData = { + if (!isActive.get()) { + new ListTransactionsResponseData().setErrorCode(Errors.COORDINATOR_NOT_AVAILABLE.code) + } else { + txnManager.listTransactionStates(filteredProducerIds, filteredStates) + } + } + + def handleDescribeTransactions( + transactionalId: String + ): DescribeTransactionsResponseData.TransactionState = { + if (transactionalId == null) { + throw new IllegalArgumentException("Invalid null transactionalId") + } + + val transactionState = new DescribeTransactionsResponseData.TransactionState() + .setTransactionalId(transactionalId) + + if (!isActive.get()) { + transactionState.setErrorCode(Errors.COORDINATOR_NOT_AVAILABLE.code) + } else if (transactionalId.isEmpty) { + transactionState.setErrorCode(Errors.INVALID_REQUEST.code) + } else { + txnManager.getTransactionState(transactionalId) match { + case Left(error) => + transactionState.setErrorCode(error.code) + case Right(None) => + transactionState.setErrorCode(Errors.TRANSACTIONAL_ID_NOT_FOUND.code) + case Right(Some(coordinatorEpochAndMetadata)) => + val txnMetadata = coordinatorEpochAndMetadata.transactionMetadata + txnMetadata.inLock { + if (txnMetadata.state == Dead) { + // The transaction state is being expired, so ignore it + transactionState.setErrorCode(Errors.TRANSACTIONAL_ID_NOT_FOUND.code) + } else { + txnMetadata.topicPartitions.foreach { topicPartition => + var topicData = transactionState.topics.find(topicPartition.topic) + if (topicData == null) { + topicData = new DescribeTransactionsResponseData.TopicData() + .setTopic(topicPartition.topic) + transactionState.topics.add(topicData) + } + topicData.partitions.add(topicPartition.partition) + } + + transactionState + .setErrorCode(Errors.NONE.code) + .setProducerId(txnMetadata.producerId) + .setProducerEpoch(txnMetadata.producerEpoch) + .setTransactionState(txnMetadata.state.name) + .setTransactionTimeoutMs(txnMetadata.txnTimeoutMs) + .setTransactionStartTimeMs(txnMetadata.txnStartTimestamp) + } + } + } + } + } + + def handleAddPartitionsToTransaction(transactionalId: String, + producerId: Long, + producerEpoch: Short, + partitions: collection.Set[TopicPartition], + responseCallback: AddPartitionsCallback, + requestLocal: RequestLocal = RequestLocal.NoCaching): Unit = { + if (transactionalId == null || transactionalId.isEmpty) { + debug(s"Returning ${Errors.INVALID_REQUEST} error code to client for $transactionalId's AddPartitions request") + responseCallback(Errors.INVALID_REQUEST) + } else { + // try to update the transaction metadata and append the updated metadata to txn log; + // if there is no such metadata treat it as invalid producerId mapping error. + val result: ApiResult[(Int, TxnTransitMetadata)] = txnManager.getTransactionState(transactionalId).flatMap { + case None => Left(Errors.INVALID_PRODUCER_ID_MAPPING) + + case Some(epochAndMetadata) => + val coordinatorEpoch = epochAndMetadata.coordinatorEpoch + val txnMetadata = epochAndMetadata.transactionMetadata + + // generate the new transaction metadata with added partitions + txnMetadata.inLock { + if (txnMetadata.producerId != producerId) { + Left(Errors.INVALID_PRODUCER_ID_MAPPING) + } else if (txnMetadata.producerEpoch != producerEpoch) { + Left(Errors.PRODUCER_FENCED) + } else if (txnMetadata.pendingTransitionInProgress) { + // return a retriable exception to let the client backoff and retry + Left(Errors.CONCURRENT_TRANSACTIONS) + } else if (txnMetadata.state == PrepareCommit || txnMetadata.state == PrepareAbort) { + Left(Errors.CONCURRENT_TRANSACTIONS) + } else if (txnMetadata.state == Ongoing && partitions.subsetOf(txnMetadata.topicPartitions)) { + // this is an optimization: if the partitions are already in the metadata reply OK immediately + Left(Errors.NONE) + } else { + Right(coordinatorEpoch, txnMetadata.prepareAddPartitions(partitions.toSet, time.milliseconds())) + } + } + } + + result match { + case Left(err) => + debug(s"Returning $err error code to client for $transactionalId's AddPartitions request") + responseCallback(err) + + case Right((coordinatorEpoch, newMetadata)) => + txnManager.appendTransactionToLog(transactionalId, coordinatorEpoch, newMetadata, + responseCallback, requestLocal = requestLocal) + } + } + } + + /** + * Load state from the given partition and begin handling requests for groups which map to this partition. + * + * @param txnTopicPartitionId The partition that we are now leading + * @param coordinatorEpoch The partition coordinator (or leader) epoch from the received LeaderAndIsr request + */ + def onElection(txnTopicPartitionId: Int, coordinatorEpoch: Int): Unit = { + info(s"Elected as the txn coordinator for partition $txnTopicPartitionId at epoch $coordinatorEpoch") + // The operations performed during immigration must be resilient to any previous errors we saw or partial state we + // left off during the unloading phase. Ensure we remove all associated state for this partition before we continue + // loading it. + txnMarkerChannelManager.removeMarkersForTxnTopicPartition(txnTopicPartitionId) + + // Now load the partition. + txnManager.loadTransactionsForTxnTopicPartition(txnTopicPartitionId, coordinatorEpoch, + txnMarkerChannelManager.addTxnMarkersToSend) + } + + /** + * Clear coordinator caches for the given partition after giving up leadership. + * + * @param txnTopicPartitionId The partition that we are no longer leading + * @param coordinatorEpoch The partition coordinator (or leader) epoch, which may be absent if we + * are resigning after receiving a StopReplica request from the controller + */ + def onResignation(txnTopicPartitionId: Int, coordinatorEpoch: Option[Int]): Unit = { + info(s"Resigned as the txn coordinator for partition $txnTopicPartitionId at epoch $coordinatorEpoch") + coordinatorEpoch match { + case Some(epoch) => + txnManager.removeTransactionsForTxnTopicPartition(txnTopicPartitionId, epoch) + case None => + txnManager.removeTransactionsForTxnTopicPartition(txnTopicPartitionId) + } + txnMarkerChannelManager.removeMarkersForTxnTopicPartition(txnTopicPartitionId) + } + + private def logInvalidStateTransitionAndReturnError(transactionalId: String, + transactionState: TransactionState, + transactionResult: TransactionResult) = { + debug(s"TransactionalId: $transactionalId's state is $transactionState, but received transaction " + + s"marker result to send: $transactionResult") + Left(Errors.INVALID_TXN_STATE) + } + + def handleEndTransaction(transactionalId: String, + producerId: Long, + producerEpoch: Short, + txnMarkerResult: TransactionResult, + responseCallback: EndTxnCallback, + requestLocal: RequestLocal = RequestLocal.NoCaching): Unit = { + endTransaction(transactionalId, + producerId, + producerEpoch, + txnMarkerResult, + isFromClient = true, + responseCallback, + requestLocal) + } + + private def endTransaction(transactionalId: String, + producerId: Long, + producerEpoch: Short, + txnMarkerResult: TransactionResult, + isFromClient: Boolean, + responseCallback: EndTxnCallback, + requestLocal: RequestLocal): Unit = { + var isEpochFence = false + if (transactionalId == null || transactionalId.isEmpty) + responseCallback(Errors.INVALID_REQUEST) + else { + val preAppendResult: ApiResult[(Int, TxnTransitMetadata)] = txnManager.getTransactionState(transactionalId).flatMap { + case None => + Left(Errors.INVALID_PRODUCER_ID_MAPPING) + + case Some(epochAndTxnMetadata) => + val txnMetadata = epochAndTxnMetadata.transactionMetadata + val coordinatorEpoch = epochAndTxnMetadata.coordinatorEpoch + + txnMetadata.inLock { + if (txnMetadata.producerId != producerId) + Left(Errors.INVALID_PRODUCER_ID_MAPPING) + // Strict equality is enforced on the client side requests, as they shouldn't bump the producer epoch. + else if ((isFromClient && producerEpoch != txnMetadata.producerEpoch) || producerEpoch < txnMetadata.producerEpoch) + Left(Errors.PRODUCER_FENCED) + else if (txnMetadata.pendingTransitionInProgress && txnMetadata.pendingState.get != PrepareEpochFence) + Left(Errors.CONCURRENT_TRANSACTIONS) + else txnMetadata.state match { + case Ongoing => + val nextState = if (txnMarkerResult == TransactionResult.COMMIT) + PrepareCommit + else + PrepareAbort + + if (nextState == PrepareAbort && txnMetadata.pendingState.contains(PrepareEpochFence)) { + // We should clear the pending state to make way for the transition to PrepareAbort and also bump + // the epoch in the transaction metadata we are about to append. + isEpochFence = true + txnMetadata.pendingState = None + txnMetadata.producerEpoch = producerEpoch + txnMetadata.lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH + } + + Right(coordinatorEpoch, txnMetadata.prepareAbortOrCommit(nextState, time.milliseconds())) + case CompleteCommit => + if (txnMarkerResult == TransactionResult.COMMIT) + Left(Errors.NONE) + else + logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult) + case CompleteAbort => + if (txnMarkerResult == TransactionResult.ABORT) + Left(Errors.NONE) + else + logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult) + case PrepareCommit => + if (txnMarkerResult == TransactionResult.COMMIT) + Left(Errors.CONCURRENT_TRANSACTIONS) + else + logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult) + case PrepareAbort => + if (txnMarkerResult == TransactionResult.ABORT) + Left(Errors.CONCURRENT_TRANSACTIONS) + else + logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult) + case Empty => + logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult) + case Dead | PrepareEpochFence => + val errorMsg = s"Found transactionalId $transactionalId with state ${txnMetadata.state}. " + + s"This is illegal as we should never have transitioned to this state." + fatal(errorMsg) + throw new IllegalStateException(errorMsg) + + } + } + } + + preAppendResult match { + case Left(err) => + debug(s"Aborting append of $txnMarkerResult to transaction log with coordinator and returning $err error to client for $transactionalId's EndTransaction request") + responseCallback(err) + + case Right((coordinatorEpoch, newMetadata)) => + def sendTxnMarkersCallback(error: Errors): Unit = { + if (error == Errors.NONE) { + val preSendResult: ApiResult[(TransactionMetadata, TxnTransitMetadata)] = txnManager.getTransactionState(transactionalId).flatMap { + case None => + val errorMsg = s"The coordinator still owns the transaction partition for $transactionalId, but there is " + + s"no metadata in the cache; this is not expected" + fatal(errorMsg) + throw new IllegalStateException(errorMsg) + + case Some(epochAndMetadata) => + if (epochAndMetadata.coordinatorEpoch == coordinatorEpoch) { + val txnMetadata = epochAndMetadata.transactionMetadata + txnMetadata.inLock { + if (txnMetadata.producerId != producerId) + Left(Errors.INVALID_PRODUCER_ID_MAPPING) + else if (txnMetadata.producerEpoch != producerEpoch) + Left(Errors.PRODUCER_FENCED) + else if (txnMetadata.pendingTransitionInProgress) + Left(Errors.CONCURRENT_TRANSACTIONS) + else txnMetadata.state match { + case Empty| Ongoing | CompleteCommit | CompleteAbort => + logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult) + case PrepareCommit => + if (txnMarkerResult != TransactionResult.COMMIT) + logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult) + else + Right(txnMetadata, txnMetadata.prepareComplete(time.milliseconds())) + case PrepareAbort => + if (txnMarkerResult != TransactionResult.ABORT) + logInvalidStateTransitionAndReturnError(transactionalId, txnMetadata.state, txnMarkerResult) + else + Right(txnMetadata, txnMetadata.prepareComplete(time.milliseconds())) + case Dead | PrepareEpochFence => + val errorMsg = s"Found transactionalId $transactionalId with state ${txnMetadata.state}. " + + s"This is illegal as we should never have transitioned to this state." + fatal(errorMsg) + throw new IllegalStateException(errorMsg) + + } + } + } else { + debug(s"The transaction coordinator epoch has changed to ${epochAndMetadata.coordinatorEpoch} after $txnMarkerResult was " + + s"successfully appended to the log for $transactionalId with old epoch $coordinatorEpoch") + Left(Errors.NOT_COORDINATOR) + } + } + + preSendResult match { + case Left(err) => + info(s"Aborting sending of transaction markers after appended $txnMarkerResult to transaction log and returning $err error to client for $transactionalId's EndTransaction request") + responseCallback(err) + + case Right((txnMetadata, newPreSendMetadata)) => + // we can respond to the client immediately and continue to write the txn markers if + // the log append was successful + responseCallback(Errors.NONE) + + txnMarkerChannelManager.addTxnMarkersToSend(coordinatorEpoch, txnMarkerResult, txnMetadata, newPreSendMetadata) + } + } else { + info(s"Aborting sending of transaction markers and returning $error error to client for $transactionalId's EndTransaction request of $txnMarkerResult, " + + s"since appending $newMetadata to transaction log with coordinator epoch $coordinatorEpoch failed") + + if (isEpochFence) { + txnManager.getTransactionState(transactionalId).foreach { + case None => + warn(s"The coordinator still owns the transaction partition for $transactionalId, but there is " + + s"no metadata in the cache; this is not expected") + + case Some(epochAndMetadata) => + if (epochAndMetadata.coordinatorEpoch == coordinatorEpoch) { + // This was attempted epoch fence that failed, so mark this state on the metadata + epochAndMetadata.transactionMetadata.hasFailedEpochFence = true + warn(s"The coordinator failed to write an epoch fence transition for producer $transactionalId to the transaction log " + + s"with error $error. The epoch was increased to ${newMetadata.producerEpoch} but not returned to the client") + } + } + } + + responseCallback(error) + } + } + + txnManager.appendTransactionToLog(transactionalId, coordinatorEpoch, newMetadata, + sendTxnMarkersCallback, requestLocal = requestLocal) + } + } + } + + def transactionTopicConfigs: Properties = txnManager.transactionTopicConfigs + + def partitionFor(transactionalId: String): Int = txnManager.partitionFor(transactionalId) + + private def onEndTransactionComplete(txnIdAndPidEpoch: TransactionalIdAndProducerIdEpoch)(error: Errors): Unit = { + error match { + case Errors.NONE => + info("Completed rollback of ongoing transaction for transactionalId " + + s"${txnIdAndPidEpoch.transactionalId} due to timeout") + + case error@(Errors.INVALID_PRODUCER_ID_MAPPING | + Errors.PRODUCER_FENCED | + Errors.CONCURRENT_TRANSACTIONS) => + debug(s"Rollback of ongoing transaction for transactionalId ${txnIdAndPidEpoch.transactionalId} " + + s"has been cancelled due to error $error") + + case error => + warn(s"Rollback of ongoing transaction for transactionalId ${txnIdAndPidEpoch.transactionalId} " + + s"failed due to error $error") + } + } + + private[transaction] def abortTimedOutTransactions(onComplete: TransactionalIdAndProducerIdEpoch => EndTxnCallback): Unit = { + + txnManager.timedOutTransactions().foreach { txnIdAndPidEpoch => + txnManager.getTransactionState(txnIdAndPidEpoch.transactionalId).foreach { + case None => + error(s"Could not find transaction metadata when trying to timeout transaction for $txnIdAndPidEpoch") + + case Some(epochAndTxnMetadata) => + val txnMetadata = epochAndTxnMetadata.transactionMetadata + val transitMetadataOpt = txnMetadata.inLock { + if (txnMetadata.producerId != txnIdAndPidEpoch.producerId) { + error(s"Found incorrect producerId when expiring transactionalId: ${txnIdAndPidEpoch.transactionalId}. " + + s"Expected producerId: ${txnIdAndPidEpoch.producerId}. Found producerId: " + + s"${txnMetadata.producerId}") + None + } else if (txnMetadata.pendingTransitionInProgress) { + debug(s"Skipping abort of timed out transaction $txnIdAndPidEpoch since there is a " + + "pending state transition") + None + } else { + Some(txnMetadata.prepareFenceProducerEpoch()) + } + } + + transitMetadataOpt.foreach { txnTransitMetadata => + endTransaction(txnMetadata.transactionalId, + txnTransitMetadata.producerId, + txnTransitMetadata.producerEpoch, + TransactionResult.ABORT, + isFromClient = false, + onComplete(txnIdAndPidEpoch), + RequestLocal.NoCaching) + } + } + } + } + + /** + * Startup logic executed at the same time when the server starts up. + */ + def startup(retrieveTransactionTopicPartitionCount: () => Int, enableTransactionalIdExpiration: Boolean = true): Unit = { + info("Starting up.") + scheduler.startup() + scheduler.schedule("transaction-abort", + () => abortTimedOutTransactions(onEndTransactionComplete), + txnConfig.abortTimedOutTransactionsIntervalMs, + txnConfig.abortTimedOutTransactionsIntervalMs + ) + txnManager.startup(retrieveTransactionTopicPartitionCount, enableTransactionalIdExpiration) + txnMarkerChannelManager.start() + isActive.set(true) + + info("Startup complete.") + } + + /** + * Shutdown logic executed at the same time when server shuts down. + * Ordering of actions should be reversed from the startup process. + */ + def shutdown(): Unit = { + info("Shutting down.") + isActive.set(false) + scheduler.shutdown() + producerIdManager.shutdown() + txnManager.shutdown() + txnMarkerChannelManager.shutdown() + info("Shutdown complete.") + } +} + +case class InitProducerIdResult(producerId: Long, producerEpoch: Short, error: Errors) diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala new file mode 100644 index 0000000..cb501f7 --- /dev/null +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.coordinator.transaction + +import java.io.PrintStream +import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets + +import kafka.internals.generated.{TransactionLogKey, TransactionLogValue} +import org.apache.kafka.clients.consumer.ConsumerRecord +import org.apache.kafka.common.protocol.{ByteBufferAccessor, MessageUtil} +import org.apache.kafka.common.record.{CompressionType, Record, RecordBatch} +import org.apache.kafka.common.{MessageFormatter, TopicPartition} + +import scala.collection.mutable +import scala.jdk.CollectionConverters._ + +/** + * Messages stored for the transaction topic represent the producer id and transactional status of the corresponding + * transactional id, which have versions for both the key and value fields. Key and value + * versions are used to evolve the message formats: + * + * key version 0: [transactionalId] + * -> value version 0: [producer_id, producer_epoch, expire_timestamp, status, [topic, [partition] ], timestamp] + */ +object TransactionLog { + + // log-level config default values and enforced values + val DefaultNumPartitions: Int = 50 + val DefaultSegmentBytes: Int = 100 * 1024 * 1024 + val DefaultReplicationFactor: Short = 3.toShort + val DefaultMinInSyncReplicas: Int = 2 + val DefaultLoadBufferSize: Int = 5 * 1024 * 1024 + + // enforce always using + // 1. cleanup policy = compact + // 2. compression = none + // 3. unclean leader election = disabled + // 4. required acks = -1 when writing + val EnforcedCompressionType: CompressionType = CompressionType.NONE + val EnforcedRequiredAcks: Short = (-1).toShort + + /** + * Generates the bytes for transaction log message key + * + * @return key bytes + */ + private[transaction] def keyToBytes(transactionalId: String): Array[Byte] = { + MessageUtil.toVersionPrefixedBytes(TransactionLogKey.HIGHEST_SUPPORTED_VERSION, + new TransactionLogKey().setTransactionalId(transactionalId)) + } + + /** + * Generates the payload bytes for transaction log message value + * + * @return value payload bytes + */ + private[transaction] def valueToBytes(txnMetadata: TxnTransitMetadata): Array[Byte] = { + if (txnMetadata.txnState == Empty && txnMetadata.topicPartitions.nonEmpty) + throw new IllegalStateException(s"Transaction is not expected to have any partitions since its state is ${txnMetadata.txnState}: $txnMetadata") + + val transactionPartitions = if (txnMetadata.txnState == Empty) null + else txnMetadata.topicPartitions + .groupBy(_.topic) + .map { case (topic, partitions) => + new TransactionLogValue.PartitionsSchema() + .setTopic(topic) + .setPartitionIds(partitions.map(tp => Integer.valueOf(tp.partition)).toList.asJava) + }.toList.asJava + + MessageUtil.toVersionPrefixedBytes(TransactionLogValue.HIGHEST_SUPPORTED_VERSION, + new TransactionLogValue() + .setProducerId(txnMetadata.producerId) + .setProducerEpoch(txnMetadata.producerEpoch) + .setTransactionTimeoutMs(txnMetadata.txnTimeoutMs) + .setTransactionStatus(txnMetadata.txnState.id) + .setTransactionLastUpdateTimestampMs(txnMetadata.txnLastUpdateTimestamp) + .setTransactionStartTimestampMs(txnMetadata.txnStartTimestamp) + .setTransactionPartitions(transactionPartitions)) + } + + /** + * Decodes the transaction log messages' key + * + * @return the key + */ + def readTxnRecordKey(buffer: ByteBuffer): TxnKey = { + val version = buffer.getShort + if (version >= TransactionLogKey.LOWEST_SUPPORTED_VERSION && version <= TransactionLogKey.HIGHEST_SUPPORTED_VERSION) { + val value = new TransactionLogKey(new ByteBufferAccessor(buffer), version) + TxnKey( + version = version, + transactionalId = value.transactionalId + ) + } else throw new IllegalStateException(s"Unknown version $version from the transaction log message") + } + + /** + * Decodes the transaction log messages' payload and retrieves the transaction metadata from it + * + * @return a transaction metadata object from the message + */ + def readTxnRecordValue(transactionalId: String, buffer: ByteBuffer): Option[TransactionMetadata] = { + // tombstone + if (buffer == null) None + else { + val version = buffer.getShort + if (version >= TransactionLogValue.LOWEST_SUPPORTED_VERSION && version <= TransactionLogValue.HIGHEST_SUPPORTED_VERSION) { + val value = new TransactionLogValue(new ByteBufferAccessor(buffer), version) + val transactionMetadata = new TransactionMetadata( + transactionalId = transactionalId, + producerId = value.producerId, + lastProducerId = RecordBatch.NO_PRODUCER_ID, + producerEpoch = value.producerEpoch, + lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, + txnTimeoutMs = value.transactionTimeoutMs, + state = TransactionState.fromId(value.transactionStatus), + topicPartitions = mutable.Set.empty[TopicPartition], + txnStartTimestamp = value.transactionStartTimestampMs, + txnLastUpdateTimestamp = value.transactionLastUpdateTimestampMs) + + if (!transactionMetadata.state.equals(Empty)) + value.transactionPartitions.forEach(partitionsSchema => + transactionMetadata.addPartitions(partitionsSchema.partitionIds + .asScala + .map(partitionId => new TopicPartition(partitionsSchema.topic, partitionId)) + .toSet) + ) + Some(transactionMetadata) + } else throw new IllegalStateException(s"Unknown version $version from the transaction log message value") + } + } + + // Formatter for use with tools to read transaction log messages + class TransactionLogMessageFormatter extends MessageFormatter { + def writeTo(consumerRecord: ConsumerRecord[Array[Byte], Array[Byte]], output: PrintStream): Unit = { + Option(consumerRecord.key).map(key => readTxnRecordKey(ByteBuffer.wrap(key))).foreach { txnKey => + val transactionalId = txnKey.transactionalId + val value = consumerRecord.value + val producerIdMetadata = if (value == null) + None + else + readTxnRecordValue(transactionalId, ByteBuffer.wrap(value)) + output.write(transactionalId.getBytes(StandardCharsets.UTF_8)) + output.write("::".getBytes(StandardCharsets.UTF_8)) + output.write(producerIdMetadata.getOrElse("NULL").toString.getBytes(StandardCharsets.UTF_8)) + output.write("\n".getBytes(StandardCharsets.UTF_8)) + } + } + } + + /** + * Exposed for printing records using [[kafka.tools.DumpLogSegments]] + */ + def formatRecordKeyAndValue(record: Record): (Option[String], Option[String]) = { + val txnKey = TransactionLog.readTxnRecordKey(record.key) + val keyString = s"transaction_metadata::transactionalId=${txnKey.transactionalId}" + + val valueString = TransactionLog.readTxnRecordValue(txnKey.transactionalId, record.value) match { + case None => "" + + case Some(txnMetadata) => s"producerId:${txnMetadata.producerId}," + + s"producerEpoch:${txnMetadata.producerEpoch}," + + s"state=${txnMetadata.state}," + + s"partitions=${txnMetadata.topicPartitions.mkString("[", ",", "]")}," + + s"txnLastUpdateTimestamp=${txnMetadata.txnLastUpdateTimestamp}," + + s"txnTimeoutMs=${txnMetadata.txnTimeoutMs}" + } + + (Some(keyString), Some(valueString)) + } + +} + +case class TxnKey(version: Short, transactionalId: String) { + override def toString: String = transactionalId +} diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala new file mode 100644 index 0000000..62c70d9 --- /dev/null +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala @@ -0,0 +1,432 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.coordinator.transaction + + +import java.util +import java.util.concurrent.{BlockingQueue, ConcurrentHashMap, LinkedBlockingQueue} +import kafka.api.KAFKA_2_8_IV0 +import kafka.common.{InterBrokerSendThread, RequestAndCompletionHandler} +import kafka.metrics.KafkaMetricsGroup +import kafka.server.{KafkaConfig, MetadataCache, RequestLocal} +import kafka.utils.Implicits._ +import kafka.utils.{CoreUtils, Logging} +import org.apache.kafka.clients._ +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.network._ +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.WriteTxnMarkersRequest.TxnMarkerEntry +import org.apache.kafka.common.requests.{TransactionResult, WriteTxnMarkersRequest} +import org.apache.kafka.common.security.JaasContext +import org.apache.kafka.common.utils.{LogContext, Time} +import org.apache.kafka.common.{Node, Reconfigurable, TopicPartition} + +import scala.collection.{concurrent, immutable} +import scala.jdk.CollectionConverters._ + +object TransactionMarkerChannelManager { + def apply(config: KafkaConfig, + metrics: Metrics, + metadataCache: MetadataCache, + txnStateManager: TransactionStateManager, + time: Time, + logContext: LogContext): TransactionMarkerChannelManager = { + val channelBuilder = ChannelBuilders.clientChannelBuilder( + config.interBrokerSecurityProtocol, + JaasContext.Type.SERVER, + config, + config.interBrokerListenerName, + config.saslMechanismInterBrokerProtocol, + time, + config.saslInterBrokerHandshakeRequestEnable, + logContext + ) + channelBuilder match { + case reconfigurable: Reconfigurable => config.addReconfigurable(reconfigurable) + case _ => + } + val selector = new Selector( + NetworkReceive.UNLIMITED, + config.connectionsMaxIdleMs, + metrics, + time, + "txn-marker-channel", + Map.empty[String, String].asJava, + false, + channelBuilder, + logContext + ) + val networkClient = new NetworkClient( + selector, + new ManualMetadataUpdater(), + s"broker-${config.brokerId}-txn-marker-sender", + 1, + 50, + 50, + Selectable.USE_DEFAULT_BUFFER_SIZE, + config.socketReceiveBufferBytes, + config.requestTimeoutMs, + config.connectionSetupTimeoutMs, + config.connectionSetupTimeoutMaxMs, + time, + false, + new ApiVersions, + logContext + ) + + new TransactionMarkerChannelManager(config, + metadataCache, + networkClient, + txnStateManager, + time + ) + } + +} + +class TxnMarkerQueue(@volatile var destination: Node) { + + // keep track of the requests per txn topic partition so we can easily clear the queue + // during partition emigration + private val markersPerTxnTopicPartition = new ConcurrentHashMap[Int, BlockingQueue[TxnIdAndMarkerEntry]]().asScala + + def removeMarkersForTxnTopicPartition(partition: Int): Option[BlockingQueue[TxnIdAndMarkerEntry]] = { + markersPerTxnTopicPartition.remove(partition) + } + + def addMarkers(txnTopicPartition: Int, txnIdAndMarker: TxnIdAndMarkerEntry): Unit = { + val queue = CoreUtils.atomicGetOrUpdate(markersPerTxnTopicPartition, txnTopicPartition, + new LinkedBlockingQueue[TxnIdAndMarkerEntry]()) + queue.add(txnIdAndMarker) + } + + def forEachTxnTopicPartition[B](f:(Int, BlockingQueue[TxnIdAndMarkerEntry]) => B): Unit = + markersPerTxnTopicPartition.forKeyValue { (partition, queue) => + if (!queue.isEmpty) f(partition, queue) + } + + def totalNumMarkers: Int = markersPerTxnTopicPartition.values.foldLeft(0) { _ + _.size } + + // visible for testing + def totalNumMarkers(txnTopicPartition: Int): Int = markersPerTxnTopicPartition.get(txnTopicPartition).fold(0)(_.size) +} + +class TransactionMarkerChannelManager( + config: KafkaConfig, + metadataCache: MetadataCache, + networkClient: NetworkClient, + txnStateManager: TransactionStateManager, + time: Time +) extends InterBrokerSendThread("TxnMarkerSenderThread-" + config.brokerId, networkClient, config.requestTimeoutMs, time) + with Logging with KafkaMetricsGroup { + + this.logIdent = "[Transaction Marker Channel Manager " + config.brokerId + "]: " + + private val interBrokerListenerName: ListenerName = config.interBrokerListenerName + + private val markersQueuePerBroker: concurrent.Map[Int, TxnMarkerQueue] = new ConcurrentHashMap[Int, TxnMarkerQueue]().asScala + + private val markersQueueForUnknownBroker = new TxnMarkerQueue(Node.noNode) + + private val txnLogAppendRetryQueue = new LinkedBlockingQueue[PendingCompleteTxn]() + + private val transactionsWithPendingMarkers = new ConcurrentHashMap[String, PendingCompleteTxn] + + val writeTxnMarkersRequestVersion: Short = + if (config.interBrokerProtocolVersion >= KAFKA_2_8_IV0) 1 + else 0 + + newGauge("UnknownDestinationQueueSize", () => markersQueueForUnknownBroker.totalNumMarkers) + newGauge("LogAppendRetryQueueSize", () => txnLogAppendRetryQueue.size) + + override def shutdown(): Unit = { + super.shutdown() + markersQueuePerBroker.clear() + } + + // visible for testing + private[transaction] def queueForBroker(brokerId: Int) = { + markersQueuePerBroker.get(brokerId) + } + + // visible for testing + private[transaction] def queueForUnknownBroker = markersQueueForUnknownBroker + + private[transaction] def addMarkersForBroker(broker: Node, txnTopicPartition: Int, txnIdAndMarker: TxnIdAndMarkerEntry): Unit = { + val brokerId = broker.id + + // we do not synchronize on the update of the broker node with the enqueuing, + // since even if there is a race condition we will just retry + val brokerRequestQueue = CoreUtils.atomicGetOrUpdate(markersQueuePerBroker, brokerId, + new TxnMarkerQueue(broker)) + brokerRequestQueue.destination = broker + brokerRequestQueue.addMarkers(txnTopicPartition, txnIdAndMarker) + + trace(s"Added marker ${txnIdAndMarker.txnMarkerEntry} for transactional id ${txnIdAndMarker.txnId} to destination broker $brokerId") + } + + def retryLogAppends(): Unit = { + val txnLogAppendRetries: java.util.List[PendingCompleteTxn] = new util.ArrayList[PendingCompleteTxn]() + txnLogAppendRetryQueue.drainTo(txnLogAppendRetries) + txnLogAppendRetries.forEach { txnLogAppend => + debug(s"Retry appending $txnLogAppend transaction log") + tryAppendToLog(txnLogAppend) + } + } + + override def generateRequests(): Iterable[RequestAndCompletionHandler] = { + retryLogAppends() + val txnIdAndMarkerEntries: java.util.List[TxnIdAndMarkerEntry] = new util.ArrayList[TxnIdAndMarkerEntry]() + markersQueueForUnknownBroker.forEachTxnTopicPartition { case (_, queue) => + queue.drainTo(txnIdAndMarkerEntries) + } + + for (txnIdAndMarker: TxnIdAndMarkerEntry <- txnIdAndMarkerEntries.asScala) { + val transactionalId = txnIdAndMarker.txnId + val producerId = txnIdAndMarker.txnMarkerEntry.producerId + val producerEpoch = txnIdAndMarker.txnMarkerEntry.producerEpoch + val txnResult = txnIdAndMarker.txnMarkerEntry.transactionResult + val coordinatorEpoch = txnIdAndMarker.txnMarkerEntry.coordinatorEpoch + val topicPartitions = txnIdAndMarker.txnMarkerEntry.partitions.asScala.toSet + + addTxnMarkersToBrokerQueue(transactionalId, producerId, producerEpoch, txnResult, coordinatorEpoch, topicPartitions) + } + + val currentTimeMs = time.milliseconds() + markersQueuePerBroker.values.map { brokerRequestQueue => + val txnIdAndMarkerEntries = new util.ArrayList[TxnIdAndMarkerEntry]() + brokerRequestQueue.forEachTxnTopicPartition { case (_, queue) => + queue.drainTo(txnIdAndMarkerEntries) + } + (brokerRequestQueue.destination, txnIdAndMarkerEntries) + }.filter { case (_, entries) => !entries.isEmpty }.map { case (node, entries) => + val markersToSend = entries.asScala.map(_.txnMarkerEntry).asJava + val requestCompletionHandler = new TransactionMarkerRequestCompletionHandler(node.id, txnStateManager, this, entries) + val request = new WriteTxnMarkersRequest.Builder(writeTxnMarkersRequestVersion, markersToSend) + + RequestAndCompletionHandler( + currentTimeMs, + node, + request, + requestCompletionHandler + ) + } + } + + private def writeTxnCompletion(pendingCompleteTxn: PendingCompleteTxn): Unit = { + val transactionalId = pendingCompleteTxn.transactionalId + val txnMetadata = pendingCompleteTxn.txnMetadata + val newMetadata = pendingCompleteTxn.newMetadata + val coordinatorEpoch = pendingCompleteTxn.coordinatorEpoch + + trace(s"Completed sending transaction markers for $transactionalId; begin transition " + + s"to ${newMetadata.txnState}") + + txnStateManager.getTransactionState(transactionalId) match { + case Left(Errors.NOT_COORDINATOR) => + info(s"No longer the coordinator for $transactionalId with coordinator epoch " + + s"$coordinatorEpoch; cancel appending $newMetadata to transaction log") + + case Left(Errors.COORDINATOR_LOAD_IN_PROGRESS) => + info(s"Loading the transaction partition that contains $transactionalId while my " + + s"current coordinator epoch is $coordinatorEpoch; so cancel appending $newMetadata to " + + s"transaction log since the loading process will continue the remaining work") + + case Left(unexpectedError) => + throw new IllegalStateException(s"Unhandled error $unexpectedError when fetching current transaction state") + + case Right(Some(epochAndMetadata)) => + if (epochAndMetadata.coordinatorEpoch == coordinatorEpoch) { + debug(s"Sending $transactionalId's transaction markers for $txnMetadata with " + + s"coordinator epoch $coordinatorEpoch succeeded, trying to append complete transaction log now") + tryAppendToLog(PendingCompleteTxn(transactionalId, coordinatorEpoch, txnMetadata, newMetadata)) + } else { + info(s"The cached metadata $txnMetadata has changed to $epochAndMetadata after " + + s"completed sending the markers with coordinator epoch $coordinatorEpoch; abort " + + s"transiting the metadata to $newMetadata as it may have been updated by another process") + } + + case Right(None) => + val errorMsg = s"The coordinator still owns the transaction partition for $transactionalId, " + + s"but there is no metadata in the cache; this is not expected" + fatal(errorMsg) + throw new IllegalStateException(errorMsg) + } + } + + def addTxnMarkersToSend(coordinatorEpoch: Int, + txnResult: TransactionResult, + txnMetadata: TransactionMetadata, + newMetadata: TxnTransitMetadata): Unit = { + val transactionalId = txnMetadata.transactionalId + val pendingCompleteTxn = PendingCompleteTxn( + transactionalId, + coordinatorEpoch, + txnMetadata, + newMetadata) + + transactionsWithPendingMarkers.put(transactionalId, pendingCompleteTxn) + addTxnMarkersToBrokerQueue(transactionalId, txnMetadata.producerId, + txnMetadata.producerEpoch, txnResult, coordinatorEpoch, txnMetadata.topicPartitions.toSet) + maybeWriteTxnCompletion(transactionalId) + } + + def numTxnsWithPendingMarkers: Int = transactionsWithPendingMarkers.size + + private def hasPendingMarkersToWrite(txnMetadata: TransactionMetadata): Boolean = { + txnMetadata.inLock { + txnMetadata.topicPartitions.nonEmpty + } + } + + def maybeWriteTxnCompletion(transactionalId: String): Unit = { + Option(transactionsWithPendingMarkers.get(transactionalId)).foreach { pendingCompleteTxn => + if (!hasPendingMarkersToWrite(pendingCompleteTxn.txnMetadata) && + transactionsWithPendingMarkers.remove(transactionalId, pendingCompleteTxn)) { + writeTxnCompletion(pendingCompleteTxn) + } + } + } + + private def tryAppendToLog(txnLogAppend: PendingCompleteTxn): Unit = { + // try to append to the transaction log + def appendCallback(error: Errors): Unit = + error match { + case Errors.NONE => + trace(s"Completed transaction for ${txnLogAppend.transactionalId} with coordinator epoch ${txnLogAppend.coordinatorEpoch}, final state after commit: ${txnLogAppend.txnMetadata.state}") + + case Errors.NOT_COORDINATOR => + info(s"No longer the coordinator for transactionalId: ${txnLogAppend.transactionalId} while trying to append to transaction log, skip writing to transaction log") + + case Errors.COORDINATOR_NOT_AVAILABLE => + info(s"Not available to append $txnLogAppend: possible causes include ${Errors.UNKNOWN_TOPIC_OR_PARTITION}, ${Errors.NOT_ENOUGH_REPLICAS}, " + + s"${Errors.NOT_ENOUGH_REPLICAS_AFTER_APPEND} and ${Errors.REQUEST_TIMED_OUT}; retry appending") + + // enqueue for retry + txnLogAppendRetryQueue.add(txnLogAppend) + + case Errors.COORDINATOR_LOAD_IN_PROGRESS => + info(s"Coordinator is loading the partition ${txnStateManager.partitionFor(txnLogAppend.transactionalId)} and hence cannot complete append of $txnLogAppend; " + + s"skip writing to transaction log as the loading process should complete it") + + case other: Errors => + val errorMsg = s"Unexpected error ${other.exceptionName} while appending to transaction log for ${txnLogAppend.transactionalId}" + fatal(errorMsg) + throw new IllegalStateException(errorMsg) + } + + txnStateManager.appendTransactionToLog(txnLogAppend.transactionalId, txnLogAppend.coordinatorEpoch, + txnLogAppend.newMetadata, appendCallback, _ == Errors.COORDINATOR_NOT_AVAILABLE, RequestLocal.NoCaching) + } + + def addTxnMarkersToBrokerQueue(transactionalId: String, + producerId: Long, + producerEpoch: Short, + result: TransactionResult, + coordinatorEpoch: Int, + topicPartitions: immutable.Set[TopicPartition]): Unit = { + val txnTopicPartition = txnStateManager.partitionFor(transactionalId) + val partitionsByDestination: immutable.Map[Option[Node], immutable.Set[TopicPartition]] = topicPartitions.groupBy { topicPartition: TopicPartition => + metadataCache.getPartitionLeaderEndpoint(topicPartition.topic, topicPartition.partition, interBrokerListenerName) + } + + for ((broker: Option[Node], topicPartitions: immutable.Set[TopicPartition]) <- partitionsByDestination) { + broker match { + case Some(brokerNode) => + val marker = new TxnMarkerEntry(producerId, producerEpoch, coordinatorEpoch, result, topicPartitions.toList.asJava) + val txnIdAndMarker = TxnIdAndMarkerEntry(transactionalId, marker) + + if (brokerNode == Node.noNode) { + // if the leader of the partition is known but node not available, put it into an unknown broker queue + // and let the sender thread to look for its broker and migrate them later + markersQueueForUnknownBroker.addMarkers(txnTopicPartition, txnIdAndMarker) + } else { + addMarkersForBroker(brokerNode, txnTopicPartition, txnIdAndMarker) + } + + case None => + txnStateManager.getTransactionState(transactionalId) match { + case Left(error) => + info(s"Encountered $error trying to fetch transaction metadata for $transactionalId with coordinator epoch $coordinatorEpoch; cancel sending markers to its partition leaders") + transactionsWithPendingMarkers.remove(transactionalId) + + case Right(Some(epochAndMetadata)) => + if (epochAndMetadata.coordinatorEpoch != coordinatorEpoch) { + info(s"The cached metadata has changed to $epochAndMetadata (old coordinator epoch is $coordinatorEpoch) since preparing to send markers; cancel sending markers to its partition leaders") + transactionsWithPendingMarkers.remove(transactionalId) + } else { + // if the leader of the partition is unknown, skip sending the txn marker since + // the partition is likely to be deleted already + info(s"Couldn't find leader endpoint for partitions $topicPartitions while trying to send transaction markers for " + + s"$transactionalId, these partitions are likely deleted already and hence can be skipped") + + val txnMetadata = epochAndMetadata.transactionMetadata + + txnMetadata.inLock { + topicPartitions.foreach(txnMetadata.removePartition) + } + + maybeWriteTxnCompletion(transactionalId) + } + + case Right(None) => + val errorMsg = s"The coordinator still owns the transaction partition for $transactionalId, but there is " + + s"no metadata in the cache; this is not expected" + fatal(errorMsg) + throw new IllegalStateException(errorMsg) + + } + } + } + + wakeup() + } + + def removeMarkersForTxnTopicPartition(txnTopicPartitionId: Int): Unit = { + markersQueueForUnknownBroker.removeMarkersForTxnTopicPartition(txnTopicPartitionId).foreach { queue => + for (entry: TxnIdAndMarkerEntry <- queue.asScala) + removeMarkersForTxnId(entry.txnId) + } + + markersQueuePerBroker.foreach { case(_, brokerQueue) => + brokerQueue.removeMarkersForTxnTopicPartition(txnTopicPartitionId).foreach { queue => + for (entry: TxnIdAndMarkerEntry <- queue.asScala) + removeMarkersForTxnId(entry.txnId) + } + } + } + + def removeMarkersForTxnId(transactionalId: String): Unit = { + transactionsWithPendingMarkers.remove(transactionalId) + } +} + +case class TxnIdAndMarkerEntry(txnId: String, txnMarkerEntry: TxnMarkerEntry) + +case class PendingCompleteTxn(transactionalId: String, + coordinatorEpoch: Int, + txnMetadata: TransactionMetadata, + newMetadata: TxnTransitMetadata) { + + override def toString: String = { + "PendingCompleteTxn(" + + s"transactionalId=$transactionalId, " + + s"coordinatorEpoch=$coordinatorEpoch, " + + s"txnMetadata=$txnMetadata, " + + s"newMetadata=$newMetadata)" + } +} diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandler.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandler.scala new file mode 100644 index 0000000..848e0fa --- /dev/null +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandler.scala @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.coordinator.transaction + +import kafka.utils.Logging +import org.apache.kafka.clients.{ClientResponse, RequestCompletionHandler} +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.WriteTxnMarkersResponse + +import scala.collection.mutable +import scala.jdk.CollectionConverters._ + +class TransactionMarkerRequestCompletionHandler(brokerId: Int, + txnStateManager: TransactionStateManager, + txnMarkerChannelManager: TransactionMarkerChannelManager, + txnIdAndMarkerEntries: java.util.List[TxnIdAndMarkerEntry]) extends RequestCompletionHandler with Logging { + + this.logIdent = "[Transaction Marker Request Completion Handler " + brokerId + "]: " + + override def onComplete(response: ClientResponse): Unit = { + val requestHeader = response.requestHeader + val correlationId = requestHeader.correlationId + if (response.wasDisconnected) { + trace(s"Cancelled request with header $requestHeader due to node ${response.destination} being disconnected") + + for (txnIdAndMarker <- txnIdAndMarkerEntries.asScala) { + val transactionalId = txnIdAndMarker.txnId + val txnMarker = txnIdAndMarker.txnMarkerEntry + + txnStateManager.getTransactionState(transactionalId) match { + + case Left(Errors.NOT_COORDINATOR) => + info(s"I am no longer the coordinator for $transactionalId; cancel sending transaction markers $txnMarker to the brokers") + + txnMarkerChannelManager.removeMarkersForTxnId(transactionalId) + + case Left(Errors.COORDINATOR_LOAD_IN_PROGRESS) => + info(s"I am loading the transaction partition that contains $transactionalId which means the current markers have to be obsoleted; " + + s"cancel sending transaction markers $txnMarker to the brokers") + + txnMarkerChannelManager.removeMarkersForTxnId(transactionalId) + + case Left(unexpectedError) => + throw new IllegalStateException(s"Unhandled error $unexpectedError when fetching current transaction state") + + case Right(None) => + throw new IllegalStateException(s"The coordinator still owns the transaction partition for $transactionalId, but there is " + + s"no metadata in the cache; this is not expected") + + case Right(Some(epochAndMetadata)) => + if (epochAndMetadata.coordinatorEpoch != txnMarker.coordinatorEpoch) { + // coordinator epoch has changed, just cancel it from the purgatory + info(s"Transaction coordinator epoch for $transactionalId has changed from ${txnMarker.coordinatorEpoch} to " + + s"${epochAndMetadata.coordinatorEpoch}; cancel sending transaction markers $txnMarker to the brokers") + + txnMarkerChannelManager.removeMarkersForTxnId(transactionalId) + } else { + // re-enqueue the markers with possibly new destination brokers + trace(s"Re-enqueuing ${txnMarker.transactionResult} transaction markers for transactional id $transactionalId " + + s"under coordinator epoch ${txnMarker.coordinatorEpoch}") + + txnMarkerChannelManager.addTxnMarkersToBrokerQueue(transactionalId, + txnMarker.producerId, + txnMarker.producerEpoch, + txnMarker.transactionResult, + txnMarker.coordinatorEpoch, + txnMarker.partitions.asScala.toSet) + } + } + } + } else { + debug(s"Received WriteTxnMarker response $response from node ${response.destination} with correlation id $correlationId") + + val writeTxnMarkerResponse = response.responseBody.asInstanceOf[WriteTxnMarkersResponse] + + val responseErrors = writeTxnMarkerResponse.errorsByProducerId; + for (txnIdAndMarker <- txnIdAndMarkerEntries.asScala) { + val transactionalId = txnIdAndMarker.txnId + val txnMarker = txnIdAndMarker.txnMarkerEntry + val errors = responseErrors.get(txnMarker.producerId) + + if (errors == null) + throw new IllegalStateException(s"WriteTxnMarkerResponse does not contain expected error map for producer id ${txnMarker.producerId}") + + txnStateManager.getTransactionState(transactionalId) match { + case Left(Errors.NOT_COORDINATOR) => + info(s"I am no longer the coordinator for $transactionalId; cancel sending transaction markers $txnMarker to the brokers") + + txnMarkerChannelManager.removeMarkersForTxnId(transactionalId) + + case Left(Errors.COORDINATOR_LOAD_IN_PROGRESS) => + info(s"I am loading the transaction partition that contains $transactionalId which means the current markers have to be obsoleted; " + + s"cancel sending transaction markers $txnMarker to the brokers") + + txnMarkerChannelManager.removeMarkersForTxnId(transactionalId) + + case Left(unexpectedError) => + throw new IllegalStateException(s"Unhandled error $unexpectedError when fetching current transaction state") + + case Right(None) => + throw new IllegalStateException(s"The coordinator still owns the transaction partition for $transactionalId, but there is " + + s"no metadata in the cache; this is not expected") + + case Right(Some(epochAndMetadata)) => + val txnMetadata = epochAndMetadata.transactionMetadata + val retryPartitions: mutable.Set[TopicPartition] = mutable.Set.empty[TopicPartition] + var abortSending: Boolean = false + + if (epochAndMetadata.coordinatorEpoch != txnMarker.coordinatorEpoch) { + // coordinator epoch has changed, just cancel it from the purgatory + info(s"Transaction coordinator epoch for $transactionalId has changed from ${txnMarker.coordinatorEpoch} to " + + s"${epochAndMetadata.coordinatorEpoch}; cancel sending transaction markers $txnMarker to the brokers") + + txnMarkerChannelManager.removeMarkersForTxnId(transactionalId) + abortSending = true + } else { + txnMetadata.inLock { + for ((topicPartition, error) <- errors.asScala) { + error match { + case Errors.NONE => + txnMetadata.removePartition(topicPartition) + + case Errors.CORRUPT_MESSAGE | + Errors.MESSAGE_TOO_LARGE | + Errors.RECORD_LIST_TOO_LARGE | + Errors.INVALID_REQUIRED_ACKS => // these are all unexpected and fatal errors + + throw new IllegalStateException(s"Received fatal error ${error.exceptionName} while sending txn marker for $transactionalId") + + case Errors.UNKNOWN_TOPIC_OR_PARTITION | + Errors.NOT_LEADER_OR_FOLLOWER | + Errors.NOT_ENOUGH_REPLICAS | + Errors.NOT_ENOUGH_REPLICAS_AFTER_APPEND | + Errors.REQUEST_TIMED_OUT | + Errors.KAFKA_STORAGE_ERROR => // these are retriable errors + + info(s"Sending $transactionalId's transaction marker for partition $topicPartition has failed with error ${error.exceptionName}, retrying " + + s"with current coordinator epoch ${epochAndMetadata.coordinatorEpoch}") + + retryPartitions += topicPartition + + case Errors.INVALID_PRODUCER_EPOCH | + Errors.TRANSACTION_COORDINATOR_FENCED => // producer or coordinator epoch has changed, this txn can now be ignored + + info(s"Sending $transactionalId's transaction marker for partition $topicPartition has permanently failed with error ${error.exceptionName} " + + s"with the current coordinator epoch ${epochAndMetadata.coordinatorEpoch}; cancel sending any more transaction markers $txnMarker to the brokers") + + txnMarkerChannelManager.removeMarkersForTxnId(transactionalId) + abortSending = true + + case Errors.UNSUPPORTED_FOR_MESSAGE_FORMAT | + Errors.UNSUPPORTED_VERSION => + // The producer would have failed to send data to the failed topic so we can safely remove the partition + // from the set waiting for markers + info(s"Sending $transactionalId's transaction marker from partition $topicPartition has failed with " + + s" ${error.name}. This partition will be removed from the set of partitions" + + s" waiting for completion") + txnMetadata.removePartition(topicPartition) + + case other => + throw new IllegalStateException(s"Unexpected error ${other.exceptionName} while sending txn marker for $transactionalId") + } + } + } + } + + if (!abortSending) { + if (retryPartitions.nonEmpty) { + debug(s"Re-enqueuing ${txnMarker.transactionResult} transaction markers for transactional id $transactionalId " + + s"under coordinator epoch ${txnMarker.coordinatorEpoch}") + + // re-enqueue with possible new leaders of the partitions + txnMarkerChannelManager.addTxnMarkersToBrokerQueue( + transactionalId, + txnMarker.producerId, + txnMarker.producerEpoch, + txnMarker.transactionResult, + txnMarker.coordinatorEpoch, + retryPartitions.toSet) + } else { + txnMarkerChannelManager.maybeWriteTxnCompletion(transactionalId) + } + } + } + } + } + } +} diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala new file mode 100644 index 0000000..0f6d4b7 --- /dev/null +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionMetadata.scala @@ -0,0 +1,546 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.coordinator.transaction + +import java.util.concurrent.locks.ReentrantLock + +import kafka.utils.{CoreUtils, Logging, nonthreadsafe} +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.record.RecordBatch + +import scala.collection.{immutable, mutable} + + +object TransactionState { + val AllStates = Set( + Empty, + Ongoing, + PrepareCommit, + PrepareAbort, + CompleteCommit, + CompleteAbort, + Dead, + PrepareEpochFence + ) + + def fromName(name: String): Option[TransactionState] = { + AllStates.find(_.name == name) + } + + def fromId(id: Byte): TransactionState = { + id match { + case 0 => Empty + case 1 => Ongoing + case 2 => PrepareCommit + case 3 => PrepareAbort + case 4 => CompleteCommit + case 5 => CompleteAbort + case 6 => Dead + case 7 => PrepareEpochFence + case _ => throw new IllegalStateException(s"Unknown transaction state id $id from the transaction status message") + } + } +} + +private[transaction] sealed trait TransactionState { + def id: Byte + + /** + * Get the name of this state. This is exposed through the `DescribeTransactions` API. + */ + def name: String + + def validPreviousStates: Set[TransactionState] + + def isExpirationAllowed: Boolean = false +} + +/** + * Transaction has not existed yet + * + * transition: received AddPartitionsToTxnRequest => Ongoing + * received AddOffsetsToTxnRequest => Ongoing + */ +private[transaction] case object Empty extends TransactionState { + val id: Byte = 0 + val name: String = "Empty" + val validPreviousStates: Set[TransactionState] = Set(Empty, CompleteCommit, CompleteAbort) + override def isExpirationAllowed: Boolean = true +} + +/** + * Transaction has started and ongoing + * + * transition: received EndTxnRequest with commit => PrepareCommit + * received EndTxnRequest with abort => PrepareAbort + * received AddPartitionsToTxnRequest => Ongoing + * received AddOffsetsToTxnRequest => Ongoing + */ +private[transaction] case object Ongoing extends TransactionState { + val id: Byte = 1 + val name: String = "Ongoing" + val validPreviousStates: Set[TransactionState] = Set(Ongoing, Empty, CompleteCommit, CompleteAbort) +} + +/** + * Group is preparing to commit + * + * transition: received acks from all partitions => CompleteCommit + */ +private[transaction] case object PrepareCommit extends TransactionState { + val id: Byte = 2 + val name: String = "PrepareCommit" + val validPreviousStates: Set[TransactionState] = Set(Ongoing) +} + +/** + * Group is preparing to abort + * + * transition: received acks from all partitions => CompleteAbort + */ +private[transaction] case object PrepareAbort extends TransactionState { + val id: Byte = 3 + val name: String = "PrepareAbort" + val validPreviousStates: Set[TransactionState] = Set(Ongoing, PrepareEpochFence) +} + +/** + * Group has completed commit + * + * Will soon be removed from the ongoing transaction cache + */ +private[transaction] case object CompleteCommit extends TransactionState { + val id: Byte = 4 + val name: String = "CompleteCommit" + val validPreviousStates: Set[TransactionState] = Set(PrepareCommit) + override def isExpirationAllowed: Boolean = true +} + +/** + * Group has completed abort + * + * Will soon be removed from the ongoing transaction cache + */ +private[transaction] case object CompleteAbort extends TransactionState { + val id: Byte = 5 + val name: String = "CompleteAbort" + val validPreviousStates: Set[TransactionState] = Set(PrepareAbort) + override def isExpirationAllowed: Boolean = true +} + +/** + * TransactionalId has expired and is about to be removed from the transaction cache + */ +private[transaction] case object Dead extends TransactionState { + val id: Byte = 6 + val name: String = "Dead" + val validPreviousStates: Set[TransactionState] = Set(Empty, CompleteAbort, CompleteCommit) +} + +/** + * We are in the middle of bumping the epoch and fencing out older producers. + */ + +private[transaction] case object PrepareEpochFence extends TransactionState { + val id: Byte = 7 + val name: String = "PrepareEpochFence" + val validPreviousStates: Set[TransactionState] = Set(Ongoing) +} + +private[transaction] object TransactionMetadata { + def apply(transactionalId: String, producerId: Long, producerEpoch: Short, txnTimeoutMs: Int, timestamp: Long) = + new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Empty, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp) + + def apply(transactionalId: String, producerId: Long, producerEpoch: Short, txnTimeoutMs: Int, + state: TransactionState, timestamp: Long) = + new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, state, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp) + + def apply(transactionalId: String, producerId: Long, lastProducerId: Long, producerEpoch: Short, + lastProducerEpoch: Short, txnTimeoutMs: Int, state: TransactionState, timestamp: Long) = + new TransactionMetadata(transactionalId, producerId, lastProducerId, producerEpoch, lastProducerEpoch, + txnTimeoutMs, state, collection.mutable.Set.empty[TopicPartition], timestamp, timestamp) + + def isEpochExhausted(producerEpoch: Short): Boolean = producerEpoch >= Short.MaxValue - 1 +} + +// this is a immutable object representing the target transition of the transaction metadata +private[transaction] case class TxnTransitMetadata(producerId: Long, + lastProducerId: Long, + producerEpoch: Short, + lastProducerEpoch: Short, + txnTimeoutMs: Int, + txnState: TransactionState, + topicPartitions: immutable.Set[TopicPartition], + txnStartTimestamp: Long, + txnLastUpdateTimestamp: Long) { + override def toString: String = { + "TxnTransitMetadata(" + + s"producerId=$producerId, " + + s"lastProducerId=$lastProducerId, " + + s"producerEpoch=$producerEpoch, " + + s"lastProducerEpoch=$lastProducerEpoch, " + + s"txnTimeoutMs=$txnTimeoutMs, " + + s"txnState=$txnState, " + + s"topicPartitions=$topicPartitions, " + + s"txnStartTimestamp=$txnStartTimestamp, " + + s"txnLastUpdateTimestamp=$txnLastUpdateTimestamp)" + } +} + +/** + * + * @param producerId producer id + * @param lastProducerId last producer id assigned to the producer + * @param producerEpoch current epoch of the producer + * @param lastProducerEpoch last epoch of the producer + * @param txnTimeoutMs timeout to be used to abort long running transactions + * @param state current state of the transaction + * @param topicPartitions current set of partitions that are part of this transaction + * @param txnStartTimestamp time the transaction was started, i.e., when first partition is added + * @param txnLastUpdateTimestamp updated when any operation updates the TransactionMetadata. To be used for expiration + */ +@nonthreadsafe +private[transaction] class TransactionMetadata(val transactionalId: String, + var producerId: Long, + var lastProducerId: Long, + var producerEpoch: Short, + var lastProducerEpoch: Short, + var txnTimeoutMs: Int, + var state: TransactionState, + val topicPartitions: mutable.Set[TopicPartition], + @volatile var txnStartTimestamp: Long = -1, + @volatile var txnLastUpdateTimestamp: Long) extends Logging { + + // pending state is used to indicate the state that this transaction is going to + // transit to, and for blocking future attempts to transit it again if it is not legal; + // initialized as the same as the current state + var pendingState: Option[TransactionState] = None + + // Indicates that during a previous attempt to fence a producer, the bumped epoch may not have been + // successfully written to the log. If this is true, we will not bump the epoch again when fencing + var hasFailedEpochFence: Boolean = false + + private[transaction] val lock = new ReentrantLock + + def inLock[T](fun: => T): T = CoreUtils.inLock(lock)(fun) + + def addPartitions(partitions: collection.Set[TopicPartition]): Unit = { + topicPartitions ++= partitions + } + + def removePartition(topicPartition: TopicPartition): Unit = { + if (state != PrepareCommit && state != PrepareAbort) + throw new IllegalStateException(s"Transaction metadata's current state is $state, and its pending state is $pendingState " + + s"while trying to remove partitions whose txn marker has been sent, this is not expected") + + topicPartitions -= topicPartition + } + + // this is visible for test only + def prepareNoTransit(): TxnTransitMetadata = { + // do not call transitTo as it will set the pending state, a follow-up call to abort the transaction will set its pending state + TxnTransitMetadata(producerId, lastProducerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, state, topicPartitions.toSet, + txnStartTimestamp, txnLastUpdateTimestamp) + } + + def prepareFenceProducerEpoch(): TxnTransitMetadata = { + if (producerEpoch == Short.MaxValue) + throw new IllegalStateException(s"Cannot fence producer with epoch equal to Short.MaxValue since this would overflow") + + // If we've already failed to fence an epoch (because the write to the log failed), we don't increase it again. + // This is safe because we never return the epoch to client if we fail to fence the epoch + val bumpedEpoch = if (hasFailedEpochFence) producerEpoch else (producerEpoch + 1).toShort + + prepareTransitionTo(PrepareEpochFence, producerId, bumpedEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, + topicPartitions.toSet, txnStartTimestamp, txnLastUpdateTimestamp) + } + + def prepareIncrementProducerEpoch(newTxnTimeoutMs: Int, + expectedProducerEpoch: Option[Short], + updateTimestamp: Long): Either[Errors, TxnTransitMetadata] = { + if (isProducerEpochExhausted) + throw new IllegalStateException(s"Cannot allocate any more producer epochs for producerId $producerId") + + val bumpedEpoch = (producerEpoch + 1).toShort + val epochBumpResult: Either[Errors, (Short, Short)] = expectedProducerEpoch match { + case None => + // If no expected epoch was provided by the producer, bump the current epoch and set the last epoch to -1 + // In the case of a new producer, producerEpoch will be -1 and bumpedEpoch will be 0 + Right(bumpedEpoch, RecordBatch.NO_PRODUCER_EPOCH) + + case Some(expectedEpoch) => + if (producerEpoch == RecordBatch.NO_PRODUCER_EPOCH || expectedEpoch == producerEpoch) + // If the expected epoch matches the current epoch, or if there is no current epoch, the producer is attempting + // to continue after an error and no other producer has been initialized. Bump the current and last epochs. + // The no current epoch case means this is a new producer; producerEpoch will be -1 and bumpedEpoch will be 0 + Right(bumpedEpoch, producerEpoch) + else if (expectedEpoch == lastProducerEpoch) + // If the expected epoch matches the previous epoch, it is a retry of a successful call, so just return the + // current epoch without bumping. There is no danger of this producer being fenced, because a new producer + // calling InitProducerId would have caused the last epoch to be set to -1. + // Note that if the IBP is prior to 2.4.IV1, the lastProducerId and lastProducerEpoch will not be written to + // the transaction log, so a retry that spans a coordinator change will fail. We expect this to be a rare case. + Right(producerEpoch, lastProducerEpoch) + else { + // Otherwise, the producer has a fenced epoch and should receive an PRODUCER_FENCED error + info(s"Expected producer epoch $expectedEpoch does not match current " + + s"producer epoch $producerEpoch or previous producer epoch $lastProducerEpoch") + Left(Errors.PRODUCER_FENCED) + } + } + + epochBumpResult match { + case Right((nextEpoch, lastEpoch)) => Right(prepareTransitionTo(Empty, producerId, nextEpoch, lastEpoch, newTxnTimeoutMs, + immutable.Set.empty[TopicPartition], -1, updateTimestamp)) + + case Left(err) => Left(err) + } + } + + def prepareProducerIdRotation(newProducerId: Long, + newTxnTimeoutMs: Int, + updateTimestamp: Long, + recordLastEpoch: Boolean): TxnTransitMetadata = { + if (hasPendingTransaction) + throw new IllegalStateException("Cannot rotate producer ids while a transaction is still pending") + + prepareTransitionTo(Empty, newProducerId, 0, if (recordLastEpoch) producerEpoch else RecordBatch.NO_PRODUCER_EPOCH, + newTxnTimeoutMs, immutable.Set.empty[TopicPartition], -1, updateTimestamp) + } + + def prepareAddPartitions(addedTopicPartitions: immutable.Set[TopicPartition], updateTimestamp: Long): TxnTransitMetadata = { + val newTxnStartTimestamp = state match { + case Empty | CompleteAbort | CompleteCommit => updateTimestamp + case _ => txnStartTimestamp + } + + prepareTransitionTo(Ongoing, producerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, + (topicPartitions ++ addedTopicPartitions).toSet, newTxnStartTimestamp, updateTimestamp) + } + + def prepareAbortOrCommit(newState: TransactionState, updateTimestamp: Long): TxnTransitMetadata = { + prepareTransitionTo(newState, producerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, topicPartitions.toSet, + txnStartTimestamp, updateTimestamp) + } + + def prepareComplete(updateTimestamp: Long): TxnTransitMetadata = { + val newState = if (state == PrepareCommit) CompleteCommit else CompleteAbort + + // Since the state change was successfully written to the log, unset the flag for a failed epoch fence + hasFailedEpochFence = false + prepareTransitionTo(newState, producerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, Set.empty[TopicPartition], + txnStartTimestamp, updateTimestamp) + } + + def prepareDead(): TxnTransitMetadata = { + prepareTransitionTo(Dead, producerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, Set.empty[TopicPartition], + txnStartTimestamp, txnLastUpdateTimestamp) + } + + /** + * Check if the epochs have been exhausted for the current producerId. We do not allow the client to use an + * epoch equal to Short.MaxValue to ensure that the coordinator will always be able to fence an existing producer. + */ + def isProducerEpochExhausted: Boolean = TransactionMetadata.isEpochExhausted(producerEpoch) + + private def hasPendingTransaction: Boolean = { + state match { + case Ongoing | PrepareAbort | PrepareCommit => true + case _ => false + } + } + + private def prepareTransitionTo(newState: TransactionState, + newProducerId: Long, + newEpoch: Short, + newLastEpoch: Short, + newTxnTimeoutMs: Int, + newTopicPartitions: immutable.Set[TopicPartition], + newTxnStartTimestamp: Long, + updateTimestamp: Long): TxnTransitMetadata = { + if (pendingState.isDefined) + throw new IllegalStateException(s"Preparing transaction state transition to $newState " + + s"while it already a pending state ${pendingState.get}") + + if (newProducerId < 0) + throw new IllegalArgumentException(s"Illegal new producer id $newProducerId") + + if (newEpoch < 0) + throw new IllegalArgumentException(s"Illegal new producer epoch $newEpoch") + + // check that the new state transition is valid and update the pending state if necessary + if (newState.validPreviousStates.contains(state)) { + val transitMetadata = TxnTransitMetadata(newProducerId, producerId, newEpoch, newLastEpoch, newTxnTimeoutMs, newState, + newTopicPartitions, newTxnStartTimestamp, updateTimestamp) + debug(s"TransactionalId $transactionalId prepare transition from $state to $transitMetadata") + pendingState = Some(newState) + transitMetadata + } else { + throw new IllegalStateException(s"Preparing transaction state transition to $newState failed since the target state" + + s" $newState is not a valid previous state of the current state $state") + } + } + + def completeTransitionTo(transitMetadata: TxnTransitMetadata): Unit = { + // metadata transition is valid only if all the following conditions are met: + // + // 1. the new state is already indicated in the pending state. + // 2. the epoch should be either the same value, the old value + 1, or 0 if we have a new producerId. + // 3. the last update time is no smaller than the old value. + // 4. the old partitions set is a subset of the new partitions set. + // + // plus, we should only try to update the metadata after the corresponding log entry has been successfully + // written and replicated (see TransactionStateManager#appendTransactionToLog) + // + // if valid, transition is done via overwriting the whole object to ensure synchronization + + val toState = pendingState.getOrElse { + fatal(s"$this's transition to $transitMetadata failed since pendingState is not defined: this should not happen") + + throw new IllegalStateException(s"TransactionalId $transactionalId " + + "completing transaction state transition while it does not have a pending state") + } + + if (toState != transitMetadata.txnState) { + throwStateTransitionFailure(transitMetadata) + } else { + toState match { + case Empty => // from initPid + if ((producerEpoch != transitMetadata.producerEpoch && !validProducerEpochBump(transitMetadata)) || + transitMetadata.topicPartitions.nonEmpty || + transitMetadata.txnStartTimestamp != -1) { + + throwStateTransitionFailure(transitMetadata) + } else { + txnTimeoutMs = transitMetadata.txnTimeoutMs + producerEpoch = transitMetadata.producerEpoch + lastProducerEpoch = transitMetadata.lastProducerEpoch + producerId = transitMetadata.producerId + lastProducerId = transitMetadata.lastProducerId + } + + case Ongoing => // from addPartitions + if (!validProducerEpoch(transitMetadata) || + !topicPartitions.subsetOf(transitMetadata.topicPartitions) || + txnTimeoutMs != transitMetadata.txnTimeoutMs) { + + throwStateTransitionFailure(transitMetadata) + } else { + txnStartTimestamp = transitMetadata.txnStartTimestamp + addPartitions(transitMetadata.topicPartitions) + } + + case PrepareAbort | PrepareCommit => // from endTxn + if (!validProducerEpoch(transitMetadata) || + !topicPartitions.toSet.equals(transitMetadata.topicPartitions) || + txnTimeoutMs != transitMetadata.txnTimeoutMs || + txnStartTimestamp != transitMetadata.txnStartTimestamp) { + + throwStateTransitionFailure(transitMetadata) + } + + case CompleteAbort | CompleteCommit => // from write markers + if (!validProducerEpoch(transitMetadata) || + txnTimeoutMs != transitMetadata.txnTimeoutMs || + transitMetadata.txnStartTimestamp == -1) { + + throwStateTransitionFailure(transitMetadata) + } else { + txnStartTimestamp = transitMetadata.txnStartTimestamp + topicPartitions.clear() + } + + case PrepareEpochFence => + // We should never get here, since once we prepare to fence the epoch, we immediately set the pending state + // to PrepareAbort, and then consequently to CompleteAbort after the markers are written.. So we should never + // ever try to complete a transition to PrepareEpochFence, as it is not a valid previous state for any other state, and hence + // can never be transitioned out of. + throwStateTransitionFailure(transitMetadata) + + + case Dead => + // The transactionalId was being expired. The completion of the operation should result in removal of the + // the metadata from the cache, so we should never realistically transition to the dead state. + throw new IllegalStateException(s"TransactionalId $transactionalId is trying to complete a transition to " + + s"$toState. This means that the transactionalId was being expired, and the only acceptable completion of " + + s"this operation is to remove the transaction metadata from the cache, not to persist the $toState in the log.") + } + + debug(s"TransactionalId $transactionalId complete transition from $state to $transitMetadata") + txnLastUpdateTimestamp = transitMetadata.txnLastUpdateTimestamp + pendingState = None + state = toState + } + } + + private def validProducerEpoch(transitMetadata: TxnTransitMetadata): Boolean = { + val transitEpoch = transitMetadata.producerEpoch + val transitProducerId = transitMetadata.producerId + transitEpoch == producerEpoch && transitProducerId == producerId + } + + private def validProducerEpochBump(transitMetadata: TxnTransitMetadata): Boolean = { + val transitEpoch = transitMetadata.producerEpoch + val transitProducerId = transitMetadata.producerId + transitEpoch == producerEpoch + 1 || (transitEpoch == 0 && transitProducerId != producerId) + } + + private def throwStateTransitionFailure(txnTransitMetadata: TxnTransitMetadata): Unit = { + fatal(s"${this.toString}'s transition to $txnTransitMetadata failed: this should not happen") + + throw new IllegalStateException(s"TransactionalId $transactionalId failed transition to state $txnTransitMetadata " + + "due to unexpected metadata") + } + + def pendingTransitionInProgress: Boolean = pendingState.isDefined + + override def toString: String = { + "TransactionMetadata(" + + s"transactionalId=$transactionalId, " + + s"producerId=$producerId, " + + s"producerEpoch=$producerEpoch, " + + s"txnTimeoutMs=$txnTimeoutMs, " + + s"state=$state, " + + s"pendingState=$pendingState, " + + s"topicPartitions=$topicPartitions, " + + s"txnStartTimestamp=$txnStartTimestamp, " + + s"txnLastUpdateTimestamp=$txnLastUpdateTimestamp)" + } + + override def equals(that: Any): Boolean = that match { + case other: TransactionMetadata => + transactionalId == other.transactionalId && + producerId == other.producerId && + producerEpoch == other.producerEpoch && + lastProducerEpoch == other.lastProducerEpoch && + txnTimeoutMs == other.txnTimeoutMs && + state.equals(other.state) && + topicPartitions.equals(other.topicPartitions) && + txnStartTimestamp == other.txnStartTimestamp && + txnLastUpdateTimestamp == other.txnLastUpdateTimestamp + case _ => false + } + + override def hashCode(): Int = { + val fields = Seq(transactionalId, producerId, producerEpoch, txnTimeoutMs, state, topicPartitions, + txnStartTimestamp, txnLastUpdateTimestamp) + fields.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b) + } +} diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala new file mode 100644 index 0000000..217b383 --- /dev/null +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala @@ -0,0 +1,824 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.coordinator.transaction + +import java.nio.ByteBuffer +import java.util.Properties +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.locks.ReentrantReadWriteLock + +import kafka.log.{AppendOrigin, LogConfig} +import kafka.message.UncompressedCodec +import kafka.server.{Defaults, FetchLogEnd, ReplicaManager, RequestLocal} +import kafka.utils.CoreUtils.{inReadLock, inWriteLock} +import kafka.utils.{Logging, Pool, Scheduler} +import kafka.utils.Implicits._ +import org.apache.kafka.common.internals.Topic +import org.apache.kafka.common.message.ListTransactionsResponseData +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.metrics.stats.{Avg, Max} +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.record.{FileRecords, MemoryRecords, MemoryRecordsBuilder, Record, SimpleRecord, TimestampType} +import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse +import org.apache.kafka.common.requests.TransactionResult +import org.apache.kafka.common.utils.{Time, Utils} +import org.apache.kafka.common.{KafkaException, TopicPartition} + +import scala.jdk.CollectionConverters._ +import scala.collection.mutable + + +object TransactionStateManager { + // default transaction management config values + val DefaultTransactionsMaxTimeoutMs: Int = TimeUnit.MINUTES.toMillis(15).toInt + val DefaultTransactionalIdExpirationMs: Int = TimeUnit.DAYS.toMillis(7).toInt + val DefaultAbortTimedOutTransactionsIntervalMs: Int = TimeUnit.SECONDS.toMillis(10).toInt + val DefaultRemoveExpiredTransactionalIdsIntervalMs: Int = TimeUnit.HOURS.toMillis(1).toInt + + val MetricsGroup: String = "transaction-coordinator-metrics" + val LoadTimeSensor: String = "TransactionsPartitionLoadTime" +} + +/** + * Transaction state manager is part of the transaction coordinator, it manages: + * + * 1. the transaction log, which is a special internal topic. + * 2. the transaction metadata including its ongoing transaction status. + * 3. the background expiration of the transaction as well as the transactional id. + * + * Delayed operation locking notes: + * Delayed operations in TransactionStateManager use individual operation locks. + * Delayed callbacks may acquire `stateLock.readLock` or any of the `txnMetadata` locks, + * but we always require that `stateLock.readLock` be acquired first. In particular: + *

              + *
            • `stateLock.readLock` must never be acquired while holding `txnMetadata` lock.
            • + *
            • `txnMetadata` lock must never be acquired while holding `stateLock.writeLock`.
            • + *
            • `ReplicaManager.appendRecords` should never be invoked while holding a `txnMetadata` lock.
            • + *
            + */ +class TransactionStateManager(brokerId: Int, + scheduler: Scheduler, + replicaManager: ReplicaManager, + config: TransactionConfig, + time: Time, + metrics: Metrics) extends Logging { + + this.logIdent = "[Transaction State Manager " + brokerId + "]: " + + type SendTxnMarkersCallback = (Int, TransactionResult, TransactionMetadata, TxnTransitMetadata) => Unit + + /** shutting down flag */ + private val shuttingDown = new AtomicBoolean(false) + + /** lock protecting access to the transactional metadata cache, including loading and leaving partition sets */ + private val stateLock = new ReentrantReadWriteLock() + + /** partitions of transaction topic that are being loaded, state lock should be called BEFORE accessing this set */ + private[transaction] val loadingPartitions: mutable.Set[TransactionPartitionAndLeaderEpoch] = mutable.Set() + + /** transaction metadata cache indexed by assigned transaction topic partition ids */ + private[transaction] val transactionMetadataCache: mutable.Map[Int, TxnMetadataCacheEntry] = mutable.Map() + + /** number of partitions for the transaction log topic */ + private var retrieveTransactionTopicPartitionCount: () => Int = _ + @volatile private var transactionTopicPartitionCount: Int = _ + + /** setup metrics*/ + private val partitionLoadSensor = metrics.sensor(TransactionStateManager.LoadTimeSensor) + + partitionLoadSensor.add(metrics.metricName("partition-load-time-max", + TransactionStateManager.MetricsGroup, + "The max time it took to load the partitions in the last 30sec"), new Max()) + partitionLoadSensor.add(metrics.metricName("partition-load-time-avg", + TransactionStateManager.MetricsGroup, + "The avg time it took to load the partitions in the last 30sec"), new Avg()) + + // visible for testing only + private[transaction] def addLoadingPartition(partitionId: Int, coordinatorEpoch: Int): Unit = { + val partitionAndLeaderEpoch = TransactionPartitionAndLeaderEpoch(partitionId, coordinatorEpoch) + inWriteLock(stateLock) { + loadingPartitions.add(partitionAndLeaderEpoch) + } + } + + // this is best-effort expiration of an ongoing transaction which has been open for more than its + // txn timeout value, we do not need to grab the lock on the metadata object upon checking its state + // since the timestamp is volatile and we will get the lock when actually trying to transit the transaction + // metadata to abort later. + def timedOutTransactions(): Iterable[TransactionalIdAndProducerIdEpoch] = { + val now = time.milliseconds() + inReadLock(stateLock) { + transactionMetadataCache.flatMap { case (_, entry) => + entry.metadataPerTransactionalId.filter { case (_, txnMetadata) => + if (txnMetadata.pendingTransitionInProgress) { + false + } else { + txnMetadata.state match { + case Ongoing => + txnMetadata.txnStartTimestamp + txnMetadata.txnTimeoutMs < now + case _ => false + } + } + }.map { case (txnId, txnMetadata) => + TransactionalIdAndProducerIdEpoch(txnId, txnMetadata.producerId, txnMetadata.producerEpoch) + } + } + } + } + + private def removeExpiredTransactionalIds( + transactionPartition: TopicPartition, + txnMetadataCacheEntry: TxnMetadataCacheEntry, + ): Unit = { + inReadLock(stateLock) { + replicaManager.getLogConfig(transactionPartition) match { + case Some(logConfig) => + val currentTimeMs = time.milliseconds() + val maxBatchSize = logConfig.maxMessageSize + val expired = mutable.ListBuffer.empty[TransactionalIdCoordinatorEpochAndMetadata] + var recordsBuilder: MemoryRecordsBuilder = null + val stateEntries = txnMetadataCacheEntry.metadataPerTransactionalId.values.iterator.buffered + + def flushRecordsBuilder(): Unit = { + writeTombstonesForExpiredTransactionalIds( + transactionPartition, + expired.toSeq, + recordsBuilder.build() + ) + expired.clear() + recordsBuilder = null + } + + while (stateEntries.hasNext) { + val txnMetadata = stateEntries.head + val transactionalId = txnMetadata.transactionalId + var fullBatch = false + + txnMetadata.inLock { + if (txnMetadata.pendingState.isEmpty && shouldExpire(txnMetadata, currentTimeMs)) { + if (recordsBuilder == null) { + recordsBuilder = MemoryRecords.builder( + ByteBuffer.allocate(math.min(16384, maxBatchSize)), + TransactionLog.EnforcedCompressionType, + TimestampType.CREATE_TIME, + 0L, + maxBatchSize + ) + } + + if (maybeAppendExpiration(txnMetadata, recordsBuilder, currentTimeMs)) { + val transitMetadata = txnMetadata.prepareDead() + expired += TransactionalIdCoordinatorEpochAndMetadata( + transactionalId, + txnMetadataCacheEntry.coordinatorEpoch, + transitMetadata + ) + } else { + fullBatch = true + } + } + } + + if (fullBatch) { + flushRecordsBuilder() + } else { + // Advance the iterator if we do not need to retry the append + stateEntries.next() + } + } + + if (expired.nonEmpty) { + flushRecordsBuilder() + } + + case None => + warn(s"Transaction expiration for partition $transactionPartition failed because the log " + + "config was not available, which likely means the partition is not online or is no longer local.") + } + } + } + + private def shouldExpire( + txnMetadata: TransactionMetadata, + currentTimeMs: Long + ): Boolean = { + txnMetadata.state.isExpirationAllowed && + txnMetadata.txnLastUpdateTimestamp <= currentTimeMs - config.transactionalIdExpirationMs + } + + private def maybeAppendExpiration( + txnMetadata: TransactionMetadata, + recordsBuilder: MemoryRecordsBuilder, + currentTimeMs: Long, + ): Boolean = { + val keyBytes = TransactionLog.keyToBytes(txnMetadata.transactionalId) + if (recordsBuilder.hasRoomFor(currentTimeMs, keyBytes, null, Record.EMPTY_HEADERS)) { + recordsBuilder.append(currentTimeMs, keyBytes, null, Record.EMPTY_HEADERS) + true + } else { + false + } + } + + private[transaction] def removeExpiredTransactionalIds(): Unit = { + inReadLock(stateLock) { + transactionMetadataCache.forKeyValue { (partitionId, partitionCacheEntry) => + val transactionPartition = new TopicPartition(Topic.TRANSACTION_STATE_TOPIC_NAME, partitionId) + removeExpiredTransactionalIds(transactionPartition, partitionCacheEntry) + } + } + } + + private def writeTombstonesForExpiredTransactionalIds( + transactionPartition: TopicPartition, + expiredForPartition: Iterable[TransactionalIdCoordinatorEpochAndMetadata], + tombstoneRecords: MemoryRecords + ): Unit = { + def removeFromCacheCallback(responses: collection.Map[TopicPartition, PartitionResponse]): Unit = { + responses.forKeyValue { (topicPartition, response) => + inReadLock(stateLock) { + transactionMetadataCache.get(topicPartition.partition).foreach { txnMetadataCacheEntry => + expiredForPartition.foreach { idCoordinatorEpochAndMetadata => + val transactionalId = idCoordinatorEpochAndMetadata.transactionalId + val txnMetadata = txnMetadataCacheEntry.metadataPerTransactionalId.get(transactionalId) + txnMetadata.inLock { + if (txnMetadataCacheEntry.coordinatorEpoch == idCoordinatorEpochAndMetadata.coordinatorEpoch + && txnMetadata.pendingState.contains(Dead) + && txnMetadata.producerEpoch == idCoordinatorEpochAndMetadata.transitMetadata.producerEpoch + && response.error == Errors.NONE) { + txnMetadataCacheEntry.metadataPerTransactionalId.remove(transactionalId) + } else { + warn(s"Failed to remove expired transactionalId: $transactionalId" + + s" from cache. Tombstone append error code: ${response.error}," + + s" pendingState: ${txnMetadata.pendingState}, producerEpoch: ${txnMetadata.producerEpoch}," + + s" expected producerEpoch: ${idCoordinatorEpochAndMetadata.transitMetadata.producerEpoch}," + + s" coordinatorEpoch: ${txnMetadataCacheEntry.coordinatorEpoch}, expected coordinatorEpoch: " + + s"${idCoordinatorEpochAndMetadata.coordinatorEpoch}") + txnMetadata.pendingState = None + } + } + } + } + } + } + } + + inReadLock(stateLock) { + replicaManager.appendRecords( + config.requestTimeoutMs, + TransactionLog.EnforcedRequiredAcks, + internalTopicsAllowed = true, + origin = AppendOrigin.Coordinator, + entriesPerPartition = Map(transactionPartition -> tombstoneRecords), + removeFromCacheCallback, + requestLocal = RequestLocal.NoCaching) + } + } + + def enableTransactionalIdExpiration(): Unit = { + scheduler.schedule( + name = "transactionalId-expiration", + fun = removeExpiredTransactionalIds, + delay = config.removeExpiredTransactionalIdsIntervalMs, + period = config.removeExpiredTransactionalIdsIntervalMs + ) + } + + def getTransactionState(transactionalId: String): Either[Errors, Option[CoordinatorEpochAndTxnMetadata]] = { + getAndMaybeAddTransactionState(transactionalId, None) + } + + def putTransactionStateIfNotExists(txnMetadata: TransactionMetadata): Either[Errors, CoordinatorEpochAndTxnMetadata] = { + getAndMaybeAddTransactionState(txnMetadata.transactionalId, Some(txnMetadata)).map(_.getOrElse( + throw new IllegalStateException(s"Unexpected empty transaction metadata returned while putting $txnMetadata"))) + } + + def listTransactionStates( + filterProducerIds: Set[Long], + filterStateNames: Set[String] + ): ListTransactionsResponseData = { + inReadLock(stateLock) { + val response = new ListTransactionsResponseData() + if (loadingPartitions.nonEmpty) { + response.setErrorCode(Errors.COORDINATOR_LOAD_IN_PROGRESS.code) + } else { + val filterStates = mutable.Set.empty[TransactionState] + filterStateNames.foreach { stateName => + TransactionState.fromName(stateName) match { + case Some(state) => filterStates += state + case None => response.unknownStateFilters.add(stateName) + } + } + + def shouldInclude(txnMetadata: TransactionMetadata): Boolean = { + if (txnMetadata.state == Dead) { + // We filter the `Dead` state since it is a transient state which + // indicates that the transactionalId and its metadata are in the + // process of expiration and removal. + false + } else if (filterProducerIds.nonEmpty && !filterProducerIds.contains(txnMetadata.producerId)) { + false + } else if (filterStateNames.nonEmpty && !filterStates.contains(txnMetadata.state)) { + false + } else { + true + } + } + + val states = new java.util.ArrayList[ListTransactionsResponseData.TransactionState] + transactionMetadataCache.forKeyValue { (_, cache) => + cache.metadataPerTransactionalId.values.foreach { txnMetadata => + txnMetadata.inLock { + if (shouldInclude(txnMetadata)) { + states.add(new ListTransactionsResponseData.TransactionState() + .setTransactionalId(txnMetadata.transactionalId) + .setProducerId(txnMetadata.producerId) + .setTransactionState(txnMetadata.state.name) + ) + } + } + } + } + response.setErrorCode(Errors.NONE.code) + .setTransactionStates(states) + } + } + } + + /** + * Get the transaction metadata associated with the given transactional id, or an error if + * the coordinator does not own the transaction partition or is still loading it; if not found + * either return None or create a new metadata and added to the cache + * + * This function is covered by the state read lock + */ + private def getAndMaybeAddTransactionState(transactionalId: String, + createdTxnMetadataOpt: Option[TransactionMetadata]): Either[Errors, Option[CoordinatorEpochAndTxnMetadata]] = { + inReadLock(stateLock) { + val partitionId = partitionFor(transactionalId) + if (loadingPartitions.exists(_.txnPartitionId == partitionId)) + Left(Errors.COORDINATOR_LOAD_IN_PROGRESS) + else { + transactionMetadataCache.get(partitionId) match { + case Some(cacheEntry) => + val txnMetadata = Option(cacheEntry.metadataPerTransactionalId.get(transactionalId)).orElse { + createdTxnMetadataOpt.map { createdTxnMetadata => + Option(cacheEntry.metadataPerTransactionalId.putIfNotExists(transactionalId, createdTxnMetadata)) + .getOrElse(createdTxnMetadata) + } + } + Right(txnMetadata.map(CoordinatorEpochAndTxnMetadata(cacheEntry.coordinatorEpoch, _))) + + case None => + Left(Errors.NOT_COORDINATOR) + } + } + } + } + + /** + * Validate the given transaction timeout value + */ + def validateTransactionTimeoutMs(txnTimeoutMs: Int): Boolean = + txnTimeoutMs <= config.transactionMaxTimeoutMs && txnTimeoutMs > 0 + + def transactionTopicConfigs: Properties = { + val props = new Properties + + // enforce disabled unclean leader election, no compression types, and compact cleanup policy + props.put(LogConfig.UncleanLeaderElectionEnableProp, "false") + props.put(LogConfig.CompressionTypeProp, UncompressedCodec.name) + props.put(LogConfig.CleanupPolicyProp, LogConfig.Compact) + props.put(LogConfig.MinInSyncReplicasProp, config.transactionLogMinInsyncReplicas.toString) + props.put(LogConfig.SegmentBytesProp, config.transactionLogSegmentBytes.toString) + + props + } + + def partitionFor(transactionalId: String): Int = Utils.abs(transactionalId.hashCode) % transactionTopicPartitionCount + + private def loadTransactionMetadata(topicPartition: TopicPartition, coordinatorEpoch: Int): Pool[String, TransactionMetadata] = { + def logEndOffset = replicaManager.getLogEndOffset(topicPartition).getOrElse(-1L) + + val loadedTransactions = new Pool[String, TransactionMetadata] + + replicaManager.getLog(topicPartition) match { + case None => + warn(s"Attempted to load transaction metadata from $topicPartition, but found no log") + + case Some(log) => + // buffer may not be needed if records are read from memory + var buffer = ByteBuffer.allocate(0) + + // loop breaks if leader changes at any time during the load, since logEndOffset is -1 + var currOffset = log.logStartOffset + + // loop breaks if no records have been read, since the end of the log has been reached + var readAtLeastOneRecord = true + + try { + while (currOffset < logEndOffset && readAtLeastOneRecord && !shuttingDown.get() && inReadLock(stateLock) { + loadingPartitions.exists { idAndEpoch: TransactionPartitionAndLeaderEpoch => + idAndEpoch.txnPartitionId == topicPartition.partition && idAndEpoch.coordinatorEpoch == coordinatorEpoch}}) { + val fetchDataInfo = log.read(currOffset, + maxLength = config.transactionLogLoadBufferSize, + isolation = FetchLogEnd, + minOneMessage = true) + + readAtLeastOneRecord = fetchDataInfo.records.sizeInBytes > 0 + + val memRecords = (fetchDataInfo.records: @unchecked) match { + case records: MemoryRecords => records + case fileRecords: FileRecords => + val sizeInBytes = fileRecords.sizeInBytes + val bytesNeeded = Math.max(config.transactionLogLoadBufferSize, sizeInBytes) + + // minOneMessage = true in the above log.read means that the buffer may need to be grown to ensure progress can be made + if (buffer.capacity < bytesNeeded) { + if (config.transactionLogLoadBufferSize < bytesNeeded) + warn(s"Loaded transaction metadata from $topicPartition with buffer larger ($bytesNeeded bytes) than " + + s"configured transaction.state.log.load.buffer.size (${config.transactionLogLoadBufferSize} bytes)") + + buffer = ByteBuffer.allocate(bytesNeeded) + } else { + buffer.clear() + } + buffer.clear() + fileRecords.readInto(buffer, 0) + MemoryRecords.readableRecords(buffer) + } + + memRecords.batches.forEach { batch => + for (record <- batch.asScala) { + require(record.hasKey, "Transaction state log's key should not be null") + val txnKey = TransactionLog.readTxnRecordKey(record.key) + // load transaction metadata along with transaction state + val transactionalId = txnKey.transactionalId + TransactionLog.readTxnRecordValue(transactionalId, record.value) match { + case None => + loadedTransactions.remove(transactionalId) + case Some(txnMetadata) => + loadedTransactions.put(transactionalId, txnMetadata) + } + currOffset = batch.nextOffset + } + } + } + } catch { + case t: Throwable => error(s"Error loading transactions from transaction log $topicPartition", t) + } + } + + loadedTransactions + } + + /** + * Add a transaction topic partition into the cache + */ + private[transaction] def addLoadedTransactionsToCache(txnTopicPartition: Int, + coordinatorEpoch: Int, + loadedTransactions: Pool[String, TransactionMetadata]): Unit = { + val txnMetadataCacheEntry = TxnMetadataCacheEntry(coordinatorEpoch, loadedTransactions) + val previousTxnMetadataCacheEntryOpt = transactionMetadataCache.put(txnTopicPartition, txnMetadataCacheEntry) + + previousTxnMetadataCacheEntryOpt.foreach { previousTxnMetadataCacheEntry => + warn(s"Unloaded transaction metadata $previousTxnMetadataCacheEntry from $txnTopicPartition as part of " + + s"loading metadata at epoch $coordinatorEpoch") + } + } + + /** + * When this broker becomes a leader for a transaction log partition, load this partition and populate the transaction + * metadata cache with the transactional ids. This operation must be resilient to any partial state left off from + * the previous loading / unloading operation. + */ + def loadTransactionsForTxnTopicPartition(partitionId: Int, coordinatorEpoch: Int, sendTxnMarkers: SendTxnMarkersCallback): Unit = { + val topicPartition = new TopicPartition(Topic.TRANSACTION_STATE_TOPIC_NAME, partitionId) + val partitionAndLeaderEpoch = TransactionPartitionAndLeaderEpoch(partitionId, coordinatorEpoch) + + inWriteLock(stateLock) { + loadingPartitions.add(partitionAndLeaderEpoch) + } + + def loadTransactions(startTimeMs: java.lang.Long): Unit = { + val schedulerTimeMs = time.milliseconds() - startTimeMs + info(s"Loading transaction metadata from $topicPartition at epoch $coordinatorEpoch") + validateTransactionTopicPartitionCountIsStable() + + val loadedTransactions = loadTransactionMetadata(topicPartition, coordinatorEpoch) + val endTimeMs = time.milliseconds() + val totalLoadingTimeMs = endTimeMs - startTimeMs + partitionLoadSensor.record(totalLoadingTimeMs.toDouble, endTimeMs, false) + info(s"Finished loading ${loadedTransactions.size} transaction metadata from $topicPartition in " + + s"$totalLoadingTimeMs milliseconds, of which $schedulerTimeMs milliseconds was spent in the scheduler.") + + inWriteLock(stateLock) { + if (loadingPartitions.contains(partitionAndLeaderEpoch)) { + addLoadedTransactionsToCache(topicPartition.partition, coordinatorEpoch, loadedTransactions) + + val transactionsPendingForCompletion = new mutable.ListBuffer[TransactionalIdCoordinatorEpochAndTransitMetadata] + loadedTransactions.foreach { + case (transactionalId, txnMetadata) => + txnMetadata.inLock { + // if state is PrepareCommit or PrepareAbort we need to complete the transaction + txnMetadata.state match { + case PrepareAbort => + transactionsPendingForCompletion += + TransactionalIdCoordinatorEpochAndTransitMetadata(transactionalId, coordinatorEpoch, TransactionResult.ABORT, txnMetadata, txnMetadata.prepareComplete(time.milliseconds())) + case PrepareCommit => + transactionsPendingForCompletion += + TransactionalIdCoordinatorEpochAndTransitMetadata(transactionalId, coordinatorEpoch, TransactionResult.COMMIT, txnMetadata, txnMetadata.prepareComplete(time.milliseconds())) + case _ => + // nothing needs to be done + } + } + } + + // we first remove the partition from loading partition then send out the markers for those pending to be + // completed transactions, so that when the markers get sent the attempt of appending the complete transaction + // log would not be blocked by the coordinator loading error + loadingPartitions.remove(partitionAndLeaderEpoch) + + transactionsPendingForCompletion.foreach { txnTransitMetadata => + sendTxnMarkers(txnTransitMetadata.coordinatorEpoch, txnTransitMetadata.result, + txnTransitMetadata.txnMetadata, txnTransitMetadata.transitMetadata) + } + } + } + + info(s"Completed loading transaction metadata from $topicPartition for coordinator epoch $coordinatorEpoch") + } + + val scheduleStartMs = time.milliseconds() + scheduler.schedule(s"load-txns-for-partition-$topicPartition", () => loadTransactions(scheduleStartMs)) + } + + def removeTransactionsForTxnTopicPartition(partitionId: Int): Unit = { + val topicPartition = new TopicPartition(Topic.TRANSACTION_STATE_TOPIC_NAME, partitionId) + inWriteLock(stateLock) { + loadingPartitions --= loadingPartitions.filter(_.txnPartitionId == partitionId) + transactionMetadataCache.remove(partitionId).foreach { txnMetadataCacheEntry => + info(s"Unloaded transaction metadata $txnMetadataCacheEntry for $topicPartition following " + + s"local partition deletion") + } + } + } + + /** + * When this broker becomes a follower for a transaction log partition, clear out the cache for corresponding transactional ids + * that belong to that partition. + */ + def removeTransactionsForTxnTopicPartition(partitionId: Int, coordinatorEpoch: Int): Unit = { + val topicPartition = new TopicPartition(Topic.TRANSACTION_STATE_TOPIC_NAME, partitionId) + val partitionAndLeaderEpoch = TransactionPartitionAndLeaderEpoch(partitionId, coordinatorEpoch) + + inWriteLock(stateLock) { + loadingPartitions.remove(partitionAndLeaderEpoch) + transactionMetadataCache.remove(partitionId) match { + case Some(txnMetadataCacheEntry) => + info(s"Unloaded transaction metadata $txnMetadataCacheEntry for $topicPartition on become-follower transition") + + case None => + info(s"No cached transaction metadata found for $topicPartition during become-follower transition") + } + } + } + + private def validateTransactionTopicPartitionCountIsStable(): Unit = { + val previouslyDeterminedPartitionCount = transactionTopicPartitionCount + val curTransactionTopicPartitionCount = retrieveTransactionTopicPartitionCount() + if (previouslyDeterminedPartitionCount != curTransactionTopicPartitionCount) + throw new KafkaException(s"Transaction topic number of partitions has changed from $previouslyDeterminedPartitionCount to $curTransactionTopicPartitionCount") + } + + def appendTransactionToLog(transactionalId: String, + coordinatorEpoch: Int, + newMetadata: TxnTransitMetadata, + responseCallback: Errors => Unit, + retryOnError: Errors => Boolean = _ => false, + requestLocal: RequestLocal): Unit = { + + // generate the message for this transaction metadata + val keyBytes = TransactionLog.keyToBytes(transactionalId) + val valueBytes = TransactionLog.valueToBytes(newMetadata) + val timestamp = time.milliseconds() + + val records = MemoryRecords.withRecords(TransactionLog.EnforcedCompressionType, new SimpleRecord(timestamp, keyBytes, valueBytes)) + val topicPartition = new TopicPartition(Topic.TRANSACTION_STATE_TOPIC_NAME, partitionFor(transactionalId)) + val recordsPerPartition = Map(topicPartition -> records) + + // set the callback function to update transaction status in cache after log append completed + def updateCacheCallback(responseStatus: collection.Map[TopicPartition, PartitionResponse]): Unit = { + // the append response should only contain the topics partition + if (responseStatus.size != 1 || !responseStatus.contains(topicPartition)) + throw new IllegalStateException("Append status %s should only have one partition %s" + .format(responseStatus, topicPartition)) + + val status = responseStatus(topicPartition) + + var responseError = if (status.error == Errors.NONE) { + Errors.NONE + } else { + debug(s"Appending $transactionalId's new metadata $newMetadata failed due to ${status.error.exceptionName}") + + // transform the log append error code to the corresponding coordinator error code + status.error match { + case Errors.UNKNOWN_TOPIC_OR_PARTITION + | Errors.NOT_ENOUGH_REPLICAS + | Errors.NOT_ENOUGH_REPLICAS_AFTER_APPEND + | Errors.REQUEST_TIMED_OUT => // note that for timed out request we return NOT_AVAILABLE error code to let client retry + Errors.COORDINATOR_NOT_AVAILABLE + + case Errors.NOT_LEADER_OR_FOLLOWER + | Errors.KAFKA_STORAGE_ERROR => + Errors.NOT_COORDINATOR + + case Errors.MESSAGE_TOO_LARGE + | Errors.RECORD_LIST_TOO_LARGE => + Errors.UNKNOWN_SERVER_ERROR + + case other => + other + } + } + + if (responseError == Errors.NONE) { + // now try to update the cache: we need to update the status in-place instead of + // overwriting the whole object to ensure synchronization + getTransactionState(transactionalId) match { + + case Left(err) => + info(s"Accessing the cached transaction metadata for $transactionalId returns $err error; " + + s"aborting transition to the new metadata and setting the error in the callback") + responseError = err + case Right(Some(epochAndMetadata)) => + val metadata = epochAndMetadata.transactionMetadata + + metadata.inLock { + if (epochAndMetadata.coordinatorEpoch != coordinatorEpoch) { + // the cache may have been changed due to txn topic partition emigration and immigration, + // in this case directly return NOT_COORDINATOR to client and let it to re-discover the transaction coordinator + info(s"The cached coordinator epoch for $transactionalId has changed to ${epochAndMetadata.coordinatorEpoch} after appended its new metadata $newMetadata " + + s"to the transaction log (txn topic partition ${partitionFor(transactionalId)}) while it was $coordinatorEpoch before appending; " + + s"aborting transition to the new metadata and returning ${Errors.NOT_COORDINATOR} in the callback") + responseError = Errors.NOT_COORDINATOR + } else { + metadata.completeTransitionTo(newMetadata) + debug(s"Updating $transactionalId's transaction state to $newMetadata with coordinator epoch $coordinatorEpoch for $transactionalId succeeded") + } + } + + case Right(None) => + // this transactional id no longer exists, maybe the corresponding partition has already been migrated out. + // return NOT_COORDINATOR to let the client re-discover the transaction coordinator + info(s"The cached coordinator metadata does not exist in the cache anymore for $transactionalId after appended its new metadata $newMetadata " + + s"to the transaction log (txn topic partition ${partitionFor(transactionalId)}) while it was $coordinatorEpoch before appending; " + + s"aborting transition to the new metadata and returning ${Errors.NOT_COORDINATOR} in the callback") + responseError = Errors.NOT_COORDINATOR + } + } else { + // Reset the pending state when returning an error, since there is no active transaction for the transactional id at this point. + getTransactionState(transactionalId) match { + case Right(Some(epochAndTxnMetadata)) => + val metadata = epochAndTxnMetadata.transactionMetadata + metadata.inLock { + if (epochAndTxnMetadata.coordinatorEpoch == coordinatorEpoch) { + if (retryOnError(responseError)) { + info(s"TransactionalId ${metadata.transactionalId} append transaction log for $newMetadata transition failed due to $responseError, " + + s"not resetting pending state ${metadata.pendingState} but just returning the error in the callback to let the caller retry") + } else { + info(s"TransactionalId ${metadata.transactionalId} append transaction log for $newMetadata transition failed due to $responseError, " + + s"resetting pending state from ${metadata.pendingState}, aborting state transition and returning $responseError in the callback") + + metadata.pendingState = None + } + } else { + info(s"TransactionalId ${metadata.transactionalId} append transaction log for $newMetadata transition failed due to $responseError, " + + s"aborting state transition and returning the error in the callback since the coordinator epoch has changed from ${epochAndTxnMetadata.coordinatorEpoch} to $coordinatorEpoch") + } + } + + case Right(None) => + // Do nothing here, since we want to return the original append error to the user. + info(s"TransactionalId $transactionalId append transaction log for $newMetadata transition failed due to $responseError, " + + s"aborting state transition and returning the error in the callback since metadata is not available in the cache anymore") + + case Left(error) => + // Do nothing here, since we want to return the original append error to the user. + info(s"TransactionalId $transactionalId append transaction log for $newMetadata transition failed due to $responseError, " + + s"aborting state transition and returning the error in the callback since retrieving metadata returned $error") + } + + } + + responseCallback(responseError) + } + + inReadLock(stateLock) { + // we need to hold the read lock on the transaction metadata cache until appending to local log returns; + // this is to avoid the case where an emigration followed by an immigration could have completed after the check + // returns and before appendRecords() is called, since otherwise entries with a high coordinator epoch could have + // been appended to the log in between these two events, and therefore appendRecords() would append entries with + // an old coordinator epoch that can still be successfully replicated on followers and make the log in a bad state. + getTransactionState(transactionalId) match { + case Left(err) => + responseCallback(err) + + case Right(None) => + // the coordinator metadata has been removed, reply to client immediately with NOT_COORDINATOR + responseCallback(Errors.NOT_COORDINATOR) + + case Right(Some(epochAndMetadata)) => + val metadata = epochAndMetadata.transactionMetadata + + val append: Boolean = metadata.inLock { + if (epochAndMetadata.coordinatorEpoch != coordinatorEpoch) { + // the coordinator epoch has changed, reply to client immediately with NOT_COORDINATOR + responseCallback(Errors.NOT_COORDINATOR) + false + } else { + // do not need to check the metadata object itself since no concurrent thread should be able to modify it + // under the same coordinator epoch, so directly append to txn log now + true + } + } + if (append) { + replicaManager.appendRecords( + newMetadata.txnTimeoutMs.toLong, + TransactionLog.EnforcedRequiredAcks, + internalTopicsAllowed = true, + origin = AppendOrigin.Coordinator, + recordsPerPartition, + updateCacheCallback, + requestLocal = requestLocal) + + trace(s"Appending new metadata $newMetadata for transaction id $transactionalId with coordinator epoch $coordinatorEpoch to the local transaction log") + } + } + } + } + + def startup(retrieveTransactionTopicPartitionCount: () => Int, enableTransactionalIdExpiration: Boolean): Unit = { + this.retrieveTransactionTopicPartitionCount = retrieveTransactionTopicPartitionCount + transactionTopicPartitionCount = retrieveTransactionTopicPartitionCount() + if (enableTransactionalIdExpiration) + this.enableTransactionalIdExpiration() + } + + def shutdown(): Unit = { + shuttingDown.set(true) + loadingPartitions.clear() + transactionMetadataCache.clear() + + info("Shutdown complete") + } +} + + +private[transaction] case class TxnMetadataCacheEntry(coordinatorEpoch: Int, + metadataPerTransactionalId: Pool[String, TransactionMetadata]) { + override def toString: String = { + s"TxnMetadataCacheEntry(coordinatorEpoch=$coordinatorEpoch, numTransactionalEntries=${metadataPerTransactionalId.size})" + } +} + +private[transaction] case class CoordinatorEpochAndTxnMetadata(coordinatorEpoch: Int, + transactionMetadata: TransactionMetadata) + +private[transaction] case class TransactionConfig(transactionalIdExpirationMs: Int = TransactionStateManager.DefaultTransactionalIdExpirationMs, + transactionMaxTimeoutMs: Int = TransactionStateManager.DefaultTransactionsMaxTimeoutMs, + transactionLogNumPartitions: Int = TransactionLog.DefaultNumPartitions, + transactionLogReplicationFactor: Short = TransactionLog.DefaultReplicationFactor, + transactionLogSegmentBytes: Int = TransactionLog.DefaultSegmentBytes, + transactionLogLoadBufferSize: Int = TransactionLog.DefaultLoadBufferSize, + transactionLogMinInsyncReplicas: Int = TransactionLog.DefaultMinInSyncReplicas, + abortTimedOutTransactionsIntervalMs: Int = TransactionStateManager.DefaultAbortTimedOutTransactionsIntervalMs, + removeExpiredTransactionalIdsIntervalMs: Int = TransactionStateManager.DefaultRemoveExpiredTransactionalIdsIntervalMs, + requestTimeoutMs: Int = Defaults.RequestTimeoutMs) + +case class TransactionalIdAndProducerIdEpoch(transactionalId: String, producerId: Long, producerEpoch: Short) { + override def toString: String = { + s"(transactionalId=$transactionalId, producerId=$producerId, producerEpoch=$producerEpoch)" + } +} + +case class TransactionPartitionAndLeaderEpoch(txnPartitionId: Int, coordinatorEpoch: Int) + +case class TransactionalIdCoordinatorEpochAndMetadata(transactionalId: String, coordinatorEpoch: Int, transitMetadata: TxnTransitMetadata) + +case class TransactionalIdCoordinatorEpochAndTransitMetadata(transactionalId: String, coordinatorEpoch: Int, result: TransactionResult, txnMetadata: TransactionMetadata, transitMetadata: TxnTransitMetadata) diff --git a/core/src/main/scala/kafka/log/AbstractIndex.scala b/core/src/main/scala/kafka/log/AbstractIndex.scala new file mode 100644 index 0000000..31b9f6d --- /dev/null +++ b/core/src/main/scala/kafka/log/AbstractIndex.scala @@ -0,0 +1,440 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import java.io.{Closeable, File, RandomAccessFile} +import java.nio.channels.FileChannel +import java.nio.file.Files +import java.nio.{ByteBuffer, MappedByteBuffer} +import java.util.concurrent.locks.{Lock, ReentrantLock} + +import kafka.common.IndexOffsetOverflowException +import kafka.utils.CoreUtils.inLock +import kafka.utils.{CoreUtils, Logging} +import org.apache.kafka.common.utils.{ByteBufferUnmapper, OperatingSystem, Utils} + +/** + * The abstract index class which holds entry format agnostic methods. + * + * @param _file The index file + * @param baseOffset the base offset of the segment that this index is corresponding to. + * @param maxIndexSize The maximum index size in bytes. + */ +abstract class AbstractIndex(@volatile private var _file: File, val baseOffset: Long, val maxIndexSize: Int = -1, + val writable: Boolean) extends Closeable { + import AbstractIndex._ + + // Length of the index file + @volatile + private var _length: Long = _ + protected def entrySize: Int + + /* + Kafka mmaps index files into memory, and all the read / write operations of the index is through OS page cache. This + avoids blocked disk I/O in most cases. + + To the extent of our knowledge, all the modern operating systems use LRU policy or its variants to manage page + cache. Kafka always appends to the end of the index file, and almost all the index lookups (typically from in-sync + followers or consumers) are very close to the end of the index. So, the LRU cache replacement policy should work very + well with Kafka's index access pattern. + + However, when looking up index, the standard binary search algorithm is not cache friendly, and can cause unnecessary + page faults (the thread is blocked to wait for reading some index entries from hard disk, as those entries are not + cached in the page cache). + + For example, in an index with 13 pages, to lookup an entry in the last page (page #12), the standard binary search + algorithm will read index entries in page #0, 6, 9, 11, and 12. + page number: |0|1|2|3|4|5|6|7|8|9|10|11|12 | + steps: |1| | | | | |3| | |4| |5 |2/6| + In each page, there are hundreds log entries, corresponding to hundreds to thousands of kafka messages. When the + index gradually growing from the 1st entry in page #12 to the last entry in page #12, all the write (append) + operations are in page #12, and all the in-sync follower / consumer lookups read page #0,6,9,11,12. As these pages + are always used in each in-sync lookup, we can assume these pages are fairly recently used, and are very likely to be + in the page cache. When the index grows to page #13, the pages needed in a in-sync lookup change to #0, 7, 10, 12, + and 13: + page number: |0|1|2|3|4|5|6|7|8|9|10|11|12|13 | + steps: |1| | | | | | |3| | | 4|5 | 6|2/7| + Page #7 and page #10 have not been used for a very long time. They are much less likely to be in the page cache, than + the other pages. The 1st lookup, after the 1st index entry in page #13 is appended, is likely to have to read page #7 + and page #10 from disk (page fault), which can take up to more than a second. In our test, this can cause the + at-least-once produce latency to jump to about 1 second from a few ms. + + Here, we use a more cache-friendly lookup algorithm: + if (target > indexEntry[end - N]) // if the target is in the last N entries of the index + binarySearch(end - N, end) + else + binarySearch(begin, end - N) + + If possible, we only look up in the last N entries of the index. By choosing a proper constant N, all the in-sync + lookups should go to the 1st branch. We call the last N entries the "warm" section. As we frequently look up in this + relatively small section, the pages containing this section are more likely to be in the page cache. + + We set N (_warmEntries) to 8192, because + 1. This number is small enough to guarantee all the pages of the "warm" section is touched in every warm-section + lookup. So that, the entire warm section is really "warm". + When doing warm-section lookup, following 3 entries are always touched: indexEntry(end), indexEntry(end-N), + and indexEntry((end*2 -N)/2). If page size >= 4096, all the warm-section pages (3 or fewer) are touched, when we + touch those 3 entries. As of 2018, 4096 is the smallest page size for all the processors (x86-32, x86-64, MIPS, + SPARC, Power, ARM etc.). + 2. This number is large enough to guarantee most of the in-sync lookups are in the warm-section. With default Kafka + settings, 8KB index corresponds to about 4MB (offset index) or 2.7MB (time index) log messages. + + We can't set make N (_warmEntries) to be larger than 8192, as there is no simple way to guarantee all the "warm" + section pages are really warm (touched in every lookup) on a typical 4KB-page host. + + In there future, we may use a backend thread to periodically touch the entire warm section. So that, we can + 1) support larger warm section + 2) make sure the warm section of low QPS topic-partitions are really warm. + */ + protected def _warmEntries: Int = 8192 / entrySize + + protected val lock = new ReentrantLock + + @volatile + protected var mmap: MappedByteBuffer = { + val newlyCreated = file.createNewFile() + val raf = if (writable) new RandomAccessFile(file, "rw") else new RandomAccessFile(file, "r") + try { + /* pre-allocate the file if necessary */ + if(newlyCreated) { + if(maxIndexSize < entrySize) + throw new IllegalArgumentException("Invalid max index size: " + maxIndexSize) + raf.setLength(roundDownToExactMultiple(maxIndexSize, entrySize)) + } + + /* memory-map the file */ + _length = raf.length() + val idx = { + if (writable) + raf.getChannel.map(FileChannel.MapMode.READ_WRITE, 0, _length) + else + raf.getChannel.map(FileChannel.MapMode.READ_ONLY, 0, _length) + } + /* set the position in the index for the next entry */ + if(newlyCreated) + idx.position(0) + else + // if this is a pre-existing index, assume it is valid and set position to last entry + idx.position(roundDownToExactMultiple(idx.limit(), entrySize)) + idx + } finally { + CoreUtils.swallow(raf.close(), AbstractIndex) + } + } + + /** + * The maximum number of entries this index can hold + */ + @volatile + private[this] var _maxEntries: Int = mmap.limit() / entrySize + + /** The number of entries in this index */ + @volatile + protected var _entries: Int = mmap.position() / entrySize + + /** + * True iff there are no more slots available in this index + */ + def isFull: Boolean = _entries >= _maxEntries + + def file: File = _file + + def maxEntries: Int = _maxEntries + + def entries: Int = _entries + + def length: Long = _length + + def updateParentDir(parentDir: File): Unit = _file = new File(parentDir, file.getName) + + /** + * Reset the size of the memory map and the underneath file. This is used in two kinds of cases: (1) in + * trimToValidSize() which is called at closing the segment or new segment being rolled; (2) at + * loading segments from disk or truncating back to an old segment where a new log segment became active; + * we want to reset the index size to maximum index size to avoid rolling new segment. + * + * @param newSize new size of the index file + * @return a boolean indicating whether the size of the memory map and the underneath file is changed or not. + */ + def resize(newSize: Int): Boolean = { + inLock(lock) { + val roundedNewSize = roundDownToExactMultiple(newSize, entrySize) + + if (_length == roundedNewSize) { + debug(s"Index ${file.getAbsolutePath} was not resized because it already has size $roundedNewSize") + false + } else { + val raf = new RandomAccessFile(file, "rw") + try { + val position = mmap.position() + + /* Windows or z/OS won't let us modify the file length while the file is mmapped :-( */ + if (OperatingSystem.IS_WINDOWS || OperatingSystem.IS_ZOS) + safeForceUnmap() + raf.setLength(roundedNewSize) + _length = roundedNewSize + mmap = raf.getChannel().map(FileChannel.MapMode.READ_WRITE, 0, roundedNewSize) + _maxEntries = mmap.limit() / entrySize + mmap.position(position) + debug(s"Resized ${file.getAbsolutePath} to $roundedNewSize, position is ${mmap.position()} " + + s"and limit is ${mmap.limit()}") + true + } finally { + CoreUtils.swallow(raf.close(), AbstractIndex) + } + } + } + } + + /** + * Rename the file that backs this offset index + * + * @throws IOException if rename fails + */ + def renameTo(f: File): Unit = { + try Utils.atomicMoveWithFallback(file.toPath, f.toPath, false) + finally _file = f + } + + /** + * Flush the data in the index to disk + */ + def flush(): Unit = { + inLock(lock) { + mmap.force() + } + } + + /** + * Delete this index file. + * + * @throws IOException if deletion fails due to an I/O error + * @return `true` if the file was deleted by this method; `false` if the file could not be deleted because it did + * not exist + */ + def deleteIfExists(): Boolean = { + closeHandler() + Files.deleteIfExists(file.toPath) + } + + /** + * Trim this segment to fit just the valid entries, deleting all trailing unwritten bytes from + * the file. + */ + def trimToValidSize(): Unit = { + inLock(lock) { + resize(entrySize * _entries) + } + } + + /** + * The number of bytes actually used by this index + */ + def sizeInBytes: Int = entrySize * _entries + + /** Close the index */ + def close(): Unit = { + trimToValidSize() + closeHandler() + } + + def closeHandler(): Unit = { + // On JVM, a memory mapping is typically unmapped by garbage collector. + // However, in some cases it can pause application threads(STW) for a long moment reading metadata from a physical disk. + // To prevent this, we forcefully cleanup memory mapping within proper execution which never affects API responsiveness. + // See https://issues.apache.org/jira/browse/KAFKA-4614 for the details. + inLock(lock) { + safeForceUnmap() + } + } + + /** + * Do a basic sanity check on this index to detect obvious problems + * + * @throws CorruptIndexException if any problems are found + */ + def sanityCheck(): Unit + + /** + * Remove all the entries from the index. + */ + protected def truncate(): Unit + + /** + * Remove all entries from the index which have an offset greater than or equal to the given offset. + * Truncating to an offset larger than the largest in the index has no effect. + */ + def truncateTo(offset: Long): Unit + + /** + * Remove all the entries from the index and resize the index to the max index size. + */ + def reset(): Unit = { + truncate() + resize(maxIndexSize) + } + + /** + * Get offset relative to base offset of this index + * @throws IndexOffsetOverflowException + */ + def relativeOffset(offset: Long): Int = { + val relativeOffset = toRelative(offset) + if (relativeOffset.isEmpty) + throw new IndexOffsetOverflowException(s"Integer overflow for offset: $offset (${file.getAbsoluteFile})") + relativeOffset.get + } + + /** + * Check if a particular offset is valid to be appended to this index. + * @param offset The offset to check + * @return true if this offset is valid to be appended to this index; false otherwise + */ + def canAppendOffset(offset: Long): Boolean = { + toRelative(offset).isDefined + } + + protected def safeForceUnmap(): Unit = { + if (mmap != null) { + try forceUnmap() + catch { + case t: Throwable => error(s"Error unmapping index $file", t) + } + } + } + + /** + * Forcefully free the buffer's mmap. + */ + protected[log] def forceUnmap(): Unit = { + try ByteBufferUnmapper.unmap(file.getAbsolutePath, mmap) + finally mmap = null // Accessing unmapped mmap crashes JVM by SEGV so we null it out to be safe + } + + /** + * Execute the given function in a lock only if we are running on windows or z/OS. We do this + * because Windows or z/OS won't let us resize a file while it is mmapped. As a result we have to force unmap it + * and this requires synchronizing reads. + */ + protected def maybeLock[T](lock: Lock)(fun: => T): T = { + if (OperatingSystem.IS_WINDOWS || OperatingSystem.IS_ZOS) + lock.lock() + try fun + finally { + if (OperatingSystem.IS_WINDOWS || OperatingSystem.IS_ZOS) + lock.unlock() + } + } + + /** + * To parse an entry in the index. + * + * @param buffer the buffer of this memory mapped index. + * @param n the slot + * @return the index entry stored in the given slot. + */ + protected def parseEntry(buffer: ByteBuffer, n: Int): IndexEntry + + /** + * Find the slot in which the largest entry less than or equal to the given target key or value is stored. + * The comparison is made using the `IndexEntry.compareTo()` method. + * + * @param idx The index buffer + * @param target The index key to look for + * @return The slot found or -1 if the least entry in the index is larger than the target key or the index is empty + */ + protected def largestLowerBoundSlotFor(idx: ByteBuffer, target: Long, searchEntity: IndexSearchType): Int = + indexSlotRangeFor(idx, target, searchEntity)._1 + + /** + * Find the smallest entry greater than or equal the target key or value. If none can be found, -1 is returned. + */ + protected def smallestUpperBoundSlotFor(idx: ByteBuffer, target: Long, searchEntity: IndexSearchType): Int = + indexSlotRangeFor(idx, target, searchEntity)._2 + + /** + * Lookup lower and upper bounds for the given target. + */ + private def indexSlotRangeFor(idx: ByteBuffer, target: Long, searchEntity: IndexSearchType): (Int, Int) = { + // check if the index is empty + if(_entries == 0) + return (-1, -1) + + def binarySearch(begin: Int, end: Int) : (Int, Int) = { + // binary search for the entry + var lo = begin + var hi = end + while(lo < hi) { + val mid = (lo + hi + 1) >>> 1 + val found = parseEntry(idx, mid) + val compareResult = compareIndexEntry(found, target, searchEntity) + if(compareResult > 0) + hi = mid - 1 + else if(compareResult < 0) + lo = mid + else + return (mid, mid) + } + (lo, if (lo == _entries - 1) -1 else lo + 1) + } + + val firstHotEntry = Math.max(0, _entries - 1 - _warmEntries) + // check if the target offset is in the warm section of the index + if(compareIndexEntry(parseEntry(idx, firstHotEntry), target, searchEntity) < 0) { + return binarySearch(firstHotEntry, _entries - 1) + } + + // check if the target offset is smaller than the least offset + if(compareIndexEntry(parseEntry(idx, 0), target, searchEntity) > 0) + return (-1, 0) + + binarySearch(0, firstHotEntry) + } + + private def compareIndexEntry(indexEntry: IndexEntry, target: Long, searchEntity: IndexSearchType): Int = { + searchEntity match { + case IndexSearchType.KEY => java.lang.Long.compare(indexEntry.indexKey, target) + case IndexSearchType.VALUE => java.lang.Long.compare(indexEntry.indexValue, target) + } + } + + /** + * Round a number to the greatest exact multiple of the given factor less than the given number. + * E.g. roundDownToExactMultiple(67, 8) == 64 + */ + private def roundDownToExactMultiple(number: Int, factor: Int) = factor * (number / factor) + + private def toRelative(offset: Long): Option[Int] = { + val relativeOffset = offset - baseOffset + if (relativeOffset < 0 || relativeOffset > Int.MaxValue) + None + else + Some(relativeOffset.toInt) + } + +} + +object AbstractIndex extends Logging { + override val loggerName: String = classOf[AbstractIndex].getName +} + +sealed trait IndexSearchType +object IndexSearchType { + case object KEY extends IndexSearchType + case object VALUE extends IndexSearchType +} diff --git a/core/src/main/scala/kafka/log/CleanerConfig.scala b/core/src/main/scala/kafka/log/CleanerConfig.scala new file mode 100644 index 0000000..782bc9a --- /dev/null +++ b/core/src/main/scala/kafka/log/CleanerConfig.scala @@ -0,0 +1,41 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +/** + * Configuration parameters for the log cleaner + * + * @param numThreads The number of cleaner threads to run + * @param dedupeBufferSize The total memory used for log deduplication + * @param dedupeBufferLoadFactor The maximum percent full for the deduplication buffer + * @param maxMessageSize The maximum size of a message that can appear in the log + * @param maxIoBytesPerSecond The maximum read and write I/O that all cleaner threads are allowed to do + * @param backOffMs The amount of time to wait before rechecking if no logs are eligible for cleaning + * @param enableCleaner Allows completely disabling the log cleaner + * @param hashAlgorithm The hash algorithm to use in key comparison. + */ +case class CleanerConfig(numThreads: Int = 1, + dedupeBufferSize: Long = 4*1024*1024L, + dedupeBufferLoadFactor: Double = 0.9d, + ioBufferSize: Int = 1024*1024, + maxMessageSize: Int = 32*1024*1024, + maxIoBytesPerSecond: Double = Double.MaxValue, + backOffMs: Long = 15 * 1000, + enableCleaner: Boolean = true, + hashAlgorithm: String = "MD5") { +} diff --git a/core/src/main/scala/kafka/log/CorruptIndexException.scala b/core/src/main/scala/kafka/log/CorruptIndexException.scala new file mode 100644 index 0000000..b39ee5b --- /dev/null +++ b/core/src/main/scala/kafka/log/CorruptIndexException.scala @@ -0,0 +1,20 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +class CorruptIndexException(message: String) extends RuntimeException(message) diff --git a/core/src/main/scala/kafka/log/IndexEntry.scala b/core/src/main/scala/kafka/log/IndexEntry.scala new file mode 100644 index 0000000..705366e --- /dev/null +++ b/core/src/main/scala/kafka/log/IndexEntry.scala @@ -0,0 +1,52 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import org.apache.kafka.common.requests.ListOffsetsResponse + +sealed trait IndexEntry { + // We always use Long for both key and value to avoid boxing. + def indexKey: Long + def indexValue: Long +} + +/** + * The mapping between a logical log offset and the physical position + * in some log file of the beginning of the message set entry with the + * given offset. + */ +case class OffsetPosition(offset: Long, position: Int) extends IndexEntry { + override def indexKey = offset + override def indexValue = position.toLong +} + + +/** + * The mapping between a timestamp to a message offset. The entry means that any message whose timestamp is greater + * than that timestamp must be at or after that offset. + * @param timestamp The max timestamp before the given offset. + * @param offset The message offset. + */ +case class TimestampOffset(timestamp: Long, offset: Long) extends IndexEntry { + override def indexKey = timestamp + override def indexValue = offset +} + +object TimestampOffset { + val Unknown = TimestampOffset(ListOffsetsResponse.UNKNOWN_TIMESTAMP, ListOffsetsResponse.UNKNOWN_OFFSET) +} diff --git a/core/src/main/scala/kafka/log/LazyIndex.scala b/core/src/main/scala/kafka/log/LazyIndex.scala new file mode 100644 index 0000000..5ef1893 --- /dev/null +++ b/core/src/main/scala/kafka/log/LazyIndex.scala @@ -0,0 +1,166 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import java.io.File +import java.nio.file.{Files, NoSuchFileException} +import java.util.concurrent.locks.ReentrantLock + +import LazyIndex._ +import kafka.utils.CoreUtils.inLock +import kafka.utils.threadsafe +import org.apache.kafka.common.utils.Utils + +/** + * A wrapper over an `AbstractIndex` instance that provides a mechanism to defer loading + * (i.e. memory mapping) the underlying index until it is accessed for the first time via the + * `get` method. + * + * In addition, this class exposes a number of methods (e.g. updateParentDir, renameTo, close, + * etc.) that provide the desired behavior without causing the index to be loaded. If the index + * had previously been loaded, the methods in this class simply delegate to the relevant method in + * the index. + * + * This is an important optimization with regards to broker start-up and shutdown time if it has a + * large number of segments. + * + * Methods of this class are thread safe. Make sure to check `AbstractIndex` subclasses + * documentation to establish their thread safety. + * + * @param loadIndex A function that takes a `File` pointing to an index and returns a loaded + * `AbstractIndex` instance. + */ +@threadsafe +class LazyIndex[T <: AbstractIndex] private (@volatile private var indexWrapper: IndexWrapper, loadIndex: File => T) { + + private val lock = new ReentrantLock() + + def file: File = indexWrapper.file + + def get: T = { + indexWrapper match { + case indexValue: IndexValue[T] => indexValue.index + case _: IndexFile => + inLock(lock) { + indexWrapper match { + case indexValue: IndexValue[T] => indexValue.index + case indexFile: IndexFile => + val indexValue = new IndexValue(loadIndex(indexFile.file)) + indexWrapper = indexValue + indexValue.index + } + } + } + } + + def updateParentDir(parentDir: File): Unit = { + inLock(lock) { + indexWrapper.updateParentDir(parentDir) + } + } + + def renameTo(f: File): Unit = { + inLock(lock) { + indexWrapper.renameTo(f) + } + } + + def deleteIfExists(): Boolean = { + inLock(lock) { + indexWrapper.deleteIfExists() + } + } + + def close(): Unit = { + inLock(lock) { + indexWrapper.close() + } + } + + def closeHandler(): Unit = { + inLock(lock) { + indexWrapper.closeHandler() + } + } + +} + +object LazyIndex { + + def forOffset(file: File, baseOffset: Long, maxIndexSize: Int = -1, writable: Boolean = true): LazyIndex[OffsetIndex] = + new LazyIndex(new IndexFile(file), file => new OffsetIndex(file, baseOffset, maxIndexSize, writable)) + + def forTime(file: File, baseOffset: Long, maxIndexSize: Int = -1, writable: Boolean = true): LazyIndex[TimeIndex] = + new LazyIndex(new IndexFile(file), file => new TimeIndex(file, baseOffset, maxIndexSize, writable)) + + private sealed trait IndexWrapper { + + def file: File + + def updateParentDir(f: File): Unit + + def renameTo(f: File): Unit + + def deleteIfExists(): Boolean + + def close(): Unit + + def closeHandler(): Unit + + } + + private class IndexFile(@volatile private var _file: File) extends IndexWrapper { + + def file: File = _file + + def updateParentDir(parentDir: File): Unit = _file = new File(parentDir, file.getName) + + def renameTo(f: File): Unit = { + try Utils.atomicMoveWithFallback(file.toPath, f.toPath, false) + catch { + case _: NoSuchFileException if !file.exists => () + } + finally _file = f + } + + def deleteIfExists(): Boolean = Files.deleteIfExists(file.toPath) + + def close(): Unit = () + + def closeHandler(): Unit = () + + } + + private class IndexValue[T <: AbstractIndex](val index: T) extends IndexWrapper { + + def file: File = index.file + + def updateParentDir(parentDir: File): Unit = index.updateParentDir(parentDir) + + def renameTo(f: File): Unit = index.renameTo(f) + + def deleteIfExists(): Boolean = index.deleteIfExists() + + def close(): Unit = index.close() + + def closeHandler(): Unit = index.closeHandler() + + } + +} + diff --git a/core/src/main/scala/kafka/log/LocalLog.scala b/core/src/main/scala/kafka/log/LocalLog.scala new file mode 100644 index 0000000..04e6152 --- /dev/null +++ b/core/src/main/scala/kafka/log/LocalLog.scala @@ -0,0 +1,1010 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import java.io.{File, IOException} +import java.nio.file.Files +import java.text.NumberFormat +import java.util.concurrent.atomic.AtomicLong +import java.util.regex.Pattern +import kafka.metrics.KafkaMetricsGroup +import kafka.server.{FetchDataInfo, LogDirFailureChannel, LogOffsetMetadata} +import kafka.utils.{Logging, Scheduler} +import org.apache.kafka.common.{KafkaException, TopicPartition} +import org.apache.kafka.common.errors.{KafkaStorageException, OffsetOutOfRangeException} +import org.apache.kafka.common.message.FetchResponseData +import org.apache.kafka.common.record.MemoryRecords +import org.apache.kafka.common.utils.{Time, Utils} + +import scala.jdk.CollectionConverters._ +import scala.collection.{Seq, immutable} +import scala.collection.mutable.{ArrayBuffer, ListBuffer} + +/** + * Holds the result of splitting a segment into one or more segments, see LocalLog.splitOverflowedSegment(). + * + * @param deletedSegments segments deleted when splitting a segment + * @param newSegments new segments created when splitting a segment + */ +case class SplitSegmentResult(deletedSegments: Iterable[LogSegment], newSegments: Iterable[LogSegment]) + +/** + * An append-only log for storing messages locally. The log is a sequence of LogSegments, each with a base offset. + * New log segments are created according to a configurable policy that controls the size in bytes or time interval + * for a given segment. + * + * NOTE: this class is not thread-safe, and it relies on the thread safety provided by the Log class. + * + * @param _dir The directory in which log segments are created. + * @param config The log configuration settings + * @param segments The non-empty log segments recovered from disk + * @param recoveryPoint The offset at which to begin the next recovery i.e. the first offset which has not been flushed to disk + * @param nextOffsetMetadata The offset where the next message could be appended + * @param scheduler The thread pool scheduler used for background actions + * @param time The time instance used for checking the clock + * @param topicPartition The topic partition associated with this log + * @param logDirFailureChannel The LogDirFailureChannel instance to asynchronously handle Log dir failure + */ +class LocalLog(@volatile private var _dir: File, + @volatile private[log] var config: LogConfig, + private[log] val segments: LogSegments, + @volatile private[log] var recoveryPoint: Long, + @volatile private var nextOffsetMetadata: LogOffsetMetadata, + private[log] val scheduler: Scheduler, + private[log] val time: Time, + private[log] val topicPartition: TopicPartition, + private[log] val logDirFailureChannel: LogDirFailureChannel) extends Logging with KafkaMetricsGroup { + + import kafka.log.LocalLog._ + + this.logIdent = s"[LocalLog partition=$topicPartition, dir=${dir.getParent}] " + + // The memory mapped buffer for index files of this log will be closed with either delete() or closeHandlers() + // After memory mapped buffer is closed, no disk IO operation should be performed for this log. + @volatile private[log] var isMemoryMappedBufferClosed = false + + // Cache value of parent directory to avoid allocations in hot paths like ReplicaManager.checkpointHighWatermarks + @volatile private var _parentDir: String = dir.getParent + + // Last time the log was flushed + private val lastFlushedTime = new AtomicLong(time.milliseconds) + + private[log] def dir: File = _dir + + private[log] def name: String = dir.getName() + + private[log] def parentDir: String = _parentDir + + private[log] def parentDirFile: File = new File(_parentDir) + + private[log] def isFuture: Boolean = dir.getName.endsWith(LocalLog.FutureDirSuffix) + + private def maybeHandleIOException[T](msg: => String)(fun: => T): T = { + LocalLog.maybeHandleIOException(logDirFailureChannel, parentDir, msg) { + fun + } + } + + /** + * Rename the directory of the log + * @param name the new dir name + * @throws KafkaStorageException if rename fails + */ + private[log] def renameDir(name: String): Boolean = { + maybeHandleIOException(s"Error while renaming dir for $topicPartition in log dir ${dir.getParent}") { + val renamedDir = new File(dir.getParent, name) + Utils.atomicMoveWithFallback(dir.toPath, renamedDir.toPath) + if (renamedDir != dir) { + _dir = renamedDir + _parentDir = renamedDir.getParent + segments.updateParentDir(renamedDir) + true + } else { + false + } + } + } + + /** + * Update the existing configuration to the new provided configuration. + * @param newConfig the new configuration to be updated to + */ + private[log] def updateConfig(newConfig: LogConfig): Unit = { + val oldConfig = config + config = newConfig + val oldRecordVersion = oldConfig.recordVersion + val newRecordVersion = newConfig.recordVersion + if (newRecordVersion.precedes(oldRecordVersion)) + warn(s"Record format version has been downgraded from $oldRecordVersion to $newRecordVersion.") + } + + private[log] def checkIfMemoryMappedBufferClosed(): Unit = { + if (isMemoryMappedBufferClosed) + throw new KafkaStorageException(s"The memory mapped buffer for log of $topicPartition is already closed") + } + + private[log] def updateRecoveryPoint(newRecoveryPoint: Long): Unit = { + recoveryPoint = newRecoveryPoint + } + + /** + * Update recoveryPoint to provided offset and mark the log as flushed, if the offset is greater + * than the existing recoveryPoint. + * + * @param offset the offset to be updated + */ + private[log] def markFlushed(offset: Long): Unit = { + checkIfMemoryMappedBufferClosed() + if (offset > recoveryPoint) { + updateRecoveryPoint(offset) + lastFlushedTime.set(time.milliseconds) + } + } + + /** + * The number of messages appended to the log since the last flush + */ + private[log] def unflushedMessages: Long = logEndOffset - recoveryPoint + + /** + * Flush local log segments for all offsets up to offset-1. + * Does not update the recovery point. + * + * @param offset The offset to flush up to (non-inclusive) + */ + private[log] def flush(offset: Long): Unit = { + val segmentsToFlush = segments.values(recoveryPoint, offset) + segmentsToFlush.foreach(_.flush()) + // If there are any new segments, we need to flush the parent directory for crash consistency. + segmentsToFlush.lastOption.filter(_.baseOffset >= this.recoveryPoint).foreach(_ => Utils.flushDir(dir.toPath)) + } + + /** + * The time this log is last known to have been fully flushed to disk + */ + private[log] def lastFlushTime: Long = lastFlushedTime.get + + /** + * The offset metadata of the next message that will be appended to the log + */ + private[log] def logEndOffsetMetadata: LogOffsetMetadata = nextOffsetMetadata + + /** + * The offset of the next message that will be appended to the log + */ + private[log] def logEndOffset: Long = nextOffsetMetadata.messageOffset + + /** + * Update end offset of the log, and update the recoveryPoint. + * + * @param endOffset the new end offset of the log + */ + private[log] def updateLogEndOffset(endOffset: Long): Unit = { + nextOffsetMetadata = LogOffsetMetadata(endOffset, segments.activeSegment.baseOffset, segments.activeSegment.size) + if (recoveryPoint > endOffset) { + updateRecoveryPoint(endOffset) + } + } + + /** + * Close file handlers used by log but don't write to disk. + * This is called if the log directory is offline. + */ + private[log] def closeHandlers(): Unit = { + segments.closeHandlers() + isMemoryMappedBufferClosed = true + } + + /** + * Closes the segments of the log. + */ + private[log] def close(): Unit = { + maybeHandleIOException(s"Error while renaming dir for $topicPartition in dir ${dir.getParent}") { + checkIfMemoryMappedBufferClosed() + segments.close() + } + } + + /** + * Completely delete this log directory with no delay. + */ + private[log] def deleteEmptyDir(): Unit = { + maybeHandleIOException(s"Error while deleting dir for $topicPartition in dir ${dir.getParent}") { + if (segments.nonEmpty) { + throw new IllegalStateException(s"Can not delete directory when ${segments.numberOfSegments} segments are still present") + } + if (!isMemoryMappedBufferClosed) { + throw new IllegalStateException(s"Can not delete directory when memory mapped buffer for log of $topicPartition is still open.") + } + Utils.delete(dir) + } + } + + /** + * Completely delete all segments with no delay. + * @return the deleted segments + */ + private[log] def deleteAllSegments(): Iterable[LogSegment] = { + maybeHandleIOException(s"Error while deleting all segments for $topicPartition in dir ${dir.getParent}") { + val deletableSegments = List[LogSegment]() ++ segments.values + removeAndDeleteSegments(segments.values, asyncDelete = false, LogDeletion(this)) + isMemoryMappedBufferClosed = true + deletableSegments + } + } + + /** + * Find segments starting from the oldest until the user-supplied predicate is false. + * A final segment that is empty will never be returned. + * + * @param predicate A function that takes in a candidate log segment, the next higher segment + * (if there is one). It returns true iff the segment is deletable. + * @return the segments ready to be deleted + */ + private[log] def deletableSegments(predicate: (LogSegment, Option[LogSegment]) => Boolean): Iterable[LogSegment] = { + if (segments.isEmpty) { + Seq.empty + } else { + val deletable = ArrayBuffer.empty[LogSegment] + val segmentsIterator = segments.values.iterator + var segmentOpt = nextOption(segmentsIterator) + while (segmentOpt.isDefined) { + val segment = segmentOpt.get + val nextSegmentOpt = nextOption(segmentsIterator) + val isLastSegmentAndEmpty = nextSegmentOpt.isEmpty && segment.size == 0 + if (predicate(segment, nextSegmentOpt) && !isLastSegmentAndEmpty) { + deletable += segment + segmentOpt = nextSegmentOpt + } else { + segmentOpt = Option.empty + } + } + deletable + } + } + + /** + * This method deletes the given log segments by doing the following for each of them: + * - It removes the segment from the segment map so that it will no longer be used for reads. + * - It renames the index and log files by appending .deleted to the respective file name + * - It can either schedule an asynchronous delete operation to occur in the future or perform the deletion synchronously + * + * Asynchronous deletion allows reads to happen concurrently without synchronization and without the possibility of + * physically deleting a file while it is being read. + * + * This method does not convert IOException to KafkaStorageException, the immediate caller + * is expected to catch and handle IOException. + * + * @param segmentsToDelete The log segments to schedule for deletion + * @param asyncDelete Whether the segment files should be deleted asynchronously + * @param reason The reason for the segment deletion + */ + private[log] def removeAndDeleteSegments(segmentsToDelete: Iterable[LogSegment], + asyncDelete: Boolean, + reason: SegmentDeletionReason): Unit = { + if (segmentsToDelete.nonEmpty) { + // Most callers hold an iterator into the `segments` collection and `removeAndDeleteSegment` mutates it by + // removing the deleted segment, we should force materialization of the iterator here, so that results of the + // iteration remain valid and deterministic. We should also pass only the materialized view of the + // iterator to the logic that actually deletes the segments. + val toDelete = segmentsToDelete.toList + reason.logReason(toDelete) + toDelete.foreach { segment => + segments.remove(segment.baseOffset) + } + LocalLog.deleteSegmentFiles(toDelete, asyncDelete, dir, topicPartition, config, scheduler, logDirFailureChannel, logIdent) + } + } + + /** + * Given a message offset, find its corresponding offset metadata in the log. + * If the message offset is out of range, throw an OffsetOutOfRangeException + */ + private[log] def convertToOffsetMetadataOrThrow(offset: Long): LogOffsetMetadata = { + val fetchDataInfo = read(offset, + maxLength = 1, + minOneMessage = false, + maxOffsetMetadata = nextOffsetMetadata, + includeAbortedTxns = false) + fetchDataInfo.fetchOffsetMetadata + } + + /** + * Read messages from the log. + * + * @param startOffset The offset to begin reading at + * @param maxLength The maximum number of bytes to read + * @param minOneMessage If this is true, the first message will be returned even if it exceeds `maxLength` (if one exists) + * @param maxOffsetMetadata The metadata of the maximum offset to be fetched + * @param includeAbortedTxns If true, aborted transactions are included + * @throws OffsetOutOfRangeException If startOffset is beyond the log end offset + * @return The fetch data information including fetch starting offset metadata and messages read. + */ + def read(startOffset: Long, + maxLength: Int, + minOneMessage: Boolean, + maxOffsetMetadata: LogOffsetMetadata, + includeAbortedTxns: Boolean): FetchDataInfo = { + maybeHandleIOException(s"Exception while reading from $topicPartition in dir ${dir.getParent}") { + trace(s"Reading maximum $maxLength bytes at offset $startOffset from log with " + + s"total length ${segments.sizeInBytes} bytes") + + val endOffsetMetadata = nextOffsetMetadata + val endOffset = endOffsetMetadata.messageOffset + var segmentOpt = segments.floorSegment(startOffset) + + // return error on attempt to read beyond the log end offset + if (startOffset > endOffset || segmentOpt.isEmpty) + throw new OffsetOutOfRangeException(s"Received request for offset $startOffset for partition $topicPartition, " + + s"but we only have log segments upto $endOffset.") + + if (startOffset == maxOffsetMetadata.messageOffset) + emptyFetchDataInfo(maxOffsetMetadata, includeAbortedTxns) + else if (startOffset > maxOffsetMetadata.messageOffset) + emptyFetchDataInfo(convertToOffsetMetadataOrThrow(startOffset), includeAbortedTxns) + else { + // Do the read on the segment with a base offset less than the target offset + // but if that segment doesn't contain any messages with an offset greater than that + // continue to read from successive segments until we get some messages or we reach the end of the log + var fetchDataInfo: FetchDataInfo = null + while (fetchDataInfo == null && segmentOpt.isDefined) { + val segment = segmentOpt.get + val baseOffset = segment.baseOffset + + val maxPosition = + // Use the max offset position if it is on this segment; otherwise, the segment size is the limit. + if (maxOffsetMetadata.segmentBaseOffset == segment.baseOffset) maxOffsetMetadata.relativePositionInSegment + else segment.size + + fetchDataInfo = segment.read(startOffset, maxLength, maxPosition, minOneMessage) + if (fetchDataInfo != null) { + if (includeAbortedTxns) + fetchDataInfo = addAbortedTransactions(startOffset, segment, fetchDataInfo) + } else segmentOpt = segments.higherSegment(baseOffset) + } + + if (fetchDataInfo != null) fetchDataInfo + else { + // okay we are beyond the end of the last segment with no data fetched although the start offset is in range, + // this can happen when all messages with offset larger than start offsets have been deleted. + // In this case, we will return the empty set with log end offset metadata + FetchDataInfo(nextOffsetMetadata, MemoryRecords.EMPTY) + } + } + } + } + + private[log] def append(lastOffset: Long, largestTimestamp: Long, shallowOffsetOfMaxTimestamp: Long, records: MemoryRecords): Unit = { + segments.activeSegment.append(largestOffset = lastOffset, largestTimestamp = largestTimestamp, + shallowOffsetOfMaxTimestamp = shallowOffsetOfMaxTimestamp, records = records) + updateLogEndOffset(lastOffset + 1) + } + + private def addAbortedTransactions(startOffset: Long, segment: LogSegment, + fetchInfo: FetchDataInfo): FetchDataInfo = { + val fetchSize = fetchInfo.records.sizeInBytes + val startOffsetPosition = OffsetPosition(fetchInfo.fetchOffsetMetadata.messageOffset, + fetchInfo.fetchOffsetMetadata.relativePositionInSegment) + val upperBoundOffset = segment.fetchUpperBoundOffset(startOffsetPosition, fetchSize).getOrElse { + segments.higherSegment(segment.baseOffset).map(_.baseOffset).getOrElse(logEndOffset) + } + + val abortedTransactions = ListBuffer.empty[FetchResponseData.AbortedTransaction] + def accumulator(abortedTxns: List[AbortedTxn]): Unit = abortedTransactions ++= abortedTxns.map(_.asAbortedTransaction) + collectAbortedTransactions(startOffset, upperBoundOffset, segment, accumulator) + + FetchDataInfo(fetchOffsetMetadata = fetchInfo.fetchOffsetMetadata, + records = fetchInfo.records, + firstEntryIncomplete = fetchInfo.firstEntryIncomplete, + abortedTransactions = Some(abortedTransactions.toList)) + } + + private def collectAbortedTransactions(startOffset: Long, upperBoundOffset: Long, + startingSegment: LogSegment, + accumulator: List[AbortedTxn] => Unit): Unit = { + val higherSegments = segments.higherSegments(startingSegment.baseOffset).iterator + var segmentEntryOpt = Option(startingSegment) + while (segmentEntryOpt.isDefined) { + val segment = segmentEntryOpt.get + val searchResult = segment.collectAbortedTxns(startOffset, upperBoundOffset) + accumulator(searchResult.abortedTransactions) + if (searchResult.isComplete) + return + segmentEntryOpt = nextOption(higherSegments) + } + } + + private[log] def collectAbortedTransactions(logStartOffset: Long, baseOffset: Long, upperBoundOffset: Long): List[AbortedTxn] = { + val segmentEntry = segments.floorSegment(baseOffset) + val allAbortedTxns = ListBuffer.empty[AbortedTxn] + def accumulator(abortedTxns: List[AbortedTxn]): Unit = allAbortedTxns ++= abortedTxns + segmentEntry.foreach(segment => collectAbortedTransactions(logStartOffset, upperBoundOffset, segment, accumulator)) + allAbortedTxns.toList + } + + /** + * Roll the log over to a new active segment starting with the current logEndOffset. + * This will trim the index to the exact size of the number of entries it currently contains. + * + * @param expectedNextOffset The expected next offset after the segment is rolled + * + * @return The newly rolled segment + */ + private[log] def roll(expectedNextOffset: Option[Long] = None): LogSegment = { + maybeHandleIOException(s"Error while rolling log segment for $topicPartition in dir ${dir.getParent}") { + val start = time.hiResClockMs() + checkIfMemoryMappedBufferClosed() + val newOffset = math.max(expectedNextOffset.getOrElse(0L), logEndOffset) + val logFile = LocalLog.logFile(dir, newOffset) + val activeSegment = segments.activeSegment + if (segments.contains(newOffset)) { + // segment with the same base offset already exists and loaded + if (activeSegment.baseOffset == newOffset && activeSegment.size == 0) { + // We have seen this happen (see KAFKA-6388) after shouldRoll() returns true for an + // active segment of size zero because of one of the indexes is "full" (due to _maxEntries == 0). + warn(s"Trying to roll a new log segment with start offset $newOffset " + + s"=max(provided offset = $expectedNextOffset, LEO = $logEndOffset) while it already " + + s"exists and is active with size 0. Size of time index: ${activeSegment.timeIndex.entries}," + + s" size of offset index: ${activeSegment.offsetIndex.entries}.") + removeAndDeleteSegments(Seq(activeSegment), asyncDelete = true, LogRoll(this)) + } else { + throw new KafkaException(s"Trying to roll a new log segment for topic partition $topicPartition with start offset $newOffset" + + s" =max(provided offset = $expectedNextOffset, LEO = $logEndOffset) while it already exists. Existing " + + s"segment is ${segments.get(newOffset)}.") + } + } else if (!segments.isEmpty && newOffset < activeSegment.baseOffset) { + throw new KafkaException( + s"Trying to roll a new log segment for topic partition $topicPartition with " + + s"start offset $newOffset =max(provided offset = $expectedNextOffset, LEO = $logEndOffset) lower than start offset of the active segment $activeSegment") + } else { + val offsetIdxFile = offsetIndexFile(dir, newOffset) + val timeIdxFile = timeIndexFile(dir, newOffset) + val txnIdxFile = transactionIndexFile(dir, newOffset) + + for (file <- List(logFile, offsetIdxFile, timeIdxFile, txnIdxFile) if file.exists) { + warn(s"Newly rolled segment file ${file.getAbsolutePath} already exists; deleting it first") + Files.delete(file.toPath) + } + + segments.lastSegment.foreach(_.onBecomeInactiveSegment()) + } + + val newSegment = LogSegment.open(dir, + baseOffset = newOffset, + config, + time = time, + initFileSize = config.initFileSize, + preallocate = config.preallocate) + segments.add(newSegment) + + // We need to update the segment base offset and append position data of the metadata when log rolls. + // The next offset should not change. + updateLogEndOffset(nextOffsetMetadata.messageOffset) + + info(s"Rolled new log segment at offset $newOffset in ${time.hiResClockMs() - start} ms.") + + newSegment + } + } + + /** + * Delete all data in the local log and start at the new offset. + * + * @param newOffset The new offset to start the log with + * @return the list of segments that were scheduled for deletion + */ + private[log] def truncateFullyAndStartAt(newOffset: Long): Iterable[LogSegment] = { + maybeHandleIOException(s"Error while truncating the entire log for $topicPartition in dir ${dir.getParent}") { + debug(s"Truncate and start at offset $newOffset") + checkIfMemoryMappedBufferClosed() + val segmentsToDelete = List[LogSegment]() ++ segments.values + removeAndDeleteSegments(segmentsToDelete, asyncDelete = true, LogTruncation(this)) + segments.add(LogSegment.open(dir, + baseOffset = newOffset, + config = config, + time = time, + initFileSize = config.initFileSize, + preallocate = config.preallocate)) + updateLogEndOffset(newOffset) + segmentsToDelete + } + } + + /** + * Truncate this log so that it ends with the greatest offset < targetOffset. + * + * @param targetOffset The offset to truncate to, an upper bound on all offsets in the log after truncation is complete. + * @return the list of segments that were scheduled for deletion + */ + private[log] def truncateTo(targetOffset: Long): Iterable[LogSegment] = { + val deletableSegments = List[LogSegment]() ++ segments.filter(segment => segment.baseOffset > targetOffset) + removeAndDeleteSegments(deletableSegments, asyncDelete = true, LogTruncation(this)) + segments.activeSegment.truncateTo(targetOffset) + updateLogEndOffset(targetOffset) + deletableSegments + } +} + +/** + * Helper functions for logs + */ +object LocalLog extends Logging { + + /** a log file */ + private[log] val LogFileSuffix = ".log" + + /** an index file */ + private[log] val IndexFileSuffix = ".index" + + /** a time index file */ + private[log] val TimeIndexFileSuffix = ".timeindex" + + /** an (aborted) txn index */ + private[log] val TxnIndexFileSuffix = ".txnindex" + + /** a file that is scheduled to be deleted */ + private[log] val DeletedFileSuffix = ".deleted" + + /** A temporary file that is being used for log cleaning */ + private[log] val CleanedFileSuffix = ".cleaned" + + /** A temporary file used when swapping files into the log */ + private[log] val SwapFileSuffix = ".swap" + + /** a directory that is scheduled to be deleted */ + private[log] val DeleteDirSuffix = "-delete" + + /** a directory that is used for future partition */ + private[log] val FutureDirSuffix = "-future" + + private[log] val DeleteDirPattern = Pattern.compile(s"^(\\S+)-(\\S+)\\.(\\S+)$DeleteDirSuffix") + private[log] val FutureDirPattern = Pattern.compile(s"^(\\S+)-(\\S+)\\.(\\S+)$FutureDirSuffix") + + private[log] val UnknownOffset = -1L + + /** + * Make log segment file name from offset bytes. All this does is pad out the offset number with zeros + * so that ls sorts the files numerically. + * + * @param offset The offset to use in the file name + * @return The filename + */ + private[log] def filenamePrefixFromOffset(offset: Long): String = { + val nf = NumberFormat.getInstance() + nf.setMinimumIntegerDigits(20) + nf.setMaximumFractionDigits(0) + nf.setGroupingUsed(false) + nf.format(offset) + } + + /** + * Construct a log file name in the given dir with the given base offset and the given suffix + * + * @param dir The directory in which the log will reside + * @param offset The base offset of the log file + * @param suffix The suffix to be appended to the file name (e.g. "", ".deleted", ".cleaned", ".swap", etc.) + */ + private[log] def logFile(dir: File, offset: Long, suffix: String = ""): File = + new File(dir, filenamePrefixFromOffset(offset) + LogFileSuffix + suffix) + + /** + * Return a directory name to rename the log directory to for async deletion. + * The name will be in the following format: "topic-partitionId.uniqueId-delete". + * If the topic name is too long, it will be truncated to prevent the total name + * from exceeding 255 characters. + */ + private[log] def logDeleteDirName(topicPartition: TopicPartition): String = { + val uniqueId = java.util.UUID.randomUUID.toString.replaceAll("-", "") + val suffix = s"-${topicPartition.partition()}.${uniqueId}${DeleteDirSuffix}" + val prefixLength = Math.min(topicPartition.topic().size, 255 - suffix.size) + s"${topicPartition.topic().substring(0, prefixLength)}${suffix}" + } + + /** + * Return a future directory name for the given topic partition. The name will be in the following + * format: topic-partition.uniqueId-future where topic, partition and uniqueId are variables. + */ + private[log] def logFutureDirName(topicPartition: TopicPartition): String = { + logDirNameWithSuffix(topicPartition, FutureDirSuffix) + } + + private[log] def logDirNameWithSuffix(topicPartition: TopicPartition, suffix: String): String = { + val uniqueId = java.util.UUID.randomUUID.toString.replaceAll("-", "") + s"${logDirName(topicPartition)}.$uniqueId$suffix" + } + + /** + * Return a directory name for the given topic partition. The name will be in the following + * format: topic-partition where topic, partition are variables. + */ + private[log] def logDirName(topicPartition: TopicPartition): String = { + s"${topicPartition.topic}-${topicPartition.partition}" + } + + /** + * Construct an index file name in the given dir using the given base offset and the given suffix + * + * @param dir The directory in which the log will reside + * @param offset The base offset of the log file + * @param suffix The suffix to be appended to the file name ("", ".deleted", ".cleaned", ".swap", etc.) + */ + private[log] def offsetIndexFile(dir: File, offset: Long, suffix: String = ""): File = + new File(dir, filenamePrefixFromOffset(offset) + IndexFileSuffix + suffix) + + /** + * Construct a time index file name in the given dir using the given base offset and the given suffix + * + * @param dir The directory in which the log will reside + * @param offset The base offset of the log file + * @param suffix The suffix to be appended to the file name ("", ".deleted", ".cleaned", ".swap", etc.) + */ + private[log] def timeIndexFile(dir: File, offset: Long, suffix: String = ""): File = + new File(dir, filenamePrefixFromOffset(offset) + TimeIndexFileSuffix + suffix) + + /** + * Construct a transaction index file name in the given dir using the given base offset and the given suffix + * + * @param dir The directory in which the log will reside + * @param offset The base offset of the log file + * @param suffix The suffix to be appended to the file name ("", ".deleted", ".cleaned", ".swap", etc.) + */ + private[log] def transactionIndexFile(dir: File, offset: Long, suffix: String = ""): File = + new File(dir, filenamePrefixFromOffset(offset) + TxnIndexFileSuffix + suffix) + + private[log] def offsetFromFileName(filename: String): Long = { + filename.substring(0, filename.indexOf('.')).toLong + } + + private[log] def offsetFromFile(file: File): Long = { + offsetFromFileName(file.getName) + } + + /** + * Parse the topic and partition out of the directory name of a log + */ + private[log] def parseTopicPartitionName(dir: File): TopicPartition = { + if (dir == null) + throw new KafkaException("dir should not be null") + + def exception(dir: File): KafkaException = { + new KafkaException(s"Found directory ${dir.getCanonicalPath}, '${dir.getName}' is not in the form of " + + "topic-partition or topic-partition.uniqueId-delete (if marked for deletion).\n" + + "Kafka's log directories (and children) should only contain Kafka topic data.") + } + + val dirName = dir.getName + if (dirName == null || dirName.isEmpty || !dirName.contains('-')) + throw exception(dir) + if (dirName.endsWith(DeleteDirSuffix) && !DeleteDirPattern.matcher(dirName).matches || + dirName.endsWith(FutureDirSuffix) && !FutureDirPattern.matcher(dirName).matches) + throw exception(dir) + + val name: String = + if (dirName.endsWith(DeleteDirSuffix) || dirName.endsWith(FutureDirSuffix)) dirName.substring(0, dirName.lastIndexOf('.')) + else dirName + + val index = name.lastIndexOf('-') + val topic = name.substring(0, index) + val partitionString = name.substring(index + 1) + if (topic.isEmpty || partitionString.isEmpty) + throw exception(dir) + + val partition = + try partitionString.toInt + catch { case _: NumberFormatException => throw exception(dir) } + + new TopicPartition(topic, partition) + } + + private[log] def isIndexFile(file: File): Boolean = { + val filename = file.getName + filename.endsWith(IndexFileSuffix) || filename.endsWith(TimeIndexFileSuffix) || filename.endsWith(TxnIndexFileSuffix) + } + + private[log] def isLogFile(file: File): Boolean = + file.getPath.endsWith(LogFileSuffix) + + /** + * Invokes the provided function and handles any IOException raised by the function by marking the + * provided directory offline. + * + * @param logDirFailureChannel Used to asynchronously handle log directory failure. + * @param logDir The log directory to be marked offline during an IOException. + * @param errorMsg The error message to be used when marking the log directory offline. + * @param fun The function to be executed. + * @return The value returned by the function after a successful invocation + */ + private[log] def maybeHandleIOException[T](logDirFailureChannel: LogDirFailureChannel, + logDir: String, + errorMsg: => String)(fun: => T): T = { + if (logDirFailureChannel.hasOfflineLogDir(logDir)) { + throw new KafkaStorageException(s"The log dir $logDir is already offline due to a previous IO exception.") + } + try { + fun + } catch { + case e: IOException => + logDirFailureChannel.maybeAddOfflineLogDir(logDir, errorMsg, e) + throw new KafkaStorageException(errorMsg, e) + } + } + + /** + * Split a segment into one or more segments such that there is no offset overflow in any of them. The + * resulting segments will contain the exact same messages that are present in the input segment. On successful + * completion of this method, the input segment will be deleted and will be replaced by the resulting new segments. + * See replaceSegments for recovery logic, in case the broker dies in the middle of this operation. + * + * Note that this method assumes we have already determined that the segment passed in contains records that cause + * offset overflow. + * + * The split logic overloads the use of .clean files that LogCleaner typically uses to make the process of replacing + * the input segment with multiple new segments atomic and recoverable in the event of a crash. See replaceSegments + * and completeSwapOperations for the implementation to make this operation recoverable on crashes.

            + * + * @param segment Segment to split + * @param existingSegments The existing segments of the log + * @param dir The directory in which the log will reside + * @param topicPartition The topic + * @param config The log configuration settings + * @param scheduler The thread pool scheduler used for background actions + * @param logDirFailureChannel The LogDirFailureChannel to asynchronously handle log dir failure + * @param logPrefix The logging prefix + * @return List of new segments that replace the input segment + */ + private[log] def splitOverflowedSegment(segment: LogSegment, + existingSegments: LogSegments, + dir: File, + topicPartition: TopicPartition, + config: LogConfig, + scheduler: Scheduler, + logDirFailureChannel: LogDirFailureChannel, + logPrefix: String): SplitSegmentResult = { + require(isLogFile(segment.log.file), s"Cannot split file ${segment.log.file.getAbsoluteFile}") + require(segment.hasOverflow, s"Split operation is only permitted for segments with overflow, and the problem path is ${segment.log.file.getAbsoluteFile}") + + info(s"${logPrefix}Splitting overflowed segment $segment") + + val newSegments = ListBuffer[LogSegment]() + try { + var position = 0 + val sourceRecords = segment.log + + while (position < sourceRecords.sizeInBytes) { + val firstBatch = sourceRecords.batchesFrom(position).asScala.head + val newSegment = createNewCleanedSegment(dir, config, firstBatch.baseOffset) + newSegments += newSegment + + val bytesAppended = newSegment.appendFromFile(sourceRecords, position) + if (bytesAppended == 0) + throw new IllegalStateException(s"Failed to append records from position $position in $segment") + + position += bytesAppended + } + + // prepare new segments + var totalSizeOfNewSegments = 0 + newSegments.foreach { splitSegment => + splitSegment.onBecomeInactiveSegment() + splitSegment.flush() + splitSegment.lastModified = segment.lastModified + totalSizeOfNewSegments += splitSegment.log.sizeInBytes + } + // size of all the new segments combined must equal size of the original segment + if (totalSizeOfNewSegments != segment.log.sizeInBytes) + throw new IllegalStateException("Inconsistent segment sizes after split" + + s" before: ${segment.log.sizeInBytes} after: $totalSizeOfNewSegments") + + // replace old segment with new ones + info(s"${logPrefix}Replacing overflowed segment $segment with split segments $newSegments") + val newSegmentsToAdd = newSegments.toSeq + val deletedSegments = LocalLog.replaceSegments(existingSegments, newSegmentsToAdd, List(segment), + dir, topicPartition, config, scheduler, logDirFailureChannel, logPrefix) + SplitSegmentResult(deletedSegments.toSeq, newSegmentsToAdd) + } catch { + case e: Exception => + newSegments.foreach { splitSegment => + splitSegment.close() + splitSegment.deleteIfExists() + } + throw e + } + } + + /** + * Swap one or more new segment in place and delete one or more existing segments in a crash-safe + * manner. The old segments will be asynchronously deleted. + * + * This method does not need to convert IOException to KafkaStorageException because it is either + * called before all logs are loaded or the caller will catch and handle IOException + * + * The sequence of operations is: + * + * - Cleaner creates one or more new segments with suffix .cleaned and invokes replaceSegments() on + * the Log instance. If broker crashes at this point, the clean-and-swap operation is aborted and + * the .cleaned files are deleted on recovery in LogLoader. + * - New segments are renamed .swap. If the broker crashes before all segments were renamed to .swap, the + * clean-and-swap operation is aborted - .cleaned as well as .swap files are deleted on recovery in + * in LogLoader. We detect this situation by maintaining a specific order in which files are renamed + * from .cleaned to .swap. Basically, files are renamed in descending order of offsets. On recovery, + * all .swap files whose offset is greater than the minimum-offset .clean file are deleted. + * - If the broker crashes after all new segments were renamed to .swap, the operation is completed, + * the swap operation is resumed on recovery as described in the next step. + * - Old segment files are renamed to .deleted and asynchronous delete is scheduled. If the broker + * crashes, any .deleted files left behind are deleted on recovery in LogLoader. + * replaceSegments() is then invoked to complete the swap with newSegment recreated from the + * .swap file and oldSegments containing segments which were not renamed before the crash. + * - Swap segment(s) are renamed to replace the existing segments, completing this operation. + * If the broker crashes, any .deleted files which may be left behind are deleted + * on recovery in LogLoader. + * + * @param existingSegments The existing segments of the log + * @param newSegments The new log segment to add to the log + * @param oldSegments The old log segments to delete from the log + * @param dir The directory in which the log will reside + * @param topicPartition The topic + * @param config The log configuration settings + * @param scheduler The thread pool scheduler used for background actions + * @param logDirFailureChannel The LogDirFailureChannel to asynchronously handle log dir failure + * @param logPrefix The logging prefix + * @param isRecoveredSwapFile true if the new segment was created from a swap file during recovery after a crash + */ + private[log] def replaceSegments(existingSegments: LogSegments, + newSegments: Seq[LogSegment], + oldSegments: Seq[LogSegment], + dir: File, + topicPartition: TopicPartition, + config: LogConfig, + scheduler: Scheduler, + logDirFailureChannel: LogDirFailureChannel, + logPrefix: String, + isRecoveredSwapFile: Boolean = false): Iterable[LogSegment] = { + val sortedNewSegments = newSegments.sortBy(_.baseOffset) + // Some old segments may have been removed from index and scheduled for async deletion after the caller reads segments + // but before this method is executed. We want to filter out those segments to avoid calling asyncDeleteSegment() + // multiple times for the same segment. + val sortedOldSegments = oldSegments.filter(seg => existingSegments.contains(seg.baseOffset)).sortBy(_.baseOffset) + + // need to do this in two phases to be crash safe AND do the delete asynchronously + // if we crash in the middle of this we complete the swap in loadSegments() + if (!isRecoveredSwapFile) + sortedNewSegments.reverse.foreach(_.changeFileSuffixes(CleanedFileSuffix, SwapFileSuffix)) + sortedNewSegments.reverse.foreach(existingSegments.add(_)) + val newSegmentBaseOffsets = sortedNewSegments.map(_.baseOffset).toSet + + // delete the old files + val deletedNotReplaced = sortedOldSegments.map { seg => + // remove the index entry + if (seg.baseOffset != sortedNewSegments.head.baseOffset) + existingSegments.remove(seg.baseOffset) + deleteSegmentFiles( + List(seg), + asyncDelete = true, + dir, + topicPartition, + config, + scheduler, + logDirFailureChannel, + logPrefix) + if (newSegmentBaseOffsets.contains(seg.baseOffset)) Option.empty else Some(seg) + }.filter(item => item.isDefined).map(item => item.get) + + // okay we are safe now, remove the swap suffix + sortedNewSegments.foreach(_.changeFileSuffixes(SwapFileSuffix, "")) + Utils.flushDir(dir.toPath) + deletedNotReplaced + } + + /** + * Perform physical deletion of the index and log files for the given segment. + * Prior to the deletion, the index and log files are renamed by appending .deleted to the + * respective file name. Allows these files to be optionally deleted asynchronously. + * + * This method assumes that the file exists. It does not need to convert IOException + * (thrown from changeFileSuffixes) to KafkaStorageException because it is either called before + * all logs are loaded or the caller will catch and handle IOException. + * + * @param segmentsToDelete The segments to be deleted + * @param asyncDelete If true, the deletion of the segments is done asynchronously + * @param dir The directory in which the log will reside + * @param topicPartition The topic + * @param config The log configuration settings + * @param scheduler The thread pool scheduler used for background actions + * @param logDirFailureChannel The LogDirFailureChannel to asynchronously handle log dir failure + * @param logPrefix The logging prefix + * @throws IOException if the file can't be renamed and still exists + */ + private[log] def deleteSegmentFiles(segmentsToDelete: immutable.Iterable[LogSegment], + asyncDelete: Boolean, + dir: File, + topicPartition: TopicPartition, + config: LogConfig, + scheduler: Scheduler, + logDirFailureChannel: LogDirFailureChannel, + logPrefix: String): Unit = { + segmentsToDelete.foreach(_.changeFileSuffixes("", DeletedFileSuffix)) + + def deleteSegments(): Unit = { + info(s"${logPrefix}Deleting segment files ${segmentsToDelete.mkString(",")}") + val parentDir = dir.getParent + maybeHandleIOException(logDirFailureChannel, parentDir, s"Error while deleting segments for $topicPartition in dir $parentDir") { + segmentsToDelete.foreach { segment => + segment.deleteIfExists() + } + } + } + + if (asyncDelete) + scheduler.schedule("delete-file", () => deleteSegments(), delay = config.fileDeleteDelayMs) + else + deleteSegments() + } + + private[log] def emptyFetchDataInfo(fetchOffsetMetadata: LogOffsetMetadata, + includeAbortedTxns: Boolean): FetchDataInfo = { + val abortedTransactions = + if (includeAbortedTxns) Some(List.empty[FetchResponseData.AbortedTransaction]) + else None + FetchDataInfo(fetchOffsetMetadata, + MemoryRecords.EMPTY, + abortedTransactions = abortedTransactions) + } + + private[log] def createNewCleanedSegment(dir: File, logConfig: LogConfig, baseOffset: Long): LogSegment = { + LogSegment.deleteIfExists(dir, baseOffset, fileSuffix = CleanedFileSuffix) + LogSegment.open(dir, baseOffset, logConfig, Time.SYSTEM, + fileSuffix = CleanedFileSuffix, initFileSize = logConfig.initFileSize, preallocate = logConfig.preallocate) + } + + /** + * Wraps the value of iterator.next() in an option. + * Note: this facility is a part of the Iterator class starting from scala v2.13. + * + * @param iterator + * @tparam T the type of object held within the iterator + * @return Some(iterator.next) if a next element exists, None otherwise. + */ + private def nextOption[T](iterator: Iterator[T]): Option[T] = { + if (iterator.hasNext) + Some(iterator.next()) + else + None + } +} + +trait SegmentDeletionReason { + def logReason(toDelete: List[LogSegment]): Unit +} + +case class LogTruncation(log: LocalLog) extends SegmentDeletionReason { + override def logReason(toDelete: List[LogSegment]): Unit = { + log.info(s"Deleting segments as part of log truncation: ${toDelete.mkString(",")}") + } +} + +case class LogRoll(log: LocalLog) extends SegmentDeletionReason { + override def logReason(toDelete: List[LogSegment]): Unit = { + log.info(s"Deleting segments as part of log roll: ${toDelete.mkString(",")}") + } +} + +case class LogDeletion(log: LocalLog) extends SegmentDeletionReason { + override def logReason(toDelete: List[LogSegment]): Unit = { + log.info(s"Deleting segments as the log has been deleted: ${toDelete.mkString(",")}") + } +} diff --git a/core/src/main/scala/kafka/log/LogCleaner.scala b/core/src/main/scala/kafka/log/LogCleaner.scala new file mode 100644 index 0000000..0d4cab9 --- /dev/null +++ b/core/src/main/scala/kafka/log/LogCleaner.scala @@ -0,0 +1,1190 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import java.io.{File, IOException} +import java.nio._ +import java.util.Date +import java.util.concurrent.TimeUnit +import kafka.common._ +import kafka.metrics.KafkaMetricsGroup +import kafka.server.{BrokerReconfigurable, KafkaConfig, LogDirFailureChannel} +import kafka.utils._ +import org.apache.kafka.common.{KafkaException, TopicPartition} +import org.apache.kafka.common.config.ConfigException +import org.apache.kafka.common.errors.{CorruptRecordException, KafkaStorageException} +import org.apache.kafka.common.record.MemoryRecords.RecordFilter +import org.apache.kafka.common.record.MemoryRecords.RecordFilter.BatchRetention +import org.apache.kafka.common.record._ +import org.apache.kafka.common.utils.{BufferSupplier, Time} + +import scala.jdk.CollectionConverters._ +import scala.collection.mutable.ListBuffer +import scala.collection.{Iterable, Seq, Set, mutable} +import scala.util.control.ControlThrowable + +/** + * The cleaner is responsible for removing obsolete records from logs which have the "compact" retention strategy. + * A message with key K and offset O is obsolete if there exists a message with key K and offset O' such that O < O'. + * + * Each log can be thought of being split into two sections of segments: a "clean" section which has previously been cleaned followed by a + * "dirty" section that has not yet been cleaned. The dirty section is further divided into the "cleanable" section followed by an "uncleanable" section. + * The uncleanable section is excluded from cleaning. The active log segment is always uncleanable. If there is a + * compaction lag time set, segments whose largest message timestamp is within the compaction lag time of the cleaning operation are also uncleanable. + * + * The cleaning is carried out by a pool of background threads. Each thread chooses the dirtiest log that has the "compact" retention policy + * and cleans that. The dirtiness of the log is guessed by taking the ratio of bytes in the dirty section of the log to the total bytes in the log. + * + * To clean a log the cleaner first builds a mapping of key=>last_offset for the dirty section of the log. See kafka.log.OffsetMap for details of + * the implementation of the mapping. + * + * Once the key=>last_offset map is built, the log is cleaned by recopying each log segment but omitting any key that appears in the offset map with a + * higher offset than what is found in the segment (i.e. messages with a key that appears in the dirty section of the log). + * + * To avoid segments shrinking to very small sizes with repeated cleanings we implement a rule by which if we will merge successive segments when + * doing a cleaning if their log and index size are less than the maximum log and index size prior to the clean beginning. + * + * Cleaned segments are swapped into the log as they become available. + * + * One nuance that the cleaner must handle is log truncation. If a log is truncated while it is being cleaned the cleaning of that log is aborted. + * + * Messages with null payload are treated as deletes for the purpose of log compaction. This means that they receive special treatment by the cleaner. + * The cleaner will only retain delete records for a period of time to avoid accumulating space indefinitely. This period of time is configurable on a per-topic + * basis and is measured from the time the segment enters the clean portion of the log (at which point any prior message with that key has been removed). + * Delete markers in the clean section of the log that are older than this time will not be retained when log segments are being recopied as part of cleaning. + * This time is tracked by setting the base timestamp of a record batch with delete markers when the batch is recopied in the first cleaning that encounters + * it. The relative timestamps of the records in the batch are also modified when recopied in this cleaning according to the new base timestamp of the batch. + * + * Note that cleaning is more complicated with the idempotent/transactional producer capabilities. The following + * are the key points: + * + * 1. In order to maintain sequence number continuity for active producers, we always retain the last batch + * from each producerId, even if all the records from the batch have been removed. The batch will be removed + * once the producer either writes a new batch or is expired due to inactivity. + * 2. We do not clean beyond the last stable offset. This ensures that all records observed by the cleaner have + * been decided (i.e. committed or aborted). In particular, this allows us to use the transaction index to + * collect the aborted transactions ahead of time. + * 3. Records from aborted transactions are removed by the cleaner immediately without regard to record keys. + * 4. Transaction markers are retained until all record batches from the same transaction have been removed and + * a sufficient amount of time has passed to reasonably ensure that an active consumer wouldn't consume any + * data from the transaction prior to reaching the offset of the marker. This follows the same logic used for + * tombstone deletion. + * + * @param initialConfig Initial configuration parameters for the cleaner. Actual config may be dynamically updated. + * @param logDirs The directories where offset checkpoints reside + * @param logs The pool of logs + * @param time A way to control the passage of time + */ +class LogCleaner(initialConfig: CleanerConfig, + val logDirs: Seq[File], + val logs: Pool[TopicPartition, UnifiedLog], + val logDirFailureChannel: LogDirFailureChannel, + time: Time = Time.SYSTEM) extends Logging with KafkaMetricsGroup with BrokerReconfigurable +{ + + /* Log cleaner configuration which may be dynamically updated */ + @volatile private var config = initialConfig + + /* for managing the state of partitions being cleaned. package-private to allow access in tests */ + private[log] val cleanerManager = new LogCleanerManager(logDirs, logs, logDirFailureChannel) + + /* a throttle used to limit the I/O of all the cleaner threads to a user-specified maximum rate */ + private val throttler = new Throttler(desiredRatePerSec = config.maxIoBytesPerSecond, + checkIntervalMs = 300, + throttleDown = true, + "cleaner-io", + "bytes", + time = time) + + private[log] val cleaners = mutable.ArrayBuffer[CleanerThread]() + + /** + * scala 2.12 does not support maxOption so we handle the empty manually. + * @param f to compute the result + * @return the max value (int value) or 0 if there is no cleaner + */ + private def maxOverCleanerThreads(f: CleanerThread => Double): Int = + cleaners.foldLeft(0.0d)((max: Double, thread: CleanerThread) => math.max(max, f(thread))).toInt + + + /* a metric to track the maximum utilization of any thread's buffer in the last cleaning */ + newGauge("max-buffer-utilization-percent", + () => maxOverCleanerThreads(_.lastStats.bufferUtilization) * 100) + + /* a metric to track the recopy rate of each thread's last cleaning */ + newGauge("cleaner-recopy-percent", () => { + val stats = cleaners.map(_.lastStats) + val recopyRate = stats.iterator.map(_.bytesWritten).sum.toDouble / math.max(stats.iterator.map(_.bytesRead).sum, 1) + (100 * recopyRate).toInt + }) + + /* a metric to track the maximum cleaning time for the last cleaning from each thread */ + newGauge("max-clean-time-secs", + () => maxOverCleanerThreads(_.lastStats.elapsedSecs)) + + + // a metric to track delay between the time when a log is required to be compacted + // as determined by max compaction lag and the time of last cleaner run. + newGauge("max-compaction-delay-secs", + () => maxOverCleanerThreads(_.lastPreCleanStats.maxCompactionDelayMs.toDouble) / 1000) + + newGauge("DeadThreadCount", () => deadThreadCount) + + private[log] def deadThreadCount: Int = cleaners.count(_.isThreadFailed) + + /** + * Start the background cleaning + */ + def startup(): Unit = { + info("Starting the log cleaner") + (0 until config.numThreads).foreach { i => + val cleaner = new CleanerThread(i) + cleaners += cleaner + cleaner.start() + } + } + + /** + * Stop the background cleaning + */ + def shutdown(): Unit = { + info("Shutting down the log cleaner.") + cleaners.foreach(_.shutdown()) + cleaners.clear() + } + + override def reconfigurableConfigs: Set[String] = { + LogCleaner.ReconfigurableConfigs + } + + override def validateReconfiguration(newConfig: KafkaConfig): Unit = { + val newCleanerConfig = LogCleaner.cleanerConfig(newConfig) + val numThreads = newCleanerConfig.numThreads + val currentThreads = config.numThreads + if (numThreads < 1) + throw new ConfigException(s"Log cleaner threads should be at least 1") + if (numThreads < currentThreads / 2) + throw new ConfigException(s"Log cleaner threads cannot be reduced to less than half the current value $currentThreads") + if (numThreads > currentThreads * 2) + throw new ConfigException(s"Log cleaner threads cannot be increased to more than double the current value $currentThreads") + + } + + /** + * Reconfigure log clean config. This simply stops current log cleaners and creates new ones. + * That ensures that if any of the cleaners had failed, new cleaners are created to match the new config. + */ + override def reconfigure(oldConfig: KafkaConfig, newConfig: KafkaConfig): Unit = { + config = LogCleaner.cleanerConfig(newConfig) + shutdown() + startup() + } + + /** + * Abort the cleaning of a particular partition, if it's in progress. This call blocks until the cleaning of + * the partition is aborted. + */ + def abortCleaning(topicPartition: TopicPartition): Unit = { + cleanerManager.abortCleaning(topicPartition) + } + + /** + * Update checkpoint file to remove partitions if necessary. + */ + def updateCheckpoints(dataDir: File, partitionToRemove: Option[TopicPartition] = None): Unit = { + cleanerManager.updateCheckpoints(dataDir, partitionToRemove = partitionToRemove) + } + + /** + * alter the checkpoint directory for the topicPartition, to remove the data in sourceLogDir, and add the data in destLogDir + */ + def alterCheckpointDir(topicPartition: TopicPartition, sourceLogDir: File, destLogDir: File): Unit = { + cleanerManager.alterCheckpointDir(topicPartition, sourceLogDir, destLogDir) + } + + /** + * Stop cleaning logs in the provided directory + * + * @param dir the absolute path of the log dir + */ + def handleLogDirFailure(dir: String): Unit = { + cleanerManager.handleLogDirFailure(dir) + } + + /** + * Truncate cleaner offset checkpoint for the given partition if its checkpointed offset is larger than the given offset + */ + def maybeTruncateCheckpoint(dataDir: File, topicPartition: TopicPartition, offset: Long): Unit = { + cleanerManager.maybeTruncateCheckpoint(dataDir, topicPartition, offset) + } + + /** + * Abort the cleaning of a particular partition if it's in progress, and pause any future cleaning of this partition. + * This call blocks until the cleaning of the partition is aborted and paused. + */ + def abortAndPauseCleaning(topicPartition: TopicPartition): Unit = { + cleanerManager.abortAndPauseCleaning(topicPartition) + } + + /** + * Resume the cleaning of paused partitions. + */ + def resumeCleaning(topicPartitions: Iterable[TopicPartition]): Unit = { + cleanerManager.resumeCleaning(topicPartitions) + } + + /** + * For testing, a way to know when work has completed. This method waits until the + * cleaner has processed up to the given offset on the specified topic/partition + * + * @param topicPartition The topic and partition to be cleaned + * @param offset The first dirty offset that the cleaner doesn't have to clean + * @param maxWaitMs The maximum time in ms to wait for cleaner + * + * @return A boolean indicating whether the work has completed before timeout + */ + def awaitCleaned(topicPartition: TopicPartition, offset: Long, maxWaitMs: Long = 60000L): Boolean = { + def isCleaned = cleanerManager.allCleanerCheckpoints.get(topicPartition).fold(false)(_ >= offset) + var remainingWaitMs = maxWaitMs + while (!isCleaned && remainingWaitMs > 0) { + val sleepTime = math.min(100, remainingWaitMs) + Thread.sleep(sleepTime) + remainingWaitMs -= sleepTime + } + isCleaned + } + + /** + * To prevent race between retention and compaction, + * retention threads need to make this call to obtain: + * @return A list of log partitions that retention threads can safely work on + */ + def pauseCleaningForNonCompactedPartitions(): Iterable[(TopicPartition, UnifiedLog)] = { + cleanerManager.pauseCleaningForNonCompactedPartitions() + } + + // Only for testing + private[kafka] def currentConfig: CleanerConfig = config + + // Only for testing + private[log] def cleanerCount: Int = cleaners.size + + /** + * The cleaner threads do the actual log cleaning. Each thread processes does its cleaning repeatedly by + * choosing the dirtiest log, cleaning it, and then swapping in the cleaned segments. + */ + private[log] class CleanerThread(threadId: Int) + extends ShutdownableThread(name = s"kafka-log-cleaner-thread-$threadId", isInterruptible = false) { + + protected override def loggerName = classOf[LogCleaner].getName + + if (config.dedupeBufferSize / config.numThreads > Int.MaxValue) + warn("Cannot use more than 2G of cleaner buffer space per cleaner thread, ignoring excess buffer space...") + + val cleaner = new Cleaner(id = threadId, + offsetMap = new SkimpyOffsetMap(memory = math.min(config.dedupeBufferSize / config.numThreads, Int.MaxValue).toInt, + hashAlgorithm = config.hashAlgorithm), + ioBufferSize = config.ioBufferSize / config.numThreads / 2, + maxIoBufferSize = config.maxMessageSize, + dupBufferLoadFactor = config.dedupeBufferLoadFactor, + throttler = throttler, + time = time, + checkDone = checkDone) + + @volatile var lastStats: CleanerStats = new CleanerStats() + @volatile var lastPreCleanStats: PreCleanStats = new PreCleanStats() + + private def checkDone(topicPartition: TopicPartition): Unit = { + if (!isRunning) + throw new ThreadShutdownException + cleanerManager.checkCleaningAborted(topicPartition) + } + + /** + * The main loop for the cleaner thread + * Clean a log if there is a dirty log available, otherwise sleep for a bit + */ + override def doWork(): Unit = { + val cleaned = tryCleanFilthiestLog() + if (!cleaned) + pause(config.backOffMs, TimeUnit.MILLISECONDS) + + cleanerManager.maintainUncleanablePartitions() + } + + /** + * Cleans a log if there is a dirty log available + * @return whether a log was cleaned + */ + private def tryCleanFilthiestLog(): Boolean = { + try { + cleanFilthiestLog() + } catch { + case e: LogCleaningException => + warn(s"Unexpected exception thrown when cleaning log ${e.log}. Marking its partition (${e.log.topicPartition}) as uncleanable", e) + cleanerManager.markPartitionUncleanable(e.log.parentDir, e.log.topicPartition) + + false + } + } + + @throws(classOf[LogCleaningException]) + private def cleanFilthiestLog(): Boolean = { + val preCleanStats = new PreCleanStats() + val ltc = cleanerManager.grabFilthiestCompactedLog(time, preCleanStats) + val cleaned = ltc match { + case None => + false + case Some(cleanable) => + // there's a log, clean it + this.lastPreCleanStats = preCleanStats + try { + cleanLog(cleanable) + true + } catch { + case e @ (_: ThreadShutdownException | _: ControlThrowable) => throw e + case e: Exception => throw new LogCleaningException(cleanable.log, e.getMessage, e) + } + } + val deletable: Iterable[(TopicPartition, UnifiedLog)] = cleanerManager.deletableLogs() + try { + deletable.foreach { case (_, log) => + try { + log.deleteOldSegments() + } catch { + case e @ (_: ThreadShutdownException | _: ControlThrowable) => throw e + case e: Exception => throw new LogCleaningException(log, e.getMessage, e) + } + } + } finally { + cleanerManager.doneDeleting(deletable.map(_._1)) + } + + cleaned + } + + private def cleanLog(cleanable: LogToClean): Unit = { + val startOffset = cleanable.firstDirtyOffset + var endOffset = startOffset + try { + val (nextDirtyOffset, cleanerStats) = cleaner.clean(cleanable) + endOffset = nextDirtyOffset + recordStats(cleaner.id, cleanable.log.name, startOffset, endOffset, cleanerStats) + } catch { + case _: LogCleaningAbortedException => // task can be aborted, let it go. + case _: KafkaStorageException => // partition is already offline. let it go. + case e: IOException => + val logDirectory = cleanable.log.parentDir + val msg = s"Failed to clean up log for ${cleanable.topicPartition} in dir $logDirectory due to IOException" + logDirFailureChannel.maybeAddOfflineLogDir(logDirectory, msg, e) + } finally { + cleanerManager.doneCleaning(cleanable.topicPartition, cleanable.log.parentDirFile, endOffset) + } + } + + /** + * Log out statistics on a single run of the cleaner. + */ + def recordStats(id: Int, name: String, from: Long, to: Long, stats: CleanerStats): Unit = { + this.lastStats = stats + def mb(bytes: Double) = bytes / (1024*1024) + val message = + "%n\tLog cleaner thread %d cleaned log %s (dirty section = [%d, %d])%n".format(id, name, from, to) + + "\t%,.1f MB of log processed in %,.1f seconds (%,.1f MB/sec).%n".format(mb(stats.bytesRead.toDouble), + stats.elapsedSecs, + mb(stats.bytesRead.toDouble / stats.elapsedSecs)) + + "\tIndexed %,.1f MB in %.1f seconds (%,.1f Mb/sec, %.1f%% of total time)%n".format(mb(stats.mapBytesRead.toDouble), + stats.elapsedIndexSecs, + mb(stats.mapBytesRead.toDouble) / stats.elapsedIndexSecs, + 100 * stats.elapsedIndexSecs / stats.elapsedSecs) + + "\tBuffer utilization: %.1f%%%n".format(100 * stats.bufferUtilization) + + "\tCleaned %,.1f MB in %.1f seconds (%,.1f Mb/sec, %.1f%% of total time)%n".format(mb(stats.bytesRead.toDouble), + stats.elapsedSecs - stats.elapsedIndexSecs, + mb(stats.bytesRead.toDouble) / (stats.elapsedSecs - stats.elapsedIndexSecs), 100 * (stats.elapsedSecs - stats.elapsedIndexSecs) / stats.elapsedSecs) + + "\tStart size: %,.1f MB (%,d messages)%n".format(mb(stats.bytesRead.toDouble), stats.messagesRead) + + "\tEnd size: %,.1f MB (%,d messages)%n".format(mb(stats.bytesWritten.toDouble), stats.messagesWritten) + + "\t%.1f%% size reduction (%.1f%% fewer messages)%n".format(100.0 * (1.0 - stats.bytesWritten.toDouble/stats.bytesRead), + 100.0 * (1.0 - stats.messagesWritten.toDouble/stats.messagesRead)) + info(message) + if (lastPreCleanStats.delayedPartitions > 0) { + info("\tCleanable partitions: %d, Delayed partitions: %d, max delay: %d".format(lastPreCleanStats.cleanablePartitions, lastPreCleanStats.delayedPartitions, lastPreCleanStats.maxCompactionDelayMs)) + } + if (stats.invalidMessagesRead > 0) { + warn("\tFound %d invalid messages during compaction.".format(stats.invalidMessagesRead)) + } + } + + } +} + +object LogCleaner { + val ReconfigurableConfigs = Set( + KafkaConfig.LogCleanerThreadsProp, + KafkaConfig.LogCleanerDedupeBufferSizeProp, + KafkaConfig.LogCleanerDedupeBufferLoadFactorProp, + KafkaConfig.LogCleanerIoBufferSizeProp, + KafkaConfig.MessageMaxBytesProp, + KafkaConfig.LogCleanerIoMaxBytesPerSecondProp, + KafkaConfig.LogCleanerBackoffMsProp + ) + + def cleanerConfig(config: KafkaConfig): CleanerConfig = { + CleanerConfig(numThreads = config.logCleanerThreads, + dedupeBufferSize = config.logCleanerDedupeBufferSize, + dedupeBufferLoadFactor = config.logCleanerDedupeBufferLoadFactor, + ioBufferSize = config.logCleanerIoBufferSize, + maxMessageSize = config.messageMaxBytes, + maxIoBytesPerSecond = config.logCleanerIoMaxBytesPerSecond, + backOffMs = config.logCleanerBackoffMs, + enableCleaner = config.logCleanerEnable) + + } +} + +/** + * This class holds the actual logic for cleaning a log + * @param id An identifier used for logging + * @param offsetMap The map used for deduplication + * @param ioBufferSize The size of the buffers to use. Memory usage will be 2x this number as there is a read and write buffer. + * @param maxIoBufferSize The maximum size of a message that can appear in the log + * @param dupBufferLoadFactor The maximum percent full for the deduplication buffer + * @param throttler The throttler instance to use for limiting I/O rate. + * @param time The time instance + * @param checkDone Check if the cleaning for a partition is finished or aborted. + */ +private[log] class Cleaner(val id: Int, + val offsetMap: OffsetMap, + ioBufferSize: Int, + maxIoBufferSize: Int, + dupBufferLoadFactor: Double, + throttler: Throttler, + time: Time, + checkDone: TopicPartition => Unit) extends Logging { + + protected override def loggerName = classOf[LogCleaner].getName + + this.logIdent = s"Cleaner $id: " + + /* buffer used for read i/o */ + private var readBuffer = ByteBuffer.allocate(ioBufferSize) + + /* buffer used for write i/o */ + private var writeBuffer = ByteBuffer.allocate(ioBufferSize) + + private val decompressionBufferSupplier = BufferSupplier.create(); + + require(offsetMap.slots * dupBufferLoadFactor > 1, "offset map is too small to fit in even a single message, so log cleaning will never make progress. You can increase log.cleaner.dedupe.buffer.size or decrease log.cleaner.threads") + + /** + * Clean the given log + * + * @param cleanable The log to be cleaned + * + * @return The first offset not cleaned and the statistics for this round of cleaning + */ + private[log] def clean(cleanable: LogToClean): (Long, CleanerStats) = { + doClean(cleanable, time.milliseconds()) + } + + private[log] def doClean(cleanable: LogToClean, currentTime: Long): (Long, CleanerStats) = { + info("Beginning cleaning of log %s".format(cleanable.log.name)) + + // figure out the timestamp below which it is safe to remove delete tombstones + // this position is defined to be a configurable time beneath the last modified time of the last clean segment + // this timestamp is only used on the older message formats older than MAGIC_VALUE_V2 + val legacyDeleteHorizonMs = + cleanable.log.logSegments(0, cleanable.firstDirtyOffset).lastOption match { + case None => 0L + case Some(seg) => seg.lastModified - cleanable.log.config.deleteRetentionMs + } + + val log = cleanable.log + val stats = new CleanerStats() + + // build the offset map + info("Building offset map for %s...".format(cleanable.log.name)) + val upperBoundOffset = cleanable.firstUncleanableOffset + buildOffsetMap(log, cleanable.firstDirtyOffset, upperBoundOffset, offsetMap, stats) + val endOffset = offsetMap.latestOffset + 1 + stats.indexDone() + + // determine the timestamp up to which the log will be cleaned + // this is the lower of the last active segment and the compaction lag + val cleanableHorizonMs = log.logSegments(0, cleanable.firstUncleanableOffset).lastOption.map(_.lastModified).getOrElse(0L) + + // group the segments and clean the groups + info("Cleaning log %s (cleaning prior to %s, discarding tombstones prior to upper bound deletion horizon %s)...".format(log.name, new Date(cleanableHorizonMs), new Date(legacyDeleteHorizonMs))) + val transactionMetadata = new CleanedTransactionMetadata + + val groupedSegments = groupSegmentsBySize(log.logSegments(0, endOffset), log.config.segmentSize, + log.config.maxIndexSize, cleanable.firstUncleanableOffset) + for (group <- groupedSegments) + cleanSegments(log, group, offsetMap, currentTime, stats, transactionMetadata, legacyDeleteHorizonMs) + + // record buffer utilization + stats.bufferUtilization = offsetMap.utilization + + stats.allDone() + + (endOffset, stats) + } + + /** + * Clean a group of segments into a single replacement segment + * + * @param log The log being cleaned + * @param segments The group of segments being cleaned + * @param map The offset map to use for cleaning segments + * @param currentTime The current time in milliseconds + * @param stats Collector for cleaning statistics + * @param transactionMetadata State of ongoing transactions which is carried between the cleaning + * of the grouped segments + * @param legacyDeleteHorizonMs The delete horizon used for tombstones whose version is less than 2 + */ + private[log] def cleanSegments(log: UnifiedLog, + segments: Seq[LogSegment], + map: OffsetMap, + currentTime: Long, + stats: CleanerStats, + transactionMetadata: CleanedTransactionMetadata, + legacyDeleteHorizonMs: Long): Unit = { + // create a new segment with a suffix appended to the name of the log and indexes + val cleaned = UnifiedLog.createNewCleanedSegment(log.dir, log.config, segments.head.baseOffset) + transactionMetadata.cleanedIndex = Some(cleaned.txnIndex) + + try { + // clean segments into the new destination segment + val iter = segments.iterator + var currentSegmentOpt: Option[LogSegment] = Some(iter.next()) + val lastOffsetOfActiveProducers = log.lastRecordsOfActiveProducers + + while (currentSegmentOpt.isDefined) { + val currentSegment = currentSegmentOpt.get + val nextSegmentOpt = if (iter.hasNext) Some(iter.next()) else None + + val startOffset = currentSegment.baseOffset + val upperBoundOffset = nextSegmentOpt.map(_.baseOffset).getOrElse(map.latestOffset + 1) + val abortedTransactions = log.collectAbortedTransactions(startOffset, upperBoundOffset) + transactionMetadata.addAbortedTransactions(abortedTransactions) + + val retainLegacyDeletesAndTxnMarkers = currentSegment.lastModified > legacyDeleteHorizonMs + info(s"Cleaning $currentSegment in log ${log.name} into ${cleaned.baseOffset} " + + s"with an upper bound deletion horizon $legacyDeleteHorizonMs computed from " + + s"the segment last modified time of ${currentSegment.lastModified}," + + s"${if(retainLegacyDeletesAndTxnMarkers) "retaining" else "discarding"} deletes.") + + try { + cleanInto(log.topicPartition, currentSegment.log, cleaned, map, retainLegacyDeletesAndTxnMarkers, log.config.deleteRetentionMs, + log.config.maxMessageSize, transactionMetadata, lastOffsetOfActiveProducers, stats, currentTime = currentTime) + } catch { + case e: LogSegmentOffsetOverflowException => + // Split the current segment. It's also safest to abort the current cleaning process, so that we retry from + // scratch once the split is complete. + info(s"Caught segment overflow error during cleaning: ${e.getMessage}") + log.splitOverflowedSegment(currentSegment) + throw new LogCleaningAbortedException() + } + currentSegmentOpt = nextSegmentOpt + } + + cleaned.onBecomeInactiveSegment() + // flush new segment to disk before swap + cleaned.flush() + + // update the modification date to retain the last modified date of the original files + val modified = segments.last.lastModified + cleaned.lastModified = modified + + // swap in new segment + info(s"Swapping in cleaned segment $cleaned for segment(s) $segments in log $log") + log.replaceSegments(List(cleaned), segments) + } catch { + case e: LogCleaningAbortedException => + try cleaned.deleteIfExists() + catch { + case deleteException: Exception => + e.addSuppressed(deleteException) + } finally throw e + } + } + + /** + * Clean the given source log segment into the destination segment using the key=>offset mapping + * provided + * + * @param topicPartition The topic and partition of the log segment to clean + * @param sourceRecords The dirty log segment + * @param dest The cleaned log segment + * @param map The key=>offset mapping + * @param retainLegacyDeletesAndTxnMarkers Should tombstones (lower than version 2) and markers be retained while cleaning this segment + * @param deleteRetentionMs Defines how long a tombstone should be kept as defined by log configuration + * @param maxLogMessageSize The maximum message size of the corresponding topic + * @param stats Collector for cleaning statistics + * @param currentTime The time at which the clean was initiated + */ + private[log] def cleanInto(topicPartition: TopicPartition, + sourceRecords: FileRecords, + dest: LogSegment, + map: OffsetMap, + retainLegacyDeletesAndTxnMarkers: Boolean, + deleteRetentionMs: Long, + maxLogMessageSize: Int, + transactionMetadata: CleanedTransactionMetadata, + lastRecordsOfActiveProducers: Map[Long, LastRecord], + stats: CleanerStats, + currentTime: Long): Unit = { + val logCleanerFilter: RecordFilter = new RecordFilter(currentTime, deleteRetentionMs) { + var discardBatchRecords: Boolean = _ + + override def checkBatchRetention(batch: RecordBatch): RecordFilter.BatchRetentionResult = { + // we piggy-back on the tombstone retention logic to delay deletion of transaction markers. + // note that we will never delete a marker until all the records from that transaction are removed. + val canDiscardBatch = shouldDiscardBatch(batch, transactionMetadata) + + if (batch.isControlBatch) + discardBatchRecords = canDiscardBatch && batch.deleteHorizonMs().isPresent && batch.deleteHorizonMs().getAsLong <= currentTime + else + discardBatchRecords = canDiscardBatch + + def isBatchLastRecordOfProducer: Boolean = { + // We retain the batch in order to preserve the state of active producers. There are three cases: + // 1) The producer is no longer active, which means we can delete all records for that producer. + // 2) The producer is still active and has a last data offset. We retain the batch that contains + // this offset since it also contains the last sequence number for this producer. + // 3) The last entry in the log is a transaction marker. We retain this marker since it has the + // last producer epoch, which is needed to ensure fencing. + lastRecordsOfActiveProducers.get(batch.producerId).exists { lastRecord => + lastRecord.lastDataOffset match { + case Some(offset) => batch.lastOffset == offset + case None => batch.isControlBatch && batch.producerEpoch == lastRecord.producerEpoch + } + } + } + + val batchRetention: BatchRetention = + if (batch.hasProducerId && isBatchLastRecordOfProducer) + BatchRetention.RETAIN_EMPTY + else if (discardBatchRecords) + BatchRetention.DELETE + else + BatchRetention.DELETE_EMPTY + new RecordFilter.BatchRetentionResult(batchRetention, canDiscardBatch && batch.isControlBatch) + } + + override def shouldRetainRecord(batch: RecordBatch, record: Record): Boolean = { + if (discardBatchRecords) + // The batch is only retained to preserve producer sequence information; the records can be removed + false + else + Cleaner.this.shouldRetainRecord(map, retainLegacyDeletesAndTxnMarkers, batch, record, stats, currentTime = currentTime) + } + } + + var position = 0 + while (position < sourceRecords.sizeInBytes) { + checkDone(topicPartition) + // read a chunk of messages and copy any that are to be retained to the write buffer to be written out + readBuffer.clear() + writeBuffer.clear() + + sourceRecords.readInto(readBuffer, position) + val records = MemoryRecords.readableRecords(readBuffer) + throttler.maybeThrottle(records.sizeInBytes) + val result = records.filterTo(topicPartition, logCleanerFilter, writeBuffer, maxLogMessageSize, decompressionBufferSupplier) + + stats.readMessages(result.messagesRead, result.bytesRead) + stats.recopyMessages(result.messagesRetained, result.bytesRetained) + + position += result.bytesRead + + // if any messages are to be retained, write them out + val outputBuffer = result.outputBuffer + if (outputBuffer.position() > 0) { + outputBuffer.flip() + val retained = MemoryRecords.readableRecords(outputBuffer) + // it's OK not to hold the Log's lock in this case, because this segment is only accessed by other threads + // after `Log.replaceSegments` (which acquires the lock) is called + dest.append(largestOffset = result.maxOffset, + largestTimestamp = result.maxTimestamp, + shallowOffsetOfMaxTimestamp = result.shallowOffsetOfMaxTimestamp, + records = retained) + throttler.maybeThrottle(outputBuffer.limit()) + } + + // if we read bytes but didn't get even one complete batch, our I/O buffer is too small, grow it and try again + // `result.bytesRead` contains bytes from `messagesRead` and any discarded batches. + if (readBuffer.limit() > 0 && result.bytesRead == 0) + growBuffersOrFail(sourceRecords, position, maxLogMessageSize, records) + } + restoreBuffers() + } + + + /** + * Grow buffers to process next batch of records from `sourceRecords.` Buffers are doubled in size + * up to a maximum of `maxLogMessageSize`. In some scenarios, a record could be bigger than the + * current maximum size configured for the log. For example: + * 1. A compacted topic using compression may contain a message set slightly larger than max.message.bytes + * 2. max.message.bytes of a topic could have been reduced after writing larger messages + * In these cases, grow the buffer to hold the next batch. + */ + private def growBuffersOrFail(sourceRecords: FileRecords, + position: Int, + maxLogMessageSize: Int, + memoryRecords: MemoryRecords): Unit = { + + val maxSize = if (readBuffer.capacity >= maxLogMessageSize) { + val nextBatchSize = memoryRecords.firstBatchSize + val logDesc = s"log segment ${sourceRecords.file} at position $position" + if (nextBatchSize == null) + throw new IllegalStateException(s"Could not determine next batch size for $logDesc") + if (nextBatchSize <= 0) + throw new IllegalStateException(s"Invalid batch size $nextBatchSize for $logDesc") + if (nextBatchSize <= readBuffer.capacity) + throw new IllegalStateException(s"Batch size $nextBatchSize < buffer size ${readBuffer.capacity}, but not processed for $logDesc") + val bytesLeft = sourceRecords.channel.size - position + if (nextBatchSize > bytesLeft) + throw new CorruptRecordException(s"Log segment may be corrupt, batch size $nextBatchSize > $bytesLeft bytes left in segment for $logDesc") + nextBatchSize.intValue + } else + maxLogMessageSize + + growBuffers(maxSize) + } + + private def shouldDiscardBatch(batch: RecordBatch, + transactionMetadata: CleanedTransactionMetadata): Boolean = { + if (batch.isControlBatch) + transactionMetadata.onControlBatchRead(batch) + else + transactionMetadata.onBatchRead(batch) + } + + private def shouldRetainRecord(map: kafka.log.OffsetMap, + retainDeletesForLegacyRecords: Boolean, + batch: RecordBatch, + record: Record, + stats: CleanerStats, + currentTime: Long): Boolean = { + val pastLatestOffset = record.offset > map.latestOffset + if (pastLatestOffset) + return true + + if (record.hasKey) { + val key = record.key + val foundOffset = map.get(key) + /* First,the message must have the latest offset for the key + * then there are two cases in which we can retain a message: + * 1) The message has value + * 2) The message doesn't has value but it can't be deleted now. + */ + val latestOffsetForKey = record.offset() >= foundOffset + val legacyRecord = batch.magic() < RecordBatch.MAGIC_VALUE_V2 + def shouldRetainDeletes = { + if (!legacyRecord) + !batch.deleteHorizonMs().isPresent || currentTime < batch.deleteHorizonMs().getAsLong + else + retainDeletesForLegacyRecords + } + val isRetainedValue = record.hasValue || shouldRetainDeletes + latestOffsetForKey && isRetainedValue + } else { + stats.invalidMessage() + false + } + } + + /** + * Double the I/O buffer capacity + */ + def growBuffers(maxLogMessageSize: Int): Unit = { + val maxBufferSize = math.max(maxLogMessageSize, maxIoBufferSize) + if(readBuffer.capacity >= maxBufferSize || writeBuffer.capacity >= maxBufferSize) + throw new IllegalStateException("This log contains a message larger than maximum allowable size of %s.".format(maxBufferSize)) + val newSize = math.min(this.readBuffer.capacity * 2, maxBufferSize) + info(s"Growing cleaner I/O buffers from ${readBuffer.capacity} bytes to $newSize bytes.") + this.readBuffer = ByteBuffer.allocate(newSize) + this.writeBuffer = ByteBuffer.allocate(newSize) + } + + /** + * Restore the I/O buffer capacity to its original size + */ + def restoreBuffers(): Unit = { + if(this.readBuffer.capacity > this.ioBufferSize) + this.readBuffer = ByteBuffer.allocate(this.ioBufferSize) + if(this.writeBuffer.capacity > this.ioBufferSize) + this.writeBuffer = ByteBuffer.allocate(this.ioBufferSize) + } + + /** + * Group the segments in a log into groups totaling less than a given size. the size is enforced separately for the log data and the index data. + * We collect a group of such segments together into a single + * destination segment. This prevents segment sizes from shrinking too much. + * + * @param segments The log segments to group + * @param maxSize the maximum size in bytes for the total of all log data in a group + * @param maxIndexSize the maximum size in bytes for the total of all index data in a group + * + * @return A list of grouped segments + */ + private[log] def groupSegmentsBySize(segments: Iterable[LogSegment], maxSize: Int, maxIndexSize: Int, firstUncleanableOffset: Long): List[Seq[LogSegment]] = { + var grouped = List[List[LogSegment]]() + var segs = segments.toList + while(segs.nonEmpty) { + var group = List(segs.head) + var logSize = segs.head.size.toLong + var indexSize = segs.head.offsetIndex.sizeInBytes.toLong + var timeIndexSize = segs.head.timeIndex.sizeInBytes.toLong + segs = segs.tail + while(segs.nonEmpty && + logSize + segs.head.size <= maxSize && + indexSize + segs.head.offsetIndex.sizeInBytes <= maxIndexSize && + timeIndexSize + segs.head.timeIndex.sizeInBytes <= maxIndexSize && + //if first segment size is 0, we don't need to do the index offset range check. + //this will avoid empty log left every 2^31 message. + (segs.head.size == 0 || + lastOffsetForFirstSegment(segs, firstUncleanableOffset) - group.last.baseOffset <= Int.MaxValue)) { + group = segs.head :: group + logSize += segs.head.size + indexSize += segs.head.offsetIndex.sizeInBytes + timeIndexSize += segs.head.timeIndex.sizeInBytes + segs = segs.tail + } + grouped ::= group.reverse + } + grouped.reverse + } + + /** + * We want to get the last offset in the first log segment in segs. + * LogSegment.nextOffset() gives the exact last offset in a segment, but can be expensive since it requires + * scanning the segment from the last index entry. + * Therefore, we estimate the last offset of the first log segment by using + * the base offset of the next segment in the list. + * If the next segment doesn't exist, first Uncleanable Offset will be used. + * + * @param segs - remaining segments to group. + * @return The estimated last offset for the first segment in segs + */ + private def lastOffsetForFirstSegment(segs: List[LogSegment], firstUncleanableOffset: Long): Long = { + if (segs.size > 1) { + /* if there is a next segment, use its base offset as the bounding offset to guarantee we know + * the worst case offset */ + segs(1).baseOffset - 1 + } else { + //for the last segment in the list, use the first uncleanable offset. + firstUncleanableOffset - 1 + } + } + + /** + * Build a map of key_hash => offset for the keys in the cleanable dirty portion of the log to use in cleaning. + * @param log The log to use + * @param start The offset at which dirty messages begin + * @param end The ending offset for the map that is being built + * @param map The map in which to store the mappings + * @param stats Collector for cleaning statistics + */ + private[log] def buildOffsetMap(log: UnifiedLog, + start: Long, + end: Long, + map: OffsetMap, + stats: CleanerStats): Unit = { + map.clear() + val dirty = log.logSegments(start, end).toBuffer + val nextSegmentStartOffsets = new ListBuffer[Long] + if (dirty.nonEmpty) { + for (nextSegment <- dirty.tail) nextSegmentStartOffsets.append(nextSegment.baseOffset) + nextSegmentStartOffsets.append(end) + } + info("Building offset map for log %s for %d segments in offset range [%d, %d).".format(log.name, dirty.size, start, end)) + + val transactionMetadata = new CleanedTransactionMetadata + val abortedTransactions = log.collectAbortedTransactions(start, end) + transactionMetadata.addAbortedTransactions(abortedTransactions) + + // Add all the cleanable dirty segments. We must take at least map.slots * load_factor, + // but we may be able to fit more (if there is lots of duplication in the dirty section of the log) + var full = false + for ((segment, nextSegmentStartOffset) <- dirty.zip(nextSegmentStartOffsets) if !full) { + checkDone(log.topicPartition) + + full = buildOffsetMapForSegment(log.topicPartition, segment, map, start, nextSegmentStartOffset, log.config.maxMessageSize, + transactionMetadata, stats) + if (full) + debug("Offset map is full, %d segments fully mapped, segment with base offset %d is partially mapped".format(dirty.indexOf(segment), segment.baseOffset)) + } + info("Offset map for log %s complete.".format(log.name)) + } + + /** + * Add the messages in the given segment to the offset map + * + * @param segment The segment to index + * @param map The map in which to store the key=>offset mapping + * @param stats Collector for cleaning statistics + * + * @return If the map was filled whilst loading from this segment + */ + private def buildOffsetMapForSegment(topicPartition: TopicPartition, + segment: LogSegment, + map: OffsetMap, + startOffset: Long, + nextSegmentStartOffset: Long, + maxLogMessageSize: Int, + transactionMetadata: CleanedTransactionMetadata, + stats: CleanerStats): Boolean = { + var position = segment.offsetIndex.lookup(startOffset).position + val maxDesiredMapSize = (map.slots * this.dupBufferLoadFactor).toInt + while (position < segment.log.sizeInBytes) { + checkDone(topicPartition) + readBuffer.clear() + try { + segment.log.readInto(readBuffer, position) + } catch { + case e: Exception => + throw new KafkaException(s"Failed to read from segment $segment of partition $topicPartition " + + "while loading offset map", e) + } + val records = MemoryRecords.readableRecords(readBuffer) + throttler.maybeThrottle(records.sizeInBytes) + + val startPosition = position + for (batch <- records.batches.asScala) { + if (batch.isControlBatch) { + transactionMetadata.onControlBatchRead(batch) + stats.indexMessagesRead(1) + } else { + val isAborted = transactionMetadata.onBatchRead(batch) + if (isAborted) { + // If the batch is aborted, do not bother populating the offset map. + // Note that abort markers are supported in v2 and above, which means count is defined. + stats.indexMessagesRead(batch.countOrNull) + } else { + val recordsIterator = batch.streamingIterator(decompressionBufferSupplier) + try { + for (record <- recordsIterator.asScala) { + if (record.hasKey && record.offset >= startOffset) { + if (map.size < maxDesiredMapSize) + map.put(record.key, record.offset) + else + return true + } + stats.indexMessagesRead(1) + } + } finally recordsIterator.close() + } + } + + if (batch.lastOffset >= startOffset) + map.updateLatestOffset(batch.lastOffset) + } + val bytesRead = records.validBytes + position += bytesRead + stats.indexBytesRead(bytesRead) + + // if we didn't read even one complete message, our read buffer may be too small + if(position == startPosition) + growBuffersOrFail(segment.log, position, maxLogMessageSize, records) + } + + // In the case of offsets gap, fast forward to latest expected offset in this segment. + map.updateLatestOffset(nextSegmentStartOffset - 1L) + + restoreBuffers() + false + } +} + +/** + * A simple struct for collecting pre-clean stats + */ +private class PreCleanStats() { + var maxCompactionDelayMs = 0L + var delayedPartitions = 0 + var cleanablePartitions = 0 + + def updateMaxCompactionDelay(delayMs: Long): Unit = { + maxCompactionDelayMs = Math.max(maxCompactionDelayMs, delayMs) + if (delayMs > 0) { + delayedPartitions += 1 + } + } + def recordCleanablePartitions(numOfCleanables: Int): Unit = { + cleanablePartitions = numOfCleanables + } +} + +/** + * A simple struct for collecting stats about log cleaning + */ +private class CleanerStats(time: Time = Time.SYSTEM) { + val startTime = time.milliseconds + var mapCompleteTime = -1L + var endTime = -1L + var bytesRead = 0L + var bytesWritten = 0L + var mapBytesRead = 0L + var mapMessagesRead = 0L + var messagesRead = 0L + var invalidMessagesRead = 0L + var messagesWritten = 0L + var bufferUtilization = 0.0d + + def readMessages(messagesRead: Int, bytesRead: Int): Unit = { + this.messagesRead += messagesRead + this.bytesRead += bytesRead + } + + def invalidMessage(): Unit = { + invalidMessagesRead += 1 + } + + def recopyMessages(messagesWritten: Int, bytesWritten: Int): Unit = { + this.messagesWritten += messagesWritten + this.bytesWritten += bytesWritten + } + + def indexMessagesRead(size: Int): Unit = { + mapMessagesRead += size + } + + def indexBytesRead(size: Int): Unit = { + mapBytesRead += size + } + + def indexDone(): Unit = { + mapCompleteTime = time.milliseconds + } + + def allDone(): Unit = { + endTime = time.milliseconds + } + + def elapsedSecs: Double = (endTime - startTime) / 1000.0 + + def elapsedIndexSecs: Double = (mapCompleteTime - startTime) / 1000.0 + +} + +/** + * Helper class for a log, its topic/partition, the first cleanable position, the first uncleanable dirty position, + * and whether it needs compaction immediately. + */ +private case class LogToClean(topicPartition: TopicPartition, + log: UnifiedLog, + firstDirtyOffset: Long, + uncleanableOffset: Long, + needCompactionNow: Boolean = false) extends Ordered[LogToClean] { + val cleanBytes = log.logSegments(-1, firstDirtyOffset).map(_.size.toLong).sum + val (firstUncleanableOffset, cleanableBytes) = LogCleanerManager.calculateCleanableBytes(log, firstDirtyOffset, uncleanableOffset) + val totalBytes = cleanBytes + cleanableBytes + val cleanableRatio = cleanableBytes / totalBytes.toDouble + override def compare(that: LogToClean): Int = math.signum(this.cleanableRatio - that.cleanableRatio).toInt +} + +/** + * This is a helper class to facilitate tracking transaction state while cleaning the log. It maintains a set + * of the ongoing aborted and committed transactions as the cleaner is working its way through the log. This + * class is responsible for deciding when transaction markers can be removed and is therefore also responsible + * for updating the cleaned transaction index accordingly. + */ +private[log] class CleanedTransactionMetadata { + private val ongoingCommittedTxns = mutable.Set.empty[Long] + private val ongoingAbortedTxns = mutable.Map.empty[Long, AbortedTransactionMetadata] + // Minheap of aborted transactions sorted by the transaction first offset + private val abortedTransactions = mutable.PriorityQueue.empty[AbortedTxn](new Ordering[AbortedTxn] { + override def compare(x: AbortedTxn, y: AbortedTxn): Int = x.firstOffset compare y.firstOffset + }.reverse) + + // Output cleaned index to write retained aborted transactions + var cleanedIndex: Option[TransactionIndex] = None + + def addAbortedTransactions(abortedTransactions: List[AbortedTxn]): Unit = { + this.abortedTransactions ++= abortedTransactions + } + + /** + * Update the cleaned transaction state with a control batch that has just been traversed by the cleaner. + * Return true if the control batch can be discarded. + */ + def onControlBatchRead(controlBatch: RecordBatch): Boolean = { + consumeAbortedTxnsUpTo(controlBatch.lastOffset) + + val controlRecordIterator = controlBatch.iterator + if (controlRecordIterator.hasNext) { + val controlRecord = controlRecordIterator.next() + val controlType = ControlRecordType.parse(controlRecord.key) + val producerId = controlBatch.producerId + controlType match { + case ControlRecordType.ABORT => + ongoingAbortedTxns.remove(producerId) match { + // Retain the marker until all batches from the transaction have been removed. + case Some(abortedTxnMetadata) if abortedTxnMetadata.lastObservedBatchOffset.isDefined => + cleanedIndex.foreach(_.append(abortedTxnMetadata.abortedTxn)) + false + case _ => true + } + + case ControlRecordType.COMMIT => + // This marker is eligible for deletion if we didn't traverse any batches from the transaction + !ongoingCommittedTxns.remove(producerId) + + case _ => false + } + } else { + // An empty control batch was already cleaned, so it's safe to discard + true + } + } + + private def consumeAbortedTxnsUpTo(offset: Long): Unit = { + while (abortedTransactions.headOption.exists(_.firstOffset <= offset)) { + val abortedTxn = abortedTransactions.dequeue() + ongoingAbortedTxns.getOrElseUpdate(abortedTxn.producerId, new AbortedTransactionMetadata(abortedTxn)) + } + } + + /** + * Update the transactional state for the incoming non-control batch. If the batch is part of + * an aborted transaction, return true to indicate that it is safe to discard. + */ + def onBatchRead(batch: RecordBatch): Boolean = { + consumeAbortedTxnsUpTo(batch.lastOffset) + if (batch.isTransactional) { + ongoingAbortedTxns.get(batch.producerId) match { + case Some(abortedTransactionMetadata) => + abortedTransactionMetadata.lastObservedBatchOffset = Some(batch.lastOffset) + true + case None => + ongoingCommittedTxns += batch.producerId + false + } + } else { + false + } + } + +} + +private class AbortedTransactionMetadata(val abortedTxn: AbortedTxn) { + var lastObservedBatchOffset: Option[Long] = None + + override def toString: String = s"(txn: $abortedTxn, lastOffset: $lastObservedBatchOffset)" +} diff --git a/core/src/main/scala/kafka/log/LogCleanerManager.scala b/core/src/main/scala/kafka/log/LogCleanerManager.scala new file mode 100755 index 0000000..8b6926b --- /dev/null +++ b/core/src/main/scala/kafka/log/LogCleanerManager.scala @@ -0,0 +1,663 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import java.io.File +import java.util.concurrent.TimeUnit +import java.util.concurrent.locks.ReentrantLock + +import kafka.common.{KafkaException, LogCleaningAbortedException} +import kafka.metrics.KafkaMetricsGroup +import kafka.server.LogDirFailureChannel +import kafka.server.checkpoints.OffsetCheckpointFile +import kafka.utils.CoreUtils._ +import kafka.utils.{Logging, Pool} +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.utils.Time +import org.apache.kafka.common.errors.KafkaStorageException + +import scala.collection.{Iterable, Seq, mutable} + +private[log] sealed trait LogCleaningState +private[log] case object LogCleaningInProgress extends LogCleaningState +private[log] case object LogCleaningAborted extends LogCleaningState +private[log] case class LogCleaningPaused(pausedCount: Int) extends LogCleaningState + +private[log] class LogCleaningException(val log: UnifiedLog, + private val message: String, + private val cause: Throwable) extends KafkaException(message, cause) + +/** + * This class manages the state of each partition being cleaned. + * LogCleaningState defines the cleaning states that a TopicPartition can be in. + * 1. None : No cleaning state in a TopicPartition. In this state, it can become LogCleaningInProgress + * or LogCleaningPaused(1). Valid previous state are LogCleaningInProgress and LogCleaningPaused(1) + * 2. LogCleaningInProgress : The cleaning is currently in progress. In this state, it can become None when log cleaning is finished + * or become LogCleaningAborted. Valid previous state is None. + * 3. LogCleaningAborted : The cleaning abort is requested. In this state, it can become LogCleaningPaused(1). + * Valid previous state is LogCleaningInProgress. + * 4-a. LogCleaningPaused(1) : The cleaning is paused once. No log cleaning can be done in this state. + * In this state, it can become None or LogCleaningPaused(2). + * Valid previous state is None, LogCleaningAborted or LogCleaningPaused(2). + * 4-b. LogCleaningPaused(i) : The cleaning is paused i times where i>= 2. No log cleaning can be done in this state. + * In this state, it can become LogCleaningPaused(i-1) or LogCleaningPaused(i+1). + * Valid previous state is LogCleaningPaused(i-1) or LogCleaningPaused(i+1). + */ +private[log] class LogCleanerManager(val logDirs: Seq[File], + val logs: Pool[TopicPartition, UnifiedLog], + val logDirFailureChannel: LogDirFailureChannel) extends Logging with KafkaMetricsGroup { + import LogCleanerManager._ + + + protected override def loggerName = classOf[LogCleaner].getName + + // package-private for testing + private[log] val offsetCheckpointFile = "cleaner-offset-checkpoint" + + /* the offset checkpoints holding the last cleaned point for each log */ + @volatile private var checkpoints = logDirs.map(dir => + (dir, new OffsetCheckpointFile(new File(dir, offsetCheckpointFile), logDirFailureChannel))).toMap + + /* the set of logs currently being cleaned */ + private val inProgress = mutable.HashMap[TopicPartition, LogCleaningState]() + + /* the set of uncleanable partitions (partitions that have raised an unexpected error during cleaning) + * for each log directory */ + private val uncleanablePartitions = mutable.HashMap[String, mutable.Set[TopicPartition]]() + + /* a global lock used to control all access to the in-progress set and the offset checkpoints */ + private val lock = new ReentrantLock + + /* for coordinating the pausing and the cleaning of a partition */ + private val pausedCleaningCond = lock.newCondition() + + /* gauges for tracking the number of partitions marked as uncleanable for each log directory */ + for (dir <- logDirs) { + newGauge("uncleanable-partitions-count", + () => inLock(lock) { uncleanablePartitions.get(dir.getAbsolutePath).map(_.size).getOrElse(0) }, + Map("logDirectory" -> dir.getAbsolutePath) + ) + } + + /* gauges for tracking the number of uncleanable bytes from uncleanable partitions for each log directory */ + for (dir <- logDirs) { + newGauge("uncleanable-bytes", + () => inLock(lock) { + uncleanablePartitions.get(dir.getAbsolutePath) match { + case Some(partitions) => + val lastClean = allCleanerCheckpoints + val now = Time.SYSTEM.milliseconds + partitions.iterator.map { tp => + Option(logs.get(tp)).map { + log => + val lastCleanOffset = lastClean.get(tp) + val offsetsToClean = cleanableOffsets(log, lastCleanOffset, now) + val (_, uncleanableBytes) = calculateCleanableBytes(log, offsetsToClean.firstDirtyOffset, offsetsToClean.firstUncleanableDirtyOffset) + uncleanableBytes + }.getOrElse(0L) + }.sum + case None => 0 + } + }, + Map("logDirectory" -> dir.getAbsolutePath) + ) + } + + /* a gauge for tracking the cleanable ratio of the dirtiest log */ + @volatile private var dirtiestLogCleanableRatio = 0.0 + newGauge("max-dirty-percent", () => (100 * dirtiestLogCleanableRatio).toInt) + + /* a gauge for tracking the time since the last log cleaner run, in milli seconds */ + @volatile private var timeOfLastRun: Long = Time.SYSTEM.milliseconds + newGauge("time-since-last-run-ms", () => Time.SYSTEM.milliseconds - timeOfLastRun) + + /** + * @return the position processed for all logs. + */ + def allCleanerCheckpoints: Map[TopicPartition, Long] = { + inLock(lock) { + checkpoints.values.flatMap(checkpoint => { + try { + checkpoint.read() + } catch { + case e: KafkaStorageException => + error(s"Failed to access checkpoint file ${checkpoint.file.getName} in dir ${checkpoint.file.getParentFile.getAbsolutePath}", e) + Map.empty[TopicPartition, Long] + } + }).toMap + } + } + + /** + * Package private for unit test. Get the cleaning state of the partition. + */ + private[log] def cleaningState(tp: TopicPartition): Option[LogCleaningState] = { + inLock(lock) { + inProgress.get(tp) + } + } + + /** + * Package private for unit test. Set the cleaning state of the partition. + */ + private[log] def setCleaningState(tp: TopicPartition, state: LogCleaningState): Unit = { + inLock(lock) { + inProgress.put(tp, state) + } + } + + /** + * Choose the log to clean next and add it to the in-progress set. We recompute this + * each time from the full set of logs to allow logs to be dynamically added to the pool of logs + * the log manager maintains. + */ + def grabFilthiestCompactedLog(time: Time, preCleanStats: PreCleanStats = new PreCleanStats()): Option[LogToClean] = { + inLock(lock) { + val now = time.milliseconds + this.timeOfLastRun = now + val lastClean = allCleanerCheckpoints + + val dirtyLogs = logs.filter { + case (_, log) => log.config.compact + }.filterNot { + case (topicPartition, log) => + inProgress.contains(topicPartition) || isUncleanablePartition(log, topicPartition) + }.map { + case (topicPartition, log) => // create a LogToClean instance for each + try { + val lastCleanOffset = lastClean.get(topicPartition) + val offsetsToClean = cleanableOffsets(log, lastCleanOffset, now) + // update checkpoint for logs with invalid checkpointed offsets + if (offsetsToClean.forceUpdateCheckpoint) + updateCheckpoints(log.parentDirFile, partitionToUpdateOrAdd = Option(topicPartition, offsetsToClean.firstDirtyOffset)) + val compactionDelayMs = maxCompactionDelay(log, offsetsToClean.firstDirtyOffset, now) + preCleanStats.updateMaxCompactionDelay(compactionDelayMs) + + LogToClean(topicPartition, log, offsetsToClean.firstDirtyOffset, offsetsToClean.firstUncleanableDirtyOffset, compactionDelayMs > 0) + } catch { + case e: Throwable => throw new LogCleaningException(log, + s"Failed to calculate log cleaning stats for partition $topicPartition", e) + } + }.filter(ltc => ltc.totalBytes > 0) // skip any empty logs + + this.dirtiestLogCleanableRatio = if (dirtyLogs.nonEmpty) dirtyLogs.max.cleanableRatio else 0 + // and must meet the minimum threshold for dirty byte ratio or have some bytes required to be compacted + val cleanableLogs = dirtyLogs.filter { ltc => + (ltc.needCompactionNow && ltc.cleanableBytes > 0) || ltc.cleanableRatio > ltc.log.config.minCleanableRatio + } + + if (cleanableLogs.isEmpty) + None + else { + preCleanStats.recordCleanablePartitions(cleanableLogs.size) + val filthiest = cleanableLogs.max + inProgress.put(filthiest.topicPartition, LogCleaningInProgress) + Some(filthiest) + } + } + } + + /** + * Pause logs cleaning for logs that do not have compaction enabled + * and do not have other deletion or compaction in progress. + * This is to handle potential race between retention and cleaner threads when users + * switch topic configuration between compacted and non-compacted topic. + * @return retention logs that have log cleaning successfully paused + */ + def pauseCleaningForNonCompactedPartitions(): Iterable[(TopicPartition, UnifiedLog)] = { + inLock(lock) { + val deletableLogs = logs.filter { + case (_, log) => !log.config.compact // pick non-compacted logs + }.filterNot { + case (topicPartition, _) => inProgress.contains(topicPartition) // skip any logs already in-progress + } + + deletableLogs.foreach { + case (topicPartition, _) => inProgress.put(topicPartition, LogCleaningPaused(1)) + } + deletableLogs + } + } + + /** + * Find any logs that have compaction enabled. Mark them as being cleaned + * Include logs without delete enabled, as they may have segments + * that precede the start offset. + */ + def deletableLogs(): Iterable[(TopicPartition, UnifiedLog)] = { + inLock(lock) { + val toClean = logs.filter { case (topicPartition, log) => + !inProgress.contains(topicPartition) && log.config.compact && + !isUncleanablePartition(log, topicPartition) + } + toClean.foreach { case (tp, _) => inProgress.put(tp, LogCleaningInProgress) } + toClean + } + + } + + /** + * Abort the cleaning of a particular partition, if it's in progress. This call blocks until the cleaning of + * the partition is aborted. + * This is implemented by first abortAndPausing and then resuming the cleaning of the partition. + */ + def abortCleaning(topicPartition: TopicPartition): Unit = { + inLock(lock) { + abortAndPauseCleaning(topicPartition) + resumeCleaning(Seq(topicPartition)) + } + } + + /** + * Abort the cleaning of a particular partition if it's in progress, and pause any future cleaning of this partition. + * This call blocks until the cleaning of the partition is aborted and paused. + * 1. If the partition is not in progress, mark it as paused. + * 2. Otherwise, first mark the state of the partition as aborted. + * 3. The cleaner thread checks the state periodically and if it sees the state of the partition is aborted, it + * throws a LogCleaningAbortedException to stop the cleaning task. + * 4. When the cleaning task is stopped, doneCleaning() is called, which sets the state of the partition as paused. + * 5. abortAndPauseCleaning() waits until the state of the partition is changed to paused. + * 6. If the partition is already paused, a new call to this function + * will increase the paused count by one. + */ + def abortAndPauseCleaning(topicPartition: TopicPartition): Unit = { + inLock(lock) { + inProgress.get(topicPartition) match { + case None => + inProgress.put(topicPartition, LogCleaningPaused(1)) + case Some(LogCleaningInProgress) => + inProgress.put(topicPartition, LogCleaningAborted) + case Some(LogCleaningPaused(count)) => + inProgress.put(topicPartition, LogCleaningPaused(count + 1)) + case Some(s) => + throw new IllegalStateException(s"Compaction for partition $topicPartition cannot be aborted and paused since it is in $s state.") + } + while(!isCleaningInStatePaused(topicPartition)) + pausedCleaningCond.await(100, TimeUnit.MILLISECONDS) + } + } + + /** + * Resume the cleaning of paused partitions. + * Each call of this function will undo one pause. + */ + def resumeCleaning(topicPartitions: Iterable[TopicPartition]): Unit = { + inLock(lock) { + topicPartitions.foreach { + topicPartition => + inProgress.get(topicPartition) match { + case None => + throw new IllegalStateException(s"Compaction for partition $topicPartition cannot be resumed since it is not paused.") + case Some(state) => + state match { + case LogCleaningPaused(count) if count == 1 => + inProgress.remove(topicPartition) + case LogCleaningPaused(count) if count > 1 => + inProgress.put(topicPartition, LogCleaningPaused(count - 1)) + case s => + throw new IllegalStateException(s"Compaction for partition $topicPartition cannot be resumed since it is in $s state.") + } + } + } + } + } + + /** + * Check if the cleaning for a partition is in a particular state. The caller is expected to hold lock while making the call. + */ + private def isCleaningInState(topicPartition: TopicPartition, expectedState: LogCleaningState): Boolean = { + inProgress.get(topicPartition) match { + case None => false + case Some(state) => + if (state == expectedState) + true + else + false + } + } + + /** + * Check if the cleaning for a partition is paused. The caller is expected to hold lock while making the call. + */ + private def isCleaningInStatePaused(topicPartition: TopicPartition): Boolean = { + inProgress.get(topicPartition) match { + case None => false + case Some(state) => + state match { + case _: LogCleaningPaused => + true + case _ => + false + } + } + } + + /** + * Check if the cleaning for a partition is aborted. If so, throw an exception. + */ + def checkCleaningAborted(topicPartition: TopicPartition): Unit = { + inLock(lock) { + if (isCleaningInState(topicPartition, LogCleaningAborted)) + throw new LogCleaningAbortedException() + } + } + + /** + * Update checkpoint file, adding or removing partitions if necessary. + * + * @param dataDir The File object to be updated + * @param partitionToUpdateOrAdd The [TopicPartition, Long] map data to be updated. pass "none" if doing remove, not add + * @param partitionToRemove The TopicPartition to be removed + */ + def updateCheckpoints(dataDir: File, + partitionToUpdateOrAdd: Option[(TopicPartition, Long)] = None, + partitionToRemove: Option[TopicPartition] = None): Unit = { + inLock(lock) { + val checkpoint = checkpoints(dataDir) + if (checkpoint != null) { + try { + val currentCheckpoint = checkpoint.read().filter { case (tp, _) => logs.keys.contains(tp) }.toMap + // remove the partition offset if any + var updatedCheckpoint = partitionToRemove match { + case Some(topicPartition) => currentCheckpoint - topicPartition + case None => currentCheckpoint + } + // update or add the partition offset if any + updatedCheckpoint = partitionToUpdateOrAdd match { + case Some(updatedOffset) => updatedCheckpoint + updatedOffset + case None => updatedCheckpoint + } + + checkpoint.write(updatedCheckpoint) + } catch { + case e: KafkaStorageException => + error(s"Failed to access checkpoint file ${checkpoint.file.getName} in dir ${checkpoint.file.getParentFile.getAbsolutePath}", e) + } + } + } + } + + /** + * alter the checkpoint directory for the topicPartition, to remove the data in sourceLogDir, and add the data in destLogDir + */ + def alterCheckpointDir(topicPartition: TopicPartition, sourceLogDir: File, destLogDir: File): Unit = { + inLock(lock) { + try { + checkpoints.get(sourceLogDir).flatMap(_.read().get(topicPartition)) match { + case Some(offset) => + debug(s"Removing the partition offset data in checkpoint file for '${topicPartition}' " + + s"from ${sourceLogDir.getAbsoluteFile} directory.") + updateCheckpoints(sourceLogDir, partitionToRemove = Option(topicPartition)) + + debug(s"Adding the partition offset data in checkpoint file for '${topicPartition}' " + + s"to ${destLogDir.getAbsoluteFile} directory.") + updateCheckpoints(destLogDir, partitionToUpdateOrAdd = Option(topicPartition, offset)) + case None => + } + } catch { + case e: KafkaStorageException => + error(s"Failed to access checkpoint file in dir ${sourceLogDir.getAbsolutePath}", e) + } + + val logUncleanablePartitions = uncleanablePartitions.getOrElse(sourceLogDir.toString, mutable.Set[TopicPartition]()) + if (logUncleanablePartitions.contains(topicPartition)) { + logUncleanablePartitions.remove(topicPartition) + markPartitionUncleanable(destLogDir.toString, topicPartition) + } + } + } + + /** + * Stop cleaning logs in the provided directory + * + * @param dir the absolute path of the log dir + */ + def handleLogDirFailure(dir: String): Unit = { + warn(s"Stopping cleaning logs in dir $dir") + inLock(lock) { + checkpoints = checkpoints.filter { case (k, _) => k.getAbsolutePath != dir } + } + } + + /** + * Truncate the checkpointed offset for the given partition if its checkpointed offset is larger than the given offset + */ + def maybeTruncateCheckpoint(dataDir: File, topicPartition: TopicPartition, offset: Long): Unit = { + inLock(lock) { + if (logs.get(topicPartition).config.compact) { + val checkpoint = checkpoints(dataDir) + if (checkpoint != null) { + val existing = checkpoint.read() + if (existing.getOrElse(topicPartition, 0L) > offset) + checkpoint.write(mutable.Map() ++= existing += topicPartition -> offset) + } + } + } + } + + /** + * Save out the endOffset and remove the given log from the in-progress set, if not aborted. + */ + def doneCleaning(topicPartition: TopicPartition, dataDir: File, endOffset: Long): Unit = { + inLock(lock) { + inProgress.get(topicPartition) match { + case Some(LogCleaningInProgress) => + updateCheckpoints(dataDir, partitionToUpdateOrAdd = Option(topicPartition, endOffset)) + inProgress.remove(topicPartition) + case Some(LogCleaningAborted) => + inProgress.put(topicPartition, LogCleaningPaused(1)) + pausedCleaningCond.signalAll() + case None => + throw new IllegalStateException(s"State for partition $topicPartition should exist.") + case s => + throw new IllegalStateException(s"In-progress partition $topicPartition cannot be in $s state.") + } + } + } + + def doneDeleting(topicPartitions: Iterable[TopicPartition]): Unit = { + inLock(lock) { + topicPartitions.foreach { + topicPartition => + inProgress.get(topicPartition) match { + case Some(LogCleaningInProgress) => + inProgress.remove(topicPartition) + case Some(LogCleaningAborted) => + inProgress.put(topicPartition, LogCleaningPaused(1)) + pausedCleaningCond.signalAll() + case None => + throw new IllegalStateException(s"State for partition $topicPartition should exist.") + case s => + throw new IllegalStateException(s"In-progress partition $topicPartition cannot be in $s state.") + } + } + } + } + + /** + * Returns an immutable set of the uncleanable partitions for a given log directory + * Only used for testing + */ + private[log] def uncleanablePartitions(logDir: String): Set[TopicPartition] = { + var partitions: Set[TopicPartition] = Set() + inLock(lock) { partitions ++= uncleanablePartitions.getOrElse(logDir, partitions) } + partitions + } + + def markPartitionUncleanable(logDir: String, partition: TopicPartition): Unit = { + inLock(lock) { + uncleanablePartitions.get(logDir) match { + case Some(partitions) => + partitions.add(partition) + case None => + uncleanablePartitions.put(logDir, mutable.Set(partition)) + } + } + } + + private def isUncleanablePartition(log: UnifiedLog, topicPartition: TopicPartition): Boolean = { + inLock(lock) { + uncleanablePartitions.get(log.parentDir).exists(partitions => partitions.contains(topicPartition)) + } + } + + def maintainUncleanablePartitions(): Unit = { + // Remove deleted partitions from uncleanablePartitions + inLock(lock) { + // Note: we don't use retain or filterInPlace method in this function because retain is deprecated in + // scala 2.13 while filterInPlace is not available in scala 2.12. + + // Remove deleted partitions + uncleanablePartitions.values.foreach { + partitions => + val partitionsToRemove = partitions.filterNot(logs.contains(_)).toList + partitionsToRemove.foreach { partitions.remove(_) } + } + + // Remove entries with empty partition set. + val logDirsToRemove = uncleanablePartitions.filter { + case (_, partitions) => partitions.isEmpty + }.map { _._1}.toList + logDirsToRemove.foreach { uncleanablePartitions.remove(_) } + } + } +} + +/** + * Helper class for the range of cleanable dirty offsets of a log and whether to update the checkpoint associated with + * the log + * + * @param firstDirtyOffset the lower (inclusive) offset to begin cleaning from + * @param firstUncleanableDirtyOffset the upper(exclusive) offset to clean to + * @param forceUpdateCheckpoint whether to update the checkpoint associated with this log. if true, checkpoint should be + * reset to firstDirtyOffset + */ +private case class OffsetsToClean(firstDirtyOffset: Long, + firstUncleanableDirtyOffset: Long, + forceUpdateCheckpoint: Boolean = false) { +} + +private[log] object LogCleanerManager extends Logging { + + def isCompactAndDelete(log: UnifiedLog): Boolean = { + log.config.compact && log.config.delete + } + + /** + * get max delay between the time when log is required to be compacted as determined + * by maxCompactionLagMs and the current time. + */ + def maxCompactionDelay(log: UnifiedLog, firstDirtyOffset: Long, now: Long) : Long = { + val dirtyNonActiveSegments = log.nonActiveLogSegmentsFrom(firstDirtyOffset) + val firstBatchTimestamps = log.getFirstBatchTimestampForSegments(dirtyNonActiveSegments).filter(_ > 0) + + val earliestDirtySegmentTimestamp = { + if (firstBatchTimestamps.nonEmpty) + firstBatchTimestamps.min + else Long.MaxValue + } + + val maxCompactionLagMs = math.max(log.config.maxCompactionLagMs, 0L) + val cleanUntilTime = now - maxCompactionLagMs + + if (earliestDirtySegmentTimestamp < cleanUntilTime) + cleanUntilTime - earliestDirtySegmentTimestamp + else + 0L + } + + /** + * Returns the range of dirty offsets that can be cleaned. + * + * @param log the log + * @param lastCleanOffset the last checkpointed offset + * @param now the current time in milliseconds of the cleaning operation + * @return OffsetsToClean containing offsets for cleanable portion of log and whether the log checkpoint needs updating + */ + def cleanableOffsets(log: UnifiedLog, lastCleanOffset: Option[Long], now: Long): OffsetsToClean = { + // If the log segments are abnormally truncated and hence the checkpointed offset is no longer valid; + // reset to the log starting offset and log the error + val (firstDirtyOffset, forceUpdateCheckpoint) = { + val logStartOffset = log.logStartOffset + val checkpointDirtyOffset = lastCleanOffset.getOrElse(logStartOffset) + + if (checkpointDirtyOffset < logStartOffset) { + // Don't bother with the warning if compact and delete are enabled. + if (!isCompactAndDelete(log)) + warn(s"Resetting first dirty offset of ${log.name} to log start offset $logStartOffset " + + s"since the checkpointed offset $checkpointDirtyOffset is invalid.") + (logStartOffset, true) + } else if (checkpointDirtyOffset > log.logEndOffset) { + // The dirty offset has gotten ahead of the log end offset. This could happen if there was data + // corruption at the end of the log. We conservatively assume that the full log needs cleaning. + warn(s"The last checkpoint dirty offset for partition ${log.name} is $checkpointDirtyOffset, " + + s"which is larger than the log end offset ${log.logEndOffset}. Resetting to the log start offset $logStartOffset.") + (logStartOffset, true) + } else { + (checkpointDirtyOffset, false) + } + } + + val minCompactionLagMs = math.max(log.config.compactionLagMs, 0L) + + // Find the first segment that cannot be cleaned. We cannot clean past: + // 1. The active segment + // 2. The last stable offset (including the high watermark) + // 3. Any segments closer to the head of the log than the minimum compaction lag time + val firstUncleanableDirtyOffset: Long = Seq( + + // we do not clean beyond the last stable offset + Some(log.lastStableOffset), + + // the active segment is always uncleanable + Option(log.activeSegment.baseOffset), + + // the first segment whose largest message timestamp is within a minimum time lag from now + if (minCompactionLagMs > 0) { + // dirty log segments + val dirtyNonActiveSegments = log.nonActiveLogSegmentsFrom(firstDirtyOffset) + dirtyNonActiveSegments.find { s => + val isUncleanable = s.largestTimestamp > now - minCompactionLagMs + debug(s"Checking if log segment may be cleaned: log='${log.name}' segment.baseOffset=${s.baseOffset} " + + s"segment.largestTimestamp=${s.largestTimestamp}; now - compactionLag=${now - minCompactionLagMs}; " + + s"is uncleanable=$isUncleanable") + isUncleanable + }.map(_.baseOffset) + } else None + ).flatten.min + + debug(s"Finding range of cleanable offsets for log=${log.name}. Last clean offset=$lastCleanOffset " + + s"now=$now => firstDirtyOffset=$firstDirtyOffset firstUncleanableOffset=$firstUncleanableDirtyOffset " + + s"activeSegment.baseOffset=${log.activeSegment.baseOffset}") + + OffsetsToClean(firstDirtyOffset, math.max(firstDirtyOffset, firstUncleanableDirtyOffset), forceUpdateCheckpoint) + } + + /** + * Given the first dirty offset and an uncleanable offset, calculates the total cleanable bytes for this log + * @return the biggest uncleanable offset and the total amount of cleanable bytes + */ + def calculateCleanableBytes(log: UnifiedLog, firstDirtyOffset: Long, uncleanableOffset: Long): (Long, Long) = { + val firstUncleanableSegment = log.nonActiveLogSegmentsFrom(uncleanableOffset).headOption.getOrElse(log.activeSegment) + val firstUncleanableOffset = firstUncleanableSegment.baseOffset + val cleanableBytes = log.logSegments(math.min(firstDirtyOffset, firstUncleanableOffset), firstUncleanableOffset).map(_.size.toLong).sum + + (firstUncleanableOffset, cleanableBytes) + } + +} diff --git a/core/src/main/scala/kafka/log/LogConfig.scala b/core/src/main/scala/kafka/log/LogConfig.scala new file mode 100755 index 0000000..845f80a --- /dev/null +++ b/core/src/main/scala/kafka/log/LogConfig.scala @@ -0,0 +1,538 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import kafka.api.{ApiVersion, ApiVersionValidator, KAFKA_3_0_IV1} +import kafka.log.LogConfig.configDef +import kafka.message.BrokerCompressionCodec +import kafka.server.{KafkaConfig, ThrottledReplicaListValidator} +import kafka.utils.Implicits._ +import org.apache.kafka.common.config.ConfigDef.{ConfigKey, ValidList, Validator} +import org.apache.kafka.common.config.{AbstractConfig, ConfigDef, ConfigException, TopicConfig} +import org.apache.kafka.common.errors.InvalidConfigurationException +import org.apache.kafka.common.record.{LegacyRecord, RecordVersion, TimestampType} +import org.apache.kafka.common.utils.{ConfigUtils, Utils} + +import java.util.{Collections, Locale, Properties} +import scala.annotation.nowarn +import scala.collection.{Map, mutable} +import scala.jdk.CollectionConverters._ + +object Defaults { + val SegmentSize = kafka.server.Defaults.LogSegmentBytes + val SegmentMs = kafka.server.Defaults.LogRollHours * 60 * 60 * 1000L + val SegmentJitterMs = kafka.server.Defaults.LogRollJitterHours * 60 * 60 * 1000L + val FlushInterval = kafka.server.Defaults.LogFlushIntervalMessages + val FlushMs = kafka.server.Defaults.LogFlushSchedulerIntervalMs + val RetentionSize = kafka.server.Defaults.LogRetentionBytes + val RetentionMs = kafka.server.Defaults.LogRetentionHours * 60 * 60 * 1000L + val RemoteLogStorageEnable = false + val LocalRetentionBytes = -2 // It indicates the value to be derived from RetentionSize + val LocalRetentionMs = -2 // It indicates the value to be derived from RetentionMs + val MaxMessageSize = kafka.server.Defaults.MessageMaxBytes + val MaxIndexSize = kafka.server.Defaults.LogIndexSizeMaxBytes + val IndexInterval = kafka.server.Defaults.LogIndexIntervalBytes + val FileDeleteDelayMs = kafka.server.Defaults.LogDeleteDelayMs + val DeleteRetentionMs = kafka.server.Defaults.LogCleanerDeleteRetentionMs + val MinCompactionLagMs = kafka.server.Defaults.LogCleanerMinCompactionLagMs + val MaxCompactionLagMs = kafka.server.Defaults.LogCleanerMaxCompactionLagMs + val MinCleanableDirtyRatio = kafka.server.Defaults.LogCleanerMinCleanRatio + val CleanupPolicy = kafka.server.Defaults.LogCleanupPolicy + val UncleanLeaderElectionEnable = kafka.server.Defaults.UncleanLeaderElectionEnable + val MinInSyncReplicas = kafka.server.Defaults.MinInSyncReplicas + val CompressionType = kafka.server.Defaults.CompressionType + val PreAllocateEnable = kafka.server.Defaults.LogPreAllocateEnable + + /* See `TopicConfig.MESSAGE_FORMAT_VERSION_CONFIG` for details */ + @deprecated("3.0") + val MessageFormatVersion = kafka.server.Defaults.LogMessageFormatVersion + + val MessageTimestampType = kafka.server.Defaults.LogMessageTimestampType + val MessageTimestampDifferenceMaxMs = kafka.server.Defaults.LogMessageTimestampDifferenceMaxMs + val LeaderReplicationThrottledReplicas = Collections.emptyList[String]() + val FollowerReplicationThrottledReplicas = Collections.emptyList[String]() + val MaxIdMapSnapshots = kafka.server.Defaults.MaxIdMapSnapshots + val MessageDownConversionEnable = kafka.server.Defaults.MessageDownConversionEnable +} + +case class LogConfig(props: java.util.Map[_, _], overriddenConfigs: Set[String] = Set.empty) + extends AbstractConfig(LogConfig.configDef, props, false) { + /** + * Important note: Any configuration parameter that is passed along from KafkaConfig to LogConfig + * should also go in [[LogConfig.extractLogConfigMap()]]. + */ + val segmentSize = getInt(LogConfig.SegmentBytesProp) + val segmentMs = getLong(LogConfig.SegmentMsProp) + val segmentJitterMs = getLong(LogConfig.SegmentJitterMsProp) + val maxIndexSize = getInt(LogConfig.SegmentIndexBytesProp) + val flushInterval = getLong(LogConfig.FlushMessagesProp) + val flushMs = getLong(LogConfig.FlushMsProp) + val retentionSize = getLong(LogConfig.RetentionBytesProp) + val retentionMs = getLong(LogConfig.RetentionMsProp) + val maxMessageSize = getInt(LogConfig.MaxMessageBytesProp) + val indexInterval = getInt(LogConfig.IndexIntervalBytesProp) + val fileDeleteDelayMs = getLong(LogConfig.FileDeleteDelayMsProp) + val deleteRetentionMs = getLong(LogConfig.DeleteRetentionMsProp) + val compactionLagMs = getLong(LogConfig.MinCompactionLagMsProp) + val maxCompactionLagMs = getLong(LogConfig.MaxCompactionLagMsProp) + val minCleanableRatio = getDouble(LogConfig.MinCleanableDirtyRatioProp) + val compact = getList(LogConfig.CleanupPolicyProp).asScala.map(_.toLowerCase(Locale.ROOT)).contains(LogConfig.Compact) + val delete = getList(LogConfig.CleanupPolicyProp).asScala.map(_.toLowerCase(Locale.ROOT)).contains(LogConfig.Delete) + val uncleanLeaderElectionEnable = getBoolean(LogConfig.UncleanLeaderElectionEnableProp) + val minInSyncReplicas = getInt(LogConfig.MinInSyncReplicasProp) + val compressionType = getString(LogConfig.CompressionTypeProp).toLowerCase(Locale.ROOT) + val preallocate = getBoolean(LogConfig.PreAllocateEnableProp) + + /* See `TopicConfig.MESSAGE_FORMAT_VERSION_CONFIG` for details */ + @deprecated("3.0") + val messageFormatVersion = ApiVersion(getString(LogConfig.MessageFormatVersionProp)) + + val messageTimestampType = TimestampType.forName(getString(LogConfig.MessageTimestampTypeProp)) + val messageTimestampDifferenceMaxMs = getLong(LogConfig.MessageTimestampDifferenceMaxMsProp).longValue + val LeaderReplicationThrottledReplicas = getList(LogConfig.LeaderReplicationThrottledReplicasProp) + val FollowerReplicationThrottledReplicas = getList(LogConfig.FollowerReplicationThrottledReplicasProp) + val messageDownConversionEnable = getBoolean(LogConfig.MessageDownConversionEnableProp) + + class RemoteLogConfig { + val remoteStorageEnable = getBoolean(LogConfig.RemoteLogStorageEnableProp) + + val localRetentionMs: Long = { + val localLogRetentionMs = getLong(LogConfig.LocalLogRetentionMsProp) + + // -2 indicates to derive value from retentionMs property. + if(localLogRetentionMs == -2) retentionMs + else { + // Added validation here to check the effective value should not be more than RetentionMs. + if(localLogRetentionMs == -1 && retentionMs != -1) { + throw new ConfigException(LogConfig.LocalLogRetentionMsProp, localLogRetentionMs, s"Value must not be -1 as ${LogConfig.RetentionMsProp} value is set as $retentionMs.") + } + + if (localLogRetentionMs > retentionMs) { + throw new ConfigException(LogConfig.LocalLogRetentionMsProp, localLogRetentionMs, s"Value must not be more than property: ${LogConfig.RetentionMsProp} value.") + } + + localLogRetentionMs + } + } + + val localRetentionBytes: Long = { + val localLogRetentionBytes = getLong(LogConfig.LocalLogRetentionBytesProp) + + // -2 indicates to derive value from retentionSize property. + if(localLogRetentionBytes == -2) retentionSize + else { + // Added validation here to check the effective value should not be more than RetentionBytes. + if(localLogRetentionBytes == -1 && retentionSize != -1) { + throw new ConfigException(LogConfig.LocalLogRetentionBytesProp, localLogRetentionBytes, s"Value must not be -1 as ${LogConfig.RetentionBytesProp} value is set as $retentionSize.") + } + + if (localLogRetentionBytes > retentionSize) { + throw new ConfigException(LogConfig.LocalLogRetentionBytesProp, localLogRetentionBytes, s"Value must not be more than property: ${LogConfig.RetentionBytesProp} value."); + } + + localLogRetentionBytes + } + } + } + + private val _remoteLogConfig = new RemoteLogConfig() + def remoteLogConfig = _remoteLogConfig + + @nowarn("cat=deprecation") + def recordVersion = messageFormatVersion.recordVersion + + def randomSegmentJitter: Long = + if (segmentJitterMs == 0) 0 else Utils.abs(scala.util.Random.nextInt()) % math.min(segmentJitterMs, segmentMs) + + def maxSegmentMs: Long = { + if (compact && maxCompactionLagMs > 0) math.min(maxCompactionLagMs, segmentMs) + else segmentMs + } + + def initFileSize: Int = { + if (preallocate) + segmentSize + else + 0 + } + + def overriddenConfigsAsLoggableString: String = { + val overriddenTopicProps = props.asScala.collect { + case (k: String, v) if overriddenConfigs.contains(k) => (k, v.asInstanceOf[AnyRef]) + } + ConfigUtils.configMapToRedactedString(overriddenTopicProps.asJava, configDef) + } +} + +object LogConfig { + + def main(args: Array[String]): Unit = { + println(configDef.toHtml(4, (config: String) => "topicconfigs_" + config)) + } + + val SegmentBytesProp = TopicConfig.SEGMENT_BYTES_CONFIG + val SegmentMsProp = TopicConfig.SEGMENT_MS_CONFIG + val SegmentJitterMsProp = TopicConfig.SEGMENT_JITTER_MS_CONFIG + val SegmentIndexBytesProp = TopicConfig.SEGMENT_INDEX_BYTES_CONFIG + val FlushMessagesProp = TopicConfig.FLUSH_MESSAGES_INTERVAL_CONFIG + val FlushMsProp = TopicConfig.FLUSH_MS_CONFIG + val RetentionBytesProp = TopicConfig.RETENTION_BYTES_CONFIG + val RetentionMsProp = TopicConfig.RETENTION_MS_CONFIG + val RemoteLogStorageEnableProp = TopicConfig.REMOTE_LOG_STORAGE_ENABLE_CONFIG + val LocalLogRetentionMsProp = TopicConfig.LOCAL_LOG_RETENTION_MS_CONFIG + val LocalLogRetentionBytesProp = TopicConfig.LOCAL_LOG_RETENTION_BYTES_CONFIG + val MaxMessageBytesProp = TopicConfig.MAX_MESSAGE_BYTES_CONFIG + val IndexIntervalBytesProp = TopicConfig.INDEX_INTERVAL_BYTES_CONFIG + val DeleteRetentionMsProp = TopicConfig.DELETE_RETENTION_MS_CONFIG + val MinCompactionLagMsProp = TopicConfig.MIN_COMPACTION_LAG_MS_CONFIG + val MaxCompactionLagMsProp = TopicConfig.MAX_COMPACTION_LAG_MS_CONFIG + val FileDeleteDelayMsProp = TopicConfig.FILE_DELETE_DELAY_MS_CONFIG + val MinCleanableDirtyRatioProp = TopicConfig.MIN_CLEANABLE_DIRTY_RATIO_CONFIG + val CleanupPolicyProp = TopicConfig.CLEANUP_POLICY_CONFIG + val Delete = TopicConfig.CLEANUP_POLICY_DELETE + val Compact = TopicConfig.CLEANUP_POLICY_COMPACT + val UncleanLeaderElectionEnableProp = TopicConfig.UNCLEAN_LEADER_ELECTION_ENABLE_CONFIG + val MinInSyncReplicasProp = TopicConfig.MIN_IN_SYNC_REPLICAS_CONFIG + val CompressionTypeProp = TopicConfig.COMPRESSION_TYPE_CONFIG + val PreAllocateEnableProp = TopicConfig.PREALLOCATE_CONFIG + + /* See `TopicConfig.MESSAGE_FORMAT_VERSION_CONFIG` for details */ + @deprecated("3.0") @nowarn("cat=deprecation") + val MessageFormatVersionProp = TopicConfig.MESSAGE_FORMAT_VERSION_CONFIG + val MessageTimestampTypeProp = TopicConfig.MESSAGE_TIMESTAMP_TYPE_CONFIG + val MessageTimestampDifferenceMaxMsProp = TopicConfig.MESSAGE_TIMESTAMP_DIFFERENCE_MAX_MS_CONFIG + val MessageDownConversionEnableProp = TopicConfig.MESSAGE_DOWNCONVERSION_ENABLE_CONFIG + + // Leave these out of TopicConfig for now as they are replication quota configs + val LeaderReplicationThrottledReplicasProp = "leader.replication.throttled.replicas" + val FollowerReplicationThrottledReplicasProp = "follower.replication.throttled.replicas" + + val SegmentSizeDoc = TopicConfig.SEGMENT_BYTES_DOC + val SegmentMsDoc = TopicConfig.SEGMENT_MS_DOC + val SegmentJitterMsDoc = TopicConfig.SEGMENT_JITTER_MS_DOC + val MaxIndexSizeDoc = TopicConfig.SEGMENT_INDEX_BYTES_DOC + val FlushIntervalDoc = TopicConfig.FLUSH_MESSAGES_INTERVAL_DOC + val FlushMsDoc = TopicConfig.FLUSH_MS_DOC + val RetentionSizeDoc = TopicConfig.RETENTION_BYTES_DOC + val RetentionMsDoc = TopicConfig.RETENTION_MS_DOC + val RemoteLogStorageEnableDoc = TopicConfig.REMOTE_LOG_STORAGE_ENABLE_DOC + val LocalLogRetentionMsDoc = TopicConfig.LOCAL_LOG_RETENTION_MS_DOC + val LocalLogRetentionBytesDoc = TopicConfig.LOCAL_LOG_RETENTION_BYTES_DOC + val MaxMessageSizeDoc = TopicConfig.MAX_MESSAGE_BYTES_DOC + val IndexIntervalDoc = TopicConfig.INDEX_INTERVAL_BYTES_DOCS + val FileDeleteDelayMsDoc = TopicConfig.FILE_DELETE_DELAY_MS_DOC + val DeleteRetentionMsDoc = TopicConfig.DELETE_RETENTION_MS_DOC + val MinCompactionLagMsDoc = TopicConfig.MIN_COMPACTION_LAG_MS_DOC + val MaxCompactionLagMsDoc = TopicConfig.MAX_COMPACTION_LAG_MS_DOC + val MinCleanableRatioDoc = TopicConfig.MIN_CLEANABLE_DIRTY_RATIO_DOC + val CompactDoc = TopicConfig.CLEANUP_POLICY_DOC + val UncleanLeaderElectionEnableDoc = TopicConfig.UNCLEAN_LEADER_ELECTION_ENABLE_DOC + val MinInSyncReplicasDoc = TopicConfig.MIN_IN_SYNC_REPLICAS_DOC + val CompressionTypeDoc = TopicConfig.COMPRESSION_TYPE_DOC + val PreAllocateEnableDoc = TopicConfig.PREALLOCATE_DOC + + /* See `TopicConfig.MESSAGE_FORMAT_VERSION_CONFIG` for details */ + @deprecated("3.0") @nowarn("cat=deprecation") + val MessageFormatVersionDoc = TopicConfig.MESSAGE_FORMAT_VERSION_DOC + + val MessageTimestampTypeDoc = TopicConfig.MESSAGE_TIMESTAMP_TYPE_DOC + val MessageTimestampDifferenceMaxMsDoc = TopicConfig.MESSAGE_TIMESTAMP_DIFFERENCE_MAX_MS_DOC + val MessageDownConversionEnableDoc = TopicConfig.MESSAGE_DOWNCONVERSION_ENABLE_DOC + + val LeaderReplicationThrottledReplicasDoc = "A list of replicas for which log replication should be throttled on " + + "the leader side. The list should describe a set of replicas in the form " + + "[PartitionId]:[BrokerId],[PartitionId]:[BrokerId]:... or alternatively the wildcard '*' can be used to throttle " + + "all replicas for this topic." + val FollowerReplicationThrottledReplicasDoc = "A list of replicas for which log replication should be throttled on " + + "the follower side. The list should describe a set of " + "replicas in the form " + + "[PartitionId]:[BrokerId],[PartitionId]:[BrokerId]:... or alternatively the wildcard '*' can be used to throttle " + + "all replicas for this topic." + + private[log] val ServerDefaultHeaderName = "Server Default Property" + + val configsWithNoServerDefaults: Set[String] = Set(RemoteLogStorageEnableProp, LocalLogRetentionMsProp, LocalLogRetentionBytesProp); + + // Package private for testing + private[log] class LogConfigDef(base: ConfigDef) extends ConfigDef(base) { + def this() = this(new ConfigDef) + + private final val serverDefaultConfigNames = mutable.Map[String, String]() + base match { + case b: LogConfigDef => serverDefaultConfigNames ++= b.serverDefaultConfigNames + case _ => + } + + def define(name: String, defType: ConfigDef.Type, defaultValue: Any, validator: Validator, + importance: ConfigDef.Importance, doc: String, serverDefaultConfigName: String): LogConfigDef = { + super.define(name, defType, defaultValue, validator, importance, doc) + serverDefaultConfigNames.put(name, serverDefaultConfigName) + this + } + + def define(name: String, defType: ConfigDef.Type, defaultValue: Any, importance: ConfigDef.Importance, + documentation: String, serverDefaultConfigName: String): LogConfigDef = { + super.define(name, defType, defaultValue, importance, documentation) + serverDefaultConfigNames.put(name, serverDefaultConfigName) + this + } + + def define(name: String, defType: ConfigDef.Type, importance: ConfigDef.Importance, documentation: String, + serverDefaultConfigName: String): LogConfigDef = { + super.define(name, defType, importance, documentation) + serverDefaultConfigNames.put(name, serverDefaultConfigName) + this + } + + override def headers = List("Name", "Description", "Type", "Default", "Valid Values", ServerDefaultHeaderName, + "Importance").asJava + + override def getConfigValue(key: ConfigKey, headerName: String): String = { + headerName match { + case ServerDefaultHeaderName => serverDefaultConfigNames.getOrElse(key.name, null) + case _ => super.getConfigValue(key, headerName) + } + } + + def serverConfigName(configName: String): Option[String] = serverDefaultConfigNames.get(configName) + } + + // Package private for testing, return a copy since it's a mutable global variable + private[kafka] def configDefCopy: LogConfigDef = new LogConfigDef(configDef) + + private val configDef: LogConfigDef = { + import org.apache.kafka.common.config.ConfigDef.Importance._ + import org.apache.kafka.common.config.ConfigDef.Range._ + import org.apache.kafka.common.config.ConfigDef.Type._ + import org.apache.kafka.common.config.ConfigDef.ValidString._ + + @nowarn("cat=deprecation") + val logConfigDef = new LogConfigDef() + .define(SegmentBytesProp, INT, Defaults.SegmentSize, atLeast(LegacyRecord.RECORD_OVERHEAD_V0), MEDIUM, + SegmentSizeDoc, KafkaConfig.LogSegmentBytesProp) + .define(SegmentMsProp, LONG, Defaults.SegmentMs, atLeast(1), MEDIUM, SegmentMsDoc, + KafkaConfig.LogRollTimeMillisProp) + .define(SegmentJitterMsProp, LONG, Defaults.SegmentJitterMs, atLeast(0), MEDIUM, SegmentJitterMsDoc, + KafkaConfig.LogRollTimeJitterMillisProp) + .define(SegmentIndexBytesProp, INT, Defaults.MaxIndexSize, atLeast(0), MEDIUM, MaxIndexSizeDoc, + KafkaConfig.LogIndexSizeMaxBytesProp) + .define(FlushMessagesProp, LONG, Defaults.FlushInterval, atLeast(0), MEDIUM, FlushIntervalDoc, + KafkaConfig.LogFlushIntervalMessagesProp) + .define(FlushMsProp, LONG, Defaults.FlushMs, atLeast(0), MEDIUM, FlushMsDoc, + KafkaConfig.LogFlushIntervalMsProp) + // can be negative. See kafka.log.LogManager.cleanupSegmentsToMaintainSize + .define(RetentionBytesProp, LONG, Defaults.RetentionSize, MEDIUM, RetentionSizeDoc, + KafkaConfig.LogRetentionBytesProp) + // can be negative. See kafka.log.LogManager.cleanupExpiredSegments + .define(RetentionMsProp, LONG, Defaults.RetentionMs, atLeast(-1), MEDIUM, RetentionMsDoc, + KafkaConfig.LogRetentionTimeMillisProp) + .define(MaxMessageBytesProp, INT, Defaults.MaxMessageSize, atLeast(0), MEDIUM, MaxMessageSizeDoc, + KafkaConfig.MessageMaxBytesProp) + .define(IndexIntervalBytesProp, INT, Defaults.IndexInterval, atLeast(0), MEDIUM, IndexIntervalDoc, + KafkaConfig.LogIndexIntervalBytesProp) + .define(DeleteRetentionMsProp, LONG, Defaults.DeleteRetentionMs, atLeast(0), MEDIUM, + DeleteRetentionMsDoc, KafkaConfig.LogCleanerDeleteRetentionMsProp) + .define(MinCompactionLagMsProp, LONG, Defaults.MinCompactionLagMs, atLeast(0), MEDIUM, MinCompactionLagMsDoc, + KafkaConfig.LogCleanerMinCompactionLagMsProp) + .define(MaxCompactionLagMsProp, LONG, Defaults.MaxCompactionLagMs, atLeast(1), MEDIUM, MaxCompactionLagMsDoc, + KafkaConfig.LogCleanerMaxCompactionLagMsProp) + .define(FileDeleteDelayMsProp, LONG, Defaults.FileDeleteDelayMs, atLeast(0), MEDIUM, FileDeleteDelayMsDoc, + KafkaConfig.LogDeleteDelayMsProp) + .define(MinCleanableDirtyRatioProp, DOUBLE, Defaults.MinCleanableDirtyRatio, between(0, 1), MEDIUM, + MinCleanableRatioDoc, KafkaConfig.LogCleanerMinCleanRatioProp) + .define(CleanupPolicyProp, LIST, Defaults.CleanupPolicy, ValidList.in(LogConfig.Compact, LogConfig.Delete), MEDIUM, CompactDoc, + KafkaConfig.LogCleanupPolicyProp) + .define(UncleanLeaderElectionEnableProp, BOOLEAN, Defaults.UncleanLeaderElectionEnable, + MEDIUM, UncleanLeaderElectionEnableDoc, KafkaConfig.UncleanLeaderElectionEnableProp) + .define(MinInSyncReplicasProp, INT, Defaults.MinInSyncReplicas, atLeast(1), MEDIUM, MinInSyncReplicasDoc, + KafkaConfig.MinInSyncReplicasProp) + .define(CompressionTypeProp, STRING, Defaults.CompressionType, in(BrokerCompressionCodec.brokerCompressionOptions:_*), + MEDIUM, CompressionTypeDoc, KafkaConfig.CompressionTypeProp) + .define(PreAllocateEnableProp, BOOLEAN, Defaults.PreAllocateEnable, MEDIUM, PreAllocateEnableDoc, + KafkaConfig.LogPreAllocateProp) + .define(MessageFormatVersionProp, STRING, Defaults.MessageFormatVersion, ApiVersionValidator, MEDIUM, MessageFormatVersionDoc, + KafkaConfig.LogMessageFormatVersionProp) + .define(MessageTimestampTypeProp, STRING, Defaults.MessageTimestampType, in("CreateTime", "LogAppendTime"), MEDIUM, MessageTimestampTypeDoc, + KafkaConfig.LogMessageTimestampTypeProp) + .define(MessageTimestampDifferenceMaxMsProp, LONG, Defaults.MessageTimestampDifferenceMaxMs, + atLeast(0), MEDIUM, MessageTimestampDifferenceMaxMsDoc, KafkaConfig.LogMessageTimestampDifferenceMaxMsProp) + .define(LeaderReplicationThrottledReplicasProp, LIST, Defaults.LeaderReplicationThrottledReplicas, ThrottledReplicaListValidator, MEDIUM, + LeaderReplicationThrottledReplicasDoc, LeaderReplicationThrottledReplicasProp) + .define(FollowerReplicationThrottledReplicasProp, LIST, Defaults.FollowerReplicationThrottledReplicas, ThrottledReplicaListValidator, MEDIUM, + FollowerReplicationThrottledReplicasDoc, FollowerReplicationThrottledReplicasProp) + .define(MessageDownConversionEnableProp, BOOLEAN, Defaults.MessageDownConversionEnable, LOW, + MessageDownConversionEnableDoc, KafkaConfig.LogMessageDownConversionEnableProp) + + // RemoteLogStorageEnableProp, LocalLogRetentionMsProp, LocalLogRetentionBytesProp do not have server default + // config names. + logConfigDef + // This define method is not overridden in LogConfig as these configs do not have server defaults yet. + .defineInternal(RemoteLogStorageEnableProp, BOOLEAN, Defaults.RemoteLogStorageEnable, null, MEDIUM, RemoteLogStorageEnableDoc) + .defineInternal(LocalLogRetentionMsProp, LONG, Defaults.LocalRetentionMs, atLeast(-2), MEDIUM, LocalLogRetentionMsDoc) + .defineInternal(LocalLogRetentionBytesProp, LONG, Defaults.LocalRetentionBytes, atLeast(-2), MEDIUM, LocalLogRetentionBytesDoc) + + logConfigDef + } + + def apply(): LogConfig = LogConfig(new Properties()) + + def configNames: Seq[String] = configDef.names.asScala.toSeq.sorted + + def serverConfigName(configName: String): Option[String] = configDef.serverConfigName(configName) + + def configType(configName: String): Option[ConfigDef.Type] = { + Option(configDef.configKeys.get(configName)).map(_.`type`) + } + + /** + * Create a log config instance using the given properties and defaults + */ + def fromProps(defaults: java.util.Map[_ <: Object, _ <: Object], overrides: Properties): LogConfig = { + val props = new Properties() + defaults.forEach { (k, v) => props.put(k, v) } + props ++= overrides + val overriddenKeys = overrides.keySet.asScala.map(_.asInstanceOf[String]).toSet + new LogConfig(props, overriddenKeys) + } + + /** + * Check that property names are valid + */ + def validateNames(props: Properties): Unit = { + val names = configNames + for(name <- props.asScala.keys) + if (!names.contains(name)) + throw new InvalidConfigurationException(s"Unknown topic config name: $name") + } + + private[kafka] def configKeys: Map[String, ConfigKey] = configDef.configKeys.asScala + + def validateValues(props: java.util.Map[_, _]): Unit = { + val minCompactionLag = props.get(MinCompactionLagMsProp).asInstanceOf[Long] + val maxCompactionLag = props.get(MaxCompactionLagMsProp).asInstanceOf[Long] + if (minCompactionLag > maxCompactionLag) { + throw new InvalidConfigurationException(s"conflict topic config setting $MinCompactionLagMsProp " + + s"($minCompactionLag) > $MaxCompactionLagMsProp ($maxCompactionLag)") + } + } + + /** + * Check that the given properties contain only valid log config names and that all values can be parsed and are valid + */ + def validate(props: Properties): Unit = { + validateNames(props) + val valueMaps = configDef.parse(props) + validateValues(valueMaps) + } + + /** + * Map topic config to the broker config with highest priority. Some of these have additional synonyms + * that can be obtained using [[kafka.server.DynamicBrokerConfig#brokerConfigSynonyms]] + */ + @nowarn("cat=deprecation") + val TopicConfigSynonyms = Map( + SegmentBytesProp -> KafkaConfig.LogSegmentBytesProp, + SegmentMsProp -> KafkaConfig.LogRollTimeMillisProp, + SegmentJitterMsProp -> KafkaConfig.LogRollTimeJitterMillisProp, + SegmentIndexBytesProp -> KafkaConfig.LogIndexSizeMaxBytesProp, + FlushMessagesProp -> KafkaConfig.LogFlushIntervalMessagesProp, + FlushMsProp -> KafkaConfig.LogFlushIntervalMsProp, + RetentionBytesProp -> KafkaConfig.LogRetentionBytesProp, + RetentionMsProp -> KafkaConfig.LogRetentionTimeMillisProp, + MaxMessageBytesProp -> KafkaConfig.MessageMaxBytesProp, + IndexIntervalBytesProp -> KafkaConfig.LogIndexIntervalBytesProp, + DeleteRetentionMsProp -> KafkaConfig.LogCleanerDeleteRetentionMsProp, + MinCompactionLagMsProp -> KafkaConfig.LogCleanerMinCompactionLagMsProp, + MaxCompactionLagMsProp -> KafkaConfig.LogCleanerMaxCompactionLagMsProp, + FileDeleteDelayMsProp -> KafkaConfig.LogDeleteDelayMsProp, + MinCleanableDirtyRatioProp -> KafkaConfig.LogCleanerMinCleanRatioProp, + CleanupPolicyProp -> KafkaConfig.LogCleanupPolicyProp, + UncleanLeaderElectionEnableProp -> KafkaConfig.UncleanLeaderElectionEnableProp, + MinInSyncReplicasProp -> KafkaConfig.MinInSyncReplicasProp, + CompressionTypeProp -> KafkaConfig.CompressionTypeProp, + PreAllocateEnableProp -> KafkaConfig.LogPreAllocateProp, + MessageFormatVersionProp -> KafkaConfig.LogMessageFormatVersionProp, + MessageTimestampTypeProp -> KafkaConfig.LogMessageTimestampTypeProp, + MessageTimestampDifferenceMaxMsProp -> KafkaConfig.LogMessageTimestampDifferenceMaxMsProp, + MessageDownConversionEnableProp -> KafkaConfig.LogMessageDownConversionEnableProp + ) + + + /** + * Copy the subset of properties that are relevant to Logs. The individual properties + * are listed here since the names are slightly different in each Config class... + */ + @nowarn("cat=deprecation") + def extractLogConfigMap( + kafkaConfig: KafkaConfig + ): java.util.Map[String, Object] = { + val logProps = new java.util.HashMap[String, Object]() + logProps.put(SegmentBytesProp, kafkaConfig.logSegmentBytes) + logProps.put(SegmentMsProp, kafkaConfig.logRollTimeMillis) + logProps.put(SegmentJitterMsProp, kafkaConfig.logRollTimeJitterMillis) + logProps.put(SegmentIndexBytesProp, kafkaConfig.logIndexSizeMaxBytes) + logProps.put(FlushMessagesProp, kafkaConfig.logFlushIntervalMessages) + logProps.put(FlushMsProp, kafkaConfig.logFlushIntervalMs) + logProps.put(RetentionBytesProp, kafkaConfig.logRetentionBytes) + logProps.put(RetentionMsProp, kafkaConfig.logRetentionTimeMillis: java.lang.Long) + logProps.put(MaxMessageBytesProp, kafkaConfig.messageMaxBytes) + logProps.put(IndexIntervalBytesProp, kafkaConfig.logIndexIntervalBytes) + logProps.put(DeleteRetentionMsProp, kafkaConfig.logCleanerDeleteRetentionMs) + logProps.put(MinCompactionLagMsProp, kafkaConfig.logCleanerMinCompactionLagMs) + logProps.put(MaxCompactionLagMsProp, kafkaConfig.logCleanerMaxCompactionLagMs) + logProps.put(FileDeleteDelayMsProp, kafkaConfig.logDeleteDelayMs) + logProps.put(MinCleanableDirtyRatioProp, kafkaConfig.logCleanerMinCleanRatio) + logProps.put(CleanupPolicyProp, kafkaConfig.logCleanupPolicy) + logProps.put(MinInSyncReplicasProp, kafkaConfig.minInSyncReplicas) + logProps.put(CompressionTypeProp, kafkaConfig.compressionType) + logProps.put(UncleanLeaderElectionEnableProp, kafkaConfig.uncleanLeaderElectionEnable) + logProps.put(PreAllocateEnableProp, kafkaConfig.logPreAllocateEnable) + logProps.put(MessageFormatVersionProp, kafkaConfig.logMessageFormatVersion.version) + logProps.put(MessageTimestampTypeProp, kafkaConfig.logMessageTimestampType.name) + logProps.put(MessageTimestampDifferenceMaxMsProp, kafkaConfig.logMessageTimestampDifferenceMaxMs: java.lang.Long) + logProps.put(MessageDownConversionEnableProp, kafkaConfig.logMessageDownConversionEnable: java.lang.Boolean) + logProps + } + + def shouldIgnoreMessageFormatVersion(interBrokerProtocolVersion: ApiVersion): Boolean = + interBrokerProtocolVersion >= KAFKA_3_0_IV1 + + class MessageFormatVersion(messageFormatVersionString: String, interBrokerProtocolVersionString: String) { + val messageFormatVersion = ApiVersion(messageFormatVersionString) + private val interBrokerProtocolVersion = ApiVersion(interBrokerProtocolVersionString) + + def shouldIgnore: Boolean = shouldIgnoreMessageFormatVersion(interBrokerProtocolVersion) + + def shouldWarn: Boolean = + interBrokerProtocolVersion >= KAFKA_3_0_IV1 && messageFormatVersion.recordVersion.precedes(RecordVersion.V2) + + @nowarn("cat=deprecation") + def topicWarningMessage(topicName: String): String = { + s"Topic configuration ${LogConfig.MessageFormatVersionProp} with value `$messageFormatVersionString` is ignored " + + s"for `$topicName` because the inter-broker protocol version `$interBrokerProtocolVersionString` is " + + "greater or equal than 3.0. This configuration is deprecated and it will be removed in Apache Kafka 4.0." + } + + @nowarn("cat=deprecation") + def brokerWarningMessage: String = { + s"Broker configuration ${KafkaConfig.LogMessageFormatVersionProp} with value $messageFormatVersionString is ignored " + + s"because the inter-broker protocol version `$interBrokerProtocolVersionString` is greater or equal than 3.0. " + + "This configuration is deprecated and it will be removed in Apache Kafka 4.0." + } + } + +} diff --git a/core/src/main/scala/kafka/log/LogLoader.scala b/core/src/main/scala/kafka/log/LogLoader.scala new file mode 100644 index 0000000..b075069 --- /dev/null +++ b/core/src/main/scala/kafka/log/LogLoader.scala @@ -0,0 +1,524 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import java.io.{File, IOException} +import java.nio.file.{Files, NoSuchFileException} + +import kafka.common.LogSegmentOffsetOverflowException +import kafka.log.UnifiedLog.{CleanedFileSuffix, DeletedFileSuffix, SwapFileSuffix, isIndexFile, isLogFile, offsetFromFile} +import kafka.server.{LogDirFailureChannel, LogOffsetMetadata} +import kafka.server.epoch.LeaderEpochFileCache +import kafka.utils.{CoreUtils, Logging, Scheduler} +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.errors.InvalidOffsetException +import org.apache.kafka.common.utils.Time + +import scala.collection.{Set, mutable} + +case class LoadedLogOffsets(logStartOffset: Long, + recoveryPoint: Long, + nextOffsetMetadata: LogOffsetMetadata) + +/** + * @param dir The directory from which log segments need to be loaded + * @param topicPartition The topic partition associated with the log being loaded + * @param config The configuration settings for the log being loaded + * @param scheduler The thread pool scheduler used for background actions + * @param time The time instance used for checking the clock + * @param logDirFailureChannel The LogDirFailureChannel instance to asynchronously handle log + * directory failure + * @param hadCleanShutdown Boolean flag to indicate whether the associated log previously had a + * clean shutdown + * @param segments The LogSegments instance into which segments recovered from disk will be + * populated + * @param logStartOffsetCheckpoint The checkpoint of the log start offset + * @param recoveryPointCheckpoint The checkpoint of the offset at which to begin the recovery + * @param maxProducerIdExpirationMs The maximum amount of time to wait before a producer id is + * considered expired + * @param leaderEpochCache An optional LeaderEpochFileCache instance to be updated during recovery + * @param producerStateManager The ProducerStateManager instance to be updated during recovery + */ +case class LoadLogParams(dir: File, + topicPartition: TopicPartition, + config: LogConfig, + scheduler: Scheduler, + time: Time, + logDirFailureChannel: LogDirFailureChannel, + hadCleanShutdown: Boolean, + segments: LogSegments, + logStartOffsetCheckpoint: Long, + recoveryPointCheckpoint: Long, + maxProducerIdExpirationMs: Int, + leaderEpochCache: Option[LeaderEpochFileCache], + producerStateManager: ProducerStateManager) { + val logIdentifier: String = s"[LogLoader partition=$topicPartition, dir=${dir.getParent}] " +} + +/** + * This object is responsible for all activities related with recovery of log segments from disk. + */ +object LogLoader extends Logging { + + /** + * Clean shutdown file that indicates the broker was cleanly shutdown in 0.8 and higher. + * This is used to avoid unnecessary recovery after a clean shutdown. In theory this could be + * avoided by passing in the recovery point, however finding the correct position to do this + * requires accessing the offset index which may not be safe in an unclean shutdown. + * For more information see the discussion in PR#2104 + */ + val CleanShutdownFile = ".kafka_cleanshutdown" + + /** + * Load the log segments from the log files on disk, and returns the components of the loaded log. + * Additionally, it also suitably updates the provided LeaderEpochFileCache and ProducerStateManager + * to reflect the contents of the loaded log. + * + * In the context of the calling thread, this function does not need to convert IOException to + * KafkaStorageException because it is only called before all logs are loaded. + * + * @param params The parameters for the log being loaded from disk + * + * @return the offsets of the Log successfully loaded from disk + * + * @throws LogSegmentOffsetOverflowException if we encounter a .swap file with messages that + * overflow index offset + */ + def load(params: LoadLogParams): LoadedLogOffsets = { + // First pass: through the files in the log directory and remove any temporary files + // and find any interrupted swap operations + val swapFiles = removeTempFilesAndCollectSwapFiles(params) + + // The remaining valid swap files must come from compaction or segment split operation. We can + // simply rename them to regular segment files. But, before renaming, we should figure out which + // segments are compacted/split and delete these segment files: this is done by calculating + // min/maxSwapFileOffset. + // We store segments that require renaming in this code block, and do the actual renaming later. + var minSwapFileOffset = Long.MaxValue + var maxSwapFileOffset = Long.MinValue + swapFiles.filter(f => UnifiedLog.isLogFile(new File(CoreUtils.replaceSuffix(f.getPath, SwapFileSuffix, "")))).foreach { f => + val baseOffset = offsetFromFile(f) + val segment = LogSegment.open(f.getParentFile, + baseOffset = baseOffset, + params.config, + time = params.time, + fileSuffix = UnifiedLog.SwapFileSuffix) + info(s"${params.logIdentifier}Found log file ${f.getPath} from interrupted swap operation, which is recoverable from ${UnifiedLog.SwapFileSuffix} files by renaming.") + minSwapFileOffset = Math.min(segment.baseOffset, minSwapFileOffset) + maxSwapFileOffset = Math.max(segment.readNextOffset, maxSwapFileOffset) + } + + // Second pass: delete segments that are between minSwapFileOffset and maxSwapFileOffset. As + // discussed above, these segments were compacted or split but haven't been renamed to .delete + // before shutting down the broker. + for (file <- params.dir.listFiles if file.isFile) { + try { + if (!file.getName.endsWith(SwapFileSuffix)) { + val offset = offsetFromFile(file) + if (offset >= minSwapFileOffset && offset < maxSwapFileOffset) { + info(s"${params.logIdentifier}Deleting segment files ${file.getName} that is compacted but has not been deleted yet.") + file.delete() + } + } + } catch { + // offsetFromFile with files that do not include an offset in the file name + case _: StringIndexOutOfBoundsException => + case _: NumberFormatException => + } + } + + // Third pass: rename all swap files. + for (file <- params.dir.listFiles if file.isFile) { + if (file.getName.endsWith(SwapFileSuffix)) { + info(s"${params.logIdentifier}Recovering file ${file.getName} by renaming from ${UnifiedLog.SwapFileSuffix} files.") + file.renameTo(new File(CoreUtils.replaceSuffix(file.getPath, UnifiedLog.SwapFileSuffix, ""))) + } + } + + // Fourth pass: load all the log and index files. + // We might encounter legacy log segments with offset overflow (KAFKA-6264). We need to split such segments. When + // this happens, restart loading segment files from scratch. + retryOnOffsetOverflow(params, { + // In case we encounter a segment with offset overflow, the retry logic will split it after which we need to retry + // loading of segments. In that case, we also need to close all segments that could have been left open in previous + // call to loadSegmentFiles(). + params.segments.close() + params.segments.clear() + loadSegmentFiles(params) + }) + + val (newRecoveryPoint: Long, nextOffset: Long) = { + if (!params.dir.getAbsolutePath.endsWith(UnifiedLog.DeleteDirSuffix)) { + val (newRecoveryPoint, nextOffset) = retryOnOffsetOverflow(params, { + recoverLog(params) + }) + + // reset the index size of the currently active log segment to allow more entries + params.segments.lastSegment.get.resizeIndexes(params.config.maxIndexSize) + (newRecoveryPoint, nextOffset) + } else { + if (params.segments.isEmpty) { + params.segments.add( + LogSegment.open( + dir = params.dir, + baseOffset = 0, + params.config, + time = params.time, + initFileSize = params.config.initFileSize)) + } + (0L, 0L) + } + } + + params.leaderEpochCache.foreach(_.truncateFromEnd(nextOffset)) + val newLogStartOffset = math.max(params.logStartOffsetCheckpoint, params.segments.firstSegment.get.baseOffset) + // The earliest leader epoch may not be flushed during a hard failure. Recover it here. + params.leaderEpochCache.foreach(_.truncateFromStart(params.logStartOffsetCheckpoint)) + + // Any segment loading or recovery code must not use producerStateManager, so that we can build the full state here + // from scratch. + if (!params.producerStateManager.isEmpty) + throw new IllegalStateException("Producer state must be empty during log initialization") + + // Reload all snapshots into the ProducerStateManager cache, the intermediate ProducerStateManager used + // during log recovery may have deleted some files without the LogLoader.producerStateManager instance witnessing the + // deletion. + params.producerStateManager.removeStraySnapshots(params.segments.baseOffsets.toSeq) + UnifiedLog.rebuildProducerState( + params.producerStateManager, + params.segments, + newLogStartOffset, + nextOffset, + params.config.recordVersion, + params.time, + reloadFromCleanShutdown = params.hadCleanShutdown, + params.logIdentifier) + val activeSegment = params.segments.lastSegment.get + LoadedLogOffsets( + newLogStartOffset, + newRecoveryPoint, + LogOffsetMetadata(nextOffset, activeSegment.baseOffset, activeSegment.size)) + } + + /** + * Removes any temporary files found in log directory, and creates a list of all .swap files which could be swapped + * in place of existing segment(s). For log splitting, we know that any .swap file whose base offset is higher than + * the smallest offset .clean file could be part of an incomplete split operation. Such .swap files are also deleted + * by this method. + * + * @param params The parameters for the log being loaded from disk + * @return Set of .swap files that are valid to be swapped in as segment files and index files + */ + private def removeTempFilesAndCollectSwapFiles(params: LoadLogParams): Set[File] = { + + val swapFiles = mutable.Set[File]() + val cleanedFiles = mutable.Set[File]() + var minCleanedFileOffset = Long.MaxValue + + for (file <- params.dir.listFiles if file.isFile) { + if (!file.canRead) + throw new IOException(s"Could not read file $file") + val filename = file.getName + if (filename.endsWith(DeletedFileSuffix)) { + debug(s"${params.logIdentifier}Deleting stray temporary file ${file.getAbsolutePath}") + Files.deleteIfExists(file.toPath) + } else if (filename.endsWith(CleanedFileSuffix)) { + minCleanedFileOffset = Math.min(offsetFromFile(file), minCleanedFileOffset) + cleanedFiles += file + } else if (filename.endsWith(SwapFileSuffix)) { + swapFiles += file + } + } + + // KAFKA-6264: Delete all .swap files whose base offset is greater than the minimum .cleaned segment offset. Such .swap + // files could be part of an incomplete split operation that could not complete. See Log#splitOverflowedSegment + // for more details about the split operation. + val (invalidSwapFiles, validSwapFiles) = swapFiles.partition(file => offsetFromFile(file) >= minCleanedFileOffset) + invalidSwapFiles.foreach { file => + debug(s"${params.logIdentifier}Deleting invalid swap file ${file.getAbsoluteFile} minCleanedFileOffset: $minCleanedFileOffset") + Files.deleteIfExists(file.toPath) + } + + // Now that we have deleted all .swap files that constitute an incomplete split operation, let's delete all .clean files + cleanedFiles.foreach { file => + debug(s"${params.logIdentifier}Deleting stray .clean file ${file.getAbsolutePath}") + Files.deleteIfExists(file.toPath) + } + + validSwapFiles + } + + /** + * Retries the provided function only whenever an LogSegmentOffsetOverflowException is raised by + * it during execution. Before every retry, the overflowed segment is split into one or more segments + * such that there is no offset overflow in any of them. + * + * @param params The parameters for the log being loaded from disk + * @param fn The function to be executed + * @return The value returned by the function, if successful + * @throws Exception whenever the executed function throws any exception other than + * LogSegmentOffsetOverflowException, the same exception is raised to the caller + */ + private def retryOnOffsetOverflow[T](params: LoadLogParams, fn: => T): T = { + while (true) { + try { + return fn + } catch { + case e: LogSegmentOffsetOverflowException => + info(s"${params.logIdentifier}Caught segment overflow error: ${e.getMessage}. Split segment and retry.") + val result = UnifiedLog.splitOverflowedSegment( + e.segment, + params.segments, + params.dir, + params.topicPartition, + params.config, + params.scheduler, + params.logDirFailureChannel, + params.logIdentifier) + deleteProducerSnapshotsAsync(result.deletedSegments, params) + } + } + throw new IllegalStateException() + } + + /** + * Loads segments from disk into the provided params.segments. + * + * This method does not need to convert IOException to KafkaStorageException because it is only called before all logs are loaded. + * It is possible that we encounter a segment with index offset overflow in which case the LogSegmentOffsetOverflowException + * will be thrown. Note that any segments that were opened before we encountered the exception will remain open and the + * caller is responsible for closing them appropriately, if needed. + * + * @param params The parameters for the log being loaded from disk + * @throws LogSegmentOffsetOverflowException if the log directory contains a segment with messages that overflow the index offset + */ + private def loadSegmentFiles(params: LoadLogParams): Unit = { + // load segments in ascending order because transactional data from one segment may depend on the + // segments that come before it + for (file <- params.dir.listFiles.sortBy(_.getName) if file.isFile) { + if (isIndexFile(file)) { + // if it is an index file, make sure it has a corresponding .log file + val offset = offsetFromFile(file) + val logFile = UnifiedLog.logFile(params.dir, offset) + if (!logFile.exists) { + warn(s"${params.logIdentifier}Found an orphaned index file ${file.getAbsolutePath}, with no corresponding log file.") + Files.deleteIfExists(file.toPath) + } + } else if (isLogFile(file)) { + // if it's a log file, load the corresponding log segment + val baseOffset = offsetFromFile(file) + val timeIndexFileNewlyCreated = !UnifiedLog.timeIndexFile(params.dir, baseOffset).exists() + val segment = LogSegment.open( + dir = params.dir, + baseOffset = baseOffset, + params.config, + time = params.time, + fileAlreadyExists = true) + + try segment.sanityCheck(timeIndexFileNewlyCreated) + catch { + case _: NoSuchFileException => + error(s"${params.logIdentifier}Could not find offset index file corresponding to log file" + + s" ${segment.log.file.getAbsolutePath}, recovering segment and rebuilding index files...") + recoverSegment(segment, params) + case e: CorruptIndexException => + warn(s"${params.logIdentifier}Found a corrupted index file corresponding to log file" + + s" ${segment.log.file.getAbsolutePath} due to ${e.getMessage}}, recovering segment and" + + " rebuilding index files...") + recoverSegment(segment, params) + } + params.segments.add(segment) + } + } + } + + /** + * Just recovers the given segment, without adding it to the provided params.segments. + * + * @param segment Segment to recover + * @param params The parameters for the log being loaded from disk + * + * @return The number of bytes truncated from the segment + * + * @throws LogSegmentOffsetOverflowException if the segment contains messages that cause index offset overflow + */ + private def recoverSegment(segment: LogSegment, params: LoadLogParams): Int = { + val producerStateManager = new ProducerStateManager( + params.topicPartition, + params.dir, + params.maxProducerIdExpirationMs, + params.time) + UnifiedLog.rebuildProducerState( + producerStateManager, + params.segments, + params.logStartOffsetCheckpoint, + segment.baseOffset, + params.config.recordVersion, + params.time, + reloadFromCleanShutdown = false, + params.logIdentifier) + val bytesTruncated = segment.recover(producerStateManager, params.leaderEpochCache) + // once we have recovered the segment's data, take a snapshot to ensure that we won't + // need to reload the same segment again while recovering another segment. + producerStateManager.takeSnapshot() + bytesTruncated + } + + /** + * Recover the log segments (if there was an unclean shutdown). Ensures there is at least one + * active segment, and returns the updated recovery point and next offset after recovery. Along + * the way, the method suitably updates the LeaderEpochFileCache or ProducerStateManager inside + * the provided LogComponents. + * + * This method does not need to convert IOException to KafkaStorageException because it is only + * called before all logs are loaded. + * + * @param params The parameters for the log being loaded from disk + * + * @return a tuple containing (newRecoveryPoint, nextOffset). + * + * @throws LogSegmentOffsetOverflowException if we encountered a legacy segment with offset overflow + */ + private[log] def recoverLog(params: LoadLogParams): (Long, Long) = { + /** return the log end offset if valid */ + def deleteSegmentsIfLogStartGreaterThanLogEnd(): Option[Long] = { + if (params.segments.nonEmpty) { + val logEndOffset = params.segments.lastSegment.get.readNextOffset + if (logEndOffset >= params.logStartOffsetCheckpoint) + Some(logEndOffset) + else { + warn(s"${params.logIdentifier}Deleting all segments because logEndOffset ($logEndOffset) " + + s"is smaller than logStartOffset ${params.logStartOffsetCheckpoint}. " + + "This could happen if segment files were deleted from the file system.") + removeAndDeleteSegmentsAsync(params.segments.values, params) + params.leaderEpochCache.foreach(_.clearAndFlush()) + params.producerStateManager.truncateFullyAndStartAt(params.logStartOffsetCheckpoint) + None + } + } else None + } + + // If we have the clean shutdown marker, skip recovery. + if (!params.hadCleanShutdown) { + val unflushed = params.segments.values(params.recoveryPointCheckpoint, Long.MaxValue).iterator + var truncated = false + + while (unflushed.hasNext && !truncated) { + val segment = unflushed.next() + info(s"${params.logIdentifier}Recovering unflushed segment ${segment.baseOffset}") + val truncatedBytes = + try { + recoverSegment(segment, params) + } catch { + case _: InvalidOffsetException => + val startOffset = segment.baseOffset + warn(s"${params.logIdentifier}Found invalid offset during recovery. Deleting the" + + s" corrupt segment and creating an empty one with starting offset $startOffset") + segment.truncateTo(startOffset) + } + if (truncatedBytes > 0) { + // we had an invalid message, delete all remaining log + warn(s"${params.logIdentifier}Corruption found in segment ${segment.baseOffset}," + + s" truncating to offset ${segment.readNextOffset}") + removeAndDeleteSegmentsAsync(unflushed.toList, params) + truncated = true + } + } + } + + val logEndOffsetOption = deleteSegmentsIfLogStartGreaterThanLogEnd() + + if (params.segments.isEmpty) { + // no existing segments, create a new mutable segment beginning at logStartOffset + params.segments.add( + LogSegment.open( + dir = params.dir, + baseOffset = params.logStartOffsetCheckpoint, + params.config, + time = params.time, + initFileSize = params.config.initFileSize, + preallocate = params.config.preallocate)) + } + + // Update the recovery point if there was a clean shutdown and did not perform any changes to + // the segment. Otherwise, we just ensure that the recovery point is not ahead of the log end + // offset. To ensure correctness and to make it easier to reason about, it's best to only advance + // the recovery point when the log is flushed. If we advanced the recovery point here, we could + // skip recovery for unflushed segments if the broker crashed after we checkpoint the recovery + // point and before we flush the segment. + (params.hadCleanShutdown, logEndOffsetOption) match { + case (true, Some(logEndOffset)) => + (logEndOffset, logEndOffset) + case _ => + val logEndOffset = logEndOffsetOption.getOrElse(params.segments.lastSegment.get.readNextOffset) + (Math.min(params.recoveryPointCheckpoint, logEndOffset), logEndOffset) + } + } + + /** + * This method deletes the given log segments and the associated producer snapshots, by doing the + * following for each of them: + * - It removes the segment from the segment map so that it will no longer be used for reads. + * - It schedules asynchronous deletion of the segments that allows reads to happen concurrently without + * synchronization and without the possibility of physically deleting a file while it is being + * read. + * + * This method does not need to convert IOException to KafkaStorageException because it is either + * called before all logs are loaded or the immediate caller will catch and handle IOException + * + * @param segmentsToDelete The log segments to schedule for deletion + * @param params The parameters for the log being loaded from disk + */ + private def removeAndDeleteSegmentsAsync(segmentsToDelete: Iterable[LogSegment], + params: LoadLogParams): Unit = { + if (segmentsToDelete.nonEmpty) { + // Most callers hold an iterator into the `params.segments` collection and + // `removeAndDeleteSegmentAsync` mutates it by removing the deleted segment. Therefore, + // we should force materialization of the iterator here, so that results of the iteration + // remain valid and deterministic. We should also pass only the materialized view of the + // iterator to the logic that deletes the segments. + val toDelete = segmentsToDelete.toList + info(s"${params.logIdentifier}Deleting segments as part of log recovery: ${toDelete.mkString(",")}") + toDelete.foreach { segment => + params.segments.remove(segment.baseOffset) + } + UnifiedLog.deleteSegmentFiles( + toDelete, + asyncDelete = true, + params.dir, + params.topicPartition, + params.config, + params.scheduler, + params.logDirFailureChannel, + params.logIdentifier) + deleteProducerSnapshotsAsync(segmentsToDelete, params) + } + } + + private def deleteProducerSnapshotsAsync(segments: Iterable[LogSegment], + params: LoadLogParams): Unit = { + UnifiedLog.deleteProducerSnapshots(segments, + params.producerStateManager, + asyncDelete = true, + params.scheduler, + params.config, + params.logDirFailureChannel, + params.dir.getParent, + params.topicPartition) + } +} diff --git a/core/src/main/scala/kafka/log/LogManager.scala b/core/src/main/scala/kafka/log/LogManager.scala new file mode 100755 index 0000000..c4ee18c --- /dev/null +++ b/core/src/main/scala/kafka/log/LogManager.scala @@ -0,0 +1,1319 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import kafka.api.ApiVersion +import kafka.log.LogConfig.MessageFormatVersion + +import java.io._ +import java.nio.file.Files +import java.util.concurrent._ +import java.util.concurrent.atomic.AtomicInteger +import kafka.metrics.KafkaMetricsGroup +import kafka.server.checkpoints.OffsetCheckpointFile +import kafka.server.metadata.ConfigRepository +import kafka.server._ +import kafka.utils._ +import org.apache.kafka.common.{KafkaException, TopicPartition, Uuid} +import org.apache.kafka.common.utils.{KafkaThread, Time, Utils} +import org.apache.kafka.common.errors.{InconsistentTopicIdException, KafkaStorageException, LogDirNotFoundException} + +import scala.jdk.CollectionConverters._ +import scala.collection._ +import scala.collection.mutable.ArrayBuffer +import scala.util.{Failure, Success, Try} +import kafka.utils.Implicits._ + +import java.util.Properties +import scala.annotation.nowarn + +/** + * The entry point to the kafka log management subsystem. The log manager is responsible for log creation, retrieval, and cleaning. + * All read and write operations are delegated to the individual log instances. + * + * The log manager maintains logs in one or more directories. New logs are created in the data directory + * with the fewest logs. No attempt is made to move partitions after the fact or balance based on + * size or I/O rate. + * + * A background thread handles log retention by periodically truncating excess log segments. + */ +@threadsafe +class LogManager(logDirs: Seq[File], + initialOfflineDirs: Seq[File], + configRepository: ConfigRepository, + val initialDefaultConfig: LogConfig, + val cleanerConfig: CleanerConfig, + recoveryThreadsPerDataDir: Int, + val flushCheckMs: Long, + val flushRecoveryOffsetCheckpointMs: Long, + val flushStartOffsetCheckpointMs: Long, + val retentionCheckMs: Long, + val maxPidExpirationMs: Int, + interBrokerProtocolVersion: ApiVersion, + scheduler: Scheduler, + brokerTopicStats: BrokerTopicStats, + logDirFailureChannel: LogDirFailureChannel, + time: Time, + val keepPartitionMetadataFile: Boolean) extends Logging with KafkaMetricsGroup { + + import LogManager._ + + val LockFile = ".lock" + val InitialTaskDelayMs = 30 * 1000 + + private val logCreationOrDeletionLock = new Object + private val currentLogs = new Pool[TopicPartition, UnifiedLog]() + // Future logs are put in the directory with "-future" suffix. Future log is created when user wants to move replica + // from one log directory to another log directory on the same broker. The directory of the future log will be renamed + // to replace the current log of the partition after the future log catches up with the current log + private val futureLogs = new Pool[TopicPartition, UnifiedLog]() + // Each element in the queue contains the log object to be deleted and the time it is scheduled for deletion. + private val logsToBeDeleted = new LinkedBlockingQueue[(UnifiedLog, Long)]() + + private val _liveLogDirs: ConcurrentLinkedQueue[File] = createAndValidateLogDirs(logDirs, initialOfflineDirs) + @volatile private var _currentDefaultConfig = initialDefaultConfig + @volatile private var numRecoveryThreadsPerDataDir = recoveryThreadsPerDataDir + + // This map contains all partitions whose logs are getting loaded and initialized. If log configuration + // of these partitions get updated at the same time, the corresponding entry in this map is set to "true", + // which triggers a config reload after initialization is finished (to get the latest config value). + // See KAFKA-8813 for more detail on the race condition + // Visible for testing + private[log] val partitionsInitializing = new ConcurrentHashMap[TopicPartition, Boolean]().asScala + + def reconfigureDefaultLogConfig(logConfig: LogConfig): Unit = { + this._currentDefaultConfig = logConfig + } + + def currentDefaultConfig: LogConfig = _currentDefaultConfig + + def liveLogDirs: Seq[File] = { + if (_liveLogDirs.size == logDirs.size) + logDirs + else + _liveLogDirs.asScala.toBuffer + } + + private val dirLocks = lockLogDirs(liveLogDirs) + @volatile private var recoveryPointCheckpoints = liveLogDirs.map(dir => + (dir, new OffsetCheckpointFile(new File(dir, RecoveryPointCheckpointFile), logDirFailureChannel))).toMap + @volatile private var logStartOffsetCheckpoints = liveLogDirs.map(dir => + (dir, new OffsetCheckpointFile(new File(dir, LogStartOffsetCheckpointFile), logDirFailureChannel))).toMap + + private val preferredLogDirs = new ConcurrentHashMap[TopicPartition, String]() + + private def offlineLogDirs: Iterable[File] = { + val logDirsSet = mutable.Set[File]() ++= logDirs + _liveLogDirs.forEach(dir => logDirsSet -= dir) + logDirsSet + } + + @volatile private var _cleaner: LogCleaner = _ + private[kafka] def cleaner: LogCleaner = _cleaner + + newGauge("OfflineLogDirectoryCount", () => offlineLogDirs.size) + + for (dir <- logDirs) { + newGauge("LogDirectoryOffline", + () => if (_liveLogDirs.contains(dir)) 0 else 1, + Map("logDirectory" -> dir.getAbsolutePath)) + } + + /** + * Create and check validity of the given directories that are not in the given offline directories, specifically: + *
              + *
            1. Ensure that there are no duplicates in the directory list + *
            2. Create each directory if it doesn't exist + *
            3. Check that each path is a readable directory + *
            + */ + private def createAndValidateLogDirs(dirs: Seq[File], initialOfflineDirs: Seq[File]): ConcurrentLinkedQueue[File] = { + val liveLogDirs = new ConcurrentLinkedQueue[File]() + val canonicalPaths = mutable.HashSet.empty[String] + + for (dir <- dirs) { + try { + if (initialOfflineDirs.contains(dir)) + throw new IOException(s"Failed to load ${dir.getAbsolutePath} during broker startup") + + if (!dir.exists) { + info(s"Log directory ${dir.getAbsolutePath} not found, creating it.") + val created = dir.mkdirs() + if (!created) + throw new IOException(s"Failed to create data directory ${dir.getAbsolutePath}") + Utils.flushDir(dir.toPath.toAbsolutePath.normalize.getParent) + } + if (!dir.isDirectory || !dir.canRead) + throw new IOException(s"${dir.getAbsolutePath} is not a readable log directory.") + + // getCanonicalPath() throws IOException if a file system query fails or if the path is invalid (e.g. contains + // the Nul character). Since there's no easy way to distinguish between the two cases, we treat them the same + // and mark the log directory as offline. + if (!canonicalPaths.add(dir.getCanonicalPath)) + throw new KafkaException(s"Duplicate log directory found: ${dirs.mkString(", ")}") + + + liveLogDirs.add(dir) + } catch { + case e: IOException => + logDirFailureChannel.maybeAddOfflineLogDir(dir.getAbsolutePath, s"Failed to create or validate data directory ${dir.getAbsolutePath}", e) + } + } + if (liveLogDirs.isEmpty) { + fatal(s"Shutdown broker because none of the specified log dirs from ${dirs.mkString(", ")} can be created or validated") + Exit.halt(1) + } + + liveLogDirs + } + + def resizeRecoveryThreadPool(newSize: Int): Unit = { + info(s"Resizing recovery thread pool size for each data dir from $numRecoveryThreadsPerDataDir to $newSize") + numRecoveryThreadsPerDataDir = newSize + } + + /** + * The log directory failure handler. It will stop log cleaning in that directory. + * + * @param dir the absolute path of the log directory + */ + def handleLogDirFailure(dir: String): Unit = { + warn(s"Stopping serving logs in dir $dir") + logCreationOrDeletionLock synchronized { + _liveLogDirs.remove(new File(dir)) + if (_liveLogDirs.isEmpty) { + fatal(s"Shutdown broker because all log dirs in ${logDirs.mkString(", ")} have failed") + Exit.halt(1) + } + + recoveryPointCheckpoints = recoveryPointCheckpoints.filter { case (file, _) => file.getAbsolutePath != dir } + logStartOffsetCheckpoints = logStartOffsetCheckpoints.filter { case (file, _) => file.getAbsolutePath != dir } + if (cleaner != null) + cleaner.handleLogDirFailure(dir) + + def removeOfflineLogs(logs: Pool[TopicPartition, UnifiedLog]): Iterable[TopicPartition] = { + val offlineTopicPartitions: Iterable[TopicPartition] = logs.collect { + case (tp, log) if log.parentDir == dir => tp + } + offlineTopicPartitions.foreach { topicPartition => { + val removedLog = removeLogAndMetrics(logs, topicPartition) + removedLog.foreach { + log => log.closeHandlers() + } + }} + + offlineTopicPartitions + } + + val offlineCurrentTopicPartitions = removeOfflineLogs(currentLogs) + val offlineFutureTopicPartitions = removeOfflineLogs(futureLogs) + + warn(s"Logs for partitions ${offlineCurrentTopicPartitions.mkString(",")} are offline and " + + s"logs for future partitions ${offlineFutureTopicPartitions.mkString(",")} are offline due to failure on log directory $dir") + dirLocks.filter(_.file.getParent == dir).foreach(dir => CoreUtils.swallow(dir.destroy(), this)) + } + } + + /** + * Lock all the given directories + */ + private def lockLogDirs(dirs: Seq[File]): Seq[FileLock] = { + dirs.flatMap { dir => + try { + val lock = new FileLock(new File(dir, LockFile)) + if (!lock.tryLock()) + throw new KafkaException("Failed to acquire lock on file .lock in " + lock.file.getParent + + ". A Kafka instance in another process or thread is using this directory.") + Some(lock) + } catch { + case e: IOException => + logDirFailureChannel.maybeAddOfflineLogDir(dir.getAbsolutePath, s"Disk error while locking directory $dir", e) + None + } + } + } + + private def addLogToBeDeleted(log: UnifiedLog): Unit = { + this.logsToBeDeleted.add((log, time.milliseconds())) + } + + // Only for testing + private[log] def hasLogsToBeDeleted: Boolean = !logsToBeDeleted.isEmpty + + private[log] def loadLog(logDir: File, + hadCleanShutdown: Boolean, + recoveryPoints: Map[TopicPartition, Long], + logStartOffsets: Map[TopicPartition, Long], + defaultConfig: LogConfig, + topicConfigOverrides: Map[String, LogConfig]): UnifiedLog = { + val topicPartition = UnifiedLog.parseTopicPartitionName(logDir) + val config = topicConfigOverrides.getOrElse(topicPartition.topic, defaultConfig) + val logRecoveryPoint = recoveryPoints.getOrElse(topicPartition, 0L) + val logStartOffset = logStartOffsets.getOrElse(topicPartition, 0L) + + val log = UnifiedLog( + dir = logDir, + config = config, + logStartOffset = logStartOffset, + recoveryPoint = logRecoveryPoint, + maxProducerIdExpirationMs = maxPidExpirationMs, + producerIdExpirationCheckIntervalMs = LogManager.ProducerIdExpirationCheckIntervalMs, + scheduler = scheduler, + time = time, + brokerTopicStats = brokerTopicStats, + logDirFailureChannel = logDirFailureChannel, + lastShutdownClean = hadCleanShutdown, + topicId = None, + keepPartitionMetadataFile = keepPartitionMetadataFile) + + if (logDir.getName.endsWith(UnifiedLog.DeleteDirSuffix)) { + addLogToBeDeleted(log) + } else { + val previous = { + if (log.isFuture) + this.futureLogs.put(topicPartition, log) + else + this.currentLogs.put(topicPartition, log) + } + if (previous != null) { + if (log.isFuture) + throw new IllegalStateException(s"Duplicate log directories found: ${log.dir.getAbsolutePath}, ${previous.dir.getAbsolutePath}") + else + throw new IllegalStateException(s"Duplicate log directories for $topicPartition are found in both ${log.dir.getAbsolutePath} " + + s"and ${previous.dir.getAbsolutePath}. It is likely because log directory failure happened while broker was " + + s"replacing current replica with future replica. Recover broker from this failure by manually deleting one of the two directories " + + s"for this partition. It is recommended to delete the partition in the log directory that is known to have failed recently.") + } + } + + log + } + + /** + * Recover and load all logs in the given data directories + */ + private[log] def loadLogs(defaultConfig: LogConfig, topicConfigOverrides: Map[String, LogConfig]): Unit = { + info(s"Loading logs from log dirs $liveLogDirs") + val startMs = time.hiResClockMs() + val threadPools = ArrayBuffer.empty[ExecutorService] + val offlineDirs = mutable.Set.empty[(String, IOException)] + val jobs = ArrayBuffer.empty[Seq[Future[_]]] + var numTotalLogs = 0 + + for (dir <- liveLogDirs) { + val logDirAbsolutePath = dir.getAbsolutePath + var hadCleanShutdown: Boolean = false + try { + val pool = Executors.newFixedThreadPool(numRecoveryThreadsPerDataDir, + KafkaThread.nonDaemon(s"log-recovery-$logDirAbsolutePath", _)) + threadPools.append(pool) + + val cleanShutdownFile = new File(dir, LogLoader.CleanShutdownFile) + if (cleanShutdownFile.exists) { + info(s"Skipping recovery for all logs in $logDirAbsolutePath since clean shutdown file was found") + // Cache the clean shutdown status and use that for rest of log loading workflow. Delete the CleanShutdownFile + // so that if broker crashes while loading the log, it is considered hard shutdown during the next boot up. KAFKA-10471 + Files.deleteIfExists(cleanShutdownFile.toPath) + hadCleanShutdown = true + } else { + // log recovery itself is being performed by `Log` class during initialization + info(s"Attempting recovery for all logs in $logDirAbsolutePath since no clean shutdown file was found") + } + + var recoveryPoints = Map[TopicPartition, Long]() + try { + recoveryPoints = this.recoveryPointCheckpoints(dir).read() + } catch { + case e: Exception => + warn(s"Error occurred while reading recovery-point-offset-checkpoint file of directory " + + s"$logDirAbsolutePath, resetting the recovery checkpoint to 0", e) + } + + var logStartOffsets = Map[TopicPartition, Long]() + try { + logStartOffsets = this.logStartOffsetCheckpoints(dir).read() + } catch { + case e: Exception => + warn(s"Error occurred while reading log-start-offset-checkpoint file of directory " + + s"$logDirAbsolutePath, resetting to the base offset of the first segment", e) + } + + val logsToLoad = Option(dir.listFiles).getOrElse(Array.empty).filter(logDir => + logDir.isDirectory && UnifiedLog.parseTopicPartitionName(logDir).topic != KafkaRaftServer.MetadataTopic) + val numLogsLoaded = new AtomicInteger(0) + numTotalLogs += logsToLoad.length + + val jobsForDir = logsToLoad.map { logDir => + val runnable: Runnable = () => { + try { + debug(s"Loading log $logDir") + + val logLoadStartMs = time.hiResClockMs() + val log = loadLog(logDir, hadCleanShutdown, recoveryPoints, logStartOffsets, + defaultConfig, topicConfigOverrides) + val logLoadDurationMs = time.hiResClockMs() - logLoadStartMs + val currentNumLoaded = numLogsLoaded.incrementAndGet() + + info(s"Completed load of $log with ${log.numberOfSegments} segments in ${logLoadDurationMs}ms " + + s"($currentNumLoaded/${logsToLoad.length} loaded in $logDirAbsolutePath)") + } catch { + case e: IOException => + offlineDirs.add((logDirAbsolutePath, e)) + error(s"Error while loading log dir $logDirAbsolutePath", e) + } + } + runnable + } + + jobs += jobsForDir.map(pool.submit) + } catch { + case e: IOException => + offlineDirs.add((logDirAbsolutePath, e)) + error(s"Error while loading log dir $logDirAbsolutePath", e) + } + } + + try { + for (dirJobs <- jobs) { + dirJobs.foreach(_.get) + } + + offlineDirs.foreach { case (dir, e) => + logDirFailureChannel.maybeAddOfflineLogDir(dir, s"Error while loading log dir $dir", e) + } + } catch { + case e: ExecutionException => + error(s"There was an error in one of the threads during logs loading: ${e.getCause}") + throw e.getCause + } finally { + threadPools.foreach(_.shutdown()) + } + + info(s"Loaded $numTotalLogs logs in ${time.hiResClockMs() - startMs}ms.") + } + + /** + * Start the background threads to flush logs and do log cleanup + */ + def startup(topicNames: Set[String]): Unit = { + // ensure consistency between default config and overrides + val defaultConfig = currentDefaultConfig + startupWithConfigOverrides(defaultConfig, fetchTopicConfigOverrides(defaultConfig, topicNames)) + } + + // visible for testing + @nowarn("cat=deprecation") + private[log] def fetchTopicConfigOverrides(defaultConfig: LogConfig, topicNames: Set[String]): Map[String, LogConfig] = { + val topicConfigOverrides = mutable.Map[String, LogConfig]() + val defaultProps = defaultConfig.originals() + topicNames.foreach { topicName => + var overrides = configRepository.topicConfig(topicName) + // save memory by only including configs for topics with overrides + if (!overrides.isEmpty) { + Option(overrides.getProperty(LogConfig.MessageFormatVersionProp)).foreach { versionString => + val messageFormatVersion = new MessageFormatVersion(versionString, interBrokerProtocolVersion.version) + if (messageFormatVersion.shouldIgnore) { + val copy = new Properties() + copy.putAll(overrides) + copy.remove(LogConfig.MessageFormatVersionProp) + overrides = copy + + if (messageFormatVersion.shouldWarn) + warn(messageFormatVersion.topicWarningMessage(topicName)) + } + } + + val logConfig = LogConfig.fromProps(defaultProps, overrides) + topicConfigOverrides(topicName) = logConfig + } + } + topicConfigOverrides + } + + private def fetchLogConfig(topicName: String): LogConfig = { + // ensure consistency between default config and overrides + val defaultConfig = currentDefaultConfig + fetchTopicConfigOverrides(defaultConfig, Set(topicName)).values.headOption.getOrElse(defaultConfig) + } + + // visible for testing + private[log] def startupWithConfigOverrides(defaultConfig: LogConfig, topicConfigOverrides: Map[String, LogConfig]): Unit = { + loadLogs(defaultConfig, topicConfigOverrides) // this could take a while if shutdown was not clean + + /* Schedule the cleanup task to delete old logs */ + if (scheduler != null) { + info("Starting log cleanup with a period of %d ms.".format(retentionCheckMs)) + scheduler.schedule("kafka-log-retention", + cleanupLogs _, + delay = InitialTaskDelayMs, + period = retentionCheckMs, + TimeUnit.MILLISECONDS) + info("Starting log flusher with a default period of %d ms.".format(flushCheckMs)) + scheduler.schedule("kafka-log-flusher", + flushDirtyLogs _, + delay = InitialTaskDelayMs, + period = flushCheckMs, + TimeUnit.MILLISECONDS) + scheduler.schedule("kafka-recovery-point-checkpoint", + checkpointLogRecoveryOffsets _, + delay = InitialTaskDelayMs, + period = flushRecoveryOffsetCheckpointMs, + TimeUnit.MILLISECONDS) + scheduler.schedule("kafka-log-start-offset-checkpoint", + checkpointLogStartOffsets _, + delay = InitialTaskDelayMs, + period = flushStartOffsetCheckpointMs, + TimeUnit.MILLISECONDS) + scheduler.schedule("kafka-delete-logs", // will be rescheduled after each delete logs with a dynamic period + deleteLogs _, + delay = InitialTaskDelayMs, + unit = TimeUnit.MILLISECONDS) + } + if (cleanerConfig.enableCleaner) { + _cleaner = new LogCleaner(cleanerConfig, liveLogDirs, currentLogs, logDirFailureChannel, time = time) + _cleaner.startup() + } + } + + /** + * Close all the logs + */ + def shutdown(): Unit = { + info("Shutting down.") + + removeMetric("OfflineLogDirectoryCount") + for (dir <- logDirs) { + removeMetric("LogDirectoryOffline", Map("logDirectory" -> dir.getAbsolutePath)) + } + + val threadPools = ArrayBuffer.empty[ExecutorService] + val jobs = mutable.Map.empty[File, Seq[Future[_]]] + + // stop the cleaner first + if (cleaner != null) { + CoreUtils.swallow(cleaner.shutdown(), this) + } + + val localLogsByDir = logsByDir + + // close logs in each dir + for (dir <- liveLogDirs) { + debug(s"Flushing and closing logs at $dir") + + val pool = Executors.newFixedThreadPool(numRecoveryThreadsPerDataDir, + KafkaThread.nonDaemon(s"log-closing-${dir.getAbsolutePath}", _)) + threadPools.append(pool) + + val logs = logsInDir(localLogsByDir, dir).values + + val jobsForDir = logs.map { log => + val runnable: Runnable = () => { + // flush the log to ensure latest possible recovery point + log.flush() + log.close() + } + runnable + } + + jobs(dir) = jobsForDir.map(pool.submit).toSeq + } + + try { + jobs.forKeyValue { (dir, dirJobs) => + if (waitForAllToComplete(dirJobs, + e => warn(s"There was an error in one of the threads during LogManager shutdown: ${e.getCause}"))) { + val logs = logsInDir(localLogsByDir, dir) + + // update the last flush point + debug(s"Updating recovery points at $dir") + checkpointRecoveryOffsetsInDir(dir, logs) + + debug(s"Updating log start offsets at $dir") + checkpointLogStartOffsetsInDir(dir, logs) + + // mark that the shutdown was clean by creating marker file + debug(s"Writing clean shutdown marker at $dir") + CoreUtils.swallow(Files.createFile(new File(dir, LogLoader.CleanShutdownFile).toPath), this) + } + } + } finally { + threadPools.foreach(_.shutdown()) + // regardless of whether the close succeeded, we need to unlock the data directories + dirLocks.foreach(_.destroy()) + } + + info("Shutdown complete.") + } + + /** + * Truncate the partition logs to the specified offsets and checkpoint the recovery point to this offset + * + * @param partitionOffsets Partition logs that need to be truncated + * @param isFuture True iff the truncation should be performed on the future log of the specified partitions + */ + def truncateTo(partitionOffsets: Map[TopicPartition, Long], isFuture: Boolean): Unit = { + val affectedLogs = ArrayBuffer.empty[UnifiedLog] + for ((topicPartition, truncateOffset) <- partitionOffsets) { + val log = { + if (isFuture) + futureLogs.get(topicPartition) + else + currentLogs.get(topicPartition) + } + // If the log does not exist, skip it + if (log != null) { + // May need to abort and pause the cleaning of the log, and resume after truncation is done. + val needToStopCleaner = truncateOffset < log.activeSegment.baseOffset + if (needToStopCleaner && !isFuture) + abortAndPauseCleaning(topicPartition) + try { + if (log.truncateTo(truncateOffset)) + affectedLogs += log + if (needToStopCleaner && !isFuture) + maybeTruncateCleanerCheckpointToActiveSegmentBaseOffset(log, topicPartition) + } finally { + if (needToStopCleaner && !isFuture) + resumeCleaning(topicPartition) + } + } + } + + for (dir <- affectedLogs.map(_.parentDirFile).distinct) { + checkpointRecoveryOffsetsInDir(dir) + } + } + + /** + * Delete all data in a partition and start the log at the new offset + * + * @param topicPartition The partition whose log needs to be truncated + * @param newOffset The new offset to start the log with + * @param isFuture True iff the truncation should be performed on the future log of the specified partition + */ + def truncateFullyAndStartAt(topicPartition: TopicPartition, newOffset: Long, isFuture: Boolean): Unit = { + val log = { + if (isFuture) + futureLogs.get(topicPartition) + else + currentLogs.get(topicPartition) + } + // If the log does not exist, skip it + if (log != null) { + // Abort and pause the cleaning of the log, and resume after truncation is done. + if (!isFuture) + abortAndPauseCleaning(topicPartition) + try { + log.truncateFullyAndStartAt(newOffset) + if (!isFuture) + maybeTruncateCleanerCheckpointToActiveSegmentBaseOffset(log, topicPartition) + } finally { + if (!isFuture) + resumeCleaning(topicPartition) + } + checkpointRecoveryOffsetsInDir(log.parentDirFile) + } + } + + /** + * Write out the current recovery point for all logs to a text file in the log directory + * to avoid recovering the whole log on startup. + */ + def checkpointLogRecoveryOffsets(): Unit = { + val logsByDirCached = logsByDir + liveLogDirs.foreach { logDir => + val logsToCheckpoint = logsInDir(logsByDirCached, logDir) + checkpointRecoveryOffsetsInDir(logDir, logsToCheckpoint) + } + } + + /** + * Write out the current log start offset for all logs to a text file in the log directory + * to avoid exposing data that have been deleted by DeleteRecordsRequest + */ + def checkpointLogStartOffsets(): Unit = { + val logsByDirCached = logsByDir + liveLogDirs.foreach { logDir => + checkpointLogStartOffsetsInDir(logDir, logsInDir(logsByDirCached, logDir)) + } + } + + /** + * Checkpoint recovery offsets for all the logs in logDir. + * + * @param logDir the directory in which the logs to be checkpointed are + */ + // Only for testing + private[log] def checkpointRecoveryOffsetsInDir(logDir: File): Unit = { + checkpointRecoveryOffsetsInDir(logDir, logsInDir(logDir)) + } + + /** + * Checkpoint recovery offsets for all the provided logs. + * + * @param logDir the directory in which the logs are + * @param logsToCheckpoint the logs to be checkpointed + */ + private def checkpointRecoveryOffsetsInDir(logDir: File, logsToCheckpoint: Map[TopicPartition, UnifiedLog]): Unit = { + try { + recoveryPointCheckpoints.get(logDir).foreach { checkpoint => + val recoveryOffsets = logsToCheckpoint.map { case (tp, log) => tp -> log.recoveryPoint } + // checkpoint.write calls Utils.atomicMoveWithFallback, which flushes the parent + // directory and guarantees crash consistency. + checkpoint.write(recoveryOffsets) + } + } catch { + case e: KafkaStorageException => + error(s"Disk error while writing recovery offsets checkpoint in directory $logDir: ${e.getMessage}") + case e: IOException => + logDirFailureChannel.maybeAddOfflineLogDir(logDir.getAbsolutePath, + s"Disk error while writing recovery offsets checkpoint in directory $logDir: ${e.getMessage}", e) + } + } + + /** + * Checkpoint log start offsets for all the provided logs in the provided directory. + * + * @param logDir the directory in which logs are checkpointed + * @param logsToCheckpoint the logs to be checkpointed + */ + private def checkpointLogStartOffsetsInDir(logDir: File, logsToCheckpoint: Map[TopicPartition, UnifiedLog]): Unit = { + try { + logStartOffsetCheckpoints.get(logDir).foreach { checkpoint => + val logStartOffsets = logsToCheckpoint.collect { + case (tp, log) if log.logStartOffset > log.logSegments.head.baseOffset => tp -> log.logStartOffset + } + checkpoint.write(logStartOffsets) + } + } catch { + case e: KafkaStorageException => + error(s"Disk error while writing log start offsets checkpoint in directory $logDir: ${e.getMessage}") + } + } + + // The logDir should be an absolute path + def maybeUpdatePreferredLogDir(topicPartition: TopicPartition, logDir: String): Unit = { + // Do not cache the preferred log directory if either the current log or the future log for this partition exists in the specified logDir + if (!getLog(topicPartition).exists(_.parentDir == logDir) && + !getLog(topicPartition, isFuture = true).exists(_.parentDir == logDir)) + preferredLogDirs.put(topicPartition, logDir) + } + + /** + * Abort and pause cleaning of the provided partition and log a message about it. + */ + def abortAndPauseCleaning(topicPartition: TopicPartition): Unit = { + if (cleaner != null) { + cleaner.abortAndPauseCleaning(topicPartition) + info(s"The cleaning for partition $topicPartition is aborted and paused") + } + } + + /** + * Abort cleaning of the provided partition and log a message about it. + */ + def abortCleaning(topicPartition: TopicPartition): Unit = { + if (cleaner != null) { + cleaner.abortCleaning(topicPartition) + info(s"The cleaning for partition $topicPartition is aborted") + } + } + + /** + * Resume cleaning of the provided partition and log a message about it. + */ + private def resumeCleaning(topicPartition: TopicPartition): Unit = { + if (cleaner != null) { + cleaner.resumeCleaning(Seq(topicPartition)) + info(s"Cleaning for partition $topicPartition is resumed") + } + } + + /** + * Truncate the cleaner's checkpoint to the based offset of the active segment of + * the provided log. + */ + private def maybeTruncateCleanerCheckpointToActiveSegmentBaseOffset(log: UnifiedLog, topicPartition: TopicPartition): Unit = { + if (cleaner != null) { + cleaner.maybeTruncateCheckpoint(log.parentDirFile, topicPartition, log.activeSegment.baseOffset) + } + } + + /** + * Get the log if it exists, otherwise return None + * + * @param topicPartition the partition of the log + * @param isFuture True iff the future log of the specified partition should be returned + */ + def getLog(topicPartition: TopicPartition, isFuture: Boolean = false): Option[UnifiedLog] = { + if (isFuture) + Option(futureLogs.get(topicPartition)) + else + Option(currentLogs.get(topicPartition)) + } + + /** + * Method to indicate that logs are getting initialized for the partition passed in as argument. + * This method should always be followed by [[kafka.log.LogManager#finishedInitializingLog]] to indicate that log + * initialization is done. + */ + def initializingLog(topicPartition: TopicPartition): Unit = { + partitionsInitializing(topicPartition) = false + } + + /** + * Mark the partition configuration for all partitions that are getting initialized for topic + * as dirty. That will result in reloading of configuration once initialization is done. + */ + def topicConfigUpdated(topic: String): Unit = { + partitionsInitializing.keys.filter(_.topic() == topic).foreach { + topicPartition => partitionsInitializing.replace(topicPartition, false, true) + } + } + + /** + * Update the configuration of the provided topic. + */ + def updateTopicConfig(topic: String, + newTopicConfig: Properties): Unit = { + topicConfigUpdated(topic) + val logs = logsByTopic(topic) + if (logs.nonEmpty) { + // Combine the default properties with the overrides in zk to create the new LogConfig + val newLogConfig = LogConfig.fromProps(currentDefaultConfig.originals, newTopicConfig) + logs.foreach { log => + val oldLogConfig = log.updateConfig(newLogConfig) + if (oldLogConfig.compact && !newLogConfig.compact) { + abortCleaning(log.topicPartition) + } + } + } + } + + /** + * Mark all in progress partitions having dirty configuration if broker configuration is updated. + */ + def brokerConfigUpdated(): Unit = { + partitionsInitializing.keys.foreach { + topicPartition => partitionsInitializing.replace(topicPartition, false, true) + } + } + + /** + * Method to indicate that the log initialization for the partition passed in as argument is + * finished. This method should follow a call to [[kafka.log.LogManager#initializingLog]]. + * + * It will retrieve the topic configs a second time if they were updated while the + * relevant log was being loaded. + */ + def finishedInitializingLog(topicPartition: TopicPartition, + maybeLog: Option[UnifiedLog]): Unit = { + val removedValue = partitionsInitializing.remove(topicPartition) + if (removedValue.contains(true)) + maybeLog.foreach(_.updateConfig(fetchLogConfig(topicPartition.topic))) + } + + /** + * If the log already exists, just return a copy of the existing log + * Otherwise if isNew=true or if there is no offline log directory, create a log for the given topic and the given partition + * Otherwise throw KafkaStorageException + * + * @param topicPartition The partition whose log needs to be returned or created + * @param isNew Whether the replica should have existed on the broker or not + * @param isFuture True if the future log of the specified partition should be returned or created + * @param topicId The topic ID of the partition's topic + * @throws KafkaStorageException if isNew=false, log is not found in the cache and there is offline log directory on the broker + * @throws InconsistentTopicIdException if the topic ID in the log does not match the topic ID provided + */ + def getOrCreateLog(topicPartition: TopicPartition, isNew: Boolean = false, isFuture: Boolean = false, topicId: Option[Uuid]): UnifiedLog = { + logCreationOrDeletionLock synchronized { + val log = getLog(topicPartition, isFuture).getOrElse { + // create the log if it has not already been created in another thread + if (!isNew && offlineLogDirs.nonEmpty) + throw new KafkaStorageException(s"Can not create log for $topicPartition because log directories ${offlineLogDirs.mkString(",")} are offline") + + val logDirs: List[File] = { + val preferredLogDir = preferredLogDirs.get(topicPartition) + + if (isFuture) { + if (preferredLogDir == null) + throw new IllegalStateException(s"Can not create the future log for $topicPartition without having a preferred log directory") + else if (getLog(topicPartition).get.parentDir == preferredLogDir) + throw new IllegalStateException(s"Can not create the future log for $topicPartition in the current log directory of this partition") + } + + if (preferredLogDir != null) + List(new File(preferredLogDir)) + else + nextLogDirs() + } + + val logDirName = { + if (isFuture) + UnifiedLog.logFutureDirName(topicPartition) + else + UnifiedLog.logDirName(topicPartition) + } + + val logDir = logDirs + .iterator // to prevent actually mapping the whole list, lazy map + .map(createLogDirectory(_, logDirName)) + .find(_.isSuccess) + .getOrElse(Failure(new KafkaStorageException("No log directories available. Tried " + logDirs.map(_.getAbsolutePath).mkString(", ")))) + .get // If Failure, will throw + + val config = fetchLogConfig(topicPartition.topic) + val log = UnifiedLog( + dir = logDir, + config = config, + logStartOffset = 0L, + recoveryPoint = 0L, + maxProducerIdExpirationMs = maxPidExpirationMs, + producerIdExpirationCheckIntervalMs = LogManager.ProducerIdExpirationCheckIntervalMs, + scheduler = scheduler, + time = time, + brokerTopicStats = brokerTopicStats, + logDirFailureChannel = logDirFailureChannel, + topicId = topicId, + keepPartitionMetadataFile = keepPartitionMetadataFile) + + if (isFuture) + futureLogs.put(topicPartition, log) + else + currentLogs.put(topicPartition, log) + + info(s"Created log for partition $topicPartition in $logDir with properties ${config.overriddenConfigsAsLoggableString}") + // Remove the preferred log dir since it has already been satisfied + preferredLogDirs.remove(topicPartition) + + log + } + // When running a ZK controller, we may get a log that does not have a topic ID. Assign it here. + if (log.topicId.isEmpty) { + topicId.foreach(log.assignTopicId) + } + + // Ensure topic IDs are consistent + topicId.foreach { topicId => + log.topicId.foreach { logTopicId => + if (topicId != logTopicId) + throw new InconsistentTopicIdException(s"Tried to assign topic ID $topicId to log for topic partition $topicPartition," + + s"but log already contained topic ID $logTopicId") + } + } + log + } + } + + private[log] def createLogDirectory(logDir: File, logDirName: String): Try[File] = { + val logDirPath = logDir.getAbsolutePath + if (isLogDirOnline(logDirPath)) { + val dir = new File(logDirPath, logDirName) + try { + Files.createDirectories(dir.toPath) + Success(dir) + } catch { + case e: IOException => + val msg = s"Error while creating log for $logDirName in dir $logDirPath" + logDirFailureChannel.maybeAddOfflineLogDir(logDirPath, msg, e) + warn(msg, e) + Failure(new KafkaStorageException(msg, e)) + } + } else { + Failure(new KafkaStorageException(s"Can not create log $logDirName because log directory $logDirPath is offline")) + } + } + + /** + * Delete logs marked for deletion. Delete all logs for which `currentDefaultConfig.fileDeleteDelayMs` + * has elapsed after the delete was scheduled. Logs for which this interval has not yet elapsed will be + * considered for deletion in the next iteration of `deleteLogs`. The next iteration will be executed + * after the remaining time for the first log that is not deleted. If there are no more `logsToBeDeleted`, + * `deleteLogs` will be executed after `currentDefaultConfig.fileDeleteDelayMs`. + */ + private def deleteLogs(): Unit = { + var nextDelayMs = 0L + val fileDeleteDelayMs = currentDefaultConfig.fileDeleteDelayMs + try { + def nextDeleteDelayMs: Long = { + if (!logsToBeDeleted.isEmpty) { + val (_, scheduleTimeMs) = logsToBeDeleted.peek() + scheduleTimeMs + fileDeleteDelayMs - time.milliseconds() + } else + fileDeleteDelayMs + } + + while ({nextDelayMs = nextDeleteDelayMs; nextDelayMs <= 0}) { + val (removedLog, _) = logsToBeDeleted.take() + if (removedLog != null) { + try { + removedLog.delete() + info(s"Deleted log for partition ${removedLog.topicPartition} in ${removedLog.dir.getAbsolutePath}.") + } catch { + case e: KafkaStorageException => + error(s"Exception while deleting $removedLog in dir ${removedLog.parentDir}.", e) + } + } + } + } catch { + case e: Throwable => + error(s"Exception in kafka-delete-logs thread.", e) + } finally { + try { + scheduler.schedule("kafka-delete-logs", + deleteLogs _, + delay = nextDelayMs, + unit = TimeUnit.MILLISECONDS) + } catch { + case e: Throwable => + if (scheduler.isStarted) { + // No errors should occur unless scheduler has been shutdown + error(s"Failed to schedule next delete in kafka-delete-logs thread", e) + } + } + } + } + + /** + * Mark the partition directory in the source log directory for deletion and + * rename the future log of this partition in the destination log directory to be the current log + * + * @param topicPartition TopicPartition that needs to be swapped + */ + def replaceCurrentWithFutureLog(topicPartition: TopicPartition): Unit = { + logCreationOrDeletionLock synchronized { + val sourceLog = currentLogs.get(topicPartition) + val destLog = futureLogs.get(topicPartition) + + info(s"Attempting to replace current log $sourceLog with $destLog for $topicPartition") + if (sourceLog == null) + throw new KafkaStorageException(s"The current replica for $topicPartition is offline") + if (destLog == null) + throw new KafkaStorageException(s"The future replica for $topicPartition is offline") + + destLog.renameDir(UnifiedLog.logDirName(topicPartition)) + destLog.updateHighWatermark(sourceLog.highWatermark) + + // Now that future replica has been successfully renamed to be the current replica + // Update the cached map and log cleaner as appropriate. + futureLogs.remove(topicPartition) + currentLogs.put(topicPartition, destLog) + if (cleaner != null) { + cleaner.alterCheckpointDir(topicPartition, sourceLog.parentDirFile, destLog.parentDirFile) + resumeCleaning(topicPartition) + } + + try { + sourceLog.renameDir(UnifiedLog.logDeleteDirName(topicPartition)) + // Now that replica in source log directory has been successfully renamed for deletion. + // Close the log, update checkpoint files, and enqueue this log to be deleted. + sourceLog.close() + val logDir = sourceLog.parentDirFile + val logsToCheckpoint = logsInDir(logDir) + checkpointRecoveryOffsetsInDir(logDir, logsToCheckpoint) + checkpointLogStartOffsetsInDir(logDir, logsToCheckpoint) + sourceLog.removeLogMetrics() + addLogToBeDeleted(sourceLog) + } catch { + case e: KafkaStorageException => + // If sourceLog's log directory is offline, we need close its handlers here. + // handleLogDirFailure() will not close handlers of sourceLog because it has been removed from currentLogs map + sourceLog.closeHandlers() + sourceLog.removeLogMetrics() + throw e + } + + info(s"The current replica is successfully replaced with the future replica for $topicPartition") + } + } + + /** + * Rename the directory of the given topic-partition "logdir" as "logdir.uuid.delete" and + * add it in the queue for deletion. + * + * @param topicPartition TopicPartition that needs to be deleted + * @param isFuture True iff the future log of the specified partition should be deleted + * @param checkpoint True if checkpoints must be written + * @return the removed log + */ + def asyncDelete(topicPartition: TopicPartition, + isFuture: Boolean = false, + checkpoint: Boolean = true): Option[UnifiedLog] = { + val removedLog: Option[UnifiedLog] = logCreationOrDeletionLock synchronized { + removeLogAndMetrics(if (isFuture) futureLogs else currentLogs, topicPartition) + } + removedLog match { + case Some(removedLog) => + // We need to wait until there is no more cleaning task on the log to be deleted before actually deleting it. + if (cleaner != null && !isFuture) { + cleaner.abortCleaning(topicPartition) + if (checkpoint) { + cleaner.updateCheckpoints(removedLog.parentDirFile, partitionToRemove = Option(topicPartition)) + } + } + removedLog.renameDir(UnifiedLog.logDeleteDirName(topicPartition)) + if (checkpoint) { + val logDir = removedLog.parentDirFile + val logsToCheckpoint = logsInDir(logDir) + checkpointRecoveryOffsetsInDir(logDir, logsToCheckpoint) + checkpointLogStartOffsetsInDir(logDir, logsToCheckpoint) + } + addLogToBeDeleted(removedLog) + info(s"Log for partition ${removedLog.topicPartition} is renamed to ${removedLog.dir.getAbsolutePath} and is scheduled for deletion") + + case None => + if (offlineLogDirs.nonEmpty) { + throw new KafkaStorageException(s"Failed to delete log for ${if (isFuture) "future" else ""} $topicPartition because it may be in one of the offline directories ${offlineLogDirs.mkString(",")}") + } + } + + removedLog + } + + /** + * Rename the directories of the given topic-partitions and add them in the queue for + * deletion. Checkpoints are updated once all the directories have been renamed. + * + * @param topicPartitions The set of topic-partitions to delete asynchronously + * @param errorHandler The error handler that will be called when a exception for a particular + * topic-partition is raised + */ + def asyncDelete(topicPartitions: Set[TopicPartition], + errorHandler: (TopicPartition, Throwable) => Unit): Unit = { + val logDirs = mutable.Set.empty[File] + + topicPartitions.foreach { topicPartition => + try { + getLog(topicPartition).foreach { log => + logDirs += log.parentDirFile + asyncDelete(topicPartition, checkpoint = false) + } + getLog(topicPartition, isFuture = true).foreach { log => + logDirs += log.parentDirFile + asyncDelete(topicPartition, isFuture = true, checkpoint = false) + } + } catch { + case e: Throwable => errorHandler(topicPartition, e) + } + } + + val logsByDirCached = logsByDir + logDirs.foreach { logDir => + if (cleaner != null) cleaner.updateCheckpoints(logDir) + val logsToCheckpoint = logsInDir(logsByDirCached, logDir) + checkpointRecoveryOffsetsInDir(logDir, logsToCheckpoint) + checkpointLogStartOffsetsInDir(logDir, logsToCheckpoint) + } + } + + /** + * Provides the full ordered list of suggested directories for the next partition. + * Currently this is done by calculating the number of partitions in each directory and then sorting the + * data directories by fewest partitions. + */ + private def nextLogDirs(): List[File] = { + if(_liveLogDirs.size == 1) { + List(_liveLogDirs.peek()) + } else { + // count the number of logs in each parent directory (including 0 for empty directories + val logCounts = allLogs.groupBy(_.parentDir).map { case (parent, logs) => parent -> logs.size } + val zeros = _liveLogDirs.asScala.map(dir => (dir.getPath, 0)).toMap + val dirCounts = (zeros ++ logCounts).toBuffer + + // choose the directory with the least logs in it + dirCounts.sortBy(_._2).map { + case (path: String, _: Int) => new File(path) + }.toList + } + } + + /** + * Delete any eligible logs. Return the number of segments deleted. + * Only consider logs that are not compacted. + */ + def cleanupLogs(): Unit = { + debug("Beginning log cleanup...") + var total = 0 + val startMs = time.milliseconds + + // clean current logs. + val deletableLogs = { + if (cleaner != null) { + // prevent cleaner from working on same partitions when changing cleanup policy + cleaner.pauseCleaningForNonCompactedPartitions() + } else { + currentLogs.filter { + case (_, log) => !log.config.compact + } + } + } + + try { + deletableLogs.foreach { + case (topicPartition, log) => + debug(s"Garbage collecting '${log.name}'") + total += log.deleteOldSegments() + + val futureLog = futureLogs.get(topicPartition) + if (futureLog != null) { + // clean future logs + debug(s"Garbage collecting future log '${futureLog.name}'") + total += futureLog.deleteOldSegments() + } + } + } finally { + if (cleaner != null) { + cleaner.resumeCleaning(deletableLogs.map(_._1)) + } + } + + debug(s"Log cleanup completed. $total files deleted in " + + (time.milliseconds - startMs) / 1000 + " seconds") + } + + /** + * Get all the partition logs + */ + def allLogs: Iterable[UnifiedLog] = currentLogs.values ++ futureLogs.values + + def logsByTopic(topic: String): Seq[UnifiedLog] = { + (currentLogs.toList ++ futureLogs.toList).collect { + case (topicPartition, log) if topicPartition.topic == topic => log + } + } + + /** + * Map of log dir to logs by topic and partitions in that dir + */ + private def logsByDir: Map[String, Map[TopicPartition, UnifiedLog]] = { + // This code is called often by checkpoint processes and is written in a way that reduces + // allocations and CPU with many topic partitions. + // When changing this code please measure the changes with org.apache.kafka.jmh.server.CheckpointBench + val byDir = new mutable.AnyRefMap[String, mutable.AnyRefMap[TopicPartition, UnifiedLog]]() + def addToDir(tp: TopicPartition, log: UnifiedLog): Unit = { + byDir.getOrElseUpdate(log.parentDir, new mutable.AnyRefMap[TopicPartition, UnifiedLog]()).put(tp, log) + } + currentLogs.foreachEntry(addToDir) + futureLogs.foreachEntry(addToDir) + byDir + } + + private def logsInDir(dir: File): Map[TopicPartition, UnifiedLog] = { + logsByDir.getOrElse(dir.getAbsolutePath, Map.empty) + } + + private def logsInDir(cachedLogsByDir: Map[String, Map[TopicPartition, UnifiedLog]], + dir: File): Map[TopicPartition, UnifiedLog] = { + cachedLogsByDir.getOrElse(dir.getAbsolutePath, Map.empty) + } + + // logDir should be an absolute path + def isLogDirOnline(logDir: String): Boolean = { + // The logDir should be an absolute path + if (!logDirs.exists(_.getAbsolutePath == logDir)) + throw new LogDirNotFoundException(s"Log dir $logDir is not found in the config.") + + _liveLogDirs.contains(new File(logDir)) + } + + /** + * Flush any log which has exceeded its flush interval and has unwritten messages. + */ + private def flushDirtyLogs(): Unit = { + debug("Checking for dirty logs to flush...") + + for ((topicPartition, log) <- currentLogs.toList ++ futureLogs.toList) { + try { + val timeSinceLastFlush = time.milliseconds - log.lastFlushTime + debug(s"Checking if flush is needed on ${topicPartition.topic} flush interval ${log.config.flushMs}" + + s" last flushed ${log.lastFlushTime} time since last flush: $timeSinceLastFlush") + if(timeSinceLastFlush >= log.config.flushMs) + log.flush() + } catch { + case e: Throwable => + error(s"Error flushing topic ${topicPartition.topic}", e) + } + } + } + + private def removeLogAndMetrics(logs: Pool[TopicPartition, UnifiedLog], tp: TopicPartition): Option[UnifiedLog] = { + val removedLog = logs.remove(tp) + if (removedLog != null) { + removedLog.removeLogMetrics() + Some(removedLog) + } else { + None + } + } +} + +object LogManager { + + /** + * Wait all jobs to complete + * @param jobs jobs + * @param callback this will be called to handle the exception caused by each Future#get + * @return true if all pass. Otherwise, false + */ + private[log] def waitForAllToComplete(jobs: Seq[Future[_]], callback: Throwable => Unit): Boolean = { + jobs.count(future => Try(future.get) match { + case Success(_) => false + case Failure(e) => + callback(e) + true + }) == 0 + } + + val RecoveryPointCheckpointFile = "recovery-point-offset-checkpoint" + val LogStartOffsetCheckpointFile = "log-start-offset-checkpoint" + val ProducerIdExpirationCheckIntervalMs = 10 * 60 * 1000 + + def apply(config: KafkaConfig, + initialOfflineDirs: Seq[String], + configRepository: ConfigRepository, + kafkaScheduler: KafkaScheduler, + time: Time, + brokerTopicStats: BrokerTopicStats, + logDirFailureChannel: LogDirFailureChannel, + keepPartitionMetadataFile: Boolean): LogManager = { + val defaultProps = LogConfig.extractLogConfigMap(config) + + LogConfig.validateValues(defaultProps) + val defaultLogConfig = LogConfig(defaultProps) + + val cleanerConfig = LogCleaner.cleanerConfig(config) + + new LogManager(logDirs = config.logDirs.map(new File(_).getAbsoluteFile), + initialOfflineDirs = initialOfflineDirs.map(new File(_).getAbsoluteFile), + configRepository = configRepository, + initialDefaultConfig = defaultLogConfig, + cleanerConfig = cleanerConfig, + recoveryThreadsPerDataDir = config.numRecoveryThreadsPerDataDir, + flushCheckMs = config.logFlushSchedulerIntervalMs, + flushRecoveryOffsetCheckpointMs = config.logFlushOffsetCheckpointIntervalMs, + flushStartOffsetCheckpointMs = config.logFlushStartOffsetCheckpointIntervalMs, + retentionCheckMs = config.logCleanupIntervalMs, + maxPidExpirationMs = config.transactionalIdExpirationMs, + scheduler = kafkaScheduler, + brokerTopicStats = brokerTopicStats, + logDirFailureChannel = logDirFailureChannel, + time = time, + keepPartitionMetadataFile = keepPartitionMetadataFile, + interBrokerProtocolVersion = config.interBrokerProtocolVersion) + } + +} diff --git a/core/src/main/scala/kafka/log/LogSegment.scala b/core/src/main/scala/kafka/log/LogSegment.scala new file mode 100755 index 0000000..7daf9c4 --- /dev/null +++ b/core/src/main/scala/kafka/log/LogSegment.scala @@ -0,0 +1,691 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.log + +import java.io.{File, IOException} +import java.nio.file.{Files, NoSuchFileException} +import java.nio.file.attribute.FileTime +import java.util.concurrent.TimeUnit +import kafka.common.LogSegmentOffsetOverflowException +import kafka.metrics.{KafkaMetricsGroup, KafkaTimer} +import kafka.server.epoch.LeaderEpochFileCache +import kafka.server.{FetchDataInfo, LogOffsetMetadata} +import kafka.utils._ +import org.apache.kafka.common.InvalidRecordException +import org.apache.kafka.common.errors.CorruptRecordException +import org.apache.kafka.common.record.FileRecords.{LogOffsetPosition, TimestampAndOffset} +import org.apache.kafka.common.record._ +import org.apache.kafka.common.utils.{BufferSupplier, Time} + +import scala.jdk.CollectionConverters._ +import scala.math._ + +/** + * A segment of the log. Each segment has two components: a log and an index. The log is a FileRecords containing + * the actual messages. The index is an OffsetIndex that maps from logical offsets to physical file positions. Each + * segment has a base offset which is an offset <= the least offset of any message in this segment and > any offset in + * any previous segment. + * + * A segment with a base offset of [base_offset] would be stored in two files, a [base_offset].index and a [base_offset].log file. + * + * @param log The file records containing log entries + * @param lazyOffsetIndex The offset index + * @param lazyTimeIndex The timestamp index + * @param txnIndex The transaction index + * @param baseOffset A lower bound on the offsets in this segment + * @param indexIntervalBytes The approximate number of bytes between entries in the index + * @param rollJitterMs The maximum random jitter subtracted from the scheduled segment roll time + * @param time The time instance + */ +@nonthreadsafe +class LogSegment private[log] (val log: FileRecords, + val lazyOffsetIndex: LazyIndex[OffsetIndex], + val lazyTimeIndex: LazyIndex[TimeIndex], + val txnIndex: TransactionIndex, + val baseOffset: Long, + val indexIntervalBytes: Int, + val rollJitterMs: Long, + val time: Time) extends Logging { + + def offsetIndex: OffsetIndex = lazyOffsetIndex.get + + def timeIndex: TimeIndex = lazyTimeIndex.get + + def shouldRoll(rollParams: RollParams): Boolean = { + val reachedRollMs = timeWaitedForRoll(rollParams.now, rollParams.maxTimestampInMessages) > rollParams.maxSegmentMs - rollJitterMs + size > rollParams.maxSegmentBytes - rollParams.messagesSize || + (size > 0 && reachedRollMs) || + offsetIndex.isFull || timeIndex.isFull || !canConvertToRelativeOffset(rollParams.maxOffsetInMessages) + } + + def resizeIndexes(size: Int): Unit = { + offsetIndex.resize(size) + timeIndex.resize(size) + } + + def sanityCheck(timeIndexFileNewlyCreated: Boolean): Unit = { + if (lazyOffsetIndex.file.exists) { + // Resize the time index file to 0 if it is newly created. + if (timeIndexFileNewlyCreated) + timeIndex.resize(0) + // Sanity checks for time index and offset index are skipped because + // we will recover the segments above the recovery point in recoverLog() + // in any case so sanity checking them here is redundant. + txnIndex.sanityCheck() + } + else throw new NoSuchFileException(s"Offset index file ${lazyOffsetIndex.file.getAbsolutePath} does not exist") + } + + private var created = time.milliseconds + + /* the number of bytes since we last added an entry in the offset index */ + private var bytesSinceLastIndexEntry = 0 + + // The timestamp we used for time based log rolling and for ensuring max compaction delay + // volatile for LogCleaner to see the update + @volatile private var rollingBasedTimestamp: Option[Long] = None + + /* The maximum timestamp and offset we see so far */ + @volatile private var _maxTimestampAndOffsetSoFar: TimestampOffset = TimestampOffset.Unknown + def maxTimestampAndOffsetSoFar_= (timestampOffset: TimestampOffset): Unit = _maxTimestampAndOffsetSoFar = timestampOffset + def maxTimestampAndOffsetSoFar: TimestampOffset = { + if (_maxTimestampAndOffsetSoFar == TimestampOffset.Unknown) + _maxTimestampAndOffsetSoFar = timeIndex.lastEntry + _maxTimestampAndOffsetSoFar + } + + /* The maximum timestamp we see so far */ + def maxTimestampSoFar: Long = { + maxTimestampAndOffsetSoFar.timestamp + } + + def offsetOfMaxTimestampSoFar: Long = { + maxTimestampAndOffsetSoFar.offset + } + + /* Return the size in bytes of this log segment */ + def size: Int = log.sizeInBytes() + + /** + * checks that the argument offset can be represented as an integer offset relative to the baseOffset. + */ + def canConvertToRelativeOffset(offset: Long): Boolean = { + offsetIndex.canAppendOffset(offset) + } + + /** + * Append the given messages starting with the given offset. Add + * an entry to the index if needed. + * + * It is assumed this method is being called from within a lock. + * + * @param largestOffset The last offset in the message set + * @param largestTimestamp The largest timestamp in the message set. + * @param shallowOffsetOfMaxTimestamp The offset of the message that has the largest timestamp in the messages to append. + * @param records The log entries to append. + * @return the physical position in the file of the appended records + * @throws LogSegmentOffsetOverflowException if the largest offset causes index offset overflow + */ + @nonthreadsafe + def append(largestOffset: Long, + largestTimestamp: Long, + shallowOffsetOfMaxTimestamp: Long, + records: MemoryRecords): Unit = { + if (records.sizeInBytes > 0) { + trace(s"Inserting ${records.sizeInBytes} bytes at end offset $largestOffset at position ${log.sizeInBytes} " + + s"with largest timestamp $largestTimestamp at shallow offset $shallowOffsetOfMaxTimestamp") + val physicalPosition = log.sizeInBytes() + if (physicalPosition == 0) + rollingBasedTimestamp = Some(largestTimestamp) + + ensureOffsetInRange(largestOffset) + + // append the messages + val appendedBytes = log.append(records) + trace(s"Appended $appendedBytes to ${log.file} at end offset $largestOffset") + // Update the in memory max timestamp and corresponding offset. + if (largestTimestamp > maxTimestampSoFar) { + maxTimestampAndOffsetSoFar = TimestampOffset(largestTimestamp, shallowOffsetOfMaxTimestamp) + } + // append an entry to the index (if needed) + if (bytesSinceLastIndexEntry > indexIntervalBytes) { + offsetIndex.append(largestOffset, physicalPosition) + timeIndex.maybeAppend(maxTimestampSoFar, offsetOfMaxTimestampSoFar) + bytesSinceLastIndexEntry = 0 + } + bytesSinceLastIndexEntry += records.sizeInBytes + } + } + + private def ensureOffsetInRange(offset: Long): Unit = { + if (!canConvertToRelativeOffset(offset)) + throw new LogSegmentOffsetOverflowException(this, offset) + } + + private def appendChunkFromFile(records: FileRecords, position: Int, bufferSupplier: BufferSupplier): Int = { + var bytesToAppend = 0 + var maxTimestamp = Long.MinValue + var offsetOfMaxTimestamp = Long.MinValue + var maxOffset = Long.MinValue + var readBuffer = bufferSupplier.get(1024 * 1024) + + def canAppend(batch: RecordBatch) = + canConvertToRelativeOffset(batch.lastOffset) && + (bytesToAppend == 0 || bytesToAppend + batch.sizeInBytes < readBuffer.capacity) + + // find all batches that are valid to be appended to the current log segment and + // determine the maximum offset and timestamp + val nextBatches = records.batchesFrom(position).asScala.iterator + for (batch <- nextBatches.takeWhile(canAppend)) { + if (batch.maxTimestamp > maxTimestamp) { + maxTimestamp = batch.maxTimestamp + offsetOfMaxTimestamp = batch.lastOffset + } + maxOffset = batch.lastOffset + bytesToAppend += batch.sizeInBytes + } + + if (bytesToAppend > 0) { + // Grow buffer if needed to ensure we copy at least one batch + if (readBuffer.capacity < bytesToAppend) + readBuffer = bufferSupplier.get(bytesToAppend) + + readBuffer.limit(bytesToAppend) + records.readInto(readBuffer, position) + + append(maxOffset, maxTimestamp, offsetOfMaxTimestamp, MemoryRecords.readableRecords(readBuffer)) + } + + bufferSupplier.release(readBuffer) + bytesToAppend + } + + /** + * Append records from a file beginning at the given position until either the end of the file + * is reached or an offset is found which is too large to convert to a relative offset for the indexes. + * + * @return the number of bytes appended to the log (may be less than the size of the input if an + * offset is encountered which would overflow this segment) + */ + def appendFromFile(records: FileRecords, start: Int): Int = { + var position = start + val bufferSupplier: BufferSupplier = new BufferSupplier.GrowableBufferSupplier + while (position < start + records.sizeInBytes) { + val bytesAppended = appendChunkFromFile(records, position, bufferSupplier) + if (bytesAppended == 0) + return position - start + position += bytesAppended + } + position - start + } + + @nonthreadsafe + def updateTxnIndex(completedTxn: CompletedTxn, lastStableOffset: Long): Unit = { + if (completedTxn.isAborted) { + trace(s"Writing aborted transaction $completedTxn to transaction index, last stable offset is $lastStableOffset") + txnIndex.append(new AbortedTxn(completedTxn, lastStableOffset)) + } + } + + private def updateProducerState(producerStateManager: ProducerStateManager, batch: RecordBatch): Unit = { + if (batch.hasProducerId) { + val producerId = batch.producerId + val appendInfo = producerStateManager.prepareUpdate(producerId, origin = AppendOrigin.Replication) + val maybeCompletedTxn = appendInfo.append(batch, firstOffsetMetadataOpt = None) + producerStateManager.update(appendInfo) + maybeCompletedTxn.foreach { completedTxn => + val lastStableOffset = producerStateManager.lastStableOffset(completedTxn) + updateTxnIndex(completedTxn, lastStableOffset) + producerStateManager.completeTxn(completedTxn) + } + } + producerStateManager.updateMapEndOffset(batch.lastOffset + 1) + } + + /** + * Find the physical file position for the first message with offset >= the requested offset. + * + * The startingFilePosition argument is an optimization that can be used if we already know a valid starting position + * in the file higher than the greatest-lower-bound from the index. + * + * @param offset The offset we want to translate + * @param startingFilePosition A lower bound on the file position from which to begin the search. This is purely an optimization and + * when omitted, the search will begin at the position in the offset index. + * @return The position in the log storing the message with the least offset >= the requested offset and the size of the + * message or null if no message meets this criteria. + */ + @threadsafe + private[log] def translateOffset(offset: Long, startingFilePosition: Int = 0): LogOffsetPosition = { + val mapping = offsetIndex.lookup(offset) + log.searchForOffsetWithSize(offset, max(mapping.position, startingFilePosition)) + } + + /** + * Read a message set from this segment beginning with the first offset >= startOffset. The message set will include + * no more than maxSize bytes and will end before maxOffset if a maxOffset is specified. + * + * @param startOffset A lower bound on the first offset to include in the message set we read + * @param maxSize The maximum number of bytes to include in the message set we read + * @param maxPosition The maximum position in the log segment that should be exposed for read + * @param minOneMessage If this is true, the first message will be returned even if it exceeds `maxSize` (if one exists) + * + * @return The fetched data and the offset metadata of the first message whose offset is >= startOffset, + * or null if the startOffset is larger than the largest offset in this log + */ + @threadsafe + def read(startOffset: Long, + maxSize: Int, + maxPosition: Long = size, + minOneMessage: Boolean = false): FetchDataInfo = { + if (maxSize < 0) + throw new IllegalArgumentException(s"Invalid max size $maxSize for log read from segment $log") + + val startOffsetAndSize = translateOffset(startOffset) + + // if the start position is already off the end of the log, return null + if (startOffsetAndSize == null) + return null + + val startPosition = startOffsetAndSize.position + val offsetMetadata = LogOffsetMetadata(startOffset, this.baseOffset, startPosition) + + val adjustedMaxSize = + if (minOneMessage) math.max(maxSize, startOffsetAndSize.size) + else maxSize + + // return a log segment but with zero size in the case below + if (adjustedMaxSize == 0) + return FetchDataInfo(offsetMetadata, MemoryRecords.EMPTY) + + // calculate the length of the message set to read based on whether or not they gave us a maxOffset + val fetchSize: Int = min((maxPosition - startPosition).toInt, adjustedMaxSize) + + FetchDataInfo(offsetMetadata, log.slice(startPosition, fetchSize), + firstEntryIncomplete = adjustedMaxSize < startOffsetAndSize.size) + } + + def fetchUpperBoundOffset(startOffsetPosition: OffsetPosition, fetchSize: Int): Option[Long] = + offsetIndex.fetchUpperBoundOffset(startOffsetPosition, fetchSize).map(_.offset) + + /** + * Run recovery on the given segment. This will rebuild the index from the log file and lop off any invalid bytes + * from the end of the log and index. + * + * @param producerStateManager Producer state corresponding to the segment's base offset. This is needed to recover + * the transaction index. + * @param leaderEpochCache Optionally a cache for updating the leader epoch during recovery. + * @return The number of bytes truncated from the log + * @throws LogSegmentOffsetOverflowException if the log segment contains an offset that causes the index offset to overflow + */ + @nonthreadsafe + def recover(producerStateManager: ProducerStateManager, leaderEpochCache: Option[LeaderEpochFileCache] = None): Int = { + offsetIndex.reset() + timeIndex.reset() + txnIndex.reset() + var validBytes = 0 + var lastIndexEntry = 0 + maxTimestampAndOffsetSoFar = TimestampOffset.Unknown + try { + for (batch <- log.batches.asScala) { + batch.ensureValid() + ensureOffsetInRange(batch.lastOffset) + + // The max timestamp is exposed at the batch level, so no need to iterate the records + if (batch.maxTimestamp > maxTimestampSoFar) { + maxTimestampAndOffsetSoFar = TimestampOffset(batch.maxTimestamp, batch.lastOffset) + } + + // Build offset index + if (validBytes - lastIndexEntry > indexIntervalBytes) { + offsetIndex.append(batch.lastOffset, validBytes) + timeIndex.maybeAppend(maxTimestampSoFar, offsetOfMaxTimestampSoFar) + lastIndexEntry = validBytes + } + validBytes += batch.sizeInBytes() + + if (batch.magic >= RecordBatch.MAGIC_VALUE_V2) { + leaderEpochCache.foreach { cache => + if (batch.partitionLeaderEpoch >= 0 && cache.latestEpoch.forall(batch.partitionLeaderEpoch > _)) + cache.assign(batch.partitionLeaderEpoch, batch.baseOffset) + } + updateProducerState(producerStateManager, batch) + } + } + } catch { + case e@ (_: CorruptRecordException | _: InvalidRecordException) => + warn("Found invalid messages in log segment %s at byte offset %d: %s. %s" + .format(log.file.getAbsolutePath, validBytes, e.getMessage, e.getCause)) + } + val truncated = log.sizeInBytes - validBytes + if (truncated > 0) + debug(s"Truncated $truncated invalid bytes at the end of segment ${log.file.getAbsoluteFile} during recovery") + + log.truncateTo(validBytes) + offsetIndex.trimToValidSize() + // A normally closed segment always appends the biggest timestamp ever seen into log segment, we do this as well. + timeIndex.maybeAppend(maxTimestampSoFar, offsetOfMaxTimestampSoFar, skipFullCheck = true) + timeIndex.trimToValidSize() + truncated + } + + private def loadLargestTimestamp(): Unit = { + // Get the last time index entry. If the time index is empty, it will return (-1, baseOffset) + val lastTimeIndexEntry = timeIndex.lastEntry + maxTimestampAndOffsetSoFar = lastTimeIndexEntry + + val offsetPosition = offsetIndex.lookup(lastTimeIndexEntry.offset) + // Scan the rest of the messages to see if there is a larger timestamp after the last time index entry. + val maxTimestampOffsetAfterLastEntry = log.largestTimestampAfter(offsetPosition.position) + if (maxTimestampOffsetAfterLastEntry.timestamp > lastTimeIndexEntry.timestamp) { + maxTimestampAndOffsetSoFar = TimestampOffset(maxTimestampOffsetAfterLastEntry.timestamp, maxTimestampOffsetAfterLastEntry.offset) + } + } + + /** + * Check whether the last offset of the last batch in this segment overflows the indexes. + */ + def hasOverflow: Boolean = { + val nextOffset = readNextOffset + nextOffset > baseOffset && !canConvertToRelativeOffset(nextOffset - 1) + } + + def collectAbortedTxns(fetchOffset: Long, upperBoundOffset: Long): TxnIndexSearchResult = + txnIndex.collectAbortedTxns(fetchOffset, upperBoundOffset) + + override def toString: String = "LogSegment(baseOffset=" + baseOffset + + ", size=" + size + + ", lastModifiedTime=" + lastModified + + ", largestRecordTimestamp=" + largestRecordTimestamp + + ")" + + /** + * Truncate off all index and log entries with offsets >= the given offset. + * If the given offset is larger than the largest message in this segment, do nothing. + * + * @param offset The offset to truncate to + * @return The number of log bytes truncated + */ + @nonthreadsafe + def truncateTo(offset: Long): Int = { + // Do offset translation before truncating the index to avoid needless scanning + // in case we truncate the full index + val mapping = translateOffset(offset) + offsetIndex.truncateTo(offset) + timeIndex.truncateTo(offset) + txnIndex.truncateTo(offset) + + // After truncation, reset and allocate more space for the (new currently active) index + offsetIndex.resize(offsetIndex.maxIndexSize) + timeIndex.resize(timeIndex.maxIndexSize) + + val bytesTruncated = if (mapping == null) 0 else log.truncateTo(mapping.position) + if (log.sizeInBytes == 0) { + created = time.milliseconds + rollingBasedTimestamp = None + } + + bytesSinceLastIndexEntry = 0 + if (maxTimestampSoFar >= 0) + loadLargestTimestamp() + bytesTruncated + } + + /** + * Calculate the offset that would be used for the next message to be append to this segment. + * Note that this is expensive. + */ + @threadsafe + def readNextOffset: Long = { + val fetchData = read(offsetIndex.lastOffset, log.sizeInBytes) + if (fetchData == null) + baseOffset + else + fetchData.records.batches.asScala.lastOption + .map(_.nextOffset) + .getOrElse(baseOffset) + } + + /** + * Flush this log segment to disk + */ + @threadsafe + def flush(): Unit = { + LogFlushStats.logFlushTimer.time { + log.flush() + offsetIndex.flush() + timeIndex.flush() + txnIndex.flush() + } + } + + /** + * Update the directory reference for the log and indices in this segment. This would typically be called after a + * directory is renamed. + */ + def updateParentDir(dir: File): Unit = { + log.updateParentDir(dir) + lazyOffsetIndex.updateParentDir(dir) + lazyTimeIndex.updateParentDir(dir) + txnIndex.updateParentDir(dir) + } + + /** + * Change the suffix for the index and log files for this log segment + * IOException from this method should be handled by the caller + */ + def changeFileSuffixes(oldSuffix: String, newSuffix: String): Unit = { + log.renameTo(new File(CoreUtils.replaceSuffix(log.file.getPath, oldSuffix, newSuffix))) + lazyOffsetIndex.renameTo(new File(CoreUtils.replaceSuffix(lazyOffsetIndex.file.getPath, oldSuffix, newSuffix))) + lazyTimeIndex.renameTo(new File(CoreUtils.replaceSuffix(lazyTimeIndex.file.getPath, oldSuffix, newSuffix))) + txnIndex.renameTo(new File(CoreUtils.replaceSuffix(txnIndex.file.getPath, oldSuffix, newSuffix))) + } + + def hasSuffix(suffix: String): Boolean = { + log.file.getName.endsWith(suffix) && + lazyOffsetIndex.file.getName.endsWith(suffix) && + lazyTimeIndex.file.getName.endsWith(suffix) && + txnIndex.file.getName.endsWith(suffix) + } + + /** + * Append the largest time index entry to the time index and trim the log and indexes. + * + * The time index entry appended will be used to decide when to delete the segment. + */ + def onBecomeInactiveSegment(): Unit = { + timeIndex.maybeAppend(maxTimestampSoFar, offsetOfMaxTimestampSoFar, skipFullCheck = true) + offsetIndex.trimToValidSize() + timeIndex.trimToValidSize() + log.trim() + } + + /** + * If not previously loaded, + * load the timestamp of the first message into memory. + */ + private def loadFirstBatchTimestamp(): Unit = { + if (rollingBasedTimestamp.isEmpty) { + val iter = log.batches.iterator() + if (iter.hasNext) + rollingBasedTimestamp = Some(iter.next().maxTimestamp) + } + } + + /** + * The time this segment has waited to be rolled. + * If the first message batch has a timestamp we use its timestamp to determine when to roll a segment. A segment + * is rolled if the difference between the new batch's timestamp and the first batch's timestamp exceeds the + * segment rolling time. + * If the first batch does not have a timestamp, we use the wall clock time to determine when to roll a segment. A + * segment is rolled if the difference between the current wall clock time and the segment create time exceeds the + * segment rolling time. + */ + def timeWaitedForRoll(now: Long, messageTimestamp: Long): Long = { + // Load the timestamp of the first message into memory + loadFirstBatchTimestamp() + rollingBasedTimestamp match { + case Some(t) if t >= 0 => messageTimestamp - t + case _ => now - created + } + } + + /** + * @return the first batch timestamp if the timestamp is available. Otherwise return Long.MaxValue + */ + def getFirstBatchTimestamp(): Long = { + loadFirstBatchTimestamp() + rollingBasedTimestamp match { + case Some(t) if t >= 0 => t + case _ => Long.MaxValue + } + } + + /** + * Search the message offset based on timestamp and offset. + * + * This method returns an option of TimestampOffset. The returned value is determined using the following ordered list of rules: + * + * - If all the messages in the segment have smaller offsets, return None + * - If all the messages in the segment have smaller timestamps, return None + * - If all the messages in the segment have larger timestamps, or no message in the segment has a timestamp + * the returned the offset will be max(the base offset of the segment, startingOffset) and the timestamp will be Message.NoTimestamp. + * - Otherwise, return an option of TimestampOffset. The offset is the offset of the first message whose timestamp + * is greater than or equals to the target timestamp and whose offset is greater than or equals to the startingOffset. + * + * This methods only returns None when 1) all messages' offset < startOffing or 2) the log is not empty but we did not + * see any message when scanning the log from the indexed position. The latter could happen if the log is truncated + * after we get the indexed position but before we scan the log from there. In this case we simply return None and the + * caller will need to check on the truncated log and maybe retry or even do the search on another log segment. + * + * @param timestamp The timestamp to search for. + * @param startingOffset The starting offset to search. + * @return the timestamp and offset of the first message that meets the requirements. None will be returned if there is no such message. + */ + def findOffsetByTimestamp(timestamp: Long, startingOffset: Long = baseOffset): Option[TimestampAndOffset] = { + // Get the index entry with a timestamp less than or equal to the target timestamp + val timestampOffset = timeIndex.lookup(timestamp) + val position = offsetIndex.lookup(math.max(timestampOffset.offset, startingOffset)).position + + // Search the timestamp + Option(log.searchForTimestamp(timestamp, position, startingOffset)) + } + + /** + * Close this log segment + */ + def close(): Unit = { + if (_maxTimestampAndOffsetSoFar != TimestampOffset.Unknown) + CoreUtils.swallow(timeIndex.maybeAppend(maxTimestampSoFar, offsetOfMaxTimestampSoFar, + skipFullCheck = true), this) + CoreUtils.swallow(lazyOffsetIndex.close(), this) + CoreUtils.swallow(lazyTimeIndex.close(), this) + CoreUtils.swallow(log.close(), this) + CoreUtils.swallow(txnIndex.close(), this) + } + + /** + * Close file handlers used by the log segment but don't write to disk. This is used when the disk may have failed + */ + def closeHandlers(): Unit = { + CoreUtils.swallow(lazyOffsetIndex.closeHandler(), this) + CoreUtils.swallow(lazyTimeIndex.closeHandler(), this) + CoreUtils.swallow(log.closeHandlers(), this) + CoreUtils.swallow(txnIndex.close(), this) + } + + /** + * Delete this log segment from the filesystem. + */ + def deleteIfExists(): Unit = { + def delete(delete: () => Boolean, fileType: String, file: File, logIfMissing: Boolean): Unit = { + try { + if (delete()) + info(s"Deleted $fileType ${file.getAbsolutePath}.") + else if (logIfMissing) + info(s"Failed to delete $fileType ${file.getAbsolutePath} because it does not exist.") + } + catch { + case e: IOException => throw new IOException(s"Delete of $fileType ${file.getAbsolutePath} failed.", e) + } + } + + CoreUtils.tryAll(Seq( + () => delete(log.deleteIfExists _, "log", log.file, logIfMissing = true), + () => delete(lazyOffsetIndex.deleteIfExists _, "offset index", lazyOffsetIndex.file, logIfMissing = true), + () => delete(lazyTimeIndex.deleteIfExists _, "time index", lazyTimeIndex.file, logIfMissing = true), + () => delete(txnIndex.deleteIfExists _, "transaction index", txnIndex.file, logIfMissing = false) + )) + } + + def deleted(): Boolean = { + !log.file.exists() && !lazyOffsetIndex.file.exists() && !lazyTimeIndex.file.exists() && !txnIndex.file.exists() + } + + /** + * The last modified time of this log segment as a unix time stamp + */ + def lastModified = log.file.lastModified + + /** + * The largest timestamp this segment contains, if maxTimestampSoFar >= 0, otherwise None. + */ + def largestRecordTimestamp: Option[Long] = if (maxTimestampSoFar >= 0) Some(maxTimestampSoFar) else None + + /** + * The largest timestamp this segment contains. + */ + def largestTimestamp = if (maxTimestampSoFar >= 0) maxTimestampSoFar else lastModified + + /** + * Change the last modified time for this log segment + */ + def lastModified_=(ms: Long) = { + val fileTime = FileTime.fromMillis(ms) + Files.setLastModifiedTime(log.file.toPath, fileTime) + Files.setLastModifiedTime(lazyOffsetIndex.file.toPath, fileTime) + Files.setLastModifiedTime(lazyTimeIndex.file.toPath, fileTime) + } + +} + +object LogSegment { + + def open(dir: File, baseOffset: Long, config: LogConfig, time: Time, fileAlreadyExists: Boolean = false, + initFileSize: Int = 0, preallocate: Boolean = false, fileSuffix: String = ""): LogSegment = { + val maxIndexSize = config.maxIndexSize + new LogSegment( + FileRecords.open(UnifiedLog.logFile(dir, baseOffset, fileSuffix), fileAlreadyExists, initFileSize, preallocate), + LazyIndex.forOffset(UnifiedLog.offsetIndexFile(dir, baseOffset, fileSuffix), baseOffset = baseOffset, maxIndexSize = maxIndexSize), + LazyIndex.forTime(UnifiedLog.timeIndexFile(dir, baseOffset, fileSuffix), baseOffset = baseOffset, maxIndexSize = maxIndexSize), + new TransactionIndex(baseOffset, UnifiedLog.transactionIndexFile(dir, baseOffset, fileSuffix)), + baseOffset, + indexIntervalBytes = config.indexInterval, + rollJitterMs = config.randomSegmentJitter, + time) + } + + def deleteIfExists(dir: File, baseOffset: Long, fileSuffix: String = ""): Unit = { + UnifiedLog.deleteFileIfExists(UnifiedLog.offsetIndexFile(dir, baseOffset, fileSuffix)) + UnifiedLog.deleteFileIfExists(UnifiedLog.timeIndexFile(dir, baseOffset, fileSuffix)) + UnifiedLog.deleteFileIfExists(UnifiedLog.transactionIndexFile(dir, baseOffset, fileSuffix)) + UnifiedLog.deleteFileIfExists(UnifiedLog.logFile(dir, baseOffset, fileSuffix)) + } +} + +object LogFlushStats extends KafkaMetricsGroup { + val logFlushTimer = new KafkaTimer(newTimer("LogFlushRateAndTimeMs", TimeUnit.MILLISECONDS, TimeUnit.SECONDS)) +} diff --git a/core/src/main/scala/kafka/log/LogSegments.scala b/core/src/main/scala/kafka/log/LogSegments.scala new file mode 100644 index 0000000..564586d --- /dev/null +++ b/core/src/main/scala/kafka/log/LogSegments.scala @@ -0,0 +1,268 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.log + +import java.io.File +import java.util.Map +import java.util.concurrent.{ConcurrentNavigableMap, ConcurrentSkipListMap} + +import kafka.utils.threadsafe +import org.apache.kafka.common.TopicPartition + +import scala.jdk.CollectionConverters._ + +/** + * This class encapsulates a thread-safe navigable map of LogSegment instances and provides the + * required read and write behavior on the map. + * + * @param topicPartition the TopicPartition associated with the segments + * (useful for logging purposes) + */ +class LogSegments(topicPartition: TopicPartition) { + + /* the segments of the log with key being LogSegment base offset and value being a LogSegment */ + private val segments: ConcurrentNavigableMap[Long, LogSegment] = new ConcurrentSkipListMap[Long, LogSegment] + + /** + * @return true if the segments are empty, false otherwise. + */ + @threadsafe + def isEmpty: Boolean = segments.isEmpty + + /** + * @return true if the segments are non-empty, false otherwise. + */ + @threadsafe + def nonEmpty: Boolean = !isEmpty + + /** + * Add the given segment, or replace an existing entry. + * + * @param segment the segment to add + */ + @threadsafe + def add(segment: LogSegment): LogSegment = this.segments.put(segment.baseOffset, segment) + + /** + * Remove the segment at the provided offset. + * + * @param offset the offset to be removed + */ + @threadsafe + def remove(offset: Long): Unit = segments.remove(offset) + + /** + * Clears all entries. + */ + @threadsafe + def clear(): Unit = segments.clear() + + /** + * Close all segments. + */ + def close(): Unit = values.foreach(_.close()) + + /** + * Close the handlers for all segments. + */ + def closeHandlers(): Unit = values.foreach(_.closeHandlers()) + + /** + * Update the directory reference for the log and indices of all segments. + * + * @param dir the renamed directory + */ + def updateParentDir(dir: File): Unit = values.foreach(_.updateParentDir(dir)) + + /** + * Take care! this is an O(n) operation, where n is the number of segments. + * + * @return The number of segments. + * + */ + @threadsafe + def numberOfSegments: Int = segments.size + + /** + * @return the base offsets of all segments + */ + def baseOffsets: Iterable[Long] = segments.values().asScala.map(_.baseOffset) + + /** + * @param offset the segment to be checked + * @return true if a segment exists at the provided offset, false otherwise. + */ + @threadsafe + def contains(offset: Long): Boolean = segments.containsKey(offset) + + /** + * Retrieves a segment at the specified offset. + * + * @param offset the segment to be retrieved + * + * @return the segment if it exists, otherwise None. + */ + @threadsafe + def get(offset: Long): Option[LogSegment] = Option(segments.get(offset)) + + /** + * @return an iterator to the log segments ordered from oldest to newest. + */ + def values: Iterable[LogSegment] = segments.values.asScala + + /** + * @return An iterator to all segments beginning with the segment that includes "from" and ending + * with the segment that includes up to "to-1" or the end of the log (if to > end of log). + */ + def values(from: Long, to: Long): Iterable[LogSegment] = { + if (from == to) { + // Handle non-segment-aligned empty sets + List.empty[LogSegment] + } else if (to < from) { + throw new IllegalArgumentException(s"Invalid log segment range: requested segments in $topicPartition " + + s"from offset $from which is greater than limit offset $to") + } else { + val view = Option(segments.floorKey(from)).map { floor => + segments.subMap(floor, to) + }.getOrElse(segments.headMap(to)) + view.values.asScala + } + } + + def nonActiveLogSegmentsFrom(from: Long): Iterable[LogSegment] = { + val activeSegment = lastSegment.get + if (from > activeSegment.baseOffset) + Seq.empty + else + values(from, activeSegment.baseOffset) + } + + /** + * @return the entry associated with the greatest offset less than or equal to the given offset, + * if it exists. + */ + @threadsafe + private def floorEntry(offset: Long): Option[Map.Entry[Long, LogSegment]] = Option(segments.floorEntry(offset)) + + /** + * @return the log segment with the greatest offset less than or equal to the given offset, + * if it exists. + */ + @threadsafe + def floorSegment(offset: Long): Option[LogSegment] = floorEntry(offset).map(_.getValue) + + /** + * @return the entry associated with the greatest offset strictly less than the given offset, + * if it exists. + */ + @threadsafe + private def lowerEntry(offset: Long): Option[Map.Entry[Long, LogSegment]] = Option(segments.lowerEntry(offset)) + + /** + * @return the log segment with the greatest offset strictly less than the given offset, + * if it exists. + */ + @threadsafe + def lowerSegment(offset: Long): Option[LogSegment] = lowerEntry(offset).map(_.getValue) + + /** + * @return the entry associated with the smallest offset strictly greater than the given offset, + * if it exists. + */ + @threadsafe + def higherEntry(offset: Long): Option[Map.Entry[Long, LogSegment]] = Option(segments.higherEntry(offset)) + + /** + * @return the log segment with the smallest offset strictly greater than the given offset, + * if it exists. + */ + @threadsafe + def higherSegment(offset: Long): Option[LogSegment] = higherEntry(offset).map(_.getValue) + + /** + * @return the entry associated with the smallest offset, if it exists. + */ + @threadsafe + def firstEntry: Option[Map.Entry[Long, LogSegment]] = Option(segments.firstEntry) + + /** + * @return the log segment associated with the smallest offset, if it exists. + */ + @threadsafe + def firstSegment: Option[LogSegment] = firstEntry.map(_.getValue) + + /** + * @return the base offset of the log segment associated with the smallest offset, if it exists + */ + private[log] def firstSegmentBaseOffset: Option[Long] = firstSegment.map(_.baseOffset) + + /** + * @return the entry associated with the greatest offset, if it exists. + */ + @threadsafe + def lastEntry: Option[Map.Entry[Long, LogSegment]] = Option(segments.lastEntry) + + /** + * @return the log segment with the greatest offset, if it exists. + */ + @threadsafe + def lastSegment: Option[LogSegment] = lastEntry.map(_.getValue) + + /** + * @return an iterable with log segments ordered from lowest base offset to highest, + * each segment returned has a base offset strictly greater than the provided baseOffset. + */ + def higherSegments(baseOffset: Long): Iterable[LogSegment] = { + val view = + Option(segments.higherKey(baseOffset)).map { + higherOffset => segments.tailMap(higherOffset, true) + }.getOrElse(collection.immutable.Map[Long, LogSegment]().asJava) + view.values.asScala + } + + /** + * The active segment that is currently taking appends + */ + def activeSegment = lastSegment.get + + def sizeInBytes: Long = LogSegments.sizeInBytes(values) + + /** + * Returns an Iterable containing segments matching the provided predicate. + * + * @param predicate the predicate to be used for filtering segments. + */ + def filter(predicate: LogSegment => Boolean): Iterable[LogSegment] = values.filter(predicate) +} + +object LogSegments { + /** + * Calculate a log's size (in bytes) from the provided log segments. + * + * @param segments The log segments to calculate the size of + * @return Sum of the log segments' sizes (in bytes) + */ + def sizeInBytes(segments: Iterable[LogSegment]): Long = + segments.map(_.size.toLong).sum + + def getFirstBatchTimestampForSegments(segments: Iterable[LogSegment]): Iterable[Long] = { + segments.map { + segment => + segment.getFirstBatchTimestamp() + } + } +} diff --git a/core/src/main/scala/kafka/log/LogValidator.scala b/core/src/main/scala/kafka/log/LogValidator.scala new file mode 100644 index 0000000..925c602 --- /dev/null +++ b/core/src/main/scala/kafka/log/LogValidator.scala @@ -0,0 +1,591 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.log + +import java.nio.ByteBuffer +import kafka.api.{ApiVersion, KAFKA_2_1_IV0} +import kafka.common.{LongRef, RecordValidationException} +import kafka.message.{CompressionCodec, NoCompressionCodec, ZStdCompressionCodec} +import kafka.server.{BrokerTopicStats, RequestLocal} +import kafka.utils.Logging +import org.apache.kafka.common.errors.{CorruptRecordException, InvalidTimestampException, UnsupportedCompressionTypeException, UnsupportedForMessageFormatException} +import org.apache.kafka.common.record.{AbstractRecords, CompressionType, MemoryRecords, Record, RecordBatch, RecordConversionStats, TimestampType} +import org.apache.kafka.common.InvalidRecordException +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.ProduceResponse.RecordError +import org.apache.kafka.common.utils.Time + +import scala.collection.{Seq, mutable} +import scala.jdk.CollectionConverters._ +import scala.collection.mutable.ArrayBuffer + +/** + * The source of an append to the log. This is used when determining required validations. + */ +private[kafka] sealed trait AppendOrigin +private[kafka] object AppendOrigin { + + /** + * The log append came through replication from the leader. This typically implies minimal validation. + * Particularly, we do not decompress record batches in order to validate records individually. + */ + case object Replication extends AppendOrigin + + /** + * The log append came from either the group coordinator or the transaction coordinator. We validate + * producer epochs for normal log entries (specifically offset commits from the group coordinator) and + * we validate coordinate end transaction markers from the transaction coordinator. + */ + case object Coordinator extends AppendOrigin + + /** + * The log append came from the client, which implies full validation. + */ + case object Client extends AppendOrigin + + /** + * The log append come from the raft leader, which implies the offsets has been assigned + */ + case object RaftLeader extends AppendOrigin +} + +private[log] object LogValidator extends Logging { + + /** + * Update the offsets for this message set and do further validation on messages including: + * 1. Messages for compacted topics must have keys + * 2. When magic value >= 1, inner messages of a compressed message set must have monotonically increasing offsets + * starting from 0. + * 3. When magic value >= 1, validate and maybe overwrite timestamps of messages. + * 4. Declared count of records in DefaultRecordBatch must match number of valid records contained therein. + * + * This method will convert messages as necessary to the topic's configured message format version. If no format + * conversion or value overwriting is required for messages, this method will perform in-place operations to + * avoid expensive re-compression. + * + * Returns a ValidationAndOffsetAssignResult containing the validated message set, maximum timestamp, the offset + * of the shallow message with the max timestamp and a boolean indicating whether the message sizes may have changed. + */ + private[log] def validateMessagesAndAssignOffsets(records: MemoryRecords, + topicPartition: TopicPartition, + offsetCounter: LongRef, + time: Time, + now: Long, + sourceCodec: CompressionCodec, + targetCodec: CompressionCodec, + compactedTopic: Boolean, + magic: Byte, + timestampType: TimestampType, + timestampDiffMaxMs: Long, + partitionLeaderEpoch: Int, + origin: AppendOrigin, + interBrokerProtocolVersion: ApiVersion, + brokerTopicStats: BrokerTopicStats, + requestLocal: RequestLocal): ValidationAndOffsetAssignResult = { + if (sourceCodec == NoCompressionCodec && targetCodec == NoCompressionCodec) { + // check the magic value + if (!records.hasMatchingMagic(magic)) + convertAndAssignOffsetsNonCompressed(records, topicPartition, offsetCounter, compactedTopic, time, now, timestampType, + timestampDiffMaxMs, magic, partitionLeaderEpoch, origin, brokerTopicStats) + else + // Do in-place validation, offset assignment and maybe set timestamp + assignOffsetsNonCompressed(records, topicPartition, offsetCounter, now, compactedTopic, timestampType, timestampDiffMaxMs, + partitionLeaderEpoch, origin, magic, brokerTopicStats) + } else { + validateMessagesAndAssignOffsetsCompressed(records, topicPartition, offsetCounter, time, now, sourceCodec, + targetCodec, compactedTopic, magic, timestampType, timestampDiffMaxMs, partitionLeaderEpoch, origin, + interBrokerProtocolVersion, brokerTopicStats, requestLocal) + } + } + + private def getFirstBatchAndMaybeValidateNoMoreBatches(records: MemoryRecords, sourceCodec: CompressionCodec): RecordBatch = { + val batchIterator = records.batches.iterator + + if (!batchIterator.hasNext) { + throw new InvalidRecordException("Record batch has no batches at all") + } + + val batch = batchIterator.next() + + // if the format is v2 and beyond, or if the messages are compressed, we should check there's only one batch. + if (batch.magic() >= RecordBatch.MAGIC_VALUE_V2 || sourceCodec != NoCompressionCodec) { + if (batchIterator.hasNext) { + throw new InvalidRecordException("Compressed outer record has more than one batch") + } + } + + batch + } + + private def validateBatch(topicPartition: TopicPartition, + firstBatch: RecordBatch, + batch: RecordBatch, + origin: AppendOrigin, + toMagic: Byte, + brokerTopicStats: BrokerTopicStats): Unit = { + // batch magic byte should have the same magic as the first batch + if (firstBatch.magic() != batch.magic()) { + brokerTopicStats.allTopicsStats.invalidMagicNumberRecordsPerSec.mark() + throw new InvalidRecordException(s"Batch magic ${batch.magic()} is not the same as the first batch'es magic byte ${firstBatch.magic()} in topic partition $topicPartition.") + } + + if (origin == AppendOrigin.Client) { + if (batch.magic >= RecordBatch.MAGIC_VALUE_V2) { + val countFromOffsets = batch.lastOffset - batch.baseOffset + 1 + if (countFromOffsets <= 0) { + brokerTopicStats.allTopicsStats.invalidOffsetOrSequenceRecordsPerSec.mark() + throw new InvalidRecordException(s"Batch has an invalid offset range: [${batch.baseOffset}, ${batch.lastOffset}] in topic partition $topicPartition.") + } + + // v2 and above messages always have a non-null count + val count = batch.countOrNull + if (count <= 0) { + brokerTopicStats.allTopicsStats.invalidOffsetOrSequenceRecordsPerSec.mark() + throw new InvalidRecordException(s"Invalid reported count for record batch: $count in topic partition $topicPartition.") + } + + if (countFromOffsets != batch.countOrNull) { + brokerTopicStats.allTopicsStats.invalidOffsetOrSequenceRecordsPerSec.mark() + throw new InvalidRecordException(s"Inconsistent batch offset range [${batch.baseOffset}, ${batch.lastOffset}] " + + s"and count of records $count in topic partition $topicPartition.") + } + } + + if (batch.isControlBatch) { + brokerTopicStats.allTopicsStats.invalidOffsetOrSequenceRecordsPerSec.mark() + throw new InvalidRecordException(s"Clients are not allowed to write control records in topic partition $topicPartition.") + } + + if (batch.hasProducerId && batch.baseSequence < 0) { + brokerTopicStats.allTopicsStats.invalidOffsetOrSequenceRecordsPerSec.mark() + throw new InvalidRecordException(s"Invalid sequence number ${batch.baseSequence} in record batch " + + s"with producerId ${batch.producerId} in topic partition $topicPartition.") + } + } + + if (batch.isTransactional && toMagic < RecordBatch.MAGIC_VALUE_V2) + throw new UnsupportedForMessageFormatException(s"Transactional records cannot be used with magic version $toMagic") + + if (batch.hasProducerId && toMagic < RecordBatch.MAGIC_VALUE_V2) + throw new UnsupportedForMessageFormatException(s"Idempotent records cannot be used with magic version $toMagic") + } + + private def validateRecord(batch: RecordBatch, topicPartition: TopicPartition, record: Record, batchIndex: Int, now: Long, + timestampType: TimestampType, timestampDiffMaxMs: Long, compactedTopic: Boolean, + brokerTopicStats: BrokerTopicStats): Option[ApiRecordError] = { + if (!record.hasMagic(batch.magic)) { + brokerTopicStats.allTopicsStats.invalidMagicNumberRecordsPerSec.mark() + return Some(ApiRecordError(Errors.INVALID_RECORD, new RecordError(batchIndex, + s"Record $record's magic does not match outer magic ${batch.magic} in topic partition $topicPartition."))) + } + + // verify the record-level CRC only if this is one of the deep entries of a compressed message + // set for magic v0 and v1. For non-compressed messages, there is no inner record for magic v0 and v1, + // so we depend on the batch-level CRC check in Log.analyzeAndValidateRecords(). For magic v2 and above, + // there is no record-level CRC to check. + if (batch.magic <= RecordBatch.MAGIC_VALUE_V1 && batch.isCompressed) { + try { + record.ensureValid() + } catch { + case e: InvalidRecordException => + brokerTopicStats.allTopicsStats.invalidMessageCrcRecordsPerSec.mark() + throw new CorruptRecordException(e.getMessage + s" in topic partition $topicPartition.") + } + } + + validateKey(record, batchIndex, topicPartition, compactedTopic, brokerTopicStats).orElse { + validateTimestamp(batch, record, batchIndex, now, timestampType, timestampDiffMaxMs) + } + } + + private def convertAndAssignOffsetsNonCompressed(records: MemoryRecords, + topicPartition: TopicPartition, + offsetCounter: LongRef, + compactedTopic: Boolean, + time: Time, + now: Long, + timestampType: TimestampType, + timestampDiffMaxMs: Long, + toMagicValue: Byte, + partitionLeaderEpoch: Int, + origin: AppendOrigin, + brokerTopicStats: BrokerTopicStats): ValidationAndOffsetAssignResult = { + val startNanos = time.nanoseconds + val sizeInBytesAfterConversion = AbstractRecords.estimateSizeInBytes(toMagicValue, offsetCounter.value, + CompressionType.NONE, records.records) + + val (producerId, producerEpoch, sequence, isTransactional) = { + val first = records.batches.asScala.head + (first.producerId, first.producerEpoch, first.baseSequence, first.isTransactional) + } + + // The current implementation of BufferSupplier is naive and works best when the buffer size + // cardinality is low, so don't use it here + val newBuffer = ByteBuffer.allocate(sizeInBytesAfterConversion) + val builder = MemoryRecords.builder(newBuffer, toMagicValue, CompressionType.NONE, timestampType, + offsetCounter.value, now, producerId, producerEpoch, sequence, isTransactional, partitionLeaderEpoch) + + val firstBatch = getFirstBatchAndMaybeValidateNoMoreBatches(records, NoCompressionCodec) + + records.batches.forEach { batch => + validateBatch(topicPartition, firstBatch, batch, origin, toMagicValue, brokerTopicStats) + + val recordErrors = new ArrayBuffer[ApiRecordError](0) + for ((record, batchIndex) <- batch.asScala.view.zipWithIndex) { + validateRecord(batch, topicPartition, record, batchIndex, now, timestampType, + timestampDiffMaxMs, compactedTopic, brokerTopicStats).foreach(recordError => recordErrors += recordError) + // we fail the batch if any record fails, so we stop appending if any record fails + if (recordErrors.isEmpty) + builder.appendWithOffset(offsetCounter.getAndIncrement(), record) + } + + processRecordErrors(recordErrors) + } + + val convertedRecords = builder.build() + + val info = builder.info + val recordConversionStats = new RecordConversionStats(builder.uncompressedBytesWritten, + builder.numRecords, time.nanoseconds - startNanos) + ValidationAndOffsetAssignResult( + validatedRecords = convertedRecords, + maxTimestamp = info.maxTimestamp, + shallowOffsetOfMaxTimestamp = info.shallowOffsetOfMaxTimestamp, + messageSizeMaybeChanged = true, + recordConversionStats = recordConversionStats) + } + + def assignOffsetsNonCompressed(records: MemoryRecords, + topicPartition: TopicPartition, + offsetCounter: LongRef, + now: Long, + compactedTopic: Boolean, + timestampType: TimestampType, + timestampDiffMaxMs: Long, + partitionLeaderEpoch: Int, + origin: AppendOrigin, + magic: Byte, + brokerTopicStats: BrokerTopicStats): ValidationAndOffsetAssignResult = { + var maxTimestamp = RecordBatch.NO_TIMESTAMP + var offsetOfMaxTimestamp = -1L + val initialOffset = offsetCounter.value + + val firstBatch = getFirstBatchAndMaybeValidateNoMoreBatches(records, NoCompressionCodec) + + records.batches.forEach { batch => + validateBatch(topicPartition, firstBatch, batch, origin, magic, brokerTopicStats) + + var maxBatchTimestamp = RecordBatch.NO_TIMESTAMP + var offsetOfMaxBatchTimestamp = -1L + + val recordErrors = new ArrayBuffer[ApiRecordError](0) + // This is a hot path and we want to avoid any unnecessary allocations. + // That said, there is no benefit in using `skipKeyValueIterator` for the uncompressed + // case since we don't do key/value copies in this path (we just slice the ByteBuffer) + var batchIndex = 0 + batch.forEach { record => + validateRecord(batch, topicPartition, record, batchIndex, now, timestampType, + timestampDiffMaxMs, compactedTopic, brokerTopicStats).foreach(recordError => recordErrors += recordError) + + val offset = offsetCounter.getAndIncrement() + if (batch.magic > RecordBatch.MAGIC_VALUE_V0 && record.timestamp > maxBatchTimestamp) { + maxBatchTimestamp = record.timestamp + offsetOfMaxBatchTimestamp = offset + } + batchIndex += 1 + } + + processRecordErrors(recordErrors) + + if (batch.magic > RecordBatch.MAGIC_VALUE_V0 && maxBatchTimestamp > maxTimestamp) { + maxTimestamp = maxBatchTimestamp + offsetOfMaxTimestamp = offsetOfMaxBatchTimestamp + } + + batch.setLastOffset(offsetCounter.value - 1) + + if (batch.magic >= RecordBatch.MAGIC_VALUE_V2) + batch.setPartitionLeaderEpoch(partitionLeaderEpoch) + + if (batch.magic > RecordBatch.MAGIC_VALUE_V0) { + if (timestampType == TimestampType.LOG_APPEND_TIME) + batch.setMaxTimestamp(TimestampType.LOG_APPEND_TIME, now) + else + batch.setMaxTimestamp(timestampType, maxBatchTimestamp) + } + } + + if (timestampType == TimestampType.LOG_APPEND_TIME) { + maxTimestamp = now + if (magic >= RecordBatch.MAGIC_VALUE_V2) + offsetOfMaxTimestamp = offsetCounter.value - 1 + else + offsetOfMaxTimestamp = initialOffset + } + + ValidationAndOffsetAssignResult( + validatedRecords = records, + maxTimestamp = maxTimestamp, + shallowOffsetOfMaxTimestamp = offsetOfMaxTimestamp, + messageSizeMaybeChanged = false, + recordConversionStats = RecordConversionStats.EMPTY) + } + + /** + * We cannot do in place assignment in one of the following situations: + * 1. Source and target compression codec are different + * 2. When the target magic is not equal to batches' magic, meaning format conversion is needed. + * 3. When the target magic is equal to V0, meaning absolute offsets need to be re-assigned. + */ + def validateMessagesAndAssignOffsetsCompressed(records: MemoryRecords, + topicPartition: TopicPartition, + offsetCounter: LongRef, + time: Time, + now: Long, + sourceCodec: CompressionCodec, + targetCodec: CompressionCodec, + compactedTopic: Boolean, + toMagic: Byte, + timestampType: TimestampType, + timestampDiffMaxMs: Long, + partitionLeaderEpoch: Int, + origin: AppendOrigin, + interBrokerProtocolVersion: ApiVersion, + brokerTopicStats: BrokerTopicStats, + requestLocal: RequestLocal): ValidationAndOffsetAssignResult = { + + if (targetCodec == ZStdCompressionCodec && interBrokerProtocolVersion < KAFKA_2_1_IV0) + throw new UnsupportedCompressionTypeException("Produce requests to inter.broker.protocol.version < 2.1 broker " + + "are not allowed to use ZStandard compression") + + def validateRecordCompression(batchIndex: Int, record: Record): Option[ApiRecordError] = { + if (sourceCodec != NoCompressionCodec && record.isCompressed) + Some(ApiRecordError(Errors.INVALID_RECORD, new RecordError(batchIndex, + s"Compressed outer record should not have an inner record with a compression attribute set: $record"))) + else None + } + + // No in place assignment situation 1 + var inPlaceAssignment = sourceCodec == targetCodec + + var maxTimestamp = RecordBatch.NO_TIMESTAMP + val expectedInnerOffset = new LongRef(0) + val validatedRecords = new mutable.ArrayBuffer[Record] + + var uncompressedSizeInBytes = 0 + + // Assume there's only one batch with compressed memory records; otherwise, return InvalidRecordException + // One exception though is that with format smaller than v2, if sourceCodec is noCompression, then each batch is actually + // a single record so we'd need to special handle it by creating a single wrapper batch that includes all the records + val firstBatch = getFirstBatchAndMaybeValidateNoMoreBatches(records, sourceCodec) + + // No in place assignment situation 2 and 3: we only need to check for the first batch because: + // 1. For most cases (compressed records, v2, for example), there's only one batch anyways. + // 2. For cases that there may be multiple batches, all batches' magic should be the same. + if (firstBatch.magic != toMagic || toMagic == RecordBatch.MAGIC_VALUE_V0) + inPlaceAssignment = false + + // Do not compress control records unless they are written compressed + if (sourceCodec == NoCompressionCodec && firstBatch.isControlBatch) + inPlaceAssignment = true + + records.batches.forEach { batch => + validateBatch(topicPartition, firstBatch, batch, origin, toMagic, brokerTopicStats) + uncompressedSizeInBytes += AbstractRecords.recordBatchHeaderSizeInBytes(toMagic, batch.compressionType()) + + // if we are on version 2 and beyond, and we know we are going for in place assignment, + // then we can optimize the iterator to skip key / value / headers since they would not be used at all + val recordsIterator = if (inPlaceAssignment && firstBatch.magic >= RecordBatch.MAGIC_VALUE_V2) + batch.skipKeyValueIterator(requestLocal.bufferSupplier) + else + batch.streamingIterator(requestLocal.bufferSupplier) + + try { + val recordErrors = new ArrayBuffer[ApiRecordError](0) + // this is a hot path and we want to avoid any unnecessary allocations. + var batchIndex = 0 + recordsIterator.forEachRemaining { record => + val expectedOffset = expectedInnerOffset.getAndIncrement() + val recordError = validateRecordCompression(batchIndex, record).orElse { + validateRecord(batch, topicPartition, record, batchIndex, now, + timestampType, timestampDiffMaxMs, compactedTopic, brokerTopicStats).orElse { + if (batch.magic > RecordBatch.MAGIC_VALUE_V0 && toMagic > RecordBatch.MAGIC_VALUE_V0) { + if (record.timestamp > maxTimestamp) + maxTimestamp = record.timestamp + + // Some older clients do not implement the V1 internal offsets correctly. + // Historically the broker handled this by rewriting the batches rather + // than rejecting the request. We must continue this handling here to avoid + // breaking these clients. + if (record.offset != expectedOffset) + inPlaceAssignment = false + } + None + } + } + + recordError match { + case Some(e) => recordErrors += e + case None => + uncompressedSizeInBytes += record.sizeInBytes() + validatedRecords += record + } + batchIndex += 1 + } + processRecordErrors(recordErrors) + } finally { + recordsIterator.close() + } + } + + if (!inPlaceAssignment) { + val (producerId, producerEpoch, sequence, isTransactional) = { + // note that we only reassign offsets for requests coming straight from a producer. For records with magic V2, + // there should be exactly one RecordBatch per request, so the following is all we need to do. For Records + // with older magic versions, there will never be a producer id, etc. + val first = records.batches.asScala.head + (first.producerId, first.producerEpoch, first.baseSequence, first.isTransactional) + } + buildRecordsAndAssignOffsets(toMagic, offsetCounter, time, timestampType, CompressionType.forId(targetCodec.codec), + now, validatedRecords, producerId, producerEpoch, sequence, isTransactional, partitionLeaderEpoch, + uncompressedSizeInBytes) + } else { + // we can update the batch only and write the compressed payload as is; + // again we assume only one record batch within the compressed set + val batch = records.batches.iterator.next() + val lastOffset = offsetCounter.addAndGet(validatedRecords.size) - 1 + + batch.setLastOffset(lastOffset) + + if (timestampType == TimestampType.LOG_APPEND_TIME) + maxTimestamp = now + + if (toMagic >= RecordBatch.MAGIC_VALUE_V1) + batch.setMaxTimestamp(timestampType, maxTimestamp) + + if (toMagic >= RecordBatch.MAGIC_VALUE_V2) + batch.setPartitionLeaderEpoch(partitionLeaderEpoch) + + val recordConversionStats = new RecordConversionStats(uncompressedSizeInBytes, 0, 0) + ValidationAndOffsetAssignResult(validatedRecords = records, + maxTimestamp = maxTimestamp, + shallowOffsetOfMaxTimestamp = lastOffset, + messageSizeMaybeChanged = false, + recordConversionStats = recordConversionStats) + } + } + + private def buildRecordsAndAssignOffsets(magic: Byte, + offsetCounter: LongRef, + time: Time, + timestampType: TimestampType, + compressionType: CompressionType, + logAppendTime: Long, + validatedRecords: Seq[Record], + producerId: Long, + producerEpoch: Short, + baseSequence: Int, + isTransactional: Boolean, + partitionLeaderEpoch: Int, + uncompressedSizeInBytes: Int): ValidationAndOffsetAssignResult = { + val startNanos = time.nanoseconds + val estimatedSize = AbstractRecords.estimateSizeInBytes(magic, offsetCounter.value, compressionType, + validatedRecords.asJava) + // The current implementation of BufferSupplier is naive and works best when the buffer size + // cardinality is low, so don't use it here + val buffer = ByteBuffer.allocate(estimatedSize) + val builder = MemoryRecords.builder(buffer, magic, compressionType, timestampType, offsetCounter.value, + logAppendTime, producerId, producerEpoch, baseSequence, isTransactional, partitionLeaderEpoch) + + validatedRecords.foreach { record => + builder.appendWithOffset(offsetCounter.getAndIncrement(), record) + } + + val records = builder.build() + + val info = builder.info + + // This is not strictly correct, it represents the number of records where in-place assignment is not possible + // instead of the number of records that were converted. It will over-count cases where the source and target are + // message format V0 or if the inner offsets are not consecutive. This is OK since the impact is the same: we have + // to rebuild the records (including recompression if enabled). + val conversionCount = builder.numRecords + val recordConversionStats = new RecordConversionStats(uncompressedSizeInBytes + builder.uncompressedBytesWritten, + conversionCount, time.nanoseconds - startNanos) + + ValidationAndOffsetAssignResult( + validatedRecords = records, + maxTimestamp = info.maxTimestamp, + shallowOffsetOfMaxTimestamp = info.shallowOffsetOfMaxTimestamp, + messageSizeMaybeChanged = true, + recordConversionStats = recordConversionStats) + } + + private def validateKey(record: Record, + batchIndex: Int, + topicPartition: TopicPartition, + compactedTopic: Boolean, + brokerTopicStats: BrokerTopicStats): Option[ApiRecordError] = { + if (compactedTopic && !record.hasKey) { + brokerTopicStats.allTopicsStats.noKeyCompactedTopicRecordsPerSec.mark() + Some(ApiRecordError(Errors.INVALID_RECORD, new RecordError(batchIndex, + s"Compacted topic cannot accept message without key in topic partition $topicPartition."))) + } else None + } + + private def validateTimestamp(batch: RecordBatch, + record: Record, + batchIndex: Int, + now: Long, + timestampType: TimestampType, + timestampDiffMaxMs: Long): Option[ApiRecordError] = { + if (timestampType == TimestampType.CREATE_TIME + && record.timestamp != RecordBatch.NO_TIMESTAMP + && math.abs(record.timestamp - now) > timestampDiffMaxMs) + Some(ApiRecordError(Errors.INVALID_TIMESTAMP, new RecordError(batchIndex, + s"Timestamp ${record.timestamp} of message with offset ${record.offset} is " + + s"out of range. The timestamp should be within [${now - timestampDiffMaxMs}, " + + s"${now + timestampDiffMaxMs}]"))) + else if (batch.timestampType == TimestampType.LOG_APPEND_TIME) + Some(ApiRecordError(Errors.INVALID_TIMESTAMP, new RecordError(batchIndex, + s"Invalid timestamp type in message $record. Producer should not set timestamp " + + "type to LogAppendTime."))) + else None + } + + private def processRecordErrors(recordErrors: Seq[ApiRecordError]): Unit = { + if (recordErrors.nonEmpty) { + val errors = recordErrors.map(_.recordError) + if (recordErrors.exists(_.apiError == Errors.INVALID_TIMESTAMP)) { + throw new RecordValidationException(new InvalidTimestampException( + "One or more records have been rejected due to invalid timestamp"), errors) + } else { + throw new RecordValidationException(new InvalidRecordException( + "One or more records have been rejected"), errors) + } + } + } + + case class ValidationAndOffsetAssignResult(validatedRecords: MemoryRecords, + maxTimestamp: Long, + shallowOffsetOfMaxTimestamp: Long, + messageSizeMaybeChanged: Boolean, + recordConversionStats: RecordConversionStats) + + private case class ApiRecordError(apiError: Errors, recordError: RecordError) +} diff --git a/core/src/main/scala/kafka/log/OffsetIndex.scala b/core/src/main/scala/kafka/log/OffsetIndex.scala new file mode 100755 index 0000000..5719ee3 --- /dev/null +++ b/core/src/main/scala/kafka/log/OffsetIndex.scala @@ -0,0 +1,207 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import java.io.File +import java.nio.ByteBuffer + +import kafka.utils.CoreUtils.inLock +import kafka.utils.Logging +import org.apache.kafka.common.errors.InvalidOffsetException + +/** + * An index that maps offsets to physical file locations for a particular log segment. This index may be sparse: + * that is it may not hold an entry for all messages in the log. + * + * The index is stored in a file that is pre-allocated to hold a fixed maximum number of 8-byte entries. + * + * The index supports lookups against a memory-map of this file. These lookups are done using a simple binary search variant + * to locate the offset/location pair for the greatest offset less than or equal to the target offset. + * + * Index files can be opened in two ways: either as an empty, mutable index that allows appends or + * an immutable read-only index file that has previously been populated. The makeReadOnly method will turn a mutable file into an + * immutable one and truncate off any extra bytes. This is done when the index file is rolled over. + * + * No attempt is made to checksum the contents of this file, in the event of a crash it is rebuilt. + * + * The file format is a series of entries. The physical format is a 4 byte "relative" offset and a 4 byte file location for the + * message with that offset. The offset stored is relative to the base offset of the index file. So, for example, + * if the base offset was 50, then the offset 55 would be stored as 5. Using relative offsets in this way let's us use + * only 4 bytes for the offset. + * + * The frequency of entries is up to the user of this class. + * + * All external APIs translate from relative offsets to full offsets, so users of this class do not interact with the internal + * storage format. + */ +// Avoid shadowing mutable `file` in AbstractIndex +class OffsetIndex(_file: File, baseOffset: Long, maxIndexSize: Int = -1, writable: Boolean = true) + extends AbstractIndex(_file, baseOffset, maxIndexSize, writable) { + import OffsetIndex._ + + override def entrySize = 8 + + /* the last offset in the index */ + private[this] var _lastOffset = lastEntry.offset + + debug(s"Loaded index file ${file.getAbsolutePath} with maxEntries = $maxEntries, " + + s"maxIndexSize = $maxIndexSize, entries = ${_entries}, lastOffset = ${_lastOffset}, file position = ${mmap.position()}") + + /** + * The last entry in the index + */ + private def lastEntry: OffsetPosition = { + inLock(lock) { + _entries match { + case 0 => OffsetPosition(baseOffset, 0) + case s => parseEntry(mmap, s - 1) + } + } + } + + def lastOffset: Long = _lastOffset + + /** + * Find the largest offset less than or equal to the given targetOffset + * and return a pair holding this offset and its corresponding physical file position. + * + * @param targetOffset The offset to look up. + * @return The offset found and the corresponding file position for this offset + * If the target offset is smaller than the least entry in the index (or the index is empty), + * the pair (baseOffset, 0) is returned. + */ + def lookup(targetOffset: Long): OffsetPosition = { + maybeLock(lock) { + val idx = mmap.duplicate + val slot = largestLowerBoundSlotFor(idx, targetOffset, IndexSearchType.KEY) + if(slot == -1) + OffsetPosition(baseOffset, 0) + else + parseEntry(idx, slot) + } + } + + /** + * Find an upper bound offset for the given fetch starting position and size. This is an offset which + * is guaranteed to be outside the fetched range, but note that it will not generally be the smallest + * such offset. + */ + def fetchUpperBoundOffset(fetchOffset: OffsetPosition, fetchSize: Int): Option[OffsetPosition] = { + maybeLock(lock) { + val idx = mmap.duplicate + val slot = smallestUpperBoundSlotFor(idx, fetchOffset.position + fetchSize, IndexSearchType.VALUE) + if (slot == -1) + None + else + Some(parseEntry(idx, slot)) + } + } + + private def relativeOffset(buffer: ByteBuffer, n: Int): Int = buffer.getInt(n * entrySize) + + private def physical(buffer: ByteBuffer, n: Int): Int = buffer.getInt(n * entrySize + 4) + + override protected def parseEntry(buffer: ByteBuffer, n: Int): OffsetPosition = { + OffsetPosition(baseOffset + relativeOffset(buffer, n), physical(buffer, n)) + } + + /** + * Get the nth offset mapping from the index + * @param n The entry number in the index + * @return The offset/position pair at that entry + */ + def entry(n: Int): OffsetPosition = { + maybeLock(lock) { + if (n >= _entries) + throw new IllegalArgumentException(s"Attempt to fetch the ${n}th entry from index ${file.getAbsolutePath}, " + + s"which has size ${_entries}.") + parseEntry(mmap, n) + } + } + + /** + * Append an entry for the given offset/location pair to the index. This entry must have a larger offset than all subsequent entries. + * @throws IndexOffsetOverflowException if the offset causes index offset to overflow + */ + def append(offset: Long, position: Int): Unit = { + inLock(lock) { + require(!isFull, "Attempt to append to a full index (size = " + _entries + ").") + if (_entries == 0 || offset > _lastOffset) { + trace(s"Adding index entry $offset => $position to ${file.getAbsolutePath}") + mmap.putInt(relativeOffset(offset)) + mmap.putInt(position) + _entries += 1 + _lastOffset = offset + require(_entries * entrySize == mmap.position(), s"$entries entries but file position in index is ${mmap.position()}.") + } else { + throw new InvalidOffsetException(s"Attempt to append an offset ($offset) to position $entries no larger than" + + s" the last offset appended (${_lastOffset}) to ${file.getAbsolutePath}.") + } + } + } + + override def truncate() = truncateToEntries(0) + + override def truncateTo(offset: Long): Unit = { + inLock(lock) { + val idx = mmap.duplicate + val slot = largestLowerBoundSlotFor(idx, offset, IndexSearchType.KEY) + + /* There are 3 cases for choosing the new size + * 1) if there is no entry in the index <= the offset, delete everything + * 2) if there is an entry for this exact offset, delete it and everything larger than it + * 3) if there is no entry for this offset, delete everything larger than the next smallest + */ + val newEntries = + if(slot < 0) + 0 + else if(relativeOffset(idx, slot) == offset - baseOffset) + slot + else + slot + 1 + truncateToEntries(newEntries) + } + } + + /** + * Truncates index to a known number of entries. + */ + private def truncateToEntries(entries: Int): Unit = { + inLock(lock) { + _entries = entries + mmap.position(_entries * entrySize) + _lastOffset = lastEntry.offset + debug(s"Truncated index ${file.getAbsolutePath} to $entries entries;" + + s" position is now ${mmap.position()} and last offset is now ${_lastOffset}") + } + } + + override def sanityCheck(): Unit = { + if (_entries != 0 && _lastOffset < baseOffset) + throw new CorruptIndexException(s"Corrupt index found, index file (${file.getAbsolutePath}) has non-zero size " + + s"but the last offset is ${_lastOffset} which is less than the base offset $baseOffset.") + if (length % entrySize != 0) + throw new CorruptIndexException(s"Index file ${file.getAbsolutePath} is corrupt, found $length bytes which is " + + s"neither positive nor a multiple of $entrySize.") + } + +} + +object OffsetIndex extends Logging { + override val loggerName: String = classOf[OffsetIndex].getName +} diff --git a/core/src/main/scala/kafka/log/OffsetMap.scala b/core/src/main/scala/kafka/log/OffsetMap.scala new file mode 100755 index 0000000..22b5305 --- /dev/null +++ b/core/src/main/scala/kafka/log/OffsetMap.scala @@ -0,0 +1,201 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import java.util.Arrays +import java.security.MessageDigest +import java.nio.ByteBuffer +import kafka.utils._ +import org.apache.kafka.common.utils.Utils + +trait OffsetMap { + def slots: Int + def put(key: ByteBuffer, offset: Long): Unit + def get(key: ByteBuffer): Long + def updateLatestOffset(offset: Long): Unit + def clear(): Unit + def size: Int + def utilization: Double = size.toDouble / slots + def latestOffset: Long +} + +/** + * An hash table used for deduplicating the log. This hash table uses a cryptographicly secure hash of the key as a proxy for the key + * for comparisons and to save space on object overhead. Collisions are resolved by probing. This hash table does not support deletes. + * @param memory The amount of memory this map can use + * @param hashAlgorithm The hash algorithm instance to use: MD2, MD5, SHA-1, SHA-256, SHA-384, SHA-512 + */ +@nonthreadsafe +class SkimpyOffsetMap(val memory: Int, val hashAlgorithm: String = "MD5") extends OffsetMap { + private val bytes = ByteBuffer.allocate(memory) + + /* the hash algorithm instance to use, default is MD5 */ + private val digest = MessageDigest.getInstance(hashAlgorithm) + + /* the number of bytes for this hash algorithm */ + private val hashSize = digest.getDigestLength + + /* create some hash buffers to avoid reallocating each time */ + private val hash1 = new Array[Byte](hashSize) + private val hash2 = new Array[Byte](hashSize) + + /* number of entries put into the map */ + private var entries = 0 + + /* number of lookups on the map */ + private var lookups = 0L + + /* the number of probes for all lookups */ + private var probes = 0L + + /* the latest offset written into the map */ + private var lastOffset = -1L + + /** + * The number of bytes of space each entry uses (the number of bytes in the hash plus an 8 byte offset) + */ + val bytesPerEntry = hashSize + 8 + + /** + * The maximum number of entries this map can contain + */ + val slots: Int = memory / bytesPerEntry + + /** + * Associate this offset to the given key. + * @param key The key + * @param offset The offset + */ + override def put(key: ByteBuffer, offset: Long): Unit = { + require(entries < slots, "Attempt to add a new entry to a full offset map.") + lookups += 1 + hashInto(key, hash1) + // probe until we find the first empty slot + var attempt = 0 + var pos = positionOf(hash1, attempt) + while(!isEmpty(pos)) { + bytes.position(pos) + bytes.get(hash2) + if(Arrays.equals(hash1, hash2)) { + // we found an existing entry, overwrite it and return (size does not change) + bytes.putLong(offset) + lastOffset = offset + return + } + attempt += 1 + pos = positionOf(hash1, attempt) + } + // found an empty slot, update it--size grows by 1 + bytes.position(pos) + bytes.put(hash1) + bytes.putLong(offset) + lastOffset = offset + entries += 1 + } + + /** + * Check that there is no entry at the given position + */ + private def isEmpty(position: Int): Boolean = + bytes.getLong(position) == 0 && bytes.getLong(position + 8) == 0 && bytes.getLong(position + 16) == 0 + + /** + * Get the offset associated with this key. + * @param key The key + * @return The offset associated with this key or -1 if the key is not found + */ + override def get(key: ByteBuffer): Long = { + lookups += 1 + hashInto(key, hash1) + // search for the hash of this key by repeated probing until we find the hash we are looking for or we find an empty slot + var attempt = 0 + var pos = 0 + //we need to guard against attempt integer overflow if the map is full + //limit attempt to number of slots once positionOf(..) enters linear search mode + val maxAttempts = slots + hashSize - 4 + do { + if(attempt >= maxAttempts) + return -1L + pos = positionOf(hash1, attempt) + bytes.position(pos) + if(isEmpty(pos)) + return -1L + bytes.get(hash2) + attempt += 1 + } while(!Arrays.equals(hash1, hash2)) + bytes.getLong() + } + + /** + * Change the salt used for key hashing making all existing keys unfindable. + */ + override def clear(): Unit = { + this.entries = 0 + this.lookups = 0L + this.probes = 0L + this.lastOffset = -1L + Arrays.fill(bytes.array, bytes.arrayOffset, bytes.arrayOffset + bytes.limit(), 0.toByte) + } + + /** + * The number of entries put into the map (note that not all may remain) + */ + override def size: Int = entries + + /** + * The rate of collisions in the lookups + */ + def collisionRate: Double = + (this.probes - this.lookups) / this.lookups.toDouble + + /** + * The latest offset put into the map + */ + override def latestOffset: Long = lastOffset + + override def updateLatestOffset(offset: Long): Unit = { + lastOffset = offset + } + + /** + * Calculate the ith probe position. We first try reading successive integers from the hash itself + * then if all of those fail we degrade to linear probing. + * @param hash The hash of the key to find the position for + * @param attempt The ith probe + * @return The byte offset in the buffer at which the ith probing for the given hash would reside + */ + private def positionOf(hash: Array[Byte], attempt: Int): Int = { + val probe = CoreUtils.readInt(hash, math.min(attempt, hashSize - 4)) + math.max(0, attempt - hashSize + 4) + val slot = Utils.abs(probe) % slots + this.probes += 1 + slot * bytesPerEntry + } + + /** + * The offset at which we have stored the given key + * @param key The key to hash + * @param buffer The buffer to store the hash into + */ + private def hashInto(key: ByteBuffer, buffer: Array[Byte]): Unit = { + key.mark() + digest.update(key) + key.reset() + digest.digest(buffer, 0, hashSize) + } + +} diff --git a/core/src/main/scala/kafka/log/ProducerStateManager.scala b/core/src/main/scala/kafka/log/ProducerStateManager.scala new file mode 100644 index 0000000..5f5c225 --- /dev/null +++ b/core/src/main/scala/kafka/log/ProducerStateManager.scala @@ -0,0 +1,907 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.log + +import java.io.File +import java.nio.ByteBuffer +import java.nio.channels.FileChannel +import java.nio.file.{Files, NoSuchFileException, StandardOpenOption} +import java.util.concurrent.ConcurrentSkipListMap +import kafka.log.UnifiedLog.offsetFromFile +import kafka.server.LogOffsetMetadata +import kafka.utils.{CoreUtils, Logging, nonthreadsafe, threadsafe} +import org.apache.kafka.common.{KafkaException, TopicPartition} +import org.apache.kafka.common.errors._ +import org.apache.kafka.common.protocol.types._ +import org.apache.kafka.common.record.{ControlRecordType, DefaultRecordBatch, EndTransactionMarker, RecordBatch} +import org.apache.kafka.common.utils.{ByteUtils, Crc32C, Time, Utils} + +import scala.jdk.CollectionConverters._ +import scala.collection.mutable.ListBuffer +import scala.collection.{immutable, mutable} + +class CorruptSnapshotException(msg: String) extends KafkaException(msg) + +/** + * The last written record for a given producer. The last data offset may be undefined + * if the only log entry for a producer is a transaction marker. + */ +case class LastRecord(lastDataOffset: Option[Long], producerEpoch: Short) + + +private[log] case class TxnMetadata( + producerId: Long, + firstOffset: LogOffsetMetadata, + var lastOffset: Option[Long] = None +) { + def this(producerId: Long, firstOffset: Long) = this(producerId, LogOffsetMetadata(firstOffset)) + + override def toString: String = { + "TxnMetadata(" + + s"producerId=$producerId, " + + s"firstOffset=$firstOffset, " + + s"lastOffset=$lastOffset)" + } +} + +private[log] object ProducerStateEntry { + private[log] val NumBatchesToRetain = 5 + + def empty(producerId: Long) = new ProducerStateEntry(producerId, + batchMetadata = mutable.Queue[BatchMetadata](), + producerEpoch = RecordBatch.NO_PRODUCER_EPOCH, + coordinatorEpoch = -1, + lastTimestamp = RecordBatch.NO_TIMESTAMP, + currentTxnFirstOffset = None) +} + +private[log] case class BatchMetadata(lastSeq: Int, lastOffset: Long, offsetDelta: Int, timestamp: Long) { + def firstSeq: Int = DefaultRecordBatch.decrementSequence(lastSeq, offsetDelta) + def firstOffset: Long = lastOffset - offsetDelta + + override def toString: String = { + "BatchMetadata(" + + s"firstSeq=$firstSeq, " + + s"lastSeq=$lastSeq, " + + s"firstOffset=$firstOffset, " + + s"lastOffset=$lastOffset, " + + s"timestamp=$timestamp)" + } +} + +// the batchMetadata is ordered such that the batch with the lowest sequence is at the head of the queue while the +// batch with the highest sequence is at the tail of the queue. We will retain at most ProducerStateEntry.NumBatchesToRetain +// elements in the queue. When the queue is at capacity, we remove the first element to make space for the incoming batch. +private[log] class ProducerStateEntry(val producerId: Long, + val batchMetadata: mutable.Queue[BatchMetadata], + var producerEpoch: Short, + var coordinatorEpoch: Int, + var lastTimestamp: Long, + var currentTxnFirstOffset: Option[Long]) { + + def firstSeq: Int = if (isEmpty) RecordBatch.NO_SEQUENCE else batchMetadata.front.firstSeq + + def firstDataOffset: Long = if (isEmpty) -1L else batchMetadata.front.firstOffset + + def lastSeq: Int = if (isEmpty) RecordBatch.NO_SEQUENCE else batchMetadata.last.lastSeq + + def lastDataOffset: Long = if (isEmpty) -1L else batchMetadata.last.lastOffset + + def lastOffsetDelta : Int = if (isEmpty) 0 else batchMetadata.last.offsetDelta + + def isEmpty: Boolean = batchMetadata.isEmpty + + def addBatch(producerEpoch: Short, lastSeq: Int, lastOffset: Long, offsetDelta: Int, timestamp: Long): Unit = { + maybeUpdateProducerEpoch(producerEpoch) + addBatchMetadata(BatchMetadata(lastSeq, lastOffset, offsetDelta, timestamp)) + this.lastTimestamp = timestamp + } + + def maybeUpdateProducerEpoch(producerEpoch: Short): Boolean = { + if (this.producerEpoch != producerEpoch) { + batchMetadata.clear() + this.producerEpoch = producerEpoch + true + } else { + false + } + } + + private def addBatchMetadata(batch: BatchMetadata): Unit = { + if (batchMetadata.size == ProducerStateEntry.NumBatchesToRetain) + batchMetadata.dequeue() + batchMetadata.enqueue(batch) + } + + def update(nextEntry: ProducerStateEntry): Unit = { + maybeUpdateProducerEpoch(nextEntry.producerEpoch) + while (nextEntry.batchMetadata.nonEmpty) + addBatchMetadata(nextEntry.batchMetadata.dequeue()) + this.coordinatorEpoch = nextEntry.coordinatorEpoch + this.currentTxnFirstOffset = nextEntry.currentTxnFirstOffset + this.lastTimestamp = nextEntry.lastTimestamp + } + + def findDuplicateBatch(batch: RecordBatch): Option[BatchMetadata] = { + if (batch.producerEpoch != producerEpoch) + None + else + batchWithSequenceRange(batch.baseSequence, batch.lastSequence) + } + + // Return the batch metadata of the cached batch having the exact sequence range, if any. + def batchWithSequenceRange(firstSeq: Int, lastSeq: Int): Option[BatchMetadata] = { + val duplicate = batchMetadata.filter { metadata => + firstSeq == metadata.firstSeq && lastSeq == metadata.lastSeq + } + duplicate.headOption + } + + override def toString: String = { + "ProducerStateEntry(" + + s"producerId=$producerId, " + + s"producerEpoch=$producerEpoch, " + + s"currentTxnFirstOffset=$currentTxnFirstOffset, " + + s"coordinatorEpoch=$coordinatorEpoch, " + + s"lastTimestamp=$lastTimestamp, " + + s"batchMetadata=$batchMetadata" + } +} + +/** + * This class is used to validate the records appended by a given producer before they are written to the log. + * It is initialized with the producer's state after the last successful append, and transitively validates the + * sequence numbers and epochs of each new record. Additionally, this class accumulates transaction metadata + * as the incoming records are validated. + * + * @param producerId The id of the producer appending to the log + * @param currentEntry The current entry associated with the producer id which contains metadata for a fixed number of + * the most recent appends made by the producer. Validation of the first incoming append will + * be made against the latest append in the current entry. New appends will replace older appends + * in the current entry so that the space overhead is constant. + * @param origin Indicates the origin of the append which implies the extent of validation. For example, offset + * commits, which originate from the group coordinator, do not have sequence numbers and therefore + * only producer epoch validation is done. Appends which come through replication are not validated + * (we assume the validation has already been done) and appends from clients require full validation. + */ +private[log] class ProducerAppendInfo(val topicPartition: TopicPartition, + val producerId: Long, + val currentEntry: ProducerStateEntry, + val origin: AppendOrigin) extends Logging { + + private val transactions = ListBuffer.empty[TxnMetadata] + private val updatedEntry = ProducerStateEntry.empty(producerId) + + updatedEntry.producerEpoch = currentEntry.producerEpoch + updatedEntry.coordinatorEpoch = currentEntry.coordinatorEpoch + updatedEntry.lastTimestamp = currentEntry.lastTimestamp + updatedEntry.currentTxnFirstOffset = currentEntry.currentTxnFirstOffset + + private def maybeValidateDataBatch(producerEpoch: Short, firstSeq: Int, offset: Long): Unit = { + checkProducerEpoch(producerEpoch, offset) + if (origin == AppendOrigin.Client) { + checkSequence(producerEpoch, firstSeq, offset) + } + } + + private def checkProducerEpoch(producerEpoch: Short, offset: Long): Unit = { + if (producerEpoch < updatedEntry.producerEpoch) { + val message = s"Epoch of producer $producerId at offset $offset in $topicPartition is $producerEpoch, " + + s"which is smaller than the last seen epoch ${updatedEntry.producerEpoch}" + + if (origin == AppendOrigin.Replication) { + warn(message) + } else { + // Starting from 2.7, we replaced ProducerFenced error with InvalidProducerEpoch in the + // producer send response callback to differentiate from the former fatal exception, + // letting client abort the ongoing transaction and retry. + throw new InvalidProducerEpochException(message) + } + } + } + + private def checkSequence(producerEpoch: Short, appendFirstSeq: Int, offset: Long): Unit = { + if (producerEpoch != updatedEntry.producerEpoch) { + if (appendFirstSeq != 0) { + if (updatedEntry.producerEpoch != RecordBatch.NO_PRODUCER_EPOCH) { + throw new OutOfOrderSequenceException(s"Invalid sequence number for new epoch of producer $producerId " + + s"at offset $offset in partition $topicPartition: $producerEpoch (request epoch), $appendFirstSeq (seq. number), " + + s"${updatedEntry.producerEpoch} (current producer epoch)") + } + } + } else { + val currentLastSeq = if (!updatedEntry.isEmpty) + updatedEntry.lastSeq + else if (producerEpoch == currentEntry.producerEpoch) + currentEntry.lastSeq + else + RecordBatch.NO_SEQUENCE + + // If there is no current producer epoch (possibly because all producer records have been deleted due to + // retention or the DeleteRecords API) accept writes with any sequence number + if (!(currentEntry.producerEpoch == RecordBatch.NO_PRODUCER_EPOCH || inSequence(currentLastSeq, appendFirstSeq))) { + throw new OutOfOrderSequenceException(s"Out of order sequence number for producer $producerId at " + + s"offset $offset in partition $topicPartition: $appendFirstSeq (incoming seq. number), " + + s"$currentLastSeq (current end sequence number)") + } + } + } + + private def inSequence(lastSeq: Int, nextSeq: Int): Boolean = { + nextSeq == lastSeq + 1L || (nextSeq == 0 && lastSeq == Int.MaxValue) + } + + def append(batch: RecordBatch, firstOffsetMetadataOpt: Option[LogOffsetMetadata]): Option[CompletedTxn] = { + if (batch.isControlBatch) { + val recordIterator = batch.iterator + if (recordIterator.hasNext) { + val record = recordIterator.next() + val endTxnMarker = EndTransactionMarker.deserialize(record) + appendEndTxnMarker(endTxnMarker, batch.producerEpoch, batch.baseOffset, record.timestamp) + } else { + // An empty control batch means the entire transaction has been cleaned from the log, so no need to append + None + } + } else { + val firstOffsetMetadata = firstOffsetMetadataOpt.getOrElse(LogOffsetMetadata(batch.baseOffset)) + appendDataBatch(batch.producerEpoch, batch.baseSequence, batch.lastSequence, batch.maxTimestamp, + firstOffsetMetadata, batch.lastOffset, batch.isTransactional) + None + } + } + + def appendDataBatch(epoch: Short, + firstSeq: Int, + lastSeq: Int, + lastTimestamp: Long, + firstOffsetMetadata: LogOffsetMetadata, + lastOffset: Long, + isTransactional: Boolean): Unit = { + val firstOffset = firstOffsetMetadata.messageOffset + maybeValidateDataBatch(epoch, firstSeq, firstOffset) + updatedEntry.addBatch(epoch, lastSeq, lastOffset, (lastOffset - firstOffset).toInt, lastTimestamp) + + updatedEntry.currentTxnFirstOffset match { + case Some(_) if !isTransactional => + // Received a non-transactional message while a transaction is active + throw new InvalidTxnStateException(s"Expected transactional write from producer $producerId at " + + s"offset $firstOffsetMetadata in partition $topicPartition") + + case None if isTransactional => + // Began a new transaction + updatedEntry.currentTxnFirstOffset = Some(firstOffset) + transactions += TxnMetadata(producerId, firstOffsetMetadata) + + case _ => // nothing to do + } + } + + private def checkCoordinatorEpoch(endTxnMarker: EndTransactionMarker, offset: Long): Unit = { + if (updatedEntry.coordinatorEpoch > endTxnMarker.coordinatorEpoch) { + if (origin == AppendOrigin.Replication) { + info(s"Detected invalid coordinator epoch for producerId $producerId at " + + s"offset $offset in partition $topicPartition: ${endTxnMarker.coordinatorEpoch} " + + s"is older than previously known coordinator epoch ${updatedEntry.coordinatorEpoch}") + } else { + throw new TransactionCoordinatorFencedException(s"Invalid coordinator epoch for producerId $producerId at " + + s"offset $offset in partition $topicPartition: ${endTxnMarker.coordinatorEpoch} " + + s"(zombie), ${updatedEntry.coordinatorEpoch} (current)") + } + } + } + + def appendEndTxnMarker( + endTxnMarker: EndTransactionMarker, + producerEpoch: Short, + offset: Long, + timestamp: Long + ): Option[CompletedTxn] = { + checkProducerEpoch(producerEpoch, offset) + checkCoordinatorEpoch(endTxnMarker, offset) + + // Only emit the `CompletedTxn` for non-empty transactions. A transaction marker + // without any associated data will not have any impact on the last stable offset + // and would not need to be reflected in the transaction index. + val completedTxn = updatedEntry.currentTxnFirstOffset.map { firstOffset => + CompletedTxn(producerId, firstOffset, offset, endTxnMarker.controlType == ControlRecordType.ABORT) + } + + updatedEntry.maybeUpdateProducerEpoch(producerEpoch) + updatedEntry.currentTxnFirstOffset = None + updatedEntry.coordinatorEpoch = endTxnMarker.coordinatorEpoch + updatedEntry.lastTimestamp = timestamp + + completedTxn + } + + def toEntry: ProducerStateEntry = updatedEntry + + def startedTransactions: List[TxnMetadata] = transactions.toList + + override def toString: String = { + "ProducerAppendInfo(" + + s"producerId=$producerId, " + + s"producerEpoch=${updatedEntry.producerEpoch}, " + + s"firstSequence=${updatedEntry.firstSeq}, " + + s"lastSequence=${updatedEntry.lastSeq}, " + + s"currentTxnFirstOffset=${updatedEntry.currentTxnFirstOffset}, " + + s"coordinatorEpoch=${updatedEntry.coordinatorEpoch}, " + + s"lastTimestamp=${updatedEntry.lastTimestamp}, " + + s"startedTransactions=$transactions)" + } +} + +object ProducerStateManager { + private val ProducerSnapshotVersion: Short = 1 + private val VersionField = "version" + private val CrcField = "crc" + private val ProducerIdField = "producer_id" + private val LastSequenceField = "last_sequence" + private val ProducerEpochField = "epoch" + private val LastOffsetField = "last_offset" + private val OffsetDeltaField = "offset_delta" + private val TimestampField = "timestamp" + private val ProducerEntriesField = "producer_entries" + private val CoordinatorEpochField = "coordinator_epoch" + private val CurrentTxnFirstOffsetField = "current_txn_first_offset" + + private val VersionOffset = 0 + private val CrcOffset = VersionOffset + 2 + private val ProducerEntriesOffset = CrcOffset + 4 + + val ProducerSnapshotEntrySchema = new Schema( + new Field(ProducerIdField, Type.INT64, "The producer ID"), + new Field(ProducerEpochField, Type.INT16, "Current epoch of the producer"), + new Field(LastSequenceField, Type.INT32, "Last written sequence of the producer"), + new Field(LastOffsetField, Type.INT64, "Last written offset of the producer"), + new Field(OffsetDeltaField, Type.INT32, "The difference of the last sequence and first sequence in the last written batch"), + new Field(TimestampField, Type.INT64, "Max timestamp from the last written entry"), + new Field(CoordinatorEpochField, Type.INT32, "The epoch of the last transaction coordinator to send an end transaction marker"), + new Field(CurrentTxnFirstOffsetField, Type.INT64, "The first offset of the on-going transaction (-1 if there is none)")) + val PidSnapshotMapSchema = new Schema( + new Field(VersionField, Type.INT16, "Version of the snapshot file"), + new Field(CrcField, Type.UNSIGNED_INT32, "CRC of the snapshot data"), + new Field(ProducerEntriesField, new ArrayOf(ProducerSnapshotEntrySchema), "The entries in the producer table")) + + def readSnapshot(file: File): Iterable[ProducerStateEntry] = { + try { + val buffer = Files.readAllBytes(file.toPath) + val struct = PidSnapshotMapSchema.read(ByteBuffer.wrap(buffer)) + + val version = struct.getShort(VersionField) + if (version != ProducerSnapshotVersion) + throw new CorruptSnapshotException(s"Snapshot contained an unknown file version $version") + + val crc = struct.getUnsignedInt(CrcField) + val computedCrc = Crc32C.compute(buffer, ProducerEntriesOffset, buffer.length - ProducerEntriesOffset) + if (crc != computedCrc) + throw new CorruptSnapshotException(s"Snapshot is corrupt (CRC is no longer valid). " + + s"Stored crc: $crc. Computed crc: $computedCrc") + + struct.getArray(ProducerEntriesField).map { producerEntryObj => + val producerEntryStruct = producerEntryObj.asInstanceOf[Struct] + val producerId = producerEntryStruct.getLong(ProducerIdField) + val producerEpoch = producerEntryStruct.getShort(ProducerEpochField) + val seq = producerEntryStruct.getInt(LastSequenceField) + val offset = producerEntryStruct.getLong(LastOffsetField) + val timestamp = producerEntryStruct.getLong(TimestampField) + val offsetDelta = producerEntryStruct.getInt(OffsetDeltaField) + val coordinatorEpoch = producerEntryStruct.getInt(CoordinatorEpochField) + val currentTxnFirstOffset = producerEntryStruct.getLong(CurrentTxnFirstOffsetField) + val lastAppendedDataBatches = mutable.Queue.empty[BatchMetadata] + if (offset >= 0) + lastAppendedDataBatches += BatchMetadata(seq, offset, offsetDelta, timestamp) + + val newEntry = new ProducerStateEntry(producerId, lastAppendedDataBatches, producerEpoch, + coordinatorEpoch, timestamp, if (currentTxnFirstOffset >= 0) Some(currentTxnFirstOffset) else None) + newEntry + } + } catch { + case e: SchemaException => + throw new CorruptSnapshotException(s"Snapshot failed schema validation: ${e.getMessage}") + } + } + + private def writeSnapshot(file: File, entries: mutable.Map[Long, ProducerStateEntry]): Unit = { + val struct = new Struct(PidSnapshotMapSchema) + struct.set(VersionField, ProducerSnapshotVersion) + struct.set(CrcField, 0L) // we'll fill this after writing the entries + val entriesArray = entries.map { + case (producerId, entry) => + val producerEntryStruct = struct.instance(ProducerEntriesField) + producerEntryStruct.set(ProducerIdField, producerId) + .set(ProducerEpochField, entry.producerEpoch) + .set(LastSequenceField, entry.lastSeq) + .set(LastOffsetField, entry.lastDataOffset) + .set(OffsetDeltaField, entry.lastOffsetDelta) + .set(TimestampField, entry.lastTimestamp) + .set(CoordinatorEpochField, entry.coordinatorEpoch) + .set(CurrentTxnFirstOffsetField, entry.currentTxnFirstOffset.getOrElse(-1L)) + producerEntryStruct + }.toArray + struct.set(ProducerEntriesField, entriesArray) + + val buffer = ByteBuffer.allocate(struct.sizeOf) + struct.writeTo(buffer) + buffer.flip() + + // now fill in the CRC + val crc = Crc32C.compute(buffer, ProducerEntriesOffset, buffer.limit() - ProducerEntriesOffset) + ByteUtils.writeUnsignedInt(buffer, CrcOffset, crc) + + val fileChannel = FileChannel.open(file.toPath, StandardOpenOption.CREATE, StandardOpenOption.WRITE) + try { + fileChannel.write(buffer) + fileChannel.force(true) + } finally { + fileChannel.close() + } + } + + private def isSnapshotFile(file: File): Boolean = file.getName.endsWith(UnifiedLog.ProducerSnapshotFileSuffix) + + // visible for testing + private[log] def listSnapshotFiles(dir: File): Seq[SnapshotFile] = { + if (dir.exists && dir.isDirectory) { + Option(dir.listFiles).map { files => + files.filter(f => f.isFile && isSnapshotFile(f)).map(SnapshotFile(_)).toSeq + }.getOrElse(Seq.empty) + } else Seq.empty + } +} + +/** + * Maintains a mapping from ProducerIds to metadata about the last appended entries (e.g. + * epoch, sequence number, last offset, etc.) + * + * The sequence number is the last number successfully appended to the partition for the given identifier. + * The epoch is used for fencing against zombie writers. The offset is the one of the last successful message + * appended to the partition. + * + * As long as a producer id is contained in the map, the corresponding producer can continue to write data. + * However, producer ids can be expired due to lack of recent use or if the last written entry has been deleted from + * the log (e.g. if the retention policy is "delete"). For compacted topics, the log cleaner will ensure + * that the most recent entry from a given producer id is retained in the log provided it hasn't expired due to + * age. This ensures that producer ids will not be expired until either the max expiration time has been reached, + * or if the topic also is configured for deletion, the segment containing the last written offset has + * been deleted. + */ +@nonthreadsafe +class ProducerStateManager(val topicPartition: TopicPartition, + @volatile var _logDir: File, + val maxProducerIdExpirationMs: Int = 60 * 60 * 1000, + val time: Time = Time.SYSTEM) extends Logging { + import ProducerStateManager._ + import java.util + + this.logIdent = s"[ProducerStateManager partition=$topicPartition] " + + private var snapshots: ConcurrentSkipListMap[java.lang.Long, SnapshotFile] = locally { + loadSnapshots() + } + + private val producers = mutable.Map.empty[Long, ProducerStateEntry] + private var lastMapOffset = 0L + private var lastSnapOffset = 0L + + // ongoing transactions sorted by the first offset of the transaction + private val ongoingTxns = new util.TreeMap[Long, TxnMetadata] + + // completed transactions whose markers are at offsets above the high watermark + private val unreplicatedTxns = new util.TreeMap[Long, TxnMetadata] + + /** + * Load producer state snapshots by scanning the _logDir. + */ + private def loadSnapshots(): ConcurrentSkipListMap[java.lang.Long, SnapshotFile] = { + val tm = new ConcurrentSkipListMap[java.lang.Long, SnapshotFile]() + for (f <- listSnapshotFiles(_logDir)) { + tm.put(f.offset, f) + } + tm + } + + /** + * Scans the log directory, gathering all producer state snapshot files. Snapshot files which do not have an offset + * corresponding to one of the provided offsets in segmentBaseOffsets will be removed, except in the case that there + * is a snapshot file at a higher offset than any offset in segmentBaseOffsets. + * + * The goal here is to remove any snapshot files which do not have an associated segment file, but not to remove the + * largest stray snapshot file which was emitted during clean shutdown. + */ + private[log] def removeStraySnapshots(segmentBaseOffsets: Seq[Long]): Unit = { + val maxSegmentBaseOffset = if (segmentBaseOffsets.isEmpty) None else Some(segmentBaseOffsets.max) + val baseOffsets = segmentBaseOffsets.toSet + var latestStraySnapshot: Option[SnapshotFile] = None + + val ss = loadSnapshots() + for (snapshot <- ss.values().asScala) { + val key = snapshot.offset + latestStraySnapshot match { + case Some(prev) => + if (!baseOffsets.contains(key)) { + // this snapshot is now the largest stray snapshot. + prev.deleteIfExists() + ss.remove(prev.offset) + latestStraySnapshot = Some(snapshot) + } + case None => + if (!baseOffsets.contains(key)) { + latestStraySnapshot = Some(snapshot) + } + } + } + + // Check to see if the latestStraySnapshot is larger than the largest segment base offset, if it is not, + // delete the largestStraySnapshot. + for (strayOffset <- latestStraySnapshot.map(_.offset); maxOffset <- maxSegmentBaseOffset) { + if (strayOffset < maxOffset) { + Option(ss.remove(strayOffset)).foreach(_.deleteIfExists()) + } + } + + this.snapshots = ss + } + + /** + * An unstable offset is one which is either undecided (i.e. its ultimate outcome is not yet known), + * or one that is decided, but may not have been replicated (i.e. any transaction which has a COMMIT/ABORT + * marker written at a higher offset than the current high watermark). + */ + def firstUnstableOffset: Option[LogOffsetMetadata] = { + val unreplicatedFirstOffset = Option(unreplicatedTxns.firstEntry).map(_.getValue.firstOffset) + val undecidedFirstOffset = Option(ongoingTxns.firstEntry).map(_.getValue.firstOffset) + if (unreplicatedFirstOffset.isEmpty) + undecidedFirstOffset + else if (undecidedFirstOffset.isEmpty) + unreplicatedFirstOffset + else if (undecidedFirstOffset.get.messageOffset < unreplicatedFirstOffset.get.messageOffset) + undecidedFirstOffset + else + unreplicatedFirstOffset + } + + /** + * Acknowledge all transactions which have been completed before a given offset. This allows the LSO + * to advance to the next unstable offset. + */ + def onHighWatermarkUpdated(highWatermark: Long): Unit = { + removeUnreplicatedTransactions(highWatermark) + } + + /** + * The first undecided offset is the earliest transactional message which has not yet been committed + * or aborted. Unlike [[firstUnstableOffset]], this does not reflect the state of replication (i.e. + * whether a completed transaction marker is beyond the high watermark). + */ + private[log] def firstUndecidedOffset: Option[Long] = Option(ongoingTxns.firstEntry).map(_.getValue.firstOffset.messageOffset) + + /** + * Returns the last offset of this map + */ + def mapEndOffset: Long = lastMapOffset + + /** + * Get a copy of the active producers + */ + def activeProducers: immutable.Map[Long, ProducerStateEntry] = producers.toMap + + def isEmpty: Boolean = producers.isEmpty && unreplicatedTxns.isEmpty + + private def loadFromSnapshot(logStartOffset: Long, currentTime: Long): Unit = { + while (true) { + latestSnapshotFile match { + case Some(snapshot) => + try { + info(s"Loading producer state from snapshot file '$snapshot'") + val loadedProducers = readSnapshot(snapshot.file).filter { producerEntry => !isProducerExpired(currentTime, producerEntry) } + loadedProducers.foreach(loadProducerEntry) + lastSnapOffset = snapshot.offset + lastMapOffset = lastSnapOffset + return + } catch { + case e: CorruptSnapshotException => + warn(s"Failed to load producer snapshot from '${snapshot.file}': ${e.getMessage}") + removeAndDeleteSnapshot(snapshot.offset) + } + case None => + lastSnapOffset = logStartOffset + lastMapOffset = logStartOffset + return + } + } + } + + // visible for testing + private[log] def loadProducerEntry(entry: ProducerStateEntry): Unit = { + val producerId = entry.producerId + producers.put(producerId, entry) + entry.currentTxnFirstOffset.foreach { offset => + ongoingTxns.put(offset, new TxnMetadata(producerId, offset)) + } + } + + private def isProducerExpired(currentTimeMs: Long, producerState: ProducerStateEntry): Boolean = + producerState.currentTxnFirstOffset.isEmpty && currentTimeMs - producerState.lastTimestamp >= maxProducerIdExpirationMs + + /** + * Expire any producer ids which have been idle longer than the configured maximum expiration timeout. + */ + def removeExpiredProducers(currentTimeMs: Long): Unit = { + producers --= producers.filter { case (_, lastEntry) => isProducerExpired(currentTimeMs, lastEntry) }.keySet + } + + /** + * Truncate the producer id mapping to the given offset range and reload the entries from the most recent + * snapshot in range (if there is one). We delete snapshot files prior to the logStartOffset but do not remove + * producer state from the map. This means that in-memory and on-disk state can diverge, and in the case of + * broker failover or unclean shutdown, any in-memory state not persisted in the snapshots will be lost, which + * would lead to UNKNOWN_PRODUCER_ID errors. Note that the log end offset is assumed to be less than or equal + * to the high watermark. + */ + def truncateAndReload(logStartOffset: Long, logEndOffset: Long, currentTimeMs: Long): Unit = { + // remove all out of range snapshots + snapshots.values().asScala.foreach { snapshot => + if (snapshot.offset > logEndOffset || snapshot.offset <= logStartOffset) { + removeAndDeleteSnapshot(snapshot.offset) + } + } + + if (logEndOffset != mapEndOffset) { + producers.clear() + ongoingTxns.clear() + + // since we assume that the offset is less than or equal to the high watermark, it is + // safe to clear the unreplicated transactions + unreplicatedTxns.clear() + loadFromSnapshot(logStartOffset, currentTimeMs) + } else { + onLogStartOffsetIncremented(logStartOffset) + } + } + + def prepareUpdate(producerId: Long, origin: AppendOrigin): ProducerAppendInfo = { + val currentEntry = lastEntry(producerId).getOrElse(ProducerStateEntry.empty(producerId)) + new ProducerAppendInfo(topicPartition, producerId, currentEntry, origin) + } + + /** + * Update the mapping with the given append information + */ + def update(appendInfo: ProducerAppendInfo): Unit = { + if (appendInfo.producerId == RecordBatch.NO_PRODUCER_ID) + throw new IllegalArgumentException(s"Invalid producer id ${appendInfo.producerId} passed to update " + + s"for partition $topicPartition") + + trace(s"Updated producer ${appendInfo.producerId} state to $appendInfo") + val updatedEntry = appendInfo.toEntry + producers.get(appendInfo.producerId) match { + case Some(currentEntry) => + currentEntry.update(updatedEntry) + + case None => + producers.put(appendInfo.producerId, updatedEntry) + } + + appendInfo.startedTransactions.foreach { txn => + ongoingTxns.put(txn.firstOffset.messageOffset, txn) + } + } + + def updateMapEndOffset(lastOffset: Long): Unit = { + lastMapOffset = lastOffset + } + + /** + * Get the last written entry for the given producer id. + */ + def lastEntry(producerId: Long): Option[ProducerStateEntry] = producers.get(producerId) + + /** + * Take a snapshot at the current end offset if one does not already exist. + */ + def takeSnapshot(): Unit = { + // If not a new offset, then it is not worth taking another snapshot + if (lastMapOffset > lastSnapOffset) { + val snapshotFile = SnapshotFile(UnifiedLog.producerSnapshotFile(_logDir, lastMapOffset)) + val start = time.hiResClockMs() + writeSnapshot(snapshotFile.file, producers) + info(s"Wrote producer snapshot at offset $lastMapOffset with ${producers.size} producer ids in ${time.hiResClockMs() - start} ms.") + + snapshots.put(snapshotFile.offset, snapshotFile) + + // Update the last snap offset according to the serialized map + lastSnapOffset = lastMapOffset + } + } + + /** + * Update the parentDir for this ProducerStateManager and all of the snapshot files which it manages. + */ + def updateParentDir(parentDir: File): Unit = { + _logDir = parentDir + snapshots.forEach((_, s) => s.updateParentDir(parentDir)) + } + + /** + * Get the last offset (exclusive) of the latest snapshot file. + */ + def latestSnapshotOffset: Option[Long] = latestSnapshotFile.map(_.offset) + + /** + * Get the last offset (exclusive) of the oldest snapshot file. + */ + def oldestSnapshotOffset: Option[Long] = oldestSnapshotFile.map(_.offset) + + /** + * Visible for testing + */ + private[log] def snapshotFileForOffset(offset: Long): Option[SnapshotFile] = { + Option(snapshots.get(offset)) + } + + /** + * Remove any unreplicated transactions lower than the provided logStartOffset and bring the lastMapOffset forward + * if necessary. + */ + def onLogStartOffsetIncremented(logStartOffset: Long): Unit = { + removeUnreplicatedTransactions(logStartOffset) + + if (lastMapOffset < logStartOffset) + lastMapOffset = logStartOffset + + lastSnapOffset = latestSnapshotOffset.getOrElse(logStartOffset) + } + + private def removeUnreplicatedTransactions(offset: Long): Unit = { + val iterator = unreplicatedTxns.entrySet.iterator + while (iterator.hasNext) { + val txnEntry = iterator.next() + val lastOffset = txnEntry.getValue.lastOffset + if (lastOffset.exists(_ < offset)) + iterator.remove() + } + } + + /** + * Truncate the producer id mapping and remove all snapshots. This resets the state of the mapping. + */ + def truncateFullyAndStartAt(offset: Long): Unit = { + producers.clear() + ongoingTxns.clear() + unreplicatedTxns.clear() + snapshots.values().asScala.foreach { snapshot => + removeAndDeleteSnapshot(snapshot.offset) + } + lastSnapOffset = 0L + lastMapOffset = offset + } + + /** + * Compute the last stable offset of a completed transaction, but do not yet mark the transaction complete. + * That will be done in `completeTxn` below. This is used to compute the LSO that will be appended to the + * transaction index, but the completion must be done only after successfully appending to the index. + */ + def lastStableOffset(completedTxn: CompletedTxn): Long = { + val nextIncompleteTxn = ongoingTxns.values.asScala.find(_.producerId != completedTxn.producerId) + nextIncompleteTxn.map(_.firstOffset.messageOffset).getOrElse(completedTxn.lastOffset + 1) + } + + /** + * Mark a transaction as completed. We will still await advancement of the high watermark before + * advancing the first unstable offset. + */ + def completeTxn(completedTxn: CompletedTxn): Unit = { + val txnMetadata = ongoingTxns.remove(completedTxn.firstOffset) + if (txnMetadata == null) + throw new IllegalArgumentException(s"Attempted to complete transaction $completedTxn on partition $topicPartition " + + s"which was not started") + + txnMetadata.lastOffset = Some(completedTxn.lastOffset) + unreplicatedTxns.put(completedTxn.firstOffset, txnMetadata) + } + + @threadsafe + def deleteSnapshotsBefore(offset: Long): Unit = { + snapshots.subMap(0, offset).values().asScala.foreach { snapshot => + removeAndDeleteSnapshot(snapshot.offset) + } + } + + private def oldestSnapshotFile: Option[SnapshotFile] = { + Option(snapshots.firstEntry()).map(_.getValue) + } + + private def latestSnapshotFile: Option[SnapshotFile] = { + Option(snapshots.lastEntry()).map(_.getValue) + } + + /** + * Removes the producer state snapshot file metadata corresponding to the provided offset if it exists from this + * ProducerStateManager, and deletes the backing snapshot file. + */ + private def removeAndDeleteSnapshot(snapshotOffset: Long): Unit = { + Option(snapshots.remove(snapshotOffset)).foreach(_.deleteIfExists()) + } + + /** + * Removes the producer state snapshot file metadata corresponding to the provided offset if it exists from this + * ProducerStateManager, and renames the backing snapshot file to have the Log.DeletionSuffix. + * + * Note: This method is safe to use with async deletes. If a race occurs and the snapshot file + * is deleted without this ProducerStateManager instance knowing, the resulting exception on + * SnapshotFile rename will be ignored and None will be returned. + */ + private[log] def removeAndMarkSnapshotForDeletion(snapshotOffset: Long): Option[SnapshotFile] = { + Option(snapshots.remove(snapshotOffset)).flatMap { snapshot => { + // If the file cannot be renamed, it likely means that the file was deleted already. + // This can happen due to the way we construct an intermediate producer state manager + // during log recovery, and use it to issue deletions prior to creating the "real" + // producer state manager. + // + // In any case, removeAndMarkSnapshotForDeletion is intended to be used for snapshot file + // deletion, so ignoring the exception here just means that the intended operation was + // already completed. + try { + snapshot.renameTo(UnifiedLog.DeletedFileSuffix) + Some(snapshot) + } catch { + case _: NoSuchFileException => + info(s"Failed to rename producer state snapshot ${snapshot.file.getAbsoluteFile} with deletion suffix because it was already deleted") + None + } + } + } + } +} + +case class SnapshotFile private[log] (@volatile private var _file: File, + offset: Long) extends Logging { + def deleteIfExists(): Boolean = { + val deleted = Files.deleteIfExists(file.toPath) + if (deleted) { + info(s"Deleted producer state snapshot ${file.getAbsolutePath}") + } else { + info(s"Failed to delete producer state snapshot ${file.getAbsolutePath} because it does not exist.") + } + deleted + } + + def updateParentDir(parentDir: File): Unit = { + _file = new File(parentDir, _file.getName) + } + + def file: File = { + _file + } + + def renameTo(newSuffix: String): Unit = { + val renamed = new File(CoreUtils.replaceSuffix(_file.getPath, "", newSuffix)) + try { + Utils.atomicMoveWithFallback(_file.toPath, renamed.toPath) + } finally { + _file = renamed + } + } +} + +object SnapshotFile { + def apply(file: File): SnapshotFile = { + val offset = offsetFromFile(file) + SnapshotFile(file, offset) + } +} diff --git a/core/src/main/scala/kafka/log/TimeIndex.scala b/core/src/main/scala/kafka/log/TimeIndex.scala new file mode 100644 index 0000000..779a451 --- /dev/null +++ b/core/src/main/scala/kafka/log/TimeIndex.scala @@ -0,0 +1,229 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import java.io.File +import java.nio.ByteBuffer + +import kafka.utils.CoreUtils.inLock +import kafka.utils.Logging +import org.apache.kafka.common.errors.InvalidOffsetException +import org.apache.kafka.common.record.RecordBatch + +/** + * An index that maps from the timestamp to the logical offsets of the messages in a segment. This index might be + * sparse, i.e. it may not hold an entry for all the messages in the segment. + * + * The index is stored in a file that is preallocated to hold a fixed maximum amount of 12-byte time index entries. + * The file format is a series of time index entries. The physical format is a 8 bytes timestamp and a 4 bytes "relative" + * offset used in the [[OffsetIndex]]. A time index entry (TIMESTAMP, OFFSET) means that the biggest timestamp seen + * before OFFSET is TIMESTAMP. i.e. Any message whose timestamp is greater than TIMESTAMP must come after OFFSET. + * + * All external APIs translate from relative offsets to full offsets, so users of this class do not interact with the internal + * storage format. + * + * The timestamps in the same time index file are guaranteed to be monotonically increasing. + * + * The index supports timestamp lookup for a memory map of this file. The lookup is done using a binary search to find + * the offset of the message whose indexed timestamp is closest but smaller or equals to the target timestamp. + * + * Time index files can be opened in two ways: either as an empty, mutable index that allows appending or + * an immutable read-only index file that has previously been populated. The makeReadOnly method will turn a mutable file into an + * immutable one and truncate off any extra bytes. This is done when the index file is rolled over. + * + * No attempt is made to checksum the contents of this file, in the event of a crash it is rebuilt. + * + */ +// Avoid shadowing mutable file in AbstractIndex +class TimeIndex(_file: File, baseOffset: Long, maxIndexSize: Int = -1, writable: Boolean = true) + extends AbstractIndex(_file, baseOffset, maxIndexSize, writable) { + import TimeIndex._ + + @volatile private var _lastEntry = lastEntryFromIndexFile + + override def entrySize = 12 + + debug(s"Loaded index file ${file.getAbsolutePath} with maxEntries = $maxEntries, maxIndexSize = $maxIndexSize," + + s" entries = ${_entries}, lastOffset = ${_lastEntry}, file position = ${mmap.position()}") + + // We override the full check to reserve the last time index entry slot for the on roll call. + override def isFull: Boolean = entries >= maxEntries - 1 + + private def timestamp(buffer: ByteBuffer, n: Int): Long = buffer.getLong(n * entrySize) + + private def relativeOffset(buffer: ByteBuffer, n: Int): Int = buffer.getInt(n * entrySize + 8) + + def lastEntry: TimestampOffset = _lastEntry + + /** + * Read the last entry from the index file. This operation involves disk access. + */ + private def lastEntryFromIndexFile: TimestampOffset = { + inLock(lock) { + _entries match { + case 0 => TimestampOffset(RecordBatch.NO_TIMESTAMP, baseOffset) + case s => parseEntry(mmap, s - 1) + } + } + } + + /** + * Get the nth timestamp mapping from the time index + * @param n The entry number in the time index + * @return The timestamp/offset pair at that entry + */ + def entry(n: Int): TimestampOffset = { + maybeLock(lock) { + if(n >= _entries) + throw new IllegalArgumentException(s"Attempt to fetch the ${n}th entry from time index ${file.getAbsolutePath} " + + s"which has size ${_entries}.") + parseEntry(mmap, n) + } + } + + override def parseEntry(buffer: ByteBuffer, n: Int): TimestampOffset = { + TimestampOffset(timestamp(buffer, n), baseOffset + relativeOffset(buffer, n)) + } + + /** + * Attempt to append a time index entry to the time index. + * The new entry is appended only if both the timestamp and offset are greater than the last appended timestamp and + * the last appended offset. + * + * @param timestamp The timestamp of the new time index entry + * @param offset The offset of the new time index entry + * @param skipFullCheck To skip checking whether the segment is full or not. We only skip the check when the segment + * gets rolled or the segment is closed. + */ + def maybeAppend(timestamp: Long, offset: Long, skipFullCheck: Boolean = false): Unit = { + inLock(lock) { + if (!skipFullCheck) + require(!isFull, "Attempt to append to a full time index (size = " + _entries + ").") + // We do not throw exception when the offset equals to the offset of last entry. That means we are trying + // to insert the same time index entry as the last entry. + // If the timestamp index entry to be inserted is the same as the last entry, we simply ignore the insertion + // because that could happen in the following two scenarios: + // 1. A log segment is closed. + // 2. LogSegment.onBecomeInactiveSegment() is called when an active log segment is rolled. + if (_entries != 0 && offset < lastEntry.offset) + throw new InvalidOffsetException(s"Attempt to append an offset ($offset) to slot ${_entries} no larger than" + + s" the last offset appended (${lastEntry.offset}) to ${file.getAbsolutePath}.") + if (_entries != 0 && timestamp < lastEntry.timestamp) + throw new IllegalStateException(s"Attempt to append a timestamp ($timestamp) to slot ${_entries} no larger" + + s" than the last timestamp appended (${lastEntry.timestamp}) to ${file.getAbsolutePath}.") + // We only append to the time index when the timestamp is greater than the last inserted timestamp. + // If all the messages are in message format v0, the timestamp will always be NoTimestamp. In that case, the time + // index will be empty. + if (timestamp > lastEntry.timestamp) { + trace(s"Adding index entry $timestamp => $offset to ${file.getAbsolutePath}.") + mmap.putLong(timestamp) + mmap.putInt(relativeOffset(offset)) + _entries += 1 + _lastEntry = TimestampOffset(timestamp, offset) + require(_entries * entrySize == mmap.position(), s"${_entries} entries but file position in index is ${mmap.position()}.") + } + } + } + + /** + * Find the time index entry whose timestamp is less than or equal to the given timestamp. + * If the target timestamp is smaller than the least timestamp in the time index, (NoTimestamp, baseOffset) is + * returned. + * + * @param targetTimestamp The timestamp to look up. + * @return The time index entry found. + */ + def lookup(targetTimestamp: Long): TimestampOffset = { + maybeLock(lock) { + val idx = mmap.duplicate + val slot = largestLowerBoundSlotFor(idx, targetTimestamp, IndexSearchType.KEY) + if (slot == -1) + TimestampOffset(RecordBatch.NO_TIMESTAMP, baseOffset) + else + parseEntry(idx, slot) + } + } + + override def truncate() = truncateToEntries(0) + + /** + * Remove all entries from the index which have an offset greater than or equal to the given offset. + * Truncating to an offset larger than the largest in the index has no effect. + */ + override def truncateTo(offset: Long): Unit = { + inLock(lock) { + val idx = mmap.duplicate + val slot = largestLowerBoundSlotFor(idx, offset, IndexSearchType.VALUE) + + /* There are 3 cases for choosing the new size + * 1) if there is no entry in the index <= the offset, delete everything + * 2) if there is an entry for this exact offset, delete it and everything larger than it + * 3) if there is no entry for this offset, delete everything larger than the next smallest + */ + val newEntries = + if(slot < 0) + 0 + else if(relativeOffset(idx, slot) == offset - baseOffset) + slot + else + slot + 1 + truncateToEntries(newEntries) + } + } + + override def resize(newSize: Int): Boolean = { + inLock(lock) { + if (super.resize(newSize)) { + _lastEntry = lastEntryFromIndexFile + true + } else + false + } + } + + /** + * Truncates index to a known number of entries. + */ + private def truncateToEntries(entries: Int): Unit = { + inLock(lock) { + _entries = entries + mmap.position(_entries * entrySize) + _lastEntry = lastEntryFromIndexFile + debug(s"Truncated index ${file.getAbsolutePath} to $entries entries; position is now ${mmap.position()} and last entry is now ${_lastEntry}") + } + } + + override def sanityCheck(): Unit = { + val lastTimestamp = lastEntry.timestamp + val lastOffset = lastEntry.offset + if (_entries != 0 && lastTimestamp < timestamp(mmap, 0)) + throw new CorruptIndexException(s"Corrupt time index found, time index file (${file.getAbsolutePath}) has " + + s"non-zero size but the last timestamp is $lastTimestamp which is less than the first timestamp " + + s"${timestamp(mmap, 0)}") + if (_entries != 0 && lastOffset < baseOffset) + throw new CorruptIndexException(s"Corrupt time index found, time index file (${file.getAbsolutePath}) has " + + s"non-zero size but the last offset is $lastOffset which is less than the first offset $baseOffset") + if (length % entrySize != 0) + throw new CorruptIndexException(s"Time index file ${file.getAbsolutePath} is corrupt, found $length bytes " + + s"which is neither positive nor a multiple of $entrySize.") + } +} + +object TimeIndex extends Logging { + override val loggerName: String = classOf[TimeIndex].getName +} diff --git a/core/src/main/scala/kafka/log/TransactionIndex.scala b/core/src/main/scala/kafka/log/TransactionIndex.scala new file mode 100644 index 0000000..ca3d1bb --- /dev/null +++ b/core/src/main/scala/kafka/log/TransactionIndex.scala @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.log + +import java.io.{File, IOException} +import java.nio.ByteBuffer +import java.nio.channels.FileChannel +import java.nio.file.{Files, StandardOpenOption} +import kafka.utils.{Logging, nonthreadsafe} +import org.apache.kafka.common.KafkaException +import org.apache.kafka.common.message.FetchResponseData +import org.apache.kafka.common.utils.Utils + +import scala.collection.mutable.ListBuffer + +private[log] case class TxnIndexSearchResult(abortedTransactions: List[AbortedTxn], isComplete: Boolean) + +/** + * The transaction index maintains metadata about the aborted transactions for each segment. This includes + * the start and end offsets for the aborted transactions and the last stable offset (LSO) at the time of + * the abort. This index is used to find the aborted transactions in the range of a given fetch request at + * the READ_COMMITTED isolation level. + * + * There is at most one transaction index for each log segment. The entries correspond to the transactions + * whose commit markers were written in the corresponding log segment. Note, however, that individual transactions + * may span multiple segments. Recovering the index therefore requires scanning the earlier segments in + * order to find the start of the transactions. + */ +@nonthreadsafe +class TransactionIndex(val startOffset: Long, @volatile private var _file: File) extends Logging { + + // note that the file is not created until we need it + @volatile private var maybeChannel: Option[FileChannel] = None + private var lastOffset: Option[Long] = None + + if (_file.exists) + openChannel() + + def append(abortedTxn: AbortedTxn): Unit = { + lastOffset.foreach { offset => + if (offset >= abortedTxn.lastOffset) + throw new IllegalArgumentException(s"The last offset of appended transactions must increase sequentially, but " + + s"${abortedTxn.lastOffset} is not greater than current last offset $offset of index ${file.getAbsolutePath}") + } + lastOffset = Some(abortedTxn.lastOffset) + Utils.writeFully(channel(), abortedTxn.buffer.duplicate()) + } + + def flush(): Unit = maybeChannel.foreach(_.force(true)) + + def file: File = _file + + def updateParentDir(parentDir: File): Unit = _file = new File(parentDir, file.getName) + + /** + * Delete this index. + * + * @throws IOException if deletion fails due to an I/O error + * @return `true` if the file was deleted by this method; `false` if the file could not be deleted because it did + * not exist + */ + def deleteIfExists(): Boolean = { + close() + Files.deleteIfExists(file.toPath) + } + + private def channel(): FileChannel = { + maybeChannel match { + case Some(channel) => channel + case None => openChannel() + } + } + + private def openChannel(): FileChannel = { + val channel = FileChannel.open(file.toPath, StandardOpenOption.CREATE, StandardOpenOption.READ, + StandardOpenOption.WRITE) + maybeChannel = Some(channel) + channel.position(channel.size) + channel + } + + /** + * Remove all the entries from the index. Unlike `AbstractIndex`, this index is not resized ahead of time. + */ + def reset(): Unit = { + maybeChannel.foreach(_.truncate(0)) + lastOffset = None + } + + def close(): Unit = { + maybeChannel.foreach(_.close()) + maybeChannel = None + } + + def renameTo(f: File): Unit = { + try { + if (file.exists) + Utils.atomicMoveWithFallback(file.toPath, f.toPath, false) + } finally _file = f + } + + def truncateTo(offset: Long): Unit = { + val buffer = ByteBuffer.allocate(AbortedTxn.TotalSize) + var newLastOffset: Option[Long] = None + for ((abortedTxn, position) <- iterator(() => buffer)) { + if (abortedTxn.lastOffset >= offset) { + channel().truncate(position) + lastOffset = newLastOffset + return + } + newLastOffset = Some(abortedTxn.lastOffset) + } + } + + private def iterator(allocate: () => ByteBuffer = () => ByteBuffer.allocate(AbortedTxn.TotalSize)): Iterator[(AbortedTxn, Int)] = { + maybeChannel match { + case None => Iterator.empty + case Some(channel) => + var position = 0 + + new Iterator[(AbortedTxn, Int)] { + override def hasNext: Boolean = channel.position - position >= AbortedTxn.TotalSize + + override def next(): (AbortedTxn, Int) = { + try { + val buffer = allocate() + Utils.readFully(channel, buffer, position) + buffer.flip() + + val abortedTxn = new AbortedTxn(buffer) + if (abortedTxn.version > AbortedTxn.CurrentVersion) + throw new KafkaException(s"Unexpected aborted transaction version ${abortedTxn.version} " + + s"in transaction index ${file.getAbsolutePath}, current version is ${AbortedTxn.CurrentVersion}") + val nextEntry = (abortedTxn, position) + position += AbortedTxn.TotalSize + nextEntry + } catch { + case e: IOException => + // We received an unexpected error reading from the index file. We propagate this as an + // UNKNOWN error to the consumer, which will cause it to retry the fetch. + throw new KafkaException(s"Failed to read from the transaction index ${file.getAbsolutePath}", e) + } + } + } + } + } + + def allAbortedTxns: List[AbortedTxn] = { + iterator().map(_._1).toList + } + + /** + * Collect all aborted transactions which overlap with a given fetch range. + * + * @param fetchOffset Inclusive first offset of the fetch range + * @param upperBoundOffset Exclusive last offset in the fetch range + * @return An object containing the aborted transactions and whether the search needs to continue + * into the next log segment. + */ + def collectAbortedTxns(fetchOffset: Long, upperBoundOffset: Long): TxnIndexSearchResult = { + val abortedTransactions = ListBuffer.empty[AbortedTxn] + for ((abortedTxn, _) <- iterator()) { + if (abortedTxn.lastOffset >= fetchOffset && abortedTxn.firstOffset < upperBoundOffset) + abortedTransactions += abortedTxn + + if (abortedTxn.lastStableOffset >= upperBoundOffset) + return TxnIndexSearchResult(abortedTransactions.toList, isComplete = true) + } + TxnIndexSearchResult(abortedTransactions.toList, isComplete = false) + } + + /** + * Do a basic sanity check on this index to detect obvious problems. + * + * @throws CorruptIndexException if any problems are found. + */ + def sanityCheck(): Unit = { + val buffer = ByteBuffer.allocate(AbortedTxn.TotalSize) + for ((abortedTxn, _) <- iterator(() => buffer)) { + if (abortedTxn.lastOffset < startOffset) + throw new CorruptIndexException(s"Last offset of aborted transaction $abortedTxn in index " + + s"${file.getAbsolutePath} is less than start offset $startOffset") + } + } + +} + +private[log] object AbortedTxn { + val VersionOffset = 0 + val VersionSize = 2 + val ProducerIdOffset = VersionOffset + VersionSize + val ProducerIdSize = 8 + val FirstOffsetOffset = ProducerIdOffset + ProducerIdSize + val FirstOffsetSize = 8 + val LastOffsetOffset = FirstOffsetOffset + FirstOffsetSize + val LastOffsetSize = 8 + val LastStableOffsetOffset = LastOffsetOffset + LastOffsetSize + val LastStableOffsetSize = 8 + val TotalSize = LastStableOffsetOffset + LastStableOffsetSize + + val CurrentVersion: Short = 0 +} + +private[log] class AbortedTxn(val buffer: ByteBuffer) { + import AbortedTxn._ + + def this(producerId: Long, + firstOffset: Long, + lastOffset: Long, + lastStableOffset: Long) = { + this(ByteBuffer.allocate(AbortedTxn.TotalSize)) + buffer.putShort(CurrentVersion) + buffer.putLong(producerId) + buffer.putLong(firstOffset) + buffer.putLong(lastOffset) + buffer.putLong(lastStableOffset) + buffer.flip() + } + + def this(completedTxn: CompletedTxn, lastStableOffset: Long) = + this(completedTxn.producerId, completedTxn.firstOffset, completedTxn.lastOffset, lastStableOffset) + + def version: Short = buffer.get(VersionOffset) + + def producerId: Long = buffer.getLong(ProducerIdOffset) + + def firstOffset: Long = buffer.getLong(FirstOffsetOffset) + + def lastOffset: Long = buffer.getLong(LastOffsetOffset) + + def lastStableOffset: Long = buffer.getLong(LastStableOffsetOffset) + + def asAbortedTransaction: FetchResponseData.AbortedTransaction = new FetchResponseData.AbortedTransaction() + .setProducerId(producerId) + .setFirstOffset(firstOffset) + + override def toString: String = + s"AbortedTxn(version=$version, producerId=$producerId, firstOffset=$firstOffset, " + + s"lastOffset=$lastOffset, lastStableOffset=$lastStableOffset)" + + override def equals(any: Any): Boolean = { + any match { + case that: AbortedTxn => this.buffer.equals(that.buffer) + case _ => false + } + } + + override def hashCode(): Int = buffer.hashCode +} diff --git a/core/src/main/scala/kafka/log/UnifiedLog.scala b/core/src/main/scala/kafka/log/UnifiedLog.scala new file mode 100644 index 0000000..029d1fb --- /dev/null +++ b/core/src/main/scala/kafka/log/UnifiedLog.scala @@ -0,0 +1,2120 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import com.yammer.metrics.core.MetricName + +import java.io.{File, IOException} +import java.nio.file.Files +import java.util.Optional +import java.util.concurrent.TimeUnit +import kafka.api.{ApiVersion, KAFKA_0_10_0_IV0} +import kafka.common.{LongRef, OffsetsOutOfOrderException, UnexpectedAppendOffsetException} +import kafka.log.AppendOrigin.RaftLeader +import kafka.message.{BrokerCompressionCodec, CompressionCodec, NoCompressionCodec} +import kafka.metrics.KafkaMetricsGroup +import kafka.server.checkpoints.LeaderEpochCheckpointFile +import kafka.server.epoch.LeaderEpochFileCache +import kafka.server.{BrokerTopicStats, FetchDataInfo, FetchHighWatermark, FetchIsolation, FetchLogEnd, FetchTxnCommitted, LogDirFailureChannel, LogOffsetMetadata, OffsetAndEpoch, PartitionMetadataFile, RequestLocal} +import kafka.utils._ +import org.apache.kafka.common.errors._ +import org.apache.kafka.common.message.{DescribeProducersResponseData, FetchResponseData} +import org.apache.kafka.common.record.FileRecords.TimestampAndOffset +import org.apache.kafka.common.record._ +import org.apache.kafka.common.requests.ListOffsetsRequest +import org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.UNDEFINED_EPOCH_OFFSET +import org.apache.kafka.common.requests.ProduceResponse.RecordError +import org.apache.kafka.common.utils.{Time, Utils} +import org.apache.kafka.common.{InvalidRecordException, KafkaException, TopicPartition, Uuid} + +import scala.annotation.nowarn +import scala.jdk.CollectionConverters._ +import scala.collection.mutable.ListBuffer +import scala.collection.{Seq, immutable, mutable} + +object LogAppendInfo { + val UnknownLogAppendInfo = LogAppendInfo(None, -1, None, RecordBatch.NO_TIMESTAMP, -1L, RecordBatch.NO_TIMESTAMP, -1L, + RecordConversionStats.EMPTY, NoCompressionCodec, NoCompressionCodec, -1, -1, offsetsMonotonic = false, -1L) + + def unknownLogAppendInfoWithLogStartOffset(logStartOffset: Long): LogAppendInfo = + LogAppendInfo(None, -1, None, RecordBatch.NO_TIMESTAMP, -1L, RecordBatch.NO_TIMESTAMP, logStartOffset, + RecordConversionStats.EMPTY, NoCompressionCodec, NoCompressionCodec, -1, -1, + offsetsMonotonic = false, -1L) + + /** + * In ProduceResponse V8+, we add two new fields record_errors and error_message (see KIP-467). + * For any record failures with InvalidTimestamp or InvalidRecordException, we construct a LogAppendInfo object like the one + * in unknownLogAppendInfoWithLogStartOffset, but with additiona fields recordErrors and errorMessage + */ + def unknownLogAppendInfoWithAdditionalInfo(logStartOffset: Long, recordErrors: Seq[RecordError], errorMessage: String): LogAppendInfo = + LogAppendInfo(None, -1, None, RecordBatch.NO_TIMESTAMP, -1L, RecordBatch.NO_TIMESTAMP, logStartOffset, + RecordConversionStats.EMPTY, NoCompressionCodec, NoCompressionCodec, -1, -1, + offsetsMonotonic = false, -1L, recordErrors, errorMessage) +} + +sealed trait LeaderHwChange +object LeaderHwChange { + case object Increased extends LeaderHwChange + case object Same extends LeaderHwChange + case object None extends LeaderHwChange +} + +/** + * Struct to hold various quantities we compute about each message set before appending to the log + * + * @param firstOffset The first offset in the message set unless the message format is less than V2 and we are appending + * to the follower. If the message is a duplicate message the segment base offset and relative position + * in segment will be unknown. + * @param lastOffset The last offset in the message set + * @param lastLeaderEpoch The partition leader epoch corresponding to the last offset, if available. + * @param maxTimestamp The maximum timestamp of the message set. + * @param offsetOfMaxTimestamp The offset of the message with the maximum timestamp. + * @param logAppendTime The log append time (if used) of the message set, otherwise Message.NoTimestamp + * @param logStartOffset The start offset of the log at the time of this append. + * @param recordConversionStats Statistics collected during record processing, `null` if `assignOffsets` is `false` + * @param sourceCodec The source codec used in the message set (send by the producer) + * @param targetCodec The target codec of the message set(after applying the broker compression configuration if any) + * @param shallowCount The number of shallow messages + * @param validBytes The number of valid bytes + * @param offsetsMonotonic Are the offsets in this message set monotonically increasing + * @param lastOffsetOfFirstBatch The last offset of the first batch + * @param leaderHwChange Incremental if the high watermark needs to be increased after appending record. + * Same if high watermark is not changed. None is the default value and it means append failed + * + */ +case class LogAppendInfo(var firstOffset: Option[LogOffsetMetadata], + var lastOffset: Long, + var lastLeaderEpoch: Option[Int], + var maxTimestamp: Long, + var offsetOfMaxTimestamp: Long, + var logAppendTime: Long, + var logStartOffset: Long, + var recordConversionStats: RecordConversionStats, + sourceCodec: CompressionCodec, + targetCodec: CompressionCodec, + shallowCount: Int, + validBytes: Int, + offsetsMonotonic: Boolean, + lastOffsetOfFirstBatch: Long, + recordErrors: Seq[RecordError] = List(), + errorMessage: String = null, + leaderHwChange: LeaderHwChange = LeaderHwChange.None) { + /** + * Get the first offset if it exists, else get the last offset of the first batch + * For magic versions 2 and newer, this method will return first offset. For magic versions + * older than 2, we use the last offset of the first batch as an approximation of the first + * offset to avoid decompressing the data. + */ + def firstOrLastOffsetOfFirstBatch: Long = firstOffset.map(_.messageOffset).getOrElse(lastOffsetOfFirstBatch) + + /** + * Get the (maximum) number of messages described by LogAppendInfo + * @return Maximum possible number of messages described by LogAppendInfo + */ + def numMessages: Long = { + firstOffset match { + case Some(firstOffsetVal) if (firstOffsetVal.messageOffset >= 0 && lastOffset >= 0) => + (lastOffset - firstOffsetVal.messageOffset + 1) + case _ => 0 + } + } +} + +/** + * Container class which represents a snapshot of the significant offsets for a partition. This allows fetching + * of these offsets atomically without the possibility of a leader change affecting their consistency relative + * to each other. See [[UnifiedLog.fetchOffsetSnapshot()]]. + */ +case class LogOffsetSnapshot(logStartOffset: Long, + logEndOffset: LogOffsetMetadata, + highWatermark: LogOffsetMetadata, + lastStableOffset: LogOffsetMetadata) + +/** + * Another container which is used for lower level reads using [[kafka.cluster.Partition.readRecords()]]. + */ +case class LogReadInfo(fetchedData: FetchDataInfo, + divergingEpoch: Option[FetchResponseData.EpochEndOffset], + highWatermark: Long, + logStartOffset: Long, + logEndOffset: Long, + lastStableOffset: Long) + +/** + * A class used to hold useful metadata about a completed transaction. This is used to build + * the transaction index after appending to the log. + * + * @param producerId The ID of the producer + * @param firstOffset The first offset (inclusive) of the transaction + * @param lastOffset The last offset (inclusive) of the transaction. This is always the offset of the + * COMMIT/ABORT control record which indicates the transaction's completion. + * @param isAborted Whether or not the transaction was aborted + */ +case class CompletedTxn(producerId: Long, firstOffset: Long, lastOffset: Long, isAborted: Boolean) { + override def toString: String = { + "CompletedTxn(" + + s"producerId=$producerId, " + + s"firstOffset=$firstOffset, " + + s"lastOffset=$lastOffset, " + + s"isAborted=$isAborted)" + } +} + +/** + * A class used to hold params required to decide to rotate a log segment or not. + */ +case class RollParams(maxSegmentMs: Long, + maxSegmentBytes: Int, + maxTimestampInMessages: Long, + maxOffsetInMessages: Long, + messagesSize: Int, + now: Long) + +object RollParams { + def apply(config: LogConfig, appendInfo: LogAppendInfo, messagesSize: Int, now: Long): RollParams = { + new RollParams(config.maxSegmentMs, + config.segmentSize, + appendInfo.maxTimestamp, + appendInfo.lastOffset, + messagesSize, now) + } +} + +sealed trait LogStartOffsetIncrementReason +case object ClientRecordDeletion extends LogStartOffsetIncrementReason { + override def toString: String = "client delete records request" +} +case object LeaderOffsetIncremented extends LogStartOffsetIncrementReason { + override def toString: String = "leader offset increment" +} +case object SegmentDeletion extends LogStartOffsetIncrementReason { + override def toString: String = "segment deletion" +} +case object SnapshotGenerated extends LogStartOffsetIncrementReason { + override def toString: String = "snapshot generated" +} + +/** + * A log which presents a unified view of local and tiered log segments. + * + * The log consists of tiered and local segments with the tiered portion of the log being optional. There could be an + * overlap between the tiered and local segments. The active segment is always guaranteed to be local. If tiered segments + * are present, they always appear at the beginning of the log, followed by an optional region of overlap, followed by the local + * segments including the active segment. + * + * NOTE: this class handles state and behavior specific to tiered segments as well as any behavior combining both tiered + * and local segments. The state and behavior specific to local segments are handled by the encapsulated LocalLog instance. + * + * @param logStartOffset The earliest offset allowed to be exposed to kafka client. + * The logStartOffset can be updated by : + * - user's DeleteRecordsRequest + * - broker's log retention + * - broker's log truncation + * - broker's log recovery + * The logStartOffset is used to decide the following: + * - Log deletion. LogSegment whose nextOffset <= log's logStartOffset can be deleted. + * It may trigger log rolling if the active segment is deleted. + * - Earliest offset of the log in response to ListOffsetRequest. To avoid OffsetOutOfRange exception after user seeks to earliest offset, + * we make sure that logStartOffset <= log's highWatermark + * Other activities such as log cleaning are not affected by logStartOffset. + * @param localLog The LocalLog instance containing non-empty log segments recovered from disk + * @param brokerTopicStats Container for Broker Topic Yammer Metrics + * @param producerIdExpirationCheckIntervalMs How often to check for producer ids which need to be expired + * @param leaderEpochCache The LeaderEpochFileCache instance (if any) containing state associated + * with the provided logStartOffset and nextOffsetMetadata + * @param producerStateManager The ProducerStateManager instance containing state associated with the provided segments + * @param _topicId optional Uuid to specify the topic ID for the topic if it exists. Should only be specified when + * first creating the log through Partition.makeLeader or Partition.makeFollower. When reloading a log, + * this field will be populated by reading the topic ID value from partition.metadata if it exists. + * @param keepPartitionMetadataFile boolean flag to indicate whether the partition.metadata file should be kept in the + * log directory. A partition.metadata file is only created when the raft controller is used + * or the ZK controller and this broker's inter-broker protocol version is at least 2.8. + * This file will persist the topic ID on the broker. If inter-broker protocol for a ZK controller + * is downgraded below 2.8, a topic ID may be lost and a new ID generated upon re-upgrade. + * If the inter-broker protocol version on a ZK cluster is below 2.8, partition.metadata + * will be deleted to avoid ID conflicts upon re-upgrade. + */ +@threadsafe +class UnifiedLog(@volatile var logStartOffset: Long, + private val localLog: LocalLog, + brokerTopicStats: BrokerTopicStats, + val producerIdExpirationCheckIntervalMs: Int, + @volatile var leaderEpochCache: Option[LeaderEpochFileCache], + val producerStateManager: ProducerStateManager, + @volatile private var _topicId: Option[Uuid], + val keepPartitionMetadataFile: Boolean) extends Logging with KafkaMetricsGroup { + + import kafka.log.UnifiedLog._ + + this.logIdent = s"[UnifiedLog partition=$topicPartition, dir=$parentDir] " + + /* A lock that guards all modifications to the log */ + private val lock = new Object + + /* The earliest offset which is part of an incomplete transaction. This is used to compute the + * last stable offset (LSO) in ReplicaManager. Note that it is possible that the "true" first unstable offset + * gets removed from the log (through record or segment deletion). In this case, the first unstable offset + * will point to the log start offset, which may actually be either part of a completed transaction or not + * part of a transaction at all. However, since we only use the LSO for the purpose of restricting the + * read_committed consumer to fetching decided data (i.e. committed, aborted, or non-transactional), this + * temporary abuse seems justifiable and saves us from scanning the log after deletion to find the first offsets + * of each ongoing transaction in order to compute a new first unstable offset. It is possible, however, + * that this could result in disagreement between replicas depending on when they began replicating the log. + * In the worst case, the LSO could be seen by a consumer to go backwards. + */ + @volatile private var firstUnstableOffsetMetadata: Option[LogOffsetMetadata] = None + + /* Keep track of the current high watermark in order to ensure that segments containing offsets at or above it are + * not eligible for deletion. This means that the active segment is only eligible for deletion if the high watermark + * equals the log end offset (which may never happen for a partition under consistent load). This is needed to + * prevent the log start offset (which is exposed in fetch responses) from getting ahead of the high watermark. + */ + @volatile private var highWatermarkMetadata: LogOffsetMetadata = LogOffsetMetadata(logStartOffset) + + @volatile var partitionMetadataFile : PartitionMetadataFile = null + + locally { + initializePartitionMetadata() + updateLogStartOffset(logStartOffset) + maybeIncrementFirstUnstableOffset() + initializeTopicId() + } + + /** + * Initialize topic ID information for the log by maintaining the partition metadata file and setting the in-memory _topicId. + * Delete partition metadata file if the version does not support topic IDs. + * Set _topicId based on a few scenarios: + * - Recover topic ID if present and topic IDs are supported. Ensure we do not try to assign a provided topicId that is inconsistent + * with the ID on file. + * - If we were provided a topic ID when creating the log, partition metadata files are supported, and one does not yet exist + * set _topicId and write to the partition metadata file. + * - Otherwise set _topicId to None + */ + def initializeTopicId(): Unit = { + if (partitionMetadataFile.exists()) { + if (keepPartitionMetadataFile) { + val fileTopicId = partitionMetadataFile.read().topicId + if (_topicId.isDefined && !_topicId.contains(fileTopicId)) + throw new InconsistentTopicIdException(s"Tried to assign topic ID $topicId to log for topic partition $topicPartition," + + s"but log already contained topic ID $fileTopicId") + + _topicId = Some(fileTopicId) + + } else { + try partitionMetadataFile.delete() + catch { + case e: IOException => + error(s"Error while trying to delete partition metadata file ${partitionMetadataFile}", e) + } + } + } else if (keepPartitionMetadataFile) { + _topicId.foreach(partitionMetadataFile.record) + scheduler.schedule("flush-metadata-file", maybeFlushMetadataFile) + } else { + // We want to keep the file and the in-memory topic ID in sync. + _topicId = None + } + } + + def topicId: Option[Uuid] = _topicId + + def dir: File = localLog.dir + + def parentDir: String = localLog.parentDir + + def parentDirFile: File = localLog.parentDirFile + + def name: String = localLog.name + + def recoveryPoint: Long = localLog.recoveryPoint + + def topicPartition: TopicPartition = localLog.topicPartition + + def time: Time = localLog.time + + def scheduler: Scheduler = localLog.scheduler + + def config: LogConfig = localLog.config + + def logDirFailureChannel: LogDirFailureChannel = localLog.logDirFailureChannel + + def updateConfig(newConfig: LogConfig): LogConfig = { + val oldConfig = localLog.config + localLog.updateConfig(newConfig) + val oldRecordVersion = oldConfig.recordVersion + val newRecordVersion = newConfig.recordVersion + if (newRecordVersion != oldRecordVersion) + initializeLeaderEpochCache() + oldConfig + } + + def highWatermark: Long = highWatermarkMetadata.messageOffset + + /** + * Update the high watermark to a new offset. The new high watermark will be lower + * bounded by the log start offset and upper bounded by the log end offset. + * + * This is intended to be called when initializing the high watermark or when updating + * it on a follower after receiving a Fetch response from the leader. + * + * @param hw the suggested new value for the high watermark + * @return the updated high watermark offset + */ + def updateHighWatermark(hw: Long): Long = { + updateHighWatermark(LogOffsetMetadata(hw)) + } + + /** + * Update high watermark with offset metadata. The new high watermark will be lower + * bounded by the log start offset and upper bounded by the log end offset. + * + * @param highWatermarkMetadata the suggested high watermark with offset metadata + * @return the updated high watermark offset + */ + def updateHighWatermark(highWatermarkMetadata: LogOffsetMetadata): Long = { + val endOffsetMetadata = localLog.logEndOffsetMetadata + val newHighWatermarkMetadata = if (highWatermarkMetadata.messageOffset < logStartOffset) { + LogOffsetMetadata(logStartOffset) + } else if (highWatermarkMetadata.messageOffset >= endOffsetMetadata.messageOffset) { + endOffsetMetadata + } else { + highWatermarkMetadata + } + + updateHighWatermarkMetadata(newHighWatermarkMetadata) + newHighWatermarkMetadata.messageOffset + } + + /** + * Update the high watermark to a new value if and only if it is larger than the old value. It is + * an error to update to a value which is larger than the log end offset. + * + * This method is intended to be used by the leader to update the high watermark after follower + * fetch offsets have been updated. + * + * @return the old high watermark, if updated by the new value + */ + def maybeIncrementHighWatermark(newHighWatermark: LogOffsetMetadata): Option[LogOffsetMetadata] = { + if (newHighWatermark.messageOffset > logEndOffset) + throw new IllegalArgumentException(s"High watermark $newHighWatermark update exceeds current " + + s"log end offset ${localLog.logEndOffsetMetadata}") + + lock.synchronized { + val oldHighWatermark = fetchHighWatermarkMetadata + + // Ensure that the high watermark increases monotonically. We also update the high watermark when the new + // offset metadata is on a newer segment, which occurs whenever the log is rolled to a new segment. + if (oldHighWatermark.messageOffset < newHighWatermark.messageOffset || + (oldHighWatermark.messageOffset == newHighWatermark.messageOffset && oldHighWatermark.onOlderSegment(newHighWatermark))) { + updateHighWatermarkMetadata(newHighWatermark) + Some(oldHighWatermark) + } else { + None + } + } + } + + /** + * Get the offset and metadata for the current high watermark. If offset metadata is not + * known, this will do a lookup in the index and cache the result. + */ + private def fetchHighWatermarkMetadata: LogOffsetMetadata = { + localLog.checkIfMemoryMappedBufferClosed() + + val offsetMetadata = highWatermarkMetadata + if (offsetMetadata.messageOffsetOnly) { + lock.synchronized { + val fullOffset = convertToOffsetMetadataOrThrow(highWatermark) + updateHighWatermarkMetadata(fullOffset) + fullOffset + } + } else { + offsetMetadata + } + } + + private def updateHighWatermarkMetadata(newHighWatermark: LogOffsetMetadata): Unit = { + if (newHighWatermark.messageOffset < 0) + throw new IllegalArgumentException("High watermark offset should be non-negative") + + lock synchronized { + if (newHighWatermark.messageOffset < highWatermarkMetadata.messageOffset) { + warn(s"Non-monotonic update of high watermark from $highWatermarkMetadata to $newHighWatermark") + } + + highWatermarkMetadata = newHighWatermark + producerStateManager.onHighWatermarkUpdated(newHighWatermark.messageOffset) + maybeIncrementFirstUnstableOffset() + } + trace(s"Setting high watermark $newHighWatermark") + } + + /** + * Get the first unstable offset. Unlike the last stable offset, which is always defined, + * the first unstable offset only exists if there are transactions in progress. + * + * @return the first unstable offset, if it exists + */ + private[log] def firstUnstableOffset: Option[Long] = firstUnstableOffsetMetadata.map(_.messageOffset) + + private def fetchLastStableOffsetMetadata: LogOffsetMetadata = { + localLog.checkIfMemoryMappedBufferClosed() + + // cache the current high watermark to avoid a concurrent update invalidating the range check + val highWatermarkMetadata = fetchHighWatermarkMetadata + + firstUnstableOffsetMetadata match { + case Some(offsetMetadata) if offsetMetadata.messageOffset < highWatermarkMetadata.messageOffset => + if (offsetMetadata.messageOffsetOnly) { + lock synchronized { + val fullOffset = convertToOffsetMetadataOrThrow(offsetMetadata.messageOffset) + if (firstUnstableOffsetMetadata.contains(offsetMetadata)) + firstUnstableOffsetMetadata = Some(fullOffset) + fullOffset + } + } else { + offsetMetadata + } + case _ => highWatermarkMetadata + } + } + + /** + * The last stable offset (LSO) is defined as the first offset such that all lower offsets have been "decided." + * Non-transactional messages are considered decided immediately, but transactional messages are only decided when + * the corresponding COMMIT or ABORT marker is written. This implies that the last stable offset will be equal + * to the high watermark if there are no transactional messages in the log. Note also that the LSO cannot advance + * beyond the high watermark. + */ + def lastStableOffset: Long = { + firstUnstableOffsetMetadata match { + case Some(offsetMetadata) if offsetMetadata.messageOffset < highWatermark => offsetMetadata.messageOffset + case _ => highWatermark + } + } + + def lastStableOffsetLag: Long = highWatermark - lastStableOffset + + /** + * Fully materialize and return an offset snapshot including segment position info. This method will update + * the LogOffsetMetadata for the high watermark and last stable offset if they are message-only. Throws an + * offset out of range error if the segment info cannot be loaded. + */ + def fetchOffsetSnapshot: LogOffsetSnapshot = { + val lastStable = fetchLastStableOffsetMetadata + val highWatermark = fetchHighWatermarkMetadata + + LogOffsetSnapshot( + logStartOffset, + localLog.logEndOffsetMetadata, + highWatermark, + lastStable + ) + } + + private val tags = { + val maybeFutureTag = if (isFuture) Map("is-future" -> "true") else Map.empty[String, String] + Map("topic" -> topicPartition.topic, "partition" -> topicPartition.partition.toString) ++ maybeFutureTag + } + + newGauge(LogMetricNames.NumLogSegments, () => numberOfSegments, tags) + newGauge(LogMetricNames.LogStartOffset, () => logStartOffset, tags) + newGauge(LogMetricNames.LogEndOffset, () => logEndOffset, tags) + newGauge(LogMetricNames.Size, () => size, tags) + + val producerExpireCheck = scheduler.schedule(name = "PeriodicProducerExpirationCheck", fun = () => { + lock synchronized { + producerStateManager.removeExpiredProducers(time.milliseconds) + } + }, period = producerIdExpirationCheckIntervalMs, delay = producerIdExpirationCheckIntervalMs, unit = TimeUnit.MILLISECONDS) + + // For compatibility, metrics are defined to be under `Log` class + override def metricName(name: String, tags: scala.collection.Map[String, String]): MetricName = { + val pkg = getClass.getPackage + val pkgStr = if (pkg == null) "" else pkg.getName + explicitMetricName(pkgStr, "Log", name, tags) + } + + private def recordVersion: RecordVersion = config.recordVersion + + private def initializePartitionMetadata(): Unit = lock synchronized { + val partitionMetadata = PartitionMetadataFile.newFile(dir) + partitionMetadataFile = new PartitionMetadataFile(partitionMetadata, logDirFailureChannel) + } + + private def maybeFlushMetadataFile(): Unit = { + partitionMetadataFile.maybeFlush() + } + + /** Only used for ZK clusters when we update and start using topic IDs on existing topics */ + def assignTopicId(topicId: Uuid): Unit = { + _topicId match { + case Some(currentId) => + if (!currentId.equals(topicId)) { + throw new InconsistentTopicIdException(s"Tried to assign topic ID $topicId to log for topic partition $topicPartition," + + s"but log already contained topic ID $currentId") + } + + case None => + if (keepPartitionMetadataFile) { + _topicId = Some(topicId) + if (!partitionMetadataFile.exists()) { + partitionMetadataFile.record(topicId) + scheduler.schedule("flush-metadata-file", maybeFlushMetadataFile) + } + } + } + } + + private def initializeLeaderEpochCache(): Unit = lock synchronized { + leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(dir, topicPartition, logDirFailureChannel, recordVersion, logIdent) + } + + private def updateHighWatermarkWithLogEndOffset(): Unit = { + // Update the high watermark in case it has gotten ahead of the log end offset following a truncation + // or if a new segment has been rolled and the offset metadata needs to be updated. + if (highWatermark >= localLog.logEndOffset) { + updateHighWatermarkMetadata(localLog.logEndOffsetMetadata) + } + } + + private def updateLogStartOffset(offset: Long): Unit = { + logStartOffset = offset + + if (highWatermark < offset) { + updateHighWatermark(offset) + } + + if (localLog.recoveryPoint < offset) { + localLog.updateRecoveryPoint(offset) + } + } + + // Rebuild producer state until lastOffset. This method may be called from the recovery code path, and thus must be + // free of all side-effects, i.e. it must not update any log-specific state. + private def rebuildProducerState(lastOffset: Long, + producerStateManager: ProducerStateManager): Unit = lock synchronized { + localLog.checkIfMemoryMappedBufferClosed() + UnifiedLog.rebuildProducerState(producerStateManager, localLog.segments, logStartOffset, lastOffset, recordVersion, time, + reloadFromCleanShutdown = false, logIdent) + } + + def activeProducers: Seq[DescribeProducersResponseData.ProducerState] = { + lock synchronized { + producerStateManager.activeProducers.map { case (producerId, state) => + new DescribeProducersResponseData.ProducerState() + .setProducerId(producerId) + .setProducerEpoch(state.producerEpoch) + .setLastSequence(state.lastSeq) + .setLastTimestamp(state.lastTimestamp) + .setCoordinatorEpoch(state.coordinatorEpoch) + .setCurrentTxnStartOffset(state.currentTxnFirstOffset.getOrElse(-1L)) + } + }.toSeq + } + + private[log] def activeProducersWithLastSequence: Map[Long, Int] = lock synchronized { + producerStateManager.activeProducers.map { case (producerId, producerIdEntry) => + (producerId, producerIdEntry.lastSeq) + } + } + + private[log] def lastRecordsOfActiveProducers: Map[Long, LastRecord] = lock synchronized { + producerStateManager.activeProducers.map { case (producerId, producerIdEntry) => + val lastDataOffset = if (producerIdEntry.lastDataOffset >= 0 ) Some(producerIdEntry.lastDataOffset) else None + val lastRecord = LastRecord(lastDataOffset, producerIdEntry.producerEpoch) + producerId -> lastRecord + } + } + + /** + * The number of segments in the log. + * Take care! this is an O(n) operation. + */ + def numberOfSegments: Int = localLog.segments.numberOfSegments + + /** + * Close this log. + * The memory mapped buffer for index files of this log will be left open until the log is deleted. + */ + def close(): Unit = { + debug("Closing log") + lock synchronized { + maybeFlushMetadataFile() + localLog.checkIfMemoryMappedBufferClosed() + producerExpireCheck.cancel(true) + maybeHandleIOException(s"Error while renaming dir for $topicPartition in dir ${dir.getParent}") { + // We take a snapshot at the last written offset to hopefully avoid the need to scan the log + // after restarting and to ensure that we cannot inadvertently hit the upgrade optimization + // (the clean shutdown file is written after the logs are all closed). + producerStateManager.takeSnapshot() + } + localLog.close() + } + } + + /** + * Rename the directory of the local log + * + * @throws KafkaStorageException if rename fails + */ + def renameDir(name: String): Unit = { + lock synchronized { + maybeHandleIOException(s"Error while renaming dir for $topicPartition in log dir ${dir.getParent}") { + // Flush partitionMetadata file before initializing again + maybeFlushMetadataFile() + if (localLog.renameDir(name)) { + producerStateManager.updateParentDir(dir) + // re-initialize leader epoch cache so that LeaderEpochCheckpointFile.checkpoint can correctly reference + // the checkpoint file in renamed log directory + initializeLeaderEpochCache() + initializePartitionMetadata() + } + } + } + } + + /** + * Close file handlers used by this log but don't write to disk. This is called if the log directory is offline + */ + def closeHandlers(): Unit = { + debug("Closing handlers") + lock synchronized { + localLog.closeHandlers() + } + } + + /** + * Append this message set to the active segment of the local log, assigning offsets and Partition Leader Epochs + * + * @param records The records to append + * @param origin Declares the origin of the append which affects required validations + * @param interBrokerProtocolVersion Inter-broker message protocol version + * @param requestLocal request local instance + * @throws KafkaStorageException If the append fails due to an I/O error. + * @return Information about the appended messages including the first and last offset. + */ + def appendAsLeader(records: MemoryRecords, + leaderEpoch: Int, + origin: AppendOrigin = AppendOrigin.Client, + interBrokerProtocolVersion: ApiVersion = ApiVersion.latestVersion, + requestLocal: RequestLocal = RequestLocal.NoCaching): LogAppendInfo = { + val validateAndAssignOffsets = origin != AppendOrigin.RaftLeader + append(records, origin, interBrokerProtocolVersion, validateAndAssignOffsets, leaderEpoch, Some(requestLocal), ignoreRecordSize = false) + } + + /** + * Append this message set to the active segment of the local log without assigning offsets or Partition Leader Epochs + * + * @param records The records to append + * @throws KafkaStorageException If the append fails due to an I/O error. + * @return Information about the appended messages including the first and last offset. + */ + def appendAsFollower(records: MemoryRecords): LogAppendInfo = { + append(records, + origin = AppendOrigin.Replication, + interBrokerProtocolVersion = ApiVersion.latestVersion, + validateAndAssignOffsets = false, + leaderEpoch = -1, + None, + // disable to check the validation of record size since the record is already accepted by leader. + ignoreRecordSize = true) + } + + /** + * Append this message set to the active segment of the local log, rolling over to a fresh segment if necessary. + * + * This method will generally be responsible for assigning offsets to the messages, + * however if the assignOffsets=false flag is passed we will only check that the existing offsets are valid. + * + * @param records The log records to append + * @param origin Declares the origin of the append which affects required validations + * @param interBrokerProtocolVersion Inter-broker message protocol version + * @param validateAndAssignOffsets Should the log assign offsets to this message set or blindly apply what it is given + * @param leaderEpoch The partition's leader epoch which will be applied to messages when offsets are assigned on the leader + * @param requestLocal The request local instance if assignOffsets is true + * @param ignoreRecordSize true to skip validation of record size. + * @throws KafkaStorageException If the append fails due to an I/O error. + * @throws OffsetsOutOfOrderException If out of order offsets found in 'records' + * @throws UnexpectedAppendOffsetException If the first or last offset in append is less than next offset + * @return Information about the appended messages including the first and last offset. + */ + private def append(records: MemoryRecords, + origin: AppendOrigin, + interBrokerProtocolVersion: ApiVersion, + validateAndAssignOffsets: Boolean, + leaderEpoch: Int, + requestLocal: Option[RequestLocal], + ignoreRecordSize: Boolean): LogAppendInfo = { + // We want to ensure the partition metadata file is written to the log dir before any log data is written to disk. + // This will ensure that any log data can be recovered with the correct topic ID in the case of failure. + maybeFlushMetadataFile() + + val appendInfo = analyzeAndValidateRecords(records, origin, ignoreRecordSize, leaderEpoch) + + // return if we have no valid messages or if this is a duplicate of the last appended entry + if (appendInfo.shallowCount == 0) appendInfo + else { + + // trim any invalid bytes or partial messages before appending it to the on-disk log + var validRecords = trimInvalidBytes(records, appendInfo) + + // they are valid, insert them in the log + lock synchronized { + maybeHandleIOException(s"Error while appending records to $topicPartition in dir ${dir.getParent}") { + localLog.checkIfMemoryMappedBufferClosed() + if (validateAndAssignOffsets) { + // assign offsets to the message set + val offset = new LongRef(localLog.logEndOffset) + appendInfo.firstOffset = Some(LogOffsetMetadata(offset.value)) + val now = time.milliseconds + val validateAndOffsetAssignResult = try { + LogValidator.validateMessagesAndAssignOffsets(validRecords, + topicPartition, + offset, + time, + now, + appendInfo.sourceCodec, + appendInfo.targetCodec, + config.compact, + config.recordVersion.value, + config.messageTimestampType, + config.messageTimestampDifferenceMaxMs, + leaderEpoch, + origin, + interBrokerProtocolVersion, + brokerTopicStats, + requestLocal.getOrElse(throw new IllegalArgumentException( + "requestLocal should be defined if assignOffsets is true"))) + } catch { + case e: IOException => + throw new KafkaException(s"Error validating messages while appending to log $name", e) + } + validRecords = validateAndOffsetAssignResult.validatedRecords + appendInfo.maxTimestamp = validateAndOffsetAssignResult.maxTimestamp + appendInfo.offsetOfMaxTimestamp = validateAndOffsetAssignResult.shallowOffsetOfMaxTimestamp + appendInfo.lastOffset = offset.value - 1 + appendInfo.recordConversionStats = validateAndOffsetAssignResult.recordConversionStats + if (config.messageTimestampType == TimestampType.LOG_APPEND_TIME) + appendInfo.logAppendTime = now + + // re-validate message sizes if there's a possibility that they have changed (due to re-compression or message + // format conversion) + if (!ignoreRecordSize && validateAndOffsetAssignResult.messageSizeMaybeChanged) { + validRecords.batches.forEach { batch => + if (batch.sizeInBytes > config.maxMessageSize) { + // we record the original message set size instead of the trimmed size + // to be consistent with pre-compression bytesRejectedRate recording + brokerTopicStats.topicStats(topicPartition.topic).bytesRejectedRate.mark(records.sizeInBytes) + brokerTopicStats.allTopicsStats.bytesRejectedRate.mark(records.sizeInBytes) + throw new RecordTooLargeException(s"Message batch size is ${batch.sizeInBytes} bytes in append to" + + s"partition $topicPartition which exceeds the maximum configured size of ${config.maxMessageSize}.") + } + } + } + } else { + // we are taking the offsets we are given + if (!appendInfo.offsetsMonotonic) + throw new OffsetsOutOfOrderException(s"Out of order offsets found in append to $topicPartition: " + + records.records.asScala.map(_.offset)) + + if (appendInfo.firstOrLastOffsetOfFirstBatch < localLog.logEndOffset) { + // we may still be able to recover if the log is empty + // one example: fetching from log start offset on the leader which is not batch aligned, + // which may happen as a result of AdminClient#deleteRecords() + val firstOffset = appendInfo.firstOffset match { + case Some(offsetMetadata) => offsetMetadata.messageOffset + case None => records.batches.asScala.head.baseOffset() + } + + val firstOrLast = if (appendInfo.firstOffset.isDefined) "First offset" else "Last offset of the first batch" + throw new UnexpectedAppendOffsetException( + s"Unexpected offset in append to $topicPartition. $firstOrLast " + + s"${appendInfo.firstOrLastOffsetOfFirstBatch} is less than the next offset ${localLog.logEndOffset}. " + + s"First 10 offsets in append: ${records.records.asScala.take(10).map(_.offset)}, last offset in" + + s" append: ${appendInfo.lastOffset}. Log start offset = $logStartOffset", + firstOffset, appendInfo.lastOffset) + } + } + + // update the epoch cache with the epoch stamped onto the message by the leader + validRecords.batches.forEach { batch => + if (batch.magic >= RecordBatch.MAGIC_VALUE_V2) { + maybeAssignEpochStartOffset(batch.partitionLeaderEpoch, batch.baseOffset) + } else { + // In partial upgrade scenarios, we may get a temporary regression to the message format. In + // order to ensure the safety of leader election, we clear the epoch cache so that we revert + // to truncation by high watermark after the next leader election. + leaderEpochCache.filter(_.nonEmpty).foreach { cache => + warn(s"Clearing leader epoch cache after unexpected append with message format v${batch.magic}") + cache.clearAndFlush() + } + } + } + + // check messages set size may be exceed config.segmentSize + if (validRecords.sizeInBytes > config.segmentSize) { + throw new RecordBatchTooLargeException(s"Message batch size is ${validRecords.sizeInBytes} bytes in append " + + s"to partition $topicPartition, which exceeds the maximum configured segment size of ${config.segmentSize}.") + } + + // maybe roll the log if this segment is full + val segment = maybeRoll(validRecords.sizeInBytes, appendInfo) + + val logOffsetMetadata = LogOffsetMetadata( + messageOffset = appendInfo.firstOrLastOffsetOfFirstBatch, + segmentBaseOffset = segment.baseOffset, + relativePositionInSegment = segment.size) + + // now that we have valid records, offsets assigned, and timestamps updated, we need to + // validate the idempotent/transactional state of the producers and collect some metadata + val (updatedProducers, completedTxns, maybeDuplicate) = analyzeAndValidateProducerState( + logOffsetMetadata, validRecords, origin) + + maybeDuplicate match { + case Some(duplicate) => + appendInfo.firstOffset = Some(LogOffsetMetadata(duplicate.firstOffset)) + appendInfo.lastOffset = duplicate.lastOffset + appendInfo.logAppendTime = duplicate.timestamp + appendInfo.logStartOffset = logStartOffset + case None => + // Before appending update the first offset metadata to include segment information + appendInfo.firstOffset = appendInfo.firstOffset.map { offsetMetadata => + offsetMetadata.copy(segmentBaseOffset = segment.baseOffset, relativePositionInSegment = segment.size) + } + + // Append the records, and increment the local log end offset immediately after the append because a + // write to the transaction index below may fail and we want to ensure that the offsets + // of future appends still grow monotonically. The resulting transaction index inconsistency + // will be cleaned up after the log directory is recovered. Note that the end offset of the + // ProducerStateManager will not be updated and the last stable offset will not advance + // if the append to the transaction index fails. + localLog.append(appendInfo.lastOffset, appendInfo.maxTimestamp, appendInfo.offsetOfMaxTimestamp, validRecords) + updateHighWatermarkWithLogEndOffset() + + // update the producer state + updatedProducers.values.foreach(producerAppendInfo => producerStateManager.update(producerAppendInfo)) + + // update the transaction index with the true last stable offset. The last offset visible + // to consumers using READ_COMMITTED will be limited by this value and the high watermark. + completedTxns.foreach { completedTxn => + val lastStableOffset = producerStateManager.lastStableOffset(completedTxn) + segment.updateTxnIndex(completedTxn, lastStableOffset) + producerStateManager.completeTxn(completedTxn) + } + + // always update the last producer id map offset so that the snapshot reflects the current offset + // even if there isn't any idempotent data being written + producerStateManager.updateMapEndOffset(appendInfo.lastOffset + 1) + + // update the first unstable offset (which is used to compute LSO) + maybeIncrementFirstUnstableOffset() + + trace(s"Appended message set with last offset: ${appendInfo.lastOffset}, " + + s"first offset: ${appendInfo.firstOffset}, " + + s"next offset: ${localLog.logEndOffset}, " + + s"and messages: $validRecords") + + if (localLog.unflushedMessages >= config.flushInterval) flush() + } + appendInfo + } + } + } + } + + def maybeAssignEpochStartOffset(leaderEpoch: Int, startOffset: Long): Unit = { + leaderEpochCache.foreach { cache => + cache.assign(leaderEpoch, startOffset) + } + } + + def latestEpoch: Option[Int] = leaderEpochCache.flatMap(_.latestEpoch) + + def endOffsetForEpoch(leaderEpoch: Int): Option[OffsetAndEpoch] = { + leaderEpochCache.flatMap { cache => + val (foundEpoch, foundOffset) = cache.endOffsetFor(leaderEpoch, logEndOffset) + if (foundOffset == UNDEFINED_EPOCH_OFFSET) + None + else + Some(OffsetAndEpoch(foundOffset, foundEpoch)) + } + } + + private def maybeIncrementFirstUnstableOffset(): Unit = lock synchronized { + localLog.checkIfMemoryMappedBufferClosed() + + val updatedFirstUnstableOffset = producerStateManager.firstUnstableOffset match { + case Some(logOffsetMetadata) if logOffsetMetadata.messageOffsetOnly || logOffsetMetadata.messageOffset < logStartOffset => + val offset = math.max(logOffsetMetadata.messageOffset, logStartOffset) + Some(convertToOffsetMetadataOrThrow(offset)) + case other => other + } + + if (updatedFirstUnstableOffset != this.firstUnstableOffsetMetadata) { + debug(s"First unstable offset updated to $updatedFirstUnstableOffset") + this.firstUnstableOffsetMetadata = updatedFirstUnstableOffset + } + } + + /** + * Increment the log start offset if the provided offset is larger. + * + * If the log start offset changed, then this method also update a few key offset such that + * `logStartOffset <= logStableOffset <= highWatermark`. The leader epoch cache is also updated + * such that all of offsets referenced in that component point to valid offset in this log. + * + * @throws OffsetOutOfRangeException if the log start offset is greater than the high watermark + * @return true if the log start offset was updated; otherwise false + */ + def maybeIncrementLogStartOffset(newLogStartOffset: Long, reason: LogStartOffsetIncrementReason): Boolean = { + // We don't have to write the log start offset to log-start-offset-checkpoint immediately. + // The deleteRecordsOffset may be lost only if all in-sync replicas of this broker are shutdown + // in an unclean manner within log.flush.start.offset.checkpoint.interval.ms. The chance of this happening is low. + var updatedLogStartOffset = false + maybeHandleIOException(s"Exception while increasing log start offset for $topicPartition to $newLogStartOffset in dir ${dir.getParent}") { + lock synchronized { + if (newLogStartOffset > highWatermark) + throw new OffsetOutOfRangeException(s"Cannot increment the log start offset to $newLogStartOffset of partition $topicPartition " + + s"since it is larger than the high watermark $highWatermark") + + localLog.checkIfMemoryMappedBufferClosed() + if (newLogStartOffset > logStartOffset) { + updatedLogStartOffset = true + updateLogStartOffset(newLogStartOffset) + info(s"Incremented log start offset to $newLogStartOffset due to $reason") + leaderEpochCache.foreach(_.truncateFromStart(logStartOffset)) + producerStateManager.onLogStartOffsetIncremented(newLogStartOffset) + maybeIncrementFirstUnstableOffset() + } + } + } + + updatedLogStartOffset + } + + private def analyzeAndValidateProducerState(appendOffsetMetadata: LogOffsetMetadata, + records: MemoryRecords, + origin: AppendOrigin): + (mutable.Map[Long, ProducerAppendInfo], List[CompletedTxn], Option[BatchMetadata]) = { + val updatedProducers = mutable.Map.empty[Long, ProducerAppendInfo] + val completedTxns = ListBuffer.empty[CompletedTxn] + var relativePositionInSegment = appendOffsetMetadata.relativePositionInSegment + + records.batches.forEach { batch => + if (batch.hasProducerId) { + // if this is a client produce request, there will be up to 5 batches which could have been duplicated. + // If we find a duplicate, we return the metadata of the appended batch to the client. + if (origin == AppendOrigin.Client) { + val maybeLastEntry = producerStateManager.lastEntry(batch.producerId) + + maybeLastEntry.flatMap(_.findDuplicateBatch(batch)).foreach { duplicate => + return (updatedProducers, completedTxns.toList, Some(duplicate)) + } + } + + // We cache offset metadata for the start of each transaction. This allows us to + // compute the last stable offset without relying on additional index lookups. + val firstOffsetMetadata = if (batch.isTransactional) + Some(LogOffsetMetadata(batch.baseOffset, appendOffsetMetadata.segmentBaseOffset, relativePositionInSegment)) + else + None + + val maybeCompletedTxn = updateProducers(producerStateManager, batch, updatedProducers, firstOffsetMetadata, origin) + maybeCompletedTxn.foreach(completedTxns += _) + } + + relativePositionInSegment += batch.sizeInBytes + } + (updatedProducers, completedTxns.toList, None) + } + + /** + * Validate the following: + *
              + *
            1. each message matches its CRC + *
            2. each message size is valid (if ignoreRecordSize is false) + *
            3. that the sequence numbers of the incoming record batches are consistent with the existing state and with each other. + *
            + * + * Also compute the following quantities: + *
              + *
            1. First offset in the message set + *
            2. Last offset in the message set + *
            3. Number of messages + *
            4. Number of valid bytes + *
            5. Whether the offsets are monotonically increasing + *
            6. Whether any compression codec is used (if many are used, then the last one is given) + *
            + */ + private def analyzeAndValidateRecords(records: MemoryRecords, + origin: AppendOrigin, + ignoreRecordSize: Boolean, + leaderEpoch: Int): LogAppendInfo = { + var shallowMessageCount = 0 + var validBytesCount = 0 + var firstOffset: Option[LogOffsetMetadata] = None + var lastOffset = -1L + var lastLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH + var sourceCodec: CompressionCodec = NoCompressionCodec + var monotonic = true + var maxTimestamp = RecordBatch.NO_TIMESTAMP + var offsetOfMaxTimestamp = -1L + var readFirstMessage = false + var lastOffsetOfFirstBatch = -1L + + records.batches.forEach { batch => + if (origin == RaftLeader && batch.partitionLeaderEpoch != leaderEpoch) { + throw new InvalidRecordException("Append from Raft leader did not set the batch epoch correctly") + } + // we only validate V2 and higher to avoid potential compatibility issues with older clients + if (batch.magic >= RecordBatch.MAGIC_VALUE_V2 && origin == AppendOrigin.Client && batch.baseOffset != 0) + throw new InvalidRecordException(s"The baseOffset of the record batch in the append to $topicPartition should " + + s"be 0, but it is ${batch.baseOffset}") + + // update the first offset if on the first message. For magic versions older than 2, we use the last offset + // to avoid the need to decompress the data (the last offset can be obtained directly from the wrapper message). + // For magic version 2, we can get the first offset directly from the batch header. + // When appending to the leader, we will update LogAppendInfo.baseOffset with the correct value. In the follower + // case, validation will be more lenient. + // Also indicate whether we have the accurate first offset or not + if (!readFirstMessage) { + if (batch.magic >= RecordBatch.MAGIC_VALUE_V2) + firstOffset = Some(LogOffsetMetadata(batch.baseOffset)) + lastOffsetOfFirstBatch = batch.lastOffset + readFirstMessage = true + } + + // check that offsets are monotonically increasing + if (lastOffset >= batch.lastOffset) + monotonic = false + + // update the last offset seen + lastOffset = batch.lastOffset + lastLeaderEpoch = batch.partitionLeaderEpoch + + // Check if the message sizes are valid. + val batchSize = batch.sizeInBytes + if (!ignoreRecordSize && batchSize > config.maxMessageSize) { + brokerTopicStats.topicStats(topicPartition.topic).bytesRejectedRate.mark(records.sizeInBytes) + brokerTopicStats.allTopicsStats.bytesRejectedRate.mark(records.sizeInBytes) + throw new RecordTooLargeException(s"The record batch size in the append to $topicPartition is $batchSize bytes " + + s"which exceeds the maximum configured value of ${config.maxMessageSize}.") + } + + // check the validity of the message by checking CRC + if (!batch.isValid) { + brokerTopicStats.allTopicsStats.invalidMessageCrcRecordsPerSec.mark() + throw new CorruptRecordException(s"Record is corrupt (stored crc = ${batch.checksum()}) in topic partition $topicPartition.") + } + + if (batch.maxTimestamp > maxTimestamp) { + maxTimestamp = batch.maxTimestamp + offsetOfMaxTimestamp = lastOffset + } + + shallowMessageCount += 1 + validBytesCount += batchSize + + val messageCodec = CompressionCodec.getCompressionCodec(batch.compressionType.id) + if (messageCodec != NoCompressionCodec) + sourceCodec = messageCodec + } + + // Apply broker-side compression if any + val targetCodec = BrokerCompressionCodec.getTargetCompressionCodec(config.compressionType, sourceCodec) + val lastLeaderEpochOpt: Option[Int] = if (lastLeaderEpoch != RecordBatch.NO_PARTITION_LEADER_EPOCH) + Some(lastLeaderEpoch) + else + None + LogAppendInfo(firstOffset, lastOffset, lastLeaderEpochOpt, maxTimestamp, offsetOfMaxTimestamp, RecordBatch.NO_TIMESTAMP, logStartOffset, + RecordConversionStats.EMPTY, sourceCodec, targetCodec, shallowMessageCount, validBytesCount, monotonic, lastOffsetOfFirstBatch) + } + + /** + * Trim any invalid bytes from the end of this message set (if there are any) + * + * @param records The records to trim + * @param info The general information of the message set + * @return A trimmed message set. This may be the same as what was passed in or it may not. + */ + private def trimInvalidBytes(records: MemoryRecords, info: LogAppendInfo): MemoryRecords = { + val validBytes = info.validBytes + if (validBytes < 0) + throw new CorruptRecordException(s"Cannot append record batch with illegal length $validBytes to " + + s"log for $topicPartition. A possible cause is a corrupted produce request.") + if (validBytes == records.sizeInBytes) { + records + } else { + // trim invalid bytes + val validByteBuffer = records.buffer.duplicate() + validByteBuffer.limit(validBytes) + MemoryRecords.readableRecords(validByteBuffer) + } + } + + private def checkLogStartOffset(offset: Long): Unit = { + if (offset < logStartOffset) + throw new OffsetOutOfRangeException(s"Received request for offset $offset for partition $topicPartition, " + + s"but we only have log segments starting from offset: $logStartOffset.") + } + + /** + * Read messages from the log. + * + * @param startOffset The offset to begin reading at + * @param maxLength The maximum number of bytes to read + * @param isolation The fetch isolation, which controls the maximum offset we are allowed to read + * @param minOneMessage If this is true, the first message will be returned even if it exceeds `maxLength` (if one exists) + * @throws OffsetOutOfRangeException If startOffset is beyond the log end offset or before the log start offset + * @return The fetch data information including fetch starting offset metadata and messages read. + */ + def read(startOffset: Long, + maxLength: Int, + isolation: FetchIsolation, + minOneMessage: Boolean): FetchDataInfo = { + checkLogStartOffset(startOffset) + val maxOffsetMetadata = isolation match { + case FetchLogEnd => localLog.logEndOffsetMetadata + case FetchHighWatermark => fetchHighWatermarkMetadata + case FetchTxnCommitted => fetchLastStableOffsetMetadata + } + localLog.read(startOffset, maxLength, minOneMessage, maxOffsetMetadata, isolation == FetchTxnCommitted) + } + + private[log] def collectAbortedTransactions(startOffset: Long, upperBoundOffset: Long): List[AbortedTxn] = { + localLog.collectAbortedTransactions(logStartOffset, startOffset, upperBoundOffset) + } + + /** + * Get an offset based on the given timestamp + * The offset returned is the offset of the first message whose timestamp is greater than or equals to the + * given timestamp. + * + * If no such message is found, the log end offset is returned. + * + * `NOTE:` OffsetRequest V0 does not use this method, the behavior of OffsetRequest V0 remains the same as before + * , i.e. it only gives back the timestamp based on the last modification time of the log segments. + * + * @param targetTimestamp The given timestamp for offset fetching. + * @return The offset of the first message whose timestamp is greater than or equals to the given timestamp. + * None if no such message is found. + */ + @nowarn("cat=deprecation") + def fetchOffsetByTimestamp(targetTimestamp: Long): Option[TimestampAndOffset] = { + maybeHandleIOException(s"Error while fetching offset by timestamp for $topicPartition in dir ${dir.getParent}") { + debug(s"Searching offset for timestamp $targetTimestamp") + + if (config.messageFormatVersion < KAFKA_0_10_0_IV0 && + targetTimestamp != ListOffsetsRequest.EARLIEST_TIMESTAMP && + targetTimestamp != ListOffsetsRequest.LATEST_TIMESTAMP) + throw new UnsupportedForMessageFormatException(s"Cannot search offsets based on timestamp because message format version " + + s"for partition $topicPartition is ${config.messageFormatVersion} which is earlier than the minimum " + + s"required version $KAFKA_0_10_0_IV0") + + // For the earliest and latest, we do not need to return the timestamp. + if (targetTimestamp == ListOffsetsRequest.EARLIEST_TIMESTAMP) { + // The first cached epoch usually corresponds to the log start offset, but we have to verify this since + // it may not be true following a message format version bump as the epoch will not be available for + // log entries written in the older format. + val earliestEpochEntry = leaderEpochCache.flatMap(_.earliestEntry) + val epochOpt = earliestEpochEntry match { + case Some(entry) if entry.startOffset <= logStartOffset => Optional.of[Integer](entry.epoch) + case _ => Optional.empty[Integer]() + } + Some(new TimestampAndOffset(RecordBatch.NO_TIMESTAMP, logStartOffset, epochOpt)) + } else if (targetTimestamp == ListOffsetsRequest.LATEST_TIMESTAMP) { + val latestEpochOpt = leaderEpochCache.flatMap(_.latestEpoch).map(_.asInstanceOf[Integer]) + val epochOptional = Optional.ofNullable(latestEpochOpt.orNull) + Some(new TimestampAndOffset(RecordBatch.NO_TIMESTAMP, logEndOffset, epochOptional)) + } else if (targetTimestamp == ListOffsetsRequest.MAX_TIMESTAMP) { + // Cache to avoid race conditions. `toBuffer` is faster than most alternatives and provides + // constant time access while being safe to use with concurrent collections unlike `toArray`. + val segmentsCopy = logSegments.toBuffer + val latestTimestampSegment = segmentsCopy.maxBy(_.maxTimestampSoFar) + val latestEpochOpt = leaderEpochCache.flatMap(_.latestEpoch).map(_.asInstanceOf[Integer]) + val epochOptional = Optional.ofNullable(latestEpochOpt.orNull) + val latestTimestampAndOffset = latestTimestampSegment.maxTimestampAndOffsetSoFar + Some(new TimestampAndOffset(latestTimestampAndOffset.timestamp, + latestTimestampAndOffset.offset, + epochOptional)) + } else { + // Cache to avoid race conditions. `toBuffer` is faster than most alternatives and provides + // constant time access while being safe to use with concurrent collections unlike `toArray`. + val segmentsCopy = logSegments.toBuffer + // We need to search the first segment whose largest timestamp is >= the target timestamp if there is one. + val targetSeg = segmentsCopy.find(_.largestTimestamp >= targetTimestamp) + targetSeg.flatMap(_.findOffsetByTimestamp(targetTimestamp, logStartOffset)) + } + } + } + + def legacyFetchOffsetsBefore(timestamp: Long, maxNumOffsets: Int): Seq[Long] = { + // Cache to avoid race conditions. `toBuffer` is faster than most alternatives and provides + // constant time access while being safe to use with concurrent collections unlike `toArray`. + val allSegments = logSegments.toBuffer + val lastSegmentHasSize = allSegments.last.size > 0 + + val offsetTimeArray = + if (lastSegmentHasSize) + new Array[(Long, Long)](allSegments.length + 1) + else + new Array[(Long, Long)](allSegments.length) + + for (i <- allSegments.indices) + offsetTimeArray(i) = (math.max(allSegments(i).baseOffset, logStartOffset), allSegments(i).lastModified) + if (lastSegmentHasSize) + offsetTimeArray(allSegments.length) = (logEndOffset, time.milliseconds) + + var startIndex = -1 + timestamp match { + case ListOffsetsRequest.LATEST_TIMESTAMP => + startIndex = offsetTimeArray.length - 1 + case ListOffsetsRequest.EARLIEST_TIMESTAMP => + startIndex = 0 + case _ => + var isFound = false + debug("Offset time array = " + offsetTimeArray.foreach(o => "%d, %d".format(o._1, o._2))) + startIndex = offsetTimeArray.length - 1 + while (startIndex >= 0 && !isFound) { + if (offsetTimeArray(startIndex)._2 <= timestamp) + isFound = true + else + startIndex -= 1 + } + } + + val retSize = maxNumOffsets.min(startIndex + 1) + val ret = new Array[Long](retSize) + for (j <- 0 until retSize) { + ret(j) = offsetTimeArray(startIndex)._1 + startIndex -= 1 + } + // ensure that the returned seq is in descending order of offsets + ret.toSeq.sortBy(-_) + } + + /** + * Given a message offset, find its corresponding offset metadata in the log. + * If the message offset is out of range, throw an OffsetOutOfRangeException + */ + private def convertToOffsetMetadataOrThrow(offset: Long): LogOffsetMetadata = { + checkLogStartOffset(offset) + localLog.convertToOffsetMetadataOrThrow(offset) + } + + /** + * Delete any local log segments starting with the oldest segment and moving forward until until + * the user-supplied predicate is false or the segment containing the current high watermark is reached. + * We do not delete segments with offsets at or beyond the high watermark to ensure that the log start + * offset can never exceed it. If the high watermark has not yet been initialized, no segments are eligible + * for deletion. + * + * @param predicate A function that takes in a candidate log segment and the next higher segment + * (if there is one) and returns true iff it is deletable + * @param reason The reason for the segment deletion + * @return The number of segments deleted + */ + private def deleteOldSegments(predicate: (LogSegment, Option[LogSegment]) => Boolean, + reason: SegmentDeletionReason): Int = { + def shouldDelete(segment: LogSegment, nextSegmentOpt: Option[LogSegment]): Boolean = { + highWatermark >= nextSegmentOpt.map(_.baseOffset).getOrElse(localLog.logEndOffset) && + predicate(segment, nextSegmentOpt) + } + lock synchronized { + val deletable = localLog.deletableSegments(shouldDelete) + if (deletable.nonEmpty) + deleteSegments(deletable, reason) + else + 0 + } + } + + private def deleteSegments(deletable: Iterable[LogSegment], reason: SegmentDeletionReason): Int = { + maybeHandleIOException(s"Error while deleting segments for $topicPartition in dir ${dir.getParent}") { + val numToDelete = deletable.size + if (numToDelete > 0) { + // we must always have at least one segment, so if we are going to delete all the segments, create a new one first + if (localLog.segments.numberOfSegments == numToDelete) + roll() + lock synchronized { + localLog.checkIfMemoryMappedBufferClosed() + // remove the segments for lookups + localLog.removeAndDeleteSegments(deletable, asyncDelete = true, reason) + deleteProducerSnapshots(deletable, asyncDelete = true) + maybeIncrementLogStartOffset(localLog.segments.firstSegmentBaseOffset.get, SegmentDeletion) + } + } + numToDelete + } + } + + /** + * If topic deletion is enabled, delete any local log segments that have either expired due to time based retention + * or because the log size is > retentionSize. + * + * Whether or not deletion is enabled, delete any local log segments that are before the log start offset + */ + def deleteOldSegments(): Int = { + if (config.delete) { + deleteLogStartOffsetBreachedSegments() + + deleteRetentionSizeBreachedSegments() + + deleteRetentionMsBreachedSegments() + } else { + deleteLogStartOffsetBreachedSegments() + } + } + + private def deleteRetentionMsBreachedSegments(): Int = { + if (config.retentionMs < 0) return 0 + val startMs = time.milliseconds + + def shouldDelete(segment: LogSegment, nextSegmentOpt: Option[LogSegment]): Boolean = { + startMs - segment.largestTimestamp > config.retentionMs + } + + deleteOldSegments(shouldDelete, RetentionMsBreach(this)) + } + + private def deleteRetentionSizeBreachedSegments(): Int = { + if (config.retentionSize < 0 || size < config.retentionSize) return 0 + var diff = size - config.retentionSize + def shouldDelete(segment: LogSegment, nextSegmentOpt: Option[LogSegment]): Boolean = { + if (diff - segment.size >= 0) { + diff -= segment.size + true + } else { + false + } + } + + deleteOldSegments(shouldDelete, RetentionSizeBreach(this)) + } + + private def deleteLogStartOffsetBreachedSegments(): Int = { + def shouldDelete(segment: LogSegment, nextSegmentOpt: Option[LogSegment]): Boolean = { + nextSegmentOpt.exists(_.baseOffset <= logStartOffset) + } + + deleteOldSegments(shouldDelete, StartOffsetBreach(this)) + } + + def isFuture: Boolean = localLog.isFuture + + /** + * The size of the log in bytes + */ + def size: Long = localLog.segments.sizeInBytes + + /** + * The offset of the next message that will be appended to the log + */ + def logEndOffset: Long = localLog.logEndOffset + + /** + * The offset metadata of the next message that will be appended to the log + */ + def logEndOffsetMetadata: LogOffsetMetadata = localLog.logEndOffsetMetadata + + /** + * Roll the log over to a new empty log segment if necessary. + * The segment will be rolled if one of the following conditions met: + * 1. The logSegment is full + * 2. The maxTime has elapsed since the timestamp of first message in the segment (or since the + * create time if the first message does not have a timestamp) + * 3. The index is full + * + * @param messagesSize The messages set size in bytes. + * @param appendInfo log append information + * + * @return The currently active segment after (perhaps) rolling to a new segment + */ + private def maybeRoll(messagesSize: Int, appendInfo: LogAppendInfo): LogSegment = lock synchronized { + val segment = localLog.segments.activeSegment + val now = time.milliseconds + + val maxTimestampInMessages = appendInfo.maxTimestamp + val maxOffsetInMessages = appendInfo.lastOffset + + if (segment.shouldRoll(RollParams(config, appendInfo, messagesSize, now))) { + debug(s"Rolling new log segment (log_size = ${segment.size}/${config.segmentSize}}, " + + s"offset_index_size = ${segment.offsetIndex.entries}/${segment.offsetIndex.maxEntries}, " + + s"time_index_size = ${segment.timeIndex.entries}/${segment.timeIndex.maxEntries}, " + + s"inactive_time_ms = ${segment.timeWaitedForRoll(now, maxTimestampInMessages)}/${config.segmentMs - segment.rollJitterMs}).") + + /* + maxOffsetInMessages - Integer.MAX_VALUE is a heuristic value for the first offset in the set of messages. + Since the offset in messages will not differ by more than Integer.MAX_VALUE, this is guaranteed <= the real + first offset in the set. Determining the true first offset in the set requires decompression, which the follower + is trying to avoid during log append. Prior behavior assigned new baseOffset = logEndOffset from old segment. + This was problematic in the case that two consecutive messages differed in offset by + Integer.MAX_VALUE.toLong + 2 or more. In this case, the prior behavior would roll a new log segment whose + base offset was too low to contain the next message. This edge case is possible when a replica is recovering a + highly compacted topic from scratch. + Note that this is only required for pre-V2 message formats because these do not store the first message offset + in the header. + */ + val rollOffset = appendInfo + .firstOffset + .map(_.messageOffset) + .getOrElse(maxOffsetInMessages - Integer.MAX_VALUE) + + roll(Some(rollOffset)) + } else { + segment + } + } + + /** + * Roll the local log over to a new active segment starting with the expectedNextOffset (when provided), + * or localLog.logEndOffset otherwise. This will trim the index to the exact size of the number of entries + * it currently contains. + * + * @return The newly rolled segment + */ + def roll(expectedNextOffset: Option[Long] = None): LogSegment = lock synchronized { + val newSegment = localLog.roll(expectedNextOffset) + // Take a snapshot of the producer state to facilitate recovery. It is useful to have the snapshot + // offset align with the new segment offset since this ensures we can recover the segment by beginning + // with the corresponding snapshot file and scanning the segment data. Because the segment base offset + // may actually be ahead of the current producer state end offset (which corresponds to the log end offset), + // we manually override the state offset here prior to taking the snapshot. + producerStateManager.updateMapEndOffset(newSegment.baseOffset) + producerStateManager.takeSnapshot() + updateHighWatermarkWithLogEndOffset() + // Schedule an asynchronous flush of the old segment + scheduler.schedule("flush-log", () => flush(newSegment.baseOffset)) + newSegment + } + + /** + * Flush all local log segments + */ + def flush(): Unit = flush(logEndOffset) + + /** + * Flush local log segments for all offsets up to offset-1 + * + * @param offset The offset to flush up to (non-inclusive); the new recovery point + */ + def flush(offset: Long): Unit = { + maybeHandleIOException(s"Error while flushing log for $topicPartition in dir ${dir.getParent} with offset $offset") { + if (offset > localLog.recoveryPoint) { + debug(s"Flushing log up to offset $offset, last flushed: $lastFlushTime, current time: ${time.milliseconds()}, " + + s"unflushed: ${localLog.unflushedMessages}") + localLog.flush(offset) + lock synchronized { + localLog.markFlushed(offset) + } + } + } + } + + /** + * Completely delete the local log directory and all contents from the file system with no delay + */ + private[log] def delete(): Unit = { + maybeHandleIOException(s"Error while deleting log for $topicPartition in dir ${dir.getParent}") { + lock synchronized { + localLog.checkIfMemoryMappedBufferClosed() + producerExpireCheck.cancel(true) + leaderEpochCache.foreach(_.clear()) + val deletedSegments = localLog.deleteAllSegments() + deleteProducerSnapshots(deletedSegments, asyncDelete = false) + localLog.deleteEmptyDir() + } + } + } + + // visible for testing + private[log] def takeProducerSnapshot(): Unit = lock synchronized { + localLog.checkIfMemoryMappedBufferClosed() + producerStateManager.takeSnapshot() + } + + // visible for testing + private[log] def latestProducerSnapshotOffset: Option[Long] = lock synchronized { + producerStateManager.latestSnapshotOffset + } + + // visible for testing + private[log] def oldestProducerSnapshotOffset: Option[Long] = lock synchronized { + producerStateManager.oldestSnapshotOffset + } + + // visible for testing + private[log] def latestProducerStateEndOffset: Long = lock synchronized { + producerStateManager.mapEndOffset + } + + /** + * Truncate this log so that it ends with the greatest offset < targetOffset. + * + * @param targetOffset The offset to truncate to, an upper bound on all offsets in the log after truncation is complete. + * @return True iff targetOffset < logEndOffset + */ + private[kafka] def truncateTo(targetOffset: Long): Boolean = { + maybeHandleIOException(s"Error while truncating log to offset $targetOffset for $topicPartition in dir ${dir.getParent}") { + if (targetOffset < 0) + throw new IllegalArgumentException(s"Cannot truncate partition $topicPartition to a negative offset (%d).".format(targetOffset)) + if (targetOffset >= localLog.logEndOffset) { + info(s"Truncating to $targetOffset has no effect as the largest offset in the log is ${localLog.logEndOffset - 1}") + + // Always truncate epoch cache since we may have a conflicting epoch entry at the + // end of the log from the leader. This could happen if this broker was a leader + // and inserted the first start offset entry, but then failed to append any entries + // before another leader was elected. + lock synchronized { + leaderEpochCache.foreach(_.truncateFromEnd(logEndOffset)) + } + + false + } else { + info(s"Truncating to offset $targetOffset") + lock synchronized { + localLog.checkIfMemoryMappedBufferClosed() + if (localLog.segments.firstSegmentBaseOffset.get > targetOffset) { + truncateFullyAndStartAt(targetOffset) + } else { + val deletedSegments = localLog.truncateTo(targetOffset) + deleteProducerSnapshots(deletedSegments, asyncDelete = true) + leaderEpochCache.foreach(_.truncateFromEnd(targetOffset)) + logStartOffset = math.min(targetOffset, logStartOffset) + rebuildProducerState(targetOffset, producerStateManager) + if (highWatermark >= localLog.logEndOffset) + updateHighWatermark(localLog.logEndOffsetMetadata) + } + true + } + } + } + } + + /** + * Delete all data in the log and start at the new offset + * + * @param newOffset The new offset to start the log with + */ + def truncateFullyAndStartAt(newOffset: Long): Unit = { + maybeHandleIOException(s"Error while truncating the entire log for $topicPartition in dir ${dir.getParent}") { + debug(s"Truncate and start at offset $newOffset") + lock synchronized { + localLog.truncateFullyAndStartAt(newOffset) + leaderEpochCache.foreach(_.clearAndFlush()) + producerStateManager.truncateFullyAndStartAt(newOffset) + logStartOffset = newOffset + rebuildProducerState(newOffset, producerStateManager) + updateHighWatermark(localLog.logEndOffsetMetadata) + } + } + } + + /** + * The time this log is last known to have been fully flushed to disk + */ + def lastFlushTime: Long = localLog.lastFlushTime + + /** + * The active segment that is currently taking appends + */ + def activeSegment: LogSegment = localLog.segments.activeSegment + + /** + * All the log segments in this log ordered from oldest to newest + */ + def logSegments: Iterable[LogSegment] = localLog.segments.values + + /** + * Get all segments beginning with the segment that includes "from" and ending with the segment + * that includes up to "to-1" or the end of the log (if to > logEndOffset). + */ + def logSegments(from: Long, to: Long): Iterable[LogSegment] = lock synchronized { + localLog.segments.values(from, to) + } + + def nonActiveLogSegmentsFrom(from: Long): Iterable[LogSegment] = lock synchronized { + localLog.segments.nonActiveLogSegmentsFrom(from) + } + + override def toString: String = { + val logString = new StringBuilder + logString.append(s"Log(dir=$dir") + topicId.foreach(id => logString.append(s", topicId=$id")) + logString.append(s", topic=${topicPartition.topic}") + logString.append(s", partition=${topicPartition.partition}") + logString.append(s", highWatermark=$highWatermark") + logString.append(s", lastStableOffset=$lastStableOffset") + logString.append(s", logStartOffset=$logStartOffset") + logString.append(s", logEndOffset=$logEndOffset") + logString.append(")") + logString.toString + } + + private[log] def replaceSegments(newSegments: Seq[LogSegment], oldSegments: Seq[LogSegment]): Unit = { + lock synchronized { + localLog.checkIfMemoryMappedBufferClosed() + val deletedSegments = UnifiedLog.replaceSegments(localLog.segments, newSegments, oldSegments, dir, topicPartition, + config, scheduler, logDirFailureChannel, logIdent) + deleteProducerSnapshots(deletedSegments, asyncDelete = true) + } + } + + /** + * This function does not acquire Log.lock. The caller has to make sure log segments don't get deleted during + * this call, and also protects against calling this function on the same segment in parallel. + * + * Currently, it is used by LogCleaner threads on log compact non-active segments only with LogCleanerManager's lock + * to ensure no other logcleaner threads and retention thread can work on the same segment. + */ + private[log] def getFirstBatchTimestampForSegments(segments: Iterable[LogSegment]): Iterable[Long] = { + LogSegments.getFirstBatchTimestampForSegments(segments) + } + + /** + * remove deleted log metrics + */ + private[log] def removeLogMetrics(): Unit = { + removeMetric(LogMetricNames.NumLogSegments, tags) + removeMetric(LogMetricNames.LogStartOffset, tags) + removeMetric(LogMetricNames.LogEndOffset, tags) + removeMetric(LogMetricNames.Size, tags) + } + + /** + * Add the given segment to the segments in this log. If this segment replaces an existing segment, delete it. + * @param segment The segment to add + */ + @threadsafe + private[log] def addSegment(segment: LogSegment): LogSegment = localLog.segments.add(segment) + + private def maybeHandleIOException[T](msg: => String)(fun: => T): T = { + LocalLog.maybeHandleIOException(logDirFailureChannel, parentDir, msg) { + fun + } + } + + private[log] def splitOverflowedSegment(segment: LogSegment): List[LogSegment] = lock synchronized { + val result = UnifiedLog.splitOverflowedSegment(segment, localLog.segments, dir, topicPartition, config, scheduler, logDirFailureChannel, logIdent) + deleteProducerSnapshots(result.deletedSegments, asyncDelete = true) + result.newSegments.toList + } + + private[log] def deleteProducerSnapshots(segments: Iterable[LogSegment], asyncDelete: Boolean): Unit = { + UnifiedLog.deleteProducerSnapshots(segments, producerStateManager, asyncDelete, scheduler, config, logDirFailureChannel, parentDir, topicPartition) + } +} + +object UnifiedLog extends Logging { + val LogFileSuffix = LocalLog.LogFileSuffix + + val IndexFileSuffix = LocalLog.IndexFileSuffix + + val TimeIndexFileSuffix = LocalLog.TimeIndexFileSuffix + + val ProducerSnapshotFileSuffix = ".snapshot" + + val TxnIndexFileSuffix = LocalLog.TxnIndexFileSuffix + + val DeletedFileSuffix = LocalLog.DeletedFileSuffix + + val CleanedFileSuffix = LocalLog.CleanedFileSuffix + + val SwapFileSuffix = LocalLog.SwapFileSuffix + + val DeleteDirSuffix = LocalLog.DeleteDirSuffix + + val FutureDirSuffix = LocalLog.FutureDirSuffix + + private[log] val DeleteDirPattern = LocalLog.DeleteDirPattern + private[log] val FutureDirPattern = LocalLog.FutureDirPattern + + val UnknownOffset = LocalLog.UnknownOffset + + def apply(dir: File, + config: LogConfig, + logStartOffset: Long, + recoveryPoint: Long, + scheduler: Scheduler, + brokerTopicStats: BrokerTopicStats, + time: Time = Time.SYSTEM, + maxProducerIdExpirationMs: Int, + producerIdExpirationCheckIntervalMs: Int, + logDirFailureChannel: LogDirFailureChannel, + lastShutdownClean: Boolean = true, + topicId: Option[Uuid], + keepPartitionMetadataFile: Boolean): UnifiedLog = { + // create the log directory if it doesn't exist + Files.createDirectories(dir.toPath) + val topicPartition = UnifiedLog.parseTopicPartitionName(dir) + val segments = new LogSegments(topicPartition) + val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache( + dir, + topicPartition, + logDirFailureChannel, + config.recordVersion, + s"[UnifiedLog partition=$topicPartition, dir=${dir.getParent}] ") + val producerStateManager = new ProducerStateManager(topicPartition, dir, maxProducerIdExpirationMs) + val offsets = LogLoader.load(LoadLogParams( + dir, + topicPartition, + config, + scheduler, + time, + logDirFailureChannel, + lastShutdownClean, + segments, + logStartOffset, + recoveryPoint, + maxProducerIdExpirationMs, + leaderEpochCache, + producerStateManager)) + val localLog = new LocalLog(dir, config, segments, offsets.recoveryPoint, + offsets.nextOffsetMetadata, scheduler, time, topicPartition, logDirFailureChannel) + new UnifiedLog(offsets.logStartOffset, + localLog, + brokerTopicStats, + producerIdExpirationCheckIntervalMs, + leaderEpochCache, + producerStateManager, + topicId, + keepPartitionMetadataFile) + } + + def logFile(dir: File, offset: Long, suffix: String = ""): File = LocalLog.logFile(dir, offset, suffix) + + def logDeleteDirName(topicPartition: TopicPartition): String = LocalLog.logDeleteDirName(topicPartition) + + def logFutureDirName(topicPartition: TopicPartition): String = LocalLog.logFutureDirName(topicPartition) + + def logDirName(topicPartition: TopicPartition): String = LocalLog.logDirName(topicPartition) + + def offsetIndexFile(dir: File, offset: Long, suffix: String = ""): File = LocalLog.offsetIndexFile(dir, offset, suffix) + + def timeIndexFile(dir: File, offset: Long, suffix: String = ""): File = LocalLog.timeIndexFile(dir, offset, suffix) + + def deleteFileIfExists(file: File, suffix: String = ""): Unit = + Files.deleteIfExists(new File(file.getPath + suffix).toPath) + + /** + * Construct a producer id snapshot file using the given offset. + * + * @param dir The directory in which the log will reside + * @param offset The last offset (exclusive) included in the snapshot + */ + def producerSnapshotFile(dir: File, offset: Long): File = + new File(dir, LocalLog.filenamePrefixFromOffset(offset) + ProducerSnapshotFileSuffix) + + def transactionIndexFile(dir: File, offset: Long, suffix: String = ""): File = LocalLog.transactionIndexFile(dir, offset, suffix) + + def offsetFromFileName(filename: String): Long = LocalLog.offsetFromFileName(filename) + + def offsetFromFile(file: File): Long = LocalLog.offsetFromFile(file) + + def sizeInBytes(segments: Iterable[LogSegment]): Long = LogSegments.sizeInBytes(segments) + + def parseTopicPartitionName(dir: File): TopicPartition = LocalLog.parseTopicPartitionName(dir) + + private[log] def isIndexFile(file: File): Boolean = LocalLog.isIndexFile(file) + + private[log] def isLogFile(file: File): Boolean = LocalLog.isLogFile(file) + + private def loadProducersFromRecords(producerStateManager: ProducerStateManager, records: Records): Unit = { + val loadedProducers = mutable.Map.empty[Long, ProducerAppendInfo] + val completedTxns = ListBuffer.empty[CompletedTxn] + records.batches.forEach { batch => + if (batch.hasProducerId) { + val maybeCompletedTxn = updateProducers( + producerStateManager, + batch, + loadedProducers, + firstOffsetMetadata = None, + origin = AppendOrigin.Replication) + maybeCompletedTxn.foreach(completedTxns += _) + } + } + loadedProducers.values.foreach(producerStateManager.update) + completedTxns.foreach(producerStateManager.completeTxn) + } + + private def updateProducers(producerStateManager: ProducerStateManager, + batch: RecordBatch, + producers: mutable.Map[Long, ProducerAppendInfo], + firstOffsetMetadata: Option[LogOffsetMetadata], + origin: AppendOrigin): Option[CompletedTxn] = { + val producerId = batch.producerId + val appendInfo = producers.getOrElseUpdate(producerId, producerStateManager.prepareUpdate(producerId, origin)) + appendInfo.append(batch, firstOffsetMetadata) + } + + /** + * If the recordVersion is >= RecordVersion.V2, then create and return a LeaderEpochFileCache. + * Otherwise, the message format is considered incompatible and the existing LeaderEpoch file + * is deleted. + * + * @param dir The directory in which the log will reside + * @param topicPartition The topic partition + * @param logDirFailureChannel The LogDirFailureChannel to asynchronously handle log dir failure + * @param recordVersion The record version + * @param logPrefix The logging prefix + * @return The new LeaderEpochFileCache instance (if created), none otherwise + */ + def maybeCreateLeaderEpochCache(dir: File, + topicPartition: TopicPartition, + logDirFailureChannel: LogDirFailureChannel, + recordVersion: RecordVersion, + logPrefix: String): Option[LeaderEpochFileCache] = { + val leaderEpochFile = LeaderEpochCheckpointFile.newFile(dir) + + def newLeaderEpochFileCache(): LeaderEpochFileCache = { + val checkpointFile = new LeaderEpochCheckpointFile(leaderEpochFile, logDirFailureChannel) + new LeaderEpochFileCache(topicPartition, checkpointFile) + } + + if (recordVersion.precedes(RecordVersion.V2)) { + val currentCache = if (leaderEpochFile.exists()) + Some(newLeaderEpochFileCache()) + else + None + + if (currentCache.exists(_.nonEmpty)) + warn(s"${logPrefix}Deleting non-empty leader epoch cache due to incompatible message format $recordVersion") + + Files.deleteIfExists(leaderEpochFile.toPath) + None + } else { + Some(newLeaderEpochFileCache()) + } + } + + private[log] def replaceSegments(existingSegments: LogSegments, + newSegments: Seq[LogSegment], + oldSegments: Seq[LogSegment], + dir: File, + topicPartition: TopicPartition, + config: LogConfig, + scheduler: Scheduler, + logDirFailureChannel: LogDirFailureChannel, + logPrefix: String, + isRecoveredSwapFile: Boolean = false): Iterable[LogSegment] = { + LocalLog.replaceSegments(existingSegments, + newSegments, + oldSegments, + dir, + topicPartition, + config, + scheduler, + logDirFailureChannel, + logPrefix, + isRecoveredSwapFile) + } + + private[log] def deleteSegmentFiles(segmentsToDelete: immutable.Iterable[LogSegment], + asyncDelete: Boolean, + dir: File, + topicPartition: TopicPartition, + config: LogConfig, + scheduler: Scheduler, + logDirFailureChannel: LogDirFailureChannel, + logPrefix: String): Unit = { + LocalLog.deleteSegmentFiles(segmentsToDelete, asyncDelete, dir, topicPartition, config, scheduler, logDirFailureChannel, logPrefix) + } + + /** + * Rebuilds producer state until the provided lastOffset. This function may be called from the + * recovery code path, and thus must be free of all side-effects, i.e. it must not update any + * log-specific state. + * + * @param producerStateManager The ProducerStateManager instance to be rebuilt. + * @param segments The segments of the log whose producer state is being rebuilt + * @param logStartOffset The log start offset + * @param lastOffset The last offset upto which the producer state needs to be rebuilt + * @param recordVersion The record version + * @param time The time instance used for checking the clock + * @param reloadFromCleanShutdown True if the producer state is being built after a clean shutdown, + * false otherwise. + * @param logPrefix The logging prefix + */ + private[log] def rebuildProducerState(producerStateManager: ProducerStateManager, + segments: LogSegments, + logStartOffset: Long, + lastOffset: Long, + recordVersion: RecordVersion, + time: Time, + reloadFromCleanShutdown: Boolean, + logPrefix: String): Unit = { + val offsetsToSnapshot = + if (segments.nonEmpty) { + val lastSegmentBaseOffset = segments.lastSegment.get.baseOffset + val nextLatestSegmentBaseOffset = segments.lowerSegment(lastSegmentBaseOffset).map(_.baseOffset) + Seq(nextLatestSegmentBaseOffset, Some(lastSegmentBaseOffset), Some(lastOffset)) + } else { + Seq(Some(lastOffset)) + } + info(s"${logPrefix}Loading producer state till offset $lastOffset with message format version ${recordVersion.value}") + + // We want to avoid unnecessary scanning of the log to build the producer state when the broker is being + // upgraded. The basic idea is to use the absence of producer snapshot files to detect the upgrade case, + // but we have to be careful not to assume too much in the presence of broker failures. The two most common + // upgrade cases in which we expect to find no snapshots are the following: + // + // 1. The broker has been upgraded, but the topic is still on the old message format. + // 2. The broker has been upgraded, the topic is on the new message format, and we had a clean shutdown. + // + // If we hit either of these cases, we skip producer state loading and write a new snapshot at the log end + // offset (see below). The next time the log is reloaded, we will load producer state using this snapshot + // (or later snapshots). Otherwise, if there is no snapshot file, then we have to rebuild producer state + // from the first segment. + if (recordVersion.value < RecordBatch.MAGIC_VALUE_V2 || + (producerStateManager.latestSnapshotOffset.isEmpty && reloadFromCleanShutdown)) { + // To avoid an expensive scan through all of the segments, we take empty snapshots from the start of the + // last two segments and the last offset. This should avoid the full scan in the case that the log needs + // truncation. + offsetsToSnapshot.flatten.foreach { offset => + producerStateManager.updateMapEndOffset(offset) + producerStateManager.takeSnapshot() + } + } else { + info(s"${logPrefix}Reloading from producer snapshot and rebuilding producer state from offset $lastOffset") + val isEmptyBeforeTruncation = producerStateManager.isEmpty && producerStateManager.mapEndOffset >= lastOffset + val producerStateLoadStart = time.milliseconds() + producerStateManager.truncateAndReload(logStartOffset, lastOffset, time.milliseconds()) + val segmentRecoveryStart = time.milliseconds() + + // Only do the potentially expensive reloading if the last snapshot offset is lower than the log end + // offset (which would be the case on first startup) and there were active producers prior to truncation + // (which could be the case if truncating after initial loading). If there weren't, then truncating + // shouldn't change that fact (although it could cause a producerId to expire earlier than expected), + // and we can skip the loading. This is an optimization for users which are not yet using + // idempotent/transactional features yet. + if (lastOffset > producerStateManager.mapEndOffset && !isEmptyBeforeTruncation) { + val segmentOfLastOffset = segments.floorSegment(lastOffset) + + segments.values(producerStateManager.mapEndOffset, lastOffset).foreach { segment => + val startOffset = Utils.max(segment.baseOffset, producerStateManager.mapEndOffset, logStartOffset) + producerStateManager.updateMapEndOffset(startOffset) + + if (offsetsToSnapshot.contains(Some(segment.baseOffset))) + producerStateManager.takeSnapshot() + + val maxPosition = if (segmentOfLastOffset.contains(segment)) { + Option(segment.translateOffset(lastOffset)) + .map(_.position) + .getOrElse(segment.size) + } else { + segment.size + } + + val fetchDataInfo = segment.read(startOffset, + maxSize = Int.MaxValue, + maxPosition = maxPosition) + if (fetchDataInfo != null) + loadProducersFromRecords(producerStateManager, fetchDataInfo.records) + } + } + producerStateManager.updateMapEndOffset(lastOffset) + producerStateManager.takeSnapshot() + info(s"${logPrefix}Producer state recovery took ${segmentRecoveryStart - producerStateLoadStart}ms for snapshot load " + + s"and ${time.milliseconds() - segmentRecoveryStart}ms for segment recovery from offset $lastOffset") + } + } + + private[log] def splitOverflowedSegment(segment: LogSegment, + existingSegments: LogSegments, + dir: File, + topicPartition: TopicPartition, + config: LogConfig, + scheduler: Scheduler, + logDirFailureChannel: LogDirFailureChannel, + logPrefix: String): SplitSegmentResult = { + LocalLog.splitOverflowedSegment(segment, existingSegments, dir, topicPartition, config, scheduler, logDirFailureChannel, logPrefix) + } + + private[log] def deleteProducerSnapshots(segments: Iterable[LogSegment], + producerStateManager: ProducerStateManager, + asyncDelete: Boolean, + scheduler: Scheduler, + config: LogConfig, + logDirFailureChannel: LogDirFailureChannel, + parentDir: String, + topicPartition: TopicPartition): Unit = { + val snapshotsToDelete = segments.flatMap { segment => + producerStateManager.removeAndMarkSnapshotForDeletion(segment.baseOffset)} + def deleteProducerSnapshots(): Unit = { + LocalLog.maybeHandleIOException(logDirFailureChannel, + parentDir, + s"Error while deleting producer state snapshots for $topicPartition in dir $parentDir") { + snapshotsToDelete.foreach { snapshot => + snapshot.deleteIfExists() + } + } + } + + if (asyncDelete) + scheduler.schedule("delete-producer-snapshot", () => deleteProducerSnapshots(), delay = config.fileDeleteDelayMs) + else + deleteProducerSnapshots() + } + + private[log] def createNewCleanedSegment(dir: File, logConfig: LogConfig, baseOffset: Long): LogSegment = { + LocalLog.createNewCleanedSegment(dir, logConfig, baseOffset) + } +} + +object LogMetricNames { + val NumLogSegments: String = "NumLogSegments" + val LogStartOffset: String = "LogStartOffset" + val LogEndOffset: String = "LogEndOffset" + val Size: String = "Size" + + def allMetricNames: List[String] = { + List(NumLogSegments, LogStartOffset, LogEndOffset, Size) + } +} + +case class RetentionMsBreach(log: UnifiedLog) extends SegmentDeletionReason { + override def logReason(toDelete: List[LogSegment]): Unit = { + val retentionMs = log.config.retentionMs + toDelete.foreach { segment => + segment.largestRecordTimestamp match { + case Some(_) => + log.info(s"Deleting segment $segment due to retention time ${retentionMs}ms breach based on the largest " + + s"record timestamp in the segment") + case None => + log.info(s"Deleting segment $segment due to retention time ${retentionMs}ms breach based on the " + + s"last modified time of the segment") + } + } + } +} + +case class RetentionSizeBreach(log: UnifiedLog) extends SegmentDeletionReason { + override def logReason(toDelete: List[LogSegment]): Unit = { + var size = log.size + toDelete.foreach { segment => + size -= segment.size + log.info(s"Deleting segment $segment due to retention size ${log.config.retentionSize} breach. Log size " + + s"after deletion will be $size.") + } + } +} + +case class StartOffsetBreach(log: UnifiedLog) extends SegmentDeletionReason { + override def logReason(toDelete: List[LogSegment]): Unit = { + log.info(s"Deleting segments due to log start offset ${log.logStartOffset} breach: ${toDelete.mkString(",")}") + } +} diff --git a/core/src/main/scala/kafka/log/package.html b/core/src/main/scala/kafka/log/package.html new file mode 100644 index 0000000..ee2f72e --- /dev/null +++ b/core/src/main/scala/kafka/log/package.html @@ -0,0 +1,24 @@ + +The log management system for Kafka. + +The entry point for this system is LogManager. LogManager is responsible for holding all the logs, and handing them out by topic/partition. It also handles the enforcement of the +flush policy and retention policies. + +The Log itself is made up of log segments. A log is a FileRecords that contains the data and an OffsetIndex that supports reads by offset on the log. \ No newline at end of file diff --git a/core/src/main/scala/kafka/message/CompressionCodec.scala b/core/src/main/scala/kafka/message/CompressionCodec.scala new file mode 100644 index 0000000..b174fea --- /dev/null +++ b/core/src/main/scala/kafka/message/CompressionCodec.scala @@ -0,0 +1,108 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +package kafka.message + +import java.util.Locale + +import kafka.common.UnknownCodecException + +object CompressionCodec { + def getCompressionCodec(codec: Int): CompressionCodec = { + codec match { + case NoCompressionCodec.codec => NoCompressionCodec + case GZIPCompressionCodec.codec => GZIPCompressionCodec + case SnappyCompressionCodec.codec => SnappyCompressionCodec + case LZ4CompressionCodec.codec => LZ4CompressionCodec + case ZStdCompressionCodec.codec => ZStdCompressionCodec + case _ => throw new UnknownCodecException("%d is an unknown compression codec".format(codec)) + } + } + def getCompressionCodec(name: String): CompressionCodec = { + name.toLowerCase(Locale.ROOT) match { + case NoCompressionCodec.name => NoCompressionCodec + case GZIPCompressionCodec.name => GZIPCompressionCodec + case SnappyCompressionCodec.name => SnappyCompressionCodec + case LZ4CompressionCodec.name => LZ4CompressionCodec + case ZStdCompressionCodec.name => ZStdCompressionCodec + case _ => throw new kafka.common.UnknownCodecException("%s is an unknown compression codec".format(name)) + } + } +} + +object BrokerCompressionCodec { + + val brokerCompressionCodecs = List(UncompressedCodec, ZStdCompressionCodec, LZ4CompressionCodec, SnappyCompressionCodec, GZIPCompressionCodec, ProducerCompressionCodec) + val brokerCompressionOptions: List[String] = brokerCompressionCodecs.map(codec => codec.name) + + def isValid(compressionType: String): Boolean = brokerCompressionOptions.contains(compressionType.toLowerCase(Locale.ROOT)) + + def getCompressionCodec(compressionType: String): CompressionCodec = { + compressionType.toLowerCase(Locale.ROOT) match { + case UncompressedCodec.name => NoCompressionCodec + case _ => CompressionCodec.getCompressionCodec(compressionType) + } + } + + def getTargetCompressionCodec(compressionType: String, producerCompression: CompressionCodec): CompressionCodec = { + if (ProducerCompressionCodec.name.equals(compressionType)) + producerCompression + else + getCompressionCodec(compressionType) + } +} + +sealed trait CompressionCodec { def codec: Int; def name: String } +sealed trait BrokerCompressionCodec { def name: String } + +case object DefaultCompressionCodec extends CompressionCodec with BrokerCompressionCodec { + val codec: Int = GZIPCompressionCodec.codec + val name: String = GZIPCompressionCodec.name +} + +case object GZIPCompressionCodec extends CompressionCodec with BrokerCompressionCodec { + val codec = 1 + val name = "gzip" +} + +case object SnappyCompressionCodec extends CompressionCodec with BrokerCompressionCodec { + val codec = 2 + val name = "snappy" +} + +case object LZ4CompressionCodec extends CompressionCodec with BrokerCompressionCodec { + val codec = 3 + val name = "lz4" +} + +case object ZStdCompressionCodec extends CompressionCodec with BrokerCompressionCodec { + val codec = 4 + val name = "zstd" +} + +case object NoCompressionCodec extends CompressionCodec with BrokerCompressionCodec { + val codec = 0 + val name = "none" +} + +case object UncompressedCodec extends BrokerCompressionCodec { + val name = "uncompressed" +} + +case object ProducerCompressionCodec extends BrokerCompressionCodec { + val name = "producer" +} diff --git a/core/src/main/scala/kafka/metrics/KafkaCSVMetricsReporter.scala b/core/src/main/scala/kafka/metrics/KafkaCSVMetricsReporter.scala new file mode 100755 index 0000000..0d83547 --- /dev/null +++ b/core/src/main/scala/kafka/metrics/KafkaCSVMetricsReporter.scala @@ -0,0 +1,87 @@ +/** + * + * + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.metrics + +import java.io.File +import java.nio.file.Files + +import com.yammer.metrics.reporting.CsvReporter +import java.util.concurrent.TimeUnit + +import kafka.utils.{Logging, VerifiableProperties} +import org.apache.kafka.common.utils.Utils + +private trait KafkaCSVMetricsReporterMBean extends KafkaMetricsReporterMBean + +private class KafkaCSVMetricsReporter extends KafkaMetricsReporter + with KafkaCSVMetricsReporterMBean + with Logging { + + private var csvDir: File = null + private var underlying: CsvReporter = null + private var running = false + private var initialized = false + + + override def getMBeanName = "kafka:type=kafka.metrics.KafkaCSVMetricsReporter" + + + override def init(props: VerifiableProperties): Unit = { + synchronized { + if (!initialized) { + val metricsConfig = new KafkaMetricsConfig(props) + csvDir = new File(props.getString("kafka.csv.metrics.dir", "kafka_metrics")) + Utils.delete(csvDir) + Files.createDirectories(csvDir.toPath()) + underlying = new CsvReporter(KafkaYammerMetrics.defaultRegistry(), csvDir) + if (props.getBoolean("kafka.csv.metrics.reporter.enabled", default = false)) { + initialized = true + startReporter(metricsConfig.pollingIntervalSecs) + } + } + } + } + + + override def startReporter(pollingPeriodSecs: Long): Unit = { + synchronized { + if (initialized && !running) { + underlying.start(pollingPeriodSecs, TimeUnit.SECONDS) + running = true + info("Started Kafka CSV metrics reporter with polling period %d seconds".format(pollingPeriodSecs)) + } + } + } + + + override def stopReporter(): Unit = { + synchronized { + if (initialized && running) { + underlying.shutdown() + running = false + info("Stopped Kafka CSV metrics reporter") + underlying = new CsvReporter(KafkaYammerMetrics.defaultRegistry(), csvDir) + } + } + } + +} + diff --git a/core/src/main/scala/kafka/metrics/KafkaMetricsConfig.scala b/core/src/main/scala/kafka/metrics/KafkaMetricsConfig.scala new file mode 100755 index 0000000..b13a1b9 --- /dev/null +++ b/core/src/main/scala/kafka/metrics/KafkaMetricsConfig.scala @@ -0,0 +1,41 @@ +/** + * + * + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.metrics + +import kafka.server.{Defaults, KafkaConfig} +import kafka.utils.{CoreUtils, VerifiableProperties} +import scala.collection.Seq + +class KafkaMetricsConfig(props: VerifiableProperties) { + + /** + * Comma-separated list of reporter types. These classes should be on the + * classpath and will be instantiated at run-time. + */ + val reporters: Seq[String] = CoreUtils.parseCsvList(props.getString(KafkaConfig.KafkaMetricsReporterClassesProp, + Defaults.KafkaMetricReporterClasses)) + + /** + * The metrics polling interval (in seconds). + */ + val pollingIntervalSecs: Int = props.getInt(KafkaConfig.KafkaMetricsPollingIntervalSecondsProp, + Defaults.KafkaMetricsPollingIntervalSeconds) +} diff --git a/core/src/main/scala/kafka/metrics/KafkaMetricsGroup.scala b/core/src/main/scala/kafka/metrics/KafkaMetricsGroup.scala new file mode 100644 index 0000000..a63be1f --- /dev/null +++ b/core/src/main/scala/kafka/metrics/KafkaMetricsGroup.scala @@ -0,0 +1,107 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.metrics + +import java.util.concurrent.TimeUnit + +import com.yammer.metrics.core.{Gauge, MetricName, Meter, Histogram, Timer} +import kafka.utils.Logging +import org.apache.kafka.common.utils.Sanitizer + +trait KafkaMetricsGroup extends Logging { + + /** + * Creates a new MetricName object for gauges, meters, etc. created for this + * metrics group. + * @param name Descriptive name of the metric. + * @param tags Additional attributes which mBean will have. + * @return Sanitized metric name object. + */ + def metricName(name: String, tags: scala.collection.Map[String, String]): MetricName = { + val klass = this.getClass + val pkg = if (klass.getPackage == null) "" else klass.getPackage.getName + val simpleName = klass.getSimpleName.replaceAll("\\$$", "") + + explicitMetricName(pkg, simpleName, name, tags) + } + + + protected def explicitMetricName(group: String, typeName: String, name: String, + tags: scala.collection.Map[String, String]): MetricName = { + + val nameBuilder: StringBuilder = new StringBuilder + + nameBuilder.append(group) + + nameBuilder.append(":type=") + + nameBuilder.append(typeName) + + if (name.length > 0) { + nameBuilder.append(",name=") + nameBuilder.append(name) + } + + val scope: String = toScope(tags).orNull + val tagsName = toMBeanName(tags) + tagsName.foreach(nameBuilder.append(",").append(_)) + + new MetricName(group, typeName, name, scope, nameBuilder.toString) + } + + def newGauge[T](name: String, metric: Gauge[T], tags: scala.collection.Map[String, String] = Map.empty): Gauge[T] = + KafkaYammerMetrics.defaultRegistry().newGauge(metricName(name, tags), metric) + + def newMeter(name: String, eventType: String, timeUnit: TimeUnit, tags: scala.collection.Map[String, String] = Map.empty): Meter = + KafkaYammerMetrics.defaultRegistry().newMeter(metricName(name, tags), eventType, timeUnit) + + def newHistogram(name: String, biased: Boolean = true, tags: scala.collection.Map[String, String] = Map.empty): Histogram = + KafkaYammerMetrics.defaultRegistry().newHistogram(metricName(name, tags), biased) + + def newTimer(name: String, durationUnit: TimeUnit, rateUnit: TimeUnit, tags: scala.collection.Map[String, String] = Map.empty): Timer = + KafkaYammerMetrics.defaultRegistry().newTimer(metricName(name, tags), durationUnit, rateUnit) + + def removeMetric(name: String, tags: scala.collection.Map[String, String] = Map.empty): Unit = + KafkaYammerMetrics.defaultRegistry().removeMetric(metricName(name, tags)) + + private def toMBeanName(tags: collection.Map[String, String]): Option[String] = { + val filteredTags = tags.filter { case (_, tagValue) => tagValue != "" } + if (filteredTags.nonEmpty) { + val tagsString = filteredTags.map { case (key, value) => "%s=%s".format(key, Sanitizer.jmxSanitize(value)) }.mkString(",") + Some(tagsString) + } + else None + } + + private def toScope(tags: collection.Map[String, String]): Option[String] = { + val filteredTags = tags.filter { case (_, tagValue) => tagValue != ""} + if (filteredTags.nonEmpty) { + // convert dot to _ since reporters like Graphite typically use dot to represent hierarchy + val tagsString = filteredTags + .toList.sortWith((t1, t2) => t1._1 < t2._1) + .map { case (key, value) => "%s.%s".format(key, value.replaceAll("\\.", "_"))} + .mkString(".") + + Some(tagsString) + } + else None + } + +} + +object KafkaMetricsGroup extends KafkaMetricsGroup diff --git a/core/src/main/scala/kafka/metrics/KafkaMetricsReporter.scala b/core/src/main/scala/kafka/metrics/KafkaMetricsReporter.scala new file mode 100755 index 0000000..30baad3 --- /dev/null +++ b/core/src/main/scala/kafka/metrics/KafkaMetricsReporter.scala @@ -0,0 +1,80 @@ +/** + * + * + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.metrics + +import kafka.utils.{CoreUtils, VerifiableProperties} +import java.util.concurrent.atomic.AtomicBoolean + +import scala.collection.Seq +import scala.collection.mutable.ArrayBuffer + + +/** + * Base trait for reporter MBeans. If a client wants to expose these JMX + * operations on a custom reporter (that implements + * [[kafka.metrics.KafkaMetricsReporter]]), the custom reporter needs to + * additionally implement an MBean trait that extends this trait so that the + * registered MBean is compliant with the standard MBean convention. + */ +trait KafkaMetricsReporterMBean { + def startReporter(pollingPeriodInSeconds: Long): Unit + def stopReporter(): Unit + /** + * + * @return The name with which the MBean will be registered. + */ + def getMBeanName: String +} + +/** + * Implement {@link org.apache.kafka.common.ClusterResourceListener} to receive cluster metadata once it's available. Please see the class documentation for ClusterResourceListener for more information. + */ +trait KafkaMetricsReporter { + def init(props: VerifiableProperties): Unit +} + +object KafkaMetricsReporter { + val ReporterStarted: AtomicBoolean = new AtomicBoolean(false) + private var reporters: ArrayBuffer[KafkaMetricsReporter] = null + + def startReporters(verifiableProps: VerifiableProperties): Seq[KafkaMetricsReporter] = { + ReporterStarted synchronized { + if (!ReporterStarted.get()) { + reporters = ArrayBuffer[KafkaMetricsReporter]() + val metricsConfig = new KafkaMetricsConfig(verifiableProps) + if (metricsConfig.reporters.nonEmpty) { + metricsConfig.reporters.foreach(reporterType => { + val reporter = CoreUtils.createObject[KafkaMetricsReporter](reporterType) + reporter.init(verifiableProps) + reporters += reporter + reporter match { + case bean: KafkaMetricsReporterMBean => CoreUtils.registerMBean(reporter, bean.getMBeanName) + case _ => + } + }) + ReporterStarted.set(true) + } + } + } + reporters + } +} + diff --git a/core/src/main/scala/kafka/metrics/KafkaTimer.scala b/core/src/main/scala/kafka/metrics/KafkaTimer.scala new file mode 100644 index 0000000..24b54d6 --- /dev/null +++ b/core/src/main/scala/kafka/metrics/KafkaTimer.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.metrics + +import com.yammer.metrics.core.Timer + +/** + * A wrapper around metrics timer object that provides a convenient mechanism + * to time code blocks. This pattern was borrowed from the metrics-scala_2.9.1 + * package. + * @param metric The underlying timer object. + */ +class KafkaTimer(metric: Timer) { + + def time[A](f: => A): A = { + val ctx = metric.time + try f + finally ctx.stop() + } +} + diff --git a/core/src/main/scala/kafka/metrics/LinuxIoMetricsCollector.scala b/core/src/main/scala/kafka/metrics/LinuxIoMetricsCollector.scala new file mode 100644 index 0000000..17de008 --- /dev/null +++ b/core/src/main/scala/kafka/metrics/LinuxIoMetricsCollector.scala @@ -0,0 +1,102 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.metrics + +import java.nio.file.{Files, Paths} + +import org.apache.kafka.common.utils.Time +import org.slf4j.Logger + +import scala.jdk.CollectionConverters._ + +/** + * Retrieves Linux /proc/self/io metrics. + */ +class LinuxIoMetricsCollector(procRoot: String, val time: Time, val logger: Logger) { + import LinuxIoMetricsCollector._ + var lastUpdateMs = -1L + var cachedReadBytes = 0L + var cachedWriteBytes = 0L + val path = Paths.get(procRoot, "self", "io") + + def readBytes(): Long = this.synchronized { + val curMs = time.milliseconds() + if (curMs != lastUpdateMs) { + updateValues(curMs) + } + cachedReadBytes + } + + def writeBytes(): Long = this.synchronized { + val curMs = time.milliseconds() + if (curMs != lastUpdateMs) { + updateValues(curMs) + } + cachedWriteBytes + } + + /** + * Read /proc/self/io. + * + * Generally, each line in this file contains a prefix followed by a colon and a number. + * + * For example, it might contain this: + * rchar: 4052 + * wchar: 0 + * syscr: 13 + * syscw: 0 + * read_bytes: 0 + * write_bytes: 0 + * cancelled_write_bytes: 0 + */ + def updateValues(now: Long): Boolean = this.synchronized { + try { + cachedReadBytes = -1 + cachedWriteBytes = -1 + val lines = Files.readAllLines(path).asScala + lines.foreach(line => { + if (line.startsWith(READ_BYTES_PREFIX)) { + cachedReadBytes = line.substring(READ_BYTES_PREFIX.size).toLong + } else if (line.startsWith(WRITE_BYTES_PREFIX)) { + cachedWriteBytes = line.substring(WRITE_BYTES_PREFIX.size).toLong + } + }) + lastUpdateMs = now + true + } catch { + case t: Throwable => { + logger.warn("Unable to update IO metrics", t) + false + } + } + } + + def usable(): Boolean = { + if (path.toFile().exists()) { + updateValues(time.milliseconds()) + } else { + logger.debug(s"disabling IO metrics collection because ${path} does not exist.") + false + } + } +} + +object LinuxIoMetricsCollector { + val READ_BYTES_PREFIX = "read_bytes: " + val WRITE_BYTES_PREFIX = "write_bytes: " +} diff --git a/core/src/main/scala/kafka/network/RequestChannel.scala b/core/src/main/scala/kafka/network/RequestChannel.scala new file mode 100644 index 0000000..5e456b0 --- /dev/null +++ b/core/src/main/scala/kafka/network/RequestChannel.scala @@ -0,0 +1,586 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.network + +import java.net.InetAddress +import java.nio.ByteBuffer +import java.util.concurrent._ +import com.fasterxml.jackson.databind.JsonNode +import com.typesafe.scalalogging.Logger +import com.yammer.metrics.core.Meter +import kafka.metrics.KafkaMetricsGroup +import kafka.network +import kafka.server.KafkaConfig +import kafka.utils.{Logging, NotNothing, Pool} +import kafka.utils.Implicits._ +import org.apache.kafka.common.config.ConfigResource +import org.apache.kafka.common.memory.MemoryPool +import org.apache.kafka.common.message.ApiMessageType.ListenerType +import org.apache.kafka.common.message.EnvelopeResponseData +import org.apache.kafka.common.network.Send +import org.apache.kafka.common.protocol.{ApiKeys, Errors, ObjectSerializationCache} +import org.apache.kafka.common.requests._ +import org.apache.kafka.common.security.auth.KafkaPrincipal +import org.apache.kafka.common.utils.{Sanitizer, Time} + +import scala.annotation.nowarn +import scala.collection.mutable +import scala.jdk.CollectionConverters._ +import scala.reflect.ClassTag + +object RequestChannel extends Logging { + private val requestLogger = Logger("kafka.request.logger") + + val RequestQueueSizeMetric = "RequestQueueSize" + val ResponseQueueSizeMetric = "ResponseQueueSize" + val ProcessorMetricTag = "processor" + + def isRequestLoggingEnabled: Boolean = requestLogger.underlying.isDebugEnabled + + sealed trait BaseRequest + case object ShutdownRequest extends BaseRequest + + case class Session(principal: KafkaPrincipal, clientAddress: InetAddress) { + val sanitizedUser: String = Sanitizer.sanitize(principal.getName) + } + + class Metrics(enabledApis: Iterable[ApiKeys]) { + def this(scope: ListenerType) = { + this(ApiKeys.apisForListener(scope).asScala) + } + + private val metricsMap = mutable.Map[String, RequestMetrics]() + + (enabledApis.map(_.name) ++ + Seq(RequestMetrics.consumerFetchMetricName, RequestMetrics.followFetchMetricName)).foreach { name => + metricsMap.put(name, new RequestMetrics(name)) + } + + def apply(metricName: String): RequestMetrics = metricsMap(metricName) + + def close(): Unit = { + metricsMap.values.foreach(_.removeMetrics()) + } + } + + class Request(val processor: Int, + val context: RequestContext, + val startTimeNanos: Long, + val memoryPool: MemoryPool, + @volatile var buffer: ByteBuffer, + metrics: RequestChannel.Metrics, + val envelope: Option[RequestChannel.Request] = None) extends BaseRequest { + // These need to be volatile because the readers are in the network thread and the writers are in the request + // handler threads or the purgatory threads + @volatile var requestDequeueTimeNanos = -1L + @volatile var apiLocalCompleteTimeNanos = -1L + @volatile var responseCompleteTimeNanos = -1L + @volatile var responseDequeueTimeNanos = -1L + @volatile var messageConversionsTimeNanos = 0L + @volatile var apiThrottleTimeMs = 0L + @volatile var temporaryMemoryBytes = 0L + @volatile var recordNetworkThreadTimeCallback: Option[Long => Unit] = None + + val session = Session(context.principal, context.clientAddress) + + private val bodyAndSize: RequestAndSize = context.parseRequest(buffer) + + // This is constructed on creation of a Request so that the JSON representation is computed before the request is + // processed by the api layer. Otherwise, a ProduceRequest can occur without its data (ie. it goes into purgatory). + val requestLog: Option[JsonNode] = + if (RequestChannel.isRequestLoggingEnabled) Some(RequestConvertToJson.request(loggableRequest)) + else None + + def header: RequestHeader = context.header + + def sizeOfBodyInBytes: Int = bodyAndSize.size + + def sizeInBytes: Int = header.size(new ObjectSerializationCache) + sizeOfBodyInBytes + + //most request types are parsed entirely into objects at this point. for those we can release the underlying buffer. + //some (like produce, or any time the schema contains fields of types BYTES or NULLABLE_BYTES) retain a reference + //to the buffer. for those requests we cannot release the buffer early, but only when request processing is done. + if (!header.apiKey.requiresDelayedAllocation) { + releaseBuffer() + } + + def isForwarded: Boolean = envelope.isDefined + + def buildResponseSend(abstractResponse: AbstractResponse): Send = { + envelope match { + case Some(request) => + val envelopeResponse = if (abstractResponse.errorCounts().containsKey(Errors.NOT_CONTROLLER)) { + // Since it's a NOT_CONTROLLER error response, we need to make envelope response with NOT_CONTROLLER error + // to notify the requester (i.e. BrokerToControllerRequestThread) to update active controller + new EnvelopeResponse(new EnvelopeResponseData() + .setErrorCode(Errors.NOT_CONTROLLER.code())) + } else { + val responseBytes = context.buildResponseEnvelopePayload(abstractResponse) + new EnvelopeResponse(responseBytes, Errors.NONE) + } + request.context.buildResponseSend(envelopeResponse) + case None => + context.buildResponseSend(abstractResponse) + } + } + + def responseNode(response: AbstractResponse): Option[JsonNode] = { + if (RequestChannel.isRequestLoggingEnabled) + Some(RequestConvertToJson.response(response, context.apiVersion)) + else + None + } + + def headerForLoggingOrThrottling(): RequestHeader = { + envelope match { + case Some(request) => + request.context.header + case None => + context.header + } + } + + def requestDesc(details: Boolean): String = { + val forwardDescription = envelope.map { request => + s"Forwarded request: ${request.context} " + }.getOrElse("") + s"$forwardDescription$header -- ${loggableRequest.toString(details)}" + } + + def body[T <: AbstractRequest](implicit classTag: ClassTag[T], @nowarn("cat=unused") nn: NotNothing[T]): T = { + bodyAndSize.request match { + case r: T => r + case r => + throw new ClassCastException(s"Expected request with type ${classTag.runtimeClass}, but found ${r.getClass}") + } + } + + def loggableRequest: AbstractRequest = { + + bodyAndSize.request match { + case alterConfigs: AlterConfigsRequest => + val newData = alterConfigs.data().duplicate() + newData.resources().forEach(resource => { + val resourceType = ConfigResource.Type.forId(resource.resourceType()) + resource.configs().forEach(config => { + config.setValue(KafkaConfig.loggableValue(resourceType, config.name(), config.value())) + }) + }) + new AlterConfigsRequest(newData, alterConfigs.version()) + + case alterConfigs: IncrementalAlterConfigsRequest => + val newData = alterConfigs.data().duplicate() + newData.resources().forEach(resource => { + val resourceType = ConfigResource.Type.forId(resource.resourceType()) + resource.configs().forEach(config => { + config.setValue(KafkaConfig.loggableValue(resourceType, config.name(), config.value())) + }) + }) + new IncrementalAlterConfigsRequest.Builder(newData).build(alterConfigs.version()) + + case _ => + bodyAndSize.request + } + } + + trace(s"Processor $processor received request: ${requestDesc(true)}") + + def requestThreadTimeNanos: Long = { + if (apiLocalCompleteTimeNanos == -1L) apiLocalCompleteTimeNanos = Time.SYSTEM.nanoseconds + math.max(apiLocalCompleteTimeNanos - requestDequeueTimeNanos, 0L) + } + + def updateRequestMetrics(networkThreadTimeNanos: Long, response: Response): Unit = { + val endTimeNanos = Time.SYSTEM.nanoseconds + + /** + * Converts nanos to millis with micros precision as additional decimal places in the request log have low + * signal to noise ratio. When it comes to metrics, there is little difference either way as we round the value + * to the nearest long. + */ + def nanosToMs(nanos: Long): Double = { + val positiveNanos = math.max(nanos, 0) + TimeUnit.NANOSECONDS.toMicros(positiveNanos).toDouble / TimeUnit.MILLISECONDS.toMicros(1) + } + + val requestQueueTimeMs = nanosToMs(requestDequeueTimeNanos - startTimeNanos) + val apiLocalTimeMs = nanosToMs(apiLocalCompleteTimeNanos - requestDequeueTimeNanos) + val apiRemoteTimeMs = nanosToMs(responseCompleteTimeNanos - apiLocalCompleteTimeNanos) + val responseQueueTimeMs = nanosToMs(responseDequeueTimeNanos - responseCompleteTimeNanos) + val responseSendTimeMs = nanosToMs(endTimeNanos - responseDequeueTimeNanos) + val messageConversionsTimeMs = nanosToMs(messageConversionsTimeNanos) + val totalTimeMs = nanosToMs(endTimeNanos - startTimeNanos) + val fetchMetricNames = + if (header.apiKey == ApiKeys.FETCH) { + val isFromFollower = body[FetchRequest].isFromFollower + Seq( + if (isFromFollower) RequestMetrics.followFetchMetricName + else RequestMetrics.consumerFetchMetricName + ) + } + else Seq.empty + val metricNames = fetchMetricNames :+ header.apiKey.name + metricNames.foreach { metricName => + val m = metrics(metricName) + m.requestRate(header.apiVersion).mark() + m.requestQueueTimeHist.update(Math.round(requestQueueTimeMs)) + m.localTimeHist.update(Math.round(apiLocalTimeMs)) + m.remoteTimeHist.update(Math.round(apiRemoteTimeMs)) + m.throttleTimeHist.update(apiThrottleTimeMs) + m.responseQueueTimeHist.update(Math.round(responseQueueTimeMs)) + m.responseSendTimeHist.update(Math.round(responseSendTimeMs)) + m.totalTimeHist.update(Math.round(totalTimeMs)) + m.requestBytesHist.update(sizeOfBodyInBytes) + m.messageConversionsTimeHist.foreach(_.update(Math.round(messageConversionsTimeMs))) + m.tempMemoryBytesHist.foreach(_.update(temporaryMemoryBytes)) + } + + // Records network handler thread usage. This is included towards the request quota for the + // user/client. Throttling is only performed when request handler thread usage + // is recorded, just before responses are queued for delivery. + // The time recorded here is the time spent on the network thread for receiving this request + // and sending the response. Note that for the first request on a connection, the time includes + // the total time spent on authentication, which may be significant for SASL/SSL. + recordNetworkThreadTimeCallback.foreach(record => record(networkThreadTimeNanos)) + + if (isRequestLoggingEnabled) { + val desc = RequestConvertToJson.requestDescMetrics(header, requestLog, response.responseLog, + context, session, isForwarded, + totalTimeMs, requestQueueTimeMs, apiLocalTimeMs, + apiRemoteTimeMs, apiThrottleTimeMs, responseQueueTimeMs, + responseSendTimeMs, temporaryMemoryBytes, + messageConversionsTimeMs) + requestLogger.debug("Completed request:" + desc.toString) + } + } + + def releaseBuffer(): Unit = { + envelope match { + case Some(request) => + request.releaseBuffer() + case None => + if (buffer != null) { + memoryPool.release(buffer) + buffer = null + } + } + } + + override def toString = s"Request(processor=$processor, " + + s"connectionId=${context.connectionId}, " + + s"session=$session, " + + s"listenerName=${context.listenerName}, " + + s"securityProtocol=${context.securityProtocol}, " + + s"buffer=$buffer, " + + s"envelope=$envelope)" + + } + + sealed abstract class Response(val request: Request) { + + def processor: Int = request.processor + + def responseLog: Option[JsonNode] = None + + def onComplete: Option[Send => Unit] = None + } + + /** responseLogValue should only be defined if request logging is enabled */ + class SendResponse(request: Request, + val responseSend: Send, + val responseLogValue: Option[JsonNode], + val onCompleteCallback: Option[Send => Unit]) extends Response(request) { + override def responseLog: Option[JsonNode] = responseLogValue + + override def onComplete: Option[Send => Unit] = onCompleteCallback + + override def toString: String = + s"Response(type=Send, request=$request, send=$responseSend, asString=$responseLogValue)" + } + + class NoOpResponse(request: Request) extends Response(request) { + override def toString: String = + s"Response(type=NoOp, request=$request)" + } + + class CloseConnectionResponse(request: Request) extends Response(request) { + override def toString: String = + s"Response(type=CloseConnection, request=$request)" + } + + class StartThrottlingResponse(request: Request) extends Response(request) { + override def toString: String = + s"Response(type=StartThrottling, request=$request)" + } + + class EndThrottlingResponse(request: Request) extends Response(request) { + override def toString: String = + s"Response(type=EndThrottling, request=$request)" + } +} + +class RequestChannel(val queueSize: Int, + val metricNamePrefix: String, + time: Time, + val metrics: RequestChannel.Metrics) extends KafkaMetricsGroup { + import RequestChannel._ + private val requestQueue = new ArrayBlockingQueue[BaseRequest](queueSize) + private val processors = new ConcurrentHashMap[Int, Processor]() + val requestQueueSizeMetricName = metricNamePrefix.concat(RequestQueueSizeMetric) + val responseQueueSizeMetricName = metricNamePrefix.concat(ResponseQueueSizeMetric) + + newGauge(requestQueueSizeMetricName, () => requestQueue.size) + + newGauge(responseQueueSizeMetricName, () => { + processors.values.asScala.foldLeft(0) {(total, processor) => + total + processor.responseQueueSize + } + }) + + def addProcessor(processor: Processor): Unit = { + if (processors.putIfAbsent(processor.id, processor) != null) + warn(s"Unexpected processor with processorId ${processor.id}") + + newGauge(responseQueueSizeMetricName, () => processor.responseQueueSize, + Map(ProcessorMetricTag -> processor.id.toString)) + } + + def removeProcessor(processorId: Int): Unit = { + processors.remove(processorId) + removeMetric(responseQueueSizeMetricName, Map(ProcessorMetricTag -> processorId.toString)) + } + + /** Send a request to be handled, potentially blocking until there is room in the queue for the request */ + def sendRequest(request: RequestChannel.Request): Unit = { + requestQueue.put(request) + } + + def closeConnection( + request: RequestChannel.Request, + errorCounts: java.util.Map[Errors, Integer] + ): Unit = { + // This case is used when the request handler has encountered an error, but the client + // does not expect a response (e.g. when produce request has acks set to 0) + updateErrorMetrics(request.header.apiKey, errorCounts.asScala) + sendResponse(new RequestChannel.CloseConnectionResponse(request)) + } + + def sendResponse( + request: RequestChannel.Request, + response: AbstractResponse, + onComplete: Option[Send => Unit] + ): Unit = { + updateErrorMetrics(request.header.apiKey, response.errorCounts.asScala) + sendResponse(new RequestChannel.SendResponse( + request, + request.buildResponseSend(response), + request.responseNode(response), + onComplete + )) + } + + def sendNoOpResponse(request: RequestChannel.Request): Unit = { + sendResponse(new network.RequestChannel.NoOpResponse(request)) + } + + def startThrottling(request: RequestChannel.Request): Unit = { + sendResponse(new RequestChannel.StartThrottlingResponse(request)) + } + + def endThrottling(request: RequestChannel.Request): Unit = { + sendResponse(new EndThrottlingResponse(request)) + } + + /** Send a response back to the socket server to be sent over the network */ + private[network] def sendResponse(response: RequestChannel.Response): Unit = { + if (isTraceEnabled) { + val requestHeader = response.request.headerForLoggingOrThrottling() + val message = response match { + case sendResponse: SendResponse => + s"Sending ${requestHeader.apiKey} response to client ${requestHeader.clientId} of ${sendResponse.responseSend.size} bytes." + case _: NoOpResponse => + s"Not sending ${requestHeader.apiKey} response to client ${requestHeader.clientId} as it's not required." + case _: CloseConnectionResponse => + s"Closing connection for client ${requestHeader.clientId} due to error during ${requestHeader.apiKey}." + case _: StartThrottlingResponse => + s"Notifying channel throttling has started for client ${requestHeader.clientId} for ${requestHeader.apiKey}" + case _: EndThrottlingResponse => + s"Notifying channel throttling has ended for client ${requestHeader.clientId} for ${requestHeader.apiKey}" + } + trace(message) + } + + response match { + // We should only send one of the following per request + case _: SendResponse | _: NoOpResponse | _: CloseConnectionResponse => + val request = response.request + val timeNanos = time.nanoseconds() + request.responseCompleteTimeNanos = timeNanos + if (request.apiLocalCompleteTimeNanos == -1L) + request.apiLocalCompleteTimeNanos = timeNanos + // For a given request, these may happen in addition to one in the previous section, skip updating the metrics + case _: StartThrottlingResponse | _: EndThrottlingResponse => () + } + + val processor = processors.get(response.processor) + // The processor may be null if it was shutdown. In this case, the connections + // are closed, so the response is dropped. + if (processor != null) { + processor.enqueueResponse(response) + } + } + + /** Get the next request or block until specified time has elapsed */ + def receiveRequest(timeout: Long): RequestChannel.BaseRequest = + requestQueue.poll(timeout, TimeUnit.MILLISECONDS) + + /** Get the next request or block until there is one */ + def receiveRequest(): RequestChannel.BaseRequest = + requestQueue.take() + + def updateErrorMetrics(apiKey: ApiKeys, errors: collection.Map[Errors, Integer]): Unit = { + errors.forKeyValue { (error, count) => + metrics(apiKey.name).markErrorMeter(error, count) + } + } + + def clear(): Unit = { + requestQueue.clear() + } + + def shutdown(): Unit = { + clear() + metrics.close() + } + + def sendShutdownRequest(): Unit = requestQueue.put(ShutdownRequest) + +} + +object RequestMetrics { + val consumerFetchMetricName = ApiKeys.FETCH.name + "Consumer" + val followFetchMetricName = ApiKeys.FETCH.name + "Follower" + + val RequestsPerSec = "RequestsPerSec" + val RequestQueueTimeMs = "RequestQueueTimeMs" + val LocalTimeMs = "LocalTimeMs" + val RemoteTimeMs = "RemoteTimeMs" + val ThrottleTimeMs = "ThrottleTimeMs" + val ResponseQueueTimeMs = "ResponseQueueTimeMs" + val ResponseSendTimeMs = "ResponseSendTimeMs" + val TotalTimeMs = "TotalTimeMs" + val RequestBytes = "RequestBytes" + val MessageConversionsTimeMs = "MessageConversionsTimeMs" + val TemporaryMemoryBytes = "TemporaryMemoryBytes" + val ErrorsPerSec = "ErrorsPerSec" +} + +class RequestMetrics(name: String) extends KafkaMetricsGroup { + + import RequestMetrics._ + + val tags = Map("request" -> name) + val requestRateInternal = new Pool[Short, Meter]() + // time a request spent in a request queue + val requestQueueTimeHist = newHistogram(RequestQueueTimeMs, biased = true, tags) + // time a request takes to be processed at the local broker + val localTimeHist = newHistogram(LocalTimeMs, biased = true, tags) + // time a request takes to wait on remote brokers (currently only relevant to fetch and produce requests) + val remoteTimeHist = newHistogram(RemoteTimeMs, biased = true, tags) + // time a request is throttled, not part of the request processing time (throttling is done at the client level + // for clients that support KIP-219 and by muting the channel for the rest) + val throttleTimeHist = newHistogram(ThrottleTimeMs, biased = true, tags) + // time a response spent in a response queue + val responseQueueTimeHist = newHistogram(ResponseQueueTimeMs, biased = true, tags) + // time to send the response to the requester + val responseSendTimeHist = newHistogram(ResponseSendTimeMs, biased = true, tags) + val totalTimeHist = newHistogram(TotalTimeMs, biased = true, tags) + // request size in bytes + val requestBytesHist = newHistogram(RequestBytes, biased = true, tags) + // time for message conversions (only relevant to fetch and produce requests) + val messageConversionsTimeHist = + if (name == ApiKeys.FETCH.name || name == ApiKeys.PRODUCE.name) + Some(newHistogram(MessageConversionsTimeMs, biased = true, tags)) + else + None + // Temporary memory allocated for processing request (only populated for fetch and produce requests) + // This shows the memory allocated for compression/conversions excluding the actual request size + val tempMemoryBytesHist = + if (name == ApiKeys.FETCH.name || name == ApiKeys.PRODUCE.name) + Some(newHistogram(TemporaryMemoryBytes, biased = true, tags)) + else + None + + private val errorMeters = mutable.Map[Errors, ErrorMeter]() + Errors.values.foreach(error => errorMeters.put(error, new ErrorMeter(name, error))) + + def requestRate(version: Short): Meter = { + requestRateInternal.getAndMaybePut(version, newMeter("RequestsPerSec", "requests", TimeUnit.SECONDS, tags + ("version" -> version.toString))) + } + + class ErrorMeter(name: String, error: Errors) { + private val tags = Map("request" -> name, "error" -> error.name) + + @volatile private var meter: Meter = null + + def getOrCreateMeter(): Meter = { + if (meter != null) + meter + else { + synchronized { + if (meter == null) + meter = newMeter(ErrorsPerSec, "requests", TimeUnit.SECONDS, tags) + meter + } + } + } + + def removeMeter(): Unit = { + synchronized { + if (meter != null) { + removeMetric(ErrorsPerSec, tags) + meter = null + } + } + } + } + + def markErrorMeter(error: Errors, count: Int): Unit = { + errorMeters(error).getOrCreateMeter().mark(count.toLong) + } + + def removeMetrics(): Unit = { + for (version <- requestRateInternal.keys) removeMetric(RequestsPerSec, tags + ("version" -> version.toString)) + removeMetric(RequestQueueTimeMs, tags) + removeMetric(LocalTimeMs, tags) + removeMetric(RemoteTimeMs, tags) + removeMetric(RequestsPerSec, tags) + removeMetric(ThrottleTimeMs, tags) + removeMetric(ResponseQueueTimeMs, tags) + removeMetric(TotalTimeMs, tags) + removeMetric(ResponseSendTimeMs, tags) + removeMetric(RequestBytes, tags) + removeMetric(ResponseSendTimeMs, tags) + if (name == ApiKeys.FETCH.name || name == ApiKeys.PRODUCE.name) { + removeMetric(MessageConversionsTimeMs, tags) + removeMetric(TemporaryMemoryBytes, tags) + } + errorMeters.values.foreach(_.removeMeter()) + errorMeters.clear() + } +} diff --git a/core/src/main/scala/kafka/network/RequestConvertToJson.scala b/core/src/main/scala/kafka/network/RequestConvertToJson.scala new file mode 100644 index 0000000..bb8e327 --- /dev/null +++ b/core/src/main/scala/kafka/network/RequestConvertToJson.scala @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.network + +import com.fasterxml.jackson.databind.JsonNode +import com.fasterxml.jackson.databind.node.{BooleanNode, DoubleNode, JsonNodeFactory, LongNode, ObjectNode, TextNode} +import kafka.network.RequestChannel.Session +import org.apache.kafka.common.message._ +import org.apache.kafka.common.network.ClientInformation +import org.apache.kafka.common.requests._ + +object RequestConvertToJson { + def request(request: AbstractRequest): JsonNode = { + request match { + case req: AddOffsetsToTxnRequest => AddOffsetsToTxnRequestDataJsonConverter.write(req.data, request.version) + case req: AddPartitionsToTxnRequest => AddPartitionsToTxnRequestDataJsonConverter.write(req.data, request.version) + case req: AllocateProducerIdsRequest => AllocateProducerIdsRequestDataJsonConverter.write(req.data, request.version) + case req: AlterClientQuotasRequest => AlterClientQuotasRequestDataJsonConverter.write(req.data, request.version) + case req: AlterConfigsRequest => AlterConfigsRequestDataJsonConverter.write(req.data, request.version) + case req: AlterIsrRequest => AlterIsrRequestDataJsonConverter.write(req.data, request.version) + case req: AlterPartitionReassignmentsRequest => AlterPartitionReassignmentsRequestDataJsonConverter.write(req.data, request.version) + case req: AlterReplicaLogDirsRequest => AlterReplicaLogDirsRequestDataJsonConverter.write(req.data, request.version) + case res: AlterUserScramCredentialsRequest => AlterUserScramCredentialsRequestDataJsonConverter.write(res.data, request.version) + case req: ApiVersionsRequest => ApiVersionsRequestDataJsonConverter.write(req.data, request.version) + case req: BeginQuorumEpochRequest => BeginQuorumEpochRequestDataJsonConverter.write(req.data, request.version) + case req: BrokerHeartbeatRequest => BrokerHeartbeatRequestDataJsonConverter.write(req.data, request.version) + case req: BrokerRegistrationRequest => BrokerRegistrationRequestDataJsonConverter.write(req.data, request.version) + case req: ControlledShutdownRequest => ControlledShutdownRequestDataJsonConverter.write(req.data, request.version) + case req: CreateAclsRequest => CreateAclsRequestDataJsonConverter.write(req.data, request.version) + case req: CreateDelegationTokenRequest => CreateDelegationTokenRequestDataJsonConverter.write(req.data, request.version) + case req: CreatePartitionsRequest => CreatePartitionsRequestDataJsonConverter.write(req.data, request.version) + case req: CreateTopicsRequest => CreateTopicsRequestDataJsonConverter.write(req.data, request.version) + case req: DeleteAclsRequest => DeleteAclsRequestDataJsonConverter.write(req.data, request.version) + case req: DeleteGroupsRequest => DeleteGroupsRequestDataJsonConverter.write(req.data, request.version) + case req: DeleteRecordsRequest => DeleteRecordsRequestDataJsonConverter.write(req.data, request.version) + case req: DeleteTopicsRequest => DeleteTopicsRequestDataJsonConverter.write(req.data, request.version) + case req: DescribeAclsRequest => DescribeAclsRequestDataJsonConverter.write(req.data, request.version) + case req: DescribeClientQuotasRequest => DescribeClientQuotasRequestDataJsonConverter.write(req.data, request.version) + case req: DescribeConfigsRequest => DescribeConfigsRequestDataJsonConverter.write(req.data, request.version) + case req: DescribeDelegationTokenRequest => DescribeDelegationTokenRequestDataJsonConverter.write(req.data, request.version) + case req: DescribeGroupsRequest => DescribeGroupsRequestDataJsonConverter.write(req.data, request.version) + case req: DescribeLogDirsRequest => DescribeLogDirsRequestDataJsonConverter.write(req.data, request.version) + case req: DescribeQuorumRequest => DescribeQuorumRequestDataJsonConverter.write(req.data, request.version) + case res: DescribeUserScramCredentialsRequest => DescribeUserScramCredentialsRequestDataJsonConverter.write(res.data, request.version) + case req: ElectLeadersRequest => ElectLeadersRequestDataJsonConverter.write(req.data, request.version) + case req: EndTxnRequest => EndTxnRequestDataJsonConverter.write(req.data, request.version) + case req: EndQuorumEpochRequest => EndQuorumEpochRequestDataJsonConverter.write(req.data, request.version) + case req: EnvelopeRequest => EnvelopeRequestDataJsonConverter.write(req.data, request.version) + case req: ExpireDelegationTokenRequest => ExpireDelegationTokenRequestDataJsonConverter.write(req.data, request.version) + case req: FetchRequest => FetchRequestDataJsonConverter.write(req.data, request.version) + case req: FindCoordinatorRequest => FindCoordinatorRequestDataJsonConverter.write(req.data, request.version) + case req: HeartbeatRequest => HeartbeatRequestDataJsonConverter.write(req.data, request.version) + case req: IncrementalAlterConfigsRequest => IncrementalAlterConfigsRequestDataJsonConverter.write(req.data, request.version) + case req: InitProducerIdRequest => InitProducerIdRequestDataJsonConverter.write(req.data, request.version) + case req: JoinGroupRequest => JoinGroupRequestDataJsonConverter.write(req.data, request.version) + case req: LeaderAndIsrRequest => LeaderAndIsrRequestDataJsonConverter.write(req.data, request.version) + case req: LeaveGroupRequest => LeaveGroupRequestDataJsonConverter.write(req.data, request.version) + case req: ListGroupsRequest => ListGroupsRequestDataJsonConverter.write(req.data, request.version) + case req: ListOffsetsRequest => ListOffsetsRequestDataJsonConverter.write(req.data, request.version) + case req: ListPartitionReassignmentsRequest => ListPartitionReassignmentsRequestDataJsonConverter.write(req.data, request.version) + case req: MetadataRequest => MetadataRequestDataJsonConverter.write(req.data, request.version) + case req: OffsetCommitRequest => OffsetCommitRequestDataJsonConverter.write(req.data, request.version) + case req: OffsetDeleteRequest => OffsetDeleteRequestDataJsonConverter.write(req.data, request.version) + case req: OffsetFetchRequest => OffsetFetchRequestDataJsonConverter.write(req.data, request.version) + case req: OffsetsForLeaderEpochRequest => OffsetForLeaderEpochRequestDataJsonConverter.write(req.data, request.version) + case req: ProduceRequest => ProduceRequestDataJsonConverter.write(req.data, request.version, false) + case req: RenewDelegationTokenRequest => RenewDelegationTokenRequestDataJsonConverter.write(req.data, request.version) + case req: SaslAuthenticateRequest => SaslAuthenticateRequestDataJsonConverter.write(req.data, request.version) + case req: SaslHandshakeRequest => SaslHandshakeRequestDataJsonConverter.write(req.data, request.version) + case req: StopReplicaRequest => StopReplicaRequestDataJsonConverter.write(req.data, request.version) + case req: SyncGroupRequest => SyncGroupRequestDataJsonConverter.write(req.data, request.version) + case req: TxnOffsetCommitRequest => TxnOffsetCommitRequestDataJsonConverter.write(req.data, request.version) + case req: UnregisterBrokerRequest => UnregisterBrokerRequestDataJsonConverter.write(req.data, request.version) + case req: UpdateFeaturesRequest => UpdateFeaturesRequestDataJsonConverter.write(req.data, request.version) + case req: UpdateMetadataRequest => UpdateMetadataRequestDataJsonConverter.write(req.data, request.version) + case req: VoteRequest => VoteRequestDataJsonConverter.write(req.data, request.version) + case req: WriteTxnMarkersRequest => WriteTxnMarkersRequestDataJsonConverter.write(req.data, request.version) + case req: FetchSnapshotRequest => FetchSnapshotRequestDataJsonConverter.write(req.data, request.version) + case req: DescribeClusterRequest => DescribeClusterRequestDataJsonConverter.write(req.data, request.version) + case req: DescribeProducersRequest => DescribeProducersRequestDataJsonConverter.write(req.data, request.version) + case req: DescribeTransactionsRequest => DescribeTransactionsRequestDataJsonConverter.write(req.data, request.version) + case req: ListTransactionsRequest => ListTransactionsRequestDataJsonConverter.write(req.data, request.version) + case _ => throw new IllegalStateException(s"ApiKey ${request.apiKey} is not currently handled in `request`, the " + + "code should be updated to do so."); + } + } + + def response(response: AbstractResponse, version: Short): JsonNode = { + response match { + case res: AddOffsetsToTxnResponse => AddOffsetsToTxnResponseDataJsonConverter.write(res.data, version) + case res: AddPartitionsToTxnResponse => AddPartitionsToTxnResponseDataJsonConverter.write(res.data, version) + case res: AllocateProducerIdsResponse => AllocateProducerIdsResponseDataJsonConverter.write(res.data, version) + case res: AlterClientQuotasResponse => AlterClientQuotasResponseDataJsonConverter.write(res.data, version) + case res: AlterConfigsResponse => AlterConfigsResponseDataJsonConverter.write(res.data, version) + case res: AlterIsrResponse => AlterIsrResponseDataJsonConverter.write(res.data, version) + case res: AlterPartitionReassignmentsResponse => AlterPartitionReassignmentsResponseDataJsonConverter.write(res.data, version) + case res: AlterReplicaLogDirsResponse => AlterReplicaLogDirsResponseDataJsonConverter.write(res.data, version) + case res: AlterUserScramCredentialsResponse => AlterUserScramCredentialsResponseDataJsonConverter.write(res.data, version) + case res: ApiVersionsResponse => ApiVersionsResponseDataJsonConverter.write(res.data, version) + case res: BeginQuorumEpochResponse => BeginQuorumEpochResponseDataJsonConverter.write(res.data, version) + case res: BrokerHeartbeatResponse => BrokerHeartbeatResponseDataJsonConverter.write(res.data, version) + case res: BrokerRegistrationResponse => BrokerRegistrationResponseDataJsonConverter.write(res.data, version) + case res: ControlledShutdownResponse => ControlledShutdownResponseDataJsonConverter.write(res.data, version) + case res: CreateAclsResponse => CreateAclsResponseDataJsonConverter.write(res.data, version) + case res: CreateDelegationTokenResponse => CreateDelegationTokenResponseDataJsonConverter.write(res.data, version) + case res: CreatePartitionsResponse => CreatePartitionsResponseDataJsonConverter.write(res.data, version) + case res: CreateTopicsResponse => CreateTopicsResponseDataJsonConverter.write(res.data, version) + case res: DeleteAclsResponse => DeleteAclsResponseDataJsonConverter.write(res.data, version) + case res: DeleteGroupsResponse => DeleteGroupsResponseDataJsonConverter.write(res.data, version) + case res: DeleteRecordsResponse => DeleteRecordsResponseDataJsonConverter.write(res.data, version) + case res: DeleteTopicsResponse => DeleteTopicsResponseDataJsonConverter.write(res.data, version) + case res: DescribeAclsResponse => DescribeAclsResponseDataJsonConverter.write(res.data, version) + case res: DescribeClientQuotasResponse => DescribeClientQuotasResponseDataJsonConverter.write(res.data, version) + case res: DescribeConfigsResponse => DescribeConfigsResponseDataJsonConverter.write(res.data, version) + case res: DescribeDelegationTokenResponse => DescribeDelegationTokenResponseDataJsonConverter.write(res.data, version) + case res: DescribeGroupsResponse => DescribeGroupsResponseDataJsonConverter.write(res.data, version) + case res: DescribeLogDirsResponse => DescribeLogDirsResponseDataJsonConverter.write(res.data, version) + case res: DescribeQuorumResponse => DescribeQuorumResponseDataJsonConverter.write(res.data, version) + case res: DescribeUserScramCredentialsResponse => DescribeUserScramCredentialsResponseDataJsonConverter.write(res.data, version) + case res: ElectLeadersResponse => ElectLeadersResponseDataJsonConverter.write(res.data, version) + case res: EndTxnResponse => EndTxnResponseDataJsonConverter.write(res.data, version) + case res: EndQuorumEpochResponse => EndQuorumEpochResponseDataJsonConverter.write(res.data, version) + case res: EnvelopeResponse => EnvelopeResponseDataJsonConverter.write(res.data, version) + case res: ExpireDelegationTokenResponse => ExpireDelegationTokenResponseDataJsonConverter.write(res.data, version) + case res: FetchResponse => FetchResponseDataJsonConverter.write(res.data, version, false) + case res: FindCoordinatorResponse => FindCoordinatorResponseDataJsonConverter.write(res.data, version) + case res: HeartbeatResponse => HeartbeatResponseDataJsonConverter.write(res.data, version) + case res: IncrementalAlterConfigsResponse => IncrementalAlterConfigsResponseDataJsonConverter.write(res.data, version) + case res: InitProducerIdResponse => InitProducerIdResponseDataJsonConverter.write(res.data, version) + case res: JoinGroupResponse => JoinGroupResponseDataJsonConverter.write(res.data, version) + case res: LeaderAndIsrResponse => LeaderAndIsrResponseDataJsonConverter.write(res.data, version) + case res: LeaveGroupResponse => LeaveGroupResponseDataJsonConverter.write(res.data, version) + case res: ListGroupsResponse => ListGroupsResponseDataJsonConverter.write(res.data, version) + case res: ListOffsetsResponse => ListOffsetsResponseDataJsonConverter.write(res.data, version) + case res: ListPartitionReassignmentsResponse => ListPartitionReassignmentsResponseDataJsonConverter.write(res.data, version) + case res: MetadataResponse => MetadataResponseDataJsonConverter.write(res.data, version) + case res: OffsetCommitResponse => OffsetCommitResponseDataJsonConverter.write(res.data, version) + case res: OffsetDeleteResponse => OffsetDeleteResponseDataJsonConverter.write(res.data, version) + case res: OffsetFetchResponse => OffsetFetchResponseDataJsonConverter.write(res.data, version) + case res: OffsetsForLeaderEpochResponse => OffsetForLeaderEpochResponseDataJsonConverter.write(res.data, version) + case res: ProduceResponse => ProduceResponseDataJsonConverter.write(res.data, version) + case res: RenewDelegationTokenResponse => RenewDelegationTokenResponseDataJsonConverter.write(res.data, version) + case res: SaslAuthenticateResponse => SaslAuthenticateResponseDataJsonConverter.write(res.data, version) + case res: SaslHandshakeResponse => SaslHandshakeResponseDataJsonConverter.write(res.data, version) + case res: StopReplicaResponse => StopReplicaResponseDataJsonConverter.write(res.data, version) + case res: SyncGroupResponse => SyncGroupResponseDataJsonConverter.write(res.data, version) + case res: TxnOffsetCommitResponse => TxnOffsetCommitResponseDataJsonConverter.write(res.data, version) + case res: UnregisterBrokerResponse => UnregisterBrokerResponseDataJsonConverter.write(res.data, version) + case res: UpdateFeaturesResponse => UpdateFeaturesResponseDataJsonConverter.write(res.data, version) + case res: UpdateMetadataResponse => UpdateMetadataResponseDataJsonConverter.write(res.data, version) + case res: WriteTxnMarkersResponse => WriteTxnMarkersResponseDataJsonConverter.write(res.data, version) + case res: VoteResponse => VoteResponseDataJsonConverter.write(res.data, version) + case res: FetchSnapshotResponse => FetchSnapshotResponseDataJsonConverter.write(res.data, version) + case res: DescribeClusterResponse => DescribeClusterResponseDataJsonConverter.write(res.data, version) + case res: DescribeProducersResponse => DescribeProducersResponseDataJsonConverter.write(res.data, version) + case res: DescribeTransactionsResponse => DescribeTransactionsResponseDataJsonConverter.write(res.data, version) + case res: ListTransactionsResponse => ListTransactionsResponseDataJsonConverter.write(res.data, version) + case _ => throw new IllegalStateException(s"ApiKey ${response.apiKey} is not currently handled in `response`, the " + + "code should be updated to do so."); + } + } + + def requestHeaderNode(header: RequestHeader): JsonNode = { + val node = RequestHeaderDataJsonConverter.write(header.data, header.headerVersion, false).asInstanceOf[ObjectNode] + node.set("requestApiKeyName", new TextNode(header.apiKey.toString)) + node + } + + def requestDesc(header: RequestHeader, requestNode: Option[JsonNode], isForwarded: Boolean): JsonNode = { + val node = new ObjectNode(JsonNodeFactory.instance) + node.set("isForwarded", if (isForwarded) BooleanNode.TRUE else BooleanNode.FALSE) + node.set("requestHeader", requestHeaderNode(header)) + node.set("request", requestNode.getOrElse(new TextNode(""))) + node + } + + def clientInfoNode(clientInfo: ClientInformation): JsonNode = { + val node = new ObjectNode(JsonNodeFactory.instance) + node.set("softwareName", new TextNode(clientInfo.softwareName)) + node.set("softwareVersion", new TextNode(clientInfo.softwareVersion)) + node + } + + def requestDescMetrics(header: RequestHeader, requestNode: Option[JsonNode], responseNode: Option[JsonNode], + context: RequestContext, session: Session, isForwarded: Boolean, + totalTimeMs: Double, requestQueueTimeMs: Double, apiLocalTimeMs: Double, + apiRemoteTimeMs: Double, apiThrottleTimeMs: Long, responseQueueTimeMs: Double, + responseSendTimeMs: Double, temporaryMemoryBytes: Long, + messageConversionsTimeMs: Double): JsonNode = { + val node = requestDesc(header, requestNode, isForwarded).asInstanceOf[ObjectNode] + node.set("response", responseNode.getOrElse(new TextNode(""))) + node.set("connection", new TextNode(context.connectionId)) + node.set("totalTimeMs", new DoubleNode(totalTimeMs)) + node.set("requestQueueTimeMs", new DoubleNode(requestQueueTimeMs)) + node.set("localTimeMs", new DoubleNode(apiLocalTimeMs)) + node.set("remoteTimeMs", new DoubleNode(apiRemoteTimeMs)) + node.set("throttleTimeMs", new LongNode(apiThrottleTimeMs)) + node.set("responseQueueTimeMs", new DoubleNode(responseQueueTimeMs)) + node.set("sendTimeMs", new DoubleNode(responseSendTimeMs)) + node.set("securityProtocol", new TextNode(context.securityProtocol.toString)) + node.set("principal", new TextNode(session.principal.toString)) + node.set("listener", new TextNode(context.listenerName.value)) + node.set("clientInformation", clientInfoNode(context.clientInformation)) + if (temporaryMemoryBytes > 0) + node.set("temporaryMemoryBytes", new LongNode(temporaryMemoryBytes)) + if (messageConversionsTimeMs > 0) + node.set("messageConversionsTime", new DoubleNode(messageConversionsTimeMs)) + node + } +} diff --git a/core/src/main/scala/kafka/network/SocketServer.scala b/core/src/main/scala/kafka/network/SocketServer.scala new file mode 100644 index 0000000..a4a990c --- /dev/null +++ b/core/src/main/scala/kafka/network/SocketServer.scala @@ -0,0 +1,1717 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.network + +import java.io.IOException +import java.net._ +import java.nio.ByteBuffer +import java.nio.channels.{Selector => NSelector, _} +import java.util +import java.util.Optional +import java.util.concurrent._ +import java.util.concurrent.atomic._ + +import kafka.cluster.{BrokerEndPoint, EndPoint} +import kafka.metrics.KafkaMetricsGroup +import kafka.network.ConnectionQuotas._ +import kafka.network.Processor._ +import kafka.network.RequestChannel.{CloseConnectionResponse, EndThrottlingResponse, NoOpResponse, SendResponse, StartThrottlingResponse} +import kafka.network.SocketServer._ +import kafka.security.CredentialProvider +import kafka.server.{ApiVersionManager, BrokerReconfigurable, KafkaConfig} +import kafka.utils.Implicits._ +import kafka.utils._ +import org.apache.kafka.common.config.ConfigException +import org.apache.kafka.common.config.internals.QuotaConfigs +import org.apache.kafka.common.errors.InvalidRequestException +import org.apache.kafka.common.memory.{MemoryPool, SimpleMemoryPool} +import org.apache.kafka.common.metrics._ +import org.apache.kafka.common.metrics.stats.{Avg, CumulativeSum, Meter, Rate} +import org.apache.kafka.common.network.KafkaChannel.ChannelMuteEvent +import org.apache.kafka.common.network.{ChannelBuilder, ChannelBuilders, ClientInformation, KafkaChannel, ListenerName, ListenerReconfigurable, NetworkSend, Selectable, Send, Selector => KSelector} +import org.apache.kafka.common.protocol.ApiKeys +import org.apache.kafka.common.requests.{ApiVersionsRequest, RequestContext, RequestHeader} +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.utils.{KafkaThread, LogContext, Time, Utils} +import org.apache.kafka.common.{Endpoint, KafkaException, MetricName, Reconfigurable} +import org.slf4j.event.Level + +import scala.collection._ +import scala.collection.mutable.{ArrayBuffer, Buffer} +import scala.jdk.CollectionConverters._ +import scala.util.control.ControlThrowable + +/** + * Handles new connections, requests and responses to and from broker. + * Kafka supports two types of request planes : + * - data-plane : + * - Handles requests from clients and other brokers in the cluster. + * - The threading model is + * 1 Acceptor thread per listener, that handles new connections. + * It is possible to configure multiple data-planes by specifying multiple "," separated endpoints for "listeners" in KafkaConfig. + * Acceptor has N Processor threads that each have their own selector and read requests from sockets + * M Handler threads that handle requests and produce responses back to the processor threads for writing. + * - control-plane : + * - Handles requests from controller. This is optional and can be configured by specifying "control.plane.listener.name". + * If not configured, the controller requests are handled by the data-plane. + * - The threading model is + * 1 Acceptor thread that handles new connections + * Acceptor has 1 Processor thread that has its own selector and read requests from the socket. + * 1 Handler thread that handles requests and produces responses back to the processor thread for writing. + */ +class SocketServer(val config: KafkaConfig, + val metrics: Metrics, + val time: Time, + val credentialProvider: CredentialProvider, + val apiVersionManager: ApiVersionManager) + extends Logging with KafkaMetricsGroup with BrokerReconfigurable { + + private val maxQueuedRequests = config.queuedMaxRequests + + private val nodeId = config.brokerId + + private val logContext = new LogContext(s"[SocketServer listenerType=${apiVersionManager.listenerType}, nodeId=$nodeId] ") + + this.logIdent = logContext.logPrefix + + private val memoryPoolSensor = metrics.sensor("MemoryPoolUtilization") + private val memoryPoolDepletedPercentMetricName = metrics.metricName("MemoryPoolAvgDepletedPercent", MetricsGroup) + private val memoryPoolDepletedTimeMetricName = metrics.metricName("MemoryPoolDepletedTimeTotal", MetricsGroup) + memoryPoolSensor.add(new Meter(TimeUnit.MILLISECONDS, memoryPoolDepletedPercentMetricName, memoryPoolDepletedTimeMetricName)) + private val memoryPool = if (config.queuedMaxBytes > 0) new SimpleMemoryPool(config.queuedMaxBytes, config.socketRequestMaxBytes, false, memoryPoolSensor) else MemoryPool.NONE + // data-plane + private val dataPlaneProcessors = new ConcurrentHashMap[Int, Processor]() + private[network] val dataPlaneAcceptors = new ConcurrentHashMap[EndPoint, Acceptor]() + val dataPlaneRequestChannel = new RequestChannel(maxQueuedRequests, DataPlaneMetricPrefix, time, apiVersionManager.newRequestMetrics) + // control-plane + private var controlPlaneProcessorOpt : Option[Processor] = None + private[network] var controlPlaneAcceptorOpt : Option[Acceptor] = None + val controlPlaneRequestChannelOpt: Option[RequestChannel] = config.controlPlaneListenerName.map(_ => + new RequestChannel(20, ControlPlaneMetricPrefix, time, apiVersionManager.newRequestMetrics)) + + private var nextProcessorId = 0 + val connectionQuotas = new ConnectionQuotas(config, time, metrics) + private var startedProcessingRequests = false + private var stoppedProcessingRequests = false + + /** + * Starts the socket server and creates all the Acceptors and the Processors. The Acceptors + * start listening at this stage so that the bound port is known when this method completes + * even when ephemeral ports are used. Acceptors and Processors are started if `startProcessingRequests` + * is true. If not, acceptors and processors are only started when [[kafka.network.SocketServer#startProcessingRequests()]] + * is invoked. Delayed starting of acceptors and processors is used to delay processing client + * connections until server is fully initialized, e.g. to ensure that all credentials have been + * loaded before authentications are performed. Incoming connections on this server are processed + * when processors start up and invoke [[org.apache.kafka.common.network.Selector#poll]]. + * + * @param startProcessingRequests Flag indicating whether `Processor`s must be started. + * @param controlPlaneListener The control plane listener, or None if there is none. + * @param dataPlaneListeners The data plane listeners. + */ + def startup(startProcessingRequests: Boolean = true, + controlPlaneListener: Option[EndPoint] = config.controlPlaneListener, + dataPlaneListeners: Seq[EndPoint] = config.dataPlaneListeners): Unit = { + this.synchronized { + createControlPlaneAcceptorAndProcessor(controlPlaneListener) + createDataPlaneAcceptorsAndProcessors(config.numNetworkThreads, dataPlaneListeners) + if (startProcessingRequests) { + this.startProcessingRequests() + } + } + + newGauge(s"${DataPlaneMetricPrefix}NetworkProcessorAvgIdlePercent", () => SocketServer.this.synchronized { + val ioWaitRatioMetricNames = dataPlaneProcessors.values.asScala.iterator.map { p => + metrics.metricName("io-wait-ratio", MetricsGroup, p.metricTags) + } + ioWaitRatioMetricNames.map { metricName => + Option(metrics.metric(metricName)).fold(0.0)(m => Math.min(m.metricValue.asInstanceOf[Double], 1.0)) + }.sum / dataPlaneProcessors.size + }) + newGauge(s"${ControlPlaneMetricPrefix}NetworkProcessorAvgIdlePercent", () => SocketServer.this.synchronized { + val ioWaitRatioMetricName = controlPlaneProcessorOpt.map { p => + metrics.metricName("io-wait-ratio", MetricsGroup, p.metricTags) + } + ioWaitRatioMetricName.map { metricName => + Option(metrics.metric(metricName)).fold(0.0)(m => Math.min(m.metricValue.asInstanceOf[Double], 1.0)) + }.getOrElse(Double.NaN) + }) + newGauge("MemoryPoolAvailable", () => memoryPool.availableMemory) + newGauge("MemoryPoolUsed", () => memoryPool.size() - memoryPool.availableMemory) + newGauge(s"${DataPlaneMetricPrefix}ExpiredConnectionsKilledCount", () => SocketServer.this.synchronized { + val expiredConnectionsKilledCountMetricNames = dataPlaneProcessors.values.asScala.iterator.map { p => + metrics.metricName("expired-connections-killed-count", MetricsGroup, p.metricTags) + } + expiredConnectionsKilledCountMetricNames.map { metricName => + Option(metrics.metric(metricName)).fold(0.0)(m => m.metricValue.asInstanceOf[Double]) + }.sum + }) + newGauge(s"${ControlPlaneMetricPrefix}ExpiredConnectionsKilledCount", () => SocketServer.this.synchronized { + val expiredConnectionsKilledCountMetricNames = controlPlaneProcessorOpt.map { p => + metrics.metricName("expired-connections-killed-count", MetricsGroup, p.metricTags) + } + expiredConnectionsKilledCountMetricNames.map { metricName => + Option(metrics.metric(metricName)).fold(0.0)(m => m.metricValue.asInstanceOf[Double]) + }.getOrElse(0.0) + }) + } + + /** + * Start processing requests and new connections. This method is used for delayed starting of + * all the acceptors and processors if [[kafka.network.SocketServer#startup]] was invoked with + * `startProcessingRequests=false`. + * + * Before starting processors for each endpoint, we ensure that authorizer has all the metadata + * to authorize requests on that endpoint by waiting on the provided future. We start inter-broker + * listener before other listeners. This allows authorization metadata for other listeners to be + * stored in Kafka topics in this cluster. + * + * @param authorizerFutures Future per [[EndPoint]] used to wait before starting the processor + * corresponding to the [[EndPoint]] + */ + def startProcessingRequests(authorizerFutures: Map[Endpoint, CompletableFuture[Void]] = Map.empty): Unit = { + info("Starting socket server acceptors and processors") + this.synchronized { + if (!startedProcessingRequests) { + startControlPlaneProcessorAndAcceptor(authorizerFutures) + startDataPlaneProcessorsAndAcceptors(authorizerFutures) + startedProcessingRequests = true + } else { + info("Socket server acceptors and processors already started") + } + } + info("Started socket server acceptors and processors") + } + + /** + * Starts processors of the provided acceptor and the acceptor itself. + * + * Before starting them, we ensure that authorizer has all the metadata to authorize + * requests on that endpoint by waiting on the provided future. + */ + private def startAcceptorAndProcessors(threadPrefix: String, + endpoint: EndPoint, + acceptor: Acceptor, + authorizerFutures: Map[Endpoint, CompletableFuture[Void]] = Map.empty): Unit = { + debug(s"Wait for authorizer to complete start up on listener ${endpoint.listenerName}") + waitForAuthorizerFuture(acceptor, authorizerFutures) + debug(s"Start processors on listener ${endpoint.listenerName}") + acceptor.startProcessors(threadPrefix) + debug(s"Start acceptor thread on listener ${endpoint.listenerName}") + if (!acceptor.isStarted()) { + KafkaThread.nonDaemon( + s"${threadPrefix}-kafka-socket-acceptor-${endpoint.listenerName}-${endpoint.securityProtocol}-${endpoint.port}", + acceptor + ).start() + acceptor.awaitStartup() + } + info(s"Started $threadPrefix acceptor and processor(s) for endpoint : ${endpoint.listenerName}") + } + + /** + * Starts processors of all the data-plane acceptors and all the acceptors of this server. + * + * We start inter-broker listener before other listeners. This allows authorization metadata for + * other listeners to be stored in Kafka topics in this cluster. + */ + private def startDataPlaneProcessorsAndAcceptors(authorizerFutures: Map[Endpoint, CompletableFuture[Void]]): Unit = { + val interBrokerListener = dataPlaneAcceptors.asScala.keySet + .find(_.listenerName == config.interBrokerListenerName) + val orderedAcceptors = interBrokerListener match { + case Some(interBrokerListener) => List(dataPlaneAcceptors.get(interBrokerListener)) ++ + dataPlaneAcceptors.asScala.filter { case (k, _) => k != interBrokerListener }.values + case None => dataPlaneAcceptors.asScala.values + } + orderedAcceptors.foreach { acceptor => + val endpoint = acceptor.endPoint + startAcceptorAndProcessors(DataPlaneThreadPrefix, endpoint, acceptor, authorizerFutures) + } + } + + /** + * Start the processor of control-plane acceptor and the acceptor of this server. + */ + private def startControlPlaneProcessorAndAcceptor(authorizerFutures: Map[Endpoint, CompletableFuture[Void]]): Unit = { + controlPlaneAcceptorOpt.foreach { controlPlaneAcceptor => + val endpoint = config.controlPlaneListener.get + startAcceptorAndProcessors(ControlPlaneThreadPrefix, endpoint, controlPlaneAcceptor, authorizerFutures) + } + } + + private def endpoints = config.listeners.map(l => l.listenerName -> l).toMap + + private def createDataPlaneAcceptorsAndProcessors(dataProcessorsPerListener: Int, + endpoints: Seq[EndPoint]): Unit = { + endpoints.foreach { endpoint => + connectionQuotas.addListener(config, endpoint.listenerName) + val dataPlaneAcceptor = createAcceptor(endpoint, DataPlaneMetricPrefix) + addDataPlaneProcessors(dataPlaneAcceptor, endpoint, dataProcessorsPerListener) + dataPlaneAcceptors.put(endpoint, dataPlaneAcceptor) + info(s"Created data-plane acceptor and processors for endpoint : ${endpoint.listenerName}") + } + } + + private def createControlPlaneAcceptorAndProcessor(endpointOpt: Option[EndPoint]): Unit = { + endpointOpt.foreach { endpoint => + connectionQuotas.addListener(config, endpoint.listenerName) + val controlPlaneAcceptor = createAcceptor(endpoint, ControlPlaneMetricPrefix) + val controlPlaneProcessor = newProcessor(nextProcessorId, controlPlaneRequestChannelOpt.get, + connectionQuotas, endpoint.listenerName, endpoint.securityProtocol, memoryPool, isPrivilegedListener = true) + controlPlaneAcceptorOpt = Some(controlPlaneAcceptor) + controlPlaneProcessorOpt = Some(controlPlaneProcessor) + val listenerProcessors = new ArrayBuffer[Processor]() + listenerProcessors += controlPlaneProcessor + controlPlaneRequestChannelOpt.foreach(_.addProcessor(controlPlaneProcessor)) + nextProcessorId += 1 + controlPlaneAcceptor.addProcessors(listenerProcessors, ControlPlaneThreadPrefix) + info(s"Created control-plane acceptor and processor for endpoint : ${endpoint.listenerName}") + } + } + + private def createAcceptor(endPoint: EndPoint, metricPrefix: String) : Acceptor = { + val sendBufferSize = config.socketSendBufferBytes + val recvBufferSize = config.socketReceiveBufferBytes + new Acceptor(endPoint, sendBufferSize, recvBufferSize, nodeId, connectionQuotas, metricPrefix, time) + } + + private def addDataPlaneProcessors(acceptor: Acceptor, endpoint: EndPoint, newProcessorsPerListener: Int): Unit = { + val listenerName = endpoint.listenerName + val securityProtocol = endpoint.securityProtocol + val listenerProcessors = new ArrayBuffer[Processor]() + val isPrivilegedListener = controlPlaneRequestChannelOpt.isEmpty && config.interBrokerListenerName == listenerName + + for (_ <- 0 until newProcessorsPerListener) { + val processor = newProcessor(nextProcessorId, dataPlaneRequestChannel, connectionQuotas, + listenerName, securityProtocol, memoryPool, isPrivilegedListener) + listenerProcessors += processor + dataPlaneRequestChannel.addProcessor(processor) + nextProcessorId += 1 + } + listenerProcessors.foreach(p => dataPlaneProcessors.put(p.id, p)) + acceptor.addProcessors(listenerProcessors, DataPlaneThreadPrefix) + } + + /** + * Stop processing requests and new connections. + */ + def stopProcessingRequests(): Unit = { + info("Stopping socket server request processors") + this.synchronized { + dataPlaneAcceptors.asScala.values.foreach(_.initiateShutdown()) + dataPlaneAcceptors.asScala.values.foreach(_.awaitShutdown()) + controlPlaneAcceptorOpt.foreach(_.initiateShutdown()) + controlPlaneAcceptorOpt.foreach(_.awaitShutdown()) + dataPlaneRequestChannel.clear() + controlPlaneRequestChannelOpt.foreach(_.clear()) + stoppedProcessingRequests = true + } + info("Stopped socket server request processors") + } + + def resizeThreadPool(oldNumNetworkThreads: Int, newNumNetworkThreads: Int): Unit = synchronized { + info(s"Resizing network thread pool size for each data-plane listener from $oldNumNetworkThreads to $newNumNetworkThreads") + if (newNumNetworkThreads > oldNumNetworkThreads) { + dataPlaneAcceptors.forEach { (endpoint, acceptor) => + addDataPlaneProcessors(acceptor, endpoint, newNumNetworkThreads - oldNumNetworkThreads) + } + } else if (newNumNetworkThreads < oldNumNetworkThreads) + dataPlaneAcceptors.asScala.values.foreach(_.removeProcessors(oldNumNetworkThreads - newNumNetworkThreads, dataPlaneRequestChannel)) + } + + /** + * Shutdown the socket server. If still processing requests, shutdown + * acceptors and processors first. + */ + def shutdown(): Unit = { + info("Shutting down socket server") + this.synchronized { + if (!stoppedProcessingRequests) + stopProcessingRequests() + dataPlaneRequestChannel.shutdown() + controlPlaneRequestChannelOpt.foreach(_.shutdown()) + connectionQuotas.close() + } + info("Shutdown completed") + } + + def boundPort(listenerName: ListenerName): Int = { + try { + val acceptor = dataPlaneAcceptors.get(endpoints(listenerName)) + if (acceptor != null) { + acceptor.serverChannel.socket.getLocalPort + } else { + controlPlaneAcceptorOpt.map (_.serverChannel.socket().getLocalPort).getOrElse(throw new KafkaException("Could not find listenerName : " + listenerName + " in data-plane or control-plane")) + } + } catch { + case e: Exception => + throw new KafkaException("Tried to check server's port before server was started or checked for port of non-existing protocol", e) + } + } + + def addListeners(listenersAdded: Seq[EndPoint]): Unit = synchronized { + info(s"Adding data-plane listeners for endpoints $listenersAdded") + createDataPlaneAcceptorsAndProcessors(config.numNetworkThreads, listenersAdded) + listenersAdded.foreach { endpoint => + val acceptor = dataPlaneAcceptors.get(endpoint) + startAcceptorAndProcessors(DataPlaneThreadPrefix, endpoint, acceptor) + } + } + + def removeListeners(listenersRemoved: Seq[EndPoint]): Unit = synchronized { + info(s"Removing data-plane listeners for endpoints $listenersRemoved") + listenersRemoved.foreach { endpoint => + connectionQuotas.removeListener(config, endpoint.listenerName) + dataPlaneAcceptors.asScala.remove(endpoint).foreach { acceptor => + acceptor.initiateShutdown() + acceptor.awaitShutdown() + } + } + } + + override def reconfigurableConfigs: Set[String] = SocketServer.ReconfigurableConfigs + + override def validateReconfiguration(newConfig: KafkaConfig): Unit = { + + } + + override def reconfigure(oldConfig: KafkaConfig, newConfig: KafkaConfig): Unit = { + val maxConnectionsPerIp = newConfig.maxConnectionsPerIp + if (maxConnectionsPerIp != oldConfig.maxConnectionsPerIp) { + info(s"Updating maxConnectionsPerIp: $maxConnectionsPerIp") + connectionQuotas.updateMaxConnectionsPerIp(maxConnectionsPerIp) + } + val maxConnectionsPerIpOverrides = newConfig.maxConnectionsPerIpOverrides + if (maxConnectionsPerIpOverrides != oldConfig.maxConnectionsPerIpOverrides) { + info(s"Updating maxConnectionsPerIpOverrides: ${maxConnectionsPerIpOverrides.map { case (k, v) => s"$k=$v" }.mkString(",")}") + connectionQuotas.updateMaxConnectionsPerIpOverride(maxConnectionsPerIpOverrides) + } + val maxConnections = newConfig.maxConnections + if (maxConnections != oldConfig.maxConnections) { + info(s"Updating broker-wide maxConnections: $maxConnections") + connectionQuotas.updateBrokerMaxConnections(maxConnections) + } + val maxConnectionRate = newConfig.maxConnectionCreationRate + if (maxConnectionRate != oldConfig.maxConnectionCreationRate) { + info(s"Updating broker-wide maxConnectionCreationRate: $maxConnectionRate") + connectionQuotas.updateBrokerMaxConnectionRate(maxConnectionRate) + } + } + + private def waitForAuthorizerFuture(acceptor: Acceptor, + authorizerFutures: Map[Endpoint, CompletableFuture[Void]]): Unit = { + //we can't rely on authorizerFutures.get() due to ephemeral ports. Get the future using listener name + authorizerFutures.forKeyValue { (endpoint, future) => + if (endpoint.listenerName == Optional.of(acceptor.endPoint.listenerName.value)) + future.join() + } + } + + // `protected` for test usage + protected[network] def newProcessor(id: Int, requestChannel: RequestChannel, connectionQuotas: ConnectionQuotas, listenerName: ListenerName, + securityProtocol: SecurityProtocol, memoryPool: MemoryPool, isPrivilegedListener: Boolean): Processor = { + new Processor(id, + time, + config.socketRequestMaxBytes, + requestChannel, + connectionQuotas, + config.connectionsMaxIdleMs, + config.failedAuthenticationDelayMs, + listenerName, + securityProtocol, + config, + metrics, + credentialProvider, + memoryPool, + logContext, + Processor.ConnectionQueueSize, + isPrivilegedListener, + apiVersionManager + ) + } + + // For test usage + private[network] def connectionCount(address: InetAddress): Int = + Option(connectionQuotas).fold(0)(_.get(address)) + + // For test usage + private[network] def dataPlaneProcessor(index: Int): Processor = dataPlaneProcessors.get(index) + +} + +object SocketServer { + val MetricsGroup = "socket-server-metrics" + val DataPlaneThreadPrefix = "data-plane" + val ControlPlaneThreadPrefix = "control-plane" + val DataPlaneMetricPrefix = "" + val ControlPlaneMetricPrefix = "ControlPlane" + + val ReconfigurableConfigs = Set( + KafkaConfig.MaxConnectionsPerIpProp, + KafkaConfig.MaxConnectionsPerIpOverridesProp, + KafkaConfig.MaxConnectionsProp, + KafkaConfig.MaxConnectionCreationRateProp) + + val ListenerReconfigurableConfigs = Set(KafkaConfig.MaxConnectionsProp, KafkaConfig.MaxConnectionCreationRateProp) +} + +/** + * A base class with some helper variables and methods + */ +private[kafka] abstract class AbstractServerThread(connectionQuotas: ConnectionQuotas) extends Runnable with Logging { + + private val startupLatch = new CountDownLatch(1) + + // `shutdown()` is invoked before `startupComplete` and `shutdownComplete` if an exception is thrown in the constructor + // (e.g. if the address is already in use). We want `shutdown` to proceed in such cases, so we first assign an open + // latch and then replace it in `startupComplete()`. + @volatile private var shutdownLatch = new CountDownLatch(0) + + private val alive = new AtomicBoolean(true) + + def wakeup(): Unit + + /** + * Initiates a graceful shutdown by signaling to stop + */ + def initiateShutdown(): Unit = { + if (alive.getAndSet(false)) + wakeup() + } + + /** + * Wait for the thread to completely shutdown + */ + def awaitShutdown(): Unit = shutdownLatch.await + + /** + * Returns true if the thread is completely started + */ + def isStarted(): Boolean = startupLatch.getCount == 0 + + /** + * Wait for the thread to completely start up + */ + def awaitStartup(): Unit = startupLatch.await + + /** + * Record that the thread startup is complete + */ + protected def startupComplete(): Unit = { + // Replace the open latch with a closed one + shutdownLatch = new CountDownLatch(1) + startupLatch.countDown() + } + + /** + * Record that the thread shutdown is complete + */ + protected def shutdownComplete(): Unit = shutdownLatch.countDown() + + /** + * Is the server still running? + */ + protected def isRunning: Boolean = alive.get + + /** + * Close `channel` and decrement the connection count. + */ + def close(listenerName: ListenerName, channel: SocketChannel): Unit = { + if (channel != null) { + debug(s"Closing connection from ${channel.socket.getRemoteSocketAddress}") + connectionQuotas.dec(listenerName, channel.socket.getInetAddress) + closeSocket(channel) + } + } + + protected def closeSocket(channel: SocketChannel): Unit = { + CoreUtils.swallow(channel.socket().close(), this, Level.ERROR) + CoreUtils.swallow(channel.close(), this, Level.ERROR) + } +} + +/** + * Thread that accepts and configures new connections. There is one of these per endpoint. + */ +private[kafka] class Acceptor(val endPoint: EndPoint, + val sendBufferSize: Int, + val recvBufferSize: Int, + nodeId: Int, + connectionQuotas: ConnectionQuotas, + metricPrefix: String, + time: Time, + logPrefix: String = "") extends AbstractServerThread(connectionQuotas) with KafkaMetricsGroup { + + this.logIdent = logPrefix + private val nioSelector = NSelector.open() + val serverChannel = openServerSocket(endPoint.host, endPoint.port) + private val processors = new ArrayBuffer[Processor]() + private val processorsStarted = new AtomicBoolean + private val blockedPercentMeter = newMeter(s"${metricPrefix}AcceptorBlockedPercent", + "blocked time", TimeUnit.NANOSECONDS, Map(ListenerMetricTag -> endPoint.listenerName.value)) + private var currentProcessorIndex = 0 + private[network] val throttledSockets = new mutable.PriorityQueue[DelayedCloseSocket]() + + private[network] case class DelayedCloseSocket(socket: SocketChannel, endThrottleTimeMs: Long) extends Ordered[DelayedCloseSocket] { + override def compare(that: DelayedCloseSocket): Int = endThrottleTimeMs compare that.endThrottleTimeMs + } + + private[network] def addProcessors(newProcessors: Buffer[Processor], processorThreadPrefix: String): Unit = synchronized { + processors ++= newProcessors + if (processorsStarted.get) + startProcessors(newProcessors, processorThreadPrefix) + } + + private[network] def startProcessors(processorThreadPrefix: String): Unit = synchronized { + if (!processorsStarted.getAndSet(true)) { + startProcessors(processors, processorThreadPrefix) + } + } + + private def startProcessors(processors: Seq[Processor], processorThreadPrefix: String): Unit = synchronized { + processors.foreach { processor => + KafkaThread.nonDaemon( + s"${processorThreadPrefix}-kafka-network-thread-$nodeId-${endPoint.listenerName}-${endPoint.securityProtocol}-${processor.id}", + processor + ).start() + } + } + + private[network] def removeProcessors(removeCount: Int, requestChannel: RequestChannel): Unit = synchronized { + // Shutdown `removeCount` processors. Remove them from the processor list first so that no more + // connections are assigned. Shutdown the removed processors, closing the selector and its connections. + // The processors are then removed from `requestChannel` and any pending responses to these processors are dropped. + val toRemove = processors.takeRight(removeCount) + processors.remove(processors.size - removeCount, removeCount) + toRemove.foreach(_.initiateShutdown()) + toRemove.foreach(_.awaitShutdown()) + toRemove.foreach(processor => requestChannel.removeProcessor(processor.id)) + } + + override def initiateShutdown(): Unit = { + super.initiateShutdown() + synchronized { + processors.foreach(_.initiateShutdown()) + } + } + + override def awaitShutdown(): Unit = { + super.awaitShutdown() + synchronized { + processors.foreach(_.awaitShutdown()) + } + } + + /** + * Accept loop that checks for new connection attempts + */ + def run(): Unit = { + serverChannel.register(nioSelector, SelectionKey.OP_ACCEPT) + startupComplete() + try { + while (isRunning) { + try { + acceptNewConnections() + closeThrottledConnections() + } + catch { + // We catch all the throwables to prevent the acceptor thread from exiting on exceptions due + // to a select operation on a specific channel or a bad request. We don't want + // the broker to stop responding to requests from other clients in these scenarios. + case e: ControlThrowable => throw e + case e: Throwable => error("Error occurred", e) + } + } + } finally { + debug("Closing server socket, selector, and any throttled sockets.") + CoreUtils.swallow(serverChannel.close(), this, Level.ERROR) + CoreUtils.swallow(nioSelector.close(), this, Level.ERROR) + throttledSockets.foreach(throttledSocket => closeSocket(throttledSocket.socket)) + throttledSockets.clear() + shutdownComplete() + } + } + + /** + * Create a server socket to listen for connections on. + */ + private def openServerSocket(host: String, port: Int): ServerSocketChannel = { + val socketAddress = + if (Utils.isBlank(host)) + new InetSocketAddress(port) + else + new InetSocketAddress(host, port) + val serverChannel = ServerSocketChannel.open() + serverChannel.configureBlocking(false) + if (recvBufferSize != Selectable.USE_DEFAULT_BUFFER_SIZE) + serverChannel.socket().setReceiveBufferSize(recvBufferSize) + + try { + serverChannel.socket.bind(socketAddress) + info(s"Awaiting socket connections on ${socketAddress.getHostString}:${serverChannel.socket.getLocalPort}.") + } catch { + case e: SocketException => + throw new KafkaException(s"Socket server failed to bind to ${socketAddress.getHostString}:$port: ${e.getMessage}.", e) + } + serverChannel + } + + /** + * Listen for new connections and assign accepted connections to processors using round-robin. + */ + private def acceptNewConnections(): Unit = { + val ready = nioSelector.select(500) + if (ready > 0) { + val keys = nioSelector.selectedKeys() + val iter = keys.iterator() + while (iter.hasNext && isRunning) { + try { + val key = iter.next + iter.remove() + + if (key.isAcceptable) { + accept(key).foreach { socketChannel => + // Assign the channel to the next processor (using round-robin) to which the + // channel can be added without blocking. If newConnections queue is full on + // all processors, block until the last one is able to accept a connection. + var retriesLeft = synchronized(processors.length) + var processor: Processor = null + do { + retriesLeft -= 1 + processor = synchronized { + // adjust the index (if necessary) and retrieve the processor atomically for + // correct behaviour in case the number of processors is reduced dynamically + currentProcessorIndex = currentProcessorIndex % processors.length + processors(currentProcessorIndex) + } + currentProcessorIndex += 1 + } while (!assignNewConnection(socketChannel, processor, retriesLeft == 0)) + } + } else + throw new IllegalStateException("Unrecognized key state for acceptor thread.") + } catch { + case e: Throwable => error("Error while accepting connection", e) + } + } + } + } + + /** + * Accept a new connection + */ + private def accept(key: SelectionKey): Option[SocketChannel] = { + val serverSocketChannel = key.channel().asInstanceOf[ServerSocketChannel] + val socketChannel = serverSocketChannel.accept() + try { + connectionQuotas.inc(endPoint.listenerName, socketChannel.socket.getInetAddress, blockedPercentMeter) + socketChannel.configureBlocking(false) + socketChannel.socket().setTcpNoDelay(true) + socketChannel.socket().setKeepAlive(true) + if (sendBufferSize != Selectable.USE_DEFAULT_BUFFER_SIZE) + socketChannel.socket().setSendBufferSize(sendBufferSize) + Some(socketChannel) + } catch { + case e: TooManyConnectionsException => + info(s"Rejected connection from ${e.ip}, address already has the configured maximum of ${e.count} connections.") + close(endPoint.listenerName, socketChannel) + None + case e: ConnectionThrottledException => + val ip = socketChannel.socket.getInetAddress + debug(s"Delaying closing of connection from $ip for ${e.throttleTimeMs} ms") + val endThrottleTimeMs = e.startThrottleTimeMs + e.throttleTimeMs + throttledSockets += DelayedCloseSocket(socketChannel, endThrottleTimeMs) + None + } + } + + /** + * Close sockets for any connections that have been throttled. + */ + private def closeThrottledConnections(): Unit = { + val timeMs = time.milliseconds + while (throttledSockets.headOption.exists(_.endThrottleTimeMs < timeMs)) { + val closingSocket = throttledSockets.dequeue() + debug(s"Closing socket from ip ${closingSocket.socket.getRemoteAddress}") + closeSocket(closingSocket.socket) + } + } + + private def assignNewConnection(socketChannel: SocketChannel, processor: Processor, mayBlock: Boolean): Boolean = { + if (processor.accept(socketChannel, mayBlock, blockedPercentMeter)) { + debug(s"Accepted connection from ${socketChannel.socket.getRemoteSocketAddress} on" + + s" ${socketChannel.socket.getLocalSocketAddress} and assigned it to processor ${processor.id}," + + s" sendBufferSize [actual|requested]: [${socketChannel.socket.getSendBufferSize}|$sendBufferSize]" + + s" recvBufferSize [actual|requested]: [${socketChannel.socket.getReceiveBufferSize}|$recvBufferSize]") + true + } else + false + } + + /** + * Wakeup the thread for selection. + */ + @Override + def wakeup(): Unit = nioSelector.wakeup() + +} + +private[kafka] object Processor { + val IdlePercentMetricName = "IdlePercent" + val NetworkProcessorMetricTag = "networkProcessor" + val ListenerMetricTag = "listener" + val ConnectionQueueSize = 20 +} + +/** + * Thread that processes all requests from a single connection. There are N of these running in parallel + * each of which has its own selector + * + * @param isPrivilegedListener The privileged listener flag is used as one factor to determine whether + * a certain request is forwarded or not. When the control plane is defined, + * the control plane processor would be fellow broker's choice for sending + * forwarding requests; if the control plane is not defined, the processor + * relying on the inter broker listener would be acting as the privileged listener. + */ +private[kafka] class Processor(val id: Int, + time: Time, + maxRequestSize: Int, + requestChannel: RequestChannel, + connectionQuotas: ConnectionQuotas, + connectionsMaxIdleMs: Long, + failedAuthenticationDelayMs: Int, + listenerName: ListenerName, + securityProtocol: SecurityProtocol, + config: KafkaConfig, + metrics: Metrics, + credentialProvider: CredentialProvider, + memoryPool: MemoryPool, + logContext: LogContext, + connectionQueueSize: Int, + isPrivilegedListener: Boolean, + apiVersionManager: ApiVersionManager) extends AbstractServerThread(connectionQuotas) with KafkaMetricsGroup { + + private object ConnectionId { + def fromString(s: String): Option[ConnectionId] = s.split("-") match { + case Array(local, remote, index) => BrokerEndPoint.parseHostPort(local).flatMap { case (localHost, localPort) => + BrokerEndPoint.parseHostPort(remote).map { case (remoteHost, remotePort) => + ConnectionId(localHost, localPort, remoteHost, remotePort, Integer.parseInt(index)) + } + } + case _ => None + } + } + + private[network] case class ConnectionId(localHost: String, localPort: Int, remoteHost: String, remotePort: Int, index: Int) { + override def toString: String = s"$localHost:$localPort-$remoteHost:$remotePort-$index" + } + + private val newConnections = new ArrayBlockingQueue[SocketChannel](connectionQueueSize) + private val inflightResponses = mutable.Map[String, RequestChannel.Response]() + private val responseQueue = new LinkedBlockingDeque[RequestChannel.Response]() + + private[kafka] val metricTags = mutable.LinkedHashMap( + ListenerMetricTag -> listenerName.value, + NetworkProcessorMetricTag -> id.toString + ).asJava + + newGauge(IdlePercentMetricName, () => { + Option(metrics.metric(metrics.metricName("io-wait-ratio", MetricsGroup, metricTags))).fold(0.0)(m => + Math.min(m.metricValue.asInstanceOf[Double], 1.0)) + }, + // for compatibility, only add a networkProcessor tag to the Yammer Metrics alias (the equivalent Selector metric + // also includes the listener name) + Map(NetworkProcessorMetricTag -> id.toString) + ) + + val expiredConnectionsKilledCount = new CumulativeSum() + private val expiredConnectionsKilledCountMetricName = metrics.metricName("expired-connections-killed-count", MetricsGroup, metricTags) + metrics.addMetric(expiredConnectionsKilledCountMetricName, expiredConnectionsKilledCount) + + private val selector = createSelector( + ChannelBuilders.serverChannelBuilder( + listenerName, + listenerName == config.interBrokerListenerName, + securityProtocol, + config, + credentialProvider.credentialCache, + credentialProvider.tokenCache, + time, + logContext, + () => apiVersionManager.apiVersionResponse(throttleTimeMs = 0) + ) + ) + + // Visible to override for testing + protected[network] def createSelector(channelBuilder: ChannelBuilder): KSelector = { + channelBuilder match { + case reconfigurable: Reconfigurable => config.addReconfigurable(reconfigurable) + case _ => + } + new KSelector( + maxRequestSize, + connectionsMaxIdleMs, + failedAuthenticationDelayMs, + metrics, + time, + "socket-server", + metricTags, + false, + true, + channelBuilder, + memoryPool, + logContext) + } + + // Connection ids have the format `localAddr:localPort-remoteAddr:remotePort-index`. The index is a + // non-negative incrementing value that ensures that even if remotePort is reused after a connection is + // closed, connection ids are not reused while requests from the closed connection are being processed. + private var nextConnectionIndex = 0 + + override def run(): Unit = { + startupComplete() + try { + while (isRunning) { + try { + // setup any new connections that have been queued up + configureNewConnections() + // register any new responses for writing + processNewResponses() + poll() + processCompletedReceives() + processCompletedSends() + processDisconnected() + closeExcessConnections() + } catch { + // We catch all the throwables here to prevent the processor thread from exiting. We do this because + // letting a processor exit might cause a bigger impact on the broker. This behavior might need to be + // reviewed if we see an exception that needs the entire broker to stop. Usually the exceptions thrown would + // be either associated with a specific socket channel or a bad request. These exceptions are caught and + // processed by the individual methods above which close the failing channel and continue processing other + // channels. So this catch block should only ever see ControlThrowables. + case e: Throwable => processException("Processor got uncaught exception.", e) + } + } + } finally { + debug(s"Closing selector - processor $id") + CoreUtils.swallow(closeAll(), this, Level.ERROR) + shutdownComplete() + } + } + + private[network] def processException(errorMessage: String, throwable: Throwable): Unit = { + throwable match { + case e: ControlThrowable => throw e + case e => error(errorMessage, e) + } + } + + private def processChannelException(channelId: String, errorMessage: String, throwable: Throwable): Unit = { + if (openOrClosingChannel(channelId).isDefined) { + error(s"Closing socket for $channelId because of error", throwable) + close(channelId) + } + processException(errorMessage, throwable) + } + + private def processNewResponses(): Unit = { + var currentResponse: RequestChannel.Response = null + while ({currentResponse = dequeueResponse(); currentResponse != null}) { + val channelId = currentResponse.request.context.connectionId + try { + currentResponse match { + case response: NoOpResponse => + // There is no response to send to the client, we need to read more pipelined requests + // that are sitting in the server's socket buffer + updateRequestMetrics(response) + trace(s"Socket server received empty response to send, registering for read: $response") + // Try unmuting the channel. If there was no quota violation and the channel has not been throttled, + // it will be unmuted immediately. If the channel has been throttled, it will be unmuted only if the + // throttling delay has already passed by now. + handleChannelMuteEvent(channelId, ChannelMuteEvent.RESPONSE_SENT) + tryUnmuteChannel(channelId) + + case response: SendResponse => + sendResponse(response, response.responseSend) + case response: CloseConnectionResponse => + updateRequestMetrics(response) + trace("Closing socket connection actively according to the response code.") + close(channelId) + case _: StartThrottlingResponse => + handleChannelMuteEvent(channelId, ChannelMuteEvent.THROTTLE_STARTED) + case _: EndThrottlingResponse => + // Try unmuting the channel. The channel will be unmuted only if the response has already been sent out to + // the client. + handleChannelMuteEvent(channelId, ChannelMuteEvent.THROTTLE_ENDED) + tryUnmuteChannel(channelId) + case _ => + throw new IllegalArgumentException(s"Unknown response type: ${currentResponse.getClass}") + } + } catch { + case e: Throwable => + processChannelException(channelId, s"Exception while processing response for $channelId", e) + } + } + } + + // `protected` for test usage + protected[network] def sendResponse(response: RequestChannel.Response, responseSend: Send): Unit = { + val connectionId = response.request.context.connectionId + trace(s"Socket server received response to send to $connectionId, registering for write and sending data: $response") + // `channel` can be None if the connection was closed remotely or if selector closed it for being idle for too long + if (channel(connectionId).isEmpty) { + warn(s"Attempting to send response via channel for which there is no open connection, connection id $connectionId") + response.request.updateRequestMetrics(0L, response) + } + // Invoke send for closingChannel as well so that the send is failed and the channel closed properly and + // removed from the Selector after discarding any pending staged receives. + // `openOrClosingChannel` can be None if the selector closed the connection because it was idle for too long + if (openOrClosingChannel(connectionId).isDefined) { + selector.send(new NetworkSend(connectionId, responseSend)) + inflightResponses += (connectionId -> response) + } + } + + private def poll(): Unit = { + val pollTimeout = if (newConnections.isEmpty) 300 else 0 + try selector.poll(pollTimeout) + catch { + case e @ (_: IllegalStateException | _: IOException) => + // The exception is not re-thrown and any completed sends/receives/connections/disconnections + // from this poll will be processed. + error(s"Processor $id poll failed", e) + } + } + + protected def parseRequestHeader(buffer: ByteBuffer): RequestHeader = { + val header = RequestHeader.parse(buffer) + if (apiVersionManager.isApiEnabled(header.apiKey)) { + header + } else { + throw new InvalidRequestException(s"Received request api key ${header.apiKey} which is not enabled") + } + } + + private def processCompletedReceives(): Unit = { + selector.completedReceives.forEach { receive => + try { + openOrClosingChannel(receive.source) match { + case Some(channel) => + val header = parseRequestHeader(receive.payload) + if (header.apiKey == ApiKeys.SASL_HANDSHAKE && channel.maybeBeginServerReauthentication(receive, + () => time.nanoseconds())) + trace(s"Begin re-authentication: $channel") + else { + val nowNanos = time.nanoseconds() + if (channel.serverAuthenticationSessionExpired(nowNanos)) { + // be sure to decrease connection count and drop any in-flight responses + debug(s"Disconnecting expired channel: $channel : $header") + close(channel.id) + expiredConnectionsKilledCount.record(null, 1, 0) + } else { + val connectionId = receive.source + val context = new RequestContext(header, connectionId, channel.socketAddress, + channel.principal, listenerName, securityProtocol, + channel.channelMetadataRegistry.clientInformation, isPrivilegedListener, channel.principalSerde) + + val req = new RequestChannel.Request(processor = id, context = context, + startTimeNanos = nowNanos, memoryPool, receive.payload, requestChannel.metrics, None) + + // KIP-511: ApiVersionsRequest is intercepted here to catch the client software name + // and version. It is done here to avoid wiring things up to the api layer. + if (header.apiKey == ApiKeys.API_VERSIONS) { + val apiVersionsRequest = req.body[ApiVersionsRequest] + if (apiVersionsRequest.isValid) { + channel.channelMetadataRegistry.registerClientInformation(new ClientInformation( + apiVersionsRequest.data.clientSoftwareName, + apiVersionsRequest.data.clientSoftwareVersion)) + } + } + requestChannel.sendRequest(req) + selector.mute(connectionId) + handleChannelMuteEvent(connectionId, ChannelMuteEvent.REQUEST_RECEIVED) + } + } + case None => + // This should never happen since completed receives are processed immediately after `poll()` + throw new IllegalStateException(s"Channel ${receive.source} removed from selector before processing completed receive") + } + } catch { + // note that even though we got an exception, we can assume that receive.source is valid. + // Issues with constructing a valid receive object were handled earlier + case e: Throwable => + processChannelException(receive.source, s"Exception while processing request from ${receive.source}", e) + } + } + selector.clearCompletedReceives() + } + + private def processCompletedSends(): Unit = { + selector.completedSends.forEach { send => + try { + val response = inflightResponses.remove(send.destinationId).getOrElse { + throw new IllegalStateException(s"Send for ${send.destinationId} completed, but not in `inflightResponses`") + } + updateRequestMetrics(response) + + // Invoke send completion callback + response.onComplete.foreach(onComplete => onComplete(send)) + + // Try unmuting the channel. If there was no quota violation and the channel has not been throttled, + // it will be unmuted immediately. If the channel has been throttled, it will unmuted only if the throttling + // delay has already passed by now. + handleChannelMuteEvent(send.destinationId, ChannelMuteEvent.RESPONSE_SENT) + tryUnmuteChannel(send.destinationId) + } catch { + case e: Throwable => processChannelException(send.destinationId, + s"Exception while processing completed send to ${send.destinationId}", e) + } + } + selector.clearCompletedSends() + } + + private def updateRequestMetrics(response: RequestChannel.Response): Unit = { + val request = response.request + val networkThreadTimeNanos = openOrClosingChannel(request.context.connectionId).fold(0L)(_.getAndResetNetworkThreadTimeNanos()) + request.updateRequestMetrics(networkThreadTimeNanos, response) + } + + private def processDisconnected(): Unit = { + selector.disconnected.keySet.forEach { connectionId => + try { + val remoteHost = ConnectionId.fromString(connectionId).getOrElse { + throw new IllegalStateException(s"connectionId has unexpected format: $connectionId") + }.remoteHost + inflightResponses.remove(connectionId).foreach(updateRequestMetrics) + // the channel has been closed by the selector but the quotas still need to be updated + connectionQuotas.dec(listenerName, InetAddress.getByName(remoteHost)) + } catch { + case e: Throwable => processException(s"Exception while processing disconnection of $connectionId", e) + } + } + } + + private def closeExcessConnections(): Unit = { + if (connectionQuotas.maxConnectionsExceeded(listenerName)) { + val channel = selector.lowestPriorityChannel() + if (channel != null) + close(channel.id) + } + } + + /** + * Close the connection identified by `connectionId` and decrement the connection count. + * The channel will be immediately removed from the selector's `channels` or `closingChannels` + * and no further disconnect notifications will be sent for this channel by the selector. + * If responses are pending for the channel, they are dropped and metrics is updated. + * If the channel has already been removed from selector, no action is taken. + */ + private def close(connectionId: String): Unit = { + openOrClosingChannel(connectionId).foreach { channel => + debug(s"Closing selector connection $connectionId") + val address = channel.socketAddress + if (address != null) + connectionQuotas.dec(listenerName, address) + selector.close(connectionId) + + inflightResponses.remove(connectionId).foreach(response => updateRequestMetrics(response)) + } + } + + /** + * Queue up a new connection for reading + */ + def accept(socketChannel: SocketChannel, + mayBlock: Boolean, + acceptorIdlePercentMeter: com.yammer.metrics.core.Meter): Boolean = { + val accepted = { + if (newConnections.offer(socketChannel)) + true + else if (mayBlock) { + val startNs = time.nanoseconds + newConnections.put(socketChannel) + acceptorIdlePercentMeter.mark(time.nanoseconds() - startNs) + true + } else + false + } + if (accepted) + wakeup() + accepted + } + + /** + * Register any new connections that have been queued up. The number of connections processed + * in each iteration is limited to ensure that traffic and connection close notifications of + * existing channels are handled promptly. + */ + private def configureNewConnections(): Unit = { + var connectionsProcessed = 0 + while (connectionsProcessed < connectionQueueSize && !newConnections.isEmpty) { + val channel = newConnections.poll() + try { + debug(s"Processor $id listening to new connection from ${channel.socket.getRemoteSocketAddress}") + selector.register(connectionId(channel.socket), channel) + connectionsProcessed += 1 + } catch { + // We explicitly catch all exceptions and close the socket to avoid a socket leak. + case e: Throwable => + val remoteAddress = channel.socket.getRemoteSocketAddress + // need to close the channel here to avoid a socket leak. + close(listenerName, channel) + processException(s"Processor $id closed connection from $remoteAddress", e) + } + } + } + + /** + * Close the selector and all open connections + */ + private def closeAll(): Unit = { + while (!newConnections.isEmpty) { + newConnections.poll().close() + } + selector.channels.forEach { channel => + close(channel.id) + } + selector.close() + removeMetric(IdlePercentMetricName, Map(NetworkProcessorMetricTag -> id.toString)) + } + + // 'protected` to allow override for testing + protected[network] def connectionId(socket: Socket): String = { + val localHost = socket.getLocalAddress.getHostAddress + val localPort = socket.getLocalPort + val remoteHost = socket.getInetAddress.getHostAddress + val remotePort = socket.getPort + val connId = ConnectionId(localHost, localPort, remoteHost, remotePort, nextConnectionIndex).toString + nextConnectionIndex = if (nextConnectionIndex == Int.MaxValue) 0 else nextConnectionIndex + 1 + connId + } + + private[network] def enqueueResponse(response: RequestChannel.Response): Unit = { + responseQueue.put(response) + wakeup() + } + + private def dequeueResponse(): RequestChannel.Response = { + val response = responseQueue.poll() + if (response != null) + response.request.responseDequeueTimeNanos = Time.SYSTEM.nanoseconds + response + } + + private[network] def responseQueueSize = responseQueue.size + + // Only for testing + private[network] def inflightResponseCount: Int = inflightResponses.size + + // Visible for testing + // Only methods that are safe to call on a disconnected channel should be invoked on 'openOrClosingChannel'. + private[network] def openOrClosingChannel(connectionId: String): Option[KafkaChannel] = + Option(selector.channel(connectionId)).orElse(Option(selector.closingChannel(connectionId))) + + // Indicate the specified channel that the specified channel mute-related event has happened so that it can change its + // mute state. + private def handleChannelMuteEvent(connectionId: String, event: ChannelMuteEvent): Unit = { + openOrClosingChannel(connectionId).foreach(c => c.handleChannelMuteEvent(event)) + } + + private def tryUnmuteChannel(connectionId: String) = { + openOrClosingChannel(connectionId).foreach(c => selector.unmute(c.id)) + } + + /* For test usage */ + private[network] def channel(connectionId: String): Option[KafkaChannel] = + Option(selector.channel(connectionId)) + + /** + * Wakeup the thread for selection. + */ + override def wakeup() = selector.wakeup() + + override def initiateShutdown(): Unit = { + super.initiateShutdown() + removeMetric("IdlePercent", Map("networkProcessor" -> id.toString)) + metrics.removeMetric(expiredConnectionsKilledCountMetricName) + } +} + +/** + * Interface for connection quota configuration. Connection quotas can be configured at the + * broker, listener or IP level. + */ +sealed trait ConnectionQuotaEntity { + def sensorName: String + def metricName: String + def sensorExpiration: Long + def metricTags: Map[String, String] +} + +object ConnectionQuotas { + private val InactiveSensorExpirationTimeSeconds = TimeUnit.HOURS.toSeconds(1) + private val ConnectionRateSensorName = "Connection-Accept-Rate" + private val ConnectionRateMetricName = "connection-accept-rate" + private val IpMetricTag = "ip" + private val ListenerThrottlePrefix = "" + private val IpThrottlePrefix = "ip-" + + private case class ListenerQuotaEntity(listenerName: String) extends ConnectionQuotaEntity { + override def sensorName: String = s"$ConnectionRateSensorName-$listenerName" + override def sensorExpiration: Long = Long.MaxValue + override def metricName: String = ConnectionRateMetricName + override def metricTags: Map[String, String] = Map(ListenerMetricTag -> listenerName) + } + + private case object BrokerQuotaEntity extends ConnectionQuotaEntity { + override def sensorName: String = ConnectionRateSensorName + override def sensorExpiration: Long = Long.MaxValue + override def metricName: String = s"broker-$ConnectionRateMetricName" + override def metricTags: Map[String, String] = Map.empty + } + + private case class IpQuotaEntity(ip: InetAddress) extends ConnectionQuotaEntity { + override def sensorName: String = s"$ConnectionRateSensorName-${ip.getHostAddress}" + override def sensorExpiration: Long = InactiveSensorExpirationTimeSeconds + override def metricName: String = ConnectionRateMetricName + override def metricTags: Map[String, String] = Map(IpMetricTag -> ip.getHostAddress) + } +} + +class ConnectionQuotas(config: KafkaConfig, time: Time, metrics: Metrics) extends Logging with AutoCloseable { + + @volatile private var defaultMaxConnectionsPerIp: Int = config.maxConnectionsPerIp + @volatile private var maxConnectionsPerIpOverrides = config.maxConnectionsPerIpOverrides.map { case (host, count) => (InetAddress.getByName(host), count) } + @volatile private var brokerMaxConnections = config.maxConnections + private val interBrokerListenerName = config.interBrokerListenerName + private val counts = mutable.Map[InetAddress, Int]() + + // Listener counts and configs are synchronized on `counts` + private val listenerCounts = mutable.Map[ListenerName, Int]() + private[network] val maxConnectionsPerListener = mutable.Map[ListenerName, ListenerConnectionQuota]() + @volatile private var totalCount = 0 + // updates to defaultConnectionRatePerIp or connectionRatePerIp must be synchronized on `counts` + @volatile private var defaultConnectionRatePerIp = QuotaConfigs.IP_CONNECTION_RATE_DEFAULT.intValue() + private val connectionRatePerIp = new ConcurrentHashMap[InetAddress, Int]() + // sensor that tracks broker-wide connection creation rate and limit (quota) + private val brokerConnectionRateSensor = getOrCreateConnectionRateQuotaSensor(config.maxConnectionCreationRate, BrokerQuotaEntity) + private val maxThrottleTimeMs = TimeUnit.SECONDS.toMillis(config.quotaWindowSizeSeconds.toLong) + + def inc(listenerName: ListenerName, address: InetAddress, acceptorBlockedPercentMeter: com.yammer.metrics.core.Meter): Unit = { + counts.synchronized { + waitForConnectionSlot(listenerName, acceptorBlockedPercentMeter) + + recordIpConnectionMaybeThrottle(listenerName, address) + val count = counts.getOrElseUpdate(address, 0) + counts.put(address, count + 1) + totalCount += 1 + if (listenerCounts.contains(listenerName)) { + listenerCounts.put(listenerName, listenerCounts(listenerName) + 1) + } + val max = maxConnectionsPerIpOverrides.getOrElse(address, defaultMaxConnectionsPerIp) + if (count >= max) + throw new TooManyConnectionsException(address, max) + } + } + + private[network] def updateMaxConnectionsPerIp(maxConnectionsPerIp: Int): Unit = { + defaultMaxConnectionsPerIp = maxConnectionsPerIp + } + + private[network] def updateMaxConnectionsPerIpOverride(overrideQuotas: Map[String, Int]): Unit = { + maxConnectionsPerIpOverrides = overrideQuotas.map { case (host, count) => (InetAddress.getByName(host), count) } + } + + private[network] def updateBrokerMaxConnections(maxConnections: Int): Unit = { + counts.synchronized { + brokerMaxConnections = maxConnections + counts.notifyAll() + } + } + + private[network] def updateBrokerMaxConnectionRate(maxConnectionRate: Int): Unit = { + // if there is a connection waiting on the rate throttle delay, we will let it wait the original delay even if + // the rate limit increases, because it is just one connection per listener and the code is simpler that way + updateConnectionRateQuota(maxConnectionRate, BrokerQuotaEntity) + } + + /** + * Update the connection rate quota for a given IP and updates quota configs for updated IPs. + * If an IP is given, metric config will be updated only for the given IP, otherwise + * all metric configs will be checked and updated if required. + * + * @param ip ip to update or default if None + * @param maxConnectionRate new connection rate, or resets entity to default if None + */ + def updateIpConnectionRateQuota(ip: Option[InetAddress], maxConnectionRate: Option[Int]): Unit = synchronized { + def isIpConnectionRateMetric(metricName: MetricName) = { + metricName.name == ConnectionRateMetricName && + metricName.group == MetricsGroup && + metricName.tags.containsKey(IpMetricTag) + } + + def shouldUpdateQuota(metric: KafkaMetric, quotaLimit: Int) = { + quotaLimit != metric.config.quota.bound + } + + ip match { + case Some(address) => + // synchronize on counts to ensure reading an IP connection rate quota and creating a quota config is atomic + counts.synchronized { + maxConnectionRate match { + case Some(rate) => + info(s"Updating max connection rate override for $address to $rate") + connectionRatePerIp.put(address, rate) + case None => + info(s"Removing max connection rate override for $address") + connectionRatePerIp.remove(address) + } + } + updateConnectionRateQuota(connectionRateForIp(address), IpQuotaEntity(address)) + case None => + // synchronize on counts to ensure reading an IP connection rate quota and creating a quota config is atomic + counts.synchronized { + defaultConnectionRatePerIp = maxConnectionRate.getOrElse(QuotaConfigs.IP_CONNECTION_RATE_DEFAULT.intValue()) + } + info(s"Updated default max IP connection rate to $defaultConnectionRatePerIp") + metrics.metrics.forEach { (metricName, metric) => + if (isIpConnectionRateMetric(metricName)) { + val quota = connectionRateForIp(InetAddress.getByName(metricName.tags.get(IpMetricTag))) + if (shouldUpdateQuota(metric, quota)) { + debug(s"Updating existing connection rate quota config for ${metricName.tags} to $quota") + metric.config(rateQuotaMetricConfig(quota)) + } + } + } + } + } + + // Visible for testing + def connectionRateForIp(ip: InetAddress): Int = { + connectionRatePerIp.getOrDefault(ip, defaultConnectionRatePerIp) + } + + private[network] def addListener(config: KafkaConfig, listenerName: ListenerName): Unit = { + counts.synchronized { + if (!maxConnectionsPerListener.contains(listenerName)) { + val newListenerQuota = new ListenerConnectionQuota(counts, listenerName) + maxConnectionsPerListener.put(listenerName, newListenerQuota) + listenerCounts.put(listenerName, 0) + config.addReconfigurable(newListenerQuota) + newListenerQuota.configure(config.valuesWithPrefixOverride(listenerName.configPrefix)) + } + counts.notifyAll() + } + } + + private[network] def removeListener(config: KafkaConfig, listenerName: ListenerName): Unit = { + counts.synchronized { + maxConnectionsPerListener.remove(listenerName).foreach { listenerQuota => + listenerCounts.remove(listenerName) + // once listener is removed from maxConnectionsPerListener, no metrics will be recorded into listener's sensor + // so it is safe to remove sensor here + listenerQuota.close() + counts.notifyAll() // wake up any waiting acceptors to close cleanly + config.removeReconfigurable(listenerQuota) + } + } + } + + def dec(listenerName: ListenerName, address: InetAddress): Unit = { + counts.synchronized { + val count = counts.getOrElse(address, + throw new IllegalArgumentException(s"Attempted to decrease connection count for address with no connections, address: $address")) + if (count == 1) + counts.remove(address) + else + counts.put(address, count - 1) + + if (totalCount <= 0) + error(s"Attempted to decrease total connection count for broker with no connections") + totalCount -= 1 + + if (maxConnectionsPerListener.contains(listenerName)) { + val listenerCount = listenerCounts(listenerName) + if (listenerCount == 0) + error(s"Attempted to decrease connection count for listener $listenerName with no connections") + else + listenerCounts.put(listenerName, listenerCount - 1) + } + counts.notifyAll() // wake up any acceptors waiting to process a new connection since listener connection limit was reached + } + } + + def get(address: InetAddress): Int = counts.synchronized { + counts.getOrElse(address, 0) + } + + private def waitForConnectionSlot(listenerName: ListenerName, + acceptorBlockedPercentMeter: com.yammer.metrics.core.Meter): Unit = { + counts.synchronized { + val startThrottleTimeMs = time.milliseconds + val throttleTimeMs = math.max(recordConnectionAndGetThrottleTimeMs(listenerName, startThrottleTimeMs), 0) + + if (throttleTimeMs > 0 || !connectionSlotAvailable(listenerName)) { + val startNs = time.nanoseconds + val endThrottleTimeMs = startThrottleTimeMs + throttleTimeMs + var remainingThrottleTimeMs = throttleTimeMs + do { + counts.wait(remainingThrottleTimeMs) + remainingThrottleTimeMs = math.max(endThrottleTimeMs - time.milliseconds, 0) + } while (remainingThrottleTimeMs > 0 || !connectionSlotAvailable(listenerName)) + acceptorBlockedPercentMeter.mark(time.nanoseconds - startNs) + } + } + } + + // This is invoked in every poll iteration and we close one LRU connection in an iteration + // if necessary + def maxConnectionsExceeded(listenerName: ListenerName): Boolean = { + totalCount > brokerMaxConnections && !protectedListener(listenerName) + } + + private def connectionSlotAvailable(listenerName: ListenerName): Boolean = { + if (listenerCounts(listenerName) >= maxListenerConnections(listenerName)) + false + else if (protectedListener(listenerName)) + true + else + totalCount < brokerMaxConnections + } + + private def protectedListener(listenerName: ListenerName): Boolean = + interBrokerListenerName == listenerName && listenerCounts.size > 1 + + private def maxListenerConnections(listenerName: ListenerName): Int = + maxConnectionsPerListener.get(listenerName).map(_.maxConnections).getOrElse(Int.MaxValue) + + /** + * Calculates the delay needed to bring the observed connection creation rate to listener-level limit or to broker-wide + * limit, whichever the longest. The delay is capped to the quota window size defined by QuotaWindowSizeSecondsProp + * + * @param listenerName listener for which calculate the delay + * @param timeMs current time in milliseconds + * @return delay in milliseconds + */ + private def recordConnectionAndGetThrottleTimeMs(listenerName: ListenerName, timeMs: Long): Long = { + def recordAndGetListenerThrottleTime(minThrottleTimeMs: Int): Int = { + maxConnectionsPerListener + .get(listenerName) + .map { listenerQuota => + val listenerThrottleTimeMs = recordAndGetThrottleTimeMs(listenerQuota.connectionRateSensor, timeMs) + val throttleTimeMs = math.max(minThrottleTimeMs, listenerThrottleTimeMs) + // record throttle time due to hitting connection rate quota + if (throttleTimeMs > 0) { + listenerQuota.listenerConnectionRateThrottleSensor.record(throttleTimeMs.toDouble, timeMs) + } + throttleTimeMs + } + .getOrElse(0) + } + + if (protectedListener(listenerName)) { + recordAndGetListenerThrottleTime(0) + } else { + val brokerThrottleTimeMs = recordAndGetThrottleTimeMs(brokerConnectionRateSensor, timeMs) + recordAndGetListenerThrottleTime(brokerThrottleTimeMs) + } + } + + /** + * Record IP throttle time on the corresponding listener. To avoid over-recording listener/broker connection rate, we + * also un-record the listener and broker connection if the IP gets throttled. + * + * @param listenerName listener to un-record connection + * @param throttleMs IP throttle time to record for listener + * @param timeMs current time in milliseconds + */ + private def updateListenerMetrics(listenerName: ListenerName, throttleMs: Long, timeMs: Long): Unit = { + if (!protectedListener(listenerName)) { + brokerConnectionRateSensor.record(-1.0, timeMs, false) + } + maxConnectionsPerListener + .get(listenerName) + .foreach { listenerQuota => + listenerQuota.ipConnectionRateThrottleSensor.record(throttleMs.toDouble, timeMs) + listenerQuota.connectionRateSensor.record(-1.0, timeMs, false) + } + } + + /** + * Calculates the delay needed to bring the observed connection creation rate to the IP limit. + * If the connection would cause an IP quota violation, un-record the connection for both IP, + * listener, and broker connection rate and throw a ConnectionThrottledException. Calls to + * this function must be performed with the counts lock to ensure that reading the IP + * connection rate quota and creating the sensor's metric config is atomic. + * + * @param listenerName listener to unrecord connection if throttled + * @param address ip address to record connection + */ + private def recordIpConnectionMaybeThrottle(listenerName: ListenerName, address: InetAddress): Unit = { + val connectionRateQuota = connectionRateForIp(address) + val quotaEnabled = connectionRateQuota != QuotaConfigs.IP_CONNECTION_RATE_DEFAULT + if (quotaEnabled) { + val sensor = getOrCreateConnectionRateQuotaSensor(connectionRateQuota, IpQuotaEntity(address)) + val timeMs = time.milliseconds + val throttleMs = recordAndGetThrottleTimeMs(sensor, timeMs) + if (throttleMs > 0) { + trace(s"Throttling $address for $throttleMs ms") + // unrecord the connection since we won't accept the connection + sensor.record(-1.0, timeMs, false) + updateListenerMetrics(listenerName, throttleMs, timeMs) + throw new ConnectionThrottledException(address, timeMs, throttleMs) + } + } + } + + /** + * Records a new connection into a given connection acceptance rate sensor 'sensor' and returns throttle time + * in milliseconds if quota got violated + * @param sensor sensor to record connection + * @param timeMs current time in milliseconds + * @return throttle time in milliseconds if quota got violated, otherwise 0 + */ + private def recordAndGetThrottleTimeMs(sensor: Sensor, timeMs: Long): Int = { + try { + sensor.record(1.0, timeMs) + 0 + } catch { + case e: QuotaViolationException => + val throttleTimeMs = QuotaUtils.boundedThrottleTime(e, maxThrottleTimeMs, timeMs).toInt + debug(s"Quota violated for sensor (${sensor.name}). Delay time: $throttleTimeMs ms") + throttleTimeMs + } + } + + /** + * Creates sensor for tracking the connection creation rate and corresponding connection rate quota for a given + * listener or broker-wide, if listener is not provided. + * @param quotaLimit connection creation rate quota + * @param connectionQuotaEntity entity to create the sensor for + */ + private def getOrCreateConnectionRateQuotaSensor(quotaLimit: Int, connectionQuotaEntity: ConnectionQuotaEntity): Sensor = { + Option(metrics.getSensor(connectionQuotaEntity.sensorName)).getOrElse { + val sensor = metrics.sensor( + connectionQuotaEntity.sensorName, + rateQuotaMetricConfig(quotaLimit), + connectionQuotaEntity.sensorExpiration + ) + sensor.add(connectionRateMetricName(connectionQuotaEntity), new Rate, null) + sensor + } + } + + /** + * Updates quota configuration for a given connection quota entity + */ + private def updateConnectionRateQuota(quotaLimit: Int, connectionQuotaEntity: ConnectionQuotaEntity): Unit = { + Option(metrics.metric(connectionRateMetricName(connectionQuotaEntity))).foreach { metric => + metric.config(rateQuotaMetricConfig(quotaLimit)) + info(s"Updated ${connectionQuotaEntity.metricName} max connection creation rate to $quotaLimit") + } + } + + private def connectionRateMetricName(connectionQuotaEntity: ConnectionQuotaEntity): MetricName = { + metrics.metricName( + connectionQuotaEntity.metricName, + MetricsGroup, + s"Tracking rate of accepting new connections (per second)", + connectionQuotaEntity.metricTags.asJava) + } + + private def rateQuotaMetricConfig(quotaLimit: Int): MetricConfig = { + new MetricConfig() + .timeWindow(config.quotaWindowSizeSeconds.toLong, TimeUnit.SECONDS) + .samples(config.numQuotaSamples) + .quota(new Quota(quotaLimit, true)) + } + + def close(): Unit = { + metrics.removeSensor(brokerConnectionRateSensor.name) + maxConnectionsPerListener.values.foreach(_.close()) + } + + class ListenerConnectionQuota(lock: Object, listener: ListenerName) extends ListenerReconfigurable with AutoCloseable { + @volatile private var _maxConnections = Int.MaxValue + private[network] val connectionRateSensor = getOrCreateConnectionRateQuotaSensor(Int.MaxValue, ListenerQuotaEntity(listener.value)) + private[network] val listenerConnectionRateThrottleSensor = createConnectionRateThrottleSensor(ListenerThrottlePrefix) + private[network] val ipConnectionRateThrottleSensor = createConnectionRateThrottleSensor(IpThrottlePrefix) + + def maxConnections: Int = _maxConnections + + override def listenerName(): ListenerName = listener + + override def configure(configs: util.Map[String, _]): Unit = { + _maxConnections = maxConnections(configs) + updateConnectionRateQuota(maxConnectionCreationRate(configs), ListenerQuotaEntity(listener.value)) + } + + override def reconfigurableConfigs(): util.Set[String] = { + SocketServer.ListenerReconfigurableConfigs.asJava + } + + override def validateReconfiguration(configs: util.Map[String, _]): Unit = { + val value = maxConnections(configs) + if (value <= 0) + throw new ConfigException(s"Invalid ${KafkaConfig.MaxConnectionsProp} $value") + + val rate = maxConnectionCreationRate(configs) + if (rate <= 0) + throw new ConfigException(s"Invalid ${KafkaConfig.MaxConnectionCreationRateProp} $rate") + } + + override def reconfigure(configs: util.Map[String, _]): Unit = { + lock.synchronized { + _maxConnections = maxConnections(configs) + updateConnectionRateQuota(maxConnectionCreationRate(configs), ListenerQuotaEntity(listener.value)) + lock.notifyAll() + } + } + + def close(): Unit = { + metrics.removeSensor(connectionRateSensor.name) + metrics.removeSensor(listenerConnectionRateThrottleSensor.name) + metrics.removeSensor(ipConnectionRateThrottleSensor.name) + } + + private def maxConnections(configs: util.Map[String, _]): Int = { + Option(configs.get(KafkaConfig.MaxConnectionsProp)).map(_.toString.toInt).getOrElse(Int.MaxValue) + } + + private def maxConnectionCreationRate(configs: util.Map[String, _]): Int = { + Option(configs.get(KafkaConfig.MaxConnectionCreationRateProp)).map(_.toString.toInt).getOrElse(Int.MaxValue) + } + + /** + * Creates sensor for tracking the average throttle time on this listener due to hitting broker/listener connection + * rate or IP connection rate quota. The average is out of all throttle times > 0, which is consistent with the + * bandwidth and request quota throttle time metrics. + */ + private def createConnectionRateThrottleSensor(throttlePrefix: String): Sensor = { + val sensor = metrics.sensor(s"${throttlePrefix}ConnectionRateThrottleTime-${listener.value}") + val metricName = metrics.metricName(s"${throttlePrefix}connection-accept-throttle-time", + MetricsGroup, + "Tracking average throttle-time, out of non-zero throttle times, per listener", + Map(ListenerMetricTag -> listener.value).asJava) + sensor.add(metricName, new Avg) + sensor + } + } +} + +class TooManyConnectionsException(val ip: InetAddress, val count: Int) extends KafkaException(s"Too many connections from $ip (maximum = $count)") + +class ConnectionThrottledException(val ip: InetAddress, val startThrottleTimeMs: Long, val throttleTimeMs: Long) + extends KafkaException(s"$ip throttled for $throttleTimeMs") diff --git a/core/src/main/scala/kafka/network/package.html b/core/src/main/scala/kafka/network/package.html new file mode 100644 index 0000000..c9280fe --- /dev/null +++ b/core/src/main/scala/kafka/network/package.html @@ -0,0 +1,29 @@ + +The network server for kafka. Now application specific code here, just general network server stuff. +
            +The classes Receive and Send encapsulate the incoming and outgoing transmission of bytes. A Handler +is a mapping between a Receive and a Send, and represents the users hook to add logic for mapping requests +to actual processing code. Any uncaught exceptions in the reading or writing of the transmissions will result in +the server logging an error and closing the offending socket. As a result it is the duty of the Handler +implementation to catch and serialize any application-level errors that should be sent to the client. +
            +This slightly lower-level interface that models sending and receiving rather than requests and responses +is necessary in order to allow the send or receive to be overridden with a non-user-space writing of bytes +using FileChannel.transferTo. \ No newline at end of file diff --git a/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala b/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala new file mode 100644 index 0000000..c83aec6 --- /dev/null +++ b/core/src/main/scala/kafka/raft/KafkaMetadataLog.scala @@ -0,0 +1,644 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.raft + +import kafka.log.{AppendOrigin, Defaults, UnifiedLog, LogConfig, LogOffsetSnapshot, SnapshotGenerated} +import kafka.server.KafkaConfig.{MetadataLogSegmentBytesProp, MetadataLogSegmentMinBytesProp} +import kafka.server.{BrokerTopicStats, FetchHighWatermark, FetchLogEnd, KafkaConfig, LogDirFailureChannel, RequestLocal} +import kafka.utils.{CoreUtils, Logging, Scheduler} +import org.apache.kafka.common.config.AbstractConfig +import org.apache.kafka.common.errors.InvalidConfigurationException +import org.apache.kafka.common.record.{ControlRecordUtils, MemoryRecords, Records} +import org.apache.kafka.common.utils.{BufferSupplier, Time} +import org.apache.kafka.common.{KafkaException, TopicPartition, Uuid} +import org.apache.kafka.raft.{Isolation, KafkaRaftClient, LogAppendInfo, LogFetchInfo, LogOffsetMetadata, OffsetAndEpoch, OffsetMetadata, ReplicatedLog, ValidOffsetAndEpoch} +import org.apache.kafka.snapshot.{FileRawSnapshotReader, FileRawSnapshotWriter, RawSnapshotReader, RawSnapshotWriter, Snapshots} + +import java.io.File +import java.nio.file.{Files, NoSuchFileException, Path} +import java.util.{Optional, Properties} +import scala.annotation.nowarn +import scala.collection.mutable +import scala.compat.java8.OptionConverters._ + +final class KafkaMetadataLog private ( + val log: UnifiedLog, + time: Time, + scheduler: Scheduler, + // Access to this object needs to be synchronized because it is used by the snapshotting thread to notify the + // polling thread when snapshots are created. This object is also used to store any opened snapshot reader. + snapshots: mutable.TreeMap[OffsetAndEpoch, Option[FileRawSnapshotReader]], + topicPartition: TopicPartition, + config: MetadataLogConfig +) extends ReplicatedLog with Logging { + + this.logIdent = s"[MetadataLog partition=$topicPartition, nodeId=${config.nodeId}] " + + override def read(startOffset: Long, readIsolation: Isolation): LogFetchInfo = { + val isolation = readIsolation match { + case Isolation.COMMITTED => FetchHighWatermark + case Isolation.UNCOMMITTED => FetchLogEnd + case _ => throw new IllegalArgumentException(s"Unhandled read isolation $readIsolation") + } + + val fetchInfo = log.read(startOffset, + maxLength = config.maxFetchSizeInBytes, + isolation = isolation, + minOneMessage = true) + + new LogFetchInfo( + fetchInfo.records, + + new LogOffsetMetadata( + fetchInfo.fetchOffsetMetadata.messageOffset, + Optional.of(SegmentPosition( + fetchInfo.fetchOffsetMetadata.segmentBaseOffset, + fetchInfo.fetchOffsetMetadata.relativePositionInSegment)) + ) + ) + } + + override def appendAsLeader(records: Records, epoch: Int): LogAppendInfo = { + if (records.sizeInBytes == 0) + throw new IllegalArgumentException("Attempt to append an empty record set") + + handleAndConvertLogAppendInfo( + log.appendAsLeader(records.asInstanceOf[MemoryRecords], + leaderEpoch = epoch, + origin = AppendOrigin.RaftLeader, + requestLocal = RequestLocal.NoCaching + ) + ) + } + + override def appendAsFollower(records: Records): LogAppendInfo = { + if (records.sizeInBytes == 0) + throw new IllegalArgumentException("Attempt to append an empty record set") + + handleAndConvertLogAppendInfo(log.appendAsFollower(records.asInstanceOf[MemoryRecords])) + } + + private def handleAndConvertLogAppendInfo(appendInfo: kafka.log.LogAppendInfo): LogAppendInfo = { + appendInfo.firstOffset match { + case Some(firstOffset) => + new LogAppendInfo(firstOffset.messageOffset, appendInfo.lastOffset) + case None => + throw new KafkaException(s"Append failed unexpectedly: ${appendInfo.errorMessage}") + } + } + + override def lastFetchedEpoch: Int = { + log.latestEpoch.getOrElse { + latestSnapshotId().map[Int] { snapshotId => + val logEndOffset = endOffset().offset + if (snapshotId.offset == startOffset && snapshotId.offset == logEndOffset) { + // Return the epoch of the snapshot when the log is empty + snapshotId.epoch + } else { + throw new KafkaException( + s"Log doesn't have a last fetch epoch and there is a snapshot ($snapshotId). " + + s"Expected the snapshot's end offset to match the log's end offset ($logEndOffset) " + + s"and the log start offset ($startOffset)" + ) + } + }.orElse(0) + } + } + + override def endOffsetForEpoch(epoch: Int): OffsetAndEpoch = { + (log.endOffsetForEpoch(epoch), earliestSnapshotId().asScala) match { + case (Some(offsetAndEpoch), Some(snapshotId)) if ( + offsetAndEpoch.offset == snapshotId.offset && + offsetAndEpoch.leaderEpoch == epoch) => + + // The epoch is smaller than the smallest epoch on the log. Override the diverging + // epoch to the oldest snapshot which should be the snapshot at the log start offset + new OffsetAndEpoch(snapshotId.offset, snapshotId.epoch) + + case (Some(offsetAndEpoch), _) => + new OffsetAndEpoch(offsetAndEpoch.offset, offsetAndEpoch.leaderEpoch) + + case (None, _) => + new OffsetAndEpoch(endOffset.offset, lastFetchedEpoch) + } + } + + override def endOffset: LogOffsetMetadata = { + val endOffsetMetadata = log.logEndOffsetMetadata + new LogOffsetMetadata( + endOffsetMetadata.messageOffset, + Optional.of(SegmentPosition( + endOffsetMetadata.segmentBaseOffset, + endOffsetMetadata.relativePositionInSegment) + ) + ) + } + + override def startOffset: Long = { + log.logStartOffset + } + + override def truncateTo(offset: Long): Unit = { + if (offset < highWatermark.offset) { + throw new IllegalArgumentException(s"Attempt to truncate to offset $offset, which is below " + + s"the current high watermark ${highWatermark.offset}") + } + log.truncateTo(offset) + } + + override def truncateToLatestSnapshot(): Boolean = { + val latestEpoch = log.latestEpoch.getOrElse(0) + val (truncated, forgottenSnapshots) = latestSnapshotId().asScala match { + case Some(snapshotId) if ( + snapshotId.epoch > latestEpoch || + (snapshotId.epoch == latestEpoch && snapshotId.offset > endOffset().offset) + ) => + // Truncate the log fully if the latest snapshot is greater than the log end offset + log.truncateFullyAndStartAt(snapshotId.offset) + + // Forget snapshots less than the log start offset + snapshots synchronized { + (true, forgetSnapshotsBefore(snapshotId)) + } + case _ => + (false, mutable.TreeMap.empty[OffsetAndEpoch, Option[FileRawSnapshotReader]]) + } + + removeSnapshots(forgottenSnapshots) + truncated + } + + override def initializeLeaderEpoch(epoch: Int): Unit = { + log.maybeAssignEpochStartOffset(epoch, log.logEndOffset) + } + + override def updateHighWatermark(offsetMetadata: LogOffsetMetadata): Unit = { + offsetMetadata.metadata.asScala match { + case Some(segmentPosition: SegmentPosition) => log.updateHighWatermark( + new kafka.server.LogOffsetMetadata( + offsetMetadata.offset, + segmentPosition.baseOffset, + segmentPosition.relativePosition) + ) + case _ => + // FIXME: This API returns the new high watermark, which may be different from the passed offset + log.updateHighWatermark(offsetMetadata.offset) + } + } + + override def highWatermark: LogOffsetMetadata = { + val LogOffsetSnapshot(_, _, hwm, _) = log.fetchOffsetSnapshot + val segmentPosition: Optional[OffsetMetadata] = if (hwm.messageOffsetOnly) { + Optional.of(SegmentPosition(hwm.segmentBaseOffset, hwm.relativePositionInSegment)) + } else { + Optional.empty() + } + + new LogOffsetMetadata(hwm.messageOffset, segmentPosition) + } + + override def flush(): Unit = { + log.flush() + } + + override def lastFlushedOffset(): Long = { + log.recoveryPoint + } + + /** + * Return the topic partition associated with the log. + */ + override def topicPartition(): TopicPartition = { + topicPartition + } + + /** + * Return the topic ID associated with the log. + */ + override def topicId(): Uuid = { + log.topicId.get + } + + override def createNewSnapshot(snapshotId: OffsetAndEpoch): Optional[RawSnapshotWriter] = { + if (snapshotId.offset < startOffset) { + info(s"Cannot create a snapshot with an id ($snapshotId) less than the log start offset ($startOffset)") + return Optional.empty() + } + + val highWatermarkOffset = highWatermark.offset + if (snapshotId.offset > highWatermarkOffset) { + throw new IllegalArgumentException( + s"Cannot create a snapshot with an id ($snapshotId) greater than the high-watermark ($highWatermarkOffset)" + ) + } + + val validOffsetAndEpoch = validateOffsetAndEpoch(snapshotId.offset, snapshotId.epoch) + if (validOffsetAndEpoch.kind() != ValidOffsetAndEpoch.Kind.VALID) { + throw new IllegalArgumentException( + s"Snapshot id ($snapshotId) is not valid according to the log: $validOffsetAndEpoch" + ) + } + + storeSnapshot(snapshotId) + } + + override def storeSnapshot(snapshotId: OffsetAndEpoch): Optional[RawSnapshotWriter] = { + if (snapshots.contains(snapshotId)) { + Optional.empty() + } else { + Optional.of(FileRawSnapshotWriter.create(log.dir.toPath, snapshotId, Optional.of(this))) + } + } + + override def readSnapshot(snapshotId: OffsetAndEpoch): Optional[RawSnapshotReader] = { + snapshots synchronized { + val reader = snapshots.get(snapshotId) match { + case None => + // Snapshot doesn't exists + None + case Some(None) => + // Snapshot exists but has never been read before + try { + val snapshotReader = Some(FileRawSnapshotReader.open(log.dir.toPath, snapshotId)) + snapshots.put(snapshotId, snapshotReader) + snapshotReader + } catch { + case _: NoSuchFileException => + // Snapshot doesn't exists in the data dir; remove + val path = Snapshots.snapshotPath(log.dir.toPath, snapshotId) + warn(s"Couldn't read $snapshotId; expected to find snapshot file $path") + snapshots.remove(snapshotId) + None + } + case Some(value) => + // Snapshot exists and it is already open; do nothing + value + } + + reader.asJava.asInstanceOf[Optional[RawSnapshotReader]] + } + } + + override def latestSnapshot(): Optional[RawSnapshotReader] = { + snapshots synchronized { + latestSnapshotId().flatMap(readSnapshot) + } + } + + override def latestSnapshotId(): Optional[OffsetAndEpoch] = { + snapshots synchronized { + snapshots.lastOption.map { case (snapshotId, _) => snapshotId }.asJava + } + } + + override def earliestSnapshotId(): Optional[OffsetAndEpoch] = { + snapshots synchronized { + snapshots.headOption.map { case (snapshotId, _) => snapshotId }.asJava + } + } + + override def onSnapshotFrozen(snapshotId: OffsetAndEpoch): Unit = { + snapshots synchronized { + snapshots.put(snapshotId, None) + } + } + + /** + * Delete snapshots that come before a given snapshot ID. This is done by advancing the log start offset to the given + * snapshot and cleaning old log segments. + * + * This will only happen if the following invariants all hold true: + * + *
          • The given snapshot precedes the latest snapshot
          • + *
          • The offset of the given snapshot is greater than the log start offset
          • + *
          • The log layer can advance the offset to the given snapshot
          • + * + * This method is thread-safe + */ + override def deleteBeforeSnapshot(snapshotId: OffsetAndEpoch): Boolean = { + val (deleted, forgottenSnapshots) = snapshots synchronized { + latestSnapshotId().asScala match { + case Some(latestSnapshotId) if + snapshots.contains(snapshotId) && + startOffset < snapshotId.offset && + snapshotId.offset <= latestSnapshotId.offset && + log.maybeIncrementLogStartOffset(snapshotId.offset, SnapshotGenerated) => + // Delete all segments that have a "last offset" less than the log start offset + log.deleteOldSegments() + // Remove older snapshots from the snapshots cache + (true, forgetSnapshotsBefore(snapshotId)) + case _ => + (false, mutable.TreeMap.empty[OffsetAndEpoch, Option[FileRawSnapshotReader]]) + } + } + removeSnapshots(forgottenSnapshots) + deleted + } + + /** + * Force all known snapshots to have an open reader so we can know their sizes. This method is not thread-safe + */ + private def loadSnapshotSizes(): Seq[(OffsetAndEpoch, Long)] = { + snapshots.keys.toSeq.flatMap { + snapshotId => readSnapshot(snapshotId).asScala.map { reader => (snapshotId, reader.sizeInBytes())} + } + } + + /** + * Return the max timestamp of the first batch in a snapshot, if the snapshot exists and has records + */ + private def readSnapshotTimestamp(snapshotId: OffsetAndEpoch): Option[Long] = { + readSnapshot(snapshotId).asScala.flatMap { reader => + val batchIterator = reader.records().batchIterator() + + val firstBatch = batchIterator.next() + val records = firstBatch.streamingIterator(new BufferSupplier.GrowableBufferSupplier()) + if (firstBatch.isControlBatch) { + val header = ControlRecordUtils.deserializedSnapshotHeaderRecord(records.next()); + Some(header.lastContainedLogTimestamp()) + } else { + warn("Did not find control record at beginning of snapshot") + None + } + } + } + + /** + * Perform cleaning of old snapshots and log segments based on size. + * + * If our configured retention size has been violated, we perform cleaning as follows: + * + *
          • Find oldest snapshot and delete it
          • + *
          • Advance log start offset to end of next oldest snapshot
          • + *
          • Delete log segments which wholly precede the new log start offset
          • + * + * This process is repeated until the retention size is no longer violated, or until only + * a single snapshot remains. + */ + override def maybeClean(): Boolean = { + snapshots synchronized { + var didClean = false + didClean |= cleanSnapshotsRetentionSize() + didClean |= cleanSnapshotsRetentionMs() + didClean + } + } + + /** + * Iterate through the snapshots a test the given predicate to see if we should attempt to delete it. Since + * we have some additional invariants regarding snapshots and log segments we cannot simply delete a snapshot in + * all cases. + * + * For the given predicate, we are testing if the snapshot identified by the first argument should be deleted. + */ + private def cleanSnapshots(predicate: (OffsetAndEpoch) => Boolean): Boolean = { + if (snapshots.size < 2) + return false + + var didClean = false + snapshots.keys.toSeq.sliding(2).foreach { + case Seq(snapshot: OffsetAndEpoch, nextSnapshot: OffsetAndEpoch) => + if (predicate(snapshot) && deleteBeforeSnapshot(nextSnapshot)) { + didClean = true + } else { + return didClean + } + case _ => false // Shouldn't get here with the sliding window + } + didClean + } + + private def cleanSnapshotsRetentionMs(): Boolean = { + if (config.retentionMillis < 0) + return false + + // Keep deleting snapshots as long as the + def shouldClean(snapshotId: OffsetAndEpoch): Boolean = { + val now = time.milliseconds() + readSnapshotTimestamp(snapshotId).exists { timestamp => + if (now - timestamp > config.retentionMillis) { + true + } else { + false + } + } + } + + cleanSnapshots(shouldClean) + } + + private def cleanSnapshotsRetentionSize(): Boolean = { + if (config.retentionMaxBytes < 0) + return false + + val snapshotSizes = loadSnapshotSizes().toMap + + var snapshotTotalSize: Long = snapshotSizes.values.sum + + // Keep deleting snapshots and segments as long as we exceed the retention size + def shouldClean(snapshotId: OffsetAndEpoch): Boolean = { + snapshotSizes.get(snapshotId).exists { snapshotSize => + if (log.size + snapshotTotalSize > config.retentionMaxBytes) { + snapshotTotalSize -= snapshotSize + true + } else { + false + } + } + } + + cleanSnapshots(shouldClean) + } + + /** + * Forget the snapshots earlier than a given snapshot id and return the associated + * snapshot readers. + * + * This method assumes that the lock for `snapshots` is already held. + */ + @nowarn("cat=deprecation") // Needed for TreeMap.until + private def forgetSnapshotsBefore( + logStartSnapshotId: OffsetAndEpoch + ): mutable.TreeMap[OffsetAndEpoch, Option[FileRawSnapshotReader]] = { + val expiredSnapshots = snapshots.until(logStartSnapshotId).clone() + snapshots --= expiredSnapshots.keys + + expiredSnapshots + } + + /** + * Rename the given snapshots on the log directory. Asynchronously, close and delete the + * given snapshots after some delay. + */ + private def removeSnapshots( + expiredSnapshots: mutable.TreeMap[OffsetAndEpoch, Option[FileRawSnapshotReader]] + ): Unit = { + expiredSnapshots.foreach { case (snapshotId, _) => + info(s"Marking snapshot $snapshotId for deletion") + Snapshots.markForDelete(log.dir.toPath, snapshotId) + } + + if (expiredSnapshots.nonEmpty) { + scheduler.schedule( + "delete-snapshot-files", + KafkaMetadataLog.deleteSnapshotFiles(log.dir.toPath, expiredSnapshots, this), + config.fileDeleteDelayMs + ) + } + } + + override def close(): Unit = { + log.close() + snapshots synchronized { + snapshots.values.flatten.foreach(_.close()) + snapshots.clear() + } + } + + private[raft] def snapshotCount(): Int = { + snapshots synchronized { + snapshots.size + } + } +} + +object MetadataLogConfig { + def apply(config: AbstractConfig, maxBatchSizeInBytes: Int, maxFetchSizeInBytes: Int): MetadataLogConfig = { + new MetadataLogConfig( + config.getInt(KafkaConfig.MetadataLogSegmentBytesProp), + config.getInt(KafkaConfig.MetadataLogSegmentMinBytesProp), + config.getLong(KafkaConfig.MetadataLogSegmentMillisProp), + config.getLong(KafkaConfig.MetadataMaxRetentionBytesProp), + config.getLong(KafkaConfig.MetadataMaxRetentionMillisProp), + maxBatchSizeInBytes, + maxFetchSizeInBytes, + Defaults.FileDeleteDelayMs, + config.getInt(KafkaConfig.NodeIdProp) + ) + } +} + +case class MetadataLogConfig(logSegmentBytes: Int, + logSegmentMinBytes: Int, + logSegmentMillis: Long, + retentionMaxBytes: Long, + retentionMillis: Long, + maxBatchSizeInBytes: Int, + maxFetchSizeInBytes: Int, + fileDeleteDelayMs: Int, + nodeId: Int) + +object KafkaMetadataLog { + def apply( + topicPartition: TopicPartition, + topicId: Uuid, + dataDir: File, + time: Time, + scheduler: Scheduler, + config: MetadataLogConfig + ): KafkaMetadataLog = { + val props = new Properties() + props.put(LogConfig.MaxMessageBytesProp, config.maxBatchSizeInBytes.toString) + props.put(LogConfig.SegmentBytesProp, Int.box(config.logSegmentBytes)) + props.put(LogConfig.SegmentMsProp, Long.box(config.logSegmentMillis)) + props.put(LogConfig.FileDeleteDelayMsProp, Int.box(Defaults.FileDeleteDelayMs)) + LogConfig.validateValues(props) + val defaultLogConfig = LogConfig(props) + + if (config.logSegmentBytes < config.logSegmentMinBytes) { + throw new InvalidConfigurationException(s"Cannot set $MetadataLogSegmentBytesProp below ${config.logSegmentMinBytes}") + } + + val log = UnifiedLog( + dir = dataDir, + config = defaultLogConfig, + logStartOffset = 0L, + recoveryPoint = 0L, + scheduler = scheduler, + brokerTopicStats = new BrokerTopicStats, + time = time, + maxProducerIdExpirationMs = Int.MaxValue, + producerIdExpirationCheckIntervalMs = Int.MaxValue, + logDirFailureChannel = new LogDirFailureChannel(5), + lastShutdownClean = false, + topicId = Some(topicId), + keepPartitionMetadataFile = true + ) + + val metadataLog = new KafkaMetadataLog( + log, + time, + scheduler, + recoverSnapshots(log), + topicPartition, + config + ) + + // Print a warning if users have overridden the internal config + if (config.logSegmentMinBytes != KafkaRaftClient.MAX_BATCH_SIZE_BYTES) { + metadataLog.error(s"Overriding $MetadataLogSegmentMinBytesProp is only supported for testing. Setting " + + s"this value too low may lead to an inability to write batches of metadata records.") + } + + // When recovering, truncate fully if the latest snapshot is after the log end offset. This can happen to a follower + // when the follower crashes after downloading a snapshot from the leader but before it could truncate the log fully. + metadataLog.truncateToLatestSnapshot() + + metadataLog + } + + private def recoverSnapshots( + log: UnifiedLog + ): mutable.TreeMap[OffsetAndEpoch, Option[FileRawSnapshotReader]] = { + val snapshots = mutable.TreeMap.empty[OffsetAndEpoch, Option[FileRawSnapshotReader]] + // Scan the log directory; deleting partial snapshots and older snapshot, only remembering immutable snapshots start + // from logStartOffset + val filesInDir = Files.newDirectoryStream(log.dir.toPath) + + try { + filesInDir.forEach { path => + Snapshots.parse(path).ifPresent { snapshotPath => + if (snapshotPath.partial || + snapshotPath.deleted || + snapshotPath.snapshotId.offset < log.logStartOffset) { + // Delete partial snapshot, deleted snapshot and older snapshot + Files.deleteIfExists(snapshotPath.path) + } else { + snapshots.put(snapshotPath.snapshotId, None) + } + } + } + } finally { + filesInDir.close() + } + + snapshots + } + + private def deleteSnapshotFiles( + logDir: Path, + expiredSnapshots: mutable.TreeMap[OffsetAndEpoch, Option[FileRawSnapshotReader]], + logging: Logging + ): () => Unit = () => { + expiredSnapshots.foreach { case (snapshotId, snapshotReader) => + snapshotReader.foreach { reader => + CoreUtils.swallow(reader.close(), logging) + } + Snapshots.deleteIfExists(logDir, snapshotId) + } + } +} diff --git a/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala b/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala new file mode 100644 index 0000000..d990391 --- /dev/null +++ b/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.raft + +import kafka.common.{InterBrokerSendThread, RequestAndCompletionHandler} +import kafka.utils.Logging +import org.apache.kafka.clients.{ClientResponse, KafkaClient} +import org.apache.kafka.common.Node +import org.apache.kafka.common.message._ +import org.apache.kafka.common.protocol.{ApiKeys, ApiMessage, Errors} +import org.apache.kafka.common.requests._ +import org.apache.kafka.common.utils.Time +import org.apache.kafka.raft.RaftConfig.InetAddressSpec +import org.apache.kafka.raft.{NetworkChannel, RaftRequest, RaftResponse, RaftUtil} + +import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.atomic.AtomicInteger +import scala.collection.mutable + +object KafkaNetworkChannel { + + private[raft] def buildRequest(requestData: ApiMessage): AbstractRequest.Builder[_ <: AbstractRequest] = { + requestData match { + case voteRequest: VoteRequestData => + new VoteRequest.Builder(voteRequest) + case beginEpochRequest: BeginQuorumEpochRequestData => + new BeginQuorumEpochRequest.Builder(beginEpochRequest) + case endEpochRequest: EndQuorumEpochRequestData => + new EndQuorumEpochRequest.Builder(endEpochRequest) + case fetchRequest: FetchRequestData => + // Since we already have the request, we go through a simplified builder + new AbstractRequest.Builder[FetchRequest](ApiKeys.FETCH) { + override def build(version: Short): FetchRequest = new FetchRequest(fetchRequest, version) + override def toString(): String = fetchRequest.toString + } + case fetchSnapshotRequest: FetchSnapshotRequestData => + new FetchSnapshotRequest.Builder(fetchSnapshotRequest) + case _ => + throw new IllegalArgumentException(s"Unexpected type for requestData: $requestData") + } + } + +} + +private[raft] class RaftSendThread( + name: String, + networkClient: KafkaClient, + requestTimeoutMs: Int, + time: Time, + isInterruptible: Boolean = true +) extends InterBrokerSendThread( + name, + networkClient, + requestTimeoutMs, + time, + isInterruptible +) { + private val queue = new ConcurrentLinkedQueue[RequestAndCompletionHandler]() + + def generateRequests(): Iterable[RequestAndCompletionHandler] = { + val buffer = mutable.Buffer[RequestAndCompletionHandler]() + while (true) { + val request = queue.poll() + if (request == null) { + return buffer + } else { + buffer += request + } + } + buffer + } + + def sendRequest(request: RequestAndCompletionHandler): Unit = { + queue.add(request) + wakeup() + } + +} + + +class KafkaNetworkChannel( + time: Time, + client: KafkaClient, + requestTimeoutMs: Int, + threadNamePrefix: String +) extends NetworkChannel with Logging { + import KafkaNetworkChannel._ + + type ResponseHandler = AbstractResponse => Unit + + private val correlationIdCounter = new AtomicInteger(0) + private val endpoints = mutable.HashMap.empty[Int, Node] + + private val requestThread = new RaftSendThread( + name = threadNamePrefix + "-outbound-request-thread", + networkClient = client, + requestTimeoutMs = requestTimeoutMs, + time = time, + isInterruptible = false + ) + + override def send(request: RaftRequest.Outbound): Unit = { + def completeFuture(message: ApiMessage): Unit = { + val response = new RaftResponse.Inbound( + request.correlationId, + message, + request.destinationId + ) + request.completion.complete(response) + } + + def onComplete(clientResponse: ClientResponse): Unit = { + val response = if (clientResponse.versionMismatch != null) { + error(s"Request $request failed due to unsupported version error", + clientResponse.versionMismatch) + errorResponse(request.data, Errors.UNSUPPORTED_VERSION) + } else if (clientResponse.authenticationException != null) { + // For now we treat authentication errors as retriable. We use the + // `NETWORK_EXCEPTION` error code for lack of a good alternative. + // Note that `BrokerToControllerChannelManager` will still log the + // authentication errors so that users have a chance to fix the problem. + error(s"Request $request failed due to authentication error", + clientResponse.authenticationException) + errorResponse(request.data, Errors.NETWORK_EXCEPTION) + } else if (clientResponse.wasDisconnected()) { + errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE) + } else { + clientResponse.responseBody.data + } + completeFuture(response) + } + + endpoints.get(request.destinationId) match { + case Some(node) => + requestThread.sendRequest(RequestAndCompletionHandler( + request.createdTimeMs, + destination = node, + request = buildRequest(request.data), + handler = onComplete + )) + + case None => + completeFuture(errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE)) + } + } + + // Visible for testing + private[raft] def pollOnce(): Unit = { + requestThread.doWork() + } + + override def newCorrelationId(): Int = { + correlationIdCounter.getAndIncrement() + } + + private def errorResponse( + request: ApiMessage, + error: Errors + ): ApiMessage = { + val apiKey = ApiKeys.forId(request.apiKey) + RaftUtil.errorResponse(apiKey, error) + } + + override def updateEndpoint(id: Int, spec: InetAddressSpec): Unit = { + val node = new Node(id, spec.address.getHostString, spec.address.getPort) + endpoints.put(id, node) + } + + def start(): Unit = { + requestThread.start() + } + + def initiateShutdown(): Unit = { + requestThread.initiateShutdown() + } + + override def close(): Unit = { + requestThread.shutdown() + } +} diff --git a/core/src/main/scala/kafka/raft/RaftManager.scala b/core/src/main/scala/kafka/raft/RaftManager.scala new file mode 100644 index 0000000..4c29250 --- /dev/null +++ b/core/src/main/scala/kafka/raft/RaftManager.scala @@ -0,0 +1,285 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.raft + +import java.io.File +import java.nio.file.Files +import java.util +import java.util.OptionalInt +import java.util.concurrent.CompletableFuture +import kafka.log.UnifiedLog +import kafka.raft.KafkaRaftManager.RaftIoThread +import kafka.server.{KafkaConfig, MetaProperties} +import kafka.server.KafkaRaftServer.ControllerRole +import kafka.utils.timer.SystemTimer +import kafka.utils.{KafkaScheduler, Logging, ShutdownableThread} +import org.apache.kafka.clients.{ApiVersions, ManualMetadataUpdater, NetworkClient} +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.network.{ChannelBuilders, ListenerName, NetworkReceive, Selectable, Selector} +import org.apache.kafka.common.protocol.ApiMessage +import org.apache.kafka.common.requests.RequestHeader +import org.apache.kafka.common.security.JaasContext +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.utils.{LogContext, Time} +import org.apache.kafka.common.{TopicPartition, Uuid} +import org.apache.kafka.raft.RaftConfig.{AddressSpec, InetAddressSpec, NON_ROUTABLE_ADDRESS, UnknownAddressSpec} +import org.apache.kafka.raft.{FileBasedStateStore, KafkaRaftClient, LeaderAndEpoch, RaftClient, RaftConfig, RaftRequest, ReplicatedLog} +import org.apache.kafka.server.common.serialization.RecordSerde +import scala.jdk.CollectionConverters._ + +object KafkaRaftManager { + class RaftIoThread( + client: KafkaRaftClient[_], + threadNamePrefix: String + ) extends ShutdownableThread( + name = threadNamePrefix + "-io-thread", + isInterruptible = false + ) { + override def doWork(): Unit = { + client.poll() + } + + override def initiateShutdown(): Boolean = { + if (super.initiateShutdown()) { + client.shutdown(5000).whenComplete { (_, exception) => + if (exception != null) { + error("Graceful shutdown of RaftClient failed", exception) + } else { + info("Completed graceful shutdown of RaftClient") + } + } + true + } else { + false + } + } + + override def isRunning: Boolean = { + client.isRunning && !isThreadFailed + } + } + + private def createLogDirectory(logDir: File, logDirName: String): File = { + val logDirPath = logDir.getAbsolutePath + val dir = new File(logDirPath, logDirName) + Files.createDirectories(dir.toPath) + dir + } +} + +trait RaftManager[T] { + def handleRequest( + header: RequestHeader, + request: ApiMessage, + createdTimeMs: Long + ): CompletableFuture[ApiMessage] + + def register( + listener: RaftClient.Listener[T] + ): Unit + + def leaderAndEpoch: LeaderAndEpoch + + def client: RaftClient[T] + + def replicatedLog: ReplicatedLog +} + +class KafkaRaftManager[T]( + metaProperties: MetaProperties, + config: KafkaConfig, + recordSerde: RecordSerde[T], + topicPartition: TopicPartition, + topicId: Uuid, + time: Time, + metrics: Metrics, + threadNamePrefixOpt: Option[String], + val controllerQuorumVotersFuture: CompletableFuture[util.Map[Integer, AddressSpec]] +) extends RaftManager[T] with Logging { + + private val raftConfig = new RaftConfig(config) + private val threadNamePrefix = threadNamePrefixOpt.getOrElse("kafka-raft") + private val logContext = new LogContext(s"[RaftManager nodeId=${config.nodeId}] ") + this.logIdent = logContext.logPrefix() + + private val scheduler = new KafkaScheduler(threads = 1, threadNamePrefix + "-scheduler") + scheduler.startup() + + private val dataDir = createDataDir() + override val replicatedLog: ReplicatedLog = buildMetadataLog() + private val netChannel = buildNetworkChannel() + override val client: KafkaRaftClient[T] = buildRaftClient() + private val raftIoThread = new RaftIoThread(client, threadNamePrefix) + + def startup(): Unit = { + // Update the voter endpoints (if valid) with what's in RaftConfig + val voterAddresses: util.Map[Integer, AddressSpec] = controllerQuorumVotersFuture.get() + for (voterAddressEntry <- voterAddresses.entrySet.asScala) { + voterAddressEntry.getValue match { + case spec: InetAddressSpec => + netChannel.updateEndpoint(voterAddressEntry.getKey, spec) + case _: UnknownAddressSpec => + logger.info(s"Skipping channel update for destination ID: ${voterAddressEntry.getKey} " + + s"because of non-routable endpoint: ${NON_ROUTABLE_ADDRESS.toString}") + case invalid: AddressSpec => + logger.warn(s"Unexpected address spec (type: ${invalid.getClass}) for channel update for " + + s"destination ID: ${voterAddressEntry.getKey}") + } + } + netChannel.start() + raftIoThread.start() + } + + def shutdown(): Unit = { + raftIoThread.shutdown() + client.close() + scheduler.shutdown() + netChannel.close() + replicatedLog.close() + } + + override def register( + listener: RaftClient.Listener[T] + ): Unit = { + client.register(listener) + } + + override def handleRequest( + header: RequestHeader, + request: ApiMessage, + createdTimeMs: Long + ): CompletableFuture[ApiMessage] = { + val inboundRequest = new RaftRequest.Inbound( + header.correlationId, + request, + createdTimeMs + ) + + client.handle(inboundRequest) + + inboundRequest.completion.thenApply { response => + response.data + } + } + + private def buildRaftClient(): KafkaRaftClient[T] = { + val expirationTimer = new SystemTimer("raft-expiration-executor") + val expirationService = new TimingWheelExpirationService(expirationTimer) + val quorumStateStore = new FileBasedStateStore(new File(dataDir, "quorum-state")) + + val nodeId = if (config.processRoles.contains(ControllerRole)) { + OptionalInt.of(config.nodeId) + } else { + OptionalInt.empty() + } + + val client = new KafkaRaftClient( + recordSerde, + netChannel, + replicatedLog, + quorumStateStore, + time, + metrics, + expirationService, + logContext, + metaProperties.clusterId, + nodeId, + raftConfig + ) + client.initialize() + client + } + + private def buildNetworkChannel(): KafkaNetworkChannel = { + val netClient = buildNetworkClient() + new KafkaNetworkChannel(time, netClient, config.quorumRequestTimeoutMs, threadNamePrefix) + } + + private def createDataDir(): File = { + val logDirName = UnifiedLog.logDirName(topicPartition) + KafkaRaftManager.createLogDirectory(new File(config.metadataLogDir), logDirName) + } + + private def buildMetadataLog(): KafkaMetadataLog = { + KafkaMetadataLog( + topicPartition, + topicId, + dataDir, + time, + scheduler, + config = MetadataLogConfig(config, KafkaRaftClient.MAX_BATCH_SIZE_BYTES, KafkaRaftClient.MAX_FETCH_SIZE_BYTES) + ) + } + + private def buildNetworkClient(): NetworkClient = { + val controllerListenerName = new ListenerName(config.controllerListenerNames.head) + val controllerSecurityProtocol = config.effectiveListenerSecurityProtocolMap.getOrElse(controllerListenerName, SecurityProtocol.forName(controllerListenerName.value())) + val channelBuilder = ChannelBuilders.clientChannelBuilder( + controllerSecurityProtocol, + JaasContext.Type.SERVER, + config, + controllerListenerName, + config.saslMechanismControllerProtocol, + time, + config.saslInterBrokerHandshakeRequestEnable, + logContext + ) + + val metricGroupPrefix = "raft-channel" + val collectPerConnectionMetrics = false + + val selector = new Selector( + NetworkReceive.UNLIMITED, + config.connectionsMaxIdleMs, + metrics, + time, + metricGroupPrefix, + Map.empty[String, String].asJava, + collectPerConnectionMetrics, + channelBuilder, + logContext + ) + + val clientId = s"raft-client-${config.nodeId}" + val maxInflightRequestsPerConnection = 1 + val reconnectBackoffMs = 50 + val reconnectBackoffMsMs = 500 + val discoverBrokerVersions = true + + new NetworkClient( + selector, + new ManualMetadataUpdater(), + clientId, + maxInflightRequestsPerConnection, + reconnectBackoffMs, + reconnectBackoffMsMs, + Selectable.USE_DEFAULT_BUFFER_SIZE, + config.socketReceiveBufferBytes, + config.quorumRequestTimeoutMs, + config.connectionSetupTimeoutMs, + config.connectionSetupTimeoutMaxMs, + time, + discoverBrokerVersions, + new ApiVersions, + logContext + ) + } + + override def leaderAndEpoch: LeaderAndEpoch = { + client.leaderAndEpoch + } +} diff --git a/core/src/main/scala/kafka/raft/SegmentPosition.scala b/core/src/main/scala/kafka/raft/SegmentPosition.scala new file mode 100644 index 0000000..eb6a59f --- /dev/null +++ b/core/src/main/scala/kafka/raft/SegmentPosition.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.raft + +import org.apache.kafka.raft.OffsetMetadata + +case class SegmentPosition(baseOffset: Long, relativePosition: Int) extends OffsetMetadata { + override def toString: String = s"(segmentBaseOffset=$baseOffset,relativePositionInSegment=$relativePosition)" +} diff --git a/core/src/main/scala/kafka/raft/TimingWheelExpirationService.scala b/core/src/main/scala/kafka/raft/TimingWheelExpirationService.scala new file mode 100644 index 0000000..c07661e --- /dev/null +++ b/core/src/main/scala/kafka/raft/TimingWheelExpirationService.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.raft + +import java.util.concurrent.CompletableFuture + +import kafka.utils.ShutdownableThread +import kafka.utils.timer.{Timer, TimerTask} +import org.apache.kafka.common.errors.TimeoutException +import org.apache.kafka.raft.ExpirationService + +object TimingWheelExpirationService { + private val WorkTimeoutMs: Long = 200L + + class TimerTaskCompletableFuture[T](override val delayMs: Long) extends CompletableFuture[T] with TimerTask { + override def run(): Unit = { + completeExceptionally(new TimeoutException( + s"Future failed to be completed before timeout of $delayMs ms was reached")) + } + } +} + +class TimingWheelExpirationService(timer: Timer) extends ExpirationService { + import TimingWheelExpirationService._ + + private val expirationReaper = new ExpiredOperationReaper() + + expirationReaper.start() + + override def failAfter[T](timeoutMs: Long): CompletableFuture[T] = { + val future = new TimerTaskCompletableFuture[T](timeoutMs) + future.whenComplete { (_, _) => + future.cancel() + } + timer.add(future) + future + } + + private class ExpiredOperationReaper extends ShutdownableThread( + name = "raft-expiration-reaper", isInterruptible = false) { + + override def doWork(): Unit = { + timer.advanceClock(WorkTimeoutMs) + } + } + + def shutdown(): Unit = { + expirationReaper.shutdown() + } +} diff --git a/core/src/main/scala/kafka/security/CredentialProvider.scala b/core/src/main/scala/kafka/security/CredentialProvider.scala new file mode 100644 index 0000000..9aa8bc9 --- /dev/null +++ b/core/src/main/scala/kafka/security/CredentialProvider.scala @@ -0,0 +1,54 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.security + +import java.util.{Collection, Properties} + +import org.apache.kafka.common.security.authenticator.CredentialCache +import org.apache.kafka.common.security.scram.ScramCredential +import org.apache.kafka.common.config.ConfigDef +import org.apache.kafka.common.config.ConfigDef._ +import org.apache.kafka.common.security.scram.internals.{ScramCredentialUtils, ScramMechanism} +import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache + +class CredentialProvider(scramMechanisms: Collection[String], val tokenCache: DelegationTokenCache) { + + val credentialCache = new CredentialCache + ScramCredentialUtils.createCache(credentialCache, scramMechanisms) + + def updateCredentials(username: String, config: Properties): Unit = { + for (mechanism <- ScramMechanism.values()) { + val cache = credentialCache.cache(mechanism.mechanismName, classOf[ScramCredential]) + if (cache != null) { + config.getProperty(mechanism.mechanismName) match { + case null => cache.remove(username) + case c => cache.put(username, ScramCredentialUtils.credentialFromString(c)) + } + } + } + } +} + +object CredentialProvider { + def userCredentialConfigs: ConfigDef = { + ScramMechanism.values.foldLeft(new ConfigDef) { + (c, m) => c.define(m.mechanismName, Type.STRING, null, Importance.MEDIUM, s"User credentials for SCRAM mechanism ${m.mechanismName}") + } + } +} + diff --git a/core/src/main/scala/kafka/security/authorizer/AclAuthorizer.scala b/core/src/main/scala/kafka/security/authorizer/AclAuthorizer.scala new file mode 100644 index 0000000..88648fd --- /dev/null +++ b/core/src/main/scala/kafka/security/authorizer/AclAuthorizer.scala @@ -0,0 +1,754 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.security.authorizer + +import java.{lang, util} +import java.util.concurrent.{CompletableFuture, CompletionStage} + +import com.typesafe.scalalogging.Logger +import kafka.api.KAFKA_2_0_IV1 +import kafka.security.authorizer.AclEntry.ResourceSeparator +import kafka.server.{KafkaConfig, KafkaServer} +import kafka.utils._ +import kafka.utils.Implicits._ +import kafka.zk._ +import org.apache.kafka.common.Endpoint +import org.apache.kafka.common.acl._ +import org.apache.kafka.common.acl.AclOperation._ +import org.apache.kafka.common.acl.AclPermissionType.{ALLOW, DENY} +import org.apache.kafka.common.errors.{ApiException, InvalidRequestException, UnsupportedVersionException} +import org.apache.kafka.common.protocol.ApiKeys +import org.apache.kafka.common.resource._ +import org.apache.kafka.common.security.auth.KafkaPrincipal +import org.apache.kafka.common.utils.{SecurityUtils, Time} +import org.apache.kafka.server.authorizer.AclDeleteResult.AclBindingDeleteResult +import org.apache.kafka.server.authorizer._ +import org.apache.zookeeper.client.ZKClientConfig + +import scala.annotation.nowarn +import scala.collection.mutable.ArrayBuffer +import scala.collection.{Seq, immutable, mutable} +import scala.jdk.CollectionConverters._ +import scala.util.{Failure, Random, Success, Try} + +object AclAuthorizer { + // Optional override zookeeper cluster configuration where acls will be stored. If not specified, + // acls will be stored in the same zookeeper where all other kafka broker metadata is stored. + val configPrefix = "authorizer." + val ZkUrlProp = s"${configPrefix}zookeeper.url" + val ZkConnectionTimeOutProp = s"${configPrefix}zookeeper.connection.timeout.ms" + val ZkSessionTimeOutProp = s"${configPrefix}zookeeper.session.timeout.ms" + val ZkMaxInFlightRequests = s"${configPrefix}zookeeper.max.in.flight.requests" + + // Semi-colon separated list of users that will be treated as super users and will have access to all the resources + // for all actions from all hosts, defaults to no super users. + val SuperUsersProp = "super.users" + // If set to true when no acls are found for a resource, authorizer allows access to everyone. Defaults to false. + val AllowEveryoneIfNoAclIsFoundProp = "allow.everyone.if.no.acl.found" + + case class VersionedAcls(acls: Set[AclEntry], zkVersion: Int) { + def exists: Boolean = zkVersion != ZkVersion.UnknownVersion + } + + class AclSeqs(seqs: Seq[AclEntry]*) { + def find(p: AclEntry => Boolean): Option[AclEntry] = { + // Lazily iterate through the inner `Seq` elements and stop as soon as we find a match + val it = seqs.iterator.flatMap(_.find(p)) + if (it.hasNext) Some(it.next()) + else None + } + + def isEmpty: Boolean = !seqs.exists(_.nonEmpty) + } + + val NoAcls = VersionedAcls(Set.empty, ZkVersion.UnknownVersion) + val WildcardHost = "*" + + // Orders by resource type, then resource pattern type and finally reverse ordering by name. + class ResourceOrdering extends Ordering[ResourcePattern] { + + def compare(a: ResourcePattern, b: ResourcePattern): Int = { + val rt = a.resourceType.compareTo(b.resourceType) + if (rt != 0) + rt + else { + val rnt = a.patternType.compareTo(b.patternType) + if (rnt != 0) + rnt + else + (a.name compare b.name) * -1 + } + } + } + + private[authorizer] def zkClientConfigFromKafkaConfigAndMap(kafkaConfig: KafkaConfig, configMap: mutable.Map[String, _<:Any]): ZKClientConfig = { + val zkSslClientEnable = configMap.get(AclAuthorizer.configPrefix + KafkaConfig.ZkSslClientEnableProp). + map(_.toString).getOrElse(kafkaConfig.zkSslClientEnable.toString).toBoolean + if (!zkSslClientEnable) + new ZKClientConfig + else { + // start with the base config from the Kafka configuration + // be sure to force creation since the zkSslClientEnable property in the kafkaConfig could be false + val zkClientConfig = KafkaServer.zkClientConfigFromKafkaConfig(kafkaConfig, true) + // add in any prefixed overlays + KafkaConfig.ZkSslConfigToSystemPropertyMap.forKeyValue { (kafkaProp, sysProp) => + configMap.get(AclAuthorizer.configPrefix + kafkaProp).foreach { prefixedValue => + zkClientConfig.setProperty(sysProp, + if (kafkaProp == KafkaConfig.ZkSslEndpointIdentificationAlgorithmProp) + (prefixedValue.toString.toUpperCase == "HTTPS").toString + else + prefixedValue.toString) + } + } + zkClientConfig + } + } + + private def validateAclBinding(aclBinding: AclBinding): Unit = { + if (aclBinding.isUnknown) + throw new IllegalArgumentException("ACL binding contains unknown elements") + } +} + +class AclAuthorizer extends Authorizer with Logging { + import kafka.security.authorizer.AclAuthorizer._ + + private[security] val authorizerLogger = Logger("kafka.authorizer.logger") + private var superUsers = Set.empty[KafkaPrincipal] + private var shouldAllowEveryoneIfNoAclIsFound = false + private var zkClient: KafkaZkClient = _ + private var aclChangeListeners: Iterable[AclChangeSubscription] = Iterable.empty + private var extendedAclSupport: Boolean = _ + + @volatile + private var aclCache = new scala.collection.immutable.TreeMap[ResourcePattern, VersionedAcls]()(new ResourceOrdering) + + @volatile + private var resourceCache = new scala.collection.immutable.HashMap[ResourceTypeKey, + scala.collection.immutable.HashSet[String]]() + + private val lock = new Object() + + // The maximum number of times we should try to update the resource acls in zookeeper before failing; + // This should never occur, but is a safeguard just in case. + protected[security] var maxUpdateRetries = 10 + + private val retryBackoffMs = 100 + private val retryBackoffJitterMs = 50 + + /** + * Guaranteed to be called before any authorize call is made. + */ + override def configure(javaConfigs: util.Map[String, _]): Unit = { + val configs = javaConfigs.asScala + val props = new java.util.Properties() + configs.forKeyValue { (key, value) => props.put(key, value.toString) } + + superUsers = configs.get(AclAuthorizer.SuperUsersProp).collect { + case str: String if str.nonEmpty => str.split(";").map(s => SecurityUtils.parseKafkaPrincipal(s.trim)).toSet + }.getOrElse(Set.empty[KafkaPrincipal]) + + shouldAllowEveryoneIfNoAclIsFound = configs.get(AclAuthorizer.AllowEveryoneIfNoAclIsFoundProp).exists(_.toString.toBoolean) + + // Use `KafkaConfig` in order to get the default ZK config values if not present in `javaConfigs`. Note that this + // means that `KafkaConfig.zkConnect` must always be set by the user (even if `AclAuthorizer.ZkUrlProp` is also + // set). + val kafkaConfig = KafkaConfig.fromProps(props, doLog = false) + val zkUrl = configs.get(AclAuthorizer.ZkUrlProp).map(_.toString).getOrElse(kafkaConfig.zkConnect) + val zkConnectionTimeoutMs = configs.get(AclAuthorizer.ZkConnectionTimeOutProp).map(_.toString.toInt).getOrElse(kafkaConfig.zkConnectionTimeoutMs) + val zkSessionTimeOutMs = configs.get(AclAuthorizer.ZkSessionTimeOutProp).map(_.toString.toInt).getOrElse(kafkaConfig.zkSessionTimeoutMs) + val zkMaxInFlightRequests = configs.get(AclAuthorizer.ZkMaxInFlightRequests).map(_.toString.toInt).getOrElse(kafkaConfig.zkMaxInFlightRequests) + + val zkClientConfig = AclAuthorizer.zkClientConfigFromKafkaConfigAndMap(kafkaConfig, configs) + val time = Time.SYSTEM + // createChrootIfNecessary=true is necessary in case we are running in a KRaft cluster + // because such a cluster will not create any chroot path in ZooKeeper (it doesn't connect to ZooKeeper) + zkClient = KafkaZkClient(zkUrl, kafkaConfig.zkEnableSecureAcls, zkSessionTimeOutMs, zkConnectionTimeoutMs, + zkMaxInFlightRequests, time, name = "ACL authorizer", zkClientConfig = zkClientConfig, + metricGroup = "kafka.security", metricType = "AclAuthorizer", createChrootIfNecessary = true) + zkClient.createAclPaths() + + extendedAclSupport = kafkaConfig.interBrokerProtocolVersion >= KAFKA_2_0_IV1 + + // Start change listeners first and then populate the cache so that there is no timing window + // between loading cache and processing change notifications. + startZkChangeListeners() + loadCache() + } + + override def start(serverInfo: AuthorizerServerInfo): util.Map[Endpoint, _ <: CompletionStage[Void]] = { + serverInfo.endpoints.asScala.map { endpoint => + endpoint -> CompletableFuture.completedFuture[Void](null) }.toMap.asJava + } + + override def authorize(requestContext: AuthorizableRequestContext, actions: util.List[Action]): util.List[AuthorizationResult] = { + actions.asScala.map { action => authorizeAction(requestContext, action) }.asJava + } + + override def createAcls(requestContext: AuthorizableRequestContext, + aclBindings: util.List[AclBinding]): util.List[_ <: CompletionStage[AclCreateResult]] = { + val results = new Array[AclCreateResult](aclBindings.size) + val aclsToCreate = aclBindings.asScala.zipWithIndex.filter { case (aclBinding, i) => + try { + if (!extendedAclSupport && aclBinding.pattern.patternType == PatternType.PREFIXED) { + throw new UnsupportedVersionException(s"Adding ACLs on prefixed resource patterns requires " + + s"${KafkaConfig.InterBrokerProtocolVersionProp} of $KAFKA_2_0_IV1 or greater") + } + validateAclBinding(aclBinding) + true + } catch { + case e: Throwable => + results(i) = new AclCreateResult(new InvalidRequestException("Failed to create ACL", apiException(e))) + false + } + }.groupBy(_._1.pattern) + + if (aclsToCreate.nonEmpty) { + lock synchronized { + aclsToCreate.forKeyValue { (resource, aclsWithIndex) => + try { + updateResourceAcls(resource) { currentAcls => + val newAcls = aclsWithIndex.map { case (acl, _) => new AclEntry(acl.entry) } + currentAcls ++ newAcls + } + aclsWithIndex.foreach { case (_, index) => results(index) = AclCreateResult.SUCCESS } + } catch { + case e: Throwable => + aclsWithIndex.foreach { case (_, index) => results(index) = new AclCreateResult(apiException(e)) } + } + } + } + } + results.toBuffer.map(CompletableFuture.completedFuture[AclCreateResult]).asJava + } + + /** + * + * Concurrent updates: + *
              + *
            • If ACLs are created using [[kafka.security.authorizer.AclAuthorizer#createAcls]] while a delete is in + * progress, these ACLs may or may not be considered for deletion depending on the order of updates. + * The returned [[org.apache.kafka.server.authorizer.AclDeleteResult]] indicates which ACLs were deleted.
            • + *
            • If the provided filters use resource pattern type + * [[org.apache.kafka.common.resource.PatternType#MATCH]] that needs to filter all resources to determine + * matching ACLs, only ACLs that have already been propagated to the broker processing the ACL update will be + * deleted. This may not include some ACLs that were persisted, but not yet propagated to all brokers. The + * returned [[org.apache.kafka.server.authorizer.AclDeleteResult]] indicates which ACLs were deleted.
            • + *
            • If the provided filters use other resource pattern types that perform a direct match, all matching ACLs + * from previously completed [[kafka.security.authorizer.AclAuthorizer#createAcls]] are guaranteed to be deleted.
            • + *
            + */ + override def deleteAcls(requestContext: AuthorizableRequestContext, + aclBindingFilters: util.List[AclBindingFilter]): util.List[_ <: CompletionStage[AclDeleteResult]] = { + val deletedBindings = new mutable.HashMap[AclBinding, Int]() + val deleteExceptions = new mutable.HashMap[AclBinding, ApiException]() + val filters = aclBindingFilters.asScala.zipWithIndex + lock synchronized { + // Find all potentially matching resource patterns from the provided filters and ACL cache and apply the filters + val resources = aclCache.keys ++ filters.map(_._1.patternFilter).filter(_.matchesAtMostOne).flatMap(filterToResources) + val resourcesToUpdate = resources.map { resource => + val matchingFilters = filters.filter { case (filter, _) => + filter.patternFilter.matches(resource) + } + resource -> matchingFilters + }.toMap.filter(_._2.nonEmpty) + + resourcesToUpdate.forKeyValue { (resource, matchingFilters) => + val resourceBindingsBeingDeleted = new mutable.HashMap[AclBinding, Int]() + try { + updateResourceAcls(resource) { currentAcls => + val aclsToRemove = currentAcls.filter { acl => + matchingFilters.exists { case (filter, index) => + val matches = filter.entryFilter.matches(acl) + if (matches) { + val binding = new AclBinding(resource, acl) + deletedBindings.getOrElseUpdate(binding, index) + resourceBindingsBeingDeleted.getOrElseUpdate(binding, index) + } + matches + } + } + currentAcls -- aclsToRemove + } + } catch { + case e: Exception => + resourceBindingsBeingDeleted.keys.foreach { binding => + deleteExceptions.getOrElseUpdate(binding, apiException(e)) + } + } + } + } + val deletedResult = deletedBindings.groupBy(_._2).map { case (k, bindings) => + k -> bindings.keys.map { binding => new AclBindingDeleteResult(binding, deleteExceptions.get(binding).orNull) } + } + (0 until aclBindingFilters.size).map { i => + new AclDeleteResult(deletedResult.getOrElse(i, Set.empty[AclBindingDeleteResult]).toSet.asJava) + }.map(CompletableFuture.completedFuture[AclDeleteResult]).asJava + } + + override def acls(filter: AclBindingFilter): lang.Iterable[AclBinding] = { + val aclBindings = new util.ArrayList[AclBinding]() + aclCache.forKeyValue { case (resource, versionedAcls) => + versionedAcls.acls.foreach { acl => + val binding = new AclBinding(resource, acl.ace) + if (filter.matches(binding)) + aclBindings.add(binding) + } + } + aclBindings + } + + override def close(): Unit = { + aclChangeListeners.foreach(listener => listener.close()) + if (zkClient != null) zkClient.close() + } + + override def authorizeByResourceType(requestContext: AuthorizableRequestContext, + op: AclOperation, + resourceType: ResourceType): AuthorizationResult = { + SecurityUtils.authorizeByResourceTypeCheckArgs(op, resourceType) + + val principal = new KafkaPrincipal( + requestContext.principal().getPrincipalType, + requestContext.principal().getName) + + if (isSuperUser(principal)) + return AuthorizationResult.ALLOWED + + val resourceSnapshot = resourceCache + val principalStr = principal.toString + val host = requestContext.clientAddress().getHostAddress + val action = new Action(op, new ResourcePattern(resourceType, "NONE", PatternType.UNKNOWN), 0, true, true) + + val denyLiterals = matchingResources( + resourceSnapshot, principalStr, host, op, AclPermissionType.DENY, resourceType, PatternType.LITERAL) + + if (denyAll(denyLiterals)) { + logAuditMessage(requestContext, action, authorized = false) + return AuthorizationResult.DENIED + } + + if (shouldAllowEveryoneIfNoAclIsFound) { + logAuditMessage(requestContext, action, authorized = true) + return AuthorizationResult.ALLOWED + } + + val denyPrefixes = matchingResources( + resourceSnapshot, principalStr, host, op, AclPermissionType.DENY, resourceType, PatternType.PREFIXED) + + if (denyLiterals.isEmpty && denyPrefixes.isEmpty) { + if (hasMatchingResources(resourceSnapshot, principalStr, host, op, AclPermissionType.ALLOW, resourceType, PatternType.PREFIXED) + || hasMatchingResources(resourceSnapshot, principalStr, host, op, AclPermissionType.ALLOW, resourceType, PatternType.LITERAL)) { + logAuditMessage(requestContext, action, authorized = true) + return AuthorizationResult.ALLOWED + } else { + logAuditMessage(requestContext, action, authorized = false) + return AuthorizationResult.DENIED + } + } + + val allowLiterals = matchingResources( + resourceSnapshot, principalStr, host, op, AclPermissionType.ALLOW, resourceType, PatternType.LITERAL) + val allowPrefixes = matchingResources( + resourceSnapshot, principalStr, host, op, AclPermissionType.ALLOW, resourceType, PatternType.PREFIXED) + + if (allowAny(allowLiterals, allowPrefixes, denyLiterals, denyPrefixes)) { + logAuditMessage(requestContext, action, authorized = true) + return AuthorizationResult.ALLOWED + } + + logAuditMessage(requestContext, action, authorized = false) + AuthorizationResult.DENIED + } + + private def matchingResources(resourceSnapshot: immutable.Map[ResourceTypeKey, immutable.Set[String]], + principal: String, host: String, op: AclOperation, permission: AclPermissionType, + resourceType: ResourceType, patternType: PatternType): ArrayBuffer[Set[String]] = { + val matched = ArrayBuffer[immutable.Set[String]]() + for (p <- Set(principal, AclEntry.WildcardPrincipalString); + h <- Set(host, AclEntry.WildcardHost); + o <- Set(op, AclOperation.ALL)) { + val resourceTypeKey = ResourceTypeKey( + new AccessControlEntry(p, h, o, permission), resourceType, patternType) + resourceSnapshot.get(resourceTypeKey) match { + case Some(resources) => matched += resources + case None => + } + } + matched + } + + private def hasMatchingResources(resourceSnapshot: immutable.Map[ResourceTypeKey, immutable.Set[String]], + principal: String, host: String, op: AclOperation, permission: AclPermissionType, + resourceType: ResourceType, patternType: PatternType): Boolean = { + for (p <- Set(principal, AclEntry.WildcardPrincipalString); + h <- Set(host, AclEntry.WildcardHost); + o <- Set(op, AclOperation.ALL)) { + val resourceTypeKey = ResourceTypeKey( + new AccessControlEntry(p, h, o, permission), resourceType, patternType) + if (resourceSnapshot.contains(resourceTypeKey)) + return true + } + false + } + + private def denyAll(denyLiterals: ArrayBuffer[immutable.Set[String]]): Boolean = + denyLiterals.exists(_.contains(ResourcePattern.WILDCARD_RESOURCE)) + + + private def allowAny(allowLiterals: ArrayBuffer[immutable.Set[String]], allowPrefixes: ArrayBuffer[immutable.Set[String]], + denyLiterals: ArrayBuffer[immutable.Set[String]], denyPrefixes: ArrayBuffer[immutable.Set[String]]): Boolean = { + (allowPrefixes.exists(_.exists(prefix => allowPrefix(prefix, denyPrefixes))) + || allowLiterals.exists(_.exists(literal => allowLiteral(literal, denyLiterals, denyPrefixes)))) + } + + private def allowLiteral(literalName: String, denyLiterals: ArrayBuffer[immutable.Set[String]], + denyPrefixes: ArrayBuffer[immutable.Set[String]]): Boolean = { + literalName match { + case ResourcePattern.WILDCARD_RESOURCE => true + case _ => !denyLiterals.exists(_.contains(literalName)) && !hasDominantPrefixedDeny(literalName, denyPrefixes) + } + } + + private def allowPrefix(prefixName: String, + denyPrefixes: ArrayBuffer[immutable.Set[String]]): Boolean = { + !hasDominantPrefixedDeny(prefixName, denyPrefixes) + } + + private def hasDominantPrefixedDeny(resourceName: String, denyPrefixes: ArrayBuffer[immutable.Set[String]]): Boolean = { + val sb = new StringBuilder + for (ch <- resourceName.toCharArray) { + sb.append(ch) + if (denyPrefixes.exists(p => p.contains(sb.toString()))) { + return true + } + } + false + } + + + private def authorizeAction(requestContext: AuthorizableRequestContext, action: Action): AuthorizationResult = { + val resource = action.resourcePattern + if (resource.patternType != PatternType.LITERAL) { + throw new IllegalArgumentException("Only literal resources are supported. Got: " + resource.patternType) + } + + // ensure we compare identical classes + val sessionPrincipal = requestContext.principal + val principal = if (classOf[KafkaPrincipal] != sessionPrincipal.getClass) + new KafkaPrincipal(sessionPrincipal.getPrincipalType, sessionPrincipal.getName) + else + sessionPrincipal + + val host = requestContext.clientAddress.getHostAddress + val operation = action.operation + + def isEmptyAclAndAuthorized(acls: AclSeqs): Boolean = { + if (acls.isEmpty) { + // No ACLs found for this resource, permission is determined by value of config allow.everyone.if.no.acl.found + authorizerLogger.debug(s"No acl found for resource $resource, authorized = $shouldAllowEveryoneIfNoAclIsFound") + shouldAllowEveryoneIfNoAclIsFound + } else false + } + + def denyAclExists(acls: AclSeqs): Boolean = { + // Check if there are any Deny ACLs which would forbid this operation. + matchingAclExists(operation, resource, principal, host, DENY, acls) + } + + def allowAclExists(acls: AclSeqs): Boolean = { + // Check if there are any Allow ACLs which would allow this operation. + // Allowing read, write, delete, or alter implies allowing describe. + // See #{org.apache.kafka.common.acl.AclOperation} for more details about ACL inheritance. + val allowOps = operation match { + case DESCRIBE => Set[AclOperation](DESCRIBE, READ, WRITE, DELETE, ALTER) + case DESCRIBE_CONFIGS => Set[AclOperation](DESCRIBE_CONFIGS, ALTER_CONFIGS) + case _ => Set[AclOperation](operation) + } + allowOps.exists(operation => matchingAclExists(operation, resource, principal, host, ALLOW, acls)) + } + + def aclsAllowAccess = { + // we allow an operation if no acls are found and user has configured to allow all users + // when no acls are found or if no deny acls are found and at least one allow acls matches. + val acls = matchingAcls(resource.resourceType, resource.name) + isEmptyAclAndAuthorized(acls) || (!denyAclExists(acls) && allowAclExists(acls)) + } + + // Evaluate if operation is allowed + val authorized = isSuperUser(principal) || aclsAllowAccess + + logAuditMessage(requestContext, action, authorized) + if (authorized) AuthorizationResult.ALLOWED else AuthorizationResult.DENIED + } + + def isSuperUser(principal: KafkaPrincipal): Boolean = { + if (superUsers.contains(principal)) { + authorizerLogger.debug(s"principal = $principal is a super user, allowing operation without checking acls.") + true + } else false + } + + @nowarn("cat=deprecation") + private def matchingAcls(resourceType: ResourceType, resourceName: String): AclSeqs = { + // this code is performance sensitive, make sure to run AclAuthorizerBenchmark after any changes + + // save aclCache reference to a local val to get a consistent view of the cache during acl updates. + val aclCacheSnapshot = aclCache + val wildcard = aclCacheSnapshot.get(new ResourcePattern(resourceType, ResourcePattern.WILDCARD_RESOURCE, PatternType.LITERAL)) + .map(_.acls.toBuffer) + .getOrElse(mutable.Buffer.empty) + + val literal = aclCacheSnapshot.get(new ResourcePattern(resourceType, resourceName, PatternType.LITERAL)) + .map(_.acls.toBuffer) + .getOrElse(mutable.Buffer.empty) + + val prefixed = new ArrayBuffer[AclEntry] + aclCacheSnapshot + .from(new ResourcePattern(resourceType, resourceName, PatternType.PREFIXED)) + .to(new ResourcePattern(resourceType, resourceName.take(1), PatternType.PREFIXED)) + .forKeyValue { (resource, acls) => + if (resourceName.startsWith(resource.name)) prefixed ++= acls.acls + } + + new AclSeqs(prefixed, wildcard, literal) + } + + private def matchingAclExists(operation: AclOperation, + resource: ResourcePattern, + principal: KafkaPrincipal, + host: String, + permissionType: AclPermissionType, + acls: AclSeqs): Boolean = { + acls.find { acl => + acl.permissionType == permissionType && + (acl.kafkaPrincipal == principal || acl.kafkaPrincipal == AclEntry.WildcardPrincipal) && + (operation == acl.operation || acl.operation == AclOperation.ALL) && + (acl.host == host || acl.host == AclEntry.WildcardHost) + }.exists { acl => + authorizerLogger.debug(s"operation = $operation on resource = $resource from host = $host is $permissionType based on acl = $acl") + true + } + } + + private def loadCache(): Unit = { + lock synchronized { + ZkAclStore.stores.foreach(store => { + val resourceTypes = zkClient.getResourceTypes(store.patternType) + for (rType <- resourceTypes) { + val resourceType = Try(SecurityUtils.resourceType(rType)) + resourceType match { + case Success(resourceTypeObj) => + val resourceNames = zkClient.getResourceNames(store.patternType, resourceTypeObj) + for (resourceName <- resourceNames) { + val resource = new ResourcePattern(resourceTypeObj, resourceName, store.patternType) + val versionedAcls = getAclsFromZk(resource) + updateCache(resource, versionedAcls) + } + case Failure(_) => warn(s"Ignoring unknown ResourceType: $rType") + } + } + }) + } + } + + private[authorizer] def startZkChangeListeners(): Unit = { + aclChangeListeners = ZkAclChangeStore.stores + .map(store => store.createListener(AclChangedNotificationHandler, zkClient)) + } + + private def filterToResources(filter: ResourcePatternFilter): Set[ResourcePattern] = { + filter.patternType match { + case PatternType.LITERAL | PatternType.PREFIXED => + Set(new ResourcePattern(filter.resourceType, filter.name, filter.patternType)) + case PatternType.ANY => + Set(new ResourcePattern(filter.resourceType, filter.name, PatternType.LITERAL), + new ResourcePattern(filter.resourceType, filter.name, PatternType.PREFIXED)) + case _ => throw new IllegalArgumentException(s"Cannot determine matching resources for patternType $filter") + } + } + + def logAuditMessage(requestContext: AuthorizableRequestContext, action: Action, authorized: Boolean): Unit = { + def logMessage: String = { + val principal = requestContext.principal + val operation = SecurityUtils.operationName(action.operation) + val host = requestContext.clientAddress.getHostAddress + val resourceType = SecurityUtils.resourceTypeName(action.resourcePattern.resourceType) + val resource = s"$resourceType$ResourceSeparator${action.resourcePattern.patternType}$ResourceSeparator${action.resourcePattern.name}" + val authResult = if (authorized) "Allowed" else "Denied" + val apiKey = if (ApiKeys.hasId(requestContext.requestType)) ApiKeys.forId(requestContext.requestType).name else requestContext.requestType + val refCount = action.resourceReferenceCount + + s"Principal = $principal is $authResult Operation = $operation " + + s"from host = $host on resource = $resource for request = $apiKey with resourceRefCount = $refCount" + } + + if (authorized) { + // logIfAllowed is true if access is granted to the resource as a result of this authorization. + // In this case, log at debug level. If false, no access is actually granted, the result is used + // only to determine authorized operations. So log only at trace level. + if (action.logIfAllowed) + authorizerLogger.debug(logMessage) + else + authorizerLogger.trace(logMessage) + } else { + // logIfDenied is true if access to the resource was explicitly requested. Since this is an attempt + // to access unauthorized resources, log at info level. If false, this is either a request to determine + // authorized operations or a filter (e.g for regex subscriptions) to filter out authorized resources. + // In this case, log only at trace level. + if (action.logIfDenied) + authorizerLogger.info(logMessage) + else + authorizerLogger.trace(logMessage) + } + } + + /** + * Safely updates the resources ACLs by ensuring reads and writes respect the expected zookeeper version. + * Continues to retry until it successfully updates zookeeper. + * + * Returns a boolean indicating if the content of the ACLs was actually changed. + * + * @param resource the resource to change ACLs for + * @param getNewAcls function to transform existing acls to new ACLs + * @return boolean indicating if a change was made + */ + private def updateResourceAcls(resource: ResourcePattern)(getNewAcls: Set[AclEntry] => Set[AclEntry]): Boolean = { + var currentVersionedAcls = + if (aclCache.contains(resource)) + getAclsFromCache(resource) + else + getAclsFromZk(resource) + var newVersionedAcls: VersionedAcls = null + var writeComplete = false + var retries = 0 + while (!writeComplete && retries <= maxUpdateRetries) { + val newAcls = getNewAcls(currentVersionedAcls.acls) + val (updateSucceeded, updateVersion) = + if (newAcls.nonEmpty) { + if (currentVersionedAcls.exists) + zkClient.conditionalSetAclsForResource(resource, newAcls, currentVersionedAcls.zkVersion) + else + zkClient.createAclsForResourceIfNotExists(resource, newAcls) + } else { + trace(s"Deleting path for $resource because it had no ACLs remaining") + (zkClient.conditionalDelete(resource, currentVersionedAcls.zkVersion), 0) + } + + if (!updateSucceeded) { + trace(s"Failed to update ACLs for $resource. Used version ${currentVersionedAcls.zkVersion}. Reading data and retrying update.") + Thread.sleep(backoffTime) + currentVersionedAcls = getAclsFromZk(resource) + retries += 1 + } else { + newVersionedAcls = VersionedAcls(newAcls, updateVersion) + writeComplete = updateSucceeded + } + } + + if (!writeComplete) + throw new IllegalStateException(s"Failed to update ACLs for $resource after trying a maximum of $maxUpdateRetries times") + + if (newVersionedAcls.acls != currentVersionedAcls.acls) { + info(s"Updated ACLs for $resource with new version ${newVersionedAcls.zkVersion}") + debug(s"Updated ACLs for $resource to $newVersionedAcls") + updateCache(resource, newVersionedAcls) + updateAclChangedFlag(resource) + true + } else { + debug(s"Updated ACLs for $resource, no change was made") + updateCache(resource, newVersionedAcls) // Even if no change, update the version + false + } + } + + private def getAclsFromCache(resource: ResourcePattern): VersionedAcls = { + aclCache.getOrElse(resource, throw new IllegalArgumentException(s"ACLs do not exist in the cache for resource $resource")) + } + + private def getAclsFromZk(resource: ResourcePattern): VersionedAcls = { + zkClient.getVersionedAclsForResource(resource) + } + + // Visible for benchmark + def updateCache(resource: ResourcePattern, versionedAcls: VersionedAcls): Unit = { + val currentAces: Set[AccessControlEntry] = aclCache.get(resource).map(_.acls.map(_.ace)).getOrElse(Set.empty) + val newAces: Set[AccessControlEntry] = versionedAcls.acls.map(aclEntry => aclEntry.ace) + val acesToAdd = newAces.diff(currentAces) + val acesToRemove = currentAces.diff(newAces) + + acesToAdd.foreach(ace => { + val resourceTypeKey = ResourceTypeKey(ace, resource.resourceType(), resource.patternType()) + resourceCache.get(resourceTypeKey) match { + case Some(resources) => resourceCache += (resourceTypeKey -> (resources + resource.name())) + case None => resourceCache += (resourceTypeKey -> immutable.HashSet(resource.name())) + } + }) + acesToRemove.foreach(ace => { + val resourceTypeKey = ResourceTypeKey(ace, resource.resourceType(), resource.patternType()) + resourceCache.get(resourceTypeKey) match { + case Some(resources) => + val newResources = resources - resource.name() + if (newResources.isEmpty) { + resourceCache -= resourceTypeKey + } else { + resourceCache += (resourceTypeKey -> newResources) + } + case None => + } + }) + + if (versionedAcls.acls.nonEmpty) { + aclCache = aclCache.updated(resource, versionedAcls) + } else { + aclCache -= resource + } + } + + private def updateAclChangedFlag(resource: ResourcePattern): Unit = { + zkClient.createAclChangeNotification(resource) + } + + private def backoffTime = { + retryBackoffMs + Random.nextInt(retryBackoffJitterMs) + } + + private def apiException(e: Throwable): ApiException = { + e match { + case e1: ApiException => e1 + case e1 => new ApiException(e1) + } + } + + private[authorizer] def processAclChangeNotification(resource: ResourcePattern): Unit = { + lock synchronized { + val versionedAcls = getAclsFromZk(resource) + info(s"Processing Acl change notification for $resource, versionedAcls : ${versionedAcls.acls}, zkVersion : ${versionedAcls.zkVersion}") + updateCache(resource, versionedAcls) + } + } + + object AclChangedNotificationHandler extends AclChangeNotificationHandler { + override def processNotification(resource: ResourcePattern): Unit = { + processAclChangeNotification(resource) + } + } + + private case class ResourceTypeKey(ace: AccessControlEntry, + resourceType: ResourceType, + patternType: PatternType) +} diff --git a/core/src/main/scala/kafka/security/authorizer/AclEntry.scala b/core/src/main/scala/kafka/security/authorizer/AclEntry.scala new file mode 100644 index 0000000..2014916 --- /dev/null +++ b/core/src/main/scala/kafka/security/authorizer/AclEntry.scala @@ -0,0 +1,146 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.security.authorizer + +import kafka.utils.Json +import org.apache.kafka.common.acl.{AccessControlEntry, AclOperation, AclPermissionType} +import org.apache.kafka.common.acl.AclOperation.{READ, WRITE, CREATE, DESCRIBE, DELETE, ALTER, DESCRIBE_CONFIGS, ALTER_CONFIGS, CLUSTER_ACTION, IDEMPOTENT_WRITE} +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.resource.{ResourcePattern, ResourceType} +import org.apache.kafka.common.security.auth.KafkaPrincipal +import org.apache.kafka.common.utils.SecurityUtils + +import scala.jdk.CollectionConverters._ + +object AclEntry { + val WildcardPrincipal: KafkaPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "*") + val WildcardPrincipalString: String = WildcardPrincipal.toString + val WildcardHost: String = "*" + val WildcardResource: String = ResourcePattern.WILDCARD_RESOURCE + + val ResourceSeparator = ":" + val ResourceTypes: Set[ResourceType] = ResourceType.values.toSet + .filterNot(t => t == ResourceType.UNKNOWN || t == ResourceType.ANY) + val AclOperations: Set[AclOperation] = AclOperation.values.toSet + .filterNot(t => t == AclOperation.UNKNOWN || t == AclOperation.ANY) + + val PrincipalKey = "principal" + val PermissionTypeKey = "permissionType" + val OperationKey = "operation" + val HostsKey = "host" + val VersionKey = "version" + val CurrentVersion = 1 + val AclsKey = "acls" + + def apply(principal: KafkaPrincipal, + permissionType: AclPermissionType, + host: String, + operation: AclOperation): AclEntry = { + new AclEntry(new AccessControlEntry(if (principal == null) null else principal.toString, + host, operation, permissionType)) + } + + /** + * Parse JSON representation of ACLs + * @param bytes of acls json string + * + *

            + { + "version": 1, + "acls": [ + { + "host":"host1", + "permissionType": "Deny", + "operation": "Read", + "principal": "User:alice" + } + ] + } + *

            + * + * @return set of AclEntry objects from the JSON string + */ + def fromBytes(bytes: Array[Byte]): Set[AclEntry] = { + if (bytes == null || bytes.isEmpty) + return collection.immutable.Set.empty[AclEntry] + + Json.parseBytes(bytes).map(_.asJsonObject).map { js => + //the acl json version. + require(js(VersionKey).to[Int] == CurrentVersion) + js(AclsKey).asJsonArray.iterator.map(_.asJsonObject).map { itemJs => + val principal = SecurityUtils.parseKafkaPrincipal(itemJs(PrincipalKey).to[String]) + val permissionType = SecurityUtils.permissionType(itemJs(PermissionTypeKey).to[String]) + val host = itemJs(HostsKey).to[String] + val operation = SecurityUtils.operation(itemJs(OperationKey).to[String]) + AclEntry(principal, permissionType, host, operation) + }.toSet + }.getOrElse(Set.empty) + } + + def toJsonCompatibleMap(acls: Set[AclEntry]): Map[String, Any] = { + Map(AclEntry.VersionKey -> AclEntry.CurrentVersion, AclEntry.AclsKey -> acls.map(acl => acl.toMap.asJava).toList.asJava) + } + + def supportedOperations(resourceType: ResourceType): Set[AclOperation] = { + resourceType match { + case ResourceType.TOPIC => Set(READ, WRITE, CREATE, DESCRIBE, DELETE, ALTER, DESCRIBE_CONFIGS, ALTER_CONFIGS) + case ResourceType.GROUP => Set(READ, DESCRIBE, DELETE) + case ResourceType.CLUSTER => Set(CREATE, CLUSTER_ACTION, DESCRIBE_CONFIGS, ALTER_CONFIGS, IDEMPOTENT_WRITE, ALTER, DESCRIBE) + case ResourceType.TRANSACTIONAL_ID => Set(DESCRIBE, WRITE) + case ResourceType.DELEGATION_TOKEN => Set(DESCRIBE) + case _ => throw new IllegalArgumentException("Not a concrete resource type") + } + } + + def authorizationError(resourceType: ResourceType): Errors = { + resourceType match { + case ResourceType.TOPIC => Errors.TOPIC_AUTHORIZATION_FAILED + case ResourceType.GROUP => Errors.GROUP_AUTHORIZATION_FAILED + case ResourceType.CLUSTER => Errors.CLUSTER_AUTHORIZATION_FAILED + case ResourceType.TRANSACTIONAL_ID => Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED + case ResourceType.DELEGATION_TOKEN => Errors.DELEGATION_TOKEN_AUTHORIZATION_FAILED + case _ => throw new IllegalArgumentException("Authorization error type not known") + } + } +} + +class AclEntry(val ace: AccessControlEntry) + extends AccessControlEntry(ace.principal, ace.host, ace.operation, ace.permissionType) { + + val kafkaPrincipal: KafkaPrincipal = if (principal == null) + null + else + SecurityUtils.parseKafkaPrincipal(principal) + + def toMap: Map[String, Any] = { + Map(AclEntry.PrincipalKey -> principal, + AclEntry.PermissionTypeKey -> SecurityUtils.permissionTypeName(permissionType), + AclEntry.OperationKey -> SecurityUtils.operationName(operation), + AclEntry.HostsKey -> host) + } + + override def hashCode(): Int = ace.hashCode() + + override def equals(o: scala.Any): Boolean = super.equals(o) // to keep spotbugs happy + + override def toString: String = { + "%s has %s permission for operations: %s from hosts: %s".format(principal, permissionType.name, operation, host) + } + +} + diff --git a/core/src/main/scala/kafka/security/authorizer/AuthorizerUtils.scala b/core/src/main/scala/kafka/security/authorizer/AuthorizerUtils.scala new file mode 100644 index 0000000..0e417d6 --- /dev/null +++ b/core/src/main/scala/kafka/security/authorizer/AuthorizerUtils.scala @@ -0,0 +1,47 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.security.authorizer + +import java.net.InetAddress + +import kafka.network.RequestChannel.Session +import org.apache.kafka.common.resource.Resource +import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol} +import org.apache.kafka.common.utils.Utils +import org.apache.kafka.server.authorizer.{AuthorizableRequestContext, Authorizer} + + +object AuthorizerUtils { + + def createAuthorizer(className: String): Authorizer = Utils.newInstance(className, classOf[Authorizer]) + + def isClusterResource(name: String): Boolean = name.equals(Resource.CLUSTER_NAME) + + def sessionToRequestContext(session: Session): AuthorizableRequestContext = { + new AuthorizableRequestContext { + override def clientId(): String = "" + override def requestType(): Int = -1 + override def listenerName(): String = "" + override def clientAddress(): InetAddress = session.clientAddress + override def principal(): KafkaPrincipal = session.principal + override def securityProtocol(): SecurityProtocol = null + override def correlationId(): Int = -1 + override def requestVersion(): Int = -1 + } + } +} diff --git a/core/src/main/scala/kafka/serializer/Decoder.scala b/core/src/main/scala/kafka/serializer/Decoder.scala new file mode 100644 index 0000000..4b8c545 --- /dev/null +++ b/core/src/main/scala/kafka/serializer/Decoder.scala @@ -0,0 +1,73 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.serializer + +import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets + +import kafka.utils.VerifiableProperties + +/** + * A decoder is a method of turning byte arrays into objects. + * An implementation is required to provide a constructor that + * takes a VerifiableProperties instance. + */ +trait Decoder[T] { + def fromBytes(bytes: Array[Byte]): T +} + +/** + * The default implementation does nothing, just returns the same byte array it takes in. + */ +class DefaultDecoder(props: VerifiableProperties = null) extends Decoder[Array[Byte]] { + def fromBytes(bytes: Array[Byte]): Array[Byte] = bytes +} + +/** + * The string decoder translates bytes into strings. It uses UTF8 by default but takes + * an optional property serializer.encoding to control this. + */ +class StringDecoder(props: VerifiableProperties = null) extends Decoder[String] { + val encoding = + if(props == null) + StandardCharsets.UTF_8.name() + else + props.getString("serializer.encoding", StandardCharsets.UTF_8.name()) + + def fromBytes(bytes: Array[Byte]): String = { + new String(bytes, encoding) + } +} + +/** + * The long decoder translates bytes into longs. + */ +class LongDecoder(props: VerifiableProperties = null) extends Decoder[Long] { + def fromBytes(bytes: Array[Byte]): Long = { + ByteBuffer.wrap(bytes).getLong + } +} + +/** + * The integer decoder translates bytes into integers. + */ +class IntegerDecoder(props: VerifiableProperties = null) extends Decoder[Integer] { + def fromBytes(bytes: Array[Byte]): Integer = { + ByteBuffer.wrap(bytes).getInt() + } +} diff --git a/core/src/main/scala/kafka/server/AbstractFetcherManager.scala b/core/src/main/scala/kafka/server/AbstractFetcherManager.scala new file mode 100755 index 0000000..7780535 --- /dev/null +++ b/core/src/main/scala/kafka/server/AbstractFetcherManager.scala @@ -0,0 +1,260 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.cluster.BrokerEndPoint +import kafka.metrics.KafkaMetricsGroup +import kafka.utils.Implicits._ +import kafka.utils.Logging +import org.apache.kafka.common.{TopicPartition, Uuid} +import org.apache.kafka.common.utils.Utils + +import scala.collection.{Map, Set, mutable} + +abstract class AbstractFetcherManager[T <: AbstractFetcherThread](val name: String, clientId: String, numFetchers: Int) + extends Logging with KafkaMetricsGroup { + // map of (source broker_id, fetcher_id per source broker) => fetcher. + // package private for test + private[server] val fetcherThreadMap = new mutable.HashMap[BrokerIdAndFetcherId, T] + private val lock = new Object + private var numFetchersPerBroker = numFetchers + val failedPartitions = new FailedPartitions + this.logIdent = "[" + name + "] " + + private val tags = Map("clientId" -> clientId) + + newGauge("MaxLag", () => { + // current max lag across all fetchers/topics/partitions + fetcherThreadMap.values.foldLeft(0L) { (curMaxLagAll, fetcherThread) => + val maxLagThread = fetcherThread.fetcherLagStats.stats.values.foldLeft(0L)((curMaxLagThread, lagMetrics) => + math.max(curMaxLagThread, lagMetrics.lag)) + math.max(curMaxLagAll, maxLagThread) + } + }, tags) + + newGauge("MinFetchRate", () => { + // current min fetch rate across all fetchers/topics/partitions + val headRate = fetcherThreadMap.values.headOption.map(_.fetcherStats.requestRate.oneMinuteRate).getOrElse(0.0) + fetcherThreadMap.values.foldLeft(headRate)((curMinAll, fetcherThread) => + math.min(curMinAll, fetcherThread.fetcherStats.requestRate.oneMinuteRate)) + }, tags) + + newGauge("FailedPartitionsCount", () => failedPartitions.size, tags) + + newGauge("DeadThreadCount", () => deadThreadCount, tags) + + private[server] def deadThreadCount: Int = lock synchronized { fetcherThreadMap.values.count(_.isThreadFailed) } + + def resizeThreadPool(newSize: Int): Unit = { + def migratePartitions(newSize: Int): Unit = { + fetcherThreadMap.forKeyValue { (id, thread) => + val partitionStates = removeFetcherForPartitions(thread.partitions) + if (id.fetcherId >= newSize) + thread.shutdown() + val fetchStates = partitionStates.map { case (topicPartition, currentFetchState) => + val initialFetchState = InitialFetchState(currentFetchState.topicId, thread.sourceBroker, + currentLeaderEpoch = currentFetchState.currentLeaderEpoch, + initOffset = currentFetchState.fetchOffset) + topicPartition -> initialFetchState + } + addFetcherForPartitions(fetchStates) + } + } + lock synchronized { + val currentSize = numFetchersPerBroker + info(s"Resizing fetcher thread pool size from $currentSize to $newSize") + numFetchersPerBroker = newSize + if (newSize != currentSize) { + // We could just migrate some partitions explicitly to new threads. But this is currently + // reassigning all partitions using the new thread size so that hash-based allocation + // works with partition add/delete as it did before. + migratePartitions(newSize) + } + shutdownIdleFetcherThreads() + } + } + + // Visible for testing + private[server] def getFetcher(topicPartition: TopicPartition): Option[T] = { + lock synchronized { + fetcherThreadMap.values.find { fetcherThread => + fetcherThread.fetchState(topicPartition).isDefined + } + } + } + + // Visibility for testing + private[server] def getFetcherId(topicPartition: TopicPartition): Int = { + lock synchronized { + Utils.abs(31 * topicPartition.topic.hashCode() + topicPartition.partition) % numFetchersPerBroker + } + } + + // This method is only needed by ReplicaAlterDirManager + def markPartitionsForTruncation(brokerId: Int, topicPartition: TopicPartition, truncationOffset: Long): Unit = { + lock synchronized { + val fetcherId = getFetcherId(topicPartition) + val brokerIdAndFetcherId = BrokerIdAndFetcherId(brokerId, fetcherId) + fetcherThreadMap.get(brokerIdAndFetcherId).foreach { thread => + thread.markPartitionsForTruncation(topicPartition, truncationOffset) + } + } + } + + // to be defined in subclass to create a specific fetcher + def createFetcherThread(fetcherId: Int, sourceBroker: BrokerEndPoint): T + + def addFetcherForPartitions(partitionAndOffsets: Map[TopicPartition, InitialFetchState]): Unit = { + lock synchronized { + val partitionsPerFetcher = partitionAndOffsets.groupBy { case (topicPartition, brokerAndInitialFetchOffset) => + BrokerAndFetcherId(brokerAndInitialFetchOffset.leader, getFetcherId(topicPartition)) + } + + def addAndStartFetcherThread(brokerAndFetcherId: BrokerAndFetcherId, + brokerIdAndFetcherId: BrokerIdAndFetcherId): T = { + val fetcherThread = createFetcherThread(brokerAndFetcherId.fetcherId, brokerAndFetcherId.broker) + fetcherThreadMap.put(brokerIdAndFetcherId, fetcherThread) + fetcherThread.start() + fetcherThread + } + + for ((brokerAndFetcherId, initialFetchOffsets) <- partitionsPerFetcher) { + val brokerIdAndFetcherId = BrokerIdAndFetcherId(brokerAndFetcherId.broker.id, brokerAndFetcherId.fetcherId) + val fetcherThread = fetcherThreadMap.get(brokerIdAndFetcherId) match { + case Some(currentFetcherThread) if currentFetcherThread.sourceBroker == brokerAndFetcherId.broker => + // reuse the fetcher thread + currentFetcherThread + case Some(f) => + f.shutdown() + addAndStartFetcherThread(brokerAndFetcherId, brokerIdAndFetcherId) + case None => + addAndStartFetcherThread(brokerAndFetcherId, brokerIdAndFetcherId) + } + + addPartitionsToFetcherThread(fetcherThread, initialFetchOffsets) + } + } + } + + def addFailedPartition(topicPartition: TopicPartition): Unit = { + lock synchronized { + failedPartitions.add(topicPartition) + } + } + + protected def addPartitionsToFetcherThread(fetcherThread: T, + initialOffsetAndEpochs: collection.Map[TopicPartition, InitialFetchState]): Unit = { + fetcherThread.addPartitions(initialOffsetAndEpochs) + info(s"Added fetcher to broker ${fetcherThread.sourceBroker.id} for partitions $initialOffsetAndEpochs") + } + + /** + * If the fetcher and partition state exist, update all to include the topic ID + * + * @param partitionsToUpdate a mapping of partitions to be updated to their leader IDs + * @param topicIds the mappings from topic name to ID or None if it does not exist + */ + def maybeUpdateTopicIds(partitionsToUpdate: Map[TopicPartition, Int], topicIds: String => Option[Uuid]): Unit = { + lock synchronized { + val partitionsPerFetcher = partitionsToUpdate.groupBy { case (topicPartition, leaderId) => + BrokerIdAndFetcherId(leaderId, getFetcherId(topicPartition)) + }.map { case (brokerAndFetcherId, partitionsToUpdate) => + (brokerAndFetcherId, partitionsToUpdate.keySet) + } + + for ((brokerIdAndFetcherId, partitions) <- partitionsPerFetcher) { + fetcherThreadMap.get(brokerIdAndFetcherId).foreach(_.maybeUpdateTopicIds(partitions, topicIds)) + } + } + } + + def removeFetcherForPartitions(partitions: Set[TopicPartition]): Map[TopicPartition, PartitionFetchState] = { + val fetchStates = mutable.Map.empty[TopicPartition, PartitionFetchState] + lock synchronized { + for (fetcher <- fetcherThreadMap.values) + fetchStates ++= fetcher.removePartitions(partitions) + failedPartitions.removeAll(partitions) + } + if (partitions.nonEmpty) + info(s"Removed fetcher for partitions $partitions") + fetchStates + } + + def shutdownIdleFetcherThreads(): Unit = { + lock synchronized { + val keysToBeRemoved = new mutable.HashSet[BrokerIdAndFetcherId] + for ((key, fetcher) <- fetcherThreadMap) { + if (fetcher.partitionCount <= 0) { + fetcher.shutdown() + keysToBeRemoved += key + } + } + fetcherThreadMap --= keysToBeRemoved + } + } + + def closeAllFetchers(): Unit = { + lock synchronized { + for ((_, fetcher) <- fetcherThreadMap) { + fetcher.initiateShutdown() + } + + for ((_, fetcher) <- fetcherThreadMap) { + fetcher.shutdown() + } + fetcherThreadMap.clear() + } + } +} + +/** + * The class FailedPartitions would keep a track of partitions marked as failed either during truncation or appending + * resulting from one of the following errors - + *
              + *
            1. Storage exception + *
            2. Fenced epoch + *
            3. Unexpected errors + *
            + * The partitions which fail due to storage error are eventually removed from this set after the log directory is + * taken offline. + */ +class FailedPartitions { + private val failedPartitionsSet = new mutable.HashSet[TopicPartition] + + def size: Int = synchronized { + failedPartitionsSet.size + } + + def add(topicPartition: TopicPartition): Unit = synchronized { + failedPartitionsSet += topicPartition + } + + def removeAll(topicPartitions: Set[TopicPartition]): Unit = synchronized { + failedPartitionsSet --= topicPartitions + } + + def contains(topicPartition: TopicPartition): Boolean = synchronized { + failedPartitionsSet.contains(topicPartition) + } +} + +case class BrokerAndFetcherId(broker: BrokerEndPoint, fetcherId: Int) + +case class InitialFetchState(topicId: Option[Uuid], leader: BrokerEndPoint, currentLeaderEpoch: Int, initOffset: Long) + +case class BrokerIdAndFetcherId(brokerId: Int, fetcherId: Int) diff --git a/core/src/main/scala/kafka/server/AbstractFetcherThread.scala b/core/src/main/scala/kafka/server/AbstractFetcherThread.scala new file mode 100755 index 0000000..492cec4 --- /dev/null +++ b/core/src/main/scala/kafka/server/AbstractFetcherThread.scala @@ -0,0 +1,910 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.cluster.BrokerEndPoint +import kafka.common.ClientIdAndBroker +import kafka.log.LogAppendInfo +import kafka.metrics.KafkaMetricsGroup +import kafka.server.AbstractFetcherThread.{ReplicaFetch, ResultWithPartitions} +import kafka.utils.CoreUtils.inLock +import kafka.utils.Implicits._ +import kafka.utils.{DelayedItem, Pool, ShutdownableThread} +import org.apache.kafka.common.errors._ +import org.apache.kafka.common.internals.PartitionStates +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset +import org.apache.kafka.common.message.{FetchResponseData, OffsetForLeaderEpochRequestData} +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.record.{FileRecords, MemoryRecords, Records} +import org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.{UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET} +import org.apache.kafka.common.requests._ +import org.apache.kafka.common.{InvalidRecordException, TopicPartition, Uuid} + +import java.nio.ByteBuffer +import java.util +import java.util.Optional +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.locks.ReentrantLock +import scala.collection.{Map, Set, mutable} +import scala.compat.java8.OptionConverters._ +import scala.jdk.CollectionConverters._ +import scala.math._ + +/** + * Abstract class for fetching data from multiple partitions from the same broker. + */ +abstract class AbstractFetcherThread(name: String, + clientId: String, + val sourceBroker: BrokerEndPoint, + failedPartitions: FailedPartitions, + fetchBackOffMs: Int = 0, + isInterruptible: Boolean = true, + val brokerTopicStats: BrokerTopicStats) //BrokerTopicStats's lifecycle managed by ReplicaManager + extends ShutdownableThread(name, isInterruptible) { + + type FetchData = FetchResponseData.PartitionData + type EpochData = OffsetForLeaderEpochRequestData.OffsetForLeaderPartition + + private val partitionStates = new PartitionStates[PartitionFetchState] + protected val partitionMapLock = new ReentrantLock + private val partitionMapCond = partitionMapLock.newCondition() + + private val metricId = ClientIdAndBroker(clientId, sourceBroker.host, sourceBroker.port) + val fetcherStats = new FetcherStats(metricId) + val fetcherLagStats = new FetcherLagStats(metricId) + + /* callbacks to be defined in subclass */ + + // process fetched data + protected def processPartitionData(topicPartition: TopicPartition, + fetchOffset: Long, + partitionData: FetchData): Option[LogAppendInfo] + + protected def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit + + protected def truncateFullyAndStartAt(topicPartition: TopicPartition, offset: Long): Unit + + protected def buildFetch(partitionMap: Map[TopicPartition, PartitionFetchState]): ResultWithPartitions[Option[ReplicaFetch]] + + protected def latestEpoch(topicPartition: TopicPartition): Option[Int] + + protected def logStartOffset(topicPartition: TopicPartition): Long + + protected def logEndOffset(topicPartition: TopicPartition): Long + + protected def endOffsetForEpoch(topicPartition: TopicPartition, epoch: Int): Option[OffsetAndEpoch] + + protected def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] + + protected def fetchFromLeader(fetchRequest: FetchRequest.Builder): Map[TopicPartition, FetchData] + + protected def fetchEarliestOffsetFromLeader(topicPartition: TopicPartition, currentLeaderEpoch: Int): Long + + protected def fetchLatestOffsetFromLeader(topicPartition: TopicPartition, currentLeaderEpoch: Int): Long + + protected val isOffsetForLeaderEpochSupported: Boolean + + protected val isTruncationOnFetchSupported: Boolean + + override def shutdown(): Unit = { + initiateShutdown() + inLock(partitionMapLock) { + partitionMapCond.signalAll() + } + awaitShutdown() + + // we don't need the lock since the thread has finished shutdown and metric removal is safe + fetcherStats.unregister() + fetcherLagStats.unregister() + } + + override def doWork(): Unit = { + maybeTruncate() + maybeFetch() + } + + private def maybeFetch(): Unit = { + val fetchRequestOpt = inLock(partitionMapLock) { + val ResultWithPartitions(fetchRequestOpt, partitionsWithError) = buildFetch(partitionStates.partitionStateMap.asScala) + + handlePartitionsWithErrors(partitionsWithError, "maybeFetch") + + if (fetchRequestOpt.isEmpty) { + trace(s"There are no active partitions. Back off for $fetchBackOffMs ms before sending a fetch request") + partitionMapCond.await(fetchBackOffMs, TimeUnit.MILLISECONDS) + } + + fetchRequestOpt + } + + fetchRequestOpt.foreach { case ReplicaFetch(sessionPartitions, fetchRequest) => + processFetchRequest(sessionPartitions, fetchRequest) + } + } + + // deal with partitions with errors, potentially due to leadership changes + private def handlePartitionsWithErrors(partitions: Iterable[TopicPartition], methodName: String): Unit = { + if (partitions.nonEmpty) { + debug(s"Handling errors in $methodName for partitions $partitions") + delayPartitions(partitions, fetchBackOffMs) + } + } + + /** + * Builds offset for leader epoch requests for partitions that are in the truncating phase based + * on latest epochs of the future replicas (the one that is fetching) + */ + private def fetchTruncatingPartitions(): (Map[TopicPartition, EpochData], Set[TopicPartition]) = inLock(partitionMapLock) { + val partitionsWithEpochs = mutable.Map.empty[TopicPartition, EpochData] + val partitionsWithoutEpochs = mutable.Set.empty[TopicPartition] + + partitionStates.partitionStateMap.forEach { (tp, state) => + if (state.isTruncating) { + latestEpoch(tp) match { + case Some(epoch) if isOffsetForLeaderEpochSupported => + partitionsWithEpochs += tp -> new EpochData() + .setPartition(tp.partition) + .setCurrentLeaderEpoch(state.currentLeaderEpoch) + .setLeaderEpoch(epoch) + case _ => + partitionsWithoutEpochs += tp + } + } + } + + (partitionsWithEpochs, partitionsWithoutEpochs) + } + + private def maybeTruncate(): Unit = { + val (partitionsWithEpochs, partitionsWithoutEpochs) = fetchTruncatingPartitions() + if (partitionsWithEpochs.nonEmpty) { + truncateToEpochEndOffsets(partitionsWithEpochs) + } + if (partitionsWithoutEpochs.nonEmpty) { + truncateToHighWatermark(partitionsWithoutEpochs) + } + } + + private def doTruncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Boolean = { + try { + truncate(topicPartition, truncationState) + true + } + catch { + case e: KafkaStorageException => + error(s"Failed to truncate $topicPartition at offset ${truncationState.offset}", e) + markPartitionFailed(topicPartition) + false + case t: Throwable => + error(s"Unexpected error occurred during truncation for $topicPartition " + + s"at offset ${truncationState.offset}", t) + markPartitionFailed(topicPartition) + false + } + } + + /** + * - Build a leader epoch fetch based on partitions that are in the Truncating phase + * - Send OffsetsForLeaderEpochRequest, retrieving the latest offset for each partition's + * leader epoch. This is the offset the follower should truncate to ensure + * accurate log replication. + * - Finally truncate the logs for partitions in the truncating phase and mark the + * truncation complete. Do this within a lock to ensure no leadership changes can + * occur during truncation. + */ + private def truncateToEpochEndOffsets(latestEpochsForPartitions: Map[TopicPartition, EpochData]): Unit = { + val endOffsets = fetchEpochEndOffsets(latestEpochsForPartitions) + //Ensure we hold a lock during truncation. + inLock(partitionMapLock) { + //Check no leadership and no leader epoch changes happened whilst we were unlocked, fetching epochs + val epochEndOffsets = endOffsets.filter { case (tp, _) => + val curPartitionState = partitionStates.stateValue(tp) + val partitionEpochRequest = latestEpochsForPartitions.getOrElse(tp, { + throw new IllegalStateException( + s"Leader replied with partition $tp not requested in OffsetsForLeaderEpoch request") + }) + val leaderEpochInRequest = partitionEpochRequest.currentLeaderEpoch + curPartitionState != null && leaderEpochInRequest == curPartitionState.currentLeaderEpoch + } + + val ResultWithPartitions(fetchOffsets, partitionsWithError) = maybeTruncateToEpochEndOffsets(epochEndOffsets, latestEpochsForPartitions) + handlePartitionsWithErrors(partitionsWithError, "truncateToEpochEndOffsets") + updateFetchOffsetAndMaybeMarkTruncationComplete(fetchOffsets) + } + } + + // Visibility for unit tests + protected[server] def truncateOnFetchResponse(epochEndOffsets: Map[TopicPartition, EpochEndOffset]): Unit = { + inLock(partitionMapLock) { + val ResultWithPartitions(fetchOffsets, partitionsWithError) = maybeTruncateToEpochEndOffsets(epochEndOffsets, Map.empty) + handlePartitionsWithErrors(partitionsWithError, "truncateOnFetchResponse") + updateFetchOffsetAndMaybeMarkTruncationComplete(fetchOffsets) + } + } + + // Visible for testing + private[server] def truncateToHighWatermark(partitions: Set[TopicPartition]): Unit = inLock(partitionMapLock) { + val fetchOffsets = mutable.HashMap.empty[TopicPartition, OffsetTruncationState] + + for (tp <- partitions) { + val partitionState = partitionStates.stateValue(tp) + if (partitionState != null) { + val highWatermark = partitionState.fetchOffset + val truncationState = OffsetTruncationState(highWatermark, truncationCompleted = true) + + info(s"Truncating partition $tp with $truncationState due to local high watermark $highWatermark") + if (doTruncate(tp, truncationState)) + fetchOffsets.put(tp, truncationState) + } + } + + updateFetchOffsetAndMaybeMarkTruncationComplete(fetchOffsets) + } + + private def maybeTruncateToEpochEndOffsets(fetchedEpochs: Map[TopicPartition, EpochEndOffset], + latestEpochsForPartitions: Map[TopicPartition, EpochData]): ResultWithPartitions[Map[TopicPartition, OffsetTruncationState]] = { + val fetchOffsets = mutable.HashMap.empty[TopicPartition, OffsetTruncationState] + val partitionsWithError = mutable.HashSet.empty[TopicPartition] + + fetchedEpochs.forKeyValue { (tp, leaderEpochOffset) => + if (partitionStates.contains(tp)) { + Errors.forCode(leaderEpochOffset.errorCode) match { + case Errors.NONE => + val offsetTruncationState = getOffsetTruncationState(tp, leaderEpochOffset) + info(s"Truncating partition $tp with $offsetTruncationState due to leader epoch and offset $leaderEpochOffset") + if (doTruncate(tp, offsetTruncationState)) + fetchOffsets.put(tp, offsetTruncationState) + + case Errors.FENCED_LEADER_EPOCH => + val currentLeaderEpoch = latestEpochsForPartitions.get(tp) + .map(epochEndOffset => Int.box(epochEndOffset.currentLeaderEpoch)).asJava + if (onPartitionFenced(tp, currentLeaderEpoch)) + partitionsWithError += tp + + case error => + info(s"Retrying leaderEpoch request for partition $tp as the leader reported an error: $error") + partitionsWithError += tp + } + } else { + // Partitions may have been removed from the fetcher while the thread was waiting for fetch + // response. Removed partitions are filtered out while holding `partitionMapLock` to ensure that we + // don't update state for any partition that may have already been migrated to another thread. + trace(s"Ignoring epoch offsets for partition $tp since it has been removed from this fetcher thread.") + } + } + + ResultWithPartitions(fetchOffsets, partitionsWithError) + } + + /** + * remove the partition if the partition state is NOT updated. Otherwise, keep the partition active. + * @return true if the epoch in this thread is updated. otherwise, false + */ + private def onPartitionFenced(tp: TopicPartition, requestEpoch: Optional[Integer]): Boolean = inLock(partitionMapLock) { + Option(partitionStates.stateValue(tp)).exists { currentFetchState => + val currentLeaderEpoch = currentFetchState.currentLeaderEpoch + if (requestEpoch.isPresent && requestEpoch.get == currentLeaderEpoch) { + info(s"Partition $tp has an older epoch ($currentLeaderEpoch) than the current leader. Will await " + + s"the new LeaderAndIsr state before resuming fetching.") + markPartitionFailed(tp) + false + } else { + info(s"Partition $tp has an new epoch ($currentLeaderEpoch) than the current leader. retry the partition later") + true + } + } + } + + private def processFetchRequest(sessionPartitions: util.Map[TopicPartition, FetchRequest.PartitionData], + fetchRequest: FetchRequest.Builder): Unit = { + val partitionsWithError = mutable.Set[TopicPartition]() + val divergingEndOffsets = mutable.Map.empty[TopicPartition, EpochEndOffset] + var responseData: Map[TopicPartition, FetchData] = Map.empty + + try { + trace(s"Sending fetch request $fetchRequest") + responseData = fetchFromLeader(fetchRequest) + } catch { + case t: Throwable => + if (isRunning) { + warn(s"Error in response for fetch request $fetchRequest", t) + inLock(partitionMapLock) { + partitionsWithError ++= partitionStates.partitionSet.asScala + } + } + } + fetcherStats.requestRate.mark() + + if (responseData.nonEmpty) { + // process fetched data + inLock(partitionMapLock) { + responseData.forKeyValue { (topicPartition, partitionData) => + Option(partitionStates.stateValue(topicPartition)).foreach { currentFetchState => + // It's possible that a partition is removed and re-added or truncated when there is a pending fetch request. + // In this case, we only want to process the fetch response if the partition state is ready for fetch and + // the current offset is the same as the offset requested. + val fetchPartitionData = sessionPartitions.get(topicPartition) + if (fetchPartitionData != null && fetchPartitionData.fetchOffset == currentFetchState.fetchOffset && currentFetchState.isReadyForFetch) { + Errors.forCode(partitionData.errorCode) match { + case Errors.NONE => + try { + // Once we hand off the partition data to the subclass, we can't mess with it any more in this thread + val logAppendInfoOpt = processPartitionData(topicPartition, currentFetchState.fetchOffset, + partitionData) + + logAppendInfoOpt.foreach { logAppendInfo => + val validBytes = logAppendInfo.validBytes + val nextOffset = if (validBytes > 0) logAppendInfo.lastOffset + 1 else currentFetchState.fetchOffset + val lag = Math.max(0L, partitionData.highWatermark - nextOffset) + fetcherLagStats.getAndMaybePut(topicPartition).lag = lag + + // ReplicaDirAlterThread may have removed topicPartition from the partitionStates after processing the partition data + if (validBytes > 0 && partitionStates.contains(topicPartition)) { + // Update partitionStates only if there is no exception during processPartitionData + val newFetchState = PartitionFetchState(currentFetchState.topicId, nextOffset, Some(lag), + currentFetchState.currentLeaderEpoch, state = Fetching, + logAppendInfo.lastLeaderEpoch) + partitionStates.updateAndMoveToEnd(topicPartition, newFetchState) + fetcherStats.byteRate.mark(validBytes) + } + } + if (isTruncationOnFetchSupported) { + FetchResponse.divergingEpoch(partitionData).ifPresent { divergingEpoch => + divergingEndOffsets += topicPartition -> new EpochEndOffset() + .setPartition(topicPartition.partition) + .setErrorCode(Errors.NONE.code) + .setLeaderEpoch(divergingEpoch.epoch) + .setEndOffset(divergingEpoch.endOffset) + } + } + } catch { + case ime@(_: CorruptRecordException | _: InvalidRecordException) => + // we log the error and continue. This ensures two things + // 1. If there is a corrupt message in a topic partition, it does not bring the fetcher thread + // down and cause other topic partition to also lag + // 2. If the message is corrupt due to a transient state in the log (truncation, partial writes + // can cause this), we simply continue and should get fixed in the subsequent fetches + error(s"Found invalid messages during fetch for partition $topicPartition " + + s"offset ${currentFetchState.fetchOffset}", ime) + partitionsWithError += topicPartition + case e: KafkaStorageException => + error(s"Error while processing data for partition $topicPartition " + + s"at offset ${currentFetchState.fetchOffset}", e) + markPartitionFailed(topicPartition) + case t: Throwable => + // stop monitoring this partition and add it to the set of failed partitions + error(s"Unexpected error occurred while processing data for partition $topicPartition " + + s"at offset ${currentFetchState.fetchOffset}", t) + markPartitionFailed(topicPartition) + } + case Errors.OFFSET_OUT_OF_RANGE => + if (handleOutOfRangeError(topicPartition, currentFetchState, fetchPartitionData.currentLeaderEpoch)) + partitionsWithError += topicPartition + + case Errors.UNKNOWN_LEADER_EPOCH => + debug(s"Remote broker has a smaller leader epoch for partition $topicPartition than " + + s"this replica's current leader epoch of ${currentFetchState.currentLeaderEpoch}.") + partitionsWithError += topicPartition + + case Errors.FENCED_LEADER_EPOCH => + if (onPartitionFenced(topicPartition, fetchPartitionData.currentLeaderEpoch)) + partitionsWithError += topicPartition + + case Errors.NOT_LEADER_OR_FOLLOWER => + debug(s"Remote broker is not the leader for partition $topicPartition, which could indicate " + + "that the partition is being moved") + partitionsWithError += topicPartition + + case Errors.UNKNOWN_TOPIC_OR_PARTITION => + warn(s"Received ${Errors.UNKNOWN_TOPIC_OR_PARTITION} from the leader for partition $topicPartition. " + + "This error may be returned transiently when the partition is being created or deleted, but it is not " + + "expected to persist.") + partitionsWithError += topicPartition + + case Errors.UNKNOWN_TOPIC_ID => + warn(s"Received ${Errors.UNKNOWN_TOPIC_ID} from the leader for partition $topicPartition. " + + "This error may be returned transiently when the partition is being created or deleted, but it is not " + + "expected to persist.") + partitionsWithError += topicPartition + + case Errors.INCONSISTENT_TOPIC_ID => + warn(s"Received ${Errors.INCONSISTENT_TOPIC_ID} from the leader for partition $topicPartition. " + + "This error may be returned transiently when the partition is being created or deleted, but it is not " + + "expected to persist.") + partitionsWithError += topicPartition + + case partitionError => + error(s"Error for partition $topicPartition at offset ${currentFetchState.fetchOffset}", partitionError.exception) + partitionsWithError += topicPartition + } + } + } + } + } + } + + if (divergingEndOffsets.nonEmpty) + truncateOnFetchResponse(divergingEndOffsets) + if (partitionsWithError.nonEmpty) { + handlePartitionsWithErrors(partitionsWithError, "processFetchRequest") + } + } + + /** + * This is used to mark partitions for truncation in ReplicaAlterLogDirsThread after leader + * offsets are known. + */ + def markPartitionsForTruncation(topicPartition: TopicPartition, truncationOffset: Long): Unit = { + partitionMapLock.lockInterruptibly() + try { + Option(partitionStates.stateValue(topicPartition)).foreach { state => + val newState = PartitionFetchState(state.topicId, math.min(truncationOffset, state.fetchOffset), + state.lag, state.currentLeaderEpoch, state.delay, state = Truncating, + lastFetchedEpoch = None) + partitionStates.updateAndMoveToEnd(topicPartition, newState) + partitionMapCond.signalAll() + } + } finally partitionMapLock.unlock() + } + + private def markPartitionFailed(topicPartition: TopicPartition): Unit = { + partitionMapLock.lock() + try { + failedPartitions.add(topicPartition) + removePartitions(Set(topicPartition)) + } finally partitionMapLock.unlock() + warn(s"Partition $topicPartition marked as failed") + } + + /** + * Returns initial partition fetch state based on current state and the provided `initialFetchState`. + * From IBP 2.7 onwards, we can rely on truncation based on diverging data returned in fetch responses. + * For older versions, we can skip the truncation step iff the leader epoch matches the existing epoch. + */ + private def partitionFetchState(tp: TopicPartition, initialFetchState: InitialFetchState, currentState: PartitionFetchState): PartitionFetchState = { + if (currentState != null && currentState.currentLeaderEpoch == initialFetchState.currentLeaderEpoch) { + currentState + } else if (initialFetchState.initOffset < 0) { + fetchOffsetAndTruncate(tp, initialFetchState.topicId, initialFetchState.currentLeaderEpoch) + } else if (isTruncationOnFetchSupported) { + // With old message format, `latestEpoch` will be empty and we use Truncating state + // to truncate to high watermark. + val lastFetchedEpoch = latestEpoch(tp) + val state = if (lastFetchedEpoch.nonEmpty) Fetching else Truncating + PartitionFetchState(initialFetchState.topicId, initialFetchState.initOffset, None, initialFetchState.currentLeaderEpoch, + state, lastFetchedEpoch) + } else { + PartitionFetchState(initialFetchState.topicId, initialFetchState.initOffset, None, initialFetchState.currentLeaderEpoch, + state = Truncating, lastFetchedEpoch = None) + } + } + + def addPartitions(initialFetchStates: Map[TopicPartition, InitialFetchState]): Set[TopicPartition] = { + partitionMapLock.lockInterruptibly() + try { + failedPartitions.removeAll(initialFetchStates.keySet) + + initialFetchStates.forKeyValue { (tp, initialFetchState) => + val currentState = partitionStates.stateValue(tp) + val updatedState = partitionFetchState(tp, initialFetchState, currentState) + partitionStates.updateAndMoveToEnd(tp, updatedState) + } + + partitionMapCond.signalAll() + initialFetchStates.keySet + } finally partitionMapLock.unlock() + } + + def maybeUpdateTopicIds(partitions: Set[TopicPartition], topicIds: String => Option[Uuid]): Unit = { + partitionMapLock.lockInterruptibly() + try { + partitions.foreach { tp => + val currentState = partitionStates.stateValue(tp) + if (currentState != null) { + val updatedState = currentState.updateTopicId(topicIds(tp.topic)) + partitionStates.update(tp, updatedState) + } + } + partitionMapCond.signalAll() + } finally partitionMapLock.unlock() + } + + /** + * Loop through all partitions, updating their fetch offset and maybe marking them as + * truncation completed if their offsetTruncationState indicates truncation completed + * + * @param fetchOffsets the partitions to update fetch offset and maybe mark truncation complete + */ + private def updateFetchOffsetAndMaybeMarkTruncationComplete(fetchOffsets: Map[TopicPartition, OffsetTruncationState]): Unit = { + val newStates: Map[TopicPartition, PartitionFetchState] = partitionStates.partitionStateMap.asScala + .map { case (topicPartition, currentFetchState) => + val maybeTruncationComplete = fetchOffsets.get(topicPartition) match { + case Some(offsetTruncationState) => + val lastFetchedEpoch = latestEpoch(topicPartition) + val state = if (isTruncationOnFetchSupported || offsetTruncationState.truncationCompleted) + Fetching + else + Truncating + PartitionFetchState(currentFetchState.topicId, offsetTruncationState.offset, currentFetchState.lag, + currentFetchState.currentLeaderEpoch, currentFetchState.delay, state, lastFetchedEpoch) + case None => currentFetchState + } + (topicPartition, maybeTruncationComplete) + } + partitionStates.set(newStates.asJava) + } + + /** + * Called from ReplicaFetcherThread and ReplicaAlterLogDirsThread maybeTruncate for each topic + * partition. Returns truncation offset and whether this is the final offset to truncate to + * + * For each topic partition, the offset to truncate to is calculated based on leader's returned + * epoch and offset: + * -- If the leader replied with undefined epoch offset, we must use the high watermark. This can + * happen if 1) the leader is still using message format older than KAFKA_0_11_0; 2) the follower + * requested leader epoch < the first leader epoch known to the leader. + * -- If the leader replied with the valid offset but undefined leader epoch, we truncate to + * leader's offset if it is lower than follower's Log End Offset. This may happen if the + * leader is on the inter-broker protocol version < KAFKA_2_0_IV0 + * -- If the leader replied with leader epoch not known to the follower, we truncate to the + * end offset of the largest epoch that is smaller than the epoch the leader replied with, and + * send OffsetsForLeaderEpochRequest with that leader epoch. In a more rare case, where the + * follower was not tracking epochs smaller than the epoch the leader replied with, we + * truncate the leader's offset (and do not send any more leader epoch requests). + * -- Otherwise, truncate to min(leader's offset, end offset on the follower for epoch that + * leader replied with, follower's Log End Offset). + * + * @param tp Topic partition + * @param leaderEpochOffset Epoch end offset received from the leader for this topic partition + */ + private def getOffsetTruncationState(tp: TopicPartition, + leaderEpochOffset: EpochEndOffset): OffsetTruncationState = inLock(partitionMapLock) { + if (leaderEpochOffset.endOffset == UNDEFINED_EPOCH_OFFSET) { + // truncate to initial offset which is the high watermark for follower replica. For + // future replica, it is either high watermark of the future replica or current + // replica's truncation offset (when the current replica truncates, it forces future + // replica's partition state to 'truncating' and sets initial offset to its truncation offset) + warn(s"Based on replica's leader epoch, leader replied with an unknown offset in $tp. " + + s"The initial fetch offset ${partitionStates.stateValue(tp).fetchOffset} will be used for truncation.") + OffsetTruncationState(partitionStates.stateValue(tp).fetchOffset, truncationCompleted = true) + } else if (leaderEpochOffset.leaderEpoch == UNDEFINED_EPOCH) { + // either leader or follower or both use inter-broker protocol version < KAFKA_2_0_IV0 + // (version 0 of OffsetForLeaderEpoch request/response) + warn(s"Leader or replica is on protocol version where leader epoch is not considered in the OffsetsForLeaderEpoch response. " + + s"The leader's offset ${leaderEpochOffset.endOffset} will be used for truncation in $tp.") + OffsetTruncationState(min(leaderEpochOffset.endOffset, logEndOffset(tp)), truncationCompleted = true) + } else { + val replicaEndOffset = logEndOffset(tp) + + // get (leader epoch, end offset) pair that corresponds to the largest leader epoch + // less than or equal to the requested epoch. + endOffsetForEpoch(tp, leaderEpochOffset.leaderEpoch) match { + case Some(OffsetAndEpoch(followerEndOffset, followerEpoch)) => + if (followerEpoch != leaderEpochOffset.leaderEpoch) { + // the follower does not know about the epoch that leader replied with + // we truncate to the end offset of the largest epoch that is smaller than the + // epoch the leader replied with, and send another offset for leader epoch request + val intermediateOffsetToTruncateTo = min(followerEndOffset, replicaEndOffset) + info(s"Based on replica's leader epoch, leader replied with epoch ${leaderEpochOffset.leaderEpoch} " + + s"unknown to the replica for $tp. " + + s"Will truncate to $intermediateOffsetToTruncateTo and send another leader epoch request to the leader.") + OffsetTruncationState(intermediateOffsetToTruncateTo, truncationCompleted = false) + } else { + val offsetToTruncateTo = min(followerEndOffset, leaderEpochOffset.endOffset) + OffsetTruncationState(min(offsetToTruncateTo, replicaEndOffset), truncationCompleted = true) + } + case None => + // This can happen if the follower was not tracking leader epochs at that point (before the + // upgrade, or if this broker is new). Since the leader replied with epoch < + // requested epoch from follower, so should be safe to truncate to leader's + // offset (this is the same behavior as post-KIP-101 and pre-KIP-279) + warn(s"Based on replica's leader epoch, leader replied with epoch ${leaderEpochOffset.leaderEpoch} " + + s"below any replica's tracked epochs for $tp. " + + s"The leader's offset only ${leaderEpochOffset.endOffset} will be used for truncation.") + OffsetTruncationState(min(leaderEpochOffset.endOffset, replicaEndOffset), truncationCompleted = true) + } + } + } + + /** + * Handle the out of range error. Return false if + * 1) the request succeeded or + * 2) was fenced and this thread haven't received new epoch, + * which means we need not backoff and retry. True if there was a retriable error. + */ + private def handleOutOfRangeError(topicPartition: TopicPartition, + fetchState: PartitionFetchState, + requestEpoch: Optional[Integer]): Boolean = { + try { + val newFetchState = fetchOffsetAndTruncate(topicPartition, fetchState.topicId, fetchState.currentLeaderEpoch) + partitionStates.updateAndMoveToEnd(topicPartition, newFetchState) + info(s"Current offset ${fetchState.fetchOffset} for partition $topicPartition is " + + s"out of range, which typically implies a leader change. Reset fetch offset to ${newFetchState.fetchOffset}") + false + } catch { + case _: FencedLeaderEpochException => + onPartitionFenced(topicPartition, requestEpoch) + + case e @ (_ : UnknownTopicOrPartitionException | + _ : UnknownLeaderEpochException | + _ : NotLeaderOrFollowerException) => + info(s"Could not fetch offset for $topicPartition due to error: ${e.getMessage}") + true + + case e: Throwable => + error(s"Error getting offset for partition $topicPartition", e) + true + } + } + + /** + * Handle a partition whose offset is out of range and return a new fetch offset. + */ + protected def fetchOffsetAndTruncate(topicPartition: TopicPartition, topicId: Option[Uuid], currentLeaderEpoch: Int): PartitionFetchState = { + val replicaEndOffset = logEndOffset(topicPartition) + + /** + * Unclean leader election: A follower goes down, in the meanwhile the leader keeps appending messages. The follower comes back up + * and before it has completely caught up with the leader's logs, all replicas in the ISR go down. The follower is now uncleanly + * elected as the new leader, and it starts appending messages from the client. The old leader comes back up, becomes a follower + * and it may discover that the current leader's end offset is behind its own end offset. + * + * In such a case, truncate the current follower's log to the current leader's end offset and continue fetching. + * + * There is a potential for a mismatch between the logs of the two replicas here. We don't fix this mismatch as of now. + */ + val leaderEndOffset = fetchLatestOffsetFromLeader(topicPartition, currentLeaderEpoch) + if (leaderEndOffset < replicaEndOffset) { + warn(s"Reset fetch offset for partition $topicPartition from $replicaEndOffset to current " + + s"leader's latest offset $leaderEndOffset") + truncate(topicPartition, OffsetTruncationState(leaderEndOffset, truncationCompleted = true)) + + fetcherLagStats.getAndMaybePut(topicPartition).lag = 0 + PartitionFetchState(topicId, leaderEndOffset, Some(0), currentLeaderEpoch, + state = Fetching, lastFetchedEpoch = latestEpoch(topicPartition)) + } else { + /** + * If the leader's log end offset is greater than the follower's log end offset, there are two possibilities: + * 1. The follower could have been down for a long time and when it starts up, its end offset could be smaller than the leader's + * start offset because the leader has deleted old logs (log.logEndOffset < leaderStartOffset). + * 2. When unclean leader election occurs, it is possible that the old leader's high watermark is greater than + * the new leader's log end offset. So when the old leader truncates its offset to its high watermark and starts + * to fetch from the new leader, an OffsetOutOfRangeException will be thrown. After that some more messages are + * produced to the new leader. While the old leader is trying to handle the OffsetOutOfRangeException and query + * the log end offset of the new leader, the new leader's log end offset becomes higher than the follower's log end offset. + * + * In the first case, the follower's current log end offset is smaller than the leader's log start offset. So the + * follower should truncate all its logs, roll out a new segment and start to fetch from the current leader's log + * start offset. + * In the second case, the follower should just keep the current log segments and retry the fetch. In the second + * case, there will be some inconsistency of data between old and new leader. We are not solving it here. + * If users want to have strong consistency guarantees, appropriate configurations needs to be set for both + * brokers and producers. + * + * Putting the two cases together, the follower should fetch from the higher one of its replica log end offset + * and the current leader's log start offset. + */ + val leaderStartOffset = fetchEarliestOffsetFromLeader(topicPartition, currentLeaderEpoch) + warn(s"Reset fetch offset for partition $topicPartition from $replicaEndOffset to current " + + s"leader's start offset $leaderStartOffset") + val offsetToFetch = Math.max(leaderStartOffset, replicaEndOffset) + // Only truncate log when current leader's log start offset is greater than follower's log end offset. + if (leaderStartOffset > replicaEndOffset) + truncateFullyAndStartAt(topicPartition, leaderStartOffset) + + val initialLag = leaderEndOffset - offsetToFetch + fetcherLagStats.getAndMaybePut(topicPartition).lag = initialLag + PartitionFetchState(topicId, offsetToFetch, Some(initialLag), currentLeaderEpoch, + state = Fetching, lastFetchedEpoch = latestEpoch(topicPartition)) + } + } + + def delayPartitions(partitions: Iterable[TopicPartition], delay: Long): Unit = { + partitionMapLock.lockInterruptibly() + try { + for (partition <- partitions) { + Option(partitionStates.stateValue(partition)).foreach { currentFetchState => + if (!currentFetchState.isDelayed) { + partitionStates.updateAndMoveToEnd(partition, PartitionFetchState(currentFetchState.topicId, currentFetchState.fetchOffset, + currentFetchState.lag, currentFetchState.currentLeaderEpoch, Some(new DelayedItem(delay)), + currentFetchState.state, currentFetchState.lastFetchedEpoch)) + } + } + } + partitionMapCond.signalAll() + } finally partitionMapLock.unlock() + } + + def removePartitions(topicPartitions: Set[TopicPartition]): Map[TopicPartition, PartitionFetchState] = { + partitionMapLock.lockInterruptibly() + try { + topicPartitions.map { topicPartition => + val state = partitionStates.stateValue(topicPartition) + partitionStates.remove(topicPartition) + fetcherLagStats.unregister(topicPartition) + topicPartition -> state + }.filter(_._2 != null).toMap + } finally partitionMapLock.unlock() + } + + def partitionCount: Int = { + partitionMapLock.lockInterruptibly() + try partitionStates.size + finally partitionMapLock.unlock() + } + + def partitions: Set[TopicPartition] = { + partitionMapLock.lockInterruptibly() + try partitionStates.partitionSet.asScala.toSet + finally partitionMapLock.unlock() + } + + // Visible for testing + private[server] def fetchState(topicPartition: TopicPartition): Option[PartitionFetchState] = inLock(partitionMapLock) { + Option(partitionStates.stateValue(topicPartition)) + } + + protected def toMemoryRecords(records: Records): MemoryRecords = { + (records: @unchecked) match { + case r: MemoryRecords => r + case r: FileRecords => + val buffer = ByteBuffer.allocate(r.sizeInBytes) + r.readInto(buffer, 0) + MemoryRecords.readableRecords(buffer) + } + } +} + +object AbstractFetcherThread { + + case class ReplicaFetch(partitionData: util.Map[TopicPartition, FetchRequest.PartitionData], fetchRequest: FetchRequest.Builder) + case class ResultWithPartitions[R](result: R, partitionsWithError: Set[TopicPartition]) + +} + +object FetcherMetrics { + val ConsumerLag = "ConsumerLag" + val RequestsPerSec = "RequestsPerSec" + val BytesPerSec = "BytesPerSec" +} + +class FetcherLagMetrics(metricId: ClientIdTopicPartition) extends KafkaMetricsGroup { + + private[this] val lagVal = new AtomicLong(-1L) + private[this] val tags = Map( + "clientId" -> metricId.clientId, + "topic" -> metricId.topicPartition.topic, + "partition" -> metricId.topicPartition.partition.toString) + + newGauge(FetcherMetrics.ConsumerLag, () => lagVal.get, tags) + + def lag_=(newLag: Long): Unit = { + lagVal.set(newLag) + } + + def lag = lagVal.get + + def unregister(): Unit = { + removeMetric(FetcherMetrics.ConsumerLag, tags) + } +} + +class FetcherLagStats(metricId: ClientIdAndBroker) { + private val valueFactory = (k: TopicPartition) => new FetcherLagMetrics(ClientIdTopicPartition(metricId.clientId, k)) + val stats = new Pool[TopicPartition, FetcherLagMetrics](Some(valueFactory)) + + def getAndMaybePut(topicPartition: TopicPartition): FetcherLagMetrics = { + stats.getAndMaybePut(topicPartition) + } + + def unregister(topicPartition: TopicPartition): Unit = { + val lagMetrics = stats.remove(topicPartition) + if (lagMetrics != null) lagMetrics.unregister() + } + + def unregister(): Unit = { + stats.keys.toBuffer.foreach { key: TopicPartition => + unregister(key) + } + } +} + +class FetcherStats(metricId: ClientIdAndBroker) extends KafkaMetricsGroup { + val tags = Map("clientId" -> metricId.clientId, + "brokerHost" -> metricId.brokerHost, + "brokerPort" -> metricId.brokerPort.toString) + + val requestRate = newMeter(FetcherMetrics.RequestsPerSec, "requests", TimeUnit.SECONDS, tags) + + val byteRate = newMeter(FetcherMetrics.BytesPerSec, "bytes", TimeUnit.SECONDS, tags) + + def unregister(): Unit = { + removeMetric(FetcherMetrics.RequestsPerSec, tags) + removeMetric(FetcherMetrics.BytesPerSec, tags) + } + +} + +case class ClientIdTopicPartition(clientId: String, topicPartition: TopicPartition) { + override def toString: String = s"$clientId-$topicPartition" +} + +sealed trait ReplicaState +case object Truncating extends ReplicaState +case object Fetching extends ReplicaState + +object PartitionFetchState { + def apply(topicId: Option[Uuid], offset: Long, lag: Option[Long], currentLeaderEpoch: Int, state: ReplicaState, + lastFetchedEpoch: Option[Int]): PartitionFetchState = { + PartitionFetchState(topicId, offset, lag, currentLeaderEpoch, None, state, lastFetchedEpoch) + } +} + + +/** + * case class to keep partition offset and its state(truncatingLog, delayed) + * This represents a partition as being either: + * (1) Truncating its log, for example having recently become a follower + * (2) Delayed, for example due to an error, where we subsequently back off a bit + * (3) ReadyForFetch, the is the active state where the thread is actively fetching data. + */ +case class PartitionFetchState(topicId: Option[Uuid], + fetchOffset: Long, + lag: Option[Long], + currentLeaderEpoch: Int, + delay: Option[DelayedItem], + state: ReplicaState, + lastFetchedEpoch: Option[Int]) { + + def isReadyForFetch: Boolean = state == Fetching && !isDelayed + + def isReplicaInSync: Boolean = lag.isDefined && lag.get <= 0 + + def isTruncating: Boolean = state == Truncating && !isDelayed + + def isDelayed: Boolean = delay.exists(_.getDelay(TimeUnit.MILLISECONDS) > 0) + + override def toString: String = { + s"FetchState(topicId=$topicId" + + s", fetchOffset=$fetchOffset" + + s", currentLeaderEpoch=$currentLeaderEpoch" + + s", lastFetchedEpoch=$lastFetchedEpoch" + + s", state=$state" + + s", lag=$lag" + + s", delay=${delay.map(_.delayMs).getOrElse(0)}ms" + + s")" + } + + def updateTopicId(topicId: Option[Uuid]): PartitionFetchState = { + this.copy(topicId = topicId) + } +} + +case class OffsetTruncationState(offset: Long, truncationCompleted: Boolean) { + + def this(offset: Long) = this(offset, true) + + override def toString: String = s"TruncationState(offset=$offset, completed=$truncationCompleted)" +} + +case class OffsetAndEpoch(offset: Long, leaderEpoch: Int) { + override def toString: String = { + s"(offset=$offset, leaderEpoch=$leaderEpoch)" + } +} diff --git a/core/src/main/scala/kafka/server/AclApis.scala b/core/src/main/scala/kafka/server/AclApis.scala new file mode 100644 index 0000000..97b685b --- /dev/null +++ b/core/src/main/scala/kafka/server/AclApis.scala @@ -0,0 +1,155 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.network.RequestChannel +import kafka.security.authorizer.AuthorizerUtils +import kafka.utils.Logging +import org.apache.kafka.common.acl.AclOperation._ +import org.apache.kafka.common.acl.AclBinding +import org.apache.kafka.common.errors._ +import org.apache.kafka.common.message.CreateAclsResponseData.AclCreationResult +import org.apache.kafka.common.message._ +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests._ +import org.apache.kafka.common.resource.Resource.CLUSTER_NAME +import org.apache.kafka.common.resource.ResourceType +import org.apache.kafka.server.authorizer._ +import java.util + +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable +import scala.compat.java8.OptionConverters._ +import scala.jdk.CollectionConverters._ + +/** + * Logic to handle ACL requests. + */ +class AclApis(authHelper: AuthHelper, + authorizer: Option[Authorizer], + requestHelper: RequestHandlerHelper, + name: String, + config: KafkaConfig) extends Logging { + this.logIdent = "[AclApis-%s-%s] ".format(name, config.nodeId) + private val alterAclsPurgatory = + new DelayedFuturePurgatory(purgatoryName = "AlterAcls", brokerId = config.nodeId) + + def isClosed: Boolean = alterAclsPurgatory.isShutdown + + def close(): Unit = alterAclsPurgatory.shutdown() + + def handleDescribeAcls(request: RequestChannel.Request): Unit = { + authHelper.authorizeClusterOperation(request, DESCRIBE) + val describeAclsRequest = request.body[DescribeAclsRequest] + authorizer match { + case None => + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new DescribeAclsResponse(new DescribeAclsResponseData() + .setErrorCode(Errors.SECURITY_DISABLED.code) + .setErrorMessage("No Authorizer is configured on the broker") + .setThrottleTimeMs(requestThrottleMs), + describeAclsRequest.version)) + case Some(auth) => + val filter = describeAclsRequest.filter + val returnedAcls = new util.HashSet[AclBinding]() + auth.acls(filter).forEach(returnedAcls.add) + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new DescribeAclsResponse(new DescribeAclsResponseData() + .setThrottleTimeMs(requestThrottleMs) + .setResources(DescribeAclsResponse.aclsResources(returnedAcls)), + describeAclsRequest.version)) + } + } + + def handleCreateAcls(request: RequestChannel.Request): Unit = { + authHelper.authorizeClusterOperation(request, ALTER) + val createAclsRequest = request.body[CreateAclsRequest] + + authorizer match { + case None => requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + createAclsRequest.getErrorResponse(requestThrottleMs, + new SecurityDisabledException("No Authorizer is configured."))) + case Some(auth) => + val allBindings = createAclsRequest.aclCreations.asScala.map(CreateAclsRequest.aclBinding) + val errorResults = mutable.Map[AclBinding, AclCreateResult]() + val validBindings = new ArrayBuffer[AclBinding] + allBindings.foreach { acl => + val resource = acl.pattern + val throwable = if (resource.resourceType == ResourceType.CLUSTER && !AuthorizerUtils.isClusterResource(resource.name)) + new InvalidRequestException("The only valid name for the CLUSTER resource is " + CLUSTER_NAME) + else if (resource.name.isEmpty) + new InvalidRequestException("Invalid empty resource name") + else + null + if (throwable != null) { + debug(s"Failed to add acl $acl to $resource", throwable) + errorResults(acl) = new AclCreateResult(throwable) + } else + validBindings += acl + } + + val createResults = auth.createAcls(request.context, validBindings.asJava).asScala.map(_.toCompletableFuture) + + def sendResponseCallback(): Unit = { + val aclCreationResults = allBindings.map { acl => + val result = errorResults.getOrElse(acl, createResults(validBindings.indexOf(acl)).get) + val creationResult = new AclCreationResult() + result.exception.asScala.foreach { throwable => + val apiError = ApiError.fromThrowable(throwable) + creationResult + .setErrorCode(apiError.error.code) + .setErrorMessage(apiError.message) + } + creationResult + } + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new CreateAclsResponse(new CreateAclsResponseData() + .setThrottleTimeMs(requestThrottleMs) + .setResults(aclCreationResults.asJava))) + } + + alterAclsPurgatory.tryCompleteElseWatch(config.connectionsMaxIdleMs, createResults, sendResponseCallback) + } + } + + def handleDeleteAcls(request: RequestChannel.Request): Unit = { + authHelper.authorizeClusterOperation(request, ALTER) + val deleteAclsRequest = request.body[DeleteAclsRequest] + authorizer match { + case None => + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + deleteAclsRequest.getErrorResponse(requestThrottleMs, + new SecurityDisabledException("No Authorizer is configured."))) + case Some(auth) => + + val deleteResults = auth.deleteAcls(request.context, deleteAclsRequest.filters) + .asScala.map(_.toCompletableFuture).toList + + def sendResponseCallback(): Unit = { + val filterResults = deleteResults.map(_.get).map(DeleteAclsResponse.filterResult).asJava + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new DeleteAclsResponse( + new DeleteAclsResponseData() + .setThrottleTimeMs(requestThrottleMs) + .setFilterResults(filterResults), + deleteAclsRequest.version)) + } + alterAclsPurgatory.tryCompleteElseWatch(config.connectionsMaxIdleMs, deleteResults, sendResponseCallback) + } + } +} diff --git a/core/src/main/scala/kafka/server/ActionQueue.scala b/core/src/main/scala/kafka/server/ActionQueue.scala new file mode 100644 index 0000000..1b6b832 --- /dev/null +++ b/core/src/main/scala/kafka/server/ActionQueue.scala @@ -0,0 +1,56 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util.concurrent.ConcurrentLinkedQueue + +import kafka.utils.Logging + +/** + * This queue is used to collect actions which need to be executed later. One use case is that ReplicaManager#appendRecords + * produces record changes so we need to check and complete delayed requests. In order to avoid conflicting locking, + * we add those actions to this queue and then complete them at the end of KafkaApis.handle() or DelayedJoin.onExpiration. + */ +class ActionQueue extends Logging { + private val queue = new ConcurrentLinkedQueue[() => Unit]() + + /** + * add action to this queue. + * @param action action + */ + def add(action: () => Unit): Unit = queue.add(action) + + /** + * try to complete all delayed actions + */ + def tryCompleteActions(): Unit = { + val maxToComplete = queue.size() + var count = 0 + var done = false + while (!done && count < maxToComplete) { + try { + val action = queue.poll() + if (action == null) done = true + else action() + } catch { + case e: Throwable => + error("failed to complete delayed actions", e) + } finally count += 1 + } + } +} diff --git a/core/src/main/scala/kafka/server/AlterIsrManager.scala b/core/src/main/scala/kafka/server/AlterIsrManager.scala new file mode 100644 index 0000000..b8507d0 --- /dev/null +++ b/core/src/main/scala/kafka/server/AlterIsrManager.scala @@ -0,0 +1,294 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.util +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.{CompletableFuture, ConcurrentHashMap, TimeUnit} + +import kafka.api.LeaderAndIsr +import kafka.metrics.KafkaMetricsGroup +import kafka.utils.{KafkaScheduler, Logging, Scheduler} +import kafka.zk.KafkaZkClient +import org.apache.kafka.clients.ClientResponse +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.errors.OperationNotAttemptedException +import org.apache.kafka.common.message.{AlterIsrRequestData, AlterIsrResponseData} +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.{AlterIsrRequest, AlterIsrResponse} +import org.apache.kafka.common.utils.Time + +import scala.collection.mutable +import scala.collection.mutable.ListBuffer +import scala.jdk.CollectionConverters._ + +/** + * Handles updating the ISR by sending AlterIsr requests to the controller (as of 2.7) or by updating ZK directly + * (prior to 2.7). Updating the ISR is an asynchronous operation, so partitions will learn about the result of their + * request through a callback. + * + * Note that ISR state changes can still be initiated by the controller and sent to the partitions via LeaderAndIsr + * requests. + */ +trait AlterIsrManager { + def start(): Unit = {} + + def shutdown(): Unit = {} + + def submit( + topicPartition: TopicPartition, + leaderAndIsr: LeaderAndIsr, + controllerEpoch: Int + ): CompletableFuture[LeaderAndIsr] +} + +case class AlterIsrItem(topicPartition: TopicPartition, + leaderAndIsr: LeaderAndIsr, + future: CompletableFuture[LeaderAndIsr], + controllerEpoch: Int) // controllerEpoch needed for Zk impl + +object AlterIsrManager { + + /** + * Factory to AlterIsr based implementation, used when IBP >= 2.7-IV2 + */ + def apply( + config: KafkaConfig, + metadataCache: MetadataCache, + scheduler: KafkaScheduler, + time: Time, + metrics: Metrics, + threadNamePrefix: Option[String], + brokerEpochSupplier: () => Long, + brokerId: Int + ): AlterIsrManager = { + val nodeProvider = MetadataCacheControllerNodeProvider(config, metadataCache) + + val channelManager = BrokerToControllerChannelManager( + controllerNodeProvider = nodeProvider, + time = time, + metrics = metrics, + config = config, + channelName = "alterIsr", + threadNamePrefix = threadNamePrefix, + retryTimeoutMs = Long.MaxValue + ) + new DefaultAlterIsrManager( + controllerChannelManager = channelManager, + scheduler = scheduler, + time = time, + brokerId = brokerId, + brokerEpochSupplier = brokerEpochSupplier + ) + } + + /** + * Factory for ZK based implementation, used when IBP < 2.7-IV2 + */ + def apply( + scheduler: Scheduler, + time: Time, + zkClient: KafkaZkClient + ): AlterIsrManager = { + new ZkIsrManager(scheduler, time, zkClient) + } + +} + +class DefaultAlterIsrManager( + val controllerChannelManager: BrokerToControllerChannelManager, + val scheduler: Scheduler, + val time: Time, + val brokerId: Int, + val brokerEpochSupplier: () => Long +) extends AlterIsrManager with Logging with KafkaMetricsGroup { + + // Used to allow only one pending ISR update per partition (visible for testing) + private[server] val unsentIsrUpdates: util.Map[TopicPartition, AlterIsrItem] = new ConcurrentHashMap[TopicPartition, AlterIsrItem]() + + // Used to allow only one in-flight request at a time + private val inflightRequest: AtomicBoolean = new AtomicBoolean(false) + + override def start(): Unit = { + controllerChannelManager.start() + } + + override def shutdown(): Unit = { + controllerChannelManager.shutdown() + } + + override def submit( + topicPartition: TopicPartition, + leaderAndIsr: LeaderAndIsr, + controllerEpoch: Int + ): CompletableFuture[LeaderAndIsr] = { + val future = new CompletableFuture[LeaderAndIsr]() + val alterIsrItem = AlterIsrItem(topicPartition, leaderAndIsr, future, controllerEpoch) + val enqueued = unsentIsrUpdates.putIfAbsent(alterIsrItem.topicPartition, alterIsrItem) == null + if (enqueued) { + maybePropagateIsrChanges() + } else { + future.completeExceptionally(new OperationNotAttemptedException( + s"Failed to enqueue ISR change state $leaderAndIsr for partition $topicPartition")) + } + future + } + + private[server] def maybePropagateIsrChanges(): Unit = { + // Send all pending items if there is not already a request in-flight. + if (!unsentIsrUpdates.isEmpty && inflightRequest.compareAndSet(false, true)) { + // Copy current unsent ISRs but don't remove from the map, they get cleared in the response handler + val inflightAlterIsrItems = new ListBuffer[AlterIsrItem]() + unsentIsrUpdates.values().forEach(item => inflightAlterIsrItems.append(item)) + sendRequest(inflightAlterIsrItems.toSeq) + } + } + + private[server] def clearInFlightRequest(): Unit = { + if (!inflightRequest.compareAndSet(true, false)) { + warn("Attempting to clear AlterIsr in-flight flag when no apparent request is in-flight") + } + } + + private def sendRequest(inflightAlterIsrItems: Seq[AlterIsrItem]): Unit = { + val message = buildRequest(inflightAlterIsrItems) + debug(s"Sending AlterIsr to controller $message") + + // We will not timeout AlterISR request, instead letting it retry indefinitely + // until a response is received, or a new LeaderAndIsr overwrites the existing isrState + // which causes the response for those partitions to be ignored. + controllerChannelManager.sendRequest(new AlterIsrRequest.Builder(message), + new ControllerRequestCompletionHandler { + override def onComplete(response: ClientResponse): Unit = { + debug(s"Received AlterIsr response $response") + val error = try { + if (response.authenticationException != null) { + // For now we treat authentication errors as retriable. We use the + // `NETWORK_EXCEPTION` error code for lack of a good alternative. + // Note that `BrokerToControllerChannelManager` will still log the + // authentication errors so that users have a chance to fix the problem. + Errors.NETWORK_EXCEPTION + } else if (response.versionMismatch != null) { + Errors.UNSUPPORTED_VERSION + } else { + val body = response.responseBody().asInstanceOf[AlterIsrResponse] + handleAlterIsrResponse(body, message.brokerEpoch, inflightAlterIsrItems) + } + } finally { + // clear the flag so future requests can proceed + clearInFlightRequest() + } + + // check if we need to send another request right away + error match { + case Errors.NONE => + // In the normal case, check for pending updates to send immediately + maybePropagateIsrChanges() + case _ => + // If we received a top-level error from the controller, retry the request in the near future + scheduler.schedule("send-alter-isr", () => maybePropagateIsrChanges(), 50, -1, TimeUnit.MILLISECONDS) + } + } + + override def onTimeout(): Unit = { + throw new IllegalStateException("Encountered unexpected timeout when sending AlterIsr to the controller") + } + }) + } + + private def buildRequest(inflightAlterIsrItems: Seq[AlterIsrItem]): AlterIsrRequestData = { + val message = new AlterIsrRequestData() + .setBrokerId(brokerId) + .setBrokerEpoch(brokerEpochSupplier.apply()) + .setTopics(new util.ArrayList()) + + inflightAlterIsrItems.groupBy(_.topicPartition.topic).foreach(entry => { + val topicPart = new AlterIsrRequestData.TopicData() + .setName(entry._1) + .setPartitions(new util.ArrayList()) + message.topics().add(topicPart) + entry._2.foreach(item => { + topicPart.partitions().add(new AlterIsrRequestData.PartitionData() + .setPartitionIndex(item.topicPartition.partition) + .setLeaderEpoch(item.leaderAndIsr.leaderEpoch) + .setNewIsr(item.leaderAndIsr.isr.map(Integer.valueOf).asJava) + .setCurrentIsrVersion(item.leaderAndIsr.zkVersion) + ) + }) + }) + message + } + + def handleAlterIsrResponse(alterIsrResponse: AlterIsrResponse, + sentBrokerEpoch: Long, + inflightAlterIsrItems: Seq[AlterIsrItem]): Errors = { + val data: AlterIsrResponseData = alterIsrResponse.data + + Errors.forCode(data.errorCode) match { + case Errors.STALE_BROKER_EPOCH => + warn(s"Broker had a stale broker epoch ($sentBrokerEpoch), retrying.") + case Errors.CLUSTER_AUTHORIZATION_FAILED => + error(s"Broker is not authorized to send AlterIsr to controller", + Errors.CLUSTER_AUTHORIZATION_FAILED.exception("Broker is not authorized to send AlterIsr to controller")) + case Errors.NONE => + // Collect partition-level responses to pass to the callbacks + val partitionResponses: mutable.Map[TopicPartition, Either[Errors, LeaderAndIsr]] = + new mutable.HashMap[TopicPartition, Either[Errors, LeaderAndIsr]]() + data.topics.forEach { topic => + topic.partitions().forEach(partition => { + val tp = new TopicPartition(topic.name, partition.partitionIndex) + val error = Errors.forCode(partition.errorCode()) + debug(s"Controller successfully handled AlterIsr request for $tp: $partition") + if (error == Errors.NONE) { + val newLeaderAndIsr = new LeaderAndIsr(partition.leaderId, partition.leaderEpoch, + partition.isr.asScala.toList.map(_.toInt), partition.currentIsrVersion) + partitionResponses(tp) = Right(newLeaderAndIsr) + } else { + partitionResponses(tp) = Left(error) + } + }) + } + + // Iterate across the items we sent rather than what we received to ensure we run the callback even if a + // partition was somehow erroneously excluded from the response. Note that these callbacks are run from + // the leaderIsrUpdateLock write lock in Partition#sendAlterIsrRequest + inflightAlterIsrItems.foreach { inflightAlterIsr => + partitionResponses.get(inflightAlterIsr.topicPartition) match { + case Some(leaderAndIsrOrError) => + try { + leaderAndIsrOrError match { + case Left(error) => inflightAlterIsr.future.completeExceptionally(error.exception) + case Right(leaderAndIsr) => inflightAlterIsr.future.complete(leaderAndIsr) + } + } finally { + // Regardless of callback outcome, we need to clear from the unsent updates map to unblock further updates + unsentIsrUpdates.remove(inflightAlterIsr.topicPartition) + } + case None => + // Don't remove this partition from the update map so it will get re-sent + warn(s"Partition ${inflightAlterIsr.topicPartition} was sent but not included in the response") + } + } + + case e: Errors => + warn(s"Controller returned an unexpected top-level error when handling AlterIsr request: $e") + } + + Errors.forCode(data.errorCode) + } +} diff --git a/core/src/main/scala/kafka/server/ApiVersionManager.scala b/core/src/main/scala/kafka/server/ApiVersionManager.scala new file mode 100644 index 0000000..640e98d --- /dev/null +++ b/core/src/main/scala/kafka/server/ApiVersionManager.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import kafka.api.ApiVersion +import kafka.network +import kafka.network.RequestChannel +import org.apache.kafka.common.message.ApiMessageType.ListenerType +import org.apache.kafka.common.message.ApiVersionsResponseData +import org.apache.kafka.common.protocol.ApiKeys +import org.apache.kafka.common.requests.ApiVersionsResponse + +import scala.jdk.CollectionConverters._ + +trait ApiVersionManager { + def listenerType: ListenerType + def enabledApis: collection.Set[ApiKeys] + def apiVersionResponse(throttleTimeMs: Int): ApiVersionsResponse + def isApiEnabled(apiKey: ApiKeys): Boolean = enabledApis.contains(apiKey) + def newRequestMetrics: RequestChannel.Metrics = new network.RequestChannel.Metrics(enabledApis) +} + +object ApiVersionManager { + def apply( + listenerType: ListenerType, + config: KafkaConfig, + forwardingManager: Option[ForwardingManager], + features: BrokerFeatures, + featureCache: FinalizedFeatureCache + ): ApiVersionManager = { + new DefaultApiVersionManager( + listenerType, + config.interBrokerProtocolVersion, + forwardingManager, + features, + featureCache + ) + } +} + +class SimpleApiVersionManager( + val listenerType: ListenerType, + val enabledApis: collection.Set[ApiKeys] +) extends ApiVersionManager { + + def this(listenerType: ListenerType) = { + this(listenerType, ApiKeys.apisForListener(listenerType).asScala) + } + + private val apiVersions = ApiVersionsResponse.collectApis(enabledApis.asJava) + + override def apiVersionResponse(requestThrottleMs: Int): ApiVersionsResponse = { + ApiVersionsResponse.createApiVersionsResponse(0, apiVersions) + } +} + +class DefaultApiVersionManager( + val listenerType: ListenerType, + interBrokerProtocolVersion: ApiVersion, + forwardingManager: Option[ForwardingManager], + features: BrokerFeatures, + featureCache: FinalizedFeatureCache +) extends ApiVersionManager { + + override def apiVersionResponse(throttleTimeMs: Int): ApiVersionsResponse = { + val supportedFeatures = features.supportedFeatures + val finalizedFeaturesOpt = featureCache.get + val controllerApiVersions = forwardingManager.flatMap(_.controllerApiVersions) + + val response = finalizedFeaturesOpt match { + case Some(finalizedFeatures) => ApiVersion.apiVersionsResponse( + throttleTimeMs, + interBrokerProtocolVersion.recordVersion, + supportedFeatures, + finalizedFeatures.features, + finalizedFeatures.epoch, + controllerApiVersions, + listenerType) + case None => ApiVersion.apiVersionsResponse( + throttleTimeMs, + interBrokerProtocolVersion.recordVersion, + supportedFeatures, + controllerApiVersions, + listenerType) + } + + // This is a temporary workaround in order to allow testing of forwarding + // in integration tests. We can remove this after the KRaft controller + // is available for integration testing. + if (forwardingManager.isDefined) { + response.data.apiKeys.add( + new ApiVersionsResponseData.ApiVersion() + .setApiKey(ApiKeys.ENVELOPE.id) + .setMinVersion(ApiKeys.ENVELOPE.oldestVersion) + .setMaxVersion(ApiKeys.ENVELOPE.latestVersion) + ) + } + + response + } + + override def enabledApis: collection.Set[ApiKeys] = { + forwardingManager match { + case Some(_) => ApiKeys.apisForListener(listenerType).asScala ++ Set(ApiKeys.ENVELOPE) + case None => ApiKeys.apisForListener(listenerType).asScala + } + } + + override def isApiEnabled(apiKey: ApiKeys): Boolean = { + apiKey.inScope(listenerType) || (apiKey == ApiKeys.ENVELOPE && forwardingManager.isDefined) + } +} diff --git a/core/src/main/scala/kafka/server/AuthHelper.scala b/core/src/main/scala/kafka/server/AuthHelper.scala new file mode 100644 index 0000000..50a1351 --- /dev/null +++ b/core/src/main/scala/kafka/server/AuthHelper.scala @@ -0,0 +1,133 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.lang.{Byte => JByte} +import java.util.Collections +import kafka.network.RequestChannel +import kafka.security.authorizer.AclEntry +import kafka.utils.CoreUtils +import org.apache.kafka.common.acl.AclOperation +import org.apache.kafka.common.errors.ClusterAuthorizationException +import org.apache.kafka.common.requests.RequestContext +import org.apache.kafka.common.resource.Resource.CLUSTER_NAME +import org.apache.kafka.common.resource.ResourceType.CLUSTER +import org.apache.kafka.common.resource.{PatternType, Resource, ResourcePattern, ResourceType} +import org.apache.kafka.common.utils.Utils +import org.apache.kafka.server.authorizer.{Action, AuthorizationResult, Authorizer} + +import scala.collection.{Map, Seq} +import scala.jdk.CollectionConverters._ + +class AuthHelper(authorizer: Option[Authorizer]) { + + def authorize(requestContext: RequestContext, + operation: AclOperation, + resourceType: ResourceType, + resourceName: String, + logIfAllowed: Boolean = true, + logIfDenied: Boolean = true, + refCount: Int = 1): Boolean = { + authorizer.forall { authZ => + val resource = new ResourcePattern(resourceType, resourceName, PatternType.LITERAL) + val actions = Collections.singletonList(new Action(operation, resource, refCount, logIfAllowed, logIfDenied)) + authZ.authorize(requestContext, actions).get(0) == AuthorizationResult.ALLOWED + } + } + + def authorizeClusterOperation(request: RequestChannel.Request, operation: AclOperation): Unit = { + if (!authorize(request.context, operation, CLUSTER, CLUSTER_NAME)) + throw new ClusterAuthorizationException(s"Request $request is not authorized.") + } + + def authorizedOperations(request: RequestChannel.Request, resource: Resource): Int = { + val supportedOps = AclEntry.supportedOperations(resource.resourceType).toList + val authorizedOps = authorizer match { + case Some(authZ) => + val resourcePattern = new ResourcePattern(resource.resourceType, resource.name, PatternType.LITERAL) + val actions = supportedOps.map { op => new Action(op, resourcePattern, 1, false, false) } + authZ.authorize(request.context, actions.asJava).asScala + .zip(supportedOps) + .filter(_._1 == AuthorizationResult.ALLOWED) + .map(_._2).toSet + case None => + supportedOps.toSet + } + Utils.to32BitField(authorizedOps.map(operation => operation.code.asInstanceOf[JByte]).asJava) + } + + def authorizeByResourceType(requestContext: RequestContext, operation: AclOperation, + resourceType: ResourceType): Boolean = { + authorizer.forall { authZ => + authZ.authorizeByResourceType(requestContext, operation, resourceType) == AuthorizationResult.ALLOWED + } + } + + def partitionSeqByAuthorized[T](requestContext: RequestContext, + operation: AclOperation, + resourceType: ResourceType, + resources: Seq[T], + logIfAllowed: Boolean = true, + logIfDenied: Boolean = true)(resourceName: T => String): (Seq[T], Seq[T]) = { + authorizer match { + case Some(_) => + val authorizedResourceNames = filterByAuthorized(requestContext, operation, resourceType, + resources, logIfAllowed, logIfDenied)(resourceName) + resources.partition(resource => authorizedResourceNames.contains(resourceName(resource))) + case None => (resources, Seq.empty) + } + } + + def partitionMapByAuthorized[K, V](requestContext: RequestContext, + operation: AclOperation, + resourceType: ResourceType, + resources: Map[K, V], + logIfAllowed: Boolean = true, + logIfDenied: Boolean = true)(resourceName: K => String): (Map[K, V], Map[K, V]) = { + authorizer match { + case Some(_) => + val authorizedResourceNames = filterByAuthorized(requestContext, operation, resourceType, + resources.keySet, logIfAllowed, logIfDenied)(resourceName) + resources.partition { case (k, _) => authorizedResourceNames.contains(resourceName(k)) } + case None => (resources, Map.empty) + } + } + + def filterByAuthorized[T](requestContext: RequestContext, + operation: AclOperation, + resourceType: ResourceType, + resources: Iterable[T], + logIfAllowed: Boolean = true, + logIfDenied: Boolean = true)(resourceName: T => String): Set[String] = { + authorizer match { + case Some(authZ) => + val resourceNameToCount = CoreUtils.groupMapReduce(resources)(resourceName)(_ => 1)(_ + _) + val actions = resourceNameToCount.iterator.map { case (resourceName, count) => + val resource = new ResourcePattern(resourceType, resourceName, PatternType.LITERAL) + new Action(operation, resource, count, logIfAllowed, logIfDenied) + }.toBuffer + authZ.authorize(requestContext, actions.asJava).asScala + .zip(resourceNameToCount.keySet) + .collect { case (authzResult, resourceName) if authzResult == AuthorizationResult.ALLOWED => + resourceName + }.toSet + case None => resources.iterator.map(resourceName).toSet + } + } + +} diff --git a/core/src/main/scala/kafka/server/AutoTopicCreationManager.scala b/core/src/main/scala/kafka/server/AutoTopicCreationManager.scala new file mode 100644 index 0000000..60796ab --- /dev/null +++ b/core/src/main/scala/kafka/server/AutoTopicCreationManager.scala @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicReference +import java.util.{Collections, Properties} + +import kafka.controller.KafkaController +import kafka.coordinator.group.GroupCoordinator +import kafka.coordinator.transaction.TransactionCoordinator +import kafka.utils.Logging +import org.apache.kafka.clients.ClientResponse +import org.apache.kafka.common.errors.InvalidTopicException +import org.apache.kafka.common.internals.Topic +import org.apache.kafka.common.internals.Topic.{GROUP_METADATA_TOPIC_NAME, TRANSACTION_STATE_TOPIC_NAME} +import org.apache.kafka.common.message.CreateTopicsRequestData +import org.apache.kafka.common.message.CreateTopicsRequestData.{CreatableTopic, CreateableTopicConfig, CreateableTopicConfigCollection} +import org.apache.kafka.common.message.MetadataResponseData.MetadataResponseTopic +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.requests.{ApiError, CreateTopicsRequest, RequestContext, RequestHeader} + +import scala.collection.{Map, Seq, Set, mutable} +import scala.jdk.CollectionConverters._ + +trait AutoTopicCreationManager { + + def createTopics( + topicNames: Set[String], + controllerMutationQuota: ControllerMutationQuota, + metadataRequestContext: Option[RequestContext] + ): Seq[MetadataResponseTopic] +} + +object AutoTopicCreationManager { + + def apply( + config: KafkaConfig, + metadataCache: MetadataCache, + threadNamePrefix: Option[String], + channelManager: Option[BrokerToControllerChannelManager], + adminManager: Option[ZkAdminManager], + controller: Option[KafkaController], + groupCoordinator: GroupCoordinator, + txnCoordinator: TransactionCoordinator, + ): AutoTopicCreationManager = { + new DefaultAutoTopicCreationManager(config, channelManager, adminManager, + controller, groupCoordinator, txnCoordinator) + } +} + +class DefaultAutoTopicCreationManager( + config: KafkaConfig, + channelManager: Option[BrokerToControllerChannelManager], + adminManager: Option[ZkAdminManager], + controller: Option[KafkaController], + groupCoordinator: GroupCoordinator, + txnCoordinator: TransactionCoordinator +) extends AutoTopicCreationManager with Logging { + if (controller.isEmpty && channelManager.isEmpty) { + throw new IllegalArgumentException("Must supply a channel manager if not supplying a controller") + } + + private val inflightTopics = Collections.newSetFromMap(new ConcurrentHashMap[String, java.lang.Boolean]()) + + /** + * Initiate auto topic creation for the given topics. + * + * @param topics the topics to create + * @param controllerMutationQuota the controller mutation quota for topic creation + * @param metadataRequestContext defined when creating topics on behalf of the client. The goal here is to preserve + * original client principal for auditing, thus needing to wrap a plain CreateTopicsRequest + * inside Envelope to send to the controller when forwarding is enabled. + * @return auto created topic metadata responses + */ + override def createTopics( + topics: Set[String], + controllerMutationQuota: ControllerMutationQuota, + metadataRequestContext: Option[RequestContext] + ): Seq[MetadataResponseTopic] = { + val (creatableTopics, uncreatableTopicResponses) = filterCreatableTopics(topics) + + val creatableTopicResponses = if (creatableTopics.isEmpty) { + Seq.empty + } else if (controller.isEmpty || !controller.get.isActive && channelManager.isDefined) { + sendCreateTopicRequest(creatableTopics, metadataRequestContext) + } else { + createTopicsInZk(creatableTopics, controllerMutationQuota) + } + + uncreatableTopicResponses ++ creatableTopicResponses + } + + private def createTopicsInZk( + creatableTopics: Map[String, CreatableTopic], + controllerMutationQuota: ControllerMutationQuota + ): Seq[MetadataResponseTopic] = { + val topicErrors = new AtomicReference[Map[String, ApiError]]() + try { + // Note that we use timeout = 0 since we do not need to wait for metadata propagation + // and we want to get the response error immediately. + adminManager.get.createTopics( + timeout = 0, + validateOnly = false, + creatableTopics, + Map.empty, + controllerMutationQuota, + topicErrors.set + ) + + val creatableTopicResponses = Option(topicErrors.get) match { + case Some(errors) => + errors.toSeq.map { case (topic, apiError) => + val error = apiError.error match { + case Errors.TOPIC_ALREADY_EXISTS | Errors.REQUEST_TIMED_OUT => + // The timeout error is expected because we set timeout=0. This + // nevertheless indicates that the topic metadata was created + // successfully, so we return LEADER_NOT_AVAILABLE. + Errors.LEADER_NOT_AVAILABLE + case error => error + } + + new MetadataResponseTopic() + .setErrorCode(error.code) + .setName(topic) + .setIsInternal(Topic.isInternal(topic)) + } + + case None => + creatableTopics.keySet.toSeq.map { topic => + new MetadataResponseTopic() + .setErrorCode(Errors.UNKNOWN_TOPIC_OR_PARTITION.code) + .setName(topic) + .setIsInternal(Topic.isInternal(topic)) + } + } + + creatableTopicResponses + } finally { + clearInflightRequests(creatableTopics) + } + } + + private def sendCreateTopicRequest( + creatableTopics: Map[String, CreatableTopic], + metadataRequestContext: Option[RequestContext] + ): Seq[MetadataResponseTopic] = { + val topicsToCreate = new CreateTopicsRequestData.CreatableTopicCollection(creatableTopics.size) + topicsToCreate.addAll(creatableTopics.values.asJavaCollection) + + val createTopicsRequest = new CreateTopicsRequest.Builder( + new CreateTopicsRequestData() + .setTimeoutMs(config.requestTimeoutMs) + .setTopics(topicsToCreate) + ) + + val requestCompletionHandler = new ControllerRequestCompletionHandler { + override def onTimeout(): Unit = { + clearInflightRequests(creatableTopics) + debug(s"Auto topic creation timed out for ${creatableTopics.keys}.") + } + + override def onComplete(response: ClientResponse): Unit = { + clearInflightRequests(creatableTopics) + if (response.authenticationException() != null) { + warn(s"Auto topic creation failed for ${creatableTopics.keys} with authentication exception") + } else if (response.versionMismatch() != null) { + warn(s"Auto topic creation failed for ${creatableTopics.keys} with invalid version exception") + } else { + debug(s"Auto topic creation completed for ${creatableTopics.keys} with response ${response.responseBody}.") + } + } + } + + val channelManager = this.channelManager.getOrElse { + throw new IllegalStateException("Channel manager must be defined in order to send CreateTopic requests.") + } + + val request = metadataRequestContext.map { context => + val requestVersion = + channelManager.controllerApiVersions() match { + case None => + // We will rely on the Metadata request to be retried in the case + // that the latest version is not usable by the controller. + ApiKeys.CREATE_TOPICS.latestVersion() + case Some(nodeApiVersions) => + nodeApiVersions.latestUsableVersion(ApiKeys.CREATE_TOPICS) + } + + // Borrow client information such as client id and correlation id from the original request, + // in order to correlate the create request with the original metadata request. + val requestHeader = new RequestHeader(ApiKeys.CREATE_TOPICS, + requestVersion, + context.clientId, + context.correlationId) + ForwardingManager.buildEnvelopeRequest(context, + createTopicsRequest.build(requestVersion).serializeWithHeader(requestHeader)) + }.getOrElse(createTopicsRequest) + + channelManager.sendRequest(request, requestCompletionHandler) + + val creatableTopicResponses = creatableTopics.keySet.toSeq.map { topic => + new MetadataResponseTopic() + .setErrorCode(Errors.UNKNOWN_TOPIC_OR_PARTITION.code) + .setName(topic) + .setIsInternal(Topic.isInternal(topic)) + } + + info(s"Sent auto-creation request for ${creatableTopics.keys} to the active controller.") + creatableTopicResponses + } + + private def clearInflightRequests(creatableTopics: Map[String, CreatableTopic]): Unit = { + creatableTopics.keySet.foreach(inflightTopics.remove) + debug(s"Cleared inflight topic creation state for $creatableTopics") + } + + private def creatableTopic(topic: String): CreatableTopic = { + topic match { + case GROUP_METADATA_TOPIC_NAME => + new CreatableTopic() + .setName(topic) + .setNumPartitions(config.offsetsTopicPartitions) + .setReplicationFactor(config.offsetsTopicReplicationFactor) + .setConfigs(convertToTopicConfigCollections(groupCoordinator.offsetsTopicConfigs)) + case TRANSACTION_STATE_TOPIC_NAME => + new CreatableTopic() + .setName(topic) + .setNumPartitions(config.transactionTopicPartitions) + .setReplicationFactor(config.transactionTopicReplicationFactor) + .setConfigs(convertToTopicConfigCollections( + txnCoordinator.transactionTopicConfigs)) + case topicName => + new CreatableTopic() + .setName(topicName) + .setNumPartitions(config.numPartitions) + .setReplicationFactor(config.defaultReplicationFactor.shortValue) + } + } + + private def convertToTopicConfigCollections(config: Properties): CreateableTopicConfigCollection = { + val topicConfigs = new CreateableTopicConfigCollection() + config.forEach { + case (name, value) => + topicConfigs.add(new CreateableTopicConfig() + .setName(name.toString) + .setValue(value.toString)) + } + topicConfigs + } + + private def isValidTopicName(topic: String): Boolean = { + try { + Topic.validate(topic) + true + } catch { + case _: InvalidTopicException => + false + } + } + + private def filterCreatableTopics( + topics: Set[String] + ): (Map[String, CreatableTopic], Seq[MetadataResponseTopic]) = { + + val creatableTopics = mutable.Map.empty[String, CreatableTopic] + val uncreatableTopics = mutable.Buffer.empty[MetadataResponseTopic] + + topics.foreach { topic => + // Attempt basic topic validation before sending any requests to the controller. + val validationError: Option[Errors] = if (!isValidTopicName(topic)) { + Some(Errors.INVALID_TOPIC_EXCEPTION) + } else if (!inflightTopics.add(topic)) { + Some(Errors.UNKNOWN_TOPIC_OR_PARTITION) + } else { + None + } + + validationError match { + case Some(error) => + uncreatableTopics += new MetadataResponseTopic() + .setErrorCode(error.code) + .setName(topic) + .setIsInternal(Topic.isInternal(topic)) + case None => + creatableTopics.put(topic, creatableTopic(topic)) + } + } + + (creatableTopics, uncreatableTopics) + } +} diff --git a/core/src/main/scala/kafka/server/BrokerFeatures.scala b/core/src/main/scala/kafka/server/BrokerFeatures.scala new file mode 100644 index 0000000..dd84f9e --- /dev/null +++ b/core/src/main/scala/kafka/server/BrokerFeatures.scala @@ -0,0 +1,116 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.utils.Logging +import org.apache.kafka.common.feature.{Features, FinalizedVersionRange, SupportedVersionRange} +import org.apache.kafka.common.feature.Features._ + +import scala.jdk.CollectionConverters._ + +/** + * A class that encapsulates the latest features supported by the Broker and also provides APIs to + * check for incompatibilities between the features supported by the Broker and finalized features. + * This class is immutable in production. It provides few APIs to mutate state only for the purpose + * of testing. + */ +class BrokerFeatures private (@volatile var supportedFeatures: Features[SupportedVersionRange]) { + // For testing only. + def setSupportedFeatures(newFeatures: Features[SupportedVersionRange]): Unit = { + supportedFeatures = newFeatures + } + + /** + * Returns the default finalized features that a new Kafka cluster with IBP config >= KAFKA_2_7_IV0 + * needs to be bootstrapped with. + */ + def defaultFinalizedFeatures: Features[FinalizedVersionRange] = { + Features.finalizedFeatures( + supportedFeatures.features.asScala.map { + case(name, versionRange) => ( + name, new FinalizedVersionRange(versionRange.min, versionRange.max)) + }.asJava) + } + + /** + * Returns the set of feature names found to be incompatible. + * A feature incompatibility is a version mismatch between the latest feature supported by the + * Broker, and a provided finalized feature. This can happen because a provided finalized + * feature: + * 1) Does not exist in the Broker (i.e. it is unknown to the Broker). + * [OR] + * 2) Exists but the FinalizedVersionRange does not match with the SupportedVersionRange + * of the supported feature. + * + * @param finalized The finalized features against which incompatibilities need to be checked for. + * + * @return The subset of input features which are incompatible. If the returned object + * is empty, it means there were no feature incompatibilities found. + */ + def incompatibleFeatures(finalized: Features[FinalizedVersionRange]): Features[FinalizedVersionRange] = { + BrokerFeatures.incompatibleFeatures(supportedFeatures, finalized, logIncompatibilities = true) + } +} + +object BrokerFeatures extends Logging { + + def createDefault(): BrokerFeatures = { + // The arguments are currently empty, but, in the future as we define features we should + // populate the required values here. + new BrokerFeatures(emptySupportedFeatures) + } + + /** + * Returns true if any of the provided finalized features are incompatible with the provided + * supported features. + * + * @param supportedFeatures The supported features to be compared + * @param finalizedFeatures The finalized features to be compared + * + * @return - True if there are any feature incompatibilities found. + * - False otherwise. + */ + def hasIncompatibleFeatures(supportedFeatures: Features[SupportedVersionRange], + finalizedFeatures: Features[FinalizedVersionRange]): Boolean = { + !incompatibleFeatures(supportedFeatures, finalizedFeatures, logIncompatibilities = false).empty + } + + private def incompatibleFeatures(supportedFeatures: Features[SupportedVersionRange], + finalizedFeatures: Features[FinalizedVersionRange], + logIncompatibilities: Boolean): Features[FinalizedVersionRange] = { + val incompatibleFeaturesInfo = finalizedFeatures.features.asScala.map { + case (feature, versionLevels) => + val supportedVersions = supportedFeatures.get(feature) + if (supportedVersions == null) { + (feature, versionLevels, "{feature=%s, reason='Unsupported feature'}".format(feature)) + } else if (versionLevels.isIncompatibleWith(supportedVersions)) { + (feature, versionLevels, "{feature=%s, reason='%s is incompatible with %s'}".format( + feature, versionLevels, supportedVersions)) + } else { + (feature, versionLevels, null) + } + }.filter{ case(_, _, errorReason) => errorReason != null}.toList + + if (logIncompatibilities && incompatibleFeaturesInfo.nonEmpty) { + warn("Feature incompatibilities seen: " + + incompatibleFeaturesInfo.map { case(_, _, errorReason) => errorReason }.mkString(", ")) + } + Features.finalizedFeatures( + incompatibleFeaturesInfo.map { case(feature, versionLevels, _) => (feature, versionLevels) }.toMap.asJava) + } +} diff --git a/core/src/main/scala/kafka/server/BrokerLifecycleManager.scala b/core/src/main/scala/kafka/server/BrokerLifecycleManager.scala new file mode 100644 index 0000000..394c353 --- /dev/null +++ b/core/src/main/scala/kafka/server/BrokerLifecycleManager.scala @@ -0,0 +1,485 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.util +import java.util.concurrent.TimeUnit.{MILLISECONDS, NANOSECONDS} +import java.util.concurrent.CompletableFuture +import kafka.utils.Logging +import org.apache.kafka.clients.ClientResponse +import org.apache.kafka.common.Uuid +import org.apache.kafka.common.message.BrokerRegistrationRequestData.ListenerCollection +import org.apache.kafka.common.message.{BrokerHeartbeatRequestData, BrokerRegistrationRequestData} +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.{BrokerHeartbeatRequest, BrokerHeartbeatResponse, BrokerRegistrationRequest, BrokerRegistrationResponse} +import org.apache.kafka.metadata.{BrokerState, VersionRange} +import org.apache.kafka.queue.EventQueue.DeadlineFunction +import org.apache.kafka.common.utils.{ExponentialBackoff, LogContext, Time} +import org.apache.kafka.queue.{EventQueue, KafkaEventQueue} +import scala.jdk.CollectionConverters._ + + +/** + * The broker lifecycle manager owns the broker state. + * + * Its inputs are messages passed in from other parts of the broker and from the + * controller: requests to start up, or shut down, for example. Its output are the broker + * state and various futures that can be used to wait for broker state transitions to + * occur. + * + * The lifecycle manager handles registering the broker with the controller, as described + * in KIP-631. After registration is complete, it handles sending periodic broker + * heartbeats and processing the responses. + * + * This code uses an event queue paradigm. Modifications get translated into events, which + * are placed on the queue to be processed sequentially. As described in the JavaDoc for + * each variable, most mutable state can be accessed only from that event queue thread. + * In some cases we expose a volatile variable which can be read from any thread, but only + * written from the event queue thread. + */ +class BrokerLifecycleManager(val config: KafkaConfig, + val time: Time, + val threadNamePrefix: Option[String]) extends Logging { + val logContext = new LogContext(s"[BrokerLifecycleManager id=${config.nodeId}] ") + + this.logIdent = logContext.logPrefix() + + /** + * The broker id. + */ + private val nodeId = config.nodeId + + /** + * The broker rack, or null if there is no configured rack. + */ + private val rack = config.rack + + /** + * How long to wait for registration to succeed before failing the startup process. + */ + private val initialTimeoutNs = + MILLISECONDS.toNanos(config.initialRegistrationTimeoutMs.longValue()) + + /** + * The exponential backoff to use for resending communication. + */ + private val resendExponentialBackoff = + new ExponentialBackoff(100, 2, config.brokerSessionTimeoutMs.toLong, 0.02) + + /** + * The number of times we've tried and failed to communicate. This variable can only be + * read or written from the event queue thread. + */ + private var failedAttempts = 0L + + /** + * The broker incarnation ID. This ID uniquely identifies each time we start the broker + */ + val incarnationId = Uuid.randomUuid() + + /** + * A future which is completed just as soon as the broker has caught up with the latest + * metadata offset for the first time. + */ + val initialCatchUpFuture = new CompletableFuture[Void]() + + /** + * A future which is completed when controlled shutdown is done. + */ + val controlledShutdownFuture = new CompletableFuture[Void]() + + /** + * The broker epoch, or -1 if the broker has not yet registered. This variable can only + * be written from the event queue thread. + */ + @volatile private var _brokerEpoch = -1L + + /** + * The current broker state. This variable can only be written from the event queue + * thread. + */ + @volatile private var _state = BrokerState.NOT_RUNNING + + /** + * A thread-safe callback function which gives this manager the current highest metadata + * offset. This variable can only be read or written from the event queue thread. + */ + private var _highestMetadataOffsetProvider: () => Long = _ + + /** + * True only if we are ready to unfence the broker. This variable can only be read or + * written from the event queue thread. + */ + private var readyToUnfence = false + + /** + * True if we sent a event queue to the active controller requesting controlled + * shutdown. This variable can only be read or written from the event queue thread. + */ + private var gotControlledShutdownResponse = false + + /** + * Whether or not this broker is registered with the controller quorum. + * This variable can only be read or written from the event queue thread. + */ + private var registered = false + + /** + * True if the initial registration succeeded. This variable can only be read or + * written from the event queue thread. + */ + private var initialRegistrationSucceeded = false + + /** + * The cluster ID, or null if this manager has not been started yet. This variable can + * only be read or written from the event queue thread. + */ + private var _clusterId: String = _ + + /** + * The listeners which this broker advertises. This variable can only be read or + * written from the event queue thread. + */ + private var _advertisedListeners: ListenerCollection = _ + + /** + * The features supported by this broker. This variable can only be read or written + * from the event queue thread. + */ + private var _supportedFeatures: util.Map[String, VersionRange] = _ + + /** + * The channel manager, or null if this manager has not been started yet. This variable + * can only be read or written from the event queue thread. + */ + private var _channelManager: BrokerToControllerChannelManager = _ + + /** + * The event queue. + */ + private[server] val eventQueue = new KafkaEventQueue(time, logContext, threadNamePrefix.getOrElse("")) + + /** + * Start the BrokerLifecycleManager. + * + * @param highestMetadataOffsetProvider Provides the current highest metadata offset. + * @param channelManager The brokerToControllerChannelManager to use. + * @param clusterId The cluster ID. + */ + def start(highestMetadataOffsetProvider: () => Long, + channelManager: BrokerToControllerChannelManager, + clusterId: String, + advertisedListeners: ListenerCollection, + supportedFeatures: util.Map[String, VersionRange]): Unit = { + eventQueue.append(new StartupEvent(highestMetadataOffsetProvider, + channelManager, clusterId, advertisedListeners, supportedFeatures)) + } + + def setReadyToUnfence(): Unit = { + eventQueue.append(new SetReadyToUnfenceEvent()) + } + + def brokerEpoch: Long = _brokerEpoch + + def state: BrokerState = _state + + private class BeginControlledShutdownEvent extends EventQueue.Event { + override def run(): Unit = { + _state match { + case BrokerState.PENDING_CONTROLLED_SHUTDOWN => + info("Attempted to enter pending controlled shutdown state, but we are " + + "already in that state.") + case BrokerState.RUNNING => + info("Beginning controlled shutdown.") + _state = BrokerState.PENDING_CONTROLLED_SHUTDOWN + // Send the next heartbeat immediately in order to let the controller + // begin processing the controlled shutdown as soon as possible. + scheduleNextCommunicationImmediately() + + case _ => + info(s"Skipping controlled shutdown because we are in state ${_state}.") + beginShutdown() + } + } + } + + /** + * Enter the controlled shutdown state if we are in RUNNING state. + * Or, if we're not running, shut down immediately. + */ + def beginControlledShutdown(): Unit = { + eventQueue.append(new BeginControlledShutdownEvent()) + } + + /** + * Start shutting down the BrokerLifecycleManager, but do not block. + */ + def beginShutdown(): Unit = { + eventQueue.beginShutdown("beginShutdown", new ShutdownEvent()) + } + + /** + * Shut down the BrokerLifecycleManager and block until all threads are joined. + */ + def close(): Unit = { + beginShutdown() + eventQueue.close() + } + + private class SetReadyToUnfenceEvent() extends EventQueue.Event { + override def run(): Unit = { + readyToUnfence = true + scheduleNextCommunicationImmediately() + } + } + + private class StartupEvent(highestMetadataOffsetProvider: () => Long, + channelManager: BrokerToControllerChannelManager, + clusterId: String, + advertisedListeners: ListenerCollection, + supportedFeatures: util.Map[String, VersionRange]) extends EventQueue.Event { + override def run(): Unit = { + _highestMetadataOffsetProvider = highestMetadataOffsetProvider + _channelManager = channelManager + _channelManager.start() + _state = BrokerState.STARTING + _clusterId = clusterId + _advertisedListeners = advertisedListeners.duplicate() + _supportedFeatures = new util.HashMap[String, VersionRange](supportedFeatures) + eventQueue.scheduleDeferred("initialRegistrationTimeout", + new DeadlineFunction(time.nanoseconds() + initialTimeoutNs), + new RegistrationTimeoutEvent()) + sendBrokerRegistration() + info(s"Incarnation ${incarnationId} of broker ${nodeId} in cluster ${clusterId} " + + "is now STARTING.") + } + } + + private def sendBrokerRegistration(): Unit = { + val features = new BrokerRegistrationRequestData.FeatureCollection() + _supportedFeatures.asScala.foreach { + case (name, range) => features.add(new BrokerRegistrationRequestData.Feature(). + setName(name). + setMinSupportedVersion(range.min()). + setMaxSupportedVersion(range.max())) + } + val data = new BrokerRegistrationRequestData(). + setBrokerId(nodeId). + setClusterId(_clusterId). + setFeatures(features). + setIncarnationId(incarnationId). + setListeners(_advertisedListeners). + setRack(rack.orNull) + if (isDebugEnabled) { + debug(s"Sending broker registration ${data}") + } + _channelManager.sendRequest(new BrokerRegistrationRequest.Builder(data), + new BrokerRegistrationResponseHandler()) + } + + private class BrokerRegistrationResponseHandler extends ControllerRequestCompletionHandler { + override def onComplete(response: ClientResponse): Unit = { + if (response.authenticationException() != null) { + error(s"Unable to register broker ${nodeId} because of an authentication exception.", + response.authenticationException()); + scheduleNextCommunicationAfterFailure() + } else if (response.versionMismatch() != null) { + error(s"Unable to register broker ${nodeId} because of an API version problem.", + response.versionMismatch()); + scheduleNextCommunicationAfterFailure() + } else if (response.responseBody() == null) { + warn(s"Unable to register broker ${nodeId}.") + scheduleNextCommunicationAfterFailure() + } else if (!response.responseBody().isInstanceOf[BrokerRegistrationResponse]) { + error(s"Unable to register broker ${nodeId} because the controller returned an " + + "invalid response type.") + scheduleNextCommunicationAfterFailure() + } else { + val message = response.responseBody().asInstanceOf[BrokerRegistrationResponse] + val errorCode = Errors.forCode(message.data().errorCode()) + if (errorCode == Errors.NONE) { + failedAttempts = 0 + _brokerEpoch = message.data().brokerEpoch() + registered = true + initialRegistrationSucceeded = true + info(s"Successfully registered broker ${nodeId} with broker epoch ${_brokerEpoch}") + scheduleNextCommunicationImmediately() // Immediately send a heartbeat + } else { + info(s"Unable to register broker ${nodeId} because the controller returned " + + s"error ${errorCode}") + scheduleNextCommunicationAfterFailure() + } + } + } + + override def onTimeout(): Unit = { + info(s"Unable to register the broker because the RPC got timed out before it could be sent.") + scheduleNextCommunicationAfterFailure() + } + } + + private def sendBrokerHeartbeat(): Unit = { + val metadataOffset = _highestMetadataOffsetProvider() + val data = new BrokerHeartbeatRequestData(). + setBrokerEpoch(_brokerEpoch). + setBrokerId(nodeId). + setCurrentMetadataOffset(metadataOffset). + setWantFence(!readyToUnfence). + setWantShutDown(_state == BrokerState.PENDING_CONTROLLED_SHUTDOWN) + if (isTraceEnabled) { + trace(s"Sending broker heartbeat ${data}") + } + _channelManager.sendRequest(new BrokerHeartbeatRequest.Builder(data), + new BrokerHeartbeatResponseHandler()) + } + + private class BrokerHeartbeatResponseHandler extends ControllerRequestCompletionHandler { + override def onComplete(response: ClientResponse): Unit = { + if (response.authenticationException() != null) { + error(s"Unable to send broker heartbeat for ${nodeId} because of an " + + "authentication exception.", response.authenticationException()); + scheduleNextCommunicationAfterFailure() + } else if (response.versionMismatch() != null) { + error(s"Unable to send broker heartbeat for ${nodeId} because of an API " + + "version problem.", response.versionMismatch()); + scheduleNextCommunicationAfterFailure() + } else if (response.responseBody() == null) { + warn(s"Unable to send broker heartbeat for ${nodeId}. Retrying.") + scheduleNextCommunicationAfterFailure() + } else if (!response.responseBody().isInstanceOf[BrokerHeartbeatResponse]) { + error(s"Unable to send broker heartbeat for ${nodeId} because the controller " + + "returned an invalid response type.") + scheduleNextCommunicationAfterFailure() + } else { + val message = response.responseBody().asInstanceOf[BrokerHeartbeatResponse] + val errorCode = Errors.forCode(message.data().errorCode()) + if (errorCode == Errors.NONE) { + failedAttempts = 0 + _state match { + case BrokerState.STARTING => + if (message.data().isCaughtUp()) { + info(s"The broker has caught up. Transitioning from STARTING to RECOVERY.") + _state = BrokerState.RECOVERY + initialCatchUpFuture.complete(null) + } else { + debug(s"The broker is STARTING. Still waiting to catch up with cluster metadata.") + } + // Schedule the heartbeat after only 10 ms so that in the case where + // there is no recovery work to be done, we start up a bit quicker. + scheduleNextCommunication(NANOSECONDS.convert(10, MILLISECONDS)) + case BrokerState.RECOVERY => + if (!message.data().isFenced()) { + info(s"The broker has been unfenced. Transitioning from RECOVERY to RUNNING.") + _state = BrokerState.RUNNING + } else { + info(s"The broker is in RECOVERY.") + } + scheduleNextCommunicationAfterSuccess() + case BrokerState.RUNNING => + debug(s"The broker is RUNNING. Processing heartbeat response.") + scheduleNextCommunicationAfterSuccess() + case BrokerState.PENDING_CONTROLLED_SHUTDOWN => + if (!message.data().shouldShutDown()) { + info(s"The broker is in PENDING_CONTROLLED_SHUTDOWN state, still waiting " + + "for the active controller.") + if (!gotControlledShutdownResponse) { + // If this is the first pending controlled shutdown response we got, + // schedule our next heartbeat a little bit sooner than we usually would. + // In the case where controlled shutdown completes quickly, this will + // speed things up a little bit. + scheduleNextCommunication(NANOSECONDS.convert(50, MILLISECONDS)) + } else { + scheduleNextCommunicationAfterSuccess() + } + } else { + info(s"The controller has asked us to exit controlled shutdown.") + beginShutdown() + } + gotControlledShutdownResponse = true + case BrokerState.SHUTTING_DOWN => + info(s"The broker is SHUTTING_DOWN. Ignoring heartbeat response.") + case _ => + error(s"Unexpected broker state ${_state}") + scheduleNextCommunicationAfterSuccess() + } + } else { + warn(s"Broker ${nodeId} sent a heartbeat request but received error ${errorCode}.") + scheduleNextCommunicationAfterFailure() + } + } + } + + override def onTimeout(): Unit = { + info("Unable to send a heartbeat because the RPC got timed out before it could be sent.") + scheduleNextCommunicationAfterFailure() + } + } + + private def scheduleNextCommunicationImmediately(): Unit = scheduleNextCommunication(0) + + private def scheduleNextCommunicationAfterFailure(): Unit = { + val delayMs = resendExponentialBackoff.backoff(failedAttempts) + failedAttempts = failedAttempts + 1 + scheduleNextCommunication(NANOSECONDS.convert(delayMs, MILLISECONDS)) + } + + private def scheduleNextCommunicationAfterSuccess(): Unit = { + scheduleNextCommunication(NANOSECONDS.convert( + config.brokerHeartbeatIntervalMs.longValue() , MILLISECONDS)) + } + + private def scheduleNextCommunication(intervalNs: Long): Unit = { + trace(s"Scheduling next communication at ${MILLISECONDS.convert(intervalNs, NANOSECONDS)} " + + "ms from now.") + val deadlineNs = time.nanoseconds() + intervalNs + eventQueue.scheduleDeferred("communication", + new DeadlineFunction(deadlineNs), + new CommunicationEvent()) + } + + private class RegistrationTimeoutEvent extends EventQueue.Event { + override def run(): Unit = { + if (!initialRegistrationSucceeded) { + error("Shutting down because we were unable to register with the controller quorum.") + eventQueue.beginShutdown("registrationTimeout", new ShutdownEvent()) + } + } + } + + private class CommunicationEvent extends EventQueue.Event { + override def run(): Unit = { + if (registered) { + sendBrokerHeartbeat() + } else { + sendBrokerRegistration() + } + } + } + + private class ShutdownEvent extends EventQueue.Event { + override def run(): Unit = { + info(s"Transitioning from ${_state} to ${BrokerState.SHUTTING_DOWN}.") + _state = BrokerState.SHUTTING_DOWN + controlledShutdownFuture.complete(null) + initialCatchUpFuture.cancel(false) + if (_channelManager != null) { + _channelManager.shutdown() + _channelManager = null + } + } + } +} diff --git a/core/src/main/scala/kafka/server/BrokerMetadataCheckpoint.scala b/core/src/main/scala/kafka/server/BrokerMetadataCheckpoint.scala new file mode 100755 index 0000000..e85144b --- /dev/null +++ b/core/src/main/scala/kafka/server/BrokerMetadataCheckpoint.scala @@ -0,0 +1,240 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.io._ +import java.nio.file.{Files, NoSuchFileException} +import java.util.Properties + +import kafka.common.{InconsistentBrokerMetadataException, KafkaException} +import kafka.server.RawMetaProperties._ +import kafka.utils._ +import org.apache.kafka.common.utils.Utils + +import scala.collection.mutable +import scala.jdk.CollectionConverters._ + +object RawMetaProperties { + val ClusterIdKey = "cluster.id" + val BrokerIdKey = "broker.id" + val NodeIdKey = "node.id" + val VersionKey = "version" +} + +class RawMetaProperties(val props: Properties = new Properties()) { + + def clusterId: Option[String] = { + Option(props.getProperty(ClusterIdKey)) + } + + def clusterId_=(id: String): Unit = { + props.setProperty(ClusterIdKey, id) + } + + def brokerId: Option[Int] = { + intValue(BrokerIdKey) + } + + def brokerId_=(id: Int): Unit = { + props.setProperty(BrokerIdKey, id.toString) + } + + def nodeId: Option[Int] = { + intValue(NodeIdKey) + } + + def nodeId_=(id: Int): Unit = { + props.setProperty(NodeIdKey, id.toString) + } + + def version: Int = { + intValue(VersionKey).getOrElse(0) + } + + def version_=(ver: Int): Unit = { + props.setProperty(VersionKey, ver.toString) + } + + def requireVersion(expectedVersion: Int): Unit = { + if (version != expectedVersion) { + throw new RuntimeException(s"Expected version $expectedVersion, but got "+ + s"version $version") + } + } + + private def intValue(key: String): Option[Int] = { + try { + Option(props.getProperty(key)).map(Integer.parseInt) + } catch { + case e: Throwable => throw new RuntimeException(s"Failed to parse $key property " + + s"as an int: ${e.getMessage}") + } + } + + override def toString: String = { + "{" + props.keySet().asScala.toList.asInstanceOf[List[String]].sorted.map { + key => key + "=" + props.get(key) + }.mkString(", ") + "}" + } +} + +object MetaProperties { + def parse(properties: RawMetaProperties): MetaProperties = { + properties.requireVersion(expectedVersion = 1) + val clusterId = require(ClusterIdKey, properties.clusterId) + val nodeId = require(NodeIdKey, properties.nodeId) + new MetaProperties(clusterId, nodeId) + } + + def require[T](key: String, value: Option[T]): T = { + value.getOrElse(throw new RuntimeException(s"Failed to find required property $key.")) + } +} + +case class ZkMetaProperties( + clusterId: String, + brokerId: Int +) { + def toProperties: Properties = { + val properties = new RawMetaProperties() + properties.version = 0 + properties.clusterId = clusterId + properties.brokerId = brokerId + properties.props + } + + override def toString: String = { + s"ZkMetaProperties(brokerId=$brokerId, clusterId=$clusterId)" + } +} + +case class MetaProperties( + clusterId: String, + nodeId: Int, +) { + def toProperties: Properties = { + val properties = new RawMetaProperties() + properties.version = 1 + properties.clusterId = clusterId + properties.nodeId = nodeId + properties.props + } + + override def toString: String = { + s"MetaProperties(clusterId=$clusterId, nodeId=$nodeId)" + } +} + +object BrokerMetadataCheckpoint extends Logging { + def getBrokerMetadataAndOfflineDirs( + logDirs: collection.Seq[String], + ignoreMissing: Boolean + ): (RawMetaProperties, collection.Seq[String]) = { + require(logDirs.nonEmpty, "Must have at least one log dir to read meta.properties") + + val brokerMetadataMap = mutable.HashMap[String, Properties]() + val offlineDirs = mutable.ArrayBuffer.empty[String] + + for (logDir <- logDirs) { + val brokerCheckpointFile = new File(logDir, "meta.properties") + val brokerCheckpoint = new BrokerMetadataCheckpoint(brokerCheckpointFile) + + try { + brokerCheckpoint.read() match { + case Some(properties) => + brokerMetadataMap += logDir -> properties + case None => + if (!ignoreMissing) { + throw new KafkaException(s"No `meta.properties` found in $logDir " + + "(have you run `kafka-storage.sh` to format the directory?)") + } + } + } catch { + case e: IOException => + offlineDirs += logDir + error(s"Failed to read $brokerCheckpointFile", e) + } + } + + if (brokerMetadataMap.isEmpty) { + (new RawMetaProperties(), offlineDirs) + } else { + val numDistinctMetaProperties = brokerMetadataMap.values.toSet.size + if (numDistinctMetaProperties > 1) { + val builder = new StringBuilder + + for ((logDir, brokerMetadata) <- brokerMetadataMap) + builder ++= s"- $logDir -> $brokerMetadata\n" + + throw new InconsistentBrokerMetadataException( + s"BrokerMetadata is not consistent across log.dirs. This could happen if multiple brokers shared a log directory (log.dirs) " + + s"or partial data was manually copied from another broker. Found:\n${builder.toString()}" + ) + } + + val rawProps = new RawMetaProperties(brokerMetadataMap.head._2) + (rawProps, offlineDirs) + } + } +} + +/** + * This class saves the metadata properties to a file + */ +class BrokerMetadataCheckpoint(val file: File) extends Logging { + private val lock = new Object() + + def write(properties: Properties): Unit = { + lock synchronized { + try { + val temp = new File(file.getAbsolutePath + ".tmp") + val fileOutputStream = new FileOutputStream(temp) + try { + properties.store(fileOutputStream, "") + fileOutputStream.flush() + fileOutputStream.getFD.sync() + } finally { + Utils.closeQuietly(fileOutputStream, temp.getName) + } + Utils.atomicMoveWithFallback(temp.toPath, file.toPath) + } catch { + case ie: IOException => + error("Failed to write meta.properties due to", ie) + throw ie + } + } + } + + def read(): Option[Properties] = { + Files.deleteIfExists(new File(file.getPath + ".tmp").toPath) // try to delete any existing temp files for cleanliness + + val absolutePath = file.getAbsolutePath + lock synchronized { + try { + Some(Utils.loadProps(absolutePath)) + } catch { + case _: NoSuchFileException => + warn(s"No meta.properties file under dir $absolutePath") + None + case e: Exception => + error(s"Failed to read meta.properties file under dir $absolutePath", e) + throw e + } + } + } +} diff --git a/core/src/main/scala/kafka/server/BrokerServer.scala b/core/src/main/scala/kafka/server/BrokerServer.scala new file mode 100644 index 0000000..0871030 --- /dev/null +++ b/core/src/main/scala/kafka/server/BrokerServer.scala @@ -0,0 +1,556 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.net.InetAddress +import java.util +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.locks.ReentrantLock +import java.util.concurrent.{CompletableFuture, TimeUnit, TimeoutException} + +import kafka.cluster.Broker.ServerInfo +import kafka.coordinator.group.GroupCoordinator +import kafka.coordinator.transaction.{ProducerIdManager, TransactionCoordinator} +import kafka.log.LogManager +import kafka.metrics.KafkaYammerMetrics +import kafka.network.SocketServer +import kafka.raft.RaftManager +import kafka.security.CredentialProvider +import kafka.server.KafkaRaftServer.ControllerRole +import kafka.server.metadata.{BrokerMetadataListener, BrokerMetadataPublisher, BrokerMetadataSnapshotter, ClientQuotaMetadataManager, KRaftMetadataCache, SnapshotWriterBuilder} +import kafka.utils.{CoreUtils, KafkaScheduler} +import org.apache.kafka.common.message.ApiMessageType.ListenerType +import org.apache.kafka.common.message.BrokerRegistrationRequestData.{Listener, ListenerCollection} +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.security.scram.internals.ScramMechanism +import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache +import org.apache.kafka.common.utils.{AppInfoParser, LogContext, Time, Utils} +import org.apache.kafka.common.{ClusterResource, Endpoint} +import org.apache.kafka.metadata.{BrokerState, VersionRange} +import org.apache.kafka.raft.RaftConfig.AddressSpec +import org.apache.kafka.raft.{RaftClient, RaftConfig} +import org.apache.kafka.server.authorizer.Authorizer +import org.apache.kafka.server.common.ApiMessageAndVersion +import org.apache.kafka.snapshot.SnapshotWriter + +import scala.collection.{Map, Seq} +import scala.compat.java8.OptionConverters._ +import scala.jdk.CollectionConverters._ + + +class BrokerSnapshotWriterBuilder(raftClient: RaftClient[ApiMessageAndVersion]) + extends SnapshotWriterBuilder { + override def build(committedOffset: Long, + committedEpoch: Int, + lastContainedLogTime: Long): SnapshotWriter[ApiMessageAndVersion] = { + raftClient.createSnapshot(committedOffset, committedEpoch, lastContainedLogTime). + asScala.getOrElse( + throw new RuntimeException("A snapshot already exists with " + + s"committedOffset=${committedOffset}, committedEpoch=${committedEpoch}, " + + s"lastContainedLogTime=${lastContainedLogTime}") + ) + } +} + +/** + * A Kafka broker that runs in KRaft (Kafka Raft) mode. + */ +class BrokerServer( + val config: KafkaConfig, + val metaProps: MetaProperties, + val raftManager: RaftManager[ApiMessageAndVersion], + val time: Time, + val metrics: Metrics, + val threadNamePrefix: Option[String], + val initialOfflineDirs: Seq[String], + val controllerQuorumVotersFuture: CompletableFuture[util.Map[Integer, AddressSpec]], + val supportedFeatures: util.Map[String, VersionRange] +) extends KafkaBroker { + + override def brokerState: BrokerState = lifecycleManager.state + + import kafka.server.Server._ + + private val logContext: LogContext = new LogContext(s"[BrokerServer id=${config.nodeId}] ") + + this.logIdent = logContext.logPrefix + + @volatile private var lifecycleManager: BrokerLifecycleManager = null + + private val isShuttingDown = new AtomicBoolean(false) + + val lock = new ReentrantLock() + val awaitShutdownCond = lock.newCondition() + var status: ProcessStatus = SHUTDOWN + + @volatile var dataPlaneRequestProcessor: KafkaApis = null + var controlPlaneRequestProcessor: KafkaApis = null + + var authorizer: Option[Authorizer] = None + @volatile var socketServer: SocketServer = null + var dataPlaneRequestHandlerPool: KafkaRequestHandlerPool = null + + var logDirFailureChannel: LogDirFailureChannel = null + var logManager: LogManager = null + + var tokenManager: DelegationTokenManager = null + + var dynamicConfigHandlers: Map[String, ConfigHandler] = null + + @volatile private[this] var _replicaManager: ReplicaManager = null + + var credentialProvider: CredentialProvider = null + var tokenCache: DelegationTokenCache = null + + @volatile var groupCoordinator: GroupCoordinator = null + + var transactionCoordinator: TransactionCoordinator = null + + var clientToControllerChannelManager: BrokerToControllerChannelManager = null + + var forwardingManager: ForwardingManager = null + + var alterIsrManager: AlterIsrManager = null + + var autoTopicCreationManager: AutoTopicCreationManager = null + + var kafkaScheduler: KafkaScheduler = null + + @volatile var metadataCache: KRaftMetadataCache = null + + var quotaManagers: QuotaFactory.QuotaManagers = null + + var clientQuotaMetadataManager: ClientQuotaMetadataManager = null + + @volatile var brokerTopicStats: BrokerTopicStats = null + + val brokerFeatures: BrokerFeatures = BrokerFeatures.createDefault() + + val featureCache: FinalizedFeatureCache = new FinalizedFeatureCache(brokerFeatures) + + val clusterId: String = metaProps.clusterId + + var metadataSnapshotter: Option[BrokerMetadataSnapshotter] = None + + var metadataListener: BrokerMetadataListener = null + + var metadataPublisher: BrokerMetadataPublisher = null + + def kafkaYammerMetrics: kafka.metrics.KafkaYammerMetrics = KafkaYammerMetrics.INSTANCE + + private def maybeChangeStatus(from: ProcessStatus, to: ProcessStatus): Boolean = { + lock.lock() + try { + if (status != from) return false + info(s"Transition from $status to $to") + + status = to + if (to == SHUTTING_DOWN) { + isShuttingDown.set(true) + } else if (to == SHUTDOWN) { + isShuttingDown.set(false) + awaitShutdownCond.signalAll() + } + } finally { + lock.unlock() + } + true + } + + def replicaManager: ReplicaManager = _replicaManager + + override def startup(): Unit = { + if (!maybeChangeStatus(SHUTDOWN, STARTING)) return + try { + info("Starting broker") + + config.dynamicConfig.initialize(zkClientOpt = None) + + lifecycleManager = new BrokerLifecycleManager(config, time, threadNamePrefix) + + /* start scheduler */ + kafkaScheduler = new KafkaScheduler(config.backgroundThreads) + kafkaScheduler.startup() + + /* register broker metrics */ + brokerTopicStats = new BrokerTopicStats + + quotaManagers = QuotaFactory.instantiate(config, metrics, time, threadNamePrefix.getOrElse("")) + + logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size) + + metadataCache = MetadataCache.kRaftMetadataCache(config.nodeId) + + // Create log manager, but don't start it because we need to delay any potential unclean shutdown log recovery + // until we catch up on the metadata log and have up-to-date topic and broker configs. + logManager = LogManager(config, initialOfflineDirs, metadataCache, kafkaScheduler, time, + brokerTopicStats, logDirFailureChannel, keepPartitionMetadataFile = true) + + // Enable delegation token cache for all SCRAM mechanisms to simplify dynamic update. + // This keeps the cache up-to-date if new SCRAM mechanisms are enabled dynamically. + tokenCache = new DelegationTokenCache(ScramMechanism.mechanismNames) + credentialProvider = new CredentialProvider(ScramMechanism.mechanismNames, tokenCache) + + val controllerNodes = RaftConfig.voterConnectionsToNodes(controllerQuorumVotersFuture.get()).asScala + val controllerNodeProvider = RaftControllerNodeProvider(raftManager, config, controllerNodes) + + clientToControllerChannelManager = BrokerToControllerChannelManager( + controllerNodeProvider, + time, + metrics, + config, + channelName = "forwarding", + threadNamePrefix, + retryTimeoutMs = 60000 + ) + clientToControllerChannelManager.start() + forwardingManager = new ForwardingManagerImpl(clientToControllerChannelManager) + + val apiVersionManager = ApiVersionManager( + ListenerType.BROKER, + config, + Some(forwardingManager), + brokerFeatures, + featureCache + ) + + // Create and start the socket server acceptor threads so that the bound port is known. + // Delay starting processors until the end of the initialization sequence to ensure + // that credentials have been loaded before processing authentications. + socketServer = new SocketServer(config, metrics, time, credentialProvider, apiVersionManager) + socketServer.startup(startProcessingRequests = false) + + clientQuotaMetadataManager = new ClientQuotaMetadataManager(quotaManagers, socketServer.connectionQuotas) + + val alterIsrChannelManager = BrokerToControllerChannelManager( + controllerNodeProvider, + time, + metrics, + config, + channelName = "alterIsr", + threadNamePrefix, + retryTimeoutMs = Long.MaxValue + ) + alterIsrManager = new DefaultAlterIsrManager( + controllerChannelManager = alterIsrChannelManager, + scheduler = kafkaScheduler, + time = time, + brokerId = config.nodeId, + brokerEpochSupplier = () => lifecycleManager.brokerEpoch + ) + alterIsrManager.start() + + this._replicaManager = new ReplicaManager( + config = config, + metrics = metrics, + time = time, + scheduler = kafkaScheduler, + logManager = logManager, + quotaManagers = quotaManagers, + metadataCache = metadataCache, + logDirFailureChannel = logDirFailureChannel, + alterIsrManager = alterIsrManager, + brokerTopicStats = brokerTopicStats, + isShuttingDown = isShuttingDown, + zkClient = None, + threadNamePrefix = threadNamePrefix) + + /* start token manager */ + if (config.tokenAuthEnabled) { + throw new UnsupportedOperationException("Delegation tokens are not supported") + } + tokenManager = new DelegationTokenManager(config, tokenCache, time , null) + tokenManager.startup() // does nothing, we just need a token manager in order to compile right now... + + // Create group coordinator, but don't start it until we've started replica manager. + // Hardcode Time.SYSTEM for now as some Streams tests fail otherwise, it would be good to fix the underlying issue + groupCoordinator = GroupCoordinator(config, replicaManager, Time.SYSTEM, metrics) + + val producerIdManagerSupplier = () => ProducerIdManager.rpc( + config.brokerId, + brokerEpochSupplier = () => lifecycleManager.brokerEpoch, + clientToControllerChannelManager, + config.requestTimeoutMs + ) + + // Create transaction coordinator, but don't start it until we've started replica manager. + // Hardcode Time.SYSTEM for now as some Streams tests fail otherwise, it would be good to fix the underlying issue + transactionCoordinator = TransactionCoordinator(config, replicaManager, + new KafkaScheduler(threads = 1, threadNamePrefix = "transaction-log-manager-"), + producerIdManagerSupplier, metrics, metadataCache, Time.SYSTEM) + + autoTopicCreationManager = new DefaultAutoTopicCreationManager( + config, Some(clientToControllerChannelManager), None, None, + groupCoordinator, transactionCoordinator) + + /* Add all reconfigurables for config change notification before starting the metadata listener */ + config.dynamicConfig.addReconfigurables(this) + + dynamicConfigHandlers = Map[String, ConfigHandler]( + ConfigType.Topic -> new TopicConfigHandler(logManager, config, quotaManagers, None), + ConfigType.Broker -> new BrokerConfigHandler(config, quotaManagers)) + + if (!config.processRoles.contains(ControllerRole)) { + // If no controller is defined, we rely on the broker to generate snapshots. + metadataSnapshotter = Some(new BrokerMetadataSnapshotter( + config.nodeId, + time, + threadNamePrefix, + new BrokerSnapshotWriterBuilder(raftManager.client) + )) + } + + metadataListener = new BrokerMetadataListener(config.nodeId, + time, + threadNamePrefix, + config.metadataSnapshotMaxNewRecordBytes, + metadataSnapshotter) + + val networkListeners = new ListenerCollection() + config.effectiveAdvertisedListeners.foreach { ep => + networkListeners.add(new Listener(). + setHost(if (Utils.isBlank(ep.host)) InetAddress.getLocalHost.getCanonicalHostName else ep.host). + setName(ep.listenerName.value()). + setPort(if (ep.port == 0) socketServer.boundPort(ep.listenerName) else ep.port). + setSecurityProtocol(ep.securityProtocol.id)) + } + lifecycleManager.start(() => metadataListener.highestMetadataOffset, + BrokerToControllerChannelManager(controllerNodeProvider, time, metrics, config, + "heartbeat", threadNamePrefix, config.brokerSessionTimeoutMs.toLong), + metaProps.clusterId, networkListeners, supportedFeatures) + + // Register a listener with the Raft layer to receive metadata event notifications + raftManager.register(metadataListener) + + val endpoints = new util.ArrayList[Endpoint](networkListeners.size()) + var interBrokerListener: Endpoint = null + networkListeners.iterator().forEachRemaining(listener => { + val endPoint = new Endpoint(listener.name(), + SecurityProtocol.forId(listener.securityProtocol()), + listener.host(), listener.port()) + endpoints.add(endPoint) + if (listener.name().equals(config.interBrokerListenerName.value())) { + interBrokerListener = endPoint + } + }) + if (interBrokerListener == null) { + throw new RuntimeException("Unable to find inter-broker listener " + + config.interBrokerListenerName.value() + ". Found listener(s): " + + endpoints.asScala.map(ep => ep.listenerName().orElse("(none)")).mkString(", ")) + } + val authorizerInfo = ServerInfo(new ClusterResource(clusterId), + config.nodeId, endpoints, interBrokerListener) + + /* Get the authorizer and initialize it if one is specified.*/ + authorizer = config.authorizer + authorizer.foreach(_.configure(config.originals)) + val authorizerFutures: Map[Endpoint, CompletableFuture[Void]] = authorizer match { + case Some(authZ) => + authZ.start(authorizerInfo).asScala.map { case (ep, cs) => + ep -> cs.toCompletableFuture + } + case None => + authorizerInfo.endpoints.asScala.map { ep => + ep -> CompletableFuture.completedFuture[Void](null) + }.toMap + } + + val fetchManager = new FetchManager(Time.SYSTEM, + new FetchSessionCache(config.maxIncrementalFetchSessionCacheSlots, + KafkaServer.MIN_INCREMENTAL_FETCH_SESSION_EVICTION_MS)) + + // Create the request processor objects. + val raftSupport = RaftSupport(forwardingManager, metadataCache) + dataPlaneRequestProcessor = new KafkaApis( + requestChannel = socketServer.dataPlaneRequestChannel, + metadataSupport = raftSupport, + replicaManager = replicaManager, + groupCoordinator = groupCoordinator, + txnCoordinator = transactionCoordinator, + autoTopicCreationManager = autoTopicCreationManager, + brokerId = config.nodeId, + config = config, + configRepository = metadataCache, + metadataCache = metadataCache, + metrics = metrics, + authorizer = authorizer, + quotas = quotaManagers, + fetchManager = fetchManager, + brokerTopicStats = brokerTopicStats, + clusterId = clusterId, + time = time, + tokenManager = tokenManager, + apiVersionManager = apiVersionManager) + + dataPlaneRequestHandlerPool = new KafkaRequestHandlerPool(config.nodeId, + socketServer.dataPlaneRequestChannel, dataPlaneRequestProcessor, time, + config.numIoThreads, s"${SocketServer.DataPlaneMetricPrefix}RequestHandlerAvgIdlePercent", + SocketServer.DataPlaneThreadPrefix) + + // Block until we've caught up with the latest metadata from the controller quorum. + lifecycleManager.initialCatchUpFuture.get() + + // Apply the metadata log changes that we've accumulated. + metadataPublisher = new BrokerMetadataPublisher(config, metadataCache, + logManager, replicaManager, groupCoordinator, transactionCoordinator, + clientQuotaMetadataManager, featureCache, dynamicConfigHandlers.toMap) + + // Tell the metadata listener to start publishing its output, and wait for the first + // publish operation to complete. This first operation will initialize logManager, + // replicaManager, groupCoordinator, and txnCoordinator. The log manager may perform + // a potentially lengthy recovery-from-unclean-shutdown operation here, if required. + metadataListener.startPublishing(metadataPublisher).get() + + // Log static broker configurations. + new KafkaConfig(config.originals(), true) + + // Enable inbound TCP connections. + socketServer.startProcessingRequests(authorizerFutures) + + // We're now ready to unfence the broker. This also allows this broker to transition + // from RECOVERY state to RUNNING state, once the controller unfences the broker. + lifecycleManager.setReadyToUnfence() + + maybeChangeStatus(STARTING, STARTED) + } catch { + case e: Throwable => + maybeChangeStatus(STARTING, STARTED) + fatal("Fatal error during broker startup. Prepare to shutdown", e) + shutdown() + throw e + } + } + + override def shutdown(): Unit = { + if (!maybeChangeStatus(STARTED, SHUTTING_DOWN)) return + try { + info("shutting down") + + if (config.controlledShutdownEnable) { + // Shut down the broker metadata listener, so that we don't get added to any + // more ISRs. + if (metadataListener != null) { + metadataListener.beginShutdown() + } + lifecycleManager.beginControlledShutdown() + try { + lifecycleManager.controlledShutdownFuture.get(5L, TimeUnit.MINUTES) + } catch { + case _: TimeoutException => + error("Timed out waiting for the controller to approve controlled shutdown") + case e: Throwable => + error("Got unexpected exception waiting for controlled shutdown future", e) + } + } + lifecycleManager.beginShutdown() + + // Stop socket server to stop accepting any more connections and requests. + // Socket server will be shutdown towards the end of the sequence. + if (socketServer != null) { + CoreUtils.swallow(socketServer.stopProcessingRequests(), this) + } + if (dataPlaneRequestHandlerPool != null) + CoreUtils.swallow(dataPlaneRequestHandlerPool.shutdown(), this) + if (dataPlaneRequestProcessor != null) + CoreUtils.swallow(dataPlaneRequestProcessor.close(), this) + if (controlPlaneRequestProcessor != null) + CoreUtils.swallow(controlPlaneRequestProcessor.close(), this) + CoreUtils.swallow(authorizer.foreach(_.close()), this) + if (metadataListener != null) { + CoreUtils.swallow(metadataListener.close(), this) + } + metadataSnapshotter.foreach(snapshotter => CoreUtils.swallow(snapshotter.close(), this)) + + /** + * We must shutdown the scheduler early because otherwise, the scheduler could touch other + * resources that might have been shutdown and cause exceptions. + * For example, if we didn't shutdown the scheduler first, when LogManager was closing + * partitions one by one, the scheduler might concurrently delete old segments due to + * retention. However, the old segments could have been closed by the LogManager, which would + * cause an IOException and subsequently mark logdir as offline. As a result, the broker would + * not flush the remaining partitions or write the clean shutdown marker. Ultimately, the + * broker would have to take hours to recover the log during restart. + */ + if (kafkaScheduler != null) + CoreUtils.swallow(kafkaScheduler.shutdown(), this) + + if (transactionCoordinator != null) + CoreUtils.swallow(transactionCoordinator.shutdown(), this) + if (groupCoordinator != null) + CoreUtils.swallow(groupCoordinator.shutdown(), this) + + if (tokenManager != null) + CoreUtils.swallow(tokenManager.shutdown(), this) + + if (replicaManager != null) + CoreUtils.swallow(replicaManager.shutdown(), this) + + if (alterIsrManager != null) + CoreUtils.swallow(alterIsrManager.shutdown(), this) + + if (clientToControllerChannelManager != null) + CoreUtils.swallow(clientToControllerChannelManager.shutdown(), this) + + if (logManager != null) + CoreUtils.swallow(logManager.shutdown(), this) + + if (quotaManagers != null) + CoreUtils.swallow(quotaManagers.shutdown(), this) + + if (socketServer != null) + CoreUtils.swallow(socketServer.shutdown(), this) + if (metrics != null) + CoreUtils.swallow(metrics.close(), this) + if (brokerTopicStats != null) + CoreUtils.swallow(brokerTopicStats.close(), this) + + // Clear all reconfigurable instances stored in DynamicBrokerConfig + config.dynamicConfig.clear() + + isShuttingDown.set(false) + + CoreUtils.swallow(lifecycleManager.close(), this) + + CoreUtils.swallow(AppInfoParser.unregisterAppInfo(MetricsPrefix, config.nodeId.toString, metrics), this) + info("shut down completed") + } catch { + case e: Throwable => + fatal("Fatal error during broker shutdown.", e) + throw e + } finally { + maybeChangeStatus(SHUTTING_DOWN, SHUTDOWN) + } + } + + override def awaitShutdown(): Unit = { + lock.lock() + try { + while (true) { + if (status == SHUTDOWN) return + awaitShutdownCond.awaitUninterruptibly() + } + } finally { + lock.unlock() + } + } + + override def boundPort(listenerName: ListenerName): Int = socketServer.boundPort(listenerName) + +} diff --git a/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala b/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala new file mode 100644 index 0000000..b671c70 --- /dev/null +++ b/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala @@ -0,0 +1,387 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util.concurrent.LinkedBlockingDeque +import java.util.concurrent.atomic.AtomicReference + +import kafka.common.{InterBrokerSendThread, RequestAndCompletionHandler} +import kafka.raft.RaftManager +import kafka.utils.Logging +import org.apache.kafka.clients._ +import org.apache.kafka.common.Node +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.network._ +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.AbstractRequest +import org.apache.kafka.common.security.JaasContext +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.utils.{LogContext, Time} +import org.apache.kafka.server.common.ApiMessageAndVersion + +import scala.collection.Seq +import scala.compat.java8.OptionConverters._ +import scala.jdk.CollectionConverters._ + +trait ControllerNodeProvider { + def get(): Option[Node] + def listenerName: ListenerName + def securityProtocol: SecurityProtocol + def saslMechanism: String +} + +object MetadataCacheControllerNodeProvider { + def apply( + config: KafkaConfig, + metadataCache: kafka.server.MetadataCache + ): MetadataCacheControllerNodeProvider = { + val listenerName = config.controlPlaneListenerName + .getOrElse(config.interBrokerListenerName) + + val securityProtocol = config.controlPlaneSecurityProtocol + .getOrElse(config.interBrokerSecurityProtocol) + + new MetadataCacheControllerNodeProvider( + metadataCache, + listenerName, + securityProtocol, + config.saslMechanismInterBrokerProtocol + ) + } +} + +class MetadataCacheControllerNodeProvider( + val metadataCache: kafka.server.MetadataCache, + val listenerName: ListenerName, + val securityProtocol: SecurityProtocol, + val saslMechanism: String +) extends ControllerNodeProvider { + override def get(): Option[Node] = { + metadataCache.getControllerId + .flatMap(metadataCache.getAliveBrokerNode(_, listenerName)) + } +} + +object RaftControllerNodeProvider { + def apply(raftManager: RaftManager[ApiMessageAndVersion], + config: KafkaConfig, + controllerQuorumVoterNodes: Seq[Node]): RaftControllerNodeProvider = { + val controllerListenerName = new ListenerName(config.controllerListenerNames.head) + val controllerSecurityProtocol = config.effectiveListenerSecurityProtocolMap.getOrElse(controllerListenerName, SecurityProtocol.forName(controllerListenerName.value())) + val controllerSaslMechanism = config.saslMechanismControllerProtocol + new RaftControllerNodeProvider( + raftManager, + controllerQuorumVoterNodes, + controllerListenerName, + controllerSecurityProtocol, + controllerSaslMechanism + ) + } +} + +/** + * Finds the controller node by checking the metadata log manager. + * This provider is used when we are using a Raft-based metadata quorum. + */ +class RaftControllerNodeProvider(val raftManager: RaftManager[ApiMessageAndVersion], + controllerQuorumVoterNodes: Seq[Node], + val listenerName: ListenerName, + val securityProtocol: SecurityProtocol, + val saslMechanism: String + ) extends ControllerNodeProvider with Logging { + val idToNode = controllerQuorumVoterNodes.map(node => node.id() -> node).toMap + + override def get(): Option[Node] = { + raftManager.leaderAndEpoch.leaderId.asScala.map(idToNode) + } +} + +object BrokerToControllerChannelManager { + def apply( + controllerNodeProvider: ControllerNodeProvider, + time: Time, + metrics: Metrics, + config: KafkaConfig, + channelName: String, + threadNamePrefix: Option[String], + retryTimeoutMs: Long + ): BrokerToControllerChannelManager = { + new BrokerToControllerChannelManagerImpl( + controllerNodeProvider, + time, + metrics, + config, + channelName, + threadNamePrefix, + retryTimeoutMs + ) + } +} + + +trait BrokerToControllerChannelManager { + def start(): Unit + def shutdown(): Unit + def controllerApiVersions(): Option[NodeApiVersions] + def sendRequest( + request: AbstractRequest.Builder[_ <: AbstractRequest], + callback: ControllerRequestCompletionHandler + ): Unit +} + + +/** + * This class manages the connection between a broker and the controller. It runs a single + * [[BrokerToControllerRequestThread]] which uses the broker's metadata cache as its own metadata to find + * and connect to the controller. The channel is async and runs the network connection in the background. + * The maximum number of in-flight requests are set to one to ensure orderly response from the controller, therefore + * care must be taken to not block on outstanding requests for too long. + */ +class BrokerToControllerChannelManagerImpl( + controllerNodeProvider: ControllerNodeProvider, + time: Time, + metrics: Metrics, + config: KafkaConfig, + channelName: String, + threadNamePrefix: Option[String], + retryTimeoutMs: Long +) extends BrokerToControllerChannelManager with Logging { + private val logContext = new LogContext(s"[BrokerToControllerChannelManager broker=${config.brokerId} name=$channelName] ") + private val manualMetadataUpdater = new ManualMetadataUpdater() + private val apiVersions = new ApiVersions() + private val currentNodeApiVersions = NodeApiVersions.create() + private val requestThread = newRequestThread + + def start(): Unit = { + requestThread.start() + } + + def shutdown(): Unit = { + requestThread.shutdown() + info(s"Broker to controller channel manager for $channelName shutdown") + } + + private[server] def newRequestThread = { + val networkClient = { + val channelBuilder = ChannelBuilders.clientChannelBuilder( + controllerNodeProvider.securityProtocol, + JaasContext.Type.SERVER, + config, + controllerNodeProvider.listenerName, + controllerNodeProvider.saslMechanism, + time, + config.saslInterBrokerHandshakeRequestEnable, + logContext + ) + val selector = new Selector( + NetworkReceive.UNLIMITED, + Selector.NO_IDLE_TIMEOUT_MS, + metrics, + time, + channelName, + Map("BrokerId" -> config.brokerId.toString).asJava, + false, + channelBuilder, + logContext + ) + new NetworkClient( + selector, + manualMetadataUpdater, + config.brokerId.toString, + 1, + 50, + 50, + Selectable.USE_DEFAULT_BUFFER_SIZE, + Selectable.USE_DEFAULT_BUFFER_SIZE, + config.requestTimeoutMs, + config.connectionSetupTimeoutMs, + config.connectionSetupTimeoutMaxMs, + time, + true, + apiVersions, + logContext + ) + } + val threadName = threadNamePrefix match { + case None => s"BrokerToControllerChannelManager broker=${config.brokerId} name=$channelName" + case Some(name) => s"$name:BrokerToControllerChannelManager broker=${config.brokerId} name=$channelName" + } + + new BrokerToControllerRequestThread( + networkClient, + manualMetadataUpdater, + controllerNodeProvider, + config, + time, + threadName, + retryTimeoutMs + ) + } + + /** + * Send request to the controller. + * + * @param request The request to be sent. + * @param callback Request completion callback. + */ + def sendRequest( + request: AbstractRequest.Builder[_ <: AbstractRequest], + callback: ControllerRequestCompletionHandler + ): Unit = { + requestThread.enqueue(BrokerToControllerQueueItem( + time.milliseconds(), + request, + callback + )) + } + + def controllerApiVersions(): Option[NodeApiVersions] = + requestThread.activeControllerAddress().flatMap( + activeController => if (activeController.id() == config.brokerId) + Some(currentNodeApiVersions) + else + Option(apiVersions.get(activeController.idString())) + ) +} + +abstract class ControllerRequestCompletionHandler extends RequestCompletionHandler { + + /** + * Fire when the request transmission time passes the caller defined deadline on the channel queue. + * It covers the total waiting time including retries which might be the result of individual request timeout. + */ + def onTimeout(): Unit +} + +case class BrokerToControllerQueueItem( + createdTimeMs: Long, + request: AbstractRequest.Builder[_ <: AbstractRequest], + callback: ControllerRequestCompletionHandler +) + +class BrokerToControllerRequestThread( + networkClient: KafkaClient, + metadataUpdater: ManualMetadataUpdater, + controllerNodeProvider: ControllerNodeProvider, + config: KafkaConfig, + time: Time, + threadName: String, + retryTimeoutMs: Long +) extends InterBrokerSendThread(threadName, networkClient, config.controllerSocketTimeoutMs, time, isInterruptible = false) { + + private val requestQueue = new LinkedBlockingDeque[BrokerToControllerQueueItem]() + private val activeController = new AtomicReference[Node](null) + + // Used for testing + @volatile + private[server] var started = false + + def activeControllerAddress(): Option[Node] = { + Option(activeController.get()) + } + + private def updateControllerAddress(newActiveController: Node): Unit = { + activeController.set(newActiveController) + } + + def enqueue(request: BrokerToControllerQueueItem): Unit = { + if (!started) { + throw new IllegalStateException("Cannot enqueue a request if the request thread is not running") + } + requestQueue.add(request) + if (activeControllerAddress().isDefined) { + wakeup() + } + } + + def queueSize: Int = { + requestQueue.size + } + + override def generateRequests(): Iterable[RequestAndCompletionHandler] = { + val currentTimeMs = time.milliseconds() + val requestIter = requestQueue.iterator() + while (requestIter.hasNext) { + val request = requestIter.next + if (currentTimeMs - request.createdTimeMs >= retryTimeoutMs) { + requestIter.remove() + request.callback.onTimeout() + } else { + val controllerAddress = activeControllerAddress() + if (controllerAddress.isDefined) { + requestIter.remove() + return Some(RequestAndCompletionHandler( + time.milliseconds(), + controllerAddress.get, + request.request, + handleResponse(request) + )) + } + } + } + None + } + + private[server] def handleResponse(queueItem: BrokerToControllerQueueItem)(response: ClientResponse): Unit = { + if (response.authenticationException != null) { + error(s"Request ${queueItem.request} failed due to authentication error with controller", + response.authenticationException) + queueItem.callback.onComplete(response) + } else if (response.versionMismatch != null) { + error(s"Request ${queueItem.request} failed due to unsupported version error", + response.versionMismatch) + queueItem.callback.onComplete(response) + } else if (response.wasDisconnected()) { + updateControllerAddress(null) + requestQueue.putFirst(queueItem) + } else if (response.responseBody().errorCounts().containsKey(Errors.NOT_CONTROLLER)) { + // just close the controller connection and wait for metadata cache update in doWork + activeControllerAddress().foreach { controllerAddress => { + networkClient.disconnect(controllerAddress.idString) + updateControllerAddress(null) + }} + + requestQueue.putFirst(queueItem) + } else { + queueItem.callback.onComplete(response) + } + } + + override def doWork(): Unit = { + if (activeControllerAddress().isDefined) { + super.pollOnce(Long.MaxValue) + } else { + debug("Controller isn't cached, looking for local metadata changes") + controllerNodeProvider.get() match { + case Some(controllerNode) => + info(s"Recorded new controller, from now on will use broker $controllerNode") + updateControllerAddress(controllerNode) + metadataUpdater.setNodes(Seq(controllerNode).asJava) + case None => + // need to backoff to avoid tight loops + debug("No controller defined in metadata cache, retrying after backoff") + super.pollOnce(maxTimeoutMs = 100) + } + } + } + + override def start(): Unit = { + super.start() + started = true + } +} diff --git a/core/src/main/scala/kafka/server/ClientQuotaManager.scala b/core/src/main/scala/kafka/server/ClientQuotaManager.scala new file mode 100644 index 0000000..1e6523f --- /dev/null +++ b/core/src/main/scala/kafka/server/ClientQuotaManager.scala @@ -0,0 +1,686 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.{lang, util} +import java.util.concurrent.{ConcurrentHashMap, DelayQueue, TimeUnit} +import java.util.concurrent.locks.ReentrantReadWriteLock + +import kafka.network.RequestChannel +import kafka.network.RequestChannel._ +import kafka.server.ClientQuotaManager._ +import kafka.utils.{Logging, QuotaUtils, ShutdownableThread} +import org.apache.kafka.common.{Cluster, MetricName} +import org.apache.kafka.common.metrics._ +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.metrics.stats.{Avg, CumulativeSum, Rate} +import org.apache.kafka.common.security.auth.KafkaPrincipal +import org.apache.kafka.common.utils.{Sanitizer, Time} +import org.apache.kafka.server.quota.{ClientQuotaCallback, ClientQuotaEntity, ClientQuotaType} + +import scala.jdk.CollectionConverters._ + +/** + * Represents the sensors aggregated per client + * @param metricTags Quota metric tags for the client + * @param quotaSensor @Sensor that tracks the quota + * @param throttleTimeSensor @Sensor that tracks the throttle time + */ +case class ClientSensors(metricTags: Map[String, String], quotaSensor: Sensor, throttleTimeSensor: Sensor) + +/** + * Configuration settings for quota management + * @param numQuotaSamples The number of samples to retain in memory + * @param quotaWindowSizeSeconds The time span of each sample + * + */ +case class ClientQuotaManagerConfig(numQuotaSamples: Int = + ClientQuotaManagerConfig.DefaultNumQuotaSamples, + quotaWindowSizeSeconds: Int = + ClientQuotaManagerConfig.DefaultQuotaWindowSizeSeconds) + +object ClientQuotaManagerConfig { + // Always have 10 whole windows + 1 current window + val DefaultNumQuotaSamples = 11 + val DefaultQuotaWindowSizeSeconds = 1 +} + +object QuotaTypes { + val NoQuotas = 0 + val ClientIdQuotaEnabled = 1 + val UserQuotaEnabled = 2 + val UserClientIdQuotaEnabled = 4 + val CustomQuotas = 8 // No metric update optimizations are used with custom quotas +} + +object ClientQuotaManager { + // Purge sensors after 1 hour of inactivity + val InactiveSensorExpirationTimeSeconds = 3600 + + val DefaultClientIdQuotaEntity = KafkaQuotaEntity(None, Some(DefaultClientIdEntity)) + val DefaultUserQuotaEntity = KafkaQuotaEntity(Some(DefaultUserEntity), None) + val DefaultUserClientIdQuotaEntity = KafkaQuotaEntity(Some(DefaultUserEntity), Some(DefaultClientIdEntity)) + + sealed trait BaseUserEntity extends ClientQuotaEntity.ConfigEntity + + case class UserEntity(sanitizedUser: String) extends BaseUserEntity { + override def entityType: ClientQuotaEntity.ConfigEntityType = ClientQuotaEntity.ConfigEntityType.USER + override def name: String = Sanitizer.desanitize(sanitizedUser) + override def toString: String = s"user $sanitizedUser" + } + + case class ClientIdEntity(clientId: String) extends ClientQuotaEntity.ConfigEntity { + override def entityType: ClientQuotaEntity.ConfigEntityType = ClientQuotaEntity.ConfigEntityType.CLIENT_ID + override def name: String = clientId + override def toString: String = s"client-id $clientId" + } + + case object DefaultUserEntity extends BaseUserEntity { + override def entityType: ClientQuotaEntity.ConfigEntityType = ClientQuotaEntity.ConfigEntityType.DEFAULT_USER + override def name: String = ConfigEntityName.Default + override def toString: String = "default user" + } + + case object DefaultClientIdEntity extends ClientQuotaEntity.ConfigEntity { + override def entityType: ClientQuotaEntity.ConfigEntityType = ClientQuotaEntity.ConfigEntityType.DEFAULT_CLIENT_ID + override def name: String = ConfigEntityName.Default + override def toString: String = "default client-id" + } + + case class KafkaQuotaEntity(userEntity: Option[BaseUserEntity], + clientIdEntity: Option[ClientQuotaEntity.ConfigEntity]) extends ClientQuotaEntity { + override def configEntities: util.List[ClientQuotaEntity.ConfigEntity] = + (userEntity.toList ++ clientIdEntity.toList).asJava + + def sanitizedUser: String = userEntity.map { + case entity: UserEntity => entity.sanitizedUser + case DefaultUserEntity => ConfigEntityName.Default + }.getOrElse("") + + def clientId: String = clientIdEntity.map(_.name).getOrElse("") + + override def toString: String = { + val user = userEntity.map(_.toString).getOrElse("") + val clientId = clientIdEntity.map(_.toString).getOrElse("") + s"$user $clientId".trim + } + } + + object DefaultTags { + val User = "user" + val ClientId = "client-id" + } + + /** + * This calculates the amount of time needed to bring the metric within quota + * assuming that no new metrics are recorded. + * + * Basically, if O is the observed rate and T is the target rate over a window of W, to bring O down to T, + * we need to add a delay of X to W such that O * W / (W + X) = T. + * Solving for X, we get X = (O - T)/T * W. + */ + def throttleTime(e: QuotaViolationException, timeMs: Long): Long = { + val difference = e.value - e.bound + // Use the precise window used by the rate calculation + val throttleTimeMs = difference / e.bound * windowSize(e.metric, timeMs) + Math.round(throttleTimeMs) + } + + private def windowSize(metric: KafkaMetric, timeMs: Long): Long = + measurableAsRate(metric.metricName, metric.measurable).windowSize(metric.config, timeMs) + + // Casting to Rate because we only use Rate in Quota computation + private def measurableAsRate(name: MetricName, measurable: Measurable): Rate = { + measurable match { + case r: Rate => r + case _ => throw new IllegalArgumentException(s"Metric $name is not a Rate metric, value $measurable") + } + } +} + +/** + * Helper class that records per-client metrics. It is also responsible for maintaining Quota usage statistics + * for all clients. + *

            + * Quotas can be set at , user or client-id levels. For a given client connection, + * the most specific quota matching the connection will be applied. For example, if both a + * and a user quota match a connection, the quota will be used. Otherwise, user quota takes + * precedence over client-id quota. The order of precedence is: + *

              + *
            • /config/users//clients/ + *
            • /config/users//clients/ + *
            • /config/users/ + *
            • /config/users//clients/ + *
            • /config/users//clients/ + *
            • /config/users/ + *
            • /config/clients/ + *
            • /config/clients/ + *
            + * Quota limits including defaults may be updated dynamically. The implementation is optimized for the case + * where a single level of quotas is configured. + * + * @param config @ClientQuotaManagerConfig quota configs + * @param metrics @Metrics Metrics instance + * @param quotaType Quota type of this quota manager + * @param time @Time object to use + * @param threadNamePrefix The thread prefix to use + * @param clientQuotaCallback An optional @ClientQuotaCallback + */ +class ClientQuotaManager(private val config: ClientQuotaManagerConfig, + private val metrics: Metrics, + private val quotaType: QuotaType, + private val time: Time, + private val threadNamePrefix: String, + private val clientQuotaCallback: Option[ClientQuotaCallback] = None) extends Logging { + + private val lock = new ReentrantReadWriteLock() + private val sensorAccessor = new SensorAccess(lock, metrics) + private val quotaCallback = clientQuotaCallback.getOrElse(new DefaultQuotaCallback) + private val clientQuotaType = QuotaType.toClientQuotaType(quotaType) + + @volatile + private var quotaTypesEnabled = clientQuotaCallback match { + case Some(_) => QuotaTypes.CustomQuotas + case None => QuotaTypes.NoQuotas + } + + private val delayQueueSensor = metrics.sensor(quotaType.toString + "-delayQueue") + delayQueueSensor.add(metrics.metricName("queue-size", quotaType.toString, + "Tracks the size of the delay queue"), new CumulativeSum()) + + private val delayQueue = new DelayQueue[ThrottledChannel]() + private[server] val throttledChannelReaper = new ThrottledChannelReaper(delayQueue, threadNamePrefix) + start() // Use start method to keep spotbugs happy + private def start(): Unit = { + throttledChannelReaper.start() + } + + /** + * Reaper thread that triggers channel unmute callbacks on all throttled channels + * @param delayQueue DelayQueue to dequeue from + */ + class ThrottledChannelReaper(delayQueue: DelayQueue[ThrottledChannel], prefix: String) extends ShutdownableThread( + s"${prefix}ThrottledChannelReaper-$quotaType", false) { + + override def doWork(): Unit = { + val throttledChannel: ThrottledChannel = delayQueue.poll(1, TimeUnit.SECONDS) + if (throttledChannel != null) { + // Decrement the size of the delay queue + delayQueueSensor.record(-1) + // Notify the socket server that throttling is done for this channel, so that it can try to unmute the channel. + throttledChannel.notifyThrottlingDone() + } + } + } + + /** + * Returns true if any quotas are enabled for this quota manager. This is used + * to determine if quota related metrics should be created. + * Note: If any quotas (static defaults, dynamic defaults or quota overrides) have + * been configured for this broker at any time for this quota type, quotasEnabled will + * return true until the next broker restart, even if all quotas are subsequently deleted. + */ + def quotasEnabled: Boolean = quotaTypesEnabled != QuotaTypes.NoQuotas + + /** + * See {recordAndGetThrottleTimeMs}. + */ + def maybeRecordAndGetThrottleTimeMs(request: RequestChannel.Request, value: Double, timeMs: Long): Int = { + maybeRecordAndGetThrottleTimeMs(request.session, request.header.clientId, value, timeMs) + } + + /** + * See {recordAndGetThrottleTimeMs}. + */ + def maybeRecordAndGetThrottleTimeMs(session: Session, clientId: String, value: Double, timeMs: Long): Int = { + // Record metrics only if quotas are enabled. + if (quotasEnabled) { + recordAndGetThrottleTimeMs(session, clientId, value, timeMs) + } else { + 0 + } + } + + /** + * Records that a user/clientId accumulated or would like to accumulate the provided amount at the + * the specified time, returns throttle time in milliseconds. + * + * @param session The session from which the user is extracted + * @param clientId The client id + * @param value The value to accumulate + * @param timeMs The time at which to accumulate the value + * @return The throttle time in milliseconds defines as the time to wait until the average + * rate gets back to the defined quota + */ + def recordAndGetThrottleTimeMs(session: Session, clientId: String, value: Double, timeMs: Long): Int = { + val clientSensors = getOrCreateQuotaSensors(session, clientId) + try { + clientSensors.quotaSensor.record(value, timeMs, true) + 0 + } catch { + case e: QuotaViolationException => + val throttleTimeMs = throttleTime(e, timeMs).toInt + debug(s"Quota violated for sensor (${clientSensors.quotaSensor.name}). Delay time: ($throttleTimeMs)") + throttleTimeMs + } + } + + /** + * Records that a user/clientId changed some metric being throttled without checking for + * quota violation. The aggregate value will subsequently be used for throttling when the + * next request is processed. + */ + def recordNoThrottle(session: Session, clientId: String, value: Double): Unit = { + val clientSensors = getOrCreateQuotaSensors(session, clientId) + clientSensors.quotaSensor.record(value, time.milliseconds(), false) + } + + /** + * "Unrecord" the given value that has already been recorded for the given user/client by recording a negative value + * of the same quantity. + * + * For a throttled fetch, the broker should return an empty response and thus should not record the value. Ideally, + * we would like to compute the throttle time before actually recording the value, but the current Sensor code + * couples value recording and quota checking very tightly. As a workaround, we will unrecord the value for the fetch + * in case of throttling. Rate keeps the sum of values that fall in each time window, so this should bring the + * overall sum back to the previous value. + */ + def unrecordQuotaSensor(request: RequestChannel.Request, value: Double, timeMs: Long): Unit = { + val clientSensors = getOrCreateQuotaSensors(request.session, request.header.clientId) + clientSensors.quotaSensor.record(value * (-1), timeMs, false) + } + + /** + * Returns maximum value that could be recorded without guaranteed throttling. + * Recording any larger value will always be throttled, even if no other values were recorded in the quota window. + * This is used for deciding the maximum bytes that can be fetched at once + */ + def getMaxValueInQuotaWindow(session: Session, clientId: String): Double = { + if (quotasEnabled) { + val clientSensors = getOrCreateQuotaSensors(session, clientId) + Option(quotaCallback.quotaLimit(clientQuotaType, clientSensors.metricTags.asJava)) + .map(_.toDouble * (config.numQuotaSamples - 1) * config.quotaWindowSizeSeconds) + .getOrElse(Double.MaxValue) + } else { + Double.MaxValue + } + } + + /** + * Throttle a client by muting the associated channel for the given throttle time. + * + * @param request client request + * @param throttleTimeMs Duration in milliseconds for which the channel is to be muted. + * @param throttleCallback Callback for channel throttling + */ + def throttle( + request: RequestChannel.Request, + throttleCallback: ThrottleCallback, + throttleTimeMs: Int + ): Unit = { + if (throttleTimeMs > 0) { + val clientSensors = getOrCreateQuotaSensors(request.session, request.headerForLoggingOrThrottling().clientId) + clientSensors.throttleTimeSensor.record(throttleTimeMs) + val throttledChannel = new ThrottledChannel(time, throttleTimeMs, throttleCallback) + delayQueue.add(throttledChannel) + delayQueueSensor.record() + debug("Channel throttled for sensor (%s). Delay time: (%d)".format(clientSensors.quotaSensor.name(), throttleTimeMs)) + } + } + + /** + * Returns the quota for the client with the specified (non-encoded) user principal and client-id. + * + * Note: this method is expensive, it is meant to be used by tests only + */ + def quota(user: String, clientId: String): Quota = { + val userPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, user) + quota(userPrincipal, clientId) + } + + /** + * Returns the quota for the client with the specified user principal and client-id. + * + * Note: this method is expensive, it is meant to be used by tests only + */ + def quota(userPrincipal: KafkaPrincipal, clientId: String): Quota = { + val metricTags = quotaCallback.quotaMetricTags(clientQuotaType, userPrincipal, clientId) + Quota.upperBound(quotaLimit(metricTags)) + } + + private def quotaLimit(metricTags: util.Map[String, String]): Double = { + Option(quotaCallback.quotaLimit(clientQuotaType, metricTags)).map(_.toDouble).getOrElse(Long.MaxValue) + } + + /** + * This calculates the amount of time needed to bring the metric within quota + * assuming that no new metrics are recorded. + * + * See {QuotaUtils.throttleTime} for the details. + */ + protected def throttleTime(e: QuotaViolationException, timeMs: Long): Long = { + QuotaUtils.throttleTime(e, timeMs) + } + + /** + * This function either returns the sensors for a given client id or creates them if they don't exist + * First sensor of the tuple is the quota enforcement sensor. Second one is the throttle time sensor + */ + def getOrCreateQuotaSensors(session: Session, clientId: String): ClientSensors = { + // Use cached sanitized principal if using default callback + val metricTags = quotaCallback match { + case callback: DefaultQuotaCallback => callback.quotaMetricTags(session.sanitizedUser, clientId) + case _ => quotaCallback.quotaMetricTags(clientQuotaType, session.principal, clientId).asScala.toMap + } + // Names of the sensors to access + val sensors = ClientSensors( + metricTags, + sensorAccessor.getOrCreate( + getQuotaSensorName(metricTags), + ClientQuotaManager.InactiveSensorExpirationTimeSeconds, + registerQuotaMetrics(metricTags) + ), + sensorAccessor.getOrCreate( + getThrottleTimeSensorName(metricTags), + ClientQuotaManager.InactiveSensorExpirationTimeSeconds, + sensor => sensor.add(throttleMetricName(metricTags), new Avg) + ) + ) + if (quotaCallback.quotaResetRequired(clientQuotaType)) + updateQuotaMetricConfigs() + sensors + } + + protected def registerQuotaMetrics(metricTags: Map[String, String])(sensor: Sensor): Unit = { + sensor.add( + clientQuotaMetricName(metricTags), + new Rate, + getQuotaMetricConfig(metricTags) + ) + } + + private def metricTagsToSensorSuffix(metricTags: Map[String, String]): String = + metricTags.values.mkString(":") + + private def getThrottleTimeSensorName(metricTags: Map[String, String]): String = + s"${quotaType}ThrottleTime-${metricTagsToSensorSuffix(metricTags)}" + + private def getQuotaSensorName(metricTags: Map[String, String]): String = + s"$quotaType-${metricTagsToSensorSuffix(metricTags)}" + + protected def getQuotaMetricConfig(metricTags: Map[String, String]): MetricConfig = { + getQuotaMetricConfig(quotaLimit(metricTags.asJava)) + } + + private def getQuotaMetricConfig(quotaLimit: Double): MetricConfig = { + new MetricConfig() + .timeWindow(config.quotaWindowSizeSeconds, TimeUnit.SECONDS) + .samples(config.numQuotaSamples) + .quota(new Quota(quotaLimit, true)) + } + + protected def getOrCreateSensor(sensorName: String, metricName: MetricName): Sensor = { + sensorAccessor.getOrCreate( + sensorName, + ClientQuotaManager.InactiveSensorExpirationTimeSeconds, + sensor => sensor.add(metricName, new Rate) + ) + } + + /** + * Overrides quotas for , or or the dynamic defaults + * for any of these levels. + * + * @param sanitizedUser user to override if quota applies to or + * @param clientId client to override if quota applies to or + * @param sanitizedClientId sanitized client ID to override if quota applies to or + * @param quota custom quota to apply or None if quota override is being removed + */ + def updateQuota(sanitizedUser: Option[String], clientId: Option[String], sanitizedClientId: Option[String], quota: Option[Quota]): Unit = { + /* + * Acquire the write lock to apply changes in the quota objects. + * This method changes the quota in the overriddenQuota map and applies the update on the actual KafkaMetric object (if it exists). + * If the KafkaMetric hasn't been created, the most recent value will be used from the overriddenQuota map. + * The write lock prevents quota update and creation at the same time. It also guards against concurrent quota change + * notifications + */ + lock.writeLock().lock() + try { + val userEntity = sanitizedUser.map { + case ConfigEntityName.Default => DefaultUserEntity + case user => UserEntity(user) + } + val clientIdEntity = sanitizedClientId.map { + case ConfigEntityName.Default => DefaultClientIdEntity + case _ => ClientIdEntity(clientId.getOrElse(throw new IllegalStateException("Client-id not provided"))) + } + val quotaEntity = KafkaQuotaEntity(userEntity, clientIdEntity) + + if (userEntity.nonEmpty) { + if (quotaEntity.clientIdEntity.nonEmpty) + quotaTypesEnabled |= QuotaTypes.UserClientIdQuotaEnabled + else + quotaTypesEnabled |= QuotaTypes.UserQuotaEnabled + } else if (clientIdEntity.nonEmpty) + quotaTypesEnabled |= QuotaTypes.ClientIdQuotaEnabled + + quota match { + case Some(newQuota) => quotaCallback.updateQuota(clientQuotaType, quotaEntity, newQuota.bound) + case None => quotaCallback.removeQuota(clientQuotaType, quotaEntity) + } + val updatedEntity = if (userEntity.contains(DefaultUserEntity) || clientIdEntity.contains(DefaultClientIdEntity)) + None // more than one entity may need updating, so `updateQuotaMetricConfigs` will go through all metrics + else + Some(quotaEntity) + updateQuotaMetricConfigs(updatedEntity) + + } finally { + lock.writeLock().unlock() + } + } + + /** + * Updates metrics configs. This is invoked when quota configs are updated in ZooKeeper + * or when partitions leaders change and custom callbacks that implement partition-based quotas + * have updated quotas. + * + * @param updatedQuotaEntity If set to one entity and quotas have only been enabled at one + * level, then an optimized update is performed with a single metric update. If None is provided, + * or if custom callbacks are used or if multi-level quotas have been enabled, all metric configs + * are checked and updated if required. + */ + def updateQuotaMetricConfigs(updatedQuotaEntity: Option[KafkaQuotaEntity] = None): Unit = { + val allMetrics = metrics.metrics() + + // If using custom quota callbacks or if multiple-levels of quotas are defined or + // if this is a default quota update, traverse metrics to find all affected values. + // Otherwise, update just the single matching one. + val singleUpdate = quotaTypesEnabled match { + case QuotaTypes.NoQuotas | QuotaTypes.ClientIdQuotaEnabled | QuotaTypes.UserQuotaEnabled | QuotaTypes.UserClientIdQuotaEnabled => + updatedQuotaEntity.nonEmpty + case _ => false + } + if (singleUpdate) { + val quotaEntity = updatedQuotaEntity.getOrElse(throw new IllegalStateException("Quota entity not specified")) + val user = quotaEntity.sanitizedUser + val clientId = quotaEntity.clientId + val metricTags = Map(DefaultTags.User -> user, DefaultTags.ClientId -> clientId) + + val quotaMetricName = clientQuotaMetricName(metricTags) + // Change the underlying metric config if the sensor has been created + val metric = allMetrics.get(quotaMetricName) + if (metric != null) { + Option(quotaLimit(metricTags.asJava)).foreach { newQuota => + info(s"Sensor for $quotaEntity already exists. Changing quota to $newQuota in MetricConfig") + metric.config(getQuotaMetricConfig(newQuota)) + } + } + } else { + val quotaMetricName = clientQuotaMetricName(Map.empty) + allMetrics.forEach { (metricName, metric) => + if (metricName.name == quotaMetricName.name && metricName.group == quotaMetricName.group) { + val metricTags = metricName.tags + Option(quotaLimit(metricTags)).foreach { newQuota => + if (newQuota != metric.config.quota.bound) { + info(s"Sensor for quota-id $metricTags already exists. Setting quota to $newQuota in MetricConfig") + metric.config(getQuotaMetricConfig(newQuota)) + } + } + } + } + } + } + + /** + * Returns the MetricName of the metric used for the quota. The name is used to create the + * metric but also to find the metric when the quota is changed. + */ + protected def clientQuotaMetricName(quotaMetricTags: Map[String, String]): MetricName = { + metrics.metricName("byte-rate", quotaType.toString, + "Tracking byte-rate per user/client-id", + quotaMetricTags.asJava) + } + + private def throttleMetricName(quotaMetricTags: Map[String, String]): MetricName = { + metrics.metricName("throttle-time", + quotaType.toString, + "Tracking average throttle-time per user/client-id", + quotaMetricTags.asJava) + } + + def shutdown(): Unit = { + throttledChannelReaper.shutdown() + } + + class DefaultQuotaCallback extends ClientQuotaCallback { + private val overriddenQuotas = new ConcurrentHashMap[ClientQuotaEntity, Quota]() + + override def configure(configs: util.Map[String, _]): Unit = {} + + override def quotaMetricTags(quotaType: ClientQuotaType, principal: KafkaPrincipal, clientId: String): util.Map[String, String] = { + quotaMetricTags(Sanitizer.sanitize(principal.getName), clientId).asJava + } + + override def quotaLimit(quotaType: ClientQuotaType, metricTags: util.Map[String, String]): lang.Double = { + val sanitizedUser = metricTags.get(DefaultTags.User) + val clientId = metricTags.get(DefaultTags.ClientId) + var quota: Quota = null + + if (sanitizedUser != null && clientId != null) { + val userEntity = Some(UserEntity(sanitizedUser)) + val clientIdEntity = Some(ClientIdEntity(clientId)) + if (!sanitizedUser.isEmpty && !clientId.isEmpty) { + // /config/users//clients/ + quota = overriddenQuotas.get(KafkaQuotaEntity(userEntity, clientIdEntity)) + if (quota == null) { + // /config/users//clients/ + quota = overriddenQuotas.get(KafkaQuotaEntity(userEntity, Some(DefaultClientIdEntity))) + } + if (quota == null) { + // /config/users//clients/ + quota = overriddenQuotas.get(KafkaQuotaEntity(Some(DefaultUserEntity), clientIdEntity)) + } + if (quota == null) { + // /config/users//clients/ + quota = overriddenQuotas.get(DefaultUserClientIdQuotaEntity) + } + } else if (!sanitizedUser.isEmpty) { + // /config/users/ + quota = overriddenQuotas.get(KafkaQuotaEntity(userEntity, None)) + if (quota == null) { + // /config/users/ + quota = overriddenQuotas.get(DefaultUserQuotaEntity) + } + } else if (!clientId.isEmpty) { + // /config/clients/ + quota = overriddenQuotas.get(KafkaQuotaEntity(None, clientIdEntity)) + if (quota == null) { + // /config/clients/ + quota = overriddenQuotas.get(DefaultClientIdQuotaEntity) + } + } + } + if (quota == null) null else quota.bound + } + + override def updateClusterMetadata(cluster: Cluster): Boolean = { + // Default quota callback does not use any cluster metadata + false + } + + override def updateQuota(quotaType: ClientQuotaType, entity: ClientQuotaEntity, newValue: Double): Unit = { + val quotaEntity = entity.asInstanceOf[KafkaQuotaEntity] + info(s"Changing $quotaType quota for $quotaEntity to $newValue") + overriddenQuotas.put(quotaEntity, new Quota(newValue, true)) + } + + override def removeQuota(quotaType: ClientQuotaType, entity: ClientQuotaEntity): Unit = { + val quotaEntity = entity.asInstanceOf[KafkaQuotaEntity] + info(s"Removing $quotaType quota for $quotaEntity") + overriddenQuotas.remove(quotaEntity) + } + + override def quotaResetRequired(quotaType: ClientQuotaType): Boolean = false + + def quotaMetricTags(sanitizedUser: String, clientId: String) : Map[String, String] = { + val (userTag, clientIdTag) = quotaTypesEnabled match { + case QuotaTypes.NoQuotas | QuotaTypes.ClientIdQuotaEnabled => + ("", clientId) + case QuotaTypes.UserQuotaEnabled => + (sanitizedUser, "") + case QuotaTypes.UserClientIdQuotaEnabled => + (sanitizedUser, clientId) + case _ => + val userEntity = Some(UserEntity(sanitizedUser)) + val clientIdEntity = Some(ClientIdEntity(clientId)) + + var metricTags = (sanitizedUser, clientId) + // 1) /config/users//clients/ + if (!overriddenQuotas.containsKey(KafkaQuotaEntity(userEntity, clientIdEntity))) { + // 2) /config/users//clients/ + metricTags = (sanitizedUser, clientId) + if (!overriddenQuotas.containsKey(KafkaQuotaEntity(userEntity, Some(DefaultClientIdEntity)))) { + // 3) /config/users/ + metricTags = (sanitizedUser, "") + if (!overriddenQuotas.containsKey(KafkaQuotaEntity(userEntity, None))) { + // 4) /config/users//clients/ + metricTags = (sanitizedUser, clientId) + if (!overriddenQuotas.containsKey(KafkaQuotaEntity(Some(DefaultUserEntity), clientIdEntity))) { + // 5) /config/users//clients/ + metricTags = (sanitizedUser, clientId) + if (!overriddenQuotas.containsKey(DefaultUserClientIdQuotaEntity)) { + // 6) /config/users/ + metricTags = (sanitizedUser, "") + if (!overriddenQuotas.containsKey(DefaultUserQuotaEntity)) { + // 7) /config/clients/ + // 8) /config/clients/ + // 9) static client-id quota + metricTags = ("", clientId) + } + } + } + } + } + } + metricTags + } + Map(DefaultTags.User -> userTag, DefaultTags.ClientId -> clientIdTag) + } + + override def close(): Unit = {} + } +} diff --git a/core/src/main/scala/kafka/server/ClientRequestQuotaManager.scala b/core/src/main/scala/kafka/server/ClientRequestQuotaManager.scala new file mode 100644 index 0000000..2ceaab9 --- /dev/null +++ b/core/src/main/scala/kafka/server/ClientRequestQuotaManager.scala @@ -0,0 +1,90 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.util.concurrent.TimeUnit + +import kafka.network.RequestChannel +import kafka.utils.QuotaUtils +import org.apache.kafka.common.MetricName +import org.apache.kafka.common.metrics._ +import org.apache.kafka.common.utils.Time +import org.apache.kafka.server.quota.ClientQuotaCallback + +import scala.jdk.CollectionConverters._ + +object ClientRequestQuotaManager { + val QuotaRequestPercentDefault = Int.MaxValue.toDouble + val NanosToPercentagePerSecond = 100.0 / TimeUnit.SECONDS.toNanos(1) + + private val ExemptSensorName = "exempt-" + QuotaType.Request +} + +class ClientRequestQuotaManager(private val config: ClientQuotaManagerConfig, + private val metrics: Metrics, + private val time: Time, + private val threadNamePrefix: String, + private val quotaCallback: Option[ClientQuotaCallback]) + extends ClientQuotaManager(config, metrics, QuotaType.Request, time, threadNamePrefix, quotaCallback) { + + private val maxThrottleTimeMs = TimeUnit.SECONDS.toMillis(this.config.quotaWindowSizeSeconds) + private val exemptMetricName = metrics.metricName("exempt-request-time", + QuotaType.Request.toString, "Tracking exempt-request-time utilization percentage") + + lazy val exemptSensor: Sensor = getOrCreateSensor(ClientRequestQuotaManager.ExemptSensorName, exemptMetricName) + + def recordExempt(value: Double): Unit = { + exemptSensor.record(value) + } + + /** + * Records that a user/clientId changed request processing time being throttled. If quota has been violated, return + * throttle time in milliseconds. Throttle time calculation may be overridden by sub-classes. + * @param request client request + * @return Number of milliseconds to throttle in case of quota violation. Zero otherwise + */ + def maybeRecordAndGetThrottleTimeMs(request: RequestChannel.Request, timeMs: Long): Int = { + if (quotasEnabled) { + request.recordNetworkThreadTimeCallback = Some(timeNanos => recordNoThrottle( + request.session, request.header.clientId, nanosToPercentage(timeNanos))) + recordAndGetThrottleTimeMs(request.session, request.header.clientId, + nanosToPercentage(request.requestThreadTimeNanos), timeMs) + } else { + 0 + } + } + + def maybeRecordExempt(request: RequestChannel.Request): Unit = { + if (quotasEnabled) { + request.recordNetworkThreadTimeCallback = Some(timeNanos => recordExempt(nanosToPercentage(timeNanos))) + recordExempt(nanosToPercentage(request.requestThreadTimeNanos)) + } + } + + override protected def throttleTime(e: QuotaViolationException, timeMs: Long): Long = { + QuotaUtils.boundedThrottleTime(e, maxThrottleTimeMs, timeMs) + } + + override protected def clientQuotaMetricName(quotaMetricTags: Map[String, String]): MetricName = { + metrics.metricName("request-time", QuotaType.Request.toString, + "Tracking request-time per user/client-id", + quotaMetricTags.asJava) + } + + private def nanosToPercentage(nanos: Long): Double = + nanos * ClientRequestQuotaManager.NanosToPercentagePerSecond +} diff --git a/core/src/main/scala/kafka/server/ConfigHandler.scala b/core/src/main/scala/kafka/server/ConfigHandler.scala new file mode 100644 index 0000000..ab8639b --- /dev/null +++ b/core/src/main/scala/kafka/server/ConfigHandler.scala @@ -0,0 +1,249 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.net.{InetAddress, UnknownHostException} +import java.util.Properties +import DynamicConfig.Broker._ +import kafka.controller.KafkaController +import kafka.log.LogConfig.MessageFormatVersion +import kafka.log.{LogConfig, LogManager} +import kafka.network.ConnectionQuotas +import kafka.security.CredentialProvider +import kafka.server.Constants._ +import kafka.server.QuotaFactory.QuotaManagers +import kafka.utils.Implicits._ +import kafka.utils.Logging +import org.apache.kafka.common.config.ConfigDef.Validator +import org.apache.kafka.common.config.ConfigException +import org.apache.kafka.common.config.internals.QuotaConfigs +import org.apache.kafka.common.metrics.Quota +import org.apache.kafka.common.metrics.Quota._ +import org.apache.kafka.common.utils.Sanitizer + +import scala.annotation.nowarn +import scala.jdk.CollectionConverters._ +import scala.collection.Seq +import scala.util.Try + +/** + * The ConfigHandler is used to process config change notifications received by the DynamicConfigManager + */ +trait ConfigHandler { + def processConfigChanges(entityName: String, value: Properties): Unit +} + +/** + * The TopicConfigHandler will process topic config changes from ZooKeeper or the metadata log. + * The callback provides the topic name and the full properties set. + */ +class TopicConfigHandler(private val logManager: LogManager, kafkaConfig: KafkaConfig, + val quotas: QuotaManagers, kafkaController: Option[KafkaController]) extends ConfigHandler with Logging { + + def processConfigChanges(topic: String, topicConfig: Properties): Unit = { + // Validate the configurations. + val configNamesToExclude = excludedConfigs(topic, topicConfig) + val props = new Properties() + topicConfig.asScala.forKeyValue { (key, value) => + if (!configNamesToExclude.contains(key)) props.put(key, value) + } + logManager.updateTopicConfig(topic, props) + + def updateThrottledList(prop: String, quotaManager: ReplicationQuotaManager) = { + if (topicConfig.containsKey(prop) && topicConfig.getProperty(prop).nonEmpty) { + val partitions = parseThrottledPartitions(topicConfig, kafkaConfig.brokerId, prop) + quotaManager.markThrottled(topic, partitions) + debug(s"Setting $prop on broker ${kafkaConfig.brokerId} for topic: $topic and partitions $partitions") + } else { + quotaManager.removeThrottle(topic) + debug(s"Removing $prop from broker ${kafkaConfig.brokerId} for topic $topic") + } + } + updateThrottledList(LogConfig.LeaderReplicationThrottledReplicasProp, quotas.leader) + updateThrottledList(LogConfig.FollowerReplicationThrottledReplicasProp, quotas.follower) + + if (Try(topicConfig.getProperty(KafkaConfig.UncleanLeaderElectionEnableProp).toBoolean).getOrElse(false)) { + kafkaController.foreach(_.enableTopicUncleanLeaderElection(topic)) + } + } + + def parseThrottledPartitions(topicConfig: Properties, brokerId: Int, prop: String): Seq[Int] = { + val configValue = topicConfig.get(prop).toString.trim + ThrottledReplicaListValidator.ensureValidString(prop, configValue) + configValue match { + case "" => Seq() + case "*" => AllReplicas + case _ => configValue.trim + .split(",") + .map(_.split(":")) + .filter(_ (1).toInt == brokerId) //Filter this replica + .map(_ (0).toInt).toSeq //convert to list of partition ids + } + } + + @nowarn("cat=deprecation") + def excludedConfigs(topic: String, topicConfig: Properties): Set[String] = { + // Verify message format version + Option(topicConfig.getProperty(LogConfig.MessageFormatVersionProp)).flatMap { versionString => + val messageFormatVersion = new MessageFormatVersion(versionString, kafkaConfig.interBrokerProtocolVersion.version) + if (messageFormatVersion.shouldIgnore) { + if (messageFormatVersion.shouldWarn) + warn(messageFormatVersion.topicWarningMessage(topic)) + Some(LogConfig.MessageFormatVersionProp) + } else if (kafkaConfig.interBrokerProtocolVersion < messageFormatVersion.messageFormatVersion) { + warn(s"Topic configuration ${LogConfig.MessageFormatVersionProp} is ignored for `$topic` because `$versionString` " + + s"is higher than what is allowed by the inter-broker protocol version `${kafkaConfig.interBrokerProtocolVersionString}`") + Some(LogConfig.MessageFormatVersionProp) + } else + None + }.toSet + } +} + + +/** + * Handles , or quota config updates in ZK. + * This implementation reports the overrides to the respective ClientQuotaManager objects + */ +class QuotaConfigHandler(private val quotaManagers: QuotaManagers) { + + def updateQuotaConfig(sanitizedUser: Option[String], sanitizedClientId: Option[String], config: Properties): Unit = { + val clientId = sanitizedClientId.map(Sanitizer.desanitize) + val producerQuota = + if (config.containsKey(QuotaConfigs.PRODUCER_BYTE_RATE_OVERRIDE_CONFIG)) + Some(new Quota(config.getProperty(QuotaConfigs.PRODUCER_BYTE_RATE_OVERRIDE_CONFIG).toLong.toDouble, true)) + else + None + quotaManagers.produce.updateQuota(sanitizedUser, clientId, sanitizedClientId, producerQuota) + val consumerQuota = + if (config.containsKey(QuotaConfigs.CONSUMER_BYTE_RATE_OVERRIDE_CONFIG)) + Some(new Quota(config.getProperty(QuotaConfigs.CONSUMER_BYTE_RATE_OVERRIDE_CONFIG).toLong.toDouble, true)) + else + None + quotaManagers.fetch.updateQuota(sanitizedUser, clientId, sanitizedClientId, consumerQuota) + val requestQuota = + if (config.containsKey(QuotaConfigs.REQUEST_PERCENTAGE_OVERRIDE_CONFIG)) + Some(new Quota(config.getProperty(QuotaConfigs.REQUEST_PERCENTAGE_OVERRIDE_CONFIG).toDouble, true)) + else + None + quotaManagers.request.updateQuota(sanitizedUser, clientId, sanitizedClientId, requestQuota) + val controllerMutationQuota = + if (config.containsKey(QuotaConfigs.CONTROLLER_MUTATION_RATE_OVERRIDE_CONFIG)) + Some(new Quota(config.getProperty(QuotaConfigs.CONTROLLER_MUTATION_RATE_OVERRIDE_CONFIG).toDouble, true)) + else + None + quotaManagers.controllerMutation.updateQuota(sanitizedUser, clientId, sanitizedClientId, controllerMutationQuota) + } +} + +/** + * The ClientIdConfigHandler will process clientId config changes in ZK. + * The callback provides the clientId and the full properties set read from ZK. + */ +class ClientIdConfigHandler(private val quotaManagers: QuotaManagers) extends QuotaConfigHandler(quotaManagers) with ConfigHandler { + + def processConfigChanges(sanitizedClientId: String, clientConfig: Properties): Unit = { + updateQuotaConfig(None, Some(sanitizedClientId), clientConfig) + } +} + +/** + * The UserConfigHandler will process and quota changes in ZK. + * The callback provides the node name containing sanitized user principal, sanitized client-id if this is + * a update and the full properties set read from ZK. + */ +class UserConfigHandler(private val quotaManagers: QuotaManagers, val credentialProvider: CredentialProvider) extends QuotaConfigHandler(quotaManagers) with ConfigHandler { + + def processConfigChanges(quotaEntityPath: String, config: Properties): Unit = { + // Entity path is or /clients/ + val entities = quotaEntityPath.split("/") + if (entities.length != 1 && entities.length != 3) + throw new IllegalArgumentException("Invalid quota entity path: " + quotaEntityPath) + val sanitizedUser = entities(0) + val sanitizedClientId = if (entities.length == 3) Some(entities(2)) else None + updateQuotaConfig(Some(sanitizedUser), sanitizedClientId, config) + if (!sanitizedClientId.isDefined && sanitizedUser != ConfigEntityName.Default) + credentialProvider.updateCredentials(Sanitizer.desanitize(sanitizedUser), config) + } +} + +class IpConfigHandler(private val connectionQuotas: ConnectionQuotas) extends ConfigHandler with Logging { + + def processConfigChanges(ip: String, config: Properties): Unit = { + val ipConnectionRateQuota = Option(config.getProperty(QuotaConfigs.IP_CONNECTION_RATE_OVERRIDE_CONFIG)).map(_.toInt) + val updatedIp = { + if (ip != ConfigEntityName.Default) { + try { + Some(InetAddress.getByName(ip)) + } catch { + case _: UnknownHostException => throw new IllegalArgumentException(s"Unable to resolve address $ip") + } + } else + None + } + connectionQuotas.updateIpConnectionRateQuota(updatedIp, ipConnectionRateQuota) + } +} + +/** + * The BrokerConfigHandler will process individual broker config changes in ZK. + * The callback provides the brokerId and the full properties set read from ZK. + * This implementation reports the overrides to the respective ReplicationQuotaManager objects + */ +class BrokerConfigHandler(private val brokerConfig: KafkaConfig, + private val quotaManagers: QuotaManagers) extends ConfigHandler with Logging { + + def processConfigChanges(brokerId: String, properties: Properties): Unit = { + def getOrDefault(prop: String): Long = { + if (properties.containsKey(prop)) + properties.getProperty(prop).toLong + else + DefaultReplicationThrottledRate + } + if (brokerId == ConfigEntityName.Default) + brokerConfig.dynamicConfig.updateDefaultConfig(properties) + else if (brokerConfig.brokerId == brokerId.trim.toInt) { + brokerConfig.dynamicConfig.updateBrokerConfig(brokerConfig.brokerId, properties) + quotaManagers.leader.updateQuota(upperBound(getOrDefault(LeaderReplicationThrottledRateProp).toDouble)) + quotaManagers.follower.updateQuota(upperBound(getOrDefault(FollowerReplicationThrottledRateProp).toDouble)) + quotaManagers.alterLogDirs.updateQuota(upperBound(getOrDefault(ReplicaAlterLogDirsIoMaxBytesPerSecondProp).toDouble)) + } + } +} + +object ThrottledReplicaListValidator extends Validator { + def ensureValidString(name: String, value: String): Unit = + ensureValid(name, value.split(",").map(_.trim).toSeq) + + override def ensureValid(name: String, value: Any): Unit = { + def check(proposed: Seq[Any]): Unit = { + if (!(proposed.forall(_.toString.trim.matches("([0-9]+:[0-9]+)?")) + || proposed.mkString.trim.equals("*"))) + throw new ConfigException(name, value, + s"$name must be the literal '*' or a list of replicas in the following format: [partitionId]:[brokerId],[partitionId]:[brokerId],...") + } + value match { + case scalaSeq: Seq[_] => check(scalaSeq) + case javaList: java.util.List[_] => check(javaList.asScala) + case _ => throw new ConfigException(name, value, s"$name must be a List but was ${value.getClass.getName}") + } + } + + override def toString: String = "[partitionId]:[brokerId],[partitionId]:[brokerId],..." + +} diff --git a/core/src/main/scala/kafka/server/ConfigHelper.scala b/core/src/main/scala/kafka/server/ConfigHelper.scala new file mode 100644 index 0000000..174d998 --- /dev/null +++ b/core/src/main/scala/kafka/server/ConfigHelper.scala @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util.{Collections, Properties} + +import kafka.log.LogConfig +import kafka.server.metadata.ConfigRepository +import kafka.utils.{Log4jController, Logging} +import org.apache.kafka.common.config.{AbstractConfig, ConfigDef, ConfigResource} +import org.apache.kafka.common.errors.{ApiException, InvalidRequestException} +import org.apache.kafka.common.internals.Topic +import org.apache.kafka.common.message.DescribeConfigsRequestData.DescribeConfigsResource +import org.apache.kafka.common.message.DescribeConfigsResponseData +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.{ApiError, DescribeConfigsResponse} +import org.apache.kafka.common.requests.DescribeConfigsResponse.ConfigSource + +import scala.collection.{Map, mutable} +import scala.jdk.CollectionConverters._ + +class ConfigHelper(metadataCache: MetadataCache, config: KafkaConfig, configRepository: ConfigRepository) extends Logging { + + def allConfigs(config: AbstractConfig) = { + config.originals.asScala.filter(_._2 != null) ++ config.nonInternalValues.asScala + } + + def describeConfigs(resourceToConfigNames: List[DescribeConfigsResource], + includeSynonyms: Boolean, + includeDocumentation: Boolean): List[DescribeConfigsResponseData.DescribeConfigsResult] = { + resourceToConfigNames.map { case resource => + + def createResponseConfig(configs: Map[String, Any], + createConfigEntry: (String, Any) => DescribeConfigsResponseData.DescribeConfigsResourceResult): DescribeConfigsResponseData.DescribeConfigsResult = { + val filteredConfigPairs = if (resource.configurationKeys == null || resource.configurationKeys.isEmpty) + configs.toBuffer + else + configs.filter { case (configName, _) => + resource.configurationKeys.asScala.contains(configName) + }.toBuffer + + val configEntries = filteredConfigPairs.map { case (name, value) => createConfigEntry(name, value) } + new DescribeConfigsResponseData.DescribeConfigsResult().setErrorCode(Errors.NONE.code) + .setConfigs(configEntries.asJava) + } + + try { + val configResult = ConfigResource.Type.forId(resource.resourceType) match { + case ConfigResource.Type.TOPIC => + val topic = resource.resourceName + Topic.validate(topic) + if (metadataCache.contains(topic)) { + val topicProps = configRepository.topicConfig(topic) + val logConfig = LogConfig.fromProps(LogConfig.extractLogConfigMap(config), topicProps) + createResponseConfig(allConfigs(logConfig), createTopicConfigEntry(logConfig, topicProps, includeSynonyms, includeDocumentation)) + } else { + new DescribeConfigsResponseData.DescribeConfigsResult().setErrorCode(Errors.UNKNOWN_TOPIC_OR_PARTITION.code) + .setConfigs(Collections.emptyList[DescribeConfigsResponseData.DescribeConfigsResourceResult]) + } + + case ConfigResource.Type.BROKER => + if (resource.resourceName == null || resource.resourceName.isEmpty) + createResponseConfig(config.dynamicConfig.currentDynamicDefaultConfigs, + createBrokerConfigEntry(perBrokerConfig = false, includeSynonyms, includeDocumentation)) + else if (resourceNameToBrokerId(resource.resourceName) == config.brokerId) + createResponseConfig(allConfigs(config), + createBrokerConfigEntry(perBrokerConfig = true, includeSynonyms, includeDocumentation)) + else + throw new InvalidRequestException(s"Unexpected broker id, expected ${config.brokerId} or empty string, but received ${resource.resourceName}") + + case ConfigResource.Type.BROKER_LOGGER => + if (resource.resourceName == null || resource.resourceName.isEmpty) + throw new InvalidRequestException("Broker id must not be empty") + else if (resourceNameToBrokerId(resource.resourceName) != config.brokerId) + throw new InvalidRequestException(s"Unexpected broker id, expected ${config.brokerId} but received ${resource.resourceName}") + else + createResponseConfig(Log4jController.loggers, + (name, value) => new DescribeConfigsResponseData.DescribeConfigsResourceResult().setName(name) + .setValue(value.toString).setConfigSource(ConfigSource.DYNAMIC_BROKER_LOGGER_CONFIG.id) + .setIsSensitive(false).setReadOnly(false).setSynonyms(List.empty.asJava)) + case resourceType => throw new InvalidRequestException(s"Unsupported resource type: $resourceType") + } + configResult.setResourceName(resource.resourceName).setResourceType(resource.resourceType) + } catch { + case e: Throwable => + // Log client errors at a lower level than unexpected exceptions + val message = s"Error processing describe configs request for resource $resource" + if (e.isInstanceOf[ApiException]) + info(message, e) + else + error(message, e) + val err = ApiError.fromThrowable(e) + new DescribeConfigsResponseData.DescribeConfigsResult() + .setResourceName(resource.resourceName) + .setResourceType(resource.resourceType) + .setErrorMessage(err.message) + .setErrorCode(err.error.code) + .setConfigs(Collections.emptyList[DescribeConfigsResponseData.DescribeConfigsResourceResult]) + } + } + } + + def createTopicConfigEntry(logConfig: LogConfig, topicProps: Properties, includeSynonyms: Boolean, includeDocumentation: Boolean) + (name: String, value: Any): DescribeConfigsResponseData.DescribeConfigsResourceResult = { + val configEntryType = LogConfig.configType(name) + val isSensitive = KafkaConfig.maybeSensitive(configEntryType) + val valueAsString = if (isSensitive) null else ConfigDef.convertToString(value, configEntryType.orNull) + val allSynonyms = { + val list = LogConfig.TopicConfigSynonyms.get(name) + .map(s => configSynonyms(s, brokerSynonyms(s), isSensitive)) + .getOrElse(List.empty) + if (!topicProps.containsKey(name)) + list + else + new DescribeConfigsResponseData.DescribeConfigsSynonym().setName(name).setValue(valueAsString) + .setSource(ConfigSource.TOPIC_CONFIG.id) +: list + } + val source = if (allSynonyms.isEmpty) ConfigSource.DEFAULT_CONFIG.id else allSynonyms.head.source + val synonyms = if (!includeSynonyms) List.empty else allSynonyms + val dataType = configResponseType(configEntryType) + val configDocumentation = if (includeDocumentation) logConfig.documentationOf(name) else null + new DescribeConfigsResponseData.DescribeConfigsResourceResult() + .setName(name).setValue(valueAsString).setConfigSource(source) + .setIsSensitive(isSensitive).setReadOnly(false).setSynonyms(synonyms.asJava) + .setDocumentation(configDocumentation).setConfigType(dataType.id) + } + + private def createBrokerConfigEntry(perBrokerConfig: Boolean, includeSynonyms: Boolean, includeDocumentation: Boolean) + (name: String, value: Any): DescribeConfigsResponseData.DescribeConfigsResourceResult = { + val allNames = brokerSynonyms(name) + val configEntryType = KafkaConfig.configType(name) + val isSensitive = KafkaConfig.maybeSensitive(configEntryType) + val valueAsString = if (isSensitive) + null + else value match { + case v: String => v + case _ => ConfigDef.convertToString(value, configEntryType.orNull) + } + val allSynonyms = configSynonyms(name, allNames, isSensitive) + .filter(perBrokerConfig || _.source == ConfigSource.DYNAMIC_DEFAULT_BROKER_CONFIG.id) + val synonyms = if (!includeSynonyms) List.empty else allSynonyms + val source = if (allSynonyms.isEmpty) ConfigSource.DEFAULT_CONFIG.id else allSynonyms.head.source + val readOnly = !DynamicBrokerConfig.AllDynamicConfigs.contains(name) + + val dataType = configResponseType(configEntryType) + val configDocumentation = if (includeDocumentation) brokerDocumentation(name) else null + new DescribeConfigsResponseData.DescribeConfigsResourceResult().setName(name).setValue(valueAsString).setConfigSource(source) + .setIsSensitive(isSensitive).setReadOnly(readOnly).setSynonyms(synonyms.asJava) + .setDocumentation(configDocumentation).setConfigType(dataType.id) + } + + private def configSynonyms(name: String, synonyms: List[String], isSensitive: Boolean): List[DescribeConfigsResponseData.DescribeConfigsSynonym] = { + val dynamicConfig = config.dynamicConfig + val allSynonyms = mutable.Buffer[DescribeConfigsResponseData.DescribeConfigsSynonym]() + + def maybeAddSynonym(map: Map[String, String], source: ConfigSource)(name: String): Unit = { + map.get(name).map { value => + val configValue = if (isSensitive) null else value + allSynonyms += new DescribeConfigsResponseData.DescribeConfigsSynonym().setName(name).setValue(configValue).setSource(source.id) + } + } + + synonyms.foreach(maybeAddSynonym(dynamicConfig.currentDynamicBrokerConfigs, ConfigSource.DYNAMIC_BROKER_CONFIG)) + synonyms.foreach(maybeAddSynonym(dynamicConfig.currentDynamicDefaultConfigs, ConfigSource.DYNAMIC_DEFAULT_BROKER_CONFIG)) + synonyms.foreach(maybeAddSynonym(dynamicConfig.staticBrokerConfigs, ConfigSource.STATIC_BROKER_CONFIG)) + synonyms.foreach(maybeAddSynonym(dynamicConfig.staticDefaultConfigs, ConfigSource.DEFAULT_CONFIG)) + allSynonyms.dropWhile(s => s.name != name).toList // e.g. drop listener overrides when describing base config + } + + private def brokerSynonyms(name: String): List[String] = { + DynamicBrokerConfig.brokerConfigSynonyms(name, matchListenerOverride = true) + } + + private def brokerDocumentation(name: String): String = { + config.documentationOf(name) + } + + private def configResponseType(configType: Option[ConfigDef.Type]): DescribeConfigsResponse.ConfigType = { + if (configType.isEmpty) + DescribeConfigsResponse.ConfigType.UNKNOWN + else configType.get match { + case ConfigDef.Type.BOOLEAN => DescribeConfigsResponse.ConfigType.BOOLEAN + case ConfigDef.Type.STRING => DescribeConfigsResponse.ConfigType.STRING + case ConfigDef.Type.INT => DescribeConfigsResponse.ConfigType.INT + case ConfigDef.Type.SHORT => DescribeConfigsResponse.ConfigType.SHORT + case ConfigDef.Type.LONG => DescribeConfigsResponse.ConfigType.LONG + case ConfigDef.Type.DOUBLE => DescribeConfigsResponse.ConfigType.DOUBLE + case ConfigDef.Type.LIST => DescribeConfigsResponse.ConfigType.LIST + case ConfigDef.Type.CLASS => DescribeConfigsResponse.ConfigType.CLASS + case ConfigDef.Type.PASSWORD => DescribeConfigsResponse.ConfigType.PASSWORD + case _ => DescribeConfigsResponse.ConfigType.UNKNOWN + } + } + + private def resourceNameToBrokerId(resourceName: String): Int = { + try resourceName.toInt catch { + case _: NumberFormatException => + throw new InvalidRequestException(s"Broker id must be an integer, but it is: $resourceName") + } + } +} diff --git a/core/src/main/scala/kafka/server/ControllerApis.scala b/core/src/main/scala/kafka/server/ControllerApis.scala new file mode 100644 index 0000000..ed9b55a --- /dev/null +++ b/core/src/main/scala/kafka/server/ControllerApis.scala @@ -0,0 +1,778 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util +import java.util.Collections +import java.util.Map.Entry +import java.util.concurrent.TimeUnit.{MILLISECONDS, NANOSECONDS} +import java.util.concurrent.{CompletableFuture, ExecutionException} + +import kafka.network.RequestChannel +import kafka.raft.RaftManager +import kafka.server.QuotaFactory.QuotaManagers +import kafka.utils.Logging +import org.apache.kafka.clients.admin.AlterConfigOp +import org.apache.kafka.common.Uuid.ZERO_UUID +import org.apache.kafka.common.acl.AclOperation.{ALTER, ALTER_CONFIGS, CLUSTER_ACTION, CREATE, DELETE, DESCRIBE} +import org.apache.kafka.common.config.ConfigResource +import org.apache.kafka.common.errors.{ApiException, ClusterAuthorizationException, InvalidRequestException, TopicDeletionDisabledException} +import org.apache.kafka.common.internals.FatalExitError +import org.apache.kafka.common.message.AlterConfigsResponseData.{AlterConfigsResourceResponse => OldAlterConfigsResourceResponse} +import org.apache.kafka.common.message.CreatePartitionsRequestData.CreatePartitionsTopic +import org.apache.kafka.common.message.CreatePartitionsResponseData.CreatePartitionsTopicResult +import org.apache.kafka.common.message.CreateTopicsResponseData.CreatableTopicResult +import org.apache.kafka.common.message.DeleteTopicsResponseData.{DeletableTopicResult, DeletableTopicResultCollection} +import org.apache.kafka.common.message.IncrementalAlterConfigsResponseData.AlterConfigsResourceResponse +import org.apache.kafka.common.message.{CreateTopicsRequestData, _} +import org.apache.kafka.common.protocol.Errors._ +import org.apache.kafka.common.protocol.{ApiKeys, ApiMessage, Errors} +import org.apache.kafka.common.requests._ +import org.apache.kafka.common.resource.Resource.CLUSTER_NAME +import org.apache.kafka.common.resource.ResourceType.{CLUSTER, TOPIC} +import org.apache.kafka.common.utils.Time +import org.apache.kafka.common.{Node, Uuid} +import org.apache.kafka.controller.Controller +import org.apache.kafka.metadata.{BrokerHeartbeatReply, BrokerRegistrationReply, VersionRange} +import org.apache.kafka.server.authorizer.Authorizer +import org.apache.kafka.server.common.ApiMessageAndVersion + +import scala.jdk.CollectionConverters._ + + +/** + * Request handler for Controller APIs + */ +class ControllerApis(val requestChannel: RequestChannel, + val authorizer: Option[Authorizer], + val quotas: QuotaManagers, + val time: Time, + val supportedFeatures: Map[String, VersionRange], + val controller: Controller, + val raftManager: RaftManager[ApiMessageAndVersion], + val config: KafkaConfig, + val metaProperties: MetaProperties, + val controllerNodes: Seq[Node], + val apiVersionManager: ApiVersionManager) extends ApiRequestHandler with Logging { + + val authHelper = new AuthHelper(authorizer) + val requestHelper = new RequestHandlerHelper(requestChannel, quotas, time) + private val aclApis = new AclApis(authHelper, authorizer, requestHelper, "controller", config) + + def isClosed: Boolean = aclApis.isClosed + + def close(): Unit = aclApis.close() + + override def handle(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { + try { + request.header.apiKey match { + case ApiKeys.FETCH => handleFetch(request) + case ApiKeys.FETCH_SNAPSHOT => handleFetchSnapshot(request) + case ApiKeys.CREATE_TOPICS => handleCreateTopics(request) + case ApiKeys.DELETE_TOPICS => handleDeleteTopics(request) + case ApiKeys.API_VERSIONS => handleApiVersionsRequest(request) + case ApiKeys.ALTER_CONFIGS => handleLegacyAlterConfigs(request) + case ApiKeys.VOTE => handleVote(request) + case ApiKeys.BEGIN_QUORUM_EPOCH => handleBeginQuorumEpoch(request) + case ApiKeys.END_QUORUM_EPOCH => handleEndQuorumEpoch(request) + case ApiKeys.DESCRIBE_QUORUM => handleDescribeQuorum(request) + case ApiKeys.ALTER_ISR => handleAlterIsrRequest(request) + case ApiKeys.BROKER_REGISTRATION => handleBrokerRegistration(request) + case ApiKeys.BROKER_HEARTBEAT => handleBrokerHeartBeatRequest(request) + case ApiKeys.UNREGISTER_BROKER => handleUnregisterBroker(request) + case ApiKeys.ALTER_CLIENT_QUOTAS => handleAlterClientQuotas(request) + case ApiKeys.INCREMENTAL_ALTER_CONFIGS => handleIncrementalAlterConfigs(request) + case ApiKeys.ALTER_PARTITION_REASSIGNMENTS => handleAlterPartitionReassignments(request) + case ApiKeys.LIST_PARTITION_REASSIGNMENTS => handleListPartitionReassignments(request) + case ApiKeys.ENVELOPE => handleEnvelopeRequest(request, requestLocal) + case ApiKeys.SASL_HANDSHAKE => handleSaslHandshakeRequest(request) + case ApiKeys.SASL_AUTHENTICATE => handleSaslAuthenticateRequest(request) + case ApiKeys.ALLOCATE_PRODUCER_IDS => handleAllocateProducerIdsRequest(request) + case ApiKeys.CREATE_PARTITIONS => handleCreatePartitions(request) + case ApiKeys.DESCRIBE_ACLS => aclApis.handleDescribeAcls(request) + case ApiKeys.CREATE_ACLS => aclApis.handleCreateAcls(request) + case ApiKeys.DELETE_ACLS => aclApis.handleDeleteAcls(request) + case ApiKeys.ELECT_LEADERS => handleElectLeaders(request) + case _ => throw new ApiException(s"Unsupported ApiKey ${request.context.header.apiKey}") + } + } catch { + case e: FatalExitError => throw e + case e: ExecutionException => requestHelper.handleError(request, e.getCause) + case e: Throwable => requestHelper.handleError(request, e) + } + } + + def handleEnvelopeRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { + if (!authHelper.authorize(request.context, CLUSTER_ACTION, CLUSTER, CLUSTER_NAME)) { + requestHelper.sendErrorResponseMaybeThrottle(request, new ClusterAuthorizationException( + s"Principal ${request.context.principal} does not have required CLUSTER_ACTION for envelope")) + } else { + EnvelopeUtils.handleEnvelopeRequest(request, requestChannel.metrics, handle(_, requestLocal)) + } + } + + def handleSaslHandshakeRequest(request: RequestChannel.Request): Unit = { + val responseData = new SaslHandshakeResponseData().setErrorCode(ILLEGAL_SASL_STATE.code) + requestHelper.sendResponseMaybeThrottle(request, _ => new SaslHandshakeResponse(responseData)) + } + + def handleSaslAuthenticateRequest(request: RequestChannel.Request): Unit = { + val responseData = new SaslAuthenticateResponseData() + .setErrorCode(ILLEGAL_SASL_STATE.code) + .setErrorMessage("SaslAuthenticate request received after successful authentication") + requestHelper.sendResponseMaybeThrottle(request, _ => new SaslAuthenticateResponse(responseData)) + } + + def handleFetch(request: RequestChannel.Request): Unit = { + authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) + handleRaftRequest(request, response => new FetchResponse(response.asInstanceOf[FetchResponseData])) + } + + def handleFetchSnapshot(request: RequestChannel.Request): Unit = { + authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) + handleRaftRequest(request, response => new FetchSnapshotResponse(response.asInstanceOf[FetchSnapshotResponseData])) + } + + def handleDeleteTopics(request: RequestChannel.Request): Unit = { + val deleteTopicsRequest = request.body[DeleteTopicsRequest] + val future = deleteTopics(deleteTopicsRequest.data, + request.context.apiVersion, + authHelper.authorize(request.context, DELETE, CLUSTER, CLUSTER_NAME), + names => authHelper.filterByAuthorized(request.context, DESCRIBE, TOPIC, names)(n => n), + names => authHelper.filterByAuthorized(request.context, DELETE, TOPIC, names)(n => n)) + future.whenComplete { (results, exception) => + requestHelper.sendResponseMaybeThrottle(request, throttleTimeMs => { + if (exception != null) { + deleteTopicsRequest.getErrorResponse(throttleTimeMs, exception) + } else { + val responseData = new DeleteTopicsResponseData(). + setResponses(new DeletableTopicResultCollection(results.iterator)). + setThrottleTimeMs(throttleTimeMs) + new DeleteTopicsResponse(responseData) + } + }) + } + } + + def deleteTopics(request: DeleteTopicsRequestData, + apiVersion: Int, + hasClusterAuth: Boolean, + getDescribableTopics: Iterable[String] => Set[String], + getDeletableTopics: Iterable[String] => Set[String]) + : CompletableFuture[util.List[DeletableTopicResult]] = { + // Check if topic deletion is enabled at all. + if (!config.deleteTopicEnable) { + if (apiVersion < 3) { + throw new InvalidRequestException("Topic deletion is disabled.") + } else { + throw new TopicDeletionDisabledException() + } + } + val deadlineNs = time.nanoseconds() + NANOSECONDS.convert(request.timeoutMs, MILLISECONDS); + // The first step is to load up the names and IDs that have been provided by the + // request. This is a bit messy because we support multiple ways of referring to + // topics (both by name and by id) and because we need to check for duplicates or + // other invalid inputs. + val responses = new util.ArrayList[DeletableTopicResult] + def appendResponse(name: String, id: Uuid, error: ApiError): Unit = { + responses.add(new DeletableTopicResult(). + setName(name). + setTopicId(id). + setErrorCode(error.error.code). + setErrorMessage(error.message)) + } + val providedNames = new util.HashSet[String] + val duplicateProvidedNames = new util.HashSet[String] + val providedIds = new util.HashSet[Uuid] + val duplicateProvidedIds = new util.HashSet[Uuid] + def addProvidedName(name: String): Unit = { + if (duplicateProvidedNames.contains(name) || !providedNames.add(name)) { + duplicateProvidedNames.add(name) + providedNames.remove(name) + } + } + request.topicNames.forEach(addProvidedName) + request.topics.forEach { + topic => if (topic.name == null) { + if (topic.topicId.equals(ZERO_UUID)) { + appendResponse(null, ZERO_UUID, new ApiError(INVALID_REQUEST, + "Neither topic name nor id were specified.")) + } else if (duplicateProvidedIds.contains(topic.topicId) || !providedIds.add(topic.topicId)) { + duplicateProvidedIds.add(topic.topicId) + providedIds.remove(topic.topicId) + } + } else { + if (topic.topicId.equals(ZERO_UUID)) { + addProvidedName(topic.name) + } else { + appendResponse(topic.name, topic.topicId, new ApiError(INVALID_REQUEST, + "You may not specify both topic name and topic id.")) + } + } + } + // Create error responses for duplicates. + duplicateProvidedNames.forEach(name => appendResponse(name, ZERO_UUID, + new ApiError(INVALID_REQUEST, "Duplicate topic name."))) + duplicateProvidedIds.forEach(id => appendResponse(null, id, + new ApiError(INVALID_REQUEST, "Duplicate topic id."))) + // At this point we have all the valid names and IDs that have been provided. + // However, the Authorizer needs topic names as inputs, not topic IDs. So + // we need to resolve all IDs to names. + val toAuthenticate = new util.HashSet[String] + toAuthenticate.addAll(providedNames) + val idToName = new util.HashMap[Uuid, String] + controller.findTopicNames(deadlineNs, providedIds).thenCompose { topicNames => + topicNames.forEach { (id, nameOrError) => + if (nameOrError.isError) { + appendResponse(null, id, nameOrError.error()) + } else { + toAuthenticate.add(nameOrError.result()) + idToName.put(id, nameOrError.result()) + } + } + // Get the list of deletable topics (those we can delete) and the list of describeable + // topics. + val topicsToAuthenticate = toAuthenticate.asScala + val (describeable, deletable) = if (hasClusterAuth) { + (topicsToAuthenticate.toSet, topicsToAuthenticate.toSet) + } else { + (getDescribableTopics(topicsToAuthenticate), getDeletableTopics(topicsToAuthenticate)) + } + // For each topic that was provided by ID, check if authentication failed. + // If so, remove it from the idToName map and create an error response for it. + val iterator = idToName.entrySet().iterator() + while (iterator.hasNext) { + val entry = iterator.next() + val id = entry.getKey + val name = entry.getValue + if (!deletable.contains(name)) { + if (describeable.contains(name)) { + appendResponse(name, id, new ApiError(TOPIC_AUTHORIZATION_FAILED)) + } else { + appendResponse(null, id, new ApiError(TOPIC_AUTHORIZATION_FAILED)) + } + iterator.remove() + } + } + // For each topic that was provided by name, check if authentication failed. + // If so, create an error response for it. Otherwise, add it to the idToName map. + controller.findTopicIds(deadlineNs, providedNames).thenCompose { topicIds => + topicIds.forEach { (name, idOrError) => + if (!describeable.contains(name)) { + appendResponse(name, ZERO_UUID, new ApiError(TOPIC_AUTHORIZATION_FAILED)) + } else if (idOrError.isError) { + appendResponse(name, ZERO_UUID, idOrError.error) + } else if (deletable.contains(name)) { + val id = idOrError.result() + if (duplicateProvidedIds.contains(id) || idToName.put(id, name) != null) { + // This is kind of a weird case: what if we supply topic ID X and also a name + // that maps to ID X? In that case, _if authorization succeeds_, we end up + // here. If authorization doesn't succeed, we refrain from commenting on the + // situation since it would reveal topic ID mappings. + duplicateProvidedIds.add(id) + idToName.remove(id) + appendResponse(name, id, new ApiError(INVALID_REQUEST, + "The provided topic name maps to an ID that was already supplied.")) + } + } else { + appendResponse(name, ZERO_UUID, new ApiError(TOPIC_AUTHORIZATION_FAILED)) + } + } + // Finally, the idToName map contains all the topics that we are authorized to delete. + // Perform the deletion and create responses for each one. + controller.deleteTopics(deadlineNs, idToName.keySet).thenApply { idToError => + idToError.forEach { (id, error) => + appendResponse(idToName.get(id), id, error) + } + // Shuffle the responses so that users can not use patterns in their positions to + // distinguish between absent topics and topics we are not permitted to see. + Collections.shuffle(responses) + responses + } + } + } + } + + def handleCreateTopics(request: RequestChannel.Request): Unit = { + val createTopicsRequest = request.body[CreateTopicsRequest] + val future = createTopics(createTopicsRequest.data(), + authHelper.authorize(request.context, CREATE, CLUSTER, CLUSTER_NAME), + names => authHelper.filterByAuthorized(request.context, CREATE, TOPIC, names)(identity)) + future.whenComplete { (result, exception) => + requestHelper.sendResponseMaybeThrottle(request, throttleTimeMs => { + if (exception != null) { + createTopicsRequest.getErrorResponse(throttleTimeMs, exception) + } else { + result.setThrottleTimeMs(throttleTimeMs) + new CreateTopicsResponse(result) + } + }) + } + } + + def createTopics(request: CreateTopicsRequestData, + hasClusterAuth: Boolean, + getCreatableTopics: Iterable[String] => Set[String]) + : CompletableFuture[CreateTopicsResponseData] = { + val topicNames = new util.HashSet[String]() + val duplicateTopicNames = new util.HashSet[String]() + request.topics().forEach { topicData => + if (!duplicateTopicNames.contains(topicData.name())) { + if (!topicNames.add(topicData.name())) { + topicNames.remove(topicData.name()) + duplicateTopicNames.add(topicData.name()) + } + } + } + val authorizedTopicNames = if (hasClusterAuth) { + topicNames.asScala + } else { + getCreatableTopics.apply(topicNames.asScala) + } + val effectiveRequest = request.duplicate() + val iterator = effectiveRequest.topics().iterator() + while (iterator.hasNext) { + val creatableTopic = iterator.next() + if (duplicateTopicNames.contains(creatableTopic.name()) || + !authorizedTopicNames.contains(creatableTopic.name())) { + iterator.remove() + } + } + controller.createTopics(effectiveRequest).thenApply { response => + duplicateTopicNames.forEach { name => + response.topics().add(new CreatableTopicResult(). + setName(name). + setErrorCode(INVALID_REQUEST.code). + setErrorMessage("Duplicate topic name.")) + } + topicNames.forEach { name => + if (!authorizedTopicNames.contains(name)) { + response.topics().add(new CreatableTopicResult(). + setName(name). + setErrorCode(TOPIC_AUTHORIZATION_FAILED.code)) + } + } + response + } + } + + def handleApiVersionsRequest(request: RequestChannel.Request): Unit = { + // Note that broker returns its full list of supported ApiKeys and versions regardless of current + // authentication state (e.g., before SASL authentication on an SASL listener, do note that no + // Kafka protocol requests may take place on an SSL listener before the SSL handshake is finished). + // If this is considered to leak information about the broker version a workaround is to use SSL + // with client authentication which is performed at an earlier stage of the connection where the + // ApiVersionRequest is not available. + def createResponseCallback(requestThrottleMs: Int): ApiVersionsResponse = { + val apiVersionRequest = request.body[ApiVersionsRequest] + if (apiVersionRequest.hasUnsupportedRequestVersion) { + apiVersionRequest.getErrorResponse(requestThrottleMs, UNSUPPORTED_VERSION.exception) + } else if (!apiVersionRequest.isValid) { + apiVersionRequest.getErrorResponse(requestThrottleMs, INVALID_REQUEST.exception) + } else { + apiVersionManager.apiVersionResponse(requestThrottleMs) + } + } + requestHelper.sendResponseMaybeThrottle(request, createResponseCallback) + } + + def authorizeAlterResource(requestContext: RequestContext, + resource: ConfigResource): ApiError = { + resource.`type` match { + case ConfigResource.Type.BROKER => + if (authHelper.authorize(requestContext, ALTER_CONFIGS, CLUSTER, CLUSTER_NAME)) { + new ApiError(NONE) + } else { + new ApiError(CLUSTER_AUTHORIZATION_FAILED) + } + case ConfigResource.Type.TOPIC => + if (authHelper.authorize(requestContext, ALTER_CONFIGS, TOPIC, resource.name)) { + new ApiError(NONE) + } else { + new ApiError(TOPIC_AUTHORIZATION_FAILED) + } + case rt => new ApiError(INVALID_REQUEST, s"Unexpected resource type $rt.") + } + } + + def handleLegacyAlterConfigs(request: RequestChannel.Request): Unit = { + val response = new AlterConfigsResponseData() + val alterConfigsRequest = request.body[AlterConfigsRequest] + val duplicateResources = new util.HashSet[ConfigResource] + val configChanges = new util.HashMap[ConfigResource, util.Map[String, String]]() + alterConfigsRequest.data.resources.forEach { resource => + val configResource = new ConfigResource( + ConfigResource.Type.forId(resource.resourceType), resource.resourceName()) + if (configResource.`type`().equals(ConfigResource.Type.UNKNOWN)) { + response.responses().add(new OldAlterConfigsResourceResponse(). + setErrorCode(UNSUPPORTED_VERSION.code()). + setErrorMessage("Unknown resource type " + resource.resourceType() + "."). + setResourceName(resource.resourceName()). + setResourceType(resource.resourceType())) + } else if (!duplicateResources.contains(configResource)) { + val configs = new util.HashMap[String, String]() + resource.configs().forEach(config => configs.put(config.name(), config.value())) + if (configChanges.put(configResource, configs) != null) { + duplicateResources.add(configResource) + configChanges.remove(configResource) + response.responses().add(new OldAlterConfigsResourceResponse(). + setErrorCode(INVALID_REQUEST.code()). + setErrorMessage("Duplicate resource."). + setResourceName(resource.resourceName()). + setResourceType(resource.resourceType())) + } + } + } + val iterator = configChanges.keySet().iterator() + while (iterator.hasNext) { + val resource = iterator.next() + val apiError = authorizeAlterResource(request.context, resource) + if (apiError.isFailure) { + response.responses().add(new OldAlterConfigsResourceResponse(). + setErrorCode(apiError.error().code()). + setErrorMessage(apiError.message()). + setResourceName(resource.name()). + setResourceType(resource.`type`().id())) + iterator.remove() + } + } + controller.legacyAlterConfigs(configChanges, alterConfigsRequest.data.validateOnly) + .whenComplete { (controllerResults, exception) => + if (exception != null) { + requestHelper.handleError(request, exception) + } else { + controllerResults.entrySet().forEach(entry => response.responses().add( + new OldAlterConfigsResourceResponse(). + setErrorCode(entry.getValue.error().code()). + setErrorMessage(entry.getValue.message()). + setResourceName(entry.getKey.name()). + setResourceType(entry.getKey.`type`().id()))) + requestHelper.sendResponseMaybeThrottle(request, throttleMs => + new AlterConfigsResponse(response.setThrottleTimeMs(throttleMs))) + } + } + } + + def handleVote(request: RequestChannel.Request): Unit = { + authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) + handleRaftRequest(request, response => new VoteResponse(response.asInstanceOf[VoteResponseData])) + } + + def handleBeginQuorumEpoch(request: RequestChannel.Request): Unit = { + authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) + handleRaftRequest(request, response => new BeginQuorumEpochResponse(response.asInstanceOf[BeginQuorumEpochResponseData])) + } + + def handleEndQuorumEpoch(request: RequestChannel.Request): Unit = { + authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) + handleRaftRequest(request, response => new EndQuorumEpochResponse(response.asInstanceOf[EndQuorumEpochResponseData])) + } + + def handleDescribeQuorum(request: RequestChannel.Request): Unit = { + authHelper.authorizeClusterOperation(request, DESCRIBE) + handleRaftRequest(request, response => new DescribeQuorumResponse(response.asInstanceOf[DescribeQuorumResponseData])) + } + + def handleElectLeaders(request: RequestChannel.Request): Unit = { + authHelper.authorizeClusterOperation(request, ALTER) + + val electLeadersRequest = request.body[ElectLeadersRequest] + val future = controller.electLeaders(electLeadersRequest.data) + future.whenComplete { (responseData, exception) => + if (exception != null) { + requestHelper.sendResponseMaybeThrottle(request, throttleMs => { + electLeadersRequest.getErrorResponse(throttleMs, exception) + }) + } else { + requestHelper.sendResponseMaybeThrottle(request, throttleMs => { + new ElectLeadersResponse(responseData.setThrottleTimeMs(throttleMs)) + }) + } + } + } + + def handleAlterIsrRequest(request: RequestChannel.Request): Unit = { + val alterIsrRequest = request.body[AlterIsrRequest] + authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) + val future = controller.alterIsr(alterIsrRequest.data) + future.whenComplete { (result, exception) => + val response = if (exception != null) { + alterIsrRequest.getErrorResponse(exception) + } else { + new AlterIsrResponse(result) + } + requestHelper.sendResponseExemptThrottle(request, response) + } + } + + def handleBrokerHeartBeatRequest(request: RequestChannel.Request): Unit = { + val heartbeatRequest = request.body[BrokerHeartbeatRequest] + authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) + + controller.processBrokerHeartbeat(heartbeatRequest.data).handle[Unit] { (reply, e) => + def createResponseCallback(requestThrottleMs: Int, + reply: BrokerHeartbeatReply, + e: Throwable): BrokerHeartbeatResponse = { + if (e != null) { + new BrokerHeartbeatResponse(new BrokerHeartbeatResponseData(). + setThrottleTimeMs(requestThrottleMs). + setErrorCode(Errors.forException(e).code)) + } else { + new BrokerHeartbeatResponse(new BrokerHeartbeatResponseData(). + setThrottleTimeMs(requestThrottleMs). + setErrorCode(NONE.code). + setIsCaughtUp(reply.isCaughtUp). + setIsFenced(reply.isFenced). + setShouldShutDown(reply.shouldShutDown)) + } + } + requestHelper.sendResponseMaybeThrottle(request, + requestThrottleMs => createResponseCallback(requestThrottleMs, reply, e)) + } + } + + def handleUnregisterBroker(request: RequestChannel.Request): Unit = { + val decommissionRequest = request.body[UnregisterBrokerRequest] + authHelper.authorizeClusterOperation(request, ALTER) + + controller.unregisterBroker(decommissionRequest.data().brokerId()).handle[Unit] { (_, e) => + def createResponseCallback(requestThrottleMs: Int, + e: Throwable): UnregisterBrokerResponse = { + if (e != null) { + new UnregisterBrokerResponse(new UnregisterBrokerResponseData(). + setThrottleTimeMs(requestThrottleMs). + setErrorCode(Errors.forException(e).code)) + } else { + new UnregisterBrokerResponse(new UnregisterBrokerResponseData(). + setThrottleTimeMs(requestThrottleMs)) + } + } + requestHelper.sendResponseMaybeThrottle(request, + requestThrottleMs => createResponseCallback(requestThrottleMs, e)) + } + } + + def handleBrokerRegistration(request: RequestChannel.Request): Unit = { + val registrationRequest = request.body[BrokerRegistrationRequest] + authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) + + controller.registerBroker(registrationRequest.data).handle[Unit] { (reply, e) => + def createResponseCallback(requestThrottleMs: Int, + reply: BrokerRegistrationReply, + e: Throwable): BrokerRegistrationResponse = { + if (e != null) { + new BrokerRegistrationResponse(new BrokerRegistrationResponseData(). + setThrottleTimeMs(requestThrottleMs). + setErrorCode(Errors.forException(e).code)) + } else { + new BrokerRegistrationResponse(new BrokerRegistrationResponseData(). + setThrottleTimeMs(requestThrottleMs). + setErrorCode(NONE.code). + setBrokerEpoch(reply.epoch)) + } + } + requestHelper.sendResponseMaybeThrottle(request, + requestThrottleMs => createResponseCallback(requestThrottleMs, reply, e)) + } + } + + private def handleRaftRequest(request: RequestChannel.Request, + buildResponse: ApiMessage => AbstractResponse): Unit = { + val requestBody = request.body[AbstractRequest] + val future = raftManager.handleRequest(request.header, requestBody.data, time.milliseconds()) + + future.whenComplete { (responseData, exception) => + val response = if (exception != null) { + requestBody.getErrorResponse(exception) + } else { + buildResponse(responseData) + } + requestHelper.sendResponseExemptThrottle(request, response) + } + } + + def handleAlterClientQuotas(request: RequestChannel.Request): Unit = { + val quotaRequest = request.body[AlterClientQuotasRequest] + authHelper.authorizeClusterOperation(request, ALTER_CONFIGS) + controller.alterClientQuotas(quotaRequest.entries, quotaRequest.validateOnly) + .whenComplete { (results, exception) => + if (exception != null) { + requestHelper.handleError(request, exception) + } else { + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + AlterClientQuotasResponse.fromQuotaEntities(results, requestThrottleMs)) + } + } + } + + def handleIncrementalAlterConfigs(request: RequestChannel.Request): Unit = { + val response = new IncrementalAlterConfigsResponseData() + val alterConfigsRequest = request.body[IncrementalAlterConfigsRequest] + val duplicateResources = new util.HashSet[ConfigResource] + val configChanges = new util.HashMap[ConfigResource, + util.Map[String, Entry[AlterConfigOp.OpType, String]]]() + alterConfigsRequest.data.resources.forEach { resource => + val configResource = new ConfigResource( + ConfigResource.Type.forId(resource.resourceType), resource.resourceName()) + if (configResource.`type`().equals(ConfigResource.Type.UNKNOWN)) { + response.responses().add(new AlterConfigsResourceResponse(). + setErrorCode(UNSUPPORTED_VERSION.code()). + setErrorMessage("Unknown resource type " + resource.resourceType() + "."). + setResourceName(resource.resourceName()). + setResourceType(resource.resourceType())) + } else if (!duplicateResources.contains(configResource)) { + val altersByName = new util.HashMap[String, Entry[AlterConfigOp.OpType, String]]() + resource.configs.forEach { config => + altersByName.put(config.name, new util.AbstractMap.SimpleEntry[AlterConfigOp.OpType, String]( + AlterConfigOp.OpType.forId(config.configOperation), config.value)) + } + if (configChanges.put(configResource, altersByName) != null) { + duplicateResources.add(configResource) + configChanges.remove(configResource) + response.responses().add(new AlterConfigsResourceResponse(). + setErrorCode(INVALID_REQUEST.code()). + setErrorMessage("Duplicate resource."). + setResourceName(resource.resourceName()). + setResourceType(resource.resourceType())) + } + } + } + val iterator = configChanges.keySet().iterator() + while (iterator.hasNext) { + val resource = iterator.next() + val apiError = authorizeAlterResource(request.context, resource) + if (apiError.isFailure) { + response.responses().add(new AlterConfigsResourceResponse(). + setErrorCode(apiError.error().code()). + setErrorMessage(apiError.message()). + setResourceName(resource.name()). + setResourceType(resource.`type`().id())) + iterator.remove() + } + } + controller.incrementalAlterConfigs(configChanges, alterConfigsRequest.data.validateOnly) + .whenComplete { (controllerResults, exception) => + if (exception != null) { + requestHelper.handleError(request, exception) + } else { + controllerResults.entrySet().forEach(entry => response.responses().add( + new AlterConfigsResourceResponse(). + setErrorCode(entry.getValue.error().code()). + setErrorMessage(entry.getValue.message()). + setResourceName(entry.getKey.name()). + setResourceType(entry.getKey.`type`().id()))) + requestHelper.sendResponseMaybeThrottle(request, throttleMs => + new IncrementalAlterConfigsResponse(response.setThrottleTimeMs(throttleMs))) + } + } + } + + def handleCreatePartitions(request: RequestChannel.Request): Unit = { + val future = createPartitions(request.body[CreatePartitionsRequest].data, + authHelper.authorize(request.context, CREATE, CLUSTER, CLUSTER_NAME), + names => authHelper.filterByAuthorized(request.context, CREATE, TOPIC, names)(n => n)) + future.whenComplete { (responses, exception) => + if (exception != null) { + requestHelper.handleError(request, exception) + } else { + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => { + val responseData = new CreatePartitionsResponseData(). + setResults(responses). + setThrottleTimeMs(requestThrottleMs) + new CreatePartitionsResponse(responseData) + }) + } + } + } + + def createPartitions(request: CreatePartitionsRequestData, + hasClusterAuth: Boolean, + getCreatableTopics: Iterable[String] => Set[String]) + : CompletableFuture[util.List[CreatePartitionsTopicResult]] = { + val deadlineNs = time.nanoseconds() + NANOSECONDS.convert(request.timeoutMs, MILLISECONDS); + val responses = new util.ArrayList[CreatePartitionsTopicResult]() + val duplicateTopicNames = new util.HashSet[String]() + val topicNames = new util.HashSet[String]() + request.topics().forEach { + topic => + if (!topicNames.add(topic.name())) { + duplicateTopicNames.add(topic.name()) + } + } + duplicateTopicNames.forEach { topicName => + responses.add(new CreatePartitionsTopicResult(). + setName(topicName). + setErrorCode(INVALID_REQUEST.code). + setErrorMessage("Duplicate topic name.")) + topicNames.remove(topicName) + } + val authorizedTopicNames = { + if (hasClusterAuth) { + topicNames.asScala + } else { + getCreatableTopics(topicNames.asScala) + } + } + val topics = new util.ArrayList[CreatePartitionsTopic] + topicNames.forEach { topicName => + if (authorizedTopicNames.contains(topicName)) { + topics.add(request.topics().find(topicName)) + } else { + responses.add(new CreatePartitionsTopicResult(). + setName(topicName). + setErrorCode(TOPIC_AUTHORIZATION_FAILED.code)) + } + } + controller.createPartitions(deadlineNs, topics).thenApply { results => + results.forEach(response => responses.add(response)) + responses + } + } + + def handleAlterPartitionReassignments(request: RequestChannel.Request): Unit = { + val alterRequest = request.body[AlterPartitionReassignmentsRequest] + authHelper.authorizeClusterOperation(request, ALTER) + val response = controller.alterPartitionReassignments(alterRequest.data()).get() + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new AlterPartitionReassignmentsResponse(response.setThrottleTimeMs(requestThrottleMs))) + } + + def handleListPartitionReassignments(request: RequestChannel.Request): Unit = { + val listRequest = request.body[ListPartitionReassignmentsRequest] + authHelper.authorizeClusterOperation(request, DESCRIBE) + val response = controller.listPartitionReassignments(listRequest.data()).get() + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new ListPartitionReassignmentsResponse(response.setThrottleTimeMs(requestThrottleMs))) + } + + def handleAllocateProducerIdsRequest(request: RequestChannel.Request): Unit = { + val allocatedProducerIdsRequest = request.body[AllocateProducerIdsRequest] + authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) + controller.allocateProducerIds(allocatedProducerIdsRequest.data) + .whenComplete((results, exception) => { + if (exception != null) { + requestHelper.handleError(request, exception) + } else { + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => { + results.setThrottleTimeMs(requestThrottleMs) + new AllocateProducerIdsResponse(results) + }) + } + }) + } +} diff --git a/core/src/main/scala/kafka/server/ControllerConfigurationValidator.scala b/core/src/main/scala/kafka/server/ControllerConfigurationValidator.scala new file mode 100644 index 0000000..dfb78b2 --- /dev/null +++ b/core/src/main/scala/kafka/server/ControllerConfigurationValidator.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util +import java.util.Properties + +import kafka.log.LogConfig +import org.apache.kafka.common.config.ConfigResource +import org.apache.kafka.common.config.ConfigResource.Type.{BROKER, TOPIC} +import org.apache.kafka.controller.ConfigurationValidator +import org.apache.kafka.common.errors.InvalidRequestException + +import scala.collection.mutable + +class ControllerConfigurationValidator extends ConfigurationValidator { + override def validate(resource: ConfigResource, config: util.Map[String, String]): Unit = { + resource.`type`() match { + case TOPIC => + val properties = new Properties() + val nullTopicConfigs = new mutable.ArrayBuffer[String]() + config.entrySet().forEach(e => { + if (e.getValue() == null) { + nullTopicConfigs += e.getKey() + } else { + properties.setProperty(e.getKey(), e.getValue()) + } + }) + if (nullTopicConfigs.nonEmpty) { + throw new InvalidRequestException("Null value not supported for topic configs : " + + nullTopicConfigs.mkString(",")) + } + LogConfig.validate(properties) + case BROKER => + // TODO: add broker configuration validation + case _ => + // Note: we should never handle BROKER_LOGGER resources here, since changes to + // those resources are not persisted in the metadata. + throw new InvalidRequestException(s"Unknown resource type ${resource.`type`}") + } + } +} \ No newline at end of file diff --git a/core/src/main/scala/kafka/server/ControllerMutationQuotaManager.scala b/core/src/main/scala/kafka/server/ControllerMutationQuotaManager.scala new file mode 100644 index 0000000..f011a6b --- /dev/null +++ b/core/src/main/scala/kafka/server/ControllerMutationQuotaManager.scala @@ -0,0 +1,282 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import kafka.network.RequestChannel +import kafka.network.RequestChannel.Session +import org.apache.kafka.common.MetricName +import org.apache.kafka.common.errors.ThrottlingQuotaExceededException +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.metrics.QuotaViolationException +import org.apache.kafka.common.metrics.Sensor +import org.apache.kafka.common.metrics.stats.Rate +import org.apache.kafka.common.metrics.stats.TokenBucket +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.utils.Time +import org.apache.kafka.server.quota.ClientQuotaCallback + +import scala.jdk.CollectionConverters._ + +/** + * The ControllerMutationQuota trait defines a quota for a given user/clientId pair. Such + * quota is not meant to be cached forever but rather during the lifetime of processing + * a request. + */ +trait ControllerMutationQuota { + def isExceeded: Boolean + def record(permits: Double): Unit + def throttleTime: Int +} + +/** + * Default quota used when quota is disabled. + */ +object UnboundedControllerMutationQuota extends ControllerMutationQuota { + override def isExceeded: Boolean = false + override def record(permits: Double): Unit = () + override def throttleTime: Int = 0 +} + +/** + * The AbstractControllerMutationQuota is the base class of StrictControllerMutationQuota and + * PermissiveControllerMutationQuota. + * + * @param time @Time object to use + */ +abstract class AbstractControllerMutationQuota(private val time: Time) extends ControllerMutationQuota { + protected var lastThrottleTimeMs = 0L + protected var lastRecordedTimeMs = 0L + + protected def updateThrottleTime(e: QuotaViolationException, timeMs: Long): Unit = { + lastThrottleTimeMs = ControllerMutationQuotaManager.throttleTimeMs(e, timeMs) + lastRecordedTimeMs = timeMs + } + + override def throttleTime: Int = { + // If a throttle time has been recorded, we adjust it by deducting the time elapsed + // between the recording and now. We do this because `throttleTime` may be called + // long after having recorded it, especially when a request waits in the purgatory. + val deltaTimeMs = time.milliseconds - lastRecordedTimeMs + Math.max(0, lastThrottleTimeMs - deltaTimeMs).toInt + } +} + +/** + * The StrictControllerMutationQuota defines a strict quota for a given user/clientId pair. The + * quota is strict meaning that 1) it does not accept any mutations once the quota is exhausted + * until it gets back to the defined rate; and 2) it does not throttle for any number of mutations + * if quota is not already exhausted. + * + * @param time @Time object to use + * @param quotaSensor @Sensor object with a defined quota for a given user/clientId pair + */ +class StrictControllerMutationQuota(private val time: Time, + private val quotaSensor: Sensor) + extends AbstractControllerMutationQuota(time) { + + override def isExceeded: Boolean = lastThrottleTimeMs > 0 + + override def record(permits: Double): Unit = { + val timeMs = time.milliseconds + try { + quotaSensor synchronized { + quotaSensor.checkQuotas(timeMs) + quotaSensor.record(permits, timeMs, false) + } + } catch { + case e: QuotaViolationException => + updateThrottleTime(e, timeMs) + throw new ThrottlingQuotaExceededException(lastThrottleTimeMs.toInt, + Errors.THROTTLING_QUOTA_EXCEEDED.message) + } + } +} + +/** + * The PermissiveControllerMutationQuota defines a permissive quota for a given user/clientId pair. + * The quota is permissive meaning that 1) it does accept any mutations even if the quota is + * exhausted; and 2) it does throttle as soon as the quota is exhausted. + * + * @param time @Time object to use + * @param quotaSensor @Sensor object with a defined quota for a given user/clientId pair + */ +class PermissiveControllerMutationQuota(private val time: Time, + private val quotaSensor: Sensor) + extends AbstractControllerMutationQuota(time) { + + override def isExceeded: Boolean = false + + override def record(permits: Double): Unit = { + val timeMs = time.milliseconds + try { + quotaSensor.record(permits, timeMs, true) + } catch { + case e: QuotaViolationException => + updateThrottleTime(e, timeMs) + } + } +} + +object ControllerMutationQuotaManager { + val QuotaControllerMutationDefault = Int.MaxValue.toDouble + + /** + * This calculates the amount of time needed to bring the TokenBucket within quota + * assuming that no new metrics are recorded. + * + * Basically, if a value < 0 is observed, the time required to bring it to zero is + * -value / refill rate (quota bound) * 1000. + */ + def throttleTimeMs(e: QuotaViolationException, timeMs: Long): Long = { + e.metric().measurable() match { + case _: TokenBucket => + Math.round(-e.value() / e.bound() * 1000) + case _ => throw new IllegalArgumentException( + s"Metric ${e.metric().metricName()} is not a TokenBucket metric, value ${e.metric().measurable()}") + } + } +} + +/** + * The ControllerMutationQuotaManager is a specialized ClientQuotaManager used in the context + * of throttling controller's operations/mutations. + * + * @param config @ClientQuotaManagerConfig quota configs + * @param metrics @Metrics Metrics instance + * @param time @Time object to use + * @param threadNamePrefix The thread prefix to use + * @param quotaCallback @ClientQuotaCallback ClientQuotaCallback to use + */ +class ControllerMutationQuotaManager(private val config: ClientQuotaManagerConfig, + private val metrics: Metrics, + private val time: Time, + private val threadNamePrefix: String, + private val quotaCallback: Option[ClientQuotaCallback]) + extends ClientQuotaManager(config, metrics, QuotaType.ControllerMutation, time, threadNamePrefix, quotaCallback) { + + override protected def clientQuotaMetricName(quotaMetricTags: Map[String, String]): MetricName = { + metrics.metricName("tokens", QuotaType.ControllerMutation.toString, + "Tracking remaining tokens in the token bucket per user/client-id", + quotaMetricTags.asJava) + } + + private def clientRateMetricName(quotaMetricTags: Map[String, String]): MetricName = { + metrics.metricName("mutation-rate", QuotaType.ControllerMutation.toString, + "Tracking mutation-rate per user/client-id", + quotaMetricTags.asJava) + } + + override protected def registerQuotaMetrics(metricTags: Map[String, String])(sensor: Sensor): Unit = { + sensor.add( + clientRateMetricName(metricTags), + new Rate + ) + sensor.add( + clientQuotaMetricName(metricTags), + new TokenBucket, + getQuotaMetricConfig(metricTags) + ) + } + + /** + * Records that a user/clientId accumulated or would like to accumulate the provided amount at the + * the specified time, returns throttle time in milliseconds. The quota is strict meaning that it + * does not accept any mutations once the quota is exhausted until it gets back to the defined rate. + * + * @param session The session from which the user is extracted + * @param clientId The client id + * @param value The value to accumulate + * @param timeMs The time at which to accumulate the value + * @return The throttle time in milliseconds defines as the time to wait until the average + * rate gets back to the defined quota + */ + override def recordAndGetThrottleTimeMs(session: Session, clientId: String, value: Double, timeMs: Long): Int = { + val clientSensors = getOrCreateQuotaSensors(session, clientId) + val quotaSensor = clientSensors.quotaSensor + try { + quotaSensor synchronized { + quotaSensor.checkQuotas(timeMs) + quotaSensor.record(value, timeMs, false) + } + 0 + } catch { + case e: QuotaViolationException => + val throttleTimeMs = ControllerMutationQuotaManager.throttleTimeMs(e, timeMs).toInt + debug(s"Quota violated for sensor (${quotaSensor.name}). Delay time: ($throttleTimeMs)") + throttleTimeMs + } + } + + /** + * Returns a StrictControllerMutationQuota for the given user/clientId pair or + * a UnboundedControllerMutationQuota$ if the quota is disabled. + * + * @param session The session from which the user is extracted + * @param clientId The client id + * @return ControllerMutationQuota + */ + def newStrictQuotaFor(session: Session, clientId: String): ControllerMutationQuota = { + if (quotasEnabled) { + val clientSensors = getOrCreateQuotaSensors(session, clientId) + new StrictControllerMutationQuota(time, clientSensors.quotaSensor) + } else { + UnboundedControllerMutationQuota + } + } + + def newStrictQuotaFor(request: RequestChannel.Request): ControllerMutationQuota = + newStrictQuotaFor(request.session, request.header.clientId) + + /** + * Returns a PermissiveControllerMutationQuota for the given user/clientId pair or + * a UnboundedControllerMutationQuota$ if the quota is disabled. + * + * @param session The session from which the user is extracted + * @param clientId The client id + * @return ControllerMutationQuota + */ + def newPermissiveQuotaFor(session: Session, clientId: String): ControllerMutationQuota = { + if (quotasEnabled) { + val clientSensors = getOrCreateQuotaSensors(session, clientId) + new PermissiveControllerMutationQuota(time, clientSensors.quotaSensor) + } else { + UnboundedControllerMutationQuota + } + } + + def newPermissiveQuotaFor(request: RequestChannel.Request): ControllerMutationQuota = + newPermissiveQuotaFor(request.session, request.header.clientId) + + /** + * Returns a ControllerMutationQuota based on `strictSinceVersion`. It returns a strict + * quota if the version is equal to or above of the `strictSinceVersion`, a permissive + * quota if the version is below, and a unbounded quota if the quota is disabled. + * + * When the quota is strictly enforced. Any operation above the quota is not allowed + * and rejected with a THROTTLING_QUOTA_EXCEEDED error. + * + * @param request The request to extract the user and the clientId from + * @param strictSinceVersion The version since quota is strict + * @return + */ + def newQuotaFor(request: RequestChannel.Request, strictSinceVersion: Short): ControllerMutationQuota = { + if (request.header.apiVersion() >= strictSinceVersion) + newStrictQuotaFor(request) + else + newPermissiveQuotaFor(request) + } +} diff --git a/core/src/main/scala/kafka/server/ControllerServer.scala b/core/src/main/scala/kafka/server/ControllerServer.scala new file mode 100644 index 0000000..ede71d4 --- /dev/null +++ b/core/src/main/scala/kafka/server/ControllerServer.scala @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util +import java.util.concurrent.locks.ReentrantLock +import java.util.concurrent.{CompletableFuture, TimeUnit} + +import kafka.cluster.Broker.ServerInfo +import kafka.log.LogConfig +import kafka.metrics.{KafkaMetricsGroup, KafkaYammerMetrics, LinuxIoMetricsCollector} +import kafka.network.SocketServer +import kafka.raft.RaftManager +import kafka.security.CredentialProvider +import kafka.server.KafkaConfig.{AlterConfigPolicyClassNameProp, CreateTopicPolicyClassNameProp} +import kafka.server.QuotaFactory.QuotaManagers +import kafka.utils.{CoreUtils, Logging} +import org.apache.kafka.common.config.ConfigResource +import org.apache.kafka.common.message.ApiMessageType.ListenerType +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.security.scram.internals.ScramMechanism +import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache +import org.apache.kafka.common.utils.{LogContext, Time} +import org.apache.kafka.common.{ClusterResource, Endpoint} +import org.apache.kafka.controller.{Controller, QuorumController, QuorumControllerMetrics} +import org.apache.kafka.metadata.VersionRange +import org.apache.kafka.raft.RaftConfig +import org.apache.kafka.raft.RaftConfig.AddressSpec +import org.apache.kafka.server.authorizer.Authorizer +import org.apache.kafka.server.common.ApiMessageAndVersion +import org.apache.kafka.common.config.ConfigException +import org.apache.kafka.server.policy.{AlterConfigPolicy, CreateTopicPolicy} + +import scala.jdk.CollectionConverters._ +import scala.compat.java8.OptionConverters._ + +/** + * A Kafka controller that runs in KRaft (Kafka Raft) mode. + */ +class ControllerServer( + val metaProperties: MetaProperties, + val config: KafkaConfig, + val raftManager: RaftManager[ApiMessageAndVersion], + val time: Time, + val metrics: Metrics, + val threadNamePrefix: Option[String], + val controllerQuorumVotersFuture: CompletableFuture[util.Map[Integer, AddressSpec]] +) extends Logging with KafkaMetricsGroup { + import kafka.server.Server._ + + val lock = new ReentrantLock() + val awaitShutdownCond = lock.newCondition() + var status: ProcessStatus = SHUTDOWN + + var linuxIoMetricsCollector: LinuxIoMetricsCollector = null + var authorizer: Option[Authorizer] = null + var tokenCache: DelegationTokenCache = null + var credentialProvider: CredentialProvider = null + var socketServer: SocketServer = null + val socketServerFirstBoundPortFuture = new CompletableFuture[Integer]() + var createTopicPolicy: Option[CreateTopicPolicy] = None + var alterConfigPolicy: Option[AlterConfigPolicy] = None + var controller: Controller = null + val supportedFeatures: Map[String, VersionRange] = Map() + var quotaManagers: QuotaManagers = null + var controllerApis: ControllerApis = null + var controllerApisHandlerPool: KafkaRequestHandlerPool = null + + private def maybeChangeStatus(from: ProcessStatus, to: ProcessStatus): Boolean = { + lock.lock() + try { + if (status != from) return false + status = to + if (to == SHUTDOWN) awaitShutdownCond.signalAll() + } finally { + lock.unlock() + } + true + } + + def clusterId: String = metaProperties.clusterId.toString + + def startup(): Unit = { + if (!maybeChangeStatus(SHUTDOWN, STARTING)) return + try { + info("Starting controller") + + maybeChangeStatus(STARTING, STARTED) + // TODO: initialize the log dir(s) + this.logIdent = new LogContext(s"[ControllerServer id=${config.nodeId}] ").logPrefix() + + newGauge("ClusterId", () => clusterId) + newGauge("yammer-metrics-count", () => KafkaYammerMetrics.defaultRegistry.allMetrics.size) + + linuxIoMetricsCollector = new LinuxIoMetricsCollector("/proc", time, logger.underlying) + if (linuxIoMetricsCollector.usable()) { + newGauge("linux-disk-read-bytes", () => linuxIoMetricsCollector.readBytes()) + newGauge("linux-disk-write-bytes", () => linuxIoMetricsCollector.writeBytes()) + } + + val javaListeners = config.controllerListeners.map(_.toJava).asJava + authorizer = config.authorizer + authorizer.foreach(_.configure(config.originals)) + + val authorizerFutures: Map[Endpoint, CompletableFuture[Void]] = authorizer match { + case Some(authZ) => + // It would be nice to remove some of the broker-specific assumptions from + // AuthorizerServerInfo, such as the assumption that there is an inter-broker + // listener, or that ID is named brokerId. + val controllerAuthorizerInfo = ServerInfo( + new ClusterResource(clusterId), config.nodeId, javaListeners, javaListeners.get(0)) + authZ.start(controllerAuthorizerInfo).asScala.map { case (ep, cs) => + ep -> cs.toCompletableFuture + }.toMap + case None => + javaListeners.asScala.map { + ep => ep -> CompletableFuture.completedFuture[Void](null) + }.toMap + } + + val apiVersionManager = new SimpleApiVersionManager(ListenerType.CONTROLLER) + + tokenCache = new DelegationTokenCache(ScramMechanism.mechanismNames) + credentialProvider = new CredentialProvider(ScramMechanism.mechanismNames, tokenCache) + socketServer = new SocketServer(config, + metrics, + time, + credentialProvider, + apiVersionManager) + socketServer.startup(startProcessingRequests = false, controlPlaneListener = None, config.controllerListeners) + + if (config.controllerListeners.nonEmpty) { + socketServerFirstBoundPortFuture.complete(socketServer.boundPort( + config.controllerListeners.head.listenerName)) + } else { + throw new ConfigException("No controller.listener.names defined for controller"); + } + + val configDefs = Map(ConfigResource.Type.BROKER -> KafkaConfig.configDef, + ConfigResource.Type.TOPIC -> LogConfig.configDefCopy).asJava + val threadNamePrefixAsString = threadNamePrefix.getOrElse("") + + createTopicPolicy = Option(config. + getConfiguredInstance(CreateTopicPolicyClassNameProp, classOf[CreateTopicPolicy])) + alterConfigPolicy = Option(config. + getConfiguredInstance(AlterConfigPolicyClassNameProp, classOf[AlterConfigPolicy])) + + controller = new QuorumController.Builder(config.nodeId). + setTime(time). + setThreadNamePrefix(threadNamePrefixAsString). + setConfigDefs(configDefs). + setRaftClient(raftManager.client). + setDefaultReplicationFactor(config.defaultReplicationFactor.toShort). + setDefaultNumPartitions(config.numPartitions.intValue()). + setSessionTimeoutNs(TimeUnit.NANOSECONDS.convert(config.brokerSessionTimeoutMs.longValue(), + TimeUnit.MILLISECONDS)). + setSnapshotMaxNewRecordBytes(config.metadataSnapshotMaxNewRecordBytes). + setMetrics(new QuorumControllerMetrics(KafkaYammerMetrics.defaultRegistry())). + setCreateTopicPolicy(createTopicPolicy.asJava). + setAlterConfigPolicy(alterConfigPolicy.asJava). + setConfigurationValidator(new ControllerConfigurationValidator()). + build() + + quotaManagers = QuotaFactory.instantiate(config, metrics, time, threadNamePrefix.getOrElse("")) + val controllerNodes = RaftConfig.voterConnectionsToNodes(controllerQuorumVotersFuture.get()).asScala + controllerApis = new ControllerApis(socketServer.dataPlaneRequestChannel, + authorizer, + quotaManagers, + time, + supportedFeatures, + controller, + raftManager, + config, + metaProperties, + controllerNodes.toSeq, + apiVersionManager) + controllerApisHandlerPool = new KafkaRequestHandlerPool(config.nodeId, + socketServer.dataPlaneRequestChannel, + controllerApis, + time, + config.numIoThreads, + s"${SocketServer.DataPlaneMetricPrefix}RequestHandlerAvgIdlePercent", + SocketServer.DataPlaneThreadPrefix) + socketServer.startProcessingRequests(authorizerFutures) + } catch { + case e: Throwable => + maybeChangeStatus(STARTING, STARTED) + fatal("Fatal error during controller startup. Prepare to shutdown", e) + shutdown() + throw e + } + } + + def shutdown(): Unit = { + if (!maybeChangeStatus(STARTED, SHUTTING_DOWN)) return + try { + info("shutting down") + if (socketServer != null) + CoreUtils.swallow(socketServer.stopProcessingRequests(), this) + if (controller != null) + controller.beginShutdown() + if (socketServer != null) + CoreUtils.swallow(socketServer.shutdown(), this) + if (controllerApisHandlerPool != null) + CoreUtils.swallow(controllerApisHandlerPool.shutdown(), this) + if (controllerApis != null) + CoreUtils.swallow(controllerApis.close(), this) + if (quotaManagers != null) + CoreUtils.swallow(quotaManagers.shutdown(), this) + if (controller != null) + controller.close() + createTopicPolicy.foreach(policy => CoreUtils.swallow(policy.close(), this)) + alterConfigPolicy.foreach(policy => CoreUtils.swallow(policy.close(), this)) + socketServerFirstBoundPortFuture.completeExceptionally(new RuntimeException("shutting down")) + } catch { + case e: Throwable => + fatal("Fatal error during controller shutdown.", e) + throw e + } finally { + maybeChangeStatus(SHUTTING_DOWN, SHUTDOWN) + } + } + + def awaitShutdown(): Unit = { + lock.lock() + try { + while (true) { + if (status == SHUTDOWN) return + awaitShutdownCond.awaitUninterruptibly() + } + } finally { + lock.unlock() + } + } +} diff --git a/core/src/main/scala/kafka/server/DelayedCreatePartitions.scala b/core/src/main/scala/kafka/server/DelayedCreatePartitions.scala new file mode 100644 index 0000000..7809aaa --- /dev/null +++ b/core/src/main/scala/kafka/server/DelayedCreatePartitions.scala @@ -0,0 +1,102 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.api.LeaderAndIsr +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.ApiError + +import scala.collection._ + +/** + * The create metadata maintained by the delayed create topic or create partitions operations. + */ +case class CreatePartitionsMetadata(topic: String, partitions: Set[Int], error: ApiError) + +object CreatePartitionsMetadata { + def apply(topic: String, partitions: Set[Int]): CreatePartitionsMetadata = { + CreatePartitionsMetadata(topic, partitions, ApiError.NONE) + } + + def apply(topic: String, error: Errors): CreatePartitionsMetadata = { + CreatePartitionsMetadata(topic, Set.empty, new ApiError(error, null)) + } + + def apply(topic: String, throwable: Throwable): CreatePartitionsMetadata = { + CreatePartitionsMetadata(topic, Set.empty, ApiError.fromThrowable(throwable)) + } +} + +/** + * A delayed create topic or create partitions operation that is stored in the topic purgatory. + */ +class DelayedCreatePartitions(delayMs: Long, + createMetadata: Seq[CreatePartitionsMetadata], + adminManager: ZkAdminManager, + responseCallback: Map[String, ApiError] => Unit) + extends DelayedOperation(delayMs) { + + /** + * The operation can be completed if all of the topics that do not have an error exist and every partition has a + * leader in the controller. + * See KafkaController.onNewTopicCreation + */ + override def tryComplete() : Boolean = { + trace(s"Trying to complete operation for $createMetadata") + + val leaderlessPartitionCount = createMetadata.filter(_.error.isSuccess).foldLeft(0) { case (topicCounter, metadata) => + topicCounter + missingLeaderCount(metadata.topic, metadata.partitions) + } + + if (leaderlessPartitionCount == 0) { + trace("All partitions have a leader, completing the delayed operation") + forceComplete() + } else { + trace(s"$leaderlessPartitionCount partitions do not have a leader, not completing the delayed operation") + false + } + } + + /** + * Check for partitions that are still missing a leader, update their error code and call the responseCallback + */ + override def onComplete(): Unit = { + trace(s"Completing operation for $createMetadata") + val results = createMetadata.map { metadata => + // ignore topics that already have errors + if (metadata.error.isSuccess && missingLeaderCount(metadata.topic, metadata.partitions) > 0) + (metadata.topic, new ApiError(Errors.REQUEST_TIMED_OUT, null)) + else + (metadata.topic, metadata.error) + }.toMap + responseCallback(results) + } + + override def onExpiration(): Unit = {} + + private def missingLeaderCount(topic: String, partitions: Set[Int]): Int = { + partitions.foldLeft(0) { case (counter, partition) => + if (isMissingLeader(topic, partition)) counter + 1 else counter + } + } + + private def isMissingLeader(topic: String, partition: Int): Boolean = { + val partitionInfo = adminManager.metadataCache.getPartitionInfo(topic, partition) + partitionInfo.forall(_.leader == LeaderAndIsr.NoLeader) + } +} diff --git a/core/src/main/scala/kafka/server/DelayedDeleteRecords.scala b/core/src/main/scala/kafka/server/DelayedDeleteRecords.scala new file mode 100644 index 0000000..317d0b8 --- /dev/null +++ b/core/src/main/scala/kafka/server/DelayedDeleteRecords.scala @@ -0,0 +1,133 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + + +import java.util.concurrent.TimeUnit + +import kafka.metrics.KafkaMetricsGroup +import kafka.utils.Implicits._ +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.message.DeleteRecordsResponseData +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.DeleteRecordsResponse + +import scala.collection._ + + +case class DeleteRecordsPartitionStatus(requiredOffset: Long, + responseStatus: DeleteRecordsResponseData.DeleteRecordsPartitionResult) { + @volatile var acksPending = false + + override def toString = "[acksPending: %b, error: %s, lowWatermark: %d, requiredOffset: %d]" + .format(acksPending, Errors.forCode(responseStatus.errorCode).toString, responseStatus.lowWatermark, requiredOffset) +} + +/** + * A delayed delete records operation that can be created by the replica manager and watched + * in the delete records operation purgatory + */ +class DelayedDeleteRecords(delayMs: Long, + deleteRecordsStatus: Map[TopicPartition, DeleteRecordsPartitionStatus], + replicaManager: ReplicaManager, + responseCallback: Map[TopicPartition, DeleteRecordsResponseData.DeleteRecordsPartitionResult] => Unit) + extends DelayedOperation(delayMs) { + + // first update the acks pending variable according to the error code + deleteRecordsStatus.forKeyValue { (topicPartition, status) => + if (status.responseStatus.errorCode == Errors.NONE.code) { + // Timeout error state will be cleared when required acks are received + status.acksPending = true + status.responseStatus.setErrorCode(Errors.REQUEST_TIMED_OUT.code) + } else { + status.acksPending = false + } + + trace("Initial partition status for %s is %s".format(topicPartition, status)) + } + + /** + * The delayed delete records operation can be completed if every partition specified in the request satisfied one of the following: + * + * 1) There was an error while checking if all replicas have caught up to the deleteRecordsOffset: set an error in response + * 2) The low watermark of the partition has caught up to the deleteRecordsOffset. set the low watermark in response + * + */ + override def tryComplete(): Boolean = { + // check for each partition if it still has pending acks + deleteRecordsStatus.forKeyValue { (topicPartition, status) => + trace(s"Checking delete records satisfaction for $topicPartition, current status $status") + // skip those partitions that have already been satisfied + if (status.acksPending) { + val (lowWatermarkReached, error, lw) = replicaManager.getPartition(topicPartition) match { + case HostedPartition.Online(partition) => + partition.leaderLogIfLocal match { + case Some(_) => + val leaderLW = partition.lowWatermarkIfLeader + (leaderLW >= status.requiredOffset, Errors.NONE, leaderLW) + case None => + (false, Errors.NOT_LEADER_OR_FOLLOWER, DeleteRecordsResponse.INVALID_LOW_WATERMARK) + } + + case HostedPartition.Offline => + (false, Errors.KAFKA_STORAGE_ERROR, DeleteRecordsResponse.INVALID_LOW_WATERMARK) + + case HostedPartition.None => + (false, Errors.UNKNOWN_TOPIC_OR_PARTITION, DeleteRecordsResponse.INVALID_LOW_WATERMARK) + } + if (error != Errors.NONE || lowWatermarkReached) { + status.acksPending = false + status.responseStatus.setErrorCode(error.code) + status.responseStatus.setLowWatermark(lw) + } + } + } + + // check if every partition has satisfied at least one of case A or B + if (!deleteRecordsStatus.values.exists(_.acksPending)) + forceComplete() + else + false + } + + override def onExpiration(): Unit = { + deleteRecordsStatus.forKeyValue { (topicPartition, status) => + if (status.acksPending) { + DelayedDeleteRecordsMetrics.recordExpiration(topicPartition) + } + } + } + + /** + * Upon completion, return the current response status along with the error code per partition + */ + override def onComplete(): Unit = { + val responseStatus = deleteRecordsStatus.map { case (k, status) => k -> status.responseStatus } + responseCallback(responseStatus) + } +} + +object DelayedDeleteRecordsMetrics extends KafkaMetricsGroup { + + private val aggregateExpirationMeter = newMeter("ExpiresPerSec", "requests", TimeUnit.SECONDS) + + def recordExpiration(partition: TopicPartition): Unit = { + aggregateExpirationMeter.mark() + } +} + diff --git a/core/src/main/scala/kafka/server/DelayedDeleteTopics.scala b/core/src/main/scala/kafka/server/DelayedDeleteTopics.scala new file mode 100644 index 0000000..523252c --- /dev/null +++ b/core/src/main/scala/kafka/server/DelayedDeleteTopics.scala @@ -0,0 +1,83 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import org.apache.kafka.common.protocol.Errors + +import scala.collection._ + +/** + * The delete metadata maintained by the delayed delete operation + */ +case class DeleteTopicMetadata(topic: String, error: Errors) + +object DeleteTopicMetadata { + def apply(topic: String, throwable: Throwable): DeleteTopicMetadata = { + DeleteTopicMetadata(topic, Errors.forException(throwable)) + } +} + +/** + * A delayed delete topics operation that can be created by the admin manager and watched + * in the topic purgatory + */ +class DelayedDeleteTopics(delayMs: Long, + deleteMetadata: Seq[DeleteTopicMetadata], + adminManager: ZkAdminManager, + responseCallback: Map[String, Errors] => Unit) + extends DelayedOperation(delayMs) { + + /** + * The operation can be completed if all of the topics not in error have been removed + */ + override def tryComplete() : Boolean = { + trace(s"Trying to complete operation for $deleteMetadata") + + // Ignore topics that already have errors + val existingTopics = deleteMetadata.count { metadata => metadata.error == Errors.NONE && topicExists(metadata.topic) } + + if (existingTopics == 0) { + trace("All topics have been deleted or have errors, completing the delayed operation") + forceComplete() + } else { + trace(s"$existingTopics topics still exist, not completing the delayed operation") + false + } + } + + /** + * Check for partitions that still exist, update their error code and call the responseCallback + */ + override def onComplete(): Unit = { + trace(s"Completing operation for $deleteMetadata") + val results = deleteMetadata.map { metadata => + // ignore topics that already have errors + if (metadata.error == Errors.NONE && topicExists(metadata.topic)) + (metadata.topic, Errors.REQUEST_TIMED_OUT) + else + (metadata.topic, metadata.error) + }.toMap + responseCallback(results) + } + + override def onExpiration(): Unit = { } + + private def topicExists(topic: String): Boolean = { + adminManager.metadataCache.contains(topic) + } +} diff --git a/core/src/main/scala/kafka/server/DelayedElectLeader.scala b/core/src/main/scala/kafka/server/DelayedElectLeader.scala new file mode 100644 index 0000000..cd0a804 --- /dev/null +++ b/core/src/main/scala/kafka/server/DelayedElectLeader.scala @@ -0,0 +1,83 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.ApiError + +import scala.collection.{Map, mutable} + +/** A delayed elect leader operation that can be created by the replica manager and watched + * in the elect leader purgatory + */ +class DelayedElectLeader( + delayMs: Long, + expectedLeaders: Map[TopicPartition, Int], + results: Map[TopicPartition, ApiError], + replicaManager: ReplicaManager, + responseCallback: Map[TopicPartition, ApiError] => Unit +) extends DelayedOperation(delayMs) { + + private val waitingPartitions = mutable.Map() ++= expectedLeaders + private val fullResults = mutable.Map() ++= results + + + /** + * Call-back to execute when a delayed operation gets expired and hence forced to complete. + */ + override def onExpiration(): Unit = {} + + /** + * Process for completing an operation; This function needs to be defined + * in subclasses and will be called exactly once in forceComplete() + */ + override def onComplete(): Unit = { + // This could be called to force complete, so I need the full list of partitions, so I can time them all out. + updateWaiting() + val timedOut = waitingPartitions.map { + case (tp, _) => tp -> new ApiError(Errors.REQUEST_TIMED_OUT, null) + } + responseCallback(timedOut ++ fullResults) + } + + /** + * Try to complete the delayed operation by first checking if the operation + * can be completed by now. If yes execute the completion logic by calling + * forceComplete() and return true iff forceComplete returns true; otherwise return false + * + * This function needs to be defined in subclasses + */ + override def tryComplete(): Boolean = { + updateWaiting() + debug(s"tryComplete() waitingPartitions: $waitingPartitions") + waitingPartitions.isEmpty && forceComplete() + } + + private def updateWaiting(): Unit = { + val metadataCache = replicaManager.metadataCache + val completedPartitions = waitingPartitions.collect { + case (tp, leader) if metadataCache.getPartitionInfo(tp.topic, tp.partition).exists(_.leader == leader) => tp + } + completedPartitions.foreach { tp => + waitingPartitions -= tp + fullResults += tp -> ApiError.NONE + } + } + +} diff --git a/core/src/main/scala/kafka/server/DelayedFetch.scala b/core/src/main/scala/kafka/server/DelayedFetch.scala new file mode 100644 index 0000000..1bc2a73 --- /dev/null +++ b/core/src/main/scala/kafka/server/DelayedFetch.scala @@ -0,0 +1,201 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util.concurrent.TimeUnit + +import kafka.metrics.KafkaMetricsGroup +import org.apache.kafka.common.TopicIdPartition +import org.apache.kafka.common.errors._ +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.replica.ClientMetadata +import org.apache.kafka.common.requests.FetchRequest.PartitionData +import org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.{UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET} + +import scala.collection._ + +case class FetchPartitionStatus(startOffsetMetadata: LogOffsetMetadata, fetchInfo: PartitionData) { + + override def toString: String = { + "[startOffsetMetadata: " + startOffsetMetadata + + ", fetchInfo: " + fetchInfo + + "]" + } +} + +/** + * The fetch metadata maintained by the delayed fetch operation + */ +case class FetchMetadata(fetchMinBytes: Int, + fetchMaxBytes: Int, + hardMaxBytesLimit: Boolean, + fetchOnlyLeader: Boolean, + fetchIsolation: FetchIsolation, + isFromFollower: Boolean, + replicaId: Int, + fetchPartitionStatus: Seq[(TopicIdPartition, FetchPartitionStatus)]) { + + override def toString = "FetchMetadata(minBytes=" + fetchMinBytes + ", " + + "maxBytes=" + fetchMaxBytes + ", " + + "onlyLeader=" + fetchOnlyLeader + ", " + + "fetchIsolation=" + fetchIsolation + ", " + + "replicaId=" + replicaId + ", " + + "partitionStatus=" + fetchPartitionStatus + ")" +} +/** + * A delayed fetch operation that can be created by the replica manager and watched + * in the fetch operation purgatory + */ +class DelayedFetch(delayMs: Long, + fetchMetadata: FetchMetadata, + replicaManager: ReplicaManager, + quota: ReplicaQuota, + clientMetadata: Option[ClientMetadata], + responseCallback: Seq[(TopicIdPartition, FetchPartitionData)] => Unit) + extends DelayedOperation(delayMs) { + + /** + * The operation can be completed if: + * + * Case A: This broker is no longer the leader for some partitions it tries to fetch + * Case B: The replica is no longer available on this broker + * Case C: This broker does not know of some partitions it tries to fetch + * Case D: The partition is in an offline log directory on this broker + * Case E: This broker is the leader, but the requested epoch is now fenced + * Case F: The fetch offset locates not on the last segment of the log + * Case G: The accumulated bytes from all the fetching partitions exceeds the minimum bytes + * Case H: A diverging epoch was found, return response to trigger truncation + * Upon completion, should return whatever data is available for each valid partition + */ + override def tryComplete(): Boolean = { + var accumulatedSize = 0 + fetchMetadata.fetchPartitionStatus.foreach { + case (topicIdPartition, fetchStatus) => + val fetchOffset = fetchStatus.startOffsetMetadata + val fetchLeaderEpoch = fetchStatus.fetchInfo.currentLeaderEpoch + try { + if (fetchOffset != LogOffsetMetadata.UnknownOffsetMetadata) { + val partition = replicaManager.getPartitionOrException(topicIdPartition.topicPartition) + val offsetSnapshot = partition.fetchOffsetSnapshot(fetchLeaderEpoch, fetchMetadata.fetchOnlyLeader) + + val endOffset = fetchMetadata.fetchIsolation match { + case FetchLogEnd => offsetSnapshot.logEndOffset + case FetchHighWatermark => offsetSnapshot.highWatermark + case FetchTxnCommitted => offsetSnapshot.lastStableOffset + } + + // Go directly to the check for Case G if the message offsets are the same. If the log segment + // has just rolled, then the high watermark offset will remain the same but be on the old segment, + // which would incorrectly be seen as an instance of Case F. + if (endOffset.messageOffset != fetchOffset.messageOffset) { + if (endOffset.onOlderSegment(fetchOffset)) { + // Case F, this can happen when the new fetch operation is on a truncated leader + debug(s"Satisfying fetch $fetchMetadata since it is fetching later segments of partition $topicIdPartition.") + return forceComplete() + } else if (fetchOffset.onOlderSegment(endOffset)) { + // Case F, this can happen when the fetch operation is falling behind the current segment + // or the partition has just rolled a new segment + debug(s"Satisfying fetch $fetchMetadata immediately since it is fetching older segments.") + // We will not force complete the fetch request if a replica should be throttled. + if (!replicaManager.shouldLeaderThrottle(quota, partition, fetchMetadata.replicaId)) + return forceComplete() + } else if (fetchOffset.messageOffset < endOffset.messageOffset) { + // we take the partition fetch size as upper bound when accumulating the bytes (skip if a throttled partition) + val bytesAvailable = math.min(endOffset.positionDiff(fetchOffset), fetchStatus.fetchInfo.maxBytes) + if (!replicaManager.shouldLeaderThrottle(quota, partition, fetchMetadata.replicaId)) + accumulatedSize += bytesAvailable + } + } + + // Case H: If truncation has caused diverging epoch while this request was in purgatory, return to trigger truncation + fetchStatus.fetchInfo.lastFetchedEpoch.ifPresent { fetchEpoch => + val epochEndOffset = partition.lastOffsetForLeaderEpoch(fetchLeaderEpoch, fetchEpoch, fetchOnlyFromLeader = false) + if (epochEndOffset.errorCode != Errors.NONE.code() + || epochEndOffset.endOffset == UNDEFINED_EPOCH_OFFSET + || epochEndOffset.leaderEpoch == UNDEFINED_EPOCH) { + debug(s"Could not obtain last offset for leader epoch for partition $topicIdPartition, epochEndOffset=$epochEndOffset.") + return forceComplete() + } else if (epochEndOffset.leaderEpoch < fetchEpoch || epochEndOffset.endOffset < fetchStatus.fetchInfo.fetchOffset) { + debug(s"Satisfying fetch $fetchMetadata since it has diverging epoch requiring truncation for partition " + + s"$topicIdPartition epochEndOffset=$epochEndOffset fetchEpoch=$fetchEpoch fetchOffset=${fetchStatus.fetchInfo.fetchOffset}.") + return forceComplete() + } + } + } + } catch { + case _: NotLeaderOrFollowerException => // Case A or Case B + debug(s"Broker is no longer the leader or follower of $topicIdPartition, satisfy $fetchMetadata immediately") + return forceComplete() + case _: UnknownTopicOrPartitionException => // Case C + debug(s"Broker no longer knows of partition $topicIdPartition, satisfy $fetchMetadata immediately") + return forceComplete() + case _: KafkaStorageException => // Case D + debug(s"Partition $topicIdPartition is in an offline log directory, satisfy $fetchMetadata immediately") + return forceComplete() + case _: FencedLeaderEpochException => // Case E + debug(s"Broker is the leader of partition $topicIdPartition, but the requested epoch " + + s"$fetchLeaderEpoch is fenced by the latest leader epoch, satisfy $fetchMetadata immediately") + return forceComplete() + } + } + + // Case G + if (accumulatedSize >= fetchMetadata.fetchMinBytes) + forceComplete() + else + false + } + + override def onExpiration(): Unit = { + if (fetchMetadata.isFromFollower) + DelayedFetchMetrics.followerExpiredRequestMeter.mark() + else + DelayedFetchMetrics.consumerExpiredRequestMeter.mark() + } + + /** + * Upon completion, read whatever data is available and pass to the complete callback + */ + override def onComplete(): Unit = { + val logReadResults = replicaManager.readFromLocalLog( + replicaId = fetchMetadata.replicaId, + fetchOnlyFromLeader = fetchMetadata.fetchOnlyLeader, + fetchIsolation = fetchMetadata.fetchIsolation, + fetchMaxBytes = fetchMetadata.fetchMaxBytes, + hardMaxBytesLimit = fetchMetadata.hardMaxBytesLimit, + readPartitionInfo = fetchMetadata.fetchPartitionStatus.map { case (tp, status) => tp -> status.fetchInfo }, + clientMetadata = clientMetadata, + quota = quota) + + val fetchPartitionData = logReadResults.map { case (tp, result) => + val isReassignmentFetch = fetchMetadata.isFromFollower && + replicaManager.isAddingReplica(tp.topicPartition, fetchMetadata.replicaId) + + tp -> result.toFetchPartitionData(isReassignmentFetch) + } + + responseCallback(fetchPartitionData) + } +} + +object DelayedFetchMetrics extends KafkaMetricsGroup { + private val FetcherTypeKey = "fetcherType" + val followerExpiredRequestMeter = newMeter("ExpiresPerSec", "requests", TimeUnit.SECONDS, tags = Map(FetcherTypeKey -> "follower")) + val consumerExpiredRequestMeter = newMeter("ExpiresPerSec", "requests", TimeUnit.SECONDS, tags = Map(FetcherTypeKey -> "consumer")) +} + diff --git a/core/src/main/scala/kafka/server/DelayedFuture.scala b/core/src/main/scala/kafka/server/DelayedFuture.scala new file mode 100644 index 0000000..018fae0 --- /dev/null +++ b/core/src/main/scala/kafka/server/DelayedFuture.scala @@ -0,0 +1,102 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util.concurrent._ +import java.util.function.BiConsumer + +import org.apache.kafka.common.errors.TimeoutException +import org.apache.kafka.common.utils.KafkaThread + +import scala.collection.Seq + +/** + * A delayed operation using CompletionFutures that can be created by KafkaApis and watched + * in a DelayedFuturePurgatory purgatory. This is used for ACL updates using async Authorizers. + */ +class DelayedFuture[T](timeoutMs: Long, + futures: Seq[CompletableFuture[T]], + responseCallback: () => Unit) + extends DelayedOperation(timeoutMs) { + + /** + * The operation can be completed if all the futures have completed successfully + * or failed with exceptions. + */ + override def tryComplete() : Boolean = { + trace(s"Trying to complete operation for ${futures.size} futures") + + val pending = futures.count(future => !future.isDone) + if (pending == 0) { + trace("All futures have been completed or have errors, completing the delayed operation") + forceComplete() + } else { + trace(s"$pending future still pending, not completing the delayed operation") + false + } + } + + /** + * Timeout any pending futures and invoke responseCallback. This is invoked when all + * futures have completed or the operation has timed out. + */ + override def onComplete(): Unit = { + val pendingFutures = futures.filterNot(_.isDone) + trace(s"Completing operation for ${futures.size} futures, expired ${pendingFutures.size}") + pendingFutures.foreach(_.completeExceptionally(new TimeoutException(s"Request has been timed out after $timeoutMs ms"))) + responseCallback.apply() + } + + /** + * This is invoked after onComplete(), so no actions required. + */ + override def onExpiration(): Unit = { + } +} + +class DelayedFuturePurgatory(purgatoryName: String, brokerId: Int) { + private val purgatory = DelayedOperationPurgatory[DelayedFuture[_]](purgatoryName, brokerId) + private val executor = new ThreadPoolExecutor(1, 1, 0, TimeUnit.MILLISECONDS, + new LinkedBlockingQueue[Runnable](), + new ThreadFactory { + override def newThread(r: Runnable): Thread = new KafkaThread(s"DelayedExecutor-$purgatoryName", r, true) + }) + val purgatoryKey = new Object + + def tryCompleteElseWatch[T](timeoutMs: Long, + futures: Seq[CompletableFuture[T]], + responseCallback: () => Unit): DelayedFuture[T] = { + val delayedFuture = new DelayedFuture[T](timeoutMs, futures, responseCallback) + val done = purgatory.tryCompleteElseWatch(delayedFuture, Seq(purgatoryKey)) + if (!done) { + val callbackAction = new BiConsumer[Void, Throwable]() { + override def accept(result: Void, exception: Throwable): Unit = delayedFuture.forceComplete() + } + CompletableFuture.allOf(futures.toArray: _*).whenCompleteAsync(callbackAction, executor) + } + delayedFuture + } + + def shutdown(): Unit = { + executor.shutdownNow() + executor.awaitTermination(60, TimeUnit.SECONDS) + purgatory.shutdown() + } + + def isShutdown: Boolean = executor.isShutdown +} diff --git a/core/src/main/scala/kafka/server/DelayedOperation.scala b/core/src/main/scala/kafka/server/DelayedOperation.scala new file mode 100644 index 0000000..09fd337 --- /dev/null +++ b/core/src/main/scala/kafka/server/DelayedOperation.scala @@ -0,0 +1,436 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util.concurrent._ +import java.util.concurrent.atomic._ +import java.util.concurrent.locks.{Lock, ReentrantLock} + +import kafka.metrics.KafkaMetricsGroup +import kafka.utils.CoreUtils.inLock +import kafka.utils._ +import kafka.utils.timer._ + +import scala.collection._ +import scala.collection.mutable.ListBuffer + +/** + * An operation whose processing needs to be delayed for at most the given delayMs. For example + * a delayed produce operation could be waiting for specified number of acks; or + * a delayed fetch operation could be waiting for a given number of bytes to accumulate. + * + * The logic upon completing a delayed operation is defined in onComplete() and will be called exactly once. + * Once an operation is completed, isCompleted() will return true. onComplete() can be triggered by either + * forceComplete(), which forces calling onComplete() after delayMs if the operation is not yet completed, + * or tryComplete(), which first checks if the operation can be completed or not now, and if yes calls + * forceComplete(). + * + * A subclass of DelayedOperation needs to provide an implementation of both onComplete() and tryComplete(). + * + * Noted that if you add a future delayed operation that calls ReplicaManager.appendRecords() in onComplete() + * like DelayedJoin, you must be aware that this operation's onExpiration() needs to call actionQueue.tryCompleteAction(). + */ +abstract class DelayedOperation(override val delayMs: Long, + lockOpt: Option[Lock] = None) + extends TimerTask with Logging { + + private val completed = new AtomicBoolean(false) + // Visible for testing + private[server] val lock: Lock = lockOpt.getOrElse(new ReentrantLock) + + /* + * Force completing the delayed operation, if not already completed. + * This function can be triggered when + * + * 1. The operation has been verified to be completable inside tryComplete() + * 2. The operation has expired and hence needs to be completed right now + * + * Return true iff the operation is completed by the caller: note that + * concurrent threads can try to complete the same operation, but only + * the first thread will succeed in completing the operation and return + * true, others will still return false + */ + def forceComplete(): Boolean = { + if (completed.compareAndSet(false, true)) { + // cancel the timeout timer + cancel() + onComplete() + true + } else { + false + } + } + + /** + * Check if the delayed operation is already completed + */ + def isCompleted: Boolean = completed.get() + + /** + * Call-back to execute when a delayed operation gets expired and hence forced to complete. + */ + def onExpiration(): Unit + + /** + * Process for completing an operation; This function needs to be defined + * in subclasses and will be called exactly once in forceComplete() + */ + def onComplete(): Unit + + /** + * Try to complete the delayed operation by first checking if the operation + * can be completed by now. If yes execute the completion logic by calling + * forceComplete() and return true iff forceComplete returns true; otherwise return false + * + * This function needs to be defined in subclasses + */ + def tryComplete(): Boolean + + /** + * Thread-safe variant of tryComplete() and call extra function if first tryComplete returns false + * @param f else function to be executed after first tryComplete returns false + * @return result of tryComplete + */ + private[server] def safeTryCompleteOrElse(f: => Unit): Boolean = inLock(lock) { + if (tryComplete()) true + else { + f + // last completion check + tryComplete() + } + } + + /** + * Thread-safe variant of tryComplete() + */ + private[server] def safeTryComplete(): Boolean = inLock(lock)(tryComplete()) + + /* + * run() method defines a task that is executed on timeout + */ + override def run(): Unit = { + if (forceComplete()) + onExpiration() + } +} + +object DelayedOperationPurgatory { + + private val Shards = 512 // Shard the watcher list to reduce lock contention + + def apply[T <: DelayedOperation](purgatoryName: String, + brokerId: Int = 0, + purgeInterval: Int = 1000, + reaperEnabled: Boolean = true, + timerEnabled: Boolean = true): DelayedOperationPurgatory[T] = { + val timer = new SystemTimer(purgatoryName) + new DelayedOperationPurgatory[T](purgatoryName, timer, brokerId, purgeInterval, reaperEnabled, timerEnabled) + } + +} + +/** + * A helper purgatory class for bookkeeping delayed operations with a timeout, and expiring timed out operations. + */ +final class DelayedOperationPurgatory[T <: DelayedOperation](purgatoryName: String, + timeoutTimer: Timer, + brokerId: Int = 0, + purgeInterval: Int = 1000, + reaperEnabled: Boolean = true, + timerEnabled: Boolean = true) + extends Logging with KafkaMetricsGroup { + /* a list of operation watching keys */ + private class WatcherList { + val watchersByKey = new Pool[Any, Watchers](Some((key: Any) => new Watchers(key))) + + val watchersLock = new ReentrantLock() + + /* + * Return all the current watcher lists, + * note that the returned watchers may be removed from the list by other threads + */ + def allWatchers = { + watchersByKey.values + } + } + + private val watcherLists = Array.fill[WatcherList](DelayedOperationPurgatory.Shards)(new WatcherList) + private def watcherList(key: Any): WatcherList = { + watcherLists(Math.abs(key.hashCode() % watcherLists.length)) + } + + // the number of estimated total operations in the purgatory + private[this] val estimatedTotalOperations = new AtomicInteger(0) + + /* background thread expiring operations that have timed out */ + private val expirationReaper = new ExpiredOperationReaper() + + private val metricsTags = Map("delayedOperation" -> purgatoryName) + newGauge("PurgatorySize", () => watched, metricsTags) + newGauge("NumDelayedOperations", () => numDelayed, metricsTags) + + if (reaperEnabled) + expirationReaper.start() + + /** + * Check if the operation can be completed, if not watch it based on the given watch keys + * + * Note that a delayed operation can be watched on multiple keys. It is possible that + * an operation is completed after it has been added to the watch list for some, but + * not all of the keys. In this case, the operation is considered completed and won't + * be added to the watch list of the remaining keys. The expiration reaper thread will + * remove this operation from any watcher list in which the operation exists. + * + * @param operation the delayed operation to be checked + * @param watchKeys keys for bookkeeping the operation + * @return true iff the delayed operations can be completed by the caller + */ + def tryCompleteElseWatch(operation: T, watchKeys: Seq[Any]): Boolean = { + assert(watchKeys.nonEmpty, "The watch key list can't be empty") + + // The cost of tryComplete() is typically proportional to the number of keys. Calling tryComplete() for each key is + // going to be expensive if there are many keys. Instead, we do the check in the following way through safeTryCompleteOrElse(). + // If the operation is not completed, we just add the operation to all keys. Then we call tryComplete() again. At + // this time, if the operation is still not completed, we are guaranteed that it won't miss any future triggering + // event since the operation is already on the watcher list for all keys. + // + // ==============[story about lock]============== + // Through safeTryCompleteOrElse(), we hold the operation's lock while adding the operation to watch list and doing + // the tryComplete() check. This is to avoid a potential deadlock between the callers to tryCompleteElseWatch() and + // checkAndComplete(). For example, the following deadlock can happen if the lock is only held for the final tryComplete() + // 1) thread_a holds readlock of stateLock from TransactionStateManager + // 2) thread_a is executing tryCompleteElseWatch() + // 3) thread_a adds op to watch list + // 4) thread_b requires writelock of stateLock from TransactionStateManager (blocked by thread_a) + // 5) thread_c calls checkAndComplete() and holds lock of op + // 6) thread_c is waiting readlock of stateLock to complete op (blocked by thread_b) + // 7) thread_a is waiting lock of op to call the final tryComplete() (blocked by thread_c) + // + // Note that even with the current approach, deadlocks could still be introduced. For example, + // 1) thread_a calls tryCompleteElseWatch() and gets lock of op + // 2) thread_a adds op to watch list + // 3) thread_a calls op#tryComplete and tries to require lock_b + // 4) thread_b holds lock_b and calls checkAndComplete() + // 5) thread_b sees op from watch list + // 6) thread_b needs lock of op + // To avoid the above scenario, we recommend DelayedOperationPurgatory.checkAndComplete() be called without holding + // any exclusive lock. Since DelayedOperationPurgatory.checkAndComplete() completes delayed operations asynchronously, + // holding a exclusive lock to make the call is often unnecessary. + if (operation.safeTryCompleteOrElse { + watchKeys.foreach(key => watchForOperation(key, operation)) + if (watchKeys.nonEmpty) estimatedTotalOperations.incrementAndGet() + }) return true + + // if it cannot be completed by now and hence is watched, add to the expire queue also + if (!operation.isCompleted) { + if (timerEnabled) + timeoutTimer.add(operation) + if (operation.isCompleted) { + // cancel the timer task + operation.cancel() + } + } + + false + } + + /** + * Check if some delayed operations can be completed with the given watch key, + * and if yes complete them. + * + * @return the number of completed operations during this process + */ + def checkAndComplete(key: Any): Int = { + val wl = watcherList(key) + val watchers = inLock(wl.watchersLock) { wl.watchersByKey.get(key) } + val numCompleted = if (watchers == null) + 0 + else + watchers.tryCompleteWatched() + debug(s"Request key $key unblocked $numCompleted $purgatoryName operations") + numCompleted + } + + /** + * Return the total size of watch lists the purgatory. Since an operation may be watched + * on multiple lists, and some of its watched entries may still be in the watch lists + * even when it has been completed, this number may be larger than the number of real operations watched + */ + def watched: Int = { + watcherLists.foldLeft(0) { case (sum, watcherList) => sum + watcherList.allWatchers.map(_.countWatched).sum } + } + + /** + * Return the number of delayed operations in the expiry queue + */ + def numDelayed: Int = timeoutTimer.size + + /** + * Cancel watching on any delayed operations for the given key. Note the operation will not be completed + */ + def cancelForKey(key: Any): List[T] = { + val wl = watcherList(key) + inLock(wl.watchersLock) { + val watchers = wl.watchersByKey.remove(key) + if (watchers != null) + watchers.cancel() + else + Nil + } + } + + /* + * Return the watch list of the given key, note that we need to + * grab the removeWatchersLock to avoid the operation being added to a removed watcher list + */ + private def watchForOperation(key: Any, operation: T): Unit = { + val wl = watcherList(key) + inLock(wl.watchersLock) { + val watcher = wl.watchersByKey.getAndMaybePut(key) + watcher.watch(operation) + } + } + + /* + * Remove the key from watcher lists if its list is empty + */ + private def removeKeyIfEmpty(key: Any, watchers: Watchers): Unit = { + val wl = watcherList(key) + inLock(wl.watchersLock) { + // if the current key is no longer correlated to the watchers to remove, skip + if (wl.watchersByKey.get(key) != watchers) + return + + if (watchers != null && watchers.isEmpty) { + wl.watchersByKey.remove(key) + } + } + } + + /** + * Shutdown the expire reaper thread + */ + def shutdown(): Unit = { + if (reaperEnabled) + expirationReaper.shutdown() + timeoutTimer.shutdown() + removeMetric("PurgatorySize", metricsTags) + removeMetric("NumDelayedOperations", metricsTags) + } + + /** + * A linked list of watched delayed operations based on some key + */ + private class Watchers(val key: Any) { + private[this] val operations = new ConcurrentLinkedQueue[T]() + + // count the current number of watched operations. This is O(n), so use isEmpty() if possible + def countWatched: Int = operations.size + + def isEmpty: Boolean = operations.isEmpty + + // add the element to watch + def watch(t: T): Unit = { + operations.add(t) + } + + // traverse the list and try to complete some watched elements + def tryCompleteWatched(): Int = { + var completed = 0 + + val iter = operations.iterator() + while (iter.hasNext) { + val curr = iter.next() + if (curr.isCompleted) { + // another thread has completed this operation, just remove it + iter.remove() + } else if (curr.safeTryComplete()) { + iter.remove() + completed += 1 + } + } + + if (operations.isEmpty) + removeKeyIfEmpty(key, this) + + completed + } + + def cancel(): List[T] = { + val iter = operations.iterator() + val cancelled = new ListBuffer[T]() + while (iter.hasNext) { + val curr = iter.next() + curr.cancel() + iter.remove() + cancelled += curr + } + cancelled.toList + } + + // traverse the list and purge elements that are already completed by others + def purgeCompleted(): Int = { + var purged = 0 + + val iter = operations.iterator() + while (iter.hasNext) { + val curr = iter.next() + if (curr.isCompleted) { + iter.remove() + purged += 1 + } + } + + if (operations.isEmpty) + removeKeyIfEmpty(key, this) + + purged + } + } + + def advanceClock(timeoutMs: Long): Unit = { + timeoutTimer.advanceClock(timeoutMs) + + // Trigger a purge if the number of completed but still being watched operations is larger than + // the purge threshold. That number is computed by the difference btw the estimated total number of + // operations and the number of pending delayed operations. + if (estimatedTotalOperations.get - numDelayed > purgeInterval) { + // now set estimatedTotalOperations to delayed (the number of pending operations) since we are going to + // clean up watchers. Note that, if more operations are completed during the clean up, we may end up with + // a little overestimated total number of operations. + estimatedTotalOperations.getAndSet(numDelayed) + debug("Begin purging watch lists") + val purged = watcherLists.foldLeft(0) { + case (sum, watcherList) => sum + watcherList.allWatchers.map(_.purgeCompleted()).sum + } + debug("Purged %d elements from watch lists.".format(purged)) + } + } + + /** + * A background reaper to expire delayed operations that have timed out + */ + private class ExpiredOperationReaper extends ShutdownableThread( + "ExpirationReaper-%d-%s".format(brokerId, purgatoryName), + false) { + + override def doWork(): Unit = { + advanceClock(200L) + } + } +} diff --git a/core/src/main/scala/kafka/server/DelayedOperationKey.scala b/core/src/main/scala/kafka/server/DelayedOperationKey.scala new file mode 100644 index 0000000..13ed462 --- /dev/null +++ b/core/src/main/scala/kafka/server/DelayedOperationKey.scala @@ -0,0 +1,65 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import org.apache.kafka.common.{TopicIdPartition, TopicPartition} + +/** + * Keys used for delayed operation metrics recording + */ +trait DelayedOperationKey { + def keyLabel: String +} + +object DelayedOperationKey { + val globalLabel = "All" +} + +/* used by delayed-produce and delayed-fetch operations */ +case class TopicPartitionOperationKey(topic: String, partition: Int) extends DelayedOperationKey { + override def keyLabel: String = "%s-%d".format(topic, partition) +} + +object TopicPartitionOperationKey { + def apply(topicPartition: TopicPartition): TopicPartitionOperationKey = { + apply(topicPartition.topic, topicPartition.partition) + } + def apply(topicIdPartition: TopicIdPartition): TopicPartitionOperationKey = { + apply(topicIdPartition.topic, topicIdPartition.partition) + } +} + +/* used by delayed-join-group operations */ +case class MemberKey(groupId: String, consumerId: String) extends DelayedOperationKey { + override def keyLabel: String = "%s-%s".format(groupId, consumerId) +} + +/* used by delayed-join operations */ +case class GroupJoinKey(groupId: String) extends DelayedOperationKey { + override def keyLabel: String = "join-%s".format(groupId) +} + +/* used by delayed-sync operations */ +case class GroupSyncKey(groupId: String) extends DelayedOperationKey { + override def keyLabel: String = "sync-%s".format(groupId) +} + +/* used by delayed-topic operations */ +case class TopicKey(topic: String) extends DelayedOperationKey { + override def keyLabel: String = topic +} diff --git a/core/src/main/scala/kafka/server/DelayedProduce.scala b/core/src/main/scala/kafka/server/DelayedProduce.scala new file mode 100644 index 0000000..5e7e7bf --- /dev/null +++ b/core/src/main/scala/kafka/server/DelayedProduce.scala @@ -0,0 +1,148 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util.concurrent.TimeUnit +import java.util.concurrent.locks.Lock + +import com.yammer.metrics.core.Meter +import kafka.metrics.KafkaMetricsGroup +import kafka.utils.Implicits._ +import kafka.utils.Pool +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse + +import scala.collection._ + +case class ProducePartitionStatus(requiredOffset: Long, responseStatus: PartitionResponse) { + @volatile var acksPending = false + + override def toString = s"[acksPending: $acksPending, error: ${responseStatus.error.code}, " + + s"startOffset: ${responseStatus.baseOffset}, requiredOffset: $requiredOffset]" +} + +/** + * The produce metadata maintained by the delayed produce operation + */ +case class ProduceMetadata(produceRequiredAcks: Short, + produceStatus: Map[TopicPartition, ProducePartitionStatus]) { + + override def toString = s"[requiredAcks: $produceRequiredAcks, partitionStatus: $produceStatus]" +} + +/** + * A delayed produce operation that can be created by the replica manager and watched + * in the produce operation purgatory + */ +class DelayedProduce(delayMs: Long, + produceMetadata: ProduceMetadata, + replicaManager: ReplicaManager, + responseCallback: Map[TopicPartition, PartitionResponse] => Unit, + lockOpt: Option[Lock] = None) + extends DelayedOperation(delayMs, lockOpt) { + + // first update the acks pending variable according to the error code + produceMetadata.produceStatus.forKeyValue { (topicPartition, status) => + if (status.responseStatus.error == Errors.NONE) { + // Timeout error state will be cleared when required acks are received + status.acksPending = true + status.responseStatus.error = Errors.REQUEST_TIMED_OUT + } else { + status.acksPending = false + } + + trace(s"Initial partition status for $topicPartition is $status") + } + + /** + * The delayed produce operation can be completed if every partition + * it produces to is satisfied by one of the following: + * + * Case A: Replica not assigned to partition + * Case B: Replica is no longer the leader of this partition + * Case C: This broker is the leader: + * C.1 - If there was a local error thrown while checking if at least requiredAcks + * replicas have caught up to this operation: set an error in response + * C.2 - Otherwise, set the response with no error. + */ + override def tryComplete(): Boolean = { + // check for each partition if it still has pending acks + produceMetadata.produceStatus.forKeyValue { (topicPartition, status) => + trace(s"Checking produce satisfaction for $topicPartition, current status $status") + // skip those partitions that have already been satisfied + if (status.acksPending) { + val (hasEnough, error) = replicaManager.getPartitionOrError(topicPartition) match { + case Left(err) => + // Case A + (false, err) + + case Right(partition) => + partition.checkEnoughReplicasReachOffset(status.requiredOffset) + } + + // Case B || C.1 || C.2 + if (error != Errors.NONE || hasEnough) { + status.acksPending = false + status.responseStatus.error = error + } + } + } + + // check if every partition has satisfied at least one of case A, B or C + if (!produceMetadata.produceStatus.values.exists(_.acksPending)) + forceComplete() + else + false + } + + override def onExpiration(): Unit = { + produceMetadata.produceStatus.forKeyValue { (topicPartition, status) => + if (status.acksPending) { + debug(s"Expiring produce request for partition $topicPartition with status $status") + DelayedProduceMetrics.recordExpiration(topicPartition) + } + } + } + + /** + * Upon completion, return the current response status along with the error code per partition + */ + override def onComplete(): Unit = { + val responseStatus = produceMetadata.produceStatus.map { case (k, status) => k -> status.responseStatus } + responseCallback(responseStatus) + } +} + +object DelayedProduceMetrics extends KafkaMetricsGroup { + + private val aggregateExpirationMeter = newMeter("ExpiresPerSec", "requests", TimeUnit.SECONDS) + + private val partitionExpirationMeterFactory = (key: TopicPartition) => + newMeter("ExpiresPerSec", + "requests", + TimeUnit.SECONDS, + tags = Map("topic" -> key.topic, "partition" -> key.partition.toString)) + private val partitionExpirationMeters = new Pool[TopicPartition, Meter](valueFactory = Some(partitionExpirationMeterFactory)) + + def recordExpiration(partition: TopicPartition): Unit = { + aggregateExpirationMeter.mark() + partitionExpirationMeters.getAndMaybePut(partition).mark() + } +} + diff --git a/core/src/main/scala/kafka/server/DelegationTokenManager.scala b/core/src/main/scala/kafka/server/DelegationTokenManager.scala new file mode 100644 index 0000000..536a296 --- /dev/null +++ b/core/src/main/scala/kafka/server/DelegationTokenManager.scala @@ -0,0 +1,512 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets +import java.security.InvalidKeyException +import java.util.Base64 + +import javax.crypto.spec.SecretKeySpec +import javax.crypto.{Mac, SecretKey} +import kafka.common.{NotificationHandler, ZkNodeChangeNotificationListener} +import kafka.metrics.KafkaMetricsGroup +import kafka.utils.{CoreUtils, Json, Logging} +import kafka.zk.{DelegationTokenChangeNotificationSequenceZNode, DelegationTokenChangeNotificationZNode, DelegationTokensZNode, KafkaZkClient} +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.security.auth.KafkaPrincipal +import org.apache.kafka.common.security.scram.internals.{ScramFormatter, ScramMechanism} +import org.apache.kafka.common.security.scram.ScramCredential +import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache +import org.apache.kafka.common.security.token.delegation.{DelegationToken, TokenInformation} +import org.apache.kafka.common.utils.{Sanitizer, SecurityUtils, Time} + +import scala.jdk.CollectionConverters._ +import scala.collection.mutable + +object DelegationTokenManager { + val DefaultHmacAlgorithm = "HmacSHA512" + val OwnerKey ="owner" + val RenewersKey = "renewers" + val IssueTimestampKey = "issueTimestamp" + val MaxTimestampKey = "maxTimestamp" + val ExpiryTimestampKey = "expiryTimestamp" + val TokenIdKey = "tokenId" + val VersionKey = "version" + val CurrentVersion = 1 + val ErrorTimestamp = -1 + + /** + * + * @param tokenId + * @param secretKey + * @return + */ + def createHmac(tokenId: String, secretKey: String) : Array[Byte] = { + createHmac(tokenId, createSecretKey(secretKey.getBytes(StandardCharsets.UTF_8))) + } + + /** + * Convert the byte[] to a secret key + * @param keybytes the byte[] to create the secret key from + * @return the secret key + */ + def createSecretKey(keybytes: Array[Byte]) : SecretKey = { + new SecretKeySpec(keybytes, DefaultHmacAlgorithm) + } + + /** + * + * + * @param tokenId + * @param secretKey + * @return + */ + def createBase64HMAC(tokenId: String, secretKey: SecretKey) : String = { + val hmac = createHmac(tokenId, secretKey) + Base64.getEncoder.encodeToString(hmac) + } + + /** + * Compute HMAC of the identifier using the secret key + * @param tokenId the bytes of the identifier + * @param secretKey the secret key + * @return String of the generated hmac + */ + def createHmac(tokenId: String, secretKey: SecretKey) : Array[Byte] = { + val mac = Mac.getInstance(DefaultHmacAlgorithm) + try + mac.init(secretKey) + catch { + case ike: InvalidKeyException => throw new IllegalArgumentException("Invalid key to HMAC computation", ike); + } + mac.doFinal(tokenId.getBytes(StandardCharsets.UTF_8)) + } + + def toJsonCompatibleMap(token: DelegationToken): Map[String, Any] = { + val tokenInfo = token.tokenInfo + val tokenInfoMap = mutable.Map[String, Any]() + tokenInfoMap(VersionKey) = CurrentVersion + tokenInfoMap(OwnerKey) = Sanitizer.sanitize(tokenInfo.ownerAsString) + tokenInfoMap(RenewersKey) = tokenInfo.renewersAsString.asScala.map(e => Sanitizer.sanitize(e)).asJava + tokenInfoMap(IssueTimestampKey) = tokenInfo.issueTimestamp + tokenInfoMap(MaxTimestampKey) = tokenInfo.maxTimestamp + tokenInfoMap(ExpiryTimestampKey) = tokenInfo.expiryTimestamp + tokenInfoMap(TokenIdKey) = tokenInfo.tokenId() + tokenInfoMap.toMap + } + + def fromBytes(bytes: Array[Byte]): Option[TokenInformation] = { + if (bytes == null || bytes.isEmpty) + return None + + Json.parseBytes(bytes) match { + case Some(js) => + val mainJs = js.asJsonObject + require(mainJs(VersionKey).to[Int] == CurrentVersion) + val owner = SecurityUtils.parseKafkaPrincipal(Sanitizer.desanitize(mainJs(OwnerKey).to[String])) + val renewerStr = mainJs(RenewersKey).to[Seq[String]] + val renewers = renewerStr.map(Sanitizer.desanitize(_)).map(SecurityUtils.parseKafkaPrincipal(_)) + val issueTimestamp = mainJs(IssueTimestampKey).to[Long] + val expiryTimestamp = mainJs(ExpiryTimestampKey).to[Long] + val maxTimestamp = mainJs(MaxTimestampKey).to[Long] + val tokenId = mainJs(TokenIdKey).to[String] + + val tokenInfo = new TokenInformation(tokenId, owner, renewers.asJava, + issueTimestamp, maxTimestamp, expiryTimestamp) + + Some(tokenInfo) + case None => + None + } + } + + def filterToken(requestedPrincipal: KafkaPrincipal, owners : Option[List[KafkaPrincipal]], token: TokenInformation, authorizeToken: String => Boolean) : Boolean = { + + val allow = + //exclude tokens which are not requested + if (!owners.isEmpty && !owners.get.exists(owner => token.ownerOrRenewer(owner))) { + false + //Owners and the renewers can describe their own tokens + } else if (token.ownerOrRenewer(requestedPrincipal)) { + true + // Check permission for non-owned tokens + } else if ((authorizeToken(token.tokenId))) { + true + } + else { + false + } + + allow + } +} + +class DelegationTokenManager(val config: KafkaConfig, + val tokenCache: DelegationTokenCache, + val time: Time, + val zkClient: KafkaZkClient) extends Logging with KafkaMetricsGroup { + this.logIdent = s"[Token Manager on Broker ${config.brokerId}]: " + + import DelegationTokenManager._ + + type CreateResponseCallback = CreateTokenResult => Unit + type RenewResponseCallback = (Errors, Long) => Unit + type ExpireResponseCallback = (Errors, Long) => Unit + type DescribeResponseCallback = (Errors, List[DelegationToken]) => Unit + + val secretKey = { + val keyBytes = if (config.tokenAuthEnabled) config.delegationTokenSecretKey.value.getBytes(StandardCharsets.UTF_8) else null + if (keyBytes == null || keyBytes.length == 0) null + else + createSecretKey(keyBytes) + } + + val tokenMaxLifetime: Long = config.delegationTokenMaxLifeMs + val defaultTokenRenewTime: Long = config.delegationTokenExpiryTimeMs + val tokenRemoverScanInterval: Long = config.delegationTokenExpiryCheckIntervalMs + private val lock = new Object() + private var tokenChangeListener: ZkNodeChangeNotificationListener = null + + def startup() = { + if (config.tokenAuthEnabled) { + zkClient.createDelegationTokenPaths() + loadCache() + tokenChangeListener = new ZkNodeChangeNotificationListener(zkClient, DelegationTokenChangeNotificationZNode.path, DelegationTokenChangeNotificationSequenceZNode.SequenceNumberPrefix, TokenChangedNotificationHandler) + tokenChangeListener.init() + } + } + + def shutdown() = { + if (config.tokenAuthEnabled) { + if (tokenChangeListener != null) tokenChangeListener.close() + } + } + + private def loadCache(): Unit = { + lock.synchronized { + val tokens = zkClient.getChildren(DelegationTokensZNode.path) + info(s"Loading the token cache. Total token count: ${tokens.size}") + for (tokenId <- tokens) { + try { + getTokenFromZk(tokenId) match { + case Some(token) => updateCache(token) + case None => + } + } catch { + case ex: Throwable => error(s"Error while getting Token for tokenId: $tokenId", ex) + } + } + } + } + + private def getTokenFromZk(tokenId: String): Option[DelegationToken] = { + zkClient.getDelegationTokenInfo(tokenId) match { + case Some(tokenInformation) => { + val hmac = createHmac(tokenId, secretKey) + Some(new DelegationToken(tokenInformation, hmac)) + } + case None => + None + } + } + + /** + * + * @param token + */ + private def updateCache(token: DelegationToken): Unit = { + val hmacString = token.hmacAsBase64String + val scramCredentialMap = prepareScramCredentials(hmacString) + tokenCache.updateCache(token, scramCredentialMap.asJava) + } + /** + * @param hmacString + */ + private def prepareScramCredentials(hmacString: String) : Map[String, ScramCredential] = { + val scramCredentialMap = mutable.Map[String, ScramCredential]() + + def scramCredential(mechanism: ScramMechanism): ScramCredential = { + new ScramFormatter(mechanism).generateCredential(hmacString, mechanism.minIterations) + } + + for (mechanism <- ScramMechanism.values) + scramCredentialMap(mechanism.mechanismName) = scramCredential(mechanism) + + scramCredentialMap.toMap + } + + /** + * + * @param owner + * @param renewers + * @param maxLifeTimeMs + * @param responseCallback + */ + def createToken(owner: KafkaPrincipal, + renewers: List[KafkaPrincipal], + maxLifeTimeMs: Long, + responseCallback: CreateResponseCallback): Unit = { + + if (!config.tokenAuthEnabled) { + responseCallback(CreateTokenResult(-1, -1, -1, "", Array[Byte](), Errors.DELEGATION_TOKEN_AUTH_DISABLED)) + } else { + lock.synchronized { + val tokenId = CoreUtils.generateUuidAsBase64() + + val issueTimeStamp = time.milliseconds + val maxLifeTime = if (maxLifeTimeMs <= 0) tokenMaxLifetime else Math.min(maxLifeTimeMs, tokenMaxLifetime) + val maxLifeTimeStamp = issueTimeStamp + maxLifeTime + val expiryTimeStamp = Math.min(maxLifeTimeStamp, issueTimeStamp + defaultTokenRenewTime) + + val tokenInfo = new TokenInformation(tokenId, owner, renewers.asJava, issueTimeStamp, maxLifeTimeStamp, expiryTimeStamp) + + val hmac = createHmac(tokenId, secretKey) + val token = new DelegationToken(tokenInfo, hmac) + updateToken(token) + info(s"Created a delegation token: $tokenId for owner: $owner") + responseCallback(CreateTokenResult(issueTimeStamp, expiryTimeStamp, maxLifeTimeStamp, tokenId, hmac, Errors.NONE)) + } + } + } + + /** + * + * @param principal + * @param hmac + * @param renewLifeTimeMs + * @param renewCallback + */ + def renewToken(principal: KafkaPrincipal, + hmac: ByteBuffer, + renewLifeTimeMs: Long, + renewCallback: RenewResponseCallback): Unit = { + + if (!config.tokenAuthEnabled) { + renewCallback(Errors.DELEGATION_TOKEN_AUTH_DISABLED, -1) + } else { + lock.synchronized { + getToken(hmac) match { + case Some(token) => { + val now = time.milliseconds + val tokenInfo = token.tokenInfo + + if (!allowedToRenew(principal, tokenInfo)) { + renewCallback(Errors.DELEGATION_TOKEN_OWNER_MISMATCH, -1) + } else if (tokenInfo.maxTimestamp < now || tokenInfo.expiryTimestamp < now) { + renewCallback(Errors.DELEGATION_TOKEN_EXPIRED, -1) + } else { + val renewLifeTime = if (renewLifeTimeMs < 0) defaultTokenRenewTime else renewLifeTimeMs + val renewTimeStamp = now + renewLifeTime + val expiryTimeStamp = Math.min(tokenInfo.maxTimestamp, renewTimeStamp) + tokenInfo.setExpiryTimestamp(expiryTimeStamp) + + updateToken(token) + info(s"Delegation token renewed for token: ${tokenInfo.tokenId} for owner: ${tokenInfo.owner}") + renewCallback(Errors.NONE, expiryTimeStamp) + } + } + case None => renewCallback(Errors.DELEGATION_TOKEN_NOT_FOUND, -1) + } + } + } + } + + /** + * @param token + */ + private def updateToken(token: DelegationToken): Unit = { + zkClient.setOrCreateDelegationToken(token) + updateCache(token) + zkClient.createTokenChangeNotification(token.tokenInfo.tokenId()) + } + + /** + * + * @param hmac + * @return + */ + private def getToken(hmac: ByteBuffer): Option[DelegationToken] = { + try { + val byteArray = new Array[Byte](hmac.remaining) + hmac.get(byteArray) + val base64Pwd = Base64.getEncoder.encodeToString(byteArray) + val tokenInfo = tokenCache.tokenForHmac(base64Pwd) + if (tokenInfo == null) None else Some(new DelegationToken(tokenInfo, byteArray)) + } catch { + case e: Exception => + error("Exception while getting token for hmac", e) + None + } + } + + /** + * + * @param principal + * @param tokenInfo + * @return + */ + private def allowedToRenew(principal: KafkaPrincipal, tokenInfo: TokenInformation): Boolean = { + if (principal.equals(tokenInfo.owner) || tokenInfo.renewers.asScala.toList.contains(principal)) true else false + } + + /** + * + * @param tokenId + * @return + */ + def getToken(tokenId: String): Option[DelegationToken] = { + val tokenInfo = tokenCache.token(tokenId) + if (tokenInfo != null) Some(getToken(tokenInfo)) else None + } + + /** + * + * @param tokenInfo + * @return + */ + private def getToken(tokenInfo: TokenInformation): DelegationToken = { + val hmac = createHmac(tokenInfo.tokenId, secretKey) + new DelegationToken(tokenInfo, hmac) + } + + /** + * + * @param principal + * @param hmac + * @param expireLifeTimeMs + * @param expireResponseCallback + */ + def expireToken(principal: KafkaPrincipal, + hmac: ByteBuffer, + expireLifeTimeMs: Long, + expireResponseCallback: ExpireResponseCallback): Unit = { + + if (!config.tokenAuthEnabled) { + expireResponseCallback(Errors.DELEGATION_TOKEN_AUTH_DISABLED, -1) + } else { + lock.synchronized { + getToken(hmac) match { + case Some(token) => { + val tokenInfo = token.tokenInfo + val now = time.milliseconds + + if (!allowedToRenew(principal, tokenInfo)) { + expireResponseCallback(Errors.DELEGATION_TOKEN_OWNER_MISMATCH, -1) + } else if (tokenInfo.maxTimestamp < now || tokenInfo.expiryTimestamp < now) { + expireResponseCallback(Errors.DELEGATION_TOKEN_EXPIRED, -1) + } else if (expireLifeTimeMs < 0) { //expire immediately + removeToken(tokenInfo.tokenId) + info(s"Token expired for token: ${tokenInfo.tokenId} for owner: ${tokenInfo.owner}") + expireResponseCallback(Errors.NONE, now) + } else { + //set expiry time stamp + val expiryTimeStamp = Math.min(tokenInfo.maxTimestamp, now + expireLifeTimeMs) + tokenInfo.setExpiryTimestamp(expiryTimeStamp) + + updateToken(token) + info(s"Updated expiry time for token: ${tokenInfo.tokenId} for owner: ${tokenInfo.owner}") + expireResponseCallback(Errors.NONE, expiryTimeStamp) + } + } + case None => expireResponseCallback(Errors.DELEGATION_TOKEN_NOT_FOUND, -1) + } + } + } + } + + /** + * + * @param tokenId + */ + private def removeToken(tokenId: String): Unit = { + zkClient.deleteDelegationToken(tokenId) + removeCache(tokenId) + zkClient.createTokenChangeNotification(tokenId) + } + + /** + * + * @param tokenId + */ + private def removeCache(tokenId: String): Unit = { + tokenCache.removeCache(tokenId) + } + + /** + * + * @return + */ + def expireTokens(): Unit = { + lock.synchronized { + for (tokenInfo <- getAllTokenInformation) { + val now = time.milliseconds + if (tokenInfo.maxTimestamp < now || tokenInfo.expiryTimestamp < now) { + info(s"Delegation token expired for token: ${tokenInfo.tokenId} for owner: ${tokenInfo.owner}") + removeToken(tokenInfo.tokenId) + } + } + } + } + + def getAllTokenInformation: List[TokenInformation] = tokenCache.tokens.asScala.toList + + def getTokens(filterToken: TokenInformation => Boolean): List[DelegationToken] = { + getAllTokenInformation.filter(filterToken).map(token => getToken(token)) + } + + object TokenChangedNotificationHandler extends NotificationHandler { + override def processNotification(tokenIdBytes: Array[Byte]): Unit = { + lock.synchronized { + val tokenId = new String(tokenIdBytes, StandardCharsets.UTF_8) + info(s"Processing Token Notification for tokenId: $tokenId") + getTokenFromZk(tokenId) match { + case Some(token) => updateCache(token) + case None => removeCache(tokenId) + } + } + } + } + +} + +case class CreateTokenResult(issueTimestamp: Long, + expiryTimestamp: Long, + maxTimestamp: Long, + tokenId: String, + hmac: Array[Byte], + error: Errors) { + + override def equals(other: Any): Boolean = { + other match { + case that: CreateTokenResult => + error.equals(that.error) && + tokenId.equals(that.tokenId) && + issueTimestamp.equals(that.issueTimestamp) && + expiryTimestamp.equals(that.expiryTimestamp) && + maxTimestamp.equals(that.maxTimestamp) && + (hmac sameElements that.hmac) + case _ => false + } + } + + override def hashCode(): Int = { + val fields = Seq(issueTimestamp, expiryTimestamp, maxTimestamp, tokenId, hmac, error) + fields.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b) + } +} diff --git a/core/src/main/scala/kafka/server/DynamicBrokerConfig.scala b/core/src/main/scala/kafka/server/DynamicBrokerConfig.scala new file mode 100755 index 0000000..11d6970 --- /dev/null +++ b/core/src/main/scala/kafka/server/DynamicBrokerConfig.scala @@ -0,0 +1,961 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util +import java.util.{Collections, Properties} +import java.util.concurrent.locks.ReentrantReadWriteLock +import kafka.cluster.EndPoint +import kafka.log.{LogCleaner, LogConfig, LogManager} +import kafka.network.SocketServer +import kafka.server.DynamicBrokerConfig._ +import kafka.utils.{CoreUtils, Logging, PasswordEncoder} +import kafka.utils.Implicits._ +import kafka.zk.{AdminZkClient, KafkaZkClient} +import org.apache.kafka.common.Reconfigurable +import org.apache.kafka.common.config.{AbstractConfig, ConfigDef, ConfigException, SslConfigs} +import org.apache.kafka.common.metrics.MetricsReporter +import org.apache.kafka.common.config.types.Password +import org.apache.kafka.common.network.{ListenerName, ListenerReconfigurable} +import org.apache.kafka.common.security.authenticator.LoginManager +import org.apache.kafka.common.utils.{ConfigUtils, Utils} + +import scala.annotation.nowarn +import scala.collection._ +import scala.jdk.CollectionConverters._ + +/** + * Dynamic broker configurations are stored in ZooKeeper and may be defined at two levels: + *
              + *
            • Per-broker configs persisted at /configs/brokers/{brokerId}: These can be described/altered + * using AdminClient using the resource name brokerId.
            • + *
            • Cluster-wide defaults persisted at /configs/brokers/<default>: These can be described/altered + * using AdminClient using an empty resource name.
            • + *
            + * The order of precedence for broker configs is: + *
              + *
            1. DYNAMIC_BROKER_CONFIG: stored in ZK at /configs/brokers/{brokerId}
            2. + *
            3. DYNAMIC_DEFAULT_BROKER_CONFIG: stored in ZK at /configs/brokers/<default>
            4. + *
            5. STATIC_BROKER_CONFIG: properties that broker is started up with, typically from server.properties file
            6. + *
            7. DEFAULT_CONFIG: Default configs defined in KafkaConfig
            8. + *
            + * Log configs use topic config overrides if defined and fallback to broker defaults using the order of precedence above. + * Topic config overrides may use a different config name from the default broker config. + * See [[kafka.log.LogConfig#TopicConfigSynonyms]] for the mapping. + *

            + * AdminClient returns all config synonyms in the order of precedence when configs are described with + * includeSynonyms. In addition to configs that may be defined with the same name at different levels, + * some configs have additional synonyms. + *

            + *
              + *
            • Listener configs may be defined using the prefix listener.name.{listenerName}.{configName}. These may be + * configured as dynamic or static broker configs. Listener configs have higher precedence than the base configs + * that don't specify the listener name. Listeners without a listener config use the base config. Base configs + * may be defined only as STATIC_BROKER_CONFIG or DEFAULT_CONFIG and cannot be updated dynamically.
            • + *
            • Some configs may be defined using multiple properties. For example, log.roll.ms and + * log.roll.hours refer to the same config that may be defined in milliseconds or hours. The order of + * precedence of these synonyms is described in the docs of these configs in [[kafka.server.KafkaConfig]].
            • + *
            + * + */ +object DynamicBrokerConfig { + + private[server] val DynamicSecurityConfigs = SslConfigs.RECONFIGURABLE_CONFIGS.asScala + + val AllDynamicConfigs = DynamicSecurityConfigs ++ + LogCleaner.ReconfigurableConfigs ++ + DynamicLogConfig.ReconfigurableConfigs ++ + DynamicThreadPool.ReconfigurableConfigs ++ + Set(KafkaConfig.MetricReporterClassesProp) ++ + DynamicListenerConfig.ReconfigurableConfigs ++ + SocketServer.ReconfigurableConfigs + + private val ClusterLevelListenerConfigs = Set(KafkaConfig.MaxConnectionsProp, KafkaConfig.MaxConnectionCreationRateProp) + private val PerBrokerConfigs = (DynamicSecurityConfigs ++ DynamicListenerConfig.ReconfigurableConfigs).diff( + ClusterLevelListenerConfigs) + private val ListenerMechanismConfigs = Set(KafkaConfig.SaslJaasConfigProp, + KafkaConfig.SaslLoginCallbackHandlerClassProp, + KafkaConfig.SaslLoginClassProp, + KafkaConfig.SaslServerCallbackHandlerClassProp, + KafkaConfig.ConnectionsMaxReauthMsProp) + + private val ReloadableFileConfigs = Set(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG, SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG) + + val ListenerConfigRegex = """listener\.name\.[^.]*\.(.*)""".r + + private val DynamicPasswordConfigs = { + val passwordConfigs = KafkaConfig.configKeys.filter(_._2.`type` == ConfigDef.Type.PASSWORD).keySet + AllDynamicConfigs.intersect(passwordConfigs) + } + + def isPasswordConfig(name: String): Boolean = DynamicBrokerConfig.DynamicPasswordConfigs.exists(name.endsWith) + + def brokerConfigSynonyms(name: String, matchListenerOverride: Boolean): List[String] = { + name match { + case KafkaConfig.LogRollTimeMillisProp | KafkaConfig.LogRollTimeHoursProp => + List(KafkaConfig.LogRollTimeMillisProp, KafkaConfig.LogRollTimeHoursProp) + case KafkaConfig.LogRollTimeJitterMillisProp | KafkaConfig.LogRollTimeJitterHoursProp => + List(KafkaConfig.LogRollTimeJitterMillisProp, KafkaConfig.LogRollTimeJitterHoursProp) + case KafkaConfig.LogFlushIntervalMsProp => // LogFlushSchedulerIntervalMsProp is used as default + List(KafkaConfig.LogFlushIntervalMsProp, KafkaConfig.LogFlushSchedulerIntervalMsProp) + case KafkaConfig.LogRetentionTimeMillisProp | KafkaConfig.LogRetentionTimeMinutesProp | KafkaConfig.LogRetentionTimeHoursProp => + List(KafkaConfig.LogRetentionTimeMillisProp, KafkaConfig.LogRetentionTimeMinutesProp, KafkaConfig.LogRetentionTimeHoursProp) + case ListenerConfigRegex(baseName) if matchListenerOverride => + // `ListenerMechanismConfigs` are specified as listenerPrefix.mechanism. + // and other listener configs are specified as listenerPrefix. + // Add as a synonym in both cases. + val mechanismConfig = ListenerMechanismConfigs.find(baseName.endsWith) + List(name, mechanismConfig.getOrElse(baseName)) + case _ => List(name) + } + } + + def validateConfigs(props: Properties, perBrokerConfig: Boolean): Unit = { + def checkInvalidProps(invalidPropNames: Set[String], errorMessage: String): Unit = { + if (invalidPropNames.nonEmpty) + throw new ConfigException(s"$errorMessage: $invalidPropNames") + } + checkInvalidProps(nonDynamicConfigs(props), "Cannot update these configs dynamically") + checkInvalidProps(securityConfigsWithoutListenerPrefix(props), + "These security configs can be dynamically updated only per-listener using the listener prefix") + validateConfigTypes(props) + if (!perBrokerConfig) { + checkInvalidProps(perBrokerConfigs(props), + "Cannot update these configs at default cluster level, broker id must be specified") + } + } + + private def perBrokerConfigs(props: Properties): Set[String] = { + val configNames = props.asScala.keySet + def perBrokerListenerConfig(name: String): Boolean = { + name match { + case ListenerConfigRegex(baseName) => !ClusterLevelListenerConfigs.contains(baseName) + case _ => false + } + } + configNames.intersect(PerBrokerConfigs) ++ configNames.filter(perBrokerListenerConfig) + } + + private def nonDynamicConfigs(props: Properties): Set[String] = { + props.asScala.keySet.intersect(DynamicConfig.Broker.nonDynamicProps) + } + + private def securityConfigsWithoutListenerPrefix(props: Properties): Set[String] = { + DynamicSecurityConfigs.filter(props.containsKey) + } + + private def validateConfigTypes(props: Properties): Unit = { + val baseProps = new Properties + props.asScala.foreach { + case (ListenerConfigRegex(baseName), v) => baseProps.put(baseName, v) + case (k, v) => baseProps.put(k, v) + } + DynamicConfig.Broker.validate(baseProps) + } + + private[server] def addDynamicConfigs(configDef: ConfigDef): Unit = { + KafkaConfig.configKeys.forKeyValue { (configName, config) => + if (AllDynamicConfigs.contains(configName)) { + configDef.define(config.name, config.`type`, config.defaultValue, config.validator, + config.importance, config.documentation, config.group, config.orderInGroup, config.width, + config.displayName, config.dependents, config.recommender) + } + } + } + + private[server] def dynamicConfigUpdateModes: util.Map[String, String] = { + AllDynamicConfigs.map { name => + val mode = if (PerBrokerConfigs.contains(name)) "per-broker" else "cluster-wide" + (name -> mode) + }.toMap.asJava + } + + private[server] def resolveVariableConfigs(propsOriginal: Properties): Properties = { + val props = new Properties + val config = new AbstractConfig(new ConfigDef(), propsOriginal, false) + config.originals.asScala.filter(!_._1.startsWith(AbstractConfig.CONFIG_PROVIDERS_CONFIG)).foreach {case (key: String, value: Object) => { + props.put(key, value) + }} + props + } +} + +class DynamicBrokerConfig(private val kafkaConfig: KafkaConfig) extends Logging { + + private[server] val staticBrokerConfigs = ConfigDef.convertToStringMapWithPasswordValues(kafkaConfig.originalsFromThisConfig).asScala + private[server] val staticDefaultConfigs = ConfigDef.convertToStringMapWithPasswordValues(KafkaConfig.defaultValues.asJava).asScala + private val dynamicBrokerConfigs = mutable.Map[String, String]() + private val dynamicDefaultConfigs = mutable.Map[String, String]() + private val reconfigurables = mutable.Buffer[Reconfigurable]() + private val brokerReconfigurables = mutable.Buffer[BrokerReconfigurable]() + private val lock = new ReentrantReadWriteLock + private var currentConfig: KafkaConfig = null + private val dynamicConfigPasswordEncoder = maybeCreatePasswordEncoder(kafkaConfig.passwordEncoderSecret) + + private[server] def initialize(zkClientOpt: Option[KafkaZkClient]): Unit = { + currentConfig = new KafkaConfig(kafkaConfig.props, false, None) + + zkClientOpt.foreach { zkClient => + val adminZkClient = new AdminZkClient(zkClient) + updateDefaultConfig(adminZkClient.fetchEntityConfig(ConfigType.Broker, ConfigEntityName.Default)) + val props = adminZkClient.fetchEntityConfig(ConfigType.Broker, kafkaConfig.brokerId.toString) + val brokerConfig = maybeReEncodePasswords(props, adminZkClient) + updateBrokerConfig(kafkaConfig.brokerId, brokerConfig) + } + } + + /** + * Clear all cached values. This is used to clear state on broker shutdown to avoid + * exceptions in tests when broker is restarted. These fields are re-initialized when + * broker starts up. + */ + private[server] def clear(): Unit = { + dynamicBrokerConfigs.clear() + dynamicDefaultConfigs.clear() + reconfigurables.clear() + brokerReconfigurables.clear() + } + + /** + * Add reconfigurables to be notified when a dynamic broker config is updated. + * + * `Reconfigurable` is the public API used by configurable plugins like metrics reporter + * and quota callbacks. These are reconfigured before `KafkaConfig` is updated so that + * the update can be aborted if `reconfigure()` fails with an exception. + * + * `BrokerReconfigurable` is used for internal reconfigurable classes. These are + * reconfigured after `KafkaConfig` is updated so that they can access `KafkaConfig` + * directly. They are provided both old and new configs. + */ + def addReconfigurables(kafkaServer: KafkaBroker): Unit = { + kafkaServer.authorizer match { + case Some(authz: Reconfigurable) => addReconfigurable(authz) + case _ => + } + addReconfigurable(kafkaServer.kafkaYammerMetrics) + addReconfigurable(new DynamicMetricsReporters(kafkaConfig.brokerId, kafkaServer)) + addReconfigurable(new DynamicClientQuotaCallback(kafkaConfig.brokerId, kafkaServer)) + + addBrokerReconfigurable(new DynamicThreadPool(kafkaServer)) + if (kafkaServer.logManager.cleaner != null) + addBrokerReconfigurable(kafkaServer.logManager.cleaner) + addBrokerReconfigurable(new DynamicLogConfig(kafkaServer.logManager, kafkaServer)) + addBrokerReconfigurable(new DynamicListenerConfig(kafkaServer)) + addBrokerReconfigurable(kafkaServer.socketServer) + } + + def addReconfigurable(reconfigurable: Reconfigurable): Unit = CoreUtils.inWriteLock(lock) { + verifyReconfigurableConfigs(reconfigurable.reconfigurableConfigs.asScala) + reconfigurables += reconfigurable + } + + def addBrokerReconfigurable(reconfigurable: BrokerReconfigurable): Unit = CoreUtils.inWriteLock(lock) { + verifyReconfigurableConfigs(reconfigurable.reconfigurableConfigs) + brokerReconfigurables += reconfigurable + } + + def removeReconfigurable(reconfigurable: Reconfigurable): Unit = CoreUtils.inWriteLock(lock) { + reconfigurables -= reconfigurable + } + + private def verifyReconfigurableConfigs(configNames: Set[String]): Unit = CoreUtils.inWriteLock(lock) { + val nonDynamic = configNames.filter(DynamicConfig.Broker.nonDynamicProps.contains) + require(nonDynamic.isEmpty, s"Reconfigurable contains non-dynamic configs $nonDynamic") + } + + // Visibility for testing + private[server] def currentKafkaConfig: KafkaConfig = CoreUtils.inReadLock(lock) { + currentConfig + } + + private[server] def currentDynamicBrokerConfigs: Map[String, String] = CoreUtils.inReadLock(lock) { + dynamicBrokerConfigs.clone() + } + + private[server] def currentDynamicDefaultConfigs: Map[String, String] = CoreUtils.inReadLock(lock) { + dynamicDefaultConfigs.clone() + } + + private[server] def updateBrokerConfig(brokerId: Int, persistentProps: Properties): Unit = CoreUtils.inWriteLock(lock) { + try { + val props = fromPersistentProps(persistentProps, perBrokerConfig = true) + dynamicBrokerConfigs.clear() + dynamicBrokerConfigs ++= props.asScala + updateCurrentConfig() + } catch { + case e: Exception => error(s"Per-broker configs of $brokerId could not be applied: ${persistentProps.keys()}", e) + } + } + + private[server] def updateDefaultConfig(persistentProps: Properties): Unit = CoreUtils.inWriteLock(lock) { + try { + val props = fromPersistentProps(persistentProps, perBrokerConfig = false) + dynamicDefaultConfigs.clear() + dynamicDefaultConfigs ++= props.asScala + updateCurrentConfig() + } catch { + case e: Exception => error(s"Cluster default configs could not be applied: ${persistentProps.keys()}", e) + } + } + + /** + * All config updates through ZooKeeper are triggered through actual changes in values stored in ZooKeeper. + * For some configs like SSL keystores and truststores, we also want to reload the store if it was modified + * in-place, even though the actual value of the file path and password haven't changed. This scenario alone + * is handled here when a config update request using admin client is processed by ZkAdminManager. If any of + * the SSL configs have changed, then the update will not be done here, but will be handled later when ZK + * changes are processed. At the moment, only listener configs are considered for reloading. + */ + private[server] def reloadUpdatedFilesWithoutConfigChange(newProps: Properties): Unit = CoreUtils.inWriteLock(lock) { + reconfigurables + .filter(reconfigurable => ReloadableFileConfigs.exists(reconfigurable.reconfigurableConfigs.contains)) + .foreach { + case reconfigurable: ListenerReconfigurable => + val kafkaProps = validatedKafkaProps(newProps, perBrokerConfig = true) + val newConfig = new KafkaConfig(kafkaProps.asJava, false, None) + processListenerReconfigurable(reconfigurable, newConfig, Collections.emptyMap(), validateOnly = false, reloadOnly = true) + case reconfigurable => + trace(s"Files will not be reloaded without config change for $reconfigurable") + } + } + + private def maybeCreatePasswordEncoder(secret: Option[Password]): Option[PasswordEncoder] = { + secret.map { secret => + new PasswordEncoder(secret, + kafkaConfig.passwordEncoderKeyFactoryAlgorithm, + kafkaConfig.passwordEncoderCipherAlgorithm, + kafkaConfig.passwordEncoderKeyLength, + kafkaConfig.passwordEncoderIterations) + } + } + + private def passwordEncoder: PasswordEncoder = { + dynamicConfigPasswordEncoder.getOrElse(throw new ConfigException("Password encoder secret not configured")) + } + + private[server] def toPersistentProps(configProps: Properties, perBrokerConfig: Boolean): Properties = { + val props = configProps.clone().asInstanceOf[Properties] + + def encodePassword(configName: String, value: String): Unit = { + if (value != null) { + if (!perBrokerConfig) + throw new ConfigException("Password config can be defined only at broker level") + props.setProperty(configName, passwordEncoder.encode(new Password(value))) + } + } + configProps.asScala.forKeyValue { (name, value) => + if (isPasswordConfig(name)) + encodePassword(name, value) + } + props + } + + private[server] def fromPersistentProps(persistentProps: Properties, + perBrokerConfig: Boolean): Properties = { + val props = persistentProps.clone().asInstanceOf[Properties] + + // Remove all invalid configs from `props` + removeInvalidConfigs(props, perBrokerConfig) + def removeInvalidProps(invalidPropNames: Set[String], errorMessage: String): Unit = { + if (invalidPropNames.nonEmpty) { + invalidPropNames.foreach(props.remove) + error(s"$errorMessage: $invalidPropNames") + } + } + removeInvalidProps(nonDynamicConfigs(props), "Non-dynamic configs configured in ZooKeeper will be ignored") + removeInvalidProps(securityConfigsWithoutListenerPrefix(props), + "Security configs can be dynamically updated only using listener prefix, base configs will be ignored") + if (!perBrokerConfig) + removeInvalidProps(perBrokerConfigs(props), "Per-broker configs defined at default cluster level will be ignored") + + def decodePassword(configName: String, value: String): Unit = { + if (value != null) { + try { + props.setProperty(configName, passwordEncoder.decode(value).value) + } catch { + case e: Exception => + error(s"Dynamic password config $configName could not be decoded, ignoring.", e) + props.remove(configName) + } + } + } + + props.asScala.forKeyValue { (name, value) => + if (isPasswordConfig(name)) + decodePassword(name, value) + } + props + } + + // If the secret has changed, password.encoder.old.secret contains the old secret that was used + // to encode the configs in ZK. Decode passwords using the old secret and update ZK with values + // encoded using the current secret. Ignore any errors during decoding since old secret may not + // have been removed during broker restart. + private def maybeReEncodePasswords(persistentProps: Properties, adminZkClient: AdminZkClient): Properties = { + val props = persistentProps.clone().asInstanceOf[Properties] + if (props.asScala.keySet.exists(isPasswordConfig)) { + maybeCreatePasswordEncoder(kafkaConfig.passwordEncoderOldSecret).foreach { passwordDecoder => + persistentProps.asScala.forKeyValue { (configName, value) => + if (isPasswordConfig(configName) && value != null) { + val decoded = try { + Some(passwordDecoder.decode(value).value) + } catch { + case _: Exception => + debug(s"Dynamic password config $configName could not be decoded using old secret, new secret will be used.") + None + } + decoded.foreach { value => props.put(configName, passwordEncoder.encode(new Password(value))) } + } + } + adminZkClient.changeBrokerConfig(Some(kafkaConfig.brokerId), props) + } + } + props + } + + /** + * Validate the provided configs `propsOverride` and return the full Kafka configs with + * the configured defaults and these overrides. + * + * Note: The caller must acquire the read or write lock before invoking this method. + */ + private def validatedKafkaProps(propsOverride: Properties, perBrokerConfig: Boolean): Map[String, String] = { + val propsResolved = DynamicBrokerConfig.resolveVariableConfigs(propsOverride) + validateConfigs(propsResolved, perBrokerConfig) + val newProps = mutable.Map[String, String]() + newProps ++= staticBrokerConfigs + if (perBrokerConfig) { + overrideProps(newProps, dynamicDefaultConfigs) + overrideProps(newProps, propsResolved.asScala) + } else { + overrideProps(newProps, propsResolved.asScala) + overrideProps(newProps, dynamicBrokerConfigs) + } + newProps + } + + private[server] def validate(props: Properties, perBrokerConfig: Boolean): Unit = CoreUtils.inReadLock(lock) { + val newProps = validatedKafkaProps(props, perBrokerConfig) + processReconfiguration(newProps, validateOnly = true) + } + + private def removeInvalidConfigs(props: Properties, perBrokerConfig: Boolean): Unit = { + try { + validateConfigTypes(props) + props.asScala + } catch { + case e: Exception => + val invalidProps = props.asScala.filter { case (k, v) => + val props1 = new Properties + props1.put(k, v) + try { + validateConfigTypes(props1) + false + } catch { + case _: Exception => true + } + } + invalidProps.keys.foreach(props.remove) + val configSource = if (perBrokerConfig) "broker" else "default cluster" + error(s"Dynamic $configSource config contains invalid values in: ${invalidProps.keys}, these configs will be ignored", e) + } + } + + private[server] def maybeReconfigure(reconfigurable: Reconfigurable, oldConfig: KafkaConfig, newConfig: util.Map[String, _]): Unit = { + if (reconfigurable.reconfigurableConfigs.asScala.exists(key => oldConfig.originals.get(key) != newConfig.get(key))) + reconfigurable.reconfigure(newConfig) + } + + /** + * Returns the change in configurations between the new props and current props by returning a + * map of the changed configs, as well as the set of deleted keys + */ + private def updatedConfigs(newProps: java.util.Map[String, _], currentProps: java.util.Map[String, _]): + (mutable.Map[String, _], Set[String]) = { + val changeMap = newProps.asScala.filter { + case (k, v) => v != currentProps.get(k) + } + val deletedKeySet = currentProps.asScala.filter { + case (k, _) => !newProps.containsKey(k) + }.keySet + (changeMap, deletedKeySet) + } + + /** + * Updates values in `props` with the new values from `propsOverride`. Synonyms of updated configs + * are removed from `props` to ensure that the config with the higher precedence is applied. For example, + * if `log.roll.ms` was defined in server.properties and `log.roll.hours` is configured dynamically, + * `log.roll.hours` from the dynamic configuration will be used and `log.roll.ms` will be removed from + * `props` (even though `log.roll.hours` is secondary to `log.roll.ms`). + */ + private def overrideProps(props: mutable.Map[String, String], propsOverride: mutable.Map[String, String]): Unit = { + propsOverride.forKeyValue { (k, v) => + // Remove synonyms of `k` to ensure the right precedence is applied. But disable `matchListenerOverride` + // so that base configs corresponding to listener configs are not removed. Base configs should not be removed + // since they may be used by other listeners. It is ok to retain them in `props` since base configs cannot be + // dynamically updated and listener-specific configs have the higher precedence. + brokerConfigSynonyms(k, matchListenerOverride = false).foreach(props.remove) + props.put(k, v) + } + } + + private def updateCurrentConfig(): Unit = { + val newProps = mutable.Map[String, String]() + newProps ++= staticBrokerConfigs + overrideProps(newProps, dynamicDefaultConfigs) + overrideProps(newProps, dynamicBrokerConfigs) + + val oldConfig = currentConfig + val (newConfig, brokerReconfigurablesToUpdate) = processReconfiguration(newProps, validateOnly = false) + if (newConfig ne currentConfig) { + currentConfig = newConfig + kafkaConfig.updateCurrentConfig(newConfig) + + // Process BrokerReconfigurable updates after current config is updated + brokerReconfigurablesToUpdate.foreach(_.reconfigure(oldConfig, newConfig)) + } + } + + private def processReconfiguration(newProps: Map[String, String], validateOnly: Boolean): (KafkaConfig, List[BrokerReconfigurable]) = { + val newConfig = new KafkaConfig(newProps.asJava, !validateOnly, None) + val (changeMap, deletedKeySet) = updatedConfigs(newConfig.originalsFromThisConfig, currentConfig.originals) + if (changeMap.nonEmpty || deletedKeySet.nonEmpty) { + try { + val customConfigs = new util.HashMap[String, Object](newConfig.originalsFromThisConfig) // non-Kafka configs + newConfig.valuesFromThisConfig.keySet.forEach(customConfigs.remove(_)) + reconfigurables.foreach { + case listenerReconfigurable: ListenerReconfigurable => + processListenerReconfigurable(listenerReconfigurable, newConfig, customConfigs, validateOnly, reloadOnly = false) + case reconfigurable => + if (needsReconfiguration(reconfigurable.reconfigurableConfigs, changeMap.keySet, deletedKeySet)) + processReconfigurable(reconfigurable, changeMap.keySet, newConfig.valuesFromThisConfig, customConfigs, validateOnly) + } + + // BrokerReconfigurable updates are processed after config is updated. Only do the validation here. + val brokerReconfigurablesToUpdate = mutable.Buffer[BrokerReconfigurable]() + brokerReconfigurables.foreach { reconfigurable => + if (needsReconfiguration(reconfigurable.reconfigurableConfigs.asJava, changeMap.keySet, deletedKeySet)) { + reconfigurable.validateReconfiguration(newConfig) + if (!validateOnly) + brokerReconfigurablesToUpdate += reconfigurable + } + } + (newConfig, brokerReconfigurablesToUpdate.toList) + } catch { + case e: Exception => + if (!validateOnly) + error(s"Failed to update broker configuration with configs : " + + s"${ConfigUtils.configMapToRedactedString(newConfig.originalsFromThisConfig, KafkaConfig.configDef)}", e) + throw new ConfigException("Invalid dynamic configuration", e) + } + } + else + (currentConfig, List.empty) + } + + private def needsReconfiguration(reconfigurableConfigs: util.Set[String], updatedKeys: Set[String], deletedKeys: Set[String]): Boolean = { + reconfigurableConfigs.asScala.intersect(updatedKeys).nonEmpty || + reconfigurableConfigs.asScala.intersect(deletedKeys).nonEmpty + } + + private def processListenerReconfigurable(listenerReconfigurable: ListenerReconfigurable, + newConfig: KafkaConfig, + customConfigs: util.Map[String, Object], + validateOnly: Boolean, + reloadOnly: Boolean): Unit = { + val listenerName = listenerReconfigurable.listenerName + val oldValues = currentConfig.valuesWithPrefixOverride(listenerName.configPrefix) + val newValues = newConfig.valuesFromThisConfigWithPrefixOverride(listenerName.configPrefix) + val (changeMap, deletedKeys) = updatedConfigs(newValues, oldValues) + val updatedKeys = changeMap.keySet + val configsChanged = needsReconfiguration(listenerReconfigurable.reconfigurableConfigs, updatedKeys, deletedKeys) + // if `reloadOnly`, reconfigure if configs haven't changed. Otherwise reconfigure if configs have changed + if (reloadOnly != configsChanged) + processReconfigurable(listenerReconfigurable, updatedKeys, newValues, customConfigs, validateOnly) + } + + private def processReconfigurable(reconfigurable: Reconfigurable, + updatedConfigNames: Set[String], + allNewConfigs: util.Map[String, _], + newCustomConfigs: util.Map[String, Object], + validateOnly: Boolean): Unit = { + val newConfigs = new util.HashMap[String, Object] + allNewConfigs.forEach { (k, v) => newConfigs.put(k, v.asInstanceOf[AnyRef]) } + newConfigs.putAll(newCustomConfigs) + try { + reconfigurable.validateReconfiguration(newConfigs) + } catch { + case e: ConfigException => throw e + case _: Exception => + throw new ConfigException(s"Validation of dynamic config update of $updatedConfigNames failed with class ${reconfigurable.getClass}") + } + + if (!validateOnly) { + info(s"Reconfiguring $reconfigurable, updated configs: $updatedConfigNames " + + s"custom configs: ${ConfigUtils.configMapToRedactedString(newCustomConfigs, KafkaConfig.configDef)}") + reconfigurable.reconfigure(newConfigs) + } + } +} + +trait BrokerReconfigurable { + + def reconfigurableConfigs: Set[String] + + def validateReconfiguration(newConfig: KafkaConfig): Unit + + def reconfigure(oldConfig: KafkaConfig, newConfig: KafkaConfig): Unit +} + +object DynamicLogConfig { + // Exclude message.format.version for now since we need to check that the version + // is supported on all brokers in the cluster. + @nowarn("cat=deprecation") + val ExcludedConfigs = Set(KafkaConfig.LogMessageFormatVersionProp) + + val ReconfigurableConfigs = LogConfig.TopicConfigSynonyms.values.toSet -- ExcludedConfigs + val KafkaConfigToLogConfigName = LogConfig.TopicConfigSynonyms.map { case (k, v) => (v, k) } +} + +class DynamicLogConfig(logManager: LogManager, server: KafkaBroker) extends BrokerReconfigurable with Logging { + + override def reconfigurableConfigs: Set[String] = { + DynamicLogConfig.ReconfigurableConfigs + } + + override def validateReconfiguration(newConfig: KafkaConfig): Unit = { + // For update of topic config overrides, only config names and types are validated + // Names and types have already been validated. For consistency with topic config + // validation, no additional validation is performed. + } + + private def updateLogsConfig(newBrokerDefaults: Map[String, Object]): Unit = { + logManager.brokerConfigUpdated() + logManager.allLogs.foreach { log => + val props = mutable.Map.empty[Any, Any] + props ++= newBrokerDefaults + props ++= log.config.originals.asScala.filter { case (k, _) => log.config.overriddenConfigs.contains(k) } + + val logConfig = LogConfig(props.asJava, log.config.overriddenConfigs) + log.updateConfig(logConfig) + } + } + + override def reconfigure(oldConfig: KafkaConfig, newConfig: KafkaConfig): Unit = { + val originalLogConfig = logManager.currentDefaultConfig + val originalUncleanLeaderElectionEnable = originalLogConfig.uncleanLeaderElectionEnable + val newBrokerDefaults = new util.HashMap[String, Object](originalLogConfig.originals) + newConfig.valuesFromThisConfig.forEach { (k, v) => + if (DynamicLogConfig.ReconfigurableConfigs.contains(k)) { + DynamicLogConfig.KafkaConfigToLogConfigName.get(k).foreach { configName => + if (v == null) + newBrokerDefaults.remove(configName) + else + newBrokerDefaults.put(configName, v.asInstanceOf[AnyRef]) + } + } + } + + logManager.reconfigureDefaultLogConfig(LogConfig(newBrokerDefaults)) + + updateLogsConfig(newBrokerDefaults.asScala) + + if (logManager.currentDefaultConfig.uncleanLeaderElectionEnable && !originalUncleanLeaderElectionEnable) { + server match { + case kafkaServer: KafkaServer => kafkaServer.kafkaController.enableDefaultUncleanLeaderElection() + case _ => + } + } + } +} + +object DynamicThreadPool { + val ReconfigurableConfigs = Set( + KafkaConfig.NumIoThreadsProp, + KafkaConfig.NumNetworkThreadsProp, + KafkaConfig.NumReplicaFetchersProp, + KafkaConfig.NumRecoveryThreadsPerDataDirProp, + KafkaConfig.BackgroundThreadsProp) +} + +class DynamicThreadPool(server: KafkaBroker) extends BrokerReconfigurable { + + override def reconfigurableConfigs: Set[String] = { + DynamicThreadPool.ReconfigurableConfigs + } + + override def validateReconfiguration(newConfig: KafkaConfig): Unit = { + newConfig.values.forEach { (k, v) => + if (DynamicThreadPool.ReconfigurableConfigs.contains(k)) { + val newValue = v.asInstanceOf[Int] + val oldValue = currentValue(k) + if (newValue != oldValue) { + val errorMsg = s"Dynamic thread count update validation failed for $k=$v" + if (newValue <= 0) + throw new ConfigException(s"$errorMsg, value should be at least 1") + if (newValue < oldValue / 2) + throw new ConfigException(s"$errorMsg, value should be at least half the current value $oldValue") + if (newValue > oldValue * 2) + throw new ConfigException(s"$errorMsg, value should not be greater than double the current value $oldValue") + } + } + } + } + + override def reconfigure(oldConfig: KafkaConfig, newConfig: KafkaConfig): Unit = { + if (newConfig.numIoThreads != oldConfig.numIoThreads) + server.dataPlaneRequestHandlerPool.resizeThreadPool(newConfig.numIoThreads) + if (newConfig.numNetworkThreads != oldConfig.numNetworkThreads) + server.socketServer.resizeThreadPool(oldConfig.numNetworkThreads, newConfig.numNetworkThreads) + if (newConfig.numReplicaFetchers != oldConfig.numReplicaFetchers) + server.replicaManager.resizeFetcherThreadPool(newConfig.numReplicaFetchers) + if (newConfig.numRecoveryThreadsPerDataDir != oldConfig.numRecoveryThreadsPerDataDir) + server.logManager.resizeRecoveryThreadPool(newConfig.numRecoveryThreadsPerDataDir) + if (newConfig.backgroundThreads != oldConfig.backgroundThreads) + server.kafkaScheduler.resizeThreadPool(newConfig.backgroundThreads) + } + + private def currentValue(name: String): Int = { + name match { + case KafkaConfig.NumIoThreadsProp => server.config.numIoThreads + case KafkaConfig.NumNetworkThreadsProp => server.config.numNetworkThreads + case KafkaConfig.NumReplicaFetchersProp => server.config.numReplicaFetchers + case KafkaConfig.NumRecoveryThreadsPerDataDirProp => server.config.numRecoveryThreadsPerDataDir + case KafkaConfig.BackgroundThreadsProp => server.config.backgroundThreads + case n => throw new IllegalStateException(s"Unexpected config $n") + } + } +} + +class DynamicMetricsReporters(brokerId: Int, server: KafkaBroker) extends Reconfigurable { + + private val dynamicConfig = server.config.dynamicConfig + private val metrics = server.metrics + private val propsOverride = Map[String, AnyRef](KafkaConfig.BrokerIdProp -> brokerId.toString) + private val currentReporters = mutable.Map[String, MetricsReporter]() + + createReporters(dynamicConfig.currentKafkaConfig.getList(KafkaConfig.MetricReporterClassesProp), + Collections.emptyMap[String, Object]) + + private[server] def currentMetricsReporters: List[MetricsReporter] = currentReporters.values.toList + + override def configure(configs: util.Map[String, _]): Unit = {} + + override def reconfigurableConfigs(): util.Set[String] = { + val configs = new util.HashSet[String]() + configs.add(KafkaConfig.MetricReporterClassesProp) + currentReporters.values.foreach { + case reporter: Reconfigurable => configs.addAll(reporter.reconfigurableConfigs) + case _ => + } + configs + } + + override def validateReconfiguration(configs: util.Map[String, _]): Unit = { + val updatedMetricsReporters = metricsReporterClasses(configs) + + // Ensure all the reporter classes can be loaded and have a default constructor + updatedMetricsReporters.foreach { className => + val clazz = Utils.loadClass(className, classOf[MetricsReporter]) + clazz.getConstructor() + } + + // Validate the new configuration using every reconfigurable reporter instance that is not being deleted + currentReporters.values.foreach { + case reporter: Reconfigurable => + if (updatedMetricsReporters.contains(reporter.getClass.getName)) + reporter.validateReconfiguration(configs) + case _ => + } + } + + override def reconfigure(configs: util.Map[String, _]): Unit = { + val updatedMetricsReporters = metricsReporterClasses(configs) + val deleted = currentReporters.keySet.toSet -- updatedMetricsReporters + deleted.foreach(removeReporter) + currentReporters.values.foreach { + case reporter: Reconfigurable => dynamicConfig.maybeReconfigure(reporter, dynamicConfig.currentKafkaConfig, configs) + case _ => + } + val added = updatedMetricsReporters.filterNot(currentReporters.keySet) + createReporters(added.asJava, configs) + } + + private def createReporters(reporterClasses: util.List[String], + updatedConfigs: util.Map[String, _]): Unit = { + val props = new util.HashMap[String, AnyRef] + updatedConfigs.forEach { (k, v) => props.put(k, v.asInstanceOf[AnyRef]) } + propsOverride.forKeyValue { (k, v) => props.put(k, v) } + val reporters = dynamicConfig.currentKafkaConfig.getConfiguredInstances(reporterClasses, classOf[MetricsReporter], props) + reporters.forEach { reporter => + metrics.addReporter(reporter) + currentReporters += reporter.getClass.getName -> reporter + } + KafkaBroker.notifyClusterListeners(server.clusterId, reporters.asScala) + KafkaBroker.notifyMetricsReporters(server.clusterId, server.config, reporters.asScala) + } + + private def removeReporter(className: String): Unit = { + currentReporters.remove(className).foreach(metrics.removeReporter) + } + + private def metricsReporterClasses(configs: util.Map[String, _]): mutable.Buffer[String] = { + configs.get(KafkaConfig.MetricReporterClassesProp).asInstanceOf[util.List[String]].asScala + } +} +object DynamicListenerConfig { + + val ReconfigurableConfigs = Set( + // Listener configs + KafkaConfig.AdvertisedListenersProp, + KafkaConfig.ListenersProp, + KafkaConfig.ListenerSecurityProtocolMapProp, + + // SSL configs + KafkaConfig.PrincipalBuilderClassProp, + KafkaConfig.SslProtocolProp, + KafkaConfig.SslProviderProp, + KafkaConfig.SslCipherSuitesProp, + KafkaConfig.SslEnabledProtocolsProp, + KafkaConfig.SslKeystoreTypeProp, + KafkaConfig.SslKeystoreLocationProp, + KafkaConfig.SslKeystorePasswordProp, + KafkaConfig.SslKeyPasswordProp, + KafkaConfig.SslTruststoreTypeProp, + KafkaConfig.SslTruststoreLocationProp, + KafkaConfig.SslTruststorePasswordProp, + KafkaConfig.SslKeyManagerAlgorithmProp, + KafkaConfig.SslTrustManagerAlgorithmProp, + KafkaConfig.SslEndpointIdentificationAlgorithmProp, + KafkaConfig.SslSecureRandomImplementationProp, + KafkaConfig.SslClientAuthProp, + KafkaConfig.SslEngineFactoryClassProp, + + // SASL configs + KafkaConfig.SaslMechanismInterBrokerProtocolProp, + KafkaConfig.SaslJaasConfigProp, + KafkaConfig.SaslEnabledMechanismsProp, + KafkaConfig.SaslKerberosServiceNameProp, + KafkaConfig.SaslKerberosKinitCmdProp, + KafkaConfig.SaslKerberosTicketRenewWindowFactorProp, + KafkaConfig.SaslKerberosTicketRenewJitterProp, + KafkaConfig.SaslKerberosMinTimeBeforeReloginProp, + KafkaConfig.SaslKerberosPrincipalToLocalRulesProp, + KafkaConfig.SaslLoginRefreshWindowFactorProp, + KafkaConfig.SaslLoginRefreshWindowJitterProp, + KafkaConfig.SaslLoginRefreshMinPeriodSecondsProp, + KafkaConfig.SaslLoginRefreshBufferSecondsProp, + + // Connection limit configs + KafkaConfig.MaxConnectionsProp, + KafkaConfig.MaxConnectionCreationRateProp + ) +} + +class DynamicClientQuotaCallback(brokerId: Int, server: KafkaBroker) extends Reconfigurable { + + override def configure(configs: util.Map[String, _]): Unit = {} + + override def reconfigurableConfigs(): util.Set[String] = { + val configs = new util.HashSet[String]() + server.quotaManagers.clientQuotaCallback.foreach { + case callback: Reconfigurable => configs.addAll(callback.reconfigurableConfigs) + case _ => + } + configs + } + + override def validateReconfiguration(configs: util.Map[String, _]): Unit = { + server.quotaManagers.clientQuotaCallback.foreach { + case callback: Reconfigurable => callback.validateReconfiguration(configs) + case _ => + } + } + + override def reconfigure(configs: util.Map[String, _]): Unit = { + val config = server.config + server.quotaManagers.clientQuotaCallback.foreach { + case callback: Reconfigurable => + config.dynamicConfig.maybeReconfigure(callback, config.dynamicConfig.currentKafkaConfig, configs) + true + case _ => false + } + } +} + +class DynamicListenerConfig(server: KafkaBroker) extends BrokerReconfigurable with Logging { + + override def reconfigurableConfigs: Set[String] = { + DynamicListenerConfig.ReconfigurableConfigs + } + + def validateReconfiguration(newConfig: KafkaConfig): Unit = { + val oldConfig = server.config + if (!oldConfig.requiresZookeeper) { + throw new ConfigException("Dynamic reconfiguration of listeners is not yet supported when using a Raft-based metadata quorum") + } + val newListeners = listenersToMap(newConfig.listeners) + val newAdvertisedListeners = listenersToMap(newConfig.effectiveAdvertisedListeners) + val oldListeners = listenersToMap(oldConfig.listeners) + if (!newAdvertisedListeners.keySet.subsetOf(newListeners.keySet)) + throw new ConfigException(s"Advertised listeners '$newAdvertisedListeners' must be a subset of listeners '$newListeners'") + if (!newListeners.keySet.subsetOf(newConfig.effectiveListenerSecurityProtocolMap.keySet)) + throw new ConfigException(s"Listeners '$newListeners' must be subset of listener map '${newConfig.effectiveListenerSecurityProtocolMap}'") + newListeners.keySet.intersect(oldListeners.keySet).foreach { listenerName => + def immutableListenerConfigs(kafkaConfig: KafkaConfig, prefix: String): Map[String, AnyRef] = { + kafkaConfig.originalsWithPrefix(prefix, true).asScala.filter { case (key, _) => + // skip the reconfigurable configs + !DynamicSecurityConfigs.contains(key) && !SocketServer.ListenerReconfigurableConfigs.contains(key) + } + } + if (immutableListenerConfigs(newConfig, listenerName.configPrefix) != immutableListenerConfigs(oldConfig, listenerName.configPrefix)) + throw new ConfigException(s"Configs cannot be updated dynamically for existing listener $listenerName, " + + "restart broker or create a new listener for update") + if (oldConfig.effectiveListenerSecurityProtocolMap(listenerName) != newConfig.effectiveListenerSecurityProtocolMap(listenerName)) + throw new ConfigException(s"Security protocol cannot be updated for existing listener $listenerName") + } + if (!newAdvertisedListeners.contains(newConfig.interBrokerListenerName)) + throw new ConfigException(s"Advertised listener must be specified for inter-broker listener ${newConfig.interBrokerListenerName}") + } + + def reconfigure(oldConfig: KafkaConfig, newConfig: KafkaConfig): Unit = { + val newListeners = newConfig.listeners + val newListenerMap = listenersToMap(newListeners) + val oldListeners = oldConfig.listeners + val oldListenerMap = listenersToMap(oldListeners) + val listenersRemoved = oldListeners.filterNot(e => newListenerMap.contains(e.listenerName)) + val listenersAdded = newListeners.filterNot(e => oldListenerMap.contains(e.listenerName)) + + // Clear SASL login cache to force re-login + if (listenersAdded.nonEmpty || listenersRemoved.nonEmpty) + LoginManager.closeAll() + + server.socketServer.removeListeners(listenersRemoved) + if (listenersAdded.nonEmpty) + server.socketServer.addListeners(listenersAdded) + + server match { + case kafkaServer: KafkaServer => kafkaServer.kafkaController.updateBrokerInfo(kafkaServer.createBrokerInfo) + case _ => + } + } + + private def listenersToMap(listeners: Seq[EndPoint]): Map[ListenerName, EndPoint] = + listeners.map(e => (e.listenerName, e)).toMap + +} diff --git a/core/src/main/scala/kafka/server/DynamicConfig.scala b/core/src/main/scala/kafka/server/DynamicConfig.scala new file mode 100644 index 0000000..ddcda03 --- /dev/null +++ b/core/src/main/scala/kafka/server/DynamicConfig.scala @@ -0,0 +1,127 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.net.{InetAddress, UnknownHostException} +import java.util.Properties + +import kafka.log.LogConfig +import org.apache.kafka.common.config.ConfigDef +import org.apache.kafka.common.config.ConfigDef.Importance._ +import org.apache.kafka.common.config.ConfigDef.Range._ +import org.apache.kafka.common.config.ConfigDef.Type._ + +import scala.jdk.CollectionConverters._ + +/** + * Class used to hold dynamic configs. These are configs which have no physical manifestation in the server.properties + * and can only be set dynamically. + */ +object DynamicConfig { + + object Broker { + // Properties + val LeaderReplicationThrottledRateProp = "leader.replication.throttled.rate" + val FollowerReplicationThrottledRateProp = "follower.replication.throttled.rate" + val ReplicaAlterLogDirsIoMaxBytesPerSecondProp = "replica.alter.log.dirs.io.max.bytes.per.second" + + // Defaults + val DefaultReplicationThrottledRate = ReplicationQuotaManagerConfig.QuotaBytesPerSecondDefault + + // Documentation + val LeaderReplicationThrottledRateDoc = "A long representing the upper bound (bytes/sec) on replication traffic for leaders enumerated in the " + + s"property ${LogConfig.LeaderReplicationThrottledReplicasProp} (for each topic). This property can be only set dynamically. It is suggested that the " + + s"limit be kept above 1MB/s for accurate behaviour." + val FollowerReplicationThrottledRateDoc = "A long representing the upper bound (bytes/sec) on replication traffic for followers enumerated in the " + + s"property ${LogConfig.FollowerReplicationThrottledReplicasProp} (for each topic). This property can be only set dynamically. It is suggested that the " + + s"limit be kept above 1MB/s for accurate behaviour." + val ReplicaAlterLogDirsIoMaxBytesPerSecondDoc = "A long representing the upper bound (bytes/sec) on disk IO used for moving replica between log directories on the same broker. " + + s"This property can be only set dynamically. It is suggested that the limit be kept above 1MB/s for accurate behaviour." + + // Definitions + val brokerConfigDef = new ConfigDef() + // Round minimum value down, to make it easier for users. + .define(LeaderReplicationThrottledRateProp, LONG, DefaultReplicationThrottledRate, atLeast(0), MEDIUM, LeaderReplicationThrottledRateDoc) + .define(FollowerReplicationThrottledRateProp, LONG, DefaultReplicationThrottledRate, atLeast(0), MEDIUM, FollowerReplicationThrottledRateDoc) + .define(ReplicaAlterLogDirsIoMaxBytesPerSecondProp, LONG, DefaultReplicationThrottledRate, atLeast(0), MEDIUM, ReplicaAlterLogDirsIoMaxBytesPerSecondDoc) + DynamicBrokerConfig.addDynamicConfigs(brokerConfigDef) + val nonDynamicProps = KafkaConfig.configNames.toSet -- brokerConfigDef.names.asScala + + def names = brokerConfigDef.names + + def validate(props: Properties) = DynamicConfig.validate(brokerConfigDef, props, customPropsAllowed = true) + } + + object QuotaConfigs { + def isClientOrUserQuotaConfig(name: String): Boolean = org.apache.kafka.common.config.internals.QuotaConfigs.isClientOrUserConfig(name) + } + + object Client { + private val clientConfigs = org.apache.kafka.common.config.internals.QuotaConfigs.clientConfigs() + + def configKeys = clientConfigs.configKeys + + def names = clientConfigs.names + + def validate(props: Properties) = DynamicConfig.validate(clientConfigs, props, customPropsAllowed = false) + } + + object User { + private val userConfigs = org.apache.kafka.common.config.internals.QuotaConfigs.userConfigs() + + def configKeys = userConfigs.configKeys + + def names = userConfigs.names + + def validate(props: Properties) = DynamicConfig.validate(userConfigs, props, customPropsAllowed = false) + } + + object Ip { + private val ipConfigs = org.apache.kafka.common.config.internals.QuotaConfigs.ipConfigs() + + def configKeys = ipConfigs.configKeys + + def names = ipConfigs.names + + def validate(props: Properties) = DynamicConfig.validate(ipConfigs, props, customPropsAllowed = false) + + def isValidIpEntity(ip: String): Boolean = { + if (ip != ConfigEntityName.Default) { + try { + InetAddress.getByName(ip) + } catch { + case _: UnknownHostException => return false + } + } + true + } + } + + private def validate(configDef: ConfigDef, props: Properties, customPropsAllowed: Boolean) = { + // Validate Names + val names = configDef.names() + val propKeys = props.keySet.asScala.map(_.asInstanceOf[String]) + if (!customPropsAllowed) { + val unknownKeys = propKeys.filter(!names.contains(_)) + require(unknownKeys.isEmpty, s"Unknown Dynamic Configuration: $unknownKeys.") + } + val propResolved = DynamicBrokerConfig.resolveVariableConfigs(props) + // ValidateValues + configDef.parse(propResolved) + } +} diff --git a/core/src/main/scala/kafka/server/DynamicConfigManager.scala b/core/src/main/scala/kafka/server/DynamicConfigManager.scala new file mode 100644 index 0000000..3eed382 --- /dev/null +++ b/core/src/main/scala/kafka/server/DynamicConfigManager.scala @@ -0,0 +1,184 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.nio.charset.StandardCharsets + +import kafka.common.{NotificationHandler, ZkNodeChangeNotificationListener} +import kafka.utils.{Json, Logging} +import kafka.utils.json.JsonObject +import kafka.zk.{AdminZkClient, ConfigEntityChangeNotificationSequenceZNode, ConfigEntityChangeNotificationZNode, KafkaZkClient} +import org.apache.kafka.common.config.types.Password +import org.apache.kafka.common.security.scram.internals.ScramMechanism +import org.apache.kafka.common.utils.Time + +import scala.jdk.CollectionConverters._ +import scala.collection._ + +/** + * Represents all the entities that can be configured via ZK + */ +object ConfigType { + val Topic = "topics" + val Client = "clients" + val User = "users" + val Broker = "brokers" + val Ip = "ips" + val all = Seq(Topic, Client, User, Broker, Ip) +} + +object ConfigEntityName { + val Default = "" +} + +/** + * This class initiates and carries out config changes for all entities defined in ConfigType. + * + * It works as follows. + * + * Config is stored under the path: /config/entityType/entityName + * E.g. /config/topics/ and /config/clients/ + * This znode stores the overrides for this entity in properties format with defaults stored using entityName "". + * Multiple entity names may be specified (eg. quotas) using a hierarchical path: + * E.g. /config/users//clients/ + * + * To avoid watching all topics for changes instead we have a notification path + * /config/changes + * The DynamicConfigManager has a child watch on this path. + * + * To update a config we first update the config properties. Then we create a new sequential + * znode under the change path which contains the name of the entityType and entityName that was updated, say + * /config/changes/config_change_13321 + * The sequential znode contains data in this format: {"version" : 1, "entity_type":"topic/client", "entity_name" : "topic_name/client_id"} + * This is just a notification--the actual config change is stored only once under the /config/entityType/entityName path. + * Version 2 of notifications has the format: {"version" : 2, "entity_path":"entity_type/entity_name"} + * Multiple entities may be specified as a hierarchical path (eg. users//clients/). + * + * This will fire a watcher on all brokers. This watcher works as follows. It reads all the config change notifications. + * It keeps track of the highest config change suffix number it has applied previously. For any previously applied change it finds + * it checks if this notification is larger than a static expiration time (say 10mins) and if so it deletes this notification. + * For any new changes it reads the new configuration, combines it with the defaults, and updates the existing config. + * + * Note that config is always read from the config path in zk, the notification is just a trigger to do so. So if a broker is + * down and misses a change that is fine--when it restarts it will be loading the full config anyway. Note also that + * if there are two consecutive config changes it is possible that only the last one will be applied (since by the time the + * broker reads the config the both changes may have been made). In this case the broker would needlessly refresh the config twice, + * but that is harmless. + * + * On restart the config manager re-processes all notifications. This will usually be wasted work, but avoids any race conditions + * on startup where a change might be missed between the initial config load and registering for change notifications. + * + */ +class DynamicConfigManager(private val zkClient: KafkaZkClient, + private val configHandlers: Map[String, ConfigHandler], + private val changeExpirationMs: Long = 15*60*1000, + private val time: Time = Time.SYSTEM) extends Logging { + val adminZkClient = new AdminZkClient(zkClient) + + object ConfigChangedNotificationHandler extends NotificationHandler { + override def processNotification(jsonBytes: Array[Byte]) = { + // Ignore non-json notifications because they can be from the deprecated TopicConfigManager + Json.parseBytes(jsonBytes).foreach { js => + val jsObject = js.asJsonObjectOption.getOrElse { + throw new IllegalArgumentException("Config change notification has an unexpected value. The format is:" + + """{"version" : 1, "entity_type":"topics/clients", "entity_name" : "topic_name/client_id"} or """ + + """{"version" : 2, "entity_path":"entity_type/entity_name"}. """ + + s"Received: ${new String(jsonBytes, StandardCharsets.UTF_8)}") + } + jsObject("version").to[Int] match { + case 1 => processEntityConfigChangeVersion1(jsonBytes, jsObject) + case 2 => processEntityConfigChangeVersion2(jsonBytes, jsObject) + case version => throw new IllegalArgumentException("Config change notification has unsupported version " + + s"'$version', supported versions are 1 and 2.") + } + } + } + + private def processEntityConfigChangeVersion1(jsonBytes: Array[Byte], js: JsonObject): Unit = { + val validConfigTypes = Set(ConfigType.Topic, ConfigType.Client) + val entityType = js.get("entity_type").flatMap(_.to[Option[String]]).filter(validConfigTypes).getOrElse { + throw new IllegalArgumentException("Version 1 config change notification must have 'entity_type' set to " + + s"'clients' or 'topics'. Received: ${new String(jsonBytes, StandardCharsets.UTF_8)}") + } + + val entity = js.get("entity_name").flatMap(_.to[Option[String]]).getOrElse { + throw new IllegalArgumentException("Version 1 config change notification does not specify 'entity_name'. " + + s"Received: ${new String(jsonBytes, StandardCharsets.UTF_8)}") + } + + val entityConfig = adminZkClient.fetchEntityConfig(entityType, entity) + info(s"Processing override for entityType: $entityType, entity: $entity with config: $entityConfig") + configHandlers(entityType).processConfigChanges(entity, entityConfig) + + } + + private def processEntityConfigChangeVersion2(jsonBytes: Array[Byte], js: JsonObject): Unit = { + + val entityPath = js.get("entity_path").flatMap(_.to[Option[String]]).getOrElse { + throw new IllegalArgumentException(s"Version 2 config change notification must specify 'entity_path'. " + + s"Received: ${new String(jsonBytes, StandardCharsets.UTF_8)}") + } + + val index = entityPath.indexOf('/') + val rootEntityType = entityPath.substring(0, index) + if (index < 0 || !configHandlers.contains(rootEntityType)) { + val entityTypes = configHandlers.keys.map(entityType => s"'$entityType'/").mkString(", ") + throw new IllegalArgumentException("Version 2 config change notification must have 'entity_path' starting with " + + s"one of $entityTypes. Received: ${new String(jsonBytes, StandardCharsets.UTF_8)}") + } + val fullSanitizedEntityName = entityPath.substring(index + 1) + + val entityConfig = adminZkClient.fetchEntityConfig(rootEntityType, fullSanitizedEntityName) + val loggableConfig = entityConfig.asScala.map { + case (k, v) => (k, if (ScramMechanism.isScram(k)) Password.HIDDEN else v) + } + info(s"Processing override for entityPath: $entityPath with config: $loggableConfig") + configHandlers(rootEntityType).processConfigChanges(fullSanitizedEntityName, entityConfig) + + } + } + + private val configChangeListener = new ZkNodeChangeNotificationListener(zkClient, ConfigEntityChangeNotificationZNode.path, + ConfigEntityChangeNotificationSequenceZNode.SequenceNumberPrefix, ConfigChangedNotificationHandler) + + /** + * Begin watching for config changes + */ + def startup(): Unit = { + configChangeListener.init() + + // Apply all existing client/user configs to the ClientIdConfigHandler/UserConfigHandler to bootstrap the overrides + configHandlers.foreach { + case (ConfigType.User, handler) => + adminZkClient.fetchAllEntityConfigs(ConfigType.User).foreach { + case (sanitizedUser, properties) => handler.processConfigChanges(sanitizedUser, properties) + } + adminZkClient.fetchAllChildEntityConfigs(ConfigType.User, ConfigType.Client).foreach { + case (sanitizedUserClientId, properties) => handler.processConfigChanges(sanitizedUserClientId, properties) + } + case (configType, handler) => + adminZkClient.fetchAllEntityConfigs(configType).foreach { + case (entityName, properties) => handler.processConfigChanges(entityName, properties) + } + } + } + + def shutdown(): Unit = { + configChangeListener.close() + } +} diff --git a/core/src/main/scala/kafka/server/EnvelopeUtils.scala b/core/src/main/scala/kafka/server/EnvelopeUtils.scala new file mode 100644 index 0000000..ec8871f --- /dev/null +++ b/core/src/main/scala/kafka/server/EnvelopeUtils.scala @@ -0,0 +1,137 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.net.{InetAddress, UnknownHostException} +import java.nio.ByteBuffer + +import kafka.network.RequestChannel +import org.apache.kafka.common.errors.{InvalidRequestException, PrincipalDeserializationException, UnsupportedVersionException} +import org.apache.kafka.common.network.ClientInformation +import org.apache.kafka.common.requests.{EnvelopeRequest, RequestContext, RequestHeader} +import org.apache.kafka.common.security.auth.KafkaPrincipal + +import scala.compat.java8.OptionConverters._ + +object EnvelopeUtils { + def handleEnvelopeRequest( + request: RequestChannel.Request, + requestChannelMetrics: RequestChannel.Metrics, + handler: RequestChannel.Request => Unit): Unit = { + val envelope = request.body[EnvelopeRequest] + val forwardedPrincipal = parseForwardedPrincipal(request.context, envelope.requestPrincipal) + val forwardedClientAddress = parseForwardedClientAddress(envelope.clientAddress) + + val forwardedRequestBuffer = envelope.requestData.duplicate() + val forwardedRequestHeader = parseForwardedRequestHeader(forwardedRequestBuffer) + + val forwardedApi = forwardedRequestHeader.apiKey + if (!forwardedApi.forwardable) { + throw new InvalidRequestException(s"API $forwardedApi is not enabled or is not eligible for forwarding") + } + + val forwardedContext = new RequestContext( + forwardedRequestHeader, + request.context.connectionId, + forwardedClientAddress, + forwardedPrincipal, + request.context.listenerName, + request.context.securityProtocol, + ClientInformation.EMPTY, + request.context.fromPrivilegedListener + ) + + val forwardedRequest = parseForwardedRequest( + request, + forwardedContext, + forwardedRequestBuffer, + requestChannelMetrics + ) + handler(forwardedRequest) + } + + private def parseForwardedClientAddress( + address: Array[Byte] + ): InetAddress = { + try { + InetAddress.getByAddress(address) + } catch { + case e: UnknownHostException => + throw new InvalidRequestException("Failed to parse client address from envelope", e) + } + } + + private def parseForwardedRequest( + envelope: RequestChannel.Request, + forwardedContext: RequestContext, + buffer: ByteBuffer, + requestChannelMetrics: RequestChannel.Metrics + ): RequestChannel.Request = { + try { + new RequestChannel.Request( + processor = envelope.processor, + context = forwardedContext, + startTimeNanos = envelope.startTimeNanos, + envelope.memoryPool, + buffer, + requestChannelMetrics, + Some(envelope) + ) + } catch { + case e: InvalidRequestException => + // We use UNSUPPORTED_VERSION if the embedded request cannot be parsed. + // The purpose is to disambiguate structural errors in the envelope request + // itself, such as an invalid client address. + throw new UnsupportedVersionException(s"Failed to parse forwarded request " + + s"with header ${forwardedContext.header}", e) + } + } + + private def parseForwardedRequestHeader( + buffer: ByteBuffer + ): RequestHeader = { + try { + RequestHeader.parse(buffer) + } catch { + case e: InvalidRequestException => + // We use UNSUPPORTED_VERSION if the embedded request cannot be parsed. + // The purpose is to disambiguate structural errors in the envelope request + // itself, such as an invalid client address. + throw new UnsupportedVersionException("Failed to parse request header from envelope", e) + } + } + + private def parseForwardedPrincipal( + envelopeContext: RequestContext, + principalBytes: Array[Byte] + ): KafkaPrincipal = { + envelopeContext.principalSerde.asScala match { + case Some(serde) => + try { + serde.deserialize(principalBytes) + } catch { + case e: Exception => + throw new PrincipalDeserializationException("Failed to deserialize client principal from envelope", e) + } + + case None => + throw new PrincipalDeserializationException("Could not deserialize principal since " + + "no `KafkaPrincipalSerde` has been defined") + } + } +} diff --git a/core/src/main/scala/kafka/server/FetchDataInfo.scala b/core/src/main/scala/kafka/server/FetchDataInfo.scala new file mode 100644 index 0000000..f6cf725 --- /dev/null +++ b/core/src/main/scala/kafka/server/FetchDataInfo.scala @@ -0,0 +1,31 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import org.apache.kafka.common.message.FetchResponseData +import org.apache.kafka.common.record.Records + +sealed trait FetchIsolation +case object FetchLogEnd extends FetchIsolation +case object FetchHighWatermark extends FetchIsolation +case object FetchTxnCommitted extends FetchIsolation + +case class FetchDataInfo(fetchOffsetMetadata: LogOffsetMetadata, + records: Records, + firstEntryIncomplete: Boolean = false, + abortedTransactions: Option[List[FetchResponseData.AbortedTransaction]] = None) diff --git a/core/src/main/scala/kafka/server/FetchSession.scala b/core/src/main/scala/kafka/server/FetchSession.scala new file mode 100644 index 0000000..f7d348d --- /dev/null +++ b/core/src/main/scala/kafka/server/FetchSession.scala @@ -0,0 +1,839 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.metrics.KafkaMetricsGroup +import kafka.utils.Logging +import org.apache.kafka.common.{TopicIdPartition, TopicPartition, Uuid} +import org.apache.kafka.common.message.FetchResponseData +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.FetchMetadata.{FINAL_EPOCH, INITIAL_EPOCH, INVALID_SESSION_ID} +import org.apache.kafka.common.requests.{FetchRequest, FetchResponse, FetchMetadata => JFetchMetadata} +import org.apache.kafka.common.utils.{ImplicitLinkedHashCollection, Time, Utils} +import java.util +import java.util.Optional +import java.util.concurrent.{ThreadLocalRandom, TimeUnit} + +import scala.collection.{mutable, _} +import scala.math.Ordered.orderingToOrdered + +object FetchSession { + type REQ_MAP = util.Map[TopicIdPartition, FetchRequest.PartitionData] + type RESP_MAP = util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + type CACHE_MAP = ImplicitLinkedHashCollection[CachedPartition] + type RESP_MAP_ITER = util.Iterator[util.Map.Entry[TopicIdPartition, FetchResponseData.PartitionData]] + type TOPIC_NAME_MAP = util.Map[Uuid, String] + + val NUM_INCREMENTAL_FETCH_SESSIONS = "NumIncrementalFetchSessions" + val NUM_INCREMENTAL_FETCH_PARTITIONS_CACHED = "NumIncrementalFetchPartitionsCached" + val INCREMENTAL_FETCH_SESSIONS_EVICTIONS_PER_SEC = "IncrementalFetchSessionEvictionsPerSec" + val EVICTIONS = "evictions" + + def partitionsToLogString(partitions: util.Collection[TopicIdPartition], traceEnabled: Boolean): String = { + if (traceEnabled) { + "(" + Utils.join(partitions, ", ") + ")" + } else { + s"${partitions.size} partition(s)" + } + } +} + +/** + * A cached partition. + * + * The broker maintains a set of these objects for each incremental fetch session. + * When an incremental fetch request is made, any partitions which are not explicitly + * enumerated in the fetch request are loaded from the cache. Similarly, when an + * incremental fetch response is being prepared, any partitions that have not changed and + * do not have errors are left out of the response. + * + * We store many of these objects, so it is important for them to be memory-efficient. + * That is why we store topic and partition separately rather than storing a TopicPartition + * object. The TP object takes up more memory because it is a separate JVM object, and + * because it stores the cached hash code in memory. + * + * Note that fetcherLogStartOffset is the LSO of the follower performing the fetch, whereas + * localLogStartOffset is the log start offset of the partition on this broker. + */ +class CachedPartition(var topic: String, + val topicId: Uuid, + val partition: Int, + var maxBytes: Int, + var fetchOffset: Long, + var highWatermark: Long, + var leaderEpoch: Optional[Integer], + var fetcherLogStartOffset: Long, + var localLogStartOffset: Long, + var lastFetchedEpoch: Optional[Integer]) + extends ImplicitLinkedHashCollection.Element { + + var cachedNext: Int = ImplicitLinkedHashCollection.INVALID_INDEX + var cachedPrev: Int = ImplicitLinkedHashCollection.INVALID_INDEX + + override def next: Int = cachedNext + override def setNext(next: Int): Unit = this.cachedNext = next + override def prev: Int = cachedPrev + override def setPrev(prev: Int): Unit = this.cachedPrev = prev + + def this(topic: String, topicId: Uuid, partition: Int) = + this(topic, topicId, partition, -1, -1, -1, Optional.empty(), -1, -1, Optional.empty[Integer]) + + def this(part: TopicIdPartition) = { + this(part.topic, part.topicId, part.partition) + } + + def this(part: TopicIdPartition, reqData: FetchRequest.PartitionData) = + this(part.topic, part.topicId, part.partition, reqData.maxBytes, reqData.fetchOffset, -1, + reqData.currentLeaderEpoch, reqData.logStartOffset, -1, reqData.lastFetchedEpoch) + + def this(part: TopicIdPartition, reqData: FetchRequest.PartitionData, + respData: FetchResponseData.PartitionData) = + this(part.topic, part.topicId, part.partition, reqData.maxBytes, reqData.fetchOffset, respData.highWatermark, + reqData.currentLeaderEpoch, reqData.logStartOffset, respData.logStartOffset, reqData.lastFetchedEpoch) + + def reqData = new FetchRequest.PartitionData(topicId, fetchOffset, fetcherLogStartOffset, maxBytes, leaderEpoch, lastFetchedEpoch) + + def updateRequestParams(reqData: FetchRequest.PartitionData): Unit = { + // Update our cached request parameters. + maxBytes = reqData.maxBytes + fetchOffset = reqData.fetchOffset + fetcherLogStartOffset = reqData.logStartOffset + leaderEpoch = reqData.currentLeaderEpoch + lastFetchedEpoch = reqData.lastFetchedEpoch + } + + def maybeResolveUnknownName(topicNames: FetchSession.TOPIC_NAME_MAP): Unit = { + if (this.topic == null) { + this.topic = topicNames.get(this.topicId) + } + } + + /** + * Determine whether or not the specified cached partition should be included in the FetchResponse we send back to + * the fetcher and update it if requested. + * + * This function should be called while holding the appropriate session lock. + * + * @param respData partition data + * @param updateResponseData if set to true, update this CachedPartition with new request and response data. + * @return True if this partition should be included in the response; false if it can be omitted. + */ + def maybeUpdateResponseData(respData: FetchResponseData.PartitionData, updateResponseData: Boolean): Boolean = { + // Check the response data. + var mustRespond = false + if (FetchResponse.recordsSize(respData) > 0) { + // Partitions with new data are always included in the response. + mustRespond = true + } + if (highWatermark != respData.highWatermark) { + mustRespond = true + if (updateResponseData) + highWatermark = respData.highWatermark + } + if (localLogStartOffset != respData.logStartOffset) { + mustRespond = true + if (updateResponseData) + localLogStartOffset = respData.logStartOffset + } + if (FetchResponse.isPreferredReplica(respData)) { + // If the broker computed a preferred read replica, we need to include it in the response + mustRespond = true + } + if (respData.errorCode != Errors.NONE.code) { + // Partitions with errors are always included in the response. + // We also set the cached highWatermark to an invalid offset, -1. + // This ensures that when the error goes away, we re-send the partition. + if (updateResponseData) + highWatermark = -1 + mustRespond = true + } + + if (FetchResponse.isDivergingEpoch(respData)) { + // Partitions with diverging epoch are always included in response to trigger truncation. + mustRespond = true + } + mustRespond + } + + /** + * We have different equality checks depending on whether topic IDs are used. + * This means we need a different hash function as well. We use name to calculate the hash if the ID is zero and unused. + * Otherwise, we use the topic ID in the hash calculation. + * + * @return the hash code for the CachedPartition depending on what request version we are using. + */ + override def hashCode: Int = + if (topicId != Uuid.ZERO_UUID) + (31 * partition) + topicId.hashCode + else + (31 * partition) + topic.hashCode + + /** + * We have different equality checks depending on whether topic IDs are used. + * + * This is because when we use topic IDs, a partition with a given ID and an unknown name is the same as a partition with that + * ID and a known name. This means we can only use topic ID and partition when determining equality. + * + * On the other hand, if we are using topic names, all IDs are zero. This means we can only use topic name and partition + * when determining equality. + */ + override def equals(that: Any): Boolean = + that match { + case that: CachedPartition => + this.eq(that) || (if (this.topicId != Uuid.ZERO_UUID) + this.partition.equals(that.partition) && this.topicId.equals(that.topicId) + else + this.partition.equals(that.partition) && this.topic.equals(that.topic)) + case _ => false + } + + override def toString: String = synchronized { + "CachedPartition(topic=" + topic + + ", topicId=" + topicId + + ", partition=" + partition + + ", maxBytes=" + maxBytes + + ", fetchOffset=" + fetchOffset + + ", highWatermark=" + highWatermark + + ", fetcherLogStartOffset=" + fetcherLogStartOffset + + ", localLogStartOffset=" + localLogStartOffset + + ")" + } +} + +/** + * The fetch session. + * + * Each fetch session is protected by its own lock, which must be taken before mutable + * fields are read or modified. This includes modification of the session partition map. + * + * @param id The unique fetch session ID. + * @param privileged True if this session is privileged. Sessions crated by followers + * are privileged; session created by consumers are not. + * @param partitionMap The CachedPartitionMap. + * @param usesTopicIds True if this session is using topic IDs + * @param creationMs The time in milliseconds when this session was created. + * @param lastUsedMs The last used time in milliseconds. This should only be updated by + * FetchSessionCache#touch. + * @param epoch The fetch session sequence number. + */ +class FetchSession(val id: Int, + val privileged: Boolean, + val partitionMap: FetchSession.CACHE_MAP, + val usesTopicIds: Boolean, + val creationMs: Long, + var lastUsedMs: Long, + var epoch: Int) { + // This is used by the FetchSessionCache to store the last known size of this session. + // If this is -1, the Session is not in the cache. + var cachedSize = -1 + + def size: Int = synchronized { + partitionMap.size + } + + def isEmpty: Boolean = synchronized { + partitionMap.isEmpty + } + + def lastUsedKey: LastUsedKey = synchronized { + LastUsedKey(lastUsedMs, id) + } + + def evictableKey: EvictableKey = synchronized { + EvictableKey(privileged, cachedSize, id) + } + + def metadata: JFetchMetadata = synchronized { new JFetchMetadata(id, epoch) } + + def getFetchOffset(topicIdPartition: TopicIdPartition): Option[Long] = synchronized { + Option(partitionMap.find(new CachedPartition(topicIdPartition))).map(_.fetchOffset) + } + + type TL = util.ArrayList[TopicIdPartition] + + // Update the cached partition data based on the request. + def update(fetchData: FetchSession.REQ_MAP, + toForget: util.List[TopicIdPartition], + reqMetadata: JFetchMetadata): (TL, TL, TL) = synchronized { + val added = new TL + val updated = new TL + val removed = new TL + fetchData.forEach { (topicPart, reqData) => + val cachedPartitionKey = new CachedPartition(topicPart, reqData) + val cachedPart = partitionMap.find(cachedPartitionKey) + if (cachedPart == null) { + partitionMap.mustAdd(cachedPartitionKey) + added.add(topicPart) + } else { + cachedPart.updateRequestParams(reqData) + updated.add(topicPart) + } + } + toForget.forEach { p => + if (partitionMap.remove(new CachedPartition(p))) { + removed.add(p) + } + } + (added, updated, removed) + } + + override def toString: String = synchronized { + "FetchSession(id=" + id + + ", privileged=" + privileged + + ", partitionMap.size=" + partitionMap.size + + ", usesTopicIds=" + usesTopicIds + + ", creationMs=" + creationMs + + ", lastUsedMs=" + lastUsedMs + + ", epoch=" + epoch + ")" + } +} + +trait FetchContext extends Logging { + /** + * Get the fetch offset for a given partition. + */ + def getFetchOffset(part: TopicIdPartition): Option[Long] + + /** + * Apply a function to each partition in the fetch request. + */ + def foreachPartition(fun: (TopicIdPartition, FetchRequest.PartitionData) => Unit): Unit + + /** + * Get the response size to be used for quota computation. Since we are returning an empty response in case of + * throttling, we are not supposed to update the context until we know that we are not going to throttle. + */ + def getResponseSize(updates: FetchSession.RESP_MAP, versionId: Short): Int + + /** + * Updates the fetch context with new partition information. Generates response data. + * The response data may require subsequent down-conversion. + */ + def updateAndGenerateResponseData(updates: FetchSession.RESP_MAP): FetchResponse + + def partitionsToLogString(partitions: util.Collection[TopicIdPartition]): String = + FetchSession.partitionsToLogString(partitions, isTraceEnabled) + + /** + * Return an empty throttled response due to quota violation. + */ + def getThrottledResponse(throttleTimeMs: Int): FetchResponse = + FetchResponse.of(Errors.NONE, throttleTimeMs, INVALID_SESSION_ID, new FetchSession.RESP_MAP) +} + +/** + * The fetch context for a fetch request that had a session error. + */ +class SessionErrorContext(val error: Errors, + val reqMetadata: JFetchMetadata) extends FetchContext { + override def getFetchOffset(part: TopicIdPartition): Option[Long] = None + + override def foreachPartition(fun: (TopicIdPartition, FetchRequest.PartitionData) => Unit): Unit = {} + + override def getResponseSize(updates: FetchSession.RESP_MAP, versionId: Short): Int = { + FetchResponse.sizeOf(versionId, (new FetchSession.RESP_MAP).entrySet.iterator) + } + + // Because of the fetch session error, we don't know what partitions were supposed to be in this request. + override def updateAndGenerateResponseData(updates: FetchSession.RESP_MAP): FetchResponse = { + debug(s"Session error fetch context returning $error") + FetchResponse.of(error, 0, INVALID_SESSION_ID, new FetchSession.RESP_MAP) + } +} + +/** + * The fetch context for a sessionless fetch request. + * + * @param fetchData The partition data from the fetch request. + */ +class SessionlessFetchContext(val fetchData: util.Map[TopicIdPartition, FetchRequest.PartitionData]) extends FetchContext { + override def getFetchOffset(part: TopicIdPartition): Option[Long] = + Option(fetchData.get(part)).map(_.fetchOffset) + + override def foreachPartition(fun: (TopicIdPartition, FetchRequest.PartitionData) => Unit): Unit = { + fetchData.forEach((tp, data) => fun(tp, data)) + } + + override def getResponseSize(updates: FetchSession.RESP_MAP, versionId: Short): Int = { + FetchResponse.sizeOf(versionId, updates.entrySet.iterator) + } + + override def updateAndGenerateResponseData(updates: FetchSession.RESP_MAP): FetchResponse = { + debug(s"Sessionless fetch context returning ${partitionsToLogString(updates.keySet)}") + FetchResponse.of(Errors.NONE, 0, INVALID_SESSION_ID, updates) + } +} + +/** + * The fetch context for a full fetch request. + * + * @param time The clock to use. + * @param cache The fetch session cache. + * @param reqMetadata The request metadata. + * @param fetchData The partition data from the fetch request. + * @param usesTopicIds True if this session should use topic IDs. + * @param isFromFollower True if this fetch request came from a follower. + */ +class FullFetchContext(private val time: Time, + private val cache: FetchSessionCache, + private val reqMetadata: JFetchMetadata, + private val fetchData: util.Map[TopicIdPartition, FetchRequest.PartitionData], + private val usesTopicIds: Boolean, + private val isFromFollower: Boolean) extends FetchContext { + override def getFetchOffset(part: TopicIdPartition): Option[Long] = + Option(fetchData.get(part)).map(_.fetchOffset) + + override def foreachPartition(fun: (TopicIdPartition, FetchRequest.PartitionData) => Unit): Unit = { + fetchData.forEach((tp, data) => fun(tp, data)) + } + + override def getResponseSize(updates: FetchSession.RESP_MAP, versionId: Short): Int = { + FetchResponse.sizeOf(versionId, updates.entrySet.iterator) + } + + override def updateAndGenerateResponseData(updates: FetchSession.RESP_MAP): FetchResponse = { + def createNewSession: FetchSession.CACHE_MAP = { + val cachedPartitions = new FetchSession.CACHE_MAP(updates.size) + updates.forEach { (part, respData) => + val reqData = fetchData.get(part) + cachedPartitions.mustAdd(new CachedPartition(part, reqData, respData)) + } + cachedPartitions + } + val responseSessionId = cache.maybeCreateSession(time.milliseconds(), isFromFollower, + updates.size, usesTopicIds, () => createNewSession) + debug(s"Full fetch context with session id $responseSessionId returning " + + s"${partitionsToLogString(updates.keySet)}") + FetchResponse.of(Errors.NONE, 0, responseSessionId, updates) + } +} + +/** + * The fetch context for an incremental fetch request. + * + * @param time The clock to use. + * @param reqMetadata The request metadata. + * @param session The incremental fetch request session. + * @param topicNames A mapping from topic ID to topic name used to resolve partitions already in the session. + */ +class IncrementalFetchContext(private val time: Time, + private val reqMetadata: JFetchMetadata, + private val session: FetchSession, + private val topicNames: FetchSession.TOPIC_NAME_MAP) extends FetchContext { + + override def getFetchOffset(tp: TopicIdPartition): Option[Long] = session.getFetchOffset(tp) + + override def foreachPartition(fun: (TopicIdPartition, FetchRequest.PartitionData) => Unit): Unit = { + // Take the session lock and iterate over all the cached partitions. + session.synchronized { + session.partitionMap.forEach { part => + // Try to resolve an unresolved partition if it does not yet have a name + if (session.usesTopicIds) + part.maybeResolveUnknownName(topicNames) + fun(new TopicIdPartition(part.topicId, new TopicPartition(part.topic, part.partition)), part.reqData) + } + } + } + + // Iterator that goes over the given partition map and selects partitions that need to be included in the response. + // If updateFetchContextAndRemoveUnselected is set to true, the fetch context will be updated for the selected + // partitions and also remove unselected ones as they are encountered. + private class PartitionIterator(val iter: FetchSession.RESP_MAP_ITER, + val updateFetchContextAndRemoveUnselected: Boolean) + extends FetchSession.RESP_MAP_ITER { + var nextElement: util.Map.Entry[TopicIdPartition, FetchResponseData.PartitionData] = null + + override def hasNext: Boolean = { + while ((nextElement == null) && iter.hasNext) { + val element = iter.next() + val topicPart = element.getKey + val respData = element.getValue + val cachedPart = session.partitionMap.find(new CachedPartition(topicPart)) + val mustRespond = cachedPart.maybeUpdateResponseData(respData, updateFetchContextAndRemoveUnselected) + if (mustRespond) { + nextElement = element + if (updateFetchContextAndRemoveUnselected && FetchResponse.recordsSize(respData) > 0) { + session.partitionMap.remove(cachedPart) + session.partitionMap.mustAdd(cachedPart) + } + } else { + if (updateFetchContextAndRemoveUnselected) { + iter.remove() + } + } + } + nextElement != null + } + + override def next(): util.Map.Entry[TopicIdPartition, FetchResponseData.PartitionData] = { + if (!hasNext) throw new NoSuchElementException + val element = nextElement + nextElement = null + element + } + + override def remove() = throw new UnsupportedOperationException + } + + override def getResponseSize(updates: FetchSession.RESP_MAP, versionId: Short): Int = { + session.synchronized { + val expectedEpoch = JFetchMetadata.nextEpoch(reqMetadata.epoch) + if (session.epoch != expectedEpoch) { + FetchResponse.sizeOf(versionId, (new FetchSession.RESP_MAP).entrySet.iterator) + } else { + // Pass the partition iterator which updates neither the fetch context nor the partition map. + FetchResponse.sizeOf(versionId, new PartitionIterator(updates.entrySet.iterator, false)) + } + } + } + + override def updateAndGenerateResponseData(updates: FetchSession.RESP_MAP): FetchResponse = { + session.synchronized { + // Check to make sure that the session epoch didn't change in between + // creating this fetch context and generating this response. + val expectedEpoch = JFetchMetadata.nextEpoch(reqMetadata.epoch) + if (session.epoch != expectedEpoch) { + info(s"Incremental fetch session ${session.id} expected epoch $expectedEpoch, but " + + s"got ${session.epoch}. Possible duplicate request.") + FetchResponse.of(Errors.INVALID_FETCH_SESSION_EPOCH, 0, session.id, new FetchSession.RESP_MAP) + } else { + // Iterate over the update list using PartitionIterator. This will prune updates which don't need to be sent + val partitionIter = new PartitionIterator(updates.entrySet.iterator, true) + while (partitionIter.hasNext) { + partitionIter.next() + } + debug(s"Incremental fetch context with session id ${session.id} returning " + + s"${partitionsToLogString(updates.keySet)}") + FetchResponse.of(Errors.NONE, 0, session.id, updates) + } + } + } + + override def getThrottledResponse(throttleTimeMs: Int): FetchResponse = { + session.synchronized { + // Check to make sure that the session epoch didn't change in between + // creating this fetch context and generating this response. + val expectedEpoch = JFetchMetadata.nextEpoch(reqMetadata.epoch) + if (session.epoch != expectedEpoch) { + info(s"Incremental fetch session ${session.id} expected epoch $expectedEpoch, but " + + s"got ${session.epoch}. Possible duplicate request.") + FetchResponse.of(Errors.INVALID_FETCH_SESSION_EPOCH, throttleTimeMs, session.id, new FetchSession.RESP_MAP) + } else { + FetchResponse.of(Errors.NONE, throttleTimeMs, session.id, new FetchSession.RESP_MAP) + } + } + } +} + +case class LastUsedKey(lastUsedMs: Long, id: Int) extends Comparable[LastUsedKey] { + override def compareTo(other: LastUsedKey): Int = + (lastUsedMs, id) compare (other.lastUsedMs, other.id) +} + +case class EvictableKey(privileged: Boolean, size: Int, id: Int) extends Comparable[EvictableKey] { + override def compareTo(other: EvictableKey): Int = + (privileged, size, id) compare (other.privileged, other.size, other.id) +} + +/** + * Caches fetch sessions. + * + * See tryEvict for an explanation of the cache eviction strategy. + * + * The FetchSessionCache is thread-safe because all of its methods are synchronized. + * Note that individual fetch sessions have their own locks which are separate from the + * FetchSessionCache lock. In order to avoid deadlock, the FetchSessionCache lock + * must never be acquired while an individual FetchSession lock is already held. + * + * @param maxEntries The maximum number of entries that can be in the cache. + * @param evictionMs The minimum time that an entry must be unused in order to be evictable. + */ +class FetchSessionCache(private val maxEntries: Int, + private val evictionMs: Long) extends Logging with KafkaMetricsGroup { + private var numPartitions: Long = 0 + + // A map of session ID to FetchSession. + private val sessions = new mutable.HashMap[Int, FetchSession] + + // Maps last used times to sessions. + private val lastUsed = new util.TreeMap[LastUsedKey, FetchSession] + + // A map containing sessions which can be evicted by both privileged and + // unprivileged sessions. + private val evictableByAll = new util.TreeMap[EvictableKey, FetchSession] + + // A map containing sessions which can be evicted by privileged sessions. + private val evictableByPrivileged = new util.TreeMap[EvictableKey, FetchSession] + + // Set up metrics. + removeMetric(FetchSession.NUM_INCREMENTAL_FETCH_SESSIONS) + newGauge(FetchSession.NUM_INCREMENTAL_FETCH_SESSIONS, () => FetchSessionCache.this.size) + removeMetric(FetchSession.NUM_INCREMENTAL_FETCH_PARTITIONS_CACHED) + newGauge(FetchSession.NUM_INCREMENTAL_FETCH_PARTITIONS_CACHED, () => FetchSessionCache.this.totalPartitions) + removeMetric(FetchSession.INCREMENTAL_FETCH_SESSIONS_EVICTIONS_PER_SEC) + private[server] val evictionsMeter = newMeter(FetchSession.INCREMENTAL_FETCH_SESSIONS_EVICTIONS_PER_SEC, + FetchSession.EVICTIONS, TimeUnit.SECONDS, Map.empty) + + /** + * Get a session by session ID. + * + * @param sessionId The session ID. + * @return The session, or None if no such session was found. + */ + def get(sessionId: Int): Option[FetchSession] = synchronized { + sessions.get(sessionId) + } + + /** + * Get the number of entries currently in the fetch session cache. + */ + def size: Int = synchronized { + sessions.size + } + + /** + * Get the total number of cached partitions. + */ + def totalPartitions: Long = synchronized { + numPartitions + } + + /** + * Creates a new random session ID. The new session ID will be positive and unique on this broker. + * + * @return The new session ID. + */ + def newSessionId(): Int = synchronized { + var id = 0 + do { + id = ThreadLocalRandom.current().nextInt(1, Int.MaxValue) + } while (sessions.contains(id) || id == INVALID_SESSION_ID) + id + } + + /** + * Try to create a new session. + * + * @param now The current time in milliseconds. + * @param privileged True if the new entry we are trying to create is privileged. + * @param size The number of cached partitions in the new entry we are trying to create. + * @param usesTopicIds True if this session should use topic IDs. + * @param createPartitions A callback function which creates the map of cached partitions and the mapping from + * topic name to topic ID for the topics. + * @return If we created a session, the ID; INVALID_SESSION_ID otherwise. + */ + def maybeCreateSession(now: Long, + privileged: Boolean, + size: Int, + usesTopicIds: Boolean, + createPartitions: () => FetchSession.CACHE_MAP): Int = + synchronized { + // If there is room, create a new session entry. + if ((sessions.size < maxEntries) || + tryEvict(privileged, EvictableKey(privileged, size, 0), now)) { + val partitionMap = createPartitions() + val session = new FetchSession(newSessionId(), privileged, partitionMap, usesTopicIds, + now, now, JFetchMetadata.nextEpoch(INITIAL_EPOCH)) + debug(s"Created fetch session ${session.toString}") + sessions.put(session.id, session) + touch(session, now) + session.id + } else { + debug(s"No fetch session created for privileged=$privileged, size=$size.") + INVALID_SESSION_ID + } + } + + /** + * Try to evict an entry from the session cache. + * + * A proposed new element A may evict an existing element B if: + * 1. A is privileged and B is not, or + * 2. B is considered "stale" because it has been inactive for a long time, or + * 3. A contains more partitions than B, and B is not recently created. + * + * @param privileged True if the new entry we would like to add is privileged. + * @param key The EvictableKey for the new entry we would like to add. + * @param now The current time in milliseconds. + * @return True if an entry was evicted; false otherwise. + */ + def tryEvict(privileged: Boolean, key: EvictableKey, now: Long): Boolean = synchronized { + // Try to evict an entry which is stale. + val lastUsedEntry = lastUsed.firstEntry + if (lastUsedEntry == null) { + trace("There are no cache entries to evict.") + false + } else if (now - lastUsedEntry.getKey.lastUsedMs > evictionMs) { + val session = lastUsedEntry.getValue + trace(s"Evicting stale FetchSession ${session.id}.") + remove(session) + evictionsMeter.mark() + true + } else { + // If there are no stale entries, check the first evictable entry. + // If it is less valuable than our proposed entry, evict it. + val map = if (privileged) evictableByPrivileged else evictableByAll + val evictableEntry = map.firstEntry + if (evictableEntry == null) { + trace("No evictable entries found.") + false + } else if (key.compareTo(evictableEntry.getKey) < 0) { + trace(s"Can't evict ${evictableEntry.getKey} with ${key.toString}") + false + } else { + trace(s"Evicting ${evictableEntry.getKey} with ${key.toString}.") + remove(evictableEntry.getValue) + evictionsMeter.mark() + true + } + } + } + + def remove(sessionId: Int): Option[FetchSession] = synchronized { + get(sessionId) match { + case None => None + case Some(session) => remove(session) + } + } + + /** + * Remove an entry from the session cache. + * + * @param session The session. + * + * @return The removed session, or None if there was no such session. + */ + def remove(session: FetchSession): Option[FetchSession] = synchronized { + val evictableKey = session.synchronized { + lastUsed.remove(session.lastUsedKey) + session.evictableKey + } + evictableByAll.remove(evictableKey) + evictableByPrivileged.remove(evictableKey) + val removeResult = sessions.remove(session.id) + if (removeResult.isDefined) { + numPartitions = numPartitions - session.cachedSize + } + removeResult + } + + /** + * Update a session's position in the lastUsed and evictable trees. + * + * @param session The session. + * @param now The current time in milliseconds. + */ + def touch(session: FetchSession, now: Long): Unit = synchronized { + session.synchronized { + // Update the lastUsed map. + lastUsed.remove(session.lastUsedKey) + session.lastUsedMs = now + lastUsed.put(session.lastUsedKey, session) + + val oldSize = session.cachedSize + if (oldSize != -1) { + val oldEvictableKey = session.evictableKey + evictableByPrivileged.remove(oldEvictableKey) + evictableByAll.remove(oldEvictableKey) + numPartitions = numPartitions - oldSize + } + session.cachedSize = session.size + val newEvictableKey = session.evictableKey + if ((!session.privileged) || (now - session.creationMs > evictionMs)) { + evictableByPrivileged.put(newEvictableKey, session) + } + if (now - session.creationMs > evictionMs) { + evictableByAll.put(newEvictableKey, session) + } + numPartitions = numPartitions + session.cachedSize + } + } +} + +class FetchManager(private val time: Time, + private val cache: FetchSessionCache) extends Logging { + def newContext(reqVersion: Short, + reqMetadata: JFetchMetadata, + isFollower: Boolean, + fetchData: FetchSession.REQ_MAP, + toForget: util.List[TopicIdPartition], + topicNames: FetchSession.TOPIC_NAME_MAP): FetchContext = { + val context = if (reqMetadata.isFull) { + var removedFetchSessionStr = "" + if (reqMetadata.sessionId != INVALID_SESSION_ID) { + // Any session specified in a FULL fetch request will be closed. + if (cache.remove(reqMetadata.sessionId).isDefined) { + removedFetchSessionStr = s" Removed fetch session ${reqMetadata.sessionId}." + } + } + var suffix = "" + val context = if (reqMetadata.epoch == FINAL_EPOCH) { + // If the epoch is FINAL_EPOCH, don't try to create a new session. + suffix = " Will not try to create a new session." + new SessionlessFetchContext(fetchData) + } else { + new FullFetchContext(time, cache, reqMetadata, fetchData, reqVersion >= 13, isFollower) + } + debug(s"Created a new full FetchContext with ${partitionsToLogString(fetchData.keySet)}."+ + s"${removedFetchSessionStr}${suffix}") + context + } else { + cache.synchronized { + cache.get(reqMetadata.sessionId) match { + case None => { + debug(s"Session error for ${reqMetadata.sessionId}: no such session ID found.") + new SessionErrorContext(Errors.FETCH_SESSION_ID_NOT_FOUND, reqMetadata) + } + case Some(session) => session.synchronized { + if (session.epoch != reqMetadata.epoch) { + debug(s"Session error for ${reqMetadata.sessionId}: expected epoch " + + s"${session.epoch}, but got ${reqMetadata.epoch} instead.") + new SessionErrorContext(Errors.INVALID_FETCH_SESSION_EPOCH, reqMetadata) + } else if (session.usesTopicIds && reqVersion < 13 || !session.usesTopicIds && reqVersion >= 13) { + debug(s"Session error for ${reqMetadata.sessionId}: expected " + + s"${if (session.usesTopicIds) "to use topic IDs" else "to not use topic IDs"}" + + s", but request version $reqVersion means that we can not.") + new SessionErrorContext(Errors.FETCH_SESSION_TOPIC_ID_ERROR, reqMetadata) + } else { + val (added, updated, removed) = session.update(fetchData, toForget, reqMetadata) + if (session.isEmpty) { + debug(s"Created a new sessionless FetchContext and closing session id ${session.id}, " + + s"epoch ${session.epoch}: after removing ${partitionsToLogString(removed)}, " + + s"there are no more partitions left.") + cache.remove(session) + new SessionlessFetchContext(fetchData) + } else { + cache.touch(session, time.milliseconds()) + session.epoch = JFetchMetadata.nextEpoch(session.epoch) + debug(s"Created a new incremental FetchContext for session id ${session.id}, " + + s"epoch ${session.epoch}: added ${partitionsToLogString(added)}, " + + s"updated ${partitionsToLogString(updated)}, " + + s"removed ${partitionsToLogString(removed)}") + new IncrementalFetchContext(time, reqMetadata, session, topicNames) + } + } + } + } + } + } + context + } + + def partitionsToLogString(partitions: util.Collection[TopicIdPartition]): String = + FetchSession.partitionsToLogString(partitions, isTraceEnabled) +} diff --git a/core/src/main/scala/kafka/server/FinalizedFeatureCache.scala b/core/src/main/scala/kafka/server/FinalizedFeatureCache.scala new file mode 100644 index 0000000..88addb7 --- /dev/null +++ b/core/src/main/scala/kafka/server/FinalizedFeatureCache.scala @@ -0,0 +1,183 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util +import java.util.Collections +import kafka.utils.Logging +import org.apache.kafka.common.feature.{Features, FinalizedVersionRange} +import org.apache.kafka.image.FeaturesDelta + +import scala.concurrent.TimeoutException +import scala.math.max + +import scala.compat.java8.OptionConverters._ + +// Raised whenever there was an error in updating the FinalizedFeatureCache with features. +class FeatureCacheUpdateException(message: String) extends RuntimeException(message) { +} + +// Helper class that represents finalized features along with an epoch value. +case class FinalizedFeaturesAndEpoch(features: Features[FinalizedVersionRange], epoch: Long) { + override def toString(): String = { + s"FinalizedFeaturesAndEpoch(features=$features, epoch=$epoch)" + } +} + +/** + * A common mutable cache containing the latest finalized features and epoch. By default the contents of + * the cache are empty. This cache needs to be populated at least once for its contents to become + * non-empty. Currently the main reader of this cache is the read path that serves an ApiVersionsRequest, + * returning the features information in the response. This cache is typically updated asynchronously + * whenever the finalized features and epoch values are modified in ZK by the KafkaController. + * This cache is thread-safe for reads and writes. + * + * @see FinalizedFeatureChangeListener + */ +class FinalizedFeatureCache(private val brokerFeatures: BrokerFeatures) extends Logging { + @volatile private var featuresAndEpoch: Option[FinalizedFeaturesAndEpoch] = Option.empty + + /** + * @return the latest known FinalizedFeaturesAndEpoch or empty if not defined in the cache. + */ + def get: Option[FinalizedFeaturesAndEpoch] = { + featuresAndEpoch + } + + def isEmpty: Boolean = { + featuresAndEpoch.isEmpty + } + + /** + * Waits no more than timeoutMs for the cache's epoch to reach an epoch >= minExpectedEpoch. + * + * @param minExpectedEpoch the minimum expected epoch to be reached by the cache + * (should be >= 0) + * @param timeoutMs the timeout (in milli seconds) + * + * @throws TimeoutException if the cache's epoch has not reached at least + * minExpectedEpoch within timeoutMs. + */ + def waitUntilEpochOrThrow(minExpectedEpoch: Long, timeoutMs: Long): Unit = { + if(minExpectedEpoch < 0L) { + throw new IllegalArgumentException( + s"Expected minExpectedEpoch >= 0, but $minExpectedEpoch was provided.") + } + waitUntilConditionOrThrow( + () => featuresAndEpoch.isDefined && featuresAndEpoch.get.epoch >= minExpectedEpoch, + timeoutMs) + } + + /** + * Clears all existing finalized features and epoch from the cache. + */ + def clear(): Unit = { + synchronized { + featuresAndEpoch = Option.empty + notifyAll() + } + info("Cleared cache") + } + + /** + * Updates the cache to the latestFeatures, and updates the existing epoch to latestEpoch. + * Expects that the latestEpoch should be always greater than the existing epoch (when the + * existing epoch is defined). + * + * @param latestFeatures the latest finalized features to be set in the cache + * @param latestEpoch the latest epoch value to be set in the cache + * + * @throws FeatureCacheUpdateException if the cache update operation fails + * due to invalid parameters or incompatibilities with the broker's + * supported features. In such a case, the existing cache contents are + * not modified. + */ + def updateOrThrow(latestFeatures: Features[FinalizedVersionRange], latestEpoch: Long): Unit = { + val latest = FinalizedFeaturesAndEpoch(latestFeatures, latestEpoch) + val existing = featuresAndEpoch.map(item => item.toString()).getOrElse("") + if (featuresAndEpoch.isDefined && featuresAndEpoch.get.epoch > latest.epoch) { + val errorMsg = s"FinalizedFeatureCache update failed due to invalid epoch in new $latest." + + s" The existing cache contents are $existing." + throw new FeatureCacheUpdateException(errorMsg) + } else { + val incompatibleFeatures = brokerFeatures.incompatibleFeatures(latest.features) + if (!incompatibleFeatures.empty) { + val errorMsg = "FinalizedFeatureCache update failed since feature compatibility" + + s" checks failed! Supported ${brokerFeatures.supportedFeatures} has incompatibilities" + + s" with the latest $latest." + throw new FeatureCacheUpdateException(errorMsg) + } else { + val logMsg = s"Updated cache from existing $existing to latest $latest." + synchronized { + featuresAndEpoch = Some(latest) + notifyAll() + } + info(logMsg) + } + } + } + + def update(featuresDelta: FeaturesDelta, highestMetadataOffset: Long): Unit = { + val features = featuresAndEpoch.getOrElse( + FinalizedFeaturesAndEpoch(Features.emptyFinalizedFeatures(), -1)) + val newFeatures = new util.HashMap[String, FinalizedVersionRange]() + newFeatures.putAll(features.features.features()) + featuresDelta.changes().entrySet().forEach { e => + e.getValue().asScala match { + case None => newFeatures.remove(e.getKey) + case Some(feature) => newFeatures.put(e.getKey, + new FinalizedVersionRange(feature.min(), feature.max())) + } + } + featuresAndEpoch = Some(FinalizedFeaturesAndEpoch(Features.finalizedFeatures( + Collections.unmodifiableMap(newFeatures)), highestMetadataOffset)) + } + + /** + * Causes the current thread to wait no more than timeoutMs for the specified condition to be met. + * It is guaranteed that the provided condition will always be invoked only from within a + * synchronized block. + * + * @param waitCondition the condition to be waited upon: + * - if the condition returns true, then, the wait will stop. + * - if the condition returns false, it means the wait must continue until + * timeout. + * + * @param timeoutMs the timeout (in milli seconds) + * + * @throws TimeoutException if the condition is not met within timeoutMs. + */ + private def waitUntilConditionOrThrow(waitCondition: () => Boolean, timeoutMs: Long): Unit = { + if(timeoutMs < 0L) { + throw new IllegalArgumentException(s"Expected timeoutMs >= 0, but $timeoutMs was provided.") + } + val waitEndTimeNanos = System.nanoTime() + (timeoutMs * 1000000) + synchronized { + while (!waitCondition()) { + val nowNanos = System.nanoTime() + if (nowNanos > waitEndTimeNanos) { + throw new TimeoutException( + s"Timed out after waiting for ${timeoutMs}ms for required condition to be met." + + s" Current epoch: ${featuresAndEpoch.map(fe => fe.epoch).getOrElse("")}.") + } + val sleepTimeMs = max(1L, (waitEndTimeNanos - nowNanos) / 1000000) + wait(sleepTimeMs) + } + } + } +} diff --git a/core/src/main/scala/kafka/server/FinalizedFeatureChangeListener.scala b/core/src/main/scala/kafka/server/FinalizedFeatureChangeListener.scala new file mode 100644 index 0000000..8f10ab6 --- /dev/null +++ b/core/src/main/scala/kafka/server/FinalizedFeatureChangeListener.scala @@ -0,0 +1,256 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util.concurrent.{CountDownLatch, LinkedBlockingQueue, TimeUnit} + +import kafka.utils.{Logging, ShutdownableThread} +import kafka.zk.{FeatureZNode, FeatureZNodeStatus, KafkaZkClient, ZkVersion} +import kafka.zookeeper.{StateChangeHandler, ZNodeChangeHandler} +import org.apache.kafka.common.internals.FatalExitError + +import scala.concurrent.TimeoutException + +/** + * Listens to changes in the ZK feature node, via the ZK client. Whenever a change notification + * is received from ZK, the feature cache in FinalizedFeatureCache is asynchronously updated + * to the latest features read from ZK. The cache updates are serialized through a single + * notification processor thread. + * + * @param finalizedFeatureCache the finalized feature cache + * @param zkClient the Zookeeper client + */ +class FinalizedFeatureChangeListener(private val finalizedFeatureCache: FinalizedFeatureCache, + private val zkClient: KafkaZkClient) extends Logging { + + /** + * Helper class used to update the FinalizedFeatureCache. + * + * @param featureZkNodePath the path to the ZK feature node to be read + * @param maybeNotifyOnce an optional latch that can be used to notify the caller when an + * updateOrThrow() operation is over + */ + private class FeatureCacheUpdater(featureZkNodePath: String, maybeNotifyOnce: Option[CountDownLatch]) { + + def this(featureZkNodePath: String) = this(featureZkNodePath, Option.empty) + + /** + * Updates the feature cache in FinalizedFeatureCache with the latest features read from the + * ZK node in featureZkNodePath. If the cache update is not successful, then, a suitable + * exception is raised. + * + * NOTE: if a notifier was provided in the constructor, then, this method can be invoked exactly + * once successfully. A subsequent invocation will raise an exception. + * + * @throws IllegalStateException, if a non-empty notifier was provided in the constructor, and + * this method is called again after a successful previous invocation. + * @throws FeatureCacheUpdateException, if there was an error in updating the + * FinalizedFeatureCache. + */ + def updateLatestOrThrow(): Unit = { + maybeNotifyOnce.foreach(notifier => { + if (notifier.getCount != 1) { + throw new IllegalStateException( + "Can not notify after updateLatestOrThrow was called more than once successfully.") + } + }) + + debug(s"Reading feature ZK node at path: $featureZkNodePath") + val (mayBeFeatureZNodeBytes, version) = zkClient.getDataAndVersion(featureZkNodePath) + + // There are 4 cases: + // + // (empty dataBytes, valid version) => The empty dataBytes will fail FeatureZNode deserialization. + // FeatureZNode, when present in ZK, can not have empty contents. + // (non-empty dataBytes, valid version) => This is a valid case, and should pass FeatureZNode deserialization + // if dataBytes contains valid data. + // (empty dataBytes, unknown version) => This is a valid case, and this can happen if the FeatureZNode + // does not exist in ZK. + // (non-empty dataBytes, unknown version) => This case is impossible, since, KafkaZkClient.getDataAndVersion + // API ensures that unknown version is returned only when the + // ZK node is absent. Therefore dataBytes should be empty in such + // a case. + if (version == ZkVersion.UnknownVersion) { + info(s"Feature ZK node at path: $featureZkNodePath does not exist") + finalizedFeatureCache.clear() + } else { + var maybeFeatureZNode: Option[FeatureZNode] = Option.empty + try { + maybeFeatureZNode = Some(FeatureZNode.decode(mayBeFeatureZNodeBytes.get)) + } catch { + case e: IllegalArgumentException => { + error(s"Unable to deserialize feature ZK node at path: $featureZkNodePath", e) + finalizedFeatureCache.clear() + } + } + maybeFeatureZNode.foreach(featureZNode => { + featureZNode.status match { + case FeatureZNodeStatus.Disabled => { + info(s"Feature ZK node at path: $featureZkNodePath is in disabled status.") + finalizedFeatureCache.clear() + } + case FeatureZNodeStatus.Enabled => { + finalizedFeatureCache.updateOrThrow(featureZNode.features, version) + } + case _ => throw new IllegalStateException(s"Unexpected FeatureZNodeStatus found in $featureZNode") + } + }) + } + + maybeNotifyOnce.foreach(notifier => notifier.countDown()) + } + + /** + * Waits until at least a single updateLatestOrThrow completes successfully. This method returns + * immediately if an updateLatestOrThrow call had already completed successfully. + * + * @param waitTimeMs the timeout for the wait operation + * + * @throws TimeoutException if the wait can not be completed in waitTimeMs + * milli seconds + */ + def awaitUpdateOrThrow(waitTimeMs: Long): Unit = { + maybeNotifyOnce.foreach(notifier => { + if (!notifier.await(waitTimeMs, TimeUnit.MILLISECONDS)) { + throw new TimeoutException( + s"Timed out after waiting for ${waitTimeMs}ms for FeatureCache to be updated.") + } + }) + } + } + + /** + * A shutdownable thread to process feature node change notifications that are populated into the + * queue. If any change notification can not be processed successfully (unless it is due to an + * interrupt), the thread treats it as a fatal event and triggers Broker exit. + * + * @param name name of the thread + */ + private class ChangeNotificationProcessorThread(name: String) extends ShutdownableThread(name = name) { + override def doWork(): Unit = { + try { + queue.take.updateLatestOrThrow() + } catch { + case ie: InterruptedException => + // While the queue is empty and this thread is blocking on taking an item from the queue, + // a concurrent call to FinalizedFeatureChangeListener.close() could interrupt the thread + // and cause an InterruptedException to be raised from queue.take(). In such a case, it is + // safe to ignore the exception if the thread is being shutdown. We raise the exception + // here again, because, it is ignored by ShutdownableThread if it is shutting down. + throw ie + case e: Exception => { + error("Failed to process feature ZK node change event. The broker will eventually exit.", e) + throw new FatalExitError(1) + } + } + } + } + + // Feature ZK node change handler. + object FeatureZNodeChangeHandler extends ZNodeChangeHandler { + override val path: String = FeatureZNode.path + + override def handleCreation(): Unit = { + info(s"Feature ZK node created at path: $path") + queue.add(new FeatureCacheUpdater(path)) + } + + override def handleDataChange(): Unit = { + info(s"Feature ZK node updated at path: $path") + queue.add(new FeatureCacheUpdater(path)) + } + + override def handleDeletion(): Unit = { + warn(s"Feature ZK node deleted at path: $path") + // This event may happen, rarely (ex: ZK corruption or operational error). + // In such a case, we prefer to just log a warning and treat the case as if the node is absent, + // and populate the FinalizedFeatureCache with empty finalized features. + queue.add(new FeatureCacheUpdater(path)) + } + } + + object ZkStateChangeHandler extends StateChangeHandler { + val path: String = FeatureZNode.path + + override val name: String = path + + override def afterInitializingSession(): Unit = { + queue.add(new FeatureCacheUpdater(path)) + } + } + + private val queue = new LinkedBlockingQueue[FeatureCacheUpdater] + + private val thread = new ChangeNotificationProcessorThread("feature-zk-node-event-process-thread") + + /** + * This method initializes the feature ZK node change listener. Optionally, it also ensures to + * update the FinalizedFeatureCache once with the latest contents of the feature ZK node + * (if the node exists). This step helps ensure that feature incompatibilities (if any) in brokers + * are conveniently detected before the initOrThrow() method returns to the caller. If feature + * incompatibilities are detected, this method will throw an Exception to the caller, and the Broker + * will exit eventually. + * + * @param waitOnceForCacheUpdateMs # of milli seconds to wait for feature cache to be updated once. + * (should be > 0) + * + * @throws Exception if feature incompatibility check could not be finished in a timely manner + */ + def initOrThrow(waitOnceForCacheUpdateMs: Long): Unit = { + if (waitOnceForCacheUpdateMs <= 0) { + throw new IllegalArgumentException( + s"Expected waitOnceForCacheUpdateMs > 0, but provided: $waitOnceForCacheUpdateMs") + } + + thread.start() + zkClient.registerStateChangeHandler(ZkStateChangeHandler) + zkClient.registerZNodeChangeHandlerAndCheckExistence(FeatureZNodeChangeHandler) + val ensureCacheUpdateOnce = new FeatureCacheUpdater( + FeatureZNodeChangeHandler.path, Some(new CountDownLatch(1))) + queue.add(ensureCacheUpdateOnce) + try { + ensureCacheUpdateOnce.awaitUpdateOrThrow(waitOnceForCacheUpdateMs) + } catch { + case e: Exception => { + close() + throw e + } + } + } + + /** + * Closes the feature ZK node change listener by unregistering the listener from ZK client, + * clearing the queue and shutting down the ChangeNotificationProcessorThread. + */ + def close(): Unit = { + zkClient.unregisterStateChangeHandler(ZkStateChangeHandler.name) + zkClient.unregisterZNodeChangeHandler(FeatureZNodeChangeHandler.path) + queue.clear() + thread.shutdown() + } + + // For testing only. + def isListenerInitiated: Boolean = { + thread.isRunning && thread.isAlive + } + + // For testing only. + def isListenerDead: Boolean = { + !thread.isRunning && !thread.isAlive + } +} diff --git a/core/src/main/scala/kafka/server/ForwardingManager.scala b/core/src/main/scala/kafka/server/ForwardingManager.scala new file mode 100644 index 0000000..e84592b --- /dev/null +++ b/core/src/main/scala/kafka/server/ForwardingManager.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.nio.ByteBuffer + +import kafka.network.RequestChannel +import kafka.utils.Logging +import org.apache.kafka.clients.{ClientResponse, NodeApiVersions} +import org.apache.kafka.common.errors.TimeoutException +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.{AbstractRequest, AbstractResponse, EnvelopeRequest, EnvelopeResponse, RequestContext, RequestHeader} + +import scala.compat.java8.OptionConverters._ + +trait ForwardingManager { + def forwardRequest( + request: RequestChannel.Request, + responseCallback: Option[AbstractResponse] => Unit + ): Unit + + def controllerApiVersions: Option[NodeApiVersions] +} + +object ForwardingManager { + def apply( + channelManager: BrokerToControllerChannelManager + ): ForwardingManager = { + new ForwardingManagerImpl(channelManager) + } + + private[server] def buildEnvelopeRequest(context: RequestContext, + forwardRequestBuffer: ByteBuffer): EnvelopeRequest.Builder = { + val principalSerde = context.principalSerde.asScala.getOrElse( + throw new IllegalArgumentException(s"Cannot deserialize principal from request context $context " + + "since there is no serde defined") + ) + val serializedPrincipal = principalSerde.serialize(context.principal) + new EnvelopeRequest.Builder( + forwardRequestBuffer, + serializedPrincipal, + context.clientAddress.getAddress + ) + } +} + +class ForwardingManagerImpl( + channelManager: BrokerToControllerChannelManager +) extends ForwardingManager with Logging { + + /** + * Forward given request to the active controller. + * + * @param request request to be forwarded + * @param responseCallback callback which takes in an `Option[AbstractResponse]`, where + * None is indicating that controller doesn't support the request + * version. + */ + override def forwardRequest( + request: RequestChannel.Request, + responseCallback: Option[AbstractResponse] => Unit + ): Unit = { + val requestBuffer = request.buffer.duplicate() + requestBuffer.flip() + val envelopeRequest = ForwardingManager.buildEnvelopeRequest(request.context, requestBuffer) + + class ForwardingResponseHandler extends ControllerRequestCompletionHandler { + override def onComplete(clientResponse: ClientResponse): Unit = { + val requestBody = request.body[AbstractRequest] + + if (clientResponse.versionMismatch != null) { + debug(s"Returning `UNKNOWN_SERVER_ERROR` in response to request $requestBody " + + s"due to unexpected version error", clientResponse.versionMismatch) + responseCallback(Some(requestBody.getErrorResponse(Errors.UNKNOWN_SERVER_ERROR.exception))) + } else if (clientResponse.authenticationException != null) { + debug(s"Returning `UNKNOWN_SERVER_ERROR` in response to request $requestBody " + + s"due to authentication error", clientResponse.authenticationException) + responseCallback(Some(requestBody.getErrorResponse(Errors.UNKNOWN_SERVER_ERROR.exception))) + } else { + val envelopeResponse = clientResponse.responseBody.asInstanceOf[EnvelopeResponse] + val envelopeError = envelopeResponse.error() + + // Unsupported version indicates an incompatibility between controller and client API versions. This + // could happen when the controller changed after the connection was established. The forwarding broker + // should close the connection with the client and let it reinitialize the connection and refresh + // the controller API versions. + if (envelopeError == Errors.UNSUPPORTED_VERSION) { + responseCallback(None) + } else { + val response = if (envelopeError != Errors.NONE) { + // A general envelope error indicates broker misconfiguration (e.g. the principal serde + // might not be defined on the receiving broker). In this case, we do not return + // the error directly to the client since it would not be expected. Instead we + // return `UNKNOWN_SERVER_ERROR` so that the user knows that there is a problem + // on the broker. + debug(s"Forwarded request $request failed with an error in the envelope response $envelopeError") + requestBody.getErrorResponse(Errors.UNKNOWN_SERVER_ERROR.exception) + } else { + parseResponse(envelopeResponse.responseData, requestBody, request.header) + } + responseCallback(Option(response)) + } + } + } + + override def onTimeout(): Unit = { + debug(s"Forwarding of the request $request failed due to timeout exception") + val response = request.body[AbstractRequest].getErrorResponse(new TimeoutException()) + responseCallback(Option(response)) + } + } + + channelManager.sendRequest(envelopeRequest, new ForwardingResponseHandler) + } + + override def controllerApiVersions: Option[NodeApiVersions] = + channelManager.controllerApiVersions() + + private def parseResponse( + buffer: ByteBuffer, + request: AbstractRequest, + header: RequestHeader + ): AbstractResponse = { + try { + AbstractResponse.parseResponse(buffer, header) + } catch { + case e: Exception => + error(s"Failed to parse response from envelope for request with header $header", e) + request.getErrorResponse(Errors.UNKNOWN_SERVER_ERROR.exception) + } + } +} diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala b/core/src/main/scala/kafka/server/KafkaApis.scala new file mode 100644 index 0000000..a4b38be --- /dev/null +++ b/core/src/main/scala/kafka/server/KafkaApis.scala @@ -0,0 +1,3527 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.admin.AdminUtils +import kafka.api.{ApiVersion, ElectLeadersRequestOps, KAFKA_0_11_0_IV0, KAFKA_2_3_IV0} +import kafka.common.OffsetAndMetadata +import kafka.controller.ReplicaAssignment +import kafka.coordinator.group._ +import kafka.coordinator.transaction.{InitProducerIdResult, TransactionCoordinator} +import kafka.log.AppendOrigin +import kafka.message.ZStdCompressionCodec +import kafka.network.RequestChannel +import kafka.server.QuotaFactory.{QuotaManagers, UnboundedQuota} +import kafka.server.metadata.ConfigRepository +import kafka.utils.Implicits._ +import kafka.utils.{CoreUtils, Logging} +import org.apache.kafka.clients.admin.AlterConfigOp.OpType +import org.apache.kafka.clients.admin.{AlterConfigOp, ConfigEntry} +import org.apache.kafka.common.acl.AclOperation._ +import org.apache.kafka.common.acl.AclOperation +import org.apache.kafka.common.config.ConfigResource +import org.apache.kafka.common.errors._ +import org.apache.kafka.common.internals.Topic.{GROUP_METADATA_TOPIC_NAME, TRANSACTION_STATE_TOPIC_NAME, isInternal} +import org.apache.kafka.common.internals.{FatalExitError, Topic} +import org.apache.kafka.common.message.AlterConfigsResponseData.AlterConfigsResourceResponse +import org.apache.kafka.common.message.AlterPartitionReassignmentsResponseData.{ReassignablePartitionResponse, ReassignableTopicResponse} +import org.apache.kafka.common.message.CreatePartitionsResponseData.CreatePartitionsTopicResult +import org.apache.kafka.common.message.CreateTopicsRequestData.CreatableTopic +import org.apache.kafka.common.message.CreateTopicsResponseData.{CreatableTopicResult, CreatableTopicResultCollection} +import org.apache.kafka.common.message.DeleteGroupsResponseData.{DeletableGroupResult, DeletableGroupResultCollection} +import org.apache.kafka.common.message.DeleteRecordsResponseData.{DeleteRecordsPartitionResult, DeleteRecordsTopicResult} +import org.apache.kafka.common.message.DeleteTopicsResponseData.{DeletableTopicResult, DeletableTopicResultCollection} +import org.apache.kafka.common.message.ElectLeadersResponseData.{PartitionResult, ReplicaElectionResult} +import org.apache.kafka.common.message.LeaveGroupResponseData.MemberResponse +import org.apache.kafka.common.message.ListOffsetsRequestData.ListOffsetsPartition +import org.apache.kafka.common.message.ListOffsetsResponseData.{ListOffsetsPartitionResponse, ListOffsetsTopicResponse} +import org.apache.kafka.common.message.MetadataResponseData.{MetadataResponsePartition, MetadataResponseTopic} +import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderTopic +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.{EpochEndOffset, OffsetForLeaderTopicResult, OffsetForLeaderTopicResultCollection} +import org.apache.kafka.common.message._ +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.network.{ListenerName, Send} +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.record._ +import org.apache.kafka.common.replica.ClientMetadata +import org.apache.kafka.common.replica.ClientMetadata.DefaultClientMetadata +import org.apache.kafka.common.requests.FindCoordinatorRequest.CoordinatorType +import org.apache.kafka.common.requests.OffsetFetchResponse.PartitionData +import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse +import org.apache.kafka.common.requests._ +import org.apache.kafka.common.resource.Resource.CLUSTER_NAME +import org.apache.kafka.common.resource.ResourceType._ +import org.apache.kafka.common.resource.{Resource, ResourceType} +import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol} +import org.apache.kafka.common.security.token.delegation.{DelegationToken, TokenInformation} +import org.apache.kafka.common.utils.{ProducerIdAndEpoch, Time} +import org.apache.kafka.common.{Node, TopicIdPartition, TopicPartition, Uuid} +import org.apache.kafka.server.authorizer._ +import java.lang.{Long => JLong} +import java.nio.ByteBuffer +import java.util +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicInteger +import java.util.{Collections, Optional} + +import scala.annotation.nowarn +import scala.collection.{Map, Seq, Set, immutable, mutable} +import scala.jdk.CollectionConverters._ +import scala.util.{Failure, Success, Try} + +/** + * Logic to handle the various Kafka requests + */ +class KafkaApis(val requestChannel: RequestChannel, + val metadataSupport: MetadataSupport, + val replicaManager: ReplicaManager, + val groupCoordinator: GroupCoordinator, + val txnCoordinator: TransactionCoordinator, + val autoTopicCreationManager: AutoTopicCreationManager, + val brokerId: Int, + val config: KafkaConfig, + val configRepository: ConfigRepository, + val metadataCache: MetadataCache, + val metrics: Metrics, + val authorizer: Option[Authorizer], + val quotas: QuotaManagers, + val fetchManager: FetchManager, + brokerTopicStats: BrokerTopicStats, + val clusterId: String, + time: Time, + val tokenManager: DelegationTokenManager, + val apiVersionManager: ApiVersionManager) extends ApiRequestHandler with Logging { + + type FetchResponseStats = Map[TopicPartition, RecordConversionStats] + this.logIdent = "[KafkaApi-%d] ".format(brokerId) + val configHelper = new ConfigHelper(metadataCache, config, configRepository) + val authHelper = new AuthHelper(authorizer) + val requestHelper = new RequestHandlerHelper(requestChannel, quotas, time) + val aclApis = new AclApis(authHelper, authorizer, requestHelper, "broker", config) + + def close(): Unit = { + aclApis.close() + info("Shutdown complete.") + } + + private def isForwardingEnabled(request: RequestChannel.Request): Boolean = { + metadataSupport.forwardingManager.isDefined && request.context.principalSerde.isPresent + } + + private def maybeForwardToController( + request: RequestChannel.Request, + handler: RequestChannel.Request => Unit + ): Unit = { + def responseCallback(responseOpt: Option[AbstractResponse]): Unit = { + responseOpt match { + case Some(response) => requestHelper.sendForwardedResponse(request, response) + case None => + info(s"The client connection will be closed due to controller responded " + + s"unsupported version exception during $request forwarding. " + + s"This could happen when the controller changed after the connection was established.") + requestChannel.closeConnection(request, Collections.emptyMap()) + } + } + + metadataSupport.maybeForward(request, handler, responseCallback) + } + + private def forwardToControllerOrFail( + request: RequestChannel.Request + ): Unit = { + def errorHandler(request: RequestChannel.Request): Unit = { + throw new IllegalStateException(s"Unable to forward $request to the controller") + } + + maybeForwardToController(request, errorHandler) + } + + /** + * Top-level method that handles all requests and multiplexes to the right api + */ + override def handle(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { + try { + trace(s"Handling request:${request.requestDesc(true)} from connection ${request.context.connectionId};" + + s"securityProtocol:${request.context.securityProtocol},principal:${request.context.principal}") + + if (!apiVersionManager.isApiEnabled(request.header.apiKey)) { + // The socket server will reject APIs which are not exposed in this scope and close the connection + // before handing them to the request handler, so this path should not be exercised in practice + throw new IllegalStateException(s"API ${request.header.apiKey} is not enabled") + } + + request.header.apiKey match { + case ApiKeys.PRODUCE => handleProduceRequest(request, requestLocal) + case ApiKeys.FETCH => handleFetchRequest(request) + case ApiKeys.LIST_OFFSETS => handleListOffsetRequest(request) + case ApiKeys.METADATA => handleTopicMetadataRequest(request) + case ApiKeys.LEADER_AND_ISR => handleLeaderAndIsrRequest(request) + case ApiKeys.STOP_REPLICA => handleStopReplicaRequest(request) + case ApiKeys.UPDATE_METADATA => handleUpdateMetadataRequest(request, requestLocal) + case ApiKeys.CONTROLLED_SHUTDOWN => handleControlledShutdownRequest(request) + case ApiKeys.OFFSET_COMMIT => handleOffsetCommitRequest(request, requestLocal) + case ApiKeys.OFFSET_FETCH => handleOffsetFetchRequest(request) + case ApiKeys.FIND_COORDINATOR => handleFindCoordinatorRequest(request) + case ApiKeys.JOIN_GROUP => handleJoinGroupRequest(request, requestLocal) + case ApiKeys.HEARTBEAT => handleHeartbeatRequest(request) + case ApiKeys.LEAVE_GROUP => handleLeaveGroupRequest(request) + case ApiKeys.SYNC_GROUP => handleSyncGroupRequest(request, requestLocal) + case ApiKeys.DESCRIBE_GROUPS => handleDescribeGroupRequest(request) + case ApiKeys.LIST_GROUPS => handleListGroupsRequest(request) + case ApiKeys.SASL_HANDSHAKE => handleSaslHandshakeRequest(request) + case ApiKeys.API_VERSIONS => handleApiVersionsRequest(request) + case ApiKeys.CREATE_TOPICS => maybeForwardToController(request, handleCreateTopicsRequest) + case ApiKeys.DELETE_TOPICS => maybeForwardToController(request, handleDeleteTopicsRequest) + case ApiKeys.DELETE_RECORDS => handleDeleteRecordsRequest(request) + case ApiKeys.INIT_PRODUCER_ID => handleInitProducerIdRequest(request, requestLocal) + case ApiKeys.OFFSET_FOR_LEADER_EPOCH => handleOffsetForLeaderEpochRequest(request) + case ApiKeys.ADD_PARTITIONS_TO_TXN => handleAddPartitionToTxnRequest(request, requestLocal) + case ApiKeys.ADD_OFFSETS_TO_TXN => handleAddOffsetsToTxnRequest(request, requestLocal) + case ApiKeys.END_TXN => handleEndTxnRequest(request, requestLocal) + case ApiKeys.WRITE_TXN_MARKERS => handleWriteTxnMarkersRequest(request, requestLocal) + case ApiKeys.TXN_OFFSET_COMMIT => handleTxnOffsetCommitRequest(request, requestLocal) + case ApiKeys.DESCRIBE_ACLS => handleDescribeAcls(request) + case ApiKeys.CREATE_ACLS => maybeForwardToController(request, handleCreateAcls) + case ApiKeys.DELETE_ACLS => maybeForwardToController(request, handleDeleteAcls) + case ApiKeys.ALTER_CONFIGS => maybeForwardToController(request, handleAlterConfigsRequest) + case ApiKeys.DESCRIBE_CONFIGS => handleDescribeConfigsRequest(request) + case ApiKeys.ALTER_REPLICA_LOG_DIRS => handleAlterReplicaLogDirsRequest(request) + case ApiKeys.DESCRIBE_LOG_DIRS => handleDescribeLogDirsRequest(request) + case ApiKeys.SASL_AUTHENTICATE => handleSaslAuthenticateRequest(request) + case ApiKeys.CREATE_PARTITIONS => maybeForwardToController(request, handleCreatePartitionsRequest) + case ApiKeys.CREATE_DELEGATION_TOKEN => maybeForwardToController(request, handleCreateTokenRequest) + case ApiKeys.RENEW_DELEGATION_TOKEN => maybeForwardToController(request, handleRenewTokenRequest) + case ApiKeys.EXPIRE_DELEGATION_TOKEN => maybeForwardToController(request, handleExpireTokenRequest) + case ApiKeys.DESCRIBE_DELEGATION_TOKEN => handleDescribeTokensRequest(request) + case ApiKeys.DELETE_GROUPS => handleDeleteGroupsRequest(request, requestLocal) + case ApiKeys.ELECT_LEADERS => maybeForwardToController(request, handleElectLeaders) + case ApiKeys.INCREMENTAL_ALTER_CONFIGS => maybeForwardToController(request, handleIncrementalAlterConfigsRequest) + case ApiKeys.ALTER_PARTITION_REASSIGNMENTS => maybeForwardToController(request, handleAlterPartitionReassignmentsRequest) + case ApiKeys.LIST_PARTITION_REASSIGNMENTS => maybeForwardToController(request, handleListPartitionReassignmentsRequest) + case ApiKeys.OFFSET_DELETE => handleOffsetDeleteRequest(request, requestLocal) + case ApiKeys.DESCRIBE_CLIENT_QUOTAS => handleDescribeClientQuotasRequest(request) + case ApiKeys.ALTER_CLIENT_QUOTAS => maybeForwardToController(request, handleAlterClientQuotasRequest) + case ApiKeys.DESCRIBE_USER_SCRAM_CREDENTIALS => handleDescribeUserScramCredentialsRequest(request) + case ApiKeys.ALTER_USER_SCRAM_CREDENTIALS => maybeForwardToController(request, handleAlterUserScramCredentialsRequest) + case ApiKeys.ALTER_ISR => handleAlterIsrRequest(request) + case ApiKeys.UPDATE_FEATURES => maybeForwardToController(request, handleUpdateFeatures) + case ApiKeys.ENVELOPE => handleEnvelope(request, requestLocal) + case ApiKeys.DESCRIBE_CLUSTER => handleDescribeCluster(request) + case ApiKeys.DESCRIBE_PRODUCERS => handleDescribeProducersRequest(request) + case ApiKeys.DESCRIBE_TRANSACTIONS => handleDescribeTransactionsRequest(request) + case ApiKeys.LIST_TRANSACTIONS => handleListTransactionsRequest(request) + case ApiKeys.ALLOCATE_PRODUCER_IDS => handleAllocateProducerIdsRequest(request) + case ApiKeys.DESCRIBE_QUORUM => forwardToControllerOrFail(request) + case _ => throw new IllegalStateException(s"No handler for request api key ${request.header.apiKey}") + } + } catch { + case e: FatalExitError => throw e + case e: Throwable => + error(s"Unexpected error handling request ${request.requestDesc(true)} " + + s"with context ${request.context}", e) + requestHelper.handleError(request, e) + } finally { + // try to complete delayed action. In order to avoid conflicting locking, the actions to complete delayed requests + // are kept in a queue. We add the logic to check the ReplicaManager queue at the end of KafkaApis.handle() and the + // expiration thread for certain delayed operations (e.g. DelayedJoin) + replicaManager.tryCompleteActions() + // The local completion time may be set while processing the request. Only record it if it's unset. + if (request.apiLocalCompleteTimeNanos < 0) + request.apiLocalCompleteTimeNanos = time.nanoseconds + } + } + + def handleLeaderAndIsrRequest(request: RequestChannel.Request): Unit = { + val zkSupport = metadataSupport.requireZkOrThrow(KafkaApis.shouldNeverReceive(request)) + // ensureTopicExists is only for client facing requests + // We can't have the ensureTopicExists check here since the controller sends it as an advisory to all brokers so they + // stop serving data to clients for the topic being deleted + val correlationId = request.header.correlationId + val leaderAndIsrRequest = request.body[LeaderAndIsrRequest] + + authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) + if (isBrokerEpochStale(zkSupport, leaderAndIsrRequest.brokerEpoch)) { + // When the broker restarts very quickly, it is possible for this broker to receive request intended + // for its previous generation so the broker should skip the stale request. + info("Received LeaderAndIsr request with broker epoch " + + s"${leaderAndIsrRequest.brokerEpoch} smaller than the current broker epoch ${zkSupport.controller.brokerEpoch}") + requestHelper.sendResponseExemptThrottle(request, leaderAndIsrRequest.getErrorResponse(0, Errors.STALE_BROKER_EPOCH.exception)) + } else { + val response = replicaManager.becomeLeaderOrFollower(correlationId, leaderAndIsrRequest, + RequestHandlerHelper.onLeadershipChange(groupCoordinator, txnCoordinator, _, _)) + requestHelper.sendResponseExemptThrottle(request, response) + } + } + + def handleStopReplicaRequest(request: RequestChannel.Request): Unit = { + val zkSupport = metadataSupport.requireZkOrThrow(KafkaApis.shouldNeverReceive(request)) + // ensureTopicExists is only for client facing requests + // We can't have the ensureTopicExists check here since the controller sends it as an advisory to all brokers so they + // stop serving data to clients for the topic being deleted + val stopReplicaRequest = request.body[StopReplicaRequest] + authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) + if (isBrokerEpochStale(zkSupport, stopReplicaRequest.brokerEpoch)) { + // When the broker restarts very quickly, it is possible for this broker to receive request intended + // for its previous generation so the broker should skip the stale request. + info("Received StopReplica request with broker epoch " + + s"${stopReplicaRequest.brokerEpoch} smaller than the current broker epoch ${zkSupport.controller.brokerEpoch}") + requestHelper.sendResponseExemptThrottle(request, new StopReplicaResponse( + new StopReplicaResponseData().setErrorCode(Errors.STALE_BROKER_EPOCH.code))) + } else { + val partitionStates = stopReplicaRequest.partitionStates().asScala + val (result, error) = replicaManager.stopReplicas( + request.context.correlationId, + stopReplicaRequest.controllerId, + stopReplicaRequest.controllerEpoch, + partitionStates) + // Clear the coordinator caches in case we were the leader. In the case of a reassignment, we + // cannot rely on the LeaderAndIsr API for this since it is only sent to active replicas. + result.forKeyValue { (topicPartition, error) => + if (error == Errors.NONE) { + val partitionState = partitionStates(topicPartition) + if (topicPartition.topic == GROUP_METADATA_TOPIC_NAME + && partitionState.deletePartition) { + val leaderEpoch = if (partitionState.leaderEpoch >= 0) + Some(partitionState.leaderEpoch) + else + None + groupCoordinator.onResignation(topicPartition.partition, leaderEpoch) + } else if (topicPartition.topic == TRANSACTION_STATE_TOPIC_NAME + && partitionState.deletePartition) { + val leaderEpoch = if (partitionState.leaderEpoch >= 0) + Some(partitionState.leaderEpoch) + else + None + txnCoordinator.onResignation(topicPartition.partition, coordinatorEpoch = leaderEpoch) + } + } + } + + def toStopReplicaPartition(tp: TopicPartition, error: Errors) = + new StopReplicaResponseData.StopReplicaPartitionError() + .setTopicName(tp.topic) + .setPartitionIndex(tp.partition) + .setErrorCode(error.code) + + requestHelper.sendResponseExemptThrottle(request, new StopReplicaResponse(new StopReplicaResponseData() + .setErrorCode(error.code) + .setPartitionErrors(result.map { + case (tp, error) => toStopReplicaPartition(tp, error) + }.toBuffer.asJava))) + } + + CoreUtils.swallow(replicaManager.replicaFetcherManager.shutdownIdleFetcherThreads(), this) + } + + def handleUpdateMetadataRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { + val zkSupport = metadataSupport.requireZkOrThrow(KafkaApis.shouldNeverReceive(request)) + val correlationId = request.header.correlationId + val updateMetadataRequest = request.body[UpdateMetadataRequest] + + authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) + if (isBrokerEpochStale(zkSupport, updateMetadataRequest.brokerEpoch)) { + // When the broker restarts very quickly, it is possible for this broker to receive request intended + // for its previous generation so the broker should skip the stale request. + info("Received update metadata request with broker epoch " + + s"${updateMetadataRequest.brokerEpoch} smaller than the current broker epoch ${zkSupport.controller.brokerEpoch}") + requestHelper.sendResponseExemptThrottle(request, + new UpdateMetadataResponse(new UpdateMetadataResponseData().setErrorCode(Errors.STALE_BROKER_EPOCH.code))) + } else { + val deletedPartitions = replicaManager.maybeUpdateMetadataCache(correlationId, updateMetadataRequest) + if (deletedPartitions.nonEmpty) + groupCoordinator.handleDeletedPartitions(deletedPartitions, requestLocal) + + if (zkSupport.adminManager.hasDelayedTopicOperations) { + updateMetadataRequest.partitionStates.forEach { partitionState => + zkSupport.adminManager.tryCompleteDelayedTopicOperations(partitionState.topicName) + } + } + + quotas.clientQuotaCallback.foreach { callback => + if (callback.updateClusterMetadata(metadataCache.getClusterMetadata(clusterId, request.context.listenerName))) { + quotas.fetch.updateQuotaMetricConfigs() + quotas.produce.updateQuotaMetricConfigs() + quotas.request.updateQuotaMetricConfigs() + quotas.controllerMutation.updateQuotaMetricConfigs() + } + } + if (replicaManager.hasDelayedElectionOperations) { + updateMetadataRequest.partitionStates.forEach { partitionState => + val tp = new TopicPartition(partitionState.topicName, partitionState.partitionIndex) + replicaManager.tryCompleteElection(TopicPartitionOperationKey(tp)) + } + } + requestHelper.sendResponseExemptThrottle(request, new UpdateMetadataResponse( + new UpdateMetadataResponseData().setErrorCode(Errors.NONE.code))) + } + } + + def handleControlledShutdownRequest(request: RequestChannel.Request): Unit = { + val zkSupport = metadataSupport.requireZkOrThrow(KafkaApis.shouldNeverReceive(request)) + // ensureTopicExists is only for client facing requests + // We can't have the ensureTopicExists check here since the controller sends it as an advisory to all brokers so they + // stop serving data to clients for the topic being deleted + val controlledShutdownRequest = request.body[ControlledShutdownRequest] + authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) + + def controlledShutdownCallback(controlledShutdownResult: Try[Set[TopicPartition]]): Unit = { + val response = controlledShutdownResult match { + case Success(partitionsRemaining) => + ControlledShutdownResponse.prepareResponse(Errors.NONE, partitionsRemaining.asJava) + + case Failure(throwable) => + controlledShutdownRequest.getErrorResponse(throwable) + } + requestHelper.sendResponseExemptThrottle(request, response) + } + zkSupport.controller.controlledShutdown(controlledShutdownRequest.data.brokerId, controlledShutdownRequest.data.brokerEpoch, controlledShutdownCallback) + } + + /** + * Handle an offset commit request + */ + def handleOffsetCommitRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { + val header = request.header + val offsetCommitRequest = request.body[OffsetCommitRequest] + + val unauthorizedTopicErrors = mutable.Map[TopicPartition, Errors]() + val nonExistingTopicErrors = mutable.Map[TopicPartition, Errors]() + // the callback for sending an offset commit response + def sendResponseCallback(commitStatus: Map[TopicPartition, Errors]): Unit = { + val combinedCommitStatus = commitStatus ++ unauthorizedTopicErrors ++ nonExistingTopicErrors + if (isDebugEnabled) + combinedCommitStatus.forKeyValue { (topicPartition, error) => + if (error != Errors.NONE) { + debug(s"Offset commit request with correlation id ${header.correlationId} from client ${header.clientId} " + + s"on partition $topicPartition failed due to ${error.exceptionName}") + } + } + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new OffsetCommitResponse(requestThrottleMs, combinedCommitStatus.asJava)) + } + + // reject the request if not authorized to the group + if (!authHelper.authorize(request.context, READ, GROUP, offsetCommitRequest.data.groupId)) { + val error = Errors.GROUP_AUTHORIZATION_FAILED + val responseTopicList = OffsetCommitRequest.getErrorResponseTopics( + offsetCommitRequest.data.topics, + error) + + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => new OffsetCommitResponse( + new OffsetCommitResponseData() + .setTopics(responseTopicList) + .setThrottleTimeMs(requestThrottleMs) + )) + } else if (offsetCommitRequest.data.groupInstanceId != null && config.interBrokerProtocolVersion < KAFKA_2_3_IV0) { + // Only enable static membership when IBP >= 2.3, because it is not safe for the broker to use the static member logic + // until we are sure that all brokers support it. If static group being loaded by an older coordinator, it will discard + // the group.instance.id field, so static members could accidentally become "dynamic", which leads to wrong states. + val errorMap = new mutable.HashMap[TopicPartition, Errors] + for (topicData <- offsetCommitRequest.data.topics.asScala) { + for (partitionData <- topicData.partitions.asScala) { + val topicPartition = new TopicPartition(topicData.name, partitionData.partitionIndex) + errorMap += topicPartition -> Errors.UNSUPPORTED_VERSION + } + } + sendResponseCallback(errorMap.toMap) + } else { + val authorizedTopicRequestInfoBldr = immutable.Map.newBuilder[TopicPartition, OffsetCommitRequestData.OffsetCommitRequestPartition] + + val topics = offsetCommitRequest.data.topics.asScala + val authorizedTopics = authHelper.filterByAuthorized(request.context, READ, TOPIC, topics)(_.name) + for (topicData <- topics) { + for (partitionData <- topicData.partitions.asScala) { + val topicPartition = new TopicPartition(topicData.name, partitionData.partitionIndex) + if (!authorizedTopics.contains(topicData.name)) + unauthorizedTopicErrors += (topicPartition -> Errors.TOPIC_AUTHORIZATION_FAILED) + else if (!metadataCache.contains(topicPartition)) + nonExistingTopicErrors += (topicPartition -> Errors.UNKNOWN_TOPIC_OR_PARTITION) + else + authorizedTopicRequestInfoBldr += (topicPartition -> partitionData) + } + } + + val authorizedTopicRequestInfo = authorizedTopicRequestInfoBldr.result() + + if (authorizedTopicRequestInfo.isEmpty) + sendResponseCallback(Map.empty) + else if (header.apiVersion == 0) { + // for version 0 always store offsets to ZK + val zkSupport = metadataSupport.requireZkOrThrow(KafkaApis.unsupported("Version 0 offset commit requests")) + val responseInfo = authorizedTopicRequestInfo.map { + case (topicPartition, partitionData) => + try { + if (partitionData.committedMetadata() != null + && partitionData.committedMetadata().length > config.offsetMetadataMaxSize) + (topicPartition, Errors.OFFSET_METADATA_TOO_LARGE) + else { + zkSupport.zkClient.setOrCreateConsumerOffset( + offsetCommitRequest.data.groupId, + topicPartition, + partitionData.committedOffset) + (topicPartition, Errors.NONE) + } + } catch { + case e: Throwable => (topicPartition, Errors.forException(e)) + } + } + sendResponseCallback(responseInfo) + } else { + // for version 1 and beyond store offsets in offset manager + + // "default" expiration timestamp is now + retention (and retention may be overridden if v2) + // expire timestamp is computed differently for v1 and v2. + // - If v1 and no explicit commit timestamp is provided we treat it the same as v5. + // - If v1 and explicit retention time is provided we calculate expiration timestamp based on that + // - If v2/v3/v4 (no explicit commit timestamp) we treat it the same as v5. + // - For v5 and beyond there is no per partition expiration timestamp, so this field is no longer in effect + val currentTimestamp = time.milliseconds + val partitionData = authorizedTopicRequestInfo.map { case (k, partitionData) => + val metadata = if (partitionData.committedMetadata == null) + OffsetAndMetadata.NoMetadata + else + partitionData.committedMetadata + + val leaderEpochOpt = if (partitionData.committedLeaderEpoch == RecordBatch.NO_PARTITION_LEADER_EPOCH) + Optional.empty[Integer] + else + Optional.of[Integer](partitionData.committedLeaderEpoch) + + k -> new OffsetAndMetadata( + offset = partitionData.committedOffset, + leaderEpoch = leaderEpochOpt, + metadata = metadata, + commitTimestamp = partitionData.commitTimestamp match { + case OffsetCommitRequest.DEFAULT_TIMESTAMP => currentTimestamp + case customTimestamp => customTimestamp + }, + expireTimestamp = offsetCommitRequest.data.retentionTimeMs match { + case OffsetCommitRequest.DEFAULT_RETENTION_TIME => None + case retentionTime => Some(currentTimestamp + retentionTime) + } + ) + } + + // call coordinator to handle commit offset + groupCoordinator.handleCommitOffsets( + offsetCommitRequest.data.groupId, + offsetCommitRequest.data.memberId, + Option(offsetCommitRequest.data.groupInstanceId), + offsetCommitRequest.data.generationId, + partitionData, + sendResponseCallback, + requestLocal) + } + } + } + + /** + * Handle a produce request + */ + def handleProduceRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { + val produceRequest = request.body[ProduceRequest] + val requestSize = request.sizeInBytes + + if (RequestUtils.hasTransactionalRecords(produceRequest)) { + val isAuthorizedTransactional = produceRequest.transactionalId != null && + authHelper.authorize(request.context, WRITE, TRANSACTIONAL_ID, produceRequest.transactionalId) + if (!isAuthorizedTransactional) { + requestHelper.sendErrorResponseMaybeThrottle(request, Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED.exception) + return + } + } + + val unauthorizedTopicResponses = mutable.Map[TopicPartition, PartitionResponse]() + val nonExistingTopicResponses = mutable.Map[TopicPartition, PartitionResponse]() + val invalidRequestResponses = mutable.Map[TopicPartition, PartitionResponse]() + val authorizedRequestInfo = mutable.Map[TopicPartition, MemoryRecords]() + // cache the result to avoid redundant authorization calls + val authorizedTopics = authHelper.filterByAuthorized(request.context, WRITE, TOPIC, + produceRequest.data().topicData().asScala)(_.name()) + + produceRequest.data.topicData.forEach(topic => topic.partitionData.forEach { partition => + val topicPartition = new TopicPartition(topic.name, partition.index) + // This caller assumes the type is MemoryRecords and that is true on current serialization + // We cast the type to avoid causing big change to code base. + // https://issues.apache.org/jira/browse/KAFKA-10698 + val memoryRecords = partition.records.asInstanceOf[MemoryRecords] + if (!authorizedTopics.contains(topicPartition.topic)) + unauthorizedTopicResponses += topicPartition -> new PartitionResponse(Errors.TOPIC_AUTHORIZATION_FAILED) + else if (!metadataCache.contains(topicPartition)) + nonExistingTopicResponses += topicPartition -> new PartitionResponse(Errors.UNKNOWN_TOPIC_OR_PARTITION) + else + try { + ProduceRequest.validateRecords(request.header.apiVersion, memoryRecords) + authorizedRequestInfo += (topicPartition -> memoryRecords) + } catch { + case e: ApiException => + invalidRequestResponses += topicPartition -> new PartitionResponse(Errors.forException(e)) + } + }) + + // the callback for sending a produce response + // The construction of ProduceResponse is able to accept auto-generated protocol data so + // KafkaApis#handleProduceRequest should apply auto-generated protocol to avoid extra conversion. + // https://issues.apache.org/jira/browse/KAFKA-10730 + @nowarn("cat=deprecation") + def sendResponseCallback(responseStatus: Map[TopicPartition, PartitionResponse]): Unit = { + val mergedResponseStatus = responseStatus ++ unauthorizedTopicResponses ++ nonExistingTopicResponses ++ invalidRequestResponses + var errorInResponse = false + + mergedResponseStatus.forKeyValue { (topicPartition, status) => + if (status.error != Errors.NONE) { + errorInResponse = true + debug("Produce request with correlation id %d from client %s on partition %s failed due to %s".format( + request.header.correlationId, + request.header.clientId, + topicPartition, + status.error.exceptionName)) + } + } + + // Record both bandwidth and request quota-specific values and throttle by muting the channel if any of the quotas + // have been violated. If both quotas have been violated, use the max throttle time between the two quotas. Note + // that the request quota is not enforced if acks == 0. + val timeMs = time.milliseconds() + val bandwidthThrottleTimeMs = quotas.produce.maybeRecordAndGetThrottleTimeMs(request, requestSize, timeMs) + val requestThrottleTimeMs = + if (produceRequest.acks == 0) 0 + else quotas.request.maybeRecordAndGetThrottleTimeMs(request, timeMs) + val maxThrottleTimeMs = Math.max(bandwidthThrottleTimeMs, requestThrottleTimeMs) + if (maxThrottleTimeMs > 0) { + request.apiThrottleTimeMs = maxThrottleTimeMs + if (bandwidthThrottleTimeMs > requestThrottleTimeMs) { + requestHelper.throttle(quotas.produce, request, bandwidthThrottleTimeMs) + } else { + requestHelper.throttle(quotas.request, request, requestThrottleTimeMs) + } + } + + // Send the response immediately. In case of throttling, the channel has already been muted. + if (produceRequest.acks == 0) { + // no operation needed if producer request.required.acks = 0; however, if there is any error in handling + // the request, since no response is expected by the producer, the server will close socket server so that + // the producer client will know that some error has happened and will refresh its metadata + if (errorInResponse) { + val exceptionsSummary = mergedResponseStatus.map { case (topicPartition, status) => + topicPartition -> status.error.exceptionName + }.mkString(", ") + info( + s"Closing connection due to error during produce request with correlation id ${request.header.correlationId} " + + s"from client id ${request.header.clientId} with ack=0\n" + + s"Topic and partition to exceptions: $exceptionsSummary" + ) + requestChannel.closeConnection(request, new ProduceResponse(mergedResponseStatus.asJava).errorCounts) + } else { + // Note that although request throttling is exempt for acks == 0, the channel may be throttled due to + // bandwidth quota violation. + requestHelper.sendNoOpResponseExemptThrottle(request) + } + } else { + requestChannel.sendResponse(request, new ProduceResponse(mergedResponseStatus.asJava, maxThrottleTimeMs), None) + } + } + + def processingStatsCallback(processingStats: FetchResponseStats): Unit = { + processingStats.forKeyValue { (tp, info) => + updateRecordConversionStats(request, tp, info) + } + } + + if (authorizedRequestInfo.isEmpty) + sendResponseCallback(Map.empty) + else { + val internalTopicsAllowed = request.header.clientId == AdminUtils.AdminClientId + + // call the replica manager to append messages to the replicas + replicaManager.appendRecords( + timeout = produceRequest.timeout.toLong, + requiredAcks = produceRequest.acks, + internalTopicsAllowed = internalTopicsAllowed, + origin = AppendOrigin.Client, + entriesPerPartition = authorizedRequestInfo, + requestLocal = requestLocal, + responseCallback = sendResponseCallback, + recordConversionStatsCallback = processingStatsCallback) + + // if the request is put into the purgatory, it will have a held reference and hence cannot be garbage collected; + // hence we clear its data here in order to let GC reclaim its memory since it is already appended to log + produceRequest.clearPartitionRecords() + } + } + + /** + * Handle a fetch request + */ + def handleFetchRequest(request: RequestChannel.Request): Unit = { + val versionId = request.header.apiVersion + val clientId = request.header.clientId + val fetchRequest = request.body[FetchRequest] + val topicNames = + if (fetchRequest.version() >= 13) + metadataCache.topicIdsToNames() + else + Collections.emptyMap[Uuid, String]() + + val fetchData = fetchRequest.fetchData(topicNames) + val forgottenTopics = fetchRequest.forgottenTopics(topicNames) + + val fetchContext = fetchManager.newContext( + fetchRequest.version, + fetchRequest.metadata, + fetchRequest.isFromFollower, + fetchData, + forgottenTopics, + topicNames) + + val clientMetadata: Option[ClientMetadata] = if (versionId >= 11) { + // Fetch API version 11 added preferred replica logic + Some(new DefaultClientMetadata( + fetchRequest.rackId, + clientId, + request.context.clientAddress, + request.context.principal, + request.context.listenerName.value)) + } else { + None + } + + val erroneous = mutable.ArrayBuffer[(TopicIdPartition, FetchResponseData.PartitionData)]() + val interesting = mutable.ArrayBuffer[(TopicIdPartition, FetchRequest.PartitionData)]() + if (fetchRequest.isFromFollower) { + // The follower must have ClusterAction on ClusterResource in order to fetch partition data. + if (authHelper.authorize(request.context, CLUSTER_ACTION, CLUSTER, CLUSTER_NAME)) { + fetchContext.foreachPartition { (topicIdPartition, data) => + if (topicIdPartition.topic == null) + erroneous += topicIdPartition -> FetchResponse.partitionResponse(topicIdPartition, Errors.UNKNOWN_TOPIC_ID) + else if (!metadataCache.contains(topicIdPartition.topicPartition)) + erroneous += topicIdPartition -> FetchResponse.partitionResponse(topicIdPartition, Errors.UNKNOWN_TOPIC_OR_PARTITION) + else + interesting += topicIdPartition -> data + } + } else { + fetchContext.foreachPartition { (topicIdPartition, _) => + erroneous += topicIdPartition -> FetchResponse.partitionResponse(topicIdPartition, Errors.TOPIC_AUTHORIZATION_FAILED) + } + } + } else { + // Regular Kafka consumers need READ permission on each partition they are fetching. + val partitionDatas = new mutable.ArrayBuffer[(TopicIdPartition, FetchRequest.PartitionData)] + fetchContext.foreachPartition { (topicIdPartition, partitionData) => + if (topicIdPartition.topic == null) + erroneous += topicIdPartition -> FetchResponse.partitionResponse(topicIdPartition, Errors.UNKNOWN_TOPIC_ID) + else + partitionDatas += topicIdPartition -> partitionData + } + val authorizedTopics = authHelper.filterByAuthorized(request.context, READ, TOPIC, partitionDatas)(_._1.topicPartition.topic) + partitionDatas.foreach { case (topicIdPartition, data) => + if (!authorizedTopics.contains(topicIdPartition.topic)) + erroneous += topicIdPartition -> FetchResponse.partitionResponse(topicIdPartition, Errors.TOPIC_AUTHORIZATION_FAILED) + else if (!metadataCache.contains(topicIdPartition.topicPartition)) + erroneous += topicIdPartition -> FetchResponse.partitionResponse(topicIdPartition, Errors.UNKNOWN_TOPIC_OR_PARTITION) + else + interesting += topicIdPartition -> data + } + } + + def maybeDownConvertStorageError(error: Errors): Errors = { + // If consumer sends FetchRequest V5 or earlier, the client library is not guaranteed to recognize the error code + // for KafkaStorageException. In this case the client library will translate KafkaStorageException to + // UnknownServerException which is not retriable. We can ensure that consumer will update metadata and retry + // by converting the KafkaStorageException to NotLeaderOrFollowerException in the response if FetchRequest version <= 5 + if (error == Errors.KAFKA_STORAGE_ERROR && versionId <= 5) { + Errors.NOT_LEADER_OR_FOLLOWER + } else { + error + } + } + + def maybeConvertFetchedData(tp: TopicIdPartition, + partitionData: FetchResponseData.PartitionData): FetchResponseData.PartitionData = { + // We will never return a logConfig when the topic is unresolved and the name is null. This is ok since we won't have any records to convert. + val logConfig = replicaManager.getLogConfig(tp.topicPartition) + + if (logConfig.exists(_.compressionType == ZStdCompressionCodec.name) && versionId < 10) { + trace(s"Fetching messages is disabled for ZStandard compressed partition $tp. Sending unsupported version response to $clientId.") + FetchResponse.partitionResponse(tp, Errors.UNSUPPORTED_COMPRESSION_TYPE) + } else { + // Down-conversion of fetched records is needed when the on-disk magic value is greater than what is + // supported by the fetch request version. + // If the inter-broker protocol version is `3.0` or higher, the log config message format version is + // always `3.0` (i.e. magic value is `v2`). As a result, we always go through the down-conversion + // path if the fetch version is 3 or lower (in rare cases the down-conversion may not be needed, but + // it's not worth optimizing for them). + // If the inter-broker protocol version is lower than `3.0`, we rely on the log config message format + // version as a proxy for the on-disk magic value to maintain the long-standing behavior originally + // introduced in Kafka 0.10.0. An important implication is that it's unsafe to downgrade the message + // format version after a single message has been produced (the broker would return the message(s) + // without down-conversion irrespective of the fetch version). + val unconvertedRecords = FetchResponse.recordsOrFail(partitionData) + val downConvertMagic = + logConfig.map(_.recordVersion.value).flatMap { magic => + if (magic > RecordBatch.MAGIC_VALUE_V0 && versionId <= 1) + Some(RecordBatch.MAGIC_VALUE_V0) + else if (magic > RecordBatch.MAGIC_VALUE_V1 && versionId <= 3) + Some(RecordBatch.MAGIC_VALUE_V1) + else + None + } + + downConvertMagic match { + case Some(magic) => + // For fetch requests from clients, check if down-conversion is disabled for the particular partition + if (!fetchRequest.isFromFollower && !logConfig.forall(_.messageDownConversionEnable)) { + trace(s"Conversion to message format ${downConvertMagic.get} is disabled for partition $tp. Sending unsupported version response to $clientId.") + FetchResponse.partitionResponse(tp, Errors.UNSUPPORTED_VERSION) + } else { + try { + trace(s"Down converting records from partition $tp to message format version $magic for fetch request from $clientId") + // Because down-conversion is extremely memory intensive, we want to try and delay the down-conversion as much + // as possible. With KIP-283, we have the ability to lazily down-convert in a chunked manner. The lazy, chunked + // down-conversion always guarantees that at least one batch of messages is down-converted and sent out to the + // client. + new FetchResponseData.PartitionData() + .setPartitionIndex(tp.partition) + .setErrorCode(maybeDownConvertStorageError(Errors.forCode(partitionData.errorCode)).code) + .setHighWatermark(partitionData.highWatermark) + .setLastStableOffset(partitionData.lastStableOffset) + .setLogStartOffset(partitionData.logStartOffset) + .setAbortedTransactions(partitionData.abortedTransactions) + .setRecords(new LazyDownConversionRecords(tp.topicPartition, unconvertedRecords, magic, fetchContext.getFetchOffset(tp).get, time)) + .setPreferredReadReplica(partitionData.preferredReadReplica()) + } catch { + case e: UnsupportedCompressionTypeException => + trace("Received unsupported compression type error during down-conversion", e) + FetchResponse.partitionResponse(tp, Errors.UNSUPPORTED_COMPRESSION_TYPE) + } + } + case None => + new FetchResponseData.PartitionData() + .setPartitionIndex(tp.partition) + .setErrorCode(maybeDownConvertStorageError(Errors.forCode(partitionData.errorCode)).code) + .setHighWatermark(partitionData.highWatermark) + .setLastStableOffset(partitionData.lastStableOffset) + .setLogStartOffset(partitionData.logStartOffset) + .setAbortedTransactions(partitionData.abortedTransactions) + .setRecords(unconvertedRecords) + .setPreferredReadReplica(partitionData.preferredReadReplica) + .setDivergingEpoch(partitionData.divergingEpoch) + } + } + } + + // the callback for process a fetch response, invoked before throttling + def processResponseCallback(responsePartitionData: Seq[(TopicIdPartition, FetchPartitionData)]): Unit = { + val partitions = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + val reassigningPartitions = mutable.Set[TopicIdPartition]() + responsePartitionData.foreach { case (tp, data) => + val abortedTransactions = data.abortedTransactions.map(_.asJava).orNull + val lastStableOffset = data.lastStableOffset.getOrElse(FetchResponse.INVALID_LAST_STABLE_OFFSET) + if (data.isReassignmentFetch) reassigningPartitions.add(tp) + val partitionData = new FetchResponseData.PartitionData() + .setPartitionIndex(tp.partition) + .setErrorCode(maybeDownConvertStorageError(data.error).code) + .setHighWatermark(data.highWatermark) + .setLastStableOffset(lastStableOffset) + .setLogStartOffset(data.logStartOffset) + .setAbortedTransactions(abortedTransactions) + .setRecords(data.records) + .setPreferredReadReplica(data.preferredReadReplica.getOrElse(FetchResponse.INVALID_PREFERRED_REPLICA_ID)) + data.divergingEpoch.foreach(partitionData.setDivergingEpoch) + partitions.put(tp, partitionData) + } + erroneous.foreach { case (tp, data) => partitions.put(tp, data) } + + var unconvertedFetchResponse: FetchResponse = null + + def createResponse(throttleTimeMs: Int): FetchResponse = { + // Down-convert messages for each partition if required + val convertedData = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + unconvertedFetchResponse.data().responses().forEach { topicResponse => + topicResponse.partitions().forEach { unconvertedPartitionData => + val tp = new TopicIdPartition(topicResponse.topicId, new TopicPartition(topicResponse.topic, unconvertedPartitionData.partitionIndex())) + val error = Errors.forCode(unconvertedPartitionData.errorCode) + if (error != Errors.NONE) + debug(s"Fetch request with correlation id ${request.header.correlationId} from client $clientId " + + s"on partition $tp failed due to ${error.exceptionName}") + convertedData.put(tp, maybeConvertFetchedData(tp, unconvertedPartitionData)) + } + } + + // Prepare fetch response from converted data + val response = + FetchResponse.of(unconvertedFetchResponse.error, throttleTimeMs, unconvertedFetchResponse.sessionId, convertedData) + // record the bytes out metrics only when the response is being sent + response.data.responses.forEach { topicResponse => + topicResponse.partitions.forEach { data => + // If the topic name was not known, we will have no bytes out. + if (topicResponse.topic != null) { + val tp = new TopicIdPartition(topicResponse.topicId, new TopicPartition(topicResponse.topic, data.partitionIndex)) + brokerTopicStats.updateBytesOut(tp.topic, fetchRequest.isFromFollower, reassigningPartitions.contains(tp), FetchResponse.recordsSize(data)) + } + } + } + response + } + + def updateConversionStats(send: Send): Unit = { + send match { + case send: MultiRecordsSend if send.recordConversionStats != null => + send.recordConversionStats.asScala.toMap.foreach { + case (tp, stats) => updateRecordConversionStats(request, tp, stats) + } + case _ => + } + } + + if (fetchRequest.isFromFollower) { + // We've already evaluated against the quota and are good to go. Just need to record it now. + unconvertedFetchResponse = fetchContext.updateAndGenerateResponseData(partitions) + val responseSize = KafkaApis.sizeOfThrottledPartitions(versionId, unconvertedFetchResponse, quotas.leader) + quotas.leader.record(responseSize) + val responsePartitionsSize = unconvertedFetchResponse.data().responses().stream().mapToInt(_.partitions().size()).sum() + trace(s"Sending Fetch response with partitions.size=$responsePartitionsSize, " + + s"metadata=${unconvertedFetchResponse.sessionId}") + requestHelper.sendResponseExemptThrottle(request, createResponse(0), Some(updateConversionStats)) + } else { + // Fetch size used to determine throttle time is calculated before any down conversions. + // This may be slightly different from the actual response size. But since down conversions + // result in data being loaded into memory, we should do this only when we are not going to throttle. + // + // Record both bandwidth and request quota-specific values and throttle by muting the channel if any of the + // quotas have been violated. If both quotas have been violated, use the max throttle time between the two + // quotas. When throttled, we unrecord the recorded bandwidth quota value + val responseSize = fetchContext.getResponseSize(partitions, versionId) + val timeMs = time.milliseconds() + val requestThrottleTimeMs = quotas.request.maybeRecordAndGetThrottleTimeMs(request, timeMs) + val bandwidthThrottleTimeMs = quotas.fetch.maybeRecordAndGetThrottleTimeMs(request, responseSize, timeMs) + + val maxThrottleTimeMs = math.max(bandwidthThrottleTimeMs, requestThrottleTimeMs) + if (maxThrottleTimeMs > 0) { + request.apiThrottleTimeMs = maxThrottleTimeMs + // Even if we need to throttle for request quota violation, we should "unrecord" the already recorded value + // from the fetch quota because we are going to return an empty response. + quotas.fetch.unrecordQuotaSensor(request, responseSize, timeMs) + if (bandwidthThrottleTimeMs > requestThrottleTimeMs) { + requestHelper.throttle(quotas.fetch, request, bandwidthThrottleTimeMs) + } else { + requestHelper.throttle(quotas.request, request, requestThrottleTimeMs) + } + // If throttling is required, return an empty response. + unconvertedFetchResponse = fetchContext.getThrottledResponse(maxThrottleTimeMs) + } else { + // Get the actual response. This will update the fetch context. + unconvertedFetchResponse = fetchContext.updateAndGenerateResponseData(partitions) + val responsePartitionsSize = unconvertedFetchResponse.data().responses().stream().mapToInt(_.partitions().size()).sum() + trace(s"Sending Fetch response with partitions.size=$responsePartitionsSize, " + + s"metadata=${unconvertedFetchResponse.sessionId}") + } + + // Send the response immediately. + requestChannel.sendResponse(request, createResponse(maxThrottleTimeMs), Some(updateConversionStats)) + } + } + + // for fetch from consumer, cap fetchMaxBytes to the maximum bytes that could be fetched without being throttled given + // no bytes were recorded in the recent quota window + // trying to fetch more bytes would result in a guaranteed throttling potentially blocking consumer progress + val maxQuotaWindowBytes = if (fetchRequest.isFromFollower) + Int.MaxValue + else + quotas.fetch.getMaxValueInQuotaWindow(request.session, clientId).toInt + + val fetchMaxBytes = Math.min(Math.min(fetchRequest.maxBytes, config.fetchMaxBytes), maxQuotaWindowBytes) + val fetchMinBytes = Math.min(fetchRequest.minBytes, fetchMaxBytes) + if (interesting.isEmpty) + processResponseCallback(Seq.empty) + else { + // call the replica manager to fetch messages from the local replica + replicaManager.fetchMessages( + fetchRequest.maxWait.toLong, + fetchRequest.replicaId, + fetchMinBytes, + fetchMaxBytes, + versionId <= 2, + interesting, + replicationQuota(fetchRequest), + processResponseCallback, + fetchRequest.isolationLevel, + clientMetadata) + } + } + + def replicationQuota(fetchRequest: FetchRequest): ReplicaQuota = + if (fetchRequest.isFromFollower) quotas.leader else UnboundedQuota + + def handleListOffsetRequest(request: RequestChannel.Request): Unit = { + val version = request.header.apiVersion + + val topics = if (version == 0) + handleListOffsetRequestV0(request) + else + handleListOffsetRequestV1AndAbove(request) + + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => new ListOffsetsResponse(new ListOffsetsResponseData() + .setThrottleTimeMs(requestThrottleMs) + .setTopics(topics.asJava))) + } + + private def handleListOffsetRequestV0(request : RequestChannel.Request) : List[ListOffsetsTopicResponse] = { + val correlationId = request.header.correlationId + val clientId = request.header.clientId + val offsetRequest = request.body[ListOffsetsRequest] + + val (authorizedRequestInfo, unauthorizedRequestInfo) = authHelper.partitionSeqByAuthorized(request.context, + DESCRIBE, TOPIC, offsetRequest.topics.asScala.toSeq)(_.name) + + val unauthorizedResponseStatus = unauthorizedRequestInfo.map(topic => + new ListOffsetsTopicResponse() + .setName(topic.name) + .setPartitions(topic.partitions.asScala.map(partition => + new ListOffsetsPartitionResponse() + .setPartitionIndex(partition.partitionIndex) + .setErrorCode(Errors.TOPIC_AUTHORIZATION_FAILED.code)).asJava) + ) + + val responseTopics = authorizedRequestInfo.map { topic => + val responsePartitions = topic.partitions.asScala.map { partition => + val topicPartition = new TopicPartition(topic.name, partition.partitionIndex) + + try { + val offsets = replicaManager.legacyFetchOffsetsForTimestamp( + topicPartition = topicPartition, + timestamp = partition.timestamp, + maxNumOffsets = partition.maxNumOffsets, + isFromConsumer = offsetRequest.replicaId == ListOffsetsRequest.CONSUMER_REPLICA_ID, + fetchOnlyFromLeader = offsetRequest.replicaId != ListOffsetsRequest.DEBUGGING_REPLICA_ID) + new ListOffsetsPartitionResponse() + .setPartitionIndex(partition.partitionIndex) + .setErrorCode(Errors.NONE.code) + .setOldStyleOffsets(offsets.map(JLong.valueOf).asJava) + } catch { + // NOTE: UnknownTopicOrPartitionException and NotLeaderOrFollowerException are special cases since these error messages + // are typically transient and there is no value in logging the entire stack trace for the same + case e @ (_ : UnknownTopicOrPartitionException | + _ : NotLeaderOrFollowerException | + _ : KafkaStorageException) => + debug("Offset request with correlation id %d from client %s on partition %s failed due to %s".format( + correlationId, clientId, topicPartition, e.getMessage)) + new ListOffsetsPartitionResponse() + .setPartitionIndex(partition.partitionIndex) + .setErrorCode(Errors.forException(e).code) + case e: Throwable => + error("Error while responding to offset request", e) + new ListOffsetsPartitionResponse() + .setPartitionIndex(partition.partitionIndex) + .setErrorCode(Errors.forException(e).code) + } + } + new ListOffsetsTopicResponse().setName(topic.name).setPartitions(responsePartitions.asJava) + } + (responseTopics ++ unauthorizedResponseStatus).toList + } + + private def handleListOffsetRequestV1AndAbove(request : RequestChannel.Request): List[ListOffsetsTopicResponse] = { + val correlationId = request.header.correlationId + val clientId = request.header.clientId + val offsetRequest = request.body[ListOffsetsRequest] + val version = request.header.apiVersion + + def buildErrorResponse(e: Errors, partition: ListOffsetsPartition): ListOffsetsPartitionResponse = { + new ListOffsetsPartitionResponse() + .setPartitionIndex(partition.partitionIndex) + .setErrorCode(e.code) + .setTimestamp(ListOffsetsResponse.UNKNOWN_TIMESTAMP) + .setOffset(ListOffsetsResponse.UNKNOWN_OFFSET) + } + + val (authorizedRequestInfo, unauthorizedRequestInfo) = authHelper.partitionSeqByAuthorized(request.context, + DESCRIBE, TOPIC, offsetRequest.topics.asScala.toSeq)(_.name) + + val unauthorizedResponseStatus = unauthorizedRequestInfo.map(topic => + new ListOffsetsTopicResponse() + .setName(topic.name) + .setPartitions(topic.partitions.asScala.map(partition => + buildErrorResponse(Errors.TOPIC_AUTHORIZATION_FAILED, partition)).asJava) + ) + + val responseTopics = authorizedRequestInfo.map { topic => + val responsePartitions = topic.partitions.asScala.map { partition => + val topicPartition = new TopicPartition(topic.name, partition.partitionIndex) + if (offsetRequest.duplicatePartitions.contains(topicPartition)) { + debug(s"OffsetRequest with correlation id $correlationId from client $clientId on partition $topicPartition " + + s"failed because the partition is duplicated in the request.") + buildErrorResponse(Errors.INVALID_REQUEST, partition) + } else { + try { + val fetchOnlyFromLeader = offsetRequest.replicaId != ListOffsetsRequest.DEBUGGING_REPLICA_ID + val isClientRequest = offsetRequest.replicaId == ListOffsetsRequest.CONSUMER_REPLICA_ID + val isolationLevelOpt = if (isClientRequest) + Some(offsetRequest.isolationLevel) + else + None + + val foundOpt = replicaManager.fetchOffsetForTimestamp(topicPartition, + partition.timestamp, + isolationLevelOpt, + if (partition.currentLeaderEpoch == ListOffsetsResponse.UNKNOWN_EPOCH) Optional.empty() else Optional.of(partition.currentLeaderEpoch), + fetchOnlyFromLeader) + + val response = foundOpt match { + case Some(found) => + val partitionResponse = new ListOffsetsPartitionResponse() + .setPartitionIndex(partition.partitionIndex) + .setErrorCode(Errors.NONE.code) + .setTimestamp(found.timestamp) + .setOffset(found.offset) + if (found.leaderEpoch.isPresent && version >= 4) + partitionResponse.setLeaderEpoch(found.leaderEpoch.get) + partitionResponse + case None => + buildErrorResponse(Errors.NONE, partition) + } + response + } catch { + // NOTE: These exceptions are special cases since these error messages are typically transient or the client + // would have received a clear exception and there is no value in logging the entire stack trace for the same + case e @ (_ : UnknownTopicOrPartitionException | + _ : NotLeaderOrFollowerException | + _ : UnknownLeaderEpochException | + _ : FencedLeaderEpochException | + _ : KafkaStorageException | + _ : UnsupportedForMessageFormatException) => + debug(s"Offset request with correlation id $correlationId from client $clientId on " + + s"partition $topicPartition failed due to ${e.getMessage}") + buildErrorResponse(Errors.forException(e), partition) + + // Only V5 and newer ListOffset calls should get OFFSET_NOT_AVAILABLE + case e: OffsetNotAvailableException => + if (request.header.apiVersion >= 5) { + buildErrorResponse(Errors.forException(e), partition) + } else { + buildErrorResponse(Errors.LEADER_NOT_AVAILABLE, partition) + } + + case e: Throwable => + error("Error while responding to offset request", e) + buildErrorResponse(Errors.forException(e), partition) + } + } + } + new ListOffsetsTopicResponse().setName(topic.name).setPartitions(responsePartitions.asJava) + } + (responseTopics ++ unauthorizedResponseStatus).toList + } + + private def metadataResponseTopic(error: Errors, + topic: String, + topicId: Uuid, + isInternal: Boolean, + partitionData: util.List[MetadataResponsePartition]): MetadataResponseTopic = { + new MetadataResponseTopic() + .setErrorCode(error.code) + .setName(topic) + .setTopicId(topicId) + .setIsInternal(isInternal) + .setPartitions(partitionData) + } + + private def getTopicMetadata( + request: RequestChannel.Request, + fetchAllTopics: Boolean, + allowAutoTopicCreation: Boolean, + topics: Set[String], + listenerName: ListenerName, + errorUnavailableEndpoints: Boolean, + errorUnavailableListeners: Boolean + ): Seq[MetadataResponseTopic] = { + val topicResponses = metadataCache.getTopicMetadata(topics, listenerName, + errorUnavailableEndpoints, errorUnavailableListeners) + + if (topics.isEmpty || topicResponses.size == topics.size || fetchAllTopics) { + topicResponses + } else { + val nonExistingTopics = topics.diff(topicResponses.map(_.name).toSet) + val nonExistingTopicResponses = if (allowAutoTopicCreation) { + val controllerMutationQuota = quotas.controllerMutation.newPermissiveQuotaFor(request) + autoTopicCreationManager.createTopics(nonExistingTopics, controllerMutationQuota, Some(request.context)) + } else { + nonExistingTopics.map { topic => + val error = try { + Topic.validate(topic) + Errors.UNKNOWN_TOPIC_OR_PARTITION + } catch { + case _: InvalidTopicException => + Errors.INVALID_TOPIC_EXCEPTION + } + + metadataResponseTopic( + error, + topic, + metadataCache.getTopicId(topic), + Topic.isInternal(topic), + util.Collections.emptyList() + ) + } + } + + topicResponses ++ nonExistingTopicResponses + } + } + + def handleTopicMetadataRequest(request: RequestChannel.Request): Unit = { + val metadataRequest = request.body[MetadataRequest] + val requestVersion = request.header.apiVersion + + // Topic IDs are not supported for versions 10 and 11. Topic names can not be null in these versions. + if (!metadataRequest.isAllTopics) { + metadataRequest.data.topics.forEach{ topic => + if (topic.name == null && metadataRequest.version < 12) { + throw new InvalidRequestException(s"Topic name can not be null for version ${metadataRequest.version}") + } else if (topic.topicId != Uuid.ZERO_UUID && metadataRequest.version < 12) { + throw new InvalidRequestException(s"Topic IDs are not supported in requests for version ${metadataRequest.version}") + } + } + } + + // Check if topicId is presented firstly. + val topicIds = metadataRequest.topicIds.asScala.toSet.filterNot(_ == Uuid.ZERO_UUID) + val useTopicId = topicIds.nonEmpty + + // Only get topicIds and topicNames when supporting topicId + val unknownTopicIds = topicIds.filter(metadataCache.getTopicName(_).isEmpty) + val knownTopicNames = topicIds.flatMap(metadataCache.getTopicName) + + val unknownTopicIdsTopicMetadata = unknownTopicIds.map(topicId => + metadataResponseTopic(Errors.UNKNOWN_TOPIC_ID, null, topicId, false, util.Collections.emptyList())).toSeq + + val topics = if (metadataRequest.isAllTopics) + metadataCache.getAllTopics() + else if (useTopicId) + knownTopicNames + else + metadataRequest.topics.asScala.toSet + + val authorizedForDescribeTopics = authHelper.filterByAuthorized(request.context, DESCRIBE, TOPIC, + topics, logIfDenied = !metadataRequest.isAllTopics)(identity) + var (authorizedTopics, unauthorizedForDescribeTopics) = topics.partition(authorizedForDescribeTopics.contains) + var unauthorizedForCreateTopics = Set[String]() + + if (authorizedTopics.nonEmpty) { + val nonExistingTopics = authorizedTopics.filterNot(metadataCache.contains(_)) + if (metadataRequest.allowAutoTopicCreation && config.autoCreateTopicsEnable && nonExistingTopics.nonEmpty) { + if (!authHelper.authorize(request.context, CREATE, CLUSTER, CLUSTER_NAME, logIfDenied = false)) { + val authorizedForCreateTopics = authHelper.filterByAuthorized(request.context, CREATE, TOPIC, + nonExistingTopics)(identity) + unauthorizedForCreateTopics = nonExistingTopics.diff(authorizedForCreateTopics) + authorizedTopics = authorizedTopics.diff(unauthorizedForCreateTopics) + } + } + } + + val unauthorizedForCreateTopicMetadata = unauthorizedForCreateTopics.map(topic => + // Set topicId to zero since we will never create topic which topicId + metadataResponseTopic(Errors.TOPIC_AUTHORIZATION_FAILED, topic, Uuid.ZERO_UUID, isInternal(topic), util.Collections.emptyList())) + + // do not disclose the existence of topics unauthorized for Describe, so we've not even checked if they exist or not + val unauthorizedForDescribeTopicMetadata = + // In case of all topics, don't include topics unauthorized for Describe + if ((requestVersion == 0 && (metadataRequest.topics == null || metadataRequest.topics.isEmpty)) || metadataRequest.isAllTopics) + Set.empty[MetadataResponseTopic] + else if (useTopicId) { + // Topic IDs are not considered sensitive information, so returning TOPIC_AUTHORIZATION_FAILED is OK + unauthorizedForDescribeTopics.map(topic => + metadataResponseTopic(Errors.TOPIC_AUTHORIZATION_FAILED, null, metadataCache.getTopicId(topic), false, util.Collections.emptyList())) + } else { + // We should not return topicId when on unauthorized error, so we return zero uuid. + unauthorizedForDescribeTopics.map(topic => + metadataResponseTopic(Errors.TOPIC_AUTHORIZATION_FAILED, topic, Uuid.ZERO_UUID, false, util.Collections.emptyList())) + } + + // In version 0, we returned an error when brokers with replicas were unavailable, + // while in higher versions we simply don't include the broker in the returned broker list + val errorUnavailableEndpoints = requestVersion == 0 + // In versions 5 and below, we returned LEADER_NOT_AVAILABLE if a matching listener was not found on the leader. + // From version 6 onwards, we return LISTENER_NOT_FOUND to enable diagnosis of configuration errors. + val errorUnavailableListeners = requestVersion >= 6 + + val allowAutoCreation = config.autoCreateTopicsEnable && metadataRequest.allowAutoTopicCreation && !metadataRequest.isAllTopics + val topicMetadata = getTopicMetadata(request, metadataRequest.isAllTopics, allowAutoCreation, authorizedTopics, + request.context.listenerName, errorUnavailableEndpoints, errorUnavailableListeners) + + var clusterAuthorizedOperations = Int.MinValue // Default value in the schema + if (requestVersion >= 8) { + // get cluster authorized operations + if (requestVersion <= 10) { + if (metadataRequest.data.includeClusterAuthorizedOperations) { + if (authHelper.authorize(request.context, DESCRIBE, CLUSTER, CLUSTER_NAME)) + clusterAuthorizedOperations = authHelper.authorizedOperations(request, Resource.CLUSTER) + else + clusterAuthorizedOperations = 0 + } + } + + // get topic authorized operations + if (metadataRequest.data.includeTopicAuthorizedOperations) { + def setTopicAuthorizedOperations(topicMetadata: Seq[MetadataResponseTopic]): Unit = { + topicMetadata.foreach { topicData => + topicData.setTopicAuthorizedOperations(authHelper.authorizedOperations(request, new Resource(ResourceType.TOPIC, topicData.name))) + } + } + setTopicAuthorizedOperations(topicMetadata) + } + } + + val completeTopicMetadata = unknownTopicIdsTopicMetadata ++ + topicMetadata ++ unauthorizedForCreateTopicMetadata ++ unauthorizedForDescribeTopicMetadata + + val brokers = metadataCache.getAliveBrokerNodes(request.context.listenerName) + + trace("Sending topic metadata %s and brokers %s for correlation id %d to client %s".format(completeTopicMetadata.mkString(","), + brokers.mkString(","), request.header.correlationId, request.header.clientId)) + + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + MetadataResponse.prepareResponse( + requestVersion, + requestThrottleMs, + brokers.toList.asJava, + clusterId, + metadataSupport.controllerId.getOrElse(MetadataResponse.NO_CONTROLLER_ID), + completeTopicMetadata.asJava, + clusterAuthorizedOperations + )) + } + + /** + * Handle an offset fetch request + */ + def handleOffsetFetchRequest(request: RequestChannel.Request): Unit = { + val version = request.header.apiVersion + if (version == 0) { + // reading offsets from ZK + handleOffsetFetchRequestV0(request) + } else if (version >= 1 && version <= 7) { + // reading offsets from Kafka + handleOffsetFetchRequestBetweenV1AndV7(request) + } else { + // batching offset reads for multiple groups starts with version 8 and greater + handleOffsetFetchRequestV8AndAbove(request) + } + } + + private def handleOffsetFetchRequestV0(request: RequestChannel.Request): Unit = { + val header = request.header + val offsetFetchRequest = request.body[OffsetFetchRequest] + + def createResponse(requestThrottleMs: Int): AbstractResponse = { + val offsetFetchResponse = + // reject the request if not authorized to the group + if (!authHelper.authorize(request.context, DESCRIBE, GROUP, offsetFetchRequest.groupId)) + offsetFetchRequest.getErrorResponse(requestThrottleMs, Errors.GROUP_AUTHORIZATION_FAILED) + else { + val zkSupport = metadataSupport.requireZkOrThrow(KafkaApis.unsupported("Version 0 offset fetch requests")) + val (authorizedPartitions, unauthorizedPartitions) = partitionByAuthorized( + offsetFetchRequest.partitions.asScala, request.context) + + // version 0 reads offsets from ZK + val authorizedPartitionData = authorizedPartitions.map { topicPartition => + try { + if (!metadataCache.contains(topicPartition)) + (topicPartition, OffsetFetchResponse.UNKNOWN_PARTITION) + else { + val payloadOpt = zkSupport.zkClient.getConsumerOffset(offsetFetchRequest.groupId, topicPartition) + payloadOpt match { + case Some(payload) => + (topicPartition, new OffsetFetchResponse.PartitionData(payload.toLong, + Optional.empty(), OffsetFetchResponse.NO_METADATA, Errors.NONE)) + case None => + (topicPartition, OffsetFetchResponse.UNKNOWN_PARTITION) + } + } + } catch { + case e: Throwable => + (topicPartition, new OffsetFetchResponse.PartitionData(OffsetFetchResponse.INVALID_OFFSET, + Optional.empty(), OffsetFetchResponse.NO_METADATA, Errors.forException(e))) + } + }.toMap + + val unauthorizedPartitionData = unauthorizedPartitions.map(_ -> OffsetFetchResponse.UNAUTHORIZED_PARTITION).toMap + new OffsetFetchResponse(requestThrottleMs, Errors.NONE, (authorizedPartitionData ++ unauthorizedPartitionData).asJava) + } + trace(s"Sending offset fetch response $offsetFetchResponse for correlation id ${header.correlationId} to client ${header.clientId}.") + offsetFetchResponse + } + requestHelper.sendResponseMaybeThrottle(request, createResponse) + } + + private def handleOffsetFetchRequestBetweenV1AndV7(request: RequestChannel.Request): Unit = { + val header = request.header + val offsetFetchRequest = request.body[OffsetFetchRequest] + val groupId = offsetFetchRequest.groupId() + val (error, partitionData) = fetchOffsets(groupId, offsetFetchRequest.isAllPartitions, + offsetFetchRequest.requireStable, offsetFetchRequest.partitions, request.context) + def createResponse(requestThrottleMs: Int): AbstractResponse = { + val offsetFetchResponse = + if (error != Errors.NONE) { + offsetFetchRequest.getErrorResponse(requestThrottleMs, error) + } else { + new OffsetFetchResponse(requestThrottleMs, Errors.NONE, partitionData.asJava) + } + trace(s"Sending offset fetch response $offsetFetchResponse for correlation id ${header.correlationId} to client ${header.clientId}.") + offsetFetchResponse + } + requestHelper.sendResponseMaybeThrottle(request, createResponse) + } + + private def handleOffsetFetchRequestV8AndAbove(request: RequestChannel.Request): Unit = { + val header = request.header + val offsetFetchRequest = request.body[OffsetFetchRequest] + val groupIds = offsetFetchRequest.groupIds().asScala + val groupToErrorMap = mutable.Map.empty[String, Errors] + val groupToPartitionData = mutable.Map.empty[String, util.Map[TopicPartition, PartitionData]] + val groupToTopicPartitions = offsetFetchRequest.groupIdsToPartitions() + groupIds.foreach(g => { + val (error, partitionData) = fetchOffsets(g, + offsetFetchRequest.isAllPartitionsForGroup(g), + offsetFetchRequest.requireStable(), + groupToTopicPartitions.get(g), request.context) + groupToErrorMap += (g -> error) + groupToPartitionData += (g -> partitionData.asJava) + }) + + def createResponse(requestThrottleMs: Int): AbstractResponse = { + val offsetFetchResponse = new OffsetFetchResponse(requestThrottleMs, + groupToErrorMap.asJava, groupToPartitionData.asJava) + trace(s"Sending offset fetch response $offsetFetchResponse for correlation id ${header.correlationId} to client ${header.clientId}.") + offsetFetchResponse + } + + requestHelper.sendResponseMaybeThrottle(request, createResponse) + } + + private def fetchOffsets(groupId: String, isAllPartitions: Boolean, requireStable: Boolean, + partitions: util.List[TopicPartition], context: RequestContext): (Errors, Map[TopicPartition, OffsetFetchResponse.PartitionData]) = { + if (!authHelper.authorize(context, DESCRIBE, GROUP, groupId)) { + (Errors.GROUP_AUTHORIZATION_FAILED, Map.empty) + } else { + if (isAllPartitions) { + val (error, allPartitionData) = groupCoordinator.handleFetchOffsets(groupId, requireStable) + if (error != Errors.NONE) { + (error, allPartitionData) + } else { + // clients are not allowed to see offsets for topics that are not authorized for Describe + val (authorizedPartitionData, _) = authHelper.partitionMapByAuthorized(context, + DESCRIBE, TOPIC, allPartitionData)(_.topic) + (Errors.NONE, authorizedPartitionData) + } + } else { + val (authorizedPartitions, unauthorizedPartitions) = partitionByAuthorized( + partitions.asScala, context) + val (error, authorizedPartitionData) = groupCoordinator.handleFetchOffsets(groupId, + requireStable, Some(authorizedPartitions)) + if (error != Errors.NONE) { + (error, authorizedPartitionData) + } else { + val unauthorizedPartitionData = unauthorizedPartitions.map(_ -> OffsetFetchResponse.UNAUTHORIZED_PARTITION).toMap + (Errors.NONE, authorizedPartitionData ++ unauthorizedPartitionData) + } + } + } + } + + private def partitionByAuthorized(seq: Seq[TopicPartition], context: RequestContext): + (Seq[TopicPartition], Seq[TopicPartition]) = + authHelper.partitionSeqByAuthorized(context, DESCRIBE, TOPIC, seq)(_.topic) + + def handleFindCoordinatorRequest(request: RequestChannel.Request): Unit = { + val version = request.header.apiVersion + if (version < 4) { + handleFindCoordinatorRequestLessThanV4(request) + } else { + handleFindCoordinatorRequestV4AndAbove(request) + } + } + + private def handleFindCoordinatorRequestV4AndAbove(request: RequestChannel.Request): Unit = { + val findCoordinatorRequest = request.body[FindCoordinatorRequest] + + val coordinators = findCoordinatorRequest.data.coordinatorKeys.asScala.map { key => + val (error, node) = getCoordinator(request, findCoordinatorRequest.data.keyType, key) + new FindCoordinatorResponseData.Coordinator() + .setKey(key) + .setErrorCode(error.code) + .setHost(node.host) + .setNodeId(node.id) + .setPort(node.port) + } + def createResponse(requestThrottleMs: Int): AbstractResponse = { + val response = new FindCoordinatorResponse( + new FindCoordinatorResponseData() + .setCoordinators(coordinators.asJava) + .setThrottleTimeMs(requestThrottleMs)) + trace("Sending FindCoordinator response %s for correlation id %d to client %s." + .format(response, request.header.correlationId, request.header.clientId)) + response + } + requestHelper.sendResponseMaybeThrottle(request, createResponse) + } + + private def handleFindCoordinatorRequestLessThanV4(request: RequestChannel.Request): Unit = { + val findCoordinatorRequest = request.body[FindCoordinatorRequest] + + val (error, node) = getCoordinator(request, findCoordinatorRequest.data.keyType, findCoordinatorRequest.data.key) + def createResponse(requestThrottleMs: Int): AbstractResponse = { + val responseBody = new FindCoordinatorResponse( + new FindCoordinatorResponseData() + .setErrorCode(error.code) + .setErrorMessage(error.message()) + .setNodeId(node.id) + .setHost(node.host) + .setPort(node.port) + .setThrottleTimeMs(requestThrottleMs)) + trace("Sending FindCoordinator response %s for correlation id %d to client %s." + .format(responseBody, request.header.correlationId, request.header.clientId)) + responseBody + } + if (error == Errors.NONE) { + requestHelper.sendResponseMaybeThrottle(request, createResponse) + } else { + requestHelper.sendErrorResponseMaybeThrottle(request, error.exception) + } + } + + private def getCoordinator(request: RequestChannel.Request, keyType: Byte, key: String): (Errors, Node) = { + if (keyType == CoordinatorType.GROUP.id && + !authHelper.authorize(request.context, DESCRIBE, GROUP, key)) + (Errors.GROUP_AUTHORIZATION_FAILED, Node.noNode) + else if (keyType == CoordinatorType.TRANSACTION.id && + !authHelper.authorize(request.context, DESCRIBE, TRANSACTIONAL_ID, key)) + (Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED, Node.noNode) + else { + val (partition, internalTopicName) = CoordinatorType.forId(keyType) match { + case CoordinatorType.GROUP => + (groupCoordinator.partitionFor(key), GROUP_METADATA_TOPIC_NAME) + + case CoordinatorType.TRANSACTION => + (txnCoordinator.partitionFor(key), TRANSACTION_STATE_TOPIC_NAME) + } + + val topicMetadata = metadataCache.getTopicMetadata(Set(internalTopicName), request.context.listenerName) + + if (topicMetadata.headOption.isEmpty) { + val controllerMutationQuota = quotas.controllerMutation.newPermissiveQuotaFor(request) + autoTopicCreationManager.createTopics(Seq(internalTopicName).toSet, controllerMutationQuota, None) + (Errors.COORDINATOR_NOT_AVAILABLE, Node.noNode) + } else { + if (topicMetadata.head.errorCode != Errors.NONE.code) { + (Errors.COORDINATOR_NOT_AVAILABLE, Node.noNode) + } else { + val coordinatorEndpoint = topicMetadata.head.partitions.asScala + .find(_.partitionIndex == partition) + .filter(_.leaderId != MetadataResponse.NO_LEADER_ID) + .flatMap(metadata => metadataCache. + getAliveBrokerNode(metadata.leaderId, request.context.listenerName)) + + coordinatorEndpoint match { + case Some(endpoint) => + (Errors.NONE, endpoint) + case _ => + (Errors.COORDINATOR_NOT_AVAILABLE, Node.noNode) + } + } + } + } + } + + def handleDescribeGroupRequest(request: RequestChannel.Request): Unit = { + + def sendResponseCallback(describeGroupsResponseData: DescribeGroupsResponseData): Unit = { + def createResponse(requestThrottleMs: Int): AbstractResponse = { + describeGroupsResponseData.setThrottleTimeMs(requestThrottleMs) + new DescribeGroupsResponse(describeGroupsResponseData) + } + requestHelper.sendResponseMaybeThrottle(request, createResponse) + } + + val describeRequest = request.body[DescribeGroupsRequest] + val describeGroupsResponseData = new DescribeGroupsResponseData() + + describeRequest.data.groups.forEach { groupId => + if (!authHelper.authorize(request.context, DESCRIBE, GROUP, groupId)) { + describeGroupsResponseData.groups.add(DescribeGroupsResponse.forError(groupId, Errors.GROUP_AUTHORIZATION_FAILED)) + } else { + val (error, summary) = groupCoordinator.handleDescribeGroup(groupId) + val members = summary.members.map { member => + new DescribeGroupsResponseData.DescribedGroupMember() + .setMemberId(member.memberId) + .setGroupInstanceId(member.groupInstanceId.orNull) + .setClientId(member.clientId) + .setClientHost(member.clientHost) + .setMemberAssignment(member.assignment) + .setMemberMetadata(member.metadata) + } + + val describedGroup = new DescribeGroupsResponseData.DescribedGroup() + .setErrorCode(error.code) + .setGroupId(groupId) + .setGroupState(summary.state) + .setProtocolType(summary.protocolType) + .setProtocolData(summary.protocol) + .setMembers(members.asJava) + + if (request.header.apiVersion >= 3) { + if (error == Errors.NONE && describeRequest.data.includeAuthorizedOperations) { + describedGroup.setAuthorizedOperations(authHelper.authorizedOperations(request, new Resource(ResourceType.GROUP, groupId))) + } + } + + describeGroupsResponseData.groups.add(describedGroup) + } + } + + sendResponseCallback(describeGroupsResponseData) + } + + def handleListGroupsRequest(request: RequestChannel.Request): Unit = { + val listGroupsRequest = request.body[ListGroupsRequest] + val states = if (listGroupsRequest.data.statesFilter == null) + // Handle a null array the same as empty + immutable.Set[String]() + else + listGroupsRequest.data.statesFilter.asScala.toSet + + def createResponse(throttleMs: Int, groups: List[GroupOverview], error: Errors): AbstractResponse = { + new ListGroupsResponse(new ListGroupsResponseData() + .setErrorCode(error.code) + .setGroups(groups.map { group => + val listedGroup = new ListGroupsResponseData.ListedGroup() + .setGroupId(group.groupId) + .setProtocolType(group.protocolType) + .setGroupState(group.state.toString) + listedGroup + }.asJava) + .setThrottleTimeMs(throttleMs) + ) + } + val (error, groups) = groupCoordinator.handleListGroups(states) + if (authHelper.authorize(request.context, DESCRIBE, CLUSTER, CLUSTER_NAME)) + // With describe cluster access all groups are returned. We keep this alternative for backward compatibility. + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + createResponse(requestThrottleMs, groups, error)) + else { + val filteredGroups = groups.filter(group => authHelper.authorize(request.context, DESCRIBE, GROUP, group.groupId)) + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + createResponse(requestThrottleMs, filteredGroups, error)) + } + } + + def handleJoinGroupRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { + val joinGroupRequest = request.body[JoinGroupRequest] + + // the callback for sending a join-group response + def sendResponseCallback(joinResult: JoinGroupResult): Unit = { + def createResponse(requestThrottleMs: Int): AbstractResponse = { + val protocolName = if (request.context.apiVersion() >= 7) + joinResult.protocolName.orNull + else + joinResult.protocolName.getOrElse(GroupCoordinator.NoProtocol) + + val responseBody = new JoinGroupResponse( + new JoinGroupResponseData() + .setThrottleTimeMs(requestThrottleMs) + .setErrorCode(joinResult.error.code) + .setGenerationId(joinResult.generationId) + .setProtocolType(joinResult.protocolType.orNull) + .setProtocolName(protocolName) + .setLeader(joinResult.leaderId) + .setMemberId(joinResult.memberId) + .setMembers(joinResult.members.asJava) + ) + + trace("Sending join group response %s for correlation id %d to client %s." + .format(responseBody, request.header.correlationId, request.header.clientId)) + responseBody + } + requestHelper.sendResponseMaybeThrottle(request, createResponse) + } + + if (joinGroupRequest.data.groupInstanceId != null && config.interBrokerProtocolVersion < KAFKA_2_3_IV0) { + // Only enable static membership when IBP >= 2.3, because it is not safe for the broker to use the static member logic + // until we are sure that all brokers support it. If static group being loaded by an older coordinator, it will discard + // the group.instance.id field, so static members could accidentally become "dynamic", which leads to wrong states. + sendResponseCallback(JoinGroupResult(JoinGroupRequest.UNKNOWN_MEMBER_ID, Errors.UNSUPPORTED_VERSION)) + } else if (!authHelper.authorize(request.context, READ, GROUP, joinGroupRequest.data.groupId)) { + sendResponseCallback(JoinGroupResult(JoinGroupRequest.UNKNOWN_MEMBER_ID, Errors.GROUP_AUTHORIZATION_FAILED)) + } else { + val groupInstanceId = Option(joinGroupRequest.data.groupInstanceId) + + // Only return MEMBER_ID_REQUIRED error if joinGroupRequest version is >= 4 + // and groupInstanceId is configured to unknown. + val requireKnownMemberId = joinGroupRequest.version >= 4 && groupInstanceId.isEmpty + + // let the coordinator handle join-group + val protocols = joinGroupRequest.data.protocols.valuesList.asScala.map(protocol => + (protocol.name, protocol.metadata)).toList + + groupCoordinator.handleJoinGroup( + joinGroupRequest.data.groupId, + joinGroupRequest.data.memberId, + groupInstanceId, + requireKnownMemberId, + request.header.clientId, + request.context.clientAddress.toString, + joinGroupRequest.data.rebalanceTimeoutMs, + joinGroupRequest.data.sessionTimeoutMs, + joinGroupRequest.data.protocolType, + protocols, + sendResponseCallback, + requestLocal) + } + } + + def handleSyncGroupRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { + val syncGroupRequest = request.body[SyncGroupRequest] + + def sendResponseCallback(syncGroupResult: SyncGroupResult): Unit = { + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new SyncGroupResponse( + new SyncGroupResponseData() + .setErrorCode(syncGroupResult.error.code) + .setProtocolType(syncGroupResult.protocolType.orNull) + .setProtocolName(syncGroupResult.protocolName.orNull) + .setAssignment(syncGroupResult.memberAssignment) + .setThrottleTimeMs(requestThrottleMs) + )) + } + + if (syncGroupRequest.data.groupInstanceId != null && config.interBrokerProtocolVersion < KAFKA_2_3_IV0) { + // Only enable static membership when IBP >= 2.3, because it is not safe for the broker to use the static member logic + // until we are sure that all brokers support it. If static group being loaded by an older coordinator, it will discard + // the group.instance.id field, so static members could accidentally become "dynamic", which leads to wrong states. + sendResponseCallback(SyncGroupResult(Errors.UNSUPPORTED_VERSION)) + } else if (!syncGroupRequest.areMandatoryProtocolTypeAndNamePresent()) { + // Starting from version 5, ProtocolType and ProtocolName fields are mandatory. + sendResponseCallback(SyncGroupResult(Errors.INCONSISTENT_GROUP_PROTOCOL)) + } else if (!authHelper.authorize(request.context, READ, GROUP, syncGroupRequest.data.groupId)) { + sendResponseCallback(SyncGroupResult(Errors.GROUP_AUTHORIZATION_FAILED)) + } else { + val assignmentMap = immutable.Map.newBuilder[String, Array[Byte]] + syncGroupRequest.data.assignments.forEach { assignment => + assignmentMap += (assignment.memberId -> assignment.assignment) + } + + groupCoordinator.handleSyncGroup( + syncGroupRequest.data.groupId, + syncGroupRequest.data.generationId, + syncGroupRequest.data.memberId, + Option(syncGroupRequest.data.protocolType), + Option(syncGroupRequest.data.protocolName), + Option(syncGroupRequest.data.groupInstanceId), + assignmentMap.result(), + sendResponseCallback, + requestLocal + ) + } + } + + def handleDeleteGroupsRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { + val deleteGroupsRequest = request.body[DeleteGroupsRequest] + val groups = deleteGroupsRequest.data.groupsNames.asScala.distinct + + val (authorizedGroups, unauthorizedGroups) = authHelper.partitionSeqByAuthorized(request.context, DELETE, GROUP, + groups)(identity) + + val groupDeletionResult = groupCoordinator.handleDeleteGroups(authorizedGroups.toSet, requestLocal) ++ + unauthorizedGroups.map(_ -> Errors.GROUP_AUTHORIZATION_FAILED) + + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => { + val deletionCollections = new DeletableGroupResultCollection() + groupDeletionResult.forKeyValue { (groupId, error) => + deletionCollections.add(new DeletableGroupResult() + .setGroupId(groupId) + .setErrorCode(error.code) + ) + } + + new DeleteGroupsResponse(new DeleteGroupsResponseData() + .setResults(deletionCollections) + .setThrottleTimeMs(requestThrottleMs) + ) + }) + } + + def handleHeartbeatRequest(request: RequestChannel.Request): Unit = { + val heartbeatRequest = request.body[HeartbeatRequest] + + // the callback for sending a heartbeat response + def sendResponseCallback(error: Errors): Unit = { + def createResponse(requestThrottleMs: Int): AbstractResponse = { + val response = new HeartbeatResponse( + new HeartbeatResponseData() + .setThrottleTimeMs(requestThrottleMs) + .setErrorCode(error.code)) + trace("Sending heartbeat response %s for correlation id %d to client %s." + .format(response, request.header.correlationId, request.header.clientId)) + response + } + requestHelper.sendResponseMaybeThrottle(request, createResponse) + } + + if (heartbeatRequest.data.groupInstanceId != null && config.interBrokerProtocolVersion < KAFKA_2_3_IV0) { + // Only enable static membership when IBP >= 2.3, because it is not safe for the broker to use the static member logic + // until we are sure that all brokers support it. If static group being loaded by an older coordinator, it will discard + // the group.instance.id field, so static members could accidentally become "dynamic", which leads to wrong states. + sendResponseCallback(Errors.UNSUPPORTED_VERSION) + } else if (!authHelper.authorize(request.context, READ, GROUP, heartbeatRequest.data.groupId)) { + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new HeartbeatResponse( + new HeartbeatResponseData() + .setThrottleTimeMs(requestThrottleMs) + .setErrorCode(Errors.GROUP_AUTHORIZATION_FAILED.code))) + } else { + // let the coordinator to handle heartbeat + groupCoordinator.handleHeartbeat( + heartbeatRequest.data.groupId, + heartbeatRequest.data.memberId, + Option(heartbeatRequest.data.groupInstanceId), + heartbeatRequest.data.generationId, + sendResponseCallback) + } + } + + def handleLeaveGroupRequest(request: RequestChannel.Request): Unit = { + val leaveGroupRequest = request.body[LeaveGroupRequest] + + val members = leaveGroupRequest.members.asScala.toList + + if (!authHelper.authorize(request.context, READ, GROUP, leaveGroupRequest.data.groupId)) { + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => { + new LeaveGroupResponse(new LeaveGroupResponseData() + .setThrottleTimeMs(requestThrottleMs) + .setErrorCode(Errors.GROUP_AUTHORIZATION_FAILED.code) + ) + }) + } else { + def sendResponseCallback(leaveGroupResult : LeaveGroupResult): Unit = { + val memberResponses = leaveGroupResult.memberResponses.map( + leaveGroupResult => + new MemberResponse() + .setErrorCode(leaveGroupResult.error.code) + .setMemberId(leaveGroupResult.memberId) + .setGroupInstanceId(leaveGroupResult.groupInstanceId.orNull) + ) + def createResponse(requestThrottleMs: Int): AbstractResponse = { + new LeaveGroupResponse( + memberResponses.asJava, + leaveGroupResult.topLevelError, + requestThrottleMs, + leaveGroupRequest.version) + } + requestHelper.sendResponseMaybeThrottle(request, createResponse) + } + + groupCoordinator.handleLeaveGroup( + leaveGroupRequest.data.groupId, + members, + sendResponseCallback) + } + } + + def handleSaslHandshakeRequest(request: RequestChannel.Request): Unit = { + val responseData = new SaslHandshakeResponseData().setErrorCode(Errors.ILLEGAL_SASL_STATE.code) + requestHelper.sendResponseMaybeThrottle(request, _ => new SaslHandshakeResponse(responseData)) + } + + def handleSaslAuthenticateRequest(request: RequestChannel.Request): Unit = { + val responseData = new SaslAuthenticateResponseData() + .setErrorCode(Errors.ILLEGAL_SASL_STATE.code) + .setErrorMessage("SaslAuthenticate request received after successful authentication") + requestHelper.sendResponseMaybeThrottle(request, _ => new SaslAuthenticateResponse(responseData)) + } + + def handleApiVersionsRequest(request: RequestChannel.Request): Unit = { + // Note that broker returns its full list of supported ApiKeys and versions regardless of current + // authentication state (e.g., before SASL authentication on an SASL listener, do note that no + // Kafka protocol requests may take place on an SSL listener before the SSL handshake is finished). + // If this is considered to leak information about the broker version a workaround is to use SSL + // with client authentication which is performed at an earlier stage of the connection where the + // ApiVersionRequest is not available. + def createResponseCallback(requestThrottleMs: Int): ApiVersionsResponse = { + val apiVersionRequest = request.body[ApiVersionsRequest] + if (apiVersionRequest.hasUnsupportedRequestVersion) { + apiVersionRequest.getErrorResponse(requestThrottleMs, Errors.UNSUPPORTED_VERSION.exception) + } else if (!apiVersionRequest.isValid) { + apiVersionRequest.getErrorResponse(requestThrottleMs, Errors.INVALID_REQUEST.exception) + } else { + apiVersionManager.apiVersionResponse(requestThrottleMs) + } + } + requestHelper.sendResponseMaybeThrottle(request, createResponseCallback) + } + + def handleCreateTopicsRequest(request: RequestChannel.Request): Unit = { + val zkSupport = metadataSupport.requireZkOrThrow(KafkaApis.shouldAlwaysForward(request)) + val controllerMutationQuota = quotas.controllerMutation.newQuotaFor(request, strictSinceVersion = 6) + + def sendResponseCallback(results: CreatableTopicResultCollection): Unit = { + def createResponse(requestThrottleMs: Int): AbstractResponse = { + val responseData = new CreateTopicsResponseData() + .setThrottleTimeMs(requestThrottleMs) + .setTopics(results) + val responseBody = new CreateTopicsResponse(responseData) + trace(s"Sending create topics response $responseData for correlation id " + + s"${request.header.correlationId} to client ${request.header.clientId}.") + responseBody + } + requestHelper.sendResponseMaybeThrottleWithControllerQuota(controllerMutationQuota, request, createResponse) + } + + val createTopicsRequest = request.body[CreateTopicsRequest] + val results = new CreatableTopicResultCollection(createTopicsRequest.data.topics.size) + if (!zkSupport.controller.isActive) { + createTopicsRequest.data.topics.forEach { topic => + results.add(new CreatableTopicResult().setName(topic.name) + .setErrorCode(Errors.NOT_CONTROLLER.code)) + } + sendResponseCallback(results) + } else { + createTopicsRequest.data.topics.forEach { topic => + results.add(new CreatableTopicResult().setName(topic.name)) + } + val hasClusterAuthorization = authHelper.authorize(request.context, CREATE, CLUSTER, CLUSTER_NAME, + logIfDenied = false) + val topics = createTopicsRequest.data.topics.asScala.map(_.name) + val authorizedTopics = + if (hasClusterAuthorization) topics.toSet + else authHelper.filterByAuthorized(request.context, CREATE, TOPIC, topics)(identity) + val authorizedForDescribeConfigs = authHelper.filterByAuthorized(request.context, DESCRIBE_CONFIGS, TOPIC, + topics, logIfDenied = false)(identity).map(name => name -> results.find(name)).toMap + + results.forEach { topic => + if (results.findAll(topic.name).size > 1) { + topic.setErrorCode(Errors.INVALID_REQUEST.code) + topic.setErrorMessage("Found multiple entries for this topic.") + } else if (!authorizedTopics.contains(topic.name)) { + topic.setErrorCode(Errors.TOPIC_AUTHORIZATION_FAILED.code) + topic.setErrorMessage("Authorization failed.") + } + if (!authorizedForDescribeConfigs.contains(topic.name)) { + topic.setTopicConfigErrorCode(Errors.TOPIC_AUTHORIZATION_FAILED.code) + } + } + val toCreate = mutable.Map[String, CreatableTopic]() + createTopicsRequest.data.topics.forEach { topic => + if (results.find(topic.name).errorCode == Errors.NONE.code) { + toCreate += topic.name -> topic + } + } + def handleCreateTopicsResults(errors: Map[String, ApiError]): Unit = { + errors.foreach { case (topicName, error) => + val result = results.find(topicName) + result.setErrorCode(error.error.code) + .setErrorMessage(error.message) + // Reset any configs in the response if Create failed + if (error != ApiError.NONE) { + result.setConfigs(List.empty.asJava) + .setNumPartitions(-1) + .setReplicationFactor(-1) + .setTopicConfigErrorCode(Errors.NONE.code) + } + } + sendResponseCallback(results) + } + zkSupport.adminManager.createTopics( + createTopicsRequest.data.timeoutMs, + createTopicsRequest.data.validateOnly, + toCreate, + authorizedForDescribeConfigs, + controllerMutationQuota, + handleCreateTopicsResults) + } + } + + def handleCreatePartitionsRequest(request: RequestChannel.Request): Unit = { + val zkSupport = metadataSupport.requireZkOrThrow(KafkaApis.shouldAlwaysForward(request)) + val createPartitionsRequest = request.body[CreatePartitionsRequest] + val controllerMutationQuota = quotas.controllerMutation.newQuotaFor(request, strictSinceVersion = 3) + + def sendResponseCallback(results: Map[String, ApiError]): Unit = { + def createResponse(requestThrottleMs: Int): AbstractResponse = { + val createPartitionsResults = results.map { + case (topic, error) => new CreatePartitionsTopicResult() + .setName(topic) + .setErrorCode(error.error.code) + .setErrorMessage(error.message) + }.toSeq + val responseBody = new CreatePartitionsResponse(new CreatePartitionsResponseData() + .setThrottleTimeMs(requestThrottleMs) + .setResults(createPartitionsResults.asJava)) + trace(s"Sending create partitions response $responseBody for correlation id ${request.header.correlationId} to " + + s"client ${request.header.clientId}.") + responseBody + } + requestHelper.sendResponseMaybeThrottleWithControllerQuota(controllerMutationQuota, request, createResponse) + } + + if (!zkSupport.controller.isActive) { + val result = createPartitionsRequest.data.topics.asScala.map { topic => + (topic.name, new ApiError(Errors.NOT_CONTROLLER, null)) + }.toMap + sendResponseCallback(result) + } else { + // Special handling to add duplicate topics to the response + val topics = createPartitionsRequest.data.topics.asScala.toSeq + val dupes = topics.groupBy(_.name) + .filter { _._2.size > 1 } + .keySet + val notDuped = topics.filterNot(topic => dupes.contains(topic.name)) + val (authorized, unauthorized) = authHelper.partitionSeqByAuthorized(request.context, ALTER, TOPIC, + notDuped)(_.name) + + val (queuedForDeletion, valid) = authorized.partition { topic => + zkSupport.controller.topicDeletionManager.isTopicQueuedUpForDeletion(topic.name) + } + + val errors = dupes.map(_ -> new ApiError(Errors.INVALID_REQUEST, "Duplicate topic in request.")) ++ + unauthorized.map(_.name -> new ApiError(Errors.TOPIC_AUTHORIZATION_FAILED, "The topic authorization is failed.")) ++ + queuedForDeletion.map(_.name -> new ApiError(Errors.INVALID_TOPIC_EXCEPTION, "The topic is queued for deletion.")) + + zkSupport.adminManager.createPartitions( + createPartitionsRequest.data.timeoutMs, + valid, + createPartitionsRequest.data.validateOnly, + controllerMutationQuota, + result => sendResponseCallback(result ++ errors)) + } + } + + def handleDeleteTopicsRequest(request: RequestChannel.Request): Unit = { + val zkSupport = metadataSupport.requireZkOrThrow(KafkaApis.shouldAlwaysForward(request)) + val controllerMutationQuota = quotas.controllerMutation.newQuotaFor(request, strictSinceVersion = 5) + + def sendResponseCallback(results: DeletableTopicResultCollection): Unit = { + def createResponse(requestThrottleMs: Int): AbstractResponse = { + val responseData = new DeleteTopicsResponseData() + .setThrottleTimeMs(requestThrottleMs) + .setResponses(results) + val responseBody = new DeleteTopicsResponse(responseData) + trace(s"Sending delete topics response $responseBody for correlation id ${request.header.correlationId} to client ${request.header.clientId}.") + responseBody + } + requestHelper.sendResponseMaybeThrottleWithControllerQuota(controllerMutationQuota, request, createResponse) + } + + val deleteTopicRequest = request.body[DeleteTopicsRequest] + val results = new DeletableTopicResultCollection(deleteTopicRequest.numberOfTopics()) + val toDelete = mutable.Set[String]() + if (!zkSupport.controller.isActive) { + deleteTopicRequest.topics().forEach { topic => + results.add(new DeletableTopicResult() + .setName(topic.name()) + .setTopicId(topic.topicId()) + .setErrorCode(Errors.NOT_CONTROLLER.code)) + } + sendResponseCallback(results) + } else if (!config.deleteTopicEnable) { + val error = if (request.context.apiVersion < 3) Errors.INVALID_REQUEST else Errors.TOPIC_DELETION_DISABLED + deleteTopicRequest.topics().forEach { topic => + results.add(new DeletableTopicResult() + .setName(topic.name()) + .setTopicId(topic.topicId()) + .setErrorCode(error.code)) + } + sendResponseCallback(results) + } else { + val topicIdsFromRequest = deleteTopicRequest.topicIds().asScala.filter(topicId => topicId != Uuid.ZERO_UUID).toSet + deleteTopicRequest.topics().forEach { topic => + if (topic.name() != null && topic.topicId() != Uuid.ZERO_UUID) + throw new InvalidRequestException("Topic name and topic ID can not both be specified.") + val name = if (topic.topicId() == Uuid.ZERO_UUID) topic.name() + else zkSupport.controller.controllerContext.topicName(topic.topicId).orNull + results.add(new DeletableTopicResult() + .setName(name) + .setTopicId(topic.topicId())) + } + val authorizedDescribeTopics = authHelper.filterByAuthorized(request.context, DESCRIBE, TOPIC, + results.asScala.filter(result => result.name() != null))(_.name) + val authorizedDeleteTopics = authHelper.filterByAuthorized(request.context, DELETE, TOPIC, + results.asScala.filter(result => result.name() != null))(_.name) + results.forEach { topic => + val unresolvedTopicId = topic.topicId() != Uuid.ZERO_UUID && topic.name() == null + if (unresolvedTopicId) { + topic.setErrorCode(Errors.UNKNOWN_TOPIC_ID.code) + } else if (topicIdsFromRequest.contains(topic.topicId) && !authorizedDescribeTopics.contains(topic.name)) { + + // Because the client does not have Describe permission, the name should + // not be returned in the response. Note, however, that we do not consider + // the topicId itself to be sensitive, so there is no reason to obscure + // this case with `UNKNOWN_TOPIC_ID`. + topic.setName(null) + topic.setErrorCode(Errors.TOPIC_AUTHORIZATION_FAILED.code) + } else if (!authorizedDeleteTopics.contains(topic.name)) { + topic.setErrorCode(Errors.TOPIC_AUTHORIZATION_FAILED.code) + } else if (!metadataCache.contains(topic.name)) { + topic.setErrorCode(Errors.UNKNOWN_TOPIC_OR_PARTITION.code) + } else { + toDelete += topic.name + } + } + // If no authorized topics return immediately + if (toDelete.isEmpty) + sendResponseCallback(results) + else { + def handleDeleteTopicsResults(errors: Map[String, Errors]): Unit = { + errors.foreach { + case (topicName, error) => + results.find(topicName) + .setErrorCode(error.code) + } + sendResponseCallback(results) + } + + zkSupport.adminManager.deleteTopics( + deleteTopicRequest.data.timeoutMs, + toDelete, + controllerMutationQuota, + handleDeleteTopicsResults + ) + } + } + } + + def handleDeleteRecordsRequest(request: RequestChannel.Request): Unit = { + val deleteRecordsRequest = request.body[DeleteRecordsRequest] + + val unauthorizedTopicResponses = mutable.Map[TopicPartition, DeleteRecordsPartitionResult]() + val nonExistingTopicResponses = mutable.Map[TopicPartition, DeleteRecordsPartitionResult]() + val authorizedForDeleteTopicOffsets = mutable.Map[TopicPartition, Long]() + + val topics = deleteRecordsRequest.data.topics.asScala + val authorizedTopics = authHelper.filterByAuthorized(request.context, DELETE, TOPIC, topics)(_.name) + val deleteTopicPartitions = topics.flatMap { deleteTopic => + deleteTopic.partitions.asScala.map { deletePartition => + new TopicPartition(deleteTopic.name, deletePartition.partitionIndex) -> deletePartition.offset + } + } + for ((topicPartition, offset) <- deleteTopicPartitions) { + if (!authorizedTopics.contains(topicPartition.topic)) + unauthorizedTopicResponses += topicPartition -> new DeleteRecordsPartitionResult() + .setLowWatermark(DeleteRecordsResponse.INVALID_LOW_WATERMARK) + .setErrorCode(Errors.TOPIC_AUTHORIZATION_FAILED.code) + else if (!metadataCache.contains(topicPartition)) + nonExistingTopicResponses += topicPartition -> new DeleteRecordsPartitionResult() + .setLowWatermark(DeleteRecordsResponse.INVALID_LOW_WATERMARK) + .setErrorCode(Errors.UNKNOWN_TOPIC_OR_PARTITION.code) + else + authorizedForDeleteTopicOffsets += (topicPartition -> offset) + } + + // the callback for sending a DeleteRecordsResponse + def sendResponseCallback(authorizedTopicResponses: Map[TopicPartition, DeleteRecordsPartitionResult]): Unit = { + val mergedResponseStatus = authorizedTopicResponses ++ unauthorizedTopicResponses ++ nonExistingTopicResponses + mergedResponseStatus.forKeyValue { (topicPartition, status) => + if (status.errorCode != Errors.NONE.code) { + debug("DeleteRecordsRequest with correlation id %d from client %s on partition %s failed due to %s".format( + request.header.correlationId, + request.header.clientId, + topicPartition, + Errors.forCode(status.errorCode).exceptionName)) + } + } + + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new DeleteRecordsResponse(new DeleteRecordsResponseData() + .setThrottleTimeMs(requestThrottleMs) + .setTopics(new DeleteRecordsResponseData.DeleteRecordsTopicResultCollection(mergedResponseStatus.groupBy(_._1.topic).map { case (topic, partitionMap) => { + new DeleteRecordsTopicResult() + .setName(topic) + .setPartitions(new DeleteRecordsResponseData.DeleteRecordsPartitionResultCollection(partitionMap.map { case (topicPartition, partitionResult) => { + new DeleteRecordsPartitionResult().setPartitionIndex(topicPartition.partition) + .setLowWatermark(partitionResult.lowWatermark) + .setErrorCode(partitionResult.errorCode) + } + }.toList.asJava.iterator())) + } + }.toList.asJava.iterator())))) + } + + if (authorizedForDeleteTopicOffsets.isEmpty) + sendResponseCallback(Map.empty) + else { + // call the replica manager to append messages to the replicas + replicaManager.deleteRecords( + deleteRecordsRequest.data.timeoutMs.toLong, + authorizedForDeleteTopicOffsets, + sendResponseCallback) + } + } + + def handleInitProducerIdRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { + val initProducerIdRequest = request.body[InitProducerIdRequest] + val transactionalId = initProducerIdRequest.data.transactionalId + + if (transactionalId != null) { + if (!authHelper.authorize(request.context, WRITE, TRANSACTIONAL_ID, transactionalId)) { + requestHelper.sendErrorResponseMaybeThrottle(request, Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED.exception) + return + } + } else if (!authHelper.authorize(request.context, IDEMPOTENT_WRITE, CLUSTER, CLUSTER_NAME, true, false) + && !authHelper.authorizeByResourceType(request.context, AclOperation.WRITE, ResourceType.TOPIC)) { + requestHelper.sendErrorResponseMaybeThrottle(request, Errors.CLUSTER_AUTHORIZATION_FAILED.exception) + return + } + + def sendResponseCallback(result: InitProducerIdResult): Unit = { + def createResponse(requestThrottleMs: Int): AbstractResponse = { + val finalError = + if (initProducerIdRequest.version < 4 && result.error == Errors.PRODUCER_FENCED) { + // For older clients, they could not understand the new PRODUCER_FENCED error code, + // so we need to return the INVALID_PRODUCER_EPOCH to have the same client handling logic. + Errors.INVALID_PRODUCER_EPOCH + } else { + result.error + } + val responseData = new InitProducerIdResponseData() + .setProducerId(result.producerId) + .setProducerEpoch(result.producerEpoch) + .setThrottleTimeMs(requestThrottleMs) + .setErrorCode(finalError.code) + val responseBody = new InitProducerIdResponse(responseData) + trace(s"Completed $transactionalId's InitProducerIdRequest with result $result from client ${request.header.clientId}.") + responseBody + } + requestHelper.sendResponseMaybeThrottle(request, createResponse) + } + + val producerIdAndEpoch = (initProducerIdRequest.data.producerId, initProducerIdRequest.data.producerEpoch) match { + case (RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH) => Right(None) + case (RecordBatch.NO_PRODUCER_ID, _) | (_, RecordBatch.NO_PRODUCER_EPOCH) => Left(Errors.INVALID_REQUEST) + case (_, _) => Right(Some(new ProducerIdAndEpoch(initProducerIdRequest.data.producerId, initProducerIdRequest.data.producerEpoch))) + } + + producerIdAndEpoch match { + case Right(producerIdAndEpoch) => txnCoordinator.handleInitProducerId(transactionalId, initProducerIdRequest.data.transactionTimeoutMs, + producerIdAndEpoch, sendResponseCallback, requestLocal) + case Left(error) => requestHelper.sendErrorResponseMaybeThrottle(request, error.exception) + } + } + + def handleEndTxnRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { + ensureInterBrokerVersion(KAFKA_0_11_0_IV0) + val endTxnRequest = request.body[EndTxnRequest] + val transactionalId = endTxnRequest.data.transactionalId + + if (authHelper.authorize(request.context, WRITE, TRANSACTIONAL_ID, transactionalId)) { + def sendResponseCallback(error: Errors): Unit = { + def createResponse(requestThrottleMs: Int): AbstractResponse = { + val finalError = + if (endTxnRequest.version < 2 && error == Errors.PRODUCER_FENCED) { + // For older clients, they could not understand the new PRODUCER_FENCED error code, + // so we need to return the INVALID_PRODUCER_EPOCH to have the same client handling logic. + Errors.INVALID_PRODUCER_EPOCH + } else { + error + } + val responseBody = new EndTxnResponse(new EndTxnResponseData() + .setErrorCode(finalError.code) + .setThrottleTimeMs(requestThrottleMs)) + trace(s"Completed ${endTxnRequest.data.transactionalId}'s EndTxnRequest " + + s"with committed: ${endTxnRequest.data.committed}, " + + s"errors: $error from client ${request.header.clientId}.") + responseBody + } + requestHelper.sendResponseMaybeThrottle(request, createResponse) + } + + txnCoordinator.handleEndTransaction(endTxnRequest.data.transactionalId, + endTxnRequest.data.producerId, + endTxnRequest.data.producerEpoch, + endTxnRequest.result(), + sendResponseCallback, + requestLocal) + } else + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new EndTxnResponse(new EndTxnResponseData() + .setErrorCode(Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED.code) + .setThrottleTimeMs(requestThrottleMs)) + ) + } + + def handleWriteTxnMarkersRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { + ensureInterBrokerVersion(KAFKA_0_11_0_IV0) + authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) + val writeTxnMarkersRequest = request.body[WriteTxnMarkersRequest] + val errors = new ConcurrentHashMap[java.lang.Long, util.Map[TopicPartition, Errors]]() + val markers = writeTxnMarkersRequest.markers + val numAppends = new AtomicInteger(markers.size) + + if (numAppends.get == 0) { + requestHelper.sendResponseExemptThrottle(request, new WriteTxnMarkersResponse(errors)) + return + } + + def updateErrors(producerId: Long, currentErrors: ConcurrentHashMap[TopicPartition, Errors]): Unit = { + val previousErrors = errors.putIfAbsent(producerId, currentErrors) + if (previousErrors != null) + previousErrors.putAll(currentErrors) + } + + /** + * This is the call back invoked when a log append of transaction markers succeeds. This can be called multiple + * times when handling a single WriteTxnMarkersRequest because there is one append per TransactionMarker in the + * request, so there could be multiple appends of markers to the log. The final response will be sent only + * after all appends have returned. + */ + def maybeSendResponseCallback(producerId: Long, result: TransactionResult)(responseStatus: Map[TopicPartition, PartitionResponse]): Unit = { + trace(s"End transaction marker append for producer id $producerId completed with status: $responseStatus") + val currentErrors = new ConcurrentHashMap[TopicPartition, Errors](responseStatus.map { case (k, v) => k -> v.error }.asJava) + updateErrors(producerId, currentErrors) + val successfulOffsetsPartitions = responseStatus.filter { case (topicPartition, partitionResponse) => + topicPartition.topic == GROUP_METADATA_TOPIC_NAME && partitionResponse.error == Errors.NONE + }.keys + + if (successfulOffsetsPartitions.nonEmpty) { + // as soon as the end transaction marker has been written for a transactional offset commit, + // call to the group coordinator to materialize the offsets into the cache + try { + groupCoordinator.scheduleHandleTxnCompletion(producerId, successfulOffsetsPartitions, result) + } catch { + case e: Exception => + error(s"Received an exception while trying to update the offsets cache on transaction marker append", e) + val updatedErrors = new ConcurrentHashMap[TopicPartition, Errors]() + successfulOffsetsPartitions.foreach(updatedErrors.put(_, Errors.UNKNOWN_SERVER_ERROR)) + updateErrors(producerId, updatedErrors) + } + } + + if (numAppends.decrementAndGet() == 0) + requestHelper.sendResponseExemptThrottle(request, new WriteTxnMarkersResponse(errors)) + } + + // TODO: The current append API makes doing separate writes per producerId a little easier, but it would + // be nice to have only one append to the log. This requires pushing the building of the control records + // into Log so that we only append those having a valid producer epoch, and exposing a new appendControlRecord + // API in ReplicaManager. For now, we've done the simpler approach + var skippedMarkers = 0 + for (marker <- markers.asScala) { + val producerId = marker.producerId + val partitionsWithCompatibleMessageFormat = new mutable.ArrayBuffer[TopicPartition] + + val currentErrors = new ConcurrentHashMap[TopicPartition, Errors]() + marker.partitions.forEach { partition => + replicaManager.getMagic(partition) match { + case Some(magic) => + if (magic < RecordBatch.MAGIC_VALUE_V2) + currentErrors.put(partition, Errors.UNSUPPORTED_FOR_MESSAGE_FORMAT) + else + partitionsWithCompatibleMessageFormat += partition + case None => + currentErrors.put(partition, Errors.UNKNOWN_TOPIC_OR_PARTITION) + } + } + + if (!currentErrors.isEmpty) + updateErrors(producerId, currentErrors) + + if (partitionsWithCompatibleMessageFormat.isEmpty) { + numAppends.decrementAndGet() + skippedMarkers += 1 + } else { + val controlRecords = partitionsWithCompatibleMessageFormat.map { partition => + val controlRecordType = marker.transactionResult match { + case TransactionResult.COMMIT => ControlRecordType.COMMIT + case TransactionResult.ABORT => ControlRecordType.ABORT + } + val endTxnMarker = new EndTransactionMarker(controlRecordType, marker.coordinatorEpoch) + partition -> MemoryRecords.withEndTransactionMarker(producerId, marker.producerEpoch, endTxnMarker) + }.toMap + + replicaManager.appendRecords( + timeout = config.requestTimeoutMs.toLong, + requiredAcks = -1, + internalTopicsAllowed = true, + origin = AppendOrigin.Coordinator, + entriesPerPartition = controlRecords, + requestLocal = requestLocal, + responseCallback = maybeSendResponseCallback(producerId, marker.transactionResult)) + } + } + + // No log appends were written as all partitions had incorrect log format + // so we need to send the error response + if (skippedMarkers == markers.size) + requestHelper.sendResponseExemptThrottle(request, new WriteTxnMarkersResponse(errors)) + } + + def ensureInterBrokerVersion(version: ApiVersion): Unit = { + if (config.interBrokerProtocolVersion < version) + throw new UnsupportedVersionException(s"inter.broker.protocol.version: ${config.interBrokerProtocolVersion.version} is less than the required version: ${version.version}") + } + + def handleAddPartitionToTxnRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { + ensureInterBrokerVersion(KAFKA_0_11_0_IV0) + val addPartitionsToTxnRequest = request.body[AddPartitionsToTxnRequest] + val transactionalId = addPartitionsToTxnRequest.data.transactionalId + val partitionsToAdd = addPartitionsToTxnRequest.partitions.asScala + if (!authHelper.authorize(request.context, WRITE, TRANSACTIONAL_ID, transactionalId)) + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + addPartitionsToTxnRequest.getErrorResponse(requestThrottleMs, Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED.exception)) + else { + val unauthorizedTopicErrors = mutable.Map[TopicPartition, Errors]() + val nonExistingTopicErrors = mutable.Map[TopicPartition, Errors]() + val authorizedPartitions = mutable.Set[TopicPartition]() + + val authorizedTopics = authHelper.filterByAuthorized(request.context, WRITE, TOPIC, + partitionsToAdd.filterNot(tp => Topic.isInternal(tp.topic)))(_.topic) + for (topicPartition <- partitionsToAdd) { + if (!authorizedTopics.contains(topicPartition.topic)) + unauthorizedTopicErrors += topicPartition -> Errors.TOPIC_AUTHORIZATION_FAILED + else if (!metadataCache.contains(topicPartition)) + nonExistingTopicErrors += topicPartition -> Errors.UNKNOWN_TOPIC_OR_PARTITION + else + authorizedPartitions.add(topicPartition) + } + + if (unauthorizedTopicErrors.nonEmpty || nonExistingTopicErrors.nonEmpty) { + // Any failed partition check causes the entire request to fail. We send the appropriate error codes for the + // partitions which failed, and an 'OPERATION_NOT_ATTEMPTED' error code for the partitions which succeeded + // the authorization check to indicate that they were not added to the transaction. + val partitionErrors = unauthorizedTopicErrors ++ nonExistingTopicErrors ++ + authorizedPartitions.map(_ -> Errors.OPERATION_NOT_ATTEMPTED) + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new AddPartitionsToTxnResponse(requestThrottleMs, partitionErrors.asJava)) + } else { + def sendResponseCallback(error: Errors): Unit = { + def createResponse(requestThrottleMs: Int): AbstractResponse = { + val finalError = + if (addPartitionsToTxnRequest.version < 2 && error == Errors.PRODUCER_FENCED) { + // For older clients, they could not understand the new PRODUCER_FENCED error code, + // so we need to return the old INVALID_PRODUCER_EPOCH to have the same client handling logic. + Errors.INVALID_PRODUCER_EPOCH + } else { + error + } + + val responseBody: AddPartitionsToTxnResponse = new AddPartitionsToTxnResponse(requestThrottleMs, + partitionsToAdd.map{tp => (tp, finalError)}.toMap.asJava) + trace(s"Completed $transactionalId's AddPartitionsToTxnRequest with partitions $partitionsToAdd: errors: $error from client ${request.header.clientId}") + responseBody + } + + requestHelper.sendResponseMaybeThrottle(request, createResponse) + } + + txnCoordinator.handleAddPartitionsToTransaction(transactionalId, + addPartitionsToTxnRequest.data.producerId, + addPartitionsToTxnRequest.data.producerEpoch, + authorizedPartitions, + sendResponseCallback, + requestLocal) + } + } + } + + def handleAddOffsetsToTxnRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { + ensureInterBrokerVersion(KAFKA_0_11_0_IV0) + val addOffsetsToTxnRequest = request.body[AddOffsetsToTxnRequest] + val transactionalId = addOffsetsToTxnRequest.data.transactionalId + val groupId = addOffsetsToTxnRequest.data.groupId + val offsetTopicPartition = new TopicPartition(GROUP_METADATA_TOPIC_NAME, groupCoordinator.partitionFor(groupId)) + + if (!authHelper.authorize(request.context, WRITE, TRANSACTIONAL_ID, transactionalId)) + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new AddOffsetsToTxnResponse(new AddOffsetsToTxnResponseData() + .setErrorCode(Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED.code) + .setThrottleTimeMs(requestThrottleMs))) + else if (!authHelper.authorize(request.context, READ, GROUP, groupId)) + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new AddOffsetsToTxnResponse(new AddOffsetsToTxnResponseData() + .setErrorCode(Errors.GROUP_AUTHORIZATION_FAILED.code) + .setThrottleTimeMs(requestThrottleMs)) + ) + else { + def sendResponseCallback(error: Errors): Unit = { + def createResponse(requestThrottleMs: Int): AbstractResponse = { + val finalError = + if (addOffsetsToTxnRequest.version < 2 && error == Errors.PRODUCER_FENCED) { + // For older clients, they could not understand the new PRODUCER_FENCED error code, + // so we need to return the old INVALID_PRODUCER_EPOCH to have the same client handling logic. + Errors.INVALID_PRODUCER_EPOCH + } else { + error + } + + val responseBody: AddOffsetsToTxnResponse = new AddOffsetsToTxnResponse( + new AddOffsetsToTxnResponseData() + .setErrorCode(finalError.code) + .setThrottleTimeMs(requestThrottleMs)) + trace(s"Completed $transactionalId's AddOffsetsToTxnRequest for group $groupId on partition " + + s"$offsetTopicPartition: errors: $error from client ${request.header.clientId}") + responseBody + } + requestHelper.sendResponseMaybeThrottle(request, createResponse) + } + + txnCoordinator.handleAddPartitionsToTransaction(transactionalId, + addOffsetsToTxnRequest.data.producerId, + addOffsetsToTxnRequest.data.producerEpoch, + Set(offsetTopicPartition), + sendResponseCallback, + requestLocal) + } + } + + def handleTxnOffsetCommitRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { + ensureInterBrokerVersion(KAFKA_0_11_0_IV0) + val header = request.header + val txnOffsetCommitRequest = request.body[TxnOffsetCommitRequest] + + // authorize for the transactionalId and the consumer group. Note that we skip producerId authorization + // since it is implied by transactionalId authorization + if (!authHelper.authorize(request.context, WRITE, TRANSACTIONAL_ID, txnOffsetCommitRequest.data.transactionalId)) + requestHelper.sendErrorResponseMaybeThrottle(request, Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED.exception) + else if (!authHelper.authorize(request.context, READ, GROUP, txnOffsetCommitRequest.data.groupId)) + requestHelper.sendErrorResponseMaybeThrottle(request, Errors.GROUP_AUTHORIZATION_FAILED.exception) + else { + val unauthorizedTopicErrors = mutable.Map[TopicPartition, Errors]() + val nonExistingTopicErrors = mutable.Map[TopicPartition, Errors]() + val authorizedTopicCommittedOffsets = mutable.Map[TopicPartition, TxnOffsetCommitRequest.CommittedOffset]() + val committedOffsets = txnOffsetCommitRequest.offsets.asScala + val authorizedTopics = authHelper.filterByAuthorized(request.context, READ, TOPIC, committedOffsets)(_._1.topic) + + for ((topicPartition, commitedOffset) <- committedOffsets) { + if (!authorizedTopics.contains(topicPartition.topic)) + unauthorizedTopicErrors += topicPartition -> Errors.TOPIC_AUTHORIZATION_FAILED + else if (!metadataCache.contains(topicPartition)) + nonExistingTopicErrors += topicPartition -> Errors.UNKNOWN_TOPIC_OR_PARTITION + else + authorizedTopicCommittedOffsets += (topicPartition -> commitedOffset) + } + + // the callback for sending an offset commit response + def sendResponseCallback(authorizedTopicErrors: Map[TopicPartition, Errors]): Unit = { + val combinedCommitStatus = mutable.Map() ++= authorizedTopicErrors ++= unauthorizedTopicErrors ++= nonExistingTopicErrors + if (isDebugEnabled) + combinedCommitStatus.forKeyValue { (topicPartition, error) => + if (error != Errors.NONE) { + debug(s"TxnOffsetCommit with correlation id ${header.correlationId} from client ${header.clientId} " + + s"on partition $topicPartition failed due to ${error.exceptionName}") + } + } + + // We need to replace COORDINATOR_LOAD_IN_PROGRESS with COORDINATOR_NOT_AVAILABLE + // for older producer client from 0.11 to prior 2.0, which could potentially crash due + // to unexpected loading error. This bug is fixed later by KAFKA-7296. Clients using + // txn commit protocol >= 2 (version 2.3 and onwards) are guaranteed to have + // the fix to check for the loading error. + if (txnOffsetCommitRequest.version < 2) { + combinedCommitStatus ++= combinedCommitStatus.collect { + case (tp, error) if error == Errors.COORDINATOR_LOAD_IN_PROGRESS => tp -> Errors.COORDINATOR_NOT_AVAILABLE + } + } + + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new TxnOffsetCommitResponse(requestThrottleMs, combinedCommitStatus.asJava)) + } + + if (authorizedTopicCommittedOffsets.isEmpty) + sendResponseCallback(Map.empty) + else { + val offsetMetadata = convertTxnOffsets(authorizedTopicCommittedOffsets.toMap) + groupCoordinator.handleTxnCommitOffsets( + txnOffsetCommitRequest.data.groupId, + txnOffsetCommitRequest.data.producerId, + txnOffsetCommitRequest.data.producerEpoch, + txnOffsetCommitRequest.data.memberId, + Option(txnOffsetCommitRequest.data.groupInstanceId), + txnOffsetCommitRequest.data.generationId, + offsetMetadata, + sendResponseCallback, + requestLocal) + } + } + } + + private def convertTxnOffsets(offsetsMap: immutable.Map[TopicPartition, TxnOffsetCommitRequest.CommittedOffset]): immutable.Map[TopicPartition, OffsetAndMetadata] = { + val currentTimestamp = time.milliseconds + offsetsMap.map { case (topicPartition, partitionData) => + val metadata = if (partitionData.metadata == null) OffsetAndMetadata.NoMetadata else partitionData.metadata + topicPartition -> new OffsetAndMetadata( + offset = partitionData.offset, + leaderEpoch = partitionData.leaderEpoch, + metadata = metadata, + commitTimestamp = currentTimestamp, + expireTimestamp = None) + } + } + + def handleDescribeAcls(request: RequestChannel.Request): Unit = { + aclApis.handleDescribeAcls(request) + } + + def handleCreateAcls(request: RequestChannel.Request): Unit = { + metadataSupport.requireZkOrThrow(KafkaApis.shouldAlwaysForward(request)) + aclApis.handleCreateAcls(request) + } + + def handleDeleteAcls(request: RequestChannel.Request): Unit = { + metadataSupport.requireZkOrThrow(KafkaApis.shouldAlwaysForward(request)) + aclApis.handleDeleteAcls(request) + } + + def handleOffsetForLeaderEpochRequest(request: RequestChannel.Request): Unit = { + val offsetForLeaderEpoch = request.body[OffsetsForLeaderEpochRequest] + val topics = offsetForLeaderEpoch.data.topics.asScala.toSeq + + // The OffsetsForLeaderEpoch API was initially only used for inter-broker communication and required + // cluster permission. With KIP-320, the consumer now also uses this API to check for log truncation + // following a leader change, so we also allow topic describe permission. + val (authorizedTopics, unauthorizedTopics) = + if (authHelper.authorize(request.context, CLUSTER_ACTION, CLUSTER, CLUSTER_NAME, logIfDenied = false)) + (topics, Seq.empty[OffsetForLeaderTopic]) + else authHelper.partitionSeqByAuthorized(request.context, DESCRIBE, TOPIC, topics)(_.topic) + + val endOffsetsForAuthorizedPartitions = replicaManager.lastOffsetForLeaderEpoch(authorizedTopics) + val endOffsetsForUnauthorizedPartitions = unauthorizedTopics.map { offsetForLeaderTopic => + val partitions = offsetForLeaderTopic.partitions.asScala.map { offsetForLeaderPartition => + new EpochEndOffset() + .setPartition(offsetForLeaderPartition.partition) + .setErrorCode(Errors.TOPIC_AUTHORIZATION_FAILED.code) + } + + new OffsetForLeaderTopicResult() + .setTopic(offsetForLeaderTopic.topic) + .setPartitions(partitions.toList.asJava) + } + + val endOffsetsForAllTopics = new OffsetForLeaderTopicResultCollection( + (endOffsetsForAuthorizedPartitions ++ endOffsetsForUnauthorizedPartitions).asJava.iterator + ) + + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new OffsetsForLeaderEpochResponse(new OffsetForLeaderEpochResponseData() + .setThrottleTimeMs(requestThrottleMs) + .setTopics(endOffsetsForAllTopics))) + } + + def handleAlterConfigsRequest(request: RequestChannel.Request): Unit = { + val zkSupport = metadataSupport.requireZkOrThrow(KafkaApis.shouldAlwaysForward(request)) + val alterConfigsRequest = request.body[AlterConfigsRequest] + val (authorizedResources, unauthorizedResources) = alterConfigsRequest.configs.asScala.toMap.partition { case (resource, _) => + resource.`type` match { + case ConfigResource.Type.BROKER_LOGGER => + throw new InvalidRequestException(s"AlterConfigs is deprecated and does not support the resource type ${ConfigResource.Type.BROKER_LOGGER}") + case ConfigResource.Type.BROKER => + authHelper.authorize(request.context, ALTER_CONFIGS, CLUSTER, CLUSTER_NAME) + case ConfigResource.Type.TOPIC => + authHelper.authorize(request.context, ALTER_CONFIGS, TOPIC, resource.name) + case rt => throw new InvalidRequestException(s"Unexpected resource type $rt") + } + } + val authorizedResult = zkSupport.adminManager.alterConfigs(authorizedResources, alterConfigsRequest.validateOnly) + val unauthorizedResult = unauthorizedResources.keys.map { resource => + resource -> configsAuthorizationApiError(resource) + } + def responseCallback(requestThrottleMs: Int): AlterConfigsResponse = { + val data = new AlterConfigsResponseData() + .setThrottleTimeMs(requestThrottleMs) + (authorizedResult ++ unauthorizedResult).foreach{ case (resource, error) => + data.responses().add(new AlterConfigsResourceResponse() + .setErrorCode(error.error.code) + .setErrorMessage(error.message) + .setResourceName(resource.name) + .setResourceType(resource.`type`.id)) + } + new AlterConfigsResponse(data) + } + requestHelper.sendResponseMaybeThrottle(request, responseCallback) + } + + def handleAlterPartitionReassignmentsRequest(request: RequestChannel.Request): Unit = { + val zkSupport = metadataSupport.requireZkOrThrow(KafkaApis.shouldAlwaysForward(request)) + authHelper.authorizeClusterOperation(request, ALTER) + val alterPartitionReassignmentsRequest = request.body[AlterPartitionReassignmentsRequest] + + def sendResponseCallback(result: Either[Map[TopicPartition, ApiError], ApiError]): Unit = { + val responseData = result match { + case Right(topLevelError) => + new AlterPartitionReassignmentsResponseData().setErrorMessage(topLevelError.message).setErrorCode(topLevelError.error.code) + + case Left(assignments) => + val topicResponses = assignments.groupBy(_._1.topic).map { + case (topic, reassignmentsByTp) => + val partitionResponses = reassignmentsByTp.map { + case (topicPartition, error) => + new ReassignablePartitionResponse().setPartitionIndex(topicPartition.partition) + .setErrorCode(error.error.code).setErrorMessage(error.message) + } + new ReassignableTopicResponse().setName(topic).setPartitions(partitionResponses.toList.asJava) + } + new AlterPartitionReassignmentsResponseData().setResponses(topicResponses.toList.asJava) + } + + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new AlterPartitionReassignmentsResponse(responseData.setThrottleTimeMs(requestThrottleMs)) + ) + } + + val reassignments = alterPartitionReassignmentsRequest.data.topics.asScala.flatMap { + reassignableTopic => reassignableTopic.partitions.asScala.map { + reassignablePartition => + val tp = new TopicPartition(reassignableTopic.name, reassignablePartition.partitionIndex) + if (reassignablePartition.replicas == null) + tp -> None // revert call + else + tp -> Some(reassignablePartition.replicas.asScala.map(_.toInt)) + } + }.toMap + + zkSupport.controller.alterPartitionReassignments(reassignments, sendResponseCallback) + } + + def handleListPartitionReassignmentsRequest(request: RequestChannel.Request): Unit = { + val zkSupport = metadataSupport.requireZkOrThrow(KafkaApis.shouldAlwaysForward(request)) + authHelper.authorizeClusterOperation(request, DESCRIBE) + val listPartitionReassignmentsRequest = request.body[ListPartitionReassignmentsRequest] + + def sendResponseCallback(result: Either[Map[TopicPartition, ReplicaAssignment], ApiError]): Unit = { + val responseData = result match { + case Right(error) => new ListPartitionReassignmentsResponseData().setErrorMessage(error.message).setErrorCode(error.error.code) + + case Left(assignments) => + val topicReassignments = assignments.groupBy(_._1.topic).map { + case (topic, reassignmentsByTp) => + val partitionReassignments = reassignmentsByTp.map { + case (topicPartition, assignment) => + new ListPartitionReassignmentsResponseData.OngoingPartitionReassignment() + .setPartitionIndex(topicPartition.partition) + .setAddingReplicas(assignment.addingReplicas.toList.asJava.asInstanceOf[java.util.List[java.lang.Integer]]) + .setRemovingReplicas(assignment.removingReplicas.toList.asJava.asInstanceOf[java.util.List[java.lang.Integer]]) + .setReplicas(assignment.replicas.toList.asJava.asInstanceOf[java.util.List[java.lang.Integer]]) + }.toList + + new ListPartitionReassignmentsResponseData.OngoingTopicReassignment().setName(topic) + .setPartitions(partitionReassignments.asJava) + }.toList + + new ListPartitionReassignmentsResponseData().setTopics(topicReassignments.asJava) + } + + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new ListPartitionReassignmentsResponse(responseData.setThrottleTimeMs(requestThrottleMs)) + ) + } + + val partitionsOpt = listPartitionReassignmentsRequest.data.topics match { + case topics: Any => + Some(topics.iterator().asScala.flatMap { topic => + topic.partitionIndexes.iterator().asScala + .map { tp => new TopicPartition(topic.name(), tp) } + }.toSet) + case _ => None + } + + zkSupport.controller.listPartitionReassignments(partitionsOpt, sendResponseCallback) + } + + private def configsAuthorizationApiError(resource: ConfigResource): ApiError = { + val error = resource.`type` match { + case ConfigResource.Type.BROKER | ConfigResource.Type.BROKER_LOGGER => Errors.CLUSTER_AUTHORIZATION_FAILED + case ConfigResource.Type.TOPIC => Errors.TOPIC_AUTHORIZATION_FAILED + case rt => throw new InvalidRequestException(s"Unexpected resource type $rt for resource ${resource.name}") + } + new ApiError(error, null) + } + + def handleIncrementalAlterConfigsRequest(request: RequestChannel.Request): Unit = { + val zkSupport = metadataSupport.requireZkOrThrow(KafkaApis.shouldAlwaysForward(request)) + val alterConfigsRequest = request.body[IncrementalAlterConfigsRequest] + + val configs = alterConfigsRequest.data.resources.iterator.asScala.map { alterConfigResource => + val configResource = new ConfigResource(ConfigResource.Type.forId(alterConfigResource.resourceType), + alterConfigResource.resourceName) + configResource -> alterConfigResource.configs.iterator.asScala.map { + alterConfig => new AlterConfigOp(new ConfigEntry(alterConfig.name, alterConfig.value), + OpType.forId(alterConfig.configOperation)) + }.toBuffer + }.toMap + + val (authorizedResources, unauthorizedResources) = configs.partition { case (resource, _) => + resource.`type` match { + case ConfigResource.Type.BROKER | ConfigResource.Type.BROKER_LOGGER => + authHelper.authorize(request.context, ALTER_CONFIGS, CLUSTER, CLUSTER_NAME) + case ConfigResource.Type.TOPIC => + authHelper.authorize(request.context, ALTER_CONFIGS, TOPIC, resource.name) + case rt => throw new InvalidRequestException(s"Unexpected resource type $rt") + } + } + + val authorizedResult = zkSupport.adminManager.incrementalAlterConfigs(authorizedResources, alterConfigsRequest.data.validateOnly) + val unauthorizedResult = unauthorizedResources.keys.map { resource => + resource -> configsAuthorizationApiError(resource) + } + + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => new IncrementalAlterConfigsResponse( + requestThrottleMs, (authorizedResult ++ unauthorizedResult).asJava)) + } + + def handleDescribeConfigsRequest(request: RequestChannel.Request): Unit = { + val describeConfigsRequest = request.body[DescribeConfigsRequest] + val (authorizedResources, unauthorizedResources) = describeConfigsRequest.data.resources.asScala.partition { resource => + ConfigResource.Type.forId(resource.resourceType) match { + case ConfigResource.Type.BROKER | ConfigResource.Type.BROKER_LOGGER => + authHelper.authorize(request.context, DESCRIBE_CONFIGS, CLUSTER, CLUSTER_NAME) + case ConfigResource.Type.TOPIC => + authHelper.authorize(request.context, DESCRIBE_CONFIGS, TOPIC, resource.resourceName) + case rt => throw new InvalidRequestException(s"Unexpected resource type $rt for resource ${resource.resourceName}") + } + } + val authorizedConfigs = configHelper.describeConfigs(authorizedResources.toList, describeConfigsRequest.data.includeSynonyms, describeConfigsRequest.data.includeDocumentation) + val unauthorizedConfigs = unauthorizedResources.map { resource => + val error = ConfigResource.Type.forId(resource.resourceType) match { + case ConfigResource.Type.BROKER | ConfigResource.Type.BROKER_LOGGER => Errors.CLUSTER_AUTHORIZATION_FAILED + case ConfigResource.Type.TOPIC => Errors.TOPIC_AUTHORIZATION_FAILED + case rt => throw new InvalidRequestException(s"Unexpected resource type $rt for resource ${resource.resourceName}") + } + new DescribeConfigsResponseData.DescribeConfigsResult().setErrorCode(error.code) + .setErrorMessage(error.message) + .setConfigs(Collections.emptyList[DescribeConfigsResponseData.DescribeConfigsResourceResult]) + .setResourceName(resource.resourceName) + .setResourceType(resource.resourceType) + } + + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new DescribeConfigsResponse(new DescribeConfigsResponseData().setThrottleTimeMs(requestThrottleMs) + .setResults((authorizedConfigs ++ unauthorizedConfigs).asJava))) + } + + def handleAlterReplicaLogDirsRequest(request: RequestChannel.Request): Unit = { + val alterReplicaDirsRequest = request.body[AlterReplicaLogDirsRequest] + if (authHelper.authorize(request.context, ALTER, CLUSTER, CLUSTER_NAME)) { + val result = replicaManager.alterReplicaLogDirs(alterReplicaDirsRequest.partitionDirs.asScala) + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new AlterReplicaLogDirsResponse(new AlterReplicaLogDirsResponseData() + .setResults(result.groupBy(_._1.topic).map { + case (topic, errors) => new AlterReplicaLogDirsResponseData.AlterReplicaLogDirTopicResult() + .setTopicName(topic) + .setPartitions(errors.map { + case (tp, error) => new AlterReplicaLogDirsResponseData.AlterReplicaLogDirPartitionResult() + .setPartitionIndex(tp.partition) + .setErrorCode(error.code) + }.toList.asJava) + }.toList.asJava) + .setThrottleTimeMs(requestThrottleMs))) + } else { + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + alterReplicaDirsRequest.getErrorResponse(requestThrottleMs, Errors.CLUSTER_AUTHORIZATION_FAILED.exception)) + } + } + + def handleDescribeLogDirsRequest(request: RequestChannel.Request): Unit = { + val describeLogDirsDirRequest = request.body[DescribeLogDirsRequest] + val logDirInfos = { + if (authHelper.authorize(request.context, DESCRIBE, CLUSTER, CLUSTER_NAME)) { + val partitions = + if (describeLogDirsDirRequest.isAllTopicPartitions) + replicaManager.logManager.allLogs.map(_.topicPartition).toSet + else + describeLogDirsDirRequest.data.topics.asScala.flatMap( + logDirTopic => logDirTopic.partitions.asScala.map(partitionIndex => + new TopicPartition(logDirTopic.topic, partitionIndex))).toSet + + replicaManager.describeLogDirs(partitions) + } else { + List.empty[DescribeLogDirsResponseData.DescribeLogDirsResult] + } + } + requestHelper.sendResponseMaybeThrottle(request, throttleTimeMs => new DescribeLogDirsResponse(new DescribeLogDirsResponseData() + .setThrottleTimeMs(throttleTimeMs) + .setResults(logDirInfos.asJava))) + } + + def handleCreateTokenRequest(request: RequestChannel.Request): Unit = { + metadataSupport.requireZkOrThrow(KafkaApis.shouldAlwaysForward(request)) + val createTokenRequest = request.body[CreateDelegationTokenRequest] + + // the callback for sending a create token response + def sendResponseCallback(createResult: CreateTokenResult): Unit = { + trace(s"Sending create token response for correlation id ${request.header.correlationId} " + + s"to client ${request.header.clientId}.") + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + CreateDelegationTokenResponse.prepareResponse(requestThrottleMs, createResult.error, request.context.principal, createResult.issueTimestamp, + createResult.expiryTimestamp, createResult.maxTimestamp, createResult.tokenId, ByteBuffer.wrap(createResult.hmac))) + } + + if (!allowTokenRequests(request)) + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + CreateDelegationTokenResponse.prepareResponse(requestThrottleMs, Errors.DELEGATION_TOKEN_REQUEST_NOT_ALLOWED, request.context.principal)) + else { + val renewerList = createTokenRequest.data.renewers.asScala.toList.map(entry => + new KafkaPrincipal(entry.principalType, entry.principalName)) + + if (renewerList.exists(principal => principal.getPrincipalType != KafkaPrincipal.USER_TYPE)) { + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + CreateDelegationTokenResponse.prepareResponse(requestThrottleMs, Errors.INVALID_PRINCIPAL_TYPE, request.context.principal)) + } + else { + tokenManager.createToken( + request.context.principal, + renewerList, + createTokenRequest.data.maxLifetimeMs, + sendResponseCallback + ) + } + } + } + + def handleRenewTokenRequest(request: RequestChannel.Request): Unit = { + metadataSupport.requireZkOrThrow(KafkaApis.shouldAlwaysForward(request)) + val renewTokenRequest = request.body[RenewDelegationTokenRequest] + + // the callback for sending a renew token response + def sendResponseCallback(error: Errors, expiryTimestamp: Long): Unit = { + trace("Sending renew token response for correlation id %d to client %s." + .format(request.header.correlationId, request.header.clientId)) + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new RenewDelegationTokenResponse( + new RenewDelegationTokenResponseData() + .setThrottleTimeMs(requestThrottleMs) + .setErrorCode(error.code) + .setExpiryTimestampMs(expiryTimestamp))) + } + + if (!allowTokenRequests(request)) + sendResponseCallback(Errors.DELEGATION_TOKEN_REQUEST_NOT_ALLOWED, DelegationTokenManager.ErrorTimestamp) + else { + tokenManager.renewToken( + request.context.principal, + ByteBuffer.wrap(renewTokenRequest.data.hmac), + renewTokenRequest.data.renewPeriodMs, + sendResponseCallback + ) + } + } + + def handleExpireTokenRequest(request: RequestChannel.Request): Unit = { + metadataSupport.requireZkOrThrow(KafkaApis.shouldAlwaysForward(request)) + val expireTokenRequest = request.body[ExpireDelegationTokenRequest] + + // the callback for sending a expire token response + def sendResponseCallback(error: Errors, expiryTimestamp: Long): Unit = { + trace("Sending expire token response for correlation id %d to client %s." + .format(request.header.correlationId, request.header.clientId)) + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new ExpireDelegationTokenResponse( + new ExpireDelegationTokenResponseData() + .setThrottleTimeMs(requestThrottleMs) + .setErrorCode(error.code) + .setExpiryTimestampMs(expiryTimestamp))) + } + + if (!allowTokenRequests(request)) + sendResponseCallback(Errors.DELEGATION_TOKEN_REQUEST_NOT_ALLOWED, DelegationTokenManager.ErrorTimestamp) + else { + tokenManager.expireToken( + request.context.principal, + expireTokenRequest.hmac(), + expireTokenRequest.expiryTimePeriod(), + sendResponseCallback + ) + } + } + + def handleDescribeTokensRequest(request: RequestChannel.Request): Unit = { + val describeTokenRequest = request.body[DescribeDelegationTokenRequest] + + // the callback for sending a describe token response + def sendResponseCallback(error: Errors, tokenDetails: List[DelegationToken]): Unit = { + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new DescribeDelegationTokenResponse(requestThrottleMs, error, tokenDetails.asJava)) + trace("Sending describe token response for correlation id %d to client %s." + .format(request.header.correlationId, request.header.clientId)) + } + + if (!allowTokenRequests(request)) + sendResponseCallback(Errors.DELEGATION_TOKEN_REQUEST_NOT_ALLOWED, List.empty) + else if (!config.tokenAuthEnabled) + sendResponseCallback(Errors.DELEGATION_TOKEN_AUTH_DISABLED, List.empty) + else { + val requestPrincipal = request.context.principal + + if (describeTokenRequest.ownersListEmpty()) { + sendResponseCallback(Errors.NONE, List()) + } + else { + val owners = if (describeTokenRequest.data.owners == null) + None + else + Some(describeTokenRequest.data.owners.asScala.map(p => new KafkaPrincipal(p.principalType(), p.principalName)).toList) + def authorizeToken(tokenId: String) = authHelper.authorize(request.context, DESCRIBE, DELEGATION_TOKEN, tokenId) + def eligible(token: TokenInformation) = DelegationTokenManager.filterToken(requestPrincipal, owners, token, authorizeToken) + val tokens = tokenManager.getTokens(eligible) + sendResponseCallback(Errors.NONE, tokens) + } + } + } + + def allowTokenRequests(request: RequestChannel.Request): Boolean = { + val protocol = request.context.securityProtocol + if (request.context.principal.tokenAuthenticated || + protocol == SecurityProtocol.PLAINTEXT || + // disallow requests from 1-way SSL + (protocol == SecurityProtocol.SSL && request.context.principal == KafkaPrincipal.ANONYMOUS)) + false + else + true + } + + def handleElectLeaders(request: RequestChannel.Request): Unit = { + val zkSupport = metadataSupport.requireZkOrThrow(KafkaApis.shouldAlwaysForward(request)) + val electionRequest = request.body[ElectLeadersRequest] + + def sendResponseCallback( + error: ApiError + )( + results: Map[TopicPartition, ApiError] + ): Unit = { + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => { + val adjustedResults = if (electionRequest.data.topicPartitions == null) { + /* When performing elections across all of the partitions we should only return + * partitions for which there was an election or resulted in an error. In other + * words, partitions that didn't need election because they ready have the correct + * leader are not returned to the client. + */ + results.filter { case (_, error) => + error.error != Errors.ELECTION_NOT_NEEDED + } + } else results + + val electionResults = new util.ArrayList[ReplicaElectionResult]() + adjustedResults + .groupBy { case (tp, _) => tp.topic } + .forKeyValue { (topic, ps) => + val electionResult = new ReplicaElectionResult() + + electionResult.setTopic(topic) + ps.forKeyValue { (topicPartition, error) => + val partitionResult = new PartitionResult() + partitionResult.setPartitionId(topicPartition.partition) + partitionResult.setErrorCode(error.error.code) + partitionResult.setErrorMessage(error.message) + electionResult.partitionResult.add(partitionResult) + } + + electionResults.add(electionResult) + } + + new ElectLeadersResponse( + requestThrottleMs, + error.error.code, + electionResults, + electionRequest.version + ) + }) + } + + if (!authHelper.authorize(request.context, ALTER, CLUSTER, CLUSTER_NAME)) { + val error = new ApiError(Errors.CLUSTER_AUTHORIZATION_FAILED, null) + val partitionErrors: Map[TopicPartition, ApiError] = + electionRequest.topicPartitions.iterator.map(partition => partition -> error).toMap + + sendResponseCallback(error)(partitionErrors) + } else { + val partitions = if (electionRequest.data.topicPartitions == null) { + metadataCache.getAllTopics().flatMap(metadataCache.getTopicPartitions) + } else { + electionRequest.topicPartitions + } + + replicaManager.electLeaders( + zkSupport.controller, + partitions, + electionRequest.electionType, + sendResponseCallback(ApiError.NONE), + electionRequest.data.timeoutMs + ) + } + } + + def handleOffsetDeleteRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { + val offsetDeleteRequest = request.body[OffsetDeleteRequest] + val groupId = offsetDeleteRequest.data.groupId + + if (authHelper.authorize(request.context, DELETE, GROUP, groupId)) { + val topics = offsetDeleteRequest.data.topics.asScala + val authorizedTopics = authHelper.filterByAuthorized(request.context, READ, TOPIC, topics)(_.name) + + val topicPartitionErrors = mutable.Map[TopicPartition, Errors]() + val topicPartitions = mutable.ArrayBuffer[TopicPartition]() + + for (topic <- topics) { + for (partition <- topic.partitions.asScala) { + val tp = new TopicPartition(topic.name, partition.partitionIndex) + if (!authorizedTopics.contains(topic.name)) + topicPartitionErrors(tp) = Errors.TOPIC_AUTHORIZATION_FAILED + else if (!metadataCache.contains(tp)) + topicPartitionErrors(tp) = Errors.UNKNOWN_TOPIC_OR_PARTITION + else + topicPartitions += tp + } + } + + val (groupError, authorizedTopicPartitionsErrors) = groupCoordinator.handleDeleteOffsets( + groupId, topicPartitions, requestLocal) + + topicPartitionErrors ++= authorizedTopicPartitionsErrors + + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => { + if (groupError != Errors.NONE) + offsetDeleteRequest.getErrorResponse(requestThrottleMs, groupError) + else { + val topics = new OffsetDeleteResponseData.OffsetDeleteResponseTopicCollection + topicPartitionErrors.groupBy(_._1.topic).forKeyValue { (topic, topicPartitions) => + val partitions = new OffsetDeleteResponseData.OffsetDeleteResponsePartitionCollection + topicPartitions.forKeyValue { (topicPartition, error) => + partitions.add( + new OffsetDeleteResponseData.OffsetDeleteResponsePartition() + .setPartitionIndex(topicPartition.partition) + .setErrorCode(error.code) + ) + } + topics.add(new OffsetDeleteResponseData.OffsetDeleteResponseTopic() + .setName(topic) + .setPartitions(partitions)) + } + + new OffsetDeleteResponse(new OffsetDeleteResponseData() + .setTopics(topics) + .setThrottleTimeMs(requestThrottleMs)) + } + }) + } else { + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + offsetDeleteRequest.getErrorResponse(requestThrottleMs, Errors.GROUP_AUTHORIZATION_FAILED)) + } + } + + def handleDescribeClientQuotasRequest(request: RequestChannel.Request): Unit = { + val describeClientQuotasRequest = request.body[DescribeClientQuotasRequest] + + if (!authHelper.authorize(request.context, DESCRIBE_CONFIGS, CLUSTER, CLUSTER_NAME)) { + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + describeClientQuotasRequest.getErrorResponse(requestThrottleMs, Errors.CLUSTER_AUTHORIZATION_FAILED.exception)) + } else { + metadataSupport match { + case ZkSupport(adminManager, controller, zkClient, forwardingManager, metadataCache) => + val result = adminManager.describeClientQuotas(describeClientQuotasRequest.filter) + + val entriesData = result.iterator.map { case (quotaEntity, quotaValues) => + val entityData = quotaEntity.entries.asScala.iterator.map { case (entityType, entityName) => + new DescribeClientQuotasResponseData.EntityData() + .setEntityType(entityType) + .setEntityName(entityName) + }.toBuffer + + val valueData = quotaValues.iterator.map { case (key, value) => + new DescribeClientQuotasResponseData.ValueData() + .setKey(key) + .setValue(value) + }.toBuffer + + new DescribeClientQuotasResponseData.EntryData() + .setEntity(entityData.asJava) + .setValues(valueData.asJava) + }.toBuffer + + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new DescribeClientQuotasResponse(new DescribeClientQuotasResponseData() + .setThrottleTimeMs(requestThrottleMs) + .setEntries(entriesData.asJava))) + case RaftSupport(_, metadataCache) => + val result = metadataCache.describeClientQuotas(describeClientQuotasRequest.data()) + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => { + result.setThrottleTimeMs(requestThrottleMs) + new DescribeClientQuotasResponse(result) + }) + } + } + } + + def handleAlterClientQuotasRequest(request: RequestChannel.Request): Unit = { + val zkSupport = metadataSupport.requireZkOrThrow(KafkaApis.shouldAlwaysForward(request)) + val alterClientQuotasRequest = request.body[AlterClientQuotasRequest] + + if (authHelper.authorize(request.context, ALTER_CONFIGS, CLUSTER, CLUSTER_NAME)) { + val result = zkSupport.adminManager.alterClientQuotas(alterClientQuotasRequest.entries.asScala, + alterClientQuotasRequest.validateOnly) + + val entriesData = result.iterator.map { case (quotaEntity, apiError) => + val entityData = quotaEntity.entries.asScala.iterator.map { case (key, value) => + new AlterClientQuotasResponseData.EntityData() + .setEntityType(key) + .setEntityName(value) + }.toBuffer + + new AlterClientQuotasResponseData.EntryData() + .setErrorCode(apiError.error.code) + .setErrorMessage(apiError.message) + .setEntity(entityData.asJava) + }.toBuffer + + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new AlterClientQuotasResponse(new AlterClientQuotasResponseData() + .setThrottleTimeMs(requestThrottleMs) + .setEntries(entriesData.asJava))) + } else { + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + alterClientQuotasRequest.getErrorResponse(requestThrottleMs, Errors.CLUSTER_AUTHORIZATION_FAILED.exception)) + } + } + + def handleDescribeUserScramCredentialsRequest(request: RequestChannel.Request): Unit = { + val zkSupport = metadataSupport.requireZkOrThrow(KafkaApis.notYetSupported(request)) + val describeUserScramCredentialsRequest = request.body[DescribeUserScramCredentialsRequest] + + if (authHelper.authorize(request.context, DESCRIBE, CLUSTER, CLUSTER_NAME)) { + val result = zkSupport.adminManager.describeUserScramCredentials( + Option(describeUserScramCredentialsRequest.data.users).map(_.asScala.map(_.name).toList)) + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new DescribeUserScramCredentialsResponse(result.setThrottleTimeMs(requestThrottleMs))) + } else { + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + describeUserScramCredentialsRequest.getErrorResponse(requestThrottleMs, Errors.CLUSTER_AUTHORIZATION_FAILED.exception)) + } + } + + def handleAlterUserScramCredentialsRequest(request: RequestChannel.Request): Unit = { + val zkSupport = metadataSupport.requireZkOrThrow(KafkaApis.shouldAlwaysForward(request)) + val alterUserScramCredentialsRequest = request.body[AlterUserScramCredentialsRequest] + + if (!zkSupport.controller.isActive) { + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + alterUserScramCredentialsRequest.getErrorResponse(requestThrottleMs, Errors.NOT_CONTROLLER.exception)) + } else if (authHelper.authorize(request.context, ALTER, CLUSTER, CLUSTER_NAME)) { + val result = zkSupport.adminManager.alterUserScramCredentials( + alterUserScramCredentialsRequest.data.upsertions().asScala, alterUserScramCredentialsRequest.data.deletions().asScala) + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new AlterUserScramCredentialsResponse(result.setThrottleTimeMs(requestThrottleMs))) + } else { + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + alterUserScramCredentialsRequest.getErrorResponse(requestThrottleMs, Errors.CLUSTER_AUTHORIZATION_FAILED.exception)) + } + } + + def handleAlterIsrRequest(request: RequestChannel.Request): Unit = { + val zkSupport = metadataSupport.requireZkOrThrow(KafkaApis.shouldNeverReceive(request)) + val alterIsrRequest = request.body[AlterIsrRequest] + authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) + + if (!zkSupport.controller.isActive) + requestHelper.sendResponseExemptThrottle(request, alterIsrRequest.getErrorResponse( + AbstractResponse.DEFAULT_THROTTLE_TIME, Errors.NOT_CONTROLLER.exception)) + else + zkSupport.controller.alterIsrs(alterIsrRequest.data, alterIsrResp => + requestHelper.sendResponseExemptThrottle(request, new AlterIsrResponse(alterIsrResp)) + ) + } + + def handleUpdateFeatures(request: RequestChannel.Request): Unit = { + val zkSupport = metadataSupport.requireZkOrThrow(KafkaApis.shouldAlwaysForward(request)) + val updateFeaturesRequest = request.body[UpdateFeaturesRequest] + + def sendResponseCallback(errors: Either[ApiError, Map[String, ApiError]]): Unit = { + def createResponse(throttleTimeMs: Int): UpdateFeaturesResponse = { + errors match { + case Left(topLevelError) => + UpdateFeaturesResponse.createWithErrors( + topLevelError, + Collections.emptyMap(), + throttleTimeMs) + case Right(featureUpdateErrors) => + UpdateFeaturesResponse.createWithErrors( + ApiError.NONE, + featureUpdateErrors.asJava, + throttleTimeMs) + } + } + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => createResponse(requestThrottleMs)) + } + + if (!authHelper.authorize(request.context, ALTER, CLUSTER, CLUSTER_NAME)) { + sendResponseCallback(Left(new ApiError(Errors.CLUSTER_AUTHORIZATION_FAILED))) + } else if (!zkSupport.controller.isActive) { + sendResponseCallback(Left(new ApiError(Errors.NOT_CONTROLLER))) + } else if (!config.isFeatureVersioningSupported) { + sendResponseCallback(Left(new ApiError(Errors.INVALID_REQUEST, "Feature versioning system is disabled."))) + } else { + zkSupport.controller.updateFeatures(updateFeaturesRequest, sendResponseCallback) + } + } + + def handleDescribeCluster(request: RequestChannel.Request): Unit = { + val describeClusterRequest = request.body[DescribeClusterRequest] + + var clusterAuthorizedOperations = Int.MinValue // Default value in the schema + // get cluster authorized operations + if (describeClusterRequest.data.includeClusterAuthorizedOperations) { + if (authHelper.authorize(request.context, DESCRIBE, CLUSTER, CLUSTER_NAME)) + clusterAuthorizedOperations = authHelper.authorizedOperations(request, Resource.CLUSTER) + else + clusterAuthorizedOperations = 0 + } + + val brokers = metadataCache.getAliveBrokerNodes(request.context.listenerName) + val controllerId = metadataSupport.controllerId.getOrElse(MetadataResponse.NO_CONTROLLER_ID) + + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => { + val data = new DescribeClusterResponseData() + .setThrottleTimeMs(requestThrottleMs) + .setClusterId(clusterId) + .setControllerId(controllerId) + .setClusterAuthorizedOperations(clusterAuthorizedOperations); + + + brokers.foreach { broker => + data.brokers.add(new DescribeClusterResponseData.DescribeClusterBroker() + .setBrokerId(broker.id) + .setHost(broker.host) + .setPort(broker.port) + .setRack(broker.rack)) + } + + new DescribeClusterResponse(data) + }) + } + + def handleEnvelope(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { + val zkSupport = metadataSupport.requireZkOrThrow(KafkaApis.shouldNeverReceive(request)) + + // If forwarding is not yet enabled or this request has been received on an invalid endpoint, + // then we treat the request as unparsable and close the connection. + if (!isForwardingEnabled(request)) { + info(s"Closing connection ${request.context.connectionId} because it sent an `Envelope` " + + "request even though forwarding has not been enabled") + requestChannel.closeConnection(request, Collections.emptyMap()) + return + } else if (!request.context.fromPrivilegedListener) { + info(s"Closing connection ${request.context.connectionId} from listener ${request.context.listenerName} " + + s"because it sent an `Envelope` request, which is only accepted on the inter-broker listener " + + s"${config.interBrokerListenerName}.") + requestChannel.closeConnection(request, Collections.emptyMap()) + return + } else if (!authHelper.authorize(request.context, CLUSTER_ACTION, CLUSTER, CLUSTER_NAME)) { + requestHelper.sendErrorResponseMaybeThrottle(request, new ClusterAuthorizationException( + s"Principal ${request.context.principal} does not have required CLUSTER_ACTION for envelope")) + return + } else if (!zkSupport.controller.isActive) { + requestHelper.sendErrorResponseMaybeThrottle(request, new NotControllerException( + s"Broker $brokerId is not the active controller")) + return + } + + EnvelopeUtils.handleEnvelopeRequest(request, requestChannel.metrics, handle(_, requestLocal)) + } + + def handleDescribeProducersRequest(request: RequestChannel.Request): Unit = { + val describeProducersRequest = request.body[DescribeProducersRequest] + + def partitionError( + topicPartition: TopicPartition, + apiError: ApiError + ): DescribeProducersResponseData.PartitionResponse = { + new DescribeProducersResponseData.PartitionResponse() + .setPartitionIndex(topicPartition.partition) + .setErrorCode(apiError.error.code) + .setErrorMessage(apiError.message) + } + + val response = new DescribeProducersResponseData() + describeProducersRequest.data.topics.forEach { topicRequest => + val topicResponse = new DescribeProducersResponseData.TopicResponse() + .setName(topicRequest.name) + + val invalidTopicError = checkValidTopic(topicRequest.name) + + val topicError = invalidTopicError.orElse { + if (!authHelper.authorize(request.context, READ, TOPIC, topicRequest.name)) { + Some(new ApiError(Errors.TOPIC_AUTHORIZATION_FAILED)) + } else if (!metadataCache.contains(topicRequest.name)) + Some(new ApiError(Errors.UNKNOWN_TOPIC_OR_PARTITION)) + else { + None + } + } + + topicRequest.partitionIndexes.forEach { partitionId => + val topicPartition = new TopicPartition(topicRequest.name, partitionId) + val partitionResponse = topicError match { + case Some(error) => partitionError(topicPartition, error) + case None => replicaManager.activeProducerState(topicPartition) + } + topicResponse.partitions.add(partitionResponse) + } + + response.topics.add(topicResponse) + } + + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new DescribeProducersResponse(response.setThrottleTimeMs(requestThrottleMs))) + } + + private def checkValidTopic(topic: String): Option[ApiError] = { + try { + Topic.validate(topic) + None + } catch { + case e: Throwable => Some(ApiError.fromThrowable(e)) + } + } + + def handleDescribeTransactionsRequest(request: RequestChannel.Request): Unit = { + val describeTransactionsRequest = request.body[DescribeTransactionsRequest] + val response = new DescribeTransactionsResponseData() + + describeTransactionsRequest.data.transactionalIds.forEach { transactionalId => + val transactionState = if (!authHelper.authorize(request.context, DESCRIBE, TRANSACTIONAL_ID, transactionalId)) { + new DescribeTransactionsResponseData.TransactionState() + .setTransactionalId(transactionalId) + .setErrorCode(Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED.code) + } else { + txnCoordinator.handleDescribeTransactions(transactionalId) + } + + // Include only partitions which the principal is authorized to describe + val topicIter = transactionState.topics.iterator() + while (topicIter.hasNext) { + val topic = topicIter.next().topic + if (!authHelper.authorize(request.context, DESCRIBE, TOPIC, topic)) { + topicIter.remove() + } + } + response.transactionStates.add(transactionState) + } + + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new DescribeTransactionsResponse(response.setThrottleTimeMs(requestThrottleMs))) + } + + def handleListTransactionsRequest(request: RequestChannel.Request): Unit = { + val listTransactionsRequest = request.body[ListTransactionsRequest] + val filteredProducerIds = listTransactionsRequest.data.producerIdFilters.asScala.map(Long.unbox).toSet + val filteredStates = listTransactionsRequest.data.stateFilters.asScala.toSet + val response = txnCoordinator.handleListTransactions(filteredProducerIds, filteredStates) + + // The response should contain only transactionalIds that the principal + // has `Describe` permission to access. + val transactionStateIter = response.transactionStates.iterator() + while (transactionStateIter.hasNext) { + val transactionState = transactionStateIter.next() + if (!authHelper.authorize(request.context, DESCRIBE, TRANSACTIONAL_ID, transactionState.transactionalId)) { + transactionStateIter.remove() + } + } + + requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => + new ListTransactionsResponse(response.setThrottleTimeMs(requestThrottleMs))) + } + + def handleAllocateProducerIdsRequest(request: RequestChannel.Request): Unit = { + val zkSupport = metadataSupport.requireZkOrThrow(KafkaApis.shouldNeverReceive(request)) + authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) + + val allocateProducerIdsRequest = request.body[AllocateProducerIdsRequest] + + if (!zkSupport.controller.isActive) + requestHelper.sendResponseMaybeThrottle(request, throttleTimeMs => + allocateProducerIdsRequest.getErrorResponse(throttleTimeMs, Errors.NOT_CONTROLLER.exception)) + else + zkSupport.controller.allocateProducerIds(allocateProducerIdsRequest.data, producerIdsResponse => + requestHelper.sendResponseMaybeThrottle(request, throttleTimeMs => + new AllocateProducerIdsResponse(producerIdsResponse.setThrottleTimeMs(throttleTimeMs))) + ) + } + + private def updateRecordConversionStats(request: RequestChannel.Request, + tp: TopicPartition, + conversionStats: RecordConversionStats): Unit = { + val conversionCount = conversionStats.numRecordsConverted + if (conversionCount > 0) { + request.header.apiKey match { + case ApiKeys.PRODUCE => + brokerTopicStats.topicStats(tp.topic).produceMessageConversionsRate.mark(conversionCount) + brokerTopicStats.allTopicsStats.produceMessageConversionsRate.mark(conversionCount) + case ApiKeys.FETCH => + brokerTopicStats.topicStats(tp.topic).fetchMessageConversionsRate.mark(conversionCount) + brokerTopicStats.allTopicsStats.fetchMessageConversionsRate.mark(conversionCount) + case _ => + throw new IllegalStateException("Message conversion info is recorded only for Produce/Fetch requests") + } + request.messageConversionsTimeNanos = conversionStats.conversionTimeNanos + } + request.temporaryMemoryBytes = conversionStats.temporaryMemoryBytes + } + + private def isBrokerEpochStale(zkSupport: ZkSupport, brokerEpochInRequest: Long): Boolean = { + // Broker epoch in LeaderAndIsr/UpdateMetadata/StopReplica request is unknown + // if the controller hasn't been upgraded to use KIP-380 + if (brokerEpochInRequest == AbstractControlRequest.UNKNOWN_BROKER_EPOCH) false + else { + // brokerEpochInRequest > controller.brokerEpoch is possible in rare scenarios where the controller gets notified + // about the new broker epoch and sends a control request with this epoch before the broker learns about it + brokerEpochInRequest < zkSupport.controller.brokerEpoch + } + } +} + +object KafkaApis { + // Traffic from both in-sync and out of sync replicas are accounted for in replication quota to ensure total replication + // traffic doesn't exceed quota. + // TODO: remove resolvedResponseData method when sizeOf can take a data object. + private[server] def sizeOfThrottledPartitions(versionId: Short, + unconvertedResponse: FetchResponse, + quota: ReplicationQuotaManager): Int = { + val responseData = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + unconvertedResponse.data.responses().forEach(topicResponse => + topicResponse.partitions().forEach(partition => + responseData.put(new TopicIdPartition(topicResponse.topicId, new TopicPartition(topicResponse.topic(), partition.partitionIndex)), partition))) + FetchResponse.sizeOf(versionId, responseData.entrySet + .iterator.asScala.filter(element => element.getKey.topicPartition.topic != null && quota.isThrottled(element.getKey.topicPartition)).asJava) + } + + // visible for testing + private[server] def shouldNeverReceive(request: RequestChannel.Request): Exception = { + new UnsupportedVersionException(s"Should never receive when using a Raft-based metadata quorum: ${request.header.apiKey()}") + } + + // visible for testing + private[server] def shouldAlwaysForward(request: RequestChannel.Request): Exception = { + new UnsupportedVersionException(s"Should always be forwarded to the Active Controller when using a Raft-based metadata quorum: ${request.header.apiKey}") + } + + private def unsupported(text: String): Exception = { + new UnsupportedVersionException(s"Unsupported when using a Raft-based metadata quorum: $text") + } + + private def notYetSupported(request: RequestChannel.Request): Exception = { + notYetSupported(request.header.apiKey().toString) + } + + private def notYetSupported(text: String): Exception = { + new UnsupportedVersionException(s"Not yet supported when using a Raft-based metadata quorum: $text") + } +} diff --git a/core/src/main/scala/kafka/server/KafkaBroker.scala b/core/src/main/scala/kafka/server/KafkaBroker.scala new file mode 100644 index 0000000..f4c6abc --- /dev/null +++ b/core/src/main/scala/kafka/server/KafkaBroker.scala @@ -0,0 +1,108 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import com.yammer.metrics.core.MetricName +import kafka.coordinator.group.GroupCoordinator +import kafka.log.LogManager +import kafka.metrics.{KafkaMetricsGroup, KafkaYammerMetrics, LinuxIoMetricsCollector} +import kafka.network.SocketServer +import kafka.security.CredentialProvider +import kafka.utils.KafkaScheduler +import org.apache.kafka.common.ClusterResource +import org.apache.kafka.common.internals.ClusterResourceListeners +import org.apache.kafka.common.metrics.{Metrics, MetricsReporter} +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.utils.Time +import org.apache.kafka.metadata.BrokerState +import org.apache.kafka.server.authorizer.Authorizer + +import scala.collection.Seq +import scala.jdk.CollectionConverters._ + +object KafkaBroker { + //properties for MetricsContext + val MetricsTypeName: String = "KafkaServer" + + private[server] def notifyClusterListeners(clusterId: String, + clusterListeners: Seq[AnyRef]): Unit = { + val clusterResourceListeners = new ClusterResourceListeners + clusterResourceListeners.maybeAddAll(clusterListeners.asJava) + clusterResourceListeners.onUpdate(new ClusterResource(clusterId)) + } + + private[server] def notifyMetricsReporters(clusterId: String, + config: KafkaConfig, + metricsReporters: Seq[AnyRef]): Unit = { + val metricsContext = Server.createKafkaMetricsContext(config, clusterId) + metricsReporters.foreach { + case x: MetricsReporter => x.contextChange(metricsContext) + case _ => //do nothing + } + } + + /** + * The log message that we print when the broker has been successfully started. + * The ducktape system tests look for a line matching the regex 'Kafka\s*Server.*started' + * to know when the broker is started, so it is best not to change this message -- but if + * you do change it, be sure to make it match that regex or the system tests will fail. + */ + val STARTED_MESSAGE = "Kafka Server started" +} + +trait KafkaBroker extends KafkaMetricsGroup { + + def authorizer: Option[Authorizer] + def brokerState: BrokerState + def clusterId: String + def config: KafkaConfig + def dataPlaneRequestHandlerPool: KafkaRequestHandlerPool + def dataPlaneRequestProcessor: KafkaApis + def kafkaScheduler: KafkaScheduler + def kafkaYammerMetrics: KafkaYammerMetrics + def logManager: LogManager + def metrics: Metrics + def quotaManagers: QuotaFactory.QuotaManagers + def replicaManager: ReplicaManager + def socketServer: SocketServer + def metadataCache: MetadataCache + def groupCoordinator: GroupCoordinator + def boundPort(listenerName: ListenerName): Int + def startup(): Unit + def awaitShutdown(): Unit + def shutdown(): Unit + def brokerTopicStats: BrokerTopicStats + def credentialProvider: CredentialProvider + + // For backwards compatibility, we need to keep older metrics tied + // to their original name when this class was named `KafkaServer` + override def metricName(name: String, metricTags: scala.collection.Map[String, String]): MetricName = { + explicitMetricName(Server.MetricsPrefix, KafkaBroker.MetricsTypeName, name, metricTags) + } + + newGauge("BrokerState", () => brokerState.value) + newGauge("ClusterId", () => clusterId) + newGauge("yammer-metrics-count", () => KafkaYammerMetrics.defaultRegistry.allMetrics.size) + + private val linuxIoMetricsCollector = new LinuxIoMetricsCollector("/proc", Time.SYSTEM, logger.underlying) + + if (linuxIoMetricsCollector.usable()) { + newGauge("linux-disk-read-bytes", () => linuxIoMetricsCollector.readBytes()) + newGauge("linux-disk-write-bytes", () => linuxIoMetricsCollector.writeBytes()) + } +} diff --git a/core/src/main/scala/kafka/server/KafkaConfig.scala b/core/src/main/scala/kafka/server/KafkaConfig.scala new file mode 100755 index 0000000..15c5d09 --- /dev/null +++ b/core/src/main/scala/kafka/server/KafkaConfig.scala @@ -0,0 +1,2191 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util +import java.util.{Collections, Locale, Properties} +import kafka.api.{ApiVersion, ApiVersionValidator, KAFKA_0_10_0_IV1, KAFKA_2_1_IV0, KAFKA_2_7_IV0, KAFKA_2_8_IV0, KAFKA_3_0_IV1} +import kafka.cluster.EndPoint +import kafka.coordinator.group.OffsetConfig +import kafka.coordinator.transaction.{TransactionLog, TransactionStateManager} +import kafka.log.LogConfig +import kafka.log.LogConfig.MessageFormatVersion +import kafka.message.{BrokerCompressionCodec, CompressionCodec, ZStdCompressionCodec} +import kafka.security.authorizer.AuthorizerUtils +import kafka.server.KafkaConfig.{ControllerListenerNamesProp, ListenerSecurityProtocolMapProp} +import kafka.server.KafkaRaftServer.{BrokerRole, ControllerRole, ProcessRole} +import kafka.utils.CoreUtils.parseCsvList +import kafka.utils.{CoreUtils, Logging} +import kafka.utils.Implicits._ +import org.apache.kafka.clients.CommonClientConfigs +import org.apache.kafka.common.Reconfigurable +import org.apache.kafka.common.config.{AbstractConfig, ConfigDef, ConfigException, ConfigResource, SaslConfigs, SecurityConfig, SslClientAuth, SslConfigs, TopicConfig} +import org.apache.kafka.common.config.ConfigDef.{ConfigKey, ValidList} +import org.apache.kafka.common.config.internals.BrokerSecurityConfigs +import org.apache.kafka.common.config.types.Password +import org.apache.kafka.common.metrics.Sensor +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.record.{LegacyRecord, Records, TimestampType} +import org.apache.kafka.common.security.auth.KafkaPrincipalSerde +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.security.authenticator.DefaultKafkaPrincipalBuilder +import org.apache.kafka.common.utils.Utils +import org.apache.kafka.raft.RaftConfig +import org.apache.kafka.server.authorizer.Authorizer +import org.apache.kafka.server.log.remote.storage.RemoteLogManagerConfig +import org.apache.zookeeper.client.ZKClientConfig + +import scala.annotation.nowarn +import scala.jdk.CollectionConverters._ +import scala.collection.{Map, Seq} + +object Defaults { + /** ********* Zookeeper Configuration ***********/ + val ZkSessionTimeoutMs = 18000 + val ZkSyncTimeMs = 2000 + val ZkEnableSecureAcls = false + val ZkMaxInFlightRequests = 10 + val ZkSslClientEnable = false + val ZkSslProtocol = "TLSv1.2" + val ZkSslEndpointIdentificationAlgorithm = "HTTPS" + val ZkSslCrlEnable = false + val ZkSslOcspEnable = false + + /** ********* General Configuration ***********/ + val BrokerIdGenerationEnable = true + val MaxReservedBrokerId = 1000 + val BrokerId = -1 + val MessageMaxBytes = 1024 * 1024 + Records.LOG_OVERHEAD + val NumNetworkThreads = 3 + val NumIoThreads = 8 + val BackgroundThreads = 10 + val QueuedMaxRequests = 500 + val QueuedMaxRequestBytes = -1 + val InitialBrokerRegistrationTimeoutMs = 60000 + val BrokerHeartbeatIntervalMs = 2000 + val BrokerSessionTimeoutMs = 9000 + val MetadataSnapshotMaxNewRecordBytes = 20 * 1024 * 1024 + + /** KRaft mode configs */ + val EmptyNodeId: Int = -1 + + /************* Authorizer Configuration ***********/ + val AuthorizerClassName = "" + + /** ********* Socket Server Configuration ***********/ + val Listeners = "PLAINTEXT://:9092" + val ListenerSecurityProtocolMap: String = EndPoint.DefaultSecurityProtocolMap.map { case (listenerName, securityProtocol) => + s"${listenerName.value}:${securityProtocol.name}" + }.mkString(",") + + val SocketSendBufferBytes: Int = 100 * 1024 + val SocketReceiveBufferBytes: Int = 100 * 1024 + val SocketRequestMaxBytes: Int = 100 * 1024 * 1024 + val MaxConnectionsPerIp: Int = Int.MaxValue + val MaxConnectionsPerIpOverrides: String = "" + val MaxConnections: Int = Int.MaxValue + val MaxConnectionCreationRate: Int = Int.MaxValue + val ConnectionsMaxIdleMs = 10 * 60 * 1000L + val RequestTimeoutMs = 30000 + val ConnectionSetupTimeoutMs = CommonClientConfigs.DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT_MS + val ConnectionSetupTimeoutMaxMs = CommonClientConfigs.DEFAULT_SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS + val FailedAuthenticationDelayMs = 100 + + /** ********* Log Configuration ***********/ + val NumPartitions = 1 + val LogDir = "/tmp/kafka-logs" + val LogSegmentBytes = 1 * 1024 * 1024 * 1024 + val LogRollHours = 24 * 7 + val LogRollJitterHours = 0 + val LogRetentionHours = 24 * 7 + + val LogRetentionBytes = -1L + val LogCleanupIntervalMs = 5 * 60 * 1000L + val Delete = "delete" + val Compact = "compact" + val LogCleanupPolicy = Delete + val LogCleanerThreads = 1 + val LogCleanerIoMaxBytesPerSecond = Double.MaxValue + val LogCleanerDedupeBufferSize = 128 * 1024 * 1024L + val LogCleanerIoBufferSize = 512 * 1024 + val LogCleanerDedupeBufferLoadFactor = 0.9d + val LogCleanerBackoffMs = 15 * 1000 + val LogCleanerMinCleanRatio = 0.5d + val LogCleanerEnable = true + val LogCleanerDeleteRetentionMs = 24 * 60 * 60 * 1000L + val LogCleanerMinCompactionLagMs = 0L + val LogCleanerMaxCompactionLagMs = Long.MaxValue + val LogIndexSizeMaxBytes = 10 * 1024 * 1024 + val LogIndexIntervalBytes = 4096 + val LogFlushIntervalMessages = Long.MaxValue + val LogDeleteDelayMs = 60000 + val LogFlushSchedulerIntervalMs = Long.MaxValue + val LogFlushOffsetCheckpointIntervalMs = 60000 + val LogFlushStartOffsetCheckpointIntervalMs = 60000 + val LogPreAllocateEnable = false + + /* See `TopicConfig.MESSAGE_FORMAT_VERSION_CONFIG` for details */ + @deprecated("3.0") + val LogMessageFormatVersion = KAFKA_3_0_IV1.version + + val LogMessageTimestampType = "CreateTime" + val LogMessageTimestampDifferenceMaxMs = Long.MaxValue + val NumRecoveryThreadsPerDataDir = 1 + val AutoCreateTopicsEnable = true + val MinInSyncReplicas = 1 + val MessageDownConversionEnable = true + + /** ********* Replication configuration ***********/ + val ControllerSocketTimeoutMs = RequestTimeoutMs + val ControllerMessageQueueSize = Int.MaxValue + val DefaultReplicationFactor = 1 + val ReplicaLagTimeMaxMs = 30000L + val ReplicaSocketTimeoutMs = 30 * 1000 + val ReplicaSocketReceiveBufferBytes = 64 * 1024 + val ReplicaFetchMaxBytes = 1024 * 1024 + val ReplicaFetchWaitMaxMs = 500 + val ReplicaFetchMinBytes = 1 + val ReplicaFetchResponseMaxBytes = 10 * 1024 * 1024 + val NumReplicaFetchers = 1 + val ReplicaFetchBackoffMs = 1000 + val ReplicaHighWatermarkCheckpointIntervalMs = 5000L + val FetchPurgatoryPurgeIntervalRequests = 1000 + val ProducerPurgatoryPurgeIntervalRequests = 1000 + val DeleteRecordsPurgatoryPurgeIntervalRequests = 1 + val AutoLeaderRebalanceEnable = true + val LeaderImbalancePerBrokerPercentage = 10 + val LeaderImbalanceCheckIntervalSeconds = 300 + val UncleanLeaderElectionEnable = false + val InterBrokerSecurityProtocol = SecurityProtocol.PLAINTEXT.toString + val InterBrokerProtocolVersion = ApiVersion.latestVersion.toString + + /** ********* Controlled shutdown configuration ***********/ + val ControlledShutdownMaxRetries = 3 + val ControlledShutdownRetryBackoffMs = 5000 + val ControlledShutdownEnable = true + + /** ********* Group coordinator configuration ***********/ + val GroupMinSessionTimeoutMs = 6000 + val GroupMaxSessionTimeoutMs = 1800000 + val GroupInitialRebalanceDelayMs = 3000 + val GroupMaxSize: Int = Int.MaxValue + + /** ********* Offset management configuration ***********/ + val OffsetMetadataMaxSize = OffsetConfig.DefaultMaxMetadataSize + val OffsetsLoadBufferSize = OffsetConfig.DefaultLoadBufferSize + val OffsetsTopicReplicationFactor = OffsetConfig.DefaultOffsetsTopicReplicationFactor + val OffsetsTopicPartitions: Int = OffsetConfig.DefaultOffsetsTopicNumPartitions + val OffsetsTopicSegmentBytes: Int = OffsetConfig.DefaultOffsetsTopicSegmentBytes + val OffsetsTopicCompressionCodec: Int = OffsetConfig.DefaultOffsetsTopicCompressionCodec.codec + val OffsetsRetentionMinutes: Int = 7 * 24 * 60 + val OffsetsRetentionCheckIntervalMs: Long = OffsetConfig.DefaultOffsetsRetentionCheckIntervalMs + val OffsetCommitTimeoutMs = OffsetConfig.DefaultOffsetCommitTimeoutMs + val OffsetCommitRequiredAcks = OffsetConfig.DefaultOffsetCommitRequiredAcks + + /** ********* Transaction management configuration ***********/ + val TransactionalIdExpirationMs = TransactionStateManager.DefaultTransactionalIdExpirationMs + val TransactionsMaxTimeoutMs = TransactionStateManager.DefaultTransactionsMaxTimeoutMs + val TransactionsTopicMinISR = TransactionLog.DefaultMinInSyncReplicas + val TransactionsLoadBufferSize = TransactionLog.DefaultLoadBufferSize + val TransactionsTopicReplicationFactor = TransactionLog.DefaultReplicationFactor + val TransactionsTopicPartitions = TransactionLog.DefaultNumPartitions + val TransactionsTopicSegmentBytes = TransactionLog.DefaultSegmentBytes + val TransactionsAbortTimedOutTransactionsCleanupIntervalMS = TransactionStateManager.DefaultAbortTimedOutTransactionsIntervalMs + val TransactionsRemoveExpiredTransactionsCleanupIntervalMS = TransactionStateManager.DefaultRemoveExpiredTransactionalIdsIntervalMs + + /** ********* Fetch Configuration **************/ + val MaxIncrementalFetchSessionCacheSlots = 1000 + val FetchMaxBytes = 55 * 1024 * 1024 + + /** ********* Quota Configuration ***********/ + val NumQuotaSamples: Int = ClientQuotaManagerConfig.DefaultNumQuotaSamples + val QuotaWindowSizeSeconds: Int = ClientQuotaManagerConfig.DefaultQuotaWindowSizeSeconds + val NumReplicationQuotaSamples: Int = ReplicationQuotaManagerConfig.DefaultNumQuotaSamples + val ReplicationQuotaWindowSizeSeconds: Int = ReplicationQuotaManagerConfig.DefaultQuotaWindowSizeSeconds + val NumAlterLogDirsReplicationQuotaSamples: Int = ReplicationQuotaManagerConfig.DefaultNumQuotaSamples + val AlterLogDirsReplicationQuotaWindowSizeSeconds: Int = ReplicationQuotaManagerConfig.DefaultQuotaWindowSizeSeconds + val NumControllerQuotaSamples: Int = ClientQuotaManagerConfig.DefaultNumQuotaSamples + val ControllerQuotaWindowSizeSeconds: Int = ClientQuotaManagerConfig.DefaultQuotaWindowSizeSeconds + + /** ********* Transaction Configuration ***********/ + val TransactionalIdExpirationMsDefault = 604800000 + + val DeleteTopicEnable = true + + val CompressionType = "producer" + + val MaxIdMapSnapshots = 2 + /** ********* Kafka Metrics Configuration ***********/ + val MetricNumSamples = 2 + val MetricSampleWindowMs = 30000 + val MetricReporterClasses = "" + val MetricRecordingLevel = Sensor.RecordingLevel.INFO.toString() + + + /** ********* Kafka Yammer Metrics Reporter Configuration ***********/ + val KafkaMetricReporterClasses = "" + val KafkaMetricsPollingIntervalSeconds = 10 + + /** ********* SSL configuration ***********/ + val SslProtocol = SslConfigs.DEFAULT_SSL_PROTOCOL + val SslEnabledProtocols = SslConfigs.DEFAULT_SSL_ENABLED_PROTOCOLS + val SslKeystoreType = SslConfigs.DEFAULT_SSL_KEYSTORE_TYPE + val SslTruststoreType = SslConfigs.DEFAULT_SSL_TRUSTSTORE_TYPE + val SslKeyManagerAlgorithm = SslConfigs.DEFAULT_SSL_KEYMANGER_ALGORITHM + val SslTrustManagerAlgorithm = SslConfigs.DEFAULT_SSL_TRUSTMANAGER_ALGORITHM + val SslEndpointIdentificationAlgorithm = SslConfigs.DEFAULT_SSL_ENDPOINT_IDENTIFICATION_ALGORITHM + val SslClientAuthentication = SslClientAuth.NONE.name().toLowerCase(Locale.ROOT) + val SslClientAuthenticationValidValues = SslClientAuth.VALUES.asScala.map(v => v.toString().toLowerCase(Locale.ROOT)).asJava.toArray(new Array[String](0)) + val SslPrincipalMappingRules = BrokerSecurityConfigs.DEFAULT_SSL_PRINCIPAL_MAPPING_RULES + + /** ********* General Security configuration ***********/ + val ConnectionsMaxReauthMsDefault = 0L + val DefaultPrincipalSerde = classOf[DefaultKafkaPrincipalBuilder] + + /** ********* Sasl configuration ***********/ + val SaslMechanismInterBrokerProtocol = SaslConfigs.DEFAULT_SASL_MECHANISM + val SaslEnabledMechanisms = BrokerSecurityConfigs.DEFAULT_SASL_ENABLED_MECHANISMS + val SaslKerberosKinitCmd = SaslConfigs.DEFAULT_KERBEROS_KINIT_CMD + val SaslKerberosTicketRenewWindowFactor = SaslConfigs.DEFAULT_KERBEROS_TICKET_RENEW_WINDOW_FACTOR + val SaslKerberosTicketRenewJitter = SaslConfigs.DEFAULT_KERBEROS_TICKET_RENEW_JITTER + val SaslKerberosMinTimeBeforeRelogin = SaslConfigs.DEFAULT_KERBEROS_MIN_TIME_BEFORE_RELOGIN + val SaslKerberosPrincipalToLocalRules = BrokerSecurityConfigs.DEFAULT_SASL_KERBEROS_PRINCIPAL_TO_LOCAL_RULES + val SaslLoginRefreshWindowFactor = SaslConfigs.DEFAULT_LOGIN_REFRESH_WINDOW_FACTOR + val SaslLoginRefreshWindowJitter = SaslConfigs.DEFAULT_LOGIN_REFRESH_WINDOW_JITTER + val SaslLoginRefreshMinPeriodSeconds = SaslConfigs.DEFAULT_LOGIN_REFRESH_MIN_PERIOD_SECONDS + val SaslLoginRefreshBufferSeconds = SaslConfigs.DEFAULT_LOGIN_REFRESH_BUFFER_SECONDS + val SaslLoginRetryBackoffMaxMs = SaslConfigs.DEFAULT_SASL_LOGIN_RETRY_BACKOFF_MAX_MS + val SaslLoginRetryBackoffMs = SaslConfigs.DEFAULT_SASL_LOGIN_RETRY_BACKOFF_MS + val SaslOAuthBearerScopeClaimName = SaslConfigs.DEFAULT_SASL_OAUTHBEARER_SCOPE_CLAIM_NAME + val SaslOAuthBearerSubClaimName = SaslConfigs.DEFAULT_SASL_OAUTHBEARER_SUB_CLAIM_NAME + val SaslOAuthBearerJwksEndpointRefreshMs = SaslConfigs.DEFAULT_SASL_OAUTHBEARER_JWKS_ENDPOINT_REFRESH_MS + val SaslOAuthBearerJwksEndpointRetryBackoffMaxMs = SaslConfigs.DEFAULT_SASL_OAUTHBEARER_JWKS_ENDPOINT_RETRY_BACKOFF_MAX_MS + val SaslOAuthBearerJwksEndpointRetryBackoffMs = SaslConfigs.DEFAULT_SASL_OAUTHBEARER_JWKS_ENDPOINT_RETRY_BACKOFF_MS + val SaslOAuthBearerClockSkewSeconds = SaslConfigs.DEFAULT_SASL_OAUTHBEARER_CLOCK_SKEW_SECONDS + + /** ********* Delegation Token configuration ***********/ + val DelegationTokenMaxLifeTimeMsDefault = 7 * 24 * 60 * 60 * 1000L + val DelegationTokenExpiryTimeMsDefault = 24 * 60 * 60 * 1000L + val DelegationTokenExpiryCheckIntervalMsDefault = 1 * 60 * 60 * 1000L + + /** ********* Password encryption configuration for dynamic configs *********/ + val PasswordEncoderCipherAlgorithm = "AES/CBC/PKCS5Padding" + val PasswordEncoderKeyLength = 128 + val PasswordEncoderIterations = 4096 + + /** ********* Raft Quorum Configuration *********/ + val QuorumVoters = RaftConfig.DEFAULT_QUORUM_VOTERS + val QuorumElectionTimeoutMs = RaftConfig.DEFAULT_QUORUM_ELECTION_TIMEOUT_MS + val QuorumFetchTimeoutMs = RaftConfig.DEFAULT_QUORUM_FETCH_TIMEOUT_MS + val QuorumElectionBackoffMs = RaftConfig.DEFAULT_QUORUM_ELECTION_BACKOFF_MAX_MS + val QuorumLingerMs = RaftConfig.DEFAULT_QUORUM_LINGER_MS + val QuorumRequestTimeoutMs = RaftConfig.DEFAULT_QUORUM_REQUEST_TIMEOUT_MS + val QuorumRetryBackoffMs = RaftConfig.DEFAULT_QUORUM_RETRY_BACKOFF_MS +} + +object KafkaConfig { + + private val LogConfigPrefix = "log." + + def main(args: Array[String]): Unit = { + System.out.println(configDef.toHtml(4, (config: String) => "brokerconfigs_" + config, + DynamicBrokerConfig.dynamicConfigUpdateModes)) + } + + /** ********* Zookeeper Configuration ***********/ + val ZkConnectProp = "zookeeper.connect" + val ZkSessionTimeoutMsProp = "zookeeper.session.timeout.ms" + val ZkConnectionTimeoutMsProp = "zookeeper.connection.timeout.ms" + val ZkSyncTimeMsProp = "zookeeper.sync.time.ms" + val ZkEnableSecureAclsProp = "zookeeper.set.acl" + val ZkMaxInFlightRequestsProp = "zookeeper.max.in.flight.requests" + val ZkSslClientEnableProp = "zookeeper.ssl.client.enable" + val ZkClientCnxnSocketProp = "zookeeper.clientCnxnSocket" + val ZkSslKeyStoreLocationProp = "zookeeper.ssl.keystore.location" + val ZkSslKeyStorePasswordProp = "zookeeper.ssl.keystore.password" + val ZkSslKeyStoreTypeProp = "zookeeper.ssl.keystore.type" + val ZkSslTrustStoreLocationProp = "zookeeper.ssl.truststore.location" + val ZkSslTrustStorePasswordProp = "zookeeper.ssl.truststore.password" + val ZkSslTrustStoreTypeProp = "zookeeper.ssl.truststore.type" + val ZkSslProtocolProp = "zookeeper.ssl.protocol" + val ZkSslEnabledProtocolsProp = "zookeeper.ssl.enabled.protocols" + val ZkSslCipherSuitesProp = "zookeeper.ssl.cipher.suites" + val ZkSslEndpointIdentificationAlgorithmProp = "zookeeper.ssl.endpoint.identification.algorithm" + val ZkSslCrlEnableProp = "zookeeper.ssl.crl.enable" + val ZkSslOcspEnableProp = "zookeeper.ssl.ocsp.enable" + + // a map from the Kafka config to the corresponding ZooKeeper Java system property + private[kafka] val ZkSslConfigToSystemPropertyMap: Map[String, String] = Map( + ZkSslClientEnableProp -> ZKClientConfig.SECURE_CLIENT, + ZkClientCnxnSocketProp -> ZKClientConfig.ZOOKEEPER_CLIENT_CNXN_SOCKET, + ZkSslKeyStoreLocationProp -> "zookeeper.ssl.keyStore.location", + ZkSslKeyStorePasswordProp -> "zookeeper.ssl.keyStore.password", + ZkSslKeyStoreTypeProp -> "zookeeper.ssl.keyStore.type", + ZkSslTrustStoreLocationProp -> "zookeeper.ssl.trustStore.location", + ZkSslTrustStorePasswordProp -> "zookeeper.ssl.trustStore.password", + ZkSslTrustStoreTypeProp -> "zookeeper.ssl.trustStore.type", + ZkSslProtocolProp -> "zookeeper.ssl.protocol", + ZkSslEnabledProtocolsProp -> "zookeeper.ssl.enabledProtocols", + ZkSslCipherSuitesProp -> "zookeeper.ssl.ciphersuites", + ZkSslEndpointIdentificationAlgorithmProp -> "zookeeper.ssl.hostnameVerification", + ZkSslCrlEnableProp -> "zookeeper.ssl.crl", + ZkSslOcspEnableProp -> "zookeeper.ssl.ocsp") + + private[kafka] def zooKeeperClientProperty(clientConfig: ZKClientConfig, kafkaPropName: String): Option[String] = { + Option(clientConfig.getProperty(ZkSslConfigToSystemPropertyMap(kafkaPropName))) + } + + private[kafka] def setZooKeeperClientProperty(clientConfig: ZKClientConfig, kafkaPropName: String, kafkaPropValue: Any): Unit = { + clientConfig.setProperty(ZkSslConfigToSystemPropertyMap(kafkaPropName), + kafkaPropName match { + case ZkSslEndpointIdentificationAlgorithmProp => (kafkaPropValue.toString.toUpperCase == "HTTPS").toString + case ZkSslEnabledProtocolsProp | ZkSslCipherSuitesProp => kafkaPropValue match { + case list: java.util.List[_] => list.asScala.mkString(",") + case _ => kafkaPropValue.toString + } + case _ => kafkaPropValue.toString + }) + } + + // For ZooKeeper TLS client authentication to be enabled the client must (at a minimum) configure itself as using TLS + // with both a client connection socket and a key store location explicitly set. + private[kafka] def zkTlsClientAuthEnabled(zkClientConfig: ZKClientConfig): Boolean = { + zooKeeperClientProperty(zkClientConfig, ZkSslClientEnableProp).contains("true") && + zooKeeperClientProperty(zkClientConfig, ZkClientCnxnSocketProp).isDefined && + zooKeeperClientProperty(zkClientConfig, ZkSslKeyStoreLocationProp).isDefined + } + + /** ********* General Configuration ***********/ + val BrokerIdGenerationEnableProp = "broker.id.generation.enable" + val MaxReservedBrokerIdProp = "reserved.broker.max.id" + val BrokerIdProp = "broker.id" + val MessageMaxBytesProp = "message.max.bytes" + val NumNetworkThreadsProp = "num.network.threads" + val NumIoThreadsProp = "num.io.threads" + val BackgroundThreadsProp = "background.threads" + val NumReplicaAlterLogDirsThreadsProp = "num.replica.alter.log.dirs.threads" + val QueuedMaxRequestsProp = "queued.max.requests" + val QueuedMaxBytesProp = "queued.max.request.bytes" + val RequestTimeoutMsProp = CommonClientConfigs.REQUEST_TIMEOUT_MS_CONFIG + val ConnectionSetupTimeoutMsProp = CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MS_CONFIG + val ConnectionSetupTimeoutMaxMsProp = CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS_CONFIG + + /** KRaft mode configs */ + val ProcessRolesProp = "process.roles" + val InitialBrokerRegistrationTimeoutMsProp = "initial.broker.registration.timeout.ms" + val BrokerHeartbeatIntervalMsProp = "broker.heartbeat.interval.ms" + val BrokerSessionTimeoutMsProp = "broker.session.timeout.ms" + val NodeIdProp = "node.id" + val MetadataLogDirProp = "metadata.log.dir" + val MetadataSnapshotMaxNewRecordBytesProp = "metadata.log.max.record.bytes.between.snapshots" + val ControllerListenerNamesProp = "controller.listener.names" + val SaslMechanismControllerProtocolProp = "sasl.mechanism.controller.protocol" + val MetadataLogSegmentMinBytesProp = "metadata.log.segment.min.bytes" + val MetadataLogSegmentBytesProp = "metadata.log.segment.bytes" + val MetadataLogSegmentMillisProp = "metadata.log.segment.ms" + val MetadataMaxRetentionBytesProp = "metadata.max.retention.bytes" + val MetadataMaxRetentionMillisProp = "metadata.max.retention.ms" + val QuorumVotersProp = RaftConfig.QUORUM_VOTERS_CONFIG + + /************* Authorizer Configuration ***********/ + val AuthorizerClassNameProp = "authorizer.class.name" + /** ********* Socket Server Configuration ***********/ + val ListenersProp = "listeners" + val AdvertisedListenersProp = "advertised.listeners" + val ListenerSecurityProtocolMapProp = "listener.security.protocol.map" + val ControlPlaneListenerNameProp = "control.plane.listener.name" + val SocketSendBufferBytesProp = "socket.send.buffer.bytes" + val SocketReceiveBufferBytesProp = "socket.receive.buffer.bytes" + val SocketRequestMaxBytesProp = "socket.request.max.bytes" + val MaxConnectionsPerIpProp = "max.connections.per.ip" + val MaxConnectionsPerIpOverridesProp = "max.connections.per.ip.overrides" + val MaxConnectionsProp = "max.connections" + val MaxConnectionCreationRateProp = "max.connection.creation.rate" + val ConnectionsMaxIdleMsProp = "connections.max.idle.ms" + val FailedAuthenticationDelayMsProp = "connection.failed.authentication.delay.ms" + /***************** rack configuration *************/ + val RackProp = "broker.rack" + /** ********* Log Configuration ***********/ + val NumPartitionsProp = "num.partitions" + val LogDirsProp = "log.dirs" + val LogDirProp = "log.dir" + val LogSegmentBytesProp = "log.segment.bytes" + + val LogRollTimeMillisProp = "log.roll.ms" + val LogRollTimeHoursProp = "log.roll.hours" + + val LogRollTimeJitterMillisProp = "log.roll.jitter.ms" + val LogRollTimeJitterHoursProp = "log.roll.jitter.hours" + + val LogRetentionTimeMillisProp = "log.retention.ms" + val LogRetentionTimeMinutesProp = "log.retention.minutes" + val LogRetentionTimeHoursProp = "log.retention.hours" + + val LogRetentionBytesProp = "log.retention.bytes" + val LogCleanupIntervalMsProp = "log.retention.check.interval.ms" + val LogCleanupPolicyProp = "log.cleanup.policy" + val LogCleanerThreadsProp = "log.cleaner.threads" + val LogCleanerIoMaxBytesPerSecondProp = "log.cleaner.io.max.bytes.per.second" + val LogCleanerDedupeBufferSizeProp = "log.cleaner.dedupe.buffer.size" + val LogCleanerIoBufferSizeProp = "log.cleaner.io.buffer.size" + val LogCleanerDedupeBufferLoadFactorProp = "log.cleaner.io.buffer.load.factor" + val LogCleanerBackoffMsProp = "log.cleaner.backoff.ms" + val LogCleanerMinCleanRatioProp = "log.cleaner.min.cleanable.ratio" + val LogCleanerEnableProp = "log.cleaner.enable" + val LogCleanerDeleteRetentionMsProp = "log.cleaner.delete.retention.ms" + val LogCleanerMinCompactionLagMsProp = "log.cleaner.min.compaction.lag.ms" + val LogCleanerMaxCompactionLagMsProp = "log.cleaner.max.compaction.lag.ms" + val LogIndexSizeMaxBytesProp = "log.index.size.max.bytes" + val LogIndexIntervalBytesProp = "log.index.interval.bytes" + val LogFlushIntervalMessagesProp = "log.flush.interval.messages" + val LogDeleteDelayMsProp = "log.segment.delete.delay.ms" + val LogFlushSchedulerIntervalMsProp = "log.flush.scheduler.interval.ms" + val LogFlushIntervalMsProp = "log.flush.interval.ms" + val LogFlushOffsetCheckpointIntervalMsProp = "log.flush.offset.checkpoint.interval.ms" + val LogFlushStartOffsetCheckpointIntervalMsProp = "log.flush.start.offset.checkpoint.interval.ms" + val LogPreAllocateProp = "log.preallocate" + + /* See `TopicConfig.MESSAGE_FORMAT_VERSION_CONFIG` for details */ + @deprecated("3.0") + val LogMessageFormatVersionProp = LogConfigPrefix + "message.format.version" + + val LogMessageTimestampTypeProp = LogConfigPrefix + "message.timestamp.type" + val LogMessageTimestampDifferenceMaxMsProp = LogConfigPrefix + "message.timestamp.difference.max.ms" + val LogMaxIdMapSnapshotsProp = LogConfigPrefix + "max.id.map.snapshots" + val NumRecoveryThreadsPerDataDirProp = "num.recovery.threads.per.data.dir" + val AutoCreateTopicsEnableProp = "auto.create.topics.enable" + val MinInSyncReplicasProp = "min.insync.replicas" + val CreateTopicPolicyClassNameProp = "create.topic.policy.class.name" + val AlterConfigPolicyClassNameProp = "alter.config.policy.class.name" + val LogMessageDownConversionEnableProp = LogConfigPrefix + "message.downconversion.enable" + /** ********* Replication configuration ***********/ + val ControllerSocketTimeoutMsProp = "controller.socket.timeout.ms" + val DefaultReplicationFactorProp = "default.replication.factor" + val ReplicaLagTimeMaxMsProp = "replica.lag.time.max.ms" + val ReplicaSocketTimeoutMsProp = "replica.socket.timeout.ms" + val ReplicaSocketReceiveBufferBytesProp = "replica.socket.receive.buffer.bytes" + val ReplicaFetchMaxBytesProp = "replica.fetch.max.bytes" + val ReplicaFetchWaitMaxMsProp = "replica.fetch.wait.max.ms" + val ReplicaFetchMinBytesProp = "replica.fetch.min.bytes" + val ReplicaFetchResponseMaxBytesProp = "replica.fetch.response.max.bytes" + val ReplicaFetchBackoffMsProp = "replica.fetch.backoff.ms" + val NumReplicaFetchersProp = "num.replica.fetchers" + val ReplicaHighWatermarkCheckpointIntervalMsProp = "replica.high.watermark.checkpoint.interval.ms" + val FetchPurgatoryPurgeIntervalRequestsProp = "fetch.purgatory.purge.interval.requests" + val ProducerPurgatoryPurgeIntervalRequestsProp = "producer.purgatory.purge.interval.requests" + val DeleteRecordsPurgatoryPurgeIntervalRequestsProp = "delete.records.purgatory.purge.interval.requests" + val AutoLeaderRebalanceEnableProp = "auto.leader.rebalance.enable" + val LeaderImbalancePerBrokerPercentageProp = "leader.imbalance.per.broker.percentage" + val LeaderImbalanceCheckIntervalSecondsProp = "leader.imbalance.check.interval.seconds" + val UncleanLeaderElectionEnableProp = "unclean.leader.election.enable" + val InterBrokerSecurityProtocolProp = "security.inter.broker.protocol" + val InterBrokerProtocolVersionProp = "inter.broker.protocol.version" + val InterBrokerListenerNameProp = "inter.broker.listener.name" + val ReplicaSelectorClassProp = "replica.selector.class" + /** ********* Controlled shutdown configuration ***********/ + val ControlledShutdownMaxRetriesProp = "controlled.shutdown.max.retries" + val ControlledShutdownRetryBackoffMsProp = "controlled.shutdown.retry.backoff.ms" + val ControlledShutdownEnableProp = "controlled.shutdown.enable" + /** ********* Group coordinator configuration ***********/ + val GroupMinSessionTimeoutMsProp = "group.min.session.timeout.ms" + val GroupMaxSessionTimeoutMsProp = "group.max.session.timeout.ms" + val GroupInitialRebalanceDelayMsProp = "group.initial.rebalance.delay.ms" + val GroupMaxSizeProp = "group.max.size" + /** ********* Offset management configuration ***********/ + val OffsetMetadataMaxSizeProp = "offset.metadata.max.bytes" + val OffsetsLoadBufferSizeProp = "offsets.load.buffer.size" + val OffsetsTopicReplicationFactorProp = "offsets.topic.replication.factor" + val OffsetsTopicPartitionsProp = "offsets.topic.num.partitions" + val OffsetsTopicSegmentBytesProp = "offsets.topic.segment.bytes" + val OffsetsTopicCompressionCodecProp = "offsets.topic.compression.codec" + val OffsetsRetentionMinutesProp = "offsets.retention.minutes" + val OffsetsRetentionCheckIntervalMsProp = "offsets.retention.check.interval.ms" + val OffsetCommitTimeoutMsProp = "offsets.commit.timeout.ms" + val OffsetCommitRequiredAcksProp = "offsets.commit.required.acks" + /** ********* Transaction management configuration ***********/ + val TransactionalIdExpirationMsProp = "transactional.id.expiration.ms" + val TransactionsMaxTimeoutMsProp = "transaction.max.timeout.ms" + val TransactionsTopicMinISRProp = "transaction.state.log.min.isr" + val TransactionsLoadBufferSizeProp = "transaction.state.log.load.buffer.size" + val TransactionsTopicPartitionsProp = "transaction.state.log.num.partitions" + val TransactionsTopicSegmentBytesProp = "transaction.state.log.segment.bytes" + val TransactionsTopicReplicationFactorProp = "transaction.state.log.replication.factor" + val TransactionsAbortTimedOutTransactionCleanupIntervalMsProp = "transaction.abort.timed.out.transaction.cleanup.interval.ms" + val TransactionsRemoveExpiredTransactionalIdCleanupIntervalMsProp = "transaction.remove.expired.transaction.cleanup.interval.ms" + + /** ********* Fetch Configuration **************/ + val MaxIncrementalFetchSessionCacheSlots = "max.incremental.fetch.session.cache.slots" + val FetchMaxBytes = "fetch.max.bytes" + + /** ********* Quota Configuration ***********/ + val NumQuotaSamplesProp = "quota.window.num" + val NumReplicationQuotaSamplesProp = "replication.quota.window.num" + val NumAlterLogDirsReplicationQuotaSamplesProp = "alter.log.dirs.replication.quota.window.num" + val NumControllerQuotaSamplesProp = "controller.quota.window.num" + val QuotaWindowSizeSecondsProp = "quota.window.size.seconds" + val ReplicationQuotaWindowSizeSecondsProp = "replication.quota.window.size.seconds" + val AlterLogDirsReplicationQuotaWindowSizeSecondsProp = "alter.log.dirs.replication.quota.window.size.seconds" + val ControllerQuotaWindowSizeSecondsProp = "controller.quota.window.size.seconds" + val ClientQuotaCallbackClassProp = "client.quota.callback.class" + + val DeleteTopicEnableProp = "delete.topic.enable" + val CompressionTypeProp = "compression.type" + + /** ********* Kafka Metrics Configuration ***********/ + val MetricSampleWindowMsProp = CommonClientConfigs.METRICS_SAMPLE_WINDOW_MS_CONFIG + val MetricNumSamplesProp: String = CommonClientConfigs.METRICS_NUM_SAMPLES_CONFIG + val MetricReporterClassesProp: String = CommonClientConfigs.METRIC_REPORTER_CLASSES_CONFIG + val MetricRecordingLevelProp: String = CommonClientConfigs.METRICS_RECORDING_LEVEL_CONFIG + + /** ********* Kafka Yammer Metrics Reporters Configuration ***********/ + val KafkaMetricsReporterClassesProp = "kafka.metrics.reporters" + val KafkaMetricsPollingIntervalSecondsProp = "kafka.metrics.polling.interval.secs" + + /** ******** Common Security Configuration *************/ + val PrincipalBuilderClassProp = BrokerSecurityConfigs.PRINCIPAL_BUILDER_CLASS_CONFIG + val ConnectionsMaxReauthMsProp = BrokerSecurityConfigs.CONNECTIONS_MAX_REAUTH_MS + val securityProviderClassProp = SecurityConfig.SECURITY_PROVIDERS_CONFIG + + /** ********* SSL Configuration ****************/ + val SslProtocolProp = SslConfigs.SSL_PROTOCOL_CONFIG + val SslProviderProp = SslConfigs.SSL_PROVIDER_CONFIG + val SslCipherSuitesProp = SslConfigs.SSL_CIPHER_SUITES_CONFIG + val SslEnabledProtocolsProp = SslConfigs.SSL_ENABLED_PROTOCOLS_CONFIG + val SslKeystoreTypeProp = SslConfigs.SSL_KEYSTORE_TYPE_CONFIG + val SslKeystoreLocationProp = SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG + val SslKeystorePasswordProp = SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG + val SslKeyPasswordProp = SslConfigs.SSL_KEY_PASSWORD_CONFIG + val SslKeystoreKeyProp = SslConfigs.SSL_KEYSTORE_KEY_CONFIG + val SslKeystoreCertificateChainProp = SslConfigs.SSL_KEYSTORE_CERTIFICATE_CHAIN_CONFIG + val SslTruststoreTypeProp = SslConfigs.SSL_TRUSTSTORE_TYPE_CONFIG + val SslTruststoreLocationProp = SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG + val SslTruststorePasswordProp = SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG + val SslTruststoreCertificatesProp = SslConfigs.SSL_TRUSTSTORE_CERTIFICATES_CONFIG + val SslKeyManagerAlgorithmProp = SslConfigs.SSL_KEYMANAGER_ALGORITHM_CONFIG + val SslTrustManagerAlgorithmProp = SslConfigs.SSL_TRUSTMANAGER_ALGORITHM_CONFIG + val SslEndpointIdentificationAlgorithmProp = SslConfigs.SSL_ENDPOINT_IDENTIFICATION_ALGORITHM_CONFIG + val SslSecureRandomImplementationProp = SslConfigs.SSL_SECURE_RANDOM_IMPLEMENTATION_CONFIG + val SslClientAuthProp = BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG + val SslPrincipalMappingRulesProp = BrokerSecurityConfigs.SSL_PRINCIPAL_MAPPING_RULES_CONFIG + var SslEngineFactoryClassProp = SslConfigs.SSL_ENGINE_FACTORY_CLASS_CONFIG + + /** ********* SASL Configuration ****************/ + val SaslMechanismInterBrokerProtocolProp = "sasl.mechanism.inter.broker.protocol" + val SaslJaasConfigProp = SaslConfigs.SASL_JAAS_CONFIG + val SaslEnabledMechanismsProp = BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG + val SaslServerCallbackHandlerClassProp = BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS + val SaslClientCallbackHandlerClassProp = SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS + val SaslLoginClassProp = SaslConfigs.SASL_LOGIN_CLASS + val SaslLoginCallbackHandlerClassProp = SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS + val SaslKerberosServiceNameProp = SaslConfigs.SASL_KERBEROS_SERVICE_NAME + val SaslKerberosKinitCmdProp = SaslConfigs.SASL_KERBEROS_KINIT_CMD + val SaslKerberosTicketRenewWindowFactorProp = SaslConfigs.SASL_KERBEROS_TICKET_RENEW_WINDOW_FACTOR + val SaslKerberosTicketRenewJitterProp = SaslConfigs.SASL_KERBEROS_TICKET_RENEW_JITTER + val SaslKerberosMinTimeBeforeReloginProp = SaslConfigs.SASL_KERBEROS_MIN_TIME_BEFORE_RELOGIN + val SaslKerberosPrincipalToLocalRulesProp = BrokerSecurityConfigs.SASL_KERBEROS_PRINCIPAL_TO_LOCAL_RULES_CONFIG + val SaslLoginRefreshWindowFactorProp = SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_FACTOR + val SaslLoginRefreshWindowJitterProp = SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_JITTER + val SaslLoginRefreshMinPeriodSecondsProp = SaslConfigs.SASL_LOGIN_REFRESH_MIN_PERIOD_SECONDS + val SaslLoginRefreshBufferSecondsProp = SaslConfigs.SASL_LOGIN_REFRESH_BUFFER_SECONDS + + val SaslLoginConnectTimeoutMsProp = SaslConfigs.SASL_LOGIN_CONNECT_TIMEOUT_MS + val SaslLoginReadTimeoutMsProp = SaslConfigs.SASL_LOGIN_READ_TIMEOUT_MS + val SaslLoginRetryBackoffMaxMsProp = SaslConfigs.SASL_LOGIN_RETRY_BACKOFF_MAX_MS + val SaslLoginRetryBackoffMsProp = SaslConfigs.SASL_LOGIN_RETRY_BACKOFF_MS + val SaslOAuthBearerScopeClaimNameProp = SaslConfigs.SASL_OAUTHBEARER_SCOPE_CLAIM_NAME + val SaslOAuthBearerSubClaimNameProp = SaslConfigs.SASL_OAUTHBEARER_SUB_CLAIM_NAME + val SaslOAuthBearerTokenEndpointUrlProp = SaslConfigs.SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL + val SaslOAuthBearerJwksEndpointUrlProp = SaslConfigs.SASL_OAUTHBEARER_JWKS_ENDPOINT_URL + val SaslOAuthBearerJwksEndpointRefreshMsProp = SaslConfigs.SASL_OAUTHBEARER_JWKS_ENDPOINT_REFRESH_MS + val SaslOAuthBearerJwksEndpointRetryBackoffMaxMsProp = SaslConfigs.SASL_OAUTHBEARER_JWKS_ENDPOINT_RETRY_BACKOFF_MAX_MS + val SaslOAuthBearerJwksEndpointRetryBackoffMsProp = SaslConfigs.SASL_OAUTHBEARER_JWKS_ENDPOINT_RETRY_BACKOFF_MS + val SaslOAuthBearerClockSkewSecondsProp = SaslConfigs.SASL_OAUTHBEARER_CLOCK_SKEW_SECONDS + val SaslOAuthBearerExpectedAudienceProp = SaslConfigs.SASL_OAUTHBEARER_EXPECTED_AUDIENCE + val SaslOAuthBearerExpectedIssuerProp = SaslConfigs.SASL_OAUTHBEARER_EXPECTED_ISSUER + + /** ********* Delegation Token Configuration ****************/ + val DelegationTokenSecretKeyAliasProp = "delegation.token.master.key" + val DelegationTokenSecretKeyProp = "delegation.token.secret.key" + val DelegationTokenMaxLifeTimeProp = "delegation.token.max.lifetime.ms" + val DelegationTokenExpiryTimeMsProp = "delegation.token.expiry.time.ms" + val DelegationTokenExpiryCheckIntervalMsProp = "delegation.token.expiry.check.interval.ms" + + /** ********* Password encryption configuration for dynamic configs *********/ + val PasswordEncoderSecretProp = "password.encoder.secret" + val PasswordEncoderOldSecretProp = "password.encoder.old.secret" + val PasswordEncoderKeyFactoryAlgorithmProp = "password.encoder.keyfactory.algorithm" + val PasswordEncoderCipherAlgorithmProp = "password.encoder.cipher.algorithm" + val PasswordEncoderKeyLengthProp = "password.encoder.key.length" + val PasswordEncoderIterationsProp = "password.encoder.iterations" + + /* Documentation */ + /** ********* Zookeeper Configuration ***********/ + val ZkConnectDoc = "Specifies the ZooKeeper connection string in the form hostname:port where host and port are the " + + "host and port of a ZooKeeper server. To allow connecting through other ZooKeeper nodes when that ZooKeeper machine is " + + "down you can also specify multiple hosts in the form hostname1:port1,hostname2:port2,hostname3:port3.\n" + + "The server can also have a ZooKeeper chroot path as part of its ZooKeeper connection string which puts its data under some path in the global ZooKeeper namespace. " + + "For example to give a chroot path of /chroot/path you would give the connection string as hostname1:port1,hostname2:port2,hostname3:port3/chroot/path." + val ZkSessionTimeoutMsDoc = "Zookeeper session timeout" + val ZkConnectionTimeoutMsDoc = "The max time that the client waits to establish a connection to zookeeper. If not set, the value in " + ZkSessionTimeoutMsProp + " is used" + val ZkSyncTimeMsDoc = "How far a ZK follower can be behind a ZK leader" + val ZkEnableSecureAclsDoc = "Set client to use secure ACLs" + val ZkMaxInFlightRequestsDoc = "The maximum number of unacknowledged requests the client will send to Zookeeper before blocking." + val ZkSslClientEnableDoc = "Set client to use TLS when connecting to ZooKeeper." + + " An explicit value overrides any value set via the zookeeper.client.secure system property (note the different name)." + + s" Defaults to false if neither is set; when true, $ZkClientCnxnSocketProp must be set (typically to org.apache.zookeeper.ClientCnxnSocketNetty); other values to set may include " + + ZkSslConfigToSystemPropertyMap.keys.toList.sorted.filter(x => x != ZkSslClientEnableProp && x != ZkClientCnxnSocketProp).mkString("", ", ", "") + val ZkClientCnxnSocketDoc = "Typically set to org.apache.zookeeper.ClientCnxnSocketNetty when using TLS connectivity to ZooKeeper." + + s" Overrides any explicit value set via the same-named ${ZkSslConfigToSystemPropertyMap(ZkClientCnxnSocketProp)} system property." + val ZkSslKeyStoreLocationDoc = "Keystore location when using a client-side certificate with TLS connectivity to ZooKeeper." + + s" Overrides any explicit value set via the ${ZkSslConfigToSystemPropertyMap(ZkSslKeyStoreLocationProp)} system property (note the camelCase)." + val ZkSslKeyStorePasswordDoc = "Keystore password when using a client-side certificate with TLS connectivity to ZooKeeper." + + s" Overrides any explicit value set via the ${ZkSslConfigToSystemPropertyMap(ZkSslKeyStorePasswordProp)} system property (note the camelCase)." + + " Note that ZooKeeper does not support a key password different from the keystore password, so be sure to set the key password in the keystore to be identical to the keystore password; otherwise the connection attempt to Zookeeper will fail." + val ZkSslKeyStoreTypeDoc = "Keystore type when using a client-side certificate with TLS connectivity to ZooKeeper." + + s" Overrides any explicit value set via the ${ZkSslConfigToSystemPropertyMap(ZkSslKeyStoreTypeProp)} system property (note the camelCase)." + + " The default value of null means the type will be auto-detected based on the filename extension of the keystore." + val ZkSslTrustStoreLocationDoc = "Truststore location when using TLS connectivity to ZooKeeper." + + s" Overrides any explicit value set via the ${ZkSslConfigToSystemPropertyMap(ZkSslTrustStoreLocationProp)} system property (note the camelCase)." + val ZkSslTrustStorePasswordDoc = "Truststore password when using TLS connectivity to ZooKeeper." + + s" Overrides any explicit value set via the ${ZkSslConfigToSystemPropertyMap(ZkSslTrustStorePasswordProp)} system property (note the camelCase)." + val ZkSslTrustStoreTypeDoc = "Truststore type when using TLS connectivity to ZooKeeper." + + s" Overrides any explicit value set via the ${ZkSslConfigToSystemPropertyMap(ZkSslTrustStoreTypeProp)} system property (note the camelCase)." + + " The default value of null means the type will be auto-detected based on the filename extension of the truststore." + val ZkSslProtocolDoc = "Specifies the protocol to be used in ZooKeeper TLS negotiation." + + s" An explicit value overrides any value set via the same-named ${ZkSslConfigToSystemPropertyMap(ZkSslProtocolProp)} system property." + val ZkSslEnabledProtocolsDoc = "Specifies the enabled protocol(s) in ZooKeeper TLS negotiation (csv)." + + s" Overrides any explicit value set via the ${ZkSslConfigToSystemPropertyMap(ZkSslEnabledProtocolsProp)} system property (note the camelCase)." + + s" The default value of null means the enabled protocol will be the value of the ${KafkaConfig.ZkSslProtocolProp} configuration property." + val ZkSslCipherSuitesDoc = "Specifies the enabled cipher suites to be used in ZooKeeper TLS negotiation (csv)." + + s""" Overrides any explicit value set via the ${ZkSslConfigToSystemPropertyMap(ZkSslCipherSuitesProp)} system property (note the single word \"ciphersuites\").""" + + " The default value of null means the list of enabled cipher suites is determined by the Java runtime being used." + val ZkSslEndpointIdentificationAlgorithmDoc = "Specifies whether to enable hostname verification in the ZooKeeper TLS negotiation process, with (case-insensitively) \"https\" meaning ZooKeeper hostname verification is enabled and an explicit blank value meaning it is disabled (disabling it is only recommended for testing purposes)." + + s""" An explicit value overrides any \"true\" or \"false\" value set via the ${ZkSslConfigToSystemPropertyMap(ZkSslEndpointIdentificationAlgorithmProp)} system property (note the different name and values; true implies https and false implies blank).""" + val ZkSslCrlEnableDoc = "Specifies whether to enable Certificate Revocation List in the ZooKeeper TLS protocols." + + s" Overrides any explicit value set via the ${ZkSslConfigToSystemPropertyMap(ZkSslCrlEnableProp)} system property (note the shorter name)." + val ZkSslOcspEnableDoc = "Specifies whether to enable Online Certificate Status Protocol in the ZooKeeper TLS protocols." + + s" Overrides any explicit value set via the ${ZkSslConfigToSystemPropertyMap(ZkSslOcspEnableProp)} system property (note the shorter name)." + /** ********* General Configuration ***********/ + val BrokerIdGenerationEnableDoc = s"Enable automatic broker id generation on the server. When enabled the value configured for $MaxReservedBrokerIdProp should be reviewed." + val MaxReservedBrokerIdDoc = "Max number that can be used for a broker.id" + val BrokerIdDoc = "The broker id for this server. If unset, a unique broker id will be generated." + + "To avoid conflicts between zookeeper generated broker id's and user configured broker id's, generated broker ids " + + "start from " + MaxReservedBrokerIdProp + " + 1." + val MessageMaxBytesDoc = TopicConfig.MAX_MESSAGE_BYTES_DOC + + s"This can be set per topic with the topic level ${TopicConfig.MAX_MESSAGE_BYTES_CONFIG} config." + val NumNetworkThreadsDoc = "The number of threads that the server uses for receiving requests from the network and sending responses to the network" + val NumIoThreadsDoc = "The number of threads that the server uses for processing requests, which may include disk I/O" + val NumReplicaAlterLogDirsThreadsDoc = "The number of threads that can move replicas between log directories, which may include disk I/O" + val BackgroundThreadsDoc = "The number of threads to use for various background processing tasks" + val QueuedMaxRequestsDoc = "The number of queued requests allowed for data-plane, before blocking the network threads" + val QueuedMaxRequestBytesDoc = "The number of queued bytes allowed before no more requests are read" + val RequestTimeoutMsDoc = CommonClientConfigs.REQUEST_TIMEOUT_MS_DOC + val ConnectionSetupTimeoutMsDoc = CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MS_DOC + val ConnectionSetupTimeoutMaxMsDoc = CommonClientConfigs.SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS_DOC + + /** KRaft mode configs */ + val ProcessRolesDoc = "The roles that this process plays: 'broker', 'controller', or 'broker,controller' if it is both. " + + "This configuration is only applicable for clusters in KRaft (Kafka Raft) mode (instead of ZooKeeper). Leave this config undefined or empty for Zookeeper clusters." + val InitialBrokerRegistrationTimeoutMsDoc = "When initially registering with the controller quorum, the number of milliseconds to wait before declaring failure and exiting the broker process." + val BrokerHeartbeatIntervalMsDoc = "The length of time in milliseconds between broker heartbeats. Used when running in KRaft mode." + val BrokerSessionTimeoutMsDoc = "The length of time in milliseconds that a broker lease lasts if no heartbeats are made. Used when running in KRaft mode." + val NodeIdDoc = "The node ID associated with the roles this process is playing when `process.roles` is non-empty. " + + "This is required configuration when running in KRaft mode." + val MetadataLogDirDoc = "This configuration determines where we put the metadata log for clusters in KRaft mode. " + + "If it is not set, the metadata log is placed in the first log directory from log.dirs." + val MetadataSnapshotMaxNewRecordBytesDoc = "This is the maximum number of bytes in the log between the latest snapshot and the high-watermark needed before generating a new snapshot." + val ControllerListenerNamesDoc = "A comma-separated list of the names of the listeners used by the controller. This is required " + + "if running in KRaft mode. The ZK-based controller will not use this configuration." + val SaslMechanismControllerProtocolDoc = "SASL mechanism used for communication with controllers. Default is GSSAPI." + val MetadataLogSegmentBytesDoc = "The maximum size of a single metadata log file." + val MetadataLogSegmentMinBytesDoc = "Override the minimum size for a single metadata log file. This should be used for testing only." + + val MetadataLogSegmentMillisDoc = "The maximum time before a new metadata log file is rolled out (in milliseconds)." + val MetadataMaxRetentionBytesDoc = "The maximum combined size of the metadata log and snapshots before deleting old " + + "snapshots and log files. Since at least one snapshot must exist before any logs can be deleted, this is a soft limit." + val MetadataMaxRetentionMillisDoc = "The number of milliseconds to keep a metadata log file or snapshot before " + + "deleting it. Since at least one snapshot must exist before any logs can be deleted, this is a soft limit." + + /************* Authorizer Configuration ***********/ + val AuthorizerClassNameDoc = s"The fully qualified name of a class that implements ${classOf[Authorizer].getName}" + + " interface, which is used by the broker for authorization." + /** ********* Socket Server Configuration ***********/ + val ListenersDoc = "Listener List - Comma-separated list of URIs we will listen on and the listener names." + + s" If the listener name is not a security protocol, $ListenerSecurityProtocolMapProp must also be set.\n" + + " Listener names and port numbers must be unique.\n" + + " Specify hostname as 0.0.0.0 to bind to all interfaces.\n" + + " Leave hostname empty to bind to default interface.\n" + + " Examples of legal listener lists:\n" + + " PLAINTEXT://myhost:9092,SSL://:9091\n" + + " CLIENT://0.0.0.0:9092,REPLICATION://localhost:9093\n" + val AdvertisedListenersDoc = s"Listeners to publish to ZooKeeper for clients to use, if different than the $ListenersProp config property." + + " In IaaS environments, this may need to be different from the interface to which the broker binds." + + s" If this is not set, the value for $ListenersProp will be used." + + s" Unlike $ListenersProp, it is not valid to advertise the 0.0.0.0 meta-address.\n" + + s" Also unlike $ListenersProp, there can be duplicated ports in this property," + + " so that one listener can be configured to advertise another listener's address." + + " This can be useful in some cases where external load balancers are used." + val ListenerSecurityProtocolMapDoc = "Map between listener names and security protocols. This must be defined for " + + "the same security protocol to be usable in more than one port or IP. For example, internal and " + + "external traffic can be separated even if SSL is required for both. Concretely, the user could define listeners " + + "with names INTERNAL and EXTERNAL and this property as: `INTERNAL:SSL,EXTERNAL:SSL`. As shown, key and value are " + + "separated by a colon and map entries are separated by commas. Each listener name should only appear once in the map. " + + "Different security (SSL and SASL) settings can be configured for each listener by adding a normalised " + + "prefix (the listener name is lowercased) to the config name. For example, to set a different keystore for the " + + "INTERNAL listener, a config with name listener.name.internal.ssl.keystore.location would be set. " + + "If the config for the listener name is not set, the config will fallback to the generic config (i.e. ssl.keystore.location). " + + "Note that in KRaft a default mapping from the listener names defined by controller.listener.names to PLAINTEXT " + + "is assumed if no explicit mapping is provided and no other security protocol is in use." + val controlPlaneListenerNameDoc = "Name of listener used for communication between controller and brokers. " + + s"Broker will use the $ControlPlaneListenerNameProp to locate the endpoint in $ListenersProp list, to listen for connections from the controller. " + + "For example, if a broker's config is :\n" + + "listeners = INTERNAL://192.1.1.8:9092, EXTERNAL://10.1.1.5:9093, CONTROLLER://192.1.1.8:9094\n" + + "listener.security.protocol.map = INTERNAL:PLAINTEXT, EXTERNAL:SSL, CONTROLLER:SSL\n" + + "control.plane.listener.name = CONTROLLER\n" + + "On startup, the broker will start listening on \"192.1.1.8:9094\" with security protocol \"SSL\".\n" + + s"On controller side, when it discovers a broker's published endpoints through zookeeper, it will use the $ControlPlaneListenerNameProp " + + "to find the endpoint, which it will use to establish connection to the broker.\n" + + "For example, if the broker's published endpoints on zookeeper are :\n" + + "\"endpoints\" : [\"INTERNAL://broker1.example.com:9092\",\"EXTERNAL://broker1.example.com:9093\",\"CONTROLLER://broker1.example.com:9094\"]\n" + + " and the controller's config is :\n" + + "listener.security.protocol.map = INTERNAL:PLAINTEXT, EXTERNAL:SSL, CONTROLLER:SSL\n" + + "control.plane.listener.name = CONTROLLER\n" + + "then controller will use \"broker1.example.com:9094\" with security protocol \"SSL\" to connect to the broker.\n" + + "If not explicitly configured, the default value will be null and there will be no dedicated endpoints for controller connections." + + val SocketSendBufferBytesDoc = "The SO_SNDBUF buffer of the socket server sockets. If the value is -1, the OS default will be used." + val SocketReceiveBufferBytesDoc = "The SO_RCVBUF buffer of the socket server sockets. If the value is -1, the OS default will be used." + val SocketRequestMaxBytesDoc = "The maximum number of bytes in a socket request" + val MaxConnectionsPerIpDoc = "The maximum number of connections we allow from each ip address. This can be set to 0 if there are overrides " + + s"configured using $MaxConnectionsPerIpOverridesProp property. New connections from the ip address are dropped if the limit is reached." + val MaxConnectionsPerIpOverridesDoc = "A comma-separated list of per-ip or hostname overrides to the default maximum number of connections. " + + "An example value is \"hostName:100,127.0.0.1:200\"" + val MaxConnectionsDoc = "The maximum number of connections we allow in the broker at any time. This limit is applied in addition " + + s"to any per-ip limits configured using $MaxConnectionsPerIpProp. Listener-level limits may also be configured by prefixing the " + + s"config name with the listener prefix, for example, listener.name.internal.$MaxConnectionsProp. Broker-wide limit " + + "should be configured based on broker capacity while listener limits should be configured based on application requirements. " + + "New connections are blocked if either the listener or broker limit is reached. Connections on the inter-broker listener are " + + "permitted even if broker-wide limit is reached. The least recently used connection on another listener will be closed in this case." + val MaxConnectionCreationRateDoc = "The maximum connection creation rate we allow in the broker at any time. Listener-level limits " + + s"may also be configured by prefixing the config name with the listener prefix, for example, listener.name.internal.$MaxConnectionCreationRateProp." + + "Broker-wide connection rate limit should be configured based on broker capacity while listener limits should be configured based on " + + "application requirements. New connections will be throttled if either the listener or the broker limit is reached, with the exception " + + "of inter-broker listener. Connections on the inter-broker listener will be throttled only when the listener-level rate limit is reached." + val ConnectionsMaxIdleMsDoc = "Idle connections timeout: the server socket processor threads close the connections that idle more than this" + val FailedAuthenticationDelayMsDoc = "Connection close delay on failed authentication: this is the time (in milliseconds) by which connection close will be delayed on authentication failure. " + + s"This must be configured to be less than $ConnectionsMaxIdleMsProp to prevent connection timeout." + /************* Rack Configuration **************/ + val RackDoc = "Rack of the broker. This will be used in rack aware replication assignment for fault tolerance. Examples: `RACK1`, `us-east-1d`" + /** ********* Log Configuration ***********/ + val NumPartitionsDoc = "The default number of log partitions per topic" + val LogDirDoc = "The directory in which the log data is kept (supplemental for " + LogDirsProp + " property)" + val LogDirsDoc = "The directories in which the log data is kept. If not set, the value in " + LogDirProp + " is used" + val LogSegmentBytesDoc = "The maximum size of a single log file" + val LogRollTimeMillisDoc = "The maximum time before a new log segment is rolled out (in milliseconds). If not set, the value in " + LogRollTimeHoursProp + " is used" + val LogRollTimeHoursDoc = "The maximum time before a new log segment is rolled out (in hours), secondary to " + LogRollTimeMillisProp + " property" + + val LogRollTimeJitterMillisDoc = "The maximum jitter to subtract from logRollTimeMillis (in milliseconds). If not set, the value in " + LogRollTimeJitterHoursProp + " is used" + val LogRollTimeJitterHoursDoc = "The maximum jitter to subtract from logRollTimeMillis (in hours), secondary to " + LogRollTimeJitterMillisProp + " property" + + val LogRetentionTimeMillisDoc = "The number of milliseconds to keep a log file before deleting it (in milliseconds), If not set, the value in " + LogRetentionTimeMinutesProp + " is used. If set to -1, no time limit is applied." + val LogRetentionTimeMinsDoc = "The number of minutes to keep a log file before deleting it (in minutes), secondary to " + LogRetentionTimeMillisProp + " property. If not set, the value in " + LogRetentionTimeHoursProp + " is used" + val LogRetentionTimeHoursDoc = "The number of hours to keep a log file before deleting it (in hours), tertiary to " + LogRetentionTimeMillisProp + " property" + + val LogRetentionBytesDoc = "The maximum size of the log before deleting it" + val LogCleanupIntervalMsDoc = "The frequency in milliseconds that the log cleaner checks whether any log is eligible for deletion" + val LogCleanupPolicyDoc = "The default cleanup policy for segments beyond the retention window. A comma separated list of valid policies. Valid policies are: \"delete\" and \"compact\"" + val LogCleanerThreadsDoc = "The number of background threads to use for log cleaning" + val LogCleanerIoMaxBytesPerSecondDoc = "The log cleaner will be throttled so that the sum of its read and write i/o will be less than this value on average" + val LogCleanerDedupeBufferSizeDoc = "The total memory used for log deduplication across all cleaner threads" + val LogCleanerIoBufferSizeDoc = "The total memory used for log cleaner I/O buffers across all cleaner threads" + val LogCleanerDedupeBufferLoadFactorDoc = "Log cleaner dedupe buffer load factor. The percentage full the dedupe buffer can become. A higher value " + + "will allow more log to be cleaned at once but will lead to more hash collisions" + val LogCleanerBackoffMsDoc = "The amount of time to sleep when there are no logs to clean" + val LogCleanerMinCleanRatioDoc = "The minimum ratio of dirty log to total log for a log to eligible for cleaning. " + + "If the " + LogCleanerMaxCompactionLagMsProp + " or the " + LogCleanerMinCompactionLagMsProp + + " configurations are also specified, then the log compactor considers the log eligible for compaction " + + "as soon as either: (i) the dirty ratio threshold has been met and the log has had dirty (uncompacted) " + + "records for at least the " + LogCleanerMinCompactionLagMsProp + " duration, or (ii) if the log has had " + + "dirty (uncompacted) records for at most the " + LogCleanerMaxCompactionLagMsProp + " period." + val LogCleanerEnableDoc = "Enable the log cleaner process to run on the server. Should be enabled if using any topics with a cleanup.policy=compact including the internal offsets topic. If disabled those topics will not be compacted and continually grow in size." + val LogCleanerDeleteRetentionMsDoc = "How long are delete records retained?" + val LogCleanerMinCompactionLagMsDoc = "The minimum time a message will remain uncompacted in the log. Only applicable for logs that are being compacted." + val LogCleanerMaxCompactionLagMsDoc = "The maximum time a message will remain ineligible for compaction in the log. Only applicable for logs that are being compacted." + val LogIndexSizeMaxBytesDoc = "The maximum size in bytes of the offset index" + val LogIndexIntervalBytesDoc = "The interval with which we add an entry to the offset index" + val LogFlushIntervalMessagesDoc = "The number of messages accumulated on a log partition before messages are flushed to disk " + val LogDeleteDelayMsDoc = "The amount of time to wait before deleting a file from the filesystem" + val LogFlushSchedulerIntervalMsDoc = "The frequency in ms that the log flusher checks whether any log needs to be flushed to disk" + val LogFlushIntervalMsDoc = "The maximum time in ms that a message in any topic is kept in memory before flushed to disk. If not set, the value in " + LogFlushSchedulerIntervalMsProp + " is used" + val LogFlushOffsetCheckpointIntervalMsDoc = "The frequency with which we update the persistent record of the last flush which acts as the log recovery point" + val LogFlushStartOffsetCheckpointIntervalMsDoc = "The frequency with which we update the persistent record of log start offset" + val LogPreAllocateEnableDoc = "Should pre allocate file when create new segment? If you are using Kafka on Windows, you probably need to set it to true." + val LogMessageFormatVersionDoc = "Specify the message format version the broker will use to append messages to the logs. The value should be a valid ApiVersion. " + + "Some examples are: 0.8.2, 0.9.0.0, 0.10.0, check ApiVersion for more details. By setting a particular message format version, the " + + "user is certifying that all the existing messages on disk are smaller or equal than the specified version. Setting this value incorrectly " + + "will cause consumers with older versions to break as they will receive messages with a format that they don't understand." + + val LogMessageTimestampTypeDoc = "Define whether the timestamp in the message is message create time or log append time. The value should be either " + + "`CreateTime` or `LogAppendTime`" + + val LogMessageTimestampDifferenceMaxMsDoc = "The maximum difference allowed between the timestamp when a broker receives " + + "a message and the timestamp specified in the message. If log.message.timestamp.type=CreateTime, a message will be rejected " + + "if the difference in timestamp exceeds this threshold. This configuration is ignored if log.message.timestamp.type=LogAppendTime." + + "The maximum timestamp difference allowed should be no greater than log.retention.ms to avoid unnecessarily frequent log rolling." + val NumRecoveryThreadsPerDataDirDoc = "The number of threads per data directory to be used for log recovery at startup and flushing at shutdown" + val AutoCreateTopicsEnableDoc = "Enable auto creation of topic on the server" + val MinInSyncReplicasDoc = "When a producer sets acks to \"all\" (or \"-1\"), " + + "min.insync.replicas specifies the minimum number of replicas that must acknowledge " + + "a write for the write to be considered successful. If this minimum cannot be met, " + + "then the producer will raise an exception (either NotEnoughReplicas or " + + "NotEnoughReplicasAfterAppend).
            When used together, min.insync.replicas and acks " + + "allow you to enforce greater durability guarantees. A typical scenario would be to " + + "create a topic with a replication factor of 3, set min.insync.replicas to 2, and " + + "produce with acks of \"all\". This will ensure that the producer raises an exception " + + "if a majority of replicas do not receive a write." + + val CreateTopicPolicyClassNameDoc = "The create topic policy class that should be used for validation. The class should " + + "implement the org.apache.kafka.server.policy.CreateTopicPolicy interface." + val AlterConfigPolicyClassNameDoc = "The alter configs policy class that should be used for validation. The class should " + + "implement the org.apache.kafka.server.policy.AlterConfigPolicy interface." + val LogMessageDownConversionEnableDoc = TopicConfig.MESSAGE_DOWNCONVERSION_ENABLE_DOC; + + /** ********* Replication configuration ***********/ + val ControllerSocketTimeoutMsDoc = "The socket timeout for controller-to-broker channels" + val ControllerMessageQueueSizeDoc = "The buffer size for controller-to-broker-channels" + val DefaultReplicationFactorDoc = "The default replication factors for automatically created topics" + val ReplicaLagTimeMaxMsDoc = "If a follower hasn't sent any fetch requests or hasn't consumed up to the leaders log end offset for at least this time," + + " the leader will remove the follower from isr" + val ReplicaSocketTimeoutMsDoc = "The socket timeout for network requests. Its value should be at least replica.fetch.wait.max.ms" + val ReplicaSocketReceiveBufferBytesDoc = "The socket receive buffer for network requests" + val ReplicaFetchMaxBytesDoc = "The number of bytes of messages to attempt to fetch for each partition. This is not an absolute maximum, " + + "if the first record batch in the first non-empty partition of the fetch is larger than this value, the record batch will still be returned " + + "to ensure that progress can be made. The maximum record batch size accepted by the broker is defined via " + + "message.max.bytes (broker config) or max.message.bytes (topic config)." + val ReplicaFetchWaitMaxMsDoc = "The maximum wait time for each fetcher request issued by follower replicas. This value should always be less than the " + + "replica.lag.time.max.ms at all times to prevent frequent shrinking of ISR for low throughput topics" + val ReplicaFetchMinBytesDoc = "Minimum bytes expected for each fetch response. If not enough bytes, wait up to replica.fetch.wait.max.ms (broker config)." + val ReplicaFetchResponseMaxBytesDoc = "Maximum bytes expected for the entire fetch response. Records are fetched in batches, " + + "and if the first record batch in the first non-empty partition of the fetch is larger than this value, the record batch " + + "will still be returned to ensure that progress can be made. As such, this is not an absolute maximum. The maximum " + + "record batch size accepted by the broker is defined via message.max.bytes (broker config) or " + + "max.message.bytes (topic config)." + val NumReplicaFetchersDoc = "Number of fetcher threads used to replicate messages from a source broker. " + + "Increasing this value can increase the degree of I/O parallelism in the follower broker." + val ReplicaFetchBackoffMsDoc = "The amount of time to sleep when fetch partition error occurs." + val ReplicaHighWatermarkCheckpointIntervalMsDoc = "The frequency with which the high watermark is saved out to disk" + val FetchPurgatoryPurgeIntervalRequestsDoc = "The purge interval (in number of requests) of the fetch request purgatory" + val ProducerPurgatoryPurgeIntervalRequestsDoc = "The purge interval (in number of requests) of the producer request purgatory" + val DeleteRecordsPurgatoryPurgeIntervalRequestsDoc = "The purge interval (in number of requests) of the delete records request purgatory" + val AutoLeaderRebalanceEnableDoc = "Enables auto leader balancing. A background thread checks the distribution of partition leaders at regular intervals, configurable by `leader.imbalance.check.interval.seconds`. If the leader imbalance exceeds `leader.imbalance.per.broker.percentage`, leader rebalance to the preferred leader for partitions is triggered." + val LeaderImbalancePerBrokerPercentageDoc = "The ratio of leader imbalance allowed per broker. The controller would trigger a leader balance if it goes above this value per broker. The value is specified in percentage." + val LeaderImbalanceCheckIntervalSecondsDoc = "The frequency with which the partition rebalance check is triggered by the controller" + val UncleanLeaderElectionEnableDoc = "Indicates whether to enable replicas not in the ISR set to be elected as leader as a last resort, even though doing so may result in data loss" + val InterBrokerSecurityProtocolDoc = "Security protocol used to communicate between brokers. Valid values are: " + + s"${SecurityProtocol.names.asScala.mkString(", ")}. It is an error to set this and $InterBrokerListenerNameProp " + + "properties at the same time." + val InterBrokerProtocolVersionDoc = "Specify which version of the inter-broker protocol will be used.\n" + + " This is typically bumped after all brokers were upgraded to a new version.\n" + + " Example of some valid values are: 0.8.0, 0.8.1, 0.8.1.1, 0.8.2, 0.8.2.0, 0.8.2.1, 0.9.0.0, 0.9.0.1 Check ApiVersion for the full list." + val InterBrokerListenerNameDoc = s"Name of listener used for communication between brokers. If this is unset, the listener name is defined by $InterBrokerSecurityProtocolProp. " + + s"It is an error to set this and $InterBrokerSecurityProtocolProp properties at the same time." + val ReplicaSelectorClassDoc = "The fully qualified class name that implements ReplicaSelector. This is used by the broker to find the preferred read replica. By default, we use an implementation that returns the leader." + /** ********* Controlled shutdown configuration ***********/ + val ControlledShutdownMaxRetriesDoc = "Controlled shutdown can fail for multiple reasons. This determines the number of retries when such failure happens" + val ControlledShutdownRetryBackoffMsDoc = "Before each retry, the system needs time to recover from the state that caused the previous failure (Controller fail over, replica lag etc). This config determines the amount of time to wait before retrying." + val ControlledShutdownEnableDoc = "Enable controlled shutdown of the server" + /** ********* Group coordinator configuration ***********/ + val GroupMinSessionTimeoutMsDoc = "The minimum allowed session timeout for registered consumers. Shorter timeouts result in quicker failure detection at the cost of more frequent consumer heartbeating, which can overwhelm broker resources." + val GroupMaxSessionTimeoutMsDoc = "The maximum allowed session timeout for registered consumers. Longer timeouts give consumers more time to process messages in between heartbeats at the cost of a longer time to detect failures." + val GroupInitialRebalanceDelayMsDoc = "The amount of time the group coordinator will wait for more consumers to join a new group before performing the first rebalance. A longer delay means potentially fewer rebalances, but increases the time until processing begins." + val GroupMaxSizeDoc = "The maximum number of consumers that a single consumer group can accommodate." + /** ********* Offset management configuration ***********/ + val OffsetMetadataMaxSizeDoc = "The maximum size for a metadata entry associated with an offset commit" + val OffsetsLoadBufferSizeDoc = "Batch size for reading from the offsets segments when loading offsets into the cache (soft-limit, overridden if records are too large)." + val OffsetsTopicReplicationFactorDoc = "The replication factor for the offsets topic (set higher to ensure availability). " + + "Internal topic creation will fail until the cluster size meets this replication factor requirement." + val OffsetsTopicPartitionsDoc = "The number of partitions for the offset commit topic (should not change after deployment)" + val OffsetsTopicSegmentBytesDoc = "The offsets topic segment bytes should be kept relatively small in order to facilitate faster log compaction and cache loads" + val OffsetsTopicCompressionCodecDoc = "Compression codec for the offsets topic - compression may be used to achieve \"atomic\" commits" + val OffsetsRetentionMinutesDoc = "After a consumer group loses all its consumers (i.e. becomes empty) its offsets will be kept for this retention period before getting discarded. " + + "For standalone consumers (using manual assignment), offsets will be expired after the time of last commit plus this retention period." + val OffsetsRetentionCheckIntervalMsDoc = "Frequency at which to check for stale offsets" + val OffsetCommitTimeoutMsDoc = "Offset commit will be delayed until all replicas for the offsets topic receive the commit " + + "or this timeout is reached. This is similar to the producer request timeout." + val OffsetCommitRequiredAcksDoc = "The required acks before the commit can be accepted. In general, the default (-1) should not be overridden" + /** ********* Transaction management configuration ***********/ + val TransactionalIdExpirationMsDoc = "The time in ms that the transaction coordinator will wait without receiving any transaction status updates " + + "for the current transaction before expiring its transactional id. This setting also influences producer id expiration - producer ids are expired " + + "once this time has elapsed after the last write with the given producer id. Note that producer ids may expire sooner if the last write from the producer id is deleted due to the topic's retention settings." + val TransactionsMaxTimeoutMsDoc = "The maximum allowed timeout for transactions. " + + "If a client’s requested transaction time exceed this, then the broker will return an error in InitProducerIdRequest. This prevents a client from too large of a timeout, which can stall consumers reading from topics included in the transaction." + val TransactionsTopicMinISRDoc = "Overridden " + MinInSyncReplicasProp + " config for the transaction topic." + val TransactionsLoadBufferSizeDoc = "Batch size for reading from the transaction log segments when loading producer ids and transactions into the cache (soft-limit, overridden if records are too large)." + val TransactionsTopicReplicationFactorDoc = "The replication factor for the transaction topic (set higher to ensure availability). " + + "Internal topic creation will fail until the cluster size meets this replication factor requirement." + val TransactionsTopicPartitionsDoc = "The number of partitions for the transaction topic (should not change after deployment)." + val TransactionsTopicSegmentBytesDoc = "The transaction topic segment bytes should be kept relatively small in order to facilitate faster log compaction and cache loads" + val TransactionsAbortTimedOutTransactionsIntervalMsDoc = "The interval at which to rollback transactions that have timed out" + val TransactionsRemoveExpiredTransactionsIntervalMsDoc = "The interval at which to remove transactions that have expired due to transactional.id.expiration.ms passing" + + /** ********* Fetch Configuration **************/ + val MaxIncrementalFetchSessionCacheSlotsDoc = "The maximum number of incremental fetch sessions that we will maintain." + val FetchMaxBytesDoc = "The maximum number of bytes we will return for a fetch request. Must be at least 1024." + + /** ********* Quota Configuration ***********/ + val NumQuotaSamplesDoc = "The number of samples to retain in memory for client quotas" + val NumReplicationQuotaSamplesDoc = "The number of samples to retain in memory for replication quotas" + val NumAlterLogDirsReplicationQuotaSamplesDoc = "The number of samples to retain in memory for alter log dirs replication quotas" + val NumControllerQuotaSamplesDoc = "The number of samples to retain in memory for controller mutation quotas" + val QuotaWindowSizeSecondsDoc = "The time span of each sample for client quotas" + val ReplicationQuotaWindowSizeSecondsDoc = "The time span of each sample for replication quotas" + val AlterLogDirsReplicationQuotaWindowSizeSecondsDoc = "The time span of each sample for alter log dirs replication quotas" + val ControllerQuotaWindowSizeSecondsDoc = "The time span of each sample for controller mutations quotas" + + val ClientQuotaCallbackClassDoc = "The fully qualified name of a class that implements the ClientQuotaCallback interface, " + + "which is used to determine quota limits applied to client requests. By default, <user>, <client-id>, <user> or <client-id> " + + "quotas stored in ZooKeeper are applied. For any given request, the most specific quota that matches the user principal " + + "of the session and the client-id of the request is applied." + + val DeleteTopicEnableDoc = "Enables delete topic. Delete topic through the admin tool will have no effect if this config is turned off" + val CompressionTypeDoc = "Specify the final compression type for a given topic. This configuration accepts the standard compression codecs " + + "('gzip', 'snappy', 'lz4', 'zstd'). It additionally accepts 'uncompressed' which is equivalent to no compression; and " + + "'producer' which means retain the original compression codec set by the producer." + + /** ********* Kafka Metrics Configuration ***********/ + val MetricSampleWindowMsDoc = CommonClientConfigs.METRICS_SAMPLE_WINDOW_MS_DOC + val MetricNumSamplesDoc = CommonClientConfigs.METRICS_NUM_SAMPLES_DOC + val MetricReporterClassesDoc = CommonClientConfigs.METRIC_REPORTER_CLASSES_DOC + val MetricRecordingLevelDoc = CommonClientConfigs.METRICS_RECORDING_LEVEL_DOC + + + /** ********* Kafka Yammer Metrics Reporter Configuration ***********/ + val KafkaMetricsReporterClassesDoc = "A list of classes to use as Yammer metrics custom reporters." + + " The reporters should implement kafka.metrics.KafkaMetricsReporter trait. If a client wants" + + " to expose JMX operations on a custom reporter, the custom reporter needs to additionally implement an MBean" + + " trait that extends kafka.metrics.KafkaMetricsReporterMBean trait so that the registered MBean is compliant with" + + " the standard MBean convention." + + val KafkaMetricsPollingIntervalSecondsDoc = s"The metrics polling interval (in seconds) which can be used" + + s" in $KafkaMetricsReporterClassesProp implementations." + + /** ******** Common Security Configuration *************/ + val PrincipalBuilderClassDoc = BrokerSecurityConfigs.PRINCIPAL_BUILDER_CLASS_DOC + val ConnectionsMaxReauthMsDoc = BrokerSecurityConfigs.CONNECTIONS_MAX_REAUTH_MS_DOC + val securityProviderClassDoc = SecurityConfig.SECURITY_PROVIDERS_DOC + + /** ********* SSL Configuration ****************/ + val SslProtocolDoc = SslConfigs.SSL_PROTOCOL_DOC + val SslProviderDoc = SslConfigs.SSL_PROVIDER_DOC + val SslCipherSuitesDoc = SslConfigs.SSL_CIPHER_SUITES_DOC + val SslEnabledProtocolsDoc = SslConfigs.SSL_ENABLED_PROTOCOLS_DOC + val SslKeystoreTypeDoc = SslConfigs.SSL_KEYSTORE_TYPE_DOC + val SslKeystoreLocationDoc = SslConfigs.SSL_KEYSTORE_LOCATION_DOC + val SslKeystorePasswordDoc = SslConfigs.SSL_KEYSTORE_PASSWORD_DOC + val SslKeyPasswordDoc = SslConfigs.SSL_KEY_PASSWORD_DOC + val SslKeystoreKeyDoc = SslConfigs.SSL_KEYSTORE_KEY_DOC + val SslKeystoreCertificateChainDoc = SslConfigs.SSL_KEYSTORE_CERTIFICATE_CHAIN_DOC + val SslTruststoreTypeDoc = SslConfigs.SSL_TRUSTSTORE_TYPE_DOC + val SslTruststorePasswordDoc = SslConfigs.SSL_TRUSTSTORE_PASSWORD_DOC + val SslTruststoreLocationDoc = SslConfigs.SSL_TRUSTSTORE_LOCATION_DOC + val SslTruststoreCertificatesDoc = SslConfigs.SSL_TRUSTSTORE_CERTIFICATES_DOC + val SslKeyManagerAlgorithmDoc = SslConfigs.SSL_KEYMANAGER_ALGORITHM_DOC + val SslTrustManagerAlgorithmDoc = SslConfigs.SSL_TRUSTMANAGER_ALGORITHM_DOC + val SslEndpointIdentificationAlgorithmDoc = SslConfigs.SSL_ENDPOINT_IDENTIFICATION_ALGORITHM_DOC + val SslSecureRandomImplementationDoc = SslConfigs.SSL_SECURE_RANDOM_IMPLEMENTATION_DOC + val SslClientAuthDoc = BrokerSecurityConfigs.SSL_CLIENT_AUTH_DOC + val SslPrincipalMappingRulesDoc = BrokerSecurityConfigs.SSL_PRINCIPAL_MAPPING_RULES_DOC + val SslEngineFactoryClassDoc = SslConfigs.SSL_ENGINE_FACTORY_CLASS_DOC + + /** ********* Sasl Configuration ****************/ + val SaslMechanismInterBrokerProtocolDoc = "SASL mechanism used for inter-broker communication. Default is GSSAPI." + val SaslJaasConfigDoc = SaslConfigs.SASL_JAAS_CONFIG_DOC + val SaslEnabledMechanismsDoc = BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_DOC + val SaslServerCallbackHandlerClassDoc = BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS_DOC + val SaslClientCallbackHandlerClassDoc = SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS_DOC + val SaslLoginClassDoc = SaslConfigs.SASL_LOGIN_CLASS_DOC + val SaslLoginCallbackHandlerClassDoc = SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS_DOC + val SaslKerberosServiceNameDoc = SaslConfigs.SASL_KERBEROS_SERVICE_NAME_DOC + val SaslKerberosKinitCmdDoc = SaslConfigs.SASL_KERBEROS_KINIT_CMD_DOC + val SaslKerberosTicketRenewWindowFactorDoc = SaslConfigs.SASL_KERBEROS_TICKET_RENEW_WINDOW_FACTOR_DOC + val SaslKerberosTicketRenewJitterDoc = SaslConfigs.SASL_KERBEROS_TICKET_RENEW_JITTER_DOC + val SaslKerberosMinTimeBeforeReloginDoc = SaslConfigs.SASL_KERBEROS_MIN_TIME_BEFORE_RELOGIN_DOC + val SaslKerberosPrincipalToLocalRulesDoc = BrokerSecurityConfigs.SASL_KERBEROS_PRINCIPAL_TO_LOCAL_RULES_DOC + val SaslLoginRefreshWindowFactorDoc = SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_FACTOR_DOC + val SaslLoginRefreshWindowJitterDoc = SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_JITTER_DOC + val SaslLoginRefreshMinPeriodSecondsDoc = SaslConfigs.SASL_LOGIN_REFRESH_MIN_PERIOD_SECONDS_DOC + val SaslLoginRefreshBufferSecondsDoc = SaslConfigs.SASL_LOGIN_REFRESH_BUFFER_SECONDS_DOC + + val SaslLoginConnectTimeoutMsDoc = SaslConfigs.SASL_LOGIN_CONNECT_TIMEOUT_MS_DOC + val SaslLoginReadTimeoutMsDoc = SaslConfigs.SASL_LOGIN_READ_TIMEOUT_MS_DOC + val SaslLoginRetryBackoffMaxMsDoc = SaslConfigs.SASL_LOGIN_RETRY_BACKOFF_MAX_MS_DOC + val SaslLoginRetryBackoffMsDoc = SaslConfigs.SASL_LOGIN_RETRY_BACKOFF_MS_DOC + val SaslOAuthBearerScopeClaimNameDoc = SaslConfigs.SASL_OAUTHBEARER_SCOPE_CLAIM_NAME_DOC + val SaslOAuthBearerSubClaimNameDoc = SaslConfigs.SASL_OAUTHBEARER_SUB_CLAIM_NAME_DOC + val SaslOAuthBearerTokenEndpointUrlDoc = SaslConfigs.SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL_DOC + val SaslOAuthBearerJwksEndpointUrlDoc = SaslConfigs.SASL_OAUTHBEARER_JWKS_ENDPOINT_URL_DOC + val SaslOAuthBearerJwksEndpointRefreshMsDoc = SaslConfigs.SASL_OAUTHBEARER_JWKS_ENDPOINT_REFRESH_MS_DOC + val SaslOAuthBearerJwksEndpointRetryBackoffMaxMsDoc = SaslConfigs.SASL_OAUTHBEARER_JWKS_ENDPOINT_RETRY_BACKOFF_MAX_MS_DOC + val SaslOAuthBearerJwksEndpointRetryBackoffMsDoc = SaslConfigs.SASL_OAUTHBEARER_JWKS_ENDPOINT_RETRY_BACKOFF_MS_DOC + val SaslOAuthBearerClockSkewSecondsDoc = SaslConfigs.SASL_OAUTHBEARER_CLOCK_SKEW_SECONDS_DOC + val SaslOAuthBearerExpectedAudienceDoc = SaslConfigs.SASL_OAUTHBEARER_EXPECTED_AUDIENCE_DOC + val SaslOAuthBearerExpectedIssuerDoc = SaslConfigs.SASL_OAUTHBEARER_EXPECTED_ISSUER_DOC + + /** ********* Delegation Token Configuration ****************/ + val DelegationTokenSecretKeyAliasDoc = s"DEPRECATED: An alias for $DelegationTokenSecretKeyProp, which should be used instead of this config." + val DelegationTokenSecretKeyDoc = "Secret key to generate and verify delegation tokens. The same key must be configured across all the brokers. " + + " If the key is not set or set to empty string, brokers will disable the delegation token support." + val DelegationTokenMaxLifeTimeDoc = "The token has a maximum lifetime beyond which it cannot be renewed anymore. Default value 7 days." + val DelegationTokenExpiryTimeMsDoc = "The token validity time in miliseconds before the token needs to be renewed. Default value 1 day." + val DelegationTokenExpiryCheckIntervalDoc = "Scan interval to remove expired delegation tokens." + + /** ********* Password encryption configuration for dynamic configs *********/ + val PasswordEncoderSecretDoc = "The secret used for encoding dynamically configured passwords for this broker." + val PasswordEncoderOldSecretDoc = "The old secret that was used for encoding dynamically configured passwords. " + + "This is required only when the secret is updated. If specified, all dynamically encoded passwords are " + + s"decoded using this old secret and re-encoded using $PasswordEncoderSecretProp when broker starts up." + val PasswordEncoderKeyFactoryAlgorithmDoc = "The SecretKeyFactory algorithm used for encoding dynamically configured passwords. " + + "Default is PBKDF2WithHmacSHA512 if available and PBKDF2WithHmacSHA1 otherwise." + val PasswordEncoderCipherAlgorithmDoc = "The Cipher algorithm used for encoding dynamically configured passwords." + val PasswordEncoderKeyLengthDoc = "The key length used for encoding dynamically configured passwords." + val PasswordEncoderIterationsDoc = "The iteration count used for encoding dynamically configured passwords." + + @nowarn("cat=deprecation") + private[server] val configDef = { + import ConfigDef.Importance._ + import ConfigDef.Range._ + import ConfigDef.Type._ + import ConfigDef.ValidString._ + + new ConfigDef() + + /** ********* Zookeeper Configuration ***********/ + .define(ZkConnectProp, STRING, null, HIGH, ZkConnectDoc) + .define(ZkSessionTimeoutMsProp, INT, Defaults.ZkSessionTimeoutMs, HIGH, ZkSessionTimeoutMsDoc) + .define(ZkConnectionTimeoutMsProp, INT, null, HIGH, ZkConnectionTimeoutMsDoc) + .define(ZkSyncTimeMsProp, INT, Defaults.ZkSyncTimeMs, LOW, ZkSyncTimeMsDoc) + .define(ZkEnableSecureAclsProp, BOOLEAN, Defaults.ZkEnableSecureAcls, HIGH, ZkEnableSecureAclsDoc) + .define(ZkMaxInFlightRequestsProp, INT, Defaults.ZkMaxInFlightRequests, atLeast(1), HIGH, ZkMaxInFlightRequestsDoc) + .define(ZkSslClientEnableProp, BOOLEAN, Defaults.ZkSslClientEnable, MEDIUM, ZkSslClientEnableDoc) + .define(ZkClientCnxnSocketProp, STRING, null, MEDIUM, ZkClientCnxnSocketDoc) + .define(ZkSslKeyStoreLocationProp, STRING, null, MEDIUM, ZkSslKeyStoreLocationDoc) + .define(ZkSslKeyStorePasswordProp, PASSWORD, null, MEDIUM, ZkSslKeyStorePasswordDoc) + .define(ZkSslKeyStoreTypeProp, STRING, null, MEDIUM, ZkSslKeyStoreTypeDoc) + .define(ZkSslTrustStoreLocationProp, STRING, null, MEDIUM, ZkSslTrustStoreLocationDoc) + .define(ZkSslTrustStorePasswordProp, PASSWORD, null, MEDIUM, ZkSslTrustStorePasswordDoc) + .define(ZkSslTrustStoreTypeProp, STRING, null, MEDIUM, ZkSslTrustStoreTypeDoc) + .define(ZkSslProtocolProp, STRING, Defaults.ZkSslProtocol, LOW, ZkSslProtocolDoc) + .define(ZkSslEnabledProtocolsProp, LIST, null, LOW, ZkSslEnabledProtocolsDoc) + .define(ZkSslCipherSuitesProp, LIST, null, LOW, ZkSslCipherSuitesDoc) + .define(ZkSslEndpointIdentificationAlgorithmProp, STRING, Defaults.ZkSslEndpointIdentificationAlgorithm, LOW, ZkSslEndpointIdentificationAlgorithmDoc) + .define(ZkSslCrlEnableProp, BOOLEAN, Defaults.ZkSslCrlEnable, LOW, ZkSslCrlEnableDoc) + .define(ZkSslOcspEnableProp, BOOLEAN, Defaults.ZkSslOcspEnable, LOW, ZkSslOcspEnableDoc) + + /** ********* General Configuration ***********/ + .define(BrokerIdGenerationEnableProp, BOOLEAN, Defaults.BrokerIdGenerationEnable, MEDIUM, BrokerIdGenerationEnableDoc) + .define(MaxReservedBrokerIdProp, INT, Defaults.MaxReservedBrokerId, atLeast(0), MEDIUM, MaxReservedBrokerIdDoc) + .define(BrokerIdProp, INT, Defaults.BrokerId, HIGH, BrokerIdDoc) + .define(MessageMaxBytesProp, INT, Defaults.MessageMaxBytes, atLeast(0), HIGH, MessageMaxBytesDoc) + .define(NumNetworkThreadsProp, INT, Defaults.NumNetworkThreads, atLeast(1), HIGH, NumNetworkThreadsDoc) + .define(NumIoThreadsProp, INT, Defaults.NumIoThreads, atLeast(1), HIGH, NumIoThreadsDoc) + .define(NumReplicaAlterLogDirsThreadsProp, INT, null, HIGH, NumReplicaAlterLogDirsThreadsDoc) + .define(BackgroundThreadsProp, INT, Defaults.BackgroundThreads, atLeast(1), HIGH, BackgroundThreadsDoc) + .define(QueuedMaxRequestsProp, INT, Defaults.QueuedMaxRequests, atLeast(1), HIGH, QueuedMaxRequestsDoc) + .define(QueuedMaxBytesProp, LONG, Defaults.QueuedMaxRequestBytes, MEDIUM, QueuedMaxRequestBytesDoc) + .define(RequestTimeoutMsProp, INT, Defaults.RequestTimeoutMs, HIGH, RequestTimeoutMsDoc) + .define(ConnectionSetupTimeoutMsProp, LONG, Defaults.ConnectionSetupTimeoutMs, MEDIUM, ConnectionSetupTimeoutMsDoc) + .define(ConnectionSetupTimeoutMaxMsProp, LONG, Defaults.ConnectionSetupTimeoutMaxMs, MEDIUM, ConnectionSetupTimeoutMaxMsDoc) + + /* + * KRaft mode configs. + */ + .define(MetadataSnapshotMaxNewRecordBytesProp, LONG, Defaults.MetadataSnapshotMaxNewRecordBytes, atLeast(1), HIGH, MetadataSnapshotMaxNewRecordBytesDoc) + + /* + * KRaft mode private configs. Note that these configs are defined as internal. We will make them public in the 3.0.0 release. + */ + .define(ProcessRolesProp, LIST, Collections.emptyList(), ValidList.in("broker", "controller"), HIGH, ProcessRolesDoc) + .define(NodeIdProp, INT, Defaults.EmptyNodeId, null, HIGH, NodeIdDoc) + .define(InitialBrokerRegistrationTimeoutMsProp, INT, Defaults.InitialBrokerRegistrationTimeoutMs, null, MEDIUM, InitialBrokerRegistrationTimeoutMsDoc) + .define(BrokerHeartbeatIntervalMsProp, INT, Defaults.BrokerHeartbeatIntervalMs, null, MEDIUM, BrokerHeartbeatIntervalMsDoc) + .define(BrokerSessionTimeoutMsProp, INT, Defaults.BrokerSessionTimeoutMs, null, MEDIUM, BrokerSessionTimeoutMsDoc) + .define(ControllerListenerNamesProp, STRING, null, null, HIGH, ControllerListenerNamesDoc) + .define(SaslMechanismControllerProtocolProp, STRING, SaslConfigs.DEFAULT_SASL_MECHANISM, null, HIGH, SaslMechanismControllerProtocolDoc) + .define(MetadataLogDirProp, STRING, null, null, HIGH, MetadataLogDirDoc) + .define(MetadataLogSegmentBytesProp, INT, Defaults.LogSegmentBytes, atLeast(Records.LOG_OVERHEAD), HIGH, MetadataLogSegmentBytesDoc) + .defineInternal(MetadataLogSegmentMinBytesProp, INT, 8 * 1024 * 1024, atLeast(Records.LOG_OVERHEAD), HIGH, MetadataLogSegmentMinBytesDoc) + .define(MetadataLogSegmentMillisProp, LONG, Defaults.LogRollHours * 60 * 60 * 1000L, null, HIGH, MetadataLogSegmentMillisDoc) + .define(MetadataMaxRetentionBytesProp, LONG, Defaults.LogRetentionBytes, null, HIGH, MetadataMaxRetentionBytesDoc) + .define(MetadataMaxRetentionMillisProp, LONG, Defaults.LogRetentionHours * 60 * 60 * 1000L, null, HIGH, MetadataMaxRetentionMillisDoc) + + /************* Authorizer Configuration ***********/ + .define(AuthorizerClassNameProp, STRING, Defaults.AuthorizerClassName, LOW, AuthorizerClassNameDoc) + + /** ********* Socket Server Configuration ***********/ + .define(ListenersProp, STRING, Defaults.Listeners, HIGH, ListenersDoc) + .define(AdvertisedListenersProp, STRING, null, HIGH, AdvertisedListenersDoc) + .define(ListenerSecurityProtocolMapProp, STRING, Defaults.ListenerSecurityProtocolMap, LOW, ListenerSecurityProtocolMapDoc) + .define(ControlPlaneListenerNameProp, STRING, null, HIGH, controlPlaneListenerNameDoc) + .define(SocketSendBufferBytesProp, INT, Defaults.SocketSendBufferBytes, HIGH, SocketSendBufferBytesDoc) + .define(SocketReceiveBufferBytesProp, INT, Defaults.SocketReceiveBufferBytes, HIGH, SocketReceiveBufferBytesDoc) + .define(SocketRequestMaxBytesProp, INT, Defaults.SocketRequestMaxBytes, atLeast(1), HIGH, SocketRequestMaxBytesDoc) + .define(MaxConnectionsPerIpProp, INT, Defaults.MaxConnectionsPerIp, atLeast(0), MEDIUM, MaxConnectionsPerIpDoc) + .define(MaxConnectionsPerIpOverridesProp, STRING, Defaults.MaxConnectionsPerIpOverrides, MEDIUM, MaxConnectionsPerIpOverridesDoc) + .define(MaxConnectionsProp, INT, Defaults.MaxConnections, atLeast(0), MEDIUM, MaxConnectionsDoc) + .define(MaxConnectionCreationRateProp, INT, Defaults.MaxConnectionCreationRate, atLeast(0), MEDIUM, MaxConnectionCreationRateDoc) + .define(ConnectionsMaxIdleMsProp, LONG, Defaults.ConnectionsMaxIdleMs, MEDIUM, ConnectionsMaxIdleMsDoc) + .define(FailedAuthenticationDelayMsProp, INT, Defaults.FailedAuthenticationDelayMs, atLeast(0), LOW, FailedAuthenticationDelayMsDoc) + + /************ Rack Configuration ******************/ + .define(RackProp, STRING, null, MEDIUM, RackDoc) + + /** ********* Log Configuration ***********/ + .define(NumPartitionsProp, INT, Defaults.NumPartitions, atLeast(1), MEDIUM, NumPartitionsDoc) + .define(LogDirProp, STRING, Defaults.LogDir, HIGH, LogDirDoc) + .define(LogDirsProp, STRING, null, HIGH, LogDirsDoc) + .define(LogSegmentBytesProp, INT, Defaults.LogSegmentBytes, atLeast(LegacyRecord.RECORD_OVERHEAD_V0), HIGH, LogSegmentBytesDoc) + + .define(LogRollTimeMillisProp, LONG, null, HIGH, LogRollTimeMillisDoc) + .define(LogRollTimeHoursProp, INT, Defaults.LogRollHours, atLeast(1), HIGH, LogRollTimeHoursDoc) + + .define(LogRollTimeJitterMillisProp, LONG, null, HIGH, LogRollTimeJitterMillisDoc) + .define(LogRollTimeJitterHoursProp, INT, Defaults.LogRollJitterHours, atLeast(0), HIGH, LogRollTimeJitterHoursDoc) + + .define(LogRetentionTimeMillisProp, LONG, null, HIGH, LogRetentionTimeMillisDoc) + .define(LogRetentionTimeMinutesProp, INT, null, HIGH, LogRetentionTimeMinsDoc) + .define(LogRetentionTimeHoursProp, INT, Defaults.LogRetentionHours, HIGH, LogRetentionTimeHoursDoc) + + .define(LogRetentionBytesProp, LONG, Defaults.LogRetentionBytes, HIGH, LogRetentionBytesDoc) + .define(LogCleanupIntervalMsProp, LONG, Defaults.LogCleanupIntervalMs, atLeast(1), MEDIUM, LogCleanupIntervalMsDoc) + .define(LogCleanupPolicyProp, LIST, Defaults.LogCleanupPolicy, ValidList.in(Defaults.Compact, Defaults.Delete), MEDIUM, LogCleanupPolicyDoc) + .define(LogCleanerThreadsProp, INT, Defaults.LogCleanerThreads, atLeast(0), MEDIUM, LogCleanerThreadsDoc) + .define(LogCleanerIoMaxBytesPerSecondProp, DOUBLE, Defaults.LogCleanerIoMaxBytesPerSecond, MEDIUM, LogCleanerIoMaxBytesPerSecondDoc) + .define(LogCleanerDedupeBufferSizeProp, LONG, Defaults.LogCleanerDedupeBufferSize, MEDIUM, LogCleanerDedupeBufferSizeDoc) + .define(LogCleanerIoBufferSizeProp, INT, Defaults.LogCleanerIoBufferSize, atLeast(0), MEDIUM, LogCleanerIoBufferSizeDoc) + .define(LogCleanerDedupeBufferLoadFactorProp, DOUBLE, Defaults.LogCleanerDedupeBufferLoadFactor, MEDIUM, LogCleanerDedupeBufferLoadFactorDoc) + .define(LogCleanerBackoffMsProp, LONG, Defaults.LogCleanerBackoffMs, atLeast(0), MEDIUM, LogCleanerBackoffMsDoc) + .define(LogCleanerMinCleanRatioProp, DOUBLE, Defaults.LogCleanerMinCleanRatio, MEDIUM, LogCleanerMinCleanRatioDoc) + .define(LogCleanerEnableProp, BOOLEAN, Defaults.LogCleanerEnable, MEDIUM, LogCleanerEnableDoc) + .define(LogCleanerDeleteRetentionMsProp, LONG, Defaults.LogCleanerDeleteRetentionMs, MEDIUM, LogCleanerDeleteRetentionMsDoc) + .define(LogCleanerMinCompactionLagMsProp, LONG, Defaults.LogCleanerMinCompactionLagMs, MEDIUM, LogCleanerMinCompactionLagMsDoc) + .define(LogCleanerMaxCompactionLagMsProp, LONG, Defaults.LogCleanerMaxCompactionLagMs, MEDIUM, LogCleanerMaxCompactionLagMsDoc) + .define(LogIndexSizeMaxBytesProp, INT, Defaults.LogIndexSizeMaxBytes, atLeast(4), MEDIUM, LogIndexSizeMaxBytesDoc) + .define(LogIndexIntervalBytesProp, INT, Defaults.LogIndexIntervalBytes, atLeast(0), MEDIUM, LogIndexIntervalBytesDoc) + .define(LogFlushIntervalMessagesProp, LONG, Defaults.LogFlushIntervalMessages, atLeast(1), HIGH, LogFlushIntervalMessagesDoc) + .define(LogDeleteDelayMsProp, LONG, Defaults.LogDeleteDelayMs, atLeast(0), HIGH, LogDeleteDelayMsDoc) + .define(LogFlushSchedulerIntervalMsProp, LONG, Defaults.LogFlushSchedulerIntervalMs, HIGH, LogFlushSchedulerIntervalMsDoc) + .define(LogFlushIntervalMsProp, LONG, null, HIGH, LogFlushIntervalMsDoc) + .define(LogFlushOffsetCheckpointIntervalMsProp, INT, Defaults.LogFlushOffsetCheckpointIntervalMs, atLeast(0), HIGH, LogFlushOffsetCheckpointIntervalMsDoc) + .define(LogFlushStartOffsetCheckpointIntervalMsProp, INT, Defaults.LogFlushStartOffsetCheckpointIntervalMs, atLeast(0), HIGH, LogFlushStartOffsetCheckpointIntervalMsDoc) + .define(LogPreAllocateProp, BOOLEAN, Defaults.LogPreAllocateEnable, MEDIUM, LogPreAllocateEnableDoc) + .define(NumRecoveryThreadsPerDataDirProp, INT, Defaults.NumRecoveryThreadsPerDataDir, atLeast(1), HIGH, NumRecoveryThreadsPerDataDirDoc) + .define(AutoCreateTopicsEnableProp, BOOLEAN, Defaults.AutoCreateTopicsEnable, HIGH, AutoCreateTopicsEnableDoc) + .define(MinInSyncReplicasProp, INT, Defaults.MinInSyncReplicas, atLeast(1), HIGH, MinInSyncReplicasDoc) + .define(LogMessageFormatVersionProp, STRING, Defaults.LogMessageFormatVersion, ApiVersionValidator, MEDIUM, LogMessageFormatVersionDoc) + .define(LogMessageTimestampTypeProp, STRING, Defaults.LogMessageTimestampType, in("CreateTime", "LogAppendTime"), MEDIUM, LogMessageTimestampTypeDoc) + .define(LogMessageTimestampDifferenceMaxMsProp, LONG, Defaults.LogMessageTimestampDifferenceMaxMs, MEDIUM, LogMessageTimestampDifferenceMaxMsDoc) + .define(CreateTopicPolicyClassNameProp, CLASS, null, LOW, CreateTopicPolicyClassNameDoc) + .define(AlterConfigPolicyClassNameProp, CLASS, null, LOW, AlterConfigPolicyClassNameDoc) + .define(LogMessageDownConversionEnableProp, BOOLEAN, Defaults.MessageDownConversionEnable, LOW, LogMessageDownConversionEnableDoc) + + /** ********* Replication configuration ***********/ + .define(ControllerSocketTimeoutMsProp, INT, Defaults.ControllerSocketTimeoutMs, MEDIUM, ControllerSocketTimeoutMsDoc) + .define(DefaultReplicationFactorProp, INT, Defaults.DefaultReplicationFactor, MEDIUM, DefaultReplicationFactorDoc) + .define(ReplicaLagTimeMaxMsProp, LONG, Defaults.ReplicaLagTimeMaxMs, HIGH, ReplicaLagTimeMaxMsDoc) + .define(ReplicaSocketTimeoutMsProp, INT, Defaults.ReplicaSocketTimeoutMs, HIGH, ReplicaSocketTimeoutMsDoc) + .define(ReplicaSocketReceiveBufferBytesProp, INT, Defaults.ReplicaSocketReceiveBufferBytes, HIGH, ReplicaSocketReceiveBufferBytesDoc) + .define(ReplicaFetchMaxBytesProp, INT, Defaults.ReplicaFetchMaxBytes, atLeast(0), MEDIUM, ReplicaFetchMaxBytesDoc) + .define(ReplicaFetchWaitMaxMsProp, INT, Defaults.ReplicaFetchWaitMaxMs, HIGH, ReplicaFetchWaitMaxMsDoc) + .define(ReplicaFetchBackoffMsProp, INT, Defaults.ReplicaFetchBackoffMs, atLeast(0), MEDIUM, ReplicaFetchBackoffMsDoc) + .define(ReplicaFetchMinBytesProp, INT, Defaults.ReplicaFetchMinBytes, HIGH, ReplicaFetchMinBytesDoc) + .define(ReplicaFetchResponseMaxBytesProp, INT, Defaults.ReplicaFetchResponseMaxBytes, atLeast(0), MEDIUM, ReplicaFetchResponseMaxBytesDoc) + .define(NumReplicaFetchersProp, INT, Defaults.NumReplicaFetchers, HIGH, NumReplicaFetchersDoc) + .define(ReplicaHighWatermarkCheckpointIntervalMsProp, LONG, Defaults.ReplicaHighWatermarkCheckpointIntervalMs, HIGH, ReplicaHighWatermarkCheckpointIntervalMsDoc) + .define(FetchPurgatoryPurgeIntervalRequestsProp, INT, Defaults.FetchPurgatoryPurgeIntervalRequests, MEDIUM, FetchPurgatoryPurgeIntervalRequestsDoc) + .define(ProducerPurgatoryPurgeIntervalRequestsProp, INT, Defaults.ProducerPurgatoryPurgeIntervalRequests, MEDIUM, ProducerPurgatoryPurgeIntervalRequestsDoc) + .define(DeleteRecordsPurgatoryPurgeIntervalRequestsProp, INT, Defaults.DeleteRecordsPurgatoryPurgeIntervalRequests, MEDIUM, DeleteRecordsPurgatoryPurgeIntervalRequestsDoc) + .define(AutoLeaderRebalanceEnableProp, BOOLEAN, Defaults.AutoLeaderRebalanceEnable, HIGH, AutoLeaderRebalanceEnableDoc) + .define(LeaderImbalancePerBrokerPercentageProp, INT, Defaults.LeaderImbalancePerBrokerPercentage, HIGH, LeaderImbalancePerBrokerPercentageDoc) + .define(LeaderImbalanceCheckIntervalSecondsProp, LONG, Defaults.LeaderImbalanceCheckIntervalSeconds, HIGH, LeaderImbalanceCheckIntervalSecondsDoc) + .define(UncleanLeaderElectionEnableProp, BOOLEAN, Defaults.UncleanLeaderElectionEnable, HIGH, UncleanLeaderElectionEnableDoc) + .define(InterBrokerSecurityProtocolProp, STRING, Defaults.InterBrokerSecurityProtocol, MEDIUM, InterBrokerSecurityProtocolDoc) + .define(InterBrokerProtocolVersionProp, STRING, Defaults.InterBrokerProtocolVersion, ApiVersionValidator, MEDIUM, InterBrokerProtocolVersionDoc) + .define(InterBrokerListenerNameProp, STRING, null, MEDIUM, InterBrokerListenerNameDoc) + .define(ReplicaSelectorClassProp, STRING, null, MEDIUM, ReplicaSelectorClassDoc) + + /** ********* Controlled shutdown configuration ***********/ + .define(ControlledShutdownMaxRetriesProp, INT, Defaults.ControlledShutdownMaxRetries, MEDIUM, ControlledShutdownMaxRetriesDoc) + .define(ControlledShutdownRetryBackoffMsProp, LONG, Defaults.ControlledShutdownRetryBackoffMs, MEDIUM, ControlledShutdownRetryBackoffMsDoc) + .define(ControlledShutdownEnableProp, BOOLEAN, Defaults.ControlledShutdownEnable, MEDIUM, ControlledShutdownEnableDoc) + + /** ********* Group coordinator configuration ***********/ + .define(GroupMinSessionTimeoutMsProp, INT, Defaults.GroupMinSessionTimeoutMs, MEDIUM, GroupMinSessionTimeoutMsDoc) + .define(GroupMaxSessionTimeoutMsProp, INT, Defaults.GroupMaxSessionTimeoutMs, MEDIUM, GroupMaxSessionTimeoutMsDoc) + .define(GroupInitialRebalanceDelayMsProp, INT, Defaults.GroupInitialRebalanceDelayMs, MEDIUM, GroupInitialRebalanceDelayMsDoc) + .define(GroupMaxSizeProp, INT, Defaults.GroupMaxSize, atLeast(1), MEDIUM, GroupMaxSizeDoc) + + /** ********* Offset management configuration ***********/ + .define(OffsetMetadataMaxSizeProp, INT, Defaults.OffsetMetadataMaxSize, HIGH, OffsetMetadataMaxSizeDoc) + .define(OffsetsLoadBufferSizeProp, INT, Defaults.OffsetsLoadBufferSize, atLeast(1), HIGH, OffsetsLoadBufferSizeDoc) + .define(OffsetsTopicReplicationFactorProp, SHORT, Defaults.OffsetsTopicReplicationFactor, atLeast(1), HIGH, OffsetsTopicReplicationFactorDoc) + .define(OffsetsTopicPartitionsProp, INT, Defaults.OffsetsTopicPartitions, atLeast(1), HIGH, OffsetsTopicPartitionsDoc) + .define(OffsetsTopicSegmentBytesProp, INT, Defaults.OffsetsTopicSegmentBytes, atLeast(1), HIGH, OffsetsTopicSegmentBytesDoc) + .define(OffsetsTopicCompressionCodecProp, INT, Defaults.OffsetsTopicCompressionCodec, HIGH, OffsetsTopicCompressionCodecDoc) + .define(OffsetsRetentionMinutesProp, INT, Defaults.OffsetsRetentionMinutes, atLeast(1), HIGH, OffsetsRetentionMinutesDoc) + .define(OffsetsRetentionCheckIntervalMsProp, LONG, Defaults.OffsetsRetentionCheckIntervalMs, atLeast(1), HIGH, OffsetsRetentionCheckIntervalMsDoc) + .define(OffsetCommitTimeoutMsProp, INT, Defaults.OffsetCommitTimeoutMs, atLeast(1), HIGH, OffsetCommitTimeoutMsDoc) + .define(OffsetCommitRequiredAcksProp, SHORT, Defaults.OffsetCommitRequiredAcks, HIGH, OffsetCommitRequiredAcksDoc) + .define(DeleteTopicEnableProp, BOOLEAN, Defaults.DeleteTopicEnable, HIGH, DeleteTopicEnableDoc) + .define(CompressionTypeProp, STRING, Defaults.CompressionType, HIGH, CompressionTypeDoc) + + /** ********* Transaction management configuration ***********/ + .define(TransactionalIdExpirationMsProp, INT, Defaults.TransactionalIdExpirationMs, atLeast(1), HIGH, TransactionalIdExpirationMsDoc) + .define(TransactionsMaxTimeoutMsProp, INT, Defaults.TransactionsMaxTimeoutMs, atLeast(1), HIGH, TransactionsMaxTimeoutMsDoc) + .define(TransactionsTopicMinISRProp, INT, Defaults.TransactionsTopicMinISR, atLeast(1), HIGH, TransactionsTopicMinISRDoc) + .define(TransactionsLoadBufferSizeProp, INT, Defaults.TransactionsLoadBufferSize, atLeast(1), HIGH, TransactionsLoadBufferSizeDoc) + .define(TransactionsTopicReplicationFactorProp, SHORT, Defaults.TransactionsTopicReplicationFactor, atLeast(1), HIGH, TransactionsTopicReplicationFactorDoc) + .define(TransactionsTopicPartitionsProp, INT, Defaults.TransactionsTopicPartitions, atLeast(1), HIGH, TransactionsTopicPartitionsDoc) + .define(TransactionsTopicSegmentBytesProp, INT, Defaults.TransactionsTopicSegmentBytes, atLeast(1), HIGH, TransactionsTopicSegmentBytesDoc) + .define(TransactionsAbortTimedOutTransactionCleanupIntervalMsProp, INT, Defaults.TransactionsAbortTimedOutTransactionsCleanupIntervalMS, atLeast(1), LOW, TransactionsAbortTimedOutTransactionsIntervalMsDoc) + .define(TransactionsRemoveExpiredTransactionalIdCleanupIntervalMsProp, INT, Defaults.TransactionsRemoveExpiredTransactionsCleanupIntervalMS, atLeast(1), LOW, TransactionsRemoveExpiredTransactionsIntervalMsDoc) + + /** ********* Fetch Configuration **************/ + .define(MaxIncrementalFetchSessionCacheSlots, INT, Defaults.MaxIncrementalFetchSessionCacheSlots, atLeast(0), MEDIUM, MaxIncrementalFetchSessionCacheSlotsDoc) + .define(FetchMaxBytes, INT, Defaults.FetchMaxBytes, atLeast(1024), MEDIUM, FetchMaxBytesDoc) + + /** ********* Kafka Metrics Configuration ***********/ + .define(MetricNumSamplesProp, INT, Defaults.MetricNumSamples, atLeast(1), LOW, MetricNumSamplesDoc) + .define(MetricSampleWindowMsProp, LONG, Defaults.MetricSampleWindowMs, atLeast(1), LOW, MetricSampleWindowMsDoc) + .define(MetricReporterClassesProp, LIST, Defaults.MetricReporterClasses, LOW, MetricReporterClassesDoc) + .define(MetricRecordingLevelProp, STRING, Defaults.MetricRecordingLevel, LOW, MetricRecordingLevelDoc) + + /** ********* Kafka Yammer Metrics Reporter Configuration for docs ***********/ + .define(KafkaMetricsReporterClassesProp, LIST, Defaults.KafkaMetricReporterClasses, LOW, KafkaMetricsReporterClassesDoc) + .define(KafkaMetricsPollingIntervalSecondsProp, INT, Defaults.KafkaMetricsPollingIntervalSeconds, atLeast(1), LOW, KafkaMetricsPollingIntervalSecondsDoc) + + /** ********* Quota configuration ***********/ + .define(NumQuotaSamplesProp, INT, Defaults.NumQuotaSamples, atLeast(1), LOW, NumQuotaSamplesDoc) + .define(NumReplicationQuotaSamplesProp, INT, Defaults.NumReplicationQuotaSamples, atLeast(1), LOW, NumReplicationQuotaSamplesDoc) + .define(NumAlterLogDirsReplicationQuotaSamplesProp, INT, Defaults.NumAlterLogDirsReplicationQuotaSamples, atLeast(1), LOW, NumAlterLogDirsReplicationQuotaSamplesDoc) + .define(NumControllerQuotaSamplesProp, INT, Defaults.NumControllerQuotaSamples, atLeast(1), LOW, NumControllerQuotaSamplesDoc) + .define(QuotaWindowSizeSecondsProp, INT, Defaults.QuotaWindowSizeSeconds, atLeast(1), LOW, QuotaWindowSizeSecondsDoc) + .define(ReplicationQuotaWindowSizeSecondsProp, INT, Defaults.ReplicationQuotaWindowSizeSeconds, atLeast(1), LOW, ReplicationQuotaWindowSizeSecondsDoc) + .define(AlterLogDirsReplicationQuotaWindowSizeSecondsProp, INT, Defaults.AlterLogDirsReplicationQuotaWindowSizeSeconds, atLeast(1), LOW, AlterLogDirsReplicationQuotaWindowSizeSecondsDoc) + .define(ControllerQuotaWindowSizeSecondsProp, INT, Defaults.ControllerQuotaWindowSizeSeconds, atLeast(1), LOW, ControllerQuotaWindowSizeSecondsDoc) + .define(ClientQuotaCallbackClassProp, CLASS, null, LOW, ClientQuotaCallbackClassDoc) + + /** ********* General Security Configuration ****************/ + .define(ConnectionsMaxReauthMsProp, LONG, Defaults.ConnectionsMaxReauthMsDefault, MEDIUM, ConnectionsMaxReauthMsDoc) + .define(securityProviderClassProp, STRING, null, LOW, securityProviderClassDoc) + + /** ********* SSL Configuration ****************/ + .define(PrincipalBuilderClassProp, CLASS, Defaults.DefaultPrincipalSerde, MEDIUM, PrincipalBuilderClassDoc) + .define(SslProtocolProp, STRING, Defaults.SslProtocol, MEDIUM, SslProtocolDoc) + .define(SslProviderProp, STRING, null, MEDIUM, SslProviderDoc) + .define(SslEnabledProtocolsProp, LIST, Defaults.SslEnabledProtocols, MEDIUM, SslEnabledProtocolsDoc) + .define(SslKeystoreTypeProp, STRING, Defaults.SslKeystoreType, MEDIUM, SslKeystoreTypeDoc) + .define(SslKeystoreLocationProp, STRING, null, MEDIUM, SslKeystoreLocationDoc) + .define(SslKeystorePasswordProp, PASSWORD, null, MEDIUM, SslKeystorePasswordDoc) + .define(SslKeyPasswordProp, PASSWORD, null, MEDIUM, SslKeyPasswordDoc) + .define(SslKeystoreKeyProp, PASSWORD, null, MEDIUM, SslKeystoreKeyDoc) + .define(SslKeystoreCertificateChainProp, PASSWORD, null, MEDIUM, SslKeystoreCertificateChainDoc) + .define(SslTruststoreTypeProp, STRING, Defaults.SslTruststoreType, MEDIUM, SslTruststoreTypeDoc) + .define(SslTruststoreLocationProp, STRING, null, MEDIUM, SslTruststoreLocationDoc) + .define(SslTruststorePasswordProp, PASSWORD, null, MEDIUM, SslTruststorePasswordDoc) + .define(SslTruststoreCertificatesProp, PASSWORD, null, MEDIUM, SslTruststoreCertificatesDoc) + .define(SslKeyManagerAlgorithmProp, STRING, Defaults.SslKeyManagerAlgorithm, MEDIUM, SslKeyManagerAlgorithmDoc) + .define(SslTrustManagerAlgorithmProp, STRING, Defaults.SslTrustManagerAlgorithm, MEDIUM, SslTrustManagerAlgorithmDoc) + .define(SslEndpointIdentificationAlgorithmProp, STRING, Defaults.SslEndpointIdentificationAlgorithm, LOW, SslEndpointIdentificationAlgorithmDoc) + .define(SslSecureRandomImplementationProp, STRING, null, LOW, SslSecureRandomImplementationDoc) + .define(SslClientAuthProp, STRING, Defaults.SslClientAuthentication, in(Defaults.SslClientAuthenticationValidValues:_*), MEDIUM, SslClientAuthDoc) + .define(SslCipherSuitesProp, LIST, Collections.emptyList(), MEDIUM, SslCipherSuitesDoc) + .define(SslPrincipalMappingRulesProp, STRING, Defaults.SslPrincipalMappingRules, LOW, SslPrincipalMappingRulesDoc) + .define(SslEngineFactoryClassProp, CLASS, null, LOW, SslEngineFactoryClassDoc) + + /** ********* Sasl Configuration ****************/ + .define(SaslMechanismInterBrokerProtocolProp, STRING, Defaults.SaslMechanismInterBrokerProtocol, MEDIUM, SaslMechanismInterBrokerProtocolDoc) + .define(SaslJaasConfigProp, PASSWORD, null, MEDIUM, SaslJaasConfigDoc) + .define(SaslEnabledMechanismsProp, LIST, Defaults.SaslEnabledMechanisms, MEDIUM, SaslEnabledMechanismsDoc) + .define(SaslServerCallbackHandlerClassProp, CLASS, null, MEDIUM, SaslServerCallbackHandlerClassDoc) + .define(SaslClientCallbackHandlerClassProp, CLASS, null, MEDIUM, SaslClientCallbackHandlerClassDoc) + .define(SaslLoginClassProp, CLASS, null, MEDIUM, SaslLoginClassDoc) + .define(SaslLoginCallbackHandlerClassProp, CLASS, null, MEDIUM, SaslLoginCallbackHandlerClassDoc) + .define(SaslKerberosServiceNameProp, STRING, null, MEDIUM, SaslKerberosServiceNameDoc) + .define(SaslKerberosKinitCmdProp, STRING, Defaults.SaslKerberosKinitCmd, MEDIUM, SaslKerberosKinitCmdDoc) + .define(SaslKerberosTicketRenewWindowFactorProp, DOUBLE, Defaults.SaslKerberosTicketRenewWindowFactor, MEDIUM, SaslKerberosTicketRenewWindowFactorDoc) + .define(SaslKerberosTicketRenewJitterProp, DOUBLE, Defaults.SaslKerberosTicketRenewJitter, MEDIUM, SaslKerberosTicketRenewJitterDoc) + .define(SaslKerberosMinTimeBeforeReloginProp, LONG, Defaults.SaslKerberosMinTimeBeforeRelogin, MEDIUM, SaslKerberosMinTimeBeforeReloginDoc) + .define(SaslKerberosPrincipalToLocalRulesProp, LIST, Defaults.SaslKerberosPrincipalToLocalRules, MEDIUM, SaslKerberosPrincipalToLocalRulesDoc) + .define(SaslLoginRefreshWindowFactorProp, DOUBLE, Defaults.SaslLoginRefreshWindowFactor, MEDIUM, SaslLoginRefreshWindowFactorDoc) + .define(SaslLoginRefreshWindowJitterProp, DOUBLE, Defaults.SaslLoginRefreshWindowJitter, MEDIUM, SaslLoginRefreshWindowJitterDoc) + .define(SaslLoginRefreshMinPeriodSecondsProp, SHORT, Defaults.SaslLoginRefreshMinPeriodSeconds, MEDIUM, SaslLoginRefreshMinPeriodSecondsDoc) + .define(SaslLoginRefreshBufferSecondsProp, SHORT, Defaults.SaslLoginRefreshBufferSeconds, MEDIUM, SaslLoginRefreshBufferSecondsDoc) + .define(SaslLoginConnectTimeoutMsProp, INT, null, LOW, SaslLoginConnectTimeoutMsDoc) + .define(SaslLoginReadTimeoutMsProp, INT, null, LOW, SaslLoginReadTimeoutMsDoc) + .define(SaslLoginRetryBackoffMaxMsProp, LONG, Defaults.SaslLoginRetryBackoffMaxMs, LOW, SaslLoginRetryBackoffMaxMsDoc) + .define(SaslLoginRetryBackoffMsProp, LONG, Defaults.SaslLoginRetryBackoffMs, LOW, SaslLoginRetryBackoffMsDoc) + .define(SaslOAuthBearerScopeClaimNameProp, STRING, Defaults.SaslOAuthBearerScopeClaimName, LOW, SaslOAuthBearerScopeClaimNameDoc) + .define(SaslOAuthBearerSubClaimNameProp, STRING, Defaults.SaslOAuthBearerSubClaimName, LOW, SaslOAuthBearerSubClaimNameDoc) + .define(SaslOAuthBearerTokenEndpointUrlProp, STRING, null, MEDIUM, SaslOAuthBearerTokenEndpointUrlDoc) + .define(SaslOAuthBearerJwksEndpointUrlProp, STRING, null, MEDIUM, SaslOAuthBearerJwksEndpointUrlDoc) + .define(SaslOAuthBearerJwksEndpointRefreshMsProp, LONG, Defaults.SaslOAuthBearerJwksEndpointRefreshMs, LOW, SaslOAuthBearerJwksEndpointRefreshMsDoc) + .define(SaslOAuthBearerJwksEndpointRetryBackoffMsProp, LONG, Defaults.SaslOAuthBearerJwksEndpointRetryBackoffMs, LOW, SaslOAuthBearerJwksEndpointRetryBackoffMsDoc) + .define(SaslOAuthBearerJwksEndpointRetryBackoffMaxMsProp, LONG, Defaults.SaslOAuthBearerJwksEndpointRetryBackoffMaxMs, LOW, SaslOAuthBearerJwksEndpointRetryBackoffMaxMsDoc) + .define(SaslOAuthBearerClockSkewSecondsProp, INT, Defaults.SaslOAuthBearerClockSkewSeconds, LOW, SaslOAuthBearerClockSkewSecondsDoc) + .define(SaslOAuthBearerExpectedAudienceProp, LIST, null, LOW, SaslOAuthBearerExpectedAudienceDoc) + .define(SaslOAuthBearerExpectedIssuerProp, STRING, null, LOW, SaslOAuthBearerExpectedIssuerDoc) + + /** ********* Delegation Token Configuration ****************/ + .define(DelegationTokenSecretKeyAliasProp, PASSWORD, null, MEDIUM, DelegationTokenSecretKeyAliasDoc) + .define(DelegationTokenSecretKeyProp, PASSWORD, null, MEDIUM, DelegationTokenSecretKeyDoc) + .define(DelegationTokenMaxLifeTimeProp, LONG, Defaults.DelegationTokenMaxLifeTimeMsDefault, atLeast(1), MEDIUM, DelegationTokenMaxLifeTimeDoc) + .define(DelegationTokenExpiryTimeMsProp, LONG, Defaults.DelegationTokenExpiryTimeMsDefault, atLeast(1), MEDIUM, DelegationTokenExpiryTimeMsDoc) + .define(DelegationTokenExpiryCheckIntervalMsProp, LONG, Defaults.DelegationTokenExpiryCheckIntervalMsDefault, atLeast(1), LOW, DelegationTokenExpiryCheckIntervalDoc) + + /** ********* Password encryption configuration for dynamic configs *********/ + .define(PasswordEncoderSecretProp, PASSWORD, null, MEDIUM, PasswordEncoderSecretDoc) + .define(PasswordEncoderOldSecretProp, PASSWORD, null, MEDIUM, PasswordEncoderOldSecretDoc) + .define(PasswordEncoderKeyFactoryAlgorithmProp, STRING, null, LOW, PasswordEncoderKeyFactoryAlgorithmDoc) + .define(PasswordEncoderCipherAlgorithmProp, STRING, Defaults.PasswordEncoderCipherAlgorithm, LOW, PasswordEncoderCipherAlgorithmDoc) + .define(PasswordEncoderKeyLengthProp, INT, Defaults.PasswordEncoderKeyLength, atLeast(8), LOW, PasswordEncoderKeyLengthDoc) + .define(PasswordEncoderIterationsProp, INT, Defaults.PasswordEncoderIterations, atLeast(1024), LOW, PasswordEncoderIterationsDoc) + + /** ********* Raft Quorum Configuration *********/ + .define(RaftConfig.QUORUM_VOTERS_CONFIG, LIST, Defaults.QuorumVoters, new RaftConfig.ControllerQuorumVotersValidator(), HIGH, RaftConfig.QUORUM_VOTERS_DOC) + .define(RaftConfig.QUORUM_ELECTION_TIMEOUT_MS_CONFIG, INT, Defaults.QuorumElectionTimeoutMs, null, HIGH, RaftConfig.QUORUM_ELECTION_TIMEOUT_MS_DOC) + .define(RaftConfig.QUORUM_FETCH_TIMEOUT_MS_CONFIG, INT, Defaults.QuorumFetchTimeoutMs, null, HIGH, RaftConfig.QUORUM_FETCH_TIMEOUT_MS_DOC) + .define(RaftConfig.QUORUM_ELECTION_BACKOFF_MAX_MS_CONFIG, INT, Defaults.QuorumElectionBackoffMs, null, HIGH, RaftConfig.QUORUM_ELECTION_BACKOFF_MAX_MS_DOC) + .define(RaftConfig.QUORUM_LINGER_MS_CONFIG, INT, Defaults.QuorumLingerMs, null, MEDIUM, RaftConfig.QUORUM_LINGER_MS_DOC) + .define(RaftConfig.QUORUM_REQUEST_TIMEOUT_MS_CONFIG, INT, Defaults.QuorumRequestTimeoutMs, null, MEDIUM, RaftConfig.QUORUM_REQUEST_TIMEOUT_MS_DOC) + .define(RaftConfig.QUORUM_RETRY_BACKOFF_MS_CONFIG, INT, Defaults.QuorumRetryBackoffMs, null, LOW, RaftConfig.QUORUM_RETRY_BACKOFF_MS_DOC) + } + + /** ********* Remote Log Management Configuration *********/ + RemoteLogManagerConfig.CONFIG_DEF.configKeys().values().forEach(key => configDef.define(key)) + + def configNames: Seq[String] = configDef.names.asScala.toBuffer.sorted + private[server] def defaultValues: Map[String, _] = configDef.defaultValues.asScala + private[server] def configKeys: Map[String, ConfigKey] = configDef.configKeys.asScala + + def fromProps(props: Properties): KafkaConfig = + fromProps(props, true) + + def fromProps(props: Properties, doLog: Boolean): KafkaConfig = + new KafkaConfig(props, doLog) + + def fromProps(defaults: Properties, overrides: Properties): KafkaConfig = + fromProps(defaults, overrides, true) + + def fromProps(defaults: Properties, overrides: Properties, doLog: Boolean): KafkaConfig = { + val props = new Properties() + props ++= defaults + props ++= overrides + fromProps(props, doLog) + } + + def apply(props: java.util.Map[_, _], doLog: Boolean = true): KafkaConfig = new KafkaConfig(props, doLog) + + private def typeOf(name: String): Option[ConfigDef.Type] = Option(configDef.configKeys.get(name)).map(_.`type`) + + def configType(configName: String): Option[ConfigDef.Type] = { + val configType = configTypeExact(configName) + if (configType.isDefined) { + return configType + } + typeOf(configName) match { + case Some(t) => Some(t) + case None => + DynamicBrokerConfig.brokerConfigSynonyms(configName, matchListenerOverride = true).flatMap(typeOf).headOption + } + } + + private def configTypeExact(exactName: String): Option[ConfigDef.Type] = { + val configType = typeOf(exactName).orNull + if (configType != null) { + Some(configType) + } else { + val configKey = DynamicConfig.Broker.brokerConfigDef.configKeys().get(exactName) + if (configKey != null) { + Some(configKey.`type`) + } else { + None + } + } + } + + def maybeSensitive(configType: Option[ConfigDef.Type]): Boolean = { + // If we can't determine the config entry type, treat it as a sensitive config to be safe + configType.isEmpty || configType.contains(ConfigDef.Type.PASSWORD) + } + + def loggableValue(resourceType: ConfigResource.Type, name: String, value: String): String = { + val maybeSensitive = resourceType match { + case ConfigResource.Type.BROKER => KafkaConfig.maybeSensitive(KafkaConfig.configType(name)) + case ConfigResource.Type.TOPIC => KafkaConfig.maybeSensitive(LogConfig.configType(name)) + case ConfigResource.Type.BROKER_LOGGER => false + case _ => true + } + if (maybeSensitive) Password.HIDDEN else value + } + + /** + * Copy a configuration map, populating some keys that we want to treat as synonyms. + */ + def populateSynonyms(input: util.Map[_, _]): util.Map[Any, Any] = { + val output = new util.HashMap[Any, Any](input) + val brokerId = output.get(KafkaConfig.BrokerIdProp) + val nodeId = output.get(KafkaConfig.NodeIdProp) + if (brokerId == null && nodeId != null) { + output.put(KafkaConfig.BrokerIdProp, nodeId) + } else if (brokerId != null && nodeId == null) { + output.put(KafkaConfig.NodeIdProp, brokerId) + } + output + } +} + +class KafkaConfig private(doLog: Boolean, val props: java.util.Map[_, _], dynamicConfigOverride: Option[DynamicBrokerConfig]) + extends AbstractConfig(KafkaConfig.configDef, props, doLog) with Logging { + + def this(props: java.util.Map[_, _]) = this(true, KafkaConfig.populateSynonyms(props), None) + def this(props: java.util.Map[_, _], doLog: Boolean) = this(doLog, KafkaConfig.populateSynonyms(props), None) + def this(props: java.util.Map[_, _], doLog: Boolean, dynamicConfigOverride: Option[DynamicBrokerConfig]) = + this(doLog, KafkaConfig.populateSynonyms(props), dynamicConfigOverride) + + // Cache the current config to avoid acquiring read lock to access from dynamicConfig + @volatile private var currentConfig = this + private[server] val dynamicConfig = dynamicConfigOverride.getOrElse(new DynamicBrokerConfig(this)) + + private[server] def updateCurrentConfig(newConfig: KafkaConfig): Unit = { + this.currentConfig = newConfig + } + + // The following captures any system properties impacting ZooKeeper TLS configuration + // and defines the default values this instance will use if no explicit config is given. + // We make it part of each instance rather than the object to facilitate testing. + private val zkClientConfigViaSystemProperties = new ZKClientConfig() + + override def originals: util.Map[String, AnyRef] = + if (this eq currentConfig) super.originals else currentConfig.originals + override def values: util.Map[String, _] = + if (this eq currentConfig) super.values else currentConfig.values + override def nonInternalValues: util.Map[String, _] = + if (this eq currentConfig) super.nonInternalValues else currentConfig.values + override def originalsStrings: util.Map[String, String] = + if (this eq currentConfig) super.originalsStrings else currentConfig.originalsStrings + override def originalsWithPrefix(prefix: String): util.Map[String, AnyRef] = + if (this eq currentConfig) super.originalsWithPrefix(prefix) else currentConfig.originalsWithPrefix(prefix) + override def valuesWithPrefixOverride(prefix: String): util.Map[String, AnyRef] = + if (this eq currentConfig) super.valuesWithPrefixOverride(prefix) else currentConfig.valuesWithPrefixOverride(prefix) + override def get(key: String): AnyRef = + if (this eq currentConfig) super.get(key) else currentConfig.get(key) + + // During dynamic update, we use the values from this config, these are only used in DynamicBrokerConfig + private[server] def originalsFromThisConfig: util.Map[String, AnyRef] = super.originals + private[server] def valuesFromThisConfig: util.Map[String, _] = super.values + private[server] def valuesFromThisConfigWithPrefixOverride(prefix: String): util.Map[String, AnyRef] = + super.valuesWithPrefixOverride(prefix) + + /** ********* Zookeeper Configuration ***********/ + val zkConnect: String = getString(KafkaConfig.ZkConnectProp) + val zkSessionTimeoutMs: Int = getInt(KafkaConfig.ZkSessionTimeoutMsProp) + val zkConnectionTimeoutMs: Int = + Option(getInt(KafkaConfig.ZkConnectionTimeoutMsProp)).map(_.toInt).getOrElse(getInt(KafkaConfig.ZkSessionTimeoutMsProp)) + val zkSyncTimeMs: Int = getInt(KafkaConfig.ZkSyncTimeMsProp) + val zkEnableSecureAcls: Boolean = getBoolean(KafkaConfig.ZkEnableSecureAclsProp) + val zkMaxInFlightRequests: Int = getInt(KafkaConfig.ZkMaxInFlightRequestsProp) + + private val _remoteLogManagerConfig = new RemoteLogManagerConfig(this) + def remoteLogManagerConfig = _remoteLogManagerConfig + + private def zkBooleanConfigOrSystemPropertyWithDefaultValue(propKey: String): Boolean = { + // Use the system property if it exists and the Kafka config value was defaulted rather than actually provided + // Need to translate any system property value from true/false (String) to true/false (Boolean) + val actuallyProvided = originals.containsKey(propKey) + if (actuallyProvided) getBoolean(propKey) else { + val sysPropValue = KafkaConfig.zooKeeperClientProperty(zkClientConfigViaSystemProperties, propKey) + sysPropValue match { + case Some("true") => true + case Some(_) => false + case _ => getBoolean(propKey) // not specified so use the default value + } + } + } + + private def zkStringConfigOrSystemPropertyWithDefaultValue(propKey: String): String = { + // Use the system property if it exists and the Kafka config value was defaulted rather than actually provided + val actuallyProvided = originals.containsKey(propKey) + if (actuallyProvided) getString(propKey) else { + KafkaConfig.zooKeeperClientProperty(zkClientConfigViaSystemProperties, propKey) match { + case Some(v) => v + case _ => getString(propKey) // not specified so use the default value + } + } + } + + private def zkOptionalStringConfigOrSystemProperty(propKey: String): Option[String] = { + Option(getString(propKey)).orElse { + KafkaConfig.zooKeeperClientProperty(zkClientConfigViaSystemProperties, propKey) + } + } + private def zkPasswordConfigOrSystemProperty(propKey: String): Option[Password] = { + Option(getPassword(propKey)).orElse { + KafkaConfig.zooKeeperClientProperty(zkClientConfigViaSystemProperties, propKey).map(new Password(_)) + } + } + private def zkListConfigOrSystemProperty(propKey: String): Option[util.List[String]] = { + Option(getList(propKey)).orElse { + KafkaConfig.zooKeeperClientProperty(zkClientConfigViaSystemProperties, propKey).map { sysProp => + sysProp.split("\\s*,\\s*").toBuffer.asJava + } + } + } + + val zkSslClientEnable = zkBooleanConfigOrSystemPropertyWithDefaultValue(KafkaConfig.ZkSslClientEnableProp) + val zkClientCnxnSocketClassName = zkOptionalStringConfigOrSystemProperty(KafkaConfig.ZkClientCnxnSocketProp) + val zkSslKeyStoreLocation = zkOptionalStringConfigOrSystemProperty(KafkaConfig.ZkSslKeyStoreLocationProp) + val zkSslKeyStorePassword = zkPasswordConfigOrSystemProperty(KafkaConfig.ZkSslKeyStorePasswordProp) + val zkSslKeyStoreType = zkOptionalStringConfigOrSystemProperty(KafkaConfig.ZkSslKeyStoreTypeProp) + val zkSslTrustStoreLocation = zkOptionalStringConfigOrSystemProperty(KafkaConfig.ZkSslTrustStoreLocationProp) + val zkSslTrustStorePassword = zkPasswordConfigOrSystemProperty(KafkaConfig.ZkSslTrustStorePasswordProp) + val zkSslTrustStoreType = zkOptionalStringConfigOrSystemProperty(KafkaConfig.ZkSslTrustStoreTypeProp) + val ZkSslProtocol = zkStringConfigOrSystemPropertyWithDefaultValue(KafkaConfig.ZkSslProtocolProp) + val ZkSslEnabledProtocols = zkListConfigOrSystemProperty(KafkaConfig.ZkSslEnabledProtocolsProp) + val ZkSslCipherSuites = zkListConfigOrSystemProperty(KafkaConfig.ZkSslCipherSuitesProp) + val ZkSslEndpointIdentificationAlgorithm = { + // Use the system property if it exists and the Kafka config value was defaulted rather than actually provided + // Need to translate any system property value from true/false to HTTPS/ + val kafkaProp = KafkaConfig.ZkSslEndpointIdentificationAlgorithmProp + val actuallyProvided = originals.containsKey(kafkaProp) + if (actuallyProvided) + getString(kafkaProp) + else { + KafkaConfig.zooKeeperClientProperty(zkClientConfigViaSystemProperties, kafkaProp) match { + case Some("true") => "HTTPS" + case Some(_) => "" + case None => getString(kafkaProp) // not specified so use the default value + } + } + } + val ZkSslCrlEnable = zkBooleanConfigOrSystemPropertyWithDefaultValue(KafkaConfig.ZkSslCrlEnableProp) + val ZkSslOcspEnable = zkBooleanConfigOrSystemPropertyWithDefaultValue(KafkaConfig.ZkSslOcspEnableProp) + /** ********* General Configuration ***********/ + val brokerIdGenerationEnable: Boolean = getBoolean(KafkaConfig.BrokerIdGenerationEnableProp) + val maxReservedBrokerId: Int = getInt(KafkaConfig.MaxReservedBrokerIdProp) + var brokerId: Int = getInt(KafkaConfig.BrokerIdProp) + val nodeId: Int = getInt(KafkaConfig.NodeIdProp) + val processRoles: Set[ProcessRole] = parseProcessRoles() + val initialRegistrationTimeoutMs: Int = getInt(KafkaConfig.InitialBrokerRegistrationTimeoutMsProp) + val brokerHeartbeatIntervalMs: Int = getInt(KafkaConfig.BrokerHeartbeatIntervalMsProp) + val brokerSessionTimeoutMs: Int = getInt(KafkaConfig.BrokerSessionTimeoutMsProp) + + def requiresZookeeper: Boolean = processRoles.isEmpty + def usesSelfManagedQuorum: Boolean = processRoles.nonEmpty + + private def parseProcessRoles(): Set[ProcessRole] = { + val roles = getList(KafkaConfig.ProcessRolesProp).asScala.map { + case "broker" => BrokerRole + case "controller" => ControllerRole + case role => throw new ConfigException(s"Unknown process role '$role'" + + " (only 'broker' and 'controller' are allowed roles)") + } + + val distinctRoles: Set[ProcessRole] = roles.toSet + + if (distinctRoles.size != roles.size) { + throw new ConfigException(s"Duplicate role names found in `${KafkaConfig.ProcessRolesProp}`: $roles") + } + + distinctRoles + } + + def metadataLogDir: String = { + Option(getString(KafkaConfig.MetadataLogDirProp)) match { + case Some(dir) => dir + case None => logDirs.head + } + } + + def metadataLogSegmentBytes = getInt(KafkaConfig.MetadataLogSegmentBytesProp) + def metadataLogSegmentMillis = getLong(KafkaConfig.MetadataLogSegmentMillisProp) + def metadataRetentionBytes = getLong(KafkaConfig.MetadataMaxRetentionBytesProp) + def metadataRetentionMillis = getLong(KafkaConfig.MetadataMaxRetentionMillisProp) + + + def numNetworkThreads = getInt(KafkaConfig.NumNetworkThreadsProp) + def backgroundThreads = getInt(KafkaConfig.BackgroundThreadsProp) + val queuedMaxRequests = getInt(KafkaConfig.QueuedMaxRequestsProp) + val queuedMaxBytes = getLong(KafkaConfig.QueuedMaxBytesProp) + def numIoThreads = getInt(KafkaConfig.NumIoThreadsProp) + def messageMaxBytes = getInt(KafkaConfig.MessageMaxBytesProp) + val requestTimeoutMs = getInt(KafkaConfig.RequestTimeoutMsProp) + val connectionSetupTimeoutMs = getLong(KafkaConfig.ConnectionSetupTimeoutMsProp) + val connectionSetupTimeoutMaxMs = getLong(KafkaConfig.ConnectionSetupTimeoutMaxMsProp) + + def getNumReplicaAlterLogDirsThreads: Int = { + val numThreads: Integer = Option(getInt(KafkaConfig.NumReplicaAlterLogDirsThreadsProp)).getOrElse(logDirs.size) + numThreads + } + + /************* Metadata Configuration ***********/ + val metadataSnapshotMaxNewRecordBytes = getLong(KafkaConfig.MetadataSnapshotMaxNewRecordBytesProp) + + /************* Authorizer Configuration ***********/ + val authorizer: Option[Authorizer] = { + val className = getString(KafkaConfig.AuthorizerClassNameProp) + if (className == null || className.isEmpty) + None + else { + Some(AuthorizerUtils.createAuthorizer(className)) + } + } + + /** ********* Socket Server Configuration ***********/ + val socketSendBufferBytes = getInt(KafkaConfig.SocketSendBufferBytesProp) + val socketReceiveBufferBytes = getInt(KafkaConfig.SocketReceiveBufferBytesProp) + val socketRequestMaxBytes = getInt(KafkaConfig.SocketRequestMaxBytesProp) + val maxConnectionsPerIp = getInt(KafkaConfig.MaxConnectionsPerIpProp) + val maxConnectionsPerIpOverrides: Map[String, Int] = + getMap(KafkaConfig.MaxConnectionsPerIpOverridesProp, getString(KafkaConfig.MaxConnectionsPerIpOverridesProp)).map { case (k, v) => (k, v.toInt)} + def maxConnections = getInt(KafkaConfig.MaxConnectionsProp) + def maxConnectionCreationRate = getInt(KafkaConfig.MaxConnectionCreationRateProp) + val connectionsMaxIdleMs = getLong(KafkaConfig.ConnectionsMaxIdleMsProp) + val failedAuthenticationDelayMs = getInt(KafkaConfig.FailedAuthenticationDelayMsProp) + + /***************** rack configuration **************/ + val rack = Option(getString(KafkaConfig.RackProp)) + val replicaSelectorClassName = Option(getString(KafkaConfig.ReplicaSelectorClassProp)) + + /** ********* Log Configuration ***********/ + val autoCreateTopicsEnable = getBoolean(KafkaConfig.AutoCreateTopicsEnableProp) + val numPartitions = getInt(KafkaConfig.NumPartitionsProp) + val logDirs = CoreUtils.parseCsvList(Option(getString(KafkaConfig.LogDirsProp)).getOrElse(getString(KafkaConfig.LogDirProp))) + def logSegmentBytes = getInt(KafkaConfig.LogSegmentBytesProp) + def logFlushIntervalMessages = getLong(KafkaConfig.LogFlushIntervalMessagesProp) + val logCleanerThreads = getInt(KafkaConfig.LogCleanerThreadsProp) + def numRecoveryThreadsPerDataDir = getInt(KafkaConfig.NumRecoveryThreadsPerDataDirProp) + val logFlushSchedulerIntervalMs = getLong(KafkaConfig.LogFlushSchedulerIntervalMsProp) + val logFlushOffsetCheckpointIntervalMs = getInt(KafkaConfig.LogFlushOffsetCheckpointIntervalMsProp).toLong + val logFlushStartOffsetCheckpointIntervalMs = getInt(KafkaConfig.LogFlushStartOffsetCheckpointIntervalMsProp).toLong + val logCleanupIntervalMs = getLong(KafkaConfig.LogCleanupIntervalMsProp) + def logCleanupPolicy = getList(KafkaConfig.LogCleanupPolicyProp) + val offsetsRetentionMinutes = getInt(KafkaConfig.OffsetsRetentionMinutesProp) + val offsetsRetentionCheckIntervalMs = getLong(KafkaConfig.OffsetsRetentionCheckIntervalMsProp) + def logRetentionBytes = getLong(KafkaConfig.LogRetentionBytesProp) + val logCleanerDedupeBufferSize = getLong(KafkaConfig.LogCleanerDedupeBufferSizeProp) + val logCleanerDedupeBufferLoadFactor = getDouble(KafkaConfig.LogCleanerDedupeBufferLoadFactorProp) + val logCleanerIoBufferSize = getInt(KafkaConfig.LogCleanerIoBufferSizeProp) + val logCleanerIoMaxBytesPerSecond = getDouble(KafkaConfig.LogCleanerIoMaxBytesPerSecondProp) + def logCleanerDeleteRetentionMs = getLong(KafkaConfig.LogCleanerDeleteRetentionMsProp) + def logCleanerMinCompactionLagMs = getLong(KafkaConfig.LogCleanerMinCompactionLagMsProp) + def logCleanerMaxCompactionLagMs = getLong(KafkaConfig.LogCleanerMaxCompactionLagMsProp) + val logCleanerBackoffMs = getLong(KafkaConfig.LogCleanerBackoffMsProp) + def logCleanerMinCleanRatio = getDouble(KafkaConfig.LogCleanerMinCleanRatioProp) + val logCleanerEnable = getBoolean(KafkaConfig.LogCleanerEnableProp) + def logIndexSizeMaxBytes = getInt(KafkaConfig.LogIndexSizeMaxBytesProp) + def logIndexIntervalBytes = getInt(KafkaConfig.LogIndexIntervalBytesProp) + def logDeleteDelayMs = getLong(KafkaConfig.LogDeleteDelayMsProp) + def logRollTimeMillis: java.lang.Long = Option(getLong(KafkaConfig.LogRollTimeMillisProp)).getOrElse(60 * 60 * 1000L * getInt(KafkaConfig.LogRollTimeHoursProp)) + def logRollTimeJitterMillis: java.lang.Long = Option(getLong(KafkaConfig.LogRollTimeJitterMillisProp)).getOrElse(60 * 60 * 1000L * getInt(KafkaConfig.LogRollTimeJitterHoursProp)) + def logFlushIntervalMs: java.lang.Long = Option(getLong(KafkaConfig.LogFlushIntervalMsProp)).getOrElse(getLong(KafkaConfig.LogFlushSchedulerIntervalMsProp)) + def minInSyncReplicas = getInt(KafkaConfig.MinInSyncReplicasProp) + def logPreAllocateEnable: java.lang.Boolean = getBoolean(KafkaConfig.LogPreAllocateProp) + + // We keep the user-provided String as `ApiVersion.apply` can choose a slightly different version (eg if `0.10.0` + // is passed, `0.10.0-IV0` may be picked) + @nowarn("cat=deprecation") + private val logMessageFormatVersionString = getString(KafkaConfig.LogMessageFormatVersionProp) + + /* See `TopicConfig.MESSAGE_FORMAT_VERSION_CONFIG` for details */ + @deprecated("3.0") + lazy val logMessageFormatVersion = + if (LogConfig.shouldIgnoreMessageFormatVersion(interBrokerProtocolVersion)) + ApiVersion(Defaults.LogMessageFormatVersion) + else ApiVersion(logMessageFormatVersionString) + + def logMessageTimestampType = TimestampType.forName(getString(KafkaConfig.LogMessageTimestampTypeProp)) + def logMessageTimestampDifferenceMaxMs: Long = getLong(KafkaConfig.LogMessageTimestampDifferenceMaxMsProp) + def logMessageDownConversionEnable: Boolean = getBoolean(KafkaConfig.LogMessageDownConversionEnableProp) + + /** ********* Replication configuration ***********/ + val controllerSocketTimeoutMs: Int = getInt(KafkaConfig.ControllerSocketTimeoutMsProp) + val defaultReplicationFactor: Int = getInt(KafkaConfig.DefaultReplicationFactorProp) + val replicaLagTimeMaxMs = getLong(KafkaConfig.ReplicaLagTimeMaxMsProp) + val replicaSocketTimeoutMs = getInt(KafkaConfig.ReplicaSocketTimeoutMsProp) + val replicaSocketReceiveBufferBytes = getInt(KafkaConfig.ReplicaSocketReceiveBufferBytesProp) + val replicaFetchMaxBytes = getInt(KafkaConfig.ReplicaFetchMaxBytesProp) + val replicaFetchWaitMaxMs = getInt(KafkaConfig.ReplicaFetchWaitMaxMsProp) + val replicaFetchMinBytes = getInt(KafkaConfig.ReplicaFetchMinBytesProp) + val replicaFetchResponseMaxBytes = getInt(KafkaConfig.ReplicaFetchResponseMaxBytesProp) + val replicaFetchBackoffMs = getInt(KafkaConfig.ReplicaFetchBackoffMsProp) + def numReplicaFetchers = getInt(KafkaConfig.NumReplicaFetchersProp) + val replicaHighWatermarkCheckpointIntervalMs = getLong(KafkaConfig.ReplicaHighWatermarkCheckpointIntervalMsProp) + val fetchPurgatoryPurgeIntervalRequests = getInt(KafkaConfig.FetchPurgatoryPurgeIntervalRequestsProp) + val producerPurgatoryPurgeIntervalRequests = getInt(KafkaConfig.ProducerPurgatoryPurgeIntervalRequestsProp) + val deleteRecordsPurgatoryPurgeIntervalRequests = getInt(KafkaConfig.DeleteRecordsPurgatoryPurgeIntervalRequestsProp) + val autoLeaderRebalanceEnable = getBoolean(KafkaConfig.AutoLeaderRebalanceEnableProp) + val leaderImbalancePerBrokerPercentage = getInt(KafkaConfig.LeaderImbalancePerBrokerPercentageProp) + val leaderImbalanceCheckIntervalSeconds = getLong(KafkaConfig.LeaderImbalanceCheckIntervalSecondsProp) + def uncleanLeaderElectionEnable: java.lang.Boolean = getBoolean(KafkaConfig.UncleanLeaderElectionEnableProp) + + // We keep the user-provided String as `ApiVersion.apply` can choose a slightly different version (eg if `0.10.0` + // is passed, `0.10.0-IV0` may be picked) + val interBrokerProtocolVersionString = getString(KafkaConfig.InterBrokerProtocolVersionProp) + val interBrokerProtocolVersion = ApiVersion(interBrokerProtocolVersionString) + + /** ********* Controlled shutdown configuration ***********/ + val controlledShutdownMaxRetries = getInt(KafkaConfig.ControlledShutdownMaxRetriesProp) + val controlledShutdownRetryBackoffMs = getLong(KafkaConfig.ControlledShutdownRetryBackoffMsProp) + val controlledShutdownEnable = getBoolean(KafkaConfig.ControlledShutdownEnableProp) + + /** ********* Feature configuration ***********/ + def isFeatureVersioningSupported = interBrokerProtocolVersion >= KAFKA_2_7_IV0 + + /** ********* Group coordinator configuration ***********/ + val groupMinSessionTimeoutMs = getInt(KafkaConfig.GroupMinSessionTimeoutMsProp) + val groupMaxSessionTimeoutMs = getInt(KafkaConfig.GroupMaxSessionTimeoutMsProp) + val groupInitialRebalanceDelay = getInt(KafkaConfig.GroupInitialRebalanceDelayMsProp) + val groupMaxSize = getInt(KafkaConfig.GroupMaxSizeProp) + + /** ********* Offset management configuration ***********/ + val offsetMetadataMaxSize = getInt(KafkaConfig.OffsetMetadataMaxSizeProp) + val offsetsLoadBufferSize = getInt(KafkaConfig.OffsetsLoadBufferSizeProp) + val offsetsTopicReplicationFactor = getShort(KafkaConfig.OffsetsTopicReplicationFactorProp) + val offsetsTopicPartitions = getInt(KafkaConfig.OffsetsTopicPartitionsProp) + val offsetCommitTimeoutMs = getInt(KafkaConfig.OffsetCommitTimeoutMsProp) + val offsetCommitRequiredAcks = getShort(KafkaConfig.OffsetCommitRequiredAcksProp) + val offsetsTopicSegmentBytes = getInt(KafkaConfig.OffsetsTopicSegmentBytesProp) + val offsetsTopicCompressionCodec = Option(getInt(KafkaConfig.OffsetsTopicCompressionCodecProp)).map(value => CompressionCodec.getCompressionCodec(value)).orNull + + /** ********* Transaction management configuration ***********/ + val transactionalIdExpirationMs = getInt(KafkaConfig.TransactionalIdExpirationMsProp) + val transactionMaxTimeoutMs = getInt(KafkaConfig.TransactionsMaxTimeoutMsProp) + val transactionTopicMinISR = getInt(KafkaConfig.TransactionsTopicMinISRProp) + val transactionsLoadBufferSize = getInt(KafkaConfig.TransactionsLoadBufferSizeProp) + val transactionTopicReplicationFactor = getShort(KafkaConfig.TransactionsTopicReplicationFactorProp) + val transactionTopicPartitions = getInt(KafkaConfig.TransactionsTopicPartitionsProp) + val transactionTopicSegmentBytes = getInt(KafkaConfig.TransactionsTopicSegmentBytesProp) + val transactionAbortTimedOutTransactionCleanupIntervalMs = getInt(KafkaConfig.TransactionsAbortTimedOutTransactionCleanupIntervalMsProp) + val transactionRemoveExpiredTransactionalIdCleanupIntervalMs = getInt(KafkaConfig.TransactionsRemoveExpiredTransactionalIdCleanupIntervalMsProp) + + + /** ********* Metric Configuration **************/ + val metricNumSamples = getInt(KafkaConfig.MetricNumSamplesProp) + val metricSampleWindowMs = getLong(KafkaConfig.MetricSampleWindowMsProp) + val metricRecordingLevel = getString(KafkaConfig.MetricRecordingLevelProp) + + /** ********* SSL/SASL Configuration **************/ + // Security configs may be overridden for listeners, so it is not safe to use the base values + // Hence the base SSL/SASL configs are not fields of KafkaConfig, listener configs should be + // retrieved using KafkaConfig#valuesWithPrefixOverride + private def saslEnabledMechanisms(listenerName: ListenerName): Set[String] = { + val value = valuesWithPrefixOverride(listenerName.configPrefix).get(KafkaConfig.SaslEnabledMechanismsProp) + if (value != null) + value.asInstanceOf[util.List[String]].asScala.toSet + else + Set.empty[String] + } + + def interBrokerListenerName = getInterBrokerListenerNameAndSecurityProtocol._1 + def interBrokerSecurityProtocol = getInterBrokerListenerNameAndSecurityProtocol._2 + def controlPlaneListenerName = getControlPlaneListenerNameAndSecurityProtocol.map { case (listenerName, _) => listenerName } + def controlPlaneSecurityProtocol = getControlPlaneListenerNameAndSecurityProtocol.map { case (_, securityProtocol) => securityProtocol } + def saslMechanismInterBrokerProtocol = getString(KafkaConfig.SaslMechanismInterBrokerProtocolProp) + val saslInterBrokerHandshakeRequestEnable = interBrokerProtocolVersion >= KAFKA_0_10_0_IV1 + + /** ********* DelegationToken Configuration **************/ + val delegationTokenSecretKey = Option(getPassword(KafkaConfig.DelegationTokenSecretKeyProp)) + .getOrElse(getPassword(KafkaConfig.DelegationTokenSecretKeyAliasProp)) + val tokenAuthEnabled = (delegationTokenSecretKey != null && !delegationTokenSecretKey.value.isEmpty) + val delegationTokenMaxLifeMs = getLong(KafkaConfig.DelegationTokenMaxLifeTimeProp) + val delegationTokenExpiryTimeMs = getLong(KafkaConfig.DelegationTokenExpiryTimeMsProp) + val delegationTokenExpiryCheckIntervalMs = getLong(KafkaConfig.DelegationTokenExpiryCheckIntervalMsProp) + + /** ********* Password encryption configuration for dynamic configs *********/ + def passwordEncoderSecret = Option(getPassword(KafkaConfig.PasswordEncoderSecretProp)) + def passwordEncoderOldSecret = Option(getPassword(KafkaConfig.PasswordEncoderOldSecretProp)) + def passwordEncoderCipherAlgorithm = getString(KafkaConfig.PasswordEncoderCipherAlgorithmProp) + def passwordEncoderKeyFactoryAlgorithm = Option(getString(KafkaConfig.PasswordEncoderKeyFactoryAlgorithmProp)) + def passwordEncoderKeyLength = getInt(KafkaConfig.PasswordEncoderKeyLengthProp) + def passwordEncoderIterations = getInt(KafkaConfig.PasswordEncoderIterationsProp) + + /** ********* Quota Configuration **************/ + val numQuotaSamples = getInt(KafkaConfig.NumQuotaSamplesProp) + val quotaWindowSizeSeconds = getInt(KafkaConfig.QuotaWindowSizeSecondsProp) + val numReplicationQuotaSamples = getInt(KafkaConfig.NumReplicationQuotaSamplesProp) + val replicationQuotaWindowSizeSeconds = getInt(KafkaConfig.ReplicationQuotaWindowSizeSecondsProp) + val numAlterLogDirsReplicationQuotaSamples = getInt(KafkaConfig.NumAlterLogDirsReplicationQuotaSamplesProp) + val alterLogDirsReplicationQuotaWindowSizeSeconds = getInt(KafkaConfig.AlterLogDirsReplicationQuotaWindowSizeSecondsProp) + val numControllerQuotaSamples = getInt(KafkaConfig.NumControllerQuotaSamplesProp) + val controllerQuotaWindowSizeSeconds = getInt(KafkaConfig.ControllerQuotaWindowSizeSecondsProp) + + /** ********* Fetch Configuration **************/ + val maxIncrementalFetchSessionCacheSlots = getInt(KafkaConfig.MaxIncrementalFetchSessionCacheSlots) + val fetchMaxBytes = getInt(KafkaConfig.FetchMaxBytes) + + val deleteTopicEnable = getBoolean(KafkaConfig.DeleteTopicEnableProp) + def compressionType = getString(KafkaConfig.CompressionTypeProp) + + /** ********* Raft Quorum Configuration *********/ + val quorumVoters = getList(RaftConfig.QUORUM_VOTERS_CONFIG) + val quorumElectionTimeoutMs = getInt(RaftConfig.QUORUM_ELECTION_TIMEOUT_MS_CONFIG) + val quorumFetchTimeoutMs = getInt(RaftConfig.QUORUM_FETCH_TIMEOUT_MS_CONFIG) + val quorumElectionBackoffMs = getInt(RaftConfig.QUORUM_ELECTION_BACKOFF_MAX_MS_CONFIG) + val quorumLingerMs = getInt(RaftConfig.QUORUM_LINGER_MS_CONFIG) + val quorumRequestTimeoutMs = getInt(RaftConfig.QUORUM_REQUEST_TIMEOUT_MS_CONFIG) + val quorumRetryBackoffMs = getInt(RaftConfig.QUORUM_RETRY_BACKOFF_MS_CONFIG) + + def addReconfigurable(reconfigurable: Reconfigurable): Unit = { + dynamicConfig.addReconfigurable(reconfigurable) + } + + def removeReconfigurable(reconfigurable: Reconfigurable): Unit = { + dynamicConfig.removeReconfigurable(reconfigurable) + } + + def logRetentionTimeMillis: Long = { + val millisInMinute = 60L * 1000L + val millisInHour = 60L * millisInMinute + + val millis: java.lang.Long = + Option(getLong(KafkaConfig.LogRetentionTimeMillisProp)).getOrElse( + Option(getInt(KafkaConfig.LogRetentionTimeMinutesProp)) match { + case Some(mins) => millisInMinute * mins + case None => getInt(KafkaConfig.LogRetentionTimeHoursProp) * millisInHour + }) + + if (millis < 0) return -1 + millis + } + + private def getMap(propName: String, propValue: String): Map[String, String] = { + try { + CoreUtils.parseCsvMap(propValue) + } catch { + case e: Exception => throw new IllegalArgumentException("Error parsing configuration property '%s': %s".format(propName, e.getMessage)) + } + } + + def listeners: Seq[EndPoint] = + CoreUtils.listenerListToEndPoints(getString(KafkaConfig.ListenersProp), effectiveListenerSecurityProtocolMap) + + def controllerListenerNames: Seq[String] = { + val value = Option(getString(KafkaConfig.ControllerListenerNamesProp)).getOrElse("") + if (value.isEmpty) { + Seq.empty + } else { + value.split(",") + } + } + + def controllerListeners: Seq[EndPoint] = + listeners.filter(l => controllerListenerNames.contains(l.listenerName.value())) + + def saslMechanismControllerProtocol: String = getString(KafkaConfig.SaslMechanismControllerProtocolProp) + + def controlPlaneListener: Option[EndPoint] = { + controlPlaneListenerName.map { listenerName => + listeners.filter(endpoint => endpoint.listenerName.value() == listenerName.value()).head + } + } + + def dataPlaneListeners: Seq[EndPoint] = { + listeners.filterNot { listener => + val name = listener.listenerName.value() + name.equals(getString(KafkaConfig.ControlPlaneListenerNameProp)) || + controllerListenerNames.contains(name) + } + } + + // Use advertised listeners if defined, fallback to listeners otherwise + def effectiveAdvertisedListeners: Seq[EndPoint] = { + val advertisedListenersProp = getString(KafkaConfig.AdvertisedListenersProp) + if (advertisedListenersProp != null) + CoreUtils.listenerListToEndPoints(advertisedListenersProp, effectiveListenerSecurityProtocolMap, requireDistinctPorts=false) + else + listeners.filterNot(l => controllerListenerNames.contains(l.listenerName.value())) + } + + private def getInterBrokerListenerNameAndSecurityProtocol: (ListenerName, SecurityProtocol) = { + Option(getString(KafkaConfig.InterBrokerListenerNameProp)) match { + case Some(_) if originals.containsKey(KafkaConfig.InterBrokerSecurityProtocolProp) => + throw new ConfigException(s"Only one of ${KafkaConfig.InterBrokerListenerNameProp} and " + + s"${KafkaConfig.InterBrokerSecurityProtocolProp} should be set.") + case Some(name) => + val listenerName = ListenerName.normalised(name) + val securityProtocol = effectiveListenerSecurityProtocolMap.getOrElse(listenerName, + throw new ConfigException(s"Listener with name ${listenerName.value} defined in " + + s"${KafkaConfig.InterBrokerListenerNameProp} not found in ${KafkaConfig.ListenerSecurityProtocolMapProp}.")) + (listenerName, securityProtocol) + case None => + val securityProtocol = getSecurityProtocol(getString(KafkaConfig.InterBrokerSecurityProtocolProp), + KafkaConfig.InterBrokerSecurityProtocolProp) + (ListenerName.forSecurityProtocol(securityProtocol), securityProtocol) + } + } + + private def getControlPlaneListenerNameAndSecurityProtocol: Option[(ListenerName, SecurityProtocol)] = { + Option(getString(KafkaConfig.ControlPlaneListenerNameProp)) match { + case Some(name) => + val listenerName = ListenerName.normalised(name) + val securityProtocol = effectiveListenerSecurityProtocolMap.getOrElse(listenerName, + throw new ConfigException(s"Listener with ${listenerName.value} defined in " + + s"${KafkaConfig.ControlPlaneListenerNameProp} not found in ${KafkaConfig.ListenerSecurityProtocolMapProp}.")) + Some(listenerName, securityProtocol) + + case None => None + } + } + + private def getSecurityProtocol(protocolName: String, configName: String): SecurityProtocol = { + try SecurityProtocol.forName(protocolName) + catch { + case _: IllegalArgumentException => + throw new ConfigException(s"Invalid security protocol `$protocolName` defined in $configName") + } + } + + def effectiveListenerSecurityProtocolMap: Map[ListenerName, SecurityProtocol] = { + val mapValue = getMap(KafkaConfig.ListenerSecurityProtocolMapProp, getString(KafkaConfig.ListenerSecurityProtocolMapProp)) + .map { case (listenerName, protocolName) => + ListenerName.normalised(listenerName) -> getSecurityProtocol(protocolName, KafkaConfig.ListenerSecurityProtocolMapProp) + } + if (usesSelfManagedQuorum && !originals.containsKey(ListenerSecurityProtocolMapProp)) { + // Nothing was specified explicitly for listener.security.protocol.map, so we are using the default value, + // and we are using KRaft. + // Add PLAINTEXT mappings for controller listeners as long as there is no SSL or SASL_{PLAINTEXT,SSL} in use + def isSslOrSasl(name: String): Boolean = name.equals(SecurityProtocol.SSL.name) || name.equals(SecurityProtocol.SASL_SSL.name) || name.equals(SecurityProtocol.SASL_PLAINTEXT.name) + // check controller listener names (they won't appear in listeners when process.roles=broker) + // as well as listeners for occurrences of SSL or SASL_* + if (controllerListenerNames.exists(isSslOrSasl) || + parseCsvList(getString(KafkaConfig.ListenersProp)).exists(listenerValue => isSslOrSasl(EndPoint.parseListenerName(listenerValue)))) { + mapValue // don't add default mappings since we found something that is SSL or SASL_* + } else { + // add the PLAINTEXT mappings for all controller listener names that are not explicitly PLAINTEXT + mapValue ++ controllerListenerNames.filter(!SecurityProtocol.PLAINTEXT.name.equals(_)).map( + new ListenerName(_) -> SecurityProtocol.PLAINTEXT) + } + } else { + mapValue + } + } + + // Topic IDs are used with all self-managed quorum clusters and ZK cluster with IBP greater than or equal to 2.8 + def usesTopicId: Boolean = + usesSelfManagedQuorum || interBrokerProtocolVersion >= KAFKA_2_8_IV0 + + validateValues() + + @nowarn("cat=deprecation") + private def validateValues(): Unit = { + if (nodeId != brokerId) { + throw new ConfigException(s"You must set `${KafkaConfig.NodeIdProp}` to the same value as `${KafkaConfig.BrokerIdProp}`.") + } + if (requiresZookeeper) { + if (zkConnect == null) { + throw new ConfigException(s"Missing required configuration `${KafkaConfig.ZkConnectProp}` which has no default value.") + } + if (brokerIdGenerationEnable) { + require(brokerId >= -1 && brokerId <= maxReservedBrokerId, "broker.id must be greater than or equal to -1 and not greater than reserved.broker.max.id") + } else { + require(brokerId >= 0, "broker.id must be greater than or equal to 0") + } + } else { + // KRaft-based metadata quorum + if (nodeId < 0) { + throw new ConfigException(s"Missing configuration `${KafkaConfig.NodeIdProp}` which is required " + + s"when `process.roles` is defined (i.e. when running in KRaft mode).") + } + } + require(logRollTimeMillis >= 1, "log.roll.ms must be greater than or equal to 1") + require(logRollTimeJitterMillis >= 0, "log.roll.jitter.ms must be greater than or equal to 0") + require(logRetentionTimeMillis >= 1 || logRetentionTimeMillis == -1, "log.retention.ms must be unlimited (-1) or, greater than or equal to 1") + require(logDirs.nonEmpty, "At least one log directory must be defined via log.dirs or log.dir.") + require(logCleanerDedupeBufferSize / logCleanerThreads > 1024 * 1024, "log.cleaner.dedupe.buffer.size must be at least 1MB per cleaner thread.") + require(replicaFetchWaitMaxMs <= replicaSocketTimeoutMs, "replica.socket.timeout.ms should always be at least replica.fetch.wait.max.ms" + + " to prevent unnecessary socket timeouts") + require(replicaFetchWaitMaxMs <= replicaLagTimeMaxMs, "replica.fetch.wait.max.ms should always be less than or equal to replica.lag.time.max.ms" + + " to prevent frequent changes in ISR") + require(offsetCommitRequiredAcks >= -1 && offsetCommitRequiredAcks <= offsetsTopicReplicationFactor, + "offsets.commit.required.acks must be greater or equal -1 and less or equal to offsets.topic.replication.factor") + require(BrokerCompressionCodec.isValid(compressionType), "compression.type : " + compressionType + " is not valid." + + " Valid options are " + BrokerCompressionCodec.brokerCompressionOptions.mkString(",")) + val advertisedListenerNames = effectiveAdvertisedListeners.map(_.listenerName).toSet + + // validate KRaft-related configs + val voterAddressSpecsByNodeId = RaftConfig.parseVoterConnections(quorumVoters) + def validateNonEmptyQuorumVotersForKRaft(): Unit = { + if (voterAddressSpecsByNodeId.isEmpty) { + throw new ConfigException(s"If using ${KafkaConfig.ProcessRolesProp}, ${KafkaConfig.QuorumVotersProp} must contain a parseable set of voters.") + } + } + def validateControlPlaneListenerEmptyForKRaft(): Unit = { + require(controlPlaneListenerName.isEmpty, + s"${KafkaConfig.ControlPlaneListenerNameProp} is not supported in KRaft mode. KRaft uses ${KafkaConfig.ControllerListenerNamesProp} instead.") + } + def validateAdvertisedListenersDoesNotContainControllerListenersForKRaftBroker(): Unit = { + require(!advertisedListenerNames.exists(aln => controllerListenerNames.contains(aln.value())), + s"The advertised.listeners config must not contain KRaft controller listeners from ${KafkaConfig.ControllerListenerNamesProp} when ${KafkaConfig.ProcessRolesProp} contains the broker role because Kafka clients that send requests via advertised listeners do not send requests to KRaft controllers -- they only send requests to KRaft brokers.") + } + def validateControllerQuorumVotersMustContainNodeIdForKRaftController(): Unit = { + require(voterAddressSpecsByNodeId.containsKey(nodeId), + s"If ${KafkaConfig.ProcessRolesProp} contains the 'controller' role, the node id $nodeId must be included in the set of voters ${KafkaConfig.QuorumVotersProp}=${voterAddressSpecsByNodeId.asScala.keySet.toSet}") + } + def validateControllerListenerExistsForKRaftController(): Unit = { + require(controllerListeners.nonEmpty, + s"${KafkaConfig.ControllerListenerNamesProp} must contain at least one value appearing in the '${KafkaConfig.ListenersProp}' configuration when running the KRaft controller role") + } + def validateControllerListenerNamesMustAppearInListenersForKRaftController(): Unit = { + val listenerNameValues = listeners.map(_.listenerName.value).toSet + require(controllerListenerNames.forall(cln => listenerNameValues.contains(cln)), + s"${KafkaConfig.ControllerListenerNamesProp} must only contain values appearing in the '${KafkaConfig.ListenersProp}' configuration when running the KRaft controller role") + } + def validateAdvertisedListenersNonEmptyForBroker(): Unit = { + require(advertisedListenerNames.nonEmpty, + "There must be at least one advertised listener." + ( + if (processRoles.contains(BrokerRole)) s" Perhaps all listeners appear in ${ControllerListenerNamesProp}?" else "")) + } + if (processRoles == Set(BrokerRole)) { + // KRaft broker-only + validateNonEmptyQuorumVotersForKRaft() + validateControlPlaneListenerEmptyForKRaft() + validateAdvertisedListenersDoesNotContainControllerListenersForKRaftBroker() + // nodeId must not appear in controller.quorum.voters + require(!voterAddressSpecsByNodeId.containsKey(nodeId), + s"If ${KafkaConfig.ProcessRolesProp} contains just the 'broker' role, the node id $nodeId must not be included in the set of voters ${KafkaConfig.QuorumVotersProp}=${voterAddressSpecsByNodeId.asScala.keySet.toSet}") + // controller.listener.names must be non-empty... + require(controllerListenerNames.nonEmpty, + s"${KafkaConfig.ControllerListenerNamesProp} must contain at least one value when running KRaft with just the broker role") + // controller.listener.names are forbidden in listeners... + require(controllerListeners.isEmpty, + s"${KafkaConfig.ControllerListenerNamesProp} must not contain a value appearing in the '${KafkaConfig.ListenersProp}' configuration when running KRaft with just the broker role") + // controller.listener.names must all appear in listener.security.protocol.map + controllerListenerNames.foreach { name => + val listenerName = ListenerName.normalised(name) + if (!effectiveListenerSecurityProtocolMap.contains(listenerName)) { + throw new ConfigException(s"Controller listener with name ${listenerName.value} defined in " + + s"${KafkaConfig.ControllerListenerNamesProp} not found in ${KafkaConfig.ListenerSecurityProtocolMapProp} (an explicit security mapping for each controller listener is required if ${KafkaConfig.ListenerSecurityProtocolMapProp} is non-empty, or if there are security protocols other than PLAINTEXT in use)") + } + } + // warn that only the first controller listener is used if there is more than one + if (controllerListenerNames.size > 1) { + warn(s"${KafkaConfig.ControllerListenerNamesProp} has multiple entries; only the first will be used since ${KafkaConfig.ProcessRolesProp}=broker: ${controllerListenerNames.asJava}") + } + validateAdvertisedListenersNonEmptyForBroker() + } else if (processRoles == Set(ControllerRole)) { + // KRaft controller-only + validateNonEmptyQuorumVotersForKRaft() + validateControlPlaneListenerEmptyForKRaft() + // advertised listeners must be empty when not also running the broker role + val sourceOfAdvertisedListeners: String = + if (getString(KafkaConfig.AdvertisedListenersProp) != null) + s"${KafkaConfig.AdvertisedListenersProp}" + else + s"${KafkaConfig.ListenersProp}" + require(effectiveAdvertisedListeners.isEmpty, + s"The $sourceOfAdvertisedListeners config must only contain KRaft controller listeners from ${KafkaConfig.ControllerListenerNamesProp} when ${KafkaConfig.ProcessRolesProp}=controller") + validateControllerQuorumVotersMustContainNodeIdForKRaftController() + validateControllerListenerExistsForKRaftController() + validateControllerListenerNamesMustAppearInListenersForKRaftController() + } else if (processRoles == Set(BrokerRole, ControllerRole)) { + // KRaft colocated broker and controller + validateNonEmptyQuorumVotersForKRaft() + validateControlPlaneListenerEmptyForKRaft() + validateAdvertisedListenersDoesNotContainControllerListenersForKRaftBroker() + validateControllerQuorumVotersMustContainNodeIdForKRaftController() + validateControllerListenerExistsForKRaftController() + validateControllerListenerNamesMustAppearInListenersForKRaftController() + validateAdvertisedListenersNonEmptyForBroker() + } else { + // ZK-based + // controller listener names must be empty when not in KRaft mode + require(controllerListenerNames.isEmpty, s"${KafkaConfig.ControllerListenerNamesProp} must be empty when not running in KRaft mode: ${controllerListenerNames.asJava}") + validateAdvertisedListenersNonEmptyForBroker() + } + + val listenerNames = listeners.map(_.listenerName).toSet + if (processRoles.isEmpty || processRoles.contains(BrokerRole)) { + // validations for all broker setups (i.e. ZooKeeper and KRaft broker-only and KRaft co-located) + validateAdvertisedListenersNonEmptyForBroker() + require(advertisedListenerNames.contains(interBrokerListenerName), + s"${KafkaConfig.InterBrokerListenerNameProp} must be a listener name defined in ${KafkaConfig.AdvertisedListenersProp}. " + + s"The valid options based on currently configured listeners are ${advertisedListenerNames.map(_.value).mkString(",")}") + require(advertisedListenerNames.subsetOf(listenerNames), + s"${KafkaConfig.AdvertisedListenersProp} listener names must be equal to or a subset of the ones defined in ${KafkaConfig.ListenersProp}. " + + s"Found ${advertisedListenerNames.map(_.value).mkString(",")}. The valid options based on the current configuration " + + s"are ${listenerNames.map(_.value).mkString(",")}" + ) + } + + require(!effectiveAdvertisedListeners.exists(endpoint => endpoint.host=="0.0.0.0"), + s"${KafkaConfig.AdvertisedListenersProp} cannot use the nonroutable meta-address 0.0.0.0. "+ + s"Use a routable IP address.") + + // validate control.plane.listener.name config + if (controlPlaneListenerName.isDefined) { + require(advertisedListenerNames.contains(controlPlaneListenerName.get), + s"${KafkaConfig.ControlPlaneListenerNameProp} must be a listener name defined in ${KafkaConfig.AdvertisedListenersProp}. " + + s"The valid options based on currently configured listeners are ${advertisedListenerNames.map(_.value).mkString(",")}") + // controlPlaneListenerName should be different from interBrokerListenerName + require(!controlPlaneListenerName.get.value().equals(interBrokerListenerName.value()), + s"${KafkaConfig.ControlPlaneListenerNameProp}, when defined, should have a different value from the inter broker listener name. " + + s"Currently they both have the value ${controlPlaneListenerName.get}") + } + + val messageFormatVersion = new MessageFormatVersion(logMessageFormatVersionString, interBrokerProtocolVersionString) + if (messageFormatVersion.shouldWarn) + warn(messageFormatVersion.brokerWarningMessage) + + val recordVersion = logMessageFormatVersion.recordVersion + require(interBrokerProtocolVersion.recordVersion.value >= recordVersion.value, + s"log.message.format.version $logMessageFormatVersionString can only be used when inter.broker.protocol.version " + + s"is set to version ${ApiVersion.minSupportedFor(recordVersion).shortVersion} or higher") + + if (offsetsTopicCompressionCodec == ZStdCompressionCodec) + require(interBrokerProtocolVersion.recordVersion.value >= KAFKA_2_1_IV0.recordVersion.value, + "offsets.topic.compression.codec zstd can only be used when inter.broker.protocol.version " + + s"is set to version ${KAFKA_2_1_IV0.shortVersion} or higher") + + val interBrokerUsesSasl = interBrokerSecurityProtocol == SecurityProtocol.SASL_PLAINTEXT || interBrokerSecurityProtocol == SecurityProtocol.SASL_SSL + require(!interBrokerUsesSasl || saslInterBrokerHandshakeRequestEnable || saslMechanismInterBrokerProtocol == SaslConfigs.GSSAPI_MECHANISM, + s"Only GSSAPI mechanism is supported for inter-broker communication with SASL when inter.broker.protocol.version is set to $interBrokerProtocolVersionString") + require(!interBrokerUsesSasl || saslEnabledMechanisms(interBrokerListenerName).contains(saslMechanismInterBrokerProtocol), + s"${KafkaConfig.SaslMechanismInterBrokerProtocolProp} must be included in ${KafkaConfig.SaslEnabledMechanismsProp} when SASL is used for inter-broker communication") + require(queuedMaxBytes <= 0 || queuedMaxBytes >= socketRequestMaxBytes, + s"${KafkaConfig.QueuedMaxBytesProp} must be larger or equal to ${KafkaConfig.SocketRequestMaxBytesProp}") + + if (maxConnectionsPerIp == 0) + require(!maxConnectionsPerIpOverrides.isEmpty, s"${KafkaConfig.MaxConnectionsPerIpProp} can be set to zero only if" + + s" ${KafkaConfig.MaxConnectionsPerIpOverridesProp} property is set.") + + val invalidAddresses = maxConnectionsPerIpOverrides.keys.filterNot(address => Utils.validHostPattern(address)) + if (!invalidAddresses.isEmpty) + throw new IllegalArgumentException(s"${KafkaConfig.MaxConnectionsPerIpOverridesProp} contains invalid addresses : ${invalidAddresses.mkString(",")}") + + if (connectionsMaxIdleMs >= 0) + require(failedAuthenticationDelayMs < connectionsMaxIdleMs, + s"${KafkaConfig.FailedAuthenticationDelayMsProp}=$failedAuthenticationDelayMs should always be less than" + + s" ${KafkaConfig.ConnectionsMaxIdleMsProp}=$connectionsMaxIdleMs to prevent failed" + + s" authentication responses from timing out") + + val principalBuilderClass = getClass(KafkaConfig.PrincipalBuilderClassProp) + require(principalBuilderClass != null, s"${KafkaConfig.PrincipalBuilderClassProp} must be non-null") + require(classOf[KafkaPrincipalSerde].isAssignableFrom(principalBuilderClass), + s"${KafkaConfig.PrincipalBuilderClassProp} must implement KafkaPrincipalSerde") + } +} diff --git a/core/src/main/scala/kafka/server/KafkaRaftServer.scala b/core/src/main/scala/kafka/server/KafkaRaftServer.scala new file mode 100644 index 0000000..cda545d --- /dev/null +++ b/core/src/main/scala/kafka/server/KafkaRaftServer.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.io.File +import java.util.concurrent.CompletableFuture +import kafka.common.{InconsistentNodeIdException, KafkaException} +import kafka.log.UnifiedLog +import kafka.metrics.{KafkaMetricsReporter, KafkaYammerMetrics} +import kafka.raft.KafkaRaftManager +import kafka.server.KafkaRaftServer.{BrokerRole, ControllerRole} +import kafka.utils.{CoreUtils, Logging, Mx4jLoader, VerifiableProperties} +import org.apache.kafka.common.utils.{AppInfoParser, Time} +import org.apache.kafka.common.{TopicPartition, Uuid} +import org.apache.kafka.metadata.MetadataRecordSerde +import org.apache.kafka.raft.RaftConfig +import org.apache.kafka.server.common.ApiMessageAndVersion + +import scala.collection.Seq + +/** + * This class implements the KRaft (Kafka Raft) mode server which relies + * on a KRaft quorum for maintaining cluster metadata. It is responsible for + * constructing the controller and/or broker based on the `process.roles` + * configuration and for managing their basic lifecycle (startup and shutdown). + * + * Note that this server is a work in progress and we are releasing it as + * early access in 2.8.0. + */ +class KafkaRaftServer( + config: KafkaConfig, + time: Time, + threadNamePrefix: Option[String] +) extends Server with Logging { + + KafkaMetricsReporter.startReporters(VerifiableProperties(config.originals)) + KafkaYammerMetrics.INSTANCE.configure(config.originals) + + private val (metaProps, offlineDirs) = KafkaRaftServer.initializeLogDirs(config) + + private val metrics = Server.initializeMetrics( + config, + time, + metaProps.clusterId + ) + + private val controllerQuorumVotersFuture = CompletableFuture.completedFuture( + RaftConfig.parseVoterConnections(config.quorumVoters)) + + private val raftManager = new KafkaRaftManager[ApiMessageAndVersion]( + metaProps, + config, + new MetadataRecordSerde, + KafkaRaftServer.MetadataPartition, + KafkaRaftServer.MetadataTopicId, + time, + metrics, + threadNamePrefix, + controllerQuorumVotersFuture + ) + + private val broker: Option[BrokerServer] = if (config.processRoles.contains(BrokerRole)) { + Some(new BrokerServer( + config, + metaProps, + raftManager, + time, + metrics, + threadNamePrefix, + offlineDirs, + controllerQuorumVotersFuture, + Server.SUPPORTED_FEATURES + )) + } else { + None + } + + private val controller: Option[ControllerServer] = if (config.processRoles.contains(ControllerRole)) { + Some(new ControllerServer( + metaProps, + config, + raftManager, + time, + metrics, + threadNamePrefix, + controllerQuorumVotersFuture + )) + } else { + None + } + + override def startup(): Unit = { + Mx4jLoader.maybeLoad() + raftManager.startup() + controller.foreach(_.startup()) + broker.foreach(_.startup()) + AppInfoParser.registerAppInfo(Server.MetricsPrefix, config.brokerId.toString, metrics, time.milliseconds()) + info(KafkaBroker.STARTED_MESSAGE) + } + + override def shutdown(): Unit = { + broker.foreach(_.shutdown()) + raftManager.shutdown() + controller.foreach(_.shutdown()) + CoreUtils.swallow(AppInfoParser.unregisterAppInfo(Server.MetricsPrefix, config.brokerId.toString, metrics), this) + + } + + override def awaitShutdown(): Unit = { + broker.foreach(_.awaitShutdown()) + controller.foreach(_.awaitShutdown()) + } + +} + +object KafkaRaftServer { + val MetadataTopic = "__cluster_metadata" + val MetadataPartition = new TopicPartition(MetadataTopic, 0) + val MetadataTopicId = Uuid.METADATA_TOPIC_ID + + sealed trait ProcessRole + case object BrokerRole extends ProcessRole + case object ControllerRole extends ProcessRole + + /** + * Initialize the configured log directories, including both [[KafkaConfig.MetadataLogDirProp]] + * and [[KafkaConfig.LogDirProp]]. This method performs basic validation to ensure that all + * directories are accessible and have been initialized with consistent `meta.properties`. + * + * @param config The process configuration + * @return A tuple containing the loaded meta properties (which are guaranteed to + * be consistent across all log dirs) and the offline directories + */ + def initializeLogDirs(config: KafkaConfig): (MetaProperties, Seq[String]) = { + val logDirs = (config.logDirs.toSet + config.metadataLogDir).toSeq + val (rawMetaProperties, offlineDirs) = BrokerMetadataCheckpoint. + getBrokerMetadataAndOfflineDirs(logDirs, ignoreMissing = false) + + if (offlineDirs.contains(config.metadataLogDir)) { + throw new KafkaException("Cannot start server since `meta.properties` could not be " + + s"loaded from ${config.metadataLogDir}") + } + + val metadataPartitionDirName = UnifiedLog.logDirName(MetadataPartition) + val onlineNonMetadataDirs = logDirs.diff(offlineDirs :+ config.metadataLogDir) + onlineNonMetadataDirs.foreach { logDir => + val metadataDir = new File(logDir, metadataPartitionDirName) + if (metadataDir.exists) { + throw new KafkaException(s"Found unexpected metadata location in data directory `$metadataDir` " + + s"(the configured metadata directory is ${config.metadataLogDir}).") + } + } + + val metaProperties = MetaProperties.parse(rawMetaProperties) + if (config.nodeId != metaProperties.nodeId) { + throw new InconsistentNodeIdException( + s"Configured node.id `${config.nodeId}` doesn't match stored node.id `${metaProperties.nodeId}' in " + + "meta.properties. If you moved your data, make sure your configured controller.id matches. " + + "If you intend to create a new broker, you should remove all data in your data directories (log.dirs).") + } + + (metaProperties, offlineDirs.toSeq) + } + +} diff --git a/core/src/main/scala/kafka/server/KafkaRequestHandler.scala b/core/src/main/scala/kafka/server/KafkaRequestHandler.scala new file mode 100755 index 0000000..4d38c6e --- /dev/null +++ b/core/src/main/scala/kafka/server/KafkaRequestHandler.scala @@ -0,0 +1,375 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.network._ +import kafka.utils._ +import kafka.metrics.KafkaMetricsGroup + +import java.util.concurrent.{CountDownLatch, TimeUnit} +import java.util.concurrent.atomic.AtomicInteger +import com.yammer.metrics.core.Meter +import org.apache.kafka.common.internals.FatalExitError +import org.apache.kafka.common.utils.{KafkaThread, Time} + +import scala.collection.mutable +import scala.jdk.CollectionConverters._ + +trait ApiRequestHandler { + def handle(request: RequestChannel.Request, requestLocal: RequestLocal): Unit +} + +/** + * A thread that answers kafka requests. + */ +class KafkaRequestHandler(id: Int, + brokerId: Int, + val aggregateIdleMeter: Meter, + val totalHandlerThreads: AtomicInteger, + val requestChannel: RequestChannel, + apis: ApiRequestHandler, + time: Time) extends Runnable with Logging { + this.logIdent = s"[Kafka Request Handler $id on Broker $brokerId], " + private val shutdownComplete = new CountDownLatch(1) + private val requestLocal = RequestLocal.withThreadConfinedCaching + @volatile private var stopped = false + + def run(): Unit = { + while (!stopped) { + // We use a single meter for aggregate idle percentage for the thread pool. + // Since meter is calculated as total_recorded_value / time_window and + // time_window is independent of the number of threads, each recorded idle + // time should be discounted by # threads. + val startSelectTime = time.nanoseconds + + val req = requestChannel.receiveRequest(300) + val endTime = time.nanoseconds + val idleTime = endTime - startSelectTime + aggregateIdleMeter.mark(idleTime / totalHandlerThreads.get) + + req match { + case RequestChannel.ShutdownRequest => + debug(s"Kafka request handler $id on broker $brokerId received shut down command") + completeShutdown() + return + + case request: RequestChannel.Request => + try { + request.requestDequeueTimeNanos = endTime + trace(s"Kafka request handler $id on broker $brokerId handling request $request") + apis.handle(request, requestLocal) + } catch { + case e: FatalExitError => + completeShutdown() + Exit.exit(e.statusCode) + case e: Throwable => error("Exception when handling request", e) + } finally { + request.releaseBuffer() + } + + case null => // continue + } + } + completeShutdown() + } + + private def completeShutdown(): Unit = { + requestLocal.close() + shutdownComplete.countDown() + } + + def stop(): Unit = { + stopped = true + } + + def initiateShutdown(): Unit = requestChannel.sendShutdownRequest() + + def awaitShutdown(): Unit = shutdownComplete.await() + +} + +class KafkaRequestHandlerPool(val brokerId: Int, + val requestChannel: RequestChannel, + val apis: ApiRequestHandler, + time: Time, + numThreads: Int, + requestHandlerAvgIdleMetricName: String, + logAndThreadNamePrefix : String) extends Logging with KafkaMetricsGroup { + + private val threadPoolSize: AtomicInteger = new AtomicInteger(numThreads) + /* a meter to track the average free capacity of the request handlers */ + private val aggregateIdleMeter = newMeter(requestHandlerAvgIdleMetricName, "percent", TimeUnit.NANOSECONDS) + + this.logIdent = "[" + logAndThreadNamePrefix + " Kafka Request Handler on Broker " + brokerId + "], " + val runnables = new mutable.ArrayBuffer[KafkaRequestHandler](numThreads) + for (i <- 0 until numThreads) { + createHandler(i) + } + + def createHandler(id: Int): Unit = synchronized { + runnables += new KafkaRequestHandler(id, brokerId, aggregateIdleMeter, threadPoolSize, requestChannel, apis, time) + KafkaThread.daemon(logAndThreadNamePrefix + "-kafka-request-handler-" + id, runnables(id)).start() + } + + def resizeThreadPool(newSize: Int): Unit = synchronized { + val currentSize = threadPoolSize.get + info(s"Resizing request handler thread pool size from $currentSize to $newSize") + if (newSize > currentSize) { + for (i <- currentSize until newSize) { + createHandler(i) + } + } else if (newSize < currentSize) { + for (i <- 1 to (currentSize - newSize)) { + runnables.remove(currentSize - i).stop() + } + } + threadPoolSize.set(newSize) + } + + def shutdown(): Unit = synchronized { + info("shutting down") + for (handler <- runnables) + handler.initiateShutdown() + for (handler <- runnables) + handler.awaitShutdown() + info("shut down completely") + } +} + +class BrokerTopicMetrics(name: Option[String]) extends KafkaMetricsGroup { + val tags: scala.collection.Map[String, String] = name match { + case None => Map.empty + case Some(topic) => Map("topic" -> topic) + } + + case class MeterWrapper(metricType: String, eventType: String) { + @volatile private var lazyMeter: Meter = _ + private val meterLock = new Object + + def meter(): Meter = { + var meter = lazyMeter + if (meter == null) { + meterLock synchronized { + meter = lazyMeter + if (meter == null) { + meter = newMeter(metricType, eventType, TimeUnit.SECONDS, tags) + lazyMeter = meter + } + } + } + meter + } + + def close(): Unit = meterLock synchronized { + if (lazyMeter != null) { + removeMetric(metricType, tags) + lazyMeter = null + } + } + + if (tags.isEmpty) // greedily initialize the general topic metrics + meter() + } + + // an internal map for "lazy initialization" of certain metrics + private val metricTypeMap = new Pool[String, MeterWrapper]() + metricTypeMap.putAll(Map( + BrokerTopicStats.MessagesInPerSec -> MeterWrapper(BrokerTopicStats.MessagesInPerSec, "messages"), + BrokerTopicStats.BytesInPerSec -> MeterWrapper(BrokerTopicStats.BytesInPerSec, "bytes"), + BrokerTopicStats.BytesOutPerSec -> MeterWrapper(BrokerTopicStats.BytesOutPerSec, "bytes"), + BrokerTopicStats.BytesRejectedPerSec -> MeterWrapper(BrokerTopicStats.BytesRejectedPerSec, "bytes"), + BrokerTopicStats.FailedProduceRequestsPerSec -> MeterWrapper(BrokerTopicStats.FailedProduceRequestsPerSec, "requests"), + BrokerTopicStats.FailedFetchRequestsPerSec -> MeterWrapper(BrokerTopicStats.FailedFetchRequestsPerSec, "requests"), + BrokerTopicStats.TotalProduceRequestsPerSec -> MeterWrapper(BrokerTopicStats.TotalProduceRequestsPerSec, "requests"), + BrokerTopicStats.TotalFetchRequestsPerSec -> MeterWrapper(BrokerTopicStats.TotalFetchRequestsPerSec, "requests"), + BrokerTopicStats.FetchMessageConversionsPerSec -> MeterWrapper(BrokerTopicStats.FetchMessageConversionsPerSec, "requests"), + BrokerTopicStats.ProduceMessageConversionsPerSec -> MeterWrapper(BrokerTopicStats.ProduceMessageConversionsPerSec, "requests"), + BrokerTopicStats.NoKeyCompactedTopicRecordsPerSec -> MeterWrapper(BrokerTopicStats.NoKeyCompactedTopicRecordsPerSec, "requests"), + BrokerTopicStats.InvalidMagicNumberRecordsPerSec -> MeterWrapper(BrokerTopicStats.InvalidMagicNumberRecordsPerSec, "requests"), + BrokerTopicStats.InvalidMessageCrcRecordsPerSec -> MeterWrapper(BrokerTopicStats.InvalidMessageCrcRecordsPerSec, "requests"), + BrokerTopicStats.InvalidOffsetOrSequenceRecordsPerSec -> MeterWrapper(BrokerTopicStats.InvalidOffsetOrSequenceRecordsPerSec, "requests") + ).asJava) + if (name.isEmpty) { + metricTypeMap.put(BrokerTopicStats.ReplicationBytesInPerSec, MeterWrapper(BrokerTopicStats.ReplicationBytesInPerSec, "bytes")) + metricTypeMap.put(BrokerTopicStats.ReplicationBytesOutPerSec, MeterWrapper(BrokerTopicStats.ReplicationBytesOutPerSec, "bytes")) + metricTypeMap.put(BrokerTopicStats.ReassignmentBytesInPerSec, MeterWrapper(BrokerTopicStats.ReassignmentBytesInPerSec, "bytes")) + metricTypeMap.put(BrokerTopicStats.ReassignmentBytesOutPerSec, MeterWrapper(BrokerTopicStats.ReassignmentBytesOutPerSec, "bytes")) + } + + // used for testing only + def metricMap: Map[String, MeterWrapper] = metricTypeMap.toMap + + def messagesInRate: Meter = metricTypeMap.get(BrokerTopicStats.MessagesInPerSec).meter() + + def bytesInRate: Meter = metricTypeMap.get(BrokerTopicStats.BytesInPerSec).meter() + + def bytesOutRate: Meter = metricTypeMap.get(BrokerTopicStats.BytesOutPerSec).meter() + + def bytesRejectedRate: Meter = metricTypeMap.get(BrokerTopicStats.BytesRejectedPerSec).meter() + + private[server] def replicationBytesInRate: Option[Meter] = + if (name.isEmpty) Some(metricTypeMap.get(BrokerTopicStats.ReplicationBytesInPerSec).meter()) + else None + + private[server] def replicationBytesOutRate: Option[Meter] = + if (name.isEmpty) Some(metricTypeMap.get(BrokerTopicStats.ReplicationBytesOutPerSec).meter()) + else None + + private[server] def reassignmentBytesInPerSec: Option[Meter] = + if (name.isEmpty) Some(metricTypeMap.get(BrokerTopicStats.ReassignmentBytesInPerSec).meter()) + else None + + private[server] def reassignmentBytesOutPerSec: Option[Meter] = + if (name.isEmpty) Some(metricTypeMap.get(BrokerTopicStats.ReassignmentBytesOutPerSec).meter()) + else None + + def failedProduceRequestRate: Meter = metricTypeMap.get(BrokerTopicStats.FailedProduceRequestsPerSec).meter() + + def failedFetchRequestRate: Meter = metricTypeMap.get(BrokerTopicStats.FailedFetchRequestsPerSec).meter() + + def totalProduceRequestRate: Meter = metricTypeMap.get(BrokerTopicStats.TotalProduceRequestsPerSec).meter() + + def totalFetchRequestRate: Meter = metricTypeMap.get(BrokerTopicStats.TotalFetchRequestsPerSec).meter() + + def fetchMessageConversionsRate: Meter = metricTypeMap.get(BrokerTopicStats.FetchMessageConversionsPerSec).meter() + + def produceMessageConversionsRate: Meter = metricTypeMap.get(BrokerTopicStats.ProduceMessageConversionsPerSec).meter() + + def noKeyCompactedTopicRecordsPerSec: Meter = metricTypeMap.get(BrokerTopicStats.NoKeyCompactedTopicRecordsPerSec).meter() + + def invalidMagicNumberRecordsPerSec: Meter = metricTypeMap.get(BrokerTopicStats.InvalidMagicNumberRecordsPerSec).meter() + + def invalidMessageCrcRecordsPerSec: Meter = metricTypeMap.get(BrokerTopicStats.InvalidMessageCrcRecordsPerSec).meter() + + def invalidOffsetOrSequenceRecordsPerSec: Meter = metricTypeMap.get(BrokerTopicStats.InvalidOffsetOrSequenceRecordsPerSec).meter() + + def closeMetric(metricType: String): Unit = { + val meter = metricTypeMap.get(metricType) + if (meter != null) + meter.close() + } + + def close(): Unit = metricTypeMap.values.foreach(_.close()) +} + +object BrokerTopicStats { + val MessagesInPerSec = "MessagesInPerSec" + val BytesInPerSec = "BytesInPerSec" + val BytesOutPerSec = "BytesOutPerSec" + val BytesRejectedPerSec = "BytesRejectedPerSec" + val ReplicationBytesInPerSec = "ReplicationBytesInPerSec" + val ReplicationBytesOutPerSec = "ReplicationBytesOutPerSec" + val FailedProduceRequestsPerSec = "FailedProduceRequestsPerSec" + val FailedFetchRequestsPerSec = "FailedFetchRequestsPerSec" + val TotalProduceRequestsPerSec = "TotalProduceRequestsPerSec" + val TotalFetchRequestsPerSec = "TotalFetchRequestsPerSec" + val FetchMessageConversionsPerSec = "FetchMessageConversionsPerSec" + val ProduceMessageConversionsPerSec = "ProduceMessageConversionsPerSec" + val ReassignmentBytesInPerSec = "ReassignmentBytesInPerSec" + val ReassignmentBytesOutPerSec = "ReassignmentBytesOutPerSec" + + // These following topics are for LogValidator for better debugging on failed records + val NoKeyCompactedTopicRecordsPerSec = "NoKeyCompactedTopicRecordsPerSec" + val InvalidMagicNumberRecordsPerSec = "InvalidMagicNumberRecordsPerSec" + val InvalidMessageCrcRecordsPerSec = "InvalidMessageCrcRecordsPerSec" + val InvalidOffsetOrSequenceRecordsPerSec = "InvalidOffsetOrSequenceRecordsPerSec" + + private val valueFactory = (k: String) => new BrokerTopicMetrics(Some(k)) +} + +class BrokerTopicStats extends Logging { + import BrokerTopicStats._ + + private val stats = new Pool[String, BrokerTopicMetrics](Some(valueFactory)) + val allTopicsStats = new BrokerTopicMetrics(None) + + def topicStats(topic: String): BrokerTopicMetrics = + stats.getAndMaybePut(topic) + + def updateReplicationBytesIn(value: Long): Unit = { + allTopicsStats.replicationBytesInRate.foreach { metric => + metric.mark(value) + } + } + + private def updateReplicationBytesOut(value: Long): Unit = { + allTopicsStats.replicationBytesOutRate.foreach { metric => + metric.mark(value) + } + } + + def updateReassignmentBytesIn(value: Long): Unit = { + allTopicsStats.reassignmentBytesInPerSec.foreach { metric => + metric.mark(value) + } + } + + def updateReassignmentBytesOut(value: Long): Unit = { + allTopicsStats.reassignmentBytesOutPerSec.foreach { metric => + metric.mark(value) + } + } + + // This method only removes metrics only used for leader + def removeOldLeaderMetrics(topic: String): Unit = { + val topicMetrics = topicStats(topic) + if (topicMetrics != null) { + topicMetrics.closeMetric(BrokerTopicStats.MessagesInPerSec) + topicMetrics.closeMetric(BrokerTopicStats.BytesInPerSec) + topicMetrics.closeMetric(BrokerTopicStats.BytesRejectedPerSec) + topicMetrics.closeMetric(BrokerTopicStats.FailedProduceRequestsPerSec) + topicMetrics.closeMetric(BrokerTopicStats.TotalProduceRequestsPerSec) + topicMetrics.closeMetric(BrokerTopicStats.ProduceMessageConversionsPerSec) + topicMetrics.closeMetric(BrokerTopicStats.ReplicationBytesOutPerSec) + topicMetrics.closeMetric(BrokerTopicStats.ReassignmentBytesOutPerSec) + } + } + + // This method only removes metrics only used for follower + def removeOldFollowerMetrics(topic: String): Unit = { + val topicMetrics = topicStats(topic) + if (topicMetrics != null) { + topicMetrics.closeMetric(BrokerTopicStats.ReplicationBytesInPerSec) + topicMetrics.closeMetric(BrokerTopicStats.ReassignmentBytesInPerSec) + } + } + + def removeMetrics(topic: String): Unit = { + val metrics = stats.remove(topic) + if (metrics != null) + metrics.close() + } + + def updateBytesOut(topic: String, isFollower: Boolean, isReassignment: Boolean, value: Long): Unit = { + if (isFollower) { + if (isReassignment) + updateReassignmentBytesOut(value) + updateReplicationBytesOut(value) + } else { + topicStats(topic).bytesOutRate.mark(value) + allTopicsStats.bytesOutRate.mark(value) + } + } + + def close(): Unit = { + allTopicsStats.close() + stats.values.foreach(_.close()) + + info("Broker and topic stats closed") + } +} diff --git a/core/src/main/scala/kafka/server/KafkaServer.scala b/core/src/main/scala/kafka/server/KafkaServer.scala new file mode 100755 index 0000000..416895e --- /dev/null +++ b/core/src/main/scala/kafka/server/KafkaServer.scala @@ -0,0 +1,868 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.io.{File, IOException} +import java.net.{InetAddress, SocketTimeoutException} +import java.util.concurrent._ +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} + +import kafka.api.{KAFKA_0_9_0, KAFKA_2_2_IV0, KAFKA_2_4_IV1} +import kafka.cluster.{Broker, EndPoint} +import kafka.common.{GenerateBrokerIdException, InconsistentBrokerIdException, InconsistentClusterIdException} +import kafka.controller.KafkaController +import kafka.coordinator.group.GroupCoordinator +import kafka.coordinator.transaction.{ProducerIdManager, TransactionCoordinator} +import kafka.log.LogManager +import kafka.metrics.{KafkaMetricsReporter, KafkaYammerMetrics} +import kafka.network.{RequestChannel, SocketServer} +import kafka.security.CredentialProvider +import kafka.server.metadata.{ZkConfigRepository, ZkMetadataCache} +import kafka.utils._ +import kafka.zk.{AdminZkClient, BrokerInfo, KafkaZkClient} +import org.apache.kafka.clients.{ApiVersions, ManualMetadataUpdater, NetworkClient, NetworkClientUtils} +import org.apache.kafka.common.internals.Topic +import org.apache.kafka.common.message.ApiMessageType.ListenerType +import org.apache.kafka.common.message.ControlledShutdownRequestData +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.network._ +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.{ControlledShutdownRequest, ControlledShutdownResponse} +import org.apache.kafka.common.security.scram.internals.ScramMechanism +import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache +import org.apache.kafka.common.security.{JaasContext, JaasUtils} +import org.apache.kafka.common.utils.{AppInfoParser, LogContext, Time, Utils} +import org.apache.kafka.common.{Endpoint, Node} +import org.apache.kafka.metadata.BrokerState +import org.apache.kafka.server.authorizer.Authorizer +import org.apache.zookeeper.client.ZKClientConfig + +import scala.collection.{Map, Seq} +import scala.jdk.CollectionConverters._ + +object KafkaServer { + + def zkClientConfigFromKafkaConfig(config: KafkaConfig, forceZkSslClientEnable: Boolean = false): ZKClientConfig = { + val clientConfig = new ZKClientConfig + if (config.zkSslClientEnable || forceZkSslClientEnable) { + KafkaConfig.setZooKeeperClientProperty(clientConfig, KafkaConfig.ZkSslClientEnableProp, "true") + config.zkClientCnxnSocketClassName.foreach(KafkaConfig.setZooKeeperClientProperty(clientConfig, KafkaConfig.ZkClientCnxnSocketProp, _)) + config.zkSslKeyStoreLocation.foreach(KafkaConfig.setZooKeeperClientProperty(clientConfig, KafkaConfig.ZkSslKeyStoreLocationProp, _)) + config.zkSslKeyStorePassword.foreach(x => KafkaConfig.setZooKeeperClientProperty(clientConfig, KafkaConfig.ZkSslKeyStorePasswordProp, x.value)) + config.zkSslKeyStoreType.foreach(KafkaConfig.setZooKeeperClientProperty(clientConfig, KafkaConfig.ZkSslKeyStoreTypeProp, _)) + config.zkSslTrustStoreLocation.foreach(KafkaConfig.setZooKeeperClientProperty(clientConfig, KafkaConfig.ZkSslTrustStoreLocationProp, _)) + config.zkSslTrustStorePassword.foreach(x => KafkaConfig.setZooKeeperClientProperty(clientConfig, KafkaConfig.ZkSslTrustStorePasswordProp, x.value)) + config.zkSslTrustStoreType.foreach(KafkaConfig.setZooKeeperClientProperty(clientConfig, KafkaConfig.ZkSslTrustStoreTypeProp, _)) + KafkaConfig.setZooKeeperClientProperty(clientConfig, KafkaConfig.ZkSslProtocolProp, config.ZkSslProtocol) + config.ZkSslEnabledProtocols.foreach(KafkaConfig.setZooKeeperClientProperty(clientConfig, KafkaConfig.ZkSslEnabledProtocolsProp, _)) + config.ZkSslCipherSuites.foreach(KafkaConfig.setZooKeeperClientProperty(clientConfig, KafkaConfig.ZkSslCipherSuitesProp, _)) + KafkaConfig.setZooKeeperClientProperty(clientConfig, KafkaConfig.ZkSslEndpointIdentificationAlgorithmProp, config.ZkSslEndpointIdentificationAlgorithm) + KafkaConfig.setZooKeeperClientProperty(clientConfig, KafkaConfig.ZkSslCrlEnableProp, config.ZkSslCrlEnable.toString) + KafkaConfig.setZooKeeperClientProperty(clientConfig, KafkaConfig.ZkSslOcspEnableProp, config.ZkSslOcspEnable.toString) + } + clientConfig + } + + val MIN_INCREMENTAL_FETCH_SESSION_EVICTION_MS: Long = 120000 +} + +/** + * Represents the lifecycle of a single Kafka broker. Handles all functionality required + * to start up and shutdown a single Kafka node. + */ +class KafkaServer( + val config: KafkaConfig, + time: Time = Time.SYSTEM, + threadNamePrefix: Option[String] = None, + enableForwarding: Boolean = false +) extends KafkaBroker with Server { + + private val startupComplete = new AtomicBoolean(false) + private val isShuttingDown = new AtomicBoolean(false) + private val isStartingUp = new AtomicBoolean(false) + + @volatile private var _brokerState: BrokerState = BrokerState.NOT_RUNNING + private var shutdownLatch = new CountDownLatch(1) + private var logContext: LogContext = null + + private val kafkaMetricsReporters: Seq[KafkaMetricsReporter] = + KafkaMetricsReporter.startReporters(VerifiableProperties(config.originals)) + var kafkaYammerMetrics: KafkaYammerMetrics = null + var metrics: Metrics = null + + @volatile var dataPlaneRequestProcessor: KafkaApis = null + var controlPlaneRequestProcessor: KafkaApis = null + + var authorizer: Option[Authorizer] = None + @volatile var socketServer: SocketServer = null + var dataPlaneRequestHandlerPool: KafkaRequestHandlerPool = null + var controlPlaneRequestHandlerPool: KafkaRequestHandlerPool = null + + var logDirFailureChannel: LogDirFailureChannel = null + @volatile private var _logManager: LogManager = null + + @volatile private var _replicaManager: ReplicaManager = null + var adminManager: ZkAdminManager = null + var tokenManager: DelegationTokenManager = null + + var dynamicConfigHandlers: Map[String, ConfigHandler] = null + var dynamicConfigManager: DynamicConfigManager = null + var credentialProvider: CredentialProvider = null + var tokenCache: DelegationTokenCache = null + + @volatile var groupCoordinator: GroupCoordinator = null + + var transactionCoordinator: TransactionCoordinator = null + + @volatile private var _kafkaController: KafkaController = null + + var forwardingManager: Option[ForwardingManager] = None + + var autoTopicCreationManager: AutoTopicCreationManager = null + + var clientToControllerChannelManager: BrokerToControllerChannelManager = null + + var alterIsrManager: AlterIsrManager = null + + var kafkaScheduler: KafkaScheduler = null + + @volatile var metadataCache: ZkMetadataCache = null + var quotaManagers: QuotaFactory.QuotaManagers = null + + val zkClientConfig: ZKClientConfig = KafkaServer.zkClientConfigFromKafkaConfig(config) + private var _zkClient: KafkaZkClient = null + private var configRepository: ZkConfigRepository = null + + val correlationId: AtomicInteger = new AtomicInteger(0) + val brokerMetaPropsFile = "meta.properties" + val brokerMetadataCheckpoints = config.logDirs.map { logDir => + (logDir, new BrokerMetadataCheckpoint(new File(logDir + File.separator + brokerMetaPropsFile))) + }.toMap + + private var _clusterId: String = null + @volatile var _brokerTopicStats: BrokerTopicStats = null + + private var _featureChangeListener: FinalizedFeatureChangeListener = null + + val brokerFeatures: BrokerFeatures = BrokerFeatures.createDefault() + val featureCache: FinalizedFeatureCache = new FinalizedFeatureCache(brokerFeatures) + + override def brokerState: BrokerState = _brokerState + + def clusterId: String = _clusterId + + // Visible for testing + private[kafka] def zkClient = _zkClient + + override def brokerTopicStats = _brokerTopicStats + + private[kafka] def featureChangeListener = _featureChangeListener + + override def replicaManager: ReplicaManager = _replicaManager + + override def logManager: LogManager = _logManager + + def kafkaController: KafkaController = _kafkaController + + /** + * Start up API for bringing up a single instance of the Kafka server. + * Instantiates the LogManager, the SocketServer and the request handlers - KafkaRequestHandlers + */ + override def startup(): Unit = { + try { + info("starting") + + if (isShuttingDown.get) + throw new IllegalStateException("Kafka server is still shutting down, cannot re-start!") + + if (startupComplete.get) + return + + val canStartup = isStartingUp.compareAndSet(false, true) + if (canStartup) { + _brokerState = BrokerState.STARTING + + /* setup zookeeper */ + initZkClient(time) + configRepository = new ZkConfigRepository(new AdminZkClient(zkClient)) + + /* initialize features */ + _featureChangeListener = new FinalizedFeatureChangeListener(featureCache, _zkClient) + if (config.isFeatureVersioningSupported) { + _featureChangeListener.initOrThrow(config.zkConnectionTimeoutMs) + } + + /* Get or create cluster_id */ + _clusterId = getOrGenerateClusterId(zkClient) + info(s"Cluster ID = ${clusterId}") + + /* load metadata */ + val (preloadedBrokerMetadataCheckpoint, initialOfflineDirs) = + BrokerMetadataCheckpoint.getBrokerMetadataAndOfflineDirs(config.logDirs, ignoreMissing = true) + + if (preloadedBrokerMetadataCheckpoint.version != 0) { + throw new RuntimeException(s"Found unexpected version in loaded `meta.properties`: " + + s"$preloadedBrokerMetadataCheckpoint. Zk-based brokers only support version 0 " + + "(which is implicit when the `version` field is missing).") + } + + /* check cluster id */ + if (preloadedBrokerMetadataCheckpoint.clusterId.isDefined && preloadedBrokerMetadataCheckpoint.clusterId.get != clusterId) + throw new InconsistentClusterIdException( + s"The Cluster ID ${clusterId} doesn't match stored clusterId ${preloadedBrokerMetadataCheckpoint.clusterId} in meta.properties. " + + s"The broker is trying to join the wrong cluster. Configured zookeeper.connect may be wrong.") + + /* generate brokerId */ + config.brokerId = getOrGenerateBrokerId(preloadedBrokerMetadataCheckpoint) + logContext = new LogContext(s"[KafkaServer id=${config.brokerId}] ") + this.logIdent = logContext.logPrefix + + // initialize dynamic broker configs from ZooKeeper. Any updates made after this will be + // applied after DynamicConfigManager starts. + config.dynamicConfig.initialize(Some(zkClient)) + + /* start scheduler */ + kafkaScheduler = new KafkaScheduler(config.backgroundThreads) + kafkaScheduler.startup() + + /* create and configure metrics */ + kafkaYammerMetrics = KafkaYammerMetrics.INSTANCE + kafkaYammerMetrics.configure(config.originals) + metrics = Server.initializeMetrics(config, time, clusterId) + + /* register broker metrics */ + _brokerTopicStats = new BrokerTopicStats + + quotaManagers = QuotaFactory.instantiate(config, metrics, time, threadNamePrefix.getOrElse("")) + KafkaBroker.notifyClusterListeners(clusterId, kafkaMetricsReporters ++ metrics.reporters.asScala) + + logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size) + + /* start log manager */ + _logManager = LogManager(config, initialOfflineDirs, + new ZkConfigRepository(new AdminZkClient(zkClient)), + kafkaScheduler, time, brokerTopicStats, logDirFailureChannel, config.usesTopicId) + _brokerState = BrokerState.RECOVERY + logManager.startup(zkClient.getAllTopicsInCluster()) + + metadataCache = MetadataCache.zkMetadataCache(config.brokerId) + // Enable delegation token cache for all SCRAM mechanisms to simplify dynamic update. + // This keeps the cache up-to-date if new SCRAM mechanisms are enabled dynamically. + tokenCache = new DelegationTokenCache(ScramMechanism.mechanismNames) + credentialProvider = new CredentialProvider(ScramMechanism.mechanismNames, tokenCache) + + clientToControllerChannelManager = BrokerToControllerChannelManager( + controllerNodeProvider = MetadataCacheControllerNodeProvider(config, metadataCache), + time = time, + metrics = metrics, + config = config, + channelName = "forwarding", + threadNamePrefix = threadNamePrefix, + retryTimeoutMs = config.requestTimeoutMs.longValue) + clientToControllerChannelManager.start() + + /* start forwarding manager */ + var autoTopicCreationChannel = Option.empty[BrokerToControllerChannelManager] + if (enableForwarding) { + this.forwardingManager = Some(ForwardingManager(clientToControllerChannelManager)) + autoTopicCreationChannel = Some(clientToControllerChannelManager) + } + + val apiVersionManager = ApiVersionManager( + ListenerType.ZK_BROKER, + config, + forwardingManager, + brokerFeatures, + featureCache + ) + + // Create and start the socket server acceptor threads so that the bound port is known. + // Delay starting processors until the end of the initialization sequence to ensure + // that credentials have been loaded before processing authentications. + // + // Note that we allow the use of KRaft mode controller APIs when forwarding is enabled + // so that the Envelope request is exposed. This is only used in testing currently. + socketServer = new SocketServer(config, metrics, time, credentialProvider, apiVersionManager) + socketServer.startup(startProcessingRequests = false) + + /* start replica manager */ + alterIsrManager = if (config.interBrokerProtocolVersion.isAlterIsrSupported) { + AlterIsrManager( + config = config, + metadataCache = metadataCache, + scheduler = kafkaScheduler, + time = time, + metrics = metrics, + threadNamePrefix = threadNamePrefix, + brokerEpochSupplier = () => kafkaController.brokerEpoch, + config.brokerId + ) + } else { + AlterIsrManager(kafkaScheduler, time, zkClient) + } + alterIsrManager.start() + + _replicaManager = createReplicaManager(isShuttingDown) + replicaManager.startup() + + val brokerInfo = createBrokerInfo + val brokerEpoch = zkClient.registerBroker(brokerInfo) + + // Now that the broker is successfully registered, checkpoint its metadata + checkpointBrokerMetadata(ZkMetaProperties(clusterId, config.brokerId)) + + /* start token manager */ + tokenManager = new DelegationTokenManager(config, tokenCache, time , zkClient) + tokenManager.startup() + + /* start kafka controller */ + _kafkaController = new KafkaController(config, zkClient, time, metrics, brokerInfo, brokerEpoch, tokenManager, brokerFeatures, featureCache, threadNamePrefix) + kafkaController.startup() + + adminManager = new ZkAdminManager(config, metrics, metadataCache, zkClient) + + /* start group coordinator */ + // Hardcode Time.SYSTEM for now as some Streams tests fail otherwise, it would be good to fix the underlying issue + groupCoordinator = GroupCoordinator(config, replicaManager, Time.SYSTEM, metrics) + groupCoordinator.startup(() => zkClient.getTopicPartitionCount(Topic.GROUP_METADATA_TOPIC_NAME).getOrElse(config.offsetsTopicPartitions)) + + /* create producer ids manager */ + val producerIdManager = if (config.interBrokerProtocolVersion.isAllocateProducerIdsSupported) { + ProducerIdManager.rpc( + config.brokerId, + brokerEpochSupplier = () => kafkaController.brokerEpoch, + clientToControllerChannelManager, + config.requestTimeoutMs + ) + } else { + ProducerIdManager.zk(config.brokerId, zkClient) + } + /* start transaction coordinator, with a separate background thread scheduler for transaction expiration and log loading */ + // Hardcode Time.SYSTEM for now as some Streams tests fail otherwise, it would be good to fix the underlying issue + transactionCoordinator = TransactionCoordinator(config, replicaManager, new KafkaScheduler(threads = 1, threadNamePrefix = "transaction-log-manager-"), + () => producerIdManager, metrics, metadataCache, Time.SYSTEM) + transactionCoordinator.startup( + () => zkClient.getTopicPartitionCount(Topic.TRANSACTION_STATE_TOPIC_NAME).getOrElse(config.transactionTopicPartitions)) + + /* start auto topic creation manager */ + this.autoTopicCreationManager = AutoTopicCreationManager( + config, + metadataCache, + threadNamePrefix, + autoTopicCreationChannel, + Some(adminManager), + Some(kafkaController), + groupCoordinator, + transactionCoordinator + ) + + /* Get the authorizer and initialize it if one is specified.*/ + authorizer = config.authorizer + authorizer.foreach(_.configure(config.originals)) + val authorizerFutures: Map[Endpoint, CompletableFuture[Void]] = authorizer match { + case Some(authZ) => + authZ.start(brokerInfo.broker.toServerInfo(clusterId, config)).asScala.map { case (ep, cs) => + ep -> cs.toCompletableFuture + } + case None => + brokerInfo.broker.endPoints.map { ep => + ep.toJava -> CompletableFuture.completedFuture[Void](null) + }.toMap + } + + val fetchManager = new FetchManager(Time.SYSTEM, + new FetchSessionCache(config.maxIncrementalFetchSessionCacheSlots, + KafkaServer.MIN_INCREMENTAL_FETCH_SESSION_EVICTION_MS)) + + /* start processing requests */ + val zkSupport = ZkSupport(adminManager, kafkaController, zkClient, forwardingManager, metadataCache) + + def createKafkaApis(requestChannel: RequestChannel): KafkaApis = new KafkaApis( + requestChannel = requestChannel, + metadataSupport = zkSupport, + replicaManager = replicaManager, + groupCoordinator = groupCoordinator, + txnCoordinator = transactionCoordinator, + autoTopicCreationManager = autoTopicCreationManager, + brokerId = config.brokerId, + config = config, + configRepository = configRepository, + metadataCache = metadataCache, + metrics = metrics, + authorizer = authorizer, + quotas = quotaManagers, + fetchManager = fetchManager, + brokerTopicStats = brokerTopicStats, + clusterId = clusterId, + time = time, + tokenManager = tokenManager, + apiVersionManager = apiVersionManager) + + dataPlaneRequestProcessor = createKafkaApis(socketServer.dataPlaneRequestChannel) + + dataPlaneRequestHandlerPool = new KafkaRequestHandlerPool(config.brokerId, socketServer.dataPlaneRequestChannel, dataPlaneRequestProcessor, time, + config.numIoThreads, s"${SocketServer.DataPlaneMetricPrefix}RequestHandlerAvgIdlePercent", SocketServer.DataPlaneThreadPrefix) + + socketServer.controlPlaneRequestChannelOpt.foreach { controlPlaneRequestChannel => + controlPlaneRequestProcessor = createKafkaApis(controlPlaneRequestChannel) + controlPlaneRequestHandlerPool = new KafkaRequestHandlerPool(config.brokerId, socketServer.controlPlaneRequestChannelOpt.get, controlPlaneRequestProcessor, time, + 1, s"${SocketServer.ControlPlaneMetricPrefix}RequestHandlerAvgIdlePercent", SocketServer.ControlPlaneThreadPrefix) + } + + Mx4jLoader.maybeLoad() + + /* Add all reconfigurables for config change notification before starting config handlers */ + config.dynamicConfig.addReconfigurables(this) + + /* start dynamic config manager */ + dynamicConfigHandlers = Map[String, ConfigHandler](ConfigType.Topic -> new TopicConfigHandler(logManager, config, quotaManagers, Some(kafkaController)), + ConfigType.Client -> new ClientIdConfigHandler(quotaManagers), + ConfigType.User -> new UserConfigHandler(quotaManagers, credentialProvider), + ConfigType.Broker -> new BrokerConfigHandler(config, quotaManagers), + ConfigType.Ip -> new IpConfigHandler(socketServer.connectionQuotas)) + + // Create the config manager. start listening to notifications + dynamicConfigManager = new DynamicConfigManager(zkClient, dynamicConfigHandlers) + dynamicConfigManager.startup() + + socketServer.startProcessingRequests(authorizerFutures) + + _brokerState = BrokerState.RUNNING + shutdownLatch = new CountDownLatch(1) + startupComplete.set(true) + isStartingUp.set(false) + AppInfoParser.registerAppInfo(Server.MetricsPrefix, config.brokerId.toString, metrics, time.milliseconds()) + info("started") + } + } + catch { + case e: Throwable => + fatal("Fatal error during KafkaServer startup. Prepare to shutdown", e) + isStartingUp.set(false) + shutdown() + throw e + } + } + + protected def createReplicaManager(isShuttingDown: AtomicBoolean): ReplicaManager = { + new ReplicaManager( + metrics = metrics, + config = config, + time = time, + scheduler = kafkaScheduler, + logManager = logManager, + quotaManagers = quotaManagers, + metadataCache = metadataCache, + logDirFailureChannel = logDirFailureChannel, + alterIsrManager = alterIsrManager, + brokerTopicStats = brokerTopicStats, + isShuttingDown = isShuttingDown, + zkClient = Some(zkClient), + threadNamePrefix = threadNamePrefix) + } + + private def initZkClient(time: Time): Unit = { + info(s"Connecting to zookeeper on ${config.zkConnect}") + + val secureAclsEnabled = config.zkEnableSecureAcls + val isZkSecurityEnabled = JaasUtils.isZkSaslEnabled() || KafkaConfig.zkTlsClientAuthEnabled(zkClientConfig) + + if (secureAclsEnabled && !isZkSecurityEnabled) + throw new java.lang.SecurityException(s"${KafkaConfig.ZkEnableSecureAclsProp} is true, but ZooKeeper client TLS configuration identifying at least $KafkaConfig.ZkSslClientEnableProp, $KafkaConfig.ZkClientCnxnSocketProp, and $KafkaConfig.ZkSslKeyStoreLocationProp was not present and the " + + s"verification of the JAAS login file failed ${JaasUtils.zkSecuritySysConfigString}") + + _zkClient = KafkaZkClient(config.zkConnect, secureAclsEnabled, config.zkSessionTimeoutMs, config.zkConnectionTimeoutMs, + config.zkMaxInFlightRequests, time, name = "Kafka server", zkClientConfig = zkClientConfig, + createChrootIfNecessary = true) + _zkClient.createTopLevelPaths() + } + + private def getOrGenerateClusterId(zkClient: KafkaZkClient): String = { + zkClient.getClusterId.getOrElse(zkClient.createOrGetClusterId(CoreUtils.generateUuidAsBase64())) + } + + def createBrokerInfo: BrokerInfo = { + val endPoints = config.effectiveAdvertisedListeners.map(e => s"${e.host}:${e.port}") + zkClient.getAllBrokersInCluster.filter(_.id != config.brokerId).foreach { broker => + val commonEndPoints = broker.endPoints.map(e => s"${e.host}:${e.port}").intersect(endPoints) + require(commonEndPoints.isEmpty, s"Configured end points ${commonEndPoints.mkString(",")} in" + + s" advertised listeners are already registered by broker ${broker.id}") + } + + val listeners = config.effectiveAdvertisedListeners.map { endpoint => + if (endpoint.port == 0) + endpoint.copy(port = socketServer.boundPort(endpoint.listenerName)) + else + endpoint + } + + val updatedEndpoints = listeners.map(endpoint => + if (Utils.isBlank(endpoint.host)) + endpoint.copy(host = InetAddress.getLocalHost.getCanonicalHostName) + else + endpoint + ) + + val jmxPort = System.getProperty("com.sun.management.jmxremote.port", "-1").toInt + BrokerInfo( + Broker(config.brokerId, updatedEndpoints, config.rack, brokerFeatures.supportedFeatures), + config.interBrokerProtocolVersion, + jmxPort) + } + + /** + * Performs controlled shutdown + */ + private def controlledShutdown(): Unit = { + val socketTimeoutMs = config.controllerSocketTimeoutMs + + def doControlledShutdown(retries: Int): Boolean = { + val metadataUpdater = new ManualMetadataUpdater() + val networkClient = { + val channelBuilder = ChannelBuilders.clientChannelBuilder( + config.interBrokerSecurityProtocol, + JaasContext.Type.SERVER, + config, + config.interBrokerListenerName, + config.saslMechanismInterBrokerProtocol, + time, + config.saslInterBrokerHandshakeRequestEnable, + logContext) + val selector = new Selector( + NetworkReceive.UNLIMITED, + config.connectionsMaxIdleMs, + metrics, + time, + "kafka-server-controlled-shutdown", + Map.empty.asJava, + false, + channelBuilder, + logContext + ) + new NetworkClient( + selector, + metadataUpdater, + config.brokerId.toString, + 1, + 0, + 0, + Selectable.USE_DEFAULT_BUFFER_SIZE, + Selectable.USE_DEFAULT_BUFFER_SIZE, + config.requestTimeoutMs, + config.connectionSetupTimeoutMs, + config.connectionSetupTimeoutMaxMs, + time, + false, + new ApiVersions, + logContext) + } + + var shutdownSucceeded: Boolean = false + + try { + + var remainingRetries = retries + var prevController: Node = null + var ioException = false + + while (!shutdownSucceeded && remainingRetries > 0) { + remainingRetries = remainingRetries - 1 + + // 1. Find the controller and establish a connection to it. + // If the controller id or the broker registration are missing, we sleep and retry (if there are remaining retries) + metadataCache.getControllerId match { + case Some(controllerId) => + metadataCache.getAliveBrokerNode(controllerId, config.interBrokerListenerName) match { + case Some(broker) => + // if this is the first attempt, if the controller has changed or if an exception was thrown in a previous + // attempt, connect to the most recent controller + if (ioException || broker != prevController) { + + ioException = false + + if (prevController != null) + networkClient.close(prevController.idString) + + prevController = broker + metadataUpdater.setNodes(Seq(prevController).asJava) + } + case None => + info(s"Broker registration for controller $controllerId is not available in the metadata cache") + } + case None => + info("No controller present in the metadata cache") + } + + // 2. issue a controlled shutdown to the controller + if (prevController != null) { + try { + + if (!NetworkClientUtils.awaitReady(networkClient, prevController, time, socketTimeoutMs)) + throw new SocketTimeoutException(s"Failed to connect within $socketTimeoutMs ms") + + // send the controlled shutdown request + val controlledShutdownApiVersion: Short = + if (config.interBrokerProtocolVersion < KAFKA_0_9_0) 0 + else if (config.interBrokerProtocolVersion < KAFKA_2_2_IV0) 1 + else if (config.interBrokerProtocolVersion < KAFKA_2_4_IV1) 2 + else 3 + + val controlledShutdownRequest = new ControlledShutdownRequest.Builder( + new ControlledShutdownRequestData() + .setBrokerId(config.brokerId) + .setBrokerEpoch(kafkaController.brokerEpoch), + controlledShutdownApiVersion) + val request = networkClient.newClientRequest(prevController.idString, controlledShutdownRequest, + time.milliseconds(), true) + val clientResponse = NetworkClientUtils.sendAndReceive(networkClient, request, time) + + val shutdownResponse = clientResponse.responseBody.asInstanceOf[ControlledShutdownResponse] + if (shutdownResponse.error != Errors.NONE) { + info(s"Controlled shutdown request returned after ${clientResponse.requestLatencyMs}ms " + + s"with error ${shutdownResponse.error}") + } else if (shutdownResponse.data.remainingPartitions.isEmpty) { + shutdownSucceeded = true + info("Controlled shutdown request returned successfully " + + s"after ${clientResponse.requestLatencyMs}ms") + } else { + info(s"Controlled shutdown request returned after ${clientResponse.requestLatencyMs}ms " + + s"with ${shutdownResponse.data.remainingPartitions.size} partitions remaining to move") + + if (isDebugEnabled) { + debug("Remaining partitions to move during controlled shutdown: " + + s"${shutdownResponse.data.remainingPartitions}") + } + } + } + catch { + case ioe: IOException => + ioException = true + warn("Error during controlled shutdown, possibly because leader movement took longer than the " + + s"configured controller.socket.timeout.ms and/or request.timeout.ms: ${ioe.getMessage}") + // ignore and try again + } + } + if (!shutdownSucceeded && remainingRetries > 0) { + Thread.sleep(config.controlledShutdownRetryBackoffMs) + info(s"Retrying controlled shutdown ($remainingRetries retries remaining)") + } + } + } + finally + networkClient.close() + + shutdownSucceeded + } + + if (startupComplete.get() && config.controlledShutdownEnable) { + // We request the controller to do a controlled shutdown. On failure, we backoff for a configured period + // of time and try again for a configured number of retries. If all the attempt fails, we simply force + // the shutdown. + info("Starting controlled shutdown") + + _brokerState = BrokerState.PENDING_CONTROLLED_SHUTDOWN + + val shutdownSucceeded = doControlledShutdown(config.controlledShutdownMaxRetries.intValue) + + if (!shutdownSucceeded) + warn("Proceeding to do an unclean shutdown as all the controlled shutdown attempts failed") + } + } + + /** + * Shutdown API for shutting down a single instance of the Kafka server. + * Shuts down the LogManager, the SocketServer and the log cleaner scheduler thread + */ + override def shutdown(): Unit = { + try { + info("shutting down") + + if (isStartingUp.get) + throw new IllegalStateException("Kafka server is still starting up, cannot shut down!") + + // To ensure correct behavior under concurrent calls, we need to check `shutdownLatch` first since it gets updated + // last in the `if` block. If the order is reversed, we could shutdown twice or leave `isShuttingDown` set to + // `true` at the end of this method. + if (shutdownLatch.getCount > 0 && isShuttingDown.compareAndSet(false, true)) { + CoreUtils.swallow(controlledShutdown(), this) + _brokerState = BrokerState.SHUTTING_DOWN + + if (dynamicConfigManager != null) + CoreUtils.swallow(dynamicConfigManager.shutdown(), this) + + // Stop socket server to stop accepting any more connections and requests. + // Socket server will be shutdown towards the end of the sequence. + if (socketServer != null) + CoreUtils.swallow(socketServer.stopProcessingRequests(), this) + if (dataPlaneRequestHandlerPool != null) + CoreUtils.swallow(dataPlaneRequestHandlerPool.shutdown(), this) + if (controlPlaneRequestHandlerPool != null) + CoreUtils.swallow(controlPlaneRequestHandlerPool.shutdown(), this) + + /** + * We must shutdown the scheduler early because otherwise, the scheduler could touch other + * resources that might have been shutdown and cause exceptions. + * For example, if we didn't shutdown the scheduler first, when LogManager was closing + * partitions one by one, the scheduler might concurrently delete old segments due to + * retention. However, the old segments could have been closed by the LogManager, which would + * cause an IOException and subsequently mark logdir as offline. As a result, the broker would + * not flush the remaining partitions or write the clean shutdown marker. Ultimately, the + * broker would have to take hours to recover the log during restart. + */ + if (kafkaScheduler != null) + CoreUtils.swallow(kafkaScheduler.shutdown(), this) + + if (dataPlaneRequestProcessor != null) + CoreUtils.swallow(dataPlaneRequestProcessor.close(), this) + if (controlPlaneRequestProcessor != null) + CoreUtils.swallow(controlPlaneRequestProcessor.close(), this) + CoreUtils.swallow(authorizer.foreach(_.close()), this) + if (adminManager != null) + CoreUtils.swallow(adminManager.shutdown(), this) + + if (transactionCoordinator != null) + CoreUtils.swallow(transactionCoordinator.shutdown(), this) + if (groupCoordinator != null) + CoreUtils.swallow(groupCoordinator.shutdown(), this) + + if (tokenManager != null) + CoreUtils.swallow(tokenManager.shutdown(), this) + + if (replicaManager != null) + CoreUtils.swallow(replicaManager.shutdown(), this) + + if (alterIsrManager != null) + CoreUtils.swallow(alterIsrManager.shutdown(), this) + + if (clientToControllerChannelManager != null) + CoreUtils.swallow(clientToControllerChannelManager.shutdown(), this) + + if (logManager != null) + CoreUtils.swallow(logManager.shutdown(), this) + + if (kafkaController != null) + CoreUtils.swallow(kafkaController.shutdown(), this) + + if (featureChangeListener != null) + CoreUtils.swallow(featureChangeListener.close(), this) + + if (zkClient != null) + CoreUtils.swallow(zkClient.close(), this) + + if (quotaManagers != null) + CoreUtils.swallow(quotaManagers.shutdown(), this) + + // Even though socket server is stopped much earlier, controller can generate + // response for controlled shutdown request. Shutdown server at the end to + // avoid any failures (e.g. when metrics are recorded) + if (socketServer != null) + CoreUtils.swallow(socketServer.shutdown(), this) + if (metrics != null) + CoreUtils.swallow(metrics.close(), this) + if (brokerTopicStats != null) + CoreUtils.swallow(brokerTopicStats.close(), this) + + // Clear all reconfigurable instances stored in DynamicBrokerConfig + config.dynamicConfig.clear() + + _brokerState = BrokerState.NOT_RUNNING + + startupComplete.set(false) + isShuttingDown.set(false) + CoreUtils.swallow(AppInfoParser.unregisterAppInfo(Server.MetricsPrefix, config.brokerId.toString, metrics), this) + shutdownLatch.countDown() + info("shut down completed") + } + } + catch { + case e: Throwable => + fatal("Fatal error during KafkaServer shutdown.", e) + isShuttingDown.set(false) + throw e + } + } + + /** + * After calling shutdown(), use this API to wait until the shutdown is complete + */ + override def awaitShutdown(): Unit = shutdownLatch.await() + + def getLogManager: LogManager = logManager + + override def boundPort(listenerName: ListenerName): Int = socketServer.boundPort(listenerName) + + /** Return advertised listeners with the bound port (this may differ from the configured port if the latter is `0`). */ + def advertisedListeners: Seq[EndPoint] = { + config.effectiveAdvertisedListeners.map { endPoint => + endPoint.copy(port = boundPort(endPoint.listenerName)) + } + } + + /** + * Checkpoint the BrokerMetadata to all the online log.dirs + * + * @param brokerMetadata + */ + private def checkpointBrokerMetadata(brokerMetadata: ZkMetaProperties) = { + for (logDir <- config.logDirs if logManager.isLogDirOnline(new File(logDir).getAbsolutePath)) { + val checkpoint = brokerMetadataCheckpoints(logDir) + checkpoint.write(brokerMetadata.toProperties) + } + } + + /** + * Generates new brokerId if enabled or reads from meta.properties based on following conditions + *
              + *
            1. config has no broker.id provided and broker id generation is enabled, generates a broker.id based on Zookeeper's sequence + *
            2. config has broker.id and meta.properties contains broker.id if they don't match throws InconsistentBrokerIdException + *
            3. config has broker.id and there is no meta.properties file, creates new meta.properties and stores broker.id + *
                + * + * @return The brokerId. + */ + private def getOrGenerateBrokerId(brokerMetadata: RawMetaProperties): Int = { + val brokerId = config.brokerId + + if (brokerId >= 0 && brokerMetadata.brokerId.exists(_ != brokerId)) + throw new InconsistentBrokerIdException( + s"Configured broker.id $brokerId doesn't match stored broker.id ${brokerMetadata.brokerId} in meta.properties. " + + s"If you moved your data, make sure your configured broker.id matches. " + + s"If you intend to create a new broker, you should remove all data in your data directories (log.dirs).") + else if (brokerMetadata.brokerId.isDefined) + brokerMetadata.brokerId.get + else if (brokerId < 0 && config.brokerIdGenerationEnable) // generate a new brokerId from Zookeeper + generateBrokerId() + else + brokerId + } + + /** + * Return a sequence id generated by updating the broker sequence id path in ZK. + * Users can provide brokerId in the config. To avoid conflicts between ZK generated + * sequence id and configured brokerId, we increment the generated sequence id by KafkaConfig.MaxReservedBrokerId. + */ + private def generateBrokerId(): Int = { + try { + zkClient.generateBrokerSequenceId() + config.maxReservedBrokerId + } catch { + case e: Exception => + error("Failed to generate broker.id due to ", e) + throw new GenerateBrokerIdException("Failed to generate broker.id", e) + } + } +} diff --git a/core/src/main/scala/kafka/server/LogDirFailureChannel.scala b/core/src/main/scala/kafka/server/LogDirFailureChannel.scala new file mode 100644 index 0000000..71ba9ac --- /dev/null +++ b/core/src/main/scala/kafka/server/LogDirFailureChannel.scala @@ -0,0 +1,62 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package kafka.server + +import java.io.IOException +import java.util.concurrent.{ArrayBlockingQueue, ConcurrentHashMap} + +import kafka.utils.Logging + +/* + * LogDirFailureChannel allows an external thread to block waiting for new offline log dirs. + * + * There should be a single instance of LogDirFailureChannel accessible by any class that does disk-IO operation. + * If IOException is encountered while accessing a log directory, the corresponding class can add the log directory name + * to the LogDirFailureChannel using maybeAddOfflineLogDir(). Each log directory will be added only once. After a log + * directory is added for the first time, a thread which is blocked waiting for new offline log directories + * can take the name of the new offline log directory out of the LogDirFailureChannel and handle the log failure properly. + * An offline log directory will stay offline until the broker is restarted. + * + */ +class LogDirFailureChannel(logDirNum: Int) extends Logging { + + private val offlineLogDirs = new ConcurrentHashMap[String, String] + private val offlineLogDirQueue = new ArrayBlockingQueue[String](logDirNum) + + def hasOfflineLogDir(logDir: String): Boolean = { + offlineLogDirs.containsKey(logDir) + } + + /* + * If the given logDir is not already offline, add it to the + * set of offline log dirs and enqueue it to the logDirFailureEvent queue + */ + def maybeAddOfflineLogDir(logDir: String, msg: => String, e: IOException): Unit = { + error(msg, e) + if (offlineLogDirs.putIfAbsent(logDir, logDir) == null) + offlineLogDirQueue.add(logDir) + } + + /* + * Get the next offline log dir from logDirFailureEvent queue. + * The method will wait if necessary until a new offline log directory becomes available + */ + def takeNextOfflineLogDir(): String = offlineLogDirQueue.take() + +} diff --git a/core/src/main/scala/kafka/server/LogOffsetMetadata.scala b/core/src/main/scala/kafka/server/LogOffsetMetadata.scala new file mode 100644 index 0000000..9400260 --- /dev/null +++ b/core/src/main/scala/kafka/server/LogOffsetMetadata.scala @@ -0,0 +1,84 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.log.UnifiedLog +import org.apache.kafka.common.KafkaException + +object LogOffsetMetadata { + val UnknownOffsetMetadata = LogOffsetMetadata(-1, 0, 0) + val UnknownFilePosition = -1 + + class OffsetOrdering extends Ordering[LogOffsetMetadata] { + override def compare(x: LogOffsetMetadata, y: LogOffsetMetadata): Int = { + x.offsetDiff(y).toInt + } + } + +} + +/* + * A log offset structure, including: + * 1. the message offset + * 2. the base message offset of the located segment + * 3. the physical position on the located segment + */ +case class LogOffsetMetadata(messageOffset: Long, + segmentBaseOffset: Long = UnifiedLog.UnknownOffset, + relativePositionInSegment: Int = LogOffsetMetadata.UnknownFilePosition) { + + // check if this offset is already on an older segment compared with the given offset + def onOlderSegment(that: LogOffsetMetadata): Boolean = { + if (messageOffsetOnly) + throw new KafkaException(s"$this cannot compare its segment info with $that since it only has message offset info") + + this.segmentBaseOffset < that.segmentBaseOffset + } + + // check if this offset is on the same segment with the given offset + def onSameSegment(that: LogOffsetMetadata): Boolean = { + if (messageOffsetOnly) + throw new KafkaException(s"$this cannot compare its segment info with $that since it only has message offset info") + + this.segmentBaseOffset == that.segmentBaseOffset + } + + // compute the number of messages between this offset to the given offset + def offsetDiff(that: LogOffsetMetadata): Long = { + this.messageOffset - that.messageOffset + } + + // compute the number of bytes between this offset to the given offset + // if they are on the same segment and this offset precedes the given offset + def positionDiff(that: LogOffsetMetadata): Int = { + if(!onSameSegment(that)) + throw new KafkaException(s"$this cannot compare its segment position with $that since they are not on the same segment") + if(messageOffsetOnly) + throw new KafkaException(s"$this cannot compare its segment position with $that since it only has message offset info") + + this.relativePositionInSegment - that.relativePositionInSegment + } + + // decide if the offset metadata only contains message offset info + def messageOffsetOnly: Boolean = { + segmentBaseOffset == UnifiedLog.UnknownOffset && relativePositionInSegment == LogOffsetMetadata.UnknownFilePosition + } + + override def toString = s"(offset=$messageOffset segment=[$segmentBaseOffset:$relativePositionInSegment])" + +} diff --git a/core/src/main/scala/kafka/server/MetadataCache.scala b/core/src/main/scala/kafka/server/MetadataCache.scala new file mode 100755 index 0000000..2e2da0c --- /dev/null +++ b/core/src/main/scala/kafka/server/MetadataCache.scala @@ -0,0 +1,105 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.admin.BrokerMetadata +import kafka.server.metadata.{KRaftMetadataCache, ZkMetadataCache} +import org.apache.kafka.common.message.{MetadataResponseData, UpdateMetadataRequestData} +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.{Cluster, Node, TopicPartition, Uuid} + +import java.util + +trait MetadataCache { + + /** + * Return topic metadata for a given set of topics and listener. See KafkaApis#handleTopicMetadataRequest for details + * on the use of the two boolean flags. + * + * @param topics The set of topics. + * @param listenerName The listener name. + * @param errorUnavailableEndpoints If true, we return an error on unavailable brokers. This is used to support + * MetadataResponse version 0. + * @param errorUnavailableListeners If true, return LEADER_NOT_AVAILABLE if the listener is not found on the leader. + * This is used for MetadataResponse versions 0-5. + * @return A collection of topic metadata. + */ + def getTopicMetadata( + topics: collection.Set[String], + listenerName: ListenerName, + errorUnavailableEndpoints: Boolean = false, + errorUnavailableListeners: Boolean = false): collection.Seq[MetadataResponseData.MetadataResponseTopic] + + def getAllTopics(): collection.Set[String] + + def getTopicPartitions(topicName: String): collection.Set[TopicPartition] + + def hasAliveBroker(brokerId: Int): Boolean + + def getAliveBrokers(): Iterable[BrokerMetadata] + + def getTopicId(topicName: String): Uuid + + def getTopicName(topicId: Uuid): Option[String] + + def getAliveBrokerNode(brokerId: Int, listenerName: ListenerName): Option[Node] + + def getAliveBrokerNodes(listenerName: ListenerName): Iterable[Node] + + def getPartitionInfo(topic: String, partitionId: Int): Option[UpdateMetadataRequestData.UpdateMetadataPartitionState] + + /** + * Return the number of partitions in the given topic, or None if the given topic does not exist. + */ + def numPartitions(topic: String): Option[Int] + + def topicNamesToIds(): util.Map[String, Uuid] + + def topicIdsToNames(): util.Map[Uuid, String] + + def topicIdInfo(): (util.Map[String, Uuid], util.Map[Uuid, String]) + + /** + * Get a partition leader's endpoint + * + * @return If the leader is known, and the listener name is available, return Some(node). If the leader is known, + * but the listener is unavailable, return Some(Node.NO_NODE). Otherwise, if the leader is not known, + * return None + */ + def getPartitionLeaderEndpoint(topic: String, partitionId: Int, listenerName: ListenerName): Option[Node] + + def getPartitionReplicaEndpoints(tp: TopicPartition, listenerName: ListenerName): Map[Int, Node] + + def getControllerId: Option[Int] + + def getClusterMetadata(clusterId: String, listenerName: ListenerName): Cluster + + def contains(topic: String): Boolean + + def contains(tp: TopicPartition): Boolean +} + +object MetadataCache { + def zkMetadataCache(brokerId: Int): ZkMetadataCache = { + new ZkMetadataCache(brokerId) + } + + def kRaftMetadataCache(brokerId: Int): KRaftMetadataCache = { + new KRaftMetadataCache(brokerId) + } +} diff --git a/core/src/main/scala/kafka/server/MetadataSupport.scala b/core/src/main/scala/kafka/server/MetadataSupport.scala new file mode 100644 index 0000000..ecacffa --- /dev/null +++ b/core/src/main/scala/kafka/server/MetadataSupport.scala @@ -0,0 +1,121 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.controller.KafkaController +import kafka.network.RequestChannel +import kafka.server.metadata.{KRaftMetadataCache, ZkMetadataCache} +import kafka.zk.{AdminZkClient, KafkaZkClient} +import org.apache.kafka.common.requests.AbstractResponse + +sealed trait MetadataSupport { + /** + * Provide a uniform way of getting to the ForwardingManager, which is a shared concept + * despite being optional when using ZooKeeper and required when using Raft + */ + val forwardingManager: Option[ForwardingManager] + + /** + * Return this instance downcast for use with ZooKeeper + * + * @param createException function to create an exception to throw + * @return this instance downcast for use with ZooKeeper + * @throws Exception if this instance is not for ZooKeeper + */ + def requireZkOrThrow(createException: => Exception): ZkSupport + + /** + * Return this instance downcast for use with Raft + * + * @param createException function to create an exception to throw + * @return this instance downcast for use with Raft + * @throws Exception if this instance is not for Raft + */ + def requireRaftOrThrow(createException: => Exception): RaftSupport + + /** + * Confirm that this instance is consistent with the given config + * + * @param config the config to check for consistency with this instance + * @throws IllegalStateException if there is an inconsistency (Raft for a ZooKeeper config or vice-versa) + */ + def ensureConsistentWith(config: KafkaConfig): Unit + + def maybeForward(request: RequestChannel.Request, + handler: RequestChannel.Request => Unit, + responseCallback: Option[AbstractResponse] => Unit): Unit + + def controllerId: Option[Int] +} + +case class ZkSupport(adminManager: ZkAdminManager, + controller: KafkaController, + zkClient: KafkaZkClient, + forwardingManager: Option[ForwardingManager], + metadataCache: ZkMetadataCache) extends MetadataSupport { + val adminZkClient = new AdminZkClient(zkClient) + + override def requireZkOrThrow(createException: => Exception): ZkSupport = this + override def requireRaftOrThrow(createException: => Exception): RaftSupport = throw createException + + override def ensureConsistentWith(config: KafkaConfig): Unit = { + if (!config.requiresZookeeper) { + throw new IllegalStateException("Config specifies Raft but metadata support instance is for ZooKeeper") + } + } + + override def maybeForward(request: RequestChannel.Request, + handler: RequestChannel.Request => Unit, + responseCallback: Option[AbstractResponse] => Unit): Unit = { + forwardingManager match { + case Some(mgr) if !request.isForwarded && !controller.isActive => mgr.forwardRequest(request, responseCallback) + case _ => handler(request) + } + } + + override def controllerId: Option[Int] = metadataCache.getControllerId +} + +case class RaftSupport(fwdMgr: ForwardingManager, metadataCache: KRaftMetadataCache) + extends MetadataSupport { + override val forwardingManager: Option[ForwardingManager] = Some(fwdMgr) + override def requireZkOrThrow(createException: => Exception): ZkSupport = throw createException + override def requireRaftOrThrow(createException: => Exception): RaftSupport = this + + override def ensureConsistentWith(config: KafkaConfig): Unit = { + if (config.requiresZookeeper) { + throw new IllegalStateException("Config specifies ZooKeeper but metadata support instance is for Raft") + } + } + + override def maybeForward(request: RequestChannel.Request, + handler: RequestChannel.Request => Unit, + responseCallback: Option[AbstractResponse] => Unit): Unit = { + if (!request.isForwarded) { + fwdMgr.forwardRequest(request, responseCallback) + } else { + handler(request) // will reject + } + } + + /** + * Get the broker ID to return from a MetadataResponse. This will be a broker ID, as + * described in KRaftMetadataCache#getControllerId. See that function for more details. + */ + override def controllerId: Option[Int] = metadataCache.getControllerId +} diff --git a/core/src/main/scala/kafka/server/PartitionMetadataFile.scala b/core/src/main/scala/kafka/server/PartitionMetadataFile.scala new file mode 100644 index 0000000..749b6dd --- /dev/null +++ b/core/src/main/scala/kafka/server/PartitionMetadataFile.scala @@ -0,0 +1,167 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.io.{BufferedReader, BufferedWriter, File, FileOutputStream, IOException, OutputStreamWriter} +import java.nio.charset.StandardCharsets +import java.nio.file.{Files, Paths} +import java.util.regex.Pattern + +import kafka.utils.Logging +import org.apache.kafka.common.Uuid +import org.apache.kafka.common.errors.{InconsistentTopicIdException, KafkaStorageException} +import org.apache.kafka.common.utils.Utils + + + +object PartitionMetadataFile { + private val PartitionMetadataFilename = "partition.metadata" + private val WhiteSpacesPattern = Pattern.compile(":\\s+") + private val CurrentVersion = 0 + + def newFile(dir: File): File = new File(dir, PartitionMetadataFilename) + + object PartitionMetadataFileFormatter { + def toFile(data: PartitionMetadata): String = { + s"version: ${data.version}\ntopic_id: ${data.topicId}" + } + + } + + class PartitionMetadataReadBuffer[T](location: String, + reader: BufferedReader, + version: Int) extends Logging { + def read(): PartitionMetadata = { + def malformedLineException(line: String) = + new IOException(s"Malformed line in checkpoint file ($location): '$line'") + + var line: String = null + var metadataTopicId: Uuid = null + try { + line = reader.readLine() + WhiteSpacesPattern.split(line) match { + case Array(_, version) => + if (version.toInt == CurrentVersion) { + line = reader.readLine() + WhiteSpacesPattern.split(line) match { + case Array(_, topicId) => metadataTopicId = Uuid.fromString(topicId) + case _ => throw malformedLineException(line) + } + if (metadataTopicId.equals(Uuid.ZERO_UUID)) { + throw new IOException(s"Invalid topic ID in partition metadata file ($location)") + } + new PartitionMetadata(CurrentVersion, metadataTopicId) + } else { + throw new IOException(s"Unrecognized version of partition metadata file ($location): " + version) + } + case _ => throw malformedLineException(line) + } + } catch { + case _: NumberFormatException => throw malformedLineException(line) + } + } + } + +} + +class PartitionMetadata(val version: Int, val topicId: Uuid) + + +class PartitionMetadataFile(val file: File, + logDirFailureChannel: LogDirFailureChannel) extends Logging { + import kafka.server.PartitionMetadataFile.{CurrentVersion, PartitionMetadataFileFormatter, PartitionMetadataReadBuffer} + + private val path = file.toPath.toAbsolutePath + private val tempPath = Paths.get(path.toString + ".tmp") + private val lock = new Object() + private val logDir = file.getParentFile.getParent + @volatile private var dirtyTopicIdOpt : Option[Uuid] = None + + /** + * Records the topic ID that will be flushed to disk. + */ + def record(topicId: Uuid): Unit = { + // Topic IDs should not differ, but we defensively check here to fail earlier in the case that the IDs somehow differ. + dirtyTopicIdOpt.foreach { dirtyTopicId => + if (dirtyTopicId != topicId) + throw new InconsistentTopicIdException(s"Tried to record topic ID $topicId to file " + + s"but had already recorded $dirtyTopicId") + } + dirtyTopicIdOpt = Some(topicId) + } + + def maybeFlush(): Unit = { + // We check dirtyTopicId first to avoid having to take the lock unnecessarily in the frequently called log append path + dirtyTopicIdOpt.foreach { _ => + // We synchronize on the actual write to disk + lock synchronized { + dirtyTopicIdOpt.foreach { topicId => + try { + // write to temp file and then swap with the existing file + val fileOutputStream = new FileOutputStream(tempPath.toFile) + val writer = new BufferedWriter(new OutputStreamWriter(fileOutputStream, StandardCharsets.UTF_8)) + try { + writer.write(PartitionMetadataFileFormatter.toFile(new PartitionMetadata(CurrentVersion, topicId))) + writer.flush() + fileOutputStream.getFD().sync() + } finally { + writer.close() + } + + Utils.atomicMoveWithFallback(tempPath, path) + } catch { + case e: IOException => + val msg = s"Error while writing to partition metadata file ${file.getAbsolutePath}" + logDirFailureChannel.maybeAddOfflineLogDir(logDir, msg, e) + throw new KafkaStorageException(msg, e) + } + dirtyTopicIdOpt = None + } + } + } + } + + def read(): PartitionMetadata = { + lock synchronized { + try { + val reader = Files.newBufferedReader(path) + try { + val partitionBuffer = new PartitionMetadataReadBuffer(file.getAbsolutePath, reader, CurrentVersion) + partitionBuffer.read() + } finally { + reader.close() + } + } catch { + case e: IOException => + val msg = s"Error while reading partition metadata file ${file.getAbsolutePath}" + logDirFailureChannel.maybeAddOfflineLogDir(logDir, msg, e) + throw new KafkaStorageException(msg, e) + } + } + } + + def exists(): Boolean = { + file.exists() + } + + def delete(): Unit = { + Files.delete(file.toPath) + } + + override def toString: String = s"PartitionMetadataFile(path=$path)" +} diff --git a/core/src/main/scala/kafka/server/QuotaFactory.scala b/core/src/main/scala/kafka/server/QuotaFactory.scala new file mode 100644 index 0000000..f3901f6 --- /dev/null +++ b/core/src/main/scala/kafka/server/QuotaFactory.scala @@ -0,0 +1,119 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import kafka.server.QuotaType._ +import kafka.utils.Logging +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.server.quota.ClientQuotaCallback +import org.apache.kafka.common.utils.Time +import org.apache.kafka.server.quota.ClientQuotaType + +object QuotaType { + case object Fetch extends QuotaType + case object Produce extends QuotaType + case object Request extends QuotaType + case object ControllerMutation extends QuotaType + case object LeaderReplication extends QuotaType + case object FollowerReplication extends QuotaType + case object AlterLogDirsReplication extends QuotaType + + def toClientQuotaType(quotaType: QuotaType): ClientQuotaType = { + quotaType match { + case QuotaType.Fetch => ClientQuotaType.FETCH + case QuotaType.Produce => ClientQuotaType.PRODUCE + case QuotaType.Request => ClientQuotaType.REQUEST + case QuotaType.ControllerMutation => ClientQuotaType.CONTROLLER_MUTATION + case _ => throw new IllegalArgumentException(s"Not a client quota type: $quotaType") + } + } +} + +sealed trait QuotaType + +object QuotaFactory extends Logging { + + object UnboundedQuota extends ReplicaQuota { + override def isThrottled(topicPartition: TopicPartition): Boolean = false + override def isQuotaExceeded: Boolean = false + def record(value: Long): Unit = () + } + + case class QuotaManagers(fetch: ClientQuotaManager, + produce: ClientQuotaManager, + request: ClientRequestQuotaManager, + controllerMutation: ControllerMutationQuotaManager, + leader: ReplicationQuotaManager, + follower: ReplicationQuotaManager, + alterLogDirs: ReplicationQuotaManager, + clientQuotaCallback: Option[ClientQuotaCallback]) { + def shutdown(): Unit = { + fetch.shutdown() + produce.shutdown() + request.shutdown() + controllerMutation.shutdown() + clientQuotaCallback.foreach(_.close()) + } + } + + def instantiate(cfg: KafkaConfig, metrics: Metrics, time: Time, threadNamePrefix: String): QuotaManagers = { + + val clientQuotaCallback = Option(cfg.getConfiguredInstance(KafkaConfig.ClientQuotaCallbackClassProp, + classOf[ClientQuotaCallback])) + QuotaManagers( + new ClientQuotaManager(clientConfig(cfg), metrics, Fetch, time, threadNamePrefix, clientQuotaCallback), + new ClientQuotaManager(clientConfig(cfg), metrics, Produce, time, threadNamePrefix, clientQuotaCallback), + new ClientRequestQuotaManager(clientConfig(cfg), metrics, time, threadNamePrefix, clientQuotaCallback), + new ControllerMutationQuotaManager(clientControllerMutationConfig(cfg), metrics, time, + threadNamePrefix, clientQuotaCallback), + new ReplicationQuotaManager(replicationConfig(cfg), metrics, LeaderReplication, time), + new ReplicationQuotaManager(replicationConfig(cfg), metrics, FollowerReplication, time), + new ReplicationQuotaManager(alterLogDirsReplicationConfig(cfg), metrics, AlterLogDirsReplication, time), + clientQuotaCallback + ) + } + + def clientConfig(cfg: KafkaConfig): ClientQuotaManagerConfig = { + ClientQuotaManagerConfig( + numQuotaSamples = cfg.numQuotaSamples, + quotaWindowSizeSeconds = cfg.quotaWindowSizeSeconds + ) + } + + def clientControllerMutationConfig(cfg: KafkaConfig): ClientQuotaManagerConfig = { + ClientQuotaManagerConfig( + numQuotaSamples = cfg.numControllerQuotaSamples, + quotaWindowSizeSeconds = cfg.controllerQuotaWindowSizeSeconds + ) + } + + def replicationConfig(cfg: KafkaConfig): ReplicationQuotaManagerConfig = { + ReplicationQuotaManagerConfig( + numQuotaSamples = cfg.numReplicationQuotaSamples, + quotaWindowSizeSeconds = cfg.replicationQuotaWindowSizeSeconds + ) + } + + def alterLogDirsReplicationConfig(cfg: KafkaConfig): ReplicationQuotaManagerConfig = { + ReplicationQuotaManagerConfig( + numQuotaSamples = cfg.numAlterLogDirsReplicationQuotaSamples, + quotaWindowSizeSeconds = cfg.alterLogDirsReplicationQuotaWindowSizeSeconds + ) + } + +} diff --git a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsManager.scala b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsManager.scala new file mode 100644 index 0000000..b45a766 --- /dev/null +++ b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsManager.scala @@ -0,0 +1,58 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.cluster.BrokerEndPoint +import org.apache.kafka.common.TopicPartition + +class ReplicaAlterLogDirsManager(brokerConfig: KafkaConfig, + replicaManager: ReplicaManager, + quotaManager: ReplicationQuotaManager, + brokerTopicStats: BrokerTopicStats) + extends AbstractFetcherManager[ReplicaAlterLogDirsThread]( + name = s"ReplicaAlterLogDirsManager on broker ${brokerConfig.brokerId}", + clientId = "ReplicaAlterLogDirs", + numFetchers = brokerConfig.getNumReplicaAlterLogDirsThreads) { + + override def createFetcherThread(fetcherId: Int, sourceBroker: BrokerEndPoint): ReplicaAlterLogDirsThread = { + val threadName = s"ReplicaAlterLogDirsThread-$fetcherId" + new ReplicaAlterLogDirsThread(threadName, sourceBroker, brokerConfig, failedPartitions, replicaManager, + quotaManager, brokerTopicStats) + } + + override protected def addPartitionsToFetcherThread(fetcherThread: ReplicaAlterLogDirsThread, + initialOffsetAndEpochs: collection.Map[TopicPartition, InitialFetchState]): Unit = { + val addedPartitions = fetcherThread.addPartitions(initialOffsetAndEpochs) + val (addedInitialOffsets, notAddedInitialOffsets) = initialOffsetAndEpochs.partition { case (tp, _) => + addedPartitions.contains(tp) + } + + if (addedInitialOffsets.nonEmpty) + info(s"Added log dir fetcher for partitions with initial offsets $addedInitialOffsets") + + if (notAddedInitialOffsets.nonEmpty) + info(s"Failed to add log dir fetch for partitions ${notAddedInitialOffsets.keySet} " + + s"since the log dir reassignment has already completed") + } + + def shutdown(): Unit = { + info("shutting down") + closeAllFetchers() + info("shutdown completed") + } +} diff --git a/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala new file mode 100644 index 0000000..2ce33c8 --- /dev/null +++ b/core/src/main/scala/kafka/server/ReplicaAlterLogDirsThread.scala @@ -0,0 +1,314 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.api.Request +import kafka.cluster.BrokerEndPoint +import kafka.log.{LeaderOffsetIncremented, LogAppendInfo} +import kafka.server.AbstractFetcherThread.{ReplicaFetch, ResultWithPartitions} +import kafka.server.QuotaFactory.UnboundedQuota +import org.apache.kafka.common.{TopicIdPartition, TopicPartition, Uuid} +import org.apache.kafka.common.errors.KafkaStorageException +import org.apache.kafka.common.message.FetchResponseData +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.UNDEFINED_EPOCH +import org.apache.kafka.common.requests.{FetchRequest, FetchResponse, RequestUtils} +import java.util +import java.util.Optional +import scala.collection.{Map, Seq, Set, mutable} +import scala.compat.java8.OptionConverters._ +import scala.jdk.CollectionConverters._ + +class ReplicaAlterLogDirsThread(name: String, + sourceBroker: BrokerEndPoint, + brokerConfig: KafkaConfig, + failedPartitions: FailedPartitions, + replicaMgr: ReplicaManager, + quota: ReplicationQuotaManager, + brokerTopicStats: BrokerTopicStats) + extends AbstractFetcherThread(name = name, + clientId = name, + sourceBroker = sourceBroker, + failedPartitions, + fetchBackOffMs = brokerConfig.replicaFetchBackoffMs, + isInterruptible = false, + brokerTopicStats) { + + private val replicaId = brokerConfig.brokerId + private val maxBytes = brokerConfig.replicaFetchResponseMaxBytes + private val fetchSize = brokerConfig.replicaFetchMaxBytes + private var inProgressPartition: Option[TopicPartition] = None + + override protected def latestEpoch(topicPartition: TopicPartition): Option[Int] = { + replicaMgr.futureLocalLogOrException(topicPartition).latestEpoch + } + + override protected def logStartOffset(topicPartition: TopicPartition): Long = { + replicaMgr.futureLocalLogOrException(topicPartition).logStartOffset + } + + override protected def logEndOffset(topicPartition: TopicPartition): Long = { + replicaMgr.futureLocalLogOrException(topicPartition).logEndOffset + } + + override protected def endOffsetForEpoch(topicPartition: TopicPartition, epoch: Int): Option[OffsetAndEpoch] = { + replicaMgr.futureLocalLogOrException(topicPartition).endOffsetForEpoch(epoch) + } + + def fetchFromLeader(fetchRequest: FetchRequest.Builder): Map[TopicPartition, FetchData] = { + var partitionData: Seq[(TopicPartition, FetchData)] = null + val request = fetchRequest.build() + + // We can build the map from the request since it contains topic IDs and names. + // Only one ID can be associated with a name and vice versa. + val topicNames = new mutable.HashMap[Uuid, String]() + request.data.topics.forEach { topic => + topicNames.put(topic.topicId, topic.topic) + } + + + def processResponseCallback(responsePartitionData: Seq[(TopicIdPartition, FetchPartitionData)]): Unit = { + partitionData = responsePartitionData.map { case (tp, data) => + val abortedTransactions = data.abortedTransactions.map(_.asJava).orNull + val lastStableOffset = data.lastStableOffset.getOrElse(FetchResponse.INVALID_LAST_STABLE_OFFSET) + tp.topicPartition -> new FetchResponseData.PartitionData() + .setPartitionIndex(tp.topicPartition.partition) + .setErrorCode(data.error.code) + .setHighWatermark(data.highWatermark) + .setLastStableOffset(lastStableOffset) + .setLogStartOffset(data.logStartOffset) + .setAbortedTransactions(abortedTransactions) + .setRecords(data.records) + } + } + + val fetchData = request.fetchData(topicNames.asJava) + + replicaMgr.fetchMessages( + 0L, // timeout is 0 so that the callback will be executed immediately + Request.FutureLocalReplicaId, + request.minBytes, + request.maxBytes, + false, + fetchData.asScala.toSeq, + UnboundedQuota, + processResponseCallback, + request.isolationLevel, + None) + + if (partitionData == null) + throw new IllegalStateException(s"Failed to fetch data for partitions ${fetchData.keySet().toArray.mkString(",")}") + + partitionData.toMap + } + + // process fetched data + override def processPartitionData(topicPartition: TopicPartition, + fetchOffset: Long, + partitionData: FetchData): Option[LogAppendInfo] = { + val partition = replicaMgr.getPartitionOrException(topicPartition) + val futureLog = partition.futureLocalLogOrException + val records = toMemoryRecords(FetchResponse.recordsOrFail(partitionData)) + + if (fetchOffset != futureLog.logEndOffset) + throw new IllegalStateException("Offset mismatch for the future replica %s: fetched offset = %d, log end offset = %d.".format( + topicPartition, fetchOffset, futureLog.logEndOffset)) + + val logAppendInfo = if (records.sizeInBytes() > 0) + partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = true) + else + None + + futureLog.updateHighWatermark(partitionData.highWatermark) + futureLog.maybeIncrementLogStartOffset(partitionData.logStartOffset, LeaderOffsetIncremented) + + if (partition.maybeReplaceCurrentWithFutureReplica()) + removePartitions(Set(topicPartition)) + + quota.record(records.sizeInBytes) + logAppendInfo + } + + override def addPartitions(initialFetchStates: Map[TopicPartition, InitialFetchState]): Set[TopicPartition] = { + partitionMapLock.lockInterruptibly() + try { + // It is possible that the log dir fetcher completed just before this call, so we + // filter only the partitions which still have a future log dir. + val filteredFetchStates = initialFetchStates.filter { case (tp, _) => + replicaMgr.futureLogExists(tp) + } + super.addPartitions(filteredFetchStates) + } finally { + partitionMapLock.unlock() + } + } + + override protected def fetchEarliestOffsetFromLeader(topicPartition: TopicPartition, leaderEpoch: Int): Long = { + val partition = replicaMgr.getPartitionOrException(topicPartition) + partition.localLogOrException.logStartOffset + } + + override protected def fetchLatestOffsetFromLeader(topicPartition: TopicPartition, leaderEpoch: Int): Long = { + val partition = replicaMgr.getPartitionOrException(topicPartition) + partition.localLogOrException.logEndOffset + } + + /** + * Fetches offset for leader epoch from local replica for each given topic partitions + * @param partitions map of topic partition -> leader epoch of the future replica + * @return map of topic partition -> end offset for a requested leader epoch + */ + override def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] = { + partitions.map { case (tp, epochData) => + try { + val endOffset = if (epochData.leaderEpoch == UNDEFINED_EPOCH) { + new EpochEndOffset() + .setPartition(tp.partition) + .setErrorCode(Errors.NONE.code) + } else { + val partition = replicaMgr.getPartitionOrException(tp) + partition.lastOffsetForLeaderEpoch( + currentLeaderEpoch = RequestUtils.getLeaderEpoch(epochData.currentLeaderEpoch), + leaderEpoch = epochData.leaderEpoch, + fetchOnlyFromLeader = false) + } + tp -> endOffset + } catch { + case t: Throwable => + warn(s"Error when getting EpochEndOffset for $tp", t) + tp -> new EpochEndOffset() + .setPartition(tp.partition) + .setErrorCode(Errors.forException(t).code) + } + } + } + + override protected val isOffsetForLeaderEpochSupported: Boolean = true + + override protected val isTruncationOnFetchSupported: Boolean = false + + /** + * Truncate the log for each partition based on current replica's returned epoch and offset. + * + * The logic for finding the truncation offset is the same as in ReplicaFetcherThread + * and mainly implemented in AbstractFetcherThread.getOffsetTruncationState. One difference is + * that the initial fetch offset for topic partition could be set to the truncation offset of + * the current replica if that replica truncates. Otherwise, it is high watermark as in ReplicaFetcherThread. + * + * The reason we have to follow the leader epoch approach for truncating a future replica is to + * cover the case where a future replica is offline when the current replica truncates and + * re-replicates offsets that may have already been copied to the future replica. In that case, + * the future replica may miss "mark for truncation" event and must use the offset for leader epoch + * exchange with the current replica to truncate to the largest common log prefix for the topic partition + */ + override def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit = { + val partition = replicaMgr.getPartitionOrException(topicPartition) + partition.truncateTo(truncationState.offset, isFuture = true) + } + + override protected def truncateFullyAndStartAt(topicPartition: TopicPartition, offset: Long): Unit = { + val partition = replicaMgr.getPartitionOrException(topicPartition) + partition.truncateFullyAndStartAt(offset, isFuture = true) + } + + private def nextReadyPartition(partitionMap: Map[TopicPartition, PartitionFetchState]): Option[(TopicPartition, PartitionFetchState)] = { + partitionMap.filter { case (_, partitionFetchState) => + partitionFetchState.isReadyForFetch + }.reduceLeftOption { (left, right) => + if ((left._1.topic < right._1.topic) || (left._1.topic == right._1.topic && left._1.partition < right._1.partition)) + left + else + right + } + } + + private def selectPartitionToFetch(partitionMap: Map[TopicPartition, PartitionFetchState]): Option[(TopicPartition, PartitionFetchState)] = { + // Only move one partition at a time to increase its catch-up rate and thus reduce the time spent on + // moving any given replica. Replicas are selected in ascending order (lexicographically by topic) from the + // partitions that are ready to fetch. Once selected, we will continue fetching the same partition until it + // becomes unavailable or is removed. + + inProgressPartition.foreach { tp => + val fetchStateOpt = partitionMap.get(tp) + fetchStateOpt.filter(_.isReadyForFetch).foreach { fetchState => + return Some((tp, fetchState)) + } + } + + inProgressPartition = None + + val nextPartitionOpt = nextReadyPartition(partitionMap) + nextPartitionOpt.foreach { case (tp, fetchState) => + inProgressPartition = Some(tp) + info(s"Beginning/resuming copy of partition $tp from offset ${fetchState.fetchOffset}. " + + s"Including this partition, there are ${partitionMap.size} remaining partitions to copy by this thread.") + } + nextPartitionOpt + } + + private def buildFetchForPartition(tp: TopicPartition, fetchState: PartitionFetchState): ResultWithPartitions[Option[ReplicaFetch]] = { + val requestMap = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + val partitionsWithError = mutable.Set[TopicPartition]() + + try { + val logStartOffset = replicaMgr.futureLocalLogOrException(tp).logStartOffset + val lastFetchedEpoch = if (isTruncationOnFetchSupported) + fetchState.lastFetchedEpoch.map(_.asInstanceOf[Integer]).asJava + else + Optional.empty[Integer] + val topicId = fetchState.topicId.getOrElse(Uuid.ZERO_UUID) + requestMap.put(tp, new FetchRequest.PartitionData(topicId, fetchState.fetchOffset, logStartOffset, + fetchSize, Optional.of(fetchState.currentLeaderEpoch), lastFetchedEpoch)) + } catch { + case e: KafkaStorageException => + debug(s"Failed to build fetch for $tp", e) + partitionsWithError += tp + } + + val fetchRequestOpt = if (requestMap.isEmpty) { + None + } else { + val version: Short = if (fetchState.topicId.isEmpty) + 12 + else + ApiKeys.FETCH.latestVersion + // Set maxWait and minBytes to 0 because the response should return immediately if + // the future log has caught up with the current log of the partition + val requestBuilder = FetchRequest.Builder.forReplica(version, replicaId, 0, 0, requestMap).setMaxBytes(maxBytes) + Some(ReplicaFetch(requestMap, requestBuilder)) + } + + ResultWithPartitions(fetchRequestOpt, partitionsWithError) + } + + def buildFetch(partitionMap: Map[TopicPartition, PartitionFetchState]): ResultWithPartitions[Option[ReplicaFetch]] = { + // Only include replica in the fetch request if it is not throttled. + if (quota.isQuotaExceeded) { + ResultWithPartitions(None, Set.empty) + } else { + selectPartitionToFetch(partitionMap) match { + case Some((tp, fetchState)) => + buildFetchForPartition(tp, fetchState) + case None => + ResultWithPartitions(None, Set.empty) + } + } + } + +} diff --git a/core/src/main/scala/kafka/server/ReplicaFetcherBlockingSend.scala b/core/src/main/scala/kafka/server/ReplicaFetcherBlockingSend.scala new file mode 100644 index 0000000..fd69b5a --- /dev/null +++ b/core/src/main/scala/kafka/server/ReplicaFetcherBlockingSend.scala @@ -0,0 +1,127 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.net.SocketTimeoutException + +import kafka.cluster.BrokerEndPoint +import org.apache.kafka.clients._ +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.network._ +import org.apache.kafka.common.requests.AbstractRequest +import org.apache.kafka.common.security.JaasContext +import org.apache.kafka.common.utils.{LogContext, Time} +import org.apache.kafka.clients.{ApiVersions, ClientResponse, ManualMetadataUpdater, NetworkClient} +import org.apache.kafka.common.{Node, Reconfigurable} +import org.apache.kafka.common.requests.AbstractRequest.Builder + +import scala.jdk.CollectionConverters._ + +trait BlockingSend { + + def sendRequest(requestBuilder: AbstractRequest.Builder[_ <: AbstractRequest]): ClientResponse + + def initiateClose(): Unit + + def close(): Unit +} + +class ReplicaFetcherBlockingSend(sourceBroker: BrokerEndPoint, + brokerConfig: KafkaConfig, + metrics: Metrics, + time: Time, + fetcherId: Int, + clientId: String, + logContext: LogContext) extends BlockingSend { + + private val sourceNode = new Node(sourceBroker.id, sourceBroker.host, sourceBroker.port) + private val socketTimeout: Int = brokerConfig.replicaSocketTimeoutMs + + private val (networkClient, reconfigurableChannelBuilder) = { + val channelBuilder = ChannelBuilders.clientChannelBuilder( + brokerConfig.interBrokerSecurityProtocol, + JaasContext.Type.SERVER, + brokerConfig, + brokerConfig.interBrokerListenerName, + brokerConfig.saslMechanismInterBrokerProtocol, + time, + brokerConfig.saslInterBrokerHandshakeRequestEnable, + logContext + ) + val reconfigurableChannelBuilder = channelBuilder match { + case reconfigurable: Reconfigurable => + brokerConfig.addReconfigurable(reconfigurable) + Some(reconfigurable) + case _ => None + } + val selector = new Selector( + NetworkReceive.UNLIMITED, + brokerConfig.connectionsMaxIdleMs, + metrics, + time, + "replica-fetcher", + Map("broker-id" -> sourceBroker.id.toString, "fetcher-id" -> fetcherId.toString).asJava, + false, + channelBuilder, + logContext + ) + val networkClient = new NetworkClient( + selector, + new ManualMetadataUpdater(), + clientId, + 1, + 0, + 0, + Selectable.USE_DEFAULT_BUFFER_SIZE, + brokerConfig.replicaSocketReceiveBufferBytes, + brokerConfig.requestTimeoutMs, + brokerConfig.connectionSetupTimeoutMs, + brokerConfig.connectionSetupTimeoutMaxMs, + time, + false, + new ApiVersions, + logContext + ) + (networkClient, reconfigurableChannelBuilder) + } + + override def sendRequest(requestBuilder: Builder[_ <: AbstractRequest]): ClientResponse = { + try { + if (!NetworkClientUtils.awaitReady(networkClient, sourceNode, time, socketTimeout)) + throw new SocketTimeoutException(s"Failed to connect within $socketTimeout ms") + else { + val clientRequest = networkClient.newClientRequest(sourceBroker.id.toString, requestBuilder, + time.milliseconds(), true) + NetworkClientUtils.sendAndReceive(networkClient, clientRequest, time) + } + } + catch { + case e: Throwable => + networkClient.close(sourceBroker.id.toString) + throw e + } + } + + override def initiateClose(): Unit = { + reconfigurableChannelBuilder.foreach(brokerConfig.removeReconfigurable) + networkClient.initiateClose() + } + + def close(): Unit = { + networkClient.close() + } +} diff --git a/core/src/main/scala/kafka/server/ReplicaFetcherManager.scala b/core/src/main/scala/kafka/server/ReplicaFetcherManager.scala new file mode 100644 index 0000000..d547e1b --- /dev/null +++ b/core/src/main/scala/kafka/server/ReplicaFetcherManager.scala @@ -0,0 +1,47 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.cluster.BrokerEndPoint +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.utils.Time + +class ReplicaFetcherManager(brokerConfig: KafkaConfig, + protected val replicaManager: ReplicaManager, + metrics: Metrics, + time: Time, + threadNamePrefix: Option[String] = None, + quotaManager: ReplicationQuotaManager) + extends AbstractFetcherManager[ReplicaFetcherThread]( + name = "ReplicaFetcherManager on broker " + brokerConfig.brokerId, + clientId = "Replica", + numFetchers = brokerConfig.numReplicaFetchers) { + + override def createFetcherThread(fetcherId: Int, sourceBroker: BrokerEndPoint): ReplicaFetcherThread = { + val prefix = threadNamePrefix.map(tp => s"$tp:").getOrElse("") + val threadName = s"${prefix}ReplicaFetcherThread-$fetcherId-${sourceBroker.id}" + new ReplicaFetcherThread(threadName, fetcherId, sourceBroker, brokerConfig, failedPartitions, replicaManager, + metrics, time, quotaManager) + } + + def shutdown(): Unit = { + info("shutting down") + closeAllFetchers() + info("shutdown completed") + } +} diff --git a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala new file mode 100644 index 0000000..57d89dc --- /dev/null +++ b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala @@ -0,0 +1,396 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util.Collections +import java.util.Optional + +import kafka.api._ +import kafka.cluster.BrokerEndPoint +import kafka.log.{LeaderOffsetIncremented, LogAppendInfo} +import kafka.server.AbstractFetcherThread.ReplicaFetch +import kafka.server.AbstractFetcherThread.ResultWithPartitions +import kafka.utils.Implicits._ +import org.apache.kafka.clients.FetchSessionHandler +import org.apache.kafka.common.{TopicPartition, Uuid} +import org.apache.kafka.common.errors.KafkaStorageException +import org.apache.kafka.common.message.ListOffsetsRequestData.{ListOffsetsPartition, ListOffsetsTopic} +import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderTopic +import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderTopicCollection +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.record.MemoryRecords +import org.apache.kafka.common.requests._ +import org.apache.kafka.common.utils.{LogContext, Time} + +import scala.jdk.CollectionConverters._ +import scala.collection.{Map, mutable} +import scala.compat.java8.OptionConverters._ + +class ReplicaFetcherThread(name: String, + fetcherId: Int, + sourceBroker: BrokerEndPoint, + brokerConfig: KafkaConfig, + failedPartitions: FailedPartitions, + replicaMgr: ReplicaManager, + metrics: Metrics, + time: Time, + quota: ReplicaQuota, + leaderEndpointBlockingSend: Option[BlockingSend] = None) + extends AbstractFetcherThread(name = name, + clientId = name, + sourceBroker = sourceBroker, + failedPartitions, + fetchBackOffMs = brokerConfig.replicaFetchBackoffMs, + isInterruptible = false, + replicaMgr.brokerTopicStats) { + + private val replicaId = brokerConfig.brokerId + private val logContext = new LogContext(s"[ReplicaFetcher replicaId=$replicaId, leaderId=${sourceBroker.id}, " + + s"fetcherId=$fetcherId] ") + this.logIdent = logContext.logPrefix + + private val leaderEndpoint = leaderEndpointBlockingSend.getOrElse( + new ReplicaFetcherBlockingSend(sourceBroker, brokerConfig, metrics, time, fetcherId, + s"broker-$replicaId-fetcher-$fetcherId", logContext)) + + // Visible for testing + private[server] val fetchRequestVersion: Short = + if (brokerConfig.interBrokerProtocolVersion >= KAFKA_3_1_IV0) 13 + else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_2_7_IV1) 12 + else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_2_3_IV1) 11 + else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_2_1_IV2) 10 + else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_2_0_IV1) 8 + else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_1_1_IV0) 7 + else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_0_11_0_IV1) 5 + else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_0_11_0_IV0) 4 + else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_0_10_1_IV1) 3 + else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_0_10_0_IV0) 2 + else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_0_9_0) 1 + else 0 + + // Visible for testing + private[server] val offsetForLeaderEpochRequestVersion: Short = + if (brokerConfig.interBrokerProtocolVersion >= KAFKA_2_8_IV0) 4 + else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_2_3_IV1) 3 + else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_2_1_IV1) 2 + else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_2_0_IV0) 1 + else 0 + + // Visible for testing + private[server] val listOffsetRequestVersion: Short = + if (brokerConfig.interBrokerProtocolVersion >= KAFKA_3_0_IV1) 7 + else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_2_8_IV0) 6 + else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_2_2_IV1) 5 + else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_2_1_IV1) 4 + else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_2_0_IV1) 3 + else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_0_11_0_IV0) 2 + else if (brokerConfig.interBrokerProtocolVersion >= KAFKA_0_10_1_IV2) 1 + else 0 + + private val maxWait = brokerConfig.replicaFetchWaitMaxMs + private val minBytes = brokerConfig.replicaFetchMinBytes + private val maxBytes = brokerConfig.replicaFetchResponseMaxBytes + private val fetchSize = brokerConfig.replicaFetchMaxBytes + override protected val isOffsetForLeaderEpochSupported: Boolean = brokerConfig.interBrokerProtocolVersion >= KAFKA_0_11_0_IV2 + override protected val isTruncationOnFetchSupported = ApiVersion.isTruncationOnFetchSupported(brokerConfig.interBrokerProtocolVersion) + val fetchSessionHandler = new FetchSessionHandler(logContext, sourceBroker.id) + + override protected def latestEpoch(topicPartition: TopicPartition): Option[Int] = { + replicaMgr.localLogOrException(topicPartition).latestEpoch + } + + override protected def logStartOffset(topicPartition: TopicPartition): Long = { + replicaMgr.localLogOrException(topicPartition).logStartOffset + } + + override protected def logEndOffset(topicPartition: TopicPartition): Long = { + replicaMgr.localLogOrException(topicPartition).logEndOffset + } + + override protected def endOffsetForEpoch(topicPartition: TopicPartition, epoch: Int): Option[OffsetAndEpoch] = { + replicaMgr.localLogOrException(topicPartition).endOffsetForEpoch(epoch) + } + + override def initiateShutdown(): Boolean = { + val justShutdown = super.initiateShutdown() + if (justShutdown) { + // This is thread-safe, so we don't expect any exceptions, but catch and log any errors + // to avoid failing the caller, especially during shutdown. We will attempt to close + // leaderEndpoint after the thread terminates. + try { + leaderEndpoint.initiateClose() + } catch { + case t: Throwable => + error(s"Failed to initiate shutdown of leader endpoint $leaderEndpoint after initiating replica fetcher thread shutdown", t) + } + } + justShutdown + } + + override def awaitShutdown(): Unit = { + super.awaitShutdown() + // We don't expect any exceptions here, but catch and log any errors to avoid failing the caller, + // especially during shutdown. It is safe to catch the exception here without causing correctness + // issue because we are going to shutdown the thread and will not re-use the leaderEndpoint anyway. + try { + leaderEndpoint.close() + } catch { + case t: Throwable => + error(s"Failed to close leader endpoint $leaderEndpoint after shutting down replica fetcher thread", t) + } + } + + // process fetched data + override def processPartitionData(topicPartition: TopicPartition, + fetchOffset: Long, + partitionData: FetchData): Option[LogAppendInfo] = { + val logTrace = isTraceEnabled + val partition = replicaMgr.getPartitionOrException(topicPartition) + val log = partition.localLogOrException + val records = toMemoryRecords(FetchResponse.recordsOrFail(partitionData)) + + maybeWarnIfOversizedRecords(records, topicPartition) + + if (fetchOffset != log.logEndOffset) + throw new IllegalStateException("Offset mismatch for partition %s: fetched offset = %d, log end offset = %d.".format( + topicPartition, fetchOffset, log.logEndOffset)) + + if (logTrace) + trace("Follower has replica log end offset %d for partition %s. Received %d messages and leader hw %d" + .format(log.logEndOffset, topicPartition, records.sizeInBytes, partitionData.highWatermark)) + + // Append the leader's messages to the log + val logAppendInfo = partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = false) + + if (logTrace) + trace("Follower has replica log end offset %d after appending %d bytes of messages for partition %s" + .format(log.logEndOffset, records.sizeInBytes, topicPartition)) + val leaderLogStartOffset = partitionData.logStartOffset + + // For the follower replica, we do not need to keep its segment base offset and physical position. + // These values will be computed upon becoming leader or handling a preferred read replica fetch. + val followerHighWatermark = log.updateHighWatermark(partitionData.highWatermark) + log.maybeIncrementLogStartOffset(leaderLogStartOffset, LeaderOffsetIncremented) + if (logTrace) + trace(s"Follower set replica high watermark for partition $topicPartition to $followerHighWatermark") + + // Traffic from both in-sync and out of sync replicas are accounted for in replication quota to ensure total replication + // traffic doesn't exceed quota. + if (quota.isThrottled(topicPartition)) + quota.record(records.sizeInBytes) + + if (partition.isReassigning && partition.isAddingLocalReplica) + brokerTopicStats.updateReassignmentBytesIn(records.sizeInBytes) + + brokerTopicStats.updateReplicationBytesIn(records.sizeInBytes) + + logAppendInfo + } + + def maybeWarnIfOversizedRecords(records: MemoryRecords, topicPartition: TopicPartition): Unit = { + // oversized messages don't cause replication to fail from fetch request version 3 (KIP-74) + if (fetchRequestVersion <= 2 && records.sizeInBytes > 0 && records.validBytes <= 0) + error(s"Replication is failing due to a message that is greater than replica.fetch.max.bytes for partition $topicPartition. " + + "This generally occurs when the max.message.bytes has been overridden to exceed this value and a suitably large " + + "message has also been sent. To fix this problem increase replica.fetch.max.bytes in your broker config to be " + + "equal or larger than your settings for max.message.bytes, both at a broker and topic level.") + } + + + override protected def fetchFromLeader(fetchRequest: FetchRequest.Builder): Map[TopicPartition, FetchData] = { + val clientResponse = try { + leaderEndpoint.sendRequest(fetchRequest) + } catch { + case t: Throwable => + fetchSessionHandler.handleError(t) + throw t + } + val fetchResponse = clientResponse.responseBody.asInstanceOf[FetchResponse] + if (!fetchSessionHandler.handleResponse(fetchResponse, clientResponse.requestHeader().apiVersion())) { + // If we had a session topic ID related error, throw it, otherwise return an empty fetch data map. + if (fetchResponse.error == Errors.FETCH_SESSION_TOPIC_ID_ERROR) { + throw Errors.forCode(fetchResponse.error().code()).exception() + } else { + Map.empty + } + } else { + fetchResponse.responseData(fetchSessionHandler.sessionTopicNames, clientResponse.requestHeader().apiVersion()).asScala + } + } + + override protected def fetchEarliestOffsetFromLeader(topicPartition: TopicPartition, currentLeaderEpoch: Int): Long = { + fetchOffsetFromLeader(topicPartition, currentLeaderEpoch, ListOffsetsRequest.EARLIEST_TIMESTAMP) + } + + override protected def fetchLatestOffsetFromLeader(topicPartition: TopicPartition, currentLeaderEpoch: Int): Long = { + fetchOffsetFromLeader(topicPartition, currentLeaderEpoch, ListOffsetsRequest.LATEST_TIMESTAMP) + } + + private def fetchOffsetFromLeader(topicPartition: TopicPartition, currentLeaderEpoch: Int, earliestOrLatest: Long): Long = { + val topic = new ListOffsetsTopic() + .setName(topicPartition.topic) + .setPartitions(Collections.singletonList( + new ListOffsetsPartition() + .setPartitionIndex(topicPartition.partition) + .setCurrentLeaderEpoch(currentLeaderEpoch) + .setTimestamp(earliestOrLatest))) + val requestBuilder = ListOffsetsRequest.Builder.forReplica(listOffsetRequestVersion, replicaId) + .setTargetTimes(Collections.singletonList(topic)) + + val clientResponse = leaderEndpoint.sendRequest(requestBuilder) + val response = clientResponse.responseBody.asInstanceOf[ListOffsetsResponse] + val responsePartition = response.topics.asScala.find(_.name == topicPartition.topic).get + .partitions.asScala.find(_.partitionIndex == topicPartition.partition).get + + Errors.forCode(responsePartition.errorCode) match { + case Errors.NONE => + if (brokerConfig.interBrokerProtocolVersion >= KAFKA_0_10_1_IV2) + responsePartition.offset + else + responsePartition.oldStyleOffsets.get(0) + case error => throw error.exception + } + } + + override def buildFetch(partitionMap: Map[TopicPartition, PartitionFetchState]): ResultWithPartitions[Option[ReplicaFetch]] = { + val partitionsWithError = mutable.Set[TopicPartition]() + + val builder = fetchSessionHandler.newBuilder(partitionMap.size, false) + partitionMap.forKeyValue { (topicPartition, fetchState) => + // We will not include a replica in the fetch request if it should be throttled. + if (fetchState.isReadyForFetch && !shouldFollowerThrottle(quota, fetchState, topicPartition)) { + try { + val logStartOffset = this.logStartOffset(topicPartition) + val lastFetchedEpoch = if (isTruncationOnFetchSupported) + fetchState.lastFetchedEpoch.map(_.asInstanceOf[Integer]).asJava + else + Optional.empty[Integer] + builder.add(topicPartition, new FetchRequest.PartitionData( + fetchState.topicId.getOrElse(Uuid.ZERO_UUID), + fetchState.fetchOffset, + logStartOffset, + fetchSize, + Optional.of(fetchState.currentLeaderEpoch), + lastFetchedEpoch)) + } catch { + case _: KafkaStorageException => + // The replica has already been marked offline due to log directory failure and the original failure should have already been logged. + // This partition should be removed from ReplicaFetcherThread soon by ReplicaManager.handleLogDirFailure() + partitionsWithError += topicPartition + } + } + } + + val fetchData = builder.build() + val fetchRequestOpt = if (fetchData.sessionPartitions.isEmpty && fetchData.toForget.isEmpty) { + None + } else { + val version: Short = if (fetchRequestVersion >= 13 && !fetchData.canUseTopicIds) 12 else fetchRequestVersion + val requestBuilder = FetchRequest.Builder + .forReplica(version, replicaId, maxWait, minBytes, fetchData.toSend) + .setMaxBytes(maxBytes) + .removed(fetchData.toForget) + .replaced(fetchData.toReplace) + .metadata(fetchData.metadata) + Some(ReplicaFetch(fetchData.sessionPartitions(), requestBuilder)) + } + + ResultWithPartitions(fetchRequestOpt, partitionsWithError) + } + + /** + * Truncate the log for each partition's epoch based on leader's returned epoch and offset. + * The logic for finding the truncation offset is implemented in AbstractFetcherThread.getOffsetTruncationState + */ + override def truncate(tp: TopicPartition, offsetTruncationState: OffsetTruncationState): Unit = { + val partition = replicaMgr.getPartitionOrException(tp) + val log = partition.localLogOrException + + partition.truncateTo(offsetTruncationState.offset, isFuture = false) + + if (offsetTruncationState.offset < log.highWatermark) + warn(s"Truncating $tp to offset ${offsetTruncationState.offset} below high watermark " + + s"${log.highWatermark}") + + // mark the future replica for truncation only when we do last truncation + if (offsetTruncationState.truncationCompleted) + replicaMgr.replicaAlterLogDirsManager.markPartitionsForTruncation(brokerConfig.brokerId, tp, + offsetTruncationState.offset) + } + + override protected def truncateFullyAndStartAt(topicPartition: TopicPartition, offset: Long): Unit = { + val partition = replicaMgr.getPartitionOrException(topicPartition) + partition.truncateFullyAndStartAt(offset, isFuture = false) + } + + override def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] = { + + if (partitions.isEmpty) { + debug("Skipping leaderEpoch request since all partitions do not have an epoch") + return Map.empty + } + + val topics = new OffsetForLeaderTopicCollection(partitions.size) + partitions.forKeyValue { (topicPartition, epochData) => + var topic = topics.find(topicPartition.topic) + if (topic == null) { + topic = new OffsetForLeaderTopic().setTopic(topicPartition.topic) + topics.add(topic) + } + topic.partitions.add(epochData) + } + + val epochRequest = OffsetsForLeaderEpochRequest.Builder.forFollower( + offsetForLeaderEpochRequestVersion, topics, brokerConfig.brokerId) + debug(s"Sending offset for leader epoch request $epochRequest") + + try { + val response = leaderEndpoint.sendRequest(epochRequest) + val responseBody = response.responseBody.asInstanceOf[OffsetsForLeaderEpochResponse] + debug(s"Received leaderEpoch response $response") + responseBody.data.topics.asScala.flatMap { offsetForLeaderTopicResult => + offsetForLeaderTopicResult.partitions.asScala.map { offsetForLeaderPartitionResult => + val tp = new TopicPartition(offsetForLeaderTopicResult.topic, offsetForLeaderPartitionResult.partition) + tp -> offsetForLeaderPartitionResult + } + }.toMap + } catch { + case t: Throwable => + warn(s"Error when sending leader epoch request for $partitions", t) + + // if we get any unexpected exception, mark all partitions with an error + val error = Errors.forException(t) + partitions.map { case (tp, _) => + tp -> new EpochEndOffset() + .setPartition(tp.partition) + .setErrorCode(error.code) + } + } + } + + /** + * To avoid ISR thrashing, we only throttle a replica on the follower if it's in the throttled replica list, + * the quota is exceeded and the replica is not in sync. + */ + private def shouldFollowerThrottle(quota: ReplicaQuota, fetchState: PartitionFetchState, topicPartition: TopicPartition): Boolean = { + !fetchState.isReplicaInSync && quota.isThrottled(topicPartition) && quota.isQuotaExceeded + } + +} diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala b/core/src/main/scala/kafka/server/ReplicaManager.scala new file mode 100644 index 0000000..9f8d923 --- /dev/null +++ b/core/src/main/scala/kafka/server/ReplicaManager.scala @@ -0,0 +1,2284 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.io.File +import java.util.Optional +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.locks.Lock +import com.yammer.metrics.core.Meter +import kafka.api._ +import kafka.cluster.{BrokerEndPoint, Partition} +import kafka.common.RecordValidationException +import kafka.controller.{KafkaController, StateChangeLogger} +import kafka.log._ +import kafka.metrics.KafkaMetricsGroup +import kafka.server.{FetchMetadata => SFetchMetadata} +import kafka.server.HostedPartition.Online +import kafka.server.QuotaFactory.QuotaManagers +import kafka.server.checkpoints.{LazyOffsetCheckpoints, OffsetCheckpointFile, OffsetCheckpoints} +import kafka.server.metadata.ZkMetadataCache +import kafka.utils._ +import kafka.utils.Implicits._ +import kafka.zk.KafkaZkClient +import org.apache.kafka.common.{ElectionType, IsolationLevel, Node, TopicIdPartition, TopicPartition, Uuid} +import org.apache.kafka.common.errors._ +import org.apache.kafka.common.internals.Topic +import org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrPartitionState +import org.apache.kafka.common.message.DeleteRecordsResponseData.DeleteRecordsPartitionResult +import org.apache.kafka.common.message.{DescribeLogDirsResponseData, DescribeProducersResponseData, FetchResponseData, LeaderAndIsrResponseData} +import org.apache.kafka.common.message.LeaderAndIsrResponseData.LeaderAndIsrTopicError +import org.apache.kafka.common.message.LeaderAndIsrResponseData.LeaderAndIsrPartitionError +import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderTopic +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.{EpochEndOffset, OffsetForLeaderTopicResult} +import org.apache.kafka.common.message.StopReplicaRequestData.StopReplicaPartitionState +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.record.FileRecords.TimestampAndOffset +import org.apache.kafka.common.record._ +import org.apache.kafka.common.replica.PartitionView.DefaultPartitionView +import org.apache.kafka.common.replica.ReplicaView.DefaultReplicaView +import org.apache.kafka.common.replica.{ClientMetadata, _} +import org.apache.kafka.common.requests.FetchRequest.PartitionData +import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse +import org.apache.kafka.common.requests._ +import org.apache.kafka.common.utils.Time +import org.apache.kafka.image.{LocalReplicaChanges, MetadataImage, TopicsDelta} + +import scala.jdk.CollectionConverters._ +import scala.collection.{Map, Seq, Set, mutable} +import scala.compat.java8.OptionConverters._ + +/* + * Result metadata of a log append operation on the log + */ +case class LogAppendResult(info: LogAppendInfo, exception: Option[Throwable] = None) { + def error: Errors = exception match { + case None => Errors.NONE + case Some(e) => Errors.forException(e) + } +} + +case class LogDeleteRecordsResult(requestedOffset: Long, lowWatermark: Long, exception: Option[Throwable] = None) { + def error: Errors = exception match { + case None => Errors.NONE + case Some(e) => Errors.forException(e) + } +} + +/** + * Result metadata of a log read operation on the log + * @param info @FetchDataInfo returned by the @Log read + * @param divergingEpoch Optional epoch and end offset which indicates the largest epoch such + * that subsequent records are known to diverge on the follower/consumer + * @param highWatermark high watermark of the local replica + * @param leaderLogStartOffset The log start offset of the leader at the time of the read + * @param leaderLogEndOffset The log end offset of the leader at the time of the read + * @param followerLogStartOffset The log start offset of the follower taken from the Fetch request + * @param fetchTimeMs The time the fetch was received + * @param lastStableOffset Current LSO or None if the result has an exception + * @param preferredReadReplica the preferred read replica to be used for future fetches + * @param exception Exception if error encountered while reading from the log + */ +case class LogReadResult(info: FetchDataInfo, + divergingEpoch: Option[FetchResponseData.EpochEndOffset], + highWatermark: Long, + leaderLogStartOffset: Long, + leaderLogEndOffset: Long, + followerLogStartOffset: Long, + fetchTimeMs: Long, + lastStableOffset: Option[Long], + preferredReadReplica: Option[Int] = None, + exception: Option[Throwable] = None) { + + def error: Errors = exception match { + case None => Errors.NONE + case Some(e) => Errors.forException(e) + } + + def toFetchPartitionData(isReassignmentFetch: Boolean): FetchPartitionData = FetchPartitionData( + this.error, + this.highWatermark, + this.leaderLogStartOffset, + this.info.records, + this.divergingEpoch, + this.lastStableOffset, + this.info.abortedTransactions, + this.preferredReadReplica, + isReassignmentFetch) + + def withEmptyFetchInfo: LogReadResult = + copy(info = FetchDataInfo(LogOffsetMetadata.UnknownOffsetMetadata, MemoryRecords.EMPTY)) + + override def toString = { + "LogReadResult(" + + s"info=$info, " + + s"divergingEpoch=$divergingEpoch, " + + s"highWatermark=$highWatermark, " + + s"leaderLogStartOffset=$leaderLogStartOffset, " + + s"leaderLogEndOffset=$leaderLogEndOffset, " + + s"followerLogStartOffset=$followerLogStartOffset, " + + s"fetchTimeMs=$fetchTimeMs, " + + s"preferredReadReplica=$preferredReadReplica, " + + s"lastStableOffset=$lastStableOffset, " + + s"error=$error" + + ")" + } + +} + +case class FetchPartitionData(error: Errors = Errors.NONE, + highWatermark: Long, + logStartOffset: Long, + records: Records, + divergingEpoch: Option[FetchResponseData.EpochEndOffset], + lastStableOffset: Option[Long], + abortedTransactions: Option[List[FetchResponseData.AbortedTransaction]], + preferredReadReplica: Option[Int], + isReassignmentFetch: Boolean) + +/** + * Trait to represent the state of hosted partitions. We create a concrete (active) Partition + * instance when the broker receives a LeaderAndIsr request from the controller or a metadata + * log record from the Quorum controller indicating that the broker should be either a leader + * or follower of a partition. + */ +sealed trait HostedPartition + +object HostedPartition { + /** + * This broker does not have any state for this partition locally. + */ + final object None extends HostedPartition + + /** + * This broker hosts the partition and it is online. + */ + final case class Online(partition: Partition) extends HostedPartition + + /** + * This broker hosts the partition, but it is in an offline log directory. + */ + final object Offline extends HostedPartition +} + +object ReplicaManager { + val HighWatermarkFilename = "replication-offset-checkpoint" +} + +class ReplicaManager(val config: KafkaConfig, + metrics: Metrics, + time: Time, + scheduler: Scheduler, + val logManager: LogManager, + quotaManagers: QuotaManagers, + val metadataCache: MetadataCache, + logDirFailureChannel: LogDirFailureChannel, + val alterIsrManager: AlterIsrManager, + val brokerTopicStats: BrokerTopicStats = new BrokerTopicStats(), + val isShuttingDown: AtomicBoolean = new AtomicBoolean(false), + val zkClient: Option[KafkaZkClient] = None, + delayedProducePurgatoryParam: Option[DelayedOperationPurgatory[DelayedProduce]] = None, + delayedFetchPurgatoryParam: Option[DelayedOperationPurgatory[DelayedFetch]] = None, + delayedDeleteRecordsPurgatoryParam: Option[DelayedOperationPurgatory[DelayedDeleteRecords]] = None, + delayedElectLeaderPurgatoryParam: Option[DelayedOperationPurgatory[DelayedElectLeader]] = None, + threadNamePrefix: Option[String] = None, + ) extends Logging with KafkaMetricsGroup { + + val delayedProducePurgatory = delayedProducePurgatoryParam.getOrElse( + DelayedOperationPurgatory[DelayedProduce]( + purgatoryName = "Produce", brokerId = config.brokerId, + purgeInterval = config.producerPurgatoryPurgeIntervalRequests)) + val delayedFetchPurgatory = delayedFetchPurgatoryParam.getOrElse( + DelayedOperationPurgatory[DelayedFetch]( + purgatoryName = "Fetch", brokerId = config.brokerId, + purgeInterval = config.fetchPurgatoryPurgeIntervalRequests)) + val delayedDeleteRecordsPurgatory = delayedDeleteRecordsPurgatoryParam.getOrElse( + DelayedOperationPurgatory[DelayedDeleteRecords]( + purgatoryName = "DeleteRecords", brokerId = config.brokerId, + purgeInterval = config.deleteRecordsPurgatoryPurgeIntervalRequests)) + val delayedElectLeaderPurgatory = delayedElectLeaderPurgatoryParam.getOrElse( + DelayedOperationPurgatory[DelayedElectLeader]( + purgatoryName = "ElectLeader", brokerId = config.brokerId)) + + /* epoch of the controller that last changed the leader */ + @volatile private[server] var controllerEpoch: Int = KafkaController.InitialControllerEpoch + protected val localBrokerId = config.brokerId + protected val allPartitions = new Pool[TopicPartition, HostedPartition]( + valueFactory = Some(tp => HostedPartition.Online(Partition(tp, time, this))) + ) + protected val replicaStateChangeLock = new Object + val replicaFetcherManager = createReplicaFetcherManager(metrics, time, threadNamePrefix, quotaManagers.follower) + private[server] val replicaAlterLogDirsManager = createReplicaAlterLogDirsManager(quotaManagers.alterLogDirs, brokerTopicStats) + private val highWatermarkCheckPointThreadStarted = new AtomicBoolean(false) + @volatile private[server] var highWatermarkCheckpoints: Map[String, OffsetCheckpointFile] = logManager.liveLogDirs.map(dir => + (dir.getAbsolutePath, new OffsetCheckpointFile(new File(dir, ReplicaManager.HighWatermarkFilename), logDirFailureChannel))).toMap + + this.logIdent = s"[ReplicaManager broker=$localBrokerId] " + protected val stateChangeLogger = new StateChangeLogger(localBrokerId, inControllerContext = false, None) + + private var logDirFailureHandler: LogDirFailureHandler = null + + private class LogDirFailureHandler(name: String, haltBrokerOnDirFailure: Boolean) extends ShutdownableThread(name) { + override def doWork(): Unit = { + val newOfflineLogDir = logDirFailureChannel.takeNextOfflineLogDir() + if (haltBrokerOnDirFailure) { + fatal(s"Halting broker because dir $newOfflineLogDir is offline") + Exit.halt(1) + } + handleLogDirFailure(newOfflineLogDir) + } + } + + // Visible for testing + private[server] val replicaSelectorOpt: Option[ReplicaSelector] = createReplicaSelector() + + newGauge("LeaderCount", () => leaderPartitionsIterator.size) + // Visible for testing + private[kafka] val partitionCount = newGauge("PartitionCount", () => allPartitions.size) + newGauge("OfflineReplicaCount", () => offlinePartitionCount) + newGauge("UnderReplicatedPartitions", () => underReplicatedPartitionCount) + newGauge("UnderMinIsrPartitionCount", () => leaderPartitionsIterator.count(_.isUnderMinIsr)) + newGauge("AtMinIsrPartitionCount", () => leaderPartitionsIterator.count(_.isAtMinIsr)) + newGauge("ReassigningPartitions", () => reassigningPartitionsCount) + + def reassigningPartitionsCount: Int = leaderPartitionsIterator.count(_.isReassigning) + + val isrExpandRate: Meter = newMeter("IsrExpandsPerSec", "expands", TimeUnit.SECONDS) + val isrShrinkRate: Meter = newMeter("IsrShrinksPerSec", "shrinks", TimeUnit.SECONDS) + val failedIsrUpdatesRate: Meter = newMeter("FailedIsrUpdatesPerSec", "failedUpdates", TimeUnit.SECONDS) + + def underReplicatedPartitionCount: Int = leaderPartitionsIterator.count(_.isUnderReplicated) + + def startHighWatermarkCheckPointThread(): Unit = { + if (highWatermarkCheckPointThreadStarted.compareAndSet(false, true)) + scheduler.schedule("highwatermark-checkpoint", checkpointHighWatermarks _, period = config.replicaHighWatermarkCheckpointIntervalMs, unit = TimeUnit.MILLISECONDS) + } + + // When ReplicaAlterDirThread finishes replacing a current replica with a future replica, it will + // remove the partition from the partition state map. But it will not close itself even if the + // partition state map is empty. Thus we need to call shutdownIdleReplicaAlterDirThread() periodically + // to shutdown idle ReplicaAlterDirThread + def shutdownIdleReplicaAlterLogDirsThread(): Unit = { + replicaAlterLogDirsManager.shutdownIdleFetcherThreads() + } + + def resizeFetcherThreadPool(newSize: Int): Unit = { + replicaFetcherManager.resizeThreadPool(newSize) + } + + def getLog(topicPartition: TopicPartition): Option[UnifiedLog] = logManager.getLog(topicPartition) + + def hasDelayedElectionOperations: Boolean = delayedElectLeaderPurgatory.numDelayed != 0 + + def tryCompleteElection(key: DelayedOperationKey): Unit = { + val completed = delayedElectLeaderPurgatory.checkAndComplete(key) + debug("Request key %s unblocked %d ElectLeader.".format(key.keyLabel, completed)) + } + + def startup(): Unit = { + // start ISR expiration thread + // A follower can lag behind leader for up to config.replicaLagTimeMaxMs x 1.5 before it is removed from ISR + scheduler.schedule("isr-expiration", maybeShrinkIsr _, period = config.replicaLagTimeMaxMs / 2, unit = TimeUnit.MILLISECONDS) + scheduler.schedule("shutdown-idle-replica-alter-log-dirs-thread", shutdownIdleReplicaAlterLogDirsThread _, period = 10000L, unit = TimeUnit.MILLISECONDS) + + // If inter-broker protocol (IBP) < 1.0, the controller will send LeaderAndIsrRequest V0 which does not include isNew field. + // In this case, the broker receiving the request cannot determine whether it is safe to create a partition if a log directory has failed. + // Thus, we choose to halt the broker on any log directory failure if IBP < 1.0 + val haltBrokerOnFailure = config.interBrokerProtocolVersion < KAFKA_1_0_IV0 + logDirFailureHandler = new LogDirFailureHandler("LogDirFailureHandler", haltBrokerOnFailure) + logDirFailureHandler.start() + } + + private def maybeRemoveTopicMetrics(topic: String): Unit = { + val topicHasNonOfflinePartition = allPartitions.values.exists { + case online: HostedPartition.Online => topic == online.partition.topic + case HostedPartition.None | HostedPartition.Offline => false + } + if (!topicHasNonOfflinePartition) // nothing online or deferred + brokerTopicStats.removeMetrics(topic) + } + + protected def completeDelayedFetchOrProduceRequests(topicPartition: TopicPartition): Unit = { + val topicPartitionOperationKey = TopicPartitionOperationKey(topicPartition) + delayedProducePurgatory.checkAndComplete(topicPartitionOperationKey) + delayedFetchPurgatory.checkAndComplete(topicPartitionOperationKey) + } + + def stopReplicas(correlationId: Int, + controllerId: Int, + controllerEpoch: Int, + partitionStates: Map[TopicPartition, StopReplicaPartitionState] + ): (mutable.Map[TopicPartition, Errors], Errors) = { + replicaStateChangeLock synchronized { + stateChangeLogger.info(s"Handling StopReplica request correlationId $correlationId from controller " + + s"$controllerId for ${partitionStates.size} partitions") + if (stateChangeLogger.isTraceEnabled) + partitionStates.forKeyValue { (topicPartition, partitionState) => + stateChangeLogger.trace(s"Received StopReplica request $partitionState " + + s"correlation id $correlationId from controller $controllerId " + + s"epoch $controllerEpoch for partition $topicPartition") + } + + val responseMap = new collection.mutable.HashMap[TopicPartition, Errors] + if (controllerEpoch < this.controllerEpoch) { + stateChangeLogger.warn(s"Ignoring StopReplica request from " + + s"controller $controllerId with correlation id $correlationId " + + s"since its controller epoch $controllerEpoch is old. " + + s"Latest known controller epoch is ${this.controllerEpoch}") + (responseMap, Errors.STALE_CONTROLLER_EPOCH) + } else { + this.controllerEpoch = controllerEpoch + + val stoppedPartitions = mutable.Map.empty[TopicPartition, Boolean] + partitionStates.forKeyValue { (topicPartition, partitionState) => + val deletePartition = partitionState.deletePartition() + + getPartition(topicPartition) match { + case HostedPartition.Offline => + stateChangeLogger.warn(s"Ignoring StopReplica request (delete=$deletePartition) from " + + s"controller $controllerId with correlation id $correlationId " + + s"epoch $controllerEpoch for partition $topicPartition as the local replica for the " + + "partition is in an offline log directory") + responseMap.put(topicPartition, Errors.KAFKA_STORAGE_ERROR) + + case HostedPartition.Online(partition) => + val currentLeaderEpoch = partition.getLeaderEpoch + val requestLeaderEpoch = partitionState.leaderEpoch + // When a topic is deleted, the leader epoch is not incremented. To circumvent this, + // a sentinel value (EpochDuringDelete) overwriting any previous epoch is used. + // When an older version of the StopReplica request which does not contain the leader + // epoch, a sentinel value (NoEpoch) is used and bypass the epoch validation. + if (requestLeaderEpoch == LeaderAndIsr.EpochDuringDelete || + requestLeaderEpoch == LeaderAndIsr.NoEpoch || + requestLeaderEpoch > currentLeaderEpoch) { + stoppedPartitions += topicPartition -> deletePartition + // Assume that everything will go right. It is overwritten in case of an error. + responseMap.put(topicPartition, Errors.NONE) + } else if (requestLeaderEpoch < currentLeaderEpoch) { + stateChangeLogger.warn(s"Ignoring StopReplica request (delete=$deletePartition) from " + + s"controller $controllerId with correlation id $correlationId " + + s"epoch $controllerEpoch for partition $topicPartition since its associated " + + s"leader epoch $requestLeaderEpoch is smaller than the current " + + s"leader epoch $currentLeaderEpoch") + responseMap.put(topicPartition, Errors.FENCED_LEADER_EPOCH) + } else { + stateChangeLogger.info(s"Ignoring StopReplica request (delete=$deletePartition) from " + + s"controller $controllerId with correlation id $correlationId " + + s"epoch $controllerEpoch for partition $topicPartition since its associated " + + s"leader epoch $requestLeaderEpoch matches the current leader epoch") + responseMap.put(topicPartition, Errors.FENCED_LEADER_EPOCH) + } + + case HostedPartition.None => + // Delete log and corresponding folders in case replica manager doesn't hold them anymore. + // This could happen when topic is being deleted while broker is down and recovers. + stoppedPartitions += topicPartition -> deletePartition + responseMap.put(topicPartition, Errors.NONE) + } + } + + stopPartitions(stoppedPartitions).foreach { case (topicPartition, e) => + if (e.isInstanceOf[KafkaStorageException]) { + stateChangeLogger.error(s"Ignoring StopReplica request (delete=true) from " + + s"controller $controllerId with correlation id $correlationId " + + s"epoch $controllerEpoch for partition $topicPartition as the local replica for the " + + "partition is in an offline log directory") + } else { + stateChangeLogger.error(s"Ignoring StopReplica request (delete=true) from " + + s"controller $controllerId with correlation id $correlationId " + + s"epoch $controllerEpoch for partition $topicPartition due to an unexpected " + + s"${e.getClass.getName} exception: ${e.getMessage}") + } + responseMap.put(topicPartition, Errors.forException(e)) + } + (responseMap, Errors.NONE) + } + } + } + + /** + * Stop the given partitions. + * + * @param partitionsToStop A map from a topic partition to a boolean indicating + * whether the partition should be deleted. + * + * @return A map from partitions to exceptions which occurred. + * If no errors occurred, the map will be empty. + */ + protected def stopPartitions( + partitionsToStop: Map[TopicPartition, Boolean] + ): Map[TopicPartition, Throwable] = { + // First stop fetchers for all partitions. + val partitions = partitionsToStop.keySet + replicaFetcherManager.removeFetcherForPartitions(partitions) + replicaAlterLogDirsManager.removeFetcherForPartitions(partitions) + + // Second remove deleted partitions from the partition map. Fetchers rely on the + // ReplicaManager to get Partition's information so they must be stopped first. + val partitionsToDelete = mutable.Set.empty[TopicPartition] + partitionsToStop.forKeyValue { (topicPartition, shouldDelete) => + if (shouldDelete) { + getPartition(topicPartition) match { + case hostedPartition: HostedPartition.Online => + if (allPartitions.remove(topicPartition, hostedPartition)) { + maybeRemoveTopicMetrics(topicPartition.topic) + // Logs are not deleted here. They are deleted in a single batch later on. + // This is done to avoid having to checkpoint for every deletions. + hostedPartition.partition.delete() + } + + case _ => + } + partitionsToDelete += topicPartition + } + // If we were the leader, we may have some operations still waiting for completion. + // We force completion to prevent them from timing out. + completeDelayedFetchOrProduceRequests(topicPartition) + } + + // Third delete the logs and checkpoint. + val errorMap = new mutable.HashMap[TopicPartition, Throwable]() + if (partitionsToDelete.nonEmpty) { + // Delete the logs and checkpoint. + logManager.asyncDelete(partitionsToDelete, (tp, e) => errorMap.put(tp, e)) + } + errorMap + } + + def getPartition(topicPartition: TopicPartition): HostedPartition = { + Option(allPartitions.get(topicPartition)).getOrElse(HostedPartition.None) + } + + def isAddingReplica(topicPartition: TopicPartition, replicaId: Int): Boolean = { + getPartition(topicPartition) match { + case Online(partition) => partition.isAddingReplica(replicaId) + case _ => false + } + } + + // Visible for testing + def createPartition(topicPartition: TopicPartition): Partition = { + val partition = Partition(topicPartition, time, this) + allPartitions.put(topicPartition, HostedPartition.Online(partition)) + partition + } + + def onlinePartition(topicPartition: TopicPartition): Option[Partition] = { + getPartition(topicPartition) match { + case HostedPartition.Online(partition) => Some(partition) + case _ => None + } + } + + // An iterator over all non offline partitions. This is a weakly consistent iterator; a partition made offline after + // the iterator has been constructed could still be returned by this iterator. + private def onlinePartitionsIterator: Iterator[Partition] = { + allPartitions.values.iterator.flatMap { + case HostedPartition.Online(partition) => Some(partition) + case _ => None + } + } + + private def offlinePartitionCount: Int = { + allPartitions.values.iterator.count(_ == HostedPartition.Offline) + } + + def getPartitionOrException(topicPartition: TopicPartition): Partition = { + getPartitionOrError(topicPartition) match { + case Left(Errors.KAFKA_STORAGE_ERROR) => + throw new KafkaStorageException(s"Partition $topicPartition is in an offline log directory") + + case Left(error) => + throw error.exception(s"Error while fetching partition state for $topicPartition") + + case Right(partition) => partition + } + } + + def getPartitionOrError(topicPartition: TopicPartition): Either[Errors, Partition] = { + getPartition(topicPartition) match { + case HostedPartition.Online(partition) => + Right(partition) + + case HostedPartition.Offline => + Left(Errors.KAFKA_STORAGE_ERROR) + + case HostedPartition.None if metadataCache.contains(topicPartition) => + // The topic exists, but this broker is no longer a replica of it, so we return NOT_LEADER_OR_FOLLOWER which + // forces clients to refresh metadata to find the new location. This can happen, for example, + // during a partition reassignment if a produce request from the client is sent to a broker after + // the local replica has been deleted. + Left(Errors.NOT_LEADER_OR_FOLLOWER) + + case HostedPartition.None => + Left(Errors.UNKNOWN_TOPIC_OR_PARTITION) + } + } + + def localLogOrException(topicPartition: TopicPartition): UnifiedLog = { + getPartitionOrException(topicPartition).localLogOrException + } + + def futureLocalLogOrException(topicPartition: TopicPartition): UnifiedLog = { + getPartitionOrException(topicPartition).futureLocalLogOrException + } + + def futureLogExists(topicPartition: TopicPartition): Boolean = { + getPartitionOrException(topicPartition).futureLog.isDefined + } + + def localLog(topicPartition: TopicPartition): Option[UnifiedLog] = { + onlinePartition(topicPartition).flatMap(_.log) + } + + def getLogDir(topicPartition: TopicPartition): Option[String] = { + localLog(topicPartition).map(_.parentDir) + } + + /** + * TODO: move this action queue to handle thread so we can simplify concurrency handling + */ + private val actionQueue = new ActionQueue + + def tryCompleteActions(): Unit = actionQueue.tryCompleteActions() + + /** + * Append messages to leader replicas of the partition, and wait for them to be replicated to other replicas; + * the callback function will be triggered either when timeout or the required acks are satisfied; + * if the callback function itself is already synchronized on some object then pass this object to avoid deadlock. + * + * Noted that all pending delayed check operations are stored in a queue. All callers to ReplicaManager.appendRecords() + * are expected to call ActionQueue.tryCompleteActions for all affected partitions, without holding any conflicting + * locks. + */ + def appendRecords(timeout: Long, + requiredAcks: Short, + internalTopicsAllowed: Boolean, + origin: AppendOrigin, + entriesPerPartition: Map[TopicPartition, MemoryRecords], + responseCallback: Map[TopicPartition, PartitionResponse] => Unit, + delayedProduceLock: Option[Lock] = None, + recordConversionStatsCallback: Map[TopicPartition, RecordConversionStats] => Unit = _ => (), + requestLocal: RequestLocal = RequestLocal.NoCaching): Unit = { + if (isValidRequiredAcks(requiredAcks)) { + val sTime = time.milliseconds + val localProduceResults = appendToLocalLog(internalTopicsAllowed = internalTopicsAllowed, + origin, entriesPerPartition, requiredAcks, requestLocal) + debug("Produce to local log in %d ms".format(time.milliseconds - sTime)) + + val produceStatus = localProduceResults.map { case (topicPartition, result) => + topicPartition -> ProducePartitionStatus( + result.info.lastOffset + 1, // required offset + new PartitionResponse( + result.error, + result.info.firstOffset.map(_.messageOffset).getOrElse(-1), + result.info.logAppendTime, + result.info.logStartOffset, + result.info.recordErrors.asJava, + result.info.errorMessage + ) + ) // response status + } + + actionQueue.add { + () => + localProduceResults.foreach { + case (topicPartition, result) => + val requestKey = TopicPartitionOperationKey(topicPartition) + result.info.leaderHwChange match { + case LeaderHwChange.Increased => + // some delayed operations may be unblocked after HW changed + delayedProducePurgatory.checkAndComplete(requestKey) + delayedFetchPurgatory.checkAndComplete(requestKey) + delayedDeleteRecordsPurgatory.checkAndComplete(requestKey) + case LeaderHwChange.Same => + // probably unblock some follower fetch requests since log end offset has been updated + delayedFetchPurgatory.checkAndComplete(requestKey) + case LeaderHwChange.None => + // nothing + } + } + } + + recordConversionStatsCallback(localProduceResults.map { case (k, v) => k -> v.info.recordConversionStats }) + + if (delayedProduceRequestRequired(requiredAcks, entriesPerPartition, localProduceResults)) { + // create delayed produce operation + val produceMetadata = ProduceMetadata(requiredAcks, produceStatus) + val delayedProduce = new DelayedProduce(timeout, produceMetadata, this, responseCallback, delayedProduceLock) + + // create a list of (topic, partition) pairs to use as keys for this delayed produce operation + val producerRequestKeys = entriesPerPartition.keys.map(TopicPartitionOperationKey(_)).toSeq + + // try to complete the request immediately, otherwise put it into the purgatory + // this is because while the delayed produce operation is being created, new + // requests may arrive and hence make this operation completable. + delayedProducePurgatory.tryCompleteElseWatch(delayedProduce, producerRequestKeys) + + } else { + // we can respond immediately + val produceResponseStatus = produceStatus.map { case (k, status) => k -> status.responseStatus } + responseCallback(produceResponseStatus) + } + } else { + // If required.acks is outside accepted range, something is wrong with the client + // Just return an error and don't handle the request at all + val responseStatus = entriesPerPartition.map { case (topicPartition, _) => + topicPartition -> new PartitionResponse( + Errors.INVALID_REQUIRED_ACKS, + LogAppendInfo.UnknownLogAppendInfo.firstOffset.map(_.messageOffset).getOrElse(-1), + RecordBatch.NO_TIMESTAMP, + LogAppendInfo.UnknownLogAppendInfo.logStartOffset + ) + } + responseCallback(responseStatus) + } + } + + /** + * Delete records on leader replicas of the partition, and wait for delete records operation be propagated to other replicas; + * the callback function will be triggered either when timeout or logStartOffset of all live replicas have reached the specified offset + */ + private def deleteRecordsOnLocalLog(offsetPerPartition: Map[TopicPartition, Long]): Map[TopicPartition, LogDeleteRecordsResult] = { + trace("Delete records on local logs to offsets [%s]".format(offsetPerPartition)) + offsetPerPartition.map { case (topicPartition, requestedOffset) => + // reject delete records operation on internal topics + if (Topic.isInternal(topicPartition.topic)) { + (topicPartition, LogDeleteRecordsResult(-1L, -1L, Some(new InvalidTopicException(s"Cannot delete records of internal topic ${topicPartition.topic}")))) + } else { + try { + val partition = getPartitionOrException(topicPartition) + val logDeleteResult = partition.deleteRecordsOnLeader(requestedOffset) + (topicPartition, logDeleteResult) + } catch { + case e@ (_: UnknownTopicOrPartitionException | + _: NotLeaderOrFollowerException | + _: OffsetOutOfRangeException | + _: PolicyViolationException | + _: KafkaStorageException) => + (topicPartition, LogDeleteRecordsResult(-1L, -1L, Some(e))) + case t: Throwable => + error("Error processing delete records operation on partition %s".format(topicPartition), t) + (topicPartition, LogDeleteRecordsResult(-1L, -1L, Some(t))) + } + } + } + } + + // If there exists a topic partition that meets the following requirement, + // we need to put a delayed DeleteRecordsRequest and wait for the delete records operation to complete + // + // 1. the delete records operation on this partition is successful + // 2. low watermark of this partition is smaller than the specified offset + private def delayedDeleteRecordsRequired(localDeleteRecordsResults: Map[TopicPartition, LogDeleteRecordsResult]): Boolean = { + localDeleteRecordsResults.exists{ case (_, deleteRecordsResult) => + deleteRecordsResult.exception.isEmpty && deleteRecordsResult.lowWatermark < deleteRecordsResult.requestedOffset + } + } + + /** + * For each pair of partition and log directory specified in the map, if the partition has already been created on + * this broker, move its log files to the specified log directory. Otherwise, record the pair in the memory so that + * the partition will be created in the specified log directory when broker receives LeaderAndIsrRequest for the partition later. + */ + def alterReplicaLogDirs(partitionDirs: Map[TopicPartition, String]): Map[TopicPartition, Errors] = { + replicaStateChangeLock synchronized { + partitionDirs.map { case (topicPartition, destinationDir) => + try { + /* If the topic name is exceptionally long, we can't support altering the log directory. + * See KAFKA-4893 for details. + * TODO: fix this by implementing topic IDs. */ + if (UnifiedLog.logFutureDirName(topicPartition).size > 255) + throw new InvalidTopicException("The topic name is too long.") + if (!logManager.isLogDirOnline(destinationDir)) + throw new KafkaStorageException(s"Log directory $destinationDir is offline") + + getPartition(topicPartition) match { + case HostedPartition.Online(partition) => + // Stop current replica movement if the destinationDir is different from the existing destination log directory + if (partition.futureReplicaDirChanged(destinationDir)) { + replicaAlterLogDirsManager.removeFetcherForPartitions(Set(topicPartition)) + partition.removeFutureLocalReplica() + } + case HostedPartition.Offline => + throw new KafkaStorageException(s"Partition $topicPartition is offline") + + case HostedPartition.None => // Do nothing + } + + // If the log for this partition has not been created yet: + // 1) Record the destination log directory in the memory so that the partition will be created in this log directory + // when broker receives LeaderAndIsrRequest for this partition later. + // 2) Respond with NotLeaderOrFollowerException for this partition in the AlterReplicaLogDirsResponse + logManager.maybeUpdatePreferredLogDir(topicPartition, destinationDir) + + // throw NotLeaderOrFollowerException if replica does not exist for the given partition + val partition = getPartitionOrException(topicPartition) + val log = partition.localLogOrException + val topicId = log.topicId + + // If the destinationLDir is different from the current log directory of the replica: + // - If there is no offline log directory, create the future log in the destinationDir (if it does not exist) and + // start ReplicaAlterDirThread to move data of this partition from the current log to the future log + // - Otherwise, return KafkaStorageException. We do not create the future log while there is offline log directory + // so that we can avoid creating future log for the same partition in multiple log directories. + val highWatermarkCheckpoints = new LazyOffsetCheckpoints(this.highWatermarkCheckpoints) + if (partition.maybeCreateFutureReplica(destinationDir, highWatermarkCheckpoints)) { + val futureLog = futureLocalLogOrException(topicPartition) + logManager.abortAndPauseCleaning(topicPartition) + + val initialFetchState = InitialFetchState(topicId, BrokerEndPoint(config.brokerId, "localhost", -1), + partition.getLeaderEpoch, futureLog.highWatermark) + replicaAlterLogDirsManager.addFetcherForPartitions(Map(topicPartition -> initialFetchState)) + } + + (topicPartition, Errors.NONE) + } catch { + case e@(_: InvalidTopicException | + _: LogDirNotFoundException | + _: ReplicaNotAvailableException | + _: KafkaStorageException) => + warn(s"Unable to alter log dirs for $topicPartition", e) + (topicPartition, Errors.forException(e)) + case e: NotLeaderOrFollowerException => + // Retaining REPLICA_NOT_AVAILABLE exception for ALTER_REPLICA_LOG_DIRS for compatibility + warn(s"Unable to alter log dirs for $topicPartition", e) + (topicPartition, Errors.REPLICA_NOT_AVAILABLE) + case t: Throwable => + error("Error while changing replica dir for partition %s".format(topicPartition), t) + (topicPartition, Errors.forException(t)) + } + } + } + } + + /* + * Get the LogDirInfo for the specified list of partitions. + * + * Each LogDirInfo specifies the following information for a given log directory: + * 1) Error of the log directory, e.g. whether the log is online or offline + * 2) size and lag of current and future logs for each partition in the given log directory. Only logs of the queried partitions + * are included. There may be future logs (which will replace the current logs of the partition in the future) on the broker after KIP-113 is implemented. + */ + def describeLogDirs(partitions: Set[TopicPartition]): List[DescribeLogDirsResponseData.DescribeLogDirsResult] = { + val logsByDir = logManager.allLogs.groupBy(log => log.parentDir) + + config.logDirs.toSet.map { logDir: String => + val absolutePath = new File(logDir).getAbsolutePath + try { + if (!logManager.isLogDirOnline(absolutePath)) + throw new KafkaStorageException(s"Log directory $absolutePath is offline") + + logsByDir.get(absolutePath) match { + case Some(logs) => + val topicInfos = logs.groupBy(_.topicPartition.topic).map{case (topic, logs) => + new DescribeLogDirsResponseData.DescribeLogDirsTopic().setName(topic).setPartitions( + logs.filter { log => + partitions.contains(log.topicPartition) + }.map { log => + new DescribeLogDirsResponseData.DescribeLogDirsPartition() + .setPartitionSize(log.size) + .setPartitionIndex(log.topicPartition.partition) + .setOffsetLag(getLogEndOffsetLag(log.topicPartition, log.logEndOffset, log.isFuture)) + .setIsFutureKey(log.isFuture) + }.toList.asJava) + }.toList.asJava + + new DescribeLogDirsResponseData.DescribeLogDirsResult().setLogDir(absolutePath) + .setErrorCode(Errors.NONE.code).setTopics(topicInfos) + case None => + new DescribeLogDirsResponseData.DescribeLogDirsResult().setLogDir(absolutePath) + .setErrorCode(Errors.NONE.code) + } + + } catch { + case e: KafkaStorageException => + warn("Unable to describe replica dirs for %s".format(absolutePath), e) + new DescribeLogDirsResponseData.DescribeLogDirsResult() + .setLogDir(absolutePath) + .setErrorCode(Errors.KAFKA_STORAGE_ERROR.code) + case t: Throwable => + error(s"Error while describing replica in dir $absolutePath", t) + new DescribeLogDirsResponseData.DescribeLogDirsResult() + .setLogDir(absolutePath) + .setErrorCode(Errors.forException(t).code) + } + }.toList + } + + def getLogEndOffsetLag(topicPartition: TopicPartition, logEndOffset: Long, isFuture: Boolean): Long = { + localLog(topicPartition) match { + case Some(log) => + if (isFuture) + log.logEndOffset - logEndOffset + else + math.max(log.highWatermark - logEndOffset, 0) + case None => + // return -1L to indicate that the LEO lag is not available if the replica is not created or is offline + DescribeLogDirsResponse.INVALID_OFFSET_LAG + } + } + + def deleteRecords(timeout: Long, + offsetPerPartition: Map[TopicPartition, Long], + responseCallback: Map[TopicPartition, DeleteRecordsPartitionResult] => Unit): Unit = { + val timeBeforeLocalDeleteRecords = time.milliseconds + val localDeleteRecordsResults = deleteRecordsOnLocalLog(offsetPerPartition) + debug("Delete records on local log in %d ms".format(time.milliseconds - timeBeforeLocalDeleteRecords)) + + val deleteRecordsStatus = localDeleteRecordsResults.map { case (topicPartition, result) => + topicPartition -> + DeleteRecordsPartitionStatus( + result.requestedOffset, // requested offset + new DeleteRecordsPartitionResult() + .setLowWatermark(result.lowWatermark) + .setErrorCode(result.error.code) + .setPartitionIndex(topicPartition.partition)) // response status + } + + if (delayedDeleteRecordsRequired(localDeleteRecordsResults)) { + // create delayed delete records operation + val delayedDeleteRecords = new DelayedDeleteRecords(timeout, deleteRecordsStatus, this, responseCallback) + + // create a list of (topic, partition) pairs to use as keys for this delayed delete records operation + val deleteRecordsRequestKeys = offsetPerPartition.keys.map(TopicPartitionOperationKey(_)).toSeq + + // try to complete the request immediately, otherwise put it into the purgatory + // this is because while the delayed delete records operation is being created, new + // requests may arrive and hence make this operation completable. + delayedDeleteRecordsPurgatory.tryCompleteElseWatch(delayedDeleteRecords, deleteRecordsRequestKeys) + } else { + // we can respond immediately + val deleteRecordsResponseStatus = deleteRecordsStatus.map { case (k, status) => k -> status.responseStatus } + responseCallback(deleteRecordsResponseStatus) + } + } + + // If all the following conditions are true, we need to put a delayed produce request and wait for replication to complete + // + // 1. required acks = -1 + // 2. there is data to append + // 3. at least one partition append was successful (fewer errors than partitions) + private def delayedProduceRequestRequired(requiredAcks: Short, + entriesPerPartition: Map[TopicPartition, MemoryRecords], + localProduceResults: Map[TopicPartition, LogAppendResult]): Boolean = { + requiredAcks == -1 && + entriesPerPartition.nonEmpty && + localProduceResults.values.count(_.exception.isDefined) < entriesPerPartition.size + } + + private def isValidRequiredAcks(requiredAcks: Short): Boolean = { + requiredAcks == -1 || requiredAcks == 1 || requiredAcks == 0 + } + + /** + * Append the messages to the local replica logs + */ + private def appendToLocalLog(internalTopicsAllowed: Boolean, + origin: AppendOrigin, + entriesPerPartition: Map[TopicPartition, MemoryRecords], + requiredAcks: Short, + requestLocal: RequestLocal): Map[TopicPartition, LogAppendResult] = { + val traceEnabled = isTraceEnabled + def processFailedRecord(topicPartition: TopicPartition, t: Throwable) = { + val logStartOffset = onlinePartition(topicPartition).map(_.logStartOffset).getOrElse(-1L) + brokerTopicStats.topicStats(topicPartition.topic).failedProduceRequestRate.mark() + brokerTopicStats.allTopicsStats.failedProduceRequestRate.mark() + error(s"Error processing append operation on partition $topicPartition", t) + + logStartOffset + } + + if (traceEnabled) + trace(s"Append [$entriesPerPartition] to local log") + + entriesPerPartition.map { case (topicPartition, records) => + brokerTopicStats.topicStats(topicPartition.topic).totalProduceRequestRate.mark() + brokerTopicStats.allTopicsStats.totalProduceRequestRate.mark() + + // reject appending to internal topics if it is not allowed + if (Topic.isInternal(topicPartition.topic) && !internalTopicsAllowed) { + (topicPartition, LogAppendResult( + LogAppendInfo.UnknownLogAppendInfo, + Some(new InvalidTopicException(s"Cannot append to internal topic ${topicPartition.topic}")))) + } else { + try { + val partition = getPartitionOrException(topicPartition) + val info = partition.appendRecordsToLeader(records, origin, requiredAcks, requestLocal) + val numAppendedMessages = info.numMessages + + // update stats for successfully appended bytes and messages as bytesInRate and messageInRate + brokerTopicStats.topicStats(topicPartition.topic).bytesInRate.mark(records.sizeInBytes) + brokerTopicStats.allTopicsStats.bytesInRate.mark(records.sizeInBytes) + brokerTopicStats.topicStats(topicPartition.topic).messagesInRate.mark(numAppendedMessages) + brokerTopicStats.allTopicsStats.messagesInRate.mark(numAppendedMessages) + + if (traceEnabled) + trace(s"${records.sizeInBytes} written to log $topicPartition beginning at offset " + + s"${info.firstOffset.getOrElse(-1)} and ending at offset ${info.lastOffset}") + + (topicPartition, LogAppendResult(info)) + } catch { + // NOTE: Failed produce requests metric is not incremented for known exceptions + // it is supposed to indicate un-expected failures of a broker in handling a produce request + case e@ (_: UnknownTopicOrPartitionException | + _: NotLeaderOrFollowerException | + _: RecordTooLargeException | + _: RecordBatchTooLargeException | + _: CorruptRecordException | + _: KafkaStorageException) => + (topicPartition, LogAppendResult(LogAppendInfo.UnknownLogAppendInfo, Some(e))) + case rve: RecordValidationException => + val logStartOffset = processFailedRecord(topicPartition, rve.invalidException) + val recordErrors = rve.recordErrors + (topicPartition, LogAppendResult(LogAppendInfo.unknownLogAppendInfoWithAdditionalInfo( + logStartOffset, recordErrors, rve.invalidException.getMessage), Some(rve.invalidException))) + case t: Throwable => + val logStartOffset = processFailedRecord(topicPartition, t) + (topicPartition, LogAppendResult(LogAppendInfo.unknownLogAppendInfoWithLogStartOffset(logStartOffset), Some(t))) + } + } + } + } + + def fetchOffsetForTimestamp(topicPartition: TopicPartition, + timestamp: Long, + isolationLevel: Option[IsolationLevel], + currentLeaderEpoch: Optional[Integer], + fetchOnlyFromLeader: Boolean): Option[TimestampAndOffset] = { + val partition = getPartitionOrException(topicPartition) + partition.fetchOffsetForTimestamp(timestamp, isolationLevel, currentLeaderEpoch, fetchOnlyFromLeader) + } + + def legacyFetchOffsetsForTimestamp(topicPartition: TopicPartition, + timestamp: Long, + maxNumOffsets: Int, + isFromConsumer: Boolean, + fetchOnlyFromLeader: Boolean): Seq[Long] = { + val partition = getPartitionOrException(topicPartition) + partition.legacyFetchOffsetsForTimestamp(timestamp, maxNumOffsets, isFromConsumer, fetchOnlyFromLeader) + } + + /** + * Fetch messages from a replica, and wait until enough data can be fetched and return; + * the callback function will be triggered either when timeout or required fetch info is satisfied. + * Consumers may fetch from any replica, but followers can only fetch from the leader. + */ + def fetchMessages(timeout: Long, + replicaId: Int, + fetchMinBytes: Int, + fetchMaxBytes: Int, + hardMaxBytesLimit: Boolean, + fetchInfos: Seq[(TopicIdPartition, PartitionData)], + quota: ReplicaQuota, + responseCallback: Seq[(TopicIdPartition, FetchPartitionData)] => Unit, + isolationLevel: IsolationLevel, + clientMetadata: Option[ClientMetadata]): Unit = { + val isFromFollower = Request.isValidBrokerId(replicaId) + val isFromConsumer = !(isFromFollower || replicaId == Request.FutureLocalReplicaId) + val fetchIsolation = if (!isFromConsumer) + FetchLogEnd + else if (isolationLevel == IsolationLevel.READ_COMMITTED) + FetchTxnCommitted + else + FetchHighWatermark + + // Restrict fetching to leader if request is from follower or from a client with older version (no ClientMetadata) + val fetchOnlyFromLeader = isFromFollower || (isFromConsumer && clientMetadata.isEmpty) + def readFromLog(): Seq[(TopicIdPartition, LogReadResult)] = { + val result = readFromLocalLog( + replicaId = replicaId, + fetchOnlyFromLeader = fetchOnlyFromLeader, + fetchIsolation = fetchIsolation, + fetchMaxBytes = fetchMaxBytes, + hardMaxBytesLimit = hardMaxBytesLimit, + readPartitionInfo = fetchInfos, + quota = quota, + clientMetadata = clientMetadata) + if (isFromFollower) updateFollowerFetchState(replicaId, result) + else result + } + + val logReadResults = readFromLog() + + // check if this fetch request can be satisfied right away + var bytesReadable: Long = 0 + var errorReadingData = false + var hasDivergingEpoch = false + val logReadResultMap = new mutable.HashMap[TopicIdPartition, LogReadResult] + logReadResults.foreach { case (topicIdPartition, logReadResult) => + brokerTopicStats.topicStats(topicIdPartition.topicPartition.topic).totalFetchRequestRate.mark() + brokerTopicStats.allTopicsStats.totalFetchRequestRate.mark() + + if (logReadResult.error != Errors.NONE) + errorReadingData = true + if (logReadResult.divergingEpoch.nonEmpty) + hasDivergingEpoch = true + bytesReadable = bytesReadable + logReadResult.info.records.sizeInBytes + logReadResultMap.put(topicIdPartition, logReadResult) + } + + // respond immediately if 1) fetch request does not want to wait + // 2) fetch request does not require any data + // 3) has enough data to respond + // 4) some error happens while reading data + // 5) we found a diverging epoch + if (timeout <= 0 || fetchInfos.isEmpty || bytesReadable >= fetchMinBytes || errorReadingData || hasDivergingEpoch) { + val fetchPartitionData = logReadResults.map { case (tp, result) => + val isReassignmentFetch = isFromFollower && isAddingReplica(tp.topicPartition, replicaId) + tp -> result.toFetchPartitionData(isReassignmentFetch) + } + responseCallback(fetchPartitionData) + } else { + // construct the fetch results from the read results + val fetchPartitionStatus = new mutable.ArrayBuffer[(TopicIdPartition, FetchPartitionStatus)] + fetchInfos.foreach { case (topicIdPartition, partitionData) => + logReadResultMap.get(topicIdPartition).foreach(logReadResult => { + val logOffsetMetadata = logReadResult.info.fetchOffsetMetadata + fetchPartitionStatus += (topicIdPartition -> FetchPartitionStatus(logOffsetMetadata, partitionData)) + }) + } + val fetchMetadata: SFetchMetadata = SFetchMetadata(fetchMinBytes, fetchMaxBytes, hardMaxBytesLimit, + fetchOnlyFromLeader, fetchIsolation, isFromFollower, replicaId, fetchPartitionStatus) + val delayedFetch = new DelayedFetch(timeout, fetchMetadata, this, quota, clientMetadata, + responseCallback) + + // create a list of (topic, partition) pairs to use as keys for this delayed fetch operation + val delayedFetchKeys = fetchPartitionStatus.map { case (tp, _) => TopicPartitionOperationKey(tp) } + + // try to complete the request immediately, otherwise put it into the purgatory; + // this is because while the delayed fetch operation is being created, new requests + // may arrive and hence make this operation completable. + delayedFetchPurgatory.tryCompleteElseWatch(delayedFetch, delayedFetchKeys) + } + } + + /** + * Read from multiple topic partitions at the given offset up to maxSize bytes + */ + def readFromLocalLog(replicaId: Int, + fetchOnlyFromLeader: Boolean, + fetchIsolation: FetchIsolation, + fetchMaxBytes: Int, + hardMaxBytesLimit: Boolean, + readPartitionInfo: Seq[(TopicIdPartition, PartitionData)], + quota: ReplicaQuota, + clientMetadata: Option[ClientMetadata]): Seq[(TopicIdPartition, LogReadResult)] = { + val traceEnabled = isTraceEnabled + + def read(tp: TopicIdPartition, fetchInfo: PartitionData, limitBytes: Int, minOneMessage: Boolean): LogReadResult = { + val offset = fetchInfo.fetchOffset + val partitionFetchSize = fetchInfo.maxBytes + val followerLogStartOffset = fetchInfo.logStartOffset + + val adjustedMaxBytes = math.min(fetchInfo.maxBytes, limitBytes) + try { + if (traceEnabled) + trace(s"Fetching log segment for partition $tp, offset $offset, partition fetch size $partitionFetchSize, " + + s"remaining response limit $limitBytes" + + (if (minOneMessage) s", ignoring response/partition size limits" else "")) + + val partition = getPartitionOrException(tp.topicPartition) + val fetchTimeMs = time.milliseconds + + // Check if topic ID from the fetch request/session matches the ID in the log + val topicId = if (tp.topicId == Uuid.ZERO_UUID) None else Some(tp.topicId) + if (!hasConsistentTopicId(topicId, partition.topicId)) + throw new InconsistentTopicIdException("Topic ID in the fetch session did not match the topic ID in the log.") + + // If we are the leader, determine the preferred read-replica + val preferredReadReplica = clientMetadata.flatMap( + metadata => findPreferredReadReplica(partition, metadata, replicaId, fetchInfo.fetchOffset, fetchTimeMs)) + + if (preferredReadReplica.isDefined) { + replicaSelectorOpt.foreach { selector => + debug(s"Replica selector ${selector.getClass.getSimpleName} returned preferred replica " + + s"${preferredReadReplica.get} for $clientMetadata") + } + // If a preferred read-replica is set, skip the read + val offsetSnapshot = partition.fetchOffsetSnapshot(fetchInfo.currentLeaderEpoch, fetchOnlyFromLeader = false) + LogReadResult(info = FetchDataInfo(LogOffsetMetadata.UnknownOffsetMetadata, MemoryRecords.EMPTY), + divergingEpoch = None, + highWatermark = offsetSnapshot.highWatermark.messageOffset, + leaderLogStartOffset = offsetSnapshot.logStartOffset, + leaderLogEndOffset = offsetSnapshot.logEndOffset.messageOffset, + followerLogStartOffset = followerLogStartOffset, + fetchTimeMs = -1L, + lastStableOffset = Some(offsetSnapshot.lastStableOffset.messageOffset), + preferredReadReplica = preferredReadReplica, + exception = None) + } else { + // Try the read first, this tells us whether we need all of adjustedFetchSize for this partition + val readInfo: LogReadInfo = partition.readRecords( + lastFetchedEpoch = fetchInfo.lastFetchedEpoch, + fetchOffset = fetchInfo.fetchOffset, + currentLeaderEpoch = fetchInfo.currentLeaderEpoch, + maxBytes = adjustedMaxBytes, + fetchIsolation = fetchIsolation, + fetchOnlyFromLeader = fetchOnlyFromLeader, + minOneMessage = minOneMessage) + + val fetchDataInfo = if (shouldLeaderThrottle(quota, partition, replicaId)) { + // If the partition is being throttled, simply return an empty set. + FetchDataInfo(readInfo.fetchedData.fetchOffsetMetadata, MemoryRecords.EMPTY) + } else if (!hardMaxBytesLimit && readInfo.fetchedData.firstEntryIncomplete) { + // For FetchRequest version 3, we replace incomplete message sets with an empty one as consumers can make + // progress in such cases and don't need to report a `RecordTooLargeException` + FetchDataInfo(readInfo.fetchedData.fetchOffsetMetadata, MemoryRecords.EMPTY) + } else { + readInfo.fetchedData + } + + LogReadResult(info = fetchDataInfo, + divergingEpoch = readInfo.divergingEpoch, + highWatermark = readInfo.highWatermark, + leaderLogStartOffset = readInfo.logStartOffset, + leaderLogEndOffset = readInfo.logEndOffset, + followerLogStartOffset = followerLogStartOffset, + fetchTimeMs = fetchTimeMs, + lastStableOffset = Some(readInfo.lastStableOffset), + preferredReadReplica = preferredReadReplica, + exception = None) + } + } catch { + // NOTE: Failed fetch requests metric is not incremented for known exceptions since it + // is supposed to indicate un-expected failure of a broker in handling a fetch request + case e@ (_: UnknownTopicOrPartitionException | + _: NotLeaderOrFollowerException | + _: UnknownLeaderEpochException | + _: FencedLeaderEpochException | + _: ReplicaNotAvailableException | + _: KafkaStorageException | + _: OffsetOutOfRangeException | + _: InconsistentTopicIdException) => + LogReadResult(info = FetchDataInfo(LogOffsetMetadata.UnknownOffsetMetadata, MemoryRecords.EMPTY), + divergingEpoch = None, + highWatermark = UnifiedLog.UnknownOffset, + leaderLogStartOffset = UnifiedLog.UnknownOffset, + leaderLogEndOffset = UnifiedLog.UnknownOffset, + followerLogStartOffset = UnifiedLog.UnknownOffset, + fetchTimeMs = -1L, + lastStableOffset = None, + exception = Some(e)) + case e: Throwable => + brokerTopicStats.topicStats(tp.topic).failedFetchRequestRate.mark() + brokerTopicStats.allTopicsStats.failedFetchRequestRate.mark() + + val fetchSource = Request.describeReplicaId(replicaId) + error(s"Error processing fetch with max size $adjustedMaxBytes from $fetchSource " + + s"on partition $tp: $fetchInfo", e) + + LogReadResult(info = FetchDataInfo(LogOffsetMetadata.UnknownOffsetMetadata, MemoryRecords.EMPTY), + divergingEpoch = None, + highWatermark = UnifiedLog.UnknownOffset, + leaderLogStartOffset = UnifiedLog.UnknownOffset, + leaderLogEndOffset = UnifiedLog.UnknownOffset, + followerLogStartOffset = UnifiedLog.UnknownOffset, + fetchTimeMs = -1L, + lastStableOffset = None, + exception = Some(e)) + } + } + + var limitBytes = fetchMaxBytes + val result = new mutable.ArrayBuffer[(TopicIdPartition, LogReadResult)] + var minOneMessage = !hardMaxBytesLimit + readPartitionInfo.foreach { case (tp, fetchInfo) => + val readResult = read(tp, fetchInfo, limitBytes, minOneMessage) + val recordBatchSize = readResult.info.records.sizeInBytes + // Once we read from a non-empty partition, we stop ignoring request and partition level size limits + if (recordBatchSize > 0) + minOneMessage = false + limitBytes = math.max(0, limitBytes - recordBatchSize) + result += (tp -> readResult) + } + result + } + + /** + * Using the configured [[ReplicaSelector]], determine the preferred read replica for a partition given the + * client metadata, the requested offset, and the current set of replicas. If the preferred read replica is the + * leader, return None + */ + def findPreferredReadReplica(partition: Partition, + clientMetadata: ClientMetadata, + replicaId: Int, + fetchOffset: Long, + currentTimeMs: Long): Option[Int] = { + partition.leaderReplicaIdOpt.flatMap { leaderReplicaId => + // Don't look up preferred for follower fetches via normal replication + if (Request.isValidBrokerId(replicaId)) + None + else { + replicaSelectorOpt.flatMap { replicaSelector => + val replicaEndpoints = metadataCache.getPartitionReplicaEndpoints(partition.topicPartition, + new ListenerName(clientMetadata.listenerName)) + val replicaInfos = partition.remoteReplicas + // Exclude replicas that don't have the requested offset (whether or not if they're in the ISR) + .filter(replica => replica.logEndOffset >= fetchOffset && replica.logStartOffset <= fetchOffset) + .map(replica => new DefaultReplicaView( + replicaEndpoints.getOrElse(replica.brokerId, Node.noNode()), + replica.logEndOffset, + currentTimeMs - replica.lastCaughtUpTimeMs)) + + val leaderReplica = new DefaultReplicaView( + replicaEndpoints.getOrElse(leaderReplicaId, Node.noNode()), + partition.localLogOrException.logEndOffset, 0L) + val replicaInfoSet = mutable.Set[ReplicaView]() ++= replicaInfos += leaderReplica + + val partitionInfo = new DefaultPartitionView(replicaInfoSet.asJava, leaderReplica) + replicaSelector.select(partition.topicPartition, clientMetadata, partitionInfo).asScala.collect { + // Even though the replica selector can return the leader, we don't want to send it out with the + // FetchResponse, so we exclude it here + case selected if !selected.endpoint.isEmpty && selected != leaderReplica => selected.endpoint.id + } + } + } + } + } + + /** + * To avoid ISR thrashing, we only throttle a replica on the leader if it's in the throttled replica list, + * the quota is exceeded and the replica is not in sync. + */ + def shouldLeaderThrottle(quota: ReplicaQuota, partition: Partition, replicaId: Int): Boolean = { + val isReplicaInSync = partition.inSyncReplicaIds.contains(replicaId) + !isReplicaInSync && quota.isThrottled(partition.topicPartition) && quota.isQuotaExceeded + } + + def getLogConfig(topicPartition: TopicPartition): Option[LogConfig] = localLog(topicPartition).map(_.config) + + def getMagic(topicPartition: TopicPartition): Option[Byte] = getLogConfig(topicPartition).map(_.recordVersion.value) + + def maybeUpdateMetadataCache(correlationId: Int, updateMetadataRequest: UpdateMetadataRequest) : Seq[TopicPartition] = { + replicaStateChangeLock synchronized { + if (updateMetadataRequest.controllerEpoch < controllerEpoch) { + val stateControllerEpochErrorMessage = s"Received update metadata request with correlation id $correlationId " + + s"from an old controller ${updateMetadataRequest.controllerId} with epoch ${updateMetadataRequest.controllerEpoch}. " + + s"Latest known controller epoch is $controllerEpoch" + stateChangeLogger.warn(stateControllerEpochErrorMessage) + throw new ControllerMovedException(stateChangeLogger.messageWithPrefix(stateControllerEpochErrorMessage)) + } else { + val zkMetadataCache = metadataCache.asInstanceOf[ZkMetadataCache] + val deletedPartitions = zkMetadataCache.updateMetadata(correlationId, updateMetadataRequest) + controllerEpoch = updateMetadataRequest.controllerEpoch + deletedPartitions + } + } + } + + def becomeLeaderOrFollower(correlationId: Int, + leaderAndIsrRequest: LeaderAndIsrRequest, + onLeadershipChange: (Iterable[Partition], Iterable[Partition]) => Unit): LeaderAndIsrResponse = { + val startMs = time.milliseconds() + replicaStateChangeLock synchronized { + val controllerId = leaderAndIsrRequest.controllerId + val requestPartitionStates = leaderAndIsrRequest.partitionStates.asScala + stateChangeLogger.info(s"Handling LeaderAndIsr request correlationId $correlationId from controller " + + s"$controllerId for ${requestPartitionStates.size} partitions") + if (stateChangeLogger.isTraceEnabled) + requestPartitionStates.foreach { partitionState => + stateChangeLogger.trace(s"Received LeaderAndIsr request $partitionState " + + s"correlation id $correlationId from controller $controllerId " + + s"epoch ${leaderAndIsrRequest.controllerEpoch}") + } + val topicIds = leaderAndIsrRequest.topicIds() + def topicIdFromRequest(topicName: String): Option[Uuid] = { + val topicId = topicIds.get(topicName) + // if invalid topic ID return None + if (topicId == null || topicId == Uuid.ZERO_UUID) + None + else + Some(topicId) + } + + val response = { + if (leaderAndIsrRequest.controllerEpoch < controllerEpoch) { + stateChangeLogger.warn(s"Ignoring LeaderAndIsr request from controller $controllerId with " + + s"correlation id $correlationId since its controller epoch ${leaderAndIsrRequest.controllerEpoch} is old. " + + s"Latest known controller epoch is $controllerEpoch") + leaderAndIsrRequest.getErrorResponse(0, Errors.STALE_CONTROLLER_EPOCH.exception) + } else { + val responseMap = new mutable.HashMap[TopicPartition, Errors] + controllerEpoch = leaderAndIsrRequest.controllerEpoch + + val partitions = new mutable.HashSet[Partition]() + val partitionsToBeLeader = new mutable.HashMap[Partition, LeaderAndIsrPartitionState]() + val partitionsToBeFollower = new mutable.HashMap[Partition, LeaderAndIsrPartitionState]() + val topicIdUpdateFollowerPartitions = new mutable.HashSet[Partition]() + + // First create the partition if it doesn't exist already + requestPartitionStates.foreach { partitionState => + val topicPartition = new TopicPartition(partitionState.topicName, partitionState.partitionIndex) + val partitionOpt = getPartition(topicPartition) match { + case HostedPartition.Offline => + stateChangeLogger.warn(s"Ignoring LeaderAndIsr request from " + + s"controller $controllerId with correlation id $correlationId " + + s"epoch $controllerEpoch for partition $topicPartition as the local replica for the " + + "partition is in an offline log directory") + responseMap.put(topicPartition, Errors.KAFKA_STORAGE_ERROR) + None + + case HostedPartition.Online(partition) => + Some(partition) + + case HostedPartition.None => + val partition = Partition(topicPartition, time, this) + allPartitions.putIfNotExists(topicPartition, HostedPartition.Online(partition)) + Some(partition) + } + + // Next check the topic ID and the partition's leader epoch + partitionOpt.foreach { partition => + val currentLeaderEpoch = partition.getLeaderEpoch + val requestLeaderEpoch = partitionState.leaderEpoch + val requestTopicId = topicIdFromRequest(topicPartition.topic) + val logTopicId = partition.topicId + + if (!hasConsistentTopicId(requestTopicId, logTopicId)) { + stateChangeLogger.error(s"Topic ID in memory: ${logTopicId.get} does not" + + s" match the topic ID for partition $topicPartition received: " + + s"${requestTopicId.get}.") + responseMap.put(topicPartition, Errors.INCONSISTENT_TOPIC_ID) + } else if (requestLeaderEpoch > currentLeaderEpoch) { + // If the leader epoch is valid record the epoch of the controller that made the leadership decision. + // This is useful while updating the isr to maintain the decision maker controller's epoch in the zookeeper path + if (partitionState.replicas.contains(localBrokerId)) { + partitions += partition + if (partitionState.leader == localBrokerId) { + partitionsToBeLeader.put(partition, partitionState) + } else { + partitionsToBeFollower.put(partition, partitionState) + } + } else { + stateChangeLogger.warn(s"Ignoring LeaderAndIsr request from controller $controllerId with " + + s"correlation id $correlationId epoch $controllerEpoch for partition $topicPartition as itself is not " + + s"in assigned replica list ${partitionState.replicas.asScala.mkString(",")}") + responseMap.put(topicPartition, Errors.UNKNOWN_TOPIC_OR_PARTITION) + } + } else if (requestLeaderEpoch < currentLeaderEpoch) { + stateChangeLogger.warn(s"Ignoring LeaderAndIsr request from " + + s"controller $controllerId with correlation id $correlationId " + + s"epoch $controllerEpoch for partition $topicPartition since its associated " + + s"leader epoch $requestLeaderEpoch is smaller than the current " + + s"leader epoch $currentLeaderEpoch") + responseMap.put(topicPartition, Errors.STALE_CONTROLLER_EPOCH) + } else { + val error = requestTopicId match { + case Some(topicId) if logTopicId.isEmpty => + // The controller may send LeaderAndIsr to upgrade to using topic IDs without bumping the epoch. + // If we have a matching epoch, we expect the log to be defined. + val log = localLogOrException(partition.topicPartition) + log.assignTopicId(topicId) + stateChangeLogger.info(s"Updating log for $topicPartition to assign topic ID " + + s"$topicId from LeaderAndIsr request from controller $controllerId with correlation " + + s"id $correlationId epoch $controllerEpoch") + if (partitionState.leader != localBrokerId) + topicIdUpdateFollowerPartitions.add(partition) + Errors.NONE + case None if logTopicId.isDefined && partitionState.leader != localBrokerId => + // If we have a topic ID in the log but not in the request, we must have previously had topic IDs but + // are now downgrading. If we are a follower, remove the topic ID from the PartitionFetchState. + stateChangeLogger.info(s"Updating PartitionFetchState for $topicPartition to remove log topic ID " + + s"${logTopicId.get} since LeaderAndIsr request from controller $controllerId with correlation " + + s"id $correlationId epoch $controllerEpoch did not contain a topic ID") + topicIdUpdateFollowerPartitions.add(partition) + Errors.NONE + case _ => + stateChangeLogger.info(s"Ignoring LeaderAndIsr request from " + + s"controller $controllerId with correlation id $correlationId " + + s"epoch $controllerEpoch for partition $topicPartition since its associated " + + s"leader epoch $requestLeaderEpoch matches the current leader epoch") + Errors.STALE_CONTROLLER_EPOCH + } + responseMap.put(topicPartition, error) + } + } + } + + val highWatermarkCheckpoints = new LazyOffsetCheckpoints(this.highWatermarkCheckpoints) + val partitionsBecomeLeader = if (partitionsToBeLeader.nonEmpty) + makeLeaders(controllerId, controllerEpoch, partitionsToBeLeader, correlationId, responseMap, + highWatermarkCheckpoints, topicIdFromRequest) + else + Set.empty[Partition] + val partitionsBecomeFollower = if (partitionsToBeFollower.nonEmpty) + makeFollowers(controllerId, controllerEpoch, partitionsToBeFollower, correlationId, responseMap, + highWatermarkCheckpoints, topicIdFromRequest) + else + Set.empty[Partition] + + val followerTopicSet = partitionsBecomeFollower.map(_.topic).toSet + updateLeaderAndFollowerMetrics(followerTopicSet) + + if (topicIdUpdateFollowerPartitions.nonEmpty) + updateTopicIdForFollowers(controllerId, controllerEpoch, topicIdUpdateFollowerPartitions, correlationId, topicIdFromRequest) + + // We initialize highwatermark thread after the first LeaderAndIsr request. This ensures that all the partitions + // have been completely populated before starting the checkpointing there by avoiding weird race conditions + startHighWatermarkCheckPointThread() + + maybeAddLogDirFetchers(partitions, highWatermarkCheckpoints, topicIdFromRequest) + + replicaFetcherManager.shutdownIdleFetcherThreads() + replicaAlterLogDirsManager.shutdownIdleFetcherThreads() + onLeadershipChange(partitionsBecomeLeader, partitionsBecomeFollower) + + val data = new LeaderAndIsrResponseData().setErrorCode(Errors.NONE.code) + if (leaderAndIsrRequest.version < 5) { + responseMap.forKeyValue { (tp, error) => + data.partitionErrors.add(new LeaderAndIsrPartitionError() + .setTopicName(tp.topic) + .setPartitionIndex(tp.partition) + .setErrorCode(error.code)) + } + } else { + responseMap.forKeyValue { (tp, error) => + val topicId = topicIds.get(tp.topic) + var topic = data.topics.find(topicId) + if (topic == null) { + topic = new LeaderAndIsrTopicError().setTopicId(topicId) + data.topics.add(topic) + } + topic.partitionErrors.add(new LeaderAndIsrPartitionError() + .setPartitionIndex(tp.partition) + .setErrorCode(error.code)) + } + } + new LeaderAndIsrResponse(data, leaderAndIsrRequest.version) + } + } + val endMs = time.milliseconds() + val elapsedMs = endMs - startMs + stateChangeLogger.info(s"Finished LeaderAndIsr request in ${elapsedMs}ms correlationId $correlationId from controller " + + s"$controllerId for ${requestPartitionStates.size} partitions") + response + } + } + + /** + * Checks if the topic ID provided in the request is consistent with the topic ID in the log. + * When using this method to handle a Fetch request, the topic ID may have been provided by an earlier request. + * + * If the request had an invalid topic ID (null or zero), then we assume that topic IDs are not supported. + * The topic ID was not inconsistent, so return true. + * If the log does not exist or the topic ID is not yet set, logTopicIdOpt will be None. + * In both cases, the ID is not inconsistent so return true. + * + * @param requestTopicIdOpt the topic ID from the request if it exists + * @param logTopicIdOpt the topic ID in the log if the log and the topic ID exist + * @return true if the request topic id is consistent, false otherwise + */ + private def hasConsistentTopicId(requestTopicIdOpt: Option[Uuid], logTopicIdOpt: Option[Uuid]): Boolean = { + requestTopicIdOpt match { + case None => true + case Some(requestTopicId) => logTopicIdOpt.isEmpty || logTopicIdOpt.contains(requestTopicId) + } + } + + /** + * KAFKA-8392 + * For topic partitions of which the broker is no longer a leader, delete metrics related to + * those topics. Note that this means the broker stops being either a replica or a leader of + * partitions of said topics + */ + protected def updateLeaderAndFollowerMetrics(newFollowerTopics: Set[String]): Unit = { + val leaderTopicSet = leaderPartitionsIterator.map(_.topic).toSet + newFollowerTopics.diff(leaderTopicSet).foreach(brokerTopicStats.removeOldLeaderMetrics) + + // remove metrics for brokers which are not followers of a topic + leaderTopicSet.diff(newFollowerTopics).foreach(brokerTopicStats.removeOldFollowerMetrics) + } + + protected def maybeAddLogDirFetchers(partitions: Set[Partition], + offsetCheckpoints: OffsetCheckpoints, + topicIds: String => Option[Uuid]): Unit = { + val futureReplicasAndInitialOffset = new mutable.HashMap[TopicPartition, InitialFetchState] + for (partition <- partitions) { + val topicPartition = partition.topicPartition + if (logManager.getLog(topicPartition, isFuture = true).isDefined) { + partition.log.foreach { log => + val leader = BrokerEndPoint(config.brokerId, "localhost", -1) + + // Add future replica log to partition's map + partition.createLogIfNotExists( + isNew = false, + isFutureReplica = true, + offsetCheckpoints, + topicIds(partition.topic)) + + // pause cleaning for partitions that are being moved and start ReplicaAlterDirThread to move + // replica from source dir to destination dir + logManager.abortAndPauseCleaning(topicPartition) + + futureReplicasAndInitialOffset.put(topicPartition, InitialFetchState(topicIds(topicPartition.topic), leader, + partition.getLeaderEpoch, log.highWatermark)) + } + } + } + + if (futureReplicasAndInitialOffset.nonEmpty) + replicaAlterLogDirsManager.addFetcherForPartitions(futureReplicasAndInitialOffset) + } + + /* + * Make the current broker to become leader for a given set of partitions by: + * + * 1. Stop fetchers for these partitions + * 2. Update the partition metadata in cache + * 3. Add these partitions to the leader partitions set + * + * If an unexpected error is thrown in this function, it will be propagated to KafkaApis where + * the error message will be set on each partition since we do not know which partition caused it. Otherwise, + * return the set of partitions that are made leader due to this method + * + * TODO: the above may need to be fixed later + */ + private def makeLeaders(controllerId: Int, + controllerEpoch: Int, + partitionStates: Map[Partition, LeaderAndIsrPartitionState], + correlationId: Int, + responseMap: mutable.Map[TopicPartition, Errors], + highWatermarkCheckpoints: OffsetCheckpoints, + topicIds: String => Option[Uuid]): Set[Partition] = { + val traceEnabled = stateChangeLogger.isTraceEnabled + partitionStates.keys.foreach { partition => + if (traceEnabled) + stateChangeLogger.trace(s"Handling LeaderAndIsr request correlationId $correlationId from " + + s"controller $controllerId epoch $controllerEpoch starting the become-leader transition for " + + s"partition ${partition.topicPartition}") + responseMap.put(partition.topicPartition, Errors.NONE) + } + + val partitionsToMakeLeaders = mutable.Set[Partition]() + + try { + // First stop fetchers for all the partitions + replicaFetcherManager.removeFetcherForPartitions(partitionStates.keySet.map(_.topicPartition)) + stateChangeLogger.info(s"Stopped fetchers as part of LeaderAndIsr request correlationId $correlationId from " + + s"controller $controllerId epoch $controllerEpoch as part of the become-leader transition for " + + s"${partitionStates.size} partitions") + // Update the partition information to be the leader + partitionStates.forKeyValue { (partition, partitionState) => + try { + if (partition.makeLeader(partitionState, highWatermarkCheckpoints, topicIds(partitionState.topicName))) + partitionsToMakeLeaders += partition + else + stateChangeLogger.info(s"Skipped the become-leader state change after marking its " + + s"partition as leader with correlation id $correlationId from controller $controllerId epoch $controllerEpoch for " + + s"partition ${partition.topicPartition} (last update controller epoch ${partitionState.controllerEpoch}) " + + s"since it is already the leader for the partition.") + } catch { + case e: KafkaStorageException => + stateChangeLogger.error(s"Skipped the become-leader state change with " + + s"correlation id $correlationId from controller $controllerId epoch $controllerEpoch for partition ${partition.topicPartition} " + + s"(last update controller epoch ${partitionState.controllerEpoch}) since " + + s"the replica for the partition is offline due to storage error $e") + // If there is an offline log directory, a Partition object may have been created and have been added + // to `ReplicaManager.allPartitions` before `createLogIfNotExists()` failed to create local replica due + // to KafkaStorageException. In this case `ReplicaManager.allPartitions` will map this topic-partition + // to an empty Partition object. We need to map this topic-partition to OfflinePartition instead. + markPartitionOffline(partition.topicPartition) + responseMap.put(partition.topicPartition, Errors.KAFKA_STORAGE_ERROR) + } + } + + } catch { + case e: Throwable => + partitionStates.keys.foreach { partition => + stateChangeLogger.error(s"Error while processing LeaderAndIsr request correlationId $correlationId received " + + s"from controller $controllerId epoch $controllerEpoch for partition ${partition.topicPartition}", e) + } + // Re-throw the exception for it to be caught in KafkaApis + throw e + } + + if (traceEnabled) + partitionStates.keys.foreach { partition => + stateChangeLogger.trace(s"Completed LeaderAndIsr request correlationId $correlationId from controller $controllerId " + + s"epoch $controllerEpoch for the become-leader transition for partition ${partition.topicPartition}") + } + + partitionsToMakeLeaders + } + + /* + * Make the current broker to become follower for a given set of partitions by: + * + * 1. Remove these partitions from the leader partitions set. + * 2. Mark the replicas as followers so that no more data can be added from the producer clients. + * 3. Stop fetchers for these partitions so that no more data can be added by the replica fetcher threads. + * 4. Truncate the log and checkpoint offsets for these partitions. + * 5. Clear the produce and fetch requests in the purgatory + * 6. If the broker is not shutting down, add the fetcher to the new leaders. + * + * The ordering of doing these steps make sure that the replicas in transition will not + * take any more messages before checkpointing offsets so that all messages before the checkpoint + * are guaranteed to be flushed to disks + * + * If an unexpected error is thrown in this function, it will be propagated to KafkaApis where + * the error message will be set on each partition since we do not know which partition caused it. Otherwise, + * return the set of partitions that are made follower due to this method + */ + private def makeFollowers(controllerId: Int, + controllerEpoch: Int, + partitionStates: Map[Partition, LeaderAndIsrPartitionState], + correlationId: Int, + responseMap: mutable.Map[TopicPartition, Errors], + highWatermarkCheckpoints: OffsetCheckpoints, + topicIds: String => Option[Uuid]) : Set[Partition] = { + val traceLoggingEnabled = stateChangeLogger.isTraceEnabled + partitionStates.forKeyValue { (partition, partitionState) => + if (traceLoggingEnabled) + stateChangeLogger.trace(s"Handling LeaderAndIsr request correlationId $correlationId from controller $controllerId " + + s"epoch $controllerEpoch starting the become-follower transition for partition ${partition.topicPartition} with leader " + + s"${partitionState.leader}") + responseMap.put(partition.topicPartition, Errors.NONE) + } + + val partitionsToMakeFollower: mutable.Set[Partition] = mutable.Set() + try { + // TODO: Delete leaders from LeaderAndIsrRequest + partitionStates.forKeyValue { (partition, partitionState) => + val newLeaderBrokerId = partitionState.leader + try { + if (metadataCache.hasAliveBroker(newLeaderBrokerId)) { + // Only change partition state when the leader is available + if (partition.makeFollower(partitionState, highWatermarkCheckpoints, topicIds(partitionState.topicName))) + partitionsToMakeFollower += partition + else + stateChangeLogger.info(s"Skipped the become-follower state change after marking its partition as " + + s"follower with correlation id $correlationId from controller $controllerId epoch $controllerEpoch " + + s"for partition ${partition.topicPartition} (last update " + + s"controller epoch ${partitionState.controllerEpoch}) " + + s"since the new leader $newLeaderBrokerId is the same as the old leader") + } else { + // The leader broker should always be present in the metadata cache. + // If not, we should record the error message and abort the transition process for this partition + stateChangeLogger.error(s"Received LeaderAndIsrRequest with correlation id $correlationId from " + + s"controller $controllerId epoch $controllerEpoch for partition ${partition.topicPartition} " + + s"(last update controller epoch ${partitionState.controllerEpoch}) " + + s"but cannot become follower since the new leader $newLeaderBrokerId is unavailable.") + // Create the local replica even if the leader is unavailable. This is required to ensure that we include + // the partition's high watermark in the checkpoint file (see KAFKA-1647) + partition.createLogIfNotExists(isNew = partitionState.isNew, isFutureReplica = false, + highWatermarkCheckpoints, topicIds(partitionState.topicName)) + } + } catch { + case e: KafkaStorageException => + stateChangeLogger.error(s"Skipped the become-follower state change with correlation id $correlationId from " + + s"controller $controllerId epoch $controllerEpoch for partition ${partition.topicPartition} " + + s"(last update controller epoch ${partitionState.controllerEpoch}) with leader " + + s"$newLeaderBrokerId since the replica for the partition is offline due to storage error $e") + // If there is an offline log directory, a Partition object may have been created and have been added + // to `ReplicaManager.allPartitions` before `createLogIfNotExists()` failed to create local replica due + // to KafkaStorageException. In this case `ReplicaManager.allPartitions` will map this topic-partition + // to an empty Partition object. We need to map this topic-partition to OfflinePartition instead. + markPartitionOffline(partition.topicPartition) + responseMap.put(partition.topicPartition, Errors.KAFKA_STORAGE_ERROR) + } + } + + // Stopping the fetchers must be done first in order to initialize the fetch + // position correctly. + replicaFetcherManager.removeFetcherForPartitions(partitionsToMakeFollower.map(_.topicPartition)) + stateChangeLogger.info(s"Stopped fetchers as part of become-follower request from controller $controllerId " + + s"epoch $controllerEpoch with correlation id $correlationId for ${partitionsToMakeFollower.size} partitions") + + partitionsToMakeFollower.foreach { partition => + completeDelayedFetchOrProduceRequests(partition.topicPartition) + } + + if (isShuttingDown.get()) { + if (traceLoggingEnabled) { + partitionsToMakeFollower.foreach { partition => + stateChangeLogger.trace(s"Skipped the adding-fetcher step of the become-follower state " + + s"change with correlation id $correlationId from controller $controllerId epoch $controllerEpoch for " + + s"partition ${partition.topicPartition} with leader ${partitionStates(partition).leader} " + + "since it is shutting down") + } + } + } else { + // we do not need to check if the leader exists again since this has been done at the beginning of this process + val partitionsToMakeFollowerWithLeaderAndOffset = partitionsToMakeFollower.map { partition => + val leaderNode = partition.leaderReplicaIdOpt.flatMap(leaderId => metadataCache. + getAliveBrokerNode(leaderId, config.interBrokerListenerName)).getOrElse(Node.noNode()) + val leader = new BrokerEndPoint(leaderNode.id(), leaderNode.host(), leaderNode.port()) + val log = partition.localLogOrException + val fetchOffset = initialFetchOffset(log) + partition.topicPartition -> InitialFetchState(topicIds(partition.topic), leader, partition.getLeaderEpoch, fetchOffset) + }.toMap + + replicaFetcherManager.addFetcherForPartitions(partitionsToMakeFollowerWithLeaderAndOffset) + } + } catch { + case e: Throwable => + stateChangeLogger.error(s"Error while processing LeaderAndIsr request with correlationId $correlationId " + + s"received from controller $controllerId epoch $controllerEpoch", e) + // Re-throw the exception for it to be caught in KafkaApis + throw e + } + + if (traceLoggingEnabled) + partitionStates.keys.foreach { partition => + stateChangeLogger.trace(s"Completed LeaderAndIsr request correlationId $correlationId from controller $controllerId " + + s"epoch $controllerEpoch for the become-follower transition for partition ${partition.topicPartition} with leader " + + s"${partitionStates(partition).leader}") + } + + partitionsToMakeFollower + } + + private def updateTopicIdForFollowers(controllerId: Int, + controllerEpoch: Int, + partitions: Set[Partition], + correlationId: Int, + topicIds: String => Option[Uuid]): Unit = { + val traceLoggingEnabled = stateChangeLogger.isTraceEnabled + + try { + if (isShuttingDown.get()) { + if (traceLoggingEnabled) { + partitions.foreach { partition => + stateChangeLogger.trace(s"Skipped the update topic ID step of the become-follower state " + + s"change with correlation id $correlationId from controller $controllerId epoch $controllerEpoch for " + + s"partition ${partition.topicPartition} since it is shutting down") + } + } + } else { + val partitionsToUpdateFollowerWithLeader = mutable.Map.empty[TopicPartition, Int] + partitions.foreach { partition => + partition.leaderReplicaIdOpt.foreach { leader => + if (metadataCache.hasAliveBroker(leader)) { + partitionsToUpdateFollowerWithLeader += partition.topicPartition -> leader + } + } + } + replicaFetcherManager.maybeUpdateTopicIds(partitionsToUpdateFollowerWithLeader, topicIds) + } + } catch { + case e: Throwable => + stateChangeLogger.error(s"Error while processing LeaderAndIsr request with correlationId $correlationId " + + s"received from controller $controllerId epoch $controllerEpoch when trying to update topic IDs in the fetchers", e) + // Re-throw the exception for it to be caught in KafkaApis + throw e + } + } + + /** + * From IBP 2.7 onwards, we send latest fetch epoch in the request and truncate if a + * diverging epoch is returned in the response, avoiding the need for a separate + * OffsetForLeaderEpoch request. + */ + protected def initialFetchOffset(log: UnifiedLog): Long = { + if (ApiVersion.isTruncationOnFetchSupported(config.interBrokerProtocolVersion) && log.latestEpoch.nonEmpty) + log.logEndOffset + else + log.highWatermark + } + + private def maybeShrinkIsr(): Unit = { + trace("Evaluating ISR list of partitions to see which replicas can be removed from the ISR") + + // Shrink ISRs for non offline partitions + allPartitions.keys.foreach { topicPartition => + onlinePartition(topicPartition).foreach(_.maybeShrinkIsr()) + } + } + + /** + * Update the follower's fetch state on the leader based on the last fetch request and update `readResult`. + * If the follower replica is not recognized to be one of the assigned replicas, do not update + * `readResult` so that log start/end offset and high watermark is consistent with + * records in fetch response. Log start/end offset and high watermark may change not only due to + * this fetch request, e.g., rolling new log segment and removing old log segment may move log + * start offset further than the last offset in the fetched records. The followers will get the + * updated leader's state in the next fetch response. If follower has a diverging epoch or if read + * fails with any error, follower fetch state is not updated. + */ + private def updateFollowerFetchState(followerId: Int, + readResults: Seq[(TopicIdPartition, LogReadResult)]): Seq[(TopicIdPartition, LogReadResult)] = { + readResults.map { case (topicIdPartition, readResult) => + val updatedReadResult = if (readResult.error != Errors.NONE) { + debug(s"Skipping update of fetch state for follower $followerId since the " + + s"log read returned error ${readResult.error}") + readResult + } else if (readResult.divergingEpoch.nonEmpty) { + debug(s"Skipping update of fetch state for follower $followerId since the " + + s"log read returned diverging epoch ${readResult.divergingEpoch}") + readResult + } else { + onlinePartition(topicIdPartition.topicPartition) match { + case Some(partition) => + if (partition.updateFollowerFetchState(followerId, + followerFetchOffsetMetadata = readResult.info.fetchOffsetMetadata, + followerStartOffset = readResult.followerLogStartOffset, + followerFetchTimeMs = readResult.fetchTimeMs, + leaderEndOffset = readResult.leaderLogEndOffset)) { + readResult + } else { + warn(s"Leader $localBrokerId failed to record follower $followerId's position " + + s"${readResult.info.fetchOffsetMetadata.messageOffset}, and last sent HW since the replica " + + s"is not recognized to be one of the assigned replicas ${partition.assignmentState.replicas.mkString(",")} " + + s"for partition $topicIdPartition. Empty records will be returned for this partition.") + readResult.withEmptyFetchInfo + } + case None => + warn(s"While recording the replica LEO, the partition $topicIdPartition hasn't been created.") + readResult + } + } + topicIdPartition -> updatedReadResult + } + } + + private def leaderPartitionsIterator: Iterator[Partition] = + onlinePartitionsIterator.filter(_.leaderLogIfLocal.isDefined) + + def getLogEndOffset(topicPartition: TopicPartition): Option[Long] = + onlinePartition(topicPartition).flatMap(_.leaderLogIfLocal.map(_.logEndOffset)) + + // Flushes the highwatermark value for all partitions to the highwatermark file + def checkpointHighWatermarks(): Unit = { + def putHw(logDirToCheckpoints: mutable.AnyRefMap[String, mutable.AnyRefMap[TopicPartition, Long]], + log: UnifiedLog): Unit = { + val checkpoints = logDirToCheckpoints.getOrElseUpdate(log.parentDir, + new mutable.AnyRefMap[TopicPartition, Long]()) + checkpoints.put(log.topicPartition, log.highWatermark) + } + + val logDirToHws = new mutable.AnyRefMap[String, mutable.AnyRefMap[TopicPartition, Long]]( + allPartitions.size) + onlinePartitionsIterator.foreach { partition => + partition.log.foreach(putHw(logDirToHws, _)) + partition.futureLog.foreach(putHw(logDirToHws, _)) + } + + for ((logDir, hws) <- logDirToHws) { + try highWatermarkCheckpoints.get(logDir).foreach(_.write(hws)) + catch { + case e: KafkaStorageException => + error(s"Error while writing to highwatermark file in directory $logDir", e) + } + } + } + + def markPartitionOffline(tp: TopicPartition): Unit = replicaStateChangeLock synchronized { + allPartitions.put(tp, HostedPartition.Offline) + Partition.removeMetrics(tp) + } + + /** + * The log directory failure handler for the replica + * + * @param dir the absolute path of the log directory + * @param sendZkNotification check if we need to send notification to zookeeper node (needed for unit test) + */ + def handleLogDirFailure(dir: String, sendZkNotification: Boolean = true): Unit = { + if (!logManager.isLogDirOnline(dir)) + return + warn(s"Stopping serving replicas in dir $dir") + replicaStateChangeLock synchronized { + val newOfflinePartitions = onlinePartitionsIterator.filter { partition => + partition.log.exists { _.parentDir == dir } + }.map(_.topicPartition).toSet + + val partitionsWithOfflineFutureReplica = onlinePartitionsIterator.filter { partition => + partition.futureLog.exists { _.parentDir == dir } + }.toSet + + replicaFetcherManager.removeFetcherForPartitions(newOfflinePartitions) + replicaAlterLogDirsManager.removeFetcherForPartitions(newOfflinePartitions ++ partitionsWithOfflineFutureReplica.map(_.topicPartition)) + + partitionsWithOfflineFutureReplica.foreach(partition => partition.removeFutureLocalReplica(deleteFromLogDir = false)) + newOfflinePartitions.foreach { topicPartition => + markPartitionOffline(topicPartition) + } + newOfflinePartitions.map(_.topic).foreach { topic: String => + maybeRemoveTopicMetrics(topic) + } + highWatermarkCheckpoints = highWatermarkCheckpoints.filter { case (checkpointDir, _) => checkpointDir != dir } + + warn(s"Broker $localBrokerId stopped fetcher for partitions ${newOfflinePartitions.mkString(",")} and stopped moving logs " + + s"for partitions ${partitionsWithOfflineFutureReplica.mkString(",")} because they are in the failed log directory $dir.") + } + logManager.handleLogDirFailure(dir) + + if (sendZkNotification) + if (zkClient.isEmpty) { + warn("Unable to propagate log dir failure via Zookeeper in KRaft mode") + } else { + zkClient.get.propagateLogDirEvent(localBrokerId) + } + warn(s"Stopped serving replicas in dir $dir") + } + + def removeMetrics(): Unit = { + removeMetric("LeaderCount") + removeMetric("PartitionCount") + removeMetric("OfflineReplicaCount") + removeMetric("UnderReplicatedPartitions") + removeMetric("UnderMinIsrPartitionCount") + removeMetric("AtMinIsrPartitionCount") + removeMetric("ReassigningPartitions") + } + + // High watermark do not need to be checkpointed only when under unit tests + def shutdown(checkpointHW: Boolean = true): Unit = { + info("Shutting down") + removeMetrics() + if (logDirFailureHandler != null) + logDirFailureHandler.shutdown() + replicaFetcherManager.shutdown() + replicaAlterLogDirsManager.shutdown() + delayedFetchPurgatory.shutdown() + delayedProducePurgatory.shutdown() + delayedDeleteRecordsPurgatory.shutdown() + delayedElectLeaderPurgatory.shutdown() + if (checkpointHW) + checkpointHighWatermarks() + replicaSelectorOpt.foreach(_.close) + info("Shut down completely") + } + + protected def createReplicaFetcherManager(metrics: Metrics, time: Time, threadNamePrefix: Option[String], quotaManager: ReplicationQuotaManager) = { + new ReplicaFetcherManager(config, this, metrics, time, threadNamePrefix, quotaManager) + } + + protected def createReplicaAlterLogDirsManager(quotaManager: ReplicationQuotaManager, brokerTopicStats: BrokerTopicStats) = { + new ReplicaAlterLogDirsManager(config, this, quotaManager, brokerTopicStats) + } + + protected def createReplicaSelector(): Option[ReplicaSelector] = { + config.replicaSelectorClassName.map { className => + val tmpReplicaSelector: ReplicaSelector = CoreUtils.createObject[ReplicaSelector](className) + tmpReplicaSelector.configure(config.originals()) + tmpReplicaSelector + } + } + + def lastOffsetForLeaderEpoch( + requestedEpochInfo: Seq[OffsetForLeaderTopic] + ): Seq[OffsetForLeaderTopicResult] = { + requestedEpochInfo.map { offsetForLeaderTopic => + val partitions = offsetForLeaderTopic.partitions.asScala.map { offsetForLeaderPartition => + val tp = new TopicPartition(offsetForLeaderTopic.topic, offsetForLeaderPartition.partition) + getPartition(tp) match { + case HostedPartition.Online(partition) => + val currentLeaderEpochOpt = + if (offsetForLeaderPartition.currentLeaderEpoch == RecordBatch.NO_PARTITION_LEADER_EPOCH) + Optional.empty[Integer] + else + Optional.of[Integer](offsetForLeaderPartition.currentLeaderEpoch) + + partition.lastOffsetForLeaderEpoch( + currentLeaderEpochOpt, + offsetForLeaderPartition.leaderEpoch, + fetchOnlyFromLeader = true) + + case HostedPartition.Offline => + new EpochEndOffset() + .setPartition(offsetForLeaderPartition.partition) + .setErrorCode(Errors.KAFKA_STORAGE_ERROR.code) + + case HostedPartition.None if metadataCache.contains(tp) => + new EpochEndOffset() + .setPartition(offsetForLeaderPartition.partition) + .setErrorCode(Errors.NOT_LEADER_OR_FOLLOWER.code) + + case HostedPartition.None => + new EpochEndOffset() + .setPartition(offsetForLeaderPartition.partition) + .setErrorCode(Errors.UNKNOWN_TOPIC_OR_PARTITION.code) + } + } + + new OffsetForLeaderTopicResult() + .setTopic(offsetForLeaderTopic.topic) + .setPartitions(partitions.toList.asJava) + } + } + + def electLeaders( + controller: KafkaController, + partitions: Set[TopicPartition], + electionType: ElectionType, + responseCallback: Map[TopicPartition, ApiError] => Unit, + requestTimeout: Int + ): Unit = { + + val deadline = time.milliseconds() + requestTimeout + + def electionCallback(results: Map[TopicPartition, Either[ApiError, Int]]): Unit = { + val expectedLeaders = mutable.Map.empty[TopicPartition, Int] + val failures = mutable.Map.empty[TopicPartition, ApiError] + results.foreach { + case (partition, Right(leader)) => expectedLeaders += partition -> leader + case (partition, Left(error)) => failures += partition -> error + } + if (expectedLeaders.nonEmpty) { + val watchKeys = expectedLeaders.iterator.map { + case (tp, _) => TopicPartitionOperationKey(tp) + }.toBuffer + + delayedElectLeaderPurgatory.tryCompleteElseWatch( + new DelayedElectLeader( + math.max(0, deadline - time.milliseconds()), + expectedLeaders, + failures, + this, + responseCallback + ), + watchKeys + ) + } else { + // There are no partitions actually being elected, so return immediately + responseCallback(failures) + } + } + + controller.electLeaders(partitions, electionType, electionCallback) + } + + def activeProducerState(requestPartition: TopicPartition): DescribeProducersResponseData.PartitionResponse = { + getPartitionOrError(requestPartition) match { + case Left(error) => new DescribeProducersResponseData.PartitionResponse() + .setPartitionIndex(requestPartition.partition) + .setErrorCode(error.code) + case Right(partition) => partition.activeProducerState + } + } + + private[kafka] def getOrCreatePartition(tp: TopicPartition, + delta: TopicsDelta, + topicId: Uuid): Option[(Partition, Boolean)] = { + getPartition(tp) match { + case HostedPartition.Offline => + stateChangeLogger.warn(s"Unable to bring up new local leader ${tp} " + + s"with topic id ${topicId} because it resides in an offline log " + + "directory.") + None + + case HostedPartition.Online(partition) => { + if (partition.topicId.exists(_ != topicId)) { + // Note: Partition#topicId will be None here if the Log object for this partition + // has not been created. + throw new IllegalStateException(s"Topic ${tp} exists, but its ID is " + + s"${partition.topicId.get}, not ${topicId} as expected") + } + Some(partition, false) + } + + case HostedPartition.None => + if (delta.image().topicsById().containsKey(topicId)) { + stateChangeLogger.error(s"Expected partition ${tp} with topic id " + + s"${topicId} to exist, but it was missing. Creating...") + } else { + stateChangeLogger.info(s"Creating new partition ${tp} with topic id " + + s"${topicId}.") + } + // it's a partition that we don't know about yet, so create it and mark it online + val partition = Partition(tp, time, this) + allPartitions.put(tp, HostedPartition.Online(partition)) + Some(partition, true) + } + } + + /** + * Apply a KRaft topic change delta. + * + * @param delta The delta to apply. + * @param newImage The new metadata image. + */ + def applyDelta(delta: TopicsDelta, newImage: MetadataImage): Unit = { + // Before taking the lock, compute the local changes + val localChanges = delta.localChanges(config.nodeId) + + replicaStateChangeLock.synchronized { + // Handle deleted partitions. We need to do this first because we might subsequently + // create new partitions with the same names as the ones we are deleting here. + if (!localChanges.deletes.isEmpty) { + val deletes = localChanges.deletes.asScala.map(tp => (tp, true)).toMap + stateChangeLogger.info(s"Deleting ${deletes.size} partition(s).") + stopPartitions(deletes).forKeyValue { (topicPartition, e) => + if (e.isInstanceOf[KafkaStorageException]) { + stateChangeLogger.error(s"Unable to delete replica ${topicPartition} because " + + "the local replica for the partition is in an offline log directory") + } else { + stateChangeLogger.error(s"Unable to delete replica ${topicPartition} because " + + s"we got an unexpected ${e.getClass.getName} exception: ${e.getMessage}") + } + } + } + + // Handle partitions which we are now the leader or follower for. + if (!localChanges.leaders.isEmpty || !localChanges.followers.isEmpty) { + val lazyOffsetCheckpoints = new LazyOffsetCheckpoints(this.highWatermarkCheckpoints) + val changedPartitions = new mutable.HashSet[Partition] + if (!localChanges.leaders.isEmpty) { + applyLocalLeadersDelta(changedPartitions, delta, lazyOffsetCheckpoints, localChanges.leaders.asScala) + } + if (!localChanges.followers.isEmpty) { + applyLocalFollowersDelta(changedPartitions, newImage, delta, lazyOffsetCheckpoints, localChanges.followers.asScala) + } + maybeAddLogDirFetchers(changedPartitions, lazyOffsetCheckpoints, + name => Option(newImage.topics().getTopic(name)).map(_.id())) + + replicaFetcherManager.shutdownIdleFetcherThreads() + replicaAlterLogDirsManager.shutdownIdleFetcherThreads() + } + } + } + + private def applyLocalLeadersDelta( + changedPartitions: mutable.Set[Partition], + delta: TopicsDelta, + offsetCheckpoints: OffsetCheckpoints, + newLocalLeaders: mutable.Map[TopicPartition, LocalReplicaChanges.PartitionInfo] + ): Unit = { + stateChangeLogger.info(s"Transitioning ${newLocalLeaders.size} partition(s) to " + + "local leaders.") + replicaFetcherManager.removeFetcherForPartitions(newLocalLeaders.keySet) + newLocalLeaders.forKeyValue { (tp, info) => + getOrCreatePartition(tp, delta, info.topicId).foreach { case (partition, isNew) => + try { + val state = info.partition.toLeaderAndIsrPartitionState(tp, isNew) + if (!partition.makeLeader(state, offsetCheckpoints, Some(info.topicId))) { + stateChangeLogger.info("Skipped the become-leader state change for " + + s"$tp with topic id ${info.topicId} because this partition is " + + "already a local leader.") + } + changedPartitions.add(partition) + } catch { + case e: KafkaStorageException => + stateChangeLogger.info(s"Skipped the become-leader state change for $tp " + + s"with topic id ${info.topicId} due to a storage error ${e.getMessage}") + // If there is an offline log directory, a Partition object may have been created by + // `getOrCreatePartition()` before `createLogIfNotExists()` failed to create local replica due + // to KafkaStorageException. In this case `ReplicaManager.allPartitions` will map this topic-partition + // to an empty Partition object. We need to map this topic-partition to OfflinePartition instead. + markPartitionOffline(tp) + } + } + } + } + + private def applyLocalFollowersDelta( + changedPartitions: mutable.Set[Partition], + newImage: MetadataImage, + delta: TopicsDelta, + offsetCheckpoints: OffsetCheckpoints, + newLocalFollowers: mutable.Map[TopicPartition, LocalReplicaChanges.PartitionInfo] + ): Unit = { + stateChangeLogger.info(s"Transitioning ${newLocalFollowers.size} partition(s) to " + + "local followers.") + val shuttingDown = isShuttingDown.get() + val partitionsToMakeFollower = new mutable.HashMap[TopicPartition, Partition] + val newFollowerTopicSet = new mutable.HashSet[String] + newLocalFollowers.forKeyValue { (tp, info) => + getOrCreatePartition(tp, delta, info.topicId).foreach { case (partition, isNew) => + try { + newFollowerTopicSet.add(tp.topic) + + if (shuttingDown) { + stateChangeLogger.trace(s"Unable to start fetching $tp with topic " + + s"ID ${info.topicId} because the replica manager is shutting down.") + } else { + val leader = info.partition.leader + if (newImage.cluster.broker(leader) == null) { + stateChangeLogger.trace(s"Unable to start fetching $tp with topic ID ${info.topicId} " + + s"from leader $leader because it is not alive.") + + // Create the local replica even if the leader is unavailable. This is required + // to ensure that we include the partition's high watermark in the checkpoint + // file (see KAFKA-1647). + partition.createLogIfNotExists(isNew, false, offsetCheckpoints, Some(info.topicId)) + } else { + val state = info.partition.toLeaderAndIsrPartitionState(tp, isNew) + if (partition.makeFollower(state, offsetCheckpoints, Some(info.topicId))) { + partitionsToMakeFollower.put(tp, partition) + } else { + stateChangeLogger.info("Skipped the become-follower state change after marking its " + + s"partition as follower for partition $tp with id ${info.topicId} and partition state $state.") + } + } + } + changedPartitions.add(partition) + } catch { + case e: KafkaStorageException => + stateChangeLogger.error(s"Unable to start fetching $tp " + + s"with topic ID ${info.topicId} due to a storage error ${e.getMessage}", e) + replicaFetcherManager.addFailedPartition(tp) + // If there is an offline log directory, a Partition object may have been created by + // `getOrCreatePartition()` before `createLogIfNotExists()` failed to create local replica due + // to KafkaStorageException. In this case `ReplicaManager.allPartitions` will map this topic-partition + // to an empty Partition object. We need to map this topic-partition to OfflinePartition instead. + markPartitionOffline(tp) + + case e: Throwable => + stateChangeLogger.error(s"Unable to start fetching $tp " + + s"with topic ID ${info.topicId} due to ${e.getClass.getSimpleName}", e) + replicaFetcherManager.addFailedPartition(tp) + } + } + } + + // Stopping the fetchers must be done first in order to initialize the fetch + // position correctly. + replicaFetcherManager.removeFetcherForPartitions(partitionsToMakeFollower.keySet) + stateChangeLogger.info(s"Stopped fetchers as part of become-follower for ${partitionsToMakeFollower.size} partitions") + + val listenerName = config.interBrokerListenerName.value + val partitionAndOffsets = new mutable.HashMap[TopicPartition, InitialFetchState] + partitionsToMakeFollower.forKeyValue { (topicPartition, partition) => + val node = partition.leaderReplicaIdOpt + .flatMap(leaderId => Option(newImage.cluster.broker(leaderId))) + .flatMap(_.node(listenerName).asScala) + .getOrElse(Node.noNode) + val log = partition.localLogOrException + partitionAndOffsets.put(topicPartition, InitialFetchState( + log.topicId, + new BrokerEndPoint(node.id, node.host, node.port), + partition.getLeaderEpoch, + initialFetchOffset(log) + )) + } + + replicaFetcherManager.addFetcherForPartitions(partitionAndOffsets) + stateChangeLogger.info(s"Started fetchers as part of become-follower for ${partitionsToMakeFollower.size} partitions") + + partitionsToMakeFollower.keySet.foreach(completeDelayedFetchOrProduceRequests) + + updateLeaderAndFollowerMetrics(newFollowerTopicSet) + } + + def deleteStrayReplicas(topicPartitions: Iterable[TopicPartition]): Unit = { + stopPartitions(topicPartitions.map(tp => tp -> true).toMap).forKeyValue { (topicPartition, exception) => + exception match { + case e: KafkaStorageException => + stateChangeLogger.error(s"Unable to delete stray replica $topicPartition because " + + s"the local replica for the partition is in an offline log directory: ${e.getMessage}.") + case e: Throwable => + stateChangeLogger.error(s"Unable to delete stray replica $topicPartition because " + + s"we got an unexpected ${e.getClass.getName} exception: ${e.getMessage}", e) + } + } + } +} diff --git a/core/src/main/scala/kafka/server/ReplicationQuotaManager.scala b/core/src/main/scala/kafka/server/ReplicationQuotaManager.scala new file mode 100644 index 0000000..3035cb1 --- /dev/null +++ b/core/src/main/scala/kafka/server/ReplicationQuotaManager.scala @@ -0,0 +1,200 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} +import java.util.concurrent.locks.ReentrantReadWriteLock + +import scala.collection.Seq + +import kafka.server.Constants._ +import kafka.server.ReplicationQuotaManagerConfig._ +import kafka.utils.CoreUtils._ +import kafka.utils.Logging +import org.apache.kafka.common.metrics._ + +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.metrics.stats.SimpleRate +import org.apache.kafka.common.utils.Time + +/** + * Configuration settings for quota management + * + * @param quotaBytesPerSecondDefault The default bytes per second quota allocated to internal replication + * @param numQuotaSamples The number of samples to retain in memory + * @param quotaWindowSizeSeconds The time span of each sample + * + */ +case class ReplicationQuotaManagerConfig(quotaBytesPerSecondDefault: Long = QuotaBytesPerSecondDefault, + numQuotaSamples: Int = DefaultNumQuotaSamples, + quotaWindowSizeSeconds: Int = DefaultQuotaWindowSizeSeconds) + +object ReplicationQuotaManagerConfig { + val QuotaBytesPerSecondDefault = Long.MaxValue + // Always have 10 whole windows + 1 current window + val DefaultNumQuotaSamples = 11 + val DefaultQuotaWindowSizeSeconds = 1 + // Purge sensors after 1 hour of inactivity + val InactiveSensorExpirationTimeSeconds = 3600 +} + +trait ReplicaQuota { + def record(value: Long): Unit + def isThrottled(topicPartition: TopicPartition): Boolean + def isQuotaExceeded: Boolean +} + +object Constants { + val AllReplicas = Seq[Int](-1) +} + +/** + * Tracks replication metrics and comparing them to any quotas for throttled partitions. + * + * @param config The quota configs + * @param metrics The Metrics instance + * @param replicationType The name / key for this quota manager, typically leader or follower + * @param time Time object to use + */ +class ReplicationQuotaManager(val config: ReplicationQuotaManagerConfig, + private val metrics: Metrics, + private val replicationType: QuotaType, + private val time: Time) extends Logging with ReplicaQuota { + private val lock = new ReentrantReadWriteLock() + private val throttledPartitions = new ConcurrentHashMap[String, Seq[Int]]() + private var quota: Quota = null + private val sensorAccess = new SensorAccess(lock, metrics) + private val rateMetricName = metrics.metricName("byte-rate", replicationType.toString, + s"Tracking byte-rate for ${replicationType}") + + /** + * Update the quota + * + * @param quota + */ + def updateQuota(quota: Quota): Unit = { + inWriteLock(lock) { + this.quota = quota + //The metric could be expired by another thread, so use a local variable and null check. + val metric = metrics.metrics.get(rateMetricName) + if (metric != null) { + metric.config(getQuotaMetricConfig(quota)) + } + } + } + + /** + * Check if the quota is currently exceeded + * + * @return + */ + override def isQuotaExceeded: Boolean = { + try { + sensor().checkQuotas() + } catch { + case qve: QuotaViolationException => + trace(s"$replicationType: Quota violated for sensor (${sensor().name}), metric: (${qve.metric.metricName}), " + + s"metric-value: (${qve.value}), bound: (${qve.bound})") + return true + } + false + } + + /** + * Is the passed partition throttled by this ReplicationQuotaManager + * + * @param topicPartition the partition to check + * @return + */ + override def isThrottled(topicPartition: TopicPartition): Boolean = { + val partitions = throttledPartitions.get(topicPartition.topic) + if (partitions != null) + (partitions eq AllReplicas) || partitions.contains(topicPartition.partition) + else false + } + + /** + * Add the passed value to the throttled rate. This method ignores the quota with + * the value being added to the rate even if the quota is exceeded + * + * @param value + */ + def record(value: Long): Unit = { + sensor().record(value.toDouble, time.milliseconds(), false) + } + + /** + * Update the set of throttled partitions for this QuotaManager. The partitions passed, for + * any single topic, will replace any previous + * + * @param topic + * @param partitions the set of throttled partitions + * @return + */ + def markThrottled(topic: String, partitions: Seq[Int]): Unit = { + throttledPartitions.put(topic, partitions) + } + + /** + * Mark all replicas for this topic as throttled + * + * @param topic + * @return + */ + def markThrottled(topic: String): Unit = { + markThrottled(topic, AllReplicas) + } + + /** + * Remove list of throttled replicas for a certain topic + * + * @param topic + * @return + */ + def removeThrottle(topic: String): Unit = { + throttledPartitions.remove(topic) + } + + /** + * Returns the bound of the configured quota + * + * @return + */ + def upperBound: Long = { + inReadLock(lock) { + if (quota != null) + quota.bound.toLong + else + Long.MaxValue + } + } + + private def getQuotaMetricConfig(quota: Quota): MetricConfig = { + new MetricConfig() + .timeWindow(config.quotaWindowSizeSeconds, TimeUnit.SECONDS) + .samples(config.numQuotaSamples) + .quota(quota) + } + + private def sensor(): Sensor = { + sensorAccess.getOrCreate( + replicationType.toString, + InactiveSensorExpirationTimeSeconds, + sensor => sensor.add(rateMetricName, new SimpleRate, getQuotaMetricConfig(quota)) + ) + } +} diff --git a/core/src/main/scala/kafka/server/RequestHandlerHelper.scala b/core/src/main/scala/kafka/server/RequestHandlerHelper.scala new file mode 100644 index 0000000..a1aab61 --- /dev/null +++ b/core/src/main/scala/kafka/server/RequestHandlerHelper.scala @@ -0,0 +1,171 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.cluster.Partition +import kafka.coordinator.group.GroupCoordinator +import kafka.coordinator.transaction.TransactionCoordinator +import kafka.network.RequestChannel +import kafka.server.QuotaFactory.QuotaManagers +import org.apache.kafka.common.errors.ClusterAuthorizationException +import org.apache.kafka.common.internals.Topic +import org.apache.kafka.common.network.Send +import org.apache.kafka.common.requests.{AbstractRequest, AbstractResponse} +import org.apache.kafka.common.utils.Time + +object RequestHandlerHelper { + + def onLeadershipChange(groupCoordinator: GroupCoordinator, + txnCoordinator: TransactionCoordinator, + updatedLeaders: Iterable[Partition], + updatedFollowers: Iterable[Partition]): Unit = { + // for each new leader or follower, call coordinator to handle consumer group migration. + // this callback is invoked under the replica state change lock to ensure proper order of + // leadership changes + updatedLeaders.foreach { partition => + if (partition.topic == Topic.GROUP_METADATA_TOPIC_NAME) + groupCoordinator.onElection(partition.partitionId, partition.getLeaderEpoch) + else if (partition.topic == Topic.TRANSACTION_STATE_TOPIC_NAME) + txnCoordinator.onElection(partition.partitionId, partition.getLeaderEpoch) + } + + updatedFollowers.foreach { partition => + if (partition.topic == Topic.GROUP_METADATA_TOPIC_NAME) + groupCoordinator.onResignation(partition.partitionId, Some(partition.getLeaderEpoch)) + else if (partition.topic == Topic.TRANSACTION_STATE_TOPIC_NAME) + txnCoordinator.onResignation(partition.partitionId, Some(partition.getLeaderEpoch)) + } + } + +} + +class RequestHandlerHelper( + requestChannel: RequestChannel, + quotas: QuotaManagers, + time: Time +) { + + def throttle( + quotaManager: ClientQuotaManager, + request: RequestChannel.Request, + throttleTimeMs: Int + ): Unit = { + val callback = new ThrottleCallback { + override def startThrottling(): Unit = requestChannel.startThrottling(request) + override def endThrottling(): Unit = requestChannel.endThrottling(request) + } + quotaManager.throttle(request, callback, throttleTimeMs) + } + + def handleError(request: RequestChannel.Request, e: Throwable): Unit = { + val mayThrottle = e.isInstanceOf[ClusterAuthorizationException] || !request.header.apiKey.clusterAction + if (mayThrottle) + sendErrorResponseMaybeThrottle(request, e) + else + sendErrorResponseExemptThrottle(request, e) + } + + def sendErrorOrCloseConnection( + request: RequestChannel.Request, + error: Throwable, + throttleMs: Int + ): Unit = { + val requestBody = request.body[AbstractRequest] + val response = requestBody.getErrorResponse(throttleMs, error) + if (response == null) + requestChannel.closeConnection(request, requestBody.errorCounts(error)) + else + requestChannel.sendResponse(request, response, None) + } + + def sendForwardedResponse(request: RequestChannel.Request, + response: AbstractResponse): Unit = { + // For forwarded requests, we take the throttle time from the broker that + // the request was forwarded to + val throttleTimeMs = response.throttleTimeMs() + throttle(quotas.request, request, throttleTimeMs) + requestChannel.sendResponse(request, response, None) + } + + // Throttle the channel if the request quota is enabled but has been violated. Regardless of throttling, send the + // response immediately. + def sendResponseMaybeThrottle(request: RequestChannel.Request, + createResponse: Int => AbstractResponse): Unit = { + val throttleTimeMs = maybeRecordAndGetThrottleTimeMs(request) + // Only throttle non-forwarded requests + if (!request.isForwarded) + throttle(quotas.request, request, throttleTimeMs) + requestChannel.sendResponse(request, createResponse(throttleTimeMs), None) + } + + def sendErrorResponseMaybeThrottle(request: RequestChannel.Request, error: Throwable): Unit = { + val throttleTimeMs = maybeRecordAndGetThrottleTimeMs(request) + // Only throttle non-forwarded requests or cluster authorization failures + if (error.isInstanceOf[ClusterAuthorizationException] || !request.isForwarded) + throttle(quotas.request, request, throttleTimeMs) + sendErrorOrCloseConnection(request, error, throttleTimeMs) + } + + def maybeRecordAndGetThrottleTimeMs(request: RequestChannel.Request): Int = { + val throttleTimeMs = quotas.request.maybeRecordAndGetThrottleTimeMs(request, time.milliseconds()) + request.apiThrottleTimeMs = throttleTimeMs + throttleTimeMs + } + + /** + * Throttle the channel if the controller mutations quota or the request quota have been violated. + * Regardless of throttling, send the response immediately. + */ + def sendResponseMaybeThrottleWithControllerQuota(controllerMutationQuota: ControllerMutationQuota, + request: RequestChannel.Request, + createResponse: Int => AbstractResponse): Unit = { + val timeMs = time.milliseconds + val controllerThrottleTimeMs = controllerMutationQuota.throttleTime + val requestThrottleTimeMs = quotas.request.maybeRecordAndGetThrottleTimeMs(request, timeMs) + val maxThrottleTimeMs = Math.max(controllerThrottleTimeMs, requestThrottleTimeMs) + // Only throttle non-forwarded requests + if (maxThrottleTimeMs > 0 && !request.isForwarded) { + request.apiThrottleTimeMs = maxThrottleTimeMs + if (controllerThrottleTimeMs > requestThrottleTimeMs) { + throttle(quotas.controllerMutation, request, controllerThrottleTimeMs) + } else { + throttle(quotas.request, request, requestThrottleTimeMs) + } + } + + requestChannel.sendResponse(request, createResponse(maxThrottleTimeMs), None) + } + + def sendResponseExemptThrottle(request: RequestChannel.Request, + response: AbstractResponse, + onComplete: Option[Send => Unit] = None): Unit = { + quotas.request.maybeRecordExempt(request) + requestChannel.sendResponse(request, response, onComplete) + } + + def sendErrorResponseExemptThrottle(request: RequestChannel.Request, error: Throwable): Unit = { + quotas.request.maybeRecordExempt(request) + sendErrorOrCloseConnection(request, error, 0) + } + + def sendNoOpResponseExemptThrottle(request: RequestChannel.Request): Unit = { + quotas.request.maybeRecordExempt(request) + requestChannel.sendNoOpResponse(request) + } + +} diff --git a/core/src/main/scala/kafka/server/RequestLocal.scala b/core/src/main/scala/kafka/server/RequestLocal.scala new file mode 100644 index 0000000..5af495f --- /dev/null +++ b/core/src/main/scala/kafka/server/RequestLocal.scala @@ -0,0 +1,37 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import org.apache.kafka.common.utils.BufferSupplier + +object RequestLocal { + val NoCaching: RequestLocal = RequestLocal(BufferSupplier.NO_CACHING) + + /** The returned instance should be confined to a single thread. */ + def withThreadConfinedCaching: RequestLocal = RequestLocal(BufferSupplier.create()) +} + +/** + * Container for stateful instances where the lifecycle is scoped to one request. + * + * When each request is handled by one thread, efficient data structures with no locking or atomic operations + * can be used (see RequestLocal.withThreadConfinedCaching). + */ +case class RequestLocal(bufferSupplier: BufferSupplier) { + def close(): Unit = bufferSupplier.close() +} diff --git a/core/src/main/scala/kafka/server/SensorAccess.scala b/core/src/main/scala/kafka/server/SensorAccess.scala new file mode 100644 index 0000000..3a063f2 --- /dev/null +++ b/core/src/main/scala/kafka/server/SensorAccess.scala @@ -0,0 +1,71 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.util.concurrent.locks.ReadWriteLock + +import org.apache.kafka.common.metrics.{Metrics, Sensor} + +/** + * Class which centralises the logic for creating/accessing sensors. + * The quota can be updated by wrapping it in the passed MetricConfig + * + * The later arguments are passed as methods as they are only called when the sensor is instantiated. + */ +class SensorAccess(lock: ReadWriteLock, metrics: Metrics) { + + def getOrCreate(sensorName: String, expirationTime: Long, registerMetrics: Sensor => Unit): Sensor = { + var sensor: Sensor = null + + /* Acquire the read lock to fetch the sensor. It is safe to call getSensor from multiple threads. + * The read lock allows a thread to create a sensor in isolation. The thread creating the sensor + * will acquire the write lock and prevent the sensors from being read while they are being created. + * It should be sufficient to simply check if the sensor is null without acquiring a read lock but the + * sensor being present doesn't mean that it is fully initialized i.e. all the Metrics may not have been added. + * This read lock waits until the writer thread has released its lock i.e. fully initialized the sensor + * at which point it is safe to read + */ + lock.readLock().lock() + try sensor = metrics.getSensor(sensorName) + finally lock.readLock().unlock() + + /* If the sensor is null, try to create it else return the existing sensor + * The sensor can be null, hence the null checks + */ + if (sensor == null) { + /* Acquire a write lock because the sensor may not have been created and we only want one thread to create it. + * Note that multiple threads may acquire the write lock if they all see a null sensor initially + * In this case, the writer checks the sensor after acquiring the lock again. + * This is safe from Double Checked Locking because the references are read + * after acquiring read locks and hence they cannot see a partially published reference + */ + lock.writeLock().lock() + try { + // Set the var for both sensors in case another thread has won the race to acquire the write lock. This will + // ensure that we initialise `ClientSensors` with non-null parameters. + sensor = metrics.getSensor(sensorName) + if (sensor == null) { + sensor = metrics.sensor(sensorName, null, expirationTime) + registerMetrics(sensor) + } + } finally { + lock.writeLock().unlock() + } + } + sensor + } +} diff --git a/core/src/main/scala/kafka/server/Server.scala b/core/src/main/scala/kafka/server/Server.scala new file mode 100644 index 0000000..c395df4 --- /dev/null +++ b/core/src/main/scala/kafka/server/Server.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.util.Collections +import java.util.concurrent.TimeUnit + +import org.apache.kafka.clients.CommonClientConfigs +import org.apache.kafka.common.metrics.{JmxReporter, KafkaMetricsContext, MetricConfig, Metrics, MetricsReporter, Sensor} +import org.apache.kafka.common.utils.Time +import org.apache.kafka.metadata.VersionRange + +import scala.jdk.CollectionConverters._ + +trait Server { + def startup(): Unit + def shutdown(): Unit + def awaitShutdown(): Unit +} + +object Server { + val MetricsPrefix: String = "kafka.server" + val ClusterIdLabel: String = "kafka.cluster.id" + val BrokerIdLabel: String = "kafka.broker.id" + val NodeIdLabel: String = "kafka.node.id" + + def initializeMetrics( + config: KafkaConfig, + time: Time, + clusterId: String + ): Metrics = { + val metricsContext = createKafkaMetricsContext(config, clusterId) + buildMetrics(config, time, metricsContext) + } + + private def buildMetrics( + config: KafkaConfig, + time: Time, + metricsContext: KafkaMetricsContext + ): Metrics = { + val defaultReporters = initializeDefaultReporters(config) + val metricConfig = buildMetricsConfig(config) + new Metrics(metricConfig, defaultReporters, time, true, metricsContext) + } + + def buildMetricsConfig( + kafkaConfig: KafkaConfig + ): MetricConfig = { + new MetricConfig() + .samples(kafkaConfig.metricNumSamples) + .recordLevel(Sensor.RecordingLevel.forName(kafkaConfig.metricRecordingLevel)) + .timeWindow(kafkaConfig.metricSampleWindowMs, TimeUnit.MILLISECONDS) + } + + private[server] def createKafkaMetricsContext( + config: KafkaConfig, + clusterId: String + ): KafkaMetricsContext = { + val contextLabels = new java.util.HashMap[String, Object] + contextLabels.put(ClusterIdLabel, clusterId) + + if (config.usesSelfManagedQuorum) { + contextLabels.put(NodeIdLabel, config.nodeId.toString) + } else { + contextLabels.put(BrokerIdLabel, config.brokerId.toString) + } + + contextLabels.putAll(config.originalsWithPrefix(CommonClientConfigs.METRICS_CONTEXT_PREFIX)) + new KafkaMetricsContext(MetricsPrefix, contextLabels) + } + + private def initializeDefaultReporters( + config: KafkaConfig + ): java.util.List[MetricsReporter] = { + val jmxReporter = new JmxReporter() + jmxReporter.configure(config.originals) + + val reporters = new java.util.ArrayList[MetricsReporter] + reporters.add(jmxReporter) + reporters + } + + sealed trait ProcessStatus + case object SHUTDOWN extends ProcessStatus + case object STARTING extends ProcessStatus + case object STARTED extends ProcessStatus + case object SHUTTING_DOWN extends ProcessStatus + + val SUPPORTED_FEATURES = Collections. + unmodifiableMap[String, VersionRange](Map[String, VersionRange]().asJava) +} diff --git a/core/src/main/scala/kafka/server/ThrottledChannel.scala b/core/src/main/scala/kafka/server/ThrottledChannel.scala new file mode 100644 index 0000000..8091678 --- /dev/null +++ b/core/src/main/scala/kafka/server/ThrottledChannel.scala @@ -0,0 +1,61 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util.concurrent.{Delayed, TimeUnit} + +import kafka.utils.Logging +import org.apache.kafka.common.utils.Time + +trait ThrottleCallback { + def startThrottling(): Unit + def endThrottling(): Unit +} + +/** + * Represents a request whose response has been delayed. + * @param time Time instance to use + * @param throttleTimeMs Delay associated with this request + * @param callback Callback for channel throttling + */ +class ThrottledChannel( + val time: Time, + val throttleTimeMs: Int, + val callback: ThrottleCallback +) extends Delayed with Logging { + + private val endTimeNanos = time.nanoseconds() + TimeUnit.MILLISECONDS.toNanos(throttleTimeMs) + + // Notify the socket server that throttling has started for this channel. + callback.startThrottling() + + // Notify the socket server that throttling has been done for this channel. + def notifyThrottlingDone(): Unit = { + trace(s"Channel throttled for: $throttleTimeMs ms") + callback.endThrottling() + } + + override def getDelay(unit: TimeUnit): Long = { + unit.convert(endTimeNanos - time.nanoseconds(), TimeUnit.NANOSECONDS) + } + + override def compareTo(d: Delayed): Int = { + val other = d.asInstanceOf[ThrottledChannel] + java.lang.Long.compare(this.endTimeNanos, other.endTimeNanos) + } +} diff --git a/core/src/main/scala/kafka/server/ZkAdminManager.scala b/core/src/main/scala/kafka/server/ZkAdminManager.scala new file mode 100644 index 0000000..d2e7456 --- /dev/null +++ b/core/src/main/scala/kafka/server/ZkAdminManager.scala @@ -0,0 +1,1149 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.util +import java.util.Properties + +import kafka.admin.{AdminOperationException, AdminUtils} +import kafka.common.TopicAlreadyMarkedForDeletionException +import kafka.log.LogConfig +import kafka.utils.Log4jController +import kafka.metrics.KafkaMetricsGroup +import kafka.server.DynamicConfig.QuotaConfigs +import kafka.server.metadata.ZkConfigRepository +import kafka.utils._ +import kafka.utils.Implicits._ +import kafka.zk.{AdminZkClient, KafkaZkClient} +import org.apache.kafka.clients.admin.{AlterConfigOp, ScramMechanism} +import org.apache.kafka.clients.admin.AlterConfigOp.OpType +import org.apache.kafka.common.Uuid +import org.apache.kafka.common.config.ConfigDef.ConfigKey +import org.apache.kafka.common.config.{ConfigDef, ConfigException, ConfigResource, LogLevelConfig} +import org.apache.kafka.common.errors.ThrottlingQuotaExceededException +import org.apache.kafka.common.errors.{ApiException, InvalidConfigurationException, InvalidPartitionsException, InvalidReplicaAssignmentException, InvalidRequestException, ReassignmentInProgressException, TopicExistsException, UnknownTopicOrPartitionException, UnsupportedVersionException} +import org.apache.kafka.common.message.AlterUserScramCredentialsResponseData.AlterUserScramCredentialsResult +import org.apache.kafka.common.message.CreatePartitionsRequestData.CreatePartitionsTopic +import org.apache.kafka.common.message.CreateTopicsRequestData.CreatableTopic +import org.apache.kafka.common.message.CreateTopicsResponseData.{CreatableTopicConfigs, CreatableTopicResult} +import org.apache.kafka.common.message.{AlterUserScramCredentialsRequestData, AlterUserScramCredentialsResponseData, DescribeUserScramCredentialsResponseData} +import org.apache.kafka.common.message.DescribeUserScramCredentialsResponseData.CredentialInfo +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.security.scram.internals.{ScramMechanism => InternalScramMechanism} +import org.apache.kafka.server.policy.{AlterConfigPolicy, CreateTopicPolicy} +import org.apache.kafka.server.policy.CreateTopicPolicy.RequestMetadata +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.quota.{ClientQuotaAlteration, ClientQuotaEntity, ClientQuotaFilter, ClientQuotaFilterComponent} +import org.apache.kafka.common.requests.CreateTopicsRequest._ +import org.apache.kafka.common.requests.{AlterConfigsRequest, ApiError} +import org.apache.kafka.common.security.scram.internals.{ScramCredentialUtils, ScramFormatter} +import org.apache.kafka.common.utils.Sanitizer + +import scala.collection.{Map, mutable, _} +import scala.jdk.CollectionConverters._ + +class ZkAdminManager(val config: KafkaConfig, + val metrics: Metrics, + val metadataCache: MetadataCache, + val zkClient: KafkaZkClient) extends Logging with KafkaMetricsGroup { + + this.logIdent = "[Admin Manager on Broker " + config.brokerId + "]: " + + private val topicPurgatory = DelayedOperationPurgatory[DelayedOperation]("topic", config.brokerId) + private val adminZkClient = new AdminZkClient(zkClient) + private val configHelper = new ConfigHelper(metadataCache, config, new ZkConfigRepository(adminZkClient)) + + private val createTopicPolicy = + Option(config.getConfiguredInstance(KafkaConfig.CreateTopicPolicyClassNameProp, classOf[CreateTopicPolicy])) + + private val alterConfigPolicy = + Option(config.getConfiguredInstance(KafkaConfig.AlterConfigPolicyClassNameProp, classOf[AlterConfigPolicy])) + + def hasDelayedTopicOperations = topicPurgatory.numDelayed != 0 + + private val defaultNumPartitions = config.numPartitions.intValue() + private val defaultReplicationFactor = config.defaultReplicationFactor.shortValue() + + /** + * Try to complete delayed topic operations with the request key + */ + def tryCompleteDelayedTopicOperations(topic: String): Unit = { + val key = TopicKey(topic) + val completed = topicPurgatory.checkAndComplete(key) + debug(s"Request key ${key.keyLabel} unblocked $completed topic requests.") + } + + private def validateTopicCreatePolicy(topic: CreatableTopic, + resolvedNumPartitions: Int, + resolvedReplicationFactor: Short, + assignments: Map[Int, Seq[Int]]): Unit = { + createTopicPolicy.foreach { policy => + // Use `null` for unset fields in the public API + val numPartitions: java.lang.Integer = + if (topic.assignments().isEmpty) resolvedNumPartitions else null + val replicationFactor: java.lang.Short = + if (topic.assignments().isEmpty) resolvedReplicationFactor else null + val javaAssignments = if (topic.assignments().isEmpty) { + null + } else { + assignments.map { case (k, v) => + (k: java.lang.Integer) -> v.map(i => i: java.lang.Integer).asJava + }.asJava + } + val javaConfigs = new java.util.HashMap[String, String] + topic.configs.forEach(config => javaConfigs.put(config.name, config.value)) + policy.validate(new RequestMetadata(topic.name, numPartitions, replicationFactor, + javaAssignments, javaConfigs)) + } + } + + private def maybePopulateMetadataAndConfigs(metadataAndConfigs: Map[String, CreatableTopicResult], + topicName: String, + configs: Properties, + assignments: Map[Int, Seq[Int]]): Unit = { + metadataAndConfigs.get(topicName).foreach { result => + val logConfig = LogConfig.fromProps(LogConfig.extractLogConfigMap(config), configs) + val createEntry = configHelper.createTopicConfigEntry(logConfig, configs, includeSynonyms = false, includeDocumentation = false)(_, _) + val topicConfigs = configHelper.allConfigs(logConfig).map { case (k, v) => + val entry = createEntry(k, v) + new CreatableTopicConfigs() + .setName(k) + .setValue(entry.value) + .setIsSensitive(entry.isSensitive) + .setReadOnly(entry.readOnly) + .setConfigSource(entry.configSource) + }.toList.asJava + result.setConfigs(topicConfigs) + result.setNumPartitions(assignments.size) + result.setReplicationFactor(assignments(0).size.toShort) + } + } + + private def populateIds(metadataAndConfigs: Map[String, CreatableTopicResult], + topicName: String) : Unit = { + metadataAndConfigs.get(topicName).foreach { result => + result.setTopicId(zkClient.getTopicIdsForTopics(Predef.Set(result.name())).getOrElse(result.name(), Uuid.ZERO_UUID)) + } + } + + /** + * Create topics and wait until the topics have been completely created. + * The callback function will be triggered either when timeout, error or the topics are created. + */ + def createTopics(timeout: Int, + validateOnly: Boolean, + toCreate: Map[String, CreatableTopic], + includeConfigsAndMetadata: Map[String, CreatableTopicResult], + controllerMutationQuota: ControllerMutationQuota, + responseCallback: Map[String, ApiError] => Unit): Unit = { + + // 1. map over topics creating assignment and calling zookeeper + val brokers = metadataCache.getAliveBrokers() + val metadata = toCreate.values.map(topic => + try { + if (metadataCache.contains(topic.name)) + throw new TopicExistsException(s"Topic '${topic.name}' already exists.") + + val nullConfigs = topic.configs.asScala.filter(_.value == null).map(_.name) + if (nullConfigs.nonEmpty) + throw new InvalidRequestException(s"Null value not supported for topic configs : ${nullConfigs.mkString(",")}") + + if ((topic.numPartitions != NO_NUM_PARTITIONS || topic.replicationFactor != NO_REPLICATION_FACTOR) + && !topic.assignments().isEmpty) { + throw new InvalidRequestException("Both numPartitions or replicationFactor and replicasAssignments were set. " + + "Both cannot be used at the same time.") + } + + val resolvedNumPartitions = if (topic.numPartitions == NO_NUM_PARTITIONS) + defaultNumPartitions else topic.numPartitions + val resolvedReplicationFactor = if (topic.replicationFactor == NO_REPLICATION_FACTOR) + defaultReplicationFactor else topic.replicationFactor + + val assignments = if (topic.assignments.isEmpty) { + AdminUtils.assignReplicasToBrokers( + brokers, resolvedNumPartitions, resolvedReplicationFactor) + } else { + val assignments = new mutable.HashMap[Int, Seq[Int]] + // Note: we don't check that replicaAssignment contains unknown brokers - unlike in add-partitions case, + // this follows the existing logic in TopicCommand + topic.assignments.forEach { assignment => + assignments(assignment.partitionIndex) = assignment.brokerIds.asScala.map(a => a: Int) + } + assignments + } + trace(s"Assignments for topic $topic are $assignments ") + + val configs = new Properties() + topic.configs.forEach(entry => configs.setProperty(entry.name, entry.value)) + adminZkClient.validateTopicCreate(topic.name, assignments, configs) + validateTopicCreatePolicy(topic, resolvedNumPartitions, resolvedReplicationFactor, assignments) + + // For responses with DescribeConfigs permission, populate metadata and configs. It is + // safe to populate it before creating the topic because the values are unset if the + // creation fails. + maybePopulateMetadataAndConfigs(includeConfigsAndMetadata, topic.name, configs, assignments) + + if (validateOnly) { + CreatePartitionsMetadata(topic.name, assignments.keySet) + } else { + controllerMutationQuota.record(assignments.size) + adminZkClient.createTopicWithAssignment(topic.name, configs, assignments, validate = false, config.usesTopicId) + populateIds(includeConfigsAndMetadata, topic.name) + CreatePartitionsMetadata(topic.name, assignments.keySet) + } + } catch { + // Log client errors at a lower level than unexpected exceptions + case e: TopicExistsException => + debug(s"Topic creation failed since topic '${topic.name}' already exists.", e) + CreatePartitionsMetadata(topic.name, e) + case e: ThrottlingQuotaExceededException => + debug(s"Topic creation not allowed because quota is violated. Delay time: ${e.throttleTimeMs}") + CreatePartitionsMetadata(topic.name, e) + case e: ApiException => + info(s"Error processing create topic request $topic", e) + CreatePartitionsMetadata(topic.name, e) + case e: ConfigException => + info(s"Error processing create topic request $topic", e) + CreatePartitionsMetadata(topic.name, new InvalidConfigurationException(e.getMessage, e.getCause)) + case e: Throwable => + error(s"Error processing create topic request $topic", e) + CreatePartitionsMetadata(topic.name, e) + }).toBuffer + + // 2. if timeout <= 0, validateOnly or no topics can proceed return immediately + if (timeout <= 0 || validateOnly || !metadata.exists(_.error.is(Errors.NONE))) { + val results = metadata.map { createTopicMetadata => + // ignore topics that already have errors + if (createTopicMetadata.error.isSuccess && !validateOnly) { + (createTopicMetadata.topic, new ApiError(Errors.REQUEST_TIMED_OUT, null)) + } else { + (createTopicMetadata.topic, createTopicMetadata.error) + } + }.toMap + responseCallback(results) + } else { + // 3. else pass the assignments and errors to the delayed operation and set the keys + val delayedCreate = new DelayedCreatePartitions(timeout, metadata, this, + responseCallback) + val delayedCreateKeys = toCreate.values.map(topic => TopicKey(topic.name)).toBuffer + // try to complete the request immediately, otherwise put it into the purgatory + topicPurgatory.tryCompleteElseWatch(delayedCreate, delayedCreateKeys) + } + } + + /** + * Delete topics and wait until the topics have been completely deleted. + * The callback function will be triggered either when timeout, error or the topics are deleted. + */ + def deleteTopics(timeout: Int, + topics: Set[String], + controllerMutationQuota: ControllerMutationQuota, + responseCallback: Map[String, Errors] => Unit): Unit = { + // 1. map over topics calling the asynchronous delete + val metadata = topics.map { topic => + try { + controllerMutationQuota.record(metadataCache.numPartitions(topic).getOrElse(0).toDouble) + adminZkClient.deleteTopic(topic) + DeleteTopicMetadata(topic, Errors.NONE) + } catch { + case _: TopicAlreadyMarkedForDeletionException => + // swallow the exception, and still track deletion allowing multiple calls to wait for deletion + DeleteTopicMetadata(topic, Errors.NONE) + case e: ThrottlingQuotaExceededException => + debug(s"Topic deletion not allowed because quota is violated. Delay time: ${e.throttleTimeMs}") + DeleteTopicMetadata(topic, e) + case e: Throwable => + error(s"Error processing delete topic request for topic $topic", e) + DeleteTopicMetadata(topic, e) + } + } + + // 2. if timeout <= 0 or no topics can proceed return immediately + if (timeout <= 0 || !metadata.exists(_.error == Errors.NONE)) { + val results = metadata.map { deleteTopicMetadata => + // ignore topics that already have errors + if (deleteTopicMetadata.error == Errors.NONE) { + (deleteTopicMetadata.topic, Errors.REQUEST_TIMED_OUT) + } else { + (deleteTopicMetadata.topic, deleteTopicMetadata.error) + } + }.toMap + responseCallback(results) + } else { + // 3. else pass the topics and errors to the delayed operation and set the keys + val delayedDelete = new DelayedDeleteTopics(timeout, metadata.toSeq, this, responseCallback) + val delayedDeleteKeys = topics.map(TopicKey).toSeq + // try to complete the request immediately, otherwise put it into the purgatory + topicPurgatory.tryCompleteElseWatch(delayedDelete, delayedDeleteKeys) + } + } + + def createPartitions(timeout: Int, + newPartitions: Seq[CreatePartitionsTopic], + validateOnly: Boolean, + controllerMutationQuota: ControllerMutationQuota, + callback: Map[String, ApiError] => Unit): Unit = { + val allBrokers = adminZkClient.getBrokerMetadatas() + val allBrokerIds = allBrokers.map(_.id) + + // 1. map over topics creating assignment and calling AdminUtils + val metadata = newPartitions.map { newPartition => + val topic = newPartition.name + + try { + val existingAssignment = zkClient.getFullReplicaAssignmentForTopics(immutable.Set(topic)).map { + case (topicPartition, assignment) => + if (assignment.isBeingReassigned) { + // We prevent adding partitions while topic reassignment is in progress, to protect from a race condition + // between the controller thread processing reassignment update and createPartitions(this) request. + throw new ReassignmentInProgressException(s"A partition reassignment is in progress for the topic '$topic'.") + } + topicPartition.partition -> assignment + } + if (existingAssignment.isEmpty) + throw new UnknownTopicOrPartitionException(s"The topic '$topic' does not exist.") + + val oldNumPartitions = existingAssignment.size + val newNumPartitions = newPartition.count + val numPartitionsIncrement = newNumPartitions - oldNumPartitions + if (numPartitionsIncrement < 0) { + throw new InvalidPartitionsException( + s"Topic currently has $oldNumPartitions partitions, which is higher than the requested $newNumPartitions.") + } else if (numPartitionsIncrement == 0) { + throw new InvalidPartitionsException(s"Topic already has $oldNumPartitions partitions.") + } + + val newPartitionsAssignment = Option(newPartition.assignments).map { assignmentMap => + val assignments = assignmentMap.asScala.map { + createPartitionAssignment => createPartitionAssignment.brokerIds.asScala.map(_.toInt) + } + val unknownBrokers = assignments.flatten.toSet -- allBrokerIds + if (unknownBrokers.nonEmpty) + throw new InvalidReplicaAssignmentException( + s"Unknown broker(s) in replica assignment: ${unknownBrokers.mkString(", ")}.") + + if (assignments.size != numPartitionsIncrement) + throw new InvalidReplicaAssignmentException( + s"Increasing the number of partitions by $numPartitionsIncrement " + + s"but ${assignments.size} assignments provided.") + + assignments.zipWithIndex.map { case (replicas, index) => + existingAssignment.size + index -> replicas + }.toMap + } + + val assignmentForNewPartitions = adminZkClient.createNewPartitionsAssignment( + topic, existingAssignment, allBrokers, newPartition.count, newPartitionsAssignment) + + if (validateOnly) { + CreatePartitionsMetadata(topic, (existingAssignment ++ assignmentForNewPartitions).keySet) + } else { + controllerMutationQuota.record(numPartitionsIncrement) + val updatedReplicaAssignment = adminZkClient.createPartitionsWithAssignment( + topic, existingAssignment, assignmentForNewPartitions) + CreatePartitionsMetadata(topic, updatedReplicaAssignment.keySet) + } + } catch { + case e: AdminOperationException => + CreatePartitionsMetadata(topic, e) + case e: ThrottlingQuotaExceededException => + debug(s"Partition(s) creation not allowed because quota is violated. Delay time: ${e.throttleTimeMs}") + CreatePartitionsMetadata(topic, e) + case e: ApiException => + CreatePartitionsMetadata(topic, e) + } + } + + // 2. if timeout <= 0, validateOnly or no topics can proceed return immediately + if (timeout <= 0 || validateOnly || !metadata.exists(_.error.is(Errors.NONE))) { + val results = metadata.map { createPartitionMetadata => + // ignore topics that already have errors + if (createPartitionMetadata.error.isSuccess && !validateOnly) { + (createPartitionMetadata.topic, new ApiError(Errors.REQUEST_TIMED_OUT, null)) + } else { + (createPartitionMetadata.topic, createPartitionMetadata.error) + } + }.toMap + callback(results) + } else { + // 3. else pass the assignments and errors to the delayed operation and set the keys + val delayedCreate = new DelayedCreatePartitions(timeout, metadata, this, callback) + val delayedCreateKeys = newPartitions.map(createPartitionTopic => TopicKey(createPartitionTopic.name)) + // try to complete the request immediately, otherwise put it into the purgatory + topicPurgatory.tryCompleteElseWatch(delayedCreate, delayedCreateKeys) + } + } + + def alterConfigs(configs: Map[ConfigResource, AlterConfigsRequest.Config], validateOnly: Boolean): Map[ConfigResource, ApiError] = { + configs.map { case (resource, config) => + + try { + val nullUpdates = config.entries.asScala.filter(_.value == null).map(_.name) + if (nullUpdates.nonEmpty) + throw new InvalidRequestException(s"Null value not supported for : ${nullUpdates.mkString(",")}") + + val configEntriesMap = config.entries.asScala.map(entry => (entry.name, entry.value)).toMap + + val configProps = new Properties + config.entries.asScala.filter(_.value != null).foreach { configEntry => + configProps.setProperty(configEntry.name, configEntry.value) + } + + resource.`type` match { + case ConfigResource.Type.TOPIC => alterTopicConfigs(resource, validateOnly, configProps, configEntriesMap) + case ConfigResource.Type.BROKER => alterBrokerConfigs(resource, validateOnly, configProps, configEntriesMap) + case resourceType => + throw new InvalidRequestException(s"AlterConfigs is only supported for topics and brokers, but resource type is $resourceType") + } + } catch { + case e @ (_: ConfigException | _: IllegalArgumentException) => + val message = s"Invalid config value for resource $resource: ${e.getMessage}" + info(message) + resource -> ApiError.fromThrowable(new InvalidRequestException(message, e)) + case e: Throwable => + val configProps = new Properties + config.entries.asScala.filter(_.value != null).foreach { configEntry => + configProps.setProperty(configEntry.name, configEntry.value) + } + // Log client errors at a lower level than unexpected exceptions + val message = s"Error processing alter configs request for resource $resource, config ${toLoggableProps(resource, configProps).mkString(",")}" + if (e.isInstanceOf[ApiException]) + info(message, e) + else + error(message, e) + resource -> ApiError.fromThrowable(e) + } + }.toMap + } + + private def alterTopicConfigs(resource: ConfigResource, validateOnly: Boolean, + configProps: Properties, configEntriesMap: Map[String, String]): (ConfigResource, ApiError) = { + val topic = resource.name + if (!metadataCache.contains(topic)) + throw new UnknownTopicOrPartitionException(s"The topic '$topic' does not exist.") + + adminZkClient.validateTopicConfig(topic, configProps) + validateConfigPolicy(resource, configEntriesMap) + if (!validateOnly) { + info(s"Updating topic $topic with new configuration : ${toLoggableProps(resource, configProps).mkString(",")}") + adminZkClient.changeTopicConfig(topic, configProps) + } + + resource -> ApiError.NONE + } + + private def alterBrokerConfigs(resource: ConfigResource, validateOnly: Boolean, + configProps: Properties, configEntriesMap: Map[String, String]): (ConfigResource, ApiError) = { + val brokerId = getBrokerId(resource) + val perBrokerConfig = brokerId.nonEmpty + this.config.dynamicConfig.validate(configProps, perBrokerConfig) + validateConfigPolicy(resource, configEntriesMap) + if (!validateOnly) { + if (perBrokerConfig) + this.config.dynamicConfig.reloadUpdatedFilesWithoutConfigChange(configProps) + + if (perBrokerConfig) + info(s"Updating broker ${brokerId.get} with new configuration : ${toLoggableProps(resource, configProps).mkString(",")}") + else + info(s"Updating brokers with new configuration : ${toLoggableProps(resource, configProps).mkString(",")}") + + adminZkClient.changeBrokerConfig(brokerId, + this.config.dynamicConfig.toPersistentProps(configProps, perBrokerConfig)) + } + + resource -> ApiError.NONE + } + + private def toLoggableProps(resource: ConfigResource, configProps: Properties): Map[String, String] = { + configProps.asScala.map { + case (key, value) => (key, KafkaConfig.loggableValue(resource.`type`, key, value)) + } + } + + private def alterLogLevelConfigs(alterConfigOps: Seq[AlterConfigOp]): Unit = { + alterConfigOps.foreach { alterConfigOp => + val loggerName = alterConfigOp.configEntry().name() + val logLevel = alterConfigOp.configEntry().value() + alterConfigOp.opType() match { + case OpType.SET => + info(s"Updating the log level of $loggerName to $logLevel") + Log4jController.logLevel(loggerName, logLevel) + case OpType.DELETE => + info(s"Unset the log level of $loggerName") + Log4jController.unsetLogLevel(loggerName) + case _ => throw new IllegalArgumentException( + s"Log level cannot be changed for OpType: ${alterConfigOp.opType()}") + } + } + } + + private def getBrokerId(resource: ConfigResource) = { + if (resource.name == null || resource.name.isEmpty) + None + else { + val id = resourceNameToBrokerId(resource.name) + if (id != this.config.brokerId) + throw new InvalidRequestException(s"Unexpected broker id, expected ${this.config.brokerId}, but received ${resource.name}") + Some(id) + } + } + + private def validateConfigPolicy(resource: ConfigResource, configEntriesMap: Map[String, String]): Unit = { + alterConfigPolicy match { + case Some(policy) => + policy.validate(new AlterConfigPolicy.RequestMetadata( + new ConfigResource(resource.`type`(), resource.name), configEntriesMap.asJava)) + case None => + } + } + + def incrementalAlterConfigs(configs: Map[ConfigResource, Seq[AlterConfigOp]], validateOnly: Boolean): Map[ConfigResource, ApiError] = { + configs.map { case (resource, alterConfigOps) => + try { + // throw InvalidRequestException if any duplicate keys + val duplicateKeys = alterConfigOps.groupBy(config => config.configEntry.name).filter { case (_, v) => + v.size > 1 + }.keySet + if (duplicateKeys.nonEmpty) + throw new InvalidRequestException(s"Error due to duplicate config keys : ${duplicateKeys.mkString(",")}") + val nullUpdates = alterConfigOps + .filter(entry => entry.configEntry.value == null && entry.opType() != OpType.DELETE) + .map(entry => s"${entry.opType}:${entry.configEntry.name}") + if (nullUpdates.nonEmpty) + throw new InvalidRequestException(s"Null value not supported for : ${nullUpdates.mkString(",")}") + + val configEntriesMap = alterConfigOps.map(entry => (entry.configEntry.name, entry.configEntry.value)).toMap + + resource.`type` match { + case ConfigResource.Type.TOPIC => + val configProps = adminZkClient.fetchEntityConfig(ConfigType.Topic, resource.name) + prepareIncrementalConfigs(alterConfigOps, configProps, LogConfig.configKeys) + alterTopicConfigs(resource, validateOnly, configProps, configEntriesMap) + + case ConfigResource.Type.BROKER => + val brokerId = getBrokerId(resource) + val perBrokerConfig = brokerId.nonEmpty + + val persistentProps = if (perBrokerConfig) adminZkClient.fetchEntityConfig(ConfigType.Broker, brokerId.get.toString) + else adminZkClient.fetchEntityConfig(ConfigType.Broker, ConfigEntityName.Default) + + val configProps = this.config.dynamicConfig.fromPersistentProps(persistentProps, perBrokerConfig) + prepareIncrementalConfigs(alterConfigOps, configProps, KafkaConfig.configKeys) + alterBrokerConfigs(resource, validateOnly, configProps, configEntriesMap) + + case ConfigResource.Type.BROKER_LOGGER => + getBrokerId(resource) + validateLogLevelConfigs(alterConfigOps) + + if (!validateOnly) + alterLogLevelConfigs(alterConfigOps) + resource -> ApiError.NONE + case resourceType => + throw new InvalidRequestException(s"AlterConfigs is only supported for topics and brokers, but resource type is $resourceType") + } + } catch { + case e @ (_: ConfigException | _: IllegalArgumentException) => + val message = s"Invalid config value for resource $resource: ${e.getMessage}" + info(message) + resource -> ApiError.fromThrowable(new InvalidRequestException(message, e)) + case e: Throwable => + // Log client errors at a lower level than unexpected exceptions + val message = s"Error processing alter configs request for resource $resource, config $alterConfigOps" + if (e.isInstanceOf[ApiException]) + info(message, e) + else + error(message, e) + resource -> ApiError.fromThrowable(e) + } + }.toMap + } + + private def validateLogLevelConfigs(alterConfigOps: Seq[AlterConfigOp]): Unit = { + def validateLoggerNameExists(loggerName: String): Unit = { + if (!Log4jController.loggerExists(loggerName)) + throw new ConfigException(s"Logger $loggerName does not exist!") + } + + alterConfigOps.foreach { alterConfigOp => + val loggerName = alterConfigOp.configEntry.name + alterConfigOp.opType() match { + case OpType.SET => + validateLoggerNameExists(loggerName) + val logLevel = alterConfigOp.configEntry.value + if (!LogLevelConfig.VALID_LOG_LEVELS.contains(logLevel)) { + val validLevelsStr = LogLevelConfig.VALID_LOG_LEVELS.asScala.mkString(", ") + throw new ConfigException( + s"Cannot set the log level of $loggerName to $logLevel as it is not a supported log level. " + + s"Valid log levels are $validLevelsStr" + ) + } + case OpType.DELETE => + validateLoggerNameExists(loggerName) + if (loggerName == Log4jController.ROOT_LOGGER) + throw new InvalidRequestException(s"Removing the log level of the ${Log4jController.ROOT_LOGGER} logger is not allowed") + case OpType.APPEND => throw new InvalidRequestException(s"${OpType.APPEND} operation is not allowed for the ${ConfigResource.Type.BROKER_LOGGER} resource") + case OpType.SUBTRACT => throw new InvalidRequestException(s"${OpType.SUBTRACT} operation is not allowed for the ${ConfigResource.Type.BROKER_LOGGER} resource") + } + } + } + + private def prepareIncrementalConfigs(alterConfigOps: Seq[AlterConfigOp], configProps: Properties, configKeys: Map[String, ConfigKey]): Unit = { + + def listType(configName: String, configKeys: Map[String, ConfigKey]): Boolean = { + val configKey = configKeys(configName) + if (configKey == null) + throw new InvalidConfigurationException(s"Unknown topic config name: $configName") + configKey.`type` == ConfigDef.Type.LIST + } + + alterConfigOps.foreach { alterConfigOp => + val configPropName = alterConfigOp.configEntry.name + alterConfigOp.opType() match { + case OpType.SET => configProps.setProperty(alterConfigOp.configEntry.name, alterConfigOp.configEntry.value) + case OpType.DELETE => configProps.remove(alterConfigOp.configEntry.name) + case OpType.APPEND => { + if (!listType(alterConfigOp.configEntry.name, configKeys)) + throw new InvalidRequestException(s"Config value append is not allowed for config key: ${alterConfigOp.configEntry.name}") + val oldValueList = Option(configProps.getProperty(alterConfigOp.configEntry.name)) + .orElse(Option(ConfigDef.convertToString(configKeys(configPropName).defaultValue, ConfigDef.Type.LIST))) + .getOrElse("") + .split(",").toList + val newValueList = oldValueList ::: alterConfigOp.configEntry.value.split(",").toList + configProps.setProperty(alterConfigOp.configEntry.name, newValueList.mkString(",")) + } + case OpType.SUBTRACT => { + if (!listType(alterConfigOp.configEntry.name, configKeys)) + throw new InvalidRequestException(s"Config value subtract is not allowed for config key: ${alterConfigOp.configEntry.name}") + val oldValueList = Option(configProps.getProperty(alterConfigOp.configEntry.name)) + .orElse(Option(ConfigDef.convertToString(configKeys(configPropName).defaultValue, ConfigDef.Type.LIST))) + .getOrElse("") + .split(",").toList + val newValueList = oldValueList.diff(alterConfigOp.configEntry.value.split(",").toList) + configProps.setProperty(alterConfigOp.configEntry.name, newValueList.mkString(",")) + } + } + } + } + + def shutdown(): Unit = { + topicPurgatory.shutdown() + CoreUtils.swallow(createTopicPolicy.foreach(_.close()), this) + CoreUtils.swallow(alterConfigPolicy.foreach(_.close()), this) + } + + private def resourceNameToBrokerId(resourceName: String): Int = { + try resourceName.toInt catch { + case _: NumberFormatException => + throw new InvalidRequestException(s"Broker id must be an integer, but it is: $resourceName") + } + } + + private def sanitizeEntityName(entityName: String): String = + Option(entityName) match { + case None => ConfigEntityName.Default + case Some(name) => Sanitizer.sanitize(name) + } + + private def desanitizeEntityName(sanitizedEntityName: String): String = + sanitizedEntityName match { + case ConfigEntityName.Default => null + case name => Sanitizer.desanitize(name) + } + + private def parseAndSanitizeQuotaEntity(entity: ClientQuotaEntity): (Option[String], Option[String], Option[String]) = { + if (entity.entries.isEmpty) + throw new InvalidRequestException("Invalid empty client quota entity") + + var user: Option[String] = None + var clientId: Option[String] = None + var ip: Option[String] = None + entity.entries.forEach { (entityType, entityName) => + val sanitizedEntityName = Some(sanitizeEntityName(entityName)) + entityType match { + case ClientQuotaEntity.USER => user = sanitizedEntityName + case ClientQuotaEntity.CLIENT_ID => clientId = sanitizedEntityName + case ClientQuotaEntity.IP => ip = sanitizedEntityName + case _ => throw new InvalidRequestException(s"Unhandled client quota entity type: ${entityType}") + } + if (entityName != null && entityName.isEmpty) + throw new InvalidRequestException(s"Empty ${entityType} not supported") + } + (user, clientId, ip) + } + + private def userClientIdToEntity(user: Option[String], clientId: Option[String]): ClientQuotaEntity = { + new ClientQuotaEntity((user.map(u => ClientQuotaEntity.USER -> u) ++ clientId.map(c => ClientQuotaEntity.CLIENT_ID -> c)).toMap.asJava) + } + + def describeClientQuotas(filter: ClientQuotaFilter): Map[ClientQuotaEntity, Map[String, Double]] = { + var userComponent: Option[ClientQuotaFilterComponent] = None + var clientIdComponent: Option[ClientQuotaFilterComponent] = None + var ipComponent: Option[ClientQuotaFilterComponent] = None + filter.components.forEach { component => + component.entityType match { + case ClientQuotaEntity.USER => + if (userComponent.isDefined) + throw new InvalidRequestException(s"Duplicate user filter component entity type") + userComponent = Some(component) + case ClientQuotaEntity.CLIENT_ID => + if (clientIdComponent.isDefined) + throw new InvalidRequestException(s"Duplicate client filter component entity type") + clientIdComponent = Some(component) + case ClientQuotaEntity.IP => + if (ipComponent.isDefined) + throw new InvalidRequestException(s"Duplicate ip filter component entity type") + ipComponent = Some(component) + case "" => + throw new InvalidRequestException(s"Unexpected empty filter component entity type") + case et => + // Supplying other entity types is not yet supported. + throw new UnsupportedVersionException(s"Custom entity type '${et}' not supported") + } + } + if ((userComponent.isDefined || clientIdComponent.isDefined) && ipComponent.isDefined) + throw new InvalidRequestException(s"Invalid entity filter component combination, IP filter component should not be used with " + + s"user or clientId filter component.") + + val userClientQuotas = if (ipComponent.isEmpty) + handleDescribeClientQuotas(userComponent, clientIdComponent, filter.strict) + else + Map.empty + + val ipQuotas = if (userComponent.isEmpty && clientIdComponent.isEmpty) + handleDescribeIpQuotas(ipComponent, filter.strict) + else + Map.empty + + (userClientQuotas ++ ipQuotas).toMap + } + + private def wantExact(component: Option[ClientQuotaFilterComponent]): Boolean = component.exists(_.`match` != null) + + private def toOption(opt: java.util.Optional[String]): Option[String] = { + if (opt == null) + None + else if (opt.isPresent) + Some(opt.get) + else + Some(null) + } + + private def sanitized(name: Option[String]): String = name.map(n => sanitizeEntityName(n)).getOrElse("") + + private def fromProps(props: Map[String, String]): Map[String, Double] = { + props.map { case (key, value) => + val doubleValue = try value.toDouble catch { + case _: NumberFormatException => + throw new IllegalStateException(s"Unexpected client quota configuration value: $key -> $value") + } + key -> doubleValue + } + } + + def handleDescribeClientQuotas(userComponent: Option[ClientQuotaFilterComponent], + clientIdComponent: Option[ClientQuotaFilterComponent], strict: Boolean): Map[ClientQuotaEntity, Map[String, Double]] = { + + val user = userComponent.flatMap(c => toOption(c.`match`)) + val clientId = clientIdComponent.flatMap(c => toOption(c.`match`)) + + val sanitizedUser = sanitized(user) + val sanitizedClientId = sanitized(clientId) + + val exactUser = wantExact(userComponent) + val exactClientId = wantExact(clientIdComponent) + + def wantExcluded(component: Option[ClientQuotaFilterComponent]): Boolean = strict && !component.isDefined + val excludeUser = wantExcluded(userComponent) + val excludeClientId = wantExcluded(clientIdComponent) + + val userEntries = if (exactUser && excludeClientId) + Map((Some(user.get), None) -> adminZkClient.fetchEntityConfig(ConfigType.User, sanitizedUser)) + else if (!excludeUser && !exactClientId) + adminZkClient.fetchAllEntityConfigs(ConfigType.User).map { case (name, props) => + (Some(desanitizeEntityName(name)), None) -> props + } + else + Map.empty + + val clientIdEntries = if (excludeUser && exactClientId) + Map((None, Some(clientId.get)) -> adminZkClient.fetchEntityConfig(ConfigType.Client, sanitizedClientId)) + else if (!exactUser && !excludeClientId) + adminZkClient.fetchAllEntityConfigs(ConfigType.Client).map { case (name, props) => + (None, Some(desanitizeEntityName(name))) -> props + } + else + Map.empty + + val bothEntries = if (exactUser && exactClientId) + Map((Some(user.get), Some(clientId.get)) -> + adminZkClient.fetchEntityConfig(ConfigType.User, s"${sanitizedUser}/clients/${sanitizedClientId}")) + else if (!excludeUser && !excludeClientId) + adminZkClient.fetchAllChildEntityConfigs(ConfigType.User, ConfigType.Client).map { case (name, props) => + val components = name.split("/") + if (components.size != 3 || components(1) != "clients") + throw new IllegalArgumentException(s"Unexpected config path: ${name}") + (Some(desanitizeEntityName(components(0))), Some(desanitizeEntityName(components(2)))) -> props + } + else + Map.empty + + def matches(nameComponent: Option[ClientQuotaFilterComponent], name: Option[String]): Boolean = nameComponent match { + case Some(component) => + toOption(component.`match`) match { + case Some(n) => name.exists(_ == n) + case None => name.isDefined + } + case None => + !name.isDefined || !strict + } + + (userEntries ++ clientIdEntries ++ bothEntries).flatMap { case ((u, c), p) => + val quotaProps = p.asScala.filter { case (key, _) => QuotaConfigs.isClientOrUserQuotaConfig(key) } + if (quotaProps.nonEmpty && matches(userComponent, u) && matches(clientIdComponent, c)) + Some(userClientIdToEntity(u, c) -> fromProps(quotaProps)) + else + None + }.toMap + } + + def handleDescribeIpQuotas(ipComponent: Option[ClientQuotaFilterComponent], strict: Boolean): Map[ClientQuotaEntity, Map[String, Double]] = { + val ip = ipComponent.flatMap(c => toOption(c.`match`)) + val exactIp = wantExact(ipComponent) + val allIps = ipComponent.exists(_.`match` == null) || (ipComponent.isEmpty && !strict) + val ipEntries = if (exactIp) + Map(Some(ip.get) -> adminZkClient.fetchEntityConfig(ConfigType.Ip, sanitized(ip))) + else if (allIps) + adminZkClient.fetchAllEntityConfigs(ConfigType.Ip).map { case (name, props) => + Some(desanitizeEntityName(name)) -> props + } + else + Map.empty + + def ipToQuotaEntity(ip: Option[String]): ClientQuotaEntity = { + new ClientQuotaEntity(ip.map(ipName => ClientQuotaEntity.IP -> ipName).toMap.asJava) + } + + ipEntries.flatMap { case (ip, props) => + val ipQuotaProps = props.asScala.filter { case (key, _) => DynamicConfig.Ip.names.contains(key) } + if (ipQuotaProps.nonEmpty) + Some(ipToQuotaEntity(ip) -> fromProps(ipQuotaProps)) + else + None + } + } + + def alterClientQuotas(entries: Seq[ClientQuotaAlteration], validateOnly: Boolean): Map[ClientQuotaEntity, ApiError] = { + def alterEntityQuotas(entity: ClientQuotaEntity, ops: Iterable[ClientQuotaAlteration.Op]): Unit = { + val (path, configType, configKeys) = parseAndSanitizeQuotaEntity(entity) match { + case (Some(user), Some(clientId), None) => (user + "/clients/" + clientId, ConfigType.User, DynamicConfig.User.configKeys) + case (Some(user), None, None) => (user, ConfigType.User, DynamicConfig.User.configKeys) + case (None, Some(clientId), None) => (clientId, ConfigType.Client, DynamicConfig.Client.configKeys) + case (None, None, Some(ip)) => + if (!DynamicConfig.Ip.isValidIpEntity(ip)) + throw new InvalidRequestException(s"$ip is not a valid IP or resolvable host.") + (ip, ConfigType.Ip, DynamicConfig.Ip.configKeys) + case (_, _, Some(_)) => throw new InvalidRequestException(s"Invalid quota entity combination, " + + s"IP entity should not be used with user/client ID entity.") + case _ => throw new InvalidRequestException("Invalid client quota entity") + } + + val props = adminZkClient.fetchEntityConfig(configType, path) + ops.foreach { op => + op.value match { + case null => + props.remove(op.key) + case value => configKeys.get(op.key) match { + case null => + throw new InvalidRequestException(s"Invalid configuration key ${op.key}") + case key => key.`type` match { + case ConfigDef.Type.DOUBLE => + props.setProperty(op.key, value.toString) + case ConfigDef.Type.LONG | ConfigDef.Type.INT => + val epsilon = 1e-6 + val intValue = if (key.`type` == ConfigDef.Type.LONG) + (value + epsilon).toLong + else + (value + epsilon).toInt + if ((intValue.toDouble - value).abs > epsilon) + throw new InvalidRequestException(s"Configuration ${op.key} must be a ${key.`type`} value") + props.setProperty(op.key, intValue.toString) + case _ => + throw new IllegalStateException(s"Unexpected config type ${key.`type`}") + } + } + } + } + if (!validateOnly) + adminZkClient.changeConfigs(configType, path, props) + } + entries.map { entry => + val apiError = try { + alterEntityQuotas(entry.entity, entry.ops.asScala) + ApiError.NONE + } catch { + case e: Throwable => + info(s"Error encountered while updating client quotas", e) + ApiError.fromThrowable(e) + } + entry.entity -> apiError + }.toMap + } + + private val usernameMustNotBeEmptyMsg = "Username must not be empty" + private val errorProcessingDescribe = "Error processing describe user SCRAM credential configs request" + private val attemptToDescribeUserThatDoesNotExist = "Attempt to describe a user credential that does not exist" + + def describeUserScramCredentials(users: Option[Seq[String]]): DescribeUserScramCredentialsResponseData = { + val describingAllUsers = !users.isDefined || users.get.isEmpty + val retval = new DescribeUserScramCredentialsResponseData() + val userResults = mutable.Map[String, DescribeUserScramCredentialsResponseData.DescribeUserScramCredentialsResult]() + + def addToResultsIfHasScramCredential(user: String, userConfig: Properties, explicitUser: Boolean = false): Unit = { + val result = new DescribeUserScramCredentialsResponseData.DescribeUserScramCredentialsResult().setUser(user) + val configKeys = userConfig.stringPropertyNames + val hasScramCredential = ScramMechanism.values().toList.exists(key => key != ScramMechanism.UNKNOWN && configKeys.contains(key.mechanismName)) + if (hasScramCredential) { + val credentialInfos = new util.ArrayList[CredentialInfo] + try { + ScramMechanism.values().filter(_ != ScramMechanism.UNKNOWN).foreach { mechanism => + val propertyValue = userConfig.getProperty(mechanism.mechanismName) + if (propertyValue != null) { + val iterations = ScramCredentialUtils.credentialFromString(propertyValue).iterations + credentialInfos.add(new CredentialInfo().setMechanism(mechanism.`type`).setIterations(iterations)) + } + } + result.setCredentialInfos(credentialInfos) + } catch { + case e: Exception => { // should generally never happen, but just in case bad data gets in... + val apiError = apiErrorFrom(e, errorProcessingDescribe) + result.setErrorCode(apiError.error.code).setErrorMessage(apiError.error.message) + } + } + userResults += (user -> result) + } else if (explicitUser) { + // it is an error to request credentials for a user that has no credentials + result.setErrorCode(Errors.RESOURCE_NOT_FOUND.code).setErrorMessage(s"$attemptToDescribeUserThatDoesNotExist: $user") + userResults += (user -> result) + } + } + + def collectRetrievedResults(): Unit = { + if (describingAllUsers) { + val usersSorted = SortedSet.empty[String] ++ userResults.keys + usersSorted.foreach { user => retval.results.add(userResults(user)) } + } else { + // be sure to only include a single copy of a result for any user requested multiple times + users.get.distinct.foreach { user => retval.results.add(userResults(user)) } + } + } + + try { + if (describingAllUsers) + adminZkClient.fetchAllEntityConfigs(ConfigType.User).foreach { + case (user, properties) => addToResultsIfHasScramCredential(user, properties) } + else { + // describing specific users + val illegalUsers = users.get.filter(_.isEmpty).toSet + illegalUsers.foreach { user => + userResults += (user -> new DescribeUserScramCredentialsResponseData.DescribeUserScramCredentialsResult() + .setUser(user) + .setErrorCode(Errors.RESOURCE_NOT_FOUND.code) + .setErrorMessage(usernameMustNotBeEmptyMsg)) } + val duplicatedUsers = users.get.groupBy(identity).filter( + userAndOccurrencesTuple => userAndOccurrencesTuple._2.length > 1).keys + duplicatedUsers.filterNot(illegalUsers.contains).foreach { user => + userResults += (user -> new DescribeUserScramCredentialsResponseData.DescribeUserScramCredentialsResult() + .setUser(user) + .setErrorCode(Errors.DUPLICATE_RESOURCE.code) + .setErrorMessage(s"Cannot describe SCRAM credentials for the same user twice in a single request: $user")) } + val usersToSkip = illegalUsers ++ duplicatedUsers + users.get.filterNot(usersToSkip.contains).foreach { user => + try { + val userConfigs = adminZkClient.fetchEntityConfig(ConfigType.User, Sanitizer.sanitize(user)) + addToResultsIfHasScramCredential(user, userConfigs, true) + } catch { + case e: Exception => { + val apiError = apiErrorFrom(e, errorProcessingDescribe) + userResults += (user -> new DescribeUserScramCredentialsResponseData.DescribeUserScramCredentialsResult() + .setUser(user) + .setErrorCode(apiError.error.code) + .setErrorMessage(apiError.error.message)) + } + } + } + } + collectRetrievedResults() + } catch { + case e: Exception => { + // this should generally only happen when we get a failure trying to retrieve all user configs from ZooKeeper + val apiError = apiErrorFrom(e, errorProcessingDescribe) + retval.setErrorCode(apiError.error.code).setErrorMessage(apiError.messageWithFallback()) + } + } + retval + } + + def apiErrorFrom(e: Exception, message: String): ApiError = { + if (e.isInstanceOf[ApiException]) + info(message, e) + else + error(message, e) + ApiError.fromThrowable(e) + } + + case class requestStatus(user: String, mechanism: Option[ScramMechanism], legalRequest: Boolean, iterations: Int) {} + + def alterUserScramCredentials(upsertions: Seq[AlterUserScramCredentialsRequestData.ScramCredentialUpsertion], + deletions: Seq[AlterUserScramCredentialsRequestData.ScramCredentialDeletion]): AlterUserScramCredentialsResponseData = { + + def scramMechanism(mechanism: Byte): ScramMechanism = { + ScramMechanism.fromType(mechanism) + } + + def mechanismName(mechanism: Byte): String = { + scramMechanism(mechanism).mechanismName + } + + val retval = new AlterUserScramCredentialsResponseData() + + // fail any user that is invalid due to an empty user name, an unknown SCRAM mechanism, or unacceptable number of iterations + val maxIterations = 16384 + val illegalUpsertions = upsertions.map(upsertion => + if (upsertion.name.isEmpty) + requestStatus(upsertion.name, None, false, upsertion.iterations) // no determined mechanism -- empty user is the cause of failure + else { + val publicScramMechanism = scramMechanism(upsertion.mechanism) + if (publicScramMechanism == ScramMechanism.UNKNOWN) { + requestStatus(upsertion.name, Some(publicScramMechanism), false, upsertion.iterations) // unknown mechanism is the cause of failure + } else { + if (upsertion.iterations < InternalScramMechanism.forMechanismName(publicScramMechanism.mechanismName).minIterations + || upsertion.iterations > maxIterations) { + requestStatus(upsertion.name, Some(publicScramMechanism), false, upsertion.iterations) // known mechanism, bad iterations is the cause of failure + } else { + requestStatus(upsertion.name, Some(publicScramMechanism), true, upsertion.iterations) // legal + } + } + }).filter { !_.legalRequest } + val illegalDeletions = deletions.map(deletion => + if (deletion.name.isEmpty) { + requestStatus(deletion.name, None, false, 0) // no determined mechanism -- empty user is the cause of failure + } else { + val publicScramMechanism = scramMechanism(deletion.mechanism) + requestStatus(deletion.name, Some(publicScramMechanism), publicScramMechanism != ScramMechanism.UNKNOWN, 0) + }).filter { !_.legalRequest } + // map user names to error messages + val unknownScramMechanismMsg = "Unknown SCRAM mechanism" + val tooFewIterationsMsg = "Too few iterations" + val tooManyIterationsMsg = "Too many iterations" + val illegalRequestsByUser = + illegalDeletions.map(requestStatus => + if (requestStatus.user.isEmpty) { + (requestStatus.user, usernameMustNotBeEmptyMsg) + } else { + (requestStatus.user, unknownScramMechanismMsg) + } + ).toMap ++ illegalUpsertions.map(requestStatus => + if (requestStatus.user.isEmpty) { + (requestStatus.user, usernameMustNotBeEmptyMsg) + } else if (requestStatus.mechanism == Some(ScramMechanism.UNKNOWN)) { + (requestStatus.user, unknownScramMechanismMsg) + } else { + (requestStatus.user, if (requestStatus.iterations > maxIterations) {tooManyIterationsMsg} else {tooFewIterationsMsg}) + } + ).toMap + + illegalRequestsByUser.forKeyValue { (user, errorMessage) => + retval.results.add(new AlterUserScramCredentialsResult().setUser(user) + .setErrorCode(if (errorMessage == unknownScramMechanismMsg) {Errors.UNSUPPORTED_SASL_MECHANISM.code} else {Errors.UNACCEPTABLE_CREDENTIAL.code}) + .setErrorMessage(errorMessage)) } + + val invalidUsers = (illegalUpsertions ++ illegalDeletions).map(_.user).toSet + val initiallyValidUserMechanismPairs = (upsertions.filter(upsertion => !invalidUsers.contains(upsertion.name)).map(upsertion => (upsertion.name, upsertion.mechanism)) ++ + deletions.filter(deletion => !invalidUsers.contains(deletion.name)).map(deletion => (deletion.name, deletion.mechanism))) + + val usersWithDuplicateUserMechanismPairs = initiallyValidUserMechanismPairs.groupBy(identity).filter ( + userMechanismPairAndOccurrencesTuple => userMechanismPairAndOccurrencesTuple._2.length > 1).keys.map(userMechanismPair => userMechanismPair._1).toSet + usersWithDuplicateUserMechanismPairs.foreach { user => + retval.results.add(new AlterUserScramCredentialsResult() + .setUser(user) + .setErrorCode(Errors.DUPLICATE_RESOURCE.code).setErrorMessage("A user credential cannot be altered twice in the same request")) } + + def potentiallyValidUserMechanismPairs = initiallyValidUserMechanismPairs.filter(pair => !usersWithDuplicateUserMechanismPairs.contains(pair._1)) + + val potentiallyValidUsers = potentiallyValidUserMechanismPairs.map(_._1).toSet + val configsByPotentiallyValidUser = potentiallyValidUsers.map(user => (user, adminZkClient.fetchEntityConfig(ConfigType.User, Sanitizer.sanitize(user)))).toMap + + // check for deletion of a credential that does not exist + val invalidDeletions = deletions.filter(deletion => potentiallyValidUsers.contains(deletion.name)).filter(deletion => + configsByPotentiallyValidUser(deletion.name).getProperty(mechanismName(deletion.mechanism)) == null) + val invalidUsersDueToInvalidDeletions = invalidDeletions.map(_.name).toSet + invalidUsersDueToInvalidDeletions.foreach { user => + retval.results.add(new AlterUserScramCredentialsResult() + .setUser(user) + .setErrorCode(Errors.RESOURCE_NOT_FOUND.code).setErrorMessage("Attempt to delete a user credential that does not exist")) } + + // now prepare the new set of property values for users that don't have any issues identified above, + // keeping track of ones that fail + val usersToTryToAlter = potentiallyValidUsers.diff(invalidUsersDueToInvalidDeletions) + val usersFailedToPrepareProperties = usersToTryToAlter.map(user => { + try { + // deletions: remove property keys + deletions.filter(deletion => usersToTryToAlter.contains(deletion.name)).foreach { deletion => + configsByPotentiallyValidUser(deletion.name).remove(mechanismName(deletion.mechanism)) } + // upsertions: put property key/value + upsertions.filter(upsertion => usersToTryToAlter.contains(upsertion.name)).foreach { upsertion => + val mechanism = InternalScramMechanism.forMechanismName(mechanismName(upsertion.mechanism)) + val credential = new ScramFormatter(mechanism) + .generateCredential(upsertion.salt, upsertion.saltedPassword, upsertion.iterations) + configsByPotentiallyValidUser(upsertion.name).put(mechanismName(upsertion.mechanism), ScramCredentialUtils.credentialToString(credential)) } + (user) // success, 1 element, won't be matched + } catch { + case e: Exception => + info(s"Error encountered while altering user SCRAM credentials", e) + (user, e) // fail, 2 elements, will be matched + } + }).collect { case (user: String, exception: Exception) => (user, exception) }.toMap + + // now persist the properties we have prepared, again keeping track of whatever fails + val usersFailedToPersist = usersToTryToAlter.filterNot(usersFailedToPrepareProperties.contains).map(user => { + try { + adminZkClient.changeConfigs(ConfigType.User, Sanitizer.sanitize(user), configsByPotentiallyValidUser(user)) + (user) // success, 1 element, won't be matched + } catch { + case e: Exception => + info(s"Error encountered while altering user SCRAM credentials", e) + (user, e) // fail, 2 elements, will be matched + } + }).collect { case (user: String, exception: Exception) => (user, exception) }.toMap + + // report failures + usersFailedToPrepareProperties.++(usersFailedToPersist).forKeyValue { (user, exception) => + val error = Errors.forException(exception) + retval.results.add(new AlterUserScramCredentialsResult() + .setUser(user) + .setErrorCode(error.code) + .setErrorMessage(error.message)) } + + // report successes + usersToTryToAlter.filterNot(usersFailedToPrepareProperties.contains).filterNot(usersFailedToPersist.contains).foreach { user => + retval.results.add(new AlterUserScramCredentialsResult() + .setUser(user) + .setErrorCode(Errors.NONE.code)) } + + retval + } +} diff --git a/core/src/main/scala/kafka/server/ZkIsrManager.scala b/core/src/main/scala/kafka/server/ZkIsrManager.scala new file mode 100644 index 0000000..65e8c14 --- /dev/null +++ b/core/src/main/scala/kafka/server/ZkIsrManager.scala @@ -0,0 +1,109 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import kafka.utils.{Logging, ReplicationUtils, Scheduler} +import kafka.zk.KafkaZkClient +import org.apache.kafka.common.TopicPartition +import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.{CompletableFuture, TimeUnit} + +import kafka.api.LeaderAndIsr +import org.apache.kafka.common.errors.InvalidUpdateVersionException +import org.apache.kafka.common.utils.Time + +import scala.collection.mutable + +/** + * @param checkIntervalMs How often to check for ISR + * @param maxDelayMs Maximum time that an ISR change may be delayed before sending the notification + * @param lingerMs Maximum time to await additional changes before sending the notification + */ +case class IsrChangePropagationConfig(checkIntervalMs: Long, maxDelayMs: Long, lingerMs: Long) + +object ZkIsrManager { + // This field is mutable to allow overriding change notification behavior in test cases + @volatile var DefaultIsrPropagationConfig: IsrChangePropagationConfig = IsrChangePropagationConfig( + checkIntervalMs = 2500, + lingerMs = 5000, + maxDelayMs = 60000, + ) +} + +class ZkIsrManager(scheduler: Scheduler, time: Time, zkClient: KafkaZkClient) extends AlterIsrManager with Logging { + + private val isrChangeNotificationConfig = ZkIsrManager.DefaultIsrPropagationConfig + // Visible for testing + private[server] val isrChangeSet: mutable.Set[TopicPartition] = new mutable.HashSet[TopicPartition]() + private val lastIsrChangeMs = new AtomicLong(time.milliseconds()) + private val lastIsrPropagationMs = new AtomicLong(time.milliseconds()) + + override def start(): Unit = { + scheduler.schedule("isr-change-propagation", maybePropagateIsrChanges _, + period = isrChangeNotificationConfig.checkIntervalMs, unit = TimeUnit.MILLISECONDS) + } + + override def submit( + topicPartition: TopicPartition, + leaderAndIsr: LeaderAndIsr, + controllerEpoch: Int + ): CompletableFuture[LeaderAndIsr]= { + debug(s"Writing new ISR ${leaderAndIsr.isr} to ZooKeeper with version " + + s"${leaderAndIsr.zkVersion} for partition $topicPartition") + + val (updateSucceeded, newVersion) = ReplicationUtils.updateLeaderAndIsr(zkClient, topicPartition, + leaderAndIsr, controllerEpoch) + + val future = new CompletableFuture[LeaderAndIsr]() + if (updateSucceeded) { + // Track which partitions need to be propagated to the controller + isrChangeSet synchronized { + isrChangeSet += topicPartition + lastIsrChangeMs.set(time.milliseconds()) + } + + // We rely on Partition#isrState being properly set to the pending ISR at this point since we are synchronously + // applying the callback + future.complete(leaderAndIsr.withZkVersion(newVersion)) + } else { + future.completeExceptionally(new InvalidUpdateVersionException( + s"ISR update $leaderAndIsr for partition $topicPartition with controller epoch $controllerEpoch " + + "failed with an invalid version error")) + } + future + } + + /** + * This function periodically runs to see if ISR needs to be propagated. It propagates ISR when: + * 1. There is ISR change not propagated yet. + * 2. There is no ISR Change in the last five seconds, or it has been more than 60 seconds since the last ISR propagation. + * This allows an occasional ISR change to be propagated within a few seconds, and avoids overwhelming controller and + * other brokers when large amount of ISR change occurs. + */ + private[server] def maybePropagateIsrChanges(): Unit = { + val now = time.milliseconds() + isrChangeSet synchronized { + if (isrChangeSet.nonEmpty && + (lastIsrChangeMs.get() + isrChangeNotificationConfig.lingerMs < now || + lastIsrPropagationMs.get() + isrChangeNotificationConfig.maxDelayMs < now)) { + zkClient.propagateIsrChanges(isrChangeSet) + isrChangeSet.clear() + lastIsrPropagationMs.set(now) + } + } + } +} diff --git a/core/src/main/scala/kafka/server/checkpoints/CheckpointFileWithFailureHandler.scala b/core/src/main/scala/kafka/server/checkpoints/CheckpointFileWithFailureHandler.scala new file mode 100644 index 0000000..7021c67 --- /dev/null +++ b/core/src/main/scala/kafka/server/checkpoints/CheckpointFileWithFailureHandler.scala @@ -0,0 +1,56 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server.checkpoints + +import kafka.server.LogDirFailureChannel +import org.apache.kafka.common.errors.KafkaStorageException +import org.apache.kafka.server.common.CheckpointFile +import CheckpointFile.EntryFormatter + +import java.io._ +import scala.collection.Seq +import scala.jdk.CollectionConverters._ + +class CheckpointFileWithFailureHandler[T](val file: File, + version: Int, + formatter: EntryFormatter[T], + logDirFailureChannel: LogDirFailureChannel, + logDir: String) { + private val checkpointFile = new CheckpointFile[T](file, version, formatter) + + def write(entries: Iterable[T]): Unit = { + try { + checkpointFile.write(entries.toSeq.asJava) + } catch { + case e: IOException => + val msg = s"Error while writing to checkpoint file ${file.getAbsolutePath}" + logDirFailureChannel.maybeAddOfflineLogDir(logDir, msg, e) + throw new KafkaStorageException(msg, e) + } + } + + def read(): Seq[T] = { + try { + checkpointFile.read().asScala + } catch { + case e: IOException => + val msg = s"Error while reading checkpoint file ${file.getAbsolutePath}" + logDirFailureChannel.maybeAddOfflineLogDir(logDir, msg, e) + throw new KafkaStorageException(msg, e) + } + } +} diff --git a/core/src/main/scala/kafka/server/checkpoints/LeaderEpochCheckpointFile.scala b/core/src/main/scala/kafka/server/checkpoints/LeaderEpochCheckpointFile.scala new file mode 100644 index 0000000..c772b82 --- /dev/null +++ b/core/src/main/scala/kafka/server/checkpoints/LeaderEpochCheckpointFile.scala @@ -0,0 +1,74 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server.checkpoints + +import kafka.server.LogDirFailureChannel +import kafka.server.epoch.EpochEntry +import org.apache.kafka.server.common.CheckpointFile.EntryFormatter + +import java.io._ +import java.util.Optional +import java.util.regex.Pattern +import scala.collection._ + +trait LeaderEpochCheckpoint { + def write(epochs: Iterable[EpochEntry]): Unit + def read(): Seq[EpochEntry] +} + +object LeaderEpochCheckpointFile { + private val LeaderEpochCheckpointFilename = "leader-epoch-checkpoint" + private val WhiteSpacesPattern = Pattern.compile("\\s+") + private val CurrentVersion = 0 + + def newFile(dir: File): File = new File(dir, LeaderEpochCheckpointFilename) + + object Formatter extends EntryFormatter[EpochEntry] { + + override def toString(entry: EpochEntry): String = s"${entry.epoch} ${entry.startOffset}" + + override def fromString(line: String): Optional[EpochEntry] = { + WhiteSpacesPattern.split(line) match { + case Array(epoch, offset) => + Optional.of(EpochEntry(epoch.toInt, offset.toLong)) + case _ => Optional.empty() + } + } + + } +} + +/** + * This class persists a map of (LeaderEpoch => Offsets) to a file (for a certain replica) + * + * The format in the LeaderEpoch checkpoint file is like this: + * -----checkpoint file begin------ + * 0 <- LeaderEpochCheckpointFile.currentVersion + * 2 <- following entries size + * 0 1 <- the format is: leader_epoch(int32) start_offset(int64) + * 1 2 + * -----checkpoint file end---------- + */ +class LeaderEpochCheckpointFile(val file: File, logDirFailureChannel: LogDirFailureChannel = null) extends LeaderEpochCheckpoint { + import LeaderEpochCheckpointFile._ + + val checkpoint = new CheckpointFileWithFailureHandler[EpochEntry](file, CurrentVersion, Formatter, logDirFailureChannel, file.getParentFile.getParent) + + def write(epochs: Iterable[EpochEntry]): Unit = checkpoint.write(epochs) + + def read(): Seq[EpochEntry] = checkpoint.read() +} diff --git a/core/src/main/scala/kafka/server/checkpoints/OffsetCheckpointFile.scala b/core/src/main/scala/kafka/server/checkpoints/OffsetCheckpointFile.scala new file mode 100644 index 0000000..f7b83ea --- /dev/null +++ b/core/src/main/scala/kafka/server/checkpoints/OffsetCheckpointFile.scala @@ -0,0 +1,100 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server.checkpoints + +import kafka.server.LogDirFailureChannel +import kafka.server.epoch.EpochEntry +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.server.common.CheckpointFile.EntryFormatter + +import java.io._ +import java.util.Optional +import java.util.regex.Pattern +import scala.collection._ + +object OffsetCheckpointFile { + private val WhiteSpacesPattern = Pattern.compile("\\s+") + private[checkpoints] val CurrentVersion = 0 + + object Formatter extends EntryFormatter[(TopicPartition, Long)] { + override def toString(entry: (TopicPartition, Long)): String = { + s"${entry._1.topic} ${entry._1.partition} ${entry._2}" + } + + override def fromString(line: String): Optional[(TopicPartition, Long)] = { + WhiteSpacesPattern.split(line) match { + case Array(topic, partition, offset) => + Optional.of(new TopicPartition(topic, partition.toInt), offset.toLong) + case _ => Optional.empty() + } + } + } +} + +trait OffsetCheckpoint { + def write(epochs: Seq[EpochEntry]): Unit + def read(): Seq[EpochEntry] +} + +/** + * This class persists a map of (Partition => Offsets) to a file (for a certain replica) + * + * The format in the offset checkpoint file is like this: + * -----checkpoint file begin------ + * 0 <- OffsetCheckpointFile.currentVersion + * 2 <- following entries size + * tp1 par1 1 <- the format is: TOPIC PARTITION OFFSET + * tp1 par2 2 + * -----checkpoint file end---------- + */ +class OffsetCheckpointFile(val file: File, logDirFailureChannel: LogDirFailureChannel = null) { + val checkpoint = new CheckpointFileWithFailureHandler[(TopicPartition, Long)](file, OffsetCheckpointFile.CurrentVersion, + OffsetCheckpointFile.Formatter, logDirFailureChannel, file.getParent) + + def write(offsets: Map[TopicPartition, Long]): Unit = checkpoint.write(offsets) + + def read(): Map[TopicPartition, Long] = checkpoint.read().toMap + +} + +trait OffsetCheckpoints { + def fetch(logDir: String, topicPartition: TopicPartition): Option[Long] +} + +/** + * Loads checkpoint files on demand and caches the offsets for reuse. + */ +class LazyOffsetCheckpoints(checkpointsByLogDir: Map[String, OffsetCheckpointFile]) extends OffsetCheckpoints { + private val lazyCheckpointsByLogDir = checkpointsByLogDir.map { case (logDir, checkpointFile) => + logDir -> new LazyOffsetCheckpointMap(checkpointFile) + }.toMap + + override def fetch(logDir: String, topicPartition: TopicPartition): Option[Long] = { + val offsetCheckpointFile = lazyCheckpointsByLogDir.getOrElse(logDir, + throw new IllegalArgumentException(s"No checkpoint file for log dir $logDir")) + offsetCheckpointFile.fetch(topicPartition) + } +} + +class LazyOffsetCheckpointMap(checkpoint: OffsetCheckpointFile) { + private lazy val offsets: Map[TopicPartition, Long] = checkpoint.read() + + def fetch(topicPartition: TopicPartition): Option[Long] = { + offsets.get(topicPartition) + } + +} diff --git a/core/src/main/scala/kafka/server/epoch/LeaderEpochFileCache.scala b/core/src/main/scala/kafka/server/epoch/LeaderEpochFileCache.scala new file mode 100644 index 0000000..e6e45fd --- /dev/null +++ b/core/src/main/scala/kafka/server/epoch/LeaderEpochFileCache.scala @@ -0,0 +1,301 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server.epoch + +import java.util +import java.util.concurrent.locks.ReentrantReadWriteLock + +import kafka.server.checkpoints.LeaderEpochCheckpoint +import kafka.utils.CoreUtils._ +import kafka.utils.Logging +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.{UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET} + +import scala.collection.{Seq, mutable} +import scala.jdk.CollectionConverters._ + +/** + * Represents a cache of (LeaderEpoch => Offset) mappings for a particular replica. + * + * Leader Epoch = epoch assigned to each leader by the controller. + * Offset = offset of the first message in each epoch. + * + * @param topicPartition the associated topic partition + * @param checkpoint the checkpoint file + */ +class LeaderEpochFileCache(topicPartition: TopicPartition, + checkpoint: LeaderEpochCheckpoint) extends Logging { + this.logIdent = s"[LeaderEpochCache $topicPartition] " + + private val lock = new ReentrantReadWriteLock() + private val epochs = new util.TreeMap[Int, EpochEntry]() + + inWriteLock(lock) { + checkpoint.read().foreach(assign) + } + + /** + * Assigns the supplied Leader Epoch to the supplied Offset + * Once the epoch is assigned it cannot be reassigned + */ + def assign(epoch: Int, startOffset: Long): Unit = { + val entry = EpochEntry(epoch, startOffset) + if (assign(entry)) { + debug(s"Appended new epoch entry $entry. Cache now contains ${epochs.size} entries.") + flush() + } + } + + private def assign(entry: EpochEntry): Boolean = { + if (entry.epoch < 0 || entry.startOffset < 0) { + throw new IllegalArgumentException(s"Received invalid partition leader epoch entry $entry") + } + + def isUpdateNeeded: Boolean = { + latestEntry match { + case Some(lastEntry) => + entry.epoch != lastEntry.epoch || entry.startOffset < lastEntry.startOffset + case None => + true + } + } + + // Check whether the append is needed before acquiring the write lock + // in order to avoid contention with readers in the common case + if (!isUpdateNeeded) + return false + + inWriteLock(lock) { + if (isUpdateNeeded) { + maybeTruncateNonMonotonicEntries(entry) + epochs.put(entry.epoch, entry) + true + } else { + false + } + } + } + + /** + * Remove any entries which violate monotonicity prior to appending a new entry + */ + private def maybeTruncateNonMonotonicEntries(newEntry: EpochEntry): Unit = { + val removedEpochs = removeFromEnd { entry => + entry.epoch >= newEntry.epoch || entry.startOffset >= newEntry.startOffset + } + + if (removedEpochs.size > 1 + || (removedEpochs.nonEmpty && removedEpochs.head.startOffset != newEntry.startOffset)) { + + // Only log a warning if there were non-trivial removals. If the start offset of the new entry + // matches the start offset of the removed epoch, then no data has been written and the truncation + // is expected. + warn(s"New epoch entry $newEntry caused truncation of conflicting entries $removedEpochs. " + + s"Cache now contains ${epochs.size} entries.") + } + } + + private def removeFromEnd(predicate: EpochEntry => Boolean): Seq[EpochEntry] = { + removeWhileMatching(epochs.descendingMap.entrySet().iterator(), predicate) + } + + private def removeFromStart(predicate: EpochEntry => Boolean): Seq[EpochEntry] = { + removeWhileMatching(epochs.entrySet().iterator(), predicate) + } + + private def removeWhileMatching( + iterator: util.Iterator[util.Map.Entry[Int, EpochEntry]], + predicate: EpochEntry => Boolean + ): Seq[EpochEntry] = { + val removedEpochs = mutable.ListBuffer.empty[EpochEntry] + + while (iterator.hasNext) { + val entry = iterator.next().getValue + if (predicate.apply(entry)) { + removedEpochs += entry + iterator.remove() + } else { + return removedEpochs + } + } + + removedEpochs + } + + def nonEmpty: Boolean = inReadLock(lock) { + !epochs.isEmpty + } + + def latestEntry: Option[EpochEntry] = { + inReadLock(lock) { + Option(epochs.lastEntry).map(_.getValue) + } + } + + /** + * Returns the current Leader Epoch if one exists. This is the latest epoch + * which has messages assigned to it. + */ + def latestEpoch: Option[Int] = { + latestEntry.map(_.epoch) + } + + def previousEpoch: Option[Int] = { + inReadLock(lock) { + latestEntry.flatMap(entry => Option(epochs.lowerEntry(entry.epoch))).map(_.getKey) + } + } + + /** + * Get the earliest cached entry if one exists. + */ + def earliestEntry: Option[EpochEntry] = { + inReadLock(lock) { + Option(epochs.firstEntry).map(_.getValue) + } + } + + /** + * Returns the Leader Epoch and the End Offset for a requested Leader Epoch. + * + * The Leader Epoch returned is the largest epoch less than or equal to the requested Leader + * Epoch. The End Offset is the end offset of this epoch, which is defined as the start offset + * of the first Leader Epoch larger than the Leader Epoch requested, or else the Log End + * Offset if the latest epoch was requested. + * + * During the upgrade phase, where there are existing messages may not have a leader epoch, + * if requestedEpoch is < the first epoch cached, UNDEFINED_EPOCH_OFFSET will be returned + * so that the follower falls back to High Water Mark. + * + * @param requestedEpoch requested leader epoch + * @param logEndOffset the existing Log End Offset + * @return found leader epoch and end offset + */ + def endOffsetFor(requestedEpoch: Int, logEndOffset: Long): (Int, Long) = { + inReadLock(lock) { + val epochAndOffset = + if (requestedEpoch == UNDEFINED_EPOCH) { + // This may happen if a bootstrapping follower sends a request with undefined epoch or + // a follower is on the older message format where leader epochs are not recorded + (UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET) + } else if (latestEpoch.contains(requestedEpoch)) { + // For the leader, the latest epoch is always the current leader epoch that is still being written to. + // Followers should not have any reason to query for the end offset of the current epoch, but a consumer + // might if it is verifying its committed offset following a group rebalance. In this case, we return + // the current log end offset which makes the truncation check work as expected. + (requestedEpoch, logEndOffset) + } else { + val higherEntry = epochs.higherEntry(requestedEpoch) + if (higherEntry == null) { + // The requested epoch is larger than any known epoch. This case should never be hit because + // the latest cached epoch is always the largest. + (UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET) + } else { + val floorEntry = epochs.floorEntry(requestedEpoch) + if (floorEntry == null) { + // The requested epoch is smaller than any known epoch, so we return the start offset of the first + // known epoch which is larger than it. This may be inaccurate as there could have been + // epochs in between, but the point is that the data has already been removed from the log + // and we want to ensure that the follower can replicate correctly beginning from the leader's + // start offset. + (requestedEpoch, higherEntry.getValue.startOffset) + } else { + // We have at least one previous epoch and one subsequent epoch. The result is the first + // prior epoch and the starting offset of the first subsequent epoch. + (floorEntry.getValue.epoch, higherEntry.getValue.startOffset) + } + } + } + trace(s"Processed end offset request for epoch $requestedEpoch and returning epoch ${epochAndOffset._1} " + + s"with end offset ${epochAndOffset._2} from epoch cache of size ${epochs.size}") + epochAndOffset + } + } + + /** + * Removes all epoch entries from the store with start offsets greater than or equal to the passed offset. + */ + def truncateFromEnd(endOffset: Long): Unit = { + inWriteLock(lock) { + if (endOffset >= 0 && latestEntry.exists(_.startOffset >= endOffset)) { + val removedEntries = removeFromEnd(_.startOffset >= endOffset) + + flush() + + debug(s"Cleared entries $removedEntries from epoch cache after " + + s"truncating to end offset $endOffset, leaving ${epochs.size} entries in the cache.") + } + } + } + + /** + * Clears old epoch entries. This method searches for the oldest epoch < offset, updates the saved epoch offset to + * be offset, then clears any previous epoch entries. + * + * This method is exclusive: so truncateFromStart(6) will retain an entry at offset 6. + * + * @param startOffset the offset to clear up to + */ + def truncateFromStart(startOffset: Long): Unit = { + inWriteLock(lock) { + val removedEntries = removeFromStart { entry => + entry.startOffset <= startOffset + } + + removedEntries.lastOption.foreach { firstBeforeStartOffset => + val updatedFirstEntry = EpochEntry(firstBeforeStartOffset.epoch, startOffset) + epochs.put(updatedFirstEntry.epoch, updatedFirstEntry) + + flush() + + debug(s"Cleared entries $removedEntries and rewrote first entry $updatedFirstEntry after " + + s"truncating to start offset $startOffset, leaving ${epochs.size} in the cache.") + } + } + } + + /** + * Delete all entries. + */ + def clearAndFlush(): Unit = { + inWriteLock(lock) { + epochs.clear() + flush() + } + } + + def clear(): Unit = { + inWriteLock(lock) { + epochs.clear() + } + } + + // Visible for testing + def epochEntries: Seq[EpochEntry] = epochs.values.asScala.toSeq + + private def flush(): Unit = { + checkpoint.write(epochs.values.asScala) + } + +} + +// Mapping of epoch to the first offset of the subsequent epoch +case class EpochEntry(epoch: Int, startOffset: Long) { + override def toString: String = { + s"EpochEntry(epoch=$epoch, startOffset=$startOffset)" + } +} diff --git a/core/src/main/scala/kafka/server/metadata/BrokerMetadataListener.scala b/core/src/main/scala/kafka/server/metadata/BrokerMetadataListener.scala new file mode 100644 index 0000000..702d227 --- /dev/null +++ b/core/src/main/scala/kafka/server/metadata/BrokerMetadataListener.scala @@ -0,0 +1,297 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server.metadata + +import java.util +import java.util.concurrent.{CompletableFuture, TimeUnit} +import java.util.function.Consumer + +import kafka.metrics.KafkaMetricsGroup +import org.apache.kafka.image.{MetadataDelta, MetadataImage} +import org.apache.kafka.common.utils.{LogContext, Time} +import org.apache.kafka.queue.{EventQueue, KafkaEventQueue} +import org.apache.kafka.raft.{Batch, BatchReader, LeaderAndEpoch, RaftClient} +import org.apache.kafka.server.common.ApiMessageAndVersion +import org.apache.kafka.snapshot.SnapshotReader + + +object BrokerMetadataListener { + val MetadataBatchProcessingTimeUs = "MetadataBatchProcessingTimeUs" + val MetadataBatchSizes = "MetadataBatchSizes" +} + +class BrokerMetadataListener( + val brokerId: Int, + time: Time, + threadNamePrefix: Option[String], + val maxBytesBetweenSnapshots: Long, + val snapshotter: Option[MetadataSnapshotter] +) extends RaftClient.Listener[ApiMessageAndVersion] with KafkaMetricsGroup { + private val logContext = new LogContext(s"[BrokerMetadataListener id=${brokerId}] ") + private val log = logContext.logger(classOf[BrokerMetadataListener]) + logIdent = logContext.logPrefix() + + /** + * A histogram tracking the time in microseconds it took to process batches of events. + */ + private val batchProcessingTimeHist = newHistogram(BrokerMetadataListener.MetadataBatchProcessingTimeUs) + + /** + * A histogram tracking the sizes of batches that we have processed. + */ + private val metadataBatchSizeHist = newHistogram(BrokerMetadataListener.MetadataBatchSizes) + + /** + * The highest metadata offset that we've seen. Written only from the event queue thread. + */ + @volatile var _highestOffset = -1L + + /** + * The highest metadata log time that we've seen. Written only from the event queue thread. + */ + private var _highestTimestamp = -1L + + /** + * The current broker metadata image. Accessed only from the event queue thread. + */ + private var _image = MetadataImage.EMPTY + + /** + * The current metadata delta. Accessed only from the event queue thread. + */ + private var _delta = new MetadataDelta(_image) + + /** + * The object to use to publish new metadata changes, or None if this listener has not + * been activated yet. Accessed only from the event queue thread. + */ + private var _publisher: Option[MetadataPublisher] = None + + /** + * The number of bytes of records that we have read since the last snapshot we took. + * This does not include records we read from a snapshot. + * Accessed only from the event queue thread. + */ + private var _bytesSinceLastSnapshot: Long = 0L + + /** + * The event queue which runs this listener. + */ + val eventQueue = new KafkaEventQueue(time, logContext, threadNamePrefix.getOrElse("")) + + /** + * Returns the highest metadata-offset. Thread-safe. + */ + def highestMetadataOffset: Long = _highestOffset + + /** + * Handle new metadata records. + */ + override def handleCommit(reader: BatchReader[ApiMessageAndVersion]): Unit = + eventQueue.append(new HandleCommitsEvent(reader)) + + class HandleCommitsEvent(reader: BatchReader[ApiMessageAndVersion]) + extends EventQueue.FailureLoggingEvent(log) { + override def run(): Unit = { + val results = try { + val loadResults = loadBatches(_delta, reader, None, None, None) + if (isDebugEnabled) { + debug(s"Loaded new commits: ${loadResults}") + } + loadResults + } finally { + reader.close() + } + _publisher.foreach(publish) + + snapshotter.foreach { snapshotter => + _bytesSinceLastSnapshot = _bytesSinceLastSnapshot + results.numBytes + if (shouldSnapshot()) { + if (snapshotter.maybeStartSnapshot(_highestTimestamp, _delta.apply())) { + _bytesSinceLastSnapshot = 0L + } + } + } + } + } + + private def shouldSnapshot(): Boolean = { + _bytesSinceLastSnapshot >= maxBytesBetweenSnapshots + } + + /** + * Handle metadata snapshots + */ + override def handleSnapshot(reader: SnapshotReader[ApiMessageAndVersion]): Unit = + eventQueue.append(new HandleSnapshotEvent(reader)) + + class HandleSnapshotEvent(reader: SnapshotReader[ApiMessageAndVersion]) + extends EventQueue.FailureLoggingEvent(log) { + override def run(): Unit = { + try { + info(s"Loading snapshot ${reader.snapshotId().offset}-${reader.snapshotId().epoch}.") + _delta = new MetadataDelta(_image) // Discard any previous deltas. + val loadResults = loadBatches( + _delta, + reader, + Some(reader.lastContainedLogTimestamp), + Some(reader.lastContainedLogOffset), + Some(reader.lastContainedLogEpoch) + ) + _delta.finishSnapshot() + info(s"Loaded snapshot ${reader.snapshotId().offset}-${reader.snapshotId().epoch}: " + + s"${loadResults}") + } finally { + reader.close() + } + _publisher.foreach(publish) + } + } + + case class BatchLoadResults(numBatches: Int, numRecords: Int, elapsedUs: Long, numBytes: Long) { + override def toString(): String = { + s"${numBatches} batch(es) with ${numRecords} record(s) in ${numBytes} bytes " + + s"ending at offset ${highestMetadataOffset} in ${elapsedUs} microseconds" + } + } + + /** + * Load and replay the batches to the metadata delta. + * + * When loading and replay a snapshot the appendTimestamp and snapshotId parameter should be provided. + * In a snapshot the append timestamp, offset and epoch reported by the batch is independent of the ones + * reported by the metadata log. + * + * @param delta metadata delta on which to replay the records + * @param iterator sequence of metadata record bacthes to replay + * @param lastAppendTimestamp optional append timestamp to use instead of the batches timestamp + * @param lastCommittedOffset optional offset to use instead of the batches offset + * @param lastCommittedEpoch optional epoch to use instead of the batches epoch + */ + private def loadBatches( + delta: MetadataDelta, + iterator: util.Iterator[Batch[ApiMessageAndVersion]], + lastAppendTimestamp: Option[Long], + lastCommittedOffset: Option[Long], + lastCommittedEpoch: Option[Int] + ): BatchLoadResults = { + val startTimeNs = time.nanoseconds() + var numBatches = 0 + var numRecords = 0 + var numBytes = 0L + + while (iterator.hasNext()) { + val batch = iterator.next() + + val epoch = lastCommittedEpoch.getOrElse(batch.epoch()) + _highestTimestamp = lastAppendTimestamp.getOrElse(batch.appendTimestamp()) + + var index = 0 + batch.records().forEach { messageAndVersion => + if (isTraceEnabled) { + trace("Metadata batch %d: processing [%d/%d]: %s.".format(batch.lastOffset, index + 1, + batch.records().size(), messageAndVersion.message().toString())) + } + + _highestOffset = lastCommittedOffset.getOrElse(batch.baseOffset() + index) + + delta.replay(highestMetadataOffset, epoch, messageAndVersion.message()) + numRecords += 1 + index += 1 + } + numBytes = numBytes + batch.sizeInBytes() + metadataBatchSizeHist.update(batch.records().size()) + numBatches = numBatches + 1 + } + + val endTimeNs = time.nanoseconds() + val elapsedUs = TimeUnit.MICROSECONDS.convert(endTimeNs - startTimeNs, TimeUnit.NANOSECONDS) + batchProcessingTimeHist.update(elapsedUs) + BatchLoadResults(numBatches, numRecords, elapsedUs, numBytes) + } + + def startPublishing(publisher: MetadataPublisher): CompletableFuture[Void] = { + val event = new StartPublishingEvent(publisher) + eventQueue.append(event) + event.future + } + + class StartPublishingEvent(publisher: MetadataPublisher) + extends EventQueue.FailureLoggingEvent(log) { + val future = new CompletableFuture[Void]() + + override def run(): Unit = { + _publisher = Some(publisher) + log.info(s"Starting to publish metadata events at offset ${highestMetadataOffset}.") + try { + publish(publisher) + future.complete(null) + } catch { + case e: Throwable => + future.completeExceptionally(e) + throw e + } + } + } + + private def publish(publisher: MetadataPublisher): Unit = { + val delta = _delta + _image = _delta.apply() + _delta = new MetadataDelta(_image) + publisher.publish(delta, _image) + } + + override def handleLeaderChange(leaderAndEpoch: LeaderAndEpoch): Unit = { + // Nothing to do. + } + + override def beginShutdown(): Unit = { + eventQueue.beginShutdown("beginShutdown", new ShutdownEvent()) + } + + class ShutdownEvent() extends EventQueue.FailureLoggingEvent(log) { + override def run(): Unit = { + removeMetric(BrokerMetadataListener.MetadataBatchProcessingTimeUs) + removeMetric(BrokerMetadataListener.MetadataBatchSizes) + } + } + + def close(): Unit = { + beginShutdown() + eventQueue.close() + } + + // VisibleForTesting + private[kafka] def getImageRecords(): CompletableFuture[util.List[ApiMessageAndVersion]] = { + val future = new CompletableFuture[util.List[ApiMessageAndVersion]]() + eventQueue.append(new GetImageRecordsEvent(future)) + future + } + + class GetImageRecordsEvent(future: CompletableFuture[util.List[ApiMessageAndVersion]]) + extends EventQueue.FailureLoggingEvent(log) with Consumer[util.List[ApiMessageAndVersion]] { + val records = new util.ArrayList[ApiMessageAndVersion]() + override def accept(batch: util.List[ApiMessageAndVersion]): Unit = { + records.addAll(batch) + } + + override def run(): Unit = { + _image.write(this) + future.complete(records) + } + } +} diff --git a/core/src/main/scala/kafka/server/metadata/BrokerMetadataPublisher.scala b/core/src/main/scala/kafka/server/metadata/BrokerMetadataPublisher.scala new file mode 100644 index 0000000..8cdcb3d --- /dev/null +++ b/core/src/main/scala/kafka/server/metadata/BrokerMetadataPublisher.scala @@ -0,0 +1,273 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server.metadata + +import kafka.coordinator.group.GroupCoordinator +import kafka.coordinator.transaction.TransactionCoordinator +import kafka.log.{UnifiedLog, LogManager} +import kafka.server.ConfigType +import kafka.server.{ConfigEntityName, ConfigHandler, FinalizedFeatureCache, KafkaConfig, ReplicaManager, RequestLocal} +import kafka.utils.Logging +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.config.ConfigResource +import org.apache.kafka.common.internals.Topic +import org.apache.kafka.image.{MetadataDelta, MetadataImage, TopicDelta, TopicsImage} + +import scala.collection.mutable + + +object BrokerMetadataPublisher extends Logging { + /** + * Given a topic name, find out if it changed. Note: if a topic named X was deleted and + * then re-created, this method will return just the re-creation. The deletion will show + * up in deletedTopicIds and must be handled separately. + * + * @param topicName The topic name. + * @param newImage The new metadata image. + * @param delta The metadata delta to search. + * + * @return The delta, or None if appropriate. + */ + def getTopicDelta(topicName: String, + newImage: MetadataImage, + delta: MetadataDelta): Option[TopicDelta] = { + Option(newImage.topics().getTopic(topicName)).flatMap { + topicImage => Option(delta.topicsDelta()).flatMap { + topicDelta => Option(topicDelta.changedTopic(topicImage.id())) + } + } + } + + /** + * Find logs which should not be on the current broker, according to the metadata image. + * + * @param brokerId The ID of the current broker. + * @param newTopicsImage The new topics image after broker has been reloaded + * @param logs A collection of Log objects. + * + * @return The topic partitions which are no longer needed on this broker. + */ + def findStrayPartitions(brokerId: Int, + newTopicsImage: TopicsImage, + logs: Iterable[UnifiedLog]): Iterable[TopicPartition] = { + logs.flatMap { log => + val topicId = log.topicId.getOrElse { + throw new RuntimeException(s"The log dir $log does not have a topic ID, " + + "which is not allowed when running in KRaft mode.") + } + + val partitionId = log.topicPartition.partition() + Option(newTopicsImage.getPartition(topicId, partitionId)) match { + case Some(partition) => + if (!partition.replicas.contains(brokerId)) { + info(s"Found stray log dir $log: the current replica assignment ${partition.replicas} " + + s"does not contain the local brokerId $brokerId.") + Some(log.topicPartition) + } else { + None + } + + case None => + info(s"Found stray log dir $log: the topicId $topicId does not exist in the metadata image") + Some(log.topicPartition) + } + } + } +} + +class BrokerMetadataPublisher(conf: KafkaConfig, + metadataCache: KRaftMetadataCache, + logManager: LogManager, + replicaManager: ReplicaManager, + groupCoordinator: GroupCoordinator, + txnCoordinator: TransactionCoordinator, + clientQuotaMetadataManager: ClientQuotaMetadataManager, + featureCache: FinalizedFeatureCache, + dynamicConfigHandlers: Map[String, ConfigHandler]) extends MetadataPublisher with Logging { + logIdent = s"[BrokerMetadataPublisher id=${conf.nodeId}] " + + import BrokerMetadataPublisher._ + + /** + * The broker ID. + */ + val brokerId = conf.nodeId + + /** + * True if this is the first time we have published metadata. + */ + var _firstPublish = true + + override def publish(delta: MetadataDelta, newImage: MetadataImage): Unit = { + val highestOffsetAndEpoch = newImage.highestOffsetAndEpoch() + + try { + trace(s"Publishing delta $delta with highest offset $highestOffsetAndEpoch") + + // Publish the new metadata image to the metadata cache. + metadataCache.setImage(newImage) + + if (_firstPublish) { + info(s"Publishing initial metadata at offset $highestOffsetAndEpoch.") + + // If this is the first metadata update we are applying, initialize the managers + // first (but after setting up the metadata cache). + initializeManagers() + } else if (isDebugEnabled) { + debug(s"Publishing metadata at offset $highestOffsetAndEpoch.") + } + + // Apply feature deltas. + Option(delta.featuresDelta()).foreach { featuresDelta => + featureCache.update(featuresDelta, highestOffsetAndEpoch.offset) + } + + // Apply topic deltas. + Option(delta.topicsDelta()).foreach { topicsDelta => + // Notify the replica manager about changes to topics. + replicaManager.applyDelta(topicsDelta, newImage) + + // Handle the case where the old consumer offsets topic was deleted. + if (topicsDelta.topicWasDeleted(Topic.GROUP_METADATA_TOPIC_NAME)) { + topicsDelta.image().getTopic(Topic.GROUP_METADATA_TOPIC_NAME).partitions().entrySet().forEach { + entry => + if (entry.getValue().leader == brokerId) { + groupCoordinator.onResignation(entry.getKey(), Some(entry.getValue().leaderEpoch)) + } + } + } + // Handle the case where we have new local leaders or followers for the consumer + // offsets topic. + getTopicDelta(Topic.GROUP_METADATA_TOPIC_NAME, newImage, delta).foreach { topicDelta => + val changes = topicDelta.localChanges(brokerId) + + changes.deletes.forEach { topicPartition => + groupCoordinator.onResignation(topicPartition.partition, None) + } + changes.leaders.forEach { (topicPartition, partitionInfo) => + groupCoordinator.onElection(topicPartition.partition, partitionInfo.partition.leaderEpoch) + } + changes.followers.forEach { (topicPartition, partitionInfo) => + groupCoordinator.onResignation(topicPartition.partition, Some(partitionInfo.partition.leaderEpoch)) + } + } + + // Handle the case where the old transaction state topic was deleted. + if (topicsDelta.topicWasDeleted(Topic.TRANSACTION_STATE_TOPIC_NAME)) { + topicsDelta.image().getTopic(Topic.TRANSACTION_STATE_TOPIC_NAME).partitions().entrySet().forEach { + entry => + if (entry.getValue().leader == brokerId) { + txnCoordinator.onResignation(entry.getKey(), Some(entry.getValue().leaderEpoch)) + } + } + } + // If the transaction state topic changed in a way that's relevant to this broker, + // notify the transaction coordinator. + getTopicDelta(Topic.TRANSACTION_STATE_TOPIC_NAME, newImage, delta).foreach { topicDelta => + val changes = topicDelta.localChanges(brokerId) + + changes.deletes.forEach { topicPartition => + txnCoordinator.onResignation(topicPartition.partition, None) + } + changes.leaders.forEach { (topicPartition, partitionInfo) => + txnCoordinator.onElection(topicPartition.partition, partitionInfo.partition.leaderEpoch) + } + changes.followers.forEach { (topicPartition, partitionInfo) => + txnCoordinator.onResignation(topicPartition.partition, Some(partitionInfo.partition.leaderEpoch)) + } + } + + // Notify the group coordinator about deleted topics. + val deletedTopicPartitions = new mutable.ArrayBuffer[TopicPartition]() + topicsDelta.deletedTopicIds().forEach { id => + val topicImage = topicsDelta.image().getTopic(id) + topicImage.partitions().keySet().forEach { + id => deletedTopicPartitions += new TopicPartition(topicImage.name(), id) + } + } + if (deletedTopicPartitions.nonEmpty) { + groupCoordinator.handleDeletedPartitions(deletedTopicPartitions, RequestLocal.NoCaching) + } + } + + // Apply configuration deltas. + Option(delta.configsDelta()).foreach { configsDelta => + configsDelta.changes().keySet().forEach { configResource => + val tag = configResource.`type`() match { + case ConfigResource.Type.TOPIC => Some(ConfigType.Topic) + case ConfigResource.Type.BROKER => Some(ConfigType.Broker) + case _ => None + } + tag.foreach { t => + val newProperties = newImage.configs().configProperties(configResource) + val maybeDefaultName = configResource.name() match { + case "" => ConfigEntityName.Default + case k => k + } + dynamicConfigHandlers(t).processConfigChanges(maybeDefaultName, newProperties) + } + } + } + + // Apply client quotas delta. + Option(delta.clientQuotasDelta()).foreach { clientQuotasDelta => + clientQuotaMetadataManager.update(clientQuotasDelta) + } + + if (_firstPublish) { + finishInitializingReplicaManager(newImage) + } + } catch { + case t: Throwable => error(s"Error publishing broker metadata at $highestOffsetAndEpoch", t) + throw t + } finally { + _firstPublish = false + } + } + + private def initializeManagers(): Unit = { + // Start log manager, which will perform (potentially lengthy) + // recovery-from-unclean-shutdown if required. + logManager.startup(metadataCache.getAllTopics()) + + // Start the replica manager. + replicaManager.startup() + + // Start the group coordinator. + groupCoordinator.startup(() => metadataCache.numPartitions( + Topic.GROUP_METADATA_TOPIC_NAME).getOrElse(conf.offsetsTopicPartitions)) + + // Start the transaction coordinator. + txnCoordinator.startup(() => metadataCache.numPartitions( + Topic.TRANSACTION_STATE_TOPIC_NAME).getOrElse(conf.transactionTopicPartitions)) + } + + private def finishInitializingReplicaManager(newImage: MetadataImage): Unit = { + // Delete log directories which we're not supposed to have, according to the + // latest metadata. This is only necessary to do when we're first starting up. If + // we have to load a snapshot later, these topics will appear in deletedTopicIds. + val strayPartitions = findStrayPartitions(brokerId, newImage.topics, logManager.allLogs) + if (strayPartitions.nonEmpty) { + replicaManager.deleteStrayReplicas(strayPartitions) + } + + // Make sure that the high water mark checkpoint thread is running for the replica + // manager. + replicaManager.startHighWatermarkCheckPointThread() + } +} diff --git a/core/src/main/scala/kafka/server/metadata/BrokerMetadataSnapshotter.scala b/core/src/main/scala/kafka/server/metadata/BrokerMetadataSnapshotter.scala new file mode 100644 index 0000000..fb5bfbb --- /dev/null +++ b/core/src/main/scala/kafka/server/metadata/BrokerMetadataSnapshotter.scala @@ -0,0 +1,114 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server.metadata + +import java.util.concurrent.RejectedExecutionException + +import kafka.utils.Logging +import org.apache.kafka.image.MetadataImage +import org.apache.kafka.common.utils.{LogContext, Time} +import org.apache.kafka.queue.{EventQueue, KafkaEventQueue} +import org.apache.kafka.server.common.ApiMessageAndVersion +import org.apache.kafka.snapshot.SnapshotWriter + + +trait SnapshotWriterBuilder { + def build(committedOffset: Long, + committedEpoch: Int, + lastContainedLogTime: Long): SnapshotWriter[ApiMessageAndVersion] +} + +class BrokerMetadataSnapshotter( + brokerId: Int, + val time: Time, + threadNamePrefix: Option[String], + writerBuilder: SnapshotWriterBuilder +) extends Logging with MetadataSnapshotter { + private val logContext = new LogContext(s"[BrokerMetadataSnapshotter id=${brokerId}] ") + logIdent = logContext.logPrefix() + + /** + * The offset of the snapshot in progress, or -1 if there isn't one. Accessed only under + * the object lock. + */ + private var _currentSnapshotOffset = -1L + + /** + * The event queue which runs this listener. + */ + val eventQueue = new KafkaEventQueue(time, logContext, threadNamePrefix.getOrElse("")) + + override def maybeStartSnapshot(lastContainedLogTime: Long, image: MetadataImage): Boolean = synchronized { + if (_currentSnapshotOffset == -1L) { + val writer = writerBuilder.build( + image.highestOffsetAndEpoch().offset, + image.highestOffsetAndEpoch().epoch, + lastContainedLogTime + ) + _currentSnapshotOffset = image.highestOffsetAndEpoch().offset + info(s"Creating a new snapshot at offset ${_currentSnapshotOffset}...") + eventQueue.append(new CreateSnapshotEvent(image, writer)) + true + } else { + warn(s"Declining to create a new snapshot at ${image.highestOffsetAndEpoch()} because " + + s"there is already a snapshot in progress at offset ${_currentSnapshotOffset}") + false + } + } + + class CreateSnapshotEvent(image: MetadataImage, + writer: SnapshotWriter[ApiMessageAndVersion]) + extends EventQueue.Event { + override def run(): Unit = { + try { + image.write(writer.append(_)) + writer.freeze() + } finally { + try { + writer.close() + } finally { + BrokerMetadataSnapshotter.this.synchronized { + _currentSnapshotOffset = -1L + } + } + } + } + + override def handleException(e: Throwable): Unit = { + e match { + case _: RejectedExecutionException => + info("Not processing CreateSnapshotEvent because the event queue is closed.") + case _ => error("Unexpected error handling CreateSnapshotEvent", e) + } + writer.close() + } + } + + def beginShutdown(): Unit = { + eventQueue.beginShutdown("beginShutdown", new ShutdownEvent()) + } + + class ShutdownEvent() extends EventQueue.Event { + override def run(): Unit = { + } + } + + def close(): Unit = { + beginShutdown() + eventQueue.close() + } +} diff --git a/core/src/main/scala/kafka/server/metadata/ClientQuotaMetadataManager.scala b/core/src/main/scala/kafka/server/metadata/ClientQuotaMetadataManager.scala new file mode 100644 index 0000000..6ada6b2 --- /dev/null +++ b/core/src/main/scala/kafka/server/metadata/ClientQuotaMetadataManager.scala @@ -0,0 +1,171 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server.metadata + +import kafka.network.ConnectionQuotas +import kafka.server.ConfigEntityName +import kafka.server.QuotaFactory.QuotaManagers +import kafka.utils.Logging +import org.apache.kafka.common.config.internals.QuotaConfigs +import org.apache.kafka.common.metrics.Quota +import org.apache.kafka.common.quota.ClientQuotaEntity +import org.apache.kafka.common.utils.Sanitizer +import java.net.{InetAddress, UnknownHostException} + +import org.apache.kafka.image.{ClientQuotaDelta, ClientQuotasDelta} + +import scala.compat.java8.OptionConverters._ + + +// A strict hierarchy of entities that we support +sealed trait QuotaEntity +case class IpEntity(ip: String) extends QuotaEntity +case object DefaultIpEntity extends QuotaEntity +case class UserEntity(user: String) extends QuotaEntity +case object DefaultUserEntity extends QuotaEntity +case class ClientIdEntity(clientId: String) extends QuotaEntity +case object DefaultClientIdEntity extends QuotaEntity +case class ExplicitUserExplicitClientIdEntity(user: String, clientId: String) extends QuotaEntity +case class ExplicitUserDefaultClientIdEntity(user: String) extends QuotaEntity +case class DefaultUserExplicitClientIdEntity(clientId: String) extends QuotaEntity +case object DefaultUserDefaultClientIdEntity extends QuotaEntity + +/** + * Process quota metadata records as they appear in the metadata log and update quota managers and cache as necessary + */ +class ClientQuotaMetadataManager(private[metadata] val quotaManagers: QuotaManagers, + private[metadata] val connectionQuotas: ConnectionQuotas) extends Logging { + + def update(quotasDelta: ClientQuotasDelta): Unit = { + quotasDelta.changes().entrySet().forEach { e => + update(e.getKey, e.getValue) + } + } + + private def update(entity: ClientQuotaEntity, quotaDelta: ClientQuotaDelta): Unit = { + if (entity.entries().containsKey(ClientQuotaEntity.IP)) { + // In the IP quota manager, None is used for default entity + val ipEntity = Option(entity.entries().get(ClientQuotaEntity.IP)) match { + case Some(ip) => IpEntity(ip) + case None => DefaultIpEntity + } + handleIpQuota(ipEntity, quotaDelta) + } else if (entity.entries().containsKey(ClientQuotaEntity.USER) || + entity.entries().containsKey(ClientQuotaEntity.CLIENT_ID)) { + // These values may be null, which is why we needed to use containsKey. + val userVal = entity.entries().get(ClientQuotaEntity.USER) + val clientIdVal = entity.entries().get(ClientQuotaEntity.CLIENT_ID) + + // In User+Client quota managers, "" is used for default entity, so we need to represent all possible + // combinations of values, defaults, and absent entities + val userClientEntity = if (entity.entries().containsKey(ClientQuotaEntity.USER) && + entity.entries().containsKey(ClientQuotaEntity.CLIENT_ID)) { + if (userVal == null && clientIdVal == null) { + DefaultUserDefaultClientIdEntity + } else if (userVal == null) { + DefaultUserExplicitClientIdEntity(clientIdVal) + } else if (clientIdVal == null) { + ExplicitUserDefaultClientIdEntity(userVal) + } else { + ExplicitUserExplicitClientIdEntity(userVal, clientIdVal) + } + } else if (entity.entries().containsKey(ClientQuotaEntity.USER)) { + if (userVal == null) { + DefaultUserEntity + } else { + UserEntity(userVal) + } + } else { + if (clientIdVal == null) { + DefaultClientIdEntity + } else { + ClientIdEntity(clientIdVal) + } + } + quotaDelta.changes().entrySet().forEach { e => + handleUserClientQuotaChange(userClientEntity, e.getKey(), e.getValue().asScala.map(_.toDouble)) + } + } else { + warn(s"Ignoring unsupported quota entity ${entity}.") + } + } + + def handleIpQuota(ipEntity: QuotaEntity, quotaDelta: ClientQuotaDelta): Unit = { + val inetAddress = ipEntity match { + case IpEntity(ip) => + try { + Some(InetAddress.getByName(ip)) + } catch { + case _: UnknownHostException => throw new IllegalArgumentException(s"Unable to resolve address $ip") + } + case DefaultIpEntity => None + case _ => throw new IllegalStateException("Should only handle IP quota entities here") + } + + quotaDelta.changes().entrySet().forEach { e => + // The connection quota only understands the connection rate limit + val quotaName = e.getKey() + val quotaValue = e.getValue() + if (!quotaName.equals(QuotaConfigs.IP_CONNECTION_RATE_OVERRIDE_CONFIG)) { + warn(s"Ignoring unexpected quota key ${quotaName} for entity $ipEntity") + } else { + try { + connectionQuotas.updateIpConnectionRateQuota(inetAddress, quotaValue.asScala.map(_.toInt)) + } catch { + case t: Throwable => error(s"Failed to update IP quota $ipEntity", t) + } + } + } + } + + def handleUserClientQuotaChange(quotaEntity: QuotaEntity, key: String, newValue: Option[Double]): Unit = { + val manager = key match { + case QuotaConfigs.CONSUMER_BYTE_RATE_OVERRIDE_CONFIG => quotaManagers.fetch + case QuotaConfigs.PRODUCER_BYTE_RATE_OVERRIDE_CONFIG => quotaManagers.produce + case QuotaConfigs.REQUEST_PERCENTAGE_OVERRIDE_CONFIG => quotaManagers.request + case QuotaConfigs.CONTROLLER_MUTATION_RATE_OVERRIDE_CONFIG => quotaManagers.controllerMutation + case _ => + warn(s"Ignoring unexpected quota key ${key} for entity $quotaEntity") + return + } + + // Convert entity into Options with sanitized values for QuotaManagers + val (sanitizedUser, sanitizedClientId) = quotaEntity match { + case UserEntity(user) => (Some(Sanitizer.sanitize(user)), None) + case DefaultUserEntity => (Some(ConfigEntityName.Default), None) + case ClientIdEntity(clientId) => (None, Some(Sanitizer.sanitize(clientId))) + case DefaultClientIdEntity => (None, Some(ConfigEntityName.Default)) + case ExplicitUserExplicitClientIdEntity(user, clientId) => (Some(Sanitizer.sanitize(user)), Some(Sanitizer.sanitize(clientId))) + case ExplicitUserDefaultClientIdEntity(user) => (Some(Sanitizer.sanitize(user)), Some(ConfigEntityName.Default)) + case DefaultUserExplicitClientIdEntity(clientId) => (Some(ConfigEntityName.Default), Some(Sanitizer.sanitize(clientId))) + case DefaultUserDefaultClientIdEntity => (Some(ConfigEntityName.Default), Some(ConfigEntityName.Default)) + case IpEntity(_) | DefaultIpEntity => throw new IllegalStateException("Should not see IP quota entities here") + } + + val quotaValue = newValue.map(new Quota(_, true)) + try { + manager.updateQuota( + sanitizedUser = sanitizedUser, + clientId = sanitizedClientId.map(Sanitizer.desanitize), + sanitizedClientId = sanitizedClientId, + quota = quotaValue) + } catch { + case t: Throwable => error(s"Failed to update user-client quota $quotaEntity", t) + } + } +} diff --git a/core/src/main/scala/kafka/server/metadata/ConfigRepository.scala b/core/src/main/scala/kafka/server/metadata/ConfigRepository.scala new file mode 100644 index 0000000..68000d0 --- /dev/null +++ b/core/src/main/scala/kafka/server/metadata/ConfigRepository.scala @@ -0,0 +1,52 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server.metadata + +import java.util.Properties + +import org.apache.kafka.common.config.ConfigResource +import org.apache.kafka.common.config.ConfigResource.Type + +trait ConfigRepository { + /** + * Return a copy of the topic configuration for the given topic. Future changes will not be reflected. + * + * @param topicName the name of the topic for which the configuration will be returned + * @return a copy of the topic configuration for the given topic + */ + def topicConfig(topicName: String): Properties = { + config(new ConfigResource(Type.TOPIC, topicName)) + } + + /** + * Return a copy of the broker configuration for the given broker. Future changes will not be reflected. + * + * @param brokerId the id of the broker for which configuration will be returned + * @return a copy of the broker configuration for the given broker + */ + def brokerConfig(brokerId: Int): Properties = { + config(new ConfigResource(Type.BROKER, brokerId.toString)) + } + + /** + * Return a copy of the configuration for the given resource. Future changes will not be reflected. + * @param configResource the resource for which the configuration will be returned + * @return a copy of the configuration for the given resource + */ + def config(configResource: ConfigResource): Properties +} diff --git a/core/src/main/scala/kafka/server/metadata/KRaftMetadataCache.scala b/core/src/main/scala/kafka/server/metadata/KRaftMetadataCache.scala new file mode 100644 index 0000000..1ff7a80 --- /dev/null +++ b/core/src/main/scala/kafka/server/metadata/KRaftMetadataCache.scala @@ -0,0 +1,367 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server.metadata + +import kafka.controller.StateChangeLogger +import kafka.server.MetadataCache +import kafka.utils.Logging +import org.apache.kafka.common.internals.Topic +import org.apache.kafka.common.message.MetadataResponseData.{MetadataResponsePartition, MetadataResponseTopic} +import org.apache.kafka.common.{Cluster, Node, PartitionInfo, TopicPartition, Uuid} +import org.apache.kafka.common.message.UpdateMetadataRequestData.UpdateMetadataPartitionState +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.MetadataResponse +import org.apache.kafka.image.MetadataImage +import java.util +import java.util.{Collections, Properties} +import java.util.concurrent.ThreadLocalRandom + +import kafka.admin.BrokerMetadata +import org.apache.kafka.common.config.ConfigResource +import org.apache.kafka.common.message.{DescribeClientQuotasRequestData, DescribeClientQuotasResponseData} +import org.apache.kafka.metadata.{PartitionRegistration, Replicas} + +import scala.collection.{Seq, Set, mutable} +import scala.jdk.CollectionConverters._ +import scala.compat.java8.OptionConverters._ + + +class KRaftMetadataCache(val brokerId: Int) extends MetadataCache with Logging with ConfigRepository { + this.logIdent = s"[MetadataCache brokerId=$brokerId] " + + // This is the cache state. Every MetadataImage instance is immutable, and updates + // replace this value with a completely new one. This means reads (which are not under + // any lock) need to grab the value of this variable once, and retain that read copy for + // the duration of their operation. Multiple reads of this value risk getting different + // image values. + @volatile private var _currentImage: MetadataImage = MetadataImage.EMPTY + + private val stateChangeLogger = new StateChangeLogger(brokerId, inControllerContext = false, None) + + // This method is the main hotspot when it comes to the performance of metadata requests, + // we should be careful about adding additional logic here. + // filterUnavailableEndpoints exists to support v0 MetadataResponses + private def maybeFilterAliveReplicas(image: MetadataImage, + brokers: Array[Int], + listenerName: ListenerName, + filterUnavailableEndpoints: Boolean): java.util.List[Integer] = { + if (!filterUnavailableEndpoints) { + Replicas.toList(brokers) + } else { + val res = new util.ArrayList[Integer](brokers.length) + for (brokerId <- brokers) { + Option(image.cluster().broker(brokerId)).foreach { b => + if (!b.fenced() && b.listeners().containsKey(listenerName.value())) { + res.add(brokerId) + } + } + } + res + } + } + + def currentImage(): MetadataImage = _currentImage + + // errorUnavailableEndpoints exists to support v0 MetadataResponses + // If errorUnavailableListeners=true, return LISTENER_NOT_FOUND if listener is missing on the broker. + // Otherwise, return LEADER_NOT_AVAILABLE for broker unavailable and missing listener (Metadata response v5 and below). + private def getPartitionMetadata(image: MetadataImage, topicName: String, listenerName: ListenerName, errorUnavailableEndpoints: Boolean, + errorUnavailableListeners: Boolean): Option[Iterator[MetadataResponsePartition]] = { + Option(image.topics().getTopic(topicName)) match { + case None => None + case Some(topic) => Some(topic.partitions().entrySet().asScala.map { entry => + val partitionId = entry.getKey + val partition = entry.getValue + val filteredReplicas = maybeFilterAliveReplicas(image, partition.replicas, + listenerName, errorUnavailableEndpoints) + val filteredIsr = maybeFilterAliveReplicas(image, partition.isr, listenerName, + errorUnavailableEndpoints) + val offlineReplicas = getOfflineReplicas(image, partition, listenerName) + val maybeLeader = getAliveEndpoint(image, partition.leader, listenerName) + maybeLeader match { + case None => + val error = if (!image.cluster().brokers.containsKey(partition.leader)) { + debug(s"Error while fetching metadata for ${topicName}-${partitionId}: leader not available") + Errors.LEADER_NOT_AVAILABLE + } else { + debug(s"Error while fetching metadata for ${topicName}-${partitionId}: listener $listenerName " + + s"not found on leader ${partition.leader}") + if (errorUnavailableListeners) Errors.LISTENER_NOT_FOUND else Errors.LEADER_NOT_AVAILABLE + } + new MetadataResponsePartition() + .setErrorCode(error.code) + .setPartitionIndex(partitionId) + .setLeaderId(MetadataResponse.NO_LEADER_ID) + .setLeaderEpoch(partition.leaderEpoch) + .setReplicaNodes(filteredReplicas) + .setIsrNodes(filteredIsr) + .setOfflineReplicas(offlineReplicas) + case Some(leader) => + val error = if (filteredReplicas.size < partition.replicas.size) { + debug(s"Error while fetching metadata for ${topicName}-${partitionId}: replica information not available for " + + s"following brokers ${partition.replicas.filterNot(filteredReplicas.contains).mkString(",")}") + Errors.REPLICA_NOT_AVAILABLE + } else if (filteredIsr.size < partition.isr.size) { + debug(s"Error while fetching metadata for ${topicName}-${partitionId}: in sync replica information not available for " + + s"following brokers ${partition.isr.filterNot(filteredIsr.contains).mkString(",")}") + Errors.REPLICA_NOT_AVAILABLE + } else { + Errors.NONE + } + + new MetadataResponsePartition() + .setErrorCode(error.code) + .setPartitionIndex(partitionId) + .setLeaderId(leader.id()) + .setLeaderEpoch(partition.leaderEpoch) + .setReplicaNodes(filteredReplicas) + .setIsrNodes(filteredIsr) + .setOfflineReplicas(offlineReplicas) + } + }.iterator) + } + } + + private def getOfflineReplicas(image: MetadataImage, + partition: PartitionRegistration, + listenerName: ListenerName): util.List[Integer] = { + // TODO: in order to really implement this correctly, we would need JBOD support. + // That would require us to track which replicas were offline on a per-replica basis. + // See KAFKA-13005. + val offlineReplicas = new util.ArrayList[Integer](0) + for (brokerId <- partition.replicas) { + Option(image.cluster().broker(brokerId)) match { + case None => offlineReplicas.add(brokerId) + case Some(broker) => if (broker.fenced() || !broker.listeners().containsKey(listenerName.value())) { + offlineReplicas.add(brokerId) + } + } + } + offlineReplicas + } + + /** + * Get the endpoint matching the provided listener if the broker is alive. Note that listeners can + * be added dynamically, so a broker with a missing listener could be a transient error. + * + * @return None if broker is not alive or if the broker does not have a listener named `listenerName`. + */ + private def getAliveEndpoint(image: MetadataImage, id: Int, listenerName: ListenerName): Option[Node] = { + Option(image.cluster().broker(id)).flatMap(_.node(listenerName.value()).asScala) + } + + // errorUnavailableEndpoints exists to support v0 MetadataResponses + override def getTopicMetadata(topics: Set[String], + listenerName: ListenerName, + errorUnavailableEndpoints: Boolean = false, + errorUnavailableListeners: Boolean = false): Seq[MetadataResponseTopic] = { + val image = _currentImage + topics.toSeq.flatMap { topic => + getPartitionMetadata(image, topic, listenerName, errorUnavailableEndpoints, errorUnavailableListeners).map { partitionMetadata => + new MetadataResponseTopic() + .setErrorCode(Errors.NONE.code) + .setName(topic) + .setTopicId(Option(image.topics().getTopic(topic).id()).getOrElse(Uuid.ZERO_UUID)) + .setIsInternal(Topic.isInternal(topic)) + .setPartitions(partitionMetadata.toBuffer.asJava) + } + } + } + + override def getAllTopics(): Set[String] = _currentImage.topics().topicsByName().keySet().asScala + + override def getTopicPartitions(topicName: String): Set[TopicPartition] = { + Option(_currentImage.topics().getTopic(topicName)) match { + case None => Set.empty + case Some(topic) => topic.partitions().keySet().asScala.map(new TopicPartition(topicName, _)) + } + } + + override def getTopicId(topicName: String): Uuid = _currentImage.topics().topicsByName().asScala.get(topicName).map(_.id()).getOrElse(Uuid.ZERO_UUID) + + override def getTopicName(topicId: Uuid): Option[String] = _currentImage.topics().topicsById.asScala.get(topicId).map(_.name()) + + override def hasAliveBroker(brokerId: Int): Boolean = { + Option(_currentImage.cluster().broker(brokerId)).count(!_.fenced()) == 1 + } + + override def getAliveBrokers(): Iterable[BrokerMetadata] = getAliveBrokers(_currentImage) + + private def getAliveBrokers(image: MetadataImage): Iterable[BrokerMetadata] = { + image.cluster().brokers().values().asScala.filter(!_.fenced()). + map(b => BrokerMetadata(b.id, b.rack.asScala)) + } + + override def getAliveBrokerNode(brokerId: Int, listenerName: ListenerName): Option[Node] = { + Option(_currentImage.cluster().broker(brokerId)). + flatMap(_.node(listenerName.value()).asScala) + } + + override def getAliveBrokerNodes(listenerName: ListenerName): Seq[Node] = { + _currentImage.cluster().brokers().values().asScala.filter(!_.fenced()). + flatMap(_.node(listenerName.value()).asScala).toSeq + } + + override def getPartitionInfo(topicName: String, partitionId: Int): Option[UpdateMetadataPartitionState] = { + Option(_currentImage.topics().getTopic(topicName)). + flatMap(topic => Some(topic.partitions().get(partitionId))). + flatMap(partition => Some(new UpdateMetadataPartitionState(). + setTopicName(topicName). + setPartitionIndex(partitionId). + setControllerEpoch(-1). // Controller epoch is not stored in the cache. + setLeader(partition.leader). + setLeaderEpoch(partition.leaderEpoch). + setIsr(Replicas.toList(partition.isr)). + setZkVersion(partition.partitionEpoch))) + } + + override def numPartitions(topicName: String): Option[Int] = { + Option(_currentImage.topics().getTopic(topicName)). + map(topic => topic.partitions().size()) + } + + override def topicNamesToIds(): util.Map[String, Uuid] = _currentImage.topics.topicNameToIdView() + + override def topicIdsToNames(): util.Map[Uuid, String] = _currentImage.topics.topicIdToNameView() + + override def topicIdInfo(): (util.Map[String, Uuid], util.Map[Uuid, String]) = { + val image = _currentImage + (image.topics.topicNameToIdView(), image.topics.topicIdToNameView()) + } + + // if the leader is not known, return None; + // if the leader is known and corresponding node is available, return Some(node) + // if the leader is known but corresponding node with the listener name is not available, return Some(NO_NODE) + override def getPartitionLeaderEndpoint(topicName: String, partitionId: Int, listenerName: ListenerName): Option[Node] = { + val image = _currentImage + Option(image.topics().getTopic(topicName)) match { + case None => None + case Some(topic) => Option(topic.partitions().get(partitionId)) match { + case None => None + case Some(partition) => Option(image.cluster().broker(partition.leader)) match { + case None => Some(Node.noNode) + case Some(broker) => Some(broker.node(listenerName.value()).orElse(Node.noNode())) + } + } + } + } + + override def getPartitionReplicaEndpoints(tp: TopicPartition, listenerName: ListenerName): Map[Int, Node] = { + val image = _currentImage + val result = new mutable.HashMap[Int, Node]() + Option(image.topics().getTopic(tp.topic())).foreach { topic => + topic.partitions().values().forEach { case partition => + partition.replicas.map { case replicaId => + result.put(replicaId, Option(image.cluster().broker(replicaId)) match { + case None => Node.noNode() + case Some(broker) => broker.node(listenerName.value()).asScala.getOrElse(Node.noNode()) + }) + } + } + } + result.toMap + } + + override def getControllerId: Option[Int] = getRandomAliveBroker(_currentImage) + + /** + * Choose a random broker node to report as the controller. We do this because we want + * the client to send requests destined for the controller to a random broker. + * Clients do not have direct access to the controller in the KRaft world, as explained + * in KIP-590. + */ + private def getRandomAliveBroker(image: MetadataImage): Option[Int] = { + val aliveBrokers = getAliveBrokers(image).toList + if (aliveBrokers.size == 0) { + None + } else { + Some(aliveBrokers(ThreadLocalRandom.current().nextInt(aliveBrokers.size)).id) + } + } + + override def getClusterMetadata(clusterId: String, listenerName: ListenerName): Cluster = { + val image = _currentImage + val nodes = new util.HashMap[Integer, Node] + image.cluster().brokers().values().forEach { broker => + if (!broker.fenced()) { + broker.node(listenerName.value()).asScala.foreach { node => + nodes.put(broker.id(), node) + } + } + } + + def node(id: Int): Node = { + Option(nodes.get(id)).getOrElse(Node.noNode()) + } + + val partitionInfos = new util.ArrayList[PartitionInfo] + val internalTopics = new util.HashSet[String] + + image.topics().topicsByName().values().forEach { topic => + topic.partitions().entrySet().forEach { entry => + val partitionId = entry.getKey() + val partition = entry.getValue() + partitionInfos.add(new PartitionInfo(topic.name(), + partitionId, + node(partition.leader), + partition.replicas.map(replica => node(replica)), + partition.isr.map(replica => node(replica)), + getOfflineReplicas(image, partition, listenerName).asScala. + map(replica => node(replica)).toArray)) + if (Topic.isInternal(topic.name())) { + internalTopics.add(topic.name()) + } + } + } + val controllerNode = node(getRandomAliveBroker(image).getOrElse(-1)) + // Note: the constructor of Cluster does not allow us to reference unregistered nodes. + // So, for example, if partition foo-0 has replicas [1, 2] but broker 2 is not + // registered, we pass its replicas as [1, -1]. This doesn't make a lot of sense, but + // we are duplicating the behavior of ZkMetadataCache, for now. + new Cluster(clusterId, nodes.values(), + partitionInfos, Collections.emptySet(), internalTopics, controllerNode) + } + + def stateChangeTraceEnabled(): Boolean = { + stateChangeLogger.isTraceEnabled + } + + def logStateChangeTrace(str: String): Unit = { + stateChangeLogger.trace(str) + } + + override def contains(topicName: String): Boolean = + _currentImage.topics().topicsByName().containsKey(topicName) + + override def contains(tp: TopicPartition): Boolean = { + Option(_currentImage.topics().getTopic(tp.topic())) match { + case None => false + case Some(topic) => topic.partitions().containsKey(tp.partition()) + } + } + + def setImage(newImage: MetadataImage): Unit = _currentImage = newImage + + override def config(configResource: ConfigResource): Properties = + _currentImage.configs().configProperties(configResource) + + def describeClientQuotas(request: DescribeClientQuotasRequestData): DescribeClientQuotasResponseData = { + _currentImage.clientQuotas().describe(request) + } +} diff --git a/core/src/main/scala/kafka/server/metadata/MetadataPublisher.scala b/core/src/main/scala/kafka/server/metadata/MetadataPublisher.scala new file mode 100644 index 0000000..104d164 --- /dev/null +++ b/core/src/main/scala/kafka/server/metadata/MetadataPublisher.scala @@ -0,0 +1,33 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server.metadata + +import org.apache.kafka.image.{MetadataDelta, MetadataImage} + +/** + * An object which publishes a new metadata image. + */ +trait MetadataPublisher { + /** + * Publish a new metadata image. + * + * @param delta The delta between the old image and the new one. + * @param newImage The new image, which is the result of applying the + * delta to the previous image. + */ + def publish(delta: MetadataDelta, newImage: MetadataImage): Unit +} diff --git a/core/src/main/scala/kafka/server/metadata/MetadataSnapshotter.scala b/core/src/main/scala/kafka/server/metadata/MetadataSnapshotter.scala new file mode 100644 index 0000000..c9d7292 --- /dev/null +++ b/core/src/main/scala/kafka/server/metadata/MetadataSnapshotter.scala @@ -0,0 +1,35 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server.metadata + +import org.apache.kafka.image.MetadataImage + + +/** + * Handles creating snapshots. + */ +trait MetadataSnapshotter { + /** + * If there is no other snapshot being written out, start writing out a snapshot. + * + * @param lastContainedLogTime The highest time contained in the snapshot. + * @param image The metadata image to write out. + * + * @return True if we will write out a new snapshot; false otherwise. + */ + def maybeStartSnapshot(lastContainedLogTime: Long, image: MetadataImage): Boolean +} diff --git a/core/src/main/scala/kafka/server/metadata/ZkConfigRepository.scala b/core/src/main/scala/kafka/server/metadata/ZkConfigRepository.scala new file mode 100644 index 0000000..95fe752 --- /dev/null +++ b/core/src/main/scala/kafka/server/metadata/ZkConfigRepository.scala @@ -0,0 +1,42 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server.metadata + +import java.util.Properties + +import kafka.server.ConfigType +import kafka.zk.{AdminZkClient, KafkaZkClient} +import org.apache.kafka.common.config.ConfigResource +import org.apache.kafka.common.config.ConfigResource.Type + + +object ZkConfigRepository { + def apply(zkClient: KafkaZkClient): ZkConfigRepository = + new ZkConfigRepository(new AdminZkClient(zkClient)) +} + +class ZkConfigRepository(adminZkClient: AdminZkClient) extends ConfigRepository { + override def config(configResource: ConfigResource): Properties = { + val configTypeForZk = configResource.`type` match { + case Type.TOPIC => ConfigType.Topic + case Type.BROKER => ConfigType.Broker + case tpe => throw new IllegalArgumentException(s"Unsupported config type: $tpe") + } + adminZkClient.fetchEntityConfig(configTypeForZk, configResource.name) + } +} diff --git a/core/src/main/scala/kafka/server/metadata/ZkMetadataCache.scala b/core/src/main/scala/kafka/server/metadata/ZkMetadataCache.scala new file mode 100755 index 0000000..0356873 --- /dev/null +++ b/core/src/main/scala/kafka/server/metadata/ZkMetadataCache.scala @@ -0,0 +1,433 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server.metadata + +import java.util +import java.util.Collections +import java.util.concurrent.locks.ReentrantReadWriteLock +import kafka.admin.BrokerMetadata + +import scala.collection.{Seq, Set, mutable} +import scala.jdk.CollectionConverters._ +import kafka.cluster.{Broker, EndPoint} +import kafka.api._ +import kafka.controller.StateChangeLogger +import kafka.server.MetadataCache +import kafka.utils.CoreUtils._ +import kafka.utils.Logging +import kafka.utils.Implicits._ +import org.apache.kafka.common.internals.Topic +import org.apache.kafka.common.message.UpdateMetadataRequestData.UpdateMetadataPartitionState +import org.apache.kafka.common.{Cluster, Node, PartitionInfo, TopicPartition, Uuid} +import org.apache.kafka.common.message.MetadataResponseData.MetadataResponseTopic +import org.apache.kafka.common.message.MetadataResponseData.MetadataResponsePartition +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.{MetadataResponse, UpdateMetadataRequest} +import org.apache.kafka.common.security.auth.SecurityProtocol + +/** + * A cache for the state (e.g., current leader) of each partition. This cache is updated through + * UpdateMetadataRequest from the controller. Every broker maintains the same cache, asynchronously. + */ +class ZkMetadataCache(brokerId: Int) extends MetadataCache with Logging { + + private val partitionMetadataLock = new ReentrantReadWriteLock() + //this is the cache state. every MetadataSnapshot instance is immutable, and updates (performed under a lock) + //replace the value with a completely new one. this means reads (which are not under any lock) need to grab + //the value of this var (into a val) ONCE and retain that read copy for the duration of their operation. + //multiple reads of this value risk getting different snapshots. + @volatile private var metadataSnapshot: MetadataSnapshot = MetadataSnapshot(partitionStates = mutable.AnyRefMap.empty, + topicIds = Map.empty, controllerId = None, aliveBrokers = mutable.LongMap.empty, aliveNodes = mutable.LongMap.empty) + + this.logIdent = s"[MetadataCache brokerId=$brokerId] " + private val stateChangeLogger = new StateChangeLogger(brokerId, inControllerContext = false, None) + + // This method is the main hotspot when it comes to the performance of metadata requests, + // we should be careful about adding additional logic here. Relatedly, `brokers` is + // `List[Integer]` instead of `List[Int]` to avoid a collection copy. + // filterUnavailableEndpoints exists to support v0 MetadataResponses + private def maybeFilterAliveReplicas(snapshot: MetadataSnapshot, + brokers: java.util.List[Integer], + listenerName: ListenerName, + filterUnavailableEndpoints: Boolean): java.util.List[Integer] = { + if (!filterUnavailableEndpoints) { + brokers + } else { + val res = new util.ArrayList[Integer](math.min(snapshot.aliveBrokers.size, brokers.size)) + for (brokerId <- brokers.asScala) { + if (hasAliveEndpoint(snapshot, brokerId, listenerName)) + res.add(brokerId) + } + res + } + } + + // errorUnavailableEndpoints exists to support v0 MetadataResponses + // If errorUnavailableListeners=true, return LISTENER_NOT_FOUND if listener is missing on the broker. + // Otherwise, return LEADER_NOT_AVAILABLE for broker unavailable and missing listener (Metadata response v5 and below). + private def getPartitionMetadata(snapshot: MetadataSnapshot, topic: String, listenerName: ListenerName, errorUnavailableEndpoints: Boolean, + errorUnavailableListeners: Boolean): Option[Iterable[MetadataResponsePartition]] = { + snapshot.partitionStates.get(topic).map { partitions => + partitions.map { case (partitionId, partitionState) => + val topicPartition = new TopicPartition(topic, partitionId.toInt) + val leaderBrokerId = partitionState.leader + val leaderEpoch = partitionState.leaderEpoch + val maybeLeader = getAliveEndpoint(snapshot, leaderBrokerId, listenerName) + + val replicas = partitionState.replicas + val filteredReplicas = maybeFilterAliveReplicas(snapshot, replicas, listenerName, errorUnavailableEndpoints) + + val isr = partitionState.isr + val filteredIsr = maybeFilterAliveReplicas(snapshot, isr, listenerName, errorUnavailableEndpoints) + + val offlineReplicas = partitionState.offlineReplicas + + maybeLeader match { + case None => + val error = if (!snapshot.aliveBrokers.contains(leaderBrokerId)) { // we are already holding the read lock + debug(s"Error while fetching metadata for $topicPartition: leader not available") + Errors.LEADER_NOT_AVAILABLE + } else { + debug(s"Error while fetching metadata for $topicPartition: listener $listenerName " + + s"not found on leader $leaderBrokerId") + if (errorUnavailableListeners) Errors.LISTENER_NOT_FOUND else Errors.LEADER_NOT_AVAILABLE + } + + new MetadataResponsePartition() + .setErrorCode(error.code) + .setPartitionIndex(partitionId.toInt) + .setLeaderId(MetadataResponse.NO_LEADER_ID) + .setLeaderEpoch(leaderEpoch) + .setReplicaNodes(filteredReplicas) + .setIsrNodes(filteredIsr) + .setOfflineReplicas(offlineReplicas) + + case Some(_) => + val error = if (filteredReplicas.size < replicas.size) { + debug(s"Error while fetching metadata for $topicPartition: replica information not available for " + + s"following brokers ${replicas.asScala.filterNot(filteredReplicas.contains).mkString(",")}") + Errors.REPLICA_NOT_AVAILABLE + } else if (filteredIsr.size < isr.size) { + debug(s"Error while fetching metadata for $topicPartition: in sync replica information not available for " + + s"following brokers ${isr.asScala.filterNot(filteredIsr.contains).mkString(",")}") + Errors.REPLICA_NOT_AVAILABLE + } else { + Errors.NONE + } + + new MetadataResponsePartition() + .setErrorCode(error.code) + .setPartitionIndex(partitionId.toInt) + .setLeaderId(maybeLeader.map(_.id()).getOrElse(MetadataResponse.NO_LEADER_ID)) + .setLeaderEpoch(leaderEpoch) + .setReplicaNodes(filteredReplicas) + .setIsrNodes(filteredIsr) + .setOfflineReplicas(offlineReplicas) + } + } + } + } + + /** + * Check whether a broker is alive and has a registered listener matching the provided name. + * This method was added to avoid unnecessary allocations in [[maybeFilterAliveReplicas]], which is + * a hotspot in metadata handling. + */ + private def hasAliveEndpoint(snapshot: MetadataSnapshot, brokerId: Int, listenerName: ListenerName): Boolean = { + snapshot.aliveNodes.get(brokerId).exists(_.contains(listenerName)) + } + + /** + * Get the endpoint matching the provided listener if the broker is alive. Note that listeners can + * be added dynamically, so a broker with a missing listener could be a transient error. + * + * @return None if broker is not alive or if the broker does not have a listener named `listenerName`. + */ + private def getAliveEndpoint(snapshot: MetadataSnapshot, brokerId: Int, listenerName: ListenerName): Option[Node] = { + snapshot.aliveNodes.get(brokerId).flatMap(_.get(listenerName)) + } + + // errorUnavailableEndpoints exists to support v0 MetadataResponses + def getTopicMetadata(topics: Set[String], + listenerName: ListenerName, + errorUnavailableEndpoints: Boolean = false, + errorUnavailableListeners: Boolean = false): Seq[MetadataResponseTopic] = { + val snapshot = metadataSnapshot + topics.toSeq.flatMap { topic => + getPartitionMetadata(snapshot, topic, listenerName, errorUnavailableEndpoints, errorUnavailableListeners).map { partitionMetadata => + new MetadataResponseTopic() + .setErrorCode(Errors.NONE.code) + .setName(topic) + .setTopicId(snapshot.topicIds.getOrElse(topic, Uuid.ZERO_UUID)) + .setIsInternal(Topic.isInternal(topic)) + .setPartitions(partitionMetadata.toBuffer.asJava) + } + } + } + + def topicNamesToIds(): util.Map[String, Uuid] = { + Collections.unmodifiableMap(metadataSnapshot.topicIds.asJava) + } + + def topicIdsToNames(): util.Map[Uuid, String] = { + Collections.unmodifiableMap(metadataSnapshot.topicNames.asJava) + } + + /** + * This method returns a map from topic names to IDs and a map from topic IDs to names + */ + def topicIdInfo(): (util.Map[String, Uuid], util.Map[Uuid, String]) = { + val snapshot = metadataSnapshot + (Collections.unmodifiableMap(snapshot.topicIds.asJava), Collections.unmodifiableMap(snapshot.topicNames.asJava)) + } + + override def getAllTopics(): Set[String] = { + getAllTopics(metadataSnapshot) + } + + override def getTopicPartitions(topicName: String): Set[TopicPartition] = { + metadataSnapshot.partitionStates.getOrElse(topicName, Map.empty).values. + map(p => new TopicPartition(topicName, p.partitionIndex())).toSet + } + + private def getAllTopics(snapshot: MetadataSnapshot): Set[String] = { + snapshot.partitionStates.keySet + } + + private def getAllPartitions(snapshot: MetadataSnapshot): Map[TopicPartition, UpdateMetadataPartitionState] = { + snapshot.partitionStates.flatMap { case (topic, partitionStates) => + partitionStates.map { case (partition, state) => (new TopicPartition(topic, partition.toInt), state) } + }.toMap + } + + def getNonExistingTopics(topics: Set[String]): Set[String] = { + topics.diff(metadataSnapshot.partitionStates.keySet) + } + + override def hasAliveBroker(brokerId: Int): Boolean = metadataSnapshot.aliveBrokers.contains(brokerId) + + override def getAliveBrokers(): Iterable[BrokerMetadata] = { + metadataSnapshot.aliveBrokers.values.map(b => new BrokerMetadata(b.id, b.rack)) + } + + override def getAliveBrokerNode(brokerId: Int, listenerName: ListenerName): Option[Node] = { + metadataSnapshot.aliveBrokers.get(brokerId).flatMap(_.getNode(listenerName)) + } + + override def getAliveBrokerNodes(listenerName: ListenerName): Iterable[Node] = { + metadataSnapshot.aliveBrokers.values.flatMap(_.getNode(listenerName)) + } + + def getTopicId(topicName: String): Uuid = { + metadataSnapshot.topicIds.getOrElse(topicName, Uuid.ZERO_UUID) + } + + def getTopicName(topicId: Uuid): Option[String] = { + metadataSnapshot.topicNames.get(topicId) + } + + private def addOrUpdatePartitionInfo(partitionStates: mutable.AnyRefMap[String, mutable.LongMap[UpdateMetadataPartitionState]], + topic: String, + partitionId: Int, + stateInfo: UpdateMetadataPartitionState): Unit = { + val infos = partitionStates.getOrElseUpdate(topic, mutable.LongMap.empty) + infos(partitionId) = stateInfo + } + + def getPartitionInfo(topic: String, partitionId: Int): Option[UpdateMetadataPartitionState] = { + metadataSnapshot.partitionStates.get(topic).flatMap(_.get(partitionId)) + } + + def numPartitions(topic: String): Option[Int] = { + metadataSnapshot.partitionStates.get(topic).map(_.size) + } + + // if the leader is not known, return None; + // if the leader is known and corresponding node is available, return Some(node) + // if the leader is known but corresponding node with the listener name is not available, return Some(NO_NODE) + def getPartitionLeaderEndpoint(topic: String, partitionId: Int, listenerName: ListenerName): Option[Node] = { + val snapshot = metadataSnapshot + snapshot.partitionStates.get(topic).flatMap(_.get(partitionId)) map { partitionInfo => + val leaderId = partitionInfo.leader + + snapshot.aliveNodes.get(leaderId) match { + case Some(nodeMap) => + nodeMap.getOrElse(listenerName, Node.noNode) + case None => + Node.noNode + } + } + } + + def getPartitionReplicaEndpoints(tp: TopicPartition, listenerName: ListenerName): Map[Int, Node] = { + val snapshot = metadataSnapshot + snapshot.partitionStates.get(tp.topic).flatMap(_.get(tp.partition)).map { partitionInfo => + val replicaIds = partitionInfo.replicas + replicaIds.asScala + .map(replicaId => replicaId.intValue() -> { + snapshot.aliveBrokers.get(replicaId.longValue()) match { + case Some(broker) => + broker.getNode(listenerName).getOrElse(Node.noNode()) + case None => + Node.noNode() + } + }).toMap + .filter(pair => pair match { + case (_, node) => !node.isEmpty + }) + }.getOrElse(Map.empty[Int, Node]) + } + + def getControllerId: Option[Int] = metadataSnapshot.controllerId + + def getClusterMetadata(clusterId: String, listenerName: ListenerName): Cluster = { + val snapshot = metadataSnapshot + val nodes = snapshot.aliveNodes.flatMap { case (id, nodesByListener) => + nodesByListener.get(listenerName).map { node => + id -> node + } + } + + def node(id: Integer): Node = { + nodes.getOrElse(id.toLong, new Node(id, "", -1)) + } + + val partitions = getAllPartitions(snapshot) + .filter { case (_, state) => state.leader != LeaderAndIsr.LeaderDuringDelete } + .map { case (tp, state) => + new PartitionInfo(tp.topic, tp.partition, node(state.leader), + state.replicas.asScala.map(node).toArray, + state.isr.asScala.map(node).toArray, + state.offlineReplicas.asScala.map(node).toArray) + } + val unauthorizedTopics = Collections.emptySet[String] + val internalTopics = getAllTopics(snapshot).filter(Topic.isInternal).asJava + new Cluster(clusterId, nodes.values.toBuffer.asJava, + partitions.toBuffer.asJava, + unauthorizedTopics, internalTopics, + snapshot.controllerId.map(id => node(id)).orNull) + } + + // This method returns the deleted TopicPartitions received from UpdateMetadataRequest + def updateMetadata(correlationId: Int, updateMetadataRequest: UpdateMetadataRequest): Seq[TopicPartition] = { + inWriteLock(partitionMetadataLock) { + + val aliveBrokers = new mutable.LongMap[Broker](metadataSnapshot.aliveBrokers.size) + val aliveNodes = new mutable.LongMap[collection.Map[ListenerName, Node]](metadataSnapshot.aliveNodes.size) + val controllerIdOpt = updateMetadataRequest.controllerId match { + case id if id < 0 => None + case id => Some(id) + } + + updateMetadataRequest.liveBrokers.forEach { broker => + // `aliveNodes` is a hot path for metadata requests for large clusters, so we use java.util.HashMap which + // is a bit faster than scala.collection.mutable.HashMap. When we drop support for Scala 2.10, we could + // move to `AnyRefMap`, which has comparable performance. + val nodes = new java.util.HashMap[ListenerName, Node] + val endPoints = new mutable.ArrayBuffer[EndPoint] + broker.endpoints.forEach { ep => + val listenerName = new ListenerName(ep.listener) + endPoints += new EndPoint(ep.host, ep.port, listenerName, SecurityProtocol.forId(ep.securityProtocol)) + nodes.put(listenerName, new Node(broker.id, ep.host, ep.port)) + } + aliveBrokers(broker.id) = Broker(broker.id, endPoints, Option(broker.rack)) + aliveNodes(broker.id) = nodes.asScala + } + aliveNodes.get(brokerId).foreach { listenerMap => + val listeners = listenerMap.keySet + if (!aliveNodes.values.forall(_.keySet == listeners)) + error(s"Listeners are not identical across brokers: $aliveNodes") + } + + val topicIds = mutable.Map.empty[String, Uuid] + topicIds ++= metadataSnapshot.topicIds + val (newTopicIds, newZeroIds) = updateMetadataRequest.topicStates().asScala + .map(topicState => (topicState.topicName(), topicState.topicId())) + .partition { case (_, topicId) => topicId != Uuid.ZERO_UUID } + newZeroIds.foreach { case (zeroIdTopic, _) => topicIds.remove(zeroIdTopic) } + topicIds ++= newTopicIds.toMap + + val deletedPartitions = new mutable.ArrayBuffer[TopicPartition] + if (!updateMetadataRequest.partitionStates.iterator.hasNext) { + metadataSnapshot = MetadataSnapshot(metadataSnapshot.partitionStates, topicIds.toMap, controllerIdOpt, aliveBrokers, aliveNodes) + } else { + //since kafka may do partial metadata updates, we start by copying the previous state + val partitionStates = new mutable.AnyRefMap[String, mutable.LongMap[UpdateMetadataPartitionState]](metadataSnapshot.partitionStates.size) + metadataSnapshot.partitionStates.forKeyValue { (topic, oldPartitionStates) => + val copy = new mutable.LongMap[UpdateMetadataPartitionState](oldPartitionStates.size) + copy ++= oldPartitionStates + partitionStates(topic) = copy + } + + val traceEnabled = stateChangeLogger.isTraceEnabled + val controllerId = updateMetadataRequest.controllerId + val controllerEpoch = updateMetadataRequest.controllerEpoch + val newStates = updateMetadataRequest.partitionStates.asScala + newStates.foreach { state => + // per-partition logging here can be very expensive due going through all partitions in the cluster + val tp = new TopicPartition(state.topicName, state.partitionIndex) + if (state.leader == LeaderAndIsr.LeaderDuringDelete) { + removePartitionInfo(partitionStates, topicIds, tp.topic, tp.partition) + if (traceEnabled) + stateChangeLogger.trace(s"Deleted partition $tp from metadata cache in response to UpdateMetadata " + + s"request sent by controller $controllerId epoch $controllerEpoch with correlation id $correlationId") + deletedPartitions += tp + } else { + addOrUpdatePartitionInfo(partitionStates, tp.topic, tp.partition, state) + if (traceEnabled) + stateChangeLogger.trace(s"Cached leader info $state for partition $tp in response to " + + s"UpdateMetadata request sent by controller $controllerId epoch $controllerEpoch with correlation id $correlationId") + } + } + val cachedPartitionsCount = newStates.size - deletedPartitions.size + stateChangeLogger.info(s"Add $cachedPartitionsCount partitions and deleted ${deletedPartitions.size} partitions from metadata cache " + + s"in response to UpdateMetadata request sent by controller $controllerId epoch $controllerEpoch with correlation id $correlationId") + + metadataSnapshot = MetadataSnapshot(partitionStates, topicIds.toMap, controllerIdOpt, aliveBrokers, aliveNodes) + } + deletedPartitions + } + } + + def contains(topic: String): Boolean = { + metadataSnapshot.partitionStates.contains(topic) + } + + def contains(tp: TopicPartition): Boolean = getPartitionInfo(tp.topic, tp.partition).isDefined + + private def removePartitionInfo(partitionStates: mutable.AnyRefMap[String, mutable.LongMap[UpdateMetadataPartitionState]], + topicIds: mutable.Map[String, Uuid], topic: String, partitionId: Int): Boolean = { + partitionStates.get(topic).exists { infos => + infos.remove(partitionId) + if (infos.isEmpty) { + partitionStates.remove(topic) + topicIds.remove(topic) + } + true + } + } + + case class MetadataSnapshot(partitionStates: mutable.AnyRefMap[String, mutable.LongMap[UpdateMetadataPartitionState]], + topicIds: Map[String, Uuid], + controllerId: Option[Int], + aliveBrokers: mutable.LongMap[Broker], + aliveNodes: mutable.LongMap[collection.Map[ListenerName, Node]]) { + val topicNames: Map[Uuid, String] = topicIds.map { case (topicName, topicId) => (topicId, topicName) } + } +} diff --git a/core/src/main/scala/kafka/server/package.html b/core/src/main/scala/kafka/server/package.html new file mode 100644 index 0000000..56ab9d4 --- /dev/null +++ b/core/src/main/scala/kafka/server/package.html @@ -0,0 +1,19 @@ + +The kafka server. \ No newline at end of file diff --git a/core/src/main/scala/kafka/tools/ClusterTool.scala b/core/src/main/scala/kafka/tools/ClusterTool.scala new file mode 100644 index 0000000..b868f72 --- /dev/null +++ b/core/src/main/scala/kafka/tools/ClusterTool.scala @@ -0,0 +1,124 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.tools + +import java.io.PrintStream +import java.util.Properties +import java.util.concurrent.ExecutionException + +import kafka.utils.{Exit, Logging} +import net.sourceforge.argparse4j.ArgumentParsers +import net.sourceforge.argparse4j.impl.Arguments.store +import org.apache.kafka.clients.admin.Admin +import org.apache.kafka.common.errors.UnsupportedVersionException +import org.apache.kafka.common.utils.Utils + +object ClusterTool extends Logging { + def main(args: Array[String]): Unit = { + try { + val parser = ArgumentParsers. + newArgumentParser("kafka-cluster"). + defaultHelp(true). + description("The Kafka cluster tool.") + val subparsers = parser.addSubparsers().dest("command") + + val clusterIdParser = subparsers.addParser("cluster-id"). + help("Get information about the ID of a cluster.") + val unregisterParser = subparsers.addParser("unregister"). + help("Unregister a broker.") + List(clusterIdParser, unregisterParser).foreach(parser => { + parser.addArgument("--bootstrap-server", "-b"). + action(store()). + help("A list of host/port pairs to use for establishing the connection to the kafka cluster.") + parser.addArgument("--config", "-c"). + action(store()). + help("A property file containing configs to passed to AdminClient.") + }) + unregisterParser.addArgument("--id", "-i"). + `type`(classOf[Integer]). + action(store()). + help("The ID of the broker to unregister.") + + val namespace = parser.parseArgsOrFail(args) + val command = namespace.getString("command") + val configPath = namespace.getString("config") + val properties = if (configPath == null) { + new Properties() + } else { + Utils.loadProps(configPath) + } + Option(namespace.getString("bootstrap_server")). + foreach(b => properties.setProperty("bootstrap.servers", b)) + if (properties.getProperty("bootstrap.servers") == null) { + throw new TerseFailure("Please specify --bootstrap-server.") + } + + command match { + case "cluster-id" => + val adminClient = Admin.create(properties) + try { + clusterIdCommand(System.out, adminClient) + } finally { + adminClient.close() + } + Exit.exit(0) + case "unregister" => + val adminClient = Admin.create(properties) + try { + unregisterCommand(System.out, adminClient, namespace.getInt("id")) + } finally { + adminClient.close() + } + Exit.exit(0) + case _ => + throw new RuntimeException(s"Unknown command $command") + } + } catch { + case e: TerseFailure => + System.err.println(e.getMessage) + System.exit(1) + } + } + + def clusterIdCommand(stream: PrintStream, + adminClient: Admin): Unit = { + val clusterId = Option(adminClient.describeCluster().clusterId().get()) + clusterId match { + case None => stream.println(s"No cluster ID found. The Kafka version is probably too old.") + case Some(id) => stream.println(s"Cluster ID: ${id}") + } + } + + def unregisterCommand(stream: PrintStream, + adminClient: Admin, + id: Int): Unit = { + try { + Option(adminClient.unregisterBroker(id).all().get()) + stream.println(s"Broker ${id} is no longer registered.") + } catch { + case e: ExecutionException => { + val cause = e.getCause() + if (cause.isInstanceOf[UnsupportedVersionException]) { + stream.println(s"The target cluster does not support the broker unregistration API.") + } else { + throw e + } + } + } + } +} diff --git a/core/src/main/scala/kafka/tools/ConsoleConsumer.scala b/core/src/main/scala/kafka/tools/ConsoleConsumer.scala new file mode 100755 index 0000000..4390999 --- /dev/null +++ b/core/src/main/scala/kafka/tools/ConsoleConsumer.scala @@ -0,0 +1,629 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.tools + +import java.io.PrintStream +import java.nio.charset.StandardCharsets +import java.time.Duration +import java.util.concurrent.CountDownLatch +import java.util.regex.Pattern +import java.util.{Collections, Locale, Map, Optional, Properties, Random} +import com.typesafe.scalalogging.LazyLogging +import joptsimple._ +import kafka.utils.Implicits._ +import kafka.utils.{Exit, _} +import org.apache.kafka.clients.consumer.{Consumer, ConsumerConfig, ConsumerRecord, KafkaConsumer} +import org.apache.kafka.common.{MessageFormatter, TopicPartition} +import org.apache.kafka.common.errors.{AuthenticationException, TimeoutException, WakeupException} +import org.apache.kafka.common.record.TimestampType +import org.apache.kafka.common.requests.ListOffsetsRequest +import org.apache.kafka.common.serialization.{ByteArrayDeserializer, Deserializer} +import org.apache.kafka.common.utils.Utils + +import scala.jdk.CollectionConverters._ + +/** + * Consumer that dumps messages to standard out. + */ +object ConsoleConsumer extends Logging { + + var messageCount = 0 + + private val shutdownLatch = new CountDownLatch(1) + + def main(args: Array[String]): Unit = { + val conf = new ConsumerConfig(args) + try { + run(conf) + } catch { + case e: AuthenticationException => + error("Authentication failed: terminating consumer process", e) + Exit.exit(1) + case e: Throwable => + error("Unknown error when running consumer: ", e) + Exit.exit(1) + } + } + + def run(conf: ConsumerConfig): Unit = { + val timeoutMs = if (conf.timeoutMs >= 0) conf.timeoutMs else Long.MaxValue + val consumer = new KafkaConsumer(consumerProps(conf), new ByteArrayDeserializer, new ByteArrayDeserializer) + + val consumerWrapper = + if (conf.partitionArg.isDefined) + new ConsumerWrapper(Option(conf.topicArg), conf.partitionArg, Option(conf.offsetArg), None, consumer, timeoutMs) + else + new ConsumerWrapper(Option(conf.topicArg), None, None, Option(conf.includedTopicsArg), consumer, timeoutMs) + + addShutdownHook(consumerWrapper, conf) + + try process(conf.maxMessages, conf.formatter, consumerWrapper, System.out, conf.skipMessageOnError) + finally { + consumerWrapper.cleanup() + conf.formatter.close() + reportRecordCount() + + shutdownLatch.countDown() + } + } + + def addShutdownHook(consumer: ConsumerWrapper, conf: ConsumerConfig): Unit = { + Exit.addShutdownHook("consumer-shutdown-hook", { + consumer.wakeup() + + shutdownLatch.await() + + if (conf.enableSystestEventsLogging) { + System.out.println("shutdown_complete") + } + }) + } + + def process(maxMessages: Integer, formatter: MessageFormatter, consumer: ConsumerWrapper, output: PrintStream, + skipMessageOnError: Boolean): Unit = { + while (messageCount < maxMessages || maxMessages == -1) { + val msg: ConsumerRecord[Array[Byte], Array[Byte]] = try { + consumer.receive() + } catch { + case _: WakeupException => + trace("Caught WakeupException because consumer is shutdown, ignore and terminate.") + // Consumer will be closed + return + case e: Throwable => + error("Error processing message, terminating consumer process: ", e) + // Consumer will be closed + return + } + messageCount += 1 + try { + formatter.writeTo(new ConsumerRecord(msg.topic, msg.partition, msg.offset, msg.timestamp, msg.timestampType, + 0, 0, msg.key, msg.value, msg.headers, Optional.empty[Integer]), output) + } catch { + case e: Throwable => + if (skipMessageOnError) { + error("Error processing message, skipping this message: ", e) + } else { + // Consumer will be closed + throw e + } + } + if (checkErr(output, formatter)) { + // Consumer will be closed + return + } + } + } + + def reportRecordCount(): Unit = { + System.err.println(s"Processed a total of $messageCount messages") + } + + def checkErr(output: PrintStream, formatter: MessageFormatter): Boolean = { + val gotError = output.checkError() + if (gotError) { + // This means no one is listening to our output stream anymore, time to shutdown + System.err.println("Unable to write to standard out, closing consumer.") + } + gotError + } + + private[tools] def consumerProps(config: ConsumerConfig): Properties = { + val props = new Properties + props ++= config.consumerProps + props ++= config.extraConsumerProps + setAutoOffsetResetValue(config, props) + props.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, config.bootstrapServer) + if (props.getProperty(ConsumerConfig.CLIENT_ID_CONFIG) == null) + props.put(ConsumerConfig.CLIENT_ID_CONFIG, "console-consumer") + CommandLineUtils.maybeMergeOptions( + props, ConsumerConfig.ISOLATION_LEVEL_CONFIG, config.options, config.isolationLevelOpt) + props + } + + /** + * Used by consumerProps to retrieve the correct value for the consumer parameter 'auto.offset.reset'. + * + * Order of priority is: + * 1. Explicitly set parameter via --consumer.property command line parameter + * 2. Explicit --from-beginning given -> 'earliest' + * 3. Default value of 'latest' + * + * In case both --from-beginning and an explicit value are specified an error is thrown if these + * are conflicting. + */ + def setAutoOffsetResetValue(config: ConsumerConfig, props: Properties): Unit = { + val (earliestConfigValue, latestConfigValue) = ("earliest", "latest") + + if (props.containsKey(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG)) { + // auto.offset.reset parameter was specified on the command line + val autoResetOption = props.getProperty(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG) + if (config.options.has(config.resetBeginningOpt) && earliestConfigValue != autoResetOption) { + // conflicting options - latest und earliest, throw an error + System.err.println(s"Can't simultaneously specify --from-beginning and 'auto.offset.reset=$autoResetOption', " + + "please remove one option") + Exit.exit(1) + } + // nothing to do, checking for valid parameter values happens later and the specified + // value was already copied during .putall operation + } else { + // no explicit value for auto.offset.reset was specified + // if --from-beginning was specified use earliest, otherwise default to latest + val autoResetOption = if (config.options.has(config.resetBeginningOpt)) earliestConfigValue else latestConfigValue + props.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, autoResetOption) + } + } + + class ConsumerConfig(args: Array[String]) extends CommandDefaultOptions(args) { + val topicOpt = parser.accepts("topic", "The topic to consume on.") + .withRequiredArg + .describedAs("topic") + .ofType(classOf[String]) + val whitelistOpt = parser.accepts("whitelist", + "DEPRECATED, use --include instead; ignored if --include specified. Regular expression specifying list of topics to include for consumption.") + .withRequiredArg + .describedAs("Java regex (String)") + .ofType(classOf[String]) + val includeOpt = parser.accepts("include", + "Regular expression specifying list of topics to include for consumption.") + .withRequiredArg + .describedAs("Java regex (String)") + .ofType(classOf[String]) + val partitionIdOpt = parser.accepts("partition", "The partition to consume from. Consumption " + + "starts from the end of the partition unless '--offset' is specified.") + .withRequiredArg + .describedAs("partition") + .ofType(classOf[java.lang.Integer]) + val offsetOpt = parser.accepts("offset", "The offset to consume from (a non-negative number), or 'earliest' which means from beginning, or 'latest' which means from end") + .withRequiredArg + .describedAs("consume offset") + .ofType(classOf[String]) + .defaultsTo("latest") + val consumerPropertyOpt = parser.accepts("consumer-property", "A mechanism to pass user-defined properties in the form key=value to the consumer.") + .withRequiredArg + .describedAs("consumer_prop") + .ofType(classOf[String]) + val consumerConfigOpt = parser.accepts("consumer.config", s"Consumer config properties file. Note that $consumerPropertyOpt takes precedence over this config.") + .withRequiredArg + .describedAs("config file") + .ofType(classOf[String]) + val messageFormatterOpt = parser.accepts("formatter", "The name of a class to use for formatting kafka messages for display.") + .withRequiredArg + .describedAs("class") + .ofType(classOf[String]) + .defaultsTo(classOf[DefaultMessageFormatter].getName) + val messageFormatterArgOpt = parser.accepts("property", + """The properties to initialize the message formatter. Default properties include: + | print.timestamp=true|false + | print.key=true|false + | print.offset=true|false + | print.partition=true|false + | print.headers=true|false + | print.value=true|false + | key.separator= + | line.separator= + | headers.separator= + | null.literal= + | key.deserializer= + | value.deserializer= + | header.deserializer= + | + |Users can also pass in customized properties for their formatter; more specifically, users can pass in properties keyed with 'key.deserializer.', 'value.deserializer.' and 'headers.deserializer.' prefixes to configure their deserializers.""" + .stripMargin) + .withRequiredArg + .describedAs("prop") + .ofType(classOf[String]) + val resetBeginningOpt = parser.accepts("from-beginning", "If the consumer does not already have an established offset to consume from, " + + "start with the earliest message present in the log rather than the latest message.") + val maxMessagesOpt = parser.accepts("max-messages", "The maximum number of messages to consume before exiting. If not set, consumption is continual.") + .withRequiredArg + .describedAs("num_messages") + .ofType(classOf[java.lang.Integer]) + val timeoutMsOpt = parser.accepts("timeout-ms", "If specified, exit if no message is available for consumption for the specified interval.") + .withRequiredArg + .describedAs("timeout_ms") + .ofType(classOf[java.lang.Integer]) + val skipMessageOnErrorOpt = parser.accepts("skip-message-on-error", "If there is an error when processing a message, " + + "skip it instead of halt.") + val bootstrapServerOpt = parser.accepts("bootstrap-server", "REQUIRED: The server(s) to connect to.") + .withRequiredArg + .describedAs("server to connect to") + .ofType(classOf[String]) + val keyDeserializerOpt = parser.accepts("key-deserializer") + .withRequiredArg + .describedAs("deserializer for key") + .ofType(classOf[String]) + val valueDeserializerOpt = parser.accepts("value-deserializer") + .withRequiredArg + .describedAs("deserializer for values") + .ofType(classOf[String]) + val enableSystestEventsLoggingOpt = parser.accepts("enable-systest-events", + "Log lifecycle events of the consumer in addition to logging consumed " + + "messages. (This is specific for system tests.)") + val isolationLevelOpt = parser.accepts("isolation-level", + "Set to read_committed in order to filter out transactional messages which are not committed. Set to read_uncommitted " + + "to read all messages.") + .withRequiredArg() + .ofType(classOf[String]) + .defaultsTo("read_uncommitted") + + val groupIdOpt = parser.accepts("group", "The consumer group id of the consumer.") + .withRequiredArg + .describedAs("consumer group id") + .ofType(classOf[String]) + + options = tryParse(parser, args) + + CommandLineUtils.printHelpAndExitIfNeeded(this, "This tool helps to read data from Kafka topics and outputs it to standard output.") + + var groupIdPassed = true + val enableSystestEventsLogging = options.has(enableSystestEventsLoggingOpt) + + // topic must be specified. + var topicArg: String = null + var includedTopicsArg: String = null + var filterSpec: TopicFilter = null + val extraConsumerProps = CommandLineUtils.parseKeyValueArgs(options.valuesOf(consumerPropertyOpt).asScala) + val consumerProps = if (options.has(consumerConfigOpt)) + Utils.loadProps(options.valueOf(consumerConfigOpt)) + else + new Properties() + val fromBeginning = options.has(resetBeginningOpt) + val partitionArg = if (options.has(partitionIdOpt)) Some(options.valueOf(partitionIdOpt).intValue) else None + val skipMessageOnError = options.has(skipMessageOnErrorOpt) + val messageFormatterClass = Class.forName(options.valueOf(messageFormatterOpt)) + val formatterArgs = CommandLineUtils.parseKeyValueArgs(options.valuesOf(messageFormatterArgOpt).asScala) + val maxMessages = if (options.has(maxMessagesOpt)) options.valueOf(maxMessagesOpt).intValue else -1 + val timeoutMs = if (options.has(timeoutMsOpt)) options.valueOf(timeoutMsOpt).intValue else -1 + val bootstrapServer = options.valueOf(bootstrapServerOpt) + val keyDeserializer = options.valueOf(keyDeserializerOpt) + val valueDeserializer = options.valueOf(valueDeserializerOpt) + val formatter: MessageFormatter = messageFormatterClass.getDeclaredConstructor().newInstance().asInstanceOf[MessageFormatter] + + if (keyDeserializer != null && keyDeserializer.nonEmpty) { + formatterArgs.setProperty(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, keyDeserializer) + } + if (valueDeserializer != null && valueDeserializer.nonEmpty) { + formatterArgs.setProperty(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, valueDeserializer) + } + + formatter.configure(formatterArgs.asScala.asJava) + + topicArg = options.valueOf(topicOpt) + includedTopicsArg = if (options.has(includeOpt)) + options.valueOf(includeOpt) + else + options.valueOf(whitelistOpt) + + val topicOrFilterArgs = List(topicArg, includedTopicsArg).filterNot(_ == null) + // user need to specify value for either --topic or one of the include filters options (--include or --whitelist) + if (topicOrFilterArgs.size != 1) + CommandLineUtils.printUsageAndDie(parser, s"Exactly one of --include/--topic is required. " + + s"${if (options.has(whitelistOpt)) "--whitelist is DEPRECATED use --include instead; ignored if --include specified."}") + + if (partitionArg.isDefined) { + if (!options.has(topicOpt)) + CommandLineUtils.printUsageAndDie(parser, "The topic is required when partition is specified.") + if (fromBeginning && options.has(offsetOpt)) + CommandLineUtils.printUsageAndDie(parser, "Options from-beginning and offset cannot be specified together.") + } else if (options.has(offsetOpt)) + CommandLineUtils.printUsageAndDie(parser, "The partition is required when offset is specified.") + + def invalidOffset(offset: String): Nothing = + CommandLineUtils.printUsageAndDie(parser, s"The provided offset value '$offset' is incorrect. Valid values are " + + "'earliest', 'latest', or a non-negative long.") + + val offsetArg = + if (options.has(offsetOpt)) { + options.valueOf(offsetOpt).toLowerCase(Locale.ROOT) match { + case "earliest" => ListOffsetsRequest.EARLIEST_TIMESTAMP + case "latest" => ListOffsetsRequest.LATEST_TIMESTAMP + case offsetString => + try { + val offset = offsetString.toLong + if (offset < 0) + invalidOffset(offsetString) + offset + } catch { + case _: NumberFormatException => invalidOffset(offsetString) + } + } + } + else if (fromBeginning) ListOffsetsRequest.EARLIEST_TIMESTAMP + else ListOffsetsRequest.LATEST_TIMESTAMP + + CommandLineUtils.checkRequiredArgs(parser, options, bootstrapServerOpt) + + // if the group id is provided in more than place (through different means) all values must be the same + val groupIdsProvided = Set( + Option(options.valueOf(groupIdOpt)), // via --group + Option(consumerProps.get(ConsumerConfig.GROUP_ID_CONFIG)), // via --consumer-property + Option(extraConsumerProps.get(ConsumerConfig.GROUP_ID_CONFIG)) // via --consumer.config + ).flatten + + if (groupIdsProvided.size > 1) { + CommandLineUtils.printUsageAndDie(parser, "The group ids provided in different places (directly using '--group', " + + "via '--consumer-property', or via '--consumer.config') do not match. " + + s"Detected group ids: ${groupIdsProvided.mkString("'", "', '", "'")}") + } + + groupIdsProvided.headOption match { + case Some(group) => + consumerProps.put(ConsumerConfig.GROUP_ID_CONFIG, group) + case None => + consumerProps.put(ConsumerConfig.GROUP_ID_CONFIG, s"console-consumer-${new Random().nextInt(100000)}") + // By default, avoid unnecessary expansion of the coordinator cache since + // the auto-generated group and its offsets is not intended to be used again + if (!consumerProps.containsKey(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG)) + consumerProps.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") + groupIdPassed = false + } + + if (groupIdPassed && partitionArg.isDefined) + CommandLineUtils.printUsageAndDie(parser, "Options group and partition cannot be specified together.") + + def tryParse(parser: OptionParser, args: Array[String]): OptionSet = { + try + parser.parse(args: _*) + catch { + case e: OptionException => + CommandLineUtils.printUsageAndDie(parser, e.getMessage) + } + } + } + + private[tools] class ConsumerWrapper(topic: Option[String], partitionId: Option[Int], offset: Option[Long], includedTopics: Option[String], + consumer: Consumer[Array[Byte], Array[Byte]], val timeoutMs: Long = Long.MaxValue) { + consumerInit() + var recordIter = Collections.emptyList[ConsumerRecord[Array[Byte], Array[Byte]]]().iterator() + + def consumerInit(): Unit = { + (topic, partitionId, offset, includedTopics) match { + case (Some(topic), Some(partitionId), Some(offset), None) => + seek(topic, partitionId, offset) + case (Some(topic), Some(partitionId), None, None) => + // default to latest if no offset is provided + seek(topic, partitionId, ListOffsetsRequest.LATEST_TIMESTAMP) + case (Some(topic), None, None, None) => + consumer.subscribe(Collections.singletonList(topic)) + case (None, None, None, Some(include)) => + consumer.subscribe(Pattern.compile(include)) + case _ => + throw new IllegalArgumentException("An invalid combination of arguments is provided. " + + "Exactly one of 'topic' or 'include' must be provided. " + + "If 'topic' is provided, an optional 'partition' may also be provided. " + + "If 'partition' is provided, an optional 'offset' may also be provided, otherwise, consumption starts from the end of the partition.") + } + } + + def seek(topic: String, partitionId: Int, offset: Long): Unit = { + val topicPartition = new TopicPartition(topic, partitionId) + consumer.assign(Collections.singletonList(topicPartition)) + offset match { + case ListOffsetsRequest.EARLIEST_TIMESTAMP => consumer.seekToBeginning(Collections.singletonList(topicPartition)) + case ListOffsetsRequest.LATEST_TIMESTAMP => consumer.seekToEnd(Collections.singletonList(topicPartition)) + case _ => consumer.seek(topicPartition, offset) + } + } + + def resetUnconsumedOffsets(): Unit = { + val smallestUnconsumedOffsets = collection.mutable.Map[TopicPartition, Long]() + while (recordIter.hasNext) { + val record = recordIter.next() + val tp = new TopicPartition(record.topic, record.partition) + // avoid auto-committing offsets which haven't been consumed + smallestUnconsumedOffsets.getOrElseUpdate(tp, record.offset) + } + smallestUnconsumedOffsets.forKeyValue { (tp, offset) => consumer.seek(tp, offset) } + } + + def receive(): ConsumerRecord[Array[Byte], Array[Byte]] = { + if (!recordIter.hasNext) { + recordIter = consumer.poll(Duration.ofMillis(timeoutMs)).iterator + if (!recordIter.hasNext) + throw new TimeoutException() + } + + recordIter.next + } + + def wakeup(): Unit = { + this.consumer.wakeup() + } + + def cleanup(): Unit = { + resetUnconsumedOffsets() + this.consumer.close() + } + + } +} + +class DefaultMessageFormatter extends MessageFormatter { + var printTimestamp = false + var printKey = false + var printValue = true + var printPartition = false + var printOffset = false + var printHeaders = false + var keySeparator = utfBytes("\t") + var lineSeparator = utfBytes("\n") + var headersSeparator = utfBytes(",") + var nullLiteral = utfBytes("null") + + var keyDeserializer: Option[Deserializer[_]] = None + var valueDeserializer: Option[Deserializer[_]] = None + var headersDeserializer: Option[Deserializer[_]] = None + + override def configure(configs: Map[String, _]): Unit = { + getPropertyIfExists(configs, "print.timestamp", getBoolProperty).foreach(printTimestamp = _) + getPropertyIfExists(configs, "print.key", getBoolProperty).foreach(printKey = _) + getPropertyIfExists(configs, "print.offset", getBoolProperty).foreach(printOffset = _) + getPropertyIfExists(configs, "print.partition", getBoolProperty).foreach(printPartition = _) + getPropertyIfExists(configs, "print.headers", getBoolProperty).foreach(printHeaders = _) + getPropertyIfExists(configs, "print.value", getBoolProperty).foreach(printValue = _) + getPropertyIfExists(configs, "key.separator", getByteProperty).foreach(keySeparator = _) + getPropertyIfExists(configs, "line.separator", getByteProperty).foreach(lineSeparator = _) + getPropertyIfExists(configs, "headers.separator", getByteProperty).foreach(headersSeparator = _) + getPropertyIfExists(configs, "null.literal", getByteProperty).foreach(nullLiteral = _) + + keyDeserializer = getPropertyIfExists(configs, "key.deserializer", getDeserializerProperty(true)) + valueDeserializer = getPropertyIfExists(configs, "value.deserializer", getDeserializerProperty(false)) + headersDeserializer = getPropertyIfExists(configs, "headers.deserializer", getDeserializerProperty(false)) + } + + def writeTo(consumerRecord: ConsumerRecord[Array[Byte], Array[Byte]], output: PrintStream): Unit = { + + def writeSeparator(columnSeparator: Boolean): Unit = { + if (columnSeparator) + output.write(keySeparator) + else + output.write(lineSeparator) + } + + def deserialize(deserializer: Option[Deserializer[_]], sourceBytes: Array[Byte], topic: String) = { + val nonNullBytes = Option(sourceBytes).getOrElse(nullLiteral) + val convertedBytes = deserializer + .map(d => utfBytes(d.deserialize(topic, consumerRecord.headers, nonNullBytes).toString)) + .getOrElse(nonNullBytes) + convertedBytes + } + + import consumerRecord._ + + if (printTimestamp) { + if (timestampType != TimestampType.NO_TIMESTAMP_TYPE) + output.write(utfBytes(s"$timestampType:$timestamp")) + else + output.write(utfBytes("NO_TIMESTAMP")) + writeSeparator(columnSeparator = printOffset || printPartition || printHeaders || printKey || printValue) + } + + if (printPartition) { + output.write(utfBytes("Partition:")) + output.write(utfBytes(partition().toString)) + writeSeparator(columnSeparator = printOffset || printHeaders || printKey || printValue) + } + + if (printOffset) { + output.write(utfBytes("Offset:")) + output.write(utfBytes(offset().toString)) + writeSeparator(columnSeparator = printHeaders || printKey || printValue) + } + + if (printHeaders) { + val headersIt = headers().iterator.asScala + if (headersIt.hasNext) { + headersIt.foreach { header => + output.write(utfBytes(header.key() + ":")) + output.write(deserialize(headersDeserializer, header.value(), topic)) + if (headersIt.hasNext) { + output.write(headersSeparator) + } + } + } else { + output.write(utfBytes("NO_HEADERS")) + } + writeSeparator(columnSeparator = printKey || printValue) + } + + if (printKey) { + output.write(deserialize(keyDeserializer, key, topic)) + writeSeparator(columnSeparator = printValue) + } + + if (printValue) { + output.write(deserialize(valueDeserializer, value, topic)) + output.write(lineSeparator) + } + } + + private def propertiesWithKeyPrefixStripped(prefix: String, configs: Map[String, _]): Map[String, _] = { + val newConfigs = collection.mutable.Map[String, Any]() + configs.asScala.foreach { case (key, value) => + if (key.startsWith(prefix) && key.length > prefix.length) + newConfigs.put(key.substring(prefix.length), value) + } + newConfigs.asJava + } + + private def utfBytes(str: String) = str.getBytes(StandardCharsets.UTF_8) + + private def getByteProperty(configs: Map[String, _], key: String): Array[Byte] = { + utfBytes(configs.get(key).asInstanceOf[String]) + } + + private def getBoolProperty(configs: Map[String, _], key: String): Boolean = { + configs.get(key).asInstanceOf[String].trim.equalsIgnoreCase("true") + } + + private def getDeserializerProperty(isKey: Boolean)(configs: Map[String, _], propertyName: String): Deserializer[_] = { + val deserializer = Class.forName(configs.get(propertyName).asInstanceOf[String]).getDeclaredConstructor().newInstance().asInstanceOf[Deserializer[_]] + val deserializerConfig = propertiesWithKeyPrefixStripped(propertyName + ".", configs) + .asScala + .asJava + deserializer.configure(deserializerConfig, isKey) + deserializer + } + + private def getPropertyIfExists[T](configs: Map[String, _], key: String, getter: (Map[String, _], String) => T): Option[T] = { + if (configs.containsKey(key)) + Some(getter(configs, key)) + else + None + } +} + +class LoggingMessageFormatter extends MessageFormatter with LazyLogging { + private val defaultWriter: DefaultMessageFormatter = new DefaultMessageFormatter + + override def configure(configs: Map[String, _]): Unit = defaultWriter.configure(configs) + + def writeTo(consumerRecord: ConsumerRecord[Array[Byte], Array[Byte]], output: PrintStream): Unit = { + import consumerRecord._ + defaultWriter.writeTo(consumerRecord, output) + logger.info({if (timestampType != TimestampType.NO_TIMESTAMP_TYPE) s"$timestampType:$timestamp, " else ""} + + s"key:${if (key == null) "null" else new String(key, StandardCharsets.UTF_8)}, " + + s"value:${if (value == null) "null" else new String(value, StandardCharsets.UTF_8)}") + } +} + +class NoOpMessageFormatter extends MessageFormatter { + + def writeTo(consumerRecord: ConsumerRecord[Array[Byte], Array[Byte]], output: PrintStream): Unit = {} +} + diff --git a/core/src/main/scala/kafka/tools/ConsoleProducer.scala b/core/src/main/scala/kafka/tools/ConsoleProducer.scala new file mode 100644 index 0000000..7c221ba --- /dev/null +++ b/core/src/main/scala/kafka/tools/ConsoleProducer.scala @@ -0,0 +1,302 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.tools + +import java.io._ +import java.nio.charset.StandardCharsets +import java.util.Properties + +import joptsimple.{OptionException, OptionParser, OptionSet} +import kafka.common._ +import kafka.message._ +import kafka.utils.Implicits._ +import kafka.utils.{CommandDefaultOptions, CommandLineUtils, Exit, ToolsUtils} +import org.apache.kafka.clients.producer.internals.ErrorLoggingCallback +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerConfig, ProducerRecord} +import org.apache.kafka.common.KafkaException +import org.apache.kafka.common.utils.Utils + +import scala.jdk.CollectionConverters._ + +object ConsoleProducer { + + def main(args: Array[String]): Unit = { + + try { + val config = new ProducerConfig(args) + val reader = Class.forName(config.readerClass).getDeclaredConstructor().newInstance().asInstanceOf[MessageReader] + reader.init(System.in, getReaderProps(config)) + + val producer = new KafkaProducer[Array[Byte], Array[Byte]](producerProps(config)) + + Exit.addShutdownHook("producer-shutdown-hook", producer.close) + + var record: ProducerRecord[Array[Byte], Array[Byte]] = null + do { + record = reader.readMessage() + if (record != null) + send(producer, record, config.sync) + } while (record != null) + } catch { + case e: joptsimple.OptionException => + System.err.println(e.getMessage) + Exit.exit(1) + case e: Exception => + e.printStackTrace + Exit.exit(1) + } + Exit.exit(0) + } + + private def send(producer: KafkaProducer[Array[Byte], Array[Byte]], + record: ProducerRecord[Array[Byte], Array[Byte]], sync: Boolean): Unit = { + if (sync) + producer.send(record).get() + else + producer.send(record, new ErrorLoggingCallback(record.topic, record.key, record.value, false)) + } + + def getReaderProps(config: ProducerConfig): Properties = { + val props = new Properties + props.put("topic", config.topic) + props ++= config.cmdLineProps + props + } + + def producerProps(config: ProducerConfig): Properties = { + val props = + if (config.options.has(config.producerConfigOpt)) + Utils.loadProps(config.options.valueOf(config.producerConfigOpt)) + else new Properties + + props ++= config.extraProducerProps + + if (config.bootstrapServer != null) + props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, config.bootstrapServer) + else + props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, config.brokerList) + + props.put(ProducerConfig.COMPRESSION_TYPE_CONFIG, config.compressionCodec) + if (props.getProperty(ProducerConfig.CLIENT_ID_CONFIG) == null) + props.put(ProducerConfig.CLIENT_ID_CONFIG, "console-producer") + props.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArraySerializer") + props.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArraySerializer") + + CommandLineUtils.maybeMergeOptions( + props, ProducerConfig.LINGER_MS_CONFIG, config.options, config.sendTimeoutOpt) + CommandLineUtils.maybeMergeOptions( + props, ProducerConfig.ACKS_CONFIG, config.options, config.requestRequiredAcksOpt) + CommandLineUtils.maybeMergeOptions( + props, ProducerConfig.REQUEST_TIMEOUT_MS_CONFIG, config.options, config.requestTimeoutMsOpt) + CommandLineUtils.maybeMergeOptions( + props, ProducerConfig.RETRIES_CONFIG, config.options, config.messageSendMaxRetriesOpt) + CommandLineUtils.maybeMergeOptions( + props, ProducerConfig.RETRY_BACKOFF_MS_CONFIG, config.options, config.retryBackoffMsOpt) + CommandLineUtils.maybeMergeOptions( + props, ProducerConfig.SEND_BUFFER_CONFIG, config.options, config.socketBufferSizeOpt) + CommandLineUtils.maybeMergeOptions( + props, ProducerConfig.BUFFER_MEMORY_CONFIG, config.options, config.maxMemoryBytesOpt) + CommandLineUtils.maybeMergeOptions( + props, ProducerConfig.BATCH_SIZE_CONFIG, config.options, config.maxPartitionMemoryBytesOpt) + CommandLineUtils.maybeMergeOptions( + props, ProducerConfig.METADATA_MAX_AGE_CONFIG, config.options, config.metadataExpiryMsOpt) + CommandLineUtils.maybeMergeOptions( + props, ProducerConfig.MAX_BLOCK_MS_CONFIG, config.options, config.maxBlockMsOpt) + + props + } + + class ProducerConfig(args: Array[String]) extends CommandDefaultOptions(args) { + val topicOpt = parser.accepts("topic", "REQUIRED: The topic id to produce messages to.") + .withRequiredArg + .describedAs("topic") + .ofType(classOf[String]) + val brokerListOpt = parser.accepts("broker-list", "DEPRECATED, use --bootstrap-server instead; ignored if --bootstrap-server is specified. The broker list string in the form HOST1:PORT1,HOST2:PORT2.") + .withRequiredArg + .describedAs("broker-list") + .ofType(classOf[String]) + val bootstrapServerOpt = parser.accepts("bootstrap-server", "REQUIRED unless --broker-list(deprecated) is specified. The server(s) to connect to. The broker list string in the form HOST1:PORT1,HOST2:PORT2.") + .requiredUnless("broker-list") + .withRequiredArg + .describedAs("server to connect to") + .ofType(classOf[String]) + val syncOpt = parser.accepts("sync", "If set message send requests to the brokers are synchronously, one at a time as they arrive.") + val compressionCodecOpt = parser.accepts("compression-codec", "The compression codec: either 'none', 'gzip', 'snappy', 'lz4', or 'zstd'." + + "If specified without value, then it defaults to 'gzip'") + .withOptionalArg() + .describedAs("compression-codec") + .ofType(classOf[String]) + val batchSizeOpt = parser.accepts("batch-size", "Number of messages to send in a single batch if they are not being sent synchronously.") + .withRequiredArg + .describedAs("size") + .ofType(classOf[java.lang.Integer]) + .defaultsTo(200) + val messageSendMaxRetriesOpt = parser.accepts("message-send-max-retries", "Brokers can fail receiving the message for multiple reasons, and being unavailable transiently is just one of them. This property specifies the number of retries before the producer give up and drop this message.") + .withRequiredArg + .ofType(classOf[java.lang.Integer]) + .defaultsTo(3) + val retryBackoffMsOpt = parser.accepts("retry-backoff-ms", "Before each retry, the producer refreshes the metadata of relevant topics. Since leader election takes a bit of time, this property specifies the amount of time that the producer waits before refreshing the metadata.") + .withRequiredArg + .ofType(classOf[java.lang.Integer]) + .defaultsTo(100) + val sendTimeoutOpt = parser.accepts("timeout", "If set and the producer is running in asynchronous mode, this gives the maximum amount of time" + + " a message will queue awaiting sufficient batch size. The value is given in ms.") + .withRequiredArg + .describedAs("timeout_ms") + .ofType(classOf[java.lang.Integer]) + .defaultsTo(1000) + val requestRequiredAcksOpt = parser.accepts("request-required-acks", "The required acks of the producer requests") + .withRequiredArg + .describedAs("request required acks") + .ofType(classOf[java.lang.String]) + .defaultsTo("1") + val requestTimeoutMsOpt = parser.accepts("request-timeout-ms", "The ack timeout of the producer requests. Value must be non-negative and non-zero") + .withRequiredArg + .describedAs("request timeout ms") + .ofType(classOf[java.lang.Integer]) + .defaultsTo(1500) + val metadataExpiryMsOpt = parser.accepts("metadata-expiry-ms", + "The period of time in milliseconds after which we force a refresh of metadata even if we haven't seen any leadership changes.") + .withRequiredArg + .describedAs("metadata expiration interval") + .ofType(classOf[java.lang.Long]) + .defaultsTo(5*60*1000L) + val maxBlockMsOpt = parser.accepts("max-block-ms", + "The max time that the producer will block for during a send request") + .withRequiredArg + .describedAs("max block on send") + .ofType(classOf[java.lang.Long]) + .defaultsTo(60*1000L) + val maxMemoryBytesOpt = parser.accepts("max-memory-bytes", + "The total memory used by the producer to buffer records waiting to be sent to the server.") + .withRequiredArg + .describedAs("total memory in bytes") + .ofType(classOf[java.lang.Long]) + .defaultsTo(32 * 1024 * 1024L) + val maxPartitionMemoryBytesOpt = parser.accepts("max-partition-memory-bytes", + "The buffer size allocated for a partition. When records are received which are smaller than this size the producer " + + "will attempt to optimistically group them together until this size is reached.") + .withRequiredArg + .describedAs("memory in bytes per partition") + .ofType(classOf[java.lang.Long]) + .defaultsTo(16 * 1024L) + val messageReaderOpt = parser.accepts("line-reader", "The class name of the class to use for reading lines from standard in. " + + "By default each line is read as a separate message.") + .withRequiredArg + .describedAs("reader_class") + .ofType(classOf[java.lang.String]) + .defaultsTo(classOf[LineMessageReader].getName) + val socketBufferSizeOpt = parser.accepts("socket-buffer-size", "The size of the tcp RECV size.") + .withRequiredArg + .describedAs("size") + .ofType(classOf[java.lang.Integer]) + .defaultsTo(1024*100) + val propertyOpt = parser.accepts("property", "A mechanism to pass user-defined properties in the form key=value to the message reader. " + + "This allows custom configuration for a user-defined message reader. Default properties include:\n" + + "\tparse.key=true|false\n" + + "\tkey.separator=\n" + + "\tignore.error=true|false") + .withRequiredArg + .describedAs("prop") + .ofType(classOf[String]) + val producerPropertyOpt = parser.accepts("producer-property", "A mechanism to pass user-defined properties in the form key=value to the producer. ") + .withRequiredArg + .describedAs("producer_prop") + .ofType(classOf[String]) + val producerConfigOpt = parser.accepts("producer.config", s"Producer config properties file. Note that $producerPropertyOpt takes precedence over this config.") + .withRequiredArg + .describedAs("config file") + .ofType(classOf[String]) + + options = tryParse(parser, args) + + CommandLineUtils.printHelpAndExitIfNeeded(this, "This tool helps to read data from standard input and publish it to Kafka.") + + CommandLineUtils.checkRequiredArgs(parser, options, topicOpt) + + val topic = options.valueOf(topicOpt) + + val bootstrapServer = options.valueOf(bootstrapServerOpt) + val brokerList = options.valueOf(brokerListOpt) + + val brokerHostsAndPorts = options.valueOf(if (options.has(bootstrapServerOpt)) bootstrapServerOpt else brokerListOpt) + ToolsUtils.validatePortOrDie(parser, brokerHostsAndPorts) + + val sync = options.has(syncOpt) + val compressionCodecOptionValue = options.valueOf(compressionCodecOpt) + val compressionCodec = if (options.has(compressionCodecOpt)) + if (compressionCodecOptionValue == null || compressionCodecOptionValue.isEmpty) + DefaultCompressionCodec.name + else compressionCodecOptionValue + else NoCompressionCodec.name + val batchSize = options.valueOf(batchSizeOpt) + val readerClass = options.valueOf(messageReaderOpt) + val cmdLineProps = CommandLineUtils.parseKeyValueArgs(options.valuesOf(propertyOpt).asScala) + val extraProducerProps = CommandLineUtils.parseKeyValueArgs(options.valuesOf(producerPropertyOpt).asScala) + + def tryParse(parser: OptionParser, args: Array[String]): OptionSet = { + try + parser.parse(args: _*) + catch { + case e: OptionException => + CommandLineUtils.printUsageAndDie(parser, e.getMessage) + } + } + } + + class LineMessageReader extends MessageReader { + var topic: String = null + var reader: BufferedReader = null + var parseKey = false + var keySeparator = "\t" + var ignoreError = false + var lineNumber = 0 + var printPrompt = System.console != null + + override def init(inputStream: InputStream, props: Properties): Unit = { + topic = props.getProperty("topic") + if (props.containsKey("parse.key")) + parseKey = props.getProperty("parse.key").trim.equalsIgnoreCase("true") + if (props.containsKey("key.separator")) + keySeparator = props.getProperty("key.separator") + if (props.containsKey("ignore.error")) + ignoreError = props.getProperty("ignore.error").trim.equalsIgnoreCase("true") + reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8)) + } + + override def readMessage() = { + lineNumber += 1 + if (printPrompt) + print(">") + (reader.readLine(), parseKey) match { + case (null, _) => null + case (line, true) => + line.indexOf(keySeparator) match { + case -1 => + if (ignoreError) new ProducerRecord(topic, line.getBytes(StandardCharsets.UTF_8)) + else throw new KafkaException(s"No key found on line $lineNumber: $line") + case n => + val value = (if (n + keySeparator.size > line.size) "" else line.substring(n + keySeparator.size)).getBytes(StandardCharsets.UTF_8) + new ProducerRecord(topic, line.substring(0, n).getBytes(StandardCharsets.UTF_8), value) + } + case (line, false) => + new ProducerRecord(topic, line.getBytes(StandardCharsets.UTF_8)) + } + } + } +} diff --git a/core/src/main/scala/kafka/tools/ConsumerPerformance.scala b/core/src/main/scala/kafka/tools/ConsumerPerformance.scala new file mode 100644 index 0000000..89428e5 --- /dev/null +++ b/core/src/main/scala/kafka/tools/ConsumerPerformance.scala @@ -0,0 +1,306 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.tools + +import java.text.SimpleDateFormat +import java.time.Duration +import java.util +import java.util.concurrent.atomic.AtomicLong +import java.util.{Properties, Random} + +import com.typesafe.scalalogging.LazyLogging +import joptsimple.OptionException +import kafka.utils.{CommandLineUtils, ToolsUtils} +import org.apache.kafka.clients.consumer.{ConsumerRebalanceListener, KafkaConsumer} +import org.apache.kafka.common.serialization.ByteArrayDeserializer +import org.apache.kafka.common.utils.Utils +import org.apache.kafka.common.{Metric, MetricName, TopicPartition} + +import scala.jdk.CollectionConverters._ +import scala.collection.mutable + +/** + * Performance test for the full zookeeper consumer + */ +object ConsumerPerformance extends LazyLogging { + + def main(args: Array[String]): Unit = { + + val config = new ConsumerPerfConfig(args) + logger.info("Starting consumer...") + val totalMessagesRead = new AtomicLong(0) + val totalBytesRead = new AtomicLong(0) + var metrics: mutable.Map[MetricName, _ <: Metric] = null + val joinGroupTimeInMs = new AtomicLong(0) + + if (!config.hideHeader) + printHeader(config.showDetailedStats) + + var startMs, endMs = 0L + val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](config.props) + startMs = System.currentTimeMillis + consume(consumer, List(config.topic), config.numMessages, config.recordFetchTimeoutMs, config, totalMessagesRead, totalBytesRead, joinGroupTimeInMs, startMs) + endMs = System.currentTimeMillis + + if (config.printMetrics) { + metrics = consumer.metrics.asScala + } + consumer.close() + val elapsedSecs = (endMs - startMs) / 1000.0 + val fetchTimeInMs = (endMs - startMs) - joinGroupTimeInMs.get + if (!config.showDetailedStats) { + val totalMBRead = (totalBytesRead.get * 1.0) / (1024 * 1024) + println("%s, %s, %.4f, %.4f, %d, %.4f, %d, %d, %.4f, %.4f".format( + config.dateFormat.format(startMs), + config.dateFormat.format(endMs), + totalMBRead, + totalMBRead / elapsedSecs, + totalMessagesRead.get, + totalMessagesRead.get / elapsedSecs, + joinGroupTimeInMs.get, + fetchTimeInMs, + totalMBRead / (fetchTimeInMs / 1000.0), + totalMessagesRead.get / (fetchTimeInMs / 1000.0) + )) + } + + if (metrics != null) { + ToolsUtils.printMetrics(metrics) + } + + } + + private[tools] def printHeader(showDetailedStats: Boolean): Unit = { + val newFieldsInHeader = ", rebalance.time.ms, fetch.time.ms, fetch.MB.sec, fetch.nMsg.sec" + if (!showDetailedStats) + println("start.time, end.time, data.consumed.in.MB, MB.sec, data.consumed.in.nMsg, nMsg.sec" + newFieldsInHeader) + else + println("time, threadId, data.consumed.in.MB, MB.sec, data.consumed.in.nMsg, nMsg.sec" + newFieldsInHeader) + } + + def consume(consumer: KafkaConsumer[Array[Byte], Array[Byte]], + topics: List[String], + count: Long, + timeout: Long, + config: ConsumerPerfConfig, + totalMessagesRead: AtomicLong, + totalBytesRead: AtomicLong, + joinTime: AtomicLong, + testStartTime: Long): Unit = { + var bytesRead = 0L + var messagesRead = 0L + var lastBytesRead = 0L + var lastMessagesRead = 0L + var joinStart = System.currentTimeMillis + var joinTimeMsInSingleRound = 0L + + consumer.subscribe(topics.asJava, new ConsumerRebalanceListener { + def onPartitionsAssigned(partitions: util.Collection[TopicPartition]): Unit = { + joinTime.addAndGet(System.currentTimeMillis - joinStart) + joinTimeMsInSingleRound += System.currentTimeMillis - joinStart + } + def onPartitionsRevoked(partitions: util.Collection[TopicPartition]): Unit = { + joinStart = System.currentTimeMillis + }}) + + // Now start the benchmark + var currentTimeMillis = System.currentTimeMillis + var lastReportTime: Long = currentTimeMillis + var lastConsumedTime = currentTimeMillis + + while (messagesRead < count && currentTimeMillis - lastConsumedTime <= timeout) { + val records = consumer.poll(Duration.ofMillis(100)).asScala + currentTimeMillis = System.currentTimeMillis + if (records.nonEmpty) + lastConsumedTime = currentTimeMillis + for (record <- records) { + messagesRead += 1 + if (record.key != null) + bytesRead += record.key.size + if (record.value != null) + bytesRead += record.value.size + + if (currentTimeMillis - lastReportTime >= config.reportingInterval) { + if (config.showDetailedStats) + printConsumerProgress(0, bytesRead, lastBytesRead, messagesRead, lastMessagesRead, + lastReportTime, currentTimeMillis, config.dateFormat, joinTimeMsInSingleRound) + joinTimeMsInSingleRound = 0L + lastReportTime = currentTimeMillis + lastMessagesRead = messagesRead + lastBytesRead = bytesRead + } + } + } + + if (messagesRead < count) + println(s"WARNING: Exiting before consuming the expected number of messages: timeout ($timeout ms) exceeded. " + + "You can use the --timeout option to increase the timeout.") + totalMessagesRead.set(messagesRead) + totalBytesRead.set(bytesRead) + } + + def printConsumerProgress(id: Int, + bytesRead: Long, + lastBytesRead: Long, + messagesRead: Long, + lastMessagesRead: Long, + startMs: Long, + endMs: Long, + dateFormat: SimpleDateFormat, + periodicJoinTimeInMs: Long): Unit = { + printBasicProgress(id, bytesRead, lastBytesRead, messagesRead, lastMessagesRead, startMs, endMs, dateFormat) + printExtendedProgress(bytesRead, lastBytesRead, messagesRead, lastMessagesRead, startMs, endMs, periodicJoinTimeInMs) + println() + } + + private def printBasicProgress(id: Int, + bytesRead: Long, + lastBytesRead: Long, + messagesRead: Long, + lastMessagesRead: Long, + startMs: Long, + endMs: Long, + dateFormat: SimpleDateFormat): Unit = { + val elapsedMs: Double = (endMs - startMs).toDouble + val totalMbRead = (bytesRead * 1.0) / (1024 * 1024) + val intervalMbRead = ((bytesRead - lastBytesRead) * 1.0) / (1024 * 1024) + val intervalMbPerSec = 1000.0 * intervalMbRead / elapsedMs + val intervalMessagesPerSec = ((messagesRead - lastMessagesRead) / elapsedMs) * 1000.0 + print("%s, %d, %.4f, %.4f, %d, %.4f".format(dateFormat.format(endMs), id, totalMbRead, + intervalMbPerSec, messagesRead, intervalMessagesPerSec)) + } + + private def printExtendedProgress(bytesRead: Long, + lastBytesRead: Long, + messagesRead: Long, + lastMessagesRead: Long, + startMs: Long, + endMs: Long, + periodicJoinTimeInMs: Long): Unit = { + val fetchTimeMs = endMs - startMs - periodicJoinTimeInMs + val intervalMbRead = ((bytesRead - lastBytesRead) * 1.0) / (1024 * 1024) + val intervalMessagesRead = messagesRead - lastMessagesRead + val (intervalMbPerSec, intervalMessagesPerSec) = if (fetchTimeMs <= 0) + (0.0, 0.0) + else + (1000.0 * intervalMbRead / fetchTimeMs, 1000.0 * intervalMessagesRead / fetchTimeMs) + print(", %d, %d, %.4f, %.4f".format(periodicJoinTimeInMs, fetchTimeMs, intervalMbPerSec, intervalMessagesPerSec)) + } + + class ConsumerPerfConfig(args: Array[String]) extends PerfConfig(args) { + val brokerListOpt = parser.accepts("broker-list", "DEPRECATED, use --bootstrap-server instead; ignored if --bootstrap-server is specified. The broker list string in the form HOST1:PORT1,HOST2:PORT2.") + .withRequiredArg + .describedAs("broker-list") + .ofType(classOf[String]) + val bootstrapServerOpt = parser.accepts("bootstrap-server", "REQUIRED unless --broker-list(deprecated) is specified. The server(s) to connect to.") + .requiredUnless("broker-list") + .withRequiredArg + .describedAs("server to connect to") + .ofType(classOf[String]) + val topicOpt = parser.accepts("topic", "REQUIRED: The topic to consume from.") + .withRequiredArg + .describedAs("topic") + .ofType(classOf[String]) + val groupIdOpt = parser.accepts("group", "The group id to consume on.") + .withRequiredArg + .describedAs("gid") + .defaultsTo("perf-consumer-" + new Random().nextInt(100000)) + .ofType(classOf[String]) + val fetchSizeOpt = parser.accepts("fetch-size", "The amount of data to fetch in a single request.") + .withRequiredArg + .describedAs("size") + .ofType(classOf[java.lang.Integer]) + .defaultsTo(1024 * 1024) + val resetBeginningOffsetOpt = parser.accepts("from-latest", "If the consumer does not already have an established " + + "offset to consume from, start with the latest message present in the log rather than the earliest message.") + val socketBufferSizeOpt = parser.accepts("socket-buffer-size", "The size of the tcp RECV size.") + .withRequiredArg + .describedAs("size") + .ofType(classOf[java.lang.Integer]) + .defaultsTo(2 * 1024 * 1024) + val numThreadsOpt = parser.accepts("threads", "DEPRECATED AND IGNORED: Number of processing threads.") + .withRequiredArg + .describedAs("count") + .ofType(classOf[java.lang.Integer]) + .defaultsTo(10) + val numFetchersOpt = parser.accepts("num-fetch-threads", "DEPRECATED AND IGNORED: Number of fetcher threads.") + .withRequiredArg + .describedAs("count") + .ofType(classOf[java.lang.Integer]) + .defaultsTo(1) + val consumerConfigOpt = parser.accepts("consumer.config", "Consumer config properties file.") + .withRequiredArg + .describedAs("config file") + .ofType(classOf[String]) + val printMetricsOpt = parser.accepts("print-metrics", "Print out the metrics.") + val showDetailedStatsOpt = parser.accepts("show-detailed-stats", "If set, stats are reported for each reporting " + + "interval as configured by reporting-interval") + val recordFetchTimeoutOpt = parser.accepts("timeout", "The maximum allowed time in milliseconds between returned records.") + .withOptionalArg() + .describedAs("milliseconds") + .ofType(classOf[Long]) + .defaultsTo(10000) + + try + options = parser.parse(args: _*) + catch { + case e: OptionException => + CommandLineUtils.printUsageAndDie(parser, e.getMessage) + } + + if(options.has(numThreadsOpt) || options.has(numFetchersOpt)) + println("WARNING: option [threads] and [num-fetch-threads] have been deprecated and will be ignored by the test") + + CommandLineUtils.printHelpAndExitIfNeeded(this, "This tool helps in performance test for the full zookeeper consumer") + + CommandLineUtils.checkRequiredArgs(parser, options, topicOpt, numMessagesOpt) + + val printMetrics = options.has(printMetricsOpt) + + val props = if (options.has(consumerConfigOpt)) + Utils.loadProps(options.valueOf(consumerConfigOpt)) + else + new Properties + + import org.apache.kafka.clients.consumer.ConsumerConfig + + val brokerHostsAndPorts = options.valueOf(if (options.has(bootstrapServerOpt)) bootstrapServerOpt else brokerListOpt) + props.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerHostsAndPorts) + props.put(ConsumerConfig.GROUP_ID_CONFIG, options.valueOf(groupIdOpt)) + props.put(ConsumerConfig.RECEIVE_BUFFER_CONFIG, options.valueOf(socketBufferSizeOpt).toString) + props.put(ConsumerConfig.MAX_PARTITION_FETCH_BYTES_CONFIG, options.valueOf(fetchSizeOpt).toString) + props.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, if (options.has(resetBeginningOffsetOpt)) "latest" else "earliest") + props.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, classOf[ByteArrayDeserializer]) + props.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, classOf[ByteArrayDeserializer]) + props.put(ConsumerConfig.CHECK_CRCS_CONFIG, "false") + if (props.getProperty(ConsumerConfig.CLIENT_ID_CONFIG) == null) + props.put(ConsumerConfig.CLIENT_ID_CONFIG, "perf-consumer-client") + + val numThreads = options.valueOf(numThreadsOpt).intValue + val topic = options.valueOf(topicOpt) + val numMessages = options.valueOf(numMessagesOpt).longValue + val reportingInterval = options.valueOf(reportingIntervalOpt).intValue + if (reportingInterval <= 0) + throw new IllegalArgumentException("Reporting interval must be greater than 0.") + val showDetailedStats = options.has(showDetailedStatsOpt) + val dateFormat = new SimpleDateFormat(options.valueOf(dateFormatOpt)) + val hideHeader = options.has(hideHeaderOpt) + val recordFetchTimeoutMs = options.valueOf(recordFetchTimeoutOpt).longValue() + } + +} diff --git a/core/src/main/scala/kafka/tools/DumpLogSegments.scala b/core/src/main/scala/kafka/tools/DumpLogSegments.scala new file mode 100755 index 0000000..f254fa8 --- /dev/null +++ b/core/src/main/scala/kafka/tools/DumpLogSegments.scala @@ -0,0 +1,482 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.tools + +import java.io._ + +import com.fasterxml.jackson.databind.node.{IntNode, JsonNodeFactory, ObjectNode, TextNode} +import kafka.coordinator.group.GroupMetadataManager +import kafka.coordinator.transaction.TransactionLog +import kafka.log._ +import kafka.serializer.Decoder +import kafka.utils._ +import kafka.utils.Implicits._ +import org.apache.kafka.common.metadata.{MetadataJsonConverters, MetadataRecordType} +import org.apache.kafka.common.protocol.ByteBufferAccessor +import org.apache.kafka.common.record._ +import org.apache.kafka.common.utils.Utils +import org.apache.kafka.metadata.MetadataRecordSerde + +import scala.jdk.CollectionConverters._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +object DumpLogSegments { + + // visible for testing + private[tools] val RecordIndent = "|" + + def main(args: Array[String]): Unit = { + val opts = new DumpLogSegmentsOptions(args) + CommandLineUtils.printHelpAndExitIfNeeded(opts, "This tool helps to parse a log file and dump its contents to the console, useful for debugging a seemingly corrupt log segment.") + opts.checkArgs() + + val misMatchesForIndexFilesMap = mutable.Map[String, List[(Long, Long)]]() + val timeIndexDumpErrors = new TimeIndexDumpErrors + val nonConsecutivePairsForLogFilesMap = mutable.Map[String, List[(Long, Long)]]() + + for (arg <- opts.files) { + val file = new File(arg) + println(s"Dumping $file") + + val filename = file.getName + val suffix = filename.substring(filename.lastIndexOf(".")) + suffix match { + case UnifiedLog.LogFileSuffix => + dumpLog(file, opts.shouldPrintDataLog, nonConsecutivePairsForLogFilesMap, opts.isDeepIteration, + opts.maxMessageSize, opts.messageParser, opts.skipRecordMetadata) + case UnifiedLog.IndexFileSuffix => + dumpIndex(file, opts.indexSanityOnly, opts.verifyOnly, misMatchesForIndexFilesMap, opts.maxMessageSize) + case UnifiedLog.TimeIndexFileSuffix => + dumpTimeIndex(file, opts.indexSanityOnly, opts.verifyOnly, timeIndexDumpErrors, opts.maxMessageSize) + case UnifiedLog.ProducerSnapshotFileSuffix => + dumpProducerIdSnapshot(file) + case UnifiedLog.TxnIndexFileSuffix => + dumpTxnIndex(file) + case _ => + System.err.println(s"Ignoring unknown file $file") + } + } + + misMatchesForIndexFilesMap.forKeyValue { (fileName, listOfMismatches) => + System.err.println(s"Mismatches in :$fileName") + listOfMismatches.foreach { case (indexOffset, logOffset) => + System.err.println(s" Index offset: $indexOffset, log offset: $logOffset") + } + } + + timeIndexDumpErrors.printErrors() + + nonConsecutivePairsForLogFilesMap.forKeyValue { (fileName, listOfNonConsecutivePairs) => + System.err.println(s"Non-consecutive offsets in $fileName") + listOfNonConsecutivePairs.foreach { case (first, second) => + System.err.println(s" $first is followed by $second") + } + } + } + + private def dumpTxnIndex(file: File): Unit = { + val index = new TransactionIndex(UnifiedLog.offsetFromFile(file), file) + for (abortedTxn <- index.allAbortedTxns) { + println(s"version: ${abortedTxn.version} producerId: ${abortedTxn.producerId} firstOffset: ${abortedTxn.firstOffset} " + + s"lastOffset: ${abortedTxn.lastOffset} lastStableOffset: ${abortedTxn.lastStableOffset}") + } + } + + private def dumpProducerIdSnapshot(file: File): Unit = { + try { + ProducerStateManager.readSnapshot(file).foreach { entry => + print(s"producerId: ${entry.producerId} producerEpoch: ${entry.producerEpoch} " + + s"coordinatorEpoch: ${entry.coordinatorEpoch} currentTxnFirstOffset: ${entry.currentTxnFirstOffset} " + + s"lastTimestamp: ${entry.lastTimestamp} ") + entry.batchMetadata.headOption.foreach { metadata => + print(s"firstSequence: ${metadata.firstSeq} lastSequence: ${metadata.lastSeq} " + + s"lastOffset: ${metadata.lastOffset} offsetDelta: ${metadata.offsetDelta} timestamp: ${metadata.timestamp}") + } + println() + } + } catch { + case e: CorruptSnapshotException => + System.err.println(e.getMessage) + } + } + + /* print out the contents of the index */ + // Visible for testing + private[tools] def dumpIndex(file: File, + indexSanityOnly: Boolean, + verifyOnly: Boolean, + misMatchesForIndexFilesMap: mutable.Map[String, List[(Long, Long)]], + maxMessageSize: Int): Unit = { + val startOffset = file.getName.split("\\.")(0).toLong + val logFile = new File(file.getAbsoluteFile.getParent, file.getName.split("\\.")(0) + UnifiedLog.LogFileSuffix) + val fileRecords = FileRecords.open(logFile, false) + val index = new OffsetIndex(file, baseOffset = startOffset, writable = false) + + if (index.entries == 0) { + println(s"$file is empty.") + return + } + + //Check that index passes sanityCheck, this is the check that determines if indexes will be rebuilt on startup or not. + if (indexSanityOnly) { + index.sanityCheck() + println(s"$file passed sanity check.") + return + } + + for (i <- 0 until index.entries) { + val entry = index.entry(i) + + // since it is a sparse file, in the event of a crash there may be many zero entries, stop if we see one + if (entry.offset == index.baseOffset && i > 0) + return + + val slice = fileRecords.slice(entry.position, maxMessageSize) + val firstBatchLastOffset = slice.batches.iterator.next().lastOffset + if (firstBatchLastOffset != entry.offset) { + var misMatchesSeq = misMatchesForIndexFilesMap.getOrElse(file.getAbsolutePath, List[(Long, Long)]()) + misMatchesSeq ::= (entry.offset, firstBatchLastOffset) + misMatchesForIndexFilesMap.put(file.getAbsolutePath, misMatchesSeq) + } + if (!verifyOnly) + println(s"offset: ${entry.offset} position: ${entry.position}") + } + } + + // Visible for testing + private[tools] def dumpTimeIndex(file: File, + indexSanityOnly: Boolean, + verifyOnly: Boolean, + timeIndexDumpErrors: TimeIndexDumpErrors, + maxMessageSize: Int): Unit = { + val startOffset = file.getName.split("\\.")(0).toLong + val logFile = new File(file.getAbsoluteFile.getParent, file.getName.split("\\.")(0) + UnifiedLog.LogFileSuffix) + val fileRecords = FileRecords.open(logFile, false) + val indexFile = new File(file.getAbsoluteFile.getParent, file.getName.split("\\.")(0) + UnifiedLog.IndexFileSuffix) + val index = new OffsetIndex(indexFile, baseOffset = startOffset, writable = false) + val timeIndex = new TimeIndex(file, baseOffset = startOffset, writable = false) + + try { + //Check that index passes sanityCheck, this is the check that determines if indexes will be rebuilt on startup or not. + if (indexSanityOnly) { + timeIndex.sanityCheck() + println(s"$file passed sanity check.") + return + } + + var prevTimestamp = RecordBatch.NO_TIMESTAMP + for (i <- 0 until timeIndex.entries) { + val entry = timeIndex.entry(i) + + // since it is a sparse file, in the event of a crash there may be many zero entries, stop if we see one + if (entry.offset == timeIndex.baseOffset && i > 0) + return + + val position = index.lookup(entry.offset).position + val partialFileRecords = fileRecords.slice(position, Int.MaxValue) + val batches = partialFileRecords.batches.asScala + var maxTimestamp = RecordBatch.NO_TIMESTAMP + // We first find the message by offset then check if the timestamp is correct. + batches.find(_.lastOffset >= entry.offset) match { + case None => + timeIndexDumpErrors.recordShallowOffsetNotFound(file, entry.offset, + -1.toLong) + case Some(batch) if batch.lastOffset != entry.offset => + timeIndexDumpErrors.recordShallowOffsetNotFound(file, entry.offset, batch.lastOffset) + case Some(batch) => + for (record <- batch.asScala) + maxTimestamp = math.max(maxTimestamp, record.timestamp) + + if (maxTimestamp != entry.timestamp) + timeIndexDumpErrors.recordMismatchTimeIndex(file, entry.timestamp, maxTimestamp) + + if (prevTimestamp >= entry.timestamp) + timeIndexDumpErrors.recordOutOfOrderIndexTimestamp(file, entry.timestamp, prevTimestamp) + } + if (!verifyOnly) + println(s"timestamp: ${entry.timestamp} offset: ${entry.offset}") + prevTimestamp = entry.timestamp + } + } finally { + fileRecords.closeHandlers() + index.closeHandler() + timeIndex.closeHandler() + } + } + + private[kafka] trait MessageParser[K, V] { + def parse(record: Record): (Option[K], Option[V]) + } + + private class DecoderMessageParser[K, V](keyDecoder: Decoder[K], valueDecoder: Decoder[V]) extends MessageParser[K, V] { + override def parse(record: Record): (Option[K], Option[V]) = { + val key = if (record.hasKey) + Some(keyDecoder.fromBytes(Utils.readBytes(record.key))) + else + None + + if (!record.hasValue) { + (key, None) + } else { + val payload = Some(valueDecoder.fromBytes(Utils.readBytes(record.value))) + + (key, payload) + } + } + } + + /* print out the contents of the log */ + private def dumpLog(file: File, + printContents: Boolean, + nonConsecutivePairsForLogFilesMap: mutable.Map[String, List[(Long, Long)]], + isDeepIteration: Boolean, + maxMessageSize: Int, + parser: MessageParser[_, _], + skipRecordMetadata: Boolean): Unit = { + val startOffset = file.getName.split("\\.")(0).toLong + println("Starting offset: " + startOffset) + val fileRecords = FileRecords.open(file, false) + try { + var validBytes = 0L + var lastOffset = -1L + + for (batch <- fileRecords.batches.asScala) { + printBatchLevel(batch, validBytes) + if (isDeepIteration) { + for (record <- batch.asScala) { + if (lastOffset == -1) + lastOffset = record.offset + else if (record.offset != lastOffset + 1) { + var nonConsecutivePairsSeq = nonConsecutivePairsForLogFilesMap.getOrElse(file.getAbsolutePath, List[(Long, Long)]()) + nonConsecutivePairsSeq ::= (lastOffset, record.offset) + nonConsecutivePairsForLogFilesMap.put(file.getAbsolutePath, nonConsecutivePairsSeq) + } + lastOffset = record.offset + + var prefix = s"$RecordIndent " + if (!skipRecordMetadata) { + print(s"${prefix}offset: ${record.offset} ${batch.timestampType}: ${record.timestamp} " + + s"keySize: ${record.keySize} valueSize: ${record.valueSize}") + prefix = " " + + if (batch.magic >= RecordBatch.MAGIC_VALUE_V2) { + print(" sequence: " + record.sequence + " headerKeys: " + record.headers.map(_.key).mkString("[", ",", "]")) + } + record match { + case r: AbstractLegacyRecordBatch => print(s" isValid: ${r.isValid} crc: ${r.checksum}}") + case _ => + } + + if (batch.isControlBatch) { + val controlTypeId = ControlRecordType.parseTypeId(record.key) + ControlRecordType.fromTypeId(controlTypeId) match { + case ControlRecordType.ABORT | ControlRecordType.COMMIT => + val endTxnMarker = EndTransactionMarker.deserialize(record) + print(s" endTxnMarker: ${endTxnMarker.controlType} coordinatorEpoch: ${endTxnMarker.coordinatorEpoch}") + case controlType => + print(s" controlType: $controlType($controlTypeId)") + } + } + } + if (printContents && !batch.isControlBatch) { + val (key, payload) = parser.parse(record) + key.foreach { key => + print(s"${prefix}key: $key") + prefix = " " + } + payload.foreach(payload => print(s" payload: $payload")) + } + println() + } + } + validBytes += batch.sizeInBytes + } + val trailingBytes = fileRecords.sizeInBytes - validBytes + if (trailingBytes > 0) + println(s"Found $trailingBytes invalid bytes at the end of ${file.getName}") + } finally fileRecords.closeHandlers() + } + + private def printBatchLevel(batch: FileLogInputStream.FileChannelRecordBatch, accumulativeBytes: Long): Unit = { + if (batch.magic >= RecordBatch.MAGIC_VALUE_V2) + print("baseOffset: " + batch.baseOffset + " lastOffset: " + batch.lastOffset + " count: " + batch.countOrNull + + " baseSequence: " + batch.baseSequence + " lastSequence: " + batch.lastSequence + + " producerId: " + batch.producerId + " producerEpoch: " + batch.producerEpoch + + " partitionLeaderEpoch: " + batch.partitionLeaderEpoch + " isTransactional: " + batch.isTransactional + + " isControl: " + batch.isControlBatch) + else + print("offset: " + batch.lastOffset) + + println(" position: " + accumulativeBytes + " " + batch.timestampType + ": " + batch.maxTimestamp + + " size: " + batch.sizeInBytes + " magic: " + batch.magic + + " compresscodec: " + batch.compressionType.name + " crc: " + batch.checksum + " isvalid: " + batch.isValid) + } + + class TimeIndexDumpErrors { + val misMatchesForTimeIndexFilesMap = mutable.Map[String, ArrayBuffer[(Long, Long)]]() + val outOfOrderTimestamp = mutable.Map[String, ArrayBuffer[(Long, Long)]]() + val shallowOffsetNotFound = mutable.Map[String, ArrayBuffer[(Long, Long)]]() + + def recordMismatchTimeIndex(file: File, indexTimestamp: Long, logTimestamp: Long): Unit = { + val misMatchesSeq = misMatchesForTimeIndexFilesMap.getOrElse(file.getAbsolutePath, new ArrayBuffer[(Long, Long)]()) + if (misMatchesSeq.isEmpty) + misMatchesForTimeIndexFilesMap.put(file.getAbsolutePath, misMatchesSeq) + misMatchesSeq += ((indexTimestamp, logTimestamp)) + } + + def recordOutOfOrderIndexTimestamp(file: File, indexTimestamp: Long, prevIndexTimestamp: Long): Unit = { + val outOfOrderSeq = outOfOrderTimestamp.getOrElse(file.getAbsolutePath, new ArrayBuffer[(Long, Long)]()) + if (outOfOrderSeq.isEmpty) + outOfOrderTimestamp.put(file.getAbsolutePath, outOfOrderSeq) + outOfOrderSeq += ((indexTimestamp, prevIndexTimestamp)) + } + + def recordShallowOffsetNotFound(file: File, indexOffset: Long, logOffset: Long): Unit = { + val shallowOffsetNotFoundSeq = shallowOffsetNotFound.getOrElse(file.getAbsolutePath, new ArrayBuffer[(Long, Long)]()) + if (shallowOffsetNotFoundSeq.isEmpty) + shallowOffsetNotFound.put(file.getAbsolutePath, shallowOffsetNotFoundSeq) + shallowOffsetNotFoundSeq += ((indexOffset, logOffset)) + } + + def printErrors(): Unit = { + misMatchesForTimeIndexFilesMap.foreach { + case (fileName, listOfMismatches) => { + System.err.println("Found timestamp mismatch in :" + fileName) + listOfMismatches.foreach(m => { + System.err.println(" Index timestamp: %d, log timestamp: %d".format(m._1, m._2)) + }) + } + } + + outOfOrderTimestamp.foreach { + case (fileName, outOfOrderTimestamps) => { + System.err.println("Found out of order timestamp in :" + fileName) + outOfOrderTimestamps.foreach(m => { + System.err.println(" Index timestamp: %d, Previously indexed timestamp: %d".format(m._1, m._2)) + }) + } + } + + shallowOffsetNotFound.values.foreach { listOfShallowOffsetNotFound => + System.err.println("The following indexed offsets are not found in the log.") + listOfShallowOffsetNotFound.foreach { case (indexedOffset, logOffset) => + System.err.println(s"Indexed offset: $indexedOffset, found log offset: $logOffset") + } + } + } + } + + private class OffsetsMessageParser extends MessageParser[String, String] { + override def parse(record: Record): (Option[String], Option[String]) = { + GroupMetadataManager.formatRecordKeyAndValue(record) + } + } + + private class TransactionLogMessageParser extends MessageParser[String, String] { + override def parse(record: Record): (Option[String], Option[String]) = { + TransactionLog.formatRecordKeyAndValue(record) + } + } + + private class ClusterMetadataLogMessageParser extends MessageParser[String, String] { + val metadataRecordSerde = new MetadataRecordSerde() + + override def parse(record: Record): (Option[String], Option[String]) = { + val output = try { + val messageAndVersion = metadataRecordSerde. + read(new ByteBufferAccessor(record.value), record.valueSize()) + val json = new ObjectNode(JsonNodeFactory.instance) + json.set("type", new TextNode(MetadataRecordType.fromId( + messageAndVersion.message().apiKey()).toString)) + json.set("version", new IntNode(messageAndVersion.version())) + json.set("data", MetadataJsonConverters.writeJson( + messageAndVersion.message(), messageAndVersion.version())) + json.toString() + } catch { + case e: Throwable => { + s"Error at ${record.offset}, skipping. ${e.getMessage}" + } + } + // No keys for metadata records + (None, Some(output)) + } + } + + private class DumpLogSegmentsOptions(args: Array[String]) extends CommandDefaultOptions(args) { + val printOpt = parser.accepts("print-data-log", "if set, printing the messages content when dumping data logs. Automatically set if any decoder option is specified.") + val verifyOpt = parser.accepts("verify-index-only", "if set, just verify the index log without printing its content.") + val indexSanityOpt = parser.accepts("index-sanity-check", "if set, just checks the index sanity without printing its content. " + + "This is the same check that is executed on broker startup to determine if an index needs rebuilding or not.") + val filesOpt = parser.accepts("files", "REQUIRED: The comma separated list of data and index log files to be dumped.") + .withRequiredArg + .describedAs("file1, file2, ...") + .ofType(classOf[String]) + val maxMessageSizeOpt = parser.accepts("max-message-size", "Size of largest message.") + .withRequiredArg + .describedAs("size") + .ofType(classOf[java.lang.Integer]) + .defaultsTo(5 * 1024 * 1024) + val deepIterationOpt = parser.accepts("deep-iteration", "if set, uses deep instead of shallow iteration. Automatically set if print-data-log is enabled.") + val valueDecoderOpt = parser.accepts("value-decoder-class", "if set, used to deserialize the messages. This class should implement kafka.serializer.Decoder trait. Custom jar should be available in kafka/libs directory.") + .withOptionalArg() + .ofType(classOf[java.lang.String]) + .defaultsTo("kafka.serializer.StringDecoder") + val keyDecoderOpt = parser.accepts("key-decoder-class", "if set, used to deserialize the keys. This class should implement kafka.serializer.Decoder trait. Custom jar should be available in kafka/libs directory.") + .withOptionalArg() + .ofType(classOf[java.lang.String]) + .defaultsTo("kafka.serializer.StringDecoder") + val offsetsOpt = parser.accepts("offsets-decoder", "if set, log data will be parsed as offset data from the " + + "__consumer_offsets topic.") + val transactionLogOpt = parser.accepts("transaction-log-decoder", "if set, log data will be parsed as " + + "transaction metadata from the __transaction_state topic.") + val clusterMetadataOpt = parser.accepts("cluster-metadata-decoder", "if set, log data will be parsed as cluster metadata records.") + val skipRecordMetadataOpt = parser.accepts("skip-record-metadata", "whether to skip printing metadata for each record.") + options = parser.parse(args : _*) + + def messageParser: MessageParser[_, _] = + if (options.has(offsetsOpt)) { + new OffsetsMessageParser + } else if (options.has(transactionLogOpt)) { + new TransactionLogMessageParser + } else if (options.has(clusterMetadataOpt)) { + new ClusterMetadataLogMessageParser + } else { + val valueDecoder: Decoder[_] = CoreUtils.createObject[Decoder[_]](options.valueOf(valueDecoderOpt), new VerifiableProperties) + val keyDecoder: Decoder[_] = CoreUtils.createObject[Decoder[_]](options.valueOf(keyDecoderOpt), new VerifiableProperties) + new DecoderMessageParser(keyDecoder, valueDecoder) + } + + lazy val shouldPrintDataLog: Boolean = options.has(printOpt) || + options.has(offsetsOpt) || + options.has(transactionLogOpt) || + options.has(clusterMetadataOpt) || + options.has(valueDecoderOpt) || + options.has(keyDecoderOpt) + + lazy val skipRecordMetadata = options.has(skipRecordMetadataOpt) + lazy val isDeepIteration: Boolean = options.has(deepIterationOpt) || shouldPrintDataLog + lazy val verifyOnly: Boolean = options.has(verifyOpt) + lazy val indexSanityOnly: Boolean = options.has(indexSanityOpt) + lazy val files = options.valueOf(filesOpt).split(",") + lazy val maxMessageSize = options.valueOf(maxMessageSizeOpt).intValue() + + def checkArgs(): Unit = CommandLineUtils.checkRequiredArgs(parser, options, filesOpt) + + } +} diff --git a/core/src/main/scala/kafka/tools/EndToEndLatency.scala b/core/src/main/scala/kafka/tools/EndToEndLatency.scala new file mode 100755 index 0000000..8a0e670 --- /dev/null +++ b/core/src/main/scala/kafka/tools/EndToEndLatency.scala @@ -0,0 +1,179 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.tools + +import java.nio.charset.StandardCharsets +import java.time.Duration +import java.util.{Arrays, Collections, Properties} + +import kafka.utils.Exit +import org.apache.kafka.clients.admin.{Admin, NewTopic} +import org.apache.kafka.clients.CommonClientConfigs +import org.apache.kafka.clients.consumer.{ConsumerConfig, KafkaConsumer} +import org.apache.kafka.clients.producer._ +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.utils.Utils + +import scala.jdk.CollectionConverters._ +import scala.util.Random + + +/** + * This class records the average end to end latency for a single message to travel through Kafka + * + * broker_list = location of the bootstrap broker for both the producer and the consumer + * num_messages = # messages to send + * producer_acks = See ProducerConfig.ACKS_DOC + * message_size_bytes = size of each message in bytes + * + * e.g. [localhost:9092 test 10000 1 20] + */ + +object EndToEndLatency { + private val timeout: Long = 60000 + private val defaultReplicationFactor: Short = 1 + private val defaultNumPartitions: Int = 1 + + def main(args: Array[String]): Unit = { + if (args.length != 5 && args.length != 6) { + System.err.println("USAGE: java " + getClass.getName + " broker_list topic num_messages producer_acks message_size_bytes [optional] properties_file") + Exit.exit(1) + } + + val brokerList = args(0) + val topic = args(1) + val numMessages = args(2).toInt + val producerAcks = args(3) + val messageLen = args(4).toInt + val propsFile = if (args.length > 5) Some(args(5)).filter(_.nonEmpty) else None + + if (!List("1", "all").contains(producerAcks)) + throw new IllegalArgumentException("Latency testing requires synchronous acknowledgement. Please use 1 or all") + + def loadPropsWithBootstrapServers: Properties = { + val props = propsFile.map(Utils.loadProps).getOrElse(new Properties()) + props.put(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, brokerList) + props + } + + val consumerProps = loadPropsWithBootstrapServers + consumerProps.put(ConsumerConfig.GROUP_ID_CONFIG, "test-group-" + System.currentTimeMillis()) + consumerProps.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") + consumerProps.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "latest") + consumerProps.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArrayDeserializer") + consumerProps.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArrayDeserializer") + consumerProps.put(ConsumerConfig.FETCH_MAX_WAIT_MS_CONFIG, "0") //ensure we have no temporal batching + val consumer = new KafkaConsumer[Array[Byte], Array[Byte]](consumerProps) + + val producerProps = loadPropsWithBootstrapServers + producerProps.put(ProducerConfig.LINGER_MS_CONFIG, "0") //ensure writes are synchronous + producerProps.put(ProducerConfig.MAX_BLOCK_MS_CONFIG, Long.MaxValue.toString) + producerProps.put(ProducerConfig.ACKS_CONFIG, producerAcks.toString) + producerProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArraySerializer") + producerProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.ByteArraySerializer") + val producer = new KafkaProducer[Array[Byte], Array[Byte]](producerProps) + + def finalise(): Unit = { + consumer.commitSync() + producer.close() + consumer.close() + } + + // create topic if it does not exist + if (!consumer.listTopics().containsKey(topic)) { + try { + createTopic(topic, loadPropsWithBootstrapServers) + } catch { + case t: Throwable => + finalise() + throw new RuntimeException(s"Failed to create topic $topic", t) + } + } + + val topicPartitions = consumer.partitionsFor(topic).asScala + .map(p => new TopicPartition(p.topic(), p.partition())).asJava + consumer.assign(topicPartitions) + consumer.seekToEnd(topicPartitions) + consumer.assignment.forEach(consumer.position(_)) + + var totalTime = 0.0 + val latencies = new Array[Long](numMessages) + val random = new Random(0) + + for (i <- 0 until numMessages) { + val message = randomBytesOfLen(random, messageLen) + val begin = System.nanoTime + + //Send message (of random bytes) synchronously then immediately poll for it + producer.send(new ProducerRecord[Array[Byte], Array[Byte]](topic, message)).get() + val recordIter = consumer.poll(Duration.ofMillis(timeout)).iterator + + val elapsed = System.nanoTime - begin + + //Check we got results + if (!recordIter.hasNext) { + finalise() + throw new RuntimeException(s"poll() timed out before finding a result (timeout:[$timeout])") + } + + //Check result matches the original record + val sent = new String(message, StandardCharsets.UTF_8) + val read = new String(recordIter.next().value(), StandardCharsets.UTF_8) + if (!read.equals(sent)) { + finalise() + throw new RuntimeException(s"The message read [$read] did not match the message sent [$sent]") + } + + //Check we only got the one message + if (recordIter.hasNext) { + val count = 1 + recordIter.asScala.size + throw new RuntimeException(s"Only one result was expected during this test. We found [$count]") + } + + //Report progress + if (i % 1000 == 0) + println(i.toString + "\t" + elapsed / 1000.0 / 1000.0) + totalTime += elapsed + latencies(i) = elapsed / 1000 / 1000 + } + + //Results + println("Avg latency: %.4f ms\n".format(totalTime / numMessages / 1000.0 / 1000.0)) + Arrays.sort(latencies) + val p50 = latencies((latencies.length * 0.5).toInt) + val p99 = latencies((latencies.length * 0.99).toInt) + val p999 = latencies((latencies.length * 0.999).toInt) + println("Percentiles: 50th = %d, 99th = %d, 99.9th = %d".format(p50, p99, p999)) + + finalise() + } + + def randomBytesOfLen(random: Random, len: Int): Array[Byte] = { + Array.fill(len)((random.nextInt(26) + 65).toByte) + } + + def createTopic(topic: String, props: Properties): Unit = { + println("Topic \"%s\" does not exist. Will create topic with %d partition(s) and replication factor = %d" + .format(topic, defaultNumPartitions, defaultReplicationFactor)) + + val adminClient = Admin.create(props) + val newTopic = new NewTopic(topic, defaultNumPartitions, defaultReplicationFactor) + try adminClient.createTopics(Collections.singleton(newTopic)).all().get() + finally Utils.closeQuietly(adminClient, "AdminClient") + } +} diff --git a/core/src/main/scala/kafka/tools/GetOffsetShell.scala b/core/src/main/scala/kafka/tools/GetOffsetShell.scala new file mode 100644 index 0000000..dfd5a22 --- /dev/null +++ b/core/src/main/scala/kafka/tools/GetOffsetShell.scala @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package kafka.tools + +import java.util.Properties +import joptsimple._ +import kafka.utils.{CommandLineUtils, Exit, IncludeList, ToolsUtils} +import org.apache.kafka.clients.consumer.{ConsumerConfig, KafkaConsumer} +import org.apache.kafka.common.requests.ListOffsetsRequest +import org.apache.kafka.common.{PartitionInfo, TopicPartition} +import org.apache.kafka.common.serialization.ByteArrayDeserializer +import org.apache.kafka.common.utils.Utils + +import java.util.regex.Pattern +import scala.jdk.CollectionConverters._ +import scala.collection.Seq +import scala.math.Ordering.Implicits.infixOrderingOps + +object GetOffsetShell { + private val TopicPartitionPattern = Pattern.compile("([^:,]*)(?::(?:([0-9]*)|(?:([0-9]*)-([0-9]*))))?") + + def main(args: Array[String]): Unit = { + try { + fetchOffsets(args) + } catch { + case e: Exception => + println(s"Error occurred: ${e.getMessage}") + Exit.exit(1, Some(e.getMessage)) + } + } + + private def fetchOffsets(args: Array[String]): Unit = { + val parser = new OptionParser(false) + val brokerListOpt = parser.accepts("broker-list", "DEPRECATED, use --bootstrap-server instead; ignored if --bootstrap-server is specified. The server(s) to connect to in the form HOST1:PORT1,HOST2:PORT2.") + .withRequiredArg + .describedAs("HOST1:PORT1,...,HOST3:PORT3") + .ofType(classOf[String]) + val bootstrapServerOpt = parser.accepts("bootstrap-server", "REQUIRED. The server(s) to connect to in the form HOST1:PORT1,HOST2:PORT2.") + .requiredUnless("broker-list") + .withRequiredArg + .describedAs("HOST1:PORT1,...,HOST3:PORT3") + .ofType(classOf[String]) + val topicPartitionsOpt = parser.accepts("topic-partitions", s"Comma separated list of topic-partition patterns to get the offsets for, with the format of '$TopicPartitionPattern'." + + " The first group is an optional regex for the topic name, if omitted, it matches any topic name." + + " The section after ':' describes a 'partition' pattern, which can be: a number, a range in the format of 'NUMBER-NUMBER' (lower inclusive, upper exclusive), an inclusive lower bound in the format of 'NUMBER-', an exclusive upper bound in the format of '-NUMBER' or may be omitted to accept all partitions.") + .withRequiredArg + .describedAs("topic1:1,topic2:0-3,topic3,topic4:5-,topic5:-3") + .ofType(classOf[String]) + val topicOpt = parser.accepts("topic", s"The topic to get the offsets for. It also accepts a regular expression. If not present, all authorized topics are queried. Cannot be used if --topic-partitions is present.") + .withRequiredArg + .describedAs("topic") + .ofType(classOf[String]) + val partitionsOpt = parser.accepts("partitions", s"Comma separated list of partition ids to get the offsets for. If not present, all partitions of the authorized topics are queried. Cannot be used if --topic-partitions is present.") + .withRequiredArg + .describedAs("partition ids") + .ofType(classOf[String]) + val timeOpt = parser.accepts("time", "timestamp of the offsets before that. [Note: No offset is returned, if the timestamp greater than recently committed record timestamp is given.]") + .withRequiredArg + .describedAs("timestamp/-1(latest)/-2(earliest)") + .ofType(classOf[java.lang.Long]) + .defaultsTo(-1L) + val commandConfigOpt = parser.accepts("command-config", s"Property file containing configs to be passed to Consumer Client.") + .withRequiredArg + .describedAs("config file") + .ofType(classOf[String]) + val excludeInternalTopicsOpt = parser.accepts("exclude-internal-topics", s"By default, internal topics are included. If specified, internal topics are excluded.") + + if (args.length == 0) + CommandLineUtils.printUsageAndDie(parser, "An interactive shell for getting topic-partition offsets.") + + val options = parser.parse(args : _*) + + val effectiveBrokerListOpt = if (options.has(bootstrapServerOpt)) + bootstrapServerOpt + else + brokerListOpt + + CommandLineUtils.checkRequiredArgs(parser, options, effectiveBrokerListOpt) + + val clientId = "GetOffsetShell" + val brokerList = options.valueOf(effectiveBrokerListOpt) + + ToolsUtils.validatePortOrDie(parser, brokerList) + val excludeInternalTopics = options.has(excludeInternalTopicsOpt) + + if (options.has(topicPartitionsOpt) && (options.has(topicOpt) || options.has(partitionsOpt))) { + throw new IllegalArgumentException("--topic-partitions cannot be used with --topic or --partitions") + } + + val listOffsetsTimestamp = options.valueOf(timeOpt).longValue + + val topicPartitionFilter = if (options.has(topicPartitionsOpt)) { + createTopicPartitionFilterWithPatternList(options.valueOf(topicPartitionsOpt), excludeInternalTopics) + } else { + val partitionIdsRequested = createPartitionSet(options.valueOf(partitionsOpt)) + + createTopicPartitionFilterWithTopicAndPartitionPattern( + if (options.has(topicOpt)) Some(options.valueOf(topicOpt)) else None, + excludeInternalTopics, + partitionIdsRequested + ) + } + + val config = if (options.has(commandConfigOpt)) + Utils.loadProps(options.valueOf(commandConfigOpt)) + else + new Properties + config.setProperty(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList) + config.setProperty(ConsumerConfig.CLIENT_ID_CONFIG, clientId) + val consumer = new KafkaConsumer(config, new ByteArrayDeserializer, new ByteArrayDeserializer) + + try { + val partitionInfos = listPartitionInfos(consumer, topicPartitionFilter) + + if (partitionInfos.isEmpty) { + throw new IllegalArgumentException("Could not match any topic-partitions with the specified filters") + } + + val topicPartitions = partitionInfos.flatMap { p => + if (p.leader == null) { + System.err.println(s"Error: topic-partition ${p.topic}:${p.partition} does not have a leader. Skip getting offsets") + None + } else + Some(new TopicPartition(p.topic, p.partition)) + } + + /* Note that the value of the map can be null */ + val partitionOffsets: collection.Map[TopicPartition, java.lang.Long] = listOffsetsTimestamp match { + case ListOffsetsRequest.EARLIEST_TIMESTAMP => consumer.beginningOffsets(topicPartitions.asJava).asScala + case ListOffsetsRequest.LATEST_TIMESTAMP => consumer.endOffsets(topicPartitions.asJava).asScala + case _ => + val timestampsToSearch = topicPartitions.map(tp => tp -> (listOffsetsTimestamp: java.lang.Long)).toMap.asJava + consumer.offsetsForTimes(timestampsToSearch).asScala.map { case (k, x) => + if (x == null) (k, null) else (k, x.offset: java.lang.Long) + } + } + + partitionOffsets.toSeq.sortWith((tp1, tp2) => compareTopicPartitions(tp1._1, tp2._1)).foreach { + case (tp, offset) => println(s"${tp.topic}:${tp.partition}:${Option(offset).getOrElse("")}") + } + } finally { + consumer.close() + } + } + + def compareTopicPartitions(a: TopicPartition, b: TopicPartition): Boolean = { + (a.topic(), a.partition()) < (b.topic(), b.partition()) + } + + /** + * Creates a topic-partition filter based on a list of patterns. + * Expected format: + * List: TopicPartitionPattern(, TopicPartitionPattern)* + * TopicPartitionPattern: TopicPattern(:PartitionPattern)? | :PartitionPattern + * TopicPattern: REGEX + * PartitionPattern: NUMBER | NUMBER-(NUMBER)? | -NUMBER + */ + def createTopicPartitionFilterWithPatternList(topicPartitions: String, excludeInternalTopics: Boolean): PartitionInfo => Boolean = { + val ruleSpecs = topicPartitions.split(",") + val rules = ruleSpecs.map(ruleSpec => parseRuleSpec(ruleSpec, excludeInternalTopics)) + tp => rules.exists { rule => rule.apply(tp) } + } + + def parseRuleSpec(ruleSpec: String, excludeInternalTopics: Boolean): PartitionInfo => Boolean = { + val matcher = TopicPartitionPattern.matcher(ruleSpec) + if (!matcher.matches()) + throw new IllegalArgumentException(s"Invalid rule specification: $ruleSpec") + + def group(group: Int): Option[String] = { + Option(matcher.group(group)).filter(s => s != null && s.nonEmpty) + } + + val topicFilter = IncludeList(group(1).getOrElse(".*")) + val partitionFilter = group(2).map(_.toInt) match { + case Some(partition) => + (p: Int) => p == partition + case None => + val lowerRange = group(3).map(_.toInt).getOrElse(0) + val upperRange = group(4).map(_.toInt).getOrElse(Int.MaxValue) + (p: Int) => p >= lowerRange && p < upperRange + } + + tp => topicFilter.isTopicAllowed(tp.topic, excludeInternalTopics) && partitionFilter(tp.partition) + } + + /** + * Creates a topic-partition filter based on a topic pattern and a set of partition ids. + */ + def createTopicPartitionFilterWithTopicAndPartitionPattern(topicOpt: Option[String], excludeInternalTopics: Boolean, partitionIds: Set[Int]): PartitionInfo => Boolean = { + val topicsFilter = IncludeList(topicOpt.getOrElse(".*")) + t => topicsFilter.isTopicAllowed(t.topic, excludeInternalTopics) && (partitionIds.isEmpty || partitionIds.contains(t.partition)) + } + + def createPartitionSet(partitionsString: String): Set[Int] = { + if (partitionsString == null || partitionsString.isEmpty) + Set.empty + else + partitionsString.split(",").map { partitionString => + try partitionString.toInt + catch { + case _: NumberFormatException => + throw new IllegalArgumentException(s"--partitions expects a comma separated list of numeric " + + s"partition ids, but received: $partitionsString") + } + }.toSet + } + + /** + * Return the partition infos. Filter them with topicPartitionFilter. + */ + private def listPartitionInfos(consumer: KafkaConsumer[_, _], topicPartitionFilter: PartitionInfo => Boolean): Seq[PartitionInfo] = { + consumer.listTopics.asScala.values.flatMap { partitions => + partitions.asScala.filter(topicPartitionFilter) + }.toBuffer + } +} diff --git a/core/src/main/scala/kafka/tools/JmxTool.scala b/core/src/main/scala/kafka/tools/JmxTool.scala new file mode 100644 index 0000000..f7ace83 --- /dev/null +++ b/core/src/main/scala/kafka/tools/JmxTool.scala @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package kafka.tools + +import java.util.{Date, Objects} +import java.text.SimpleDateFormat +import javax.management._ +import javax.management.remote._ +import javax.rmi.ssl.SslRMIClientSocketFactory + +import joptsimple.OptionParser + +import scala.jdk.CollectionConverters._ +import scala.collection.mutable +import scala.math._ +import kafka.utils.{CommandLineUtils, Exit, Logging} + + +/** + * A program for reading JMX metrics from a given endpoint. + * + * This tool only works reliably if the JmxServer is fully initialized prior to invoking the tool. See KAFKA-4620 for + * details. + */ +object JmxTool extends Logging { + + def main(args: Array[String]): Unit = { + // Parse command line + val parser = new OptionParser(false) + val objectNameOpt = + parser.accepts("object-name", "A JMX object name to use as a query. This can contain wild cards, and this option " + + "can be given multiple times to specify more than one query. If no objects are specified " + + "all objects will be queried.") + .withRequiredArg + .describedAs("name") + .ofType(classOf[String]) + val attributesOpt = + parser.accepts("attributes", "The list of attributes to include in the query. This is a comma-separated list. If no " + + "attributes are specified all objects will be queried.") + .withRequiredArg + .describedAs("name") + .ofType(classOf[String]) + val reportingIntervalOpt = parser.accepts("reporting-interval", "Interval in MS with which to poll jmx stats; default value is 2 seconds. " + + "Value of -1 equivalent to setting one-time to true") + .withRequiredArg + .describedAs("ms") + .ofType(classOf[java.lang.Integer]) + .defaultsTo(2000) + val oneTimeOpt = parser.accepts("one-time", "Flag to indicate run once only.") + .withRequiredArg + .describedAs("one-time") + .ofType(classOf[java.lang.Boolean]) + .defaultsTo(false) + val dateFormatOpt = parser.accepts("date-format", "The date format to use for formatting the time field. " + + "See java.text.SimpleDateFormat for options.") + .withRequiredArg + .describedAs("format") + .ofType(classOf[String]) + val jmxServiceUrlOpt = + parser.accepts("jmx-url", "The url to connect to poll JMX data. See Oracle javadoc for JMXServiceURL for details.") + .withRequiredArg + .describedAs("service-url") + .ofType(classOf[String]) + .defaultsTo("service:jmx:rmi:///jndi/rmi://:9999/jmxrmi") + val reportFormatOpt = parser.accepts("report-format", "output format name: either 'original', 'properties', 'csv', 'tsv' ") + .withRequiredArg + .describedAs("report-format") + .ofType(classOf[java.lang.String]) + .defaultsTo("original") + val jmxAuthPropOpt = parser.accepts("jmx-auth-prop", "A mechanism to pass property in the form 'username=password' " + + "when enabling remote JMX with password authentication.") + .withRequiredArg + .describedAs("jmx-auth-prop") + .ofType(classOf[String]) + val jmxSslEnableOpt = parser.accepts("jmx-ssl-enable", "Flag to enable remote JMX with SSL.") + .withRequiredArg + .describedAs("ssl-enable") + .ofType(classOf[java.lang.Boolean]) + .defaultsTo(false) + val waitOpt = parser.accepts("wait", "Wait for requested JMX objects to become available before starting output. " + + "Only supported when the list of objects is non-empty and contains no object name patterns.") + val helpOpt = parser.accepts("help", "Print usage information.") + + + if(args.length == 0) + CommandLineUtils.printUsageAndDie(parser, "Dump JMX values to standard output.") + + val options = parser.parse(args : _*) + + if(options.has(helpOpt)) { + parser.printHelpOn(System.out) + Exit.exit(0) + } + + val url = new JMXServiceURL(options.valueOf(jmxServiceUrlOpt)) + val interval = options.valueOf(reportingIntervalOpt).intValue + val oneTime = interval < 0 || options.has(oneTimeOpt) + val attributesIncludeExists = options.has(attributesOpt) + val attributesInclude = if(attributesIncludeExists) Some(options.valueOf(attributesOpt).split(",").filterNot(_.equals(""))) else None + val dateFormatExists = options.has(dateFormatOpt) + val dateFormat = if(dateFormatExists) Some(new SimpleDateFormat(options.valueOf(dateFormatOpt))) else None + val wait = options.has(waitOpt) + + val reportFormat = parseFormat(options.valueOf(reportFormatOpt).toLowerCase) + val reportFormatOriginal = reportFormat.equals("original") + + val enablePasswordAuth = options.has(jmxAuthPropOpt) + val enableSsl = options.has(jmxSslEnableOpt) + + var jmxc: JMXConnector = null + var mbsc: MBeanServerConnection = null + var connected = false + val connectTimeoutMs = 10000 + val connectTestStarted = System.currentTimeMillis + do { + try { + System.err.println(s"Trying to connect to JMX url: $url.") + val env = new java.util.HashMap[String, AnyRef] + // ssl enable + if (enableSsl) { + val csf = new SslRMIClientSocketFactory + env.put("com.sun.jndi.rmi.factory.socket", csf) + } + // password authentication enable + if (enablePasswordAuth) { + val credentials = options.valueOf(jmxAuthPropOpt).split("=", 2) + env.put(JMXConnector.CREDENTIALS, credentials) + } + jmxc = JMXConnectorFactory.connect(url, env) + mbsc = jmxc.getMBeanServerConnection + connected = true + } catch { + case e : Exception => + System.err.println(s"Could not connect to JMX url: $url. Exception ${e.getMessage}.") + e.printStackTrace() + Thread.sleep(100) + } + } while (System.currentTimeMillis - connectTestStarted < connectTimeoutMs && !connected) + + if (!connected) { + System.err.println(s"Could not connect to JMX url $url after $connectTimeoutMs ms.") + System.err.println("Exiting.") + sys.exit(1) + } + + val queries: Iterable[ObjectName] = + if(options.has(objectNameOpt)) + options.valuesOf(objectNameOpt).asScala.map(new ObjectName(_)) + else + List(null) + + val hasPatternQueries = queries.filterNot(Objects.isNull).exists((name: ObjectName) => name.isPattern) + + var names: Iterable[ObjectName] = null + def namesSet = Option(names).toSet.flatten + def foundAllObjects = queries.toSet == namesSet + val waitTimeoutMs = 10000 + if (!hasPatternQueries) { + val start = System.currentTimeMillis + do { + if (names != null) { + System.err.println("Could not find all object names, retrying") + Thread.sleep(100) + } + names = queries.flatMap((name: ObjectName) => mbsc.queryNames(name, null).asScala) + } while (wait && System.currentTimeMillis - start < waitTimeoutMs && !foundAllObjects) + } + + if (wait && !foundAllObjects) { + val missing = (queries.toSet - namesSet).mkString(", ") + System.err.println(s"Could not find all requested object names after $waitTimeoutMs ms. Missing $missing") + System.err.println("Exiting.") + sys.exit(1) + } + + val numExpectedAttributes: Map[ObjectName, Int] = + if (!attributesIncludeExists) + names.map{name: ObjectName => + val mbean = mbsc.getMBeanInfo(name) + (name, mbsc.getAttributes(name, mbean.getAttributes.map(_.getName)).size)}.toMap + else { + if (!hasPatternQueries) + names.map{name: ObjectName => + val mbean = mbsc.getMBeanInfo(name) + val attributes = mbsc.getAttributes(name, mbean.getAttributes.map(_.getName)) + val expectedAttributes = attributes.asScala.asInstanceOf[mutable.Buffer[Attribute]] + .filter(attr => attributesInclude.get.contains(attr.getName)) + (name, expectedAttributes.size)}.toMap.filter(_._2 > 0) + else + queries.map((_, attributesInclude.get.length)).toMap + } + + if(numExpectedAttributes.isEmpty) { + CommandLineUtils.printUsageAndDie(parser, s"No matched attributes for the queried objects $queries.") + } + + // print csv header + val keys = List("time") ++ queryAttributes(mbsc, names, attributesInclude).keys.toArray.sorted + if(reportFormatOriginal && keys.size == numExpectedAttributes.values.sum + 1) { + println(keys.map("\"" + _ + "\"").mkString(",")) + } + + var keepGoing = true + while (keepGoing) { + val start = System.currentTimeMillis + val attributes = queryAttributes(mbsc, names, attributesInclude) + attributes("time") = dateFormat match { + case Some(dFormat) => dFormat.format(new Date) + case None => System.currentTimeMillis().toString + } + if(attributes.keySet.size == numExpectedAttributes.values.sum + 1) { + if(reportFormatOriginal) { + println(keys.map(attributes(_)).mkString(",")) + } + else if(reportFormat.equals("properties")) { + keys.foreach( k => { println(k + "=" + attributes(k) ) } ) + } + else if(reportFormat.equals("csv")) { + keys.foreach( k => { println(k + ",\"" + attributes(k) + "\"" ) } ) + } + else { // tsv + keys.foreach( k => { println(k + "\t" + attributes(k) ) } ) + } + } + + if (oneTime) { + keepGoing = false + } + else { + val sleep = max(0, interval - (System.currentTimeMillis - start)) + Thread.sleep(sleep) + } + } + } + + def queryAttributes(mbsc: MBeanServerConnection, names: Iterable[ObjectName], attributesInclude: Option[Array[String]]): mutable.Map[String, Any] = { + val attributes = new mutable.HashMap[String, Any]() + for (name <- names) { + val mbean = mbsc.getMBeanInfo(name) + for (attrObj <- mbsc.getAttributes(name, mbean.getAttributes.map(_.getName)).asScala) { + val attr = attrObj.asInstanceOf[Attribute] + attributesInclude match { + case Some(allowedAttributes) => + if (allowedAttributes.contains(attr.getName)) + attributes(name.toString + ":" + attr.getName) = attr.getValue + case None => attributes(name.toString + ":" + attr.getName) = attr.getValue + } + } + } + attributes + } + + def parseFormat(reportFormatOpt : String): String = reportFormatOpt match { + case "properties" => "properties" + case "csv" => "csv" + case "tsv" => "tsv" + case _ => "original" + } +} diff --git a/core/src/main/scala/kafka/tools/MirrorMaker.scala b/core/src/main/scala/kafka/tools/MirrorMaker.scala new file mode 100755 index 0000000..f6e2865 --- /dev/null +++ b/core/src/main/scala/kafka/tools/MirrorMaker.scala @@ -0,0 +1,589 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.tools + +import java.time.Duration +import java.util +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} +import java.util.concurrent.CountDownLatch +import java.util.regex.Pattern +import java.util.{Collections, Properties} + +import kafka.consumer.BaseConsumerRecord +import kafka.metrics.KafkaMetricsGroup +import kafka.utils._ +import org.apache.kafka.clients.consumer._ +import org.apache.kafka.clients.producer.internals.ErrorLoggingCallback +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerConfig, ProducerRecord, RecordMetadata} +import org.apache.kafka.common.errors.{TimeoutException, WakeupException} +import org.apache.kafka.common.record.RecordBatch +import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer} +import org.apache.kafka.common.utils.{Time, Utils} +import org.apache.kafka.common.{KafkaException, TopicPartition} + +import scala.jdk.CollectionConverters._ +import scala.collection.mutable.HashMap +import scala.util.control.ControlThrowable +import scala.util.{Failure, Success, Try} + +/** + * The mirror maker has the following architecture: + * - There are N mirror maker threads, each of which is equipped with a separate KafkaConsumer instance. + * - All the mirror maker threads share one producer. + * - Each mirror maker thread periodically flushes the producer and then commits all offsets. + * + * @note For mirror maker, the following settings are set by default to make sure there is no data loss: + * 1. use producer with following settings + * acks=all + * delivery.timeout.ms=max integer + * max.block.ms=max long + * max.in.flight.requests.per.connection=1 + * 2. Consumer Settings + * enable.auto.commit=false + * 3. Mirror Maker Setting: + * abort.on.send.failure=true + * + * @deprecated Since 3.0, use the Connect-based MirrorMaker instead (aka MM2). + */ +@deprecated(message = "Use the Connect-based MirrorMaker instead (aka MM2).", since = "3.0") +object MirrorMaker extends Logging with KafkaMetricsGroup { + + private[tools] var producer: MirrorMakerProducer = null + private var mirrorMakerThreads: Seq[MirrorMakerThread] = null + private val isShuttingDown: AtomicBoolean = new AtomicBoolean(false) + // Track the messages not successfully sent by mirror maker. + private val numDroppedMessages: AtomicInteger = new AtomicInteger(0) + private var messageHandler: MirrorMakerMessageHandler = null + private var offsetCommitIntervalMs = 0 + private var abortOnSendFailure: Boolean = true + @volatile private var exitingOnSendFailure: Boolean = false + private var lastSuccessfulCommitTime = -1L + private val time = Time.SYSTEM + + // If a message send failed after retries are exhausted. The offset of the messages will also be removed from + // the unacked offset list to avoid offset commit being stuck on that offset. In this case, the offset of that + // message was not really acked, but was skipped. This metric records the number of skipped offsets. + newGauge("MirrorMaker-numDroppedMessages", () => numDroppedMessages.get()) + + def main(args: Array[String]): Unit = { + + warn("This tool is deprecated and may be removed in a future major release.") + info("Starting mirror maker") + try { + val opts = new MirrorMakerOptions(args) + CommandLineUtils.printHelpAndExitIfNeeded(opts, "This tool helps to continuously copy data between two Kafka clusters.") + opts.checkArgs() + } catch { + case ct: ControlThrowable => throw ct + case t: Throwable => + error("Exception when starting mirror maker.", t) + } + + mirrorMakerThreads.foreach(_.start()) + mirrorMakerThreads.foreach(_.awaitShutdown()) + } + + def createConsumers(numStreams: Int, + consumerConfigProps: Properties, + customRebalanceListener: Option[ConsumerRebalanceListener], + include: Option[String]): Seq[ConsumerWrapper] = { + // Disable consumer auto offsets commit to prevent data loss. + maybeSetDefaultProperty(consumerConfigProps, "enable.auto.commit", "false") + // Hardcode the deserializer to ByteArrayDeserializer + consumerConfigProps.setProperty("key.deserializer", classOf[ByteArrayDeserializer].getName) + consumerConfigProps.setProperty("value.deserializer", classOf[ByteArrayDeserializer].getName) + // The default client id is group id, we manually set client id to groupId-index to avoid metric collision + val groupIdString = consumerConfigProps.getProperty("group.id") + val consumers = (0 until numStreams) map { i => + consumerConfigProps.setProperty("client.id", groupIdString + "-" + i.toString) + new KafkaConsumer[Array[Byte], Array[Byte]](consumerConfigProps) + } + include.getOrElse(throw new IllegalArgumentException("include list cannot be empty")) + consumers.map(consumer => new ConsumerWrapper(consumer, customRebalanceListener, include)) + } + + def commitOffsets(consumerWrapper: ConsumerWrapper): Unit = { + if (!exitingOnSendFailure) { + var retry = 0 + var retryNeeded = true + while (retryNeeded) { + trace("Committing offsets.") + try { + consumerWrapper.commit() + lastSuccessfulCommitTime = time.milliseconds + retryNeeded = false + } catch { + case e: WakeupException => + // we only call wakeup() once to close the consumer, + // so if we catch it in commit we can safely retry + // and re-throw to break the loop + commitOffsets(consumerWrapper) + throw e + + case _: TimeoutException => + Try(consumerWrapper.consumer.listTopics) match { + case Success(visibleTopics) => + consumerWrapper.offsets --= consumerWrapper.offsets.keySet.filter(tp => !visibleTopics.containsKey(tp.topic)) + case Failure(e) => + warn("Failed to list all authorized topics after committing offsets timed out: ", e) + } + + retry += 1 + warn("Failed to commit offsets because the offset commit request processing can not be completed in time. " + + s"If you see this regularly, it could indicate that you need to increase the consumer's ${ConsumerConfig.DEFAULT_API_TIMEOUT_MS_CONFIG} " + + s"Last successful offset commit timestamp=$lastSuccessfulCommitTime, retry count=$retry") + Thread.sleep(100) + + case _: CommitFailedException => + retryNeeded = false + warn("Failed to commit offsets because the consumer group has rebalanced and assigned partitions to " + + "another instance. If you see this regularly, it could indicate that you need to either increase " + + s"the consumer's ${ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG} or reduce the number of records " + + s"handled on each iteration with ${ConsumerConfig.MAX_POLL_RECORDS_CONFIG}") + } + } + } else { + info("Exiting on send failure, skip committing offsets.") + } + } + + def cleanShutdown(): Unit = { + if (isShuttingDown.compareAndSet(false, true)) { + info("Start clean shutdown.") + // Shutdown consumer threads. + info("Shutting down consumer threads.") + if (mirrorMakerThreads != null) { + mirrorMakerThreads.foreach(_.shutdown()) + mirrorMakerThreads.foreach(_.awaitShutdown()) + } + info("Closing producer.") + producer.close() + info("Kafka mirror maker shutdown successfully") + } + } + + private def maybeSetDefaultProperty(properties: Properties, propertyName: String, defaultValue: String): Unit = { + val propertyValue = properties.getProperty(propertyName) + properties.setProperty(propertyName, Option(propertyValue).getOrElse(defaultValue)) + if (properties.getProperty(propertyName) != defaultValue) + info("Property %s is overridden to %s - data loss or message reordering is possible.".format(propertyName, propertyValue)) + } + + class MirrorMakerThread(consumerWrapper: ConsumerWrapper, + val threadId: Int) extends Thread with Logging with KafkaMetricsGroup { + private val threadName = "mirrormaker-thread-" + threadId + private val shutdownLatch: CountDownLatch = new CountDownLatch(1) + private var lastOffsetCommitMs = System.currentTimeMillis() + @volatile private var shuttingDown: Boolean = false + this.logIdent = "[%s] ".format(threadName) + + setName(threadName) + + private def toBaseConsumerRecord(record: ConsumerRecord[Array[Byte], Array[Byte]]): BaseConsumerRecord = + BaseConsumerRecord(record.topic, + record.partition, + record.offset, + record.timestamp, + record.timestampType, + record.key, + record.value, + record.headers) + + override def run(): Unit = { + info(s"Starting mirror maker thread $threadName") + try { + consumerWrapper.init() + + // We needed two while loops due to the old consumer semantics, this can now be simplified + while (!exitingOnSendFailure && !shuttingDown) { + try { + while (!exitingOnSendFailure && !shuttingDown) { + val data = consumerWrapper.receive() + if (data.value != null) { + trace("Sending message with value size %d and offset %d.".format(data.value.length, data.offset)) + } else { + trace("Sending message with null value and offset %d.".format(data.offset)) + } + val records = messageHandler.handle(toBaseConsumerRecord(data)) + records.forEach(producer.send) + maybeFlushAndCommitOffsets() + } + } catch { + case _: NoRecordsException => + trace("Caught NoRecordsException, continue iteration.") + case _: WakeupException => + trace("Caught WakeupException, continue iteration.") + case e: KafkaException if (shuttingDown || exitingOnSendFailure) => + trace(s"Ignoring caught KafkaException during shutdown. sendFailure: $exitingOnSendFailure.", e) + } + maybeFlushAndCommitOffsets() + } + } catch { + case t: Throwable => + exitingOnSendFailure = true + fatal("Mirror maker thread failure due to ", t) + } finally { + CoreUtils.swallow ({ + info("Flushing producer.") + producer.flush() + + // note that this commit is skipped if flush() fails which ensures that we don't lose messages + info("Committing consumer offsets.") + commitOffsets(consumerWrapper) + }, this) + + info("Shutting down consumer connectors.") + CoreUtils.swallow(consumerWrapper.wakeup(), this) + CoreUtils.swallow(consumerWrapper.close(), this) + shutdownLatch.countDown() + info("Mirror maker thread stopped") + // if it exits accidentally, stop the entire mirror maker + if (!isShuttingDown.get()) { + fatal("Mirror maker thread exited abnormally, stopping the whole mirror maker.") + sys.exit(-1) + } + } + } + + def maybeFlushAndCommitOffsets(): Unit = { + if (System.currentTimeMillis() - lastOffsetCommitMs > offsetCommitIntervalMs) { + debug("Committing MirrorMaker state.") + producer.flush() + commitOffsets(consumerWrapper) + lastOffsetCommitMs = System.currentTimeMillis() + } + } + + def shutdown(): Unit = { + try { + info(s"$threadName shutting down") + shuttingDown = true + consumerWrapper.wakeup() + } + catch { + case _: InterruptedException => + warn("Interrupt during shutdown of the mirror maker thread") + } + } + + def awaitShutdown(): Unit = { + try { + shutdownLatch.await() + info("Mirror maker thread shutdown complete") + } catch { + case _: InterruptedException => + warn("Shutdown of the mirror maker thread interrupted") + } + } + } + + // Visible for testing + private[tools] class ConsumerWrapper(private[tools] val consumer: Consumer[Array[Byte], Array[Byte]], + customRebalanceListener: Option[ConsumerRebalanceListener], + includeOpt: Option[String]) { + val regex = includeOpt.getOrElse(throw new IllegalArgumentException("New consumer only supports include.")) + var recordIter: java.util.Iterator[ConsumerRecord[Array[Byte], Array[Byte]]] = null + + // We manually maintain the consumed offsets for historical reasons and it could be simplified + // Visible for testing + private[tools] val offsets = new HashMap[TopicPartition, Long]() + + def init(): Unit = { + debug("Initiating consumer") + val consumerRebalanceListener = new InternalRebalanceListener(this, customRebalanceListener) + includeOpt.foreach { include => + try { + consumer.subscribe(Pattern.compile(IncludeList(include).regex), consumerRebalanceListener) + } catch { + case pse: RuntimeException => + error(s"Invalid expression syntax: $include") + throw pse + } + } + } + + def receive(): ConsumerRecord[Array[Byte], Array[Byte]] = { + if (recordIter == null || !recordIter.hasNext) { + // In scenarios where data does not arrive within offsetCommitIntervalMs and + // offsetCommitIntervalMs is less than poll's timeout, offset commit will be delayed for any + // uncommitted record since last poll. Using one second as poll's timeout ensures that + // offsetCommitIntervalMs, of value greater than 1 second, does not see delays in offset + // commit. + recordIter = consumer.poll(Duration.ofSeconds(1L)).iterator + if (!recordIter.hasNext) + throw new NoRecordsException + } + + val record = recordIter.next() + val tp = new TopicPartition(record.topic, record.partition) + + offsets.put(tp, record.offset + 1) + record + } + + def wakeup(): Unit = { + consumer.wakeup() + } + + def close(): Unit = { + consumer.close() + } + + def commit(): Unit = { + consumer.commitSync(offsets.map { case (tp, offset) => (tp, new OffsetAndMetadata(offset)) }.asJava) + offsets.clear() + } + } + + private class InternalRebalanceListener(consumerWrapper: ConsumerWrapper, + customRebalanceListener: Option[ConsumerRebalanceListener]) + extends ConsumerRebalanceListener { + + override def onPartitionsLost(partitions: util.Collection[TopicPartition]): Unit = {} + + override def onPartitionsRevoked(partitions: util.Collection[TopicPartition]): Unit = { + producer.flush() + commitOffsets(consumerWrapper) + customRebalanceListener.foreach(_.onPartitionsRevoked(partitions)) + } + + override def onPartitionsAssigned(partitions: util.Collection[TopicPartition]): Unit = { + customRebalanceListener.foreach(_.onPartitionsAssigned(partitions)) + } + } + + private[tools] class MirrorMakerProducer(val sync: Boolean, val producerProps: Properties) { + + val producer = new KafkaProducer[Array[Byte], Array[Byte]](producerProps) + + def send(record: ProducerRecord[Array[Byte], Array[Byte]]): Unit = { + if (sync) { + this.producer.send(record).get() + } else { + this.producer.send(record, + new MirrorMakerProducerCallback(record.topic(), record.key(), record.value())) + } + } + + def flush(): Unit = { + this.producer.flush() + } + + def close(): Unit = { + this.producer.close() + } + + def close(timeout: Long): Unit = { + this.producer.close(Duration.ofMillis(timeout)) + } + } + + private class MirrorMakerProducerCallback (topic: String, key: Array[Byte], value: Array[Byte]) + extends ErrorLoggingCallback(topic, key, value, false) { + + override def onCompletion(metadata: RecordMetadata, exception: Exception): Unit = { + if (exception != null) { + // Use default call back to log error. This means the max retries of producer has reached and message + // still could not be sent. + super.onCompletion(metadata, exception) + // If abort.on.send.failure is set, stop the mirror maker. Otherwise log skipped message and move on. + if (abortOnSendFailure) { + info("Closing producer due to send failure.") + exitingOnSendFailure = true + producer.close(0) + } + numDroppedMessages.incrementAndGet() + } + } + } + + /** + * If message.handler.args is specified. A constructor that takes in a String as argument must exist. + */ + trait MirrorMakerMessageHandler { + def handle(record: BaseConsumerRecord): util.List[ProducerRecord[Array[Byte], Array[Byte]]] + } + + private[tools] object defaultMirrorMakerMessageHandler extends MirrorMakerMessageHandler { + override def handle(record: BaseConsumerRecord): util.List[ProducerRecord[Array[Byte], Array[Byte]]] = { + val timestamp: java.lang.Long = if (record.timestamp == RecordBatch.NO_TIMESTAMP) null else record.timestamp + Collections.singletonList(new ProducerRecord(record.topic, null, timestamp, record.key, record.value, record.headers)) + } + } + + // package-private for tests + private[tools] class NoRecordsException extends RuntimeException + + class MirrorMakerOptions(args: Array[String]) extends CommandDefaultOptions(args) { + + val consumerConfigOpt = parser.accepts("consumer.config", + "Embedded consumer config for consuming from the source cluster.") + .withRequiredArg() + .describedAs("config file") + .ofType(classOf[String]) + + parser.accepts("new.consumer", + "DEPRECATED Use new consumer in mirror maker (this is the default so this option will be removed in " + + "a future version).") + + val producerConfigOpt = parser.accepts("producer.config", + "Embedded producer config.") + .withRequiredArg() + .describedAs("config file") + .ofType(classOf[String]) + + val numStreamsOpt = parser.accepts("num.streams", + "Number of consumption streams.") + .withRequiredArg() + .describedAs("Number of threads") + .ofType(classOf[java.lang.Integer]) + .defaultsTo(1) + + val whitelistOpt = parser.accepts("whitelist", + "DEPRECATED, use --include instead; ignored if --include specified. List of included topics to mirror.") + .withRequiredArg() + .describedAs("Java regex (String)") + .ofType(classOf[String]) + + val includeOpt = parser.accepts("include", + "List of included topics to mirror.") + .withRequiredArg() + .describedAs("Java regex (String)") + .ofType(classOf[String]) + + val offsetCommitIntervalMsOpt = parser.accepts("offset.commit.interval.ms", + "Offset commit interval in ms.") + .withRequiredArg() + .describedAs("offset commit interval in millisecond") + .ofType(classOf[java.lang.Integer]) + .defaultsTo(60000) + + val consumerRebalanceListenerOpt = parser.accepts("consumer.rebalance.listener", + "The consumer rebalance listener to use for mirror maker consumer.") + .withRequiredArg() + .describedAs("A custom rebalance listener of type ConsumerRebalanceListener") + .ofType(classOf[String]) + + val rebalanceListenerArgsOpt = parser.accepts("rebalance.listener.args", + "Arguments used by custom rebalance listener for mirror maker consumer.") + .withRequiredArg() + .describedAs("Arguments passed to custom rebalance listener constructor as a string.") + .ofType(classOf[String]) + + val messageHandlerOpt = parser.accepts("message.handler", + "Message handler which will process every record in-between consumer and producer.") + .withRequiredArg() + .describedAs("A custom message handler of type MirrorMakerMessageHandler") + .ofType(classOf[String]) + + val messageHandlerArgsOpt = parser.accepts("message.handler.args", + "Arguments used by custom message handler for mirror maker.") + .withRequiredArg() + .describedAs("Arguments passed to message handler constructor.") + .ofType(classOf[String]) + + val abortOnSendFailureOpt = parser.accepts("abort.on.send.failure", + "Configure the mirror maker to exit on a failed send.") + .withRequiredArg() + .describedAs("Stop the entire mirror maker when a send failure occurs") + .ofType(classOf[String]) + .defaultsTo("true") + + options = parser.parse(args: _*) + + def checkArgs() = { + CommandLineUtils.checkRequiredArgs(parser, options, consumerConfigOpt, producerConfigOpt) + val consumerProps = Utils.loadProps(options.valueOf(consumerConfigOpt)) + + + if (!options.has(includeOpt) && !options.has(whitelistOpt)) { + error("include list must be specified") + sys.exit(1) + } + + if (!consumerProps.containsKey(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG)) + System.err.println("WARNING: The default partition assignment strategy of the mirror maker will " + + "change from 'range' to 'roundrobin' in an upcoming release (so that better load balancing can be achieved). If " + + "you prefer to make this switch in advance of that release add the following to the corresponding " + + "config: 'partition.assignment.strategy=org.apache.kafka.clients.consumer.RoundRobinAssignor'") + + abortOnSendFailure = options.valueOf(abortOnSendFailureOpt).toBoolean + offsetCommitIntervalMs = options.valueOf(offsetCommitIntervalMsOpt).intValue() + val numStreams = options.valueOf(numStreamsOpt).intValue() + + Exit.addShutdownHook("MirrorMakerShutdownHook", cleanShutdown()) + + // create producer + val producerProps = Utils.loadProps(options.valueOf(producerConfigOpt)) + val sync = producerProps.getProperty("producer.type", "async").equals("sync") + producerProps.remove("producer.type") + // Defaults to no data loss settings. + maybeSetDefaultProperty(producerProps, ProducerConfig.DELIVERY_TIMEOUT_MS_CONFIG, Int.MaxValue.toString) + maybeSetDefaultProperty(producerProps, ProducerConfig.MAX_BLOCK_MS_CONFIG, Long.MaxValue.toString) + maybeSetDefaultProperty(producerProps, ProducerConfig.ACKS_CONFIG, "all") + maybeSetDefaultProperty(producerProps, ProducerConfig.MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION, "1") + // Always set producer key and value serializer to ByteArraySerializer. + producerProps.setProperty(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + producerProps.setProperty(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + producer = new MirrorMakerProducer(sync, producerProps) + + // Create consumers + val customRebalanceListener: Option[ConsumerRebalanceListener] = { + val customRebalanceListenerClass = options.valueOf(consumerRebalanceListenerOpt) + if (customRebalanceListenerClass != null) { + val rebalanceListenerArgs = options.valueOf(rebalanceListenerArgsOpt) + if (rebalanceListenerArgs != null) + Some(CoreUtils.createObject[ConsumerRebalanceListener](customRebalanceListenerClass, rebalanceListenerArgs)) + else + Some(CoreUtils.createObject[ConsumerRebalanceListener](customRebalanceListenerClass)) + } else { + None + } + } + + val includedTopicsValue = if (options.has(includeOpt)) + Option(options.valueOf(includeOpt)) + else + Option(options.valueOf(whitelistOpt)) + + val mirrorMakerConsumers = createConsumers( + numStreams, + consumerProps, + customRebalanceListener, + includedTopicsValue) + + // Create mirror maker threads. + mirrorMakerThreads = (0 until numStreams) map (i => + new MirrorMakerThread(mirrorMakerConsumers(i), i)) + + // Create and initialize message handler + val customMessageHandlerClass = options.valueOf(messageHandlerOpt) + val messageHandlerArgs = options.valueOf(messageHandlerArgsOpt) + messageHandler = { + if (customMessageHandlerClass != null) { + if (messageHandlerArgs != null) + CoreUtils.createObject[MirrorMakerMessageHandler](customMessageHandlerClass, messageHandlerArgs) + else + CoreUtils.createObject[MirrorMakerMessageHandler](customMessageHandlerClass) + } else { + defaultMirrorMakerMessageHandler + } + } + } + } +} diff --git a/core/src/main/scala/kafka/tools/PerfConfig.scala b/core/src/main/scala/kafka/tools/PerfConfig.scala new file mode 100644 index 0000000..836163c --- /dev/null +++ b/core/src/main/scala/kafka/tools/PerfConfig.scala @@ -0,0 +1,40 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +package kafka.tools + +import kafka.utils.CommandDefaultOptions + + +class PerfConfig(args: Array[String]) extends CommandDefaultOptions(args) { + val numMessagesOpt = parser.accepts("messages", "REQUIRED: The number of messages to send or consume") + .withRequiredArg + .describedAs("count") + .ofType(classOf[java.lang.Long]) + val reportingIntervalOpt = parser.accepts("reporting-interval", "Interval in milliseconds at which to print progress info.") + .withRequiredArg + .describedAs("interval_ms") + .ofType(classOf[java.lang.Integer]) + .defaultsTo(5000) + val dateFormatOpt = parser.accepts("date-format", "The date format to use for formatting the time field. " + + "See java.text.SimpleDateFormat for options.") + .withRequiredArg + .describedAs("date format") + .ofType(classOf[String]) + .defaultsTo("yyyy-MM-dd HH:mm:ss:SSS") + val hideHeaderOpt = parser.accepts("hide-header", "If set, skips printing the header for the stats ") +} diff --git a/core/src/main/scala/kafka/tools/ReplicaVerificationTool.scala b/core/src/main/scala/kafka/tools/ReplicaVerificationTool.scala new file mode 100644 index 0000000..2c74550 --- /dev/null +++ b/core/src/main/scala/kafka/tools/ReplicaVerificationTool.scala @@ -0,0 +1,521 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.tools + +import joptsimple.OptionParser +import kafka.api._ +import kafka.utils.{IncludeList, _} +import org.apache.kafka.clients._ +import org.apache.kafka.clients.admin.{Admin, ListTopicsOptions, TopicDescription} +import org.apache.kafka.clients.consumer.{ConsumerConfig, KafkaConsumer} +import org.apache.kafka.common.message.FetchResponseData +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.network.{NetworkReceive, Selectable, Selector} +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.requests.AbstractRequest.Builder +import org.apache.kafka.common.requests.{AbstractRequest, FetchResponse, ListOffsetsRequest, FetchRequest => JFetchRequest} +import org.apache.kafka.common.serialization.StringDeserializer +import org.apache.kafka.common.utils.{LogContext, Time} +import org.apache.kafka.common.{Node, TopicPartition, Uuid} +import java.net.SocketTimeoutException +import java.text.SimpleDateFormat +import java.util +import java.util.concurrent.CountDownLatch +import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} +import java.util.regex.{Pattern, PatternSyntaxException} +import java.util.{Date, Optional, Properties} + +import scala.collection.Seq +import scala.jdk.CollectionConverters._ + +/** + * For verifying the consistency among replicas. + * + * 1. start a fetcher on every broker. + * 2. each fetcher does the following + * 2.1 issues fetch request + * 2.2 puts the fetched result in a shared buffer + * 2.3 waits for all other fetchers to finish step 2.2 + * 2.4 one of the fetchers verifies the consistency of fetched results among replicas + * + * The consistency verification is up to the high watermark. The tool reports the + * max lag between the verified offset and the high watermark among all partitions. + * + * If a broker goes down, the verification of the partitions on that broker is delayed + * until the broker is up again. + * + * Caveats: + * 1. The tools needs all brokers to be up at startup time. + * 2. The tool doesn't handle out of range offsets. + */ + +object ReplicaVerificationTool extends Logging { + val clientId = "replicaVerificationTool" + val dateFormatString = "yyyy-MM-dd HH:mm:ss,SSS" + val dateFormat = new SimpleDateFormat(dateFormatString) + + def getCurrentTimeString() = { + ReplicaVerificationTool.dateFormat.format(new Date(Time.SYSTEM.milliseconds)) + } + + def main(args: Array[String]): Unit = { + val parser = new OptionParser(false) + val brokerListOpt = parser.accepts("broker-list", "REQUIRED: The list of hostname and port of the server to connect to.") + .withRequiredArg + .describedAs("hostname:port,...,hostname:port") + .ofType(classOf[String]) + val fetchSizeOpt = parser.accepts("fetch-size", "The fetch size of each request.") + .withRequiredArg + .describedAs("bytes") + .ofType(classOf[java.lang.Integer]) + .defaultsTo(ConsumerConfig.DEFAULT_MAX_PARTITION_FETCH_BYTES) + val maxWaitMsOpt = parser.accepts("max-wait-ms", "The max amount of time each fetch request waits.") + .withRequiredArg + .describedAs("ms") + .ofType(classOf[java.lang.Integer]) + .defaultsTo(1000) + val topicWhiteListOpt = parser.accepts("topic-white-list", "DEPRECATED use --topics-include instead; ignored if --topics-include specified. List of topics to verify replica consistency. Defaults to '.*' (all topics)") + .withRequiredArg + .describedAs("Java regex (String)") + .ofType(classOf[String]) + .defaultsTo(".*") + val topicsIncludeOpt = parser.accepts("topics-include", "List of topics to verify replica consistency. Defaults to '.*' (all topics)") + .withRequiredArg + .describedAs("Java regex (String)") + .ofType(classOf[String]) + .defaultsTo(".*") + val initialOffsetTimeOpt = parser.accepts("time", "Timestamp for getting the initial offsets.") + .withRequiredArg + .describedAs("timestamp/-1(latest)/-2(earliest)") + .ofType(classOf[java.lang.Long]) + .defaultsTo(-1L) + val reportIntervalOpt = parser.accepts("report-interval-ms", "The reporting interval.") + .withRequiredArg + .describedAs("ms") + .ofType(classOf[java.lang.Long]) + .defaultsTo(30 * 1000L) + val helpOpt = parser.accepts("help", "Print usage information.").forHelp() + val versionOpt = parser.accepts("version", "Print version information and exit.").forHelp() + + val options = parser.parse(args: _*) + + if (args.length == 0 || options.has(helpOpt)) { + CommandLineUtils.printUsageAndDie(parser, "Validate that all replicas for a set of topics have the same data.") + } + + if (options.has(versionOpt)) { + CommandLineUtils.printVersionAndDie() + } + CommandLineUtils.checkRequiredArgs(parser, options, brokerListOpt) + + val regex = if (options.has(topicsIncludeOpt)) + options.valueOf(topicsIncludeOpt) + else + options.valueOf(topicWhiteListOpt) + + val topicsIncludeFilter = new IncludeList(regex) + + try Pattern.compile(regex) + catch { + case _: PatternSyntaxException => + throw new RuntimeException(s"$regex is an invalid regex.") + } + + val fetchSize = options.valueOf(fetchSizeOpt).intValue + val maxWaitMs = options.valueOf(maxWaitMsOpt).intValue + val initialOffsetTime = options.valueOf(initialOffsetTimeOpt).longValue + val reportInterval = options.valueOf(reportIntervalOpt).longValue + // getting topic metadata + info("Getting topic metadata...") + val brokerList = options.valueOf(brokerListOpt) + ToolsUtils.validatePortOrDie(parser, brokerList) + + val (topicsMetadata, brokerInfo) = { + val adminClient = createAdminClient(brokerList) + try ((listTopicsMetadata(adminClient), brokerDetails(adminClient))) + finally CoreUtils.swallow(adminClient.close(), this) + } + + val topicIds = topicsMetadata.map( metadata => metadata.name() -> metadata.topicId()).toMap + + val filteredTopicMetadata = topicsMetadata.filter { topicMetaData => + topicsIncludeFilter.isTopicAllowed(topicMetaData.name, excludeInternalTopics = false) + } + + if (filteredTopicMetadata.isEmpty) { + error(s"No topics found. $topicsIncludeOpt if specified, is either filtering out all topics or there is no topic.") + Exit.exit(1) + } + + val topicPartitionReplicas = filteredTopicMetadata.flatMap { topicMetadata => + topicMetadata.partitions.asScala.flatMap { partitionMetadata => + partitionMetadata.replicas.asScala.map { node => + TopicPartitionReplica(topic = topicMetadata.name, partitionId = partitionMetadata.partition, replicaId = node.id) + } + } + } + debug(s"Selected topic partitions: $topicPartitionReplicas") + val brokerToTopicPartitions = topicPartitionReplicas.groupBy(_.replicaId).map { case (brokerId, partitions) => + brokerId -> partitions.map { partition => new TopicPartition(partition.topic, partition.partitionId) } + } + debug(s"Topic partitions per broker: $brokerToTopicPartitions") + val expectedReplicasPerTopicPartition = topicPartitionReplicas.groupBy { replica => + new TopicPartition(replica.topic, replica.partitionId) + }.map { case (topicAndPartition, replicaSet) => topicAndPartition -> replicaSet.size } + debug(s"Expected replicas per topic partition: $expectedReplicasPerTopicPartition") + + val topicPartitions = filteredTopicMetadata.flatMap { topicMetaData => + topicMetaData.partitions.asScala.map { partitionMetadata => + new TopicPartition(topicMetaData.name, partitionMetadata.partition) + } + } + + val consumerProps = consumerConfig(brokerList) + + val replicaBuffer = new ReplicaBuffer(expectedReplicasPerTopicPartition, + initialOffsets(topicPartitions, consumerProps, initialOffsetTime), + brokerToTopicPartitions.size, + reportInterval) + // create all replica fetcher threads + val verificationBrokerId = brokerToTopicPartitions.head._1 + val counter = new AtomicInteger(0) + val fetcherThreads = brokerToTopicPartitions.map { case (brokerId, topicPartitions) => + new ReplicaFetcher(name = s"ReplicaFetcher-$brokerId", + sourceBroker = brokerInfo(brokerId), + topicPartitions = topicPartitions, + topicIds = topicIds, + replicaBuffer = replicaBuffer, + socketTimeout = 30000, + socketBufferSize = 256000, + fetchSize = fetchSize, + maxWait = maxWaitMs, + minBytes = 1, + doVerification = brokerId == verificationBrokerId, + consumerProps, + fetcherId = counter.incrementAndGet()) + } + + Exit.addShutdownHook("ReplicaVerificationToolShutdownHook", { + info("Stopping all fetchers") + fetcherThreads.foreach(_.shutdown()) + }) + fetcherThreads.foreach(_.start()) + println(s"${ReplicaVerificationTool.getCurrentTimeString()}: verification process is started.") + + } + + private def listTopicsMetadata(adminClient: Admin): Seq[TopicDescription] = { + val topics = adminClient.listTopics(new ListTopicsOptions().listInternal(true)).names.get + adminClient.describeTopics(topics).allTopicNames.get.values.asScala.toBuffer + } + + private def brokerDetails(adminClient: Admin): Map[Int, Node] = { + adminClient.describeCluster.nodes.get.asScala.map(n => (n.id, n)).toMap + } + + private def createAdminClient(brokerUrl: String): Admin = { + val props = new Properties() + props.put(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, brokerUrl) + Admin.create(props) + } + + private def initialOffsets(topicPartitions: Seq[TopicPartition], consumerConfig: Properties, + initialOffsetTime: Long): collection.Map[TopicPartition, Long] = { + val consumer = createConsumer(consumerConfig) + try { + if (ListOffsetsRequest.LATEST_TIMESTAMP == initialOffsetTime) + consumer.endOffsets(topicPartitions.asJava).asScala.map { case (k, v) => k -> v.longValue } + else if (ListOffsetsRequest.EARLIEST_TIMESTAMP == initialOffsetTime) + consumer.beginningOffsets(topicPartitions.asJava).asScala.map { case (k, v) => k -> v.longValue } + else { + val timestampsToSearch = topicPartitions.map(tp => tp -> (initialOffsetTime: java.lang.Long)).toMap + consumer.offsetsForTimes(timestampsToSearch.asJava).asScala.map { case (k, v) => k -> v.offset } + } + } finally consumer.close() + } + + private def consumerConfig(brokerUrl: String): Properties = { + val properties = new Properties() + properties.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerUrl) + properties.put(ConsumerConfig.GROUP_ID_CONFIG, "ReplicaVerification") + properties.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, classOf[StringDeserializer]) + properties.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, classOf[StringDeserializer]) + properties + } + + private def createConsumer(consumerConfig: Properties): KafkaConsumer[String, String] = + new KafkaConsumer(consumerConfig) +} + +private case class TopicPartitionReplica(topic: String, partitionId: Int, replicaId: Int) + +private case class MessageInfo(replicaId: Int, offset: Long, nextOffset: Long, checksum: Long) + +private class ReplicaBuffer(expectedReplicasPerTopicPartition: collection.Map[TopicPartition, Int], + initialOffsets: collection.Map[TopicPartition, Long], + expectedNumFetchers: Int, + reportInterval: Long) extends Logging { + private val fetchOffsetMap = new Pool[TopicPartition, Long] + private val recordsCache = new Pool[TopicPartition, Pool[Int, FetchResponseData.PartitionData]] + private val fetcherBarrier = new AtomicReference(new CountDownLatch(expectedNumFetchers)) + private val verificationBarrier = new AtomicReference(new CountDownLatch(1)) + @volatile private var lastReportTime = Time.SYSTEM.milliseconds + private var maxLag: Long = -1L + private var offsetWithMaxLag: Long = -1L + private var maxLagTopicAndPartition: TopicPartition = null + initialize() + + def createNewFetcherBarrier(): Unit = { + fetcherBarrier.set(new CountDownLatch(expectedNumFetchers)) + } + + def getFetcherBarrier() = fetcherBarrier.get + + def createNewVerificationBarrier(): Unit = { + verificationBarrier.set(new CountDownLatch(1)) + } + + def getVerificationBarrier() = verificationBarrier.get + + private def initialize(): Unit = { + for (topicPartition <- expectedReplicasPerTopicPartition.keySet) + recordsCache.put(topicPartition, new Pool[Int, FetchResponseData.PartitionData]) + setInitialOffsets() + } + + + private def setInitialOffsets(): Unit = { + for ((tp, offset) <- initialOffsets) + fetchOffsetMap.put(tp, offset) + } + + def addFetchedData(topicAndPartition: TopicPartition, replicaId: Int, partitionData: FetchResponseData.PartitionData): Unit = { + recordsCache.get(topicAndPartition).put(replicaId, partitionData) + } + + def getOffset(topicAndPartition: TopicPartition) = { + fetchOffsetMap.get(topicAndPartition) + } + + def verifyCheckSum(println: String => Unit): Unit = { + debug("Begin verification") + maxLag = -1L + for ((topicPartition, fetchResponsePerReplica) <- recordsCache) { + debug(s"Verifying $topicPartition") + assert(fetchResponsePerReplica.size == expectedReplicasPerTopicPartition(topicPartition), + "fetched " + fetchResponsePerReplica.size + " replicas for " + topicPartition + ", but expected " + + expectedReplicasPerTopicPartition(topicPartition) + " replicas") + val recordBatchIteratorMap = fetchResponsePerReplica.map { case (replicaId, fetchResponse) => + replicaId -> FetchResponse.recordsOrFail(fetchResponse).batches.iterator + } + val maxHw = fetchResponsePerReplica.values.map(_.highWatermark).max + + // Iterate one message at a time from every replica, until high watermark is reached. + var isMessageInAllReplicas = true + while (isMessageInAllReplicas) { + var messageInfoFromFirstReplicaOpt: Option[MessageInfo] = None + for ((replicaId, recordBatchIterator) <- recordBatchIteratorMap) { + try { + if (recordBatchIterator.hasNext) { + val batch = recordBatchIterator.next() + + // only verify up to the high watermark + if (batch.lastOffset >= fetchResponsePerReplica.get(replicaId).highWatermark) + isMessageInAllReplicas = false + else { + messageInfoFromFirstReplicaOpt match { + case None => + messageInfoFromFirstReplicaOpt = Some( + MessageInfo(replicaId, batch.lastOffset, batch.nextOffset, batch.checksum)) + case Some(messageInfoFromFirstReplica) => + if (messageInfoFromFirstReplica.offset != batch.lastOffset) { + println(ReplicaVerificationTool.getCurrentTimeString() + ": partition " + topicPartition + + ": replica " + messageInfoFromFirstReplica.replicaId + "'s offset " + + messageInfoFromFirstReplica.offset + " doesn't match replica " + + replicaId + "'s offset " + batch.lastOffset) + Exit.exit(1) + } + if (messageInfoFromFirstReplica.checksum != batch.checksum) + println(ReplicaVerificationTool.getCurrentTimeString() + ": partition " + + topicPartition + " has unmatched checksum at offset " + batch.lastOffset + "; replica " + + messageInfoFromFirstReplica.replicaId + "'s checksum " + messageInfoFromFirstReplica.checksum + + "; replica " + replicaId + "'s checksum " + batch.checksum) + } + } + } else + isMessageInAllReplicas = false + } catch { + case t: Throwable => + throw new RuntimeException("Error in processing replica %d in partition %s at offset %d." + .format(replicaId, topicPartition, fetchOffsetMap.get(topicPartition)), t) + } + } + if (isMessageInAllReplicas) { + val nextOffset = messageInfoFromFirstReplicaOpt.get.nextOffset + fetchOffsetMap.put(topicPartition, nextOffset) + debug(s"${expectedReplicasPerTopicPartition(topicPartition)} replicas match at offset " + + s"$nextOffset for $topicPartition") + } + } + if (maxHw - fetchOffsetMap.get(topicPartition) > maxLag) { + offsetWithMaxLag = fetchOffsetMap.get(topicPartition) + maxLag = maxHw - offsetWithMaxLag + maxLagTopicAndPartition = topicPartition + } + fetchResponsePerReplica.clear() + } + val currentTimeMs = Time.SYSTEM.milliseconds + if (currentTimeMs - lastReportTime > reportInterval) { + println(ReplicaVerificationTool.dateFormat.format(new Date(currentTimeMs)) + ": max lag is " + + maxLag + " for partition " + maxLagTopicAndPartition + " at offset " + offsetWithMaxLag + + " among " + recordsCache.size + " partitions") + lastReportTime = currentTimeMs + } + } +} + +private class ReplicaFetcher(name: String, sourceBroker: Node, topicPartitions: Iterable[TopicPartition], + topicIds: Map[String, Uuid], replicaBuffer: ReplicaBuffer, socketTimeout: Int, socketBufferSize: Int, + fetchSize: Int, maxWait: Int, minBytes: Int, doVerification: Boolean, consumerConfig: Properties, + fetcherId: Int) + extends ShutdownableThread(name) { + + private val fetchEndpoint = new ReplicaFetcherBlockingSend(sourceBroker, new ConsumerConfig(consumerConfig), new Metrics(), Time.SYSTEM, fetcherId, + s"broker-${Request.DebuggingConsumerId}-fetcher-$fetcherId") + + private val topicNames = topicIds.map(_.swap) + + override def doWork(): Unit = { + + val fetcherBarrier = replicaBuffer.getFetcherBarrier() + val verificationBarrier = replicaBuffer.getVerificationBarrier() + + val requestMap = new util.LinkedHashMap[TopicPartition, JFetchRequest.PartitionData] + for (topicPartition <- topicPartitions) + requestMap.put(topicPartition, new JFetchRequest.PartitionData(topicIds.getOrElse(topicPartition.topic, Uuid.ZERO_UUID), replicaBuffer.getOffset(topicPartition), + 0L, fetchSize, Optional.empty())) + + val fetchRequestBuilder = JFetchRequest.Builder. + forReplica(ApiKeys.FETCH.latestVersion, Request.DebuggingConsumerId, maxWait, minBytes, requestMap) + + debug("Issuing fetch request ") + + var fetchResponse: FetchResponse = null + try { + val clientResponse = fetchEndpoint.sendRequest(fetchRequestBuilder) + fetchResponse = clientResponse.responseBody.asInstanceOf[FetchResponse] + } catch { + case t: Throwable => + if (!isRunning) + throw t + } + + if (fetchResponse != null) { + fetchResponse.responseData(topicNames.asJava, ApiKeys.FETCH.latestVersion()).forEach { (tp, partitionData) => + replicaBuffer.addFetchedData(tp, sourceBroker.id, partitionData) + } + } else { + for (topicAndPartition <- topicPartitions) + replicaBuffer.addFetchedData(topicAndPartition, sourceBroker.id, FetchResponse.partitionResponse(topicAndPartition.partition, Errors.NONE)) + } + + fetcherBarrier.countDown() + debug("Done fetching") + + // wait for all fetchers to finish + fetcherBarrier.await() + debug("Ready for verification") + + // one of the fetchers will do the verification + if (doVerification) { + debug("Do verification") + replicaBuffer.verifyCheckSum(println) + replicaBuffer.createNewFetcherBarrier() + replicaBuffer.createNewVerificationBarrier() + debug("Created new barrier") + verificationBarrier.countDown() + } + + verificationBarrier.await() + debug("Done verification") + } +} + +private class ReplicaFetcherBlockingSend(sourceNode: Node, + consumerConfig: ConsumerConfig, + metrics: Metrics, + time: Time, + fetcherId: Int, + clientId: String) { + + private val socketTimeout: Int = consumerConfig.getInt(ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG) + + private val networkClient = { + val logContext = new LogContext() + val channelBuilder = org.apache.kafka.clients.ClientUtils.createChannelBuilder(consumerConfig, time, logContext) + val selector = new Selector( + NetworkReceive.UNLIMITED, + consumerConfig.getLong(ConsumerConfig.CONNECTIONS_MAX_IDLE_MS_CONFIG), + metrics, + time, + "replica-fetcher", + Map("broker-id" -> sourceNode.id.toString, "fetcher-id" -> fetcherId.toString).asJava, + false, + channelBuilder, + logContext + ) + new NetworkClient( + selector, + new ManualMetadataUpdater(), + clientId, + 1, + 0, + 0, + Selectable.USE_DEFAULT_BUFFER_SIZE, + consumerConfig.getInt(ConsumerConfig.RECEIVE_BUFFER_CONFIG), + consumerConfig.getInt(ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG), + consumerConfig.getLong(ConsumerConfig.SOCKET_CONNECTION_SETUP_TIMEOUT_MS_CONFIG), + consumerConfig.getLong(ConsumerConfig.SOCKET_CONNECTION_SETUP_TIMEOUT_MAX_MS_CONFIG), + time, + false, + new ApiVersions, + logContext + ) + } + + def sendRequest(requestBuilder: Builder[_ <: AbstractRequest]): ClientResponse = { + try { + if (!NetworkClientUtils.awaitReady(networkClient, sourceNode, time, socketTimeout)) + throw new SocketTimeoutException(s"Failed to connect within $socketTimeout ms") + else { + val clientRequest = networkClient.newClientRequest(sourceNode.id.toString, requestBuilder, + time.milliseconds(), true) + NetworkClientUtils.sendAndReceive(networkClient, clientRequest, time) + } + } + catch { + case e: Throwable => + networkClient.close(sourceNode.id.toString) + throw e + } + } + + def close(): Unit = { + networkClient.close() + } +} diff --git a/core/src/main/scala/kafka/tools/StateChangeLogMerger.scala b/core/src/main/scala/kafka/tools/StateChangeLogMerger.scala new file mode 100755 index 0000000..de711e5 --- /dev/null +++ b/core/src/main/scala/kafka/tools/StateChangeLogMerger.scala @@ -0,0 +1,196 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.tools + +import joptsimple._ + +import scala.util.matching.Regex +import collection.mutable +import java.util.Date +import java.text.SimpleDateFormat + +import kafka.utils.{CommandLineUtils, CoreUtils, Exit, Logging} +import java.io.{BufferedOutputStream, OutputStream} +import java.nio.charset.StandardCharsets + +import org.apache.kafka.common.internals.Topic + +/** + * A utility that merges the state change logs (possibly obtained from different brokers and over multiple days). + * + * This utility expects at least one of the following two arguments - + * 1. A list of state change log files + * 2. A regex to specify state change log file names. + * + * This utility optionally also accepts the following arguments - + * 1. The topic whose state change logs should be merged + * 2. A list of partitions whose state change logs should be merged (can be specified only when the topic argument + * is explicitly specified) + * 3. Start time from when the logs should be merged + * 4. End time until when the logs should be merged + */ + +object StateChangeLogMerger extends Logging { + + val dateFormatString = "yyyy-MM-dd HH:mm:ss,SSS" + val topicPartitionRegex = new Regex("\\[(" + Topic.LEGAL_CHARS + "+),( )*([0-9]+)\\]") + val dateRegex = new Regex("[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2},[0-9]{3}") + val dateFormat = new SimpleDateFormat(dateFormatString) + var files: List[String] = List() + var topic: String = null + var partitions: List[Int] = List() + var startDate: Date = null + var endDate: Date = null + + def main(args: Array[String]): Unit = { + + // Parse input arguments. + val parser = new OptionParser(false) + val filesOpt = parser.accepts("logs", "Comma separated list of state change logs or a regex for the log file names") + .withRequiredArg + .describedAs("file1,file2,...") + .ofType(classOf[String]) + val regexOpt = parser.accepts("logs-regex", "Regex to match the state change log files to be merged") + .withRequiredArg + .describedAs("for example: /tmp/state-change.log*") + .ofType(classOf[String]) + val topicOpt = parser.accepts("topic", "The topic whose state change logs should be merged") + .withRequiredArg + .describedAs("topic") + .ofType(classOf[String]) + val partitionsOpt = parser.accepts("partitions", "Comma separated list of partition ids whose state change logs should be merged") + .withRequiredArg + .describedAs("0,1,2,...") + .ofType(classOf[String]) + val startTimeOpt = parser.accepts("start-time", "The earliest timestamp of state change log entries to be merged") + .withRequiredArg + .describedAs("start timestamp in the format " + dateFormat) + .ofType(classOf[String]) + .defaultsTo("0000-00-00 00:00:00,000") + val endTimeOpt = parser.accepts("end-time", "The latest timestamp of state change log entries to be merged") + .withRequiredArg + .describedAs("end timestamp in the format " + dateFormat) + .ofType(classOf[String]) + .defaultsTo("9999-12-31 23:59:59,999") + + if(args.length == 0) + CommandLineUtils.printUsageAndDie(parser, "A tool for merging the log files from several brokers to reconnstruct a unified history of what happened.") + + + val options = parser.parse(args : _*) + if ((!options.has(filesOpt) && !options.has(regexOpt)) || (options.has(filesOpt) && options.has(regexOpt))) { + System.err.println("Provide arguments to exactly one of the two options \"" + filesOpt + "\" or \"" + regexOpt + "\"") + parser.printHelpOn(System.err) + Exit.exit(1) + } + if (options.has(partitionsOpt) && !options.has(topicOpt)) { + System.err.println("The option \"" + topicOpt + "\" needs to be provided an argument when specifying partition ids") + parser.printHelpOn(System.err) + Exit.exit(1) + } + + // Populate data structures. + if (options.has(filesOpt)) { + files :::= options.valueOf(filesOpt).split(",").toList + } else if (options.has(regexOpt)) { + val regex = options.valueOf(regexOpt) + val fileNameIndex = regex.lastIndexOf('/') + 1 + val dirName = if (fileNameIndex == 0) "." else regex.substring(0, fileNameIndex - 1) + val fileNameRegex = new Regex(regex.substring(fileNameIndex)) + files :::= new java.io.File(dirName).listFiles.filter(f => fileNameRegex.findFirstIn(f.getName).isDefined).map(dirName + "/" + _.getName).toList + } + if (options.has(topicOpt)) { + topic = options.valueOf(topicOpt) + } + if (options.has(partitionsOpt)) { + partitions = options.valueOf(partitionsOpt).split(",").toList.map(_.toInt) + val duplicatePartitions = CoreUtils.duplicates(partitions) + if (duplicatePartitions.nonEmpty) { + System.err.println("The list of partitions contains repeated entries: %s".format(duplicatePartitions.mkString(","))) + Exit.exit(1) + } + } + startDate = dateFormat.parse(options.valueOf(startTimeOpt).replace('\"', ' ').trim) + endDate = dateFormat.parse(options.valueOf(endTimeOpt).replace('\"', ' ').trim) + + /** + * n-way merge from m input files: + * 1. Read a line that matches the specified topic/partitions and date range from every input file in a priority queue. + * 2. Take the line from the file with the earliest date and add it to a buffered output stream. + * 3. Add another line from the file selected in step 2 in the priority queue. + * 4. Flush the output buffer at the end. (The buffer will also be automatically flushed every K bytes.) + */ + val pqueue = new mutable.PriorityQueue[LineIterator]()(dateBasedOrdering) + val output: OutputStream = new BufferedOutputStream(System.out, 1024*1024) + val lineIterators = files.map(scala.io.Source.fromFile(_).getLines()) + var lines: List[LineIterator] = List() + + for (itr <- lineIterators) { + val lineItr = getNextLine(itr) + if (!lineItr.isEmpty) + lines ::= lineItr + } + if (lines.nonEmpty) pqueue.enqueue(lines:_*) + + while (pqueue.nonEmpty) { + val lineItr = pqueue.dequeue() + output.write((lineItr.line + "\n").getBytes(StandardCharsets.UTF_8)) + val nextLineItr = getNextLine(lineItr.itr) + if (!nextLineItr.isEmpty) + pqueue.enqueue(nextLineItr) + } + + output.flush() + } + + /** + * Returns the next line that matches the specified topic/partitions from the file that has the earliest date + * from the specified date range. + * @param itr Line iterator of a file + * @return (line from a file, line iterator for the same file) + */ + def getNextLine(itr: Iterator[String]): LineIterator = { + while (itr != null && itr.hasNext) { + val nextLine = itr.next() + dateRegex.findFirstIn(nextLine).foreach { d => + val date = dateFormat.parse(d) + if ((date.equals(startDate) || date.after(startDate)) && (date.equals(endDate) || date.before(endDate))) { + topicPartitionRegex.findFirstMatchIn(nextLine).foreach { matcher => + if ((topic == null || topic == matcher.group(1)) && (partitions.isEmpty || partitions.contains(matcher.group(3).toInt))) + return new LineIterator(nextLine, itr) + } + } + } + } + new LineIterator() + } + + class LineIterator(val line: String, val itr: Iterator[String]) { + def this() = this("", null) + def isEmpty = line == "" && itr == null + } + + implicit object dateBasedOrdering extends Ordering[LineIterator] { + def compare(first: LineIterator, second: LineIterator) = { + val firstDate = dateRegex.findFirstIn(first.line).get + val secondDate = dateRegex.findFirstIn(second.line).get + secondDate.compareTo(firstDate) + } + } + +} diff --git a/core/src/main/scala/kafka/tools/StorageTool.scala b/core/src/main/scala/kafka/tools/StorageTool.scala new file mode 100644 index 0000000..28377d2 --- /dev/null +++ b/core/src/main/scala/kafka/tools/StorageTool.scala @@ -0,0 +1,238 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.tools + +import java.io.PrintStream +import java.nio.file.{Files, Paths} + +import kafka.server.{BrokerMetadataCheckpoint, KafkaConfig, MetaProperties, RawMetaProperties} +import kafka.utils.{Exit, Logging} +import net.sourceforge.argparse4j.ArgumentParsers +import net.sourceforge.argparse4j.impl.Arguments.{store, storeTrue} +import org.apache.kafka.common.Uuid +import org.apache.kafka.common.utils.Utils + +import scala.collection.mutable + +object StorageTool extends Logging { + def main(args: Array[String]): Unit = { + try { + val parser = ArgumentParsers. + newArgumentParser("kafka-storage"). + defaultHelp(true). + description("The Kafka storage tool.") + val subparsers = parser.addSubparsers().dest("command") + + val infoParser = subparsers.addParser("info"). + help("Get information about the Kafka log directories on this node.") + val formatParser = subparsers.addParser("format"). + help("Format the Kafka log directories on this node.") + subparsers.addParser("random-uuid").help("Print a random UUID.") + List(infoParser, formatParser).foreach(parser => { + parser.addArgument("--config", "-c"). + action(store()). + required(true). + help("The Kafka configuration file to use.") + }) + formatParser.addArgument("--cluster-id", "-t"). + action(store()). + required(true). + help("The cluster ID to use.") + formatParser.addArgument("--ignore-formatted", "-g"). + action(storeTrue()) + + val namespace = parser.parseArgsOrFail(args) + val command = namespace.getString("command") + val config = Option(namespace.getString("config")).flatMap( + p => Some(new KafkaConfig(Utils.loadProps(p)))) + + command match { + case "info" => + val directories = configToLogDirectories(config.get) + val selfManagedMode = configToSelfManagedMode(config.get) + Exit.exit(infoCommand(System.out, selfManagedMode, directories)) + + case "format" => + val directories = configToLogDirectories(config.get) + val clusterId = namespace.getString("cluster_id") + val metaProperties = buildMetadataProperties(clusterId, config.get) + val ignoreFormatted = namespace.getBoolean("ignore_formatted") + if (!configToSelfManagedMode(config.get)) { + throw new TerseFailure("The kafka configuration file appears to be for " + + "a legacy cluster. Formatting is only supported for clusters in KRaft mode.") + } + Exit.exit(formatCommand(System.out, directories, metaProperties, ignoreFormatted )) + + case "random-uuid" => + System.out.println(Uuid.randomUuid) + Exit.exit(0) + + case _ => + throw new RuntimeException(s"Unknown command $command") + } + } catch { + case e: TerseFailure => + System.err.println(e.getMessage) + System.exit(1) + } + } + + def configToLogDirectories(config: KafkaConfig): Seq[String] = { + val directories = new mutable.TreeSet[String] + directories ++= config.logDirs + Option(config.metadataLogDir).foreach(directories.add) + directories.toSeq + } + + def configToSelfManagedMode(config: KafkaConfig): Boolean = config.processRoles.nonEmpty + + def infoCommand(stream: PrintStream, selfManagedMode: Boolean, directories: Seq[String]): Int = { + val problems = new mutable.ArrayBuffer[String] + val foundDirectories = new mutable.ArrayBuffer[String] + var prevMetadata: Option[RawMetaProperties] = None + directories.sorted.foreach(directory => { + val directoryPath = Paths.get(directory) + if (!Files.isDirectory(directoryPath)) { + if (!Files.exists(directoryPath)) { + problems += s"$directoryPath does not exist" + } else { + problems += s"$directoryPath is not a directory" + } + } else { + foundDirectories += directoryPath.toString + val metaPath = directoryPath.resolve("meta.properties") + if (!Files.exists(metaPath)) { + problems += s"$directoryPath is not formatted." + } else { + val properties = Utils.loadProps(metaPath.toString) + val rawMetaProperties = new RawMetaProperties(properties) + + val curMetadata = rawMetaProperties.version match { + case 0 | 1 => Some(rawMetaProperties) + case v => + problems += s"Unsupported version for $metaPath: $v" + None + } + + if (prevMetadata.isEmpty) { + prevMetadata = curMetadata + } else { + if (!prevMetadata.get.equals(curMetadata.get)) { + problems += s"Metadata for $metaPath was ${curMetadata.get}, " + + s"but other directories featured ${prevMetadata.get}" + } + } + } + } + }) + + prevMetadata.foreach { prev => + if (selfManagedMode) { + if (prev.version == 0) { + problems += "The kafka configuration file appears to be for a cluster in KRaft mode, but " + + "the directories are formatted for legacy mode." + } + } else if (prev.version == 1) { + problems += "The kafka configuration file appears to be for a legacy cluster, but " + + "the directories are formatted for a cluster in KRaft mode." + } + } + + if (directories.isEmpty) { + stream.println("No directories specified.") + 0 + } else { + if (foundDirectories.nonEmpty) { + if (foundDirectories.size == 1) { + stream.println("Found log directory:") + } else { + stream.println("Found log directories:") + } + foundDirectories.foreach(d => stream.println(" %s".format(d))) + stream.println("") + } + + prevMetadata.foreach { prev => + stream.println(s"Found metadata: ${prev}") + stream.println("") + } + + if (problems.nonEmpty) { + if (problems.size == 1) { + stream.println("Found problem:") + } else { + stream.println("Found problems:") + } + problems.foreach(d => stream.println(" %s".format(d))) + stream.println("") + 1 + } else { + 0 + } + } + } + + def buildMetadataProperties( + clusterIdStr: String, + config: KafkaConfig + ): MetaProperties = { + val effectiveClusterId = try { + Uuid.fromString(clusterIdStr) + } catch { + case e: Throwable => throw new TerseFailure(s"Cluster ID string $clusterIdStr " + + s"does not appear to be a valid UUID: ${e.getMessage}") + } + require(config.nodeId >= 0, s"The node.id must be set to a non-negative integer.") + new MetaProperties(effectiveClusterId.toString, config.nodeId) + } + + def formatCommand(stream: PrintStream, + directories: Seq[String], + metaProperties: MetaProperties, + ignoreFormatted: Boolean): Int = { + if (directories.isEmpty) { + throw new TerseFailure("No log directories found in the configuration.") + } + val unformattedDirectories = directories.filter(directory => { + if (!Files.isDirectory(Paths.get(directory)) || !Files.exists(Paths.get(directory, "meta.properties"))) { + true + } else if (!ignoreFormatted) { + throw new TerseFailure(s"Log directory ${directory} is already formatted. " + + "Use --ignore-formatted to ignore this directory and format the others.") + } else { + false + } + }) + if (unformattedDirectories.isEmpty) { + stream.println("All of the log directories are already formatted.") + } + unformattedDirectories.foreach(directory => { + try { + Files.createDirectories(Paths.get(directory)) + } catch { + case e: Throwable => throw new TerseFailure(s"Unable to create storage " + + s"directory ${directory}: ${e.getMessage}") + } + val metaPropertiesPath = Paths.get(directory, "meta.properties") + val checkpoint = new BrokerMetadataCheckpoint(metaPropertiesPath.toFile) + checkpoint.write(metaProperties.toProperties) + stream.println(s"Formatting ${directory}") + }) + 0 + } +} diff --git a/core/src/main/scala/kafka/tools/StreamsResetter.java b/core/src/main/scala/kafka/tools/StreamsResetter.java new file mode 100644 index 0000000..3b7e1cb --- /dev/null +++ b/core/src/main/scala/kafka/tools/StreamsResetter.java @@ -0,0 +1,696 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.tools; + +import joptsimple.OptionException; +import joptsimple.OptionParser; +import joptsimple.OptionSet; +import joptsimple.OptionSpec; +import joptsimple.OptionSpecBuilder; +import kafka.utils.CommandLineUtils; +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.DeleteTopicsResult; +import org.apache.kafka.clients.admin.DescribeConsumerGroupsOptions; +import org.apache.kafka.clients.admin.DescribeConsumerGroupsResult; +import org.apache.kafka.clients.admin.MemberDescription; +import org.apache.kafka.clients.admin.RemoveMembersFromConsumerGroupOptions; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.consumer.OffsetAndTimestamp; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.requests.ListOffsetsResponse; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; +import scala.collection.JavaConverters; + +import java.io.IOException; +import java.text.ParseException; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +/** + * {@link StreamsResetter} resets the processing state of a Kafka Streams application so that, for example, + * you can reprocess its input from scratch. + *

                + * This class is not part of public API. For backward compatibility, + * use the provided script in "bin/" instead of calling this class directly from your code. + *

                + * Resetting the processing state of an application includes the following actions: + *

                  + *
                1. setting the application's consumer offsets for input and internal topics to zero
                2. + *
                3. skip over all intermediate user topics (i.e., "seekToEnd" for consumers of intermediate topics)
                4. + *
                5. deleting any topics created internally by Kafka Streams for this application
                6. + *
                + *

                + * Do only use this tool if no application instance is running. + * Otherwise, the application will get into an invalid state and crash or produce wrong results. + *

                + * If you run multiple application instances, running this tool once is sufficient. + * However, you need to call {@code KafkaStreams#cleanUp()} before re-starting any instance + * (to clean local state store directory). + * Otherwise, your application is in an invalid state. + *

                + * User output topics will not be deleted or modified by this tool. + * If downstream applications consume intermediate or output topics, + * it is the user's responsibility to adjust those applications manually if required. + */ +@InterfaceStability.Unstable +public class StreamsResetter { + private static final int EXIT_CODE_SUCCESS = 0; + private static final int EXIT_CODE_ERROR = 1; + + private static OptionSpec bootstrapServerOption; + private static OptionSpec applicationIdOption; + private static OptionSpec inputTopicsOption; + private static OptionSpec intermediateTopicsOption; + private static OptionSpec internalTopicsOption; + private static OptionSpec toOffsetOption; + private static OptionSpec toDatetimeOption; + private static OptionSpec byDurationOption; + private static OptionSpecBuilder toEarliestOption; + private static OptionSpecBuilder toLatestOption; + private static OptionSpec fromFileOption; + private static OptionSpec shiftByOption; + private static OptionSpecBuilder dryRunOption; + private static OptionSpec helpOption; + private static OptionSpec versionOption; + private static OptionSpec commandConfigOption; + private static OptionSpecBuilder forceOption; + + private final static String USAGE = "This tool helps to quickly reset an application in order to reprocess " + + "its data from scratch.\n" + + "* This tool resets offsets of input topics to the earliest available offset and it skips to the end of " + + "intermediate topics (topics that are input and output topics, e.g., used by deprecated through() method).\n" + + "* This tool deletes the internal topics that were created by Kafka Streams (topics starting with " + + "\"-\").\n" + + "The tool finds these internal topics automatically. If the topics flagged automatically for deletion by " + + "the dry-run are unsuitable, you can specify a subset with the \"--internal-topics\" option.\n" + + "* This tool will not delete output topics (if you want to delete them, you need to do it yourself " + + "with the bin/kafka-topics.sh command).\n" + + "* This tool will not clean up the local state on the stream application instances (the persisted " + + "stores used to cache aggregation results).\n" + + "You need to call KafkaStreams#cleanUp() in your application or manually delete them from the " + + "directory specified by \"state.dir\" configuration (${java.io.tmpdir}/kafka-streams/ by default).\n" + + "* When long session timeout has been configured, active members could take longer to get expired on the " + + "broker thus blocking the reset job to complete. Use the \"--force\" option could remove those left-over " + + "members immediately. Make sure to stop all stream applications when this option is specified " + + "to avoid unexpected disruptions.\n\n" + + "*** Important! You will get wrong output if you don't clean up the local stores after running the " + + "reset tool!\n\n" + + "*** Warning! This tool makes irreversible changes to your application. It is strongly recommended that " + + "you run this once with \"--dry-run\" to preview your changes before making them.\n\n"; + + private OptionSet options = null; + private final List allTopics = new LinkedList<>(); + + + public int run(final String[] args) { + return run(args, new Properties()); + } + + public int run(final String[] args, + final Properties config) { + int exitCode; + + Admin adminClient = null; + try { + parseArguments(args); + + final boolean dryRun = options.has(dryRunOption); + + final String groupId = options.valueOf(applicationIdOption); + final Properties properties = new Properties(); + if (options.has(commandConfigOption)) { + properties.putAll(Utils.loadProps(options.valueOf(commandConfigOption))); + } + properties.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, options.valueOf(bootstrapServerOption)); + + adminClient = Admin.create(properties); + maybeDeleteActiveConsumers(groupId, adminClient); + + allTopics.clear(); + allTopics.addAll(adminClient.listTopics().names().get(60, TimeUnit.SECONDS)); + + if (dryRun) { + System.out.println("----Dry run displays the actions which will be performed when running Streams Reset Tool----"); + } + + final HashMap consumerConfig = new HashMap<>(config); + consumerConfig.putAll(properties); + exitCode = maybeResetInputAndSeekToEndIntermediateTopicOffsets(consumerConfig, dryRun); + exitCode |= maybeDeleteInternalTopics(adminClient, dryRun); + } catch (final Throwable e) { + exitCode = EXIT_CODE_ERROR; + System.err.println("ERROR: " + e); + e.printStackTrace(System.err); + } finally { + if (adminClient != null) { + adminClient.close(Duration.ofSeconds(60)); + } + } + + return exitCode; + } + + private void maybeDeleteActiveConsumers(final String groupId, + final Admin adminClient) + throws ExecutionException, InterruptedException { + + final DescribeConsumerGroupsResult describeResult = adminClient.describeConsumerGroups( + Collections.singleton(groupId), + new DescribeConsumerGroupsOptions().timeoutMs(10 * 1000)); + final List members = + new ArrayList<>(describeResult.describedGroups().get(groupId).get().members()); + if (!members.isEmpty()) { + if (options.has(forceOption)) { + System.out.println("Force deleting all active members in the group: " + groupId); + adminClient.removeMembersFromConsumerGroup(groupId, new RemoveMembersFromConsumerGroupOptions()).all().get(); + } else { + throw new IllegalStateException("Consumer group '" + groupId + "' is still active " + + "and has following members: " + members + ". " + + "Make sure to stop all running application instances before running the reset tool." + + " You can use option '--force' to remove active members from the group."); + } + } + } + + private void parseArguments(final String[] args) { + final OptionParser optionParser = new OptionParser(false); + applicationIdOption = optionParser.accepts("application-id", "The Kafka Streams application ID (application.id).") + .withRequiredArg() + .ofType(String.class) + .describedAs("id") + .required(); + bootstrapServerOption = optionParser.accepts("bootstrap-servers", "Comma-separated list of broker urls with format: HOST1:PORT1,HOST2:PORT2") + .withRequiredArg() + .ofType(String.class) + .defaultsTo("localhost:9092") + .describedAs("urls"); + inputTopicsOption = optionParser.accepts("input-topics", "Comma-separated list of user input topics. For these topics, the tool will reset the offset to the earliest available offset.") + .withRequiredArg() + .ofType(String.class) + .withValuesSeparatedBy(',') + .describedAs("list"); + intermediateTopicsOption = optionParser.accepts("intermediate-topics", "Comma-separated list of intermediate user topics (topics that are input and output topics, e.g., used in the deprecated through() method). For these topics, the tool will skip to the end.") + .withRequiredArg() + .ofType(String.class) + .withValuesSeparatedBy(',') + .describedAs("list"); + internalTopicsOption = optionParser.accepts("internal-topics", "Comma-separated list of " + + "internal topics to delete. Must be a subset of the internal topics marked for deletion by the " + + "default behaviour (do a dry-run without this option to view these topics).") + .withRequiredArg() + .ofType(String.class) + .withValuesSeparatedBy(',') + .describedAs("list"); + toOffsetOption = optionParser.accepts("to-offset", "Reset offsets to a specific offset.") + .withRequiredArg() + .ofType(Long.class); + toDatetimeOption = optionParser.accepts("to-datetime", "Reset offsets to offset from datetime. Format: 'YYYY-MM-DDTHH:mm:SS.sss'") + .withRequiredArg() + .ofType(String.class); + byDurationOption = optionParser.accepts("by-duration", "Reset offsets to offset by duration from current timestamp. Format: 'PnDTnHnMnS'") + .withRequiredArg() + .ofType(String.class); + toEarliestOption = optionParser.accepts("to-earliest", "Reset offsets to earliest offset."); + toLatestOption = optionParser.accepts("to-latest", "Reset offsets to latest offset."); + fromFileOption = optionParser.accepts("from-file", "Reset offsets to values defined in CSV file.") + .withRequiredArg() + .ofType(String.class); + shiftByOption = optionParser.accepts("shift-by", "Reset offsets shifting current offset by 'n', where 'n' can be positive or negative") + .withRequiredArg() + .describedAs("number-of-offsets") + .ofType(Long.class); + commandConfigOption = optionParser.accepts("config-file", "Property file containing configs to be passed to admin clients and embedded consumer.") + .withRequiredArg() + .ofType(String.class) + .describedAs("file name"); + forceOption = optionParser.accepts("force", "Force the removal of members of the consumer group (intended to remove stopped members if a long session timeout was used). " + + "Make sure to shut down all stream applications when this option is specified to avoid unexpected rebalances."); + + dryRunOption = optionParser.accepts("dry-run", "Display the actions that would be performed without executing the reset commands."); + helpOption = optionParser.accepts("help", "Print usage information.").forHelp(); + versionOption = optionParser.accepts("version", "Print version information and exit.").forHelp(); + + try { + options = optionParser.parse(args); + if (args.length == 0 || options.has(helpOption)) { + CommandLineUtils.printUsageAndDie(optionParser, USAGE); + } + if (options.has(versionOption)) { + CommandLineUtils.printVersionAndDie(); + } + } catch (final OptionException e) { + CommandLineUtils.printUsageAndDie(optionParser, e.getMessage()); + } + + final Set> allScenarioOptions = new HashSet<>(); + allScenarioOptions.add(toOffsetOption); + allScenarioOptions.add(toDatetimeOption); + allScenarioOptions.add(byDurationOption); + allScenarioOptions.add(toEarliestOption); + allScenarioOptions.add(toLatestOption); + allScenarioOptions.add(fromFileOption); + allScenarioOptions.add(shiftByOption); + + checkInvalidArgs(optionParser, options, allScenarioOptions, toOffsetOption); + checkInvalidArgs(optionParser, options, allScenarioOptions, toDatetimeOption); + checkInvalidArgs(optionParser, options, allScenarioOptions, byDurationOption); + checkInvalidArgs(optionParser, options, allScenarioOptions, toEarliestOption); + checkInvalidArgs(optionParser, options, allScenarioOptions, toLatestOption); + checkInvalidArgs(optionParser, options, allScenarioOptions, fromFileOption); + checkInvalidArgs(optionParser, options, allScenarioOptions, shiftByOption); + } + + private void checkInvalidArgs(final OptionParser optionParser, + final OptionSet options, + final Set> allOptions, + final OptionSpec option) { + final Set> invalidOptions = new HashSet<>(allOptions); + invalidOptions.remove(option); + CommandLineUtils.checkInvalidArgs( + optionParser, + options, + option, + JavaConverters.asScalaSetConverter(invalidOptions).asScala()); + } + + private int maybeResetInputAndSeekToEndIntermediateTopicOffsets(final Map consumerConfig, + final boolean dryRun) + throws IOException, ParseException { + + final List inputTopics = options.valuesOf(inputTopicsOption); + final List intermediateTopics = options.valuesOf(intermediateTopicsOption); + int topicNotFound = EXIT_CODE_SUCCESS; + + final List notFoundInputTopics = new ArrayList<>(); + final List notFoundIntermediateTopics = new ArrayList<>(); + + final String groupId = options.valueOf(applicationIdOption); + + if (inputTopics.size() == 0 && intermediateTopics.size() == 0) { + System.out.println("No input or intermediate topics specified. Skipping seek."); + return EXIT_CODE_SUCCESS; + } + + if (inputTopics.size() != 0) { + System.out.println("Reset-offsets for input topics " + inputTopics); + } + if (intermediateTopics.size() != 0) { + System.out.println("Seek-to-end for intermediate topics " + intermediateTopics); + } + + final Set topicsToSubscribe = new HashSet<>(inputTopics.size() + intermediateTopics.size()); + + for (final String topic : inputTopics) { + if (!allTopics.contains(topic)) { + notFoundInputTopics.add(topic); + } else { + topicsToSubscribe.add(topic); + } + } + for (final String topic : intermediateTopics) { + if (!allTopics.contains(topic)) { + notFoundIntermediateTopics.add(topic); + } else { + topicsToSubscribe.add(topic); + } + } + + if (!notFoundInputTopics.isEmpty()) { + System.out.println("Following input topics are not found, skipping them"); + for (final String topic : notFoundInputTopics) { + System.out.println("Topic: " + topic); + } + topicNotFound = EXIT_CODE_ERROR; + } + + if (!notFoundIntermediateTopics.isEmpty()) { + System.out.println("Following intermediate topics are not found, skipping them"); + for (final String topic : notFoundIntermediateTopics) { + System.out.println("Topic:" + topic); + } + topicNotFound = EXIT_CODE_ERROR; + } + + // Return early if there are no topics to reset (the consumer will raise an error if we + // try to poll with an empty subscription) + if (topicsToSubscribe.isEmpty()) { + return topicNotFound; + } + + final Properties config = new Properties(); + config.putAll(consumerConfig); + config.setProperty(ConsumerConfig.GROUP_ID_CONFIG, groupId); + config.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false"); + + try (final KafkaConsumer client = + new KafkaConsumer<>(config, new ByteArrayDeserializer(), new ByteArrayDeserializer())) { + + final Collection partitions = topicsToSubscribe.stream().map(client::partitionsFor) + .flatMap(Collection::stream) + .map(info -> new TopicPartition(info.topic(), info.partition())) + .collect(Collectors.toList()); + client.assign(partitions); + + final Set inputTopicPartitions = new HashSet<>(); + final Set intermediateTopicPartitions = new HashSet<>(); + + for (final TopicPartition p : partitions) { + final String topic = p.topic(); + if (isInputTopic(topic)) { + inputTopicPartitions.add(p); + } else if (isIntermediateTopic(topic)) { + intermediateTopicPartitions.add(p); + } else { + System.err.println("Skipping invalid partition: " + p); + } + } + + maybeReset(groupId, client, inputTopicPartitions); + + maybeSeekToEnd(groupId, client, intermediateTopicPartitions); + + if (!dryRun) { + for (final TopicPartition p : partitions) { + client.position(p); + } + client.commitSync(); + } + } catch (final IOException | ParseException e) { + System.err.println("ERROR: Resetting offsets failed."); + throw e; + } + System.out.println("Done."); + return topicNotFound; + } + + // visible for testing + public void maybeSeekToEnd(final String groupId, + final Consumer client, + final Set intermediateTopicPartitions) { + if (intermediateTopicPartitions.size() > 0) { + System.out.println("Following intermediate topics offsets will be reset to end (for consumer group " + groupId + ")"); + for (final TopicPartition topicPartition : intermediateTopicPartitions) { + if (allTopics.contains(topicPartition.topic())) { + System.out.println("Topic: " + topicPartition.topic()); + } + } + client.seekToEnd(intermediateTopicPartitions); + } + } + + private void maybeReset(final String groupId, + final Consumer client, + final Set inputTopicPartitions) + throws IOException, ParseException { + + if (inputTopicPartitions.size() > 0) { + System.out.println("Following input topics offsets will be reset to (for consumer group " + groupId + ")"); + if (options.has(toOffsetOption)) { + resetOffsetsTo(client, inputTopicPartitions, options.valueOf(toOffsetOption)); + } else if (options.has(toEarliestOption)) { + client.seekToBeginning(inputTopicPartitions); + } else if (options.has(toLatestOption)) { + client.seekToEnd(inputTopicPartitions); + } else if (options.has(shiftByOption)) { + shiftOffsetsBy(client, inputTopicPartitions, options.valueOf(shiftByOption)); + } else if (options.has(toDatetimeOption)) { + final String ts = options.valueOf(toDatetimeOption); + final long timestamp = Utils.getDateTime(ts); + resetToDatetime(client, inputTopicPartitions, timestamp); + } else if (options.has(byDurationOption)) { + final String duration = options.valueOf(byDurationOption); + resetByDuration(client, inputTopicPartitions, Duration.parse(duration)); + } else if (options.has(fromFileOption)) { + final String resetPlanPath = options.valueOf(fromFileOption); + final Map topicPartitionsAndOffset = + getTopicPartitionOffsetFromResetPlan(resetPlanPath); + resetOffsetsFromResetPlan(client, inputTopicPartitions, topicPartitionsAndOffset); + } else { + client.seekToBeginning(inputTopicPartitions); + } + + for (final TopicPartition p : inputTopicPartitions) { + System.out.println("Topic: " + p.topic() + " Partition: " + p.partition() + " Offset: " + client.position(p)); + } + } + } + + // visible for testing + public void resetOffsetsFromResetPlan(final Consumer client, + final Set inputTopicPartitions, + final Map topicPartitionsAndOffset) { + final Map endOffsets = client.endOffsets(inputTopicPartitions); + final Map beginningOffsets = client.beginningOffsets(inputTopicPartitions); + + final Map validatedTopicPartitionsAndOffset = + checkOffsetRange(topicPartitionsAndOffset, beginningOffsets, endOffsets); + + for (final TopicPartition topicPartition : inputTopicPartitions) { + client.seek(topicPartition, validatedTopicPartitionsAndOffset.get(topicPartition)); + } + } + + private Map getTopicPartitionOffsetFromResetPlan(final String resetPlanPath) + throws IOException, ParseException { + + final String resetPlanCsv = Utils.readFileAsString(resetPlanPath); + return parseResetPlan(resetPlanCsv); + } + + private void resetByDuration(final Consumer client, + final Set inputTopicPartitions, + final Duration duration) { + resetToDatetime(client, inputTopicPartitions, Instant.now().minus(duration).toEpochMilli()); + } + + // visible for testing + public void resetToDatetime(final Consumer client, + final Set inputTopicPartitions, + final Long timestamp) { + final Map topicPartitionsAndTimes = new HashMap<>(inputTopicPartitions.size()); + for (final TopicPartition topicPartition : inputTopicPartitions) { + topicPartitionsAndTimes.put(topicPartition, timestamp); + } + + final Map topicPartitionsAndOffset = client.offsetsForTimes(topicPartitionsAndTimes); + + for (final TopicPartition topicPartition : inputTopicPartitions) { + final Optional partitionOffset = Optional.ofNullable(topicPartitionsAndOffset.get(topicPartition)) + .map(OffsetAndTimestamp::offset) + .filter(offset -> offset != ListOffsetsResponse.UNKNOWN_OFFSET); + if (partitionOffset.isPresent()) { + client.seek(topicPartition, partitionOffset.get()); + } else { + client.seekToEnd(Collections.singletonList(topicPartition)); + System.out.println("Partition " + topicPartition.partition() + " from topic " + topicPartition.topic() + + " is empty, without a committed record. Falling back to latest known offset."); + } + } + } + + // visible for testing + public void shiftOffsetsBy(final Consumer client, + final Set inputTopicPartitions, + final long shiftBy) { + final Map endOffsets = client.endOffsets(inputTopicPartitions); + final Map beginningOffsets = client.beginningOffsets(inputTopicPartitions); + + final Map topicPartitionsAndOffset = new HashMap<>(inputTopicPartitions.size()); + for (final TopicPartition topicPartition : inputTopicPartitions) { + final long position = client.position(topicPartition); + final long offset = position + shiftBy; + topicPartitionsAndOffset.put(topicPartition, offset); + } + + final Map validatedTopicPartitionsAndOffset = + checkOffsetRange(topicPartitionsAndOffset, beginningOffsets, endOffsets); + + for (final TopicPartition topicPartition : inputTopicPartitions) { + client.seek(topicPartition, validatedTopicPartitionsAndOffset.get(topicPartition)); + } + } + + // visible for testing + public void resetOffsetsTo(final Consumer client, + final Set inputTopicPartitions, + final Long offset) { + final Map endOffsets = client.endOffsets(inputTopicPartitions); + final Map beginningOffsets = client.beginningOffsets(inputTopicPartitions); + + final Map topicPartitionsAndOffset = new HashMap<>(inputTopicPartitions.size()); + for (final TopicPartition topicPartition : inputTopicPartitions) { + topicPartitionsAndOffset.put(topicPartition, offset); + } + + final Map validatedTopicPartitionsAndOffset = + checkOffsetRange(topicPartitionsAndOffset, beginningOffsets, endOffsets); + + for (final TopicPartition topicPartition : inputTopicPartitions) { + client.seek(topicPartition, validatedTopicPartitionsAndOffset.get(topicPartition)); + } + } + + + private Map parseResetPlan(final String resetPlanCsv) throws ParseException { + final Map topicPartitionAndOffset = new HashMap<>(); + if (resetPlanCsv == null || resetPlanCsv.isEmpty()) { + throw new ParseException("Error parsing reset plan CSV file. It is empty,", 0); + } + + final String[] resetPlanCsvParts = resetPlanCsv.split("\n"); + + for (final String line : resetPlanCsvParts) { + final String[] lineParts = line.split(","); + if (lineParts.length != 3) { + throw new ParseException("Reset plan CSV file is not following the format `TOPIC,PARTITION,OFFSET`.", 0); + } + final String topic = lineParts[0]; + final int partition = Integer.parseInt(lineParts[1]); + final long offset = Long.parseLong(lineParts[2]); + final TopicPartition topicPartition = new TopicPartition(topic, partition); + topicPartitionAndOffset.put(topicPartition, offset); + } + + return topicPartitionAndOffset; + } + + private Map checkOffsetRange(final Map inputTopicPartitionsAndOffset, + final Map beginningOffsets, + final Map endOffsets) { + final Map validatedTopicPartitionsOffsets = new HashMap<>(); + for (final Map.Entry topicPartitionAndOffset : inputTopicPartitionsAndOffset.entrySet()) { + final long endOffset = endOffsets.get(topicPartitionAndOffset.getKey()); + final long offset = topicPartitionAndOffset.getValue(); + if (offset < endOffset) { + final long beginningOffset = beginningOffsets.get(topicPartitionAndOffset.getKey()); + if (offset > beginningOffset) { + validatedTopicPartitionsOffsets.put(topicPartitionAndOffset.getKey(), offset); + } else { + System.out.println("New offset (" + offset + ") is lower than earliest offset. Value will be set to " + beginningOffset); + validatedTopicPartitionsOffsets.put(topicPartitionAndOffset.getKey(), beginningOffset); + } + } else { + System.out.println("New offset (" + offset + ") is higher than latest offset. Value will be set to " + endOffset); + validatedTopicPartitionsOffsets.put(topicPartitionAndOffset.getKey(), endOffset); + } + } + return validatedTopicPartitionsOffsets; + } + + private boolean isInputTopic(final String topic) { + return options.valuesOf(inputTopicsOption).contains(topic); + } + + private boolean isIntermediateTopic(final String topic) { + return options.valuesOf(intermediateTopicsOption).contains(topic); + } + + private int maybeDeleteInternalTopics(final Admin adminClient, final boolean dryRun) { + final List inferredInternalTopics = allTopics.stream() + .filter(this::isInferredInternalTopic) + .collect(Collectors.toList()); + final List specifiedInternalTopics = options.valuesOf(internalTopicsOption); + final List topicsToDelete; + + if (!specifiedInternalTopics.isEmpty()) { + if (!inferredInternalTopics.containsAll(specifiedInternalTopics)) { + throw new IllegalArgumentException("Invalid topic specified in the " + + "--internal-topics option. " + + "Ensure that the topics specified are all internal topics. " + + "Do a dry run without the --internal-topics option to see the " + + "list of all internal topics that can be deleted."); + } + + topicsToDelete = specifiedInternalTopics; + System.out.println("Deleting specified internal topics " + topicsToDelete); + } else { + topicsToDelete = inferredInternalTopics; + System.out.println("Deleting inferred internal topics " + topicsToDelete); + } + + if (!dryRun) { + doDelete(topicsToDelete, adminClient); + } + + System.out.println("Done."); + return EXIT_CODE_SUCCESS; + } + + // visible for testing + public void doDelete(final List topicsToDelete, + final Admin adminClient) { + boolean hasDeleteErrors = false; + final DeleteTopicsResult deleteTopicsResult = adminClient.deleteTopics(topicsToDelete); + final Map> results = deleteTopicsResult.topicNameValues(); + + for (final Map.Entry> entry : results.entrySet()) { + try { + entry.getValue().get(30, TimeUnit.SECONDS); + } catch (final Exception e) { + System.err.println("ERROR: deleting topic " + entry.getKey()); + e.printStackTrace(System.err); + hasDeleteErrors = true; + } + } + if (hasDeleteErrors) { + throw new RuntimeException("Encountered an error deleting one or more topics"); + } + } + + private boolean isInferredInternalTopic(final String topicName) { + // Specified input/intermediate topics might be named like internal topics (by chance). + // Even is this is not expected in general, we need to exclude those topics here + // and don't consider them as internal topics even if they follow the same naming schema. + // Cf. https://issues.apache.org/jira/browse/KAFKA-7930 + return !isInputTopic(topicName) && !isIntermediateTopic(topicName) && topicName.startsWith(options.valueOf(applicationIdOption) + "-") + && matchesInternalTopicFormat(topicName); + } + + // visible for testing + public static boolean matchesInternalTopicFormat(final String topicName) { + return topicName.endsWith("-changelog") || topicName.endsWith("-repartition") + || topicName.endsWith("-subscription-registration-topic") + || topicName.endsWith("-subscription-response-topic") + || topicName.matches(".+-KTABLE-FK-JOIN-SUBSCRIPTION-REGISTRATION-\\d+-topic") + || topicName.matches(".+-KTABLE-FK-JOIN-SUBSCRIPTION-RESPONSE-\\d+-topic"); + } + + public static void main(final String[] args) { + Exit.exit(new StreamsResetter().run(args)); + } + +} diff --git a/core/src/main/scala/kafka/tools/TerseFailure.scala b/core/src/main/scala/kafka/tools/TerseFailure.scala new file mode 100644 index 0000000..c37b613 --- /dev/null +++ b/core/src/main/scala/kafka/tools/TerseFailure.scala @@ -0,0 +1,30 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.tools + +import org.apache.kafka.common.KafkaException + +/** + * An exception thrown to indicate that the command has failed, but we don't want to + * print a stack trace. + * + * @param message The message to print out before exiting. A stack trace will not + * be printed. + */ +class TerseFailure(message: String) extends KafkaException(message) { +} diff --git a/core/src/main/scala/kafka/tools/TestRaftRequestHandler.scala b/core/src/main/scala/kafka/tools/TestRaftRequestHandler.scala new file mode 100644 index 0000000..a9b471b --- /dev/null +++ b/core/src/main/scala/kafka/tools/TestRaftRequestHandler.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.tools + +import kafka.network.RequestChannel +import kafka.raft.RaftManager +import kafka.server.{ApiRequestHandler, ApiVersionManager, RequestLocal} +import kafka.utils.Logging +import org.apache.kafka.common.internals.FatalExitError +import org.apache.kafka.common.message.{BeginQuorumEpochResponseData, EndQuorumEpochResponseData, FetchResponseData, FetchSnapshotResponseData, VoteResponseData} +import org.apache.kafka.common.protocol.{ApiKeys, ApiMessage} +import org.apache.kafka.common.requests.{AbstractRequest, AbstractResponse, BeginQuorumEpochResponse, EndQuorumEpochResponse, FetchResponse, FetchSnapshotResponse, VoteResponse} +import org.apache.kafka.common.utils.Time + +/** + * Simple request handler implementation for use by [[TestRaftServer]]. + */ +class TestRaftRequestHandler( + raftManager: RaftManager[_], + requestChannel: RequestChannel, + time: Time, + apiVersionManager: ApiVersionManager +) extends ApiRequestHandler with Logging { + + override def handle(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { + try { + trace(s"Handling request:${request.requestDesc(true)} with context ${request.context}") + request.header.apiKey match { + case ApiKeys.API_VERSIONS => handleApiVersions(request) + case ApiKeys.VOTE => handleVote(request) + case ApiKeys.BEGIN_QUORUM_EPOCH => handleBeginQuorumEpoch(request) + case ApiKeys.END_QUORUM_EPOCH => handleEndQuorumEpoch(request) + case ApiKeys.FETCH => handleFetch(request) + case ApiKeys.FETCH_SNAPSHOT => handleFetchSnapshot(request) + case _ => throw new IllegalArgumentException(s"Unsupported api key: ${request.header.apiKey}") + } + } catch { + case e: FatalExitError => throw e + case e: Throwable => + error(s"Unexpected error handling request ${request.requestDesc(true)} " + + s"with context ${request.context}", e) + val errorResponse = request.body[AbstractRequest].getErrorResponse(e) + requestChannel.sendResponse(request, errorResponse, None) + } finally { + // The local completion time may be set while processing the request. Only record it if it's unset. + if (request.apiLocalCompleteTimeNanos < 0) + request.apiLocalCompleteTimeNanos = time.nanoseconds + } + } + + private def handleApiVersions(request: RequestChannel.Request): Unit = { + requestChannel.sendResponse(request, apiVersionManager.apiVersionResponse(throttleTimeMs = 0), None) + } + + private def handleVote(request: RequestChannel.Request): Unit = { + handle(request, response => new VoteResponse(response.asInstanceOf[VoteResponseData])) + } + + private def handleBeginQuorumEpoch(request: RequestChannel.Request): Unit = { + handle(request, response => new BeginQuorumEpochResponse(response.asInstanceOf[BeginQuorumEpochResponseData])) + } + + private def handleEndQuorumEpoch(request: RequestChannel.Request): Unit = { + handle(request, response => new EndQuorumEpochResponse(response.asInstanceOf[EndQuorumEpochResponseData])) + } + + private def handleFetch(request: RequestChannel.Request): Unit = { + handle(request, response => new FetchResponse(response.asInstanceOf[FetchResponseData])) + } + + private def handleFetchSnapshot(request: RequestChannel.Request): Unit = { + handle(request, response => new FetchSnapshotResponse(response.asInstanceOf[FetchSnapshotResponseData])) + } + + private def handle( + request: RequestChannel.Request, + buildResponse: ApiMessage => AbstractResponse + ): Unit = { + val requestBody = request.body[AbstractRequest] + + val future = raftManager.handleRequest( + request.header, + requestBody.data, + time.milliseconds() + ) + + future.whenComplete((response, exception) => { + val res = if (exception != null) { + requestBody.getErrorResponse(exception) + } else { + buildResponse(response) + } + requestChannel.sendResponse(request, res, None) + }) + } + +} diff --git a/core/src/main/scala/kafka/tools/TestRaftServer.scala b/core/src/main/scala/kafka/tools/TestRaftServer.scala new file mode 100644 index 0000000..5099138 --- /dev/null +++ b/core/src/main/scala/kafka/tools/TestRaftServer.scala @@ -0,0 +1,473 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.tools + +import java.util.concurrent.atomic.{AtomicInteger, AtomicLong} +import java.util.concurrent.{CompletableFuture, CountDownLatch, LinkedBlockingDeque, TimeUnit} +import joptsimple.OptionException +import kafka.network.SocketServer +import kafka.raft.{KafkaRaftManager, RaftManager} +import kafka.security.CredentialProvider +import kafka.server.{KafkaConfig, KafkaRequestHandlerPool, MetaProperties, SimpleApiVersionManager} +import kafka.utils.{CommandDefaultOptions, CommandLineUtils, CoreUtils, Exit, Logging, ShutdownableThread} +import org.apache.kafka.common.errors.InvalidConfigurationException +import org.apache.kafka.common.message.ApiMessageType.ListenerType +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.metrics.stats.Percentiles.BucketSizing +import org.apache.kafka.common.metrics.stats.{Meter, Percentile, Percentiles} +import org.apache.kafka.common.protocol.{ObjectSerializationCache, Writable} +import org.apache.kafka.common.security.scram.internals.ScramMechanism +import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache +import org.apache.kafka.common.utils.{Time, Utils} +import org.apache.kafka.common.{TopicPartition, Uuid, protocol} +import org.apache.kafka.raft.errors.NotLeaderException +import org.apache.kafka.raft.{Batch, BatchReader, LeaderAndEpoch, RaftClient, RaftConfig} +import org.apache.kafka.server.common.serialization.RecordSerde +import org.apache.kafka.snapshot.SnapshotReader + +import scala.jdk.CollectionConverters._ + +/** + * This is an experimental server which is intended for testing the performance + * of the Raft implementation. It uses a hard-coded `__raft_performance_test` topic. + */ +class TestRaftServer( + val config: KafkaConfig, + val throughput: Int, + val recordSize: Int +) extends Logging { + import kafka.tools.TestRaftServer._ + + private val partition = new TopicPartition("__raft_performance_test", 0) + // The topic ID must be constant. This value was chosen as to not conflict with the topic ID used for __cluster_metadata. + private val topicId = new Uuid(0L, 2L) + private val time = Time.SYSTEM + private val metrics = new Metrics(time) + private val shutdownLatch = new CountDownLatch(1) + private val threadNamePrefix = "test-raft" + + var socketServer: SocketServer = _ + var credentialProvider: CredentialProvider = _ + var tokenCache: DelegationTokenCache = _ + var dataPlaneRequestHandlerPool: KafkaRequestHandlerPool = _ + var workloadGenerator: RaftWorkloadGenerator = _ + var raftManager: KafkaRaftManager[Array[Byte]] = _ + + def startup(): Unit = { + tokenCache = new DelegationTokenCache(ScramMechanism.mechanismNames) + credentialProvider = new CredentialProvider(ScramMechanism.mechanismNames, tokenCache) + + val apiVersionManager = new SimpleApiVersionManager(ListenerType.CONTROLLER) + socketServer = new SocketServer(config, metrics, time, credentialProvider, apiVersionManager) + socketServer.startup(startProcessingRequests = false) + + val metaProperties = MetaProperties( + clusterId = Uuid.ZERO_UUID.toString, + nodeId = config.nodeId + ) + + raftManager = new KafkaRaftManager[Array[Byte]]( + metaProperties, + config, + new ByteArraySerde, + partition, + topicId, + time, + metrics, + Some(threadNamePrefix), + CompletableFuture.completedFuture(RaftConfig.parseVoterConnections(config.quorumVoters)) + ) + + workloadGenerator = new RaftWorkloadGenerator( + raftManager, + time, + recordsPerSec = 20000, + recordSize = 256 + ) + + val requestHandler = new TestRaftRequestHandler( + raftManager, + socketServer.dataPlaneRequestChannel, + time, + apiVersionManager + ) + + dataPlaneRequestHandlerPool = new KafkaRequestHandlerPool( + config.brokerId, + socketServer.dataPlaneRequestChannel, + requestHandler, + time, + config.numIoThreads, + s"${SocketServer.DataPlaneMetricPrefix}RequestHandlerAvgIdlePercent", + SocketServer.DataPlaneThreadPrefix + ) + + workloadGenerator.start() + raftManager.startup() + socketServer.startProcessingRequests(Map.empty) + } + + def shutdown(): Unit = { + if (raftManager != null) + CoreUtils.swallow(raftManager.shutdown(), this) + if (workloadGenerator != null) + CoreUtils.swallow(workloadGenerator.shutdown(), this) + if (dataPlaneRequestHandlerPool != null) + CoreUtils.swallow(dataPlaneRequestHandlerPool.shutdown(), this) + if (socketServer != null) + CoreUtils.swallow(socketServer.shutdown(), this) + if (metrics != null) + CoreUtils.swallow(metrics.close(), this) + shutdownLatch.countDown() + } + + def awaitShutdown(): Unit = { + shutdownLatch.await() + } + + class RaftWorkloadGenerator( + raftManager: RaftManager[Array[Byte]], + time: Time, + recordsPerSec: Int, + recordSize: Int + ) extends ShutdownableThread(name = "raft-workload-generator") + with RaftClient.Listener[Array[Byte]] { + + sealed trait RaftEvent + case class HandleClaim(epoch: Int) extends RaftEvent + case object HandleResign extends RaftEvent + case class HandleCommit(reader: BatchReader[Array[Byte]]) extends RaftEvent + case class HandleSnapshot(reader: SnapshotReader[Array[Byte]]) extends RaftEvent + case object Shutdown extends RaftEvent + + private val eventQueue = new LinkedBlockingDeque[RaftEvent]() + private val stats = new WriteStats(metrics, time, printIntervalMs = 5000) + private val payload = new Array[Byte](recordSize) + private val pendingAppends = new LinkedBlockingDeque[PendingAppend]() + private val recordCount = new AtomicInteger(0) + private val throttler = new ThroughputThrottler(time, recordsPerSec) + + private var claimedEpoch: Option[Int] = None + + raftManager.register(this) + + override def handleLeaderChange(newLeaderAndEpoch: LeaderAndEpoch): Unit = { + if (newLeaderAndEpoch.isLeader(config.nodeId)) { + eventQueue.offer(HandleClaim(newLeaderAndEpoch.epoch)) + } else if (claimedEpoch.isDefined) { + eventQueue.offer(HandleResign) + } + } + + override def handleCommit(reader: BatchReader[Array[Byte]]): Unit = { + eventQueue.offer(HandleCommit(reader)) + } + + override def handleSnapshot(reader: SnapshotReader[Array[Byte]]): Unit = { + eventQueue.offer(HandleSnapshot(reader)) + } + + override def initiateShutdown(): Boolean = { + val initiated = super.initiateShutdown() + eventQueue.offer(Shutdown) + initiated + } + + private def sendNow( + leaderEpoch: Int, + currentTimeMs: Long + ): Unit = { + recordCount.incrementAndGet() + try { + val offset = raftManager.client.scheduleAppend(leaderEpoch, List(payload).asJava) + pendingAppends.offer(PendingAppend(offset, currentTimeMs)) + } catch { + case e: NotLeaderException => + logger.debug(s"Append failed because this node is no longer leader in epoch $leaderEpoch", e) + time.sleep(10) + } + } + + override def doWork(): Unit = { + val startTimeMs = time.milliseconds() + val eventTimeoutMs = claimedEpoch.map { leaderEpoch => + val throttleTimeMs = throttler.maybeThrottle(recordCount.get() + 1, startTimeMs) + if (throttleTimeMs == 0) { + sendNow(leaderEpoch, startTimeMs) + } + throttleTimeMs + }.getOrElse(Long.MaxValue) + + eventQueue.poll(eventTimeoutMs, TimeUnit.MILLISECONDS) match { + case HandleClaim(epoch) => + claimedEpoch = Some(epoch) + throttler.reset() + pendingAppends.clear() + recordCount.set(0) + + case HandleResign => + claimedEpoch = None + pendingAppends.clear() + + case HandleCommit(reader) => + try { + while (reader.hasNext) { + val batch = reader.next() + claimedEpoch.foreach { leaderEpoch => + handleLeaderCommit(leaderEpoch, batch) + } + } + } finally { + reader.close() + } + + case HandleSnapshot(reader) => + // Ignore snapshots; only interested in records appended by this leader + reader.close() + + case Shutdown => // Ignore shutdown command + + case null => // Ignore null when timeout expires. + } + } + + private def handleLeaderCommit( + leaderEpoch: Int, + batch: Batch[Array[Byte]] + ): Unit = { + val batchEpoch = batch.epoch() + var offset = batch.baseOffset + val currentTimeMs = time.milliseconds() + + // We are only interested in batches written during the current leader's + // epoch since this allows us to rely on the local clock + if (batchEpoch != leaderEpoch) { + return + } + + for (record <- batch.records.asScala) { + val pendingAppend = pendingAppends.peek() + + if (pendingAppend == null || pendingAppend.offset != offset) { + throw new IllegalStateException(s"Unexpected append at offset $offset. The " + + s"next offset we expected was ${pendingAppend.offset}") + } + + pendingAppends.poll() + val latencyMs = math.max(0, currentTimeMs - pendingAppend.appendTimeMs).toInt + stats.record(latencyMs, record.length, currentTimeMs) + offset += 1 + } + } + + } + +} + +object TestRaftServer extends Logging { + + case class PendingAppend( + offset: Long, + appendTimeMs: Long + ) { + override def toString: String = { + s"PendingAppend(offset=$offset, appendTimeMs=$appendTimeMs)" + } + } + + class ByteArraySerde extends RecordSerde[Array[Byte]] { + override def recordSize(data: Array[Byte], serializationCache: ObjectSerializationCache): Int = { + data.length + } + + override def write(data: Array[Byte], serializationCache: ObjectSerializationCache, out: Writable): Unit = { + out.writeByteArray(data) + } + + override def read(input: protocol.Readable, size: Int): Array[Byte] = { + val data = new Array[Byte](size) + input.readArray(data) + data + } + } + + private class LatencyHistogram( + metrics: Metrics, + name: String, + group: String + ) { + private val sensor = metrics.sensor(name) + private val latencyP75Name = metrics.metricName(s"$name.p75", group) + private val latencyP99Name = metrics.metricName(s"$name.p99", group) + private val latencyP999Name = metrics.metricName(s"$name.p999", group) + + sensor.add(new Percentiles( + 1000, + 250.0, + BucketSizing.CONSTANT, + new Percentile(latencyP75Name, 75), + new Percentile(latencyP99Name, 99), + new Percentile(latencyP999Name, 99.9) + )) + + private val p75 = metrics.metric(latencyP75Name) + private val p99 = metrics.metric(latencyP99Name) + private val p999 = metrics.metric(latencyP999Name) + + def record(latencyMs: Int): Unit = sensor.record(latencyMs) + def currentP75: Double = p75.metricValue.asInstanceOf[Double] + def currentP99: Double = p99.metricValue.asInstanceOf[Double] + def currentP999: Double = p999.metricValue.asInstanceOf[Double] + } + + private class ThroughputMeter( + metrics: Metrics, + name: String, + group: String + ) { + private val sensor = metrics.sensor(name) + private val throughputRateName = metrics.metricName(s"$name.rate", group) + private val throughputTotalName = metrics.metricName(s"$name.total", group) + + sensor.add(new Meter(throughputRateName, throughputTotalName)) + + private val rate = metrics.metric(throughputRateName) + + def record(bytes: Int): Unit = sensor.record(bytes) + def currentRate: Double = rate.metricValue.asInstanceOf[Double] + } + + private class ThroughputThrottler( + time: Time, + targetRecordsPerSec: Int + ) { + private val startTimeMs = new AtomicLong(time.milliseconds()) + + require(targetRecordsPerSec > 0) + + def reset(): Unit = { + this.startTimeMs.set(time.milliseconds()) + } + + def maybeThrottle( + currentCount: Int, + currentTimeMs: Long + ): Long = { + val targetDurationMs = math.round(currentCount / targetRecordsPerSec.toDouble * 1000) + if (targetDurationMs > 0) { + val targetDeadlineMs = startTimeMs.get() + targetDurationMs + if (targetDeadlineMs > currentTimeMs) { + val throttleDurationMs = targetDeadlineMs - currentTimeMs + return throttleDurationMs + } + } + 0 + } + } + + private class WriteStats( + metrics: Metrics, + time: Time, + printIntervalMs: Long + ) { + private var lastReportTimeMs = time.milliseconds() + private val latency = new LatencyHistogram(metrics, name = "commit.latency", group = "kafka.raft") + private val throughput = new ThroughputMeter(metrics, name = "bytes.committed", group = "kafka.raft") + + def record( + latencyMs: Int, + bytes: Int, + currentTimeMs: Long + ): Unit = { + throughput.record(bytes) + latency.record(latencyMs) + + if (currentTimeMs - lastReportTimeMs >= printIntervalMs) { + printSummary() + this.lastReportTimeMs = currentTimeMs + } + } + + private def printSummary(): Unit = { + println("Throughput (bytes/second): %.2f, Latency (ms): %.1f p75 %.1f p99 %.1f p999".format( + throughput.currentRate, + latency.currentP75, + latency.currentP99, + latency.currentP999 + )) + } + } + + class TestRaftServerOptions(args: Array[String]) extends CommandDefaultOptions(args) { + val configOpt = parser.accepts("config", "Required configured file") + .withRequiredArg + .describedAs("filename") + .ofType(classOf[String]) + + val throughputOpt = parser.accepts("throughput", + "The number of records per second the leader will write to the metadata topic") + .withRequiredArg + .describedAs("records/sec") + .ofType(classOf[Int]) + .defaultsTo(5000) + + val recordSizeOpt = parser.accepts("record-size", "The size of each record") + .withRequiredArg + .describedAs("size in bytes") + .ofType(classOf[Int]) + .defaultsTo(256) + + options = parser.parse(args : _*) + } + + def main(args: Array[String]): Unit = { + val opts = new TestRaftServerOptions(args) + try { + CommandLineUtils.printHelpAndExitIfNeeded(opts, + "Standalone raft server for performance testing") + + val configFile = opts.options.valueOf(opts.configOpt) + if (configFile == null) { + throw new InvalidConfigurationException("Missing configuration file. Should specify with '--config'") + } + val serverProps = Utils.loadProps(configFile) + + // KafkaConfig requires either `process.roles` or `zookeeper.connect`. Neither are + // actually used by the test server, so we fill in `process.roles` with an arbitrary value. + serverProps.put(KafkaConfig.ProcessRolesProp, "controller") + + val config = KafkaConfig.fromProps(serverProps, doLog = false) + val throughput = opts.options.valueOf(opts.throughputOpt) + val recordSize = opts.options.valueOf(opts.recordSizeOpt) + val server = new TestRaftServer(config, throughput, recordSize) + + Exit.addShutdownHook("raft-shutdown-hook", server.shutdown()) + + server.startup() + server.awaitShutdown() + Exit.exit(0) + } catch { + case e: OptionException => + CommandLineUtils.printUsageAndDie(opts.parser, e.getMessage) + case e: Throwable => + fatal("Exiting raft server due to fatal exception", e) + Exit.exit(1) + } + } + +} diff --git a/core/src/main/scala/kafka/utils/Annotations.scala b/core/src/main/scala/kafka/utils/Annotations.scala new file mode 100644 index 0000000..da4a25c --- /dev/null +++ b/core/src/main/scala/kafka/utils/Annotations.scala @@ -0,0 +1,38 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import scala.annotation.StaticAnnotation + +/* Some helpful annotations */ + +/** + * Indicates that the annotated class is meant to be threadsafe. For an abstract class it is a part of the interface that an implementation + * must respect + */ +class threadsafe extends StaticAnnotation + +/** + * Indicates that the annotated class is not threadsafe + */ +class nonthreadsafe extends StaticAnnotation + +/** + * Indicates that the annotated class is immutable + */ +class immutable extends StaticAnnotation diff --git a/core/src/main/scala/kafka/utils/CommandDefaultOptions.scala b/core/src/main/scala/kafka/utils/CommandDefaultOptions.scala new file mode 100644 index 0000000..2cdb408 --- /dev/null +++ b/core/src/main/scala/kafka/utils/CommandDefaultOptions.scala @@ -0,0 +1,27 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import joptsimple.{OptionParser, OptionSet} + +abstract class CommandDefaultOptions(val args: Array[String], allowCommandOptionAbbreviation: Boolean = false) { + val parser = new OptionParser(allowCommandOptionAbbreviation) + val helpOpt = parser.accepts("help", "Print usage information.").forHelp() + val versionOpt = parser.accepts("version", "Display Kafka version.").forHelp() + var options: OptionSet = _ +} diff --git a/core/src/main/scala/kafka/utils/CommandLineUtils.scala b/core/src/main/scala/kafka/utils/CommandLineUtils.scala new file mode 100644 index 0000000..80726ce --- /dev/null +++ b/core/src/main/scala/kafka/utils/CommandLineUtils.scala @@ -0,0 +1,145 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package kafka.utils + +import java.util.Properties + +import joptsimple.{OptionParser, OptionSet, OptionSpec} + +import scala.collection.Set + +/** + * Helper functions for dealing with command line utilities + */ +object CommandLineUtils extends Logging { + /** + * Check if there are no options or `--help` option from command line + * + * @param commandOpts Acceptable options for a command + * @return true on matching the help check condition + */ + def isPrintHelpNeeded(commandOpts: CommandDefaultOptions): Boolean = { + commandOpts.args.length == 0 || commandOpts.options.has(commandOpts.helpOpt) + } + + def isPrintVersionNeeded(commandOpts: CommandDefaultOptions): Boolean = { + commandOpts.options.has(commandOpts.versionOpt) + } + + /** + * Check and print help message if there is no options or `--help` option + * from command line, if `--version` is specified on the command line + * print version information and exit. + * NOTE: The function name is not strictly speaking correct anymore + * as it also checks whether the version needs to be printed, but + * refactoring this would have meant changing all command line tools + * and unnecessarily increased the blast radius of this change. + * + * @param commandOpts Acceptable options for a command + * @param message Message to display on successful check + */ + def printHelpAndExitIfNeeded(commandOpts: CommandDefaultOptions, message: String) = { + if (isPrintHelpNeeded(commandOpts)) + printUsageAndDie(commandOpts.parser, message) + if (isPrintVersionNeeded(commandOpts)) + printVersionAndDie() + } + + /** + * Check that all the listed options are present + */ + def checkRequiredArgs(parser: OptionParser, options: OptionSet, required: OptionSpec[_]*): Unit = { + for (arg <- required) { + if (!options.has(arg)) + printUsageAndDie(parser, "Missing required argument \"" + arg + "\"") + } + } + + /** + * Check that none of the listed options are present + */ + def checkInvalidArgs(parser: OptionParser, options: OptionSet, usedOption: OptionSpec[_], invalidOptions: Set[OptionSpec[_]]): Unit = { + if (options.has(usedOption)) { + for (arg <- invalidOptions) { + if (options.has(arg)) + printUsageAndDie(parser, "Option \"" + usedOption + "\" can't be used with option \"" + arg + "\"") + } + } + } + + /** + * Check that none of the listed options are present with the combination of used options + */ + def checkInvalidArgsSet(parser: OptionParser, options: OptionSet, usedOptions: Set[OptionSpec[_]], invalidOptions: Set[OptionSpec[_]], + trailingAdditionalMessage: Option[String] = None): Unit = { + if (usedOptions.count(options.has) == usedOptions.size) { + for (arg <- invalidOptions) { + if (options.has(arg)) + printUsageAndDie(parser, "Option combination \"" + usedOptions.mkString(",") + "\" can't be used with option \"" + arg + "\"" + trailingAdditionalMessage.getOrElse("")) + } + } + } + + /** + * Print usage and exit + */ + def printUsageAndDie(parser: OptionParser, message: String): Nothing = { + System.err.println(message) + parser.printHelpOn(System.err) + Exit.exit(1, Some(message)) + } + + def printVersionAndDie(): Nothing = { + System.out.println(VersionInfo.getVersionString) + Exit.exit(0) + } + + /** + * Parse key-value pairs in the form key=value + * value may contain equals sign + */ + def parseKeyValueArgs(args: Iterable[String], acceptMissingValue: Boolean = true): Properties = { + val splits = args.map(_.split("=", 2)).filterNot(_.length == 0) + + val props = new Properties + for (a <- splits) { + if (a.length == 1 || (a.length == 2 && a(1).isEmpty())) { + if (acceptMissingValue) props.put(a(0), "") + else throw new IllegalArgumentException(s"Missing value for key ${a(0)}") + } + else props.put(a(0), a(1)) + } + props + } + + /** + * Merge the options into {@code props} for key {@code key}, with the following precedence, from high to low: + * 1) if {@code spec} is specified on {@code options} explicitly, use the value; + * 2) if {@code props} already has {@code key} set, keep it; + * 3) otherwise, use the default value of {@code spec}. + * A {@code null} value means to remove {@code key} from the {@code props}. + */ + def maybeMergeOptions[V](props: Properties, key: String, options: OptionSet, spec: OptionSpec[V]): Unit = { + if (options.has(spec) || !props.containsKey(key)) { + val value = options.valueOf(spec) + if (value == null) + props.remove(key) + else + props.put(key, value.toString) + } + } +} diff --git a/core/src/main/scala/kafka/utils/CoreUtils.scala b/core/src/main/scala/kafka/utils/CoreUtils.scala new file mode 100755 index 0000000..96a71c2 --- /dev/null +++ b/core/src/main/scala/kafka/utils/CoreUtils.scala @@ -0,0 +1,327 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import java.io._ +import java.nio._ +import java.nio.channels._ +import java.util.concurrent.locks.{Lock, ReadWriteLock} +import java.lang.management._ +import java.util.{Base64, Properties, UUID} +import com.typesafe.scalalogging.Logger + +import javax.management._ +import scala.collection._ +import scala.collection.{Seq, mutable} +import kafka.cluster.EndPoint +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.utils.Utils +import org.slf4j.event.Level + +import scala.annotation.nowarn + +/** + * General helper functions! + * + * This is for general helper functions that aren't specific to Kafka logic. Things that should have been included in + * the standard library etc. + * + * If you are making a new helper function and want to add it to this class please ensure the following: + * 1. It has documentation + * 2. It is the most general possible utility, not just the thing you needed in one particular place + * 3. You have tests for it if it is nontrivial in any way + */ +object CoreUtils { + private val logger = Logger(getClass) + + /** + * Return the smallest element in `iterable` if it is not empty. Otherwise return `ifEmpty`. + */ + def min[A, B >: A](iterable: Iterable[A], ifEmpty: A)(implicit cmp: Ordering[B]): A = + if (iterable.isEmpty) ifEmpty else iterable.min(cmp) + + /** + * Do the given action and log any exceptions thrown without rethrowing them. + * + * @param action The action to execute. + * @param logging The logging instance to use for logging the thrown exception. + * @param logLevel The log level to use for logging. + */ + def swallow(action: => Unit, logging: Logging, logLevel: Level = Level.WARN): Unit = { + try { + action + } catch { + case e: Throwable => logLevel match { + case Level.ERROR => logger.error(e.getMessage, e) + case Level.WARN => logger.warn(e.getMessage, e) + case Level.INFO => logger.info(e.getMessage, e) + case Level.DEBUG => logger.debug(e.getMessage, e) + case Level.TRACE => logger.trace(e.getMessage, e) + } + } + } + + /** + * Recursively delete the list of files/directories and any subfiles (if any exist) + * @param files sequence of files to be deleted + */ + def delete(files: Seq[String]): Unit = files.foreach(f => Utils.delete(new File(f))) + + /** + * Invokes every function in `all` even if one or more functions throws an exception. + * + * If any of the functions throws an exception, the first one will be rethrown at the end with subsequent exceptions + * added as suppressed exceptions. + */ + // Note that this is a generalised version of `Utils.closeAll`. We could potentially make it more general by + // changing the signature to `def tryAll[R](all: Seq[() => R]): Seq[R]` + def tryAll(all: Seq[() => Unit]): Unit = { + var exception: Throwable = null + all.foreach { element => + try element.apply() + catch { + case e: Throwable => + if (exception != null) + exception.addSuppressed(e) + else + exception = e + } + } + if (exception != null) + throw exception + } + + /** + * Register the given mbean with the platform mbean server, + * unregistering any mbean that was there before. Note, + * this method will not throw an exception if the registration + * fails (since there is nothing you can do and it isn't fatal), + * instead it just returns false indicating the registration failed. + * @param mbean The object to register as an mbean + * @param name The name to register this mbean with + * @return true if the registration succeeded + */ + def registerMBean(mbean: Object, name: String): Boolean = { + try { + val mbs = ManagementFactory.getPlatformMBeanServer() + mbs synchronized { + val objName = new ObjectName(name) + if (mbs.isRegistered(objName)) + mbs.unregisterMBean(objName) + mbs.registerMBean(mbean, objName) + true + } + } catch { + case e: Exception => + logger.error(s"Failed to register Mbean $name", e) + false + } + } + + /** + * Unregister the mbean with the given name, if there is one registered + * @param name The mbean name to unregister + */ + def unregisterMBean(name: String): Unit = { + val mbs = ManagementFactory.getPlatformMBeanServer() + mbs synchronized { + val objName = new ObjectName(name) + if (mbs.isRegistered(objName)) + mbs.unregisterMBean(objName) + } + } + + /** + * Read some bytes into the provided buffer, and return the number of bytes read. If the + * channel has been closed or we get -1 on the read for any reason, throw an EOFException + */ + def read(channel: ReadableByteChannel, buffer: ByteBuffer): Int = { + channel.read(buffer) match { + case -1 => throw new EOFException("Received -1 when reading from channel, socket has likely been closed.") + case n => n + } + } + + /** + * This method gets comma separated values which contains key,value pairs and returns a map of + * key value pairs. the format of allCSVal is key1:val1, key2:val2 .... + * Also supports strings with multiple ":" such as IpV6 addresses, taking the last occurrence + * of the ":" in the pair as the split, eg a:b:c:val1, d:e:f:val2 => a:b:c -> val1, d:e:f -> val2 + */ + def parseCsvMap(str: String): Map[String, String] = { + val map = new mutable.HashMap[String, String] + if ("".equals(str)) + return map + val keyVals = str.split("\\s*,\\s*").map(s => { + val lio = s.lastIndexOf(":") + (s.substring(0,lio).trim, s.substring(lio + 1).trim) + }) + keyVals.toMap + } + + /** + * Parse a comma separated string into a sequence of strings. + * Whitespace surrounding the comma will be removed. + */ + def parseCsvList(csvList: String): Seq[String] = { + if (csvList == null || csvList.isEmpty) + Seq.empty[String] + else + csvList.split("\\s*,\\s*").filter(v => !v.equals("")) + } + + /** + * Create an instance of the class with the given class name + */ + def createObject[T <: AnyRef](className: String, args: AnyRef*): T = { + val klass = Class.forName(className, true, Utils.getContextOrKafkaClassLoader()).asInstanceOf[Class[T]] + val constructor = klass.getConstructor(args.map(_.getClass): _*) + constructor.newInstance(args: _*) + } + + /** + * Create a circular (looping) iterator over a collection. + * @param coll An iterable over the underlying collection. + * @return A circular iterator over the collection. + */ + def circularIterator[T](coll: Iterable[T]) = + for (_ <- Iterator.continually(1); t <- coll) yield t + + /** + * Replace the given string suffix with the new suffix. If the string doesn't end with the given suffix throw an exception. + */ + def replaceSuffix(s: String, oldSuffix: String, newSuffix: String): String = { + if(!s.endsWith(oldSuffix)) + throw new IllegalArgumentException("Expected string to end with '%s' but string is '%s'".format(oldSuffix, s)) + s.substring(0, s.length - oldSuffix.length) + newSuffix + } + + /** + * Read a big-endian integer from a byte array + */ + def readInt(bytes: Array[Byte], offset: Int): Int = { + ((bytes(offset) & 0xFF) << 24) | + ((bytes(offset + 1) & 0xFF) << 16) | + ((bytes(offset + 2) & 0xFF) << 8) | + (bytes(offset + 3) & 0xFF) + } + + /** + * Execute the given function inside the lock + */ + def inLock[T](lock: Lock)(fun: => T): T = { + lock.lock() + try { + fun + } finally { + lock.unlock() + } + } + + def inReadLock[T](lock: ReadWriteLock)(fun: => T): T = inLock[T](lock.readLock)(fun) + + def inWriteLock[T](lock: ReadWriteLock)(fun: => T): T = inLock[T](lock.writeLock)(fun) + + /** + * Returns a list of duplicated items + */ + def duplicates[T](s: Iterable[T]): Iterable[T] = { + s.groupBy(identity) + .map { case (k, l) => (k, l.size)} + .filter { case (_, l) => l > 1 } + .keys + } + + def listenerListToEndPoints(listeners: String, securityProtocolMap: Map[ListenerName, SecurityProtocol]): Seq[EndPoint] = { + listenerListToEndPoints(listeners, securityProtocolMap, true) + } + + def listenerListToEndPoints(listeners: String, securityProtocolMap: Map[ListenerName, SecurityProtocol], requireDistinctPorts: Boolean): Seq[EndPoint] = { + def validate(endPoints: Seq[EndPoint]): Unit = { + // filter port 0 for unit tests + val portsExcludingZero = endPoints.map(_.port).filter(_ != 0) + val distinctListenerNames = endPoints.map(_.listenerName).distinct + + require(distinctListenerNames.size == endPoints.size, s"Each listener must have a different name, listeners: $listeners") + if (requireDistinctPorts) { + val distinctPorts = portsExcludingZero.distinct + require(distinctPorts.size == portsExcludingZero.size, s"Each listener must have a different port, listeners: $listeners") + } + } + + val endPoints = try { + val listenerList = parseCsvList(listeners) + listenerList.map(EndPoint.createEndPoint(_, Some(securityProtocolMap))) + } catch { + case e: Exception => + throw new IllegalArgumentException(s"Error creating broker listeners from '$listeners': ${e.getMessage}", e) + } + validate(endPoints) + endPoints + } + + def generateUuidAsBase64(): String = { + val uuid = UUID.randomUUID() + Base64.getUrlEncoder.withoutPadding.encodeToString(getBytesFromUuid(uuid)) + } + + def getBytesFromUuid(uuid: UUID): Array[Byte] = { + // Extract bytes for uuid which is 128 bits (or 16 bytes) long. + val uuidBytes = ByteBuffer.wrap(new Array[Byte](16)) + uuidBytes.putLong(uuid.getMostSignificantBits) + uuidBytes.putLong(uuid.getLeastSignificantBits) + uuidBytes.array + } + + def propsWith(key: String, value: String): Properties = { + propsWith((key, value)) + } + + def propsWith(props: (String, String)*): Properties = { + val properties = new Properties() + props.foreach { case (k, v) => properties.put(k, v) } + properties + } + + /** + * Atomic `getOrElseUpdate` for concurrent maps. This is optimized for the case where + * keys often exist in the map, avoiding the need to create a new value. `createValue` + * may be invoked more than once if multiple threads attempt to insert a key at the same + * time, but the same inserted value will be returned to all threads. + * + * In Scala 2.12, `ConcurrentMap.getOrElse` has the same behaviour as this method, but JConcurrentMapWrapper that + * wraps Java maps does not. + */ + def atomicGetOrUpdate[K, V](map: concurrent.Map[K, V], key: K, createValue: => V): V = { + map.get(key) match { + case Some(value) => value + case None => + val value = createValue + map.putIfAbsent(key, value).getOrElse(value) + } + } + + @nowarn("cat=unused") // see below for explanation + def groupMapReduce[T, K, B](elements: Iterable[T])(key: T => K)(f: T => B)(reduce: (B, B) => B): Map[K, B] = { + // required for Scala 2.12 compatibility, unused in Scala 2.13 and hence we need to suppress the unused warning + import scala.collection.compat._ + elements.groupMapReduce(key)(f)(reduce) + } + +} diff --git a/core/src/main/scala/kafka/utils/DelayedItem.scala b/core/src/main/scala/kafka/utils/DelayedItem.scala new file mode 100644 index 0000000..cfb8771 --- /dev/null +++ b/core/src/main/scala/kafka/utils/DelayedItem.scala @@ -0,0 +1,44 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import java.util.concurrent._ + +import org.apache.kafka.common.utils.Time + +import scala.math._ + +class DelayedItem(val delayMs: Long) extends Delayed with Logging { + + private val dueMs = Time.SYSTEM.milliseconds + delayMs + + def this(delay: Long, unit: TimeUnit) = this(unit.toMillis(delay)) + + /** + * The remaining delay time + */ + def getDelay(unit: TimeUnit): Long = { + unit.convert(max(dueMs - Time.SYSTEM.milliseconds, 0), TimeUnit.MILLISECONDS) + } + + def compareTo(d: Delayed): Int = { + val other = d.asInstanceOf[DelayedItem] + java.lang.Long.compare(dueMs, other.dueMs) + } + +} diff --git a/core/src/main/scala/kafka/utils/Exit.scala b/core/src/main/scala/kafka/utils/Exit.scala new file mode 100644 index 0000000..ad17237 --- /dev/null +++ b/core/src/main/scala/kafka/utils/Exit.scala @@ -0,0 +1,63 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.utils + +import org.apache.kafka.common.utils.{Exit => JExit} + +/** + * Internal class that should be used instead of `System.exit()` and `Runtime.getRuntime().halt()` so that tests can + * easily change the behaviour. + */ +object Exit { + + def exit(statusCode: Int, message: Option[String] = None): Nothing = { + JExit.exit(statusCode, message.orNull) + throw new AssertionError("exit should not return, but it did.") + } + + def halt(statusCode: Int, message: Option[String] = None): Nothing = { + JExit.halt(statusCode, message.orNull) + throw new AssertionError("halt should not return, but it did.") + } + + def addShutdownHook(name: String, shutdownHook: => Unit): Unit = { + JExit.addShutdownHook(name, () => shutdownHook) + } + + def setExitProcedure(exitProcedure: (Int, Option[String]) => Nothing): Unit = + JExit.setExitProcedure(functionToProcedure(exitProcedure)) + + def setHaltProcedure(haltProcedure: (Int, Option[String]) => Nothing): Unit = + JExit.setHaltProcedure(functionToProcedure(haltProcedure)) + + def setShutdownHookAdder(shutdownHookAdder: (String, => Unit) => Unit): Unit = { + JExit.setShutdownHookAdder((name, runnable) => shutdownHookAdder(name, runnable.run)) + } + + def resetExitProcedure(): Unit = + JExit.resetExitProcedure() + + def resetHaltProcedure(): Unit = + JExit.resetHaltProcedure() + + def resetShutdownHookAdder(): Unit = + JExit.resetShutdownHookAdder() + + private def functionToProcedure(procedure: (Int, Option[String]) => Nothing) = new JExit.Procedure { + def execute(statusCode: Int, message: String): Unit = procedure(statusCode, Option(message)) + } +} diff --git a/core/src/main/scala/kafka/utils/FileLock.scala b/core/src/main/scala/kafka/utils/FileLock.scala new file mode 100644 index 0000000..c635f76 --- /dev/null +++ b/core/src/main/scala/kafka/utils/FileLock.scala @@ -0,0 +1,82 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package kafka.utils + +import java.io._ +import java.nio.channels._ +import java.nio.file.StandardOpenOption + +/** + * A file lock a la flock/funlock + * + * The given path will be created and opened if it doesn't exist. + */ +class FileLock(val file: File) extends Logging { + + private val channel = FileChannel.open(file.toPath, StandardOpenOption.CREATE, StandardOpenOption.READ, + StandardOpenOption.WRITE) + private var flock: java.nio.channels.FileLock = null + + /** + * Lock the file or throw an exception if the lock is already held + */ + def lock(): Unit = { + this synchronized { + trace(s"Acquiring lock on ${file.getAbsolutePath}") + flock = channel.lock() + } + } + + /** + * Try to lock the file and return true if the locking succeeds + */ + def tryLock(): Boolean = { + this synchronized { + trace(s"Acquiring lock on ${file.getAbsolutePath}") + try { + // weirdly this method will return null if the lock is held by another + // process, but will throw an exception if the lock is held by this process + // so we have to handle both cases + flock = channel.tryLock() + flock != null + } catch { + case _: OverlappingFileLockException => false + } + } + } + + /** + * Unlock the lock if it is held + */ + def unlock(): Unit = { + this synchronized { + trace(s"Releasing lock on ${file.getAbsolutePath}") + if(flock != null) + flock.release() + } + } + + /** + * Destroy this lock, closing the associated FileChannel + */ + def destroy() = { + this synchronized { + unlock() + channel.close() + } + } +} diff --git a/core/src/main/scala/kafka/utils/Implicits.scala b/core/src/main/scala/kafka/utils/Implicits.scala new file mode 100644 index 0000000..fbd22ec --- /dev/null +++ b/core/src/main/scala/kafka/utils/Implicits.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import java.util +import java.util.Properties + +import scala.annotation.nowarn +import scala.jdk.CollectionConverters._ + +/** + * In order to have these implicits in scope, add the following import: + * + * `import kafka.utils.Implicits._` + */ +object Implicits { + + /** + * The java.util.Properties.putAll override introduced in Java 9 is seen as an overload by the + * Scala compiler causing ambiguity errors in some cases. The `++=` methods introduced via + * implicits provide a concise alternative. + * + * See https://github.com/scala/bug/issues/10418 for more details. + */ + implicit class PropertiesOps(properties: Properties) { + + def ++=(props: Properties): Unit = + (properties: util.Hashtable[AnyRef, AnyRef]).putAll(props) + + def ++=(map: collection.Map[String, AnyRef]): Unit = + (properties: util.Hashtable[AnyRef, AnyRef]).putAll(map.asJava) + + } + + /** + * Exposes `forKeyValue` which maps to `foreachEntry` in Scala 2.13 and `foreach` in Scala 2.12 + * (with the help of scala.collection.compat). `foreachEntry` avoids the tuple allocation and + * is more efficient. + * + * This was not named `foreachEntry` to avoid `unused import` warnings in Scala 2.13 (the implicit + * would not be triggered in Scala 2.13 since `Map.foreachEntry` would have precedence). + */ + @nowarn("cat=unused-imports") + implicit class MapExtensionMethods[K, V](private val self: scala.collection.Map[K, V]) extends AnyVal { + import scala.collection.compat._ + def forKeyValue[U](f: (K, V) => U): Unit = { + self.foreachEntry { (k, v) => f(k, v) } + } + + } + +} diff --git a/core/src/main/scala/kafka/utils/Json.scala b/core/src/main/scala/kafka/utils/Json.scala new file mode 100644 index 0000000..049941c --- /dev/null +++ b/core/src/main/scala/kafka/utils/Json.scala @@ -0,0 +1,92 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.utils + +import com.fasterxml.jackson.core.{JsonParseException, JsonProcessingException} +import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.databind.node.MissingNode +import kafka.utils.json.JsonValue + +import scala.reflect.ClassTag + +/** + * Provides methods for parsing JSON with Jackson and encoding to JSON with a simple and naive custom implementation. + */ +object Json { + + private val mapper = new ObjectMapper() + + /** + * Parse a JSON string into a JsonValue if possible. `None` is returned if `input` is not valid JSON. + */ + def parseFull(input: String): Option[JsonValue] = tryParseFull(input).toOption + + /** + * Parse a JSON string into either a generic type T, or a JsonProcessingException in the case of + * exception. + */ + def parseStringAs[T](input: String)(implicit tag: ClassTag[T]): Either[JsonProcessingException, T] = { + try Right(mapper.readValue(input, tag.runtimeClass).asInstanceOf[T]) + catch { case e: JsonProcessingException => Left(e) } + } + + /** + * Parse a JSON byte array into a JsonValue if possible. `None` is returned if `input` is not valid JSON. + */ + def parseBytes(input: Array[Byte]): Option[JsonValue] = + try Option(mapper.readTree(input)).map(JsonValue(_)) + catch { case _: JsonProcessingException => None } + + def tryParseBytes(input: Array[Byte]): Either[JsonProcessingException, JsonValue] = + try Right(mapper.readTree(input)).map(JsonValue(_)) + catch { case e: JsonProcessingException => Left(e) } + + /** + * Parse a JSON byte array into either a generic type T, or a JsonProcessingException in the case of exception. + */ + def parseBytesAs[T](input: Array[Byte])(implicit tag: ClassTag[T]): Either[JsonProcessingException, T] = { + try Right(mapper.readValue(input, tag.runtimeClass).asInstanceOf[T]) + catch { case e: JsonProcessingException => Left(e) } + } + + /** + * Parse a JSON string into a JsonValue if possible. It returns an `Either` where `Left` will be an exception and + * `Right` is the `JsonValue`. + * @param input a JSON string to parse + * @return An `Either` which in case of `Left` means an exception and `Right` is the actual return value. + */ + def tryParseFull(input: String): Either[JsonProcessingException, JsonValue] = + if (input == null || input.isEmpty) + Left(new JsonParseException(MissingNode.getInstance().traverse(), "The input string shouldn't be empty")) + else + try Right(mapper.readTree(input)).map(JsonValue(_)) + catch { case e: JsonProcessingException => Left(e) } + + /** + * Encode an object into a JSON string. This method accepts any type supported by Jackson's ObjectMapper in + * the default configuration. That is, Java collections are supported, but Scala collections are not (to avoid + * a jackson-scala dependency). + */ + def encodeAsString(obj: Any): String = mapper.writeValueAsString(obj) + + /** + * Encode an object into a JSON value in bytes. This method accepts any type supported by Jackson's ObjectMapper in + * the default configuration. That is, Java collections are supported, but Scala collections are not (to avoid + * a jackson-scala dependency). + */ + def encodeAsBytes(obj: Any): Array[Byte] = mapper.writeValueAsBytes(obj) +} diff --git a/core/src/main/scala/kafka/utils/KafkaScheduler.scala b/core/src/main/scala/kafka/utils/KafkaScheduler.scala new file mode 100755 index 0000000..bec511b --- /dev/null +++ b/core/src/main/scala/kafka/utils/KafkaScheduler.scala @@ -0,0 +1,165 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import java.util.concurrent._ +import atomic._ +import org.apache.kafka.common.utils.KafkaThread + +import java.util.concurrent.TimeUnit.NANOSECONDS + +/** + * A scheduler for running jobs + * + * This interface controls a job scheduler that allows scheduling either repeating background jobs + * that execute periodically or delayed one-time actions that are scheduled in the future. + */ +trait Scheduler { + + /** + * Initialize this scheduler so it is ready to accept scheduling of tasks + */ + def startup(): Unit + + /** + * Shutdown this scheduler. When this method is complete no more executions of background tasks will occur. + * This includes tasks scheduled with a delayed execution. + */ + def shutdown(): Unit + + /** + * Check if the scheduler has been started + */ + def isStarted: Boolean + + /** + * Schedule a task + * @param name The name of this task + * @param delay The amount of time to wait before the first execution + * @param period The period with which to execute the task. If < 0 the task will execute only once. + * @param unit The unit for the preceding times. + * @return A Future object to manage the task scheduled. + */ + def schedule(name: String, fun: ()=>Unit, delay: Long = 0, period: Long = -1, unit: TimeUnit = TimeUnit.MILLISECONDS) : ScheduledFuture[_] +} + +/** + * A scheduler based on java.util.concurrent.ScheduledThreadPoolExecutor + * + * It has a pool of kafka-scheduler- threads that do the actual work. + * + * @param threads The number of threads in the thread pool + * @param threadNamePrefix The name to use for scheduler threads. This prefix will have a number appended to it. + * @param daemon If true the scheduler threads will be "daemon" threads and will not block jvm shutdown. + */ +@threadsafe +class KafkaScheduler(val threads: Int, + val threadNamePrefix: String = "kafka-scheduler-", + daemon: Boolean = true) extends Scheduler with Logging { + private var executor: ScheduledThreadPoolExecutor = null + private val schedulerThreadId = new AtomicInteger(0) + + override def startup(): Unit = { + debug("Initializing task scheduler.") + this synchronized { + if(isStarted) + throw new IllegalStateException("This scheduler has already been started!") + executor = new ScheduledThreadPoolExecutor(threads) + executor.setContinueExistingPeriodicTasksAfterShutdownPolicy(false) + executor.setExecuteExistingDelayedTasksAfterShutdownPolicy(false) + executor.setRemoveOnCancelPolicy(true) + executor.setThreadFactory(runnable => + new KafkaThread(threadNamePrefix + schedulerThreadId.getAndIncrement(), runnable, daemon)) + } + } + + override def shutdown(): Unit = { + debug("Shutting down task scheduler.") + // We use the local variable to avoid NullPointerException if another thread shuts down scheduler at same time. + val cachedExecutor = this.executor + if (cachedExecutor != null) { + this synchronized { + cachedExecutor.shutdown() + this.executor = null + } + cachedExecutor.awaitTermination(1, TimeUnit.DAYS) + } + } + + def scheduleOnce(name: String, fun: () => Unit): Unit = { + schedule(name, fun, delay = 0L, period = -1L, unit = TimeUnit.MILLISECONDS) + } + + def schedule(name: String, fun: () => Unit, delay: Long, period: Long, unit: TimeUnit): ScheduledFuture[_] = { + debug("Scheduling task %s with initial delay %d ms and period %d ms." + .format(name, TimeUnit.MILLISECONDS.convert(delay, unit), TimeUnit.MILLISECONDS.convert(period, unit))) + this synchronized { + if (isStarted) { + val runnable: Runnable = () => { + try { + trace("Beginning execution of scheduled task '%s'.".format(name)) + fun() + } catch { + case t: Throwable => error(s"Uncaught exception in scheduled task '$name'", t) + } finally { + trace("Completed execution of scheduled task '%s'.".format(name)) + } + } + if (period >= 0) + executor.scheduleAtFixedRate(runnable, delay, period, unit) + else + executor.schedule(runnable, delay, unit) + } else { + info("Kafka scheduler is not running at the time task '%s' is scheduled. The task is ignored.".format(name)) + new NoOpScheduledFutureTask + } + } + } + + /** + * Package private for testing. + */ + private[kafka] def taskRunning(task: ScheduledFuture[_]): Boolean = { + executor.getQueue().contains(task) + } + + def resizeThreadPool(newSize: Int): Unit = { + executor.setCorePoolSize(newSize) + } + + def isStarted: Boolean = { + this synchronized { + executor != null + } + } +} + +private class NoOpScheduledFutureTask() extends ScheduledFuture[Unit] { + override def cancel(mayInterruptIfRunning: Boolean): Boolean = true + override def isCancelled: Boolean = true + override def isDone: Boolean = true + override def get(): Unit = {} + override def get(timeout: Long, unit: TimeUnit): Unit = {} + override def getDelay(unit: TimeUnit): Long = 0 + override def compareTo(o: Delayed): Int = { + val diff = getDelay(NANOSECONDS) - o.getDelay(NANOSECONDS) + if (diff < 0) -1 + else if (diff > 0) 1 + else 0 + } +} diff --git a/core/src/main/scala/kafka/utils/Log4jController.scala b/core/src/main/scala/kafka/utils/Log4jController.scala new file mode 100755 index 0000000..0d54c74 --- /dev/null +++ b/core/src/main/scala/kafka/utils/Log4jController.scala @@ -0,0 +1,135 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import java.util +import java.util.Locale + +import org.apache.kafka.common.utils.Utils +import org.apache.log4j.{Level, LogManager, Logger} + +import scala.collection.mutable +import scala.jdk.CollectionConverters._ + + +object Log4jController { + val ROOT_LOGGER = "root" + + private def resolveLevel(logger: Logger): String = { + var name = logger.getName + var level = logger.getLevel + while (level == null) { + val index = name.lastIndexOf(".") + if (index > 0) { + name = name.substring(0, index) + val ancestor = existingLogger(name) + if (ancestor != null) { + level = ancestor.getLevel + } + } else { + level = existingLogger(ROOT_LOGGER).getLevel + } + } + level.toString + } + + /** + * Returns a map of the log4j loggers and their assigned log level. + * If a logger does not have a log level assigned, we return the root logger's log level + */ + def loggers: mutable.Map[String, String] = { + val logs = new mutable.HashMap[String, String]() + val rootLoggerLvl = existingLogger(ROOT_LOGGER).getLevel.toString + logs.put(ROOT_LOGGER, rootLoggerLvl) + + val loggers = LogManager.getCurrentLoggers + while (loggers.hasMoreElements) { + val logger = loggers.nextElement().asInstanceOf[Logger] + if (logger != null) { + logs.put(logger.getName, resolveLevel(logger)) + } + } + logs + } + + /** + * Sets the log level of a particular logger + */ + def logLevel(loggerName: String, logLevel: String): Boolean = { + val log = existingLogger(loggerName) + if (!Utils.isBlank(loggerName) && !Utils.isBlank(logLevel) && log != null) { + log.setLevel(Level.toLevel(logLevel.toUpperCase(Locale.ROOT))) + true + } + else false + } + + def unsetLogLevel(loggerName: String): Boolean = { + val log = existingLogger(loggerName) + if (!Utils.isBlank(loggerName) && log != null) { + log.setLevel(null) + true + } + else false + } + + def loggerExists(loggerName: String): Boolean = existingLogger(loggerName) != null + + private def existingLogger(loggerName: String) = + if (loggerName == ROOT_LOGGER) + LogManager.getRootLogger + else LogManager.exists(loggerName) +} + +/** + * An MBean that allows the user to dynamically alter log4j levels at runtime. + * The companion object contains the singleton instance of this class and + * registers the MBean. The [[kafka.utils.Logging]] trait forces initialization + * of the companion object. + */ +class Log4jController extends Log4jControllerMBean { + + def getLoggers: util.List[String] = { + // we replace scala collection by java collection so mbean client is able to deserialize it without scala library. + new util.ArrayList[String](Log4jController.loggers.map { + case (logger, level) => s"$logger=$level" + }.toSeq.asJava) + } + + + def getLogLevel(loggerName: String): String = { + val log = Log4jController.existingLogger(loggerName) + if (log != null) { + val level = log.getLevel + if (level != null) + log.getLevel.toString + else + Log4jController.resolveLevel(log) + } + else "No such logger." + } + + def setLogLevel(loggerName: String, level: String): Boolean = Log4jController.logLevel(loggerName, level) +} + + +trait Log4jControllerMBean { + def getLoggers: java.util.List[String] + def getLogLevel(logger: String): String + def setLogLevel(logger: String, level: String): Boolean +} diff --git a/core/src/main/scala/kafka/utils/Logging.scala b/core/src/main/scala/kafka/utils/Logging.scala new file mode 100755 index 0000000..0221821 --- /dev/null +++ b/core/src/main/scala/kafka/utils/Logging.scala @@ -0,0 +1,83 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import com.typesafe.scalalogging.Logger +import org.slf4j.{LoggerFactory, Marker, MarkerFactory} + + +object Log4jControllerRegistration { + private val logger = Logger(this.getClass.getName) + + try { + val log4jController = Class.forName("kafka.utils.Log4jController").asInstanceOf[Class[Object]] + val instance = log4jController.getDeclaredConstructor().newInstance() + CoreUtils.registerMBean(instance, "kafka:type=kafka.Log4jController") + logger.info("Registered kafka:type=kafka.Log4jController MBean") + } catch { + case _: Exception => logger.info("Couldn't register kafka:type=kafka.Log4jController MBean") + } +} + +private object Logging { + private val FatalMarker: Marker = MarkerFactory.getMarker("FATAL") +} + +trait Logging { + + protected lazy val logger = Logger(LoggerFactory.getLogger(loggerName)) + + protected var logIdent: String = _ + + Log4jControllerRegistration + + protected def loggerName: String = getClass.getName + + protected def msgWithLogIdent(msg: String): String = + if (logIdent == null) msg else logIdent + msg + + def trace(msg: => String): Unit = logger.trace(msgWithLogIdent(msg)) + + def trace(msg: => String, e: => Throwable): Unit = logger.trace(msgWithLogIdent(msg),e) + + def isDebugEnabled: Boolean = logger.underlying.isDebugEnabled + + def isTraceEnabled: Boolean = logger.underlying.isTraceEnabled + + def debug(msg: => String): Unit = logger.debug(msgWithLogIdent(msg)) + + def debug(msg: => String, e: => Throwable): Unit = logger.debug(msgWithLogIdent(msg),e) + + def info(msg: => String): Unit = logger.info(msgWithLogIdent(msg)) + + def info(msg: => String,e: => Throwable): Unit = logger.info(msgWithLogIdent(msg),e) + + def warn(msg: => String): Unit = logger.warn(msgWithLogIdent(msg)) + + def warn(msg: => String, e: => Throwable): Unit = logger.warn(msgWithLogIdent(msg),e) + + def error(msg: => String): Unit = logger.error(msgWithLogIdent(msg)) + + def error(msg: => String, e: => Throwable): Unit = logger.error(msgWithLogIdent(msg),e) + + def fatal(msg: => String): Unit = + logger.error(Logging.FatalMarker, msgWithLogIdent(msg)) + + def fatal(msg: => String, e: => Throwable): Unit = + logger.error(Logging.FatalMarker, msgWithLogIdent(msg), e) +} diff --git a/core/src/main/scala/kafka/utils/Mx4jLoader.scala b/core/src/main/scala/kafka/utils/Mx4jLoader.scala new file mode 100644 index 0000000..e49f3a5 --- /dev/null +++ b/core/src/main/scala/kafka/utils/Mx4jLoader.scala @@ -0,0 +1,71 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + + +import java.lang.management.ManagementFactory +import javax.management.ObjectName + +/** + * If mx4j-tools is in the classpath call maybeLoad to load the HTTP interface of mx4j. + * + * The default port is 8082. To override that provide e.g. -Dmx4jport=8083 + * The default listen address is 0.0.0.0. To override that provide -Dmx4jaddress=127.0.0.1 + * This feature must be enabled with -Dmx4jenable=true + * + * This is a Scala port of org.apache.cassandra.utils.Mx4jTool written by Ran Tavory for CASSANDRA-1068 + * */ +object Mx4jLoader extends Logging { + + def maybeLoad(): Boolean = { + val props = new VerifiableProperties(System.getProperties()) + if (!props.getBoolean("kafka_mx4jenable", false)) + return false + val address = props.getString("mx4jaddress", "0.0.0.0") + val port = props.getInt("mx4jport", 8082) + try { + debug("Will try to load MX4j now, if it's in the classpath") + + val mbs = ManagementFactory.getPlatformMBeanServer() + val processorName = new ObjectName("Server:name=XSLTProcessor") + + val httpAdaptorClass = Class.forName("mx4j.tools.adaptor.http.HttpAdaptor") + val httpAdaptor = httpAdaptorClass.getDeclaredConstructor().newInstance() + httpAdaptorClass.getMethod("setHost", classOf[String]).invoke(httpAdaptor, address.asInstanceOf[AnyRef]) + httpAdaptorClass.getMethod("setPort", Integer.TYPE).invoke(httpAdaptor, port.asInstanceOf[AnyRef]) + + val httpName = new ObjectName("system:name=http") + mbs.registerMBean(httpAdaptor, httpName) + + val xsltProcessorClass = Class.forName("mx4j.tools.adaptor.http.XSLTProcessor") + val xsltProcessor = xsltProcessorClass.getDeclaredConstructor().newInstance() + httpAdaptorClass.getMethod("setProcessor", Class.forName("mx4j.tools.adaptor.http.ProcessorMBean")).invoke(httpAdaptor, xsltProcessor.asInstanceOf[AnyRef]) + mbs.registerMBean(xsltProcessor, processorName) + httpAdaptorClass.getMethod("start").invoke(httpAdaptor) + info("mx4j successfully loaded") + return true + } + catch { + case _: ClassNotFoundException => + info("Will not load MX4J, mx4j-tools.jar is not in the classpath") + case e: Throwable => + warn("Could not start register mbean in JMX", e) + } + false + } +} diff --git a/core/src/main/scala/kafka/utils/NotNothing.scala b/core/src/main/scala/kafka/utils/NotNothing.scala new file mode 100644 index 0000000..aee345e --- /dev/null +++ b/core/src/main/scala/kafka/utils/NotNothing.scala @@ -0,0 +1,41 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import scala.annotation.implicitNotFound + +/** + * This is a trick to prevent the compiler from inferring the `Nothing` type in cases where it would be a bug to do + * so. An example is the following method: + * + * ``` + * def body[T <: AbstractRequest](implicit classTag: ClassTag[T], nn: NotNothing[T]): T + * ``` + * + * If we remove the `nn` parameter and we invoke it without any type parameters (e.g. `request.body`), `Nothing` would + * be inferred, which is not desirable. As defined above, we get a helpful compiler error asking the user to provide + * the type parameter explicitly. + */ +@implicitNotFound("Unable to infer type parameter, please provide it explicitly.") +trait NotNothing[T] + +object NotNothing { + private val evidence: NotNothing[Any] = new Object with NotNothing[Any] + + implicit def notNothingEvidence[T](implicit n: T =:= T): NotNothing[T] = evidence.asInstanceOf[NotNothing[T]] +} diff --git a/core/src/main/scala/kafka/utils/PasswordEncoder.scala b/core/src/main/scala/kafka/utils/PasswordEncoder.scala new file mode 100644 index 0000000..f748a45 --- /dev/null +++ b/core/src/main/scala/kafka/utils/PasswordEncoder.scala @@ -0,0 +1,175 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +package kafka.utils + +import java.nio.charset.StandardCharsets +import java.security.{AlgorithmParameters, NoSuchAlgorithmException, SecureRandom} +import java.security.spec.AlgorithmParameterSpec +import java.util.Base64 + +import javax.crypto.{Cipher, SecretKeyFactory} +import javax.crypto.spec._ +import kafka.utils.PasswordEncoder._ +import org.apache.kafka.common.config.ConfigException +import org.apache.kafka.common.config.types.Password + +import scala.collection.Map + +object PasswordEncoder { + val KeyFactoryAlgorithmProp = "keyFactoryAlgorithm" + val CipherAlgorithmProp = "cipherAlgorithm" + val InitializationVectorProp = "initializationVector" + val KeyLengthProp = "keyLength" + val SaltProp = "salt" + val IterationsProp = "iterations" + val EncyrptedPasswordProp = "encryptedPassword" + val PasswordLengthProp = "passwordLength" +} + +/** + * Password encoder and decoder implementation. Encoded passwords are persisted as a CSV map + * containing the encoded password in base64 and along with the properties used for encryption. + * + * @param secret The secret used for encoding and decoding + * @param keyFactoryAlgorithm Key factory algorithm if configured. By default, PBKDF2WithHmacSHA512 is + * used if available, PBKDF2WithHmacSHA1 otherwise. + * @param cipherAlgorithm Cipher algorithm used for encoding. + * @param keyLength Key length used for encoding. This should be valid for the specified algorithms. + * @param iterations Iteration count used for encoding. + * + * The provided `keyFactoryAlgorithm`, 'cipherAlgorithm`, `keyLength` and `iterations` are used for encoding passwords. + * The values used for encoding are stored along with the encoded password and the stored values are used for decoding. + * + */ +class PasswordEncoder(secret: Password, + keyFactoryAlgorithm: Option[String], + cipherAlgorithm: String, + keyLength: Int, + iterations: Int) extends Logging { + + private val secureRandom = new SecureRandom + private val cipherParamsEncoder = cipherParamsInstance(cipherAlgorithm) + + def encode(password: Password): String = { + val salt = new Array[Byte](256) + secureRandom.nextBytes(salt) + val cipher = Cipher.getInstance(cipherAlgorithm) + val keyFactory = secretKeyFactory(keyFactoryAlgorithm) + val keySpec = secretKeySpec(keyFactory, cipherAlgorithm, keyLength, salt, iterations) + cipher.init(Cipher.ENCRYPT_MODE, keySpec) + val encryptedPassword = cipher.doFinal(password.value.getBytes(StandardCharsets.UTF_8)) + val encryptedMap = Map( + KeyFactoryAlgorithmProp -> keyFactory.getAlgorithm, + CipherAlgorithmProp -> cipherAlgorithm, + KeyLengthProp -> keyLength, + SaltProp -> base64Encode(salt), + IterationsProp -> iterations.toString, + EncyrptedPasswordProp -> base64Encode(encryptedPassword), + PasswordLengthProp -> password.value.length + ) ++ cipherParamsEncoder.toMap(cipher.getParameters) + encryptedMap.map { case (k, v) => s"$k:$v" }.mkString(",") + } + + def decode(encodedPassword: String): Password = { + val params = CoreUtils.parseCsvMap(encodedPassword) + val keyFactoryAlg = params(KeyFactoryAlgorithmProp) + val cipherAlg = params(CipherAlgorithmProp) + val keyLength = params(KeyLengthProp).toInt + val salt = base64Decode(params(SaltProp)) + val iterations = params(IterationsProp).toInt + val encryptedPassword = base64Decode(params(EncyrptedPasswordProp)) + val passwordLengthProp = params(PasswordLengthProp).toInt + val cipher = Cipher.getInstance(cipherAlg) + val keyFactory = secretKeyFactory(Some(keyFactoryAlg)) + val keySpec = secretKeySpec(keyFactory, cipherAlg, keyLength, salt, iterations) + cipher.init(Cipher.DECRYPT_MODE, keySpec, cipherParamsEncoder.toParameterSpec(params)) + val password = try { + val decrypted = cipher.doFinal(encryptedPassword) + new String(decrypted, StandardCharsets.UTF_8) + } catch { + case e: Exception => throw new ConfigException("Password could not be decoded", e) + } + if (password.length != passwordLengthProp) // Sanity check + throw new ConfigException("Password could not be decoded, sanity check of length failed") + new Password(password) + } + + private def secretKeyFactory(keyFactoryAlg: Option[String]): SecretKeyFactory = { + keyFactoryAlg match { + case Some(algorithm) => SecretKeyFactory.getInstance(algorithm) + case None => + try { + SecretKeyFactory.getInstance("PBKDF2WithHmacSHA512") + } catch { + case _: NoSuchAlgorithmException => SecretKeyFactory.getInstance("PBKDF2WithHmacSHA1") + } + } + } + + private def secretKeySpec(keyFactory: SecretKeyFactory, + cipherAlg: String, + keyLength: Int, + salt: Array[Byte], iterations: Int): SecretKeySpec = { + val keySpec = new PBEKeySpec(secret.value.toCharArray, salt, iterations, keyLength) + val algorithm = if (cipherAlg.indexOf('/') > 0) cipherAlg.substring(0, cipherAlg.indexOf('/')) else cipherAlg + new SecretKeySpec(keyFactory.generateSecret(keySpec).getEncoded, algorithm) + } + + private def base64Encode(bytes: Array[Byte]): String = Base64.getEncoder.encodeToString(bytes) + + private[utils] def base64Decode(encoded: String): Array[Byte] = Base64.getDecoder.decode(encoded) + + private def cipherParamsInstance(cipherAlgorithm: String): CipherParamsEncoder = { + val aesPattern = "AES/(.*)/.*".r + cipherAlgorithm match { + case aesPattern("GCM") => new GcmParamsEncoder + case _ => new IvParamsEncoder + } + } + + private trait CipherParamsEncoder { + def toMap(cipher: AlgorithmParameters): Map[String, String] + def toParameterSpec(paramMap: Map[String, String]): AlgorithmParameterSpec + } + + private class IvParamsEncoder extends CipherParamsEncoder { + def toMap(cipherParams: AlgorithmParameters): Map[String, String] = { + if (cipherParams != null) { + val ivSpec = cipherParams.getParameterSpec(classOf[IvParameterSpec]) + Map(InitializationVectorProp -> base64Encode(ivSpec.getIV)) + } else + throw new IllegalStateException("Could not determine initialization vector for cipher") + } + def toParameterSpec(paramMap: Map[String, String]): AlgorithmParameterSpec = { + new IvParameterSpec(base64Decode(paramMap(InitializationVectorProp))) + } + } + + private class GcmParamsEncoder extends CipherParamsEncoder { + def toMap(cipherParams: AlgorithmParameters): Map[String, String] = { + if (cipherParams != null) { + val spec = cipherParams.getParameterSpec(classOf[GCMParameterSpec]) + Map(InitializationVectorProp -> base64Encode(spec.getIV), + "authenticationTagLength" -> spec.getTLen.toString) + } else + throw new IllegalStateException("Could not determine initialization vector for cipher") + } + def toParameterSpec(paramMap: Map[String, String]): AlgorithmParameterSpec = { + new GCMParameterSpec(paramMap("authenticationTagLength").toInt, base64Decode(paramMap(InitializationVectorProp))) + } + } +} diff --git a/core/src/main/scala/kafka/utils/Pool.scala b/core/src/main/scala/kafka/utils/Pool.scala new file mode 100644 index 0000000..84bedc1 --- /dev/null +++ b/core/src/main/scala/kafka/utils/Pool.scala @@ -0,0 +1,108 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import java.util.concurrent._ + +import org.apache.kafka.common.KafkaException + +import collection.Set +import scala.jdk.CollectionConverters._ + +class Pool[K,V](valueFactory: Option[K => V] = None) extends Iterable[(K, V)] { + + private val pool: ConcurrentMap[K, V] = new ConcurrentHashMap[K, V] + + def put(k: K, v: V): V = pool.put(k, v) + + def putAll(map: java.util.Map[K, V]): Unit = pool.putAll(map) + + def putIfNotExists(k: K, v: V): V = pool.putIfAbsent(k, v) + + /** + * Gets the value associated with the given key. If there is no associated + * value, then create the value using the pool's value factory and return the + * value associated with the key. The user should declare the factory method + * as lazy if its side-effects need to be avoided. + * + * @param key The key to lookup. + * @return The final value associated with the key. + */ + def getAndMaybePut(key: K): V = { + if (valueFactory.isEmpty) + throw new KafkaException("Empty value factory in pool.") + getAndMaybePut(key, valueFactory.get(key)) + } + + /** + * Gets the value associated with the given key. If there is no associated + * value, then create the value using the provided by `createValue` and return the + * value associated with the key. + * + * @param key The key to lookup. + * @param createValue Factory function. + * @return The final value associated with the key. + */ + def getAndMaybePut(key: K, createValue: => V): V = + pool.computeIfAbsent(key, _ => createValue) + + def contains(id: K): Boolean = pool.containsKey(id) + + def get(key: K): V = pool.get(key) + + def remove(key: K): V = pool.remove(key) + + def remove(key: K, value: V): Boolean = pool.remove(key, value) + + def removeAll(keys: Iterable[K]): Unit = pool.keySet.removeAll(keys.asJavaCollection) + + def keys: Set[K] = pool.keySet.asScala + + def values: Iterable[V] = pool.values.asScala + + def clear(): Unit = { pool.clear() } + + def foreachEntry(f: (K, V) => Unit): Unit = { + pool.forEach((k, v) => f(k, v)) + } + + def foreachWhile(f: (K, V) => Boolean): Unit = { + val iter = pool.entrySet().iterator() + var finished = false + while (!finished && iter.hasNext) { + val entry = iter.next + finished = !f(entry.getKey, entry.getValue) + } + } + + override def size: Int = pool.size + + override def iterator: Iterator[(K, V)] = new Iterator[(K,V)]() { + + private val iter = pool.entrySet.iterator + + def hasNext: Boolean = iter.hasNext + + def next(): (K, V) = { + val n = iter.next + (n.getKey, n.getValue) + } + + } + +} diff --git a/core/src/main/scala/kafka/utils/QuotaUtils.scala b/core/src/main/scala/kafka/utils/QuotaUtils.scala new file mode 100755 index 0000000..93d5cde --- /dev/null +++ b/core/src/main/scala/kafka/utils/QuotaUtils.scala @@ -0,0 +1,75 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import org.apache.kafka.common.MetricName +import org.apache.kafka.common.metrics.{KafkaMetric, Measurable, QuotaViolationException} +import org.apache.kafka.common.metrics.stats.Rate + +/** + * Helper functions related to quotas + */ +object QuotaUtils { + + /** + * This calculates the amount of time needed to bring the observed rate within quota + * assuming that no new metrics are recorded. + * + * If O is the observed rate and T is the target rate over a window of W, to bring O down to T, + * we need to add a delay of X to W such that O * W / (W + X) = T. + * Solving for X, we get X = (O - T)/T * W. + * + * @param timeMs current time in milliseconds + * @return Delay in milliseconds + */ + def throttleTime(e: QuotaViolationException, timeMs: Long): Long = { + val difference = e.value - e.bound + // Use the precise window used by the rate calculation + val throttleTimeMs = difference / e.bound * windowSize(e.metric, timeMs) + Math.round(throttleTimeMs) + } + + /** + * Calculates the amount of time needed to bring the observed rate within quota using the same algorithm as + * throttleTime() utility method but the returned value is capped to given maxThrottleTime + */ + def boundedThrottleTime(e: QuotaViolationException, maxThrottleTime: Long, timeMs: Long): Long = { + math.min(throttleTime(e, timeMs), maxThrottleTime) + } + + /** + * Returns window size of the given metric + * + * @param metric metric with measurable of type Rate + * @param timeMs current time in milliseconds + * @throws IllegalArgumentException if given measurable is not Rate + */ + private def windowSize(metric: KafkaMetric, timeMs: Long): Long = + measurableAsRate(metric.metricName, metric.measurable).windowSize(metric.config, timeMs) + + /** + * Casts provided Measurable to Rate + * @throws IllegalArgumentException if given measurable is not Rate + */ + private def measurableAsRate(name: MetricName, measurable: Measurable): Rate = { + measurable match { + case r: Rate => r + case _ => throw new IllegalArgumentException(s"Metric $name is not a Rate metric, value $measurable") + } + } +} diff --git a/core/src/main/scala/kafka/utils/ReplicationUtils.scala b/core/src/main/scala/kafka/utils/ReplicationUtils.scala new file mode 100644 index 0000000..e2733b8 --- /dev/null +++ b/core/src/main/scala/kafka/utils/ReplicationUtils.scala @@ -0,0 +1,56 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import kafka.api.LeaderAndIsr +import kafka.controller.LeaderIsrAndControllerEpoch +import kafka.zk._ +import org.apache.kafka.common.TopicPartition + +object ReplicationUtils extends Logging { + + def updateLeaderAndIsr(zkClient: KafkaZkClient, partition: TopicPartition, newLeaderAndIsr: LeaderAndIsr, + controllerEpoch: Int): (Boolean, Int) = { + debug(s"Updated ISR for $partition to ${newLeaderAndIsr.isr.mkString(",")}") + val path = TopicPartitionStateZNode.path(partition) + val newLeaderData = TopicPartitionStateZNode.encode(LeaderIsrAndControllerEpoch(newLeaderAndIsr, controllerEpoch)) + // use the epoch of the controller that made the leadership decision, instead of the current controller epoch + val updatePersistentPath: (Boolean, Int) = zkClient.conditionalUpdatePath(path, newLeaderData, + newLeaderAndIsr.zkVersion, Some(checkLeaderAndIsrZkData)) + updatePersistentPath + } + + private def checkLeaderAndIsrZkData(zkClient: KafkaZkClient, path: String, expectedLeaderAndIsrInfo: Array[Byte]): (Boolean, Int) = { + try { + val (writtenLeaderOpt, writtenStat) = zkClient.getDataAndStat(path) + val expectedLeaderOpt = TopicPartitionStateZNode.decode(expectedLeaderAndIsrInfo, writtenStat) + val succeeded = writtenLeaderOpt.exists { writtenData => + val writtenLeaderOpt = TopicPartitionStateZNode.decode(writtenData, writtenStat) + (expectedLeaderOpt, writtenLeaderOpt) match { + case (Some(expectedLeader), Some(writtenLeader)) if expectedLeader == writtenLeader => true + case _ => false + } + } + if (succeeded) (true, writtenStat.getVersion) + else (false, -1) + } catch { + case _: Exception => (false, -1) + } + } + +} diff --git a/core/src/main/scala/kafka/utils/ShutdownableThread.scala b/core/src/main/scala/kafka/utils/ShutdownableThread.scala new file mode 100644 index 0000000..0ca21c4 --- /dev/null +++ b/core/src/main/scala/kafka/utils/ShutdownableThread.scala @@ -0,0 +1,113 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import java.util.concurrent.{CountDownLatch, TimeUnit} + +import org.apache.kafka.common.internals.FatalExitError + +abstract class ShutdownableThread(val name: String, val isInterruptible: Boolean = true) + extends Thread(name) with Logging { + this.setDaemon(false) + this.logIdent = "[" + name + "]: " + private val shutdownInitiated = new CountDownLatch(1) + private val shutdownComplete = new CountDownLatch(1) + @volatile private var isStarted: Boolean = false + + def shutdown(): Unit = { + initiateShutdown() + awaitShutdown() + } + + def isShutdownInitiated: Boolean = shutdownInitiated.getCount == 0 + + def isShutdownComplete: Boolean = shutdownComplete.getCount == 0 + + /** + * @return true if there has been an unexpected error and the thread shut down + */ + // mind that run() might set both when we're shutting down the broker + // but the return value of this function at that point wouldn't matter + def isThreadFailed: Boolean = isShutdownComplete && !isShutdownInitiated + + def initiateShutdown(): Boolean = { + this.synchronized { + if (isRunning) { + info("Shutting down") + shutdownInitiated.countDown() + if (isInterruptible) + interrupt() + true + } else + false + } + } + + /** + * After calling initiateShutdown(), use this API to wait until the shutdown is complete + */ + def awaitShutdown(): Unit = { + if (!isShutdownInitiated) + throw new IllegalStateException("initiateShutdown() was not called before awaitShutdown()") + else { + if (isStarted) + shutdownComplete.await() + info("Shutdown completed") + } + } + + /** + * Causes the current thread to wait until the shutdown is initiated, + * or the specified waiting time elapses. + * + * @param timeout + * @param unit + */ + def pause(timeout: Long, unit: TimeUnit): Unit = { + if (shutdownInitiated.await(timeout, unit)) + trace("shutdownInitiated latch count reached zero. Shutdown called.") + } + + /** + * This method is repeatedly invoked until the thread shuts down or this method throws an exception + */ + def doWork(): Unit + + override def run(): Unit = { + isStarted = true + info("Starting") + try { + while (isRunning) + doWork() + } catch { + case e: FatalExitError => + shutdownInitiated.countDown() + shutdownComplete.countDown() + info("Stopped") + Exit.exit(e.statusCode()) + case e: Throwable => + if (isRunning) + error("Error due to", e) + } finally { + shutdownComplete.countDown() + } + info("Stopped") + } + + def isRunning: Boolean = !isShutdownInitiated +} diff --git a/core/src/main/scala/kafka/utils/Throttler.scala b/core/src/main/scala/kafka/utils/Throttler.scala new file mode 100644 index 0000000..cce6270 --- /dev/null +++ b/core/src/main/scala/kafka/utils/Throttler.scala @@ -0,0 +1,105 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import kafka.metrics.KafkaMetricsGroup +import org.apache.kafka.common.utils.Time + +import java.util.concurrent.TimeUnit +import java.util.Random + +import scala.math._ + +/** + * A class to measure and throttle the rate of some process. The throttler takes a desired rate-per-second + * (the units of the process don't matter, it could be bytes or a count of some other thing), and will sleep for + * an appropriate amount of time when maybeThrottle() is called to attain the desired rate. + * + * @param desiredRatePerSec: The rate we want to hit in units/sec + * @param checkIntervalMs: The interval at which to check our rate + * @param throttleDown: Does throttling increase or decrease our rate? + * @param time: The time implementation to use + */ +@threadsafe +class Throttler(desiredRatePerSec: Double, + checkIntervalMs: Long = 100L, + throttleDown: Boolean = true, + metricName: String = "throttler", + units: String = "entries", + time: Time = Time.SYSTEM) extends Logging with KafkaMetricsGroup { + + private val lock = new Object + private val meter = newMeter(metricName, units, TimeUnit.SECONDS) + private val checkIntervalNs = TimeUnit.MILLISECONDS.toNanos(checkIntervalMs) + private var periodStartNs: Long = time.nanoseconds + private var observedSoFar: Double = 0.0 + + def maybeThrottle(observed: Double): Unit = { + val msPerSec = TimeUnit.SECONDS.toMillis(1) + val nsPerSec = TimeUnit.SECONDS.toNanos(1) + + meter.mark(observed.toLong) + lock synchronized { + observedSoFar += observed + val now = time.nanoseconds + val elapsedNs = now - periodStartNs + // if we have completed an interval AND we have observed something, maybe + // we should take a little nap + if (elapsedNs > checkIntervalNs && observedSoFar > 0) { + val rateInSecs = (observedSoFar * nsPerSec) / elapsedNs + val needAdjustment = !(throttleDown ^ (rateInSecs > desiredRatePerSec)) + if (needAdjustment) { + // solve for the amount of time to sleep to make us hit the desired rate + val desiredRateMs = desiredRatePerSec / msPerSec.toDouble + val elapsedMs = TimeUnit.NANOSECONDS.toMillis(elapsedNs) + val sleepTime = round(observedSoFar / desiredRateMs - elapsedMs) + if (sleepTime > 0) { + trace("Natural rate is %f per second but desired rate is %f, sleeping for %d ms to compensate.".format(rateInSecs, desiredRatePerSec, sleepTime)) + time.sleep(sleepTime) + } + } + periodStartNs = time.nanoseconds() + observedSoFar = 0 + } + } + } + +} + +object Throttler { + + def main(args: Array[String]): Unit = { + val rand = new Random() + val throttler = new Throttler(100000, 100, true, time = Time.SYSTEM) + val interval = 30000 + var start = System.currentTimeMillis + var total = 0 + while(true) { + val value = rand.nextInt(1000) + Thread.sleep(1) + throttler.maybeThrottle(value) + total += value + val now = System.currentTimeMillis + if(now - start >= interval) { + println(total / (interval/1000.0)) + start = now + total = 0 + } + } + } +} diff --git a/core/src/main/scala/kafka/utils/ToolsUtils.scala b/core/src/main/scala/kafka/utils/ToolsUtils.scala new file mode 100644 index 0000000..0f3de76 --- /dev/null +++ b/core/src/main/scala/kafka/utils/ToolsUtils.scala @@ -0,0 +1,67 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +package kafka.utils + +import joptsimple.OptionParser +import org.apache.kafka.common.{Metric, MetricName} + +import scala.collection.mutable + +object ToolsUtils { + + def validatePortOrDie(parser: OptionParser, hostPort: String) = { + val hostPorts: Array[String] = if(hostPort.contains(',')) + hostPort.split(",") + else + Array(hostPort) + val validHostPort = hostPorts.filter { hostPortData => + org.apache.kafka.common.utils.Utils.getPort(hostPortData) != null + } + val isValid = !validHostPort.isEmpty && validHostPort.size == hostPorts.length + if(!isValid) + CommandLineUtils.printUsageAndDie(parser, "Please provide valid host:port like host1:9091,host2:9092\n ") + } + + /** + * print out the metrics in alphabetical order + * @param metrics the metrics to be printed out + */ + def printMetrics(metrics: mutable.Map[MetricName, _ <: Metric]): Unit = { + var maxLengthOfDisplayName = 0 + + val sortedMap = metrics.toSeq.sortWith( (s,t) => + Array(s._1.group(), s._1.name(), s._1.tags()).mkString(":") + .compareTo(Array(t._1.group(), t._1.name(), t._1.tags()).mkString(":")) < 0 + ).map { + case (key, value) => + val mergedKeyName = Array(key.group(), key.name(), key.tags()).mkString(":") + if (maxLengthOfDisplayName < mergedKeyName.length) { + maxLengthOfDisplayName = mergedKeyName.length + } + (mergedKeyName, value.metricValue) + } + println(s"\n%-${maxLengthOfDisplayName}s %s".format("Metric Name", "Value")) + sortedMap.foreach { + case (metricName, value) => + val specifier = value match { + case _ @ (_: java.lang.Float | _: java.lang.Double) => "%.3f" + case _ => "%s" + } + println(s"%-${maxLengthOfDisplayName}s : $specifier".format(metricName, value)) + } + } +} diff --git a/core/src/main/scala/kafka/utils/TopicFilter.scala b/core/src/main/scala/kafka/utils/TopicFilter.scala new file mode 100644 index 0000000..075c147 --- /dev/null +++ b/core/src/main/scala/kafka/utils/TopicFilter.scala @@ -0,0 +1,55 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import java.util.regex.{Pattern, PatternSyntaxException} + +import org.apache.kafka.common.internals.Topic + +sealed abstract class TopicFilter(rawRegex: String) extends Logging { + + val regex = rawRegex + .trim + .replace(',', '|') + .replace(" ", "") + .replaceAll("""^["']+""","") + .replaceAll("""["']+$""","") // property files may bring quotes + + try { + Pattern.compile(regex) + } + catch { + case _: PatternSyntaxException => + throw new RuntimeException(regex + " is an invalid regex.") + } + + override def toString = regex + + def isTopicAllowed(topic: String, excludeInternalTopics: Boolean): Boolean +} + +case class IncludeList(rawRegex: String) extends TopicFilter(rawRegex) { + override def isTopicAllowed(topic: String, excludeInternalTopics: Boolean) = { + val allowed = topic.matches(regex) && !(Topic.isInternal(topic) && excludeInternalTopics) + + debug("%s %s".format( + topic, if (allowed) "allowed" else "filtered")) + + allowed + } +} diff --git a/core/src/main/scala/kafka/utils/VerifiableProperties.scala b/core/src/main/scala/kafka/utils/VerifiableProperties.scala new file mode 100755 index 0000000..878398a --- /dev/null +++ b/core/src/main/scala/kafka/utils/VerifiableProperties.scala @@ -0,0 +1,238 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import java.util.Properties +import java.util.Collections +import scala.collection._ +import kafka.message.{CompressionCodec, NoCompressionCodec} +import scala.jdk.CollectionConverters._ +import kafka.utils.Implicits._ + +object VerifiableProperties { + def apply(map: java.util.Map[String, AnyRef]): VerifiableProperties = { + val props = new Properties() + props ++= map.asScala + new VerifiableProperties(props) + } +} + +class VerifiableProperties(val props: Properties) extends Logging { + private val referenceSet = mutable.HashSet[String]() + + def this() = this(new Properties) + + def containsKey(name: String): Boolean = { + props.containsKey(name) + } + + def getProperty(name: String): String = { + val value = props.getProperty(name) + referenceSet.add(name) + if(value == null) value else value.trim() + } + + /** + * Read a required integer property value or throw an exception if no such property is found + */ + def getInt(name: String): Int = getString(name).toInt + + def getIntInRange(name: String, range: (Int, Int)): Int = { + require(containsKey(name), "Missing required property '" + name + "'") + getIntInRange(name, -1, range) + } + + /** + * Read an integer from the properties instance + * @param name The property name + * @param default The default value to use if the property is not found + * @return the integer value + */ + def getInt(name: String, default: Int): Int = + getIntInRange(name, default, (Int.MinValue, Int.MaxValue)) + + def getShort(name: String, default: Short): Short = + getShortInRange(name, default, (Short.MinValue, Short.MaxValue)) + + /** + * Read an integer from the properties instance. Throw an exception + * if the value is not in the given range (inclusive) + * @param name The property name + * @param default The default value to use if the property is not found + * @param range The range in which the value must fall (inclusive) + * @throws IllegalArgumentException If the value is not in the given range + * @return the integer value + */ + def getIntInRange(name: String, default: Int, range: (Int, Int)): Int = { + val v = + if(containsKey(name)) + getProperty(name).toInt + else + default + require(v >= range._1 && v <= range._2, name + " has value " + v + " which is not in the range " + range + ".") + v + } + + def getShortInRange(name: String, default: Short, range: (Short, Short)): Short = { + val v = + if(containsKey(name)) + getProperty(name).toShort + else + default + require(v >= range._1 && v <= range._2, name + " has value " + v + " which is not in the range " + range + ".") + v + } + + /** + * Read a required long property value or throw an exception if no such property is found + */ + def getLong(name: String): Long = getString(name).toLong + + /** + * Read an long from the properties instance + * @param name The property name + * @param default The default value to use if the property is not found + * @return the long value + */ + def getLong(name: String, default: Long): Long = + getLongInRange(name, default, (Long.MinValue, Long.MaxValue)) + + /** + * Read an long from the properties instance. Throw an exception + * if the value is not in the given range (inclusive) + * @param name The property name + * @param default The default value to use if the property is not found + * @param range The range in which the value must fall (inclusive) + * @throws IllegalArgumentException If the value is not in the given range + * @return the long value + */ + def getLongInRange(name: String, default: Long, range: (Long, Long)): Long = { + val v = + if(containsKey(name)) + getProperty(name).toLong + else + default + require(v >= range._1 && v <= range._2, name + " has value " + v + " which is not in the range " + range + ".") + v + } + + /** + * Get a required argument as a double + * @param name The property name + * @return the value + * @throws IllegalArgumentException If the given property is not present + */ + def getDouble(name: String): Double = getString(name).toDouble + + /** + * Get an optional argument as a double + * @param name The property name + * @param default The default value for the property if not present + */ + def getDouble(name: String, default: Double): Double = { + if(containsKey(name)) + getDouble(name) + else + default + } + + /** + * Read a boolean value from the properties instance + * @param name The property name + * @param default The default value to use if the property is not found + * @return the boolean value + */ + def getBoolean(name: String, default: Boolean): Boolean = { + if(!containsKey(name)) + default + else { + val v = getProperty(name) + require(v == "true" || v == "false", "Unacceptable value for property '" + name + "', boolean values must be either 'true' or 'false") + v.toBoolean + } + } + + def getBoolean(name: String) = getString(name).toBoolean + + /** + * Get a string property, or, if no such property is defined, return the given default value + */ + def getString(name: String, default: String): String = { + if(containsKey(name)) + getProperty(name) + else + default + } + + /** + * Get a string property or throw and exception if no such property is defined. + */ + def getString(name: String): String = { + require(containsKey(name), "Missing required property '" + name + "'") + getProperty(name) + } + + /** + * Get a Map[String, String] from a property list in the form k1:v2, k2:v2, ... + */ + def getMap(name: String, valid: String => Boolean = _ => true): Map[String, String] = { + try { + val m = CoreUtils.parseCsvMap(getString(name, "")) + m.foreach { + case(key, value) => + if(!valid(value)) + throw new IllegalArgumentException("Invalid entry '%s' = '%s' for property '%s'".format(key, value, name)) + } + m + } catch { + case e: Exception => throw new IllegalArgumentException("Error parsing configuration property '%s': %s".format(name, e.getMessage)) + } + } + + /** + * Parse compression codec from a property list in either. Codecs may be specified as integers, or as strings. + * See [[kafka.message.CompressionCodec]] for more details. + * @param name The property name + * @param default Default compression codec + * @return compression codec + */ + def getCompressionCodec(name: String, default: CompressionCodec) = { + val prop = getString(name, NoCompressionCodec.name) + try { + CompressionCodec.getCompressionCodec(prop.toInt) + } + catch { + case _: NumberFormatException => + CompressionCodec.getCompressionCodec(prop) + } + } + + def verify(): Unit = { + info("Verifying properties") + val propNames = Collections.list(props.propertyNames).asScala.map(_.toString).sorted + for(key <- propNames) { + if (!referenceSet.contains(key) && !key.startsWith("external")) + warn("Property %s is not valid".format(key)) + else + info("Property %s is overridden to %s".format(key, props.getProperty(key))) + } + } + + override def toString: String = props.toString + +} diff --git a/core/src/main/scala/kafka/utils/VersionInfo.scala b/core/src/main/scala/kafka/utils/VersionInfo.scala new file mode 100644 index 0000000..9d3130e --- /dev/null +++ b/core/src/main/scala/kafka/utils/VersionInfo.scala @@ -0,0 +1,40 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import org.apache.kafka.common.utils.AppInfoParser + +object VersionInfo { + + def main(args: Array[String]): Unit = { + System.out.println(getVersionString) + System.exit(0) + } + + def getVersion: String = { + AppInfoParser.getVersion + } + + def getCommit: String = { + AppInfoParser.getCommitId + } + + def getVersionString: String = { + s"${getVersion} (Commit:${getCommit})" + } +} diff --git a/core/src/main/scala/kafka/utils/json/DecodeJson.scala b/core/src/main/scala/kafka/utils/json/DecodeJson.scala new file mode 100644 index 0000000..71bfd61 --- /dev/null +++ b/core/src/main/scala/kafka/utils/json/DecodeJson.scala @@ -0,0 +1,111 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils.json + +import scala.collection.{Map, Seq} +import scala.collection.compat._ +import scala.jdk.CollectionConverters._ + +import com.fasterxml.jackson.databind.{JsonMappingException, JsonNode} + +/** + * A type class for parsing JSON. This should typically be used via `JsonValue.apply`. + */ +trait DecodeJson[T] { + + /** + * Decode the JSON node provided into an instance of `Right[T]`, if possible. Otherwise, return an error message + * wrapped by an instance of `Left`. + */ + def decodeEither(node: JsonNode): Either[String, T] + + /** + * Decode the JSON node provided into an instance of `T`. + * + * @throws JsonMappingException if `node` cannot be decoded into `T`. + */ + def decode(node: JsonNode): T = + decodeEither(node) match { + case Right(x) => x + case Left(x) => throw new JsonMappingException(null, x) + } + +} + +/** + * Contains `DecodeJson` type class instances. That is, we need one instance for each type that we want to be able to + * to parse into. It is a compiler error to try to parse into a type for which there is no instance. + */ +object DecodeJson { + + implicit object DecodeBoolean extends DecodeJson[Boolean] { + def decodeEither(node: JsonNode): Either[String, Boolean] = + if (node.isBoolean) Right(node.booleanValue) else Left(s"Expected `Boolean` value, received $node") + } + + implicit object DecodeDouble extends DecodeJson[Double] { + def decodeEither(node: JsonNode): Either[String, Double] = + if (node.isDouble || node.isLong || node.isInt) + Right(node.doubleValue) + else Left(s"Expected `Double` value, received $node") + } + + implicit object DecodeInt extends DecodeJson[Int] { + def decodeEither(node: JsonNode): Either[String, Int] = + if (node.isInt) Right(node.intValue) else Left(s"Expected `Int` value, received $node") + } + + implicit object DecodeLong extends DecodeJson[Long] { + def decodeEither(node: JsonNode): Either[String, Long] = + if (node.isLong || node.isInt) Right(node.longValue) else Left(s"Expected `Long` value, received $node") + } + + implicit object DecodeString extends DecodeJson[String] { + def decodeEither(node: JsonNode): Either[String, String] = + if (node.isTextual) Right(node.textValue) else Left(s"Expected `String` value, received $node") + } + + implicit def decodeOption[E](implicit decodeJson: DecodeJson[E]): DecodeJson[Option[E]] = (node: JsonNode) => { + if (node.isNull) Right(None) + else decodeJson.decodeEither(node).map(Some(_)) + } + + implicit def decodeSeq[E, S[+T] <: Seq[E]](implicit decodeJson: DecodeJson[E], factory: Factory[E, S[E]]): DecodeJson[S[E]] = (node: JsonNode) => { + if (node.isArray) + decodeIterator(node.elements.asScala)(decodeJson.decodeEither) + else Left(s"Expected JSON array, received $node") + } + + implicit def decodeMap[V, M[K, +V] <: Map[K, V]](implicit decodeJson: DecodeJson[V], factory: Factory[(String, V), M[String, V]]): DecodeJson[M[String, V]] = (node: JsonNode) => { + if (node.isObject) + decodeIterator(node.fields.asScala)(e => decodeJson.decodeEither(e.getValue).map(v => (e.getKey, v))) + else Left(s"Expected JSON object, received $node") + } + + private def decodeIterator[S, T, C](it: Iterator[S])(f: S => Either[String, T])(implicit factory: Factory[T, C]): Either[String, C] = { + val result = factory.newBuilder + while (it.hasNext) { + f(it.next()) match { + case Right(x) => result += x + case Left(x) => return Left(x) + } + } + Right(result.result()) + } + +} diff --git a/core/src/main/scala/kafka/utils/json/JsonArray.scala b/core/src/main/scala/kafka/utils/json/JsonArray.scala new file mode 100644 index 0000000..c22eda8 --- /dev/null +++ b/core/src/main/scala/kafka/utils/json/JsonArray.scala @@ -0,0 +1,27 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils.json + +import scala.collection.Iterator +import scala.jdk.CollectionConverters._ + +import com.fasterxml.jackson.databind.node.ArrayNode + +class JsonArray private[json] (protected val node: ArrayNode) extends JsonValue { + def iterator: Iterator[JsonValue] = node.elements.asScala.map(JsonValue(_)) +} diff --git a/core/src/main/scala/kafka/utils/json/JsonObject.scala b/core/src/main/scala/kafka/utils/json/JsonObject.scala new file mode 100644 index 0000000..9bf91ae --- /dev/null +++ b/core/src/main/scala/kafka/utils/json/JsonObject.scala @@ -0,0 +1,42 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils.json + +import com.fasterxml.jackson.databind.JsonMappingException + +import scala.jdk.CollectionConverters._ + +import com.fasterxml.jackson.databind.node.ObjectNode + +import scala.collection.Iterator + +/** + * A thin wrapper over Jackson's `ObjectNode` for a more idiomatic API. See `JsonValue` for more details. + */ +class JsonObject private[json] (protected val node: ObjectNode) extends JsonValue { + + def apply(name: String): JsonValue = + get(name).getOrElse(throw new JsonMappingException(null, s"No such field exists: `$name`")) + + def get(name: String): Option[JsonValue] = Option(node.get(name)).map(JsonValue(_)) + + def iterator: Iterator[(String, JsonValue)] = node.fields.asScala.map { entry => + (entry.getKey, JsonValue(entry.getValue)) + } + +} diff --git a/core/src/main/scala/kafka/utils/json/JsonValue.scala b/core/src/main/scala/kafka/utils/json/JsonValue.scala new file mode 100644 index 0000000..ff62c6c --- /dev/null +++ b/core/src/main/scala/kafka/utils/json/JsonValue.scala @@ -0,0 +1,116 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils.json + +import com.fasterxml.jackson.databind.{JsonMappingException, JsonNode} +import com.fasterxml.jackson.databind.node.{ArrayNode, ObjectNode} + +/** + * A simple wrapper over Jackson's JsonNode that enables type safe parsing via the `DecodeJson` type + * class. + * + * Typical usage would be something like: + * + * {{{ + * val jsonNode: JsonNode = ??? + * val jsonObject = JsonValue(jsonNode).asJsonObject + * val intValue = jsonObject("int_field").to[Int] + * val optionLongValue = jsonObject("option_long_field").to[Option[Long]] + * val mapStringIntField = jsonObject("map_string_int_field").to[Map[String, Int]] + * val seqStringField = jsonObject("seq_string_field").to[Seq[String] + * }}} + * + * The `to` method throws an exception if the value cannot be converted to the requested type. An alternative is the + * `toEither` method that returns an `Either` instead. + */ +trait JsonValue { + + protected def node: JsonNode + + /** + * Decode this JSON value into an instance of `T`. + * + * @throws JsonMappingException if this value cannot be decoded into `T`. + */ + def to[T](implicit decodeJson: DecodeJson[T]): T = decodeJson.decode(node) + + /** + * Decode this JSON value into an instance of `Right[T]`, if possible. Otherwise, return an error message + * wrapped by an instance of `Left`. + */ + def toEither[T](implicit decodeJson: DecodeJson[T]): Either[String, T] = decodeJson.decodeEither(node) + + /** + * If this is a JSON object, return an instance of JsonObject. Otherwise, throw a JsonMappingException. + */ + def asJsonObject: JsonObject = + asJsonObjectOption.getOrElse(throw new JsonMappingException(null, s"Expected JSON object, received $node")) + + /** + * If this is a JSON object, return a JsonObject wrapped by a `Some`. Otherwise, return None. + */ + def asJsonObjectOption: Option[JsonObject] = this match { + case j: JsonObject => Some(j) + case _ => node match { + case n: ObjectNode => Some(new JsonObject(n)) + case _ => None + } + } + + /** + * If this is a JSON array, return an instance of JsonArray. Otherwise, throw a JsonMappingException. + */ + def asJsonArray: JsonArray = + asJsonArrayOption.getOrElse(throw new JsonMappingException(null, s"Expected JSON array, received $node")) + + /** + * If this is a JSON array, return a JsonArray wrapped by a `Some`. Otherwise, return None. + */ + def asJsonArrayOption: Option[JsonArray] = this match { + case j: JsonArray => Some(j) + case _ => node match { + case n: ArrayNode => Some(new JsonArray(n)) + case _ => None + } + } + + override def hashCode: Int = node.hashCode + + override def equals(a: Any): Boolean = a match { + case a: JsonValue => node == a.node + case _ => false + } + + override def toString: String = node.toString + +} + +object JsonValue { + + /** + * Create an instance of `JsonValue` from Jackson's `JsonNode`. + */ + def apply(node: JsonNode): JsonValue = node match { + case n: ObjectNode => new JsonObject(n) + case n: ArrayNode => new JsonArray(n) + case _ => new BasicJsonValue(node) + } + + private class BasicJsonValue private[json] (protected val node: JsonNode) extends JsonValue + +} diff --git a/core/src/main/scala/kafka/utils/timer/Timer.scala b/core/src/main/scala/kafka/utils/timer/Timer.scala new file mode 100644 index 0000000..6973dcb --- /dev/null +++ b/core/src/main/scala/kafka/utils/timer/Timer.scala @@ -0,0 +1,125 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.utils.timer + +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.locks.ReentrantReadWriteLock +import java.util.concurrent.{DelayQueue, Executors, TimeUnit} + +import kafka.utils.threadsafe +import org.apache.kafka.common.utils.{KafkaThread, Time} + +trait Timer { + /** + * Add a new task to this executor. It will be executed after the task's delay + * (beginning from the time of submission) + * @param timerTask the task to add + */ + def add(timerTask: TimerTask): Unit + + /** + * Advance the internal clock, executing any tasks whose expiration has been + * reached within the duration of the passed timeout. + * @param timeoutMs + * @return whether or not any tasks were executed + */ + def advanceClock(timeoutMs: Long): Boolean + + /** + * Get the number of tasks pending execution + * @return the number of tasks + */ + def size: Int + + /** + * Shutdown the timer service, leaving pending tasks unexecuted + */ + def shutdown(): Unit +} + +@threadsafe +class SystemTimer(executorName: String, + tickMs: Long = 1, + wheelSize: Int = 20, + startMs: Long = Time.SYSTEM.hiResClockMs) extends Timer { + + // timeout timer + private[this] val taskExecutor = Executors.newFixedThreadPool(1, + (runnable: Runnable) => KafkaThread.nonDaemon("executor-" + executorName, runnable)) + + private[this] val delayQueue = new DelayQueue[TimerTaskList]() + private[this] val taskCounter = new AtomicInteger(0) + private[this] val timingWheel = new TimingWheel( + tickMs = tickMs, + wheelSize = wheelSize, + startMs = startMs, + taskCounter = taskCounter, + delayQueue + ) + + // Locks used to protect data structures while ticking + private[this] val readWriteLock = new ReentrantReadWriteLock() + private[this] val readLock = readWriteLock.readLock() + private[this] val writeLock = readWriteLock.writeLock() + + def add(timerTask: TimerTask): Unit = { + readLock.lock() + try { + addTimerTaskEntry(new TimerTaskEntry(timerTask, timerTask.delayMs + Time.SYSTEM.hiResClockMs)) + } finally { + readLock.unlock() + } + } + + private def addTimerTaskEntry(timerTaskEntry: TimerTaskEntry): Unit = { + if (!timingWheel.add(timerTaskEntry)) { + // Already expired or cancelled + if (!timerTaskEntry.cancelled) + taskExecutor.submit(timerTaskEntry.timerTask) + } + } + + /* + * Advances the clock if there is an expired bucket. If there isn't any expired bucket when called, + * waits up to timeoutMs before giving up. + */ + def advanceClock(timeoutMs: Long): Boolean = { + var bucket = delayQueue.poll(timeoutMs, TimeUnit.MILLISECONDS) + if (bucket != null) { + writeLock.lock() + try { + while (bucket != null) { + timingWheel.advanceClock(bucket.getExpiration) + bucket.flush(addTimerTaskEntry) + bucket = delayQueue.poll() + } + } finally { + writeLock.unlock() + } + true + } else { + false + } + } + + def size: Int = taskCounter.get + + override def shutdown(): Unit = { + taskExecutor.shutdown() + } + +} diff --git a/core/src/main/scala/kafka/utils/timer/TimerTask.scala b/core/src/main/scala/kafka/utils/timer/TimerTask.scala new file mode 100644 index 0000000..a1995c1 --- /dev/null +++ b/core/src/main/scala/kafka/utils/timer/TimerTask.scala @@ -0,0 +1,45 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.utils.timer + +trait TimerTask extends Runnable { + + val delayMs: Long // timestamp in millisecond + + private[this] var timerTaskEntry: TimerTaskEntry = null + + def cancel(): Unit = { + synchronized { + if (timerTaskEntry != null) timerTaskEntry.remove() + timerTaskEntry = null + } + } + + private[timer] def setTimerTaskEntry(entry: TimerTaskEntry): Unit = { + synchronized { + // if this timerTask is already held by an existing timer task entry, + // we will remove such an entry first. + if (timerTaskEntry != null && timerTaskEntry != entry) + timerTaskEntry.remove() + + timerTaskEntry = entry + } + } + + private[timer] def getTimerTaskEntry: TimerTaskEntry = timerTaskEntry + +} diff --git a/core/src/main/scala/kafka/utils/timer/TimerTaskList.scala b/core/src/main/scala/kafka/utils/timer/TimerTaskList.scala new file mode 100644 index 0000000..efd0480 --- /dev/null +++ b/core/src/main/scala/kafka/utils/timer/TimerTaskList.scala @@ -0,0 +1,159 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.utils.timer + +import java.util.concurrent.{Delayed, TimeUnit} +import java.util.concurrent.atomic.{AtomicInteger, AtomicLong} + +import kafka.utils.threadsafe +import org.apache.kafka.common.utils.Time + +import scala.math._ + +@threadsafe +private[timer] class TimerTaskList(taskCounter: AtomicInteger) extends Delayed { + + // TimerTaskList forms a doubly linked cyclic list using a dummy root entry + // root.next points to the head + // root.prev points to the tail + private[this] val root = new TimerTaskEntry(null, -1) + root.next = root + root.prev = root + + private[this] val expiration = new AtomicLong(-1L) + + // Set the bucket's expiration time + // Returns true if the expiration time is changed + def setExpiration(expirationMs: Long): Boolean = { + expiration.getAndSet(expirationMs) != expirationMs + } + + // Get the bucket's expiration time + def getExpiration: Long = expiration.get + + // Apply the supplied function to each of tasks in this list + def foreach(f: (TimerTask)=>Unit): Unit = { + synchronized { + var entry = root.next + while (entry ne root) { + val nextEntry = entry.next + + if (!entry.cancelled) f(entry.timerTask) + + entry = nextEntry + } + } + } + + // Add a timer task entry to this list + def add(timerTaskEntry: TimerTaskEntry): Unit = { + var done = false + while (!done) { + // Remove the timer task entry if it is already in any other list + // We do this outside of the sync block below to avoid deadlocking. + // We may retry until timerTaskEntry.list becomes null. + timerTaskEntry.remove() + + synchronized { + timerTaskEntry.synchronized { + if (timerTaskEntry.list == null) { + // put the timer task entry to the end of the list. (root.prev points to the tail entry) + val tail = root.prev + timerTaskEntry.next = root + timerTaskEntry.prev = tail + timerTaskEntry.list = this + tail.next = timerTaskEntry + root.prev = timerTaskEntry + taskCounter.incrementAndGet() + done = true + } + } + } + } + } + + // Remove the specified timer task entry from this list + def remove(timerTaskEntry: TimerTaskEntry): Unit = { + synchronized { + timerTaskEntry.synchronized { + if (timerTaskEntry.list eq this) { + timerTaskEntry.next.prev = timerTaskEntry.prev + timerTaskEntry.prev.next = timerTaskEntry.next + timerTaskEntry.next = null + timerTaskEntry.prev = null + timerTaskEntry.list = null + taskCounter.decrementAndGet() + } + } + } + } + + // Remove all task entries and apply the supplied function to each of them + def flush(f: TimerTaskEntry => Unit): Unit = { + synchronized { + var head = root.next + while (head ne root) { + remove(head) + f(head) + head = root.next + } + expiration.set(-1L) + } + } + + def getDelay(unit: TimeUnit): Long = { + unit.convert(max(getExpiration - Time.SYSTEM.hiResClockMs, 0), TimeUnit.MILLISECONDS) + } + + def compareTo(d: Delayed): Int = { + val other = d.asInstanceOf[TimerTaskList] + java.lang.Long.compare(getExpiration, other.getExpiration) + } + +} + +private[timer] class TimerTaskEntry(val timerTask: TimerTask, val expirationMs: Long) extends Ordered[TimerTaskEntry] { + + @volatile + var list: TimerTaskList = null + var next: TimerTaskEntry = null + var prev: TimerTaskEntry = null + + // if this timerTask is already held by an existing timer task entry, + // setTimerTaskEntry will remove it. + if (timerTask != null) timerTask.setTimerTaskEntry(this) + + def cancelled: Boolean = { + timerTask.getTimerTaskEntry != this + } + + def remove(): Unit = { + var currentList = list + // If remove is called when another thread is moving the entry from a task entry list to another, + // this may fail to remove the entry due to the change of value of list. Thus, we retry until the list becomes null. + // In a rare case, this thread sees null and exits the loop, but the other thread insert the entry to another list later. + while (currentList != null) { + currentList.remove(this) + currentList = list + } + } + + override def compare(that: TimerTaskEntry): Int = { + java.lang.Long.compare(expirationMs, that.expirationMs) + } +} + diff --git a/core/src/main/scala/kafka/utils/timer/TimingWheel.scala b/core/src/main/scala/kafka/utils/timer/TimingWheel.scala new file mode 100644 index 0000000..4535f3f --- /dev/null +++ b/core/src/main/scala/kafka/utils/timer/TimingWheel.scala @@ -0,0 +1,166 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.utils.timer + +import kafka.utils.nonthreadsafe + +import java.util.concurrent.DelayQueue +import java.util.concurrent.atomic.AtomicInteger + +/* + * Hierarchical Timing Wheels + * + * A simple timing wheel is a circular list of buckets of timer tasks. Let u be the time unit. + * A timing wheel with size n has n buckets and can hold timer tasks in n * u time interval. + * Each bucket holds timer tasks that fall into the corresponding time range. At the beginning, + * the first bucket holds tasks for [0, u), the second bucket holds tasks for [u, 2u), …, + * the n-th bucket for [u * (n -1), u * n). Every interval of time unit u, the timer ticks and + * moved to the next bucket then expire all timer tasks in it. So, the timer never insert a task + * into the bucket for the current time since it is already expired. The timer immediately runs + * the expired task. The emptied bucket is then available for the next round, so if the current + * bucket is for the time t, it becomes the bucket for [t + u * n, t + (n + 1) * u) after a tick. + * A timing wheel has O(1) cost for insert/delete (start-timer/stop-timer) whereas priority queue + * based timers, such as java.util.concurrent.DelayQueue and java.util.Timer, have O(log n) + * insert/delete cost. + * + * A major drawback of a simple timing wheel is that it assumes that a timer request is within + * the time interval of n * u from the current time. If a timer request is out of this interval, + * it is an overflow. A hierarchical timing wheel deals with such overflows. It is a hierarchically + * organized timing wheels. The lowest level has the finest time resolution. As moving up the + * hierarchy, time resolutions become coarser. If the resolution of a wheel at one level is u and + * the size is n, the resolution of the next level should be n * u. At each level overflows are + * delegated to the wheel in one level higher. When the wheel in the higher level ticks, it reinsert + * timer tasks to the lower level. An overflow wheel can be created on-demand. When a bucket in an + * overflow bucket expires, all tasks in it are reinserted into the timer recursively. The tasks + * are then moved to the finer grain wheels or be executed. The insert (start-timer) cost is O(m) + * where m is the number of wheels, which is usually very small compared to the number of requests + * in the system, and the delete (stop-timer) cost is still O(1). + * + * Example + * Let's say that u is 1 and n is 3. If the start time is c, + * then the buckets at different levels are: + * + * level buckets + * 1 [c,c] [c+1,c+1] [c+2,c+2] + * 2 [c,c+2] [c+3,c+5] [c+6,c+8] + * 3 [c,c+8] [c+9,c+17] [c+18,c+26] + * + * The bucket expiration is at the time of bucket beginning. + * So at time = c+1, buckets [c,c], [c,c+2] and [c,c+8] are expired. + * Level 1's clock moves to c+1, and [c+3,c+3] is created. + * Level 2 and level3's clock stay at c since their clocks move in unit of 3 and 9, respectively. + * So, no new buckets are created in level 2 and 3. + * + * Note that bucket [c,c+2] in level 2 won't receive any task since that range is already covered in level 1. + * The same is true for the bucket [c,c+8] in level 3 since its range is covered in level 2. + * This is a bit wasteful, but simplifies the implementation. + * + * 1 [c+1,c+1] [c+2,c+2] [c+3,c+3] + * 2 [c,c+2] [c+3,c+5] [c+6,c+8] + * 3 [c,c+8] [c+9,c+17] [c+18,c+26] + * + * At time = c+2, [c+1,c+1] is newly expired. + * Level 1 moves to c+2, and [c+4,c+4] is created, + * + * 1 [c+2,c+2] [c+3,c+3] [c+4,c+4] + * 2 [c,c+2] [c+3,c+5] [c+6,c+8] + * 3 [c,c+8] [c+9,c+17] [c+18,c+18] + * + * At time = c+3, [c+2,c+2] is newly expired. + * Level 2 moves to c+3, and [c+5,c+5] and [c+9,c+11] are created. + * Level 3 stay at c. + * + * 1 [c+3,c+3] [c+4,c+4] [c+5,c+5] + * 2 [c+3,c+5] [c+6,c+8] [c+9,c+11] + * 3 [c,c+8] [c+9,c+17] [c+8,c+11] + * + * The hierarchical timing wheels works especially well when operations are completed before they time out. + * Even when everything times out, it still has advantageous when there are many items in the timer. + * Its insert cost (including reinsert) and delete cost are O(m) and O(1), respectively while priority + * queue based timers takes O(log N) for both insert and delete where N is the number of items in the queue. + * + * This class is not thread-safe. There should not be any add calls while advanceClock is executing. + * It is caller's responsibility to enforce it. Simultaneous add calls are thread-safe. + */ +@nonthreadsafe +private[timer] class TimingWheel(tickMs: Long, wheelSize: Int, startMs: Long, taskCounter: AtomicInteger, queue: DelayQueue[TimerTaskList]) { + + private[this] val interval = tickMs * wheelSize + private[this] val buckets = Array.tabulate[TimerTaskList](wheelSize) { _ => new TimerTaskList(taskCounter) } + + private[this] var currentTime = startMs - (startMs % tickMs) // rounding down to multiple of tickMs + + // overflowWheel can potentially be updated and read by two concurrent threads through add(). + // Therefore, it needs to be volatile due to the issue of Double-Checked Locking pattern with JVM + @volatile private[this] var overflowWheel: TimingWheel = null + + private[this] def addOverflowWheel(): Unit = { + synchronized { + if (overflowWheel == null) { + overflowWheel = new TimingWheel( + tickMs = interval, + wheelSize = wheelSize, + startMs = currentTime, + taskCounter = taskCounter, + queue + ) + } + } + } + + def add(timerTaskEntry: TimerTaskEntry): Boolean = { + val expiration = timerTaskEntry.expirationMs + + if (timerTaskEntry.cancelled) { + // Cancelled + false + } else if (expiration < currentTime + tickMs) { + // Already expired + false + } else if (expiration < currentTime + interval) { + // Put in its own bucket + val virtualId = expiration / tickMs + val bucket = buckets((virtualId % wheelSize.toLong).toInt) + bucket.add(timerTaskEntry) + + // Set the bucket expiration time + if (bucket.setExpiration(virtualId * tickMs)) { + // The bucket needs to be enqueued because it was an expired bucket + // We only need to enqueue the bucket when its expiration time has changed, i.e. the wheel has advanced + // and the previous buckets gets reused; further calls to set the expiration within the same wheel cycle + // will pass in the same value and hence return false, thus the bucket with the same expiration will not + // be enqueued multiple times. + queue.offer(bucket) + } + true + } else { + // Out of the interval. Put it into the parent timer + if (overflowWheel == null) addOverflowWheel() + overflowWheel.add(timerTaskEntry) + } + } + + // Try to advance the clock + def advanceClock(timeMs: Long): Unit = { + if (timeMs >= currentTime + tickMs) { + currentTime = timeMs - (timeMs % tickMs) + + // Try to advance the clock of the overflow wheel if present + if (overflowWheel != null) overflowWheel.advanceClock(currentTime) + } + } +} diff --git a/core/src/main/scala/kafka/zk/AdminZkClient.scala b/core/src/main/scala/kafka/zk/AdminZkClient.scala new file mode 100644 index 0000000..0c4e3b5 --- /dev/null +++ b/core/src/main/scala/kafka/zk/AdminZkClient.scala @@ -0,0 +1,544 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +package kafka.zk + +import java.util.Properties + +import kafka.admin.{AdminOperationException, AdminUtils, BrokerMetadata, RackAwareMode} +import kafka.common.TopicAlreadyMarkedForDeletionException +import kafka.controller.ReplicaAssignment +import kafka.log.LogConfig +import kafka.server.{ConfigEntityName, ConfigType, DynamicConfig} +import kafka.utils._ +import kafka.utils.Implicits._ +import org.apache.kafka.common.{TopicPartition, Uuid} +import org.apache.kafka.common.errors._ +import org.apache.kafka.common.internals.Topic +import org.apache.zookeeper.KeeperException.NodeExistsException + +import scala.collection.{Map, Seq} + +/** + * Provides admin related methods for interacting with ZooKeeper. + * + * This is an internal class and no compatibility guarantees are provided, + * see org.apache.kafka.clients.admin.AdminClient for publicly supported APIs. + */ +class AdminZkClient(zkClient: KafkaZkClient) extends Logging { + + /** + * Creates the topic with given configuration + * @param topic topic name to create + * @param partitions Number of partitions to be set + * @param replicationFactor Replication factor + * @param topicConfig topic configs + * @param rackAwareMode rack aware mode for replica assignment + * @param usesTopicId Boolean indicating whether the topic ID will be created + */ + def createTopic(topic: String, + partitions: Int, + replicationFactor: Int, + topicConfig: Properties = new Properties, + rackAwareMode: RackAwareMode = RackAwareMode.Enforced, + usesTopicId: Boolean = false): Unit = { + val brokerMetadatas = getBrokerMetadatas(rackAwareMode) + val replicaAssignment = AdminUtils.assignReplicasToBrokers(brokerMetadatas, partitions, replicationFactor) + createTopicWithAssignment(topic, topicConfig, replicaAssignment, usesTopicId = usesTopicId) + } + + /** + * Gets broker metadata list + * + * @param rackAwareMode rack aware mode for replica assignment + * @param brokerList The brokers to gather metadata about. + * @return The metadata for each broker that was found. + */ + def getBrokerMetadatas(rackAwareMode: RackAwareMode = RackAwareMode.Enforced, + brokerList: Option[Seq[Int]] = None): Seq[BrokerMetadata] = { + val allBrokers = zkClient.getAllBrokersInCluster + val brokers = brokerList.map(brokerIds => allBrokers.filter(b => brokerIds.contains(b.id))).getOrElse(allBrokers) + val brokersWithRack = brokers.filter(_.rack.nonEmpty) + if (rackAwareMode == RackAwareMode.Enforced && brokersWithRack.nonEmpty && brokersWithRack.size < brokers.size) { + throw new AdminOperationException("Not all brokers have rack information. Add --disable-rack-aware in command line" + + " to make replica assignment without rack information.") + } + val brokerMetadatas = rackAwareMode match { + case RackAwareMode.Disabled => brokers.map(broker => BrokerMetadata(broker.id, None)) + case RackAwareMode.Safe if brokersWithRack.size < brokers.size => + brokers.map(broker => BrokerMetadata(broker.id, None)) + case _ => brokers.map(broker => BrokerMetadata(broker.id, broker.rack)) + } + brokerMetadatas.sortBy(_.id) + } + + /** + * Create topic and optionally validate its parameters. Note that this method is used by the + * TopicCommand as well. + * + * @param topic The name of the topic + * @param config The config of the topic + * @param partitionReplicaAssignment The assignments of the topic + * @param validate Boolean indicating if parameters must be validated or not (true by default) + * @param usesTopicId Boolean indicating whether the topic ID will be created + */ + def createTopicWithAssignment(topic: String, + config: Properties, + partitionReplicaAssignment: Map[Int, Seq[Int]], + validate: Boolean = true, + usesTopicId: Boolean = false): Unit = { + if (validate) + validateTopicCreate(topic, partitionReplicaAssignment, config) + + info(s"Creating topic $topic with configuration $config and initial partition " + + s"assignment $partitionReplicaAssignment") + + // write out the config if there is any, this isn't transactional with the partition assignments + zkClient.setOrCreateEntityConfigs(ConfigType.Topic, topic, config) + + // create the partition assignment + writeTopicPartitionAssignment(topic, partitionReplicaAssignment.map { case (k, v) => k -> ReplicaAssignment(v) }, + isUpdate = false, usesTopicId) + } + + /** + * Validate topic creation parameters. Note that this method is indirectly used by the + * TopicCommand via the `createTopicWithAssignment` method. + * + * @param topic The name of the topic + * @param partitionReplicaAssignment The assignments of the topic + * @param config The config of the topic + */ + def validateTopicCreate(topic: String, + partitionReplicaAssignment: Map[Int, Seq[Int]], + config: Properties): Unit = { + Topic.validate(topic) + if (zkClient.isTopicMarkedForDeletion(topic)) { + throw new TopicExistsException(s"Topic '$topic' is marked for deletion.") + } + if (zkClient.topicExists(topic)) + throw new TopicExistsException(s"Topic '$topic' already exists.") + else if (Topic.hasCollisionChars(topic)) { + val allTopics = zkClient.getAllTopicsInCluster() + // check again in case the topic was created in the meantime, otherwise the + // topic could potentially collide with itself + if (allTopics.contains(topic)) + throw new TopicExistsException(s"Topic '$topic' already exists.") + val collidingTopics = allTopics.filter(Topic.hasCollision(topic, _)) + if (collidingTopics.nonEmpty) { + throw new InvalidTopicException(s"Topic '$topic' collides with existing topics: ${collidingTopics.mkString(", ")}") + } + } + + if (partitionReplicaAssignment.values.map(_.size).toSet.size != 1) + throw new InvalidReplicaAssignmentException("All partitions should have the same number of replicas") + + partitionReplicaAssignment.values.foreach(reps => + if (reps.size != reps.toSet.size) + throw new InvalidReplicaAssignmentException("Duplicate replica assignment found: " + partitionReplicaAssignment) + ) + + val partitionSize = partitionReplicaAssignment.size + val sequenceSum = partitionSize * (partitionSize - 1) / 2 + if (partitionReplicaAssignment.size != partitionReplicaAssignment.toSet.size || + partitionReplicaAssignment.keys.filter(_ >= 0).sum != sequenceSum) + throw new InvalidReplicaAssignmentException("partitions should be a consecutive 0-based integer sequence") + + LogConfig.validate(config) + } + + private def writeTopicPartitionAssignment(topic: String, replicaAssignment: Map[Int, ReplicaAssignment], + isUpdate: Boolean, usesTopicId: Boolean = false): Unit = { + try { + val assignment = replicaAssignment.map { case (partitionId, replicas) => (new TopicPartition(topic,partitionId), replicas) }.toMap + + if (!isUpdate) { + val topicIdOpt = if (usesTopicId) Some(Uuid.randomUuid()) else None + zkClient.createTopicAssignment(topic, topicIdOpt, assignment.map { case (k, v) => k -> v.replicas }) + } else { + val topicIds = zkClient.getTopicIdsForTopics(Set(topic)) + zkClient.setTopicAssignment(topic, topicIds.get(topic), assignment) + } + debug("Updated path %s with %s for replica assignment".format(TopicZNode.path(topic), assignment)) + } catch { + case _: NodeExistsException => throw new TopicExistsException(s"Topic '$topic' already exists.") + case e2: Throwable => throw new AdminOperationException(e2.toString) + } + } + + /** + * Creates a delete path for a given topic + * @param topic Topic name to delete + */ + def deleteTopic(topic: String): Unit = { + if (zkClient.topicExists(topic)) { + try { + zkClient.createDeleteTopicPath(topic) + } catch { + case _: NodeExistsException => throw new TopicAlreadyMarkedForDeletionException( + "topic %s is already marked for deletion".format(topic)) + case e: Throwable => throw new AdminOperationException(e.getMessage) + } + } else { + throw new UnknownTopicOrPartitionException(s"Topic `$topic` to delete does not exist") + } + } + + /** + * Add partitions to existing topic with optional replica assignment. Note that this + * method is used by the TopicCommand. + * + * @param topic Topic for adding partitions to + * @param existingAssignment A map from partition id to its assignment + * @param allBrokers All brokers in the cluster + * @param numPartitions Number of partitions to be set + * @param replicaAssignment Manual replica assignment, or none + * @param validateOnly If true, validate the parameters without actually adding the partitions + * @return the updated replica assignment + */ + def addPartitions(topic: String, + existingAssignment: Map[Int, ReplicaAssignment], + allBrokers: Seq[BrokerMetadata], + numPartitions: Int = 1, + replicaAssignment: Option[Map[Int, Seq[Int]]] = None, + validateOnly: Boolean = false): Map[Int, Seq[Int]] = { + + val proposedAssignmentForNewPartitions = createNewPartitionsAssignment( + topic, + existingAssignment, + allBrokers, + numPartitions, + replicaAssignment + ) + + if (validateOnly) { + (existingAssignment ++ proposedAssignmentForNewPartitions) + .map { case (k, v) => k -> v.replicas } + } else { + createPartitionsWithAssignment(topic, existingAssignment, proposedAssignmentForNewPartitions) + .map { case (k, v) => k -> v.replicas } + } + } + + /** + * Create assignment to add the given number of partitions while validating the + * provided arguments. + * + * @param topic Topic for adding partitions to + * @param existingAssignment A map from partition id to its assignment + * @param allBrokers All brokers in the cluster + * @param numPartitions Number of partitions to be set + * @param replicaAssignment Manual replica assignment, or none + * @return the assignment for the new partitions + */ + def createNewPartitionsAssignment(topic: String, + existingAssignment: Map[Int, ReplicaAssignment], + allBrokers: Seq[BrokerMetadata], + numPartitions: Int = 1, + replicaAssignment: Option[Map[Int, Seq[Int]]] = None): Map[Int, ReplicaAssignment] = { + val existingAssignmentPartition0 = existingAssignment.getOrElse(0, + throw new AdminOperationException( + s"Unexpected existing replica assignment for topic '$topic', partition id 0 is missing. " + + s"Assignment: $existingAssignment")).replicas + + val partitionsToAdd = numPartitions - existingAssignment.size + if (partitionsToAdd <= 0) + throw new InvalidPartitionsException( + s"The number of partitions for a topic can only be increased. " + + s"Topic $topic currently has ${existingAssignment.size} partitions, " + + s"$numPartitions would not be an increase.") + + replicaAssignment.foreach { proposedReplicaAssignment => + validateReplicaAssignment(proposedReplicaAssignment, existingAssignmentPartition0.size, + allBrokers.map(_.id).toSet) + } + + val proposedAssignmentForNewPartitions = replicaAssignment.getOrElse { + val startIndex = math.max(0, allBrokers.indexWhere(_.id >= existingAssignmentPartition0.head)) + AdminUtils.assignReplicasToBrokers(allBrokers, partitionsToAdd, existingAssignmentPartition0.size, + startIndex, existingAssignment.size) + } + + proposedAssignmentForNewPartitions.map { case (tp, replicas) => + tp -> ReplicaAssignment(replicas, List(), List()) + } + } + + /** + * Add partitions to the existing topic with the provided assignment. This method does + * not validate the provided assignments. Validation must be done beforehand. + * + * @param topic Topic for adding partitions to + * @param existingAssignment A map from partition id to its assignment + * @param newPartitionAssignment The assignments to add + * @return the updated replica assignment + */ + def createPartitionsWithAssignment(topic: String, + existingAssignment: Map[Int, ReplicaAssignment], + newPartitionAssignment: Map[Int, ReplicaAssignment]): Map[Int, ReplicaAssignment] = { + + info(s"Creating ${newPartitionAssignment.size} partitions for '$topic' with the following replica assignment: " + + s"$newPartitionAssignment.") + + val combinedAssignment = existingAssignment ++ newPartitionAssignment + + writeTopicPartitionAssignment(topic, combinedAssignment, isUpdate = true) + + combinedAssignment + } + + private def validateReplicaAssignment(replicaAssignment: Map[Int, Seq[Int]], + expectedReplicationFactor: Int, + availableBrokerIds: Set[Int]): Unit = { + + replicaAssignment.forKeyValue { (partitionId, replicas) => + if (replicas.isEmpty) + throw new InvalidReplicaAssignmentException( + s"Cannot have replication factor of 0 for partition id $partitionId.") + if (replicas.size != replicas.toSet.size) + throw new InvalidReplicaAssignmentException( + s"Duplicate brokers not allowed in replica assignment: " + + s"${replicas.mkString(", ")} for partition id $partitionId.") + if (!replicas.toSet.subsetOf(availableBrokerIds)) + throw new BrokerNotAvailableException( + s"Some brokers specified for partition id $partitionId are not available. " + + s"Specified brokers: ${replicas.mkString(", ")}, " + + s"available brokers: ${availableBrokerIds.mkString(", ")}.") + partitionId -> replicas.size + } + val badRepFactors = replicaAssignment.collect { + case (partition, replicas) if replicas.size != expectedReplicationFactor => partition -> replicas.size + } + if (badRepFactors.nonEmpty) { + val sortedBadRepFactors = badRepFactors.toSeq.sortBy { case (partitionId, _) => partitionId } + val partitions = sortedBadRepFactors.map { case (partitionId, _) => partitionId } + val repFactors = sortedBadRepFactors.map { case (_, rf) => rf } + throw new InvalidReplicaAssignmentException(s"Inconsistent replication factor between partitions, " + + s"partition 0 has $expectedReplicationFactor while partitions [${partitions.mkString(", ")}] have " + + s"replication factors [${repFactors.mkString(", ")}], respectively.") + } + } + + /** + * Parse broker from entity name to integer id + * @param broker The broker entity name to parse + * @return Integer brokerId after successfully parsed or default None + */ + def parseBroker(broker: String): Option[Int] = { + broker match { + case ConfigEntityName.Default => None + case _ => + try Some(broker.toInt) + catch { + case _: NumberFormatException => + throw new IllegalArgumentException(s"Error parsing broker $broker. The broker's Entity Name must be a single integer value") + } + } + } + + /** + * Change the configs for a given entityType and entityName + * @param entityType The entityType of the configs that will be changed + * @param entityName The entityName of the entityType + * @param configs The config of the entityName + */ + def changeConfigs(entityType: String, entityName: String, configs: Properties): Unit = { + + entityType match { + case ConfigType.Topic => changeTopicConfig(entityName, configs) + case ConfigType.Client => changeClientIdConfig(entityName, configs) + case ConfigType.User => changeUserOrUserClientIdConfig(entityName, configs) + case ConfigType.Broker => changeBrokerConfig(parseBroker(entityName), configs) + case ConfigType.Ip => changeIpConfig(entityName, configs) + case _ => throw new IllegalArgumentException(s"$entityType is not a known entityType. Should be one of ${ConfigType.all}") + } + } + + /** + * Update the config for a client and create a change notification so the change will propagate to other brokers. + * If clientId is , default clientId config is updated. ClientId configs are used only if + * and configs are not specified. + * + * @param sanitizedClientId: The sanitized clientId for which configs are being changed + * @param configs: The final set of configs that will be applied to the topic. If any new configs need to be added or + * existing configs need to be deleted, it should be done prior to invoking this API + * + */ + def changeClientIdConfig(sanitizedClientId: String, configs: Properties): Unit = { + DynamicConfig.Client.validate(configs) + changeEntityConfig(ConfigType.Client, sanitizedClientId, configs) + } + + /** + * Update the config for a or and create a change notification so the change will propagate to other brokers. + * User and/or clientId components of the path may be , indicating that the configuration is the default + * value to be applied if a more specific override is not configured. + * + * @param sanitizedEntityName: or /clients/ + * @param configs: The final set of configs that will be applied to the topic. If any new configs need to be added or + * existing configs need to be deleted, it should be done prior to invoking this API + * + */ + def changeUserOrUserClientIdConfig(sanitizedEntityName: String, configs: Properties): Unit = { + if (sanitizedEntityName == ConfigEntityName.Default || sanitizedEntityName.contains("/clients")) + DynamicConfig.Client.validate(configs) + else + DynamicConfig.User.validate(configs) + changeEntityConfig(ConfigType.User, sanitizedEntityName, configs) + } + + /** + * Validates the IP configs. + * @param ip ip for which configs are being validated + * @param configs properties to validate for the IP + */ + def validateIpConfig(ip: String, configs: Properties): Unit = { + if (!DynamicConfig.Ip.isValidIpEntity(ip)) + throw new AdminOperationException(s"$ip is not a valid IP or resolvable host.") + DynamicConfig.Ip.validate(configs) + } + + /** + * Update the config for an IP. These overrides will be persisted between sessions, and will override any default + * IP properties. + * @param ip ip for which configs are being updated + * @param configs properties to update for the IP + */ + def changeIpConfig(ip: String, configs: Properties): Unit = { + validateIpConfig(ip, configs) + changeEntityConfig(ConfigType.Ip, ip, configs) + } + + /** + * validates the topic configs + * @param topic topic for which configs are being validated + * @param configs properties to validate for the topic + */ + def validateTopicConfig(topic: String, configs: Properties): Unit = { + Topic.validate(topic) + if (!zkClient.topicExists(topic)) + throw new UnknownTopicOrPartitionException(s"Topic '$topic' does not exist.") + // remove the topic overrides + LogConfig.validate(configs) + } + + /** + * Update the config for an existing topic and create a change notification so the change will propagate to other brokers + * + * @param topic: The topic for which configs are being changed + * @param configs: The final set of configs that will be applied to the topic. If any new configs need to be added or + * existing configs need to be deleted, it should be done prior to invoking this API + * + */ + def changeTopicConfig(topic: String, configs: Properties): Unit = { + validateTopicConfig(topic, configs) + changeEntityConfig(ConfigType.Topic, topic, configs) + } + + /** + * Override the broker config on some set of brokers. These overrides will be persisted between sessions, and will + * override any defaults entered in the broker's config files + * + * @param brokers: The list of brokers to apply config changes to + * @param configs: The config to change, as properties + */ + def changeBrokerConfig(brokers: Seq[Int], configs: Properties): Unit = { + validateBrokerConfig(configs) + brokers.foreach { + broker => changeEntityConfig(ConfigType.Broker, broker.toString, configs) + } + } + + /** + * Override a broker override or broker default config. These overrides will be persisted between sessions, and will + * override any defaults entered in the broker's config files + * + * @param broker: The broker to apply config changes to or None to update dynamic default configs + * @param configs: The config to change, as properties + */ + def changeBrokerConfig(broker: Option[Int], configs: Properties): Unit = { + validateBrokerConfig(configs) + changeEntityConfig(ConfigType.Broker, broker.map(_.toString).getOrElse(ConfigEntityName.Default), configs) + } + + /** + * Validate dynamic broker configs. Since broker configs may contain custom configs, the validation + * only verifies that the provided config does not contain any static configs. + * @param configs configs to validate + */ + def validateBrokerConfig(configs: Properties): Unit = { + DynamicConfig.Broker.validate(configs) + } + + private def changeEntityConfig(rootEntityType: String, fullSanitizedEntityName: String, configs: Properties): Unit = { + val sanitizedEntityPath = rootEntityType + '/' + fullSanitizedEntityName + zkClient.setOrCreateEntityConfigs(rootEntityType, fullSanitizedEntityName, configs) + + // create the change notification + zkClient.createConfigChangeNotification(sanitizedEntityPath) + } + + /** + * Read the entity (topic, broker, client, user, or ) config (if any) from zk + * sanitizedEntityName is , , , , /clients/ or . + * @param rootEntityType entityType for which configs are being fetched + * @param sanitizedEntityName entityName of the entityType + * @return The successfully gathered configs + */ + def fetchEntityConfig(rootEntityType: String, sanitizedEntityName: String): Properties = { + zkClient.getEntityConfigs(rootEntityType, sanitizedEntityName) + } + + /** + * Gets all topic configs + * @return The successfully gathered configs of all topics + */ + def getAllTopicConfigs(): Map[String, Properties] = + zkClient.getAllTopicsInCluster().map(topic => (topic, fetchEntityConfig(ConfigType.Topic, topic))).toMap + + /** + * Gets all the entity configs for a given entityType + * @param entityType entityType for which configs are being fetched + * @return The successfully gathered configs of the entityType + */ + def fetchAllEntityConfigs(entityType: String): Map[String, Properties] = + zkClient.getAllEntitiesWithConfig(entityType).map(entity => (entity, fetchEntityConfig(entityType, entity))).toMap + + /** + * Gets all the entity configs for a given childEntityType + * @param rootEntityType rootEntityType for which configs are being fetched + * @param childEntityType childEntityType of the rootEntityType + * @return The successfully gathered configs of the childEntityType + */ + def fetchAllChildEntityConfigs(rootEntityType: String, childEntityType: String): Map[String, Properties] = { + def entityPaths(rootPath: Option[String]): Seq[String] = { + val root = rootPath match { + case Some(path) => rootEntityType + '/' + path + case None => rootEntityType + } + val entityNames = zkClient.getAllEntitiesWithConfig(root) + rootPath match { + case Some(path) => entityNames.map(entityName => path + '/' + entityName) + case None => entityNames + } + } + entityPaths(None) + .flatMap(entity => entityPaths(Some(entity + '/' + childEntityType))) + .map(entityPath => (entityPath, fetchEntityConfig(rootEntityType, entityPath))).toMap + } + +} + diff --git a/core/src/main/scala/kafka/zk/KafkaZkClient.scala b/core/src/main/scala/kafka/zk/KafkaZkClient.scala new file mode 100644 index 0000000..823d6e8 --- /dev/null +++ b/core/src/main/scala/kafka/zk/KafkaZkClient.scala @@ -0,0 +1,2047 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +package kafka.zk + +import java.util.Properties +import com.yammer.metrics.core.MetricName +import kafka.api.LeaderAndIsr +import kafka.cluster.Broker +import kafka.controller.{KafkaController, LeaderIsrAndControllerEpoch, ReplicaAssignment} +import kafka.log.LogConfig +import kafka.metrics.KafkaMetricsGroup +import kafka.security.authorizer.AclAuthorizer.{NoAcls, VersionedAcls} +import kafka.security.authorizer.AclEntry +import kafka.server.ConfigType +import kafka.utils.Logging +import kafka.zk.TopicZNode.TopicIdReplicaAssignment +import kafka.zookeeper._ +import org.apache.kafka.common.errors.ControllerMovedException +import org.apache.kafka.common.resource.{PatternType, ResourcePattern, ResourceType} +import org.apache.kafka.common.security.token.delegation.{DelegationToken, TokenInformation} +import org.apache.kafka.common.utils.{Time, Utils} +import org.apache.kafka.common.{KafkaException, TopicPartition, Uuid} +import org.apache.zookeeper.KeeperException.{Code, NodeExistsException} +import org.apache.zookeeper.OpResult.{CreateResult, ErrorResult, SetDataResult} +import org.apache.zookeeper.client.ZKClientConfig +import org.apache.zookeeper.common.ZKConfig +import org.apache.zookeeper.data.{ACL, Stat} +import org.apache.zookeeper.{CreateMode, KeeperException, ZooKeeper} + +import scala.collection.{Map, Seq, mutable} + +/** + * Provides higher level Kafka-specific operations on top of the pipelined [[kafka.zookeeper.ZooKeeperClient]]. + * + * Implementation note: this class includes methods for various components (Controller, Configs, Old Consumer, etc.) + * and returns instances of classes from the calling packages in some cases. This is not ideal, but it made it + * easier to migrate away from `ZkUtils` (since removed). We should revisit this. We should also consider whether a + * monolithic [[kafka.zk.ZkData]] is the way to go. + */ +class KafkaZkClient private[zk] (zooKeeperClient: ZooKeeperClient, isSecure: Boolean, time: Time) extends AutoCloseable with + Logging with KafkaMetricsGroup { + + override def metricName(name: String, metricTags: scala.collection.Map[String, String]): MetricName = { + explicitMetricName("kafka.server", "ZooKeeperClientMetrics", name, metricTags) + } + + private val latencyMetric = newHistogram("ZooKeeperRequestLatencyMs") + + import KafkaZkClient._ + + // Only for testing + private[kafka] def currentZooKeeper: ZooKeeper = zooKeeperClient.currentZooKeeper + + // This variable holds the Zookeeper session id at the moment a Broker gets registered in Zookeeper and the subsequent + // updates of the session id. It is possible that the session id changes over the time for 'Session expired'. + // This code is part of the work around done in the KAFKA-7165, once ZOOKEEPER-2985 is complete, this code must + // be deleted. + private var currentZooKeeperSessionId: Long = -1 + + /** + * Create a sequential persistent path. That is, the znode will not be automatically deleted upon client's disconnect + * and a monotonically increasing number will be appended to its name. + * + * @param path the path to create (with the monotonically increasing number appended) + * @param data the znode data + * @return the created path (including the appended monotonically increasing number) + */ + private[kafka] def createSequentialPersistentPath(path: String, data: Array[Byte]): String = { + val createRequest = CreateRequest(path, data, defaultAcls(path), CreateMode.PERSISTENT_SEQUENTIAL) + val createResponse = retryRequestUntilConnected(createRequest) + createResponse.maybeThrow() + createResponse.name + } + + /** + * Registers the broker in zookeeper and return the broker epoch. + * @param brokerInfo payload of the broker znode + * @return broker epoch (znode create transaction id) + */ + def registerBroker(brokerInfo: BrokerInfo): Long = { + val path = brokerInfo.path + val stat = checkedEphemeralCreate(path, brokerInfo.toJsonBytes) + info(s"Registered broker ${brokerInfo.broker.id} at path $path with addresses: " + + s"${brokerInfo.broker.endPoints.map(_.connectionString).mkString(",")}, czxid (broker epoch): ${stat.getCzxid}") + stat.getCzxid + } + + /** + * Registers a given broker in zookeeper as the controller and increments controller epoch. + * @param controllerId the id of the broker that is to be registered as the controller. + * @return the (updated controller epoch, epoch zkVersion) tuple + * @throws ControllerMovedException if fail to create /controller or fail to increment controller epoch. + */ + def registerControllerAndIncrementControllerEpoch(controllerId: Int): (Int, Int) = { + val timestamp = time.milliseconds() + + // Read /controller_epoch to get the current controller epoch and zkVersion, + // create /controller_epoch with initial value if not exists + val (curEpoch, curEpochZkVersion) = getControllerEpoch + .map(e => (e._1, e._2.getVersion)) + .getOrElse(maybeCreateControllerEpochZNode()) + + // Create /controller and update /controller_epoch atomically + val newControllerEpoch = curEpoch + 1 + val expectedControllerEpochZkVersion = curEpochZkVersion + + debug(s"Try to create ${ControllerZNode.path} and increment controller epoch to $newControllerEpoch with expected controller epoch zkVersion $expectedControllerEpochZkVersion") + + def checkControllerAndEpoch(): (Int, Int) = { + val curControllerId = getControllerId.getOrElse(throw new ControllerMovedException( + s"The ephemeral node at ${ControllerZNode.path} went away while checking whether the controller election succeeds. " + + s"Aborting controller startup procedure")) + if (controllerId == curControllerId) { + val (epoch, stat) = getControllerEpoch.getOrElse( + throw new IllegalStateException(s"${ControllerEpochZNode.path} existed before but goes away while trying to read it")) + + // If the epoch is the same as newControllerEpoch, it is safe to infer that the returned epoch zkVersion + // is associated with the current broker during controller election because we already knew that the zk + // transaction succeeds based on the controller znode verification. Other rounds of controller + // election will result in larger epoch number written in zk. + if (epoch == newControllerEpoch) + return (newControllerEpoch, stat.getVersion) + } + throw new ControllerMovedException("Controller moved to another broker. Aborting controller startup procedure") + } + + def tryCreateControllerZNodeAndIncrementEpoch(): (Int, Int) = { + val response = retryRequestUntilConnected( + MultiRequest(Seq( + CreateOp(ControllerZNode.path, ControllerZNode.encode(controllerId, timestamp), defaultAcls(ControllerZNode.path), CreateMode.EPHEMERAL), + SetDataOp(ControllerEpochZNode.path, ControllerEpochZNode.encode(newControllerEpoch), expectedControllerEpochZkVersion))) + ) + response.resultCode match { + case Code.NODEEXISTS | Code.BADVERSION => checkControllerAndEpoch() + case Code.OK => + val setDataResult = response.zkOpResults(1).rawOpResult.asInstanceOf[SetDataResult] + (newControllerEpoch, setDataResult.getStat.getVersion) + case code => throw KeeperException.create(code) + } + } + + tryCreateControllerZNodeAndIncrementEpoch() + } + + private def maybeCreateControllerEpochZNode(): (Int, Int) = { + createControllerEpochRaw(KafkaController.InitialControllerEpoch).resultCode match { + case Code.OK => + info(s"Successfully created ${ControllerEpochZNode.path} with initial epoch ${KafkaController.InitialControllerEpoch}") + (KafkaController.InitialControllerEpoch, KafkaController.InitialControllerEpochZkVersion) + case Code.NODEEXISTS => + val (epoch, stat) = getControllerEpoch.getOrElse(throw new IllegalStateException(s"${ControllerEpochZNode.path} existed before but goes away while trying to read it")) + (epoch, stat.getVersion) + case code => + throw KeeperException.create(code) + } + } + + def updateBrokerInfo(brokerInfo: BrokerInfo): Unit = { + val brokerIdPath = brokerInfo.path + val setDataRequest = SetDataRequest(brokerIdPath, brokerInfo.toJsonBytes, ZkVersion.MatchAnyVersion) + val response = retryRequestUntilConnected(setDataRequest) + response.maybeThrow() + info("Updated broker %d at path %s with addresses: %s".format(brokerInfo.broker.id, brokerIdPath, brokerInfo.broker.endPoints)) + } + + /** + * Gets topic partition states for the given partitions. + * @param partitions the partitions for which we want ot get states. + * @return sequence of GetDataResponses whose contexts are the partitions they are associated with. + */ + def getTopicPartitionStatesRaw(partitions: Seq[TopicPartition]): Seq[GetDataResponse] = { + val getDataRequests = partitions.map { partition => + GetDataRequest(TopicPartitionStateZNode.path(partition), ctx = Some(partition)) + } + retryRequestsUntilConnected(getDataRequests) + } + + /** + * Sets topic partition states for the given partitions. + * @param leaderIsrAndControllerEpochs the partition states of each partition whose state we wish to set. + * @param expectedControllerEpochZkVersion expected controller epoch zkVersion. + * @return sequence of SetDataResponse whose contexts are the partitions they are associated with. + */ + def setTopicPartitionStatesRaw(leaderIsrAndControllerEpochs: Map[TopicPartition, LeaderIsrAndControllerEpoch], expectedControllerEpochZkVersion: Int): Seq[SetDataResponse] = { + val setDataRequests = leaderIsrAndControllerEpochs.map { case (partition, leaderIsrAndControllerEpoch) => + val path = TopicPartitionStateZNode.path(partition) + val data = TopicPartitionStateZNode.encode(leaderIsrAndControllerEpoch) + SetDataRequest(path, data, leaderIsrAndControllerEpoch.leaderAndIsr.zkVersion, Some(partition)) + } + retryRequestsUntilConnected(setDataRequests.toSeq, expectedControllerEpochZkVersion) + } + + /** + * Creates topic partition state znodes for the given partitions. + * @param leaderIsrAndControllerEpochs the partition states of each partition whose state we wish to set. + * @param expectedControllerEpochZkVersion expected controller epoch zkVersion. + * @return sequence of CreateResponse whose contexts are the partitions they are associated with. + */ + def createTopicPartitionStatesRaw(leaderIsrAndControllerEpochs: Map[TopicPartition, LeaderIsrAndControllerEpoch], expectedControllerEpochZkVersion: Int): Seq[CreateResponse] = { + createTopicPartitions(leaderIsrAndControllerEpochs.keys.map(_.topic).toSet.toSeq, expectedControllerEpochZkVersion) + createTopicPartition(leaderIsrAndControllerEpochs.keys.toSeq, expectedControllerEpochZkVersion) + val createRequests = leaderIsrAndControllerEpochs.map { case (partition, leaderIsrAndControllerEpoch) => + val path = TopicPartitionStateZNode.path(partition) + val data = TopicPartitionStateZNode.encode(leaderIsrAndControllerEpoch) + CreateRequest(path, data, defaultAcls(path), CreateMode.PERSISTENT, Some(partition)) + } + retryRequestsUntilConnected(createRequests.toSeq, expectedControllerEpochZkVersion) + } + + /** + * Sets the controller epoch conditioned on the given epochZkVersion. + * @param epoch the epoch to set + * @param epochZkVersion the expected version number of the epoch znode. + * @return SetDataResponse + */ + def setControllerEpochRaw(epoch: Int, epochZkVersion: Int): SetDataResponse = { + val setDataRequest = SetDataRequest(ControllerEpochZNode.path, ControllerEpochZNode.encode(epoch), epochZkVersion) + retryRequestUntilConnected(setDataRequest) + } + + /** + * Creates the controller epoch znode. + * @param epoch the epoch to set + * @return CreateResponse + */ + def createControllerEpochRaw(epoch: Int): CreateResponse = { + val createRequest = CreateRequest(ControllerEpochZNode.path, ControllerEpochZNode.encode(epoch), + defaultAcls(ControllerEpochZNode.path), CreateMode.PERSISTENT) + retryRequestUntilConnected(createRequest) + } + + /** + * Update the partition states of multiple partitions in zookeeper. + * @param leaderAndIsrs The partition states to update. + * @param controllerEpoch The current controller epoch. + * @param expectedControllerEpochZkVersion expected controller epoch zkVersion. + * @return UpdateLeaderAndIsrResult instance containing per partition results. + */ + def updateLeaderAndIsr( + leaderAndIsrs: Map[TopicPartition, LeaderAndIsr], + controllerEpoch: Int, + expectedControllerEpochZkVersion: Int + ): UpdateLeaderAndIsrResult = { + val leaderIsrAndControllerEpochs = leaderAndIsrs.map { case (partition, leaderAndIsr) => + partition -> LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch) + } + val setDataResponses = try { + setTopicPartitionStatesRaw(leaderIsrAndControllerEpochs, expectedControllerEpochZkVersion) + } catch { + case e: ControllerMovedException => throw e + case e: Exception => + return UpdateLeaderAndIsrResult(leaderAndIsrs.keys.iterator.map(_ -> Left(e)).toMap, Seq.empty) + } + + val updatesToRetry = mutable.Buffer.empty[TopicPartition] + val finished = setDataResponses.iterator.flatMap { setDataResponse => + val partition = setDataResponse.ctx.get.asInstanceOf[TopicPartition] + setDataResponse.resultCode match { + case Code.OK => + val updatedLeaderAndIsr = leaderAndIsrs(partition).withZkVersion(setDataResponse.stat.getVersion) + Some(partition -> Right(updatedLeaderAndIsr)) + case Code.BADVERSION => + // Update the buffer for partitions to retry + updatesToRetry += partition + None + case _ => + Some(partition -> Left(setDataResponse.resultException.get)) + } + }.toMap + + UpdateLeaderAndIsrResult(finished, updatesToRetry) + } + + /** + * Get log configs that merge local configs with topic-level configs in zookeeper. + * @param topics The topics to get log configs for. + * @param config The local configs. + * @return A tuple of two values: + * 1. The successfully gathered log configs + * 2. Exceptions corresponding to failed log config lookups. + */ + def getLogConfigs( + topics: Set[String], + config: java.util.Map[String, AnyRef] + ): (Map[String, LogConfig], Map[String, Exception]) = { + val logConfigs = mutable.Map.empty[String, LogConfig] + val failed = mutable.Map.empty[String, Exception] + val configResponses = try { + getTopicConfigs(topics) + } catch { + case e: Exception => + topics.foreach(topic => failed.put(topic, e)) + return (logConfigs.toMap, failed.toMap) + } + configResponses.foreach { configResponse => + val topic = configResponse.ctx.get.asInstanceOf[String] + configResponse.resultCode match { + case Code.OK => + val overrides = ConfigEntityZNode.decode(configResponse.data) + val logConfig = LogConfig.fromProps(config, overrides) + logConfigs.put(topic, logConfig) + case Code.NONODE => + val logConfig = LogConfig.fromProps(config, new Properties) + logConfigs.put(topic, logConfig) + case _ => failed.put(topic, configResponse.resultException.get) + } + } + (logConfigs.toMap, failed.toMap) + } + + /** + * Get entity configs for a given entity name + * @param rootEntityType entity type + * @param sanitizedEntityName entity name + * @return The successfully gathered log configs + */ + def getEntityConfigs(rootEntityType: String, sanitizedEntityName: String): Properties = { + val getDataRequest = GetDataRequest(ConfigEntityZNode.path(rootEntityType, sanitizedEntityName)) + val getDataResponse = retryRequestUntilConnected(getDataRequest) + + getDataResponse.resultCode match { + case Code.OK => + ConfigEntityZNode.decode(getDataResponse.data) + case Code.NONODE => new Properties() + case _ => throw getDataResponse.resultException.get + } + } + + /** + * Sets or creates the entity znode path with the given configs depending + * on whether it already exists or not. + * + * If this is method is called concurrently, the last writer wins. In cases where we update configs and then + * partition assignment (i.e. create topic), it's possible for one thread to set this and the other to set the + * partition assignment. As such, the recommendation is to never call create topic for the same topic with different + * configs/partition assignment concurrently. + * + * @param rootEntityType entity type + * @param sanitizedEntityName entity name + * @throws KeeperException if there is an error while setting or creating the znode + */ + def setOrCreateEntityConfigs(rootEntityType: String, sanitizedEntityName: String, config: Properties) = { + + def set(configData: Array[Byte]): SetDataResponse = { + val setDataRequest = SetDataRequest(ConfigEntityZNode.path(rootEntityType, sanitizedEntityName), + configData, ZkVersion.MatchAnyVersion) + retryRequestUntilConnected(setDataRequest) + } + + def createOrSet(configData: Array[Byte]): Unit = { + val path = ConfigEntityZNode.path(rootEntityType, sanitizedEntityName) + try createRecursive(path, configData) + catch { + case _: NodeExistsException => set(configData).maybeThrow() + } + } + + val configData = ConfigEntityZNode.encode(config) + + val setDataResponse = set(configData) + setDataResponse.resultCode match { + case Code.NONODE => createOrSet(configData) + case _ => setDataResponse.maybeThrow() + } + } + + /** + * Returns all the entities for a given entityType + * @param entityType entity type + * @return List of all entity names + */ + def getAllEntitiesWithConfig(entityType: String): Seq[String] = { + getChildren(ConfigEntityTypeZNode.path(entityType)) + } + + /** + * Creates config change notification + * @param sanitizedEntityPath sanitizedEntityPath path to write + * @throws KeeperException if there is an error while setting or creating the znode + */ + def createConfigChangeNotification(sanitizedEntityPath: String): Unit = { + makeSurePersistentPathExists(ConfigEntityChangeNotificationZNode.path) + val path = ConfigEntityChangeNotificationSequenceZNode.createPath + val createRequest = CreateRequest(path, ConfigEntityChangeNotificationSequenceZNode.encode(sanitizedEntityPath), defaultAcls(path), CreateMode.PERSISTENT_SEQUENTIAL) + val createResponse = retryRequestUntilConnected(createRequest) + createResponse.maybeThrow() + } + + /** + * Gets all brokers in the cluster. + * @return sequence of brokers in the cluster. + */ + def getAllBrokersInCluster: Seq[Broker] = { + val brokerIds = getSortedBrokerList + val getDataRequests = brokerIds.map(brokerId => GetDataRequest(BrokerIdZNode.path(brokerId), ctx = Some(brokerId))) + val getDataResponses = retryRequestsUntilConnected(getDataRequests) + getDataResponses.flatMap { getDataResponse => + val brokerId = getDataResponse.ctx.get.asInstanceOf[Int] + getDataResponse.resultCode match { + case Code.OK => + Option(BrokerIdZNode.decode(brokerId, getDataResponse.data).broker) + case Code.NONODE => None + case _ => throw getDataResponse.resultException.get + } + } + } + + /** + * Gets all brokers with broker epoch in the cluster. + * @return map of broker to epoch in the cluster. + */ + def getAllBrokerAndEpochsInCluster: Map[Broker, Long] = { + val brokerIds = getSortedBrokerList + val getDataRequests = brokerIds.map(brokerId => GetDataRequest(BrokerIdZNode.path(brokerId), ctx = Some(brokerId))) + val getDataResponses = retryRequestsUntilConnected(getDataRequests) + getDataResponses.flatMap { getDataResponse => + val brokerId = getDataResponse.ctx.get.asInstanceOf[Int] + getDataResponse.resultCode match { + case Code.OK => + Some((BrokerIdZNode.decode(brokerId, getDataResponse.data).broker, getDataResponse.stat.getCzxid)) + case Code.NONODE => None + case _ => throw getDataResponse.resultException.get + } + }.toMap + } + + /** + * Get a broker from ZK + * @return an optional Broker + */ + def getBroker(brokerId: Int): Option[Broker] = { + val getDataRequest = GetDataRequest(BrokerIdZNode.path(brokerId)) + val getDataResponse = retryRequestUntilConnected(getDataRequest) + getDataResponse.resultCode match { + case Code.OK => + Option(BrokerIdZNode.decode(brokerId, getDataResponse.data).broker) + case Code.NONODE => None + case _ => throw getDataResponse.resultException.get + } + } + + /** + * Gets the list of sorted broker Ids + */ + def getSortedBrokerList: Seq[Int] = getChildren(BrokerIdsZNode.path).map(_.toInt).sorted + + /** + * Gets all topics in the cluster. + * @param registerWatch indicates if a watch must be registered or not + * @return sequence of topics in the cluster. + */ + def getAllTopicsInCluster(registerWatch: Boolean = false): Set[String] = { + val getChildrenResponse = retryRequestUntilConnected( + GetChildrenRequest(TopicsZNode.path, registerWatch)) + getChildrenResponse.resultCode match { + case Code.OK => getChildrenResponse.children.toSet + case Code.NONODE => Set.empty + case _ => throw getChildrenResponse.resultException.get + } + } + + /** + * Checks the topic existence + * @param topicName + * @return true if topic exists else false + */ + def topicExists(topicName: String): Boolean = { + pathExists(TopicZNode.path(topicName)) + } + + /** + * Adds a topic ID to existing topic and replica assignments + * @param topicIdReplicaAssignments the TopicIDReplicaAssignments to add a topic ID to + * @return the updated TopicIdReplicaAssigments including the newly created topic IDs + */ + def setTopicIds(topicIdReplicaAssignments: collection.Set[TopicIdReplicaAssignment], + expectedControllerEpochZkVersion: Int): Set[TopicIdReplicaAssignment] = { + val updatedAssignments = topicIdReplicaAssignments.map { + case TopicIdReplicaAssignment(topic, None, assignments) => + TopicIdReplicaAssignment(topic, Some(Uuid.randomUuid()), assignments) + case TopicIdReplicaAssignment(topic, Some(_), _) => + throw new IllegalArgumentException("TopicIdReplicaAssignment for " + topic + " already contains a topic ID.") + }.toSet + + val setDataRequests = updatedAssignments.map { case TopicIdReplicaAssignment(topic, topicIdOpt, assignments) => + SetDataRequest(TopicZNode.path(topic), TopicZNode.encode(topicIdOpt, assignments), ZkVersion.MatchAnyVersion) + }.toSeq + + retryRequestsUntilConnected(setDataRequests, expectedControllerEpochZkVersion) + updatedAssignments + } + + /** + * Sets the topic znode with the given assignment. + * @param topic the topic whose assignment is being set. + * @param topicId unique topic ID for the topic if the version supports it + * @param assignment the partition to replica mapping to set for the given topic + * @param expectedControllerEpochZkVersion expected controller epoch zkVersion. + * @return SetDataResponse + */ + def setTopicAssignmentRaw(topic: String, + topicId: Option[Uuid], + assignment: collection.Map[TopicPartition, ReplicaAssignment], + expectedControllerEpochZkVersion: Int): SetDataResponse = { + val setDataRequest = SetDataRequest(TopicZNode.path(topic), TopicZNode.encode(topicId, assignment), ZkVersion.MatchAnyVersion) + retryRequestUntilConnected(setDataRequest, expectedControllerEpochZkVersion) + } + + /** + * Sets the topic znode with the given assignment. + * @param topic the topic whose assignment is being set. + * @param topicId unique topic ID for the topic if the version supports it + * @param assignment the partition to replica mapping to set for the given topic + * @param expectedControllerEpochZkVersion expected controller epoch zkVersion. + * @throws KeeperException if there is an error while setting assignment + */ + def setTopicAssignment(topic: String, + topicId: Option[Uuid], + assignment: Map[TopicPartition, ReplicaAssignment], + expectedControllerEpochZkVersion: Int = ZkVersion.MatchAnyVersion) = { + val setDataResponse = setTopicAssignmentRaw(topic, topicId, assignment, expectedControllerEpochZkVersion) + setDataResponse.maybeThrow() + } + + /** + * Create the topic znode with the given assignment. + * @param topic the topic whose assignment is being set. + * @param topicId unique topic ID for the topic if the version supports it + * @param assignment the partition to replica mapping to set for the given topic + * @throws KeeperException if there is an error while creating assignment + */ + def createTopicAssignment(topic: String, topicId: Option[Uuid], assignment: Map[TopicPartition, Seq[Int]]): Unit = { + val persistedAssignments = assignment.map { case (k, v) => k -> ReplicaAssignment(v) } + createRecursive(TopicZNode.path(topic), TopicZNode.encode(topicId, persistedAssignments)) + } + + /** + * Gets the log dir event notifications as strings. These strings are the znode names and not the absolute znode path. + * @return sequence of znode names and not the absolute znode path. + */ + def getAllLogDirEventNotifications: Seq[String] = { + val getChildrenResponse = retryRequestUntilConnected(GetChildrenRequest(LogDirEventNotificationZNode.path, registerWatch = true)) + getChildrenResponse.resultCode match { + case Code.OK => getChildrenResponse.children.map(LogDirEventNotificationSequenceZNode.sequenceNumber) + case Code.NONODE => Seq.empty + case _ => throw getChildrenResponse.resultException.get + } + } + + /** + * Reads each of the log dir event notifications associated with the given sequence numbers and extracts the broker ids. + * @param sequenceNumbers the sequence numbers associated with the log dir event notifications. + * @return broker ids associated with the given log dir event notifications. + */ + def getBrokerIdsFromLogDirEvents(sequenceNumbers: Seq[String]): Seq[Int] = { + val getDataRequests = sequenceNumbers.map { sequenceNumber => + GetDataRequest(LogDirEventNotificationSequenceZNode.path(sequenceNumber)) + } + val getDataResponses = retryRequestsUntilConnected(getDataRequests) + getDataResponses.flatMap { getDataResponse => + getDataResponse.resultCode match { + case Code.OK => LogDirEventNotificationSequenceZNode.decode(getDataResponse.data) + case Code.NONODE => None + case _ => throw getDataResponse.resultException.get + } + } + } + + /** + * Deletes all log dir event notifications. + * @param expectedControllerEpochZkVersion expected controller epoch zkVersion. + */ + def deleteLogDirEventNotifications(expectedControllerEpochZkVersion: Int): Unit = { + val getChildrenResponse = retryRequestUntilConnected(GetChildrenRequest(LogDirEventNotificationZNode.path, registerWatch = true)) + if (getChildrenResponse.resultCode == Code.OK) { + deleteLogDirEventNotifications(getChildrenResponse.children.map(LogDirEventNotificationSequenceZNode.sequenceNumber), expectedControllerEpochZkVersion) + } else if (getChildrenResponse.resultCode != Code.NONODE) { + getChildrenResponse.maybeThrow() + } + } + + /** + * Deletes the log dir event notifications associated with the given sequence numbers. + * @param sequenceNumbers the sequence numbers associated with the log dir event notifications to be deleted. + * @param expectedControllerEpochZkVersion expected controller epoch zkVersion. + */ + def deleteLogDirEventNotifications(sequenceNumbers: Seq[String], expectedControllerEpochZkVersion: Int): Unit = { + val deleteRequests = sequenceNumbers.map { sequenceNumber => + DeleteRequest(LogDirEventNotificationSequenceZNode.path(sequenceNumber), ZkVersion.MatchAnyVersion) + } + retryRequestsUntilConnected(deleteRequests, expectedControllerEpochZkVersion) + } + + /** + * Gets the topic IDs for the given topics. + * @param topics the topics we wish to retrieve the Topic IDs for + * @return the Topic IDs + */ + def getTopicIdsForTopics(topics: Set[String]): Map[String, Uuid] = { + val getDataRequests = topics.map(topic => GetDataRequest(TopicZNode.path(topic), ctx = Some(topic))) + val getDataResponses = retryRequestsUntilConnected(getDataRequests.toSeq) + getDataResponses.map { getDataResponse => + val topic = getDataResponse.ctx.get.asInstanceOf[String] + getDataResponse.resultCode match { + case Code.OK => Some(TopicZNode.decode(topic, getDataResponse.data)) + case Code.NONODE => None + case _ => throw getDataResponse.resultException.get + } + }.filter(_.flatMap(_.topicId).isDefined) + .map(_.get) + .map(topicIdAssignment => (topicIdAssignment.topic, topicIdAssignment.topicId.get)) + .toMap + } + + /** + * Gets the replica assignments for the given topics. + * This function does not return information about which replicas are being added or removed from the assignment. + * @param topics the topics whose partitions we wish to get the assignments for. + * @return the replica assignment for each partition from the given topics. + */ + def getReplicaAssignmentForTopics(topics: Set[String]): Map[TopicPartition, Seq[Int]] = { + getFullReplicaAssignmentForTopics(topics).map { case (k, v) => k -> v.replicas } + } + + /** + * Gets the TopicID and replica assignments for the given topics. + * @param topics the topics whose partitions we wish to get the assignments for. + * @return the TopicIdReplicaAssignment for each partition for the given topics. + */ + def getReplicaAssignmentAndTopicIdForTopics(topics: Set[String]): Set[TopicIdReplicaAssignment] = { + val getDataRequests = topics.map(topic => GetDataRequest(TopicZNode.path(topic), ctx = Some(topic))) + val getDataResponses = retryRequestsUntilConnected(getDataRequests.toSeq) + getDataResponses.map { getDataResponse => + val topic = getDataResponse.ctx.get.asInstanceOf[String] + getDataResponse.resultCode match { + case Code.OK => TopicZNode.decode(topic, getDataResponse.data) + case Code.NONODE => TopicIdReplicaAssignment(topic, None, Map.empty[TopicPartition, ReplicaAssignment]) + case _ => throw getDataResponse.resultException.get + } + }.toSet + } + + /** + * Gets the replica assignments for the given topics. + * @param topics the topics whose partitions we wish to get the assignments for. + * @return the full replica assignment for each partition from the given topics. + */ + def getFullReplicaAssignmentForTopics(topics: Set[String]): Map[TopicPartition, ReplicaAssignment] = { + val getDataRequests = topics.map(topic => GetDataRequest(TopicZNode.path(topic), ctx = Some(topic))) + val getDataResponses = retryRequestsUntilConnected(getDataRequests.toSeq) + getDataResponses.flatMap { getDataResponse => + val topic = getDataResponse.ctx.get.asInstanceOf[String] + getDataResponse.resultCode match { + case Code.OK => TopicZNode.decode(topic, getDataResponse.data).assignment + case Code.NONODE => Map.empty[TopicPartition, ReplicaAssignment] + case _ => throw getDataResponse.resultException.get + } + }.toMap + } + + /** + * Gets partition the assignments for the given topics. + * @param topics the topics whose partitions we wish to get the assignments for. + * @return the partition assignment for each partition from the given topics. + */ + def getPartitionAssignmentForTopics(topics: Set[String]): Map[String, Map[Int, ReplicaAssignment]] = { + val getDataRequests = topics.map(topic => GetDataRequest(TopicZNode.path(topic), ctx = Some(topic))) + val getDataResponses = retryRequestsUntilConnected(getDataRequests.toSeq) + getDataResponses.flatMap { getDataResponse => + val topic = getDataResponse.ctx.get.asInstanceOf[String] + if (getDataResponse.resultCode == Code.OK) { + val partitionMap = TopicZNode.decode(topic, getDataResponse.data).assignment.map { case (k, v) => (k.partition, v) } + Map(topic -> partitionMap) + } else if (getDataResponse.resultCode == Code.NONODE) { + Map.empty[String, Map[Int, ReplicaAssignment]] + } else { + throw getDataResponse.resultException.get + } + }.toMap + } + + /** + * Gets the partition numbers for the given topics + * @param topics the topics whose partitions we wish to get. + * @return the partition array for each topic from the given topics. + */ + def getPartitionsForTopics(topics: Set[String]): Map[String, Seq[Int]] = { + getPartitionAssignmentForTopics(topics).map { topicAndPartitionMap => + val topic = topicAndPartitionMap._1 + val partitionMap = topicAndPartitionMap._2 + topic -> partitionMap.keys.toSeq.sortWith((s, t) => s < t) + } + } + + /** + * Gets the partition count for a given topic + * @param topic The topic to get partition count for. + * @return optional integer that is Some if the topic exists and None otherwise. + */ + def getTopicPartitionCount(topic: String): Option[Int] = { + val topicData = getReplicaAssignmentForTopics(Set(topic)) + if (topicData.nonEmpty) + Some(topicData.size) + else + None + } + + /** + * Gets the assigned replicas for a specific topic and partition + * @param topicPartition TopicAndPartition to get assigned replicas for . + * @return List of assigned replicas + */ + def getReplicasForPartition(topicPartition: TopicPartition): Seq[Int] = { + val topicData = getReplicaAssignmentForTopics(Set(topicPartition.topic)) + topicData.getOrElse(topicPartition, Seq.empty) + } + + /** + * Gets all partitions in the cluster + * @return all partitions in the cluster + */ + def getAllPartitions: Set[TopicPartition] = { + val topics = getChildren(TopicsZNode.path) + if (topics == null) Set.empty + else { + topics.flatMap { topic => + // The partitions path may not exist if the topic is in the process of being deleted + getChildren(TopicPartitionsZNode.path(topic)).map(_.toInt).map(new TopicPartition(topic, _)) + }.toSet + } + } + + /** + * Gets the data and version at the given zk path + * @param path zk node path + * @return A tuple of 2 elements, where first element is zk node data as an array of bytes + * and second element is zk node version. + * returns (None, ZkVersion.UnknownVersion) if node doesn't exist and throws exception for any error + */ + def getDataAndVersion(path: String): (Option[Array[Byte]], Int) = { + val (data, stat) = getDataAndStat(path) + stat match { + case ZkStat.NoStat => (data, ZkVersion.UnknownVersion) + case _ => (data, stat.getVersion) + } + } + + /** + * Gets the data and Stat at the given zk path + * @param path zk node path + * @return A tuple of 2 elements, where first element is zk node data as an array of bytes + * and second element is zk node stats. + * returns (None, ZkStat.NoStat) if node doesn't exists and throws exception for any error + */ + def getDataAndStat(path: String): (Option[Array[Byte]], Stat) = { + val getDataRequest = GetDataRequest(path) + val getDataResponse = retryRequestUntilConnected(getDataRequest) + + getDataResponse.resultCode match { + case Code.OK => (Option(getDataResponse.data), getDataResponse.stat) + case Code.NONODE => (None, ZkStat.NoStat) + case _ => throw getDataResponse.resultException.get + } + } + + /** + * Gets all the child nodes at a given zk node path + * @param path + * @return list of child node names + */ + def getChildren(path : String): Seq[String] = { + val getChildrenResponse = retryRequestUntilConnected(GetChildrenRequest(path, registerWatch = true)) + getChildrenResponse.resultCode match { + case Code.OK => getChildrenResponse.children + case Code.NONODE => Seq.empty + case _ => throw getChildrenResponse.resultException.get + } + } + + /** + * Conditional update the persistent path data, return (true, newVersion) if it succeeds, otherwise (the path doesn't + * exist, the current version is not the expected version, etc.) return (false, ZkVersion.UnknownVersion) + * + * When there is a ConnectionLossException during the conditional update, ZookeeperClient will retry the update and may fail + * since the previous update may have succeeded (but the stored zkVersion no longer matches the expected one). + * In this case, we will run the optionalChecker to further check if the previous write did indeed succeeded. + */ + def conditionalUpdatePath(path: String, data: Array[Byte], expectVersion: Int, + optionalChecker: Option[(KafkaZkClient, String, Array[Byte]) => (Boolean,Int)] = None): (Boolean, Int) = { + + val setDataRequest = SetDataRequest(path, data, expectVersion) + val setDataResponse = retryRequestUntilConnected(setDataRequest) + + setDataResponse.resultCode match { + case Code.OK => + debug("Conditional update of path %s with value %s and expected version %d succeeded, returning the new version: %d" + .format(path, Utils.utf8(data), expectVersion, setDataResponse.stat.getVersion)) + (true, setDataResponse.stat.getVersion) + + case Code.BADVERSION => + optionalChecker match { + case Some(checker) => checker(this, path, data) + case _ => + debug("Checker method is not passed skipping zkData match") + debug("Conditional update of path %s with data %s and expected version %d failed due to %s" + .format(path, Utils.utf8(data), expectVersion, setDataResponse.resultException.get.getMessage)) + (false, ZkVersion.UnknownVersion) + } + + case Code.NONODE => + debug("Conditional update of path %s with data %s and expected version %d failed due to %s".format(path, + Utils.utf8(data), expectVersion, setDataResponse.resultException.get.getMessage)) + (false, ZkVersion.UnknownVersion) + + case _ => + debug("Conditional update of path %s with data %s and expected version %d failed due to %s".format(path, + Utils.utf8(data), expectVersion, setDataResponse.resultException.get.getMessage)) + throw setDataResponse.resultException.get + } + } + + /** + * Creates the delete topic znode. + * @param topicName topic name + * @throws KeeperException if there is an error while setting or creating the znode + */ + def createDeleteTopicPath(topicName: String): Unit = { + createRecursive(DeleteTopicsTopicZNode.path(topicName)) + } + + /** + * Checks if topic is marked for deletion + * @param topic + * @return true if topic is marked for deletion, else false + */ + def isTopicMarkedForDeletion(topic: String): Boolean = { + pathExists(DeleteTopicsTopicZNode.path(topic)) + } + + /** + * Get all topics marked for deletion. + * @return sequence of topics marked for deletion. + */ + def getTopicDeletions: Seq[String] = { + val getChildrenResponse = retryRequestUntilConnected(GetChildrenRequest(DeleteTopicsZNode.path, registerWatch = true)) + getChildrenResponse.resultCode match { + case Code.OK => getChildrenResponse.children + case Code.NONODE => Seq.empty + case _ => throw getChildrenResponse.resultException.get + } + } + + /** + * Remove the given topics from the topics marked for deletion. + * @param topics the topics to remove. + * @param expectedControllerEpochZkVersion expected controller epoch zkVersion. + */ + def deleteTopicDeletions(topics: Seq[String], expectedControllerEpochZkVersion: Int): Unit = { + val deleteRequests = topics.map(topic => DeleteRequest(DeleteTopicsTopicZNode.path(topic), ZkVersion.MatchAnyVersion)) + retryRequestsUntilConnected(deleteRequests, expectedControllerEpochZkVersion) + } + + /** + * Returns all reassignments. + * @return the reassignments for each partition. + */ + def getPartitionReassignment: collection.Map[TopicPartition, Seq[Int]] = { + val getDataRequest = GetDataRequest(ReassignPartitionsZNode.path) + val getDataResponse = retryRequestUntilConnected(getDataRequest) + getDataResponse.resultCode match { + case Code.OK => + ReassignPartitionsZNode.decode(getDataResponse.data) match { + case Left(e) => + logger.warn(s"Ignoring partition reassignment due to invalid json: ${e.getMessage}", e) + Map.empty[TopicPartition, Seq[Int]] + case Right(assignments) => assignments + } + case Code.NONODE => Map.empty + case _ => throw getDataResponse.resultException.get + } + } + + /** + * Sets or creates the partition reassignment znode with the given reassignment depending on whether it already + * exists or not. + * + * @param reassignment the reassignment to set on the reassignment znode + * @param expectedControllerEpochZkVersion expected controller epoch zkVersion. + * @throws KeeperException if there is an error while setting or creating the znode + * @deprecated Use the PartitionReassignment Kafka API instead + */ + @Deprecated + def setOrCreatePartitionReassignment(reassignment: collection.Map[TopicPartition, Seq[Int]], expectedControllerEpochZkVersion: Int): Unit = { + + def set(reassignmentData: Array[Byte]): SetDataResponse = { + val setDataRequest = SetDataRequest(ReassignPartitionsZNode.path, reassignmentData, ZkVersion.MatchAnyVersion) + retryRequestUntilConnected(setDataRequest, expectedControllerEpochZkVersion) + } + + def create(reassignmentData: Array[Byte]): CreateResponse = { + val createRequest = CreateRequest(ReassignPartitionsZNode.path, reassignmentData, defaultAcls(ReassignPartitionsZNode.path), + CreateMode.PERSISTENT) + retryRequestUntilConnected(createRequest, expectedControllerEpochZkVersion) + } + + val reassignmentData = ReassignPartitionsZNode.encode(reassignment) + val setDataResponse = set(reassignmentData) + setDataResponse.resultCode match { + case Code.NONODE => + val createDataResponse = create(reassignmentData) + createDataResponse.maybeThrow() + case _ => setDataResponse.maybeThrow() + } + } + + /** + * Creates the partition reassignment znode with the given reassignment. + * @param reassignment the reassignment to set on the reassignment znode. + * @throws KeeperException if there is an error while creating the znode. + */ + def createPartitionReassignment(reassignment: Map[TopicPartition, Seq[Int]]) = { + createRecursive(ReassignPartitionsZNode.path, ReassignPartitionsZNode.encode(reassignment)) + } + + /** + * Deletes the partition reassignment znode. + * @param expectedControllerEpochZkVersion expected controller epoch zkVersion. + */ + def deletePartitionReassignment(expectedControllerEpochZkVersion: Int): Unit = { + deletePath(ReassignPartitionsZNode.path, expectedControllerEpochZkVersion) + } + + /** + * Checks if reassign partitions is in progress. + * @return true if reassign partitions is in progress, else false. + */ + def reassignPartitionsInProgress: Boolean = { + pathExists(ReassignPartitionsZNode.path) + } + + /** + * Gets topic partition states for the given partitions. + * @param partitions the partitions for which we want to get states. + * @return map containing LeaderIsrAndControllerEpoch of each partition for we were able to lookup the partition state. + */ + def getTopicPartitionStates(partitions: Seq[TopicPartition]): Map[TopicPartition, LeaderIsrAndControllerEpoch] = { + val getDataResponses = getTopicPartitionStatesRaw(partitions) + getDataResponses.flatMap { getDataResponse => + val partition = getDataResponse.ctx.get.asInstanceOf[TopicPartition] + getDataResponse.resultCode match { + case Code.OK => TopicPartitionStateZNode.decode(getDataResponse.data, getDataResponse.stat).map(partition -> _) + case Code.NONODE => None + case _ => throw getDataResponse.resultException.get + } + }.toMap + } + + /** + * Gets topic partition state for the given partition. + * @param partition the partition for which we want to get state. + * @return LeaderIsrAndControllerEpoch of the partition state if exists, else None + */ + def getTopicPartitionState(partition: TopicPartition): Option[LeaderIsrAndControllerEpoch] = { + val getDataResponse = getTopicPartitionStatesRaw(Seq(partition)).head + if (getDataResponse.resultCode == Code.OK) { + TopicPartitionStateZNode.decode(getDataResponse.data, getDataResponse.stat) + } else if (getDataResponse.resultCode == Code.NONODE) { + None + } else { + throw getDataResponse.resultException.get + } + } + + /** + * Gets the leader for a given partition + * @param partition The partition for which we want to get leader. + * @return optional integer if the leader exists and None otherwise. + */ + def getLeaderForPartition(partition: TopicPartition): Option[Int] = + getTopicPartitionState(partition).map(_.leaderAndIsr.leader) + + /** + * Gets the in-sync replicas (ISR) for a specific topicPartition + * @param partition The partition for which we want to get ISR. + * @return optional ISR if exists and None otherwise + */ + def getInSyncReplicasForPartition(partition: TopicPartition): Option[Seq[Int]] = + getTopicPartitionState(partition).map(_.leaderAndIsr.isr) + + + /** + * Gets the leader epoch for a specific topicPartition + * @param partition The partition for which we want to get the leader epoch + * @return optional integer if the leader exists and None otherwise + */ + def getEpochForPartition(partition: TopicPartition): Option[Int] = { + getTopicPartitionState(partition).map(_.leaderAndIsr.leaderEpoch) + } + + /** + * Gets the isr change notifications as strings. These strings are the znode names and not the absolute znode path. + * @return sequence of znode names and not the absolute znode path. + */ + def getAllIsrChangeNotifications: Seq[String] = { + val getChildrenResponse = retryRequestUntilConnected(GetChildrenRequest(IsrChangeNotificationZNode.path, registerWatch = true)) + getChildrenResponse.resultCode match { + case Code.OK => getChildrenResponse.children.map(IsrChangeNotificationSequenceZNode.sequenceNumber) + case Code.NONODE => Seq.empty + case _ => throw getChildrenResponse.resultException.get + } + } + + /** + * Reads each of the isr change notifications associated with the given sequence numbers and extracts the partitions. + * @param sequenceNumbers the sequence numbers associated with the isr change notifications. + * @return partitions associated with the given isr change notifications. + */ + def getPartitionsFromIsrChangeNotifications(sequenceNumbers: Seq[String]): Seq[TopicPartition] = { + val getDataRequests = sequenceNumbers.map { sequenceNumber => + GetDataRequest(IsrChangeNotificationSequenceZNode.path(sequenceNumber)) + } + val getDataResponses = retryRequestsUntilConnected(getDataRequests) + getDataResponses.flatMap { getDataResponse => + getDataResponse.resultCode match { + case Code.OK => IsrChangeNotificationSequenceZNode.decode(getDataResponse.data) + case Code.NONODE => None + case _ => throw getDataResponse.resultException.get + } + } + } + + /** + * Deletes all isr change notifications. + * @param expectedControllerEpochZkVersion expected controller epoch zkVersion. + */ + def deleteIsrChangeNotifications(expectedControllerEpochZkVersion: Int): Unit = { + val getChildrenResponse = retryRequestUntilConnected(GetChildrenRequest(IsrChangeNotificationZNode.path, registerWatch = true)) + if (getChildrenResponse.resultCode == Code.OK) { + deleteIsrChangeNotifications(getChildrenResponse.children.map(IsrChangeNotificationSequenceZNode.sequenceNumber), expectedControllerEpochZkVersion) + } else if (getChildrenResponse.resultCode != Code.NONODE) { + getChildrenResponse.maybeThrow() + } + } + + /** + * Deletes the isr change notifications associated with the given sequence numbers. + * @param sequenceNumbers the sequence numbers associated with the isr change notifications to be deleted. + * @param expectedControllerEpochZkVersion expected controller epoch zkVersion. + */ + def deleteIsrChangeNotifications(sequenceNumbers: Seq[String], expectedControllerEpochZkVersion: Int): Unit = { + val deleteRequests = sequenceNumbers.map { sequenceNumber => + DeleteRequest(IsrChangeNotificationSequenceZNode.path(sequenceNumber), ZkVersion.MatchAnyVersion) + } + retryRequestsUntilConnected(deleteRequests, expectedControllerEpochZkVersion) + } + + /** + * Creates preferred replica election znode with partitions undergoing election + * @param partitions + * @throws KeeperException if there is an error while creating the znode + */ + def createPreferredReplicaElection(partitions: Set[TopicPartition]): Unit = { + createRecursive(PreferredReplicaElectionZNode.path, PreferredReplicaElectionZNode.encode(partitions)) + } + + /** + * Gets the partitions marked for preferred replica election. + * @return sequence of partitions. + */ + def getPreferredReplicaElection: Set[TopicPartition] = { + val getDataRequest = GetDataRequest(PreferredReplicaElectionZNode.path) + val getDataResponse = retryRequestUntilConnected(getDataRequest) + getDataResponse.resultCode match { + case Code.OK => PreferredReplicaElectionZNode.decode(getDataResponse.data) + case Code.NONODE => Set.empty + case _ => throw getDataResponse.resultException.get + } + } + + /** + * Deletes the preferred replica election znode. + * @param expectedControllerEpochZkVersion expected controller epoch zkVersion. + */ + def deletePreferredReplicaElection(expectedControllerEpochZkVersion: Int): Unit = { + val deleteRequest = DeleteRequest(PreferredReplicaElectionZNode.path, ZkVersion.MatchAnyVersion) + retryRequestUntilConnected(deleteRequest, expectedControllerEpochZkVersion) + } + + /** + * Gets the controller id. + * @return optional integer that is Some if the controller znode exists and can be parsed and None otherwise. + */ + def getControllerId: Option[Int] = { + val getDataRequest = GetDataRequest(ControllerZNode.path) + val getDataResponse = retryRequestUntilConnected(getDataRequest) + getDataResponse.resultCode match { + case Code.OK => ControllerZNode.decode(getDataResponse.data) + case Code.NONODE => None + case _ => throw getDataResponse.resultException.get + } + } + + /** + * Deletes the controller znode. + * @param expectedControllerEpochZkVersion expected controller epoch zkVersion. + */ + def deleteController(expectedControllerEpochZkVersion: Int): Unit = { + val deleteRequest = DeleteRequest(ControllerZNode.path, ZkVersion.MatchAnyVersion) + retryRequestUntilConnected(deleteRequest, expectedControllerEpochZkVersion) + } + + /** + * Gets the controller epoch. + * @return optional (Int, Stat) that is Some if the controller epoch path exists and None otherwise. + */ + def getControllerEpoch: Option[(Int, Stat)] = { + val getDataRequest = GetDataRequest(ControllerEpochZNode.path) + val getDataResponse = retryRequestUntilConnected(getDataRequest) + getDataResponse.resultCode match { + case Code.OK => + val epoch = ControllerEpochZNode.decode(getDataResponse.data) + Option(epoch, getDataResponse.stat) + case Code.NONODE => None + case _ => throw getDataResponse.resultException.get + } + } + + /** + * Recursively deletes the topic znode. + * @param topic the topic whose topic znode we wish to delete. + * @param expectedControllerEpochZkVersion expected controller epoch zkVersion. + */ + def deleteTopicZNode(topic: String, expectedControllerEpochZkVersion: Int): Unit = { + deleteRecursive(TopicZNode.path(topic), expectedControllerEpochZkVersion) + } + + /** + * Deletes the topic configs for the given topics. + * @param topics the topics whose configs we wish to delete. + * @param expectedControllerEpochZkVersion expected controller epoch zkVersion. + */ + def deleteTopicConfigs(topics: Seq[String], expectedControllerEpochZkVersion: Int): Unit = { + val deleteRequests = topics.map(topic => DeleteRequest(ConfigEntityZNode.path(ConfigType.Topic, topic), + ZkVersion.MatchAnyVersion)) + retryRequestsUntilConnected(deleteRequests, expectedControllerEpochZkVersion) + } + + //Acl management methods + + /** + * Creates the required zk nodes for Acl storage and Acl change storage. + */ + def createAclPaths(): Unit = { + ZkAclStore.stores.foreach(store => { + createRecursive(store.aclPath, throwIfPathExists = false) + AclEntry.ResourceTypes.foreach(resourceType => createRecursive(store.path(resourceType), throwIfPathExists = false)) + }) + + ZkAclChangeStore.stores.foreach(store => createRecursive(store.aclChangePath, throwIfPathExists = false)) + } + + /** + * Gets VersionedAcls for a given Resource + * @param resource Resource to get VersionedAcls for + * @return VersionedAcls + */ + def getVersionedAclsForResource(resource: ResourcePattern): VersionedAcls = { + val getDataRequest = GetDataRequest(ResourceZNode.path(resource)) + val getDataResponse = retryRequestUntilConnected(getDataRequest) + getDataResponse.resultCode match { + case Code.OK => ResourceZNode.decode(getDataResponse.data, getDataResponse.stat) + case Code.NONODE => NoAcls + case _ => throw getDataResponse.resultException.get + } + } + + /** + * Sets or creates the resource znode path with the given acls and expected zk version depending + * on whether it already exists or not. + * @param resource + * @param aclsSet + * @param expectedVersion + * @return true if the update was successful and the new version + */ + def conditionalSetAclsForResource(resource: ResourcePattern, + aclsSet: Set[AclEntry], + expectedVersion: Int): (Boolean, Int) = { + def set(aclData: Array[Byte], expectedVersion: Int): SetDataResponse = { + val setDataRequest = SetDataRequest(ResourceZNode.path(resource), aclData, expectedVersion) + retryRequestUntilConnected(setDataRequest) + } + + if (expectedVersion < 0) + throw new IllegalArgumentException(s"Invalid version $expectedVersion provided for conditional update") + + val aclData = ResourceZNode.encode(aclsSet) + + val setDataResponse = set(aclData, expectedVersion) + setDataResponse.resultCode match { + case Code.OK => (true, setDataResponse.stat.getVersion) + case Code.NONODE | Code.BADVERSION => (false, ZkVersion.UnknownVersion) + case _ => throw setDataResponse.resultException.get + } + } + + def createAclsForResourceIfNotExists(resource: ResourcePattern, aclsSet: Set[AclEntry]): (Boolean, Int) = { + def create(aclData: Array[Byte]): CreateResponse = { + val path = ResourceZNode.path(resource) + val createRequest = CreateRequest(path, aclData, defaultAcls(path), CreateMode.PERSISTENT) + retryRequestUntilConnected(createRequest) + } + + val aclData = ResourceZNode.encode(aclsSet) + + val createResponse = create(aclData) + createResponse.resultCode match { + case Code.OK => (true, 0) + case Code.NODEEXISTS => (false, ZkVersion.UnknownVersion) + case _ => throw createResponse.resultException.get + } + } + + /** + * Creates an Acl change notification message. + * @param resource resource pattern that has changed + */ + def createAclChangeNotification(resource: ResourcePattern): Unit = { + val aclChange = ZkAclStore(resource.patternType).changeStore.createChangeNode(resource) + val createRequest = CreateRequest(aclChange.path, aclChange.bytes, defaultAcls(aclChange.path), CreateMode.PERSISTENT_SEQUENTIAL) + val createResponse = retryRequestUntilConnected(createRequest) + createResponse.maybeThrow() + } + + def propagateLogDirEvent(brokerId: Int): Unit = { + val logDirEventNotificationPath: String = createSequentialPersistentPath( + LogDirEventNotificationZNode.path + "/" + LogDirEventNotificationSequenceZNode.SequenceNumberPrefix, + LogDirEventNotificationSequenceZNode.encode(brokerId)) + debug(s"Added $logDirEventNotificationPath for broker $brokerId") + } + + def propagateIsrChanges(isrChangeSet: collection.Set[TopicPartition]): Unit = { + val isrChangeNotificationPath: String = createSequentialPersistentPath(IsrChangeNotificationSequenceZNode.path(), + IsrChangeNotificationSequenceZNode.encode(isrChangeSet)) + debug(s"Added $isrChangeNotificationPath for $isrChangeSet") + } + + /** + * Deletes all Acl change notifications. + * @throws KeeperException if there is an error while deleting Acl change notifications + */ + def deleteAclChangeNotifications(): Unit = { + ZkAclChangeStore.stores.foreach(store => { + val getChildrenResponse = retryRequestUntilConnected(GetChildrenRequest(store.aclChangePath, registerWatch = true)) + if (getChildrenResponse.resultCode == Code.OK) { + deleteAclChangeNotifications(store.aclChangePath, getChildrenResponse.children) + } else if (getChildrenResponse.resultCode != Code.NONODE) { + getChildrenResponse.maybeThrow() + } + }) + } + + /** + * Deletes the Acl change notifications associated with the given sequence nodes + * + * @param aclChangePath the root path + * @param sequenceNodes the name of the node to delete. + */ + private def deleteAclChangeNotifications(aclChangePath: String, sequenceNodes: Seq[String]): Unit = { + val deleteRequests = sequenceNodes.map { sequenceNode => + DeleteRequest(s"$aclChangePath/$sequenceNode", ZkVersion.MatchAnyVersion) + } + + val deleteResponses = retryRequestsUntilConnected(deleteRequests) + deleteResponses.foreach { deleteResponse => + if (deleteResponse.resultCode != Code.NONODE) { + deleteResponse.maybeThrow() + } + } + } + + /** + * Gets the resource types, for which ACLs are stored, for the supplied resource pattern type. + * @param patternType The resource pattern type to retrieve the names for. + * @return list of resource type names + */ + def getResourceTypes(patternType: PatternType): Seq[String] = { + getChildren(ZkAclStore(patternType).aclPath) + } + + /** + * Gets the resource names, for which ACLs are stored, for a given resource type and pattern type + * @param patternType The resource pattern type to retrieve the names for. + * @param resourceType Resource type to retrieve the names for. + * @return list of resource names + */ + def getResourceNames(patternType: PatternType, resourceType: ResourceType): Seq[String] = { + getChildren(ZkAclStore(patternType).path(resourceType)) + } + + /** + * Deletes the given Resource node + * @param resource + * @return delete status + */ + def deleteResource(resource: ResourcePattern): Boolean = { + deleteRecursive(ResourceZNode.path(resource)) + } + + /** + * checks the resource existence + * @param resource + * @return existence status + */ + def resourceExists(resource: ResourcePattern): Boolean = { + pathExists(ResourceZNode.path(resource)) + } + + /** + * Conditional delete the resource node + * @param resource + * @param expectedVersion + * @return return true if it succeeds, false otherwise (the current version is not the expected version) + */ + def conditionalDelete(resource: ResourcePattern, expectedVersion: Int): Boolean = { + val deleteRequest = DeleteRequest(ResourceZNode.path(resource), expectedVersion) + val deleteResponse = retryRequestUntilConnected(deleteRequest) + deleteResponse.resultCode match { + case Code.OK | Code.NONODE => true + case Code.BADVERSION => false + case _ => throw deleteResponse.resultException.get + } + } + + /** + * Deletes the zk node recursively + * @param path path to delete + * @param expectedControllerEpochZkVersion expected controller epoch zkVersion. + * @param recursiveDelete enable recursive delete + * @return KeeperException if there is an error while deleting the path + */ + def deletePath(path: String, expectedControllerEpochZkVersion: Int = ZkVersion.MatchAnyVersion, recursiveDelete: Boolean = true): Unit = { + if (recursiveDelete) + deleteRecursive(path, expectedControllerEpochZkVersion) + else { + val deleteRequest = DeleteRequest(path, ZkVersion.MatchAnyVersion) + val deleteResponse = retryRequestUntilConnected(deleteRequest, expectedControllerEpochZkVersion) + if (deleteResponse.resultCode != Code.OK && deleteResponse.resultCode != Code.NONODE) { + throw deleteResponse.resultException.get + } + } + } + + /** + * Creates the required zk nodes for Delegation Token storage + */ + def createDelegationTokenPaths(): Unit = { + createRecursive(DelegationTokenChangeNotificationZNode.path, throwIfPathExists = false) + createRecursive(DelegationTokensZNode.path, throwIfPathExists = false) + } + + /** + * Creates Delegation Token change notification message + * @param tokenId token Id + */ + def createTokenChangeNotification(tokenId: String): Unit = { + val path = DelegationTokenChangeNotificationSequenceZNode.createPath + val createRequest = CreateRequest(path, DelegationTokenChangeNotificationSequenceZNode.encode(tokenId), defaultAcls(path), CreateMode.PERSISTENT_SEQUENTIAL) + val createResponse = retryRequestUntilConnected(createRequest) + createResponse.resultException.foreach(e => throw e) + } + + /** + * Sets or creates token info znode with the given token details depending on whether it already + * exists or not. + * + * @param token the token to set on the token znode + * @throws KeeperException if there is an error while setting or creating the znode + */ + def setOrCreateDelegationToken(token: DelegationToken): Unit = { + + def set(tokenData: Array[Byte]): SetDataResponse = { + val setDataRequest = SetDataRequest(DelegationTokenInfoZNode.path(token.tokenInfo().tokenId()), tokenData, ZkVersion.MatchAnyVersion) + retryRequestUntilConnected(setDataRequest) + } + + def create(tokenData: Array[Byte]): CreateResponse = { + val path = DelegationTokenInfoZNode.path(token.tokenInfo().tokenId()) + val createRequest = CreateRequest(path, tokenData, defaultAcls(path), CreateMode.PERSISTENT) + retryRequestUntilConnected(createRequest) + } + + val tokenInfo = DelegationTokenInfoZNode.encode(token) + val setDataResponse = set(tokenInfo) + setDataResponse.resultCode match { + case Code.NONODE => + val createDataResponse = create(tokenInfo) + createDataResponse.maybeThrow() + case _ => setDataResponse.maybeThrow() + } + } + + /** + * Gets the Delegation Token Info + * @return optional TokenInfo that is Some if the token znode exists and can be parsed and None otherwise. + */ + def getDelegationTokenInfo(delegationTokenId: String): Option[TokenInformation] = { + val getDataRequest = GetDataRequest(DelegationTokenInfoZNode.path(delegationTokenId)) + val getDataResponse = retryRequestUntilConnected(getDataRequest) + getDataResponse.resultCode match { + case Code.OK => DelegationTokenInfoZNode.decode(getDataResponse.data) + case Code.NONODE => None + case _ => throw getDataResponse.resultException.get + } + } + + /** + * Deletes the given Delegation token node + * @param delegationTokenId + * @return delete status + */ + def deleteDelegationToken(delegationTokenId: String): Boolean = { + deleteRecursive(DelegationTokenInfoZNode.path(delegationTokenId)) + } + + /** + * This registers a ZNodeChangeHandler and attempts to register a watcher with an ExistsRequest, which allows data + * watcher registrations on paths which might not even exist. + * + * @param zNodeChangeHandler + * @return `true` if the path exists or `false` if it does not + * @throws KeeperException if an error is returned by ZooKeeper + */ + def registerZNodeChangeHandlerAndCheckExistence(zNodeChangeHandler: ZNodeChangeHandler): Boolean = { + zooKeeperClient.registerZNodeChangeHandler(zNodeChangeHandler) + val existsResponse = retryRequestUntilConnected(ExistsRequest(zNodeChangeHandler.path)) + existsResponse.resultCode match { + case Code.OK => true + case Code.NONODE => false + case _ => throw existsResponse.resultException.get + } + } + + /** + * See ZooKeeperClient.registerZNodeChangeHandler + * @param zNodeChangeHandler + */ + def registerZNodeChangeHandler(zNodeChangeHandler: ZNodeChangeHandler): Unit = { + zooKeeperClient.registerZNodeChangeHandler(zNodeChangeHandler) + } + + /** + * See ZooKeeperClient.unregisterZNodeChangeHandler + * @param path + */ + def unregisterZNodeChangeHandler(path: String): Unit = { + zooKeeperClient.unregisterZNodeChangeHandler(path) + } + + /** + * See ZooKeeperClient.registerZNodeChildChangeHandler + * @param zNodeChildChangeHandler + */ + def registerZNodeChildChangeHandler(zNodeChildChangeHandler: ZNodeChildChangeHandler): Unit = { + zooKeeperClient.registerZNodeChildChangeHandler(zNodeChildChangeHandler) + } + + /** + * See ZooKeeperClient.unregisterZNodeChildChangeHandler + * @param path + */ + def unregisterZNodeChildChangeHandler(path: String): Unit = { + zooKeeperClient.unregisterZNodeChildChangeHandler(path) + } + + /** + * + * @param stateChangeHandler + */ + def registerStateChangeHandler(stateChangeHandler: StateChangeHandler): Unit = { + zooKeeperClient.registerStateChangeHandler(stateChangeHandler) + } + + /** + * + * @param name + */ + def unregisterStateChangeHandler(name: String): Unit = { + zooKeeperClient.unregisterStateChangeHandler(name) + } + + /** + * Close the underlying ZooKeeperClient. + */ + def close(): Unit = { + removeMetric("ZooKeeperRequestLatencyMs") + zooKeeperClient.close() + } + + /** + * Get the committed offset for a topic partition and group + * @param group the group we wish to get offset for + * @param topicPartition the topic partition we wish to get the offset for + * @return optional long that is Some if there was an offset committed for topic partition, group and None otherwise. + */ + def getConsumerOffset(group: String, topicPartition: TopicPartition): Option[Long] = { + val getDataRequest = GetDataRequest(ConsumerOffset.path(group, topicPartition.topic, topicPartition.partition)) + val getDataResponse = retryRequestUntilConnected(getDataRequest) + getDataResponse.resultCode match { + case Code.OK => ConsumerOffset.decode(getDataResponse.data) + case Code.NONODE => None + case _ => throw getDataResponse.resultException.get + } + } + + /** + * Set the committed offset for a topic partition and group + * @param group the group whose offset is being set + * @param topicPartition the topic partition whose offset is being set + * @param offset the offset value + */ + def setOrCreateConsumerOffset(group: String, topicPartition: TopicPartition, offset: Long): Unit = { + val setDataResponse = setConsumerOffset(group, topicPartition, offset) + if (setDataResponse.resultCode == Code.NONODE) { + createConsumerOffset(group, topicPartition, offset) + } else { + setDataResponse.maybeThrow() + } + } + + /** + * Get the cluster id. + * @return optional cluster id in String. + */ + def getClusterId: Option[String] = { + val getDataRequest = GetDataRequest(ClusterIdZNode.path) + val getDataResponse = retryRequestUntilConnected(getDataRequest) + getDataResponse.resultCode match { + case Code.OK => Some(ClusterIdZNode.fromJson(getDataResponse.data)) + case Code.NONODE => None + case _ => throw getDataResponse.resultException.get + } + } + + /** + * Return the ACLs of the node of the given path + * @param path the given path for the node + * @return the ACL array of the given node. + */ + def getAcl(path: String): Seq[ACL] = { + val getAclRequest = GetAclRequest(path) + val getAclResponse = retryRequestUntilConnected(getAclRequest) + getAclResponse.resultCode match { + case Code.OK => getAclResponse.acl + case _ => throw getAclResponse.resultException.get + } + } + + /** + * sets the ACLs to the node of the given path + * @param path the given path for the node + * @param acl the given acl for the node + */ + def setAcl(path: String, acl: Seq[ACL]): Unit = { + val setAclRequest = SetAclRequest(path, acl, ZkVersion.MatchAnyVersion) + val setAclResponse = retryRequestUntilConnected(setAclRequest) + setAclResponse.maybeThrow() + } + + /** + * Create the cluster Id. If the cluster id already exists, return the current cluster id. + * @return cluster id + */ + def createOrGetClusterId(proposedClusterId: String): String = { + try { + createRecursive(ClusterIdZNode.path, ClusterIdZNode.toJson(proposedClusterId)) + proposedClusterId + } catch { + case _: NodeExistsException => getClusterId.getOrElse( + throw new KafkaException("Failed to get cluster id from Zookeeper. This can happen if /cluster/id is deleted from Zookeeper.")) + } + } + + /** + * Generate a broker id by updating the broker sequence id path in ZK and return the version of the path. + * The version is incremented by one on every update starting from 1. + * @return sequence number as the broker id + */ + def generateBrokerSequenceId(): Int = { + val setDataRequest = SetDataRequest(BrokerSequenceIdZNode.path, Array.empty[Byte], ZkVersion.MatchAnyVersion) + val setDataResponse = retryRequestUntilConnected(setDataRequest) + setDataResponse.resultCode match { + case Code.OK => setDataResponse.stat.getVersion + case Code.NONODE => + // maker sure the path exists + createRecursive(BrokerSequenceIdZNode.path, Array.empty[Byte], throwIfPathExists = false) + generateBrokerSequenceId() + case _ => throw setDataResponse.resultException.get + } + } + + /** + * Pre-create top level paths in ZK if needed. + */ + def createTopLevelPaths(): Unit = { + ZkData.PersistentZkPaths.foreach(makeSurePersistentPathExists(_)) + } + + /** + * Make sure a persistent path exists in ZK. + * @param path + */ + def makeSurePersistentPathExists(path: String): Unit = { + createRecursive(path, data = null, throwIfPathExists = false) + } + + def createFeatureZNode(nodeContents: FeatureZNode): Unit = { + val createRequest = CreateRequest( + FeatureZNode.path, + FeatureZNode.encode(nodeContents), + defaultAcls(FeatureZNode.path), + CreateMode.PERSISTENT) + val response = retryRequestUntilConnected(createRequest) + response.maybeThrow() + } + + def updateFeatureZNode(nodeContents: FeatureZNode): Int = { + val setRequest = SetDataRequest( + FeatureZNode.path, + FeatureZNode.encode(nodeContents), + ZkVersion.MatchAnyVersion) + val response = retryRequestUntilConnected(setRequest) + response.maybeThrow() + response.stat.getVersion + } + + def deleteFeatureZNode(): Unit = { + deletePath(FeatureZNode.path, ZkVersion.MatchAnyVersion, false) + } + + private def setConsumerOffset(group: String, topicPartition: TopicPartition, offset: Long): SetDataResponse = { + val setDataRequest = SetDataRequest(ConsumerOffset.path(group, topicPartition.topic, topicPartition.partition), + ConsumerOffset.encode(offset), ZkVersion.MatchAnyVersion) + retryRequestUntilConnected(setDataRequest) + } + + private def createConsumerOffset(group: String, topicPartition: TopicPartition, offset: Long) = { + val path = ConsumerOffset.path(group, topicPartition.topic, topicPartition.partition) + createRecursive(path, ConsumerOffset.encode(offset)) + } + + /** + * Deletes the given zk path recursively + * @param path + * @param expectedControllerEpochZkVersion expected controller epoch zkVersion. + * @return true if path gets deleted successfully, false if root path doesn't exist + * @throws KeeperException if there is an error while deleting the znodes + */ + def deleteRecursive(path: String, expectedControllerEpochZkVersion: Int = ZkVersion.MatchAnyVersion): Boolean = { + val getChildrenResponse = retryRequestUntilConnected(GetChildrenRequest(path, registerWatch = true)) + getChildrenResponse.resultCode match { + case Code.OK => + getChildrenResponse.children.foreach(child => deleteRecursive(s"$path/$child", expectedControllerEpochZkVersion)) + val deleteResponse = retryRequestUntilConnected(DeleteRequest(path, ZkVersion.MatchAnyVersion), expectedControllerEpochZkVersion) + if (deleteResponse.resultCode != Code.OK && deleteResponse.resultCode != Code.NONODE) + throw deleteResponse.resultException.get + true + case Code.NONODE => false + case _ => throw getChildrenResponse.resultException.get + } + } + + def pathExists(path: String): Boolean = { + val existsRequest = ExistsRequest(path) + val existsResponse = retryRequestUntilConnected(existsRequest) + existsResponse.resultCode match { + case Code.OK => true + case Code.NONODE => false + case _ => throw existsResponse.resultException.get + } + } + + private[kafka] def createRecursive(path: String, data: Array[Byte] = null, throwIfPathExists: Boolean = true) = { + + def parentPath(path: String): String = { + val indexOfLastSlash = path.lastIndexOf("/") + if (indexOfLastSlash == -1) throw new IllegalArgumentException(s"Invalid path ${path}") + path.substring(0, indexOfLastSlash) + } + + def createRecursive0(path: String): Unit = { + val createRequest = CreateRequest(path, null, defaultAcls(path), CreateMode.PERSISTENT) + var createResponse = retryRequestUntilConnected(createRequest) + if (createResponse.resultCode == Code.NONODE) { + createRecursive0(parentPath(path)) + createResponse = retryRequestUntilConnected(createRequest) + if (createResponse.resultCode != Code.OK && createResponse.resultCode != Code.NODEEXISTS) { + throw createResponse.resultException.get + } + } else if (createResponse.resultCode != Code.OK && createResponse.resultCode != Code.NODEEXISTS) { + throw createResponse.resultException.get + } + } + + val createRequest = CreateRequest(path, data, defaultAcls(path), CreateMode.PERSISTENT) + var createResponse = retryRequestUntilConnected(createRequest) + + if (throwIfPathExists && createResponse.resultCode == Code.NODEEXISTS) { + createResponse.maybeThrow() + } else if (createResponse.resultCode == Code.NONODE) { + createRecursive0(parentPath(path)) + createResponse = retryRequestUntilConnected(createRequest) + if (throwIfPathExists || createResponse.resultCode != Code.NODEEXISTS) + createResponse.maybeThrow() + } else if (createResponse.resultCode != Code.NODEEXISTS) + createResponse.maybeThrow() + + } + + private def createTopicPartition(partitions: Seq[TopicPartition], expectedControllerEpochZkVersion: Int): Seq[CreateResponse] = { + val createRequests = partitions.map { partition => + val path = TopicPartitionZNode.path(partition) + CreateRequest(path, null, defaultAcls(path), CreateMode.PERSISTENT, Some(partition)) + } + retryRequestsUntilConnected(createRequests, expectedControllerEpochZkVersion) + } + + private def createTopicPartitions(topics: Seq[String], expectedControllerEpochZkVersion: Int): Seq[CreateResponse] = { + val createRequests = topics.map { topic => + val path = TopicPartitionsZNode.path(topic) + CreateRequest(path, null, defaultAcls(path), CreateMode.PERSISTENT, Some(topic)) + } + retryRequestsUntilConnected(createRequests, expectedControllerEpochZkVersion) + } + + private def getTopicConfigs(topics: Set[String]): Seq[GetDataResponse] = { + val getDataRequests: Seq[GetDataRequest] = topics.iterator.map { topic => + GetDataRequest(ConfigEntityZNode.path(ConfigType.Topic, topic), ctx = Some(topic)) + }.toBuffer + + retryRequestsUntilConnected(getDataRequests) + } + + def defaultAcls(path: String): Seq[ACL] = ZkData.defaultAcls(isSecure, path) + + def secure: Boolean = isSecure + + private[zk] def retryRequestUntilConnected[Req <: AsyncRequest](request: Req, expectedControllerZkVersion: Int = ZkVersion.MatchAnyVersion): Req#Response = { + retryRequestsUntilConnected(Seq(request), expectedControllerZkVersion).head + } + + private def retryRequestsUntilConnected[Req <: AsyncRequest](requests: Seq[Req], expectedControllerZkVersion: Int): Seq[Req#Response] = { + expectedControllerZkVersion match { + case ZkVersion.MatchAnyVersion => retryRequestsUntilConnected(requests) + case version if version >= 0 => + retryRequestsUntilConnected(requests.map(wrapRequestWithControllerEpochCheck(_, version))) + .map(unwrapResponseWithControllerEpochCheck(_).asInstanceOf[Req#Response]) + case invalidVersion => + throw new IllegalArgumentException(s"Expected controller epoch zkVersion $invalidVersion should be non-negative or equal to ${ZkVersion.MatchAnyVersion}") + } + } + + private def retryRequestsUntilConnected[Req <: AsyncRequest](requests: Seq[Req]): Seq[Req#Response] = { + val remainingRequests = new mutable.ArrayBuffer(requests.size) ++= requests + val responses = new mutable.ArrayBuffer[Req#Response] + while (remainingRequests.nonEmpty) { + val batchResponses = zooKeeperClient.handleRequests(remainingRequests) + + batchResponses.foreach(response => latencyMetric.update(response.metadata.responseTimeMs)) + + // Only execute slow path if we find a response with CONNECTIONLOSS + if (batchResponses.exists(_.resultCode == Code.CONNECTIONLOSS)) { + val requestResponsePairs = remainingRequests.zip(batchResponses) + + remainingRequests.clear() + requestResponsePairs.foreach { case (request, response) => + if (response.resultCode == Code.CONNECTIONLOSS) + remainingRequests += request + else + responses += response + } + + if (remainingRequests.nonEmpty) + zooKeeperClient.waitUntilConnected() + } else { + remainingRequests.clear() + responses ++= batchResponses + } + } + responses + } + + private def checkedEphemeralCreate(path: String, data: Array[Byte]): Stat = { + val checkedEphemeral = new CheckedEphemeral(path, data) + info(s"Creating $path (is it secure? $isSecure)") + val stat = checkedEphemeral.create() + info(s"Stat of the created znode at $path is: $stat") + stat + } + + private def isZKSessionIdDiffFromCurrentZKSessionId(): Boolean = { + zooKeeperClient.sessionId != currentZooKeeperSessionId + } + + private def isZKSessionTheEphemeralOwner(ephemeralOwnerId: Long): Boolean = { + ephemeralOwnerId == currentZooKeeperSessionId + } + + private[zk] def shouldReCreateEphemeralZNode(ephemeralOwnerId: Long): Boolean = { + isZKSessionTheEphemeralOwner(ephemeralOwnerId) && isZKSessionIdDiffFromCurrentZKSessionId() + } + + private def updateCurrentZKSessionId(newSessionId: Long): Unit = { + currentZooKeeperSessionId = newSessionId + } + + private class CheckedEphemeral(path: String, data: Array[Byte]) extends Logging { + def create(): Stat = { + val response = retryRequestUntilConnected( + MultiRequest(Seq( + CreateOp(path, null, defaultAcls(path), CreateMode.EPHEMERAL), + SetDataOp(path, data, 0))) + ) + val stat = response.resultCode match { + case Code.OK => + val setDataResult = response.zkOpResults(1).rawOpResult.asInstanceOf[SetDataResult] + setDataResult.getStat + case Code.NODEEXISTS => + getAfterNodeExists() + case code => + error(s"Error while creating ephemeral at $path with return code: $code") + throw KeeperException.create(code) + } + + // At this point, we need to save a reference to the zookeeper session id. + // This is done here since the Zookeeper session id may not be available at the Object creation time. + // This is assuming the 'retryRequestUntilConnected' method got connected and a valid session id is present. + // This code is part of the workaround done in the KAFKA-7165, once ZOOKEEPER-2985 is complete, this code + // must be deleted. + updateCurrentZKSessionId(zooKeeperClient.sessionId) + + stat + } + + // This method is part of the work around done in the KAFKA-7165, once ZOOKEEPER-2985 is complete, this code must + // be deleted. + private def delete(): Code = { + val deleteRequest = DeleteRequest(path, ZkVersion.MatchAnyVersion) + val deleteResponse = retryRequestUntilConnected(deleteRequest) + deleteResponse.resultCode match { + case code@ Code.OK => code + case code@ Code.NONODE => code + case code => + error(s"Error while deleting ephemeral node at $path with return code: $code") + code + } + } + + private def reCreate(): Stat = { + val codeAfterDelete = delete() + val codeAfterReCreate = codeAfterDelete + debug(s"Result of znode ephemeral deletion at $path is: $codeAfterDelete") + if (codeAfterDelete == Code.OK || codeAfterDelete == Code.NONODE) { + create() + } else { + throw KeeperException.create(codeAfterReCreate) + } + } + + private def getAfterNodeExists(): Stat = { + val getDataRequest = GetDataRequest(path) + val getDataResponse = retryRequestUntilConnected(getDataRequest) + val ephemeralOwnerId = getDataResponse.stat.getEphemeralOwner + getDataResponse.resultCode match { + // At this point, the Zookeeper session could be different (due a 'Session expired') from the one that initially + // registered the Broker into the Zookeeper ephemeral node, but the znode is still present in ZooKeeper. + // The expected behaviour is that Zookeeper server removes the ephemeral node associated with the expired session + // but due an already reported bug in Zookeeper (ZOOKEEPER-2985) this is not happening, so, the following check + // will validate if this Broker got registered with the previous (expired) session and try to register again, + // deleting the ephemeral node and creating it again. + // This code is part of the work around done in the KAFKA-7165, once ZOOKEEPER-2985 is complete, this code must + // be deleted. + case Code.OK if shouldReCreateEphemeralZNode(ephemeralOwnerId) => + info(s"Was not possible to create the ephemeral at $path, node already exists and owner " + + s"'$ephemeralOwnerId' does not match current session '${zooKeeperClient.sessionId}'" + + s", trying to delete and re-create it with the newest Zookeeper session") + reCreate() + case Code.OK if ephemeralOwnerId != zooKeeperClient.sessionId => + error(s"Error while creating ephemeral at $path, node already exists and owner " + + s"'$ephemeralOwnerId' does not match current session '${zooKeeperClient.sessionId}'") + throw KeeperException.create(Code.NODEEXISTS) + case Code.OK => + getDataResponse.stat + case Code.NONODE => + info(s"The ephemeral node at $path went away while reading it, attempting create() again") + create() + case code => + error(s"Error while creating ephemeral at $path as it already exists and error getting the node data due to $code") + throw KeeperException.create(code) + } + } + } +} + +object KafkaZkClient { + + /** + * @param finishedPartitions Partitions that finished either in successfully + * updated partition states or failed with an exception. + * @param partitionsToRetry The partitions that we should retry due to a zookeeper BADVERSION conflict. Version conflicts + * can occur if the partition leader updated partition state while the controller attempted to + * update partition state. + */ + case class UpdateLeaderAndIsrResult( + finishedPartitions: Map[TopicPartition, Either[Exception, LeaderAndIsr]], + partitionsToRetry: Seq[TopicPartition] + ) + + /** + * Create an instance of this class with the provided parameters. + * + * The metric group and type are preserved by default for compatibility with previous versions. + */ + def apply(connectString: String, + isSecure: Boolean, + sessionTimeoutMs: Int, + connectionTimeoutMs: Int, + maxInFlightRequests: Int, + time: Time, + name: String, + zkClientConfig: ZKClientConfig, + metricGroup: String = "kafka.server", + metricType: String = "SessionExpireListener", + createChrootIfNecessary: Boolean = false + ): KafkaZkClient = { + + /* ZooKeeper 3.6.0 changed the default configuration for JUTE_MAXBUFFER from 4 MB to 1 MB. + * This causes a regression if Kafka tries to retrieve a large amount of data across many + * znodes – in such a case the ZooKeeper client will repeatedly emit a message of the form + * "java.io.IOException: Packet len <####> is out of range". + * + * We restore the 3.4.x/3.5.x behavior unless the caller has set the property (note that ZKConfig + * auto configures itself if certain system properties have been set). + * + * See https://github.com/apache/zookeeper/pull/1129 for the details on why the behavior + * changed in 3.6.0. + */ + if (zkClientConfig.getProperty(ZKConfig.JUTE_MAXBUFFER) == null) + zkClientConfig.setProperty(ZKConfig.JUTE_MAXBUFFER, ((4096 * 1024).toString)) + + if (createChrootIfNecessary) { + val chrootIndex = connectString.indexOf("/") + if (chrootIndex > 0) { + val zkConnWithoutChrootForChrootCreation = connectString.substring(0, chrootIndex) + val zkClientForChrootCreation = apply(zkConnWithoutChrootForChrootCreation, isSecure, sessionTimeoutMs, + connectionTimeoutMs, maxInFlightRequests, time, name, zkClientConfig, metricGroup, metricType) + try { + val chroot = connectString.substring(chrootIndex) + if (!zkClientForChrootCreation.pathExists(chroot)) { + zkClientForChrootCreation.makeSurePersistentPathExists(chroot) + } + } finally { + zkClientForChrootCreation.close() + } + } + } + val zooKeeperClient = new ZooKeeperClient(connectString, sessionTimeoutMs, connectionTimeoutMs, maxInFlightRequests, + time, metricGroup, metricType, zkClientConfig, name) + new KafkaZkClient(zooKeeperClient, isSecure, time) + } + + // A helper function to transform a regular request into a MultiRequest + // with the check on controller epoch znode zkVersion. + // This is used for fencing zookeeper updates in controller. + private def wrapRequestWithControllerEpochCheck(request: AsyncRequest, expectedControllerZkVersion: Int): MultiRequest = { + val checkOp = CheckOp(ControllerEpochZNode.path, expectedControllerZkVersion) + request match { + case CreateRequest(path, data, acl, createMode, ctx) => + MultiRequest(Seq(checkOp, CreateOp(path, data, acl, createMode)), ctx) + case DeleteRequest(path, version, ctx) => + MultiRequest(Seq(checkOp, DeleteOp(path, version)), ctx) + case SetDataRequest(path, data, version, ctx) => + MultiRequest(Seq(checkOp, SetDataOp(path, data, version)), ctx) + case _ => throw new IllegalStateException(s"$request does not need controller epoch check") + } + } + + // A helper function to transform a MultiResponse with the check on + // controller epoch znode zkVersion back into a regular response. + // ControllerMovedException will be thrown if the controller epoch + // znode zkVersion check fails. This is used for fencing zookeeper + // updates in controller. + private def unwrapResponseWithControllerEpochCheck(response: AsyncResponse): AsyncResponse = { + response match { + case MultiResponse(resultCode, _, ctx, zkOpResults, responseMetadata) => + zkOpResults match { + case Seq(ZkOpResult(checkOp: CheckOp, checkOpResult), zkOpResult) => + checkOpResult match { + case errorResult: ErrorResult => + if (checkOp.path.equals(ControllerEpochZNode.path)) { + val errorCode = Code.get(errorResult.getErr) + if (errorCode == Code.BADVERSION) + // Throw ControllerMovedException when the zkVersionCheck is performed on the controller epoch znode and the check fails + throw new ControllerMovedException(s"Controller epoch zkVersion check fails. Expected zkVersion = ${checkOp.version}") + else if (errorCode != Code.OK) + throw KeeperException.create(errorCode, checkOp.path) + } + case _ => + } + val rawOpResult = zkOpResult.rawOpResult + zkOpResult.zkOp match { + case createOp: CreateOp => + val name = rawOpResult match { + case c: CreateResult => c.getPath + case _ => null + } + CreateResponse(resultCode, createOp.path, ctx, name, responseMetadata) + case deleteOp: DeleteOp => + DeleteResponse(resultCode, deleteOp.path, ctx, responseMetadata) + case setDataOp: SetDataOp => + val stat = rawOpResult match { + case s: SetDataResult => s.getStat + case _ => null + } + SetDataResponse(resultCode, setDataOp.path, ctx, stat, responseMetadata) + case zkOp => throw new IllegalStateException(s"Unexpected zkOp: $zkOp") + } + case null => throw KeeperException.create(resultCode) + case _ => throw new IllegalStateException(s"Cannot unwrap $response because the first zookeeper op is not check op in original MultiRequest") + } + case _ => throw new IllegalStateException(s"Cannot unwrap $response because it is not a MultiResponse") + } + } +} diff --git a/core/src/main/scala/kafka/zk/ZkData.scala b/core/src/main/scala/kafka/zk/ZkData.scala new file mode 100644 index 0000000..0f6db4a --- /dev/null +++ b/core/src/main/scala/kafka/zk/ZkData.scala @@ -0,0 +1,1014 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +package kafka.zk + +import java.nio.charset.StandardCharsets.UTF_8 +import java.util +import java.util.Properties +import com.fasterxml.jackson.annotation.JsonProperty +import com.fasterxml.jackson.core.JsonProcessingException +import kafka.api.{ApiVersion, KAFKA_0_10_0_IV1, KAFKA_2_7_IV0, LeaderAndIsr} +import kafka.cluster.{Broker, EndPoint} +import kafka.common.{NotificationHandler, ZkNodeChangeNotificationListener} +import kafka.controller.{IsrChangeNotificationHandler, LeaderIsrAndControllerEpoch, ReplicaAssignment} +import kafka.security.authorizer.AclAuthorizer.VersionedAcls +import kafka.security.authorizer.AclEntry +import kafka.server.{ConfigType, DelegationTokenManager} +import kafka.utils.Json +import kafka.utils.json.JsonObject +import org.apache.kafka.common.{KafkaException, TopicPartition, Uuid} +import org.apache.kafka.common.errors.UnsupportedVersionException +import org.apache.kafka.common.feature.{Features, FinalizedVersionRange, SupportedVersionRange} +import org.apache.kafka.common.feature.Features._ +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.resource.{PatternType, ResourcePattern, ResourceType} +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.security.token.delegation.{DelegationToken, TokenInformation} +import org.apache.kafka.common.utils.{SecurityUtils, Time} +import org.apache.kafka.server.common.ProducerIdsBlock +import org.apache.zookeeper.ZooDefs +import org.apache.zookeeper.data.{ACL, Stat} + +import scala.beans.BeanProperty +import scala.jdk.CollectionConverters._ +import scala.collection.mutable.ArrayBuffer +import scala.collection.{Map, Seq, immutable, mutable} +import scala.util.{Failure, Success, Try} + +// This file contains objects for encoding/decoding data stored in ZooKeeper nodes (znodes). + +object ControllerZNode { + def path = "/controller" + def encode(brokerId: Int, timestamp: Long): Array[Byte] = { + Json.encodeAsBytes(Map("version" -> 1, "brokerid" -> brokerId, "timestamp" -> timestamp.toString).asJava) + } + def decode(bytes: Array[Byte]): Option[Int] = Json.parseBytes(bytes).map { js => + js.asJsonObject("brokerid").to[Int] + } +} + +object ControllerEpochZNode { + def path = "/controller_epoch" + def encode(epoch: Int): Array[Byte] = epoch.toString.getBytes(UTF_8) + def decode(bytes: Array[Byte]): Int = new String(bytes, UTF_8).toInt +} + +object ConfigZNode { + def path = "/config" +} + +object BrokersZNode { + def path = "/brokers" +} + +object BrokerIdsZNode { + def path = s"${BrokersZNode.path}/ids" + def encode: Array[Byte] = null +} + +object BrokerInfo { + + /** + * - Create a broker info with v5 json format if the apiVersion is 2.7.x or above. + * - Create a broker info with v4 json format (which includes multiple endpoints and rack) if + * the apiVersion is 0.10.0.X or above but lesser than 2.7.x. + * - Register the broker with v2 json format otherwise. + * + * Due to KAFKA-3100, 0.9.0.0 broker and old clients will break if JSON version is above 2. + * + * We include v2 to make it possible for the broker to migrate from 0.9.0.0 to 0.10.0.X or above + * without having to upgrade to 0.9.0.1 first (clients have to be upgraded to 0.9.0.1 in + * any case). + */ + def apply(broker: Broker, apiVersion: ApiVersion, jmxPort: Int): BrokerInfo = { + val version = { + if (apiVersion >= KAFKA_2_7_IV0) + 5 + else if (apiVersion >= KAFKA_0_10_0_IV1) + 4 + else + 2 + } + BrokerInfo(broker, version, jmxPort) + } + +} + +case class BrokerInfo(broker: Broker, version: Int, jmxPort: Int) { + val path: String = BrokerIdZNode.path(broker.id) + def toJsonBytes: Array[Byte] = BrokerIdZNode.encode(this) +} + +object BrokerIdZNode { + private val HostKey = "host" + private val PortKey = "port" + private val VersionKey = "version" + private val EndpointsKey = "endpoints" + private val RackKey = "rack" + private val JmxPortKey = "jmx_port" + private val ListenerSecurityProtocolMapKey = "listener_security_protocol_map" + private val TimestampKey = "timestamp" + private val FeaturesKey = "features" + + def path(id: Int) = s"${BrokerIdsZNode.path}/$id" + + /** + * Encode to JSON bytes. + * + * The JSON format includes a top level host and port for compatibility with older clients. + */ + def encode(version: Int, host: String, port: Int, advertisedEndpoints: Seq[EndPoint], jmxPort: Int, + rack: Option[String], features: Features[SupportedVersionRange]): Array[Byte] = { + val jsonMap = collection.mutable.Map(VersionKey -> version, + HostKey -> host, + PortKey -> port, + EndpointsKey -> advertisedEndpoints.map(_.connectionString).toBuffer.asJava, + JmxPortKey -> jmxPort, + TimestampKey -> Time.SYSTEM.milliseconds().toString + ) + rack.foreach(rack => if (version >= 3) jsonMap += (RackKey -> rack)) + + if (version >= 4) { + jsonMap += (ListenerSecurityProtocolMapKey -> advertisedEndpoints.map { endPoint => + endPoint.listenerName.value -> endPoint.securityProtocol.name + }.toMap.asJava) + } + + if (version >= 5) { + jsonMap += (FeaturesKey -> features.toMap) + } + Json.encodeAsBytes(jsonMap.asJava) + } + + def encode(brokerInfo: BrokerInfo): Array[Byte] = { + val broker = brokerInfo.broker + // the default host and port are here for compatibility with older clients that only support PLAINTEXT + // we choose the first plaintext port, if there is one + // or we register an empty endpoint, which means that older clients will not be able to connect + val plaintextEndpoint = broker.endPoints.find(_.securityProtocol == SecurityProtocol.PLAINTEXT).getOrElse( + new EndPoint(null, -1, null, null)) + encode(brokerInfo.version, plaintextEndpoint.host, plaintextEndpoint.port, broker.endPoints, brokerInfo.jmxPort, + broker.rack, broker.features) + } + + def featuresAsJavaMap(brokerInfo: JsonObject): util.Map[String, util.Map[String, java.lang.Short]] = { + FeatureZNode.asJavaMap(brokerInfo + .get(FeaturesKey) + .flatMap(_.to[Option[Map[String, Map[String, Int]]]]) + .map(theMap => theMap.map { + case(featureName, versionsInfo) => featureName -> versionsInfo.map { + case(label, version) => label -> version.asInstanceOf[Short] + }.toMap + }.toMap) + .getOrElse(Map[String, Map[String, Short]]())) + } + + /** + * Create a BrokerInfo object from id and JSON bytes. + * + * @param id + * @param jsonBytes + * + * Version 1 JSON schema for a broker is: + * { + * "version":1, + * "host":"localhost", + * "port":9092 + * "jmx_port":9999, + * "timestamp":"2233345666" + * } + * + * Version 2 JSON schema for a broker is: + * { + * "version":2, + * "host":"localhost", + * "port":9092, + * "jmx_port":9999, + * "timestamp":"2233345666", + * "endpoints":["PLAINTEXT://host1:9092", "SSL://host1:9093"] + * } + * + * Version 3 JSON schema for a broker is: + * { + * "version":3, + * "host":"localhost", + * "port":9092, + * "jmx_port":9999, + * "timestamp":"2233345666", + * "endpoints":["PLAINTEXT://host1:9092", "SSL://host1:9093"], + * "rack":"dc1" + * } + * + * Version 4 JSON schema for a broker is: + * { + * "version":4, + * "host":"localhost", + * "port":9092, + * "jmx_port":9999, + * "timestamp":"2233345666", + * "endpoints":["CLIENT://host1:9092", "REPLICATION://host1:9093"], + * "rack":"dc1" + * } + * + * Version 5 (current) JSON schema for a broker is: + * { + * "version":5, + * "host":"localhost", + * "port":9092, + * "jmx_port":9999, + * "timestamp":"2233345666", + * "endpoints":["CLIENT://host1:9092", "REPLICATION://host1:9093"], + * "rack":"dc1", + * "features": {"feature": {"min_version":1, "first_active_version":2, "max_version":3}} + * } + */ + def decode(id: Int, jsonBytes: Array[Byte]): BrokerInfo = { + Json.tryParseBytes(jsonBytes) match { + case Right(js) => + val brokerInfo = js.asJsonObject + val version = brokerInfo(VersionKey).to[Int] + val jmxPort = brokerInfo(JmxPortKey).to[Int] + + val endpoints = + if (version < 1) + throw new KafkaException("Unsupported version of broker registration: " + + s"${new String(jsonBytes, UTF_8)}") + else if (version == 1) { + val host = brokerInfo(HostKey).to[String] + val port = brokerInfo(PortKey).to[Int] + val securityProtocol = SecurityProtocol.PLAINTEXT + val endPoint = new EndPoint(host, port, ListenerName.forSecurityProtocol(securityProtocol), securityProtocol) + Seq(endPoint) + } + else { + val securityProtocolMap = brokerInfo.get(ListenerSecurityProtocolMapKey).map( + _.to[Map[String, String]].map { case (listenerName, securityProtocol) => + new ListenerName(listenerName) -> SecurityProtocol.forName(securityProtocol) + }) + val listeners = brokerInfo(EndpointsKey).to[Seq[String]] + listeners.map(EndPoint.createEndPoint(_, securityProtocolMap)) + } + + val rack = brokerInfo.get(RackKey).flatMap(_.to[Option[String]]) + val features = featuresAsJavaMap(brokerInfo) + BrokerInfo( + Broker(id, endpoints, rack, fromSupportedFeaturesMap(features)), version, jmxPort) + case Left(e) => + throw new KafkaException(s"Failed to parse ZooKeeper registration for broker $id: " + + s"${new String(jsonBytes, UTF_8)}", e) + } + } +} + +object TopicsZNode { + def path = s"${BrokersZNode.path}/topics" +} + +object TopicZNode { + case class TopicIdReplicaAssignment(topic: String, + topicId: Option[Uuid], + assignment: Map[TopicPartition, ReplicaAssignment]) + def path(topic: String) = s"${TopicsZNode.path}/$topic" + def encode(topicId: Option[Uuid], + assignment: collection.Map[TopicPartition, ReplicaAssignment]): Array[Byte] = { + val replicaAssignmentJson = mutable.Map[String, util.List[Int]]() + val addingReplicasAssignmentJson = mutable.Map[String, util.List[Int]]() + val removingReplicasAssignmentJson = mutable.Map[String, util.List[Int]]() + + for ((partition, replicaAssignment) <- assignment) { + replicaAssignmentJson += (partition.partition.toString -> replicaAssignment.replicas.asJava) + if (replicaAssignment.addingReplicas.nonEmpty) + addingReplicasAssignmentJson += (partition.partition.toString -> replicaAssignment.addingReplicas.asJava) + if (replicaAssignment.removingReplicas.nonEmpty) + removingReplicasAssignmentJson += (partition.partition.toString -> replicaAssignment.removingReplicas.asJava) + } + + val topicAssignment = mutable.Map( + "version" -> 3, + "partitions" -> replicaAssignmentJson.asJava, + "adding_replicas" -> addingReplicasAssignmentJson.asJava, + "removing_replicas" -> removingReplicasAssignmentJson.asJava + ) + topicId.foreach(id => topicAssignment += "topic_id" -> id.toString) + + Json.encodeAsBytes(topicAssignment.asJava) + } + def decode(topic: String, bytes: Array[Byte]): TopicIdReplicaAssignment = { + def getReplicas(replicasJsonOpt: Option[JsonObject], partition: String): Seq[Int] = { + replicasJsonOpt match { + case Some(replicasJson) => replicasJson.get(partition) match { + case Some(ar) => ar.to[Seq[Int]] + case None => Seq.empty[Int] + } + case None => Seq.empty[Int] + } + } + + Json.parseBytes(bytes).map { js => + val assignmentJson = js.asJsonObject + val topicId = assignmentJson.get("topic_id").map(_.to[String]).map(Uuid.fromString) + val addingReplicasJsonOpt = assignmentJson.get("adding_replicas").map(_.asJsonObject) + val removingReplicasJsonOpt = assignmentJson.get("removing_replicas").map(_.asJsonObject) + val partitionsJsonOpt = assignmentJson.get("partitions").map(_.asJsonObject) + val partitions = partitionsJsonOpt.map { partitionsJson => + partitionsJson.iterator.map { case (partition, replicas) => + new TopicPartition(topic, partition.toInt) -> ReplicaAssignment( + replicas.to[Seq[Int]], + getReplicas(addingReplicasJsonOpt, partition), + getReplicas(removingReplicasJsonOpt, partition) + ) + }.toMap + }.getOrElse(immutable.Map.empty[TopicPartition, ReplicaAssignment]) + + TopicIdReplicaAssignment(topic, topicId, partitions) + }.getOrElse(TopicIdReplicaAssignment(topic, None, Map.empty[TopicPartition, ReplicaAssignment])) + } +} + +object TopicPartitionsZNode { + def path(topic: String) = s"${TopicZNode.path(topic)}/partitions" +} + +object TopicPartitionZNode { + def path(partition: TopicPartition) = s"${TopicPartitionsZNode.path(partition.topic)}/${partition.partition}" +} + +object TopicPartitionStateZNode { + def path(partition: TopicPartition) = s"${TopicPartitionZNode.path(partition)}/state" + def encode(leaderIsrAndControllerEpoch: LeaderIsrAndControllerEpoch): Array[Byte] = { + val leaderAndIsr = leaderIsrAndControllerEpoch.leaderAndIsr + val controllerEpoch = leaderIsrAndControllerEpoch.controllerEpoch + Json.encodeAsBytes(Map("version" -> 1, "leader" -> leaderAndIsr.leader, "leader_epoch" -> leaderAndIsr.leaderEpoch, + "controller_epoch" -> controllerEpoch, "isr" -> leaderAndIsr.isr.asJava).asJava) + } + def decode(bytes: Array[Byte], stat: Stat): Option[LeaderIsrAndControllerEpoch] = { + Json.parseBytes(bytes).map { js => + val leaderIsrAndEpochInfo = js.asJsonObject + val leader = leaderIsrAndEpochInfo("leader").to[Int] + val epoch = leaderIsrAndEpochInfo("leader_epoch").to[Int] + val isr = leaderIsrAndEpochInfo("isr").to[List[Int]] + val controllerEpoch = leaderIsrAndEpochInfo("controller_epoch").to[Int] + val zkPathVersion = stat.getVersion + LeaderIsrAndControllerEpoch(LeaderAndIsr(leader, epoch, isr, zkPathVersion), controllerEpoch) + } + } +} + +object ConfigEntityTypeZNode { + def path(entityType: String) = s"${ConfigZNode.path}/$entityType" +} + +object ConfigEntityZNode { + def path(entityType: String, entityName: String) = s"${ConfigEntityTypeZNode.path(entityType)}/$entityName" + def encode(config: Properties): Array[Byte] = { + Json.encodeAsBytes(Map("version" -> 1, "config" -> config).asJava) + } + def decode(bytes: Array[Byte]): Properties = { + val props = new Properties() + if (bytes != null) { + Json.parseBytes(bytes).foreach { js => + val configOpt = js.asJsonObjectOption.flatMap(_.get("config").flatMap(_.asJsonObjectOption)) + configOpt.foreach(config => config.iterator.foreach { case (k, v) => props.setProperty(k, v.to[String]) }) + } + } + props + } +} + +object ConfigEntityChangeNotificationZNode { + def path = s"${ConfigZNode.path}/changes" +} + +object ConfigEntityChangeNotificationSequenceZNode { + val SequenceNumberPrefix = "config_change_" + def createPath = s"${ConfigEntityChangeNotificationZNode.path}/$SequenceNumberPrefix" + def encode(sanitizedEntityPath: String): Array[Byte] = Json.encodeAsBytes( + Map("version" -> 2, "entity_path" -> sanitizedEntityPath).asJava) +} + +object IsrChangeNotificationZNode { + def path = "/isr_change_notification" +} + +object IsrChangeNotificationSequenceZNode { + val SequenceNumberPrefix = "isr_change_" + def path(sequenceNumber: String = "") = s"${IsrChangeNotificationZNode.path}/$SequenceNumberPrefix$sequenceNumber" + def encode(partitions: collection.Set[TopicPartition]): Array[Byte] = { + val partitionsJson = partitions.map(partition => Map("topic" -> partition.topic, "partition" -> partition.partition).asJava) + Json.encodeAsBytes(Map("version" -> IsrChangeNotificationHandler.Version, "partitions" -> partitionsJson.asJava).asJava) + } + + def decode(bytes: Array[Byte]): Set[TopicPartition] = { + Json.parseBytes(bytes).map { js => + val partitionsJson = js.asJsonObject("partitions").asJsonArray + partitionsJson.iterator.map { partitionsJson => + val partitionJson = partitionsJson.asJsonObject + val topic = partitionJson("topic").to[String] + val partition = partitionJson("partition").to[Int] + new TopicPartition(topic, partition) + } + } + }.map(_.toSet).getOrElse(Set.empty) + def sequenceNumber(path: String) = path.substring(path.lastIndexOf(SequenceNumberPrefix) + SequenceNumberPrefix.length) +} + +object LogDirEventNotificationZNode { + def path = "/log_dir_event_notification" +} + +object LogDirEventNotificationSequenceZNode { + val SequenceNumberPrefix = "log_dir_event_" + val LogDirFailureEvent = 1 + def path(sequenceNumber: String) = s"${LogDirEventNotificationZNode.path}/$SequenceNumberPrefix$sequenceNumber" + def encode(brokerId: Int) = { + Json.encodeAsBytes(Map("version" -> 1, "broker" -> brokerId, "event" -> LogDirFailureEvent).asJava) + } + def decode(bytes: Array[Byte]): Option[Int] = Json.parseBytes(bytes).map { js => + js.asJsonObject("broker").to[Int] + } + def sequenceNumber(path: String) = path.substring(path.lastIndexOf(SequenceNumberPrefix) + SequenceNumberPrefix.length) +} + +object AdminZNode { + def path = "/admin" +} + +object DeleteTopicsZNode { + def path = s"${AdminZNode.path}/delete_topics" +} + +object DeleteTopicsTopicZNode { + def path(topic: String) = s"${DeleteTopicsZNode.path}/$topic" +} + +/** + * The znode for initiating a partition reassignment. + * @deprecated Since 2.4, use the PartitionReassignment Kafka API instead. + */ +object ReassignPartitionsZNode { + + /** + * The assignment of brokers for a `TopicPartition`. + * + * A replica assignment consists of a `topic`, `partition` and a list of `replicas`, which + * represent the broker ids that the `TopicPartition` is assigned to. + */ + case class ReplicaAssignment(@BeanProperty @JsonProperty("topic") topic: String, + @BeanProperty @JsonProperty("partition") partition: Int, + @BeanProperty @JsonProperty("replicas") replicas: java.util.List[Int]) + + /** + * An assignment consists of a `version` and a list of `partitions`, which represent the + * assignment of topic-partitions to brokers. + * @deprecated Use the PartitionReassignment Kafka API instead + */ + @Deprecated + case class LegacyPartitionAssignment(@BeanProperty @JsonProperty("version") version: Int, + @BeanProperty @JsonProperty("partitions") partitions: java.util.List[ReplicaAssignment]) + + def path = s"${AdminZNode.path}/reassign_partitions" + + def encode(reassignmentMap: collection.Map[TopicPartition, Seq[Int]]): Array[Byte] = { + val reassignment = LegacyPartitionAssignment(1, + reassignmentMap.toSeq.map { case (tp, replicas) => + ReplicaAssignment(tp.topic, tp.partition, replicas.asJava) + }.asJava + ) + Json.encodeAsBytes(reassignment) + } + + def decode(bytes: Array[Byte]): Either[JsonProcessingException, collection.Map[TopicPartition, Seq[Int]]] = + Json.parseBytesAs[LegacyPartitionAssignment](bytes).map { partitionAssignment => + partitionAssignment.partitions.asScala.iterator.map { replicaAssignment => + new TopicPartition(replicaAssignment.topic, replicaAssignment.partition) -> replicaAssignment.replicas.asScala + }.toMap + } +} + +object PreferredReplicaElectionZNode { + def path = s"${AdminZNode.path}/preferred_replica_election" + def encode(partitions: Set[TopicPartition]): Array[Byte] = { + val jsonMap = Map("version" -> 1, + "partitions" -> partitions.map(tp => Map("topic" -> tp.topic, "partition" -> tp.partition).asJava).asJava) + Json.encodeAsBytes(jsonMap.asJava) + } + def decode(bytes: Array[Byte]): Set[TopicPartition] = Json.parseBytes(bytes).map { js => + val partitionsJson = js.asJsonObject("partitions").asJsonArray + partitionsJson.iterator.map { partitionsJson => + val partitionJson = partitionsJson.asJsonObject + val topic = partitionJson("topic").to[String] + val partition = partitionJson("partition").to[Int] + new TopicPartition(topic, partition) + } + }.map(_.toSet).getOrElse(Set.empty) +} + +//old consumer path znode +object ConsumerPathZNode { + def path = "/consumers" +} + +object ConsumerOffset { + def path(group: String, topic: String, partition: Integer) = s"${ConsumerPathZNode.path}/${group}/offsets/${topic}/${partition}" + def encode(offset: Long): Array[Byte] = offset.toString.getBytes(UTF_8) + def decode(bytes: Array[Byte]): Option[Long] = Option(bytes).map(new String(_, UTF_8).toLong) +} + +object ZkVersion { + val MatchAnyVersion = -1 // if used in a conditional set, matches any version (the value should match ZooKeeper codebase) + val UnknownVersion = -2 // Version returned from get if node does not exist (internal constant for Kafka codebase, unused value in ZK) +} + +object ZkStat { + val NoStat = new Stat() +} + +object StateChangeHandlers { + val ControllerHandler = "controller-state-change-handler" + def zkNodeChangeListenerHandler(seqNodeRoot: String) = s"change-notification-$seqNodeRoot" +} + +/** + * Acls for resources are stored in ZK under two root paths: + *

                  + *
                • [[org.apache.kafka.common.resource.PatternType#LITERAL Literal]] patterns are stored under '/kafka-acl'. + * The format is JSON. See [[kafka.zk.ResourceZNode]] for details.
                • + *
                • All other patterns are stored under '/kafka-acl-extended/pattern-type'. + * The format is JSON. See [[kafka.zk.ResourceZNode]] for details.
                • + *
                + * + * Under each root node there will be one child node per resource type (Topic, Cluster, Group, etc). + * Under each resourceType there will be a unique child for each resource pattern and the data for that child will contain + * list of its acls as a json object. Following gives an example: + * + *
                +  * // Literal patterns:
                +  * /kafka-acl/Topic/topic-1 => {"version": 1, "acls": [ { "host":"host1", "permissionType": "Allow","operation": "Read","principal": "User:alice"}]}
                +  * /kafka-acl/Cluster/kafka-cluster => {"version": 1, "acls": [ { "host":"host1", "permissionType": "Allow","operation": "Read","principal": "User:alice"}]}
                +  *
                +  * // Prefixed patterns:
                +  * /kafka-acl-extended/PREFIXED/Group/group-1 => {"version": 1, "acls": [ { "host":"host1", "permissionType": "Allow","operation": "Read","principal": "User:alice"}]}
                +  * 
                + * + * Acl change events are also stored under two paths: + *
                  + *
                • [[org.apache.kafka.common.resource.PatternType#LITERAL Literal]] patterns are stored under '/kafka-acl-changes'. + * The format is a UTF8 string in the form: <resource-type>:<resource-name>
                • + *
                • All other patterns are stored under '/kafka-acl-extended-changes' + * The format is JSON, as defined by [[kafka.zk.ExtendedAclChangeEvent]]
                • + *
                + */ +sealed trait ZkAclStore { + val patternType: PatternType + val aclPath: String + + def path(resourceType: ResourceType): String = s"$aclPath/${SecurityUtils.resourceTypeName(resourceType)}" + + def path(resourceType: ResourceType, resourceName: String): String = s"$aclPath/${SecurityUtils.resourceTypeName(resourceType)}/$resourceName" + + def changeStore: ZkAclChangeStore +} + +object ZkAclStore { + private val storesByType: Map[PatternType, ZkAclStore] = PatternType.values + .filter(_.isSpecific) + .map(patternType => (patternType, create(patternType))) + .toMap + + val stores: Iterable[ZkAclStore] = storesByType.values + + val securePaths: Iterable[String] = stores + .flatMap(store => Set(store.aclPath, store.changeStore.aclChangePath)) + + def apply(patternType: PatternType): ZkAclStore = { + storesByType.get(patternType) match { + case Some(store) => store + case None => throw new KafkaException(s"Invalid pattern type: $patternType") + } + } + + private def create(patternType: PatternType) = { + patternType match { + case PatternType.LITERAL => LiteralAclStore + case _ => new ExtendedAclStore(patternType) + } + } +} + +object LiteralAclStore extends ZkAclStore { + val patternType: PatternType = PatternType.LITERAL + val aclPath: String = "/kafka-acl" + + def changeStore: ZkAclChangeStore = LiteralAclChangeStore +} + +class ExtendedAclStore(val patternType: PatternType) extends ZkAclStore { + if (patternType == PatternType.LITERAL) + throw new IllegalArgumentException("Literal pattern types are not supported") + + val aclPath: String = s"${ExtendedAclZNode.path}/${patternType.name.toLowerCase}" + + def changeStore: ZkAclChangeStore = ExtendedAclChangeStore +} + +object ExtendedAclZNode { + def path = "/kafka-acl-extended" +} + +trait AclChangeNotificationHandler { + def processNotification(resource: ResourcePattern): Unit +} + +trait AclChangeSubscription extends AutoCloseable { + def close(): Unit +} + +case class AclChangeNode(path: String, bytes: Array[Byte]) + +sealed trait ZkAclChangeStore { + val aclChangePath: String + def createPath: String = s"$aclChangePath/${ZkAclChangeStore.SequenceNumberPrefix}" + + def decode(bytes: Array[Byte]): ResourcePattern + + protected def encode(resource: ResourcePattern): Array[Byte] + + def createChangeNode(resource: ResourcePattern): AclChangeNode = AclChangeNode(createPath, encode(resource)) + + def createListener(handler: AclChangeNotificationHandler, zkClient: KafkaZkClient): AclChangeSubscription = { + val rawHandler: NotificationHandler = (bytes: Array[Byte]) => handler.processNotification(decode(bytes)) + + val aclChangeListener = new ZkNodeChangeNotificationListener( + zkClient, aclChangePath, ZkAclChangeStore.SequenceNumberPrefix, rawHandler) + + aclChangeListener.init() + + () => aclChangeListener.close() + } +} + +object ZkAclChangeStore { + val stores: Iterable[ZkAclChangeStore] = List(LiteralAclChangeStore, ExtendedAclChangeStore) + + def SequenceNumberPrefix = "acl_changes_" +} + +case object LiteralAclChangeStore extends ZkAclChangeStore { + val name = "LiteralAclChangeStore" + val aclChangePath: String = "/kafka-acl-changes" + + def encode(resource: ResourcePattern): Array[Byte] = { + if (resource.patternType != PatternType.LITERAL) + throw new IllegalArgumentException("Only literal resource patterns can be encoded") + + val legacyName = resource.resourceType.toString + AclEntry.ResourceSeparator + resource.name + legacyName.getBytes(UTF_8) + } + + def decode(bytes: Array[Byte]): ResourcePattern = { + val string = new String(bytes, UTF_8) + string.split(AclEntry.ResourceSeparator, 2) match { + case Array(resourceType, resourceName, _*) => new ResourcePattern(ResourceType.fromString(resourceType), resourceName, PatternType.LITERAL) + case _ => throw new IllegalArgumentException("expected a string in format ResourceType:ResourceName but got " + string) + } + } +} + +case object ExtendedAclChangeStore extends ZkAclChangeStore { + val name = "ExtendedAclChangeStore" + val aclChangePath: String = "/kafka-acl-extended-changes" + + def encode(resource: ResourcePattern): Array[Byte] = { + if (resource.patternType == PatternType.LITERAL) + throw new IllegalArgumentException("Literal pattern types are not supported") + + Json.encodeAsBytes(ExtendedAclChangeEvent( + ExtendedAclChangeEvent.currentVersion, + resource.resourceType.name, + resource.name, + resource.patternType.name)) + } + + def decode(bytes: Array[Byte]): ResourcePattern = { + val changeEvent = Json.parseBytesAs[ExtendedAclChangeEvent](bytes) match { + case Right(event) => event + case Left(e) => throw new IllegalArgumentException("Failed to parse ACL change event", e) + } + + changeEvent.toResource match { + case Success(r) => r + case Failure(e) => throw new IllegalArgumentException("Failed to convert ACL change event to resource", e) + } + } +} + +object ResourceZNode { + def path(resource: ResourcePattern): String = ZkAclStore(resource.patternType).path(resource.resourceType, resource.name) + + def encode(acls: Set[AclEntry]): Array[Byte] = Json.encodeAsBytes(AclEntry.toJsonCompatibleMap(acls).asJava) + def decode(bytes: Array[Byte], stat: Stat): VersionedAcls = VersionedAcls(AclEntry.fromBytes(bytes), stat.getVersion) +} + +object ExtendedAclChangeEvent { + val currentVersion: Int = 1 +} + +case class ExtendedAclChangeEvent(@BeanProperty @JsonProperty("version") version: Int, + @BeanProperty @JsonProperty("resourceType") resourceType: String, + @BeanProperty @JsonProperty("name") name: String, + @BeanProperty @JsonProperty("patternType") patternType: String) { + if (version > ExtendedAclChangeEvent.currentVersion) + throw new UnsupportedVersionException(s"Acl change event received for unsupported version: $version") + + def toResource: Try[ResourcePattern] = { + for { + resType <- Try(ResourceType.fromString(resourceType)) + patType <- Try(PatternType.fromString(patternType)) + resource = new ResourcePattern(resType, name, patType) + } yield resource + } +} + +object ClusterZNode { + def path = "/cluster" +} + +object ClusterIdZNode { + def path = s"${ClusterZNode.path}/id" + + def toJson(id: String): Array[Byte] = { + Json.encodeAsBytes(Map("version" -> "1", "id" -> id).asJava) + } + + def fromJson(clusterIdJson: Array[Byte]): String = { + Json.parseBytes(clusterIdJson).map(_.asJsonObject("id").to[String]).getOrElse { + throw new KafkaException(s"Failed to parse the cluster id json $clusterIdJson") + } + } +} + +object BrokerSequenceIdZNode { + def path = s"${BrokersZNode.path}/seqid" +} + +object ProducerIdBlockZNode { + val CurrentVersion: Long = 1L + + def path = "/latest_producer_id_block" + + def generateProducerIdBlockJson(producerIdBlock: ProducerIdsBlock): Array[Byte] = { + Json.encodeAsBytes(Map("version" -> CurrentVersion, + "broker" -> producerIdBlock.brokerId, + "block_start" -> producerIdBlock.producerIdStart.toString, + "block_end" -> producerIdBlock.producerIdEnd.toString).asJava + ) + } + + def parseProducerIdBlockData(jsonData: Array[Byte]): ProducerIdsBlock = { + val jsonDataAsString = jsonData.map(_.toChar).mkString + try { + Json.parseBytes(jsonData).map(_.asJsonObject).flatMap { js => + val brokerId = js("broker").to[Int] + val blockStart = js("block_start").to[String].toLong + val blockEnd = js("block_end").to[String].toLong + Some(new ProducerIdsBlock(brokerId, blockStart, Math.toIntExact(blockEnd - blockStart + 1))) + }.getOrElse(throw new KafkaException(s"Failed to parse the producerId block json $jsonDataAsString")) + } catch { + case e: java.lang.NumberFormatException => + // this should never happen: the written data has exceeded long type limit + throw new KafkaException(s"Read jason data $jsonDataAsString contains producerIds that have exceeded long type limit", e) + } + } +} + +object DelegationTokenAuthZNode { + def path = "/delegation_token" +} + +object DelegationTokenChangeNotificationZNode { + def path = s"${DelegationTokenAuthZNode.path}/token_changes" +} + +object DelegationTokenChangeNotificationSequenceZNode { + val SequenceNumberPrefix = "token_change_" + def createPath = s"${DelegationTokenChangeNotificationZNode.path}/$SequenceNumberPrefix" + def deletePath(sequenceNode: String) = s"${DelegationTokenChangeNotificationZNode.path}/${sequenceNode}" + def encode(tokenId : String): Array[Byte] = tokenId.getBytes(UTF_8) + def decode(bytes: Array[Byte]): String = new String(bytes, UTF_8) +} + +object DelegationTokensZNode { + def path = s"${DelegationTokenAuthZNode.path}/tokens" +} + +object DelegationTokenInfoZNode { + def path(tokenId: String) = s"${DelegationTokensZNode.path}/$tokenId" + def encode(token: DelegationToken): Array[Byte] = Json.encodeAsBytes(DelegationTokenManager.toJsonCompatibleMap(token).asJava) + def decode(bytes: Array[Byte]): Option[TokenInformation] = DelegationTokenManager.fromBytes(bytes) +} + +/** + * Represents the status of the FeatureZNode. + * + * Enabled -> This status means the feature versioning system (KIP-584) is enabled, and, the + * finalized features stored in the FeatureZNode are active. This status is written by + * the controller to the FeatureZNode only when the broker IBP config is greater than + * or equal to KAFKA_2_7_IV0. + * + * Disabled -> This status means the feature versioning system (KIP-584) is disabled, and, the + * the finalized features stored in the FeatureZNode is not relevant. This status is + * written by the controller to the FeatureZNode only when the broker IBP config + * is less than KAFKA_2_7_IV0. + */ +sealed trait FeatureZNodeStatus { + def id: Int +} + +object FeatureZNodeStatus { + case object Disabled extends FeatureZNodeStatus { + val id: Int = 0 + } + + case object Enabled extends FeatureZNodeStatus { + val id: Int = 1 + } + + def withNameOpt(id: Int): Option[FeatureZNodeStatus] = { + id match { + case Disabled.id => Some(Disabled) + case Enabled.id => Some(Enabled) + case _ => Option.empty + } + } +} + +/** + * Represents the contents of the ZK node containing finalized feature information. + * + * @param status the status of the ZK node + * @param features the cluster-wide finalized features + */ +case class FeatureZNode(status: FeatureZNodeStatus, features: Features[FinalizedVersionRange]) { +} + +object FeatureZNode { + private val VersionKey = "version" + private val StatusKey = "status" + private val FeaturesKey = "features" + + // V1 contains 'version', 'status' and 'features' keys. + val V1 = 1 + val CurrentVersion = V1 + + def path = "/feature" + + def asJavaMap(scalaMap: Map[String, Map[String, Short]]): util.Map[String, util.Map[String, java.lang.Short]] = { + scalaMap + .map { + case(featureName, versionInfo) => featureName -> versionInfo.map { + case(label, version) => label -> java.lang.Short.valueOf(version) + }.asJava + }.asJava + } + + /** + * Encodes a FeatureZNode to JSON. + * + * @param featureZNode FeatureZNode to be encoded + * + * @return JSON representation of the FeatureZNode, as an Array[Byte] + */ + def encode(featureZNode: FeatureZNode): Array[Byte] = { + val jsonMap = collection.mutable.Map( + VersionKey -> CurrentVersion, + StatusKey -> featureZNode.status.id, + FeaturesKey -> featureZNode.features.toMap) + Json.encodeAsBytes(jsonMap.asJava) + } + + /** + * Decodes the contents of the feature ZK node from Array[Byte] to a FeatureZNode. + * + * @param jsonBytes the contents of the feature ZK node + * + * @return the FeatureZNode created from jsonBytes + * + * @throws IllegalArgumentException if the Array[Byte] can not be decoded. + */ + def decode(jsonBytes: Array[Byte]): FeatureZNode = { + Json.tryParseBytes(jsonBytes) match { + case Right(js) => + val featureInfo = js.asJsonObject + val version = featureInfo(VersionKey).to[Int] + if (version < V1) { + throw new IllegalArgumentException(s"Unsupported version: $version of feature information: " + + s"${new String(jsonBytes, UTF_8)}") + } + + val featuresMap = featureInfo + .get(FeaturesKey) + .flatMap(_.to[Option[Map[String, Map[String, Int]]]]) + + if (featuresMap.isEmpty) { + throw new IllegalArgumentException("Features map can not be absent in: " + + s"${new String(jsonBytes, UTF_8)}") + } + val features = asJavaMap( + featuresMap + .map(theMap => theMap.map { + case (featureName, versionInfo) => featureName -> versionInfo.map { + case (label, version) => label -> version.asInstanceOf[Short] + } + }).getOrElse(Map[String, Map[String, Short]]())) + + val statusInt = featureInfo + .get(StatusKey) + .flatMap(_.to[Option[Int]]) + if (statusInt.isEmpty) { + throw new IllegalArgumentException("Status can not be absent in feature information: " + + s"${new String(jsonBytes, UTF_8)}") + } + val status = FeatureZNodeStatus.withNameOpt(statusInt.get) + if (status.isEmpty) { + throw new IllegalArgumentException( + s"Malformed status: $statusInt found in feature information: ${new String(jsonBytes, UTF_8)}") + } + + var finalizedFeatures: Features[FinalizedVersionRange] = null + try { + finalizedFeatures = fromFinalizedFeaturesMap(features) + } catch { + case e: Exception => throw new IllegalArgumentException( + "Unable to convert to finalized features from map: " + features, e) + } + FeatureZNode(status.get, finalizedFeatures) + case Left(e) => + throw new IllegalArgumentException(s"Failed to parse feature information: " + + s"${new String(jsonBytes, UTF_8)}", e) + } + } +} + +object ZkData { + + // Important: it is necessary to add any new top level Zookeeper path to the Seq + val SecureRootPaths = Seq(AdminZNode.path, + BrokersZNode.path, + ClusterZNode.path, + ConfigZNode.path, + ControllerZNode.path, + ControllerEpochZNode.path, + IsrChangeNotificationZNode.path, + ProducerIdBlockZNode.path, + LogDirEventNotificationZNode.path, + DelegationTokenAuthZNode.path, + ExtendedAclZNode.path) ++ ZkAclStore.securePaths + + // These are persistent ZK paths that should exist on kafka broker startup. + val PersistentZkPaths = Seq( + ConsumerPathZNode.path, // old consumer path + BrokerIdsZNode.path, + TopicsZNode.path, + ConfigEntityChangeNotificationZNode.path, + DeleteTopicsZNode.path, + BrokerSequenceIdZNode.path, + IsrChangeNotificationZNode.path, + ProducerIdBlockZNode.path, + LogDirEventNotificationZNode.path + ) ++ ConfigType.all.map(ConfigEntityTypeZNode.path) + + val SensitiveRootPaths = Seq( + ConfigEntityTypeZNode.path(ConfigType.User), + ConfigEntityTypeZNode.path(ConfigType.Broker), + DelegationTokensZNode.path + ) + + def sensitivePath(path: String): Boolean = { + path != null && SensitiveRootPaths.exists(path.startsWith) + } + + def defaultAcls(isSecure: Boolean, path: String): Seq[ACL] = { + //Old Consumer path is kept open as different consumers will write under this node. + if (!ConsumerPathZNode.path.equals(path) && isSecure) { + val acls = new ArrayBuffer[ACL] + acls ++= ZooDefs.Ids.CREATOR_ALL_ACL.asScala + if (!sensitivePath(path)) + acls ++= ZooDefs.Ids.READ_ACL_UNSAFE.asScala + acls + } else ZooDefs.Ids.OPEN_ACL_UNSAFE.asScala + } +} diff --git a/core/src/main/scala/kafka/zk/ZkSecurityMigratorUtils.scala b/core/src/main/scala/kafka/zk/ZkSecurityMigratorUtils.scala new file mode 100644 index 0000000..31a7ba2 --- /dev/null +++ b/core/src/main/scala/kafka/zk/ZkSecurityMigratorUtils.scala @@ -0,0 +1,30 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.zk + +import org.apache.zookeeper.ZooKeeper + +/** + * This class should only be used in ZkSecurityMigrator tool. + * This class will be removed after we migrate ZkSecurityMigrator away from ZK's asynchronous API. + * @param kafkaZkClient + */ +class ZkSecurityMigratorUtils(val kafkaZkClient: KafkaZkClient) { + + def currentZooKeeper: ZooKeeper = kafkaZkClient.currentZooKeeper + +} diff --git a/core/src/main/scala/kafka/zookeeper/ZooKeeperClient.scala b/core/src/main/scala/kafka/zookeeper/ZooKeeperClient.scala new file mode 100755 index 0000000..bc634a8 --- /dev/null +++ b/core/src/main/scala/kafka/zookeeper/ZooKeeperClient.scala @@ -0,0 +1,598 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.zookeeper + +import java.util.Locale +import java.util.concurrent.locks.{ReentrantLock, ReentrantReadWriteLock} +import java.util.concurrent._ +import java.util.{List => JList} + +import com.yammer.metrics.core.MetricName +import kafka.metrics.KafkaMetricsGroup +import kafka.utils.CoreUtils.{inLock, inReadLock, inWriteLock} +import kafka.utils.{KafkaScheduler, Logging} +import kafka.zookeeper.ZooKeeperClient._ +import org.apache.kafka.common.utils.Time +import org.apache.zookeeper.AsyncCallback.{Children2Callback, DataCallback, StatCallback} +import org.apache.zookeeper.KeeperException.Code +import org.apache.zookeeper.Watcher.Event.{EventType, KeeperState} +import org.apache.zookeeper.ZooKeeper.States +import org.apache.zookeeper.data.{ACL, Stat} +import org.apache.zookeeper._ +import org.apache.zookeeper.client.ZKClientConfig + +import scala.jdk.CollectionConverters._ +import scala.collection.Seq +import scala.collection.mutable.Set + +object ZooKeeperClient { + val RetryBackoffMs = 1000 +} + +/** + * A ZooKeeper client that encourages pipelined requests. + * + * @param connectString comma separated host:port pairs, each corresponding to a zk server + * @param sessionTimeoutMs session timeout in milliseconds + * @param connectionTimeoutMs connection timeout in milliseconds + * @param maxInFlightRequests maximum number of unacknowledged requests the client will send before blocking. + * @param name name of the client instance + * @param zkClientConfig ZooKeeper client configuration, for TLS configs if desired + */ +class ZooKeeperClient(connectString: String, + sessionTimeoutMs: Int, + connectionTimeoutMs: Int, + maxInFlightRequests: Int, + time: Time, + metricGroup: String, + metricType: String, + private[zookeeper] val clientConfig: ZKClientConfig, + name: String) extends Logging with KafkaMetricsGroup { + + this.logIdent = s"[ZooKeeperClient $name] " + private val initializationLock = new ReentrantReadWriteLock() + private val isConnectedOrExpiredLock = new ReentrantLock() + private val isConnectedOrExpiredCondition = isConnectedOrExpiredLock.newCondition() + private val zNodeChangeHandlers = new ConcurrentHashMap[String, ZNodeChangeHandler]().asScala + private val zNodeChildChangeHandlers = new ConcurrentHashMap[String, ZNodeChildChangeHandler]().asScala + private val inFlightRequests = new Semaphore(maxInFlightRequests) + private val stateChangeHandlers = new ConcurrentHashMap[String, StateChangeHandler]().asScala + private[zookeeper] val reinitializeScheduler = new KafkaScheduler(threads = 1, s"zk-client-${threadPrefix}reinit-") + private var isFirstConnectionEstablished = false + + private val metricNames = Set[String]() + + // The state map has to be created before creating ZooKeeper since it's needed in the ZooKeeper callback. + private val stateToMeterMap = { + import KeeperState._ + val stateToEventTypeMap = Map( + Disconnected -> "Disconnects", + SyncConnected -> "SyncConnects", + AuthFailed -> "AuthFailures", + ConnectedReadOnly -> "ReadOnlyConnects", + SaslAuthenticated -> "SaslAuthentications", + Expired -> "Expires" + ) + stateToEventTypeMap.map { case (state, eventType) => + val name = s"ZooKeeper${eventType}PerSec" + metricNames += name + state -> newMeter(name, eventType.toLowerCase(Locale.ROOT), TimeUnit.SECONDS) + } + } + + info(s"Initializing a new session to $connectString.") + // Fail-fast if there's an error during construction (so don't call initialize, which retries forever) + @volatile private var zooKeeper = new ZooKeeper(connectString, sessionTimeoutMs, ZooKeeperClientWatcher, + clientConfig) + + newGauge("SessionState", () => connectionState.toString) + + metricNames += "SessionState" + + reinitializeScheduler.startup() + try waitUntilConnected(connectionTimeoutMs, TimeUnit.MILLISECONDS) + catch { + case e: Throwable => + close() + throw e + } + + override def metricName(name: String, metricTags: scala.collection.Map[String, String]): MetricName = { + explicitMetricName(metricGroup, metricType, name, metricTags) + } + + /** + * Return the state of the ZooKeeper connection. + */ + def connectionState: States = zooKeeper.getState + + /** + * Send a request and wait for its response. See handle(Seq[AsyncRequest]) for details. + * + * @param request a single request to send and wait on. + * @return an instance of the response with the specific type (e.g. CreateRequest -> CreateResponse). + */ + def handleRequest[Req <: AsyncRequest](request: Req): Req#Response = { + handleRequests(Seq(request)).head + } + + /** + * Send a pipelined sequence of requests and wait for all of their responses. + * + * The watch flag on each outgoing request will be set if we've already registered a handler for the + * path associated with the request. + * + * @param requests a sequence of requests to send and wait on. + * @return the responses for the requests. If all requests have the same type, the responses will have the respective + * response type (e.g. Seq[CreateRequest] -> Seq[CreateResponse]). Otherwise, the most specific common supertype + * will be used (e.g. Seq[AsyncRequest] -> Seq[AsyncResponse]). + */ + def handleRequests[Req <: AsyncRequest](requests: Seq[Req]): Seq[Req#Response] = { + if (requests.isEmpty) + Seq.empty + else { + val countDownLatch = new CountDownLatch(requests.size) + val responseQueue = new ArrayBlockingQueue[Req#Response](requests.size) + + requests.foreach { request => + inFlightRequests.acquire() + try { + inReadLock(initializationLock) { + send(request) { response => + responseQueue.add(response) + inFlightRequests.release() + countDownLatch.countDown() + } + } + } catch { + case e: Throwable => + inFlightRequests.release() + throw e + } + } + countDownLatch.await() + responseQueue.asScala.toBuffer + } + } + + // Visibility to override for testing + private[zookeeper] def send[Req <: AsyncRequest](request: Req)(processResponse: Req#Response => Unit): Unit = { + // Safe to cast as we always create a response of the right type + def callback(response: AsyncResponse): Unit = processResponse(response.asInstanceOf[Req#Response]) + + def responseMetadata(sendTimeMs: Long) = new ResponseMetadata(sendTimeMs, receivedTimeMs = time.hiResClockMs()) + + val sendTimeMs = time.hiResClockMs() + + // Cast to AsyncRequest to workaround a scalac bug that results in an false exhaustiveness warning + // with -Xlint:strict-unsealed-patmat + (request: AsyncRequest) match { + case ExistsRequest(path, ctx) => + zooKeeper.exists(path, shouldWatch(request), new StatCallback { + def processResult(rc: Int, path: String, ctx: Any, stat: Stat): Unit = + callback(ExistsResponse(Code.get(rc), path, Option(ctx), stat, responseMetadata(sendTimeMs))) + }, ctx.orNull) + case GetDataRequest(path, ctx) => + zooKeeper.getData(path, shouldWatch(request), new DataCallback { + def processResult(rc: Int, path: String, ctx: Any, data: Array[Byte], stat: Stat): Unit = + callback(GetDataResponse(Code.get(rc), path, Option(ctx), data, stat, responseMetadata(sendTimeMs))) + }, ctx.orNull) + case GetChildrenRequest(path, _, ctx) => + zooKeeper.getChildren(path, shouldWatch(request), new Children2Callback { + def processResult(rc: Int, path: String, ctx: Any, children: JList[String], stat: Stat): Unit = + callback(GetChildrenResponse(Code.get(rc), path, Option(ctx), Option(children).map(_.asScala).getOrElse(Seq.empty), + stat, responseMetadata(sendTimeMs))) + }, ctx.orNull) + case CreateRequest(path, data, acl, createMode, ctx) => + zooKeeper.create(path, data, acl.asJava, createMode, + (rc, path, ctx, name) => + callback(CreateResponse(Code.get(rc), path, Option(ctx), name, responseMetadata(sendTimeMs))), + ctx.orNull) + case SetDataRequest(path, data, version, ctx) => + zooKeeper.setData(path, data, version, + (rc, path, ctx, stat) => + callback(SetDataResponse(Code.get(rc), path, Option(ctx), stat, responseMetadata(sendTimeMs))), + ctx.orNull) + case DeleteRequest(path, version, ctx) => + zooKeeper.delete(path, version, + (rc, path, ctx) => callback(DeleteResponse(Code.get(rc), path, Option(ctx), responseMetadata(sendTimeMs))), + ctx.orNull) + case GetAclRequest(path, ctx) => + zooKeeper.getACL(path, null, + (rc, path, ctx, acl, stat) => + callback(GetAclResponse(Code.get(rc), path, Option(ctx), Option(acl).map(_.asScala).getOrElse(Seq.empty), + stat, responseMetadata(sendTimeMs))), + ctx.orNull) + case SetAclRequest(path, acl, version, ctx) => + zooKeeper.setACL(path, acl.asJava, version, + (rc, path, ctx, stat) => + callback(SetAclResponse(Code.get(rc), path, Option(ctx), stat, responseMetadata(sendTimeMs))), + ctx.orNull) + case MultiRequest(zkOps, ctx) => + def toZkOpResult(opResults: JList[OpResult]): Seq[ZkOpResult] = + Option(opResults).map(results => zkOps.zip(results.asScala).map { case (zkOp, result) => + ZkOpResult(zkOp, result) + }).orNull + zooKeeper.multi(zkOps.map(_.toZookeeperOp).asJava, + (rc, path, ctx, opResults) => + callback(MultiResponse(Code.get(rc), path, Option(ctx), toZkOpResult(opResults), responseMetadata(sendTimeMs))), + ctx.orNull) + } + } + + /** + * Wait indefinitely until the underlying zookeeper client to reaches the CONNECTED state. + * @throws ZooKeeperClientAuthFailedException if the authentication failed either before or while waiting for connection. + * @throws ZooKeeperClientExpiredException if the session expired either before or while waiting for connection. + */ + def waitUntilConnected(): Unit = inLock(isConnectedOrExpiredLock) { + waitUntilConnected(Long.MaxValue, TimeUnit.MILLISECONDS) + } + + private def waitUntilConnected(timeout: Long, timeUnit: TimeUnit): Unit = { + info("Waiting until connected.") + var nanos = timeUnit.toNanos(timeout) + inLock(isConnectedOrExpiredLock) { + var state = connectionState + while (!state.isConnected && state.isAlive) { + if (nanos <= 0) { + throw new ZooKeeperClientTimeoutException(s"Timed out waiting for connection while in state: $state") + } + nanos = isConnectedOrExpiredCondition.awaitNanos(nanos) + state = connectionState + } + if (state == States.AUTH_FAILED) { + throw new ZooKeeperClientAuthFailedException("Auth failed either before or while waiting for connection") + } else if (state == States.CLOSED) { + throw new ZooKeeperClientExpiredException("Session expired either before or while waiting for connection") + } + isFirstConnectionEstablished = true + } + info("Connected.") + } + + // If this method is changed, the documentation for registerZNodeChangeHandler and/or registerZNodeChildChangeHandler + // may need to be updated. + private def shouldWatch(request: AsyncRequest): Boolean = request match { + case GetChildrenRequest(_, registerWatch, _) => registerWatch && zNodeChildChangeHandlers.contains(request.path) + case _: ExistsRequest | _: GetDataRequest => zNodeChangeHandlers.contains(request.path) + case _ => throw new IllegalArgumentException(s"Request $request is not watchable") + } + + /** + * Register the handler to ZooKeeperClient. This is just a local operation. This does not actually register a watcher. + * + * The watcher is only registered once the user calls handle(AsyncRequest) or handle(Seq[AsyncRequest]) + * with either a GetDataRequest or ExistsRequest. + * + * NOTE: zookeeper only allows registration to a nonexistent znode with ExistsRequest. + * + * @param zNodeChangeHandler the handler to register + */ + def registerZNodeChangeHandler(zNodeChangeHandler: ZNodeChangeHandler): Unit = { + zNodeChangeHandlers.put(zNodeChangeHandler.path, zNodeChangeHandler) + } + + /** + * Unregister the handler from ZooKeeperClient. This is just a local operation. + * @param path the path of the handler to unregister + */ + def unregisterZNodeChangeHandler(path: String): Unit = { + zNodeChangeHandlers.remove(path) + } + + /** + * Register the handler to ZooKeeperClient. This is just a local operation. This does not actually register a watcher. + * + * The watcher is only registered once the user calls handle(AsyncRequest) or handle(Seq[AsyncRequest]) with a GetChildrenRequest. + * + * @param zNodeChildChangeHandler the handler to register + */ + def registerZNodeChildChangeHandler(zNodeChildChangeHandler: ZNodeChildChangeHandler): Unit = { + zNodeChildChangeHandlers.put(zNodeChildChangeHandler.path, zNodeChildChangeHandler) + } + + /** + * Unregister the handler from ZooKeeperClient. This is just a local operation. + * @param path the path of the handler to unregister + */ + def unregisterZNodeChildChangeHandler(path: String): Unit = { + zNodeChildChangeHandlers.remove(path) + } + + /** + * @param stateChangeHandler + */ + def registerStateChangeHandler(stateChangeHandler: StateChangeHandler): Unit = inReadLock(initializationLock) { + if (stateChangeHandler != null) + stateChangeHandlers.put(stateChangeHandler.name, stateChangeHandler) + } + + /** + * + * @param name + */ + def unregisterStateChangeHandler(name: String): Unit = inReadLock(initializationLock) { + stateChangeHandlers.remove(name) + } + + def close(): Unit = { + info("Closing.") + + // Shutdown scheduler outside of lock to avoid deadlock if scheduler + // is waiting for lock to process session expiry. Close expiry thread + // first to ensure that new clients are not created during close(). + reinitializeScheduler.shutdown() + + inWriteLock(initializationLock) { + zNodeChangeHandlers.clear() + zNodeChildChangeHandlers.clear() + stateChangeHandlers.clear() + zooKeeper.close() + metricNames.foreach(removeMetric(_)) + } + info("Closed.") + } + + def sessionId: Long = inReadLock(initializationLock) { + zooKeeper.getSessionId + } + + // Only for testing + private[kafka] def currentZooKeeper: ZooKeeper = inReadLock(initializationLock) { + zooKeeper + } + + private def reinitialize(): Unit = { + // Initialization callbacks are invoked outside of the lock to avoid deadlock potential since their completion + // may require additional Zookeeper requests, which will block to acquire the initialization lock + stateChangeHandlers.values.foreach(callBeforeInitializingSession _) + + inWriteLock(initializationLock) { + if (!connectionState.isAlive) { + zooKeeper.close() + info(s"Initializing a new session to $connectString.") + // retry forever until ZooKeeper can be instantiated + var connected = false + while (!connected) { + try { + zooKeeper = new ZooKeeper(connectString, sessionTimeoutMs, ZooKeeperClientWatcher, clientConfig) + connected = true + } catch { + case e: Exception => + info("Error when recreating ZooKeeper, retrying after a short sleep", e) + Thread.sleep(RetryBackoffMs) + } + } + } + } + + stateChangeHandlers.values.foreach(callAfterInitializingSession _) + } + + /** + * Close the zookeeper client to force session reinitialization. This is visible for testing only. + */ + private[zookeeper] def forceReinitialize(): Unit = { + zooKeeper.close() + reinitialize() + } + + private def callBeforeInitializingSession(handler: StateChangeHandler): Unit = { + try { + handler.beforeInitializingSession() + } catch { + case t: Throwable => + error(s"Uncaught error in handler ${handler.name}", t) + } + } + + private def callAfterInitializingSession(handler: StateChangeHandler): Unit = { + try { + handler.afterInitializingSession() + } catch { + case t: Throwable => + error(s"Uncaught error in handler ${handler.name}", t) + } + } + + // Visibility for testing + private[zookeeper] def scheduleReinitialize(name: String, message: String, delayMs: Long): Unit = { + reinitializeScheduler.schedule(name, () => { + info(message) + reinitialize() + }, delayMs, period = -1L, unit = TimeUnit.MILLISECONDS) + } + + private def threadPrefix: String = name.replaceAll("\\s", "") + "-" + + // package level visibility for testing only + private[zookeeper] object ZooKeeperClientWatcher extends Watcher { + override def process(event: WatchedEvent): Unit = { + debug(s"Received event: $event") + Option(event.getPath) match { + case None => + val state = event.getState + stateToMeterMap.get(state).foreach(_.mark()) + inLock(isConnectedOrExpiredLock) { + isConnectedOrExpiredCondition.signalAll() + } + if (state == KeeperState.AuthFailed) { + error(s"Auth failed, initialized=$isFirstConnectionEstablished connectionState=$connectionState") + stateChangeHandlers.values.foreach(_.onAuthFailure()) + + // If this is during initial startup, we fail fast. Otherwise, schedule retry. + val initialized = inLock(isConnectedOrExpiredLock) { + isFirstConnectionEstablished + } + if (initialized && !connectionState.isAlive) + scheduleReinitialize("auth-failed", "Reinitializing due to auth failure.", RetryBackoffMs) + } else if (state == KeeperState.Expired) { + scheduleReinitialize("session-expired", "Session expired.", delayMs = 0L) + } + case Some(path) => + (event.getType: @unchecked) match { + case EventType.NodeChildrenChanged => zNodeChildChangeHandlers.get(path).foreach(_.handleChildChange()) + case EventType.NodeCreated => zNodeChangeHandlers.get(path).foreach(_.handleCreation()) + case EventType.NodeDeleted => zNodeChangeHandlers.get(path).foreach(_.handleDeletion()) + case EventType.NodeDataChanged => zNodeChangeHandlers.get(path).foreach(_.handleDataChange()) + } + } + } + } +} + +trait StateChangeHandler { + val name: String + def beforeInitializingSession(): Unit = {} + def afterInitializingSession(): Unit = {} + def onAuthFailure(): Unit = {} +} + +trait ZNodeChangeHandler { + val path: String + def handleCreation(): Unit = {} + def handleDeletion(): Unit = {} + def handleDataChange(): Unit = {} +} + +trait ZNodeChildChangeHandler { + val path: String + def handleChildChange(): Unit = {} +} + +// Thin wrapper for zookeeper.Op +sealed trait ZkOp { + def toZookeeperOp: Op +} + +case class CreateOp(path: String, data: Array[Byte], acl: Seq[ACL], createMode: CreateMode) extends ZkOp { + override def toZookeeperOp: Op = Op.create(path, data, acl.asJava, createMode) +} + +case class DeleteOp(path: String, version: Int) extends ZkOp { + override def toZookeeperOp: Op = Op.delete(path, version) +} + +case class SetDataOp(path: String, data: Array[Byte], version: Int) extends ZkOp { + override def toZookeeperOp: Op = Op.setData(path, data, version) +} + +case class CheckOp(path: String, version: Int) extends ZkOp { + override def toZookeeperOp: Op = Op.check(path, version) +} + +case class ZkOpResult(zkOp: ZkOp, rawOpResult: OpResult) + +sealed trait AsyncRequest { + /** + * This type member allows us to define methods that take requests and return responses with the correct types. + * See ``ZooKeeperClient.handleRequests`` for example. + */ + type Response <: AsyncResponse + def path: String + def ctx: Option[Any] +} + +case class CreateRequest(path: String, data: Array[Byte], acl: Seq[ACL], createMode: CreateMode, + ctx: Option[Any] = None) extends AsyncRequest { + type Response = CreateResponse +} + +case class DeleteRequest(path: String, version: Int, ctx: Option[Any] = None) extends AsyncRequest { + type Response = DeleteResponse +} + +case class ExistsRequest(path: String, ctx: Option[Any] = None) extends AsyncRequest { + type Response = ExistsResponse +} + +case class GetDataRequest(path: String, ctx: Option[Any] = None) extends AsyncRequest { + type Response = GetDataResponse +} + +case class SetDataRequest(path: String, data: Array[Byte], version: Int, ctx: Option[Any] = None) extends AsyncRequest { + type Response = SetDataResponse +} + +case class GetAclRequest(path: String, ctx: Option[Any] = None) extends AsyncRequest { + type Response = GetAclResponse +} + +case class SetAclRequest(path: String, acl: Seq[ACL], version: Int, ctx: Option[Any] = None) extends AsyncRequest { + type Response = SetAclResponse +} + +case class GetChildrenRequest(path: String, registerWatch: Boolean, ctx: Option[Any] = None) extends AsyncRequest { + type Response = GetChildrenResponse +} + +case class MultiRequest(zkOps: Seq[ZkOp], ctx: Option[Any] = None) extends AsyncRequest { + type Response = MultiResponse + + override def path: String = null +} + + +sealed abstract class AsyncResponse { + def resultCode: Code + def path: String + def ctx: Option[Any] + + /** Return None if the result code is OK and KeeperException otherwise. */ + def resultException: Option[KeeperException] = + if (resultCode == Code.OK) None else Some(KeeperException.create(resultCode, path)) + + /** + * Throw KeeperException if the result code is not OK. + */ + def maybeThrow(): Unit = { + if (resultCode != Code.OK) + throw KeeperException.create(resultCode, path) + } + + def metadata: ResponseMetadata +} + +case class ResponseMetadata(sendTimeMs: Long, receivedTimeMs: Long) { + def responseTimeMs: Long = receivedTimeMs - sendTimeMs +} + +case class CreateResponse(resultCode: Code, path: String, ctx: Option[Any], name: String, + metadata: ResponseMetadata) extends AsyncResponse +case class DeleteResponse(resultCode: Code, path: String, ctx: Option[Any], + metadata: ResponseMetadata) extends AsyncResponse +case class ExistsResponse(resultCode: Code, path: String, ctx: Option[Any], stat: Stat, + metadata: ResponseMetadata) extends AsyncResponse +case class GetDataResponse(resultCode: Code, path: String, ctx: Option[Any], data: Array[Byte], stat: Stat, + metadata: ResponseMetadata) extends AsyncResponse +case class SetDataResponse(resultCode: Code, path: String, ctx: Option[Any], stat: Stat, + metadata: ResponseMetadata) extends AsyncResponse +case class GetAclResponse(resultCode: Code, path: String, ctx: Option[Any], acl: Seq[ACL], stat: Stat, + metadata: ResponseMetadata) extends AsyncResponse +case class SetAclResponse(resultCode: Code, path: String, ctx: Option[Any], stat: Stat, + metadata: ResponseMetadata) extends AsyncResponse +case class GetChildrenResponse(resultCode: Code, path: String, ctx: Option[Any], children: Seq[String], stat: Stat, + metadata: ResponseMetadata) extends AsyncResponse +case class MultiResponse(resultCode: Code, path: String, ctx: Option[Any], zkOpResults: Seq[ZkOpResult], + metadata: ResponseMetadata) extends AsyncResponse + +class ZooKeeperClientException(message: String) extends RuntimeException(message) +class ZooKeeperClientExpiredException(message: String) extends ZooKeeperClientException(message) +class ZooKeeperClientAuthFailedException(message: String) extends ZooKeeperClientException(message) +class ZooKeeperClientTimeoutException(message: String) extends ZooKeeperClientException(message) diff --git a/core/src/main/scala/org/apache/zookeeper/ZooKeeperMainWithTlsSupportForKafka.scala b/core/src/main/scala/org/apache/zookeeper/ZooKeeperMainWithTlsSupportForKafka.scala new file mode 100644 index 0000000..93674ad --- /dev/null +++ b/core/src/main/scala/org/apache/zookeeper/ZooKeeperMainWithTlsSupportForKafka.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.zookeeper + +import kafka.admin.ZkSecurityMigrator +import org.apache.zookeeper.admin.ZooKeeperAdmin +import org.apache.zookeeper.cli.CommandNotFoundException +import org.apache.zookeeper.cli.MalformedCommandException +import org.apache.zookeeper.client.ZKClientConfig + +import scala.jdk.CollectionConverters._ + +object ZooKeeperMainWithTlsSupportForKafka { + val zkTlsConfigFileOption = "-zk-tls-config-file" + def main(args: Array[String]): Unit = { + val zkTlsConfigFileIndex = args.indexOf(zkTlsConfigFileOption) + val zooKeeperMain: ZooKeeperMain = + if (zkTlsConfigFileIndex < 0) + // no TLS config, so just pass args directly + new ZooKeeperMainWithTlsSupportForKafka(args, None) + else if (zkTlsConfigFileIndex == args.length - 1) + throw new IllegalArgumentException(s"Error: no filename provided with option $zkTlsConfigFileOption") + else + // found TLS config, so instantiate it and pass args without the two TLS config-related arguments + new ZooKeeperMainWithTlsSupportForKafka( + args.slice(0, zkTlsConfigFileIndex) ++ args.slice(zkTlsConfigFileIndex + 2, args.length), + Some(ZkSecurityMigrator.createZkClientConfigFromFile(args(zkTlsConfigFileIndex + 1)))) + // The run method of ZooKeeperMain is package-private, + // therefore this code unfortunately must reside in the same org.apache.zookeeper package. + zooKeeperMain.run + } +} + +class ZooKeeperMainWithTlsSupportForKafka(args: Array[String], val zkClientConfig: Option[ZKClientConfig]) + extends ZooKeeperMain(args) with Watcher { + + override def processZKCmd (co: ZooKeeperMain.MyCommandOptions): Boolean = { + // Unfortunately the usage() method is static, so it can't be overridden. + // This method is where usage() gets called. We don't cover all possible calls + // to usage() -- we would have to implement the entire method to do that -- but + // the short implementation below covers most cases. + val args = co.getArgArray + val cmd = co.getCommand + if (args.length < 1) { + kafkaTlsUsage() + throw new MalformedCommandException("No command entered") + } + + if (!ZooKeeperMain.commandMap.containsKey(cmd)) { + kafkaTlsUsage() + throw new CommandNotFoundException(s"Command not found $cmd") + } + super.processZKCmd(co) + } + + def kafkaTlsUsage(): Unit = { + System.err.println("ZooKeeper -server host:port [-zk-tls-config-file ] cmd args") + ZooKeeperMain.commandMap.keySet.asScala.toList.sorted.foreach(cmd => + System.err.println(s"\t$cmd ${ZooKeeperMain.commandMap.get(cmd)}")) + } + + override def connectToZK(newHost: String) = { + // ZooKeeperAdmin has no constructor that supports passing in both readOnly and ZkClientConfig, + // and readOnly ends up being set to false when passing in a ZkClientConfig instance; + // therefore it is currently not possible for us to construct a ZooKeeperAdmin instance with + // both an explicit ZkClientConfig instance and a readOnly value of true. + val readOnlyRequested = cl.getOption("readonly") != null + if (readOnlyRequested && zkClientConfig.isDefined) + throw new IllegalArgumentException( + s"read-only mode (-r) is not supported with an explicit TLS config (${ZooKeeperMainWithTlsSupportForKafka.zkTlsConfigFileOption})") + if (zk != null && zk.getState.isAlive) zk.close() + host = newHost + zk = if (zkClientConfig.isDefined) + new ZooKeeperAdmin(host, cl.getOption("timeout").toInt, this, zkClientConfig.get) + else + new ZooKeeperAdmin(host, cl.getOption("timeout").toInt, this, readOnlyRequested) + } + + override def process(event: WatchedEvent): Unit = { + if (getPrintWatches) { + ZooKeeperMain.printMessage("WATCHER::") + ZooKeeperMain.printMessage(event.toString) + } + } +} diff --git a/core/src/test/java/kafka/test/ClusterConfig.java b/core/src/test/java/kafka/test/ClusterConfig.java new file mode 100644 index 0000000..20b74cf --- /dev/null +++ b/core/src/test/java/kafka/test/ClusterConfig.java @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.test; + +import kafka.test.annotation.Type; +import org.apache.kafka.common.security.auth.SecurityProtocol; + +import java.io.File; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Optional; +import java.util.Properties; + +/** + * Represents a requested configuration of a Kafka cluster for integration testing + */ +public class ClusterConfig { + + private final Type type; + private final int brokers; + private final int controllers; + private final String name; + private final boolean autoStart; + + private final SecurityProtocol securityProtocol; + private final String listenerName; + private final File trustStoreFile; + private final String ibp; + + private final Properties serverProperties = new Properties(); + private final Properties producerProperties = new Properties(); + private final Properties consumerProperties = new Properties(); + private final Properties adminClientProperties = new Properties(); + private final Properties saslServerProperties = new Properties(); + private final Properties saslClientProperties = new Properties(); + private final Map perBrokerOverrideProperties = new HashMap<>(); + + ClusterConfig(Type type, int brokers, int controllers, String name, boolean autoStart, + SecurityProtocol securityProtocol, String listenerName, File trustStoreFile, + String ibp) { + this.type = type; + this.brokers = brokers; + this.controllers = controllers; + this.name = name; + this.autoStart = autoStart; + this.securityProtocol = securityProtocol; + this.listenerName = listenerName; + this.trustStoreFile = trustStoreFile; + this.ibp = ibp; + } + + public Type clusterType() { + return type; + } + + public int numBrokers() { + return brokers; + } + + public int numControllers() { + return controllers; + } + + public Optional name() { + return Optional.ofNullable(name); + } + + public boolean isAutoStart() { + return autoStart; + } + + public Properties serverProperties() { + return serverProperties; + } + + public Properties producerProperties() { + return producerProperties; + } + + public Properties consumerProperties() { + return consumerProperties; + } + + public Properties adminClientProperties() { + return adminClientProperties; + } + + public Properties saslServerProperties() { + return saslServerProperties; + } + + public Properties saslClientProperties() { + return saslClientProperties; + } + + public SecurityProtocol securityProtocol() { + return securityProtocol; + } + + public Optional listenerName() { + return Optional.ofNullable(listenerName); + } + + public Optional trustStoreFile() { + return Optional.ofNullable(trustStoreFile); + } + + public Optional ibp() { + return Optional.ofNullable(ibp); + } + + public Properties brokerServerProperties(int brokerId) { + return perBrokerOverrideProperties.computeIfAbsent(brokerId, __ -> new Properties()); + } + + public Map nameTags() { + Map tags = new LinkedHashMap<>(3); + name().ifPresent(name -> tags.put("Name", name)); + ibp().ifPresent(ibp -> tags.put("IBP", ibp)); + tags.put("Security", securityProtocol.name()); + listenerName().ifPresent(listener -> tags.put("Listener", listener)); + return tags; + } + + public ClusterConfig copyOf() { + ClusterConfig copy = new ClusterConfig(type, brokers, controllers, name, autoStart, securityProtocol, listenerName, trustStoreFile, ibp); + copy.serverProperties.putAll(serverProperties); + copy.producerProperties.putAll(producerProperties); + copy.consumerProperties.putAll(consumerProperties); + copy.saslServerProperties.putAll(saslServerProperties); + copy.saslClientProperties.putAll(saslClientProperties); + return copy; + } + + public static Builder defaultClusterBuilder() { + return new Builder(Type.ZK, 1, 1, true, SecurityProtocol.PLAINTEXT); + } + + public static Builder clusterBuilder(Type type, int brokers, int controllers, boolean autoStart, SecurityProtocol securityProtocol) { + return new Builder(type, brokers, controllers, autoStart, securityProtocol); + } + + public static class Builder { + private Type type; + private int brokers; + private int controllers; + private String name; + private boolean autoStart; + private SecurityProtocol securityProtocol; + private String listenerName; + private File trustStoreFile; + private String ibp; + + Builder(Type type, int brokers, int controllers, boolean autoStart, SecurityProtocol securityProtocol) { + this.type = type; + this.brokers = brokers; + this.controllers = controllers; + this.autoStart = autoStart; + this.securityProtocol = securityProtocol; + } + + public Builder type(Type type) { + this.type = type; + return this; + } + + public Builder brokers(int brokers) { + this.brokers = brokers; + return this; + } + + public Builder controllers(int controllers) { + this.controllers = controllers; + return this; + } + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder autoStart(boolean autoStart) { + this.autoStart = autoStart; + return this; + } + + public Builder securityProtocol(SecurityProtocol securityProtocol) { + this.securityProtocol = securityProtocol; + return this; + } + + public Builder listenerName(String listenerName) { + this.listenerName = listenerName; + return this; + } + + public Builder trustStoreFile(File trustStoreFile) { + this.trustStoreFile = trustStoreFile; + return this; + } + + public Builder ibp(String ibp) { + this.ibp = ibp; + return this; + } + + public ClusterConfig build() { + return new ClusterConfig(type, brokers, controllers, name, autoStart, securityProtocol, listenerName, trustStoreFile, ibp); + } + } +} diff --git a/core/src/test/java/kafka/test/ClusterGenerator.java b/core/src/test/java/kafka/test/ClusterGenerator.java new file mode 100644 index 0000000..97a2463 --- /dev/null +++ b/core/src/test/java/kafka/test/ClusterGenerator.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.test; + +import java.util.function.Consumer; + +@FunctionalInterface +public interface ClusterGenerator extends Consumer { + +} diff --git a/core/src/test/java/kafka/test/ClusterInstance.java b/core/src/test/java/kafka/test/ClusterInstance.java new file mode 100644 index 0000000..23b417e --- /dev/null +++ b/core/src/test/java/kafka/test/ClusterInstance.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.test; + +import kafka.network.SocketServer; +import kafka.test.annotation.ClusterTest; +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.common.network.ListenerName; + +import java.util.Collection; +import java.util.Properties; + +public interface ClusterInstance { + + enum ClusterType { + ZK, + RAFT + } + + /** + * Cluster type. For now, only ZK is supported. + */ + ClusterType clusterType(); + + /** + * The cluster configuration used to create this cluster. Changing data in this instance through this accessor will + * have no affect on the cluster since it is already provisioned. + */ + ClusterConfig config(); + + /** + * The listener for this cluster as configured by {@link ClusterTest} or by {@link ClusterConfig}. If + * unspecified by those sources, this will return the listener for the default security protocol PLAINTEXT + */ + ListenerName clientListener(); + + /** + * The broker connect string which can be used by clients for bootstrapping + */ + String bootstrapServers(); + + /** + * A collection of all brokers in the cluster. In ZK-based clusters this will also include the broker which is + * acting as the controller (since ZK controllers serve both broker and controller roles). + */ + Collection brokerSocketServers(); + + /** + * A collection of all controllers in the cluster. For ZK-based clusters, this will return the broker which is also + * currently the active controller. For Raft-based clusters, this will return all controller servers. + */ + Collection controllerSocketServers(); + + /** + * Return any one of the broker servers. Throw an error if none are found + */ + SocketServer anyBrokerSocketServer(); + + /** + * Return any one of the controller servers. Throw an error if none are found + */ + SocketServer anyControllerSocketServer(); + + /** + * The underlying object which is responsible for setting up and tearing down the cluster. + */ + Object getUnderlying(); + + default T getUnderlying(Class asClass) { + return asClass.cast(getUnderlying()); + } + + Admin createAdminClient(Properties configOverrides); + + default Admin createAdminClient() { + return createAdminClient(new Properties()); + } + + void start(); + + void stop(); + + void shutdownBroker(int brokerId); + + void startBroker(int brokerId); + + void rollingBrokerRestart(); + + void waitForReadyBrokers() throws InterruptedException; +} diff --git a/core/src/test/java/kafka/test/ClusterTestExtensionsTest.java b/core/src/test/java/kafka/test/ClusterTestExtensionsTest.java new file mode 100644 index 0000000..767a279 --- /dev/null +++ b/core/src/test/java/kafka/test/ClusterTestExtensionsTest.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.test; + +import kafka.test.annotation.AutoStart; +import kafka.test.annotation.ClusterConfigProperty; +import kafka.test.annotation.ClusterTemplate; +import kafka.test.annotation.ClusterTest; +import kafka.test.annotation.ClusterTestDefaults; +import kafka.test.annotation.ClusterTests; +import kafka.test.annotation.Type; +import kafka.test.junit.ClusterTestExtensions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.extension.ExtendWith; + + +@ClusterTestDefaults(clusterType = Type.ZK) // Set defaults for a few params in @ClusterTest(s) +@ExtendWith(ClusterTestExtensions.class) +public class ClusterTestExtensionsTest { + + private final ClusterInstance clusterInstance; + private final ClusterConfig config; + + ClusterTestExtensionsTest(ClusterInstance clusterInstance, ClusterConfig config) { // Constructor injections + this.clusterInstance = clusterInstance; + this.config = config; + } + + // Static methods can generate cluster configurations + static void generate1(ClusterGenerator clusterGenerator) { + clusterGenerator.accept(ClusterConfig.defaultClusterBuilder().name("Generated Test").build()); + } + + // BeforeEach run after class construction, but before cluster initialization and test invocation + @BeforeEach + public void beforeEach(ClusterConfig config) { + Assertions.assertSame(this.config, config, "Injected objects should be the same"); + config.serverProperties().put("before", "each"); + } + + // AfterEach runs after test invocation and cluster teardown + @AfterEach + public void afterEach(ClusterConfig config) { + Assertions.assertSame(this.config, config, "Injected objects should be the same"); + } + + // With no params, configuration comes from the annotation defaults as well as @ClusterTestDefaults (if present) + @ClusterTest + public void testClusterTest(ClusterConfig config, ClusterInstance clusterInstance) { + Assertions.assertSame(this.config, config, "Injected objects should be the same"); + Assertions.assertSame(this.clusterInstance, clusterInstance, "Injected objects should be the same"); + Assertions.assertEquals(clusterInstance.clusterType(), ClusterInstance.ClusterType.ZK); // From the class level default + Assertions.assertEquals(clusterInstance.config().serverProperties().getProperty("before"), "each"); + } + + // generate1 is a template method which generates any number of cluster configs + @ClusterTemplate("generate1") + public void testClusterTemplate() { + Assertions.assertEquals(clusterInstance.clusterType(), ClusterInstance.ClusterType.ZK, + "generate1 provided a Zk cluster, so we should see that here"); + Assertions.assertEquals(clusterInstance.config().name().orElse(""), "Generated Test", + "generate 1 named this cluster config, so we should see that here"); + Assertions.assertEquals(clusterInstance.config().serverProperties().getProperty("before"), "each"); + } + + // Multiple @ClusterTest can be used with @ClusterTests + @ClusterTests({ + @ClusterTest(name = "cluster-tests-1", clusterType = Type.ZK, serverProperties = { + @ClusterConfigProperty(key = "foo", value = "bar"), + @ClusterConfigProperty(key = "spam", value = "eggs") + }), + @ClusterTest(name = "cluster-tests-2", clusterType = Type.KRAFT, serverProperties = { + @ClusterConfigProperty(key = "foo", value = "baz"), + @ClusterConfigProperty(key = "spam", value = "eggz") + }) + }) + public void testClusterTests() { + if (clusterInstance.clusterType().equals(ClusterInstance.ClusterType.ZK)) { + Assertions.assertEquals(clusterInstance.config().serverProperties().getProperty("foo"), "bar"); + Assertions.assertEquals(clusterInstance.config().serverProperties().getProperty("spam"), "eggs"); + } else if (clusterInstance.clusterType().equals(ClusterInstance.ClusterType.RAFT)) { + Assertions.assertEquals(clusterInstance.config().serverProperties().getProperty("foo"), "baz"); + Assertions.assertEquals(clusterInstance.config().serverProperties().getProperty("spam"), "eggz"); + } else { + Assertions.fail("Unknown cluster type " + clusterInstance.clusterType()); + } + } + + @ClusterTest(autoStart = AutoStart.NO) + public void testNoAutoStart() { + Assertions.assertThrows(RuntimeException.class, clusterInstance::anyBrokerSocketServer); + clusterInstance.start(); + Assertions.assertNotNull(clusterInstance.anyBrokerSocketServer()); + } +} diff --git a/core/src/test/java/kafka/test/MockController.java b/core/src/test/java/kafka/test/MockController.java new file mode 100644 index 0000000..f56e2cb --- /dev/null +++ b/core/src/test/java/kafka/test/MockController.java @@ -0,0 +1,357 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.test; + +import org.apache.kafka.clients.admin.AlterConfigOp; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.errors.NotControllerException; +import org.apache.kafka.common.message.AllocateProducerIdsRequestData; +import org.apache.kafka.common.message.AllocateProducerIdsResponseData; +import org.apache.kafka.common.message.AlterIsrRequestData; +import org.apache.kafka.common.message.AlterIsrResponseData; +import org.apache.kafka.common.message.AlterPartitionReassignmentsRequestData; +import org.apache.kafka.common.message.AlterPartitionReassignmentsResponseData; +import org.apache.kafka.common.message.BrokerHeartbeatRequestData; +import org.apache.kafka.common.message.BrokerRegistrationRequestData; +import org.apache.kafka.common.message.CreatePartitionsRequestData.CreatePartitionsTopic; +import org.apache.kafka.common.message.CreatePartitionsResponseData.CreatePartitionsTopicResult; +import org.apache.kafka.common.message.CreateTopicsRequestData; +import org.apache.kafka.common.message.CreateTopicsRequestData.CreatableTopic; +import org.apache.kafka.common.message.CreateTopicsResponseData; +import org.apache.kafka.common.message.CreateTopicsResponseData.CreatableTopicResult; +import org.apache.kafka.common.message.ElectLeadersRequestData; +import org.apache.kafka.common.message.ElectLeadersResponseData; +import org.apache.kafka.common.message.ListPartitionReassignmentsRequestData; +import org.apache.kafka.common.message.ListPartitionReassignmentsResponseData; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.quota.ClientQuotaAlteration; +import org.apache.kafka.common.quota.ClientQuotaEntity; +import org.apache.kafka.common.requests.ApiError; +import org.apache.kafka.controller.Controller; +import org.apache.kafka.controller.ResultOrError; +import org.apache.kafka.metadata.BrokerHeartbeatReply; +import org.apache.kafka.metadata.BrokerRegistrationReply; +import org.apache.kafka.metadata.FeatureMapAndEpoch; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicLong; + +import static org.apache.kafka.clients.admin.AlterConfigOp.OpType.DELETE; +import static org.apache.kafka.clients.admin.AlterConfigOp.OpType.SET; +import static org.apache.kafka.common.protocol.Errors.INVALID_REQUEST; + + +public class MockController implements Controller { + private final static NotControllerException NOT_CONTROLLER_EXCEPTION = + new NotControllerException("This is not the correct controller for this cluster."); + + private final AtomicLong nextTopicId = new AtomicLong(1); + + public static class Builder { + private final Map initialTopics = new HashMap<>(); + + public Builder newInitialTopic(String name, Uuid id) { + initialTopics.put(name, new MockTopic(name, id)); + return this; + } + + public MockController build() { + return new MockController(initialTopics.values()); + } + } + + private volatile boolean active = true; + + private MockController(Collection initialTopics) { + for (MockTopic topic : initialTopics) { + topics.put(topic.id, topic); + topicNameToId.put(topic.name, topic.id); + } + } + + @Override + public CompletableFuture alterIsr(AlterIsrRequestData request) { + throw new UnsupportedOperationException(); + } + + @Override + synchronized public CompletableFuture + createTopics(CreateTopicsRequestData request) { + CreateTopicsResponseData response = new CreateTopicsResponseData(); + for (CreatableTopic topic : request.topics()) { + if (topicNameToId.containsKey(topic.name())) { + response.topics().add(new CreatableTopicResult(). + setName(topic.name()). + setErrorCode(Errors.TOPIC_ALREADY_EXISTS.code())); + } else { + long topicId = nextTopicId.getAndIncrement(); + Uuid topicUuid = new Uuid(0, topicId); + topicNameToId.put(topic.name(), topicUuid); + topics.put(topicUuid, new MockTopic(topic.name(), topicUuid)); + response.topics().add(new CreatableTopicResult(). + setName(topic.name()). + setErrorCode(Errors.NONE.code()). + setTopicId(topicUuid)); + // For a better mock, we might want to return configs, replication factor, + // etc. Right now, the tests that use MockController don't need these + // things. + } + } + return CompletableFuture.completedFuture(response); + } + + @Override + public CompletableFuture unregisterBroker(int brokerId) { + throw new UnsupportedOperationException(); + } + + static class MockTopic { + private final String name; + private final Uuid id; + + MockTopic(String name, Uuid id) { + this.name = name; + this.id = id; + } + } + + private final Map topicNameToId = new HashMap<>(); + + private final Map topics = new HashMap<>(); + + private final Map> configs = new HashMap<>(); + + @Override + synchronized public CompletableFuture>> + findTopicIds(long deadlineNs, Collection topicNames) { + Map> results = new HashMap<>(); + for (String topicName : topicNames) { + if (!topicNameToId.containsKey(topicName)) { + results.put(topicName, new ResultOrError<>(new ApiError(Errors.UNKNOWN_TOPIC_OR_PARTITION))); + } else { + results.put(topicName, new ResultOrError<>(topicNameToId.get(topicName))); + } + } + return CompletableFuture.completedFuture(results); + } + + @Override + synchronized public CompletableFuture>> + findTopicNames(long deadlineNs, Collection topicIds) { + Map> results = new HashMap<>(); + for (Uuid topicId : topicIds) { + MockTopic topic = topics.get(topicId); + if (topic == null) { + results.put(topicId, new ResultOrError<>(new ApiError(Errors.UNKNOWN_TOPIC_ID))); + } else { + results.put(topicId, new ResultOrError<>(topic.name)); + } + } + return CompletableFuture.completedFuture(results); + } + + @Override + synchronized public CompletableFuture> + deleteTopics(long deadlineNs, Collection topicIds) { + if (!active) { + CompletableFuture> future = new CompletableFuture<>(); + future.completeExceptionally(NOT_CONTROLLER_EXCEPTION); + return future; + } + Map results = new HashMap<>(); + for (Uuid topicId : topicIds) { + MockTopic topic = topics.remove(topicId); + if (topic == null) { + results.put(topicId, new ApiError(Errors.UNKNOWN_TOPIC_ID)); + } else { + topicNameToId.remove(topic.name); + results.put(topicId, ApiError.NONE); + } + } + return CompletableFuture.completedFuture(results); + } + + @Override + public CompletableFuture>>> describeConfigs(Map> resources) { + throw new UnsupportedOperationException(); + } + + @Override + public CompletableFuture electLeaders(ElectLeadersRequestData request) { + throw new UnsupportedOperationException(); + } + + @Override + public CompletableFuture finalizedFeatures() { + throw new UnsupportedOperationException(); + } + + @Override + public CompletableFuture> incrementalAlterConfigs( + Map>> configChanges, + boolean validateOnly) { + Map results = new HashMap<>(); + for (Entry>> entry : + configChanges.entrySet()) { + ConfigResource resource = entry.getKey(); + results.put(resource, incrementalAlterResource(resource, entry.getValue(), validateOnly)); + } + CompletableFuture> future = new CompletableFuture<>(); + future.complete(results); + return future; + } + + private ApiError incrementalAlterResource(ConfigResource resource, + Map> ops, boolean validateOnly) { + for (Entry> entry : ops.entrySet()) { + AlterConfigOp.OpType opType = entry.getValue().getKey(); + if (opType != SET && opType != DELETE) { + return new ApiError(INVALID_REQUEST, "This mock does not " + + "support the " + opType + " config operation."); + } + } + if (!validateOnly) { + for (Entry> entry : ops.entrySet()) { + String key = entry.getKey(); + AlterConfigOp.OpType op = entry.getValue().getKey(); + String value = entry.getValue().getValue(); + switch (op) { + case SET: + configs.computeIfAbsent(resource, __ -> new HashMap<>()).put(key, value); + break; + case DELETE: + configs.getOrDefault(resource, Collections.emptyMap()).remove(key); + break; + } + } + } + return ApiError.NONE; + } + + @Override + public CompletableFuture + alterPartitionReassignments(AlterPartitionReassignmentsRequestData request) { + throw new UnsupportedOperationException(); + } + + @Override + public CompletableFuture + listPartitionReassignments(ListPartitionReassignmentsRequestData request) { + throw new UnsupportedOperationException(); + } + + @Override + public CompletableFuture> legacyAlterConfigs( + Map> newConfigs, boolean validateOnly) { + Map results = new HashMap<>(); + if (!validateOnly) { + for (Entry> entry : newConfigs.entrySet()) { + ConfigResource resource = entry.getKey(); + Map map = configs.computeIfAbsent(resource, __ -> new HashMap<>()); + map.clear(); + map.putAll(entry.getValue()); + } + } + CompletableFuture> future = new CompletableFuture<>(); + future.complete(results); + return future; + } + + @Override + public CompletableFuture + processBrokerHeartbeat(BrokerHeartbeatRequestData request) { + throw new UnsupportedOperationException(); + } + + @Override + public CompletableFuture + registerBroker(BrokerRegistrationRequestData request) { + throw new UnsupportedOperationException(); + } + + @Override + public CompletableFuture waitForReadyBrokers(int minBrokers) { + throw new UnsupportedOperationException(); + } + + @Override + public CompletableFuture> + alterClientQuotas(Collection quotaAlterations, boolean validateOnly) { + throw new UnsupportedOperationException(); + } + + @Override + public CompletableFuture allocateProducerIds(AllocateProducerIdsRequestData request) { + throw new UnsupportedOperationException(); + } + + @Override + synchronized public CompletableFuture> + createPartitions(long deadlineNs, List topicList) { + if (!active) { + CompletableFuture> future = new CompletableFuture<>(); + future.completeExceptionally(NOT_CONTROLLER_EXCEPTION); + return future; + } + List results = new ArrayList<>(); + for (CreatePartitionsTopic topic : topicList) { + if (topicNameToId.containsKey(topic.name())) { + results.add(new CreatePartitionsTopicResult().setName(topic.name()). + setErrorCode(Errors.NONE.code()). + setErrorMessage(null)); + } else { + results.add(new CreatePartitionsTopicResult().setName(topic.name()). + setErrorCode(Errors.UNKNOWN_TOPIC_OR_PARTITION.code()). + setErrorMessage("No such topic as " + topic.name())); + } + } + return CompletableFuture.completedFuture(results); + } + + @Override + public CompletableFuture beginWritingSnapshot() { + throw new UnsupportedOperationException(); + } + + @Override + public void beginShutdown() { + this.active = false; + } + + public void setActive(boolean active) { + this.active = active; + } + + @Override + public int curClaimEpoch() { + return active ? 1 : -1; + } + + @Override + public void close() { + beginShutdown(); + } +} diff --git a/core/src/test/java/kafka/test/annotation/AutoStart.java b/core/src/test/java/kafka/test/annotation/AutoStart.java new file mode 100644 index 0000000..24fdedf --- /dev/null +++ b/core/src/test/java/kafka/test/annotation/AutoStart.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.test.annotation; + +public enum AutoStart { + YES, + NO, + DEFAULT +} diff --git a/core/src/test/java/kafka/test/annotation/ClusterConfigProperty.java b/core/src/test/java/kafka/test/annotation/ClusterConfigProperty.java new file mode 100644 index 0000000..eb1434d --- /dev/null +++ b/core/src/test/java/kafka/test/annotation/ClusterConfigProperty.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.test.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Documented +@Target({ElementType.ANNOTATION_TYPE}) +@Retention(RetentionPolicy.RUNTIME) +public @interface ClusterConfigProperty { + String key(); + String value(); +} diff --git a/core/src/test/java/kafka/test/annotation/ClusterTemplate.java b/core/src/test/java/kafka/test/annotation/ClusterTemplate.java new file mode 100644 index 0000000..f776b4e --- /dev/null +++ b/core/src/test/java/kafka/test/annotation/ClusterTemplate.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.test.annotation; + +import kafka.test.ClusterConfig; +import kafka.test.ClusterGenerator; +import org.junit.jupiter.api.TestTemplate; + +import java.lang.annotation.Documented; +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +/** + * Used to indicate that a test should call the method given by {@link #value()} to generate a number of + * cluster configurations. The method specified by the value should accept a single argument of the type + * {@link ClusterGenerator}. Any return value from the method is ignore. A test invocation + * will be generated for each {@link ClusterConfig} provided to the ClusterGenerator instance. + * + * The method given here must be static since it is invoked before any tests are actually run. Each test generated + * by this annotation will run as if it was defined as a separate test method with its own + * {@link org.junit.jupiter.api.Test}. That is to say, each generated test invocation will have a separate lifecycle. + * + * This annotation may be used in conjunction with {@link ClusterTest} and {@link ClusterTests} which also yield + * ClusterConfig instances. + * + * For Scala tests, the method should be defined in a companion object with the same name as the test class. + */ +@Documented +@Target({METHOD}) +@Retention(RUNTIME) +@TestTemplate +public @interface ClusterTemplate { + /** + * Specify the static method used for generating cluster configs + */ + String value(); +} diff --git a/core/src/test/java/kafka/test/annotation/ClusterTest.java b/core/src/test/java/kafka/test/annotation/ClusterTest.java new file mode 100644 index 0000000..11336ab --- /dev/null +++ b/core/src/test/java/kafka/test/annotation/ClusterTest.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.test.annotation; + +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.junit.jupiter.api.TestTemplate; + +import java.lang.annotation.Documented; +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +@Documented +@Target({METHOD}) +@Retention(RUNTIME) +@TestTemplate +public @interface ClusterTest { + Type clusterType() default Type.DEFAULT; + int brokers() default 0; + int controllers() default 0; + AutoStart autoStart() default AutoStart.DEFAULT; + + String name() default ""; + SecurityProtocol securityProtocol() default SecurityProtocol.PLAINTEXT; + String listener() default ""; + String ibp() default ""; + ClusterConfigProperty[] serverProperties() default {}; +} diff --git a/core/src/test/java/kafka/test/annotation/ClusterTestDefaults.java b/core/src/test/java/kafka/test/annotation/ClusterTestDefaults.java new file mode 100644 index 0000000..cd8a66d --- /dev/null +++ b/core/src/test/java/kafka/test/annotation/ClusterTestDefaults.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.test.annotation; + +import kafka.test.junit.ClusterTestExtensions; + +import java.lang.annotation.Documented; +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import static java.lang.annotation.ElementType.TYPE; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +/** + * Used to set class level defaults for any test template methods annotated with {@link ClusterTest} or + * {@link ClusterTests}. The default values here are also used as the source for defaults in + * {@link ClusterTestExtensions}. + */ +@Documented +@Target({TYPE}) +@Retention(RUNTIME) +public @interface ClusterTestDefaults { + Type clusterType() default Type.ZK; + int brokers() default 1; + int controllers() default 1; + boolean autoStart() default true; +} diff --git a/core/src/test/java/kafka/test/annotation/ClusterTests.java b/core/src/test/java/kafka/test/annotation/ClusterTests.java new file mode 100644 index 0000000..64905f8 --- /dev/null +++ b/core/src/test/java/kafka/test/annotation/ClusterTests.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.test.annotation; + +import org.junit.jupiter.api.TestTemplate; + +import java.lang.annotation.Documented; +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +@Documented +@Target({METHOD}) +@Retention(RUNTIME) +@TestTemplate +public @interface ClusterTests { + ClusterTest[] value(); +} diff --git a/core/src/test/java/kafka/test/annotation/Type.java b/core/src/test/java/kafka/test/annotation/Type.java new file mode 100644 index 0000000..0d1a161 --- /dev/null +++ b/core/src/test/java/kafka/test/annotation/Type.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.test.annotation; + +import kafka.test.ClusterConfig; +import kafka.test.junit.RaftClusterInvocationContext; +import kafka.test.junit.ZkClusterInvocationContext; +import org.junit.jupiter.api.extension.TestTemplateInvocationContext; + +import java.util.function.Consumer; + +/** + * The type of cluster config being requested. Used by {@link kafka.test.ClusterConfig} and the test annotations. + */ +public enum Type { + KRAFT { + @Override + public void invocationContexts(ClusterConfig config, Consumer invocationConsumer) { + invocationConsumer.accept(new RaftClusterInvocationContext(config.copyOf())); + } + }, + ZK { + @Override + public void invocationContexts(ClusterConfig config, Consumer invocationConsumer) { + invocationConsumer.accept(new ZkClusterInvocationContext(config.copyOf())); + } + }, + BOTH { + @Override + public void invocationContexts(ClusterConfig config, Consumer invocationConsumer) { + invocationConsumer.accept(new RaftClusterInvocationContext(config.copyOf())); + invocationConsumer.accept(new ZkClusterInvocationContext(config.copyOf())); + } + }, + DEFAULT { + @Override + public void invocationContexts(ClusterConfig config, Consumer invocationConsumer) { + throw new UnsupportedOperationException("Cannot create invocation contexts for DEFAULT type"); + } + }; + + public abstract void invocationContexts(ClusterConfig config, Consumer invocationConsumer); +} diff --git a/core/src/test/java/kafka/test/junit/ClusterInstanceParameterResolver.java b/core/src/test/java/kafka/test/junit/ClusterInstanceParameterResolver.java new file mode 100644 index 0000000..3329e32 --- /dev/null +++ b/core/src/test/java/kafka/test/junit/ClusterInstanceParameterResolver.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.test.junit; + +import kafka.test.ClusterInstance; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.api.extension.ParameterContext; +import org.junit.jupiter.api.extension.ParameterResolver; + +import java.lang.reflect.Executable; + +import static org.junit.platform.commons.util.AnnotationUtils.isAnnotated; + +/** + * This resolver provides an instance of {@link ClusterInstance} to a test invocation. The instance represents the + * underlying cluster being run for the current test. It can be injected into test methods or into the class + * constructor. + * + * N.B., if injected into the class constructor, the instance will not be fully initialized until the actual test method + * is being invoked. This is because the cluster is not started until after class construction and after "before" + * lifecycle methods have been run. Constructor injection is meant for convenience so helper methods can be defined on + * the test which can rely on a class member rather than an argument for ClusterInstance. + */ +public class ClusterInstanceParameterResolver implements ParameterResolver { + private final ClusterInstance clusterInstance; + + ClusterInstanceParameterResolver(ClusterInstance clusterInstance) { + this.clusterInstance = clusterInstance; + } + + @Override + public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext) { + if (!parameterContext.getParameter().getType().equals(ClusterInstance.class)) { + return false; + } + + if (!extensionContext.getTestMethod().isPresent()) { + // Allow this to be injected into the class + extensionContext.getRequiredTestClass(); + return true; + } else { + // If we're injecting into a method, make sure it's a test method and not a lifecycle method + Executable parameterizedMethod = parameterContext.getParameter().getDeclaringExecutable(); + return isAnnotated(parameterizedMethod, TestTemplate.class); + } + } + + @Override + public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext) { + return clusterInstance; + } +} diff --git a/core/src/test/java/kafka/test/junit/ClusterTestExtensions.java b/core/src/test/java/kafka/test/junit/ClusterTestExtensions.java new file mode 100644 index 0000000..293f00b --- /dev/null +++ b/core/src/test/java/kafka/test/junit/ClusterTestExtensions.java @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.test.junit; + +import kafka.test.ClusterConfig; +import kafka.test.ClusterGenerator; +import kafka.test.annotation.ClusterTestDefaults; +import kafka.test.annotation.ClusterConfigProperty; +import kafka.test.annotation.ClusterTemplate; +import kafka.test.annotation.ClusterTest; +import kafka.test.annotation.ClusterTests; +import kafka.test.annotation.Type; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.api.extension.TestTemplateInvocationContext; +import org.junit.jupiter.api.extension.TestTemplateInvocationContextProvider; +import org.junit.platform.commons.util.ReflectionUtils; + +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.Properties; +import java.util.function.Consumer; +import java.util.stream.Stream; + +/** + * This class is a custom JUnit extension that will generate some number of test invocations depending on the processing + * of a few custom annotations. These annotations are placed on so-called test template methods. Template methods look + * like normal JUnit test methods, but instead of being invoked directly, they are used as templates for generating + * multiple test invocations. + * + * Test class that use this extension should use one of the following annotations on each template method: + * + *
                  + *
                • {@link ClusterTest}, define a single cluster configuration
                • + *
                • {@link ClusterTests}, provide multiple instances of @ClusterTest
                • + *
                • {@link ClusterTemplate}, define a static method that generates cluster configurations
                • + *
                + * + * Any combination of these annotations may be used on a given test template method. If no test invocations are + * generated after processing the annotations, an error is thrown. + * + * Depending on which annotations are used, and what values are given, different {@link ClusterConfig} will be + * generated. Each ClusterConfig is used to create an underlying Kafka cluster that is used for the actual test + * invocation. + * + * For example: + * + *
                + * @ExtendWith(value = Array(classOf[ClusterTestExtensions]))
                + * class SomeIntegrationTest {
                + *   @ClusterTest(brokers = 1, controllers = 1, clusterType = ClusterType.Both)
                + *   def someTest(): Unit = {
                + *     assertTrue(condition)
                + *   }
                + * }
                + * 
                + * + * will generate two invocations of "someTest" (since ClusterType.Both was given). For each invocation, the test class + * SomeIntegrationTest will be instantiated, lifecycle methods (before/after) will be run, and "someTest" will be invoked. + * + **/ +public class ClusterTestExtensions implements TestTemplateInvocationContextProvider { + @Override + public boolean supportsTestTemplate(ExtensionContext context) { + return true; + } + + @Override + public Stream provideTestTemplateInvocationContexts(ExtensionContext context) { + ClusterTestDefaults defaults = getClusterTestDefaults(context.getRequiredTestClass()); + List generatedContexts = new ArrayList<>(); + + // Process the @ClusterTemplate annotation + ClusterTemplate clusterTemplateAnnot = context.getRequiredTestMethod().getDeclaredAnnotation(ClusterTemplate.class); + if (clusterTemplateAnnot != null) { + processClusterTemplate(context, clusterTemplateAnnot, generatedContexts::add); + if (generatedContexts.size() == 0) { + throw new IllegalStateException("ClusterConfig generator method should provide at least one config"); + } + } + + // Process single @ClusterTest annotation + ClusterTest clusterTestAnnot = context.getRequiredTestMethod().getDeclaredAnnotation(ClusterTest.class); + if (clusterTestAnnot != null) { + processClusterTest(context, clusterTestAnnot, defaults, generatedContexts::add); + } + + // Process multiple @ClusterTest annotation within @ClusterTests + ClusterTests clusterTestsAnnot = context.getRequiredTestMethod().getDeclaredAnnotation(ClusterTests.class); + if (clusterTestsAnnot != null) { + for (ClusterTest annot : clusterTestsAnnot.value()) { + processClusterTest(context, annot, defaults, generatedContexts::add); + } + } + + if (generatedContexts.size() == 0) { + throw new IllegalStateException("Please annotate test methods with @ClusterTemplate, @ClusterTest, or " + + "@ClusterTests when using the ClusterTestExtensions provider"); + } + + return generatedContexts.stream(); + } + + private void processClusterTemplate(ExtensionContext context, ClusterTemplate annot, + Consumer testInvocations) { + // If specified, call cluster config generated method (must be static) + List generatedClusterConfigs = new ArrayList<>(); + if (!annot.value().isEmpty()) { + generateClusterConfigurations(context, annot.value(), generatedClusterConfigs::add); + } else { + // Ensure we have at least one cluster config + generatedClusterConfigs.add(ClusterConfig.defaultClusterBuilder().build()); + } + + generatedClusterConfigs.forEach(config -> config.clusterType().invocationContexts(config, testInvocations)); + } + + private void generateClusterConfigurations(ExtensionContext context, String generateClustersMethods, ClusterGenerator generator) { + Object testInstance = context.getTestInstance().orElse(null); + Method method = ReflectionUtils.getRequiredMethod(context.getRequiredTestClass(), generateClustersMethods, ClusterGenerator.class); + ReflectionUtils.invokeMethod(method, testInstance, generator); + } + + private void processClusterTest(ExtensionContext context, ClusterTest annot, ClusterTestDefaults defaults, + Consumer testInvocations) { + final Type type; + if (annot.clusterType() == Type.DEFAULT) { + type = defaults.clusterType(); + } else { + type = annot.clusterType(); + } + + final int brokers; + if (annot.brokers() == 0) { + brokers = defaults.brokers(); + } else { + brokers = annot.brokers(); + } + + final int controllers; + if (annot.controllers() == 0) { + controllers = defaults.controllers(); + } else { + controllers = annot.controllers(); + } + + if (brokers <= 0 || controllers <= 0) { + throw new IllegalArgumentException("Number of brokers/controllers must be greater than zero."); + } + + final boolean autoStart; + switch (annot.autoStart()) { + case YES: + autoStart = true; + break; + case NO: + autoStart = false; + break; + case DEFAULT: + autoStart = defaults.autoStart(); + break; + default: + throw new IllegalStateException(); + } + + ClusterConfig.Builder builder = ClusterConfig.clusterBuilder(type, brokers, controllers, autoStart, annot.securityProtocol()); + if (!annot.name().isEmpty()) { + builder.name(annot.name()); + } else { + builder.name(context.getRequiredTestMethod().getName()); + } + if (!annot.listener().isEmpty()) { + builder.listenerName(annot.listener()); + } + + Properties properties = new Properties(); + for (ClusterConfigProperty property : annot.serverProperties()) { + properties.put(property.key(), property.value()); + } + + if (!annot.ibp().isEmpty()) { + builder.ibp(annot.ibp()); + } + + ClusterConfig config = builder.build(); + config.serverProperties().putAll(properties); + type.invocationContexts(config, testInvocations); + } + + private ClusterTestDefaults getClusterTestDefaults(Class testClass) { + return Optional.ofNullable(testClass.getDeclaredAnnotation(ClusterTestDefaults.class)) + .orElseGet(() -> EmptyClass.class.getDeclaredAnnotation(ClusterTestDefaults.class)); + } + + @ClusterTestDefaults + private final static class EmptyClass { + // Just used as a convenience to get default values from the annotation + } +} diff --git a/core/src/test/java/kafka/test/junit/GenericParameterResolver.java b/core/src/test/java/kafka/test/junit/GenericParameterResolver.java new file mode 100644 index 0000000..70387e1 --- /dev/null +++ b/core/src/test/java/kafka/test/junit/GenericParameterResolver.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.test.junit; + +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.api.extension.ParameterContext; +import org.junit.jupiter.api.extension.ParameterResolver; + +/** + * This resolver is used for supplying any type of object to the test invocation. It does not restrict where the given + * type can be injected, it simply checks if the requested injection type matches the type given in the constructor. If + * it matches, the given object is returned. + * + * This is useful for injecting helper objects and objects which can be fully initialized before the test lifecycle + * begins. + */ +public class GenericParameterResolver implements ParameterResolver { + + private final T instance; + private final Class clazz; + + GenericParameterResolver(T instance, Class clazz) { + this.instance = instance; + this.clazz = clazz; + } + + @Override + public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext) { + return parameterContext.getParameter().getType().equals(clazz); + } + + @Override + public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext) { + return instance; + } +} diff --git a/core/src/test/java/kafka/test/junit/README.md b/core/src/test/java/kafka/test/junit/README.md new file mode 100644 index 0000000..6df7a26 --- /dev/null +++ b/core/src/test/java/kafka/test/junit/README.md @@ -0,0 +1,139 @@ +This document describes a custom JUnit extension which allows for running the same JUnit tests against multiple Kafka +cluster configurations. + +# Annotations + +A new `@ClusterTest` annotation is introduced which allows for a test to declaratively configure an underlying Kafka cluster. + +```scala +@ClusterTest +def testSomething(): Unit = { ... } +``` + +This annotation has fields for cluster type and number of brokers, as well as commonly parameterized configurations. +Arbitrary server properties can also be provided in the annotation: + +```java +@ClusterTest(clusterType = Type.Zk, securityProtocol = "PLAINTEXT", properties = { + @ClusterProperty(key = "inter.broker.protocol.version", value = "2.7-IV2"), + @ClusterProperty(key = "socket.send.buffer.bytes", value = "10240"), +}) +void testSomething() { ... } +``` + +Multiple `@ClusterTest` annotations can be given to generate more than one test invocation for the annotated method. + +```scala +@ClusterTests(Array( + @ClusterTest(securityProtocol = "PLAINTEXT"), + @ClusterTest(securityProtocol = "SASL_PLAINTEXT") +)) +def testSomething(): Unit = { ... } +``` + +A class-level `@ClusterTestDefaults` annotation is added to provide default values for `@ClusterTest` defined within +the class. The intention here is to reduce repetitive annotation declarations and also make changing defaults easier +for a class with many test cases. + +# Dynamic Configuration + +In order to allow for more flexible cluster configuration, a `@ClusterTemplate` annotation is also introduced. This +annotation takes a single string value which references a static method on the test class. This method is used to +produce any number of test configurations using a fluent builder style API. + +```java +@ClusterTemplate("generateConfigs") +void testSomething() { ... } + +static void generateConfigs(ClusterGenerator clusterGenerator) { + clusterGenerator.accept(ClusterConfig.defaultClusterBuilder() + .name("Generated Test 1") + .serverProperties(props1) + .ibp("2.7-IV1") + .build()); + clusterGenerator.accept(ClusterConfig.defaultClusterBuilder() + .name("Generated Test 2") + .serverProperties(props2) + .ibp("2.7-IV2") + .build()); + clusterGenerator.accept(ClusterConfig.defaultClusterBuilder() + .name("Generated Test 3") + .serverProperties(props3) + .build()); +} +``` + +This "escape hatch" from the simple declarative style configuration makes it easy to dynamically configure clusters. + + +# JUnit Extension + +One thing to note is that our "test*" methods are no longer _tests_, but rather they are test templates. We have added +a JUnit extension called `ClusterTestExtensions` which knows how to process these annotations in order to generate test +invocations. Test classes that wish to make use of these annotations need to explicitly register this extension: + +```scala +import kafka.test.junit.ClusterTestExtensions + +@ExtendWith(value = Array(classOf[ClusterTestExtensions])) +class ApiVersionsRequestTest { + ... +} +``` + +# JUnit Lifecycle + +The lifecycle of a test class that is extended with `ClusterTestExtensions` follows: + +* JUnit discovers test template methods that are annotated with `@ClusterTest`, `@ClusterTests`, or `@ClusterTemplate` +* `ClusterTestExtensions` is called for each of these template methods in order to generate some number of test invocations + +For each generated invocation: +* Static `@BeforeAll` methods are called +* Test class is instantiated +* Non-static `@BeforeEach` methods are called +* Kafka Cluster is started +* Test method is invoked +* Kafka Cluster is stopped +* Non-static `@AfterEach` methods are called +* Static `@AfterAll` methods are called + +`@BeforeEach` methods give an opportunity to setup additional test dependencies before the cluster is started. + +# Dependency Injection + +A few classes are introduced to provide context to the underlying cluster and to provide reusable functionality that was +previously garnered from the test hierarchy. + +* ClusterConfig: a mutable cluster configuration, includes cluster type, number of brokers, properties, etc +* ClusterInstance: a shim to the underlying class that actually runs the cluster, provides access to things like SocketServers +* IntegrationTestHelper: connection related functions taken from IntegrationTestHarness and BaseRequestTest + +In order to have one of these objects injected, simply add it as a parameter to your test class, `@BeforeEach` method, or test method. + +| Injection | Class | BeforeEach | Test | Notes +| --- | --- | --- | --- | --- | +| ClusterConfig | yes | yes | yes* | Once in the test, changing config has no effect | +| ClusterInstance | yes* | no | yes | Injectable at class level for convenience, can only be accessed inside test | +| IntegrationTestHelper | yes | yes | yes | - | + +```scala +@ExtendWith(value = Array(classOf[ClusterTestExtensions])) +class SomeTestClass(helper: IntegrationTestHelper) { + + @BeforeEach + def setup(config: ClusterConfig): Unit = { + config.serverProperties().put("foo", "bar") + } + + @ClusterTest + def testSomething(cluster: ClusterInstance): Unit = { + val topics = cluster.createAdminClient().listTopics() + } +} +``` + +# Gotchas +* Test methods annotated with JUnit's `@Test` will still be run, but no cluster will be started and no dependency + injection will happen. This is generally not what you want. +* Even though ClusterConfig is accessible and mutable inside the test method, changing it will have no effect on the cluster. \ No newline at end of file diff --git a/core/src/test/java/kafka/test/junit/RaftClusterInvocationContext.java b/core/src/test/java/kafka/test/junit/RaftClusterInvocationContext.java new file mode 100644 index 0000000..599bdf0 --- /dev/null +++ b/core/src/test/java/kafka/test/junit/RaftClusterInvocationContext.java @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.test.junit; + +import kafka.network.SocketServer; +import kafka.server.BrokerServer; +import kafka.server.ControllerServer; +import kafka.test.ClusterConfig; +import kafka.test.ClusterInstance; +import kafka.testkit.KafkaClusterTestKit; +import kafka.testkit.TestKitNodes; +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.common.network.ListenerName; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.metadata.BrokerState; +import org.junit.jupiter.api.extension.AfterTestExecutionCallback; +import org.junit.jupiter.api.extension.BeforeTestExecutionCallback; +import org.junit.jupiter.api.extension.Extension; +import org.junit.jupiter.api.extension.TestTemplateInvocationContext; + +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.Properties; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * Wraps a {@link KafkaClusterTestKit} inside lifecycle methods for a test invocation. Each instance of this + * class is provided with a configuration for the cluster. + * + * This context also provides parameter resolvers for: + * + *
                  + *
                • ClusterConfig (the same instance passed to the constructor)
                • + *
                • ClusterInstance (includes methods to expose underlying SocketServer-s)
                • + *
                • IntegrationTestHelper (helper methods)
                • + *
                + */ +public class RaftClusterInvocationContext implements TestTemplateInvocationContext { + + private final ClusterConfig clusterConfig; + private final AtomicReference clusterReference; + + public RaftClusterInvocationContext(ClusterConfig clusterConfig) { + this.clusterConfig = clusterConfig; + this.clusterReference = new AtomicReference<>(); + } + + @Override + public String getDisplayName(int invocationIndex) { + String clusterDesc = clusterConfig.nameTags().entrySet().stream() + .map(Object::toString) + .collect(Collectors.joining(", ")); + return String.format("[%d] Type=Raft, %s", invocationIndex, clusterDesc); + } + + @Override + public List getAdditionalExtensions() { + RaftClusterInstance clusterInstance = new RaftClusterInstance(clusterReference, clusterConfig); + return Arrays.asList( + (BeforeTestExecutionCallback) context -> { + TestKitNodes nodes = new TestKitNodes.Builder(). + setNumBrokerNodes(clusterConfig.numBrokers()). + setNumControllerNodes(clusterConfig.numControllers()).build(); + nodes.brokerNodes().forEach((brokerId, brokerNode) -> { + clusterConfig.brokerServerProperties(brokerId).forEach( + (key, value) -> brokerNode.propertyOverrides().put(key.toString(), value.toString())); + }); + KafkaClusterTestKit.Builder builder = new KafkaClusterTestKit.Builder(nodes); + + // Copy properties into the TestKit builder + clusterConfig.serverProperties().forEach((key, value) -> builder.setConfigProp(key.toString(), value.toString())); + // KAFKA-12512 need to pass security protocol and listener name here + KafkaClusterTestKit cluster = builder.build(); + clusterReference.set(cluster); + cluster.format(); + cluster.startup(); + kafka.utils.TestUtils.waitUntilTrue( + () -> cluster.brokers().get(0).brokerState() == BrokerState.RUNNING, + () -> "Broker never made it to RUNNING state.", + org.apache.kafka.test.TestUtils.DEFAULT_MAX_WAIT_MS, + 100L); + }, + (AfterTestExecutionCallback) context -> clusterInstance.stop(), + new ClusterInstanceParameterResolver(clusterInstance), + new GenericParameterResolver<>(clusterConfig, ClusterConfig.class) + ); + } + + public static class RaftClusterInstance implements ClusterInstance { + + private final AtomicReference clusterReference; + private final ClusterConfig clusterConfig; + final AtomicBoolean started = new AtomicBoolean(false); + final AtomicBoolean stopped = new AtomicBoolean(false); + private final ConcurrentLinkedQueue admins = new ConcurrentLinkedQueue<>(); + + RaftClusterInstance(AtomicReference clusterReference, ClusterConfig clusterConfig) { + this.clusterReference = clusterReference; + this.clusterConfig = clusterConfig; + } + + @Override + public String bootstrapServers() { + return clusterReference.get().clientProperties().getProperty(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG); + } + + @Override + public Collection brokerSocketServers() { + return brokers() + .map(BrokerServer::socketServer) + .collect(Collectors.toList()); + } + + @Override + public ListenerName clientListener() { + return ListenerName.normalised("EXTERNAL"); + } + + @Override + public Collection controllerSocketServers() { + return controllers() + .map(ControllerServer::socketServer) + .collect(Collectors.toList()); + } + + @Override + public SocketServer anyBrokerSocketServer() { + return brokers() + .map(BrokerServer::socketServer) + .findFirst() + .orElseThrow(() -> new RuntimeException("No broker SocketServers found")); + } + + @Override + public SocketServer anyControllerSocketServer() { + return controllers() + .map(ControllerServer::socketServer) + .findFirst() + .orElseThrow(() -> new RuntimeException("No controller SocketServers found")); + } + + @Override + public ClusterType clusterType() { + return ClusterType.RAFT; + } + + @Override + public ClusterConfig config() { + return clusterConfig; + } + + @Override + public KafkaClusterTestKit getUnderlying() { + return clusterReference.get(); + } + + @Override + public Admin createAdminClient(Properties configOverrides) { + Admin admin = Admin.create(clusterReference.get().clientProperties()); + admins.add(admin); + return admin; + } + + @Override + public void start() { + if (started.compareAndSet(false, true)) { + try { + clusterReference.get().startup(); + } catch (Exception e) { + throw new RuntimeException("Failed to start Raft server", e); + } + } + } + + @Override + public void stop() { + if (stopped.compareAndSet(false, true)) { + admins.forEach(admin -> Utils.closeQuietly(admin, "admin")); + Utils.closeQuietly(clusterReference.get(), "cluster"); + } + } + + @Override + public void shutdownBroker(int brokerId) { + findBrokerOrThrow(brokerId).shutdown(); + } + + @Override + public void startBroker(int brokerId) { + findBrokerOrThrow(brokerId).startup(); + } + + @Override + public void waitForReadyBrokers() throws InterruptedException { + try { + clusterReference.get().waitForReadyBrokers(); + } catch (ExecutionException e) { + throw new AssertionError("Failed while waiting for brokers to become ready", e); + } + } + + @Override + public void rollingBrokerRestart() { + throw new UnsupportedOperationException("Restarting Raft servers is not yet supported."); + } + + private BrokerServer findBrokerOrThrow(int brokerId) { + return Optional.ofNullable(clusterReference.get().brokers().get(brokerId)) + .orElseThrow(() -> new IllegalArgumentException("Unknown brokerId " + brokerId)); + } + + private Stream brokers() { + return clusterReference.get().brokers().values().stream(); + } + + private Stream controllers() { + return clusterReference.get().controllers().values().stream(); + } + + } +} diff --git a/core/src/test/java/kafka/test/junit/ZkClusterInvocationContext.java b/core/src/test/java/kafka/test/junit/ZkClusterInvocationContext.java new file mode 100644 index 0000000..4d9c6b7 --- /dev/null +++ b/core/src/test/java/kafka/test/junit/ZkClusterInvocationContext.java @@ -0,0 +1,308 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.test.junit; + +import kafka.api.IntegrationTestHarness; +import kafka.network.SocketServer; +import kafka.server.KafkaConfig; +import kafka.server.KafkaServer; +import kafka.test.ClusterConfig; +import kafka.test.ClusterInstance; +import kafka.utils.EmptyTestInfo; +import kafka.utils.TestUtils; +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.common.network.ListenerName; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.junit.jupiter.api.extension.AfterTestExecutionCallback; +import org.junit.jupiter.api.extension.BeforeTestExecutionCallback; +import org.junit.jupiter.api.extension.Extension; +import org.junit.jupiter.api.extension.TestTemplateInvocationContext; +import scala.Option; +import scala.collection.JavaConverters; +import scala.collection.Seq; +import scala.compat.java8.OptionConverters; + +import java.io.File; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Properties; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * Wraps a {@link IntegrationTestHarness} inside lifecycle methods for a test invocation. Each instance of this + * class is provided with a configuration for the cluster. + * + * This context also provides parameter resolvers for: + * + *
                  + *
                • ClusterConfig (the same instance passed to the constructor)
                • + *
                • ClusterInstance (includes methods to expose underlying SocketServer-s)
                • + *
                • IntegrationTestHelper (helper methods)
                • + *
                + */ +public class ZkClusterInvocationContext implements TestTemplateInvocationContext { + + private final ClusterConfig clusterConfig; + private final AtomicReference clusterReference; + + public ZkClusterInvocationContext(ClusterConfig clusterConfig) { + this.clusterConfig = clusterConfig; + this.clusterReference = new AtomicReference<>(); + } + + @Override + public String getDisplayName(int invocationIndex) { + String clusterDesc = clusterConfig.nameTags().entrySet().stream() + .map(Object::toString) + .collect(Collectors.joining(", ")); + return String.format("[%d] Type=ZK, %s", invocationIndex, clusterDesc); + } + + @Override + public List getAdditionalExtensions() { + if (clusterConfig.numControllers() != 1) { + throw new IllegalArgumentException("For ZK clusters, please specify exactly 1 controller."); + } + ClusterInstance clusterShim = new ZkClusterInstance(clusterConfig, clusterReference); + return Arrays.asList( + (BeforeTestExecutionCallback) context -> { + // We have to wait to actually create the underlying cluster until after our @BeforeEach methods + // have run. This allows tests to set up external dependencies like ZK, MiniKDC, etc. + // However, since we cannot create this instance until we are inside the test invocation, we have + // to use a container class (AtomicReference) to provide this cluster object to the test itself + + // This is what tests normally extend from to start a cluster, here we create it anonymously and + // configure the cluster using values from ClusterConfig + IntegrationTestHarness cluster = new IntegrationTestHarness() { + + @Override + public void modifyConfigs(Seq props) { + super.modifyConfigs(props); + for (int i = 0; i < props.length(); i++) { + props.apply(i).putAll(clusterConfig.brokerServerProperties(i)); + } + } + + @Override + public Properties serverConfig() { + Properties props = clusterConfig.serverProperties(); + clusterConfig.ibp().ifPresent(ibp -> props.put(KafkaConfig.InterBrokerProtocolVersionProp(), ibp)); + return props; + } + + @Override + public Properties adminClientConfig() { + return clusterConfig.adminClientProperties(); + } + + @Override + public Properties consumerConfig() { + return clusterConfig.consumerProperties(); + } + + @Override + public Properties producerConfig() { + return clusterConfig.producerProperties(); + } + + @Override + public SecurityProtocol securityProtocol() { + return clusterConfig.securityProtocol(); + } + + @Override + public ListenerName listenerName() { + return clusterConfig.listenerName().map(ListenerName::normalised) + .orElseGet(() -> ListenerName.forSecurityProtocol(securityProtocol())); + } + + @Override + public Option serverSaslProperties() { + if (clusterConfig.saslServerProperties().isEmpty()) { + return Option.empty(); + } else { + return Option.apply(clusterConfig.saslServerProperties()); + } + } + + @Override + public Option clientSaslProperties() { + if (clusterConfig.saslClientProperties().isEmpty()) { + return Option.empty(); + } else { + return Option.apply(clusterConfig.saslClientProperties()); + } + } + + @Override + public int brokerCount() { + // Controllers are also brokers in zk mode, so just use broker count + return clusterConfig.numBrokers(); + } + + @Override + public Option trustStoreFile() { + return OptionConverters.toScala(clusterConfig.trustStoreFile()); + } + }; + + clusterReference.set(cluster); + if (clusterConfig.isAutoStart()) { + clusterShim.start(); + } + }, + (AfterTestExecutionCallback) context -> clusterShim.stop(), + new ClusterInstanceParameterResolver(clusterShim), + new GenericParameterResolver<>(clusterConfig, ClusterConfig.class) + ); + } + + public static class ZkClusterInstance implements ClusterInstance { + + final AtomicReference clusterReference; + final ClusterConfig config; + final AtomicBoolean started = new AtomicBoolean(false); + final AtomicBoolean stopped = new AtomicBoolean(false); + + ZkClusterInstance(ClusterConfig config, AtomicReference clusterReference) { + this.config = config; + this.clusterReference = clusterReference; + } + + @Override + public String bootstrapServers() { + return TestUtils.bootstrapServers(clusterReference.get().servers(), clusterReference.get().listenerName()); + } + + @Override + public Collection brokerSocketServers() { + return servers() + .map(KafkaServer::socketServer) + .collect(Collectors.toList()); + } + + @Override + public ListenerName clientListener() { + return clusterReference.get().listenerName(); + } + + @Override + public Collection controllerSocketServers() { + return servers() + .filter(broker -> broker.kafkaController().isActive()) + .map(KafkaServer::socketServer) + .collect(Collectors.toList()); + } + + @Override + public SocketServer anyBrokerSocketServer() { + return servers() + .map(KafkaServer::socketServer) + .findFirst() + .orElseThrow(() -> new RuntimeException("No broker SocketServers found")); + } + + @Override + public SocketServer anyControllerSocketServer() { + return servers() + .filter(broker -> broker.kafkaController().isActive()) + .map(KafkaServer::socketServer) + .findFirst() + .orElseThrow(() -> new RuntimeException("No broker SocketServers found")); + } + + @Override + public ClusterType clusterType() { + return ClusterType.ZK; + } + + @Override + public ClusterConfig config() { + return config; + } + + @Override + public IntegrationTestHarness getUnderlying() { + return clusterReference.get(); + } + + @Override + public Admin createAdminClient(Properties configOverrides) { + return clusterReference.get().createAdminClient(configOverrides); + } + + @Override + public void start() { + if (started.compareAndSet(false, true)) { + clusterReference.get().setUp(new EmptyTestInfo()); + } + } + + @Override + public void stop() { + if (stopped.compareAndSet(false, true)) { + clusterReference.get().tearDown(); + } + } + + @Override + public void shutdownBroker(int brokerId) { + findBrokerOrThrow(brokerId).shutdown(); + } + + @Override + public void startBroker(int brokerId) { + findBrokerOrThrow(brokerId).startup(); + } + + @Override + public void rollingBrokerRestart() { + if (!started.get()) { + throw new IllegalStateException("Tried to restart brokers but the cluster has not been started!"); + } + for (int i = 0; i < clusterReference.get().brokerCount(); i++) { + clusterReference.get().killBroker(i); + } + clusterReference.get().restartDeadBrokers(true); + } + + @Override + public void waitForReadyBrokers() throws InterruptedException { + org.apache.kafka.test.TestUtils.waitForCondition(() -> { + int numRegisteredBrokers = clusterReference.get().zkClient().getAllBrokersInCluster().size(); + return numRegisteredBrokers == config.numBrokers(); + }, "Timed out while waiting for brokers to become ready"); + } + + private KafkaServer findBrokerOrThrow(int brokerId) { + return servers() + .filter(server -> server.config().brokerId() == brokerId) + .findFirst() + .orElseThrow(() -> new IllegalArgumentException("Unknown brokerId " + brokerId)); + } + + private Stream servers() { + return JavaConverters.asJavaCollection(clusterReference.get().servers()).stream(); + } + + } +} diff --git a/core/src/test/java/kafka/testkit/BrokerNode.java b/core/src/test/java/kafka/testkit/BrokerNode.java new file mode 100644 index 0000000..005d498 --- /dev/null +++ b/core/src/test/java/kafka/testkit/BrokerNode.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.testkit; + +import org.apache.kafka.common.Uuid; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static java.util.Collections.emptyMap; + +public class BrokerNode implements TestKitNode { + public static class Builder { + private int id = -1; + private Uuid incarnationId = null; + private String metadataDirectory = null; + private List logDataDirectories = null; + + public Builder setId(int id) { + this.id = id; + return this; + } + + public Builder setLogDirectories(List logDataDirectories) { + this.logDataDirectories = logDataDirectories; + return this; + } + + public Builder setMetadataDirectory(String metadataDirectory) { + this.metadataDirectory = metadataDirectory; + return this; + } + + public BrokerNode build() { + if (id == -1) { + throw new RuntimeException("You must set the node id"); + } + if (incarnationId == null) { + incarnationId = Uuid.randomUuid(); + } + if (logDataDirectories == null) { + logDataDirectories = Collections. + singletonList(String.format("broker_%d_data0", id)); + } + if (metadataDirectory == null) { + metadataDirectory = logDataDirectories.get(0); + } + return new BrokerNode(id, incarnationId, metadataDirectory, + logDataDirectories); + } + } + + private final int id; + private final Uuid incarnationId; + private final String metadataDirectory; + private final List logDataDirectories; + private final Map propertyOverrides; + + BrokerNode(int id, + Uuid incarnationId, + String metadataDirectory, + List logDataDirectories) { + this(id, incarnationId, metadataDirectory, logDataDirectories, emptyMap()); + } + + BrokerNode(int id, + Uuid incarnationId, + String metadataDirectory, + List logDataDirectories, + Map propertyOverrides) { + this.id = id; + this.incarnationId = incarnationId; + this.metadataDirectory = metadataDirectory; + this.logDataDirectories = new ArrayList<>(logDataDirectories); + this.propertyOverrides = new HashMap<>(propertyOverrides); + } + + @Override + public int id() { + return id; + } + + public Uuid incarnationId() { + return incarnationId; + } + + @Override + public String metadataDirectory() { + return metadataDirectory; + } + + public List logDataDirectories() { + return logDataDirectories; + } + + public Map propertyOverrides() { + return propertyOverrides; + } +} diff --git a/core/src/test/java/kafka/testkit/ControllerNode.java b/core/src/test/java/kafka/testkit/ControllerNode.java new file mode 100644 index 0000000..be6c806 --- /dev/null +++ b/core/src/test/java/kafka/testkit/ControllerNode.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.testkit; + +public class ControllerNode implements TestKitNode { + public static class Builder { + private int id = -1; + private String metadataDirectory = null; + + public Builder setId(int id) { + this.id = id; + return this; + } + + public Builder setMetadataDirectory() { + this.metadataDirectory = metadataDirectory; + return this; + } + + public ControllerNode build() { + if (id == -1) { + throw new RuntimeException("You must set the node id"); + } + if (metadataDirectory == null) { + metadataDirectory = String.format("controller_%d", id); + } + return new ControllerNode(id, metadataDirectory); + } + } + + private final int id; + private final String metadataDirectory; + + ControllerNode(int id, String metadataDirectory) { + this.id = id; + this.metadataDirectory = metadataDirectory; + } + + @Override + public int id() { + return id; + } + + @Override + public String metadataDirectory() { + return metadataDirectory; + } +} diff --git a/core/src/test/java/kafka/testkit/KafkaClusterTestKit.java b/core/src/test/java/kafka/testkit/KafkaClusterTestKit.java new file mode 100644 index 0000000..6947702 --- /dev/null +++ b/core/src/test/java/kafka/testkit/KafkaClusterTestKit.java @@ -0,0 +1,490 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.testkit; + +import kafka.raft.KafkaRaftManager; +import kafka.server.BrokerServer; +import kafka.server.ControllerServer; +import kafka.server.KafkaConfig; +import kafka.server.KafkaConfig$; +import kafka.server.KafkaRaftServer; +import kafka.server.MetaProperties; +import kafka.server.Server; +import kafka.tools.StorageTool; +import kafka.utils.Logging; +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.network.ListenerName; +import org.apache.kafka.common.utils.ThreadUtils; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.controller.Controller; +import org.apache.kafka.metadata.MetadataRecordSerde; +import org.apache.kafka.raft.RaftConfig; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.apache.kafka.test.TestUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.Option; +import scala.collection.JavaConverters; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.io.PrintStream; +import java.net.InetSocketAddress; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.AbstractMap.SimpleImmutableEntry; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Properties; +import java.util.TreeMap; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; +import java.util.stream.Collectors; + + +@SuppressWarnings("deprecation") // Needed for Scala 2.12 compatibility +public class KafkaClusterTestKit implements AutoCloseable { + private final static Logger log = LoggerFactory.getLogger(KafkaClusterTestKit.class); + + /** + * This class manages a future which is completed with the proper value for + * controller.quorum.voters once the randomly assigned ports for all the controllers are + * known. + */ + private static class ControllerQuorumVotersFutureManager implements AutoCloseable { + private final int expectedControllers; + private final CompletableFuture> future = new CompletableFuture<>(); + private final Map controllerPorts = new TreeMap<>(); + + ControllerQuorumVotersFutureManager(int expectedControllers) { + this.expectedControllers = expectedControllers; + } + + synchronized void registerPort(int nodeId, int port) { + controllerPorts.put(nodeId, port); + if (controllerPorts.size() >= expectedControllers) { + future.complete(controllerPorts.entrySet().stream(). + collect(Collectors.toMap( + Map.Entry::getKey, + entry -> new RaftConfig.InetAddressSpec(new InetSocketAddress("localhost", entry.getValue())) + ))); + } + } + + void fail(Throwable e) { + future.completeExceptionally(e); + } + + @Override + public void close() { + future.cancel(true); + } + } + + public static class Builder { + private TestKitNodes nodes; + private Map configProps = new HashMap<>(); + + public Builder(TestKitNodes nodes) { + this.nodes = nodes; + } + + public Builder setConfigProp(String key, String value) { + this.configProps.put(key, value); + return this; + } + + public KafkaClusterTestKit build() throws Exception { + Map controllers = new HashMap<>(); + Map brokers = new HashMap<>(); + Map> raftManagers = new HashMap<>(); + String uninitializedQuorumVotersString = nodes.controllerNodes().keySet().stream(). + map(controllerNode -> String.format("%d@0.0.0.0:0", controllerNode)). + collect(Collectors.joining(",")); + /* + Number of threads = Total number of brokers + Total number of controllers + Total number of Raft Managers + = Total number of brokers + Total number of controllers * 2 + (Raft Manager per broker/controller) + */ + int numOfExecutorThreads = (nodes.brokerNodes().size() + nodes.controllerNodes().size()) * 2; + ExecutorService executorService = null; + ControllerQuorumVotersFutureManager connectFutureManager = + new ControllerQuorumVotersFutureManager(nodes.controllerNodes().size()); + File baseDirectory = null; + + try { + baseDirectory = TestUtils.tempDirectory(); + nodes = nodes.copyWithAbsolutePaths(baseDirectory.getAbsolutePath()); + executorService = Executors.newFixedThreadPool(numOfExecutorThreads, + ThreadUtils.createThreadFactory("KafkaClusterTestKit%d", false)); + for (ControllerNode node : nodes.controllerNodes().values()) { + Map props = new HashMap<>(configProps); + props.put(KafkaConfig$.MODULE$.ProcessRolesProp(), "controller"); + props.put(KafkaConfig$.MODULE$.NodeIdProp(), + Integer.toString(node.id())); + props.put(KafkaConfig$.MODULE$.MetadataLogDirProp(), + node.metadataDirectory()); + props.put(KafkaConfig$.MODULE$.ListenerSecurityProtocolMapProp(), + "CONTROLLER:PLAINTEXT"); + props.put(KafkaConfig$.MODULE$.ListenersProp(), + "CONTROLLER://localhost:0"); + props.put(KafkaConfig$.MODULE$.ControllerListenerNamesProp(), + "CONTROLLER"); + // Note: we can't accurately set controller.quorum.voters yet, since we don't + // yet know what ports each controller will pick. Set it to a dummy string \ + // for now as a placeholder. + props.put(RaftConfig.QUORUM_VOTERS_CONFIG, uninitializedQuorumVotersString); + setupNodeDirectories(baseDirectory, node.metadataDirectory(), Collections.emptyList()); + KafkaConfig config = new KafkaConfig(props, false, Option.empty()); + + String threadNamePrefix = String.format("controller%d_", node.id()); + MetaProperties metaProperties = MetaProperties.apply(nodes.clusterId().toString(), node.id()); + TopicPartition metadataPartition = new TopicPartition(KafkaRaftServer.MetadataTopic(), 0); + KafkaRaftManager raftManager = new KafkaRaftManager<>( + metaProperties, config, new MetadataRecordSerde(), metadataPartition, KafkaRaftServer.MetadataTopicId(), + Time.SYSTEM, new Metrics(), Option.apply(threadNamePrefix), connectFutureManager.future); + ControllerServer controller = new ControllerServer( + nodes.controllerProperties(node.id()), + config, + raftManager, + Time.SYSTEM, + new Metrics(), + Option.apply(threadNamePrefix), + connectFutureManager.future + ); + controllers.put(node.id(), controller); + controller.socketServerFirstBoundPortFuture().whenComplete((port, e) -> { + if (e != null) { + connectFutureManager.fail(e); + } else { + connectFutureManager.registerPort(node.id(), port); + } + }); + raftManagers.put(node.id(), raftManager); + } + for (BrokerNode node : nodes.brokerNodes().values()) { + Map props = new HashMap<>(configProps); + props.put(KafkaConfig$.MODULE$.ProcessRolesProp(), "broker"); + props.put(KafkaConfig$.MODULE$.BrokerIdProp(), + Integer.toString(node.id())); + props.put(KafkaConfig$.MODULE$.MetadataLogDirProp(), + node.metadataDirectory()); + props.put(KafkaConfig$.MODULE$.LogDirsProp(), + String.join(",", node.logDataDirectories())); + props.put(KafkaConfig$.MODULE$.ListenerSecurityProtocolMapProp(), + "EXTERNAL:PLAINTEXT,CONTROLLER:PLAINTEXT"); + props.put(KafkaConfig$.MODULE$.ListenersProp(), + "EXTERNAL://localhost:0"); + props.put(KafkaConfig$.MODULE$.InterBrokerListenerNameProp(), + nodes.interBrokerListenerName().value()); + props.put(KafkaConfig$.MODULE$.ControllerListenerNamesProp(), + "CONTROLLER"); + + setupNodeDirectories(baseDirectory, node.metadataDirectory(), + node.logDataDirectories()); + + // Just like above, we set a placeholder voter list here until we + // find out what ports the controllers picked. + props.put(RaftConfig.QUORUM_VOTERS_CONFIG, uninitializedQuorumVotersString); + props.putAll(node.propertyOverrides()); + KafkaConfig config = new KafkaConfig(props, false, Option.empty()); + + String threadNamePrefix = String.format("broker%d_", node.id()); + MetaProperties metaProperties = MetaProperties.apply(nodes.clusterId().toString(), node.id()); + TopicPartition metadataPartition = new TopicPartition(KafkaRaftServer.MetadataTopic(), 0); + KafkaRaftManager raftManager = new KafkaRaftManager<>( + metaProperties, config, new MetadataRecordSerde(), metadataPartition, KafkaRaftServer.MetadataTopicId(), + Time.SYSTEM, new Metrics(), Option.apply(threadNamePrefix), connectFutureManager.future); + BrokerServer broker = new BrokerServer( + config, + nodes.brokerProperties(node.id()), + raftManager, + Time.SYSTEM, + new Metrics(), + Option.apply(threadNamePrefix), + JavaConverters.asScalaBuffer(Collections.emptyList()).toSeq(), + connectFutureManager.future, + Server.SUPPORTED_FEATURES() + ); + brokers.put(node.id(), broker); + raftManagers.put(node.id(), raftManager); + } + } catch (Exception e) { + if (executorService != null) { + executorService.shutdownNow(); + executorService.awaitTermination(5, TimeUnit.MINUTES); + } + for (ControllerServer controller : controllers.values()) { + controller.shutdown(); + } + for (BrokerServer brokerServer : brokers.values()) { + brokerServer.shutdown(); + } + for (KafkaRaftManager raftManager : raftManagers.values()) { + raftManager.shutdown(); + } + connectFutureManager.close(); + if (baseDirectory != null) { + Utils.delete(baseDirectory); + } + throw e; + } + return new KafkaClusterTestKit(executorService, nodes, controllers, + brokers, raftManagers, connectFutureManager, baseDirectory); + } + + static private void setupNodeDirectories(File baseDirectory, + String metadataDirectory, + Collection logDataDirectories) throws Exception { + Files.createDirectories(new File(baseDirectory, "local").toPath()); + Files.createDirectories(Paths.get(metadataDirectory)); + for (String logDataDirectory : logDataDirectories) { + Files.createDirectories(Paths.get(logDataDirectory)); + } + } + } + + private final ExecutorService executorService; + private final TestKitNodes nodes; + private final Map controllers; + private final Map brokers; + private final Map> raftManagers; + private final ControllerQuorumVotersFutureManager controllerQuorumVotersFutureManager; + private final File baseDirectory; + + private KafkaClusterTestKit(ExecutorService executorService, + TestKitNodes nodes, + Map controllers, + Map brokers, + Map> raftManagers, + ControllerQuorumVotersFutureManager controllerQuorumVotersFutureManager, + File baseDirectory) { + this.executorService = executorService; + this.nodes = nodes; + this.controllers = controllers; + this.brokers = brokers; + this.raftManagers = raftManagers; + this.controllerQuorumVotersFutureManager = controllerQuorumVotersFutureManager; + this.baseDirectory = baseDirectory; + } + + public void format() throws Exception { + List> futures = new ArrayList<>(); + try { + for (Entry entry : controllers.entrySet()) { + int nodeId = entry.getKey(); + ControllerServer controller = entry.getValue(); + formatNodeAndLog(nodes.controllerProperties(nodeId), controller.config().metadataLogDir(), + controller, futures::add); + } + for (Entry entry : brokers.entrySet()) { + int nodeId = entry.getKey(); + BrokerServer broker = entry.getValue(); + formatNodeAndLog(nodes.brokerProperties(nodeId), broker.config().metadataLogDir(), + broker, futures::add); + } + for (Future future: futures) { + future.get(); + } + } catch (Exception e) { + for (Future future: futures) { + future.cancel(true); + } + throw e; + } + } + + private void formatNodeAndLog(MetaProperties properties, String metadataLogDir, Logging loggingMixin, + Consumer> futureConsumer) { + futureConsumer.accept(executorService.submit(() -> { + try (ByteArrayOutputStream stream = new ByteArrayOutputStream()) { + try (PrintStream out = new PrintStream(stream)) { + StorageTool.formatCommand(out, + JavaConverters.asScalaBuffer(Collections.singletonList(metadataLogDir)).toSeq(), + properties, + false); + } finally { + for (String line : stream.toString().split(String.format("%n"))) { + loggingMixin.info(() -> line); + } + } + } catch (IOException e) { + throw new RuntimeException(e); + } + })); + } + + public void startup() throws ExecutionException, InterruptedException { + List> futures = new ArrayList<>(); + try { + for (ControllerServer controller : controllers.values()) { + futures.add(executorService.submit(controller::startup)); + } + for (KafkaRaftManager raftManager : raftManagers.values()) { + futures.add(controllerQuorumVotersFutureManager.future.thenRunAsync(raftManager::startup)); + } + for (BrokerServer broker : brokers.values()) { + futures.add(executorService.submit(broker::startup)); + } + for (Future future: futures) { + future.get(); + } + } catch (Exception e) { + for (Future future: futures) { + future.cancel(true); + } + throw e; + } + } + + /** + * Wait for a controller to mark all the brokers as ready (registered and unfenced). + */ + public void waitForReadyBrokers() throws ExecutionException, InterruptedException { + // We can choose any controller, not just the active controller. + // If we choose a standby controller, we will wait slightly longer. + ControllerServer controllerServer = controllers.values().iterator().next(); + Controller controller = controllerServer.controller(); + controller.waitForReadyBrokers(brokers.size()).get(); + } + + public Properties controllerClientProperties() throws ExecutionException, InterruptedException { + Properties properties = new Properties(); + if (!controllers.isEmpty()) { + Collection controllerNodes = RaftConfig.voterConnectionsToNodes( + controllerQuorumVotersFutureManager.future.get()); + + StringBuilder bld = new StringBuilder(); + String prefix = ""; + for (Node node : controllerNodes) { + bld.append(prefix).append(node.id()).append('@'); + bld.append(node.host()).append(":").append(node.port()); + prefix = ","; + } + properties.setProperty(RaftConfig.QUORUM_VOTERS_CONFIG, bld.toString()); + properties.setProperty(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, + controllerNodes.stream().map(n -> n.host() + ":" + n.port()). + collect(Collectors.joining(","))); + } + return properties; + } + + public Properties clientProperties() { + Properties properties = new Properties(); + if (!brokers.isEmpty()) { + StringBuilder bld = new StringBuilder(); + String prefix = ""; + for (Entry entry : brokers.entrySet()) { + int brokerId = entry.getKey(); + BrokerServer broker = entry.getValue(); + ListenerName listenerName = nodes.externalListenerName(); + int port = broker.boundPort(listenerName); + if (port <= 0) { + throw new RuntimeException("Broker " + brokerId + " does not yet " + + "have a bound port for " + listenerName + ". Did you start " + + "the cluster yet?"); + } + bld.append(prefix).append("localhost:").append(port); + prefix = ","; + } + properties.setProperty(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, bld.toString()); + } + return properties; + } + + public Map controllers() { + return controllers; + } + + public Map brokers() { + return brokers; + } + + public Map> raftManagers() { + return raftManagers; + } + + public TestKitNodes nodes() { + return nodes; + } + + @Override + public void close() throws Exception { + List>> futureEntries = new ArrayList<>(); + try { + controllerQuorumVotersFutureManager.close(); + for (Entry entry : brokers.entrySet()) { + int brokerId = entry.getKey(); + BrokerServer broker = entry.getValue(); + futureEntries.add(new SimpleImmutableEntry<>("broker" + brokerId, + executorService.submit(broker::shutdown))); + } + waitForAllFutures(futureEntries); + futureEntries.clear(); + for (Entry entry : controllers.entrySet()) { + int controllerId = entry.getKey(); + ControllerServer controller = entry.getValue(); + futureEntries.add(new SimpleImmutableEntry<>("controller" + controllerId, + executorService.submit(controller::shutdown))); + } + waitForAllFutures(futureEntries); + futureEntries.clear(); + for (Entry> entry : raftManagers.entrySet()) { + int raftManagerId = entry.getKey(); + KafkaRaftManager raftManager = entry.getValue(); + futureEntries.add(new SimpleImmutableEntry<>("raftManager" + raftManagerId, + executorService.submit(raftManager::shutdown))); + } + waitForAllFutures(futureEntries); + futureEntries.clear(); + Utils.delete(baseDirectory); + } catch (Exception e) { + for (Entry> entry : futureEntries) { + entry.getValue().cancel(true); + } + throw e; + } finally { + executorService.shutdownNow(); + executorService.awaitTermination(5, TimeUnit.MINUTES); + } + } + + private void waitForAllFutures(List>> futureEntries) + throws Exception { + for (Entry> entry : futureEntries) { + log.debug("waiting for {} to shut down.", entry.getKey()); + entry.getValue().get(); + log.debug("{} successfully shut down.", entry.getKey()); + } + } +} diff --git a/core/src/test/java/kafka/testkit/TestKitNode.java b/core/src/test/java/kafka/testkit/TestKitNode.java new file mode 100644 index 0000000..a5423d1 --- /dev/null +++ b/core/src/test/java/kafka/testkit/TestKitNode.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.testkit; + +public interface TestKitNode { + int id(); + String metadataDirectory(); +} diff --git a/core/src/test/java/kafka/testkit/TestKitNodes.java b/core/src/test/java/kafka/testkit/TestKitNodes.java new file mode 100644 index 0000000..d52b800 --- /dev/null +++ b/core/src/test/java/kafka/testkit/TestKitNodes.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.testkit; + +import kafka.server.MetaProperties; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.network.ListenerName; + +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.NavigableMap; +import java.util.TreeMap; + +public class TestKitNodes { + public static class Builder { + private Uuid clusterId = null; + private final NavigableMap controllerNodes = new TreeMap<>(); + private final NavigableMap brokerNodes = new TreeMap<>(); + + public Builder setClusterId(Uuid clusterId) { + this.clusterId = clusterId; + return this; + } + + public Builder addNodes(TestKitNode[] nodes) { + for (TestKitNode node : nodes) { + addNode(node); + } + return this; + } + + public Builder addNode(TestKitNode node) { + if (node instanceof ControllerNode) { + ControllerNode controllerNode = (ControllerNode) node; + controllerNodes.put(node.id(), controllerNode); + } else if (node instanceof BrokerNode) { + BrokerNode brokerNode = (BrokerNode) node; + brokerNodes.put(node.id(), brokerNode); + } else { + throw new RuntimeException("Can't handle TestKitNode subclass " + + node.getClass().getSimpleName()); + } + return this; + } + + public Builder setNumControllerNodes(int numControllerNodes) { + if (numControllerNodes < 0) { + throw new RuntimeException("Invalid negative value for numControllerNodes"); + } + + while (controllerNodes.size() > numControllerNodes) { + controllerNodes.pollFirstEntry(); + } + while (controllerNodes.size() < numControllerNodes) { + int nextId = 3000; + if (!controllerNodes.isEmpty()) { + nextId = controllerNodes.lastKey() + 1; + } + controllerNodes.put(nextId, new ControllerNode.Builder(). + setId(nextId).build()); + } + return this; + } + + public Builder setNumBrokerNodes(int numBrokerNodes) { + if (numBrokerNodes < 0) { + throw new RuntimeException("Invalid negative value for numBrokerNodes"); + } + while (brokerNodes.size() > numBrokerNodes) { + brokerNodes.pollFirstEntry(); + } + while (brokerNodes.size() < numBrokerNodes) { + int nextId = 0; + if (!brokerNodes.isEmpty()) { + nextId = brokerNodes.lastKey() + 1; + } + brokerNodes.put(nextId, new BrokerNode.Builder(). + setId(nextId).build()); + } + return this; + } + + public TestKitNodes build() { + if (clusterId == null) { + clusterId = Uuid.randomUuid(); + } + return new TestKitNodes(clusterId, controllerNodes, brokerNodes); + } + } + + private final Uuid clusterId; + private final NavigableMap controllerNodes; + private final NavigableMap brokerNodes; + + private TestKitNodes(Uuid clusterId, + NavigableMap controllerNodes, + NavigableMap brokerNodes) { + this.clusterId = clusterId; + this.controllerNodes = controllerNodes; + this.brokerNodes = brokerNodes; + } + + public Uuid clusterId() { + return clusterId; + } + + public Map controllerNodes() { + return controllerNodes; + } + + public NavigableMap brokerNodes() { + return brokerNodes; + } + + public MetaProperties controllerProperties(int id) { + return MetaProperties.apply(clusterId.toString(), id); + } + + public MetaProperties brokerProperties(int id) { + return MetaProperties.apply(clusterId.toString(), id); + } + + public ListenerName interBrokerListenerName() { + return new ListenerName("EXTERNAL"); + } + + public ListenerName externalListenerName() { + return new ListenerName("EXTERNAL"); + } + + public TestKitNodes copyWithAbsolutePaths(String baseDirectory) { + NavigableMap newControllerNodes = new TreeMap<>(); + NavigableMap newBrokerNodes = new TreeMap<>(); + for (Entry entry : controllerNodes.entrySet()) { + ControllerNode node = entry.getValue(); + newControllerNodes.put(entry.getKey(), new ControllerNode(node.id(), + absolutize(baseDirectory, node.metadataDirectory()))); + } + for (Entry entry : brokerNodes.entrySet()) { + BrokerNode node = entry.getValue(); + newBrokerNodes.put(entry.getKey(), new BrokerNode(node.id(), + node.incarnationId(), absolutize(baseDirectory, node.metadataDirectory()), + absolutize(baseDirectory, node.logDataDirectories()), node.propertyOverrides())); + } + return new TestKitNodes(clusterId, newControllerNodes, newBrokerNodes); + } + + private static List absolutize(String base, Collection directories) { + List newDirectories = new ArrayList<>(); + for (String directory : directories) { + newDirectories.add(absolutize(base, directory)); + } + return newDirectories; + } + + private static String absolutize(String base, String directory) { + if (Paths.get(directory).isAbsolute()) { + return directory; + } + return Paths.get(base, directory).toAbsolutePath().toString(); + } +} diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties new file mode 100644 index 0000000..f7fb736 --- /dev/null +++ b/core/src/test/resources/log4j.properties @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +log4j.rootLogger=OFF, stdout + +log4j.appender.stdout=org.apache.log4j.ConsoleAppender +log4j.appender.stdout.layout=org.apache.log4j.PatternLayout +log4j.appender.stdout.layout.ConversionPattern=[%d] %p %m (%c:%L)%n + +log4j.logger.kafka=WARN +log4j.logger.org.apache.kafka=WARN + + +# zkclient can be verbose, during debugging it is common to adjust it separately +log4j.logger.org.apache.zookeeper=WARN diff --git a/core/src/test/resources/minikdc-krb5.conf b/core/src/test/resources/minikdc-krb5.conf new file mode 100644 index 0000000..20f1be5 --- /dev/null +++ b/core/src/test/resources/minikdc-krb5.conf @@ -0,0 +1,27 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +[libdefaults] +default_realm = {0} +udp_preference_limit = 1 +default_tkt_enctypes=aes128-cts-hmac-sha1-96 +default_tgs_enctypes=aes128-cts-hmac-sha1-96 + +[realms] +{0} = '{' + kdc = {1}:{2} +'}' diff --git a/core/src/test/resources/minikdc.ldiff b/core/src/test/resources/minikdc.ldiff new file mode 100644 index 0000000..75e4dfd --- /dev/null +++ b/core/src/test/resources/minikdc.ldiff @@ -0,0 +1,47 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +dn: ou=users,dc=${0},dc=${1} +objectClass: organizationalUnit +objectClass: top +ou: users + +dn: uid=krbtgt,ou=users,dc=${0},dc=${1} +objectClass: top +objectClass: person +objectClass: inetOrgPerson +objectClass: krb5principal +objectClass: krb5kdcentry +cn: KDC Service +sn: Service +uid: krbtgt +userPassword: secret +krb5PrincipalName: krbtgt/${2}.${3}@${2}.${3} +krb5KeyVersionNumber: 0 + +dn: uid=ldap,ou=users,dc=${0},dc=${1} +objectClass: top +objectClass: person +objectClass: inetOrgPerson +objectClass: krb5principal +objectClass: krb5kdcentry +cn: LDAP +sn: Service +uid: ldap +userPassword: secret +krb5PrincipalName: ldap/${4}@${2}.${3} +krb5KeyVersionNumber: 0 diff --git a/core/src/test/scala/integration/kafka/admin/BrokerApiVersionsCommandTest.scala b/core/src/test/scala/integration/kafka/admin/BrokerApiVersionsCommandTest.scala new file mode 100644 index 0000000..2db694f --- /dev/null +++ b/core/src/test/scala/integration/kafka/admin/BrokerApiVersionsCommandTest.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.admin + +import java.io.{ByteArrayOutputStream, PrintStream} +import java.nio.charset.StandardCharsets +import scala.collection.Seq +import kafka.integration.KafkaServerTestHarness +import kafka.server.KafkaConfig +import kafka.utils.TestUtils +import org.apache.kafka.clients.NodeApiVersions +import org.apache.kafka.common.protocol.ApiKeys +import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertNotNull, assertTrue} +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Timeout + +import scala.jdk.CollectionConverters._ + +class BrokerApiVersionsCommandTest extends KafkaServerTestHarness { + + def generateConfigs: Seq[KafkaConfig] = + TestUtils.createBrokerConfigs(1, zkConnect).map(props => { + // Configure control plane listener to make sure we have separate listeners from client, + // in order to avoid returning Envelope API version. + props.setProperty(KafkaConfig.ControlPlaneListenerNameProp, "CONTROLLER") + props.setProperty(KafkaConfig.ListenerSecurityProtocolMapProp, "CONTROLLER:PLAINTEXT,PLAINTEXT:PLAINTEXT") + props.setProperty("listeners", "PLAINTEXT://localhost:0,CONTROLLER://localhost:0") + props.setProperty(KafkaConfig.AdvertisedListenersProp, "PLAINTEXT://localhost:0,CONTROLLER://localhost:0") + props + }).map(KafkaConfig.fromProps) + + @Timeout(120) + @Test + def checkBrokerApiVersionCommandOutput(): Unit = { + val byteArrayOutputStream = new ByteArrayOutputStream + val printStream = new PrintStream(byteArrayOutputStream, false, StandardCharsets.UTF_8.name()) + BrokerApiVersionsCommand.execute(Array("--bootstrap-server", brokerList), printStream) + val content = new String(byteArrayOutputStream.toByteArray, StandardCharsets.UTF_8) + val lineIter = content.split("\n").iterator + assertTrue(lineIter.hasNext) + assertEquals(s"$brokerList (id: 0 rack: null) -> (", lineIter.next()) + val nodeApiVersions = NodeApiVersions.create + val enabledApis = ApiKeys.zkBrokerApis.asScala + for (apiKey <- enabledApis) { + val apiVersion = nodeApiVersions.apiVersion(apiKey) + assertNotNull(apiVersion) + + val versionRangeStr = + if (apiVersion.minVersion == apiVersion.maxVersion) apiVersion.minVersion.toString + else s"${apiVersion.minVersion} to ${apiVersion.maxVersion}" + val usableVersion = nodeApiVersions.latestUsableVersion(apiKey) + + val terminator = if (apiKey == enabledApis.last) "" else "," + + val line = s"\t${apiKey.name}(${apiKey.id}): $versionRangeStr [usable: $usableVersion]$terminator" + assertTrue(lineIter.hasNext) + assertEquals(line, lineIter.next()) + } + assertTrue(lineIter.hasNext) + assertEquals(")", lineIter.next()) + assertFalse(lineIter.hasNext) + } +} diff --git a/core/src/test/scala/integration/kafka/admin/ListOffsetsIntegrationTest.scala b/core/src/test/scala/integration/kafka/admin/ListOffsetsIntegrationTest.scala new file mode 100644 index 0000000..6e20188 --- /dev/null +++ b/core/src/test/scala/integration/kafka/admin/ListOffsetsIntegrationTest.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package integration.kafka.admin + +import kafka.integration.KafkaServerTestHarness +import kafka.server.KafkaConfig +import kafka.utils.TestUtils +import org.apache.kafka.clients.admin._ +import org.apache.kafka.clients.producer.ProducerRecord +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.utils.Utils +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +import scala.collection.{Map, Seq} +import scala.jdk.CollectionConverters._ + +class ListOffsetsIntegrationTest extends KafkaServerTestHarness { + + val topicName = "foo" + var adminClient: Admin = null + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + createTopic(topicName, 1, 1.toShort) + produceMessages() + adminClient = Admin.create(Map[String, Object]( + AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG -> brokerList + ).asJava) + } + + @AfterEach + override def tearDown(): Unit = { + Utils.closeQuietly(adminClient, "ListOffsetsAdminClient") + super.tearDown() + } + + @Test + def testEarliestOffset(): Unit = { + val earliestOffset = runFetchOffsets(adminClient, OffsetSpec.earliest()) + assertEquals(0, earliestOffset.offset()) + } + + @Test + def testLatestOffset(): Unit = { + val latestOffset = runFetchOffsets(adminClient, OffsetSpec.latest()) + assertEquals(3, latestOffset.offset()) + } + + @Test + def testMaxTimestampOffset(): Unit = { + val maxTimestampOffset = runFetchOffsets(adminClient, OffsetSpec.maxTimestamp()) + assertEquals(1, maxTimestampOffset.offset()) + } + + private def runFetchOffsets(adminClient: Admin, + offsetSpec: OffsetSpec): ListOffsetsResult.ListOffsetsResultInfo = { + val tp = new TopicPartition(topicName, 0) + adminClient.listOffsets(Map( + tp -> offsetSpec + ).asJava, new ListOffsetsOptions()).all().get().get(tp) + } + + def produceMessages(): Unit = { + val records = Seq( + new ProducerRecord[Array[Byte], Array[Byte]](topicName, 0, 100L, + null, new Array[Byte](10000)), + new ProducerRecord[Array[Byte], Array[Byte]](topicName, 0, 999L, + null, new Array[Byte](10000)), + new ProducerRecord[Array[Byte], Array[Byte]](topicName, 0, 200L, + null, new Array[Byte](10000)), + ) + TestUtils.produceMessages(servers, records, -1) + } + + def generateConfigs: Seq[KafkaConfig] = + TestUtils.createBrokerConfigs(1, zkConnect).map(KafkaConfig.fromProps) +} + diff --git a/core/src/test/scala/integration/kafka/admin/ReassignPartitionsIntegrationTest.scala b/core/src/test/scala/integration/kafka/admin/ReassignPartitionsIntegrationTest.scala new file mode 100644 index 0000000..2969a95 --- /dev/null +++ b/core/src/test/scala/integration/kafka/admin/ReassignPartitionsIntegrationTest.scala @@ -0,0 +1,660 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.admin + +import java.io.Closeable +import java.util.{Collections, HashMap, List} +import kafka.admin.ReassignPartitionsCommand._ +import kafka.api.KAFKA_2_7_IV1 +import kafka.server.{IsrChangePropagationConfig, KafkaConfig, KafkaServer, ZkIsrManager} +import kafka.utils.Implicits._ +import kafka.utils.TestUtils +import kafka.server.QuorumTestHarness +import org.apache.kafka.clients.admin.{Admin, AdminClientConfig, AlterConfigOp, ConfigEntry, DescribeLogDirsResult, NewTopic} +import org.apache.kafka.clients.producer.ProducerRecord +import org.apache.kafka.common.config.ConfigResource +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.utils.Utils +import org.apache.kafka.common.{TopicPartition, TopicPartitionReplica} +import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertTrue} +import org.junit.jupiter.api.{AfterEach, Test, Timeout} + +import scala.collection.{Map, Seq, mutable} +import scala.jdk.CollectionConverters._ + +@Timeout(300) +class ReassignPartitionsIntegrationTest extends QuorumTestHarness { + + var cluster: ReassignPartitionsTestCluster = null + + @AfterEach + override def tearDown(): Unit = { + Utils.closeQuietly(cluster, "ReassignPartitionsTestCluster") + super.tearDown() + } + + val unthrottledBrokerConfigs = + 0.to(4).map { brokerId => + brokerId -> brokerLevelThrottles.map(throttle => (throttle, -1L)).toMap + }.toMap + + + @Test + def testReassignment(): Unit = { + cluster = new ReassignPartitionsTestCluster(zkConnect) + cluster.setup() + executeAndVerifyReassignment() + } + + @Test + def testReassignmentWithAlterIsrDisabled(): Unit = { + // Test reassignment when the IBP is on an older version which does not use + // the `AlterIsr` API. In this case, the controller will register individual + // watches for each reassigning partition so that the reassignment can be + // completed as soon as the ISR is expanded. + val configOverrides = Map(KafkaConfig.InterBrokerProtocolVersionProp -> KAFKA_2_7_IV1.version) + cluster = new ReassignPartitionsTestCluster(zkConnect, configOverrides = configOverrides) + cluster.setup() + executeAndVerifyReassignment() + } + + @Test + def testReassignmentCompletionDuringPartialUpgrade(): Unit = { + // Test reassignment during a partial upgrade when some brokers are relying on + // `AlterIsr` and some rely on the old notification logic through Zookeeper. + // In this test case, broker 0 starts up first on the latest IBP and is typically + // elected as controller. The three remaining brokers start up on the older IBP. + // We want to ensure that reassignment can still complete through the ISR change + // notification path even though the controller expects `AlterIsr`. + + // Override change notification settings so that test is not delayed by ISR + // change notification delay + ZkIsrManager.DefaultIsrPropagationConfig = IsrChangePropagationConfig( + checkIntervalMs = 500, + lingerMs = 100, + maxDelayMs = 500 + ) + + val oldIbpConfig = Map(KafkaConfig.InterBrokerProtocolVersionProp -> KAFKA_2_7_IV1.version) + val brokerConfigOverrides = Map(1 -> oldIbpConfig, 2 -> oldIbpConfig, 3 -> oldIbpConfig) + + cluster = new ReassignPartitionsTestCluster(zkConnect, brokerConfigOverrides = brokerConfigOverrides) + cluster.setup() + + executeAndVerifyReassignment() + } + + def executeAndVerifyReassignment(): Unit = { + val assignment = """{"version":1,"partitions":""" + + """[{"topic":"foo","partition":0,"replicas":[0,1,3],"log_dirs":["any","any","any"]},""" + + """{"topic":"bar","partition":0,"replicas":[3,2,0],"log_dirs":["any","any","any"]}""" + + """]}""" + + // Check that the assignment has not yet been started yet. + val initialAssignment = Map( + new TopicPartition("foo", 0) -> + PartitionReassignmentState(Seq(0, 1, 2), Seq(0, 1, 3), true), + new TopicPartition("bar", 0) -> + PartitionReassignmentState(Seq(3, 2, 1), Seq(3, 2, 0), true) + ) + waitForVerifyAssignment(cluster.adminClient, assignment, false, + VerifyAssignmentResult(initialAssignment)) + + // Execute the assignment + runExecuteAssignment(cluster.adminClient, false, assignment, -1L, -1L) + assertEquals(unthrottledBrokerConfigs, describeBrokerLevelThrottles(unthrottledBrokerConfigs.keySet.toSeq)) + val finalAssignment = Map( + new TopicPartition("foo", 0) -> + PartitionReassignmentState(Seq(0, 1, 3), Seq(0, 1, 3), true), + new TopicPartition("bar", 0) -> + PartitionReassignmentState(Seq(3, 2, 0), Seq(3, 2, 0), true) + ) + + val verifyAssignmentResult = runVerifyAssignment(cluster.adminClient, assignment, false) + assertFalse(verifyAssignmentResult.movesOngoing) + + // Wait for the assignment to complete + waitForVerifyAssignment(cluster.adminClient, assignment, false, + VerifyAssignmentResult(finalAssignment)) + + assertEquals(unthrottledBrokerConfigs, + describeBrokerLevelThrottles(unthrottledBrokerConfigs.keySet.toSeq)) + } + + @Test + def testHighWaterMarkAfterPartitionReassignment(): Unit = { + cluster = new ReassignPartitionsTestCluster(zkConnect) + cluster.setup() + val assignment = """{"version":1,"partitions":""" + + """[{"topic":"foo","partition":0,"replicas":[3,1,2],"log_dirs":["any","any","any"]}""" + + """]}""" + + // Set the high water mark of foo-0 to 123 on its leader. + val part = new TopicPartition("foo", 0) + cluster.servers(0).replicaManager.logManager.truncateFullyAndStartAt(part, 123L, false) + + // Execute the assignment + runExecuteAssignment(cluster.adminClient, false, assignment, -1L, -1L) + val finalAssignment = Map(part -> + PartitionReassignmentState(Seq(3, 1, 2), Seq(3, 1, 2), true)) + + // Wait for the assignment to complete + waitForVerifyAssignment(cluster.adminClient, assignment, false, + VerifyAssignmentResult(finalAssignment)) + + TestUtils.waitUntilTrue(() => { + cluster.servers(3).replicaManager.onlinePartition(part). + flatMap(_.leaderLogIfLocal).isDefined + }, "broker 3 should be the new leader", pause = 10L) + assertEquals(123L, cluster.servers(3).replicaManager.localLogOrException(part).highWatermark, + s"Expected broker 3 to have the correct high water mark for the partition.") + } + + @Test + def testAlterReassignmentThrottle(): Unit = { + cluster = new ReassignPartitionsTestCluster(zkConnect) + cluster.setup() + cluster.produceMessages("foo", 0, 50) + cluster.produceMessages("baz", 2, 60) + val assignment = """{"version":1,"partitions": + [{"topic":"foo","partition":0,"replicas":[0,3,2],"log_dirs":["any","any","any"]}, + {"topic":"baz","partition":2,"replicas":[3,2,1],"log_dirs":["any","any","any"]} + ]}""" + + // Execute the assignment with a low throttle + val initialThrottle = 1L + runExecuteAssignment(cluster.adminClient, false, assignment, initialThrottle, -1L) + waitForInterBrokerThrottle(Set(0, 1, 2, 3), initialThrottle) + + // Now update the throttle and verify the reassignment completes + val updatedThrottle = 300000L + runExecuteAssignment(cluster.adminClient, additional = true, assignment, updatedThrottle, -1L) + waitForInterBrokerThrottle(Set(0, 1, 2, 3), updatedThrottle) + + val finalAssignment = Map( + new TopicPartition("foo", 0) -> + PartitionReassignmentState(Seq(0, 3, 2), Seq(0, 3, 2), true), + new TopicPartition("baz", 2) -> + PartitionReassignmentState(Seq(3, 2, 1), Seq(3, 2, 1), true)) + + // Now remove the throttles. + waitForVerifyAssignment(cluster.adminClient, assignment, false, + VerifyAssignmentResult(finalAssignment)) + waitForBrokerLevelThrottles(unthrottledBrokerConfigs) + } + + /** + * Test running a reassignment with the interBrokerThrottle set. + */ + @Test + def testThrottledReassignment(): Unit = { + cluster = new ReassignPartitionsTestCluster(zkConnect) + cluster.setup() + cluster.produceMessages("foo", 0, 50) + cluster.produceMessages("baz", 2, 60) + val assignment = """{"version":1,"partitions":""" + + """[{"topic":"foo","partition":0,"replicas":[0,3,2],"log_dirs":["any","any","any"]},""" + + """{"topic":"baz","partition":2,"replicas":[3,2,1],"log_dirs":["any","any","any"]}""" + + """]}""" + + // Check that the assignment has not yet been started yet. + val initialAssignment = Map( + new TopicPartition("foo", 0) -> + PartitionReassignmentState(Seq(0, 1, 2), Seq(0, 3, 2), true), + new TopicPartition("baz", 2) -> + PartitionReassignmentState(Seq(0, 2, 1), Seq(3, 2, 1), true)) + assertEquals(VerifyAssignmentResult(initialAssignment), runVerifyAssignment(cluster.adminClient, assignment, false)) + assertEquals(unthrottledBrokerConfigs, describeBrokerLevelThrottles(unthrottledBrokerConfigs.keySet.toSeq)) + + // Execute the assignment + val interBrokerThrottle = 300000L + runExecuteAssignment(cluster.adminClient, false, assignment, interBrokerThrottle, -1L) + waitForInterBrokerThrottle(Set(0, 1, 2, 3), interBrokerThrottle) + + val finalAssignment = Map( + new TopicPartition("foo", 0) -> + PartitionReassignmentState(Seq(0, 3, 2), Seq(0, 3, 2), true), + new TopicPartition("baz", 2) -> + PartitionReassignmentState(Seq(3, 2, 1), Seq(3, 2, 1), true)) + + // Wait for the assignment to complete + TestUtils.waitUntilTrue( + () => { + // Check the reassignment status. + val result = runVerifyAssignment(cluster.adminClient, assignment, true) + if (!result.partsOngoing) { + true + } else { + assertFalse(result.partStates.forall(_._2.done), s"Expected at least one partition reassignment to be ongoing when result = $result") + assertEquals(Seq(0, 3, 2), result.partStates(new TopicPartition("foo", 0)).targetReplicas) + assertEquals(Seq(3, 2, 1), result.partStates(new TopicPartition("baz", 2)).targetReplicas) + logger.info(s"Current result: ${result}") + waitForInterBrokerThrottle(Set(0, 1, 2, 3), interBrokerThrottle) + false + } + }, "Expected reassignment to complete.") + waitForVerifyAssignment(cluster.adminClient, assignment, true, + VerifyAssignmentResult(finalAssignment)) + // The throttles should still have been preserved, since we ran with --preserve-throttles + waitForInterBrokerThrottle(Set(0, 1, 2, 3), interBrokerThrottle) + // Now remove the throttles. + waitForVerifyAssignment(cluster.adminClient, assignment, false, + VerifyAssignmentResult(finalAssignment)) + waitForBrokerLevelThrottles(unthrottledBrokerConfigs) + } + + @Test + def testProduceAndConsumeWithReassignmentInProgress(): Unit = { + cluster = new ReassignPartitionsTestCluster(zkConnect) + cluster.setup() + cluster.produceMessages("baz", 2, 60) + val assignment = """{"version":1,"partitions":""" + + """[{"topic":"baz","partition":2,"replicas":[3,2,1],"log_dirs":["any","any","any"]}""" + + """]}""" + runExecuteAssignment(cluster.adminClient, false, assignment, 300L, -1L) + cluster.produceMessages("baz", 2, 100) + val consumer = TestUtils.createConsumer(cluster.brokerList) + val part = new TopicPartition("baz", 2) + try { + consumer.assign(Seq(part).asJava) + TestUtils.pollUntilAtLeastNumRecords(consumer, numRecords = 100) + } finally { + consumer.close() + } + TestUtils.removeReplicationThrottleForPartitions(cluster.adminClient, Seq(0,1,2,3), Set(part)) + val finalAssignment = Map(part -> + PartitionReassignmentState(Seq(3, 2, 1), Seq(3, 2, 1), true)) + waitForVerifyAssignment(cluster.adminClient, assignment, false, + VerifyAssignmentResult(finalAssignment)) + } + + /** + * Test running a reassignment and then cancelling it. + */ + @Test + def testCancellation(): Unit = { + cluster = new ReassignPartitionsTestCluster(zkConnect) + cluster.setup() + cluster.produceMessages("foo", 0, 200) + cluster.produceMessages("baz", 1, 200) + val assignment = """{"version":1,"partitions":""" + + """[{"topic":"foo","partition":0,"replicas":[0,1,3],"log_dirs":["any","any","any"]},""" + + """{"topic":"baz","partition":1,"replicas":[0,2,3],"log_dirs":["any","any","any"]}""" + + """]}""" + assertEquals(unthrottledBrokerConfigs, + describeBrokerLevelThrottles(unthrottledBrokerConfigs.keySet.toSeq)) + val interBrokerThrottle = 1L + runExecuteAssignment(cluster.adminClient, false, assignment, interBrokerThrottle, -1L) + waitForInterBrokerThrottle(Set(0, 1, 2, 3), interBrokerThrottle) + + // Verify that the reassignment is running. The very low throttle should keep it + // from completing before this runs. + waitForVerifyAssignment(cluster.adminClient, assignment, true, + VerifyAssignmentResult(Map( + new TopicPartition("foo", 0) -> PartitionReassignmentState(Seq(0, 1, 3, 2), Seq(0, 1, 3), false), + new TopicPartition("baz", 1) -> PartitionReassignmentState(Seq(0, 2, 3, 1), Seq(0, 2, 3), false)), + true, Map(), false)) + // Cancel the reassignment. + assertEquals((Set( + new TopicPartition("foo", 0), + new TopicPartition("baz", 1) + ), Set()), runCancelAssignment(cluster.adminClient, assignment, true)) + // Broker throttles are still active because we passed --preserve-throttles + waitForInterBrokerThrottle(Set(0, 1, 2, 3), interBrokerThrottle) + // Cancelling the reassignment again should reveal nothing to cancel. + assertEquals((Set(), Set()), runCancelAssignment(cluster.adminClient, assignment, false)) + // This time, the broker throttles were removed. + waitForBrokerLevelThrottles(unthrottledBrokerConfigs) + // Verify that there are no ongoing reassignments. + assertFalse(runVerifyAssignment(cluster.adminClient, assignment, false).partsOngoing) + } + + private def waitForLogDirThrottle(throttledBrokers: Set[Int], logDirThrottle: Long): Unit = { + val throttledConfigMap = Map[String, Long]( + brokerLevelLeaderThrottle -> -1, + brokerLevelFollowerThrottle -> -1, + brokerLevelLogDirThrottle -> logDirThrottle) + waitForBrokerThrottles(throttledBrokers, throttledConfigMap) + } + + private def waitForInterBrokerThrottle(throttledBrokers: Set[Int], interBrokerThrottle: Long): Unit = { + val throttledConfigMap = Map[String, Long]( + brokerLevelLeaderThrottle -> interBrokerThrottle, + brokerLevelFollowerThrottle -> interBrokerThrottle, + brokerLevelLogDirThrottle -> -1L) + waitForBrokerThrottles(throttledBrokers, throttledConfigMap) + } + + private def waitForBrokerThrottles(throttledBrokers: Set[Int], throttleConfig: Map[String, Long]): Unit = { + val throttledBrokerConfigs = unthrottledBrokerConfigs.map { case (brokerId, unthrottledConfig) => + val expectedThrottleConfig = if (throttledBrokers.contains(brokerId)) { + throttleConfig + } else { + unthrottledConfig + } + brokerId -> expectedThrottleConfig + } + waitForBrokerLevelThrottles(throttledBrokerConfigs) + } + + private def waitForBrokerLevelThrottles(targetThrottles: Map[Int, Map[String, Long]]): Unit = { + var curThrottles: Map[Int, Map[String, Long]] = Map.empty + TestUtils.waitUntilTrue(() => { + curThrottles = describeBrokerLevelThrottles(targetThrottles.keySet.toSeq) + targetThrottles.equals(curThrottles) + }, s"timed out waiting for broker throttle to become ${targetThrottles}. " + + s"Latest throttles were ${curThrottles}", pause = 25) + } + + /** + * Describe the broker-level throttles in the cluster. + * + * @return A map whose keys are broker IDs and whose values are throttle + * information. The nested maps are keyed on throttle name. + */ + private def describeBrokerLevelThrottles(brokerIds: Seq[Int]): Map[Int, Map[String, Long]] = { + brokerIds.map { brokerId => + val props = zkClient.getEntityConfigs("brokers", brokerId.toString) + val throttles = brokerLevelThrottles.map { throttleName => + (throttleName, props.getOrDefault(throttleName, "-1").asInstanceOf[String].toLong) + }.toMap + brokerId -> throttles + }.toMap + } + + /** + * Test moving partitions between directories. + */ + @Test + def testLogDirReassignment(): Unit = { + val topicPartition = new TopicPartition("foo", 0) + + cluster = new ReassignPartitionsTestCluster(zkConnect) + cluster.setup() + cluster.produceMessages(topicPartition.topic, topicPartition.partition, 700) + + val targetBrokerId = 0 + val replicas = Seq(0, 1, 2) + val reassignment = buildLogDirReassignment(topicPartition, targetBrokerId, replicas) + + // Start the replica move, but throttle it to be very slow so that it can't complete + // before our next checks happen. + val logDirThrottle = 1L + runExecuteAssignment(cluster.adminClient, additional = false, reassignment.json, + interBrokerThrottle = -1L, logDirThrottle) + + // Check the output of --verify + waitForVerifyAssignment(cluster.adminClient, reassignment.json, true, + VerifyAssignmentResult(Map( + topicPartition -> PartitionReassignmentState(Seq(0, 1, 2), Seq(0, 1, 2), true) + ), false, Map( + new TopicPartitionReplica(topicPartition.topic, topicPartition.partition, 0) -> + ActiveMoveState(reassignment.currentDir, reassignment.targetDir, reassignment.targetDir) + ), true)) + waitForLogDirThrottle(Set(0), logDirThrottle) + + // Remove the throttle + cluster.adminClient.incrementalAlterConfigs(Collections.singletonMap( + new ConfigResource(ConfigResource.Type.BROKER, "0"), + Collections.singletonList(new AlterConfigOp( + new ConfigEntry(brokerLevelLogDirThrottle, ""), AlterConfigOp.OpType.DELETE)))) + .all().get() + waitForBrokerLevelThrottles(unthrottledBrokerConfigs) + + // Wait for the directory movement to complete. + waitForVerifyAssignment(cluster.adminClient, reassignment.json, true, + VerifyAssignmentResult(Map( + topicPartition -> PartitionReassignmentState(Seq(0, 1, 2), Seq(0, 1, 2), true) + ), false, Map( + new TopicPartitionReplica(topicPartition.topic, topicPartition.partition, 0) -> + CompletedMoveState(reassignment.targetDir) + ), false)) + + val info1 = new BrokerDirs(cluster.adminClient.describeLogDirs(0.to(4). + map(_.asInstanceOf[Integer]).asJavaCollection), 0) + assertEquals(reassignment.targetDir, info1.curLogDirs.getOrElse(topicPartition, "")) + } + + @Test + def testAlterLogDirReassignmentThrottle(): Unit = { + val topicPartition = new TopicPartition("foo", 0) + + cluster = new ReassignPartitionsTestCluster(zkConnect) + cluster.setup() + cluster.produceMessages(topicPartition.topic, topicPartition.partition, 700) + + val targetBrokerId = 0 + val replicas = Seq(0, 1, 2) + val reassignment = buildLogDirReassignment(topicPartition, targetBrokerId, replicas) + + // Start the replica move with a low throttle so it does not complete + val initialLogDirThrottle = 1L + runExecuteAssignment(cluster.adminClient, false, reassignment.json, + interBrokerThrottle = -1L, initialLogDirThrottle) + waitForLogDirThrottle(Set(0), initialLogDirThrottle) + + // Now increase the throttle and verify that the log dir movement completes + val updatedLogDirThrottle = 3000000L + runExecuteAssignment(cluster.adminClient, additional = true, reassignment.json, + interBrokerThrottle = -1L, replicaAlterLogDirsThrottle = updatedLogDirThrottle) + waitForLogDirThrottle(Set(0), updatedLogDirThrottle) + + waitForVerifyAssignment(cluster.adminClient, reassignment.json, true, + VerifyAssignmentResult(Map( + topicPartition -> PartitionReassignmentState(Seq(0, 1, 2), Seq(0, 1, 2), true) + ), false, Map( + new TopicPartitionReplica(topicPartition.topic, topicPartition.partition, targetBrokerId) -> + CompletedMoveState(reassignment.targetDir) + ), false)) + } + + case class LogDirReassignment(json: String, currentDir: String, targetDir: String) + + private def buildLogDirReassignment(topicPartition: TopicPartition, + brokerId: Int, + replicas: Seq[Int]): LogDirReassignment = { + + val describeLogDirsResult = cluster.adminClient.describeLogDirs( + 0.to(4).map(_.asInstanceOf[Integer]).asJavaCollection) + + val logDirInfo = new BrokerDirs(describeLogDirsResult, brokerId) + assertTrue(logDirInfo.futureLogDirs.isEmpty) + + val currentDir = logDirInfo.curLogDirs(topicPartition) + val newDir = logDirInfo.logDirs.find(!_.equals(currentDir)).get + + val logDirs = replicas.map { replicaId => + if (replicaId == brokerId) + s""""$newDir"""" + else + "\"any\"" + } + + val reassignmentJson = + s""" + | { "version": 1, + | "partitions": [ + | { + | "topic": "${topicPartition.topic}", + | "partition": ${topicPartition.partition}, + | "replicas": [${replicas.mkString(",")}], + | "log_dirs": [${logDirs.mkString(",")}] + | } + | ] + | } + |""".stripMargin + + LogDirReassignment(reassignmentJson, currentDir = currentDir, targetDir = newDir) + } + + private def runVerifyAssignment(adminClient: Admin, jsonString: String, + preserveThrottles: Boolean) = { + println(s"==> verifyAssignment(adminClient, jsonString=${jsonString})") + verifyAssignment(adminClient, jsonString, preserveThrottles) + } + + private def waitForVerifyAssignment(adminClient: Admin, + jsonString: String, + preserveThrottles: Boolean, + expectedResult: VerifyAssignmentResult): Unit = { + var latestResult: VerifyAssignmentResult = null + TestUtils.waitUntilTrue( + () => { + latestResult = runVerifyAssignment(adminClient, jsonString, preserveThrottles) + expectedResult.equals(latestResult) + }, s"Timed out waiting for verifyAssignment result ${expectedResult}. " + + s"The latest result was ${latestResult}", pause = 10L) + } + + private def runExecuteAssignment(adminClient: Admin, + additional: Boolean, + reassignmentJson: String, + interBrokerThrottle: Long, + replicaAlterLogDirsThrottle: Long) = { + println(s"==> executeAssignment(adminClient, additional=${additional}, " + + s"reassignmentJson=${reassignmentJson}, " + + s"interBrokerThrottle=${interBrokerThrottle}, " + + s"replicaAlterLogDirsThrottle=${replicaAlterLogDirsThrottle}))") + executeAssignment(adminClient, additional, reassignmentJson, + interBrokerThrottle, replicaAlterLogDirsThrottle) + } + + private def runCancelAssignment(adminClient: Admin, jsonString: String, + preserveThrottles: Boolean) = { + println(s"==> cancelAssignment(adminClient, jsonString=${jsonString})") + cancelAssignment(adminClient, jsonString, preserveThrottles) + } + + class BrokerDirs(result: DescribeLogDirsResult, val brokerId: Int) { + val logDirs = new mutable.HashSet[String] + val curLogDirs = new mutable.HashMap[TopicPartition, String] + val futureLogDirs = new mutable.HashMap[TopicPartition, String] + result.descriptions.get(brokerId).get().forEach { + case (logDirName, logDirInfo) => { + logDirs.add(logDirName) + logDirInfo.replicaInfos.forEach { + case (part, info) => + if (info.isFuture) { + futureLogDirs.put(part, logDirName) + } else { + curLogDirs.put(part, logDirName) + } + } + } + } + } + + class ReassignPartitionsTestCluster( + val zkConnect: String, + configOverrides: Map[String, String] = Map.empty, + brokerConfigOverrides: Map[Int, Map[String, String]] = Map.empty + ) extends Closeable { + val brokers = Map( + 0 -> "rack0", + 1 -> "rack0", + 2 -> "rack1", + 3 -> "rack1", + 4 -> "rack1" + ) + + val topics = Map( + "foo" -> Seq(Seq(0, 1, 2), Seq(1, 2, 3)), + "bar" -> Seq(Seq(3, 2, 1)), + "baz" -> Seq(Seq(1, 0, 2), Seq(2, 0, 1), Seq(0, 2, 1)) + ) + + val brokerConfigs = brokers.map { + case (brokerId, rack) => + val config = TestUtils.createBrokerConfig( + nodeId = brokerId, + zkConnect = zkConnect, + rack = Some(rack), + enableControlledShutdown = false, // shorten test time + logDirCount = 3) + // shorter backoff to reduce test durations when no active partitions are eligible for fetching due to throttling + config.setProperty(KafkaConfig.ReplicaFetchBackoffMsProp, "100") + // Don't move partition leaders automatically. + config.setProperty(KafkaConfig.AutoLeaderRebalanceEnableProp, "false") + config.setProperty(KafkaConfig.ReplicaLagTimeMaxMsProp, "1000") + configOverrides.forKeyValue(config.setProperty) + + brokerConfigOverrides.get(brokerId).foreach { overrides => + overrides.forKeyValue(config.setProperty) + } + + config + }.toBuffer + + var servers = new mutable.ArrayBuffer[KafkaServer] + + var brokerList: String = null + + var adminClient: Admin = null + + def setup(): Unit = { + createServers() + createTopics() + } + + def createServers(): Unit = { + brokers.keySet.foreach { brokerId => + servers += TestUtils.createServer(KafkaConfig(brokerConfigs(brokerId))) + } + } + + def createTopics(): Unit = { + TestUtils.waitUntilBrokerMetadataIsPropagated(servers) + brokerList = TestUtils.bootstrapServers(servers, + ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT)) + adminClient = Admin.create(Map[String, Object]( + AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG -> brokerList + ).asJava) + adminClient.createTopics(topics.map { + case (topicName, parts) => + val partMap = new HashMap[Integer, List[Integer]]() + parts.zipWithIndex.foreach { + case (part, index) => partMap.put(index, part.map(Integer.valueOf(_)).asJava) + } + new NewTopic(topicName, partMap) + }.toList.asJava).all().get() + topics.foreach { + case (topicName, parts) => + TestUtils.waitForAllPartitionsMetadata(servers, topicName, parts.size) + } + } + + def produceMessages(topic: String, partition: Int, numMessages: Int): Unit = { + val records = (0 until numMessages).map(_ => + new ProducerRecord[Array[Byte], Array[Byte]](topic, partition, + null, new Array[Byte](10000))) + TestUtils.produceMessages(servers, records, -1) + } + + override def close(): Unit = { + brokerList = null + Utils.closeQuietly(adminClient, "adminClient") + adminClient = null + try { + TestUtils.shutdownServers(servers) + } finally { + servers.clear() + } + } + } +} diff --git a/core/src/test/scala/integration/kafka/api/AbstractConsumerTest.scala b/core/src/test/scala/integration/kafka/api/AbstractConsumerTest.scala new file mode 100644 index 0000000..56bc47c --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/AbstractConsumerTest.scala @@ -0,0 +1,480 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.api + +import java.time.Duration +import java.util +import java.util.Properties + +import org.apache.kafka.clients.consumer._ +import org.apache.kafka.clients.producer.{ProducerConfig, ProducerRecord} +import org.apache.kafka.common.record.TimestampType +import org.apache.kafka.common.TopicPartition +import kafka.utils.{ShutdownableThread, TestUtils} +import kafka.server.{BaseRequestTest, KafkaConfig} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{BeforeEach, TestInfo} + +import scala.jdk.CollectionConverters._ +import scala.collection.mutable.{ArrayBuffer, Buffer} +import org.apache.kafka.clients.producer.KafkaProducer +import org.apache.kafka.common.errors.WakeupException + +import scala.collection.mutable + +/** + * Extension point for consumer integration tests. + */ +abstract class AbstractConsumerTest extends BaseRequestTest { + + val epsilon = 0.1 + override def brokerCount: Int = 3 + + val topic = "topic" + val part = 0 + val tp = new TopicPartition(topic, part) + val part2 = 1 + val tp2 = new TopicPartition(topic, part2) + val group = "my-test" + val producerClientId = "ConsumerTestProducer" + val consumerClientId = "ConsumerTestConsumer" + val groupMaxSessionTimeoutMs = 60000L + + this.producerConfig.setProperty(ProducerConfig.ACKS_CONFIG, "all") + this.producerConfig.setProperty(ProducerConfig.CLIENT_ID_CONFIG, producerClientId) + this.consumerConfig.setProperty(ConsumerConfig.CLIENT_ID_CONFIG, consumerClientId) + this.consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, group) + this.consumerConfig.setProperty(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest") + this.consumerConfig.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") + this.consumerConfig.setProperty(ConsumerConfig.METADATA_MAX_AGE_CONFIG, "100") + this.consumerConfig.setProperty(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG, "6000") + + + override protected def brokerPropertyOverrides(properties: Properties): Unit = { + properties.setProperty(KafkaConfig.ControlledShutdownEnableProp, "false") // speed up shutdown + properties.setProperty(KafkaConfig.OffsetsTopicReplicationFactorProp, "3") // don't want to lose offset + properties.setProperty(KafkaConfig.OffsetsTopicPartitionsProp, "1") + properties.setProperty(KafkaConfig.GroupMinSessionTimeoutMsProp, "100") // set small enough session timeout + properties.setProperty(KafkaConfig.GroupMaxSessionTimeoutMsProp, groupMaxSessionTimeoutMs.toString) + properties.setProperty(KafkaConfig.GroupInitialRebalanceDelayMsProp, "10") + } + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + + // create the test topic with all the brokers as replicas + createTopic(topic, 2, brokerCount) + } + + protected class TestConsumerReassignmentListener extends ConsumerRebalanceListener { + var callsToAssigned = 0 + var callsToRevoked = 0 + + def onPartitionsAssigned(partitions: java.util.Collection[TopicPartition]): Unit = { + info("onPartitionsAssigned called.") + callsToAssigned += 1 + } + + def onPartitionsRevoked(partitions: java.util.Collection[TopicPartition]): Unit = { + info("onPartitionsRevoked called.") + callsToRevoked += 1 + } + } + + protected def createConsumerWithGroupId(groupId: String): KafkaConsumer[Array[Byte], Array[Byte]] = { + val groupOverrideConfig = new Properties + groupOverrideConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, groupId) + createConsumer(configOverrides = groupOverrideConfig) + } + + protected def sendRecords(producer: KafkaProducer[Array[Byte], Array[Byte]], numRecords: Int, + tp: TopicPartition, + startingTimestamp: Long = System.currentTimeMillis()): Seq[ProducerRecord[Array[Byte], Array[Byte]]] = { + val records = (0 until numRecords).map { i => + val timestamp = startingTimestamp + i.toLong + val record = new ProducerRecord(tp.topic(), tp.partition(), timestamp, s"key $i".getBytes, s"value $i".getBytes) + producer.send(record) + record + } + producer.flush() + + records + } + + protected def consumeAndVerifyRecords(consumer: Consumer[Array[Byte], Array[Byte]], + numRecords: Int, + startingOffset: Int, + startingKeyAndValueIndex: Int = 0, + startingTimestamp: Long = 0L, + timestampType: TimestampType = TimestampType.CREATE_TIME, + tp: TopicPartition = tp, + maxPollRecords: Int = Int.MaxValue): Unit = { + val records = consumeRecords(consumer, numRecords, maxPollRecords = maxPollRecords) + val now = System.currentTimeMillis() + for (i <- 0 until numRecords) { + val record = records(i) + val offset = startingOffset + i + assertEquals(tp.topic, record.topic) + assertEquals(tp.partition, record.partition) + if (timestampType == TimestampType.CREATE_TIME) { + assertEquals(timestampType, record.timestampType) + val timestamp = startingTimestamp + i + assertEquals(timestamp.toLong, record.timestamp) + } else + assertTrue(record.timestamp >= startingTimestamp && record.timestamp <= now, + s"Got unexpected timestamp ${record.timestamp}. Timestamp should be between [$startingTimestamp, $now}]") + assertEquals(offset.toLong, record.offset) + val keyAndValueIndex = startingKeyAndValueIndex + i + assertEquals(s"key $keyAndValueIndex", new String(record.key)) + assertEquals(s"value $keyAndValueIndex", new String(record.value)) + // this is true only because K and V are byte arrays + assertEquals(s"key $keyAndValueIndex".length, record.serializedKeySize) + assertEquals(s"value $keyAndValueIndex".length, record.serializedValueSize) + } + } + + protected def consumeRecords[K, V](consumer: Consumer[K, V], + numRecords: Int, + maxPollRecords: Int = Int.MaxValue): ArrayBuffer[ConsumerRecord[K, V]] = { + val records = new ArrayBuffer[ConsumerRecord[K, V]] + def pollAction(polledRecords: ConsumerRecords[K, V]): Boolean = { + assertTrue(polledRecords.asScala.size <= maxPollRecords) + records ++= polledRecords.asScala + records.size >= numRecords + } + TestUtils.pollRecordsUntilTrue(consumer, pollAction, waitTimeMs = 60000, + msg = s"Timed out before consuming expected $numRecords records. " + + s"The number consumed was ${records.size}.") + records + } + + + /** + * Creates topic 'topicName' with 'numPartitions' partitions and produces 'recordsPerPartition' + * records to each partition + */ + protected def createTopicAndSendRecords(producer: KafkaProducer[Array[Byte], Array[Byte]], + topicName: String, + numPartitions: Int, + recordsPerPartition: Int): Set[TopicPartition] = { + createTopic(topicName, numPartitions, brokerCount) + var parts = Set[TopicPartition]() + for (partition <- 0 until numPartitions) { + val tp = new TopicPartition(topicName, partition) + sendRecords(producer, recordsPerPartition, tp) + parts = parts + tp + } + parts + } + + protected def sendAndAwaitAsyncCommit[K, V](consumer: Consumer[K, V], + offsetsOpt: Option[Map[TopicPartition, OffsetAndMetadata]] = None): Unit = { + + def sendAsyncCommit(callback: OffsetCommitCallback) = { + offsetsOpt match { + case Some(offsets) => consumer.commitAsync(offsets.asJava, callback) + case None => consumer.commitAsync(callback) + } + } + + class RetryCommitCallback extends OffsetCommitCallback { + var isComplete = false + var error: Option[Exception] = None + + override def onComplete(offsets: util.Map[TopicPartition, OffsetAndMetadata], exception: Exception): Unit = { + exception match { + case e: RetriableCommitFailedException => + sendAsyncCommit(this) + case e => + isComplete = true + error = Option(e) + } + } + } + + val commitCallback = new RetryCommitCallback + + sendAsyncCommit(commitCallback) + TestUtils.pollUntilTrue(consumer, () => commitCallback.isComplete, + "Failed to observe commit callback before timeout", waitTimeMs = 10000) + + assertEquals(None, commitCallback.error) + } + + /** + * Create 'numOfConsumersToAdd' consumers add then to the consumer group 'consumerGroup', and create corresponding + * pollers for these consumers. Wait for partition re-assignment and validate. + * + * Currently, assignment validation requires that total number of partitions is greater or equal to + * number of consumers, so subscriptions.size must be greater or equal the resulting number of consumers in the group + * + * @param numOfConsumersToAdd number of consumers to create and add to the consumer group + * @param consumerGroup current consumer group + * @param consumerPollers current consumer pollers + * @param topicsToSubscribe topics to which new consumers will subscribe to + * @param subscriptions set of all topic partitions + */ + def addConsumersToGroupAndWaitForGroupAssignment(numOfConsumersToAdd: Int, + consumerGroup: mutable.Buffer[KafkaConsumer[Array[Byte], Array[Byte]]], + consumerPollers: mutable.Buffer[ConsumerAssignmentPoller], + topicsToSubscribe: List[String], + subscriptions: Set[TopicPartition], + group: String = group): (mutable.Buffer[KafkaConsumer[Array[Byte], Array[Byte]]], mutable.Buffer[ConsumerAssignmentPoller]) = { + assertTrue(consumerGroup.size + numOfConsumersToAdd <= subscriptions.size) + addConsumersToGroup(numOfConsumersToAdd, consumerGroup, consumerPollers, topicsToSubscribe, subscriptions, group) + // wait until topics get re-assigned and validate assignment + validateGroupAssignment(consumerPollers, subscriptions) + + (consumerGroup, consumerPollers) + } + + /** + * Create 'numOfConsumersToAdd' consumers add then to the consumer group 'consumerGroup', and create corresponding + * pollers for these consumers. + * + * + * @param numOfConsumersToAdd number of consumers to create and add to the consumer group + * @param consumerGroup current consumer group + * @param consumerPollers current consumer pollers + * @param topicsToSubscribe topics to which new consumers will subscribe to + * @param subscriptions set of all topic partitions + */ + def addConsumersToGroup(numOfConsumersToAdd: Int, + consumerGroup: mutable.Buffer[KafkaConsumer[Array[Byte], Array[Byte]]], + consumerPollers: mutable.Buffer[ConsumerAssignmentPoller], + topicsToSubscribe: List[String], + subscriptions: Set[TopicPartition], + group: String = group): (mutable.Buffer[KafkaConsumer[Array[Byte], Array[Byte]]], mutable.Buffer[ConsumerAssignmentPoller]) = { + for (_ <- 0 until numOfConsumersToAdd) { + val consumer = createConsumerWithGroupId(group) + consumerGroup += consumer + consumerPollers += subscribeConsumerAndStartPolling(consumer, topicsToSubscribe) + } + + (consumerGroup, consumerPollers) + } + + /** + * Wait for consumers to get partition assignment and validate it. + * + * @param consumerPollers consumer pollers corresponding to the consumer group we are testing + * @param subscriptions set of all topic partitions + * @param msg message to print when waiting for/validating assignment fails + */ + def validateGroupAssignment(consumerPollers: mutable.Buffer[ConsumerAssignmentPoller], + subscriptions: Set[TopicPartition], + msg: Option[String] = None, + waitTime: Long = 10000L, + expectedAssignment: Buffer[Set[TopicPartition]] = Buffer()): Unit = { + val assignments = mutable.Buffer[Set[TopicPartition]]() + TestUtils.waitUntilTrue(() => { + assignments.clear() + consumerPollers.foreach(assignments += _.consumerAssignment()) + isPartitionAssignmentValid(assignments, subscriptions, expectedAssignment) + }, msg.getOrElse(s"Did not get valid assignment for partitions $subscriptions. Instead, got $assignments"), waitTime) + } + + /** + * Subscribes consumer 'consumer' to a given list of topics 'topicsToSubscribe', creates + * consumer poller and starts polling. + * Assumes that the consumer is not subscribed to any topics yet + * + * @param consumer consumer + * @param topicsToSubscribe topics that this consumer will subscribe to + * @return consumer poller for the given consumer + */ + def subscribeConsumerAndStartPolling(consumer: Consumer[Array[Byte], Array[Byte]], + topicsToSubscribe: List[String], + partitionsToAssign: Set[TopicPartition] = Set.empty[TopicPartition]): ConsumerAssignmentPoller = { + assertEquals(0, consumer.assignment().size) + val consumerPoller = if (topicsToSubscribe.nonEmpty) + new ConsumerAssignmentPoller(consumer, topicsToSubscribe) + else + new ConsumerAssignmentPoller(consumer, partitionsToAssign) + + consumerPoller.start() + consumerPoller + } + + protected def awaitRebalance(consumer: Consumer[_, _], rebalanceListener: TestConsumerReassignmentListener): Unit = { + val numReassignments = rebalanceListener.callsToAssigned + TestUtils.pollUntilTrue(consumer, () => rebalanceListener.callsToAssigned > numReassignments, + "Timed out before expected rebalance completed") + } + + protected def ensureNoRebalance(consumer: Consumer[_, _], rebalanceListener: TestConsumerReassignmentListener): Unit = { + // The best way to verify that the current membership is still active is to commit offsets. + // This would fail if the group had rebalanced. + val initialRevokeCalls = rebalanceListener.callsToRevoked + sendAndAwaitAsyncCommit(consumer) + assertEquals(initialRevokeCalls, rebalanceListener.callsToRevoked) + } + + protected class CountConsumerCommitCallback extends OffsetCommitCallback { + var successCount = 0 + var failCount = 0 + var lastError: Option[Exception] = None + + override def onComplete(offsets: util.Map[TopicPartition, OffsetAndMetadata], exception: Exception): Unit = { + if (exception == null) { + successCount += 1 + } else { + failCount += 1 + lastError = Some(exception) + } + } + } + + protected class ConsumerAssignmentPoller(consumer: Consumer[Array[Byte], Array[Byte]], + topicsToSubscribe: List[String], + partitionsToAssign: Set[TopicPartition]) + extends ShutdownableThread("daemon-consumer-assignment", false) { + + def this(consumer: Consumer[Array[Byte], Array[Byte]], topicsToSubscribe: List[String]) = { + this(consumer, topicsToSubscribe, Set.empty[TopicPartition]) + } + + def this(consumer: Consumer[Array[Byte], Array[Byte]], partitionsToAssign: Set[TopicPartition]) = { + this(consumer, List.empty[String], partitionsToAssign) + } + + @volatile var thrownException: Option[Throwable] = None + @volatile var receivedMessages = 0 + + private val partitionAssignment = mutable.Set[TopicPartition]() + @volatile private var subscriptionChanged = false + private var topicsSubscription = topicsToSubscribe + + val rebalanceListener: ConsumerRebalanceListener = new ConsumerRebalanceListener { + override def onPartitionsAssigned(partitions: util.Collection[TopicPartition]) = { + partitionAssignment ++= partitions.toArray(new Array[TopicPartition](0)) + } + + override def onPartitionsRevoked(partitions: util.Collection[TopicPartition]) = { + partitionAssignment --= partitions.toArray(new Array[TopicPartition](0)) + } + } + + if (partitionsToAssign.isEmpty) { + consumer.subscribe(topicsToSubscribe.asJava, rebalanceListener) + } else { + consumer.assign(partitionsToAssign.asJava) + } + + def consumerAssignment(): Set[TopicPartition] = { + partitionAssignment.toSet + } + + /** + * Subscribe consumer to a new set of topics. + * Since this method most likely be called from a different thread, this function + * just "schedules" the subscription change, and actual call to consumer.subscribe is done + * in the doWork() method + * + * This method does not allow to change subscription until doWork processes the previous call + * to this method. This is just to avoid race conditions and enough functionality for testing purposes + * @param newTopicsToSubscribe + */ + def subscribe(newTopicsToSubscribe: List[String]): Unit = { + if (subscriptionChanged) + throw new IllegalStateException("Do not call subscribe until the previous subscribe request is processed.") + if (partitionsToAssign.nonEmpty) + throw new IllegalStateException("Cannot call subscribe when configured to use manual partition assignment") + + topicsSubscription = newTopicsToSubscribe + subscriptionChanged = true + } + + def isSubscribeRequestProcessed: Boolean = { + !subscriptionChanged + } + + override def initiateShutdown(): Boolean = { + val res = super.initiateShutdown() + consumer.wakeup() + res + } + + override def doWork(): Unit = { + if (subscriptionChanged) { + consumer.subscribe(topicsSubscription.asJava, rebalanceListener) + subscriptionChanged = false + } + try { + receivedMessages += consumer.poll(Duration.ofMillis(50)).count() + } catch { + case _: WakeupException => // ignore for shutdown + case e: Throwable => + thrownException = Some(e) + throw e + } + } + } + + /** + * Check whether partition assignment is valid + * Assumes partition assignment is valid iff + * 1. Every consumer got assigned at least one partition + * 2. Each partition is assigned to only one consumer + * 3. Every partition is assigned to one of the consumers + * 4. The assignment is the same as expected assignment (if provided) + * + * @param assignments set of consumer assignments; one per each consumer + * @param partitions set of partitions that consumers subscribed to + * @return true if partition assignment is valid + */ + def isPartitionAssignmentValid(assignments: Buffer[Set[TopicPartition]], + partitions: Set[TopicPartition], + expectedAssignment: Buffer[Set[TopicPartition]]): Boolean = { + val allNonEmptyAssignments = assignments.forall(assignment => assignment.nonEmpty) + if (!allNonEmptyAssignments) { + // at least one consumer got empty assignment + return false + } + + // make sure that sum of all partitions to all consumers equals total number of partitions + val totalPartitionsInAssignments = assignments.foldLeft(0)(_ + _.size) + if (totalPartitionsInAssignments != partitions.size) { + // either same partitions got assigned to more than one consumer or some + // partitions were not assigned + return false + } + + // The above checks could miss the case where one or more partitions were assigned to more + // than one consumer and the same number of partitions were missing from assignments. + // Make sure that all unique assignments are the same as 'partitions' + val uniqueAssignedPartitions = assignments.foldLeft(Set.empty[TopicPartition])(_ ++ _) + if (uniqueAssignedPartitions != partitions) { + return false + } + + // check the assignment is the same as the expected assignment if provided + // Note: since we've checked that each partition is assigned to only one consumer, + // we just need to check the assignment is included in the expected assignment + if (expectedAssignment.nonEmpty) { + for (assignment <- assignments) { + if (!expectedAssignment.contains(assignment)) { + return false + } + } + } + + true + } + +} diff --git a/core/src/test/scala/integration/kafka/api/AdminClientWithPoliciesIntegrationTest.scala b/core/src/test/scala/integration/kafka/api/AdminClientWithPoliciesIntegrationTest.scala new file mode 100644 index 0000000..61018bb --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/AdminClientWithPoliciesIntegrationTest.scala @@ -0,0 +1,203 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package kafka.api + +import java.util +import java.util.Properties +import java.util.concurrent.ExecutionException +import kafka.integration.KafkaServerTestHarness +import kafka.log.LogConfig +import kafka.server.{Defaults, KafkaConfig} +import kafka.utils.{Logging, TestUtils} +import org.apache.kafka.clients.admin.{Admin, AdminClientConfig, AlterConfigsOptions, Config, ConfigEntry} +import org.apache.kafka.common.config.{ConfigResource, TopicConfig} +import org.apache.kafka.common.errors.{InvalidRequestException, PolicyViolationException} +import org.apache.kafka.common.utils.Utils +import org.apache.kafka.server.policy.AlterConfigPolicy +import org.junit.jupiter.api.Assertions.{assertEquals, assertNull, assertThrows, assertTrue} +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo, Timeout} + +import scala.annotation.nowarn +import scala.jdk.CollectionConverters._ + +/** + * Tests AdminClient calls when the broker is configured with policies like AlterConfigPolicy, CreateTopicPolicy, etc. + */ +@Timeout(120) +class AdminClientWithPoliciesIntegrationTest extends KafkaServerTestHarness with Logging { + + import AdminClientWithPoliciesIntegrationTest._ + + var client: Admin = null + val brokerCount = 3 + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + TestUtils.waitUntilBrokerMetadataIsPropagated(servers) + } + + @AfterEach + override def tearDown(): Unit = { + if (client != null) + Utils.closeQuietly(client, "AdminClient") + super.tearDown() + } + + def createConfig: util.Map[String, Object] = + Map[String, Object](AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG -> brokerList).asJava + + override def generateConfigs = { + val configs = TestUtils.createBrokerConfigs(brokerCount, zkConnect) + configs.foreach(props => props.put(KafkaConfig.AlterConfigPolicyClassNameProp, classOf[Policy])) + configs.map(KafkaConfig.fromProps) + } + + @Test + def testValidAlterConfigs(): Unit = { + client = Admin.create(createConfig) + // Create topics + val topic1 = "describe-alter-configs-topic-1" + val topicResource1 = new ConfigResource(ConfigResource.Type.TOPIC, topic1) + val topicConfig1 = new Properties + topicConfig1.setProperty(LogConfig.MaxMessageBytesProp, "500000") + topicConfig1.setProperty(LogConfig.RetentionMsProp, "60000000") + createTopic(topic1, 1, 1, topicConfig1) + + val topic2 = "describe-alter-configs-topic-2" + val topicResource2 = new ConfigResource(ConfigResource.Type.TOPIC, topic2) + createTopic(topic2, 1, 1) + + PlaintextAdminIntegrationTest.checkValidAlterConfigs(client, topicResource1, topicResource2) + } + + @Test + def testInvalidAlterConfigs(): Unit = { + client = Admin.create(createConfig) + PlaintextAdminIntegrationTest.checkInvalidAlterConfigs(zkClient, servers, client) + } + + @nowarn("cat=deprecation") + @Test + def testInvalidAlterConfigsDueToPolicy(): Unit = { + client = Admin.create(createConfig) + + // Create topics + val topic1 = "invalid-alter-configs-due-to-policy-topic-1" + val topicResource1 = new ConfigResource(ConfigResource.Type.TOPIC, topic1) + createTopic(topic1, 1, 1) + + val topic2 = "invalid-alter-configs-due-to-policy-topic-2" + val topicResource2 = new ConfigResource(ConfigResource.Type.TOPIC, topic2) + createTopic(topic2, 1, 1) + + val topic3 = "invalid-alter-configs-due-to-policy-topic-3" + val topicResource3 = new ConfigResource(ConfigResource.Type.TOPIC, topic3) + createTopic(topic3, 1, 1) + + val topicConfigEntries1 = Seq( + new ConfigEntry(LogConfig.MinCleanableDirtyRatioProp, "0.9"), + new ConfigEntry(LogConfig.MinInSyncReplicasProp, "2") // policy doesn't allow this + ).asJava + + var topicConfigEntries2 = Seq(new ConfigEntry(LogConfig.MinCleanableDirtyRatioProp, "0.8")).asJava + + val topicConfigEntries3 = Seq(new ConfigEntry(LogConfig.MinInSyncReplicasProp, "-1")).asJava + + val brokerResource = new ConfigResource(ConfigResource.Type.BROKER, servers.head.config.brokerId.toString) + val brokerConfigEntries = Seq(new ConfigEntry(KafkaConfig.SslTruststorePasswordProp, "12313")).asJava + + // Alter configs: second is valid, the others are invalid + var alterResult = client.alterConfigs(Map( + topicResource1 -> new Config(topicConfigEntries1), + topicResource2 -> new Config(topicConfigEntries2), + topicResource3 -> new Config(topicConfigEntries3), + brokerResource -> new Config(brokerConfigEntries) + ).asJava) + + assertEquals(Set(topicResource1, topicResource2, topicResource3, brokerResource).asJava, alterResult.values.keySet) + assertTrue(assertThrows(classOf[ExecutionException], () => alterResult.values.get(topicResource1).get).getCause.isInstanceOf[PolicyViolationException]) + alterResult.values.get(topicResource2).get + assertTrue(assertThrows(classOf[ExecutionException], () => alterResult.values.get(topicResource3).get).getCause.isInstanceOf[InvalidRequestException]) + assertTrue(assertThrows(classOf[ExecutionException], () => alterResult.values.get(brokerResource).get).getCause.isInstanceOf[InvalidRequestException]) + + // Verify that the second resource was updated and the others were not + var describeResult = client.describeConfigs(Seq(topicResource1, topicResource2, topicResource3, brokerResource).asJava) + var configs = describeResult.all.get + assertEquals(4, configs.size) + + assertEquals(Defaults.LogCleanerMinCleanRatio.toString, configs.get(topicResource1).get(LogConfig.MinCleanableDirtyRatioProp).value) + assertEquals(Defaults.MinInSyncReplicas.toString, configs.get(topicResource1).get(LogConfig.MinInSyncReplicasProp).value) + + assertEquals("0.8", configs.get(topicResource2).get(LogConfig.MinCleanableDirtyRatioProp).value) + + assertNull(configs.get(brokerResource).get(KafkaConfig.SslTruststorePasswordProp).value) + + // Alter configs with validateOnly = true: only second is valid + topicConfigEntries2 = Seq(new ConfigEntry(LogConfig.MinCleanableDirtyRatioProp, "0.7")).asJava + + alterResult = client.alterConfigs(Map( + topicResource1 -> new Config(topicConfigEntries1), + topicResource2 -> new Config(topicConfigEntries2), + brokerResource -> new Config(brokerConfigEntries), + topicResource3 -> new Config(topicConfigEntries3) + ).asJava, new AlterConfigsOptions().validateOnly(true)) + + assertEquals(Set(topicResource1, topicResource2, topicResource3, brokerResource).asJava, alterResult.values.keySet) + assertTrue(assertThrows(classOf[ExecutionException], () => alterResult.values.get(topicResource1).get).getCause.isInstanceOf[PolicyViolationException]) + alterResult.values.get(topicResource2).get + assertTrue(assertThrows(classOf[ExecutionException], () => alterResult.values.get(topicResource3).get).getCause.isInstanceOf[InvalidRequestException]) + assertTrue(assertThrows(classOf[ExecutionException], () => alterResult.values.get(brokerResource).get).getCause.isInstanceOf[InvalidRequestException]) + + // Verify that no resources are updated since validate_only = true + describeResult = client.describeConfigs(Seq(topicResource1, topicResource2, topicResource3, brokerResource).asJava) + configs = describeResult.all.get + assertEquals(4, configs.size) + + assertEquals(Defaults.LogCleanerMinCleanRatio.toString, configs.get(topicResource1).get(LogConfig.MinCleanableDirtyRatioProp).value) + assertEquals(Defaults.MinInSyncReplicas.toString, configs.get(topicResource1).get(LogConfig.MinInSyncReplicasProp).value) + + assertEquals("0.8", configs.get(topicResource2).get(LogConfig.MinCleanableDirtyRatioProp).value) + + assertNull(configs.get(brokerResource).get(KafkaConfig.SslTruststorePasswordProp).value) + } + + +} + +object AdminClientWithPoliciesIntegrationTest { + + class Policy extends AlterConfigPolicy { + + var configs: Map[String, _] = _ + var closed = false + + def configure(configs: util.Map[String, _]): Unit = { + this.configs = configs.asScala.toMap + } + + def validate(requestMetadata: AlterConfigPolicy.RequestMetadata): Unit = { + require(!closed, "Policy should not be closed") + require(!configs.isEmpty, "configure should have been called with non empty configs") + require(!requestMetadata.configs.isEmpty, "request configs should not be empty") + require(requestMetadata.resource.name.nonEmpty, "resource name should not be empty") + require(requestMetadata.resource.name.contains("topic")) + if (requestMetadata.configs.containsKey(TopicConfig.MIN_IN_SYNC_REPLICAS_CONFIG)) + throw new PolicyViolationException("Min in sync replicas cannot be updated") + } + + def close(): Unit = closed = true + + } +} diff --git a/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala b/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala new file mode 100644 index 0000000..6efb860 --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala @@ -0,0 +1,2396 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package kafka.api + +import java.lang.{Byte => JByte} +import java.time.Duration +import java.util +import java.util.concurrent.ExecutionException +import java.util.regex.Pattern +import java.util.{Collections, Optional, Properties} + +import kafka.admin.ConsumerGroupCommand.{ConsumerGroupCommandOptions, ConsumerGroupService} +import kafka.log.LogConfig +import kafka.security.authorizer.{AclAuthorizer, AclEntry} +import kafka.security.authorizer.AclEntry.WildcardHost +import kafka.server.{BaseRequestTest, KafkaConfig} +import kafka.utils.TestUtils +import org.apache.kafka.clients.admin.{Admin, AdminClientConfig, AlterConfigOp} +import org.apache.kafka.clients.consumer._ +import org.apache.kafka.clients.consumer.internals.NoOpConsumerRebalanceListener +import org.apache.kafka.clients.producer._ +import org.apache.kafka.common.acl.AclOperation._ +import org.apache.kafka.common.acl.AclPermissionType.{ALLOW, DENY} +import org.apache.kafka.common.acl.{AccessControlEntry, AccessControlEntryFilter, AclBindingFilter, AclOperation, AclPermissionType} +import org.apache.kafka.common.config.internals.BrokerSecurityConfigs +import org.apache.kafka.common.config.{ConfigResource, LogLevelConfig} +import org.apache.kafka.common.errors._ +import org.apache.kafka.common.internals.Topic.GROUP_METADATA_TOPIC_NAME +import org.apache.kafka.common.message.CreatePartitionsRequestData.CreatePartitionsTopic +import org.apache.kafka.common.message.CreateTopicsRequestData.{CreatableTopic, CreatableTopicCollection} +import org.apache.kafka.common.message.IncrementalAlterConfigsRequestData.{AlterConfigsResource, AlterableConfig, AlterableConfigCollection} +import org.apache.kafka.common.message.JoinGroupRequestData.JoinGroupRequestProtocolCollection +import org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrPartitionState +import org.apache.kafka.common.message.LeaveGroupRequestData.MemberIdentity +import org.apache.kafka.common.message.ListOffsetsRequestData.{ListOffsetsPartition, ListOffsetsTopic} +import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.{OffsetForLeaderPartition, OffsetForLeaderTopic, OffsetForLeaderTopicCollection} +import org.apache.kafka.common.message.StopReplicaRequestData.{StopReplicaPartitionState, StopReplicaTopicState} +import org.apache.kafka.common.message.UpdateMetadataRequestData.{UpdateMetadataBroker, UpdateMetadataEndpoint, UpdateMetadataPartitionState} +import org.apache.kafka.common.message.{AddOffsetsToTxnRequestData, AlterPartitionReassignmentsRequestData, AlterReplicaLogDirsRequestData, ControlledShutdownRequestData, CreateAclsRequestData, CreatePartitionsRequestData, CreateTopicsRequestData, DeleteAclsRequestData, DeleteGroupsRequestData, DeleteRecordsRequestData, DeleteTopicsRequestData, DescribeClusterRequestData, DescribeConfigsRequestData, DescribeGroupsRequestData, DescribeLogDirsRequestData, DescribeProducersRequestData, DescribeTransactionsRequestData, FindCoordinatorRequestData, HeartbeatRequestData, IncrementalAlterConfigsRequestData, JoinGroupRequestData, ListPartitionReassignmentsRequestData, ListTransactionsRequestData, MetadataRequestData, OffsetCommitRequestData, ProduceRequestData, SyncGroupRequestData} +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.record.{CompressionType, MemoryRecords, RecordBatch, SimpleRecord} +import org.apache.kafka.common.requests.OffsetFetchResponse.PartitionData +import org.apache.kafka.common.requests._ +import org.apache.kafka.common.resource.PatternType.{LITERAL, PREFIXED} +import org.apache.kafka.common.resource.ResourceType._ +import org.apache.kafka.common.resource.{PatternType, Resource, ResourcePattern, ResourcePatternFilter, ResourceType} +import org.apache.kafka.common.security.auth.{AuthenticationContext, KafkaPrincipal, SecurityProtocol} +import org.apache.kafka.common.security.authenticator.DefaultKafkaPrincipalBuilder +import org.apache.kafka.common.utils.Utils +import org.apache.kafka.common.{ElectionType, IsolationLevel, Node, TopicPartition, Uuid, requests} +import org.apache.kafka.test.{TestUtils => JTestUtils} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ValueSource +import java.util.Collections.singletonList + +import scala.annotation.nowarn +import scala.collection.mutable +import scala.collection.mutable.Buffer +import scala.jdk.CollectionConverters._ + +object AuthorizerIntegrationTest { + val BrokerPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "broker") + val ClientPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "client") + + val BrokerListenerName = "BROKER" + val ClientListenerName = "CLIENT" + + class PrincipalBuilder extends DefaultKafkaPrincipalBuilder(null, null) { + override def build(context: AuthenticationContext): KafkaPrincipal = { + context.listenerName match { + case BrokerListenerName => BrokerPrincipal + case ClientListenerName => ClientPrincipal + case listenerName => throw new IllegalArgumentException(s"No principal mapped to listener $listenerName") + } + } + } +} + +class AuthorizerIntegrationTest extends BaseRequestTest { + import AuthorizerIntegrationTest._ + + override def interBrokerListenerName: ListenerName = new ListenerName(BrokerListenerName) + override def listenerName: ListenerName = new ListenerName(ClientListenerName) + override def brokerCount: Int = 1 + + def clientPrincipal: KafkaPrincipal = ClientPrincipal + def brokerPrincipal: KafkaPrincipal = BrokerPrincipal + + val clientPrincipalString: String = clientPrincipal.toString + + val brokerId: Integer = 0 + val topic = "topic" + val topicPattern = "topic.*" + val transactionalId = "transactional.id" + val producerId = 83392L + val part = 0 + val correlationId = 0 + val clientId = "client-Id" + val tp = new TopicPartition(topic, part) + val logDir = "logDir" + val group = "my-group" + val protocolType = "consumer" + val protocolName = "consumer-range" + val clusterResource = new ResourcePattern(CLUSTER, Resource.CLUSTER_NAME, LITERAL) + val topicResource = new ResourcePattern(TOPIC, topic, LITERAL) + val groupResource = new ResourcePattern(GROUP, group, LITERAL) + val transactionalIdResource = new ResourcePattern(TRANSACTIONAL_ID, transactionalId, LITERAL) + + val groupReadAcl = Map(groupResource -> Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW))) + val groupDescribeAcl = Map(groupResource -> Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW))) + val groupDeleteAcl = Map(groupResource -> Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DELETE, ALLOW))) + val clusterAcl = Map(clusterResource -> Set(new AccessControlEntry(clientPrincipalString, WildcardHost, CLUSTER_ACTION, ALLOW))) + val clusterCreateAcl = Map(clusterResource -> Set(new AccessControlEntry(clientPrincipalString, WildcardHost, CREATE, ALLOW))) + val clusterAlterAcl = Map(clusterResource -> Set(new AccessControlEntry(clientPrincipalString, WildcardHost, ALTER, ALLOW))) + val clusterDescribeAcl = Map(clusterResource -> Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW))) + val clusterAlterConfigsAcl = Map(clusterResource -> Set(new AccessControlEntry(clientPrincipalString, WildcardHost, ALTER_CONFIGS, ALLOW))) + val clusterIdempotentWriteAcl = Map(clusterResource -> Set(new AccessControlEntry(clientPrincipalString, WildcardHost, IDEMPOTENT_WRITE, ALLOW))) + val topicCreateAcl = Map(topicResource -> Set(new AccessControlEntry(clientPrincipalString, WildcardHost, CREATE, ALLOW))) + val topicReadAcl = Map(topicResource -> Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW))) + val topicWriteAcl = Map(topicResource -> Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW))) + val topicDescribeAcl = Map(topicResource -> Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW))) + val topicAlterAcl = Map(topicResource -> Set(new AccessControlEntry(clientPrincipalString, WildcardHost, ALTER, ALLOW))) + val topicDeleteAcl = Map(topicResource -> Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DELETE, ALLOW))) + val topicDescribeConfigsAcl = Map(topicResource -> Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE_CONFIGS, ALLOW))) + val topicAlterConfigsAcl = Map(topicResource -> Set(new AccessControlEntry(clientPrincipalString, WildcardHost, ALTER_CONFIGS, ALLOW))) + val transactionIdWriteAcl = Map(transactionalIdResource -> Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW))) + val transactionalIdDescribeAcl = Map(transactionalIdResource -> Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW))) + + val numRecords = 1 + val adminClients = Buffer[Admin]() + + producerConfig.setProperty(ProducerConfig.ACKS_CONFIG, "1") + producerConfig.setProperty(ProducerConfig.MAX_BLOCK_MS_CONFIG, "50000") + consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, group) + + override def brokerPropertyOverrides(properties: Properties): Unit = { + properties.put(KafkaConfig.AuthorizerClassNameProp, classOf[AclAuthorizer].getName) + properties.put(KafkaConfig.BrokerIdProp, brokerId.toString) + properties.put(KafkaConfig.OffsetsTopicPartitionsProp, "1") + properties.put(KafkaConfig.OffsetsTopicReplicationFactorProp, "1") + properties.put(KafkaConfig.TransactionsTopicPartitionsProp, "1") + properties.put(KafkaConfig.TransactionsTopicReplicationFactorProp, "1") + properties.put(KafkaConfig.TransactionsTopicMinISRProp, "1") + properties.put(BrokerSecurityConfigs.PRINCIPAL_BUILDER_CLASS_CONFIG, + classOf[PrincipalBuilder].getName) + } + + val requestKeyToError = (topicNames: Map[Uuid, String], version: Short) => Map[ApiKeys, Nothing => Errors]( + ApiKeys.METADATA -> ((resp: requests.MetadataResponse) => resp.errors.asScala.find(_._1 == topic).getOrElse(("test", Errors.NONE))._2), + ApiKeys.PRODUCE -> ((resp: requests.ProduceResponse) => { + Errors.forCode( + resp.data + .responses.find(topic) + .partitionResponses.asScala.find(_.index == part).get + .errorCode + ) + }), + // We may need to get the top level error if the topic does not exist in the response + ApiKeys.FETCH -> ((resp: requests.FetchResponse) => Errors.forCode(resp.responseData(topicNames.asJava, version).asScala.find { + case (topicPartition, _) => topicPartition == tp}.map { case (_, data) => data.errorCode }.getOrElse(resp.error.code()))), + ApiKeys.LIST_OFFSETS -> ((resp: ListOffsetsResponse) => { + Errors.forCode( + resp.data + .topics.asScala.find(_.name == topic).get + .partitions.asScala.find(_.partitionIndex == part).get + .errorCode + ) + }), + ApiKeys.OFFSET_COMMIT -> ((resp: requests.OffsetCommitResponse) => Errors.forCode( + resp.data.topics().get(0).partitions().get(0).errorCode)), + ApiKeys.OFFSET_FETCH -> ((resp: requests.OffsetFetchResponse) => resp.groupLevelError(group)), + ApiKeys.FIND_COORDINATOR -> ((resp: FindCoordinatorResponse) => { + Errors.forCode(resp.data.coordinators.asScala.find(g => group == g.key).head.errorCode) + }), + ApiKeys.UPDATE_METADATA -> ((resp: requests.UpdateMetadataResponse) => resp.error), + ApiKeys.JOIN_GROUP -> ((resp: JoinGroupResponse) => resp.error), + ApiKeys.SYNC_GROUP -> ((resp: SyncGroupResponse) => Errors.forCode(resp.data.errorCode)), + ApiKeys.DESCRIBE_GROUPS -> ((resp: DescribeGroupsResponse) => { + Errors.forCode(resp.data.groups.asScala.find(g => group == g.groupId).head.errorCode) + }), + ApiKeys.HEARTBEAT -> ((resp: HeartbeatResponse) => resp.error), + ApiKeys.LEAVE_GROUP -> ((resp: LeaveGroupResponse) => resp.error), + ApiKeys.DELETE_GROUPS -> ((resp: DeleteGroupsResponse) => resp.get(group)), + ApiKeys.LEADER_AND_ISR -> ((resp: requests.LeaderAndIsrResponse) => Errors.forCode( + resp.topics.asScala.find(t => topicNames(t.topicId) == tp.topic).get.partitionErrors.asScala.find( + p => p.partitionIndex == tp.partition).get.errorCode)), + ApiKeys.STOP_REPLICA -> ((resp: requests.StopReplicaResponse) => Errors.forCode( + resp.partitionErrors.asScala.find(pe => pe.topicName == tp.topic && pe.partitionIndex == tp.partition).get.errorCode)), + ApiKeys.CONTROLLED_SHUTDOWN -> ((resp: requests.ControlledShutdownResponse) => resp.error), + ApiKeys.CREATE_TOPICS -> ((resp: CreateTopicsResponse) => Errors.forCode(resp.data.topics.find(topic).errorCode)), + ApiKeys.DELETE_TOPICS -> ((resp: requests.DeleteTopicsResponse) => Errors.forCode(resp.data.responses.find(topic).errorCode)), + ApiKeys.DELETE_RECORDS -> ((resp: requests.DeleteRecordsResponse) => Errors.forCode( + resp.data.topics.find(tp.topic).partitions.find(tp.partition).errorCode)), + ApiKeys.OFFSET_FOR_LEADER_EPOCH -> ((resp: OffsetsForLeaderEpochResponse) => Errors.forCode( + resp.data.topics.find(tp.topic).partitions.asScala.find(_.partition == tp.partition).get.errorCode)), + ApiKeys.DESCRIBE_CONFIGS -> ((resp: DescribeConfigsResponse) => + Errors.forCode(resp.resultMap.get(new ConfigResource(ConfigResource.Type.TOPIC, tp.topic)).errorCode)), + ApiKeys.ALTER_CONFIGS -> ((resp: AlterConfigsResponse) => + resp.errors.get(new ConfigResource(ConfigResource.Type.TOPIC, tp.topic)).error), + ApiKeys.INIT_PRODUCER_ID -> ((resp: InitProducerIdResponse) => resp.error), + ApiKeys.WRITE_TXN_MARKERS -> ((resp: WriteTxnMarkersResponse) => resp.errorsByProducerId.get(producerId).get(tp)), + ApiKeys.ADD_PARTITIONS_TO_TXN -> ((resp: AddPartitionsToTxnResponse) => resp.errors.get(tp)), + ApiKeys.ADD_OFFSETS_TO_TXN -> ((resp: AddOffsetsToTxnResponse) => Errors.forCode(resp.data.errorCode)), + ApiKeys.END_TXN -> ((resp: EndTxnResponse) => resp.error), + ApiKeys.TXN_OFFSET_COMMIT -> ((resp: TxnOffsetCommitResponse) => resp.errors.get(tp)), + ApiKeys.CREATE_ACLS -> ((resp: CreateAclsResponse) => Errors.forCode(resp.results.asScala.head.errorCode)), + ApiKeys.DESCRIBE_ACLS -> ((resp: DescribeAclsResponse) => resp.error.error), + ApiKeys.DELETE_ACLS -> ((resp: DeleteAclsResponse) => Errors.forCode(resp.filterResults.asScala.head.errorCode)), + ApiKeys.ALTER_REPLICA_LOG_DIRS -> ((resp: AlterReplicaLogDirsResponse) => Errors.forCode(resp.data.results.asScala + .find(x => x.topicName == tp.topic).get.partitions.asScala + .find(p => p.partitionIndex == tp.partition).get.errorCode)), + ApiKeys.DESCRIBE_LOG_DIRS -> ((resp: DescribeLogDirsResponse) => + if (resp.data.results.size() > 0) Errors.forCode(resp.data.results.get(0).errorCode) else Errors.CLUSTER_AUTHORIZATION_FAILED), + ApiKeys.CREATE_PARTITIONS -> ((resp: CreatePartitionsResponse) => Errors.forCode(resp.data.results.asScala.head.errorCode)), + ApiKeys.ELECT_LEADERS -> ((resp: ElectLeadersResponse) => Errors.forCode(resp.data.errorCode)), + ApiKeys.INCREMENTAL_ALTER_CONFIGS -> ((resp: IncrementalAlterConfigsResponse) => { + val topicResourceError = IncrementalAlterConfigsResponse.fromResponseData(resp.data).get(new ConfigResource(ConfigResource.Type.TOPIC, tp.topic)) + if (topicResourceError == null) + IncrementalAlterConfigsResponse.fromResponseData(resp.data).get(new ConfigResource(ConfigResource.Type.BROKER_LOGGER, brokerId.toString)).error + else + topicResourceError.error() + }), + ApiKeys.ALTER_PARTITION_REASSIGNMENTS -> ((resp: AlterPartitionReassignmentsResponse) => Errors.forCode(resp.data.errorCode)), + ApiKeys.LIST_PARTITION_REASSIGNMENTS -> ((resp: ListPartitionReassignmentsResponse) => Errors.forCode(resp.data.errorCode)), + ApiKeys.OFFSET_DELETE -> ((resp: OffsetDeleteResponse) => { + Errors.forCode( + resp.data + .topics.asScala.find(_.name == topic).get + .partitions.asScala.find(_.partitionIndex == part).get + .errorCode + ) + }), + ApiKeys.DESCRIBE_PRODUCERS -> ((resp: DescribeProducersResponse) => { + Errors.forCode( + resp.data + .topics.asScala.find(_.name == topic).get + .partitions.asScala.find(_.partitionIndex == part).get + .errorCode + ) + }), + ApiKeys.DESCRIBE_TRANSACTIONS -> ((resp: DescribeTransactionsResponse) => { + Errors.forCode( + resp.data + .transactionStates.asScala.find(_.transactionalId == transactionalId).get + .errorCode + ) + }) + ) + + def findErrorForTopicId(id: Uuid, response: AbstractResponse): Errors = { + response match { + case res: DeleteTopicsResponse => + Errors.forCode(res.data.responses.asScala.find(_.topicId == id).get.errorCode) + case _ => + fail(s"Unexpected response type $response") + } + } + + val requestKeysToAcls = Map[ApiKeys, Map[ResourcePattern, Set[AccessControlEntry]]]( + ApiKeys.METADATA -> topicDescribeAcl, + ApiKeys.PRODUCE -> (topicWriteAcl ++ transactionIdWriteAcl ++ clusterIdempotentWriteAcl), + ApiKeys.FETCH -> topicReadAcl, + ApiKeys.LIST_OFFSETS -> topicDescribeAcl, + ApiKeys.OFFSET_COMMIT -> (topicReadAcl ++ groupReadAcl), + ApiKeys.OFFSET_FETCH -> (topicReadAcl ++ groupDescribeAcl), + ApiKeys.FIND_COORDINATOR -> (topicReadAcl ++ groupDescribeAcl ++ transactionalIdDescribeAcl), + ApiKeys.UPDATE_METADATA -> clusterAcl, + ApiKeys.JOIN_GROUP -> groupReadAcl, + ApiKeys.SYNC_GROUP -> groupReadAcl, + ApiKeys.DESCRIBE_GROUPS -> groupDescribeAcl, + ApiKeys.HEARTBEAT -> groupReadAcl, + ApiKeys.LEAVE_GROUP -> groupReadAcl, + ApiKeys.DELETE_GROUPS -> groupDeleteAcl, + ApiKeys.LEADER_AND_ISR -> clusterAcl, + ApiKeys.STOP_REPLICA -> clusterAcl, + ApiKeys.CONTROLLED_SHUTDOWN -> clusterAcl, + ApiKeys.CREATE_TOPICS -> topicCreateAcl, + ApiKeys.DELETE_TOPICS -> topicDeleteAcl, + ApiKeys.DELETE_RECORDS -> topicDeleteAcl, + ApiKeys.OFFSET_FOR_LEADER_EPOCH -> topicDescribeAcl, + ApiKeys.DESCRIBE_CONFIGS -> topicDescribeConfigsAcl, + ApiKeys.ALTER_CONFIGS -> topicAlterConfigsAcl, + ApiKeys.INIT_PRODUCER_ID -> (transactionIdWriteAcl ++ clusterIdempotentWriteAcl), + ApiKeys.WRITE_TXN_MARKERS -> clusterAcl, + ApiKeys.ADD_PARTITIONS_TO_TXN -> (topicWriteAcl ++ transactionIdWriteAcl), + ApiKeys.ADD_OFFSETS_TO_TXN -> (groupReadAcl ++ transactionIdWriteAcl), + ApiKeys.END_TXN -> transactionIdWriteAcl, + ApiKeys.TXN_OFFSET_COMMIT -> (groupReadAcl ++ transactionIdWriteAcl), + ApiKeys.CREATE_ACLS -> clusterAlterAcl, + ApiKeys.DESCRIBE_ACLS -> clusterDescribeAcl, + ApiKeys.DELETE_ACLS -> clusterAlterAcl, + ApiKeys.ALTER_REPLICA_LOG_DIRS -> clusterAlterAcl, + ApiKeys.DESCRIBE_LOG_DIRS -> clusterDescribeAcl, + ApiKeys.CREATE_PARTITIONS -> topicAlterAcl, + ApiKeys.ELECT_LEADERS -> clusterAlterAcl, + ApiKeys.INCREMENTAL_ALTER_CONFIGS -> topicAlterConfigsAcl, + ApiKeys.ALTER_PARTITION_REASSIGNMENTS -> clusterAlterAcl, + ApiKeys.LIST_PARTITION_REASSIGNMENTS -> clusterDescribeAcl, + ApiKeys.OFFSET_DELETE -> groupReadAcl, + ApiKeys.DESCRIBE_PRODUCERS -> topicReadAcl, + ApiKeys.DESCRIBE_TRANSACTIONS -> transactionalIdDescribeAcl + ) + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + doSetup(testInfo, createOffsetsTopic = false) + + // Allow inter-broker communication + addAndVerifyAcls(Set(new AccessControlEntry(brokerPrincipal.toString, WildcardHost, CLUSTER_ACTION, ALLOW)), clusterResource) + + TestUtils.createOffsetsTopic(zkClient, servers) + } + + @AfterEach + override def tearDown(): Unit = { + adminClients.foreach(_.close()) + removeAllClientAcls() + super.tearDown() + } + + private def createMetadataRequest(allowAutoTopicCreation: Boolean) = { + new requests.MetadataRequest.Builder(List(topic).asJava, allowAutoTopicCreation).build() + } + + private def createProduceRequest = + requests.ProduceRequest.forCurrentMagic(new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection( + Collections.singletonList(new ProduceRequestData.TopicProduceData() + .setName(tp.topic).setPartitionData(Collections.singletonList( + new ProduceRequestData.PartitionProduceData() + .setIndex(tp.partition) + .setRecords(MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("test".getBytes)))))) + .iterator)) + .setAcks(1.toShort) + .setTimeoutMs(5000)) + .build() + + private def createFetchRequest = { + val partitionMap = new util.LinkedHashMap[TopicPartition, requests.FetchRequest.PartitionData] + partitionMap.put(tp, new requests.FetchRequest.PartitionData(getTopicIds().getOrElse(tp.topic, Uuid.ZERO_UUID), + 0, 0, 100, Optional.of(27))) + requests.FetchRequest.Builder.forConsumer(ApiKeys.FETCH.latestVersion, 100, Int.MaxValue, partitionMap).build() + } + + private def createFetchRequestWithUnknownTopic(id: Uuid, version: Short) = { + val partitionMap = new util.LinkedHashMap[TopicPartition, requests.FetchRequest.PartitionData] + partitionMap.put(tp, + new requests.FetchRequest.PartitionData(id, 0, 0, 100, Optional.of(27))) + requests.FetchRequest.Builder.forConsumer(version, 100, Int.MaxValue, partitionMap).build() + } + + private def createFetchFollowerRequest = { + val partitionMap = new util.LinkedHashMap[TopicPartition, requests.FetchRequest.PartitionData] + partitionMap.put(tp, new requests.FetchRequest.PartitionData(getTopicIds().getOrElse(tp.topic, Uuid.ZERO_UUID), + 0, 0, 100, Optional.of(27))) + val version = ApiKeys.FETCH.latestVersion + requests.FetchRequest.Builder.forReplica(version, 5000, 100, Int.MaxValue, partitionMap).build() + } + + private def createListOffsetsRequest = { + requests.ListOffsetsRequest.Builder.forConsumer(false, IsolationLevel.READ_UNCOMMITTED, false).setTargetTimes( + List(new ListOffsetsTopic() + .setName(tp.topic) + .setPartitions(List(new ListOffsetsPartition() + .setPartitionIndex(tp.partition) + .setTimestamp(0L) + .setCurrentLeaderEpoch(27)).asJava)).asJava + ). + build() + } + + private def offsetsForLeaderEpochRequest: OffsetsForLeaderEpochRequest = { + val epochs = new OffsetForLeaderTopicCollection() + epochs.add(new OffsetForLeaderTopic() + .setTopic(tp.topic) + .setPartitions(List(new OffsetForLeaderPartition() + .setPartition(tp.partition) + .setLeaderEpoch(7) + .setCurrentLeaderEpoch(27)).asJava)) + OffsetsForLeaderEpochRequest.Builder.forConsumer(epochs).build() + } + + private def createOffsetFetchRequest: OffsetFetchRequest = { + new requests.OffsetFetchRequest.Builder(group, false, List(tp).asJava, false).build() + } + + private def createOffsetFetchRequestAllPartitions: OffsetFetchRequest = { + new requests.OffsetFetchRequest.Builder(group, false, null, false).build() + } + + private def createOffsetFetchRequest(groupToPartitionMap: util.Map[String, util.List[TopicPartition]]): OffsetFetchRequest = { + new requests.OffsetFetchRequest.Builder(groupToPartitionMap, false, false).build() + } + + private def createFindCoordinatorRequest = { + new FindCoordinatorRequest.Builder( + new FindCoordinatorRequestData() + .setKeyType(FindCoordinatorRequest.CoordinatorType.GROUP.id) + .setCoordinatorKeys(Collections.singletonList(group))).build() + } + + private def createUpdateMetadataRequest = { + val partitionStates = Seq(new UpdateMetadataPartitionState() + .setTopicName(tp.topic) + .setPartitionIndex(tp.partition) + .setControllerEpoch(Int.MaxValue) + .setLeader(brokerId) + .setLeaderEpoch(Int.MaxValue) + .setIsr(List(brokerId).asJava) + .setZkVersion(2) + .setReplicas(Seq(brokerId).asJava)).asJava + val securityProtocol = SecurityProtocol.PLAINTEXT + val brokers = Seq(new UpdateMetadataBroker() + .setId(brokerId) + .setEndpoints(Seq(new UpdateMetadataEndpoint() + .setHost("localhost") + .setPort(0) + .setSecurityProtocol(securityProtocol.id) + .setListener(ListenerName.forSecurityProtocol(securityProtocol).value)).asJava)).asJava + val version = ApiKeys.UPDATE_METADATA.latestVersion + new requests.UpdateMetadataRequest.Builder(version, brokerId, Int.MaxValue, Long.MaxValue, partitionStates, + brokers, Collections.emptyMap()).build() + } + + private def createJoinGroupRequest = { + val protocolSet = new JoinGroupRequestProtocolCollection( + Collections.singletonList(new JoinGroupRequestData.JoinGroupRequestProtocol() + .setName(protocolName) + .setMetadata("test".getBytes()) + ).iterator()) + + new JoinGroupRequest.Builder( + new JoinGroupRequestData() + .setGroupId(group) + .setSessionTimeoutMs(10000) + .setMemberId(JoinGroupRequest.UNKNOWN_MEMBER_ID) + .setGroupInstanceId(null) + .setProtocolType(protocolType) + .setProtocols(protocolSet) + .setRebalanceTimeoutMs(60000) + ).build() + } + + private def createSyncGroupRequest = { + new SyncGroupRequest.Builder( + new SyncGroupRequestData() + .setGroupId(group) + .setGenerationId(1) + .setMemberId(JoinGroupRequest.UNKNOWN_MEMBER_ID) + .setProtocolType(protocolType) + .setProtocolName(protocolName) + .setAssignments(Collections.emptyList()) + ).build() + } + + private def createDescribeGroupsRequest = { + new DescribeGroupsRequest.Builder(new DescribeGroupsRequestData().setGroups(List(group).asJava)).build() + } + + private def createOffsetCommitRequest = { + new requests.OffsetCommitRequest.Builder( + new OffsetCommitRequestData() + .setGroupId(group) + .setMemberId(JoinGroupRequest.UNKNOWN_MEMBER_ID) + .setGenerationId(1) + .setTopics(Collections.singletonList( + new OffsetCommitRequestData.OffsetCommitRequestTopic() + .setName(topic) + .setPartitions(Collections.singletonList( + new OffsetCommitRequestData.OffsetCommitRequestPartition() + .setPartitionIndex(part) + .setCommittedOffset(0) + .setCommittedLeaderEpoch(RecordBatch.NO_PARTITION_LEADER_EPOCH) + .setCommitTimestamp(OffsetCommitRequest.DEFAULT_TIMESTAMP) + .setCommittedMetadata("metadata") + ))) + ) + ).build() + } + + private def createPartitionsRequest = { + val partitionTopic = new CreatePartitionsTopic() + .setName(topic) + .setCount(10) + .setAssignments(null) + val data = new CreatePartitionsRequestData() + .setTimeoutMs(10000) + .setValidateOnly(true) + data.topics().add(partitionTopic) + new CreatePartitionsRequest.Builder(data).build(0.toShort) + } + + private def heartbeatRequest = new HeartbeatRequest.Builder( + new HeartbeatRequestData() + .setGroupId(group) + .setGenerationId(1) + .setMemberId(JoinGroupRequest.UNKNOWN_MEMBER_ID)).build() + + private def leaveGroupRequest = new LeaveGroupRequest.Builder( + group, Collections.singletonList( + new MemberIdentity() + .setMemberId(JoinGroupRequest.UNKNOWN_MEMBER_ID) + )).build() + + private def deleteGroupsRequest = new DeleteGroupsRequest.Builder( + new DeleteGroupsRequestData() + .setGroupsNames(Collections.singletonList(group)) + ).build() + + private def leaderAndIsrRequest: LeaderAndIsrRequest = { + new requests.LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, brokerId, Int.MaxValue, Long.MaxValue, + Seq(new LeaderAndIsrPartitionState() + .setTopicName(tp.topic) + .setPartitionIndex(tp.partition) + .setControllerEpoch(Int.MaxValue) + .setLeader(brokerId) + .setLeaderEpoch(Int.MaxValue) + .setIsr(List(brokerId).asJava) + .setZkVersion(2) + .setReplicas(Seq(brokerId).asJava) + .setIsNew(false)).asJava, + getTopicIds().asJava, + Set(new Node(brokerId, "localhost", 0)).asJava).build() + } + + private def stopReplicaRequest: StopReplicaRequest = { + val topicStates = Seq( + new StopReplicaTopicState() + .setTopicName(tp.topic) + .setPartitionStates(Seq(new StopReplicaPartitionState() + .setPartitionIndex(tp.partition) + .setLeaderEpoch(LeaderAndIsr.initialLeaderEpoch + 2) + .setDeletePartition(true)).asJava) + ).asJava + new StopReplicaRequest.Builder(ApiKeys.STOP_REPLICA.latestVersion, brokerId, Int.MaxValue, + Long.MaxValue, false, topicStates).build() + } + + private def controlledShutdownRequest: ControlledShutdownRequest = { + new ControlledShutdownRequest.Builder( + new ControlledShutdownRequestData() + .setBrokerId(brokerId) + .setBrokerEpoch(Long.MaxValue), + ApiKeys.CONTROLLED_SHUTDOWN.latestVersion).build() + } + + private def createTopicsRequest: CreateTopicsRequest = { + new CreateTopicsRequest.Builder(new CreateTopicsRequestData().setTopics( + new CreatableTopicCollection(Collections.singleton(new CreatableTopic(). + setName(topic).setNumPartitions(1). + setReplicationFactor(1.toShort)).iterator))).build() + } + + private def deleteTopicsRequest: DeleteTopicsRequest = { + new DeleteTopicsRequest.Builder( + new DeleteTopicsRequestData() + .setTopicNames(Collections.singletonList(topic)) + .setTimeoutMs(5000)).build() + } + + private def deleteTopicsWithIdsRequest(topicId: Uuid): DeleteTopicsRequest = { + new DeleteTopicsRequest.Builder( + new DeleteTopicsRequestData() + .setTopics(Collections.singletonList( + new DeleteTopicsRequestData.DeleteTopicState() + .setTopicId(topicId))) + .setTimeoutMs(5000)).build() + } + + private def deleteRecordsRequest = new DeleteRecordsRequest.Builder( + new DeleteRecordsRequestData() + .setTimeoutMs(5000) + .setTopics(Collections.singletonList(new DeleteRecordsRequestData.DeleteRecordsTopic() + .setName(tp.topic) + .setPartitions(Collections.singletonList(new DeleteRecordsRequestData.DeleteRecordsPartition() + .setPartitionIndex(tp.partition) + .setOffset(0L)))))).build() + + private def describeConfigsRequest = + new DescribeConfigsRequest.Builder(new DescribeConfigsRequestData().setResources(Collections.singletonList( + new DescribeConfigsRequestData.DescribeConfigsResource().setResourceType(ConfigResource.Type.TOPIC.id) + .setResourceName(tp.topic)))).build() + + private def alterConfigsRequest = + new AlterConfigsRequest.Builder( + Collections.singletonMap(new ConfigResource(ConfigResource.Type.TOPIC, tp.topic), + new AlterConfigsRequest.Config(Collections.singleton( + new AlterConfigsRequest.ConfigEntry(LogConfig.MaxMessageBytesProp, "1000000") + ))), true).build() + + private def incrementalAlterConfigsRequest = { + val data = new IncrementalAlterConfigsRequestData + val alterableConfig = new AlterableConfig + alterableConfig.setName(LogConfig.MaxMessageBytesProp). + setValue("1000000").setConfigOperation(AlterConfigOp.OpType.SET.id()) + val alterableConfigSet = new AlterableConfigCollection + alterableConfigSet.add(alterableConfig) + data.resources().add(new AlterConfigsResource(). + setResourceName(tp.topic).setResourceType(ConfigResource.Type.TOPIC.id()). + setConfigs(alterableConfigSet)) + new IncrementalAlterConfigsRequest.Builder(data).build() + } + + private def describeAclsRequest = new DescribeAclsRequest.Builder(AclBindingFilter.ANY).build() + + private def createAclsRequest: CreateAclsRequest = new CreateAclsRequest.Builder( + new CreateAclsRequestData().setCreations(Collections.singletonList( + new CreateAclsRequestData.AclCreation() + .setResourceType(ResourceType.TOPIC.code) + .setResourceName("mytopic") + .setResourcePatternType(PatternType.LITERAL.code) + .setPrincipal(clientPrincipalString) + .setHost("*") + .setOperation(AclOperation.WRITE.code) + .setPermissionType(AclPermissionType.DENY.code))) + ).build() + + private def deleteAclsRequest: DeleteAclsRequest = new DeleteAclsRequest.Builder( + new DeleteAclsRequestData().setFilters(Collections.singletonList( + new DeleteAclsRequestData.DeleteAclsFilter() + .setResourceTypeFilter(ResourceType.TOPIC.code) + .setResourceNameFilter(null) + .setPatternTypeFilter(PatternType.LITERAL.code) + .setPrincipalFilter(clientPrincipalString) + .setHostFilter("*") + .setOperation(AclOperation.ANY.code) + .setPermissionType(AclPermissionType.DENY.code))) + ).build() + + private def alterReplicaLogDirsRequest = { + val dir = new AlterReplicaLogDirsRequestData.AlterReplicaLogDir() + .setPath(logDir) + dir.topics.add(new AlterReplicaLogDirsRequestData.AlterReplicaLogDirTopic() + .setName(tp.topic) + .setPartitions(Collections.singletonList(tp.partition))) + val data = new AlterReplicaLogDirsRequestData(); + data.dirs.add(dir) + new AlterReplicaLogDirsRequest.Builder(data).build() + } + + private def describeLogDirsRequest = new DescribeLogDirsRequest.Builder(new DescribeLogDirsRequestData().setTopics(new DescribeLogDirsRequestData.DescribableLogDirTopicCollection(Collections.singleton( + new DescribeLogDirsRequestData.DescribableLogDirTopic().setTopic(tp.topic).setPartitions(Collections.singletonList(tp.partition))).iterator()))).build() + + private def addPartitionsToTxnRequest = new AddPartitionsToTxnRequest.Builder(transactionalId, 1, 1, Collections.singletonList(tp)).build() + + private def addOffsetsToTxnRequest = new AddOffsetsToTxnRequest.Builder( + new AddOffsetsToTxnRequestData() + .setTransactionalId(transactionalId) + .setProducerId(1) + .setProducerEpoch(1) + .setGroupId(group) + ).build() + + private def electLeadersRequest = new ElectLeadersRequest.Builder( + ElectionType.PREFERRED, + Collections.singleton(tp), + 10000 + ).build() + + private def describeProducersRequest: DescribeProducersRequest = new DescribeProducersRequest.Builder( + new DescribeProducersRequestData() + .setTopics(List( + new DescribeProducersRequestData.TopicRequest() + .setName(tp.topic) + .setPartitionIndexes(List(Int.box(tp.partition)).asJava) + ).asJava) + ).build() + + private def describeTransactionsRequest: DescribeTransactionsRequest = new DescribeTransactionsRequest.Builder( + new DescribeTransactionsRequestData().setTransactionalIds(List(transactionalId).asJava) + ).build() + + private def alterPartitionReassignmentsRequest = new AlterPartitionReassignmentsRequest.Builder( + new AlterPartitionReassignmentsRequestData().setTopics( + List(new AlterPartitionReassignmentsRequestData.ReassignableTopic() + .setName(topic) + .setPartitions( + List(new AlterPartitionReassignmentsRequestData.ReassignablePartition().setPartitionIndex(tp.partition)).asJava + )).asJava + ) + ).build() + + private def listPartitionReassignmentsRequest = new ListPartitionReassignmentsRequest.Builder( + new ListPartitionReassignmentsRequestData().setTopics( + List(new ListPartitionReassignmentsRequestData.ListPartitionReassignmentsTopics() + .setName(topic) + .setPartitionIndexes( + List(Integer.valueOf(tp.partition)).asJava + )).asJava + ) + ).build() + + private def sendRequests(requestKeyToRequest: mutable.Map[ApiKeys, AbstractRequest], topicExists: Boolean = true, + topicNames: Map[Uuid, String] = getTopicNames()) = { + for ((key, request) <- requestKeyToRequest) { + removeAllClientAcls() + val resources = requestKeysToAcls(key).map(_._1.resourceType).toSet + sendRequestAndVerifyResponseError(request, resources, isAuthorized = false, topicExists = topicExists, topicNames = topicNames) + + val resourceToAcls = requestKeysToAcls(key) + resourceToAcls.get(topicResource).foreach { acls => + val describeAcls = topicDescribeAcl(topicResource) + val isAuthorized = describeAcls == acls + addAndVerifyAcls(describeAcls, topicResource) + sendRequestAndVerifyResponseError(request, resources, isAuthorized = isAuthorized, topicExists = topicExists, topicNames = topicNames) + removeAllClientAcls() + } + + for ((resource, acls) <- resourceToAcls) + addAndVerifyAcls(acls, resource) + sendRequestAndVerifyResponseError(request, resources, isAuthorized = true, topicExists = topicExists, topicNames = topicNames) + } + } + + @Test + def testAuthorizationWithTopicExisting(): Unit = { + //First create the topic so we have a valid topic ID + sendRequests(mutable.Map(ApiKeys.CREATE_TOPICS -> createTopicsRequest)) + + val requestKeyToRequest = mutable.LinkedHashMap[ApiKeys, AbstractRequest]( + ApiKeys.METADATA -> createMetadataRequest(allowAutoTopicCreation = true), + ApiKeys.PRODUCE -> createProduceRequest, + ApiKeys.FETCH -> createFetchRequest, + ApiKeys.LIST_OFFSETS -> createListOffsetsRequest, + ApiKeys.OFFSET_FETCH -> createOffsetFetchRequest, + ApiKeys.FIND_COORDINATOR -> createFindCoordinatorRequest, + ApiKeys.JOIN_GROUP -> createJoinGroupRequest, + ApiKeys.SYNC_GROUP -> createSyncGroupRequest, + ApiKeys.DESCRIBE_GROUPS -> createDescribeGroupsRequest, + ApiKeys.OFFSET_COMMIT -> createOffsetCommitRequest, + ApiKeys.HEARTBEAT -> heartbeatRequest, + ApiKeys.LEAVE_GROUP -> leaveGroupRequest, + ApiKeys.DELETE_RECORDS -> deleteRecordsRequest, + ApiKeys.OFFSET_FOR_LEADER_EPOCH -> offsetsForLeaderEpochRequest, + ApiKeys.DESCRIBE_CONFIGS -> describeConfigsRequest, + ApiKeys.ALTER_CONFIGS -> alterConfigsRequest, + ApiKeys.CREATE_ACLS -> createAclsRequest, + ApiKeys.DELETE_ACLS -> deleteAclsRequest, + ApiKeys.DESCRIBE_ACLS -> describeAclsRequest, + ApiKeys.ALTER_REPLICA_LOG_DIRS -> alterReplicaLogDirsRequest, + ApiKeys.DESCRIBE_LOG_DIRS -> describeLogDirsRequest, + ApiKeys.CREATE_PARTITIONS -> createPartitionsRequest, + ApiKeys.ADD_PARTITIONS_TO_TXN -> addPartitionsToTxnRequest, + ApiKeys.ADD_OFFSETS_TO_TXN -> addOffsetsToTxnRequest, + ApiKeys.ELECT_LEADERS -> electLeadersRequest, + ApiKeys.INCREMENTAL_ALTER_CONFIGS -> incrementalAlterConfigsRequest, + ApiKeys.ALTER_PARTITION_REASSIGNMENTS -> alterPartitionReassignmentsRequest, + ApiKeys.LIST_PARTITION_REASSIGNMENTS -> listPartitionReassignmentsRequest, + ApiKeys.DESCRIBE_PRODUCERS -> describeProducersRequest, + ApiKeys.DESCRIBE_TRANSACTIONS -> describeTransactionsRequest, + + // Inter-broker APIs use an invalid broker epoch, so does not affect the test case + ApiKeys.UPDATE_METADATA -> createUpdateMetadataRequest, + ApiKeys.LEADER_AND_ISR -> leaderAndIsrRequest, + ApiKeys.STOP_REPLICA -> stopReplicaRequest, + ApiKeys.CONTROLLED_SHUTDOWN -> controlledShutdownRequest, + + // Delete the topic last + ApiKeys.DELETE_TOPICS -> deleteTopicsRequest + ) + + sendRequests(requestKeyToRequest, true) + } + + /* + * even if the topic doesn't exist, request APIs should not leak the topic name + */ + @Test + def testAuthorizationWithTopicNotExisting(): Unit = { + val id = Uuid.randomUuid() + val topicNames = Map(id -> "topic") + val requestKeyToRequest = mutable.LinkedHashMap[ApiKeys, AbstractRequest]( + ApiKeys.METADATA -> createMetadataRequest(allowAutoTopicCreation = false), + ApiKeys.PRODUCE -> createProduceRequest, + ApiKeys.FETCH -> createFetchRequestWithUnknownTopic(id, ApiKeys.FETCH.latestVersion()), + ApiKeys.LIST_OFFSETS -> createListOffsetsRequest, + ApiKeys.OFFSET_COMMIT -> createOffsetCommitRequest, + ApiKeys.OFFSET_FETCH -> createOffsetFetchRequest, + ApiKeys.DELETE_TOPICS -> deleteTopicsRequest, + ApiKeys.DELETE_RECORDS -> deleteRecordsRequest, + ApiKeys.ADD_PARTITIONS_TO_TXN -> addPartitionsToTxnRequest, + ApiKeys.ADD_OFFSETS_TO_TXN -> addOffsetsToTxnRequest, + ApiKeys.CREATE_PARTITIONS -> createPartitionsRequest, + ApiKeys.DELETE_GROUPS -> deleteGroupsRequest, + ApiKeys.OFFSET_FOR_LEADER_EPOCH -> offsetsForLeaderEpochRequest, + ApiKeys.ELECT_LEADERS -> electLeadersRequest + ) + + sendRequests(requestKeyToRequest, false, topicNames) + } + + @ParameterizedTest + @ValueSource(booleans = Array(true, false)) + def testTopicIdAuthorization(withTopicExisting: Boolean): Unit = { + val topicId = if (withTopicExisting) { + createTopic(topic) + getTopicIds()(topic) + } else { + Uuid.randomUuid() + } + + val requestKeyToRequest = mutable.LinkedHashMap[ApiKeys, AbstractRequest]( + ApiKeys.DELETE_TOPICS -> deleteTopicsWithIdsRequest(topicId) + ) + + def sendAndVerify( + request: AbstractRequest, + isAuthorized: Boolean, + isDescribeAuthorized: Boolean + ): Unit = { + val response = connectAndReceive[AbstractResponse](request) + val error = findErrorForTopicId(topicId, response) + if (!withTopicExisting) { + assertEquals(Errors.UNKNOWN_TOPIC_ID, error) + } else if (!isDescribeAuthorized || !isAuthorized) { + assertEquals(Errors.TOPIC_AUTHORIZATION_FAILED, error) + } + } + + for ((key, request) <- requestKeyToRequest) { + removeAllClientAcls() + sendAndVerify(request, isAuthorized = false, isDescribeAuthorized = false) + + val describeAcls = topicDescribeAcl(topicResource) + addAndVerifyAcls(describeAcls, topicResource) + + val resourceToAcls = requestKeysToAcls(key) + resourceToAcls.get(topicResource).foreach { acls => + val isAuthorized = describeAcls == acls + sendAndVerify(request, isAuthorized = isAuthorized, isDescribeAuthorized = true) + } + + removeAllClientAcls() + for ((resource, acls) <- resourceToAcls) { + addAndVerifyAcls(acls, resource) + } + + sendAndVerify(request, isAuthorized = true, isDescribeAuthorized = true) + } + } + + /* + * even if the topic doesn't exist, request APIs should not leak the topic name + */ + @Test + def testAuthorizationFetchV12WithTopicNotExisting(): Unit = { + val id = Uuid.ZERO_UUID + val topicNames = Map(id -> "topic") + val requestKeyToRequest = mutable.LinkedHashMap[ApiKeys, AbstractRequest]( + ApiKeys.FETCH -> createFetchRequestWithUnknownTopic(id, 12), + ) + + sendRequests(requestKeyToRequest, false, topicNames) + } + + @Test + def testCreateTopicAuthorizationWithClusterCreate(): Unit = { + removeAllClientAcls() + val resources = Set[ResourceType](TOPIC) + + sendRequestAndVerifyResponseError(createTopicsRequest, resources, isAuthorized = false) + + for ((resource, acls) <- clusterCreateAcl) + addAndVerifyAcls(acls, resource) + sendRequestAndVerifyResponseError(createTopicsRequest, resources, isAuthorized = true) + } + + @Test + def testFetchFollowerRequest(): Unit = { + createTopic(topic) + + val request = createFetchFollowerRequest + + removeAllClientAcls() + val resources = Set(topicResource.resourceType, clusterResource.resourceType) + sendRequestAndVerifyResponseError(request, resources, isAuthorized = false) + + val readAcls = topicReadAcl(topicResource) + addAndVerifyAcls(readAcls, topicResource) + sendRequestAndVerifyResponseError(request, resources, isAuthorized = false) + + val clusterAcls = clusterAcl(clusterResource) + addAndVerifyAcls(clusterAcls, clusterResource) + sendRequestAndVerifyResponseError(request, resources, isAuthorized = true) + } + + @Test + def testIncrementalAlterConfigsRequestRequiresClusterPermissionForBrokerLogger(): Unit = { + createTopic(topic) + + val data = new IncrementalAlterConfigsRequestData + val alterableConfig = new AlterableConfig().setName("kafka.controller.KafkaController"). + setValue(LogLevelConfig.DEBUG_LOG_LEVEL).setConfigOperation(AlterConfigOp.OpType.DELETE.id()) + val alterableConfigSet = new AlterableConfigCollection + alterableConfigSet.add(alterableConfig) + data.resources().add(new AlterConfigsResource(). + setResourceName(brokerId.toString).setResourceType(ConfigResource.Type.BROKER_LOGGER.id()). + setConfigs(alterableConfigSet)) + val request = new IncrementalAlterConfigsRequest.Builder(data).build() + + removeAllClientAcls() + val resources = Set(topicResource.resourceType, clusterResource.resourceType) + sendRequestAndVerifyResponseError(request, resources, isAuthorized = false) + + val clusterAcls = clusterAlterConfigsAcl(clusterResource) + addAndVerifyAcls(clusterAcls, clusterResource) + sendRequestAndVerifyResponseError(request, resources, isAuthorized = true) + } + + @Test + def testOffsetsForLeaderEpochClusterPermission(): Unit = { + createTopic(topic) + + val request = offsetsForLeaderEpochRequest + + removeAllClientAcls() + + val resources = Set(topicResource.resourceType, clusterResource.resourceType) + sendRequestAndVerifyResponseError(request, resources, isAuthorized = false) + + // Although the OffsetsForLeaderEpoch API now accepts topic describe, we should continue + // allowing cluster action for backwards compatibility + val clusterAcls = clusterAcl(clusterResource) + addAndVerifyAcls(clusterAcls, clusterResource) + sendRequestAndVerifyResponseError(request, resources, isAuthorized = true) + } + + @Test + def testProduceWithNoTopicAccess(): Unit = { + createTopic(topic) + val producer = createProducer() + assertThrows(classOf[TopicAuthorizationException], () => sendRecords(producer, numRecords, tp)) + } + + @Test + def testProduceWithTopicDescribe(): Unit = { + createTopic(topic) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), topicResource) + val producer = createProducer() + assertThrows(classOf[TopicAuthorizationException], () => sendRecords(producer, numRecords, tp)) + } + + @Test + def testProduceWithTopicRead(): Unit = { + createTopic(topic) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), topicResource) + val producer = createProducer() + assertThrows(classOf[TopicAuthorizationException], () => sendRecords(producer, numRecords, tp)) + } + + @Test + def testProduceWithTopicWrite(): Unit = { + createTopic(topic) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), topicResource) + val producer = createProducer() + sendRecords(producer, numRecords, tp) + } + + @Test + def testCreatePermissionOnTopicToWriteToNonExistentTopic(): Unit = { + testCreatePermissionNeededToWriteToNonExistentTopic(TOPIC) + } + + @Test + def testCreatePermissionOnClusterToWriteToNonExistentTopic(): Unit = { + testCreatePermissionNeededToWriteToNonExistentTopic(CLUSTER) + } + + private def testCreatePermissionNeededToWriteToNonExistentTopic(resType: ResourceType): Unit = { + val newTopicResource = new ResourcePattern(TOPIC, topic, LITERAL) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), newTopicResource) + val producer = createProducer() + val e = assertThrows(classOf[TopicAuthorizationException], () => sendRecords(producer, numRecords, tp)) + assertEquals(Collections.singleton(tp.topic), e.unauthorizedTopics()) + + val resource = if (resType == ResourceType.TOPIC) newTopicResource else clusterResource + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, CREATE, ALLOW)), resource) + + sendRecords(producer, numRecords, tp) + } + + @Test + def testConsumeUsingAssignWithNoAccess(): Unit = { + createTopic(topic) + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), topicResource) + val producer = createProducer() + sendRecords(producer, 1, tp) + removeAllClientAcls() + + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + assertThrows(classOf[TopicAuthorizationException], () => consumeRecords(consumer)) + } + + @Test + def testSimpleConsumeWithOffsetLookupAndNoGroupAccess(): Unit = { + createTopic(topic) + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), topicResource) + val producer = createProducer() + sendRecords(producer, 1, tp) + removeAllClientAcls() + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), topicResource) + + // note this still depends on group access because we haven't set offsets explicitly, which means + // they will first be fetched from the consumer coordinator (which requires group access) + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + val e = assertThrows(classOf[GroupAuthorizationException], () => consumeRecords(consumer)) + assertEquals(group, e.groupId()) + } + + @Test + def testSimpleConsumeWithExplicitSeekAndNoGroupAccess(): Unit = { + createTopic(topic) + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), topicResource) + val producer = createProducer() + sendRecords(producer, 1, tp) + removeAllClientAcls() + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), topicResource) + + // in this case, we do an explicit seek, so there should be no need to query the coordinator at all + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + consumer.seekToBeginning(List(tp).asJava) + consumeRecords(consumer) + } + + @Test + def testConsumeWithoutTopicDescribeAccess(): Unit = { + createTopic(topic) + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), topicResource) + val producer = createProducer() + sendRecords(producer, 1, tp) + removeAllClientAcls() + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResource) + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + + val e = assertThrows(classOf[TopicAuthorizationException], () => consumeRecords(consumer)) + assertEquals(Collections.singleton(topic), e.unauthorizedTopics()) + } + + @Test + def testConsumeWithTopicDescribe(): Unit = { + createTopic(topic) + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), topicResource) + val producer = createProducer() + sendRecords(producer, 1, tp) + removeAllClientAcls() + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), topicResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResource) + + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + val e = assertThrows(classOf[TopicAuthorizationException], () => consumeRecords(consumer)) + assertEquals(Collections.singleton(topic), e.unauthorizedTopics()) + } + + @Test + def testConsumeWithTopicWrite(): Unit = { + createTopic(topic) + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), topicResource) + val producer = createProducer() + sendRecords(producer, 1, tp) + removeAllClientAcls() + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), topicResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResource) + + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + val e = assertThrows(classOf[TopicAuthorizationException], () => consumeRecords(consumer)) + assertEquals(Collections.singleton(topic), e.unauthorizedTopics()) + } + + @Test + def testConsumeWithTopicAndGroupRead(): Unit = { + createTopic(topic) + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), topicResource) + val producer = createProducer() + sendRecords(producer, 1, tp) + removeAllClientAcls() + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), topicResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResource) + + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + consumeRecords(consumer) + } + + @nowarn("cat=deprecation") + @Test + def testPatternSubscriptionWithNoTopicAccess(): Unit = { + createTopic(topic) + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), topicResource) + val producer = createProducer() + sendRecords(producer, 1, tp) + removeAllClientAcls() + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResource) + + val consumer = createConsumer() + consumer.subscribe(Pattern.compile(topicPattern), new NoOpConsumerRebalanceListener) + consumer.poll(0) + assertTrue(consumer.subscription.isEmpty) + } + + @Test + def testPatternSubscriptionWithTopicDescribeOnlyAndGroupRead(): Unit = { + createTopic(topic) + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), topicResource) + val producer = createProducer() + sendRecords(producer, 1, tp) + removeAllClientAcls() + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), topicResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResource) + val consumer = createConsumer() + consumer.subscribe(Pattern.compile(topicPattern)) + val e = assertThrows(classOf[TopicAuthorizationException], () => consumeRecords(consumer)) + assertEquals(Collections.singleton(topic), e.unauthorizedTopics()) + } + + @nowarn("cat=deprecation") + @Test + def testPatternSubscriptionWithTopicAndGroupRead(): Unit = { + createTopic(topic) + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), topicResource) + val producer = createProducer() + sendRecords(producer, 1, tp) + + // create an unmatched topic + val unmatchedTopic = "unmatched" + createTopic(unmatchedTopic) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), new ResourcePattern(TOPIC, unmatchedTopic, LITERAL)) + sendRecords(producer, 1, new TopicPartition(unmatchedTopic, part)) + removeAllClientAcls() + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), topicResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResource) + val consumer = createConsumer() + consumer.subscribe(Pattern.compile(topicPattern)) + consumeRecords(consumer) + + // set the subscription pattern to an internal topic that the consumer has read permission to. Since + // internal topics are not included, we should not be assigned any partitions from this topic + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), new ResourcePattern(TOPIC, + GROUP_METADATA_TOPIC_NAME, LITERAL)) + consumer.subscribe(Pattern.compile(GROUP_METADATA_TOPIC_NAME)) + consumer.poll(0) + assertTrue(consumer.subscription().isEmpty) + assertTrue(consumer.assignment().isEmpty) + } + + @nowarn("cat=deprecation") + @Test + def testPatternSubscriptionMatchingInternalTopic(): Unit = { + createTopic(topic) + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), topicResource) + val producer = createProducer() + sendRecords(producer, 1, tp) + removeAllClientAcls() + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), topicResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResource) + + consumerConfig.put(ConsumerConfig.EXCLUDE_INTERNAL_TOPICS_CONFIG, "false") + val consumer = createConsumer() + // ensure that internal topics are not included if no permission + consumer.subscribe(Pattern.compile(".*")) + consumeRecords(consumer) + assertEquals(Set(topic).asJava, consumer.subscription) + + // now authorize the user for the internal topic and verify that we can subscribe + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), new ResourcePattern(TOPIC, + GROUP_METADATA_TOPIC_NAME, LITERAL)) + consumer.subscribe(Pattern.compile(GROUP_METADATA_TOPIC_NAME)) + consumer.poll(0) + assertEquals(Set(GROUP_METADATA_TOPIC_NAME), consumer.subscription.asScala) + } + + @Test + def testPatternSubscriptionMatchingInternalTopicWithDescribeOnlyPermission(): Unit = { + createTopic(topic) + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), topicResource) + val producer = createProducer() + sendRecords(producer, 1, tp) + removeAllClientAcls() + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), topicResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResource) + val internalTopicResource = new ResourcePattern(TOPIC, GROUP_METADATA_TOPIC_NAME, LITERAL) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), internalTopicResource) + + consumerConfig.put(ConsumerConfig.EXCLUDE_INTERNAL_TOPICS_CONFIG, "false") + val consumer = createConsumer() + consumer.subscribe(Pattern.compile(".*")) + val e = assertThrows(classOf[TopicAuthorizationException], () => { + // It is possible that the first call returns records of "topic" and the second call throws TopicAuthorizationException + consumeRecords(consumer) + consumeRecords(consumer) + }) + assertEquals(Collections.singleton(GROUP_METADATA_TOPIC_NAME), e.unauthorizedTopics()) + } + + @Test + def testPatternSubscriptionNotMatchingInternalTopic(): Unit = { + createTopic(topic) + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), topicResource) + val producer = createProducer() + sendRecords(producer, 1, tp) + removeAllClientAcls() + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), topicResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResource) + + consumerConfig.put(ConsumerConfig.EXCLUDE_INTERNAL_TOPICS_CONFIG, "false") + val consumer = createConsumer() + consumer.subscribe(Pattern.compile(topicPattern)) + consumeRecords(consumer) + } + + @Test + def testCreatePermissionOnTopicToReadFromNonExistentTopic(): Unit = { + testCreatePermissionNeededToReadFromNonExistentTopic("newTopic", + Set(new AccessControlEntry(clientPrincipalString, WildcardHost, CREATE, ALLOW)), + TOPIC) + } + + @Test + def testCreatePermissionOnClusterToReadFromNonExistentTopic(): Unit = { + testCreatePermissionNeededToReadFromNonExistentTopic("newTopic", + Set(new AccessControlEntry(clientPrincipalString, WildcardHost, CREATE, ALLOW)), + CLUSTER) + } + + private def testCreatePermissionNeededToReadFromNonExistentTopic(newTopic: String, acls: Set[AccessControlEntry], resType: ResourceType): Unit = { + val topicPartition = new TopicPartition(newTopic, 0) + val newTopicResource = new ResourcePattern(TOPIC, newTopic, LITERAL) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), newTopicResource) + addAndVerifyAcls(groupReadAcl(groupResource), groupResource) + val consumer = createConsumer() + consumer.assign(List(topicPartition).asJava) + val unauthorizedTopics = assertThrows(classOf[TopicAuthorizationException], + () => (0 until 10).foreach(_ => consumer.poll(Duration.ofMillis(50L)))).unauthorizedTopics + assertEquals(Collections.singleton(newTopic), unauthorizedTopics) + + val resource = if (resType == TOPIC) newTopicResource else clusterResource + addAndVerifyAcls(acls, resource) + + TestUtils.waitUntilTrue(() => { + consumer.poll(Duration.ofMillis(50L)) + this.zkClient.topicExists(newTopic) + }, "Expected topic was not created") + } + + @Test + def testCreatePermissionMetadataRequestAutoCreate(): Unit = { + val readAcls = topicReadAcl(topicResource) + addAndVerifyAcls(readAcls, topicResource) + assertFalse(zkClient.topicExists(topic)) + + val metadataRequest = new MetadataRequest.Builder(List(topic).asJava, true).build() + val metadataResponse = connectAndReceive[MetadataResponse](metadataRequest) + + assertEquals(Set().asJava, metadataResponse.topicsByError(Errors.NONE)) + + val createAcls = topicCreateAcl(topicResource) + addAndVerifyAcls(createAcls, topicResource) + + // retry as topic being created can have MetadataResponse with Errors.LEADER_NOT_AVAILABLE + TestUtils.retry(JTestUtils.DEFAULT_MAX_WAIT_MS) { + val metadataResponse = connectAndReceive[MetadataResponse](metadataRequest) + assertEquals(Set(topic).asJava, metadataResponse.topicsByError(Errors.NONE)) + } + } + + @Test + def testCommitWithNoAccess(): Unit = { + val consumer = createConsumer() + assertThrows(classOf[GroupAuthorizationException], () => consumer.commitSync(Map(tp -> new OffsetAndMetadata(5)).asJava)) + } + + @Test + def testCommitWithNoTopicAccess(): Unit = { + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResource) + val consumer = createConsumer() + assertThrows(classOf[TopicAuthorizationException], () => consumer.commitSync(Map(tp -> new OffsetAndMetadata(5)).asJava)) + } + + @Test + def testCommitWithTopicWrite(): Unit = { + createTopic(topic) + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), topicResource) + val consumer = createConsumer() + assertThrows(classOf[TopicAuthorizationException], () => consumer.commitSync(Map(tp -> new OffsetAndMetadata(5)).asJava)) + } + + @Test + def testCommitWithTopicDescribe(): Unit = { + createTopic(topic) + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), topicResource) + val consumer = createConsumer() + assertThrows(classOf[TopicAuthorizationException], () => consumer.commitSync(Map(tp -> new OffsetAndMetadata(5)).asJava)) + } + + @Test + def testCommitWithNoGroupAccess(): Unit = { + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), topicResource) + val consumer = createConsumer() + assertThrows(classOf[GroupAuthorizationException], () => consumer.commitSync(Map(tp -> new OffsetAndMetadata(5)).asJava)) + } + + @Test + def testCommitWithTopicAndGroupRead(): Unit = { + createTopic(topic) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), topicResource) + val consumer = createConsumer() + consumer.commitSync(Map(tp -> new OffsetAndMetadata(5)).asJava) + } + + @Test + def testOffsetFetchWithNoAccess(): Unit = { + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + assertThrows(classOf[TopicAuthorizationException], () => consumer.position(tp)) + } + + @Test + def testOffsetFetchWithNoGroupAccess(): Unit = { + createTopic(topic) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), topicResource) + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + assertThrows(classOf[GroupAuthorizationException], () => consumer.position(tp)) + } + + @Test + def testOffsetFetchWithNoTopicAccess(): Unit = { + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResource) + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + assertThrows(classOf[TopicAuthorizationException], () => consumer.position(tp)) + } + + @Test + def testOffsetFetchAllTopicPartitionsAuthorization(): Unit = { + createTopic(topic) + + val offset = 15L + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), topicResource) + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + consumer.commitSync(Map(tp -> new OffsetAndMetadata(offset)).asJava) + + removeAllClientAcls() + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResource) + + // send offset fetch requests directly since the consumer does not expose an API to do so + // note there's only one broker, so no need to lookup the group coordinator + + // without describe permission on the topic, we shouldn't be able to fetch offsets + val offsetFetchRequest = createOffsetFetchRequestAllPartitions + var offsetFetchResponse = connectAndReceive[OffsetFetchResponse](offsetFetchRequest) + assertEquals(Errors.NONE, offsetFetchResponse.groupLevelError(group)) + assertTrue(offsetFetchResponse.partitionDataMap(group).isEmpty) + + // now add describe permission on the topic and verify that the offset can be fetched + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), topicResource) + offsetFetchResponse = connectAndReceive[OffsetFetchResponse](offsetFetchRequest) + assertEquals(Errors.NONE, offsetFetchResponse.groupLevelError(group)) + assertTrue(offsetFetchResponse.partitionDataMap(group).containsKey(tp)) + assertEquals(offset, offsetFetchResponse.partitionDataMap(group).get(tp).offset) + } + + @Test + def testOffsetFetchMultipleGroupsAuthorization(): Unit = { + val groups: Seq[String] = (1 to 5).map(i => s"group$i") + val groupResources = groups.map(group => new ResourcePattern(GROUP, group, LITERAL)) + val topics: Seq[String] = (1 to 3).map(i => s"topic$i") + val topicResources = topics.map(topic => new ResourcePattern(TOPIC, topic, LITERAL)) + + val topic1List = singletonList(new TopicPartition(topics(0), 0)) + val topic1And2List = util.Arrays.asList( + new TopicPartition(topics(0), 0), + new TopicPartition(topics(1), 0), + new TopicPartition(topics(1), 1)) + val allTopicsList = util.Arrays.asList( + new TopicPartition(topics(0), 0), + new TopicPartition(topics(1), 0), + new TopicPartition(topics(1), 1), + new TopicPartition(topics(2), 0), + new TopicPartition(topics(2), 1), + new TopicPartition(topics(2), 2)) + + // create group to partition map to build batched offsetFetch request + val groupToPartitionMap = new util.HashMap[String, util.List[TopicPartition]]() + groupToPartitionMap.put(groups(0), topic1List) + groupToPartitionMap.put(groups(1), topic1And2List) + groupToPartitionMap.put(groups(2), allTopicsList) + groupToPartitionMap.put(groups(3), null) + groupToPartitionMap.put(groups(4), null) + + createTopic(topics(0)) + createTopic(topics(1), numPartitions = 2) + createTopic(topics(2), numPartitions = 3) + groupResources.foreach(r => { + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), r) + }) + topicResources.foreach(t => { + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), t) + }) + + val offset = 15L + val leaderEpoch: Optional[Integer] = Optional.of(1) + val metadata = "metadata" + + def commitOffsets(tpList: util.List[TopicPartition]): Unit = { + val consumer = createConsumer() + consumer.assign(tpList) + val offsets = tpList.asScala.map{ + tp => (tp, new OffsetAndMetadata(offset, leaderEpoch, metadata)) + }.toMap.asJava + consumer.commitSync(offsets) + consumer.close() + } + + // create 5 consumers to commit offsets so we can fetch them later + val partitionMap = groupToPartitionMap.asScala.map(e => (e._1, Option(e._2).getOrElse(allTopicsList))) + groups.foreach { groupId => + consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, groupId) + commitOffsets(partitionMap(groupId)) + } + + removeAllClientAcls() + + def verifyPartitionData(partitionData: OffsetFetchResponse.PartitionData): Unit = { + assertTrue(!partitionData.hasError) + assertEquals(offset, partitionData.offset) + assertEquals(metadata, partitionData.metadata) + assertEquals(leaderEpoch.get(), partitionData.leaderEpoch.get()) + } + + def verifyResponse(groupLevelResponse: Errors, + partitionData: util.Map[TopicPartition, PartitionData], + topicList: util.List[TopicPartition]): Unit = { + assertEquals(Errors.NONE, groupLevelResponse) + assertTrue(partitionData.size() == topicList.size()) + topicList.forEach(t => verifyPartitionData(partitionData.get(t))) + } + + // test handling partial errors, where one group is fully authorized, some groups don't have + // the right topic authorizations, and some groups have no authorization + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResources(0)) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResources(1)) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResources(3)) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), topicResources(0)) + val offsetFetchRequest = createOffsetFetchRequest(groupToPartitionMap) + var offsetFetchResponse = connectAndReceive[OffsetFetchResponse](offsetFetchRequest) + offsetFetchResponse.data().groups().forEach(g => + g.groupId() match { + case "group1" => + verifyResponse(offsetFetchResponse.groupLevelError(groups(0)), offsetFetchResponse + .partitionDataMap(groups(0)), topic1List) + case "group2" => + assertEquals(Errors.NONE, offsetFetchResponse.groupLevelError(groups(1))) + val group2Response = offsetFetchResponse.partitionDataMap(groups(1)) + assertTrue(group2Response.size() == 3) + assertTrue(group2Response.keySet().containsAll(topic1And2List)) + verifyPartitionData(group2Response.get(topic1And2List.get(0))) + assertTrue(group2Response.get(topic1And2List.get(1)).hasError) + assertTrue(group2Response.get(topic1And2List.get(2)).hasError) + assertEquals(OffsetFetchResponse.UNAUTHORIZED_PARTITION, group2Response.get(topic1And2List.get(1))) + assertEquals(OffsetFetchResponse.UNAUTHORIZED_PARTITION, group2Response.get(topic1And2List.get(2))) + case "group3" => + assertEquals(Errors.GROUP_AUTHORIZATION_FAILED, offsetFetchResponse.groupLevelError(groups(2))) + assertTrue(offsetFetchResponse.partitionDataMap(groups(2)).size() == 0) + case "group4" => + verifyResponse(offsetFetchResponse.groupLevelError(groups(3)), offsetFetchResponse + .partitionDataMap(groups(3)), topic1List) + case "group5" => + assertEquals(Errors.GROUP_AUTHORIZATION_FAILED, offsetFetchResponse.groupLevelError(groups(4))) + assertTrue(offsetFetchResponse.partitionDataMap(groups(4)).size() == 0) + }) + + // test that after adding some of the ACLs, we get no group level authorization errors, but + // still get topic level authorization errors for topics we don't have ACLs for + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResources(2)) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResources(4)) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), topicResources(1)) + offsetFetchResponse = connectAndReceive[OffsetFetchResponse](offsetFetchRequest) + offsetFetchResponse.data().groups().forEach(g => + g.groupId() match { + case "group1" => + verifyResponse(offsetFetchResponse.groupLevelError(groups(0)), offsetFetchResponse + .partitionDataMap(groups(0)), topic1List) + case "group2" => + verifyResponse(offsetFetchResponse.groupLevelError(groups(1)), offsetFetchResponse + .partitionDataMap(groups(1)), topic1And2List) + case "group3" => + assertEquals(Errors.NONE, offsetFetchResponse.groupLevelError(groups(2))) + val group3Response = offsetFetchResponse.partitionDataMap(groups(2)) + assertTrue(group3Response.size() == 6) + assertTrue(group3Response.keySet().containsAll(allTopicsList)) + verifyPartitionData(group3Response.get(allTopicsList.get(0))) + verifyPartitionData(group3Response.get(allTopicsList.get(1))) + verifyPartitionData(group3Response.get(allTopicsList.get(2))) + assertTrue(group3Response.get(allTopicsList.get(3)).hasError) + assertTrue(group3Response.get(allTopicsList.get(4)).hasError) + assertTrue(group3Response.get(allTopicsList.get(5)).hasError) + assertEquals(OffsetFetchResponse.UNAUTHORIZED_PARTITION, group3Response.get(allTopicsList.get(3))) + assertEquals(OffsetFetchResponse.UNAUTHORIZED_PARTITION, group3Response.get(allTopicsList.get(4))) + assertEquals(OffsetFetchResponse.UNAUTHORIZED_PARTITION, group3Response.get(allTopicsList.get(5))) + case "group4" => + verifyResponse(offsetFetchResponse.groupLevelError(groups(3)), offsetFetchResponse + .partitionDataMap(groups(3)), topic1And2List) + case "group5" => + verifyResponse(offsetFetchResponse.groupLevelError(groups(4)), offsetFetchResponse + .partitionDataMap(groups(4)), topic1And2List) + }) + + // test that after adding all necessary ACLs, we get no partition level or group level errors + // from the offsetFetch response + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), topicResources(2)) + offsetFetchResponse = connectAndReceive[OffsetFetchResponse](offsetFetchRequest) + offsetFetchResponse.data.groups.asScala.map(_.groupId).foreach( groupId => + verifyResponse(offsetFetchResponse.groupLevelError(groupId), offsetFetchResponse.partitionDataMap(groupId), partitionMap(groupId)) + ) + } + + @Test + def testOffsetFetchTopicDescribe(): Unit = { + createTopic(topic) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), groupResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), topicResource) + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + consumer.position(tp) + } + + @Test + def testOffsetFetchWithTopicAndGroupRead(): Unit = { + createTopic(topic) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), topicResource) + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + consumer.position(tp) + } + + @Test + def testMetadataWithNoTopicAccess(): Unit = { + val consumer = createConsumer() + assertThrows(classOf[TopicAuthorizationException], () => consumer.partitionsFor(topic)) + } + + @Test + def testMetadataWithTopicDescribe(): Unit = { + createTopic(topic) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), topicResource) + val consumer = createConsumer() + consumer.partitionsFor(topic) + } + + @Test + def testListOffsetsWithNoTopicAccess(): Unit = { + val consumer = createConsumer() + assertThrows(classOf[TopicAuthorizationException], () => consumer.endOffsets(Set(tp).asJava)) + } + + @Test + def testListOffsetsWithTopicDescribe(): Unit = { + createTopic(topic) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), topicResource) + val consumer = createConsumer() + consumer.endOffsets(Set(tp).asJava) + } + + @Test + def testDescribeGroupApiWithNoGroupAcl(): Unit = { + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), topicResource) + val result = createAdminClient().describeConsumerGroups(Seq(group).asJava) + TestUtils.assertFutureExceptionTypeEquals(result.describedGroups().get(group), classOf[GroupAuthorizationException]) + } + + @Test + def testDescribeGroupApiWithGroupDescribe(): Unit = { + createTopic(topic) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), groupResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), topicResource) + createAdminClient().describeConsumerGroups(Seq(group).asJava).describedGroups().get(group).get() + } + + @Test + def testDescribeGroupCliWithGroupDescribe(): Unit = { + createTopic(topic) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), groupResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), topicResource) + + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--group", group) + val opts = new ConsumerGroupCommandOptions(cgcArgs) + val consumerGroupService = new ConsumerGroupService(opts) + consumerGroupService.describeGroups() + consumerGroupService.close() + } + + @Test + def testListGroupApiWithAndWithoutListGroupAcls(): Unit = { + createTopic(topic) + + // write some record to the topic + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), topicResource) + val producer = createProducer() + sendRecords(producer, numRecords = 1, tp) + + // use two consumers to write to two different groups + val group2 = "other group" + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), new ResourcePattern(GROUP, group2, LITERAL)) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), topicResource) + val consumer = createConsumer() + consumer.subscribe(Collections.singleton(topic)) + consumeRecords(consumer) + + val otherConsumerProps = new Properties + otherConsumerProps.put(ConsumerConfig.GROUP_ID_CONFIG, group2) + val otherConsumer = createConsumer(configOverrides = otherConsumerProps) + otherConsumer.subscribe(Collections.singleton(topic)) + consumeRecords(otherConsumer) + + val adminClient = createAdminClient() + + // first use cluster describe permission + removeAllClientAcls() + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), clusterResource) + // it should list both groups (due to cluster describe permission) + assertEquals(Set(group, group2), adminClient.listConsumerGroups().all().get().asScala.map(_.groupId()).toSet) + + // now replace cluster describe with group read permission + removeAllClientAcls() + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResource) + // it should list only one group now + val groupList = adminClient.listConsumerGroups().all().get().asScala.toList + assertEquals(1, groupList.length) + assertEquals(group, groupList.head.groupId) + + // now remove all acls and verify describe group access is required to list any group + removeAllClientAcls() + val listGroupResult = adminClient.listConsumerGroups() + assertEquals(List(), listGroupResult.errors().get().asScala.toList) + assertEquals(List(), listGroupResult.all().get().asScala.toList) + otherConsumer.close() + } + + @Test + def testDeleteGroupApiWithDeleteGroupAcl(): Unit = { + createTopic(topic) + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), topicResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DELETE, ALLOW)), groupResource) + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + consumer.commitSync(Map(tp -> new OffsetAndMetadata(5, "")).asJava) + createAdminClient().deleteConsumerGroups(Seq(group).asJava).deletedGroups().get(group).get() + } + + @Test + def testDeleteGroupApiWithNoDeleteGroupAcl(): Unit = { + createTopic(topic) + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), topicResource) + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + consumer.commitSync(Map(tp -> new OffsetAndMetadata(5, "")).asJava) + val result = createAdminClient().deleteConsumerGroups(Seq(group).asJava) + TestUtils.assertFutureExceptionTypeEquals(result.deletedGroups().get(group), classOf[GroupAuthorizationException]) + } + + @Test + def testDeleteGroupApiWithNoDeleteGroupAcl2(): Unit = { + val result = createAdminClient().deleteConsumerGroups(Seq(group).asJava) + TestUtils.assertFutureExceptionTypeEquals(result.deletedGroups().get(group), classOf[GroupAuthorizationException]) + } + + @Test + def testDeleteGroupOffsetsWithAcl(): Unit = { + createTopic(topic) + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DELETE, ALLOW)), groupResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), topicResource) + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + consumer.commitSync(Map(tp -> new OffsetAndMetadata(5, "")).asJava) + consumer.close() + val result = createAdminClient().deleteConsumerGroupOffsets(group, Set(tp).asJava) + assertNull(result.partitionResult(tp).get()) + } + + @Test + def testDeleteGroupOffsetsWithoutDeleteAcl(): Unit = { + createTopic(topic) + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), topicResource) + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + consumer.commitSync(Map(tp -> new OffsetAndMetadata(5, "")).asJava) + consumer.close() + val result = createAdminClient().deleteConsumerGroupOffsets(group, Set(tp).asJava) + TestUtils.assertFutureExceptionTypeEquals(result.all(), classOf[GroupAuthorizationException]) + } + + @Test + def testDeleteGroupOffsetsWithDeleteAclWithoutTopicAcl(): Unit = { + createTopic(topic) + // Create the consumer group + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), topicResource) + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + consumer.commitSync(Map(tp -> new OffsetAndMetadata(5, "")).asJava) + consumer.close() + + // Remove the topic ACL & Check that it does not work without it + removeAllClientAcls() + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DELETE, ALLOW)), groupResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), groupResource) + val result = createAdminClient().deleteConsumerGroupOffsets(group, Set(tp).asJava) + TestUtils.assertFutureExceptionTypeEquals(result.all(), classOf[TopicAuthorizationException]) + TestUtils.assertFutureExceptionTypeEquals(result.partitionResult(tp), classOf[TopicAuthorizationException]) + } + + @Test + def testDeleteGroupOffsetsWithNoAcl(): Unit = { + val result = createAdminClient().deleteConsumerGroupOffsets(group, Set(tp).asJava) + TestUtils.assertFutureExceptionTypeEquals(result.all(), classOf[GroupAuthorizationException]) + } + + @Test + def testUnauthorizedDeleteTopicsWithoutDescribe(): Unit = { + val deleteResponse = connectAndReceive[DeleteTopicsResponse](deleteTopicsRequest) + assertEquals(Errors.TOPIC_AUTHORIZATION_FAILED.code, deleteResponse.data.responses.find(topic).errorCode) + } + + @Test + def testUnauthorizedDeleteTopicsWithDescribe(): Unit = { + createTopic(topic) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), topicResource) + val deleteResponse = connectAndReceive[DeleteTopicsResponse](deleteTopicsRequest) + assertEquals(Errors.TOPIC_AUTHORIZATION_FAILED.code, deleteResponse.data.responses.find(topic).errorCode) + } + + @Test + def testDeleteTopicsWithWildCardAuth(): Unit = { + createTopic(topic) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DELETE, ALLOW)), new ResourcePattern(TOPIC, "*", LITERAL)) + val deleteResponse = connectAndReceive[DeleteTopicsResponse](deleteTopicsRequest) + assertEquals(Errors.NONE.code, deleteResponse.data.responses.find(topic).errorCode) + } + + @Test + def testUnauthorizedDeleteRecordsWithoutDescribe(): Unit = { + val deleteRecordsResponse = connectAndReceive[DeleteRecordsResponse](deleteRecordsRequest) + assertEquals(Errors.TOPIC_AUTHORIZATION_FAILED.code, deleteRecordsResponse.data.topics.asScala.head. + partitions.asScala.head.errorCode) + } + + @Test + def testUnauthorizedDeleteRecordsWithDescribe(): Unit = { + createTopic(topic) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), topicResource) + val deleteRecordsResponse = connectAndReceive[DeleteRecordsResponse](deleteRecordsRequest) + assertEquals(Errors.TOPIC_AUTHORIZATION_FAILED.code, deleteRecordsResponse.data.topics.asScala.head. + partitions.asScala.head.errorCode) + } + + @Test + def testDeleteRecordsWithWildCardAuth(): Unit = { + createTopic(topic) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DELETE, ALLOW)), new ResourcePattern(TOPIC, "*", LITERAL)) + val deleteRecordsResponse = connectAndReceive[DeleteRecordsResponse](deleteRecordsRequest) + assertEquals(Errors.NONE.code, deleteRecordsResponse.data.topics.asScala.head. + partitions.asScala.head.errorCode) + } + + @Test + def testUnauthorizedCreatePartitions(): Unit = { + val createPartitionsResponse = connectAndReceive[CreatePartitionsResponse](createPartitionsRequest) + assertEquals(Errors.TOPIC_AUTHORIZATION_FAILED.code, createPartitionsResponse.data.results.asScala.head.errorCode) + } + + @Test + def testCreatePartitionsWithWildCardAuth(): Unit = { + createTopic(topic) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, ALTER, ALLOW)), new ResourcePattern(TOPIC, "*", LITERAL)) + val createPartitionsResponse = connectAndReceive[CreatePartitionsResponse](createPartitionsRequest) + assertEquals(Errors.NONE.code, createPartitionsResponse.data.results.asScala.head.errorCode) + } + + @Test + def testTransactionalProducerInitTransactionsNoWriteTransactionalIdAcl(): Unit = { + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), transactionalIdResource) + val producer = buildTransactionalProducer() + assertThrows(classOf[TransactionalIdAuthorizationException], () => producer.initTransactions()) + } + + @Test + def testTransactionalProducerInitTransactionsNoDescribeTransactionalIdAcl(): Unit = { + val producer = buildTransactionalProducer() + assertThrows(classOf[TransactionalIdAuthorizationException], () => producer.initTransactions()) + } + + @Test + def testSendOffsetsWithNoConsumerGroupDescribeAccess(): Unit = { + createTopic(topic) + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, CLUSTER_ACTION, ALLOW)), clusterResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), topicResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), transactionalIdResource) + val producer = buildTransactionalProducer() + producer.initTransactions() + producer.beginTransaction() + + assertThrows(classOf[GroupAuthorizationException], + () => producer.sendOffsetsToTransaction(Map(tp -> new OffsetAndMetadata(0L)).asJava, new ConsumerGroupMetadata(group))) + } + + @Test + def testSendOffsetsWithNoConsumerGroupWriteAccess(): Unit = { + createTopic(topic) + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), transactionalIdResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), groupResource) + val producer = buildTransactionalProducer() + producer.initTransactions() + producer.beginTransaction() + + assertThrows(classOf[GroupAuthorizationException], + () => producer.sendOffsetsToTransaction(Map(tp -> new OffsetAndMetadata(0L)).asJava, new ConsumerGroupMetadata(group))) + } + + @Test + def testIdempotentProducerNoIdempotentWriteAclInInitProducerId(): Unit = { + createTopic(topic) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, READ, ALLOW)), topicResource) + shouldIdempotentProducerFailInInitProducerId(true) + } + + def shouldIdempotentProducerFailInInitProducerId(expectAuthException: Boolean): Unit = { + val producer = buildIdempotentProducer() + try { + // the InitProducerId is sent asynchronously, so we expect the error either in the callback + // or raised from send itself + producer.send(new ProducerRecord[Array[Byte], Array[Byte]](topic, "hi".getBytes)).get() + if (expectAuthException) + fail("Should have raised ClusterAuthorizationException") + } catch { + case e: ExecutionException => + assertTrue(e.getCause.isInstanceOf[ClusterAuthorizationException]) + } + try { + // the second time, the call to send itself should fail (the producer becomes unusable + // if no producerId can be obtained) + producer.send(new ProducerRecord[Array[Byte], Array[Byte]](topic, "hi".getBytes)).get() + if (expectAuthException) + fail("Should have raised ClusterAuthorizationException") + } catch { + case e: ExecutionException => + assertTrue(e.getCause.isInstanceOf[ClusterAuthorizationException]) + } + } + + @Test + def testIdempotentProducerNoIdempotentWriteAclInProduce(): Unit = { + createTopic(topic) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), topicResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, IDEMPOTENT_WRITE, ALLOW)), clusterResource) + idempotentProducerShouldFailInProduce(() => removeAllClientAcls()) + } + + def idempotentProducerShouldFailInProduce(removeAclIdempotenceRequired: () => Unit): Unit = { + val producer = buildIdempotentProducer() + + // first send should be fine since we have permission to get a ProducerId + producer.send(new ProducerRecord[Array[Byte], Array[Byte]](topic, "hi".getBytes)).get() + + // revoke the IdempotentWrite permission + removeAclIdempotenceRequired() + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), topicResource) + + // the send should now fail with a cluster auth error + var e = assertThrows(classOf[ExecutionException], () => producer.send(new ProducerRecord[Array[Byte], Array[Byte]](topic, "hi".getBytes)).get()) + assertTrue(e.getCause.isInstanceOf[TopicAuthorizationException]) + + // the second time, the call to send itself should fail (the producer becomes unusable + // if no producerId can be obtained) + e = assertThrows(classOf[ExecutionException], () => producer.send(new ProducerRecord[Array[Byte], Array[Byte]](topic, "hi".getBytes)).get()) + assertTrue(e.getCause.isInstanceOf[TopicAuthorizationException]) + } + + @Test + def shouldInitTransactionsWhenAclSet(): Unit = { + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), transactionalIdResource) + val producer = buildTransactionalProducer() + producer.initTransactions() + } + + @Test + def testTransactionalProducerTopicAuthorizationExceptionInSendCallback(): Unit = { + createTopic(topic) + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), transactionalIdResource) + // add describe access so that we can fetch metadata + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), topicResource) + val producer = buildTransactionalProducer() + producer.initTransactions() + producer.beginTransaction() + + val future = producer.send(new ProducerRecord(tp.topic, tp.partition, "1".getBytes, "1".getBytes)) + val e = JTestUtils.assertFutureThrows(future, classOf[TopicAuthorizationException]) + assertEquals(Set(topic), e.unauthorizedTopics.asScala) + } + + @Test + def testTransactionalProducerTopicAuthorizationExceptionInCommit(): Unit = { + createTopic(topic) + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), transactionalIdResource) + // add describe access so that we can fetch metadata + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), topicResource) + val producer = buildTransactionalProducer() + producer.initTransactions() + producer.beginTransaction() + + assertThrows(classOf[TopicAuthorizationException], () => { + producer.send(new ProducerRecord(tp.topic, tp.partition, "1".getBytes, "1".getBytes)) + producer.commitTransaction() + }) + } + + @Test + def shouldThrowTransactionalIdAuthorizationExceptionWhenNoTransactionAccessDuringSend(): Unit = { + createTopic(topic) + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), transactionalIdResource) + val producer = buildTransactionalProducer() + producer.initTransactions() + removeAllClientAcls() + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), topicResource) + producer.beginTransaction() + val future = producer.send(new ProducerRecord(tp.topic, tp.partition, "1".getBytes, "1".getBytes)) + JTestUtils.assertFutureThrows(future, classOf[TransactionalIdAuthorizationException]) + } + + @Test + def shouldThrowTransactionalIdAuthorizationExceptionWhenNoTransactionAccessOnEndTransaction(): Unit = { + createTopic(topic) + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), transactionalIdResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), topicResource) + val producer = buildTransactionalProducer() + producer.initTransactions() + producer.beginTransaction() + producer.send(new ProducerRecord(tp.topic, tp.partition, "1".getBytes, "1".getBytes)).get + removeAllClientAcls() + assertThrows(classOf[TransactionalIdAuthorizationException], () => producer.commitTransaction()) + } + + @Test + def testListTransactionsAuthorization(): Unit = { + createTopic(topic) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), transactionalIdResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), topicResource) + + // Start a transaction and write to a topic. + val producer = buildTransactionalProducer() + producer.initTransactions() + producer.beginTransaction() + producer.send(new ProducerRecord(tp.topic, tp.partition, "1".getBytes, "1".getBytes)).get + + def assertListTransactionResult( + expectedTransactionalIds: Set[String] + ): Unit = { + val listTransactionsRequest = new ListTransactionsRequest.Builder(new ListTransactionsRequestData()).build() + val listTransactionsResponse = connectAndReceive[ListTransactionsResponse](listTransactionsRequest) + assertEquals(Errors.NONE, Errors.forCode(listTransactionsResponse.data.errorCode)) + assertEquals(expectedTransactionalIds, listTransactionsResponse.data.transactionStates.asScala.map(_.transactionalId).toSet) + } + + // First verify that we can list the transaction + assertListTransactionResult(expectedTransactionalIds = Set(transactionalId)) + + // Now revoke authorization and verify that the transaction is no longer listable + removeAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), transactionalIdResource) + assertListTransactionResult(expectedTransactionalIds = Set()) + + // The minimum permission needed is `Describe` + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), transactionalIdResource) + assertListTransactionResult(expectedTransactionalIds = Set(transactionalId)) + } + + @Test + def shouldNotIncludeUnauthorizedTopicsInDescribeTransactionsResponse(): Unit = { + createTopic(topic) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), transactionalIdResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), topicResource) + + // Start a transaction and write to a topic. + val producer = buildTransactionalProducer() + producer.initTransactions() + producer.beginTransaction() + producer.send(new ProducerRecord(tp.topic, tp.partition, "1".getBytes, "1".getBytes)).get + + // Remove only topic authorization so that we can verify that the + // topic does not get included in the response. + removeAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), topicResource) + val response = connectAndReceive[DescribeTransactionsResponse](describeTransactionsRequest) + assertEquals(1, response.data.transactionStates.size) + val transactionStateData = response.data.transactionStates.asScala.find(_.transactionalId == transactionalId).get + assertEquals("Ongoing", transactionStateData.transactionState) + assertEquals(List.empty, transactionStateData.topics.asScala.toList) + } + + @Test + def shouldSuccessfullyAbortTransactionAfterTopicAuthorizationException(): Unit = { + createTopic(topic) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), transactionalIdResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), topicResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), + new ResourcePattern(TOPIC, topic, LITERAL)) + val producer = buildTransactionalProducer() + producer.initTransactions() + producer.beginTransaction() + producer.send(new ProducerRecord(tp.topic, tp.partition, "1".getBytes, "1".getBytes)).get + // try and add a partition resulting in TopicAuthorizationException + val future = producer.send(new ProducerRecord("otherTopic", 0, "1".getBytes, "1".getBytes)) + val e = JTestUtils.assertFutureThrows(future, classOf[TopicAuthorizationException]) + assertEquals(Set("otherTopic"), e.unauthorizedTopics.asScala) + // now rollback + producer.abortTransaction() + } + + @Test + def shouldThrowTransactionalIdAuthorizationExceptionWhenNoTransactionAccessOnSendOffsetsToTxn(): Unit = { + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), transactionalIdResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), groupResource) + val producer = buildTransactionalProducer() + producer.initTransactions() + producer.beginTransaction() + removeAllClientAcls() + assertThrows(classOf[TransactionalIdAuthorizationException], () => { + val offsets = Map(tp -> new OffsetAndMetadata(1L)).asJava + producer.sendOffsetsToTransaction(offsets, new ConsumerGroupMetadata(group)) + producer.commitTransaction() + }) + } + + @Test + def shouldSendSuccessfullyWhenIdempotentAndHasCorrectACL(): Unit = { + createTopic(topic) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, IDEMPOTENT_WRITE, ALLOW)), clusterResource) + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), topicResource) + val producer = buildIdempotentProducer() + producer.send(new ProducerRecord(tp.topic, tp.partition, "1".getBytes, "1".getBytes)).get + } + + // Verify that metadata request without topics works without any ACLs and returns cluster id + @Test + def testClusterId(): Unit = { + val request = new requests.MetadataRequest.Builder(List.empty.asJava, false).build() + val response = connectAndReceive[MetadataResponse](request) + assertEquals(Collections.emptyMap, response.errorCounts) + assertFalse(response.clusterId.isEmpty, "Cluster id not returned") + } + + @Test + def testAuthorizeByResourceTypeMultipleAddAndRemove(): Unit = { + createTopic(topic) + + for (_ <- 1 to 3) { + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), topicResource) + shouldIdempotentProducerFailInInitProducerId(true) + + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW)), topicResource) + shouldIdempotentProducerFailInInitProducerId(false) + + removeAllClientAcls() + addAndVerifyAcls(Set(new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW)), topicResource) + shouldIdempotentProducerFailInInitProducerId(true) + } + } + + @Test + def testAuthorizeByResourceTypeIsolationUnrelatedDenyWontDominateAllow(): Unit = { + createTopic(topic) + createTopic("topic-2") + createTopic("to") + + val unrelatedPrincipalString = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "unrelated").toString + val unrelatedTopicResource = new ResourcePattern(TOPIC, "topic-2", LITERAL) + val unrelatedGroupResource = new ResourcePattern(GROUP, "to", PREFIXED) + + val acl1 = new AccessControlEntry(clientPrincipalString, WildcardHost, READ, DENY) + val acl2 = new AccessControlEntry(unrelatedPrincipalString, WildcardHost, READ, DENY) + val acl3 = new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, DENY) + val acl4 = new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW) + val acl5 = new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW) + + addAndVerifyAcls(Set(acl1, acl4, acl5), topicResource) + addAndVerifyAcls(Set(acl2, acl3), unrelatedTopicResource) + addAndVerifyAcls(Set(acl2, acl3), unrelatedGroupResource) + shouldIdempotentProducerFailInInitProducerId(false) + } + + @Test + def testAuthorizeByResourceTypeDenyTakesPrecedence(): Unit = { + createTopic(topic) + val allowWriteAce = new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW) + addAndVerifyAcls(Set(allowWriteAce), topicResource) + shouldIdempotentProducerFailInInitProducerId(false) + + val denyWriteAce = new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, DENY) + addAndVerifyAcls(Set(denyWriteAce), topicResource) + shouldIdempotentProducerFailInInitProducerId(true) + } + + @Test + def testAuthorizeByResourceTypeWildcardResourceDenyDominate(): Unit = { + createTopic(topic) + val wildcard = new ResourcePattern(TOPIC, ResourcePattern.WILDCARD_RESOURCE, LITERAL) + val prefixed = new ResourcePattern(TOPIC, "t", PREFIXED) + val literal = new ResourcePattern(TOPIC, topic, LITERAL) + val allowWriteAce = new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW) + val denyWriteAce = new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, DENY) + + addAndVerifyAcls(Set(allowWriteAce), prefixed) + addAndVerifyAcls(Set(allowWriteAce), literal) + shouldIdempotentProducerFailInInitProducerId(false) + + addAndVerifyAcls(Set(denyWriteAce), wildcard) + shouldIdempotentProducerFailInInitProducerId(true) + } + + @Test + def testAuthorizeByResourceTypePrefixedResourceDenyDominate(): Unit = { + createTopic(topic) + val prefixed = new ResourcePattern(TOPIC, topic.substring(0, 1), PREFIXED) + val literal = new ResourcePattern(TOPIC, topic, LITERAL) + val allowWriteAce = new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, ALLOW) + val denyWriteAce = new AccessControlEntry(clientPrincipalString, WildcardHost, WRITE, DENY) + + addAndVerifyAcls(Set(denyWriteAce), prefixed) + addAndVerifyAcls(Set(allowWriteAce), literal) + shouldIdempotentProducerFailInInitProducerId(true) + } + + @Test + def testMetadataClusterAuthorizedOperationsWithoutDescribeCluster(): Unit = { + removeAllClientAcls() + + // MetadataRequest versions older than 1 are not supported. + for (version <- 1 to ApiKeys.METADATA.latestVersion) { + testMetadataClusterClusterAuthorizedOperations(version.toShort, 0) + } + } + + @Test + def testMetadataClusterAuthorizedOperationsWithDescribeAndAlterCluster(): Unit = { + removeAllClientAcls() + + val clusterResource = new ResourcePattern(ResourceType.CLUSTER, Resource.CLUSTER_NAME, PatternType.LITERAL) + val acls = Set( + new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW), + new AccessControlEntry(clientPrincipalString, WildcardHost, ALTER, ALLOW) + ) + addAndVerifyAcls(acls, clusterResource) + + val expectedClusterAuthorizedOperations = Utils.to32BitField( + acls.map(_.operation.code.asInstanceOf[JByte]).asJava) + + // MetadataRequest versions older than 1 are not supported. + for (version <- 1 to ApiKeys.METADATA.latestVersion) { + testMetadataClusterClusterAuthorizedOperations(version.toShort, expectedClusterAuthorizedOperations) + } + } + + private def testMetadataClusterClusterAuthorizedOperations( + version: Short, + expectedClusterAuthorizedOperations: Int + ): Unit = { + val metadataRequest = new MetadataRequest.Builder(new MetadataRequestData() + .setTopics(Collections.emptyList()) + .setAllowAutoTopicCreation(true) + .setIncludeClusterAuthorizedOperations(true)) + .build(version) + + // The expected value is only verified if the request supports it. + if (version >= 8 && version <= 10) { + val metadataResponse = connectAndReceive[MetadataResponse](metadataRequest) + assertEquals(expectedClusterAuthorizedOperations, metadataResponse.data.clusterAuthorizedOperations) + } else { + assertThrows(classOf[UnsupportedVersionException], + () => connectAndReceive[MetadataResponse](metadataRequest)) + } + } + + @Test + def testDescribeClusterClusterAuthorizedOperationsWithoutDescribeCluster(): Unit = { + removeAllClientAcls() + + for (version <- ApiKeys.DESCRIBE_CLUSTER.oldestVersion to ApiKeys.DESCRIBE_CLUSTER.latestVersion) { + testDescribeClusterClusterAuthorizedOperations(version.toShort, 0) + } + } + + @Test + def testDescribeClusterClusterAuthorizedOperationsWithDescribeAndAlterCluster(): Unit = { + removeAllClientAcls() + + val clusterResource = new ResourcePattern(ResourceType.CLUSTER, Resource.CLUSTER_NAME, PatternType.LITERAL) + val acls = Set( + new AccessControlEntry(clientPrincipalString, WildcardHost, DESCRIBE, ALLOW), + new AccessControlEntry(clientPrincipalString, WildcardHost, ALTER, ALLOW) + ) + addAndVerifyAcls(acls, clusterResource) + + val expectedClusterAuthorizedOperations = Utils.to32BitField( + acls.map(_.operation.code.asInstanceOf[JByte]).asJava) + + for (version <- ApiKeys.DESCRIBE_CLUSTER.oldestVersion to ApiKeys.DESCRIBE_CLUSTER.latestVersion) { + testDescribeClusterClusterAuthorizedOperations(version.toShort, expectedClusterAuthorizedOperations) + } + } + + private def testDescribeClusterClusterAuthorizedOperations( + version: Short, + expectedClusterAuthorizedOperations: Int + ): Unit = { + val describeClusterRequest = new DescribeClusterRequest.Builder(new DescribeClusterRequestData() + .setIncludeClusterAuthorizedOperations(true)) + .build(version) + + val describeClusterResponse = connectAndReceive[DescribeClusterResponse](describeClusterRequest) + assertEquals(expectedClusterAuthorizedOperations, describeClusterResponse.data.clusterAuthorizedOperations) + } + + def removeAllClientAcls(): Unit = { + val authorizer = servers.head.dataPlaneRequestProcessor.authorizer.get + val aclEntryFilter = new AccessControlEntryFilter(clientPrincipalString, null, AclOperation.ANY, AclPermissionType.ANY) + val aclFilter = new AclBindingFilter(ResourcePatternFilter.ANY, aclEntryFilter) + + authorizer.deleteAcls(null, List(aclFilter).asJava).asScala.map(_.toCompletableFuture.get).flatMap { deletion => + deletion.aclBindingDeleteResults().asScala.map(_.aclBinding.pattern).toSet + }.foreach { resource => + TestUtils.waitAndVerifyAcls(Set.empty[AccessControlEntry], authorizer, resource, aclEntryFilter) + } + } + + private def sendRequestAndVerifyResponseError(request: AbstractRequest, + resources: Set[ResourceType], + isAuthorized: Boolean, + topicExists: Boolean = true, + topicNames: Map[Uuid, String] = getTopicNames()): AbstractResponse = { + val apiKey = request.apiKey + val response = connectAndReceive[AbstractResponse](request) + val error = requestKeyToError(topicNames, request.version())(apiKey).asInstanceOf[AbstractResponse => Errors](response) + + val authorizationErrors = resources.flatMap { resourceType => + if (resourceType == TOPIC) { + if (isAuthorized) + Set(Errors.UNKNOWN_TOPIC_OR_PARTITION, AclEntry.authorizationError(ResourceType.TOPIC)) + else + Set(AclEntry.authorizationError(ResourceType.TOPIC)) + } else { + Set(AclEntry.authorizationError(resourceType)) + } + } + + if (topicExists) + if (isAuthorized) + assertFalse(authorizationErrors.contains(error), s"$apiKey should be allowed. Found unexpected authorization error $error") + else + assertTrue(authorizationErrors.contains(error), s"$apiKey should be forbidden. Found error $error but expected one of $authorizationErrors") + else if (resources == Set(TOPIC)) + if (isAuthorized) + if (apiKey.equals(ApiKeys.FETCH) && request.version() >= 13) + assertEquals(Errors.UNKNOWN_TOPIC_ID, error, s"$apiKey had an unexpected error") + else + assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, error, s"$apiKey had an unexpected error") + else { + if (apiKey.equals(ApiKeys.FETCH) && request.version() >= 13) + assertEquals(Errors.UNKNOWN_TOPIC_ID, error, s"$apiKey had an unexpected error") + else + assertEquals(Errors.TOPIC_AUTHORIZATION_FAILED, error, s"$apiKey had an unexpected error") + } + + response + } + + private def sendRecords(producer: KafkaProducer[Array[Byte], Array[Byte]], + numRecords: Int, + tp: TopicPartition): Unit = { + val futures = (0 until numRecords).map { i => + producer.send(new ProducerRecord(tp.topic, tp.partition, i.toString.getBytes, i.toString.getBytes)) + } + try { + futures.foreach(_.get) + } catch { + case e: ExecutionException => throw e.getCause + } + } + + private def addAndVerifyAcls(acls: Set[AccessControlEntry], resource: ResourcePattern): Unit = { + TestUtils.addAndVerifyAcls(servers.head, acls, resource) + } + + private def removeAndVerifyAcls(acls: Set[AccessControlEntry], resource: ResourcePattern): Unit = { + TestUtils.removeAndVerifyAcls(servers.head, acls, resource) + } + + private def consumeRecords(consumer: Consumer[Array[Byte], Array[Byte]], + numRecords: Int = 1, + startingOffset: Int = 0, + topic: String = topic, + part: Int = part): Unit = { + val records = TestUtils.consumeRecords(consumer, numRecords) + + for (i <- 0 until numRecords) { + val record = records(i) + val offset = startingOffset + i + assertEquals(topic, record.topic) + assertEquals(part, record.partition) + assertEquals(offset.toLong, record.offset) + } + } + + private def buildTransactionalProducer(): KafkaProducer[Array[Byte], Array[Byte]] = { + producerConfig.setProperty(ProducerConfig.TRANSACTIONAL_ID_CONFIG, transactionalId) + producerConfig.setProperty(ProducerConfig.ACKS_CONFIG, "all") + createProducer() + } + + private def buildIdempotentProducer(): KafkaProducer[Array[Byte], Array[Byte]] = { + producerConfig.setProperty(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG, "true") + producerConfig.setProperty(ProducerConfig.ACKS_CONFIG, "all") + createProducer() + } + + private def createAdminClient(): Admin = { + val props = new Properties() + props.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList) + val adminClient = Admin.create(props) + adminClients += adminClient + adminClient + } +} diff --git a/core/src/test/scala/integration/kafka/api/BaseAdminIntegrationTest.scala b/core/src/test/scala/integration/kafka/api/BaseAdminIntegrationTest.scala new file mode 100644 index 0000000..9c8e32f --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/BaseAdminIntegrationTest.scala @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.api + +import java.util +import java.util.Properties +import java.util.concurrent.ExecutionException +import kafka.security.authorizer.AclEntry +import kafka.server.KafkaConfig +import kafka.utils.Logging +import kafka.utils.TestUtils._ +import org.apache.kafka.clients.admin.{Admin, AdminClientConfig, CreateTopicsOptions, CreateTopicsResult, DescribeClusterOptions, DescribeTopicsOptions, NewTopic, TopicDescription} +import org.apache.kafka.common.Uuid +import org.apache.kafka.common.acl.AclOperation +import org.apache.kafka.common.errors.{TopicExistsException, UnknownTopicOrPartitionException} +import org.apache.kafka.common.resource.ResourceType +import org.apache.kafka.common.utils.Utils +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo, Timeout} + +import scala.jdk.CollectionConverters._ +import scala.collection.Seq +import scala.compat.java8.OptionConverters._ + +/** + * Base integration test cases for [[Admin]]. Each test case added here will be executed + * in extending classes. Typically we prefer to write basic Admin functionality test cases in + * [[kafka.api.PlaintextAdminIntegrationTest]] rather than here to avoid unnecessary execution + * time to the build. However, if an admin API involves differing interactions with + * authentication/authorization layers, we may add the test case here. + */ +@Timeout(120) +abstract class BaseAdminIntegrationTest extends IntegrationTestHarness with Logging { + def brokerCount = 3 + override def logDirCount = 2 + + var client: Admin = _ + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + waitUntilBrokerMetadataIsPropagated(servers) + } + + @AfterEach + override def tearDown(): Unit = { + if (client != null) + Utils.closeQuietly(client, "AdminClient") + super.tearDown() + } + + @Test + def testCreateDeleteTopics(): Unit = { + client = Admin.create(createConfig) + val topics = Seq("mytopic", "mytopic2", "mytopic3") + val newTopics = Seq( + new NewTopic("mytopic", Map((0: Integer) -> Seq[Integer](1, 2).asJava, (1: Integer) -> Seq[Integer](2, 0).asJava).asJava), + new NewTopic("mytopic2", 3, 3.toShort), + new NewTopic("mytopic3", Option.empty[Integer].asJava, Option.empty[java.lang.Short].asJava) + ) + val validateResult = client.createTopics(newTopics.asJava, new CreateTopicsOptions().validateOnly(true)) + validateResult.all.get() + waitForTopics(client, List(), topics) + + def validateMetadataAndConfigs(result: CreateTopicsResult): Unit = { + assertEquals(2, result.numPartitions("mytopic").get()) + assertEquals(2, result.replicationFactor("mytopic").get()) + assertEquals(3, result.numPartitions("mytopic2").get()) + assertEquals(3, result.replicationFactor("mytopic2").get()) + assertEquals(configs.head.numPartitions, result.numPartitions("mytopic3").get()) + assertEquals(configs.head.defaultReplicationFactor, result.replicationFactor("mytopic3").get()) + assertFalse(result.config("mytopic").get().entries.isEmpty) + } + validateMetadataAndConfigs(validateResult) + + val createResult = client.createTopics(newTopics.asJava) + createResult.all.get() + waitForTopics(client, topics, List()) + validateMetadataAndConfigs(createResult) + val topicIds = getTopicIds() + topics.foreach { topic => + assertNotEquals(Uuid.ZERO_UUID, createResult.topicId(topic).get()) + assertEquals(topicIds(topic), createResult.topicId(topic).get()) + } + + + val failedCreateResult = client.createTopics(newTopics.asJava) + val results = failedCreateResult.values() + assertTrue(results.containsKey("mytopic")) + assertFutureExceptionTypeEquals(results.get("mytopic"), classOf[TopicExistsException]) + assertTrue(results.containsKey("mytopic2")) + assertFutureExceptionTypeEquals(results.get("mytopic2"), classOf[TopicExistsException]) + assertTrue(results.containsKey("mytopic3")) + assertFutureExceptionTypeEquals(results.get("mytopic3"), classOf[TopicExistsException]) + assertFutureExceptionTypeEquals(failedCreateResult.numPartitions("mytopic3"), classOf[TopicExistsException]) + assertFutureExceptionTypeEquals(failedCreateResult.replicationFactor("mytopic3"), classOf[TopicExistsException]) + assertFutureExceptionTypeEquals(failedCreateResult.config("mytopic3"), classOf[TopicExistsException]) + + val topicToDescription = client.describeTopics(topics.asJava).allTopicNames.get() + assertEquals(topics.toSet, topicToDescription.keySet.asScala) + + val topic0 = topicToDescription.get("mytopic") + assertEquals(false, topic0.isInternal) + assertEquals("mytopic", topic0.name) + assertEquals(2, topic0.partitions.size) + val topic0Partition0 = topic0.partitions.get(0) + assertEquals(1, topic0Partition0.leader.id) + assertEquals(0, topic0Partition0.partition) + assertEquals(Seq(1, 2), topic0Partition0.isr.asScala.map(_.id)) + assertEquals(Seq(1, 2), topic0Partition0.replicas.asScala.map(_.id)) + val topic0Partition1 = topic0.partitions.get(1) + assertEquals(2, topic0Partition1.leader.id) + assertEquals(1, topic0Partition1.partition) + assertEquals(Seq(2, 0), topic0Partition1.isr.asScala.map(_.id)) + assertEquals(Seq(2, 0), topic0Partition1.replicas.asScala.map(_.id)) + + val topic1 = topicToDescription.get("mytopic2") + assertEquals(false, topic1.isInternal) + assertEquals("mytopic2", topic1.name) + assertEquals(3, topic1.partitions.size) + for (partitionId <- 0 until 3) { + val partition = topic1.partitions.get(partitionId) + assertEquals(partitionId, partition.partition) + assertEquals(3, partition.replicas.size) + partition.replicas.forEach { replica => + assertTrue(replica.id >= 0) + assertTrue(replica.id < brokerCount) + } + assertEquals(partition.replicas.size, partition.replicas.asScala.map(_.id).distinct.size, "No duplicate replica ids") + + assertEquals(3, partition.isr.size) + assertEquals(partition.replicas, partition.isr) + assertTrue(partition.replicas.contains(partition.leader)) + } + + val topic3 = topicToDescription.get("mytopic3") + assertEquals("mytopic3", topic3.name) + assertEquals(configs.head.numPartitions, topic3.partitions.size) + assertEquals(configs.head.defaultReplicationFactor, topic3.partitions.get(0).replicas().size()) + + client.deleteTopics(topics.asJava).all.get() + waitForTopics(client, List(), topics) + } + + @Test + def testAuthorizedOperations(): Unit = { + client = Admin.create(createConfig) + + // without includeAuthorizedOperations flag + var result = client.describeCluster + assertNull(result.authorizedOperations().get()) + + //with includeAuthorizedOperations flag + result = client.describeCluster(new DescribeClusterOptions().includeAuthorizedOperations(true)) + var expectedOperations = configuredClusterPermissions.asJava + assertEquals(expectedOperations, result.authorizedOperations().get()) + + val topic = "mytopic" + val newTopics = Seq(new NewTopic(topic, 3, 3.toShort)) + client.createTopics(newTopics.asJava).all.get() + waitForTopics(client, expectedPresent = Seq(topic), expectedMissing = List()) + + // without includeAuthorizedOperations flag + var topicResult = getTopicMetadata(client, topic) + assertNull(topicResult.authorizedOperations) + + //with includeAuthorizedOperations flag + topicResult = getTopicMetadata(client, topic, new DescribeTopicsOptions().includeAuthorizedOperations(true)) + expectedOperations = AclEntry.supportedOperations(ResourceType.TOPIC).asJava + assertEquals(expectedOperations, topicResult.authorizedOperations) + } + + def configuredClusterPermissions: Set[AclOperation] = + AclEntry.supportedOperations(ResourceType.CLUSTER) + + override def modifyConfigs(configs: Seq[Properties]): Unit = { + super.modifyConfigs(configs) + configs.foreach { config => + config.setProperty(KafkaConfig.DeleteTopicEnableProp, "true") + config.setProperty(KafkaConfig.GroupInitialRebalanceDelayMsProp, "0") + config.setProperty(KafkaConfig.AutoLeaderRebalanceEnableProp, "false") + config.setProperty(KafkaConfig.ControlledShutdownEnableProp, "false") + // We set this in order to test that we don't expose sensitive data via describe configs. This will already be + // set for subclasses with security enabled and we don't want to overwrite it. + if (!config.containsKey(KafkaConfig.SslTruststorePasswordProp)) + config.setProperty(KafkaConfig.SslTruststorePasswordProp, "some.invalid.pass") + } + } + + def createConfig: util.Map[String, Object] = { + val config = new util.HashMap[String, Object] + config.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList) + config.put(AdminClientConfig.REQUEST_TIMEOUT_MS_CONFIG, "20000") + val securityProps: util.Map[Object, Object] = + adminClientSecurityConfigs(securityProtocol, trustStoreFile, clientSaslProperties) + securityProps.forEach { (key, value) => config.put(key.asInstanceOf[String], value) } + config + } + + def waitForTopics(client: Admin, expectedPresent: Seq[String], expectedMissing: Seq[String]): Unit = { + waitUntilTrue(() => { + val topics = client.listTopics.names.get() + expectedPresent.forall(topicName => topics.contains(topicName)) && + expectedMissing.forall(topicName => !topics.contains(topicName)) + }, "timed out waiting for topics") + } + + def getTopicMetadata(client: Admin, + topic: String, + describeOptions: DescribeTopicsOptions = new DescribeTopicsOptions, + expectedNumPartitionsOpt: Option[Int] = None): TopicDescription = { + var result: TopicDescription = null + waitUntilTrue(() => { + val topicResult = client.describeTopics(Set(topic).asJava, describeOptions).topicNameValues().get(topic) + try { + result = topicResult.get + expectedNumPartitionsOpt.map(_ == result.partitions.size).getOrElse(true) + } catch { + case e: ExecutionException if e.getCause.isInstanceOf[UnknownTopicOrPartitionException] => false // metadata may not have propagated yet, so retry + } + }, s"Timed out waiting for metadata for $topic") + result + } + +} diff --git a/core/src/test/scala/integration/kafka/api/BaseConsumerTest.scala b/core/src/test/scala/integration/kafka/api/BaseConsumerTest.scala new file mode 100644 index 0000000..fe56040 --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/BaseConsumerTest.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.api + +import org.apache.kafka.clients.consumer.ConsumerConfig +import org.apache.kafka.common.PartitionInfo +import org.apache.kafka.common.internals.Topic +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Assertions._ + +import scala.jdk.CollectionConverters._ +import scala.collection.Seq + +/** + * Integration tests for the consumer that cover basic usage as well as coordinator failure + */ +abstract class BaseConsumerTest extends AbstractConsumerTest { + + @Test + def testSimpleConsumption(): Unit = { + val numRecords = 10000 + val producer = createProducer() + val startingTimestamp = System.currentTimeMillis() + sendRecords(producer, numRecords, tp, startingTimestamp = startingTimestamp) + + val consumer = createConsumer() + assertEquals(0, consumer.assignment.size) + consumer.assign(List(tp).asJava) + assertEquals(1, consumer.assignment.size) + + consumer.seek(tp, 0) + consumeAndVerifyRecords(consumer = consumer, numRecords = numRecords, startingOffset = 0, startingTimestamp = startingTimestamp) + + // check async commit callbacks + sendAndAwaitAsyncCommit(consumer) + } + + @Test + def testCoordinatorFailover(): Unit = { + val listener = new TestConsumerReassignmentListener() + this.consumerConfig.setProperty(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, "5001") + this.consumerConfig.setProperty(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, "1000") + // Use higher poll timeout to avoid consumer leaving the group due to timeout + this.consumerConfig.setProperty(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG, "15000") + val consumer = createConsumer() + + consumer.subscribe(List(topic).asJava, listener) + + // the initial subscription should cause a callback execution + awaitRebalance(consumer, listener) + assertEquals(1, listener.callsToAssigned) + + // get metadata for the topic + var parts: Seq[PartitionInfo] = null + while (parts == null) + parts = consumer.partitionsFor(Topic.GROUP_METADATA_TOPIC_NAME).asScala + assertEquals(1, parts.size) + assertNotNull(parts.head.leader()) + + // shutdown the coordinator + val coordinator = parts.head.leader().id() + this.servers(coordinator).shutdown() + + // the failover should not cause a rebalance + ensureNoRebalance(consumer, listener) + } +} diff --git a/core/src/test/scala/integration/kafka/api/BaseProducerSendTest.scala b/core/src/test/scala/integration/kafka/api/BaseProducerSendTest.scala new file mode 100644 index 0000000..8c2b6da --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/BaseProducerSendTest.scala @@ -0,0 +1,488 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.api + +import java.time.Duration +import java.nio.charset.StandardCharsets +import java.util.Properties +import java.util.concurrent.TimeUnit + +import kafka.integration.KafkaServerTestHarness +import kafka.log.LogConfig +import kafka.server.KafkaConfig +import kafka.utils.TestUtils +import org.apache.kafka.clients.consumer.KafkaConsumer +import org.apache.kafka.clients.producer._ +import org.apache.kafka.common.errors.TimeoutException +import org.apache.kafka.common.record.TimestampType +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.{KafkaException, TopicPartition} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +import scala.jdk.CollectionConverters._ +import scala.collection.mutable.Buffer +import scala.concurrent.ExecutionException + +abstract class BaseProducerSendTest extends KafkaServerTestHarness { + + def generateConfigs = { + val overridingProps = new Properties() + val numServers = 2 + overridingProps.put(KafkaConfig.NumPartitionsProp, 4.toString) + TestUtils.createBrokerConfigs(numServers, zkConnect, false, interBrokerSecurityProtocol = Some(securityProtocol), + trustStoreFile = trustStoreFile, saslProperties = serverSaslProperties).map(KafkaConfig.fromProps(_, overridingProps)) + } + + private var consumer: KafkaConsumer[Array[Byte], Array[Byte]] = _ + private val producers = Buffer[KafkaProducer[Array[Byte], Array[Byte]]]() + + protected val topic = "topic" + private val numRecords = 100 + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + consumer = TestUtils.createConsumer(TestUtils.getBrokerListStrFromServers(servers), securityProtocol = SecurityProtocol.PLAINTEXT) + } + + @AfterEach + override def tearDown(): Unit = { + consumer.close() + // Ensure that all producers are closed since unclosed producers impact other tests when Kafka server ports are reused + producers.foreach(_.close()) + + super.tearDown() + } + + protected def createProducer(brokerList: String, + lingerMs: Int = 0, + deliveryTimeoutMs: Int = 2 * 60 * 1000, + batchSize: Int = 16384, + compressionType: String = "none", + maxBlockMs: Long = 60 * 1000L, + bufferSize: Long = 1024L * 1024L): KafkaProducer[Array[Byte],Array[Byte]] = { + val producer = TestUtils.createProducer(brokerList, + compressionType = compressionType, + securityProtocol = securityProtocol, + trustStoreFile = trustStoreFile, + saslProperties = clientSaslProperties, + lingerMs = lingerMs, + deliveryTimeoutMs = deliveryTimeoutMs, + maxBlockMs = maxBlockMs, + batchSize = batchSize, + bufferSize = bufferSize) + registerProducer(producer) + } + + protected def registerProducer(producer: KafkaProducer[Array[Byte], Array[Byte]]): KafkaProducer[Array[Byte], Array[Byte]] = { + producers += producer + producer + } + + /** + * testSendOffset checks the basic send API behavior + * + * 1. Send with null key/value/partition-id should be accepted; send with null topic should be rejected. + * 2. Last message of the non-blocking send should return the correct offset metadata + */ + @Test + def testSendOffset(): Unit = { + val producer = createProducer(brokerList) + val partition = 0 + + object callback extends Callback { + var offset = 0L + + def onCompletion(metadata: RecordMetadata, exception: Exception): Unit = { + if (exception == null) { + assertEquals(offset, metadata.offset()) + assertEquals(topic, metadata.topic()) + assertEquals(partition, metadata.partition()) + offset match { + case 0 => assertEquals(metadata.serializedKeySize + metadata.serializedValueSize, + "key".getBytes(StandardCharsets.UTF_8).length + "value".getBytes(StandardCharsets.UTF_8).length) + case 1 => assertEquals(metadata.serializedKeySize(), "key".getBytes(StandardCharsets.UTF_8).length) + case 2 => assertEquals(metadata.serializedValueSize, "value".getBytes(StandardCharsets.UTF_8).length) + case _ => assertTrue(metadata.serializedValueSize > 0) + } + offset += 1 + } else { + fail(s"Send callback returns the following exception: $exception") + } + } + } + + try { + // create topic + createTopic(topic, 1, 2) + + // send a normal record + val record0 = new ProducerRecord[Array[Byte], Array[Byte]](topic, partition, "key".getBytes(StandardCharsets.UTF_8), + "value".getBytes(StandardCharsets.UTF_8)) + assertEquals(0L, producer.send(record0, callback).get.offset, "Should have offset 0") + + // send a record with null value should be ok + val record1 = new ProducerRecord[Array[Byte], Array[Byte]](topic, partition, "key".getBytes(StandardCharsets.UTF_8), null) + assertEquals(1L, producer.send(record1, callback).get.offset, "Should have offset 1") + + // send a record with null key should be ok + val record2 = new ProducerRecord[Array[Byte], Array[Byte]](topic, partition, null, "value".getBytes(StandardCharsets.UTF_8)) + assertEquals(2L, producer.send(record2, callback).get.offset, "Should have offset 2") + + // send a record with null part id should be ok + val record3 = new ProducerRecord[Array[Byte], Array[Byte]](topic, null, "key".getBytes(StandardCharsets.UTF_8), + "value".getBytes(StandardCharsets.UTF_8)) + assertEquals(3L, producer.send(record3, callback).get.offset, "Should have offset 3") + + // non-blocking send a list of records + for (_ <- 1 to numRecords) + producer.send(record0, callback) + + // check that all messages have been acked via offset + assertEquals(numRecords + 4L, producer.send(record0, callback).get.offset, "Should have offset " + (numRecords + 4)) + + } finally { + producer.close() + } + } + + @Test + def testSendCompressedMessageWithCreateTime(): Unit = { + val producer = createProducer(brokerList = brokerList, + compressionType = "gzip", + lingerMs = Int.MaxValue, + deliveryTimeoutMs = Int.MaxValue) + sendAndVerifyTimestamp(producer, TimestampType.CREATE_TIME) + } + + @Test + def testSendNonCompressedMessageWithCreateTime(): Unit = { + val producer = createProducer(brokerList = brokerList, lingerMs = Int.MaxValue, deliveryTimeoutMs = Int.MaxValue) + sendAndVerifyTimestamp(producer, TimestampType.CREATE_TIME) + } + + protected def sendAndVerify(producer: KafkaProducer[Array[Byte], Array[Byte]], + numRecords: Int = numRecords, + timeoutMs: Long = 20000L): Unit = { + val partition = 0 + try { + createTopic(topic, 1, 2) + + val futures = for (i <- 1 to numRecords) yield { + val record = new ProducerRecord(topic, partition, s"key$i".getBytes(StandardCharsets.UTF_8), + s"value$i".getBytes(StandardCharsets.UTF_8)) + producer.send(record) + } + producer.close(Duration.ofMillis(timeoutMs)) + val lastOffset = futures.foldLeft(0) { (offset, future) => + val recordMetadata = future.get + assertEquals(topic, recordMetadata.topic) + assertEquals(partition, recordMetadata.partition) + assertEquals(offset, recordMetadata.offset) + offset + 1 + } + assertEquals(numRecords, lastOffset) + } finally { + producer.close() + } + } + + protected def sendAndVerifyTimestamp(producer: KafkaProducer[Array[Byte], Array[Byte]], timestampType: TimestampType): Unit = { + val partition = 0 + + val baseTimestamp = 123456L + val startTime = System.currentTimeMillis() + + object callback extends Callback { + var offset = 0L + var timestampDiff = 1L + + def onCompletion(metadata: RecordMetadata, exception: Exception): Unit = { + if (exception == null) { + assertEquals(offset, metadata.offset) + assertEquals(topic, metadata.topic) + if (timestampType == TimestampType.CREATE_TIME) + assertEquals(baseTimestamp + timestampDiff, metadata.timestamp) + else + assertTrue(metadata.timestamp >= startTime && metadata.timestamp <= System.currentTimeMillis()) + assertEquals(partition, metadata.partition) + offset += 1 + timestampDiff += 1 + } else { + fail(s"Send callback returns the following exception: $exception") + } + } + } + + try { + // create topic + val topicProps = new Properties() + if (timestampType == TimestampType.LOG_APPEND_TIME) + topicProps.setProperty(LogConfig.MessageTimestampTypeProp, "LogAppendTime") + else + topicProps.setProperty(LogConfig.MessageTimestampTypeProp, "CreateTime") + createTopic(topic, 1, 2, topicProps) + + val recordAndFutures = for (i <- 1 to numRecords) yield { + val record = new ProducerRecord(topic, partition, baseTimestamp + i, s"key$i".getBytes(StandardCharsets.UTF_8), + s"value$i".getBytes(StandardCharsets.UTF_8)) + (record, producer.send(record, callback)) + } + producer.close(Duration.ofSeconds(20L)) + recordAndFutures.foreach { case (record, future) => + val recordMetadata = future.get + if (timestampType == TimestampType.LOG_APPEND_TIME) + assertTrue(recordMetadata.timestamp >= startTime && recordMetadata.timestamp <= System.currentTimeMillis()) + else + assertEquals(record.timestamp, recordMetadata.timestamp) + } + assertEquals(numRecords, callback.offset, s"Should have offset $numRecords but only successfully sent ${callback.offset}") + } finally { + producer.close() + } + } + + /** + * testClose checks the closing behavior + * + * After close() returns, all messages should be sent with correct returned offset metadata + */ + @Test + def testClose(): Unit = { + val producer = createProducer(brokerList) + + try { + // create topic + createTopic(topic, 1, 2) + + // non-blocking send a list of records + val record0 = new ProducerRecord[Array[Byte], Array[Byte]](topic, null, "key".getBytes(StandardCharsets.UTF_8), + "value".getBytes(StandardCharsets.UTF_8)) + for (_ <- 1 to numRecords) + producer.send(record0) + val response0 = producer.send(record0) + + // close the producer + producer.close() + + // check that all messages have been acked via offset, + // this also checks that messages with same key go to the same partition + assertTrue(response0.isDone, "The last message should be acked before producer is shutdown") + assertEquals(numRecords.toLong, response0.get.offset, "Should have offset " + numRecords) + + } finally { + producer.close() + } + } + + /** + * testSendToPartition checks the partitioning behavior + * + * The specified partition-id should be respected + */ + @Test + def testSendToPartition(): Unit = { + val producer = createProducer(brokerList) + + try { + createTopic(topic, 2, 2) + val partition = 1 + + val now = System.currentTimeMillis() + val futures = (1 to numRecords).map { i => + producer.send(new ProducerRecord(topic, partition, now, null, ("value" + i).getBytes(StandardCharsets.UTF_8))) + }.map(_.get(30, TimeUnit.SECONDS)) + + // make sure all of them end up in the same partition with increasing offset values + for ((recordMetadata, offset) <- futures zip (0 until numRecords)) { + assertEquals(offset.toLong, recordMetadata.offset) + assertEquals(topic, recordMetadata.topic) + assertEquals(partition, recordMetadata.partition) + } + + consumer.assign(List(new TopicPartition(topic, partition)).asJava) + + // make sure the fetched messages also respect the partitioning and ordering + val records = TestUtils.consumeRecords(consumer, numRecords) + + records.zipWithIndex.foreach { case (record, i) => + assertEquals(topic, record.topic) + assertEquals(partition, record.partition) + assertEquals(i.toLong, record.offset) + assertNull(record.key) + assertEquals(s"value${i + 1}", new String(record.value)) + assertEquals(now, record.timestamp) + } + + } finally { + producer.close() + } + } + + /** + * Checks partitioning behavior before and after partitions are added + * + * Producer will attempt to send messages to the partition specified in each record, and should + * succeed as long as the partition is included in the metadata. + */ + @Test + def testSendBeforeAndAfterPartitionExpansion(): Unit = { + val producer = createProducer(brokerList, maxBlockMs = 5 * 1000L) + + // create topic + createTopic(topic, 1, 2) + val partition0 = 0 + + var futures0 = (1 to numRecords).map { i => + producer.send(new ProducerRecord(topic, partition0, null, ("value" + i).getBytes(StandardCharsets.UTF_8))) + }.map(_.get(30, TimeUnit.SECONDS)) + + // make sure all of them end up in the same partition with increasing offset values + for ((recordMetadata, offset) <- futures0 zip (0 until numRecords)) { + assertEquals(offset.toLong, recordMetadata.offset) + assertEquals(topic, recordMetadata.topic) + assertEquals(partition0, recordMetadata.partition) + } + + // Trying to send a record to a partition beyond topic's partition range before adding the partition should fail. + val partition1 = 1 + val e = assertThrows(classOf[ExecutionException], () => producer.send(new ProducerRecord(topic, partition1, null, "value".getBytes(StandardCharsets.UTF_8))).get()) + assertEquals(classOf[TimeoutException], e.getCause.getClass) + + val existingAssignment = zkClient.getFullReplicaAssignmentForTopics(Set(topic)).map { + case (topicPartition, assignment) => topicPartition.partition -> assignment + } + adminZkClient.addPartitions(topic, existingAssignment, adminZkClient.getBrokerMetadatas(), 2) + // read metadata from a broker and verify the new topic partitions exist + TestUtils.waitForPartitionMetadata(servers, topic, 0) + TestUtils.waitForPartitionMetadata(servers, topic, 1) + + // send records to the newly added partition after confirming that metadata have been updated. + val futures1 = (1 to numRecords).map { i => + producer.send(new ProducerRecord(topic, partition1, null, ("value" + i).getBytes(StandardCharsets.UTF_8))) + }.map(_.get(30, TimeUnit.SECONDS)) + + // make sure all of them end up in the same partition with increasing offset values + for ((recordMetadata, offset) <- futures1 zip (0 until numRecords)) { + assertEquals(offset.toLong, recordMetadata.offset) + assertEquals(topic, recordMetadata.topic) + assertEquals(partition1, recordMetadata.partition) + } + + futures0 = (1 to numRecords).map { i => + producer.send(new ProducerRecord(topic, partition0, null, ("value" + i).getBytes(StandardCharsets.UTF_8))) + }.map(_.get(30, TimeUnit.SECONDS)) + + // make sure all of them end up in the same partition with increasing offset values starting where previous + for ((recordMetadata, offset) <- futures0 zip (numRecords until 2 * numRecords)) { + assertEquals(offset.toLong, recordMetadata.offset) + assertEquals(topic, recordMetadata.topic) + assertEquals(partition0, recordMetadata.partition) + } + } + + /** + * Test that flush immediately sends all accumulated requests. + */ + @Test + def testFlush(): Unit = { + val producer = createProducer(brokerList, lingerMs = Int.MaxValue, deliveryTimeoutMs = Int.MaxValue) + try { + createTopic(topic, 2, 2) + val record = new ProducerRecord[Array[Byte], Array[Byte]](topic, + "value".getBytes(StandardCharsets.UTF_8)) + for (_ <- 0 until 50) { + val responses = (0 until numRecords) map (_ => producer.send(record)) + assertTrue(responses.forall(!_.isDone()), "No request is complete.") + producer.flush() + assertTrue(responses.forall(_.isDone()), "All requests are complete.") + } + } finally { + producer.close() + } + } + + /** + * Test close with zero timeout from caller thread + */ + @Test + def testCloseWithZeroTimeoutFromCallerThread(): Unit = { + createTopic(topic, 2, 2) + val partition = 0 + consumer.assign(List(new TopicPartition(topic, partition)).asJava) + val record0 = new ProducerRecord[Array[Byte], Array[Byte]](topic, partition, null, + "value".getBytes(StandardCharsets.UTF_8)) + + // Test closing from caller thread. + for (_ <- 0 until 50) { + val producer = createProducer(brokerList, lingerMs = Int.MaxValue, deliveryTimeoutMs = Int.MaxValue) + val responses = (0 until numRecords) map (_ => producer.send(record0)) + assertTrue(responses.forall(!_.isDone()), "No request is complete.") + producer.close(Duration.ZERO) + responses.foreach { future => + val e = assertThrows(classOf[ExecutionException], () => future.get()) + assertEquals(classOf[KafkaException], e.getCause.getClass) + } + assertEquals(0, consumer.poll(Duration.ofMillis(50L)).count, "Fetch response should have no message returned.") + } + } + + /** + * Test close with zero and non-zero timeout from sender thread + */ + @Test + def testCloseWithZeroTimeoutFromSenderThread(): Unit = { + createTopic(topic, 1, 2) + val partition = 0 + consumer.assign(List(new TopicPartition(topic, partition)).asJava) + val record = new ProducerRecord[Array[Byte], Array[Byte]](topic, partition, null, "value".getBytes(StandardCharsets.UTF_8)) + + // Test closing from sender thread. + class CloseCallback(producer: KafkaProducer[Array[Byte], Array[Byte]], sendRecords: Boolean) extends Callback { + override def onCompletion(metadata: RecordMetadata, exception: Exception): Unit = { + // Trigger another batch in accumulator before close the producer. These messages should + // not be sent. + if (sendRecords) + (0 until numRecords) foreach (_ => producer.send(record)) + // The close call will be called by all the message callbacks. This tests idempotence of the close call. + producer.close(Duration.ZERO) + // Test close with non zero timeout. Should not block at all. + producer.close() + } + } + for (i <- 0 until 50) { + val producer = createProducer(brokerList, lingerMs = Int.MaxValue, deliveryTimeoutMs = Int.MaxValue) + try { + // send message to partition 0 + // Only send the records in the first callback since we close the producer in the callback and no records + // can be sent afterwards. + val responses = (0 until numRecords) map (i => producer.send(record, new CloseCallback(producer, i == 0))) + assertTrue(responses.forall(!_.isDone()), "No request is complete.") + // flush the messages. + producer.flush() + assertTrue(responses.forall(_.isDone()), "All requests are complete.") + // Check the messages received by broker. + TestUtils.pollUntilAtLeastNumRecords(consumer, numRecords) + } finally { + producer.close() + } + } + } + +} diff --git a/core/src/test/scala/integration/kafka/api/BaseQuotaTest.scala b/core/src/test/scala/integration/kafka/api/BaseQuotaTest.scala new file mode 100644 index 0000000..9f73236 --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/BaseQuotaTest.scala @@ -0,0 +1,387 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + **/ + +package kafka.api + +import java.time.Duration +import java.util.concurrent.TimeUnit +import java.util.{Collections, HashMap, Properties} +import com.yammer.metrics.core.{Histogram, Meter} +import kafka.api.QuotaTestClients._ +import kafka.metrics.KafkaYammerMetrics +import kafka.server.{ClientQuotaManager, ClientQuotaManagerConfig, KafkaConfig, KafkaServer, QuotaType} +import kafka.utils.TestUtils +import org.apache.kafka.clients.admin.Admin +import org.apache.kafka.clients.consumer.{ConsumerConfig, KafkaConsumer} +import org.apache.kafka.clients.producer._ +import org.apache.kafka.clients.producer.internals.ErrorLoggingCallback +import org.apache.kafka.common.config.internals.QuotaConfigs +import org.apache.kafka.common.{Metric, MetricName, TopicPartition} +import org.apache.kafka.common.metrics.{KafkaMetric, Quota} +import org.apache.kafka.common.protocol.ApiKeys +import org.apache.kafka.common.quota.ClientQuotaAlteration +import org.apache.kafka.common.quota.ClientQuotaEntity +import org.apache.kafka.common.security.auth.KafkaPrincipal +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{BeforeEach, Test, TestInfo} + +import scala.collection.Map +import scala.jdk.CollectionConverters._ + +abstract class BaseQuotaTest extends IntegrationTestHarness { + + override val brokerCount = 2 + + protected def producerClientId = "QuotasTestProducer-1" + protected def consumerClientId = "QuotasTestConsumer-1" + protected def createQuotaTestClients(topic: String, leaderNode: KafkaServer): QuotaTestClients + + this.serverConfig.setProperty(KafkaConfig.ControlledShutdownEnableProp, "false") + this.serverConfig.setProperty(KafkaConfig.OffsetsTopicReplicationFactorProp, "2") + this.serverConfig.setProperty(KafkaConfig.OffsetsTopicPartitionsProp, "1") + this.serverConfig.setProperty(KafkaConfig.GroupMinSessionTimeoutMsProp, "100") + this.serverConfig.setProperty(KafkaConfig.GroupMaxSessionTimeoutMsProp, "60000") + this.serverConfig.setProperty(KafkaConfig.GroupInitialRebalanceDelayMsProp, "0") + this.producerConfig.setProperty(ProducerConfig.ACKS_CONFIG, "-1") + this.producerConfig.setProperty(ProducerConfig.BUFFER_MEMORY_CONFIG, "300000") + this.producerConfig.setProperty(ProducerConfig.CLIENT_ID_CONFIG, producerClientId) + this.consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "QuotasTest") + this.consumerConfig.setProperty(ConsumerConfig.MAX_PARTITION_FETCH_BYTES_CONFIG, 4096.toString) + this.consumerConfig.setProperty(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest") + this.consumerConfig.setProperty(ConsumerConfig.CLIENT_ID_CONFIG, consumerClientId) + this.consumerConfig.setProperty(ConsumerConfig.FETCH_MIN_BYTES_CONFIG, "0") + this.consumerConfig.setProperty(ConsumerConfig.FETCH_MAX_WAIT_MS_CONFIG, "0") + + // Low enough quota that a producer sending a small payload in a tight loop should get throttled + val defaultProducerQuota: Long = 8000 + val defaultConsumerQuota: Long = 2500 + val defaultRequestQuota: Double = Long.MaxValue.toDouble + + val topic1 = "topic-1" + var leaderNode: KafkaServer = _ + var followerNode: KafkaServer = _ + var quotaTestClients: QuotaTestClients = _ + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + + val numPartitions = 1 + val leaders = createTopic(topic1, numPartitions, brokerCount) + leaderNode = if (leaders(0) == servers.head.config.brokerId) servers.head else servers(1) + followerNode = if (leaders(0) != servers.head.config.brokerId) servers.head else servers(1) + quotaTestClients = createQuotaTestClients(topic1, leaderNode) + } + + @Test + def testThrottledProducerConsumer(): Unit = { + val numRecords = 1000 + val produced = quotaTestClients.produceUntilThrottled(numRecords) + quotaTestClients.verifyProduceThrottle(expectThrottle = true) + + // Consumer should read in a bursty manner and get throttled immediately + assertTrue(quotaTestClients.consumeUntilThrottled(produced) > 0, "Should have consumed at least one record") + quotaTestClients.verifyConsumeThrottle(expectThrottle = true) + } + + @Test + def testProducerConsumerOverrideUnthrottled(): Unit = { + // Give effectively unlimited quota for producer and consumer + val props = new Properties() + props.put(QuotaConfigs.PRODUCER_BYTE_RATE_OVERRIDE_CONFIG, Long.MaxValue.toString) + props.put(QuotaConfigs.CONSUMER_BYTE_RATE_OVERRIDE_CONFIG, Long.MaxValue.toString) + + quotaTestClients.overrideQuotas(Long.MaxValue, Long.MaxValue, Long.MaxValue.toDouble) + quotaTestClients.waitForQuotaUpdate(Long.MaxValue, Long.MaxValue, Long.MaxValue.toDouble) + + val numRecords = 1000 + assertEquals(numRecords, quotaTestClients.produceUntilThrottled(numRecords)) + quotaTestClients.verifyProduceThrottle(expectThrottle = false) + + // The "client" consumer does not get throttled. + assertEquals(numRecords, quotaTestClients.consumeUntilThrottled(numRecords)) + quotaTestClients.verifyConsumeThrottle(expectThrottle = false) + } + + @Test + def testProducerConsumerOverrideLowerQuota(): Unit = { + // consumer quota is set such that consumer quota * default quota window (10 seconds) is less than + // MAX_PARTITION_FETCH_BYTES_CONFIG, so that we can test consumer ability to fetch in this case + // In this case, 250 * 10 < 4096 + quotaTestClients.overrideQuotas(2000, 250, Long.MaxValue.toDouble) + quotaTestClients.waitForQuotaUpdate(2000, 250, Long.MaxValue.toDouble) + + val numRecords = 1000 + val produced = quotaTestClients.produceUntilThrottled(numRecords) + quotaTestClients.verifyProduceThrottle(expectThrottle = true) + + // Consumer should be able to consume at least one record, even when throttled + assertTrue(quotaTestClients.consumeUntilThrottled(produced) > 0, "Should have consumed at least one record") + quotaTestClients.verifyConsumeThrottle(expectThrottle = true) + } + + @Test + def testQuotaOverrideDelete(): Unit = { + // Override producer and consumer quotas to unlimited + quotaTestClients.overrideQuotas(Long.MaxValue, Long.MaxValue, Long.MaxValue.toDouble) + quotaTestClients.waitForQuotaUpdate(Long.MaxValue, Long.MaxValue, Long.MaxValue.toDouble) + + val numRecords = 1000 + assertEquals(numRecords, quotaTestClients.produceUntilThrottled(numRecords)) + quotaTestClients.verifyProduceThrottle(expectThrottle = false) + assertEquals(numRecords, quotaTestClients.consumeUntilThrottled(numRecords)) + quotaTestClients.verifyConsumeThrottle(expectThrottle = false) + + // Delete producer and consumer quota overrides. Consumer and producer should now be + // throttled since broker defaults are very small + quotaTestClients.removeQuotaOverrides() + quotaTestClients.waitForQuotaUpdate(defaultProducerQuota, defaultConsumerQuota, defaultRequestQuota) + val produced = quotaTestClients.produceUntilThrottled(numRecords) + quotaTestClients.verifyProduceThrottle(expectThrottle = true) + + // Since producer may have been throttled after producing a couple of records, + // consume from beginning till throttled + quotaTestClients.consumer.seekToBeginning(Collections.singleton(new TopicPartition(topic1, 0))) + quotaTestClients.consumeUntilThrottled(numRecords + produced) + quotaTestClients.verifyConsumeThrottle(expectThrottle = true) + } + + @Test + def testThrottledRequest(): Unit = { + quotaTestClients.overrideQuotas(Long.MaxValue, Long.MaxValue, 0.1) + quotaTestClients.waitForQuotaUpdate(Long.MaxValue, Long.MaxValue, 0.1) + + val consumer = quotaTestClients.consumer + consumer.subscribe(Collections.singleton(topic1)) + val endTimeMs = System.currentTimeMillis + 10000 + var throttled = false + while ((!throttled || quotaTestClients.exemptRequestMetric == null) && System.currentTimeMillis < endTimeMs) { + consumer.poll(Duration.ofMillis(100L)) + val throttleMetric = quotaTestClients.throttleMetric(QuotaType.Request, consumerClientId) + throttled = throttleMetric != null && metricValue(throttleMetric) > 0 + } + + assertTrue(throttled, "Should have been throttled") + quotaTestClients.verifyConsumerClientThrottleTimeMetric(expectThrottle = true, + Some(ClientQuotaManagerConfig.DefaultQuotaWindowSizeSeconds * 1000.0)) + + val exemptMetric = quotaTestClients.exemptRequestMetric + assertNotNull(exemptMetric, "Exempt requests not recorded") + assertTrue(metricValue(exemptMetric) > 0, "Exempt requests not recorded") + } +} + +object QuotaTestClients { + val DefaultEntity: String = null + + def metricValue(metric: Metric): Double = metric.metricValue().asInstanceOf[Double] +} + +abstract class QuotaTestClients(topic: String, + leaderNode: KafkaServer, + producerClientId: String, + consumerClientId: String, + val producer: KafkaProducer[Array[Byte], Array[Byte]], + val consumer: KafkaConsumer[Array[Byte], Array[Byte]], + val adminClient: Admin) { + + def overrideQuotas(producerQuota: Long, consumerQuota: Long, requestQuota: Double): Unit + def removeQuotaOverrides(): Unit + + protected def userPrincipal: KafkaPrincipal + protected def quotaMetricTags(clientId: String): Map[String, String] + + def produceUntilThrottled(maxRecords: Int, waitForRequestCompletion: Boolean = true): Int = { + var numProduced = 0 + var throttled = false + do { + val payload = numProduced.toString.getBytes + val future = producer.send(new ProducerRecord[Array[Byte], Array[Byte]](topic, null, null, payload), + new ErrorLoggingCallback(topic, null, null, true)) + numProduced += 1 + do { + val metric = throttleMetric(QuotaType.Produce, producerClientId) + throttled = metric != null && metricValue(metric) > 0 + } while (!future.isDone && (!throttled || waitForRequestCompletion)) + } while (numProduced < maxRecords && !throttled) + numProduced + } + + def consumeUntilThrottled(maxRecords: Int, waitForRequestCompletion: Boolean = true): Int = { + val timeoutMs = TimeUnit.MINUTES.toMillis(1) + + consumer.subscribe(Collections.singleton(topic)) + var numConsumed = 0 + var throttled = false + val startMs = System.currentTimeMillis + do { + numConsumed += consumer.poll(Duration.ofMillis(100L)).count + val metric = throttleMetric(QuotaType.Fetch, consumerClientId) + throttled = metric != null && metricValue(metric) > 0 + } while (numConsumed < maxRecords && !throttled && System.currentTimeMillis < startMs + timeoutMs) + + // If throttled, wait for the records from the last fetch to be received + if (throttled && numConsumed < maxRecords && waitForRequestCompletion) { + val minRecords = numConsumed + 1 + val startMs = System.currentTimeMillis + while (numConsumed < minRecords && System.currentTimeMillis < startMs + timeoutMs) + numConsumed += consumer.poll(Duration.ofMillis(100L)).count + } + numConsumed + } + + private def quota(quotaManager: ClientQuotaManager, userPrincipal: KafkaPrincipal, clientId: String): Quota = { + quotaManager.quota(userPrincipal, clientId) + } + + private def verifyThrottleTimeRequestChannelMetric(apiKey: ApiKeys, metricNameSuffix: String, + clientId: String, expectThrottle: Boolean): Unit = { + val throttleTimeMs = brokerRequestMetricsThrottleTimeMs(apiKey, metricNameSuffix) + if (expectThrottle) + assertTrue(throttleTimeMs > 0, s"Client with id=$clientId should have been throttled, $throttleTimeMs") + else + assertEquals(0.0, throttleTimeMs, 0.0, s"Client with id=$clientId should not have been throttled") + } + + def verifyProduceThrottle(expectThrottle: Boolean, verifyClientMetric: Boolean = true, + verifyRequestChannelMetric: Boolean = true): Unit = { + verifyThrottleTimeMetric(QuotaType.Produce, producerClientId, expectThrottle) + if (verifyRequestChannelMetric) + verifyThrottleTimeRequestChannelMetric(ApiKeys.PRODUCE, "", producerClientId, expectThrottle) + if (verifyClientMetric) + verifyProducerClientThrottleTimeMetric(expectThrottle) + } + + def verifyConsumeThrottle(expectThrottle: Boolean, verifyClientMetric: Boolean = true, + verifyRequestChannelMetric: Boolean = true): Unit = { + verifyThrottleTimeMetric(QuotaType.Fetch, consumerClientId, expectThrottle) + if (verifyRequestChannelMetric) + verifyThrottleTimeRequestChannelMetric(ApiKeys.FETCH, "Consumer", consumerClientId, expectThrottle) + if (verifyClientMetric) + verifyConsumerClientThrottleTimeMetric(expectThrottle) + } + + private def verifyThrottleTimeMetric(quotaType: QuotaType, clientId: String, expectThrottle: Boolean): Unit = { + val throttleMetricValue = metricValue(throttleMetric(quotaType, clientId)) + if (expectThrottle) { + assertTrue(throttleMetricValue > 0, s"Client with id=$clientId should have been throttled") + } else { + assertTrue(throttleMetricValue.isNaN, s"Client with id=$clientId should not have been throttled") + } + } + + private def throttleMetricName(quotaType: QuotaType, clientId: String): MetricName = { + leaderNode.metrics.metricName("throttle-time", + quotaType.toString, + quotaMetricTags(clientId).asJava) + } + + def throttleMetric(quotaType: QuotaType, clientId: String): KafkaMetric = { + leaderNode.metrics.metrics.get(throttleMetricName(quotaType, clientId)) + } + + private def brokerRequestMetricsThrottleTimeMs(apiKey: ApiKeys, metricNameSuffix: String): Double = { + def yammerMetricValue(name: String): Double = { + val allMetrics = KafkaYammerMetrics.defaultRegistry.allMetrics.asScala + val (_, metric) = allMetrics.find { case (metricName, _) => + metricName.getMBeanName.startsWith(name) + }.getOrElse(fail(s"Unable to find broker metric $name: allMetrics: ${allMetrics.keySet.map(_.getMBeanName)}")) + metric match { + case m: Meter => m.count.toDouble + case m: Histogram => m.max + case m => throw new AssertionError(s"Unexpected broker metric of class ${m.getClass}") + } + } + + yammerMetricValue(s"kafka.network:type=RequestMetrics,name=ThrottleTimeMs,request=${apiKey.name}$metricNameSuffix") + } + + def exemptRequestMetric: KafkaMetric = { + val metricName = leaderNode.metrics.metricName("exempt-request-time", QuotaType.Request.toString, "") + leaderNode.metrics.metrics.get(metricName) + } + + private def verifyProducerClientThrottleTimeMetric(expectThrottle: Boolean): Unit = { + val tags = new HashMap[String, String] + tags.put("client-id", producerClientId) + val avgMetric = producer.metrics.get(new MetricName("produce-throttle-time-avg", "producer-metrics", "", tags)) + val maxMetric = producer.metrics.get(new MetricName("produce-throttle-time-max", "producer-metrics", "", tags)) + + if (expectThrottle) { + TestUtils.waitUntilTrue(() => metricValue(avgMetric) > 0.0 && metricValue(maxMetric) > 0.0, + s"Producer throttle metric not updated: avg=${metricValue(avgMetric)} max=${metricValue(maxMetric)}") + } else + assertEquals(0.0, metricValue(maxMetric), 0.0, "Should not have been throttled") + } + + def verifyConsumerClientThrottleTimeMetric(expectThrottle: Boolean, maxThrottleTime: Option[Double] = None): Unit = { + val tags = new HashMap[String, String] + tags.put("client-id", consumerClientId) + val avgMetric = consumer.metrics.get(new MetricName("fetch-throttle-time-avg", "consumer-fetch-manager-metrics", "", tags)) + val maxMetric = consumer.metrics.get(new MetricName("fetch-throttle-time-max", "consumer-fetch-manager-metrics", "", tags)) + + if (expectThrottle) { + TestUtils.waitUntilTrue(() => metricValue(avgMetric) > 0.0 && metricValue(maxMetric) > 0.0, + s"Consumer throttle metric not updated: avg=${metricValue(avgMetric)} max=${metricValue(maxMetric)}") + maxThrottleTime.foreach(max => assertTrue(metricValue(maxMetric) <= max, + s"Maximum consumer throttle too high: ${metricValue(maxMetric)}")) + } else + assertEquals(0.0, metricValue(maxMetric), 0.0, "Should not have been throttled") + } + + def clientQuotaEntity(user: Option[String], clientId: Option[String]): ClientQuotaEntity = { + var entries = Map.empty[String, String] + user.foreach(user => entries = entries ++ Map(ClientQuotaEntity.USER -> user)) + clientId.foreach(clientId => entries = entries ++ Map(ClientQuotaEntity.CLIENT_ID -> clientId)) + new ClientQuotaEntity(entries.asJava) + } + + // None is translated to `null` which remove the quota + def clientQuotaAlteration(quotaEntity: ClientQuotaEntity, + producerQuota: Option[Long], + consumerQuota: Option[Long], + requestQuota: Option[Double]): ClientQuotaAlteration = { + var ops = Seq.empty[ClientQuotaAlteration.Op] + def addOp(key: String, value: Option[Double]): Unit = { + ops = ops ++ Seq(new ClientQuotaAlteration.Op(key, value.map(Double.box).orNull)) + } + addOp(QuotaConfigs.PRODUCER_BYTE_RATE_OVERRIDE_CONFIG, producerQuota.map(_.toDouble)) + addOp(QuotaConfigs.CONSUMER_BYTE_RATE_OVERRIDE_CONFIG, consumerQuota.map(_.toDouble)) + addOp(QuotaConfigs.REQUEST_PERCENTAGE_OVERRIDE_CONFIG, requestQuota) + new ClientQuotaAlteration(quotaEntity, ops.asJava) + } + + def alterClientQuotas(quotaAlterations: ClientQuotaAlteration *): Unit = { + adminClient.alterClientQuotas(quotaAlterations.asJava).all().get() + } + + def waitForQuotaUpdate(producerQuota: Long, consumerQuota: Long, requestQuota: Double, server: KafkaServer = leaderNode): Unit = { + TestUtils.retry(10000) { + val quotaManagers = server.dataPlaneRequestProcessor.quotas + val overrideProducerQuota = quota(quotaManagers.produce, userPrincipal, producerClientId) + val overrideConsumerQuota = quota(quotaManagers.fetch, userPrincipal, consumerClientId) + val overrideProducerRequestQuota = quota(quotaManagers.request, userPrincipal, producerClientId) + val overrideConsumerRequestQuota = quota(quotaManagers.request, userPrincipal, consumerClientId) + + assertEquals(Quota.upperBound(producerQuota.toDouble), overrideProducerQuota, + s"ClientId $producerClientId of user $userPrincipal must have producer quota") + assertEquals(Quota.upperBound(consumerQuota.toDouble), overrideConsumerQuota, + s"ClientId $consumerClientId of user $userPrincipal must have consumer quota") + assertEquals(Quota.upperBound(requestQuota.toDouble), overrideProducerRequestQuota, + s"ClientId $producerClientId of user $userPrincipal must have request quota") + assertEquals(Quota.upperBound(requestQuota.toDouble), overrideConsumerRequestQuota, + s"ClientId $consumerClientId of user $userPrincipal must have request quota") + } + } +} diff --git a/core/src/test/scala/integration/kafka/api/ClientIdQuotaTest.scala b/core/src/test/scala/integration/kafka/api/ClientIdQuotaTest.scala new file mode 100644 index 0000000..e4cebe2 --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/ClientIdQuotaTest.scala @@ -0,0 +1,77 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + **/ + +package kafka.api + +import kafka.server.KafkaServer +import org.apache.kafka.common.security.auth.KafkaPrincipal +import org.junit.jupiter.api.{BeforeEach, TestInfo} + +class ClientIdQuotaTest extends BaseQuotaTest { + + override def producerClientId = "QuotasTestProducer-!@#$%^&*()" + override def consumerClientId = "QuotasTestConsumer-!@#$%^&*()" + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + quotaTestClients.alterClientQuotas( + quotaTestClients.clientQuotaAlteration( + quotaTestClients.clientQuotaEntity(None, Some(QuotaTestClients.DefaultEntity)), + Some(defaultProducerQuota), Some(defaultConsumerQuota), Some(defaultRequestQuota) + ) + ) + quotaTestClients.waitForQuotaUpdate(defaultProducerQuota, defaultConsumerQuota, defaultRequestQuota) + } + + override def createQuotaTestClients(topic: String, leaderNode: KafkaServer): QuotaTestClients = { + val producer = createProducer() + val consumer = createConsumer() + val adminClient = createAdminClient() + + new QuotaTestClients(topic, leaderNode, producerClientId, consumerClientId, producer, consumer, adminClient) { + override def userPrincipal: KafkaPrincipal = KafkaPrincipal.ANONYMOUS + + override def quotaMetricTags(clientId: String): Map[String, String] = { + Map("user" -> "", "client-id" -> clientId) + } + + override def overrideQuotas(producerQuota: Long, consumerQuota: Long, requestQuota: Double): Unit = { + alterClientQuotas( + clientQuotaAlteration( + clientQuotaEntity(None, Some(producerClientId)), + Some(producerQuota), None, Some(requestQuota) + ), + clientQuotaAlteration( + clientQuotaEntity(None, Some(consumerClientId)), + None, Some(consumerQuota), Some(requestQuota) + ) + ) + } + + override def removeQuotaOverrides(): Unit = { + alterClientQuotas( + clientQuotaAlteration( + clientQuotaEntity(None, Some(producerClientId)), + None, None, None + ), + clientQuotaAlteration( + clientQuotaEntity(None, Some(consumerClientId)), + None, None, None + ) + ) + } + } + } +} diff --git a/core/src/test/scala/integration/kafka/api/ConsumerBounceTest.scala b/core/src/test/scala/integration/kafka/api/ConsumerBounceTest.scala new file mode 100644 index 0000000..9fc8727 --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/ConsumerBounceTest.scala @@ -0,0 +1,540 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package kafka.api + +import java.time +import java.util.concurrent._ +import java.util.{Collection, Collections, Properties} + +import kafka.server.KafkaConfig +import kafka.utils.{Logging, ShutdownableThread, TestUtils} +import org.apache.kafka.clients.consumer._ +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord} +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.errors.GroupMaxSizeReachedException +import org.apache.kafka.common.message.FindCoordinatorRequestData +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.{FindCoordinatorRequest, FindCoordinatorResponse} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, Disabled, Test} + +import scala.annotation.nowarn +import scala.jdk.CollectionConverters._ +import scala.collection.{Seq, mutable} + +/** + * Integration tests for the consumer that cover basic usage as well as server failures + */ +class ConsumerBounceTest extends AbstractConsumerTest with Logging { + val maxGroupSize = 5 + + // Time to process commit and leave group requests in tests when brokers are available + val gracefulCloseTimeMs = Some(1000L) + val executor: ScheduledExecutorService = Executors.newScheduledThreadPool(2) + val consumerPollers: mutable.Buffer[ConsumerAssignmentPoller] = mutable.Buffer[ConsumerAssignmentPoller]() + + this.consumerConfig.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "true") + + override def generateConfigs: Seq[KafkaConfig] = { + generateKafkaConfigs() + } + + private def generateKafkaConfigs(maxGroupSize: String = maxGroupSize.toString): Seq[KafkaConfig] = { + val properties = new Properties + properties.put(KafkaConfig.OffsetsTopicReplicationFactorProp, "3") // don't want to lose offset + properties.put(KafkaConfig.OffsetsTopicPartitionsProp, "1") + properties.put(KafkaConfig.GroupMinSessionTimeoutMsProp, "10") // set small enough session timeout + properties.put(KafkaConfig.GroupInitialRebalanceDelayMsProp, "0") + properties.put(KafkaConfig.GroupMaxSizeProp, maxGroupSize) + properties.put(KafkaConfig.UncleanLeaderElectionEnableProp, "true") + properties.put(KafkaConfig.AutoCreateTopicsEnableProp, "false") + + FixedPortTestUtils.createBrokerConfigs(brokerCount, zkConnect, enableControlledShutdown = false) + .map(KafkaConfig.fromProps(_, properties)) + } + + @AfterEach + override def tearDown(): Unit = { + try { + consumerPollers.foreach(_.shutdown()) + executor.shutdownNow() + // Wait for any active tasks to terminate to ensure consumer is not closed while being used from another thread + assertTrue(executor.awaitTermination(5000, TimeUnit.MILLISECONDS), "Executor did not terminate") + } finally { + super.tearDown() + } + } + + @Test + @Disabled // To be re-enabled once we can make it less flaky (KAFKA-4801) + def testConsumptionWithBrokerFailures(): Unit = consumeWithBrokerFailures(10) + + /* + * 1. Produce a bunch of messages + * 2. Then consume the messages while killing and restarting brokers at random + */ + @nowarn("cat=deprecation") + def consumeWithBrokerFailures(numIters: Int): Unit = { + val numRecords = 1000 + val producer = createProducer() + producerSend(producer, numRecords) + + var consumed = 0L + val consumer = createConsumer() + + consumer.subscribe(Collections.singletonList(topic)) + + val scheduler = new BounceBrokerScheduler(numIters) + scheduler.start() + + while (scheduler.isRunning) { + val records = consumer.poll(100).asScala + assertEquals(Set(tp), consumer.assignment.asScala) + + for (record <- records) { + assertEquals(consumed, record.offset()) + consumed += 1 + } + + if (records.nonEmpty) { + consumer.commitSync() + assertEquals(consumer.position(tp), consumer.committed(Set(tp).asJava).get(tp).offset) + + if (consumer.position(tp) == numRecords) { + consumer.seekToBeginning(Collections.emptyList()) + consumed = 0 + } + } + } + scheduler.shutdown() + } + + @Test + def testSeekAndCommitWithBrokerFailures(): Unit = seekAndCommitWithBrokerFailures(5) + + def seekAndCommitWithBrokerFailures(numIters: Int): Unit = { + val numRecords = 1000 + val producer = createProducer() + producerSend(producer, numRecords) + + val consumer = createConsumer() + consumer.assign(Collections.singletonList(tp)) + consumer.seek(tp, 0) + + // wait until all the followers have synced the last HW with leader + TestUtils.waitUntilTrue(() => servers.forall(server => + server.replicaManager.localLog(tp).get.highWatermark == numRecords + ), "Failed to update high watermark for followers after timeout") + + val scheduler = new BounceBrokerScheduler(numIters) + scheduler.start() + + while(scheduler.isRunning) { + val coin = TestUtils.random.nextInt(3) + if (coin == 0) { + info("Seeking to end of log") + consumer.seekToEnd(Collections.emptyList()) + assertEquals(numRecords.toLong, consumer.position(tp)) + } else if (coin == 1) { + val pos = TestUtils.random.nextInt(numRecords).toLong + info("Seeking to " + pos) + consumer.seek(tp, pos) + assertEquals(pos, consumer.position(tp)) + } else if (coin == 2) { + info("Committing offset.") + consumer.commitSync() + assertEquals(consumer.position(tp), consumer.committed(Set(tp).asJava).get(tp).offset) + } + } + } + + @Test + def testSubscribeWhenTopicUnavailable(): Unit = { + val numRecords = 1000 + val newtopic = "newtopic" + + val consumer = createConsumer() + consumer.subscribe(Collections.singleton(newtopic)) + executor.schedule(new Runnable { + def run() = createTopic(newtopic, numPartitions = brokerCount, replicationFactor = brokerCount) + }, 2, TimeUnit.SECONDS) + consumer.poll(time.Duration.ZERO) + + val producer = createProducer() + + def sendRecords(numRecords: Int, topic: String): Unit = { + var remainingRecords = numRecords + val endTimeMs = System.currentTimeMillis + 20000 + while (remainingRecords > 0 && System.currentTimeMillis < endTimeMs) { + val futures = (0 until remainingRecords).map { i => + producer.send(new ProducerRecord(topic, part, i.toString.getBytes, i.toString.getBytes)) + } + futures.map { future => + try { + future.get + remainingRecords -= 1 + } catch { + case _: Exception => + } + } + } + assertEquals(0, remainingRecords) + } + + val poller = new ConsumerAssignmentPoller(consumer, List(newtopic)) + consumerPollers += poller + poller.start() + sendRecords(numRecords, newtopic) + receiveExactRecords(poller, numRecords, 10000) + poller.shutdown() + + servers.foreach(server => killBroker(server.config.brokerId)) + Thread.sleep(500) + restartDeadBrokers() + + val poller2 = new ConsumerAssignmentPoller(consumer, List(newtopic)) + consumerPollers += poller2 + poller2.start() + sendRecords(numRecords, newtopic) + receiveExactRecords(poller, numRecords, 10000L) + } + + @Test + def testClose(): Unit = { + val numRecords = 10 + val producer = createProducer() + producerSend(producer, numRecords) + + checkCloseGoodPath(numRecords, "group1") + checkCloseWithCoordinatorFailure(numRecords, "group2", "group3") + checkCloseWithClusterFailure(numRecords, "group4", "group5") + } + + /** + * Consumer is closed while cluster is healthy. Consumer should complete pending offset commits + * and leave group. New consumer instance should be able join group and start consuming from + * last committed offset. + */ + private def checkCloseGoodPath(numRecords: Int, groupId: String): Unit = { + val consumer = createConsumerAndReceive(groupId, false, numRecords) + val future = submitCloseAndValidate(consumer, Long.MaxValue, None, gracefulCloseTimeMs) + future.get + checkClosedState(groupId, numRecords) + } + + /** + * Consumer closed while coordinator is unavailable. Close of consumers using group + * management should complete after commit attempt even though commits fail due to rebalance. + * Close of consumers using manual assignment should complete with successful commits since a + * broker is available. + */ + private def checkCloseWithCoordinatorFailure(numRecords: Int, dynamicGroup: String, manualGroup: String): Unit = { + val consumer1 = createConsumerAndReceive(dynamicGroup, false, numRecords) + val consumer2 = createConsumerAndReceive(manualGroup, true, numRecords) + + killBroker(findCoordinator(dynamicGroup)) + killBroker(findCoordinator(manualGroup)) + + val future1 = submitCloseAndValidate(consumer1, Long.MaxValue, None, gracefulCloseTimeMs) + + val future2 = submitCloseAndValidate(consumer2, Long.MaxValue, None, gracefulCloseTimeMs) + + future1.get + future2.get + + restartDeadBrokers() + checkClosedState(dynamicGroup, 0) + checkClosedState(manualGroup, numRecords) + } + + private def findCoordinator(group: String): Int = { + val request = new FindCoordinatorRequest.Builder(new FindCoordinatorRequestData() + .setKeyType(FindCoordinatorRequest.CoordinatorType.GROUP.id) + .setCoordinatorKeys(Collections.singletonList(group))).build() + var nodeId = -1 + TestUtils.waitUntilTrue(() => { + val response = connectAndReceive[FindCoordinatorResponse](request) + nodeId = response.node.id + response.error == Errors.NONE + }, s"Failed to find coordinator for group $group") + nodeId + } + + /** + * Consumer is closed while all brokers are unavailable. Cannot rebalance or commit offsets since + * there is no coordinator, but close should timeout and return. If close is invoked with a very + * large timeout, close should timeout after request timeout. + */ + private def checkCloseWithClusterFailure(numRecords: Int, group1: String, group2: String): Unit = { + val consumer1 = createConsumerAndReceive(group1, false, numRecords) + + val requestTimeout = 6000 + this.consumerConfig.setProperty(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, "5000") + this.consumerConfig.setProperty(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, "1000") + this.consumerConfig.setProperty(ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG, requestTimeout.toString) + val consumer2 = createConsumerAndReceive(group2, true, numRecords) + + servers.foreach(server => killBroker(server.config.brokerId)) + val closeTimeout = 2000 + val future1 = submitCloseAndValidate(consumer1, closeTimeout, None, Some(closeTimeout)) + val future2 = submitCloseAndValidate(consumer2, Long.MaxValue, Some(requestTimeout), Some(requestTimeout)) + future1.get + future2.get + } + + /** + * If we have a running consumer group of size N, configure consumer.group.max.size = N-1 and restart all brokers, + * the group should be forced to rebalance when it becomes hosted on a Coordinator with the new config. + * Then, 1 consumer should be left out of the group. + */ + @Test + @Disabled // To be re-enabled once we fix KAFKA-13421 + def testRollingBrokerRestartsWithSmallerMaxGroupSizeConfigDisruptsBigGroup(): Unit = { + val group = "group-max-size-test" + val topic = "group-max-size-test" + val maxGroupSize = 2 + val consumerCount = maxGroupSize + 1 + val partitionCount = consumerCount * 2 + + this.consumerConfig.setProperty(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG, "60000") + this.consumerConfig.setProperty(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, "1000") + this.consumerConfig.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") + val partitions = createTopicPartitions(topic, numPartitions = partitionCount, replicationFactor = brokerCount) + + addConsumersToGroupAndWaitForGroupAssignment(consumerCount, mutable.Buffer[KafkaConsumer[Array[Byte], Array[Byte]]](), + consumerPollers, List[String](topic), partitions, group) + + // roll all brokers with a lesser max group size to make sure coordinator has the new config + val newConfigs = generateKafkaConfigs(maxGroupSize.toString) + for (serverIdx <- servers.indices) { + killBroker(serverIdx) + val config = newConfigs(serverIdx) + servers(serverIdx) = TestUtils.createServer(config, time = brokerTime(config.brokerId)) + restartDeadBrokers() + } + + def raisedExceptions: Seq[Throwable] = { + consumerPollers.flatten(_.thrownException) + } + + // we are waiting for the group to rebalance and one member to get kicked + TestUtils.waitUntilTrue(() => raisedExceptions.nonEmpty, + msg = "The remaining consumers in the group could not fetch the expected records", 10000L) + + assertEquals(1, raisedExceptions.size) + assertTrue(raisedExceptions.head.isInstanceOf[GroupMaxSizeReachedException]) + } + + /** + * When we have the consumer group max size configured to X, the X+1th consumer trying to join should receive a fatal exception + */ + @Test + def testConsumerReceivesFatalExceptionWhenGroupPassesMaxSize(): Unit = { + val group = "fatal-exception-test" + val topic = "fatal-exception-test" + this.consumerConfig.setProperty(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG, "60000") + this.consumerConfig.setProperty(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, "1000") + this.consumerConfig.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") + + val partitions = createTopicPartitions(topic, numPartitions = maxGroupSize, replicationFactor = brokerCount) + + // Create N+1 consumers in the same consumer group and assert that the N+1th consumer receives a fatal error when it tries to join the group + addConsumersToGroupAndWaitForGroupAssignment(maxGroupSize, mutable.Buffer[KafkaConsumer[Array[Byte], Array[Byte]]](), + consumerPollers, List[String](topic), partitions, group) + val (_, rejectedConsumerPollers) = addConsumersToGroup(1, + mutable.Buffer[KafkaConsumer[Array[Byte], Array[Byte]]](), mutable.Buffer[ConsumerAssignmentPoller](), List[String](topic), partitions, group) + val rejectedConsumer = rejectedConsumerPollers.head + TestUtils.waitUntilTrue(() => { + rejectedConsumer.thrownException.isDefined + }, "Extra consumer did not throw an exception") + assertTrue(rejectedConsumer.thrownException.get.isInstanceOf[GroupMaxSizeReachedException]) + + // assert group continues to live + producerSend(createProducer(), maxGroupSize * 100, topic, numPartitions = Some(partitions.size)) + TestUtils.waitUntilTrue(() => { + consumerPollers.forall(p => p.receivedMessages >= 100) + }, "The consumers in the group could not fetch the expected records", 10000L) + } + + /** + * Consumer is closed during rebalance. Close should leave group and close + * immediately if rebalance is in progress. If brokers are not available, + * close should terminate immediately without sending leave group. + */ + @Test + def testCloseDuringRebalance(): Unit = { + val topic = "closetest" + createTopic(topic, 10, brokerCount) + this.consumerConfig.setProperty(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG, "60000") + this.consumerConfig.setProperty(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, "1000") + this.consumerConfig.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") + checkCloseDuringRebalance("group1", topic, executor, true) + } + + @nowarn("cat=deprecation") + private def checkCloseDuringRebalance(groupId: String, topic: String, executor: ExecutorService, brokersAvailableDuringClose: Boolean): Unit = { + + def subscribeAndPoll(consumer: KafkaConsumer[Array[Byte], Array[Byte]], revokeSemaphore: Option[Semaphore] = None): Future[Any] = { + executor.submit(() => { + consumer.subscribe(Collections.singletonList(topic)) + revokeSemaphore.foreach(s => s.release()) + // requires to used deprecated `poll(long)` to trigger metadata update + consumer.poll(0L) + }, 0) + } + + def waitForRebalance(timeoutMs: Long, future: Future[Any], otherConsumers: KafkaConsumer[Array[Byte], Array[Byte]]*): Unit = { + val startMs = System.currentTimeMillis + while (System.currentTimeMillis < startMs + timeoutMs && !future.isDone) + otherConsumers.foreach(consumer => consumer.poll(time.Duration.ofMillis(100L))) + assertTrue(future.isDone, "Rebalance did not complete in time") + } + + def createConsumerToRebalance(): Future[Any] = { + val consumer = createConsumerWithGroupId(groupId) + val rebalanceSemaphore = new Semaphore(0) + val future = subscribeAndPoll(consumer, Some(rebalanceSemaphore)) + // Wait for consumer to poll and trigger rebalance + assertTrue(rebalanceSemaphore.tryAcquire(2000, TimeUnit.MILLISECONDS), "Rebalance not triggered") + // Rebalance is blocked by other consumers not polling + assertFalse(future.isDone, "Rebalance completed too early") + future + } + val consumer1 = createConsumerWithGroupId(groupId) + waitForRebalance(2000, subscribeAndPoll(consumer1)) + val consumer2 = createConsumerWithGroupId(groupId) + waitForRebalance(2000, subscribeAndPoll(consumer2), consumer1) + val rebalanceFuture = createConsumerToRebalance() + + // consumer1 should leave group and close immediately even though rebalance is in progress + val closeFuture1 = submitCloseAndValidate(consumer1, Long.MaxValue, None, gracefulCloseTimeMs) + + // Rebalance should complete without waiting for consumer1 to timeout since consumer1 has left the group + waitForRebalance(2000, rebalanceFuture, consumer2) + + // Trigger another rebalance and shutdown all brokers + // This consumer poll() doesn't complete and `tearDown` shuts down the executor and closes the consumer + createConsumerToRebalance() + servers.foreach(server => killBroker(server.config.brokerId)) + + // consumer2 should close immediately without LeaveGroup request since there are no brokers available + val closeFuture2 = submitCloseAndValidate(consumer2, Long.MaxValue, None, Some(0)) + + // Ensure futures complete to avoid concurrent shutdown attempt during test cleanup + closeFuture1.get(2000, TimeUnit.MILLISECONDS) + closeFuture2.get(2000, TimeUnit.MILLISECONDS) + } + + private def createConsumerAndReceive(groupId: String, manualAssign: Boolean, numRecords: Int): KafkaConsumer[Array[Byte], Array[Byte]] = { + val consumer = createConsumerWithGroupId(groupId) + val consumerPoller = if (manualAssign) + subscribeConsumerAndStartPolling(consumer, List(), Set(tp)) + else + subscribeConsumerAndStartPolling(consumer, List(topic)) + + receiveExactRecords(consumerPoller, numRecords) + consumerPoller.shutdown() + consumer + } + + private def receiveExactRecords(consumer: ConsumerAssignmentPoller, numRecords: Int, timeoutMs: Long = 60000): Unit = { + TestUtils.waitUntilTrue(() => { + consumer.receivedMessages == numRecords + }, s"Consumer did not receive expected $numRecords. It received ${consumer.receivedMessages}", timeoutMs) + } + + private def submitCloseAndValidate(consumer: KafkaConsumer[Array[Byte], Array[Byte]], + closeTimeoutMs: Long, minCloseTimeMs: Option[Long], maxCloseTimeMs: Option[Long]): Future[Any] = { + executor.submit(() => { + val closeGraceTimeMs = 2000 + val startMs = System.currentTimeMillis() + info("Closing consumer with timeout " + closeTimeoutMs + " ms.") + consumer.close(time.Duration.ofMillis(closeTimeoutMs)) + val timeTakenMs = System.currentTimeMillis() - startMs + maxCloseTimeMs.foreach { ms => + assertTrue(timeTakenMs < ms + closeGraceTimeMs, "Close took too long " + timeTakenMs) + } + minCloseTimeMs.foreach { ms => + assertTrue(timeTakenMs >= ms, "Close finished too quickly " + timeTakenMs) + } + info("consumer.close() completed in " + timeTakenMs + " ms.") + }, 0) + } + + private def checkClosedState(groupId: String, committedRecords: Int): Unit = { + // Check that close was graceful with offsets committed and leave group sent. + // New instance of consumer should be assigned partitions immediately and should see committed offsets. + val assignSemaphore = new Semaphore(0) + val consumer = createConsumerWithGroupId(groupId) + consumer.subscribe(Collections.singletonList(topic), new ConsumerRebalanceListener { + def onPartitionsAssigned(partitions: Collection[TopicPartition]): Unit = { + assignSemaphore.release() + } + def onPartitionsRevoked(partitions: Collection[TopicPartition]): Unit = { + }}) + + TestUtils.waitUntilTrue(() => { + consumer.poll(time.Duration.ofMillis(100L)) + assignSemaphore.tryAcquire() + }, "Assignment did not complete on time") + + if (committedRecords > 0) + assertEquals(committedRecords, consumer.committed(Set(tp).asJava).get(tp).offset) + consumer.close() + } + + private class BounceBrokerScheduler(val numIters: Int) extends ShutdownableThread("daemon-bounce-broker", false) { + var iter: Int = 0 + + override def doWork(): Unit = { + killRandomBroker() + Thread.sleep(500) + restartDeadBrokers() + + iter += 1 + if (iter == numIters) + initiateShutdown() + else + Thread.sleep(500) + } + } + + private def createTopicPartitions(topic: String, numPartitions: Int, replicationFactor: Int, + topicConfig: Properties = new Properties): Set[TopicPartition] = { + createTopic(topic, numPartitions = numPartitions, replicationFactor = replicationFactor, topicConfig = topicConfig) + Range(0, numPartitions).map(part => new TopicPartition(topic, part)).toSet + } + + private def producerSend(producer: KafkaProducer[Array[Byte], Array[Byte]], + numRecords: Int, + topic: String = this.topic, + numPartitions: Option[Int] = None): Unit = { + var partitionIndex = 0 + def getPartition: Int = { + numPartitions match { + case Some(partitions) => + val nextPart = partitionIndex % partitions + partitionIndex += 1 + nextPart + case None => part + } + } + + val futures = (0 until numRecords).map { i => + producer.send(new ProducerRecord(topic, getPartition, i.toString.getBytes, i.toString.getBytes)) + } + futures.map(_.get) + } + +} diff --git a/core/src/test/scala/integration/kafka/api/ConsumerTopicCreationTest.scala b/core/src/test/scala/integration/kafka/api/ConsumerTopicCreationTest.scala new file mode 100644 index 0000000..e0853e4 --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/ConsumerTopicCreationTest.scala @@ -0,0 +1,122 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.api + +import java.lang.{Boolean => JBoolean} +import java.time.Duration +import java.util +import java.util.Collections + +import kafka.api +import kafka.server.KafkaConfig +import kafka.utils.{EmptyTestInfo, TestUtils} +import org.apache.kafka.clients.admin.NewTopic +import org.apache.kafka.clients.consumer.ConsumerConfig +import org.apache.kafka.clients.producer.{ProducerConfig, ProducerRecord} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.{Arguments, MethodSource} + +/** + * Tests behavior of specifying auto topic creation configuration for the consumer and broker + */ +class ConsumerTopicCreationTest { + + @ParameterizedTest + @MethodSource(Array("parameters")) + def testAutoTopicCreation(brokerAutoTopicCreationEnable: JBoolean, consumerAllowAutoCreateTopics: JBoolean): Unit = { + val testCase = new ConsumerTopicCreationTest.TestCase(brokerAutoTopicCreationEnable, consumerAllowAutoCreateTopics) + testCase.setUp(new EmptyTestInfo()) + try testCase.test() finally testCase.tearDown() + } + + @ParameterizedTest + @MethodSource(Array("parameters")) + def testAutoTopicCreationWithForwarding(brokerAutoTopicCreationEnable: JBoolean, consumerAllowAutoCreateTopics: JBoolean): Unit = { + val testCase = new api.ConsumerTopicCreationTest.TestCaseWithForwarding(brokerAutoTopicCreationEnable, consumerAllowAutoCreateTopics) + testCase.setUp(new EmptyTestInfo()) + try testCase.test() finally testCase.tearDown() + } +} + +object ConsumerTopicCreationTest { + + private class TestCaseWithForwarding(brokerAutoTopicCreationEnable: JBoolean, consumerAllowAutoCreateTopics: JBoolean) + extends TestCase(brokerAutoTopicCreationEnable, consumerAllowAutoCreateTopics) { + + override protected def brokerCount: Int = 3 + + override def enableForwarding: Boolean = true + } + + private class TestCase(brokerAutoTopicCreationEnable: JBoolean, consumerAllowAutoCreateTopics: JBoolean) extends IntegrationTestHarness { + private val topic_1 = "topic-1" + private val topic_2 = "topic-2" + private val producerClientId = "ConsumerTestProducer" + private val consumerClientId = "ConsumerTestConsumer" + + // configure server properties + this.serverConfig.setProperty(KafkaConfig.ControlledShutdownEnableProp, "false") // speed up shutdown + this.serverConfig.setProperty(KafkaConfig.AutoCreateTopicsEnableProp, brokerAutoTopicCreationEnable.toString) + + // configure client properties + this.producerConfig.setProperty(ProducerConfig.CLIENT_ID_CONFIG, producerClientId) + this.consumerConfig.setProperty(ConsumerConfig.CLIENT_ID_CONFIG, consumerClientId) + this.consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "my-test") + this.consumerConfig.setProperty(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest") + this.consumerConfig.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") + this.consumerConfig.setProperty(ConsumerConfig.METADATA_MAX_AGE_CONFIG, "100") + this.consumerConfig.setProperty(ConsumerConfig.ALLOW_AUTO_CREATE_TOPICS_CONFIG, consumerAllowAutoCreateTopics.toString) + override protected def brokerCount: Int = 1 + + + def test(): Unit = { + val consumer = createConsumer() + val producer = createProducer() + val adminClient = createAdminClient() + val record = new ProducerRecord(topic_1, 0, "key".getBytes, "value".getBytes) + + // create `topic_1` and produce a record to it + adminClient.createTopics(Collections.singleton(new NewTopic(topic_1, 1, 1.toShort))).all.get + producer.send(record).get + + consumer.subscribe(util.Arrays.asList(topic_1, topic_2)) + + // Wait until the produced record was consumed. This guarantees that metadata request for `topic_2` was sent to the + // broker. + TestUtils.waitUntilTrue(() => { + consumer.poll(Duration.ofMillis(100)).count > 0 + }, "Timed out waiting to consume") + + // MetadataRequest is guaranteed to create the topic znode if creation was required + val topicCreated = zkClient.getAllTopicsInCluster().contains(topic_2) + if (brokerAutoTopicCreationEnable && consumerAllowAutoCreateTopics) + assertTrue(topicCreated) + else + assertFalse(topicCreated) + } + } + + def parameters: java.util.stream.Stream[Arguments] = { + val data = new java.util.ArrayList[Arguments]() + for (brokerAutoTopicCreationEnable <- Array(JBoolean.TRUE, JBoolean.FALSE)) + for (consumerAutoCreateTopicsPolicy <- Array(JBoolean.TRUE, JBoolean.FALSE)) + data.add(Arguments.of(brokerAutoTopicCreationEnable, consumerAutoCreateTopicsPolicy)) + data.stream() + } +} diff --git a/core/src/test/scala/integration/kafka/api/ConsumerWithLegacyMessageFormatIntegrationTest.scala b/core/src/test/scala/integration/kafka/api/ConsumerWithLegacyMessageFormatIntegrationTest.scala new file mode 100644 index 0000000..e8c451e --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/ConsumerWithLegacyMessageFormatIntegrationTest.scala @@ -0,0 +1,127 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.api + +import kafka.log.LogConfig +import kafka.server.KafkaConfig +import org.apache.kafka.common.TopicPartition +import org.junit.jupiter.api.Assertions.{assertEquals, assertNull, assertThrows} +import org.junit.jupiter.api.Test + +import java.util +import java.util.{Collections, Optional, Properties} +import scala.annotation.nowarn +import scala.jdk.CollectionConverters._ + +class ConsumerWithLegacyMessageFormatIntegrationTest extends AbstractConsumerTest { + + override protected def brokerPropertyOverrides(properties: Properties): Unit = { + // legacy message formats are only supported with IBP < 3.0 + properties.put(KafkaConfig.InterBrokerProtocolVersionProp, "2.8") + } + + @nowarn("cat=deprecation") + @Test + def testOffsetsForTimes(): Unit = { + val numParts = 2 + val topic1 = "part-test-topic-1" + val topic2 = "part-test-topic-2" + val topic3 = "part-test-topic-3" + val props = new Properties() + props.setProperty(LogConfig.MessageFormatVersionProp, "0.9.0") + createTopic(topic1, numParts, 1) + // Topic2 is in old message format. + createTopic(topic2, numParts, 1, props) + createTopic(topic3, numParts, 1) + + val consumer = createConsumer() + + // Test negative target time + assertThrows(classOf[IllegalArgumentException], + () => consumer.offsetsForTimes(Collections.singletonMap(new TopicPartition(topic1, 0), -1))) + + val producer = createProducer() + val timestampsToSearch = new util.HashMap[TopicPartition, java.lang.Long]() + var i = 0 + for (topic <- List(topic1, topic2, topic3)) { + for (part <- 0 until numParts) { + val tp = new TopicPartition(topic, part) + // In sendRecords(), each message will have key, value and timestamp equal to the sequence number. + sendRecords(producer, numRecords = 100, tp, startingTimestamp = 0) + timestampsToSearch.put(tp, (i * 20).toLong) + i += 1 + } + } + // The timestampToSearch map should contain: + // (topic1Partition0 -> 0, + // topic1Partitoin1 -> 20, + // topic2Partition0 -> 40, + // topic2Partition1 -> 60, + // topic3Partition0 -> 80, + // topic3Partition1 -> 100) + val timestampOffsets = consumer.offsetsForTimes(timestampsToSearch) + + val timestampTopic1P0 = timestampOffsets.get(new TopicPartition(topic1, 0)) + assertEquals(0, timestampTopic1P0.offset) + assertEquals(0, timestampTopic1P0.timestamp) + assertEquals(Optional.of(0), timestampTopic1P0.leaderEpoch) + + val timestampTopic1P1 = timestampOffsets.get(new TopicPartition(topic1, 1)) + assertEquals(20, timestampTopic1P1.offset) + assertEquals(20, timestampTopic1P1.timestamp) + assertEquals(Optional.of(0), timestampTopic1P1.leaderEpoch) + + assertNull(timestampOffsets.get(new TopicPartition(topic2, 0)), "null should be returned when message format is 0.9.0") + assertNull(timestampOffsets.get(new TopicPartition(topic2, 1)), "null should be returned when message format is 0.9.0") + + val timestampTopic3P0 = timestampOffsets.get(new TopicPartition(topic3, 0)) + assertEquals(80, timestampTopic3P0.offset) + assertEquals(80, timestampTopic3P0.timestamp) + assertEquals(Optional.of(0), timestampTopic3P0.leaderEpoch) + + assertNull(timestampOffsets.get(new TopicPartition(topic3, 1))) + } + + @nowarn("cat=deprecation") + @Test + def testEarliestOrLatestOffsets(): Unit = { + val topic0 = "topicWithNewMessageFormat" + val topic1 = "topicWithOldMessageFormat" + val producer = createProducer() + createTopicAndSendRecords(producer, topicName = topic0, numPartitions = 2, recordsPerPartition = 100) + val props = new Properties() + props.setProperty(LogConfig.MessageFormatVersionProp, "0.9.0") + createTopic(topic1, numPartitions = 1, replicationFactor = 1, props) + sendRecords(producer, numRecords = 100, new TopicPartition(topic1, 0)) + + val t0p0 = new TopicPartition(topic0, 0) + val t0p1 = new TopicPartition(topic0, 1) + val t1p0 = new TopicPartition(topic1, 0) + val partitions = Set(t0p0, t0p1, t1p0).asJava + val consumer = createConsumer() + + val earliests = consumer.beginningOffsets(partitions) + assertEquals(0L, earliests.get(t0p0)) + assertEquals(0L, earliests.get(t0p1)) + assertEquals(0L, earliests.get(t1p0)) + + val latests = consumer.endOffsets(partitions) + assertEquals(100L, latests.get(t0p0)) + assertEquals(100L, latests.get(t0p1)) + assertEquals(100L, latests.get(t1p0)) + } +} diff --git a/core/src/test/scala/integration/kafka/api/CustomQuotaCallbackTest.scala b/core/src/test/scala/integration/kafka/api/CustomQuotaCallbackTest.scala new file mode 100644 index 0000000..201fc4b --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/CustomQuotaCallbackTest.scala @@ -0,0 +1,469 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + **/ + +package kafka.api + +import java.io.File +import java.{lang, util} +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} +import java.util.Properties + +import kafka.api.GroupedUserPrincipalBuilder._ +import kafka.api.GroupedUserQuotaCallback._ +import kafka.server._ +import kafka.utils.JaasTestUtils.ScramLoginModule +import kafka.utils.{JaasTestUtils, Logging, TestUtils} +import kafka.zk.ConfigEntityChangeNotificationZNode +import org.apache.kafka.clients.admin.{Admin, AdminClientConfig} +import org.apache.kafka.clients.consumer.{ConsumerConfig, KafkaConsumer} +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerConfig, ProducerRecord} +import org.apache.kafka.common.{Cluster, Reconfigurable} +import org.apache.kafka.common.config.SaslConfigs +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.security.auth._ +import org.apache.kafka.server.quota._ +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +import scala.collection.mutable.ArrayBuffer +import scala.jdk.CollectionConverters._ + +class CustomQuotaCallbackTest extends IntegrationTestHarness with SaslSetup { + + override protected def securityProtocol = SecurityProtocol.SASL_SSL + override protected def listenerName = new ListenerName("CLIENT") + override protected def interBrokerListenerName: ListenerName = new ListenerName("BROKER") + + override protected lazy val trustStoreFile = Some(File.createTempFile("truststore", ".jks")) + override val brokerCount: Int = 2 + + private val kafkaServerSaslMechanisms = Seq("SCRAM-SHA-256") + private val kafkaClientSaslMechanism = "SCRAM-SHA-256" + override protected val serverSaslProperties = Some(kafkaServerSaslProperties(kafkaServerSaslMechanisms, kafkaClientSaslMechanism)) + override protected val clientSaslProperties = Some(kafkaClientSaslProperties(kafkaClientSaslMechanism)) + private val adminClients = new ArrayBuffer[Admin]() + private var producerWithoutQuota: KafkaProducer[Array[Byte], Array[Byte]] = _ + + val defaultRequestQuota = 1000 + val defaultProduceQuota = 2000 * 1000 * 1000 + val defaultConsumeQuota = 1000 * 1000 * 1000 + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + startSasl(jaasSections(kafkaServerSaslMechanisms, Some("SCRAM-SHA-256"), KafkaSasl, JaasTestUtils.KafkaServerContextName)) + this.serverConfig.setProperty(KafkaConfig.ClientQuotaCallbackClassProp, classOf[GroupedUserQuotaCallback].getName) + this.serverConfig.setProperty(s"${listenerName.configPrefix}${KafkaConfig.PrincipalBuilderClassProp}", + classOf[GroupedUserPrincipalBuilder].getName) + this.serverConfig.setProperty(KafkaConfig.DeleteTopicEnableProp, "true") + super.setUp(testInfo) + brokerList = TestUtils.bootstrapServers(servers, listenerName) + + producerConfig.put(SaslConfigs.SASL_JAAS_CONFIG, + ScramLoginModule(JaasTestUtils.KafkaScramAdmin, JaasTestUtils.KafkaScramAdminPassword).toString) + producerWithoutQuota = createProducer() + } + + @AfterEach + override def tearDown(): Unit = { + adminClients.foreach(_.close()) + GroupedUserQuotaCallback.tearDown() + super.tearDown() + } + + override def configureSecurityBeforeServersStart(): Unit = { + super.configureSecurityBeforeServersStart() + zkClient.makeSurePersistentPathExists(ConfigEntityChangeNotificationZNode.path) + createScramCredentials(zkConnect, JaasTestUtils.KafkaScramAdmin, JaasTestUtils.KafkaScramAdminPassword) + } + + @Test + def testCustomQuotaCallback(): Unit = { + // Large quota override, should not throttle + var brokerId = 0 + var user = createGroupWithOneUser("group0_user1", brokerId) + user.configureAndWaitForQuota(1000000, 2000000) + quotaLimitCalls.values.foreach(_.set(0)) + user.produceConsume(expectProduceThrottle = false, expectConsumeThrottle = false) + + // ClientQuotaCallback#quotaLimit is invoked by each quota manager once per throttled produce request for each client + assertEquals(1, quotaLimitCalls(ClientQuotaType.PRODUCE).get) + // ClientQuotaCallback#quotaLimit is invoked once per each unthrottled and two for each throttled request + // since we don't know the total number of requests, we verify it was called at least twice (at least one throttled request) + assertTrue(quotaLimitCalls(ClientQuotaType.FETCH).get > 2, "quotaLimit must be called at least twice") + assertTrue(quotaLimitCalls(ClientQuotaType.REQUEST).get <= 10, s"Too many quotaLimit calls $quotaLimitCalls") // sanity check + // Large quota updated to small quota, should throttle + user.configureAndWaitForQuota(9000, 3000) + user.produceConsume(expectProduceThrottle = true, expectConsumeThrottle = true) + + // Quota override deletion - verify default quota applied (large quota, no throttling) + user = addUser("group0_user2", brokerId) + user.removeQuotaOverrides() + user.waitForQuotaUpdate(defaultProduceQuota, defaultConsumeQuota, defaultRequestQuota) + user.removeThrottleMetrics() // since group was throttled before + user.produceConsume(expectProduceThrottle = false, expectConsumeThrottle = false) + + // Make default quota smaller, should throttle + user.configureAndWaitForQuota(8000, 2500, divisor = 1, group = None) + user.produceConsume(expectProduceThrottle = true, expectConsumeThrottle = true) + + // Configure large quota override, should not throttle + user = addUser("group0_user3", brokerId) + user.configureAndWaitForQuota(2000000, 2000000) + user.removeThrottleMetrics() // since group was throttled before + user.produceConsume(expectProduceThrottle = false, expectConsumeThrottle = false) + + // Quota large enough for one partition, should not throttle + brokerId = 1 + user = createGroupWithOneUser("group1_user1", brokerId) + user.configureAndWaitForQuota(8000 * 100, 2500 * 100) + user.produceConsume(expectProduceThrottle = false, expectConsumeThrottle = false) + + // Create large number of partitions on another broker, should result in throttling on first partition + val largeTopic = "group1_largeTopic" + createTopic(largeTopic, numPartitions = 99, leader = 0) + user.waitForQuotaUpdate(8000, 2500, defaultRequestQuota) + user.produceConsume(expectProduceThrottle = true, expectConsumeThrottle = true) + + // Remove quota override and test default quota applied with scaling based on partitions + user = addUser("group1_user2", brokerId) + user.waitForQuotaUpdate(defaultProduceQuota / 100, defaultConsumeQuota / 100, defaultRequestQuota) + user.removeThrottleMetrics() // since group was throttled before + user.produceConsume(expectProduceThrottle = false, expectConsumeThrottle = false) + user.configureAndWaitForQuota(8000 * 100, 2500 * 100, divisor=100, group = None) + user.produceConsume(expectProduceThrottle = true, expectConsumeThrottle = true) + + // Remove the second topic with large number of partitions, verify no longer throttled + adminZkClient.deleteTopic(largeTopic) + user = addUser("group1_user3", brokerId) + user.waitForQuotaUpdate(8000 * 100, 2500 * 100, defaultRequestQuota) + user.removeThrottleMetrics() // since group was throttled before + user.produceConsume(expectProduceThrottle = false, expectConsumeThrottle = false) + + // Alter configs of custom callback dynamically + val adminClient = createAdminClient() + val newProps = new Properties + newProps.put(GroupedUserQuotaCallback.DefaultProduceQuotaProp, "8000") + newProps.put(GroupedUserQuotaCallback.DefaultFetchQuotaProp, "2500") + TestUtils.incrementalAlterConfigs(servers, adminClient, newProps, perBrokerConfig = false) + user.waitForQuotaUpdate(8000, 2500, defaultRequestQuota) + user.produceConsume(expectProduceThrottle = true, expectConsumeThrottle = true) + + assertEquals(brokerCount, callbackInstances.get) + } + + /** + * Creates a group with one user and one topic with one partition. + * @param firstUser First user to create in the group + * @param brokerId The broker id to use as leader of the partition + */ + private def createGroupWithOneUser(firstUser: String, brokerId: Int): GroupedUser = { + val user = addUser(firstUser, brokerId) + createTopic(user.topic, numPartitions = 1, brokerId) + user.configureAndWaitForQuota(defaultProduceQuota, defaultConsumeQuota, divisor = 1, group = None) + user + } + + private def createTopic(topic: String, numPartitions: Int, leader: Int): Unit = { + val assignment = (0 until numPartitions).map { i => i -> Seq(leader) }.toMap + TestUtils.createTopic(zkClient, topic, assignment, servers) + } + + private def createAdminClient(): Admin = { + val config = new util.HashMap[String, Object] + config.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, + TestUtils.bootstrapServers(servers, new ListenerName("BROKER"))) + clientSecurityProps("admin-client").asInstanceOf[util.Map[Object, Object]].forEach { (key, value) => + config.put(key.toString, value) + } + config.put(SaslConfigs.SASL_JAAS_CONFIG, + ScramLoginModule(JaasTestUtils.KafkaScramAdmin, JaasTestUtils.KafkaScramAdminPassword).toString) + val adminClient = Admin.create(config) + adminClients += adminClient + adminClient + } + + private def produceWithoutThrottle(topic: String, numRecords: Int): Unit = { + (0 until numRecords).foreach { i => + val payload = i.toString.getBytes + producerWithoutQuota.send(new ProducerRecord[Array[Byte], Array[Byte]](topic, null, null, payload)) + } + } + + private def passwordForUser(user: String) = { + s"$user:secret" + } + + private def addUser(user: String, leader: Int): GroupedUser = { + val adminClient = createAdminClient() + createScramCredentials(adminClient, user, passwordForUser(user)) + waitForUserScramCredentialToAppearOnAllBrokers(user, kafkaClientSaslMechanism) + groupedUser(adminClient, user, leader) + } + + private def groupedUser(adminClient: Admin, user: String, leader: Int): GroupedUser = { + val password = passwordForUser(user) + val userGroup = group(user) + val topic = s"${userGroup}_topic" + val producerClientId = s"$user:producer-client-id" + val consumerClientId = s"$user:consumer-client-id" + + producerConfig.put(ProducerConfig.CLIENT_ID_CONFIG, producerClientId) + producerConfig.put(SaslConfigs.SASL_JAAS_CONFIG, ScramLoginModule(user, password).toString) + + consumerConfig.put(ConsumerConfig.CLIENT_ID_CONFIG, consumerClientId) + consumerConfig.put(ConsumerConfig.MAX_PARTITION_FETCH_BYTES_CONFIG, 4096.toString) + consumerConfig.put(ConsumerConfig.GROUP_ID_CONFIG, s"$user-group") + consumerConfig.put(SaslConfigs.SASL_JAAS_CONFIG, ScramLoginModule(user, password).toString) + + GroupedUser(user, userGroup, topic, servers(leader), producerClientId, consumerClientId, + createProducer(), createConsumer(), adminClient) + } + + case class GroupedUser(user: String, userGroup: String, topic: String, leaderNode: KafkaServer, + producerClientId: String, consumerClientId: String, + override val producer: KafkaProducer[Array[Byte], Array[Byte]], + override val consumer: KafkaConsumer[Array[Byte], Array[Byte]], + override val adminClient: Admin) extends + QuotaTestClients(topic, leaderNode, producerClientId, consumerClientId, producer, consumer, adminClient) { + + override def userPrincipal: KafkaPrincipal = GroupedUserPrincipal(user, userGroup) + + override def quotaMetricTags(clientId: String): Map[String, String] = { + Map(GroupedUserQuotaCallback.QuotaGroupTag -> userGroup) + } + + override def overrideQuotas(producerQuota: Long, consumerQuota: Long, requestQuota: Double): Unit = { + configureQuota(userGroup, producerQuota, consumerQuota, requestQuota) + } + + override def removeQuotaOverrides(): Unit = { + alterClientQuotas( + clientQuotaAlteration( + clientQuotaEntity(Some(quotaEntityName(userGroup)), None), + None, None, None + ) + ) + } + + def configureQuota(userGroup: String, producerQuota: Long, consumerQuota: Long, requestQuota: Double): Unit = { + alterClientQuotas( + clientQuotaAlteration( + clientQuotaEntity(Some(quotaEntityName(userGroup)), None), + Some(producerQuota), Some(consumerQuota), Some(requestQuota) + ) + ) + } + + def configureAndWaitForQuota(produceQuota: Long, fetchQuota: Long, divisor: Int = 1, + group: Option[String] = Some(userGroup)): Unit = { + configureQuota(group.getOrElse(""), produceQuota, fetchQuota, defaultRequestQuota) + waitForQuotaUpdate(produceQuota / divisor, fetchQuota / divisor, defaultRequestQuota) + } + + def produceConsume(expectProduceThrottle: Boolean, expectConsumeThrottle: Boolean): Unit = { + val numRecords = 1000 + val produced = produceUntilThrottled(numRecords, waitForRequestCompletion = false) + // don't verify request channel metrics as it's difficult to write non flaky assertions + // given the specifics of this test (throttle metric removal followed by produce/consume + // until throttled) + verifyProduceThrottle(expectProduceThrottle, verifyClientMetric = false, + verifyRequestChannelMetric = false) + // make sure there are enough records on the topic to test consumer throttling + produceWithoutThrottle(topic, numRecords - produced) + consumeUntilThrottled(numRecords, waitForRequestCompletion = false) + verifyConsumeThrottle(expectConsumeThrottle, verifyClientMetric = false, + verifyRequestChannelMetric = false) + } + + def removeThrottleMetrics(): Unit = { + def removeSensors(quotaType: QuotaType, clientId: String): Unit = { + val sensorSuffix = quotaMetricTags(clientId).values.mkString(":") + leaderNode.metrics.removeSensor(s"${quotaType}ThrottleTime-$sensorSuffix") + leaderNode.metrics.removeSensor(s"$quotaType-$sensorSuffix") + } + removeSensors(QuotaType.Produce, producerClientId) + removeSensors(QuotaType.Fetch, consumerClientId) + removeSensors(QuotaType.Request, producerClientId) + removeSensors(QuotaType.Request, consumerClientId) + } + + private def quotaEntityName(userGroup: String): String = s"${userGroup}_" + } +} + +object GroupedUserPrincipalBuilder { + def group(str: String): String = { + if (str.indexOf("_") <= 0) + "" + else + str.substring(0, str.indexOf("_")) + } +} + +class GroupedUserPrincipalBuilder extends KafkaPrincipalBuilder { + override def build(context: AuthenticationContext): KafkaPrincipal = { + val securityProtocol = context.securityProtocol + if (securityProtocol == SecurityProtocol.SASL_PLAINTEXT || securityProtocol == SecurityProtocol.SASL_SSL) { + val user = context.asInstanceOf[SaslAuthenticationContext].server().getAuthorizationID + val userGroup = group(user) + if (userGroup.isEmpty) + new KafkaPrincipal(KafkaPrincipal.USER_TYPE, user) + else + GroupedUserPrincipal(user, userGroup) + } else + throw new IllegalStateException(s"Unexpected security protocol $securityProtocol") + } +} + +case class GroupedUserPrincipal(user: String, userGroup: String) extends KafkaPrincipal(KafkaPrincipal.USER_TYPE, user) + +object GroupedUserQuotaCallback { + val QuotaGroupTag = "group" + val DefaultProduceQuotaProp = "default.produce.quota" + val DefaultFetchQuotaProp = "default.fetch.quota" + val UnlimitedQuotaMetricTags = new util.HashMap[String, String] + val quotaLimitCalls = Map( + ClientQuotaType.PRODUCE -> new AtomicInteger, + ClientQuotaType.FETCH -> new AtomicInteger, + ClientQuotaType.REQUEST -> new AtomicInteger + ) + val callbackInstances = new AtomicInteger + + def tearDown(): Unit = { + callbackInstances.set(0) + quotaLimitCalls.values.foreach(_.set(0)) + UnlimitedQuotaMetricTags.clear() + } +} + +/** + * Quota callback for a grouped user. Both user principals and topics of each group + * are prefixed with the group name followed by '_'. This callback defines quotas of different + * types at the group level. Group quotas are configured in ZooKeeper as user quotas with + * the entity name "${group}_". Default group quotas are configured in ZooKeeper as user quotas + * with the entity name "_". + * + * Default group quotas may also be configured using the configuration options + * "default.produce.quota" and "default.fetch.quota" which can be reconfigured dynamically + * without restarting the broker. This tests custom reconfigurable options for quota callbacks, + */ +class GroupedUserQuotaCallback extends ClientQuotaCallback with Reconfigurable with Logging { + + var brokerId: Int = -1 + val customQuotasUpdated = ClientQuotaType.values.map(quotaType => quotaType -> new AtomicBoolean).toMap + val quotas = ClientQuotaType.values.map(quotaType => quotaType -> new ConcurrentHashMap[String, Double]).toMap + + val partitionRatio = new ConcurrentHashMap[String, Double]() + + override def configure(configs: util.Map[String, _]): Unit = { + brokerId = configs.get(KafkaConfig.BrokerIdProp).toString.toInt + callbackInstances.incrementAndGet + } + + override def reconfigurableConfigs: util.Set[String] = { + Set(DefaultProduceQuotaProp, DefaultFetchQuotaProp).asJava + } + + override def validateReconfiguration(configs: util.Map[String, _]): Unit = { + reconfigurableConfigs.forEach(configValue(configs, _)) + } + + override def reconfigure(configs: util.Map[String, _]): Unit = { + configValue(configs, DefaultProduceQuotaProp).foreach(value => quotas(ClientQuotaType.PRODUCE).put("", value.toDouble)) + configValue(configs, DefaultFetchQuotaProp).foreach(value => quotas(ClientQuotaType.FETCH).put("", value.toDouble)) + customQuotasUpdated.values.foreach(_.set(true)) + } + + private def configValue(configs: util.Map[String, _], key: String): Option[Long] = { + val value = configs.get(key) + if (value != null) Some(value.toString.toLong) else None + } + + override def quotaMetricTags(quotaType: ClientQuotaType, principal: KafkaPrincipal, clientId: String): util.Map[String, String] = { + principal match { + case groupPrincipal: GroupedUserPrincipal => + val userGroup = groupPrincipal.userGroup + val quotaLimit = quotaOrDefault(userGroup, quotaType) + if (quotaLimit != null) + Map(QuotaGroupTag -> userGroup).asJava + else + UnlimitedQuotaMetricTags + case _ => + UnlimitedQuotaMetricTags + } + } + + override def quotaLimit(quotaType: ClientQuotaType, metricTags: util.Map[String, String]): lang.Double = { + quotaLimitCalls(quotaType).incrementAndGet + val group = metricTags.get(QuotaGroupTag) + if (group != null) quotaOrDefault(group, quotaType) else null + } + + override def updateClusterMetadata(cluster: Cluster): Boolean = { + val topicsByGroup = cluster.topics.asScala.groupBy(group) + + topicsByGroup.map { case (group, groupTopics) => + val groupPartitions = groupTopics.flatMap(topic => cluster.partitionsForTopic(topic).asScala) + val totalPartitions = groupPartitions.size + val partitionsOnThisBroker = groupPartitions.count { p => p.leader != null && p.leader.id == brokerId } + val multiplier = if (totalPartitions == 0) + 1 + else if (partitionsOnThisBroker == 0) + 1.0 / totalPartitions + else + partitionsOnThisBroker.toDouble / totalPartitions + partitionRatio.put(group, multiplier) != multiplier + }.exists(identity) + } + + override def updateQuota(quotaType: ClientQuotaType, quotaEntity: ClientQuotaEntity, newValue: Double): Unit = { + quotas(quotaType).put(userGroup(quotaEntity), newValue) + } + + override def removeQuota(quotaType: ClientQuotaType, quotaEntity: ClientQuotaEntity): Unit = { + quotas(quotaType).remove(userGroup(quotaEntity)) + } + + override def quotaResetRequired(quotaType: ClientQuotaType): Boolean = customQuotasUpdated(quotaType).getAndSet(false) + + def close(): Unit = {} + + private def userGroup(quotaEntity: ClientQuotaEntity): String = { + val configEntity = quotaEntity.configEntities.get(0) + if (configEntity.entityType == ClientQuotaEntity.ConfigEntityType.USER) + group(configEntity.name) + else + throw new IllegalArgumentException(s"Config entity type ${configEntity.entityType} is not supported") + } + + private def quotaOrDefault(group: String, quotaType: ClientQuotaType): lang.Double = { + val quotaMap = quotas(quotaType) + var quotaLimit: Any = quotaMap.get(group) + if (quotaLimit == null) + quotaLimit = quotaMap.get("") + if (quotaLimit != null) scaledQuota(quotaType, group, quotaLimit.asInstanceOf[Double]) else null + } + + private def scaledQuota(quotaType: ClientQuotaType, group: String, configuredQuota: Double): Double = { + if (quotaType == ClientQuotaType.REQUEST) + configuredQuota + else { + val multiplier = partitionRatio.get(group) + if (multiplier <= 0.0) configuredQuota else configuredQuota * multiplier + } + } +} + + diff --git a/core/src/test/scala/integration/kafka/api/DelegationTokenEndToEndAuthorizationTest.scala b/core/src/test/scala/integration/kafka/api/DelegationTokenEndToEndAuthorizationTest.scala new file mode 100644 index 0000000..7336055 --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/DelegationTokenEndToEndAuthorizationTest.scala @@ -0,0 +1,126 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.api + +import java.util.Properties + +import kafka.server.KafkaConfig +import kafka.utils.{JaasTestUtils, TestUtils} +import kafka.zk.ConfigEntityChangeNotificationZNode +import org.apache.kafka.clients.admin.{Admin, AdminClientConfig, ScramCredentialInfo, UserScramCredentialAlteration, UserScramCredentialUpsertion, ScramMechanism => PublicScramMechanism} +import org.apache.kafka.common.config.SaslConfigs +import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol} +import org.apache.kafka.common.security.scram.internals.ScramMechanism +import org.apache.kafka.common.security.token.delegation.DelegationToken +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{BeforeEach, Test, TestInfo} + +import scala.jdk.CollectionConverters._ + +class DelegationTokenEndToEndAuthorizationTest extends EndToEndAuthorizationTest { + + val kafkaClientSaslMechanism = "SCRAM-SHA-256" + val kafkaServerSaslMechanisms = ScramMechanism.mechanismNames.asScala.toList + + override protected def securityProtocol = SecurityProtocol.SASL_SSL + + override protected val serverSaslProperties = Some(kafkaServerSaslProperties(kafkaServerSaslMechanisms, kafkaClientSaslMechanism)) + override protected val clientSaslProperties = Some(kafkaClientSaslProperties(kafkaClientSaslMechanism)) + + override val clientPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, JaasTestUtils.KafkaScramUser) + private val clientPassword = JaasTestUtils.KafkaScramPassword + + override val kafkaPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, JaasTestUtils.KafkaScramAdmin) + private val kafkaPassword = JaasTestUtils.KafkaScramAdminPassword + + private val privilegedAdminClientConfig = new Properties() + + this.serverConfig.setProperty(KafkaConfig.DelegationTokenSecretKeyProp, "testKey") + + override def configureSecurityBeforeServersStart(): Unit = { + super.configureSecurityBeforeServersStart() + zkClient.makeSurePersistentPathExists(ConfigEntityChangeNotificationZNode.path) + // Create broker admin credentials before starting brokers + createScramCredentials(zkConnect, kafkaPrincipal.getName, kafkaPassword) + } + + override def createPrivilegedAdminClient() = createScramAdminClient(kafkaClientSaslMechanism, kafkaPrincipal.getName, kafkaPassword) + + override def configureSecurityAfterServersStart(): Unit = { + super.configureSecurityAfterServersStart() + + // create scram credential for user "scram-user" + createScramCredentialsViaPrivilegedAdminClient(clientPrincipal.getName, clientPassword) + waitForUserScramCredentialToAppearOnAllBrokers(clientPrincipal.getName, kafkaClientSaslMechanism) + + //create a token with "scram-user" credentials and a privileged token with scram-admin credentials + val tokens = createDelegationTokens() + val token = tokens._1 + val privilegedToken = tokens._2 + + privilegedAdminClientConfig.putAll(adminClientConfig) + + // pass token to client jaas config + val clientLoginContext = JaasTestUtils.tokenClientLoginModule(token.tokenInfo().tokenId(), token.hmacAsBase64String()) + producerConfig.put(SaslConfigs.SASL_JAAS_CONFIG, clientLoginContext) + consumerConfig.put(SaslConfigs.SASL_JAAS_CONFIG, clientLoginContext) + adminClientConfig.put(SaslConfigs.SASL_JAAS_CONFIG, clientLoginContext) + val privilegedClientLoginContext = JaasTestUtils.tokenClientLoginModule(privilegedToken.tokenInfo().tokenId(), privilegedToken.hmacAsBase64String()) + privilegedAdminClientConfig.put(SaslConfigs.SASL_JAAS_CONFIG, privilegedClientLoginContext) + } + + @Test + def testCreateUserWithDelegationToken(): Unit = { + val privilegedAdminClient = Admin.create(privilegedAdminClientConfig) + try { + val user = "user" + val results = privilegedAdminClient.alterUserScramCredentials(List[UserScramCredentialAlteration]( + new UserScramCredentialUpsertion(user, new ScramCredentialInfo(PublicScramMechanism.SCRAM_SHA_256, 4096), "password")).asJava) + assertEquals(1, results.values.size) + val future = results.values.get(user) + future.get // make sure we haven't completed exceptionally + } finally { + privilegedAdminClient.close() + } + } + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + startSasl(jaasSections(kafkaServerSaslMechanisms, Option(kafkaClientSaslMechanism), Both)) + super.setUp(testInfo) + privilegedAdminClientConfig.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList) + } + + private def createDelegationTokens(): (DelegationToken, DelegationToken) = { + val adminClient = createScramAdminClient(kafkaClientSaslMechanism, clientPrincipal.getName, clientPassword) + try { + val privilegedAdminClient = createScramAdminClient(kafkaClientSaslMechanism, kafkaPrincipal.getName, kafkaPassword) + try { + val token = adminClient.createDelegationToken().delegationToken().get() + val privilegedToken = privilegedAdminClient.createDelegationToken().delegationToken().get() + //wait for tokens to reach all the brokers + TestUtils.waitUntilTrue(() => servers.forall(server => server.tokenCache.tokens().size() == 2), + "Timed out waiting for token to propagate to all servers") + (token, privilegedToken) + } finally { + privilegedAdminClient.close() + } + } finally { + adminClient.close() + } + } +} diff --git a/core/src/test/scala/integration/kafka/api/DescribeAuthorizedOperationsTest.scala b/core/src/test/scala/integration/kafka/api/DescribeAuthorizedOperationsTest.scala new file mode 100644 index 0000000..0a46972 --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/DescribeAuthorizedOperationsTest.scala @@ -0,0 +1,209 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package kafka.api + +import java.io.File +import java.util +import java.util.Properties + +import kafka.security.authorizer.{AclAuthorizer, AclEntry} +import kafka.server.KafkaConfig +import kafka.utils.{CoreUtils, JaasTestUtils, TestUtils} +import org.apache.kafka.clients.admin._ +import org.apache.kafka.common.acl.AclOperation.{ALL, ALTER, CLUSTER_ACTION, DELETE, DESCRIBE} +import org.apache.kafka.common.acl.AclPermissionType.ALLOW +import org.apache.kafka.common.acl._ +import org.apache.kafka.common.resource.{PatternType, Resource, ResourcePattern, ResourceType} +import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol} +import org.apache.kafka.common.utils.Utils +import org.apache.kafka.server.authorizer.Authorizer +import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertNull} +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +import scala.jdk.CollectionConverters._ + +object DescribeAuthorizedOperationsTest { + val Group1 = "group1" + val Group2 = "group2" + val Group3 = "group3" + val Topic1 = "topic1" + val Topic2 = "topic2" + + val Group1Acl = new AclBinding( + new ResourcePattern(ResourceType.GROUP, Group1, PatternType.LITERAL), + accessControlEntry(JaasTestUtils.KafkaClientPrincipalUnqualifiedName2, ALL)) + + val Group2Acl = new AclBinding( + new ResourcePattern(ResourceType.GROUP, Group2, PatternType.LITERAL), + accessControlEntry(JaasTestUtils.KafkaClientPrincipalUnqualifiedName2, DESCRIBE)) + + val Group3Acl = new AclBinding( + new ResourcePattern(ResourceType.GROUP, Group3, PatternType.LITERAL), + accessControlEntry(JaasTestUtils.KafkaClientPrincipalUnqualifiedName2, DELETE)) + + val ClusterAllAcl = new AclBinding( + new ResourcePattern(ResourceType.CLUSTER, Resource.CLUSTER_NAME, PatternType.LITERAL), + accessControlEntry(JaasTestUtils.KafkaClientPrincipalUnqualifiedName2, ALL)) + + val Topic1Acl = new AclBinding( + new ResourcePattern(ResourceType.TOPIC, Topic1, PatternType.LITERAL), + accessControlEntry(JaasTestUtils.KafkaClientPrincipalUnqualifiedName2, ALL)) + + val Topic2All = new AclBinding( + new ResourcePattern(ResourceType.TOPIC, Topic2, PatternType.LITERAL), + accessControlEntry(JaasTestUtils.KafkaClientPrincipalUnqualifiedName2, DELETE)) + + private def accessControlEntry( + userName: String, + operation: AclOperation + ): AccessControlEntry = { + new AccessControlEntry(new KafkaPrincipal(KafkaPrincipal.USER_TYPE, userName).toString, + AclEntry.WildcardHost, operation, ALLOW) + } +} + +class DescribeAuthorizedOperationsTest extends IntegrationTestHarness with SaslSetup { + import DescribeAuthorizedOperationsTest._ + + override val brokerCount = 1 + this.serverConfig.setProperty(KafkaConfig.ZkEnableSecureAclsProp, "true") + this.serverConfig.setProperty(KafkaConfig.AuthorizerClassNameProp, classOf[AclAuthorizer].getName) + + var client: Admin = _ + + override protected def securityProtocol = SecurityProtocol.SASL_SSL + + override protected lazy val trustStoreFile = Some(File.createTempFile("truststore", ".jks")) + + override def configureSecurityBeforeServersStart(): Unit = { + val authorizer = CoreUtils.createObject[Authorizer](classOf[AclAuthorizer].getName) + val clusterResource = new ResourcePattern(ResourceType.CLUSTER, Resource.CLUSTER_NAME, PatternType.LITERAL) + val topicResource = new ResourcePattern(ResourceType.TOPIC, AclEntry.WildcardResource, PatternType.LITERAL) + + try { + authorizer.configure(this.configs.head.originals()) + val result = authorizer.createAcls(null, List( + new AclBinding(clusterResource, accessControlEntry( + JaasTestUtils.KafkaServerPrincipalUnqualifiedName, CLUSTER_ACTION)), + new AclBinding(clusterResource, accessControlEntry( + JaasTestUtils.KafkaClientPrincipalUnqualifiedName2, ALTER)), + new AclBinding(topicResource, accessControlEntry( + JaasTestUtils.KafkaClientPrincipalUnqualifiedName2, DESCRIBE)) + ).asJava) + result.asScala.map(_.toCompletableFuture.get).foreach(result => assertFalse(result.exception.isPresent)) + } finally { + authorizer.close() + } + } + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + startSasl(jaasSections(Seq("GSSAPI"), Some("GSSAPI"), Both, JaasTestUtils.KafkaServerContextName)) + super.setUp(testInfo) + TestUtils.waitUntilBrokerMetadataIsPropagated(servers) + client = Admin.create(createConfig()) + } + + @AfterEach + override def tearDown(): Unit = { + Utils.closeQuietly(client, "AdminClient") + super.tearDown() + closeSasl() + } + + private def createConfig(): Properties = { + val adminClientConfig = new Properties() + adminClientConfig.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList) + adminClientConfig.put(AdminClientConfig.REQUEST_TIMEOUT_MS_CONFIG, "20000") + val securityProps: util.Map[Object, Object] = + TestUtils.adminClientSecurityConfigs(securityProtocol, trustStoreFile, clientSaslProperties) + adminClientConfig.putAll(securityProps) + adminClientConfig + } + + @Test + def testConsumerGroupAuthorizedOperations(): Unit = { + val results = client.createAcls(List(Group1Acl, Group2Acl, Group3Acl).asJava) + assertEquals(Set(Group1Acl, Group2Acl, Group3Acl), results.values.keySet.asScala) + results.all.get + + val describeConsumerGroupsResult = client.describeConsumerGroups(Seq(Group1, Group2, Group3).asJava, + new DescribeConsumerGroupsOptions().includeAuthorizedOperations(true)) + assertEquals(3, describeConsumerGroupsResult.describedGroups().size()) + val expectedOperations = AclEntry.supportedOperations(ResourceType.GROUP).asJava + + val group1Description = describeConsumerGroupsResult.describedGroups().get(Group1).get + assertEquals(expectedOperations, group1Description.authorizedOperations()) + + val group2Description = describeConsumerGroupsResult.describedGroups().get(Group2).get + assertEquals(Set(AclOperation.DESCRIBE), group2Description.authorizedOperations().asScala.toSet) + + val group3Description = describeConsumerGroupsResult.describedGroups().get(Group3).get + assertEquals(Set(AclOperation.DESCRIBE, AclOperation.DELETE), group3Description.authorizedOperations().asScala.toSet) + } + + @Test + def testClusterAuthorizedOperations(): Unit = { + // test without includeAuthorizedOperations flag + var clusterDescribeResult = client.describeCluster() + assertNull(clusterDescribeResult.authorizedOperations.get()) + + // test with includeAuthorizedOperations flag, we have give Alter permission + // in configureSecurityBeforeServersStart() + clusterDescribeResult = client.describeCluster(new DescribeClusterOptions(). + includeAuthorizedOperations(true)) + assertEquals(Set(AclOperation.DESCRIBE, AclOperation.ALTER), + clusterDescribeResult.authorizedOperations().get().asScala.toSet) + + // enable all operations for cluster resource + val results = client.createAcls(List(ClusterAllAcl).asJava) + assertEquals(Set(ClusterAllAcl), results.values.keySet.asScala) + results.all.get + + val expectedOperations = AclEntry.supportedOperations(ResourceType.CLUSTER).asJava + + clusterDescribeResult = client.describeCluster(new DescribeClusterOptions(). + includeAuthorizedOperations(true)) + assertEquals(expectedOperations, clusterDescribeResult.authorizedOperations().get()) + } + + @Test + def testTopicAuthorizedOperations(): Unit = { + createTopic(Topic1) + createTopic(Topic2) + + // test without includeAuthorizedOperations flag + var describeTopicsResult = client.describeTopics(Set(Topic1, Topic2).asJava).allTopicNames.get() + assertNull(describeTopicsResult.get(Topic1).authorizedOperations) + assertNull(describeTopicsResult.get(Topic2).authorizedOperations) + + // test with includeAuthorizedOperations flag + describeTopicsResult = client.describeTopics(Set(Topic1, Topic2).asJava, + new DescribeTopicsOptions().includeAuthorizedOperations(true)).allTopicNames.get() + assertEquals(Set(AclOperation.DESCRIBE), describeTopicsResult.get(Topic1).authorizedOperations().asScala.toSet) + assertEquals(Set(AclOperation.DESCRIBE), describeTopicsResult.get(Topic2).authorizedOperations().asScala.toSet) + + // add few permissions + val results = client.createAcls(List(Topic1Acl, Topic2All).asJava) + assertEquals(Set(Topic1Acl, Topic2All), results.values.keySet.asScala) + results.all.get + + val expectedOperations = AclEntry.supportedOperations(ResourceType.TOPIC).asJava + + describeTopicsResult = client.describeTopics(Set(Topic1, Topic2).asJava, + new DescribeTopicsOptions().includeAuthorizedOperations(true)).allTopicNames.get() + assertEquals(expectedOperations, describeTopicsResult.get(Topic1).authorizedOperations()) + assertEquals(Set(AclOperation.DESCRIBE, AclOperation.DELETE), + describeTopicsResult.get(Topic2).authorizedOperations().asScala.toSet) + } +} diff --git a/core/src/test/scala/integration/kafka/api/EndToEndAuthorizationTest.scala b/core/src/test/scala/integration/kafka/api/EndToEndAuthorizationTest.scala new file mode 100644 index 0000000..cbb536b --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/EndToEndAuthorizationTest.scala @@ -0,0 +1,553 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.api + +import com.yammer.metrics.core.Gauge + +import java.io.File +import java.util.Collections +import java.util.concurrent.ExecutionException +import kafka.admin.AclCommand +import kafka.metrics.KafkaYammerMetrics +import kafka.security.authorizer.AclAuthorizer +import kafka.security.authorizer.AclEntry.WildcardHost +import kafka.server._ +import kafka.utils._ +import org.apache.kafka.clients.admin.Admin +import org.apache.kafka.clients.consumer.{Consumer, ConsumerConfig, ConsumerRecords} +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord} +import org.apache.kafka.common.acl._ +import org.apache.kafka.common.acl.AclOperation._ +import org.apache.kafka.common.acl.AclPermissionType._ +import org.apache.kafka.common.{KafkaException, TopicPartition} +import org.apache.kafka.common.errors.{GroupAuthorizationException, TopicAuthorizationException} +import org.apache.kafka.common.resource._ +import org.apache.kafka.common.resource.ResourceType._ +import org.apache.kafka.common.resource.PatternType.{LITERAL, PREFIXED} +import org.apache.kafka.common.security.auth.KafkaPrincipal +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +import scala.jdk.CollectionConverters._ + +/** + * The test cases here verify that a producer authorized to publish to a topic + * is able to, and that consumers in a group authorized to consume are able to + * to do so. + * + * This test relies on a chain of test harness traits to set up. It directly + * extends IntegrationTestHarness. IntegrationTestHarness creates producers and + * consumers, and it extends KafkaServerTestHarness. KafkaServerTestHarness starts + * brokers, but first it initializes a ZooKeeper server and client, which happens + * in QuorumTestHarness. + * + * To start brokers we need to set a cluster ACL, which happens optionally in KafkaServerTestHarness. + * The remaining ACLs to enable access to producers and consumers are set here. To set ACLs, we use AclCommand directly. + * + * Finally, we rely on SaslSetup to bootstrap and setup Kerberos. We don't use + * SaslTestHarness here directly because it extends QuorumTestHarness, and we + * would end up with QuorumTestHarness twice. + */ +abstract class EndToEndAuthorizationTest extends IntegrationTestHarness with SaslSetup { + override val brokerCount = 3 + + override def configureSecurityBeforeServersStart(): Unit = { + AclCommand.main(clusterActionArgs) + AclCommand.main(clusterAlterArgs) + AclCommand.main(topicBrokerReadAclArgs) + } + + val numRecords = 1 + val groupPrefix = "gr" + val group = s"${groupPrefix}oup" + val topicPrefix = "e2e" + val topic = s"${topicPrefix}topic" + val wildcard = "*" + val part = 0 + val tp = new TopicPartition(topic, part) + + override protected lazy val trustStoreFile = Some(File.createTempFile("truststore", ".jks")) + protected def authorizerClass: Class[_] = classOf[AclAuthorizer] + + val topicResource = new ResourcePattern(TOPIC, topic, LITERAL) + val groupResource = new ResourcePattern(GROUP, group, LITERAL) + val clusterResource = new ResourcePattern(CLUSTER, Resource.CLUSTER_NAME, LITERAL) + val prefixedTopicResource = new ResourcePattern(TOPIC, topicPrefix, PREFIXED) + val prefixedGroupResource = new ResourcePattern(GROUP, groupPrefix, PREFIXED) + val wildcardTopicResource = new ResourcePattern(TOPIC, wildcard, LITERAL) + val wildcardGroupResource = new ResourcePattern(GROUP, wildcard, LITERAL) + + def clientPrincipal: KafkaPrincipal + def kafkaPrincipal: KafkaPrincipal + + // Arguments to AclCommand to set ACLs. + def clusterActionArgs: Array[String] = Array("--authorizer-properties", + s"zookeeper.connect=$zkConnect", + s"--add", + s"--cluster", + s"--operation=ClusterAction", + s"--allow-principal=$kafkaPrincipal") + // necessary to create SCRAM credentials via the admin client using the broker's credentials + // without this we would need to create the SCRAM credentials via ZooKeeper + def clusterAlterArgs: Array[String] = Array("--authorizer-properties", + s"zookeeper.connect=$zkConnect", + s"--add", + s"--cluster", + s"--operation=Alter", + s"--allow-principal=$kafkaPrincipal") + def topicBrokerReadAclArgs: Array[String] = Array("--authorizer-properties", + s"zookeeper.connect=$zkConnect", + s"--add", + s"--topic=$wildcard", + s"--operation=Read", + s"--allow-principal=$kafkaPrincipal") + def produceAclArgs(topic: String): Array[String] = Array("--authorizer-properties", + s"zookeeper.connect=$zkConnect", + s"--add", + s"--topic=$topic", + s"--producer", + s"--allow-principal=$clientPrincipal") + def describeAclArgs: Array[String] = Array("--authorizer-properties", + s"zookeeper.connect=$zkConnect", + s"--add", + s"--topic=$topic", + s"--operation=Describe", + s"--allow-principal=$clientPrincipal") + def deleteDescribeAclArgs: Array[String] = Array("--authorizer-properties", + s"zookeeper.connect=$zkConnect", + s"--remove", + s"--force", + s"--topic=$topic", + s"--operation=Describe", + s"--allow-principal=$clientPrincipal") + def deleteWriteAclArgs: Array[String] = Array("--authorizer-properties", + s"zookeeper.connect=$zkConnect", + s"--remove", + s"--force", + s"--topic=$topic", + s"--operation=Write", + s"--allow-principal=$clientPrincipal") + def consumeAclArgs(topic: String): Array[String] = Array("--authorizer-properties", + s"zookeeper.connect=$zkConnect", + s"--add", + s"--topic=$topic", + s"--group=$group", + s"--consumer", + s"--allow-principal=$clientPrincipal") + def groupAclArgs: Array[String] = Array("--authorizer-properties", + s"zookeeper.connect=$zkConnect", + s"--add", + s"--group=$group", + s"--operation=Read", + s"--allow-principal=$clientPrincipal") + def produceConsumeWildcardAclArgs: Array[String] = Array("--authorizer-properties", + s"zookeeper.connect=$zkConnect", + s"--add", + s"--topic=$wildcard", + s"--group=$wildcard", + s"--consumer", + s"--producer", + s"--allow-principal=$clientPrincipal") + def produceConsumePrefixedAclsArgs: Array[String] = Array("--authorizer-properties", + s"zookeeper.connect=$zkConnect", + s"--add", + s"--topic=$topicPrefix", + s"--group=$groupPrefix", + s"--resource-pattern-type=prefixed", + s"--consumer", + s"--producer", + s"--allow-principal=$clientPrincipal") + + def ClusterActionAndClusterAlterAcls = Set(new AccessControlEntry(kafkaPrincipal.toString, WildcardHost, CLUSTER_ACTION, ALLOW), + new AccessControlEntry(kafkaPrincipal.toString, WildcardHost, ALTER, ALLOW)) + def TopicBrokerReadAcl = Set(new AccessControlEntry(kafkaPrincipal.toString, WildcardHost, READ, ALLOW)) + def GroupReadAcl = Set(new AccessControlEntry(clientPrincipal.toString, WildcardHost, READ, ALLOW)) + def TopicReadAcl = Set(new AccessControlEntry(clientPrincipal.toString, WildcardHost, READ, ALLOW)) + def TopicWriteAcl = Set(new AccessControlEntry(clientPrincipal.toString, WildcardHost, WRITE, ALLOW)) + def TopicDescribeAcl = Set(new AccessControlEntry(clientPrincipal.toString, WildcardHost, DESCRIBE, ALLOW)) + def TopicCreateAcl = Set(new AccessControlEntry(clientPrincipal.toString, WildcardHost, CREATE, ALLOW)) + // The next two configuration parameters enable ZooKeeper secure ACLs + // and sets the Kafka authorizer, both necessary to enable security. + this.serverConfig.setProperty(KafkaConfig.ZkEnableSecureAclsProp, "true") + this.serverConfig.setProperty(KafkaConfig.AuthorizerClassNameProp, authorizerClass.getName) + // Some needed configuration for brokers, producers, and consumers + this.serverConfig.setProperty(KafkaConfig.OffsetsTopicPartitionsProp, "1") + this.serverConfig.setProperty(KafkaConfig.OffsetsTopicReplicationFactorProp, "3") + this.serverConfig.setProperty(KafkaConfig.MinInSyncReplicasProp, "3") + this.serverConfig.setProperty(KafkaConfig.DefaultReplicationFactorProp, "3") + this.serverConfig.setProperty(KafkaConfig.ConnectionsMaxReauthMsProp, "1500") + this.consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "group") + this.consumerConfig.setProperty(ConsumerConfig.METADATA_MAX_AGE_CONFIG, "1500") + + /** + * Starts MiniKDC and only then sets up the parent trait. + */ + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + servers.foreach { s => + TestUtils.waitAndVerifyAcls(ClusterActionAndClusterAlterAcls, s.dataPlaneRequestProcessor.authorizer.get, clusterResource) + TestUtils.waitAndVerifyAcls(TopicBrokerReadAcl, s.dataPlaneRequestProcessor.authorizer.get, new ResourcePattern(TOPIC, "*", LITERAL)) + } + // create the test topic with all the brokers as replicas + createTopic(topic, 1, 3) + } + + /** + * Closes MiniKDC last when tearing down. + */ + @AfterEach + override def tearDown(): Unit = { + super.tearDown() + closeSasl() + } + + /** + * Tests the ability of producing and consuming with the appropriate ACLs set. + */ + @Test + def testProduceConsumeViaAssign(): Unit = { + setAclsAndProduce(tp) + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + consumeRecords(consumer, numRecords) + confirmReauthenticationMetrics() + } + + protected def confirmReauthenticationMetrics(): Unit = { + val expiredConnectionsKilledCountTotal = getGauge("ExpiredConnectionsKilledCount").value() + servers.foreach { s => + val numExpiredKilled = TestUtils.totalMetricValue(s, "expired-connections-killed-count") + assertEquals(0, numExpiredKilled, "Should have been zero expired connections killed: " + numExpiredKilled + "(total=" + expiredConnectionsKilledCountTotal + ")") + } + assertEquals(0, expiredConnectionsKilledCountTotal, 0.0, "Should have been zero expired connections killed total") + servers.foreach { s => + assertEquals(0, TestUtils.totalMetricValue(s, "failed-reauthentication-total"), "failed re-authentications not 0") + } + } + + private def getGauge(metricName: String) = { + KafkaYammerMetrics.defaultRegistry.allMetrics.asScala + .find { case (k, _) => k.getName == metricName } + .getOrElse(throw new RuntimeException( "Unable to find metric " + metricName)) + ._2.asInstanceOf[Gauge[Double]] + } + + @Test + def testProduceConsumeViaSubscribe(): Unit = { + setAclsAndProduce(tp) + val consumer = createConsumer() + consumer.subscribe(List(topic).asJava) + consumeRecords(consumer, numRecords) + confirmReauthenticationMetrics() + } + + @Test + def testProduceConsumeWithWildcardAcls(): Unit = { + setWildcardResourceAcls() + val producer = createProducer() + sendRecords(producer, numRecords, tp) + val consumer = createConsumer() + consumer.subscribe(List(topic).asJava) + consumeRecords(consumer, numRecords) + confirmReauthenticationMetrics() + } + + @Test + def testProduceConsumeWithPrefixedAcls(): Unit = { + setPrefixedResourceAcls() + val producer = createProducer() + sendRecords(producer, numRecords, tp) + val consumer = createConsumer() + consumer.subscribe(List(topic).asJava) + consumeRecords(consumer, numRecords) + confirmReauthenticationMetrics() + } + + @Test + def testProduceConsumeTopicAutoCreateTopicCreateAcl(): Unit = { + // topic2 is not created on setup() + val tp2 = new TopicPartition("topic2", 0) + setAclsAndProduce(tp2) + val consumer = createConsumer() + consumer.assign(List(tp2).asJava) + consumeRecords(consumer, numRecords, topic = tp2.topic) + confirmReauthenticationMetrics() + } + + private def setWildcardResourceAcls(): Unit = { + AclCommand.main(produceConsumeWildcardAclArgs) + servers.foreach { s => + TestUtils.waitAndVerifyAcls(TopicReadAcl ++ TopicWriteAcl ++ TopicDescribeAcl ++ TopicCreateAcl ++ TopicBrokerReadAcl, s.dataPlaneRequestProcessor.authorizer.get, wildcardTopicResource) + TestUtils.waitAndVerifyAcls(GroupReadAcl, s.dataPlaneRequestProcessor.authorizer.get, wildcardGroupResource) + } + } + + private def setPrefixedResourceAcls(): Unit = { + AclCommand.main(produceConsumePrefixedAclsArgs) + servers.foreach { s => + TestUtils.waitAndVerifyAcls(TopicReadAcl ++ TopicWriteAcl ++ TopicDescribeAcl ++ TopicCreateAcl, s.dataPlaneRequestProcessor.authorizer.get, prefixedTopicResource) + TestUtils.waitAndVerifyAcls(GroupReadAcl, s.dataPlaneRequestProcessor.authorizer.get, prefixedGroupResource) + } + } + + private def setReadAndWriteAcls(tp: TopicPartition): Unit = { + AclCommand.main(produceAclArgs(tp.topic)) + AclCommand.main(consumeAclArgs(tp.topic)) + servers.foreach { s => + TestUtils.waitAndVerifyAcls(TopicReadAcl ++ TopicWriteAcl ++ TopicDescribeAcl ++ TopicCreateAcl, s.dataPlaneRequestProcessor.authorizer.get, + new ResourcePattern(TOPIC, tp.topic, LITERAL)) + TestUtils.waitAndVerifyAcls(GroupReadAcl, s.dataPlaneRequestProcessor.authorizer.get, groupResource) + } + } + + protected def setAclsAndProduce(tp: TopicPartition): Unit = { + setReadAndWriteAcls(tp) + val producer = createProducer() + sendRecords(producer, numRecords, tp) + } + + private def setConsumerGroupAcls(): Unit = { + AclCommand.main(groupAclArgs) + servers.foreach { s => + TestUtils.waitAndVerifyAcls(GroupReadAcl, s.dataPlaneRequestProcessor.authorizer.get, groupResource) + } + } + + /** + * Tests that producer, consumer and adminClient fail to publish messages, consume + * messages and describe topics respectively when the describe ACL isn't set. + * Also verifies that subsequent publish, consume and describe to authorized topic succeeds. + */ + @Test + def testNoDescribeProduceOrConsumeWithoutTopicDescribeAcl(): Unit = { + // Set consumer group acls since we are testing topic authorization + setConsumerGroupAcls() + + // Verify produce/consume/describe throw TopicAuthorizationException + val producer = createProducer() + assertThrows(classOf[TopicAuthorizationException], () => sendRecords(producer, numRecords, tp)) + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + assertThrows(classOf[TopicAuthorizationException], () => consumeRecords(consumer, numRecords, topic = tp.topic)) + val adminClient = createAdminClient() + val e1 = assertThrows(classOf[ExecutionException], () => adminClient.describeTopics(Set(topic).asJava).allTopicNames().get()) + assertTrue(e1.getCause.isInstanceOf[TopicAuthorizationException], "Unexpected exception " + e1.getCause) + + // Verify successful produce/consume/describe on another topic using the same producer, consumer and adminClient + val topic2 = "topic2" + val tp2 = new TopicPartition(topic2, 0) + setReadAndWriteAcls(tp2) + sendRecords(producer, numRecords, tp2) + consumer.assign(List(tp2).asJava) + consumeRecords(consumer, numRecords, topic = topic2) + val describeResults = adminClient.describeTopics(Set(topic, topic2).asJava).topicNameValues() + assertEquals(1, describeResults.get(topic2).get().partitions().size()) + val e2 = assertThrows(classOf[ExecutionException], () => adminClient.describeTopics(Set(topic).asJava).allTopicNames().get()) + assertTrue(e2.getCause.isInstanceOf[TopicAuthorizationException], "Unexpected exception " + e2.getCause) + + // Verify that consumer manually assigning both authorized and unauthorized topic doesn't consume + // from the unauthorized topic and throw; since we can now return data during the time we are updating + // metadata / fetching positions, it is possible that the authorized topic record is returned during this time. + consumer.assign(List(tp, tp2).asJava) + sendRecords(producer, numRecords, tp2) + var topic2RecordConsumed = false + def verifyNoRecords(records: ConsumerRecords[Array[Byte], Array[Byte]]): Boolean = { + assertEquals(Collections.singleton(tp2), records.partitions(), "Consumed records with unexpected partitions: " + records) + topic2RecordConsumed = true + false + } + assertThrows(classOf[TopicAuthorizationException], + () => TestUtils.pollRecordsUntilTrue(consumer, verifyNoRecords, "Consumer didn't fail with authorization exception within timeout")) + + // Add ACLs and verify successful produce/consume/describe on first topic + setReadAndWriteAcls(tp) + if (!topic2RecordConsumed) { + consumeRecordsIgnoreOneAuthorizationException(consumer, numRecords, startingOffset = 1, topic2) + } + sendRecords(producer, numRecords, tp) + consumeRecordsIgnoreOneAuthorizationException(consumer, numRecords, startingOffset = 0, topic) + val describeResults2 = adminClient.describeTopics(Set(topic, topic2).asJava).topicNameValues + assertEquals(1, describeResults2.get(topic).get().partitions().size()) + assertEquals(1, describeResults2.get(topic2).get().partitions().size()) + } + + @Test + def testNoProduceWithDescribeAcl(): Unit = { + AclCommand.main(describeAclArgs) + servers.foreach { s => + TestUtils.waitAndVerifyAcls(TopicDescribeAcl, s.dataPlaneRequestProcessor.authorizer.get, topicResource) + } + val producer = createProducer() + val e = assertThrows(classOf[TopicAuthorizationException], () => sendRecords(producer, numRecords, tp)) + assertEquals(Set(topic).asJava, e.unauthorizedTopics()) + confirmReauthenticationMetrics() + } + + /** + * Tests that a consumer fails to consume messages without the appropriate + * ACL set. + */ + @Test + def testNoConsumeWithoutDescribeAclViaAssign(): Unit = { + noConsumeWithoutDescribeAclSetup() + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + // the exception is expected when the consumer attempts to lookup offsets + assertThrows(classOf[KafkaException], () => consumeRecords(consumer)) + confirmReauthenticationMetrics() + } + + @Test + def testNoConsumeWithoutDescribeAclViaSubscribe(): Unit = { + noConsumeWithoutDescribeAclSetup() + val consumer = createConsumer() + consumer.subscribe(List(topic).asJava) + // this should timeout since the consumer will not be able to fetch any metadata for the topic + assertThrows(classOf[TopicAuthorizationException], () => consumeRecords(consumer, timeout = 3000)) + + // Verify that no records are consumed even if one of the requested topics is authorized + setReadAndWriteAcls(tp) + consumer.subscribe(List(topic, "topic2").asJava) + assertThrows(classOf[TopicAuthorizationException], () => consumeRecords(consumer, timeout = 3000)) + + // Verify that records are consumed if all topics are authorized + consumer.subscribe(List(topic).asJava) + consumeRecordsIgnoreOneAuthorizationException(consumer) + } + + private def noConsumeWithoutDescribeAclSetup(): Unit = { + AclCommand.main(produceAclArgs(tp.topic)) + AclCommand.main(groupAclArgs) + servers.foreach { s => + TestUtils.waitAndVerifyAcls(TopicWriteAcl ++ TopicDescribeAcl ++ TopicCreateAcl, s.dataPlaneRequestProcessor.authorizer.get, topicResource) + TestUtils.waitAndVerifyAcls(GroupReadAcl, s.dataPlaneRequestProcessor.authorizer.get, groupResource) + } + + val producer = createProducer() + sendRecords(producer, numRecords, tp) + + AclCommand.main(deleteDescribeAclArgs) + AclCommand.main(deleteWriteAclArgs) + servers.foreach { s => + TestUtils.waitAndVerifyAcls(GroupReadAcl, s.dataPlaneRequestProcessor.authorizer.get, groupResource) + } + } + + @Test + def testNoConsumeWithDescribeAclViaAssign(): Unit = { + noConsumeWithDescribeAclSetup() + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + + val e = assertThrows(classOf[TopicAuthorizationException], () => consumeRecords(consumer)) + assertEquals(Set(topic).asJava, e.unauthorizedTopics()) + confirmReauthenticationMetrics() + } + + @Test + def testNoConsumeWithDescribeAclViaSubscribe(): Unit = { + noConsumeWithDescribeAclSetup() + val consumer = createConsumer() + consumer.subscribe(List(topic).asJava) + + val e = assertThrows(classOf[TopicAuthorizationException], () => consumeRecords(consumer)) + assertEquals(Set(topic).asJava, e.unauthorizedTopics()) + confirmReauthenticationMetrics() + } + + private def noConsumeWithDescribeAclSetup(): Unit = { + AclCommand.main(produceAclArgs(tp.topic)) + AclCommand.main(groupAclArgs) + servers.foreach { s => + TestUtils.waitAndVerifyAcls(TopicWriteAcl ++ TopicDescribeAcl ++ TopicCreateAcl, s.dataPlaneRequestProcessor.authorizer.get, topicResource) + TestUtils.waitAndVerifyAcls(GroupReadAcl, s.dataPlaneRequestProcessor.authorizer.get, groupResource) + } + val producer = createProducer() + sendRecords(producer, numRecords, tp) + } + + /** + * Tests that a consumer fails to consume messages without the appropriate + * ACL set. + */ + @Test + def testNoGroupAcl(): Unit = { + AclCommand.main(produceAclArgs(tp.topic)) + servers.foreach { s => + TestUtils.waitAndVerifyAcls(TopicWriteAcl ++ TopicDescribeAcl ++ TopicCreateAcl, s.dataPlaneRequestProcessor.authorizer.get, topicResource) + } + val producer = createProducer() + sendRecords(producer, numRecords, tp) + + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + val e = assertThrows(classOf[GroupAuthorizationException], () => consumeRecords(consumer)) + assertEquals(group, e.groupId()) + confirmReauthenticationMetrics() + } + + protected final def sendRecords(producer: KafkaProducer[Array[Byte], Array[Byte]], + numRecords: Int, tp: TopicPartition): Unit = { + val futures = (0 until numRecords).map { i => + val record = new ProducerRecord(tp.topic(), tp.partition(), s"$i".getBytes, s"$i".getBytes) + debug(s"Sending this record: $record") + producer.send(record) + } + try { + futures.foreach(_.get) + } catch { + case e: ExecutionException => throw e.getCause + } + } + + protected final def consumeRecords(consumer: Consumer[Array[Byte], Array[Byte]], + numRecords: Int = 1, + startingOffset: Int = 0, + topic: String = topic, + part: Int = part, + timeout: Long = 10000): Unit = { + val records = TestUtils.consumeRecords(consumer, numRecords, timeout) + + for (i <- 0 until numRecords) { + val record = records(i) + val offset = startingOffset + i + assertEquals(topic, record.topic) + assertEquals(part, record.partition) + assertEquals(offset.toLong, record.offset) + } + } + + protected def createScramAdminClient(scramMechanism: String, user: String, password: String): Admin = { + createAdminClient(brokerList, securityProtocol, trustStoreFile, clientSaslProperties, + scramMechanism, user, password) + } + + // Consume records, ignoring at most one TopicAuthorization exception from previously sent request + private def consumeRecordsIgnoreOneAuthorizationException(consumer: Consumer[Array[Byte], Array[Byte]], + numRecords: Int = 1, + startingOffset: Int = 0, + topic: String = topic): Unit = { + try { + consumeRecords(consumer, numRecords, startingOffset, topic) + } catch { + case _: TopicAuthorizationException => consumeRecords(consumer, numRecords, startingOffset, topic) + } + } +} + diff --git a/core/src/test/scala/integration/kafka/api/EndToEndClusterIdTest.scala b/core/src/test/scala/integration/kafka/api/EndToEndClusterIdTest.scala new file mode 100644 index 0000000..7b05113 --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/EndToEndClusterIdTest.scala @@ -0,0 +1,213 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.api + +import java.util.concurrent.ExecutionException +import java.util.concurrent.atomic.AtomicReference +import java.util.Properties +import kafka.integration.KafkaServerTestHarness +import kafka.server._ +import kafka.utils._ +import kafka.utils.Implicits._ +import org.apache.kafka.clients.consumer._ +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerConfig, ProducerRecord} +import org.apache.kafka.common.{ClusterResource, ClusterResourceListener, TopicPartition} +import org.apache.kafka.test.{TestUtils => _, _} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{BeforeEach, Test, TestInfo} + +import scala.jdk.CollectionConverters._ +import org.apache.kafka.test.TestUtils.isValidClusterId + +/** The test cases here verify the following conditions. + * 1. The ProducerInterceptor receives the cluster id after the onSend() method is called and before onAcknowledgement() method is called. + * 2. The Serializer receives the cluster id before the serialize() method is called. + * 3. The producer MetricReporter receives the cluster id after send() method is called on KafkaProducer. + * 4. The ConsumerInterceptor receives the cluster id before the onConsume() method. + * 5. The Deserializer receives the cluster id before the deserialize() method is called. + * 6. The consumer MetricReporter receives the cluster id after poll() is called on KafkaConsumer. + * 7. The broker MetricReporter receives the cluster id after the broker startup is over. + * 8. The broker KafkaMetricReporter receives the cluster id after the broker startup is over. + * 9. All the components receive the same cluster id. + */ + +object EndToEndClusterIdTest { + + object MockConsumerMetricsReporter { + val CLUSTER_META = new AtomicReference[ClusterResource] + } + + class MockConsumerMetricsReporter extends MockMetricsReporter with ClusterResourceListener { + + override def onUpdate(clusterMetadata: ClusterResource): Unit = { + MockConsumerMetricsReporter.CLUSTER_META.set(clusterMetadata) + } + } + + object MockProducerMetricsReporter { + val CLUSTER_META = new AtomicReference[ClusterResource] + } + + class MockProducerMetricsReporter extends MockMetricsReporter with ClusterResourceListener { + + override def onUpdate(clusterMetadata: ClusterResource): Unit = { + MockProducerMetricsReporter.CLUSTER_META.set(clusterMetadata) + } + } + + object MockBrokerMetricsReporter { + val CLUSTER_META = new AtomicReference[ClusterResource] + } + + class MockBrokerMetricsReporter extends MockMetricsReporter with ClusterResourceListener { + + override def onUpdate(clusterMetadata: ClusterResource): Unit = { + MockBrokerMetricsReporter.CLUSTER_META.set(clusterMetadata) + } + } +} + +class EndToEndClusterIdTest extends KafkaServerTestHarness { + + import EndToEndClusterIdTest._ + + val producerCount = 1 + val consumerCount = 1 + val serverCount = 1 + lazy val producerConfig = new Properties + lazy val consumerConfig = new Properties + lazy val serverConfig = new Properties + val numRecords = 1 + val topic = "e2etopic" + val part = 0 + val tp = new TopicPartition(topic, part) + this.serverConfig.setProperty(KafkaConfig.MetricReporterClassesProp, classOf[MockBrokerMetricsReporter].getName) + + override def generateConfigs = { + val cfgs = TestUtils.createBrokerConfigs(serverCount, zkConnect, interBrokerSecurityProtocol = Some(securityProtocol), + trustStoreFile = trustStoreFile, saslProperties = serverSaslProperties) + cfgs.foreach(_ ++= serverConfig) + cfgs.map(KafkaConfig.fromProps) + } + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + MockDeserializer.resetStaticVariables + // create the consumer offset topic + createTopic(topic, 2, serverCount) + } + + @Test + def testEndToEnd(): Unit = { + val appendStr = "mock" + MockConsumerInterceptor.resetCounters() + MockProducerInterceptor.resetCounters() + + assertNotNull(MockBrokerMetricsReporter.CLUSTER_META) + isValidClusterId(MockBrokerMetricsReporter.CLUSTER_META.get.clusterId) + + val producerProps = new Properties() + producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList) + producerProps.put(ProducerConfig.INTERCEPTOR_CLASSES_CONFIG, classOf[MockProducerInterceptor].getName) + producerProps.put("mock.interceptor.append", appendStr) + producerProps.put(ProducerConfig.METRIC_REPORTER_CLASSES_CONFIG, classOf[MockProducerMetricsReporter].getName) + val testProducer = new KafkaProducer(producerProps, new MockSerializer, new MockSerializer) + + // Send one record and make sure clusterId is set after send and before onAcknowledgement + sendRecords(testProducer, 1, tp) + assertNotEquals(MockProducerInterceptor.CLUSTER_ID_BEFORE_ON_ACKNOWLEDGEMENT, MockProducerInterceptor.NO_CLUSTER_ID) + assertNotNull(MockProducerInterceptor.CLUSTER_META) + assertEquals(MockProducerInterceptor.CLUSTER_ID_BEFORE_ON_ACKNOWLEDGEMENT.get.clusterId, MockProducerInterceptor.CLUSTER_META.get.clusterId) + isValidClusterId(MockProducerInterceptor.CLUSTER_META.get.clusterId) + + // Make sure that serializer gets the cluster id before serialize method. + assertNotEquals(MockSerializer.CLUSTER_ID_BEFORE_SERIALIZE, MockSerializer.NO_CLUSTER_ID) + assertNotNull(MockSerializer.CLUSTER_META) + isValidClusterId(MockSerializer.CLUSTER_META.get.clusterId) + + assertNotNull(MockProducerMetricsReporter.CLUSTER_META) + isValidClusterId(MockProducerMetricsReporter.CLUSTER_META.get.clusterId) + + this.consumerConfig.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList) + this.consumerConfig.setProperty(ConsumerConfig.INTERCEPTOR_CLASSES_CONFIG, classOf[MockConsumerInterceptor].getName) + this.consumerConfig.put(ConsumerConfig.METRIC_REPORTER_CLASSES_CONFIG, classOf[MockConsumerMetricsReporter].getName) + val testConsumer = new KafkaConsumer(this.consumerConfig, new MockDeserializer, new MockDeserializer) + testConsumer.assign(List(tp).asJava) + testConsumer.seek(tp, 0) + + // consume and verify that values are modified by interceptors + consumeRecords(testConsumer, numRecords) + + // Check that cluster id is present after the first poll call. + assertNotEquals(MockConsumerInterceptor.CLUSTER_ID_BEFORE_ON_CONSUME, MockConsumerInterceptor.NO_CLUSTER_ID) + assertNotNull(MockConsumerInterceptor.CLUSTER_META) + isValidClusterId(MockConsumerInterceptor.CLUSTER_META.get.clusterId) + assertEquals(MockConsumerInterceptor.CLUSTER_ID_BEFORE_ON_CONSUME.get.clusterId, MockConsumerInterceptor.CLUSTER_META.get.clusterId) + + assertNotEquals(MockDeserializer.clusterIdBeforeDeserialize, MockDeserializer.noClusterId) + assertNotNull(MockDeserializer.clusterMeta) + isValidClusterId(MockDeserializer.clusterMeta.get.clusterId) + assertEquals(MockDeserializer.clusterIdBeforeDeserialize.get.clusterId, MockDeserializer.clusterMeta.get.clusterId) + + assertNotNull(MockConsumerMetricsReporter.CLUSTER_META) + isValidClusterId(MockConsumerMetricsReporter.CLUSTER_META.get.clusterId) + + // Make sure everyone receives the same cluster id. + assertEquals(MockProducerInterceptor.CLUSTER_META.get.clusterId, MockSerializer.CLUSTER_META.get.clusterId) + assertEquals(MockProducerInterceptor.CLUSTER_META.get.clusterId, MockProducerMetricsReporter.CLUSTER_META.get.clusterId) + assertEquals(MockProducerInterceptor.CLUSTER_META.get.clusterId, MockConsumerInterceptor.CLUSTER_META.get.clusterId) + assertEquals(MockProducerInterceptor.CLUSTER_META.get.clusterId, MockDeserializer.clusterMeta.get.clusterId) + assertEquals(MockProducerInterceptor.CLUSTER_META.get.clusterId, MockConsumerMetricsReporter.CLUSTER_META.get.clusterId) + assertEquals(MockProducerInterceptor.CLUSTER_META.get.clusterId, MockBrokerMetricsReporter.CLUSTER_META.get.clusterId) + + testConsumer.close() + testProducer.close() + MockConsumerInterceptor.resetCounters() + MockProducerInterceptor.resetCounters() + } + + private def sendRecords(producer: KafkaProducer[Array[Byte], Array[Byte]], numRecords: Int, tp: TopicPartition): Unit = { + val futures = (0 until numRecords).map { i => + val record = new ProducerRecord(tp.topic(), tp.partition(), s"$i".getBytes, s"$i".getBytes) + debug(s"Sending this record: $record") + producer.send(record) + } + try { + futures.foreach(_.get) + } catch { + case e: ExecutionException => throw e.getCause + } + } + + private def consumeRecords(consumer: Consumer[Array[Byte], Array[Byte]], + numRecords: Int, + startingOffset: Int = 0, + topic: String = topic, + part: Int = part): Unit = { + val records = TestUtils.consumeRecords(consumer, numRecords) + + for (i <- 0 until numRecords) { + val record = records(i) + val offset = startingOffset + i + assertEquals(topic, record.topic) + assertEquals(part, record.partition) + assertEquals(offset.toLong, record.offset) + } + } +} diff --git a/core/src/test/scala/integration/kafka/api/FixedPortTestUtils.scala b/core/src/test/scala/integration/kafka/api/FixedPortTestUtils.scala new file mode 100644 index 0000000..bf5f8c1 --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/FixedPortTestUtils.scala @@ -0,0 +1,51 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package kafka.api + +import java.io.IOException +import java.net.ServerSocket +import java.util.Properties + +import kafka.utils.TestUtils + +/** + * DO NOT USE THESE UTILITIES UNLESS YOU ABSOLUTELY MUST + * + * These are utilities for selecting fixed (preselected), ephemeral ports to use with tests. This is not a reliable way + * of testing on most machines because you can easily run into port conflicts. If you're using this class, you're almost + * certainly doing something wrong unless you can prove that your test **cannot function** properly without it. + */ +object FixedPortTestUtils { + def choosePorts(count: Int): Seq[Int] = { + try { + val sockets = (0 until count).map(_ => new ServerSocket(0)) + val ports = sockets.map(_.getLocalPort()) + sockets.foreach(_.close()) + ports + } catch { + case e: IOException => throw new RuntimeException(e) + } + } + + def createBrokerConfigs(numConfigs: Int, + zkConnect: String, + enableControlledShutdown: Boolean = true, + enableDeleteTopic: Boolean = false): Seq[Properties] = { + val ports = FixedPortTestUtils.choosePorts(numConfigs) + (0 until numConfigs).map { node => + TestUtils.createBrokerConfig(node, zkConnect, enableControlledShutdown, enableDeleteTopic, ports(node)) + } + } + +} diff --git a/core/src/test/scala/integration/kafka/api/GroupAuthorizerIntegrationTest.scala b/core/src/test/scala/integration/kafka/api/GroupAuthorizerIntegrationTest.scala new file mode 100644 index 0000000..cedfc17 --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/GroupAuthorizerIntegrationTest.scala @@ -0,0 +1,137 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package kafka.api + +import java.util.Properties +import java.util.concurrent.ExecutionException + +import kafka.api.GroupAuthorizerIntegrationTest._ +import kafka.security.authorizer.AclAuthorizer +import kafka.security.authorizer.AclEntry.WildcardHost +import kafka.server.{BaseRequestTest, KafkaConfig} +import kafka.utils.TestUtils +import org.apache.kafka.clients.consumer.ConsumerConfig +import org.apache.kafka.clients.producer.ProducerRecord +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.acl.{AccessControlEntry, AclOperation, AclPermissionType} +import org.apache.kafka.common.config.internals.BrokerSecurityConfigs +import org.apache.kafka.common.errors.TopicAuthorizationException +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.resource.{PatternType, Resource, ResourcePattern, ResourceType} +import org.apache.kafka.common.security.auth.{AuthenticationContext, KafkaPrincipal} +import org.apache.kafka.common.security.authenticator.DefaultKafkaPrincipalBuilder +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{BeforeEach, Test, TestInfo} + +import scala.jdk.CollectionConverters._ + +object GroupAuthorizerIntegrationTest { + val BrokerPrincipal = new KafkaPrincipal("Group", "broker") + val ClientPrincipal = new KafkaPrincipal("Group", "client") + + val BrokerListenerName = "BROKER" + val ClientListenerName = "CLIENT" + + class GroupPrincipalBuilder extends DefaultKafkaPrincipalBuilder(null, null) { + override def build(context: AuthenticationContext): KafkaPrincipal = { + context.listenerName match { + case BrokerListenerName => BrokerPrincipal + case ClientListenerName => ClientPrincipal + case listenerName => throw new IllegalArgumentException(s"No principal mapped to listener $listenerName") + } + } + } +} + +class GroupAuthorizerIntegrationTest extends BaseRequestTest { + + val brokerId: Integer = 0 + + override def brokerCount: Int = 1 + override def interBrokerListenerName: ListenerName = new ListenerName(BrokerListenerName) + override def listenerName: ListenerName = new ListenerName(ClientListenerName) + + def brokerPrincipal: KafkaPrincipal = BrokerPrincipal + def clientPrincipal: KafkaPrincipal = ClientPrincipal + + override def brokerPropertyOverrides(properties: Properties): Unit = { + properties.put(KafkaConfig.AuthorizerClassNameProp, classOf[AclAuthorizer].getName) + properties.put(KafkaConfig.BrokerIdProp, brokerId.toString) + properties.put(KafkaConfig.OffsetsTopicPartitionsProp, "1") + properties.put(KafkaConfig.OffsetsTopicReplicationFactorProp, "1") + properties.put(KafkaConfig.TransactionsTopicPartitionsProp, "1") + properties.put(KafkaConfig.TransactionsTopicReplicationFactorProp, "1") + properties.put(KafkaConfig.TransactionsTopicMinISRProp, "1") + properties.put(BrokerSecurityConfigs.PRINCIPAL_BUILDER_CLASS_CONFIG, classOf[GroupPrincipalBuilder].getName) + } + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + doSetup(testInfo, createOffsetsTopic = false) + + // Allow inter-broker communication + TestUtils.addAndVerifyAcls(servers.head, + Set(createAcl(AclOperation.CLUSTER_ACTION, AclPermissionType.ALLOW, principal = BrokerPrincipal)), + new ResourcePattern(ResourceType.CLUSTER, Resource.CLUSTER_NAME, PatternType.LITERAL)) + + TestUtils.createOffsetsTopic(zkClient, servers) + } + + private def createAcl(aclOperation: AclOperation, + aclPermissionType: AclPermissionType, + principal: KafkaPrincipal = ClientPrincipal): AccessControlEntry = { + new AccessControlEntry(principal.toString, WildcardHost, aclOperation, aclPermissionType) + } + + @Test + def testUnauthorizedProduceAndConsume(): Unit = { + val topic = "topic" + val topicPartition = new TopicPartition("topic", 0) + + createTopic(topic) + + val producer = createProducer() + val produceException = assertThrows(classOf[ExecutionException], + () => producer.send(new ProducerRecord[Array[Byte], Array[Byte]](topic, "message".getBytes)).get()).getCause + assertTrue(produceException.isInstanceOf[TopicAuthorizationException]) + assertEquals(Set(topic), produceException.asInstanceOf[TopicAuthorizationException].unauthorizedTopics.asScala) + + val consumer = createConsumer(configsToRemove = List(ConsumerConfig.GROUP_ID_CONFIG)) + consumer.assign(List(topicPartition).asJava) + val consumeException = assertThrows(classOf[TopicAuthorizationException], + () => TestUtils.pollUntilAtLeastNumRecords(consumer, numRecords = 1)) + assertEquals(Set(topic), consumeException.unauthorizedTopics.asScala) + } + + @Test + def testAuthorizedProduceAndConsume(): Unit = { + val topic = "topic" + val topicPartition = new TopicPartition("topic", 0) + + createTopic(topic) + + TestUtils.addAndVerifyAcls(servers.head, + Set(createAcl(AclOperation.WRITE, AclPermissionType.ALLOW)), + new ResourcePattern(ResourceType.TOPIC, topic, PatternType.LITERAL)) + val producer = createProducer() + producer.send(new ProducerRecord[Array[Byte], Array[Byte]](topic, "message".getBytes)).get() + + TestUtils.addAndVerifyAcls(servers.head, + Set(createAcl(AclOperation.READ, AclPermissionType.ALLOW)), + new ResourcePattern(ResourceType.TOPIC, topic, PatternType.LITERAL)) + val consumer = createConsumer(configsToRemove = List(ConsumerConfig.GROUP_ID_CONFIG)) + consumer.assign(List(topicPartition).asJava) + TestUtils.pollUntilAtLeastNumRecords(consumer, numRecords = 1) + } + +} diff --git a/core/src/test/scala/integration/kafka/api/GroupCoordinatorIntegrationTest.scala b/core/src/test/scala/integration/kafka/api/GroupCoordinatorIntegrationTest.scala new file mode 100644 index 0000000..a6b59f0 --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/GroupCoordinatorIntegrationTest.scala @@ -0,0 +1,62 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package kafka.api + +import kafka.integration.KafkaServerTestHarness +import kafka.log.UnifiedLog +import kafka.server.KafkaConfig +import kafka.utils.TestUtils +import org.apache.kafka.clients.consumer.OffsetAndMetadata +import org.apache.kafka.common.TopicPartition +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Assertions._ + +import scala.jdk.CollectionConverters._ +import java.util.Properties + +import org.apache.kafka.common.internals.Topic +import org.apache.kafka.common.record.CompressionType + +class GroupCoordinatorIntegrationTest extends KafkaServerTestHarness { + val offsetsTopicCompressionCodec = CompressionType.GZIP + val overridingProps = new Properties() + overridingProps.put(KafkaConfig.OffsetsTopicPartitionsProp, "1") + overridingProps.put(KafkaConfig.OffsetsTopicCompressionCodecProp, offsetsTopicCompressionCodec.id.toString) + + override def generateConfigs = TestUtils.createBrokerConfigs(1, zkConnect, enableControlledShutdown = false).map { + KafkaConfig.fromProps(_, overridingProps) + } + + @Test + def testGroupCoordinatorPropagatesOffsetsTopicCompressionCodec(): Unit = { + val consumer = TestUtils.createConsumer(TestUtils.getBrokerListStrFromServers(servers)) + val offsetMap = Map( + new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, 0) -> new OffsetAndMetadata(10, "") + ).asJava + consumer.commitSync(offsetMap) + val logManager = servers.head.getLogManager + def getGroupMetadataLogOpt: Option[UnifiedLog] = + logManager.getLog(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, 0)) + + TestUtils.waitUntilTrue(() => getGroupMetadataLogOpt.exists(_.logSegments.exists(_.log.batches.asScala.nonEmpty)), + "Commit message not appended in time") + + val logSegments = getGroupMetadataLogOpt.get.logSegments + val incorrectCompressionCodecs = logSegments + .flatMap(_.log.batches.asScala.map(_.compressionType)) + .filter(_ != offsetsTopicCompressionCodec) + assertEquals(Seq.empty, incorrectCompressionCodecs, "Incorrect compression codecs should be empty") + + consumer.close() + } +} diff --git a/core/src/test/scala/integration/kafka/api/GroupEndToEndAuthorizationTest.scala b/core/src/test/scala/integration/kafka/api/GroupEndToEndAuthorizationTest.scala new file mode 100644 index 0000000..8bd6393 --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/GroupEndToEndAuthorizationTest.scala @@ -0,0 +1,47 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.api + +import kafka.api.GroupEndToEndAuthorizationTest._ +import kafka.utils.JaasTestUtils +import org.apache.kafka.common.config.internals.BrokerSecurityConfigs +import org.apache.kafka.common.security.auth.{AuthenticationContext, KafkaPrincipal, SaslAuthenticationContext} +import org.apache.kafka.common.security.authenticator.DefaultKafkaPrincipalBuilder + +object GroupEndToEndAuthorizationTest { + val GroupPrincipalType = "Group" + val ClientGroup = "testGroup" + class GroupPrincipalBuilder extends DefaultKafkaPrincipalBuilder(null, null) { + override def build(context: AuthenticationContext): KafkaPrincipal = { + context match { + case ctx: SaslAuthenticationContext => + if (ctx.server.getAuthorizationID == JaasTestUtils.KafkaScramUser) + new KafkaPrincipal(GroupPrincipalType, ClientGroup) + else + new KafkaPrincipal(GroupPrincipalType, ctx.server.getAuthorizationID) + case _ => + KafkaPrincipal.ANONYMOUS + } + } + } +} + +class GroupEndToEndAuthorizationTest extends SaslScramSslEndToEndAuthorizationTest { + override val clientPrincipal = new KafkaPrincipal(GroupPrincipalType, ClientGroup) + override val kafkaPrincipal = new KafkaPrincipal(GroupPrincipalType, JaasTestUtils.KafkaScramAdmin) + this.serverConfig.setProperty(BrokerSecurityConfigs.PRINCIPAL_BUILDER_CLASS_CONFIG, classOf[GroupPrincipalBuilder].getName) +} diff --git a/core/src/test/scala/integration/kafka/api/IntegrationTestHarness.scala b/core/src/test/scala/integration/kafka/api/IntegrationTestHarness.scala new file mode 100644 index 0000000..0f987e9 --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/IntegrationTestHarness.scala @@ -0,0 +1,171 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.api + +import java.time.Duration + +import org.apache.kafka.clients.consumer.{ConsumerConfig, KafkaConsumer} +import kafka.utils.TestUtils +import kafka.utils.Implicits._ +import java.util.Properties + +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerConfig} +import kafka.server.KafkaConfig +import kafka.integration.KafkaServerTestHarness +import org.apache.kafka.clients.admin.{Admin, AdminClientConfig} +import org.apache.kafka.common.network.{ListenerName, Mode} +import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer, Deserializer, Serializer} +import org.junit.jupiter.api.{AfterEach, BeforeEach, TestInfo} + +import scala.collection.mutable +import scala.collection.Seq + +/** + * A helper class for writing integration tests that involve producers, consumers, and servers + */ +abstract class IntegrationTestHarness extends KafkaServerTestHarness { + protected def brokerCount: Int + protected def logDirCount: Int = 1 + + val producerConfig = new Properties + val consumerConfig = new Properties + val adminClientConfig = new Properties + val serverConfig = new Properties + + private val consumers = mutable.Buffer[KafkaConsumer[_, _]]() + private val producers = mutable.Buffer[KafkaProducer[_, _]]() + private val adminClients = mutable.Buffer[Admin]() + + protected def interBrokerListenerName: ListenerName = listenerName + + protected def modifyConfigs(props: Seq[Properties]): Unit = { + props.foreach(_ ++= serverConfig) + } + + override def generateConfigs: Seq[KafkaConfig] = { + val cfgs = TestUtils.createBrokerConfigs(brokerCount, zkConnectOrNull, interBrokerSecurityProtocol = Some(securityProtocol), + trustStoreFile = trustStoreFile, saslProperties = serverSaslProperties, logDirCount = logDirCount) + configureListeners(cfgs) + modifyConfigs(cfgs) + cfgs.map(KafkaConfig.fromProps) + } + + protected def configureListeners(props: Seq[Properties]): Unit = { + props.foreach { config => + config.remove(KafkaConfig.InterBrokerSecurityProtocolProp) + config.setProperty(KafkaConfig.InterBrokerListenerNameProp, interBrokerListenerName.value) + + val listenerNames = Set(listenerName, interBrokerListenerName) + val listeners = listenerNames.map(listenerName => s"${listenerName.value}://localhost:${TestUtils.RandomPort}").mkString(",") + val listenerSecurityMap = listenerNames.map(listenerName => s"${listenerName.value}:${securityProtocol.name}").mkString(",") + + config.setProperty(KafkaConfig.ListenersProp, listeners) + config.setProperty(KafkaConfig.AdvertisedListenersProp, listeners) + config.setProperty(KafkaConfig.ListenerSecurityProtocolMapProp, listenerSecurityMap) + } + } + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + doSetup(testInfo, createOffsetsTopic = true) + } + + def doSetup(testInfo: TestInfo, + createOffsetsTopic: Boolean): Unit = { + // Generate client security properties before starting the brokers in case certs are needed + producerConfig ++= clientSecurityProps("producer") + consumerConfig ++= clientSecurityProps("consumer") + adminClientConfig ++= clientSecurityProps("adminClient") + + super.setUp(testInfo) + + producerConfig.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList) + producerConfig.putIfAbsent(ProducerConfig.ACKS_CONFIG, "-1") + producerConfig.putIfAbsent(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + producerConfig.putIfAbsent(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + + consumerConfig.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList) + consumerConfig.putIfAbsent(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest") + consumerConfig.putIfAbsent(ConsumerConfig.GROUP_ID_CONFIG, "group") + consumerConfig.putIfAbsent(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, classOf[ByteArrayDeserializer].getName) + consumerConfig.putIfAbsent(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, classOf[ByteArrayDeserializer].getName) + + adminClientConfig.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList) + + if (createOffsetsTopic) { + if (isKRaftTest()) { + TestUtils.createOffsetsTopicWithAdmin(brokers, adminClientConfig) + } else { + TestUtils.createOffsetsTopic(zkClient, servers) + } + } + } + + def clientSecurityProps(certAlias: String): Properties = { + TestUtils.securityConfigs(Mode.CLIENT, securityProtocol, trustStoreFile, certAlias, TestUtils.SslCertificateCn, + clientSaslProperties) + } + + def createProducer[K, V](keySerializer: Serializer[K] = new ByteArraySerializer, + valueSerializer: Serializer[V] = new ByteArraySerializer, + configOverrides: Properties = new Properties): KafkaProducer[K, V] = { + val props = new Properties + props ++= producerConfig + props ++= configOverrides + val producer = new KafkaProducer[K, V](props, keySerializer, valueSerializer) + producers += producer + producer + } + + def createConsumer[K, V](keyDeserializer: Deserializer[K] = new ByteArrayDeserializer, + valueDeserializer: Deserializer[V] = new ByteArrayDeserializer, + configOverrides: Properties = new Properties, + configsToRemove: List[String] = List()): KafkaConsumer[K, V] = { + val props = new Properties + props ++= consumerConfig + props ++= configOverrides + configsToRemove.foreach(props.remove(_)) + val consumer = new KafkaConsumer[K, V](props, keyDeserializer, valueDeserializer) + consumers += consumer + consumer + } + + def createAdminClient(configOverrides: Properties = new Properties): Admin = { + val props = new Properties + props ++= adminClientConfig + props ++= configOverrides + val adminClient = Admin.create(props) + adminClients += adminClient + adminClient + } + + @AfterEach + override def tearDown(): Unit = { + producers.foreach(_.close(Duration.ZERO)) + consumers.foreach(_.wakeup()) + consumers.foreach(_.close(Duration.ZERO)) + adminClients.foreach(_.close(Duration.ZERO)) + + producers.clear() + consumers.clear() + adminClients.clear() + + super.tearDown() + } + +} diff --git a/core/src/test/scala/integration/kafka/api/LogAppendTimeTest.scala b/core/src/test/scala/integration/kafka/api/LogAppendTimeTest.scala new file mode 100644 index 0000000..6f397d8 --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/LogAppendTimeTest.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.api + +import java.util.Collections +import java.util.concurrent.TimeUnit +import kafka.server.KafkaConfig +import kafka.utils.TestUtils +import org.apache.kafka.clients.producer.ProducerRecord +import org.apache.kafka.common.record.TimestampType +import org.junit.jupiter.api.{BeforeEach, Test, TestInfo} +import org.junit.jupiter.api.Assertions.{assertEquals, assertNotEquals, assertTrue} + +/** + * Tests where the broker is configured to use LogAppendTime. For tests where LogAppendTime is configured via topic + * level configs, see the *ProducerSendTest classes. + */ +class LogAppendTimeTest extends IntegrationTestHarness { + val producerCount: Int = 1 + val consumerCount: Int = 1 + val brokerCount: Int = 2 + + // This will be used for the offsets topic as well + serverConfig.put(KafkaConfig.LogMessageTimestampTypeProp, TimestampType.LOG_APPEND_TIME.name) + serverConfig.put(KafkaConfig.OffsetsTopicReplicationFactorProp, "2") + + private val topic = "topic" + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + createTopic(topic) + } + + @Test + def testProduceConsume(): Unit = { + val producer = createProducer() + val now = System.currentTimeMillis() + val createTime = now - TimeUnit.DAYS.toMillis(1) + val producerRecords = (1 to 10).map(i => new ProducerRecord(topic, null, createTime, s"key$i".getBytes, + s"value$i".getBytes)) + val recordMetadatas = producerRecords.map(producer.send).map(_.get(10, TimeUnit.SECONDS)) + recordMetadatas.foreach { recordMetadata => + assertTrue(recordMetadata.timestamp >= now) + assertTrue(recordMetadata.timestamp < now + TimeUnit.SECONDS.toMillis(60)) + } + + val consumer = createConsumer() + consumer.subscribe(Collections.singleton(topic)) + val consumerRecords = TestUtils.consumeRecords(consumer, producerRecords.size) + + consumerRecords.zipWithIndex.foreach { case (consumerRecord, index) => + val producerRecord = producerRecords(index) + val recordMetadata = recordMetadatas(index) + assertEquals(new String(producerRecord.key), new String(consumerRecord.key)) + assertEquals(new String(producerRecord.value), new String(consumerRecord.value)) + assertNotEquals(producerRecord.timestamp, consumerRecord.timestamp) + assertEquals(recordMetadata.timestamp, consumerRecord.timestamp) + assertEquals(TimestampType.LOG_APPEND_TIME, consumerRecord.timestampType) + } + } + +} diff --git a/core/src/test/scala/integration/kafka/api/MetricsTest.scala b/core/src/test/scala/integration/kafka/api/MetricsTest.scala new file mode 100644 index 0000000..850ac89 --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/MetricsTest.scala @@ -0,0 +1,309 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package kafka.api + +import java.util.{Locale, Properties} +import kafka.log.LogConfig +import kafka.server.{KafkaConfig, KafkaServer} +import kafka.utils.{JaasTestUtils, TestUtils} +import com.yammer.metrics.core.{Gauge, Histogram, Meter} +import kafka.metrics.KafkaYammerMetrics +import org.apache.kafka.clients.consumer.KafkaConsumer +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerConfig, ProducerRecord} +import org.apache.kafka.common.{Metric, MetricName, TopicPartition} +import org.apache.kafka.common.config.SaslConfigs +import org.apache.kafka.common.errors.{InvalidTopicException, UnknownTopicOrPartitionException} +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.security.authenticator.TestJaasConfig +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} +import org.junit.jupiter.api.Assertions._ + +import scala.annotation.nowarn +import scala.jdk.CollectionConverters._ + +class MetricsTest extends IntegrationTestHarness with SaslSetup { + + override val brokerCount = 1 + + override protected def listenerName = new ListenerName("CLIENT") + private val kafkaClientSaslMechanism = "PLAIN" + private val kafkaServerSaslMechanisms = List(kafkaClientSaslMechanism) + private val kafkaServerJaasEntryName = + s"${listenerName.value.toLowerCase(Locale.ROOT)}.${JaasTestUtils.KafkaServerContextName}" + this.serverConfig.setProperty(KafkaConfig.ZkEnableSecureAclsProp, "false") + this.serverConfig.setProperty(KafkaConfig.AutoCreateTopicsEnableProp, "false") + this.serverConfig.setProperty(KafkaConfig.InterBrokerProtocolVersionProp, "2.8") + this.producerConfig.setProperty(ProducerConfig.LINGER_MS_CONFIG, "10") + // intentionally slow message down conversion via gzip compression to ensure we can measure the time it takes + this.producerConfig.setProperty(ProducerConfig.COMPRESSION_TYPE_CONFIG, "gzip") + override protected def securityProtocol = SecurityProtocol.SASL_PLAINTEXT + override protected val serverSaslProperties = + Some(kafkaServerSaslProperties(kafkaServerSaslMechanisms, kafkaClientSaslMechanism)) + override protected val clientSaslProperties = + Some(kafkaClientSaslProperties(kafkaClientSaslMechanism)) + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + verifyNoRequestMetrics("Request metrics not removed in a previous test") + startSasl(jaasSections(kafkaServerSaslMechanisms, Some(kafkaClientSaslMechanism), KafkaSasl, kafkaServerJaasEntryName)) + super.setUp(testInfo) + } + + @AfterEach + override def tearDown(): Unit = { + super.tearDown() + closeSasl() + verifyNoRequestMetrics("Request metrics not removed in this test") + } + + /** + * Verifies some of the metrics of producer, consumer as well as server. + */ + @nowarn("cat=deprecation") + @Test + def testMetrics(): Unit = { + val topic = "topicWithOldMessageFormat" + val props = new Properties + props.setProperty(LogConfig.MessageFormatVersionProp, "0.9.0") + createTopic(topic, numPartitions = 1, replicationFactor = 1, props) + val tp = new TopicPartition(topic, 0) + + // Produce and consume some records + val numRecords = 10 + val recordSize = 100000 + val producer = createProducer() + sendRecords(producer, numRecords, recordSize, tp) + + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + consumer.seek(tp, 0) + TestUtils.consumeRecords(consumer, numRecords) + + verifyKafkaRateMetricsHaveCumulativeCount(producer, consumer) + verifyClientVersionMetrics(consumer.metrics, "Consumer") + verifyClientVersionMetrics(producer.metrics, "Producer") + + val server = servers.head + verifyBrokerMessageConversionMetrics(server, recordSize, tp) + verifyBrokerErrorMetrics(servers.head) + verifyBrokerZkMetrics(server, topic) + + generateAuthenticationFailure(tp) + verifyBrokerAuthenticationMetrics(server) + } + + private def sendRecords(producer: KafkaProducer[Array[Byte], Array[Byte]], numRecords: Int, + recordSize: Int, tp: TopicPartition) = { + val bytes = new Array[Byte](recordSize) + (0 until numRecords).map { i => + producer.send(new ProducerRecord(tp.topic, tp.partition, i.toLong, s"key $i".getBytes, bytes)) + } + producer.flush() + } + + // Create a producer that fails authentication to verify authentication failure metrics + private def generateAuthenticationFailure(tp: TopicPartition): Unit = { + val saslProps = new Properties() + saslProps.put(SaslConfigs.SASL_MECHANISM, kafkaClientSaslMechanism) + saslProps.put(SaslConfigs.SASL_JAAS_CONFIG, TestJaasConfig.jaasConfigProperty(kafkaClientSaslMechanism, "badUser", "badPass")) + // Use acks=0 to verify error metric when connection is closed without a response + val producer = TestUtils.createProducer(brokerList, + acks = 0, + requestTimeoutMs = 1000, + maxBlockMs = 1000, + securityProtocol = securityProtocol, + trustStoreFile = trustStoreFile, + saslProperties = Some(saslProps)) + + try { + producer.send(new ProducerRecord(tp.topic, tp.partition, "key".getBytes, "value".getBytes)).get + } catch { + case _: Exception => // expected exception + } finally { + producer.close() + } + } + + private def verifyKafkaRateMetricsHaveCumulativeCount(producer: KafkaProducer[Array[Byte], Array[Byte]], + consumer: KafkaConsumer[Array[Byte], Array[Byte]]): Unit = { + + def exists(name: String, rateMetricName: MetricName, allMetricNames: Set[MetricName]): Boolean = { + allMetricNames.contains(new MetricName(name, rateMetricName.group, "", rateMetricName.tags)) + } + + def verify(rateMetricName: MetricName, allMetricNames: Set[MetricName]): Unit = { + val name = rateMetricName.name + val totalExists = exists(name.replace("-rate", "-total"), rateMetricName, allMetricNames) + val totalTimeExists = exists(name.replace("-rate", "-time"), rateMetricName, allMetricNames) + assertTrue(totalExists || totalTimeExists, s"No cumulative count/time metric for rate metric $rateMetricName") + } + + val consumerMetricNames = consumer.metrics.keySet.asScala.toSet + consumerMetricNames.filter(_.name.endsWith("-rate")) + .foreach(verify(_, consumerMetricNames)) + + val producerMetricNames = producer.metrics.keySet.asScala.toSet + val producerExclusions = Set("compression-rate") // compression-rate is an Average metric, not Rate + producerMetricNames.filter(_.name.endsWith("-rate")) + .filterNot(metricName => producerExclusions.contains(metricName.name)) + .foreach(verify(_, producerMetricNames)) + + // Check a couple of metrics of consumer and producer to ensure that values are set + verifyKafkaMetricRecorded("records-consumed-rate", consumer.metrics, "Consumer") + verifyKafkaMetricRecorded("records-consumed-total", consumer.metrics, "Consumer") + verifyKafkaMetricRecorded("record-send-rate", producer.metrics, "Producer") + verifyKafkaMetricRecorded("record-send-total", producer.metrics, "Producer") + } + + private def verifyClientVersionMetrics(metrics: java.util.Map[MetricName, _ <: Metric], entity: String): Unit = { + Seq("commit-id", "version").foreach { name => + verifyKafkaMetric(name, metrics, entity) { matchingMetrics => + assertEquals(1, matchingMetrics.size) + val metric = matchingMetrics.head + val value = metric.metricValue + assertNotNull(value, s"$entity metric not recorded $name") + assertNotNull(value.isInstanceOf[String] && value.asInstanceOf[String].nonEmpty, + s"$entity metric $name should be a non-empty String") + assertTrue(metric.metricName.tags.containsKey("client-id"), "Client-id not specified") + } + } + } + + private def verifyBrokerAuthenticationMetrics(server: KafkaServer): Unit = { + val metrics = server.metrics.metrics + TestUtils.waitUntilTrue(() => + maxKafkaMetricValue("failed-authentication-total", metrics, "Broker", Some("socket-server-metrics")) > 0, + "failed-authentication-total not updated") + verifyKafkaMetricRecorded("successful-authentication-rate", metrics, "Broker", Some("socket-server-metrics")) + verifyKafkaMetricRecorded("successful-authentication-total", metrics, "Broker", Some("socket-server-metrics")) + verifyKafkaMetricRecorded("failed-authentication-rate", metrics, "Broker", Some("socket-server-metrics")) + verifyKafkaMetricRecorded("failed-authentication-total", metrics, "Broker", Some("socket-server-metrics")) + } + + private def verifyBrokerMessageConversionMetrics(server: KafkaServer, recordSize: Int, tp: TopicPartition): Unit = { + val requestMetricsPrefix = "kafka.network:type=RequestMetrics" + val requestBytes = verifyYammerMetricRecorded(s"$requestMetricsPrefix,name=RequestBytes,request=Produce") + val tempBytes = verifyYammerMetricRecorded(s"$requestMetricsPrefix,name=TemporaryMemoryBytes,request=Produce") + assertTrue(tempBytes >= recordSize, s"Unexpected temporary memory size requestBytes $requestBytes tempBytes $tempBytes") + + verifyYammerMetricRecorded(s"kafka.server:type=BrokerTopicMetrics,name=ProduceMessageConversionsPerSec") + verifyYammerMetricRecorded(s"$requestMetricsPrefix,name=MessageConversionsTimeMs,request=Produce", value => value > 0.0) + verifyYammerMetricRecorded(s"$requestMetricsPrefix,name=RequestBytes,request=Fetch") + verifyYammerMetricRecorded(s"$requestMetricsPrefix,name=TemporaryMemoryBytes,request=Fetch", value => value == 0.0) + + // request size recorded for all request types, check one + verifyYammerMetricRecorded(s"$requestMetricsPrefix,name=RequestBytes,request=Metadata") + } + + private def verifyBrokerZkMetrics(server: KafkaServer, topic: String): Unit = { + val histogram = yammerHistogram("kafka.server:type=ZooKeeperClientMetrics,name=ZooKeeperRequestLatencyMs") + // Latency is rounded to milliseconds, so check the count instead + val initialCount = histogram.count + servers.head.zkClient.getLeaderForPartition(new TopicPartition(topic, 0)) + val newCount = histogram.count + assertTrue(newCount > initialCount, "ZooKeeper latency not recorded") + + val min = histogram.min + assertTrue(min >= 0, s"Min latency should not be negative: $min") + + assertEquals("CONNECTED", yammerMetricValue("SessionState"), s"Unexpected ZK state") + } + + private def verifyBrokerErrorMetrics(server: KafkaServer): Unit = { + + def errorMetricCount = KafkaYammerMetrics.defaultRegistry.allMetrics.keySet.asScala.filter(_.getName == "ErrorsPerSec").size + + val startErrorMetricCount = errorMetricCount + val errorMetricPrefix = "kafka.network:type=RequestMetrics,name=ErrorsPerSec" + verifyYammerMetricRecorded(s"$errorMetricPrefix,request=Metadata,error=NONE") + + val consumer = createConsumer() + try { + consumer.partitionsFor("12{}!") + } catch { + case _: InvalidTopicException => // expected + } + verifyYammerMetricRecorded(s"$errorMetricPrefix,request=Metadata,error=INVALID_TOPIC_EXCEPTION") + + // Check that error metrics are registered dynamically + val currentErrorMetricCount = errorMetricCount + assertEquals(startErrorMetricCount + 1, currentErrorMetricCount) + assertTrue(currentErrorMetricCount < 10, s"Too many error metrics $currentErrorMetricCount") + + try { + consumer.partitionsFor("non-existing-topic") + } catch { + case _: UnknownTopicOrPartitionException => // expected + } + verifyYammerMetricRecorded(s"$errorMetricPrefix,request=Metadata,error=UNKNOWN_TOPIC_OR_PARTITION") + } + + private def verifyKafkaMetric[T](name: String, metrics: java.util.Map[MetricName, _ <: Metric], entity: String, + group: Option[String] = None)(verify: Iterable[Metric] => T) : T = { + val matchingMetrics = metrics.asScala.filter { + case (metricName, _) => metricName.name == name && group.forall(_ == metricName.group) + } + assertTrue(matchingMetrics.nonEmpty, s"Metric not found $name") + verify(matchingMetrics.values) + } + + private def maxKafkaMetricValue(name: String, metrics: java.util.Map[MetricName, _ <: Metric], entity: String, + group: Option[String]): Double = { + // Use max value of all matching metrics since Selector metrics are recorded for each Processor + verifyKafkaMetric(name, metrics, entity, group) { matchingMetrics => + matchingMetrics.foldLeft(0.0)((max, metric) => Math.max(max, metric.metricValue.asInstanceOf[Double])) + } + } + + private def verifyKafkaMetricRecorded(name: String, metrics: java.util.Map[MetricName, _ <: Metric], entity: String, + group: Option[String] = None): Unit = { + val value = maxKafkaMetricValue(name, metrics, entity, group) + assertTrue(value > 0.0, s"$entity metric not recorded correctly for $name value $value") + } + + private def yammerMetricValue(name: String): Any = { + val allMetrics = KafkaYammerMetrics.defaultRegistry.allMetrics.asScala + val (_, metric) = allMetrics.find { case (n, _) => n.getMBeanName.endsWith(name) } + .getOrElse(fail(s"Unable to find broker metric $name: allMetrics: ${allMetrics.keySet.map(_.getMBeanName)}")) + metric match { + case m: Meter => m.count.toDouble + case m: Histogram => m.max + case m: Gauge[_] => m.value + case m => fail(s"Unexpected broker metric of class ${m.getClass}") + } + } + + private def yammerHistogram(name: String): Histogram = { + val allMetrics = KafkaYammerMetrics.defaultRegistry.allMetrics.asScala + val (_, metric) = allMetrics.find { case (n, _) => n.getMBeanName.endsWith(name) } + .getOrElse(fail(s"Unable to find broker metric $name: allMetrics: ${allMetrics.keySet.map(_.getMBeanName)}")) + metric match { + case m: Histogram => m + case m => throw new AssertionError(s"Unexpected broker metric of class ${m.getClass}") + } + } + + private def verifyYammerMetricRecorded(name: String, verify: Double => Boolean = d => d > 0): Double = { + val metricValue = yammerMetricValue(name).asInstanceOf[Double] + assertTrue(verify(metricValue), s"Broker metric not recorded correctly for $name value $metricValue") + metricValue + } + + private def verifyNoRequestMetrics(errorMessage: String): Unit = { + val metrics = KafkaYammerMetrics.defaultRegistry.allMetrics.asScala.filter { case (n, _) => + n.getMBeanName.startsWith("kafka.network:type=RequestMetrics") + } + assertTrue(metrics.isEmpty, s"$errorMessage: ${metrics.keys}") + } +} diff --git a/core/src/test/scala/integration/kafka/api/PlaintextAdminIntegrationTest.scala b/core/src/test/scala/integration/kafka/api/PlaintextAdminIntegrationTest.scala new file mode 100644 index 0000000..94fc04c --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/PlaintextAdminIntegrationTest.scala @@ -0,0 +1,2422 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.api + +import java.io.File +import java.net.InetAddress +import java.lang.{Long => JLong} +import java.time.{Duration => JDuration} +import java.util.Arrays.asList +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} +import java.util.concurrent.{CountDownLatch, ExecutionException, TimeUnit} +import java.util.{Collections, Optional, Properties} +import java.{time, util} + +import kafka.log.LogConfig +import kafka.security.authorizer.AclEntry +import kafka.server.{Defaults, DynamicConfig, KafkaConfig, KafkaServer} +import kafka.utils.TestUtils._ +import kafka.utils.{Log4jController, TestUtils} +import kafka.zk.KafkaZkClient +import org.apache.kafka.clients.HostResolver +import org.apache.kafka.clients.admin._ +import org.apache.kafka.clients.consumer.{ConsumerConfig, KafkaConsumer} +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord} +import org.apache.kafka.common.acl.{AccessControlEntry, AclBinding, AclBindingFilter, AclOperation, AclPermissionType} +import org.apache.kafka.common.config.{ConfigResource, LogLevelConfig} +import org.apache.kafka.common.errors._ +import org.apache.kafka.common.requests.{DeleteRecordsRequest, MetadataResponse} +import org.apache.kafka.common.resource.{PatternType, ResourcePattern, ResourceType} +import org.apache.kafka.common.utils.{Time, Utils} +import org.apache.kafka.common.{ConsumerGroupState, ElectionType, TopicCollection, TopicPartition, TopicPartitionInfo, TopicPartitionReplica, Uuid} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Disabled, Test, TestInfo} +import org.slf4j.LoggerFactory + +import scala.annotation.nowarn +import scala.collection.Seq +import scala.compat.java8.OptionConverters._ +import scala.concurrent.duration.Duration +import scala.concurrent.{Await, Future} +import scala.jdk.CollectionConverters._ +import scala.util.Random + +/** + * An integration test of the KafkaAdminClient. + * + * Also see [[org.apache.kafka.clients.admin.KafkaAdminClientTest]] for unit tests of the admin client. + */ +class PlaintextAdminIntegrationTest extends BaseAdminIntegrationTest { + import PlaintextAdminIntegrationTest._ + + val topic = "topic" + val partition = 0 + val topicPartition = new TopicPartition(topic, partition) + + private var brokerLoggerConfigResource: ConfigResource = _ + private val changedBrokerLoggers = scala.collection.mutable.Set[String]() + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + brokerLoggerConfigResource = new ConfigResource( + ConfigResource.Type.BROKER_LOGGER, servers.head.config.brokerId.toString) + } + + @AfterEach + override def tearDown(): Unit = { + teardownBrokerLoggers() + super.tearDown() + } + + @Test + def testClose(): Unit = { + val client = Admin.create(createConfig) + client.close() + client.close() // double close has no effect + } + + @Test + def testListNodes(): Unit = { + client = Admin.create(createConfig) + val brokerStrs = brokerList.split(",").toList.sorted + var nodeStrs: List[String] = null + do { + val nodes = client.describeCluster().nodes().get().asScala + nodeStrs = nodes.map ( node => s"${node.host}:${node.port}" ).toList.sorted + } while (nodeStrs.size < brokerStrs.size) + assertEquals(brokerStrs.mkString(","), nodeStrs.mkString(",")) + } + + @Test + def testAdminClientHandlingBadIPWithoutTimeout(): Unit = { + val config = createConfig + config.put(AdminClientConfig.SOCKET_CONNECTION_SETUP_TIMEOUT_MS_CONFIG, "1000") + val returnBadAddressFirst = new HostResolver { + override def resolve(host: String): Array[InetAddress] = { + Array[InetAddress](InetAddress.getByName("10.200.20.100"), InetAddress.getByName(host)) + } + } + client = AdminClientTestUtils.create(config, returnBadAddressFirst) + // simply check that a call, e.g. describeCluster, returns normally + client.describeCluster().nodes().get() + } + + @Test + def testCreateExistingTopicsThrowTopicExistsException(): Unit = { + client = Admin.create(createConfig) + val topic = "mytopic" + val topics = Seq(topic) + val newTopics = Seq(new NewTopic(topic, 1, 1.toShort)) + + client.createTopics(newTopics.asJava).all.get() + waitForTopics(client, topics, List()) + + val newTopicsWithInvalidRF = Seq(new NewTopic(topic, 1, (servers.size + 1).toShort)) + val e = assertThrows(classOf[ExecutionException], + () => client.createTopics(newTopicsWithInvalidRF.asJava, new CreateTopicsOptions().validateOnly(true)).all.get()) + assertTrue(e.getCause.isInstanceOf[TopicExistsException]) + } + + @Test + def testDeleteTopicsWithIds(): Unit = { + client = Admin.create(createConfig) + val topics = Seq("mytopic", "mytopic2", "mytopic3") + val newTopics = Seq( + new NewTopic("mytopic", Map((0: Integer) -> Seq[Integer](1, 2).asJava, (1: Integer) -> Seq[Integer](2, 0).asJava).asJava), + new NewTopic("mytopic2", 3, 3.toShort), + new NewTopic("mytopic3", Option.empty[Integer].asJava, Option.empty[java.lang.Short].asJava) + ) + val createResult = client.createTopics(newTopics.asJava) + createResult.all.get() + waitForTopics(client, topics, List()) + val topicIds = getTopicIds().values.toSet + + client.deleteTopics(TopicCollection.ofTopicIds(topicIds.asJava)).all.get() + waitForTopics(client, List(), topics) + } + + @Test + def testMetadataRefresh(): Unit = { + client = Admin.create(createConfig) + val topics = Seq("mytopic") + val newTopics = Seq(new NewTopic("mytopic", 3, 3.toShort)) + client.createTopics(newTopics.asJava).all.get() + waitForTopics(client, expectedPresent = topics, expectedMissing = List()) + + val controller = servers.find(_.config.brokerId == TestUtils.waitUntilControllerElected(zkClient)).get + controller.shutdown() + controller.awaitShutdown() + val topicDesc = client.describeTopics(topics.asJava).allTopicNames.get() + assertEquals(topics.toSet, topicDesc.keySet.asScala) + } + + /** + * describe should not auto create topics + */ + @Test + def testDescribeNonExistingTopic(): Unit = { + client = Admin.create(createConfig) + + val existingTopic = "existing-topic" + client.createTopics(Seq(existingTopic).map(new NewTopic(_, 1, 1.toShort)).asJava).all.get() + waitForTopics(client, Seq(existingTopic), List()) + + val nonExistingTopic = "non-existing" + val results = client.describeTopics(Seq(nonExistingTopic, existingTopic).asJava).topicNameValues() + assertEquals(existingTopic, results.get(existingTopic).get.name) + assertThrows(classOf[ExecutionException], () => results.get(nonExistingTopic).get).getCause.isInstanceOf[UnknownTopicOrPartitionException] + assertEquals(None, zkClient.getTopicPartitionCount(nonExistingTopic)) + } + + @Test + def testDescribeTopicsWithIds(): Unit = { + client = Admin.create(createConfig) + + val existingTopic = "existing-topic" + client.createTopics(Seq(existingTopic).map(new NewTopic(_, 1, 1.toShort)).asJava).all.get() + waitForTopics(client, Seq(existingTopic), List()) + val existingTopicId = zkClient.getTopicIdsForTopics(Set(existingTopic)).values.head + + val nonExistingTopicId = Uuid.randomUuid() + + val results = client.describeTopics(TopicCollection.ofTopicIds(Seq(existingTopicId, nonExistingTopicId).asJava)).topicIdValues() + assertEquals(existingTopicId, results.get(existingTopicId).get.topicId()) + assertThrows(classOf[ExecutionException], () => results.get(nonExistingTopicId).get).getCause.isInstanceOf[UnknownTopicIdException] + } + + @Test + def testDescribeCluster(): Unit = { + client = Admin.create(createConfig) + val result = client.describeCluster + val nodes = result.nodes.get() + val clusterId = result.clusterId().get() + assertEquals(servers.head.dataPlaneRequestProcessor.clusterId, clusterId) + val controller = result.controller().get() + assertEquals(servers.head.dataPlaneRequestProcessor.metadataCache.getControllerId. + getOrElse(MetadataResponse.NO_CONTROLLER_ID), controller.id()) + val brokers = brokerList.split(",") + assertEquals(brokers.size, nodes.size) + for (node <- nodes.asScala) { + val hostStr = s"${node.host}:${node.port}" + assertTrue(brokers.contains(hostStr), s"Unknown host:port pair $hostStr in brokerVersionInfos") + } + } + + @Test + def testDescribeLogDirs(): Unit = { + client = Admin.create(createConfig) + val topic = "topic" + val leaderByPartition = createTopic(topic, numPartitions = 10) + val partitionsByBroker = leaderByPartition.groupBy { case (_, leaderId) => leaderId }.map { case (k, v) => + k -> v.keys.toSeq + } + val brokers = (0 until brokerCount).map(Integer.valueOf) + val logDirInfosByBroker = client.describeLogDirs(brokers.asJava).allDescriptions.get + + (0 until brokerCount).foreach { brokerId => + val server = servers.find(_.config.brokerId == brokerId).get + val expectedPartitions = partitionsByBroker(brokerId) + val logDirInfos = logDirInfosByBroker.get(brokerId) + val replicaInfos = logDirInfos.asScala.flatMap { case (_, logDirInfo) => + logDirInfo.replicaInfos.asScala + }.filter { case (k, _) => k.topic == topic } + + assertEquals(expectedPartitions.toSet, replicaInfos.keys.map(_.partition).toSet) + logDirInfos.forEach { (logDir, logDirInfo) => + logDirInfo.replicaInfos.asScala.keys.foreach(tp => + assertEquals(server.logManager.getLog(tp).get.dir.getParent, logDir) + ) + } + } + } + + @Test + def testDescribeReplicaLogDirs(): Unit = { + client = Admin.create(createConfig) + val topic = "topic" + val leaderByPartition = createTopic(topic, numPartitions = 10) + val replicas = leaderByPartition.map { case (partition, brokerId) => + new TopicPartitionReplica(topic, partition, brokerId) + }.toSeq + + val replicaDirInfos = client.describeReplicaLogDirs(replicas.asJavaCollection).all.get + replicaDirInfos.forEach { (topicPartitionReplica, replicaDirInfo) => + val server = servers.find(_.config.brokerId == topicPartitionReplica.brokerId()).get + val tp = new TopicPartition(topicPartitionReplica.topic(), topicPartitionReplica.partition()) + assertEquals(server.logManager.getLog(tp).get.dir.getParent, replicaDirInfo.getCurrentReplicaLogDir) + } + } + + @Test + def testAlterReplicaLogDirs(): Unit = { + client = Admin.create(createConfig) + val topic = "topic" + val tp = new TopicPartition(topic, 0) + val randomNums = servers.map(server => server -> Random.nextInt(2)).toMap + + // Generate two mutually exclusive replicaAssignment + val firstReplicaAssignment = servers.map { server => + val logDir = new File(server.config.logDirs(randomNums(server))).getAbsolutePath + new TopicPartitionReplica(topic, 0, server.config.brokerId) -> logDir + }.toMap + val secondReplicaAssignment = servers.map { server => + val logDir = new File(server.config.logDirs(1 - randomNums(server))).getAbsolutePath + new TopicPartitionReplica(topic, 0, server.config.brokerId) -> logDir + }.toMap + + // Verify that replica can be created in the specified log directory + val futures = client.alterReplicaLogDirs(firstReplicaAssignment.asJava, + new AlterReplicaLogDirsOptions).values.asScala.values + futures.foreach { future => + val exception = assertThrows(classOf[ExecutionException], () => future.get) + assertTrue(exception.getCause.isInstanceOf[UnknownTopicOrPartitionException]) + } + + createTopic(topic, replicationFactor = brokerCount) + servers.foreach { server => + val logDir = server.logManager.getLog(tp).get.dir.getParent + assertEquals(firstReplicaAssignment(new TopicPartitionReplica(topic, 0, server.config.brokerId)), logDir) + } + + // Verify that replica can be moved to the specified log directory after the topic has been created + client.alterReplicaLogDirs(secondReplicaAssignment.asJava, new AlterReplicaLogDirsOptions).all.get + servers.foreach { server => + TestUtils.waitUntilTrue(() => { + val logDir = server.logManager.getLog(tp).get.dir.getParent + secondReplicaAssignment(new TopicPartitionReplica(topic, 0, server.config.brokerId)) == logDir + }, "timed out waiting for replica movement") + } + + // Verify that replica can be moved to the specified log directory while the producer is sending messages + val running = new AtomicBoolean(true) + val numMessages = new AtomicInteger + import scala.concurrent.ExecutionContext.Implicits._ + val producerFuture = Future { + val producer = TestUtils.createProducer( + TestUtils.getBrokerListStrFromServers(servers, protocol = securityProtocol), + securityProtocol = securityProtocol, + trustStoreFile = trustStoreFile, + retries = 0, // Producer should not have to retry when broker is moving replica between log directories. + requestTimeoutMs = 10000, + acks = -1 + ) + try { + while (running.get) { + val future = producer.send(new ProducerRecord(topic, s"xxxxxxxxxxxxxxxxxxxx-$numMessages".getBytes)) + numMessages.incrementAndGet() + future.get(10, TimeUnit.SECONDS) + } + numMessages.get + } finally producer.close() + } + + try { + TestUtils.waitUntilTrue(() => numMessages.get > 10, s"only $numMessages messages are produced before timeout. Producer future ${producerFuture.value}") + client.alterReplicaLogDirs(firstReplicaAssignment.asJava, new AlterReplicaLogDirsOptions).all.get + servers.foreach { server => + TestUtils.waitUntilTrue(() => { + val logDir = server.logManager.getLog(tp).get.dir.getParent + firstReplicaAssignment(new TopicPartitionReplica(topic, 0, server.config.brokerId)) == logDir + }, s"timed out waiting for replica movement. Producer future ${producerFuture.value}") + } + + val currentMessagesNum = numMessages.get + TestUtils.waitUntilTrue(() => numMessages.get - currentMessagesNum > 10, + s"only ${numMessages.get - currentMessagesNum} messages are produced within timeout after replica movement. Producer future ${producerFuture.value}") + } finally running.set(false) + + val finalNumMessages = Await.result(producerFuture, Duration(20, TimeUnit.SECONDS)) + + // Verify that all messages that are produced can be consumed + val consumerRecords = TestUtils.consumeTopicRecords(servers, topic, finalNumMessages, + securityProtocol = securityProtocol, trustStoreFile = trustStoreFile) + consumerRecords.zipWithIndex.foreach { case (consumerRecord, index) => + assertEquals(s"xxxxxxxxxxxxxxxxxxxx-$index", new String(consumerRecord.value)) + } + } + + @Test + def testDescribeAndAlterConfigs(): Unit = { + client = Admin.create(createConfig) + + // Create topics + val topic1 = "describe-alter-configs-topic-1" + val topicResource1 = new ConfigResource(ConfigResource.Type.TOPIC, topic1) + val topicConfig1 = new Properties + topicConfig1.setProperty(LogConfig.MaxMessageBytesProp, "500000") + topicConfig1.setProperty(LogConfig.RetentionMsProp, "60000000") + createTopic(topic1, numPartitions = 1, replicationFactor = 1, topicConfig1) + + val topic2 = "describe-alter-configs-topic-2" + val topicResource2 = new ConfigResource(ConfigResource.Type.TOPIC, topic2) + createTopic(topic2) + + // Describe topics and broker + val brokerResource1 = new ConfigResource(ConfigResource.Type.BROKER, servers(1).config.brokerId.toString) + val brokerResource2 = new ConfigResource(ConfigResource.Type.BROKER, servers(2).config.brokerId.toString) + val configResources = Seq(topicResource1, topicResource2, brokerResource1, brokerResource2) + val describeResult = client.describeConfigs(configResources.asJava) + val configs = describeResult.all.get + + assertEquals(4, configs.size) + + val maxMessageBytes1 = configs.get(topicResource1).get(LogConfig.MaxMessageBytesProp) + assertEquals(LogConfig.MaxMessageBytesProp, maxMessageBytes1.name) + assertEquals(topicConfig1.get(LogConfig.MaxMessageBytesProp), maxMessageBytes1.value) + assertFalse(maxMessageBytes1.isDefault) + assertFalse(maxMessageBytes1.isSensitive) + assertFalse(maxMessageBytes1.isReadOnly) + + assertEquals(topicConfig1.get(LogConfig.RetentionMsProp), + configs.get(topicResource1).get(LogConfig.RetentionMsProp).value) + + val maxMessageBytes2 = configs.get(topicResource2).get(LogConfig.MaxMessageBytesProp) + assertEquals(Defaults.MessageMaxBytes.toString, maxMessageBytes2.value) + assertEquals(LogConfig.MaxMessageBytesProp, maxMessageBytes2.name) + assertTrue(maxMessageBytes2.isDefault) + assertFalse(maxMessageBytes2.isSensitive) + assertFalse(maxMessageBytes2.isReadOnly) + + assertEquals(servers(1).config.nonInternalValues.size, configs.get(brokerResource1).entries.size) + assertEquals(servers(1).config.brokerId.toString, configs.get(brokerResource1).get(KafkaConfig.BrokerIdProp).value) + val listenerSecurityProtocolMap = configs.get(brokerResource1).get(KafkaConfig.ListenerSecurityProtocolMapProp) + assertEquals(servers(1).config.getString(KafkaConfig.ListenerSecurityProtocolMapProp), listenerSecurityProtocolMap.value) + assertEquals(KafkaConfig.ListenerSecurityProtocolMapProp, listenerSecurityProtocolMap.name) + assertFalse(listenerSecurityProtocolMap.isDefault) + assertFalse(listenerSecurityProtocolMap.isSensitive) + assertFalse(listenerSecurityProtocolMap.isReadOnly) + val truststorePassword = configs.get(brokerResource1).get(KafkaConfig.SslTruststorePasswordProp) + assertEquals(KafkaConfig.SslTruststorePasswordProp, truststorePassword.name) + assertNull(truststorePassword.value) + assertFalse(truststorePassword.isDefault) + assertTrue(truststorePassword.isSensitive) + assertFalse(truststorePassword.isReadOnly) + val compressionType = configs.get(brokerResource1).get(KafkaConfig.CompressionTypeProp) + assertEquals(servers(1).config.compressionType, compressionType.value) + assertEquals(KafkaConfig.CompressionTypeProp, compressionType.name) + assertTrue(compressionType.isDefault) + assertFalse(compressionType.isSensitive) + assertFalse(compressionType.isReadOnly) + + assertEquals(servers(2).config.nonInternalValues.size, configs.get(brokerResource2).entries.size) + assertEquals(servers(2).config.brokerId.toString, configs.get(brokerResource2).get(KafkaConfig.BrokerIdProp).value) + assertEquals(servers(2).config.logCleanerThreads.toString, + configs.get(brokerResource2).get(KafkaConfig.LogCleanerThreadsProp).value) + + checkValidAlterConfigs(client, topicResource1, topicResource2) + } + + @Test + def testCreatePartitions(): Unit = { + client = Admin.create(createConfig) + + // Create topics + val topic1 = "create-partitions-topic-1" + createTopic(topic1) + + val topic2 = "create-partitions-topic-2" + createTopic(topic2, replicationFactor = 2) + + // assert that both the topics have 1 partition + val topic1_metadata = getTopicMetadata(client, topic1) + val topic2_metadata = getTopicMetadata(client, topic2) + assertEquals(1, topic1_metadata.partitions.size) + assertEquals(1, topic2_metadata.partitions.size) + + val validateOnly = new CreatePartitionsOptions().validateOnly(true) + val actuallyDoIt = new CreatePartitionsOptions().validateOnly(false) + + def partitions(topic: String, expectedNumPartitionsOpt: Option[Int]): util.List[TopicPartitionInfo] = { + getTopicMetadata(client, topic, expectedNumPartitionsOpt = expectedNumPartitionsOpt).partitions + } + + def numPartitions(topic: String, expectedNumPartitionsOpt: Option[Int] = None): Int = partitions(topic, expectedNumPartitionsOpt).size + + // validateOnly: try creating a new partition (no assignments), to bring the total to 3 partitions + var alterResult = client.createPartitions(Map(topic1 -> + NewPartitions.increaseTo(3)).asJava, validateOnly) + var altered = alterResult.values.get(topic1).get + assertEquals(1, numPartitions(topic1)) + + // try creating a new partition (no assignments), to bring the total to 3 partitions + alterResult = client.createPartitions(Map(topic1 -> + NewPartitions.increaseTo(3)).asJava, actuallyDoIt) + altered = alterResult.values.get(topic1).get + TestUtils.waitUntilTrue(() => numPartitions(topic1) == 3, "Timed out waiting for new partitions to appear") + + // validateOnly: now try creating a new partition (with assignments), to bring the total to 3 partitions + val newPartition2Assignments = asList[util.List[Integer]](asList(0, 1), asList(1, 2)) + alterResult = client.createPartitions(Map(topic2 -> + NewPartitions.increaseTo(3, newPartition2Assignments)).asJava, validateOnly) + altered = alterResult.values.get(topic2).get + assertEquals(1, numPartitions(topic2)) + + // now try creating a new partition (with assignments), to bring the total to 3 partitions + alterResult = client.createPartitions(Map(topic2 -> + NewPartitions.increaseTo(3, newPartition2Assignments)).asJava, actuallyDoIt) + altered = alterResult.values.get(topic2).get + val actualPartitions2 = partitions(topic2, expectedNumPartitionsOpt = Some(3)) + assertEquals(3, actualPartitions2.size) + assertEquals(Seq(0, 1), actualPartitions2.get(1).replicas.asScala.map(_.id).toList) + assertEquals(Seq(1, 2), actualPartitions2.get(2).replicas.asScala.map(_.id).toList) + + // loop over error cases calling with+without validate-only + for (option <- Seq(validateOnly, actuallyDoIt)) { + val desc = if (option.validateOnly()) "validateOnly" else "validateOnly=false" + + // try a newCount which would be a decrease + alterResult = client.createPartitions(Map(topic1 -> + NewPartitions.increaseTo(1)).asJava, option) + + var e = assertThrows(classOf[ExecutionException], () => alterResult.values.get(topic1).get, + () => s"$desc: Expect InvalidPartitionsException when newCount is a decrease") + assertTrue(e.getCause.isInstanceOf[InvalidPartitionsException], desc) + assertEquals("Topic currently has 3 partitions, which is higher than the requested 1.", e.getCause.getMessage, desc) + assertEquals(3, numPartitions(topic1), desc) + + // try a newCount which would be a noop (without assignment) + alterResult = client.createPartitions(Map(topic2 -> + NewPartitions.increaseTo(3)).asJava, option) + e = assertThrows(classOf[ExecutionException], () => alterResult.values.get(topic2).get, + () => s"$desc: Expect InvalidPartitionsException when requesting a noop") + assertTrue(e.getCause.isInstanceOf[InvalidPartitionsException], desc) + assertEquals("Topic already has 3 partitions.", e.getCause.getMessage, desc) + assertEquals(3, numPartitions(topic2, Some(3)), desc) + + // try a newCount which would be a noop (where the assignment matches current state) + alterResult = client.createPartitions(Map(topic2 -> + NewPartitions.increaseTo(3, newPartition2Assignments)).asJava, option) + e = assertThrows(classOf[ExecutionException], () => alterResult.values.get(topic2).get) + assertTrue(e.getCause.isInstanceOf[InvalidPartitionsException], desc) + assertEquals("Topic already has 3 partitions.", e.getCause.getMessage, desc) + assertEquals(3, numPartitions(topic2, Some(3)), desc) + + // try a newCount which would be a noop (where the assignment doesn't match current state) + alterResult = client.createPartitions(Map(topic2 -> + NewPartitions.increaseTo(3, newPartition2Assignments.asScala.reverse.toList.asJava)).asJava, option) + e = assertThrows(classOf[ExecutionException], () => alterResult.values.get(topic2).get) + assertTrue(e.getCause.isInstanceOf[InvalidPartitionsException], desc) + assertEquals("Topic already has 3 partitions.", e.getCause.getMessage, desc) + assertEquals(3, numPartitions(topic2, Some(3)), desc) + + // try a bad topic name + val unknownTopic = "an-unknown-topic" + alterResult = client.createPartitions(Map(unknownTopic -> + NewPartitions.increaseTo(2)).asJava, option) + e = assertThrows(classOf[ExecutionException], () => alterResult.values.get(unknownTopic).get, + () => s"$desc: Expect InvalidTopicException when using an unknown topic") + assertTrue(e.getCause.isInstanceOf[UnknownTopicOrPartitionException], desc) + assertEquals("The topic 'an-unknown-topic' does not exist.", e.getCause.getMessage, desc) + + // try an invalid newCount + alterResult = client.createPartitions(Map(topic1 -> + NewPartitions.increaseTo(-22)).asJava, option) + e = assertThrows(classOf[ExecutionException], () => alterResult.values.get(topic1).get, + () => s"$desc: Expect InvalidPartitionsException when newCount is invalid") + assertTrue(e.getCause.isInstanceOf[InvalidPartitionsException], desc) + assertEquals("Topic currently has 3 partitions, which is higher than the requested -22.", e.getCause.getMessage, + desc) + assertEquals(3, numPartitions(topic1), desc) + + // try assignments where the number of brokers != replication factor + alterResult = client.createPartitions(Map(topic1 -> + NewPartitions.increaseTo(4, asList(asList(1, 2)))).asJava, option) + e = assertThrows(classOf[ExecutionException], () => alterResult.values.get(topic1).get, + () => s"$desc: Expect InvalidPartitionsException when #brokers != replication factor") + assertTrue(e.getCause.isInstanceOf[InvalidReplicaAssignmentException], desc) + assertEquals("Inconsistent replication factor between partitions, partition 0 has 1 " + + "while partitions [3] have replication factors [2], respectively.", + e.getCause.getMessage, desc) + assertEquals(3, numPartitions(topic1), desc) + + // try #assignments < with the increase + alterResult = client.createPartitions(Map(topic1 -> + NewPartitions.increaseTo(6, asList(asList(1)))).asJava, option) + e = assertThrows(classOf[ExecutionException], () => alterResult.values.get(topic1).get, + () => s"$desc: Expect InvalidReplicaAssignmentException when #assignments != newCount - oldCount") + assertTrue(e.getCause.isInstanceOf[InvalidReplicaAssignmentException], desc) + assertEquals("Increasing the number of partitions by 3 but 1 assignments provided.", e.getCause.getMessage, desc) + assertEquals(3, numPartitions(topic1), desc) + + // try #assignments > with the increase + alterResult = client.createPartitions(Map(topic1 -> + NewPartitions.increaseTo(4, asList(asList(1), asList(2)))).asJava, option) + e = assertThrows(classOf[ExecutionException], () => alterResult.values.get(topic1).get, + () => s"$desc: Expect InvalidReplicaAssignmentException when #assignments != newCount - oldCount") + assertTrue(e.getCause.isInstanceOf[InvalidReplicaAssignmentException], desc) + assertEquals("Increasing the number of partitions by 1 but 2 assignments provided.", e.getCause.getMessage, desc) + assertEquals(3, numPartitions(topic1), desc) + + // try with duplicate brokers in assignments + alterResult = client.createPartitions(Map(topic1 -> + NewPartitions.increaseTo(4, asList(asList(1, 1)))).asJava, option) + e = assertThrows(classOf[ExecutionException], () => alterResult.values.get(topic1).get, + () => s"$desc: Expect InvalidReplicaAssignmentException when assignments has duplicate brokers") + assertTrue(e.getCause.isInstanceOf[InvalidReplicaAssignmentException], desc) + assertEquals("Duplicate brokers not allowed in replica assignment: 1, 1 for partition id 3.", + e.getCause.getMessage, desc) + assertEquals(3, numPartitions(topic1), desc) + + // try assignments with differently sized inner lists + alterResult = client.createPartitions(Map(topic1 -> + NewPartitions.increaseTo(5, asList(asList(1), asList(1, 0)))).asJava, option) + e = assertThrows(classOf[ExecutionException], () => alterResult.values.get(topic1).get, + () => s"$desc: Expect InvalidReplicaAssignmentException when assignments have differently sized inner lists") + assertTrue(e.getCause.isInstanceOf[InvalidReplicaAssignmentException], desc) + assertEquals("Inconsistent replication factor between partitions, partition 0 has 1 " + + "while partitions [4] have replication factors [2], respectively.", e.getCause.getMessage, desc) + assertEquals(3, numPartitions(topic1), desc) + + // try assignments with unknown brokers + alterResult = client.createPartitions(Map(topic1 -> + NewPartitions.increaseTo(4, asList(asList(12)))).asJava, option) + e = assertThrows(classOf[ExecutionException], () => alterResult.values.get(topic1).get, + () => s"$desc: Expect InvalidReplicaAssignmentException when assignments contains an unknown broker") + assertTrue(e.getCause.isInstanceOf[InvalidReplicaAssignmentException], desc) + assertEquals("Unknown broker(s) in replica assignment: 12.", e.getCause.getMessage, desc) + assertEquals(3, numPartitions(topic1), desc) + + // try with empty assignments + alterResult = client.createPartitions(Map(topic1 -> + NewPartitions.increaseTo(4, Collections.emptyList())).asJava, option) + e = assertThrows(classOf[ExecutionException], () => alterResult.values.get(topic1).get, + () => s"$desc: Expect InvalidReplicaAssignmentException when assignments is empty") + assertTrue(e.getCause.isInstanceOf[InvalidReplicaAssignmentException], desc) + assertEquals("Increasing the number of partitions by 1 but 0 assignments provided.", e.getCause.getMessage, desc) + assertEquals(3, numPartitions(topic1), desc) + } + + // a mixed success, failure response + alterResult = client.createPartitions(Map( + topic1 -> NewPartitions.increaseTo(4), + topic2 -> NewPartitions.increaseTo(2)).asJava, actuallyDoIt) + // assert that the topic1 now has 4 partitions + altered = alterResult.values.get(topic1).get + TestUtils.waitUntilTrue(() => numPartitions(topic1) == 4, "Timed out waiting for new partitions to appear") + var e = assertThrows(classOf[ExecutionException], () => alterResult.values.get(topic2).get) + assertTrue(e.getCause.isInstanceOf[InvalidPartitionsException]) + assertEquals("Topic currently has 3 partitions, which is higher than the requested 2.", e.getCause.getMessage) + assertEquals(3, numPartitions(topic2)) + + // finally, try to add partitions to a topic queued for deletion + val deleteResult = client.deleteTopics(asList(topic1)) + deleteResult.topicNameValues.get(topic1).get + alterResult = client.createPartitions(Map(topic1 -> + NewPartitions.increaseTo(4)).asJava, validateOnly) + e = assertThrows(classOf[ExecutionException], () => alterResult.values.get(topic1).get, + () => "Expect InvalidTopicException when the topic is queued for deletion") + assertTrue(e.getCause.isInstanceOf[InvalidTopicException]) + assertEquals("The topic is queued for deletion.", e.getCause.getMessage) + } + + @Test + def testSeekAfterDeleteRecords(): Unit = { + createTopic(topic, numPartitions = 2, replicationFactor = brokerCount) + + client = Admin.create(createConfig) + + val consumer = createConsumer() + subscribeAndWaitForAssignment(topic, consumer) + + val producer = createProducer() + sendRecords(producer, 10, topicPartition) + consumer.seekToBeginning(Collections.singleton(topicPartition)) + assertEquals(0L, consumer.position(topicPartition)) + + val result = client.deleteRecords(Map(topicPartition -> RecordsToDelete.beforeOffset(5L)).asJava) + val lowWatermark = result.lowWatermarks().get(topicPartition).get.lowWatermark + assertEquals(5L, lowWatermark) + + consumer.seekToBeginning(Collections.singletonList(topicPartition)) + assertEquals(5L, consumer.position(topicPartition)) + + consumer.seek(topicPartition, 7L) + assertEquals(7L, consumer.position(topicPartition)) + + client.deleteRecords(Map(topicPartition -> RecordsToDelete.beforeOffset(DeleteRecordsRequest.HIGH_WATERMARK)).asJava).all.get + consumer.seekToBeginning(Collections.singletonList(topicPartition)) + assertEquals(10L, consumer.position(topicPartition)) + } + + @Test + def testLogStartOffsetCheckpoint(): Unit = { + createTopic(topic, numPartitions = 2, replicationFactor = brokerCount) + + client = Admin.create(createConfig) + + val consumer = createConsumer() + subscribeAndWaitForAssignment(topic, consumer) + + val producer = createProducer() + sendRecords(producer, 10, topicPartition) + var result = client.deleteRecords(Map(topicPartition -> RecordsToDelete.beforeOffset(5L)).asJava) + var lowWatermark: Option[Long] = Some(result.lowWatermarks.get(topicPartition).get.lowWatermark) + assertEquals(Some(5), lowWatermark) + + for (i <- 0 until brokerCount) { + killBroker(i) + } + restartDeadBrokers() + + client.close() + brokerList = TestUtils.bootstrapServers(servers, listenerName) + client = Admin.create(createConfig) + + TestUtils.waitUntilTrue(() => { + // Need to retry if leader is not available for the partition + result = client.deleteRecords(Map(topicPartition -> RecordsToDelete.beforeOffset(0L)).asJava) + + lowWatermark = None + val future = result.lowWatermarks().get(topicPartition) + try { + lowWatermark = Some(future.get.lowWatermark) + lowWatermark.contains(5L) + } catch { + case e: ExecutionException if e.getCause.isInstanceOf[LeaderNotAvailableException] || + e.getCause.isInstanceOf[NotLeaderOrFollowerException] => false + } + }, s"Expected low watermark of the partition to be 5 but got ${lowWatermark.getOrElse("no response within the timeout")}") + } + + @Test + def testLogStartOffsetAfterDeleteRecords(): Unit = { + createTopic(topic, numPartitions = 2, replicationFactor = brokerCount) + + client = Admin.create(createConfig) + + val consumer = createConsumer() + subscribeAndWaitForAssignment(topic, consumer) + + val producer = createProducer() + sendRecords(producer, 10, topicPartition) + + val result = client.deleteRecords(Map(topicPartition -> RecordsToDelete.beforeOffset(3L)).asJava) + val lowWatermark = result.lowWatermarks.get(topicPartition).get.lowWatermark + assertEquals(3L, lowWatermark) + + for (i <- 0 until brokerCount) + assertEquals(3, servers(i).replicaManager.localLog(topicPartition).get.logStartOffset) + } + + @Test + def testReplicaCanFetchFromLogStartOffsetAfterDeleteRecords(): Unit = { + val leaders = createTopic(topic, replicationFactor = brokerCount) + val followerIndex = if (leaders(0) != servers(0).config.brokerId) 0 else 1 + + def waitForFollowerLog(expectedStartOffset: Long, expectedEndOffset: Long): Unit = { + TestUtils.waitUntilTrue(() => servers(followerIndex).replicaManager.localLog(topicPartition) != None, + "Expected follower to create replica for partition") + + // wait until the follower discovers that log start offset moved beyond its HW + TestUtils.waitUntilTrue(() => { + servers(followerIndex).replicaManager.localLog(topicPartition).get.logStartOffset == expectedStartOffset + }, s"Expected follower to discover new log start offset $expectedStartOffset") + + TestUtils.waitUntilTrue(() => { + servers(followerIndex).replicaManager.localLog(topicPartition).get.logEndOffset == expectedEndOffset + }, s"Expected follower to catch up to log end offset $expectedEndOffset") + } + + // we will produce to topic and delete records while one follower is down + killBroker(followerIndex) + + client = Admin.create(createConfig) + val producer = createProducer() + sendRecords(producer, 100, topicPartition) + + val result = client.deleteRecords(Map(topicPartition -> RecordsToDelete.beforeOffset(3L)).asJava) + result.all().get() + + // start the stopped broker to verify that it will be able to fetch from new log start offset + restartDeadBrokers() + + waitForFollowerLog(expectedStartOffset=3L, expectedEndOffset=100L) + + // after the new replica caught up, all replicas should have same log start offset + for (i <- 0 until brokerCount) + assertEquals(3, servers(i).replicaManager.localLog(topicPartition).get.logStartOffset) + + // kill the same follower again, produce more records, and delete records beyond follower's LOE + killBroker(followerIndex) + sendRecords(producer, 100, topicPartition) + val result1 = client.deleteRecords(Map(topicPartition -> RecordsToDelete.beforeOffset(117L)).asJava) + result1.all().get() + restartDeadBrokers() + waitForFollowerLog(expectedStartOffset=117L, expectedEndOffset=200L) + } + + @Test + def testAlterLogDirsAfterDeleteRecords(): Unit = { + client = Admin.create(createConfig) + createTopic(topic, replicationFactor = brokerCount) + val expectedLEO = 100 + val producer = createProducer() + sendRecords(producer, expectedLEO, topicPartition) + + // delete records to move log start offset + val result = client.deleteRecords(Map(topicPartition -> RecordsToDelete.beforeOffset(3L)).asJava) + result.all().get() + // make sure we are in the expected state after delete records + for (i <- 0 until brokerCount) { + assertEquals(3, servers(i).replicaManager.localLog(topicPartition).get.logStartOffset) + assertEquals(expectedLEO, servers(i).replicaManager.localLog(topicPartition).get.logEndOffset) + } + + // we will create another dir just for one server + val futureLogDir = servers(0).config.logDirs(1) + val futureReplica = new TopicPartitionReplica(topic, 0, servers(0).config.brokerId) + + // Verify that replica can be moved to the specified log directory + client.alterReplicaLogDirs(Map(futureReplica -> futureLogDir).asJava).all.get + TestUtils.waitUntilTrue(() => { + futureLogDir == servers(0).logManager.getLog(topicPartition).get.dir.getParent + }, "timed out waiting for replica movement") + + // once replica moved, its LSO and LEO should match other replicas + assertEquals(3, servers.head.replicaManager.localLog(topicPartition).get.logStartOffset) + assertEquals(expectedLEO, servers.head.replicaManager.localLog(topicPartition).get.logEndOffset) + } + + @Test + def testOffsetsForTimesAfterDeleteRecords(): Unit = { + createTopic(topic, numPartitions = 2, replicationFactor = brokerCount) + + client = Admin.create(createConfig) + + val consumer = createConsumer() + subscribeAndWaitForAssignment(topic, consumer) + + val producer = createProducer() + sendRecords(producer, 10, topicPartition) + assertEquals(0L, consumer.offsetsForTimes(Map(topicPartition -> JLong.valueOf(0L)).asJava).get(topicPartition).offset()) + + var result = client.deleteRecords(Map(topicPartition -> RecordsToDelete.beforeOffset(5L)).asJava) + result.all.get + assertEquals(5L, consumer.offsetsForTimes(Map(topicPartition -> JLong.valueOf(0L)).asJava).get(topicPartition).offset()) + + result = client.deleteRecords(Map(topicPartition -> RecordsToDelete.beforeOffset(DeleteRecordsRequest.HIGH_WATERMARK)).asJava) + result.all.get + assertNull(consumer.offsetsForTimes(Map(topicPartition -> JLong.valueOf(0L)).asJava).get(topicPartition)) + } + + @Test + def testConsumeAfterDeleteRecords(): Unit = { + val consumer = createConsumer() + subscribeAndWaitForAssignment(topic, consumer) + + client = Admin.create(createConfig) + + val producer = createProducer() + sendRecords(producer, 10, topicPartition) + var messageCount = 0 + TestUtils.consumeRecords(consumer, 10) + + client.deleteRecords(Map(topicPartition -> RecordsToDelete.beforeOffset(3L)).asJava).all.get + consumer.seek(topicPartition, 1) + messageCount = 0 + TestUtils.consumeRecords(consumer, 7) + + client.deleteRecords(Map(topicPartition -> RecordsToDelete.beforeOffset(8L)).asJava).all.get + consumer.seek(topicPartition, 1) + messageCount = 0 + TestUtils.consumeRecords(consumer, 2) + } + + @Test + def testDeleteRecordsWithException(): Unit = { + val consumer = createConsumer() + subscribeAndWaitForAssignment(topic, consumer) + + client = Admin.create(createConfig) + + val producer = createProducer() + sendRecords(producer, 10, topicPartition) + + assertEquals(5L, client.deleteRecords(Map(topicPartition -> RecordsToDelete.beforeOffset(5L)).asJava) + .lowWatermarks.get(topicPartition).get.lowWatermark) + + // OffsetOutOfRangeException if offset > high_watermark + var cause = assertThrows(classOf[ExecutionException], + () => client.deleteRecords(Map(topicPartition -> RecordsToDelete.beforeOffset(20L)).asJava).lowWatermarks.get(topicPartition).get).getCause + assertEquals(classOf[OffsetOutOfRangeException], cause.getClass) + + val nonExistPartition = new TopicPartition(topic, 3) + // LeaderNotAvailableException if non existent partition + cause = assertThrows(classOf[ExecutionException], + () => client.deleteRecords(Map(nonExistPartition -> RecordsToDelete.beforeOffset(20L)).asJava).lowWatermarks.get(nonExistPartition).get).getCause + assertEquals(classOf[LeaderNotAvailableException], cause.getClass) + } + + @Test + def testDescribeConfigsForTopic(): Unit = { + createTopic(topic, numPartitions = 2, replicationFactor = brokerCount) + client = Admin.create(createConfig) + + val existingTopic = new ConfigResource(ConfigResource.Type.TOPIC, topic) + client.describeConfigs(Collections.singletonList(existingTopic)).values.get(existingTopic).get() + + val nonExistentTopic = new ConfigResource(ConfigResource.Type.TOPIC, "unknown") + val describeResult1 = client.describeConfigs(Collections.singletonList(nonExistentTopic)) + + assertTrue(assertThrows(classOf[ExecutionException], () => describeResult1.values.get(nonExistentTopic).get).getCause.isInstanceOf[UnknownTopicOrPartitionException]) + + val invalidTopic = new ConfigResource(ConfigResource.Type.TOPIC, "(invalid topic)") + val describeResult2 = client.describeConfigs(Collections.singletonList(invalidTopic)) + + assertTrue(assertThrows(classOf[ExecutionException], () => describeResult2.values.get(invalidTopic).get).getCause.isInstanceOf[InvalidTopicException]) + } + + private def subscribeAndWaitForAssignment(topic: String, consumer: KafkaConsumer[Array[Byte], Array[Byte]]): Unit = { + consumer.subscribe(Collections.singletonList(topic)) + TestUtils.pollUntilTrue(consumer, () => !consumer.assignment.isEmpty, "Expected non-empty assignment") + } + + private def sendRecords(producer: KafkaProducer[Array[Byte], Array[Byte]], + numRecords: Int, + topicPartition: TopicPartition): Unit = { + val futures = (0 until numRecords).map( i => { + val record = new ProducerRecord(topicPartition.topic, topicPartition.partition, s"$i".getBytes, s"$i".getBytes) + debug(s"Sending this record: $record") + producer.send(record) + }) + + futures.foreach(_.get) + } + + @Test + def testInvalidAlterConfigs(): Unit = { + client = Admin.create(createConfig) + checkInvalidAlterConfigs(zkClient, servers, client) + } + + /** + * Test that ACL operations are not possible when the authorizer is disabled. + * Also see [[kafka.api.SaslSslAdminIntegrationTest.testAclOperations()]] for tests of ACL operations + * when the authorizer is enabled. + */ + @Test + def testAclOperations(): Unit = { + val acl = new AclBinding(new ResourcePattern(ResourceType.TOPIC, "mytopic3", PatternType.LITERAL), + new AccessControlEntry("User:ANONYMOUS", "*", AclOperation.DESCRIBE, AclPermissionType.ALLOW)) + client = Admin.create(createConfig) + assertFutureExceptionTypeEquals(client.describeAcls(AclBindingFilter.ANY).values(), classOf[SecurityDisabledException]) + assertFutureExceptionTypeEquals(client.createAcls(Collections.singleton(acl)).all(), + classOf[SecurityDisabledException]) + assertFutureExceptionTypeEquals(client.deleteAcls(Collections.singleton(acl.toFilter())).all(), + classOf[SecurityDisabledException]) + } + + /** + * Test closing the AdminClient with a generous timeout. Calls in progress should be completed, + * since they can be done within the timeout. New calls should receive timeouts. + */ + @Test + def testDelayedClose(): Unit = { + client = Admin.create(createConfig) + val topics = Seq("mytopic", "mytopic2") + val newTopics = topics.map(new NewTopic(_, 1, 1.toShort)) + val future = client.createTopics(newTopics.asJava, new CreateTopicsOptions().validateOnly(true)).all() + client.close(time.Duration.ofHours(2)) + val future2 = client.createTopics(newTopics.asJava, new CreateTopicsOptions().validateOnly(true)).all() + assertFutureExceptionTypeEquals(future2, classOf[TimeoutException]) + future.get + client.close(time.Duration.ofMinutes(30)) // multiple close-with-timeout should have no effect + } + + /** + * Test closing the AdminClient with a timeout of 0, when there are calls with extremely long + * timeouts in progress. The calls should be aborted after the hard shutdown timeout elapses. + */ + @Test + def testForceClose(): Unit = { + val config = createConfig + config.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, s"localhost:${TestUtils.IncorrectBrokerPort}") + client = Admin.create(config) + // Because the bootstrap servers are set up incorrectly, this call will not complete, but must be + // cancelled by the close operation. + val future = client.createTopics(Seq("mytopic", "mytopic2").map(new NewTopic(_, 1, 1.toShort)).asJava, + new CreateTopicsOptions().timeoutMs(900000)).all() + client.close(time.Duration.ZERO) + assertFutureExceptionTypeEquals(future, classOf[TimeoutException]) + } + + /** + * Check that a call with a timeout does not complete before the minimum timeout has elapsed, + * even when the default request timeout is shorter. + */ + @Test + def testMinimumRequestTimeouts(): Unit = { + val config = createConfig + config.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, s"localhost:${TestUtils.IncorrectBrokerPort}") + config.put(AdminClientConfig.REQUEST_TIMEOUT_MS_CONFIG, "0") + client = Admin.create(config) + val startTimeMs = Time.SYSTEM.milliseconds() + val future = client.createTopics(Seq("mytopic", "mytopic2").map(new NewTopic(_, 1, 1.toShort)).asJava, + new CreateTopicsOptions().timeoutMs(2)).all() + assertFutureExceptionTypeEquals(future, classOf[TimeoutException]) + val endTimeMs = Time.SYSTEM.milliseconds() + assertTrue(endTimeMs > startTimeMs, "Expected the timeout to take at least one millisecond.") + } + + /** + * Test injecting timeouts for calls that are in flight. + */ + @Test + def testCallInFlightTimeouts(): Unit = { + val config = createConfig + config.put(AdminClientConfig.DEFAULT_API_TIMEOUT_MS_CONFIG, "100000000") + config.put(AdminClientConfig.RETRIES_CONFIG, "0") + val factory = new KafkaAdminClientTest.FailureInjectingTimeoutProcessorFactory() + client = KafkaAdminClientTest.createInternal(new AdminClientConfig(config), factory) + val future = client.createTopics(Seq("mytopic", "mytopic2").map(new NewTopic(_, 1, 1.toShort)).asJava, + new CreateTopicsOptions().validateOnly(true)).all() + assertFutureExceptionTypeEquals(future, classOf[TimeoutException]) + val future2 = client.createTopics(Seq("mytopic3", "mytopic4").map(new NewTopic(_, 1, 1.toShort)).asJava, + new CreateTopicsOptions().validateOnly(true)).all() + future2.get + assertEquals(1, factory.failuresInjected) + } + + /** + * Test the consumer group APIs. + */ + @Test + def testConsumerGroups(): Unit = { + val config = createConfig + client = Admin.create(config) + try { + // Verify that initially there are no consumer groups to list. + val list1 = client.listConsumerGroups() + assertTrue(0 == list1.all().get().size()) + assertTrue(0 == list1.errors().get().size()) + assertTrue(0 == list1.valid().get().size()) + val testTopicName = "test_topic" + val testTopicName1 = testTopicName + "1" + val testTopicName2 = testTopicName + "2" + val testNumPartitions = 2 + + client.createTopics(util.Arrays.asList( + new NewTopic(testTopicName, testNumPartitions, 1.toShort), + new NewTopic(testTopicName1, testNumPartitions, 1.toShort), + new NewTopic(testTopicName2, testNumPartitions, 1.toShort) + )).all().get() + waitForTopics(client, List(testTopicName, testTopicName1, testTopicName2), List()) + + val producer = createProducer() + try { + producer.send(new ProducerRecord(testTopicName, 0, null, null)).get() + } finally { + Utils.closeQuietly(producer, "producer") + } + + val EMPTY_GROUP_INSTANCE_ID = "" + val testGroupId = "test_group_id" + val testClientId = "test_client_id" + val testInstanceId1 = "test_instance_id_1" + val testInstanceId2 = "test_instance_id_2" + val fakeGroupId = "fake_group_id" + + def createProperties(groupInstanceId: String): Properties = { + val newConsumerConfig = new Properties(consumerConfig) + // We need to disable the auto commit because after the members got removed from group, the offset commit + // will cause the member rejoining and the test will be flaky (check ConsumerCoordinator#OffsetCommitResponseHandler) + newConsumerConfig.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") + newConsumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, testGroupId) + newConsumerConfig.setProperty(ConsumerConfig.CLIENT_ID_CONFIG, testClientId) + if (groupInstanceId != EMPTY_GROUP_INSTANCE_ID) { + newConsumerConfig.setProperty(ConsumerConfig.GROUP_INSTANCE_ID_CONFIG, groupInstanceId) + } + newConsumerConfig + } + + // contains two static members and one dynamic member + val groupInstanceSet = Set(testInstanceId1, testInstanceId2, EMPTY_GROUP_INSTANCE_ID) + val consumerSet = groupInstanceSet.map { groupInstanceId => createConsumer(configOverrides = createProperties(groupInstanceId))} + val topicSet = Set(testTopicName, testTopicName1, testTopicName2) + + val latch = new CountDownLatch(consumerSet.size) + try { + def createConsumerThread[K,V](consumer: KafkaConsumer[K,V], topic: String): Thread = { + new Thread { + override def run : Unit = { + consumer.subscribe(Collections.singleton(topic)) + try { + while (true) { + consumer.poll(JDuration.ofSeconds(5)) + if (!consumer.assignment.isEmpty && latch.getCount > 0L) + latch.countDown() + consumer.commitSync() + } + } catch { + case _: InterruptException => // Suppress the output to stderr + } + } + } + } + + // Start consumers in a thread that will subscribe to a new group. + val consumerThreads = consumerSet.zip(topicSet).map(zipped => createConsumerThread(zipped._1, zipped._2)) + + try { + consumerThreads.foreach(_.start()) + assertTrue(latch.await(30000, TimeUnit.MILLISECONDS)) + // Test that we can list the new group. + TestUtils.waitUntilTrue(() => { + val matching = client.listConsumerGroups.all.get.asScala.filter(group => + group.groupId == testGroupId && + group.state.get == ConsumerGroupState.STABLE) + matching.size == 1 + }, s"Expected to be able to list $testGroupId") + + TestUtils.waitUntilTrue(() => { + val options = new ListConsumerGroupsOptions().inStates(Set(ConsumerGroupState.STABLE).asJava) + val matching = client.listConsumerGroups(options).all.get.asScala.filter(group => + group.groupId == testGroupId && + group.state.get == ConsumerGroupState.STABLE) + matching.size == 1 + }, s"Expected to be able to list $testGroupId in state Stable") + + TestUtils.waitUntilTrue(() => { + val options = new ListConsumerGroupsOptions().inStates(Set(ConsumerGroupState.EMPTY).asJava) + val matching = client.listConsumerGroups(options).all.get.asScala.filter( + _.groupId == testGroupId) + matching.isEmpty + }, s"Expected to find zero groups") + + val describeWithFakeGroupResult = client.describeConsumerGroups(Seq(testGroupId, fakeGroupId).asJava, + new DescribeConsumerGroupsOptions().includeAuthorizedOperations(true)) + assertEquals(2, describeWithFakeGroupResult.describedGroups().size()) + + // Test that we can get information about the test consumer group. + assertTrue(describeWithFakeGroupResult.describedGroups().containsKey(testGroupId)) + var testGroupDescription = describeWithFakeGroupResult.describedGroups().get(testGroupId).get() + + assertEquals(testGroupId, testGroupDescription.groupId()) + assertFalse(testGroupDescription.isSimpleConsumerGroup) + assertEquals(groupInstanceSet.size, testGroupDescription.members().size()) + val members = testGroupDescription.members() + members.asScala.foreach(member => assertEquals(testClientId, member.clientId())) + val topicPartitionsByTopic = members.asScala.flatMap(_.assignment().topicPartitions().asScala).groupBy(_.topic()) + topicSet.foreach { topic => + val topicPartitions = topicPartitionsByTopic.getOrElse(topic, List.empty) + assertEquals(testNumPartitions, topicPartitions.size) + } + + val expectedOperations = AclEntry.supportedOperations(ResourceType.GROUP).asJava + assertEquals(expectedOperations, testGroupDescription.authorizedOperations()) + + // Test that the fake group is listed as dead. + assertTrue(describeWithFakeGroupResult.describedGroups().containsKey(fakeGroupId)) + val fakeGroupDescription = describeWithFakeGroupResult.describedGroups().get(fakeGroupId).get() + + assertEquals(fakeGroupId, fakeGroupDescription.groupId()) + assertEquals(0, fakeGroupDescription.members().size()) + assertEquals("", fakeGroupDescription.partitionAssignor()) + assertEquals(ConsumerGroupState.DEAD, fakeGroupDescription.state()) + assertEquals(expectedOperations, fakeGroupDescription.authorizedOperations()) + + // Test that all() returns 2 results + assertEquals(2, describeWithFakeGroupResult.all().get().size()) + + // Test listConsumerGroupOffsets + TestUtils.waitUntilTrue(() => { + val parts = client.listConsumerGroupOffsets(testGroupId).partitionsToOffsetAndMetadata().get() + val part = new TopicPartition(testTopicName, 0) + parts.containsKey(part) && (parts.get(part).offset() == 1) + }, s"Expected the offset for partition 0 to eventually become 1.") + + // Test delete non-exist consumer instance + val invalidInstanceId = "invalid-instance-id" + var removeMembersResult = client.removeMembersFromConsumerGroup(testGroupId, new RemoveMembersFromConsumerGroupOptions( + Collections.singleton(new MemberToRemove(invalidInstanceId)) + )) + + TestUtils.assertFutureExceptionTypeEquals(removeMembersResult.all, classOf[UnknownMemberIdException]) + val firstMemberFuture = removeMembersResult.memberResult(new MemberToRemove(invalidInstanceId)) + TestUtils.assertFutureExceptionTypeEquals(firstMemberFuture, classOf[UnknownMemberIdException]) + + // Test consumer group deletion + var deleteResult = client.deleteConsumerGroups(Seq(testGroupId, fakeGroupId).asJava) + assertEquals(2, deleteResult.deletedGroups().size()) + + // Deleting the fake group ID should get GroupIdNotFoundException. + assertTrue(deleteResult.deletedGroups().containsKey(fakeGroupId)) + assertFutureExceptionTypeEquals(deleteResult.deletedGroups().get(fakeGroupId), + classOf[GroupIdNotFoundException]) + + // Deleting the real group ID should get GroupNotEmptyException + assertTrue(deleteResult.deletedGroups().containsKey(testGroupId)) + assertFutureExceptionTypeEquals(deleteResult.deletedGroups().get(testGroupId), + classOf[GroupNotEmptyException]) + + // Test delete one correct static member + removeMembersResult = client.removeMembersFromConsumerGroup(testGroupId, new RemoveMembersFromConsumerGroupOptions( + Collections.singleton(new MemberToRemove(testInstanceId1)) + )) + + assertNull(removeMembersResult.all().get()) + val validMemberFuture = removeMembersResult.memberResult(new MemberToRemove(testInstanceId1)) + assertNull(validMemberFuture.get()) + + val describeTestGroupResult = client.describeConsumerGroups(Seq(testGroupId).asJava, + new DescribeConsumerGroupsOptions().includeAuthorizedOperations(true)) + assertEquals(1, describeTestGroupResult.describedGroups().size()) + + testGroupDescription = describeTestGroupResult.describedGroups().get(testGroupId).get() + + assertEquals(testGroupId, testGroupDescription.groupId) + assertFalse(testGroupDescription.isSimpleConsumerGroup) + assertEquals(consumerSet.size - 1, testGroupDescription.members().size()) + + // Delete all active members remaining (a static member + a dynamic member) + removeMembersResult = client.removeMembersFromConsumerGroup(testGroupId, new RemoveMembersFromConsumerGroupOptions()) + assertNull(removeMembersResult.all().get()) + + // The group should contain no members now. + testGroupDescription = client.describeConsumerGroups(Seq(testGroupId).asJava, + new DescribeConsumerGroupsOptions().includeAuthorizedOperations(true)) + .describedGroups().get(testGroupId).get() + assertTrue(testGroupDescription.members().isEmpty) + + // Consumer group deletion on empty group should succeed + deleteResult = client.deleteConsumerGroups(Seq(testGroupId).asJava) + assertEquals(1, deleteResult.deletedGroups().size()) + + assertTrue(deleteResult.deletedGroups().containsKey(testGroupId)) + assertNull(deleteResult.deletedGroups().get(testGroupId).get()) + } finally { + consumerThreads.foreach { + case consumerThread => + consumerThread.interrupt() + consumerThread.join() + } + } + } finally { + consumerSet.zip(groupInstanceSet).foreach(zipped => Utils.closeQuietly(zipped._1, zipped._2)) + } + } finally { + Utils.closeQuietly(client, "adminClient") + } + } + + @Test + def testDeleteConsumerGroupOffsets(): Unit = { + val config = createConfig + client = Admin.create(config) + try { + val testTopicName = "test_topic" + val testGroupId = "test_group_id" + val testClientId = "test_client_id" + val fakeGroupId = "fake_group_id" + + val tp1 = new TopicPartition(testTopicName, 0) + val tp2 = new TopicPartition("foo", 0) + + client.createTopics(Collections.singleton( + new NewTopic(testTopicName, 1, 1.toShort))).all().get() + waitForTopics(client, List(testTopicName), List()) + + val producer = createProducer() + try { + producer.send(new ProducerRecord(testTopicName, 0, null, null)).get() + } finally { + Utils.closeQuietly(producer, "producer") + } + + val newConsumerConfig = new Properties(consumerConfig) + newConsumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, testGroupId) + newConsumerConfig.setProperty(ConsumerConfig.CLIENT_ID_CONFIG, testClientId) + // Increase timeouts to avoid having a rebalance during the test + newConsumerConfig.setProperty(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG, Integer.MAX_VALUE.toString) + newConsumerConfig.setProperty(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, Defaults.GroupMaxSessionTimeoutMs.toString) + val consumer = createConsumer(configOverrides = newConsumerConfig) + + try { + TestUtils.subscribeAndWaitForRecords(testTopicName, consumer) + consumer.commitSync() + + // Test offset deletion while consuming + val offsetDeleteResult = client.deleteConsumerGroupOffsets(testGroupId, Set(tp1, tp2).asJava) + + // Top level error will equal to the first partition level error + assertFutureExceptionTypeEquals(offsetDeleteResult.all(), classOf[GroupSubscribedToTopicException]) + assertFutureExceptionTypeEquals(offsetDeleteResult.partitionResult(tp1), + classOf[GroupSubscribedToTopicException]) + assertFutureExceptionTypeEquals(offsetDeleteResult.partitionResult(tp2), + classOf[UnknownTopicOrPartitionException]) + + // Test the fake group ID + val fakeDeleteResult = client.deleteConsumerGroupOffsets(fakeGroupId, Set(tp1, tp2).asJava) + + assertFutureExceptionTypeEquals(fakeDeleteResult.all(), classOf[GroupIdNotFoundException]) + assertFutureExceptionTypeEquals(fakeDeleteResult.partitionResult(tp1), + classOf[GroupIdNotFoundException]) + assertFutureExceptionTypeEquals(fakeDeleteResult.partitionResult(tp2), + classOf[GroupIdNotFoundException]) + + } finally { + Utils.closeQuietly(consumer, "consumer") + } + + // Test offset deletion when group is empty + val offsetDeleteResult = client.deleteConsumerGroupOffsets(testGroupId, Set(tp1, tp2).asJava) + + assertFutureExceptionTypeEquals(offsetDeleteResult.all(), + classOf[UnknownTopicOrPartitionException]) + assertNull(offsetDeleteResult.partitionResult(tp1).get()) + assertFutureExceptionTypeEquals(offsetDeleteResult.partitionResult(tp2), + classOf[UnknownTopicOrPartitionException]) + } finally { + Utils.closeQuietly(client, "adminClient") + } + } + + @Test + def testElectPreferredLeaders(): Unit = { + client = Admin.create(createConfig) + + val prefer0 = Seq(0, 1, 2) + val prefer1 = Seq(1, 2, 0) + val prefer2 = Seq(2, 0, 1) + + val partition1 = new TopicPartition("elect-preferred-leaders-topic-1", 0) + TestUtils.createTopic(zkClient, partition1.topic, Map[Int, Seq[Int]](partition1.partition -> prefer0), servers) + + val partition2 = new TopicPartition("elect-preferred-leaders-topic-2", 0) + TestUtils.createTopic(zkClient, partition2.topic, Map[Int, Seq[Int]](partition2.partition -> prefer0), servers) + + def preferredLeader(topicPartition: TopicPartition): Int = { + val partitionMetadata = getTopicMetadata(client, topicPartition.topic).partitions.get(topicPartition.partition) + val preferredLeaderMetadata = partitionMetadata.replicas.get(0) + preferredLeaderMetadata.id + } + + /** Changes the preferred leader without changing the current leader. */ + def changePreferredLeader(newAssignment: Seq[Int]) = { + val preferred = newAssignment.head + val prior1 = zkClient.getLeaderForPartition(partition1).get + val prior2 = zkClient.getLeaderForPartition(partition2).get + + var m = Map.empty[TopicPartition, Seq[Int]] + + if (prior1 != preferred) + m += partition1 -> newAssignment + if (prior2 != preferred) + m += partition2 -> newAssignment + + zkClient.createPartitionReassignment(m) + TestUtils.waitUntilTrue( + () => preferredLeader(partition1) == preferred && preferredLeader(partition2) == preferred, + s"Expected preferred leader to become $preferred, but is ${preferredLeader(partition1)} and ${preferredLeader(partition2)}", + 10000) + // Check the leader hasn't moved + TestUtils.assertLeader(client, partition1, prior1) + TestUtils.assertLeader(client, partition2, prior2) + } + + // Check current leaders are 0 + TestUtils.assertLeader(client, partition1, 0) + TestUtils.assertLeader(client, partition2, 0) + + // Noop election + var electResult = client.electLeaders(ElectionType.PREFERRED, Set(partition1).asJava) + var exception = electResult.partitions.get.get(partition1).get + assertEquals(classOf[ElectionNotNeededException], exception.getClass) + TestUtils.assertLeader(client, partition1, 0) + + // Noop election with null partitions + electResult = client.electLeaders(ElectionType.PREFERRED, null) + assertTrue(electResult.partitions.get.isEmpty) + TestUtils.assertLeader(client, partition1, 0) + TestUtils.assertLeader(client, partition2, 0) + + // Now change the preferred leader to 1 + changePreferredLeader(prefer1) + + // meaningful election + electResult = client.electLeaders(ElectionType.PREFERRED, Set(partition1).asJava) + assertEquals(Set(partition1).asJava, electResult.partitions.get.keySet) + assertFalse(electResult.partitions.get.get(partition1).isPresent) + TestUtils.assertLeader(client, partition1, 1) + + // topic 2 unchanged + assertFalse(electResult.partitions.get.containsKey(partition2)) + TestUtils.assertLeader(client, partition2, 0) + + // meaningful election with null partitions + electResult = client.electLeaders(ElectionType.PREFERRED, null) + assertEquals(Set(partition2), electResult.partitions.get.keySet.asScala) + assertFalse(electResult.partitions.get.get(partition2).isPresent) + TestUtils.assertLeader(client, partition2, 1) + + // unknown topic + val unknownPartition = new TopicPartition("topic-does-not-exist", 0) + electResult = client.electLeaders(ElectionType.PREFERRED, Set(unknownPartition).asJava) + assertEquals(Set(unknownPartition).asJava, electResult.partitions.get.keySet) + exception = electResult.partitions.get.get(unknownPartition).get + assertEquals(classOf[UnknownTopicOrPartitionException], exception.getClass) + assertEquals("The partition does not exist.", exception.getMessage) + TestUtils.assertLeader(client, partition1, 1) + TestUtils.assertLeader(client, partition2, 1) + + // Now change the preferred leader to 2 + changePreferredLeader(prefer2) + + // mixed results + electResult = client.electLeaders(ElectionType.PREFERRED, Set(unknownPartition, partition1).asJava) + assertEquals(Set(unknownPartition, partition1).asJava, electResult.partitions.get.keySet) + TestUtils.assertLeader(client, partition1, 2) + TestUtils.assertLeader(client, partition2, 1) + exception = electResult.partitions.get.get(unknownPartition).get + assertEquals(classOf[UnknownTopicOrPartitionException], exception.getClass) + assertEquals("The partition does not exist.", exception.getMessage) + + // elect preferred leader for partition 2 + electResult = client.electLeaders(ElectionType.PREFERRED, Set(partition2).asJava) + assertEquals(Set(partition2).asJava, electResult.partitions.get.keySet) + assertFalse(electResult.partitions.get.get(partition2).isPresent) + TestUtils.assertLeader(client, partition2, 2) + + // Now change the preferred leader to 1 + changePreferredLeader(prefer1) + // but shut it down... + servers(1).shutdown() + TestUtils.waitForBrokersOutOfIsr(client, Set(partition1, partition2), Set(1)) + + // ... now what happens if we try to elect the preferred leader and it's down? + val shortTimeout = new ElectLeadersOptions().timeoutMs(10000) + electResult = client.electLeaders(ElectionType.PREFERRED, Set(partition1).asJava, shortTimeout) + assertEquals(Set(partition1).asJava, electResult.partitions.get.keySet) + exception = electResult.partitions.get.get(partition1).get + assertEquals(classOf[PreferredLeaderNotAvailableException], exception.getClass) + assertTrue(exception.getMessage.contains( + "Failed to elect leader for partition elect-preferred-leaders-topic-1-0 under strategy PreferredReplicaPartitionLeaderElectionStrategy"), + s"Wrong message ${exception.getMessage}") + TestUtils.assertLeader(client, partition1, 2) + + // preferred leader unavailable with null argument + electResult = client.electLeaders(ElectionType.PREFERRED, null, shortTimeout) + + exception = electResult.partitions.get.get(partition1).get + assertEquals(classOf[PreferredLeaderNotAvailableException], exception.getClass) + assertTrue(exception.getMessage.contains( + "Failed to elect leader for partition elect-preferred-leaders-topic-1-0 under strategy PreferredReplicaPartitionLeaderElectionStrategy"), + s"Wrong message ${exception.getMessage}") + + exception = electResult.partitions.get.get(partition2).get + assertEquals(classOf[PreferredLeaderNotAvailableException], exception.getClass) + assertTrue(exception.getMessage.contains( + "Failed to elect leader for partition elect-preferred-leaders-topic-2-0 under strategy PreferredReplicaPartitionLeaderElectionStrategy"), + s"Wrong message ${exception.getMessage}") + + TestUtils.assertLeader(client, partition1, 2) + TestUtils.assertLeader(client, partition2, 2) + } + + @Test + def testElectUncleanLeadersForOnePartition(): Unit = { + // Case: unclean leader election with one topic partition + client = Admin.create(createConfig) + + val broker1 = 1 + val broker2 = 2 + val assignment1 = Seq(broker1, broker2) + + val partition1 = new TopicPartition("unclean-test-topic-1", 0) + TestUtils.createTopic(zkClient, partition1.topic, Map[Int, Seq[Int]](partition1.partition -> assignment1), servers) + + TestUtils.assertLeader(client, partition1, broker1) + + servers(broker2).shutdown() + TestUtils.waitForBrokersOutOfIsr(client, Set(partition1), Set(broker2)) + servers(broker1).shutdown() + TestUtils.assertNoLeader(client, partition1) + servers(broker2).startup() + + val electResult = client.electLeaders(ElectionType.UNCLEAN, Set(partition1).asJava) + assertFalse(electResult.partitions.get.get(partition1).isPresent) + TestUtils.assertLeader(client, partition1, broker2) + } + + @Test + def testElectUncleanLeadersForManyPartitions(): Unit = { + // Case: unclean leader election with many topic partitions + client = Admin.create(createConfig) + + val broker1 = 1 + val broker2 = 2 + val assignment1 = Seq(broker1, broker2) + val assignment2 = Seq(broker1, broker2) + + val topic = "unclean-test-topic-1" + val partition1 = new TopicPartition(topic, 0) + val partition2 = new TopicPartition(topic, 1) + + TestUtils.createTopic( + zkClient, + topic, + Map(partition1.partition -> assignment1, partition2.partition -> assignment2), + servers + ) + + TestUtils.assertLeader(client, partition1, broker1) + TestUtils.assertLeader(client, partition2, broker1) + + servers(broker2).shutdown() + TestUtils.waitForBrokersOutOfIsr(client, Set(partition1, partition2), Set(broker2)) + servers(broker1).shutdown() + TestUtils.assertNoLeader(client, partition1) + TestUtils.assertNoLeader(client, partition2) + servers(broker2).startup() + + val electResult = client.electLeaders(ElectionType.UNCLEAN, Set(partition1, partition2).asJava) + assertFalse(electResult.partitions.get.get(partition1).isPresent) + assertFalse(electResult.partitions.get.get(partition2).isPresent) + TestUtils.assertLeader(client, partition1, broker2) + TestUtils.assertLeader(client, partition2, broker2) + } + + @Test + def testElectUncleanLeadersForAllPartitions(): Unit = { + // Case: noop unclean leader election and valid unclean leader election for all partitions + client = Admin.create(createConfig) + + val broker1 = 1 + val broker2 = 2 + val broker3 = 0 + val assignment1 = Seq(broker1, broker2) + val assignment2 = Seq(broker1, broker3) + + val topic = "unclean-test-topic-1" + val partition1 = new TopicPartition(topic, 0) + val partition2 = new TopicPartition(topic, 1) + + TestUtils.createTopic( + zkClient, + topic, + Map(partition1.partition -> assignment1, partition2.partition -> assignment2), + servers + ) + + TestUtils.assertLeader(client, partition1, broker1) + TestUtils.assertLeader(client, partition2, broker1) + + servers(broker2).shutdown() + TestUtils.waitForBrokersOutOfIsr(client, Set(partition1), Set(broker2)) + servers(broker1).shutdown() + TestUtils.assertNoLeader(client, partition1) + TestUtils.assertLeader(client, partition2, broker3) + servers(broker2).startup() + + val electResult = client.electLeaders(ElectionType.UNCLEAN, null) + assertFalse(electResult.partitions.get.get(partition1).isPresent) + assertFalse(electResult.partitions.get.containsKey(partition2)) + TestUtils.assertLeader(client, partition1, broker2) + TestUtils.assertLeader(client, partition2, broker3) + } + + @Test + def testElectUncleanLeadersForUnknownPartitions(): Unit = { + // Case: unclean leader election for unknown topic + client = Admin.create(createConfig) + + val broker1 = 1 + val broker2 = 2 + val assignment1 = Seq(broker1, broker2) + + val topic = "unclean-test-topic-1" + val unknownPartition = new TopicPartition(topic, 1) + val unknownTopic = new TopicPartition("unknown-topic", 0) + + TestUtils.createTopic( + zkClient, + topic, + Map(0 -> assignment1), + servers + ) + + TestUtils.assertLeader(client, new TopicPartition(topic, 0), broker1) + + val electResult = client.electLeaders(ElectionType.UNCLEAN, Set(unknownPartition, unknownTopic).asJava) + assertTrue(electResult.partitions.get.get(unknownPartition).get.isInstanceOf[UnknownTopicOrPartitionException]) + assertTrue(electResult.partitions.get.get(unknownTopic).get.isInstanceOf[UnknownTopicOrPartitionException]) + } + + @Test + def testElectUncleanLeadersWhenNoLiveBrokers(): Unit = { + // Case: unclean leader election with no live brokers + client = Admin.create(createConfig) + + val broker1 = 1 + val broker2 = 2 + val assignment1 = Seq(broker1, broker2) + + val topic = "unclean-test-topic-1" + val partition1 = new TopicPartition(topic, 0) + + TestUtils.createTopic( + zkClient, + topic, + Map(partition1.partition -> assignment1), + servers + ) + + TestUtils.assertLeader(client, partition1, broker1) + + servers(broker2).shutdown() + TestUtils.waitForBrokersOutOfIsr(client, Set(partition1), Set(broker2)) + servers(broker1).shutdown() + TestUtils.assertNoLeader(client, partition1) + + val electResult = client.electLeaders(ElectionType.UNCLEAN, Set(partition1).asJava) + assertTrue(electResult.partitions.get.get(partition1).get.isInstanceOf[EligibleLeadersNotAvailableException]) + } + + @Test + def testElectUncleanLeadersNoop(): Unit = { + // Case: noop unclean leader election with explicit topic partitions + client = Admin.create(createConfig) + + val broker1 = 1 + val broker2 = 2 + val assignment1 = Seq(broker1, broker2) + + val topic = "unclean-test-topic-1" + val partition1 = new TopicPartition(topic, 0) + + TestUtils.createTopic( + zkClient, + topic, + Map(partition1.partition -> assignment1), + servers + ) + + TestUtils.assertLeader(client, partition1, broker1) + + servers(broker1).shutdown() + TestUtils.assertLeader(client, partition1, broker2) + servers(broker1).startup() + + val electResult = client.electLeaders(ElectionType.UNCLEAN, Set(partition1).asJava) + assertTrue(electResult.partitions.get.get(partition1).get.isInstanceOf[ElectionNotNeededException]) + } + + @Test + def testElectUncleanLeadersAndNoop(): Unit = { + // Case: one noop unclean leader election and one valid unclean leader election + client = Admin.create(createConfig) + + val broker1 = 1 + val broker2 = 2 + val broker3 = 0 + val assignment1 = Seq(broker1, broker2) + val assignment2 = Seq(broker1, broker3) + + val topic = "unclean-test-topic-1" + val partition1 = new TopicPartition(topic, 0) + val partition2 = new TopicPartition(topic, 1) + + TestUtils.createTopic( + zkClient, + topic, + Map(partition1.partition -> assignment1, partition2.partition -> assignment2), + servers + ) + + TestUtils.assertLeader(client, partition1, broker1) + TestUtils.assertLeader(client, partition2, broker1) + + servers(broker2).shutdown() + TestUtils.waitForBrokersOutOfIsr(client, Set(partition1), Set(broker2)) + servers(broker1).shutdown() + TestUtils.assertNoLeader(client, partition1) + TestUtils.assertLeader(client, partition2, broker3) + servers(broker2).startup() + + val electResult = client.electLeaders(ElectionType.UNCLEAN, Set(partition1, partition2).asJava) + assertFalse(electResult.partitions.get.get(partition1).isPresent) + assertTrue(electResult.partitions.get.get(partition2).get.isInstanceOf[ElectionNotNeededException]) + TestUtils.assertLeader(client, partition1, broker2) + TestUtils.assertLeader(client, partition2, broker3) + } + + @Test + def testListReassignmentsDoesNotShowNonReassigningPartitions(): Unit = { + client = Admin.create(createConfig) + + // Create topics + val topic = "list-reassignments-no-reassignments" + createTopic(topic, replicationFactor = 3) + val tp = new TopicPartition(topic, 0) + + val reassignmentsMap = client.listPartitionReassignments(Set(tp).asJava).reassignments().get() + assertEquals(0, reassignmentsMap.size()) + + val allReassignmentsMap = client.listPartitionReassignments().reassignments().get() + assertEquals(0, allReassignmentsMap.size()) + } + + @Test + def testListReassignmentsDoesNotShowDeletedPartitions(): Unit = { + client = Admin.create(createConfig) + + val topic = "list-reassignments-no-reassignments" + val tp = new TopicPartition(topic, 0) + + val reassignmentsMap = client.listPartitionReassignments(Set(tp).asJava).reassignments().get() + assertEquals(0, reassignmentsMap.size()) + + val allReassignmentsMap = client.listPartitionReassignments().reassignments().get() + assertEquals(0, allReassignmentsMap.size()) + } + + @Test + def testValidIncrementalAlterConfigs(): Unit = { + client = Admin.create(createConfig) + + // Create topics + val topic1 = "incremental-alter-configs-topic-1" + val topic1Resource = new ConfigResource(ConfigResource.Type.TOPIC, topic1) + val topic1CreateConfigs = new Properties + topic1CreateConfigs.setProperty(LogConfig.RetentionMsProp, "60000000") + topic1CreateConfigs.setProperty(LogConfig.CleanupPolicyProp, LogConfig.Compact) + createTopic(topic1, numPartitions = 1, replicationFactor = 1, topic1CreateConfigs) + + val topic2 = "incremental-alter-configs-topic-2" + val topic2Resource = new ConfigResource(ConfigResource.Type.TOPIC, topic2) + createTopic(topic2) + + // Alter topic configs + var topic1AlterConfigs = Seq( + new AlterConfigOp(new ConfigEntry(LogConfig.FlushMsProp, "1000"), AlterConfigOp.OpType.SET), + new AlterConfigOp(new ConfigEntry(LogConfig.CleanupPolicyProp, LogConfig.Delete), AlterConfigOp.OpType.APPEND), + new AlterConfigOp(new ConfigEntry(LogConfig.RetentionMsProp, ""), AlterConfigOp.OpType.DELETE) + ).asJavaCollection + + // Test SET and APPEND on previously unset properties + var topic2AlterConfigs = Seq( + new AlterConfigOp(new ConfigEntry(LogConfig.MinCleanableDirtyRatioProp, "0.9"), AlterConfigOp.OpType.SET), + new AlterConfigOp(new ConfigEntry(LogConfig.CompressionTypeProp, "lz4"), AlterConfigOp.OpType.SET), + new AlterConfigOp(new ConfigEntry(LogConfig.CleanupPolicyProp, LogConfig.Compact), AlterConfigOp.OpType.APPEND) + ).asJavaCollection + + var alterResult = client.incrementalAlterConfigs(Map( + topic1Resource -> topic1AlterConfigs, + topic2Resource -> topic2AlterConfigs + ).asJava) + + assertEquals(Set(topic1Resource, topic2Resource).asJava, alterResult.values.keySet) + alterResult.all.get + + // Verify that topics were updated correctly + var describeResult = client.describeConfigs(Seq(topic1Resource, topic2Resource).asJava) + var configs = describeResult.all.get + + assertEquals(2, configs.size) + + assertEquals("1000", configs.get(topic1Resource).get(LogConfig.FlushMsProp).value) + assertEquals("compact,delete", configs.get(topic1Resource).get(LogConfig.CleanupPolicyProp).value) + assertEquals((Defaults.LogRetentionHours * 60 * 60 * 1000).toString, configs.get(topic1Resource).get(LogConfig.RetentionMsProp).value) + + assertEquals("0.9", configs.get(topic2Resource).get(LogConfig.MinCleanableDirtyRatioProp).value) + assertEquals("lz4", configs.get(topic2Resource).get(LogConfig.CompressionTypeProp).value) + assertEquals("delete,compact", configs.get(topic2Resource).get(LogConfig.CleanupPolicyProp).value) + + //verify subtract operation, including from an empty property + topic1AlterConfigs = Seq( + new AlterConfigOp(new ConfigEntry(LogConfig.CleanupPolicyProp, LogConfig.Compact), AlterConfigOp.OpType.SUBTRACT), + new AlterConfigOp(new ConfigEntry(LogConfig.LeaderReplicationThrottledReplicasProp, "0"), AlterConfigOp.OpType.SUBTRACT) + ).asJava + + // subtract all from this list property + topic2AlterConfigs = Seq( + new AlterConfigOp(new ConfigEntry(LogConfig.CleanupPolicyProp, LogConfig.Compact + "," + LogConfig.Delete), AlterConfigOp.OpType.SUBTRACT) + ).asJavaCollection + + alterResult = client.incrementalAlterConfigs(Map( + topic1Resource -> topic1AlterConfigs, + topic2Resource -> topic2AlterConfigs + ).asJava) + assertEquals(Set(topic1Resource, topic2Resource).asJava, alterResult.values.keySet) + alterResult.all.get + + // Verify that topics were updated correctly + describeResult = client.describeConfigs(Seq(topic1Resource, topic2Resource).asJava) + configs = describeResult.all.get + + assertEquals(2, configs.size) + + assertEquals("delete", configs.get(topic1Resource).get(LogConfig.CleanupPolicyProp).value) + assertEquals("1000", configs.get(topic1Resource).get(LogConfig.FlushMsProp).value) // verify previous change is still intact + assertEquals("", configs.get(topic1Resource).get(LogConfig.LeaderReplicationThrottledReplicasProp).value) + assertEquals("", configs.get(topic2Resource).get(LogConfig.CleanupPolicyProp).value ) + + // Alter topics with validateOnly=true + topic1AlterConfigs = Seq( + new AlterConfigOp(new ConfigEntry(LogConfig.CleanupPolicyProp, LogConfig.Compact), AlterConfigOp.OpType.APPEND) + ).asJava + + alterResult = client.incrementalAlterConfigs(Map( + topic1Resource -> topic1AlterConfigs + ).asJava, new AlterConfigsOptions().validateOnly(true)) + alterResult.all.get + + // Verify that topics were not updated due to validateOnly = true + describeResult = client.describeConfigs(Seq(topic1Resource).asJava) + configs = describeResult.all.get + + assertEquals("delete", configs.get(topic1Resource).get(LogConfig.CleanupPolicyProp).value) + + //Alter topics with validateOnly=true with invalid configs + topic1AlterConfigs = Seq( + new AlterConfigOp(new ConfigEntry(LogConfig.CompressionTypeProp, "zip"), AlterConfigOp.OpType.SET) + ).asJava + + alterResult = client.incrementalAlterConfigs(Map( + topic1Resource -> topic1AlterConfigs + ).asJava, new AlterConfigsOptions().validateOnly(true)) + + assertFutureExceptionTypeEquals(alterResult.values().get(topic1Resource), classOf[InvalidRequestException], + Some("Invalid config value for resource")) + } + + @Test + def testIncrementalAlterConfigsDeleteAndSetBrokerConfigs(): Unit = { + client = Admin.create(createConfig) + val broker0Resource = new ConfigResource(ConfigResource.Type.BROKER, "0") + client.incrementalAlterConfigs(Map(broker0Resource -> + Seq(new AlterConfigOp(new ConfigEntry(DynamicConfig.Broker.LeaderReplicationThrottledRateProp, "123"), + AlterConfigOp.OpType.SET), + new AlterConfigOp(new ConfigEntry(DynamicConfig.Broker.FollowerReplicationThrottledRateProp, "456"), + AlterConfigOp.OpType.SET) + ).asJavaCollection).asJava).all().get() + TestUtils.waitUntilTrue(() => { + val broker0Configs = client.describeConfigs(Seq(broker0Resource).asJava). + all().get().get(broker0Resource).entries().asScala.map { + case entry => (entry.name, entry.value) + }.toMap + ("123".equals(broker0Configs.getOrElse(DynamicConfig.Broker.LeaderReplicationThrottledRateProp, "")) && + "456".equals(broker0Configs.getOrElse(DynamicConfig.Broker.FollowerReplicationThrottledRateProp, ""))) + }, "Expected to see the broker properties we just set", pause=25) + client.incrementalAlterConfigs(Map(broker0Resource -> + Seq(new AlterConfigOp(new ConfigEntry(DynamicConfig.Broker.LeaderReplicationThrottledRateProp, ""), + AlterConfigOp.OpType.DELETE), + new AlterConfigOp(new ConfigEntry(DynamicConfig.Broker.FollowerReplicationThrottledRateProp, "654"), + AlterConfigOp.OpType.SET), + new AlterConfigOp(new ConfigEntry(DynamicConfig.Broker.ReplicaAlterLogDirsIoMaxBytesPerSecondProp, "987"), + AlterConfigOp.OpType.SET) + ).asJavaCollection).asJava).all().get() + TestUtils.waitUntilTrue(() => { + val broker0Configs = client.describeConfigs(Seq(broker0Resource).asJava). + all().get().get(broker0Resource).entries().asScala.map { + case entry => (entry.name, entry.value) + }.toMap + ("".equals(broker0Configs.getOrElse(DynamicConfig.Broker.LeaderReplicationThrottledRateProp, "")) && + "654".equals(broker0Configs.getOrElse(DynamicConfig.Broker.FollowerReplicationThrottledRateProp, "")) && + "987".equals(broker0Configs.getOrElse(DynamicConfig.Broker.ReplicaAlterLogDirsIoMaxBytesPerSecondProp, ""))) + }, "Expected to see the broker properties we just modified", pause=25) + } + + @Test + def testIncrementalAlterConfigsDeleteBrokerConfigs(): Unit = { + client = Admin.create(createConfig) + val broker0Resource = new ConfigResource(ConfigResource.Type.BROKER, "0") + client.incrementalAlterConfigs(Map(broker0Resource -> + Seq(new AlterConfigOp(new ConfigEntry(DynamicConfig.Broker.LeaderReplicationThrottledRateProp, "123"), + AlterConfigOp.OpType.SET), + new AlterConfigOp(new ConfigEntry(DynamicConfig.Broker.FollowerReplicationThrottledRateProp, "456"), + AlterConfigOp.OpType.SET), + new AlterConfigOp(new ConfigEntry(DynamicConfig.Broker.ReplicaAlterLogDirsIoMaxBytesPerSecondProp, "789"), + AlterConfigOp.OpType.SET) + ).asJavaCollection).asJava).all().get() + TestUtils.waitUntilTrue(() => { + val broker0Configs = client.describeConfigs(Seq(broker0Resource).asJava). + all().get().get(broker0Resource).entries().asScala.map { + case entry => (entry.name, entry.value) + }.toMap + ("123".equals(broker0Configs.getOrElse(DynamicConfig.Broker.LeaderReplicationThrottledRateProp, "")) && + "456".equals(broker0Configs.getOrElse(DynamicConfig.Broker.FollowerReplicationThrottledRateProp, "")) && + "789".equals(broker0Configs.getOrElse(DynamicConfig.Broker.ReplicaAlterLogDirsIoMaxBytesPerSecondProp, ""))) + }, "Expected to see the broker properties we just set", pause=25) + client.incrementalAlterConfigs(Map(broker0Resource -> + Seq(new AlterConfigOp(new ConfigEntry(DynamicConfig.Broker.LeaderReplicationThrottledRateProp, ""), + AlterConfigOp.OpType.DELETE), + new AlterConfigOp(new ConfigEntry(DynamicConfig.Broker.FollowerReplicationThrottledRateProp, ""), + AlterConfigOp.OpType.DELETE), + new AlterConfigOp(new ConfigEntry(DynamicConfig.Broker.ReplicaAlterLogDirsIoMaxBytesPerSecondProp, ""), + AlterConfigOp.OpType.DELETE) + ).asJavaCollection).asJava).all().get() + TestUtils.waitUntilTrue(() => { + val broker0Configs = client.describeConfigs(Seq(broker0Resource).asJava). + all().get().get(broker0Resource).entries().asScala.map { + case entry => (entry.name, entry.value) + }.toMap + ("".equals(broker0Configs.getOrElse(DynamicConfig.Broker.LeaderReplicationThrottledRateProp, "")) && + "".equals(broker0Configs.getOrElse(DynamicConfig.Broker.FollowerReplicationThrottledRateProp, "")) && + "".equals(broker0Configs.getOrElse(DynamicConfig.Broker.ReplicaAlterLogDirsIoMaxBytesPerSecondProp, ""))) + }, "Expected to see the broker properties we just removed to be deleted", pause=25) + } + + @Test + def testInvalidIncrementalAlterConfigs(): Unit = { + client = Admin.create(createConfig) + + // Create topics + val topic1 = "incremental-alter-configs-topic-1" + val topic1Resource = new ConfigResource(ConfigResource.Type.TOPIC, topic1) + createTopic(topic1) + + val topic2 = "incremental-alter-configs-topic-2" + val topic2Resource = new ConfigResource(ConfigResource.Type.TOPIC, topic2) + createTopic(topic2) + + //Add duplicate Keys for topic1 + var topic1AlterConfigs = Seq( + new AlterConfigOp(new ConfigEntry(LogConfig.MinCleanableDirtyRatioProp, "0.75"), AlterConfigOp.OpType.SET), + new AlterConfigOp(new ConfigEntry(LogConfig.MinCleanableDirtyRatioProp, "0.65"), AlterConfigOp.OpType.SET), + new AlterConfigOp(new ConfigEntry(LogConfig.CompressionTypeProp, "gzip"), AlterConfigOp.OpType.SET) // valid entry + ).asJavaCollection + + //Add valid config for topic2 + var topic2AlterConfigs = Seq( + new AlterConfigOp(new ConfigEntry(LogConfig.MinCleanableDirtyRatioProp, "0.9"), AlterConfigOp.OpType.SET) + ).asJavaCollection + + var alterResult = client.incrementalAlterConfigs(Map( + topic1Resource -> topic1AlterConfigs, + topic2Resource -> topic2AlterConfigs + ).asJava) + assertEquals(Set(topic1Resource, topic2Resource).asJava, alterResult.values.keySet) + + //InvalidRequestException error for topic1 + assertFutureExceptionTypeEquals(alterResult.values().get(topic1Resource), classOf[InvalidRequestException], + Some("Error due to duplicate config keys")) + + //operation should succeed for topic2 + alterResult.values().get(topic2Resource).get() + + // Verify that topic1 is not config not updated, and topic2 config is updated + val describeResult = client.describeConfigs(Seq(topic1Resource, topic2Resource).asJava) + val configs = describeResult.all.get + assertEquals(2, configs.size) + + assertEquals(Defaults.LogCleanerMinCleanRatio.toString, configs.get(topic1Resource).get(LogConfig.MinCleanableDirtyRatioProp).value) + assertEquals(Defaults.CompressionType.toString, configs.get(topic1Resource).get(LogConfig.CompressionTypeProp).value) + assertEquals("0.9", configs.get(topic2Resource).get(LogConfig.MinCleanableDirtyRatioProp).value) + + //check invalid use of append/subtract operation types + topic1AlterConfigs = Seq( + new AlterConfigOp(new ConfigEntry(LogConfig.CompressionTypeProp, "gzip"), AlterConfigOp.OpType.APPEND) + ).asJavaCollection + + topic2AlterConfigs = Seq( + new AlterConfigOp(new ConfigEntry(LogConfig.CompressionTypeProp, "snappy"), AlterConfigOp.OpType.SUBTRACT) + ).asJavaCollection + + alterResult = client.incrementalAlterConfigs(Map( + topic1Resource -> topic1AlterConfigs, + topic2Resource -> topic2AlterConfigs + ).asJava) + assertEquals(Set(topic1Resource, topic2Resource).asJava, alterResult.values.keySet) + + assertFutureExceptionTypeEquals(alterResult.values().get(topic1Resource), classOf[InvalidRequestException], + Some("Config value append is not allowed for config")) + + assertFutureExceptionTypeEquals(alterResult.values().get(topic2Resource), classOf[InvalidRequestException], + Some("Config value subtract is not allowed for config")) + + + //try to add invalid config + topic1AlterConfigs = Seq( + new AlterConfigOp(new ConfigEntry(LogConfig.MinCleanableDirtyRatioProp, "1.1"), AlterConfigOp.OpType.SET) + ).asJavaCollection + + alterResult = client.incrementalAlterConfigs(Map( + topic1Resource -> topic1AlterConfigs + ).asJava) + assertEquals(Set(topic1Resource).asJava, alterResult.values.keySet) + + assertFutureExceptionTypeEquals(alterResult.values().get(topic1Resource), classOf[InvalidRequestException], + Some("Invalid config value for resource")) + } + + @Test + def testInvalidAlterPartitionReassignments(): Unit = { + client = Admin.create(createConfig) + val topic = "alter-reassignments-topic-1" + val tp1 = new TopicPartition(topic, 0) + val tp2 = new TopicPartition(topic, 1) + val tp3 = new TopicPartition(topic, 2) + createTopic(topic, numPartitions = 4) + + + val validAssignment = Optional.of(new NewPartitionReassignment( + (0 until brokerCount).map(_.asInstanceOf[Integer]).asJava + )) + + val nonExistentTp1 = new TopicPartition("topicA", 0) + val nonExistentTp2 = new TopicPartition(topic, 4) + val nonExistentPartitionsResult = client.alterPartitionReassignments(Map( + tp1 -> validAssignment, + tp2 -> validAssignment, + tp3 -> validAssignment, + nonExistentTp1 -> validAssignment, + nonExistentTp2 -> validAssignment + ).asJava).values() + assertFutureExceptionTypeEquals(nonExistentPartitionsResult.get(nonExistentTp1), classOf[UnknownTopicOrPartitionException]) + assertFutureExceptionTypeEquals(nonExistentPartitionsResult.get(nonExistentTp2), classOf[UnknownTopicOrPartitionException]) + + val extraNonExistentReplica = Optional.of(new NewPartitionReassignment((0 until brokerCount + 1).map(_.asInstanceOf[Integer]).asJava)) + val negativeIdReplica = Optional.of(new NewPartitionReassignment(Seq(-3, -2, -1).map(_.asInstanceOf[Integer]).asJava)) + val duplicateReplica = Optional.of(new NewPartitionReassignment(Seq(0, 1, 1).map(_.asInstanceOf[Integer]).asJava)) + val invalidReplicaResult = client.alterPartitionReassignments(Map( + tp1 -> extraNonExistentReplica, + tp2 -> negativeIdReplica, + tp3 -> duplicateReplica + ).asJava).values() + assertFutureExceptionTypeEquals(invalidReplicaResult.get(tp1), classOf[InvalidReplicaAssignmentException]) + assertFutureExceptionTypeEquals(invalidReplicaResult.get(tp2), classOf[InvalidReplicaAssignmentException]) + assertFutureExceptionTypeEquals(invalidReplicaResult.get(tp3), classOf[InvalidReplicaAssignmentException]) + } + + @Test + def testLongTopicNames(): Unit = { + val client = Admin.create(createConfig) + val longTopicName = String.join("", Collections.nCopies(249, "x")); + val invalidTopicName = String.join("", Collections.nCopies(250, "x")); + val newTopics2 = Seq(new NewTopic(invalidTopicName, 3, 3.toShort), + new NewTopic(longTopicName, 3, 3.toShort)) + val results = client.createTopics(newTopics2.asJava).values() + assertTrue(results.containsKey(longTopicName)) + results.get(longTopicName).get() + assertTrue(results.containsKey(invalidTopicName)) + assertFutureExceptionTypeEquals(results.get(invalidTopicName), classOf[InvalidTopicException]) + assertFutureExceptionTypeEquals(client.alterReplicaLogDirs( + Map(new TopicPartitionReplica(longTopicName, 0, 0) -> servers(0).config.logDirs(0)).asJava).all(), + classOf[InvalidTopicException]) + client.close() + } + + // Verify that createTopics and alterConfigs fail with null values + @Test + def testNullConfigs(): Unit = { + + def validateLogConfig(compressionType: String): Unit = { + val logConfig = zkClient.getLogConfigs(Set(topic), Collections.emptyMap[String, AnyRef])._1(topic) + + assertEquals(compressionType, logConfig.originals.get(LogConfig.CompressionTypeProp)) + assertNull(logConfig.originals.get(LogConfig.RetentionBytesProp)) + assertEquals(Defaults.LogRetentionBytes, logConfig.retentionSize) + } + + client = Admin.create(createConfig) + val invalidConfigs = Map[String, String]( + LogConfig.RetentionBytesProp -> null, + LogConfig.CompressionTypeProp -> "producer" + ).asJava + val newTopic = new NewTopic(topic, 2, brokerCount.toShort) + val e1 = assertThrows(classOf[ExecutionException], + () => client.createTopics(Collections.singletonList(newTopic.configs(invalidConfigs))).all.get()) + assertTrue(e1.getCause.isInstanceOf[InvalidRequestException], + s"Unexpected exception ${e1.getCause.getClass}") + + val validConfigs = Map[String, String](LogConfig.CompressionTypeProp -> "producer").asJava + client.createTopics(Collections.singletonList(newTopic.configs(validConfigs))).all.get() + waitForTopics(client, expectedPresent = Seq(topic), expectedMissing = List()) + validateLogConfig(compressionType = "producer") + + val topicResource = new ConfigResource(ConfigResource.Type.TOPIC, topic) + val alterOps = Seq( + new AlterConfigOp(new ConfigEntry(LogConfig.RetentionBytesProp, null), AlterConfigOp.OpType.SET), + new AlterConfigOp(new ConfigEntry(LogConfig.CompressionTypeProp, "lz4"), AlterConfigOp.OpType.SET) + ) + val e2 = assertThrows(classOf[ExecutionException], + () => client.incrementalAlterConfigs(Map(topicResource -> alterOps.asJavaCollection).asJava).all.get) + assertTrue(e2.getCause.isInstanceOf[InvalidRequestException], + s"Unexpected exception ${e2.getCause.getClass}") + validateLogConfig(compressionType = "producer") + } + + @Test + def testDescribeConfigsForLog4jLogLevels(): Unit = { + client = Admin.create(createConfig) + LoggerFactory.getLogger("kafka.cluster.Replica").trace("Message to create the logger") + val loggerConfig = describeBrokerLoggers() + val kafkaLogLevel = loggerConfig.get("kafka").value() + val logCleanerLogLevelConfig = loggerConfig.get("kafka.cluster.Replica") + // we expect the log level to be inherited from the first ancestor with a level configured + assertEquals(kafkaLogLevel, logCleanerLogLevelConfig.value()) + assertEquals("kafka.cluster.Replica", logCleanerLogLevelConfig.name()) + assertEquals(ConfigEntry.ConfigSource.DYNAMIC_BROKER_LOGGER_CONFIG, logCleanerLogLevelConfig.source()) + assertEquals(false, logCleanerLogLevelConfig.isReadOnly) + assertEquals(false, logCleanerLogLevelConfig.isSensitive) + assertTrue(logCleanerLogLevelConfig.synonyms().isEmpty) + } + + @Test + @Disabled // To be re-enabled once KAFKA-8779 is resolved + def testIncrementalAlterConfigsForLog4jLogLevels(): Unit = { + client = Admin.create(createConfig) + + val initialLoggerConfig = describeBrokerLoggers() + val initialRootLogLevel = initialLoggerConfig.get(Log4jController.ROOT_LOGGER).value() + assertEquals(initialRootLogLevel, initialLoggerConfig.get("kafka.controller.KafkaController").value()) + assertEquals(initialRootLogLevel, initialLoggerConfig.get("kafka.log.LogCleaner").value()) + assertEquals(initialRootLogLevel, initialLoggerConfig.get("kafka.server.ReplicaManager").value()) + + val newRootLogLevel = LogLevelConfig.DEBUG_LOG_LEVEL + val alterRootLoggerEntry = Seq( + new AlterConfigOp(new ConfigEntry(Log4jController.ROOT_LOGGER, newRootLogLevel), AlterConfigOp.OpType.SET) + ).asJavaCollection + // Test validateOnly does not change anything + alterBrokerLoggers(alterRootLoggerEntry, validateOnly = true) + val validatedLoggerConfig = describeBrokerLoggers() + assertEquals(initialRootLogLevel, validatedLoggerConfig.get(Log4jController.ROOT_LOGGER).value()) + assertEquals(initialRootLogLevel, validatedLoggerConfig.get("kafka.controller.KafkaController").value()) + assertEquals(initialRootLogLevel, validatedLoggerConfig.get("kafka.log.LogCleaner").value()) + assertEquals(initialRootLogLevel, validatedLoggerConfig.get("kafka.server.ReplicaManager").value()) + assertEquals(initialRootLogLevel, validatedLoggerConfig.get("kafka.zookeeper.ZooKeeperClient").value()) + + // test that we can change them and unset loggers still use the root's log level + alterBrokerLoggers(alterRootLoggerEntry) + val changedRootLoggerConfig = describeBrokerLoggers() + assertEquals(newRootLogLevel, changedRootLoggerConfig.get(Log4jController.ROOT_LOGGER).value()) + assertEquals(newRootLogLevel, changedRootLoggerConfig.get("kafka.controller.KafkaController").value()) + assertEquals(newRootLogLevel, changedRootLoggerConfig.get("kafka.log.LogCleaner").value()) + assertEquals(newRootLogLevel, changedRootLoggerConfig.get("kafka.server.ReplicaManager").value()) + assertEquals(newRootLogLevel, changedRootLoggerConfig.get("kafka.zookeeper.ZooKeeperClient").value()) + + // alter the ZK client's logger so we can later test resetting it + val alterZKLoggerEntry = Seq( + new AlterConfigOp(new ConfigEntry("kafka.zookeeper.ZooKeeperClient", LogLevelConfig.ERROR_LOG_LEVEL), AlterConfigOp.OpType.SET) + ).asJavaCollection + alterBrokerLoggers(alterZKLoggerEntry) + val changedZKLoggerConfig = describeBrokerLoggers() + assertEquals(LogLevelConfig.ERROR_LOG_LEVEL, changedZKLoggerConfig.get("kafka.zookeeper.ZooKeeperClient").value()) + + // properly test various set operations and one delete + val alterLogLevelsEntries = Seq( + new AlterConfigOp(new ConfigEntry("kafka.controller.KafkaController", LogLevelConfig.INFO_LOG_LEVEL), AlterConfigOp.OpType.SET), + new AlterConfigOp(new ConfigEntry("kafka.log.LogCleaner", LogLevelConfig.ERROR_LOG_LEVEL), AlterConfigOp.OpType.SET), + new AlterConfigOp(new ConfigEntry("kafka.server.ReplicaManager", LogLevelConfig.TRACE_LOG_LEVEL), AlterConfigOp.OpType.SET), + new AlterConfigOp(new ConfigEntry("kafka.zookeeper.ZooKeeperClient", ""), AlterConfigOp.OpType.DELETE) // should reset to the root logger level + ).asJavaCollection + alterBrokerLoggers(alterLogLevelsEntries) + val alteredLoggerConfig = describeBrokerLoggers() + assertEquals(newRootLogLevel, alteredLoggerConfig.get(Log4jController.ROOT_LOGGER).value()) + assertEquals(LogLevelConfig.INFO_LOG_LEVEL, alteredLoggerConfig.get("kafka.controller.KafkaController").value()) + assertEquals(LogLevelConfig.ERROR_LOG_LEVEL, alteredLoggerConfig.get("kafka.log.LogCleaner").value()) + assertEquals(LogLevelConfig.TRACE_LOG_LEVEL, alteredLoggerConfig.get("kafka.server.ReplicaManager").value()) + assertEquals(newRootLogLevel, alteredLoggerConfig.get("kafka.zookeeper.ZooKeeperClient").value()) + } + + /** + * 1. Assume ROOT logger == TRACE + * 2. Change kafka.controller.KafkaController logger to INFO + * 3. Unset kafka.controller.KafkaController via AlterConfigOp.OpType.DELETE (resets it to the root logger - TRACE) + * 4. Change ROOT logger to ERROR + * 5. Ensure the kafka.controller.KafkaController logger's level is ERROR (the curent root logger level) + */ + @Test + @Disabled // To be re-enabled once KAFKA-8779 is resolved + def testIncrementalAlterConfigsForLog4jLogLevelsCanResetLoggerToCurrentRoot(): Unit = { + client = Admin.create(createConfig) + // step 1 - configure root logger + val initialRootLogLevel = LogLevelConfig.TRACE_LOG_LEVEL + val alterRootLoggerEntry = Seq( + new AlterConfigOp(new ConfigEntry(Log4jController.ROOT_LOGGER, initialRootLogLevel), AlterConfigOp.OpType.SET) + ).asJavaCollection + alterBrokerLoggers(alterRootLoggerEntry) + val initialLoggerConfig = describeBrokerLoggers() + assertEquals(initialRootLogLevel, initialLoggerConfig.get(Log4jController.ROOT_LOGGER).value()) + assertEquals(initialRootLogLevel, initialLoggerConfig.get("kafka.controller.KafkaController").value()) + + // step 2 - change KafkaController logger to INFO + val alterControllerLoggerEntry = Seq( + new AlterConfigOp(new ConfigEntry("kafka.controller.KafkaController", LogLevelConfig.INFO_LOG_LEVEL), AlterConfigOp.OpType.SET) + ).asJavaCollection + alterBrokerLoggers(alterControllerLoggerEntry) + val changedControllerLoggerConfig = describeBrokerLoggers() + assertEquals(initialRootLogLevel, changedControllerLoggerConfig.get(Log4jController.ROOT_LOGGER).value()) + assertEquals(LogLevelConfig.INFO_LOG_LEVEL, changedControllerLoggerConfig.get("kafka.controller.KafkaController").value()) + + // step 3 - unset KafkaController logger + val deleteControllerLoggerEntry = Seq( + new AlterConfigOp(new ConfigEntry("kafka.controller.KafkaController", ""), AlterConfigOp.OpType.DELETE) + ).asJavaCollection + alterBrokerLoggers(deleteControllerLoggerEntry) + val deletedControllerLoggerConfig = describeBrokerLoggers() + assertEquals(initialRootLogLevel, deletedControllerLoggerConfig.get(Log4jController.ROOT_LOGGER).value()) + assertEquals(initialRootLogLevel, deletedControllerLoggerConfig.get("kafka.controller.KafkaController").value()) + + val newRootLogLevel = LogLevelConfig.ERROR_LOG_LEVEL + val newAlterRootLoggerEntry = Seq( + new AlterConfigOp(new ConfigEntry(Log4jController.ROOT_LOGGER, newRootLogLevel), AlterConfigOp.OpType.SET) + ).asJavaCollection + alterBrokerLoggers(newAlterRootLoggerEntry) + val newRootLoggerConfig = describeBrokerLoggers() + assertEquals(newRootLogLevel, newRootLoggerConfig.get(Log4jController.ROOT_LOGGER).value()) + assertEquals(newRootLogLevel, newRootLoggerConfig.get("kafka.controller.KafkaController").value()) + } + + @Test + @Disabled // To be re-enabled once KAFKA-8779 is resolved + def testIncrementalAlterConfigsForLog4jLogLevelsCannotResetRootLogger(): Unit = { + client = Admin.create(createConfig) + val deleteRootLoggerEntry = Seq( + new AlterConfigOp(new ConfigEntry(Log4jController.ROOT_LOGGER, ""), AlterConfigOp.OpType.DELETE) + ).asJavaCollection + + assertTrue(assertThrows(classOf[ExecutionException], () => alterBrokerLoggers(deleteRootLoggerEntry)).getCause.isInstanceOf[InvalidRequestException]) + } + + @Test + @Disabled // To be re-enabled once KAFKA-8779 is resolved + def testIncrementalAlterConfigsForLog4jLogLevelsDoesNotWorkWithInvalidConfigs(): Unit = { + client = Admin.create(createConfig) + val validLoggerName = "kafka.server.KafkaRequestHandler" + val expectedValidLoggerLogLevel = describeBrokerLoggers().get(validLoggerName) + def assertLogLevelDidNotChange(): Unit = { + assertEquals(expectedValidLoggerLogLevel, describeBrokerLoggers().get(validLoggerName)) + } + + val appendLogLevelEntries = Seq( + new AlterConfigOp(new ConfigEntry("kafka.server.KafkaRequestHandler", LogLevelConfig.INFO_LOG_LEVEL), AlterConfigOp.OpType.SET), // valid + new AlterConfigOp(new ConfigEntry("kafka.network.SocketServer", LogLevelConfig.ERROR_LOG_LEVEL), AlterConfigOp.OpType.APPEND) // append is not supported + ).asJavaCollection + assertTrue(assertThrows(classOf[ExecutionException], + () => alterBrokerLoggers(appendLogLevelEntries)).getCause.isInstanceOf[InvalidRequestException]) + assertLogLevelDidNotChange() + + val subtractLogLevelEntries = Seq( + new AlterConfigOp(new ConfigEntry("kafka.server.KafkaRequestHandler", LogLevelConfig.INFO_LOG_LEVEL), AlterConfigOp.OpType.SET), // valid + new AlterConfigOp(new ConfigEntry("kafka.network.SocketServer", LogLevelConfig.ERROR_LOG_LEVEL), AlterConfigOp.OpType.SUBTRACT) // subtract is not supported + ).asJavaCollection + assertTrue(assertThrows(classOf[ExecutionException], () => alterBrokerLoggers(subtractLogLevelEntries)).getCause.isInstanceOf[InvalidRequestException]) + assertLogLevelDidNotChange() + + val invalidLogLevelLogLevelEntries = Seq( + new AlterConfigOp(new ConfigEntry("kafka.server.KafkaRequestHandler", LogLevelConfig.INFO_LOG_LEVEL), AlterConfigOp.OpType.SET), // valid + new AlterConfigOp(new ConfigEntry("kafka.network.SocketServer", "OFF"), AlterConfigOp.OpType.SET) // OFF is not a valid log level + ).asJavaCollection + assertTrue(assertThrows(classOf[ExecutionException], () => alterBrokerLoggers(invalidLogLevelLogLevelEntries)).getCause.isInstanceOf[InvalidRequestException]) + assertLogLevelDidNotChange() + + val invalidLoggerNameLogLevelEntries = Seq( + new AlterConfigOp(new ConfigEntry("kafka.server.KafkaRequestHandler", LogLevelConfig.INFO_LOG_LEVEL), AlterConfigOp.OpType.SET), // valid + new AlterConfigOp(new ConfigEntry("Some Other LogCleaner", LogLevelConfig.ERROR_LOG_LEVEL), AlterConfigOp.OpType.SET) // invalid logger name is not supported + ).asJavaCollection + assertTrue(assertThrows(classOf[ExecutionException], () => alterBrokerLoggers(invalidLoggerNameLogLevelEntries)).getCause.isInstanceOf[InvalidRequestException]) + assertLogLevelDidNotChange() + } + + /** + * The AlterConfigs API is deprecated and should not support altering log levels + */ + @nowarn("cat=deprecation") + @Test + @Disabled // To be re-enabled once KAFKA-8779 is resolved + def testAlterConfigsForLog4jLogLevelsDoesNotWork(): Unit = { + client = Admin.create(createConfig) + + val alterLogLevelsEntries = Seq( + new ConfigEntry("kafka.controller.KafkaController", LogLevelConfig.INFO_LOG_LEVEL) + ).asJavaCollection + val alterResult = client.alterConfigs(Map(brokerLoggerConfigResource -> new Config(alterLogLevelsEntries)).asJava) + assertTrue(assertThrows(classOf[ExecutionException], () => alterResult.values.get(brokerLoggerConfigResource).get).getCause.isInstanceOf[InvalidRequestException]) + } + + def alterBrokerLoggers(entries: util.Collection[AlterConfigOp], validateOnly: Boolean = false): Unit = { + if (!validateOnly) { + for (entry <- entries.asScala) + changedBrokerLoggers.add(entry.configEntry().name()) + } + + client.incrementalAlterConfigs(Map(brokerLoggerConfigResource -> entries).asJava, new AlterConfigsOptions().validateOnly(validateOnly)) + .values.get(brokerLoggerConfigResource).get() + } + + def describeBrokerLoggers(): Config = + client.describeConfigs(Collections.singletonList(brokerLoggerConfigResource)).values.get(brokerLoggerConfigResource).get() + + /** + * Due to the fact that log4j is not re-initialized across tests, changing a logger's log level persists across test classes. + * We need to clean up the changes done while testing. + */ + private def teardownBrokerLoggers(): Unit = { + if (changedBrokerLoggers.nonEmpty) { + val validLoggers = describeBrokerLoggers().entries().asScala.filterNot(_.name.equals(Log4jController.ROOT_LOGGER)).map(_.name).toSet + val unsetBrokerLoggersEntries = changedBrokerLoggers + .intersect(validLoggers) + .map { logger => new AlterConfigOp(new ConfigEntry(logger, ""), AlterConfigOp.OpType.DELETE) } + .asJavaCollection + + // ensure that we first reset the root logger to an arbitrary log level. Note that we cannot reset it to its original value + alterBrokerLoggers(List( + new AlterConfigOp(new ConfigEntry(Log4jController.ROOT_LOGGER, LogLevelConfig.FATAL_LOG_LEVEL), AlterConfigOp.OpType.SET) + ).asJavaCollection) + alterBrokerLoggers(unsetBrokerLoggersEntries) + + changedBrokerLoggers.clear() + } + } + +} + +object PlaintextAdminIntegrationTest { + + @nowarn("cat=deprecation") + def checkValidAlterConfigs(client: Admin, topicResource1: ConfigResource, topicResource2: ConfigResource): Unit = { + // Alter topics + var topicConfigEntries1 = Seq( + new ConfigEntry(LogConfig.FlushMsProp, "1000") + ).asJava + + var topicConfigEntries2 = Seq( + new ConfigEntry(LogConfig.MinCleanableDirtyRatioProp, "0.9"), + new ConfigEntry(LogConfig.CompressionTypeProp, "lz4") + ).asJava + + var alterResult = client.alterConfigs(Map( + topicResource1 -> new Config(topicConfigEntries1), + topicResource2 -> new Config(topicConfigEntries2) + ).asJava) + + assertEquals(Set(topicResource1, topicResource2).asJava, alterResult.values.keySet) + alterResult.all.get + + // Verify that topics were updated correctly + var describeResult = client.describeConfigs(Seq(topicResource1, topicResource2).asJava) + var configs = describeResult.all.get + + assertEquals(2, configs.size) + + assertEquals("1000", configs.get(topicResource1).get(LogConfig.FlushMsProp).value) + assertEquals(Defaults.MessageMaxBytes.toString, + configs.get(topicResource1).get(LogConfig.MaxMessageBytesProp).value) + assertEquals((Defaults.LogRetentionHours * 60 * 60 * 1000).toString, + configs.get(topicResource1).get(LogConfig.RetentionMsProp).value) + + assertEquals("0.9", configs.get(topicResource2).get(LogConfig.MinCleanableDirtyRatioProp).value) + assertEquals("lz4", configs.get(topicResource2).get(LogConfig.CompressionTypeProp).value) + + // Alter topics with validateOnly=true + topicConfigEntries1 = Seq( + new ConfigEntry(LogConfig.MaxMessageBytesProp, "10") + ).asJava + + topicConfigEntries2 = Seq( + new ConfigEntry(LogConfig.MinCleanableDirtyRatioProp, "0.3") + ).asJava + + alterResult = client.alterConfigs(Map( + topicResource1 -> new Config(topicConfigEntries1), + topicResource2 -> new Config(topicConfigEntries2) + ).asJava, new AlterConfigsOptions().validateOnly(true)) + + assertEquals(Set(topicResource1, topicResource2).asJava, alterResult.values.keySet) + alterResult.all.get + + // Verify that topics were not updated due to validateOnly = true + describeResult = client.describeConfigs(Seq(topicResource1, topicResource2).asJava) + configs = describeResult.all.get + + assertEquals(2, configs.size) + + assertEquals(Defaults.MessageMaxBytes.toString, + configs.get(topicResource1).get(LogConfig.MaxMessageBytesProp).value) + assertEquals("0.9", configs.get(topicResource2).get(LogConfig.MinCleanableDirtyRatioProp).value) + } + + @nowarn("cat=deprecation") + def checkInvalidAlterConfigs(zkClient: KafkaZkClient, servers: Seq[KafkaServer], client: Admin): Unit = { + // Create topics + val topic1 = "invalid-alter-configs-topic-1" + val topicResource1 = new ConfigResource(ConfigResource.Type.TOPIC, topic1) + TestUtils.createTopic(zkClient, topic1, 1, 1, servers) + + val topic2 = "invalid-alter-configs-topic-2" + val topicResource2 = new ConfigResource(ConfigResource.Type.TOPIC, topic2) + TestUtils.createTopic(zkClient, topic2, 1, 1, servers) + + val topicConfigEntries1 = Seq( + new ConfigEntry(LogConfig.MinCleanableDirtyRatioProp, "1.1"), // this value is invalid as it's above 1.0 + new ConfigEntry(LogConfig.CompressionTypeProp, "lz4") + ).asJava + + var topicConfigEntries2 = Seq(new ConfigEntry(LogConfig.CompressionTypeProp, "snappy")).asJava + + val brokerResource = new ConfigResource(ConfigResource.Type.BROKER, servers.head.config.brokerId.toString) + val brokerConfigEntries = Seq(new ConfigEntry(KafkaConfig.ZkConnectProp, "localhost:2181")).asJava + + // Alter configs: first and third are invalid, second is valid + var alterResult = client.alterConfigs(Map( + topicResource1 -> new Config(topicConfigEntries1), + topicResource2 -> new Config(topicConfigEntries2), + brokerResource -> new Config(brokerConfigEntries) + ).asJava) + + assertEquals(Set(topicResource1, topicResource2, brokerResource).asJava, alterResult.values.keySet) + assertTrue(assertThrows(classOf[ExecutionException], () => alterResult.values.get(topicResource1).get).getCause.isInstanceOf[InvalidRequestException]) + alterResult.values.get(topicResource2).get + assertTrue(assertThrows(classOf[ExecutionException], () => alterResult.values.get(brokerResource).get).getCause.isInstanceOf[InvalidRequestException]) + + // Verify that first and third resources were not updated and second was updated + var describeResult = client.describeConfigs(Seq(topicResource1, topicResource2, brokerResource).asJava) + var configs = describeResult.all.get + assertEquals(3, configs.size) + + assertEquals(Defaults.LogCleanerMinCleanRatio.toString, + configs.get(topicResource1).get(LogConfig.MinCleanableDirtyRatioProp).value) + assertEquals(Defaults.CompressionType, + configs.get(topicResource1).get(LogConfig.CompressionTypeProp).value) + + assertEquals("snappy", configs.get(topicResource2).get(LogConfig.CompressionTypeProp).value) + + assertEquals(Defaults.CompressionType, configs.get(brokerResource).get(KafkaConfig.CompressionTypeProp).value) + + // Alter configs with validateOnly = true: first and third are invalid, second is valid + topicConfigEntries2 = Seq(new ConfigEntry(LogConfig.CompressionTypeProp, "gzip")).asJava + + alterResult = client.alterConfigs(Map( + topicResource1 -> new Config(topicConfigEntries1), + topicResource2 -> new Config(topicConfigEntries2), + brokerResource -> new Config(brokerConfigEntries) + ).asJava, new AlterConfigsOptions().validateOnly(true)) + + assertEquals(Set(topicResource1, topicResource2, brokerResource).asJava, alterResult.values.keySet) + assertTrue(assertThrows(classOf[ExecutionException], () => alterResult.values.get(topicResource1).get).getCause.isInstanceOf[InvalidRequestException]) + alterResult.values.get(topicResource2).get + assertTrue(assertThrows(classOf[ExecutionException], () => alterResult.values.get(brokerResource).get).getCause.isInstanceOf[InvalidRequestException]) + + // Verify that no resources are updated since validate_only = true + describeResult = client.describeConfigs(Seq(topicResource1, topicResource2, brokerResource).asJava) + configs = describeResult.all.get + assertEquals(3, configs.size) + + assertEquals(Defaults.LogCleanerMinCleanRatio.toString, + configs.get(topicResource1).get(LogConfig.MinCleanableDirtyRatioProp).value) + assertEquals(Defaults.CompressionType, + configs.get(topicResource1).get(LogConfig.CompressionTypeProp).value) + + assertEquals("snappy", configs.get(topicResource2).get(LogConfig.CompressionTypeProp).value) + + assertEquals(Defaults.CompressionType, configs.get(brokerResource).get(KafkaConfig.CompressionTypeProp).value) + } + +} diff --git a/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala b/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala new file mode 100644 index 0000000..c34d5e1 --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/PlaintextConsumerTest.scala @@ -0,0 +1,1794 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package kafka.api + +import java.time.Duration +import java.util +import java.util.Arrays.asList +import java.util.regex.Pattern +import java.util.{Locale, Optional, Properties} +import kafka.log.LogConfig +import kafka.utils.TestUtils +import org.apache.kafka.clients.consumer._ +import org.apache.kafka.clients.producer.{ProducerConfig, ProducerRecord} +import org.apache.kafka.common.{MetricName, TopicPartition} +import org.apache.kafka.common.errors.{InvalidGroupIdException, InvalidTopicException} +import org.apache.kafka.common.header.Headers +import org.apache.kafka.common.record.{CompressionType, TimestampType} +import org.apache.kafka.common.serialization._ +import org.apache.kafka.common.utils.Utils +import org.apache.kafka.test.{MockConsumerInterceptor, MockProducerInterceptor} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ +import scala.collection.mutable.Buffer +import kafka.server.QuotaType +import kafka.server.KafkaServer + +import scala.collection.mutable + +/* We have some tests in this class instead of `BaseConsumerTest` in order to keep the build time under control. */ +class PlaintextConsumerTest extends BaseConsumerTest { + + @Test + def testHeaders(): Unit = { + val numRecords = 1 + val record = new ProducerRecord(tp.topic, tp.partition, null, "key".getBytes, "value".getBytes) + + record.headers().add("headerKey", "headerValue".getBytes) + + val producer = createProducer() + producer.send(record) + + val consumer = createConsumer() + assertEquals(0, consumer.assignment.size) + consumer.assign(List(tp).asJava) + assertEquals(1, consumer.assignment.size) + + consumer.seek(tp, 0) + val records = consumeRecords(consumer = consumer, numRecords = numRecords) + + assertEquals(numRecords, records.size) + + for (i <- 0 until numRecords) { + val record = records(i) + val header = record.headers().lastHeader("headerKey") + assertEquals("headerValue", if (header == null) null else new String(header.value())) + } + } + + trait SerializerImpl extends Serializer[Array[Byte]]{ + var serializer = new ByteArraySerializer() + + override def serialize(topic: String, headers: Headers, data: Array[Byte]): Array[Byte] = { + headers.add("content-type", "application/octet-stream".getBytes) + serializer.serialize(topic, data) + } + + override def configure(configs: util.Map[String, _], isKey: Boolean): Unit = serializer.configure(configs, isKey) + + override def close(): Unit = serializer.close() + + override def serialize(topic: String, data: Array[Byte]): Array[Byte] = { + fail("method should not be invoked") + null + } + } + + trait DeserializerImpl extends Deserializer[Array[Byte]]{ + var deserializer = new ByteArrayDeserializer() + + override def deserialize(topic: String, headers: Headers, data: Array[Byte]): Array[Byte] = { + val header = headers.lastHeader("content-type") + assertEquals("application/octet-stream", if (header == null) null else new String(header.value())) + deserializer.deserialize(topic, data) + } + + override def configure(configs: util.Map[String, _], isKey: Boolean): Unit = deserializer.configure(configs, isKey) + + override def close(): Unit = deserializer.close() + + override def deserialize(topic: String, data: Array[Byte]): Array[Byte] = { + fail("method should not be invoked") + null + } + } + + private def testHeadersSerializeDeserialize(serializer: Serializer[Array[Byte]], deserializer: Deserializer[Array[Byte]]): Unit = { + val numRecords = 1 + val record = new ProducerRecord(tp.topic, tp.partition, null, "key".getBytes, "value".getBytes) + + val producer = createProducer( + keySerializer = new ByteArraySerializer, + valueSerializer = serializer) + producer.send(record) + + val consumer = createConsumer( + keyDeserializer = new ByteArrayDeserializer, + valueDeserializer = deserializer) + assertEquals(0, consumer.assignment.size) + consumer.assign(List(tp).asJava) + assertEquals(1, consumer.assignment.size) + + consumer.seek(tp, 0) + val records = consumeRecords(consumer = consumer, numRecords = numRecords) + + assertEquals(numRecords, records.size) + } + + @deprecated("poll(Duration) is the replacement", since = "2.0") + @Test + def testDeprecatedPollBlocksForAssignment(): Unit = { + val consumer = createConsumer() + consumer.subscribe(Set(topic).asJava) + consumer.poll(0) + assertEquals(Set(tp, tp2), consumer.assignment().asScala) + } + + @Test + def testHeadersSerializerDeserializer(): Unit = { + val extendedSerializer = new Serializer[Array[Byte]] with SerializerImpl + + val extendedDeserializer = new Deserializer[Array[Byte]] with DeserializerImpl + + testHeadersSerializeDeserialize(extendedSerializer, extendedDeserializer) + } + + @Test + def testMaxPollRecords(): Unit = { + val maxPollRecords = 2 + val numRecords = 10000 + + val producer = createProducer() + val startingTimestamp = System.currentTimeMillis() + sendRecords(producer, numRecords, tp, startingTimestamp = startingTimestamp) + + this.consumerConfig.setProperty(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, maxPollRecords.toString) + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + consumeAndVerifyRecords(consumer, numRecords = numRecords, startingOffset = 0, maxPollRecords = maxPollRecords, + startingTimestamp = startingTimestamp) + } + + @Test + def testMaxPollIntervalMs(): Unit = { + this.consumerConfig.setProperty(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG, 1000.toString) + this.consumerConfig.setProperty(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, 500.toString) + this.consumerConfig.setProperty(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, 2000.toString) + + val consumer = createConsumer() + + val listener = new TestConsumerReassignmentListener() + consumer.subscribe(List(topic).asJava, listener) + + // rebalance to get the initial assignment + awaitRebalance(consumer, listener) + assertEquals(1, listener.callsToAssigned) + assertEquals(0, listener.callsToRevoked) + + // after we extend longer than max.poll a rebalance should be triggered + // NOTE we need to have a relatively much larger value than max.poll to let heartbeat expired for sure + Thread.sleep(3000) + + awaitRebalance(consumer, listener) + assertEquals(2, listener.callsToAssigned) + assertEquals(1, listener.callsToRevoked) + } + + @Test + def testMaxPollIntervalMsDelayInRevocation(): Unit = { + this.consumerConfig.setProperty(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG, 5000.toString) + this.consumerConfig.setProperty(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, 500.toString) + this.consumerConfig.setProperty(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, 1000.toString) + this.consumerConfig.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, false.toString) + + val consumer = createConsumer() + var commitCompleted = false + var committedPosition: Long = -1 + + val listener = new TestConsumerReassignmentListener { + override def onPartitionsLost(partitions: util.Collection[TopicPartition]): Unit = {} + override def onPartitionsRevoked(partitions: util.Collection[TopicPartition]): Unit = { + if (!partitions.isEmpty && partitions.contains(tp)) { + // on the second rebalance (after we have joined the group initially), sleep longer + // than session timeout and then try a commit. We should still be in the group, + // so the commit should succeed + Utils.sleep(1500) + committedPosition = consumer.position(tp) + consumer.commitSync(Map(tp -> new OffsetAndMetadata(committedPosition)).asJava) + commitCompleted = true + } + super.onPartitionsRevoked(partitions) + } + } + + consumer.subscribe(List(topic).asJava, listener) + + // rebalance to get the initial assignment + awaitRebalance(consumer, listener) + + // force a rebalance to trigger an invocation of the revocation callback while in the group + consumer.subscribe(List("otherTopic").asJava, listener) + awaitRebalance(consumer, listener) + + assertEquals(0, committedPosition) + assertTrue(commitCompleted) + } + + @Test + def testMaxPollIntervalMsDelayInAssignment(): Unit = { + this.consumerConfig.setProperty(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG, 5000.toString) + this.consumerConfig.setProperty(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, 500.toString) + this.consumerConfig.setProperty(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, 1000.toString) + this.consumerConfig.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, false.toString) + + val consumer = createConsumer() + val listener = new TestConsumerReassignmentListener { + override def onPartitionsAssigned(partitions: util.Collection[TopicPartition]): Unit = { + // sleep longer than the session timeout, we should still be in the group after invocation + Utils.sleep(1500) + super.onPartitionsAssigned(partitions) + } + } + consumer.subscribe(List(topic).asJava, listener) + + // rebalance to get the initial assignment + awaitRebalance(consumer, listener) + + // We should still be in the group after this invocation + ensureNoRebalance(consumer, listener) + } + + @Test + def testAutoCommitOnClose(): Unit = { + this.consumerConfig.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "true") + val consumer = createConsumer() + + val numRecords = 10000 + val producer = createProducer() + sendRecords(producer, numRecords, tp) + + consumer.subscribe(List(topic).asJava) + awaitAssignment(consumer, Set(tp, tp2)) + + // should auto-commit seeked positions before closing + consumer.seek(tp, 300) + consumer.seek(tp2, 500) + consumer.close() + + // now we should see the committed positions from another consumer + val anotherConsumer = createConsumer() + assertEquals(300, anotherConsumer.committed(Set(tp).asJava).get(tp).offset) + assertEquals(500, anotherConsumer.committed(Set(tp2).asJava).get(tp2).offset) + } + + @Test + def testAutoCommitOnCloseAfterWakeup(): Unit = { + this.consumerConfig.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "true") + val consumer = createConsumer() + + val numRecords = 10000 + val producer = createProducer() + sendRecords(producer, numRecords, tp) + + consumer.subscribe(List(topic).asJava) + awaitAssignment(consumer, Set(tp, tp2)) + + // should auto-commit seeked positions before closing + consumer.seek(tp, 300) + consumer.seek(tp2, 500) + + // wakeup the consumer before closing to simulate trying to break a poll + // loop from another thread + consumer.wakeup() + consumer.close() + + // now we should see the committed positions from another consumer + val anotherConsumer = createConsumer() + assertEquals(300, anotherConsumer.committed(Set(tp).asJava).get(tp).offset) + assertEquals(500, anotherConsumer.committed(Set(tp2).asJava).get(tp2).offset) + } + + @Test + def testAutoOffsetReset(): Unit = { + val producer = createProducer() + val startingTimestamp = System.currentTimeMillis() + sendRecords(producer, numRecords = 1, tp, startingTimestamp = startingTimestamp) + + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + consumeAndVerifyRecords(consumer = consumer, numRecords = 1, startingOffset = 0, startingTimestamp = startingTimestamp) + } + + @Test + def testGroupConsumption(): Unit = { + val producer = createProducer() + val startingTimestamp = System.currentTimeMillis() + sendRecords(producer, numRecords = 10, tp, startingTimestamp = startingTimestamp) + + val consumer = createConsumer() + consumer.subscribe(List(topic).asJava) + consumeAndVerifyRecords(consumer = consumer, numRecords = 1, startingOffset = 0, startingTimestamp = startingTimestamp) + } + + /** + * Verifies that pattern subscription performs as expected. + * The pattern matches the topics 'topic' and 'tblablac', but not 'tblablak' or 'tblab1'. + * It is expected that the consumer is subscribed to all partitions of 'topic' and + * 'tblablac' after the subscription when metadata is refreshed. + * When a new topic 'tsomec' is added afterwards, it is expected that upon the next + * metadata refresh the consumer becomes subscribed to this new topic and all partitions + * of that topic are assigned to it. + */ + @Test + def testPatternSubscription(): Unit = { + val numRecords = 10000 + val producer = createProducer() + sendRecords(producer, numRecords, tp) + + val topic1 = "tblablac" // matches subscribed pattern + createTopic(topic1, 2, brokerCount) + sendRecords(producer, numRecords = 1000, new TopicPartition(topic1, 0)) + sendRecords(producer, numRecords = 1000, new TopicPartition(topic1, 1)) + + val topic2 = "tblablak" // does not match subscribed pattern + createTopic(topic2, 2, brokerCount) + sendRecords(producer,numRecords = 1000, new TopicPartition(topic2, 0)) + sendRecords(producer, numRecords = 1000, new TopicPartition(topic2, 1)) + + val topic3 = "tblab1" // does not match subscribed pattern + createTopic(topic3, 2, brokerCount) + sendRecords(producer, numRecords = 1000, new TopicPartition(topic3, 0)) + sendRecords(producer, numRecords = 1000, new TopicPartition(topic3, 1)) + + val consumer = createConsumer() + assertEquals(0, consumer.assignment().size) + + val pattern = Pattern.compile("t.*c") + consumer.subscribe(pattern, new TestConsumerReassignmentListener) + + var assignment = Set( + new TopicPartition(topic, 0), + new TopicPartition(topic, 1), + new TopicPartition(topic1, 0), + new TopicPartition(topic1, 1)) + awaitAssignment(consumer, assignment) + + val topic4 = "tsomec" // matches subscribed pattern + createTopic(topic4, 2, brokerCount) + sendRecords(producer, numRecords = 1000, new TopicPartition(topic4, 0)) + sendRecords(producer, numRecords = 1000, new TopicPartition(topic4, 1)) + + assignment ++= Set( + new TopicPartition(topic4, 0), + new TopicPartition(topic4, 1)) + awaitAssignment(consumer, assignment) + + consumer.unsubscribe() + assertEquals(0, consumer.assignment().size) + } + + /** + * Verifies that a second call to pattern subscription succeeds and performs as expected. + * The initial subscription is to a pattern that matches two topics 'topic' and 'foo'. + * The second subscription is to a pattern that matches 'foo' and a new topic 'bar'. + * It is expected that the consumer is subscribed to all partitions of 'topic' and 'foo' after + * the first subscription, and to all partitions of 'foo' and 'bar' after the second. + * The metadata refresh interval is intentionally increased to a large enough value to guarantee + * that it is the subscription call that triggers a metadata refresh, and not the timeout. + */ + @Test + def testSubsequentPatternSubscription(): Unit = { + this.consumerConfig.setProperty(ConsumerConfig.METADATA_MAX_AGE_CONFIG, "30000") + val consumer = createConsumer() + + val numRecords = 10000 + val producer = createProducer() + sendRecords(producer, numRecords = numRecords, tp) + + // the first topic ('topic') matches first subscription pattern only + + val fooTopic = "foo" // matches both subscription patterns + createTopic(fooTopic, 1, brokerCount) + sendRecords(producer, numRecords = 1000, new TopicPartition(fooTopic, 0)) + + assertEquals(0, consumer.assignment().size) + + val pattern1 = Pattern.compile(".*o.*") // only 'topic' and 'foo' match this + consumer.subscribe(pattern1, new TestConsumerReassignmentListener) + + var assignment = Set( + new TopicPartition(topic, 0), + new TopicPartition(topic, 1), + new TopicPartition(fooTopic, 0)) + awaitAssignment(consumer, assignment) + + val barTopic = "bar" // matches the next subscription pattern + createTopic(barTopic, 1, brokerCount) + sendRecords(producer, numRecords = 1000, new TopicPartition(barTopic, 0)) + + val pattern2 = Pattern.compile("...") // only 'foo' and 'bar' match this + consumer.subscribe(pattern2, new TestConsumerReassignmentListener) + assignment --= Set( + new TopicPartition(topic, 0), + new TopicPartition(topic, 1)) + assignment ++= Set( + new TopicPartition(barTopic, 0)) + awaitAssignment(consumer, assignment) + + consumer.unsubscribe() + assertEquals(0, consumer.assignment().size) + } + + /** + * Verifies that pattern unsubscription performs as expected. + * The pattern matches the topics 'topic' and 'tblablac'. + * It is expected that the consumer is subscribed to all partitions of 'topic' and + * 'tblablac' after the subscription when metadata is refreshed. + * When consumer unsubscribes from all its subscriptions, it is expected that its + * assignments are cleared right away. + */ + @Test + def testPatternUnsubscription(): Unit = { + val numRecords = 10000 + val producer = createProducer() + sendRecords(producer, numRecords, tp) + + val topic1 = "tblablac" // matches the subscription pattern + createTopic(topic1, 2, brokerCount) + sendRecords(producer, numRecords = 1000, new TopicPartition(topic1, 0)) + sendRecords(producer, numRecords = 1000, new TopicPartition(topic1, 1)) + + val consumer = createConsumer() + assertEquals(0, consumer.assignment().size) + + consumer.subscribe(Pattern.compile("t.*c"), new TestConsumerReassignmentListener) + val assignment = Set( + new TopicPartition(topic, 0), + new TopicPartition(topic, 1), + new TopicPartition(topic1, 0), + new TopicPartition(topic1, 1)) + awaitAssignment(consumer, assignment) + + consumer.unsubscribe() + assertEquals(0, consumer.assignment().size) + } + + @Test + def testCommitMetadata(): Unit = { + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + + // sync commit + val syncMetadata = new OffsetAndMetadata(5, Optional.of(15), "foo") + consumer.commitSync(Map((tp, syncMetadata)).asJava) + assertEquals(syncMetadata, consumer.committed(Set(tp).asJava).get(tp)) + + // async commit + val asyncMetadata = new OffsetAndMetadata(10, "bar") + sendAndAwaitAsyncCommit(consumer, Some(Map(tp -> asyncMetadata))) + assertEquals(asyncMetadata, consumer.committed(Set(tp).asJava).get(tp)) + + // handle null metadata + val nullMetadata = new OffsetAndMetadata(5, null) + consumer.commitSync(Map(tp -> nullMetadata).asJava) + assertEquals(nullMetadata, consumer.committed(Set(tp).asJava).get(tp)) + } + + @Test + def testAsyncCommit(): Unit = { + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + + val callback = new CountConsumerCommitCallback + val count = 5 + + for (i <- 1 to count) + consumer.commitAsync(Map(tp -> new OffsetAndMetadata(i)).asJava, callback) + + TestUtils.pollUntilTrue(consumer, () => callback.successCount >= count || callback.lastError.isDefined, + "Failed to observe commit callback before timeout", waitTimeMs = 10000) + + assertEquals(None, callback.lastError) + assertEquals(count, callback.successCount) + assertEquals(new OffsetAndMetadata(count), consumer.committed(Set(tp).asJava).get(tp)) + } + + @Test + def testExpandingTopicSubscriptions(): Unit = { + val otherTopic = "other" + val initialAssignment = Set(new TopicPartition(topic, 0), new TopicPartition(topic, 1)) + val consumer = createConsumer() + consumer.subscribe(List(topic).asJava) + awaitAssignment(consumer, initialAssignment) + + createTopic(otherTopic, 2, brokerCount) + val expandedAssignment = initialAssignment ++ Set(new TopicPartition(otherTopic, 0), new TopicPartition(otherTopic, 1)) + consumer.subscribe(List(topic, otherTopic).asJava) + awaitAssignment(consumer, expandedAssignment) + } + + @Test + def testShrinkingTopicSubscriptions(): Unit = { + val otherTopic = "other" + createTopic(otherTopic, 2, brokerCount) + val initialAssignment = Set(new TopicPartition(topic, 0), new TopicPartition(topic, 1), new TopicPartition(otherTopic, 0), new TopicPartition(otherTopic, 1)) + val consumer = createConsumer() + consumer.subscribe(List(topic, otherTopic).asJava) + awaitAssignment(consumer, initialAssignment) + + val shrunkenAssignment = Set(new TopicPartition(topic, 0), new TopicPartition(topic, 1)) + consumer.subscribe(List(topic).asJava) + awaitAssignment(consumer, shrunkenAssignment) + } + + @Test + def testPartitionsFor(): Unit = { + val numParts = 2 + createTopic("part-test", numParts, 1) + val consumer = createConsumer() + val parts = consumer.partitionsFor("part-test") + assertNotNull(parts) + assertEquals(2, parts.size) + } + + @Test + def testPartitionsForAutoCreate(): Unit = { + val consumer = createConsumer() + // First call would create the topic + consumer.partitionsFor("non-exist-topic") + val partitions = consumer.partitionsFor("non-exist-topic") + assertFalse(partitions.isEmpty) + } + + @Test + def testPartitionsForInvalidTopic(): Unit = { + val consumer = createConsumer() + assertThrows(classOf[InvalidTopicException], () => consumer.partitionsFor(";3# ads,{234")) + } + + @Test + def testSeek(): Unit = { + val consumer = createConsumer() + val totalRecords = 50L + val mid = totalRecords / 2 + + // Test seek non-compressed message + val producer = createProducer() + val startingTimestamp = 0 + sendRecords(producer, totalRecords.toInt, tp, startingTimestamp = startingTimestamp) + consumer.assign(List(tp).asJava) + + consumer.seekToEnd(List(tp).asJava) + assertEquals(totalRecords, consumer.position(tp)) + assertTrue(consumer.poll(Duration.ofMillis(50)).isEmpty) + + consumer.seekToBeginning(List(tp).asJava) + assertEquals(0L, consumer.position(tp)) + consumeAndVerifyRecords(consumer, numRecords = 1, startingOffset = 0, startingTimestamp = startingTimestamp) + + consumer.seek(tp, mid) + assertEquals(mid, consumer.position(tp)) + + consumeAndVerifyRecords(consumer, numRecords = 1, startingOffset = mid.toInt, startingKeyAndValueIndex = mid.toInt, + startingTimestamp = mid.toLong) + + // Test seek compressed message + sendCompressedMessages(totalRecords.toInt, tp2) + consumer.assign(List(tp2).asJava) + + consumer.seekToEnd(List(tp2).asJava) + assertEquals(totalRecords, consumer.position(tp2)) + assertTrue(consumer.poll(Duration.ofMillis(50)).isEmpty) + + consumer.seekToBeginning(List(tp2).asJava) + assertEquals(0L, consumer.position(tp2)) + consumeAndVerifyRecords(consumer, numRecords = 1, startingOffset = 0, tp = tp2) + + consumer.seek(tp2, mid) + assertEquals(mid, consumer.position(tp2)) + consumeAndVerifyRecords(consumer, numRecords = 1, startingOffset = mid.toInt, startingKeyAndValueIndex = mid.toInt, + startingTimestamp = mid.toLong, tp = tp2) + } + + private def sendCompressedMessages(numRecords: Int, tp: TopicPartition): Unit = { + val producerProps = new Properties() + producerProps.setProperty(ProducerConfig.COMPRESSION_TYPE_CONFIG, CompressionType.GZIP.name) + producerProps.setProperty(ProducerConfig.LINGER_MS_CONFIG, Int.MaxValue.toString) + val producer = createProducer(configOverrides = producerProps) + (0 until numRecords).foreach { i => + producer.send(new ProducerRecord(tp.topic, tp.partition, i.toLong, s"key $i".getBytes, s"value $i".getBytes)) + } + producer.close() + } + + @Test + def testPositionAndCommit(): Unit = { + val producer = createProducer() + var startingTimestamp = System.currentTimeMillis() + sendRecords(producer, numRecords = 5, tp, startingTimestamp = startingTimestamp) + + val topicPartition = new TopicPartition(topic, 15) + val consumer = createConsumer() + assertNull(consumer.committed(Set(topicPartition).asJava).get(topicPartition)) + + // position() on a partition that we aren't subscribed to throws an exception + assertThrows(classOf[IllegalStateException], () => consumer.position(topicPartition)) + + consumer.assign(List(tp).asJava) + + assertEquals(0L, consumer.position(tp), "position() on a partition that we are subscribed to should reset the offset") + consumer.commitSync() + assertEquals(0L, consumer.committed(Set(tp).asJava).get(tp).offset) + + consumeAndVerifyRecords(consumer = consumer, numRecords = 5, startingOffset = 0, startingTimestamp = startingTimestamp) + assertEquals(5L, consumer.position(tp), "After consuming 5 records, position should be 5") + consumer.commitSync() + assertEquals(5L, consumer.committed(Set(tp).asJava).get(tp).offset, "Committed offset should be returned") + + startingTimestamp = System.currentTimeMillis() + sendRecords(producer, numRecords = 1, tp, startingTimestamp = startingTimestamp) + + // another consumer in the same group should get the same position + val otherConsumer = createConsumer() + otherConsumer.assign(List(tp).asJava) + consumeAndVerifyRecords(consumer = otherConsumer, numRecords = 1, startingOffset = 5, startingTimestamp = startingTimestamp) + } + + @Test + def testPartitionPauseAndResume(): Unit = { + val partitions = List(tp).asJava + val producer = createProducer() + var startingTimestamp = System.currentTimeMillis() + sendRecords(producer, numRecords = 5, tp, startingTimestamp = startingTimestamp) + + val consumer = createConsumer() + consumer.assign(partitions) + consumeAndVerifyRecords(consumer = consumer, numRecords = 5, startingOffset = 0, startingTimestamp = startingTimestamp) + consumer.pause(partitions) + startingTimestamp = System.currentTimeMillis() + sendRecords(producer, numRecords = 5, tp, startingTimestamp = startingTimestamp) + assertTrue(consumer.poll(Duration.ofMillis(100)).isEmpty) + consumer.resume(partitions) + consumeAndVerifyRecords(consumer = consumer, numRecords = 5, startingOffset = 5, startingTimestamp = startingTimestamp) + } + + @Test + def testFetchInvalidOffset(): Unit = { + this.consumerConfig.setProperty(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "none") + val consumer = createConsumer() + + // produce one record + val totalRecords = 2 + val producer = createProducer() + sendRecords(producer, totalRecords, tp) + consumer.assign(List(tp).asJava) + + // poll should fail because there is no offset reset strategy set. + // we fail only when resetting positions after coordinator is known, so using a long timeout. + assertThrows(classOf[NoOffsetForPartitionException], () => consumer.poll(Duration.ofMillis(15000))) + + // seek to out of range position + val outOfRangePos = totalRecords + 1 + consumer.seek(tp, outOfRangePos) + val e = assertThrows(classOf[OffsetOutOfRangeException], () => consumer.poll(Duration.ofMillis(20000))) + val outOfRangePartitions = e.offsetOutOfRangePartitions() + assertNotNull(outOfRangePartitions) + assertEquals(1, outOfRangePartitions.size) + assertEquals(outOfRangePos.toLong, outOfRangePartitions.get(tp)) + } + + @Test + def testFetchRecordLargerThanFetchMaxBytes(): Unit = { + val maxFetchBytes = 10 * 1024 + this.consumerConfig.setProperty(ConsumerConfig.FETCH_MAX_BYTES_CONFIG, maxFetchBytes.toString) + checkLargeRecord(maxFetchBytes + 1) + } + + private def checkLargeRecord(producerRecordSize: Int): Unit = { + val consumer = createConsumer() + + // produce a record that is larger than the configured fetch size + val record = new ProducerRecord(tp.topic(), tp.partition(), "key".getBytes, + new Array[Byte](producerRecordSize)) + val producer = createProducer() + producer.send(record) + + // consuming a record that is too large should succeed since KIP-74 + consumer.assign(List(tp).asJava) + val records = consumer.poll(Duration.ofMillis(20000)) + assertEquals(1, records.count) + val consumerRecord = records.iterator().next() + assertEquals(0L, consumerRecord.offset) + assertEquals(tp.topic(), consumerRecord.topic()) + assertEquals(tp.partition(), consumerRecord.partition()) + assertArrayEquals(record.key(), consumerRecord.key()) + assertArrayEquals(record.value(), consumerRecord.value()) + } + + /** We should only return a large record if it's the first record in the first non-empty partition of the fetch request */ + @Test + def testFetchHonoursFetchSizeIfLargeRecordNotFirst(): Unit = { + val maxFetchBytes = 10 * 1024 + this.consumerConfig.setProperty(ConsumerConfig.FETCH_MAX_BYTES_CONFIG, maxFetchBytes.toString) + checkFetchHonoursSizeIfLargeRecordNotFirst(maxFetchBytes) + } + + private def checkFetchHonoursSizeIfLargeRecordNotFirst(largeProducerRecordSize: Int): Unit = { + val consumer = createConsumer() + + val smallRecord = new ProducerRecord(tp.topic(), tp.partition(), "small".getBytes, + "value".getBytes) + val largeRecord = new ProducerRecord(tp.topic(), tp.partition(), "large".getBytes, + new Array[Byte](largeProducerRecordSize)) + + val producer = createProducer() + producer.send(smallRecord).get + producer.send(largeRecord).get + + // we should only get the small record in the first `poll` + consumer.assign(List(tp).asJava) + val records = consumer.poll(Duration.ofMillis(20000)) + assertEquals(1, records.count) + val consumerRecord = records.iterator().next() + assertEquals(0L, consumerRecord.offset) + assertEquals(tp.topic(), consumerRecord.topic()) + assertEquals(tp.partition(), consumerRecord.partition()) + assertArrayEquals(smallRecord.key(), consumerRecord.key()) + assertArrayEquals(smallRecord.value(), consumerRecord.value()) + } + + /** We should only return a large record if it's the first record in the first partition of the fetch request */ + @Test + def testFetchHonoursMaxPartitionFetchBytesIfLargeRecordNotFirst(): Unit = { + val maxPartitionFetchBytes = 10 * 1024 + this.consumerConfig.setProperty(ConsumerConfig.MAX_PARTITION_FETCH_BYTES_CONFIG, maxPartitionFetchBytes.toString) + checkFetchHonoursSizeIfLargeRecordNotFirst(maxPartitionFetchBytes) + } + + @Test + def testFetchRecordLargerThanMaxPartitionFetchBytes(): Unit = { + val maxPartitionFetchBytes = 10 * 1024 + this.consumerConfig.setProperty(ConsumerConfig.MAX_PARTITION_FETCH_BYTES_CONFIG, maxPartitionFetchBytes.toString) + checkLargeRecord(maxPartitionFetchBytes + 1) + } + + /** Test that we consume all partitions if fetch max bytes and max.partition.fetch.bytes are low */ + @Test + def testLowMaxFetchSizeForRequestAndPartition(): Unit = { + // one of the effects of this is that there will be some log reads where `0 > remaining limit bytes < message size` + // and we don't return the message because it's not the first message in the first non-empty partition of the fetch + // this behaves a little different than when remaining limit bytes is 0 and it's important to test it + this.consumerConfig.setProperty(ConsumerConfig.FETCH_MAX_BYTES_CONFIG, "500") + this.consumerConfig.setProperty(ConsumerConfig.MAX_PARTITION_FETCH_BYTES_CONFIG, "100") + + // Avoid a rebalance while the records are being sent (the default is 6 seconds) + this.consumerConfig.setProperty(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG, 20000.toString) + val consumer = createConsumer() + + val topic1 = "topic1" + val topic2 = "topic2" + val topic3 = "topic3" + val partitionCount = 30 + val topics = Seq(topic1, topic2, topic3) + topics.foreach { topicName => + createTopic(topicName, partitionCount, brokerCount) + } + + val partitions = topics.flatMap { topic => + (0 until partitionCount).map(new TopicPartition(topic, _)) + } + + assertEquals(0, consumer.assignment().size) + + consumer.subscribe(List(topic1, topic2, topic3).asJava) + + awaitAssignment(consumer, partitions.toSet) + + val producer = createProducer() + + val producerRecords = partitions.flatMap(sendRecords(producer, numRecords = partitionCount, _)) + + val consumerRecords = consumeRecords(consumer, producerRecords.size) + + val expected = producerRecords.map { record => + (record.topic, record.partition, new String(record.key), new String(record.value), record.timestamp) + }.toSet + + val actual = consumerRecords.map { record => + (record.topic, record.partition, new String(record.key), new String(record.value), record.timestamp) + }.toSet + + assertEquals(expected, actual) + } + + @Test + def testRoundRobinAssignment(): Unit = { + // 1 consumer using round-robin assignment + this.consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "roundrobin-group") + this.consumerConfig.setProperty(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG, classOf[RoundRobinAssignor].getName) + val consumer = createConsumer() + + // create two new topics, each having 2 partitions + val topic1 = "topic1" + val topic2 = "topic2" + val producer = createProducer() + val expectedAssignment = createTopicAndSendRecords(producer, topic1, 2, 100) ++ + createTopicAndSendRecords(producer, topic2, 2, 100) + + assertEquals(0, consumer.assignment().size) + + // subscribe to two topics + consumer.subscribe(List(topic1, topic2).asJava) + awaitAssignment(consumer, expectedAssignment) + + // add one more topic with 2 partitions + val topic3 = "topic3" + createTopicAndSendRecords(producer, topic3, 2, 100) + + val newExpectedAssignment = expectedAssignment ++ Set(new TopicPartition(topic3, 0), new TopicPartition(topic3, 1)) + consumer.subscribe(List(topic1, topic2, topic3).asJava) + awaitAssignment(consumer, newExpectedAssignment) + + // remove the topic we just added + consumer.subscribe(List(topic1, topic2).asJava) + awaitAssignment(consumer, expectedAssignment) + + consumer.unsubscribe() + assertEquals(0, consumer.assignment().size) + } + + @Test + def testMultiConsumerRoundRobinAssignor(): Unit = { + this.consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "roundrobin-group") + this.consumerConfig.setProperty(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG, classOf[RoundRobinAssignor].getName) + + // create two new topics, total number of partitions must be greater than number of consumers + val topic1 = "topic1" + val topic2 = "topic2" + val producer = createProducer() + val subscriptions = createTopicAndSendRecords(producer, topic1, 5, 100) ++ + createTopicAndSendRecords(producer, topic2, 8, 100) + + // create a group of consumers, subscribe the consumers to all the topics and start polling + // for the topic partition assignment + val (consumerGroup, consumerPollers) = createConsumerGroupAndWaitForAssignment(10, List(topic1, topic2), subscriptions) + try { + validateGroupAssignment(consumerPollers, subscriptions) + + // add one more consumer and validate re-assignment + addConsumersToGroupAndWaitForGroupAssignment(1, consumerGroup, consumerPollers, + List(topic1, topic2), subscriptions, "roundrobin-group") + } finally { + consumerPollers.foreach(_.shutdown()) + } + } + + /** + * This test runs the following scenario to verify sticky assignor behavior. + * Topics: single-topic, with random number of partitions, where #par is 10, 20, 30, 40, 50, 60, 70, 80, 90, or 100 + * Consumers: 9 consumers subscribed to the single topic + * Expected initial assignment: partitions are assigned to consumers in a round robin fashion. + * - (#par mod 9) consumers will get (#par / 9 + 1) partitions, and the rest get (#par / 9) partitions + * Then consumer #10 is added to the list (subscribing to the same single topic) + * Expected new assignment: + * - (#par / 10) partition per consumer, where one partition from each of the early (#par mod 9) consumers + * will move to consumer #10, leading to a total of (#par mod 9) partition movement + */ + @Test + def testMultiConsumerStickyAssignor(): Unit = { + + def reverse(m: Map[Long, Set[TopicPartition]]) = + m.values.toSet.flatten.map(v => (v, m.keys.filter(m(_).contains(v)).head)).toMap + + this.consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "sticky-group") + this.consumerConfig.setProperty(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG, classOf[StickyAssignor].getName) + + // create one new topic + val topic = "single-topic" + val rand = 1 + scala.util.Random.nextInt(10) + val producer = createProducer() + val partitions = createTopicAndSendRecords(producer, topic, rand * 10, 100) + + // create a group of consumers, subscribe the consumers to the single topic and start polling + // for the topic partition assignment + val (consumerGroup, consumerPollers) = createConsumerGroupAndWaitForAssignment(9, List(topic), partitions) + validateGroupAssignment(consumerPollers, partitions) + val prePartition2PollerId = reverse(consumerPollers.map(poller => (poller.getId, poller.consumerAssignment())).toMap) + + // add one more consumer and validate re-assignment + addConsumersToGroupAndWaitForGroupAssignment(1, consumerGroup, consumerPollers, List(topic), partitions, "sticky-group") + + val postPartition2PollerId = reverse(consumerPollers.map(poller => (poller.getId, poller.consumerAssignment())).toMap) + val keys = prePartition2PollerId.keySet.union(postPartition2PollerId.keySet) + var changes = 0 + keys.foreach { key => + val preVal = prePartition2PollerId.get(key) + val postVal = postPartition2PollerId.get(key) + if (preVal.nonEmpty && postVal.nonEmpty) { + if (preVal.get != postVal.get) + changes += 1 + } else + changes += 1 + } + + consumerPollers.foreach(_.shutdown()) + + assertEquals(rand, changes, "Expected only two topic partitions that have switched to other consumers.") + } + + /** + * This test re-uses BaseConsumerTest's consumers. + * As a result, it is testing the default assignment strategy set by BaseConsumerTest + */ + @Test + def testMultiConsumerDefaultAssignor(): Unit = { + // use consumers and topics defined in this class + one more topic + val producer = createProducer() + sendRecords(producer, numRecords = 100, tp) + sendRecords(producer, numRecords = 100, tp2) + val topic1 = "topic1" + val subscriptions = Set(tp, tp2) ++ createTopicAndSendRecords(producer, topic1, 5, 100) + + // subscribe all consumers to all topics and validate the assignment + + val consumersInGroup = Buffer[KafkaConsumer[Array[Byte], Array[Byte]]]() + consumersInGroup += createConsumer() + consumersInGroup += createConsumer() + + val consumerPollers = subscribeConsumers(consumersInGroup, List(topic, topic1)) + try { + validateGroupAssignment(consumerPollers, subscriptions) + + // add 2 more consumers and validate re-assignment + addConsumersToGroupAndWaitForGroupAssignment(2, consumersInGroup, consumerPollers, List(topic, topic1), subscriptions) + + // add one more topic and validate partition re-assignment + val topic2 = "topic2" + val expandedSubscriptions = subscriptions ++ createTopicAndSendRecords(producer, topic2, 3, 100) + changeConsumerGroupSubscriptionAndValidateAssignment(consumerPollers, List(topic, topic1, topic2), expandedSubscriptions) + + // remove the topic we just added and validate re-assignment + changeConsumerGroupSubscriptionAndValidateAssignment(consumerPollers, List(topic, topic1), subscriptions) + + } finally { + consumerPollers.foreach(_.shutdown()) + } + } + + /** + * This test re-uses BaseConsumerTest's consumers. + * As a result, it is testing the default assignment strategy set by BaseConsumerTest + * It tests the assignment results is expected using default assignor (i.e. Range assignor) + */ + @Test + def testMultiConsumerDefaultAssignorAndVerifyAssignment(): Unit = { + // create two new topics, each having 3 partitions + val topic1 = "topic1" + val topic2 = "topic2" + + createTopic(topic1, 3) + createTopic(topic2, 3) + + val consumersInGroup = Buffer[KafkaConsumer[Array[Byte], Array[Byte]]]() + consumersInGroup += createConsumer() + consumersInGroup += createConsumer() + + val tp1_0 = new TopicPartition(topic1, 0) + val tp1_1 = new TopicPartition(topic1, 1) + val tp1_2 = new TopicPartition(topic1, 2) + val tp2_0 = new TopicPartition(topic2, 0) + val tp2_1 = new TopicPartition(topic2, 1) + val tp2_2 = new TopicPartition(topic2, 2) + + val subscriptions = Set(tp1_0, tp1_1, tp1_2, tp2_0, tp2_1, tp2_2) + val consumerPollers = subscribeConsumers(consumersInGroup, List(topic1, topic2)) + + val expectedAssignment = Buffer(Set(tp1_0, tp1_1, tp2_0, tp2_1), Set(tp1_2, tp2_2)) + + try { + validateGroupAssignment(consumerPollers, subscriptions, expectedAssignment = expectedAssignment) + } finally { + consumerPollers.foreach(_.shutdown()) + } + } + + @Test + def testMultiConsumerSessionTimeoutOnStopPolling(): Unit = { + runMultiConsumerSessionTimeoutTest(false) + } + + @Test + def testMultiConsumerSessionTimeoutOnClose(): Unit = { + runMultiConsumerSessionTimeoutTest(true) + } + + @Test + def testInterceptors(): Unit = { + val appendStr = "mock" + MockConsumerInterceptor.resetCounters() + MockProducerInterceptor.resetCounters() + + // create producer with interceptor + val producerProps = new Properties() + producerProps.put(ProducerConfig.INTERCEPTOR_CLASSES_CONFIG, classOf[MockProducerInterceptor].getName) + producerProps.put("mock.interceptor.append", appendStr) + val testProducer = createProducer(keySerializer = new StringSerializer, + valueSerializer = new StringSerializer, + configOverrides = producerProps) + + // produce records + val numRecords = 10 + (0 until numRecords).map { i => + testProducer.send(new ProducerRecord(tp.topic, tp.partition, s"key $i", s"value $i")) + }.foreach(_.get) + assertEquals(numRecords, MockProducerInterceptor.ONSEND_COUNT.intValue) + assertEquals(numRecords, MockProducerInterceptor.ON_SUCCESS_COUNT.intValue) + // send invalid record + assertThrows(classOf[Throwable], () => testProducer.send(null), () => "Should not allow sending a null record") + assertEquals(1, MockProducerInterceptor.ON_ERROR_COUNT.intValue, "Interceptor should be notified about exception") + assertEquals(0, MockProducerInterceptor.ON_ERROR_WITH_METADATA_COUNT.intValue(), "Interceptor should not receive metadata with an exception when record is null") + + // create consumer with interceptor + this.consumerConfig.setProperty(ConsumerConfig.INTERCEPTOR_CLASSES_CONFIG, "org.apache.kafka.test.MockConsumerInterceptor") + val testConsumer = createConsumer(keyDeserializer = new StringDeserializer, valueDeserializer = new StringDeserializer) + testConsumer.assign(List(tp).asJava) + testConsumer.seek(tp, 0) + + // consume and verify that values are modified by interceptors + val records = consumeRecords(testConsumer, numRecords) + for (i <- 0 until numRecords) { + val record = records(i) + assertEquals(s"key $i", new String(record.key)) + assertEquals(s"value $i$appendStr".toUpperCase(Locale.ROOT), new String(record.value)) + } + + // commit sync and verify onCommit is called + val commitCountBefore = MockConsumerInterceptor.ON_COMMIT_COUNT.intValue + testConsumer.commitSync(Map[TopicPartition, OffsetAndMetadata]((tp, new OffsetAndMetadata(2L))).asJava) + assertEquals(2, testConsumer.committed(Set(tp).asJava).get(tp).offset) + assertEquals(commitCountBefore + 1, MockConsumerInterceptor.ON_COMMIT_COUNT.intValue) + + // commit async and verify onCommit is called + sendAndAwaitAsyncCommit(testConsumer, Some(Map(tp -> new OffsetAndMetadata(5L)))) + assertEquals(5, testConsumer.committed(Set(tp).asJava).get(tp).offset) + assertEquals(commitCountBefore + 2, MockConsumerInterceptor.ON_COMMIT_COUNT.intValue) + + testConsumer.close() + testProducer.close() + + // cleanup + MockConsumerInterceptor.resetCounters() + MockProducerInterceptor.resetCounters() + } + + @Test + def testAutoCommitIntercept(): Unit = { + val topic2 = "topic2" + createTopic(topic2, 2, brokerCount) + + // produce records + val numRecords = 100 + val testProducer = createProducer(keySerializer = new StringSerializer, valueSerializer = new StringSerializer) + (0 until numRecords).map { i => + testProducer.send(new ProducerRecord(tp.topic(), tp.partition(), s"key $i", s"value $i")) + }.foreach(_.get) + + // create consumer with interceptor + this.consumerConfig.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "true") + this.consumerConfig.setProperty(ConsumerConfig.INTERCEPTOR_CLASSES_CONFIG, "org.apache.kafka.test.MockConsumerInterceptor") + val testConsumer = createConsumer(keyDeserializer = new StringDeserializer, valueDeserializer = new StringDeserializer) + val rebalanceListener = new ConsumerRebalanceListener { + override def onPartitionsAssigned(partitions: util.Collection[TopicPartition]) = { + // keep partitions paused in this test so that we can verify the commits based on specific seeks + testConsumer.pause(partitions) + } + + override def onPartitionsRevoked(partitions: util.Collection[TopicPartition]) = {} + } + changeConsumerSubscriptionAndValidateAssignment(testConsumer, List(topic), Set(tp, tp2), rebalanceListener) + testConsumer.seek(tp, 10) + testConsumer.seek(tp2, 20) + + // change subscription to trigger rebalance + val commitCountBeforeRebalance = MockConsumerInterceptor.ON_COMMIT_COUNT.intValue() + changeConsumerSubscriptionAndValidateAssignment(testConsumer, + List(topic, topic2), + Set(tp, tp2, new TopicPartition(topic2, 0), new TopicPartition(topic2, 1)), + rebalanceListener) + + // after rebalancing, we should have reset to the committed positions + assertEquals(10, testConsumer.committed(Set(tp).asJava).get(tp).offset) + assertEquals(20, testConsumer.committed(Set(tp2).asJava).get(tp2).offset) + assertTrue(MockConsumerInterceptor.ON_COMMIT_COUNT.intValue() > commitCountBeforeRebalance) + + // verify commits are intercepted on close + val commitCountBeforeClose = MockConsumerInterceptor.ON_COMMIT_COUNT.intValue() + testConsumer.close() + assertTrue(MockConsumerInterceptor.ON_COMMIT_COUNT.intValue() > commitCountBeforeClose) + testProducer.close() + + // cleanup + MockConsumerInterceptor.resetCounters() + } + + @Test + def testInterceptorsWithWrongKeyValue(): Unit = { + val appendStr = "mock" + // create producer with interceptor that has different key and value types from the producer + val producerProps = new Properties() + producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList) + producerProps.put(ProducerConfig.INTERCEPTOR_CLASSES_CONFIG, "org.apache.kafka.test.MockProducerInterceptor") + producerProps.put("mock.interceptor.append", appendStr) + val testProducer = createProducer() + + // producing records should succeed + testProducer.send(new ProducerRecord(tp.topic(), tp.partition(), s"key".getBytes, s"value will not be modified".getBytes)) + + // create consumer with interceptor that has different key and value types from the consumer + this.consumerConfig.setProperty(ConsumerConfig.INTERCEPTOR_CLASSES_CONFIG, "org.apache.kafka.test.MockConsumerInterceptor") + val testConsumer = createConsumer() + + testConsumer.assign(List(tp).asJava) + testConsumer.seek(tp, 0) + + // consume and verify that values are not modified by interceptors -- their exceptions are caught and logged, but not propagated + val records = consumeRecords(testConsumer, 1) + val record = records.head + assertEquals(s"value will not be modified", new String(record.value())) + } + + @Test + def testConsumeMessagesWithCreateTime(): Unit = { + val numRecords = 50 + // Test non-compressed messages + val producer = createProducer() + val startingTimestamp = System.currentTimeMillis() + sendRecords(producer, numRecords, tp, startingTimestamp = startingTimestamp) + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + consumeAndVerifyRecords(consumer = consumer, numRecords = numRecords, startingOffset = 0, startingTimestamp = startingTimestamp) + + // Test compressed messages + sendCompressedMessages(numRecords, tp2) + consumer.assign(List(tp2).asJava) + consumeAndVerifyRecords(consumer = consumer, numRecords = numRecords, tp = tp2, startingOffset = 0) + } + + @Test + def testConsumeMessagesWithLogAppendTime(): Unit = { + val topicName = "testConsumeMessagesWithLogAppendTime" + val topicProps = new Properties() + topicProps.setProperty(LogConfig.MessageTimestampTypeProp, "LogAppendTime") + createTopic(topicName, 2, 2, topicProps) + + val startTime = System.currentTimeMillis() + val numRecords = 50 + + // Test non-compressed messages + val tp1 = new TopicPartition(topicName, 0) + val producer = createProducer() + sendRecords(producer, numRecords, tp1) + + val consumer = createConsumer() + consumer.assign(List(tp1).asJava) + consumeAndVerifyRecords(consumer = consumer, numRecords = numRecords, tp = tp1, startingOffset = 0, startingKeyAndValueIndex = 0, + startingTimestamp = startTime, timestampType = TimestampType.LOG_APPEND_TIME) + + // Test compressed messages + val tp2 = new TopicPartition(topicName, 1) + sendCompressedMessages(numRecords, tp2) + consumer.assign(List(tp2).asJava) + consumeAndVerifyRecords(consumer = consumer, numRecords = numRecords, tp = tp2, startingOffset = 0, startingKeyAndValueIndex = 0, + startingTimestamp = startTime, timestampType = TimestampType.LOG_APPEND_TIME) + } + + @Test + def testListTopics(): Unit = { + val numParts = 2 + val topic1 = "part-test-topic-1" + val topic2 = "part-test-topic-2" + val topic3 = "part-test-topic-3" + createTopic(topic1, numParts, 1) + createTopic(topic2, numParts, 1) + createTopic(topic3, numParts, 1) + + val consumer = createConsumer() + val topics = consumer.listTopics() + assertNotNull(topics) + assertEquals(5, topics.size()) + assertEquals(5, topics.keySet().size()) + assertEquals(2, topics.get(topic1).size) + assertEquals(2, topics.get(topic2).size) + assertEquals(2, topics.get(topic3).size) + } + + @Test + def testUnsubscribeTopic(): Unit = { + this.consumerConfig.setProperty(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, "100") // timeout quickly to avoid slow test + this.consumerConfig.setProperty(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, "30") + val consumer = createConsumer() + + val listener = new TestConsumerReassignmentListener() + consumer.subscribe(List(topic).asJava, listener) + + // the initial subscription should cause a callback execution + awaitRebalance(consumer, listener) + + consumer.subscribe(List[String]().asJava) + assertEquals(0, consumer.assignment.size()) + } + + @Test + def testPauseStateNotPreservedByRebalance(): Unit = { + this.consumerConfig.setProperty(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, "100") // timeout quickly to avoid slow test + this.consumerConfig.setProperty(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, "30") + val consumer = createConsumer() + + val producer = createProducer() + val startingTimestamp = System.currentTimeMillis() + sendRecords(producer, numRecords = 5, tp, startingTimestamp = startingTimestamp) + consumer.subscribe(List(topic).asJava) + consumeAndVerifyRecords(consumer = consumer, numRecords = 5, startingOffset = 0, startingTimestamp = startingTimestamp) + consumer.pause(List(tp).asJava) + + // subscribe to a new topic to trigger a rebalance + consumer.subscribe(List("topic2").asJava) + + // after rebalance, our position should be reset and our pause state lost, + // so we should be able to consume from the beginning + consumeAndVerifyRecords(consumer = consumer, numRecords = 0, startingOffset = 5, startingTimestamp = startingTimestamp) + } + + @Test + def testCommitSpecifiedOffsets(): Unit = { + val producer = createProducer() + sendRecords(producer, numRecords = 5, tp) + sendRecords(producer, numRecords = 7, tp2) + + val consumer = createConsumer() + consumer.assign(List(tp, tp2).asJava) + + val pos1 = consumer.position(tp) + val pos2 = consumer.position(tp2) + consumer.commitSync(Map[TopicPartition, OffsetAndMetadata]((tp, new OffsetAndMetadata(3L))).asJava) + assertEquals(3, consumer.committed(Set(tp).asJava).get(tp).offset) + assertNull(consumer.committed(Set(tp2).asJava).get(tp2)) + + // Positions should not change + assertEquals(pos1, consumer.position(tp)) + assertEquals(pos2, consumer.position(tp2)) + consumer.commitSync(Map[TopicPartition, OffsetAndMetadata]((tp2, new OffsetAndMetadata(5L))).asJava) + assertEquals(3, consumer.committed(Set(tp).asJava).get(tp).offset) + assertEquals(5, consumer.committed(Set(tp2).asJava).get(tp2).offset) + + // Using async should pick up the committed changes after commit completes + sendAndAwaitAsyncCommit(consumer, Some(Map(tp2 -> new OffsetAndMetadata(7L)))) + assertEquals(7, consumer.committed(Set(tp2).asJava).get(tp2).offset) + } + + @Test + def testAutoCommitOnRebalance(): Unit = { + val topic2 = "topic2" + createTopic(topic2, 2, brokerCount) + + this.consumerConfig.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "true") + val consumer = createConsumer() + + val numRecords = 10000 + val producer = createProducer() + sendRecords(producer, numRecords, tp) + + val rebalanceListener = new ConsumerRebalanceListener { + override def onPartitionsAssigned(partitions: util.Collection[TopicPartition]) = { + // keep partitions paused in this test so that we can verify the commits based on specific seeks + consumer.pause(partitions) + } + + override def onPartitionsRevoked(partitions: util.Collection[TopicPartition]) = {} + } + + consumer.subscribe(List(topic).asJava, rebalanceListener) + + awaitAssignment(consumer, Set(tp, tp2)) + + consumer.seek(tp, 300) + consumer.seek(tp2, 500) + + // change subscription to trigger rebalance + consumer.subscribe(List(topic, topic2).asJava, rebalanceListener) + + val newAssignment = Set(tp, tp2, new TopicPartition(topic2, 0), new TopicPartition(topic2, 1)) + awaitAssignment(consumer, newAssignment) + + // after rebalancing, we should have reset to the committed positions + assertEquals(300, consumer.committed(Set(tp).asJava).get(tp).offset) + assertEquals(500, consumer.committed(Set(tp2).asJava).get(tp2).offset) + } + + @Test + def testPerPartitionLeadMetricsCleanUpWithSubscribe(): Unit = { + val numMessages = 1000 + val topic2 = "topic2" + createTopic(topic2, 2, brokerCount) + // send some messages. + val producer = createProducer() + sendRecords(producer, numMessages, tp) + // Test subscribe + // Create a consumer and consumer some messages. + consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "testPerPartitionLeadMetricsCleanUpWithSubscribe") + consumerConfig.setProperty(ConsumerConfig.CLIENT_ID_CONFIG, "testPerPartitionLeadMetricsCleanUpWithSubscribe") + val consumer = createConsumer() + val listener = new TestConsumerReassignmentListener + consumer.subscribe(List(topic, topic2).asJava, listener) + val records = awaitNonEmptyRecords(consumer, tp) + assertEquals(1, listener.callsToAssigned, "should be assigned once") + // Verify the metric exist. + val tags1 = new util.HashMap[String, String]() + tags1.put("client-id", "testPerPartitionLeadMetricsCleanUpWithSubscribe") + tags1.put("topic", tp.topic()) + tags1.put("partition", String.valueOf(tp.partition())) + + val tags2 = new util.HashMap[String, String]() + tags2.put("client-id", "testPerPartitionLeadMetricsCleanUpWithSubscribe") + tags2.put("topic", tp2.topic()) + tags2.put("partition", String.valueOf(tp2.partition())) + val fetchLead0 = consumer.metrics.get(new MetricName("records-lead", "consumer-fetch-manager-metrics", "", tags1)) + assertNotNull(fetchLead0) + assertEquals(records.count.toDouble, fetchLead0.metricValue(), s"The lead should be ${records.count}") + + // Remove topic from subscription + consumer.subscribe(List(topic2).asJava, listener) + awaitRebalance(consumer, listener) + // Verify the metric has gone + assertNull(consumer.metrics.get(new MetricName("records-lead", "consumer-fetch-manager-metrics", "", tags1))) + assertNull(consumer.metrics.get(new MetricName("records-lead", "consumer-fetch-manager-metrics", "", tags2))) + } + + @Test + def testPerPartitionLagMetricsCleanUpWithSubscribe(): Unit = { + val numMessages = 1000 + val topic2 = "topic2" + createTopic(topic2, 2, brokerCount) + // send some messages. + val producer = createProducer() + sendRecords(producer, numMessages, tp) + // Test subscribe + // Create a consumer and consumer some messages. + consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "testPerPartitionLagMetricsCleanUpWithSubscribe") + consumerConfig.setProperty(ConsumerConfig.CLIENT_ID_CONFIG, "testPerPartitionLagMetricsCleanUpWithSubscribe") + val consumer = createConsumer() + val listener = new TestConsumerReassignmentListener + consumer.subscribe(List(topic, topic2).asJava, listener) + val records = awaitNonEmptyRecords(consumer, tp) + assertEquals(1, listener.callsToAssigned, "should be assigned once") + // Verify the metric exist. + val tags1 = new util.HashMap[String, String]() + tags1.put("client-id", "testPerPartitionLagMetricsCleanUpWithSubscribe") + tags1.put("topic", tp.topic()) + tags1.put("partition", String.valueOf(tp.partition())) + + val tags2 = new util.HashMap[String, String]() + tags2.put("client-id", "testPerPartitionLagMetricsCleanUpWithSubscribe") + tags2.put("topic", tp2.topic()) + tags2.put("partition", String.valueOf(tp2.partition())) + val fetchLag0 = consumer.metrics.get(new MetricName("records-lag", "consumer-fetch-manager-metrics", "", tags1)) + assertNotNull(fetchLag0) + val expectedLag = numMessages - records.count + assertEquals(expectedLag, fetchLag0.metricValue.asInstanceOf[Double], epsilon, s"The lag should be $expectedLag") + + // Remove topic from subscription + consumer.subscribe(List(topic2).asJava, listener) + awaitRebalance(consumer, listener) + // Verify the metric has gone + assertNull(consumer.metrics.get(new MetricName("records-lag", "consumer-fetch-manager-metrics", "", tags1))) + assertNull(consumer.metrics.get(new MetricName("records-lag", "consumer-fetch-manager-metrics", "", tags2))) + } + + @Test + def testPerPartitionLeadMetricsCleanUpWithAssign(): Unit = { + val numMessages = 1000 + // Test assign + // send some messages. + val producer = createProducer() + sendRecords(producer, numMessages, tp) + sendRecords(producer, numMessages, tp2) + + consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "testPerPartitionLeadMetricsCleanUpWithAssign") + consumerConfig.setProperty(ConsumerConfig.CLIENT_ID_CONFIG, "testPerPartitionLeadMetricsCleanUpWithAssign") + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + val records = awaitNonEmptyRecords(consumer, tp) + // Verify the metric exist. + val tags = new util.HashMap[String, String]() + tags.put("client-id", "testPerPartitionLeadMetricsCleanUpWithAssign") + tags.put("topic", tp.topic()) + tags.put("partition", String.valueOf(tp.partition())) + val fetchLead = consumer.metrics.get(new MetricName("records-lead", "consumer-fetch-manager-metrics", "", tags)) + assertNotNull(fetchLead) + + assertEquals(records.count.toDouble, fetchLead.metricValue(), s"The lead should be ${records.count}") + + consumer.assign(List(tp2).asJava) + awaitNonEmptyRecords(consumer ,tp2) + assertNull(consumer.metrics.get(new MetricName("records-lead", "consumer-fetch-manager-metrics", "", tags))) + } + + @Test + def testPerPartitionLagMetricsCleanUpWithAssign(): Unit = { + val numMessages = 1000 + // Test assign + // send some messages. + val producer = createProducer() + sendRecords(producer, numMessages, tp) + sendRecords(producer, numMessages, tp2) + + consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "testPerPartitionLagMetricsCleanUpWithAssign") + consumerConfig.setProperty(ConsumerConfig.CLIENT_ID_CONFIG, "testPerPartitionLagMetricsCleanUpWithAssign") + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + val records = awaitNonEmptyRecords(consumer, tp) + // Verify the metric exist. + val tags = new util.HashMap[String, String]() + tags.put("client-id", "testPerPartitionLagMetricsCleanUpWithAssign") + tags.put("topic", tp.topic()) + tags.put("partition", String.valueOf(tp.partition())) + val fetchLag = consumer.metrics.get(new MetricName("records-lag", "consumer-fetch-manager-metrics", "", tags)) + assertNotNull(fetchLag) + + val expectedLag = numMessages - records.count + assertEquals(expectedLag, fetchLag.metricValue.asInstanceOf[Double], epsilon, s"The lag should be $expectedLag") + + consumer.assign(List(tp2).asJava) + awaitNonEmptyRecords(consumer, tp2) + assertNull(consumer.metrics.get(new MetricName(tp.toString + ".records-lag", "consumer-fetch-manager-metrics", "", tags))) + assertNull(consumer.metrics.get(new MetricName("records-lag", "consumer-fetch-manager-metrics", "", tags))) + } + + @Test + def testPerPartitionLagMetricsWhenReadCommitted(): Unit = { + val numMessages = 1000 + // send some messages. + val producer = createProducer() + sendRecords(producer, numMessages, tp) + sendRecords(producer, numMessages, tp2) + + consumerConfig.setProperty(ConsumerConfig.ISOLATION_LEVEL_CONFIG, "read_committed") + consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "testPerPartitionLagMetricsCleanUpWithAssign") + consumerConfig.setProperty(ConsumerConfig.CLIENT_ID_CONFIG, "testPerPartitionLagMetricsCleanUpWithAssign") + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + awaitNonEmptyRecords(consumer, tp) + // Verify the metric exist. + val tags = new util.HashMap[String, String]() + tags.put("client-id", "testPerPartitionLagMetricsCleanUpWithAssign") + tags.put("topic", tp.topic()) + tags.put("partition", String.valueOf(tp.partition())) + val fetchLag = consumer.metrics.get(new MetricName("records-lag", "consumer-fetch-manager-metrics", "", tags)) + assertNotNull(fetchLag) + } + + @Test + def testPerPartitionLeadWithMaxPollRecords(): Unit = { + val numMessages = 1000 + val maxPollRecords = 10 + val producer = createProducer() + sendRecords(producer, numMessages, tp) + + consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "testPerPartitionLeadWithMaxPollRecords") + consumerConfig.setProperty(ConsumerConfig.CLIENT_ID_CONFIG, "testPerPartitionLeadWithMaxPollRecords") + consumerConfig.setProperty(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, maxPollRecords.toString) + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + awaitNonEmptyRecords(consumer, tp) + + val tags = new util.HashMap[String, String]() + tags.put("client-id", "testPerPartitionLeadWithMaxPollRecords") + tags.put("topic", tp.topic()) + tags.put("partition", String.valueOf(tp.partition())) + val lead = consumer.metrics.get(new MetricName("records-lead", "consumer-fetch-manager-metrics", "", tags)) + assertEquals(maxPollRecords, lead.metricValue().asInstanceOf[Double], s"The lead should be $maxPollRecords") + } + + @Test + def testPerPartitionLagWithMaxPollRecords(): Unit = { + val numMessages = 1000 + val maxPollRecords = 10 + val producer = createProducer() + sendRecords(producer, numMessages, tp) + + consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "testPerPartitionLagWithMaxPollRecords") + consumerConfig.setProperty(ConsumerConfig.CLIENT_ID_CONFIG, "testPerPartitionLagWithMaxPollRecords") + consumerConfig.setProperty(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, maxPollRecords.toString) + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + val records = awaitNonEmptyRecords(consumer, tp) + + val tags = new util.HashMap[String, String]() + tags.put("client-id", "testPerPartitionLagWithMaxPollRecords") + tags.put("topic", tp.topic()) + tags.put("partition", String.valueOf(tp.partition())) + val lag = consumer.metrics.get(new MetricName("records-lag", "consumer-fetch-manager-metrics", "", tags)) + + assertEquals(numMessages - records.count, lag.metricValue.asInstanceOf[Double], epsilon, s"The lag should be ${numMessages - records.count}") + } + + @Test + def testQuotaMetricsNotCreatedIfNoQuotasConfigured(): Unit = { + val numRecords = 1000 + val producer = createProducer() + val startingTimestamp = System.currentTimeMillis() + sendRecords(producer, numRecords, tp, startingTimestamp = startingTimestamp) + + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + consumer.seek(tp, 0) + consumeAndVerifyRecords(consumer = consumer, numRecords = numRecords, startingOffset = 0, startingTimestamp = startingTimestamp) + + def assertNoMetric(broker: KafkaServer, name: String, quotaType: QuotaType, clientId: String): Unit = { + val metricName = broker.metrics.metricName("throttle-time", + quotaType.toString, + "", + "user", "", + "client-id", clientId) + assertNull(broker.metrics.metric(metricName), "Metric should not have been created " + metricName) + } + servers.foreach(assertNoMetric(_, "byte-rate", QuotaType.Produce, producerClientId)) + servers.foreach(assertNoMetric(_, "throttle-time", QuotaType.Produce, producerClientId)) + servers.foreach(assertNoMetric(_, "byte-rate", QuotaType.Fetch, consumerClientId)) + servers.foreach(assertNoMetric(_, "throttle-time", QuotaType.Fetch, consumerClientId)) + + servers.foreach(assertNoMetric(_, "request-time", QuotaType.Request, producerClientId)) + servers.foreach(assertNoMetric(_, "throttle-time", QuotaType.Request, producerClientId)) + servers.foreach(assertNoMetric(_, "request-time", QuotaType.Request, consumerClientId)) + servers.foreach(assertNoMetric(_, "throttle-time", QuotaType.Request, consumerClientId)) + + def assertNoExemptRequestMetric(broker: KafkaServer): Unit = { + val metricName = broker.metrics.metricName("exempt-request-time", QuotaType.Request.toString, "") + assertNull(broker.metrics.metric(metricName), "Metric should not have been created " + metricName) + } + servers.foreach(assertNoExemptRequestMetric) + } + + def runMultiConsumerSessionTimeoutTest(closeConsumer: Boolean): Unit = { + // use consumers defined in this class plus one additional consumer + // Use topic defined in this class + one additional topic + val producer = createProducer() + sendRecords(producer, numRecords = 100, tp) + sendRecords(producer, numRecords = 100, tp2) + val topic1 = "topic1" + val subscriptions = Set(tp, tp2) ++ createTopicAndSendRecords(producer, topic1, 6, 100) + + // first subscribe consumers that are defined in this class + val consumerPollers = Buffer[ConsumerAssignmentPoller]() + consumerPollers += subscribeConsumerAndStartPolling(createConsumer(), List(topic, topic1)) + consumerPollers += subscribeConsumerAndStartPolling(createConsumer(), List(topic, topic1)) + + // create one more consumer and add it to the group; we will timeout this consumer + val timeoutConsumer = createConsumer() + val timeoutPoller = subscribeConsumerAndStartPolling(timeoutConsumer, List(topic, topic1)) + consumerPollers += timeoutPoller + + // validate the initial assignment + validateGroupAssignment(consumerPollers, subscriptions) + + // stop polling and close one of the consumers, should trigger partition re-assignment among alive consumers + timeoutPoller.shutdown() + consumerPollers -= timeoutPoller + if (closeConsumer) + timeoutConsumer.close() + + validateGroupAssignment(consumerPollers, subscriptions, + Some(s"Did not get valid assignment for partitions ${subscriptions.asJava} after one consumer left"), 3 * groupMaxSessionTimeoutMs) + + // done with pollers and consumers + for (poller <- consumerPollers) + poller.shutdown() + } + + /** + * Creates consumer pollers corresponding to a given consumer group, one per consumer; subscribes consumers to + * 'topicsToSubscribe' topics, waits until consumers get topics assignment. + * + * When the function returns, consumer pollers will continue to poll until shutdown is called on every poller. + * + * @param consumerGroup consumer group + * @param topicsToSubscribe topics to which consumers will subscribe to + * @return collection of consumer pollers + */ + def subscribeConsumers(consumerGroup: mutable.Buffer[KafkaConsumer[Array[Byte], Array[Byte]]], + topicsToSubscribe: List[String]): mutable.Buffer[ConsumerAssignmentPoller] = { + val consumerPollers = mutable.Buffer[ConsumerAssignmentPoller]() + for (consumer <- consumerGroup) + consumerPollers += subscribeConsumerAndStartPolling(consumer, topicsToSubscribe) + consumerPollers + } + + /** + * Creates 'consumerCount' consumers and consumer pollers, one per consumer; subscribes consumers to + * 'topicsToSubscribe' topics, waits until consumers get topics assignment. + * + * When the function returns, consumer pollers will continue to poll until shutdown is called on every poller. + * + * @param consumerCount number of consumers to create + * @param topicsToSubscribe topics to which consumers will subscribe to + * @param subscriptions set of all topic partitions + * @return collection of created consumers and collection of corresponding consumer pollers + */ + def createConsumerGroupAndWaitForAssignment(consumerCount: Int, + topicsToSubscribe: List[String], + subscriptions: Set[TopicPartition]): (Buffer[KafkaConsumer[Array[Byte], Array[Byte]]], Buffer[ConsumerAssignmentPoller]) = { + assertTrue(consumerCount <= subscriptions.size) + val consumerGroup = Buffer[KafkaConsumer[Array[Byte], Array[Byte]]]() + for (_ <- 0 until consumerCount) + consumerGroup += createConsumer() + + // create consumer pollers, wait for assignment and validate it + val consumerPollers = subscribeConsumers(consumerGroup, topicsToSubscribe) + (consumerGroup, consumerPollers) + } + + def changeConsumerGroupSubscriptionAndValidateAssignment(consumerPollers: Buffer[ConsumerAssignmentPoller], + topicsToSubscribe: List[String], + subscriptions: Set[TopicPartition]): Unit = { + for (poller <- consumerPollers) + poller.subscribe(topicsToSubscribe) + + // since subscribe call to poller does not actually call consumer subscribe right away, wait + // until subscribe is called on all consumers + TestUtils.waitUntilTrue(() => { + consumerPollers.forall { poller => poller.isSubscribeRequestProcessed } + }, s"Failed to call subscribe on all consumers in the group for subscription $subscriptions", 1000L) + + validateGroupAssignment(consumerPollers, subscriptions, + Some(s"Did not get valid assignment for partitions ${subscriptions.asJava} after we changed subscription")) + } + + def changeConsumerSubscriptionAndValidateAssignment[K, V](consumer: Consumer[K, V], + topicsToSubscribe: List[String], + expectedAssignment: Set[TopicPartition], + rebalanceListener: ConsumerRebalanceListener): Unit = { + consumer.subscribe(topicsToSubscribe.asJava, rebalanceListener) + awaitAssignment(consumer, expectedAssignment) + } + + private def awaitNonEmptyRecords[K, V](consumer: Consumer[K, V], partition: TopicPartition): ConsumerRecords[K, V] = { + TestUtils.pollRecordsUntilTrue(consumer, (polledRecords: ConsumerRecords[K, V]) => { + if (polledRecords.records(partition).asScala.nonEmpty) + return polledRecords + false + }, s"Consumer did not consume any messages for partition $partition before timeout.") + throw new IllegalStateException("Should have timed out before reaching here") + } + + private def awaitAssignment(consumer: Consumer[_, _], expectedAssignment: Set[TopicPartition]): Unit = { + TestUtils.pollUntilTrue(consumer, () => consumer.assignment() == expectedAssignment.asJava, + s"Timed out while awaiting expected assignment $expectedAssignment. " + + s"The current assignment is ${consumer.assignment()}") + } + + @Test + def testConsumingWithNullGroupId(): Unit = { + val topic = "test_topic" + val partition = 0; + val tp = new TopicPartition(topic, partition) + createTopic(topic, 1, 1) + + TestUtils.waitUntilTrue(() => { + this.zkClient.topicExists(topic) + }, "Failed to create topic") + + val producer = createProducer() + producer.send(new ProducerRecord(topic, partition, "k1".getBytes, "v1".getBytes)).get() + producer.send(new ProducerRecord(topic, partition, "k2".getBytes, "v2".getBytes)).get() + producer.send(new ProducerRecord(topic, partition, "k3".getBytes, "v3".getBytes)).get() + producer.close() + + // consumer 1 uses the default group id and consumes from earliest offset + val consumer1Config = new Properties(consumerConfig) + consumer1Config.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest") + consumer1Config.put(ConsumerConfig.CLIENT_ID_CONFIG, "consumer1") + val consumer1 = createConsumer( + configOverrides = consumer1Config, + configsToRemove = List(ConsumerConfig.GROUP_ID_CONFIG)) + + // consumer 2 uses the default group id and consumes from latest offset + val consumer2Config = new Properties(consumerConfig) + consumer2Config.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "latest") + consumer2Config.put(ConsumerConfig.CLIENT_ID_CONFIG, "consumer2") + val consumer2 = createConsumer( + configOverrides = consumer2Config, + configsToRemove = List(ConsumerConfig.GROUP_ID_CONFIG)) + + // consumer 3 uses the default group id and starts from an explicit offset + val consumer3Config = new Properties(consumerConfig) + consumer3Config.put(ConsumerConfig.CLIENT_ID_CONFIG, "consumer3") + val consumer3 = createConsumer( + configOverrides = consumer3Config, + configsToRemove = List(ConsumerConfig.GROUP_ID_CONFIG)) + + consumer1.assign(asList(tp)) + consumer2.assign(asList(tp)) + consumer3.assign(asList(tp)) + consumer3.seek(tp, 1) + + val numRecords1 = consumer1.poll(Duration.ofMillis(5000)).count() + assertThrows(classOf[InvalidGroupIdException], () => consumer1.commitSync()) + assertThrows(classOf[InvalidGroupIdException], () => consumer2.committed(Set(tp).asJava)) + + val numRecords2 = consumer2.poll(Duration.ofMillis(5000)).count() + val numRecords3 = consumer3.poll(Duration.ofMillis(5000)).count() + + consumer1.unsubscribe() + consumer2.unsubscribe() + consumer3.unsubscribe() + + consumer1.close() + consumer2.close() + consumer3.close() + + assertEquals(3, numRecords1, "Expected consumer1 to consume from earliest offset") + assertEquals(0, numRecords2, "Expected consumer2 to consume from latest offset") + assertEquals(2, numRecords3, "Expected consumer3 to consume from offset 1") + } + + @Test + def testConsumingWithEmptyGroupId(): Unit = { + val topic = "test_topic" + val partition = 0; + val tp = new TopicPartition(topic, partition) + createTopic(topic, 1, 1) + + TestUtils.waitUntilTrue(() => { + this.zkClient.topicExists(topic) + }, "Failed to create topic") + + val producer = createProducer() + producer.send(new ProducerRecord(topic, partition, "k1".getBytes, "v1".getBytes)).get() + producer.send(new ProducerRecord(topic, partition, "k2".getBytes, "v2".getBytes)).get() + producer.close() + + // consumer 1 uses the empty group id + val consumer1Config = new Properties(consumerConfig) + consumer1Config.put(ConsumerConfig.GROUP_ID_CONFIG, "") + consumer1Config.put(ConsumerConfig.CLIENT_ID_CONFIG, "consumer1") + consumer1Config.put(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, "1") + val consumer1 = createConsumer(configOverrides = consumer1Config) + + // consumer 2 uses the empty group id and consumes from latest offset if there is no committed offset + val consumer2Config = new Properties(consumerConfig) + consumer2Config.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "latest") + consumer2Config.put(ConsumerConfig.GROUP_ID_CONFIG, "") + consumer2Config.put(ConsumerConfig.CLIENT_ID_CONFIG, "consumer2") + consumer2Config.put(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, "1") + val consumer2 = createConsumer(configOverrides = consumer2Config) + + consumer1.assign(asList(tp)) + consumer2.assign(asList(tp)) + + val records1 = consumer1.poll(Duration.ofMillis(5000)) + consumer1.commitSync() + + val records2 = consumer2.poll(Duration.ofMillis(5000)) + consumer2.commitSync() + + consumer1.close() + consumer2.close() + + assertTrue(records1.count() == 1 && records1.records(tp).asScala.head.offset == 0, + "Expected consumer1 to consume one message from offset 0") + assertTrue(records2.count() == 1 && records2.records(tp).asScala.head.offset == 1, + "Expected consumer2 to consume one message from offset 1, which is the committed offset of consumer1") + } +} diff --git a/core/src/test/scala/integration/kafka/api/PlaintextEndToEndAuthorizationTest.scala b/core/src/test/scala/integration/kafka/api/PlaintextEndToEndAuthorizationTest.scala new file mode 100644 index 0000000..8e69a2d --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/PlaintextEndToEndAuthorizationTest.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.api + +import org.apache.kafka.common.config.internals.BrokerSecurityConfigs +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.security.auth._ +import org.apache.kafka.common.security.authenticator.DefaultKafkaPrincipalBuilder +import org.junit.jupiter.api.{BeforeEach, Test, TestInfo} +import org.junit.jupiter.api.Assertions._ +import org.apache.kafka.common.errors.TopicAuthorizationException + +// This test case uses a separate listener for client and inter-broker communication, from +// which we derive corresponding principals +object PlaintextEndToEndAuthorizationTest { + @volatile + private var clientListenerName = None: Option[String] + @volatile + private var serverListenerName = None: Option[String] + class TestClientPrincipalBuilder extends DefaultKafkaPrincipalBuilder(null, null) { + override def build(context: AuthenticationContext): KafkaPrincipal = { + clientListenerName = Some(context.listenerName) + context match { + case ctx: PlaintextAuthenticationContext if ctx.clientAddress != null => + new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "client") + case _ => + KafkaPrincipal.ANONYMOUS + } + } + } + + class TestServerPrincipalBuilder extends DefaultKafkaPrincipalBuilder(null, null) { + override def build(context: AuthenticationContext): KafkaPrincipal = { + serverListenerName = Some(context.listenerName) + context match { + case ctx: PlaintextAuthenticationContext => + new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "server") + case _ => + KafkaPrincipal.ANONYMOUS + } + } + } +} + +class PlaintextEndToEndAuthorizationTest extends EndToEndAuthorizationTest { + import PlaintextEndToEndAuthorizationTest.{TestClientPrincipalBuilder, TestServerPrincipalBuilder} + + override protected def securityProtocol = SecurityProtocol.PLAINTEXT + override protected def listenerName: ListenerName = new ListenerName("CLIENT") + override protected def interBrokerListenerName: ListenerName = new ListenerName("SERVER") + + this.serverConfig.setProperty("listener.name.client." + BrokerSecurityConfigs.PRINCIPAL_BUILDER_CLASS_CONFIG, + classOf[TestClientPrincipalBuilder].getName) + this.serverConfig.setProperty("listener.name.server." + BrokerSecurityConfigs.PRINCIPAL_BUILDER_CLASS_CONFIG, + classOf[TestServerPrincipalBuilder].getName) + override val clientPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "client") + override val kafkaPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "server") + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + startSasl(jaasSections(List.empty, None, ZkSasl)) + super.setUp(testInfo) + } + + @Test + def testListenerName(): Unit = { + // To check the client listener name, establish a session on the server by sending any request eg sendRecords + val producer = createProducer() + assertThrows(classOf[TopicAuthorizationException], () => sendRecords(producer, numRecords = 1, tp)) + + assertEquals(Some("CLIENT"), PlaintextEndToEndAuthorizationTest.clientListenerName) + assertEquals(Some("SERVER"), PlaintextEndToEndAuthorizationTest.serverListenerName) + } + +} diff --git a/core/src/test/scala/integration/kafka/api/PlaintextProducerSendTest.scala b/core/src/test/scala/integration/kafka/api/PlaintextProducerSendTest.scala new file mode 100644 index 0000000..38febbc --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/PlaintextProducerSendTest.scala @@ -0,0 +1,196 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.api + +import java.util.Properties +import java.util.concurrent.{ExecutionException, Future, TimeUnit} + +import kafka.log.LogConfig +import kafka.server.Defaults +import kafka.utils.TestUtils +import org.apache.kafka.clients.producer.{BufferExhaustedException, KafkaProducer, ProducerConfig, ProducerRecord, RecordMetadata} +import org.apache.kafka.common.errors.{InvalidTimestampException, RecordTooLargeException, SerializationException, TimeoutException} +import org.apache.kafka.common.record.{DefaultRecord, DefaultRecordBatch, Records, TimestampType} +import org.apache.kafka.common.serialization.ByteArraySerializer +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + + +class PlaintextProducerSendTest extends BaseProducerSendTest { + + @Test + def testWrongSerializer(): Unit = { + val producerProps = new Properties() + producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList) + producerProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.StringSerializer") + producerProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.StringSerializer") + val producer = registerProducer(new KafkaProducer(producerProps)) + val record = new ProducerRecord[Array[Byte], Array[Byte]](topic, 0, "key".getBytes, "value".getBytes) + assertThrows(classOf[SerializationException], () => producer.send(record)) + } + + @Test + def testBatchSizeZero(): Unit = { + val producer = createProducer(brokerList = brokerList, + lingerMs = Int.MaxValue, + deliveryTimeoutMs = Int.MaxValue, + batchSize = 0) + sendAndVerify(producer) + } + + @Test + def testSendCompressedMessageWithLogAppendTime(): Unit = { + val producer = createProducer(brokerList = brokerList, + compressionType = "gzip", + lingerMs = Int.MaxValue, + deliveryTimeoutMs = Int.MaxValue) + sendAndVerifyTimestamp(producer, TimestampType.LOG_APPEND_TIME) + } + + @Test + def testSendNonCompressedMessageWithLogAppendTime(): Unit = { + val producer = createProducer(brokerList = brokerList, lingerMs = Int.MaxValue, deliveryTimeoutMs = Int.MaxValue) + sendAndVerifyTimestamp(producer, TimestampType.LOG_APPEND_TIME) + } + + /** + * testAutoCreateTopic + * + * The topic should be created upon sending the first message + */ + @Test + def testAutoCreateTopic(): Unit = { + val producer = createProducer(brokerList) + try { + // Send a message to auto-create the topic + val record = new ProducerRecord(topic, null, "key".getBytes, "value".getBytes) + assertEquals(0L, producer.send(record).get.offset, "Should have offset 0") + + // double check that the topic is created with leader elected + TestUtils.waitUntilLeaderIsElectedOrChanged(zkClient, topic, 0) + + } finally { + producer.close() + } + } + + @Test + def testSendWithInvalidCreateTime(): Unit = { + val topicProps = new Properties() + topicProps.setProperty(LogConfig.MessageTimestampDifferenceMaxMsProp, "1000") + createTopic(topic, 1, 2, topicProps) + + val producer = createProducer(brokerList = brokerList) + try { + val e = assertThrows(classOf[ExecutionException], + () => producer.send(new ProducerRecord(topic, 0, System.currentTimeMillis() - 1001, "key".getBytes, "value".getBytes)).get()).getCause + assertTrue(e.isInstanceOf[InvalidTimestampException]) + } finally { + producer.close() + } + + // Test compressed messages. + val compressedProducer = createProducer(brokerList = brokerList, compressionType = "gzip") + try { + val e = assertThrows(classOf[ExecutionException], + () => compressedProducer.send(new ProducerRecord(topic, 0, System.currentTimeMillis() - 1001, "key".getBytes, "value".getBytes)).get()).getCause + assertTrue(e.isInstanceOf[InvalidTimestampException]) + } finally { + compressedProducer.close() + } + } + + // Test that producer with max.block.ms=0 can be used to send in non-blocking mode + // where requests are failed immediately without blocking if metadata is not available + // or buffer is full. + @Test + def testNonBlockingProducer(): Unit = { + + def send(producer: KafkaProducer[Array[Byte],Array[Byte]]): Future[RecordMetadata] = { + producer.send(new ProducerRecord(topic, 0, "key".getBytes, new Array[Byte](1000))) + } + + def sendUntilQueued(producer: KafkaProducer[Array[Byte],Array[Byte]]): Future[RecordMetadata] = { + val (future, _) = TestUtils.computeUntilTrue(send(producer))(future => { + if (future.isDone) { + try { + future.get + true // Send was queued and completed successfully + } catch { + case _: ExecutionException => false + } + } else + true // Send future not yet complete, so it has been queued to be sent + }) + future + } + + def verifySendSuccess(future: Future[RecordMetadata]): Unit = { + val recordMetadata = future.get(30, TimeUnit.SECONDS) + assertEquals(topic, recordMetadata.topic) + assertEquals(0, recordMetadata.partition) + assertTrue(recordMetadata.offset >= 0, s"Invalid offset $recordMetadata") + } + + def verifyMetadataNotAvailable(future: Future[RecordMetadata]): Unit = { + assertTrue(future.isDone) // verify future was completed immediately + assertEquals(classOf[TimeoutException], assertThrows(classOf[ExecutionException], () => future.get).getCause.getClass) + } + + def verifyBufferExhausted(future: Future[RecordMetadata]): Unit = { + assertTrue(future.isDone) // verify future was completed immediately + assertEquals(classOf[BufferExhaustedException], assertThrows(classOf[ExecutionException], () => future.get).getCause.getClass) + } + + // Topic metadata not available, send should fail without blocking + val producer = createProducer(brokerList = brokerList, maxBlockMs = 0) + verifyMetadataNotAvailable(send(producer)) + + // Test that send starts succeeding once metadata is available + val future = sendUntilQueued(producer) + verifySendSuccess(future) + + // Verify that send fails immediately without blocking when there is no space left in the buffer + val producer2 = createProducer(brokerList = brokerList, maxBlockMs = 0, + lingerMs = 15000, batchSize = 1100, bufferSize = 1500) + val future2 = sendUntilQueued(producer2) // wait until metadata is available and one record is queued + verifyBufferExhausted(send(producer2)) // should fail send since buffer is full + verifySendSuccess(future2) // previous batch should be completed and sent now + } + + @Test + def testSendRecordBatchWithMaxRequestSizeAndHigher(): Unit = { + val producerProps = new Properties() + producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList) + val producer = registerProducer(new KafkaProducer(producerProps, new ByteArraySerializer, new ByteArraySerializer)) + + val keyLengthSize = 1 + val headerLengthSize = 1 + val valueLengthSize = 3 + val overhead = Records.LOG_OVERHEAD + DefaultRecordBatch.RECORD_BATCH_OVERHEAD + DefaultRecord.MAX_RECORD_OVERHEAD + + keyLengthSize + headerLengthSize + valueLengthSize + val valueSize = Defaults.MessageMaxBytes - overhead + + val record0 = new ProducerRecord(topic, new Array[Byte](0), new Array[Byte](valueSize)) + assertEquals(record0.value.length, producer.send(record0).get.serializedValueSize) + + val record1 = new ProducerRecord(topic, new Array[Byte](0), new Array[Byte](valueSize + 1)) + assertEquals(classOf[RecordTooLargeException], assertThrows(classOf[ExecutionException], () => producer.send(record1).get).getCause.getClass) + } + +} diff --git a/core/src/test/scala/integration/kafka/api/ProducerCompressionTest.scala b/core/src/test/scala/integration/kafka/api/ProducerCompressionTest.scala new file mode 100755 index 0000000..62b2689 --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/ProducerCompressionTest.scala @@ -0,0 +1,117 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.api.test + +import kafka.server.{KafkaConfig, KafkaServer} +import kafka.utils.TestUtils +import kafka.server.QuorumTestHarness +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerConfig, ProducerRecord} +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.serialization.ByteArraySerializer +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, TestInfo} +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.{Arguments, MethodSource} + +import java.util.{Collections, Properties} +import scala.jdk.CollectionConverters._ + +class ProducerCompressionTest extends QuorumTestHarness { + + private val brokerId = 0 + private val topic = "topic" + private val numRecords = 2000 + + private var server: KafkaServer = null + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + val props = TestUtils.createBrokerConfig(brokerId, zkConnect) + server = TestUtils.createServer(KafkaConfig.fromProps(props)) + } + + @AfterEach + override def tearDown(): Unit = { + TestUtils.shutdownServers(Seq(server)) + super.tearDown() + } + + /** + * testCompression + * + * Compressed messages should be able to sent and consumed correctly + */ + @ParameterizedTest + @MethodSource(Array("parameters")) + def testCompression(compression: String): Unit = { + + val producerProps = new Properties() + val bootstrapServers = TestUtils.getBrokerListStrFromServers(Seq(server)) + producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers) + producerProps.put(ProducerConfig.COMPRESSION_TYPE_CONFIG, compression) + producerProps.put(ProducerConfig.BATCH_SIZE_CONFIG, "66000") + producerProps.put(ProducerConfig.LINGER_MS_CONFIG, "200") + val producer = new KafkaProducer(producerProps, new ByteArraySerializer, new ByteArraySerializer) + val consumer = TestUtils.createConsumer(bootstrapServers) + + try { + // create topic + TestUtils.createTopic(zkClient, topic, 1, 1, List(server)) + val partition = 0 + + // prepare the messages + val messageValues = (0 until numRecords).map(i => "value" + i) + + // make sure the returned messages are correct + val now = System.currentTimeMillis() + val responses = for (message <- messageValues) + yield producer.send(new ProducerRecord(topic, null, now, null, message.getBytes)) + for ((future, offset) <- responses.zipWithIndex) { + assertEquals(offset.toLong, future.get.offset) + } + + val tp = new TopicPartition(topic, partition) + // make sure the fetched message count match + consumer.assign(Collections.singleton(tp)) + consumer.seek(tp, 0) + val records = TestUtils.consumeRecords(consumer, numRecords) + + for (((messageValue, record), index) <- messageValues.zip(records).zipWithIndex) { + assertEquals(messageValue, new String(record.value)) + assertEquals(now, record.timestamp) + assertEquals(index.toLong, record.offset) + } + } finally { + producer.close() + consumer.close() + } + } +} + +object ProducerCompressionTest { + def parameters: java.util.stream.Stream[Arguments] = { + Seq( + Arguments.of("none"), + Arguments.of("gzip"), + Arguments.of("snappy"), + Arguments.of("lz4"), + Arguments.of("zstd") + ).asJava.stream() + } +} diff --git a/core/src/test/scala/integration/kafka/api/ProducerFailureHandlingTest.scala b/core/src/test/scala/integration/kafka/api/ProducerFailureHandlingTest.scala new file mode 100644 index 0000000..4d45da9 --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/ProducerFailureHandlingTest.scala @@ -0,0 +1,260 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.api + +import java.util.concurrent.ExecutionException +import java.util.Properties + +import kafka.integration.KafkaServerTestHarness +import kafka.log.LogConfig +import kafka.server.KafkaConfig +import kafka.utils.TestUtils +import org.apache.kafka.clients.producer._ +import org.apache.kafka.common.errors._ +import org.apache.kafka.common.internals.Topic +import org.apache.kafka.common.record.{DefaultRecord, DefaultRecordBatch} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +class ProducerFailureHandlingTest extends KafkaServerTestHarness { + private val producerBufferSize = 30000 + private val serverMessageMaxBytes = producerBufferSize/2 + private val replicaFetchMaxPartitionBytes = serverMessageMaxBytes + 200 + private val replicaFetchMaxResponseBytes = replicaFetchMaxPartitionBytes + 200 + + val numServers = 2 + + val overridingProps = new Properties() + overridingProps.put(KafkaConfig.AutoCreateTopicsEnableProp, false.toString) + overridingProps.put(KafkaConfig.MessageMaxBytesProp, serverMessageMaxBytes.toString) + overridingProps.put(KafkaConfig.ReplicaFetchMaxBytesProp, replicaFetchMaxPartitionBytes.toString) + overridingProps.put(KafkaConfig.ReplicaFetchResponseMaxBytesDoc, replicaFetchMaxResponseBytes.toString) + // Set a smaller value for the number of partitions for the offset commit topic (__consumer_offset topic) + // so that the creation of that topic/partition(s) and subsequent leader assignment doesn't take relatively long + overridingProps.put(KafkaConfig.OffsetsTopicPartitionsProp, 1.toString) + + def generateConfigs = + TestUtils.createBrokerConfigs(numServers, zkConnect, false).map(KafkaConfig.fromProps(_, overridingProps)) + + private var producer1: KafkaProducer[Array[Byte], Array[Byte]] = null + private var producer2: KafkaProducer[Array[Byte], Array[Byte]] = null + private var producer3: KafkaProducer[Array[Byte], Array[Byte]] = null + private var producer4: KafkaProducer[Array[Byte], Array[Byte]] = null + + private val topic1 = "topic-1" + private val topic2 = "topic-2" + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + + producer1 = TestUtils.createProducer(brokerList, acks = 0, retries = 0, requestTimeoutMs = 30000, maxBlockMs = 10000L, + bufferSize = producerBufferSize) + producer2 = TestUtils.createProducer(brokerList, acks = 1, retries = 0, requestTimeoutMs = 30000, maxBlockMs = 10000L, + bufferSize = producerBufferSize) + producer3 = TestUtils.createProducer(brokerList, acks = -1, retries = 0, requestTimeoutMs = 30000, maxBlockMs = 10000L, + bufferSize = producerBufferSize) + } + + @AfterEach + override def tearDown(): Unit = { + if (producer1 != null) producer1.close() + if (producer2 != null) producer2.close() + if (producer3 != null) producer3.close() + if (producer4 != null) producer4.close() + + super.tearDown() + } + + /** + * With ack == 0 the future metadata will have no exceptions with offset -1 + */ + @Test + def testTooLargeRecordWithAckZero(): Unit = { + // create topic + createTopic(topic1, replicationFactor = numServers) + + // send a too-large record + val record = new ProducerRecord(topic1, null, "key".getBytes, new Array[Byte](serverMessageMaxBytes + 1)) + + val recordMetadata = producer1.send(record).get() + assertNotNull(recordMetadata) + assertFalse(recordMetadata.hasOffset) + assertEquals(-1L, recordMetadata.offset) + } + + /** + * With ack == 1 the future metadata will throw ExecutionException caused by RecordTooLargeException + */ + @Test + def testTooLargeRecordWithAckOne(): Unit = { + // create topic + createTopic(topic1, replicationFactor = numServers) + + // send a too-large record + val record = new ProducerRecord(topic1, null, "key".getBytes, new Array[Byte](serverMessageMaxBytes + 1)) + assertThrows(classOf[ExecutionException], () => producer2.send(record).get) + } + + private def checkTooLargeRecordForReplicationWithAckAll(maxFetchSize: Int): Unit = { + val maxMessageSize = maxFetchSize + 100 + val topicConfig = new Properties + topicConfig.setProperty(LogConfig.MinInSyncReplicasProp, numServers.toString) + topicConfig.setProperty(LogConfig.MaxMessageBytesProp, maxMessageSize.toString) + + // create topic + val topic10 = "topic10" + createTopic(topic10, numPartitions = servers.size, replicationFactor = numServers, topicConfig) + + // send a record that is too large for replication, but within the broker max message limit + val value = new Array[Byte](maxMessageSize - DefaultRecordBatch.RECORD_BATCH_OVERHEAD - DefaultRecord.MAX_RECORD_OVERHEAD) + val record = new ProducerRecord[Array[Byte], Array[Byte]](topic10, null, value) + val recordMetadata = producer3.send(record).get + + assertEquals(topic10, recordMetadata.topic) + } + + /** This should succeed as the replica fetcher thread can handle oversized messages since KIP-74 */ + @Test + def testPartitionTooLargeForReplicationWithAckAll(): Unit = { + checkTooLargeRecordForReplicationWithAckAll(replicaFetchMaxPartitionBytes) + } + + /** This should succeed as the replica fetcher thread can handle oversized messages since KIP-74 */ + @Test + def testResponseTooLargeForReplicationWithAckAll(): Unit = { + checkTooLargeRecordForReplicationWithAckAll(replicaFetchMaxResponseBytes) + } + + /** + * With non-exist-topic the future metadata should return ExecutionException caused by TimeoutException + */ + @Test + def testNonExistentTopic(): Unit = { + // send a record with non-exist topic + val record = new ProducerRecord(topic2, null, "key".getBytes, "value".getBytes) + assertThrows(classOf[ExecutionException], () => producer1.send(record).get) + } + + /** + * With incorrect broker-list the future metadata should return ExecutionException caused by TimeoutException + * + * TODO: other exceptions that can be thrown in ExecutionException: + * UnknownTopicOrPartitionException + * NotLeaderOrFollowerException + * LeaderNotAvailableException + * CorruptRecordException + * TimeoutException + */ + @Test + def testWrongBrokerList(): Unit = { + // create topic + createTopic(topic1, replicationFactor = numServers) + + // producer with incorrect broker list + producer4 = TestUtils.createProducer("localhost:8686,localhost:4242", acks = 1, maxBlockMs = 10000L, bufferSize = producerBufferSize) + + // send a record with incorrect broker list + val record = new ProducerRecord(topic1, null, "key".getBytes, "value".getBytes) + assertThrows(classOf[ExecutionException], () => producer4.send(record).get) + } + + /** + * Send with invalid partition id should return ExecutionException caused by TimeoutException + * when partition is higher than the upper bound of partitions. + */ + @Test + def testInvalidPartition(): Unit = { + // create topic with a single partition + createTopic(topic1, numPartitions = 1, replicationFactor = numServers) + + // create a record with incorrect partition id (higher than the number of partitions), send should fail + val higherRecord = new ProducerRecord(topic1, 1, "key".getBytes, "value".getBytes) + val e = assertThrows(classOf[ExecutionException], () => producer1.send(higherRecord).get) + assertEquals(classOf[TimeoutException], e.getCause.getClass) + } + + /** + * The send call after producer closed should throw IllegalStateException + */ + @Test + def testSendAfterClosed(): Unit = { + // create topic + createTopic(topic1, replicationFactor = numServers) + + val record = new ProducerRecord[Array[Byte], Array[Byte]](topic1, null, "key".getBytes, "value".getBytes) + + // first send a message to make sure the metadata is refreshed + producer1.send(record).get + producer2.send(record).get + producer3.send(record).get + + producer1.close() + assertThrows(classOf[IllegalStateException], () => producer1.send(record)) + producer2.close() + assertThrows(classOf[IllegalStateException], () => producer2.send(record)) + producer3.close() + assertThrows(classOf[IllegalStateException], () => producer3.send(record)) + } + + @Test + def testCannotSendToInternalTopic(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + val thrown = assertThrows(classOf[ExecutionException], + () => producer2.send(new ProducerRecord(Topic.GROUP_METADATA_TOPIC_NAME, "test".getBytes, "test".getBytes)).get) + assertTrue(thrown.getCause.isInstanceOf[InvalidTopicException], "Unexpected exception while sending to an invalid topic " + thrown.getCause) + } + + @Test + def testNotEnoughReplicas(): Unit = { + val topicName = "minisrtest" + val topicProps = new Properties() + topicProps.put("min.insync.replicas",(numServers+1).toString) + + createTopic(topicName, replicationFactor = numServers, topicConfig = topicProps) + + val record = new ProducerRecord(topicName, null, "key".getBytes, "value".getBytes) + val e = assertThrows(classOf[ExecutionException], () => producer3.send(record).get) + assertEquals(classOf[NotEnoughReplicasException], e.getCause.getClass) + } + + @Test + def testNotEnoughReplicasAfterBrokerShutdown(): Unit = { + val topicName = "minisrtest2" + val topicProps = new Properties() + topicProps.put("min.insync.replicas", numServers.toString) + + createTopic(topicName, replicationFactor = numServers, topicConfig = topicProps) + + val record = new ProducerRecord(topicName, null, "key".getBytes, "value".getBytes) + // this should work with all brokers up and running + producer3.send(record).get + + // shut down one broker + servers.head.shutdown() + servers.head.awaitShutdown() + val e = assertThrows(classOf[ExecutionException], () => producer3.send(record).get) + assertTrue(e.getCause.isInstanceOf[NotEnoughReplicasException] || + e.getCause.isInstanceOf[NotEnoughReplicasAfterAppendException] || + e.getCause.isInstanceOf[TimeoutException]) + + // restart the server + servers.head.startup() + } + +} diff --git a/core/src/test/scala/integration/kafka/api/ProducerSendWhileDeletionTest.scala b/core/src/test/scala/integration/kafka/api/ProducerSendWhileDeletionTest.scala new file mode 100644 index 0000000..ec05bb2 --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/ProducerSendWhileDeletionTest.scala @@ -0,0 +1,83 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.api + +import kafka.server.KafkaConfig +import kafka.utils.TestUtils +import org.apache.kafka.clients.producer.{ProducerConfig, ProducerRecord} + +import org.apache.kafka.common.TopicPartition +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test + +import java.nio.charset.StandardCharsets + + +class ProducerSendWhileDeletionTest extends IntegrationTestHarness { + val producerCount: Int = 1 + val brokerCount: Int = 2 + + serverConfig.put(KafkaConfig.NumPartitionsProp, 2.toString) + serverConfig.put(KafkaConfig.DefaultReplicationFactorProp, 2.toString) + serverConfig.put(KafkaConfig.AutoLeaderRebalanceEnableProp, false.toString) + + producerConfig.put(ProducerConfig.MAX_BLOCK_MS_CONFIG, 5000L.toString) + producerConfig.put(ProducerConfig.REQUEST_TIMEOUT_MS_CONFIG, 10000.toString) + producerConfig.put(ProducerConfig.DELIVERY_TIMEOUT_MS_CONFIG, 10000.toString) + + /** + * Tests that Producer gets self-recovered when a topic is deleted mid-way of produce. + * + * Producer will attempt to send messages to the partition specified in each record, and should + * succeed as long as the partition is included in the metadata. + */ + @Test + def testSendWithTopicDeletionMidWay(): Unit = { + val numRecords = 10 + val topic = "topic" + + // Create topic with leader as 0 for the 2 partitions. + createTopic(topic, Map(0 -> Seq(0, 1), 1 -> Seq(0, 1))) + + val reassignment = Map( + new TopicPartition(topic, 0) -> Seq(1, 0), + new TopicPartition(topic, 1) -> Seq(1, 0) + ) + + // Change leader to 1 for both the partitions to increase leader epoch from 0 -> 1 + zkClient.createPartitionReassignment(reassignment) + TestUtils.waitUntilTrue(() => !zkClient.reassignPartitionsInProgress, + "failed to remove reassign partitions path after completion") + + val producer = createProducer() + + (1 to numRecords).foreach { i => + val resp = producer.send(new ProducerRecord(topic, null, ("value" + i).getBytes(StandardCharsets.UTF_8))).get + assertEquals(topic, resp.topic()) + } + + // Start topic deletion + adminZkClient.deleteTopic(topic) + + // Verify that the topic is deleted when no metadata request comes in + TestUtils.verifyTopicDeletion(zkClient, topic, 2, servers) + + // Producer should be able to send messages even after topic gets deleted and auto-created + assertEquals(topic, producer.send(new ProducerRecord(topic, null, "value".getBytes(StandardCharsets.UTF_8))).get.topic()) + } + +} diff --git a/core/src/test/scala/integration/kafka/api/RackAwareAutoTopicCreationTest.scala b/core/src/test/scala/integration/kafka/api/RackAwareAutoTopicCreationTest.scala new file mode 100644 index 0000000..745e86d --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/RackAwareAutoTopicCreationTest.scala @@ -0,0 +1,65 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.api + +import java.util.Properties + +import kafka.admin.{RackAwareMode, RackAwareTest} +import kafka.integration.KafkaServerTestHarness +import kafka.server.KafkaConfig +import kafka.utils.TestUtils +import org.apache.kafka.clients.producer.ProducerRecord +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test +import scala.collection.Map + +class RackAwareAutoTopicCreationTest extends KafkaServerTestHarness with RackAwareTest { + val numServers = 4 + val numPartitions = 8 + val replicationFactor = 2 + val overridingProps = new Properties() + overridingProps.put(KafkaConfig.NumPartitionsProp, numPartitions.toString) + overridingProps.put(KafkaConfig.DefaultReplicationFactorProp, replicationFactor.toString) + + def generateConfigs = + (0 until numServers) map { node => + TestUtils.createBrokerConfig(node, zkConnect, enableControlledShutdown = false, rack = Some((node / 2).toString)) + } map (KafkaConfig.fromProps(_, overridingProps)) + + private val topic = "topic" + + @Test + def testAutoCreateTopic(): Unit = { + val producer = TestUtils.createProducer(brokerList) + try { + // Send a message to auto-create the topic + val record = new ProducerRecord(topic, null, "key".getBytes, "value".getBytes) + assertEquals(0L, producer.send(record).get.offset, "Should have offset 0") + + // double check that the topic is created with leader elected + TestUtils.waitUntilLeaderIsElectedOrChanged(zkClient, topic, 0) + val assignment = zkClient.getReplicaAssignmentForTopics(Set(topic)).map { case (topicPartition, replicas) => + topicPartition.partition -> replicas + } + val brokerMetadatas = adminZkClient.getBrokerMetadatas(RackAwareMode.Enforced) + val expectedMap = Map(0 -> "0", 1 -> "0", 2 -> "1", 3 -> "1") + assertEquals(expectedMap, brokerMetadatas.map(b => b.id -> b.rack.get).toMap) + checkReplicaDistribution(assignment, expectedMap, numServers, numPartitions, replicationFactor) + } finally producer.close() + } +} + diff --git a/core/src/test/scala/integration/kafka/api/SaslClientsWithInvalidCredentialsTest.scala b/core/src/test/scala/integration/kafka/api/SaslClientsWithInvalidCredentialsTest.scala new file mode 100644 index 0000000..a9f2c6c --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/SaslClientsWithInvalidCredentialsTest.scala @@ -0,0 +1,245 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package kafka.api + +import java.nio.file.Files +import java.time.Duration +import java.util.Collections +import java.util.concurrent.{ExecutionException, TimeUnit} + +import scala.jdk.CollectionConverters._ +import org.apache.kafka.clients.admin.{Admin, AdminClientConfig} +import org.apache.kafka.clients.consumer.{ConsumerConfig, KafkaConsumer} +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerConfig, ProducerRecord} +import org.apache.kafka.common.{KafkaException, TopicPartition} +import org.apache.kafka.common.errors.SaslAuthenticationException +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} +import org.junit.jupiter.api.Assertions._ +import kafka.admin.ConsumerGroupCommand.{ConsumerGroupCommandOptions, ConsumerGroupService} +import kafka.server.KafkaConfig +import kafka.utils.{JaasTestUtils, TestUtils} +import kafka.zk.ConfigEntityChangeNotificationZNode +import org.apache.kafka.common.security.auth.SecurityProtocol + +class SaslClientsWithInvalidCredentialsTest extends IntegrationTestHarness with SaslSetup { + private val kafkaClientSaslMechanism = "SCRAM-SHA-256" + private val kafkaServerSaslMechanisms = List(kafkaClientSaslMechanism) + override protected val securityProtocol = SecurityProtocol.SASL_PLAINTEXT + override protected val serverSaslProperties = Some(kafkaServerSaslProperties(kafkaServerSaslMechanisms, kafkaClientSaslMechanism)) + override protected val clientSaslProperties = Some(kafkaClientSaslProperties(kafkaClientSaslMechanism)) + val consumerCount = 1 + val producerCount = 1 + val brokerCount = 1 + + this.serverConfig.setProperty(KafkaConfig.OffsetsTopicReplicationFactorProp, "1") + this.serverConfig.setProperty(KafkaConfig.TransactionsTopicReplicationFactorProp, "1") + this.serverConfig.setProperty(KafkaConfig.TransactionsTopicMinISRProp, "1") + this.consumerConfig.setProperty(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest") + + val topic = "topic" + val numPartitions = 1 + val tp = new TopicPartition(topic, 0) + + override def configureSecurityBeforeServersStart(): Unit = { + super.configureSecurityBeforeServersStart() + zkClient.makeSurePersistentPathExists(ConfigEntityChangeNotificationZNode.path) + // Create broker credentials before starting brokers + createScramCredentials(zkConnect, JaasTestUtils.KafkaScramAdmin, JaasTestUtils.KafkaScramAdminPassword) + } + + override def createPrivilegedAdminClient() = { + createAdminClient(brokerList, securityProtocol, trustStoreFile, clientSaslProperties, + kafkaClientSaslMechanism, JaasTestUtils.KafkaScramAdmin, JaasTestUtils.KafkaScramAdminPassword) + } + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + startSasl(jaasSections(kafkaServerSaslMechanisms, Some(kafkaClientSaslMechanism), Both, + JaasTestUtils.KafkaServerContextName)) + super.setUp(testInfo) + createTopic(topic, numPartitions, brokerCount) + } + + @AfterEach + override def tearDown(): Unit = { + super.tearDown() + closeSasl() + } + + @Test + def testProducerWithAuthenticationFailure(): Unit = { + val producer = createProducer() + verifyAuthenticationException(sendOneRecord(producer, maxWaitMs = 10000)) + verifyAuthenticationException(producer.partitionsFor(topic)) + + createClientCredential() + verifyWithRetry(sendOneRecord(producer)) + } + + @Test + def testTransactionalProducerWithAuthenticationFailure(): Unit = { + val txProducer = createTransactionalProducer() + verifyAuthenticationException(txProducer.initTransactions()) + + createClientCredential() + assertThrows(classOf[KafkaException], () => txProducer.initTransactions()) + } + + @Test + def testConsumerWithAuthenticationFailure(): Unit = { + val consumer = createConsumer() + consumer.subscribe(List(topic).asJava) + verifyConsumerWithAuthenticationFailure(consumer) + } + + @Test + def testManualAssignmentConsumerWithAuthenticationFailure(): Unit = { + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + verifyConsumerWithAuthenticationFailure(consumer) + } + + @Test + def testManualAssignmentConsumerWithAutoCommitDisabledWithAuthenticationFailure(): Unit = { + this.consumerConfig.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, false.toString) + val consumer = createConsumer() + consumer.assign(List(tp).asJava) + consumer.seek(tp, 0) + verifyConsumerWithAuthenticationFailure(consumer) + } + + private def verifyConsumerWithAuthenticationFailure(consumer: KafkaConsumer[Array[Byte], Array[Byte]]): Unit = { + verifyAuthenticationException(consumer.poll(Duration.ofMillis(1000))) + verifyAuthenticationException(consumer.partitionsFor(topic)) + + createClientCredential() + val producer = createProducer() + verifyWithRetry(sendOneRecord(producer)) + verifyWithRetry(assertEquals(1, consumer.poll(Duration.ofMillis(1000)).count)) + } + + @Test + def testKafkaAdminClientWithAuthenticationFailure(): Unit = { + val props = TestUtils.adminClientSecurityConfigs(securityProtocol, trustStoreFile, clientSaslProperties) + props.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList) + val adminClient = Admin.create(props) + + def describeTopic(): Unit = { + try { + val response = adminClient.describeTopics(Collections.singleton(topic)).allTopicNames.get + assertEquals(1, response.size) + response.forEach { (topic, description) => + assertEquals(numPartitions, description.partitions.size) + } + } catch { + case e: ExecutionException => throw e.getCause + } + } + + try { + verifyAuthenticationException(describeTopic()) + + createClientCredential() + verifyWithRetry(describeTopic()) + } finally { + adminClient.close() + } + } + + @Test + def testConsumerGroupServiceWithAuthenticationFailure(): Unit = { + val consumerGroupService: ConsumerGroupService = prepareConsumerGroupService + + val consumer = createConsumer() + try { + consumer.subscribe(List(topic).asJava) + + verifyAuthenticationException(consumerGroupService.listGroups()) + } finally consumerGroupService.close() + } + + @Test + def testConsumerGroupServiceWithAuthenticationSuccess(): Unit = { + createClientCredential() + val consumerGroupService: ConsumerGroupService = prepareConsumerGroupService + + val consumer = createConsumer() + try { + consumer.subscribe(List(topic).asJava) + + verifyWithRetry(consumer.poll(Duration.ofMillis(1000))) + assertEquals(1, consumerGroupService.listConsumerGroups().size) + } + finally consumerGroupService.close() + } + + private def prepareConsumerGroupService = { + val propsFile = TestUtils.tempFile() + val propsStream = Files.newOutputStream(propsFile.toPath) + try { + propsStream.write("security.protocol=SASL_PLAINTEXT\n".getBytes()) + propsStream.write(s"sasl.mechanism=$kafkaClientSaslMechanism".getBytes()) + } + finally propsStream.close() + + val cgcArgs = Array("--bootstrap-server", brokerList, + "--describe", + "--group", "test.group", + "--command-config", propsFile.getAbsolutePath) + val opts = new ConsumerGroupCommandOptions(cgcArgs) + val consumerGroupService = new ConsumerGroupService(opts) + consumerGroupService + } + + private def createClientCredential(): Unit = { + createScramCredentialsViaPrivilegedAdminClient(JaasTestUtils.KafkaScramUser2, JaasTestUtils.KafkaScramPassword2) + } + + private def sendOneRecord(producer: KafkaProducer[Array[Byte], Array[Byte]], maxWaitMs: Long = 15000): Unit = { + val record = new ProducerRecord(tp.topic(), tp.partition(), 0L, "key".getBytes, "value".getBytes) + val future = producer.send(record) + producer.flush() + try { + val recordMetadata = future.get(maxWaitMs, TimeUnit.MILLISECONDS) + assertTrue(recordMetadata.offset >= 0, s"Invalid offset $recordMetadata") + } catch { + case e: ExecutionException => throw e.getCause + } + } + + private def verifyAuthenticationException(action: => Unit): Unit = { + val startMs = System.currentTimeMillis + assertThrows(classOf[Exception], () => action) + val elapsedMs = System.currentTimeMillis - startMs + assertTrue(elapsedMs <= 5000, s"Poll took too long, elapsed=$elapsedMs") + } + + private def verifyWithRetry(action: => Unit): Unit = { + var attempts = 0 + TestUtils.waitUntilTrue(() => { + try { + attempts += 1 + action + true + } catch { + case _: SaslAuthenticationException => false + } + }, s"Operation did not succeed within timeout after $attempts") + } + + private def createTransactionalProducer(): KafkaProducer[Array[Byte], Array[Byte]] = { + producerConfig.setProperty(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "txclient-1") + producerConfig.put(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG, "true") + createProducer() + } +} diff --git a/core/src/test/scala/integration/kafka/api/SaslEndToEndAuthorizationTest.scala b/core/src/test/scala/integration/kafka/api/SaslEndToEndAuthorizationTest.scala new file mode 100644 index 0000000..e406450 --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/SaslEndToEndAuthorizationTest.scala @@ -0,0 +1,81 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.api + +import org.apache.kafka.common.config.SaslConfigs +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.errors.{GroupAuthorizationException, TopicAuthorizationException} +import org.junit.jupiter.api.{BeforeEach, Test, TestInfo, Timeout} +import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue, fail} + +import scala.collection.immutable.List +import scala.jdk.CollectionConverters._ + +abstract class SaslEndToEndAuthorizationTest extends EndToEndAuthorizationTest { + override protected def securityProtocol = SecurityProtocol.SASL_SSL + override protected val serverSaslProperties = Some(kafkaServerSaslProperties(kafkaServerSaslMechanisms, kafkaClientSaslMechanism)) + override protected val clientSaslProperties = Some(kafkaClientSaslProperties(kafkaClientSaslMechanism)) + + protected def kafkaClientSaslMechanism: String + protected def kafkaServerSaslMechanisms: List[String] + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + // create static config including client login context with credentials for JaasTestUtils 'client2' + startSasl(jaasSections(kafkaServerSaslMechanisms, Option(kafkaClientSaslMechanism), Both)) + // set dynamic properties with credentials for JaasTestUtils 'client1' so that dynamic JAAS configuration is also + // tested by this set of tests + val clientLoginContext = jaasClientLoginModule(kafkaClientSaslMechanism) + producerConfig.put(SaslConfigs.SASL_JAAS_CONFIG, clientLoginContext) + consumerConfig.put(SaslConfigs.SASL_JAAS_CONFIG, clientLoginContext) + adminClientConfig.put(SaslConfigs.SASL_JAAS_CONFIG, clientLoginContext) + super.setUp(testInfo) + } + + /** + * Test with two consumers, each with different valid SASL credentials. + * The first consumer succeeds because it is allowed by the ACL, + * the second one connects ok, but fails to consume messages due to the ACL. + */ + @Timeout(15) + @Test + def testTwoConsumersWithDifferentSaslCredentials(): Unit = { + setAclsAndProduce(tp) + val consumer1 = createConsumer() + + // consumer2 retrieves its credentials from the static JAAS configuration, so we test also this path + consumerConfig.remove(SaslConfigs.SASL_JAAS_CONFIG) + consumerConfig.remove(SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS) + + val consumer2 = createConsumer() + consumer1.assign(List(tp).asJava) + consumer2.assign(List(tp).asJava) + + consumeRecords(consumer1, numRecords) + + try { + consumeRecords(consumer2) + fail("Expected exception as consumer2 has no access to topic or group") + } catch { + // Either exception is possible depending on the order that the first Metadata + // and FindCoordinator requests are received + case e: TopicAuthorizationException => assertTrue(e.unauthorizedTopics.contains(topic)) + case e: GroupAuthorizationException => assertEquals(group, e.groupId) + } + confirmReauthenticationMetrics() + } +} diff --git a/core/src/test/scala/integration/kafka/api/SaslGssapiSslEndToEndAuthorizationTest.scala b/core/src/test/scala/integration/kafka/api/SaslGssapiSslEndToEndAuthorizationTest.scala new file mode 100644 index 0000000..17e39f6 --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/SaslGssapiSslEndToEndAuthorizationTest.scala @@ -0,0 +1,46 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.api + +import kafka.security.authorizer.AclAuthorizer +import kafka.server.KafkaConfig +import kafka.utils.JaasTestUtils +import org.apache.kafka.common.config.SslConfigs +import org.apache.kafka.common.security.auth.KafkaPrincipal +import org.junit.jupiter.api.Assertions.assertNull + +import scala.collection.immutable.List + +class SaslGssapiSslEndToEndAuthorizationTest extends SaslEndToEndAuthorizationTest { + override val clientPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, + JaasTestUtils.KafkaClientPrincipalUnqualifiedName) + override val kafkaPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, + JaasTestUtils.KafkaServerPrincipalUnqualifiedName) + + override protected def kafkaClientSaslMechanism = "GSSAPI" + override protected def kafkaServerSaslMechanisms = List("GSSAPI") + override protected def authorizerClass = classOf[AclAuthorizer] + + // Configure brokers to require SSL client authentication in order to verify that SASL_SSL works correctly even if the + // client doesn't have a keystore. We want to cover the scenario where a broker requires either SSL client + // authentication or SASL authentication with SSL as the transport layer (but not both). + serverConfig.put(KafkaConfig.SslClientAuthProp, "required") + assertNull(producerConfig.get(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG)) + assertNull(consumerConfig.get(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG)) + assertNull(adminClientConfig.get(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG)) + +} diff --git a/core/src/test/scala/integration/kafka/api/SaslMultiMechanismConsumerTest.scala b/core/src/test/scala/integration/kafka/api/SaslMultiMechanismConsumerTest.scala new file mode 100644 index 0000000..37b36e1 --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/SaslMultiMechanismConsumerTest.scala @@ -0,0 +1,95 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package kafka.api + +import java.io.File + +import kafka.server.KafkaConfig +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} +import kafka.utils.JaasTestUtils +import org.apache.kafka.common.security.auth.SecurityProtocol + +import scala.jdk.CollectionConverters._ + +class SaslMultiMechanismConsumerTest extends BaseConsumerTest with SaslSetup { + private val kafkaClientSaslMechanism = "PLAIN" + private val kafkaServerSaslMechanisms = List("GSSAPI", "PLAIN") + this.serverConfig.setProperty(KafkaConfig.ZkEnableSecureAclsProp, "true") + override protected def securityProtocol = SecurityProtocol.SASL_SSL + override protected lazy val trustStoreFile = Some(File.createTempFile("truststore", ".jks")) + override protected val serverSaslProperties = Some(kafkaServerSaslProperties(kafkaServerSaslMechanisms, kafkaClientSaslMechanism)) + override protected val clientSaslProperties = Some(kafkaClientSaslProperties(kafkaClientSaslMechanism)) + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + startSasl(jaasSections(kafkaServerSaslMechanisms, Some(kafkaClientSaslMechanism), Both, + JaasTestUtils.KafkaServerContextName)) + super.setUp(testInfo) + } + + @AfterEach + override def tearDown(): Unit = { + super.tearDown() + closeSasl() + } + + @Test + def testMultipleBrokerMechanisms(): Unit = { + val plainSaslProducer = createProducer() + val plainSaslConsumer = createConsumer() + + val gssapiSaslProperties = kafkaClientSaslProperties("GSSAPI", dynamicJaasConfig = true) + val gssapiSaslProducer = createProducer(configOverrides = gssapiSaslProperties) + val gssapiSaslConsumer = createConsumer(configOverrides = gssapiSaslProperties) + val numRecords = 1000 + var startingOffset = 0 + + // Test SASL/PLAIN producer and consumer + var startingTimestamp = System.currentTimeMillis() + sendRecords(plainSaslProducer, numRecords, tp, startingTimestamp = startingTimestamp) + plainSaslConsumer.assign(List(tp).asJava) + plainSaslConsumer.seek(tp, 0) + consumeAndVerifyRecords(consumer = plainSaslConsumer, numRecords = numRecords, startingOffset = startingOffset, + startingTimestamp = startingTimestamp) + sendAndAwaitAsyncCommit(plainSaslConsumer) + startingOffset += numRecords + + // Test SASL/GSSAPI producer and consumer + startingTimestamp = System.currentTimeMillis() + sendRecords(gssapiSaslProducer, numRecords, tp, startingTimestamp = startingTimestamp) + gssapiSaslConsumer.assign(List(tp).asJava) + gssapiSaslConsumer.seek(tp, startingOffset) + consumeAndVerifyRecords(consumer = gssapiSaslConsumer, numRecords = numRecords, startingOffset = startingOffset, + startingTimestamp = startingTimestamp) + sendAndAwaitAsyncCommit(gssapiSaslConsumer) + startingOffset += numRecords + + // Test SASL/PLAIN producer and SASL/GSSAPI consumer + startingTimestamp = System.currentTimeMillis() + sendRecords(plainSaslProducer, numRecords, tp, startingTimestamp = startingTimestamp) + gssapiSaslConsumer.assign(List(tp).asJava) + gssapiSaslConsumer.seek(tp, startingOffset) + consumeAndVerifyRecords(consumer = gssapiSaslConsumer, numRecords = numRecords, startingOffset = startingOffset, + startingTimestamp = startingTimestamp) + startingOffset += numRecords + + // Test SASL/GSSAPI producer and SASL/PLAIN consumer + startingTimestamp = System.currentTimeMillis() + sendRecords(gssapiSaslProducer, numRecords, tp, startingTimestamp = startingTimestamp) + plainSaslConsumer.assign(List(tp).asJava) + plainSaslConsumer.seek(tp, startingOffset) + consumeAndVerifyRecords(consumer = plainSaslConsumer, numRecords = numRecords, startingOffset = startingOffset, + startingTimestamp = startingTimestamp) + } + +} diff --git a/core/src/test/scala/integration/kafka/api/SaslOAuthBearerSslEndToEndAuthorizationTest.scala b/core/src/test/scala/integration/kafka/api/SaslOAuthBearerSslEndToEndAuthorizationTest.scala new file mode 100644 index 0000000..4baa61d --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/SaslOAuthBearerSslEndToEndAuthorizationTest.scala @@ -0,0 +1,27 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.api + +import kafka.utils.JaasTestUtils +import org.apache.kafka.common.security.auth.KafkaPrincipal + +class SaslOAuthBearerSslEndToEndAuthorizationTest extends SaslEndToEndAuthorizationTest { + override protected def kafkaClientSaslMechanism = "OAUTHBEARER" + override protected def kafkaServerSaslMechanisms = List(kafkaClientSaslMechanism) + override val clientPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, JaasTestUtils.KafkaOAuthBearerUser) + override val kafkaPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, JaasTestUtils.KafkaOAuthBearerAdmin) +} diff --git a/core/src/test/scala/integration/kafka/api/SaslPlainPlaintextConsumerTest.scala b/core/src/test/scala/integration/kafka/api/SaslPlainPlaintextConsumerTest.scala new file mode 100644 index 0000000..042aa3d --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/SaslPlainPlaintextConsumerTest.scala @@ -0,0 +1,58 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package kafka.api + +import java.io.File +import java.util.Locale + +import kafka.server.KafkaConfig +import kafka.utils.{JaasTestUtils, TestUtils} +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +class SaslPlainPlaintextConsumerTest extends BaseConsumerTest with SaslSetup { + override protected def listenerName = new ListenerName("CLIENT") + private val kafkaClientSaslMechanism = "PLAIN" + private val kafkaServerSaslMechanisms = List(kafkaClientSaslMechanism) + private val kafkaServerJaasEntryName = + s"${listenerName.value.toLowerCase(Locale.ROOT)}.${JaasTestUtils.KafkaServerContextName}" + this.serverConfig.setProperty(KafkaConfig.ZkEnableSecureAclsProp, "false") + // disable secure acls of zkClient in QuorumTestHarness + override protected def zkAclsEnabled = Some(false) + override protected def securityProtocol = SecurityProtocol.SASL_PLAINTEXT + override protected lazy val trustStoreFile = Some(File.createTempFile("truststore", ".jks")) + override protected val serverSaslProperties = Some(kafkaServerSaslProperties(kafkaServerSaslMechanisms, kafkaClientSaslMechanism)) + override protected val clientSaslProperties = Some(kafkaClientSaslProperties(kafkaClientSaslMechanism)) + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + startSasl(jaasSections(kafkaServerSaslMechanisms, Some(kafkaClientSaslMechanism), Both, kafkaServerJaasEntryName)) + super.setUp(testInfo) + } + + @AfterEach + override def tearDown(): Unit = { + super.tearDown() + closeSasl() + } + + /** + * Checks that everyone can access ZkData.SecureZkRootPaths and ZkData.SensitiveZkRootPaths + * when zookeeper.set.acl=false, even if ZooKeeper is SASL-enabled. + */ + @Test + def testZkAclsDisabled(): Unit = { + TestUtils.verifyUnsecureZkAcls(zkClient) + } +} diff --git a/core/src/test/scala/integration/kafka/api/SaslPlainSslEndToEndAuthorizationTest.scala b/core/src/test/scala/integration/kafka/api/SaslPlainSslEndToEndAuthorizationTest.scala new file mode 100644 index 0000000..7727803 --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/SaslPlainSslEndToEndAuthorizationTest.scala @@ -0,0 +1,157 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.api + +import java.security.AccessController +import java.util.Properties + +import javax.security.auth.callback._ +import javax.security.auth.Subject +import javax.security.auth.login.AppConfigurationEntry + +import scala.collection.Seq +import kafka.server.KafkaConfig +import kafka.utils.TestUtils +import kafka.utils.JaasTestUtils._ +import org.apache.kafka.common.config.SaslConfigs +import org.apache.kafka.common.config.internals.BrokerSecurityConfigs +import org.apache.kafka.common.network.Mode +import org.apache.kafka.common.security.auth._ +import org.apache.kafka.common.security.authenticator.DefaultKafkaPrincipalBuilder +import org.apache.kafka.common.security.plain.PlainAuthenticateCallback +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test + +object SaslPlainSslEndToEndAuthorizationTest { + + class TestPrincipalBuilder extends DefaultKafkaPrincipalBuilder(null, null) { + + override def build(context: AuthenticationContext): KafkaPrincipal = { + val saslContext = context.asInstanceOf[SaslAuthenticationContext] + + // Verify that peer principal can be obtained from the SSLSession provided in the context + // since we have enabled TLS mutual authentication for the listener + val sslPrincipal = saslContext.sslSession.get.getPeerPrincipal.getName + assertTrue(sslPrincipal.endsWith(s"CN=${TestUtils.SslCertificateCn}"), s"Unexpected SSL principal $sslPrincipal") + + saslContext.server.getAuthorizationID match { + case KafkaPlainAdmin => + new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "admin") + case KafkaPlainUser => + new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "user") + case _ => + KafkaPrincipal.ANONYMOUS + } + } + } + + object Credentials { + val allUsers = Map(KafkaPlainUser -> "user1-password", + KafkaPlainUser2 -> KafkaPlainPassword2, + KafkaPlainAdmin -> "broker-password") + } + + class TestServerCallbackHandler extends AuthenticateCallbackHandler { + def configure(configs: java.util.Map[String, _], saslMechanism: String, jaasConfigEntries: java.util.List[AppConfigurationEntry]): Unit = {} + def handle(callbacks: Array[Callback]): Unit = { + var username: String = null + for (callback <- callbacks) { + if (callback.isInstanceOf[NameCallback]) + username = callback.asInstanceOf[NameCallback].getDefaultName + else if (callback.isInstanceOf[PlainAuthenticateCallback]) { + val plainCallback = callback.asInstanceOf[PlainAuthenticateCallback] + plainCallback.authenticated(Credentials.allUsers(username) == new String(plainCallback.password)) + } else + throw new UnsupportedCallbackException(callback) + } + } + def close(): Unit = {} + } + + class TestClientCallbackHandler extends AuthenticateCallbackHandler { + def configure(configs: java.util.Map[String, _], saslMechanism: String, jaasConfigEntries: java.util.List[AppConfigurationEntry]): Unit = {} + def handle(callbacks: Array[Callback]): Unit = { + val subject = Subject.getSubject(AccessController.getContext()) + val username = subject.getPublicCredentials(classOf[String]).iterator().next() + for (callback <- callbacks) { + if (callback.isInstanceOf[NameCallback]) + callback.asInstanceOf[NameCallback].setName(username) + else if (callback.isInstanceOf[PasswordCallback]) { + if (username == KafkaPlainUser || username == KafkaPlainAdmin) + callback.asInstanceOf[PasswordCallback].setPassword(Credentials.allUsers(username).toCharArray) + } else + throw new UnsupportedCallbackException(callback) + } + } + def close(): Unit = {} + } +} + + +// This test uses SASL callback handler overrides for server connections of Kafka broker +// and client connections of Kafka producers and consumers. Client connections from Kafka brokers +// used for inter-broker communication also use custom callback handlers. The second client used in +// the multi-user test SaslEndToEndAuthorizationTest#testTwoConsumersWithDifferentSaslCredentials uses +// static JAAS configuration with default callback handlers to test those code paths as well. +class SaslPlainSslEndToEndAuthorizationTest extends SaslEndToEndAuthorizationTest { + import SaslPlainSslEndToEndAuthorizationTest._ + + this.serverConfig.setProperty(s"${listenerName.configPrefix}${KafkaConfig.SslClientAuthProp}", "required") + this.serverConfig.setProperty(BrokerSecurityConfigs.PRINCIPAL_BUILDER_CLASS_CONFIG, classOf[TestPrincipalBuilder].getName) + this.serverConfig.put(KafkaConfig.SaslClientCallbackHandlerClassProp, classOf[TestClientCallbackHandler].getName) + val mechanismPrefix = listenerName.saslMechanismConfigPrefix("PLAIN") + this.serverConfig.put(s"$mechanismPrefix${KafkaConfig.SaslServerCallbackHandlerClassProp}", classOf[TestServerCallbackHandler].getName) + this.producerConfig.put(SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS, classOf[TestClientCallbackHandler].getName) + this.consumerConfig.put(SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS, classOf[TestClientCallbackHandler].getName) + this.adminClientConfig.put(SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS, classOf[TestClientCallbackHandler].getName) + private val plainLogin = s"org.apache.kafka.common.security.plain.PlainLoginModule username=$KafkaPlainUser required;" + this.producerConfig.put(SaslConfigs.SASL_JAAS_CONFIG, plainLogin) + this.consumerConfig.put(SaslConfigs.SASL_JAAS_CONFIG, plainLogin) + this.adminClientConfig.put(SaslConfigs.SASL_JAAS_CONFIG, plainLogin) + + override protected def kafkaClientSaslMechanism = "PLAIN" + override protected def kafkaServerSaslMechanisms = List("PLAIN") + + override val clientPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "user") + override val kafkaPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "admin") + + override def jaasSections(kafkaServerSaslMechanisms: Seq[String], + kafkaClientSaslMechanism: Option[String], + mode: SaslSetupMode, + kafkaServerEntryName: String): Seq[JaasSection] = { + val brokerLogin = PlainLoginModule(KafkaPlainAdmin, "") // Password provided by callback handler + val clientLogin = PlainLoginModule(KafkaPlainUser2, KafkaPlainPassword2) + Seq(JaasSection(kafkaServerEntryName, Seq(brokerLogin)), + JaasSection(KafkaClientContextName, Seq(clientLogin))) ++ zkSections + } + + // Generate SSL certificates for clients since we are enabling TLS mutual authentication + // in this test for the SASL_SSL listener. + override def clientSecurityProps(certAlias: String): Properties = { + TestUtils.securityConfigs(Mode.CLIENT, securityProtocol, trustStoreFile, certAlias, TestUtils.SslCertificateCn, + clientSaslProperties, needsClientCert = Some(true)) + } + + /** + * Checks that secure paths created by broker and acl paths created by AclCommand + * have expected ACLs. + */ + @Test + def testAcls(): Unit = { + TestUtils.verifySecureZkAcls(zkClient, 1) + } +} diff --git a/core/src/test/scala/integration/kafka/api/SaslPlaintextConsumerTest.scala b/core/src/test/scala/integration/kafka/api/SaslPlaintextConsumerTest.scala new file mode 100644 index 0000000..0933818 --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/SaslPlaintextConsumerTest.scala @@ -0,0 +1,34 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package kafka.api + +import kafka.utils.JaasTestUtils +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.junit.jupiter.api.{AfterEach, BeforeEach, TestInfo} + +class SaslPlaintextConsumerTest extends BaseConsumerTest with SaslSetup { + override protected def securityProtocol = SecurityProtocol.SASL_PLAINTEXT + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + startSasl(jaasSections(Seq("GSSAPI"), Some("GSSAPI"), KafkaSasl, JaasTestUtils.KafkaServerContextName)) + super.setUp(testInfo) + } + + @AfterEach + override def tearDown(): Unit = { + super.tearDown() + closeSasl() + } + +} diff --git a/core/src/test/scala/integration/kafka/api/SaslScramSslEndToEndAuthorizationTest.scala b/core/src/test/scala/integration/kafka/api/SaslScramSslEndToEndAuthorizationTest.scala new file mode 100644 index 0000000..6e334d1 --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/SaslScramSslEndToEndAuthorizationTest.scala @@ -0,0 +1,61 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.api + +import java.util.Properties + +import kafka.utils.JaasTestUtils +import kafka.zk.ConfigEntityChangeNotificationZNode +import org.apache.kafka.common.security.auth.KafkaPrincipal +import org.apache.kafka.common.security.scram.internals.ScramMechanism +import org.apache.kafka.test.TestSslUtils + +import scala.jdk.CollectionConverters._ +import org.junit.jupiter.api.{BeforeEach, TestInfo} + +class SaslScramSslEndToEndAuthorizationTest extends SaslEndToEndAuthorizationTest { + override protected def kafkaClientSaslMechanism = "SCRAM-SHA-256" + override protected def kafkaServerSaslMechanisms = ScramMechanism.mechanismNames.asScala.toList + override val clientPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, JaasTestUtils.KafkaScramUser) + override val kafkaPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, JaasTestUtils.KafkaScramAdmin) + private val kafkaPassword = JaasTestUtils.KafkaScramAdminPassword + + override def configureSecurityBeforeServersStart(): Unit = { + super.configureSecurityBeforeServersStart() + zkClient.makeSurePersistentPathExists(ConfigEntityChangeNotificationZNode.path) + // Create broker credentials before starting brokers + createScramCredentials(zkConnect, kafkaPrincipal.getName, kafkaPassword) + TestSslUtils.convertToPemWithoutFiles(producerConfig) + TestSslUtils.convertToPemWithoutFiles(consumerConfig) + TestSslUtils.convertToPemWithoutFiles(adminClientConfig) + } + + override def configureListeners(props: collection.Seq[Properties]): Unit = { + props.foreach(TestSslUtils.convertToPemWithoutFiles) + super.configureListeners(props) + } + + override def createPrivilegedAdminClient() = createScramAdminClient(kafkaClientSaslMechanism, kafkaPrincipal.getName, kafkaPassword) + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + // Create client credentials after starting brokers so that dynamic credential creation is also tested + createScramCredentialsViaPrivilegedAdminClient(JaasTestUtils.KafkaScramUser, JaasTestUtils.KafkaScramPassword) + createScramCredentialsViaPrivilegedAdminClient(JaasTestUtils.KafkaScramUser2, JaasTestUtils.KafkaScramPassword2) + } +} diff --git a/core/src/test/scala/integration/kafka/api/SaslSetup.scala b/core/src/test/scala/integration/kafka/api/SaslSetup.scala new file mode 100644 index 0000000..d613b72 --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/SaslSetup.scala @@ -0,0 +1,215 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.api + +import java.io.File +import java.util +import java.util.Properties + +import javax.security.auth.login.Configuration + +import scala.collection.Seq +import kafka.security.minikdc.MiniKdc +import kafka.server.{ConfigType, KafkaConfig} +import kafka.utils.JaasTestUtils.{JaasSection, Krb5LoginModule, ZkDigestModule} +import kafka.utils.{JaasTestUtils, TestUtils} +import kafka.zk.{AdminZkClient, KafkaZkClient} +import org.apache.kafka.clients.admin.{Admin, AdminClientConfig, ScramCredentialInfo, UserScramCredentialAlteration, UserScramCredentialUpsertion, ScramMechanism => PublicScramMechanism} +import org.apache.kafka.common.config.SaslConfigs +import org.apache.kafka.common.config.internals.BrokerSecurityConfigs +import org.apache.kafka.common.security.JaasUtils +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.security.authenticator.LoginManager +import org.apache.kafka.common.security.scram.internals.{ScramCredentialUtils, ScramFormatter, ScramMechanism} +import org.apache.kafka.common.utils.Time +import org.apache.zookeeper.client.ZKClientConfig + +import scala.jdk.CollectionConverters._ + +/* + * Implements an enumeration for the modes enabled here: + * zk only, kafka only, both, custom KafkaServer. + */ +sealed trait SaslSetupMode +case object ZkSasl extends SaslSetupMode +case object KafkaSasl extends SaslSetupMode +case object Both extends SaslSetupMode + +/* + * Trait used in SaslTestHarness and EndToEndAuthorizationTest to setup keytab and jaas files. + */ +trait SaslSetup { + private val workDir = TestUtils.tempDir() + private val kdcConf = MiniKdc.createConfig + private var kdc: MiniKdc = null + private var serverKeytabFile: Option[File] = None + private var clientKeytabFile: Option[File] = None + + def startSasl(jaasSections: Seq[JaasSection]): Unit = { + // Important if tests leak consumers, producers or brokers + LoginManager.closeAll() + val hasKerberos = jaasSections.exists(_.modules.exists { + case _: Krb5LoginModule => true + case _ => false + }) + if (hasKerberos) { + initializeKerberos() + } + writeJaasConfigurationToFile(jaasSections) + val hasZk = jaasSections.exists(_.modules.exists { + case _: ZkDigestModule => true + case _ => false + }) + if (hasZk) + System.setProperty("zookeeper.authProvider.1", "org.apache.zookeeper.server.auth.SASLAuthenticationProvider") + } + + protected def initializeKerberos(): Unit = { + val (serverKeytabFile, clientKeytabFile) = maybeCreateEmptyKeytabFiles() + kdc = new MiniKdc(kdcConf, workDir) + kdc.start() + kdc.createPrincipal(serverKeytabFile, JaasTestUtils.KafkaServerPrincipalUnqualifiedName + "/localhost") + kdc.createPrincipal(clientKeytabFile, + JaasTestUtils.KafkaClientPrincipalUnqualifiedName, JaasTestUtils.KafkaClientPrincipalUnqualifiedName2) + } + + /** Return a tuple with the path to the server keytab file and client keytab file */ + protected def maybeCreateEmptyKeytabFiles(): (File, File) = { + if (serverKeytabFile.isEmpty) + serverKeytabFile = Some(TestUtils.tempFile()) + if (clientKeytabFile.isEmpty) + clientKeytabFile = Some(TestUtils.tempFile()) + (serverKeytabFile.get, clientKeytabFile.get) + } + + def jaasSections(kafkaServerSaslMechanisms: Seq[String], + kafkaClientSaslMechanism: Option[String], + mode: SaslSetupMode = Both, + kafkaServerEntryName: String = JaasTestUtils.KafkaServerContextName): Seq[JaasSection] = { + val hasKerberos = mode != ZkSasl && + (kafkaServerSaslMechanisms.contains("GSSAPI") || kafkaClientSaslMechanism.contains("GSSAPI")) + if (hasKerberos) + maybeCreateEmptyKeytabFiles() + mode match { + case ZkSasl => JaasTestUtils.zkSections + case KafkaSasl => + Seq(JaasTestUtils.kafkaServerSection(kafkaServerEntryName, kafkaServerSaslMechanisms, serverKeytabFile), + JaasTestUtils.kafkaClientSection(kafkaClientSaslMechanism, clientKeytabFile)) + case Both => Seq(JaasTestUtils.kafkaServerSection(kafkaServerEntryName, kafkaServerSaslMechanisms, serverKeytabFile), + JaasTestUtils.kafkaClientSection(kafkaClientSaslMechanism, clientKeytabFile)) ++ JaasTestUtils.zkSections + } + } + + private def writeJaasConfigurationToFile(jaasSections: Seq[JaasSection]): Unit = { + val file = JaasTestUtils.writeJaasContextsToFile(jaasSections) + System.setProperty(JaasUtils.JAVA_LOGIN_CONFIG_PARAM, file.getAbsolutePath) + // This will cause a reload of the Configuration singleton when `getConfiguration` is called + Configuration.setConfiguration(null) + } + + def closeSasl(): Unit = { + if (kdc != null) + kdc.stop() + // Important if tests leak consumers, producers or brokers + LoginManager.closeAll() + System.clearProperty(JaasUtils.JAVA_LOGIN_CONFIG_PARAM) + System.clearProperty("zookeeper.authProvider.1") + Configuration.setConfiguration(null) + } + + def kafkaServerSaslProperties(serverSaslMechanisms: Seq[String], interBrokerSaslMechanism: String): Properties = { + val props = new Properties + props.put(KafkaConfig.SaslMechanismInterBrokerProtocolProp, interBrokerSaslMechanism) + props.put(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG, serverSaslMechanisms.mkString(",")) + props + } + + def kafkaClientSaslProperties(clientSaslMechanism: String, dynamicJaasConfig: Boolean = false): Properties = { + val props = new Properties + props.put(SaslConfigs.SASL_MECHANISM, clientSaslMechanism) + if (dynamicJaasConfig) + props.put(SaslConfigs.SASL_JAAS_CONFIG, jaasClientLoginModule(clientSaslMechanism)) + props + } + + def jaasClientLoginModule(clientSaslMechanism: String, serviceName: Option[String] = None): String = { + if (serviceName.isDefined) + JaasTestUtils.clientLoginModule(clientSaslMechanism, clientKeytabFile, serviceName.get) + else + JaasTestUtils.clientLoginModule(clientSaslMechanism, clientKeytabFile) + } + + def jaasScramClientLoginModule(clientSaslScramMechanism: String, scramUser: String, scramPassword: String): String = { + JaasTestUtils.scramClientLoginModule(clientSaslScramMechanism, scramUser, scramPassword) + } + + def createPrivilegedAdminClient(): Admin = { + // create an admin client instance that is authorized to create credentials + throw new UnsupportedOperationException("Must implement this if a test needs to use it") + } + + def createAdminClient(brokerList: String, securityProtocol: SecurityProtocol, trustStoreFile: Option[File], + clientSaslProperties: Option[Properties], scramMechanism: String, user: String, password: String) : Admin = { + val config = new util.HashMap[String, Object] + config.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList) + val securityProps: util.Map[Object, Object] = + TestUtils.adminClientSecurityConfigs(securityProtocol, trustStoreFile, clientSaslProperties) + securityProps.forEach { (key, value) => config.put(key.asInstanceOf[String], value) } + config.put(SaslConfigs.SASL_JAAS_CONFIG, jaasScramClientLoginModule(scramMechanism, user, password)) + Admin.create(config) + } + + def createScramCredentialsViaPrivilegedAdminClient(userName: String, password: String): Unit = { + val privilegedAdminClient = createPrivilegedAdminClient() // must explicitly implement this method + try { + // create the SCRAM credential for the given user + createScramCredentials(privilegedAdminClient, userName, password) + } finally { + privilegedAdminClient.close() + } + } + + def createScramCredentials(adminClient: Admin, userName: String, password: String): Unit = { + val results = adminClient.alterUserScramCredentials(PublicScramMechanism.values().filter(_ != PublicScramMechanism.UNKNOWN).map(mechanism => + new UserScramCredentialUpsertion(userName, new ScramCredentialInfo(mechanism, 4096), password) + .asInstanceOf[UserScramCredentialAlteration]).toList.asJava) + results.all.get + } + + def createScramCredentials(zkConnect: String, userName: String, password: String): Unit = { + val zkClientConfig = new ZKClientConfig() + val zkClient = KafkaZkClient( + zkConnect, JaasUtils.isZkSaslEnabled || KafkaConfig.zkTlsClientAuthEnabled(zkClientConfig), 30000, 30000, + Int.MaxValue, Time.SYSTEM, name = "SaslSetup", zkClientConfig = zkClientConfig) + val adminZkClient = new AdminZkClient(zkClient) + + val entityType = ConfigType.User + val entityName = userName + val configs = adminZkClient.fetchEntityConfig(entityType, entityName) + + ScramMechanism.values().foreach(mechanism => { + val credential = new ScramFormatter(mechanism).generateCredential(password, 4096) + val credentialString = ScramCredentialUtils.credentialToString(credential) + configs.setProperty(mechanism.mechanismName, credentialString) + }) + + adminZkClient.changeConfigs(entityType, entityName, configs) + zkClient.close() + } + +} diff --git a/core/src/test/scala/integration/kafka/api/SaslSslAdminIntegrationTest.scala b/core/src/test/scala/integration/kafka/api/SaslSslAdminIntegrationTest.scala new file mode 100644 index 0000000..3cb424d --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/SaslSslAdminIntegrationTest.scala @@ -0,0 +1,528 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package kafka.api + +import java.io.File +import java.util + +import kafka.log.LogConfig +import kafka.security.authorizer.AclAuthorizer +import kafka.security.authorizer.AclEntry.{WildcardHost, WildcardPrincipalString} +import kafka.server.{Defaults, KafkaConfig} +import kafka.utils.{CoreUtils, JaasTestUtils, TestUtils} +import kafka.utils.TestUtils._ +import org.apache.kafka.clients.admin._ +import org.apache.kafka.common.Uuid +import org.apache.kafka.common.acl._ +import org.apache.kafka.common.acl.AclOperation.{ALL, ALTER, ALTER_CONFIGS, CLUSTER_ACTION, CREATE, DELETE, DESCRIBE} +import org.apache.kafka.common.acl.AclPermissionType.{ALLOW, DENY} +import org.apache.kafka.common.config.ConfigResource +import org.apache.kafka.common.errors.{ClusterAuthorizationException, InvalidRequestException, TopicAuthorizationException, UnknownTopicOrPartitionException} +import org.apache.kafka.common.resource.PatternType.LITERAL +import org.apache.kafka.common.resource.ResourceType.{GROUP, TOPIC} +import org.apache.kafka.common.resource.{PatternType, Resource, ResourcePattern, ResourcePatternFilter, ResourceType} +import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol} +import org.apache.kafka.server.authorizer.Authorizer +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} +import java.util.Collections + +import scala.jdk.CollectionConverters._ +import scala.collection.Seq +import scala.compat.java8.OptionConverters._ +import scala.concurrent.ExecutionException +import scala.util.{Failure, Success, Try} + +class SaslSslAdminIntegrationTest extends BaseAdminIntegrationTest with SaslSetup { + val clusterResourcePattern = new ResourcePattern(ResourceType.CLUSTER, Resource.CLUSTER_NAME, PatternType.LITERAL) + + val authorizationAdmin = new AclAuthorizationAdmin(classOf[AclAuthorizer], classOf[AclAuthorizer]) + + this.serverConfig.setProperty(KafkaConfig.ZkEnableSecureAclsProp, "true") + + override protected def securityProtocol = SecurityProtocol.SASL_SSL + override protected lazy val trustStoreFile = Some(File.createTempFile("truststore", ".jks")) + + override def generateConfigs: Seq[KafkaConfig] = { + this.serverConfig.setProperty(KafkaConfig.AuthorizerClassNameProp, authorizationAdmin.authorizerClassName) + super.generateConfigs + } + + override def configureSecurityBeforeServersStart(): Unit = { + authorizationAdmin.initializeAcls() + } + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + setUpSasl() + super.setUp(testInfo) + } + + def setUpSasl(): Unit = { + startSasl(jaasSections(Seq("GSSAPI"), Some("GSSAPI"), Both, JaasTestUtils.KafkaServerContextName)) + } + + @AfterEach + override def tearDown(): Unit = { + super.tearDown() + closeSasl() + } + + val anyAcl = new AclBinding(new ResourcePattern(ResourceType.TOPIC, "*", PatternType.LITERAL), + new AccessControlEntry("User:*", "*", AclOperation.ALL, AclPermissionType.ALLOW)) + val acl2 = new AclBinding(new ResourcePattern(ResourceType.TOPIC, "mytopic2", PatternType.LITERAL), + new AccessControlEntry("User:ANONYMOUS", "*", AclOperation.WRITE, AclPermissionType.ALLOW)) + val acl3 = new AclBinding(new ResourcePattern(ResourceType.TOPIC, "mytopic3", PatternType.LITERAL), + new AccessControlEntry("User:ANONYMOUS", "*", AclOperation.READ, AclPermissionType.ALLOW)) + val fooAcl = new AclBinding(new ResourcePattern(ResourceType.TOPIC, "foobar", PatternType.LITERAL), + new AccessControlEntry("User:ANONYMOUS", "*", AclOperation.READ, AclPermissionType.ALLOW)) + val prefixAcl = new AclBinding(new ResourcePattern(ResourceType.TOPIC, "mytopic", PatternType.PREFIXED), + new AccessControlEntry("User:ANONYMOUS", "*", AclOperation.READ, AclPermissionType.ALLOW)) + val transactionalIdAcl = new AclBinding(new ResourcePattern(ResourceType.TRANSACTIONAL_ID, "transactional_id", PatternType.LITERAL), + new AccessControlEntry("User:ANONYMOUS", "*", AclOperation.WRITE, AclPermissionType.ALLOW)) + val groupAcl = new AclBinding(new ResourcePattern(ResourceType.GROUP, "*", PatternType.LITERAL), + new AccessControlEntry("User:*", "*", AclOperation.ALL, AclPermissionType.ALLOW)) + + @Test + def testAclOperations(): Unit = { + client = Admin.create(createConfig) + val acl = new AclBinding(new ResourcePattern(ResourceType.TOPIC, "mytopic3", PatternType.LITERAL), + new AccessControlEntry("User:ANONYMOUS", "*", AclOperation.DESCRIBE, AclPermissionType.ALLOW)) + assertEquals(7, getAcls(AclBindingFilter.ANY).size) + val results = client.createAcls(List(acl2, acl3).asJava) + assertEquals(Set(acl2, acl3), results.values.keySet().asScala) + results.values.values.forEach(value => value.get) + val aclUnknown = new AclBinding(new ResourcePattern(ResourceType.TOPIC, "mytopic3", PatternType.LITERAL), + new AccessControlEntry("User:ANONYMOUS", "*", AclOperation.UNKNOWN, AclPermissionType.ALLOW)) + val results2 = client.createAcls(List(aclUnknown).asJava) + assertEquals(Set(aclUnknown), results2.values.keySet().asScala) + assertFutureExceptionTypeEquals(results2.all, classOf[InvalidRequestException]) + val results3 = client.deleteAcls(List(acl.toFilter, acl2.toFilter, acl3.toFilter).asJava).values + assertEquals(Set(acl.toFilter, acl2.toFilter, acl3.toFilter), results3.keySet.asScala) + assertEquals(0, results3.get(acl.toFilter).get.values.size()) + assertEquals(Set(acl2), results3.get(acl2.toFilter).get.values.asScala.map(_.binding).toSet) + assertEquals(Set(acl3), results3.get(acl3.toFilter).get.values.asScala.map(_.binding).toSet) + } + + @Test + def testAclOperations2(): Unit = { + client = Admin.create(createConfig) + val results = client.createAcls(List(acl2, acl2, transactionalIdAcl).asJava) + assertEquals(Set(acl2, acl2, transactionalIdAcl), results.values.keySet.asScala) + results.all.get() + waitForDescribeAcls(client, acl2.toFilter, Set(acl2)) + waitForDescribeAcls(client, transactionalIdAcl.toFilter, Set(transactionalIdAcl)) + + val filterA = new AclBindingFilter(new ResourcePatternFilter(ResourceType.GROUP, null, PatternType.LITERAL), AccessControlEntryFilter.ANY) + val filterB = new AclBindingFilter(new ResourcePatternFilter(ResourceType.TOPIC, "mytopic2", PatternType.LITERAL), AccessControlEntryFilter.ANY) + val filterC = new AclBindingFilter(new ResourcePatternFilter(ResourceType.TRANSACTIONAL_ID, null, PatternType.LITERAL), AccessControlEntryFilter.ANY) + + waitForDescribeAcls(client, filterA, Set(groupAcl)) + waitForDescribeAcls(client, filterC, Set(transactionalIdAcl)) + + val results2 = client.deleteAcls(List(filterA, filterB, filterC).asJava, new DeleteAclsOptions()) + assertEquals(Set(filterA, filterB, filterC), results2.values.keySet.asScala) + assertEquals(Set(groupAcl), results2.values.get(filterA).get.values.asScala.map(_.binding).toSet) + assertEquals(Set(transactionalIdAcl), results2.values.get(filterC).get.values.asScala.map(_.binding).toSet) + assertEquals(Set(acl2), results2.values.get(filterB).get.values.asScala.map(_.binding).toSet) + + waitForDescribeAcls(client, filterB, Set()) + waitForDescribeAcls(client, filterC, Set()) + } + + @Test + def testAclDescribe(): Unit = { + client = Admin.create(createConfig) + ensureAcls(Set(anyAcl, acl2, fooAcl, prefixAcl)) + + val allTopicAcls = new AclBindingFilter(new ResourcePatternFilter(ResourceType.TOPIC, null, PatternType.ANY), AccessControlEntryFilter.ANY) + val allLiteralTopicAcls = new AclBindingFilter(new ResourcePatternFilter(ResourceType.TOPIC, null, PatternType.LITERAL), AccessControlEntryFilter.ANY) + val allPrefixedTopicAcls = new AclBindingFilter(new ResourcePatternFilter(ResourceType.TOPIC, null, PatternType.PREFIXED), AccessControlEntryFilter.ANY) + val literalMyTopic2Acls = new AclBindingFilter(new ResourcePatternFilter(ResourceType.TOPIC, "mytopic2", PatternType.LITERAL), AccessControlEntryFilter.ANY) + val prefixedMyTopicAcls = new AclBindingFilter(new ResourcePatternFilter(ResourceType.TOPIC, "mytopic", PatternType.PREFIXED), AccessControlEntryFilter.ANY) + val allMyTopic2Acls = new AclBindingFilter(new ResourcePatternFilter(ResourceType.TOPIC, "mytopic2", PatternType.MATCH), AccessControlEntryFilter.ANY) + val allFooTopicAcls = new AclBindingFilter(new ResourcePatternFilter(ResourceType.TOPIC, "foobar", PatternType.MATCH), AccessControlEntryFilter.ANY) + + assertEquals(Set(anyAcl), getAcls(anyAcl.toFilter)) + assertEquals(Set(prefixAcl), getAcls(prefixAcl.toFilter)) + assertEquals(Set(acl2), getAcls(acl2.toFilter)) + assertEquals(Set(fooAcl), getAcls(fooAcl.toFilter)) + + assertEquals(Set(acl2), getAcls(literalMyTopic2Acls)) + assertEquals(Set(prefixAcl), getAcls(prefixedMyTopicAcls)) + assertEquals(Set(anyAcl, acl2, fooAcl), getAcls(allLiteralTopicAcls)) + assertEquals(Set(prefixAcl), getAcls(allPrefixedTopicAcls)) + assertEquals(Set(anyAcl, acl2, prefixAcl), getAcls(allMyTopic2Acls)) + assertEquals(Set(anyAcl, fooAcl), getAcls(allFooTopicAcls)) + assertEquals(Set(anyAcl, acl2, fooAcl, prefixAcl), getAcls(allTopicAcls)) + } + + @Test + def testAclDelete(): Unit = { + client = Admin.create(createConfig) + ensureAcls(Set(anyAcl, acl2, fooAcl, prefixAcl)) + + val allTopicAcls = new AclBindingFilter(new ResourcePatternFilter(ResourceType.TOPIC, null, PatternType.MATCH), AccessControlEntryFilter.ANY) + val allLiteralTopicAcls = new AclBindingFilter(new ResourcePatternFilter(ResourceType.TOPIC, null, PatternType.LITERAL), AccessControlEntryFilter.ANY) + val allPrefixedTopicAcls = new AclBindingFilter(new ResourcePatternFilter(ResourceType.TOPIC, null, PatternType.PREFIXED), AccessControlEntryFilter.ANY) + + // Delete only ACLs on literal 'mytopic2' topic + var deleted = client.deleteAcls(List(acl2.toFilter).asJava).all().get().asScala.toSet + assertEquals(Set(acl2), deleted) + assertEquals(Set(anyAcl, fooAcl, prefixAcl), getAcls(allTopicAcls)) + + ensureAcls(deleted) + + // Delete only ACLs on literal '*' topic + deleted = client.deleteAcls(List(anyAcl.toFilter).asJava).all().get().asScala.toSet + assertEquals(Set(anyAcl), deleted) + assertEquals(Set(acl2, fooAcl, prefixAcl), getAcls(allTopicAcls)) + + ensureAcls(deleted) + + // Delete only ACLs on specific prefixed 'mytopic' topics: + deleted = client.deleteAcls(List(prefixAcl.toFilter).asJava).all().get().asScala.toSet + assertEquals(Set(prefixAcl), deleted) + assertEquals(Set(anyAcl, acl2, fooAcl), getAcls(allTopicAcls)) + + ensureAcls(deleted) + + // Delete all literal ACLs: + deleted = client.deleteAcls(List(allLiteralTopicAcls).asJava).all().get().asScala.toSet + assertEquals(Set(anyAcl, acl2, fooAcl), deleted) + assertEquals(Set(prefixAcl), getAcls(allTopicAcls)) + + ensureAcls(deleted) + + // Delete all prefixed ACLs: + deleted = client.deleteAcls(List(allPrefixedTopicAcls).asJava).all().get().asScala.toSet + assertEquals(Set(prefixAcl), deleted) + assertEquals(Set(anyAcl, acl2, fooAcl), getAcls(allTopicAcls)) + + ensureAcls(deleted) + + // Delete all topic ACLs: + deleted = client.deleteAcls(List(allTopicAcls).asJava).all().get().asScala.toSet + assertEquals(Set(), getAcls(allTopicAcls)) + } + + //noinspection ScalaDeprecation - test explicitly covers clients using legacy / deprecated constructors + @Test + def testLegacyAclOpsNeverAffectOrReturnPrefixed(): Unit = { + client = Admin.create(createConfig) + ensureAcls(Set(anyAcl, acl2, fooAcl, prefixAcl)) // <-- prefixed exists, but should never be returned. + + val allTopicAcls = new AclBindingFilter(new ResourcePatternFilter(ResourceType.TOPIC, null, PatternType.MATCH), AccessControlEntryFilter.ANY) + val legacyAllTopicAcls = new AclBindingFilter(new ResourcePatternFilter(ResourceType.TOPIC, null, PatternType.LITERAL), AccessControlEntryFilter.ANY) + val legacyMyTopic2Acls = new AclBindingFilter(new ResourcePatternFilter(ResourceType.TOPIC, "mytopic2", PatternType.LITERAL), AccessControlEntryFilter.ANY) + val legacyAnyTopicAcls = new AclBindingFilter(new ResourcePatternFilter(ResourceType.TOPIC, "*", PatternType.LITERAL), AccessControlEntryFilter.ANY) + val legacyFooTopicAcls = new AclBindingFilter(new ResourcePatternFilter(ResourceType.TOPIC, "foobar", PatternType.LITERAL), AccessControlEntryFilter.ANY) + + assertEquals(Set(anyAcl, acl2, fooAcl), getAcls(legacyAllTopicAcls)) + assertEquals(Set(acl2), getAcls(legacyMyTopic2Acls)) + assertEquals(Set(anyAcl), getAcls(legacyAnyTopicAcls)) + assertEquals(Set(fooAcl), getAcls(legacyFooTopicAcls)) + + // Delete only (legacy) ACLs on 'mytopic2' topic + var deleted = client.deleteAcls(List(legacyMyTopic2Acls).asJava).all().get().asScala.toSet + assertEquals(Set(acl2), deleted) + assertEquals(Set(anyAcl, fooAcl, prefixAcl), getAcls(allTopicAcls)) + + ensureAcls(deleted) + + // Delete only (legacy) ACLs on '*' topic + deleted = client.deleteAcls(List(legacyAnyTopicAcls).asJava).all().get().asScala.toSet + assertEquals(Set(anyAcl), deleted) + assertEquals(Set(acl2, fooAcl, prefixAcl), getAcls(allTopicAcls)) + + ensureAcls(deleted) + + // Delete all (legacy) topic ACLs: + deleted = client.deleteAcls(List(legacyAllTopicAcls).asJava).all().get().asScala.toSet + assertEquals(Set(anyAcl, acl2, fooAcl), deleted) + assertEquals(Set(), getAcls(legacyAllTopicAcls)) + assertEquals(Set(prefixAcl), getAcls(allTopicAcls)) + } + + @Test + def testAttemptToCreateInvalidAcls(): Unit = { + client = Admin.create(createConfig) + val clusterAcl = new AclBinding(new ResourcePattern(ResourceType.CLUSTER, "foobar", PatternType.LITERAL), + new AccessControlEntry("User:ANONYMOUS", "*", AclOperation.READ, AclPermissionType.ALLOW)) + val emptyResourceNameAcl = new AclBinding(new ResourcePattern(ResourceType.TOPIC, "", PatternType.LITERAL), + new AccessControlEntry("User:ANONYMOUS", "*", AclOperation.READ, AclPermissionType.ALLOW)) + val results = client.createAcls(List(clusterAcl, emptyResourceNameAcl).asJava, new CreateAclsOptions()) + assertEquals(Set(clusterAcl, emptyResourceNameAcl), results.values.keySet().asScala) + assertFutureExceptionTypeEquals(results.values.get(clusterAcl), classOf[InvalidRequestException]) + assertFutureExceptionTypeEquals(results.values.get(emptyResourceNameAcl), classOf[InvalidRequestException]) + } + + override def configuredClusterPermissions: Set[AclOperation] = { + Set(AclOperation.ALTER, AclOperation.CREATE, AclOperation.CLUSTER_ACTION, AclOperation.ALTER_CONFIGS, + AclOperation.DESCRIBE, AclOperation.DESCRIBE_CONFIGS) + } + + private def verifyCauseIsClusterAuth(e: Throwable): Unit = assertEquals(classOf[ClusterAuthorizationException], e.getCause.getClass) + + private def testAclCreateGetDelete(expectAuth: Boolean): Unit = { + TestUtils.waitUntilTrue(() => { + val result = client.createAcls(List(fooAcl, transactionalIdAcl).asJava, new CreateAclsOptions) + if (expectAuth) { + Try(result.all.get) match { + case Failure(e) => + verifyCauseIsClusterAuth(e) + false + case Success(_) => true + } + } else { + Try(result.all.get) match { + case Failure(e) => + verifyCauseIsClusterAuth(e) + true + case Success(_) => false + } + } + }, "timed out waiting for createAcls to " + (if (expectAuth) "succeed" else "fail")) + if (expectAuth) { + waitForDescribeAcls(client, fooAcl.toFilter, Set(fooAcl)) + waitForDescribeAcls(client, transactionalIdAcl.toFilter, Set(transactionalIdAcl)) + } + TestUtils.waitUntilTrue(() => { + val result = client.deleteAcls(List(fooAcl.toFilter, transactionalIdAcl.toFilter).asJava, new DeleteAclsOptions) + if (expectAuth) { + Try(result.all.get) match { + case Failure(e) => + verifyCauseIsClusterAuth(e) + false + case Success(_) => true + } + } else { + Try(result.all.get) match { + case Failure(e) => + verifyCauseIsClusterAuth(e) + true + case Success(_) => + assertEquals(Set(fooAcl, transactionalIdAcl), result.values.keySet) + assertEquals(Set(fooAcl), result.values.get(fooAcl.toFilter).get.values.asScala.map(_.binding).toSet) + assertEquals(Set(transactionalIdAcl), + result.values.get(transactionalIdAcl.toFilter).get.values.asScala.map(_.binding).toSet) + true + } + } + }, "timed out waiting for deleteAcls to " + (if (expectAuth) "succeed" else "fail")) + if (expectAuth) { + waitForDescribeAcls(client, fooAcl.toFilter, Set.empty) + waitForDescribeAcls(client, transactionalIdAcl.toFilter, Set.empty) + } + } + + private def testAclGet(expectAuth: Boolean): Unit = { + TestUtils.waitUntilTrue(() => { + val userAcl = new AclBinding(new ResourcePattern(ResourceType.TOPIC, "*", PatternType.LITERAL), + new AccessControlEntry("User:*", "*", AclOperation.ALL, AclPermissionType.ALLOW)) + val results = client.describeAcls(userAcl.toFilter) + if (expectAuth) { + Try(results.values.get) match { + case Failure(e) => + verifyCauseIsClusterAuth(e) + false + case Success(acls) => Set(userAcl).equals(acls.asScala.toSet) + } + } else { + Try(results.values.get) match { + case Failure(e) => + verifyCauseIsClusterAuth(e) + true + case Success(_) => false + } + } + }, "timed out waiting for describeAcls to " + (if (expectAuth) "succeed" else "fail")) + } + + @Test + def testAclAuthorizationDenied(): Unit = { + client = Admin.create(createConfig) + + // Test that we cannot create or delete ACLs when ALTER is denied. + authorizationAdmin.addClusterAcl(DENY, ALTER) + testAclGet(expectAuth = true) + testAclCreateGetDelete(expectAuth = false) + + // Test that we cannot do anything with ACLs when DESCRIBE and ALTER are denied. + authorizationAdmin.addClusterAcl(DENY, DESCRIBE) + testAclGet(expectAuth = false) + testAclCreateGetDelete(expectAuth = false) + + // Test that we can create, delete, and get ACLs with the default ACLs. + authorizationAdmin.removeClusterAcl(DENY, DESCRIBE) + authorizationAdmin.removeClusterAcl(DENY, ALTER) + testAclGet(expectAuth = true) + testAclCreateGetDelete(expectAuth = true) + + // Test that we can't do anything with ACLs without the ALLOW ALTER ACL in place. + authorizationAdmin.removeClusterAcl(ALLOW, ALTER) + authorizationAdmin.removeClusterAcl(ALLOW, DELETE) + testAclGet(expectAuth = false) + testAclCreateGetDelete(expectAuth = false) + + // Test that we can describe, but not alter ACLs, with only the ALLOW DESCRIBE ACL in place. + authorizationAdmin.addClusterAcl(ALLOW, DESCRIBE) + testAclGet(expectAuth = true) + testAclCreateGetDelete(expectAuth = false) + } + + @Test + def testCreateTopicsResponseMetadataAndConfig(): Unit = { + val topic1 = "mytopic1" + val topic2 = "mytopic2" + val denyAcl = new AclBinding(new ResourcePattern(ResourceType.TOPIC, topic2, PatternType.LITERAL), + new AccessControlEntry("User:*", "*", AclOperation.DESCRIBE_CONFIGS, AclPermissionType.DENY)) + + client = Admin.create(createConfig) + client.createAcls(List(denyAcl).asJava, new CreateAclsOptions()).all().get() + + val topics = Seq(topic1, topic2) + val configsOverride = Map(LogConfig.SegmentBytesProp -> "100000").asJava + val newTopics = Seq( + new NewTopic(topic1, 2, 3.toShort).configs(configsOverride), + new NewTopic(topic2, Option.empty[Integer].asJava, Option.empty[java.lang.Short].asJava).configs(configsOverride)) + val validateResult = client.createTopics(newTopics.asJava, new CreateTopicsOptions().validateOnly(true)) + validateResult.all.get() + waitForTopics(client, List(), topics) + + def validateMetadataAndConfigs(result: CreateTopicsResult): Unit = { + assertEquals(2, result.numPartitions(topic1).get()) + assertEquals(3, result.replicationFactor(topic1).get()) + val topicConfigs = result.config(topic1).get().entries.asScala + assertTrue(topicConfigs.nonEmpty) + val segmentBytesConfig = topicConfigs.find(_.name == LogConfig.SegmentBytesProp).get + assertEquals(100000, segmentBytesConfig.value.toLong) + assertEquals(ConfigEntry.ConfigSource.DYNAMIC_TOPIC_CONFIG, segmentBytesConfig.source) + val compressionConfig = topicConfigs.find(_.name == LogConfig.CompressionTypeProp).get + assertEquals(Defaults.CompressionType, compressionConfig.value) + assertEquals(ConfigEntry.ConfigSource.DEFAULT_CONFIG, compressionConfig.source) + + assertFutureExceptionTypeEquals(result.numPartitions(topic2), classOf[TopicAuthorizationException]) + assertFutureExceptionTypeEquals(result.replicationFactor(topic2), classOf[TopicAuthorizationException]) + assertFutureExceptionTypeEquals(result.config(topic2), classOf[TopicAuthorizationException]) + } + validateMetadataAndConfigs(validateResult) + + val createResult = client.createTopics(newTopics.asJava, new CreateTopicsOptions()) + createResult.all.get() + waitForTopics(client, topics, List()) + validateMetadataAndConfigs(createResult) + val topicIds = getTopicIds() + assertNotEquals(Uuid.ZERO_UUID, createResult.topicId(topic1).get()) + assertEquals(topicIds(topic1), createResult.topicId(topic1).get()) + assertFutureExceptionTypeEquals(createResult.topicId(topic2), classOf[TopicAuthorizationException]) + + val createResponseConfig = createResult.config(topic1).get().entries.asScala + + val describeResponseConfig = describeConfigs(topic1) + assertEquals(describeResponseConfig.map(_.name).toSet, createResponseConfig.map(_.name).toSet) + describeResponseConfig.foreach { describeEntry => + val name = describeEntry.name + val createEntry = createResponseConfig.find(_.name == name).get + assertEquals(describeEntry.value, createEntry.value, s"Value mismatch for $name") + assertEquals(describeEntry.isReadOnly, createEntry.isReadOnly, s"isReadOnly mismatch for $name") + assertEquals(describeEntry.isSensitive, createEntry.isSensitive, s"isSensitive mismatch for $name") + assertEquals(describeEntry.source, createEntry.source, s"Source mismatch for $name") + } + } + + private def describeConfigs(topic: String): Iterable[ConfigEntry] = { + val topicResource = new ConfigResource(ConfigResource.Type.TOPIC, topic) + var configEntries: Iterable[ConfigEntry] = null + + TestUtils.waitUntilTrue(() => { + try { + val topicResponse = client.describeConfigs(List(topicResource).asJava).all.get.get(topicResource) + configEntries = topicResponse.entries.asScala + true + } catch { + case e: ExecutionException if e.getCause.isInstanceOf[UnknownTopicOrPartitionException] => false + } + }, "Timed out waiting for describeConfigs") + + configEntries + } + + private def waitForDescribeAcls(client: Admin, filter: AclBindingFilter, acls: Set[AclBinding]): Unit = { + var lastResults: util.Collection[AclBinding] = null + TestUtils.waitUntilTrue(() => { + lastResults = client.describeAcls(filter).values.get() + acls == lastResults.asScala.toSet + }, s"timed out waiting for ACLs $acls.\nActual $lastResults") + } + + private def ensureAcls(bindings: Set[AclBinding]): Unit = { + client.createAcls(bindings.asJava).all().get() + + bindings.foreach(binding => waitForDescribeAcls(client, binding.toFilter, Set(binding))) + } + + private def getAcls(allTopicAcls: AclBindingFilter) = { + client.describeAcls(allTopicAcls).values.get().asScala.toSet + } + + class AclAuthorizationAdmin(authorizerClass: Class[_ <: AclAuthorizer], authorizerForInitClass: Class[_ <: AclAuthorizer]) { + + def authorizerClassName: String = authorizerClass.getName + + def initializeAcls(): Unit = { + val authorizer = CoreUtils.createObject[Authorizer](authorizerForInitClass.getName) + try { + authorizer.configure(configs.head.originals()) + val ace = new AccessControlEntry(WildcardPrincipalString, WildcardHost, ALL, ALLOW) + authorizer.createAcls(null, List(new AclBinding(new ResourcePattern(TOPIC, "*", LITERAL), ace)).asJava) + authorizer.createAcls(null, List(new AclBinding(new ResourcePattern(GROUP, "*", LITERAL), ace)).asJava) + + authorizer.createAcls(null, List(clusterAcl(ALLOW, CREATE), + clusterAcl(ALLOW, DELETE), + clusterAcl(ALLOW, CLUSTER_ACTION), + clusterAcl(ALLOW, ALTER_CONFIGS), + clusterAcl(ALLOW, ALTER)) + .map(ace => new AclBinding(clusterResourcePattern, ace)).asJava) + } finally { + authorizer.close() + } + } + + def addClusterAcl(permissionType: AclPermissionType, operation: AclOperation): Unit = { + val ace = clusterAcl(permissionType, operation) + val aclBinding = new AclBinding(clusterResourcePattern, ace) + val authorizer = servers.head.dataPlaneRequestProcessor.authorizer.get + val prevAcls = authorizer.acls(new AclBindingFilter(clusterResourcePattern.toFilter, AccessControlEntryFilter.ANY)) + .asScala.map(_.entry).toSet + authorizer.createAcls(null, Collections.singletonList(aclBinding)) + TestUtils.waitAndVerifyAcls(prevAcls ++ Set(ace), authorizer, clusterResourcePattern) + } + + def removeClusterAcl(permissionType: AclPermissionType, operation: AclOperation): Unit = { + val ace = clusterAcl(permissionType, operation) + val authorizer = servers.head.dataPlaneRequestProcessor.authorizer.get + val clusterFilter = new AclBindingFilter(clusterResourcePattern.toFilter, AccessControlEntryFilter.ANY) + val prevAcls = authorizer.acls(clusterFilter).asScala.map(_.entry).toSet + val deleteFilter = new AclBindingFilter(clusterResourcePattern.toFilter, ace.toFilter) + assertFalse(authorizer.deleteAcls(null, Collections.singletonList(deleteFilter)) + .get(0).toCompletableFuture.get.aclBindingDeleteResults().asScala.head.exception.isPresent) + TestUtils.waitAndVerifyAcls(prevAcls -- Set(ace), authorizer, clusterResourcePattern) + } + + private def clusterAcl(permissionType: AclPermissionType, operation: AclOperation): AccessControlEntry = { + new AccessControlEntry(new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "*").toString, + WildcardHost, operation, permissionType) + } + } +} diff --git a/core/src/test/scala/integration/kafka/api/SaslSslConsumerTest.scala b/core/src/test/scala/integration/kafka/api/SaslSslConsumerTest.scala new file mode 100644 index 0000000..563481d --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/SaslSslConsumerTest.scala @@ -0,0 +1,39 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package kafka.api + +import java.io.File + +import kafka.server.KafkaConfig +import kafka.utils.JaasTestUtils +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.junit.jupiter.api.{AfterEach, BeforeEach, TestInfo} + +class SaslSslConsumerTest extends BaseConsumerTest with SaslSetup { + this.serverConfig.setProperty(KafkaConfig.ZkEnableSecureAclsProp, "true") + override protected def securityProtocol = SecurityProtocol.SASL_SSL + override protected lazy val trustStoreFile = Some(File.createTempFile("truststore", ".jks")) + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + startSasl(jaasSections(Seq("GSSAPI"), Some("GSSAPI"), Both, JaasTestUtils.KafkaServerContextName)) + super.setUp(testInfo) + } + + @AfterEach + override def tearDown(): Unit = { + super.tearDown() + closeSasl() + } + +} diff --git a/core/src/test/scala/integration/kafka/api/SslAdminIntegrationTest.scala b/core/src/test/scala/integration/kafka/api/SslAdminIntegrationTest.scala new file mode 100644 index 0000000..b918081 --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/SslAdminIntegrationTest.scala @@ -0,0 +1,262 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package kafka.api + +import java.io.File +import java.util +import java.util.concurrent._ +import com.yammer.metrics.core.Gauge +import kafka.metrics.KafkaYammerMetrics +import kafka.security.authorizer.AclAuthorizer +import kafka.server.KafkaConfig +import kafka.utils.TestUtils +import org.apache.kafka.clients.admin.{Admin, AdminClientConfig, CreateAclsResult} +import org.apache.kafka.common.acl._ +import org.apache.kafka.common.protocol.ApiKeys +import org.apache.kafka.common.resource.{PatternType, ResourcePattern, ResourceType} +import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol} +import org.apache.kafka.server.authorizer._ +import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertNotNull, assertTrue} +import org.junit.jupiter.api.{AfterEach, Test} + +import scala.jdk.CollectionConverters._ +import scala.collection.mutable + +object SslAdminIntegrationTest { + @volatile var semaphore: Option[Semaphore] = None + @volatile var executor: Option[ExecutorService] = None + @volatile var lastUpdateRequestContext: Option[AuthorizableRequestContext] = None + class TestableAclAuthorizer extends AclAuthorizer { + override def createAcls(requestContext: AuthorizableRequestContext, + aclBindings: util.List[AclBinding]): util.List[_ <: CompletionStage[AclCreateResult]] = { + lastUpdateRequestContext = Some(requestContext) + execute[AclCreateResult](aclBindings.size, () => super.createAcls(requestContext, aclBindings)) + } + + override def deleteAcls(requestContext: AuthorizableRequestContext, + aclBindingFilters: util.List[AclBindingFilter]): util.List[_ <: CompletionStage[AclDeleteResult]] = { + lastUpdateRequestContext = Some(requestContext) + execute[AclDeleteResult](aclBindingFilters.size, () => super.deleteAcls(requestContext, aclBindingFilters)) + } + + private def execute[T](batchSize: Int, action: () => util.List[_ <: CompletionStage[T]]): util.List[CompletableFuture[T]] = { + val futures = (0 until batchSize).map(_ => new CompletableFuture[T]).toList + val runnable = new Runnable { + override def run(): Unit = { + semaphore.foreach(_.acquire()) + try { + action.apply().asScala.zip(futures).foreach { case (baseFuture, resultFuture) => + baseFuture.whenComplete { (result, exception) => + if (exception != null) + resultFuture.completeExceptionally(exception) + else + resultFuture.complete(result) + } + } + } finally { + semaphore.foreach(_.release()) + } + } + } + executor match { + case Some(executorService) => executorService.submit(runnable) + case None => runnable.run() + } + futures.asJava + } + } +} + +class SslAdminIntegrationTest extends SaslSslAdminIntegrationTest { + override val authorizationAdmin = new AclAuthorizationAdmin(classOf[SslAdminIntegrationTest.TestableAclAuthorizer], classOf[AclAuthorizer]) + + this.serverConfig.setProperty(KafkaConfig.ZkEnableSecureAclsProp, "true") + + override protected def securityProtocol = SecurityProtocol.SSL + override protected lazy val trustStoreFile = Some(File.createTempFile("truststore", ".jks")) + private val adminClients = mutable.Buffer.empty[Admin] + + override def setUpSasl(): Unit = { + SslAdminIntegrationTest.semaphore = None + SslAdminIntegrationTest.executor = None + SslAdminIntegrationTest.lastUpdateRequestContext = None + + startSasl(jaasSections(List.empty, None, ZkSasl)) + } + + @AfterEach + override def tearDown(): Unit = { + // Ensure semaphore doesn't block shutdown even if test has failed + val semaphore = SslAdminIntegrationTest.semaphore + SslAdminIntegrationTest.semaphore = None + semaphore.foreach(s => s.release(s.getQueueLength)) + + adminClients.foreach(_.close()) + super.tearDown() + } + + @Test + def testAclUpdatesUsingSynchronousAuthorizer(): Unit = { + verifyAclUpdates() + } + + @Test + def testAclUpdatesUsingAsynchronousAuthorizer(): Unit = { + SslAdminIntegrationTest.executor = Some(Executors.newSingleThreadExecutor) + verifyAclUpdates() + } + + /** + * Verify that ACL updates using synchronous authorizer are performed synchronously + * on request threads without any performance overhead introduced by a purgatory. + */ + @Test + def testSynchronousAuthorizerAclUpdatesBlockRequestThreads(): Unit = { + val testSemaphore = new Semaphore(0) + SslAdminIntegrationTest.semaphore = Some(testSemaphore) + waitForNoBlockedRequestThreads() + + // Queue requests until all threads are blocked. ACL create requests are sent to least loaded + // node, so we may need more than `numRequestThreads` requests to block all threads. + val aclFutures = mutable.Buffer[CreateAclsResult]() + while (blockedRequestThreads.size < numRequestThreads) { + aclFutures += createAdminClient.createAcls(List(acl2).asJava) + assertTrue(aclFutures.size < numRequestThreads * 10, + s"Request threads not blocked numRequestThreads=$numRequestThreads blocked=$blockedRequestThreads") + } + assertEquals(0, purgatoryMetric("NumDelayedOperations")) + assertEquals(0, purgatoryMetric("PurgatorySize")) + + // Verify that operations on other clients are blocked + val describeFuture = createAdminClient.describeCluster().clusterId() + assertFalse(describeFuture.isDone) + + // Release the semaphore and verify that all requests complete + testSemaphore.release(aclFutures.size) + waitForNoBlockedRequestThreads() + assertNotNull(describeFuture.get(10, TimeUnit.SECONDS)) + // If any of the requests time out since we were blocking the threads earlier, retry the request. + val numTimedOut = aclFutures.count { future => + try { + future.all().get() + false + } catch { + case e: ExecutionException => + if (e.getCause.isInstanceOf[org.apache.kafka.common.errors.TimeoutException]) + true + else + throw e.getCause + } + } + (0 until numTimedOut) + .map(_ => createAdminClient.createAcls(List(acl2).asJava)) + .foreach(_.all().get(30, TimeUnit.SECONDS)) + } + + /** + * Verify that ACL updates using an asynchronous authorizer are completed asynchronously + * using a purgatory, enabling other requests to be processed even when ACL updates are blocked. + */ + @Test + def testAsynchronousAuthorizerAclUpdatesDontBlockRequestThreads(): Unit = { + SslAdminIntegrationTest.executor = Some(Executors.newSingleThreadExecutor) + val testSemaphore = new Semaphore(0) + SslAdminIntegrationTest.semaphore = Some(testSemaphore) + + waitForNoBlockedRequestThreads() + + val aclFutures = (0 until numRequestThreads).map(_ => createAdminClient.createAcls(List(acl2).asJava)) + waitForNoBlockedRequestThreads() + assertTrue(aclFutures.forall(future => !future.all.isDone)) + // Other requests should succeed even though ACL updates are blocked + assertNotNull(createAdminClient.describeCluster().clusterId().get(10, TimeUnit.SECONDS)) + TestUtils.waitUntilTrue(() => purgatoryMetric("PurgatorySize") > 0, "PurgatorySize metrics not updated") + TestUtils.waitUntilTrue(() => purgatoryMetric("NumDelayedOperations") > 0, "NumDelayedOperations metrics not updated") + + // Release the semaphore and verify that ACL update requests complete + testSemaphore.release(aclFutures.size) + aclFutures.foreach(_.all.get()) + assertEquals(0, purgatoryMetric("NumDelayedOperations")) + } + + private def verifyAclUpdates(): Unit = { + val acl = new AclBinding(new ResourcePattern(ResourceType.TOPIC, "mytopic3", PatternType.LITERAL), + new AccessControlEntry("User:ANONYMOUS", "*", AclOperation.DESCRIBE, AclPermissionType.ALLOW)) + + def validateRequestContext(context: AuthorizableRequestContext, apiKey: ApiKeys): Unit = { + assertEquals(SecurityProtocol.SSL, context.securityProtocol) + assertEquals("SSL", context.listenerName) + assertEquals(KafkaPrincipal.ANONYMOUS, context.principal) + assertEquals(apiKey.id.toInt, context.requestType) + assertEquals(apiKey.latestVersion.toInt, context.requestVersion) + assertTrue(context.correlationId > 0, s"Invalid correlation id: ${context.correlationId}") + assertTrue(context.clientId.startsWith("adminclient"), s"Invalid client id: ${context.clientId}") + assertTrue(context.clientAddress.isLoopbackAddress, s"Invalid host address: ${context.clientAddress}") + } + + val testSemaphore = new Semaphore(0) + SslAdminIntegrationTest.semaphore = Some(testSemaphore) + + client = Admin.create(createConfig) + val results = client.createAcls(List(acl2, acl3).asJava).values + assertEquals(Set(acl2, acl3), results.keySet().asScala) + assertFalse(results.values.asScala.exists(_.isDone)) + TestUtils.waitUntilTrue(() => testSemaphore.hasQueuedThreads, "Authorizer not blocked in createAcls") + testSemaphore.release() + results.values.forEach(_.get) + validateRequestContext(SslAdminIntegrationTest.lastUpdateRequestContext.get, ApiKeys.CREATE_ACLS) + + testSemaphore.acquire() + val results2 = client.deleteAcls(List(acl.toFilter, acl2.toFilter, acl3.toFilter).asJava).values + assertEquals(Set(acl.toFilter, acl2.toFilter, acl3.toFilter), results2.keySet.asScala) + assertFalse(results2.values.asScala.exists(_.isDone)) + TestUtils.waitUntilTrue(() => testSemaphore.hasQueuedThreads, "Authorizer not blocked in deleteAcls") + testSemaphore.release() + results.values.forEach(_.get) + assertEquals(0, results2.get(acl.toFilter).get.values.size()) + assertEquals(Set(acl2), results2.get(acl2.toFilter).get.values.asScala.map(_.binding).toSet) + assertEquals(Set(acl3), results2.get(acl3.toFilter).get.values.asScala.map(_.binding).toSet) + validateRequestContext(SslAdminIntegrationTest.lastUpdateRequestContext.get, ApiKeys.DELETE_ACLS) + } + + private def createAdminClient: Admin = { + val config = createConfig + config.put(AdminClientConfig.DEFAULT_API_TIMEOUT_MS_CONFIG, "40000") + val client = Admin.create(config) + adminClients += client + client + } + + private def blockedRequestThreads: List[Thread] = { + val requestThreads = Thread.getAllStackTraces.keySet.asScala + .filter(_.getName.contains("data-plane-kafka-request-handler")) + assertEquals(numRequestThreads, requestThreads.size) + requestThreads.filter(_.getState == Thread.State.WAITING).toList + } + + private def numRequestThreads = servers.head.config.numIoThreads * servers.size + + private def waitForNoBlockedRequestThreads(): Unit = { + val (blockedThreads, _) = TestUtils.computeUntilTrue(blockedRequestThreads)(_.isEmpty) + assertEquals(List.empty, blockedThreads) + } + + private def purgatoryMetric(name: String): Int = { + val allMetrics = KafkaYammerMetrics.defaultRegistry.allMetrics.asScala + val metrics = allMetrics.filter { case (metricName, _) => + metricName.getMBeanName.contains("delayedOperation=AlterAcls") && metricName.getMBeanName.contains(s"name=$name") + }.values.toList + assertTrue(metrics.nonEmpty, s"Unable to find metric $name: allMetrics: ${allMetrics.keySet.map(_.getMBeanName)}") + metrics.map(_.asInstanceOf[Gauge[Int]].value).sum + } +} diff --git a/core/src/test/scala/integration/kafka/api/SslConsumerTest.scala b/core/src/test/scala/integration/kafka/api/SslConsumerTest.scala new file mode 100644 index 0000000..a09fcdc --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/SslConsumerTest.scala @@ -0,0 +1,22 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package kafka.api + +import java.io.File + +import org.apache.kafka.common.security.auth.SecurityProtocol + +class SslConsumerTest extends BaseConsumerTest { + override protected def securityProtocol = SecurityProtocol.SSL + override protected lazy val trustStoreFile = Some(File.createTempFile("truststore", ".jks")) +} diff --git a/core/src/test/scala/integration/kafka/api/SslEndToEndAuthorizationTest.scala b/core/src/test/scala/integration/kafka/api/SslEndToEndAuthorizationTest.scala new file mode 100644 index 0000000..850d99f --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/SslEndToEndAuthorizationTest.scala @@ -0,0 +1,84 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.api + +import java.util.Properties + +import kafka.utils.TestUtils +import org.apache.kafka.common.config.SslConfigs +import org.apache.kafka.common.config.internals.BrokerSecurityConfigs +import org.apache.kafka.common.network.Mode +import org.apache.kafka.common.security.auth._ +import org.apache.kafka.common.security.authenticator.DefaultKafkaPrincipalBuilder +import org.apache.kafka.common.utils.Java +import org.junit.jupiter.api.{BeforeEach, TestInfo} + +object SslEndToEndAuthorizationTest { + class TestPrincipalBuilder extends DefaultKafkaPrincipalBuilder(null, null) { + private val Pattern = "O=A (.*?),CN=(.*?)".r + + // Use full DN as client principal to test special characters in principal + // Use field from DN as server principal to test custom PrincipalBuilder + override def build(context: AuthenticationContext): KafkaPrincipal = { + val peerPrincipal = context.asInstanceOf[SslAuthenticationContext].session.getPeerPrincipal.getName + peerPrincipal match { + case Pattern(name, _) => + val principal = if (name == "server") name else peerPrincipal + new KafkaPrincipal(KafkaPrincipal.USER_TYPE, principal) + case _ => + KafkaPrincipal.ANONYMOUS + } + } + } +} + +class SslEndToEndAuthorizationTest extends EndToEndAuthorizationTest { + + import kafka.api.SslEndToEndAuthorizationTest.TestPrincipalBuilder + + override protected def securityProtocol = SecurityProtocol.SSL + // Since there are other E2E tests that enable SSL, running this test with TLSv1.3 if supported + private val tlsProtocol = if (Java.IS_JAVA11_COMPATIBLE) "TLSv1.3" else "TLSv1.2" + + this.serverConfig.setProperty(BrokerSecurityConfigs.SSL_CLIENT_AUTH_CONFIG, "required") + this.serverConfig.setProperty(BrokerSecurityConfigs.PRINCIPAL_BUILDER_CLASS_CONFIG, classOf[TestPrincipalBuilder].getName) + this.serverConfig.setProperty(SslConfigs.SSL_PROTOCOL_CONFIG, tlsProtocol) + this.serverConfig.setProperty(SslConfigs.SSL_ENABLED_PROTOCOLS_CONFIG, tlsProtocol) + // Escaped characters in DN attribute values: from http://www.ietf.org/rfc/rfc2253.txt + // - a space or "#" character occurring at the beginning of the string + // - a space character occurring at the end of the string + // - one of the characters ",", "+", """, "\", "<", ">" or ";" + // + // Leading and trailing spaces in Kafka principal dont work with ACLs, but we can workaround by using + // a PrincipalBuilder that removes/replaces them. + private val clientCn = """\#A client with special chars in CN : (\, \+ \" \\ \< \> \; ')""" + override val clientPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, s"O=A client,CN=$clientCn") + override val kafkaPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "server") + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + startSasl(jaasSections(List.empty, None, ZkSasl)) + super.setUp(testInfo) + } + + override def clientSecurityProps(certAlias: String): Properties = { + val props = TestUtils.securityConfigs(Mode.CLIENT, securityProtocol, trustStoreFile, + certAlias, clientCn, clientSaslProperties, tlsProtocol) + props.remove(SslConfigs.SSL_ENDPOINT_IDENTIFICATION_ALGORITHM_CONFIG) + props + } +} diff --git a/core/src/test/scala/integration/kafka/api/SslProducerSendTest.scala b/core/src/test/scala/integration/kafka/api/SslProducerSendTest.scala new file mode 100644 index 0000000..cadc02f --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/SslProducerSendTest.scala @@ -0,0 +1,27 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.api + +import java.io.File + +import org.apache.kafka.common.security.auth.SecurityProtocol + +class SslProducerSendTest extends BaseProducerSendTest { + override protected def securityProtocol = SecurityProtocol.SSL + override protected lazy val trustStoreFile = Some(File.createTempFile("truststore", ".jks")) +} diff --git a/core/src/test/scala/integration/kafka/api/TransactionsBounceTest.scala b/core/src/test/scala/integration/kafka/api/TransactionsBounceTest.scala new file mode 100644 index 0000000..204ab38 --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/TransactionsBounceTest.scala @@ -0,0 +1,210 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.api + +import java.util.Properties + +import kafka.server.KafkaConfig +import kafka.utils.{ShutdownableThread, TestUtils} +import org.apache.kafka.clients.consumer.{ConsumerConfig, KafkaConsumer} +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerConfig} +import org.apache.kafka.clients.producer.internals.ErrorLoggingCallback +import org.apache.kafka.common.TopicPartition +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.annotation.nowarn +import scala.jdk.CollectionConverters._ +import scala.collection.mutable + +class TransactionsBounceTest extends IntegrationTestHarness { + private val consumeRecordTimeout = 30000 + private val producerBufferSize = 65536 + private val serverMessageMaxBytes = producerBufferSize/2 + private val numPartitions = 3 + private val outputTopic = "output-topic" + private val inputTopic = "input-topic" + + val overridingProps = new Properties() + overridingProps.put(KafkaConfig.AutoCreateTopicsEnableProp, false.toString) + overridingProps.put(KafkaConfig.MessageMaxBytesProp, serverMessageMaxBytes.toString) + // Set a smaller value for the number of partitions for the offset commit topic (__consumer_offset topic) + // so that the creation of that topic/partition(s) and subsequent leader assignment doesn't take relatively long + overridingProps.put(KafkaConfig.ControlledShutdownEnableProp, true.toString) + overridingProps.put(KafkaConfig.UncleanLeaderElectionEnableProp, false.toString) + overridingProps.put(KafkaConfig.AutoLeaderRebalanceEnableProp, false.toString) + overridingProps.put(KafkaConfig.OffsetsTopicPartitionsProp, 1.toString) + overridingProps.put(KafkaConfig.OffsetsTopicReplicationFactorProp, 3.toString) + overridingProps.put(KafkaConfig.MinInSyncReplicasProp, 2.toString) + overridingProps.put(KafkaConfig.TransactionsTopicPartitionsProp, 1.toString) + overridingProps.put(KafkaConfig.TransactionsTopicReplicationFactorProp, 3.toString) + overridingProps.put(KafkaConfig.GroupMinSessionTimeoutMsProp, "10") // set small enough session timeout + overridingProps.put(KafkaConfig.GroupInitialRebalanceDelayMsProp, "0") + + // This is the one of the few tests we currently allow to preallocate ports, despite the fact that this can result in transient + // failures due to ports getting reused. We can't use random ports because of bad behavior that can result from bouncing + // brokers too quickly when they get new, random ports. If we're not careful, the client can end up in a situation + // where metadata is not refreshed quickly enough, and by the time it's actually trying to, all the servers have + // been bounced and have new addresses. None of the bootstrap nodes or current metadata can get them connected to a + // running server. + // + // Since such quick rotation of servers is incredibly unrealistic, we allow this one test to preallocate ports, leaving + // a small risk of hitting errors due to port conflicts. Hopefully this is infrequent enough to not cause problems. + override def generateConfigs = { + FixedPortTestUtils.createBrokerConfigs(brokerCount, zkConnect, enableControlledShutdown = true) + .map(KafkaConfig.fromProps(_, overridingProps)) + } + + override protected def brokerCount: Int = 4 + + @nowarn("cat=deprecation") + @Test + def testWithGroupId(): Unit = { + testBrokerFailure((producer, groupId, consumer) => + producer.sendOffsetsToTransaction(TestUtils.consumerPositions(consumer).asJava, groupId)) + } + + @Test + def testWithGroupMetadata(): Unit = { + testBrokerFailure((producer, _, consumer) => + producer.sendOffsetsToTransaction(TestUtils.consumerPositions(consumer).asJava, consumer.groupMetadata())) + } + + private def testBrokerFailure(commit: (KafkaProducer[Array[Byte], Array[Byte]], + String, KafkaConsumer[Array[Byte], Array[Byte]]) => Unit): Unit = { + // basic idea is to seed a topic with 10000 records, and copy it transactionally while bouncing brokers + // constantly through the period. + val consumerGroup = "myGroup" + val numInputRecords = 10000 + createTopics() + + TestUtils.seedTopicWithNumberedRecords(inputTopic, numInputRecords, servers) + val consumer = createConsumerAndSubscribe(consumerGroup, List(inputTopic)) + val producer = createTransactionalProducer("test-txn") + + producer.initTransactions() + + val scheduler = new BounceScheduler + scheduler.start() + + try { + var numMessagesProcessed = 0 + var iteration = 0 + + while (numMessagesProcessed < numInputRecords) { + val toRead = Math.min(200, numInputRecords - numMessagesProcessed) + trace(s"$iteration: About to read $toRead messages, processed $numMessagesProcessed so far..") + val records = TestUtils.pollUntilAtLeastNumRecords(consumer, toRead, waitTimeMs = consumeRecordTimeout) + trace(s"Received ${records.size} messages, sending them transactionally to $outputTopic") + + producer.beginTransaction() + val shouldAbort = iteration % 3 == 0 + records.foreach { record => + producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(outputTopic, null, record.key, record.value, !shouldAbort), new ErrorLoggingCallback(outputTopic, record.key, record.value, true)) + } + trace(s"Sent ${records.size} messages. Committing offsets.") + commit(producer, consumerGroup, consumer) + + if (shouldAbort) { + trace(s"Committed offsets. Aborting transaction of ${records.size} messages.") + producer.abortTransaction() + TestUtils.resetToCommittedPositions(consumer) + } else { + trace(s"Committed offsets. committing transaction of ${records.size} messages.") + producer.commitTransaction() + numMessagesProcessed += records.size + } + iteration += 1 + } + } finally { + scheduler.shutdown() + } + + val verifyingConsumer = createConsumerAndSubscribe("randomGroup", List(outputTopic), readCommitted = true) + val recordsByPartition = new mutable.HashMap[TopicPartition, mutable.ListBuffer[Int]]() + TestUtils.pollUntilAtLeastNumRecords(verifyingConsumer, numInputRecords, waitTimeMs = consumeRecordTimeout).foreach { record => + val value = TestUtils.assertCommittedAndGetValue(record).toInt + val topicPartition = new TopicPartition(record.topic(), record.partition()) + recordsByPartition.getOrElseUpdate(topicPartition, new mutable.ListBuffer[Int]) + .append(value) + } + + val outputRecords = new mutable.ListBuffer[Int]() + recordsByPartition.values.foreach { partitionValues => + assertEquals(partitionValues, partitionValues.sorted, "Out of order messages detected") + outputRecords.appendAll(partitionValues) + } + + val recordSet = outputRecords.toSet + assertEquals(numInputRecords, recordSet.size) + + val expectedValues = (0 until numInputRecords).toSet + assertEquals(expectedValues, recordSet, s"Missing messages: ${expectedValues -- recordSet}") + } + + private def createTransactionalProducer(transactionalId: String) = { + val props = new Properties() + props.put(ProducerConfig.ACKS_CONFIG, "all") + props.put(ProducerConfig.BATCH_SIZE_CONFIG, "512") + props.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, transactionalId) + props.put(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG, "true") + createProducer(configOverrides = props) + } + + private def createConsumerAndSubscribe(groupId: String, + topics: List[String], + readCommitted: Boolean = false) = { + val consumerProps = new Properties + consumerProps.put(ConsumerConfig.GROUP_ID_CONFIG, groupId) + consumerProps.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") + consumerProps.put(ConsumerConfig.ISOLATION_LEVEL_CONFIG, + if (readCommitted) "read_committed" else "read_uncommitted") + val consumer = createConsumer(configOverrides = consumerProps) + consumer.subscribe(topics.asJava) + consumer + } + + private def createTopics() = { + val topicConfig = new Properties() + topicConfig.put(KafkaConfig.MinInSyncReplicasProp, 2.toString) + createTopic(inputTopic, numPartitions, 3, topicConfig) + createTopic(outputTopic, numPartitions, 3, topicConfig) + } + + private class BounceScheduler extends ShutdownableThread("daemon-broker-bouncer", false) { + override def doWork(): Unit = { + for (server <- servers) { + trace("Shutting down server : %s".format(server.config.brokerId)) + server.shutdown() + server.awaitShutdown() + Thread.sleep(500) + trace("Server %s shut down. Starting it up again.".format(server.config.brokerId)) + server.startup() + trace("Restarted server: %s".format(server.config.brokerId)) + Thread.sleep(500) + } + + (0 until numPartitions).foreach(partition => TestUtils.waitUntilLeaderIsElectedOrChanged(zkClient, outputTopic, partition)) + } + + override def shutdown(): Unit = { + super.shutdown() + } + } + +} diff --git a/core/src/test/scala/integration/kafka/api/TransactionsExpirationTest.scala b/core/src/test/scala/integration/kafka/api/TransactionsExpirationTest.scala new file mode 100644 index 0000000..1a656ef --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/TransactionsExpirationTest.scala @@ -0,0 +1,122 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.api + +import java.util.Properties + +import kafka.integration.KafkaServerTestHarness +import kafka.server.KafkaConfig +import kafka.utils.TestUtils +import kafka.utils.TestUtils.consumeRecords +import org.apache.kafka.clients.consumer.KafkaConsumer +import org.apache.kafka.clients.producer.KafkaProducer +import org.apache.kafka.common.errors.InvalidPidMappingException +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +import scala.jdk.CollectionConverters._ +import scala.collection.Seq + +// Test class that uses a very small transaction timeout to trigger InvalidPidMapping errors +class TransactionsExpirationTest extends KafkaServerTestHarness { + val topic1 = "topic1" + val topic2 = "topic2" + val numPartitions = 4 + val replicationFactor = 3 + + var producer: KafkaProducer[Array[Byte], Array[Byte]] = _ + var consumer: KafkaConsumer[Array[Byte], Array[Byte]] = _ + + override def generateConfigs: Seq[KafkaConfig] = { + TestUtils.createBrokerConfigs(3, zkConnect).map(KafkaConfig.fromProps(_, serverProps())) + } + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + + producer = TestUtils.createTransactionalProducer("transactionalProducer", servers) + consumer = TestUtils.createConsumer(TestUtils.getBrokerListStrFromServers(servers), + enableAutoCommit = false, + readCommitted = true) + + TestUtils.createTopic(zkClient, topic1, numPartitions, 3, servers, new Properties()) + TestUtils.createTopic(zkClient, topic2, numPartitions, 3, servers, new Properties()) + } + + @AfterEach + override def tearDown(): Unit = { + producer.close() + consumer.close() + + super.tearDown() + } + + @Test + def testBumpTransactionalEpochAfterInvalidProducerIdMapping(): Unit = { + producer.initTransactions() + + // Start and then abort a transaction to allow the transactional ID to expire + producer.beginTransaction() + producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, 0, "2", "2", willBeCommitted = false)) + producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, 0, "4", "4", willBeCommitted = false)) + producer.abortTransaction() + + // Wait for the transactional ID to expire + Thread.sleep(3000) + + // Start a new transaction and attempt to send, which will trigger an AddPartitionsToTxnRequest, which will fail due to the expired producer ID + producer.beginTransaction() + val failedFuture = producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, 3, "1", "1", willBeCommitted = false)) + Thread.sleep(500) + + org.apache.kafka.test.TestUtils.assertFutureThrows(failedFuture, classOf[InvalidPidMappingException]) + producer.abortTransaction() + + producer.beginTransaction() + producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, "2", "2", willBeCommitted = true)) + producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, 2, "4", "4", willBeCommitted = true)) + producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, "1", "1", willBeCommitted = true)) + producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, 3, "3", "3", willBeCommitted = true)) + producer.commitTransaction() + + consumer.subscribe(List(topic1, topic2).asJava) + + val records = consumeRecords(consumer, 4) + records.foreach { record => + TestUtils.assertCommittedAndGetValue(record) + } + } + private def serverProps() = { + val serverProps = new Properties() + serverProps.put(KafkaConfig.AutoCreateTopicsEnableProp, false.toString) + // Set a smaller value for the number of partitions for the __consumer_offsets topic + // so that the creation of that topic/partition(s) and subsequent leader assignment doesn't take relatively long + serverProps.put(KafkaConfig.OffsetsTopicPartitionsProp, 1.toString) + serverProps.put(KafkaConfig.TransactionsTopicPartitionsProp, 3.toString) + serverProps.put(KafkaConfig.TransactionsTopicReplicationFactorProp, 2.toString) + serverProps.put(KafkaConfig.TransactionsTopicMinISRProp, 2.toString) + serverProps.put(KafkaConfig.ControlledShutdownEnableProp, true.toString) + serverProps.put(KafkaConfig.UncleanLeaderElectionEnableProp, false.toString) + serverProps.put(KafkaConfig.AutoLeaderRebalanceEnableProp, false.toString) + serverProps.put(KafkaConfig.GroupInitialRebalanceDelayMsProp, "0") + serverProps.put(KafkaConfig.TransactionsAbortTimedOutTransactionCleanupIntervalMsProp, "200") + serverProps.put(KafkaConfig.TransactionalIdExpirationMsProp, "2000") + serverProps.put(KafkaConfig.TransactionsRemoveExpiredTransactionalIdCleanupIntervalMsProp, "500") + serverProps + } +} diff --git a/core/src/test/scala/integration/kafka/api/TransactionsTest.scala b/core/src/test/scala/integration/kafka/api/TransactionsTest.scala new file mode 100644 index 0000000..1fbba9e --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/TransactionsTest.scala @@ -0,0 +1,785 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.api + +import java.lang.{Long => JLong} +import java.nio.charset.StandardCharsets +import java.time.Duration +import java.util.concurrent.TimeUnit +import java.util.{Optional, Properties} + +import kafka.integration.KafkaServerTestHarness +import kafka.server.KafkaConfig +import kafka.utils.TestUtils +import kafka.utils.TestUtils.consumeRecords +import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerGroupMetadata, KafkaConsumer, OffsetAndMetadata} +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord} +import org.apache.kafka.common.errors.{InvalidProducerEpochException, ProducerFencedException, TimeoutException} +import org.apache.kafka.common.{KafkaException, TopicPartition} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +import scala.annotation.nowarn +import scala.jdk.CollectionConverters._ +import scala.collection.Seq +import scala.collection.mutable.Buffer +import scala.concurrent.ExecutionException + +class TransactionsTest extends KafkaServerTestHarness { + val numServers = 3 + val transactionalProducerCount = 2 + val transactionalConsumerCount = 1 + val nonTransactionalConsumerCount = 1 + + val topic1 = "topic1" + val topic2 = "topic2" + val numPartitions = 4 + + val transactionalProducers = Buffer[KafkaProducer[Array[Byte], Array[Byte]]]() + val transactionalConsumers = Buffer[KafkaConsumer[Array[Byte], Array[Byte]]]() + val nonTransactionalConsumers = Buffer[KafkaConsumer[Array[Byte], Array[Byte]]]() + + override def generateConfigs: Seq[KafkaConfig] = { + TestUtils.createBrokerConfigs(numServers, zkConnect).map(KafkaConfig.fromProps(_, serverProps())) + } + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + val topicConfig = new Properties() + topicConfig.put(KafkaConfig.MinInSyncReplicasProp, 2.toString) + createTopic(topic1, numPartitions, numServers, topicConfig) + createTopic(topic2, numPartitions, numServers, topicConfig) + + for (_ <- 0 until transactionalProducerCount) + createTransactionalProducer("transactional-producer") + for (_ <- 0 until transactionalConsumerCount) + createReadCommittedConsumer("transactional-group") + for (_ <- 0 until nonTransactionalConsumerCount) + createReadUncommittedConsumer("non-transactional-group") + } + + @AfterEach + override def tearDown(): Unit = { + transactionalProducers.foreach(_.close()) + transactionalConsumers.foreach(_.close()) + nonTransactionalConsumers.foreach(_.close()) + super.tearDown() + } + + @Test + def testBasicTransactions() = { + val producer = transactionalProducers.head + val consumer = transactionalConsumers.head + val unCommittedConsumer = nonTransactionalConsumers.head + + producer.initTransactions() + + producer.beginTransaction() + producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, "2", "2", willBeCommitted = false)) + producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "4", "4", willBeCommitted = false)) + producer.flush() + producer.abortTransaction() + + producer.beginTransaction() + producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "1", "1", willBeCommitted = true)) + producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, "3", "3", willBeCommitted = true)) + producer.commitTransaction() + + consumer.subscribe(List(topic1, topic2).asJava) + unCommittedConsumer.subscribe(List(topic1, topic2).asJava) + + val records = consumeRecords(consumer, 2) + records.foreach { record => + TestUtils.assertCommittedAndGetValue(record) + } + + val allRecords = consumeRecords(unCommittedConsumer, 4) + val expectedValues = List("1", "2", "3", "4").toSet + allRecords.foreach { record => + assertTrue(expectedValues.contains(TestUtils.recordValueAsString(record))) + } + } + + @Test + def testReadCommittedConsumerShouldNotSeeUndecidedData(): Unit = { + val producer1 = transactionalProducers.head + val producer2 = createTransactionalProducer("other") + val readCommittedConsumer = transactionalConsumers.head + val readUncommittedConsumer = nonTransactionalConsumers.head + + producer1.initTransactions() + producer2.initTransactions() + + producer1.beginTransaction() + producer2.beginTransaction() + + val latestVisibleTimestamp = System.currentTimeMillis() + producer2.send(new ProducerRecord(topic1, 0, latestVisibleTimestamp, "x".getBytes, "1".getBytes)) + producer2.send(new ProducerRecord(topic2, 0, latestVisibleTimestamp, "x".getBytes, "1".getBytes)) + producer2.flush() + + val latestWrittenTimestamp = latestVisibleTimestamp + 1 + producer1.send(new ProducerRecord(topic1, 0, latestWrittenTimestamp, "a".getBytes, "1".getBytes)) + producer1.send(new ProducerRecord(topic1, 0, latestWrittenTimestamp, "b".getBytes, "2".getBytes)) + producer1.send(new ProducerRecord(topic2, 0, latestWrittenTimestamp, "c".getBytes, "3".getBytes)) + producer1.send(new ProducerRecord(topic2, 0, latestWrittenTimestamp, "d".getBytes, "4".getBytes)) + producer1.flush() + + producer2.send(new ProducerRecord(topic1, 0, latestWrittenTimestamp, "x".getBytes, "2".getBytes)) + producer2.send(new ProducerRecord(topic2, 0, latestWrittenTimestamp, "x".getBytes, "2".getBytes)) + producer2.commitTransaction() + + // ensure the records are visible to the read uncommitted consumer + val tp1 = new TopicPartition(topic1, 0) + val tp2 = new TopicPartition(topic2, 0) + readUncommittedConsumer.assign(Set(tp1, tp2).asJava) + consumeRecords(readUncommittedConsumer, 8) + val readUncommittedOffsetsForTimes = readUncommittedConsumer.offsetsForTimes(Map( + tp1 -> (latestWrittenTimestamp: JLong), + tp2 -> (latestWrittenTimestamp: JLong) + ).asJava) + assertEquals(2, readUncommittedOffsetsForTimes.size) + assertEquals(latestWrittenTimestamp, readUncommittedOffsetsForTimes.get(tp1).timestamp) + assertEquals(latestWrittenTimestamp, readUncommittedOffsetsForTimes.get(tp2).timestamp) + readUncommittedConsumer.unsubscribe() + + // we should only see the first two records which come before the undecided second transaction + readCommittedConsumer.assign(Set(tp1, tp2).asJava) + val records = consumeRecords(readCommittedConsumer, 2) + records.foreach { record => + assertEquals("x", new String(record.key)) + assertEquals("1", new String(record.value)) + } + + // even if we seek to the end, we should not be able to see the undecided data + assertEquals(2, readCommittedConsumer.assignment.size) + readCommittedConsumer.seekToEnd(readCommittedConsumer.assignment) + readCommittedConsumer.assignment.forEach { tp => + assertEquals(1L, readCommittedConsumer.position(tp)) + } + + // undecided timestamps should not be searchable either + val readCommittedOffsetsForTimes = readCommittedConsumer.offsetsForTimes(Map( + tp1 -> (latestWrittenTimestamp: JLong), + tp2 -> (latestWrittenTimestamp: JLong) + ).asJava) + assertNull(readCommittedOffsetsForTimes.get(tp1)) + assertNull(readCommittedOffsetsForTimes.get(tp2)) + } + + @Test + def testDelayedFetchIncludesAbortedTransaction(): Unit = { + val producer1 = transactionalProducers.head + val producer2 = createTransactionalProducer("other") + + producer1.initTransactions() + producer2.initTransactions() + + producer1.beginTransaction() + producer2.beginTransaction() + producer2.send(new ProducerRecord(topic1, 0, "x".getBytes, "1".getBytes)) + producer2.flush() + + producer1.send(new ProducerRecord(topic1, 0, "y".getBytes, "1".getBytes)) + producer1.send(new ProducerRecord(topic1, 0, "y".getBytes, "2".getBytes)) + producer1.flush() + + producer2.send(new ProducerRecord(topic1, 0, "x".getBytes, "2".getBytes)) + producer2.flush() + + producer1.abortTransaction() + producer2.commitTransaction() + + // ensure that the consumer's fetch will sit in purgatory + val consumerProps = new Properties() + consumerProps.put(ConsumerConfig.FETCH_MIN_BYTES_CONFIG, "100000") + consumerProps.put(ConsumerConfig.FETCH_MAX_WAIT_MS_CONFIG, "100") + val readCommittedConsumer = createReadCommittedConsumer(props = consumerProps) + + readCommittedConsumer.assign(Set(new TopicPartition(topic1, 0)).asJava) + val records = consumeRecords(readCommittedConsumer, numRecords = 2) + assertEquals(2, records.size) + + val first = records.head + assertEquals("x", new String(first.key)) + assertEquals("1", new String(first.value)) + assertEquals(0L, first.offset) + + val second = records.last + assertEquals("x", new String(second.key)) + assertEquals("2", new String(second.value)) + assertEquals(3L, second.offset) + } + + @nowarn("cat=deprecation") + @Test + def testSendOffsetsWithGroupId() = { + sendOffset((producer, groupId, consumer) => + producer.sendOffsetsToTransaction(TestUtils.consumerPositions(consumer).asJava, groupId)) + } + + @Test + def testSendOffsetsWithGroupMetadata() = { + sendOffset((producer, _, consumer) => + producer.sendOffsetsToTransaction(TestUtils.consumerPositions(consumer).asJava, consumer.groupMetadata())) + } + + private def sendOffset(commit: (KafkaProducer[Array[Byte], Array[Byte]], + String, KafkaConsumer[Array[Byte], Array[Byte]]) => Unit) = { + + // The basic plan for the test is as follows: + // 1. Seed topic1 with 1000 unique, numbered, messages. + // 2. Run a consume/process/produce loop to transactionally copy messages from topic1 to topic2 and commit + // offsets as part of the transaction. + // 3. Randomly abort transactions in step2. + // 4. Validate that we have 1000 unique committed messages in topic2. If the offsets were committed properly with the + // transactions, we should not have any duplicates or missing messages since we should process in the input + // messages exactly once. + + val consumerGroupId = "foobar-consumer-group" + val numSeedMessages = 500 + + TestUtils.seedTopicWithNumberedRecords(topic1, numSeedMessages, servers) + + val producer = transactionalProducers.head + + val consumer = createReadCommittedConsumer(consumerGroupId, maxPollRecords = numSeedMessages / 4) + consumer.subscribe(List(topic1).asJava) + producer.initTransactions() + + var shouldCommit = false + var recordsProcessed = 0 + try { + while (recordsProcessed < numSeedMessages) { + val records = TestUtils.pollUntilAtLeastNumRecords(consumer, Math.min(10, numSeedMessages - recordsProcessed)) + + producer.beginTransaction() + shouldCommit = !shouldCommit + + records.foreach { record => + val key = new String(record.key(), StandardCharsets.UTF_8) + val value = new String(record.value(), StandardCharsets.UTF_8) + producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, key, value, willBeCommitted = shouldCommit)) + } + + commit(producer, consumerGroupId, consumer) + if (shouldCommit) { + producer.commitTransaction() + recordsProcessed += records.size + debug(s"committed transaction.. Last committed record: ${new String(records.last.value(), StandardCharsets.UTF_8)}. Num " + + s"records written to $topic2: $recordsProcessed") + } else { + producer.abortTransaction() + debug(s"aborted transaction Last committed record: ${new String(records.last.value(), StandardCharsets.UTF_8)}. Num " + + s"records written to $topic2: $recordsProcessed") + TestUtils.resetToCommittedPositions(consumer) + } + } + } finally { + consumer.close() + } + + // In spite of random aborts, we should still have exactly 1000 messages in topic2. I.e. we should not + // re-copy or miss any messages from topic1, since the consumed offsets were committed transactionally. + val verifyingConsumer = transactionalConsumers(0) + verifyingConsumer.subscribe(List(topic2).asJava) + val valueSeq = TestUtils.pollUntilAtLeastNumRecords(verifyingConsumer, numSeedMessages).map { record => + TestUtils.assertCommittedAndGetValue(record).toInt + } + val valueSet = valueSeq.toSet + assertEquals(numSeedMessages, valueSeq.size, s"Expected $numSeedMessages values in $topic2.") + assertEquals(valueSeq.size, valueSet.size, s"Expected ${valueSeq.size} unique messages in $topic2.") + } + + @Test + def testFencingOnCommit() = { + val producer1 = transactionalProducers(0) + val producer2 = transactionalProducers(1) + val consumer = transactionalConsumers(0) + + consumer.subscribe(List(topic1, topic2).asJava) + + producer1.initTransactions() + + producer1.beginTransaction() + producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "1", "1", willBeCommitted = false)) + producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, "3", "3", willBeCommitted = false)) + + producer2.initTransactions() // ok, will abort the open transaction. + producer2.beginTransaction() + producer2.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "2", "4", willBeCommitted = true)) + producer2.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, "2", "4", willBeCommitted = true)) + + assertThrows(classOf[ProducerFencedException], () => producer1.commitTransaction()) + + producer2.commitTransaction() // ok + + val records = consumeRecords(consumer, 2) + records.foreach { record => + TestUtils.assertCommittedAndGetValue(record) + } + } + + @Test + def testFencingOnSendOffsets() = { + val producer1 = transactionalProducers(0) + val producer2 = transactionalProducers(1) + val consumer = transactionalConsumers(0) + + consumer.subscribe(List(topic1, topic2).asJava) + + producer1.initTransactions() + + producer1.beginTransaction() + producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "1", "1", willBeCommitted = false)) + producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, "3", "3", willBeCommitted = false)) + producer1.flush() + + producer2.initTransactions() // ok, will abort the open transaction. + producer2.beginTransaction() + producer2.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "2", "4", willBeCommitted = true)) + producer2.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, "2", "4", willBeCommitted = true)) + + assertThrows(classOf[ProducerFencedException], () => producer1.sendOffsetsToTransaction(Map(new TopicPartition("foobartopic", 0) + -> new OffsetAndMetadata(110L)).asJava, new ConsumerGroupMetadata("foobarGroup"))) + + producer2.commitTransaction() // ok + + val records = consumeRecords(consumer, 2) + records.foreach { record => + TestUtils.assertCommittedAndGetValue(record) + } + } + + @Test + def testOffsetMetadataInSendOffsetsToTransaction() = { + val tp = new TopicPartition(topic1, 0) + val groupId = "group" + + val producer = transactionalProducers.head + val consumer = createReadCommittedConsumer(groupId) + + consumer.subscribe(List(topic1).asJava) + + producer.initTransactions() + + producer.beginTransaction() + val offsetAndMetadata = new OffsetAndMetadata(110L, Optional.of(15), "some metadata") + producer.sendOffsetsToTransaction(Map(tp -> offsetAndMetadata).asJava, new ConsumerGroupMetadata(groupId)) + producer.commitTransaction() // ok + + // The call to commit the transaction may return before all markers are visible, so we initialize a second + // producer to ensure the transaction completes and the committed offsets are visible. + val producer2 = transactionalProducers(1) + producer2.initTransactions() + + TestUtils.waitUntilTrue(() => offsetAndMetadata.equals(consumer.committed(Set(tp).asJava).get(tp)), "cannot read committed offset") + } + + @Test + def testInitTransactionsTimeout(): Unit = { + testTimeout(false, producer => producer.initTransactions()) + } + + @Test + def testSendOffsetsToTransactionTimeout(): Unit = { + testTimeout(true, producer => producer.sendOffsetsToTransaction( + Map(new TopicPartition(topic1, 0) -> new OffsetAndMetadata(0)).asJava, new ConsumerGroupMetadata("test-group"))) + } + + @Test + def testCommitTransactionTimeout(): Unit = { + testTimeout(true, producer => producer.commitTransaction()) + } + + @Test + def testAbortTransactionTimeout(): Unit = { + testTimeout(true, producer => producer.abortTransaction()) + } + + private def testTimeout(needInitAndSendMsg: Boolean, + timeoutProcess: KafkaProducer[Array[Byte], Array[Byte]] => Unit): Unit = { + val producer = createTransactionalProducer("transactionProducer", maxBlockMs = 3000) + if (needInitAndSendMsg) { + producer.initTransactions() + producer.beginTransaction() + producer.send(new ProducerRecord[Array[Byte], Array[Byte]](topic1, "foo".getBytes, "bar".getBytes)) + } + + for (i <- servers.indices) killBroker(i) + + assertThrows(classOf[TimeoutException], () => timeoutProcess(producer)) + producer.close(Duration.ZERO) + } + + @Test + def testFencingOnSend(): Unit = { + val producer1 = transactionalProducers(0) + val producer2 = transactionalProducers(1) + val consumer = transactionalConsumers(0) + + consumer.subscribe(List(topic1, topic2).asJava) + + producer1.initTransactions() + + producer1.beginTransaction() + producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "1", "1", willBeCommitted = false)) + producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, "3", "3", willBeCommitted = false)) + + producer2.initTransactions() // ok, will abort the open transaction. + producer2.beginTransaction() + producer2.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "2", "4", willBeCommitted = true)).get() + producer2.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, "2", "4", willBeCommitted = true)).get() + + try { + val result = producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "1", "5", willBeCommitted = false)) + val recordMetadata = result.get() + error(s"Missed a producer fenced exception when writing to ${recordMetadata.topic}-${recordMetadata.partition}. Grab the logs!!") + servers.foreach { server => + error(s"log dirs: ${server.logManager.liveLogDirs.map(_.getAbsolutePath).head}") + } + fail("Should not be able to send messages from a fenced producer.") + } catch { + case _: ProducerFencedException => + producer1.close() + case e: ExecutionException => + assertTrue(e.getCause.isInstanceOf[InvalidProducerEpochException]) + case e: Exception => + throw new AssertionError("Got an unexpected exception from a fenced producer.", e) + } + + producer2.commitTransaction() // ok + + val records = consumeRecords(consumer, 2) + records.foreach { record => + TestUtils.assertCommittedAndGetValue(record) + } + } + + @Test + def testFencingOnAddPartitions(): Unit = { + val producer1 = transactionalProducers(0) + val producer2 = transactionalProducers(1) + val consumer = transactionalConsumers(0) + + consumer.subscribe(List(topic1, topic2).asJava) + + producer1.initTransactions() + producer1.beginTransaction() + producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "1", "1", willBeCommitted = false)) + producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, "3", "3", willBeCommitted = false)) + producer1.abortTransaction() + + producer2.initTransactions() // ok, will abort the open transaction. + producer2.beginTransaction() + producer2.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "2", "4", willBeCommitted = true)) + .get(20, TimeUnit.SECONDS) + producer2.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, "2", "4", willBeCommitted = true)) + .get(20, TimeUnit.SECONDS) + + try { + producer1.beginTransaction() + val result = producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "1", "5", willBeCommitted = false)) + val recordMetadata = result.get() + error(s"Missed a producer fenced exception when writing to ${recordMetadata.topic}-${recordMetadata.partition}. Grab the logs!!") + servers.foreach { server => + error(s"log dirs: ${server.logManager.liveLogDirs.map(_.getAbsolutePath).head}") + } + fail("Should not be able to send messages from a fenced producer.") + } catch { + case _: ProducerFencedException => + case e: ExecutionException => + assertTrue(e.getCause.isInstanceOf[ProducerFencedException]) + case e: Exception => + throw new AssertionError("Got an unexpected exception from a fenced producer.", e) + } + + producer2.commitTransaction() // ok + + val records = consumeRecords(consumer, 2) + records.foreach { record => + TestUtils.assertCommittedAndGetValue(record) + } + } + + @Test + def testFencingOnTransactionExpiration(): Unit = { + val producer = createTransactionalProducer("expiringProducer", transactionTimeoutMs = 100) + + producer.initTransactions() + producer.beginTransaction() + + // The first message and hence the first AddPartitions request should be successfully sent. + val firstMessageResult = producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "1", "1", willBeCommitted = false)).get() + assertTrue(firstMessageResult.hasOffset) + + // Wait for the expiration cycle to kick in. + Thread.sleep(600) + + try { + // Now that the transaction has expired, the second send should fail with a ProducerFencedException. + producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "2", "2", willBeCommitted = false)).get() + fail("should have raised a ProducerFencedException since the transaction has expired") + } catch { + case _: ProducerFencedException => + case e: ExecutionException => + assertTrue(e.getCause.isInstanceOf[ProducerFencedException]) + } + + // Verify that the first message was aborted and the second one was never written at all. + val nonTransactionalConsumer = nonTransactionalConsumers.head + nonTransactionalConsumer.subscribe(List(topic1).asJava) + + // Attempt to consume the one written record. We should not see the second. The + // assertion does not strictly guarantee that the record wasn't written, but the + // data is small enough that had it been written, it would have been in the first fetch. + val records = TestUtils.consumeRecords(nonTransactionalConsumer, numRecords = 1) + assertEquals(1, records.size) + assertEquals("1", TestUtils.recordValueAsString(records.head)) + + val transactionalConsumer = transactionalConsumers.head + transactionalConsumer.subscribe(List(topic1).asJava) + + val transactionalRecords = TestUtils.consumeRecordsFor(transactionalConsumer, 1000) + assertTrue(transactionalRecords.isEmpty) + } + + @Test + def testMultipleMarkersOneLeader(): Unit = { + val firstProducer = transactionalProducers.head + val consumer = transactionalConsumers.head + val unCommittedConsumer = nonTransactionalConsumers.head + val topicWith10Partitions = "largeTopic" + val topicWith10PartitionsAndOneReplica = "largeTopicOneReplica" + val topicConfig = new Properties() + topicConfig.put(KafkaConfig.MinInSyncReplicasProp, 2.toString) + + createTopic(topicWith10Partitions, 10, numServers, topicConfig) + createTopic(topicWith10PartitionsAndOneReplica, 10, 1, new Properties()) + + firstProducer.initTransactions() + + firstProducer.beginTransaction() + sendTransactionalMessagesWithValueRange(firstProducer, topicWith10Partitions, 0, 5000, willBeCommitted = false) + sendTransactionalMessagesWithValueRange(firstProducer, topicWith10PartitionsAndOneReplica, 5000, 10000, willBeCommitted = false) + firstProducer.abortTransaction() + + firstProducer.beginTransaction() + sendTransactionalMessagesWithValueRange(firstProducer, topicWith10Partitions, 10000, 11000, willBeCommitted = true) + firstProducer.commitTransaction() + + consumer.subscribe(List(topicWith10PartitionsAndOneReplica, topicWith10Partitions).asJava) + unCommittedConsumer.subscribe(List(topicWith10PartitionsAndOneReplica, topicWith10Partitions).asJava) + + val records = consumeRecords(consumer, 1000) + records.foreach { record => + TestUtils.assertCommittedAndGetValue(record) + } + + val allRecords = consumeRecords(unCommittedConsumer, 11000) + val expectedValues = Range(0, 11000).map(_.toString).toSet + allRecords.foreach { record => + assertTrue(expectedValues.contains(TestUtils.recordValueAsString(record))) + } + } + + @Test + def testConsecutivelyRunInitTransactions(): Unit = { + val producer = createTransactionalProducer(transactionalId = "normalProducer") + + producer.initTransactions() + assertThrows(classOf[KafkaException], () => producer.initTransactions()) + } + + @Test + def testBumpTransactionalEpoch(): Unit = { + val producer = createTransactionalProducer("transactionalProducer", + deliveryTimeoutMs = 5000, requestTimeoutMs = 5000) + val consumer = transactionalConsumers.head + try { + // Create a topic with RF=1 so that a single broker failure will render it unavailable + val testTopic = "test-topic" + createTopic(testTopic, numPartitions, 1, new Properties) + val partitionLeader = TestUtils.waitUntilLeaderIsKnown(servers, new TopicPartition(testTopic, 0)) + + producer.initTransactions() + + producer.beginTransaction() + producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(testTopic, 0, "4", "4", willBeCommitted = true)) + producer.commitTransaction() + + var producerStateEntry = + servers(partitionLeader).logManager.getLog(new TopicPartition(testTopic, 0)).get.producerStateManager.activeProducers.head._2 + val producerId = producerStateEntry.producerId + val initialProducerEpoch = producerStateEntry.producerEpoch + + producer.beginTransaction() + producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "2", "2", willBeCommitted = false)) + + killBroker(partitionLeader) // kill the partition leader to prevent the batch from being submitted + val failedFuture = producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(testTopic, 0, "3", "3", willBeCommitted = false)) + Thread.sleep(6000) // Wait for the record to time out + restartDeadBrokers() + + org.apache.kafka.test.TestUtils.assertFutureThrows(failedFuture, classOf[TimeoutException]) + producer.abortTransaction() + + producer.beginTransaction() + producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, "2", "2", willBeCommitted = true)) + producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "4", "4", willBeCommitted = true)) + producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(testTopic, 0, "1", "1", willBeCommitted = true)) + producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(testTopic, 0, "3", "3", willBeCommitted = true)) + producer.commitTransaction() + + consumer.subscribe(List(topic1, topic2, testTopic).asJava) + + val records = consumeRecords(consumer, 5) + records.foreach { record => + TestUtils.assertCommittedAndGetValue(record) + } + + // Producers can safely abort and continue after the last record of a transaction timing out, so it's possible to + // get here without having bumped the epoch. If bumping the epoch is possible, the producer will attempt to, so + // check there that the epoch has actually increased + producerStateEntry = + servers(partitionLeader).logManager.getLog(new TopicPartition(testTopic, 0)).get.producerStateManager.activeProducers(producerId) + assertTrue(producerStateEntry.producerEpoch > initialProducerEpoch) + } finally { + producer.close(Duration.ZERO) + } + } + + @Test + def testFailureToFenceEpoch(): Unit = { + val producer1 = transactionalProducers.head + val producer2 = createTransactionalProducer("transactional-producer", maxBlockMs = 1000) + + producer1.initTransactions() + + producer1.beginTransaction() + producer1.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, 0, "4", "4", willBeCommitted = true)) + producer1.commitTransaction() + + val partitionLeader = TestUtils.waitUntilLeaderIsKnown(servers, new TopicPartition(topic1, 0)) + var producerStateEntry = + servers(partitionLeader).logManager.getLog(new TopicPartition(topic1, 0)).get.producerStateManager.activeProducers.head._2 + val producerId = producerStateEntry.producerId + val initialProducerEpoch = producerStateEntry.producerEpoch + + // Kill two brokers to bring the transaction log under min-ISR + killBroker(0) + killBroker(1) + + try { + producer2.initTransactions() + } catch { + case _: TimeoutException => + // good! + case e: Exception => + throw new AssertionError("Got an unexpected exception from initTransactions", e) + } finally { + producer2.close() + } + + restartDeadBrokers() + + // Because the epoch was bumped in memory, attempting to begin a transaction with producer 1 should fail + try { + producer1.beginTransaction() + } catch { + case _: ProducerFencedException => + // good! + case e: Exception => + throw new AssertionError("Got an unexpected exception from commitTransaction", e) + } finally { + producer1.close() + } + + val producer3 = createTransactionalProducer("transactional-producer", maxBlockMs = 5000) + producer3.initTransactions() + + producer3.beginTransaction() + producer3.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, 0, "4", "4", willBeCommitted = true)) + producer3.commitTransaction() + + // Check that the epoch only increased by 1 + producerStateEntry = + servers(partitionLeader).logManager.getLog(new TopicPartition(topic1, 0)).get.producerStateManager.activeProducers(producerId) + assertEquals((initialProducerEpoch + 1).toShort, producerStateEntry.producerEpoch) + } + + private def sendTransactionalMessagesWithValueRange(producer: KafkaProducer[Array[Byte], Array[Byte]], topic: String, + start: Int, end: Int, willBeCommitted: Boolean): Unit = { + for (i <- start until end) { + producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic, null, value = i.toString, willBeCommitted = willBeCommitted, key = i.toString)) + } + producer.flush() + } + + private def serverProps() = { + val serverProps = new Properties() + serverProps.put(KafkaConfig.AutoCreateTopicsEnableProp, false.toString) + // Set a smaller value for the number of partitions for the __consumer_offsets topic + // so that the creation of that topic/partition(s) and subsequent leader assignment doesn't take relatively long + serverProps.put(KafkaConfig.OffsetsTopicPartitionsProp, 1.toString) + serverProps.put(KafkaConfig.TransactionsTopicPartitionsProp, 3.toString) + serverProps.put(KafkaConfig.TransactionsTopicReplicationFactorProp, 2.toString) + serverProps.put(KafkaConfig.TransactionsTopicMinISRProp, 2.toString) + serverProps.put(KafkaConfig.ControlledShutdownEnableProp, true.toString) + serverProps.put(KafkaConfig.UncleanLeaderElectionEnableProp, false.toString) + serverProps.put(KafkaConfig.AutoLeaderRebalanceEnableProp, false.toString) + serverProps.put(KafkaConfig.GroupInitialRebalanceDelayMsProp, "0") + serverProps.put(KafkaConfig.TransactionsAbortTimedOutTransactionCleanupIntervalMsProp, "200") + serverProps + } + + private def createReadCommittedConsumer(group: String = "group", + maxPollRecords: Int = 500, + props: Properties = new Properties) = { + val consumer = TestUtils.createConsumer(TestUtils.getBrokerListStrFromServers(servers), + groupId = group, + enableAutoCommit = false, + readCommitted = true, + maxPollRecords = maxPollRecords) + transactionalConsumers += consumer + consumer + } + + private def createReadUncommittedConsumer(group: String) = { + val consumer = TestUtils.createConsumer(TestUtils.getBrokerListStrFromServers(servers), + groupId = group, + enableAutoCommit = false) + nonTransactionalConsumers += consumer + consumer + } + + private def createTransactionalProducer(transactionalId: String, + transactionTimeoutMs: Long = 60000, + maxBlockMs: Long = 60000, + deliveryTimeoutMs: Int = 120000, + requestTimeoutMs: Int = 30000): KafkaProducer[Array[Byte], Array[Byte]] = { + val producer = TestUtils.createTransactionalProducer(transactionalId, servers, + transactionTimeoutMs = transactionTimeoutMs, + maxBlockMs = maxBlockMs, + deliveryTimeoutMs = deliveryTimeoutMs, + requestTimeoutMs = requestTimeoutMs) + transactionalProducers += producer + producer + } +} diff --git a/core/src/test/scala/integration/kafka/api/TransactionsWithMaxInFlightOneTest.scala b/core/src/test/scala/integration/kafka/api/TransactionsWithMaxInFlightOneTest.scala new file mode 100644 index 0000000..267792f --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/TransactionsWithMaxInFlightOneTest.scala @@ -0,0 +1,131 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.api + +import java.util.Properties + +import kafka.integration.KafkaServerTestHarness +import kafka.server.KafkaConfig +import kafka.utils.TestUtils +import kafka.utils.TestUtils.consumeRecords +import org.apache.kafka.clients.consumer.KafkaConsumer +import org.apache.kafka.clients.producer.KafkaProducer +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +import scala.collection.Seq +import scala.collection.mutable.Buffer +import scala.jdk.CollectionConverters._ + +/** + * This is used to test transactions with one broker and `max.in.flight.requests.per.connection=1`. + * A single broker is used to verify edge cases where different requests are queued on the same connection. + */ +class TransactionsWithMaxInFlightOneTest extends KafkaServerTestHarness { + val numServers = 1 + + val topic1 = "topic1" + val topic2 = "topic2" + val numPartitions = 4 + + val transactionalProducers = Buffer[KafkaProducer[Array[Byte], Array[Byte]]]() + val transactionalConsumers = Buffer[KafkaConsumer[Array[Byte], Array[Byte]]]() + + override def generateConfigs: Seq[KafkaConfig] = { + TestUtils.createBrokerConfigs(numServers, zkConnect).map(KafkaConfig.fromProps(_, serverProps())) + } + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + val topicConfig = new Properties() + topicConfig.put(KafkaConfig.MinInSyncReplicasProp, 1.toString) + createTopic(topic1, numPartitions, numServers, topicConfig) + createTopic(topic2, numPartitions, numServers, topicConfig) + + createTransactionalProducer("transactional-producer") + createReadCommittedConsumer("transactional-group") + } + + @AfterEach + override def tearDown(): Unit = { + transactionalProducers.foreach(_.close()) + transactionalConsumers.foreach(_.close()) + super.tearDown() + } + + @Test + def testTransactionalProducerSingleBrokerMaxInFlightOne(): Unit = { + // We want to test with one broker to verify multiple requests queued on a connection + assertEquals(1, servers.size) + + val producer = transactionalProducers.head + val consumer = transactionalConsumers.head + + producer.initTransactions() + + producer.beginTransaction() + producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, "2", "2", willBeCommitted = false)) + producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "4", "4", willBeCommitted = false)) + producer.flush() + producer.abortTransaction() + + producer.beginTransaction() + producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic1, null, "1", "1", willBeCommitted = true)) + producer.send(TestUtils.producerRecordWithExpectedTransactionStatus(topic2, null, "3", "3", willBeCommitted = true)) + producer.commitTransaction() + + consumer.subscribe(List(topic1, topic2).asJava) + + val records = consumeRecords(consumer, 2) + records.foreach { record => + TestUtils.assertCommittedAndGetValue(record) + } + } + + private def serverProps() = { + val serverProps = new Properties() + serverProps.put(KafkaConfig.AutoCreateTopicsEnableProp, false.toString) + serverProps.put(KafkaConfig.OffsetsTopicPartitionsProp, 1.toString) + serverProps.put(KafkaConfig.OffsetsTopicReplicationFactorProp, 1.toString) + serverProps.put(KafkaConfig.TransactionsTopicPartitionsProp, 1.toString) + serverProps.put(KafkaConfig.TransactionsTopicReplicationFactorProp, 1.toString) + serverProps.put(KafkaConfig.TransactionsTopicMinISRProp, 1.toString) + serverProps.put(KafkaConfig.ControlledShutdownEnableProp, true.toString) + serverProps.put(KafkaConfig.UncleanLeaderElectionEnableProp, false.toString) + serverProps.put(KafkaConfig.AutoLeaderRebalanceEnableProp, false.toString) + serverProps.put(KafkaConfig.GroupInitialRebalanceDelayMsProp, "0") + serverProps.put(KafkaConfig.TransactionsAbortTimedOutTransactionCleanupIntervalMsProp, "200") + serverProps + } + + private def createReadCommittedConsumer(group: String) = { + val consumer = TestUtils.createConsumer(TestUtils.getBrokerListStrFromServers(servers), + groupId = group, + enableAutoCommit = false, + readCommitted = true) + transactionalConsumers += consumer + consumer + } + + private def createTransactionalProducer(transactionalId: String): KafkaProducer[Array[Byte], Array[Byte]] = { + val producer = TestUtils.createTransactionalProducer(transactionalId, servers, maxInFlight = 1) + transactionalProducers += producer + producer + } +} diff --git a/core/src/test/scala/integration/kafka/api/UserClientIdQuotaTest.scala b/core/src/test/scala/integration/kafka/api/UserClientIdQuotaTest.scala new file mode 100644 index 0000000..83c70da --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/UserClientIdQuotaTest.scala @@ -0,0 +1,84 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + **/ + +package kafka.api + +import java.io.File + +import kafka.server._ +import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol} +import org.apache.kafka.common.utils.Sanitizer +import org.junit.jupiter.api.{BeforeEach, TestInfo} + +class UserClientIdQuotaTest extends BaseQuotaTest { + + override protected def securityProtocol = SecurityProtocol.SSL + override protected lazy val trustStoreFile = Some(File.createTempFile("truststore", ".jks")) + + override def producerClientId = "QuotasTestProducer-!@#$%^&*()" + override def consumerClientId = "QuotasTestConsumer-!@#$%^&*()" + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + this.serverConfig.setProperty(KafkaConfig.SslClientAuthProp, "required") + super.setUp(testInfo) + quotaTestClients.alterClientQuotas( + quotaTestClients.clientQuotaAlteration( + quotaTestClients.clientQuotaEntity(Some(QuotaTestClients.DefaultEntity), Some(QuotaTestClients.DefaultEntity)), + Some(defaultProducerQuota), Some(defaultConsumerQuota), Some(defaultRequestQuota) + ) + ) + quotaTestClients.waitForQuotaUpdate(defaultProducerQuota, defaultConsumerQuota, defaultRequestQuota) + } + + override def createQuotaTestClients(topic: String, leaderNode: KafkaServer): QuotaTestClients = { + val producer = createProducer() + val consumer = createConsumer() + val adminClient = createAdminClient() + + new QuotaTestClients(topic, leaderNode, producerClientId, consumerClientId, producer, consumer, adminClient) { + override def userPrincipal: KafkaPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "O=A client,CN=localhost") + + override def quotaMetricTags(clientId: String): Map[String, String] = { + Map("user" -> Sanitizer.sanitize(userPrincipal.getName), "client-id" -> clientId) + } + + override def overrideQuotas(producerQuota: Long, consumerQuota: Long, requestQuota: Double): Unit = { + alterClientQuotas( + clientQuotaAlteration( + clientQuotaEntity(Some(userPrincipal.getName), Some(producerClientId)), + Some(producerQuota), None, Some(requestQuota) + ), + clientQuotaAlteration( + clientQuotaEntity(Some(userPrincipal.getName), Some(consumerClientId)), + None, Some(consumerQuota), Some(requestQuota) + ) + ) + } + + override def removeQuotaOverrides(): Unit = { + alterClientQuotas( + clientQuotaAlteration( + clientQuotaEntity(Some(userPrincipal.getName), Some(producerClientId)), + None, None, None + ), + clientQuotaAlteration( + clientQuotaEntity(Some(userPrincipal.getName), Some(consumerClientId)), + None, None, None + ) + ) + } + } + } +} diff --git a/core/src/test/scala/integration/kafka/api/UserQuotaTest.scala b/core/src/test/scala/integration/kafka/api/UserQuotaTest.scala new file mode 100644 index 0000000..ffbaebb --- /dev/null +++ b/core/src/test/scala/integration/kafka/api/UserQuotaTest.scala @@ -0,0 +1,83 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + **/ + +package kafka.api + +import java.io.File + +import kafka.server.KafkaServer +import kafka.utils.JaasTestUtils +import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol} +import org.junit.jupiter.api.{AfterEach, BeforeEach, TestInfo} + +class UserQuotaTest extends BaseQuotaTest with SaslSetup { + + override protected def securityProtocol = SecurityProtocol.SASL_SSL + override protected lazy val trustStoreFile = Some(File.createTempFile("truststore", ".jks")) + private val kafkaServerSaslMechanisms = Seq("GSSAPI") + private val kafkaClientSaslMechanism = "GSSAPI" + override protected val serverSaslProperties = Some(kafkaServerSaslProperties(kafkaServerSaslMechanisms, kafkaClientSaslMechanism)) + override protected val clientSaslProperties = Some(kafkaClientSaslProperties(kafkaClientSaslMechanism)) + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + startSasl(jaasSections(kafkaServerSaslMechanisms, Some("GSSAPI"), KafkaSasl, JaasTestUtils.KafkaServerContextName)) + super.setUp(testInfo) + quotaTestClients.alterClientQuotas( + quotaTestClients.clientQuotaAlteration( + quotaTestClients.clientQuotaEntity(Some(QuotaTestClients.DefaultEntity), None), + Some(defaultProducerQuota), Some(defaultConsumerQuota), Some(defaultRequestQuota) + ) + ) + quotaTestClients.waitForQuotaUpdate(defaultProducerQuota, defaultConsumerQuota, defaultRequestQuota) + } + + @AfterEach + override def tearDown(): Unit = { + super.tearDown() + closeSasl() + } + + override def createQuotaTestClients(topic: String, leaderNode: KafkaServer): QuotaTestClients = { + val producer = createProducer() + val consumer = createConsumer() + val adminClient = createAdminClient() + + new QuotaTestClients(topic, leaderNode, producerClientId, consumerClientId, producer, consumer, adminClient) { + override val userPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, JaasTestUtils.KafkaClientPrincipalUnqualifiedName2) + + override def quotaMetricTags(clientId: String): Map[String, String] = { + Map("user" -> userPrincipal.getName, "client-id" -> "") + } + + override def overrideQuotas(producerQuota: Long, consumerQuota: Long, requestQuota: Double): Unit = { + alterClientQuotas( + clientQuotaAlteration( + clientQuotaEntity(Some(userPrincipal.getName), None), + Some(producerQuota), Some(consumerQuota), Some(requestQuota) + ) + ) + } + + override def removeQuotaOverrides(): Unit = { + alterClientQuotas( + clientQuotaAlteration( + clientQuotaEntity(Some(userPrincipal.getName), None), + None, None, None + ) + ) + } + } + } +} diff --git a/core/src/test/scala/integration/kafka/coordinator/transaction/ProducerIdsIntegrationTest.scala b/core/src/test/scala/integration/kafka/coordinator/transaction/ProducerIdsIntegrationTest.scala new file mode 100644 index 0000000..be9f159 --- /dev/null +++ b/core/src/test/scala/integration/kafka/coordinator/transaction/ProducerIdsIntegrationTest.scala @@ -0,0 +1,86 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.coordinator.transaction + +import kafka.network.SocketServer +import kafka.server.{IntegrationTestUtils, KafkaConfig} +import kafka.test.annotation.{AutoStart, ClusterTest, ClusterTests, Type} +import kafka.test.junit.ClusterTestExtensions +import kafka.test.{ClusterConfig, ClusterInstance} +import org.apache.kafka.common.message.InitProducerIdRequestData +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.record.RecordBatch +import org.apache.kafka.common.requests.{InitProducerIdRequest, InitProducerIdResponse} +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.extension.ExtendWith + +import java.util.stream.{Collectors, IntStream} +import scala.jdk.CollectionConverters._ + +@ExtendWith(value = Array(classOf[ClusterTestExtensions])) +class ProducerIdsIntegrationTest { + + @BeforeEach + def setup(clusterConfig: ClusterConfig): Unit = { + clusterConfig.serverProperties().put(KafkaConfig.TransactionsTopicPartitionsProp, "1") + clusterConfig.serverProperties().put(KafkaConfig.TransactionsTopicReplicationFactorProp, "3") + } + + @ClusterTests(Array( + new ClusterTest(clusterType = Type.ZK, brokers = 3, ibp = "2.8"), + new ClusterTest(clusterType = Type.ZK, brokers = 3, ibp = "3.0-IV0"), + new ClusterTest(clusterType = Type.KRAFT, brokers = 3, ibp = "3.0-IV0") + )) + def testUniqueProducerIds(clusterInstance: ClusterInstance): Unit = { + verifyUniqueIds(clusterInstance) + } + + @ClusterTest(clusterType = Type.ZK, brokers = 3, autoStart = AutoStart.NO) + def testUniqueProducerIdsBumpIBP(clusterInstance: ClusterInstance): Unit = { + clusterInstance.config().serverProperties().put(KafkaConfig.InterBrokerProtocolVersionProp, "2.8") + clusterInstance.config().brokerServerProperties(0).put(KafkaConfig.InterBrokerProtocolVersionProp, "3.0-IV0") + clusterInstance.start() + verifyUniqueIds(clusterInstance) + clusterInstance.stop() + } + + private def verifyUniqueIds(clusterInstance: ClusterInstance): Unit = { + // Request enough PIDs from each broker to ensure each broker generates two PID blocks + val ids = clusterInstance.brokerSocketServers().stream().flatMap( broker => { + IntStream.range(0, 1001).parallel().mapToObj( _ => nextProducerId(broker, clusterInstance.clientListener())) + }).collect(Collectors.toList[Long]).asScala.toSeq + + assertEquals(3003, ids.size, "Expected exactly 3003 IDs") + assertEquals(ids.size, ids.distinct.size, "Found duplicate producer IDs") + } + + private def nextProducerId(broker: SocketServer, listener: ListenerName): Long = { + val data = new InitProducerIdRequestData() + .setProducerEpoch(RecordBatch.NO_PRODUCER_EPOCH) + .setProducerId(RecordBatch.NO_PRODUCER_ID) + .setTransactionalId(null) + .setTransactionTimeoutMs(10) + val request = new InitProducerIdRequest.Builder(data).build() + + val response = IntegrationTestUtils.connectAndReceive[InitProducerIdResponse](request, + destination = broker, + listenerName = listener) + response.data().producerId() + } +} diff --git a/core/src/test/scala/integration/kafka/network/DynamicConnectionQuotaTest.scala b/core/src/test/scala/integration/kafka/network/DynamicConnectionQuotaTest.scala new file mode 100644 index 0000000..14f3e8f --- /dev/null +++ b/core/src/test/scala/integration/kafka/network/DynamicConnectionQuotaTest.scala @@ -0,0 +1,415 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.network + +import java.io.IOException +import java.net.{InetAddress, Socket} +import java.util.concurrent._ +import java.util.{Collections, Properties} +import kafka.server.{BaseRequestTest, KafkaConfig} +import kafka.utils.TestUtils +import org.apache.kafka.clients.admin.{Admin, AdminClientConfig} +import org.apache.kafka.common.config.internals.QuotaConfigs +import org.apache.kafka.common.message.ProduceRequestData +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.quota.ClientQuotaEntity +import org.apache.kafka.common.record.{CompressionType, MemoryRecords, SimpleRecord} +import org.apache.kafka.common.requests.{ProduceRequest, ProduceResponse} +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.{KafkaException, requests} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +import scala.jdk.CollectionConverters._ + +class DynamicConnectionQuotaTest extends BaseRequestTest { + + override def brokerCount = 1 + + val topic = "test" + val listener = ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT) + val localAddress = InetAddress.getByName("127.0.0.1") + val unknownHost = "255.255.0.1" + val plaintextListenerDefaultQuota = 30 + var executor: ExecutorService = _ + + override def brokerPropertyOverrides(properties: Properties): Unit = { + properties.put(KafkaConfig.NumQuotaSamplesProp, "2".toString) + properties.put("listener.name.plaintext.max.connection.creation.rate", plaintextListenerDefaultQuota.toString) + } + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + TestUtils.createTopic(zkClient, topic, brokerCount, brokerCount, servers) + } + + @AfterEach + override def tearDown(): Unit = { + try { + if (executor != null) { + executor.shutdownNow() + assertTrue(executor.awaitTermination(10, TimeUnit.SECONDS)) + } + } finally { + super.tearDown() + } + } + + @Test + def testDynamicConnectionQuota(): Unit = { + val maxConnectionsPerIP = 5 + + def connectAndVerify(): Unit = { + val socket = connect() + try { + sendAndReceive[ProduceResponse](produceRequest, socket) + } finally { + socket.close() + } + } + + val props = new Properties + props.put(KafkaConfig.MaxConnectionsPerIpProp, maxConnectionsPerIP.toString) + reconfigureServers(props, perBrokerConfig = false, (KafkaConfig.MaxConnectionsPerIpProp, maxConnectionsPerIP.toString)) + + verifyMaxConnections(maxConnectionsPerIP, connectAndVerify) + + // Increase MaxConnectionsPerIpOverrides for localhost to 7 + val maxConnectionsPerIPOverride = 7 + props.put(KafkaConfig.MaxConnectionsPerIpOverridesProp, s"localhost:$maxConnectionsPerIPOverride") + reconfigureServers(props, perBrokerConfig = false, (KafkaConfig.MaxConnectionsPerIpOverridesProp, s"localhost:$maxConnectionsPerIPOverride")) + + verifyMaxConnections(maxConnectionsPerIPOverride, connectAndVerify) + } + + @Test + def testDynamicListenerConnectionQuota(): Unit = { + val initialConnectionCount = connectionCount + + def connectAndVerify(): Unit = { + val socket = connect("PLAINTEXT") + socket.setSoTimeout(1000) + try { + sendAndReceive[ProduceResponse](produceRequest, socket) + } finally { + socket.close() + } + } + + // Reduce total broker MaxConnections to 5 at the cluster level + val props = new Properties + props.put(KafkaConfig.MaxConnectionsProp, "5") + reconfigureServers(props, perBrokerConfig = false, (KafkaConfig.MaxConnectionsProp, "5")) + verifyMaxConnections(5, connectAndVerify) + + // Create another listener and verify listener connection limit of 5 for each listener + val newListeners = "PLAINTEXT://localhost:0,INTERNAL://localhost:0" + props.put(KafkaConfig.ListenersProp, newListeners) + props.put(KafkaConfig.ListenerSecurityProtocolMapProp, "PLAINTEXT:PLAINTEXT,INTERNAL:PLAINTEXT") + props.put(KafkaConfig.MaxConnectionsProp, "10") + props.put("listener.name.internal.max.connections", "5") + props.put("listener.name.plaintext.max.connections", "5") + reconfigureServers(props, perBrokerConfig = true, (KafkaConfig.ListenersProp, newListeners)) + waitForListener("INTERNAL") + + var conns = (connectionCount until 5).map(_ => connect("PLAINTEXT")) + conns ++= (5 until 10).map(_ => connect("INTERNAL")) + conns.foreach(verifyConnection) + conns.foreach(_.close()) + TestUtils.waitUntilTrue(() => initialConnectionCount == connectionCount, "Connections not closed") + + // Increase MaxConnections for PLAINTEXT listener to 7 at the broker level + val maxConnectionsPlaintext = 7 + val listenerProp = s"${listener.configPrefix}${KafkaConfig.MaxConnectionsProp}" + props.put(listenerProp, maxConnectionsPlaintext.toString) + reconfigureServers(props, perBrokerConfig = true, (listenerProp, maxConnectionsPlaintext.toString)) + verifyMaxConnections(maxConnectionsPlaintext, connectAndVerify) + + // Verify that connection blocked on the limit connects successfully when an existing connection is closed + val plaintextConnections = (connectionCount until maxConnectionsPlaintext).map(_ => connect("PLAINTEXT")) + executor = Executors.newSingleThreadExecutor + val future = executor.submit((() => createAndVerifyConnection()): Runnable) + Thread.sleep(100) + assertFalse(future.isDone) + plaintextConnections.head.close() + future.get(30, TimeUnit.SECONDS) + plaintextConnections.foreach(_.close()) + TestUtils.waitUntilTrue(() => initialConnectionCount == connectionCount, "Connections not closed") + + // Verify that connections on inter-broker listener succeed even if broker max connections has been + // reached by closing connections on another listener + var plaintextConns = (connectionCount until 5).map(_ => connect("PLAINTEXT")) + val internalConns = (5 until 10).map(_ => connect("INTERNAL")) + plaintextConns.foreach(verifyConnection) + internalConns.foreach(verifyConnection) + plaintextConns ++= (0 until 2).map(_ => connect("PLAINTEXT")) + TestUtils.waitUntilTrue(() => connectionCount <= 10, "Internal connections not closed") + plaintextConns.foreach(verifyConnection) + assertThrows(classOf[IOException], () => internalConns.foreach { socket => + sendAndReceive[ProduceResponse](produceRequest, socket) + }) + plaintextConns.foreach(_.close()) + internalConns.foreach(_.close()) + TestUtils.waitUntilTrue(() => initialConnectionCount == connectionCount, "Connections not closed") + } + + @Test + def testDynamicListenerConnectionCreationRateQuota(): Unit = { + // Create another listener. PLAINTEXT is an inter-broker listener + // keep default limits + val newListenerNames = Seq("PLAINTEXT", "EXTERNAL") + val newListeners = "PLAINTEXT://localhost:0,EXTERNAL://localhost:0" + val props = new Properties + props.put(KafkaConfig.ListenersProp, newListeners) + props.put(KafkaConfig.ListenerSecurityProtocolMapProp, "PLAINTEXT:PLAINTEXT,EXTERNAL:PLAINTEXT") + reconfigureServers(props, perBrokerConfig = true, (KafkaConfig.ListenersProp, newListeners)) + waitForListener("EXTERNAL") + + // The expected connection count after each test run + val initialConnectionCount = connectionCount + + // new broker-wide connection rate limit + val connRateLimit = 9 + + // before setting connection rate to 10, verify we can do at least double that by default (no limit) + verifyConnectionRate(2 * connRateLimit, plaintextListenerDefaultQuota, "PLAINTEXT", ignoreIOExceptions = false) + waitForConnectionCount(initialConnectionCount) + + // Reduce total broker connection rate limit to 9 at the cluster level and verify the limit is enforced + props.clear() // so that we do not pass security protocol map which cannot be set at the cluster level + props.put(KafkaConfig.MaxConnectionCreationRateProp, connRateLimit.toString) + reconfigureServers(props, perBrokerConfig = false, (KafkaConfig.MaxConnectionCreationRateProp, connRateLimit.toString)) + // verify EXTERNAL listener is capped by broker-wide quota (PLAINTEXT is not capped by broker-wide limit, since it + // has limited quota set and is a protected listener) + verifyConnectionRate(8, connRateLimit, "EXTERNAL", ignoreIOExceptions = false) + waitForConnectionCount(initialConnectionCount) + + // Set 4 conn/sec rate limit for each listener and verify it gets enforced + val listenerConnRateLimit = 4 + val plaintextListenerProp = s"${listener.configPrefix}${KafkaConfig.MaxConnectionCreationRateProp}" + props.put(s"listener.name.external.${KafkaConfig.MaxConnectionCreationRateProp}", listenerConnRateLimit.toString) + props.put(plaintextListenerProp, listenerConnRateLimit.toString) + reconfigureServers(props, perBrokerConfig = true, (plaintextListenerProp, listenerConnRateLimit.toString)) + + executor = Executors.newFixedThreadPool(newListenerNames.size) + val futures = newListenerNames.map { listener => + executor.submit((() => verifyConnectionRate(3, listenerConnRateLimit, listener, ignoreIOExceptions = false)): Runnable) + } + futures.foreach(_.get(40, TimeUnit.SECONDS)) + waitForConnectionCount(initialConnectionCount) + + // increase connection rate limit on PLAINTEXT (inter-broker) listener to 12 and verify that it will be able to + // achieve this rate even though total connection rate may exceed broker-wide rate limit, while EXTERNAL listener + // should not exceed its listener limit + val newPlaintextRateLimit = 12 + props.put(plaintextListenerProp, newPlaintextRateLimit.toString) + reconfigureServers(props, perBrokerConfig = true, (plaintextListenerProp, newPlaintextRateLimit.toString)) + + val plaintextFuture = executor.submit((() => + verifyConnectionRate(10, newPlaintextRateLimit, "PLAINTEXT", ignoreIOExceptions = false)): Runnable) + val externalFuture = executor.submit((() => + verifyConnectionRate(3, listenerConnRateLimit, "EXTERNAL", ignoreIOExceptions = false)): Runnable) + + plaintextFuture.get(40, TimeUnit.SECONDS) + externalFuture.get(40, TimeUnit.SECONDS) + waitForConnectionCount(initialConnectionCount) + } + + @Test + def testDynamicIpConnectionRateQuota(): Unit = { + val connRateLimit = 10 + val initialConnectionCount = connectionCount + // before setting connection rate to 10, verify we can do at least double that by default (no limit) + verifyConnectionRate(2 * connRateLimit, plaintextListenerDefaultQuota, "PLAINTEXT", ignoreIOExceptions = false) + waitForConnectionCount(initialConnectionCount) + // set default IP connection rate quota, verify that we don't exceed the limit + updateIpConnectionRate(None, connRateLimit) + verifyConnectionRate(8, connRateLimit, "PLAINTEXT", ignoreIOExceptions = true) + waitForConnectionCount(initialConnectionCount) + // set a higher IP connection rate quota override, verify that the higher limit is now enforced + val newRateLimit = 18 + updateIpConnectionRate(Some(localAddress.getHostAddress), newRateLimit) + verifyConnectionRate(14, newRateLimit, "PLAINTEXT", ignoreIOExceptions = true) + waitForConnectionCount(initialConnectionCount) + } + + private def reconfigureServers(newProps: Properties, perBrokerConfig: Boolean, aPropToVerify: (String, String)): Unit = { + val initialConnectionCount = connectionCount + val adminClient = createAdminClient() + TestUtils.incrementalAlterConfigs(servers, adminClient, newProps, perBrokerConfig).all.get() + waitForConfigOnServer(aPropToVerify._1, aPropToVerify._2) + adminClient.close() + TestUtils.waitUntilTrue(() => initialConnectionCount == connectionCount, + s"Admin client connection not closed (initial = $initialConnectionCount, current = $connectionCount)") + } + + private def updateIpConnectionRate(ip: Option[String], updatedRate: Int): Unit = { + val initialConnectionCount = connectionCount + val adminClient = createAdminClient() + try { + val entity = new ClientQuotaEntity(Map(ClientQuotaEntity.IP -> ip.orNull).asJava) + val request = Map(entity -> Map(QuotaConfigs.IP_CONNECTION_RATE_OVERRIDE_CONFIG -> Some(updatedRate.toDouble))) + TestUtils.alterClientQuotas(adminClient, request).all.get() + // use a random throwaway address if ip isn't specified to get the default value + TestUtils.waitUntilTrue(() => servers.head.socketServer.connectionQuotas. + connectionRateForIp(InetAddress.getByName(ip.getOrElse(unknownHost))) == updatedRate, + s"Timed out waiting for connection rate update to propagate" + ) + } finally { + adminClient.close() + } + TestUtils.waitUntilTrue(() => initialConnectionCount == connectionCount, + s"Admin client connection not closed (initial = $initialConnectionCount, current = $connectionCount)") + } + + private def waitForListener(listenerName: String): Unit = { + TestUtils.retry(maxWaitMs = 10000) { + try { + assertTrue(servers.head.socketServer.boundPort(ListenerName.normalised(listenerName)) > 0) + } catch { + case e: KafkaException => throw new AssertionError(e) + } + } + } + + private def createAdminClient(): Admin = { + val bootstrapServers = TestUtils.bootstrapServers(servers, new ListenerName(securityProtocol.name)) + val config = new Properties() + config.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers) + config.put(AdminClientConfig.METADATA_MAX_AGE_CONFIG, "10") + val adminClient = Admin.create(config) + adminClient + } + + private def waitForConfigOnServer(propName: String, propValue: String, maxWaitMs: Long = 10000): Unit = { + TestUtils.retry(maxWaitMs) { + assertEquals(propValue, servers.head.config.originals.get(propName)) + } + } + + private def produceRequest: ProduceRequest = + requests.ProduceRequest.forCurrentMagic(new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection( + Collections.singletonList(new ProduceRequestData.TopicProduceData() + .setName(topic) + .setPartitionData(Collections.singletonList(new ProduceRequestData.PartitionProduceData() + .setIndex(0) + .setRecords(MemoryRecords.withRecords(CompressionType.NONE, + new SimpleRecord(System.currentTimeMillis(), "key".getBytes, "value".getBytes)))))) + .iterator)) + .setAcks((-1).toShort) + .setTimeoutMs(3000) + .setTransactionalId(null)) + .build() + + def connectionCount: Int = servers.head.socketServer.connectionCount(localAddress) + + def connect(listener: String): Socket = { + val listenerName = ListenerName.normalised(listener) + new Socket("localhost", servers.head.socketServer.boundPort(listenerName)) + } + + private def createAndVerifyConnection(listener: String = "PLAINTEXT"): Unit = { + val socket = connect(listener) + try { + verifyConnection(socket) + } finally { + socket.close() + } + } + + private def verifyConnection(socket: Socket): Unit = { + val produceResponse = sendAndReceive[ProduceResponse](produceRequest, socket) + assertEquals(1, produceResponse.data.responses.size) + val topicProduceResponse = produceResponse.data.responses.asScala.head + assertEquals(1, topicProduceResponse.partitionResponses.size) + val partitionProduceResponse = topicProduceResponse.partitionResponses.asScala.head + assertEquals(Errors.NONE, Errors.forCode(partitionProduceResponse.errorCode)) + } + + private def verifyMaxConnections(maxConnections: Int, connectWithFailure: () => Unit): Unit = { + val initialConnectionCount = connectionCount + + //create connections up to maxConnectionsPerIP - 1, leave space for one connection + var conns = (connectionCount until (maxConnections - 1)).map(_ => connect("PLAINTEXT")) + + // produce should succeed on a new connection + createAndVerifyConnection() + + TestUtils.waitUntilTrue(() => connectionCount == (maxConnections - 1), "produce request connection is not closed") + conns = conns :+ connect("PLAINTEXT") + + // now try one more (should fail) + assertThrows(classOf[IOException], () => connectWithFailure.apply()) + + //close one connection + conns.head.close() + TestUtils.waitUntilTrue(() => connectionCount == (maxConnections - 1), "connection is not closed") + createAndVerifyConnection() + + conns.foreach(_.close()) + TestUtils.waitUntilTrue(() => initialConnectionCount == connectionCount, "Connections not closed") + } + + private def connectAndVerify(listener: String, ignoreIOExceptions: Boolean): Unit = { + val socket = connect(listener) + try { + sendAndReceive[ProduceResponse](produceRequest, socket) + } catch { + // IP rate throttling can lead to disconnected sockets on client's end + case e: IOException => if (!ignoreIOExceptions) throw e + } finally { + socket.close() + } + } + + private def waitForConnectionCount(expectedConnectionCount: Int): Unit = { + TestUtils.waitUntilTrue(() => expectedConnectionCount == connectionCount, + s"Connections not closed (expected = $expectedConnectionCount current = $connectionCount)") + } + + /** + * this method simulates a workload that creates connection, sends produce request, closes connection, + * and verifies that rate does not exceed the given maximum limit `maxConnectionRate` + * + * Since producing a request and closing a connection also takes time, this method does not verify that the lower bound + * of actual rate is close to `maxConnectionRate`. Instead, use `minConnectionRate` parameter to verify that the rate + * is at least certain value. Note that throttling is tested and verified more accurately in ConnectionQuotasTest + */ + private def verifyConnectionRate(minConnectionRate: Int, maxConnectionRate: Int, listener: String, ignoreIOExceptions: Boolean): Unit = { + // duration such that the maximum rate should be at most 20% higher than the rate limit. Since all connections + // can fall in the beginning of quota window, it is OK to create extra 2 seconds (window size) worth of connections + val runTimeMs = TimeUnit.SECONDS.toMillis(13) + val startTimeMs = System.currentTimeMillis + val endTimeMs = startTimeMs + runTimeMs + + var connCount = 0 + while (System.currentTimeMillis < endTimeMs) { + connectAndVerify(listener, ignoreIOExceptions) + connCount += 1 + } + val elapsedMs = System.currentTimeMillis - startTimeMs + val actualRate = (connCount.toDouble / elapsedMs) * 1000 + val rateCap = if (maxConnectionRate < Int.MaxValue) 1.2 * maxConnectionRate.toDouble else Int.MaxValue.toDouble + assertTrue(actualRate <= rateCap, s"Listener $listener connection rate $actualRate must be below $rateCap") + assertTrue(actualRate >= minConnectionRate, s"Listener $listener connection rate $actualRate must be above $minConnectionRate") + } +} diff --git a/core/src/test/scala/integration/kafka/server/DelayedFetchTest.scala b/core/src/test/scala/integration/kafka/server/DelayedFetchTest.scala new file mode 100644 index 0000000..7585862 --- /dev/null +++ b/core/src/test/scala/integration/kafka/server/DelayedFetchTest.scala @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.util.Optional + +import scala.collection.Seq +import kafka.cluster.Partition +import kafka.log.LogOffsetSnapshot +import org.apache.kafka.common.{TopicIdPartition, Uuid} +import org.apache.kafka.common.errors.{FencedLeaderEpochException, NotLeaderOrFollowerException} +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.record.MemoryRecords +import org.apache.kafka.common.requests.FetchRequest +import org.easymock.{EasyMock, EasyMockSupport} +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Assertions._ + +class DelayedFetchTest extends EasyMockSupport { + private val maxBytes = 1024 + private val replicaManager: ReplicaManager = mock(classOf[ReplicaManager]) + private val replicaQuota: ReplicaQuota = mock(classOf[ReplicaQuota]) + + @Test + def testFetchWithFencedEpoch(): Unit = { + val topicIdPartition = new TopicIdPartition(Uuid.randomUuid(), 0, "topic") + val fetchOffset = 500L + val logStartOffset = 0L + val currentLeaderEpoch = Optional.of[Integer](10) + val replicaId = 1 + + val fetchStatus = FetchPartitionStatus( + startOffsetMetadata = LogOffsetMetadata(fetchOffset), + fetchInfo = new FetchRequest.PartitionData(Uuid.ZERO_UUID, fetchOffset, logStartOffset, maxBytes, currentLeaderEpoch)) + val fetchMetadata = buildFetchMetadata(replicaId, topicIdPartition, fetchStatus) + + var fetchResultOpt: Option[FetchPartitionData] = None + def callback(responses: Seq[(TopicIdPartition, FetchPartitionData)]): Unit = { + fetchResultOpt = Some(responses.head._2) + } + + val delayedFetch = new DelayedFetch( + delayMs = 500, + fetchMetadata = fetchMetadata, + replicaManager = replicaManager, + quota = replicaQuota, + clientMetadata = None, + responseCallback = callback) + + val partition: Partition = mock(classOf[Partition]) + + EasyMock.expect(replicaManager.getPartitionOrException(topicIdPartition.topicPartition)) + .andReturn(partition) + EasyMock.expect(partition.fetchOffsetSnapshot( + currentLeaderEpoch, + fetchOnlyFromLeader = true)) + .andThrow(new FencedLeaderEpochException("Requested epoch has been fenced")) + EasyMock.expect(replicaManager.isAddingReplica(EasyMock.anyObject(), EasyMock.anyInt())).andReturn(false) + + expectReadFromReplica(replicaId, topicIdPartition, fetchStatus.fetchInfo, Errors.FENCED_LEADER_EPOCH) + + replayAll() + + assertTrue(delayedFetch.tryComplete()) + assertTrue(delayedFetch.isCompleted) + assertTrue(fetchResultOpt.isDefined) + + val fetchResult = fetchResultOpt.get + assertEquals(Errors.FENCED_LEADER_EPOCH, fetchResult.error) + } + + @Test + def testNotLeaderOrFollower(): Unit = { + val topicIdPartition = new TopicIdPartition(Uuid.randomUuid(), 0, "topic") + val fetchOffset = 500L + val logStartOffset = 0L + val currentLeaderEpoch = Optional.of[Integer](10) + val replicaId = 1 + + val fetchStatus = FetchPartitionStatus( + startOffsetMetadata = LogOffsetMetadata(fetchOffset), + fetchInfo = new FetchRequest.PartitionData(Uuid.ZERO_UUID, fetchOffset, logStartOffset, maxBytes, currentLeaderEpoch)) + val fetchMetadata = buildFetchMetadata(replicaId, topicIdPartition, fetchStatus) + + var fetchResultOpt: Option[FetchPartitionData] = None + def callback(responses: Seq[(TopicIdPartition, FetchPartitionData)]): Unit = { + fetchResultOpt = Some(responses.head._2) + } + + val delayedFetch = new DelayedFetch( + delayMs = 500, + fetchMetadata = fetchMetadata, + replicaManager = replicaManager, + quota = replicaQuota, + clientMetadata = None, + responseCallback = callback) + + EasyMock.expect(replicaManager.getPartitionOrException(topicIdPartition.topicPartition)) + .andThrow(new NotLeaderOrFollowerException(s"Replica for $topicIdPartition not available")) + expectReadFromReplica(replicaId, topicIdPartition, fetchStatus.fetchInfo, Errors.NOT_LEADER_OR_FOLLOWER) + EasyMock.expect(replicaManager.isAddingReplica(EasyMock.anyObject(), EasyMock.anyInt())).andReturn(false) + + replayAll() + + assertTrue(delayedFetch.tryComplete()) + assertTrue(delayedFetch.isCompleted) + assertTrue(fetchResultOpt.isDefined) + } + + @Test + def testDivergingEpoch(): Unit = { + val topicIdPartition = new TopicIdPartition(Uuid.randomUuid(), 0, "topic") + val fetchOffset = 500L + val logStartOffset = 0L + val currentLeaderEpoch = Optional.of[Integer](10) + val lastFetchedEpoch = Optional.of[Integer](9) + val replicaId = 1 + + val fetchStatus = FetchPartitionStatus( + startOffsetMetadata = LogOffsetMetadata(fetchOffset), + fetchInfo = new FetchRequest.PartitionData(topicIdPartition.topicId, fetchOffset, logStartOffset, maxBytes, currentLeaderEpoch, lastFetchedEpoch)) + val fetchMetadata = buildFetchMetadata(replicaId, topicIdPartition, fetchStatus) + + var fetchResultOpt: Option[FetchPartitionData] = None + def callback(responses: Seq[(TopicIdPartition, FetchPartitionData)]): Unit = { + fetchResultOpt = Some(responses.head._2) + } + + val delayedFetch = new DelayedFetch( + delayMs = 500, + fetchMetadata = fetchMetadata, + replicaManager = replicaManager, + quota = replicaQuota, + clientMetadata = None, + responseCallback = callback) + + val partition: Partition = mock(classOf[Partition]) + EasyMock.expect(replicaManager.getPartitionOrException(topicIdPartition.topicPartition)).andReturn(partition) + val endOffsetMetadata = LogOffsetMetadata(messageOffset = 500L, segmentBaseOffset = 0L, relativePositionInSegment = 500) + EasyMock.expect(partition.fetchOffsetSnapshot( + currentLeaderEpoch, + fetchOnlyFromLeader = true)) + .andReturn(LogOffsetSnapshot(0L, endOffsetMetadata, endOffsetMetadata, endOffsetMetadata)) + EasyMock.expect(partition.lastOffsetForLeaderEpoch(currentLeaderEpoch, lastFetchedEpoch.get, fetchOnlyFromLeader = false)) + .andReturn(new EpochEndOffset() + .setPartition(topicIdPartition.partition) + .setErrorCode(Errors.NONE.code) + .setLeaderEpoch(lastFetchedEpoch.get) + .setEndOffset(fetchOffset - 1)) + EasyMock.expect(replicaManager.isAddingReplica(EasyMock.anyObject(), EasyMock.anyInt())).andReturn(false) + expectReadFromReplica(replicaId, topicIdPartition, fetchStatus.fetchInfo, Errors.NONE) + replayAll() + + assertTrue(delayedFetch.tryComplete()) + assertTrue(delayedFetch.isCompleted) + assertTrue(fetchResultOpt.isDefined) + } + + private def buildFetchMetadata(replicaId: Int, + topicIdPartition: TopicIdPartition, + fetchStatus: FetchPartitionStatus): FetchMetadata = { + FetchMetadata(fetchMinBytes = 1, + fetchMaxBytes = maxBytes, + hardMaxBytesLimit = false, + fetchOnlyLeader = true, + fetchIsolation = FetchLogEnd, + isFromFollower = true, + replicaId = replicaId, + fetchPartitionStatus = Seq((topicIdPartition, fetchStatus))) + } + + private def expectReadFromReplica(replicaId: Int, + topicIdPartition: TopicIdPartition, + fetchPartitionData: FetchRequest.PartitionData, + error: Errors): Unit = { + EasyMock.expect(replicaManager.readFromLocalLog( + replicaId = replicaId, + fetchOnlyFromLeader = true, + fetchIsolation = FetchLogEnd, + fetchMaxBytes = maxBytes, + hardMaxBytesLimit = false, + readPartitionInfo = Seq((topicIdPartition, fetchPartitionData)), + clientMetadata = None, + quota = replicaQuota)) + .andReturn(Seq((topicIdPartition, buildReadResult(error)))) + } + + private def buildReadResult(error: Errors): LogReadResult = { + LogReadResult( + exception = if (error != Errors.NONE) Some(error.exception) else None, + info = FetchDataInfo(LogOffsetMetadata.UnknownOffsetMetadata, MemoryRecords.EMPTY), + divergingEpoch = None, + highWatermark = -1L, + leaderLogStartOffset = -1L, + leaderLogEndOffset = -1L, + followerLogStartOffset = -1L, + fetchTimeMs = -1L, + lastStableOffset = None) + } + +} diff --git a/core/src/test/scala/integration/kafka/server/DynamicBrokerReconfigurationTest.scala b/core/src/test/scala/integration/kafka/server/DynamicBrokerReconfigurationTest.scala new file mode 100644 index 0000000..687cd9e --- /dev/null +++ b/core/src/test/scala/integration/kafka/server/DynamicBrokerReconfigurationTest.scala @@ -0,0 +1,1875 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.io.{Closeable, File, FileWriter, IOException, Reader, StringReader} +import java.nio.file.{Files, Paths, StandardCopyOption} +import java.lang.management.ManagementFactory +import java.security.KeyStore +import java.time.Duration +import java.util +import java.util.{Collections, Properties} +import java.util.concurrent._ +import javax.management.ObjectName +import com.yammer.metrics.core.MetricName +import kafka.admin.ConfigCommand +import kafka.api.{KafkaSasl, SaslSetup} +import kafka.controller.{ControllerBrokerStateInfo, ControllerChannelManager} +import kafka.log.LogConfig +import kafka.message.ProducerCompressionCodec +import kafka.metrics.KafkaYammerMetrics +import kafka.network.{Processor, RequestChannel} +import kafka.server.QuorumTestHarness +import kafka.utils._ +import kafka.utils.Implicits._ +import kafka.zk.{ConfigEntityChangeNotificationZNode} +import org.apache.kafka.clients.CommonClientConfigs +import org.apache.kafka.clients.admin.AlterConfigOp.OpType +import org.apache.kafka.clients.admin.ConfigEntry.{ConfigSource, ConfigSynonym} +import org.apache.kafka.clients.admin._ +import org.apache.kafka.clients.consumer.{ConsumerConfig, ConsumerRecord, ConsumerRecords, KafkaConsumer} +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerConfig, ProducerRecord} +import org.apache.kafka.common.{ClusterResource, ClusterResourceListener, Reconfigurable, TopicPartition, TopicPartitionInfo} +import org.apache.kafka.common.config.{ConfigException, ConfigResource} +import org.apache.kafka.common.config.SslConfigs._ +import org.apache.kafka.common.config.types.Password +import org.apache.kafka.common.config.provider.FileConfigProvider +import org.apache.kafka.common.errors.{AuthenticationException, InvalidRequestException} +import org.apache.kafka.common.internals.Topic +import org.apache.kafka.common.metrics.Quota +import org.apache.kafka.common.metrics.{KafkaMetric, MetricsReporter} +import org.apache.kafka.common.network.{ListenerName, Mode} +import org.apache.kafka.common.network.CertStores.{KEYSTORE_PROPS, TRUSTSTORE_PROPS} +import org.apache.kafka.common.record.TimestampType +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.security.scram.ScramCredential +import org.apache.kafka.common.serialization.{StringDeserializer, StringSerializer} +import org.apache.kafka.test.{TestSslUtils, TestUtils => JTestUtils} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Disabled, Test, TestInfo} + +import scala.annotation.nowarn +import scala.collection._ +import scala.collection.mutable.ArrayBuffer +import scala.jdk.CollectionConverters._ +import scala.collection.Seq + +object DynamicBrokerReconfigurationTest { + val SecureInternal = "INTERNAL" + val SecureExternal = "EXTERNAL" +} + +class DynamicBrokerReconfigurationTest extends QuorumTestHarness with SaslSetup { + + import DynamicBrokerReconfigurationTest._ + + private val servers = new ArrayBuffer[KafkaServer] + private val numServers = 3 + private val numPartitions = 10 + private val producers = new ArrayBuffer[KafkaProducer[String, String]] + private val consumers = new ArrayBuffer[KafkaConsumer[String, String]] + private val adminClients = new ArrayBuffer[Admin]() + private val clientThreads = new ArrayBuffer[ShutdownableThread]() + private val executors = new ArrayBuffer[ExecutorService] + private val topic = "testtopic" + + private val kafkaClientSaslMechanism = "PLAIN" + private val kafkaServerSaslMechanisms = List("PLAIN") + + private val trustStoreFile1 = File.createTempFile("truststore", ".jks") + private val trustStoreFile2 = File.createTempFile("truststore", ".jks") + private val sslProperties1 = TestUtils.sslConfigs(Mode.SERVER, clientCert = false, Some(trustStoreFile1), "kafka") + private val sslProperties2 = TestUtils.sslConfigs(Mode.SERVER, clientCert = false, Some(trustStoreFile2), "kafka") + private val invalidSslProperties = invalidSslConfigs + + def addExtraProps(props: Properties): Unit = { + } + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + startSasl(jaasSections(kafkaServerSaslMechanisms, Some(kafkaClientSaslMechanism))) + super.setUp(testInfo) + + clearLeftOverProcessorMetrics() // clear metrics left over from other tests so that new ones can be tested + + (0 until numServers).foreach { brokerId => + + val props = TestUtils.createBrokerConfig(brokerId, zkConnect) + props ++= securityProps(sslProperties1, TRUSTSTORE_PROPS) + // Ensure that we can support multiple listeners per security protocol and multiple security protocols + props.put(KafkaConfig.ListenersProp, s"$SecureInternal://localhost:0, $SecureExternal://localhost:0") + props.put(KafkaConfig.ListenerSecurityProtocolMapProp, s"$SecureInternal:SSL, $SecureExternal:SASL_SSL") + props.put(KafkaConfig.InterBrokerListenerNameProp, SecureInternal) + props.put(KafkaConfig.SslClientAuthProp, "requested") + props.put(KafkaConfig.SaslMechanismInterBrokerProtocolProp, "PLAIN") + props.put(KafkaConfig.ZkEnableSecureAclsProp, "true") + props.put(KafkaConfig.SaslEnabledMechanismsProp, kafkaServerSaslMechanisms.mkString(",")) + props.put(KafkaConfig.LogSegmentBytesProp, "2000") // low value to test log rolling on config update + props.put(KafkaConfig.NumReplicaFetchersProp, "2") // greater than one to test reducing threads + props.put(KafkaConfig.PasswordEncoderSecretProp, "dynamic-config-secret") + props.put(KafkaConfig.LogRetentionTimeMillisProp, 1680000000.toString) + props.put(KafkaConfig.LogRetentionTimeHoursProp, 168.toString) + addExtraProps(props) + + props ++= sslProperties1 + props ++= securityProps(sslProperties1, KEYSTORE_PROPS, listenerPrefix(SecureInternal)) + + // Set invalid top-level properties to ensure that listener config is used + // Don't set any dynamic configs here since they get overridden in tests + props ++= invalidSslProperties + props ++= securityProps(invalidSslProperties, KEYSTORE_PROPS, "") + props ++= securityProps(sslProperties1, KEYSTORE_PROPS, listenerPrefix(SecureExternal)) + + val kafkaConfig = KafkaConfig.fromProps(props) + configureDynamicKeystoreInZooKeeper(kafkaConfig, sslProperties1) + + servers += TestUtils.createServer(kafkaConfig) + } + + TestUtils.createTopic(zkClient, topic, numPartitions, replicationFactor = numServers, servers) + TestUtils.createTopic(zkClient, Topic.GROUP_METADATA_TOPIC_NAME, servers.head.config.offsetsTopicPartitions, + replicationFactor = numServers, servers, servers.head.groupCoordinator.offsetsTopicConfigs) + + createAdminClient(SecurityProtocol.SSL, SecureInternal) + + TestMetricsReporter.testReporters.clear() + } + + @AfterEach + override def tearDown(): Unit = { + clientThreads.foreach(_.interrupt()) + clientThreads.foreach(_.initiateShutdown()) + clientThreads.foreach(_.join(5 * 1000)) + executors.foreach(_.shutdownNow()) + producers.foreach(_.close(Duration.ZERO)) + consumers.foreach(_.close(Duration.ofMillis(0))) + adminClients.foreach(_.close()) + TestUtils.shutdownServers(servers) + super.tearDown() + closeSasl() + } + + @Test + def testConfigDescribeUsingAdminClient(): Unit = { + + def verifyConfig(configName: String, configEntry: ConfigEntry, isSensitive: Boolean, isReadOnly: Boolean, + expectedProps: Properties): Unit = { + if (isSensitive) { + assertTrue(configEntry.isSensitive, s"Value is sensitive: $configName") + assertNull(configEntry.value, s"Sensitive value returned for $configName") + } else { + assertFalse(configEntry.isSensitive, s"Config is not sensitive: $configName") + assertEquals(expectedProps.getProperty(configName), configEntry.value) + } + assertEquals(isReadOnly, configEntry.isReadOnly, s"isReadOnly incorrect for $configName: $configEntry") + } + + def verifySynonym(configName: String, synonym: ConfigSynonym, isSensitive: Boolean, + expectedPrefix: String, expectedSource: ConfigSource, expectedProps: Properties): Unit = { + if (isSensitive) + assertNull(synonym.value, s"Sensitive value returned for $configName") + else + assertEquals(expectedProps.getProperty(configName), synonym.value) + assertTrue(synonym.name.startsWith(expectedPrefix), s"Expected listener config, got $synonym") + assertEquals(expectedSource, synonym.source) + } + + def verifySynonyms(configName: String, synonyms: util.List[ConfigSynonym], isSensitive: Boolean, + prefix: String, defaultValue: Option[String]): Unit = { + + val overrideCount = if (prefix.isEmpty) 0 else 2 + assertEquals(1 + overrideCount + defaultValue.size, synonyms.size, s"Wrong synonyms for $configName: $synonyms") + if (overrideCount > 0) { + val listenerPrefix = "listener.name.external.ssl." + verifySynonym(configName, synonyms.get(0), isSensitive, listenerPrefix, ConfigSource.DYNAMIC_BROKER_CONFIG, sslProperties1) + verifySynonym(configName, synonyms.get(1), isSensitive, listenerPrefix, ConfigSource.STATIC_BROKER_CONFIG, sslProperties1) + } + verifySynonym(configName, synonyms.get(overrideCount), isSensitive, "ssl.", ConfigSource.STATIC_BROKER_CONFIG, invalidSslProperties) + defaultValue.foreach { value => + val defaultProps = new Properties + defaultProps.setProperty(configName, value) + verifySynonym(configName, synonyms.get(overrideCount + 1), isSensitive, "ssl.", ConfigSource.DEFAULT_CONFIG, defaultProps) + } + } + + def verifySslConfig(prefix: String, expectedProps: Properties, configDesc: Config): Unit = { + // Validate file-based SSL keystore configs + val keyStoreProps = new util.HashSet[String](KEYSTORE_PROPS) + keyStoreProps.remove(SSL_KEYSTORE_KEY_CONFIG) + keyStoreProps.remove(SSL_KEYSTORE_CERTIFICATE_CHAIN_CONFIG) + keyStoreProps.forEach { configName => + val desc = configEntry(configDesc, s"$prefix$configName") + val isSensitive = configName.contains("password") + verifyConfig(configName, desc, isSensitive, isReadOnly = prefix.nonEmpty, if (prefix.isEmpty) invalidSslProperties else sslProperties1) + val defaultValue = if (configName == SSL_KEYSTORE_TYPE_CONFIG) Some("JKS") else None + verifySynonyms(configName, desc.synonyms, isSensitive, prefix, defaultValue) + } + } + + val adminClient = adminClients.head + alterSslKeystoreUsingConfigCommand(sslProperties1, SecureExternal) + + val configDesc = describeConfig(adminClient) + verifySslConfig("listener.name.external.", sslProperties1, configDesc) + verifySslConfig("", invalidSslProperties, configDesc) + + // Verify a few log configs with and without synonyms + val expectedProps = new Properties + expectedProps.setProperty(KafkaConfig.LogRetentionTimeMillisProp, "1680000000") + expectedProps.setProperty(KafkaConfig.LogRetentionTimeHoursProp, "168") + expectedProps.setProperty(KafkaConfig.LogRollTimeHoursProp, "168") + expectedProps.setProperty(KafkaConfig.LogCleanerThreadsProp, "1") + val logRetentionMs = configEntry(configDesc, KafkaConfig.LogRetentionTimeMillisProp) + verifyConfig(KafkaConfig.LogRetentionTimeMillisProp, logRetentionMs, + isSensitive = false, isReadOnly = false, expectedProps) + val logRetentionHours = configEntry(configDesc, KafkaConfig.LogRetentionTimeHoursProp) + verifyConfig(KafkaConfig.LogRetentionTimeHoursProp, logRetentionHours, + isSensitive = false, isReadOnly = true, expectedProps) + val logRollHours = configEntry(configDesc, KafkaConfig.LogRollTimeHoursProp) + verifyConfig(KafkaConfig.LogRollTimeHoursProp, logRollHours, + isSensitive = false, isReadOnly = true, expectedProps) + val logCleanerThreads = configEntry(configDesc, KafkaConfig.LogCleanerThreadsProp) + verifyConfig(KafkaConfig.LogCleanerThreadsProp, logCleanerThreads, + isSensitive = false, isReadOnly = false, expectedProps) + + def synonymsList(configEntry: ConfigEntry): List[(String, ConfigSource)] = + configEntry.synonyms.asScala.map(s => (s.name, s.source)).toList + assertEquals(List((KafkaConfig.LogRetentionTimeMillisProp, ConfigSource.STATIC_BROKER_CONFIG), + (KafkaConfig.LogRetentionTimeHoursProp, ConfigSource.STATIC_BROKER_CONFIG), + (KafkaConfig.LogRetentionTimeHoursProp, ConfigSource.DEFAULT_CONFIG)), + synonymsList(logRetentionMs)) + assertEquals(List((KafkaConfig.LogRetentionTimeHoursProp, ConfigSource.STATIC_BROKER_CONFIG), + (KafkaConfig.LogRetentionTimeHoursProp, ConfigSource.DEFAULT_CONFIG)), + synonymsList(logRetentionHours)) + assertEquals(List((KafkaConfig.LogRollTimeHoursProp, ConfigSource.DEFAULT_CONFIG)), synonymsList(logRollHours)) + assertEquals(List((KafkaConfig.LogCleanerThreadsProp, ConfigSource.DEFAULT_CONFIG)), synonymsList(logCleanerThreads)) + } + + @Test + def testUpdatesUsingConfigProvider(): Unit = { + val PollingIntervalVal = f"$${file:polling.interval:interval}" + val PollingIntervalUpdateVal = f"$${file:polling.interval:updinterval}" + val SslTruststoreTypeVal = f"$${file:ssl.truststore.type:storetype}" + val SslKeystorePasswordVal = f"$${file:ssl.keystore.password:password}" + + val configPrefix = listenerPrefix(SecureExternal) + val brokerConfigs = describeConfig(adminClients.head, servers).entries.asScala + // the following are values before updated + assertFalse(brokerConfigs.exists(_.name == TestMetricsReporter.PollingIntervalProp), "Initial value of polling interval") + assertFalse(brokerConfigs.exists(_.name == configPrefix + KafkaConfig.SslTruststoreTypeProp), "Initial value of ssl truststore type") + assertNull(brokerConfigs.find(_.name == configPrefix+KafkaConfig.SslKeystorePasswordProp).get.value, "Initial value of ssl keystore password") + + // setup ssl properties + val secProps = securityProps(sslProperties1, KEYSTORE_PROPS, configPrefix) + + // configure config providers and properties need be updated + val updatedProps = new Properties + updatedProps.setProperty("config.providers", "file") + updatedProps.setProperty("config.providers.file.class", "kafka.server.MockFileConfigProvider") + updatedProps.put(KafkaConfig.MetricReporterClassesProp, classOf[TestMetricsReporter].getName) + + // 1. update Integer property using config provider + updatedProps.put(TestMetricsReporter.PollingIntervalProp, PollingIntervalVal) + + // 2. update String property using config provider + updatedProps.put(configPrefix+KafkaConfig.SslTruststoreTypeProp, SslTruststoreTypeVal) + + // merge two properties + updatedProps ++= secProps + + // 3. update password property using config provider + updatedProps.put(configPrefix+KafkaConfig.SslKeystorePasswordProp, SslKeystorePasswordVal) + + alterConfigsUsingConfigCommand(updatedProps) + waitForConfig(TestMetricsReporter.PollingIntervalProp, "1000") + waitForConfig(configPrefix+KafkaConfig.SslTruststoreTypeProp, "JKS") + waitForConfig(configPrefix+KafkaConfig.SslKeystorePasswordProp, "ServerPassword") + + // wait for MetricsReporter + val reporters = TestMetricsReporter.waitForReporters(servers.size) + reporters.foreach { reporter => + reporter.verifyState(reconfigureCount = 0, deleteCount = 0, pollingInterval = 1000) + assertFalse(reporter.kafkaMetrics.isEmpty, "No metrics found") + } + + // fetch from ZK, values should be unresolved + val props = fetchBrokerConfigsFromZooKeeper(servers.head) + assertTrue(props.getProperty(TestMetricsReporter.PollingIntervalProp) == PollingIntervalVal, "polling interval is not updated in ZK") + assertTrue(props.getProperty(configPrefix+KafkaConfig.SslTruststoreTypeProp) == SslTruststoreTypeVal, "store type is not updated in ZK") + assertTrue(props.getProperty(configPrefix+KafkaConfig.SslKeystorePasswordProp) == SslKeystorePasswordVal, "keystore password is not updated in ZK") + + // verify the update + // 1. verify update not occurring if the value of property is same. + alterConfigsUsingConfigCommand(updatedProps) + waitForConfig(TestMetricsReporter.PollingIntervalProp, "1000") + reporters.foreach { reporter => + reporter.verifyState(reconfigureCount = 0, deleteCount = 0, pollingInterval = 1000) + } + + // 2. verify update occurring if the value of property changed. + updatedProps.put(TestMetricsReporter.PollingIntervalProp, PollingIntervalUpdateVal) + alterConfigsUsingConfigCommand(updatedProps) + waitForConfig(TestMetricsReporter.PollingIntervalProp, "2000") + reporters.foreach { reporter => + reporter.verifyState(reconfigureCount = 1, deleteCount = 0, pollingInterval = 2000) + } + } + + @Test + def testKeyStoreAlter(): Unit = { + val topic2 = "testtopic2" + TestUtils.createTopic(zkClient, topic2, numPartitions, replicationFactor = numServers, servers) + + // Start a producer and consumer that work with the current broker keystore. + // This should continue working while changes are made + val (producerThread, consumerThread) = startProduceConsume(retries = 0) + TestUtils.waitUntilTrue(() => consumerThread.received >= 10, "Messages not received") + + // Producer with new truststore should fail to connect before keystore update + val producer1 = ProducerBuilder().trustStoreProps(sslProperties2).maxRetries(0).build() + verifyAuthenticationFailure(producer1) + + // Update broker keystore for external listener + alterSslKeystoreUsingConfigCommand(sslProperties2, SecureExternal) + + // New producer with old truststore should fail to connect + val producer2 = ProducerBuilder().trustStoreProps(sslProperties1).maxRetries(0).build() + verifyAuthenticationFailure(producer2) + + // Produce/consume should work with new truststore with new producer/consumer + val producer = ProducerBuilder().trustStoreProps(sslProperties2).maxRetries(0).build() + val consumer = ConsumerBuilder("group1").trustStoreProps(sslProperties2).topic(topic2).build() + verifyProduceConsume(producer, consumer, 10, topic2) + + // Broker keystore update for internal listener with incompatible keystore should fail without update + val adminClient = adminClients.head + alterSslKeystore(adminClient, sslProperties2, SecureInternal, expectFailure = true) + verifyProduceConsume(producer, consumer, 10, topic2) + + // Broker keystore update for internal listener with compatible keystore should succeed + val sslPropertiesCopy = sslProperties1.clone().asInstanceOf[Properties] + val oldFile = new File(sslProperties1.getProperty(SSL_KEYSTORE_LOCATION_CONFIG)) + val newFile = File.createTempFile("keystore", ".jks") + Files.copy(oldFile.toPath, newFile.toPath, StandardCopyOption.REPLACE_EXISTING) + sslPropertiesCopy.setProperty(SSL_KEYSTORE_LOCATION_CONFIG, newFile.getPath) + alterSslKeystore(adminClient, sslPropertiesCopy, SecureInternal) + verifyProduceConsume(producer, consumer, 10, topic2) + + // Verify that keystores can be updated using same file name. + val reusableProps = sslProperties2.clone().asInstanceOf[Properties] + val reusableFile = File.createTempFile("keystore", ".jks") + reusableProps.setProperty(SSL_KEYSTORE_LOCATION_CONFIG, reusableFile.getPath) + Files.copy(new File(sslProperties1.getProperty(SSL_KEYSTORE_LOCATION_CONFIG)).toPath, + reusableFile.toPath, StandardCopyOption.REPLACE_EXISTING) + alterSslKeystore(adminClient, reusableProps, SecureExternal) + val producer3 = ProducerBuilder().trustStoreProps(sslProperties2).maxRetries(0).build() + verifyAuthenticationFailure(producer3) + // Now alter using same file name. We can't check if the update has completed by comparing config on + // the broker, so we wait for producer operation to succeed to verify that the update has been performed. + Files.copy(new File(sslProperties2.getProperty(SSL_KEYSTORE_LOCATION_CONFIG)).toPath, + reusableFile.toPath, StandardCopyOption.REPLACE_EXISTING) + reusableFile.setLastModified(System.currentTimeMillis() + 1000) + alterSslKeystore(adminClient, reusableProps, SecureExternal) + TestUtils.waitUntilTrue(() => { + try { + producer3.partitionsFor(topic).size() == numPartitions + } catch { + case _: Exception => false + } + }, "Keystore not updated") + + // Verify that all messages sent with retries=0 while keystores were being altered were consumed + stopAndVerifyProduceConsume(producerThread, consumerThread) + } + + @Test + def testTrustStoreAlter(): Unit = { + val producerBuilder = ProducerBuilder().listenerName(SecureInternal).securityProtocol(SecurityProtocol.SSL) + + // Producer with new keystore should fail to connect before truststore update + verifyAuthenticationFailure(producerBuilder.keyStoreProps(sslProperties2).build()) + + // Update broker truststore for SSL listener with both certificates + val combinedStoreProps = mergeTrustStores(sslProperties1, sslProperties2) + val prefix = listenerPrefix(SecureInternal) + val existingDynamicProps = new Properties + servers.head.config.dynamicConfig.currentDynamicBrokerConfigs.foreach { case (k, v) => + existingDynamicProps.put(k, v) + } + val newProps = new Properties + newProps ++= existingDynamicProps + newProps ++= securityProps(combinedStoreProps, TRUSTSTORE_PROPS, prefix) + reconfigureServers(newProps, perBrokerConfig = true, + (s"$prefix$SSL_TRUSTSTORE_LOCATION_CONFIG", combinedStoreProps.getProperty(SSL_TRUSTSTORE_LOCATION_CONFIG))) + + def verifySslProduceConsume(keyStoreProps: Properties, group: String): Unit = { + val producer = producerBuilder.keyStoreProps(keyStoreProps).build() + val consumer = ConsumerBuilder(group) + .listenerName(SecureInternal) + .securityProtocol(SecurityProtocol.SSL) + .keyStoreProps(keyStoreProps) + .autoOffsetReset("latest") + .build() + verifyProduceConsume(producer, consumer, 10, topic) + } + + // Produce/consume should work with old as well as new client keystore + verifySslProduceConsume(sslProperties1, "alter-truststore-1") + verifySslProduceConsume(sslProperties2, "alter-truststore-2") + + // Revert to old truststore with only one certificate and update. Clients should connect only with old keystore. + val oldTruststoreProps = new Properties + oldTruststoreProps ++= existingDynamicProps + oldTruststoreProps ++= securityProps(sslProperties1, TRUSTSTORE_PROPS, prefix) + reconfigureServers(oldTruststoreProps, perBrokerConfig = true, + (s"$prefix$SSL_TRUSTSTORE_LOCATION_CONFIG", sslProperties1.getProperty(SSL_TRUSTSTORE_LOCATION_CONFIG))) + verifyAuthenticationFailure(producerBuilder.keyStoreProps(sslProperties2).build()) + verifySslProduceConsume(sslProperties1, "alter-truststore-3") + + // Update same truststore file to contain both certificates without changing any configs. + // Clients should connect successfully with either keystore after admin client AlterConfigsRequest completes. + Files.copy(Paths.get(combinedStoreProps.getProperty(SSL_TRUSTSTORE_LOCATION_CONFIG)), + Paths.get(sslProperties1.getProperty(SSL_TRUSTSTORE_LOCATION_CONFIG)), + StandardCopyOption.REPLACE_EXISTING) + TestUtils.incrementalAlterConfigs(servers, adminClients.head, oldTruststoreProps, perBrokerConfig = true).all.get() + verifySslProduceConsume(sslProperties1, "alter-truststore-4") + verifySslProduceConsume(sslProperties2, "alter-truststore-5") + + // Update internal keystore/truststore and validate new client connections from broker (e.g. controller). + // Alter internal keystore from `sslProperties1` to `sslProperties2`, force disconnect of a controller connection + // and verify that metadata is propagated for new topic. + val props2 = securityProps(sslProperties2, KEYSTORE_PROPS, prefix) + props2 ++= securityProps(combinedStoreProps, TRUSTSTORE_PROPS, prefix) + TestUtils.incrementalAlterConfigs(servers, adminClients.head, props2, perBrokerConfig = true).all.get(15, TimeUnit.SECONDS) + verifySslProduceConsume(sslProperties2, "alter-truststore-6") + props2 ++= securityProps(sslProperties2, TRUSTSTORE_PROPS, prefix) + TestUtils.incrementalAlterConfigs(servers, adminClients.head, props2, perBrokerConfig = true).all.get(15, TimeUnit.SECONDS) + verifySslProduceConsume(sslProperties2, "alter-truststore-7") + waitForAuthenticationFailure(producerBuilder.keyStoreProps(sslProperties1)) + + val controller = servers.find(_.config.brokerId == TestUtils.waitUntilControllerElected(zkClient)).get + val controllerChannelManager = controller.kafkaController.controllerChannelManager + val brokerStateInfo: mutable.HashMap[Int, ControllerBrokerStateInfo] = + JTestUtils.fieldValue(controllerChannelManager, classOf[ControllerChannelManager], "brokerStateInfo") + brokerStateInfo(0).networkClient.disconnect("0") + TestUtils.createTopic(zkClient, "testtopic2", numPartitions, replicationFactor = numServers, servers) + } + + @Test + def testLogCleanerConfig(): Unit = { + val (producerThread, consumerThread) = startProduceConsume(retries = 0) + + verifyThreads("kafka-log-cleaner-thread-", countPerBroker = 1) + + val props = new Properties + props.put(KafkaConfig.LogCleanerThreadsProp, "2") + props.put(KafkaConfig.LogCleanerDedupeBufferSizeProp, "20000000") + props.put(KafkaConfig.LogCleanerDedupeBufferLoadFactorProp, "0.8") + props.put(KafkaConfig.LogCleanerIoBufferSizeProp, "300000") + props.put(KafkaConfig.MessageMaxBytesProp, "40000") + props.put(KafkaConfig.LogCleanerIoMaxBytesPerSecondProp, "50000000") + props.put(KafkaConfig.LogCleanerBackoffMsProp, "6000") + reconfigureServers(props, perBrokerConfig = false, (KafkaConfig.LogCleanerThreadsProp, "2")) + + // Verify cleaner config was updated. Wait for one of the configs to be updated and verify + // that all other others were updated at the same time since they are reconfigured together + val newCleanerConfig = servers.head.logManager.cleaner.currentConfig + TestUtils.waitUntilTrue(() => newCleanerConfig.numThreads == 2, "Log cleaner not reconfigured") + assertEquals(20000000, newCleanerConfig.dedupeBufferSize) + assertEquals(0.8, newCleanerConfig.dedupeBufferLoadFactor, 0.001) + assertEquals(300000, newCleanerConfig.ioBufferSize) + assertEquals(40000, newCleanerConfig.maxMessageSize) + assertEquals(50000000, newCleanerConfig.maxIoBytesPerSecond, 50000000) + assertEquals(6000, newCleanerConfig.backOffMs) + + // Verify thread count + verifyThreads("kafka-log-cleaner-thread-", countPerBroker = 2) + + // Stop a couple of threads and verify they are recreated if any config is updated + def cleanerThreads = Thread.getAllStackTraces.keySet.asScala.filter(_.getName.startsWith("kafka-log-cleaner-thread-")) + cleanerThreads.take(2).foreach(_.interrupt()) + TestUtils.waitUntilTrue(() => cleanerThreads.size == (2 * numServers) - 2, "Threads did not exit") + props.put(KafkaConfig.LogCleanerBackoffMsProp, "8000") + reconfigureServers(props, perBrokerConfig = false, (KafkaConfig.LogCleanerBackoffMsProp, "8000")) + verifyThreads("kafka-log-cleaner-thread-", countPerBroker = 2) + + // Verify that produce/consume worked throughout this test without any retries in producer + stopAndVerifyProduceConsume(producerThread, consumerThread) + } + + @Test + def testConsecutiveConfigChange(): Unit = { + val topic2 = "testtopic2" + val topicProps = new Properties + topicProps.put(KafkaConfig.MinInSyncReplicasProp, "2") + TestUtils.createTopic(zkClient, topic2, 1, replicationFactor = numServers, servers, topicProps) + var log = servers.head.logManager.getLog(new TopicPartition(topic2, 0)).getOrElse(throw new IllegalStateException("Log not found")) + assertTrue(log.config.overriddenConfigs.contains(KafkaConfig.MinInSyncReplicasProp)) + assertEquals("2", log.config.originals().get(KafkaConfig.MinInSyncReplicasProp).toString) + + val props = new Properties + props.put(KafkaConfig.MinInSyncReplicasProp, "3") + // Make a broker-default config + reconfigureServers(props, perBrokerConfig = false, (KafkaConfig.MinInSyncReplicasProp, "3")) + // Verify that all broker defaults have been updated again + servers.foreach { server => + props.forEach { (k, v) => + assertEquals(v, server.config.originals.get(k).toString, s"Not reconfigured $k") + } + } + + log = servers.head.logManager.getLog(new TopicPartition(topic2, 0)).getOrElse(throw new IllegalStateException("Log not found")) + assertTrue(log.config.overriddenConfigs.contains(KafkaConfig.MinInSyncReplicasProp)) + assertEquals("2", log.config.originals().get(KafkaConfig.MinInSyncReplicasProp).toString) // Verify topic-level config survives + + // Make a second broker-default change + props.clear() + props.put(KafkaConfig.LogRetentionTimeMillisProp, "604800000") + reconfigureServers(props, perBrokerConfig = false, (KafkaConfig.LogRetentionTimeMillisProp, "604800000")) + log = servers.head.logManager.getLog(new TopicPartition(topic2, 0)).getOrElse(throw new IllegalStateException("Log not found")) + assertTrue(log.config.overriddenConfigs.contains(KafkaConfig.MinInSyncReplicasProp)) + assertEquals("2", log.config.originals().get(KafkaConfig.MinInSyncReplicasProp).toString) // Verify topic-level config still survives + } + + @Test + def testDefaultTopicConfig(): Unit = { + val (producerThread, consumerThread) = startProduceConsume(retries = 0) + + val props = new Properties + props.put(KafkaConfig.LogSegmentBytesProp, "4000") + props.put(KafkaConfig.LogRollTimeMillisProp, TimeUnit.HOURS.toMillis(2).toString) + props.put(KafkaConfig.LogRollTimeJitterMillisProp, TimeUnit.HOURS.toMillis(1).toString) + props.put(KafkaConfig.LogIndexSizeMaxBytesProp, "100000") + props.put(KafkaConfig.LogFlushIntervalMessagesProp, "1000") + props.put(KafkaConfig.LogFlushIntervalMsProp, "60000") + props.put(KafkaConfig.LogRetentionBytesProp, "10000000") + props.put(KafkaConfig.LogRetentionTimeMillisProp, TimeUnit.DAYS.toMillis(1).toString) + props.put(KafkaConfig.MessageMaxBytesProp, "100000") + props.put(KafkaConfig.LogIndexIntervalBytesProp, "10000") + props.put(KafkaConfig.LogCleanerDeleteRetentionMsProp, TimeUnit.DAYS.toMillis(1).toString) + props.put(KafkaConfig.LogCleanerMinCompactionLagMsProp, "60000") + props.put(KafkaConfig.LogDeleteDelayMsProp, "60000") + props.put(KafkaConfig.LogCleanerMinCleanRatioProp, "0.3") + props.put(KafkaConfig.LogCleanupPolicyProp, "delete") + props.put(KafkaConfig.UncleanLeaderElectionEnableProp, "false") + props.put(KafkaConfig.MinInSyncReplicasProp, "2") + props.put(KafkaConfig.CompressionTypeProp, "gzip") + props.put(KafkaConfig.LogPreAllocateProp, true.toString) + props.put(KafkaConfig.LogMessageTimestampTypeProp, TimestampType.LOG_APPEND_TIME.toString) + props.put(KafkaConfig.LogMessageTimestampDifferenceMaxMsProp, "1000") + props.put(KafkaConfig.LogMessageDownConversionEnableProp, "false") + reconfigureServers(props, perBrokerConfig = false, (KafkaConfig.LogSegmentBytesProp, "4000")) + + // Verify that all broker defaults have been updated + servers.foreach { server => + props.forEach { (k, v) => + assertEquals(server.config.originals.get(k).toString, v, s"Not reconfigured $k") + } + } + + // Verify that configs of existing logs have been updated + val newLogConfig = LogConfig(LogConfig.extractLogConfigMap(servers.head.config)) + TestUtils.waitUntilTrue(() => servers.head.logManager.currentDefaultConfig == newLogConfig, + "Config not updated in LogManager") + + val log = servers.head.logManager.getLog(new TopicPartition(topic, 0)).getOrElse(throw new IllegalStateException("Log not found")) + TestUtils.waitUntilTrue(() => log.config.segmentSize == 4000, "Existing topic config using defaults not updated") + props.asScala.foreach { case (k, v) => + val logConfigName = DynamicLogConfig.KafkaConfigToLogConfigName(k) + val expectedValue = if (k == KafkaConfig.LogCleanupPolicyProp) s"[$v]" else v + assertEquals(expectedValue, log.config.originals.get(logConfigName).toString, + s"Not reconfigured $logConfigName for existing log") + } + consumerThread.waitForMatchingRecords(record => record.timestampType == TimestampType.LOG_APPEND_TIME) + + // Verify that the new config is actually used for new segments of existing logs + TestUtils.waitUntilTrue(() => log.logSegments.exists(_.size > 3000), "Log segment size increase not applied") + + // Verify that overridden topic configs are not updated when broker default is updated + val log2 = servers.head.logManager.getLog(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, 0)) + .getOrElse(throw new IllegalStateException("Log not found")) + assertFalse(log2.config.delete, "Overridden clean up policy should not be updated") + assertEquals(ProducerCompressionCodec.name, log2.config.compressionType) + + // Verify that we can alter subset of log configs + props.clear() + props.put(KafkaConfig.LogMessageTimestampTypeProp, TimestampType.CREATE_TIME.toString) + props.put(KafkaConfig.LogMessageTimestampDifferenceMaxMsProp, "1000") + reconfigureServers(props, perBrokerConfig = false, (KafkaConfig.LogMessageTimestampTypeProp, TimestampType.CREATE_TIME.toString)) + consumerThread.waitForMatchingRecords(record => record.timestampType == TimestampType.CREATE_TIME) + // Verify that invalid configs are not applied + val invalidProps = Map( + KafkaConfig.LogMessageTimestampDifferenceMaxMsProp -> "abc", // Invalid type + KafkaConfig.LogMessageTimestampTypeProp -> "invalid", // Invalid value + KafkaConfig.LogRollTimeMillisProp -> "0" // Fails KafkaConfig validation + ) + invalidProps.foreach { case (k, v) => + val newProps = new Properties + newProps ++= props + props.put(k, v) + reconfigureServers(props, perBrokerConfig = false, (k, props.getProperty(k)), expectFailure = true) + } + + // Verify that even though broker defaults can be defined at default cluster level for consistent + // configuration across brokers, they can also be defined at per-broker level for testing + props.clear() + props.put(KafkaConfig.LogIndexSizeMaxBytesProp, "500000") + props.put(KafkaConfig.LogRetentionTimeMillisProp, TimeUnit.DAYS.toMillis(2).toString) + alterConfigsOnServer(servers.head, props) + assertEquals(500000, servers.head.config.values.get(KafkaConfig.LogIndexSizeMaxBytesProp)) + assertEquals(TimeUnit.DAYS.toMillis(2), servers.head.config.values.get(KafkaConfig.LogRetentionTimeMillisProp)) + servers.tail.foreach { server => + assertEquals(Defaults.LogIndexSizeMaxBytes, server.config.values.get(KafkaConfig.LogIndexSizeMaxBytesProp)) + assertEquals(1680000000L, server.config.values.get(KafkaConfig.LogRetentionTimeMillisProp)) + } + + // Verify that produce/consume worked throughout this test without any retries in producer + stopAndVerifyProduceConsume(producerThread, consumerThread) + + // Verify that configuration at both per-broker level and default cluster level could be deleted and + // the default value should be restored + props.clear() + props.put(KafkaConfig.LogRetentionTimeMillisProp, "") + props.put(KafkaConfig.LogIndexSizeMaxBytesProp, "") + TestUtils.incrementalAlterConfigs(servers.take(1), adminClients.head, props, perBrokerConfig = true, opType = OpType.DELETE).all.get + TestUtils.incrementalAlterConfigs(servers, adminClients.head, props, perBrokerConfig = false, opType = OpType.DELETE).all.get + servers.foreach { server => + waitForConfigOnServer(server, KafkaConfig.LogRetentionTimeMillisProp, 1680000000.toString) + } + servers.foreach { server => + val log = server.logManager.getLog(new TopicPartition(topic, 0)).getOrElse(throw new IllegalStateException("Log not found")) + // Verify default values for these two configurations are restored on all brokers + TestUtils.waitUntilTrue(() => log.config.maxIndexSize == Defaults.LogIndexSizeMaxBytes && log.config.retentionMs == 1680000000L, + "Existing topic config using defaults not updated") + } + } + + @Test + def testUncleanLeaderElectionEnable(): Unit = { + val controller = servers.find(_.config.brokerId == TestUtils.waitUntilControllerElected(zkClient)).get + val controllerId = controller.config.brokerId + + // Create a topic with two replicas on brokers other than the controller + val topic = "testtopic2" + val assignment = Map(0 -> Seq((controllerId + 1) % servers.size, (controllerId + 2) % servers.size)) + TestUtils.createTopic(zkClient, topic, assignment, servers) + + val producer = ProducerBuilder().acks(1).build() + val consumer = ConsumerBuilder("unclean-leader-test").enableAutoCommit(false).topic(topic).build() + verifyProduceConsume(producer, consumer, numRecords = 10, topic) + consumer.commitSync() + + def partitionInfo: TopicPartitionInfo = + adminClients.head.describeTopics(Collections.singleton(topic)).topicNameValues().get(topic).get().partitions().get(0) + + val partitionInfo0 = partitionInfo + assertEquals(partitionInfo0.replicas.get(0), partitionInfo0.leader) + val leaderBroker = servers.find(_.config.brokerId == partitionInfo0.replicas.get(0).id).get + val followerBroker = servers.find(_.config.brokerId == partitionInfo0.replicas.get(1).id).get + + // Stop follower + followerBroker.shutdown() + followerBroker.awaitShutdown() + + // Produce and consume some messages when the only follower is down, this should succeed since MinIsr is 1 + verifyProduceConsume(producer, consumer, numRecords = 10, topic) + consumer.commitSync() + + // Shutdown leader and startup follower + leaderBroker.shutdown() + leaderBroker.awaitShutdown() + followerBroker.startup() + + // Verify that new leader is not elected with unclean leader disabled since there are no ISRs + TestUtils.waitUntilTrue(() => partitionInfo.leader == null, "Unclean leader elected") + + // Enable unclean leader election + val newProps = new Properties + newProps.put(KafkaConfig.UncleanLeaderElectionEnableProp, "true") + TestUtils.incrementalAlterConfigs(servers, adminClients.head, newProps, perBrokerConfig = false).all.get + waitForConfigOnServer(controller, KafkaConfig.UncleanLeaderElectionEnableProp, "true") + + // Verify that the old follower with missing records is elected as the new leader + val (newLeader, elected) = TestUtils.computeUntilTrue(partitionInfo.leader)(leader => leader != null) + assertTrue(elected, "Unclean leader not elected") + assertEquals(followerBroker.config.brokerId, newLeader.id) + + // New leader doesn't have the last 10 records committed on the old leader that have already been consumed. + // With unclean leader election enabled, we should be able to produce to the new leader. The first 10 records + // produced will not be consumed since they have offsets less than the consumer's committed offset. + // Next 10 records produced should be consumed. + (1 to 10).map(i => new ProducerRecord(topic, s"key$i", s"value$i")) + .map(producer.send) + .map(_.get(10, TimeUnit.SECONDS)) + verifyProduceConsume(producer, consumer, numRecords = 10, topic) + consumer.commitSync() + } + + @Test + def testThreadPoolResize(): Unit = { + val requestHandlerPrefix = "data-plane-kafka-request-handler-" + val networkThreadPrefix = "data-plane-kafka-network-thread-" + val fetcherThreadPrefix = "ReplicaFetcherThread-" + // Executor threads and recovery threads are not verified since threads may not be running + // For others, thread count should be configuredCount * threadMultiplier * numBrokers + val threadMultiplier = Map( + requestHandlerPrefix -> 1, + networkThreadPrefix -> 2, // 2 endpoints + fetcherThreadPrefix -> (servers.size - 1) + ) + + // Tolerate threads left over from previous tests + def leftOverThreadCount(prefix: String, perBrokerCount: Int): Int = { + val count = matchingThreads(prefix).size - perBrokerCount * servers.size * threadMultiplier(prefix) + if (count > 0) count else 0 + } + + val leftOverThreads = Map( + requestHandlerPrefix -> leftOverThreadCount(requestHandlerPrefix, servers.head.config.numIoThreads), + networkThreadPrefix -> leftOverThreadCount(networkThreadPrefix, servers.head.config.numNetworkThreads), + fetcherThreadPrefix -> leftOverThreadCount(fetcherThreadPrefix, servers.head.config.numReplicaFetchers) + ) + + def maybeVerifyThreadPoolSize(propName: String, size: Int, threadPrefix: String): Unit = { + val ignoreCount = leftOverThreads.getOrElse(threadPrefix, 0) + val expectedCountPerBroker = threadMultiplier.getOrElse(threadPrefix, 0) * size + if (expectedCountPerBroker > 0) + verifyThreads(threadPrefix, expectedCountPerBroker, ignoreCount) + } + + def reducePoolSize(propName: String, currentSize: => Int, threadPrefix: String): Int = { + val newSize = if (currentSize / 2 == 0) 1 else currentSize / 2 + resizeThreadPool(propName, newSize, threadPrefix) + newSize + } + + def increasePoolSize(propName: String, currentSize: => Int, threadPrefix: String): Int = { + val newSize = if (currentSize == 1) currentSize * 2 else currentSize * 2 - 1 + resizeThreadPool(propName, newSize, threadPrefix) + newSize + } + + def resizeThreadPool(propName: String, newSize: Int, threadPrefix: String): Unit = { + val props = new Properties + props.put(propName, newSize.toString) + reconfigureServers(props, perBrokerConfig = false, (propName, newSize.toString)) + maybeVerifyThreadPoolSize(propName, newSize, threadPrefix) + } + + def verifyThreadPoolResize(propName: String, currentSize: => Int, threadPrefix: String, mayReceiveDuplicates: Boolean): Unit = { + maybeVerifyThreadPoolSize(propName, currentSize, threadPrefix) + val numRetries = if (mayReceiveDuplicates) 100 else 0 + val (producerThread, consumerThread) = startProduceConsume(retries = numRetries) + var threadPoolSize = currentSize + (1 to 2).foreach { _ => + threadPoolSize = reducePoolSize(propName, threadPoolSize, threadPrefix) + Thread.sleep(100) + threadPoolSize = increasePoolSize(propName, threadPoolSize, threadPrefix) + Thread.sleep(100) + } + stopAndVerifyProduceConsume(producerThread, consumerThread, mayReceiveDuplicates) + // Verify that all threads are alive + maybeVerifyThreadPoolSize(propName, threadPoolSize, threadPrefix) + } + + val config = servers.head.config + verifyThreadPoolResize(KafkaConfig.NumIoThreadsProp, config.numIoThreads, + requestHandlerPrefix, mayReceiveDuplicates = false) + verifyThreadPoolResize(KafkaConfig.NumReplicaFetchersProp, config.numReplicaFetchers, + fetcherThreadPrefix, mayReceiveDuplicates = false) + verifyThreadPoolResize(KafkaConfig.BackgroundThreadsProp, config.backgroundThreads, + "kafka-scheduler-", mayReceiveDuplicates = false) + verifyThreadPoolResize(KafkaConfig.NumRecoveryThreadsPerDataDirProp, config.numRecoveryThreadsPerDataDir, + "", mayReceiveDuplicates = false) + verifyThreadPoolResize(KafkaConfig.NumNetworkThreadsProp, config.numNetworkThreads, + networkThreadPrefix, mayReceiveDuplicates = true) + verifyThreads("data-plane-kafka-socket-acceptor-", config.listeners.size) + + verifyProcessorMetrics() + verifyMarkPartitionsForTruncation() + } + + private def isProcessorMetric(metricName: MetricName): Boolean = { + val mbeanName = metricName.getMBeanName + mbeanName.contains(s"${Processor.NetworkProcessorMetricTag}=") || mbeanName.contains(s"${RequestChannel.ProcessorMetricTag}=") + } + + private def clearLeftOverProcessorMetrics(): Unit = { + val metricsFromOldTests = KafkaYammerMetrics.defaultRegistry.allMetrics.keySet.asScala.filter(isProcessorMetric) + metricsFromOldTests.foreach(KafkaYammerMetrics.defaultRegistry.removeMetric) + } + + // Verify that metrics from processors that were removed have been deleted. + // Since processor ids are not reused, it is sufficient to check metrics count + // based on the current number of processors + private def verifyProcessorMetrics(): Unit = { + val numProcessors = servers.head.config.numNetworkThreads * 2 // 2 listeners + + val kafkaMetrics = servers.head.metrics.metrics().keySet.asScala + .filter(_.tags.containsKey(Processor.NetworkProcessorMetricTag)) + .groupBy(_.tags.get(Processor.NetworkProcessorMetricTag)) + assertEquals(numProcessors, kafkaMetrics.size) + + KafkaYammerMetrics.defaultRegistry.allMetrics.keySet.asScala + .filter(isProcessorMetric) + .groupBy(_.getName) + .foreach { case (name, set) => assertEquals(numProcessors, set.size, s"Metrics not deleted $name") } + } + + // Verify that replicaFetcherManager.markPartitionsForTruncation uses the current fetcher thread size + // to obtain partition assignment + private def verifyMarkPartitionsForTruncation(): Unit = { + val leaderId = 0 + val partitions = (0 until numPartitions).map(i => new TopicPartition(topic, i)).filter { tp => + zkClient.getLeaderForPartition(tp).contains(leaderId) + } + assertTrue(partitions.nonEmpty, s"Partitions not found with leader $leaderId") + partitions.foreach { tp => + (1 to 2).foreach { i => + val replicaFetcherManager = servers(i).replicaManager.replicaFetcherManager + val truncationOffset = tp.partition + replicaFetcherManager.markPartitionsForTruncation(leaderId, tp, truncationOffset) + val fetcherThreads = replicaFetcherManager.fetcherThreadMap.filter(_._2.fetchState(tp).isDefined) + assertEquals(1, fetcherThreads.size) + assertEquals(replicaFetcherManager.getFetcherId(tp), fetcherThreads.head._1.fetcherId) + val thread = fetcherThreads.head._2 + assertEquals(Some(truncationOffset), thread.fetchState(tp).map(_.fetchOffset)) + assertEquals(Some(Truncating), thread.fetchState(tp).map(_.state)) + } + } + } + + @Test + def testMetricsReporterUpdate(): Unit = { + // Add a new metrics reporter + val newProps = new Properties + newProps.put(TestMetricsReporter.PollingIntervalProp, "100") + configureMetricsReporters(Seq(classOf[TestMetricsReporter]), newProps) + + val reporters = TestMetricsReporter.waitForReporters(servers.size) + reporters.foreach { reporter => + reporter.verifyState(reconfigureCount = 0, deleteCount = 0, pollingInterval = 100) + assertFalse(reporter.kafkaMetrics.isEmpty, "No metrics found") + reporter.verifyMetricValue("request-total", "socket-server-metrics") + } + assertEquals(servers.map(_.config.brokerId).toSet, TestMetricsReporter.configuredBrokers.toSet) + + // non-default value to trigger a new metric + val clientId = "test-client-1" + servers.foreach { server => + server.quotaManagers.produce.updateQuota(None, Some(clientId), Some(clientId), + Some(Quota.upperBound(10000000))) + } + val (producerThread, consumerThread) = startProduceConsume(retries = 0, clientId) + TestUtils.waitUntilTrue(() => consumerThread.received >= 5, "Messages not sent") + + // Verify that JMX reporter is still active (test a metric registered after the dynamic reporter update) + val mbeanServer = ManagementFactory.getPlatformMBeanServer + val byteRate = mbeanServer.getAttribute(new ObjectName(s"kafka.server:type=Produce,client-id=$clientId"), "byte-rate") + assertTrue(byteRate.asInstanceOf[Double] > 0, "JMX attribute not updated") + + // Property not related to the metrics reporter config should not reconfigure reporter + newProps.setProperty("some.prop", "some.value") + reconfigureServers(newProps, perBrokerConfig = false, (TestMetricsReporter.PollingIntervalProp, "100")) + reporters.foreach(_.verifyState(reconfigureCount = 0, deleteCount = 0, pollingInterval = 100)) + + // Update of custom config of metrics reporter should reconfigure reporter + newProps.put(TestMetricsReporter.PollingIntervalProp, "1000") + reconfigureServers(newProps, perBrokerConfig = false, (TestMetricsReporter.PollingIntervalProp, "1000")) + reporters.foreach(_.verifyState(reconfigureCount = 1, deleteCount = 0, pollingInterval = 1000)) + + // Verify removal of metrics reporter + configureMetricsReporters(Seq.empty[Class[_]], newProps) + reporters.foreach(_.verifyState(reconfigureCount = 1, deleteCount = 1, pollingInterval = 1000)) + TestMetricsReporter.testReporters.clear() + + // Verify recreation of metrics reporter + newProps.put(TestMetricsReporter.PollingIntervalProp, "2000") + configureMetricsReporters(Seq(classOf[TestMetricsReporter]), newProps) + val newReporters = TestMetricsReporter.waitForReporters(servers.size) + newReporters.foreach(_.verifyState(reconfigureCount = 0, deleteCount = 0, pollingInterval = 2000)) + + // Verify that validation failure of metrics reporter fails reconfiguration and leaves config unchanged + newProps.put(KafkaConfig.MetricReporterClassesProp, "unknownMetricsReporter") + reconfigureServers(newProps, perBrokerConfig = false, (TestMetricsReporter.PollingIntervalProp, "2000"), expectFailure = true) + servers.foreach { server => + assertEquals(classOf[TestMetricsReporter].getName, server.config.originals.get(KafkaConfig.MetricReporterClassesProp)) + } + newReporters.foreach(_.verifyState(reconfigureCount = 0, deleteCount = 0, pollingInterval = 2000)) + + // Verify that validation failure of custom config fails reconfiguration and leaves config unchanged + newProps.put(TestMetricsReporter.PollingIntervalProp, "invalid") + reconfigureServers(newProps, perBrokerConfig = false, (TestMetricsReporter.PollingIntervalProp, "2000"), expectFailure = true) + newReporters.foreach(_.verifyState(reconfigureCount = 0, deleteCount = 0, pollingInterval = 2000)) + + // Delete reporters + configureMetricsReporters(Seq.empty[Class[_]], newProps) + TestMetricsReporter.testReporters.clear() + + // Verify that even though metrics reporters can be defined at default cluster level for consistent + // configuration across brokers, they can also be defined at per-broker level for testing + newProps.put(KafkaConfig.MetricReporterClassesProp, classOf[TestMetricsReporter].getName) + newProps.put(TestMetricsReporter.PollingIntervalProp, "4000") + alterConfigsOnServer(servers.head, newProps) + TestUtils.waitUntilTrue(() => !TestMetricsReporter.testReporters.isEmpty, "Metrics reporter not created") + val perBrokerReporter = TestMetricsReporter.waitForReporters(1).head + perBrokerReporter.verifyState(reconfigureCount = 0, deleteCount = 0, pollingInterval = 4000) + + // update TestMetricsReporter.PollingIntervalProp to 3000 + newProps.put(TestMetricsReporter.PollingIntervalProp, "3000") + alterConfigsOnServer(servers.head, newProps) + perBrokerReporter.verifyState(reconfigureCount = 1, deleteCount = 0, pollingInterval = 3000) + + servers.tail.foreach { server => assertEquals("", server.config.originals.get(KafkaConfig.MetricReporterClassesProp)) } + + // Verify that produce/consume worked throughout this test without any retries in producer + stopAndVerifyProduceConsume(producerThread, consumerThread) + } + + @Test + def testAdvertisedListenerUpdate(): Unit = { + val adminClient = adminClients.head + val externalAdminClient = createAdminClient(SecurityProtocol.SASL_SSL, SecureExternal) + + // Ensure connections are made to brokers before external listener is made inaccessible + describeConfig(externalAdminClient) + + // Update broker external listener to use invalid listener address + // any address other than localhost is sufficient to fail (either connection or host name verification failure) + val invalidHost = "192.168.0.1" + alterAdvertisedListener(adminClient, externalAdminClient, "localhost", invalidHost) + + def validateEndpointsInZooKeeper(server: KafkaServer, endpointMatcher: String => Boolean): Unit = { + val brokerInfo = zkClient.getBroker(server.config.brokerId) + assertTrue(brokerInfo.nonEmpty, "Broker not registered") + val endpoints = brokerInfo.get.endPoints.toString + assertTrue(endpointMatcher(endpoints), s"Endpoint update not saved $endpoints") + } + + // Verify that endpoints have been updated in ZK for all brokers + servers.foreach(validateEndpointsInZooKeeper(_, endpoints => endpoints.contains(invalidHost))) + + // Trigger session expiry and ensure that controller registers new advertised listener after expiry + val controllerEpoch = zkClient.getControllerEpoch + val controllerServer = servers(zkClient.getControllerId.getOrElse(throw new IllegalStateException("No controller"))) + val controllerZkClient = controllerServer.zkClient + val sessionExpiringClient = createZooKeeperClientToTriggerSessionExpiry(controllerZkClient.currentZooKeeper) + sessionExpiringClient.close() + TestUtils.waitUntilTrue(() => zkClient.getControllerEpoch != controllerEpoch, + "Controller not re-elected after ZK session expiry") + TestUtils.retry(10000)(validateEndpointsInZooKeeper(controllerServer, endpoints => endpoints.contains(invalidHost))) + + // Verify that producer connections fail since advertised listener is invalid + val bootstrap = TestUtils.bootstrapServers(servers, new ListenerName(SecureExternal)) + .replaceAll(invalidHost, "localhost") // allow bootstrap connection to succeed + val producer1 = ProducerBuilder() + .trustStoreProps(sslProperties1) + .maxRetries(0) + .requestTimeoutMs(1000) + .deliveryTimeoutMs(1000) + .bootstrapServers(bootstrap) + .build() + + val future = producer1.send(new ProducerRecord(topic, "key", "value")) + assertTrue(assertThrows(classOf[ExecutionException], () => future.get(2, TimeUnit.SECONDS)) + .getCause.isInstanceOf[org.apache.kafka.common.errors.TimeoutException]) + + alterAdvertisedListener(adminClient, externalAdminClient, invalidHost, "localhost") + servers.foreach(validateEndpointsInZooKeeper(_, endpoints => !endpoints.contains(invalidHost))) + + // Verify that produce/consume work now + val topic2 = "testtopic2" + TestUtils.createTopic(zkClient, topic2, numPartitions, replicationFactor = numServers, servers) + val producer = ProducerBuilder().trustStoreProps(sslProperties1).maxRetries(0).build() + val consumer = ConsumerBuilder("group2").trustStoreProps(sslProperties1).topic(topic2).build() + verifyProduceConsume(producer, consumer, 10, topic2) + + // Verify updating inter-broker listener + val props = new Properties + props.put(KafkaConfig.InterBrokerListenerNameProp, SecureExternal) + val e = assertThrows(classOf[ExecutionException], () => reconfigureServers(props, perBrokerConfig = true, (KafkaConfig.InterBrokerListenerNameProp, SecureExternal))) + assertTrue(e.getCause.isInstanceOf[InvalidRequestException], s"Unexpected exception ${e.getCause}") + servers.foreach(server => assertEquals(SecureInternal, server.config.interBrokerListenerName.value)) + } + + @Test + @Disabled // Re-enable once we make it less flaky (KAFKA-6824) + def testAddRemoveSslListener(): Unit = { + verifyAddListener("SSL", SecurityProtocol.SSL, Seq.empty) + + // Restart servers and check secret rotation + servers.foreach(_.shutdown()) + servers.foreach(_.awaitShutdown()) + adminClients.foreach(_.close()) + adminClients.clear() + + // All passwords are currently encoded with password.encoder.secret. Encode with password.encoder.old.secret + // and update ZK. When each server is started, it should decode using password.encoder.old.secret and update + // ZK with newly encoded values using password.encoder.secret. + servers.foreach { server => + val props = adminZkClient.fetchEntityConfig(ConfigType.Broker, server.config.brokerId.toString) + val propsEncodedWithOldSecret = props.clone().asInstanceOf[Properties] + val config = server.config + val oldSecret = "old-dynamic-config-secret" + config.dynamicConfig.staticBrokerConfigs.put(KafkaConfig.PasswordEncoderOldSecretProp, oldSecret) + val passwordConfigs = props.asScala.filter { case (k, _) => DynamicBrokerConfig.isPasswordConfig(k) } + assertTrue(passwordConfigs.nonEmpty, "Password configs not found") + val passwordDecoder = createPasswordEncoder(config, config.passwordEncoderSecret) + val passwordEncoder = createPasswordEncoder(config, Some(new Password(oldSecret))) + passwordConfigs.foreach { case (name, value) => + val decoded = passwordDecoder.decode(value).value + propsEncodedWithOldSecret.put(name, passwordEncoder.encode(new Password(decoded))) + } + val brokerId = server.config.brokerId + adminZkClient.changeBrokerConfig(Seq(brokerId), propsEncodedWithOldSecret) + val updatedProps = adminZkClient.fetchEntityConfig(ConfigType.Broker, brokerId.toString) + passwordConfigs.foreach { case (name, value) => assertNotEquals(props.get(value), updatedProps.get(name)) } + + server.startup() + TestUtils.retry(10000) { + val newProps = adminZkClient.fetchEntityConfig(ConfigType.Broker, brokerId.toString) + passwordConfigs.foreach { case (name, value) => + assertEquals(passwordDecoder.decode(value), passwordDecoder.decode(newProps.getProperty(name))) } + } + } + + verifyListener(SecurityProtocol.SSL, None, "add-ssl-listener-group2") + createAdminClient(SecurityProtocol.SSL, SecureInternal) + verifyRemoveListener("SSL", SecurityProtocol.SSL, Seq.empty) + } + + @Test + def testAddRemoveSaslListeners(): Unit = { + createScramCredentials(adminClients.head, JaasTestUtils.KafkaScramUser, JaasTestUtils.KafkaScramPassword) + createScramCredentials(adminClients.head, JaasTestUtils.KafkaScramAdmin, JaasTestUtils.KafkaScramAdminPassword) + initializeKerberos() + // make sure each server's credential cache has all the created credentials + // (check after initializing Kerberos to minimize delays) + List(JaasTestUtils.KafkaScramUser, JaasTestUtils.KafkaScramAdmin).foreach { scramUser => + servers.foreach { server => + ScramMechanism.values().filter(_ != ScramMechanism.UNKNOWN).foreach(mechanism => + TestUtils.waitUntilTrue(() => server.credentialProvider.credentialCache.cache( + mechanism.mechanismName(), classOf[ScramCredential]).get(scramUser) != null, + s"$mechanism credentials not created for $scramUser")) + }} + + //verifyAddListener("SASL_SSL", SecurityProtocol.SASL_SSL, Seq("SCRAM-SHA-512", "SCRAM-SHA-256", "PLAIN")) + verifyAddListener("SASL_PLAINTEXT", SecurityProtocol.SASL_PLAINTEXT, Seq("GSSAPI")) + //verifyRemoveListener("SASL_SSL", SecurityProtocol.SASL_SSL, Seq("SCRAM-SHA-512", "SCRAM-SHA-256", "PLAIN")) + verifyRemoveListener("SASL_PLAINTEXT", SecurityProtocol.SASL_PLAINTEXT, Seq("GSSAPI")) + + // Verify that a listener added to a subset of servers doesn't cause any issues + // when metadata is processed by the client. + addListener(servers.tail, "SCRAM_LISTENER", SecurityProtocol.SASL_PLAINTEXT, Seq("SCRAM-SHA-256")) + val bootstrap = TestUtils.bootstrapServers(servers.tail, new ListenerName("SCRAM_LISTENER")) + val producer = ProducerBuilder().bootstrapServers(bootstrap) + .securityProtocol(SecurityProtocol.SASL_PLAINTEXT) + .saslMechanism("SCRAM-SHA-256") + .maxRetries(1000) + .build() + val partitions = producer.partitionsFor(topic).asScala + assertEquals(0, partitions.count(p => p.leader != null && p.leader.id == servers.head.config.brokerId)) + assertTrue(partitions.exists(_.leader == null), "Did not find partitions with no leader") + } + + private def addListener(servers: Seq[KafkaServer], listenerName: String, securityProtocol: SecurityProtocol, + saslMechanisms: Seq[String]): Unit = { + val config = servers.head.config + val existingListenerCount = config.listeners.size + val listeners = config.listeners + .map(e => s"${e.listenerName.value}://${e.host}:${e.port}") + .mkString(",") + s",$listenerName://localhost:0" + val listenerMap = config.effectiveListenerSecurityProtocolMap + .map { case (name, protocol) => s"${name.value}:${protocol.name}" } + .mkString(",") + s",$listenerName:${securityProtocol.name}" + + val props = fetchBrokerConfigsFromZooKeeper(servers.head) + props.put(KafkaConfig.ListenersProp, listeners) + props.put(KafkaConfig.ListenerSecurityProtocolMapProp, listenerMap) + securityProtocol match { + case SecurityProtocol.SSL => + addListenerPropsSsl(listenerName, props) + case SecurityProtocol.SASL_PLAINTEXT => + addListenerPropsSasl(listenerName, saslMechanisms, props) + case SecurityProtocol.SASL_SSL => + addListenerPropsSasl(listenerName, saslMechanisms, props) + addListenerPropsSsl(listenerName, props) + case SecurityProtocol.PLAINTEXT => // no additional props + } + + // Add a config to verify that configs whose types are not known are not returned by describeConfigs() + val unknownConfig = "some.config" + props.put(unknownConfig, "some.config.value") + + TestUtils.incrementalAlterConfigs(servers, adminClients.head, props, perBrokerConfig = true).all.get + + TestUtils.waitUntilTrue(() => servers.forall(server => server.config.listeners.size == existingListenerCount + 1), + "Listener config not updated") + TestUtils.waitUntilTrue(() => servers.forall(server => { + try { + server.socketServer.boundPort(new ListenerName(listenerName)) > 0 + } catch { + case _: Exception => false + } + }), "Listener not created") + + val brokerConfigs = describeConfig(adminClients.head, servers).entries.asScala + props.asScala.foreach { case (name, value) => + val entry = brokerConfigs.find(_.name == name).getOrElse(throw new IllegalArgumentException(s"Config not found $name")) + if (DynamicBrokerConfig.isPasswordConfig(name) || name == unknownConfig) + assertNull(entry.value, s"Password or unknown config returned $entry") + else + assertEquals(value, entry.value) + } + } + + private def verifyAddListener(listenerName: String, securityProtocol: SecurityProtocol, + saslMechanisms: Seq[String]): Unit = { + addListener(servers, listenerName, securityProtocol, saslMechanisms) + TestUtils.waitUntilTrue(() => servers.forall(hasListenerMetric(_, listenerName)), + "Processors not started for new listener") + if (saslMechanisms.nonEmpty) + saslMechanisms.foreach { mechanism => + verifyListener(securityProtocol, Some(mechanism), s"add-listener-group-$securityProtocol-$mechanism") + } + else + verifyListener(securityProtocol, None, s"add-listener-group-$securityProtocol") + } + + private def verifyRemoveListener(listenerName: String, securityProtocol: SecurityProtocol, + saslMechanisms: Seq[String]): Unit = { + val saslMechanism = if (saslMechanisms.isEmpty) "" else saslMechanisms.head + val producer1 = ProducerBuilder().listenerName(listenerName) + .securityProtocol(securityProtocol) + .saslMechanism(saslMechanism) + .maxRetries(1000) + .build() + val consumer1 = ConsumerBuilder(s"remove-listener-group-$securityProtocol") + .listenerName(listenerName) + .securityProtocol(securityProtocol) + .saslMechanism(saslMechanism) + .autoOffsetReset("latest") + .build() + verifyProduceConsume(producer1, consumer1, numRecords = 10, topic) + + val config = servers.head.config + val existingListenerCount = config.listeners.size + val listeners = config.listeners + .filter(e => e.listenerName.value != securityProtocol.name) + .map(e => s"${e.listenerName.value}://${e.host}:${e.port}") + .mkString(",") + val listenerMap = config.effectiveListenerSecurityProtocolMap + .filter { case (listenerName, _) => listenerName.value != securityProtocol.name } + .map { case (listenerName, protocol) => s"${listenerName.value}:${protocol.name}" } + .mkString(",") + + val props = fetchBrokerConfigsFromZooKeeper(servers.head) + val deleteListenerProps = new Properties() + deleteListenerProps ++= props.asScala.filter(entry => entry._1.startsWith(listenerPrefix(listenerName))) + TestUtils.incrementalAlterConfigs(servers, adminClients.head, deleteListenerProps, perBrokerConfig = true, opType = OpType.DELETE).all.get + + props.clear() + props.put(KafkaConfig.ListenersProp, listeners) + props.put(KafkaConfig.ListenerSecurityProtocolMapProp, listenerMap) + TestUtils.incrementalAlterConfigs(servers, adminClients.head, props, perBrokerConfig = true).all.get + + TestUtils.waitUntilTrue(() => servers.forall(server => server.config.listeners.size == existingListenerCount - 1), + "Listeners not updated") + // Wait until metrics of the listener have been removed to ensure that processors have been shutdown before + // verifying that connections to the removed listener fail. + TestUtils.waitUntilTrue(() => !servers.exists(hasListenerMetric(_, listenerName)), + "Processors not shutdown for removed listener") + + // Test that connections using deleted listener don't work + val producerFuture = verifyConnectionFailure(producer1) + val consumerFuture = verifyConnectionFailure(consumer1) + + // Test that other listeners still work + val topic2 = "testtopic2" + TestUtils.createTopic(zkClient, topic2, numPartitions, replicationFactor = numServers, servers) + val producer2 = ProducerBuilder().trustStoreProps(sslProperties1).maxRetries(0).build() + val consumer2 = ConsumerBuilder(s"remove-listener-group2-$securityProtocol") + .trustStoreProps(sslProperties1) + .topic(topic2) + .autoOffsetReset("latest") + .build() + verifyProduceConsume(producer2, consumer2, numRecords = 10, topic2) + + // Verify that producer/consumer using old listener don't work + verifyTimeout(producerFuture) + verifyTimeout(consumerFuture) + } + + private def verifyListener(securityProtocol: SecurityProtocol, saslMechanism: Option[String], groupId: String): Unit = { + val mechanism = saslMechanism.getOrElse("") + val retries = 1000 // since it may take time for metadata to be updated on all brokers + val producer = ProducerBuilder().listenerName(securityProtocol.name) + .securityProtocol(securityProtocol) + .saslMechanism(mechanism) + .maxRetries(retries) + .build() + val consumer = ConsumerBuilder(s"add-listener-group-$securityProtocol-$mechanism") + .listenerName(securityProtocol.name) + .securityProtocol(securityProtocol) + .saslMechanism(mechanism) + .autoOffsetReset("latest") + .build() + verifyProduceConsume(producer, consumer, numRecords = 10, topic) + } + + private def hasListenerMetric(server: KafkaServer, listenerName: String): Boolean = { + server.socketServer.metrics.metrics.keySet.asScala.exists(_.tags.get("listener") == listenerName) + } + + private def fetchBrokerConfigsFromZooKeeper(server: KafkaServer): Properties = { + val props = adminZkClient.fetchEntityConfig(ConfigType.Broker, server.config.brokerId.toString) + server.config.dynamicConfig.fromPersistentProps(props, perBrokerConfig = true) + } + + private def awaitInitialPositions(consumer: KafkaConsumer[_, _]): Unit = { + TestUtils.pollUntilTrue(consumer, () => !consumer.assignment.isEmpty, "Timed out while waiting for assignment") + consumer.assignment.forEach(consumer.position(_)) + } + + private def clientProps(securityProtocol: SecurityProtocol, saslMechanism: Option[String] = None): Properties = { + val props = new Properties + props.put(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, securityProtocol.name) + props.put(SSL_ENDPOINT_IDENTIFICATION_ALGORITHM_CONFIG, "HTTPS") + if (securityProtocol == SecurityProtocol.SASL_PLAINTEXT || securityProtocol == SecurityProtocol.SASL_SSL) + props ++= kafkaClientSaslProperties(saslMechanism.getOrElse(kafkaClientSaslMechanism), dynamicJaasConfig = true) + props ++= sslProperties1 + securityProps(props, props.keySet) + } + + private def createAdminClient(securityProtocol: SecurityProtocol, listenerName: String): Admin = { + val config = clientProps(securityProtocol) + val bootstrapServers = TestUtils.bootstrapServers(servers, new ListenerName(listenerName)) + config.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers) + config.put(AdminClientConfig.METADATA_MAX_AGE_CONFIG, "10") + val adminClient = Admin.create(config) + adminClients += adminClient + adminClient + } + + private def verifyProduceConsume(producer: KafkaProducer[String, String], + consumer: KafkaConsumer[String, String], + numRecords: Int, + topic: String): Unit = { + val producerRecords = (1 to numRecords).map(i => new ProducerRecord(topic, s"key$i", s"value$i")) + producerRecords.map(producer.send).map(_.get(10, TimeUnit.SECONDS)) + TestUtils.pollUntilAtLeastNumRecords(consumer, numRecords) + } + + private def verifyAuthenticationFailure(producer: KafkaProducer[_, _]): Unit = { + assertThrows(classOf[AuthenticationException], () => producer.partitionsFor(topic)) + } + + private def waitForAuthenticationFailure(producerBuilder: ProducerBuilder): Unit = { + TestUtils.waitUntilTrue(() => { + try { + verifyAuthenticationFailure(producerBuilder.build()) + true + } catch { + case e: Error => false + } + }, "Did not fail authentication with invalid config") + } + + private def describeConfig(adminClient: Admin, servers: Seq[KafkaServer] = this.servers): Config = { + val configResources = servers.map { server => + new ConfigResource(ConfigResource.Type.BROKER, server.config.brokerId.toString) + } + val describeOptions = new DescribeConfigsOptions().includeSynonyms(true) + val describeResult = adminClient.describeConfigs(configResources.asJava, describeOptions).all.get + assertEquals(servers.size, describeResult.values.size) + val configDescription = describeResult.values.iterator.next + assertFalse(configDescription.entries.isEmpty, "Configs are empty") + configDescription + } + + private def securityProps(srcProps: Properties, propNames: util.Set[_], listenerPrefix: String = ""): Properties = { + val resultProps = new Properties + propNames.asScala.filter(srcProps.containsKey).foreach { propName => + resultProps.setProperty(s"$listenerPrefix$propName", configValueAsString(srcProps.get(propName))) + } + resultProps + } + + // Creates a new truststore with certificates from the provided stores and returns the properties of the new store + private def mergeTrustStores(trustStore1Props: Properties, trustStore2Props: Properties): Properties = { + + def load(props: Properties): KeyStore = { + val ks = KeyStore.getInstance("JKS") + val password = props.get(SSL_TRUSTSTORE_PASSWORD_CONFIG).asInstanceOf[Password].value + val in = Files.newInputStream(Paths.get(props.getProperty(SSL_TRUSTSTORE_LOCATION_CONFIG))) + try { + ks.load(in, password.toCharArray) + ks + } finally { + in.close() + } + } + val cert1 = load(trustStore1Props).getCertificate("kafka") + val cert2 = load(trustStore2Props).getCertificate("kafka") + val certs = Map("kafka1" -> cert1, "kafka2" -> cert2) + + val combinedStorePath = File.createTempFile("truststore", ".jks").getAbsolutePath + val password = trustStore1Props.get(SSL_TRUSTSTORE_PASSWORD_CONFIG).asInstanceOf[Password] + TestSslUtils.createTrustStore(combinedStorePath, password, certs.asJava) + val newStoreProps = new Properties + newStoreProps.put(SSL_TRUSTSTORE_LOCATION_CONFIG, combinedStorePath) + newStoreProps.put(SSL_TRUSTSTORE_PASSWORD_CONFIG, password) + newStoreProps.put(SSL_TRUSTSTORE_TYPE_CONFIG, "JKS") + newStoreProps + } + + private def alterSslKeystore(adminClient: Admin, props: Properties, listener: String, expectFailure: Boolean = false): Unit = { + val configPrefix = listenerPrefix(listener) + val newProps = securityProps(props, KEYSTORE_PROPS, configPrefix) + reconfigureServers(newProps, perBrokerConfig = true, + (s"$configPrefix$SSL_KEYSTORE_LOCATION_CONFIG", props.getProperty(SSL_KEYSTORE_LOCATION_CONFIG)), expectFailure) + } + + private def alterSslKeystoreUsingConfigCommand(props: Properties, listener: String): Unit = { + val configPrefix = listenerPrefix(listener) + val newProps = securityProps(props, KEYSTORE_PROPS, configPrefix) + alterConfigsUsingConfigCommand(newProps) + waitForConfig(s"$configPrefix$SSL_KEYSTORE_LOCATION_CONFIG", props.getProperty(SSL_KEYSTORE_LOCATION_CONFIG)) + } + + private def serverEndpoints(adminClient: Admin): String = { + val nodes = adminClient.describeCluster().nodes().get + nodes.asScala.map { node => + s"${node.host}:${node.port}" + }.mkString(",") + } + + @nowarn("cat=deprecation") + private def alterAdvertisedListener(adminClient: Admin, externalAdminClient: Admin, oldHost: String, newHost: String): Unit = { + val configs = servers.map { server => + val resource = new ConfigResource(ConfigResource.Type.BROKER, server.config.brokerId.toString) + val newListeners = server.config.effectiveAdvertisedListeners.map { e => + if (e.listenerName.value == SecureExternal) + s"${e.listenerName.value}://$newHost:${server.boundPort(e.listenerName)}" + else + s"${e.listenerName.value}://${e.host}:${server.boundPort(e.listenerName)}" + }.mkString(",") + val configEntry = new ConfigEntry(KafkaConfig.AdvertisedListenersProp, newListeners) + (resource, new Config(Collections.singleton(configEntry))) + }.toMap.asJava + adminClient.alterConfigs(configs).all.get + servers.foreach { server => + TestUtils.retry(10000) { + val externalListener = server.config.effectiveAdvertisedListeners.find(_.listenerName.value == SecureExternal) + .getOrElse(throw new IllegalStateException("External listener not found")) + assertEquals(newHost, externalListener.host, "Config not updated") + } + } + val (endpoints, altered) = TestUtils.computeUntilTrue(serverEndpoints(externalAdminClient)) { endpoints => + !endpoints.contains(oldHost) + } + assertTrue(altered, s"Advertised listener update not propagated by controller: $endpoints") + } + + @nowarn("cat=deprecation") + private def alterConfigsOnServer(server: KafkaServer, props: Properties): Unit = { + val configEntries = props.asScala.map { case (k, v) => new ConfigEntry(k, v) }.toList.asJava + val newConfig = new Config(configEntries) + val configs = Map(new ConfigResource(ConfigResource.Type.BROKER, server.config.brokerId.toString) -> newConfig).asJava + adminClients.head.alterConfigs(configs).all.get + props.asScala.foreach { case (k, v) => waitForConfigOnServer(server, k, v) } + } + + @nowarn("cat=deprecation") + private def alterConfigs(servers: Seq[KafkaServer], adminClient: Admin, props: Properties, + perBrokerConfig: Boolean): AlterConfigsResult = { + val configEntries = props.asScala.map { case (k, v) => new ConfigEntry(k, v) }.toList.asJava + val newConfig = new Config(configEntries) + val configs = if (perBrokerConfig) { + servers.map { server => + val resource = new ConfigResource(ConfigResource.Type.BROKER, server.config.brokerId.toString) + (resource, newConfig) + }.toMap.asJava + } else { + Map(new ConfigResource(ConfigResource.Type.BROKER, "") -> newConfig).asJava + } + adminClient.alterConfigs(configs) + } + + private def reconfigureServers(newProps: Properties, perBrokerConfig: Boolean, aPropToVerify: (String, String), expectFailure: Boolean = false): Unit = { + val alterResult = alterConfigs(servers, adminClients.head, newProps, perBrokerConfig) + if (expectFailure) { + val oldProps = servers.head.config.values.asScala.filter { case (k, _) => newProps.containsKey(k) } + val brokerResources = if (perBrokerConfig) + servers.map(server => new ConfigResource(ConfigResource.Type.BROKER, server.config.brokerId.toString)) + else + Seq(new ConfigResource(ConfigResource.Type.BROKER, "")) + brokerResources.foreach { brokerResource => + val exception = assertThrows(classOf[ExecutionException], () => alterResult.values.get(brokerResource).get) + assertTrue(exception.getCause.isInstanceOf[InvalidRequestException]) + } + servers.foreach { server => + assertEquals(oldProps, server.config.values.asScala.filter { case (k, _) => newProps.containsKey(k) }) + } + } else { + alterResult.all.get + waitForConfig(aPropToVerify._1, aPropToVerify._2) + } + } + + private def configEntry(configDesc: Config, configName: String): ConfigEntry = { + configDesc.entries.asScala.find(cfg => cfg.name == configName) + .getOrElse(throw new IllegalStateException(s"Config not found $configName")) + } + + private def listenerPrefix(name: String): String = new ListenerName(name).configPrefix + + private def configureDynamicKeystoreInZooKeeper(kafkaConfig: KafkaConfig, sslProperties: Properties): Unit = { + val externalListenerPrefix = listenerPrefix(SecureExternal) + val sslStoreProps = new Properties + sslStoreProps ++= securityProps(sslProperties, KEYSTORE_PROPS, externalListenerPrefix) + sslStoreProps.put(KafkaConfig.PasswordEncoderSecretProp, kafkaConfig.passwordEncoderSecret.map(_.value).orNull) + zkClient.makeSurePersistentPathExists(ConfigEntityChangeNotificationZNode.path) + + val entityType = ConfigType.Broker + val entityName = kafkaConfig.brokerId.toString + + val passwordConfigs = sslStoreProps.asScala.keySet.filter(DynamicBrokerConfig.isPasswordConfig) + val passwordEncoder = createPasswordEncoder(kafkaConfig, kafkaConfig.passwordEncoderSecret) + + if (passwordConfigs.nonEmpty) { + passwordConfigs.foreach { configName => + val encodedValue = passwordEncoder.encode(new Password(sslStoreProps.getProperty(configName))) + sslStoreProps.setProperty(configName, encodedValue) + } + } + sslStoreProps.remove(KafkaConfig.PasswordEncoderSecretProp) + adminZkClient.changeConfigs(entityType, entityName, sslStoreProps) + + val brokerProps = adminZkClient.fetchEntityConfig("brokers", kafkaConfig.brokerId.toString) + assertEquals(4, brokerProps.size) + assertEquals(sslProperties.get(SSL_KEYSTORE_TYPE_CONFIG), + brokerProps.getProperty(s"$externalListenerPrefix$SSL_KEYSTORE_TYPE_CONFIG")) + assertEquals(sslProperties.get(SSL_KEYSTORE_LOCATION_CONFIG), + brokerProps.getProperty(s"$externalListenerPrefix$SSL_KEYSTORE_LOCATION_CONFIG")) + assertEquals(sslProperties.get(SSL_KEYSTORE_PASSWORD_CONFIG), + passwordEncoder.decode(brokerProps.getProperty(s"$externalListenerPrefix$SSL_KEYSTORE_PASSWORD_CONFIG"))) + assertEquals(sslProperties.get(SSL_KEY_PASSWORD_CONFIG), + passwordEncoder.decode(brokerProps.getProperty(s"$externalListenerPrefix$SSL_KEY_PASSWORD_CONFIG"))) + } + + private def createPasswordEncoder(config: KafkaConfig, secret: Option[Password]): PasswordEncoder = { + val encoderSecret = secret.getOrElse(throw new IllegalStateException("Password encoder secret not configured")) + new PasswordEncoder(encoderSecret, + config.passwordEncoderKeyFactoryAlgorithm, + config.passwordEncoderCipherAlgorithm, + config.passwordEncoderKeyLength, + config.passwordEncoderIterations) + } + + private def waitForConfig(propName: String, propValue: String, maxWaitMs: Long = 10000): Unit = { + servers.foreach { server => waitForConfigOnServer(server, propName, propValue, maxWaitMs) } + } + + private def waitForConfigOnServer(server: KafkaServer, propName: String, propValue: String, maxWaitMs: Long = 10000): Unit = { + TestUtils.retry(maxWaitMs) { + assertEquals(propValue, server.config.originals.get(propName)) + } + } + + private def configureMetricsReporters(reporters: Seq[Class[_]], props: Properties, + perBrokerConfig: Boolean = false): Unit = { + val reporterStr = reporters.map(_.getName).mkString(",") + props.put(KafkaConfig.MetricReporterClassesProp, reporterStr) + reconfigureServers(props, perBrokerConfig, (KafkaConfig.MetricReporterClassesProp, reporterStr)) + } + + private def invalidSslConfigs: Properties = { + val props = new Properties + props.put(SSL_KEYSTORE_LOCATION_CONFIG, "invalid/file/path") + props.put(SSL_KEYSTORE_PASSWORD_CONFIG, new Password("invalid")) + props.put(SSL_KEY_PASSWORD_CONFIG, new Password("invalid")) + props.put(SSL_KEYSTORE_TYPE_CONFIG, "PKCS12") + props + } + + private def currentThreads: List[String] = { + Thread.getAllStackTraces.keySet.asScala.toList.map(_.getName) + } + + private def matchingThreads(threadPrefix: String): List[String] = { + currentThreads.filter(_.startsWith(threadPrefix)) + } + + private def verifyThreads(threadPrefix: String, countPerBroker: Int, leftOverThreads: Int = 0): Unit = { + val expectedCount = countPerBroker * servers.size + val (threads, resized) = TestUtils.computeUntilTrue(matchingThreads(threadPrefix)) { matching => + matching.size >= expectedCount && matching.size <= expectedCount + leftOverThreads + } + assertTrue(resized, s"Invalid threads: expected $expectedCount, got ${threads.size}: $threads") + } + + private def startProduceConsume(retries: Int, producerClientId: String = "test-producer"): (ProducerThread, ConsumerThread) = { + val producerThread = new ProducerThread(producerClientId, retries) + clientThreads += producerThread + val consumerThread = new ConsumerThread(producerThread) + clientThreads += consumerThread + consumerThread.start() + producerThread.start() + TestUtils.waitUntilTrue(() => producerThread.sent >= 10, "Messages not sent") + (producerThread, consumerThread) + } + + private def stopAndVerifyProduceConsume(producerThread: ProducerThread, consumerThread: ConsumerThread, + mayReceiveDuplicates: Boolean = false): Unit = { + TestUtils.waitUntilTrue(() => producerThread.sent >= 10, "Messages not sent") + producerThread.shutdown() + consumerThread.initiateShutdown() + consumerThread.awaitShutdown() + assertEquals(producerThread.lastSent, consumerThread.lastReceived) + assertEquals(0, consumerThread.missingRecords.size) + if (!mayReceiveDuplicates) + assertFalse(consumerThread.duplicates, "Duplicates not expected") + assertFalse(consumerThread.outOfOrder, "Some messages received out of order") + } + + private def verifyConnectionFailure(producer: KafkaProducer[String, String]): Future[_] = { + val executor = Executors.newSingleThreadExecutor + executors += executor + val future = executor.submit(new Runnable() { + def run(): Unit = { + producer.send(new ProducerRecord(topic, "key", "value")).get + } + }) + verifyTimeout(future) + future + } + + private def verifyConnectionFailure(consumer: KafkaConsumer[String, String]): Future[_] = { + val executor = Executors.newSingleThreadExecutor + executors += executor + val future = executor.submit(new Runnable() { + def run(): Unit = { + consumer.commitSync() + } + }) + verifyTimeout(future) + future + } + + private def verifyTimeout(future: Future[_]): Unit = { + assertThrows(classOf[TimeoutException], () => future.get(100, TimeUnit.MILLISECONDS)) + } + + private def configValueAsString(value: Any): String = { + value match { + case password: Password => password.value + case list: util.List[_] => list.asScala.map(_.toString).mkString(",") + case _ => value.toString + } + } + + private def addListenerPropsSsl(listenerName: String, props: Properties): Unit = { + props ++= securityProps(sslProperties1, KEYSTORE_PROPS, listenerPrefix(listenerName)) + props ++= securityProps(sslProperties1, TRUSTSTORE_PROPS, listenerPrefix(listenerName)) + } + + private def addListenerPropsSasl(listener: String, mechanisms: Seq[String], props: Properties): Unit = { + val listenerName = new ListenerName(listener) + val prefix = listenerName.configPrefix + props.put(prefix + KafkaConfig.SaslEnabledMechanismsProp, mechanisms.mkString(",")) + props.put(prefix + KafkaConfig.SaslKerberosServiceNameProp, "kafka") + mechanisms.foreach { mechanism => + val jaasSection = jaasSections(Seq(mechanism), None, KafkaSasl, "").head + val jaasConfig = jaasSection.modules.head.toString + props.put(listenerName.saslMechanismConfigPrefix(mechanism) + KafkaConfig.SaslJaasConfigProp, jaasConfig) + } + } + + private def alterConfigsUsingConfigCommand(props: Properties): Unit = { + val propsFile = TestUtils.tempFile() + val propsWriter = new FileWriter(propsFile) + try { + clientProps(SecurityProtocol.SSL).forEach { + case (k, v) => propsWriter.write(s"$k=$v\n") + } + } finally { + propsWriter.close() + } + + servers.foreach { server => + val args = Array("--bootstrap-server", TestUtils.bootstrapServers(servers, new ListenerName(SecureInternal)), + "--command-config", propsFile.getAbsolutePath, + "--alter", "--add-config", props.asScala.map { case (k, v) => s"$k=$v" }.mkString(","), + "--entity-type", "brokers", + "--entity-name", server.config.brokerId.toString) + ConfigCommand.main(args) + } + } + + private abstract class ClientBuilder[T]() { + protected var _bootstrapServers: Option[String] = None + protected var _listenerName = SecureExternal + protected var _securityProtocol = SecurityProtocol.SASL_SSL + protected var _saslMechanism = kafkaClientSaslMechanism + protected var _clientId = "test-client" + protected val _propsOverride: Properties = new Properties + + def bootstrapServers(bootstrap: String): this.type = { _bootstrapServers = Some(bootstrap); this } + def listenerName(listener: String): this.type = { _listenerName = listener; this } + def securityProtocol(protocol: SecurityProtocol): this.type = { _securityProtocol = protocol; this } + def saslMechanism(mechanism: String): this.type = { _saslMechanism = mechanism; this } + def clientId(id: String): this.type = { _clientId = id; this } + def keyStoreProps(props: Properties): this.type = { _propsOverride ++= securityProps(props, KEYSTORE_PROPS); this } + def trustStoreProps(props: Properties): this.type = { _propsOverride ++= securityProps(props, TRUSTSTORE_PROPS); this } + + def bootstrapServers: String = + _bootstrapServers.getOrElse(TestUtils.bootstrapServers(servers, new ListenerName(_listenerName))) + + def propsOverride: Properties = { + val props = clientProps(_securityProtocol, Some(_saslMechanism)) + props.put(CommonClientConfigs.CLIENT_ID_CONFIG, _clientId) + props ++= _propsOverride + props + } + + def build(): T + } + + private case class ProducerBuilder() extends ClientBuilder[KafkaProducer[String, String]] { + private var _retries = Int.MaxValue + private var _acks = -1 + private var _requestTimeoutMs = 30000 + private var _deliveryTimeoutMs = 30000 + + def maxRetries(retries: Int): ProducerBuilder = { _retries = retries; this } + def acks(acks: Int): ProducerBuilder = { _acks = acks; this } + def requestTimeoutMs(timeoutMs: Int): ProducerBuilder = { _requestTimeoutMs = timeoutMs; this } + def deliveryTimeoutMs(timeoutMs: Int): ProducerBuilder = { _deliveryTimeoutMs= timeoutMs; this } + + override def build(): KafkaProducer[String, String] = { + val producerProps = propsOverride + producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers) + producerProps.put(ProducerConfig.ACKS_CONFIG, _acks.toString) + producerProps.put(ProducerConfig.RETRIES_CONFIG, _retries.toString) + producerProps.put(ProducerConfig.DELIVERY_TIMEOUT_MS_CONFIG, _deliveryTimeoutMs.toString) + producerProps.put(ProducerConfig.REQUEST_TIMEOUT_MS_CONFIG, _requestTimeoutMs.toString) + + val producer = new KafkaProducer[String, String](producerProps, new StringSerializer, new StringSerializer) + producers += producer + producer + } + } + + private case class ConsumerBuilder(group: String) extends ClientBuilder[KafkaConsumer[String, String]] { + private var _autoOffsetReset = "earliest" + private var _enableAutoCommit = false + private var _topic = DynamicBrokerReconfigurationTest.this.topic + + def autoOffsetReset(reset: String): ConsumerBuilder = { _autoOffsetReset = reset; this } + def enableAutoCommit(enable: Boolean): ConsumerBuilder = { _enableAutoCommit = enable; this } + def topic(topic: String): ConsumerBuilder = { _topic = topic; this } + + override def build(): KafkaConsumer[String, String] = { + val consumerProps = propsOverride + consumerProps.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers) + consumerProps.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, _autoOffsetReset) + consumerProps.put(ConsumerConfig.GROUP_ID_CONFIG, group) + consumerProps.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, _enableAutoCommit.toString) + + val consumer = new KafkaConsumer[String, String](consumerProps, new StringDeserializer, new StringDeserializer) + consumers += consumer + + consumer.subscribe(Collections.singleton(_topic)) + if (_autoOffsetReset == "latest") + awaitInitialPositions(consumer) + consumer + } + } + + private class ProducerThread(clientId: String, retries: Int) + extends ShutdownableThread(clientId, isInterruptible = false) { + + private val producer = ProducerBuilder().maxRetries(retries).clientId(clientId).build() + val lastSent = new ConcurrentHashMap[Int, Int]() + @volatile var sent = 0 + override def doWork(): Unit = { + try { + while (isRunning) { + val key = sent.toString + val partition = sent % numPartitions + val record = new ProducerRecord(topic, partition, key, s"value$sent") + producer.send(record).get(10, TimeUnit.SECONDS) + lastSent.put(partition, sent) + sent += 1 + } + } finally { + producer.close() + } + } + } + + private class ConsumerThread(producerThread: ProducerThread) extends ShutdownableThread("test-consumer", isInterruptible = false) { + private val consumer = ConsumerBuilder("group1").enableAutoCommit(true).build() + val lastReceived = new ConcurrentHashMap[Int, Int]() + val missingRecords = new ConcurrentLinkedQueue[Int]() + @volatile var outOfOrder = false + @volatile var duplicates = false + @volatile var lastBatch: ConsumerRecords[String, String] = _ + @volatile private var endTimeMs = Long.MaxValue + @volatile var received = 0 + override def doWork(): Unit = { + try { + while (isRunning || (lastReceived != producerThread.lastSent && System.currentTimeMillis < endTimeMs)) { + val records = consumer.poll(Duration.ofMillis(50L)) + received += records.count + if (!records.isEmpty) { + lastBatch = records + records.partitions.forEach { tp => + val partition = tp.partition + records.records(tp).asScala.map(_.key.toInt).foreach { key => + val prevKey = lastReceived.asScala.getOrElse(partition, partition - numPartitions) + val expectedKey = prevKey + numPartitions + if (key < prevKey) + outOfOrder = true + else if (key == prevKey) + duplicates = true + else { + for (i <- expectedKey until key by numPartitions) + missingRecords.add(i) + } + lastReceived.put(partition, key) + missingRecords.remove(key) + } + } + } + } + } finally { + consumer.close() + } + } + + override def initiateShutdown(): Boolean = { + endTimeMs = System.currentTimeMillis + 10 * 1000 + super.initiateShutdown() + } + + def waitForMatchingRecords(predicate: ConsumerRecord[String, String] => Boolean): Unit = { + TestUtils.waitUntilTrue(() => { + val records = lastBatch + if (records == null || records.isEmpty) + false + else + records.asScala.toList.exists(predicate) + }, "Received records did not match") + } + } +} + +object TestMetricsReporter { + val PollingIntervalProp = "polling.interval" + val testReporters = new ConcurrentLinkedQueue[TestMetricsReporter]() + val configuredBrokers = mutable.Set[Int]() + + def waitForReporters(count: Int): List[TestMetricsReporter] = { + TestUtils.waitUntilTrue(() => testReporters.size == count, msg = "Metrics reporters not created") + + val reporters = testReporters.asScala.toList + TestUtils.waitUntilTrue(() => reporters.forall(_.configureCount == 1), msg = "Metrics reporters not configured") + reporters + } +} + +class TestMetricsReporter extends MetricsReporter with Reconfigurable with Closeable with ClusterResourceListener { + import TestMetricsReporter._ + val kafkaMetrics = ArrayBuffer[KafkaMetric]() + @volatile var initializeCount = 0 + @volatile var configureCount = 0 + @volatile var reconfigureCount = 0 + @volatile var closeCount = 0 + @volatile var clusterUpdateCount = 0 + @volatile var pollingInterval: Int = -1 + testReporters.add(this) + + override def init(metrics: util.List[KafkaMetric]): Unit = { + kafkaMetrics ++= metrics.asScala + initializeCount += 1 + } + + override def configure(configs: util.Map[String, _]): Unit = { + configuredBrokers += configs.get(KafkaConfig.BrokerIdProp).toString.toInt + configureCount += 1 + pollingInterval = configs.get(PollingIntervalProp).toString.toInt + } + + override def metricChange(metric: KafkaMetric): Unit = { + } + + override def metricRemoval(metric: KafkaMetric): Unit = { + kafkaMetrics -= metric + } + + override def onUpdate(clusterResource: ClusterResource): Unit = { + assertNotNull(clusterResource.clusterId, "Cluster id not set") + clusterUpdateCount += 1 + } + + override def reconfigurableConfigs(): util.Set[String] = { + Set(PollingIntervalProp).asJava + } + + override def validateReconfiguration(configs: util.Map[String, _]): Unit = { + val pollingInterval = configs.get(PollingIntervalProp).toString.toInt + if (pollingInterval <= 0) + throw new ConfigException(s"Invalid polling interval $pollingInterval") + } + + override def reconfigure(configs: util.Map[String, _]): Unit = { + reconfigureCount += 1 + pollingInterval = configs.get(PollingIntervalProp).toString.toInt + } + + override def close(): Unit = { + closeCount += 1 + } + + def verifyState(reconfigureCount: Int, deleteCount: Int, pollingInterval: Int): Unit = { + assertEquals(1, initializeCount) + assertEquals(1, configureCount) + assertEquals(reconfigureCount, this.reconfigureCount) + assertEquals(deleteCount, closeCount) + assertEquals(1, clusterUpdateCount) + assertEquals(pollingInterval, this.pollingInterval) + } + + def verifyMetricValue(name: String, group: String): Unit = { + val matchingMetrics = kafkaMetrics.filter(metric => metric.metricName.name == name && metric.metricName.group == group) + assertTrue(matchingMetrics.nonEmpty, "Metric not found") + val total = matchingMetrics.foldLeft(0.0)((total, metric) => total + metric.metricValue.asInstanceOf[Double]) + assertTrue(total > 0.0, "Invalid metric value") + } +} + + +class MockFileConfigProvider extends FileConfigProvider { + @throws(classOf[IOException]) + override def reader(path: String): Reader = { + new StringReader("key=testKey\npassword=ServerPassword\ninterval=1000\nupdinterval=2000\nstoretype=JKS"); + } +} diff --git a/core/src/test/scala/integration/kafka/server/FetchRequestBetweenDifferentIbpTest.scala b/core/src/test/scala/integration/kafka/server/FetchRequestBetweenDifferentIbpTest.scala new file mode 100644 index 0000000..37ac4a2 --- /dev/null +++ b/core/src/test/scala/integration/kafka/server/FetchRequestBetweenDifferentIbpTest.scala @@ -0,0 +1,152 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package integration.kafka.server + +import java.time.Duration +import java.util.Arrays.asList +import kafka.api.{ApiVersion, DefaultApiVersion, KAFKA_2_7_IV0, KAFKA_2_8_IV1, KAFKA_3_1_IV0} +import kafka.server.{BaseRequestTest, KafkaConfig} +import kafka.utils.TestUtils +import org.apache.kafka.clients.producer.ProducerRecord +import org.apache.kafka.common.TopicPartition +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.collection.{Map, Seq} + +class FetchRequestBetweenDifferentIbpTest extends BaseRequestTest { + + override def brokerCount: Int = 3 + override def generateConfigs: Seq[KafkaConfig] = { + // Brokers should be at most 2 different IBP versions, but for more test coverage, three are used here. + Seq( + createConfig(0, KAFKA_2_7_IV0), + createConfig(1, KAFKA_2_8_IV1), + createConfig(2, KAFKA_3_1_IV0) + ) + } + + @Test + def testControllerOldIBP(): Unit = { + // Ensure controller version < KAFKA_2_8_IV1, and then create a topic where leader of partition 0 is not the controller, + // leader of partition 1 is. + testControllerWithGivenIBP(KAFKA_2_7_IV0, 0) + } + + @Test + def testControllerNewIBP(): Unit = { + // Ensure controller version = KAFKA_3_1_IV0, and then create a topic where leader of partition 1 is the old version. + testControllerWithGivenIBP(KAFKA_3_1_IV0, 2) + } + + def testControllerWithGivenIBP(version: DefaultApiVersion, controllerBroker: Int): Unit = { + val topic = "topic" + val producer = createProducer() + val consumer = createConsumer() + + ensureControllerWithIBP(version) + assertEquals(controllerBroker, controllerSocketServer.config.brokerId) + val partitionLeaders = createTopic(topic, Map(0 -> Seq(1, 0, 2), 1 -> Seq(0, 2, 1))) + TestUtils.waitForAllPartitionsMetadata(servers, topic, 2) + + assertEquals(1, partitionLeaders(0)) + assertEquals(0, partitionLeaders(1)) + + val record1 = new ProducerRecord(topic, 0, null, "key".getBytes, "value".getBytes) + val record2 = new ProducerRecord(topic, 1, null, "key".getBytes, "value".getBytes) + producer.send(record1) + producer.send(record2) + + consumer.assign(asList(new TopicPartition(topic, 0), new TopicPartition(topic, 1))) + val count = consumer.poll(Duration.ofMillis(1500)).count() + consumer.poll(Duration.ofMillis(1500)).count() + assertEquals(2, count) + } + + @Test + def testControllerNewToOldIBP(): Unit = { + testControllerSwitchingIBP(KAFKA_3_1_IV0, 2, KAFKA_2_7_IV0, 0) + } + + @Test + def testControllerOldToNewIBP(): Unit = { + testControllerSwitchingIBP(KAFKA_2_7_IV0, 0, KAFKA_3_1_IV0, 2) + } + + + def testControllerSwitchingIBP(version1: DefaultApiVersion, broker1: Int, version2: DefaultApiVersion, broker2: Int): Unit = { + val topic = "topic" + val topic2 = "topic2" + val producer = createProducer() + val consumer = createConsumer() + + // Ensure controller version = version1 + ensureControllerWithIBP(version1) + assertEquals(broker1, controllerSocketServer.config.brokerId) + val partitionLeaders = createTopic(topic, Map(0 -> Seq(1, 0, 2), 1 -> Seq(0, 2, 1))) + TestUtils.waitForAllPartitionsMetadata(servers, topic, 2) + assertEquals(1, partitionLeaders(0)) + assertEquals(0, partitionLeaders(1)) + + val record1 = new ProducerRecord(topic, 0, null, "key".getBytes, "value".getBytes) + val record2 = new ProducerRecord(topic, 1, null, "key".getBytes, "value".getBytes) + producer.send(record1) + producer.send(record2) + + consumer.assign(asList(new TopicPartition(topic, 0), new TopicPartition(topic, 1))) + + val count = consumer.poll(Duration.ofMillis(1500)).count() + consumer.poll(Duration.ofMillis(1500)).count() + assertEquals(2, count) + + // Make controller version2 + ensureControllerWithIBP(version2) + assertEquals(broker2, controllerSocketServer.config.brokerId) + // Create a new topic + createTopic(topic2, Map(0 -> Seq(1, 0, 2))) + TestUtils.waitForAllPartitionsMetadata(servers, topic2, 1) + TestUtils.waitForAllPartitionsMetadata(servers, topic, 2) + + val record3 = new ProducerRecord(topic2, 0, null, "key".getBytes, "value".getBytes) + val record4 = new ProducerRecord(topic, 1, null, "key".getBytes, "value".getBytes) + producer.send(record3) + producer.send(record4) + + // Assign this new topic in addition to the old topics. + consumer.assign(asList(new TopicPartition(topic, 0), new TopicPartition(topic, 1), new TopicPartition(topic2, 0))) + + val count2 = consumer.poll(Duration.ofMillis(1500)).count() + consumer.poll(Duration.ofMillis(1500)).count() + assertEquals(2, count2) + } + + private def ensureControllerWithIBP(version: DefaultApiVersion): Unit = { + val nonControllerServers = servers.filter(_.config.interBrokerProtocolVersion != version) + nonControllerServers.iterator.foreach(server => { + server.shutdown() + }) + TestUtils.waitUntilControllerElected(zkClient) + nonControllerServers.iterator.foreach(server => { + server.startup() + }) + } + + private def createConfig(nodeId: Int, interBrokerVersion: ApiVersion): KafkaConfig = { + val props = TestUtils.createBrokerConfig(nodeId, zkConnect) + props.put(KafkaConfig.InterBrokerProtocolVersionProp, interBrokerVersion.version) + KafkaConfig.fromProps(props) + } + +} diff --git a/core/src/test/scala/integration/kafka/server/FetchRequestTestDowngrade.scala b/core/src/test/scala/integration/kafka/server/FetchRequestTestDowngrade.scala new file mode 100644 index 0000000..148e076 --- /dev/null +++ b/core/src/test/scala/integration/kafka/server/FetchRequestTestDowngrade.scala @@ -0,0 +1,81 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package integration.kafka.server + +import java.time.Duration +import java.util.Arrays.asList + +import kafka.api.{ApiVersion, KAFKA_2_7_IV0, KAFKA_3_1_IV0} +import kafka.server.{BaseRequestTest, KafkaConfig} +import kafka.utils.TestUtils +import kafka.zk.ZkVersion +import org.apache.kafka.clients.producer.ProducerRecord +import org.apache.kafka.common.TopicPartition +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.collection.{Map, Seq} + +class FetchRequestTestDowngrade extends BaseRequestTest { + + override def brokerCount: Int = 2 + override def generateConfigs: Seq[KafkaConfig] = { + // Controller should start with newer IBP and downgrade to the older one. + Seq( + createConfig(0, KAFKA_3_1_IV0), + createConfig(1, KAFKA_2_7_IV0) + ) + } + + @Test + def testTopicIdIsRemovedFromFetcherWhenControllerDowngrades(): Unit = { + val tp = new TopicPartition("topic", 0) + val producer = createProducer() + val consumer = createConsumer() + + ensureControllerIn(Seq(0)) + assertEquals(0, controllerSocketServer.config.brokerId) + val partitionLeaders = createTopic(tp.topic, Map(tp.partition -> Seq(1, 0))) + TestUtils.waitForAllPartitionsMetadata(servers, tp.topic, 1) + ensureControllerIn(Seq(1)) + assertEquals(1, controllerSocketServer.config.brokerId) + + assertEquals(1, partitionLeaders(0)) + + val record1 = new ProducerRecord(tp.topic, tp.partition, null, "key".getBytes, "value".getBytes) + producer.send(record1) + + consumer.assign(asList(tp)) + val count = consumer.poll(Duration.ofMillis(5000)).count() + assertEquals(1, count) + } + + private def ensureControllerIn(brokerIds: Seq[Int]): Unit = { + while (!brokerIds.contains(controllerSocketServer.config.brokerId)) { + zkClient.deleteController(ZkVersion.MatchAnyVersion) + TestUtils.waitUntilControllerElected(zkClient) + } + } + + private def createConfig(nodeId: Int, interBrokerVersion: ApiVersion): KafkaConfig = { + val props = TestUtils.createBrokerConfig(nodeId, zkConnect) + props.put(KafkaConfig.InterBrokerProtocolVersionProp, interBrokerVersion.version) + KafkaConfig.fromProps(props) + } + +} diff --git a/core/src/test/scala/integration/kafka/server/GssapiAuthenticationTest.scala b/core/src/test/scala/integration/kafka/server/GssapiAuthenticationTest.scala new file mode 100644 index 0000000..2bda3ac --- /dev/null +++ b/core/src/test/scala/integration/kafka/server/GssapiAuthenticationTest.scala @@ -0,0 +1,314 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.net.InetSocketAddress +import java.time.Duration +import java.util.{Collections, Properties} +import java.util.concurrent.{CountDownLatch, Executors, TimeUnit} + +import javax.security.auth.login.LoginContext +import kafka.api.{Both, IntegrationTestHarness, SaslSetup} +import kafka.utils.TestUtils +import org.apache.kafka.clients.CommonClientConfigs +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.config.SaslConfigs +import org.apache.kafka.common.errors.SaslAuthenticationException +import org.apache.kafka.common.message.ApiMessageType.ListenerType +import org.apache.kafka.common.network._ +import org.apache.kafka.common.requests.ApiVersionsResponse +import org.apache.kafka.common.security.{JaasContext, TestSecurityConfig} +import org.apache.kafka.common.security.auth.{Login, SecurityProtocol} +import org.apache.kafka.common.security.kerberos.KerberosLogin +import org.apache.kafka.common.utils.{LogContext, MockTime} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +import scala.jdk.CollectionConverters._ + +class GssapiAuthenticationTest extends IntegrationTestHarness with SaslSetup { + override val brokerCount = 1 + override protected def securityProtocol = SecurityProtocol.SASL_PLAINTEXT + + private val kafkaClientSaslMechanism = "GSSAPI" + private val kafkaServerSaslMechanisms = List("GSSAPI") + + private val numThreads = 10 + private val executor = Executors.newFixedThreadPool(numThreads) + private val clientConfig: Properties = new Properties + private var serverAddr: InetSocketAddress = _ + private val time = new MockTime(10) + val topic = "topic" + val part = 0 + val tp = new TopicPartition(topic, part) + private val failedAuthenticationDelayMs = 2000 + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + TestableKerberosLogin.reset() + startSasl(jaasSections(kafkaServerSaslMechanisms, Option(kafkaClientSaslMechanism), Both)) + serverConfig.put(KafkaConfig.SslClientAuthProp, "required") + serverConfig.put(KafkaConfig.FailedAuthenticationDelayMsProp, failedAuthenticationDelayMs.toString) + super.setUp(testInfo) + serverAddr = new InetSocketAddress("localhost", + servers.head.boundPort(ListenerName.forSecurityProtocol(SecurityProtocol.SASL_PLAINTEXT))) + + clientConfig.put(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, SecurityProtocol.SASL_PLAINTEXT.name) + clientConfig.put(SaslConfigs.SASL_MECHANISM, kafkaClientSaslMechanism) + clientConfig.put(SaslConfigs.SASL_JAAS_CONFIG, jaasClientLoginModule(kafkaClientSaslMechanism)) + clientConfig.put(CommonClientConfigs.CONNECTIONS_MAX_IDLE_MS_CONFIG, "5000") + + // create the test topic with all the brokers as replicas + createTopic(topic, 2, brokerCount) + } + + @AfterEach + override def tearDown(): Unit = { + executor.shutdownNow() + super.tearDown() + closeSasl() + TestableKerberosLogin.reset() + } + + /** + * Tests that Kerberos replay error `Request is a replay (34)` is not handled as an authentication exception + * since replay detection used to detect DoS attacks may occasionally reject valid concurrent requests. + */ + @Test + def testRequestIsAReplay(): Unit = { + val successfulAuthsPerThread = 10 + val futures = (0 until numThreads).map(_ => executor.submit(new Runnable { + override def run(): Unit = verifyRetriableFailuresDuringAuthentication(successfulAuthsPerThread) + })) + futures.foreach(_.get(60, TimeUnit.SECONDS)) + assertEquals(0, TestUtils.totalMetricValue(servers.head, "failed-authentication-total")) + val successfulAuths = TestUtils.totalMetricValue(servers.head, "successful-authentication-total") + assertTrue(successfulAuths > successfulAuthsPerThread * numThreads, "Too few authentications: " + successfulAuths) + } + + /** + * Verifies that if login fails, subsequent re-login without failures works and clients + * are able to connect after the second re-login. Verifies that logout is performed only once + * since duplicate logouts without successful login results in NPE from Java 9 onwards. + */ + @Test + def testLoginFailure(): Unit = { + val selector = createSelectorWithRelogin() + try { + val login = TestableKerberosLogin.instance + assertNotNull(login) + login.loginException = Some(new RuntimeException("Test exception to fail login")) + executor.submit(() => login.reLogin(), 0) + executor.submit(() => login.reLogin(), 0) + + verifyRelogin(selector, login) + assertEquals(2, login.loginAttempts) + assertEquals(1, login.logoutAttempts) + } finally { + selector.close() + } + } + + /** + * Verifies that there are no authentication failures during Kerberos re-login. If authentication + * is performed when credentials are unavailable between logout and login, we handle it as a + * transient error and not an authentication failure so that clients may retry. + */ + @Test + def testReLogin(): Unit = { + val selector = createSelectorWithRelogin() + try { + val login = TestableKerberosLogin.instance + assertNotNull(login) + executor.submit(() => login.reLogin(), 0) + verifyRelogin(selector, login) + } finally { + selector.close() + } + } + + private def verifyRelogin(selector: Selector, login: TestableKerberosLogin): Unit = { + val node1 = "1" + selector.connect(node1, serverAddr, 1024, 1024) + login.logoutResumeLatch.countDown() + login.logoutCompleteLatch.await(15, TimeUnit.SECONDS) + assertFalse(pollUntilReadyOrDisconnected(selector, node1), "Authenticated during re-login") + + login.reLoginResumeLatch.countDown() + login.reLoginCompleteLatch.await(15, TimeUnit.SECONDS) + val node2 = "2" + selector.connect(node2, serverAddr, 1024, 1024) + assertTrue(pollUntilReadyOrDisconnected(selector, node2), "Authenticated failed after re-login") + } + + /** + * Tests that Kerberos error `Server not found in Kerberos database (7)` is handled + * as a fatal authentication failure. + */ + @Test + def testServerNotFoundInKerberosDatabase(): Unit = { + val jaasConfig = clientConfig.getProperty(SaslConfigs.SASL_JAAS_CONFIG) + val invalidServiceConfig = jaasConfig.replace("serviceName=\"kafka\"", "serviceName=\"invalid-service\"") + clientConfig.put(SaslConfigs.SASL_JAAS_CONFIG, invalidServiceConfig) + clientConfig.put(SaslConfigs.SASL_KERBEROS_SERVICE_NAME, "invalid-service") + verifyNonRetriableAuthenticationFailure() + } + + /** + * Test that when client fails to verify authenticity of the server, the resulting failed authentication exception + * is thrown immediately, and is not affected by connection.failed.authentication.delay.ms. + */ + @Test + def testServerAuthenticationFailure(): Unit = { + // Setup client with a non-existent service principal, so that server authentication fails on the client + val clientLoginContext = jaasClientLoginModule(kafkaClientSaslMechanism, Some("another-kafka-service")) + val configOverrides = new Properties() + configOverrides.setProperty(SaslConfigs.SASL_JAAS_CONFIG, clientLoginContext) + val consumer = createConsumer(configOverrides = configOverrides) + consumer.assign(List(tp).asJava) + + val startMs = System.currentTimeMillis() + assertThrows(classOf[SaslAuthenticationException], () => consumer.poll(Duration.ofMillis(50))) + val endMs = System.currentTimeMillis() + require(endMs - startMs < failedAuthenticationDelayMs, "Failed authentication must not be delayed on the client") + consumer.close() + } + + /** + * Verifies that any exceptions during authentication with the current `clientConfig` are + * notified with disconnect state `AUTHENTICATE` (and not `AUTHENTICATION_FAILED`). This + * is to ensure that NetworkClient doesn't handle this as a fatal authentication failure, + * but as a transient I/O exception. So Producer/Consumer/AdminClient will retry + * any operation based on their configuration until timeout and will not propagate + * the exception to the application. + */ + private def verifyRetriableFailuresDuringAuthentication(numSuccessfulAuths: Int): Unit = { + val selector = createSelector() + try { + var actualSuccessfulAuths = 0 + while (actualSuccessfulAuths < numSuccessfulAuths) { + val nodeId = actualSuccessfulAuths.toString + selector.connect(nodeId, serverAddr, 1024, 1024) + val isReady = pollUntilReadyOrDisconnected(selector, nodeId) + if (isReady) + actualSuccessfulAuths += 1 + selector.close(nodeId) + } + } finally { + selector.close() + } + } + + private def pollUntilReadyOrDisconnected(selector: Selector, nodeId: String): Boolean = { + TestUtils.waitUntilTrue(() => { + selector.poll(100) + val disconnectState = selector.disconnected().get(nodeId) + // Verify that disconnect state is not AUTHENTICATION_FAILED + if (disconnectState != null) { + assertEquals(ChannelState.State.AUTHENTICATE, disconnectState.state(), + s"Authentication failed with exception ${disconnectState.exception()}") + } + selector.isChannelReady(nodeId) || disconnectState != null + }, "Client not ready or disconnected within timeout") + val isReady = selector.isChannelReady(nodeId) + selector.close(nodeId) + isReady + } + + /** + * Verifies that authentication with the current `clientConfig` results in disconnection and that + * the disconnection is notified with disconnect state `AUTHENTICATION_FAILED`. This is to ensure + * that NetworkClient handles this as a fatal authentication failure that is propagated to + * applications by Producer/Consumer/AdminClient without retrying and waiting for timeout. + */ + private def verifyNonRetriableAuthenticationFailure(): Unit = { + val selector = createSelector() + val nodeId = "1" + selector.connect(nodeId, serverAddr, 1024, 1024) + TestUtils.waitUntilTrue(() => { + selector.poll(100) + val disconnectState = selector.disconnected().get(nodeId) + if (disconnectState != null) + assertEquals(ChannelState.State.AUTHENTICATION_FAILED, disconnectState.state()) + disconnectState != null + }, "Client not disconnected within timeout") + } + + private def createSelector(): Selector = { + val channelBuilder = ChannelBuilders.clientChannelBuilder(securityProtocol, + JaasContext.Type.CLIENT, new TestSecurityConfig(clientConfig), null, kafkaClientSaslMechanism, + time, true, new LogContext()) + NetworkTestUtils.createSelector(channelBuilder, time) + } + + private def createSelectorWithRelogin(): Selector = { + clientConfig.setProperty(SaslConfigs.SASL_KERBEROS_MIN_TIME_BEFORE_RELOGIN, "0") + val config = new TestSecurityConfig(clientConfig) + val jaasContexts = Collections.singletonMap("GSSAPI", JaasContext.loadClientContext(config.values())) + val channelBuilder = new SaslChannelBuilder(Mode.CLIENT, jaasContexts, securityProtocol, + null, false, kafkaClientSaslMechanism, true, null, null, null, time, new LogContext(), + () => ApiVersionsResponse.defaultApiVersionsResponse(ListenerType.ZK_BROKER)) { + override protected def defaultLoginClass(): Class[_ <: Login] = classOf[TestableKerberosLogin] + } + channelBuilder.configure(config.values()) + NetworkTestUtils.createSelector(channelBuilder, time) + } +} + +object TestableKerberosLogin { + @volatile var instance: TestableKerberosLogin = _ + def reset(): Unit = { + instance = null + } +} + +class TestableKerberosLogin extends KerberosLogin { + val logoutResumeLatch = new CountDownLatch(1) + val logoutCompleteLatch = new CountDownLatch(1) + val reLoginResumeLatch = new CountDownLatch(1) + val reLoginCompleteLatch = new CountDownLatch(1) + @volatile var loginException: Option[RuntimeException] = None + @volatile var loginAttempts = 0 + @volatile var logoutAttempts = 0 + + assertNull(TestableKerberosLogin.instance) + TestableKerberosLogin.instance = this + + override def reLogin(): Unit = { + super.reLogin() + reLoginCompleteLatch.countDown() + } + + override protected def login(loginContext: LoginContext): Unit = { + loginAttempts += 1 + loginException.foreach { e => + loginException = None + throw e + } + super.login(loginContext) + } + + override protected def logout(): Unit = { + logoutAttempts += 1 + logoutResumeLatch.await(15, TimeUnit.SECONDS) + super.logout() + logoutCompleteLatch.countDown() + reLoginResumeLatch.await(15, TimeUnit.SECONDS) + } +} diff --git a/core/src/test/scala/integration/kafka/server/IntegrationTestUtils.scala b/core/src/test/scala/integration/kafka/server/IntegrationTestUtils.scala new file mode 100644 index 0000000..203e181 --- /dev/null +++ b/core/src/test/scala/integration/kafka/server/IntegrationTestUtils.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.io.{DataInputStream, DataOutputStream} +import java.net.Socket +import java.nio.ByteBuffer +import java.util.{Collections, Properties} + +import kafka.network.SocketServer +import kafka.utils.Implicits._ +import kafka.utils.{NotNothing, TestUtils} +import org.apache.kafka.clients.admin.{Admin, NewTopic} +import org.apache.kafka.common.network.{ListenerName, Mode} +import org.apache.kafka.common.protocol.ApiKeys +import org.apache.kafka.common.requests.{AbstractRequest, AbstractResponse, RequestHeader, ResponseHeader} +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.utils.Utils + +import scala.annotation.nowarn +import scala.jdk.CollectionConverters._ +import scala.reflect.ClassTag + +object IntegrationTestUtils { + + def sendRequest(socket: Socket, request: Array[Byte]): Unit = { + val outgoing = new DataOutputStream(socket.getOutputStream) + outgoing.writeInt(request.length) + outgoing.write(request) + outgoing.flush() + } + + private def sendWithHeader(request: AbstractRequest, header: RequestHeader, socket: Socket): Unit = { + val serializedBytes = Utils.toArray(request.serializeWithHeader(header)) + sendRequest(socket, serializedBytes) + } + + def nextRequestHeader[T <: AbstractResponse](apiKey: ApiKeys, + apiVersion: Short, + clientId: String = "client-id", + correlationIdOpt: Option[Int] = None): RequestHeader = { + val correlationId = correlationIdOpt.getOrElse { + this.correlationId += 1 + this.correlationId + } + new RequestHeader(apiKey, apiVersion, clientId, correlationId) + } + + def send(request: AbstractRequest, + socket: Socket, + clientId: String = "client-id", + correlationId: Option[Int] = None): Unit = { + val header = nextRequestHeader(request.apiKey, request.version, clientId, correlationId) + sendWithHeader(request, header, socket) + } + + def receive[T <: AbstractResponse](socket: Socket, apiKey: ApiKeys, version: Short) + (implicit classTag: ClassTag[T], @nowarn("cat=unused") nn: NotNothing[T]): T = { + val incoming = new DataInputStream(socket.getInputStream) + val len = incoming.readInt() + + val responseBytes = new Array[Byte](len) + incoming.readFully(responseBytes) + + val responseBuffer = ByteBuffer.wrap(responseBytes) + ResponseHeader.parse(responseBuffer, apiKey.responseHeaderVersion(version)) + + AbstractResponse.parseResponse(apiKey, responseBuffer, version) match { + case response: T => response + case response => + throw new ClassCastException(s"Expected response with type ${classTag.runtimeClass}, but found ${response.getClass}") + } + } + + def sendAndReceive[T <: AbstractResponse](request: AbstractRequest, + socket: Socket, + clientId: String = "client-id", + correlationId: Option[Int] = None) + (implicit classTag: ClassTag[T], nn: NotNothing[T]): T = { + send(request, socket, clientId, correlationId) + receive[T](socket, request.apiKey, request.version) + } + + def connectAndReceive[T <: AbstractResponse](request: AbstractRequest, + destination: SocketServer, + listenerName: ListenerName) + (implicit classTag: ClassTag[T], nn: NotNothing[T]): T = { + val socket = connect(destination, listenerName) + try sendAndReceive[T](request, socket) + finally socket.close() + } + + def createTopic( + admin: Admin, + topic: String, + numPartitions: Int, + replicationFactor: Short + ): Unit = { + val newTopics = Collections.singletonList(new NewTopic(topic, numPartitions, replicationFactor)) + val createTopicResult = admin.createTopics(newTopics) + createTopicResult.all().get() + } + + def createTopic( + admin: Admin, + topic: String, + replicaAssignment: Map[Int, Seq[Int]] + ): Unit = { + val javaAssignment = new java.util.HashMap[Integer, java.util.List[Integer]]() + replicaAssignment.forKeyValue { (partitionId, assignment) => + javaAssignment.put(partitionId, assignment.map(Int.box).asJava) + } + val newTopic = new NewTopic(topic, javaAssignment) + val newTopics = Collections.singletonList(newTopic) + val createTopicResult = admin.createTopics(newTopics) + createTopicResult.all().get() + } + + protected def securityProtocol: SecurityProtocol = SecurityProtocol.PLAINTEXT + private var correlationId = 0 + + def connect(socketServer: SocketServer, + listenerName: ListenerName): Socket = { + new Socket("localhost", socketServer.boundPort(listenerName)) + } + + def clientSecurityProps(certAlias: String): Properties = { + TestUtils.securityConfigs(Mode.CLIENT, securityProtocol, None, certAlias, TestUtils.SslCertificateCn, None) // TODO use real trust store and client SASL properties + } +} diff --git a/core/src/test/scala/integration/kafka/server/KRaftClusterTest.scala b/core/src/test/scala/integration/kafka/server/KRaftClusterTest.scala new file mode 100644 index 0000000..25e79ad --- /dev/null +++ b/core/src/test/scala/integration/kafka/server/KRaftClusterTest.scala @@ -0,0 +1,462 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.network.SocketServer +import kafka.server.IntegrationTestUtils.connectAndReceive +import kafka.testkit.{BrokerNode, KafkaClusterTestKit, TestKitNodes} +import kafka.utils.TestUtils +import org.apache.kafka.clients.admin.{Admin, NewPartitionReassignment, NewTopic} +import org.apache.kafka.common.{TopicPartition, TopicPartitionInfo} +import org.apache.kafka.common.message.DescribeClusterRequestData +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.quota.{ClientQuotaAlteration, ClientQuotaEntity, ClientQuotaFilter, ClientQuotaFilterComponent} +import org.apache.kafka.common.requests.{DescribeClusterRequest, DescribeClusterResponse} +import org.apache.kafka.metadata.BrokerState +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{Tag, Test, Timeout} + +import java.util +import java.util.{Arrays, Collections, Optional} +import scala.collection.mutable +import scala.concurrent.duration.{FiniteDuration, MILLISECONDS, SECONDS} +import scala.jdk.CollectionConverters._ + +@Timeout(120) +@Tag("integration") +class KRaftClusterTest { + + @Test + def testCreateClusterAndClose(): Unit = { + val cluster = new KafkaClusterTestKit.Builder( + new TestKitNodes.Builder(). + setNumBrokerNodes(1). + setNumControllerNodes(1).build()).build() + try { + cluster.format() + cluster.startup() + } finally { + cluster.close() + } + } + + @Test + def testCreateClusterAndWaitForBrokerInRunningState(): Unit = { + val cluster = new KafkaClusterTestKit.Builder( + new TestKitNodes.Builder(). + setNumBrokerNodes(3). + setNumControllerNodes(3).build()).build() + try { + cluster.format() + cluster.startup() + TestUtils.waitUntilTrue(() => cluster.brokers().get(0).brokerState == BrokerState.RUNNING, + "Broker never made it to RUNNING state.") + TestUtils.waitUntilTrue(() => cluster.raftManagers().get(0).client.leaderAndEpoch().leaderId.isPresent, + "RaftManager was not initialized.") + val admin = Admin.create(cluster.clientProperties()) + try { + assertEquals(cluster.nodes().clusterId().toString, + admin.describeCluster().clusterId().get()) + } finally { + admin.close() + } + } finally { + cluster.close() + } + } + + @Test + def testCreateClusterAndCreateListDeleteTopic(): Unit = { + val cluster = new KafkaClusterTestKit.Builder( + new TestKitNodes.Builder(). + setNumBrokerNodes(3). + setNumControllerNodes(3).build()).build() + try { + cluster.format() + cluster.startup() + cluster.waitForReadyBrokers() + TestUtils.waitUntilTrue(() => cluster.brokers().get(0).brokerState == BrokerState.RUNNING, + "Broker never made it to RUNNING state.") + TestUtils.waitUntilTrue(() => cluster.raftManagers().get(0).client.leaderAndEpoch().leaderId.isPresent, + "RaftManager was not initialized.") + + val admin = Admin.create(cluster.clientProperties()) + try { + // Create a test topic + val newTopic = Collections.singletonList(new NewTopic("test-topic", 1, 3.toShort)) + val createTopicResult = admin.createTopics(newTopic) + createTopicResult.all().get() + waitForTopicListing(admin, Seq("test-topic"), Seq()) + + // Delete topic + val deleteResult = admin.deleteTopics(Collections.singletonList("test-topic")) + deleteResult.all().get() + + // List again + waitForTopicListing(admin, Seq(), Seq("test-topic")) + } finally { + admin.close() + } + } finally { + cluster.close() + } + } + + @Test + def testCreateClusterAndCreateAndManyTopics(): Unit = { + val cluster = new KafkaClusterTestKit.Builder( + new TestKitNodes.Builder(). + setNumBrokerNodes(3). + setNumControllerNodes(3).build()).build() + try { + cluster.format() + cluster.startup() + cluster.waitForReadyBrokers() + TestUtils.waitUntilTrue(() => cluster.brokers().get(0).brokerState == BrokerState.RUNNING, + "Broker never made it to RUNNING state.") + TestUtils.waitUntilTrue(() => cluster.raftManagers().get(0).client.leaderAndEpoch().leaderId.isPresent, + "RaftManager was not initialized.") + val admin = Admin.create(cluster.clientProperties()) + try { + // Create many topics + val newTopic = new util.ArrayList[NewTopic]() + newTopic.add(new NewTopic("test-topic-1", 2, 3.toShort)) + newTopic.add(new NewTopic("test-topic-2", 2, 3.toShort)) + newTopic.add(new NewTopic("test-topic-3", 2, 3.toShort)) + val createTopicResult = admin.createTopics(newTopic) + createTopicResult.all().get() + + // List created topics + waitForTopicListing(admin, Seq("test-topic-1", "test-topic-2", "test-topic-3"), Seq()) + } finally { + admin.close() + } + } finally { + cluster.close() + } + } + + @Test + def testClientQuotas(): Unit = { + val cluster = new KafkaClusterTestKit.Builder( + new TestKitNodes.Builder(). + setNumBrokerNodes(1). + setNumControllerNodes(1).build()).build() + try { + cluster.format() + cluster.startup() + TestUtils.waitUntilTrue(() => cluster.brokers().get(0).brokerState == BrokerState.RUNNING, + "Broker never made it to RUNNING state.") + val admin = Admin.create(cluster.clientProperties()) + try { + val entity = new ClientQuotaEntity(Map("user" -> "testkit").asJava) + var filter = ClientQuotaFilter.containsOnly( + List(ClientQuotaFilterComponent.ofEntity("user", "testkit")).asJava) + + def alterThenDescribe(entity: ClientQuotaEntity, + quotas: Seq[ClientQuotaAlteration.Op], + filter: ClientQuotaFilter, + expectCount: Int): java.util.Map[ClientQuotaEntity, java.util.Map[String, java.lang.Double]] = { + val alterResult = admin.alterClientQuotas(Seq(new ClientQuotaAlteration(entity, quotas.asJava)).asJava) + try { + alterResult.all().get() + } catch { + case t: Throwable => fail("AlterClientQuotas request failed", t) + } + + def describeOrFail(filter: ClientQuotaFilter): java.util.Map[ClientQuotaEntity, java.util.Map[String, java.lang.Double]] = { + try { + admin.describeClientQuotas(filter).entities().get() + } catch { + case t: Throwable => fail("DescribeClientQuotas request failed", t) + } + } + + val (describeResult, ok) = TestUtils.computeUntilTrue(describeOrFail(filter)) { + results => results.getOrDefault(entity, java.util.Collections.emptyMap[String, java.lang.Double]()).size() == expectCount + } + assertTrue(ok, "Broker never saw new client quotas") + describeResult + } + + var describeResult = alterThenDescribe(entity, + Seq(new ClientQuotaAlteration.Op("request_percentage", 0.99)), filter, 1) + assertEquals(0.99, describeResult.get(entity).get("request_percentage"), 1e-6) + + describeResult = alterThenDescribe(entity, Seq( + new ClientQuotaAlteration.Op("request_percentage", 0.97), + new ClientQuotaAlteration.Op("producer_byte_rate", 10000), + new ClientQuotaAlteration.Op("consumer_byte_rate", 10001) + ), filter, 3) + assertEquals(0.97, describeResult.get(entity).get("request_percentage"), 1e-6) + assertEquals(10000.0, describeResult.get(entity).get("producer_byte_rate"), 1e-6) + assertEquals(10001.0, describeResult.get(entity).get("consumer_byte_rate"), 1e-6) + + describeResult = alterThenDescribe(entity, Seq( + new ClientQuotaAlteration.Op("request_percentage", 0.95), + new ClientQuotaAlteration.Op("producer_byte_rate", null), + new ClientQuotaAlteration.Op("consumer_byte_rate", null) + ), filter, 1) + assertEquals(0.95, describeResult.get(entity).get("request_percentage"), 1e-6) + + describeResult = alterThenDescribe(entity, Seq( + new ClientQuotaAlteration.Op("request_percentage", null)), filter, 0) + + describeResult = alterThenDescribe(entity, + Seq(new ClientQuotaAlteration.Op("producer_byte_rate", 9999)), filter, 1) + assertEquals(9999.0, describeResult.get(entity).get("producer_byte_rate"), 1e-6) + + // Add another quota for a different entity with same user part + val entity2 = new ClientQuotaEntity(Map("user" -> "testkit", "client-id" -> "some-client").asJava) + filter = ClientQuotaFilter.containsOnly( + List( + ClientQuotaFilterComponent.ofEntity("user", "testkit"), + ClientQuotaFilterComponent.ofEntity("client-id", "some-client"), + ).asJava) + describeResult = alterThenDescribe(entity2, + Seq(new ClientQuotaAlteration.Op("producer_byte_rate", 9998)), filter, 1) + assertEquals(9998.0, describeResult.get(entity2).get("producer_byte_rate"), 1e-6) + + // non-strict match + filter = ClientQuotaFilter.contains( + List(ClientQuotaFilterComponent.ofEntity("user", "testkit")).asJava) + + TestUtils.tryUntilNoAssertionError(){ + val results = admin.describeClientQuotas(filter).entities().get() + assertEquals(2, results.size(), "Broker did not see two client quotas") + assertEquals(9999.0, results.get(entity).get("producer_byte_rate"), 1e-6) + assertEquals(9998.0, results.get(entity2).get("producer_byte_rate"), 1e-6) + } + } finally { + admin.close() + } + } finally { + cluster.close() + } + } + + @Test + def testCreateClusterWithAdvertisedPortZero(): Unit = { + val brokerPropertyOverrides: (TestKitNodes, BrokerNode) => Map[String, String] = (nodes, _) => Map( + (KafkaConfig.ListenersProp, s"${nodes.externalListenerName.value}://localhost:0"), + (KafkaConfig.AdvertisedListenersProp, s"${nodes.externalListenerName.value}://localhost:0")) + + doOnStartedKafkaCluster(numBrokerNodes = 3, brokerPropertyOverrides = brokerPropertyOverrides) { implicit cluster => + sendDescribeClusterRequestToBoundPortUntilAllBrokersPropagated(cluster.nodes.externalListenerName, (15L, SECONDS)) + .nodes.values.forEach { broker => + assertEquals("localhost", broker.host, + "Did not advertise configured advertised host") + assertEquals(cluster.brokers.get(broker.id).socketServer.boundPort(cluster.nodes.externalListenerName), broker.port, + "Did not advertise bound socket port") + } + } + } + + @Test + def testCreateClusterWithAdvertisedHostAndPortDifferentFromSocketServer(): Unit = { + val brokerPropertyOverrides: (TestKitNodes, BrokerNode) => Map[String, String] = (nodes, broker) => Map( + (KafkaConfig.ListenersProp, s"${nodes.externalListenerName.value}://localhost:0"), + (KafkaConfig.AdvertisedListenersProp, s"${nodes.externalListenerName.value}://advertised-host-${broker.id}:${broker.id + 100}")) + + doOnStartedKafkaCluster(numBrokerNodes = 3, brokerPropertyOverrides = brokerPropertyOverrides) { implicit cluster => + sendDescribeClusterRequestToBoundPortUntilAllBrokersPropagated(cluster.nodes.externalListenerName, (15L, SECONDS)) + .nodes.values.forEach { broker => + assertEquals(s"advertised-host-${broker.id}", broker.host, "Did not advertise configured advertised host") + assertEquals(broker.id + 100, broker.port, "Did not advertise configured advertised port") + } + } + } + + private def doOnStartedKafkaCluster(numControllerNodes: Int = 1, + numBrokerNodes: Int, + brokerPropertyOverrides: (TestKitNodes, BrokerNode) => Map[String, String]) + (action: KafkaClusterTestKit => Unit): Unit = { + val nodes = new TestKitNodes.Builder() + .setNumControllerNodes(numControllerNodes) + .setNumBrokerNodes(numBrokerNodes) + .build() + nodes.brokerNodes.values.forEach { + broker => broker.propertyOverrides.putAll(brokerPropertyOverrides(nodes, broker).asJava) + } + val cluster = new KafkaClusterTestKit.Builder(nodes).build() + try { + cluster.format() + cluster.startup() + action(cluster) + } finally { + cluster.close() + } + } + + private def sendDescribeClusterRequestToBoundPortUntilAllBrokersPropagated(listenerName: ListenerName, + waitTime: FiniteDuration) + (implicit cluster: KafkaClusterTestKit): DescribeClusterResponse = { + val startTime = System.currentTimeMillis + val runningBrokerServers = waitForRunningBrokers(1, waitTime) + val remainingWaitTime = waitTime - (System.currentTimeMillis - startTime, MILLISECONDS) + sendDescribeClusterRequestToBoundPortUntilBrokersPropagated( + runningBrokerServers.head, listenerName, + cluster.nodes.brokerNodes.size, remainingWaitTime) + } + + private def waitForRunningBrokers(count: Int, waitTime: FiniteDuration) + (implicit cluster: KafkaClusterTestKit): Seq[BrokerServer] = { + def getRunningBrokerServers: Seq[BrokerServer] = cluster.brokers.values.asScala.toSeq + .filter(brokerServer => brokerServer.brokerState == BrokerState.RUNNING) + + val (runningBrokerServers, hasRunningBrokers) = TestUtils.computeUntilTrue(getRunningBrokerServers, waitTime.toMillis)(_.nonEmpty) + assertTrue(hasRunningBrokers, + s"After ${waitTime.toMillis} ms at least $count broker(s) should be in RUNNING state, " + + s"but only ${runningBrokerServers.size} broker(s) are.") + runningBrokerServers + } + + private def sendDescribeClusterRequestToBoundPortUntilBrokersPropagated(destination: BrokerServer, + listenerName: ListenerName, + expectedBrokerCount: Int, + waitTime: FiniteDuration): DescribeClusterResponse = { + val (describeClusterResponse, metadataUpToDate) = TestUtils.computeUntilTrue( + compute = sendDescribeClusterRequestToBoundPort(destination.socketServer, listenerName), + waitTime = waitTime.toMillis + ) { + response => response.nodes.size == expectedBrokerCount + } + + assertTrue(metadataUpToDate, + s"After ${waitTime.toMillis} ms Broker is only aware of ${describeClusterResponse.nodes.size} brokers, " + + s"but $expectedBrokerCount are expected.") + + describeClusterResponse + } + + private def sendDescribeClusterRequestToBoundPort(destination: SocketServer, + listenerName: ListenerName): DescribeClusterResponse = + connectAndReceive[DescribeClusterResponse]( + request = new DescribeClusterRequest.Builder(new DescribeClusterRequestData()).build(), + destination = destination, + listenerName = listenerName + ) + + @Test + def testCreateClusterAndPerformReassignment(): Unit = { + val cluster = new KafkaClusterTestKit.Builder( + new TestKitNodes.Builder(). + setNumBrokerNodes(4). + setNumControllerNodes(3).build()).build() + try { + cluster.format() + cluster.startup() + cluster.waitForReadyBrokers() + val admin = Admin.create(cluster.clientProperties()) + try { + // Create the topic. + val assignments = new util.HashMap[Integer, util.List[Integer]] + assignments.put(0, Arrays.asList(0, 1, 2)) + assignments.put(1, Arrays.asList(1, 2, 3)) + assignments.put(2, Arrays.asList(2, 3, 0)) + assignments.put(3, Arrays.asList(3, 2, 1)) + val createTopicResult = admin.createTopics(Collections.singletonList( + new NewTopic("foo", assignments))) + createTopicResult.all().get() + waitForTopicListing(admin, Seq("foo"), Seq()) + + // Start some reassignments. + assertEquals(Collections.emptyMap(), admin.listPartitionReassignments().reassignments().get()) + val reassignments = new util.HashMap[TopicPartition, Optional[NewPartitionReassignment]] + reassignments.put(new TopicPartition("foo", 0), + Optional.of(new NewPartitionReassignment(Arrays.asList(2, 1, 0)))) + reassignments.put(new TopicPartition("foo", 1), + Optional.of(new NewPartitionReassignment(Arrays.asList(0, 1, 2)))) + reassignments.put(new TopicPartition("foo", 2), + Optional.of(new NewPartitionReassignment(Arrays.asList(2, 3)))) + reassignments.put(new TopicPartition("foo", 3), + Optional.of(new NewPartitionReassignment(Arrays.asList(3, 2, 0, 1)))) + admin.alterPartitionReassignments(reassignments).all().get() + TestUtils.waitUntilTrue( + () => admin.listPartitionReassignments().reassignments().get().isEmpty(), + "The reassignment never completed.") + var currentMapping: Seq[Seq[Int]] = Seq() + val expectedMapping = Seq(Seq(2, 1, 0), Seq(0, 1, 2), Seq(2, 3), Seq(3, 2, 0, 1)) + TestUtils.waitUntilTrue( () => { + val topicInfoMap = admin.describeTopics(Collections.singleton("foo")).allTopicNames().get() + if (topicInfoMap.containsKey("foo")) { + currentMapping = translatePartitionInfoToSeq(topicInfoMap.get("foo").partitions()) + expectedMapping.equals(currentMapping) + } else { + false + } + }, "Timed out waiting for replica assignments for topic foo. " + + s"Wanted: ${expectedMapping}. Got: ${currentMapping}") + + checkReplicaManager( + cluster, + List( + (0, List(true, true, false, true)), + (1, List(true, true, false, true)), + (2, List(true, true, true, true)), + (3, List(false, false, true, true)) + ) + ) + } finally { + admin.close() + } + } finally { + cluster.close() + } + } + + private def checkReplicaManager(cluster: KafkaClusterTestKit, expectedHosting: List[(Int, List[Boolean])]): Unit = { + for ((brokerId, partitionsIsHosted) <- expectedHosting) { + val broker = cluster.brokers().get(brokerId) + + for ((isHosted, partitionId) <- partitionsIsHosted.zipWithIndex) { + val topicPartition = new TopicPartition("foo", partitionId) + if (isHosted) { + assertNotEquals( + HostedPartition.None, + broker.replicaManager.getPartition(topicPartition), + s"topicPartition = $topicPartition" + ) + } else { + assertEquals( + HostedPartition.None, + broker.replicaManager.getPartition(topicPartition), + s"topicPartition = $topicPartition" + ) + } + } + } + } + + private def translatePartitionInfoToSeq(partitions: util.List[TopicPartitionInfo]): Seq[Seq[Int]] = { + partitions.asScala.map(partition => partition.replicas().asScala.map(_.id()).toSeq).toSeq + } + + private def waitForTopicListing(admin: Admin, + expectedPresent: Seq[String], + expectedAbsent: Seq[String]): Unit = { + val topicsNotFound = new util.HashSet[String] + var extraTopics: mutable.Set[String] = null + expectedPresent.foreach(topicsNotFound.add(_)) + TestUtils.waitUntilTrue(() => { + admin.listTopics().names().get().forEach(name => topicsNotFound.remove(name)) + extraTopics = admin.listTopics().names().get().asScala.filter(expectedAbsent.contains(_)) + topicsNotFound.isEmpty && extraTopics.isEmpty + }, s"Failed to find topic(s): ${topicsNotFound.asScala} and NOT find topic(s): ${extraTopics}") + } +} diff --git a/core/src/test/scala/integration/kafka/server/MetadataRequestBetweenDifferentIbpTest.scala b/core/src/test/scala/integration/kafka/server/MetadataRequestBetweenDifferentIbpTest.scala new file mode 100644 index 0000000..387c244 --- /dev/null +++ b/core/src/test/scala/integration/kafka/server/MetadataRequestBetweenDifferentIbpTest.scala @@ -0,0 +1,96 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.api.{ApiVersion, KAFKA_2_8_IV0} +import kafka.network.SocketServer +import kafka.utils.TestUtils +import kafka.zk.ZkVersion +import org.apache.kafka.common.Uuid +import org.apache.kafka.common.message.MetadataRequestData +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.{MetadataRequest, MetadataResponse} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.collection.{Map, Seq} + +class MetadataRequestBetweenDifferentIbpTest extends BaseRequestTest { + + override def brokerCount: Int = 3 + override def generateConfigs: Seq[KafkaConfig] = { + Seq( + createConfig(0, KAFKA_2_8_IV0), + createConfig(1, ApiVersion.latestVersion), + createConfig(2, ApiVersion.latestVersion) + ) + } + + @Test + def testUnknownTopicId(): Unit = { + val topic = "topic" + + // Kill controller and restart until broker with latest ibp become controller + ensureControllerIn(Seq(1, 2)) + createTopic(topic, Map(0 -> Seq(1, 2, 0), 1 -> Seq(2, 0, 1))) + + val resp1 = sendMetadataRequest(new MetadataRequest(requestData(topic, Uuid.ZERO_UUID), 12.toShort), controllerSocketServer) + val topicId = resp1.topicMetadata.iterator().next().topicId() + + // We could still get topic metadata by topicId + val topicMetadata = sendMetadataRequest(new MetadataRequest(requestData(null, topicId), 12.toShort), controllerSocketServer) + .topicMetadata.iterator().next() + assertEquals(topicId, topicMetadata.topicId()) + assertEquals(topic, topicMetadata.topic()) + + // Make the broker whose version=KAFKA_2_8_IV0 controller + ensureControllerIn(Seq(0)) + + // Restart the broker whose ibp is higher, and the controller will send metadata request to it + killBroker(1) + restartDeadBrokers() + + // Send request to a broker whose ibp is higher and restarted just now + val resp2 = sendMetadataRequest(new MetadataRequest(requestData(topic, topicId), 12.toShort), brokerSocketServer(1)) + assertEquals(Errors.UNKNOWN_TOPIC_ID, resp2.topicMetadata.iterator().next().error()) + } + + private def ensureControllerIn(brokerIds: Seq[Int]): Unit = { + while (!brokerIds.contains(controllerSocketServer.config.brokerId)) { + zkClient.deleteController(ZkVersion.MatchAnyVersion) + TestUtils.waitUntilControllerElected(zkClient) + } + } + + private def createConfig(nodeId: Int,interBrokerVersion: ApiVersion): KafkaConfig = { + val props = TestUtils.createBrokerConfig(nodeId, zkConnect) + props.put(KafkaConfig.InterBrokerProtocolVersionProp, interBrokerVersion.version) + KafkaConfig.fromProps(props) + } + + def requestData(topic: String, topicId: Uuid): MetadataRequestData = { + val data = new MetadataRequestData + data.topics.add(new MetadataRequestData.MetadataRequestTopic().setName(topic).setTopicId(topicId)) + data + } + + private def sendMetadataRequest(request: MetadataRequest, destination: SocketServer): MetadataResponse = { + connectAndReceive[MetadataResponse](request, destination) + } + +} diff --git a/core/src/test/scala/integration/kafka/server/MultipleListenersWithAdditionalJaasContextTest.scala b/core/src/test/scala/integration/kafka/server/MultipleListenersWithAdditionalJaasContextTest.scala new file mode 100644 index 0000000..39aa3c3 --- /dev/null +++ b/core/src/test/scala/integration/kafka/server/MultipleListenersWithAdditionalJaasContextTest.scala @@ -0,0 +1,45 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.util.Properties + +import scala.collection.Seq + +import kafka.utils.JaasTestUtils +import kafka.utils.JaasTestUtils.JaasSection + +class MultipleListenersWithAdditionalJaasContextTest extends MultipleListenersWithSameSecurityProtocolBaseTest { + + import MultipleListenersWithSameSecurityProtocolBaseTest._ + + override def staticJaasSections: Seq[JaasSection] = { + val (serverKeytabFile, _) = maybeCreateEmptyKeytabFiles() + JaasTestUtils.zkSections :+ + JaasTestUtils.kafkaServerSection("secure_external.KafkaServer", kafkaServerSaslMechanisms(SecureExternal), Some(serverKeytabFile)) + } + + override protected def dynamicJaasSections: Properties = { + val props = new Properties + kafkaServerSaslMechanisms(SecureInternal).foreach { mechanism => + addDynamicJaasSection(props, SecureInternal, mechanism, + JaasTestUtils.kafkaServerSection("secure_internal.KafkaServer", Seq(mechanism), None)) + } + props + } +} diff --git a/core/src/test/scala/integration/kafka/server/MultipleListenersWithDefaultJaasContextTest.scala b/core/src/test/scala/integration/kafka/server/MultipleListenersWithDefaultJaasContextTest.scala new file mode 100644 index 0000000..23746c4 --- /dev/null +++ b/core/src/test/scala/integration/kafka/server/MultipleListenersWithDefaultJaasContextTest.scala @@ -0,0 +1,34 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.util.Properties + +import scala.collection.Seq + +import kafka.api.Both +import kafka.utils.JaasTestUtils.JaasSection + +class MultipleListenersWithDefaultJaasContextTest extends MultipleListenersWithSameSecurityProtocolBaseTest { + + override def staticJaasSections: Seq[JaasSection] = + jaasSections(kafkaServerSaslMechanisms.values.flatMap(identity).toSeq, Some(kafkaClientSaslMechanism), Both) + + override protected def dynamicJaasSections: Properties = new Properties + +} diff --git a/core/src/test/scala/integration/kafka/server/MultipleListenersWithSameSecurityProtocolBaseTest.scala b/core/src/test/scala/integration/kafka/server/MultipleListenersWithSameSecurityProtocolBaseTest.scala new file mode 100644 index 0000000..1d865f9 --- /dev/null +++ b/core/src/test/scala/integration/kafka/server/MultipleListenersWithSameSecurityProtocolBaseTest.scala @@ -0,0 +1,191 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.io.File +import java.util.{Collections, Objects, Properties} +import java.util.concurrent.TimeUnit + +import kafka.api.SaslSetup +import kafka.coordinator.group.OffsetConfig +import kafka.utils.JaasTestUtils.JaasSection +import kafka.utils.{JaasTestUtils, TestUtils} +import kafka.utils.Implicits._ +import kafka.server.QuorumTestHarness +import org.apache.kafka.clients.consumer.KafkaConsumer +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord} +import org.apache.kafka.common.config.SslConfigs +import org.apache.kafka.common.internals.Topic +import org.apache.kafka.common.network.{ListenerName, Mode} +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.collection.Seq + +object MultipleListenersWithSameSecurityProtocolBaseTest { + val SecureInternal = "SECURE_INTERNAL" + val SecureExternal = "SECURE_EXTERNAL" + val Internal = "INTERNAL" + val External = "EXTERNAL" + val GssApi = "GSSAPI" + val Plain = "PLAIN" +} + +abstract class MultipleListenersWithSameSecurityProtocolBaseTest extends QuorumTestHarness with SaslSetup { + + import MultipleListenersWithSameSecurityProtocolBaseTest._ + + private val trustStoreFile = File.createTempFile("truststore", ".jks") + private val servers = new ArrayBuffer[KafkaServer] + private val producers = mutable.Map[ClientMetadata, KafkaProducer[Array[Byte], Array[Byte]]]() + private val consumers = mutable.Map[ClientMetadata, KafkaConsumer[Array[Byte], Array[Byte]]]() + + protected val kafkaClientSaslMechanism = Plain + protected val kafkaServerSaslMechanisms = Map( + SecureExternal -> Seq("SCRAM-SHA-256", GssApi), + SecureInternal -> Seq(Plain, "SCRAM-SHA-512")) + + protected def staticJaasSections: Seq[JaasSection] + protected def dynamicJaasSections: Properties + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + startSasl(staticJaasSections) + super.setUp(testInfo) + // 2 brokers so that we can test that the data propagates correctly via UpdateMetadadaRequest + val numServers = 2 + + (0 until numServers).foreach { brokerId => + + val props = TestUtils.createBrokerConfig(brokerId, zkConnect, trustStoreFile = Some(trustStoreFile)) + // Ensure that we can support multiple listeners per security protocol and multiple security protocols + props.put(KafkaConfig.ListenersProp, s"$SecureInternal://localhost:0, $Internal://localhost:0, " + + s"$SecureExternal://localhost:0, $External://localhost:0") + props.put(KafkaConfig.ListenerSecurityProtocolMapProp, s"$Internal:PLAINTEXT, $SecureInternal:SASL_SSL," + + s"$External:PLAINTEXT, $SecureExternal:SASL_SSL") + props.put(KafkaConfig.InterBrokerListenerNameProp, Internal) + props.put(KafkaConfig.ZkEnableSecureAclsProp, "true") + props.put(KafkaConfig.SaslMechanismInterBrokerProtocolProp, kafkaClientSaslMechanism) + props.put(s"${new ListenerName(SecureInternal).configPrefix}${KafkaConfig.SaslEnabledMechanismsProp}", + kafkaServerSaslMechanisms(SecureInternal).mkString(",")) + props.put(s"${new ListenerName(SecureExternal).configPrefix}${KafkaConfig.SaslEnabledMechanismsProp}", + kafkaServerSaslMechanisms(SecureExternal).mkString(",")) + props.put(KafkaConfig.SaslKerberosServiceNameProp, "kafka") + props ++= dynamicJaasSections + + props ++= TestUtils.sslConfigs(Mode.SERVER, false, Some(trustStoreFile), s"server$brokerId") + + // set listener-specific configs and set an invalid path for the global config to verify that the overrides work + Seq(SecureInternal, SecureExternal).foreach { listenerName => + props.put(new ListenerName(listenerName).configPrefix + SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG, + props.get(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG)) + } + props.put(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG, "invalid/file/path") + + servers += TestUtils.createServer(KafkaConfig.fromProps(props)) + } + + servers.map(_.config).foreach { config => + assertEquals(4, config.listeners.size, s"Unexpected listener count for broker ${config.brokerId}") + // KAFKA-5184 seems to show that this value can sometimes be PLAINTEXT, so verify it here + assertEquals(Internal, config.interBrokerListenerName.value, + s"Unexpected ${KafkaConfig.InterBrokerListenerNameProp} for broker ${config.brokerId}") + } + + TestUtils.createTopic(zkClient, Topic.GROUP_METADATA_TOPIC_NAME, OffsetConfig.DefaultOffsetsTopicNumPartitions, + replicationFactor = 2, servers, servers.head.groupCoordinator.offsetsTopicConfigs) + + createScramCredentials(zkConnect, JaasTestUtils.KafkaScramUser, JaasTestUtils.KafkaScramPassword) + + servers.head.config.listeners.foreach { endPoint => + val listenerName = endPoint.listenerName + + val trustStoreFile = + if (TestUtils.usesSslTransportLayer(endPoint.securityProtocol)) Some(this.trustStoreFile) + else None + + val bootstrapServers = TestUtils.bootstrapServers(servers, listenerName) + + def addProducerConsumer(listenerName: ListenerName, mechanism: String, saslProps: Option[Properties]): Unit = { + + val topic = s"${listenerName.value}${producers.size}" + TestUtils.createTopic(zkClient, topic, 2, 2, servers) + val clientMetadata = ClientMetadata(listenerName, mechanism, topic) + + producers(clientMetadata) = TestUtils.createProducer(bootstrapServers, acks = -1, + securityProtocol = endPoint.securityProtocol, trustStoreFile = trustStoreFile, saslProperties = saslProps) + + consumers(clientMetadata) = TestUtils.createConsumer(bootstrapServers, groupId = clientMetadata.toString, + securityProtocol = endPoint.securityProtocol, trustStoreFile = trustStoreFile, saslProperties = saslProps) + } + + if (TestUtils.usesSaslAuthentication(endPoint.securityProtocol)) { + kafkaServerSaslMechanisms(endPoint.listenerName.value).foreach { mechanism => + addProducerConsumer(listenerName, mechanism, Some(kafkaClientSaslProperties(mechanism, dynamicJaasConfig = true))) + } + } else { + addProducerConsumer(listenerName, "", saslProps = None) + } + } + } + + @AfterEach + override def tearDown(): Unit = { + producers.values.foreach(_.close()) + consumers.values.foreach(_.close()) + TestUtils.shutdownServers(servers) + super.tearDown() + closeSasl() + } + + /** + * Tests that we can produce and consume to/from all broker-defined listeners and security protocols. We produce + * with acks=-1 to ensure that replication is also working. + */ + @Test + def testProduceConsume(): Unit = { + producers.foreach { case (clientMetadata, producer) => + val producerRecords = (1 to 10).map(i => new ProducerRecord(clientMetadata.topic, s"key$i".getBytes, + s"value$i".getBytes)) + producerRecords.map(producer.send).map(_.get(10, TimeUnit.SECONDS)) + + val consumer = consumers(clientMetadata) + consumer.subscribe(Collections.singleton(clientMetadata.topic)) + TestUtils.consumeRecords(consumer, producerRecords.size) + } + } + + protected def addDynamicJaasSection(props: Properties, listener: String, mechanism: String, jaasSection: JaasSection): Unit = { + val listenerName = new ListenerName(listener) + val prefix = listenerName.saslMechanismConfigPrefix(mechanism) + val jaasConfig = jaasSection.modules.head.toString + props.put(s"${prefix}${KafkaConfig.SaslJaasConfigProp}", jaasConfig) + } + + case class ClientMetadata(val listenerName: ListenerName, val saslMechanism: String, topic: String) { + override def hashCode: Int = Objects.hash(listenerName, saslMechanism) + override def equals(obj: Any): Boolean = obj match { + case other: ClientMetadata => listenerName == other.listenerName && saslMechanism == other.saslMechanism && topic == other.topic + case _ => false + } + override def toString: String = s"${listenerName.value}:$saslMechanism:$topic" + } +} diff --git a/core/src/test/scala/integration/kafka/server/QuorumTestHarness.scala b/core/src/test/scala/integration/kafka/server/QuorumTestHarness.scala new file mode 100755 index 0000000..c520ffb --- /dev/null +++ b/core/src/test/scala/integration/kafka/server/QuorumTestHarness.scala @@ -0,0 +1,360 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.io.{ByteArrayOutputStream, File, PrintStream} +import java.net.InetSocketAddress +import java.util +import java.util.{Collections, Properties} +import java.util.concurrent.CompletableFuture + +import javax.security.auth.login.Configuration +import kafka.raft.KafkaRaftManager +import kafka.tools.StorageTool +import kafka.utils.{CoreUtils, Logging, TestInfoUtils, TestUtils} +import kafka.zk.{AdminZkClient, EmbeddedZookeeper, KafkaZkClient} +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.{TopicPartition, Uuid} +import org.apache.kafka.common.security.JaasUtils +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.utils.{Exit, Time} +import org.apache.kafka.metadata.MetadataRecordSerde +import org.apache.kafka.raft.RaftConfig.{AddressSpec, InetAddressSpec} +import org.apache.kafka.server.common.ApiMessageAndVersion +import org.apache.zookeeper.client.ZKClientConfig +import org.apache.zookeeper.{WatchedEvent, Watcher, ZooKeeper} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterAll, AfterEach, BeforeAll, BeforeEach, Tag, TestInfo} + +import scala.collection.{Seq, immutable} + +trait QuorumImplementation { + def createAndStartBroker(config: KafkaConfig, + time: Time): KafkaBroker + + def shutdown(): Unit +} + +class ZooKeeperQuorumImplementation(val zookeeper: EmbeddedZookeeper, + val zkClient: KafkaZkClient, + val adminZkClient: AdminZkClient, + val log: Logging) extends QuorumImplementation { + override def createAndStartBroker(config: KafkaConfig, + time: Time): KafkaBroker = { + val server = new KafkaServer(config, time, None, false) + server.startup() + server + } + + override def shutdown(): Unit = { + CoreUtils.swallow(zkClient.close(), log) + CoreUtils.swallow(zookeeper.shutdown(), log) + } +} + +class KRaftQuorumImplementation(val raftManager: KafkaRaftManager[ApiMessageAndVersion], + val controllerServer: ControllerServer, + val metadataDir: File, + val controllerQuorumVotersFuture: CompletableFuture[util.Map[Integer, AddressSpec]], + val clusterId: String, + val log: Logging) extends QuorumImplementation { + override def createAndStartBroker(config: KafkaConfig, + time: Time): KafkaBroker = { + val broker = new BrokerServer(config = config, + metaProps = new MetaProperties(clusterId, config.nodeId), + raftManager = raftManager, + time = time, + metrics = new Metrics(), + threadNamePrefix = Some("Broker%02d_".format(config.nodeId)), + initialOfflineDirs = Seq(), + controllerQuorumVotersFuture = controllerQuorumVotersFuture, + supportedFeatures = Collections.emptyMap()) + broker.startup() + broker + } + + override def shutdown(): Unit = { + CoreUtils.swallow(raftManager.shutdown(), log) + CoreUtils.swallow(controllerServer.shutdown(), log) + } +} + +@Tag("integration") +abstract class QuorumTestHarness extends Logging { + val zkConnectionTimeout = 10000 + val zkSessionTimeout = 15000 // Allows us to avoid ZK session expiration due to GC up to 2/3 * 15000ms = 10 secs + val zkMaxInFlightRequests = Int.MaxValue + + protected def zkAclsEnabled: Option[Boolean] = None + + /** + * When in KRaft mode, this test harness only support PLAINTEXT + */ + private val controllerListenerSecurityProtocol: SecurityProtocol = SecurityProtocol.PLAINTEXT + + protected def kraftControllerConfigs(): Seq[Properties] = { + Seq(new Properties()) + } + + private var implementation: QuorumImplementation = null + + def isKRaftTest(): Boolean = implementation.isInstanceOf[KRaftQuorumImplementation] + + def checkIsZKTest(): Unit = { + if (isKRaftTest()) { + throw new RuntimeException("This function can't be accessed when running the test " + + "in KRaft mode. ZooKeeper mode is required.") + } + } + + def checkIsKRaftTest(): Unit = { + if (!isKRaftTest()) { + throw new RuntimeException("This function can't be accessed when running the test " + + "in ZooKeeper mode. KRaft mode is required.") + } + } + + private def asZk(): ZooKeeperQuorumImplementation = { + checkIsZKTest() + implementation.asInstanceOf[ZooKeeperQuorumImplementation] + } + + private def asKRaft(): KRaftQuorumImplementation = { + checkIsKRaftTest() + implementation.asInstanceOf[KRaftQuorumImplementation] + } + + def zookeeper: EmbeddedZookeeper = asZk().zookeeper + + def zkClient: KafkaZkClient = asZk().zkClient + + def zkClientOrNull: KafkaZkClient = if (isKRaftTest()) null else asZk().zkClient + + def adminZkClient: AdminZkClient = asZk().adminZkClient + + def zkPort: Int = asZk().zookeeper.port + + def zkConnect: String = s"127.0.0.1:$zkPort" + + def zkConnectOrNull: String = if (isKRaftTest()) null else zkConnect + + def controllerServer: ControllerServer = asKRaft().controllerServer + + // Note: according to the junit documentation: "JUnit Jupiter does not guarantee the execution + // order of multiple @BeforeEach methods that are declared within a single test class or test + // interface." Therefore, if you have things you would like to do before each test case runs, it + // is best to override this function rather than declaring a new @BeforeEach function. + // That way you control the initialization order. + @BeforeEach + def setUp(testInfo: TestInfo): Unit = { + Exit.setExitProcedure((code, message) => { + try { + throw new RuntimeException(s"exit(${code}, ${message}) called!") + } catch { + case e: Throwable => error("test error", e) + throw e + } finally { + tearDown() + } + }) + Exit.setHaltProcedure((code, message) => { + try { + throw new RuntimeException(s"halt(${code}, ${message}) called!") + } catch { + case e: Throwable => error("test error", e) + throw e + } finally { + tearDown() + } + }) + val name = if (testInfo.getTestMethod().isPresent()) { + testInfo.getTestMethod().get().toString() + } else { + "[unspecified]" + } + if (TestInfoUtils.isKRaft(testInfo)) { + info(s"Running KRAFT test ${name}") + implementation = newKRaftQuorum(testInfo) + } else { + info(s"Running ZK test ${name}") + implementation = newZooKeeperQuorum() + } + } + + def createAndStartBroker(config: KafkaConfig, + time: Time = Time.SYSTEM): KafkaBroker = { + implementation.createAndStartBroker(config, + time) + } + + def shutdownZooKeeper(): Unit = asZk().shutdown() + + private def formatDirectories(directories: immutable.Seq[String], + metaProperties: MetaProperties): Unit = { + val stream = new ByteArrayOutputStream() + var out: PrintStream = null + try { + out = new PrintStream(stream) + if (StorageTool.formatCommand(out, directories, metaProperties, false) != 0) { + throw new RuntimeException(stream.toString()) + } + debug(s"Formatted storage directory(ies) ${directories}") + } finally { + if (out != null) out.close() + stream.close() + } + } + + private def newKRaftQuorum(testInfo: TestInfo): KRaftQuorumImplementation = { + val clusterId = Uuid.randomUuid().toString + val metadataDir = TestUtils.tempDir() + val metaProperties = new MetaProperties(clusterId, 0) + formatDirectories(immutable.Seq(metadataDir.getAbsolutePath()), metaProperties) + val controllerMetrics = new Metrics() + val propsList = kraftControllerConfigs() + if (propsList.size != 1) { + throw new RuntimeException("Only one KRaft controller is supported for now.") + } + val props = propsList(0) + props.setProperty(KafkaConfig.ProcessRolesProp, "controller") + props.setProperty(KafkaConfig.NodeIdProp, "1000") + props.setProperty(KafkaConfig.MetadataLogDirProp, metadataDir.getAbsolutePath()) + val proto = controllerListenerSecurityProtocol.toString() + props.setProperty(KafkaConfig.ListenerSecurityProtocolMapProp, s"${proto}:${proto}") + props.setProperty(KafkaConfig.ListenersProp, s"${proto}://localhost:0") + props.setProperty(KafkaConfig.ControllerListenerNamesProp, proto) + props.setProperty(KafkaConfig.QuorumVotersProp, "1000@localhost:0") + val config = new KafkaConfig(props) + val threadNamePrefix = "Controller_" + testInfo.getDisplayName + val controllerQuorumVotersFuture = new CompletableFuture[util.Map[Integer, AddressSpec]] + val raftManager = new KafkaRaftManager( + metaProperties = metaProperties, + config = config, + recordSerde = MetadataRecordSerde.INSTANCE, + topicPartition = new TopicPartition(KafkaRaftServer.MetadataTopic, 0), + topicId = KafkaRaftServer.MetadataTopicId, + time = Time.SYSTEM, + metrics = controllerMetrics, + threadNamePrefixOpt = Option(threadNamePrefix), + controllerQuorumVotersFuture = controllerQuorumVotersFuture) + var controllerServer: ControllerServer = null + try { + controllerServer = new ControllerServer( + metaProperties = metaProperties, + config = config, + raftManager = raftManager, + time = Time.SYSTEM, + metrics = controllerMetrics, + threadNamePrefix = Option(threadNamePrefix), + controllerQuorumVotersFuture = controllerQuorumVotersFuture) + controllerServer.socketServerFirstBoundPortFuture.whenComplete((port, e) => { + if (e != null) { + error("Error completing controller socket server future", e) + controllerQuorumVotersFuture.completeExceptionally(e) + } else { + controllerQuorumVotersFuture.complete(Collections.singletonMap(1000, + new InetAddressSpec(new InetSocketAddress("localhost", port)))) + } + }) + controllerServer.startup() + raftManager.startup() + controllerServer.startup() + } catch { + case e: Throwable => + CoreUtils.swallow(raftManager.shutdown(), this) + if (controllerServer != null) CoreUtils.swallow(controllerServer.shutdown(), this) + throw e + } + new KRaftQuorumImplementation(raftManager, + controllerServer, + metadataDir, + controllerQuorumVotersFuture, + clusterId, + this) + } + + private def newZooKeeperQuorum(): ZooKeeperQuorumImplementation = { + val zookeeper = new EmbeddedZookeeper() + var zkClient: KafkaZkClient = null + var adminZkClient: AdminZkClient = null + try { + zkClient = KafkaZkClient(s"127.0.0.1:${zookeeper.port}", + zkAclsEnabled.getOrElse(JaasUtils.isZkSaslEnabled), + zkSessionTimeout, + zkConnectionTimeout, + zkMaxInFlightRequests, + Time.SYSTEM, + name = "ZooKeeperTestHarness", + new ZKClientConfig) + adminZkClient = new AdminZkClient(zkClient) + } catch { + case t: Throwable => + CoreUtils.swallow(zookeeper.shutdown(), this) + if (zkClient != null) CoreUtils.swallow(zkClient.close(), this) + throw t + } + new ZooKeeperQuorumImplementation(zookeeper, + zkClient, + adminZkClient, + this) + } + + @AfterEach + def tearDown(): Unit = { + Exit.resetExitProcedure() + Exit.resetHaltProcedure() + if (implementation != null) { + implementation.shutdown() + } + Configuration.setConfiguration(null) + } + + // Trigger session expiry by reusing the session id in another client + def createZooKeeperClientToTriggerSessionExpiry(zooKeeper: ZooKeeper): ZooKeeper = { + val dummyWatcher = new Watcher { + override def process(event: WatchedEvent): Unit = {} + } + val anotherZkClient = new ZooKeeper(zkConnect, 1000, dummyWatcher, + zooKeeper.getSessionId, + zooKeeper.getSessionPasswd) + assertNull(anotherZkClient.exists("/nonexistent", false)) // Make sure new client works + anotherZkClient + } +} + +object QuorumTestHarness { + val ZkClientEventThreadSuffix = "-EventThread" + + /** + * Verify that a previous test that doesn't use QuorumTestHarness hasn't left behind an unexpected thread. + * This assumes that brokers, ZooKeeper clients, producers and consumers are not created in another @BeforeClass, + * which is true for core tests where this harness is used. + */ + @BeforeAll + def setUpClass(): Unit = { + TestUtils.verifyNoUnexpectedThreads("@BeforeAll") + } + + /** + * Verify that tests from the current test class using QuorumTestHarness haven't left behind an unexpected thread + */ + @AfterAll + def tearDownClass(): Unit = { + TestUtils.verifyNoUnexpectedThreads("@AfterAll") + } +} diff --git a/core/src/test/scala/integration/kafka/server/RaftClusterSnapshotTest.scala b/core/src/test/scala/integration/kafka/server/RaftClusterSnapshotTest.scala new file mode 100644 index 0000000..6cefccd --- /dev/null +++ b/core/src/test/scala/integration/kafka/server/RaftClusterSnapshotTest.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util.Collections +import kafka.testkit.KafkaClusterTestKit +import kafka.testkit.TestKitNodes +import kafka.utils.TestUtils +import org.apache.kafka.common.utils.BufferSupplier; +import org.apache.kafka.metadata.MetadataRecordSerde +import org.apache.kafka.snapshot.SnapshotReader +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertNotEquals +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Timeout +import scala.jdk.CollectionConverters._ + +@Timeout(120) +class RaftClusterSnapshotTest { + + @Test + def testSnapshotsGenerated(): Unit = { + val numberOfBrokers = 3 + val numberOfControllers = 3 + val metadataSnapshotMaxNewRecordBytes = 100 + + TestUtils.resource( + new KafkaClusterTestKit + .Builder( + new TestKitNodes.Builder() + .setNumBrokerNodes(numberOfBrokers) + .setNumControllerNodes(numberOfControllers) + .build() + ) + .setConfigProp( + KafkaConfig.MetadataSnapshotMaxNewRecordBytesProp, + metadataSnapshotMaxNewRecordBytes.toString + ) + .build() + ) { cluster => + cluster.format() + cluster.startup() + + // Check that every controller and broker has a snapshot + TestUtils.waitUntilTrue( + () => { + cluster.raftManagers().asScala.forall { case (_, raftManager) => + raftManager.replicatedLog.latestSnapshotId.isPresent + } + }, + s"Expected for every controller and broker to generate a snapshot: ${ + cluster.raftManagers().asScala.map { case (id, raftManager) => + (id, raftManager.replicatedLog.latestSnapshotId) + } + }" + ) + + assertEquals(numberOfControllers + numberOfBrokers, cluster.raftManagers.size()) + + // For every controller and broker perform some sanity checks against the lastest snapshot + for ((_, raftManager) <- cluster.raftManagers().asScala) { + TestUtils.resource( + SnapshotReader.of( + raftManager.replicatedLog.latestSnapshot.get(), + new MetadataRecordSerde(), + BufferSupplier.create(), + 1 + ) + ) { snapshot => + // Check that the snapshot is non-empty + assertTrue(snapshot.hasNext()) + + // Check that we can read the entire snapshot + while (snapshot.hasNext()) { + val batch = snapshot.next() + assertTrue(batch.sizeInBytes > 0) + assertNotEquals(Collections.emptyList(), batch.records()) + } + } + } + } + } +} diff --git a/core/src/test/scala/integration/kafka/server/ScramServerStartupTest.scala b/core/src/test/scala/integration/kafka/server/ScramServerStartupTest.scala new file mode 100644 index 0000000..9190647 --- /dev/null +++ b/core/src/test/scala/integration/kafka/server/ScramServerStartupTest.scala @@ -0,0 +1,65 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util.Collections + +import kafka.api.{IntegrationTestHarness, KafkaSasl, SaslSetup} +import kafka.utils._ +import kafka.zk.ConfigEntityChangeNotificationZNode +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ + +/** + * Tests that there are no failed authentications during broker startup. This is to verify + * that SCRAM credentials are loaded by brokers before client connections can be made. + * For simplicity of testing, this test verifies authentications of controller connections. + */ +class ScramServerStartupTest extends IntegrationTestHarness with SaslSetup { + + override val brokerCount = 1 + + private val kafkaClientSaslMechanism = "SCRAM-SHA-256" + private val kafkaServerSaslMechanisms = Collections.singletonList("SCRAM-SHA-256").asScala + + override protected def securityProtocol = SecurityProtocol.SASL_PLAINTEXT + + override protected val serverSaslProperties = Some(kafkaServerSaslProperties(kafkaServerSaslMechanisms, kafkaClientSaslMechanism)) + override protected val clientSaslProperties = Some(kafkaClientSaslProperties(kafkaClientSaslMechanism)) + + override def configureSecurityBeforeServersStart(): Unit = { + super.configureSecurityBeforeServersStart() + zkClient.makeSurePersistentPathExists(ConfigEntityChangeNotificationZNode.path) + // Create credentials before starting brokers + createScramCredentials(zkConnect, JaasTestUtils.KafkaScramAdmin, JaasTestUtils.KafkaScramAdminPassword) + + startSasl(jaasSections(kafkaServerSaslMechanisms, Option(kafkaClientSaslMechanism), KafkaSasl)) + } + + @Test + def testAuthentications(): Unit = { + val successfulAuths = TestUtils.totalMetricValue(servers.head, "successful-authentication-total") + assertTrue(successfulAuths > 0, "No successful authentications") + val failedAuths = TestUtils.totalMetricValue(servers.head, "failed-authentication-total") + assertEquals(0, failedAuths) + } +} diff --git a/core/src/test/scala/integration/kafka/tools/MirrorMakerIntegrationTest.scala b/core/src/test/scala/integration/kafka/tools/MirrorMakerIntegrationTest.scala new file mode 100644 index 0000000..abcbebc --- /dev/null +++ b/core/src/test/scala/integration/kafka/tools/MirrorMakerIntegrationTest.scala @@ -0,0 +1,124 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.tools + +import java.util.Properties +import java.util.concurrent.atomic.AtomicBoolean + +import scala.collection.Seq +import kafka.integration.KafkaServerTestHarness +import kafka.server.KafkaConfig +import kafka.tools.MirrorMaker.{ConsumerWrapper, MirrorMakerProducer, NoRecordsException} +import kafka.utils.TestUtils +import org.apache.kafka.clients.consumer.{ConsumerConfig, KafkaConsumer} +import org.apache.kafka.clients.producer.{ProducerConfig, ProducerRecord} +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.errors.TimeoutException +import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer} +import org.apache.kafka.common.utils.Exit +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} +import org.junit.jupiter.api.Assertions._ + +@deprecated(message = "Use the Connect-based MirrorMaker instead (aka MM2).", since = "3.0") +class MirrorMakerIntegrationTest extends KafkaServerTestHarness { + + override def generateConfigs: Seq[KafkaConfig] = + TestUtils.createBrokerConfigs(1, zkConnect).map(KafkaConfig.fromProps(_, new Properties())) + + val exited = new AtomicBoolean(false) + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + Exit.setExitProcedure((_, _) => exited.set(true)) + super.setUp(testInfo) + } + + @AfterEach + override def tearDown(): Unit = { + super.tearDown() + try { + assertFalse(exited.get()) + } finally { + Exit.resetExitProcedure() + } + } + + @Test + def testCommitOffsetsThrowTimeoutException(): Unit = { + val consumerProps = new Properties + consumerProps.put(ConsumerConfig.GROUP_ID_CONFIG, "test-group") + consumerProps.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest") + consumerProps.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList) + consumerProps.put(ConsumerConfig.DEFAULT_API_TIMEOUT_MS_CONFIG, "1") + val consumer = new KafkaConsumer(consumerProps, new ByteArrayDeserializer, new ByteArrayDeserializer) + val mirrorMakerConsumer = new ConsumerWrapper(consumer, None, includeOpt = Some("any")) + mirrorMakerConsumer.offsets.put(new TopicPartition("test", 0), 0L) + assertThrows(classOf[TimeoutException], () => mirrorMakerConsumer.commit()) + } + + @Test + def testCommitOffsetsRemoveNonExistentTopics(): Unit = { + val consumerProps = new Properties + consumerProps.put(ConsumerConfig.GROUP_ID_CONFIG, "test-group") + consumerProps.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest") + consumerProps.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList) + consumerProps.put(ConsumerConfig.DEFAULT_API_TIMEOUT_MS_CONFIG, "2000") + val consumer = new KafkaConsumer(consumerProps, new ByteArrayDeserializer, new ByteArrayDeserializer) + val mirrorMakerConsumer = new ConsumerWrapper(consumer, None, includeOpt = Some("any")) + mirrorMakerConsumer.offsets.put(new TopicPartition("nonexistent-topic1", 0), 0L) + mirrorMakerConsumer.offsets.put(new TopicPartition("nonexistent-topic2", 0), 0L) + MirrorMaker.commitOffsets(mirrorMakerConsumer) + assertTrue(mirrorMakerConsumer.offsets.isEmpty, "Offsets for non-existent topics should be removed") + } + + @Test + def testCommaSeparatedRegex(): Unit = { + val topic = "new-topic" + val msg = "a test message" + val brokerList = TestUtils.getBrokerListStrFromServers(servers) + + val producerProps = new Properties + producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList) + producerProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer]) + producerProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer]) + val producer = new MirrorMakerProducer(true, producerProps) + MirrorMaker.producer = producer + MirrorMaker.producer.send(new ProducerRecord(topic, msg.getBytes())) + MirrorMaker.producer.close() + + val consumerProps = new Properties + consumerProps.put(ConsumerConfig.GROUP_ID_CONFIG, "test-group") + consumerProps.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest") + consumerProps.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList) + val consumer = new KafkaConsumer(consumerProps, new ByteArrayDeserializer, new ByteArrayDeserializer) + + val mirrorMakerConsumer = new ConsumerWrapper(consumer, None, includeOpt = Some("another_topic,new.*,foo")) + mirrorMakerConsumer.init() + try { + TestUtils.waitUntilTrue(() => { + try { + val data = mirrorMakerConsumer.receive() + data.topic == topic && new String(data.value) == msg + } catch { + // these exceptions are thrown if no records are returned within the timeout, so safe to ignore + case _: NoRecordsException => false + } + }, "MirrorMaker consumer should read the expected message from the expected topic within the timeout") + } finally consumer.close() + } + +} diff --git a/core/src/test/scala/kafka/common/InterBrokerSendThreadTest.scala b/core/src/test/scala/kafka/common/InterBrokerSendThreadTest.scala new file mode 100644 index 0000000..f2d9fb9 --- /dev/null +++ b/core/src/test/scala/kafka/common/InterBrokerSendThreadTest.scala @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.common + +import kafka.utils.MockTime +import org.apache.kafka.clients.{ClientRequest, ClientResponse, NetworkClient, RequestCompletionHandler} +import org.apache.kafka.common.Node +import org.apache.kafka.common.errors.{AuthenticationException, DisconnectException} +import org.apache.kafka.common.protocol.ApiKeys +import org.apache.kafka.common.requests.AbstractRequest +import org.easymock.EasyMock +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test +import org.mockito.{ArgumentMatchers, Mockito} + +import java.util +import scala.collection.mutable + +class InterBrokerSendThreadTest { + private val time = new MockTime() + private val networkClient: NetworkClient = EasyMock.createMock(classOf[NetworkClient]) + private val completionHandler = new StubCompletionHandler + private val requestTimeoutMs = 1000 + + class TestInterBrokerSendThread(networkClient: NetworkClient = networkClient, + exceptionCallback: Throwable => Unit = t => throw t) + extends InterBrokerSendThread("name", networkClient, requestTimeoutMs, time) { + private val queue = mutable.Queue[RequestAndCompletionHandler]() + + def enqueue(request: RequestAndCompletionHandler): Unit = { + queue += request + } + + override def generateRequests(): Iterable[RequestAndCompletionHandler] = { + if (queue.isEmpty) { + None + } else { + Some(queue.dequeue()) + } + } + override def pollOnce(maxTimeoutMs: Long): Unit = { + try super.pollOnce(maxTimeoutMs) + catch { + case e: Throwable => exceptionCallback(e) + } + } + + } + + @Test + def shutdownThreadShouldNotCauseException(): Unit = { + val networkClient = Mockito.mock(classOf[NetworkClient]) + // InterBrokerSendThread#shutdown calls NetworkClient#initiateClose first so NetworkClient#poll + // can throw DisconnectException when thread is running + Mockito.when(networkClient.poll(ArgumentMatchers.anyLong, ArgumentMatchers.anyLong)).thenThrow(new DisconnectException()) + var exception: Throwable = null + val thread = new TestInterBrokerSendThread(networkClient, e => exception = e) + thread.shutdown() + thread.pollOnce(100) + assertNull(exception) + } + + @Test + def shouldNotSendAnythingWhenNoRequests(): Unit = { + val sendThread = new TestInterBrokerSendThread() + + // poll is always called but there should be no further invocations on NetworkClient + EasyMock.expect(networkClient.poll(EasyMock.anyLong(), EasyMock.anyLong())) + .andReturn(new util.ArrayList()) + + EasyMock.replay(networkClient) + + sendThread.doWork() + + EasyMock.verify(networkClient) + assertFalse(completionHandler.executedWithDisconnectedResponse) + } + + @Test + def shouldCreateClientRequestAndSendWhenNodeIsReady(): Unit = { + val request = new StubRequestBuilder() + val node = new Node(1, "", 8080) + val handler = RequestAndCompletionHandler(time.milliseconds(), node, request, completionHandler) + val sendThread = new TestInterBrokerSendThread() + + val clientRequest = new ClientRequest("dest", request, 0, "1", 0, true, requestTimeoutMs, handler.handler) + + EasyMock.expect(networkClient.newClientRequest( + EasyMock.eq("1"), + EasyMock.same(handler.request), + EasyMock.anyLong(), + EasyMock.eq(true), + EasyMock.eq(requestTimeoutMs), + EasyMock.same(handler.handler))) + .andReturn(clientRequest) + + EasyMock.expect(networkClient.ready(node, time.milliseconds())) + .andReturn(true) + + EasyMock.expect(networkClient.send(clientRequest, time.milliseconds())) + + EasyMock.expect(networkClient.poll(EasyMock.anyLong(), EasyMock.anyLong())) + .andReturn(new util.ArrayList()) + + EasyMock.replay(networkClient) + + sendThread.enqueue(handler) + sendThread.doWork() + + EasyMock.verify(networkClient) + assertFalse(completionHandler.executedWithDisconnectedResponse) + } + + @Test + def shouldCallCompletionHandlerWithDisconnectedResponseWhenNodeNotReady(): Unit = { + val request = new StubRequestBuilder + val node = new Node(1, "", 8080) + val handler = RequestAndCompletionHandler(time.milliseconds(), node, request, completionHandler) + val sendThread = new TestInterBrokerSendThread() + + val clientRequest = new ClientRequest("dest", request, 0, "1", 0, true, requestTimeoutMs, handler.handler) + + EasyMock.expect(networkClient.newClientRequest( + EasyMock.eq("1"), + EasyMock.same(handler.request), + EasyMock.anyLong(), + EasyMock.eq(true), + EasyMock.eq(requestTimeoutMs), + EasyMock.same(handler.handler))) + .andReturn(clientRequest) + + EasyMock.expect(networkClient.ready(node, time.milliseconds())) + .andReturn(false) + + EasyMock.expect(networkClient.connectionDelay(EasyMock.anyObject(), EasyMock.anyLong())) + .andReturn(0) + + EasyMock.expect(networkClient.poll(EasyMock.anyLong(), EasyMock.anyLong())) + .andReturn(new util.ArrayList()) + + EasyMock.expect(networkClient.connectionFailed(node)) + .andReturn(true) + + EasyMock.expect(networkClient.authenticationException(node)) + .andReturn(new AuthenticationException("")) + + EasyMock.replay(networkClient) + + sendThread.enqueue(handler) + sendThread.doWork() + + EasyMock.verify(networkClient) + assertTrue(completionHandler.executedWithDisconnectedResponse) + } + + @Test + def testFailingExpiredRequests(): Unit = { + val request = new StubRequestBuilder() + val node = new Node(1, "", 8080) + val handler = RequestAndCompletionHandler(time.milliseconds(), node, request, completionHandler) + val sendThread = new TestInterBrokerSendThread() + + val clientRequest = new ClientRequest("dest", + request, + 0, + "1", + time.milliseconds(), + true, + requestTimeoutMs, + handler.handler) + time.sleep(1500) + + EasyMock.expect(networkClient.newClientRequest( + EasyMock.eq("1"), + EasyMock.same(handler.request), + EasyMock.eq(handler.creationTimeMs), + EasyMock.eq(true), + EasyMock.eq(requestTimeoutMs), + EasyMock.same(handler.handler))) + .andReturn(clientRequest) + + // make the node unready so the request is not cleared + EasyMock.expect(networkClient.ready(node, time.milliseconds())) + .andReturn(false) + + EasyMock.expect(networkClient.connectionDelay(EasyMock.anyObject(), EasyMock.anyLong())) + .andReturn(0) + + EasyMock.expect(networkClient.poll(EasyMock.anyLong(), EasyMock.anyLong())) + .andReturn(new util.ArrayList()) + + // rule out disconnects so the request stays for the expiry check + EasyMock.expect(networkClient.connectionFailed(node)) + .andReturn(false) + + EasyMock.replay(networkClient) + + sendThread.enqueue(handler) + sendThread.doWork() + + EasyMock.verify(networkClient) + assertFalse(sendThread.hasUnsentRequests) + assertTrue(completionHandler.executedWithDisconnectedResponse) + } + + private class StubRequestBuilder extends AbstractRequest.Builder(ApiKeys.END_TXN) { + override def build(version: Short): Nothing = ??? + } + + private class StubCompletionHandler extends RequestCompletionHandler { + var executedWithDisconnectedResponse = false + var response: ClientResponse = _ + override def onComplete(response: ClientResponse): Unit = { + this.executedWithDisconnectedResponse = response.wasDisconnected() + this.response = response + } + } + +} diff --git a/core/src/test/scala/kafka/metrics/LinuxIoMetricsCollectorTest.scala b/core/src/test/scala/kafka/metrics/LinuxIoMetricsCollectorTest.scala new file mode 100644 index 0000000..5c208ed --- /dev/null +++ b/core/src/test/scala/kafka/metrics/LinuxIoMetricsCollectorTest.scala @@ -0,0 +1,81 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.metrics + +import java.nio.charset.StandardCharsets +import java.nio.file.Files +import kafka.utils.{Logging, MockTime} +import org.apache.kafka.test.TestUtils +import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertTrue} +import org.junit.jupiter.api.{Test, Timeout} + +@Timeout(120) +class LinuxIoMetricsCollectorTest extends Logging { + + class TestDirectory() { + val baseDir = TestUtils.tempDirectory() + val selfDir = Files.createDirectories(baseDir.toPath.resolve("self")) + + def writeProcFile(readBytes: Long, writeBytes: Long) = { + val bld = new StringBuilder() + bld.append("rchar: 0%n".format()) + bld.append("wchar: 0%n".format()) + bld.append("syschr: 0%n".format()) + bld.append("syscw: 0%n".format()) + bld.append("read_bytes: %d%n".format(readBytes)) + bld.append("write_bytes: %d%n".format(writeBytes)) + bld.append("cancelled_write_bytes: 0%n".format()) + Files.write(selfDir.resolve("io"), bld.toString().getBytes(StandardCharsets.UTF_8)) + } + } + + @Test + def testReadProcFile(): Unit = { + val testDirectory = new TestDirectory() + val time = new MockTime(100, 1000) + testDirectory.writeProcFile(123L, 456L) + val collector = new LinuxIoMetricsCollector(testDirectory.baseDir.getAbsolutePath, + time, logger.underlying) + + // Test that we can read the values we wrote. + assertTrue(collector.usable()) + assertEquals(123L, collector.readBytes()) + assertEquals(456L, collector.writeBytes()) + testDirectory.writeProcFile(124L, 457L) + + // The previous values should still be cached. + assertEquals(123L, collector.readBytes()) + assertEquals(456L, collector.writeBytes()) + + // Update the time, and the values should be re-read. + time.sleep(1) + assertEquals(124L, collector.readBytes()) + assertEquals(457L, collector.writeBytes()) + } + + @Test + def testUnableToReadNonexistentProcFile(): Unit = { + val testDirectory = new TestDirectory() + val time = new MockTime(100, 1000) + val collector = new LinuxIoMetricsCollector(testDirectory.baseDir.getAbsolutePath, + time, logger.underlying) + + // Test that we can't read the file, since it hasn't been written. + assertFalse(collector.usable()) + } +} diff --git a/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala b/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala new file mode 100644 index 0000000..94bf453 --- /dev/null +++ b/core/src/test/scala/kafka/raft/KafkaMetadataLogTest.scala @@ -0,0 +1,962 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.raft + +import kafka.log.{Defaults, UnifiedLog, SegmentDeletion} +import kafka.server.KafkaConfig.{MetadataLogSegmentBytesProp, MetadataLogSegmentMillisProp, MetadataLogSegmentMinBytesProp, NodeIdProp, ProcessRolesProp, QuorumVotersProp} +import kafka.server.{KafkaConfig, KafkaRaftServer} +import kafka.utils.{MockTime, TestUtils} +import org.apache.kafka.common.errors.{InvalidConfigurationException, RecordTooLargeException} +import org.apache.kafka.common.protocol +import org.apache.kafka.common.protocol.{ObjectSerializationCache, Writable} +import org.apache.kafka.common.record.{CompressionType, MemoryRecords, SimpleRecord} +import org.apache.kafka.common.utils.Utils +import org.apache.kafka.raft.internals.BatchBuilder +import org.apache.kafka.raft._ +import org.apache.kafka.server.common.serialization.RecordSerde +import org.apache.kafka.snapshot.{RawSnapshotReader, RawSnapshotWriter, SnapshotPath, Snapshots} +import org.apache.kafka.test.TestUtils.assertOptional +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} + +import java.io.File +import java.nio.ByteBuffer +import java.nio.file.{Files, Path} +import java.util +import java.util.{Collections, Optional, Properties} + +final class KafkaMetadataLogTest { + import KafkaMetadataLogTest._ + + var tempDir: File = _ + val mockTime = new MockTime() + + @BeforeEach + def setUp(): Unit = { + tempDir = TestUtils.tempDir() + } + + @AfterEach + def tearDown(): Unit = { + Utils.delete(tempDir) + } + + @Test + def testConfig(): Unit = { + val props = new Properties() + props.put(ProcessRolesProp, util.Arrays.asList("broker")) + props.put(QuorumVotersProp, "1@localhost:9093") + props.put(NodeIdProp, Int.box(2)) + props.put(KafkaConfig.ControllerListenerNamesProp, "SSL") + props.put(MetadataLogSegmentBytesProp, Int.box(10240)) + props.put(MetadataLogSegmentMillisProp, Int.box(10 * 1024)) + assertThrows(classOf[InvalidConfigurationException], () => { + val kafkaConfig = KafkaConfig.fromProps(props) + val metadataConfig = MetadataLogConfig.apply(kafkaConfig, KafkaRaftClient.MAX_BATCH_SIZE_BYTES, KafkaRaftClient.MAX_FETCH_SIZE_BYTES) + buildMetadataLog(tempDir, mockTime, metadataConfig) + }) + + props.put(MetadataLogSegmentMinBytesProp, Int.box(10240)) + val kafkaConfig = KafkaConfig.fromProps(props) + val metadataConfig = MetadataLogConfig.apply(kafkaConfig, KafkaRaftClient.MAX_BATCH_SIZE_BYTES, KafkaRaftClient.MAX_FETCH_SIZE_BYTES) + buildMetadataLog(tempDir, mockTime, metadataConfig) + } + + @Test + def testUnexpectedAppendOffset(): Unit = { + val log = buildMetadataLog(tempDir, mockTime) + + val recordFoo = new SimpleRecord("foo".getBytes()) + val currentEpoch = 3 + val initialOffset = log.endOffset().offset + + log.appendAsLeader( + MemoryRecords.withRecords(initialOffset, CompressionType.NONE, currentEpoch, recordFoo), + currentEpoch + ) + + // Throw exception for out of order records + assertThrows( + classOf[RuntimeException], + () => { + log.appendAsLeader( + MemoryRecords.withRecords(initialOffset, CompressionType.NONE, currentEpoch, recordFoo), + currentEpoch + ) + } + ) + + assertThrows( + classOf[RuntimeException], + () => { + log.appendAsFollower( + MemoryRecords.withRecords(initialOffset, CompressionType.NONE, currentEpoch, recordFoo) + ) + } + ) + } + + @Test + def testCreateSnapshot(): Unit = { + val numberOfRecords = 10 + val epoch = 1 + val snapshotId = new OffsetAndEpoch(numberOfRecords, epoch) + val log = buildMetadataLog(tempDir, mockTime) + + append(log, numberOfRecords, epoch) + log.updateHighWatermark(new LogOffsetMetadata(numberOfRecords)) + + TestUtils.resource(log.createNewSnapshot(snapshotId).get()) { snapshot => + snapshot.freeze() + } + + assertEquals(0, log.readSnapshot(snapshotId).get().sizeInBytes()) + } + + @Test + def testCreateSnapshotFromEndOffset(): Unit = { + val numberOfRecords = 10 + val firstEpoch = 1 + val secondEpoch = 3 + val log = buildMetadataLog(tempDir, mockTime) + + append(log, numberOfRecords, firstEpoch) + append(log, numberOfRecords, secondEpoch) + log.updateHighWatermark(new LogOffsetMetadata(2 * numberOfRecords)) + + // Test finding the first epoch + log.createNewSnapshot(new OffsetAndEpoch(numberOfRecords, firstEpoch)).get().close() + log.createNewSnapshot(new OffsetAndEpoch(numberOfRecords - 1, firstEpoch)).get().close() + log.createNewSnapshot(new OffsetAndEpoch(1, firstEpoch)).get().close() + + // Test finding the second epoch + log.createNewSnapshot(new OffsetAndEpoch(2 * numberOfRecords, secondEpoch)).get().close() + log.createNewSnapshot(new OffsetAndEpoch(2 * numberOfRecords - 1, secondEpoch)).get().close() + log.createNewSnapshot(new OffsetAndEpoch(numberOfRecords + 1, secondEpoch)).get().close() + } + + @Test + def testCreateSnapshotLaterThanHighWatermark(): Unit = { + val numberOfRecords = 10 + val epoch = 1 + val log = buildMetadataLog(tempDir, mockTime) + + append(log, numberOfRecords, epoch) + log.updateHighWatermark(new LogOffsetMetadata(numberOfRecords)) + + assertThrows( + classOf[IllegalArgumentException], + () => log.createNewSnapshot(new OffsetAndEpoch(numberOfRecords + 1, epoch)) + ) + } + + @Test + def testCreateSnapshotMuchLaterEpoch(): Unit = { + val numberOfRecords = 10 + val epoch = 1 + val log = buildMetadataLog(tempDir, mockTime) + + append(log, numberOfRecords, epoch) + log.updateHighWatermark(new LogOffsetMetadata(numberOfRecords)) + + assertThrows( + classOf[IllegalArgumentException], + () => log.createNewSnapshot(new OffsetAndEpoch(numberOfRecords, epoch + 1)) + ) + } + + @Test + def testCreateSnapshotBeforeLogStartOffset(): Unit = { + val numberOfRecords = 10 + val epoch = 1 + val snapshotId = new OffsetAndEpoch(numberOfRecords-4, epoch) + val log = buildMetadataLog(tempDir, mockTime) + + append(log, numberOfRecords, epoch) + log.updateHighWatermark(new LogOffsetMetadata(numberOfRecords)) + TestUtils.resource(log.createNewSnapshot(snapshotId).get()) { snapshot => + snapshot.freeze() + } + + // Simulate log cleanup that advances the LSO + log.log.maybeIncrementLogStartOffset(snapshotId.offset - 1, SegmentDeletion) + + assertEquals(Optional.empty(), log.createNewSnapshot(new OffsetAndEpoch(snapshotId.offset - 2, snapshotId.epoch))) + } + + @Test + def testCreateSnapshotDivergingEpoch(): Unit = { + val numberOfRecords = 10 + val epoch = 2 + val snapshotId = new OffsetAndEpoch(numberOfRecords, epoch) + val log = buildMetadataLog(tempDir, mockTime) + + append(log, numberOfRecords, epoch) + log.updateHighWatermark(new LogOffsetMetadata(numberOfRecords)) + + assertThrows( + classOf[IllegalArgumentException], + () => log.createNewSnapshot(new OffsetAndEpoch(snapshotId.offset, snapshotId.epoch - 1)) + ) + } + + @Test + def testCreateSnapshotOlderEpoch(): Unit = { + val numberOfRecords = 10 + val epoch = 2 + val snapshotId = new OffsetAndEpoch(numberOfRecords, epoch) + val log = buildMetadataLog(tempDir, mockTime) + + append(log, numberOfRecords, epoch) + log.updateHighWatermark(new LogOffsetMetadata(numberOfRecords)) + + TestUtils.resource(log.createNewSnapshot(snapshotId).get()) { snapshot => + snapshot.freeze() + } + + assertThrows( + classOf[IllegalArgumentException], + () => log.createNewSnapshot(new OffsetAndEpoch(snapshotId.offset, snapshotId.epoch - 1)) + ) + } + + @Test + def testCreateSnapshotWithMissingEpoch(): Unit = { + val firstBatchRecords = 5 + val firstEpoch = 1 + val missingEpoch = firstEpoch + 1 + val secondBatchRecords = 5 + val secondEpoch = missingEpoch + 1 + + val numberOfRecords = firstBatchRecords + secondBatchRecords + val log = buildMetadataLog(tempDir, mockTime) + + append(log, firstBatchRecords, firstEpoch) + append(log, secondBatchRecords, secondEpoch) + log.updateHighWatermark(new LogOffsetMetadata(numberOfRecords)) + + assertThrows( + classOf[IllegalArgumentException], + () => log.createNewSnapshot(new OffsetAndEpoch(1, missingEpoch)) + ) + assertThrows( + classOf[IllegalArgumentException], + () => log.createNewSnapshot(new OffsetAndEpoch(firstBatchRecords, missingEpoch)) + ) + assertThrows( + classOf[IllegalArgumentException], + () => log.createNewSnapshot(new OffsetAndEpoch(secondBatchRecords, missingEpoch)) + ) + } + + @Test + def testCreateExistingSnapshot(): Unit = { + val numberOfRecords = 10 + val epoch = 1 + val snapshotId = new OffsetAndEpoch(numberOfRecords - 1, epoch) + val log = buildMetadataLog(tempDir, mockTime) + + append(log, numberOfRecords, epoch) + log.updateHighWatermark(new LogOffsetMetadata(numberOfRecords)) + + TestUtils.resource(log.createNewSnapshot(snapshotId).get()) { snapshot => + snapshot.freeze() + } + + assertEquals(Optional.empty(), log.createNewSnapshot(snapshotId), + "Creating an existing snapshot should not do anything") + } + + @Test + def testTopicId(): Unit = { + val log = buildMetadataLog(tempDir, mockTime) + + assertEquals(KafkaRaftServer.MetadataTopicId, log.topicId()) + } + + @Test + def testReadMissingSnapshot(): Unit = { + val log = buildMetadataLog(tempDir, mockTime) + + assertEquals(Optional.empty(), log.readSnapshot(new OffsetAndEpoch(10, 0))) + } + + @Test + def testDeleteNonExistentSnapshot(): Unit = { + val log = buildMetadataLog(tempDir, mockTime) + val offset = 10 + val epoch = 0 + + append(log, offset, epoch) + log.updateHighWatermark(new LogOffsetMetadata(offset)) + + assertFalse(log.deleteBeforeSnapshot(new OffsetAndEpoch(2L, epoch))) + assertEquals(0, log.startOffset) + assertEquals(epoch, log.lastFetchedEpoch) + assertEquals(offset, log.endOffset().offset) + assertEquals(offset, log.highWatermark.offset) + } + + @Test + def testTruncateFullyToLatestSnapshot(): Unit = { + val log = buildMetadataLog(tempDir, mockTime) + val numberOfRecords = 10 + val epoch = 0 + val sameEpochSnapshotId = new OffsetAndEpoch(2 * numberOfRecords, epoch) + + append(log, numberOfRecords, epoch) + + TestUtils.resource(log.storeSnapshot(sameEpochSnapshotId).get()) { snapshot => + snapshot.freeze() + } + + assertTrue(log.truncateToLatestSnapshot()) + assertEquals(sameEpochSnapshotId.offset, log.startOffset) + assertEquals(sameEpochSnapshotId.epoch, log.lastFetchedEpoch) + assertEquals(sameEpochSnapshotId.offset, log.endOffset().offset) + assertEquals(sameEpochSnapshotId.offset, log.highWatermark.offset) + + val greaterEpochSnapshotId = new OffsetAndEpoch(3 * numberOfRecords, epoch + 1) + + append(log, numberOfRecords, epoch) + + TestUtils.resource(log.storeSnapshot(greaterEpochSnapshotId).get()) { snapshot => + snapshot.freeze() + } + + assertTrue(log.truncateToLatestSnapshot()) + assertEquals(greaterEpochSnapshotId.offset, log.startOffset) + assertEquals(greaterEpochSnapshotId.epoch, log.lastFetchedEpoch) + assertEquals(greaterEpochSnapshotId.offset, log.endOffset().offset) + assertEquals(greaterEpochSnapshotId.offset, log.highWatermark.offset) + } + + @Test + def testTruncateWillRemoveOlderSnapshot(): Unit = { + val (logDir, log, config) = buildMetadataLogAndDir(tempDir, mockTime) + val numberOfRecords = 10 + val epoch = 1 + + append(log, 1, epoch - 1) + val oldSnapshotId1 = new OffsetAndEpoch(1, epoch - 1) + TestUtils.resource(log.storeSnapshot(oldSnapshotId1).get()) { snapshot => + snapshot.freeze() + } + + append(log, 1, epoch) + val oldSnapshotId2 = new OffsetAndEpoch(2, epoch) + TestUtils.resource(log.storeSnapshot(oldSnapshotId2).get()) { snapshot => + snapshot.freeze() + } + + append(log, numberOfRecords - 2, epoch) + val oldSnapshotId3 = new OffsetAndEpoch(numberOfRecords, epoch) + TestUtils.resource(log.storeSnapshot(oldSnapshotId3).get()) { snapshot => + snapshot.freeze() + } + + val greaterSnapshotId = new OffsetAndEpoch(3 * numberOfRecords, epoch) + append(log, numberOfRecords, epoch) + TestUtils.resource(log.storeSnapshot(greaterSnapshotId).get()) { snapshot => + snapshot.freeze() + } + + assertNotEquals(log.earliestSnapshotId(), log.latestSnapshotId()) + assertTrue(log.truncateToLatestSnapshot()) + assertEquals(log.earliestSnapshotId(), log.latestSnapshotId()) + log.close() + + mockTime.sleep(config.fileDeleteDelayMs) + // Assert that the log dir doesn't contain any older snapshots + Files + .walk(logDir, 1) + .map[Optional[SnapshotPath]](Snapshots.parse) + .filter(_.isPresent) + .forEach { path => + assertFalse(path.get.snapshotId.offset < log.startOffset) + } + } + + @Test + def testDoesntTruncateFully(): Unit = { + val log = buildMetadataLog(tempDir, mockTime) + val numberOfRecords = 10 + val epoch = 1 + + append(log, numberOfRecords, epoch) + + val olderEpochSnapshotId = new OffsetAndEpoch(numberOfRecords, epoch - 1) + TestUtils.resource(log.storeSnapshot(olderEpochSnapshotId).get()) { snapshot => + snapshot.freeze() + } + + assertFalse(log.truncateToLatestSnapshot()) + + append(log, numberOfRecords, epoch) + + val olderOffsetSnapshotId = new OffsetAndEpoch(numberOfRecords, epoch) + TestUtils.resource(log.storeSnapshot(olderOffsetSnapshotId).get()) { snapshot => + snapshot.freeze() + } + + assertFalse(log.truncateToLatestSnapshot()) + } + + @Test + def testCleanupPartialSnapshots(): Unit = { + val (logDir, log, config) = buildMetadataLogAndDir(tempDir, mockTime) + val numberOfRecords = 10 + val epoch = 1 + val snapshotId = new OffsetAndEpoch(1, epoch) + + append(log, numberOfRecords, epoch) + TestUtils.resource(log.storeSnapshot(snapshotId).get()) { snapshot => + snapshot.freeze() + } + + log.close() + + // Create a few partial snapshots + Snapshots.createTempFile(logDir, new OffsetAndEpoch(0, epoch - 1)) + Snapshots.createTempFile(logDir, new OffsetAndEpoch(1, epoch)) + Snapshots.createTempFile(logDir, new OffsetAndEpoch(2, epoch + 1)) + + val secondLog = buildMetadataLog(tempDir, mockTime) + + assertEquals(snapshotId, secondLog.latestSnapshotId().get) + assertEquals(0, log.startOffset) + assertEquals(epoch, log.lastFetchedEpoch) + assertEquals(numberOfRecords, log.endOffset().offset) + assertEquals(0, secondLog.highWatermark.offset) + + // Assert that the log dir doesn't contain any partial snapshots + Files + .walk(logDir, 1) + .map[Optional[SnapshotPath]](Snapshots.parse) + .filter(_.isPresent) + .forEach { path => + assertFalse(path.get.partial) + } + } + + @Test + def testCleanupOlderSnapshots(): Unit = { + val (logDir, log, config) = buildMetadataLogAndDir(tempDir, mockTime) + val numberOfRecords = 10 + val epoch = 1 + + append(log, 1, epoch - 1) + val oldSnapshotId1 = new OffsetAndEpoch(1, epoch - 1) + TestUtils.resource(log.storeSnapshot(oldSnapshotId1).get()) { snapshot => + snapshot.freeze() + } + + append(log, 1, epoch) + val oldSnapshotId2 = new OffsetAndEpoch(2, epoch) + TestUtils.resource(log.storeSnapshot(oldSnapshotId2).get()) { snapshot => + snapshot.freeze() + } + + append(log, numberOfRecords - 2, epoch) + val oldSnapshotId3 = new OffsetAndEpoch(numberOfRecords, epoch) + TestUtils.resource(log.storeSnapshot(oldSnapshotId3).get()) { snapshot => + snapshot.freeze() + } + + val greaterSnapshotId = new OffsetAndEpoch(3 * numberOfRecords, epoch) + append(log, numberOfRecords, epoch) + TestUtils.resource(log.storeSnapshot(greaterSnapshotId).get()) { snapshot => + snapshot.freeze() + } + + log.close() + + val secondLog = buildMetadataLog(tempDir, mockTime) + + assertEquals(greaterSnapshotId, secondLog.latestSnapshotId().get) + assertEquals(3 * numberOfRecords, secondLog.startOffset) + assertEquals(epoch, secondLog.lastFetchedEpoch) + mockTime.sleep(config.fileDeleteDelayMs) + + // Assert that the log dir doesn't contain any older snapshots + Files + .walk(logDir, 1) + .map[Optional[SnapshotPath]](Snapshots.parse) + .filter(_.isPresent) + .forEach { path => + assertFalse(path.get.snapshotId.offset < log.startOffset) + } + } + + @Test + def testCreateReplicatedLogTruncatesFully(): Unit = { + val log = buildMetadataLog(tempDir, mockTime) + val numberOfRecords = 10 + val epoch = 1 + val snapshotId = new OffsetAndEpoch(numberOfRecords + 1, epoch + 1) + + append(log, numberOfRecords, epoch) + TestUtils.resource(log.storeSnapshot(snapshotId).get()) { snapshot => + snapshot.freeze() + } + + log.close() + + val secondLog = buildMetadataLog(tempDir, mockTime) + + assertEquals(snapshotId, secondLog.latestSnapshotId().get) + assertEquals(snapshotId.offset, secondLog.startOffset) + assertEquals(snapshotId.epoch, secondLog.lastFetchedEpoch) + assertEquals(snapshotId.offset, secondLog.endOffset().offset) + assertEquals(snapshotId.offset, secondLog.highWatermark.offset) + } + + @Test + def testMaxBatchSize(): Unit = { + val leaderEpoch = 5 + val maxBatchSizeInBytes = 16384 + val recordSize = 64 + val log = buildMetadataLog(tempDir, mockTime, DefaultMetadataLogConfig.copy(maxBatchSizeInBytes = maxBatchSizeInBytes)) + + val oversizeBatch = buildFullBatch(leaderEpoch, recordSize, maxBatchSizeInBytes + recordSize) + assertThrows(classOf[RecordTooLargeException], () => { + log.appendAsLeader(oversizeBatch, leaderEpoch) + }) + + val undersizeBatch = buildFullBatch(leaderEpoch, recordSize, maxBatchSizeInBytes) + val appendInfo = log.appendAsLeader(undersizeBatch, leaderEpoch) + assertEquals(0L, appendInfo.firstOffset) + } + + @Test + def testTruncateBelowHighWatermark(): Unit = { + val log = buildMetadataLog(tempDir, mockTime) + val numRecords = 10 + val epoch = 5 + + append(log, numRecords, epoch) + assertEquals(numRecords.toLong, log.endOffset.offset) + + log.updateHighWatermark(new LogOffsetMetadata(numRecords)) + assertEquals(numRecords.toLong, log.highWatermark.offset) + + assertThrows(classOf[IllegalArgumentException], () => log.truncateTo(5L)) + assertEquals(numRecords.toLong, log.highWatermark.offset) + } + + private def buildFullBatch( + leaderEpoch: Int, + recordSize: Int, + maxBatchSizeInBytes: Int + ): MemoryRecords = { + val buffer = ByteBuffer.allocate(maxBatchSizeInBytes) + val batchBuilder = new BatchBuilder[Array[Byte]]( + buffer, + new ByteArraySerde, + CompressionType.NONE, + 0L, + mockTime.milliseconds(), + false, + leaderEpoch, + maxBatchSizeInBytes + ) + + val serializationCache = new ObjectSerializationCache + val records = Collections.singletonList(new Array[Byte](recordSize)) + while (!batchBuilder.bytesNeeded(records, serializationCache).isPresent) { + batchBuilder.appendRecord(records.get(0), serializationCache) + } + + batchBuilder.build() + } + + @Test + def testValidateEpochGreaterThanLastKnownEpoch(): Unit = { + val log = buildMetadataLog(tempDir, mockTime) + + val numberOfRecords = 1 + val epoch = 1 + + append(log, numberOfRecords, epoch) + + val resultOffsetAndEpoch = log.validateOffsetAndEpoch(numberOfRecords, epoch + 1) + assertEquals(ValidOffsetAndEpoch.Kind.DIVERGING, resultOffsetAndEpoch.kind) + assertEquals(new OffsetAndEpoch(log.endOffset.offset, epoch), resultOffsetAndEpoch.offsetAndEpoch()) + } + + @Test + def testValidateEpochLessThanOldestSnapshotEpoch(): Unit = { + val log = buildMetadataLog(tempDir, mockTime) + + val numberOfRecords = 10 + val epoch = 1 + + append(log, numberOfRecords, epoch) + log.updateHighWatermark(new LogOffsetMetadata(numberOfRecords)) + + val snapshotId = new OffsetAndEpoch(numberOfRecords, epoch) + TestUtils.resource(log.createNewSnapshot(snapshotId).get()) { snapshot => + snapshot.freeze() + } + + val resultOffsetAndEpoch = log.validateOffsetAndEpoch(numberOfRecords, epoch - 1) + assertEquals(ValidOffsetAndEpoch.Kind.SNAPSHOT, resultOffsetAndEpoch.kind) + assertEquals(snapshotId, resultOffsetAndEpoch.offsetAndEpoch()) + } + + @Test + def testValidateOffsetLessThanOldestSnapshotOffset(): Unit = { + val log = buildMetadataLog(tempDir, mockTime) + + val offset = 2 + val epoch = 1 + + append(log, offset, epoch) + log.updateHighWatermark(new LogOffsetMetadata(offset)) + + val snapshotId = new OffsetAndEpoch(offset, epoch) + TestUtils.resource(log.createNewSnapshot(snapshotId).get()) { snapshot => + snapshot.freeze() + } + // Simulate log cleaning advancing the LSO + log.log.maybeIncrementLogStartOffset(offset, SegmentDeletion); + + val resultOffsetAndEpoch = log.validateOffsetAndEpoch(offset - 1, epoch) + assertEquals(ValidOffsetAndEpoch.Kind.SNAPSHOT, resultOffsetAndEpoch.kind) + assertEquals(snapshotId, resultOffsetAndEpoch.offsetAndEpoch()) + } + + @Test + def testValidateOffsetEqualToOldestSnapshotOffset(): Unit = { + val log = buildMetadataLog(tempDir, mockTime) + + val offset = 2 + val epoch = 1 + + append(log, offset, epoch) + log.updateHighWatermark(new LogOffsetMetadata(offset)) + + val snapshotId = new OffsetAndEpoch(offset, epoch) + TestUtils.resource(log.createNewSnapshot(snapshotId).get()) { snapshot => + snapshot.freeze() + } + + val resultOffsetAndEpoch = log.validateOffsetAndEpoch(offset, epoch) + assertEquals(ValidOffsetAndEpoch.Kind.VALID, resultOffsetAndEpoch.kind) + assertEquals(snapshotId, resultOffsetAndEpoch.offsetAndEpoch()) + } + + @Test + def testValidateUnknownEpochLessThanLastKnownGreaterThanOldestSnapshot(): Unit = { + val offset = 10 + val numOfRecords = 5 + + val log = buildMetadataLog(tempDir, mockTime) + log.updateHighWatermark(new LogOffsetMetadata(offset)) + val snapshotId = new OffsetAndEpoch(offset, 1) + TestUtils.resource(log.storeSnapshot(snapshotId).get()) { snapshot => + snapshot.freeze() + } + log.truncateToLatestSnapshot() + + + append(log, numOfRecords, epoch = 1) + append(log, numOfRecords, epoch = 2) + append(log, numOfRecords, epoch = 4) + + // offset is not equal to oldest snapshot's offset + val resultOffsetAndEpoch = log.validateOffsetAndEpoch(100, 3) + assertEquals(ValidOffsetAndEpoch.Kind.DIVERGING, resultOffsetAndEpoch.kind) + assertEquals(new OffsetAndEpoch(20, 2), resultOffsetAndEpoch.offsetAndEpoch()) + } + + @Test + def testValidateEpochLessThanFirstEpochInLog(): Unit = { + val offset = 10 + val numOfRecords = 5 + + val log = buildMetadataLog(tempDir, mockTime) + log.updateHighWatermark(new LogOffsetMetadata(offset)) + val snapshotId = new OffsetAndEpoch(offset, 1) + TestUtils.resource(log.storeSnapshot(snapshotId).get()) { snapshot => + snapshot.freeze() + } + log.truncateToLatestSnapshot() + + append(log, numOfRecords, epoch = 3) + + // offset is not equal to oldest snapshot's offset + val resultOffsetAndEpoch = log.validateOffsetAndEpoch(100, 2) + assertEquals(ValidOffsetAndEpoch.Kind.DIVERGING, resultOffsetAndEpoch.kind) + assertEquals(snapshotId, resultOffsetAndEpoch.offsetAndEpoch()) + } + + @Test + def testValidateOffsetGreatThanEndOffset(): Unit = { + val log = buildMetadataLog(tempDir, mockTime) + + val numberOfRecords = 1 + val epoch = 1 + + append(log, numberOfRecords, epoch) + + val resultOffsetAndEpoch = log.validateOffsetAndEpoch(numberOfRecords + 1, epoch) + assertEquals(ValidOffsetAndEpoch.Kind.DIVERGING, resultOffsetAndEpoch.kind) + assertEquals(new OffsetAndEpoch(log.endOffset.offset, epoch), resultOffsetAndEpoch.offsetAndEpoch()) + } + + @Test + def testValidateOffsetLessThanLEO(): Unit = { + val log = buildMetadataLog(tempDir, mockTime) + + val numberOfRecords = 10 + val epoch = 1 + + append(log, numberOfRecords, epoch) + append(log, numberOfRecords, epoch + 1) + + val resultOffsetAndEpoch = log.validateOffsetAndEpoch(11, epoch) + assertEquals(ValidOffsetAndEpoch.Kind.DIVERGING, resultOffsetAndEpoch.kind) + assertEquals(new OffsetAndEpoch(10, epoch), resultOffsetAndEpoch.offsetAndEpoch()) + } + + @Test + def testValidateValidEpochAndOffset(): Unit = { + val log = buildMetadataLog(tempDir, mockTime) + + val numberOfRecords = 5 + val epoch = 1 + + append(log, numberOfRecords, epoch) + + val resultOffsetAndEpoch = log.validateOffsetAndEpoch(numberOfRecords - 1, epoch) + assertEquals(ValidOffsetAndEpoch.Kind.VALID, resultOffsetAndEpoch.kind) + assertEquals(new OffsetAndEpoch(numberOfRecords - 1, epoch), resultOffsetAndEpoch.offsetAndEpoch()) + } + + @Test + def testAdvanceLogStartOffsetAfterCleaning(): Unit = { + val config = MetadataLogConfig( + logSegmentBytes = 512, + logSegmentMinBytes = 512, + logSegmentMillis = 10 * 1000, + retentionMaxBytes = 256, + retentionMillis = 60 * 1000, + maxBatchSizeInBytes = 512, + maxFetchSizeInBytes = DefaultMetadataLogConfig.maxFetchSizeInBytes, + fileDeleteDelayMs = Defaults.FileDeleteDelayMs, + nodeId = 1 + ) + config.copy() + val log = buildMetadataLog(tempDir, mockTime, config) + + // Generate some segments + for(_ <- 0 to 100) { + append(log, 47, 1) // An odd number of records to avoid offset alignment + } + assertFalse(log.maybeClean(), "Should not clean since HW was still 0") + + log.updateHighWatermark(new LogOffsetMetadata(4000)) + assertFalse(log.maybeClean(), "Should not clean since no snapshots exist") + + val snapshotId1 = new OffsetAndEpoch(1000, 1) + TestUtils.resource(log.storeSnapshot(snapshotId1).get()) { snapshot => + append(snapshot, 100) + snapshot.freeze() + } + + val snapshotId2 = new OffsetAndEpoch(2000, 1) + TestUtils.resource(log.storeSnapshot(snapshotId2).get()) { snapshot => + append(snapshot, 100) + snapshot.freeze() + } + + val lsoBefore = log.startOffset() + assertTrue(log.maybeClean(), "Expected to clean since there was at least one snapshot") + val lsoAfter = log.startOffset() + assertTrue(lsoAfter > lsoBefore, "Log Start Offset should have increased after cleaning") + assertTrue(lsoAfter == snapshotId2.offset, "Expected the Log Start Offset to be less than or equal to the snapshot offset") + } + + @Test + def testDeleteSnapshots(): Unit = { + // Generate some logs and a few snapshots, set retention low and verify that cleaning occurs + val config = DefaultMetadataLogConfig.copy( + logSegmentBytes = 1024, + logSegmentMinBytes = 1024, + logSegmentMillis = 10 * 1000, + retentionMaxBytes = 1024, + retentionMillis = 60 * 1000, + maxBatchSizeInBytes = 100 + ) + val log = buildMetadataLog(tempDir, mockTime, config) + + for(_ <- 0 to 1000) { + append(log, 1, 1) + } + log.updateHighWatermark(new LogOffsetMetadata(1001)) + + for(offset <- Seq(100, 200, 300, 400, 500, 600)) { + val snapshotId = new OffsetAndEpoch(offset, 1) + TestUtils.resource(log.storeSnapshot(snapshotId).get()) { snapshot => + append(snapshot, 10) + snapshot.freeze() + } + } + + assertEquals(6, log.snapshotCount()) + assertTrue(log.maybeClean()) + assertEquals(1, log.snapshotCount(), "Expected only one snapshot after cleaning") + assertOptional(log.latestSnapshotId(), (snapshotId: OffsetAndEpoch) => { + assertEquals(600, snapshotId.offset) + }) + assertEquals(log.startOffset, 600) + } + + @Test + def testSoftRetentionLimit(): Unit = { + // Set retention equal to the segment size and generate slightly more than one segment of logs + val config = DefaultMetadataLogConfig.copy( + logSegmentBytes = 10240, + logSegmentMinBytes = 10240, + logSegmentMillis = 10 * 1000, + retentionMaxBytes = 10240, + retentionMillis = 60 * 1000, + maxBatchSizeInBytes = 100 + ) + val log = buildMetadataLog(tempDir, mockTime, config) + + for(_ <- 0 to 2000) { + append(log, 1, 1) + } + log.updateHighWatermark(new LogOffsetMetadata(2000)) + + // Then generate two snapshots + val snapshotId1 = new OffsetAndEpoch(1000, 1) + TestUtils.resource(log.storeSnapshot(snapshotId1).get()) { snapshot => + append(snapshot, 500) + snapshot.freeze() + } + + // Then generate a snapshot + val snapshotId2 = new OffsetAndEpoch(2000, 1) + TestUtils.resource(log.storeSnapshot(snapshotId2).get()) { snapshot => + append(snapshot, 500) + snapshot.freeze() + } + + // Cleaning should occur, but resulting size will not be under retention limit since we have to keep one snapshot + assertTrue(log.maybeClean()) + assertEquals(1, log.snapshotCount(), "Expected one snapshot after cleaning") + assertOptional(log.latestSnapshotId(), (snapshotId: OffsetAndEpoch) => { + assertEquals(2000, snapshotId.offset, "Unexpected offset for latest snapshot") + assertOptional(log.readSnapshot(snapshotId), (reader: RawSnapshotReader) => { + assertTrue(reader.sizeInBytes() + log.log.size > config.retentionMaxBytes) + }) + }) + } +} + +object KafkaMetadataLogTest { + class ByteArraySerde extends RecordSerde[Array[Byte]] { + override def recordSize(data: Array[Byte], serializationCache: ObjectSerializationCache): Int = { + data.length + } + override def write(data: Array[Byte], serializationCache: ObjectSerializationCache, out: Writable): Unit = { + out.writeByteArray(data) + } + override def read(input: protocol.Readable, size: Int): Array[Byte] = { + val array = new Array[Byte](size) + input.readArray(array) + array + } + } + + val DefaultMetadataLogConfig = MetadataLogConfig( + logSegmentBytes = 100 * 1024, + logSegmentMinBytes = 100 * 1024, + logSegmentMillis = 10 * 1000, + retentionMaxBytes = 100 * 1024, + retentionMillis = 60 * 1000, + maxBatchSizeInBytes = KafkaRaftClient.MAX_BATCH_SIZE_BYTES, + maxFetchSizeInBytes = KafkaRaftClient.MAX_FETCH_SIZE_BYTES, + fileDeleteDelayMs = Defaults.FileDeleteDelayMs, + nodeId = 1 + ) + + def buildMetadataLogAndDir( + tempDir: File, + time: MockTime, + metadataLogConfig: MetadataLogConfig = DefaultMetadataLogConfig + ): (Path, KafkaMetadataLog, MetadataLogConfig) = { + + val logDir = createLogDirectory( + tempDir, + UnifiedLog.logDirName(KafkaRaftServer.MetadataPartition) + ) + + val metadataLog = KafkaMetadataLog( + KafkaRaftServer.MetadataPartition, + KafkaRaftServer.MetadataTopicId, + logDir, + time, + time.scheduler, + metadataLogConfig + ) + + (logDir.toPath, metadataLog, metadataLogConfig) + } + + def buildMetadataLog( + tempDir: File, + time: MockTime, + metadataLogConfig: MetadataLogConfig = DefaultMetadataLogConfig, + ): KafkaMetadataLog = { + val (_, log, _) = buildMetadataLogAndDir(tempDir, time, metadataLogConfig) + log + } + + def append(log: ReplicatedLog, numberOfRecords: Int, epoch: Int): LogAppendInfo = { + log.appendAsLeader( + MemoryRecords.withRecords( + log.endOffset().offset, + CompressionType.NONE, + epoch, + (0 until numberOfRecords).map(number => new SimpleRecord(number.toString.getBytes)): _* + ), + epoch + ) + } + + def append(snapshotWriter: RawSnapshotWriter, numberOfRecords: Int): Unit = { + snapshotWriter.append(MemoryRecords.withRecords( + 0, + CompressionType.NONE, + 0, + (0 until numberOfRecords).map(number => new SimpleRecord(number.toString.getBytes)): _* + )) + } + + private def createLogDirectory(logDir: File, logDirName: String): File = { + val logDirPath = logDir.getAbsolutePath + val dir = new File(logDirPath, logDirName) + if (!Files.exists(dir.toPath)) { + Files.createDirectories(dir.toPath) + } + dir + } +} diff --git a/core/src/test/scala/kafka/security/minikdc/MiniKdc.scala b/core/src/test/scala/kafka/security/minikdc/MiniKdc.scala new file mode 100644 index 0000000..ea8815d --- /dev/null +++ b/core/src/test/scala/kafka/security/minikdc/MiniKdc.scala @@ -0,0 +1,442 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.security.minikdc + +import java.io._ +import java.net.InetSocketAddress +import java.nio.charset.StandardCharsets +import java.nio.file.Files +import java.text.MessageFormat +import java.util.{Locale, Properties, UUID} + +import kafka.utils.{CoreUtils, Exit, Logging} + +import scala.jdk.CollectionConverters._ +import org.apache.commons.lang.text.StrSubstitutor +import org.apache.directory.api.ldap.model.entry.{DefaultEntry, Entry} +import org.apache.directory.api.ldap.model.ldif.LdifReader +import org.apache.directory.api.ldap.model.name.Dn +import org.apache.directory.api.ldap.schema.extractor.impl.DefaultSchemaLdifExtractor +import org.apache.directory.api.ldap.schema.loader.LdifSchemaLoader +import org.apache.directory.api.ldap.schema.manager.impl.DefaultSchemaManager +import org.apache.directory.server.constants.ServerDNConstants +import org.apache.directory.server.core.DefaultDirectoryService +import org.apache.directory.server.core.api.{CacheService, DirectoryService, InstanceLayout} +import org.apache.directory.server.core.api.schema.SchemaPartition +import org.apache.directory.server.core.kerberos.KeyDerivationInterceptor +import org.apache.directory.server.core.partition.impl.btree.jdbm.{JdbmIndex, JdbmPartition} +import org.apache.directory.server.core.partition.ldif.LdifPartition +import org.apache.directory.server.kerberos.KerberosConfig +import org.apache.directory.server.kerberos.kdc.KdcServer +import org.apache.directory.server.kerberos.shared.crypto.encryption.KerberosKeyFactory +import org.apache.directory.server.kerberos.shared.keytab.{Keytab, KeytabEntry} +import org.apache.directory.server.protocol.shared.transport.{TcpTransport, UdpTransport} +import org.apache.directory.server.xdbm.Index +import org.apache.directory.shared.kerberos.KerberosTime +import org.apache.kafka.common.utils.{Java, Utils} + +/** + * Mini KDC based on Apache Directory Server that can be embedded in tests or used from command line as a standalone + * KDC. + * + * MiniKdc sets 2 System properties when started and unsets them when stopped: + * + * - java.security.krb5.conf: set to the MiniKDC real/host/port + * - sun.security.krb5.debug: set to the debug value provided in the configuration + * + * As a result of this, multiple MiniKdc instances should not be started concurrently in the same JVM. + * + * MiniKdc default configuration values are: + * + * - org.name=EXAMPLE (used to create the REALM) + * - org.domain=COM (used to create the REALM) + * - kdc.bind.address=localhost + * - kdc.port=0 (ephemeral port) + * - instance=DefaultKrbServer + * - max.ticket.lifetime=86400000 (1 day) + * - max.renewable.lifetime604800000 (7 days) + * - transport=TCP + * - debug=false + * + * The generated krb5.conf forces TCP connections. + * + * Acknowledgements: this class is derived from the MiniKdc class in the hadoop-minikdc project (git commit + * d8d8ed35f00b15ee0f2f8aaf3fe7f7b42141286b). + * + * @constructor creates a new MiniKdc instance. + * @param config the MiniKdc configuration + * @param workDir the working directory which will contain krb5.conf, Apache DS files and any other files needed by + * MiniKdc. + * @throws Exception thrown if the MiniKdc could not be created. + */ +class MiniKdc(config: Properties, workDir: File) extends Logging { + + if (!config.keySet.containsAll(MiniKdc.RequiredProperties.asJava)) { + val missingProperties = MiniKdc.RequiredProperties.filterNot(config.keySet.asScala) + throw new IllegalArgumentException(s"Missing configuration properties: $missingProperties") + } + + info("Configuration:") + info("---------------------------------------------------------------") + config.forEach { (key, value) => + info(s"\t$key: $value") + } + info("---------------------------------------------------------------") + + private val orgName = config.getProperty(MiniKdc.OrgName) + private val orgDomain = config.getProperty(MiniKdc.OrgDomain) + private val realm = s"${orgName.toUpperCase(Locale.ENGLISH)}.${orgDomain.toUpperCase(Locale.ENGLISH)}" + private val krb5conf = new File(workDir, "krb5.conf") + + private var _port = config.getProperty(MiniKdc.KdcPort).toInt + private var ds: DirectoryService = null + private var kdc: KdcServer = null + private var closed = false + + def port: Int = _port + + def host: String = config.getProperty(MiniKdc.KdcBindAddress) + + def start(): Unit = { + if (kdc != null) + throw new RuntimeException("KDC already started") + if (closed) + throw new RuntimeException("KDC is closed") + initDirectoryService() + initKdcServer() + initJvmKerberosConfig() + } + + private def initDirectoryService(): Unit = { + ds = new DefaultDirectoryService + ds.setInstanceLayout(new InstanceLayout(workDir)) + ds.setCacheService(new CacheService) + + // first load the schema + val instanceLayout = ds.getInstanceLayout + val schemaPartitionDirectory = new File(instanceLayout.getPartitionsDirectory, "schema") + val extractor = new DefaultSchemaLdifExtractor(instanceLayout.getPartitionsDirectory) + extractor.extractOrCopy + + val loader = new LdifSchemaLoader(schemaPartitionDirectory) + val schemaManager = new DefaultSchemaManager(loader) + schemaManager.loadAllEnabled() + ds.setSchemaManager(schemaManager) + // Init the LdifPartition with schema + val schemaLdifPartition = new LdifPartition(schemaManager, ds.getDnFactory) + schemaLdifPartition.setPartitionPath(schemaPartitionDirectory.toURI) + + // The schema partition + val schemaPartition = new SchemaPartition(schemaManager) + schemaPartition.setWrappedPartition(schemaLdifPartition) + ds.setSchemaPartition(schemaPartition) + + val systemPartition = new JdbmPartition(ds.getSchemaManager, ds.getDnFactory) + systemPartition.setId("system") + systemPartition.setPartitionPath(new File(ds.getInstanceLayout.getPartitionsDirectory, systemPartition.getId).toURI) + systemPartition.setSuffixDn(new Dn(ServerDNConstants.SYSTEM_DN)) + systemPartition.setSchemaManager(ds.getSchemaManager) + ds.setSystemPartition(systemPartition) + + ds.getChangeLog.setEnabled(false) + ds.setDenormalizeOpAttrsEnabled(true) + ds.addLast(new KeyDerivationInterceptor) + + // create one partition + val orgName = config.getProperty(MiniKdc.OrgName).toLowerCase(Locale.ENGLISH) + val orgDomain = config.getProperty(MiniKdc.OrgDomain).toLowerCase(Locale.ENGLISH) + val partition = new JdbmPartition(ds.getSchemaManager, ds.getDnFactory) + partition.setId(orgName) + partition.setPartitionPath(new File(ds.getInstanceLayout.getPartitionsDirectory, orgName).toURI) + val dn = new Dn(s"dc=$orgName,dc=$orgDomain") + partition.setSuffixDn(dn) + ds.addPartition(partition) + + // indexes + val indexedAttributes = Set[Index[_, String]]( + new JdbmIndex[Entry]("objectClass", false), + new JdbmIndex[Entry]("dc", false), + new JdbmIndex[Entry]("ou", false) + ).asJava + partition.setIndexedAttributes(indexedAttributes) + + // And start the ds + ds.setInstanceId(config.getProperty(MiniKdc.Instance)) + ds.startup() + + // context entry, after ds.startup() + val entry = ds.newEntry(dn) + entry.add("objectClass", "top", "domain") + entry.add("dc", orgName) + ds.getAdminSession.add(entry) + } + + private def initKdcServer(): Unit = { + + def addInitialEntriesToDirectoryService(bindAddress: String): Unit = { + val map = Map ( + "0" -> orgName.toLowerCase(Locale.ENGLISH), + "1" -> orgDomain.toLowerCase(Locale.ENGLISH), + "2" -> orgName.toUpperCase(Locale.ENGLISH), + "3" -> orgDomain.toUpperCase(Locale.ENGLISH), + "4" -> bindAddress + ) + val reader = new BufferedReader(new InputStreamReader(MiniKdc.getResourceAsStream("minikdc.ldiff"))) + try { + var line: String = null + val builder = new StringBuilder + while ({line = reader.readLine(); line != null}) + builder.append(line).append("\n") + addEntriesToDirectoryService(StrSubstitutor.replace(builder, map.asJava)) + } + finally CoreUtils.swallow(reader.close(), this) + } + + val bindAddress = config.getProperty(MiniKdc.KdcBindAddress) + addInitialEntriesToDirectoryService(bindAddress) + + val kerberosConfig = new KerberosConfig + kerberosConfig.setMaximumRenewableLifetime(config.getProperty(MiniKdc.MaxRenewableLifetime).toLong) + kerberosConfig.setMaximumTicketLifetime(config.getProperty(MiniKdc.MaxTicketLifetime).toLong) + kerberosConfig.setSearchBaseDn(s"dc=$orgName,dc=$orgDomain") + kerberosConfig.setPaEncTimestampRequired(false) + kdc = new KdcServer(kerberosConfig) + kdc.setDirectoryService(ds) + + // transport + val transport = config.getProperty(MiniKdc.Transport) + val absTransport = transport.trim match { + case "TCP" => new TcpTransport(bindAddress, port, 3, 50) + case "UDP" => new UdpTransport(port) + case _ => throw new IllegalArgumentException(s"Invalid transport: $transport") + } + kdc.addTransports(absTransport) + kdc.setServiceName(config.getProperty(MiniKdc.Instance)) + kdc.start() + + // if using ephemeral port, update port number for binding + if (port == 0) + _port = absTransport.getAcceptor.getLocalAddress.asInstanceOf[InetSocketAddress].getPort + + info(s"MiniKdc listening at port: $port") + } + + private def initJvmKerberosConfig(): Unit = { + writeKrb5Conf() + System.setProperty(MiniKdc.JavaSecurityKrb5Conf, krb5conf.getAbsolutePath) + System.setProperty(MiniKdc.SunSecurityKrb5Debug, config.getProperty(MiniKdc.Debug, "false")) + info(s"MiniKdc setting JVM krb5.conf to: ${krb5conf.getAbsolutePath}") + refreshJvmKerberosConfig() + } + + private def writeKrb5Conf(): Unit = { + val stringBuilder = new StringBuilder + val reader = new BufferedReader( + new InputStreamReader(MiniKdc.getResourceAsStream("minikdc-krb5.conf"), StandardCharsets.UTF_8)) + try { + var line: String = null + while ({line = reader.readLine(); line != null}) { + stringBuilder.append(line).append("{3}") + } + } finally CoreUtils.swallow(reader.close(), this) + val output = MessageFormat.format(stringBuilder.toString, realm, host, port.toString, System.lineSeparator()) + Files.write(krb5conf.toPath, output.getBytes(StandardCharsets.UTF_8)) + } + + private def refreshJvmKerberosConfig(): Unit = { + val klass = + if (Java.isIbmJdk) + Class.forName("com.ibm.security.krb5.internal.Config") + else + Class.forName("sun.security.krb5.Config") + klass.getMethod("refresh").invoke(klass) + } + + def stop(): Unit = { + if (!closed) { + closed = true + if (kdc != null) { + System.clearProperty(MiniKdc.JavaSecurityKrb5Conf) + System.clearProperty(MiniKdc.SunSecurityKrb5Debug) + kdc.stop() + try ds.shutdown() + catch { + case ex: Exception => error("Could not shutdown ApacheDS properly", ex) + } + } + } + } + + /** + * Creates a principal in the KDC with the specified user and password. + * + * An exception will be thrown if the principal cannot be created. + * + * @param principal principal name, do not include the domain. + * @param password password. + */ + private def createPrincipal(principal: String, password: String): Unit = { + val ldifContent = s""" + |dn: uid=$principal,ou=users,dc=${orgName.toLowerCase(Locale.ENGLISH)},dc=${orgDomain.toLowerCase(Locale.ENGLISH)} + |objectClass: top + |objectClass: person + |objectClass: inetOrgPerson + |objectClass: krb5principal + |objectClass: krb5kdcentry + |cn: $principal + |sn: $principal + |uid: $principal + |userPassword: $password + |krb5PrincipalName: ${principal}@${realm} + |krb5KeyVersionNumber: 0""".stripMargin + addEntriesToDirectoryService(ldifContent) + } + + /** + * Creates multiple principals in the KDC and adds them to a keytab file. + * + * An exception will be thrown if the principal cannot be created. + * + * @param keytabFile keytab file to add the created principals + * @param principals principals to add to the KDC, do not include the domain. + */ + def createPrincipal(keytabFile: File, principals: String*): Unit = { + val generatedPassword = UUID.randomUUID.toString + val keytab = new Keytab + val entries = principals.flatMap { principal => + createPrincipal(principal, generatedPassword) + val principalWithRealm = s"${principal}@${realm}" + val timestamp = new KerberosTime + KerberosKeyFactory.getKerberosKeys(principalWithRealm, generatedPassword).asScala.values.map { encryptionKey => + val keyVersion = encryptionKey.getKeyVersion.toByte + new KeytabEntry(principalWithRealm, 1, timestamp, keyVersion, encryptionKey) + } + } + keytab.setEntries(entries.asJava) + keytab.write(keytabFile) + } + + private def addEntriesToDirectoryService(ldifContent: String): Unit = { + val reader = new LdifReader(new StringReader(ldifContent)) + try { + for (ldifEntry <- reader.asScala) + ds.getAdminSession.add(new DefaultEntry(ds.getSchemaManager, ldifEntry.getEntry)) + } finally CoreUtils.swallow(reader.close(), this) + } + +} + +object MiniKdc { + + val JavaSecurityKrb5Conf = "java.security.krb5.conf" + val SunSecurityKrb5Debug = "sun.security.krb5.debug" + + def main(args: Array[String]): Unit = { + args match { + case Array(workDirPath, configPath, keytabPath, principals@ _*) if principals.nonEmpty => + val workDir = new File(workDirPath) + if (!workDir.exists) + throw new RuntimeException(s"Specified work directory does not exist: ${workDir.getAbsolutePath}") + val config = createConfig + val configFile = new File(configPath) + if (!configFile.exists) + throw new RuntimeException(s"Specified configuration does not exist: ${configFile.getAbsolutePath}") + + val userConfig = Utils.loadProps(configFile.getAbsolutePath) + userConfig.forEach { (key, value) => + config.put(key, value) + } + val keytabFile = new File(keytabPath).getAbsoluteFile + start(workDir, config, keytabFile, principals) + case _ => + println("Arguments: []+") + Exit.exit(1) + } + } + + private[minikdc] def start(workDir: File, config: Properties, keytabFile: File, principals: Seq[String]): MiniKdc = { + val miniKdc = new MiniKdc(config, workDir) + miniKdc.start() + miniKdc.createPrincipal(keytabFile, principals: _*) + val infoMessage = s""" + | + |Standalone MiniKdc Running + |--------------------------------------------------- + | Realm : ${miniKdc.realm} + | Running at : ${miniKdc.host}:${miniKdc.port} + | krb5conf : ${miniKdc.krb5conf} + | + | created keytab : $keytabFile + | with principals : ${principals.mkString(", ")} + | + |Hit or kill to stop it + |--------------------------------------------------- + | + """.stripMargin + println(infoMessage) + Exit.addShutdownHook("minikdc-shutdown-hook", miniKdc.stop()) + miniKdc + } + + val OrgName = "org.name" + val OrgDomain = "org.domain" + val KdcBindAddress = "kdc.bind.address" + val KdcPort = "kdc.port" + val Instance = "instance" + val MaxTicketLifetime = "max.ticket.lifetime" + val MaxRenewableLifetime = "max.renewable.lifetime" + val Transport = "transport" + val Debug = "debug" + + private val RequiredProperties = Set(OrgName, OrgDomain, KdcBindAddress, KdcPort, Instance, Transport, + MaxTicketLifetime, MaxRenewableLifetime) + + private val DefaultConfig = Map( + KdcBindAddress -> "localhost", + KdcPort -> "0", + Instance -> "DefaultKrbServer", + OrgName -> "Example", + OrgDomain -> "COM", + Transport -> "TCP", + MaxTicketLifetime -> "86400000", + MaxRenewableLifetime -> "604800000", + Debug -> "false" + ) + + /** + * Convenience method that returns MiniKdc default configuration. + * + * The returned configuration is a copy, it can be customized before using + * it to create a MiniKdc. + */ + def createConfig: Properties = { + val properties = new Properties + DefaultConfig.foreach { case (k, v) => properties.setProperty(k, v) } + properties + } + + @throws[IOException] + def getResourceAsStream(resourceName: String): InputStream = { + val cl = Option(Thread.currentThread.getContextClassLoader).getOrElse(classOf[MiniKdc].getClassLoader) + Option(cl.getResourceAsStream(resourceName)).getOrElse { + throw new IOException(s"Can not read resource file `$resourceName`") + } + } + +} diff --git a/core/src/test/scala/kafka/security/minikdc/MiniKdcTest.scala b/core/src/test/scala/kafka/security/minikdc/MiniKdcTest.scala new file mode 100644 index 0000000..23263fc --- /dev/null +++ b/core/src/test/scala/kafka/security/minikdc/MiniKdcTest.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.security.minikdc + +import java.util.Properties + +import kafka.utils.TestUtils +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Assertions._ + +class MiniKdcTest { + @Test + def shouldNotStopImmediatelyWhenStarted(): Unit = { + val config = new Properties() + config.setProperty("kdc.bind.address", "0.0.0.0") + config.setProperty("transport", "TCP"); + config.setProperty("max.ticket.lifetime", "86400000") + config.setProperty("org.name", "Example") + config.setProperty("kdc.port", "0") + config.setProperty("org.domain", "COM") + config.setProperty("max.renewable.lifetime", "604800000") + config.setProperty("instance", "DefaultKrbServer") + val minikdc = MiniKdc.start(TestUtils.tempDir(), config, TestUtils.tempFile(), List("foo")) + val running = System.getProperty(MiniKdc.JavaSecurityKrb5Conf) != null + try { + assertTrue(running, "MiniKdc stopped immediately; it should not have") + } finally { + if (running) minikdc.stop() + } + } +} \ No newline at end of file diff --git a/core/src/test/scala/kafka/server/BrokerMetadataCheckpointTest.scala b/core/src/test/scala/kafka/server/BrokerMetadataCheckpointTest.scala new file mode 100644 index 0000000..c7ce0ac --- /dev/null +++ b/core/src/test/scala/kafka/server/BrokerMetadataCheckpointTest.scala @@ -0,0 +1,156 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package kafka.server + +import java.io.File +import java.util.Properties + +import org.apache.kafka.common.utils.Utils +import org.apache.kafka.test.TestUtils +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +class BrokerMetadataCheckpointTest { + private val clusterIdBase64 = "H3KKO4NTRPaCWtEmm3vW7A" + + @Test + def testReadWithNonExistentFile(): Unit = { + assertEquals(None, new BrokerMetadataCheckpoint(new File("path/that/does/not/exist")).read()) + } + + @Test + def testCreateZkMetadataProperties(): Unit = { + val meta = ZkMetaProperties("7bc79ca1-9746-42a3-a35a-efb3cde44492", 3) + val properties = meta.toProperties + val parsed = new RawMetaProperties(properties) + assertEquals(0, parsed.version) + assertEquals(Some(meta.clusterId), parsed.clusterId) + assertEquals(Some(meta.brokerId), parsed.brokerId) + } + + @Test + def testParseRawMetaPropertiesWithoutVersion(): Unit = { + val brokerId = 1 + val clusterId = "7bc79ca1-9746-42a3-a35a-efb3cde44492" + + val properties = new Properties() + properties.put(RawMetaProperties.BrokerIdKey, brokerId.toString) + properties.put(RawMetaProperties.ClusterIdKey, clusterId) + + val parsed = new RawMetaProperties(properties) + assertEquals(Some(brokerId), parsed.brokerId) + assertEquals(Some(clusterId), parsed.clusterId) + assertEquals(0, parsed.version) + } + + @Test + def testRawPropertiesWithInvalidBrokerId(): Unit = { + val properties = new Properties() + properties.put(RawMetaProperties.BrokerIdKey, "oof") + val parsed = new RawMetaProperties(properties) + assertThrows(classOf[RuntimeException], () => parsed.brokerId) + } + + @Test + def testCreateMetadataProperties(): Unit = { + confirmValidForMetaProperties(clusterIdBase64) + } + + @Test + def testMetaPropertiesWithMissingVersion(): Unit = { + val properties = new RawMetaProperties() + properties.clusterId = clusterIdBase64 + properties.nodeId = 1 + assertThrows(classOf[RuntimeException], () => MetaProperties.parse(properties)) + } + + @Test + def testMetaPropertiesAllowsHexEncodedUUIDs(): Unit = { + val clusterId = "7bc79ca1-9746-42a3-a35a-efb3cde44492" + confirmValidForMetaProperties(clusterId) + } + + @Test + def testMetaPropertiesWithNonUuidClusterId(): Unit = { + val clusterId = "not a valid uuid" + confirmValidForMetaProperties(clusterId) + } + + private def confirmValidForMetaProperties(clusterId: String) = { + val meta = MetaProperties( + clusterId = clusterId, + nodeId = 5 + ) + val properties = new RawMetaProperties(meta.toProperties) + val meta2 = MetaProperties.parse(properties) + assertEquals(meta, meta2) + } + + @Test + def testMetaPropertiesWithMissingBrokerId(): Unit = { + val properties = new RawMetaProperties() + properties.version = 1 + properties.clusterId = clusterIdBase64 + assertThrows(classOf[RuntimeException], () => MetaProperties.parse(properties)) + } + + @Test + def testMetaPropertiesWithMissingControllerId(): Unit = { + val properties = new RawMetaProperties() + properties.version = 1 + properties.clusterId = clusterIdBase64 + assertThrows(classOf[RuntimeException], () => MetaProperties.parse(properties)) + } + + @Test + def testGetBrokerMetadataAndOfflineDirsWithNonexistentDirectories(): Unit = { + // Use a regular file as an invalid log dir to trigger an IO error + val invalidDir = TestUtils.tempFile("blah") + try { + // The `ignoreMissing` flag has no effect if there is an IO error + testEmptyGetBrokerMetadataAndOfflineDirs(invalidDir, + expectedOfflineDirs = Seq(invalidDir), ignoreMissing = true) + testEmptyGetBrokerMetadataAndOfflineDirs(invalidDir, + expectedOfflineDirs = Seq(invalidDir), ignoreMissing = false) + } finally { + Utils.delete(invalidDir) + } + } + + @Test + def testGetBrokerMetadataAndOfflineDirsIgnoreMissing(): Unit = { + val tempDir = TestUtils.tempDirectory() + try { + testEmptyGetBrokerMetadataAndOfflineDirs(tempDir, + expectedOfflineDirs = Seq(), ignoreMissing = true) + + assertThrows(classOf[RuntimeException], + () => BrokerMetadataCheckpoint.getBrokerMetadataAndOfflineDirs( + Seq(tempDir.getAbsolutePath), false)) + } finally { + Utils.delete(tempDir) + } + } + + private def testEmptyGetBrokerMetadataAndOfflineDirs( + logDir: File, + expectedOfflineDirs: Seq[File], + ignoreMissing: Boolean + ): Unit = { + val (metaProperties, offlineDirs) = BrokerMetadataCheckpoint.getBrokerMetadataAndOfflineDirs( + Seq(logDir.getAbsolutePath), ignoreMissing) + assertEquals(expectedOfflineDirs.map(_.getAbsolutePath), offlineDirs) + assertEquals(new Properties(), metaProperties.props) + } + +} diff --git a/core/src/test/scala/kafka/server/BrokerToControllerRequestThreadTest.scala b/core/src/test/scala/kafka/server/BrokerToControllerRequestThreadTest.scala new file mode 100644 index 0000000..3297ec0 --- /dev/null +++ b/core/src/test/scala/kafka/server/BrokerToControllerRequestThreadTest.scala @@ -0,0 +1,466 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.nio.ByteBuffer +import java.util.Collections +import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} +import kafka.utils.TestUtils +import org.apache.kafka.clients.{ClientResponse, ManualMetadataUpdater, Metadata, MockClient, NodeApiVersions} +import org.apache.kafka.common.Node +import org.apache.kafka.common.message.{EnvelopeResponseData, MetadataRequestData} +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.requests.{AbstractRequest, EnvelopeRequest, EnvelopeResponse, MetadataRequest, MetadataResponse, RequestTestUtils} +import org.apache.kafka.common.security.auth.KafkaPrincipal +import org.apache.kafka.common.security.authenticator.DefaultKafkaPrincipalBuilder +import org.apache.kafka.common.utils.MockTime +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test +import org.mockito.Mockito._ + + +class BrokerToControllerRequestThreadTest { + + @Test + def testRetryTimeoutWhileControllerNotAvailable(): Unit = { + val time = new MockTime() + val config = new KafkaConfig(TestUtils.createBrokerConfig(1, "localhost:2181")) + val metadata = mock(classOf[Metadata]) + val mockClient = new MockClient(time, metadata) + val controllerNodeProvider = mock(classOf[ControllerNodeProvider]) + + when(controllerNodeProvider.get()).thenReturn(None) + + val retryTimeoutMs = 30000 + val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), controllerNodeProvider, + config, time, "", retryTimeoutMs) + testRequestThread.started = true + + val completionHandler = new TestRequestCompletionHandler(None) + val queueItem = BrokerToControllerQueueItem( + time.milliseconds(), + new MetadataRequest.Builder(new MetadataRequestData()), + completionHandler + ) + + testRequestThread.enqueue(queueItem) + testRequestThread.doWork() + assertEquals(1, testRequestThread.queueSize) + + time.sleep(retryTimeoutMs) + testRequestThread.doWork() + assertEquals(0, testRequestThread.queueSize) + assertTrue(completionHandler.timedOut.get) + } + + @Test + def testRequestsSent(): Unit = { + // just a simple test that tests whether the request from 1 -> 2 is sent and the response callback is called + val time = new MockTime() + val config = new KafkaConfig(TestUtils.createBrokerConfig(1, "localhost:2181")) + val controllerId = 2 + + val metadata = mock(classOf[Metadata]) + val mockClient = new MockClient(time, metadata) + + val controllerNodeProvider = mock(classOf[ControllerNodeProvider]) + val activeController = new Node(controllerId, "host", 1234) + + when(controllerNodeProvider.get()).thenReturn(Some(activeController)) + + val expectedResponse = RequestTestUtils.metadataUpdateWith(2, Collections.singletonMap("a", 2)) + val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), controllerNodeProvider, + config, time, "", retryTimeoutMs = Long.MaxValue) + testRequestThread.started = true + mockClient.prepareResponse(expectedResponse) + + val completionHandler = new TestRequestCompletionHandler(Some(expectedResponse)) + val queueItem = BrokerToControllerQueueItem( + time.milliseconds(), + new MetadataRequest.Builder(new MetadataRequestData()), + completionHandler + ) + + testRequestThread.enqueue(queueItem) + assertEquals(1, testRequestThread.queueSize) + + // initialize to the controller + testRequestThread.doWork() + // send and process the request + testRequestThread.doWork() + + assertEquals(0, testRequestThread.queueSize) + assertTrue(completionHandler.completed.get()) + } + + @Test + def testControllerChanged(): Unit = { + // in this test the current broker is 1, and the controller changes from 2 -> 3 then back: 3 -> 2 + val time = new MockTime() + val config = new KafkaConfig(TestUtils.createBrokerConfig(1, "localhost:2181")) + val oldControllerId = 1 + val newControllerId = 2 + + val metadata = mock(classOf[Metadata]) + val mockClient = new MockClient(time, metadata) + + val controllerNodeProvider = mock(classOf[ControllerNodeProvider]) + val oldController = new Node(oldControllerId, "host1", 1234) + val newController = new Node(newControllerId, "host2", 1234) + + when(controllerNodeProvider.get()).thenReturn(Some(oldController), Some(newController)) + + val expectedResponse = RequestTestUtils.metadataUpdateWith(3, Collections.singletonMap("a", 2)) + val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), + controllerNodeProvider, config, time, "", retryTimeoutMs = Long.MaxValue) + testRequestThread.started = true + + val completionHandler = new TestRequestCompletionHandler(Some(expectedResponse)) + val queueItem = BrokerToControllerQueueItem( + time.milliseconds(), + new MetadataRequest.Builder(new MetadataRequestData()), + completionHandler, + ) + + testRequestThread.enqueue(queueItem) + mockClient.prepareResponse(expectedResponse) + // initialize the thread with oldController + testRequestThread.doWork() + assertFalse(completionHandler.completed.get()) + + // disconnect the node + mockClient.setUnreachable(oldController, time.milliseconds() + 5000) + // verify that the client closed the connection to the faulty controller + testRequestThread.doWork() + // should connect to the new controller + testRequestThread.doWork() + // should send the request and process the response + testRequestThread.doWork() + + assertTrue(completionHandler.completed.get()) + } + + @Test + def testNotController(): Unit = { + val time = new MockTime() + val config = new KafkaConfig(TestUtils.createBrokerConfig(1, "localhost:2181")) + val oldControllerId = 1 + val newControllerId = 2 + + val metadata = mock(classOf[Metadata]) + val mockClient = new MockClient(time, metadata) + + val controllerNodeProvider = mock(classOf[ControllerNodeProvider]) + val port = 1234 + val oldController = new Node(oldControllerId, "host1", port) + val newController = new Node(newControllerId, "host2", port) + + when(controllerNodeProvider.get()).thenReturn(Some(oldController), Some(newController)) + + val responseWithNotControllerError = RequestTestUtils.metadataUpdateWith("cluster1", 2, + Collections.singletonMap("a", Errors.NOT_CONTROLLER), + Collections.singletonMap("a", 2)) + val expectedResponse = RequestTestUtils.metadataUpdateWith(3, Collections.singletonMap("a", 2)) + val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), controllerNodeProvider, + config, time, "", retryTimeoutMs = Long.MaxValue) + testRequestThread.started = true + + val completionHandler = new TestRequestCompletionHandler(Some(expectedResponse)) + val queueItem = BrokerToControllerQueueItem( + time.milliseconds(), + new MetadataRequest.Builder(new MetadataRequestData() + .setAllowAutoTopicCreation(true)), + completionHandler + ) + testRequestThread.enqueue(queueItem) + // initialize to the controller + testRequestThread.doWork() + + val oldBrokerNode = new Node(oldControllerId, "host1", port) + assertEquals(Some(oldBrokerNode), testRequestThread.activeControllerAddress()) + + // send and process the request + mockClient.prepareResponse((body: AbstractRequest) => { + body.isInstanceOf[MetadataRequest] && + body.asInstanceOf[MetadataRequest].allowAutoTopicCreation() + }, responseWithNotControllerError) + testRequestThread.doWork() + assertEquals(None, testRequestThread.activeControllerAddress()) + // reinitialize the controller to a different node + testRequestThread.doWork() + // process the request again + mockClient.prepareResponse(expectedResponse) + testRequestThread.doWork() + + val newControllerNode = new Node(newControllerId, "host2", port) + assertEquals(Some(newControllerNode), testRequestThread.activeControllerAddress()) + + assertTrue(completionHandler.completed.get()) + } + + @Test + def testEnvelopeResponseWithNotControllerError(): Unit = { + val time = new MockTime() + val config = new KafkaConfig(TestUtils.createBrokerConfig(1, "localhost:2181")) + val oldControllerId = 1 + val newControllerId = 2 + + val metadata = mock(classOf[Metadata]) + val mockClient = new MockClient(time, metadata) + // enable envelope API + mockClient.setNodeApiVersions(NodeApiVersions.create(ApiKeys.ENVELOPE.id, 0.toShort, 0.toShort)) + + val controllerNodeProvider = mock(classOf[ControllerNodeProvider]) + val port = 1234 + val oldController = new Node(oldControllerId, "host1", port) + val newController = new Node(newControllerId, "host2", port) + + when(controllerNodeProvider.get()).thenReturn(Some(oldController), Some(newController)) + + // create an envelopeResponse with NOT_CONTROLLER error + val envelopeResponseWithNotControllerError = new EnvelopeResponse( + new EnvelopeResponseData().setErrorCode(Errors.NOT_CONTROLLER.code())) + + // response for retry request after receiving NOT_CONTROLLER error + val expectedResponse = RequestTestUtils.metadataUpdateWith(3, Collections.singletonMap("a", 2)) + + val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), controllerNodeProvider, + config, time, "", retryTimeoutMs = Long.MaxValue) + testRequestThread.started = true + + val completionHandler = new TestRequestCompletionHandler(Some(expectedResponse)) + val kafkaPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "principal", true) + val kafkaPrincipalBuilder = new DefaultKafkaPrincipalBuilder(null, null) + + // build an EnvelopeRequest by dummy data + val envelopeRequestBuilder = new EnvelopeRequest.Builder(ByteBuffer.allocate(0), + kafkaPrincipalBuilder.serialize(kafkaPrincipal), "client-address".getBytes) + + val queueItem = BrokerToControllerQueueItem( + time.milliseconds(), + envelopeRequestBuilder, + completionHandler + ) + + testRequestThread.enqueue(queueItem) + // initialize to the controller + testRequestThread.doWork() + + val oldBrokerNode = new Node(oldControllerId, "host1", port) + assertEquals(Some(oldBrokerNode), testRequestThread.activeControllerAddress()) + + // send and process the envelope request + mockClient.prepareResponse((body: AbstractRequest) => { + body.isInstanceOf[EnvelopeRequest] + }, envelopeResponseWithNotControllerError) + testRequestThread.doWork() + // expect to reset the activeControllerAddress after finding the NOT_CONTROLLER error + assertEquals(None, testRequestThread.activeControllerAddress()) + // reinitialize the controller to a different node + testRequestThread.doWork() + // process the request again + mockClient.prepareResponse(expectedResponse) + testRequestThread.doWork() + + val newControllerNode = new Node(newControllerId, "host2", port) + assertEquals(Some(newControllerNode), testRequestThread.activeControllerAddress()) + + assertTrue(completionHandler.completed.get()) + } + + @Test + def testRetryTimeout(): Unit = { + val time = new MockTime() + val config = new KafkaConfig(TestUtils.createBrokerConfig(1, "localhost:2181")) + val controllerId = 1 + + val metadata = mock(classOf[Metadata]) + val mockClient = new MockClient(time, metadata) + + val controllerNodeProvider = mock(classOf[ControllerNodeProvider]) + val controller = new Node(controllerId, "host1", 1234) + + when(controllerNodeProvider.get()).thenReturn(Some(controller)) + + val retryTimeoutMs = 30000 + val responseWithNotControllerError = RequestTestUtils.metadataUpdateWith("cluster1", 2, + Collections.singletonMap("a", Errors.NOT_CONTROLLER), + Collections.singletonMap("a", 2)) + val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), controllerNodeProvider, + config, time, "", retryTimeoutMs) + testRequestThread.started = true + + val completionHandler = new TestRequestCompletionHandler() + val queueItem = BrokerToControllerQueueItem( + time.milliseconds(), + new MetadataRequest.Builder(new MetadataRequestData() + .setAllowAutoTopicCreation(true)), + completionHandler + ) + + testRequestThread.enqueue(queueItem) + + // initialize to the controller + testRequestThread.doWork() + + time.sleep(retryTimeoutMs) + + // send and process the request + mockClient.prepareResponse((body: AbstractRequest) => { + body.isInstanceOf[MetadataRequest] && + body.asInstanceOf[MetadataRequest].allowAutoTopicCreation() + }, responseWithNotControllerError) + + testRequestThread.doWork() + + assertTrue(completionHandler.timedOut.get()) + } + + @Test + def testUnsupportedVersionHandling(): Unit = { + val time = new MockTime() + val config = new KafkaConfig(TestUtils.createBrokerConfig(1, "localhost:2181")) + val controllerId = 2 + + val metadata = mock(classOf[Metadata]) + val mockClient = new MockClient(time, metadata) + + val controllerNodeProvider = mock(classOf[ControllerNodeProvider]) + val activeController = new Node(controllerId, "host", 1234) + + when(controllerNodeProvider.get()).thenReturn(Some(activeController)) + + val callbackResponse = new AtomicReference[ClientResponse]() + val completionHandler = new ControllerRequestCompletionHandler { + override def onTimeout(): Unit = fail("Unexpected timeout exception") + override def onComplete(response: ClientResponse): Unit = callbackResponse.set(response) + } + + val queueItem = BrokerToControllerQueueItem( + time.milliseconds(), + new MetadataRequest.Builder(new MetadataRequestData()), + completionHandler + ) + + mockClient.prepareUnsupportedVersionResponse(request => request.apiKey == ApiKeys.METADATA) + + val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), controllerNodeProvider, + config, time, "", retryTimeoutMs = Long.MaxValue) + testRequestThread.started = true + + testRequestThread.enqueue(queueItem) + pollUntil(testRequestThread, () => callbackResponse.get != null) + assertNotNull(callbackResponse.get.versionMismatch) + } + + @Test + def testAuthenticationExceptionHandling(): Unit = { + val time = new MockTime() + val config = new KafkaConfig(TestUtils.createBrokerConfig(1, "localhost:2181")) + val controllerId = 2 + + val metadata = mock(classOf[Metadata]) + val mockClient = new MockClient(time, metadata) + + val controllerNodeProvider = mock(classOf[ControllerNodeProvider]) + val activeController = new Node(controllerId, "host", 1234) + + when(controllerNodeProvider.get()).thenReturn(Some(activeController)) + + val callbackResponse = new AtomicReference[ClientResponse]() + val completionHandler = new ControllerRequestCompletionHandler { + override def onTimeout(): Unit = fail("Unexpected timeout exception") + override def onComplete(response: ClientResponse): Unit = callbackResponse.set(response) + } + + val queueItem = BrokerToControllerQueueItem( + time.milliseconds(), + new MetadataRequest.Builder(new MetadataRequestData()), + completionHandler + ) + + mockClient.createPendingAuthenticationError(activeController, 50) + + val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), controllerNodeProvider, + config, time, "", retryTimeoutMs = Long.MaxValue) + testRequestThread.started = true + + testRequestThread.enqueue(queueItem) + pollUntil(testRequestThread, () => callbackResponse.get != null) + assertNotNull(callbackResponse.get.authenticationException) + } + + @Test + def testThreadNotStarted(): Unit = { + // Make sure we throw if we enqueue anything while the thread is not running + val time = new MockTime() + val config = new KafkaConfig(TestUtils.createBrokerConfig(1, "localhost:2181")) + + val metadata = mock(classOf[Metadata]) + val mockClient = new MockClient(time, metadata) + + val controllerNodeProvider = mock(classOf[ControllerNodeProvider]) + + val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), controllerNodeProvider, + config, time, "", retryTimeoutMs = Long.MaxValue) + + val completionHandler = new TestRequestCompletionHandler(None) + val queueItem = BrokerToControllerQueueItem( + time.milliseconds(), + new MetadataRequest.Builder(new MetadataRequestData()), + completionHandler + ) + + assertThrows(classOf[IllegalStateException], () => testRequestThread.enqueue(queueItem)) + assertEquals(0, testRequestThread.queueSize) + } + + private def pollUntil( + requestThread: BrokerToControllerRequestThread, + condition: () => Boolean, + maxRetries: Int = 10 + ): Unit = { + var tries = 0 + do { + requestThread.doWork() + tries += 1 + } while (!condition.apply() && tries < maxRetries) + + if (!condition.apply()) { + fail(s"Condition failed to be met after polling $tries times") + } + } + + class TestRequestCompletionHandler( + expectedResponse: Option[MetadataResponse] = None + ) extends ControllerRequestCompletionHandler { + val completed: AtomicBoolean = new AtomicBoolean(false) + val timedOut: AtomicBoolean = new AtomicBoolean(false) + + override def onComplete(response: ClientResponse): Unit = { + expectedResponse.foreach { expected => + assertEquals(expected, response.responseBody()) + } + completed.set(true) + } + + override def onTimeout(): Unit = { + timedOut.set(true) + } + } +} diff --git a/core/src/test/scala/kafka/server/metadata/MockConfigRepository.scala b/core/src/test/scala/kafka/server/metadata/MockConfigRepository.scala new file mode 100644 index 0000000..9d42fea --- /dev/null +++ b/core/src/test/scala/kafka/server/metadata/MockConfigRepository.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server.metadata; + +import java.util +import java.util.Properties + +import org.apache.kafka.common.config.ConfigResource +import org.apache.kafka.common.config.ConfigResource.Type.TOPIC + +object MockConfigRepository { + def forTopic(topic: String, key: String, value: String): MockConfigRepository = { + val properties = new Properties() + properties.put(key, value) + forTopic(topic, properties) + } + + def forTopic(topic: String, properties: Properties): MockConfigRepository = { + val repository = new MockConfigRepository() + repository.configs.put(new ConfigResource(TOPIC, topic), properties) + repository + } +} + +class MockConfigRepository extends ConfigRepository { + val configs = new util.HashMap[ConfigResource, Properties]() + + override def config(configResource: ConfigResource): Properties = configs.synchronized { + configs.getOrDefault(configResource, new Properties()) + } + + def setConfig(configResource: ConfigResource, key: String, value: String): Unit = configs.synchronized { + val properties = configs.getOrDefault(configResource, new Properties()) + val newProperties = new Properties() + newProperties.putAll(properties) + if (value == null) { + newProperties.remove(key) + } else { + newProperties.put(key, value) + } + configs.put(configResource, newProperties) + } + + def setTopicConfig(topicName: String, key: String, value: String): Unit = configs.synchronized { + setConfig(new ConfigResource(TOPIC, topicName), key, value) + } +} diff --git a/core/src/test/scala/kafka/tools/CustomDeserializerTest.scala b/core/src/test/scala/kafka/tools/CustomDeserializerTest.scala new file mode 100644 index 0000000..244a9cf --- /dev/null +++ b/core/src/test/scala/kafka/tools/CustomDeserializerTest.scala @@ -0,0 +1,65 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.tools + +import java.io.PrintStream + +import kafka.utils.TestUtils +import org.apache.kafka.clients.consumer.ConsumerRecord +import org.apache.kafka.common.header.Headers +import org.apache.kafka.common.serialization.Deserializer +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test +import org.mockito.Mockito._ + +class CustomDeserializer extends Deserializer[String] { + + override def deserialize(topic: String, data: Array[Byte]): String = { + assertNotNull(topic, "topic must not be null") + new String(data) + } + + override def deserialize(topic: String, headers: Headers, data: Array[Byte]): String = { + println("WITH HEADERS") + new String(data) + } +} + +class CustomDeserializerTest { + + @Test + def checkFormatterCallDeserializerWithHeaders(): Unit = { + val formatter = new DefaultMessageFormatter() + formatter.valueDeserializer = Some(new CustomDeserializer) + val output = TestUtils.grabConsoleOutput(formatter.writeTo( + new ConsumerRecord("topic_test", 1, 1L, "key".getBytes, "value".getBytes), mock(classOf[PrintStream]))) + assertTrue(output.contains("WITH HEADERS"), "DefaultMessageFormatter should call `deserialize` method with headers.") + formatter.close() + } + + @Test + def checkDeserializerTopicIsNotNull(): Unit = { + val formatter = new DefaultMessageFormatter() + formatter.keyDeserializer = Some(new CustomDeserializer) + + formatter.writeTo(new ConsumerRecord("topic_test", 1, 1L, "key".getBytes, "value".getBytes), + mock(classOf[PrintStream])) + + formatter.close() + } +} diff --git a/core/src/test/scala/kafka/tools/DefaultMessageFormatterTest.scala b/core/src/test/scala/kafka/tools/DefaultMessageFormatterTest.scala new file mode 100644 index 0000000..12bbc94 --- /dev/null +++ b/core/src/test/scala/kafka/tools/DefaultMessageFormatterTest.scala @@ -0,0 +1,234 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.tools + +import java.io.{ByteArrayOutputStream, Closeable, PrintStream} +import java.nio.charset.StandardCharsets +import java.util + +import org.apache.kafka.clients.consumer.ConsumerRecord +import org.apache.kafka.common.header.Header +import org.apache.kafka.common.header.internals.{RecordHeader, RecordHeaders} +import org.apache.kafka.common.record.TimestampType +import org.apache.kafka.common.serialization.Deserializer +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.{Arguments, MethodSource} + +import java.util.Optional +import scala.jdk.CollectionConverters._ + +class DefaultMessageFormatterTest { + import DefaultMessageFormatterTest._ + + @ParameterizedTest + @MethodSource(Array("parameters")) + def testWriteRecord(name: String, record: ConsumerRecord[Array[Byte], Array[Byte]], properties: Map[String, String], expected: String): Unit = { + withResource(new ByteArrayOutputStream()) { baos => + withResource(new PrintStream(baos)) { ps => + val formatter = buildFormatter(properties) + formatter.writeTo(record, ps) + val actual = new String(baos.toByteArray(), StandardCharsets.UTF_8) + assertEquals(expected, actual) + + } + } + } +} + +object DefaultMessageFormatterTest { + def parameters: java.util.stream.Stream[Arguments] = { + Seq( + Arguments.of( + "print nothing", + consumerRecord(), + Map("print.value" -> "false"), + ""), + Arguments.of( + "print key", + consumerRecord(), + Map("print.key" -> "true", + "print.value" -> "false"), + "someKey\n"), + Arguments.of( + "print value", + consumerRecord(), + Map(), + "someValue\n"), + Arguments.of( + "print empty timestamp", + consumerRecord(timestampType = TimestampType.NO_TIMESTAMP_TYPE), + Map("print.timestamp" -> "true", + "print.value" -> "false"), + "NO_TIMESTAMP\n"), + Arguments.of( + "print log append time timestamp", + consumerRecord(timestampType = TimestampType.LOG_APPEND_TIME), + Map("print.timestamp" -> "true", + "print.value" -> "false"), + "LogAppendTime:1234\n"), + Arguments.of( + "print create time timestamp", + consumerRecord(timestampType = TimestampType.CREATE_TIME), + Map("print.timestamp" -> "true", + "print.value" -> "false"), + "CreateTime:1234\n"), + Arguments.of( + "print partition", + consumerRecord(), + Map("print.partition" -> "true", + "print.value" -> "false"), + "Partition:9\n"), + Arguments.of( + "print offset", + consumerRecord(), + Map("print.offset" -> "true", + "print.value" -> "false"), + "Offset:9876\n"), + Arguments.of( + "print headers", + consumerRecord(), + Map("print.headers" -> "true", + "print.value" -> "false"), + "h1:v1,h2:v2\n"), + Arguments.of( + "print empty headers", + consumerRecord(headers = Nil), + Map("print.headers" -> "true", + "print.value" -> "false"), + "NO_HEADERS\n"), + Arguments.of( + "print all possible fields with default delimiters", + consumerRecord(), + Map("print.key" -> "true", + "print.timestamp" -> "true", + "print.partition" -> "true", + "print.offset" -> "true", + "print.headers" -> "true", + "print.value" -> "true"), + "CreateTime:1234\tPartition:9\tOffset:9876\th1:v1,h2:v2\tsomeKey\tsomeValue\n"), + Arguments.of( + "print all possible fields with custom delimiters", + consumerRecord(), + Map("key.separator" -> "|", + "line.separator" -> "^", + "headers.separator" -> "#", + "print.key" -> "true", + "print.timestamp" -> "true", + "print.partition" -> "true", + "print.offset" -> "true", + "print.headers" -> "true", + "print.value" -> "true"), + "CreateTime:1234|Partition:9|Offset:9876|h1:v1#h2:v2|someKey|someValue^"), + Arguments.of( + "print key with custom deserializer", + consumerRecord(), + Map("print.key" -> "true", + "print.headers" -> "true", + "print.value" -> "true", + "key.deserializer" -> "kafka.tools.UpperCaseDeserializer"), + "h1:v1,h2:v2\tSOMEKEY\tsomeValue\n"), + Arguments.of( + "print value with custom deserializer", + consumerRecord(), + Map("print.key" -> "true", + "print.headers" -> "true", + "print.value" -> "true", + "value.deserializer" -> "kafka.tools.UpperCaseDeserializer"), + "h1:v1,h2:v2\tsomeKey\tSOMEVALUE\n"), + Arguments.of( + "print headers with custom deserializer", + consumerRecord(), + Map("print.key" -> "true", + "print.headers" -> "true", + "print.value" -> "true", + "headers.deserializer" -> "kafka.tools.UpperCaseDeserializer"), + "h1:V1,h2:V2\tsomeKey\tsomeValue\n"), + Arguments.of( + "print key and value", + consumerRecord(), + Map("print.key" -> "true", + "print.value" -> "true"), + "someKey\tsomeValue\n"), + Arguments.of( + "print fields in the beginning, middle and the end", + consumerRecord(), + Map("print.key" -> "true", + "print.value" -> "true", + "print.partition" -> "true"), + "Partition:9\tsomeKey\tsomeValue\n"), + Arguments.of( + "null value without custom null literal", + consumerRecord(value = null), + Map("print.key" -> "true"), + "someKey\tnull\n"), + Arguments.of( + "null value with custom null literal", + consumerRecord(value = null), + Map("print.key" -> "true", + "null.literal" -> "NULL"), + "someKey\tNULL\n"), + ).asJava.stream() + } + + private def buildFormatter(propsToSet: Map[String, String]): DefaultMessageFormatter = { + val formatter = new DefaultMessageFormatter() + formatter.configure(propsToSet.asJava) + formatter + } + + + private def header(key: String, value: String) = { + new RecordHeader(key, value.getBytes(StandardCharsets.UTF_8)) + } + + private def consumerRecord(key: String = "someKey", + value: String = "someValue", + headers: Iterable[Header] = Seq(header("h1", "v1"), header("h2", "v2")), + partition: Int = 9, + offset: Long = 9876, + timestamp: Long = 1234, + timestampType: TimestampType = TimestampType.CREATE_TIME) = { + new ConsumerRecord[Array[Byte], Array[Byte]]( + "someTopic", + partition, + offset, + timestamp, + timestampType, + 0, + 0, + if (key == null) null else key.getBytes(StandardCharsets.UTF_8), + if (value == null) null else value.getBytes(StandardCharsets.UTF_8), + new RecordHeaders(headers.asJava), + Optional.empty[Integer]) + } + + private def withResource[Resource <: Closeable, Result](resource: Resource)(handler: Resource => Result): Result = { + try { + handler(resource) + } finally { + resource.close() + } + } +} + +class UpperCaseDeserializer extends Deserializer[String] { + override def configure(configs: util.Map[String, _], isKey: Boolean): Unit = {} + override def deserialize(topic: String, data: Array[Byte]): String = new String(data, StandardCharsets.UTF_8).toUpperCase + override def close(): Unit = {} +} diff --git a/core/src/test/scala/kafka/tools/GetOffsetShellParsingTest.scala b/core/src/test/scala/kafka/tools/GetOffsetShellParsingTest.scala new file mode 100644 index 0000000..edfadea --- /dev/null +++ b/core/src/test/scala/kafka/tools/GetOffsetShellParsingTest.scala @@ -0,0 +1,207 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.tools + +import org.apache.kafka.common.PartitionInfo +import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertThrows, assertTrue} +import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ValueSource + +class GetOffsetShellParsingTest { + @ParameterizedTest + @ValueSource(booleans = Array(true, false)) + def testTopicPartitionFilterForTopicName(excludeInternal: Boolean): Unit = { + val filter = GetOffsetShell.createTopicPartitionFilterWithPatternList("test", excludeInternal) + assertTrue(filter.apply(partitionInfo("test", 0))) + assertTrue(filter.apply(partitionInfo("test", 1))) + assertFalse(filter.apply(partitionInfo("test1", 0))) + assertFalse(filter.apply(partitionInfo("__consumer_offsets", 0))) + } + + @ParameterizedTest + @ValueSource(booleans = Array(true, false)) + def testTopicPartitionFilterForInternalTopicName(excludeInternal: Boolean): Unit = { + val filter = GetOffsetShell.createTopicPartitionFilterWithPatternList("__consumer_offsets", excludeInternal) + assertEquals(!excludeInternal, filter.apply(partitionInfo("__consumer_offsets", 0))) + assertEquals(!excludeInternal, filter.apply(partitionInfo("__consumer_offsets", 1))) + assertFalse(filter.apply(partitionInfo("test1", 0))) + assertFalse(filter.apply(partitionInfo("test2", 0))) + } + + @ParameterizedTest + @ValueSource(booleans = Array(true, false)) + def testTopicPartitionFilterForTopicNameList(excludeInternal: Boolean): Unit = { + val filter = GetOffsetShell.createTopicPartitionFilterWithPatternList("test,test1,__consumer_offsets", excludeInternal) + assertTrue(filter.apply(partitionInfo("test", 0))) + assertTrue(filter.apply(partitionInfo("test1", 1))) + assertFalse(filter.apply(partitionInfo("test2", 0))) + + assertEquals(!excludeInternal, filter.apply(partitionInfo("__consumer_offsets", 0))) + } + + @ParameterizedTest + @ValueSource(booleans = Array(true, false)) + def testTopicPartitionFilterForRegex(excludeInternal: Boolean): Unit = { + val filter = GetOffsetShell.createTopicPartitionFilterWithPatternList("test.*", excludeInternal) + assertTrue(filter.apply(partitionInfo("test", 0))) + assertTrue(filter.apply(partitionInfo("test1", 1))) + assertTrue(filter.apply(partitionInfo("test2", 0))) + assertFalse(filter.apply(partitionInfo("__consumer_offsets", 0))) + } + + @ParameterizedTest + @ValueSource(booleans = Array(true, false)) + def testTopicPartitionFilterForPartitionIndexSpec(excludeInternal: Boolean): Unit = { + val filter = GetOffsetShell.createTopicPartitionFilterWithPatternList(":0", excludeInternal) + assertTrue(filter.apply(partitionInfo("test", 0))) + assertTrue(filter.apply(partitionInfo("test1", 0))) + assertFalse(filter.apply(partitionInfo("test2", 1))) + + assertEquals(!excludeInternal, filter.apply(partitionInfo("__consumer_offsets", 0))) + assertFalse(filter.apply(partitionInfo("__consumer_offsets", 1))) + } + + @ParameterizedTest + @ValueSource(booleans = Array(true, false)) + def testTopicPartitionFilterForPartitionRangeSpec(excludeInternal: Boolean): Unit = { + val filter = GetOffsetShell.createTopicPartitionFilterWithPatternList(":1-3", excludeInternal) + assertTrue(filter.apply(partitionInfo("test", 1))) + assertTrue(filter.apply(partitionInfo("test1", 2))) + assertFalse(filter.apply(partitionInfo("test2", 0))) + assertFalse(filter.apply(partitionInfo("test2", 3))) + + assertEquals(!excludeInternal, filter.apply(partitionInfo("__consumer_offsets", 2))) + assertFalse(filter.apply(partitionInfo("__consumer_offsets", 3))) + } + + @ParameterizedTest + @ValueSource(booleans = Array(true, false)) + def testTopicPartitionFilterForPartitionLowerBoundSpec(excludeInternal: Boolean): Unit = { + val filter = GetOffsetShell.createTopicPartitionFilterWithPatternList(":1-", excludeInternal) + assertTrue(filter.apply(partitionInfo("test", 1))) + assertTrue(filter.apply(partitionInfo("test1", 2))) + assertFalse(filter.apply(partitionInfo("test2", 0))) + + assertEquals(!excludeInternal, filter.apply(partitionInfo("__consumer_offsets", 2))) + assertFalse(filter.apply(partitionInfo("__consumer_offsets", 0))) + } + + @ParameterizedTest + @ValueSource(booleans = Array(true, false)) + def testTopicPartitionFilterForPartitionUpperBoundSpec(excludeInternal: Boolean): Unit = { + val filter = GetOffsetShell.createTopicPartitionFilterWithPatternList(":-3", excludeInternal) + assertTrue(filter.apply(partitionInfo("test", 0))) + assertTrue(filter.apply(partitionInfo("test1", 1))) + assertTrue(filter.apply(partitionInfo("test2", 2))) + assertFalse(filter.apply(partitionInfo("test3", 3))) + + assertEquals(!excludeInternal, filter.apply(partitionInfo("__consumer_offsets", 2))) + assertFalse(filter.apply(partitionInfo("__consumer_offsets", 3))) + } + + @ParameterizedTest + @ValueSource(booleans = Array(true, false)) + def testTopicPartitionFilterComplex(excludeInternal: Boolean): Unit = { + val filter = GetOffsetShell.createTopicPartitionFilterWithPatternList("test.*:0,__consumer_offsets:1-2,.*:3", excludeInternal) + assertTrue(filter.apply(partitionInfo("test", 0))) + assertTrue(filter.apply(partitionInfo("test", 3))) + assertFalse(filter.apply(partitionInfo("test", 1))) + + assertTrue(filter.apply(partitionInfo("test1", 0))) + assertTrue(filter.apply(partitionInfo("test1", 3))) + assertFalse(filter.apply(partitionInfo("test1", 1))) + + assertTrue(filter.apply(partitionInfo("custom", 3))) + assertFalse(filter.apply(partitionInfo("custom", 0))) + + assertEquals(!excludeInternal, filter.apply(partitionInfo("__consumer_offsets", 1))) + assertEquals(!excludeInternal, filter.apply(partitionInfo("__consumer_offsets", 3))) + assertFalse(filter.apply(partitionInfo("__consumer_offsets", 0))) + assertFalse(filter.apply(partitionInfo("__consumer_offsets", 2))) + } + + @Test + def testPartitionFilterForSingleIndex(): Unit = { + val filter = GetOffsetShell.createTopicPartitionFilterWithPatternList(":1", excludeInternalTopics = false) + assertTrue(filter.apply(partitionInfo("test", 1))) + assertFalse(filter.apply(partitionInfo("test", 0))) + assertFalse(filter.apply(partitionInfo("test", 2))) + } + + @Test + def testPartitionFilterForRange(): Unit = { + val filter = GetOffsetShell.createTopicPartitionFilterWithPatternList(":1-3", excludeInternalTopics = false) + assertFalse(filter.apply(partitionInfo("test", 0))) + assertTrue(filter.apply(partitionInfo("test", 1))) + assertTrue(filter.apply(partitionInfo("test", 2))) + assertFalse(filter.apply(partitionInfo("test", 3))) + assertFalse(filter.apply(partitionInfo("test", 4))) + assertFalse(filter.apply(partitionInfo("test", 5))) + } + + @Test + def testPartitionFilterForLowerBound(): Unit = { + val filter = GetOffsetShell.createTopicPartitionFilterWithPatternList(":3-", excludeInternalTopics = false) + assertFalse(filter.apply(partitionInfo("test", 0))) + assertFalse(filter.apply(partitionInfo("test", 1))) + assertFalse(filter.apply(partitionInfo("test", 2))) + assertTrue(filter.apply(partitionInfo("test", 3))) + assertTrue(filter.apply(partitionInfo("test", 4))) + assertTrue(filter.apply(partitionInfo("test", 5))) + } + + @Test + def testPartitionFilterForUpperBound(): Unit = { + val filter = GetOffsetShell.createTopicPartitionFilterWithPatternList(":-3", excludeInternalTopics = false) + assertTrue(filter.apply(partitionInfo("test", 0))) + assertTrue(filter.apply(partitionInfo("test", 1))) + assertTrue(filter.apply(partitionInfo("test", 2))) + assertFalse(filter.apply(partitionInfo("test", 3))) + assertFalse(filter.apply(partitionInfo("test", 4))) + assertFalse(filter.apply(partitionInfo("test", 5))) + } + + @Test + def testPartitionFilterForInvalidSingleIndex(): Unit = { + assertThrows(classOf[IllegalArgumentException], + () => GetOffsetShell.createTopicPartitionFilterWithPatternList(":a", excludeInternalTopics = false)) + } + + @Test + def testPartitionFilterForInvalidRange(): Unit = { + assertThrows(classOf[IllegalArgumentException], + () => GetOffsetShell.createTopicPartitionFilterWithPatternList(":a-b", excludeInternalTopics = false)) + } + + @Test + def testPartitionFilterForInvalidLowerBound(): Unit = { + assertThrows(classOf[IllegalArgumentException], + () => GetOffsetShell.createTopicPartitionFilterWithPatternList(":a-", excludeInternalTopics = false)) + } + + @Test + def testPartitionFilterForInvalidUpperBound(): Unit = { + assertThrows(classOf[IllegalArgumentException], + () => GetOffsetShell.createTopicPartitionFilterWithPatternList(":-b", excludeInternalTopics = false)) + } + + private def partitionInfo(topic: String, partition: Int): PartitionInfo = { + new PartitionInfo(topic, partition, null, null, null) + } +} diff --git a/core/src/test/scala/kafka/tools/GetOffsetShellTest.scala b/core/src/test/scala/kafka/tools/GetOffsetShellTest.scala new file mode 100644 index 0000000..796ed36 --- /dev/null +++ b/core/src/test/scala/kafka/tools/GetOffsetShellTest.scala @@ -0,0 +1,208 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.tools + +import java.util.Properties +import kafka.integration.KafkaServerTestHarness +import kafka.server.KafkaConfig +import kafka.utils.{Exit, Logging, TestUtils} +import org.apache.kafka.clients.CommonClientConfigs +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerConfig, ProducerRecord} +import org.apache.kafka.common.serialization.StringSerializer +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.{BeforeEach, Test, TestInfo} + +class GetOffsetShellTest extends KafkaServerTestHarness with Logging { + private val topicCount = 4 + private val offsetTopicPartitionCount = 4 + + override def generateConfigs: collection.Seq[KafkaConfig] = TestUtils.createBrokerConfigs(1, zkConnect) + .map { p => + p.put(KafkaConfig.OffsetsTopicPartitionsProp, Int.box(offsetTopicPartitionCount)) + p + }.map(KafkaConfig.fromProps) + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + Range(1, topicCount + 1).foreach(i => createTopic(topicName(i), i)) + + val props = new Properties() + props.put(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, brokerList) + props.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[StringSerializer]) + props.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[StringSerializer]) + + // Send X messages to each partition of topicX + val producer = new KafkaProducer[String, String](props) + Range(1, topicCount + 1).foreach(i => Range(0, i*i) + .foreach(msgCount => producer.send(new ProducerRecord[String, String](topicName(i), msgCount % i, null, "val" + msgCount)))) + producer.close() + + TestUtils.createOffsetsTopic(zkClient, servers) + } + + @Test + def testNoFilterOptions(): Unit = { + val offsets = executeAndParse(Array()) + assertEquals(expectedOffsetsWithInternal(), offsets) + } + + @Test + def testInternalExcluded(): Unit = { + val offsets = executeAndParse(Array("--exclude-internal-topics")) + assertEquals(expectedTestTopicOffsets(), offsets) + } + + @Test + def testTopicNameArg(): Unit = { + Range(1, topicCount + 1).foreach(i => { + val offsets = executeAndParse(Array("--topic", topicName(i))) + assertEquals(expectedOffsetsForTopic(i), offsets, () => "Offset output did not match for " + topicName(i)) + }) + } + + @Test + def testTopicPatternArg(): Unit = { + val offsets = executeAndParse(Array("--topic", "topic.*")) + assertEquals(expectedTestTopicOffsets(), offsets) + } + + @Test + def testPartitionsArg(): Unit = { + val offsets = executeAndParse(Array("--partitions", "0,1")) + assertEquals(expectedOffsetsWithInternal().filter { case (_, partition, _) => partition <= 1 }, offsets) + } + + @Test + def testTopicPatternArgWithPartitionsArg(): Unit = { + val offsets = executeAndParse(Array("--topic", "topic.*", "--partitions", "0,1")) + assertEquals(expectedTestTopicOffsets().filter { case (_, partition, _) => partition <= 1 }, offsets) + } + + @Test + def testTopicPartitionsArg(): Unit = { + val offsets = executeAndParse(Array("--topic-partitions", "topic1:0,topic2:1,topic(3|4):2,__.*:3")) + assertEquals( + List( + ("__consumer_offsets", 3, Some(0)), + ("topic1", 0, Some(1)), + ("topic2", 1, Some(2)), + ("topic3", 2, Some(3)), + ("topic4", 2, Some(4)) + ), + offsets + ) + } + + @Test + def testTopicPartitionsArgWithInternalExcluded(): Unit = { + val offsets = executeAndParse(Array("--topic-partitions", + "topic1:0,topic2:1,topic(3|4):2,__.*:3", "--exclude-internal-topics")) + assertEquals( + List( + ("topic1", 0, Some(1)), + ("topic2", 1, Some(2)), + ("topic3", 2, Some(3)), + ("topic4", 2, Some(4)) + ), + offsets + ) + } + + @Test + def testTopicPartitionsNotFoundForNonExistentTopic(): Unit = { + assertExitCodeIsOne(Array("--topic", "some_nonexistent_topic")) + } + + @Test + def testTopicPartitionsNotFoundForExcludedInternalTopic(): Unit = { + assertExitCodeIsOne(Array("--topic", "some_nonexistent_topic:*")) + } + + @Test + def testTopicPartitionsNotFoundForNonMatchingTopicPartitionPattern(): Unit = { + assertExitCodeIsOne(Array("--topic-partitions", "__consumer_offsets", "--exclude-internal-topics")) + } + + @Test + def testTopicPartitionsFlagWithTopicFlagCauseExit(): Unit = { + assertExitCodeIsOne(Array("--topic-partitions", "__consumer_offsets", "--topic", "topic1")) + } + + @Test + def testTopicPartitionsFlagWithPartitionsFlagCauseExit(): Unit = { + assertExitCodeIsOne(Array("--topic-partitions", "__consumer_offsets", "--partitions", "0")) + } + + private def expectedOffsetsWithInternal(): List[(String, Int, Option[Long])] = { + Range(0, offsetTopicPartitionCount).map(i => ("__consumer_offsets", i, Some(0L))).toList ++ expectedTestTopicOffsets() + } + + private def expectedTestTopicOffsets(): List[(String, Int, Option[Long])] = { + Range(1, topicCount + 1).flatMap(i => expectedOffsetsForTopic(i)).toList + } + + private def expectedOffsetsForTopic(i: Int): List[(String, Int, Option[Long])] = { + val name = topicName(i) + Range(0, i).map(p => (name, p, Some(i.toLong))).toList + } + + private def topicName(i: Int): String = "topic" + i + + private def assertExitCodeIsOne(args: Array[String]): Unit = { + var exitStatus: Option[Int] = None + Exit.setExitProcedure { (status, _) => + exitStatus = Some(status) + throw new RuntimeException + } + + try { + GetOffsetShell.main(addBootstrapServer(args)) + } catch { + case e: RuntimeException => + } finally { + Exit.resetExitProcedure() + } + + assertEquals(Some(1), exitStatus) + } + + private def executeAndParse(args: Array[String]): List[(String, Int, Option[Long])] = { + val output = executeAndGrabOutput(args) + output.split(System.lineSeparator()) + .map(_.split(":")) + .filter(_.length >= 2) + .map { line => + val topic = line(0) + val partition = line(1).toInt + val timestamp = if (line.length == 2 || line(2).isEmpty) None else Some(line(2).toLong) + (topic, partition, timestamp) + } + .toList + } + + private def executeAndGrabOutput(args: Array[String]): String = { + TestUtils.grabConsoleOutput(GetOffsetShell.main(addBootstrapServer(args))) + } + + private def addBootstrapServer(args: Array[String]): Array[String] = { + args ++ Array("--bootstrap-server", brokerList) + } +} + + diff --git a/core/src/test/scala/kafka/tools/LogCompactionTester.scala b/core/src/test/scala/kafka/tools/LogCompactionTester.scala new file mode 100755 index 0000000..da8a3c0 --- /dev/null +++ b/core/src/test/scala/kafka/tools/LogCompactionTester.scala @@ -0,0 +1,348 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.tools + +import java.io._ +import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets.UTF_8 +import java.nio.file.{Files, Path} +import java.time.Duration +import java.util.{Properties, Random} + +import joptsimple.OptionParser +import kafka.utils._ +import org.apache.kafka.clients.admin.{Admin, NewTopic} +import org.apache.kafka.clients.CommonClientConfigs +import org.apache.kafka.clients.consumer.{ConsumerConfig, KafkaConsumer} +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerConfig, ProducerRecord} +import org.apache.kafka.common.config.TopicConfig +import org.apache.kafka.common.serialization.{ByteArraySerializer, StringDeserializer} +import org.apache.kafka.common.utils.{AbstractIterator, Utils} + +import scala.jdk.CollectionConverters._ + +/** + * This is a torture test that runs against an existing broker + * + * Here is how it works: + * + * It produces a series of specially formatted messages to one or more partitions. Each message it produces + * it logs out to a text file. The messages have a limited set of keys, so there is duplication in the key space. + * + * The broker will clean its log as the test runs. + * + * When the specified number of messages have been produced we create a consumer and consume all the messages in the topic + * and write that out to another text file. + * + * Using a stable unix sort we sort both the producer log of what was sent and the consumer log of what was retrieved by the message key. + * Then we compare the final message in both logs for each key. If this final message is not the same for all keys we + * print an error and exit with exit code 1, otherwise we print the size reduction and exit with exit code 0. + */ +object LogCompactionTester { + + //maximum line size while reading produced/consumed record text file + private val ReadAheadLimit = 4906 + + def main(args: Array[String]): Unit = { + val parser = new OptionParser(false) + val numMessagesOpt = parser.accepts("messages", "The number of messages to send or consume.") + .withRequiredArg + .describedAs("count") + .ofType(classOf[java.lang.Long]) + .defaultsTo(Long.MaxValue) + val messageCompressionOpt = parser.accepts("compression-type", "message compression type") + .withOptionalArg + .describedAs("compressionType") + .ofType(classOf[java.lang.String]) + .defaultsTo("none") + val numDupsOpt = parser.accepts("duplicates", "The number of duplicates for each key.") + .withRequiredArg + .describedAs("count") + .ofType(classOf[java.lang.Integer]) + .defaultsTo(5) + val brokerOpt = parser.accepts("bootstrap-server", "The server(s) to connect to.") + .withRequiredArg + .describedAs("url") + .ofType(classOf[String]) + val topicsOpt = parser.accepts("topics", "The number of topics to test.") + .withRequiredArg + .describedAs("count") + .ofType(classOf[java.lang.Integer]) + .defaultsTo(1) + val percentDeletesOpt = parser.accepts("percent-deletes", "The percentage of updates that are deletes.") + .withRequiredArg + .describedAs("percent") + .ofType(classOf[java.lang.Integer]) + .defaultsTo(0) + val sleepSecsOpt = parser.accepts("sleep", "Time in milliseconds to sleep between production and consumption.") + .withRequiredArg + .describedAs("ms") + .ofType(classOf[java.lang.Integer]) + .defaultsTo(0) + + val options = parser.parse(args: _*) + + if (args.length == 0) + CommandLineUtils.printUsageAndDie(parser, "A tool to test log compaction. Valid options are: ") + + CommandLineUtils.checkRequiredArgs(parser, options, brokerOpt, numMessagesOpt) + + // parse options + val messages = options.valueOf(numMessagesOpt).longValue + val compressionType = options.valueOf(messageCompressionOpt) + val percentDeletes = options.valueOf(percentDeletesOpt).intValue + val dups = options.valueOf(numDupsOpt).intValue + val brokerUrl = options.valueOf(brokerOpt) + val topicCount = options.valueOf(topicsOpt).intValue + val sleepSecs = options.valueOf(sleepSecsOpt).intValue + + val testId = new Random().nextLong + val topics = (0 until topicCount).map("log-cleaner-test-" + testId + "-" + _).toArray + createTopics(brokerUrl, topics.toSeq) + + println(s"Producing $messages messages..to topics ${topics.mkString(",")}") + val producedDataFilePath = produceMessages(brokerUrl, topics, messages, compressionType, dups, percentDeletes) + println(s"Sleeping for $sleepSecs seconds...") + Thread.sleep(sleepSecs * 1000) + println("Consuming messages...") + val consumedDataFilePath = consumeMessages(brokerUrl, topics) + + val producedLines = lineCount(producedDataFilePath) + val consumedLines = lineCount(consumedDataFilePath) + val reduction = 100 * (1.0 - consumedLines.toDouble / producedLines.toDouble) + println(f"$producedLines%d rows of data produced, $consumedLines%d rows of data consumed ($reduction%.1f%% reduction).") + + println("De-duplicating and validating output files...") + validateOutput(producedDataFilePath.toFile, consumedDataFilePath.toFile) + Utils.delete(producedDataFilePath.toFile) + Utils.delete(consumedDataFilePath.toFile) + //if you change this line, we need to update test_log_compaction_tool.py system test + println("Data verification is completed") + } + + def createTopics(brokerUrl: String, topics: Seq[String]): Unit = { + val adminConfig = new Properties + adminConfig.put(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, brokerUrl) + val adminClient = Admin.create(adminConfig) + + try { + val topicConfigs = Map(TopicConfig.CLEANUP_POLICY_CONFIG -> TopicConfig.CLEANUP_POLICY_COMPACT) + val newTopics = topics.map(name => new NewTopic(name, 1, 1.toShort).configs(topicConfigs.asJava)).asJava + adminClient.createTopics(newTopics).all.get + + var pendingTopics: Seq[String] = Seq() + TestUtils.waitUntilTrue(() => { + val allTopics = adminClient.listTopics.names.get.asScala.toSeq + pendingTopics = topics.filter(topicName => !allTopics.contains(topicName)) + pendingTopics.isEmpty + }, s"timed out waiting for topics : $pendingTopics") + + } finally adminClient.close() + } + + def lineCount(filPath: Path): Int = Files.readAllLines(filPath).size + + def validateOutput(producedDataFile: File, consumedDataFile: File): Unit = { + val producedReader = externalSort(producedDataFile) + val consumedReader = externalSort(consumedDataFile) + val produced = valuesIterator(producedReader) + val consumed = valuesIterator(consumedReader) + + val producedDedupedFile = new File(producedDataFile.getAbsolutePath + ".deduped") + val producedDeduped : BufferedWriter = Files.newBufferedWriter(producedDedupedFile.toPath, UTF_8) + + val consumedDedupedFile = new File(consumedDataFile.getAbsolutePath + ".deduped") + val consumedDeduped : BufferedWriter = Files.newBufferedWriter(consumedDedupedFile.toPath, UTF_8) + var total = 0 + var mismatched = 0 + while (produced.hasNext && consumed.hasNext) { + val p = produced.next() + producedDeduped.write(p.toString) + producedDeduped.newLine() + val c = consumed.next() + consumedDeduped.write(c.toString) + consumedDeduped.newLine() + if (p != c) + mismatched += 1 + total += 1 + } + producedDeduped.close() + consumedDeduped.close() + println(s"Validated $total values, $mismatched mismatches.") + require(!produced.hasNext, "Additional values produced not found in consumer log.") + require(!consumed.hasNext, "Additional values consumed not found in producer log.") + require(mismatched == 0, "Non-zero number of row mismatches.") + // if all the checks worked out we can delete the deduped files + Utils.delete(producedDedupedFile) + Utils.delete(consumedDedupedFile) + } + + def require(requirement: Boolean, message: => Any): Unit = { + if (!requirement) { + System.err.println(s"Data validation failed : $message") + Exit.exit(1) + } + } + + def valuesIterator(reader: BufferedReader): Iterator[TestRecord] = { + new AbstractIterator[TestRecord] { + def makeNext(): TestRecord = { + var next = readNext(reader) + while (next != null && next.delete) + next = readNext(reader) + if (next == null) + allDone() + else + next + } + }.asScala + } + + def readNext(reader: BufferedReader): TestRecord = { + var line = reader.readLine() + if (line == null) + return null + var curr = TestRecord.parse(line) + while (true) { + line = peekLine(reader) + if (line == null) + return curr + val next = TestRecord.parse(line) + if (next == null || next.topicAndKey != curr.topicAndKey) + return curr + curr = next + reader.readLine() + } + null + } + + def peekLine(reader: BufferedReader) = { + reader.mark(ReadAheadLimit) + val line = reader.readLine + reader.reset() + line + } + + def externalSort(file: File): BufferedReader = { + val builder = new ProcessBuilder("sort", "--key=1,2", "--stable", "--buffer-size=20%", "--temporary-directory=" + Files.createTempDirectory("log_compaction_test"), file.getAbsolutePath) + val process = builder.start + new Thread() { + override def run(): Unit = { + val exitCode = process.waitFor() + if (exitCode != 0) { + System.err.println("Process exited abnormally.") + while (process.getErrorStream.available > 0) { + System.err.write(process.getErrorStream().read()) + } + } + } + }.start() + new BufferedReader(new InputStreamReader(process.getInputStream(), UTF_8), 10 * 1024 * 1024) + } + + def produceMessages(brokerUrl: String, + topics: Array[String], + messages: Long, + compressionType: String, + dups: Int, + percentDeletes: Int): Path = { + val producerProps = new Properties + producerProps.setProperty(ProducerConfig.MAX_BLOCK_MS_CONFIG, Long.MaxValue.toString) + producerProps.setProperty(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerUrl) + producerProps.setProperty(ProducerConfig.COMPRESSION_TYPE_CONFIG, compressionType) + val producer = new KafkaProducer(producerProps, new ByteArraySerializer, new ByteArraySerializer) + try { + val rand = new Random(1) + val keyCount = (messages / dups).toInt + val producedFilePath = Files.createTempFile("kafka-log-cleaner-produced-", ".txt") + println(s"Logging produce requests to $producedFilePath") + val producedWriter: BufferedWriter = Files.newBufferedWriter(producedFilePath, UTF_8) + for (i <- 0L until (messages * topics.length)) { + val topic = topics((i % topics.length).toInt) + val key = rand.nextInt(keyCount) + val delete = (i % 100) < percentDeletes + val msg = + if (delete) + new ProducerRecord[Array[Byte], Array[Byte]](topic, key.toString.getBytes(UTF_8), null) + else + new ProducerRecord(topic, key.toString.getBytes(UTF_8), i.toString.getBytes(UTF_8)) + producer.send(msg) + producedWriter.write(TestRecord(topic, key, i, delete).toString) + producedWriter.newLine() + } + producedWriter.close() + producedFilePath + } finally { + producer.close() + } + } + + def createConsumer(brokerUrl: String): KafkaConsumer[String, String] = { + val consumerProps = new Properties + consumerProps.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "log-cleaner-test-" + new Random().nextInt(Int.MaxValue)) + consumerProps.setProperty(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerUrl) + consumerProps.setProperty(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest") + new KafkaConsumer(consumerProps, new StringDeserializer, new StringDeserializer) + } + + def consumeMessages(brokerUrl: String, topics: Array[String]): Path = { + val consumer = createConsumer(brokerUrl) + consumer.subscribe(topics.toSeq.asJava) + val consumedFilePath = Files.createTempFile("kafka-log-cleaner-consumed-", ".txt") + println(s"Logging consumed messages to $consumedFilePath") + val consumedWriter: BufferedWriter = Files.newBufferedWriter(consumedFilePath, UTF_8) + + try { + var done = false + while (!done) { + val consumerRecords = consumer.poll(Duration.ofSeconds(20)) + if (!consumerRecords.isEmpty) { + for (record <- consumerRecords.asScala) { + val delete = record.value == null + val value = if (delete) -1L else record.value.toLong + consumedWriter.write(TestRecord(record.topic, record.key.toInt, value, delete).toString) + consumedWriter.newLine + } + } else { + done = true + } + } + consumedFilePath + } finally { + consumedWriter.close() + consumer.close() + } + } + + def readString(buffer: ByteBuffer): String = { + Utils.utf8(buffer) + } + +} + +case class TestRecord(topic: String, key: Int, value: Long, delete: Boolean) { + override def toString = topic + "\t" + key + "\t" + value + "\t" + (if (delete) "d" else "u") + def topicAndKey = topic + key +} + +object TestRecord { + def parse(line: String): TestRecord = { + val components = line.split("\t") + new TestRecord(components(0), components(1).toInt, components(2).toLong, components(3) == "d") + } +} diff --git a/core/src/test/scala/kafka/tools/ReplicaVerificationToolTest.scala b/core/src/test/scala/kafka/tools/ReplicaVerificationToolTest.scala new file mode 100644 index 0000000..2172604 --- /dev/null +++ b/core/src/test/scala/kafka/tools/ReplicaVerificationToolTest.scala @@ -0,0 +1,65 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.tools + +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.message.FetchResponseData +import org.apache.kafka.common.record.{CompressionType, MemoryRecords, SimpleRecord} +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Assertions.assertTrue + +class ReplicaVerificationToolTest { + + @Test + def testReplicaBufferVerifyChecksum(): Unit = { + val sb = new StringBuilder + + val expectedReplicasPerTopicAndPartition = Map( + new TopicPartition("a", 0) -> 3, + new TopicPartition("a", 1) -> 3, + new TopicPartition("b", 0) -> 2 + ) + + val replicaBuffer = new ReplicaBuffer(expectedReplicasPerTopicAndPartition, Map.empty, 2, 0) + expectedReplicasPerTopicAndPartition.foreach { case (tp, numReplicas) => + (0 until numReplicas).foreach { replicaId => + val records = (0 to 5).map { index => + new SimpleRecord(s"key $index".getBytes, s"value $index".getBytes) + } + val initialOffset = 4 + val memoryRecords = MemoryRecords.withRecords(initialOffset, CompressionType.NONE, records: _*) + val partitionData = new FetchResponseData.PartitionData() + .setPartitionIndex(tp.partition) + .setHighWatermark(20) + .setLastStableOffset(20) + .setLogStartOffset(0) + .setRecords(memoryRecords) + + replicaBuffer.addFetchedData(tp, replicaId, partitionData) + } + } + + replicaBuffer.verifyCheckSum(line => sb.append(s"$line\n")) + val output = sb.toString.trim + + // If you change this assertion, you should verify that the replica_verification_test.py system test still passes + assertTrue(output.endsWith(": max lag is 10 for partition a-1 at offset 10 among 3 partitions"), + s"Max lag information should be in output: `$output`") + } + +} diff --git a/core/src/test/scala/kafka/utils/ExitTest.scala b/core/src/test/scala/kafka/utils/ExitTest.scala new file mode 100644 index 0000000..fcb2e9a --- /dev/null +++ b/core/src/test/scala/kafka/utils/ExitTest.scala @@ -0,0 +1,121 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import java.io.IOException + +import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows} +import org.junit.jupiter.api.Test + +class ExitTest { + @Test + def shouldHaltImmediately(): Unit = { + val array:Array[Any] = Array("a", "b") + def haltProcedure(exitStatus: Int, message: Option[String]) : Nothing = { + array(0) = exitStatus + array(1) = message + throw new IOException() + } + Exit.setHaltProcedure(haltProcedure) + val statusCode = 0 + val message = Some("message") + try { + assertThrows(classOf[IOException], () => Exit.halt(statusCode)) + assertEquals(statusCode, array(0)) + assertEquals(None, array(1)) + + assertThrows(classOf[IOException], () => Exit.halt(statusCode, message)) + assertEquals(statusCode, array(0)) + assertEquals(message, array(1)) + } finally { + Exit.resetHaltProcedure() + } + } + + @Test + def shouldExitImmediately(): Unit = { + val array:Array[Any] = Array("a", "b") + def exitProcedure(exitStatus: Int, message: Option[String]) : Nothing = { + array(0) = exitStatus + array(1) = message + throw new IOException() + } + Exit.setExitProcedure(exitProcedure) + val statusCode = 0 + val message = Some("message") + try { + assertThrows(classOf[IOException], () => Exit.exit(statusCode)) + assertEquals(statusCode, array(0)) + assertEquals(None, array(1)) + + assertThrows(classOf[IOException], () => Exit.exit(statusCode, message)) + assertEquals(statusCode, array(0)) + assertEquals(message, array(1)) + } finally { + Exit.resetExitProcedure() + } + } + + @Test + def shouldAddShutdownHookImmediately(): Unit = { + val name = "name" + val array:Array[Any] = Array("", 0) + // immediately invoke the shutdown hook to mutate the data when a hook is added + def shutdownHookAdder(name: String, shutdownHook: => Unit) : Unit = { + // mutate the first element + array(0) = array(0).toString + name + // invoke the shutdown hook (see below, it mutates the second element) + shutdownHook + } + Exit.setShutdownHookAdder(shutdownHookAdder) + def sideEffect(): Unit = { + // mutate the second element + array(1) = array(1).asInstanceOf[Int] + 1 + } + try { + Exit.addShutdownHook(name, sideEffect()) // by-name parameter, only invoked due to above shutdownHookAdder + assertEquals(1, array(1)) + assertEquals(name * array(1).asInstanceOf[Int], array(0).toString) + Exit.addShutdownHook(name, array(1) = array(1).asInstanceOf[Int] + 1) // by-name parameter, only invoked due to above shutdownHookAdder + assertEquals(2, array(1)) + assertEquals(name * array(1).asInstanceOf[Int], array(0).toString) + } finally { + Exit.resetShutdownHookAdder() + } + } + + @Test + def shouldNotInvokeShutdownHookImmediately(): Unit = { + val name = "name" + val array:Array[String] = Array(name) + + def sideEffect(): Unit = { + // mutate the first element + array(0) = array(0) + name + } + Exit.addShutdownHook(name, sideEffect()) // by-name parameter, not invoked + // make sure the first element wasn't mutated + assertEquals(name, array(0)) + Exit.addShutdownHook(name, sideEffect()) // by-name parameter, not invoked + // again make sure the first element wasn't mutated + assertEquals(name, array(0)) + Exit.addShutdownHook(name, array(0) = array(0) + name) // by-name parameter, not invoked + // again make sure the first element wasn't mutated + assertEquals(name, array(0)) + } +} diff --git a/core/src/test/scala/kafka/utils/LoggingTest.scala b/core/src/test/scala/kafka/utils/LoggingTest.scala new file mode 100644 index 0000000..5091fd6 --- /dev/null +++ b/core/src/test/scala/kafka/utils/LoggingTest.scala @@ -0,0 +1,88 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import java.lang.management.ManagementFactory + +import javax.management.ObjectName +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue} +import org.slf4j.LoggerFactory + + +class LoggingTest extends Logging { + + @Test + def testTypeOfGetLoggers(): Unit = { + val log4jController = new Log4jController + // the return object of getLoggers must be a collection instance from java standard library. + // That enables mbean client to deserialize it without extra libraries. + assertEquals(classOf[java.util.ArrayList[String]], log4jController.getLoggers.getClass) + } + + @Test + def testLog4jControllerIsRegistered(): Unit = { + val mbs = ManagementFactory.getPlatformMBeanServer() + val log4jControllerName = ObjectName.getInstance("kafka:type=kafka.Log4jController") + assertTrue(mbs.isRegistered(log4jControllerName), "kafka.utils.Log4jController is not registered") + val instance = mbs.getObjectInstance(log4jControllerName) + assertEquals("kafka.utils.Log4jController", instance.getClassName) + } + + @Test + def testLogNameOverride(): Unit = { + class TestLogging(overriddenLogName: String) extends Logging { + // Expose logger + def log = logger + override def loggerName = overriddenLogName + } + val overriddenLogName = "OverriddenLogName" + val logging = new TestLogging(overriddenLogName) + + assertEquals(overriddenLogName, logging.log.underlying.getName) + } + + @Test + def testLogName(): Unit = { + class TestLogging extends Logging { + // Expose logger + def log = logger + } + val logging = new TestLogging + + assertEquals(logging.getClass.getName, logging.log.underlying.getName) + } + + @Test + def testLoggerLevelIsResolved(): Unit = { + val controller = new Log4jController() + val previousLevel = controller.getLogLevel("kafka") + try { + controller.setLogLevel("kafka", "TRACE") + // Do some logging so that the Logger is created within the hierarchy + // (until loggers are used only loggers in the config file exist) + LoggerFactory.getLogger("kafka.utils.Log4jControllerTest").trace("test") + assertEquals("TRACE", controller.getLogLevel("kafka")) + assertEquals("TRACE", controller.getLogLevel("kafka.utils.Log4jControllerTest")) + assertTrue(controller.getLoggers.contains("kafka=TRACE")) + assertTrue(controller.getLoggers.contains("kafka.utils.Log4jControllerTest=TRACE")) + } finally { + controller.setLogLevel("kafka", previousLevel) + } + } +} diff --git a/core/src/test/scala/kafka/utils/TestInfoUtils.scala b/core/src/test/scala/kafka/utils/TestInfoUtils.scala new file mode 100644 index 0000000..ecd656e --- /dev/null +++ b/core/src/test/scala/kafka/utils/TestInfoUtils.scala @@ -0,0 +1,46 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.utils + +import java.lang.reflect.Method +import java.util +import java.util.{Collections, Optional} + +import org.junit.jupiter.api.TestInfo + +class EmptyTestInfo extends TestInfo { + override def getDisplayName: String = "" + override def getTags: util.Set[String] = Collections.emptySet() + override def getTestClass: (Optional[Class[_]]) = Optional.empty() + override def getTestMethod: Optional[Method] = Optional.empty() +} + +object TestInfoUtils { + def isKRaft(testInfo: TestInfo): Boolean = { + if (testInfo.getDisplayName().contains("quorum=")) { + if (testInfo.getDisplayName().contains("quorum=kraft")) { + true + } else if (testInfo.getDisplayName().contains("quorum=zk")) { + false + } else { + throw new RuntimeException(s"Unknown quorum value") + } + } else { + false + } + } +} diff --git a/core/src/test/scala/kafka/utils/ToolsUtilsTest.scala b/core/src/test/scala/kafka/utils/ToolsUtilsTest.scala new file mode 100644 index 0000000..8549136 --- /dev/null +++ b/core/src/test/scala/kafka/utils/ToolsUtilsTest.scala @@ -0,0 +1,45 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.utils + +import java.io.ByteArrayOutputStream +import java.util.Collections + +import org.apache.kafka.common.MetricName +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.metrics.internals.IntGaugeSuite +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test +import org.slf4j.LoggerFactory + +import scala.jdk.CollectionConverters._ + +class ToolsUtilsTest { + private val log = LoggerFactory.getLogger(classOf[ToolsUtilsTest]) + + @Test def testIntegerMetric(): Unit = { + val outContent = new ByteArrayOutputStream() + val metrics = new Metrics + val suite = new IntGaugeSuite[String](log, "example", metrics, (k: String) => new MetricName(k + "-bar", "test", "A test metric", Collections.singletonMap("key", "value")), 1) + suite.increment("foo") + Console.withOut(outContent) { + ToolsUtils.printMetrics(metrics.metrics.asScala) + assertTrue(outContent.toString.split("\n").exists(line => line.trim.matches("^test:foo-bar:\\{key=value\\} : 1$"))) + } + } + +} diff --git a/core/src/test/scala/kafka/zk/ExtendedAclStoreTest.scala b/core/src/test/scala/kafka/zk/ExtendedAclStoreTest.scala new file mode 100644 index 0000000..1044e42 --- /dev/null +++ b/core/src/test/scala/kafka/zk/ExtendedAclStoreTest.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.zk + +import org.apache.kafka.common.resource.PatternType.{LITERAL, PREFIXED} +import org.apache.kafka.common.resource.ResourcePattern +import org.apache.kafka.common.resource.ResourceType.TOPIC +import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows} +import org.junit.jupiter.api.Test + +class ExtendedAclStoreTest { + private val literalResource = new ResourcePattern(TOPIC, "some-topic", LITERAL) + private val prefixedResource = new ResourcePattern(TOPIC, "some-topic", PREFIXED) + private val store = new ExtendedAclStore(PREFIXED) + + @Test + def shouldHaveCorrectPaths(): Unit = { + assertEquals("/kafka-acl-extended/prefixed", store.aclPath) + assertEquals("/kafka-acl-extended/prefixed/Topic", store.path(TOPIC)) + assertEquals("/kafka-acl-extended-changes", store.changeStore.aclChangePath) + } + + @Test + def shouldHaveCorrectPatternType(): Unit = { + assertEquals(PREFIXED, store.patternType) + } + + @Test + def shouldThrowIfConstructedWithLiteral(): Unit = { + assertThrows(classOf[IllegalArgumentException], () => new ExtendedAclStore(LITERAL)) + } + + @Test + def shouldThrowFromEncodeOnLiteral(): Unit = { + assertThrows(classOf[IllegalArgumentException], () => store.changeStore.createChangeNode(literalResource)) + } + + @Test + def shouldWriteChangesToTheWritePath(): Unit = { + val changeNode = store.changeStore.createChangeNode(prefixedResource) + + assertEquals("/kafka-acl-extended-changes/acl_changes_", changeNode.path) + } + + @Test + def shouldRoundTripChangeNode(): Unit = { + val changeNode = store.changeStore.createChangeNode(prefixedResource) + + val actual = store.changeStore.decode(changeNode.bytes) + + assertEquals(prefixedResource, actual) + } +} \ No newline at end of file diff --git a/core/src/test/scala/kafka/zk/FeatureZNodeTest.scala b/core/src/test/scala/kafka/zk/FeatureZNodeTest.scala new file mode 100644 index 0000000..9344724 --- /dev/null +++ b/core/src/test/scala/kafka/zk/FeatureZNodeTest.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.zk + +import java.nio.charset.StandardCharsets + +import org.apache.kafka.common.feature.{Features, FinalizedVersionRange} +import org.apache.kafka.common.feature.Features._ +import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows} +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ + +class FeatureZNodeTest { + + @Test + def testEncodeDecode(): Unit = { + val featureZNode = FeatureZNode( + FeatureZNodeStatus.Enabled, + Features.finalizedFeatures( + Map[String, FinalizedVersionRange]( + "feature1" -> new FinalizedVersionRange(1, 2), + "feature2" -> new FinalizedVersionRange(2, 4)).asJava)) + val decoded = FeatureZNode.decode(FeatureZNode.encode(featureZNode)) + assertEquals(featureZNode, decoded) + } + + @Test + def testDecodeSuccess(): Unit = { + val featureZNodeStrTemplate = """{ + "version":1, + "status":1, + "features":%s + }""" + + val validFeatures = """{"feature1": {"min_version_level": 1, "max_version_level": 2}, "feature2": {"min_version_level": 2, "max_version_level": 4}}""" + val node1 = FeatureZNode.decode(featureZNodeStrTemplate.format(validFeatures).getBytes(StandardCharsets.UTF_8)) + assertEquals(FeatureZNodeStatus.Enabled, node1.status) + assertEquals( + Features.finalizedFeatures( + Map[String, FinalizedVersionRange]( + "feature1" -> new FinalizedVersionRange(1, 2), + "feature2" -> new FinalizedVersionRange(2, 4)).asJava), node1.features) + + val emptyFeatures = "{}" + val node2 = FeatureZNode.decode(featureZNodeStrTemplate.format(emptyFeatures).getBytes(StandardCharsets.UTF_8)) + assertEquals(FeatureZNodeStatus.Enabled, node2.status) + assertEquals(emptyFinalizedFeatures, node2.features) + } + + @Test + def testDecodeFailOnInvalidVersionAndStatus(): Unit = { + val featureZNodeStrTemplate = + """{ + "version":%d, + "status":%d, + "features":{"feature1": {"min_version_level": 1, "max_version_level": 2}, "feature2": {"min_version_level": 2, "max_version_level": 4}} + }""" + assertThrows(classOf[IllegalArgumentException], () => FeatureZNode.decode(featureZNodeStrTemplate.format(FeatureZNode.V1 - 1, 1).getBytes(StandardCharsets.UTF_8))) + val invalidStatus = FeatureZNodeStatus.Enabled.id + 1 + assertThrows(classOf[IllegalArgumentException], () => FeatureZNode.decode(featureZNodeStrTemplate.format(FeatureZNode.CurrentVersion, invalidStatus).getBytes(StandardCharsets.UTF_8))) + } + + @Test + def testDecodeFailOnInvalidFeatures(): Unit = { + val featureZNodeStrTemplate = + """{ + "version":1, + "status":1%s + }""" + + val missingFeatures = "" + assertThrows(classOf[IllegalArgumentException], () => FeatureZNode.decode(featureZNodeStrTemplate.format(missingFeatures).getBytes(StandardCharsets.UTF_8))) + + val malformedFeatures = ""","features":{"feature1": {"min_version_level": 1, "max_version_level": 2}, "partial"}""" + assertThrows(classOf[IllegalArgumentException], () => FeatureZNode.decode(featureZNodeStrTemplate.format(malformedFeatures).getBytes(StandardCharsets.UTF_8))) + + val invalidFeaturesMinVersionLevel = ""","features":{"feature1": {"min_version_level": 0, "max_version_level": 2}}""" + assertThrows(classOf[IllegalArgumentException], () => FeatureZNode.decode(featureZNodeStrTemplate.format(invalidFeaturesMinVersionLevel).getBytes(StandardCharsets.UTF_8))) + + val invalidFeaturesMaxVersionLevel = ""","features":{"feature1": {"min_version_level": 2, "max_version_level": 1}}""" + assertThrows(classOf[IllegalArgumentException], () => FeatureZNode.decode(featureZNodeStrTemplate.format(invalidFeaturesMaxVersionLevel).getBytes(StandardCharsets.UTF_8))) + + val invalidFeaturesMissingMinVersionLevel = ""","features":{"feature1": {"max_version_level": 1}}""" + assertThrows(classOf[IllegalArgumentException], () => FeatureZNode.decode(featureZNodeStrTemplate.format(invalidFeaturesMissingMinVersionLevel).getBytes(StandardCharsets.UTF_8))) + } +} diff --git a/core/src/test/scala/kafka/zk/LiteralAclStoreTest.scala b/core/src/test/scala/kafka/zk/LiteralAclStoreTest.scala new file mode 100644 index 0000000..bfddee2 --- /dev/null +++ b/core/src/test/scala/kafka/zk/LiteralAclStoreTest.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.zk + +import java.nio.charset.StandardCharsets.UTF_8 +import kafka.security.authorizer.AclEntry +import org.apache.kafka.common.resource.PatternType.{LITERAL, PREFIXED} +import org.apache.kafka.common.resource.ResourcePattern +import org.apache.kafka.common.resource.ResourceType.{GROUP, TOPIC} +import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows} +import org.junit.jupiter.api.Test + +class LiteralAclStoreTest { + private val literalResource = new ResourcePattern(TOPIC, "some-topic", LITERAL) + private val prefixedResource = new ResourcePattern(TOPIC, "some-topic", PREFIXED) + private val store = LiteralAclStore + + @Test + def shouldHaveCorrectPaths(): Unit = { + assertEquals("/kafka-acl", store.aclPath) + assertEquals("/kafka-acl/Topic", store.path(TOPIC)) + assertEquals("/kafka-acl-changes", store.changeStore.aclChangePath) + } + + @Test + def shouldHaveCorrectPatternType(): Unit = { + assertEquals(LITERAL, store.patternType) + } + + @Test + def shouldThrowFromEncodeOnNoneLiteral(): Unit = { + assertThrows(classOf[IllegalArgumentException], () => store.changeStore.createChangeNode(prefixedResource)) + } + + @Test + def shouldWriteChangesToTheWritePath(): Unit = { + val changeNode = store.changeStore.createChangeNode(literalResource) + + assertEquals("/kafka-acl-changes/acl_changes_", changeNode.path) + } + + @Test + def shouldRoundTripChangeNode(): Unit = { + val changeNode = store.changeStore.createChangeNode(literalResource) + + val actual = store.changeStore.decode(changeNode.bytes) + + assertEquals(literalResource, actual) + } + + @Test + def shouldDecodeResourceUsingTwoPartLogic(): Unit = { + val resource = new ResourcePattern(GROUP, "PREFIXED:this, including the PREFIXED part, is a valid two part group name", LITERAL) + val encoded = (resource.resourceType.toString + AclEntry.ResourceSeparator + resource.name).getBytes(UTF_8) + + val actual = store.changeStore.decode(encoded) + + assertEquals(resource, actual) + } +} diff --git a/core/src/test/scala/other/kafka.log4j.properties b/core/src/test/scala/other/kafka.log4j.properties new file mode 100644 index 0000000..1a53fd5 --- /dev/null +++ b/core/src/test/scala/other/kafka.log4j.properties @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +log4j.rootLogger=INFO, KAFKA + +log4j.appender.KAFKA=kafka.log4j.KafkaAppender + +log4j.appender.KAFKA.Port=9092 +log4j.appender.KAFKA.Host=localhost +log4j.appender.KAFKA.Topic=test-logger +log4j.appender.KAFKA.Serializer=kafka.AppenderStringSerializer diff --git a/core/src/test/scala/other/kafka/ReplicationQuotasTestRig.scala b/core/src/test/scala/other/kafka/ReplicationQuotasTestRig.scala new file mode 100644 index 0000000..42230e7 --- /dev/null +++ b/core/src/test/scala/other/kafka/ReplicationQuotasTestRig.scala @@ -0,0 +1,344 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka + +import java.io.{File, PrintWriter} +import java.nio.charset.StandardCharsets +import java.nio.file.{Files, StandardOpenOption} + +import javax.imageio.ImageIO +import kafka.admin.ReassignPartitionsCommand +import kafka.server.{KafkaConfig, KafkaServer, QuorumTestHarness, QuotaType} +import kafka.utils.TestUtils._ +import kafka.utils.{EmptyTestInfo, Exit, Logging, TestUtils} +import kafka.zk.ReassignPartitionsZNode +import org.apache.kafka.clients.admin.{Admin, AdminClientConfig} +import org.apache.kafka.clients.producer.ProducerRecord +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.utils.Utils +import org.jfree.chart.plot.PlotOrientation +import org.jfree.chart.{ChartFactory, ChartFrame, JFreeChart} +import org.jfree.data.xy.{XYSeries, XYSeriesCollection} + +import scala.jdk.CollectionConverters._ +import scala.collection.{Map, Seq, mutable} + +/** + * Test rig for measuring throttling performance. Configure the parameters for a set of experiments, then execute them + * and view the html output file, with charts, that are produced. You can also render the charts to the screen if + * you wish. + * + * Currently you'll need about 40GB of disk space to run these experiments (largest data written x2). Tune the msgSize + * & #partitions and throttle to adjust this. + */ +object ReplicationQuotasTestRig { + new File("Experiments").mkdir() + private val dir = "Experiments/Run" + System.currentTimeMillis().toString.substring(8) + new File(dir).mkdir() + val k = 1000 * 1000 + + + def main(args: Array[String]): Unit = { + val displayChartsOnScreen = if (args.length > 0 && args(0) == "show-gui") true else false + val journal = new Journal() + + val experiments = Seq( + //1GB total data written, will take 210s + new ExperimentDef("Experiment1", brokers = 5, partitions = 20, throttle = 1 * k, msgsPerPartition = 500, msgSize = 100 * 1000), + //5GB total data written, will take 110s + new ExperimentDef("Experiment2", brokers = 5, partitions = 50, throttle = 10 * k, msgsPerPartition = 1000, msgSize = 100 * 1000), + //5GB total data written, will take 110s + new ExperimentDef("Experiment3", brokers = 50, partitions = 50, throttle = 2 * k, msgsPerPartition = 1000, msgSize = 100 * 1000), + //10GB total data written, will take 110s + new ExperimentDef("Experiment4", brokers = 25, partitions = 100, throttle = 4 * k, msgsPerPartition = 1000, msgSize = 100 * 1000), + //10GB total data written, will take 80s + new ExperimentDef("Experiment5", brokers = 5, partitions = 50, throttle = 50 * k, msgsPerPartition = 4000, msgSize = 100 * 1000) + ) + experiments.foreach(run(_, journal, displayChartsOnScreen)) + + if (!displayChartsOnScreen) + Exit.exit(0) + } + + def run(config: ExperimentDef, journal: Journal, displayChartsOnScreen: Boolean): Unit = { + val experiment = new Experiment() + try { + experiment.setUp(new EmptyTestInfo()) + experiment.run(config, journal, displayChartsOnScreen) + journal.footer() + } + catch { + case e: Exception => e.printStackTrace() + } + finally { + experiment.tearDown() + } + } + + case class ExperimentDef(name: String, brokers: Int, partitions: Int, throttle: Long, msgsPerPartition: Int, msgSize: Int) { + val targetBytesPerBrokerMB: Long = msgsPerPartition.toLong * msgSize.toLong * partitions.toLong / brokers.toLong / 1000000 + } + + class Experiment extends QuorumTestHarness with Logging { + val topicName = "my-topic" + var experimentName = "unset" + val partitionId = 0 + var servers: Seq[KafkaServer] = null + val leaderRates = mutable.Map[Int, Array[Double]]() + val followerRates = mutable.Map[Int, Array[Double]]() + var adminClient: Admin = null + + def startBrokers(brokerIds: Seq[Int]): Unit = { + println("Starting Brokers") + servers = brokerIds.map(i => createBrokerConfig(i, zkConnect)) + .map(c => createServer(KafkaConfig.fromProps(c))) + + TestUtils.waitUntilBrokerMetadataIsPropagated(servers) + val brokerList = TestUtils.bootstrapServers(servers, + ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT)) + adminClient = Admin.create(Map[String, Object]( + AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG -> brokerList + ).asJava) + } + + override def tearDown(): Unit = { + Utils.closeQuietly(adminClient, "adminClient") + TestUtils.shutdownServers(servers) + super.tearDown() + } + + def run(config: ExperimentDef, journal: Journal, displayChartsOnScreen: Boolean): Unit = { + experimentName = config.name + val brokers = (100 to 100 + config.brokers) + var count = 0 + val shift = Math.round(config.brokers / 2f) + + def nextReplicaRoundRobin(): Int = { + count = count + 1 + 100 + (count + shift) % config.brokers + } + val replicas = (0 to config.partitions).map(partition => partition -> Seq(nextReplicaRoundRobin())).toMap + + startBrokers(brokers) + createTopic(zkClient, topicName, replicas, servers) + + println("Writing Data") + val producer = TestUtils.createProducer(TestUtils.getBrokerListStrFromServers(servers), acks = 0) + (0 until config.msgsPerPartition).foreach { x => + (0 until config.partitions).foreach { partition => + producer.send(new ProducerRecord(topicName, partition, null, new Array[Byte](config.msgSize))) + } + } + + println("Generating Reassignment") + val (newAssignment, _) = ReassignPartitionsCommand.generateAssignment(adminClient, + json(topicName), brokers.mkString(","), true) + + println("Starting Reassignment") + val start = System.currentTimeMillis() + ReassignPartitionsCommand.executeAssignment(adminClient, false, + new String(ReassignPartitionsZNode.encode(newAssignment), StandardCharsets.UTF_8), + config.throttle) + + //Await completion + waitForReassignmentToComplete() + println(s"Reassignment took ${(System.currentTimeMillis() - start)/1000}s") + + validateAllOffsetsMatch(config) + + journal.appendToJournal(config) + renderChart(leaderRates, "Leader", journal, displayChartsOnScreen) + renderChart(followerRates, "Follower", journal, displayChartsOnScreen) + logOutput(config, replicas, newAssignment) + + println("Output can be found here: " + journal.path()) + } + + def validateAllOffsetsMatch(config: ExperimentDef): Unit = { + //Validate that offsets are correct in all brokers + for (broker <- servers) { + (0 until config.partitions).foreach { partitionId => + val offset = broker.getLogManager.getLog(new TopicPartition(topicName, partitionId)).map(_.logEndOffset).getOrElse(-1L) + if (offset >= 0 && offset != config.msgsPerPartition) { + throw new RuntimeException(s"Run failed as offsets did not match for partition $partitionId on broker ${broker.config.brokerId}. Expected ${config.msgsPerPartition} but was $offset.") + } + } + } + } + + def logOutput(config: ExperimentDef, replicas: Map[Int, Seq[Int]], newAssignment: Map[TopicPartition, Seq[Int]]): Unit = { + val actual = zkClient.getPartitionAssignmentForTopics(Set(topicName))(topicName) + + //Long stats + println("The replicas are " + replicas.toSeq.sortBy(_._1).map("\n" + _)) + println("This is the current replica assignment:\n" + actual.map { case (k, v) => k -> v.replicas }) + println("proposed assignment is: \n" + newAssignment) + println("This is the assignment we ended up with" + actual.map { case (k, v) => k -> v.replicas }) + + //Test Stats + println(s"numBrokers: ${config.brokers}") + println(s"numPartitions: ${config.partitions}") + println(s"throttle: ${config.throttle}") + println(s"numMessagesPerPartition: ${config.msgsPerPartition}") + println(s"msgSize: ${config.msgSize}") + println(s"We will write ${config.targetBytesPerBrokerMB}MB of data per broker") + println(s"Worst case duration is ${config.targetBytesPerBrokerMB * 1000 * 1000/ config.throttle}") + } + + def waitForReassignmentToComplete(): Unit = { + waitUntilTrue(() => { + printRateMetrics() + adminClient.listPartitionReassignments().reassignments().get().isEmpty + }, s"Partition reassignments didn't complete.", 60 * 60 * 1000, pause = 1000L) + } + + def renderChart(data: mutable.Map[Int, Array[Double]], name: String, journal: Journal, displayChartsOnScreen: Boolean): Unit = { + val dataset = addDataToChart(data) + val chart = createChart(name, dataset) + + writeToFile(name, journal, chart) + maybeDisplayOnScreen(displayChartsOnScreen, chart) + println(s"Chart generated for $name") + } + + def maybeDisplayOnScreen(displayChartsOnScreen: Boolean, chart: JFreeChart): Unit = { + if (displayChartsOnScreen) { + val frame = new ChartFrame(experimentName, chart) + frame.pack() + frame.setVisible(true) + } + } + + def writeToFile(name: String, journal: Journal, chart: JFreeChart): Unit = { + val file = new File(dir, experimentName + "-" + name + ".png") + ImageIO.write(chart.createBufferedImage(1000, 700), "png", file) + journal.appendChart(file.getAbsolutePath, name.eq("Leader")) + } + + def createChart(name: String, dataset: XYSeriesCollection): JFreeChart = { + val chart: JFreeChart = ChartFactory.createXYLineChart( + experimentName + " - " + name + " Throttling Performance", + "Time (s)", + "Throttle Throughput (B/s)", + dataset + , PlotOrientation.VERTICAL, false, true, false + ) + chart + } + + def addDataToChart(data: mutable.Map[Int, Array[Double]]): XYSeriesCollection = { + val dataset = new XYSeriesCollection + data.foreach { case (broker, values) => + val series = new XYSeries("Broker:" + broker) + var x = 0 + values.foreach { value => + series.add(x, value) + x += 1 + } + dataset.addSeries(series) + } + dataset + } + + def record(rates: mutable.Map[Int, Array[Double]], brokerId: Int, currentRate: Double) = { + var leaderRatesBroker: Array[Double] = rates.getOrElse(brokerId, Array[Double]()) + leaderRatesBroker = leaderRatesBroker ++ Array(currentRate) + rates.put(brokerId, leaderRatesBroker) + } + + def printRateMetrics(): Unit = { + for (broker <- servers) { + val leaderRate: Double = measuredRate(broker, QuotaType.LeaderReplication) + if (broker.config.brokerId == 100) + info("waiting... Leader rate on 101 is " + leaderRate) + record(leaderRates, broker.config.brokerId, leaderRate) + if (leaderRate > 0) + trace("Leader Rate on " + broker.config.brokerId + " is " + leaderRate) + + val followerRate: Double = measuredRate(broker, QuotaType.FollowerReplication) + record(followerRates, broker.config.brokerId, followerRate) + if (followerRate > 0) + trace("Follower Rate on " + broker.config.brokerId + " is " + followerRate) + } + } + + private def measuredRate(broker: KafkaServer, repType: QuotaType): Double = { + val metricName = broker.metrics.metricName("byte-rate", repType.toString) + if (broker.metrics.metrics.asScala.contains(metricName)) + broker.metrics.metrics.asScala(metricName).metricValue.asInstanceOf[Double] + else -1 + } + + def json(topic: String*): String = { + val topicStr = topic.map { + t => "{\"topic\": \"" + t + "\"}" + }.mkString(",") + s"""{"topics": [$topicStr],"version":1}""" + } + } + + class Journal { + private val log = new File(dir, "Log.html") + header() + + def appendToJournal(config: ExperimentDef): Unit = { + val message = s"\n\n

                ${config.name}

                " + + s"

                - BrokerCount: ${config.brokers}" + + s"

                - PartitionCount: ${config.partitions}" + + f"

                - Throttle: ${config.throttle.toDouble}%,.0f MB/s" + + f"

                - MsgCount: ${config.msgsPerPartition}%,.0f " + + f"

                - MsgSize: ${config.msgSize}%,.0f" + + s"

                - TargetBytesPerBrokerMB: ${config.targetBytesPerBrokerMB}

                " + append(message) + } + + def appendChart(path: String, first: Boolean): Unit = { + val message = new StringBuilder + if (first) + message.append("

                ") + message.append("\"Chart\"") + if (!first) + message.append("

                ") + append(message.toString()) + } + + def header(): Unit = { + append("

                Replication Quotas Test Rig

                ") + } + + def footer(): Unit = { + append("") + } + + def append(message: String): Unit = { + val stream = Files.newOutputStream(log.toPath, StandardOpenOption.CREATE, StandardOpenOption.APPEND) + new PrintWriter(stream) { + append(message) + close + } + } + + def path(): String = { + log.getAbsolutePath + } + } + +} + diff --git a/core/src/test/scala/other/kafka/StressTestLog.scala b/core/src/test/scala/other/kafka/StressTestLog.scala new file mode 100755 index 0000000..2422dcc --- /dev/null +++ b/core/src/test/scala/other/kafka/StressTestLog.scala @@ -0,0 +1,151 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +package kafka + +import java.util.Properties +import java.util.concurrent.atomic._ + +import kafka.log._ +import kafka.server.{BrokerTopicStats, FetchLogEnd, LogDirFailureChannel} +import kafka.utils._ +import org.apache.kafka.clients.consumer.OffsetOutOfRangeException +import org.apache.kafka.common.record.FileRecords +import org.apache.kafka.common.utils.Utils + +/** + * A stress test that instantiates a log and then runs continual appends against it from one thread and continual reads against it + * from another thread and checks a few basic assertions until the user kills the process. + */ +object StressTestLog { + val running = new AtomicBoolean(true) + + def main(args: Array[String]): Unit = { + val dir = TestUtils.randomPartitionLogDir(TestUtils.tempDir()) + val time = new MockTime + val logProperties = new Properties() + logProperties.put(LogConfig.SegmentBytesProp, 64*1024*1024: java.lang.Integer) + logProperties.put(LogConfig.MaxMessageBytesProp, Int.MaxValue: java.lang.Integer) + logProperties.put(LogConfig.SegmentIndexBytesProp, 1024*1024: java.lang.Integer) + + val log = UnifiedLog(dir = dir, + config = LogConfig(logProperties), + logStartOffset = 0L, + recoveryPoint = 0L, + scheduler = time.scheduler, + time = time, + maxProducerIdExpirationMs = 60 * 60 * 1000, + producerIdExpirationCheckIntervalMs = LogManager.ProducerIdExpirationCheckIntervalMs, + brokerTopicStats = new BrokerTopicStats, + logDirFailureChannel = new LogDirFailureChannel(10), + topicId = None, + keepPartitionMetadataFile = true) + val writer = new WriterThread(log) + writer.start() + val reader = new ReaderThread(log) + reader.start() + + Exit.addShutdownHook("stress-test-shutdown-hook", { + running.set(false) + writer.join() + reader.join() + Utils.delete(dir) + }) + + while(running.get) { + Thread.sleep(1000) + println("Reader offset = %d, writer offset = %d".format(reader.currentOffset, writer.currentOffset)) + writer.checkProgress() + reader.checkProgress() + } + } + + abstract class WorkerThread extends Thread { + val threadInfo = "Thread: " + Thread.currentThread.getName + " Class: " + getClass.getName + + override def run(): Unit = { + try { + while(running.get) + work() + } catch { + case e: Exception => { + e.printStackTrace() + } + } finally { + running.set(false) + } + } + + def work(): Unit + def isMakingProgress(): Boolean + } + + trait LogProgress { + @volatile var currentOffset = 0 + private var lastOffsetCheckpointed = currentOffset + private var lastProgressCheckTime = System.currentTimeMillis + + def isMakingProgress(): Boolean = { + if (currentOffset > lastOffsetCheckpointed) { + lastOffsetCheckpointed = currentOffset + return true + } + + false + } + + def checkProgress(): Unit = { + // Check if we are making progress every 500ms + val curTime = System.currentTimeMillis + if ((curTime - lastProgressCheckTime) > 500) { + require(isMakingProgress(), "Thread not making progress") + lastProgressCheckTime = curTime + } + } + } + + class WriterThread(val log: UnifiedLog) extends WorkerThread with LogProgress { + override def work(): Unit = { + val logAppendInfo = log.appendAsLeader(TestUtils.singletonRecords(currentOffset.toString.getBytes), 0) + require(logAppendInfo.firstOffset.forall(_.messageOffset == currentOffset) && logAppendInfo.lastOffset == currentOffset) + currentOffset += 1 + if (currentOffset % 1000 == 0) + Thread.sleep(50) + } + } + + class ReaderThread(val log: UnifiedLog) extends WorkerThread with LogProgress { + override def work(): Unit = { + try { + log.read(currentOffset, + maxLength = 1, + isolation = FetchLogEnd, + minOneMessage = true).records match { + case read: FileRecords if read.sizeInBytes > 0 => { + val first = read.batches.iterator.next() + require(first.lastOffset == currentOffset, "We should either read nothing or the message we asked for.") + require(first.sizeInBytes == read.sizeInBytes, "Expected %d but got %d.".format(first.sizeInBytes, read.sizeInBytes)) + currentOffset += 1 + } + case _ => + } + } catch { + case _: OffsetOutOfRangeException => // this is okay + } + } + } +} diff --git a/core/src/test/scala/other/kafka/TestLinearWriteSpeed.scala b/core/src/test/scala/other/kafka/TestLinearWriteSpeed.scala new file mode 100755 index 0000000..f274954 --- /dev/null +++ b/core/src/test/scala/other/kafka/TestLinearWriteSpeed.scala @@ -0,0 +1,225 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka + +import java.io._ +import java.nio._ +import java.nio.channels._ +import java.nio.file.StandardOpenOption +import java.util.{Properties, Random} + +import joptsimple._ +import kafka.log._ +import kafka.message._ +import kafka.server.{BrokerTopicStats, LogDirFailureChannel} +import kafka.utils._ +import org.apache.kafka.common.record._ +import org.apache.kafka.common.utils.{Time, Utils} + +import scala.math._ + +/** + * This test does linear writes using either a kafka log or a file and measures throughput and latency. + */ +object TestLinearWriteSpeed { + + def main(args: Array[String]): Unit = { + val parser = new OptionParser(false) + val dirOpt = parser.accepts("dir", "The directory to write to.") + .withRequiredArg + .describedAs("path") + .ofType(classOf[java.lang.String]) + .defaultsTo(System.getProperty("java.io.tmpdir")) + val bytesOpt = parser.accepts("bytes", "REQUIRED: The total number of bytes to write.") + .withRequiredArg + .describedAs("num_bytes") + .ofType(classOf[java.lang.Long]) + val sizeOpt = parser.accepts("size", "REQUIRED: The size of each write.") + .withRequiredArg + .describedAs("num_bytes") + .ofType(classOf[java.lang.Integer]) + val messageSizeOpt = parser.accepts("message-size", "REQUIRED: The size of each message in the message set.") + .withRequiredArg + .describedAs("num_bytes") + .ofType(classOf[java.lang.Integer]) + .defaultsTo(1024) + val filesOpt = parser.accepts("files", "REQUIRED: The number of logs or files.") + .withRequiredArg + .describedAs("num_files") + .ofType(classOf[java.lang.Integer]) + .defaultsTo(1) + val reportingIntervalOpt = parser.accepts("reporting-interval", "The number of ms between updates.") + .withRequiredArg + .describedAs("ms") + .ofType(classOf[java.lang.Long]) + .defaultsTo(1000L) + val maxThroughputOpt = parser.accepts("max-throughput-mb", "The maximum throughput.") + .withRequiredArg + .describedAs("mb") + .ofType(classOf[java.lang.Integer]) + .defaultsTo(Integer.MAX_VALUE) + val flushIntervalOpt = parser.accepts("flush-interval", "The number of messages between flushes") + .withRequiredArg() + .describedAs("message_count") + .ofType(classOf[java.lang.Long]) + .defaultsTo(Long.MaxValue) + val compressionCodecOpt = parser.accepts("compression", "The compression codec to use") + .withRequiredArg + .describedAs("codec") + .ofType(classOf[java.lang.String]) + .defaultsTo(NoCompressionCodec.name) + val mmapOpt = parser.accepts("mmap", "Do writes to memory-mapped files.") + val channelOpt = parser.accepts("channel", "Do writes to file channels.") + val logOpt = parser.accepts("log", "Do writes to kafka logs.") + + val options = parser.parse(args : _*) + + CommandLineUtils.checkRequiredArgs(parser, options, bytesOpt, sizeOpt, filesOpt) + + var bytesToWrite = options.valueOf(bytesOpt).longValue + val bufferSize = options.valueOf(sizeOpt).intValue + val numFiles = options.valueOf(filesOpt).intValue + val reportingInterval = options.valueOf(reportingIntervalOpt).longValue + val dir = options.valueOf(dirOpt) + val maxThroughputBytes = options.valueOf(maxThroughputOpt).intValue * 1024L * 1024L + val buffer = ByteBuffer.allocate(bufferSize) + val messageSize = options.valueOf(messageSizeOpt).intValue + val flushInterval = options.valueOf(flushIntervalOpt).longValue + val compressionCodec = CompressionCodec.getCompressionCodec(options.valueOf(compressionCodecOpt)) + val rand = new Random + rand.nextBytes(buffer.array) + val numMessages = bufferSize / (messageSize + Records.LOG_OVERHEAD) + val createTime = System.currentTimeMillis + val messageSet = { + val compressionType = CompressionType.forId(compressionCodec.codec) + val records = (0 until numMessages).map(_ => new SimpleRecord(createTime, null, new Array[Byte](messageSize))) + MemoryRecords.withRecords(compressionType, records: _*) + } + + val writables = new Array[Writable](numFiles) + val scheduler = new KafkaScheduler(1) + scheduler.startup() + for(i <- 0 until numFiles) { + if(options.has(mmapOpt)) { + writables(i) = new MmapWritable(new File(dir, "kafka-test-" + i + ".dat"), bytesToWrite / numFiles, buffer) + } else if(options.has(channelOpt)) { + writables(i) = new ChannelWritable(new File(dir, "kafka-test-" + i + ".dat"), buffer) + } else if(options.has(logOpt)) { + val segmentSize = rand.nextInt(512)*1024*1024 + 64*1024*1024 // vary size to avoid herd effect + val logProperties = new Properties() + logProperties.put(LogConfig.SegmentBytesProp, segmentSize: java.lang.Integer) + logProperties.put(LogConfig.FlushMessagesProp, flushInterval: java.lang.Long) + writables(i) = new LogWritable(new File(dir, "kafka-test-" + i), new LogConfig(logProperties), scheduler, messageSet) + } else { + System.err.println("Must specify what to write to with one of --log, --channel, or --mmap") + Exit.exit(1) + } + } + bytesToWrite = (bytesToWrite / numFiles) * numFiles + + println("%10s\t%10s\t%10s".format("mb_sec", "avg_latency", "max_latency")) + + val beginTest = System.nanoTime + var maxLatency = 0L + var totalLatency = 0L + var count = 0L + var written = 0L + var totalWritten = 0L + var lastReport = beginTest + while(totalWritten + bufferSize < bytesToWrite) { + val start = System.nanoTime + val writeSize = writables((count % numFiles).toInt.abs).write() + val ellapsed = System.nanoTime - start + maxLatency = max(ellapsed, maxLatency) + totalLatency += ellapsed + written += writeSize + count += 1 + totalWritten += writeSize + if((start - lastReport)/(1000.0*1000.0) > reportingInterval.doubleValue) { + val ellapsedSecs = (start - lastReport) / (1000.0*1000.0*1000.0) + val mb = written / (1024.0*1024.0) + println("%10.3f\t%10.3f\t%10.3f".format(mb / ellapsedSecs, totalLatency / count.toDouble / (1000.0*1000.0), maxLatency / (1000.0 * 1000.0))) + lastReport = start + written = 0 + maxLatency = 0L + totalLatency = 0L + } else if(written > maxThroughputBytes * (reportingInterval / 1000.0)) { + // if we have written enough, just sit out this reporting interval + val lastReportMs = lastReport / (1000*1000) + val now = System.nanoTime / (1000*1000) + val sleepMs = lastReportMs + reportingInterval - now + if(sleepMs > 0) + Thread.sleep(sleepMs) + } + } + val elapsedSecs = (System.nanoTime - beginTest) / (1000.0*1000.0*1000.0) + println((bytesToWrite / (1024.0 * 1024.0 * elapsedSecs)).toString + " MB per sec") + scheduler.shutdown() + } + + trait Writable { + def write(): Int + def close(): Unit + } + + class MmapWritable(val file: File, size: Long, val content: ByteBuffer) extends Writable { + file.deleteOnExit() + val raf = new RandomAccessFile(file, "rw") + raf.setLength(size) + val buffer = raf.getChannel().map(FileChannel.MapMode.READ_WRITE, 0, raf.length()) + def write(): Int = { + buffer.put(content) + content.rewind() + content.limit() + } + def close(): Unit = { + raf.close() + Utils.delete(file) + } + } + + class ChannelWritable(val file: File, val content: ByteBuffer) extends Writable { + file.deleteOnExit() + val channel = FileChannel.open(file.toPath, StandardOpenOption.CREATE, StandardOpenOption.READ, + StandardOpenOption.WRITE) + def write(): Int = { + channel.write(content) + content.rewind() + content.limit() + } + def close(): Unit = { + channel.close() + Utils.delete(file) + } + } + + class LogWritable(val dir: File, config: LogConfig, scheduler: Scheduler, val messages: MemoryRecords) extends Writable { + Utils.delete(dir) + val log = UnifiedLog(dir, config, 0L, 0L, scheduler, new BrokerTopicStats, Time.SYSTEM, 60 * 60 * 1000, + LogManager.ProducerIdExpirationCheckIntervalMs, new LogDirFailureChannel(10), topicId = None, keepPartitionMetadataFile = true) + def write(): Int = { + log.appendAsLeader(messages, leaderEpoch = 0) + messages.sizeInBytes + } + def close(): Unit = { + log.close() + Utils.delete(log.dir) + } + } + +} diff --git a/core/src/test/scala/other/kafka/TestPurgatoryPerformance.scala b/core/src/test/scala/other/kafka/TestPurgatoryPerformance.scala new file mode 100644 index 0000000..507c3ff --- /dev/null +++ b/core/src/test/scala/other/kafka/TestPurgatoryPerformance.scala @@ -0,0 +1,292 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka + +import java.lang.management.ManagementFactory +import java.lang.management.OperatingSystemMXBean +import java.util.Random +import java.util.concurrent._ + +import joptsimple._ +import kafka.server.{DelayedOperation, DelayedOperationPurgatory} +import kafka.utils._ +import org.apache.kafka.common.utils.Time + +import scala.math._ +import scala.jdk.CollectionConverters._ + +/** + * This is a benchmark test of the purgatory. + */ +object TestPurgatoryPerformance { + + def main(args: Array[String]): Unit = { + val parser = new OptionParser(false) + val keySpaceSizeOpt = parser.accepts("key-space-size", "The total number of possible keys") + .withRequiredArg + .describedAs("total_num_possible_keys") + .ofType(classOf[java.lang.Integer]) + .defaultsTo(100) + val numRequestsOpt = parser.accepts("num", "The number of requests") + .withRequiredArg + .describedAs("num_requests") + .ofType(classOf[java.lang.Double]) + val requestRateOpt = parser.accepts("rate", "The request rate per second") + .withRequiredArg + .describedAs("request_per_second") + .ofType(classOf[java.lang.Double]) + val requestDataSizeOpt = parser.accepts("size", "The request data size in bytes") + .withRequiredArg + .describedAs("num_bytes") + .ofType(classOf[java.lang.Long]) + val numKeysOpt = parser.accepts("keys", "The number of keys for each request") + .withRequiredArg + .describedAs("num_keys") + .ofType(classOf[java.lang.Integer]) + .defaultsTo(3) + val timeoutOpt = parser.accepts("timeout", "The request timeout in ms") + .withRequiredArg + .describedAs("timeout_milliseconds") + .ofType(classOf[java.lang.Long]) + val pct75Opt = parser.accepts("pct75", "75th percentile of request latency in ms (log-normal distribution)") + .withRequiredArg + .describedAs("75th_percentile") + .ofType(classOf[java.lang.Double]) + val pct50Opt = parser.accepts("pct50", "50th percentile of request latency in ms (log-normal distribution)") + .withRequiredArg + .describedAs("50th_percentile") + .ofType(classOf[java.lang.Double]) + val verboseOpt = parser.accepts("verbose", "show additional information") + .withRequiredArg + .describedAs("true|false") + .ofType(classOf[java.lang.Boolean]) + .defaultsTo(true) + + val options = parser.parse(args: _*) + + CommandLineUtils.checkRequiredArgs(parser, options, numRequestsOpt, requestRateOpt, requestDataSizeOpt, pct75Opt, pct50Opt) + + val numRequests = options.valueOf(numRequestsOpt).intValue + val requestRate = options.valueOf(requestRateOpt).doubleValue + val requestDataSize = options.valueOf(requestDataSizeOpt).intValue + val numPossibleKeys = options.valueOf(keySpaceSizeOpt).intValue + val numKeys = options.valueOf(numKeysOpt).intValue + val timeout = options.valueOf(timeoutOpt).longValue + val pct75 = options.valueOf(pct75Opt).doubleValue + val pct50 = options.valueOf(pct50Opt).doubleValue + val verbose = options.valueOf(verboseOpt).booleanValue + + val gcMXBeans = ManagementFactory.getGarbageCollectorMXBeans().asScala.sortBy(_.getName) + val osMXBean = ManagementFactory.getOperatingSystemMXBean + val latencySamples = new LatencySamples(1000000, pct75, pct50) + val intervalSamples = new IntervalSamples(1000000, requestRate) + + val purgatory = DelayedOperationPurgatory[FakeOperation]("fake purgatory") + val queue = new CompletionQueue() + + val gcNames = gcMXBeans.map(_.getName) + + val initialCpuTimeNano = getProcessCpuTimeNanos(osMXBean) + val latch = new CountDownLatch(numRequests) + val start = System.currentTimeMillis + val rand = new Random() + val keys = (0 until numKeys).map(_ => "fakeKey%d".format(rand.nextInt(numPossibleKeys))) + @volatile var requestArrivalTime = start + @volatile var end = 0L + val generator = new Runnable { + def run(): Unit = { + var i = numRequests + while (i > 0) { + i -= 1 + val requestArrivalInterval = intervalSamples.next() + val latencyToComplete = latencySamples.next() + val now = System.currentTimeMillis + requestArrivalTime = requestArrivalTime + requestArrivalInterval + + if (requestArrivalTime > now) Thread.sleep(requestArrivalTime - now) + + val request = new FakeOperation(timeout, requestDataSize, latencyToComplete, latch) + if (latencyToComplete < timeout) queue.add(request) + purgatory.tryCompleteElseWatch(request, keys) + } + end = System.currentTimeMillis + } + } + val generatorThread = new Thread(generator) + + generatorThread.start() + generatorThread.join() + latch.await() + val done = System.currentTimeMillis + queue.shutdown() + + if (verbose) { + latencySamples.printStats() + intervalSamples.printStats() + println("# enqueue rate (%d requests):".format(numRequests)) + val gcCountHeader = gcNames.map("<" + _ + " count>").mkString(" ") + val gcTimeHeader = gcNames.map("<" + _ + " time ms>").mkString(" ") + println("# \t\t\t\t%s\t%s".format(gcCountHeader, gcTimeHeader)) + } + + val targetRate = numRequests.toDouble * 1000d / (requestArrivalTime - start).toDouble + val actualRate = numRequests.toDouble * 1000d / (end - start).toDouble + + val cpuTime = getProcessCpuTimeNanos(osMXBean).map(x => (x - initialCpuTimeNano.get) / 1000000L) + val gcCounts = gcMXBeans.map(_.getCollectionCount) + val gcTimes = gcMXBeans.map(_.getCollectionTime) + + println("%d\t%f\t%f\t%d\t%s\t%s".format(done - start, targetRate, actualRate, cpuTime.getOrElse(-1L), gcCounts.mkString(" "), gcTimes.mkString(" "))) + + purgatory.shutdown() + } + + // Use JRE-specific class to get process CPU time + private def getProcessCpuTimeNanos(osMXBean : OperatingSystemMXBean) = { + try { + Some(Class.forName("com.sun.management.OperatingSystemMXBean").getMethod("getProcessCpuTime").invoke(osMXBean).asInstanceOf[Long]) + } catch { + case _: Throwable => try { + Some(Class.forName("com.ibm.lang.management.OperatingSystemMXBean").getMethod("getProcessCpuTimeByNS").invoke(osMXBean).asInstanceOf[Long]) + } catch { + case _: Throwable => None + } + } + } + + // log-normal distribution (http://en.wikipedia.org/wiki/Log-normal_distribution) + // mu: the mean of the underlying normal distribution (not the mean of this log-normal distribution) + // sigma: the standard deviation of the underlying normal distribution (not the stdev of this log-normal distribution) + private class LogNormalDistribution(mu: Double, sigma: Double) { + val rand = new Random + def next(): Double = { + val n = rand.nextGaussian() * sigma + mu + math.exp(n) + } + } + + // exponential distribution (http://en.wikipedia.org/wiki/Exponential_distribution) + // lambda : the rate parameter of the exponential distribution + private class ExponentialDistribution(lambda: Double) { + val rand = new Random + def next(): Double = { + math.log(1d - rand.nextDouble()) / (- lambda) + } + } + + // Samples of Latencies to completion + // They are drawn from a log normal distribution. + // A latency value can never be negative. A log-normal distribution is a convenient way to + // model such a random variable. + private class LatencySamples(sampleSize: Int, pct75: Double, pct50: Double) { + private[this] val rand = new Random + private[this] val samples = { + val normalMean = math.log(pct50) + val normalStDev = (math.log(pct75) - normalMean) / 0.674490d // 0.674490 is 75th percentile point in N(0,1) + val dist = new LogNormalDistribution(normalMean, normalStDev) + (0 until sampleSize).map { _ => dist.next().toLong }.toArray + } + def next() = samples(rand.nextInt(sampleSize)) + + def printStats(): Unit = { + val p75 = samples.sorted.apply((sampleSize.toDouble * 0.75d).toInt) + val p50 = samples.sorted.apply((sampleSize.toDouble * 0.5d).toInt) + + println("# latency samples: pct75 = %d, pct50 = %d, min = %d, max = %d".format(p75, p50, samples.min, samples.max)) + } + } + + // Samples of Request arrival intervals + // The request arrival is modeled as a Poisson process. + // So, the internals are drawn from an exponential distribution. + private class IntervalSamples(sampleSize: Int, requestPerSecond: Double) { + private[this] val rand = new Random + private[this] val samples = { + val dist = new ExponentialDistribution(requestPerSecond / 1000d) + var residue = 0.0 + (0 until sampleSize).map { _ => + val interval = dist.next() + residue + val roundedInterval = interval.toLong + residue = interval - roundedInterval.toDouble + roundedInterval + }.toArray + } + + def next() = samples(rand.nextInt(sampleSize)) + + def printStats(): Unit = { + println( + "# interval samples: rate = %f, min = %d, max = %d" + .format(1000d / (samples.map(_.toDouble).sum / sampleSize.toDouble), samples.min, samples.max) + ) + } + } + + private class FakeOperation(delayMs: Long, size: Int, val latencyMs: Long, latch: CountDownLatch) extends DelayedOperation(delayMs) { + val completesAt = System.currentTimeMillis + latencyMs + + def onExpiration(): Unit = {} + + def onComplete(): Unit = { + latch.countDown() + } + + def tryComplete(): Boolean = { + if (System.currentTimeMillis >= completesAt) + forceComplete() + else + false + } + } + + private class CompletionQueue { + private[this] val delayQueue = new DelayQueue[Scheduled]() + private[this] val thread = new ShutdownableThread(name = "completion thread", isInterruptible = false) { + override def doWork(): Unit = { + val scheduled = delayQueue.poll(100, TimeUnit.MILLISECONDS) + if (scheduled != null) { + scheduled.operation.forceComplete() + } + } + } + thread.start() + + def add(operation: FakeOperation): Unit = { + delayQueue.offer(new Scheduled(operation)) + } + + def shutdown() = { + thread.shutdown() + } + + private class Scheduled(val operation: FakeOperation) extends Delayed { + def getDelay(unit: TimeUnit): Long = { + unit.convert(max(operation.completesAt - Time.SYSTEM.milliseconds, 0), TimeUnit.MILLISECONDS) + } + + def compareTo(d: Delayed): Int = { + + val other = d.asInstanceOf[Scheduled] + + if (operation.completesAt < other.operation.completesAt) -1 + else if (operation.completesAt > other.operation.completesAt) 1 + else 0 + } + } + } +} diff --git a/core/src/test/scala/other/kafka/TestTruncate.scala b/core/src/test/scala/other/kafka/TestTruncate.scala new file mode 100644 index 0000000..ba9c799 --- /dev/null +++ b/core/src/test/scala/other/kafka/TestTruncate.scala @@ -0,0 +1,41 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka + +import java.io._ +import java.nio._ +import java.nio.channels.FileChannel +import java.nio.file.StandardOpenOption + +/* This code tests the correct function of java's FileChannel.truncate--some platforms don't work. */ +object TestTruncate { + + def main(args: Array[String]): Unit = { + val name = File.createTempFile("kafka", ".test") + name.deleteOnExit() + val file = FileChannel.open(name.toPath, StandardOpenOption.READ, StandardOpenOption.WRITE) + val buffer = ByteBuffer.allocate(12) + buffer.putInt(4).putInt(4).putInt(4) + buffer.rewind() + file.write(buffer) + println("position prior to truncate: " + file.position) + file.truncate(4) + println("position after truncate to 4: " + file.position) + } + +} diff --git a/core/src/test/scala/unit/kafka/KafkaConfigTest.scala b/core/src/test/scala/unit/kafka/KafkaConfigTest.scala new file mode 100644 index 0000000..8c8d3e7 --- /dev/null +++ b/core/src/test/scala/unit/kafka/KafkaConfigTest.scala @@ -0,0 +1,405 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka + +import java.io.File +import java.nio.file.Files +import java.util +import java.util.Properties +import kafka.server.KafkaConfig +import kafka.utils.Exit +import kafka.utils.TestUtils.assertBadConfigContainingMessage +import org.apache.kafka.common.config.internals.BrokerSecurityConfigs +import org.apache.kafka.common.config.types.Password +import org.apache.kafka.common.internals.FatalExitError +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} +import org.junit.jupiter.api.Assertions._ + +import scala.jdk.CollectionConverters._ + +class KafkaTest { + + @BeforeEach + def setUp(): Unit = Exit.setExitProcedure((status, _) => throw new FatalExitError(status)) + + @AfterEach + def tearDown(): Unit = Exit.resetExitProcedure() + + @Test + def testGetKafkaConfigFromArgs(): Unit = { + val propertiesFile = prepareDefaultConfig() + + // We should load configuration file without any arguments + val config1 = KafkaConfig.fromProps(Kafka.getPropsFromArgs(Array(propertiesFile))) + assertEquals(1, config1.brokerId) + + // We should be able to override given property on command line + val config2 = KafkaConfig.fromProps(Kafka.getPropsFromArgs(Array(propertiesFile, "--override", "broker.id=2"))) + assertEquals(2, config2.brokerId) + + // We should be also able to set completely new property + val config3 = KafkaConfig.fromProps(Kafka.getPropsFromArgs(Array(propertiesFile, "--override", "log.cleanup.policy=compact"))) + assertEquals(1, config3.brokerId) + assertEquals(util.Arrays.asList("compact"), config3.logCleanupPolicy) + + // We should be also able to set several properties + val config4 = KafkaConfig.fromProps(Kafka.getPropsFromArgs(Array(propertiesFile, "--override", "log.cleanup.policy=compact,delete", "--override", "broker.id=2"))) + assertEquals(2, config4.brokerId) + assertEquals(util.Arrays.asList("compact","delete"), config4.logCleanupPolicy) + } + + @Test + def testGetKafkaConfigFromArgsNonArgsAtTheEnd(): Unit = { + val propertiesFile = prepareDefaultConfig() + assertThrows(classOf[FatalExitError], () => KafkaConfig.fromProps(Kafka.getPropsFromArgs(Array(propertiesFile, "--override", "broker.id=1", "broker.id=2")))) + } + + @Test + def testGetKafkaConfigFromArgsNonArgsOnly(): Unit = { + val propertiesFile = prepareDefaultConfig() + assertThrows(classOf[FatalExitError], () => KafkaConfig.fromProps(Kafka.getPropsFromArgs(Array(propertiesFile, "broker.id=1", "broker.id=2")))) + } + + @Test + def testGetKafkaConfigFromArgsNonArgsAtTheBegging(): Unit = { + val propertiesFile = prepareDefaultConfig() + assertThrows(classOf[FatalExitError], () => KafkaConfig.fromProps(Kafka.getPropsFromArgs(Array(propertiesFile, "broker.id=1", "--override", "broker.id=2")))) + } + + @Test + def testBrokerRoleNodeIdValidation(): Unit = { + // Ensure that validation is happening at startup to check that brokers do not use their node.id as a voter in controller.quorum.voters + val propertiesFile = new Properties + propertiesFile.setProperty(KafkaConfig.ProcessRolesProp, "broker") + propertiesFile.setProperty(KafkaConfig.NodeIdProp, "1") + propertiesFile.setProperty(KafkaConfig.QuorumVotersProp, "1@localhost:9092") + setListenerProps(propertiesFile) + assertBadConfigContainingMessage(propertiesFile, + "If process.roles contains just the 'broker' role, the node id 1 must not be included in the set of voters") + + // Ensure that with a valid config no exception is thrown + propertiesFile.setProperty(KafkaConfig.NodeIdProp, "2") + KafkaConfig.fromProps(propertiesFile) + } + + @Test + def testControllerRoleNodeIdValidation(): Unit = { + // Ensure that validation is happening at startup to check that controllers use their node.id as a voter in controller.quorum.voters + val propertiesFile = new Properties + propertiesFile.setProperty(KafkaConfig.ProcessRolesProp, "controller") + propertiesFile.setProperty(KafkaConfig.NodeIdProp, "1") + propertiesFile.setProperty(KafkaConfig.QuorumVotersProp, "2@localhost:9092") + setListenerProps(propertiesFile) + assertBadConfigContainingMessage(propertiesFile, + "If process.roles contains the 'controller' role, the node id 1 must be included in the set of voters") + + // Ensure that with a valid config no exception is thrown + propertiesFile.setProperty(KafkaConfig.NodeIdProp, "2") + KafkaConfig.fromProps(propertiesFile) + } + + @Test + def testColocatedRoleNodeIdValidation(): Unit = { + // Ensure that validation is happening at startup to check that colocated processes use their node.id as a voter in controller.quorum.voters + val propertiesFile = new Properties + propertiesFile.setProperty(KafkaConfig.ProcessRolesProp, "controller,broker") + propertiesFile.setProperty(KafkaConfig.NodeIdProp, "1") + propertiesFile.setProperty(KafkaConfig.QuorumVotersProp, "2@localhost:9092") + setListenerProps(propertiesFile) + assertBadConfigContainingMessage(propertiesFile, + "If process.roles contains the 'controller' role, the node id 1 must be included in the set of voters") + + // Ensure that with a valid config no exception is thrown + propertiesFile.setProperty(KafkaConfig.NodeIdProp, "2") + KafkaConfig.fromProps(propertiesFile) + } + + @Test + def testMustContainQuorumVotersIfUsingProcessRoles(): Unit = { + // Ensure that validation is happening at startup to check that if process.roles is set controller.quorum.voters is not empty + val propertiesFile = new Properties + propertiesFile.setProperty(KafkaConfig.ProcessRolesProp, "controller,broker") + propertiesFile.setProperty(KafkaConfig.NodeIdProp, "1") + propertiesFile.setProperty(KafkaConfig.QuorumVotersProp, "") + setListenerProps(propertiesFile) + assertBadConfigContainingMessage(propertiesFile, + "If using process.roles, controller.quorum.voters must contain a parseable set of voters.") + + // Ensure that if neither process.roles nor controller.quorum.voters is populated, then an exception is thrown if zookeeper.connect is not defined + propertiesFile.setProperty(KafkaConfig.ProcessRolesProp, "") + assertBadConfigContainingMessage(propertiesFile, + "Missing required configuration `zookeeper.connect` which has no default value.") + + // Ensure that no exception is thrown once zookeeper.connect is defined (and we clear controller.listener.names) + propertiesFile.setProperty(KafkaConfig.ZkConnectProp, "localhost:2181") + propertiesFile.setProperty(KafkaConfig.ControllerListenerNamesProp, "") + KafkaConfig.fromProps(propertiesFile) + } + + private def setListenerProps(props: Properties): Unit = { + val hasBrokerRole = props.getProperty(KafkaConfig.ProcessRolesProp).contains("broker") + val hasControllerRole = props.getProperty(KafkaConfig.ProcessRolesProp).contains("controller") + val controllerListener = "SASL_PLAINTEXT://localhost:9092" + val brokerListener = "PLAINTEXT://localhost:9093" + + if (hasBrokerRole || hasControllerRole) { // KRaft + props.setProperty(KafkaConfig.ControllerListenerNamesProp, "SASL_PLAINTEXT") + if (hasBrokerRole && hasControllerRole) { + props.setProperty(KafkaConfig.ListenersProp, s"$brokerListener,$controllerListener") + } else if (hasControllerRole) { + props.setProperty(KafkaConfig.ListenersProp, controllerListener) + } else if (hasBrokerRole) { + props.setProperty(KafkaConfig.ListenersProp, brokerListener) + } + } else { // ZK-based + props.setProperty(KafkaConfig.ListenersProp, brokerListener) + } + if (!(hasControllerRole & !hasBrokerRole)) { // not controller-only + props.setProperty(KafkaConfig.InterBrokerListenerNameProp, "PLAINTEXT") + props.setProperty(KafkaConfig.AdvertisedListenersProp, "PLAINTEXT://localhost:9092") + } + } + + @Test + def testKafkaSslPasswords(): Unit = { + val propertiesFile = prepareDefaultConfig() + val config = KafkaConfig.fromProps(Kafka.getPropsFromArgs(Array(propertiesFile, "--override", "ssl.keystore.password=keystore_password", + "--override", "ssl.key.password=key_password", + "--override", "ssl.truststore.password=truststore_password", + "--override", "ssl.keystore.certificate.chain=certificate_chain", + "--override", "ssl.keystore.key=private_key", + "--override", "ssl.truststore.certificates=truststore_certificates"))) + assertEquals(Password.HIDDEN, config.getPassword(KafkaConfig.SslKeyPasswordProp).toString) + assertEquals(Password.HIDDEN, config.getPassword(KafkaConfig.SslKeystorePasswordProp).toString) + assertEquals(Password.HIDDEN, config.getPassword(KafkaConfig.SslTruststorePasswordProp).toString) + assertEquals(Password.HIDDEN, config.getPassword(KafkaConfig.SslKeystoreKeyProp).toString) + assertEquals(Password.HIDDEN, config.getPassword(KafkaConfig.SslKeystoreCertificateChainProp).toString) + assertEquals(Password.HIDDEN, config.getPassword(KafkaConfig.SslTruststoreCertificatesProp).toString) + + assertEquals("key_password", config.getPassword(KafkaConfig.SslKeyPasswordProp).value) + assertEquals("keystore_password", config.getPassword(KafkaConfig.SslKeystorePasswordProp).value) + assertEquals("truststore_password", config.getPassword(KafkaConfig.SslTruststorePasswordProp).value) + assertEquals("private_key", config.getPassword(KafkaConfig.SslKeystoreKeyProp).value) + assertEquals("certificate_chain", config.getPassword(KafkaConfig.SslKeystoreCertificateChainProp).value) + assertEquals("truststore_certificates", config.getPassword(KafkaConfig.SslTruststoreCertificatesProp).value) + } + + @Test + def testKafkaSslPasswordsWithSymbols(): Unit = { + val password = "=!#-+!?*/\"\'^%$=\\.,@:;=" + val propertiesFile = prepareDefaultConfig() + val config = KafkaConfig.fromProps(Kafka.getPropsFromArgs(Array(propertiesFile, + "--override", "ssl.keystore.password=" + password, + "--override", "ssl.key.password=" + password, + "--override", "ssl.truststore.password=" + password))) + assertEquals(Password.HIDDEN, config.getPassword(KafkaConfig.SslKeyPasswordProp).toString) + assertEquals(Password.HIDDEN, config.getPassword(KafkaConfig.SslKeystorePasswordProp).toString) + assertEquals(Password.HIDDEN, config.getPassword(KafkaConfig.SslTruststorePasswordProp).toString) + + assertEquals(password, config.getPassword(KafkaConfig.SslKeystorePasswordProp).value) + assertEquals(password, config.getPassword(KafkaConfig.SslKeyPasswordProp).value) + assertEquals(password, config.getPassword(KafkaConfig.SslTruststorePasswordProp).value) + } + + private val booleanPropValueToSet = true + private val stringPropValueToSet = "foo" + private val passwordPropValueToSet = "ThePa$$word!" + private val listPropValueToSet = List("A", "B") + + @Test + def testZkSslClientEnable(): Unit = { + testZkConfig(KafkaConfig.ZkSslClientEnableProp, "zookeeper.ssl.client.enable", + "zookeeper.client.secure", booleanPropValueToSet, config => Some(config.zkSslClientEnable), booleanPropValueToSet, Some(false)) + } + + @Test + def testZkSslKeyStoreLocation(): Unit = { + testZkConfig(KafkaConfig.ZkSslKeyStoreLocationProp, "zookeeper.ssl.keystore.location", + "zookeeper.ssl.keyStore.location", stringPropValueToSet, config => config.zkSslKeyStoreLocation, stringPropValueToSet) + } + + @Test + def testZkSslTrustStoreLocation(): Unit = { + testZkConfig(KafkaConfig.ZkSslTrustStoreLocationProp, "zookeeper.ssl.truststore.location", + "zookeeper.ssl.trustStore.location", stringPropValueToSet, config => config.zkSslTrustStoreLocation, stringPropValueToSet) + } + + @Test + def testZookeeperKeyStorePassword(): Unit = { + testZkConfig(KafkaConfig.ZkSslKeyStorePasswordProp, "zookeeper.ssl.keystore.password", + "zookeeper.ssl.keyStore.password", passwordPropValueToSet, config => config.zkSslKeyStorePassword, new Password(passwordPropValueToSet)) + } + + @Test + def testZookeeperTrustStorePassword(): Unit = { + testZkConfig(KafkaConfig.ZkSslTrustStorePasswordProp, "zookeeper.ssl.truststore.password", + "zookeeper.ssl.trustStore.password", passwordPropValueToSet, config => config.zkSslTrustStorePassword, new Password(passwordPropValueToSet)) + } + + @Test + def testZkSslKeyStoreType(): Unit = { + testZkConfig(KafkaConfig.ZkSslKeyStoreTypeProp, "zookeeper.ssl.keystore.type", + "zookeeper.ssl.keyStore.type", stringPropValueToSet, config => config.zkSslKeyStoreType, stringPropValueToSet) + } + + @Test + def testZkSslTrustStoreType(): Unit = { + testZkConfig(KafkaConfig.ZkSslTrustStoreTypeProp, "zookeeper.ssl.truststore.type", + "zookeeper.ssl.trustStore.type", stringPropValueToSet, config => config.zkSslTrustStoreType, stringPropValueToSet) + } + + @Test + def testZkSslProtocol(): Unit = { + testZkConfig(KafkaConfig.ZkSslProtocolProp, "zookeeper.ssl.protocol", + "zookeeper.ssl.protocol", stringPropValueToSet, config => Some(config.ZkSslProtocol), stringPropValueToSet, Some("TLSv1.2")) + } + + @Test + def testZkSslEnabledProtocols(): Unit = { + testZkConfig(KafkaConfig.ZkSslEnabledProtocolsProp, "zookeeper.ssl.enabled.protocols", + "zookeeper.ssl.enabledProtocols", listPropValueToSet.mkString(","), config => config.ZkSslEnabledProtocols, listPropValueToSet.asJava) + } + + @Test + def testZkSslCipherSuites(): Unit = { + testZkConfig(KafkaConfig.ZkSslCipherSuitesProp, "zookeeper.ssl.cipher.suites", + "zookeeper.ssl.ciphersuites", listPropValueToSet.mkString(","), config => config.ZkSslCipherSuites, listPropValueToSet.asJava) + } + + @Test + def testZkSslEndpointIdentificationAlgorithm(): Unit = { + // this property is different than the others + // because the system property values and the Kafka property values don't match + val kafkaPropName = KafkaConfig.ZkSslEndpointIdentificationAlgorithmProp + assertEquals("zookeeper.ssl.endpoint.identification.algorithm", kafkaPropName) + val sysProp = "zookeeper.ssl.hostnameVerification" + val expectedDefaultValue = "HTTPS" + val propertiesFile = prepareDefaultConfig() + // first make sure there is the correct default value + val emptyConfig = KafkaConfig.fromProps(Kafka.getPropsFromArgs(Array(propertiesFile))) + assertNull(emptyConfig.originals.get(kafkaPropName)) // doesn't appear in the originals + assertEquals(expectedDefaultValue, emptyConfig.values.get(kafkaPropName)) // but default value appears in the values + assertEquals(expectedDefaultValue, emptyConfig.ZkSslEndpointIdentificationAlgorithm) // and has the correct default value + // next set system property alone + Map("true" -> "HTTPS", "false" -> "").foreach { case (sysPropValue, expected) => { + try { + System.setProperty(sysProp, sysPropValue) + val config = KafkaConfig.fromProps(Kafka.getPropsFromArgs(Array(propertiesFile))) + assertNull(config.originals.get(kafkaPropName)) // doesn't appear in the originals + assertEquals(expectedDefaultValue, config.values.get(kafkaPropName)) // default value appears in the values + assertEquals(expected, config.ZkSslEndpointIdentificationAlgorithm) // system property impacts the ultimate value of the property + } finally { + System.clearProperty(sysProp) + } + }} + // finally set Kafka config alone + List("https", "").foreach(expected => { + val config = KafkaConfig.fromProps(Kafka.getPropsFromArgs(Array(propertiesFile, "--override", s"$kafkaPropName=${expected}"))) + assertEquals(expected, config.originals.get(kafkaPropName)) // appears in the originals + assertEquals(expected, config.values.get(kafkaPropName)) // appears in the values + assertEquals(expected, config.ZkSslEndpointIdentificationAlgorithm) // is the ultimate value + }) + } + + @Test + def testZkSslCrlEnable(): Unit = { + testZkConfig(KafkaConfig.ZkSslCrlEnableProp, "zookeeper.ssl.crl.enable", + "zookeeper.ssl.crl", booleanPropValueToSet, config => Some(config.ZkSslCrlEnable), booleanPropValueToSet, Some(false)) + } + + @Test + def testZkSslOcspEnable(): Unit = { + testZkConfig(KafkaConfig.ZkSslOcspEnableProp, "zookeeper.ssl.ocsp.enable", + "zookeeper.ssl.ocsp", booleanPropValueToSet, config => Some(config.ZkSslOcspEnable), booleanPropValueToSet, Some(false)) + } + + @Test + def testConnectionsMaxReauthMsDefault(): Unit = { + val propertiesFile = prepareDefaultConfig() + val config = KafkaConfig.fromProps(Kafka.getPropsFromArgs(Array(propertiesFile))) + assertEquals(0L, config.valuesWithPrefixOverride("sasl_ssl.oauthbearer.").get(BrokerSecurityConfigs.CONNECTIONS_MAX_REAUTH_MS).asInstanceOf[Long]) + } + + @Test + def testConnectionsMaxReauthMsExplicit(): Unit = { + val propertiesFile = prepareDefaultConfig() + val expected = 3600000 + val config = KafkaConfig.fromProps(Kafka.getPropsFromArgs(Array(propertiesFile, "--override", s"sasl_ssl.oauthbearer.connections.max.reauth.ms=${expected}"))) + assertEquals(expected, config.valuesWithPrefixOverride("sasl_ssl.oauthbearer.").get(BrokerSecurityConfigs.CONNECTIONS_MAX_REAUTH_MS).asInstanceOf[Long]) + } + + private def testZkConfig[T, U](kafkaPropName: String, + expectedKafkaPropName: String, + sysPropName: String, + propValueToSet: T, + getPropValueFrom: (KafkaConfig) => Option[T], + expectedPropertyValue: U, + expectedDefaultValue: Option[T] = None): Unit = { + assertEquals(expectedKafkaPropName, kafkaPropName) + val propertiesFile = prepareDefaultConfig() + // first make sure there is the correct default value (if any) + val emptyConfig = KafkaConfig.fromProps(Kafka.getPropsFromArgs(Array(propertiesFile))) + assertNull(emptyConfig.originals.get(kafkaPropName)) // doesn't appear in the originals + if (expectedDefaultValue.isDefined) { + // confirm default value behavior + assertEquals(expectedDefaultValue.get, emptyConfig.values.get(kafkaPropName)) // default value appears in the values + assertEquals(expectedDefaultValue.get, getPropValueFrom(emptyConfig).get) // default value appears in the property + } else { + // confirm no default value behavior + assertNull(emptyConfig.values.get(kafkaPropName)) // doesn't appear in the values + assertEquals(None, getPropValueFrom(emptyConfig)) // has no default value + } + // next set system property alone + try { + System.setProperty(sysPropName, s"$propValueToSet") + // need to create a new Kafka config for the system property to be recognized + val config = KafkaConfig.fromProps(Kafka.getPropsFromArgs(Array(propertiesFile))) + assertNull(config.originals.get(kafkaPropName)) // doesn't appear in the originals + // confirm default value (if any) overridden by system property + if (expectedDefaultValue.isDefined) + assertEquals(expectedDefaultValue.get, config.values.get(kafkaPropName)) // default value (different from system property) appears in the values + else + assertNull(config.values.get(kafkaPropName)) // doesn't appear in the values + // confirm system property appears in the property + assertEquals(Some(expectedPropertyValue), getPropValueFrom(config)) + } finally { + System.clearProperty(sysPropName) + } + // finally set Kafka config alone + val config = KafkaConfig.fromProps(Kafka.getPropsFromArgs(Array(propertiesFile, "--override", s"$kafkaPropName=${propValueToSet}"))) + assertEquals(expectedPropertyValue, config.values.get(kafkaPropName)) // appears in the values + assertEquals(Some(expectedPropertyValue), getPropValueFrom(config)) // appears in the property + } + + def prepareDefaultConfig(): String = { + prepareConfig(Array("broker.id=1", "zookeeper.connect=somewhere")) + } + + def prepareConfig(lines : Array[String]): String = { + val file = File.createTempFile("kafkatest", ".properties") + file.deleteOnExit() + + val writer = Files.newOutputStream(file.toPath) + try { + lines.foreach { l => + writer.write(l.getBytes) + writer.write("\n".getBytes) + } + file.getAbsolutePath + } finally writer.close() + } +} diff --git a/core/src/test/scala/unit/kafka/admin/AclCommandTest.scala b/core/src/test/scala/unit/kafka/admin/AclCommandTest.scala new file mode 100644 index 0000000..7cd5a18 --- /dev/null +++ b/core/src/test/scala/unit/kafka/admin/AclCommandTest.scala @@ -0,0 +1,332 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.admin + +import java.io.{File, PrintWriter} +import java.util.Properties +import javax.management.InstanceAlreadyExistsException +import kafka.admin.AclCommand.AclCommandOptions +import kafka.security.authorizer.{AclAuthorizer, AclEntry} +import kafka.server.{KafkaConfig, KafkaServer} +import kafka.utils.{Exit, LogCaptureAppender, Logging, TestUtils} +import kafka.server.QuorumTestHarness +import org.apache.kafka.common.acl.{AccessControlEntry, AclOperation, AclPermissionType} +import org.apache.kafka.common.acl.AclOperation._ +import org.apache.kafka.common.acl.AclPermissionType._ +import org.apache.kafka.common.resource.{PatternType, Resource, ResourcePattern} +import org.apache.kafka.common.resource.ResourceType._ +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.resource.PatternType.{LITERAL, PREFIXED} +import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol} +import org.apache.kafka.common.utils.{AppInfoParser, SecurityUtils} +import org.apache.kafka.server.authorizer.Authorizer +import org.apache.log4j.Level +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +class AclCommandTest extends QuorumTestHarness with Logging { + + var servers: Seq[KafkaServer] = Seq() + + private val principal: KafkaPrincipal = SecurityUtils.parseKafkaPrincipal("User:test2") + private val Users = Set(SecurityUtils.parseKafkaPrincipal("User:CN=writeuser,OU=Unknown,O=Unknown,L=Unknown,ST=Unknown,C=Unknown"), + principal, SecurityUtils.parseKafkaPrincipal("""User:CN=\#User with special chars in CN : (\, \+ \" \\ \< \> \; ')""")) + private val Hosts = Set("host1", "host2") + private val AllowHostCommand = Array("--allow-host", "host1", "--allow-host", "host2") + private val DenyHostCommand = Array("--deny-host", "host1", "--deny-host", "host2") + + private val ClusterResource = new ResourcePattern(CLUSTER, Resource.CLUSTER_NAME, LITERAL) + private val TopicResources = Set(new ResourcePattern(TOPIC, "test-1", LITERAL), new ResourcePattern(TOPIC, "test-2", LITERAL)) + private val GroupResources = Set(new ResourcePattern(GROUP, "testGroup-1", LITERAL), new ResourcePattern(GROUP, "testGroup-2", LITERAL)) + private val TransactionalIdResources = Set(new ResourcePattern(TRANSACTIONAL_ID, "t0", LITERAL), new ResourcePattern(TRANSACTIONAL_ID, "t1", LITERAL)) + private val TokenResources = Set(new ResourcePattern(DELEGATION_TOKEN, "token1", LITERAL), new ResourcePattern(DELEGATION_TOKEN, "token2", LITERAL)) + + private val ResourceToCommand = Map[Set[ResourcePattern], Array[String]]( + TopicResources -> Array("--topic", "test-1", "--topic", "test-2"), + Set(ClusterResource) -> Array("--cluster"), + GroupResources -> Array("--group", "testGroup-1", "--group", "testGroup-2"), + TransactionalIdResources -> Array("--transactional-id", "t0", "--transactional-id", "t1"), + TokenResources -> Array("--delegation-token", "token1", "--delegation-token", "token2") + ) + + private val ResourceToOperations = Map[Set[ResourcePattern], (Set[AclOperation], Array[String])]( + TopicResources -> (Set(READ, WRITE, CREATE, DESCRIBE, DELETE, DESCRIBE_CONFIGS, ALTER_CONFIGS, ALTER), + Array("--operation", "Read" , "--operation", "Write", "--operation", "Create", "--operation", "Describe", "--operation", "Delete", + "--operation", "DescribeConfigs", "--operation", "AlterConfigs", "--operation", "Alter")), + Set(ClusterResource) -> (Set(CREATE, CLUSTER_ACTION, DESCRIBE_CONFIGS, ALTER_CONFIGS, IDEMPOTENT_WRITE, ALTER, DESCRIBE), + Array("--operation", "Create", "--operation", "ClusterAction", "--operation", "DescribeConfigs", + "--operation", "AlterConfigs", "--operation", "IdempotentWrite", "--operation", "Alter", "--operation", "Describe")), + GroupResources -> (Set(READ, DESCRIBE, DELETE), Array("--operation", "Read", "--operation", "Describe", "--operation", "Delete")), + TransactionalIdResources -> (Set(DESCRIBE, WRITE), Array("--operation", "Describe", "--operation", "Write")), + TokenResources -> (Set(DESCRIBE), Array("--operation", "Describe")) + ) + + private def ProducerResourceToAcls(enableIdempotence: Boolean = false) = Map[Set[ResourcePattern], Set[AccessControlEntry]]( + TopicResources -> AclCommand.getAcls(Users, ALLOW, Set(WRITE, DESCRIBE, CREATE), Hosts), + TransactionalIdResources -> AclCommand.getAcls(Users, ALLOW, Set(WRITE, DESCRIBE), Hosts), + Set(ClusterResource) -> AclCommand.getAcls(Users, ALLOW, + Set(if (enableIdempotence) Some(IDEMPOTENT_WRITE) else None).flatten, Hosts) + ) + + private val ConsumerResourceToAcls = Map[Set[ResourcePattern], Set[AccessControlEntry]]( + TopicResources -> AclCommand.getAcls(Users, ALLOW, Set(READ, DESCRIBE), Hosts), + GroupResources -> AclCommand.getAcls(Users, ALLOW, Set(READ), Hosts) + ) + + private val CmdToResourcesToAcl = Map[Array[String], Map[Set[ResourcePattern], Set[AccessControlEntry]]]( + Array[String]("--producer") -> ProducerResourceToAcls(), + Array[String]("--producer", "--idempotent") -> ProducerResourceToAcls(enableIdempotence = true), + Array[String]("--consumer") -> ConsumerResourceToAcls, + Array[String]("--producer", "--consumer") -> ConsumerResourceToAcls.map { case (k, v) => k -> (v ++ + ProducerResourceToAcls().getOrElse(k, Set.empty[AccessControlEntry])) }, + Array[String]("--producer", "--idempotent", "--consumer") -> ConsumerResourceToAcls.map { case (k, v) => k -> (v ++ + ProducerResourceToAcls(enableIdempotence = true).getOrElse(k, Set.empty[AccessControlEntry])) } + ) + + private var brokerProps: Properties = _ + private var zkArgs: Array[String] = _ + private var adminArgs: Array[String] = _ + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + + brokerProps = TestUtils.createBrokerConfig(0, zkConnect) + brokerProps.put(KafkaConfig.AuthorizerClassNameProp, classOf[AclAuthorizer].getName) + brokerProps.put(AclAuthorizer.SuperUsersProp, "User:ANONYMOUS") + + zkArgs = Array("--authorizer-properties", "zookeeper.connect=" + zkConnect) + } + + @AfterEach + override def tearDown(): Unit = { + TestUtils.shutdownServers(servers) + super.tearDown() + } + + @Test + def testAclCliWithAuthorizer(): Unit = { + testAclCli(zkArgs) + } + + @Test + def testAclCliWithAdminAPI(): Unit = { + createServer() + testAclCli(adminArgs) + } + + private def createServer(commandConfig: Option[File] = None): Unit = { + servers = Seq(TestUtils.createServer(KafkaConfig.fromProps(brokerProps))) + val listenerName = ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT) + + var adminArgs = Array("--bootstrap-server", TestUtils.bootstrapServers(servers, listenerName)) + if (commandConfig.isDefined) { + adminArgs ++= Array("--command-config", commandConfig.get.getAbsolutePath) + } + this.adminArgs = adminArgs + } + + private def callMain(args: Array[String]): (String, String) = { + TestUtils.grabConsoleOutputAndError(AclCommand.main(args)) + } + + private def testAclCli(cmdArgs: Array[String]): Unit = { + for ((resources, resourceCmd) <- ResourceToCommand) { + for (permissionType <- Set(ALLOW, DENY)) { + val operationToCmd = ResourceToOperations(resources) + val (acls, cmd) = getAclToCommand(permissionType, operationToCmd._1) + val (addOut, addErr) = callMain(cmdArgs ++ cmd ++ resourceCmd ++ operationToCmd._2 :+ "--add") + assertOutputContains("Adding ACLs", resources, resourceCmd, addOut) + assertOutputContains("Current ACLs", resources, resourceCmd, addOut) + assertEquals("", addErr) + + for (resource <- resources) { + withAuthorizer() { authorizer => + TestUtils.waitAndVerifyAcls(acls, authorizer, resource) + } + } + + val (listOut, listErr) = callMain(cmdArgs :+ "--list") + assertOutputContains("Current ACLs", resources, resourceCmd, listOut) + assertEquals("", listErr) + + testRemove(cmdArgs, resources, resourceCmd) + } + } + } + + private def assertOutputContains(prefix: String, resources: Set[ResourcePattern], resourceCmd: Array[String], output: String): Unit = { + resources.foreach { resource => + val resourceType = resource.resourceType.toString + (if (resource == ClusterResource) Array("kafka-cluster") else resourceCmd.filter(!_.startsWith("--"))).foreach { name => + val expected = s"$prefix for resource `ResourcePattern(resourceType=$resourceType, name=$name, patternType=LITERAL)`:" + assertTrue(output.contains(expected), s"Substring $expected not in output:\n$output") + } + } + } + + @Test + def testProducerConsumerCliWithAuthorizer(): Unit = { + testProducerConsumerCli(zkArgs) + } + + @Test + def testProducerConsumerCliWithAdminAPI(): Unit = { + createServer() + testProducerConsumerCli(adminArgs) + } + + @Test + def testAclCliWithClientId(): Unit = { + val adminClientConfig = TestUtils.tempFile() + val pw = new PrintWriter(adminClientConfig) + pw.println("client.id=my-client") + pw.close() + + createServer(Some(adminClientConfig)) + + val appender = LogCaptureAppender.createAndRegister() + val previousLevel = LogCaptureAppender.setClassLoggerLevel(classOf[AppInfoParser], Level.WARN) + try { + testAclCli(adminArgs) + } finally { + LogCaptureAppender.setClassLoggerLevel(classOf[AppInfoParser], previousLevel) + LogCaptureAppender.unregister(appender) + } + val warning = appender.getMessages.find(e => e.getLevel == Level.WARN && + e.getThrowableInformation != null && + e.getThrowableInformation.getThrowable.getClass.getName == classOf[InstanceAlreadyExistsException].getName) + assertFalse(warning.isDefined, "There should be no warnings about multiple registration of mbeans") + + } + + private def testProducerConsumerCli(cmdArgs: Array[String]): Unit = { + for ((cmd, resourcesToAcls) <- CmdToResourcesToAcl) { + val resourceCommand: Array[String] = resourcesToAcls.keys.map(ResourceToCommand).foldLeft(Array[String]())(_ ++ _) + callMain(cmdArgs ++ getCmd(ALLOW) ++ resourceCommand ++ cmd :+ "--add") + for ((resources, acls) <- resourcesToAcls) { + for (resource <- resources) { + withAuthorizer() { authorizer => + TestUtils.waitAndVerifyAcls(acls, authorizer, resource) + } + } + } + testRemove(cmdArgs, resourcesToAcls.keys.flatten.toSet, resourceCommand ++ cmd) + } + } + + @Test + def testAclsOnPrefixedResourcesWithAuthorizer(): Unit = { + testAclsOnPrefixedResources(zkArgs) + } + + @Test + def testAclsOnPrefixedResourcesWithAdminAPI(): Unit = { + createServer() + testAclsOnPrefixedResources(adminArgs) + } + + private def testAclsOnPrefixedResources(cmdArgs: Array[String]): Unit = { + val cmd = Array("--allow-principal", principal.toString, "--producer", "--topic", "Test-", "--resource-pattern-type", "Prefixed") + + callMain(cmdArgs ++ cmd :+ "--add") + + withAuthorizer() { authorizer => + val writeAcl = new AccessControlEntry(principal.toString, AclEntry.WildcardHost, WRITE, ALLOW) + val describeAcl = new AccessControlEntry(principal.toString, AclEntry.WildcardHost, DESCRIBE, ALLOW) + val createAcl = new AccessControlEntry(principal.toString, AclEntry.WildcardHost, CREATE, ALLOW) + TestUtils.waitAndVerifyAcls(Set(writeAcl, describeAcl, createAcl), authorizer, + new ResourcePattern(TOPIC, "Test-", PREFIXED)) + } + + callMain(cmdArgs ++ cmd :+ "--remove" :+ "--force") + + withAuthorizer() { authorizer => + TestUtils.waitAndVerifyAcls(Set.empty[AccessControlEntry], authorizer, new ResourcePattern(CLUSTER, "kafka-cluster", LITERAL)) + TestUtils.waitAndVerifyAcls(Set.empty[AccessControlEntry], authorizer, new ResourcePattern(TOPIC, "Test-", PREFIXED)) + } + } + + @Test + def testInvalidAuthorizerProperty(): Unit = { + val args = Array("--authorizer-properties", "zookeeper.connect " + zkConnect) + val aclCommandService = new AclCommand.AuthorizerService(classOf[AclAuthorizer].getName, + new AclCommandOptions(args)) + assertThrows(classOf[IllegalArgumentException], () => aclCommandService.listAcls()) + } + + @Test + def testPatternTypes(): Unit = { + Exit.setExitProcedure { (status, _) => + if (status == 1) + throw new RuntimeException("Exiting command") + else + throw new AssertionError(s"Unexpected exit with status $status") + } + def verifyPatternType(cmd: Array[String], isValid: Boolean): Unit = { + if (isValid) + callMain(cmd) + else + assertThrows(classOf[RuntimeException], () => callMain(cmd)) + } + try { + PatternType.values.foreach { patternType => + val addCmd = zkArgs ++ Array("--allow-principal", principal.toString, "--producer", "--topic", "Test", + "--add", "--resource-pattern-type", patternType.toString) + verifyPatternType(addCmd, isValid = patternType.isSpecific) + val listCmd = zkArgs ++ Array("--topic", "Test", "--list", "--resource-pattern-type", patternType.toString) + verifyPatternType(listCmd, isValid = patternType != PatternType.UNKNOWN) + val removeCmd = zkArgs ++ Array("--topic", "Test", "--force", "--remove", "--resource-pattern-type", patternType.toString) + verifyPatternType(removeCmd, isValid = patternType != PatternType.UNKNOWN) + + } + } finally { + Exit.resetExitProcedure() + } + } + + private def testRemove(cmdArgs: Array[String], resources: Set[ResourcePattern], resourceCmd: Array[String]): Unit = { + val (out, err) = callMain(cmdArgs ++ resourceCmd :+ "--remove" :+ "--force") + assertEquals("", out) + assertEquals("", err) + for (resource <- resources) { + withAuthorizer() { authorizer => + TestUtils.waitAndVerifyAcls(Set.empty[AccessControlEntry], authorizer, resource) + } + } + } + + private def getAclToCommand(permissionType: AclPermissionType, operations: Set[AclOperation]): (Set[AccessControlEntry], Array[String]) = { + (AclCommand.getAcls(Users, permissionType, operations, Hosts), getCmd(permissionType)) + } + + private def getCmd(permissionType: AclPermissionType): Array[String] = { + val principalCmd = if (permissionType == ALLOW) "--allow-principal" else "--deny-principal" + val cmd = if (permissionType == ALLOW) AllowHostCommand else DenyHostCommand + + Users.foldLeft(cmd) ((cmd, user) => cmd ++ Array(principalCmd, user.toString)) + } + + private def withAuthorizer()(f: Authorizer => Unit): Unit = { + val kafkaConfig = KafkaConfig.fromProps(brokerProps, doLog = false) + val authZ = new AclAuthorizer + try { + authZ.configure(kafkaConfig.originals) + f(authZ) + } finally authZ.close() + } +} diff --git a/core/src/test/scala/unit/kafka/admin/AddPartitionsTest.scala b/core/src/test/scala/unit/kafka/admin/AddPartitionsTest.scala new file mode 100755 index 0000000..cd0bb1b --- /dev/null +++ b/core/src/test/scala/unit/kafka/admin/AddPartitionsTest.scala @@ -0,0 +1,187 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.admin + +import java.util.Optional +import kafka.controller.ReplicaAssignment +import kafka.server.BaseRequestTest +import kafka.utils.TestUtils +import kafka.utils.TestUtils._ +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.errors.InvalidReplicaAssignmentException +import org.apache.kafka.common.requests.MetadataResponse.TopicMetadata +import org.apache.kafka.common.requests.{MetadataRequest, MetadataResponse} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{BeforeEach, Test, TestInfo} + +import scala.jdk.CollectionConverters._ + +class AddPartitionsTest extends BaseRequestTest { + + override def brokerCount: Int = 4 + + val partitionId = 0 + + val topic1 = "new-topic1" + val topic1Assignment = Map(0 -> ReplicaAssignment(Seq(0,1), List(), List())) + val topic2 = "new-topic2" + val topic2Assignment = Map(0 -> ReplicaAssignment(Seq(1,2), List(), List())) + val topic3 = "new-topic3" + val topic3Assignment = Map(0 -> ReplicaAssignment(Seq(2,3,0,1), List(), List())) + val topic4 = "new-topic4" + val topic4Assignment = Map(0 -> ReplicaAssignment(Seq(0,3), List(), List())) + val topic5 = "new-topic5" + val topic5Assignment = Map(1 -> ReplicaAssignment(Seq(0,1), List(), List())) + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + + createTopic(topic1, partitionReplicaAssignment = topic1Assignment.map { case (k, v) => k -> v.replicas }) + createTopic(topic2, partitionReplicaAssignment = topic2Assignment.map { case (k, v) => k -> v.replicas }) + createTopic(topic3, partitionReplicaAssignment = topic3Assignment.map { case (k, v) => k -> v.replicas }) + createTopic(topic4, partitionReplicaAssignment = topic4Assignment.map { case (k, v) => k -> v.replicas }) + } + + @Test + def testWrongReplicaCount(): Unit = { + assertThrows(classOf[InvalidReplicaAssignmentException], () => adminZkClient.addPartitions(topic1, topic1Assignment, adminZkClient.getBrokerMetadatas(), 2, + Some(Map(0 -> Seq(0, 1), 1 -> Seq(0, 1, 2))))) + } + + @Test + def testMissingPartition0(): Unit = { + val e = assertThrows(classOf[AdminOperationException], () => adminZkClient.addPartitions(topic5, topic5Assignment, adminZkClient.getBrokerMetadatas(), 2, + Some(Map(1 -> Seq(0, 1), 2 -> Seq(0, 1, 2))))) + assertTrue(e.getMessage.contains("Unexpected existing replica assignment for topic 'new-topic5', partition id 0 is missing")) + } + + @Test + def testIncrementPartitions(): Unit = { + adminZkClient.addPartitions(topic1, topic1Assignment, adminZkClient.getBrokerMetadatas(), 3) + // wait until leader is elected + val leader1 = waitUntilLeaderIsElectedOrChanged(zkClient, topic1, 1) + val leader2 = waitUntilLeaderIsElectedOrChanged(zkClient, topic1, 2) + val leader1FromZk = zkClient.getLeaderForPartition(new TopicPartition(topic1, 1)).get + val leader2FromZk = zkClient.getLeaderForPartition(new TopicPartition(topic1, 2)).get + assertEquals(leader1, leader1FromZk) + assertEquals(leader2, leader2FromZk) + + // read metadata from a broker and verify the new topic partitions exist + TestUtils.waitForPartitionMetadata(servers, topic1, 1) + TestUtils.waitForPartitionMetadata(servers, topic1, 2) + val response = connectAndReceive[MetadataResponse]( + new MetadataRequest.Builder(Seq(topic1).asJava, false).build) + assertEquals(1, response.topicMetadata.size) + val partitions = response.topicMetadata.asScala.head.partitionMetadata.asScala.sortBy(_.partition) + assertEquals(partitions.size, 3) + assertEquals(1, partitions(1).partition) + assertEquals(2, partitions(2).partition) + + for (partition <- partitions) { + val replicas = partition.replicaIds + assertEquals(2, replicas.size) + assertTrue(partition.leaderId.isPresent) + val leaderId = partition.leaderId.get + assertTrue(replicas.contains(leaderId)) + } + } + + @Test + def testManualAssignmentOfReplicas(): Unit = { + // Add 2 partitions + adminZkClient.addPartitions(topic2, topic2Assignment, adminZkClient.getBrokerMetadatas(), 3, + Some(Map(0 -> Seq(1, 2), 1 -> Seq(0, 1), 2 -> Seq(2, 3)))) + // wait until leader is elected + val leader1 = waitUntilLeaderIsElectedOrChanged(zkClient, topic2, 1) + val leader2 = waitUntilLeaderIsElectedOrChanged(zkClient, topic2, 2) + val leader1FromZk = zkClient.getLeaderForPartition(new TopicPartition(topic2, 1)).get + val leader2FromZk = zkClient.getLeaderForPartition(new TopicPartition(topic2, 2)).get + assertEquals(leader1, leader1FromZk) + assertEquals(leader2, leader2FromZk) + + // read metadata from a broker and verify the new topic partitions exist + TestUtils.waitForPartitionMetadata(servers, topic2, 1) + TestUtils.waitForPartitionMetadata(servers, topic2, 2) + val response = connectAndReceive[MetadataResponse]( + new MetadataRequest.Builder(Seq(topic2).asJava, false).build) + assertEquals(1, response.topicMetadata.size) + val topicMetadata = response.topicMetadata.asScala.head + val partitionMetadata = topicMetadata.partitionMetadata.asScala.sortBy(_.partition) + assertEquals(3, topicMetadata.partitionMetadata.size) + assertEquals(0, partitionMetadata(0).partition) + assertEquals(1, partitionMetadata(1).partition) + assertEquals(2, partitionMetadata(2).partition) + val replicas = partitionMetadata(1).replicaIds + assertEquals(2, replicas.size) + assertEquals(Set(0, 1), replicas.asScala.toSet) + } + + @Test + def testReplicaPlacementAllServers(): Unit = { + adminZkClient.addPartitions(topic3, topic3Assignment, adminZkClient.getBrokerMetadatas(), 7) + + // read metadata from a broker and verify the new topic partitions exist + TestUtils.waitForPartitionMetadata(servers, topic3, 1) + TestUtils.waitForPartitionMetadata(servers, topic3, 2) + TestUtils.waitForPartitionMetadata(servers, topic3, 3) + TestUtils.waitForPartitionMetadata(servers, topic3, 4) + TestUtils.waitForPartitionMetadata(servers, topic3, 5) + TestUtils.waitForPartitionMetadata(servers, topic3, 6) + + val response = connectAndReceive[MetadataResponse]( + new MetadataRequest.Builder(Seq(topic3).asJava, false).build) + assertEquals(1, response.topicMetadata.size) + val topicMetadata = response.topicMetadata.asScala.head + validateLeaderAndReplicas(topicMetadata, 0, 2, Set(2, 3, 0, 1)) + validateLeaderAndReplicas(topicMetadata, 1, 3, Set(3, 2, 0, 1)) + validateLeaderAndReplicas(topicMetadata, 2, 0, Set(0, 3, 1, 2)) + validateLeaderAndReplicas(topicMetadata, 3, 1, Set(1, 0, 2, 3)) + validateLeaderAndReplicas(topicMetadata, 4, 2, Set(2, 3, 0, 1)) + validateLeaderAndReplicas(topicMetadata, 5, 3, Set(3, 0, 1, 2)) + validateLeaderAndReplicas(topicMetadata, 6, 0, Set(0, 1, 2, 3)) + } + + @Test + def testReplicaPlacementPartialServers(): Unit = { + adminZkClient.addPartitions(topic2, topic2Assignment, adminZkClient.getBrokerMetadatas(), 3) + + // read metadata from a broker and verify the new topic partitions exist + TestUtils.waitForPartitionMetadata(servers, topic2, 1) + TestUtils.waitForPartitionMetadata(servers, topic2, 2) + + val response = connectAndReceive[MetadataResponse]( + new MetadataRequest.Builder(Seq(topic2).asJava, false).build) + assertEquals(1, response.topicMetadata.size) + val topicMetadata = response.topicMetadata.asScala.head + validateLeaderAndReplicas(topicMetadata, 0, 1, Set(1, 2)) + validateLeaderAndReplicas(topicMetadata, 1, 2, Set(0, 2)) + validateLeaderAndReplicas(topicMetadata, 2, 3, Set(1, 3)) + } + + def validateLeaderAndReplicas(metadata: TopicMetadata, partitionId: Int, expectedLeaderId: Int, + expectedReplicas: Set[Int]): Unit = { + val partitionOpt = metadata.partitionMetadata.asScala.find(_.partition == partitionId) + assertTrue(partitionOpt.isDefined, s"Partition $partitionId should exist") + val partition = partitionOpt.get + + assertEquals(Optional.of(expectedLeaderId), partition.leaderId, "Partition leader id should match") + assertEquals(expectedReplicas, partition.replicaIds.asScala.toSet, "Replica set should match") + } + +} diff --git a/core/src/test/scala/unit/kafka/admin/AdminRackAwareTest.scala b/core/src/test/scala/unit/kafka/admin/AdminRackAwareTest.scala new file mode 100644 index 0000000..d2d665b --- /dev/null +++ b/core/src/test/scala/unit/kafka/admin/AdminRackAwareTest.scala @@ -0,0 +1,225 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.admin + +import kafka.utils.Logging +import org.apache.kafka.common.errors.InvalidReplicationFactorException +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.collection.Map + +class AdminRackAwareTest extends RackAwareTest with Logging { + + @Test + def testGetRackAlternatedBrokerListAndAssignReplicasToBrokers(): Unit = { + val rackMap = Map(0 -> "rack1", 1 -> "rack3", 2 -> "rack3", 3 -> "rack2", 4 -> "rack2", 5 -> "rack1") + val newList = AdminUtils.getRackAlternatedBrokerList(rackMap) + assertEquals(List(0, 3, 1, 5, 4, 2), newList) + val anotherList = AdminUtils.getRackAlternatedBrokerList(rackMap.toMap - 5) + assertEquals(List(0, 3, 1, 4, 2), anotherList) + val assignment = AdminUtils.assignReplicasToBrokers(toBrokerMetadata(rackMap), 7, 3, 0, 0) + val expected = Map(0 -> List(0, 3, 1), + 1 -> List(3, 1, 5), + 2 -> List(1, 5, 4), + 3 -> List(5, 4, 2), + 4 -> List(4, 2, 0), + 5 -> List(2, 0, 3), + 6 -> List(0, 4, 2)) + assertEquals(expected, assignment) + } + + @Test + def testAssignmentWithRackAware(): Unit = { + val brokerRackMapping = Map(0 -> "rack1", 1 -> "rack2", 2 -> "rack2", 3 -> "rack3", 4 -> "rack3", 5 -> "rack1") + val numPartitions = 6 + val replicationFactor = 3 + val assignment = AdminUtils.assignReplicasToBrokers(toBrokerMetadata(brokerRackMapping), numPartitions, + replicationFactor, 2, 0) + checkReplicaDistribution(assignment, brokerRackMapping, brokerRackMapping.size, numPartitions, + replicationFactor) + } + + @Test + def testAssignmentWithRackAwareWithRandomStartIndex(): Unit = { + val brokerRackMapping = Map(0 -> "rack1", 1 -> "rack2", 2 -> "rack2", 3 -> "rack3", 4 -> "rack3", 5 -> "rack1") + val numPartitions = 6 + val replicationFactor = 3 + val assignment = AdminUtils.assignReplicasToBrokers(toBrokerMetadata(brokerRackMapping), numPartitions, + replicationFactor) + checkReplicaDistribution(assignment, brokerRackMapping, brokerRackMapping.size, numPartitions, + replicationFactor) + } + + @Test + def testAssignmentWithRackAwareWithUnevenReplicas(): Unit = { + val brokerRackMapping = Map(0 -> "rack1", 1 -> "rack2", 2 -> "rack2", 3 -> "rack3", 4 -> "rack3", 5 -> "rack1") + val numPartitions = 13 + val replicationFactor = 3 + val assignment = AdminUtils.assignReplicasToBrokers(toBrokerMetadata(brokerRackMapping), numPartitions, + replicationFactor, 0, 0) + checkReplicaDistribution(assignment, brokerRackMapping, brokerRackMapping.size, numPartitions, + replicationFactor, verifyLeaderDistribution = false, verifyReplicasDistribution = false) + } + + @Test + def testAssignmentWithRackAwareWithUnevenRacks(): Unit = { + val brokerRackMapping = Map(0 -> "rack1", 1 -> "rack1", 2 -> "rack2", 3 -> "rack3", 4 -> "rack3", 5 -> "rack1") + val numPartitions = 12 + val replicationFactor = 3 + val assignment = AdminUtils.assignReplicasToBrokers(toBrokerMetadata(brokerRackMapping), numPartitions, + replicationFactor) + checkReplicaDistribution(assignment, brokerRackMapping, brokerRackMapping.size, numPartitions, + replicationFactor, verifyReplicasDistribution = false) + } + + @Test + def testAssignmentWith2ReplicasRackAware(): Unit = { + val brokerRackMapping = Map(0 -> "rack1", 1 -> "rack2", 2 -> "rack2", 3 -> "rack3", 4 -> "rack3", 5 -> "rack1") + val numPartitions = 12 + val replicationFactor = 2 + val assignment = AdminUtils.assignReplicasToBrokers(toBrokerMetadata(brokerRackMapping), numPartitions, + replicationFactor) + checkReplicaDistribution(assignment, brokerRackMapping, brokerRackMapping.size, numPartitions, + replicationFactor) + } + + @Test + def testRackAwareExpansion(): Unit = { + val brokerRackMapping = Map(6 -> "rack1", 7 -> "rack2", 8 -> "rack2", 9 -> "rack3", 10 -> "rack3", 11 -> "rack1") + val numPartitions = 12 + val replicationFactor = 2 + val assignment = AdminUtils.assignReplicasToBrokers(toBrokerMetadata(brokerRackMapping), numPartitions, + replicationFactor, startPartitionId = 12) + checkReplicaDistribution(assignment, brokerRackMapping, brokerRackMapping.size, numPartitions, + replicationFactor) + } + + @Test + def testAssignmentWith2ReplicasRackAwareWith6Partitions(): Unit = { + val brokerRackMapping = Map(0 -> "rack1", 1 -> "rack2", 2 -> "rack2", 3 -> "rack3", 4 -> "rack3", 5 -> "rack1") + val numPartitions = 6 + val replicationFactor = 2 + val assignment = AdminUtils.assignReplicasToBrokers(toBrokerMetadata(brokerRackMapping), numPartitions, + replicationFactor) + checkReplicaDistribution(assignment, brokerRackMapping, brokerRackMapping.size, numPartitions, + replicationFactor) + } + + @Test + def testAssignmentWith2ReplicasRackAwareWith6PartitionsAnd3Brokers(): Unit = { + val brokerRackMapping = Map(0 -> "rack1", 1 -> "rack2", 4 -> "rack3") + val numPartitions = 3 + val replicationFactor = 2 + val assignment = AdminUtils.assignReplicasToBrokers(toBrokerMetadata(brokerRackMapping), numPartitions, replicationFactor) + checkReplicaDistribution(assignment, brokerRackMapping, brokerRackMapping.size, numPartitions, replicationFactor) + } + + @Test + def testLargeNumberPartitionsAssignment(): Unit = { + val numPartitions = 96 + val replicationFactor = 3 + val brokerRackMapping = Map(0 -> "rack1", 1 -> "rack2", 2 -> "rack2", 3 -> "rack3", 4 -> "rack3", 5 -> "rack1", + 6 -> "rack1", 7 -> "rack2", 8 -> "rack2", 9 -> "rack3", 10 -> "rack1", 11 -> "rack3") + val assignment = AdminUtils.assignReplicasToBrokers(toBrokerMetadata(brokerRackMapping), numPartitions, + replicationFactor) + checkReplicaDistribution(assignment, brokerRackMapping, brokerRackMapping.size, numPartitions, + replicationFactor) + } + + @Test + def testMoreReplicasThanRacks(): Unit = { + val numPartitions = 6 + val replicationFactor = 5 + val brokerRackMapping = Map(0 -> "rack1", 1 -> "rack2", 2 -> "rack2", 3 -> "rack3", 4 -> "rack3", 5 -> "rack2") + val assignment = AdminUtils.assignReplicasToBrokers(toBrokerMetadata(brokerRackMapping), numPartitions, replicationFactor) + assertEquals(List.fill(assignment.size)(replicationFactor), assignment.values.toIndexedSeq.map(_.size)) + val distribution = getReplicaDistribution(assignment, brokerRackMapping) + for (partition <- 0 until numPartitions) + assertEquals(3, distribution.partitionRacks(partition).toSet.size) + } + + @Test + def testLessReplicasThanRacks(): Unit = { + val numPartitions = 6 + val replicationFactor = 2 + val brokerRackMapping = Map(0 -> "rack1", 1 -> "rack2", 2 -> "rack2", 3 -> "rack3", 4 -> "rack3", 5 -> "rack2") + val assignment = AdminUtils.assignReplicasToBrokers(toBrokerMetadata(brokerRackMapping), numPartitions, + replicationFactor) + assertEquals(List.fill(assignment.size)(replicationFactor), assignment.values.toIndexedSeq.map(_.size)) + val distribution = getReplicaDistribution(assignment, brokerRackMapping) + for (partition <- 0 to 5) + assertEquals(2, distribution.partitionRacks(partition).toSet.size) + } + + @Test + def testSingleRack(): Unit = { + val numPartitions = 6 + val replicationFactor = 3 + val brokerRackMapping = Map(0 -> "rack1", 1 -> "rack1", 2 -> "rack1", 3 -> "rack1", 4 -> "rack1", 5 -> "rack1") + val assignment = AdminUtils.assignReplicasToBrokers(toBrokerMetadata(brokerRackMapping), numPartitions, replicationFactor) + assertEquals(List.fill(assignment.size)(replicationFactor), assignment.values.toIndexedSeq.map(_.size)) + val distribution = getReplicaDistribution(assignment, brokerRackMapping) + for (partition <- 0 until numPartitions) + assertEquals(1, distribution.partitionRacks(partition).toSet.size) + for (broker <- brokerRackMapping.keys) + assertEquals(1, distribution.brokerLeaderCount(broker)) + } + + @Test + def testSkipBrokerWithReplicaAlreadyAssigned(): Unit = { + val rackInfo = Map(0 -> "a", 1 -> "b", 2 -> "c", 3 -> "a", 4 -> "a") + val brokerList = 0 to 4 + val numPartitions = 6 + val replicationFactor = 4 + val brokerMetadatas = toBrokerMetadata(rackInfo) + assertEquals(brokerList, brokerMetadatas.map(_.id)) + val assignment = AdminUtils.assignReplicasToBrokers(brokerMetadatas, numPartitions, replicationFactor, + fixedStartIndex = 2) + checkReplicaDistribution(assignment, rackInfo, 5, 6, 4, + verifyRackAware = false, verifyLeaderDistribution = false, verifyReplicasDistribution = false) + } + + @Test + def testReplicaAssignment(): Unit = { + val brokerMetadatas = (0 to 4).map(new BrokerMetadata(_, None)) + + // test 0 replication factor + assertThrows(classOf[InvalidReplicationFactorException], + () => AdminUtils.assignReplicasToBrokers(brokerMetadatas, 10, 0)) + + // test wrong replication factor + assertThrows(classOf[InvalidReplicationFactorException], + () => AdminUtils.assignReplicasToBrokers(brokerMetadatas, 10, 6)) + + // correct assignment + val expectedAssignment = Map( + 0 -> List(0, 1, 2), + 1 -> List(1, 2, 3), + 2 -> List(2, 3, 4), + 3 -> List(3, 4, 0), + 4 -> List(4, 0, 1), + 5 -> List(0, 2, 3), + 6 -> List(1, 3, 4), + 7 -> List(2, 4, 0), + 8 -> List(3, 0, 1), + 9 -> List(4, 1, 2)) + + val actualAssignment = AdminUtils.assignReplicasToBrokers(brokerMetadatas, 10, 3, 0) + assertEquals(expectedAssignment, actualAssignment) + } +} diff --git a/core/src/test/scala/unit/kafka/admin/ConfigCommandTest.scala b/core/src/test/scala/unit/kafka/admin/ConfigCommandTest.scala new file mode 100644 index 0000000..859b3d5 --- /dev/null +++ b/core/src/test/scala/unit/kafka/admin/ConfigCommandTest.scala @@ -0,0 +1,1727 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.admin + +import java.util +import java.util.Properties +import kafka.admin.ConfigCommand.ConfigCommandOptions +import kafka.api.ApiVersion +import kafka.cluster.{Broker, EndPoint} +import kafka.server.{ConfigEntityName, ConfigType, KafkaConfig, QuorumTestHarness} +import kafka.utils.{Exit, Logging} +import kafka.zk.{AdminZkClient, BrokerInfo, KafkaZkClient} +import org.apache.kafka.clients.admin._ +import org.apache.kafka.common.Node +import org.apache.kafka.common.config.{ConfigException, ConfigResource} +import org.apache.kafka.common.errors.InvalidConfigurationException +import org.apache.kafka.common.internals.KafkaFutureImpl +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.quota.{ClientQuotaAlteration, ClientQuotaEntity, ClientQuotaFilter, ClientQuotaFilterComponent} +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.security.scram.internals.ScramCredentialUtils +import org.apache.kafka.common.utils.Sanitizer +import org.apache.kafka.test.TestUtils +import org.easymock.EasyMock +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.collection.{Seq, mutable} +import scala.jdk.CollectionConverters._ + +class ConfigCommandTest extends QuorumTestHarness with Logging { + + @Test + def shouldExitWithNonZeroStatusOnArgError(): Unit = { + assertNonZeroStatusExit(Array("--blah")) + } + + @Test + def shouldExitWithNonZeroStatusOnUpdatingUnallowedConfigViaZk(): Unit = { + assertNonZeroStatusExit(Array( + "--zookeeper", zkConnect, + "--entity-name", "1", + "--entity-type", "brokers", + "--alter", + "--add-config", "security.inter.broker.protocol=PLAINTEXT")) + } + + @Test + def shouldExitWithNonZeroStatusOnZkCommandWithTopicsEntity(): Unit = { + assertNonZeroStatusExit(Array( + "--zookeeper", zkConnect, + "--entity-type", "topics", + "--describe")) + } + + @Test + def shouldExitWithNonZeroStatusOnZkCommandWithClientsEntity(): Unit = { + assertNonZeroStatusExit(Array( + "--zookeeper", zkConnect, + "--entity-type", "clients", + "--describe")) + } + + @Test + def shouldExitWithNonZeroStatusOnZkCommandWithIpsEntity(): Unit = { + assertNonZeroStatusExit(Array( + "--zookeeper", zkConnect, + "--entity-type", "ips", + "--describe")) + } + + @Test + def shouldExitWithNonZeroStatusOnZkCommandAlterUserQuota(): Unit = { + assertNonZeroStatusExit(Array( + "--zookeeper", zkConnect, + "--entity-type", "users", + "--entity-name", "admin", + "--alter", "--add-config", "consumer_byte_rate=20000")) + } + + @Test + def shouldExitWithNonZeroStatusAlterUserQuotaWithoutEntityName(): Unit = { + assertNonZeroStatusExit(Array( + "--bootstrap-server", "localhost:9092", + "--entity-type", "users", + "--alter", "--add-config", "consumer_byte_rate=20000")) + } + + + @Test + def shouldExitWithNonZeroStatusOnBrokerCommandError(): Unit = { + assertNonZeroStatusExit(Array( + "--bootstrap-server", "invalid host", + "--entity-type", "brokers", + "--entity-name", "1", + "--describe")) + } + + @Test + def shouldExitWithNonZeroStatusOnBrokerCommandWithZkTlsConfigFile(): Unit = { + assertNonZeroStatusExit(Array( + "--bootstrap-server", "invalid host", + "--entity-type", "users", + "--zk-tls-config-file", "zk_tls_config.properties", + "--describe")) + } + + private def assertNonZeroStatusExit(args: Array[String]): Unit = { + var exitStatus: Option[Int] = None + Exit.setExitProcedure { (status, _) => + exitStatus = Some(status) + throw new RuntimeException + } + + try { + ConfigCommand.main(args) + } catch { + case e: RuntimeException => + } finally { + Exit.resetExitProcedure() + } + + assertEquals(Some(1), exitStatus) + } + + @Test + def shouldFailParseArgumentsForClientsEntityTypeUsingZookeeper(): Unit = { + assertThrows(classOf[IllegalArgumentException], () => testArgumentParse("clients", zkConfig = true)) + } + + @Test + def shouldParseArgumentsForClientsEntityType(): Unit = { + testArgumentParse("clients", zkConfig = false) + } + + @Test + def shouldParseArgumentsForUsersEntityTypeUsingZookeeper(): Unit = { + testArgumentParse("users", zkConfig = true) + } + + @Test + def shouldParseArgumentsForUsersEntityType(): Unit = { + testArgumentParse("users", zkConfig = false) + } + + @Test + def shouldFailParseArgumentsForTopicsEntityTypeUsingZookeeper(): Unit = { + assertThrows(classOf[IllegalArgumentException], () => testArgumentParse("topics", zkConfig = true)) + } + + @Test + def shouldParseArgumentsForTopicsEntityType(): Unit = { + testArgumentParse("topics", zkConfig = false) + } + + @Test + def shouldParseArgumentsForBrokersEntityTypeUsingZookeeper(): Unit = { + testArgumentParse("brokers", zkConfig = true) + } + + @Test + def shouldParseArgumentsForBrokersEntityType(): Unit = { + testArgumentParse("brokers", zkConfig = false) + } + + @Test + def shouldParseArgumentsForBrokerLoggersEntityType(): Unit = { + testArgumentParse("broker-loggers", zkConfig = false) + } + + @Test + def shouldFailParseArgumentsForIpEntityTypeUsingZookeeper(): Unit = { + assertThrows(classOf[IllegalArgumentException], () => testArgumentParse("ips", zkConfig = true)) + } + + @Test + def shouldParseArgumentsForIpEntityType(): Unit = { + testArgumentParse("ips", zkConfig = false) + } + + def testArgumentParse(entityType: String, zkConfig: Boolean): Unit = { + val shortFlag: String = s"--${entityType.dropRight(1)}" + + val connectOpts = if (zkConfig) + ("--zookeeper", zkConnect) + else + ("--bootstrap-server", "localhost:9092") + + // Should parse correctly + var createOpts = new ConfigCommandOptions(Array(connectOpts._1, connectOpts._2, + "--entity-name", "1", + "--entity-type", entityType, + "--describe")) + createOpts.checkArgs() + + createOpts = new ConfigCommandOptions(Array(connectOpts._1, connectOpts._2, + shortFlag, "1", + "--describe")) + createOpts.checkArgs() + + // For --alter and added config + createOpts = new ConfigCommandOptions(Array(connectOpts._1, connectOpts._2, + "--entity-name", "1", + "--entity-type", entityType, + "--alter", + "--add-config", "a=b,c=d")) + createOpts.checkArgs() + + createOpts = new ConfigCommandOptions(Array(connectOpts._1, connectOpts._2, + "--entity-name", "1", + "--entity-type", entityType, + "--alter", + "--add-config-file", "/tmp/new.properties")) + createOpts.checkArgs() + + createOpts = new ConfigCommandOptions(Array(connectOpts._1, connectOpts._2, + shortFlag, "1", + "--alter", + "--add-config", "a=b,c=d")) + createOpts.checkArgs() + + createOpts = new ConfigCommandOptions(Array(connectOpts._1, connectOpts._2, + shortFlag, "1", + "--alter", + "--add-config-file", "/tmp/new.properties")) + createOpts.checkArgs() + + // For alter and deleted config + createOpts = new ConfigCommandOptions(Array(connectOpts._1, connectOpts._2, + "--entity-name", "1", + "--entity-type", entityType, + "--alter", + "--delete-config", "a,b,c")) + createOpts.checkArgs() + + createOpts = new ConfigCommandOptions(Array(connectOpts._1, connectOpts._2, + shortFlag, "1", + "--alter", + "--delete-config", "a,b,c")) + createOpts.checkArgs() + + // For alter and both added, deleted config + createOpts = new ConfigCommandOptions(Array(connectOpts._1, connectOpts._2, + "--entity-name", "1", + "--entity-type", entityType, + "--alter", + "--add-config", "a=b,c=d", + "--delete-config", "a")) + createOpts.checkArgs() + + createOpts = new ConfigCommandOptions(Array(connectOpts._1, connectOpts._2, + shortFlag, "1", + "--alter", + "--add-config", "a=b,c=d", + "--delete-config", "a")) + createOpts.checkArgs() + + val addedProps = ConfigCommand.parseConfigsToBeAdded(createOpts) + assertEquals(2, addedProps.size()) + assertEquals("b", addedProps.getProperty("a")) + assertEquals("d", addedProps.getProperty("c")) + + val deletedProps = ConfigCommand.parseConfigsToBeDeleted(createOpts) + assertEquals(1, deletedProps.size) + assertEquals("a", deletedProps.head) + + createOpts = new ConfigCommandOptions(Array(connectOpts._1, connectOpts._2, + "--entity-name", "1", + "--entity-type", entityType, + "--alter", + "--add-config", "a=b,c=,d=e,f=")) + createOpts.checkArgs() + + createOpts = new ConfigCommandOptions(Array(connectOpts._1, connectOpts._2, + shortFlag, "1", + "--alter", + "--add-config", "a=b,c=,d=e,f=")) + createOpts.checkArgs() + + val addedProps2 = ConfigCommand.parseConfigsToBeAdded(createOpts) + assertEquals(4, addedProps2.size()) + assertEquals("b", addedProps2.getProperty("a")) + assertEquals("e", addedProps2.getProperty("d")) + assertTrue(addedProps2.getProperty("c").isEmpty) + assertTrue(addedProps2.getProperty("f").isEmpty) + } + + @Test + def shouldFailIfAddAndAddFile(): Unit = { + // Should not parse correctly + val createOpts = new ConfigCommandOptions(Array("--bootstrap-server", "localhost:9092", + "--entity-name", "1", + "--entity-type", "brokers", + "--alter", + "--add-config", "a=b,c=d", + "--add-config-file", "/tmp/new.properties" + )) + assertThrows(classOf[IllegalArgumentException], () => createOpts.checkArgs()) + } + + @Test + def testParseConfigsToBeAddedForAddConfigFile(): Unit = { + val fileContents = + """a=b + |c = d + |json = {"key": "val"} + |nested = [[1, 2], [3, 4]] + |""".stripMargin + + val file = TestUtils.tempFile(fileContents) + + val addConfigFileArgs = Array("--add-config-file", file.getPath) + + val createOpts = new ConfigCommandOptions(Array("--bootstrap-server", "localhost:9092", + "--entity-name", "1", + "--entity-type", "brokers", + "--alter") + ++ addConfigFileArgs) + createOpts.checkArgs() + + val addedProps = ConfigCommand.parseConfigsToBeAdded(createOpts) + assertEquals(4, addedProps.size()) + assertEquals("b", addedProps.getProperty("a")) + assertEquals("d", addedProps.getProperty("c")) + assertEquals("{\"key\": \"val\"}", addedProps.getProperty("json")) + assertEquals("[[1, 2], [3, 4]]", addedProps.getProperty("nested")) + } + + def doTestOptionEntityTypeNames(zkConfig: Boolean): Unit = { + val connectOpts = if (zkConfig) + ("--zookeeper", zkConnect) + else + ("--bootstrap-server", "localhost:9092") + + def testExpectedEntityTypeNames(expectedTypes: List[String], expectedNames: List[String], args: String*): Unit = { + val createOpts = new ConfigCommandOptions(Array(connectOpts._1, connectOpts._2, "--describe") ++ args) + createOpts.checkArgs() + assertEquals(createOpts.entityTypes, expectedTypes) + assertEquals(createOpts.entityNames, expectedNames) + } + + // zookeeper config only supports "users" and "brokers" entity type + if (!zkConfig) { + testExpectedEntityTypeNames(List(ConfigType.Topic), List("A"), "--entity-type", "topics", "--entity-name", "A") + testExpectedEntityTypeNames(List(ConfigType.Ip), List("1.2.3.4"), "--entity-name", "1.2.3.4", "--entity-type", "ips") + testExpectedEntityTypeNames(List(ConfigType.User, ConfigType.Client), List("A", ""), + "--entity-type", "users", "--entity-type", "clients", "--entity-name", "A", "--entity-default") + testExpectedEntityTypeNames(List(ConfigType.User, ConfigType.Client), List("", "B"), + "--entity-default", "--entity-name", "B", "--entity-type", "users", "--entity-type", "clients") + testExpectedEntityTypeNames(List(ConfigType.Topic), List("A"), "--topic", "A") + testExpectedEntityTypeNames(List(ConfigType.Ip), List("1.2.3.4"), "--ip", "1.2.3.4") + testExpectedEntityTypeNames(List(ConfigType.Client, ConfigType.User), List("B", "A"), "--client", "B", "--user", "A") + testExpectedEntityTypeNames(List(ConfigType.Client, ConfigType.User), List("B", ""), "--client", "B", "--user-defaults") + testExpectedEntityTypeNames(List(ConfigType.Client, ConfigType.User), List("A"), + "--entity-type", "clients", "--entity-type", "users", "--entity-name", "A") + testExpectedEntityTypeNames(List(ConfigType.Topic), List.empty, "--entity-type", "topics") + testExpectedEntityTypeNames(List(ConfigType.Ip), List.empty, "--entity-type", "ips") + } + + testExpectedEntityTypeNames(List(ConfigType.Broker), List("0"), "--entity-name", "0", "--entity-type", "brokers") + testExpectedEntityTypeNames(List(ConfigType.Broker), List("0"), "--broker", "0") + testExpectedEntityTypeNames(List(ConfigType.User), List.empty, "--entity-type", "users") + testExpectedEntityTypeNames(List(ConfigType.Broker), List.empty, "--entity-type", "brokers") + } + + @Test + def testOptionEntityTypeNamesUsingZookeeper(): Unit = { + doTestOptionEntityTypeNames(zkConfig = true) + } + + @Test + def testOptionEntityTypeNames(): Unit = { + doTestOptionEntityTypeNames(zkConfig = false) + } + + @Test + def shouldFailIfUnrecognisedEntityTypeUsingZookeeper(): Unit = { + val createOpts = new ConfigCommandOptions(Array("--zookeeper", zkConnect, + "--entity-name", "client", "--entity-type", "not-recognised", "--alter", "--add-config", "a=b,c=d")) + assertThrows(classOf[IllegalArgumentException], () => ConfigCommand.alterConfigWithZk(null, createOpts, new DummyAdminZkClient(zkClient))) + } + + @Test + def shouldFailIfUnrecognisedEntityType(): Unit = { + val createOpts = new ConfigCommandOptions(Array("--bootstrap-server", "localhost:9092", + "--entity-name", "client", "--entity-type", "not-recognised", "--alter", "--add-config", "a=b,c=d")) + assertThrows(classOf[IllegalArgumentException], () => ConfigCommand.alterConfig(new DummyAdminClient(new Node(1, "localhost", 9092)), createOpts)) + } + + @Test + def shouldFailIfBrokerEntityTypeIsNotAnIntegerUsingZookeeper(): Unit = { + val createOpts = new ConfigCommandOptions(Array("--zookeeper", zkConnect, + "--entity-name", "A", "--entity-type", "brokers", "--alter", "--add-config", "a=b,c=d")) + assertThrows(classOf[IllegalArgumentException], () => ConfigCommand.alterConfigWithZk(null, createOpts, new DummyAdminZkClient(zkClient))) + } + + @Test + def shouldFailIfBrokerEntityTypeIsNotAnInteger(): Unit = { + val createOpts = new ConfigCommandOptions(Array("--bootstrap-server", "localhost:9092", + "--entity-name", "A", "--entity-type", "brokers", "--alter", "--add-config", "a=b,c=d")) + assertThrows(classOf[IllegalArgumentException], () => ConfigCommand.alterConfig(new DummyAdminClient(new Node(1, "localhost", 9092)), createOpts)) + } + + @Test + def shouldFailIfShortBrokerEntityTypeIsNotAnIntegerUsingZookeeper(): Unit = { + val createOpts = new ConfigCommandOptions(Array("--zookeeper", zkConnect, + "--broker", "A", "--alter", "--add-config", "a=b,c=d")) + assertThrows(classOf[IllegalArgumentException], () => ConfigCommand.alterConfigWithZk(null, createOpts, new DummyAdminZkClient(zkClient))) + } + + @Test + def shouldFailIfShortBrokerEntityTypeIsNotAnInteger(): Unit = { + val createOpts = new ConfigCommandOptions(Array("--bootstrap-server", "localhost:9092", + "--broker", "A", "--alter", "--add-config", "a=b,c=d")) + assertThrows(classOf[IllegalArgumentException], () => ConfigCommand.alterConfig(new DummyAdminClient(new Node(1, "localhost", 9092)), createOpts)) + } + + @Test + def shouldFailIfMixedEntityTypeFlagsUsingZookeeper(): Unit = { + val createOpts = new ConfigCommandOptions(Array("--zookeeper", zkConnect, + "--entity-name", "A", "--entity-type", "users", "--client", "B", "--describe")) + assertThrows(classOf[IllegalArgumentException], () => createOpts.checkArgs()) + } + + @Test + def shouldFailIfMixedEntityTypeFlags(): Unit = { + val createOpts = new ConfigCommandOptions(Array("--bootstrap-server", "localhost:9092", + "--entity-name", "A", "--entity-type", "users", "--client", "B", "--describe")) + assertThrows(classOf[IllegalArgumentException], () => createOpts.checkArgs()) + } + + @Test + def shouldFailIfInvalidHost(): Unit = { + val createOpts = new ConfigCommandOptions(Array("--bootstrap-server", "localhost:9092", + "--entity-name", "A,B", "--entity-type", "ips", "--describe")) + assertThrows(classOf[IllegalArgumentException], () => createOpts.checkArgs()) + } + + @Test + def shouldFailIfInvalidHostUsingZookeeper(): Unit = { + val createOpts = new ConfigCommandOptions(Array("--zookeeper", zkConnect, + "--entity-name", "A,B", "--entity-type", "ips", "--describe")) + assertThrows(classOf[IllegalArgumentException], () => createOpts.checkArgs()) + } + + @Test + def shouldFailIfUnresolvableHost(): Unit = { + val createOpts = new ConfigCommandOptions(Array("--bootstrap-server", "localhost:9092", + "--entity-name", "admin", "--entity-type", "ips", "--describe")) + assertThrows(classOf[IllegalArgumentException], () => createOpts.checkArgs()) + } + + @Test + def shouldFailIfUnresolvableHostUsingZookeeper(): Unit = { + val createOpts = new ConfigCommandOptions(Array("--zookeeper", zkConnect, + "--entity-name", "admin", "--entity-type", "ips", "--describe")) + assertThrows(classOf[IllegalArgumentException], () => createOpts.checkArgs()) + } + + @Test + def shouldAddClientConfigUsingZookeeper(): Unit = { + val createOpts = new ConfigCommandOptions(Array("--zookeeper", zkConnect, + "--entity-name", "my-client-id", + "--entity-type", "clients", + "--alter", + "--add-config", "a=b,c=d")) + + class TestAdminZkClient(zkClient: KafkaZkClient) extends AdminZkClient(zkClient) { + override def changeClientIdConfig(clientId: String, configChange: Properties): Unit = { + assertEquals("my-client-id", clientId) + assertEquals("b", configChange.get("a")) + assertEquals("d", configChange.get("c")) + } + } + + ConfigCommand.alterConfigWithZk(null, createOpts, new TestAdminZkClient(zkClient)) + } + + @Test + def shouldAddIpConfigsUsingZookeeper(): Unit = { + val createOpts = new ConfigCommandOptions(Array("--zookeeper", zkConnect, + "--entity-name", "1.2.3.4", + "--entity-type", "ips", + "--alter", + "--add-config", "a=b,c=d")) + + class TestAdminZkClient(zkClient: KafkaZkClient) extends AdminZkClient(zkClient) { + override def changeIpConfig(ip: String, configChange: Properties): Unit = { + assertEquals("1.2.3.4", ip) + assertEquals("b", configChange.get("a")) + assertEquals("d", configChange.get("c")) + } + } + + ConfigCommand.alterConfigWithZk(null, createOpts, new TestAdminZkClient(zkClient)) + } + + private def toValues(entityName: Option[String], entityType: String): (Array[String], Map[String, String]) = { + val command = entityType match { + case ClientQuotaEntity.USER => "users" + case ClientQuotaEntity.CLIENT_ID => "clients" + case ClientQuotaEntity.IP => "ips" + } + entityName match { + case Some(null) => + (Array("--entity-type", command, "--entity-default"), Map(entityType -> null)) + case Some(name) => + (Array("--entity-type", command, "--entity-name", name), Map(entityType -> name)) + case None => (Array.empty, Map.empty) + } + } + + private def verifyAlterCommandFails(expectedErrorMessage: String, alterOpts: Seq[String]): Unit = { + val mockAdminClient: Admin = EasyMock.createStrictMock(classOf[Admin]) + val opts = new ConfigCommandOptions(Array("--bootstrap-server", "localhost:9092", + "--alter") ++ alterOpts) + val e = assertThrows(classOf[IllegalArgumentException], () => ConfigCommand.alterConfig(mockAdminClient, opts)) + assertTrue(e.getMessage.contains(expectedErrorMessage), s"Unexpected exception: $e") + } + + @Test + def shouldNotAlterNonQuotaIpConfigsUsingBootstrapServer(): Unit = { + // when using --bootstrap-server, it should be illegal to alter anything that is not a connection quota + // for ip entities + val ipEntityOpts = List("--entity-type", "ips", "--entity-name", "127.0.0.1") + val invalidProp = "some_config" + verifyAlterCommandFails(invalidProp, ipEntityOpts ++ List("--add-config", "connection_creation_rate=10000,some_config=10")) + verifyAlterCommandFails(invalidProp, ipEntityOpts ++ List("--add-config", "some_config=10")) + verifyAlterCommandFails(invalidProp, ipEntityOpts ++ List("--delete-config", "connection_creation_rate=10000,some_config=10")) + verifyAlterCommandFails(invalidProp, ipEntityOpts ++ List("--delete-config", "some_config=10")) + } + + private def verifyDescribeQuotas(describeArgs: List[String], expectedFilter: ClientQuotaFilter): Unit = { + val describeOpts = new ConfigCommandOptions(Array("--bootstrap-server", "localhost:9092", + "--describe") ++ describeArgs) + val describeFuture = new KafkaFutureImpl[util.Map[ClientQuotaEntity, util.Map[String, java.lang.Double]]] + describeFuture.complete(Map.empty[ClientQuotaEntity, util.Map[String, java.lang.Double]].asJava) + val describeResult: DescribeClientQuotasResult = EasyMock.createNiceMock(classOf[DescribeClientQuotasResult]) + EasyMock.expect(describeResult.entities()).andReturn(describeFuture) + + var describedConfigs = false + val node = new Node(1, "localhost", 9092) + val mockAdminClient = new MockAdminClient(util.Collections.singletonList(node), node) { + override def describeClientQuotas(filter: ClientQuotaFilter, options: DescribeClientQuotasOptions): DescribeClientQuotasResult = { + assertTrue(filter.strict) + assertEquals(expectedFilter.components().asScala.toSet, filter.components.asScala.toSet) + describedConfigs = true + describeResult + } + } + EasyMock.replay(describeResult) + ConfigCommand.describeConfig(mockAdminClient, describeOpts) + assertTrue(describedConfigs) + } + + @Test + def testDescribeIpConfigs(): Unit = { + val entityType = ClientQuotaEntity.IP + val knownHost = "1.2.3.4" + val defaultIpFilter = ClientQuotaFilter.containsOnly(List(ClientQuotaFilterComponent.ofDefaultEntity(entityType)).asJava) + val singleIpFilter = ClientQuotaFilter.containsOnly(List(ClientQuotaFilterComponent.ofEntity(entityType, knownHost)).asJava) + val allIpsFilter = ClientQuotaFilter.containsOnly(List(ClientQuotaFilterComponent.ofEntityType(entityType)).asJava) + verifyDescribeQuotas(List("--entity-default", "--entity-type", "ips"), defaultIpFilter) + verifyDescribeQuotas(List("--ip-defaults"), defaultIpFilter) + verifyDescribeQuotas(List("--entity-type", "ips", "--entity-name", knownHost), singleIpFilter) + verifyDescribeQuotas(List("--ip", knownHost), singleIpFilter) + verifyDescribeQuotas(List("--entity-type", "ips"), allIpsFilter) + } + + def verifyAlterQuotas(alterOpts: Seq[String], expectedAlterEntity: ClientQuotaEntity, + expectedProps: Map[String, java.lang.Double], expectedAlterOps: Set[ClientQuotaAlteration.Op]): Unit = { + val createOpts = new ConfigCommandOptions(Array("--bootstrap-server", "localhost:9092", + "--alter") ++ alterOpts) + + var describedConfigs = false + val describeFuture = new KafkaFutureImpl[util.Map[ClientQuotaEntity, util.Map[String, java.lang.Double]]] + describeFuture.complete(Map(expectedAlterEntity -> expectedProps.asJava).asJava) + val describeResult: DescribeClientQuotasResult = EasyMock.createNiceMock(classOf[DescribeClientQuotasResult]) + EasyMock.expect(describeResult.entities()).andReturn(describeFuture) + + val expectedFilterComponents = expectedAlterEntity.entries.asScala.map { case (entityType, entityName) => + if (entityName == null) + ClientQuotaFilterComponent.ofDefaultEntity(entityType) + else + ClientQuotaFilterComponent.ofEntity(entityType, entityName) + }.toSet + + var alteredConfigs = false + val alterFuture = new KafkaFutureImpl[Void] + alterFuture.complete(null) + val alterResult: AlterClientQuotasResult = EasyMock.createNiceMock(classOf[AlterClientQuotasResult]) + EasyMock.expect(alterResult.all()).andReturn(alterFuture) + + val node = new Node(1, "localhost", 9092) + val mockAdminClient = new MockAdminClient(util.Collections.singletonList(node), node) { + override def describeClientQuotas(filter: ClientQuotaFilter, options: DescribeClientQuotasOptions): DescribeClientQuotasResult = { + assertTrue(filter.strict) + assertEquals(expectedFilterComponents, filter.components().asScala.toSet) + describedConfigs = true + describeResult + } + + override def alterClientQuotas(entries: util.Collection[ClientQuotaAlteration], options: AlterClientQuotasOptions): AlterClientQuotasResult = { + assertFalse(options.validateOnly) + assertEquals(1, entries.size) + val alteration = entries.asScala.head + assertEquals(expectedAlterEntity, alteration.entity) + val ops = alteration.ops.asScala + assertEquals(expectedAlterOps, ops.toSet) + alteredConfigs = true + alterResult + } + } + EasyMock.replay(alterResult, describeResult) + ConfigCommand.alterConfig(mockAdminClient, createOpts) + assertTrue(describedConfigs) + assertTrue(alteredConfigs) + } + + @Test + def testAlterIpConfig(): Unit = { + val (singleIpArgs, singleIpEntry) = toValues(Some("1.2.3.4"), ClientQuotaEntity.IP) + val singleIpEntity = new ClientQuotaEntity(singleIpEntry.asJava) + val (defaultIpArgs, defaultIpEntry) = toValues(Some(null), ClientQuotaEntity.IP) + val defaultIpEntity = new ClientQuotaEntity(defaultIpEntry.asJava) + + val deleteArgs = List("--delete-config", "connection_creation_rate") + val deleteAlterationOps = Set(new ClientQuotaAlteration.Op("connection_creation_rate", null)) + val propsToDelete = Map("connection_creation_rate" -> Double.box(50.0)) + + val addArgs = List("--add-config", "connection_creation_rate=100") + val addAlterationOps = Set(new ClientQuotaAlteration.Op("connection_creation_rate", 100.0)) + + verifyAlterQuotas(singleIpArgs ++ deleteArgs, singleIpEntity, propsToDelete, deleteAlterationOps) + verifyAlterQuotas(singleIpArgs ++ addArgs, singleIpEntity, Map.empty, addAlterationOps) + verifyAlterQuotas(defaultIpArgs ++ deleteArgs, defaultIpEntity, propsToDelete, deleteAlterationOps) + verifyAlterQuotas(defaultIpArgs ++ addArgs, defaultIpEntity, Map.empty, addAlterationOps) + } + + @Test + def shouldAddClientConfig(): Unit = { + val alterArgs = List("--add-config", "consumer_byte_rate=20000,producer_byte_rate=10000", + "--delete-config", "request_percentage") + val propsToDelete = Map("request_percentage" -> Double.box(50.0)) + + val alterationOps = Set( + new ClientQuotaAlteration.Op("consumer_byte_rate", Double.box(20000)), + new ClientQuotaAlteration.Op("producer_byte_rate", Double.box(10000)), + new ClientQuotaAlteration.Op("request_percentage", null) + ) + + def verifyAlterUserClientQuotas(userOpt: Option[String], clientOpt: Option[String]): Unit = { + val (userArgs, userEntry) = toValues(userOpt, ClientQuotaEntity.USER) + val (clientArgs, clientEntry) = toValues(clientOpt, ClientQuotaEntity.CLIENT_ID) + + val commandArgs = alterArgs ++ userArgs ++ clientArgs + val clientQuotaEntity = new ClientQuotaEntity((userEntry ++ clientEntry).asJava) + verifyAlterQuotas(commandArgs, clientQuotaEntity, propsToDelete, alterationOps) + } + verifyAlterUserClientQuotas(Some("test-user-1"), Some("test-client-1")) + verifyAlterUserClientQuotas(Some("test-user-2"), Some(null)) + verifyAlterUserClientQuotas(Some("test-user-3"), None) + verifyAlterUserClientQuotas(Some(null), Some("test-client-2")) + verifyAlterUserClientQuotas(Some(null), Some(null)) + verifyAlterUserClientQuotas(Some(null), None) + verifyAlterUserClientQuotas(None, Some("test-client-3")) + verifyAlterUserClientQuotas(None, Some(null)) + } + + private val userEntityOpts = List("--entity-type", "users", "--entity-name", "admin") + private val clientEntityOpts = List("--entity-type", "clients", "--entity-name", "admin") + private val addScramOpts = List("--add-config", "SCRAM-SHA-256=[iterations=8192,password=foo-secret]") + private val deleteScramOpts = List("--delete-config", "SCRAM-SHA-256") + + @Test + def shouldNotAlterNonQuotaNonScramUserOrClientConfigUsingBootstrapServer(): Unit = { + // when using --bootstrap-server, it should be illegal to alter anything that is not a quota and not a SCRAM credential + // for both user and client entities + val invalidProp = "some_config" + verifyAlterCommandFails(invalidProp, userEntityOpts ++ + List("-add-config", "consumer_byte_rate=20000,producer_byte_rate=10000,some_config=10")) + verifyAlterCommandFails(invalidProp, userEntityOpts ++ + List("--add-config", "consumer_byte_rate=20000,producer_byte_rate=10000,some_config=10")) + verifyAlterCommandFails(invalidProp, clientEntityOpts ++ List("--add-config", "some_config=10")) + verifyAlterCommandFails(invalidProp, userEntityOpts ++ List("--delete-config", "consumer_byte_rate,some_config")) + verifyAlterCommandFails(invalidProp, userEntityOpts ++ List("--delete-config", "SCRAM-SHA-256,some_config")) + verifyAlterCommandFails(invalidProp, clientEntityOpts ++ List("--delete-config", "some_config")) + } + + @Test + def shouldNotAlterScramClientConfigUsingBootstrapServer(): Unit = { + // when using --bootstrap-server, it should be illegal to alter SCRAM credentials for client entities + verifyAlterCommandFails("SCRAM-SHA-256", clientEntityOpts ++ addScramOpts) + verifyAlterCommandFails("SCRAM-SHA-256", clientEntityOpts ++ deleteScramOpts) + } + + @Test + def shouldNotCreateUserScramCredentialConfigWithUnderMinimumIterationsUsingBootstrapServer(): Unit = { + // when using --bootstrap-server, it should be illegal to create a SCRAM credential for a user + // with an iterations value less than the minimum + verifyAlterCommandFails("SCRAM-SHA-256", userEntityOpts ++ List("--add-config", "SCRAM-SHA-256=[iterations=100,password=foo-secret]")) + } + + @Test + def shouldNotAlterUserScramCredentialAndClientQuotaConfigsSimultaneouslyUsingBootstrapServer(): Unit = { + // when using --bootstrap-server, it should be illegal to alter both SCRAM credentials and quotas for user entities + val expectedErrorMessage = "SCRAM-SHA-256" + val secondUserEntityOpts = List("--entity-type", "users", "--entity-name", "admin1") + val addQuotaOpts = List("--add-config", "consumer_byte_rate=20000") + val deleteQuotaOpts = List("--delete-config", "consumer_byte_rate") + + verifyAlterCommandFails(expectedErrorMessage, userEntityOpts ++ addScramOpts ++ userEntityOpts ++ deleteQuotaOpts) + verifyAlterCommandFails(expectedErrorMessage, userEntityOpts ++ addScramOpts ++ secondUserEntityOpts ++ deleteQuotaOpts) + verifyAlterCommandFails(expectedErrorMessage, userEntityOpts ++ deleteScramOpts ++ userEntityOpts ++ addQuotaOpts) + verifyAlterCommandFails(expectedErrorMessage, userEntityOpts ++ deleteScramOpts ++ secondUserEntityOpts ++ addQuotaOpts) + + // change order of quota/SCRAM commands, verify alter still fails + verifyAlterCommandFails(expectedErrorMessage, userEntityOpts ++ deleteQuotaOpts ++ userEntityOpts ++ addScramOpts) + verifyAlterCommandFails(expectedErrorMessage, secondUserEntityOpts ++ deleteQuotaOpts ++ userEntityOpts ++ addScramOpts) + verifyAlterCommandFails(expectedErrorMessage, userEntityOpts ++ addQuotaOpts ++ userEntityOpts ++ deleteScramOpts) + verifyAlterCommandFails(expectedErrorMessage, secondUserEntityOpts ++ addQuotaOpts ++ userEntityOpts ++ deleteScramOpts) + } + + @Test + def shouldNotDescribeUserScramCredentialsWithEntityDefaultUsingBootstrapServer(): Unit = { + def verifyUserScramCredentialsNotDescribed(requestOpts: List[String]): Unit = { + // User SCRAM credentials should not be described when specifying + // --describe --entity-type users --entity-default (or --user-defaults) with --bootstrap-server + val describeFuture = new KafkaFutureImpl[util.Map[ClientQuotaEntity, util.Map[String, java.lang.Double]]] + describeFuture.complete(Map((new ClientQuotaEntity(Map("" -> "").asJava) -> Map(("request_percentage" -> Double.box(50.0))).asJava)).asJava) + val describeClientQuotasResult: DescribeClientQuotasResult = EasyMock.createNiceMock(classOf[DescribeClientQuotasResult]) + EasyMock.expect(describeClientQuotasResult.entities()).andReturn(describeFuture) + EasyMock.replay(describeClientQuotasResult) + val node = new Node(1, "localhost", 9092) + val mockAdminClient = new MockAdminClient(util.Collections.singletonList(node), node) { + override def describeClientQuotas(filter: ClientQuotaFilter, options: DescribeClientQuotasOptions): DescribeClientQuotasResult = { + describeClientQuotasResult + } + override def describeUserScramCredentials(users: util.List[String], options: DescribeUserScramCredentialsOptions): DescribeUserScramCredentialsResult = { + throw new IllegalStateException("Incorrectly described SCRAM credentials when specifying --entity-default with --bootstrap-server") + } + } + val opts = new ConfigCommandOptions(Array("--bootstrap-server", "localhost:9092", "--describe") ++ requestOpts) + ConfigCommand.describeConfig(mockAdminClient, opts) // fails if describeUserScramCredentials() is invoked + } + + val expectedMsg = "The use of --entity-default or --user-defaults is not allowed with User SCRAM Credentials using --bootstrap-server." + val defaultUserOpt = List("--user-defaults") + val verboseDefaultUserOpts = List("--entity-type", "users", "--entity-default") + verifyAlterCommandFails(expectedMsg, verboseDefaultUserOpts ++ addScramOpts) + verifyAlterCommandFails(expectedMsg, verboseDefaultUserOpts ++ deleteScramOpts) + verifyUserScramCredentialsNotDescribed(verboseDefaultUserOpts) + verifyAlterCommandFails(expectedMsg, defaultUserOpt ++ addScramOpts) + verifyAlterCommandFails(expectedMsg, defaultUserOpt ++ deleteScramOpts) + verifyUserScramCredentialsNotDescribed(defaultUserOpt) + } + + @Test + def shouldAddTopicConfigUsingZookeeper(): Unit = { + val createOpts = new ConfigCommandOptions(Array("--zookeeper", zkConnect, + "--entity-name", "my-topic", + "--entity-type", "topics", + "--alter", + "--add-config", "a=b,c=d")) + + class TestAdminZkClient(zkClient: KafkaZkClient) extends AdminZkClient(zkClient) { + override def changeTopicConfig(topic: String, configChange: Properties): Unit = { + assertEquals("my-topic", topic) + assertEquals("b", configChange.get("a")) + assertEquals("d", configChange.get("c")) + } + } + + ConfigCommand.alterConfigWithZk(null, createOpts, new TestAdminZkClient(zkClient)) + } + + @Test + def shouldAlterTopicConfig(): Unit = { + doShouldAlterTopicConfig(false) + } + + @Test + def shouldAlterTopicConfigFile(): Unit = { + doShouldAlterTopicConfig(true) + } + + def doShouldAlterTopicConfig(file: Boolean): Unit = { + var filePath = "" + val addedConfigs = Seq("delete.retention.ms=1000000", "min.insync.replicas=2") + if (file) { + val file = TestUtils.tempFile(addedConfigs.mkString("\n")) + filePath = file.getPath + } + + val resourceName = "my-topic" + val alterOpts = new ConfigCommandOptions(Array("--bootstrap-server", "localhost:9092", + "--entity-name", resourceName, + "--entity-type", "topics", + "--alter", + if (file) "--add-config-file" else "--add-config", + if (file) filePath else addedConfigs.mkString(","), + "--delete-config", "unclean.leader.election.enable")) + var alteredConfigs = false + + def newConfigEntry(name: String, value: String): ConfigEntry = + ConfigTest.newConfigEntry(name, value, ConfigEntry.ConfigSource.DYNAMIC_TOPIC_CONFIG, false, false, List.empty[ConfigEntry.ConfigSynonym].asJava) + + val resource = new ConfigResource(ConfigResource.Type.TOPIC, resourceName) + val configEntries = List(newConfigEntry("min.insync.replicas", "1"), newConfigEntry("unclean.leader.election.enable", "1")).asJava + val future = new KafkaFutureImpl[util.Map[ConfigResource, Config]] + future.complete(util.Collections.singletonMap(resource, new Config(configEntries))) + val describeResult: DescribeConfigsResult = EasyMock.createNiceMock(classOf[DescribeConfigsResult]) + EasyMock.expect(describeResult.all()).andReturn(future).once() + + val alterFuture = new KafkaFutureImpl[Void] + alterFuture.complete(null) + val alterResult: AlterConfigsResult = EasyMock.createNiceMock(classOf[AlterConfigsResult]) + EasyMock.expect(alterResult.all()).andReturn(alterFuture) + + val node = new Node(1, "localhost", 9092) + val mockAdminClient = new MockAdminClient(util.Collections.singletonList(node), node) { + override def describeConfigs(resources: util.Collection[ConfigResource], options: DescribeConfigsOptions): DescribeConfigsResult = { + assertFalse(options.includeSynonyms(), "Config synonyms requested unnecessarily") + assertEquals(1, resources.size) + val resource = resources.iterator.next + assertEquals(resource.`type`, ConfigResource.Type.TOPIC) + assertEquals(resource.name, resourceName) + describeResult + } + + override def incrementalAlterConfigs(configs: util.Map[ConfigResource, util.Collection[AlterConfigOp]], options: AlterConfigsOptions): AlterConfigsResult = { + assertEquals(1, configs.size) + val entry = configs.entrySet.iterator.next + val resource = entry.getKey + val alterConfigOps = entry.getValue + assertEquals(ConfigResource.Type.TOPIC, resource.`type`) + assertEquals(3, alterConfigOps.size) + + val expectedConfigOps = Set( + new AlterConfigOp(newConfigEntry("delete.retention.ms", "1000000"), AlterConfigOp.OpType.SET), + new AlterConfigOp(newConfigEntry("min.insync.replicas", "2"), AlterConfigOp.OpType.SET), + new AlterConfigOp(newConfigEntry("unclean.leader.election.enable", ""), AlterConfigOp.OpType.DELETE) + ) + assertEquals(expectedConfigOps.size, alterConfigOps.size) + expectedConfigOps.foreach { expectedOp => + val actual = alterConfigOps.asScala.find(_.configEntry.name == expectedOp.configEntry.name) + assertNotEquals(actual, None) + assertEquals(expectedOp.opType, actual.get.opType) + assertEquals(expectedOp.configEntry.name, actual.get.configEntry.name) + assertEquals(expectedOp.configEntry.value, actual.get.configEntry.value) + } + alteredConfigs = true + alterResult + } + } + EasyMock.replay(alterResult, describeResult) + ConfigCommand.alterConfig(mockAdminClient, alterOpts) + assertTrue(alteredConfigs) + EasyMock.reset(alterResult, describeResult) + } + + @Test + def shouldDescribeConfigSynonyms(): Unit = { + val resourceName = "my-topic" + val describeOpts = new ConfigCommandOptions(Array("--bootstrap-server", "localhost:9092", + "--entity-name", resourceName, + "--entity-type", "topics", + "--describe", + "--all")) + + val resource = new ConfigResource(ConfigResource.Type.TOPIC, resourceName) + val future = new KafkaFutureImpl[util.Map[ConfigResource, Config]] + future.complete(util.Collections.singletonMap(resource, new Config(util.Collections.emptyList[ConfigEntry]))) + val describeResult: DescribeConfigsResult = EasyMock.createNiceMock(classOf[DescribeConfigsResult]) + EasyMock.expect(describeResult.all()).andReturn(future).once() + + val node = new Node(1, "localhost", 9092) + val mockAdminClient = new MockAdminClient(util.Collections.singletonList(node), node) { + override def describeConfigs(resources: util.Collection[ConfigResource], options: DescribeConfigsOptions): DescribeConfigsResult = { + assertTrue(options.includeSynonyms(), "Synonyms not requested") + assertEquals(Set(resource), resources.asScala.toSet) + describeResult + } + } + EasyMock.replay(describeResult) + ConfigCommand.describeConfig(mockAdminClient, describeOpts) + EasyMock.reset(describeResult) + } + + @Test + def shouldNotAllowAddBrokerQuotaConfigWhileBrokerUpUsingZookeeper(): Unit = { + val alterOpts = new ConfigCommandOptions(Array("--zookeeper", zkConnect, + "--entity-name", "1", + "--entity-type", "brokers", + "--alter", + "--add-config", "leader.replication.throttled.rate=10,follower.replication.throttled.rate=20")) + + val mockZkClient: KafkaZkClient = EasyMock.createNiceMock(classOf[KafkaZkClient]) + val mockBroker: Broker = EasyMock.createNiceMock(classOf[Broker]) + EasyMock.expect(mockZkClient.getBroker(1)).andReturn(Option(mockBroker)) + EasyMock.replay(mockZkClient) + + assertThrows(classOf[IllegalArgumentException], + () => ConfigCommand.alterConfigWithZk(mockZkClient, alterOpts, new DummyAdminZkClient(zkClient))) + } + + @Test + def shouldNotAllowDescribeBrokerWhileBrokerUpUsingZookeeper(): Unit = { + val describeOpts = new ConfigCommandOptions(Array("--zookeeper", zkConnect, + "--entity-name", "1", + "--entity-type", "brokers", + "--describe")) + + val mockZkClient: KafkaZkClient = EasyMock.createNiceMock(classOf[KafkaZkClient]) + val mockBroker: Broker = EasyMock.createNiceMock(classOf[Broker]) + EasyMock.expect(mockZkClient.getBroker(1)).andReturn(Option(mockBroker)) + EasyMock.replay(mockZkClient) + + assertThrows(classOf[IllegalArgumentException], + () => ConfigCommand.describeConfigWithZk(mockZkClient, describeOpts, new DummyAdminZkClient(zkClient))) + } + + @Test + def shouldSupportDescribeBrokerBeforeBrokerUpUsingZookeeper(): Unit = { + val describeOpts = new ConfigCommandOptions(Array("--zookeeper", zkConnect, + "--entity-name", "1", + "--entity-type", "brokers", + "--describe")) + + class TestAdminZkClient(zkClient: KafkaZkClient) extends AdminZkClient(zkClient) { + override def fetchEntityConfig(rootEntityType: String, sanitizedEntityName: String): Properties = { + assertEquals("brokers", rootEntityType) + assertEquals("1", sanitizedEntityName) + + new Properties() + } + } + + val mockZkClient: KafkaZkClient = EasyMock.createNiceMock(classOf[KafkaZkClient]) + EasyMock.expect(mockZkClient.getBroker(1)).andReturn(None) + EasyMock.replay(mockZkClient) + + ConfigCommand.describeConfigWithZk(mockZkClient, describeOpts, new TestAdminZkClient(zkClient)) + } + + @Test + def shouldAddBrokerLoggerConfig(): Unit = { + val node = new Node(1, "localhost", 9092) + verifyAlterBrokerLoggerConfig(node, "1", "1", List( + new ConfigEntry("kafka.log.LogCleaner", "INFO"), + new ConfigEntry("kafka.server.ReplicaManager", "INFO"), + new ConfigEntry("kafka.server.KafkaApi", "INFO") + )) + } + + @Test + def testNoSpecifiedEntityOptionWithDescribeBrokersInZKIsAllowed(): Unit = { + val optsList = List("--zookeeper", zkConnect, + "--entity-type", ConfigType.Broker, + "--describe" + ) + + new ConfigCommandOptions(optsList.toArray).checkArgs() + } + + @Test + def testNoSpecifiedEntityOptionWithDescribeBrokersInBootstrapServerIsAllowed(): Unit = { + val optsList = List("--bootstrap-server", "localhost:9092", + "--entity-type", ConfigType.Broker, + "--describe" + ) + + new ConfigCommandOptions(optsList.toArray).checkArgs() + } + + @Test + def testDescribeAllBrokerConfig(): Unit = { + val optsList = List("--bootstrap-server", "localhost:9092", + "--entity-type", ConfigType.Broker, + "--entity-name", "1", + "--describe", + "--all") + + new ConfigCommandOptions(optsList.toArray).checkArgs() + } + + @Test + def testDescribeAllTopicConfig(): Unit = { + val optsList = List("--bootstrap-server", "localhost:9092", + "--entity-type", ConfigType.Topic, + "--entity-name", "foo", + "--describe", + "--all") + + new ConfigCommandOptions(optsList.toArray).checkArgs() + } + + @Test + def testDescribeAllBrokerConfigBootstrapServerRequired(): Unit = { + val optsList = List("--zookeeper", zkConnect, + "--entity-type", ConfigType.Broker, + "--entity-name", "1", + "--describe", + "--all") + + assertThrows(classOf[IllegalArgumentException], () => new ConfigCommandOptions(optsList.toArray).checkArgs()) + } + + @Test + def testEntityDefaultOptionWithDescribeBrokerLoggerIsNotAllowed(): Unit = { + val optsList = List("--bootstrap-server", "localhost:9092", + "--entity-type", ConfigCommand.BrokerLoggerConfigType, + "--entity-default", + "--describe" + ) + + assertThrows(classOf[IllegalArgumentException], () => new ConfigCommandOptions(optsList.toArray).checkArgs()) + } + + @Test + def testEntityDefaultOptionWithAlterBrokerLoggerIsNotAllowed(): Unit = { + val optsList = List("--bootstrap-server", "localhost:9092", + "--entity-type", ConfigCommand.BrokerLoggerConfigType, + "--entity-default", + "--alter", + "--add-config", "kafka.log.LogCleaner=DEBUG" + ) + + assertThrows(classOf[IllegalArgumentException], () => new ConfigCommandOptions(optsList.toArray).checkArgs()) + } + + @Test + def shouldRaiseInvalidConfigurationExceptionWhenAddingInvalidBrokerLoggerConfig(): Unit = { + val node = new Node(1, "localhost", 9092) + // verifyAlterBrokerLoggerConfig tries to alter kafka.log.LogCleaner, kafka.server.ReplicaManager and kafka.server.KafkaApi + // yet, we make it so DescribeConfigs returns only one logger, implying that kafka.server.ReplicaManager and kafka.log.LogCleaner are invalid + assertThrows(classOf[InvalidConfigurationException], () => verifyAlterBrokerLoggerConfig(node, "1", "1", List( + new ConfigEntry("kafka.server.KafkaApi", "INFO") + ))) + } + + @Test + def shouldAddDefaultBrokerDynamicConfig(): Unit = { + val node = new Node(1, "localhost", 9092) + verifyAlterBrokerConfig(node, "", List("--entity-default")) + } + + @Test + def shouldAddBrokerDynamicConfig(): Unit = { + val node = new Node(1, "localhost", 9092) + verifyAlterBrokerConfig(node, "1", List("--entity-name", "1")) + } + + def verifyAlterBrokerConfig(node: Node, resourceName: String, resourceOpts: List[String]): Unit = { + val optsList = List("--bootstrap-server", "localhost:9092", + "--entity-type", "brokers", + "--alter", + "--add-config", "message.max.bytes=10,leader.replication.throttled.rate=10") ++ resourceOpts + val alterOpts = new ConfigCommandOptions(optsList.toArray) + val brokerConfigs = mutable.Map[String, String]("num.io.threads" -> "5") + + val resource = new ConfigResource(ConfigResource.Type.BROKER, resourceName) + val configEntries = util.Collections.singletonList(new ConfigEntry("num.io.threads", "5")) + val future = new KafkaFutureImpl[util.Map[ConfigResource, Config]] + future.complete(util.Collections.singletonMap(resource, new Config(configEntries))) + val describeResult: DescribeConfigsResult = EasyMock.createNiceMock(classOf[DescribeConfigsResult]) + EasyMock.expect(describeResult.all()).andReturn(future).once() + + val alterFuture = new KafkaFutureImpl[Void] + alterFuture.complete(null) + val alterResult: AlterConfigsResult = EasyMock.createNiceMock(classOf[AlterConfigsResult]) + EasyMock.expect(alterResult.all()).andReturn(alterFuture) + + val mockAdminClient = new MockAdminClient(util.Collections.singletonList(node), node) { + override def describeConfigs(resources: util.Collection[ConfigResource], options: DescribeConfigsOptions): DescribeConfigsResult = { + assertFalse(options.includeSynonyms(), "Config synonyms requested unnecessarily") + assertEquals(1, resources.size) + val resource = resources.iterator.next + assertEquals(ConfigResource.Type.BROKER, resource.`type`) + assertEquals(resourceName, resource.name) + describeResult + } + + override def alterConfigs(configs: util.Map[ConfigResource, Config], options: AlterConfigsOptions): AlterConfigsResult = { + assertEquals(1, configs.size) + val entry = configs.entrySet.iterator.next + val resource = entry.getKey + val config = entry.getValue + assertEquals(ConfigResource.Type.BROKER, resource.`type`) + config.entries.forEach { e => brokerConfigs.put(e.name, e.value) } + alterResult + } + } + EasyMock.replay(alterResult, describeResult) + ConfigCommand.alterConfig(mockAdminClient, alterOpts) + assertEquals(Map("message.max.bytes" -> "10", "num.io.threads" -> "5", "leader.replication.throttled.rate" -> "10"), + brokerConfigs.toMap) + EasyMock.reset(alterResult, describeResult) + } + + @Test + def shouldDescribeConfigBrokerWithoutEntityName(): Unit = { + val describeOpts = new ConfigCommandOptions(Array("--bootstrap-server", "localhost:9092", + "--entity-type", "brokers", + "--describe")) + + val BrokerDefaultEntityName = "" + val resourceCustom = new ConfigResource(ConfigResource.Type.BROKER, "1") + val resourceDefault = new ConfigResource(ConfigResource.Type.BROKER, BrokerDefaultEntityName) + val future = new KafkaFutureImpl[util.Map[ConfigResource, Config]] + val emptyConfig = new Config(util.Collections.emptyList[ConfigEntry]) + val resultMap = Map(resourceCustom -> emptyConfig, resourceDefault -> emptyConfig).asJava + future.complete(resultMap) + val describeResult: DescribeConfigsResult = EasyMock.createNiceMock(classOf[DescribeConfigsResult]) + // make sure it will be called 2 times: (1) for broker "1" (2) for default broker "" + EasyMock.expect(describeResult.all()).andReturn(future).times(2) + + val node = new Node(1, "localhost", 9092) + val mockAdminClient = new MockAdminClient(util.Collections.singletonList(node), node) { + override def describeConfigs(resources: util.Collection[ConfigResource], options: DescribeConfigsOptions): DescribeConfigsResult = { + assertTrue(options.includeSynonyms(), "Synonyms not requested") + val resource = resources.iterator.next + assertEquals(ConfigResource.Type.BROKER, resource.`type`) + assertTrue(resourceCustom.name == resource.name || resourceDefault.name == resource.name) + assertEquals(1, resources.size) + describeResult + } + } + EasyMock.replay(describeResult) + ConfigCommand.describeConfig(mockAdminClient, describeOpts) + EasyMock.verify(describeResult) + EasyMock.reset(describeResult) + } + + private def verifyAlterBrokerLoggerConfig(node: Node, resourceName: String, entityName: String, + describeConfigEntries: List[ConfigEntry]): Unit = { + val optsList = List("--bootstrap-server", "localhost:9092", + "--entity-type", ConfigCommand.BrokerLoggerConfigType, + "--alter", + "--entity-name", entityName, + "--add-config", "kafka.log.LogCleaner=DEBUG", + "--delete-config", "kafka.server.ReplicaManager,kafka.server.KafkaApi") + val alterOpts = new ConfigCommandOptions(optsList.toArray) + var alteredConfigs = false + + val resource = new ConfigResource(ConfigResource.Type.BROKER_LOGGER, resourceName) + val future = new KafkaFutureImpl[util.Map[ConfigResource, Config]] + future.complete(util.Collections.singletonMap(resource, new Config(describeConfigEntries.asJava))) + val describeResult: DescribeConfigsResult = EasyMock.createNiceMock(classOf[DescribeConfigsResult]) + EasyMock.expect(describeResult.all()).andReturn(future).once() + + val alterFuture = new KafkaFutureImpl[Void] + alterFuture.complete(null) + val alterResult: AlterConfigsResult = EasyMock.createNiceMock(classOf[AlterConfigsResult]) + EasyMock.expect(alterResult.all()).andReturn(alterFuture) + + val mockAdminClient = new MockAdminClient(util.Collections.singletonList(node), node) { + override def describeConfigs(resources: util.Collection[ConfigResource], options: DescribeConfigsOptions): DescribeConfigsResult = { + assertEquals(1, resources.size) + val resource = resources.iterator.next + assertEquals(ConfigResource.Type.BROKER_LOGGER, resource.`type`) + assertEquals(resourceName, resource.name) + describeResult + } + + override def incrementalAlterConfigs(configs: util.Map[ConfigResource, util.Collection[AlterConfigOp]], options: AlterConfigsOptions): AlterConfigsResult = { + assertEquals(1, configs.size) + val entry = configs.entrySet.iterator.next + val resource = entry.getKey + val alterConfigOps = entry.getValue + assertEquals(ConfigResource.Type.BROKER_LOGGER, resource.`type`) + assertEquals(3, alterConfigOps.size) + + val expectedConfigOps = List( + new AlterConfigOp(new ConfigEntry("kafka.log.LogCleaner", "DEBUG"), AlterConfigOp.OpType.SET), + new AlterConfigOp(new ConfigEntry("kafka.server.ReplicaManager", ""), AlterConfigOp.OpType.DELETE), + new AlterConfigOp(new ConfigEntry("kafka.server.KafkaApi", ""), AlterConfigOp.OpType.DELETE) + ) + assertEquals(expectedConfigOps, alterConfigOps.asScala.toList) + alteredConfigs = true + alterResult + } + } + EasyMock.replay(alterResult, describeResult) + ConfigCommand.alterConfig(mockAdminClient, alterOpts) + assertTrue(alteredConfigs) + EasyMock.reset(alterResult, describeResult) + } + + @Test + def shouldSupportCommaSeparatedValuesUsingZookeeper(): Unit = { + val createOpts = new ConfigCommandOptions(Array("--zookeeper", zkConnect, + "--entity-name", "my-topic", + "--entity-type", "topics", + "--alter", + "--add-config", "a=b,c=[d,e ,f],g=[h,i]")) + + class TestAdminZkClient(zkClient: KafkaZkClient) extends AdminZkClient(zkClient) { + override def changeTopicConfig(topic: String, configChange: Properties): Unit = { + assertEquals("my-topic", topic) + assertEquals("b", configChange.get("a")) + assertEquals("d,e ,f", configChange.get("c")) + assertEquals("h,i", configChange.get("g")) + } + } + + ConfigCommand.alterConfigWithZk(null, createOpts, new TestAdminZkClient(zkClient)) + } + + @Test + def shouldNotUpdateBrokerConfigIfMalformedEntityNameUsingZookeeper(): Unit = { + val createOpts = new ConfigCommandOptions(Array("--zookeeper", zkConnect, + "--entity-name", "1,2,3", //Don't support multiple brokers currently + "--entity-type", "brokers", + "--alter", + "--add-config", "leader.replication.throttled.rate=10")) + assertThrows(classOf[IllegalArgumentException], () => ConfigCommand.alterConfigWithZk(null, createOpts, new DummyAdminZkClient(zkClient))) + } + + @Test + def shouldNotUpdateBrokerConfigIfMalformedEntityName(): Unit = { + val createOpts = new ConfigCommandOptions(Array("--bootstrap-server", "localhost:9092", + "--entity-name", "1,2,3", //Don't support multiple brokers currently + "--entity-type", "brokers", + "--alter", + "--add-config", "leader.replication.throttled.rate=10")) + assertThrows(classOf[IllegalArgumentException], () => ConfigCommand.alterConfig(new DummyAdminClient(new Node(1, "localhost", 9092)), createOpts)) + } + + @Test + def testDynamicBrokerConfigUpdateUsingZooKeeper(): Unit = { + val brokerId = "1" + val adminZkClient = new AdminZkClient(zkClient) + val alterOpts = Array("--zookeeper", zkConnect, "--entity-type", "brokers", "--alter") + + def entityOpt(brokerId: Option[String]): Array[String] = { + brokerId.map(id => Array("--entity-name", id)).getOrElse(Array("--entity-default")) + } + + def alterConfigWithZk(configs: Map[String, String], brokerId: Option[String], + encoderConfigs: Map[String, String] = Map.empty): Unit = { + val configStr = (configs ++ encoderConfigs).map { case (k, v) => s"$k=$v" }.mkString(",") + val addOpts = new ConfigCommandOptions(alterOpts ++ entityOpt(brokerId) ++ Array("--add-config", configStr)) + ConfigCommand.alterConfigWithZk(zkClient, addOpts, adminZkClient) + } + + def verifyConfig(configs: Map[String, String], brokerId: Option[String]): Unit = { + val entityConfigs = zkClient.getEntityConfigs("brokers", brokerId.getOrElse(ConfigEntityName.Default)) + assertEquals(configs, entityConfigs.asScala) + } + + def alterAndVerifyConfig(configs: Map[String, String], brokerId: Option[String]): Unit = { + alterConfigWithZk(configs, brokerId) + verifyConfig(configs, brokerId) + } + + def deleteAndVerifyConfig(configNames: Set[String], brokerId: Option[String]): Unit = { + val deleteOpts = new ConfigCommandOptions(alterOpts ++ entityOpt(brokerId) ++ + Array("--delete-config", configNames.mkString(","))) + ConfigCommand.alterConfigWithZk(zkClient, deleteOpts, adminZkClient) + verifyConfig(Map.empty, brokerId) + } + + // Add config + alterAndVerifyConfig(Map("message.max.size" -> "110000"), Some(brokerId)) + alterAndVerifyConfig(Map("message.max.size" -> "120000"), None) + + // Change config + alterAndVerifyConfig(Map("message.max.size" -> "130000"), Some(brokerId)) + alterAndVerifyConfig(Map("message.max.size" -> "140000"), None) + + // Delete config + deleteAndVerifyConfig(Set("message.max.size"), Some(brokerId)) + deleteAndVerifyConfig(Set("message.max.size"), None) + + // Listener configs: should work only with listener name + alterAndVerifyConfig(Map("listener.name.external.ssl.keystore.location" -> "/tmp/test.jks"), Some(brokerId)) + assertThrows(classOf[ConfigException], () => alterConfigWithZk(Map("ssl.keystore.location" -> "/tmp/test.jks"), Some(brokerId))) + + // Per-broker config configured at default cluster-level should fail + assertThrows(classOf[ConfigException], () => alterConfigWithZk(Map("listener.name.external.ssl.keystore.location" -> "/tmp/test.jks"), None)) + deleteAndVerifyConfig(Set("listener.name.external.ssl.keystore.location"), Some(brokerId)) + + // Password config update without encoder secret should fail + assertThrows(classOf[IllegalArgumentException], () => alterConfigWithZk(Map("listener.name.external.ssl.keystore.password" -> "secret"), Some(brokerId))) + + // Password config update with encoder secret should succeed and encoded password must be stored in ZK + val configs = Map("listener.name.external.ssl.keystore.password" -> "secret", "log.cleaner.threads" -> "2") + val encoderConfigs = Map(KafkaConfig.PasswordEncoderSecretProp -> "encoder-secret") + alterConfigWithZk(configs, Some(brokerId), encoderConfigs) + val brokerConfigs = zkClient.getEntityConfigs("brokers", brokerId) + assertFalse(brokerConfigs.contains(KafkaConfig.PasswordEncoderSecretProp), "Encoder secret stored in ZooKeeper") + assertEquals("2", brokerConfigs.getProperty("log.cleaner.threads")) // not encoded + val encodedPassword = brokerConfigs.getProperty("listener.name.external.ssl.keystore.password") + val passwordEncoder = ConfigCommand.createPasswordEncoder(encoderConfigs) + assertEquals("secret", passwordEncoder.decode(encodedPassword).value) + assertEquals(configs.size, brokerConfigs.size) + + // Password config update with overrides for encoder parameters + val configs2 = Map("listener.name.internal.ssl.keystore.password" -> "secret2") + val encoderConfigs2 = Map(KafkaConfig.PasswordEncoderSecretProp -> "encoder-secret", + KafkaConfig.PasswordEncoderCipherAlgorithmProp -> "DES/CBC/PKCS5Padding", + KafkaConfig.PasswordEncoderIterationsProp -> "1024", + KafkaConfig.PasswordEncoderKeyFactoryAlgorithmProp -> "PBKDF2WithHmacSHA1", + KafkaConfig.PasswordEncoderKeyLengthProp -> "64") + alterConfigWithZk(configs2, Some(brokerId), encoderConfigs2) + val brokerConfigs2 = zkClient.getEntityConfigs("brokers", brokerId) + val encodedPassword2 = brokerConfigs2.getProperty("listener.name.internal.ssl.keystore.password") + assertEquals("secret2", ConfigCommand.createPasswordEncoder(encoderConfigs).decode(encodedPassword2).value) + assertEquals("secret2", ConfigCommand.createPasswordEncoder(encoderConfigs2).decode(encodedPassword2).value) + + + // Password config update at default cluster-level should fail + assertThrows(classOf[ConfigException], () => alterConfigWithZk(configs, None, encoderConfigs)) + + // Dynamic config updates using ZK should fail if broker is running. + registerBrokerInZk(brokerId.toInt) + assertThrows(classOf[IllegalArgumentException], () => alterConfigWithZk(Map("message.max.size" -> "210000"), Some(brokerId))) + assertThrows(classOf[IllegalArgumentException], () => alterConfigWithZk(Map("message.max.size" -> "220000"), None)) + + // Dynamic config updates using ZK should for a different broker that is not running should succeed + alterAndVerifyConfig(Map("message.max.size" -> "230000"), Some("2")) + } + + @Test + def shouldNotUpdateBrokerConfigIfMalformedConfigUsingZookeeper(): Unit = { + val createOpts = new ConfigCommandOptions(Array("--zookeeper", zkConnect, + "--entity-name", "1", + "--entity-type", "brokers", + "--alter", + "--add-config", "a==")) + assertThrows(classOf[IllegalArgumentException], () => ConfigCommand.alterConfigWithZk(null, createOpts, new DummyAdminZkClient(zkClient))) + } + + @Test + def shouldNotUpdateBrokerConfigIfMalformedConfig(): Unit = { + val createOpts = new ConfigCommandOptions(Array("--bootstrap-server", "localhost:9092", + "--entity-name", "1", + "--entity-type", "brokers", + "--alter", + "--add-config", "a==")) + assertThrows(classOf[IllegalArgumentException], () => ConfigCommand.alterConfig(new DummyAdminClient(new Node(1, "localhost", 9092)), createOpts)) + } + + @Test + def shouldNotUpdateBrokerConfigIfMalformedBracketConfigUsingZookeeper(): Unit = { + val createOpts = new ConfigCommandOptions(Array("--zookeeper", zkConnect, + "--entity-name", "1", + "--entity-type", "brokers", + "--alter", + "--add-config", "a=[b,c,d=e")) + assertThrows(classOf[IllegalArgumentException], () => ConfigCommand.alterConfigWithZk(null, createOpts, new DummyAdminZkClient(zkClient))) + } + + @Test + def shouldNotUpdateBrokerConfigIfMalformedBracketConfig(): Unit = { + val createOpts = new ConfigCommandOptions(Array("--bootstrap-server", "localhost:9092", + "--entity-name", "1", + "--entity-type", "brokers", + "--alter", + "--add-config", "a=[b,c,d=e")) + assertThrows(classOf[IllegalArgumentException], () => ConfigCommand.alterConfig(new DummyAdminClient(new Node(1, "localhost", 9092)), createOpts)) + } + + @Test + def shouldNotUpdateConfigIfNonExistingConfigIsDeletedUsingZookeper(): Unit = { + val createOpts = new ConfigCommandOptions(Array("--zookeeper", zkConnect, + "--entity-name", "my-topic", + "--entity-type", "topics", + "--alter", + "--delete-config", "missing_config1, missing_config2")) + assertThrows(classOf[InvalidConfigurationException], () => ConfigCommand.alterConfigWithZk(null, createOpts, new DummyAdminZkClient(zkClient))) + } + + @Test + def shouldNotUpdateConfigIfNonExistingConfigIsDeleted(): Unit = { + val resourceName = "my-topic" + val createOpts = new ConfigCommandOptions(Array("--bootstrap-server", "localhost:9092", + "--entity-name", resourceName, + "--entity-type", "topics", + "--alter", + "--delete-config", "missing_config1, missing_config2")) + + val resource = new ConfigResource(ConfigResource.Type.TOPIC, resourceName) + val configEntries = List.empty[ConfigEntry].asJava + val future = new KafkaFutureImpl[util.Map[ConfigResource, Config]] + future.complete(util.Collections.singletonMap(resource, new Config(configEntries))) + val describeResult: DescribeConfigsResult = EasyMock.createNiceMock(classOf[DescribeConfigsResult]) + EasyMock.expect(describeResult.all()).andReturn(future).once() + + val node = new Node(1, "localhost", 9092) + val mockAdminClient = new MockAdminClient(util.Collections.singletonList(node), node) { + override def describeConfigs(resources: util.Collection[ConfigResource], options: DescribeConfigsOptions): DescribeConfigsResult = { + assertEquals(1, resources.size) + val resource = resources.iterator.next + assertEquals(resource.`type`, ConfigResource.Type.TOPIC) + assertEquals(resource.name, resourceName) + describeResult + } + } + + EasyMock.replay(describeResult) + assertThrows(classOf[InvalidConfigurationException], () => ConfigCommand.alterConfig(mockAdminClient, createOpts)) + EasyMock.reset(describeResult) + } + + @Test + def shouldNotDeleteBrokerConfigWhileBrokerUpUsingZookeeper(): Unit = { + val createOpts = new ConfigCommandOptions(Array("--zookeeper", zkConnect, + "--entity-name", "1", + "--entity-type", "brokers", + "--alter", + "--delete-config", "a,c")) + + class TestAdminZkClient(zkClient: KafkaZkClient) extends AdminZkClient(zkClient) { + override def fetchEntityConfig(entityType: String, entityName: String): Properties = { + val properties: Properties = new Properties + properties.put("a", "b") + properties.put("c", "d") + properties.put("e", "f") + properties + } + + override def changeBrokerConfig(brokerIds: Seq[Int], configChange: Properties): Unit = { + assertEquals("f", configChange.get("e")) + assertEquals(1, configChange.size()) + } + } + + val mockZkClient: KafkaZkClient = EasyMock.createNiceMock(classOf[KafkaZkClient]) + val mockBroker: Broker = EasyMock.createNiceMock(classOf[Broker]) + EasyMock.expect(mockZkClient.getBroker(1)).andReturn(Option(mockBroker)) + EasyMock.replay(mockZkClient) + + assertThrows(classOf[IllegalArgumentException], () => ConfigCommand.alterConfigWithZk(mockZkClient, createOpts, new TestAdminZkClient(zkClient))) + } + + @Test + def testScramCredentials(): Unit = { + def createOpts(user: String, config: String): ConfigCommandOptions = { + new ConfigCommandOptions(Array("--zookeeper", zkConnect, + "--entity-name", user, + "--entity-type", "users", + "--alter", + "--add-config", config)) + } + + def deleteOpts(user: String, mechanism: String) = new ConfigCommandOptions(Array("--zookeeper", zkConnect, + "--entity-name", user, + "--entity-type", "users", + "--alter", + "--delete-config", mechanism)) + + val credentials = mutable.Map[String, Properties]() + case class CredentialChange(user: String, mechanisms: Set[String], iterations: Int) extends AdminZkClient(zkClient) { + override def fetchEntityConfig(entityType: String, entityName: String): Properties = { + credentials.getOrElse(entityName, new Properties()) + } + override def changeUserOrUserClientIdConfig(sanitizedEntityName: String, configChange: Properties): Unit = { + assertEquals(user, sanitizedEntityName) + assertEquals(mechanisms, configChange.keySet().asScala) + for (mechanism <- mechanisms) { + val value = configChange.getProperty(mechanism) + assertEquals(-1, value.indexOf("password=")) + val scramCredential = ScramCredentialUtils.credentialFromString(value) + assertEquals(iterations, scramCredential.iterations) + if (configChange != null) + credentials.put(user, configChange) + } + } + } + val optsA = createOpts("userA", "SCRAM-SHA-256=[iterations=8192,password=abc, def]") + ConfigCommand.alterConfigWithZk(null, optsA, CredentialChange("userA", Set("SCRAM-SHA-256"), 8192)) + val optsB = createOpts("userB", "SCRAM-SHA-256=[iterations=4096,password=abc, def],SCRAM-SHA-512=[password=1234=abc]") + ConfigCommand.alterConfigWithZk(null, optsB, CredentialChange("userB", Set("SCRAM-SHA-256", "SCRAM-SHA-512"), 4096)) + + val del256 = deleteOpts("userB", "SCRAM-SHA-256") + ConfigCommand.alterConfigWithZk(null, del256, CredentialChange("userB", Set("SCRAM-SHA-512"), 4096)) + val del512 = deleteOpts("userB", "SCRAM-SHA-512") + ConfigCommand.alterConfigWithZk(null, del512, CredentialChange("userB", Set(), 4096)) + } + + @Test + def testQuotaConfigEntityUsingZookeeperNotAllowed(): Unit = { + assertThrows(classOf[IllegalArgumentException], () => doTestQuotaConfigEntity(zkConfig = true)) + } + + def doTestQuotaConfigEntity(zkConfig: Boolean): Unit = { + val connectOpts = if (zkConfig) + ("--zookeeper", zkConnect) + else + ("--bootstrap-server", "localhost:9092") + + def createOpts(entityType: String, entityName: Option[String], otherArgs: Array[String]) : ConfigCommandOptions = { + val optArray = Array(connectOpts._1, connectOpts._2, "--entity-type", entityType) + val nameArray = entityName match { + case Some(name) => Array("--entity-name", name) + case None => Array[String]() + } + new ConfigCommandOptions(optArray ++ nameArray ++ otherArgs) + } + + def checkEntity(entityType: String, entityName: Option[String], expectedEntityName: String, otherArgs: Array[String]): Unit = { + val opts = createOpts(entityType, entityName, otherArgs) + opts.checkArgs() + val entity = ConfigCommand.parseEntity(opts) + assertEquals(entityType, entity.root.entityType) + assertEquals(expectedEntityName, entity.fullSanitizedName) + } + + def checkInvalidArgs(entityType: String, entityName: Option[String], otherArgs: Array[String]): Unit = { + val opts = createOpts(entityType, entityName, otherArgs) + assertThrows(classOf[IllegalArgumentException], () => opts.checkArgs()) + } + + def checkInvalidEntity(entityType: String, entityName: Option[String], otherArgs: Array[String]): Unit = { + val opts = createOpts(entityType, entityName, otherArgs) + opts.checkArgs() + assertThrows(classOf[IllegalArgumentException], () => ConfigCommand.parseEntity(opts)) + } + + val describeOpts = Array("--describe") + val alterOpts = Array("--alter", "--add-config", "a=b,c=d") + + // quota + val clientId = "client-1" + for (opts <- Seq(describeOpts, alterOpts)) { + checkEntity("clients", Some(clientId), clientId, opts) + checkEntity("clients", Some(""), ConfigEntityName.Default, opts) + } + checkEntity("clients", None, "", describeOpts) + checkInvalidArgs("clients", None, alterOpts) + + // quota + val principal = "CN=ConfigCommandTest,O=Apache,L=" + val sanitizedPrincipal = Sanitizer.sanitize(principal) + assertEquals(-1, sanitizedPrincipal.indexOf('=')) + assertEquals(principal, Sanitizer.desanitize(sanitizedPrincipal)) + for (opts <- Seq(describeOpts, alterOpts)) { + checkEntity("users", Some(principal), sanitizedPrincipal, opts) + checkEntity("users", Some(""), ConfigEntityName.Default, opts) + } + checkEntity("users", None, "", describeOpts) + checkInvalidArgs("users", None, alterOpts) + + // quota + val userClient = sanitizedPrincipal + "/clients/" + clientId + def clientIdOpts(name: String) = Array("--entity-type", "clients", "--entity-name", name) + for (opts <- Seq(describeOpts, alterOpts)) { + checkEntity("users", Some(principal), userClient, opts ++ clientIdOpts(clientId)) + checkEntity("users", Some(principal), sanitizedPrincipal + "/clients/" + ConfigEntityName.Default, opts ++ clientIdOpts("")) + checkEntity("users", Some(""), ConfigEntityName.Default + "/clients/" + clientId, describeOpts ++ clientIdOpts(clientId)) + checkEntity("users", Some(""), ConfigEntityName.Default + "/clients/" + ConfigEntityName.Default, opts ++ clientIdOpts("")) + } + checkEntity("users", Some(principal), sanitizedPrincipal + "/clients", describeOpts ++ Array("--entity-type", "clients")) + // Both user and client-id must be provided for alter + checkInvalidEntity("users", Some(principal), alterOpts ++ Array("--entity-type", "clients")) + checkInvalidEntity("users", None, alterOpts ++ clientIdOpts(clientId)) + checkInvalidArgs("users", None, alterOpts ++ Array("--entity-type", "clients")) + } + + @Test + def testQuotaConfigEntity(): Unit = { + doTestQuotaConfigEntity(zkConfig = false) + } + + @Test + def testUserClientQuotaOptsUsingZookeeperNotAllowed(): Unit = { + assertThrows(classOf[IllegalArgumentException], () => doTestUserClientQuotaOpts(zkConfig = true)) + } + + def doTestUserClientQuotaOpts(zkConfig: Boolean): Unit = { + val connectOpts = if (zkConfig) + ("--zookeeper", zkConnect) + else + ("--bootstrap-server", "localhost:9092") + + def checkEntity(expectedEntityType: String, expectedEntityName: String, args: String*): Unit = { + val opts = new ConfigCommandOptions(Array(connectOpts._1, connectOpts._2) ++ args) + opts.checkArgs() + val entity = ConfigCommand.parseEntity(opts) + assertEquals(expectedEntityType, entity.root.entityType) + assertEquals(expectedEntityName, entity.fullSanitizedName) + } + + // is a valid user principal and client-id (can be handled with URL-encoding), + checkEntity("users", Sanitizer.sanitize(""), + "--entity-type", "users", "--entity-name", "", + "--alter", "--add-config", "a=b,c=d") + checkEntity("clients", Sanitizer.sanitize(""), + "--entity-type", "clients", "--entity-name", "", + "--alter", "--add-config", "a=b,c=d") + + checkEntity("users", Sanitizer.sanitize("CN=user1") + "/clients/client1", + "--entity-type", "users", "--entity-name", "CN=user1", "--entity-type", "clients", "--entity-name", "client1", + "--alter", "--add-config", "a=b,c=d") + checkEntity("users", Sanitizer.sanitize("CN=user1") + "/clients/client1", + "--entity-name", "CN=user1", "--entity-type", "users", "--entity-name", "client1", "--entity-type", "clients", + "--alter", "--add-config", "a=b,c=d") + checkEntity("users", Sanitizer.sanitize("CN=user1") + "/clients/client1", + "--entity-type", "clients", "--entity-name", "client1", "--entity-type", "users", "--entity-name", "CN=user1", + "--alter", "--add-config", "a=b,c=d") + checkEntity("users", Sanitizer.sanitize("CN=user1") + "/clients/client1", + "--entity-name", "client1", "--entity-type", "clients", "--entity-name", "CN=user1", "--entity-type", "users", + "--alter", "--add-config", "a=b,c=d") + checkEntity("users", Sanitizer.sanitize("CN=user1") + "/clients", + "--entity-type", "clients", "--entity-name", "CN=user1", "--entity-type", "users", + "--describe") + checkEntity("users", "/clients", + "--entity-type", "clients", "--entity-type", "users", + "--describe") + checkEntity("users", Sanitizer.sanitize("CN=user1") + "/clients/" + Sanitizer.sanitize("client1?@%"), + "--entity-name", "client1?@%", "--entity-type", "clients", "--entity-name", "CN=user1", "--entity-type", "users", + "--alter", "--add-config", "a=b,c=d") + } + + @Test + def testUserClientQuotaOpts(): Unit = { + doTestUserClientQuotaOpts(zkConfig = false) + } + + @Test + def testQuotaDescribeEntities(): Unit = { + val zkClient: KafkaZkClient = EasyMock.createNiceMock(classOf[KafkaZkClient]) + + def checkEntities(opts: Array[String], expectedFetches: Map[String, Seq[String]], expectedEntityNames: Seq[String]): Unit = { + val entity = ConfigCommand.parseEntity(new ConfigCommandOptions(opts :+ "--describe")) + expectedFetches.foreach { + case (name, values) => EasyMock.expect(zkClient.getAllEntitiesWithConfig(name)).andReturn(values) + } + EasyMock.replay(zkClient) + val entities = entity.getAllEntities(zkClient) + assertEquals(expectedEntityNames, entities.map(e => e.fullSanitizedName)) + EasyMock.reset(zkClient) + } + + val clientId = "a-client" + val principal = "CN=ConfigCommandTest.testQuotaDescribeEntities , O=Apache, L=" + val sanitizedPrincipal = Sanitizer.sanitize(principal) + val userClient = sanitizedPrincipal + "/clients/" + clientId + + var opts = Array("--entity-type", "clients", "--entity-name", clientId) + checkEntities(opts, Map.empty, Seq(clientId)) + + opts = Array("--entity-type", "clients", "--entity-default") + checkEntities(opts, Map.empty, Seq("")) + + opts = Array("--entity-type", "clients") + checkEntities(opts, Map("clients" -> Seq(clientId)), Seq(clientId)) + + opts = Array("--entity-type", "users", "--entity-name", principal) + checkEntities(opts, Map.empty, Seq(sanitizedPrincipal)) + + opts = Array("--entity-type", "users", "--entity-default") + checkEntities(opts, Map.empty, Seq("")) + + opts = Array("--entity-type", "users") + checkEntities(opts, Map("users" -> Seq("", sanitizedPrincipal)), Seq("", sanitizedPrincipal)) + + opts = Array("--entity-type", "users", "--entity-name", principal, "--entity-type", "clients", "--entity-name", clientId) + checkEntities(opts, Map.empty, Seq(userClient)) + + opts = Array("--entity-type", "users", "--entity-name", principal, "--entity-type", "clients", "--entity-default") + checkEntities(opts, Map.empty, Seq(sanitizedPrincipal + "/clients/")) + + opts = Array("--entity-type", "users", "--entity-name", principal, "--entity-type", "clients") + checkEntities(opts, + Map("users/" + sanitizedPrincipal + "/clients" -> Seq("client-4")), + Seq(sanitizedPrincipal + "/clients/client-4")) + + opts = Array("--entity-type", "users", "--entity-default", "--entity-type", "clients") + checkEntities(opts, + Map("users//clients" -> Seq("client-5")), + Seq("/clients/client-5")) + + opts = Array("--entity-type", "users", "--entity-type", "clients") + val userMap = Map("users/" + sanitizedPrincipal + "/clients" -> Seq("client-2")) + val defaultUserMap = Map("users//clients" -> Seq("client-3")) + checkEntities(opts, + Map("users" -> Seq("", sanitizedPrincipal)) ++ defaultUserMap ++ userMap, + Seq("/clients/client-3", sanitizedPrincipal + "/clients/client-2")) + } + + private def registerBrokerInZk(id: Int): Unit = { + zkClient.createTopLevelPaths() + val securityProtocol = SecurityProtocol.PLAINTEXT + val endpoint = new EndPoint("localhost", 9092, ListenerName.forSecurityProtocol(securityProtocol), securityProtocol) + val brokerInfo = BrokerInfo(Broker(id, Seq(endpoint), rack = None), ApiVersion.latestVersion, jmxPort = 9192) + zkClient.registerBroker(brokerInfo) + } + + class DummyAdminZkClient(zkClient: KafkaZkClient) extends AdminZkClient(zkClient) { + override def changeBrokerConfig(brokerIds: Seq[Int], configs: Properties): Unit = {} + override def fetchEntityConfig(entityType: String, entityName: String): Properties = {new Properties} + override def changeClientIdConfig(clientId: String, configs: Properties): Unit = {} + override def changeUserOrUserClientIdConfig(sanitizedEntityName: String, configs: Properties): Unit = {} + override def changeTopicConfig(topic: String, configs: Properties): Unit = {} + } + + class DummyAdminClient(node: Node) extends MockAdminClient(util.Collections.singletonList(node), node) { + override def describeConfigs(resources: util.Collection[ConfigResource], options: DescribeConfigsOptions): DescribeConfigsResult = + EasyMock.createNiceMock(classOf[DescribeConfigsResult]) + override def incrementalAlterConfigs(configs: util.Map[ConfigResource, util.Collection[AlterConfigOp]], + options: AlterConfigsOptions): AlterConfigsResult = EasyMock.createNiceMock(classOf[AlterConfigsResult]) + override def alterConfigs(configs: util.Map[ConfigResource, Config], options: AlterConfigsOptions): AlterConfigsResult = + EasyMock.createNiceMock(classOf[AlterConfigsResult]) + override def describeClientQuotas(filter: ClientQuotaFilter, options: DescribeClientQuotasOptions): DescribeClientQuotasResult = + EasyMock.createNiceMock(classOf[DescribeClientQuotasResult]) + override def alterClientQuotas(entries: util.Collection[ClientQuotaAlteration], + options: AlterClientQuotasOptions): AlterClientQuotasResult = + EasyMock.createNiceMock(classOf[AlterClientQuotasResult]) + } +} diff --git a/core/src/test/scala/unit/kafka/admin/ConsumerGroupCommandTest.scala b/core/src/test/scala/unit/kafka/admin/ConsumerGroupCommandTest.scala new file mode 100644 index 0000000..6415c16 --- /dev/null +++ b/core/src/test/scala/unit/kafka/admin/ConsumerGroupCommandTest.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.admin + +import java.time.Duration +import java.util.concurrent.{ExecutorService, Executors, TimeUnit} +import java.util.{Collections, Properties} + +import kafka.admin.ConsumerGroupCommand.{ConsumerGroupCommandOptions, ConsumerGroupService} +import kafka.integration.KafkaServerTestHarness +import kafka.server.KafkaConfig +import kafka.utils.TestUtils +import org.apache.kafka.clients.admin.AdminClientConfig +import org.apache.kafka.clients.consumer.{KafkaConsumer, RangeAssignor} +import org.apache.kafka.common.{PartitionInfo, TopicPartition} +import org.apache.kafka.common.errors.WakeupException +import org.apache.kafka.common.serialization.StringDeserializer +import org.junit.jupiter.api.{AfterEach, BeforeEach, TestInfo} + +import scala.jdk.CollectionConverters._ +import scala.collection.mutable.ArrayBuffer + +class ConsumerGroupCommandTest extends KafkaServerTestHarness { + import ConsumerGroupCommandTest._ + + val topic = "foo" + val group = "test.group" + + private var consumerGroupService: List[ConsumerGroupService] = List() + private var consumerGroupExecutors: List[AbstractConsumerGroupExecutor] = List() + + // configure the servers and clients + override def generateConfigs = { + TestUtils.createBrokerConfigs(1, zkConnect, enableControlledShutdown = false).map { props => + KafkaConfig.fromProps(props) + } + } + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + createTopic(topic, 1, 1) + } + + @AfterEach + override def tearDown(): Unit = { + consumerGroupService.foreach(_.close()) + consumerGroupExecutors.foreach(_.shutdown()) + super.tearDown() + } + + def committedOffsets(topic: String = topic, group: String = group): collection.Map[TopicPartition, Long] = { + val consumer = createNoAutoCommitConsumer(group) + try { + val partitions: Set[TopicPartition] = consumer.partitionsFor(topic) + .asScala.toSet.map {partitionInfo : PartitionInfo => new TopicPartition(partitionInfo.topic, partitionInfo.partition)} + consumer.committed(partitions.asJava).asScala.filter(_._2 != null).map { case (k, v) => k -> v.offset } + } finally { + consumer.close() + } + } + + def createNoAutoCommitConsumer(group: String): KafkaConsumer[String, String] = { + val props = new Properties + props.put("bootstrap.servers", brokerList) + props.put("group.id", group) + props.put("enable.auto.commit", "false") + new KafkaConsumer(props, new StringDeserializer, new StringDeserializer) + } + + def getConsumerGroupService(args: Array[String]): ConsumerGroupService = { + val opts = new ConsumerGroupCommandOptions(args) + val service = new ConsumerGroupService(opts, Map(AdminClientConfig.RETRIES_CONFIG -> Int.MaxValue.toString)) + consumerGroupService = service :: consumerGroupService + service + } + + def addConsumerGroupExecutor(numConsumers: Int, + topic: String = topic, + group: String = group, + strategy: String = classOf[RangeAssignor].getName, + customPropsOpt: Option[Properties] = None, + syncCommit: Boolean = false): ConsumerGroupExecutor = { + val executor = new ConsumerGroupExecutor(brokerList, numConsumers, group, topic, strategy, customPropsOpt, syncCommit) + addExecutor(executor) + executor + } + + def addSimpleGroupExecutor(partitions: Iterable[TopicPartition] = Seq(new TopicPartition(topic, 0)), + group: String = group): SimpleConsumerGroupExecutor = { + val executor = new SimpleConsumerGroupExecutor(brokerList, group, partitions) + addExecutor(executor) + executor + } + + private def addExecutor(executor: AbstractConsumerGroupExecutor): AbstractConsumerGroupExecutor = { + consumerGroupExecutors = executor :: consumerGroupExecutors + executor + } + +} + +object ConsumerGroupCommandTest { + + abstract class AbstractConsumerRunnable(broker: String, groupId: String, customPropsOpt: Option[Properties] = None, + syncCommit: Boolean = false) extends Runnable { + val props = new Properties + configure(props) + customPropsOpt.foreach(props.asScala ++= _.asScala) + val consumer = new KafkaConsumer(props) + + def configure(props: Properties): Unit = { + props.put("bootstrap.servers", broker) + props.put("group.id", groupId) + props.put("key.deserializer", classOf[StringDeserializer].getName) + props.put("value.deserializer", classOf[StringDeserializer].getName) + } + + def subscribe(): Unit + + def run(): Unit = { + try { + subscribe() + while (true) { + consumer.poll(Duration.ofMillis(Long.MaxValue)) + if (syncCommit) + consumer.commitSync() + } + } catch { + case _: WakeupException => // OK + } finally { + consumer.close() + } + } + + def shutdown(): Unit = { + consumer.wakeup() + } + } + + class ConsumerRunnable(broker: String, groupId: String, topic: String, strategy: String, + customPropsOpt: Option[Properties] = None, syncCommit: Boolean = false) + extends AbstractConsumerRunnable(broker, groupId, customPropsOpt, syncCommit) { + + override def configure(props: Properties): Unit = { + super.configure(props) + props.put("partition.assignment.strategy", strategy) + } + + override def subscribe(): Unit = { + consumer.subscribe(Collections.singleton(topic)) + } + } + + class SimpleConsumerRunnable(broker: String, groupId: String, partitions: Iterable[TopicPartition]) + extends AbstractConsumerRunnable(broker, groupId) { + + override def subscribe(): Unit = { + consumer.assign(partitions.toList.asJava) + } + } + + class AbstractConsumerGroupExecutor(numThreads: Int) { + private val executor: ExecutorService = Executors.newFixedThreadPool(numThreads) + private val consumers = new ArrayBuffer[AbstractConsumerRunnable]() + + def submit(consumerThread: AbstractConsumerRunnable): Unit = { + consumers += consumerThread + executor.submit(consumerThread) + } + + def shutdown(): Unit = { + consumers.foreach(_.shutdown()) + executor.shutdown() + executor.awaitTermination(5000, TimeUnit.MILLISECONDS) + } + } + + class ConsumerGroupExecutor(broker: String, numConsumers: Int, groupId: String, topic: String, strategy: String, + customPropsOpt: Option[Properties] = None, syncCommit: Boolean = false) + extends AbstractConsumerGroupExecutor(numConsumers) { + + for (_ <- 1 to numConsumers) { + submit(new ConsumerRunnable(broker, groupId, topic, strategy, customPropsOpt, syncCommit)) + } + + } + + class SimpleConsumerGroupExecutor(broker: String, groupId: String, partitions: Iterable[TopicPartition]) + extends AbstractConsumerGroupExecutor(1) { + + submit(new SimpleConsumerRunnable(broker, groupId, partitions)) + } + +} + diff --git a/core/src/test/scala/unit/kafka/admin/ConsumerGroupServiceTest.scala b/core/src/test/scala/unit/kafka/admin/ConsumerGroupServiceTest.scala new file mode 100644 index 0000000..76a3855 --- /dev/null +++ b/core/src/test/scala/unit/kafka/admin/ConsumerGroupServiceTest.scala @@ -0,0 +1,220 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.admin + +import java.util +import java.util.{Collections, Optional} + +import kafka.admin.ConsumerGroupCommand.{ConsumerGroupCommandOptions, ConsumerGroupService} +import org.apache.kafka.clients.admin._ +import org.apache.kafka.clients.consumer.{OffsetAndMetadata, RangeAssignor} +import org.apache.kafka.common.{ConsumerGroupState, KafkaFuture, Node, TopicPartition, TopicPartitionInfo} +import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue} +import org.junit.jupiter.api.Test +import org.mockito.ArgumentMatchers +import org.mockito.ArgumentMatchers._ +import org.mockito.Mockito._ +import org.mockito.ArgumentMatcher + +import scala.jdk.CollectionConverters._ +import org.apache.kafka.common.internals.KafkaFutureImpl + +class ConsumerGroupServiceTest { + + private val group = "testGroup" + private val topics = (0 until 5).map(i => s"testTopic$i") + private val numPartitions = 10 + private val topicPartitions = topics.flatMap(topic => (0 until numPartitions).map(i => new TopicPartition(topic, i))) + private val admin = mock(classOf[Admin]) + + @Test + def testAdminRequestsForDescribeOffsets(): Unit = { + val args = Array("--bootstrap-server", "localhost:9092", "--group", group, "--describe", "--offsets") + val groupService = consumerGroupService(args) + + when(admin.describeConsumerGroups(ArgumentMatchers.eq(Collections.singletonList(group)), any())) + .thenReturn(describeGroupsResult(ConsumerGroupState.STABLE)) + when(admin.listConsumerGroupOffsets(ArgumentMatchers.eq(group), any())) + .thenReturn(listGroupOffsetsResult) + when(admin.listOffsets(offsetsArgMatcher, any())) + .thenReturn(listOffsetsResult) + + val (state, assignments) = groupService.collectGroupOffsets(group) + assertEquals(Some("Stable"), state) + assertTrue(assignments.nonEmpty) + assertEquals(topicPartitions.size, assignments.get.size) + + verify(admin, times(1)).describeConsumerGroups(ArgumentMatchers.eq(Collections.singletonList(group)), any()) + verify(admin, times(1)).listConsumerGroupOffsets(ArgumentMatchers.eq(group), any()) + verify(admin, times(1)).listOffsets(offsetsArgMatcher, any()) + } + + @Test + def testAdminRequestsForDescribeNegativeOffsets(): Unit = { + val args = Array("--bootstrap-server", "localhost:9092", "--group", group, "--describe", "--offsets") + val groupService = consumerGroupService(args) + + val testTopicPartition0 = new TopicPartition("testTopic1", 0); + val testTopicPartition1 = new TopicPartition("testTopic1", 1); + val testTopicPartition2 = new TopicPartition("testTopic1", 2); + val testTopicPartition3 = new TopicPartition("testTopic2", 0); + val testTopicPartition4 = new TopicPartition("testTopic2", 1); + val testTopicPartition5 = new TopicPartition("testTopic2", 2); + + // Some topic's partitions gets valid OffsetAndMetada values, other gets nulls values (negative integers) and others aren't defined + val commitedOffsets = Map( + testTopicPartition1 -> new OffsetAndMetadata(100), + testTopicPartition2 -> null, + testTopicPartition3 -> new OffsetAndMetadata(100), + testTopicPartition4 -> new OffsetAndMetadata(100), + testTopicPartition5 -> null, + ).asJava + + val resultInfo = new ListOffsetsResult.ListOffsetsResultInfo(100, System.currentTimeMillis, Optional.of(1)) + val endOffsets = Map( + testTopicPartition0 -> KafkaFuture.completedFuture(resultInfo), + testTopicPartition1 -> KafkaFuture.completedFuture(resultInfo), + testTopicPartition2 -> KafkaFuture.completedFuture(resultInfo), + testTopicPartition3 -> KafkaFuture.completedFuture(resultInfo), + testTopicPartition4 -> KafkaFuture.completedFuture(resultInfo), + testTopicPartition5 -> KafkaFuture.completedFuture(resultInfo), + ) + val assignedTopicPartitions = Set(testTopicPartition0, testTopicPartition1, testTopicPartition2) + val unassignedTopicPartitions = Set(testTopicPartition3, testTopicPartition4, testTopicPartition5) + + val consumerGroupDescription = new ConsumerGroupDescription(group, + true, + Collections.singleton(new MemberDescription("member1", Optional.of("instance1"), "client1", "host1", new MemberAssignment(assignedTopicPartitions.asJava))), + classOf[RangeAssignor].getName, + ConsumerGroupState.STABLE, + new Node(1, "localhost", 9092)) + + def offsetsArgMatcher(expectedPartitions: Set[TopicPartition]): ArgumentMatcher[util.Map[TopicPartition, OffsetSpec]] = { + topicPartitionOffsets => topicPartitionOffsets != null && topicPartitionOffsets.keySet.asScala.equals(expectedPartitions) + } + + val future = new KafkaFutureImpl[ConsumerGroupDescription]() + future.complete(consumerGroupDescription) + when(admin.describeConsumerGroups(ArgumentMatchers.eq(Collections.singletonList(group)), any())) + .thenReturn(new DescribeConsumerGroupsResult(Collections.singletonMap(group, future))) + when(admin.listConsumerGroupOffsets(ArgumentMatchers.eq(group), any())) + .thenReturn(AdminClientTestUtils.listConsumerGroupOffsetsResult(commitedOffsets)) + when(admin.listOffsets( + ArgumentMatchers.argThat(offsetsArgMatcher(assignedTopicPartitions)), + any() + )).thenReturn(new ListOffsetsResult(endOffsets.filter { case (tp, _) => assignedTopicPartitions.contains(tp) }.asJava)) + when(admin.listOffsets( + ArgumentMatchers.argThat(offsetsArgMatcher(unassignedTopicPartitions)), + any() + )).thenReturn(new ListOffsetsResult(endOffsets.filter { case (tp, _) => unassignedTopicPartitions.contains(tp) }.asJava)) + + val (state, assignments) = groupService.collectGroupOffsets(group) + val returnedOffsets = assignments.map { results => + results.map { assignment => + new TopicPartition(assignment.topic.get, assignment.partition.get) -> assignment.offset + }.toMap + }.getOrElse(Map.empty) + + val expectedOffsets = Map( + testTopicPartition0 -> None, + testTopicPartition1 -> Some(100), + testTopicPartition2 -> None, + testTopicPartition3 -> Some(100), + testTopicPartition4 -> Some(100), + testTopicPartition5 -> None + ) + assertEquals(Some("Stable"), state) + assertEquals(expectedOffsets, returnedOffsets) + + verify(admin, times(1)).describeConsumerGroups(ArgumentMatchers.eq(Collections.singletonList(group)), any()) + verify(admin, times(1)).listConsumerGroupOffsets(ArgumentMatchers.eq(group), any()) + verify(admin, times(1)).listOffsets(ArgumentMatchers.argThat(offsetsArgMatcher(assignedTopicPartitions)), any()) + verify(admin, times(1)).listOffsets(ArgumentMatchers.argThat(offsetsArgMatcher(unassignedTopicPartitions)), any()) + } + + @Test + def testAdminRequestsForResetOffsets(): Unit = { + val args = Seq("--bootstrap-server", "localhost:9092", "--group", group, "--reset-offsets", "--to-latest") + val topicsWithoutPartitionsSpecified = topics.tail + val topicArgs = Seq("--topic", s"${topics.head}:${(0 until numPartitions).mkString(",")}") ++ + topicsWithoutPartitionsSpecified.flatMap(topic => Seq("--topic", topic)) + val groupService = consumerGroupService((args ++ topicArgs).toArray) + + when(admin.describeConsumerGroups(ArgumentMatchers.eq(Collections.singletonList(group)), any())) + .thenReturn(describeGroupsResult(ConsumerGroupState.DEAD)) + when(admin.describeTopics(ArgumentMatchers.eq(topicsWithoutPartitionsSpecified.asJava), any())) + .thenReturn(describeTopicsResult(topicsWithoutPartitionsSpecified)) + when(admin.listOffsets(offsetsArgMatcher, any())) + .thenReturn(listOffsetsResult) + + val resetResult = groupService.resetOffsets() + assertEquals(Set(group), resetResult.keySet) + assertEquals(topicPartitions.toSet, resetResult(group).keySet) + + verify(admin, times(1)).describeConsumerGroups(ArgumentMatchers.eq(Collections.singletonList(group)), any()) + verify(admin, times(1)).describeTopics(ArgumentMatchers.eq(topicsWithoutPartitionsSpecified.asJava), any()) + verify(admin, times(1)).listOffsets(offsetsArgMatcher, any()) + } + + private def consumerGroupService(args: Array[String]): ConsumerGroupService = { + new ConsumerGroupService(new ConsumerGroupCommandOptions(args)) { + override protected def createAdminClient(configOverrides: collection.Map[String, String]): Admin = { + admin + } + } + } + + private def describeGroupsResult(groupState: ConsumerGroupState): DescribeConsumerGroupsResult = { + val member1 = new MemberDescription("member1", Optional.of("instance1"), "client1", "host1", null) + val description = new ConsumerGroupDescription(group, + true, + Collections.singleton(member1), + classOf[RangeAssignor].getName, + groupState, + new Node(1, "localhost", 9092)) + val future = new KafkaFutureImpl[ConsumerGroupDescription]() + future.complete(description) + new DescribeConsumerGroupsResult(Collections.singletonMap(group, future)) + } + + private def listGroupOffsetsResult: ListConsumerGroupOffsetsResult = { + val offsets = topicPartitions.map(_ -> new OffsetAndMetadata(100)).toMap.asJava + AdminClientTestUtils.listConsumerGroupOffsetsResult(offsets) + } + + private def offsetsArgMatcher: util.Map[TopicPartition, OffsetSpec] = { + val expectedOffsets = topicPartitions.map(tp => tp -> OffsetSpec.latest).toMap + ArgumentMatchers.argThat[util.Map[TopicPartition, OffsetSpec]] { map => + map.keySet.asScala == expectedOffsets.keySet && map.values.asScala.forall(_.isInstanceOf[OffsetSpec.LatestSpec]) + } + } + + private def listOffsetsResult: ListOffsetsResult = { + val resultInfo = new ListOffsetsResult.ListOffsetsResultInfo(100, System.currentTimeMillis, Optional.of(1)) + val futures = topicPartitions.map(_ -> KafkaFuture.completedFuture(resultInfo)).toMap + new ListOffsetsResult(futures.asJava) + } + + private def describeTopicsResult(topics: Seq[String]): DescribeTopicsResult = { + val topicDescriptions = topics.map { topic => + val partitions = (0 until numPartitions).map(i => new TopicPartitionInfo(i, null, Collections.emptyList[Node], Collections.emptyList[Node])) + topic -> new TopicDescription(topic, false, partitions.asJava) + }.toMap + AdminClientTestUtils.describeTopicsResult(topicDescriptions.asJava) + } +} diff --git a/core/src/test/scala/unit/kafka/admin/DelegationTokenCommandTest.scala b/core/src/test/scala/unit/kafka/admin/DelegationTokenCommandTest.scala new file mode 100644 index 0000000..2071d08 --- /dev/null +++ b/core/src/test/scala/unit/kafka/admin/DelegationTokenCommandTest.scala @@ -0,0 +1,146 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.admin + +import java.util + +import kafka.admin.DelegationTokenCommand.DelegationTokenCommandOptions +import kafka.api.{KafkaSasl, SaslSetup} +import kafka.server.{BaseRequestTest, KafkaConfig} +import kafka.utils.{JaasTestUtils, TestUtils} +import org.apache.kafka.clients.admin.{Admin, AdminClientConfig} +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +import scala.collection.mutable.ListBuffer +import scala.concurrent.ExecutionException + +class DelegationTokenCommandTest extends BaseRequestTest with SaslSetup { + override protected def securityProtocol = SecurityProtocol.SASL_PLAINTEXT + private val kafkaClientSaslMechanism = "PLAIN" + private val kafkaServerSaslMechanisms = List("PLAIN") + protected override val serverSaslProperties = Some(kafkaServerSaslProperties(kafkaServerSaslMechanisms, kafkaClientSaslMechanism)) + protected override val clientSaslProperties = Some(kafkaClientSaslProperties(kafkaClientSaslMechanism)) + var adminClient: Admin = null + + override def brokerCount = 1 + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + startSasl(jaasSections(kafkaServerSaslMechanisms, Some(kafkaClientSaslMechanism), KafkaSasl, JaasTestUtils.KafkaServerContextName)) + super.setUp(testInfo) + } + + override def generateConfigs = { + val props = TestUtils.createBrokerConfigs(brokerCount, zkConnect, + enableControlledShutdown = false, + interBrokerSecurityProtocol = Some(securityProtocol), + trustStoreFile = trustStoreFile, saslProperties = serverSaslProperties, enableToken = true) + props.foreach(brokerPropertyOverrides) + props.map(KafkaConfig.fromProps) + } + + private def createAdminConfig: util.Map[String, Object] = { + val config = new util.HashMap[String, Object] + config.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList) + val securityProps: util.Map[Object, Object] = + TestUtils.adminClientSecurityConfigs(securityProtocol, trustStoreFile, clientSaslProperties) + securityProps.forEach { (key, value) => config.put(key.asInstanceOf[String], value) } + config + } + + @Test + def testDelegationTokenRequests(): Unit = { + adminClient = Admin.create(createAdminConfig) + val renewer1 = "User:renewer1" + val renewer2 = "User:renewer2" + + // create token1 with renewer1 + val tokenCreated = DelegationTokenCommand.createToken(adminClient, getCreateOpts(List(renewer1))) + + var tokens = DelegationTokenCommand.describeToken(adminClient, getDescribeOpts(List())) + assertTrue(tokens.size == 1) + val token1 = tokens.head + assertEquals(token1, tokenCreated) + + // create token2 with renewer2 + val token2 = DelegationTokenCommand.createToken(adminClient, getCreateOpts(List(renewer2))) + + tokens = DelegationTokenCommand.describeToken(adminClient, getDescribeOpts(List())) + assertTrue(tokens.size == 2) + assertEquals(Set(token1, token2), tokens.toSet) + + //get tokens for renewer2 + tokens = DelegationTokenCommand.describeToken(adminClient, getDescribeOpts(List(renewer2))) + assertTrue(tokens.size == 1) + assertEquals(Set(token2), tokens.toSet) + + //test renewing tokens + val expiryTimestamp = DelegationTokenCommand.renewToken(adminClient, getRenewOpts(token1.hmacAsBase64String())) + val renewedToken = DelegationTokenCommand.describeToken(adminClient, getDescribeOpts(List(renewer1))).head + assertEquals(expiryTimestamp, renewedToken.tokenInfo().expiryTimestamp()) + + //test expire tokens + DelegationTokenCommand.expireToken(adminClient, getExpireOpts(token1.hmacAsBase64String())) + DelegationTokenCommand.expireToken(adminClient, getExpireOpts(token2.hmacAsBase64String())) + + tokens = DelegationTokenCommand.describeToken(adminClient, getDescribeOpts(List())) + assertTrue(tokens.size == 0) + + //create token with invalid renewer principal type + assertThrows(classOf[ExecutionException], () => DelegationTokenCommand.createToken(adminClient, getCreateOpts(List("Group:Renewer3")))) + + // try describing tokens for unknown owner + assertTrue(DelegationTokenCommand.describeToken(adminClient, getDescribeOpts(List("User:Unknown"))).isEmpty) + } + + private def getCreateOpts(renewers: List[String]): DelegationTokenCommandOptions = { + val opts = ListBuffer("--bootstrap-server", brokerList, "--max-life-time-period", "-1", + "--command-config", "testfile", "--create") + renewers.foreach(renewer => opts ++= ListBuffer("--renewer-principal", renewer)) + new DelegationTokenCommandOptions(opts.toArray) + } + + private def getDescribeOpts(owners: List[String]): DelegationTokenCommandOptions = { + val opts = ListBuffer("--bootstrap-server", brokerList, "--command-config", "testfile", "--describe") + owners.foreach(owner => opts ++= ListBuffer("--owner-principal", owner)) + new DelegationTokenCommandOptions(opts.toArray) + } + + private def getRenewOpts(hmac: String): DelegationTokenCommandOptions = { + val opts = Array("--bootstrap-server", brokerList, "--command-config", "testfile", "--renew", + "--renew-time-period", "-1", + "--hmac", hmac) + new DelegationTokenCommandOptions(opts) + } + + private def getExpireOpts(hmac: String): DelegationTokenCommandOptions = { + val opts = Array("--bootstrap-server", brokerList, "--command-config", "testfile", "--expire", + "--expiry-time-period", "-1", + "--hmac", hmac) + new DelegationTokenCommandOptions(opts) + } + + @AfterEach + override def tearDown(): Unit = { + if (adminClient != null) + adminClient.close() + super.tearDown() + closeSasl() + } +} diff --git a/core/src/test/scala/unit/kafka/admin/DeleteConsumerGroupsTest.scala b/core/src/test/scala/unit/kafka/admin/DeleteConsumerGroupsTest.scala new file mode 100644 index 0000000..15583b8 --- /dev/null +++ b/core/src/test/scala/unit/kafka/admin/DeleteConsumerGroupsTest.scala @@ -0,0 +1,246 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.admin + +import joptsimple.OptionException +import kafka.utils.TestUtils +import org.apache.kafka.common.errors.{GroupIdNotFoundException, GroupNotEmptyException} +import org.apache.kafka.common.protocol.Errors +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +class DeleteConsumerGroupsTest extends ConsumerGroupCommandTest { + + @Test + def testDeleteWithTopicOption(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + val cgcArgs = Array("--bootstrap-server", brokerList, "--delete", "--group", group, "--topic") + assertThrows(classOf[OptionException], () => getConsumerGroupService(cgcArgs)) + } + + @Test + def testDeleteCmdNonExistingGroup(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + val missingGroup = "missing.group" + + val cgcArgs = Array("--bootstrap-server", brokerList, "--delete", "--group", missingGroup) + val service = getConsumerGroupService(cgcArgs) + + val output = TestUtils.grabConsoleOutput(service.deleteGroups()) + assertTrue(output.contains(s"Group '$missingGroup' could not be deleted due to:") && output.contains(Errors.GROUP_ID_NOT_FOUND.message), + s"The expected error (${Errors.GROUP_ID_NOT_FOUND}) was not detected while deleting consumer group") + } + + @Test + def testDeleteNonExistingGroup(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + val missingGroup = "missing.group" + + // note the group to be deleted is a different (non-existing) group + val cgcArgs = Array("--bootstrap-server", brokerList, "--delete", "--group", missingGroup) + val service = getConsumerGroupService(cgcArgs) + + val result = service.deleteGroups() + assertTrue(result.size == 1 && result.keySet.contains(missingGroup) && result(missingGroup).getCause.isInstanceOf[GroupIdNotFoundException], + s"The expected error (${Errors.GROUP_ID_NOT_FOUND}) was not detected while deleting consumer group") + } + + @Test + def testDeleteCmdNonEmptyGroup(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + + // run one consumer in the group + addConsumerGroupExecutor(numConsumers = 1) + val cgcArgs = Array("--bootstrap-server", brokerList, "--delete", "--group", group) + val service = getConsumerGroupService(cgcArgs) + + TestUtils.waitUntilTrue(() => { + service.collectGroupMembers(group, false)._2.get.size == 1 + }, "The group did not initialize as expected.") + + val output = TestUtils.grabConsoleOutput(service.deleteGroups()) + assertTrue(output.contains(s"Group '$group' could not be deleted due to:") && output.contains(Errors.NON_EMPTY_GROUP.message), + s"The expected error (${Errors.NON_EMPTY_GROUP}) was not detected while deleting consumer group. Output was: (${output})") + } + + @Test + def testDeleteNonEmptyGroup(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + + // run one consumer in the group + addConsumerGroupExecutor(numConsumers = 1) + val cgcArgs = Array("--bootstrap-server", brokerList, "--delete", "--group", group) + val service = getConsumerGroupService(cgcArgs) + + TestUtils.waitUntilTrue(() => { + service.collectGroupMembers(group, false)._2.get.size == 1 + }, "The group did not initialize as expected.") + + val result = service.deleteGroups() + assertNotNull(result(group), + s"Group was deleted successfully, but it shouldn't have been. Result was:(${result})") + assertTrue(result.size == 1 && result.keySet.contains(group) && result(group).getCause.isInstanceOf[GroupNotEmptyException], + s"The expected error (${Errors.NON_EMPTY_GROUP}) was not detected while deleting consumer group. Result was:(${result})") + } + + @Test + def testDeleteCmdEmptyGroup(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + + // run one consumer in the group + val executor = addConsumerGroupExecutor(numConsumers = 1) + val cgcArgs = Array("--bootstrap-server", brokerList, "--delete", "--group", group) + val service = getConsumerGroupService(cgcArgs) + + TestUtils.waitUntilTrue(() => { + service.listConsumerGroups().contains(group) && service.collectGroupState(group).state == "Stable" + }, "The group did not initialize as expected.") + + executor.shutdown() + + TestUtils.waitUntilTrue(() => { + service.collectGroupState(group).state == "Empty" + }, "The group did not become empty as expected.") + + val output = TestUtils.grabConsoleOutput(service.deleteGroups()) + assertTrue(output.contains(s"Deletion of requested consumer groups ('$group') was successful."), + s"The consumer group could not be deleted as expected") + } + + @Test + def testDeleteCmdAllGroups(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + + // Create 3 groups with 1 consumer per each + val groups = + (for (i <- 1 to 3) yield { + val group = this.group + i + val executor = addConsumerGroupExecutor(numConsumers = 1, group = group) + group -> executor + }).toMap + + val cgcArgs = Array("--bootstrap-server", brokerList, "--delete", "--all-groups") + val service = getConsumerGroupService(cgcArgs) + + TestUtils.waitUntilTrue(() => { + service.listConsumerGroups().toSet == groups.keySet && + groups.keySet.forall(groupId => service.collectGroupState(groupId).state == "Stable") + }, "The group did not initialize as expected.") + + // Shutdown consumers to empty out groups + groups.values.foreach(executor => executor.shutdown()) + + TestUtils.waitUntilTrue(() => { + groups.keySet.forall(groupId => service.collectGroupState(groupId).state == "Empty") + }, "The group did not become empty as expected.") + + val output = TestUtils.grabConsoleOutput(service.deleteGroups()).trim + val expectedGroupsForDeletion = groups.keySet + val deletedGroupsGrepped = output.substring(output.indexOf('(') + 1, output.indexOf(')')).split(',') + .map(_.replaceAll("'", "").trim).toSet + + assertTrue(output.matches(s"Deletion of requested consumer groups (.*) was successful.") + && deletedGroupsGrepped == expectedGroupsForDeletion, s"The consumer group(s) could not be deleted as expected" + ) + } + + @Test + def testDeleteEmptyGroup(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + + // run one consumer in the group + val executor = addConsumerGroupExecutor(numConsumers = 1) + val cgcArgs = Array("--bootstrap-server", brokerList, "--delete", "--group", group) + val service = getConsumerGroupService(cgcArgs) + + TestUtils.waitUntilTrue(() => { + service.listConsumerGroups().contains(group) && service.collectGroupState(group).state == "Stable" + }, "The group did not initialize as expected.") + + executor.shutdown() + + TestUtils.waitUntilTrue(() => { + service.collectGroupState(group).state == "Empty" + }, "The group did not become empty as expected.") + + val result = service.deleteGroups() + assertTrue(result.size == 1 && result.keySet.contains(group) && result(group) == null, + s"The consumer group could not be deleted as expected") + } + + @Test + def testDeleteCmdWithMixOfSuccessAndError(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + val missingGroup = "missing.group" + + // run one consumer in the group + val executor = addConsumerGroupExecutor(numConsumers = 1) + val cgcArgs = Array("--bootstrap-server", brokerList, "--delete", "--group", group) + val service = getConsumerGroupService(cgcArgs) + + TestUtils.waitUntilTrue(() => { + service.listConsumerGroups().contains(group) && service.collectGroupState(group).state == "Stable" + }, "The group did not initialize as expected.") + + executor.shutdown() + + TestUtils.waitUntilTrue(() => { + service.collectGroupState(group).state == "Empty" + }, "The group did not become empty as expected.") + + val service2 = getConsumerGroupService(cgcArgs ++ Array("--group", missingGroup)) + val output = TestUtils.grabConsoleOutput(service2.deleteGroups()) + assertTrue(output.contains(s"Group '$missingGroup' could not be deleted due to:") && output.contains(Errors.GROUP_ID_NOT_FOUND.message) && + output.contains(s"These consumer groups were deleted successfully: '$group'"), s"The consumer group deletion did not work as expected") + } + + @Test + def testDeleteWithMixOfSuccessAndError(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + val missingGroup = "missing.group" + + // run one consumer in the group + val executor = addConsumerGroupExecutor(numConsumers = 1) + val cgcArgs = Array("--bootstrap-server", brokerList, "--delete", "--group", group) + val service = getConsumerGroupService(cgcArgs) + + TestUtils.waitUntilTrue(() => { + service.listConsumerGroups().contains(group) && service.collectGroupState(group).state == "Stable" + }, "The group did not initialize as expected.") + + executor.shutdown() + + TestUtils.waitUntilTrue(() => { + service.collectGroupState(group).state == "Empty" + }, "The group did not become empty as expected.") + + val service2 = getConsumerGroupService(cgcArgs ++ Array("--group", missingGroup)) + val result = service2.deleteGroups() + assertTrue(result.size == 2 && + result.keySet.contains(group) && result(group) == null && + result.keySet.contains(missingGroup) && + result(missingGroup).getMessage.contains(Errors.GROUP_ID_NOT_FOUND.message), + s"The consumer group deletion did not work as expected") + } + + + @Test + def testDeleteWithUnrecognizedNewConsumerOption(): Unit = { + val cgcArgs = Array("--new-consumer", "--bootstrap-server", brokerList, "--delete", "--group", group) + assertThrows(classOf[OptionException], () => getConsumerGroupService(cgcArgs)) + } +} diff --git a/core/src/test/scala/unit/kafka/admin/DeleteOffsetsConsumerGroupCommandIntegrationTest.scala b/core/src/test/scala/unit/kafka/admin/DeleteOffsetsConsumerGroupCommandIntegrationTest.scala new file mode 100644 index 0000000..2fa99b2 --- /dev/null +++ b/core/src/test/scala/unit/kafka/admin/DeleteOffsetsConsumerGroupCommandIntegrationTest.scala @@ -0,0 +1,197 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.admin + +import java.util.Properties + +import kafka.server.Defaults +import kafka.utils.TestUtils +import org.apache.kafka.clients.consumer.ConsumerConfig +import org.apache.kafka.clients.consumer.KafkaConsumer +import org.apache.kafka.clients.producer.KafkaProducer +import org.apache.kafka.clients.producer.ProducerConfig +import org.apache.kafka.clients.producer.ProducerRecord +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.serialization.ByteArrayDeserializer +import org.apache.kafka.common.serialization.ByteArraySerializer +import org.apache.kafka.common.utils.Utils +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Assertions._ + +class DeleteOffsetsConsumerGroupCommandIntegrationTest extends ConsumerGroupCommandTest { + + def getArgs(group: String, topic: String): Array[String] = { + Array( + "--bootstrap-server", brokerList, + "--delete-offsets", + "--group", group, + "--topic", topic + ) + } + + @Test + def testDeleteOffsetsNonExistingGroup(): Unit = { + val group = "missing.group" + val topic = "foo:1" + val service = getConsumerGroupService(getArgs(group, topic)) + + val (error, _) = service.deleteOffsets(group, List(topic)) + assertEquals(Errors.GROUP_ID_NOT_FOUND, error) + } + + @Test + def testDeleteOffsetsOfStableConsumerGroupWithTopicPartition(): Unit = { + testWithStableConsumerGroup(topic, 0, 0, Errors.GROUP_SUBSCRIBED_TO_TOPIC) + } + + @Test + def testDeleteOffsetsOfStableConsumerGroupWithTopicOnly(): Unit = { + testWithStableConsumerGroup(topic, -1, 0, Errors.GROUP_SUBSCRIBED_TO_TOPIC) + } + + @Test + def testDeleteOffsetsOfStableConsumerGroupWithUnknownTopicPartition(): Unit = { + testWithStableConsumerGroup("foobar", 0, 0, Errors.UNKNOWN_TOPIC_OR_PARTITION) + } + + @Test + def testDeleteOffsetsOfStableConsumerGroupWithUnknownTopicOnly(): Unit = { + testWithStableConsumerGroup("foobar", -1, -1, Errors.UNKNOWN_TOPIC_OR_PARTITION) + } + + @Test + def testDeleteOffsetsOfEmptyConsumerGroupWithTopicPartition(): Unit = { + testWithEmptyConsumerGroup(topic, 0, 0, Errors.NONE) + } + + @Test + def testDeleteOffsetsOfEmptyConsumerGroupWithTopicOnly(): Unit = { + testWithEmptyConsumerGroup(topic, -1, 0, Errors.NONE) + } + + @Test + def testDeleteOffsetsOfEmptyConsumerGroupWithUnknownTopicPartition(): Unit = { + testWithEmptyConsumerGroup("foobar", 0, 0, Errors.UNKNOWN_TOPIC_OR_PARTITION) + } + + @Test + def testDeleteOffsetsOfEmptyConsumerGroupWithUnknownTopicOnly(): Unit = { + testWithEmptyConsumerGroup("foobar", -1, -1, Errors.UNKNOWN_TOPIC_OR_PARTITION) + } + + private def testWithStableConsumerGroup(inputTopic: String, + inputPartition: Int, + expectedPartition: Int, + expectedError: Errors): Unit = { + testWithConsumerGroup( + withStableConsumerGroup, + inputTopic, + inputPartition, + expectedPartition, + expectedError) + } + + private def testWithEmptyConsumerGroup(inputTopic: String, + inputPartition: Int, + expectedPartition: Int, + expectedError: Errors): Unit = { + testWithConsumerGroup( + withEmptyConsumerGroup, + inputTopic, + inputPartition, + expectedPartition, + expectedError) + } + + private def testWithConsumerGroup(withConsumerGroup: (=> Unit) => Unit, + inputTopic: String, + inputPartition: Int, + expectedPartition: Int, + expectedError: Errors): Unit = { + produceRecord() + withConsumerGroup { + val topic = if (inputPartition >= 0) inputTopic + ":" + inputPartition else inputTopic + val service = getConsumerGroupService(getArgs(group, topic)) + val (topLevelError, partitions) = service.deleteOffsets(group, List(topic)) + val tp = new TopicPartition(inputTopic, expectedPartition) + // Partition level error should propagate to top level, unless this is due to a missed partition attempt. + if (inputPartition >= 0) { + assertEquals(expectedError, topLevelError) + } + if (expectedError == Errors.NONE) + assertNull(partitions(tp)) + else + assertEquals(expectedError.exception, partitions(tp).getCause) + } + } + + private def produceRecord(): Unit = { + val producer = createProducer() + try { + producer.send(new ProducerRecord(topic, 0, null, null)).get() + } finally { + Utils.closeQuietly(producer, "producer") + } + } + + private def withStableConsumerGroup(body: => Unit): Unit = { + val consumer = createConsumer() + try { + TestUtils.subscribeAndWaitForRecords(this.topic, consumer) + consumer.commitSync() + body + } finally { + Utils.closeQuietly(consumer, "consumer") + } + } + + private def withEmptyConsumerGroup(body: => Unit): Unit = { + val consumer = createConsumer() + try { + TestUtils.subscribeAndWaitForRecords(this.topic, consumer) + consumer.commitSync() + } finally { + Utils.closeQuietly(consumer, "consumer") + } + body + } + + private def createProducer(config: Properties = new Properties()): KafkaProducer[Array[Byte], Array[Byte]] = { + config.putIfAbsent(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList) + config.putIfAbsent(ProducerConfig.ACKS_CONFIG, "-1") + config.putIfAbsent(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + config.putIfAbsent(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + + new KafkaProducer(config) + } + + private def createConsumer(config: Properties = new Properties()): KafkaConsumer[Array[Byte], Array[Byte]] = { + config.putIfAbsent(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList) + config.putIfAbsent(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest") + config.putIfAbsent(ConsumerConfig.GROUP_ID_CONFIG, group) + config.putIfAbsent(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, classOf[ByteArrayDeserializer].getName) + config.putIfAbsent(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, classOf[ByteArrayDeserializer].getName) + // Increase timeouts to avoid having a rebalance during the test + config.putIfAbsent(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG, Integer.MAX_VALUE.toString) + config.putIfAbsent(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, Defaults.GroupMaxSessionTimeoutMs.toString) + + new KafkaConsumer(config) + } + +} diff --git a/core/src/test/scala/unit/kafka/admin/DeleteTopicTest.scala b/core/src/test/scala/unit/kafka/admin/DeleteTopicTest.scala new file mode 100644 index 0000000..bb881f2 --- /dev/null +++ b/core/src/test/scala/unit/kafka/admin/DeleteTopicTest.scala @@ -0,0 +1,450 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.admin + +import java.util +import java.util.concurrent.ExecutionException +import java.util.{Collections, Optional, Properties} + +import scala.collection.Seq +import kafka.log.UnifiedLog +import kafka.zk.TopicPartitionZNode +import kafka.utils.TestUtils +import kafka.server.{KafkaConfig, KafkaServer, QuorumTestHarness} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, Test} +import kafka.common.TopicAlreadyMarkedForDeletionException +import kafka.controller.{OfflineReplica, PartitionAndReplica, ReplicaAssignment, ReplicaDeletionSuccessful} +import org.apache.kafka.clients.admin.{Admin, AdminClientConfig, NewPartitionReassignment, NewPartitions} +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.errors.UnknownTopicOrPartitionException +import scala.jdk.CollectionConverters._ + +class DeleteTopicTest extends QuorumTestHarness { + + var servers: Seq[KafkaServer] = Seq() + + val expectedReplicaAssignment = Map(0 -> List(0, 1, 2)) + val expectedReplicaFullAssignment = expectedReplicaAssignment.map { case (k, v) => + k -> ReplicaAssignment(v, List(), List()) + } + + @AfterEach + override def tearDown(): Unit = { + TestUtils.shutdownServers(servers) + super.tearDown() + } + + @Test + def testDeleteTopicWithAllAliveReplicas(): Unit = { + val topic = "test" + servers = createTestTopicAndCluster(topic) + // start topic deletion + adminZkClient.deleteTopic(topic) + TestUtils.verifyTopicDeletion(zkClient, topic, 1, servers) + } + + @Test + def testResumeDeleteTopicWithRecoveredFollower(): Unit = { + val topicPartition = new TopicPartition("test", 0) + val topic = topicPartition.topic + servers = createTestTopicAndCluster(topic) + // shut down one follower replica + val leaderIdOpt = zkClient.getLeaderForPartition(new TopicPartition(topic, 0)) + assertTrue(leaderIdOpt.isDefined, "Leader should exist for partition [test,0]") + val follower = servers.filter(s => s.config.brokerId != leaderIdOpt.get).last + follower.shutdown() + // start topic deletion + adminZkClient.deleteTopic(topic) + // check if all replicas but the one that is shut down has deleted the log + TestUtils.waitUntilTrue(() => + servers.filter(s => s.config.brokerId != follower.config.brokerId) + .forall(_.getLogManager.getLog(topicPartition).isEmpty), "Replicas 0,1 have not deleted log.") + // ensure topic deletion is halted + TestUtils.waitUntilTrue(() => zkClient.isTopicMarkedForDeletion(topic), + "Admin path /admin/delete_topics/test path deleted even when a follower replica is down") + // restart follower replica + follower.startup() + TestUtils.verifyTopicDeletion(zkClient, topic, 1, servers) + } + + @Test + def testResumeDeleteTopicOnControllerFailover(): Unit = { + val topicPartition = new TopicPartition("test", 0) + val topic = topicPartition.topic + servers = createTestTopicAndCluster(topic) + val controllerId = zkClient.getControllerId.getOrElse(fail("Controller doesn't exist")) + val controller = servers.filter(s => s.config.brokerId == controllerId).head + val leaderIdOpt = zkClient.getLeaderForPartition(new TopicPartition(topic, 0)) + val follower = servers.filter(s => s.config.brokerId != leaderIdOpt.get && s.config.brokerId != controllerId).last + follower.shutdown() + + // start topic deletion + adminZkClient.deleteTopic(topic) + // shut down the controller to trigger controller failover during delete topic + controller.shutdown() + + // ensure topic deletion is halted + TestUtils.waitUntilTrue(() => zkClient.isTopicMarkedForDeletion(topic), + "Admin path /admin/delete_topics/test path deleted even when a replica is down") + + controller.startup() + follower.startup() + + TestUtils.verifyTopicDeletion(zkClient, topic, 1, servers) + } + + @Test + def testPartitionReassignmentDuringDeleteTopic(): Unit = { + val topic = "test" + val topicPartition = new TopicPartition(topic, 0) + val brokerConfigs = TestUtils.createBrokerConfigs(4, zkConnect, false) + brokerConfigs.foreach(p => p.setProperty("delete.topic.enable", "true")) + // create brokers + val allServers = brokerConfigs.map(b => TestUtils.createServer(KafkaConfig.fromProps(b))) + this.servers = allServers + val servers = allServers.filter(s => expectedReplicaAssignment(0).contains(s.config.brokerId)) + // create the topic + TestUtils.createTopic(zkClient, topic, expectedReplicaAssignment, servers) + // wait until replica log is created on every broker + TestUtils.waitUntilTrue(() => servers.forall(_.getLogManager.getLog(topicPartition).isDefined), + "Replicas for topic test not created.") + val leaderIdOpt = zkClient.getLeaderForPartition(new TopicPartition(topic, 0)) + assertTrue(leaderIdOpt.isDefined, "Leader should exist for partition [test,0]") + val follower = servers.filter(s => s.config.brokerId != leaderIdOpt.get).last + follower.shutdown() + // start topic deletion + adminZkClient.deleteTopic(topic) + // verify that a partition from the topic cannot be reassigned + val props = new Properties() + props.setProperty(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, TestUtils.getBrokerListStrFromServers(servers)) + val adminClient = Admin.create(props) + try { + waitUntilTopicGone(adminClient, "test") + verifyReassignmentFailsForMissing(adminClient, new TopicPartition(topic, 0), + new NewPartitionReassignment(util.Arrays.asList(1, 2, 3))) + } finally { + adminClient.close() + } + follower.startup() + TestUtils.verifyTopicDeletion(zkClient, topic, 1, servers) + } + + private def waitUntilTopicGone(adminClient: Admin, topicName: String): Unit = { + TestUtils.waitUntilTrue(() => { + try { + adminClient.describeTopics(util.Collections.singletonList(topicName)).allTopicNames().get() + false + } catch { + case e: ExecutionException => + classOf[UnknownTopicOrPartitionException].equals(e.getCause.getClass) + } + }, s"Topic ${topicName} should be deleted.") + } + + private def verifyReassignmentFailsForMissing(adminClient: Admin, + partition: TopicPartition, + reassignment: NewPartitionReassignment): Unit = { + val e = assertThrows(classOf[ExecutionException], () => adminClient.alterPartitionReassignments(Collections.singletonMap(partition, + Optional.of(reassignment))).all().get()) + assertEquals(classOf[UnknownTopicOrPartitionException], e.getCause.getClass) + } + + private def getController() : (KafkaServer, Int) = { + val controllerId = zkClient.getControllerId.getOrElse(throw new AssertionError("Controller doesn't exist")) + val controller = servers.find(s => s.config.brokerId == controllerId).get + (controller, controllerId) + } + + private def ensureControllerExists() = { + TestUtils.waitUntilTrue(() => { + try { + getController() + true + } catch { + case _: Throwable => false + } + }, "Controller should eventually exist") + } + + private def getAllReplicasFromAssignment(topic : String, assignment : Map[Int, Seq[Int]]) : Set[PartitionAndReplica] = { + assignment.flatMap { case (partition, replicas) => + replicas.map {r => new PartitionAndReplica(new TopicPartition(topic, partition), r)} + }.toSet + } + + @Test + def testIncreasePartitionCountDuringDeleteTopic(): Unit = { + val topic = "test" + val topicPartition = new TopicPartition(topic, 0) + val brokerConfigs = TestUtils.createBrokerConfigs(4, zkConnect, false) + brokerConfigs.foreach(p => p.setProperty("delete.topic.enable", "true")) + // create brokers + val allServers = brokerConfigs.map(b => TestUtils.createServer(KafkaConfig.fromProps(b))) + this.servers = allServers + val servers = allServers.filter(s => expectedReplicaAssignment(0).contains(s.config.brokerId)) + // create the topic + TestUtils.createTopic(zkClient, topic, expectedReplicaAssignment, servers) + // wait until replica log is created on every broker + TestUtils.waitUntilTrue(() => servers.forall(_.getLogManager.getLog(topicPartition).isDefined), + "Replicas for topic test not created.") + // shutdown a broker to make sure the following topic deletion will be suspended + val leaderIdOpt = zkClient.getLeaderForPartition(topicPartition) + assertTrue(leaderIdOpt.isDefined, "Leader should exist for partition [test,0]") + val follower = servers.filter(s => s.config.brokerId != leaderIdOpt.get).last + follower.shutdown() + // start topic deletion + adminZkClient.deleteTopic(topic) + + // make sure deletion of all of the topic's replicas have been tried + ensureControllerExists() + val (controller, controllerId) = getController() + val allReplicasForTopic = getAllReplicasFromAssignment(topic, expectedReplicaAssignment) + TestUtils.waitUntilTrue(() => { + val replicasInDeletionSuccessful = controller.kafkaController.controllerContext.replicasInState(topic, ReplicaDeletionSuccessful) + val offlineReplicas = controller.kafkaController.controllerContext.replicasInState(topic, OfflineReplica) + allReplicasForTopic == (replicasInDeletionSuccessful union offlineReplicas) + }, s"Not all replicas for topic $topic are in states of either ReplicaDeletionSuccessful or OfflineReplica") + + // increase the partition count for topic + val props = new Properties() + props.setProperty(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, TestUtils.getBrokerListStrFromServers(servers)) + val adminClient = Admin.create(props) + try { + adminClient.createPartitions(Map(topic -> NewPartitions.increaseTo(2)).asJava).all().get() + } catch { + case _: ExecutionException => + } + // trigger a controller switch now + val previousControllerId = controllerId + + controller.shutdown() + + ensureControllerExists() + // wait until a new controller to show up + TestUtils.waitUntilTrue(() => { + val (newController, newControllerId) = getController() + newControllerId != previousControllerId + }, "The new controller should not have the failed controller id") + + // bring back the failed brokers + follower.startup() + controller.startup() + TestUtils.verifyTopicDeletion(zkClient, topic, 2, servers) + adminClient.close() + } + + + @Test + def testDeleteTopicDuringAddPartition(): Unit = { + val topic = "test" + servers = createTestTopicAndCluster(topic) + val leaderIdOpt = zkClient.getLeaderForPartition(new TopicPartition(topic, 0)) + assertTrue(leaderIdOpt.isDefined, "Leader should exist for partition [test,0]") + val follower = servers.filter(_.config.brokerId != leaderIdOpt.get).last + val newPartition = new TopicPartition(topic, 1) + // capture the brokers before we shutdown so that we don't fail validation in `addPartitions` + val brokers = adminZkClient.getBrokerMetadatas() + follower.shutdown() + // wait until the broker has been removed from ZK to reduce non-determinism + TestUtils.waitUntilTrue(() => zkClient.getBroker(follower.config.brokerId).isEmpty, + s"Follower ${follower.config.brokerId} was not removed from ZK") + // add partitions to topic + adminZkClient.addPartitions(topic, expectedReplicaFullAssignment, brokers, 2, + Some(Map(1 -> Seq(0, 1, 2), 2 -> Seq(0, 1, 2)))) + // start topic deletion + adminZkClient.deleteTopic(topic) + follower.startup() + // test if topic deletion is resumed + TestUtils.verifyTopicDeletion(zkClient, topic, 1, servers) + // verify that new partition doesn't exist on any broker either + TestUtils.waitUntilTrue(() => + servers.forall(_.getLogManager.getLog(newPartition).isEmpty), + "Replica logs not for new partition [test,1] not deleted after delete topic is complete.") + } + + @Test + def testAddPartitionDuringDeleteTopic(): Unit = { + zkClient.createTopLevelPaths() + val topic = "test" + servers = createTestTopicAndCluster(topic) + val brokers = adminZkClient.getBrokerMetadatas() + // start topic deletion + adminZkClient.deleteTopic(topic) + // add partitions to topic + val newPartition = new TopicPartition(topic, 1) + adminZkClient.addPartitions(topic, expectedReplicaFullAssignment, brokers, 2, + Some(Map(1 -> Seq(0, 1, 2), 2 -> Seq(0, 1, 2)))) + TestUtils.verifyTopicDeletion(zkClient, topic, 1, servers) + // verify that new partition doesn't exist on any broker either + assertTrue(servers.forall(_.getLogManager.getLog(newPartition).isEmpty), "Replica logs not deleted after delete topic is complete") + } + + @Test + def testRecreateTopicAfterDeletion(): Unit = { + val expectedReplicaAssignment = Map(0 -> List(0, 1, 2)) + val topic = "test" + val topicPartition = new TopicPartition(topic, 0) + servers = createTestTopicAndCluster(topic) + // start topic deletion + adminZkClient.deleteTopic(topic) + TestUtils.verifyTopicDeletion(zkClient, topic, 1, servers) + // re-create topic on same replicas + TestUtils.createTopic(zkClient, topic, expectedReplicaAssignment, servers) + // check if all replica logs are created + TestUtils.waitUntilTrue(() => servers.forall(_.getLogManager.getLog(topicPartition).isDefined), + "Replicas for topic test not created.") + } + + @Test + def testDeleteNonExistingTopic(): Unit = { + val topicPartition = new TopicPartition("test", 0) + val topic = topicPartition.topic + servers = createTestTopicAndCluster(topic) + // start topic deletion + assertThrows(classOf[UnknownTopicOrPartitionException], () => adminZkClient.deleteTopic("test2")) + // verify delete topic path for test2 is removed from ZooKeeper + TestUtils.verifyTopicDeletion(zkClient, "test2", 1, servers) + // verify that topic test is untouched + TestUtils.waitUntilTrue(() => servers.forall(_.getLogManager.getLog(topicPartition).isDefined), + "Replicas for topic test not created") + // test the topic path exists + assertTrue(zkClient.topicExists(topic), "Topic test mistakenly deleted") + // topic test should have a leader + TestUtils.waitUntilLeaderIsElectedOrChanged(zkClient, topic, 0, 1000) + } + + @Test + def testDeleteTopicWithCleaner(): Unit = { + val topicName = "test" + val topicPartition = new TopicPartition(topicName, 0) + val topic = topicPartition.topic + + val brokerConfigs = TestUtils.createBrokerConfigs(3, zkConnect, false) + brokerConfigs.head.setProperty("delete.topic.enable", "true") + brokerConfigs.head.setProperty("log.cleaner.enable","true") + brokerConfigs.head.setProperty("log.cleanup.policy","compact") + brokerConfigs.head.setProperty("log.segment.bytes","100") + brokerConfigs.head.setProperty("log.cleaner.dedupe.buffer.size","1048577") + + servers = createTestTopicAndCluster(topic, brokerConfigs, expectedReplicaAssignment) + + // for simplicity, we are validating cleaner offsets on a single broker + val server = servers.head + val log = server.logManager.getLog(topicPartition).get + + // write to the topic to activate cleaner + writeDups(numKeys = 100, numDups = 3,log) + + // wait for cleaner to clean + server.logManager.cleaner.awaitCleaned(new TopicPartition(topicName, 0), 0) + + // delete topic + adminZkClient.deleteTopic("test") + TestUtils.verifyTopicDeletion(zkClient, "test", 1, servers) + } + + @Test + def testDeleteTopicAlreadyMarkedAsDeleted(): Unit = { + val topicPartition = new TopicPartition("test", 0) + val topic = topicPartition.topic + servers = createTestTopicAndCluster(topic) + // start topic deletion + adminZkClient.deleteTopic(topic) + // try to delete topic marked as deleted + assertThrows(classOf[TopicAlreadyMarkedForDeletionException], () => adminZkClient.deleteTopic(topic)) + + TestUtils.verifyTopicDeletion(zkClient, topic, 1, servers) + } + + private def createTestTopicAndCluster(topic: String, deleteTopicEnabled: Boolean = true, replicaAssignment: Map[Int, List[Int]] = expectedReplicaAssignment): Seq[KafkaServer] = { + val brokerConfigs = TestUtils.createBrokerConfigs(3, zkConnect, enableControlledShutdown = false) + brokerConfigs.foreach(_.setProperty("delete.topic.enable", deleteTopicEnabled.toString)) + createTestTopicAndCluster(topic, brokerConfigs, replicaAssignment) + } + + private def createTestTopicAndCluster(topic: String, brokerConfigs: Seq[Properties], replicaAssignment: Map[Int, List[Int]]): Seq[KafkaServer] = { + val topicPartition = new TopicPartition(topic, 0) + // create brokers + val servers = brokerConfigs.map(b => TestUtils.createServer(KafkaConfig.fromProps(b))) + // create the topic + TestUtils.createTopic(zkClient, topic, expectedReplicaAssignment, servers) + // wait until replica log is created on every broker + TestUtils.waitUntilTrue(() => servers.forall(_.getLogManager.getLog(topicPartition).isDefined), + "Replicas for topic test not created") + servers + } + + private def writeDups(numKeys: Int, numDups: Int, log: UnifiedLog): Seq[(Int, Int)] = { + var counter = 0 + for (_ <- 0 until numDups; key <- 0 until numKeys) yield { + val count = counter + log.appendAsLeader(TestUtils.singletonRecords(value = counter.toString.getBytes, key = key.toString.getBytes), leaderEpoch = 0) + counter += 1 + (key, count) + } + } + + @Test + def testDisableDeleteTopic(): Unit = { + val topicPartition = new TopicPartition("test", 0) + val topic = topicPartition.topic + servers = createTestTopicAndCluster(topic, deleteTopicEnabled = false) + // mark the topic for deletion + adminZkClient.deleteTopic("test") + TestUtils.waitUntilTrue(() => !zkClient.isTopicMarkedForDeletion(topic), + "Admin path /admin/delete_topics/%s path not deleted even if deleteTopic is disabled".format(topic)) + // verify that topic test is untouched + assertTrue(servers.forall(_.getLogManager.getLog(topicPartition).isDefined)) + // test the topic path exists + assertTrue(zkClient.topicExists(topic), "Topic path disappeared") + // topic test should have a leader + val leaderIdOpt = zkClient.getLeaderForPartition(new TopicPartition(topic, 0)) + assertTrue(leaderIdOpt.isDefined, "Leader should exist for topic test") + } + + @Test + def testDeletingPartiallyDeletedTopic(): Unit = { + /** + * A previous controller could have deleted some partitions of a topic from ZK, but not all partitions, and then crashed. + * In that case, the new controller should be able to handle the partially deleted topic, and finish the deletion. + */ + + val replicaAssignment = Map(0 -> List(0, 1, 2), 1 -> List(0, 1, 2)) + val topic = "test" + servers = createTestTopicAndCluster(topic, true, replicaAssignment) + + /** + * shutdown all brokers in order to create a partially deleted topic on ZK + */ + servers.foreach(_.shutdown()) + + /** + * delete the partition znode at /brokers/topics/test/partition/0 + * to simulate the case that a previous controller crashed right after deleting the partition znode + */ + zkClient.deleteRecursive(TopicPartitionZNode.path(new TopicPartition(topic, 0))) + adminZkClient.deleteTopic(topic) + + /** + * start up all brokers and verify that topic deletion eventually finishes. + */ + servers.foreach(_.startup()) + TestUtils.waitUntilTrue(() => servers.exists(_.kafkaController.isActive), "No controller is elected") + TestUtils.verifyTopicDeletion(zkClient, topic, 2, servers) + } +} diff --git a/core/src/test/scala/unit/kafka/admin/DescribeConsumerGroupTest.scala b/core/src/test/scala/unit/kafka/admin/DescribeConsumerGroupTest.scala new file mode 100644 index 0000000..e5ca106 --- /dev/null +++ b/core/src/test/scala/unit/kafka/admin/DescribeConsumerGroupTest.scala @@ -0,0 +1,698 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.admin + +import java.util.Properties + +import kafka.utils.{Exit, TestUtils} +import org.apache.kafka.clients.consumer.{ConsumerConfig, RoundRobinAssignor} +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.errors.TimeoutException +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.concurrent.ExecutionException +import scala.util.Random + +class DescribeConsumerGroupTest extends ConsumerGroupCommandTest { + + private val describeTypeOffsets = Array(Array(""), Array("--offsets")) + private val describeTypeMembers = Array(Array("--members"), Array("--members", "--verbose")) + private val describeTypeState = Array(Array("--state")) + private val describeTypes = describeTypeOffsets ++ describeTypeMembers ++ describeTypeState + + @Test + def testDescribeNonExistingGroup(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + val missingGroup = "missing.group" + + for (describeType <- describeTypes) { + // note the group to be queried is a different (non-existing) group + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--group", missingGroup) ++ describeType + val service = getConsumerGroupService(cgcArgs) + + val output = TestUtils.grabConsoleOutput(service.describeGroups()) + assertTrue(output.contains(s"Consumer group '$missingGroup' does not exist."), + s"Expected error was not detected for describe option '${describeType.mkString(" ")}'") + } + } + + @Test + def testDescribeWithMultipleSubActions(): Unit = { + var exitStatus: Option[Int] = None + var exitMessage: Option[String] = None + Exit.setExitProcedure { (status, err) => + exitStatus = Some(status) + exitMessage = err + throw new RuntimeException + } + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--group", group, "--members", "--state") + try { + ConsumerGroupCommand.main(cgcArgs) + } catch { + case e: RuntimeException => //expected + } finally { + Exit.resetExitProcedure() + } + assertEquals(Some(1), exitStatus) + assertTrue(exitMessage.get.contains("Option [describe] takes at most one of these options")) + } + + @Test + def testDescribeWithStateValue(): Unit = { + var exitStatus: Option[Int] = None + var exitMessage: Option[String] = None + Exit.setExitProcedure { (status, err) => + exitStatus = Some(status) + exitMessage = err + throw new RuntimeException + } + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--all-groups", "--state", "Stable") + try { + ConsumerGroupCommand.main(cgcArgs) + } catch { + case e: RuntimeException => //expected + } finally { + Exit.resetExitProcedure() + } + assertEquals(Some(1), exitStatus) + assertTrue(exitMessage.get.contains("Option [describe] does not take a value for [state]")) + } + + @Test + def testDescribeOffsetsOfNonExistingGroup(): Unit = { + val group = "missing.group" + TestUtils.createOffsetsTopic(zkClient, servers) + + // run one consumer in the group consuming from a single-partition topic + addConsumerGroupExecutor(numConsumers = 1) + // note the group to be queried is a different (non-existing) group + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--group", group) + val service = getConsumerGroupService(cgcArgs) + + val (state, assignments) = service.collectGroupOffsets(group) + assertTrue(state.contains("Dead") && assignments.contains(List()), + s"Expected the state to be 'Dead', with no members in the group '$group'.") + } + + @Test + def testDescribeMembersOfNonExistingGroup(): Unit = { + val group = "missing.group" + TestUtils.createOffsetsTopic(zkClient, servers) + + // run one consumer in the group consuming from a single-partition topic + addConsumerGroupExecutor(numConsumers = 1) + // note the group to be queried is a different (non-existing) group + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--group", group) + val service = getConsumerGroupService(cgcArgs) + + val (state, assignments) = service.collectGroupMembers(group, false) + assertTrue(state.contains("Dead") && assignments.contains(List()), + s"Expected the state to be 'Dead', with no members in the group '$group'.") + + val (state2, assignments2) = service.collectGroupMembers(group, true) + assertTrue(state2.contains("Dead") && assignments2.contains(List()), + s"Expected the state to be 'Dead', with no members in the group '$group' (verbose option).") + } + + @Test + def testDescribeStateOfNonExistingGroup(): Unit = { + val group = "missing.group" + TestUtils.createOffsetsTopic(zkClient, servers) + + // run one consumer in the group consuming from a single-partition topic + addConsumerGroupExecutor(numConsumers = 1) + // note the group to be queried is a different (non-existing) group + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--group", group) + val service = getConsumerGroupService(cgcArgs) + + val state = service.collectGroupState(group) + assertTrue(state.state == "Dead" && state.numMembers == 0 && + state.coordinator != null && servers.map(_.config.brokerId).toList.contains(state.coordinator.id), + s"Expected the state to be 'Dead', with no members in the group '$group'." + ) + } + + @Test + def testDescribeExistingGroup(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + + for (describeType <- describeTypes) { + val group = this.group + describeType.mkString("") + // run one consumer in the group consuming from a single-partition topic + addConsumerGroupExecutor(numConsumers = 1, group = group) + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--group", group) ++ describeType + val service = getConsumerGroupService(cgcArgs) + + TestUtils.waitUntilTrue(() => { + val (output, error) = TestUtils.grabConsoleOutputAndError(service.describeGroups()) + output.trim.split("\n").length == 2 && error.isEmpty + }, s"Expected a data row and no error in describe results with describe type ${describeType.mkString(" ")}.") + } + } + + @Test + def testDescribeExistingGroups(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + + // Create N single-threaded consumer groups from a single-partition topic + val groups = (for (describeType <- describeTypes) yield { + val group = this.group + describeType.mkString("") + addConsumerGroupExecutor(numConsumers = 1, group = group) + Array("--group", group) + }).flatten + + val expectedNumLines = describeTypes.length * 2 + + for (describeType <- describeTypes) { + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe") ++ groups ++ describeType + val service = getConsumerGroupService(cgcArgs) + + TestUtils.waitUntilTrue(() => { + val (output, error) = TestUtils.grabConsoleOutputAndError(service.describeGroups()) + val numLines = output.trim.split("\n").filterNot(line => line.isEmpty).length + (numLines == expectedNumLines) && error.isEmpty + }, s"Expected a data row and no error in describe results with describe type ${describeType.mkString(" ")}.") + } + } + + @Test + def testDescribeAllExistingGroups(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + + // Create N single-threaded consumer groups from a single-partition topic + for (describeType <- describeTypes) { + val group = this.group + describeType.mkString("") + addConsumerGroupExecutor(numConsumers = 1, group = group) + } + + val expectedNumLines = describeTypes.length * 2 + + for (describeType <- describeTypes) { + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--all-groups") ++ describeType + val service = getConsumerGroupService(cgcArgs) + + TestUtils.waitUntilTrue(() => { + val (output, error) = TestUtils.grabConsoleOutputAndError(service.describeGroups()) + val numLines = output.trim.split("\n").filterNot(line => line.isEmpty).length + (numLines == expectedNumLines) && error.isEmpty + }, s"Expected a data row and no error in describe results with describe type ${describeType.mkString(" ")}.") + } + } + + @Test + def testDescribeOffsetsOfExistingGroup(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + + // run one consumer in the group consuming from a single-partition topic + addConsumerGroupExecutor(numConsumers = 1) + + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--group", group) + val service = getConsumerGroupService(cgcArgs) + + TestUtils.waitUntilTrue(() => { + val (state, assignments) = service.collectGroupOffsets(group) + state.contains("Stable") && + assignments.isDefined && + assignments.get.count(_.group == group) == 1 && + assignments.get.filter(_.group == group).head.consumerId.exists(_.trim != ConsumerGroupCommand.MISSING_COLUMN_VALUE) && + assignments.get.filter(_.group == group).head.clientId.exists(_.trim != ConsumerGroupCommand.MISSING_COLUMN_VALUE) && + assignments.get.filter(_.group == group).head.host.exists(_.trim != ConsumerGroupCommand.MISSING_COLUMN_VALUE) + }, s"Expected a 'Stable' group status, rows and valid values for consumer id / client id / host columns in describe results for group $group.") + } + + @Test + def testDescribeMembersOfExistingGroup(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + + // run one consumer in the group consuming from a single-partition topic + addConsumerGroupExecutor(numConsumers = 1) + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--group", group) + val service = getConsumerGroupService(cgcArgs) + + TestUtils.waitUntilTrue(() => { + val (state, assignments) = service.collectGroupMembers(group, false) + state.contains("Stable") && + (assignments match { + case Some(memberAssignments) => + memberAssignments.count(_.group == group) == 1 && + memberAssignments.filter(_.group == group).head.consumerId != ConsumerGroupCommand.MISSING_COLUMN_VALUE && + memberAssignments.filter(_.group == group).head.clientId != ConsumerGroupCommand.MISSING_COLUMN_VALUE && + memberAssignments.filter(_.group == group).head.host != ConsumerGroupCommand.MISSING_COLUMN_VALUE + case None => + false + }) + }, s"Expected a 'Stable' group status, rows and valid member information for group $group.") + + val (_, assignments) = service.collectGroupMembers(group, true) + assignments match { + case None => + fail(s"Expected partition assignments for members of group $group") + case Some(memberAssignments) => + assertTrue(memberAssignments.size == 1 && memberAssignments.head.assignment.size == 1, + s"Expected a topic partition assigned to the single group member for group $group") + } + } + + @Test + def testDescribeStateOfExistingGroup(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + + // run one consumer in the group consuming from a single-partition topic + addConsumerGroupExecutor(numConsumers = 1) + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--group", group) + val service = getConsumerGroupService(cgcArgs) + + TestUtils.waitUntilTrue(() => { + val state = service.collectGroupState(group) + state.state == "Stable" && + state.numMembers == 1 && + state.assignmentStrategy == "range" && + state.coordinator != null && + servers.map(_.config.brokerId).toList.contains(state.coordinator.id) + }, s"Expected a 'Stable' group status, with one member and round robin assignment strategy for group $group.") + } + + @Test + def testDescribeStateOfExistingGroupWithRoundRobinAssignor(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + + // run one consumer in the group consuming from a single-partition topic + addConsumerGroupExecutor(numConsumers = 1, strategy = classOf[RoundRobinAssignor].getName) + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--group", group) + val service = getConsumerGroupService(cgcArgs) + + TestUtils.waitUntilTrue(() => { + val state = service.collectGroupState(group) + state.state == "Stable" && + state.numMembers == 1 && + state.assignmentStrategy == "roundrobin" && + state.coordinator != null && + servers.map(_.config.brokerId).toList.contains(state.coordinator.id) + }, s"Expected a 'Stable' group status, with one member and round robin assignment strategy for group $group.") + } + + @Test + def testDescribeExistingGroupWithNoMembers(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + + for (describeType <- describeTypes) { + val group = this.group + describeType.mkString("") + // run one consumer in the group consuming from a single-partition topic + val executor = addConsumerGroupExecutor(numConsumers = 1, group = group) + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--group", group) ++ describeType + val service = getConsumerGroupService(cgcArgs) + + TestUtils.waitUntilTrue(() => { + val (output, error) = TestUtils.grabConsoleOutputAndError(service.describeGroups()) + output.trim.split("\n").length == 2 && error.isEmpty + }, s"Expected describe group results with one data row for describe type '${describeType.mkString(" ")}'") + + // stop the consumer so the group has no active member anymore + executor.shutdown() + TestUtils.waitUntilTrue(() => { + TestUtils.grabConsoleError(service.describeGroups()).contains(s"Consumer group '$group' has no active members.") + }, s"Expected no active member in describe group results with describe type ${describeType.mkString(" ")}") + } + } + + @Test + def testDescribeOffsetsOfExistingGroupWithNoMembers(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + + // run one consumer in the group consuming from a single-partition topic + val executor = addConsumerGroupExecutor(numConsumers = 1) + + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--group", group) + val service = getConsumerGroupService(cgcArgs) + + TestUtils.waitUntilTrue(() => { + val (state, assignments) = service.collectGroupOffsets(group) + state.contains("Stable") && assignments.exists(_.exists(_.group == group)) + }, "Expected the group to initially become stable, and to find group in assignments after initial offset commit.") + + // stop the consumer so the group has no active member anymore + executor.shutdown() + + val (result, succeeded) = TestUtils.computeUntilTrue(service.collectGroupOffsets(group)) { + case (state, assignments) => + val testGroupAssignments = assignments.toSeq.flatMap(_.filter(_.group == group)) + def assignment = testGroupAssignments.head + state.contains("Empty") && + testGroupAssignments.size == 1 && + assignment.consumerId.exists(_.trim == ConsumerGroupCommand.MISSING_COLUMN_VALUE) && // the member should be gone + assignment.clientId.exists(_.trim == ConsumerGroupCommand.MISSING_COLUMN_VALUE) && + assignment.host.exists(_.trim == ConsumerGroupCommand.MISSING_COLUMN_VALUE) + } + val (state, assignments) = result + assertTrue(succeeded, s"Expected no active member in describe group results, state: $state, assignments: $assignments") + } + + @Test + def testDescribeMembersOfExistingGroupWithNoMembers(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + + // run one consumer in the group consuming from a single-partition topic + val executor = addConsumerGroupExecutor(numConsumers = 1) + + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--group", group) + val service = getConsumerGroupService(cgcArgs) + + TestUtils.waitUntilTrue(() => { + val (state, assignments) = service.collectGroupMembers(group, false) + state.contains("Stable") && assignments.exists(_.exists(_.group == group)) + }, "Expected the group to initially become stable, and to find group in assignments after initial offset commit.") + + // stop the consumer so the group has no active member anymore + executor.shutdown() + + TestUtils.waitUntilTrue(() => { + val (state, assignments) = service.collectGroupMembers(group, false) + state.contains("Empty") && assignments.isDefined && assignments.get.isEmpty + }, s"Expected no member in describe group members results for group '$group'") + } + + @Test + def testDescribeStateOfExistingGroupWithNoMembers(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + + // run one consumer in the group consuming from a single-partition topic + val executor = addConsumerGroupExecutor(numConsumers = 1) + + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--group", group) + val service = getConsumerGroupService(cgcArgs) + + TestUtils.waitUntilTrue(() => { + val state = service.collectGroupState(group) + state.state == "Stable" && + state.numMembers == 1 && + state.coordinator != null && + servers.map(_.config.brokerId).toList.contains(state.coordinator.id) + }, s"Expected the group '$group' to initially become stable, and have a single member.") + + // stop the consumer so the group has no active member anymore + executor.shutdown() + + TestUtils.waitUntilTrue(() => { + val state = service.collectGroupState(group) + state.state == "Empty" && state.numMembers == 0 && state.assignmentStrategy == "" + }, s"Expected the group '$group' to become empty after the only member leaving.") + } + + @Test + def testDescribeWithConsumersWithoutAssignedPartitions(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + + for (describeType <- describeTypes) { + val group = this.group + describeType.mkString("") + // run two consumers in the group consuming from a single-partition topic + addConsumerGroupExecutor(numConsumers = 2, group = group) + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--group", group) ++ describeType + val service = getConsumerGroupService(cgcArgs) + + TestUtils.waitUntilTrue(() => { + val (output, error) = TestUtils.grabConsoleOutputAndError(service.describeGroups()) + val expectedNumRows = if (describeTypeMembers.contains(describeType)) 3 else 2 + error.isEmpty && output.trim.split("\n").size == expectedNumRows + }, s"Expected a single data row in describe group result with describe type '${describeType.mkString(" ")}'") + } + } + + @Test + def testDescribeOffsetsWithConsumersWithoutAssignedPartitions(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + + // run two consumers in the group consuming from a single-partition topic + addConsumerGroupExecutor(numConsumers = 2) + + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--group", group) + val service = getConsumerGroupService(cgcArgs) + + TestUtils.waitUntilTrue(() => { + val (state, assignments) = service.collectGroupOffsets(group) + state.contains("Stable") && + assignments.isDefined && + assignments.get.count(_.group == group) == 1 && + assignments.get.count { x => x.group == group && x.partition.isDefined } == 1 + }, "Expected rows for consumers with no assigned partitions in describe group results") + } + + @Test + def testDescribeMembersWithConsumersWithoutAssignedPartitions(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + + // run two consumers in the group consuming from a single-partition topic + addConsumerGroupExecutor(numConsumers = 2) + + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--group", group) + val service = getConsumerGroupService(cgcArgs) + + TestUtils.waitUntilTrue(() => { + val (state, assignments) = service.collectGroupMembers(group, false) + state.contains("Stable") && + assignments.isDefined && + assignments.get.count(_.group == group) == 2 && + assignments.get.count { x => x.group == group && x.numPartitions == 1 } == 1 && + assignments.get.count { x => x.group == group && x.numPartitions == 0 } == 1 && + assignments.get.count(_.assignment.nonEmpty) == 0 + }, "Expected rows for consumers with no assigned partitions in describe group results") + + val (state, assignments) = service.collectGroupMembers(group, true) + assertTrue(state.contains("Stable") && assignments.get.count(_.assignment.nonEmpty) > 0, + "Expected additional columns in verbose version of describe members") + } + + @Test + def testDescribeStateWithConsumersWithoutAssignedPartitions(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + + // run two consumers in the group consuming from a single-partition topic + addConsumerGroupExecutor(numConsumers = 2) + + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--group", group) + val service = getConsumerGroupService(cgcArgs) + + TestUtils.waitUntilTrue(() => { + val state = service.collectGroupState(group) + state.state == "Stable" && state.numMembers == 2 + }, "Expected two consumers in describe group results") + } + + @Test + def testDescribeWithMultiPartitionTopicAndMultipleConsumers(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + val topic2 = "foo2" + createTopic(topic2, 2, 1) + + for (describeType <- describeTypes) { + val group = this.group + describeType.mkString("") + // run two consumers in the group consuming from a two-partition topic + addConsumerGroupExecutor(2, topic2, group = group) + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--group", group) ++ describeType + val service = getConsumerGroupService(cgcArgs) + + TestUtils.waitUntilTrue(() => { + val (output, error) = TestUtils.grabConsoleOutputAndError(service.describeGroups()) + val expectedNumRows = if (describeTypeState.contains(describeType)) 2 else 3 + error.isEmpty && output.trim.split("\n").size == expectedNumRows + }, s"Expected a single data row in describe group result with describe type '${describeType.mkString(" ")}'") + } + } + + @Test + def testDescribeOffsetsWithMultiPartitionTopicAndMultipleConsumers(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + val topic2 = "foo2" + createTopic(topic2, 2, 1) + + // run two consumers in the group consuming from a two-partition topic + addConsumerGroupExecutor(numConsumers = 2, topic2) + + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--group", group) + val service = getConsumerGroupService(cgcArgs) + + TestUtils.waitUntilTrue(() => { + val (state, assignments) = service.collectGroupOffsets(group) + state.contains("Stable") && + assignments.isDefined && + assignments.get.count(_.group == group) == 2 && + assignments.get.count{ x => x.group == group && x.partition.isDefined} == 2 && + assignments.get.count{ x => x.group == group && x.partition.isEmpty} == 0 + }, "Expected two rows (one row per consumer) in describe group results.") + } + + @Test + def testDescribeMembersWithMultiPartitionTopicAndMultipleConsumers(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + val topic2 = "foo2" + createTopic(topic2, 2, 1) + + // run two consumers in the group consuming from a two-partition topic + addConsumerGroupExecutor(numConsumers = 2, topic2) + + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--group", group) + val service = getConsumerGroupService(cgcArgs) + + TestUtils.waitUntilTrue(() => { + val (state, assignments) = service.collectGroupMembers(group, false) + state.contains("Stable") && + assignments.isDefined && + assignments.get.count(_.group == group) == 2 && + assignments.get.count{ x => x.group == group && x.numPartitions == 1 } == 2 && + assignments.get.count{ x => x.group == group && x.numPartitions == 0 } == 0 + }, "Expected two rows (one row per consumer) in describe group members results.") + + val (state, assignments) = service.collectGroupMembers(group, true) + assertTrue(state.contains("Stable") && assignments.get.count(_.assignment.isEmpty) == 0, + "Expected additional columns in verbose version of describe members") + } + + @Test + def testDescribeStateWithMultiPartitionTopicAndMultipleConsumers(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + val topic2 = "foo2" + createTopic(topic2, 2, 1) + + // run two consumers in the group consuming from a two-partition topic + addConsumerGroupExecutor(numConsumers = 2, topic2) + + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--group", group) + val service = getConsumerGroupService(cgcArgs) + + TestUtils.waitUntilTrue(() => { + val state = service.collectGroupState(group) + state.state == "Stable" && state.group == group && state.numMembers == 2 + }, "Expected a stable group with two members in describe group state result.") + } + + @Test + def testDescribeSimpleConsumerGroup(): Unit = { + // Ensure that the offsets of consumers which don't use group management are still displayed + + TestUtils.createOffsetsTopic(zkClient, servers) + val topic2 = "foo2" + createTopic(topic2, 2, 1) + addSimpleGroupExecutor(Seq(new TopicPartition(topic2, 0), new TopicPartition(topic2, 1))) + + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--group", group) + val service = getConsumerGroupService(cgcArgs) + + TestUtils.waitUntilTrue(() => { + val (state, assignments) = service.collectGroupOffsets(group) + state.contains("Empty") && assignments.isDefined && assignments.get.count(_.group == group) == 2 + }, "Expected a stable group with two members in describe group state result.") + } + + @Test + def testDescribeGroupWithShortInitializationTimeout(): Unit = { + // Let creation of the offsets topic happen during group initialization to ensure that initialization doesn't + // complete before the timeout expires + + val describeType = describeTypes(Random.nextInt(describeTypes.length)) + val group = this.group + describeType.mkString("") + // run one consumer in the group consuming from a single-partition topic + addConsumerGroupExecutor(numConsumers = 1) + // set the group initialization timeout too low for the group to stabilize + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--timeout", "1", "--group", group) ++ describeType + val service = getConsumerGroupService(cgcArgs) + + val e = assertThrows(classOf[ExecutionException], () => TestUtils.grabConsoleOutputAndError(service.describeGroups())) + assertEquals(classOf[TimeoutException], e.getCause.getClass) + } + + @Test + def testDescribeGroupOffsetsWithShortInitializationTimeout(): Unit = { + // Let creation of the offsets topic happen during group initialization to ensure that initialization doesn't + // complete before the timeout expires + + // run one consumer in the group consuming from a single-partition topic + addConsumerGroupExecutor(numConsumers = 1) + + // set the group initialization timeout too low for the group to stabilize + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--group", group, "--timeout", "1") + val service = getConsumerGroupService(cgcArgs) + + val e = assertThrows(classOf[ExecutionException], () => service.collectGroupOffsets(group)) + assertEquals(classOf[TimeoutException], e.getCause.getClass) + } + + @Test + def testDescribeGroupMembersWithShortInitializationTimeout(): Unit = { + // Let creation of the offsets topic happen during group initialization to ensure that initialization doesn't + // complete before the timeout expires + + // run one consumer in the group consuming from a single-partition topic + addConsumerGroupExecutor(numConsumers = 1) + + // set the group initialization timeout too low for the group to stabilize + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--group", group, "--timeout", "1") + val service = getConsumerGroupService(cgcArgs) + + var e = assertThrows(classOf[ExecutionException], () => service.collectGroupMembers(group, false)) + assertEquals(classOf[TimeoutException], e.getCause.getClass) + e = assertThrows(classOf[ExecutionException], () => service.collectGroupMembers(group, true)) + assertEquals(classOf[TimeoutException], e.getCause.getClass) + } + + @Test + def testDescribeGroupStateWithShortInitializationTimeout(): Unit = { + // Let creation of the offsets topic happen during group initialization to ensure that initialization doesn't + // complete before the timeout expires + + // run one consumer in the group consuming from a single-partition topic + addConsumerGroupExecutor(numConsumers = 1) + + // set the group initialization timeout too low for the group to stabilize + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--group", group, "--timeout", "1") + val service = getConsumerGroupService(cgcArgs) + + val e = assertThrows(classOf[ExecutionException], () => service.collectGroupState(group)) + assertEquals(classOf[TimeoutException], e.getCause.getClass) + } + + @Test + def testDescribeWithUnrecognizedNewConsumerOption(): Unit = { + val cgcArgs = Array("--new-consumer", "--bootstrap-server", brokerList, "--describe", "--group", group) + assertThrows(classOf[joptsimple.OptionException], () => getConsumerGroupService(cgcArgs)) + } + + @Test + def testDescribeNonOffsetCommitGroup(): Unit = { + TestUtils.createOffsetsTopic(zkClient, servers) + + val customProps = new Properties + // create a consumer group that never commits offsets + customProps.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") + // run one consumer in the group consuming from a single-partition topic + addConsumerGroupExecutor(numConsumers = 1, customPropsOpt = Some(customProps)) + + val cgcArgs = Array("--bootstrap-server", brokerList, "--describe", "--group", group) + val service = getConsumerGroupService(cgcArgs) + + TestUtils.waitUntilTrue(() => { + val (state, assignments) = service.collectGroupOffsets(group) + state.contains("Stable") && + assignments.isDefined && + assignments.get.count(_.group == group) == 1 && + assignments.get.filter(_.group == group).head.consumerId.exists(_.trim != ConsumerGroupCommand.MISSING_COLUMN_VALUE) && + assignments.get.filter(_.group == group).head.clientId.exists(_.trim != ConsumerGroupCommand.MISSING_COLUMN_VALUE) && + assignments.get.filter(_.group == group).head.host.exists(_.trim != ConsumerGroupCommand.MISSING_COLUMN_VALUE) + }, s"Expected a 'Stable' group status, rows and valid values for consumer id / client id / host columns in describe results for non-offset-committing group $group.") + } + +} + diff --git a/core/src/test/scala/unit/kafka/admin/FeatureCommandTest.scala b/core/src/test/scala/unit/kafka/admin/FeatureCommandTest.scala new file mode 100644 index 0000000..3179957 --- /dev/null +++ b/core/src/test/scala/unit/kafka/admin/FeatureCommandTest.scala @@ -0,0 +1,241 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.admin + +import kafka.api.KAFKA_2_7_IV0 +import kafka.server.{BaseRequestTest, KafkaConfig, KafkaServer} +import kafka.utils.TestUtils +import kafka.utils.TestUtils.waitUntilTrue +import org.apache.kafka.common.feature.{Features, SupportedVersionRange} +import org.apache.kafka.common.utils.Utils + +import java.util.Properties + +import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue, assertThrows} +import org.junit.jupiter.api.Test + +class FeatureCommandTest extends BaseRequestTest { + override def brokerCount: Int = 3 + + override def brokerPropertyOverrides(props: Properties): Unit = { + props.put(KafkaConfig.InterBrokerProtocolVersionProp, KAFKA_2_7_IV0.toString) + } + + private val defaultSupportedFeatures: Features[SupportedVersionRange] = + Features.supportedFeatures(Utils.mkMap(Utils.mkEntry("feature_1", new SupportedVersionRange(1, 3)), + Utils.mkEntry("feature_2", new SupportedVersionRange(1, 5)))) + + private def updateSupportedFeatures(features: Features[SupportedVersionRange], + targetServers: Set[KafkaServer]): Unit = { + targetServers.foreach(s => { + s.brokerFeatures.setSupportedFeatures(features) + s.zkClient.updateBrokerInfo(s.createBrokerInfo) + }) + + // Wait until updates to all BrokerZNode supported features propagate to the controller. + val brokerIds = targetServers.map(s => s.config.brokerId) + waitUntilTrue( + () => servers.exists(s => { + if (s.kafkaController.isActive) { + s.kafkaController.controllerContext.liveOrShuttingDownBrokers + .filter(b => brokerIds.contains(b.id)) + .forall(b => { + b.features.equals(features) + }) + } else { + false + } + }), + "Controller did not get broker updates") + } + + private def updateSupportedFeaturesInAllBrokers(features: Features[SupportedVersionRange]): Unit = { + updateSupportedFeatures(features, Set[KafkaServer]() ++ servers) + } + + /** + * Tests if the FeatureApis#describeFeatures API works as expected when describing features before and + * after upgrading features. + */ + @Test + def testDescribeFeaturesSuccess(): Unit = { + updateSupportedFeaturesInAllBrokers(defaultSupportedFeatures) + val featureApis = new FeatureApis(new FeatureCommandOptions(Array("--bootstrap-server", brokerList, "--describe"))) + featureApis.setSupportedFeatures(defaultSupportedFeatures) + try { + val initialDescribeOutput = TestUtils.grabConsoleOutput(featureApis.describeFeatures()) + val expectedInitialDescribeOutput = + "Feature: feature_1\tSupportedMinVersion: 1\tSupportedMaxVersion: 3\tFinalizedMinVersionLevel: -\tFinalizedMaxVersionLevel: -\tEpoch: 0\n" + + "Feature: feature_2\tSupportedMinVersion: 1\tSupportedMaxVersion: 5\tFinalizedMinVersionLevel: -\tFinalizedMaxVersionLevel: -\tEpoch: 0\n" + assertEquals(expectedInitialDescribeOutput, initialDescribeOutput) + featureApis.upgradeAllFeatures() + val finalDescribeOutput = TestUtils.grabConsoleOutput(featureApis.describeFeatures()) + val expectedFinalDescribeOutput = + "Feature: feature_1\tSupportedMinVersion: 1\tSupportedMaxVersion: 3\tFinalizedMinVersionLevel: 1\tFinalizedMaxVersionLevel: 3\tEpoch: 1\n" + + "Feature: feature_2\tSupportedMinVersion: 1\tSupportedMaxVersion: 5\tFinalizedMinVersionLevel: 1\tFinalizedMaxVersionLevel: 5\tEpoch: 1\n" + assertEquals(expectedFinalDescribeOutput, finalDescribeOutput) + } finally { + featureApis.close() + } + } + + /** + * Tests if the FeatureApis#upgradeAllFeatures API works as expected during a success case. + */ + @Test + def testUpgradeAllFeaturesSuccess(): Unit = { + val upgradeOpts = new FeatureCommandOptions(Array("--bootstrap-server", brokerList, "--upgrade-all")) + val featureApis = new FeatureApis(upgradeOpts) + try { + // Step (1): + // - Update the supported features across all brokers. + // - Upgrade non-existing feature_1 to maxVersionLevel: 2. + // - Verify results. + val initialSupportedFeatures = Features.supportedFeatures(Utils.mkMap(Utils.mkEntry("feature_1", new SupportedVersionRange(1, 2)))) + updateSupportedFeaturesInAllBrokers(initialSupportedFeatures) + featureApis.setSupportedFeatures(initialSupportedFeatures) + var output = TestUtils.grabConsoleOutput(featureApis.upgradeAllFeatures()) + var expected = + " [Add]\tFeature: feature_1\tExistingFinalizedMaxVersion: -\tNewFinalizedMaxVersion: 2\tResult: OK\n" + assertEquals(expected, output) + + // Step (2): + // - Update the supported features across all brokers. + // - Upgrade existing feature_1 to maxVersionLevel: 3. + // - Upgrade non-existing feature_2 to maxVersionLevel: 5. + // - Verify results. + updateSupportedFeaturesInAllBrokers(defaultSupportedFeatures) + featureApis.setSupportedFeatures(defaultSupportedFeatures) + output = TestUtils.grabConsoleOutput(featureApis.upgradeAllFeatures()) + expected = + " [Upgrade]\tFeature: feature_1\tExistingFinalizedMaxVersion: 2\tNewFinalizedMaxVersion: 3\tResult: OK\n" + + " [Add]\tFeature: feature_2\tExistingFinalizedMaxVersion: -\tNewFinalizedMaxVersion: 5\tResult: OK\n" + assertEquals(expected, output) + + // Step (3): + // - Perform an upgrade of all features again. + // - Since supported features have not changed, expect that the above action does not yield + // any results. + output = TestUtils.grabConsoleOutput(featureApis.upgradeAllFeatures()) + assertTrue(output.isEmpty) + featureApis.setOptions(upgradeOpts) + output = TestUtils.grabConsoleOutput(featureApis.upgradeAllFeatures()) + assertTrue(output.isEmpty) + } finally { + featureApis.close() + } + } + + /** + * Tests if the FeatureApis#downgradeAllFeatures API works as expected during a success case. + */ + @Test + def testDowngradeFeaturesSuccess(): Unit = { + val downgradeOpts = new FeatureCommandOptions(Array("--bootstrap-server", brokerList, "--downgrade-all")) + val upgradeOpts = new FeatureCommandOptions(Array("--bootstrap-server", brokerList, "--upgrade-all")) + val featureApis = new FeatureApis(upgradeOpts) + try { + // Step (1): + // - Update the supported features across all brokers. + // - Upgrade non-existing feature_1 to maxVersionLevel: 3. + // - Upgrade non-existing feature_2 to maxVersionLevel: 5. + updateSupportedFeaturesInAllBrokers(defaultSupportedFeatures) + featureApis.setSupportedFeatures(defaultSupportedFeatures) + featureApis.upgradeAllFeatures() + + // Step (2): + // - Downgrade existing feature_1 to maxVersionLevel: 2. + // - Delete feature_2 since it is no longer supported by the FeatureApis object. + // - Verify results. + val downgradedFeatures = Features.supportedFeatures(Utils.mkMap(Utils.mkEntry("feature_1", new SupportedVersionRange(1, 2)))) + featureApis.setSupportedFeatures(downgradedFeatures) + featureApis.setOptions(downgradeOpts) + var output = TestUtils.grabConsoleOutput(featureApis.downgradeAllFeatures()) + var expected = + "[Downgrade]\tFeature: feature_1\tExistingFinalizedMaxVersion: 3\tNewFinalizedMaxVersion: 2\tResult: OK\n" + + " [Delete]\tFeature: feature_2\tExistingFinalizedMaxVersion: 5\tNewFinalizedMaxVersion: -\tResult: OK\n" + assertEquals(expected, output) + + // Step (3): + // - Perform a downgrade of all features again. + // - Since supported features have not changed, expect that the above action does not yield + // any results. + updateSupportedFeaturesInAllBrokers(downgradedFeatures) + output = TestUtils.grabConsoleOutput(featureApis.downgradeAllFeatures()) + assertTrue(output.isEmpty) + + // Step (4): + // - Delete feature_1 since it is no longer supported by the FeatureApis object. + // - Verify results. + featureApis.setSupportedFeatures(Features.emptySupportedFeatures()) + output = TestUtils.grabConsoleOutput(featureApis.downgradeAllFeatures()) + expected = + " [Delete]\tFeature: feature_1\tExistingFinalizedMaxVersion: 2\tNewFinalizedMaxVersion: -\tResult: OK\n" + assertEquals(expected, output) + } finally { + featureApis.close() + } + } + + /** + * Tests if the FeatureApis#upgradeAllFeatures API works as expected during a partial failure case. + */ + @Test + def testUpgradeFeaturesFailure(): Unit = { + val upgradeOpts = new FeatureCommandOptions(Array("--bootstrap-server", brokerList, "--upgrade-all")) + val featureApis = new FeatureApis(upgradeOpts) + try { + // Step (1): Update the supported features across all brokers. + updateSupportedFeaturesInAllBrokers(defaultSupportedFeatures) + + // Step (2): + // - Intentionally setup the FeatureApis object such that it contains incompatible target + // features (viz. feature_2 and feature_3). + // - Upgrade non-existing feature_1 to maxVersionLevel: 4. Expect the operation to fail with + // an incompatibility failure. + // - Upgrade non-existing feature_2 to maxVersionLevel: 5. Expect the operation to succeed. + // - Upgrade non-existing feature_3 to maxVersionLevel: 3. Expect the operation to fail + // since the feature is not supported. + val targetFeaturesWithIncompatibilities = + Features.supportedFeatures( + Utils.mkMap(Utils.mkEntry("feature_1", new SupportedVersionRange(1, 4)), + Utils.mkEntry("feature_2", new SupportedVersionRange(1, 5)), + Utils.mkEntry("feature_3", new SupportedVersionRange(1, 3)))) + featureApis.setSupportedFeatures(targetFeaturesWithIncompatibilities) + val output = TestUtils.grabConsoleOutput({ + val exception = assertThrows(classOf[UpdateFeaturesException], () => featureApis.upgradeAllFeatures()) + assertEquals("2 feature updates failed!", exception.getMessage) + }) + val expected = + " [Add]\tFeature: feature_1\tExistingFinalizedMaxVersion: -" + + "\tNewFinalizedMaxVersion: 4\tResult: FAILED due to" + + " org.apache.kafka.common.errors.InvalidRequestException: Could not apply finalized" + + " feature update because brokers were found to have incompatible versions for the" + + " feature.\n" + + " [Add]\tFeature: feature_2\tExistingFinalizedMaxVersion: -" + + "\tNewFinalizedMaxVersion: 5\tResult: OK\n" + + " [Add]\tFeature: feature_3\tExistingFinalizedMaxVersion: -" + + "\tNewFinalizedMaxVersion: 3\tResult: FAILED due to" + + " org.apache.kafka.common.errors.InvalidRequestException: Could not apply finalized" + + " feature update because the provided feature is not supported.\n" + assertEquals(expected, output) + } finally { + featureApis.close() + } + } +} diff --git a/core/src/test/scala/unit/kafka/admin/LeaderElectionCommandErrorTest.scala b/core/src/test/scala/unit/kafka/admin/LeaderElectionCommandErrorTest.scala new file mode 100644 index 0000000..eaef936 --- /dev/null +++ b/core/src/test/scala/unit/kafka/admin/LeaderElectionCommandErrorTest.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.admin + +import kafka.common.AdminCommandFailedException +import org.apache.kafka.common.errors.TimeoutException +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.concurrent.duration._ + +/** + * For some error cases, we can save a little build time by avoiding the overhead for + * cluster creation and cleanup because the command is expected to fail immediately. + */ +class LeaderElectionCommandErrorTest { + + @Test + def testTopicWithoutPartition(): Unit = { + val e = assertThrows(classOf[Throwable], () => LeaderElectionCommand.main( + Array( + "--bootstrap-server", "nohost:9092", + "--election-type", "unclean", + "--topic", "some-topic" + ) + )) + assertTrue(e.getMessage.startsWith("Missing required option(s)")) + assertTrue(e.getMessage.contains(" partition")) + } + + @Test + def testPartitionWithoutTopic(): Unit = { + val e = assertThrows(classOf[Throwable], () => LeaderElectionCommand.main( + Array( + "--bootstrap-server", "nohost:9092", + "--election-type", "unclean", + "--all-topic-partitions", + "--partition", "0" + ) + )) + assertEquals("Option partition is only allowed if topic is used", e.getMessage) + } + + @Test + def testMissingElectionType(): Unit = { + val e = assertThrows(classOf[Throwable], () => LeaderElectionCommand.main( + Array( + "--bootstrap-server", "nohost:9092", + "--topic", "some-topic", + "--partition", "0" + ) + )) + assertTrue(e.getMessage.startsWith("Missing required option(s)")) + assertTrue(e.getMessage.contains(" election-type")) + } + + @Test + def testMissingTopicPartitionSelection(): Unit = { + val e = assertThrows(classOf[Throwable], () => LeaderElectionCommand.main( + Array( + "--bootstrap-server", "nohost:9092", + "--election-type", "preferred" + ) + )) + assertTrue(e.getMessage.startsWith("One and only one of the following options is required: ")) + assertTrue(e.getMessage.contains(" all-topic-partitions")) + assertTrue(e.getMessage.contains(" topic")) + assertTrue(e.getMessage.contains(" path-to-json-file")) + } + + @Test + def testInvalidBroker(): Unit = { + val e = assertThrows(classOf[AdminCommandFailedException], () => LeaderElectionCommand.run( + Array( + "--bootstrap-server", "example.com:1234", + "--election-type", "unclean", + "--all-topic-partitions" + ), + 1.seconds + )) + assertTrue(e.getCause.isInstanceOf[TimeoutException]) + } +} diff --git a/core/src/test/scala/unit/kafka/admin/LeaderElectionCommandTest.scala b/core/src/test/scala/unit/kafka/admin/LeaderElectionCommandTest.scala new file mode 100644 index 0000000..b942f6f --- /dev/null +++ b/core/src/test/scala/unit/kafka/admin/LeaderElectionCommandTest.scala @@ -0,0 +1,274 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.admin + +import java.io.File +import java.nio.charset.StandardCharsets +import java.nio.file.{Files, Path} + +import kafka.common.AdminCommandFailedException +import kafka.server.IntegrationTestUtils.createTopic +import kafka.server.{KafkaConfig, KafkaServer} +import kafka.test.annotation.{ClusterTest, ClusterTestDefaults, Type} +import kafka.test.junit.ClusterTestExtensions +import kafka.test.{ClusterConfig, ClusterInstance} +import kafka.utils.TestUtils +import org.apache.kafka.clients.admin.AdminClientConfig +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.errors.UnknownTopicOrPartitionException +import org.apache.kafka.common.network.ListenerName +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.extension.ExtendWith +import org.junit.jupiter.api.{BeforeEach, Tag} + +@ExtendWith(value = Array(classOf[ClusterTestExtensions])) +@ClusterTestDefaults(clusterType = Type.BOTH, brokers = 3) +@Tag("integration") +final class LeaderElectionCommandTest(cluster: ClusterInstance) { + import LeaderElectionCommandTest._ + + val broker1 = 0 + val broker2 = 1 + val broker3 = 2 + + @BeforeEach + def setup(clusterConfig: ClusterConfig): Unit = { + TestUtils.verifyNoUnexpectedThreads("@BeforeEach") + clusterConfig.serverProperties().put(KafkaConfig.AutoLeaderRebalanceEnableProp, "false") + clusterConfig.serverProperties().put(KafkaConfig.ControlledShutdownEnableProp, "true") + clusterConfig.serverProperties().put(KafkaConfig.ControlledShutdownMaxRetriesProp, "1") + clusterConfig.serverProperties().put(KafkaConfig.ControlledShutdownRetryBackoffMsProp, "1000") + clusterConfig.serverProperties().put(KafkaConfig.OffsetsTopicReplicationFactorProp, "2") + } + + @ClusterTest + def testAllTopicPartition(): Unit = { + val client = cluster.createAdminClient() + val topic = "unclean-topic" + val partition = 0 + val assignment = Seq(broker2, broker3) + + cluster.waitForReadyBrokers() + createTopic(client, topic, Map(partition -> assignment)) + + val topicPartition = new TopicPartition(topic, partition) + + TestUtils.assertLeader(client, topicPartition, broker2) + cluster.shutdownBroker(broker3) + TestUtils.waitForBrokersOutOfIsr(client, Set(topicPartition), Set(broker3)) + cluster.shutdownBroker(broker2) + TestUtils.assertNoLeader(client, topicPartition) + cluster.startBroker(broker3) + TestUtils.waitForOnlineBroker(client, broker3) + + LeaderElectionCommand.main( + Array( + "--bootstrap-server", cluster.bootstrapServers(), + "--election-type", "unclean", + "--all-topic-partitions" + ) + ) + + TestUtils.assertLeader(client, topicPartition, broker3) + } + + @ClusterTest + def testTopicPartition(): Unit = { + val client = cluster.createAdminClient() + val topic = "unclean-topic" + val partition = 0 + val assignment = Seq(broker2, broker3) + + cluster.waitForReadyBrokers() + createTopic(client, topic, Map(partition -> assignment)) + + val topicPartition = new TopicPartition(topic, partition) + + TestUtils.assertLeader(client, topicPartition, broker2) + + cluster.shutdownBroker(broker3) + TestUtils.waitForBrokersOutOfIsr(client, Set(topicPartition), Set(broker3)) + cluster.shutdownBroker(broker2) + TestUtils.assertNoLeader(client, topicPartition) + cluster.startBroker(broker3) + TestUtils.waitForOnlineBroker(client, broker3) + + LeaderElectionCommand.main( + Array( + "--bootstrap-server", cluster.bootstrapServers(), + "--election-type", "unclean", + "--topic", topic, + "--partition", partition.toString + ) + ) + + TestUtils.assertLeader(client, topicPartition, broker3) + } + + @ClusterTest + def testPathToJsonFile(): Unit = { + val client = cluster.createAdminClient() + val topic = "unclean-topic" + val partition = 0 + val assignment = Seq(broker2, broker3) + + cluster.waitForReadyBrokers() + createTopic(client, topic, Map(partition -> assignment)) + + val topicPartition = new TopicPartition(topic, partition) + + TestUtils.assertLeader(client, topicPartition, broker2) + + cluster.shutdownBroker(broker3) + TestUtils.waitForBrokersOutOfIsr(client, Set(topicPartition), Set(broker3)) + cluster.shutdownBroker(broker2) + TestUtils.assertNoLeader(client, topicPartition) + cluster.startBroker(broker3) + TestUtils.waitForOnlineBroker(client, broker3) + + val topicPartitionPath = tempTopicPartitionFile(Set(topicPartition)) + + LeaderElectionCommand.main( + Array( + "--bootstrap-server", cluster.bootstrapServers(), + "--election-type", "unclean", + "--path-to-json-file", topicPartitionPath.toString + ) + ) + + TestUtils.assertLeader(client, topicPartition, broker3) + } + + @ClusterTest + def testPreferredReplicaElection(): Unit = { + val client = cluster.createAdminClient() + val topic = "preferred-topic" + val partition = 0 + val assignment = Seq(broker2, broker3) + + cluster.waitForReadyBrokers() + createTopic(client, topic, Map(partition -> assignment)) + + val topicPartition = new TopicPartition(topic, partition) + + TestUtils.assertLeader(client, topicPartition, broker2) + + cluster.shutdownBroker(broker2) + TestUtils.assertLeader(client, topicPartition, broker3) + cluster.startBroker(broker2) + TestUtils.waitForBrokersInIsr(client, topicPartition, Set(broker2)) + + LeaderElectionCommand.main( + Array( + "--bootstrap-server", cluster.bootstrapServers(), + "--election-type", "preferred", + "--all-topic-partitions" + ) + ) + + TestUtils.assertLeader(client, topicPartition, broker2) + } + + @ClusterTest + def testTopicDoesNotExist(): Unit = { + val e = assertThrows(classOf[AdminCommandFailedException], () => LeaderElectionCommand.main( + Array( + "--bootstrap-server", cluster.bootstrapServers(), + "--election-type", "preferred", + "--topic", "unknown-topic-name", + "--partition", "0" + ) + )) + assertTrue(e.getSuppressed()(0).isInstanceOf[UnknownTopicOrPartitionException]) + } + + @ClusterTest + def testElectionResultOutput(): Unit = { + val client = cluster.createAdminClient() + val topic = "non-preferred-topic" + val partition0 = 0 + val partition1 = 1 + val assignment0 = Seq(broker2, broker3) + val assignment1 = Seq(broker3, broker2) + + cluster.waitForReadyBrokers() + createTopic(client, topic, Map( + partition0 -> assignment0, + partition1 -> assignment1 + )) + + val topicPartition0 = new TopicPartition(topic, partition0) + val topicPartition1 = new TopicPartition(topic, partition1) + + TestUtils.assertLeader(client, topicPartition0, broker2) + TestUtils.assertLeader(client, topicPartition1, broker3) + + cluster.shutdownBroker(broker2) + TestUtils.assertLeader(client, topicPartition0, broker3) + cluster.startBroker(broker2) + TestUtils.waitForBrokersInIsr(client, topicPartition0, Set(broker2)) + TestUtils.waitForBrokersInIsr(client, topicPartition1, Set(broker2)) + + val topicPartitionPath = tempTopicPartitionFile(Set(topicPartition0, topicPartition1)) + val output = TestUtils.grabConsoleOutput( + LeaderElectionCommand.main( + Array( + "--bootstrap-server", cluster.bootstrapServers(), + "--election-type", "preferred", + "--path-to-json-file", topicPartitionPath.toString + ) + ) + ) + + val electionResultOutputIter = output.split("\n").iterator + + assertTrue(electionResultOutputIter.hasNext) + val firstLine = electionResultOutputIter.next() + assertTrue(firstLine.contains(s"Successfully completed leader election (PREFERRED) for partitions $topicPartition0"), + s"Unexpected output: $firstLine") + + assertTrue(electionResultOutputIter.hasNext) + val secondLine = electionResultOutputIter.next() + assertTrue(secondLine.contains(s"Valid replica already elected for partitions $topicPartition1"), + s"Unexpected output: $secondLine") + } +} + +object LeaderElectionCommandTest { + def createConfig(servers: Seq[KafkaServer]): Map[String, Object] = { + Map( + AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG -> bootstrapServers(servers), + AdminClientConfig.DEFAULT_API_TIMEOUT_MS_CONFIG -> "20000", + AdminClientConfig.REQUEST_TIMEOUT_MS_CONFIG -> "10000" + ) + } + + def bootstrapServers(servers: Seq[KafkaServer]): String = { + TestUtils.bootstrapServers(servers, new ListenerName("PLAINTEXT")) + } + + def tempTopicPartitionFile(partitions: Set[TopicPartition]): Path = { + val file = File.createTempFile("leader-election-command", ".json") + file.deleteOnExit() + + val jsonString = TestUtils.stringifyTopicPartitions(partitions) + + Files.write(file.toPath, jsonString.getBytes(StandardCharsets.UTF_8)) + + file.toPath + } +} diff --git a/core/src/test/scala/unit/kafka/admin/ListConsumerGroupTest.scala b/core/src/test/scala/unit/kafka/admin/ListConsumerGroupTest.scala new file mode 100644 index 0000000..f7f2d7e --- /dev/null +++ b/core/src/test/scala/unit/kafka/admin/ListConsumerGroupTest.scala @@ -0,0 +1,127 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.admin + +import joptsimple.OptionException +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test +import kafka.utils.TestUtils +import org.apache.kafka.common.ConsumerGroupState +import org.apache.kafka.clients.admin.ConsumerGroupListing +import java.util.Optional + +class ListConsumerGroupTest extends ConsumerGroupCommandTest { + + @Test + def testListConsumerGroups(): Unit = { + val simpleGroup = "simple-group" + addSimpleGroupExecutor(group = simpleGroup) + addConsumerGroupExecutor(numConsumers = 1) + + val cgcArgs = Array("--bootstrap-server", brokerList, "--list") + val service = getConsumerGroupService(cgcArgs) + + val expectedGroups = Set(group, simpleGroup) + var foundGroups = Set.empty[String] + TestUtils.waitUntilTrue(() => { + foundGroups = service.listConsumerGroups().toSet + expectedGroups == foundGroups + }, s"Expected --list to show groups $expectedGroups, but found $foundGroups.") + } + + @Test + def testListWithUnrecognizedNewConsumerOption(): Unit = { + val cgcArgs = Array("--new-consumer", "--bootstrap-server", brokerList, "--list") + assertThrows(classOf[OptionException], () => getConsumerGroupService(cgcArgs)) + } + + @Test + def testListConsumerGroupsWithStates(): Unit = { + val simpleGroup = "simple-group" + addSimpleGroupExecutor(group = simpleGroup) + addConsumerGroupExecutor(numConsumers = 1) + + val cgcArgs = Array("--bootstrap-server", brokerList, "--list", "--state") + val service = getConsumerGroupService(cgcArgs) + + val expectedListing = Set( + new ConsumerGroupListing(simpleGroup, true, Optional.of(ConsumerGroupState.EMPTY)), + new ConsumerGroupListing(group, false, Optional.of(ConsumerGroupState.STABLE))) + + var foundListing = Set.empty[ConsumerGroupListing] + TestUtils.waitUntilTrue(() => { + foundListing = service.listConsumerGroupsWithState(ConsumerGroupState.values.toSet).toSet + expectedListing == foundListing + }, s"Expected to show groups $expectedListing, but found $foundListing") + + val expectedListingStable = Set( + new ConsumerGroupListing(group, false, Optional.of(ConsumerGroupState.STABLE))) + + foundListing = Set.empty[ConsumerGroupListing] + TestUtils.waitUntilTrue(() => { + foundListing = service.listConsumerGroupsWithState(Set(ConsumerGroupState.STABLE)).toSet + expectedListingStable == foundListing + }, s"Expected to show groups $expectedListingStable, but found $foundListing") + } + + @Test + def testConsumerGroupStatesFromString(): Unit = { + var result = ConsumerGroupCommand.consumerGroupStatesFromString("Stable") + assertEquals(Set(ConsumerGroupState.STABLE), result) + + result = ConsumerGroupCommand.consumerGroupStatesFromString("Stable, PreparingRebalance") + assertEquals(Set(ConsumerGroupState.STABLE, ConsumerGroupState.PREPARING_REBALANCE), result) + + result = ConsumerGroupCommand.consumerGroupStatesFromString("Dead,CompletingRebalance,") + assertEquals(Set(ConsumerGroupState.DEAD, ConsumerGroupState.COMPLETING_REBALANCE), result) + + assertThrows(classOf[IllegalArgumentException], () => ConsumerGroupCommand.consumerGroupStatesFromString("bad, wrong")) + + assertThrows(classOf[IllegalArgumentException], () => ConsumerGroupCommand.consumerGroupStatesFromString("stable")) + + assertThrows(classOf[IllegalArgumentException], () => ConsumerGroupCommand.consumerGroupStatesFromString(" bad, Stable")) + + assertThrows(classOf[IllegalArgumentException], () => ConsumerGroupCommand.consumerGroupStatesFromString(" , ,")) + } + + @Test + def testListGroupCommand(): Unit = { + val simpleGroup = "simple-group" + addSimpleGroupExecutor(group = simpleGroup) + addConsumerGroupExecutor(numConsumers = 1) + var out = "" + + var cgcArgs = Array("--bootstrap-server", brokerList, "--list") + TestUtils.waitUntilTrue(() => { + out = TestUtils.grabConsoleOutput(ConsumerGroupCommand.main(cgcArgs)) + !out.contains("STATE") && out.contains(simpleGroup) && out.contains(group) + }, s"Expected to find $simpleGroup, $group and no header, but found $out") + + cgcArgs = Array("--bootstrap-server", brokerList, "--list", "--state") + TestUtils.waitUntilTrue(() => { + out = TestUtils.grabConsoleOutput(ConsumerGroupCommand.main(cgcArgs)) + out.contains("STATE") && out.contains(simpleGroup) && out.contains(group) + }, s"Expected to find $simpleGroup, $group and the header, but found $out") + + cgcArgs = Array("--bootstrap-server", brokerList, "--list", "--state", "Stable") + TestUtils.waitUntilTrue(() => { + out = TestUtils.grabConsoleOutput(ConsumerGroupCommand.main(cgcArgs)) + out.contains("STATE") && out.contains(group) && out.contains("Stable") + }, s"Expected to find $group in state Stable and the header, but found $out") + } + +} diff --git a/core/src/test/scala/unit/kafka/admin/LogDirsCommandTest.scala b/core/src/test/scala/unit/kafka/admin/LogDirsCommandTest.scala new file mode 100644 index 0000000..397d6d5 --- /dev/null +++ b/core/src/test/scala/unit/kafka/admin/LogDirsCommandTest.scala @@ -0,0 +1,76 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.admin + +import java.io.{ByteArrayOutputStream, PrintStream} +import java.nio.charset.StandardCharsets + +import kafka.integration.KafkaServerTestHarness +import kafka.server.KafkaConfig +import kafka.utils.TestUtils +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test + +import scala.collection.Seq + +class LogDirsCommandTest extends KafkaServerTestHarness { + + def generateConfigs: Seq[KafkaConfig] = { + TestUtils.createBrokerConfigs(1, zkConnect) + .map(KafkaConfig.fromProps) + } + + @Test + def checkLogDirsCommandOutput(): Unit = { + val byteArrayOutputStream = new ByteArrayOutputStream + val printStream = new PrintStream(byteArrayOutputStream, false, StandardCharsets.UTF_8.name()) + //input exist brokerList + LogDirsCommand.describe(Array("--bootstrap-server", brokerList, "--broker-list", "0", "--describe"), printStream) + val existingBrokersContent = new String(byteArrayOutputStream.toByteArray, StandardCharsets.UTF_8) + val existingBrokersLineIter = existingBrokersContent.split("\n").iterator + + assertTrue(existingBrokersLineIter.hasNext) + assertTrue(existingBrokersLineIter.next().contains(s"Querying brokers for log directories information")) + + //input nonexistent brokerList + byteArrayOutputStream.reset() + LogDirsCommand.describe(Array("--bootstrap-server", brokerList, "--broker-list", "0,1,2", "--describe"), printStream) + val nonExistingBrokersContent = new String(byteArrayOutputStream.toByteArray, StandardCharsets.UTF_8) + val nonExistingBrokersLineIter = nonExistingBrokersContent.split("\n").iterator + + assertTrue(nonExistingBrokersLineIter.hasNext) + assertTrue(nonExistingBrokersLineIter.next().contains(s"ERROR: The given brokers do not exist from --broker-list: 1,2. Current existent brokers: 0")) + + //input duplicate ids + byteArrayOutputStream.reset() + LogDirsCommand.describe(Array("--bootstrap-server", brokerList, "--broker-list", "0,0,1,2,2", "--describe"), printStream) + val duplicateBrokersContent = new String(byteArrayOutputStream.toByteArray, StandardCharsets.UTF_8) + val duplicateBrokersLineIter = duplicateBrokersContent.split("\n").iterator + + assertTrue(duplicateBrokersLineIter.hasNext) + assertTrue(duplicateBrokersLineIter.next().contains(s"ERROR: The given brokers do not exist from --broker-list: 1,2. Current existent brokers: 0")) + + //use all brokerList for current cluster + byteArrayOutputStream.reset() + LogDirsCommand.describe(Array("--bootstrap-server", brokerList, "--describe"), printStream) + val allBrokersContent = new String(byteArrayOutputStream.toByteArray, StandardCharsets.UTF_8) + val allBrokersLineIter = allBrokersContent.split("\n").iterator + + assertTrue(allBrokersLineIter.hasNext) + assertTrue(allBrokersLineIter.next().contains(s"Querying brokers for log directories information")) + } +} diff --git a/core/src/test/scala/unit/kafka/admin/RackAwareTest.scala b/core/src/test/scala/unit/kafka/admin/RackAwareTest.scala new file mode 100644 index 0000000..6ce4f7b --- /dev/null +++ b/core/src/test/scala/unit/kafka/admin/RackAwareTest.scala @@ -0,0 +1,85 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.admin + +import scala.collection.{Map, Seq, mutable} +import org.junit.jupiter.api.Assertions._ + +trait RackAwareTest { + + def checkReplicaDistribution(assignment: Map[Int, Seq[Int]], + brokerRackMapping: Map[Int, String], + numBrokers: Int, + numPartitions: Int, + replicationFactor: Int, + verifyRackAware: Boolean = true, + verifyLeaderDistribution: Boolean = true, + verifyReplicasDistribution: Boolean = true): Unit = { + // always verify that no broker will be assigned for more than one replica + for ((_, brokerList) <- assignment) { + assertEquals(brokerList.toSet.size, brokerList.size, + "More than one replica is assigned to same broker for the same partition") + } + val distribution = getReplicaDistribution(assignment, brokerRackMapping) + + if (verifyRackAware) { + val partitionRackMap = distribution.partitionRacks + assertEquals(List.fill(numPartitions)(replicationFactor), partitionRackMap.values.toList.map(_.distinct.size), + "More than one replica of the same partition is assigned to the same rack") + } + + if (verifyLeaderDistribution) { + val leaderCount = distribution.brokerLeaderCount + val leaderCountPerBroker = numPartitions / numBrokers + assertEquals(List.fill(numBrokers)(leaderCountPerBroker), leaderCount.values.toList, + "Preferred leader count is not even for brokers") + } + + if (verifyReplicasDistribution) { + val replicasCount = distribution.brokerReplicasCount + val numReplicasPerBroker = numPartitions * replicationFactor / numBrokers + assertEquals(List.fill(numBrokers)(numReplicasPerBroker), replicasCount.values.toList, + "Replica count is not even for broker") + } + } + + def getReplicaDistribution(assignment: Map[Int, Seq[Int]], brokerRackMapping: Map[Int, String]): ReplicaDistributions = { + val leaderCount = mutable.Map[Int, Int]() + val partitionCount = mutable.Map[Int, Int]() + val partitionRackMap = mutable.Map[Int, List[String]]() + assignment.foreach { case (partitionId, replicaList) => + val leader = replicaList.head + leaderCount(leader) = leaderCount.getOrElse(leader, 0) + 1 + for (brokerId <- replicaList) { + partitionCount(brokerId) = partitionCount.getOrElse(brokerId, 0) + 1 + val rack = brokerRackMapping.getOrElse(brokerId, sys.error(s"No mapping found for $brokerId in `brokerRackMapping`")) + partitionRackMap(partitionId) = rack :: partitionRackMap.getOrElse(partitionId, List()) + } + } + ReplicaDistributions(partitionRackMap, leaderCount, partitionCount) + } + + def toBrokerMetadata(rackMap: Map[Int, String], brokersWithoutRack: Seq[Int] = Seq.empty): Seq[BrokerMetadata] = + rackMap.toSeq.map { case (brokerId, rack) => + BrokerMetadata(brokerId, Some(rack)) + } ++ brokersWithoutRack.map { brokerId => + BrokerMetadata(brokerId, None) + }.sortBy(_.id) + +} + +case class ReplicaDistributions(partitionRacks: Map[Int, Seq[String]], brokerLeaderCount: Map[Int, Int], brokerReplicasCount: Map[Int, Int]) diff --git a/core/src/test/scala/unit/kafka/admin/ReassignPartitionsCommandArgsTest.scala b/core/src/test/scala/unit/kafka/admin/ReassignPartitionsCommandArgsTest.scala new file mode 100644 index 0000000..13fc262 --- /dev/null +++ b/core/src/test/scala/unit/kafka/admin/ReassignPartitionsCommandArgsTest.scala @@ -0,0 +1,284 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.admin + +import kafka.utils.Exit +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, Timeout} + +@Timeout(60) +class ReassignPartitionsCommandArgsTest { + + val missingBootstrapServerMsg = "Please specify --bootstrap-server" + + @BeforeEach + def setUp(): Unit = { + Exit.setExitProcedure((_, message) => throw new IllegalArgumentException(message.orNull)) + } + + @AfterEach + def tearDown(): Unit = { + Exit.resetExitProcedure() + } + + ///// Test valid argument parsing + @Test + def shouldCorrectlyParseValidMinimumGenerateOptions(): Unit = { + val args = Array( + "--bootstrap-server", "localhost:1234", + "--generate", + "--broker-list", "101,102", + "--topics-to-move-json-file", "myfile.json") + ReassignPartitionsCommand.validateAndParseArgs(args) + } + + @Test + def shouldCorrectlyParseValidMinimumExecuteOptions(): Unit = { + val args = Array( + "--bootstrap-server", "localhost:1234", + "--execute", + "--reassignment-json-file", "myfile.json") + ReassignPartitionsCommand.validateAndParseArgs(args) + } + + @Test + def shouldCorrectlyParseValidMinimumVerifyOptions(): Unit = { + val args = Array( + "--bootstrap-server", "localhost:1234", + "--verify", + "--reassignment-json-file", "myfile.json") + ReassignPartitionsCommand.validateAndParseArgs(args) + } + + @Test + def shouldAllowThrottleOptionOnExecute(): Unit = { + val args = Array( + "--bootstrap-server", "localhost:1234", + "--execute", + "--throttle", "100", + "--reassignment-json-file", "myfile.json") + ReassignPartitionsCommand.validateAndParseArgs(args) + } + + @Test + def shouldUseDefaultsIfEnabled(): Unit = { + val args = Array( + "--bootstrap-server", "localhost:1234", + "--execute", + "--reassignment-json-file", "myfile.json") + val opts = ReassignPartitionsCommand.validateAndParseArgs(args) + assertEquals(10000L, opts.options.valueOf(opts.timeoutOpt)) + assertEquals(-1L, opts.options.valueOf(opts.interBrokerThrottleOpt)) + } + + @Test + def testList(): Unit = { + val args = Array( + "--list", + "--bootstrap-server", "localhost:1234") + ReassignPartitionsCommand.validateAndParseArgs(args) + } + + @Test + def testCancelWithPreserveThrottlesOption(): Unit = { + val args = Array( + "--cancel", + "--bootstrap-server", "localhost:1234", + "--reassignment-json-file", "myfile.json", + "--preserve-throttles") + ReassignPartitionsCommand.validateAndParseArgs(args) + } + + ///// Test handling missing or invalid actions + @Test + def shouldFailIfNoArgs(): Unit = { + val args: Array[String]= Array() + shouldFailWith(ReassignPartitionsCommand.helpText, args) + } + + @Test + def shouldFailIfBlankArg(): Unit = { + val args = Array(" ") + shouldFailWith("Command must include exactly one action", args) + } + + @Test + def shouldFailIfMultipleActions(): Unit = { + val args = Array( + "--bootstrap-server", "localhost:1234", + "--execute", + "--verify", + "--reassignment-json-file", "myfile.json" + ) + shouldFailWith("Command must include exactly one action", args) + } + + ///// Test --execute + @Test + def shouldNotAllowExecuteWithTopicsOption(): Unit = { + val args = Array( + "--bootstrap-server", "localhost:1234", + "--execute", + "--reassignment-json-file", "myfile.json", + "--topics-to-move-json-file", "myfile.json") + shouldFailWith("Option \"[topics-to-move-json-file]\" can't be used with action \"[execute]\"", args) + } + + @Test + def shouldNotAllowExecuteWithBrokerList(): Unit = { + val args = Array( + "--bootstrap-server", "localhost:1234", + "--execute", + "--reassignment-json-file", "myfile.json", + "--broker-list", "101,102" + ) + shouldFailWith("Option \"[broker-list]\" can't be used with action \"[execute]\"", args) + } + + @Test + def shouldNotAllowExecuteWithoutReassignmentOption(): Unit = { + val args = Array( + "--bootstrap-server", "localhost:1234", + "--execute") + shouldFailWith("Missing required argument \"[reassignment-json-file]\"", args) + } + + @Test + def testMissingBootstrapServerArgumentForExecute(): Unit = { + val args = Array( + "--execute") + shouldFailWith(missingBootstrapServerMsg, args) + } + + ///// Test --generate + @Test + def shouldNotAllowGenerateWithoutBrokersAndTopicsOptions(): Unit = { + val args = Array( + "--bootstrap-server", "localhost:1234", + "--generate") + shouldFailWith("Missing required argument \"[topics-to-move-json-file]\"", args) + } + + @Test + def shouldNotAllowGenerateWithoutBrokersOption(): Unit = { + val args = Array( + "--bootstrap-server", "localhost:1234", + "--topics-to-move-json-file", "myfile.json", + "--generate") + shouldFailWith("Missing required argument \"[broker-list]\"", args) + } + + @Test + def shouldNotAllowGenerateWithoutTopicsOption(): Unit = { + val args = Array( + "--bootstrap-server", "localhost:1234", + "--broker-list", "101,102", + "--generate") + shouldFailWith("Missing required argument \"[topics-to-move-json-file]\"", args) + } + + @Test + def shouldNotAllowGenerateWithThrottleOption(): Unit = { + val args = Array( + "--bootstrap-server", "localhost:1234", + "--generate", + "--broker-list", "101,102", + "--throttle", "100", + "--topics-to-move-json-file", "myfile.json") + shouldFailWith("Option \"[throttle]\" can't be used with action \"[generate]\"", args) + } + + @Test + def shouldNotAllowGenerateWithReassignmentOption(): Unit = { + val args = Array( + "--bootstrap-server", "localhost:1234", + "--generate", + "--broker-list", "101,102", + "--topics-to-move-json-file", "myfile.json", + "--reassignment-json-file", "myfile.json") + shouldFailWith("Option \"[reassignment-json-file]\" can't be used with action \"[generate]\"", args) + } + + @Test + def shouldPrintHelpTextIfHelpArg(): Unit = { + val args: Array[String]= Array("--help") + // note, this is not actually a failed case, it's just we share the same `printUsageAndDie` method when wrong arg received + shouldFailWith(ReassignPartitionsCommand.helpText, args) + } + + ///// Test --verify + @Test + def shouldNotAllowVerifyWithoutReassignmentOption(): Unit = { + val args = Array( + "--bootstrap-server", "localhost:1234", + "--verify") + shouldFailWith("Missing required argument \"[reassignment-json-file]\"", args) + } + + @Test + def shouldNotAllowBrokersListWithVerifyOption(): Unit = { + val args = Array( + "--bootstrap-server", "localhost:1234", + "--verify", + "--broker-list", "100,101", + "--reassignment-json-file", "myfile.json") + shouldFailWith("Option \"[broker-list]\" can't be used with action \"[verify]\"", args) + } + + @Test + def shouldNotAllowThrottleWithVerifyOption(): Unit = { + val args = Array( + "--bootstrap-server", "localhost:1234", + "--verify", + "--throttle", "100", + "--reassignment-json-file", "myfile.json") + shouldFailWith("Option \"[throttle]\" can't be used with action \"[verify]\"", args) + } + + @Test + def shouldNotAllowTopicsOptionWithVerify(): Unit = { + val args = Array( + "--bootstrap-server", "localhost:1234", + "--verify", + "--reassignment-json-file", "myfile.json", + "--topics-to-move-json-file", "myfile.json") + shouldFailWith("Option \"[topics-to-move-json-file]\" can't be used with action \"[verify]\"", args) + } + + def shouldFailWith(msg: String, args: Array[String]): Unit = { + val e = assertThrows(classOf[Exception], () => ReassignPartitionsCommand.validateAndParseArgs(args), + () => s"Should have failed with [$msg] but no failure occurred.") + assertTrue(e.getMessage.startsWith(msg), s"Expected exception with message:\n[$msg]\nbut was\n[${e.getMessage}]") + } + + ///// Test --cancel + @Test + def shouldNotAllowCancelWithoutBootstrapServerOption(): Unit = { + val args = Array( + "--cancel") + shouldFailWith(missingBootstrapServerMsg, args) + } + + @Test + def shouldNotAllowCancelWithoutReassignmentJsonFile(): Unit = { + val args = Array( + "--cancel", + "--bootstrap-server", "localhost:1234", + "--preserve-throttles") + shouldFailWith("Missing required argument \"[reassignment-json-file]\"", args) + } +} diff --git a/core/src/test/scala/unit/kafka/admin/ReassignPartitionsUnitTest.scala b/core/src/test/scala/unit/kafka/admin/ReassignPartitionsUnitTest.scala new file mode 100644 index 0000000..cbbebe7 --- /dev/null +++ b/core/src/test/scala/unit/kafka/admin/ReassignPartitionsUnitTest.scala @@ -0,0 +1,670 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.admin + +import java.util.concurrent.ExecutionException +import java.util.{Arrays, Collections} +import kafka.admin.ReassignPartitionsCommand._ +import kafka.common.AdminCommandFailedException +import kafka.utils.Exit +import org.apache.kafka.clients.admin.{Config, MockAdminClient, PartitionReassignment} +import org.apache.kafka.common.config.ConfigResource +import org.apache.kafka.common.errors.{InvalidReplicationFactorException, UnknownTopicOrPartitionException} +import org.apache.kafka.common.{Node, TopicPartition, TopicPartitionInfo, TopicPartitionReplica} +import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertThrows, assertTrue} +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, Timeout} + +import scala.collection.mutable +import scala.jdk.CollectionConverters._ + +@Timeout(60) +class ReassignPartitionsUnitTest { + + @BeforeEach + def setUp(): Unit = { + Exit.setExitProcedure((_, message) => throw new IllegalArgumentException(message.orNull)) + } + + @AfterEach + def tearDown(): Unit = { + Exit.resetExitProcedure() + } + + @Test + def testCompareTopicPartitions(): Unit = { + assertTrue(compareTopicPartitions(new TopicPartition("abc", 0), + new TopicPartition("abc", 1))) + assertFalse(compareTopicPartitions(new TopicPartition("def", 0), + new TopicPartition("abc", 1))) + } + + @Test + def testCompareTopicPartitionReplicas(): Unit = { + assertTrue(compareTopicPartitionReplicas(new TopicPartitionReplica("def", 0, 0), + new TopicPartitionReplica("abc", 0, 1))) + assertFalse(compareTopicPartitionReplicas(new TopicPartitionReplica("def", 0, 0), + new TopicPartitionReplica("cde", 0, 0))) + } + + @Test + def testPartitionReassignStatesToString(): Unit = { + assertEquals(Seq( + "Status of partition reassignment:", + "Reassignment of partition bar-0 is still in progress.", + "Reassignment of partition foo-0 is complete.", + "Reassignment of partition foo-1 is still in progress."). + mkString(System.lineSeparator()), + partitionReassignmentStatesToString(Map( + new TopicPartition("foo", 0) -> + PartitionReassignmentState(Seq(1, 2, 3), Seq(1, 2, 3), true), + new TopicPartition("foo", 1) -> + PartitionReassignmentState(Seq(1, 2, 3), Seq(1, 2, 4), false), + new TopicPartition("bar", 0) -> + PartitionReassignmentState(Seq(1, 2, 3), Seq(1, 2, 4), false), + ))) + } + + private def addTopics(adminClient: MockAdminClient): Unit = { + val b = adminClient.brokers() + adminClient.addTopic(false, "foo", Arrays.asList( + new TopicPartitionInfo(0, b.get(0), + Arrays.asList(b.get(0), b.get(1), b.get(2)), + Arrays.asList(b.get(0), b.get(1))), + new TopicPartitionInfo(1, b.get(1), + Arrays.asList(b.get(1), b.get(2), b.get(3)), + Arrays.asList(b.get(1), b.get(2), b.get(3))) + ), Collections.emptyMap()) + adminClient.addTopic(false, "bar", Arrays.asList( + new TopicPartitionInfo(0, b.get(2), + Arrays.asList(b.get(2), b.get(3), b.get(0)), + Arrays.asList(b.get(2), b.get(3), b.get(0))) + ), Collections.emptyMap()) + } + + @Test + def testFindPartitionReassignmentStates(): Unit = { + val adminClient = new MockAdminClient.Builder().numBrokers(4).build() + try { + addTopics(adminClient) + // Create a reassignment and test findPartitionReassignmentStates. + val reassignmentResult: Map[TopicPartition, Class[_ <: Throwable]] = alterPartitionReassignments(adminClient, Map( + new TopicPartition("foo", 0) -> Seq(0,1,3), + new TopicPartition("quux", 0) -> Seq(1,2,3))).map { case (k, v) => k -> v.getClass }.toMap + assertEquals(Map(new TopicPartition("quux", 0) -> classOf[UnknownTopicOrPartitionException]), + reassignmentResult) + assertEquals((Map( + new TopicPartition("foo", 0) -> PartitionReassignmentState(Seq(0,1,2), Seq(0,1,3), false), + new TopicPartition("foo", 1) -> PartitionReassignmentState(Seq(1,2,3), Seq(1,2,3), true) + ), true), + findPartitionReassignmentStates(adminClient, Seq( + (new TopicPartition("foo", 0), Seq(0,1,3)), + (new TopicPartition("foo", 1), Seq(1,2,3)) + ))) + // Cancel the reassignment and test findPartitionReassignmentStates again. + val cancelResult: Map[TopicPartition, Class[_ <: Throwable]] = cancelPartitionReassignments(adminClient, + Set(new TopicPartition("foo", 0), new TopicPartition("quux", 2))).map { case (k, v) => + k -> v.getClass + }.toMap + assertEquals(Map(new TopicPartition("quux", 2) -> classOf[UnknownTopicOrPartitionException]), + cancelResult) + assertEquals((Map( + new TopicPartition("foo", 0) -> PartitionReassignmentState(Seq(0,1,2), Seq(0,1,3), true), + new TopicPartition("foo", 1) -> PartitionReassignmentState(Seq(1,2,3), Seq(1,2,3), true) + ), false), + findPartitionReassignmentStates(adminClient, Seq( + (new TopicPartition("foo", 0), Seq(0,1,3)), + (new TopicPartition("foo", 1), Seq(1,2,3)) + ))) + } finally { + adminClient.close() + } + } + + @Test + def testFindLogDirMoveStates(): Unit = { + val adminClient = new MockAdminClient.Builder(). + numBrokers(4). + brokerLogDirs(Arrays.asList( + Arrays.asList("/tmp/kafka-logs0", "/tmp/kafka-logs1"), + Arrays.asList("/tmp/kafka-logs0", "/tmp/kafka-logs1"), + Arrays.asList("/tmp/kafka-logs0", "/tmp/kafka-logs1"), + Arrays.asList("/tmp/kafka-logs0", null))). + build(); + try { + addTopics(adminClient) + val b = adminClient.brokers() + adminClient.addTopic(false, "quux", Arrays.asList( + new TopicPartitionInfo(0, b.get(2), + Arrays.asList(b.get(1), b.get(2), b.get(3)), + Arrays.asList(b.get(1), b.get(2), b.get(3)))), + Collections.emptyMap()) + adminClient.alterReplicaLogDirs(Map( + new TopicPartitionReplica("foo", 0, 0) -> "/tmp/kafka-logs1", + new TopicPartitionReplica("quux", 0, 0) -> "/tmp/kafka-logs1" + ).asJava).all().get() + assertEquals(Map( + new TopicPartitionReplica("bar", 0, 0) -> new CompletedMoveState("/tmp/kafka-logs0"), + new TopicPartitionReplica("foo", 0, 0) -> new ActiveMoveState("/tmp/kafka-logs0", + "/tmp/kafka-logs1", "/tmp/kafka-logs1"), + new TopicPartitionReplica("foo", 1, 0) -> new CancelledMoveState("/tmp/kafka-logs0", + "/tmp/kafka-logs1"), + new TopicPartitionReplica("quux", 1, 0) -> new MissingLogDirMoveState("/tmp/kafka-logs1"), + new TopicPartitionReplica("quuz", 0, 0) -> new MissingReplicaMoveState("/tmp/kafka-logs0") + ), findLogDirMoveStates(adminClient, Map( + new TopicPartitionReplica("bar", 0, 0) -> "/tmp/kafka-logs0", + new TopicPartitionReplica("foo", 0, 0) -> "/tmp/kafka-logs1", + new TopicPartitionReplica("foo", 1, 0) -> "/tmp/kafka-logs1", + new TopicPartitionReplica("quux", 1, 0) -> "/tmp/kafka-logs1", + new TopicPartitionReplica("quuz", 0, 0) -> "/tmp/kafka-logs0" + ))) + } finally { + adminClient.close() + } + } + + @Test + def testReplicaMoveStatesToString(): Unit = { + assertEquals(Seq( + "Reassignment of replica bar-0-0 completed successfully.", + "Reassignment of replica foo-0-0 is still in progress.", + "Partition foo-1 on broker 0 is not being moved from log dir /tmp/kafka-logs0 to /tmp/kafka-logs1.", + "Partition quux-0 cannot be found in any live log directory on broker 0.", + "Partition quux-1 on broker 1 is being moved to log dir /tmp/kafka-logs2 instead of /tmp/kafka-logs1.", + "Partition quux-2 is not found in any live log dir on broker 1. " + + "There is likely an offline log directory on the broker.").mkString(System.lineSeparator()), + replicaMoveStatesToString(Map( + new TopicPartitionReplica("bar", 0, 0) -> CompletedMoveState("/tmp/kafka-logs0"), + new TopicPartitionReplica("foo", 0, 0) -> ActiveMoveState("/tmp/kafka-logs0", + "/tmp/kafka-logs1", "/tmp/kafka-logs1"), + new TopicPartitionReplica("foo", 1, 0) -> CancelledMoveState("/tmp/kafka-logs0", + "/tmp/kafka-logs1"), + new TopicPartitionReplica("quux", 0, 0) -> MissingReplicaMoveState("/tmp/kafka-logs1"), + new TopicPartitionReplica("quux", 1, 1) -> ActiveMoveState("/tmp/kafka-logs0", + "/tmp/kafka-logs1", "/tmp/kafka-logs2"), + new TopicPartitionReplica("quux", 2, 1) -> MissingLogDirMoveState("/tmp/kafka-logs1") + ))) + } + + @Test + def testGetReplicaAssignments(): Unit = { + val adminClient = new MockAdminClient.Builder().numBrokers(4).build() + try { + addTopics(adminClient) + assertEquals(Map( + new TopicPartition("foo", 0) -> Seq(0, 1, 2), + new TopicPartition("foo", 1) -> Seq(1, 2, 3), + ), + getReplicaAssignmentForTopics(adminClient, Seq("foo"))) + assertEquals(Map( + new TopicPartition("foo", 0) -> Seq(0, 1, 2), + new TopicPartition("bar", 0) -> Seq(2, 3, 0), + ), + getReplicaAssignmentForPartitions(adminClient, Set( + new TopicPartition("foo", 0), new TopicPartition("bar", 0)))) + } finally { + adminClient.close() + } + } + + @Test + def testGetBrokerRackInformation(): Unit = { + val adminClient = new MockAdminClient.Builder(). + brokers(Arrays.asList(new Node(0, "locahost", 9092, "rack0"), + new Node(1, "locahost", 9093, "rack1"), + new Node(2, "locahost", 9094, null))). + build() + try { + assertEquals(Seq( + BrokerMetadata(0, Some("rack0")), + BrokerMetadata(1, Some("rack1")) + ), getBrokerMetadata(adminClient, Seq(0, 1), true)) + assertEquals(Seq( + BrokerMetadata(0, None), + BrokerMetadata(1, None) + ), getBrokerMetadata(adminClient, Seq(0, 1), false)) + assertStartsWith("Not all brokers have rack information", + assertThrows(classOf[AdminOperationException], + () => getBrokerMetadata(adminClient, Seq(1, 2), true)).getMessage) + assertEquals(Seq( + BrokerMetadata(1, None), + BrokerMetadata(2, None) + ), getBrokerMetadata(adminClient, Seq(1, 2), false)) + } finally { + adminClient.close() + } + } + + @Test + def testParseGenerateAssignmentArgs(): Unit = { + assertStartsWith("Broker list contains duplicate entries", + assertThrows(classOf[AdminCommandFailedException], () => parseGenerateAssignmentArgs( + """{"topics": [{"topic": "foo"}], "version":1}""", "1,1,2"), + () => "Expected to detect duplicate broker list entries").getMessage) + assertStartsWith("Broker list contains duplicate entries", + assertThrows(classOf[AdminCommandFailedException], () => parseGenerateAssignmentArgs( + """{"topics": [{"topic": "foo"}], "version":1}""", "5,2,3,4,5"), + () => "Expected to detect duplicate broker list entries").getMessage) + assertEquals((Seq(5,2,3,4),Seq("foo")), + parseGenerateAssignmentArgs("""{"topics": [{"topic": "foo"}], "version":1}""", + "5,2,3,4")) + assertStartsWith("List of topics to reassign contains duplicate entries", + assertThrows(classOf[AdminCommandFailedException], () => parseGenerateAssignmentArgs( + """{"topics": [{"topic": "foo"},{"topic": "foo"}], "version":1}""", "5,2,3,4"), + () => "Expected to detect duplicate topic entries").getMessage) + assertEquals((Seq(5,3,4),Seq("foo","bar")), + parseGenerateAssignmentArgs( + """{"topics": [{"topic": "foo"},{"topic": "bar"}], "version":1}""", + "5,3,4")) + } + + @Test + def testGenerateAssignmentFailsWithoutEnoughReplicas(): Unit = { + val adminClient = new MockAdminClient.Builder().numBrokers(4).build() + try { + addTopics(adminClient) + assertStartsWith("Replication factor: 3 larger than available brokers: 2", + assertThrows(classOf[InvalidReplicationFactorException], + () => generateAssignment(adminClient, """{"topics":[{"topic":"foo"},{"topic":"bar"}]}""", "0,1", false), + () => "Expected generateAssignment to fail").getMessage) + } finally { + adminClient.close() + } + } + + @Test + def testGenerateAssignmentWithInvalidPartitionsFails(): Unit = { + val adminClient = new MockAdminClient.Builder().numBrokers(5).build() + try { + addTopics(adminClient) + assertStartsWith("Topic quux not found", + assertThrows(classOf[ExecutionException], + () => generateAssignment(adminClient, """{"topics":[{"topic":"foo"},{"topic":"quux"}]}""", "0,1", false), + () => "Expected generateAssignment to fail").getCause.getMessage) + } finally { + adminClient.close() + } + } + + @Test + def testGenerateAssignmentWithInconsistentRacks(): Unit = { + val adminClient = new MockAdminClient.Builder(). + brokers(Arrays.asList( + new Node(0, "locahost", 9092, "rack0"), + new Node(1, "locahost", 9093, "rack0"), + new Node(2, "locahost", 9094, null), + new Node(3, "locahost", 9095, "rack1"), + new Node(4, "locahost", 9096, "rack1"), + new Node(5, "locahost", 9097, "rack2"))). + build() + try { + addTopics(adminClient) + assertStartsWith("Not all brokers have rack information.", + assertThrows(classOf[AdminOperationException], + () => generateAssignment(adminClient, """{"topics":[{"topic":"foo"}]}""", "0,1,2,3", true), + () => "Expected generateAssignment to fail").getMessage) + // It should succeed when --disable-rack-aware is used. + val (_, current) = generateAssignment(adminClient, + """{"topics":[{"topic":"foo"}]}""", "0,1,2,3", false) + assertEquals(Map( + new TopicPartition("foo", 0) -> Seq(0, 1, 2), + new TopicPartition("foo", 1) -> Seq(1, 2, 3), + ), current) + } finally { + adminClient.close() + } + } + + @Test + def testGenerateAssignmentWithFewerBrokers(): Unit = { + val adminClient = new MockAdminClient.Builder().numBrokers(4).build() + try { + addTopics(adminClient) + val goalBrokers = Set(0,1,3) + val (proposed, current) = generateAssignment(adminClient, + """{"topics":[{"topic":"foo"},{"topic":"bar"}]}""", + goalBrokers.mkString(","), false) + assertEquals(Map( + new TopicPartition("foo", 0) -> Seq(0, 1, 2), + new TopicPartition("foo", 1) -> Seq(1, 2, 3), + new TopicPartition("bar", 0) -> Seq(2, 3, 0) + ), current) + + // The proposed assignment should only span the provided brokers + proposed.values.foreach(replicas => assertTrue(replicas.forall(goalBrokers.contains), + s"Proposed assignment $proposed puts replicas on brokers other than $goalBrokers")) + } finally { + adminClient.close() + } + } + + @Test + def testCurrentPartitionReplicaAssignmentToString(): Unit = { + assertEquals(Seq( + """Current partition replica assignment""", + """""", + """{"version":1,"partitions":""" + + """[{"topic":"bar","partition":0,"replicas":[7,8],"log_dirs":["any","any"]},""" + + """{"topic":"foo","partition":1,"replicas":[4,5,6],"log_dirs":["any","any","any"]}]""" + + """}""", + """""", + """Save this to use as the --reassignment-json-file option during rollback""" + ).mkString(System.lineSeparator()), + currentPartitionReplicaAssignmentToString(Map( + new TopicPartition("foo", 1) -> Seq(1,2,3), + new TopicPartition("bar", 0) -> Seq(7,8,9) + ), + Map( + new TopicPartition("foo", 0) -> Seq(1,2,3), + new TopicPartition("foo", 1) -> Seq(4,5,6), + new TopicPartition("bar", 0) -> Seq(7,8), + new TopicPartition("baz", 0) -> Seq(10,11,12) + ), + )) + } + + @Test + def testMoveMap(): Unit = { + // overwrite foo-0 with different reassignments + // keep old reassignments of foo-1 + // overwrite foo-2 with same reassignments + // overwrite foo-3 with new reassignments without overlap of old reassignments + // overwrite foo-4 with a subset of old reassignments + // overwrite foo-5 with a superset of old reassignments + // add new reassignments to bar-0 + val moveMap = calculateProposedMoveMap(Map( + new TopicPartition("foo", 0) -> new PartitionReassignment( + Arrays.asList(1,2,3,4), Arrays.asList(4), Arrays.asList(3)), + new TopicPartition("foo", 1) -> new PartitionReassignment( + Arrays.asList(4,5,6,7,8), Arrays.asList(7, 8), Arrays.asList(4, 5)), + new TopicPartition("foo", 2) -> new PartitionReassignment( + Arrays.asList(1,2,3,4), Arrays.asList(3,4), Arrays.asList(1,2)), + new TopicPartition("foo", 3) -> new PartitionReassignment( + Arrays.asList(1,2,3,4), Arrays.asList(3,4), Arrays.asList(1,2)), + new TopicPartition("foo", 4) -> new PartitionReassignment( + Arrays.asList(1,2,3,4), Arrays.asList(3,4), Arrays.asList(1,2)), + new TopicPartition("foo", 5) -> new PartitionReassignment( + Arrays.asList(1,2,3,4), Arrays.asList(3,4), Arrays.asList(1,2)) + ), Map( + new TopicPartition("foo", 0) -> Seq(1,2,5), + new TopicPartition("foo", 2) -> Seq(3,4), + new TopicPartition("foo", 3) -> Seq(5,6), + new TopicPartition("foo", 4) -> Seq(3), + new TopicPartition("foo", 5) -> Seq(3,4,5,6), + new TopicPartition("bar", 0) -> Seq(1,2,3) + ), Map( + new TopicPartition("foo", 0) -> Seq(1,2,3,4), + new TopicPartition("foo", 1) -> Seq(4,5,6,7,8), + new TopicPartition("foo", 2) -> Seq(1,2,3,4), + new TopicPartition("foo", 3) -> Seq(1,2,3,4), + new TopicPartition("foo", 4) -> Seq(1,2,3,4), + new TopicPartition("foo", 5) -> Seq(1,2,3,4), + new TopicPartition("bar", 0) -> Seq(2,3,4), + new TopicPartition("baz", 0) -> Seq(1,2,3) + )) + + assertEquals( + mutable.Map("foo" -> mutable.Map( + 0 -> PartitionMove(mutable.Set(1,2,3), mutable.Set(5)), + 1 -> PartitionMove(mutable.Set(4,5,6), mutable.Set(7,8)), + 2 -> PartitionMove(mutable.Set(1,2), mutable.Set(3,4)), + 3 -> PartitionMove(mutable.Set(1,2), mutable.Set(5,6)), + 4 -> PartitionMove(mutable.Set(1,2), mutable.Set(3)), + 5 -> PartitionMove(mutable.Set(1,2), mutable.Set(3,4,5,6)) + ), "bar" -> mutable.Map( + 0 -> PartitionMove(mutable.Set(2,3,4), mutable.Set(1)), + )), moveMap) + + assertEquals(Map( + "foo" -> "0:1,0:2,0:3,1:4,1:5,1:6,2:1,2:2,3:1,3:2,4:1,4:2,5:1,5:2", + "bar" -> "0:2,0:3,0:4" + ), calculateLeaderThrottles(moveMap)) + + assertEquals(Map( + "foo" -> "0:5,1:7,1:8,2:3,2:4,3:5,3:6,4:3,5:3,5:4,5:5,5:6", + "bar" -> "0:1" + ), calculateFollowerThrottles(moveMap)) + + assertEquals(Set(1,2,3,4,5,6,7,8), calculateReassigningBrokers(moveMap)) + + assertEquals(Set(0,2), calculateMovingBrokers( + Set(new TopicPartitionReplica("quux", 0, 0), + new TopicPartitionReplica("quux", 1, 2)))) + } + + @Test + def testParseExecuteAssignmentArgs(): Unit = { + assertStartsWith("Partition reassignment list cannot be empty", + assertThrows(classOf[AdminCommandFailedException], + () => parseExecuteAssignmentArgs("""{"version":1,"partitions":[]}"""), + () => "Expected to detect empty partition reassignment list").getMessage) + assertStartsWith("Partition reassignment contains duplicate topic partitions", + assertThrows(classOf[AdminCommandFailedException], () => parseExecuteAssignmentArgs( + """{"version":1,"partitions":""" + + """[{"topic":"foo","partition":0,"replicas":[0,1],"log_dirs":["any","any"]},""" + + """{"topic":"foo","partition":0,"replicas":[2,3,4],"log_dirs":["any","any","any"]}""" + + """]}"""), () => "Expected to detect a partition list with duplicate entries").getMessage) + assertStartsWith("Partition reassignment contains duplicate topic partitions", + assertThrows(classOf[AdminCommandFailedException], () => parseExecuteAssignmentArgs( + """{"version":1,"partitions":""" + + """[{"topic":"foo","partition":0,"replicas":[0,1],"log_dirs":["/abc","/def"]},""" + + """{"topic":"foo","partition":0,"replicas":[2,3],"log_dirs":["/abc","/def"]}""" + + """]}"""), () => "Expected to detect a partition replica list with duplicate entries").getMessage) + assertStartsWith("Partition replica lists may not contain duplicate entries", + assertThrows(classOf[AdminCommandFailedException], () => parseExecuteAssignmentArgs( + """{"version":1,"partitions":""" + + """[{"topic":"foo","partition":0,"replicas":[0,0],"log_dirs":["/abc","/def"]},""" + + """{"topic":"foo","partition":1,"replicas":[2,3],"log_dirs":["/abc","/def"]}""" + + """]}"""), () => "Expected to detect a partition replica list with duplicate entries").getMessage) + assertEquals((Map( + new TopicPartition("foo", 0) -> Seq(1, 2, 3), + new TopicPartition("foo", 1) -> Seq(3, 4, 5), + ), Map( + )), + parseExecuteAssignmentArgs( + """{"version":1,"partitions":""" + + """[{"topic":"foo","partition":0,"replicas":[1,2,3],"log_dirs":["any","any","any"]},""" + + """{"topic":"foo","partition":1,"replicas":[3,4,5],"log_dirs":["any","any","any"]}""" + + """]}""")) + assertEquals((Map( + new TopicPartition("foo", 0) -> Seq(1, 2, 3), + ), Map( + new TopicPartitionReplica("foo", 0, 1) -> "/tmp/a", + new TopicPartitionReplica("foo", 0, 2) -> "/tmp/b", + new TopicPartitionReplica("foo", 0, 3) -> "/tmp/c" + )), + parseExecuteAssignmentArgs( + """{"version":1,"partitions":""" + + """[{"topic":"foo","partition":0,"replicas":[1,2,3],"log_dirs":["/tmp/a","/tmp/b","/tmp/c"]}""" + + """]}""")) + } + + @Test + def testExecuteWithInvalidPartitionsFails(): Unit = { + val adminClient = new MockAdminClient.Builder().numBrokers(5).build() + try { + addTopics(adminClient) + assertStartsWith("Topic quux not found", + assertThrows(classOf[ExecutionException], () => executeAssignment(adminClient, false, + """{"version":1,"partitions":""" + + """[{"topic":"foo","partition":0,"replicas":[0,1],"log_dirs":["any","any"]},""" + + """{"topic":"quux","partition":0,"replicas":[2,3,4],"log_dirs":["any","any","any"]}""" + + """]}"""), () => "Expected reassignment with non-existent topic to fail").getCause.getMessage) + } finally { + adminClient.close() + } + } + + @Test + def testExecuteWithInvalidBrokerIdFails(): Unit = { + val adminClient = new MockAdminClient.Builder().numBrokers(4).build() + try { + addTopics(adminClient) + assertStartsWith("Unknown broker id 4", + assertThrows(classOf[AdminCommandFailedException], () => executeAssignment(adminClient, false, + """{"version":1,"partitions":""" + + """[{"topic":"foo","partition":0,"replicas":[0,1],"log_dirs":["any","any"]},""" + + """{"topic":"foo","partition":1,"replicas":[2,3,4],"log_dirs":["any","any","any"]}""" + + """]}"""), () => "Expected reassignment with non-existent broker id to fail").getMessage) + } finally { + adminClient.close() + } + } + + @Test + def testModifyBrokerInterBrokerThrottle(): Unit = { + val adminClient = new MockAdminClient.Builder().numBrokers(4).build() + try { + modifyInterBrokerThrottle(adminClient, Set(0, 1, 2), 1000) + modifyInterBrokerThrottle(adminClient, Set(0, 3), 100) + val brokers = Seq(0, 1, 2, 3).map( + id => new ConfigResource(ConfigResource.Type.BROKER, id.toString)) + val results = adminClient.describeConfigs(brokers.asJava).all().get() + verifyBrokerThrottleResults(results.get(brokers(0)), 100, -1) + verifyBrokerThrottleResults(results.get(brokers(1)), 1000, -1) + verifyBrokerThrottleResults(results.get(brokers(2)), 1000, -1) + verifyBrokerThrottleResults(results.get(brokers(3)), 100, -1) + } finally { + adminClient.close() + } + } + + @Test + def testModifyLogDirThrottle(): Unit = { + val adminClient = new MockAdminClient.Builder().numBrokers(4).build() + try { + modifyLogDirThrottle(adminClient, Set(0, 1, 2), 2000) + modifyLogDirThrottle(adminClient, Set(0, 3), -1) + val brokers = Seq(0, 1, 2, 3).map( + id => new ConfigResource(ConfigResource.Type.BROKER, id.toString)) + val results = adminClient.describeConfigs(brokers.asJava).all().get() + verifyBrokerThrottleResults(results.get(brokers(0)), -1, 2000) + verifyBrokerThrottleResults(results.get(brokers(1)), -1, 2000) + verifyBrokerThrottleResults(results.get(brokers(2)), -1, 2000) + verifyBrokerThrottleResults(results.get(brokers(3)), -1, -1) + } finally { + adminClient.close() + } + } + + @Test + def testCurReassignmentsToString(): Unit = { + val adminClient = new MockAdminClient.Builder().numBrokers(4).build() + try { + addTopics(adminClient) + assertEquals("No partition reassignments found.", curReassignmentsToString(adminClient)) + val reassignmentResult: Map[TopicPartition, Class[_ <: Throwable]] = alterPartitionReassignments(adminClient, + Map( + new TopicPartition("foo", 1) -> Seq(4,5,3), + new TopicPartition("foo", 0) -> Seq(0,1,4,2), + new TopicPartition("bar", 0) -> Seq(2,3) + ) + ).map { case (k, v) => k -> v.getClass }.toMap + assertEquals(Map(), reassignmentResult) + assertEquals(Seq("Current partition reassignments:", + "bar-0: replicas: 2,3,0. removing: 0.", + "foo-0: replicas: 0,1,2. adding: 4.", + "foo-1: replicas: 1,2,3. adding: 4,5. removing: 1,2.").mkString(System.lineSeparator()), + curReassignmentsToString(adminClient)) + } finally { + adminClient.close() + } + } + + private def verifyBrokerThrottleResults(config: Config, + expectedInterBrokerThrottle: Long, + expectedReplicaAlterLogDirsThrottle: Long): Unit = { + val configs = new mutable.HashMap[String, String] + config.entries.forEach(entry => configs.put(entry.name, entry.value)) + if (expectedInterBrokerThrottle >= 0) { + assertEquals(expectedInterBrokerThrottle.toString, + configs.getOrElse(brokerLevelLeaderThrottle, "")) + assertEquals(expectedInterBrokerThrottle.toString, + configs.getOrElse(brokerLevelFollowerThrottle, "")) + } + if (expectedReplicaAlterLogDirsThrottle >= 0) { + assertEquals(expectedReplicaAlterLogDirsThrottle.toString, + configs.getOrElse(brokerLevelLogDirThrottle, "")) + } + } + + @Test + def testModifyTopicThrottles(): Unit = { + val adminClient = new MockAdminClient.Builder().numBrokers(4).build() + try { + addTopics(adminClient) + modifyTopicThrottles(adminClient, + Map("foo" -> "leaderFoo", "bar" -> "leaderBar"), + Map("bar" -> "followerBar")) + val topics = Seq("bar", "foo").map( + id => new ConfigResource(ConfigResource.Type.TOPIC, id)) + val results = adminClient.describeConfigs(topics.asJava).all().get() + verifyTopicThrottleResults(results.get(topics(0)), "leaderBar", "followerBar") + verifyTopicThrottleResults(results.get(topics(1)), "leaderFoo", "") + } finally { + adminClient.close() + } + } + + private def verifyTopicThrottleResults(config: Config, + expectedLeaderThrottle: String, + expectedFollowerThrottle: String): Unit = { + val configs = new mutable.HashMap[String, String] + config.entries.forEach(entry => configs.put(entry.name, entry.value)) + assertEquals(expectedLeaderThrottle, + configs.getOrElse(topicLevelLeaderThrottle, "")) + assertEquals(expectedFollowerThrottle, + configs.getOrElse(topicLevelFollowerThrottle, "")) + } + + @Test + def testAlterReplicaLogDirs(): Unit = { + val adminClient = new MockAdminClient.Builder(). + numBrokers(4). + brokerLogDirs(Collections.nCopies(4, + Arrays.asList("/tmp/kafka-logs0", "/tmp/kafka-logs1"))). + build() + try { + addTopics(adminClient) + assertEquals(Set( + new TopicPartitionReplica("foo", 0, 0) + ), + alterReplicaLogDirs(adminClient, Map( + new TopicPartitionReplica("foo", 0, 0) -> "/tmp/kafka-logs1", + new TopicPartitionReplica("quux", 1, 0) -> "/tmp/kafka-logs1" + ))) + } finally { + adminClient.close() + } + } + + def assertStartsWith(prefix: String, str: String): Unit = { + assertTrue(str.startsWith(prefix), "Expected the string to start with %s, but it was %s".format(prefix, str)) + } + + @Test + def testPropagateInvalidJsonError(): Unit = { + val adminClient = new MockAdminClient.Builder().numBrokers(4).build() + try { + addTopics(adminClient) + assertStartsWith("Unexpected character", + assertThrows(classOf[AdminOperationException], () => executeAssignment(adminClient, additional = false, "{invalid_json")).getMessage) + } finally { + adminClient.close() + } + } +} diff --git a/core/src/test/scala/unit/kafka/admin/ReplicationQuotaUtils.scala b/core/src/test/scala/unit/kafka/admin/ReplicationQuotaUtils.scala new file mode 100644 index 0000000..4ac64fe --- /dev/null +++ b/core/src/test/scala/unit/kafka/admin/ReplicationQuotaUtils.scala @@ -0,0 +1,56 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package kafka.admin + +import kafka.log.LogConfig +import kafka.server.{ConfigType, DynamicConfig, KafkaServer} +import kafka.utils.TestUtils +import kafka.zk.AdminZkClient + +import scala.collection.Seq + +object ReplicationQuotaUtils { + + def checkThrottleConfigRemovedFromZK(adminZkClient: AdminZkClient, topic: String, servers: Seq[KafkaServer]): Unit = { + TestUtils.waitUntilTrue(() => { + val hasRateProp = servers.forall { server => + val brokerConfig = adminZkClient.fetchEntityConfig(ConfigType.Broker, server.config.brokerId.toString) + brokerConfig.contains(DynamicConfig.Broker.LeaderReplicationThrottledRateProp) || + brokerConfig.contains(DynamicConfig.Broker.FollowerReplicationThrottledRateProp) + } + val topicConfig = adminZkClient.fetchEntityConfig(ConfigType.Topic, topic) + val hasReplicasProp = topicConfig.contains(LogConfig.LeaderReplicationThrottledReplicasProp) || + topicConfig.contains(LogConfig.FollowerReplicationThrottledReplicasProp) + !hasRateProp && !hasReplicasProp + }, "Throttle limit/replicas was not unset") + } + + def checkThrottleConfigAddedToZK(adminZkClient: AdminZkClient, expectedThrottleRate: Long, servers: Seq[KafkaServer], topic: String, throttledLeaders: Set[String], throttledFollowers: Set[String]): Unit = { + TestUtils.waitUntilTrue(() => { + //Check for limit in ZK + val brokerConfigAvailable = servers.forall { server => + val configInZk = adminZkClient.fetchEntityConfig(ConfigType.Broker, server.config.brokerId.toString) + val zkLeaderRate = configInZk.getProperty(DynamicConfig.Broker.LeaderReplicationThrottledRateProp) + val zkFollowerRate = configInZk.getProperty(DynamicConfig.Broker.FollowerReplicationThrottledRateProp) + zkLeaderRate != null && expectedThrottleRate == zkLeaderRate.toLong && + zkFollowerRate != null && expectedThrottleRate == zkFollowerRate.toLong + } + //Check replicas assigned + val topicConfig = adminZkClient.fetchEntityConfig(ConfigType.Topic, topic) + val leader = topicConfig.getProperty(LogConfig.LeaderReplicationThrottledReplicasProp).split(",").toSet + val follower = topicConfig.getProperty(LogConfig.FollowerReplicationThrottledReplicasProp).split(",").toSet + val topicConfigAvailable = leader == throttledLeaders && follower == throttledFollowers + brokerConfigAvailable && topicConfigAvailable + }, "throttle limit/replicas was not set") + } +} diff --git a/core/src/test/scala/unit/kafka/admin/ResetConsumerGroupOffsetTest.scala b/core/src/test/scala/unit/kafka/admin/ResetConsumerGroupOffsetTest.scala new file mode 100644 index 0000000..8443fea --- /dev/null +++ b/core/src/test/scala/unit/kafka/admin/ResetConsumerGroupOffsetTest.scala @@ -0,0 +1,516 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package kafka.admin + +import java.io.{BufferedWriter, File, FileWriter} +import java.text.{SimpleDateFormat} +import java.util.{Calendar, Date, Properties} + +import joptsimple.OptionException +import kafka.admin.ConsumerGroupCommand.ConsumerGroupService +import kafka.server.KafkaConfig +import kafka.utils.TestUtils +import org.apache.kafka.clients.producer.ProducerRecord +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.test +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ +import scala.collection.Seq + + + +/** + * Test cases by: + * - Non-existing consumer group + * - One for each scenario, with scope=all-topics + * - scope=one topic, scenario=to-earliest + * - scope=one topic+partitions, scenario=to-earliest + * - scope=topics, scenario=to-earliest + * - scope=topics+partitions, scenario=to-earliest + * - export/import + */ +class ResetConsumerGroupOffsetTest extends ConsumerGroupCommandTest { + + val overridingProps = new Properties() + val topic1 = "foo1" + val topic2 = "foo2" + + override def generateConfigs: Seq[KafkaConfig] = { + TestUtils.createBrokerConfigs(1, zkConnect, enableControlledShutdown = false) + .map(KafkaConfig.fromProps(_, overridingProps)) + } + + private def basicArgs: Array[String] = { + Array("--reset-offsets", + "--bootstrap-server", brokerList, + "--timeout", test.TestUtils.DEFAULT_MAX_WAIT_MS.toString) + } + + private def buildArgsForGroups(groups: Seq[String], args: String*): Array[String] = { + val groupArgs = groups.flatMap(group => Seq("--group", group)).toArray + basicArgs ++ groupArgs ++ args + } + + private def buildArgsForGroup(group: String, args: String*): Array[String] = { + buildArgsForGroups(Seq(group), args: _*) + } + + private def buildArgsForAllGroups(args: String*): Array[String] = { + basicArgs ++ Array("--all-groups") ++ args + } + + @Test + def testResetOffsetsNotExistingGroup(): Unit = { + val group = "missing.group" + val args = buildArgsForGroup(group, "--all-topics", "--to-current", "--execute") + val consumerGroupCommand = getConsumerGroupService(args) + // Make sure we got a coordinator + TestUtils.waitUntilTrue(() => { + consumerGroupCommand.collectGroupState(group).coordinator.host() == "localhost" + }, "Can't find a coordinator") + val resetOffsets = consumerGroupCommand.resetOffsets()(group) + assertEquals(Map.empty, resetOffsets) + assertEquals(resetOffsets, committedOffsets(group = group)) + } + + @Test + def testResetOffsetsExistingTopic(): Unit = { + val group = "new.group" + val args = buildArgsForGroup(group, "--topic", topic, "--to-offset", "50") + produceMessages(topic, 100) + resetAndAssertOffsets(args, expectedOffset = 50, dryRun = true) + resetAndAssertOffsets(args ++ Array("--dry-run"), expectedOffset = 50, dryRun = true) + resetAndAssertOffsets(args ++ Array("--execute"), expectedOffset = 50) + } + + @Test + def testResetOffsetsExistingTopicSelectedGroups(): Unit = { + produceMessages(topic, 100) + val groups = + for (id <- 1 to 3) yield { + val group = this.group + id + val executor = addConsumerGroupExecutor(numConsumers = 1, topic = topic, group = group) + awaitConsumerProgress(count = 100L, group = group) + executor.shutdown() + group + } + val args = buildArgsForGroups(groups,"--topic", topic, "--to-offset", "50") + resetAndAssertOffsets(args, expectedOffset = 50, dryRun = true) + resetAndAssertOffsets(args ++ Array("--dry-run"), expectedOffset = 50, dryRun = true) + resetAndAssertOffsets(args ++ Array("--execute"), expectedOffset = 50) + } + + @Test + def testResetOffsetsExistingTopicAllGroups(): Unit = { + val args = buildArgsForAllGroups("--topic", topic, "--to-offset", "50") + produceMessages(topic, 100) + for (group <- 1 to 3 map (group + _)) { + val executor = addConsumerGroupExecutor(numConsumers = 1, topic = topic, group = group) + awaitConsumerProgress(count = 100L, group = group) + executor.shutdown() + } + resetAndAssertOffsets(args, expectedOffset = 50, dryRun = true) + resetAndAssertOffsets(args ++ Array("--dry-run"), expectedOffset = 50, dryRun = true) + resetAndAssertOffsets(args ++ Array("--execute"), expectedOffset = 50) + } + + @Test + def testResetOffsetsAllTopicsAllGroups(): Unit = { + val args = buildArgsForAllGroups("--all-topics", "--to-offset", "50") + val topics = 1 to 3 map (topic + _) + val groups = 1 to 3 map (group + _) + topics foreach (topic => produceMessages(topic, 100)) + for { + topic <- topics + group <- groups + } { + val executor = addConsumerGroupExecutor(numConsumers = 3, topic = topic, group = group) + awaitConsumerProgress(topic = topic, count = 100L, group = group) + executor.shutdown() + } + resetAndAssertOffsets(args, expectedOffset = 50, dryRun = true, topics = topics) + resetAndAssertOffsets(args ++ Array("--dry-run"), expectedOffset = 50, dryRun = true, topics = topics) + resetAndAssertOffsets(args ++ Array("--execute"), expectedOffset = 50, topics = topics) + } + + @Test + def testResetOffsetsToLocalDateTime(): Unit = { + val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS") + val calendar = Calendar.getInstance() + calendar.add(Calendar.DATE, -1) + + produceMessages(topic, 100) + + val executor = addConsumerGroupExecutor(numConsumers = 1, topic) + awaitConsumerProgress(count = 100L) + executor.shutdown() + + val args = buildArgsForGroup(group, "--all-topics", "--to-datetime", format.format(calendar.getTime), "--execute") + resetAndAssertOffsets(args, expectedOffset = 0) + } + + @Test + def testResetOffsetsToZonedDateTime(): Unit = { + val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSXXX") + + produceMessages(topic, 50) + val checkpoint = new Date() + produceMessages(topic, 50) + + val executor = addConsumerGroupExecutor(numConsumers = 1, topic) + awaitConsumerProgress(count = 100L) + executor.shutdown() + + val args = buildArgsForGroup(group, "--all-topics", "--to-datetime", format.format(checkpoint), "--execute") + resetAndAssertOffsets(args, expectedOffset = 50) + } + + @Test + def testResetOffsetsByDuration(): Unit = { + val args = buildArgsForGroup(group, "--all-topics", "--by-duration", "PT1M", "--execute") + produceConsumeAndShutdown(topic, group, totalMessages = 100) + resetAndAssertOffsets(args, expectedOffset = 0) + } + + @Test + def testResetOffsetsByDurationToEarliest(): Unit = { + val args = buildArgsForGroup(group, "--all-topics", "--by-duration", "PT0.1S", "--execute") + produceConsumeAndShutdown(topic, group, totalMessages = 100) + resetAndAssertOffsets(args, expectedOffset = 100) + } + + @Test + def testResetOffsetsByDurationFallbackToLatestWhenNoRecords(): Unit = { + val topic = "foo2" + val args = buildArgsForGroup(group, "--topic", topic, "--by-duration", "PT1M", "--execute") + createTopic(topic) + resetAndAssertOffsets(args, expectedOffset = 0, topics = Seq("foo2")) + + adminZkClient.deleteTopic(topic) + } + + @Test + def testResetOffsetsToEarliest(): Unit = { + val args = buildArgsForGroup(group, "--all-topics", "--to-earliest", "--execute") + produceConsumeAndShutdown(topic, group, totalMessages = 100) + resetAndAssertOffsets(args, expectedOffset = 0) + } + + @Test + def testResetOffsetsToLatest(): Unit = { + val args = buildArgsForGroup(group, "--all-topics", "--to-latest", "--execute") + produceConsumeAndShutdown(topic, group, totalMessages = 100) + produceMessages(topic, 100) + resetAndAssertOffsets(args, expectedOffset = 200) + } + + @Test + def testResetOffsetsToCurrentOffset(): Unit = { + val args = buildArgsForGroup(group, "--all-topics", "--to-current", "--execute") + produceConsumeAndShutdown(topic, group, totalMessages = 100) + produceMessages(topic, 100) + resetAndAssertOffsets(args, expectedOffset = 100) + } + + @Test + def testResetOffsetsToSpecificOffset(): Unit = { + val args = buildArgsForGroup(group, "--all-topics", "--to-offset", "1", "--execute") + produceConsumeAndShutdown(topic, group, totalMessages = 100) + resetAndAssertOffsets(args, expectedOffset = 1) + } + + @Test + def testResetOffsetsShiftPlus(): Unit = { + val args = buildArgsForGroup(group, "--all-topics", "--shift-by", "50", "--execute") + produceConsumeAndShutdown(topic, group, totalMessages = 100) + produceMessages(topic, 100) + resetAndAssertOffsets(args, expectedOffset = 150) + } + + @Test + def testResetOffsetsShiftMinus(): Unit = { + val args = buildArgsForGroup(group, "--all-topics", "--shift-by", "-50", "--execute") + produceConsumeAndShutdown(topic, group, totalMessages = 100) + produceMessages(topic, 100) + resetAndAssertOffsets(args, expectedOffset = 50) + } + + @Test + def testResetOffsetsShiftByLowerThanEarliest(): Unit = { + val args = buildArgsForGroup(group, "--all-topics", "--shift-by", "-150", "--execute") + produceConsumeAndShutdown(topic, group, totalMessages = 100) + produceMessages(topic, 100) + resetAndAssertOffsets(args, expectedOffset = 0) + } + + @Test + def testResetOffsetsShiftByHigherThanLatest(): Unit = { + val args = buildArgsForGroup(group, "--all-topics", "--shift-by", "150", "--execute") + produceConsumeAndShutdown(topic, group, totalMessages = 100) + produceMessages(topic, 100) + resetAndAssertOffsets(args, expectedOffset = 200) + } + + @Test + def testResetOffsetsToEarliestOnOneTopic(): Unit = { + val args = buildArgsForGroup(group, "--topic", topic, "--to-earliest", "--execute") + produceConsumeAndShutdown(topic, group, totalMessages = 100) + resetAndAssertOffsets(args, expectedOffset = 0) + } + + @Test + def testResetOffsetsToEarliestOnOneTopicAndPartition(): Unit = { + val topic = "bar" + createTopic(topic, 2, 1) + + val args = buildArgsForGroup(group, "--topic", s"$topic:1", "--to-earliest", "--execute") + val consumerGroupCommand = getConsumerGroupService(args) + + produceConsumeAndShutdown(topic, group, totalMessages = 100, numConsumers = 2) + val priorCommittedOffsets = committedOffsets(topic = topic) + + val tp0 = new TopicPartition(topic, 0) + val tp1 = new TopicPartition(topic, 1) + val expectedOffsets = Map(tp0 -> priorCommittedOffsets(tp0), tp1 -> 0L) + resetAndAssertOffsetsCommitted(consumerGroupCommand, expectedOffsets, topic) + + adminZkClient.deleteTopic(topic) + } + + @Test + def testResetOffsetsToEarliestOnTopics(): Unit = { + val topic1 = "topic1" + val topic2 = "topic2" + createTopic(topic1, 1, 1) + createTopic(topic2, 1, 1) + + val args = buildArgsForGroup(group, "--topic", topic1, "--topic", topic2, "--to-earliest", "--execute") + val consumerGroupCommand = getConsumerGroupService(args) + + produceConsumeAndShutdown(topic1, group, 100, 1) + produceConsumeAndShutdown(topic2, group, 100, 1) + + val tp1 = new TopicPartition(topic1, 0) + val tp2 = new TopicPartition(topic2, 0) + + val allResetOffsets = resetOffsets(consumerGroupCommand)(group).map { case (k, v) => k -> v.offset } + assertEquals(Map(tp1 -> 0L, tp2 -> 0L), allResetOffsets) + assertEquals(Map(tp1 -> 0L), committedOffsets(topic1)) + assertEquals(Map(tp2 -> 0L), committedOffsets(topic2)) + + adminZkClient.deleteTopic(topic1) + adminZkClient.deleteTopic(topic2) + } + + @Test + def testResetOffsetsToEarliestOnTopicsAndPartitions(): Unit = { + val topic1 = "topic1" + val topic2 = "topic2" + + createTopic(topic1, 2, 1) + createTopic(topic2, 2, 1) + + val args = buildArgsForGroup(group, "--topic", s"$topic1:1", "--topic", s"$topic2:1", "--to-earliest", "--execute") + val consumerGroupCommand = getConsumerGroupService(args) + + produceConsumeAndShutdown(topic1, group, 100, 2) + produceConsumeAndShutdown(topic2, group, 100, 2) + + val priorCommittedOffsets1 = committedOffsets(topic1) + val priorCommittedOffsets2 = committedOffsets(topic2) + + val tp1 = new TopicPartition(topic1, 1) + val tp2 = new TopicPartition(topic2, 1) + val allResetOffsets = resetOffsets(consumerGroupCommand)(group).map { case (k, v) => k -> v.offset } + assertEquals(Map(tp1 -> 0, tp2 -> 0), allResetOffsets) + + assertEquals(priorCommittedOffsets1.toMap + (tp1 -> 0L), committedOffsets(topic1)) + assertEquals(priorCommittedOffsets2.toMap + (tp2 -> 0L), committedOffsets(topic2)) + + adminZkClient.deleteTopic(topic1) + adminZkClient.deleteTopic(topic2) + } + + @Test + // This one deals with old CSV export/import format for a single --group arg: "topic,partition,offset" to support old behavior + def testResetOffsetsExportImportPlanSingleGroupArg(): Unit = { + val topic = "bar" + val tp0 = new TopicPartition(topic, 0) + val tp1 = new TopicPartition(topic, 1) + createTopic(topic, 2, 1) + + val cgcArgs = buildArgsForGroup(group, "--all-topics", "--to-offset", "2", "--export") + val consumerGroupCommand = getConsumerGroupService(cgcArgs) + + produceConsumeAndShutdown(topic = topic, group = group, totalMessages = 100, numConsumers = 2) + + val file = File.createTempFile("reset", ".csv") + file.deleteOnExit() + + val exportedOffsets = consumerGroupCommand.resetOffsets() + val bw = new BufferedWriter(new FileWriter(file)) + bw.write(consumerGroupCommand.exportOffsetsToCsv(exportedOffsets)) + bw.close() + assertEquals(Map(tp0 -> 2L, tp1 -> 2L), exportedOffsets(group).map { case (k, v) => k -> v.offset }) + + val cgcArgsExec = buildArgsForGroup(group, "--all-topics", "--from-file", file.getCanonicalPath, "--dry-run") + val consumerGroupCommandExec = getConsumerGroupService(cgcArgsExec) + val importedOffsets = consumerGroupCommandExec.resetOffsets() + assertEquals(Map(tp0 -> 2L, tp1 -> 2L), importedOffsets(group).map { case (k, v) => k -> v.offset }) + + adminZkClient.deleteTopic(topic) + } + + @Test + // This one deals with universal CSV export/import file format "group,topic,partition,offset", + // supporting multiple --group args or --all-groups arg + def testResetOffsetsExportImportPlan(): Unit = { + val group1 = group + "1" + val group2 = group + "2" + val topic1 = "bar1" + val topic2 = "bar2" + val t1p0 = new TopicPartition(topic1, 0) + val t1p1 = new TopicPartition(topic1, 1) + val t2p0 = new TopicPartition(topic2, 0) + val t2p1 = new TopicPartition(topic2, 1) + createTopic(topic1, 2, 1) + createTopic(topic2, 2, 1) + + val cgcArgs = buildArgsForGroups(Seq(group1, group2), "--all-topics", "--to-offset", "2", "--export") + val consumerGroupCommand = getConsumerGroupService(cgcArgs) + + produceConsumeAndShutdown(topic = topic1, group = group1, totalMessages = 100) + produceConsumeAndShutdown(topic = topic2, group = group2, totalMessages = 100) + + awaitConsumerGroupInactive(consumerGroupCommand, group1) + awaitConsumerGroupInactive(consumerGroupCommand, group2) + + val file = File.createTempFile("reset", ".csv") + file.deleteOnExit() + + val exportedOffsets = consumerGroupCommand.resetOffsets() + val bw = new BufferedWriter(new FileWriter(file)) + bw.write(consumerGroupCommand.exportOffsetsToCsv(exportedOffsets)) + bw.close() + assertEquals(Map(t1p0 -> 2L, t1p1 -> 2L), exportedOffsets(group1).map { case (k, v) => k -> v.offset }) + assertEquals(Map(t2p0 -> 2L, t2p1 -> 2L), exportedOffsets(group2).map { case (k, v) => k -> v.offset }) + + // Multiple --group's offset import + val cgcArgsExec = buildArgsForGroups(Seq(group1, group2), "--all-topics", "--from-file", file.getCanonicalPath, "--dry-run") + val consumerGroupCommandExec = getConsumerGroupService(cgcArgsExec) + val importedOffsets = consumerGroupCommandExec.resetOffsets() + assertEquals(Map(t1p0 -> 2L, t1p1 -> 2L), importedOffsets(group1).map { case (k, v) => k -> v.offset }) + assertEquals(Map(t2p0 -> 2L, t2p1 -> 2L), importedOffsets(group2).map { case (k, v) => k -> v.offset }) + + // Single --group offset import using "group,topic,partition,offset" csv format + val cgcArgsExec2 = buildArgsForGroup(group1, "--all-topics", "--from-file", file.getCanonicalPath, "--dry-run") + val consumerGroupCommandExec2 = getConsumerGroupService(cgcArgsExec2) + val importedOffsets2 = consumerGroupCommandExec2.resetOffsets() + assertEquals(Map(t1p0 -> 2L, t1p1 -> 2L), importedOffsets2(group1).map { case (k, v) => k -> v.offset }) + + adminZkClient.deleteTopic(topic) + } + + @Test + def testResetWithUnrecognizedNewConsumerOption(): Unit = { + val cgcArgs = Array("--new-consumer", "--bootstrap-server", brokerList, "--reset-offsets", "--group", group, "--all-topics", + "--to-offset", "2", "--export") + assertThrows(classOf[OptionException], () => getConsumerGroupService(cgcArgs)) + } + + private def produceMessages(topic: String, numMessages: Int): Unit = { + val records = (0 until numMessages).map(_ => new ProducerRecord[Array[Byte], Array[Byte]](topic, + new Array[Byte](100 * 1000))) + TestUtils.produceMessages(servers, records, acks = 1) + } + + private def produceConsumeAndShutdown(topic: String, group: String, totalMessages: Int, numConsumers: Int = 1): Unit = { + produceMessages(topic, totalMessages) + val executor = addConsumerGroupExecutor(numConsumers = numConsumers, topic = topic, group = group) + awaitConsumerProgress(topic, group, totalMessages) + executor.shutdown() + } + + private def awaitConsumerProgress(topic: String = topic, + group: String = group, + count: Long): Unit = { + val consumer = createNoAutoCommitConsumer(group) + try { + val partitions = consumer.partitionsFor(topic).asScala.map { partitionInfo => + new TopicPartition(partitionInfo.topic, partitionInfo.partition) + }.toSet + + TestUtils.waitUntilTrue(() => { + val committed = consumer.committed(partitions.asJava).values.asScala + val total = committed.foldLeft(0L) { case (currentSum, offsetAndMetadata) => + currentSum + Option(offsetAndMetadata).map(_.offset).getOrElse(0L) + } + total == count + }, "Expected that consumer group has consumed all messages from topic/partition. " + + s"Expected offset: $count. Actual offset: ${committedOffsets(topic, group).values.sum}") + + } finally { + consumer.close() + } + + } + + private def awaitConsumerGroupInactive(consumerGroupService: ConsumerGroupService, group: String): Unit = { + TestUtils.waitUntilTrue(() => { + val state = consumerGroupService.collectGroupState(group).state + state == "Empty" || state == "Dead" + }, s"Expected that consumer group is inactive. Actual state: ${consumerGroupService.collectGroupState(group).state}") + } + + private def resetAndAssertOffsets(args: Array[String], + expectedOffset: Long, + dryRun: Boolean = false, + topics: Seq[String] = Seq(topic)): Unit = { + val consumerGroupCommand = getConsumerGroupService(args) + val expectedOffsets = topics.map(topic => topic -> Map(new TopicPartition(topic, 0) -> expectedOffset)).toMap + val resetOffsetsResultByGroup = resetOffsets(consumerGroupCommand) + + try { + for { + topic <- topics + (group, partitionInfo) <- resetOffsetsResultByGroup + } { + val priorOffsets = committedOffsets(topic = topic, group = group) + assertEquals(expectedOffsets(topic), + partitionInfo.filter(partitionInfo => partitionInfo._1.topic() == topic).map { case (k, v) => k -> v.offset }) + assertEquals(if (dryRun) priorOffsets else expectedOffsets(topic), committedOffsets(topic = topic, group = group)) + } + } finally { + consumerGroupCommand.close() + } + } + + private def resetAndAssertOffsetsCommitted(consumerGroupService: ConsumerGroupService, + expectedOffsets: Map[TopicPartition, Long], + topic: String): Unit = { + val allResetOffsets = resetOffsets(consumerGroupService) + for { + (group, offsetsInfo) <- allResetOffsets + (tp, offsetMetadata) <- offsetsInfo + } { + assertEquals(offsetMetadata.offset(), expectedOffsets(tp)) + assertEquals(expectedOffsets, committedOffsets(topic, group)) + } + } + + private def resetOffsets(consumerGroupService: ConsumerGroupService) = { + consumerGroupService.resetOffsets() + } +} diff --git a/core/src/test/scala/unit/kafka/admin/TopicCommandIntegrationTest.scala b/core/src/test/scala/unit/kafka/admin/TopicCommandIntegrationTest.scala new file mode 100644 index 0000000..dfb4ae7 --- /dev/null +++ b/core/src/test/scala/unit/kafka/admin/TopicCommandIntegrationTest.scala @@ -0,0 +1,816 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.admin + +import java.util.{Collection, Collections, Optional, Properties} + +import kafka.admin.TopicCommand.{TopicCommandOptions, TopicService} +import kafka.integration.KafkaServerTestHarness +import kafka.server.{ConfigType, KafkaConfig} +import kafka.utils.{Logging, TestUtils} +import kafka.zk.{ConfigEntityChangeNotificationZNode, DeleteTopicsTopicZNode} +import org.apache.kafka.clients.CommonClientConfigs +import org.apache.kafka.clients.admin._ +import org.apache.kafka.common.config.{ConfigException, ConfigResource, TopicConfig} +import org.apache.kafka.common.errors.{ClusterAuthorizationException, ThrottlingQuotaExceededException, TopicExistsException} +import org.apache.kafka.common.internals.Topic +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.{Node, TopicPartition, TopicPartitionInfo} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} +import org.mockito.ArgumentMatcher +import org.mockito.ArgumentMatchers.{eq => eqThat, _} +import org.mockito.Mockito._ + +import scala.collection.Seq +import scala.concurrent.ExecutionException +import scala.jdk.CollectionConverters._ +import scala.util.Random + +class TopicCommandIntegrationTest extends KafkaServerTestHarness with Logging with RackAwareTest { + + /** + * Implementations must override this method to return a set of KafkaConfigs. This method will be invoked for every + * test and should not reuse previous configurations unless they select their ports randomly when servers are started. + * + * Note the replica fetch max bytes is set to `1` in order to throttle the rate of replication for test + * `testDescribeUnderReplicatedPartitionsWhenReassignmentIsInProgress`. + */ + override def generateConfigs: Seq[KafkaConfig] = TestUtils.createBrokerConfigs( + numConfigs = 6, + zkConnect = zkConnect, + rackInfo = Map(0 -> "rack1", 1 -> "rack2", 2 -> "rack2", 3 -> "rack1", 4 -> "rack3", 5 -> "rack3"), + numPartitions = numPartitions, + defaultReplicationFactor = defaultReplicationFactor, + ).map { props => + props.put(KafkaConfig.ReplicaFetchMaxBytesProp, "1") + KafkaConfig.fromProps(props) + } + + private val numPartitions = 1 + private val defaultReplicationFactor = 1.toShort + + private var topicService: TopicService = _ + private var adminClient: Admin = _ + private var testTopicName: String = _ + + private[this] def createAndWaitTopic(opts: TopicCommandOptions): Unit = { + topicService.createTopic(opts) + waitForTopicCreated(opts.topic.get) + } + + private[this] def waitForTopicCreated(topicName: String, timeout: Int = 10000): Unit = { + TestUtils.waitForPartitionMetadata(servers, topicName, partition = 0, timeout) + } + + @BeforeEach + override def setUp(info: TestInfo): Unit = { + super.setUp(info) + + // create adminClient + val props = new Properties() + props.put(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, brokerList) + adminClient = Admin.create(props) + topicService = TopicService(adminClient) + testTopicName = s"${info.getTestMethod.get().getName}-${Random.alphanumeric.take(10).mkString}" + } + + @AfterEach + def close(): Unit = { + // adminClient is closed by topicService + if (topicService != null) + topicService.close() + } + + @Test + def testCreate(): Unit = { + createAndWaitTopic(new TopicCommandOptions( + Array("--partitions", "2", "--replication-factor", "1", "--topic", testTopicName))) + + adminClient.listTopics().names().get().contains(testTopicName) + } + + @Test + def testCreateWithDefaults(): Unit = { + createAndWaitTopic(new TopicCommandOptions(Array("--topic", testTopicName))) + + val partitions = adminClient + .describeTopics(Collections.singletonList(testTopicName)) + .allTopicNames() + .get() + .get(testTopicName) + .partitions() + assertEquals(partitions.size(), numPartitions) + assertEquals(partitions.get(0).replicas().size(), defaultReplicationFactor) + } + + @Test + def testCreateWithDefaultReplication(): Unit = { + createAndWaitTopic(new TopicCommandOptions( + Array("--topic", testTopicName, "--partitions", "2"))) + + val partitions = adminClient + .describeTopics(Collections.singletonList(testTopicName)) + .allTopicNames() + .get() + .get(testTopicName) + .partitions() + assertEquals(partitions.size(), 2) + assertEquals(partitions.get(0).replicas().size(), defaultReplicationFactor) + } + + @Test + def testCreateWithDefaultPartitions(): Unit = { + createAndWaitTopic(new TopicCommandOptions( + Array("--topic", testTopicName, "--replication-factor", "2"))) + + val partitions = adminClient + .describeTopics(Collections.singletonList(testTopicName)) + .allTopicNames() + .get() + .get(testTopicName) + .partitions() + + assertEquals(partitions.size(), numPartitions) + assertEquals(partitions.get(0).replicas().size(), 2) + } + + @Test + def testCreateWithConfigs(): Unit = { + val configResource = new ConfigResource(ConfigResource.Type.TOPIC, testTopicName) + createAndWaitTopic(new TopicCommandOptions( + Array("--partitions", "2", "--replication-factor", "2", "--topic", testTopicName, "--config", "delete.retention.ms=1000"))) + + val configs = adminClient + .describeConfigs(Collections.singleton(configResource)) + .all().get().get(configResource) + assertEquals(1000, Integer.valueOf(configs.get("delete.retention.ms").value())) + } + + @Test + def testCreateWhenAlreadyExists(): Unit = { + val numPartitions = 1 + + // create the topic + val createOpts = new TopicCommandOptions( + Array("--partitions", numPartitions.toString, "--replication-factor", "1", "--topic", testTopicName)) + createAndWaitTopic(createOpts) + + // try to re-create the topic + assertThrows(classOf[TopicExistsException], () => topicService.createTopic(createOpts)) + } + + @Test + def testCreateWhenAlreadyExistsWithIfNotExists(): Unit = { + val createOpts = new TopicCommandOptions(Array("--topic", testTopicName, "--if-not-exists")) + createAndWaitTopic(createOpts) + topicService.createTopic(createOpts) + } + + @Test + def testCreateWithReplicaAssignment(): Unit = { + // create the topic + val createOpts = new TopicCommandOptions( + Array("--replica-assignment", "5:4,3:2,1:0", "--topic", testTopicName)) + createAndWaitTopic(createOpts) + + val partitions = adminClient + .describeTopics(Collections.singletonList(testTopicName)) + .allTopicNames() + .get() + .get(testTopicName) + .partitions() + assertEquals(3, partitions.size()) + assertEquals(List(5, 4), partitions.get(0).replicas().asScala.map(_.id())) + assertEquals(List(3, 2), partitions.get(1).replicas().asScala.map(_.id())) + assertEquals(List(1, 0), partitions.get(2).replicas().asScala.map(_.id())) + } + + @Test + def testCreateWithInvalidReplicationFactor(): Unit = { + assertThrows(classOf[IllegalArgumentException], + () => topicService.createTopic(new TopicCommandOptions( + Array("--partitions", "2", "--replication-factor", (Short.MaxValue+1).toString, "--topic", testTopicName)))) + } + + @Test + def testCreateWithNegativeReplicationFactor(): Unit = { + assertThrows(classOf[IllegalArgumentException], + () => topicService.createTopic(new TopicCommandOptions( + Array("--partitions", "2", "--replication-factor", "-1", "--topic", testTopicName)))) + } + + @Test + def testCreateWithNegativePartitionCount(): Unit = { + assertThrows(classOf[IllegalArgumentException], + () => topicService.createTopic(new TopicCommandOptions( + Array("--partitions", "-1", "--replication-factor", "1", "--topic", testTopicName)))) + } + + @Test + def testInvalidTopicLevelConfig(): Unit = { + val createOpts = new TopicCommandOptions( + Array("--partitions", "1", "--replication-factor", "1", "--topic", testTopicName, + "--config", "message.timestamp.type=boom")) + assertThrows(classOf[ConfigException], () => topicService.createTopic(createOpts)) + } + + @Test + def testListTopics(): Unit = { + createAndWaitTopic(new TopicCommandOptions( + Array("--partitions", "1", "--replication-factor", "1", "--topic", testTopicName))) + + val output = TestUtils.grabConsoleOutput( + topicService.listTopics(new TopicCommandOptions(Array()))) + + assertTrue(output.contains(testTopicName)) + } + + @Test + def testListTopicsWithIncludeList(): Unit = { + val topic1 = "kafka.testTopic1" + val topic2 = "kafka.testTopic2" + val topic3 = "oooof.testTopic1" + adminClient.createTopics( + List(new NewTopic(topic1, 2, 2.toShort), + new NewTopic(topic2, 2, 2.toShort), + new NewTopic(topic3, 2, 2.toShort)).asJavaCollection) + .all().get() + waitForTopicCreated(topic1) + waitForTopicCreated(topic2) + waitForTopicCreated(topic3) + + val output = TestUtils.grabConsoleOutput( + topicService.listTopics(new TopicCommandOptions(Array("--topic", "kafka.*")))) + + assertTrue(output.contains(topic1)) + assertTrue(output.contains(topic2)) + assertFalse(output.contains(topic3)) + } + + @Test + def testListTopicsWithExcludeInternal(): Unit = { + val topic1 = "kafka.testTopic1" + adminClient.createTopics( + List(new NewTopic(topic1, 2, 2.toShort), + new NewTopic(Topic.GROUP_METADATA_TOPIC_NAME, 2, 2.toShort)).asJavaCollection) + .all().get() + waitForTopicCreated(topic1) + + val output = TestUtils.grabConsoleOutput( + topicService.listTopics(new TopicCommandOptions(Array("--exclude-internal")))) + + assertTrue(output.contains(topic1)) + assertFalse(output.contains(Topic.GROUP_METADATA_TOPIC_NAME)) + } + + @Test + def testAlterPartitionCount(): Unit = { + adminClient.createTopics( + List(new NewTopic(testTopicName, 2, 2.toShort)).asJavaCollection).all().get() + waitForTopicCreated(testTopicName) + + topicService.alterTopic(new TopicCommandOptions( + Array("--topic", testTopicName, "--partitions", "3"))) + + val topicDescription = adminClient.describeTopics(Collections.singletonList(testTopicName)).topicNameValues().get(testTopicName).get() + assertTrue(topicDescription.partitions().size() == 3) + } + + @Test + def testAlterAssignment(): Unit = { + adminClient.createTopics( + Collections.singletonList(new NewTopic(testTopicName, 2, 2.toShort))).all().get() + waitForTopicCreated(testTopicName) + + topicService.alterTopic(new TopicCommandOptions( + Array("--topic", testTopicName, "--replica-assignment", "5:3,3:1,4:2", "--partitions", "3"))) + + val topicDescription = adminClient.describeTopics(Collections.singletonList(testTopicName)).topicNameValues().get(testTopicName).get() + assertTrue(topicDescription.partitions().size() == 3) + assertEquals(List(4,2), topicDescription.partitions().get(2).replicas().asScala.map(_.id())) + } + + @Test + def testAlterAssignmentWithMoreAssignmentThanPartitions(): Unit = { + adminClient.createTopics( + List(new NewTopic(testTopicName, 2, 2.toShort)).asJavaCollection).all().get() + waitForTopicCreated(testTopicName) + + assertThrows(classOf[ExecutionException], + () => topicService.alterTopic(new TopicCommandOptions( + Array("--topic", testTopicName, "--replica-assignment", "5:3,3:1,4:2,3:2", "--partitions", "3")))) + } + + @Test + def testAlterAssignmentWithMorePartitionsThanAssignment(): Unit = { + adminClient.createTopics( + List(new NewTopic(testTopicName, 2, 2.toShort)).asJavaCollection).all().get() + waitForTopicCreated(testTopicName) + + assertThrows(classOf[ExecutionException], + () => topicService.alterTopic(new TopicCommandOptions( + Array("--topic", testTopicName, "--replica-assignment", "5:3,3:1,4:2", "--partitions", "6")))) + } + + @Test + def testAlterWithInvalidPartitionCount(): Unit = { + createAndWaitTopic(new TopicCommandOptions( + Array("--partitions", "1", "--replication-factor", "1", "--topic", testTopicName))) + + assertThrows(classOf[ExecutionException], + () => topicService.alterTopic(new TopicCommandOptions( + Array("--partitions", "-1", "--topic", testTopicName)))) + } + + @Test + def testAlterWhenTopicDoesntExist(): Unit = { + // alter a topic that does not exist without --if-exists + val alterOpts = new TopicCommandOptions(Array("--topic", testTopicName, "--partitions", "1")) + val topicService = TopicService(adminClient) + assertThrows(classOf[IllegalArgumentException], () => topicService.alterTopic(alterOpts)) + } + + @Test + def testAlterWhenTopicDoesntExistWithIfExists(): Unit = { + topicService.alterTopic(new TopicCommandOptions( + Array("--topic", testTopicName, "--partitions", "1", "--if-exists"))) + } + + @Test + def testCreateAlterTopicWithRackAware(): Unit = { + val rackInfo = Map(0 -> "rack1", 1 -> "rack2", 2 -> "rack2", 3 -> "rack1", 4 -> "rack3", 5 -> "rack3") + + val numPartitions = 18 + val replicationFactor = 3 + val createOpts = new TopicCommandOptions(Array( + "--partitions", numPartitions.toString, + "--replication-factor", replicationFactor.toString, + "--topic", testTopicName)) + createAndWaitTopic(createOpts) + + var assignment = zkClient.getReplicaAssignmentForTopics(Set(testTopicName)).map { case (tp, replicas) => + tp.partition -> replicas + } + checkReplicaDistribution(assignment, rackInfo, rackInfo.size, numPartitions, replicationFactor) + + val alteredNumPartitions = 36 + // verify that adding partitions will also be rack aware + val alterOpts = new TopicCommandOptions(Array( + "--partitions", alteredNumPartitions.toString, + "--topic", testTopicName)) + topicService.alterTopic(alterOpts) + assignment = zkClient.getReplicaAssignmentForTopics(Set(testTopicName)).map { case (tp, replicas) => + tp.partition -> replicas + } + checkReplicaDistribution(assignment, rackInfo, rackInfo.size, alteredNumPartitions, replicationFactor) + } + + @Test + def testConfigPreservationAcrossPartitionAlteration(): Unit = { + val numPartitionsOriginal = 1 + val cleanupKey = "cleanup.policy" + val cleanupVal = "compact" + + // create the topic + val createOpts = new TopicCommandOptions(Array( + "--partitions", numPartitionsOriginal.toString, + "--replication-factor", "1", + "--config", cleanupKey + "=" + cleanupVal, + "--topic", testTopicName)) + createAndWaitTopic(createOpts) + val props = adminZkClient.fetchEntityConfig(ConfigType.Topic, testTopicName) + assertTrue(props.containsKey(cleanupKey), "Properties after creation don't contain " + cleanupKey) + assertTrue(props.getProperty(cleanupKey).equals(cleanupVal), "Properties after creation have incorrect value") + + // pre-create the topic config changes path to avoid a NoNodeException + zkClient.makeSurePersistentPathExists(ConfigEntityChangeNotificationZNode.path) + + // modify the topic to add new partitions + val numPartitionsModified = 3 + val alterOpts = new TopicCommandOptions( + Array("--partitions", numPartitionsModified.toString, "--topic", testTopicName)) + topicService.alterTopic(alterOpts) + val newProps = adminZkClient.fetchEntityConfig(ConfigType.Topic, testTopicName) + assertTrue(newProps.containsKey(cleanupKey), "Updated properties do not contain " + cleanupKey) + assertTrue(newProps.getProperty(cleanupKey).equals(cleanupVal), "Updated properties have incorrect value") + } + + @Test + def testTopicDeletion(): Unit = { + // create the NormalTopic + val createOpts = new TopicCommandOptions(Array("--partitions", "1", + "--replication-factor", "1", + "--topic", testTopicName)) + createAndWaitTopic(createOpts) + + // delete the NormalTopic + val deleteOpts = new TopicCommandOptions(Array("--topic", testTopicName)) + + val deletePath = DeleteTopicsTopicZNode.path(testTopicName) + assertFalse(zkClient.pathExists(deletePath), "Delete path for topic shouldn't exist before deletion.") + topicService.deleteTopic(deleteOpts) + TestUtils.verifyTopicDeletion(zkClient, testTopicName, 1, servers) + } + + @Test + def testDeleteInternalTopic(): Unit = { + // create the offset topic + val createOffsetTopicOpts = new TopicCommandOptions(Array("--partitions", "1", + "--replication-factor", "1", + "--topic", Topic.GROUP_METADATA_TOPIC_NAME)) + createAndWaitTopic(createOffsetTopicOpts) + + // Try to delete the Topic.GROUP_METADATA_TOPIC_NAME which is allowed by default. + // This is a difference between the new and the old command as the old one didn't allow internal topic deletion. + // If deleting internal topics is not desired, ACLS should be used to control it. + val deleteOffsetTopicOpts = new TopicCommandOptions( + Array("--topic", Topic.GROUP_METADATA_TOPIC_NAME)) + val deleteOffsetTopicPath = DeleteTopicsTopicZNode.path(Topic.GROUP_METADATA_TOPIC_NAME) + assertFalse(zkClient.pathExists(deleteOffsetTopicPath), "Delete path for topic shouldn't exist before deletion.") + topicService.deleteTopic(deleteOffsetTopicOpts) + TestUtils.verifyTopicDeletion(zkClient, Topic.GROUP_METADATA_TOPIC_NAME, 1, servers) + } + + @Test + def testDeleteWhenTopicDoesntExist(): Unit = { + // delete a topic that does not exist + val deleteOpts = new TopicCommandOptions(Array("--topic", testTopicName)) + assertThrows(classOf[IllegalArgumentException], () => topicService.deleteTopic(deleteOpts)) + } + + @Test + def testDeleteWhenTopicDoesntExistWithIfExists(): Unit = { + topicService.deleteTopic(new TopicCommandOptions(Array("--topic", testTopicName, "--if-exists"))) + } + + @Test + def testDescribe(): Unit = { + adminClient.createTopics( + Collections.singletonList(new NewTopic(testTopicName, 2, 2.toShort))).all().get() + waitForTopicCreated(testTopicName) + + val output = TestUtils.grabConsoleOutput( + topicService.describeTopic(new TopicCommandOptions(Array("--topic", testTopicName)))) + val rows = output.split("\n") + assertEquals(3, rows.size) + assertTrue(rows(0).startsWith(s"Topic: $testTopicName")) + } + + @Test + def testDescribeWhenTopicDoesntExist(): Unit = { + assertThrows(classOf[IllegalArgumentException], + () => topicService.describeTopic(new TopicCommandOptions(Array("--topic", testTopicName)))) + } + + @Test + def testDescribeWhenTopicDoesntExistWithIfExists(): Unit = { + topicService.describeTopic(new TopicCommandOptions(Array("--topic", testTopicName, "--if-exists"))) + } + + @Test + def testDescribeUnavailablePartitions(): Unit = { + adminClient.createTopics( + Collections.singletonList(new NewTopic(testTopicName, 6, 1.toShort))).all().get() + waitForTopicCreated(testTopicName) + + try { + // check which partition is on broker 0 which we'll kill + val testTopicDescription = adminClient.describeTopics(Collections.singletonList(testTopicName)) + .allTopicNames().get().asScala(testTopicName) + val partitionOnBroker0 = testTopicDescription.partitions().asScala.find(_.leader().id() == 0).get.partition() + + killBroker(0) + + // wait until the topic metadata for the test topic is propagated to each alive broker + TestUtils.waitUntilTrue(() => { + servers + .filterNot(_.config.brokerId == 0) + .foldLeft(true) { + (result, server) => { + val topicMetadatas = server.dataPlaneRequestProcessor.metadataCache + .getTopicMetadata(Set(testTopicName), ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT)) + val testPartitionMetadata = topicMetadatas.find(_.name.equals(testTopicName)).get.partitions.asScala.find(_.partitionIndex == partitionOnBroker0) + testPartitionMetadata match { + case None => throw new AssertionError(s"Partition metadata is not found in metadata cache") + case Some(metadata) => result && metadata.errorCode == Errors.LEADER_NOT_AVAILABLE.code + } + } + } + }, s"Partition metadata for $testTopicName is not propagated") + + // grab the console output and assert + val output = TestUtils.grabConsoleOutput( + topicService.describeTopic(new TopicCommandOptions( + Array("--topic", testTopicName, "--unavailable-partitions")))) + val rows = output.split("\n") + assertTrue(rows(0).startsWith(s"\tTopic: $testTopicName")) + assertTrue(rows(0).contains("Leader: none\tReplicas: 0\tIsr:")) + } finally { + restartDeadBrokers() + } + } + + @Test + def testDescribeUnderReplicatedPartitions(): Unit = { + adminClient.createTopics( + Collections.singletonList(new NewTopic(testTopicName, 1, 6.toShort))).all().get() + waitForTopicCreated(testTopicName) + + try { + killBroker(0) + val aliveServers = servers.filterNot(_.config.brokerId == 0) + TestUtils.waitForPartitionMetadata(aliveServers, testTopicName, 0) + val output = TestUtils.grabConsoleOutput( + topicService.describeTopic(new TopicCommandOptions(Array("--under-replicated-partitions")))) + val rows = output.split("\n") + assertTrue(rows(0).startsWith(s"\tTopic: $testTopicName")) + } finally { + restartDeadBrokers() + } + } + + @Test + def testDescribeUnderMinIsrPartitions(): Unit = { + val configMap = new java.util.HashMap[String, String]() + configMap.put(TopicConfig.MIN_IN_SYNC_REPLICAS_CONFIG, "6") + + adminClient.createTopics( + Collections.singletonList(new NewTopic(testTopicName, 1, 6.toShort).configs(configMap))).all().get() + waitForTopicCreated(testTopicName) + + try { + killBroker(0) + val aliveServers = servers.filterNot(_.config.brokerId == 0) + TestUtils.waitForPartitionMetadata(aliveServers, testTopicName, 0) + val output = TestUtils.grabConsoleOutput( + topicService.describeTopic(new TopicCommandOptions(Array("--under-min-isr-partitions")))) + val rows = output.split("\n") + assertTrue(rows(0).startsWith(s"\tTopic: $testTopicName")) + } finally { + restartDeadBrokers() + } + } + + @Test + def testDescribeUnderReplicatedPartitionsWhenReassignmentIsInProgress(): Unit = { + val configMap = new java.util.HashMap[String, String]() + val replicationFactor: Short = 1 + val partitions = 1 + val tp = new TopicPartition(testTopicName, 0) + + adminClient.createTopics( + Collections.singletonList(new NewTopic(testTopicName, partitions, replicationFactor).configs(configMap))).all().get() + waitForTopicCreated(testTopicName) + + // Produce multiple batches. + TestUtils.generateAndProduceMessages(servers, testTopicName, numMessages = 10, acks = -1) + TestUtils.generateAndProduceMessages(servers, testTopicName, numMessages = 10, acks = -1) + + // Enable throttling. Note the broker config sets the replica max fetch bytes to `1` upon to minimize replication + // throughput so the reassignment doesn't complete quickly. + val brokerIds = servers.map(_.config.brokerId) + TestUtils.setReplicationThrottleForPartitions(adminClient, brokerIds, Set(tp), throttleBytes = 1) + + val testTopicDesc = adminClient.describeTopics(Collections.singleton(testTopicName)).allTopicNames().get().get(testTopicName) + val firstPartition = testTopicDesc.partitions().asScala.head + + val replicasOfFirstPartition = firstPartition.replicas().asScala.map(_.id()) + val targetReplica = brokerIds.diff(replicasOfFirstPartition).head + + adminClient.alterPartitionReassignments(Collections.singletonMap(tp, + Optional.of(new NewPartitionReassignment(Collections.singletonList(targetReplica))))).all().get() + + // let's wait until the LAIR is propagated + TestUtils.waitUntilTrue(() => { + val reassignments = adminClient.listPartitionReassignments(Collections.singleton(tp)).reassignments().get() + !reassignments.get(tp).addingReplicas().isEmpty + }, "Reassignment didn't add the second node") + + // describe the topic and test if it's under-replicated + val simpleDescribeOutput = TestUtils.grabConsoleOutput( + topicService.describeTopic(new TopicCommandOptions(Array("--topic", testTopicName)))) + val simpleDescribeOutputRows = simpleDescribeOutput.split("\n") + assertTrue(simpleDescribeOutputRows(0).startsWith(s"Topic: $testTopicName")) + assertEquals(2, simpleDescribeOutputRows.size) + + val underReplicatedOutput = TestUtils.grabConsoleOutput( + topicService.describeTopic(new TopicCommandOptions(Array("--under-replicated-partitions")))) + assertEquals("", underReplicatedOutput, s"--under-replicated-partitions shouldn't return anything: '$underReplicatedOutput'") + + // Verify reassignment is still ongoing. + val reassignments = adminClient.listPartitionReassignments(Collections.singleton(tp)).reassignments.get().get(tp) + assertFalse(Option(reassignments).forall(_.addingReplicas.isEmpty)) + + TestUtils.removeReplicationThrottleForPartitions(adminClient, brokerIds, Set(tp)) + TestUtils.waitForAllReassignmentsToComplete(adminClient) + } + + @Test + def testDescribeAtMinIsrPartitions(): Unit = { + val configMap = new java.util.HashMap[String, String]() + configMap.put(TopicConfig.MIN_IN_SYNC_REPLICAS_CONFIG, "4") + + adminClient.createTopics( + Collections.singletonList(new NewTopic(testTopicName, 1, 6.toShort).configs(configMap))).all().get() + waitForTopicCreated(testTopicName) + + try { + killBroker(0) + killBroker(1) + val output = TestUtils.grabConsoleOutput( + topicService.describeTopic(new TopicCommandOptions(Array("--at-min-isr-partitions")))) + val rows = output.split("\n") + assertTrue(rows(0).startsWith(s"\tTopic: $testTopicName")) + assertEquals(1, rows.length) + } finally { + restartDeadBrokers() + } + } + + /** + * Test describe --under-min-isr-partitions option with four topics: + * (1) topic with partition under the configured min ISR count + * (2) topic with under-replicated partition (but not under min ISR count) + * (3) topic with offline partition + * (4) topic with fully replicated partition + * + * Output should only display the (1) topic with partition under min ISR count and (3) topic with offline partition + */ + @Test + def testDescribeUnderMinIsrPartitionsMixed(): Unit = { + val underMinIsrTopic = "under-min-isr-topic" + val notUnderMinIsrTopic = "not-under-min-isr-topic" + val offlineTopic = "offline-topic" + val fullyReplicatedTopic = "fully-replicated-topic" + + val configMap = new java.util.HashMap[String, String]() + configMap.put(TopicConfig.MIN_IN_SYNC_REPLICAS_CONFIG, "6") + + adminClient.createTopics( + java.util.Arrays.asList( + new NewTopic(underMinIsrTopic, 1, 6.toShort).configs(configMap), + new NewTopic(notUnderMinIsrTopic, 1, 6.toShort), + new NewTopic(offlineTopic, Collections.singletonMap(0, Collections.singletonList(0))), + new NewTopic(fullyReplicatedTopic, Collections.singletonMap(0, java.util.Arrays.asList(1, 2, 3))))).all().get() + + waitForTopicCreated(underMinIsrTopic) + waitForTopicCreated(notUnderMinIsrTopic) + waitForTopicCreated(offlineTopic) + waitForTopicCreated(fullyReplicatedTopic) + + try { + killBroker(0) + val aliveServers = servers.filterNot(_.config.brokerId == 0) + TestUtils.waitForPartitionMetadata(aliveServers, underMinIsrTopic, 0) + val output = TestUtils.grabConsoleOutput( + topicService.describeTopic(new TopicCommandOptions(Array("--under-min-isr-partitions")))) + val rows = output.split("\n") + assertTrue(rows(0).startsWith(s"\tTopic: $underMinIsrTopic")) + assertTrue(rows(1).startsWith(s"\tTopic: $offlineTopic")) + assertEquals(2, rows.length) + } finally { + restartDeadBrokers() + } + } + + @Test + def testDescribeReportOverriddenConfigs(): Unit = { + val config = "file.delete.delay.ms=1000" + createAndWaitTopic(new TopicCommandOptions( + Array("--partitions", "2", "--replication-factor", "2", "--topic", testTopicName, "--config", config))) + val output = TestUtils.grabConsoleOutput( + topicService.describeTopic(new TopicCommandOptions(Array()))) + assertTrue(output.contains(config), s"Describe output should have contained $config") + } + + @Test + def testDescribeAndListTopicsWithoutInternalTopics(): Unit = { + createAndWaitTopic( + new TopicCommandOptions(Array("--partitions", "1", "--replication-factor", "1", "--topic", testTopicName))) + // create a internal topic + createAndWaitTopic( + new TopicCommandOptions(Array("--partitions", "1", "--replication-factor", "1", "--topic", Topic.GROUP_METADATA_TOPIC_NAME))) + + // test describe + var output = TestUtils.grabConsoleOutput(topicService.describeTopic(new TopicCommandOptions( + Array("--describe", "--exclude-internal")))) + assertTrue(output.contains(testTopicName), s"Output should have contained $testTopicName") + assertFalse(output.contains(Topic.GROUP_METADATA_TOPIC_NAME)) + + // test list + output = TestUtils.grabConsoleOutput(topicService.listTopics(new TopicCommandOptions(Array("--list", "--exclude-internal")))) + assertTrue(output.contains(testTopicName)) + assertFalse(output.contains(Topic.GROUP_METADATA_TOPIC_NAME)) + } + + @Test + def testDescribeDoesNotFailWhenListingReassignmentIsUnauthorized(): Unit = { + adminClient = spy(adminClient) + topicService = TopicService(adminClient) + + val result = AdminClientTestUtils.listPartitionReassignmentsResult( + new ClusterAuthorizationException("Unauthorized")) + + // Passing `null` here to help the compiler disambiguate the `doReturn` methods, + // compilation for scala 2.12 fails otherwise. + doReturn(result, null).when(adminClient).listPartitionReassignments( + Set(new TopicPartition(testTopicName, 0)).asJava + ) + + adminClient.createTopics( + Collections.singletonList(new NewTopic(testTopicName, 1, 1.toShort)) + ).all().get() + waitForTopicCreated(testTopicName) + + val output = TestUtils.grabConsoleOutput( + topicService.describeTopic(new TopicCommandOptions(Array("--topic", testTopicName)))) + val rows = output.split("\n") + assertEquals(2, rows.size) + assertTrue(rows(0).startsWith(s"Topic: $testTopicName")) + } + + @Test + def testCreateTopicDoesNotRetryThrottlingQuotaExceededException(): Unit = { + val adminClient = mock(classOf[Admin]) + val topicService = TopicService(adminClient) + + val result = AdminClientTestUtils.createTopicsResult(testTopicName, Errors.THROTTLING_QUOTA_EXCEEDED.exception()) + when(adminClient.createTopics(any(), any())).thenReturn(result) + + assertThrows(classOf[ThrottlingQuotaExceededException], + () => topicService.createTopic(new TopicCommandOptions(Array("--topic", testTopicName)))) + + val expectedNewTopic = new NewTopic(testTopicName, Optional.empty[Integer](), Optional.empty[java.lang.Short]()) + .configs(Map.empty[String, String].asJava) + + verify(adminClient, times(1)).createTopics( + eqThat(Set(expectedNewTopic).asJava), + argThat((_.shouldRetryOnQuotaViolation() == false): ArgumentMatcher[CreateTopicsOptions]) + ) + } + + @Test + def testDeleteTopicDoesNotRetryThrottlingQuotaExceededException(): Unit = { + val adminClient = mock(classOf[Admin]) + val topicService = TopicService(adminClient) + + val listResult = AdminClientTestUtils.listTopicsResult(testTopicName) + when(adminClient.listTopics(any())).thenReturn(listResult) + + val result = AdminClientTestUtils.deleteTopicsResult(testTopicName, Errors.THROTTLING_QUOTA_EXCEEDED.exception()) + when(adminClient.deleteTopics(any[Collection[String]](), any())).thenReturn(result) + + val exception = assertThrows(classOf[ExecutionException], + () => topicService.deleteTopic(new TopicCommandOptions(Array("--topic", testTopicName)))) + assertTrue(exception.getCause.isInstanceOf[ThrottlingQuotaExceededException]) + + verify(adminClient, times(1)).deleteTopics( + eqThat(Seq(testTopicName).asJavaCollection), + argThat((_.shouldRetryOnQuotaViolation() == false): ArgumentMatcher[DeleteTopicsOptions]) + ) + } + + @Test + def testCreatePartitionsDoesNotRetryThrottlingQuotaExceededException(): Unit = { + val adminClient = mock(classOf[Admin]) + val topicService = TopicService(adminClient) + + val listResult = AdminClientTestUtils.listTopicsResult(testTopicName) + when(adminClient.listTopics(any())).thenReturn(listResult) + + val topicPartitionInfo = new TopicPartitionInfo(0, new Node(0, "", 0), + Collections.emptyList(), Collections.emptyList()) + val describeResult = AdminClientTestUtils.describeTopicsResult(testTopicName, new TopicDescription( + testTopicName, false, Collections.singletonList(topicPartitionInfo))) + when(adminClient.describeTopics(any(classOf[java.util.Collection[String]]))).thenReturn(describeResult) + + val result = AdminClientTestUtils.createPartitionsResult(testTopicName, Errors.THROTTLING_QUOTA_EXCEEDED.exception()) + when(adminClient.createPartitions(any(), any())).thenReturn(result) + + val exception = assertThrows(classOf[ExecutionException], + () => topicService.alterTopic(new TopicCommandOptions(Array("--topic", testTopicName, "--partitions", "3")))) + assertTrue(exception.getCause.isInstanceOf[ThrottlingQuotaExceededException]) + + verify(adminClient, times(1)).createPartitions( + argThat((_.get(testTopicName).totalCount() == 3): ArgumentMatcher[java.util.Map[String, NewPartitions]]), + argThat((_.shouldRetryOnQuotaViolation() == false): ArgumentMatcher[CreatePartitionsOptions]) + ) + } +} diff --git a/core/src/test/scala/unit/kafka/admin/TopicCommandTest.scala b/core/src/test/scala/unit/kafka/admin/TopicCommandTest.scala new file mode 100644 index 0000000..9586cf5 --- /dev/null +++ b/core/src/test/scala/unit/kafka/admin/TopicCommandTest.scala @@ -0,0 +1,170 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.admin + +import kafka.admin.TopicCommand.{PartitionDescription, TopicCommandOptions} +import kafka.common.AdminCommandFailedException +import kafka.utils.Exit +import org.apache.kafka.clients.admin.PartitionReassignment +import org.apache.kafka.common.Node +import org.apache.kafka.common.TopicPartitionInfo +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ + +class TopicCommandTest { + + private[this] val brokerList = "localhost:9092" + private[this] val topicName = "topicName" + + @Test + def testIsNotUnderReplicatedWhenAdding(): Unit = { + val replicaIds = List(1, 2) + val replicas = replicaIds.map { id => + new Node(id, "localhost", 9090 + id) + } + + val partitionDescription = PartitionDescription( + "test-topic", + new TopicPartitionInfo( + 0, + new Node(1, "localhost", 9091), + replicas.asJava, + List(new Node(1, "localhost", 9091)).asJava + ), + None, + markedForDeletion = false, + Some( + new PartitionReassignment( + replicaIds.map(id => id: java.lang.Integer).asJava, + List(2: java.lang.Integer).asJava, + List.empty.asJava + ) + ) + ) + + assertFalse(partitionDescription.isUnderReplicated) + } + + @Test + def testAlterWithUnspecifiedPartitionCount(): Unit = { + assertCheckArgsExitCode(1, new TopicCommandOptions( + Array("--bootstrap-server", brokerList ,"--alter", "--topic", topicName))) + } + + @Test + def testConfigOptWithBootstrapServers(): Unit = { + assertCheckArgsExitCode(1, + new TopicCommandOptions(Array("--bootstrap-server", brokerList ,"--alter", "--topic", topicName, "--partitions", "3", "--config", "cleanup.policy=compact"))) + assertCheckArgsExitCode(1, + new TopicCommandOptions(Array("--bootstrap-server", brokerList ,"--alter", "--topic", topicName, "--partitions", "3", "--delete-config", "cleanup.policy"))) + val opts = + new TopicCommandOptions(Array("--bootstrap-server", brokerList ,"--create", "--topic", topicName, "--partitions", "3", "--replication-factor", "3", "--config", "cleanup.policy=compact")) + opts.checkArgs() + assertTrue(opts.hasCreateOption) + assertEquals(brokerList, opts.bootstrapServer.get) + assertEquals("cleanup.policy=compact", opts.topicConfig.get.get(0)) + } + + @Test + def testCreateWithPartitionCountWithoutReplicationFactorShouldSucceed(): Unit = { + val opts = new TopicCommandOptions( + Array("--bootstrap-server", brokerList, + "--create", + "--partitions", "2", + "--topic", topicName)) + opts.checkArgs() + } + + @Test + def testCreateWithReplicationFactorWithoutPartitionCountShouldSucceed(): Unit = { + val opts = new TopicCommandOptions( + Array("--bootstrap-server", brokerList, + "--create", + "--replication-factor", "3", + "--topic", topicName)) + opts.checkArgs() + } + + @Test + def testCreateWithAssignmentAndPartitionCount(): Unit = { + assertCheckArgsExitCode(1, + new TopicCommandOptions( + Array("--bootstrap-server", brokerList, + "--create", + "--replica-assignment", "3:0,5:1", + "--partitions", "2", + "--topic", topicName))) + } + + @Test + def testCreateWithAssignmentAndReplicationFactor(): Unit = { + assertCheckArgsExitCode(1, + new TopicCommandOptions( + Array("--bootstrap-server", brokerList, + "--create", + "--replica-assignment", "3:0,5:1", + "--replication-factor", "2", + "--topic", topicName))) + } + + @Test + def testCreateWithoutPartitionCountAndReplicationFactorShouldSucceed(): Unit = { + val opts = new TopicCommandOptions( + Array("--bootstrap-server", brokerList, + "--create", + "--topic", topicName)) + opts.checkArgs() + } + + @Test + def testDescribeShouldSucceed(): Unit = { + val opts = new TopicCommandOptions( + Array("--bootstrap-server", brokerList, + "--describe", + "--topic", topicName)) + opts.checkArgs() + } + + + @Test + def testParseAssignmentDuplicateEntries(): Unit = { + assertThrows(classOf[AdminCommandFailedException], () => TopicCommand.parseReplicaAssignment("5:5")) + } + + @Test + def testParseAssignmentPartitionsOfDifferentSize(): Unit = { + assertThrows(classOf[AdminOperationException], () => TopicCommand.parseReplicaAssignment("5:4:3,2:1")) + } + + @Test + def testParseAssignment(): Unit = { + val actualAssignment = TopicCommand.parseReplicaAssignment("5:4,3:2,1:0") + val expectedAssignment = Map(0 -> List(5, 4), 1 -> List(3, 2), 2 -> List(1, 0)) + assertEquals(expectedAssignment, actualAssignment) + } + + private[this] def assertCheckArgsExitCode(expected: Int, options: TopicCommandOptions): Unit = { + Exit.setExitProcedure { + (exitCode: Int, _: Option[String]) => + assertEquals(expected, exitCode) + throw new RuntimeException + } + try assertThrows(classOf[RuntimeException], () => options.checkArgs()) finally Exit.resetExitProcedure() + } +} diff --git a/core/src/test/scala/unit/kafka/admin/UserScramCredentialsCommandTest.scala b/core/src/test/scala/unit/kafka/admin/UserScramCredentialsCommandTest.scala new file mode 100644 index 0000000..2f823f8 --- /dev/null +++ b/core/src/test/scala/unit/kafka/admin/UserScramCredentialsCommandTest.scala @@ -0,0 +1,137 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.admin + +import java.io.{ByteArrayOutputStream, PrintStream} +import java.nio.charset.StandardCharsets + +import kafka.server.BaseRequestTest +import kafka.utils.Exit +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +class UserScramCredentialsCommandTest extends BaseRequestTest { + override def brokerCount = 1 + var exitStatus: Option[Int] = None + var exitMessage: Option[String] = None + + case class ConfigCommandResult(stdout: String, exitStatus: Option[Int] = None) + + private def runConfigCommandViaBroker(args: Array[String]) : ConfigCommandResult = { + val byteArrayOutputStream = new ByteArrayOutputStream() + val utf8 = StandardCharsets.UTF_8.name + val printStream = new PrintStream(byteArrayOutputStream, true, utf8) + var exitStatus: Option[Int] = None + Exit.setExitProcedure { (status, _) => + exitStatus = Some(status) + throw new RuntimeException + } + val commandArgs = Array("--bootstrap-server", brokerList) ++ args + try { + Console.withOut(printStream) { + ConfigCommand.main(commandArgs) + } + ConfigCommandResult(byteArrayOutputStream.toString(utf8)) + } catch { + case e: Exception => { + debug(s"Exception running ConfigCommand ${commandArgs.mkString(" ")}", e) + ConfigCommandResult("", exitStatus) + } + } finally { + printStream.close + Exit.resetExitProcedure() + } + } + + @Test + def testUserScramCredentialsRequests(): Unit = { + val user1 = "user1" + // create and describe a credential + var result = runConfigCommandViaBroker(Array("--user", user1, "--alter", "--add-config", "SCRAM-SHA-256=[iterations=4096,password=foo-secret]")) + val alterConfigsUser1Out = s"Completed updating config for user $user1.\n" + assertEquals(alterConfigsUser1Out, result.stdout) + result = runConfigCommandViaBroker(Array("--user", user1, "--describe")) + val scramCredentialConfigsUser1Out = s"SCRAM credential configs for user-principal '$user1' are SCRAM-SHA-256=iterations=4096\n" + assertEquals(scramCredentialConfigsUser1Out, result.stdout) + // create a user quota and describe the user again + result = runConfigCommandViaBroker(Array("--user", user1, "--alter", "--add-config", "consumer_byte_rate=20000")) + assertEquals(alterConfigsUser1Out, result.stdout) + result = runConfigCommandViaBroker(Array("--user", user1, "--describe")) + val quotaConfigsUser1Out = s"Quota configs for user-principal '$user1' are consumer_byte_rate=20000.0\n" + assertEquals(s"$quotaConfigsUser1Out$scramCredentialConfigsUser1Out", result.stdout) + + // now do the same thing for user2 + val user2 = "user2" + // create and describe a credential + result = runConfigCommandViaBroker(Array("--user", user2, "--alter", "--add-config", "SCRAM-SHA-256=[iterations=4096,password=foo-secret]")) + val alterConfigsUser2Out = s"Completed updating config for user $user2.\n" + assertEquals(alterConfigsUser2Out, result.stdout) + result = runConfigCommandViaBroker(Array("--user", user2, "--describe")) + val scramCredentialConfigsUser2Out = s"SCRAM credential configs for user-principal '$user2' are SCRAM-SHA-256=iterations=4096\n" + assertEquals(scramCredentialConfigsUser2Out, result.stdout) + // create a user quota and describe the user again + result = runConfigCommandViaBroker(Array("--user", user2, "--alter", "--add-config", "consumer_byte_rate=20000")) + assertEquals(alterConfigsUser2Out, result.stdout) + result = runConfigCommandViaBroker(Array("--user", user2, "--describe")) + val quotaConfigsUser2Out = s"Quota configs for user-principal '$user2' are consumer_byte_rate=20000.0\n" + assertEquals(s"$quotaConfigsUser2Out$scramCredentialConfigsUser2Out", result.stdout) + + // describe both + result = runConfigCommandViaBroker(Array("--entity-type", "users", "--describe")) + // we don't know the order that quota or scram users come out, so we have 2 possibilities for each, 4 total + val quotaPossibilityAOut = s"$quotaConfigsUser1Out$quotaConfigsUser2Out" + val quotaPossibilityBOut = s"$quotaConfigsUser2Out$quotaConfigsUser1Out" + val scramPossibilityAOut = s"$scramCredentialConfigsUser1Out$scramCredentialConfigsUser2Out" + val scramPossibilityBOut = s"$scramCredentialConfigsUser2Out$scramCredentialConfigsUser1Out" + assertTrue(result.stdout.equals(s"$quotaPossibilityAOut$scramPossibilityAOut") + || result.stdout.equals(s"$quotaPossibilityAOut$scramPossibilityBOut") + || result.stdout.equals(s"$quotaPossibilityBOut$scramPossibilityAOut") + || result.stdout.equals(s"$quotaPossibilityBOut$scramPossibilityBOut")) + + // now delete configs, in opposite order, for user1 and user2, and describe + result = runConfigCommandViaBroker(Array("--user", user1, "--alter", "--delete-config", "consumer_byte_rate")) + assertEquals(alterConfigsUser1Out, result.stdout) + result = runConfigCommandViaBroker(Array("--user", user2, "--alter", "--delete-config", "SCRAM-SHA-256")) + assertEquals(alterConfigsUser2Out, result.stdout) + result = runConfigCommandViaBroker(Array("--entity-type", "users", "--describe")) + assertEquals(s"$quotaConfigsUser2Out$scramCredentialConfigsUser1Out", result.stdout) + + // now delete the rest of the configs, for user1 and user2, and describe + result = runConfigCommandViaBroker(Array("--user", user1, "--alter", "--delete-config", "SCRAM-SHA-256")) + assertEquals(alterConfigsUser1Out, result.stdout) + result = runConfigCommandViaBroker(Array("--user", user2, "--alter", "--delete-config", "consumer_byte_rate")) + assertEquals(alterConfigsUser2Out, result.stdout) + result = runConfigCommandViaBroker(Array("--entity-type", "users", "--describe")) + assertEquals("", result.stdout) + } + + @Test + def testAlterWithEmptyPassword(): Unit = { + val user1 = "user1" + val result = runConfigCommandViaBroker(Array("--user", user1, "--alter", "--add-config", "SCRAM-SHA-256=[iterations=4096,password=]")) + assertTrue(result.exitStatus.isDefined, "Expected System.exit() to be called with an empty password") + assertEquals(1, result.exitStatus.get, "Expected empty password to cause failure with exit status=1") + } + + @Test + def testDescribeUnknownUser(): Unit = { + val unknownUser = "unknownUser" + val result = runConfigCommandViaBroker(Array("--user", unknownUser, "--describe")) + assertTrue(result.exitStatus.isEmpty, "Expected System.exit() to not be called with an unknown user") + assertEquals("", result.stdout) + } +} diff --git a/core/src/test/scala/unit/kafka/api/ApiUtilsTest.scala b/core/src/test/scala/unit/kafka/api/ApiUtilsTest.scala new file mode 100644 index 0000000..001aeb8 --- /dev/null +++ b/core/src/test/scala/unit/kafka/api/ApiUtilsTest.scala @@ -0,0 +1,71 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.api + +import org.junit.jupiter.api._ +import org.junit.jupiter.api.Assertions._ + +import scala.util.Random +import java.nio.ByteBuffer + +import kafka.utils.TestUtils +import org.apache.kafka.common.KafkaException + +object ApiUtilsTest { + val rnd: Random = new Random() +} + +class ApiUtilsTest { + + @Test + def testShortStringNonASCII(): Unit = { + // Random-length strings + for(_ <- 0 to 100) { + // Since we're using UTF-8 encoding, each encoded byte will be one to four bytes long + val s: String = ApiUtilsTest.rnd.nextString(math.abs(ApiUtilsTest.rnd.nextInt()) % (Short.MaxValue / 4)) + val bb: ByteBuffer = ByteBuffer.allocate(ApiUtils.shortStringLength(s)) + ApiUtils.writeShortString(bb, s) + bb.rewind() + assertEquals(s, ApiUtils.readShortString(bb)) + } + } + + @Test + def testShortStringASCII(): Unit = { + // Random-length strings + for(_ <- 0 to 100) { + val s: String = TestUtils.randomString(math.abs(ApiUtilsTest.rnd.nextInt()) % Short.MaxValue) + val bb: ByteBuffer = ByteBuffer.allocate(ApiUtils.shortStringLength(s)) + ApiUtils.writeShortString(bb, s) + bb.rewind() + assertEquals(s, ApiUtils.readShortString(bb)) + } + + // Max size string + val s1: String = TestUtils.randomString(Short.MaxValue) + val bb: ByteBuffer = ByteBuffer.allocate(ApiUtils.shortStringLength(s1)) + ApiUtils.writeShortString(bb, s1) + bb.rewind() + assertEquals(s1, ApiUtils.readShortString(bb)) + + // One byte too big + val s2: String = TestUtils.randomString(Short.MaxValue + 1) + assertThrows(classOf[KafkaException], () => ApiUtils.shortStringLength(s2)) + assertThrows(classOf[KafkaException], () => ApiUtils.writeShortString(bb, s2)) + } +} diff --git a/core/src/test/scala/unit/kafka/api/ApiVersionTest.scala b/core/src/test/scala/unit/kafka/api/ApiVersionTest.scala new file mode 100644 index 0000000..75dd682 --- /dev/null +++ b/core/src/test/scala/unit/kafka/api/ApiVersionTest.scala @@ -0,0 +1,283 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.api + +import java.util + +import org.apache.kafka.common.feature.{Features, FinalizedVersionRange, SupportedVersionRange} +import org.apache.kafka.common.message.ApiMessageType.ListenerType +import org.apache.kafka.common.protocol.ApiKeys +import org.apache.kafka.common.record.{RecordBatch, RecordVersion} +import org.apache.kafka.common.requests.{AbstractResponse, ApiVersionsResponse} +import org.apache.kafka.common.utils.Utils +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ + +class ApiVersionTest { + + @Test + def testApply(): Unit = { + assertEquals(KAFKA_0_8_0, ApiVersion("0.8.0")) + assertEquals(KAFKA_0_8_0, ApiVersion("0.8.0.0")) + assertEquals(KAFKA_0_8_0, ApiVersion("0.8.0.1")) + + assertEquals(KAFKA_0_8_1, ApiVersion("0.8.1")) + assertEquals(KAFKA_0_8_1, ApiVersion("0.8.1.0")) + assertEquals(KAFKA_0_8_1, ApiVersion("0.8.1.1")) + + assertEquals(KAFKA_0_8_2, ApiVersion("0.8.2")) + assertEquals(KAFKA_0_8_2, ApiVersion("0.8.2.0")) + assertEquals(KAFKA_0_8_2, ApiVersion("0.8.2.1")) + + assertEquals(KAFKA_0_9_0, ApiVersion("0.9.0")) + assertEquals(KAFKA_0_9_0, ApiVersion("0.9.0.0")) + assertEquals(KAFKA_0_9_0, ApiVersion("0.9.0.1")) + + assertEquals(KAFKA_0_10_0_IV0, ApiVersion("0.10.0-IV0")) + + assertEquals(KAFKA_0_10_0_IV1, ApiVersion("0.10.0")) + assertEquals(KAFKA_0_10_0_IV1, ApiVersion("0.10.0.0")) + assertEquals(KAFKA_0_10_0_IV1, ApiVersion("0.10.0.0-IV0")) + assertEquals(KAFKA_0_10_0_IV1, ApiVersion("0.10.0.1")) + + assertEquals(KAFKA_0_10_1_IV0, ApiVersion("0.10.1-IV0")) + assertEquals(KAFKA_0_10_1_IV1, ApiVersion("0.10.1-IV1")) + + assertEquals(KAFKA_0_10_1_IV2, ApiVersion("0.10.1")) + assertEquals(KAFKA_0_10_1_IV2, ApiVersion("0.10.1.0")) + assertEquals(KAFKA_0_10_1_IV2, ApiVersion("0.10.1-IV2")) + assertEquals(KAFKA_0_10_1_IV2, ApiVersion("0.10.1.1")) + + assertEquals(KAFKA_0_10_2_IV0, ApiVersion("0.10.2")) + assertEquals(KAFKA_0_10_2_IV0, ApiVersion("0.10.2.0")) + assertEquals(KAFKA_0_10_2_IV0, ApiVersion("0.10.2-IV0")) + assertEquals(KAFKA_0_10_2_IV0, ApiVersion("0.10.2.1")) + + assertEquals(KAFKA_0_11_0_IV0, ApiVersion("0.11.0-IV0")) + assertEquals(KAFKA_0_11_0_IV1, ApiVersion("0.11.0-IV1")) + + assertEquals(KAFKA_0_11_0_IV2, ApiVersion("0.11.0")) + assertEquals(KAFKA_0_11_0_IV2, ApiVersion("0.11.0.0")) + assertEquals(KAFKA_0_11_0_IV2, ApiVersion("0.11.0-IV2")) + assertEquals(KAFKA_0_11_0_IV2, ApiVersion("0.11.0.1")) + + assertEquals(KAFKA_1_0_IV0, ApiVersion("1.0")) + assertEquals(KAFKA_1_0_IV0, ApiVersion("1.0.0")) + assertEquals(KAFKA_1_0_IV0, ApiVersion("1.0.0-IV0")) + assertEquals(KAFKA_1_0_IV0, ApiVersion("1.0.1")) + + assertEquals(KAFKA_1_1_IV0, ApiVersion("1.1-IV0")) + + assertEquals(KAFKA_2_0_IV1, ApiVersion("2.0")) + assertEquals(KAFKA_2_0_IV0, ApiVersion("2.0-IV0")) + assertEquals(KAFKA_2_0_IV1, ApiVersion("2.0-IV1")) + + assertEquals(KAFKA_2_1_IV2, ApiVersion("2.1")) + assertEquals(KAFKA_2_1_IV0, ApiVersion("2.1-IV0")) + assertEquals(KAFKA_2_1_IV1, ApiVersion("2.1-IV1")) + assertEquals(KAFKA_2_1_IV2, ApiVersion("2.1-IV2")) + + assertEquals(KAFKA_2_2_IV1, ApiVersion("2.2")) + assertEquals(KAFKA_2_2_IV0, ApiVersion("2.2-IV0")) + assertEquals(KAFKA_2_2_IV1, ApiVersion("2.2-IV1")) + + assertEquals(KAFKA_2_3_IV1, ApiVersion("2.3")) + assertEquals(KAFKA_2_3_IV0, ApiVersion("2.3-IV0")) + assertEquals(KAFKA_2_3_IV1, ApiVersion("2.3-IV1")) + + assertEquals(KAFKA_2_4_IV1, ApiVersion("2.4")) + assertEquals(KAFKA_2_4_IV0, ApiVersion("2.4-IV0")) + assertEquals(KAFKA_2_4_IV1, ApiVersion("2.4-IV1")) + + assertEquals(KAFKA_2_5_IV0, ApiVersion("2.5")) + assertEquals(KAFKA_2_5_IV0, ApiVersion("2.5-IV0")) + + assertEquals(KAFKA_2_6_IV0, ApiVersion("2.6")) + assertEquals(KAFKA_2_6_IV0, ApiVersion("2.6-IV0")) + + assertEquals(KAFKA_2_7_IV0, ApiVersion("2.7-IV0")) + assertEquals(KAFKA_2_7_IV1, ApiVersion("2.7-IV1")) + assertEquals(KAFKA_2_7_IV2, ApiVersion("2.7-IV2")) + + assertEquals(KAFKA_2_8_IV1, ApiVersion("2.8")) + assertEquals(KAFKA_2_8_IV0, ApiVersion("2.8-IV0")) + assertEquals(KAFKA_2_8_IV1, ApiVersion("2.8-IV1")) + + assertEquals(KAFKA_3_0_IV1, ApiVersion("3.0")) + assertEquals(KAFKA_3_0_IV0, ApiVersion("3.0-IV0")) + assertEquals(KAFKA_3_0_IV1, ApiVersion("3.0-IV1")) + + assertEquals(KAFKA_3_1_IV0, ApiVersion("3.1")) + assertEquals(KAFKA_3_1_IV0, ApiVersion("3.1-IV0")) + } + + @Test + def testApiVersionUniqueIds(): Unit = { + val allIds: Seq[Int] = ApiVersion.allVersions.map(apiVersion => { + apiVersion.id + }) + + val uniqueIds: Set[Int] = allIds.toSet + + assertEquals(allIds.size, uniqueIds.size) + } + + @Test + def testMinSupportedVersionFor(): Unit = { + assertEquals(KAFKA_0_8_0, ApiVersion.minSupportedFor(RecordVersion.V0)) + assertEquals(KAFKA_0_10_0_IV0, ApiVersion.minSupportedFor(RecordVersion.V1)) + assertEquals(KAFKA_0_11_0_IV0, ApiVersion.minSupportedFor(RecordVersion.V2)) + + // Ensure that all record versions have a defined min version so that we remember to update the method + for (recordVersion <- RecordVersion.values) + assertNotNull(ApiVersion.minSupportedFor(recordVersion)) + } + + @Test + def testShortVersion(): Unit = { + assertEquals("0.8.0", KAFKA_0_8_0.shortVersion) + assertEquals("0.10.0", KAFKA_0_10_0_IV0.shortVersion) + assertEquals("0.10.0", KAFKA_0_10_0_IV1.shortVersion) + assertEquals("0.11.0", KAFKA_0_11_0_IV0.shortVersion) + assertEquals("0.11.0", KAFKA_0_11_0_IV1.shortVersion) + assertEquals("0.11.0", KAFKA_0_11_0_IV2.shortVersion) + assertEquals("1.0", KAFKA_1_0_IV0.shortVersion) + assertEquals("1.1", KAFKA_1_1_IV0.shortVersion) + assertEquals("2.0", KAFKA_2_0_IV0.shortVersion) + assertEquals("2.0", KAFKA_2_0_IV1.shortVersion) + assertEquals("2.1", KAFKA_2_1_IV0.shortVersion) + assertEquals("2.1", KAFKA_2_1_IV1.shortVersion) + assertEquals("2.1", KAFKA_2_1_IV2.shortVersion) + assertEquals("2.2", KAFKA_2_2_IV0.shortVersion) + assertEquals("2.2", KAFKA_2_2_IV1.shortVersion) + assertEquals("2.3", KAFKA_2_3_IV0.shortVersion) + assertEquals("2.3", KAFKA_2_3_IV1.shortVersion) + assertEquals("2.4", KAFKA_2_4_IV0.shortVersion) + assertEquals("2.5", KAFKA_2_5_IV0.shortVersion) + assertEquals("2.6", KAFKA_2_6_IV0.shortVersion) + assertEquals("2.7", KAFKA_2_7_IV2.shortVersion) + assertEquals("2.8", KAFKA_2_8_IV0.shortVersion) + assertEquals("2.8", KAFKA_2_8_IV1.shortVersion) + assertEquals("3.0", KAFKA_3_0_IV0.shortVersion) + assertEquals("3.0", KAFKA_3_0_IV1.shortVersion) + assertEquals("3.1", KAFKA_3_1_IV0.shortVersion) + } + + @Test + def testApiVersionValidator(): Unit = { + val str = ApiVersionValidator.toString + val apiVersions = str.slice(1, str.length).split(",") + assertEquals(ApiVersion.allVersions.size, apiVersions.length) + } + + @Test + def shouldCreateApiResponseOnlyWithKeysSupportedByMagicValue(): Unit = { + val response = ApiVersion.apiVersionsResponse( + 10, + RecordVersion.V1, + Features.emptySupportedFeatures, + None, + ListenerType.ZK_BROKER + ) + verifyApiKeysForMagic(response, RecordBatch.MAGIC_VALUE_V1) + assertEquals(10, response.throttleTimeMs) + assertTrue(response.data.supportedFeatures.isEmpty) + assertTrue(response.data.finalizedFeatures.isEmpty) + assertEquals(ApiVersionsResponse.UNKNOWN_FINALIZED_FEATURES_EPOCH, response.data.finalizedFeaturesEpoch) + } + + @Test + def shouldReturnFeatureKeysWhenMagicIsCurrentValueAndThrottleMsIsDefaultThrottle(): Unit = { + val response = ApiVersion.apiVersionsResponse( + 10, + RecordVersion.V1, + Features.supportedFeatures( + Utils.mkMap(Utils.mkEntry("feature", new SupportedVersionRange(1.toShort, 4.toShort)))), + Features.finalizedFeatures( + Utils.mkMap(Utils.mkEntry("feature", new FinalizedVersionRange(2.toShort, 3.toShort)))), + 10, + None, + ListenerType.ZK_BROKER + ) + + verifyApiKeysForMagic(response, RecordBatch.MAGIC_VALUE_V1) + assertEquals(10, response.throttleTimeMs) + assertEquals(1, response.data.supportedFeatures.size) + val sKey = response.data.supportedFeatures.find("feature") + assertNotNull(sKey) + assertEquals(1, sKey.minVersion) + assertEquals(4, sKey.maxVersion) + assertEquals(1, response.data.finalizedFeatures.size) + val fKey = response.data.finalizedFeatures.find("feature") + assertNotNull(fKey) + assertEquals(2, fKey.minVersionLevel) + assertEquals(3, fKey.maxVersionLevel) + assertEquals(10, response.data.finalizedFeaturesEpoch) + } + + private def verifyApiKeysForMagic(response: ApiVersionsResponse, maxMagic: Byte): Unit = { + for (version <- response.data.apiKeys.asScala) { + assertTrue(ApiKeys.forId(version.apiKey).minRequiredInterBrokerMagic <= maxMagic) + } + } + + @Test + def shouldReturnAllKeysWhenMagicIsCurrentValueAndThrottleMsIsDefaultThrottle(): Unit = { + val response = ApiVersion.apiVersionsResponse( + AbstractResponse.DEFAULT_THROTTLE_TIME, + RecordVersion.current(), + Features.emptySupportedFeatures, + None, + ListenerType.ZK_BROKER + ) + assertEquals(new util.HashSet[ApiKeys](ApiKeys.zkBrokerApis), apiKeysInResponse(response)) + assertEquals(AbstractResponse.DEFAULT_THROTTLE_TIME, response.throttleTimeMs) + assertTrue(response.data.supportedFeatures.isEmpty) + assertTrue(response.data.finalizedFeatures.isEmpty) + assertEquals(ApiVersionsResponse.UNKNOWN_FINALIZED_FEATURES_EPOCH, response.data.finalizedFeaturesEpoch) + } + + @Test + def testMetadataQuorumApisAreDisabled(): Unit = { + val response = ApiVersion.apiVersionsResponse( + AbstractResponse.DEFAULT_THROTTLE_TIME, + RecordVersion.current(), + Features.emptySupportedFeatures, + None, + ListenerType.ZK_BROKER + ) + + // Ensure that APIs needed for the KRaft mode are not exposed through ApiVersions until we are ready for them + val exposedApis = apiKeysInResponse(response) + assertFalse(exposedApis.contains(ApiKeys.ENVELOPE)) + assertFalse(exposedApis.contains(ApiKeys.VOTE)) + assertFalse(exposedApis.contains(ApiKeys.BEGIN_QUORUM_EPOCH)) + assertFalse(exposedApis.contains(ApiKeys.END_QUORUM_EPOCH)) + assertFalse(exposedApis.contains(ApiKeys.DESCRIBE_QUORUM)) + } + + private def apiKeysInResponse(apiVersions: ApiVersionsResponse) = { + val apiKeys = new util.HashSet[ApiKeys] + for (version <- apiVersions.data.apiKeys.asScala) { + apiKeys.add(ApiKeys.forId(version.apiKey)) + } + apiKeys + } +} diff --git a/core/src/test/scala/unit/kafka/cluster/AbstractPartitionTest.scala b/core/src/test/scala/unit/kafka/cluster/AbstractPartitionTest.scala new file mode 100644 index 0000000..887f16c --- /dev/null +++ b/core/src/test/scala/unit/kafka/cluster/AbstractPartitionTest.scala @@ -0,0 +1,146 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.cluster + +import kafka.api.ApiVersion +import kafka.log.{CleanerConfig, LogConfig, LogManager} +import kafka.server.{Defaults, MetadataCache} +import kafka.server.checkpoints.OffsetCheckpoints +import kafka.server.metadata.MockConfigRepository +import kafka.utils.TestUtils.{MockAlterIsrManager, MockIsrChangeListener} +import kafka.utils.{MockTime, TestUtils} +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrPartitionState +import org.apache.kafka.common.utils.Utils +import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue} +import org.junit.jupiter.api.{AfterEach, BeforeEach} +import org.mockito.ArgumentMatchers +import org.mockito.Mockito.{mock, when} + +import java.io.File +import java.util.Properties + +import scala.jdk.CollectionConverters._ + +object AbstractPartitionTest { + val brokerId = 101 +} + +class AbstractPartitionTest { + + val brokerId = AbstractPartitionTest.brokerId + val topicPartition = new TopicPartition("test-topic", 0) + val time = new MockTime() + var tmpDir: File = _ + var logDir1: File = _ + var logDir2: File = _ + var logManager: LogManager = _ + var alterIsrManager: MockAlterIsrManager = _ + var isrChangeListener: MockIsrChangeListener = _ + var logConfig: LogConfig = _ + var configRepository: MockConfigRepository = _ + val delayedOperations: DelayedOperations = mock(classOf[DelayedOperations]) + val metadataCache: MetadataCache = mock(classOf[MetadataCache]) + val offsetCheckpoints: OffsetCheckpoints = mock(classOf[OffsetCheckpoints]) + var partition: Partition = _ + + @BeforeEach + def setup(): Unit = { + TestUtils.clearYammerMetrics() + + val logProps = createLogProperties(Map.empty) + logConfig = LogConfig(logProps) + configRepository = MockConfigRepository.forTopic(topicPartition.topic(), logProps) + + tmpDir = TestUtils.tempDir() + logDir1 = TestUtils.randomPartitionLogDir(tmpDir) + logDir2 = TestUtils.randomPartitionLogDir(tmpDir) + logManager = TestUtils.createLogManager(Seq(logDir1, logDir2), logConfig, configRepository, + CleanerConfig(enableCleaner = false), time, interBrokerProtocolVersion) + logManager.startup(Set.empty) + + alterIsrManager = TestUtils.createAlterIsrManager() + isrChangeListener = TestUtils.createIsrChangeListener() + partition = new Partition(topicPartition, + replicaLagTimeMaxMs = Defaults.ReplicaLagTimeMaxMs, + interBrokerProtocolVersion = interBrokerProtocolVersion, + localBrokerId = brokerId, + time, + isrChangeListener, + delayedOperations, + metadataCache, + logManager, + alterIsrManager) + + when(offsetCheckpoints.fetch(ArgumentMatchers.anyString, ArgumentMatchers.eq(topicPartition))) + .thenReturn(None) + } + + protected def interBrokerProtocolVersion: ApiVersion = ApiVersion.latestVersion + + def createLogProperties(overrides: Map[String, String]): Properties = { + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 512: java.lang.Integer) + logProps.put(LogConfig.SegmentIndexBytesProp, 1000: java.lang.Integer) + logProps.put(LogConfig.RetentionMsProp, 999: java.lang.Integer) + overrides.foreach { case (k, v) => logProps.put(k, v) } + logProps + } + + @AfterEach + def tearDown(): Unit = { + if (tmpDir.exists()) { + logManager.shutdown() + Utils.delete(tmpDir) + TestUtils.clearYammerMetrics() + } + } + + protected def setupPartitionWithMocks(leaderEpoch: Int, + isLeader: Boolean): Partition = { + partition.createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None) + + val controllerEpoch = 0 + val replicas = List[Integer](brokerId, brokerId + 1).asJava + val isr = replicas + + if (isLeader) { + assertTrue(partition.makeLeader(new LeaderAndIsrPartitionState() + .setControllerEpoch(controllerEpoch) + .setLeader(brokerId) + .setLeaderEpoch(leaderEpoch) + .setIsr(isr) + .setZkVersion(1) + .setReplicas(replicas) + .setIsNew(true), offsetCheckpoints, None), "Expected become leader transition to succeed") + assertEquals(leaderEpoch, partition.getLeaderEpoch) + } else { + assertTrue(partition.makeFollower(new LeaderAndIsrPartitionState() + .setControllerEpoch(controllerEpoch) + .setLeader(brokerId + 1) + .setLeaderEpoch(leaderEpoch) + .setIsr(isr) + .setZkVersion(1) + .setReplicas(replicas) + .setIsNew(true), offsetCheckpoints, None), "Expected become follower transition to succeed") + assertEquals(leaderEpoch, partition.getLeaderEpoch) + assertEquals(None, partition.leaderLogIfLocal) + } + + partition + } +} diff --git a/core/src/test/scala/unit/kafka/cluster/AssignmentStateTest.scala b/core/src/test/scala/unit/kafka/cluster/AssignmentStateTest.scala new file mode 100644 index 0000000..a618825 --- /dev/null +++ b/core/src/test/scala/unit/kafka/cluster/AssignmentStateTest.scala @@ -0,0 +1,122 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.cluster + +import org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrPartitionState +import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertTrue} +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.{Arguments, MethodSource} + +import scala.jdk.CollectionConverters._ + +object AssignmentStateTest { + import AbstractPartitionTest._ + + def parameters: java.util.stream.Stream[Arguments] = Seq[Arguments]( + Arguments.of( + List[Integer](brokerId, brokerId + 1, brokerId + 2), + List[Integer](brokerId, brokerId + 1, brokerId + 2), + List.empty[Integer], List.empty[Integer], Seq.empty[Int], Boolean.box(false)), + Arguments.of( + List[Integer](brokerId, brokerId + 1), + List[Integer](brokerId, brokerId + 1, brokerId + 2), + List.empty[Integer], List.empty[Integer], Seq.empty[Int], Boolean.box(true)), + Arguments.of( + List[Integer](brokerId, brokerId + 1, brokerId + 2), + List[Integer](brokerId, brokerId + 1, brokerId + 2), + List[Integer](brokerId + 3, brokerId + 4), + List[Integer](brokerId + 1), + Seq(brokerId, brokerId + 1, brokerId + 2), Boolean.box(false)), + Arguments.of( + List[Integer](brokerId, brokerId + 1, brokerId + 2), + List[Integer](brokerId, brokerId + 1, brokerId + 2), + List[Integer](brokerId + 3, brokerId + 4), + List.empty[Integer], + Seq(brokerId, brokerId + 1, brokerId + 2), Boolean.box(false)), + Arguments.of( + List[Integer](brokerId, brokerId + 1, brokerId + 2), + List[Integer](brokerId, brokerId + 1, brokerId + 2), + List.empty[Integer], + List[Integer](brokerId + 1), + Seq(brokerId, brokerId + 1, brokerId + 2), Boolean.box(false)), + Arguments.of( + List[Integer](brokerId + 1, brokerId + 2), + List[Integer](brokerId + 1, brokerId + 2), + List[Integer](brokerId), + List.empty[Integer], + Seq(brokerId + 1, brokerId + 2), Boolean.box(false)), + Arguments.of( + List[Integer](brokerId + 2, brokerId + 3, brokerId + 4), + List[Integer](brokerId, brokerId + 1, brokerId + 2), + List[Integer](brokerId + 3, brokerId + 4, brokerId + 5), + List.empty[Integer], + Seq(brokerId, brokerId + 1, brokerId + 2), Boolean.box(false)), + Arguments.of( + List[Integer](brokerId + 2, brokerId + 3, brokerId + 4), + List[Integer](brokerId, brokerId + 1, brokerId + 2), + List[Integer](brokerId + 3, brokerId + 4, brokerId + 5), + List.empty[Integer], + Seq(brokerId, brokerId + 1, brokerId + 2), Boolean.box(false)), + Arguments.of( + List[Integer](brokerId + 2, brokerId + 3), + List[Integer](brokerId, brokerId + 1, brokerId + 2), + List[Integer](brokerId + 3, brokerId + 4, brokerId + 5), + List.empty[Integer], + Seq(brokerId, brokerId + 1, brokerId + 2), Boolean.box(true)) + ).asJava.stream() +} + +class AssignmentStateTest extends AbstractPartitionTest { + + @ParameterizedTest + @MethodSource(Array("parameters")) + def testPartitionAssignmentStatus(isr: List[Integer], replicas: List[Integer], + adding: List[Integer], removing: List[Integer], + original: Seq[Int], isUnderReplicated: Boolean): Unit = { + val controllerEpoch = 3 + + val leaderState = new LeaderAndIsrPartitionState() + .setControllerEpoch(controllerEpoch) + .setLeader(brokerId) + .setLeaderEpoch(6) + .setIsr(isr.asJava) + .setZkVersion(1) + .setReplicas(replicas.asJava) + .setIsNew(false) + if (adding.nonEmpty) + leaderState.setAddingReplicas(adding.asJava) + if (removing.nonEmpty) + leaderState.setRemovingReplicas(removing.asJava) + + val isReassigning = adding.nonEmpty || removing.nonEmpty + + // set the original replicas as the URP calculation will need them + if (original.nonEmpty) + partition.assignmentState = SimpleAssignmentState(original) + // do the test + partition.makeLeader(leaderState, offsetCheckpoints, None) + assertEquals(isReassigning, partition.isReassigning) + if (adding.nonEmpty) + adding.foreach(r => assertTrue(partition.isAddingReplica(r))) + if (adding.contains(brokerId)) + assertTrue(partition.isAddingLocalReplica) + else + assertFalse(partition.isAddingLocalReplica) + + assertEquals(isUnderReplicated, partition.isUnderReplicated) + } +} diff --git a/core/src/test/scala/unit/kafka/cluster/BrokerEndPointTest.scala b/core/src/test/scala/unit/kafka/cluster/BrokerEndPointTest.scala new file mode 100644 index 0000000..f36ce9b --- /dev/null +++ b/core/src/test/scala/unit/kafka/cluster/BrokerEndPointTest.scala @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.cluster + +import java.nio.charset.StandardCharsets + +import kafka.zk.BrokerIdZNode +import org.apache.kafka.common.feature.{Features, SupportedVersionRange} +import org.apache.kafka.common.feature.Features._ +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.junit.jupiter.api.Assertions.{assertEquals, assertNotEquals, assertNull} +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ + +class BrokerEndPointTest { + + @Test + def testHashAndEquals(): Unit = { + val broker1 = new BrokerEndPoint(1, "myhost", 9092) + val broker2 = new BrokerEndPoint(1, "myhost", 9092) + val broker3 = new BrokerEndPoint(2, "myhost", 1111) + val broker4 = new BrokerEndPoint(1, "other", 1111) + + assertEquals(broker1, broker2) + assertNotEquals(broker1, broker3) + assertNotEquals(broker1, broker4) + assertEquals(broker1.hashCode, broker2.hashCode) + assertNotEquals(broker1.hashCode, broker3.hashCode) + assertNotEquals(broker1.hashCode, broker4.hashCode) + + assertEquals(Some(1), Map(broker1 -> 1).get(broker1)) + } + + @Test + def testFromJsonFutureVersion(): Unit = { + // Future compatible versions should be supported, we use a hypothetical future version here + val brokerInfoStr = """{ + "foo":"bar", + "version":100, + "host":"localhost", + "port":9092, + "jmx_port":9999, + "timestamp":"1416974968782", + "endpoints":["SSL://localhost:9093"] + }""" + val broker = parseBrokerJson(1, brokerInfoStr) + assertEquals(1, broker.id) + val brokerEndPoint = broker.brokerEndPoint(ListenerName.forSecurityProtocol(SecurityProtocol.SSL)) + assertEquals("localhost", brokerEndPoint.host) + assertEquals(9093, brokerEndPoint.port) + } + + @Test + def testFromJsonV2(): Unit = { + val brokerInfoStr = """{ + "version":2, + "host":"localhost", + "port":9092, + "jmx_port":9999, + "timestamp":"1416974968782", + "endpoints":["PLAINTEXT://localhost:9092"] + }""" + val broker = parseBrokerJson(1, brokerInfoStr) + assertEquals(1, broker.id) + val brokerEndPoint = broker.brokerEndPoint(ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT)) + assertEquals("localhost", brokerEndPoint.host) + assertEquals(9092, brokerEndPoint.port) + } + + @Test + def testFromJsonV1(): Unit = { + val brokerInfoStr = """{"jmx_port":-1,"timestamp":"1420485325400","host":"172.16.8.243","version":1,"port":9091}""" + val broker = parseBrokerJson(1, brokerInfoStr) + assertEquals(1, broker.id) + val brokerEndPoint = broker.brokerEndPoint(ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT)) + assertEquals("172.16.8.243", brokerEndPoint.host) + assertEquals(9091, brokerEndPoint.port) + } + + @Test + def testFromJsonV3(): Unit = { + val json = """{ + "version":3, + "host":"localhost", + "port":9092, + "jmx_port":9999, + "timestamp":"2233345666", + "endpoints":["PLAINTEXT://host1:9092", "SSL://host1:9093"], + "rack":"dc1" + }""" + val broker = parseBrokerJson(1, json) + assertEquals(1, broker.id) + val brokerEndPoint = broker.brokerEndPoint(ListenerName.forSecurityProtocol(SecurityProtocol.SSL)) + assertEquals("host1", brokerEndPoint.host) + assertEquals(9093, brokerEndPoint.port) + assertEquals(Some("dc1"), broker.rack) + } + + @Test + def testFromJsonV4WithNullRack(): Unit = { + val json = """{ + "version":4, + "host":"localhost", + "port":9092, + "jmx_port":9999, + "timestamp":"2233345666", + "endpoints":["CLIENT://host1:9092", "REPLICATION://host1:9093"], + "listener_security_protocol_map":{"CLIENT":"SSL", "REPLICATION":"PLAINTEXT"}, + "rack":null + }""" + val broker = parseBrokerJson(1, json) + assertEquals(1, broker.id) + val brokerEndPoint = broker.brokerEndPoint(new ListenerName("CLIENT")) + assertEquals("host1", brokerEndPoint.host) + assertEquals(9092, brokerEndPoint.port) + assertEquals(None, broker.rack) + } + + @Test + def testFromJsonV4WithNoRack(): Unit = { + val json = """{ + "version":4, + "host":"localhost", + "port":9092, + "jmx_port":9999, + "timestamp":"2233345666", + "endpoints":["CLIENT://host1:9092", "REPLICATION://host1:9093"], + "listener_security_protocol_map":{"CLIENT":"SSL", "REPLICATION":"PLAINTEXT"} + }""" + val broker = parseBrokerJson(1, json) + assertEquals(1, broker.id) + val brokerEndPoint = broker.brokerEndPoint(new ListenerName("CLIENT")) + assertEquals("host1", brokerEndPoint.host) + assertEquals(9092, brokerEndPoint.port) + assertEquals(None, broker.rack) + } + + @Test + def testFromJsonV4WithNoFeatures(): Unit = { + val json = """{ + "version":4, + "host":"localhost", + "port":9092, + "jmx_port":9999, + "timestamp":"2233345666", + "endpoints":["CLIENT://host1:9092", "REPLICATION://host1:9093"], + "listener_security_protocol_map":{"CLIENT":"SSL", "REPLICATION":"PLAINTEXT"}, + "rack":"dc1" + }""" + val broker = parseBrokerJson(1, json) + assertEquals(1, broker.id) + val brokerEndPoint = broker.brokerEndPoint(new ListenerName("CLIENT")) + assertEquals("host1", brokerEndPoint.host) + assertEquals(9092, brokerEndPoint.port) + assertEquals(Some("dc1"), broker.rack) + assertEquals(emptySupportedFeatures, broker.features) + } + + @Test + def testFromJsonV5(): Unit = { + val json = """{ + "version":5, + "host":"localhost", + "port":9092, + "jmx_port":9999, + "timestamp":"2233345666", + "endpoints":["CLIENT://host1:9092", "REPLICATION://host1:9093"], + "listener_security_protocol_map":{"CLIENT":"SSL", "REPLICATION":"PLAINTEXT"}, + "rack":"dc1", + "features": {"feature1": {"min_version": 1, "max_version": 2}, "feature2": {"min_version": 2, "max_version": 4}} + }""" + val broker = parseBrokerJson(1, json) + assertEquals(1, broker.id) + val brokerEndPoint = broker.brokerEndPoint(new ListenerName("CLIENT")) + assertEquals("host1", brokerEndPoint.host) + assertEquals(9092, brokerEndPoint.port) + assertEquals(Some("dc1"), broker.rack) + assertEquals(Features.supportedFeatures( + Map[String, SupportedVersionRange]( + "feature1" -> new SupportedVersionRange(1, 2), + "feature2" -> new SupportedVersionRange(2, 4)).asJava), + broker.features) + } + + @Test + def testBrokerEndpointFromUri(): Unit = { + var connectionString = "localhost:9092" + var endpoint = BrokerEndPoint.createBrokerEndPoint(1, connectionString) + assertEquals("localhost", endpoint.host) + assertEquals(9092, endpoint.port) + //KAFKA-3719 + connectionString = "local_host:9092" + endpoint = BrokerEndPoint.createBrokerEndPoint(1, connectionString) + assertEquals("local_host", endpoint.host) + assertEquals(9092, endpoint.port) + // also test for ipv6 + connectionString = "[::1]:9092" + endpoint = BrokerEndPoint.createBrokerEndPoint(1, connectionString) + assertEquals("::1", endpoint.host) + assertEquals(9092, endpoint.port) + // test for ipv6 with % character + connectionString = "[fe80::b1da:69ca:57f7:63d8%3]:9092" + endpoint = BrokerEndPoint.createBrokerEndPoint(1, connectionString) + assertEquals("fe80::b1da:69ca:57f7:63d8%3", endpoint.host) + assertEquals(9092, endpoint.port) + // add test for uppercase in hostname + connectionString = "MyHostname:9092" + endpoint = BrokerEndPoint.createBrokerEndPoint(1, connectionString) + assertEquals("MyHostname", endpoint.host) + assertEquals(9092, endpoint.port) + } + + @Test + def testEndpointFromUri(): Unit = { + var connectionString = "PLAINTEXT://localhost:9092" + var endpoint = EndPoint.createEndPoint(connectionString, None) + assertEquals("localhost", endpoint.host) + assertEquals(9092, endpoint.port) + assertEquals("PLAINTEXT://localhost:9092", endpoint.connectionString) + // KAFKA-3719 + connectionString = "PLAINTEXT://local_host:9092" + endpoint = EndPoint.createEndPoint(connectionString, None) + assertEquals("local_host", endpoint.host) + assertEquals(9092, endpoint.port) + assertEquals("PLAINTEXT://local_host:9092", endpoint.connectionString) + // also test for default bind + connectionString = "PLAINTEXT://:9092" + endpoint = EndPoint.createEndPoint(connectionString, None) + assertNull(endpoint.host) + assertEquals(9092, endpoint.port) + assertEquals( "PLAINTEXT://:9092", endpoint.connectionString) + // also test for ipv6 + connectionString = "PLAINTEXT://[::1]:9092" + endpoint = EndPoint.createEndPoint(connectionString, None) + assertEquals("::1", endpoint.host) + assertEquals(9092, endpoint.port) + assertEquals("PLAINTEXT://[::1]:9092", endpoint.connectionString) + // test for ipv6 with % character + connectionString = "PLAINTEXT://[fe80::b1da:69ca:57f7:63d8%3]:9092" + endpoint = EndPoint.createEndPoint(connectionString, None) + assertEquals("fe80::b1da:69ca:57f7:63d8%3", endpoint.host) + assertEquals(9092, endpoint.port) + assertEquals("PLAINTEXT://[fe80::b1da:69ca:57f7:63d8%3]:9092", endpoint.connectionString) + // test hostname + connectionString = "PLAINTEXT://MyHostname:9092" + endpoint = EndPoint.createEndPoint(connectionString, None) + assertEquals("MyHostname", endpoint.host) + assertEquals(9092, endpoint.port) + assertEquals("PLAINTEXT://MyHostname:9092", endpoint.connectionString) + } + + private def parseBrokerJson(id: Int, jsonString: String): Broker = + BrokerIdZNode.decode(id, jsonString.getBytes(StandardCharsets.UTF_8)).broker +} diff --git a/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala b/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala new file mode 100644 index 0000000..9e5441d --- /dev/null +++ b/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala @@ -0,0 +1,398 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.cluster + +import java.util.Properties +import java.util.concurrent._ +import java.util.concurrent.atomic.AtomicBoolean + +import kafka.api.{ApiVersion, LeaderAndIsr} +import kafka.log._ +import kafka.server._ +import kafka.server.checkpoints.OffsetCheckpoints +import kafka.server.epoch.LeaderEpochFileCache +import kafka.server.metadata.MockConfigRepository +import kafka.utils._ +import org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrPartitionState +import org.apache.kafka.common.record.{MemoryRecords, SimpleRecord} +import org.apache.kafka.common.utils.Utils +import org.apache.kafka.common.{TopicPartition, Uuid} +import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertTrue} +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} +import org.mockito.ArgumentMatchers +import org.mockito.Mockito.{mock, when} + +import scala.concurrent.duration._ +import scala.jdk.CollectionConverters._ + +/** + * Verifies that slow appends to log don't block request threads processing replica fetch requests. + * + * Test simulates: + * 1) Produce request handling by performing append to log as leader. + * 2) Replica fetch request handling by processing update of ISR and HW based on log read result. + * 3) Some tests also simulate a scheduler thread that checks or updates ISRs when replica falls out of ISR + */ +class PartitionLockTest extends Logging { + + val numReplicaFetchers = 2 + val numProducers = 3 + val numRecordsPerProducer = 5 + + val mockTime = new MockTime() // Used for check to shrink ISR + val tmpDir = TestUtils.tempDir() + val logDir = TestUtils.randomPartitionLogDir(tmpDir) + val executorService = Executors.newFixedThreadPool(numReplicaFetchers + numProducers + 1) + val appendSemaphore = new Semaphore(0) + val shrinkIsrSemaphore = new Semaphore(0) + val followerQueues = (0 until numReplicaFetchers).map(_ => new ArrayBlockingQueue[MemoryRecords](2)) + + var logManager: LogManager = _ + var partition: Partition = _ + + private val topicPartition = new TopicPartition("test-topic", 0) + + @BeforeEach + def setUp(): Unit = { + val logConfig = new LogConfig(new Properties) + val configRepository = MockConfigRepository.forTopic(topicPartition.topic, createLogProperties(Map.empty)) + logManager = TestUtils.createLogManager(Seq(logDir), logConfig, configRepository, + CleanerConfig(enableCleaner = false), mockTime) + partition = setupPartitionWithMocks(logManager) + } + + @AfterEach + def tearDown(): Unit = { + executorService.shutdownNow() + logManager.liveLogDirs.foreach(Utils.delete) + Utils.delete(tmpDir) + } + + /** + * Verifies that delays in appending to leader while processing produce requests has no impact on timing + * of update of log read result when processing replica fetch request if no ISR update is required. + */ + @Test + def testNoLockContentionWithoutIsrUpdate(): Unit = { + concurrentProduceFetchWithReadLockOnly() + } + + /** + * Verifies that delays in appending to leader while processing produce requests has no impact on timing + * of update of log read result when processing replica fetch request even if a scheduler thread is checking + * for ISR shrink conditions if no ISR update is required. + */ + @Test + def testAppendReplicaFetchWithSchedulerCheckForShrinkIsr(): Unit = { + val active = new AtomicBoolean(true) + + val future = scheduleShrinkIsr(active, mockTimeSleepMs = 0) + concurrentProduceFetchWithReadLockOnly() + active.set(false) + future.get(15, TimeUnit.SECONDS) + } + + /** + * Verifies concurrent produce and replica fetch log read result update with ISR updates. This + * can result in delays in processing produce and replica fetch requets since write lock is obtained, + * but it should complete without any failures. + */ + @Test + def testAppendReplicaFetchWithUpdateIsr(): Unit = { + val active = new AtomicBoolean(true) + + val future = scheduleShrinkIsr(active, mockTimeSleepMs = 10000) + TestUtils.waitUntilTrue(() => shrinkIsrSemaphore.hasQueuedThreads, "shrinkIsr not invoked") + concurrentProduceFetchWithWriteLock() + active.set(false) + future.get(15, TimeUnit.SECONDS) + } + + /** + * Concurrently calling updateAssignmentAndIsr should always ensure that non-lock access + * to the inner remoteReplicaMap (accessed by getReplica) cannot see an intermediate state + * where replicas present both in the old and new assignment are missing + */ + @Test + def testGetReplicaWithUpdateAssignmentAndIsr(): Unit = { + val active = new AtomicBoolean(true) + val replicaToCheck = 3 + val firstReplicaSet = Seq[Integer](3, 4, 5).asJava + val secondReplicaSet = Seq[Integer](1, 2, 3).asJava + def partitionState(replicas: java.util.List[Integer]) = new LeaderAndIsrPartitionState() + .setControllerEpoch(1) + .setLeader(replicas.get(0)) + .setLeaderEpoch(1) + .setIsr(replicas) + .setZkVersion(1) + .setReplicas(replicas) + .setIsNew(true) + val offsetCheckpoints: OffsetCheckpoints = mock(classOf[OffsetCheckpoints]) + // Update replica set synchronously first to avoid race conditions + partition.makeLeader(partitionState(secondReplicaSet), offsetCheckpoints, None) + assertTrue(partition.getReplica(replicaToCheck).isDefined, s"Expected replica $replicaToCheck to be defined") + + val future = executorService.submit((() => { + var i = 0 + // Flip assignment between two replica sets + while (active.get) { + val replicas = if (i % 2 == 0) { + firstReplicaSet + } else { + secondReplicaSet + } + + partition.makeLeader(partitionState(replicas), offsetCheckpoints, None) + + i += 1 + Thread.sleep(1) // just to avoid tight loop + } + }): Runnable) + + val deadline = 1.seconds.fromNow + while (deadline.hasTimeLeft()) { + assertTrue(partition.getReplica(replicaToCheck).isDefined, s"Expected replica $replicaToCheck to be defined") + } + active.set(false) + future.get(5, TimeUnit.SECONDS) + assertTrue(partition.getReplica(replicaToCheck).isDefined, s"Expected replica $replicaToCheck to be defined") + } + + /** + * Perform concurrent appends and replica fetch requests that don't require write lock to + * update follower state. Release sufficient append permits to complete all except one append. + * Verify that follower state updates complete even though an append holding read lock is in progress. + * Then release the permit for the final append and verify that all appends and follower updates complete. + */ + private def concurrentProduceFetchWithReadLockOnly(): Unit = { + val appendFutures = scheduleAppends() + val stateUpdateFutures = scheduleUpdateFollowers(numProducers * numRecordsPerProducer - 1) + + appendSemaphore.release(numProducers * numRecordsPerProducer - 1) + stateUpdateFutures.foreach(_.get(15, TimeUnit.SECONDS)) + + appendSemaphore.release(1) + scheduleUpdateFollowers(1).foreach(_.get(15, TimeUnit.SECONDS)) // just to make sure follower state update still works + appendFutures.foreach(_.get(15, TimeUnit.SECONDS)) + } + + /** + * Perform concurrent appends and replica fetch requests that may require write lock to update + * follower state. Threads waiting for write lock to update follower state while append thread is + * holding read lock will prevent other threads acquiring the read or write lock. So release sufficient + * permits for all appends to complete before verifying state updates. + */ + private def concurrentProduceFetchWithWriteLock(): Unit = { + + val appendFutures = scheduleAppends() + val stateUpdateFutures = scheduleUpdateFollowers(numProducers * numRecordsPerProducer) + + assertFalse(stateUpdateFutures.exists(_.isDone)) + appendSemaphore.release(numProducers * numRecordsPerProducer) + assertFalse(appendFutures.exists(_.isDone)) + + shrinkIsrSemaphore.release() + stateUpdateFutures.foreach(_.get(15, TimeUnit.SECONDS)) + appendFutures.foreach(_.get(15, TimeUnit.SECONDS)) + } + + private def scheduleAppends(): Seq[Future[_]] = { + (0 until numProducers).map { _ => + executorService.submit((() => { + try { + append(partition, numRecordsPerProducer, followerQueues) + } catch { + case e: Throwable => + error("Exception during append", e) + throw e + } + }): Runnable) + } + } + + private def scheduleUpdateFollowers(numRecords: Int): Seq[Future[_]] = { + (1 to numReplicaFetchers).map { index => + executorService.submit((() => { + try { + updateFollowerFetchState(partition, index, numRecords, followerQueues(index - 1)) + } catch { + case e: Throwable => + error("Exception during updateFollowerFetchState", e) + throw e + } + }): Runnable) + } + } + + private def scheduleShrinkIsr(activeFlag: AtomicBoolean, mockTimeSleepMs: Long): Future[_] = { + executorService.submit((() => { + while (activeFlag.get) { + if (mockTimeSleepMs > 0) + mockTime.sleep(mockTimeSleepMs) + partition.maybeShrinkIsr() + Thread.sleep(1) // just to avoid tight loop + } + }): Runnable) + } + + private def setupPartitionWithMocks(logManager: LogManager): Partition = { + val leaderEpoch = 1 + val brokerId = 0 + val isrChangeListener: IsrChangeListener = mock(classOf[IsrChangeListener]) + val delayedOperations: DelayedOperations = mock(classOf[DelayedOperations]) + val metadataCache: MetadataCache = mock(classOf[MetadataCache]) + val offsetCheckpoints: OffsetCheckpoints = mock(classOf[OffsetCheckpoints]) + val alterIsrManager: AlterIsrManager = mock(classOf[AlterIsrManager]) + + logManager.startup(Set.empty) + val partition = new Partition(topicPartition, + replicaLagTimeMaxMs = kafka.server.Defaults.ReplicaLagTimeMaxMs, + interBrokerProtocolVersion = ApiVersion.latestVersion, + localBrokerId = brokerId, + mockTime, + isrChangeListener, + delayedOperations, + metadataCache, + logManager, + alterIsrManager) { + + override def prepareIsrShrink(outOfSyncReplicaIds: Set[Int]): PendingShrinkIsr = { + shrinkIsrSemaphore.acquire() + try { + super.prepareIsrShrink(outOfSyncReplicaIds) + } finally { + shrinkIsrSemaphore.release() + } + } + + override def createLog(isNew: Boolean, isFutureReplica: Boolean, offsetCheckpoints: OffsetCheckpoints, topicId: Option[Uuid]): UnifiedLog = { + val log = super.createLog(isNew, isFutureReplica, offsetCheckpoints, None) + val logDirFailureChannel = new LogDirFailureChannel(1) + val segments = new LogSegments(log.topicPartition) + val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(log.dir, log.topicPartition, logDirFailureChannel, log.config.recordVersion, "") + val maxProducerIdExpirationMs = 60 * 60 * 1000 + val producerStateManager = new ProducerStateManager(log.topicPartition, log.dir, maxProducerIdExpirationMs) + val offsets = LogLoader.load(LoadLogParams( + log.dir, + log.topicPartition, + log.config, + mockTime.scheduler, + mockTime, + logDirFailureChannel, + hadCleanShutdown = true, + segments, + 0L, + 0L, + maxProducerIdExpirationMs, + leaderEpochCache, + producerStateManager)) + val localLog = new LocalLog(log.dir, log.config, segments, offsets.recoveryPoint, + offsets.nextOffsetMetadata, mockTime.scheduler, mockTime, log.topicPartition, + logDirFailureChannel) + new SlowLog(log, offsets.logStartOffset, localLog, leaderEpochCache, producerStateManager, appendSemaphore) + } + } + when(offsetCheckpoints.fetch( + ArgumentMatchers.anyString, + ArgumentMatchers.eq(topicPartition) + )).thenReturn(None) + when(alterIsrManager.submit( + ArgumentMatchers.eq(topicPartition), + ArgumentMatchers.any[LeaderAndIsr], + ArgumentMatchers.anyInt() + )).thenReturn(new CompletableFuture[LeaderAndIsr]()) + + partition.createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None) + + val controllerEpoch = 0 + val replicas = (0 to numReplicaFetchers).map(i => Integer.valueOf(brokerId + i)).toList.asJava + val isr = replicas + + assertTrue(partition.makeLeader(new LeaderAndIsrPartitionState() + .setControllerEpoch(controllerEpoch) + .setLeader(brokerId) + .setLeaderEpoch(leaderEpoch) + .setIsr(isr) + .setZkVersion(1) + .setReplicas(replicas) + .setIsNew(true), offsetCheckpoints, None), "Expected become leader transition to succeed") + + partition + } + + private def createLogProperties(overrides: Map[String, String]): Properties = { + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 512: java.lang.Integer) + logProps.put(LogConfig.SegmentIndexBytesProp, 1000: java.lang.Integer) + logProps.put(LogConfig.RetentionMsProp, 999: java.lang.Integer) + overrides.foreach { case (k, v) => logProps.put(k, v) } + logProps + } + + private def append(partition: Partition, numRecords: Int, followerQueues: Seq[ArrayBlockingQueue[MemoryRecords]]): Unit = { + val requestLocal = RequestLocal.withThreadConfinedCaching + (0 until numRecords).foreach { _ => + val batch = TestUtils.records(records = List(new SimpleRecord("k1".getBytes, "v1".getBytes), + new SimpleRecord("k2".getBytes, "v2".getBytes))) + partition.appendRecordsToLeader(batch, origin = AppendOrigin.Client, requiredAcks = 0, requestLocal) + followerQueues.foreach(_.put(batch)) + } + } + + private def updateFollowerFetchState(partition: Partition, followerId: Int, numRecords: Int, followerQueue: ArrayBlockingQueue[MemoryRecords]): Unit = { + (1 to numRecords).foreach { i => + val batch = followerQueue.poll(15, TimeUnit.SECONDS) + if (batch == null) + throw new RuntimeException(s"Timed out waiting for next batch $i") + val batches = batch.batches.iterator.asScala.toList + assertEquals(1, batches.size) + val recordBatch = batches.head + partition.updateFollowerFetchState( + followerId, + followerFetchOffsetMetadata = LogOffsetMetadata(recordBatch.lastOffset + 1), + followerStartOffset = 0L, + followerFetchTimeMs = mockTime.milliseconds(), + leaderEndOffset = partition.localLogOrException.logEndOffset) + } + } + + private class SlowLog( + log: UnifiedLog, + logStartOffset: Long, + localLog: LocalLog, + leaderEpochCache: Option[LeaderEpochFileCache], + producerStateManager: ProducerStateManager, + appendSemaphore: Semaphore + ) extends UnifiedLog( + logStartOffset, + localLog, + new BrokerTopicStats, + log.producerIdExpirationCheckIntervalMs, + leaderEpochCache, + producerStateManager, + _topicId = None, + keepPartitionMetadataFile = true) { + + override def appendAsLeader(records: MemoryRecords, leaderEpoch: Int, origin: AppendOrigin, + interBrokerProtocolVersion: ApiVersion, requestLocal: RequestLocal): LogAppendInfo = { + val appendInfo = super.appendAsLeader(records, leaderEpoch, origin, interBrokerProtocolVersion, requestLocal) + appendSemaphore.acquire() + appendInfo + } + } +} diff --git a/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala new file mode 100644 index 0000000..ccbb102 --- /dev/null +++ b/core/src/test/scala/unit/kafka/cluster/PartitionTest.scala @@ -0,0 +1,2080 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.cluster + +import com.yammer.metrics.core.Metric +import kafka.api.{ApiVersion, KAFKA_2_6_IV0} +import kafka.common.UnexpectedAppendOffsetException +import kafka.log.{Defaults => _, _} +import kafka.metrics.KafkaYammerMetrics +import kafka.server._ +import kafka.server.checkpoints.OffsetCheckpoints +import kafka.utils._ +import kafka.zk.KafkaZkClient +import org.apache.kafka.common.errors.{ApiException, InconsistentTopicIdException, NotLeaderOrFollowerException, OffsetNotAvailableException, OffsetOutOfRangeException} +import org.apache.kafka.common.message.FetchResponseData +import org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrPartitionState +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.record.FileRecords.TimestampAndOffset +import org.apache.kafka.common.record._ +import org.apache.kafka.common.requests.ListOffsetsRequest +import org.apache.kafka.common.utils.SystemTime +import org.apache.kafka.common.{IsolationLevel, TopicPartition, Uuid} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test +import org.mockito.ArgumentMatchers +import org.mockito.ArgumentMatchers.{any, anyString} +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock + +import java.nio.ByteBuffer +import java.util.Optional +import java.util.concurrent.{CountDownLatch, Semaphore} +import kafka.server.epoch.LeaderEpochFileCache + +import scala.jdk.CollectionConverters._ + +class PartitionTest extends AbstractPartitionTest { + + @Test + def testLastFetchedOffsetValidation(): Unit = { + val log = logManager.getOrCreateLog(topicPartition, topicId = None) + def append(leaderEpoch: Int, count: Int): Unit = { + val recordArray = (1 to count).map { i => + new SimpleRecord(s"$i".getBytes) + } + val records = MemoryRecords.withRecords(0L, CompressionType.NONE, leaderEpoch, + recordArray: _*) + log.appendAsLeader(records, leaderEpoch = leaderEpoch) + } + + append(leaderEpoch = 0, count = 2) // 0 + append(leaderEpoch = 3, count = 3) // 2 + append(leaderEpoch = 3, count = 3) // 5 + append(leaderEpoch = 4, count = 5) // 8 + append(leaderEpoch = 7, count = 1) // 13 + append(leaderEpoch = 9, count = 3) // 14 + assertEquals(17L, log.logEndOffset) + + val leaderEpoch = 10 + val partition = setupPartitionWithMocks(leaderEpoch = leaderEpoch, isLeader = true) + + def epochEndOffset(epoch: Int, endOffset: Long): FetchResponseData.EpochEndOffset = { + new FetchResponseData.EpochEndOffset() + .setEpoch(epoch) + .setEndOffset(endOffset) + } + + def read(lastFetchedEpoch: Int, fetchOffset: Long): LogReadInfo = { + partition.readRecords( + Optional.of(lastFetchedEpoch), + fetchOffset, + currentLeaderEpoch = Optional.of(leaderEpoch), + maxBytes = Int.MaxValue, + fetchIsolation = FetchLogEnd, + fetchOnlyFromLeader = true, + minOneMessage = true + ) + } + + def assertDivergence( + divergingEpoch: FetchResponseData.EpochEndOffset, + readInfo: LogReadInfo + ): Unit = { + assertEquals(Some(divergingEpoch), readInfo.divergingEpoch) + assertEquals(0, readInfo.fetchedData.records.sizeInBytes) + } + + def assertNoDivergence(readInfo: LogReadInfo): Unit = { + assertEquals(None, readInfo.divergingEpoch) + } + + assertDivergence(epochEndOffset(epoch = 0, endOffset = 2), read(lastFetchedEpoch = 2, fetchOffset = 5)) + assertDivergence(epochEndOffset(epoch = 0, endOffset= 2), read(lastFetchedEpoch = 0, fetchOffset = 4)) + assertDivergence(epochEndOffset(epoch = 4, endOffset = 13), read(lastFetchedEpoch = 6, fetchOffset = 6)) + assertDivergence(epochEndOffset(epoch = 4, endOffset = 13), read(lastFetchedEpoch = 5, fetchOffset = 9)) + assertDivergence(epochEndOffset(epoch = 10, endOffset = 17), read(lastFetchedEpoch = 10, fetchOffset = 18)) + assertNoDivergence(read(lastFetchedEpoch = 0, fetchOffset = 2)) + assertNoDivergence(read(lastFetchedEpoch = 7, fetchOffset = 14)) + assertNoDivergence(read(lastFetchedEpoch = 9, fetchOffset = 17)) + assertNoDivergence(read(lastFetchedEpoch = 10, fetchOffset = 17)) + + // Reads from epochs larger than we know about should cause an out of range error + assertThrows(classOf[OffsetOutOfRangeException], () => read(lastFetchedEpoch = 11, fetchOffset = 5)) + + // Move log start offset to the middle of epoch 3 + log.updateHighWatermark(log.logEndOffset) + log.maybeIncrementLogStartOffset(newLogStartOffset = 5L, ClientRecordDeletion) + + assertDivergence(epochEndOffset(epoch = 2, endOffset = 5), read(lastFetchedEpoch = 2, fetchOffset = 8)) + assertNoDivergence(read(lastFetchedEpoch = 0, fetchOffset = 5)) + assertNoDivergence(read(lastFetchedEpoch = 3, fetchOffset = 5)) + + assertThrows(classOf[OffsetOutOfRangeException], () => read(lastFetchedEpoch = 0, fetchOffset = 0)) + + // Fetch offset lower than start offset should throw OffsetOutOfRangeException + log.maybeIncrementLogStartOffset(newLogStartOffset = 10, ClientRecordDeletion) + assertThrows(classOf[OffsetOutOfRangeException], () => read(lastFetchedEpoch = 5, fetchOffset = 6)) // diverging + assertThrows(classOf[OffsetOutOfRangeException], () => read(lastFetchedEpoch = 3, fetchOffset = 6)) // not diverging + } + + @Test + def testMakeLeaderUpdatesEpochCache(): Unit = { + val leaderEpoch = 8 + + val log = logManager.getOrCreateLog(topicPartition, topicId = None) + log.appendAsLeader(MemoryRecords.withRecords(0L, CompressionType.NONE, 0, + new SimpleRecord("k1".getBytes, "v1".getBytes), + new SimpleRecord("k2".getBytes, "v2".getBytes) + ), leaderEpoch = 0) + log.appendAsLeader(MemoryRecords.withRecords(0L, CompressionType.NONE, 5, + new SimpleRecord("k3".getBytes, "v3".getBytes), + new SimpleRecord("k4".getBytes, "v4".getBytes) + ), leaderEpoch = 5) + assertEquals(4, log.logEndOffset) + + val partition = setupPartitionWithMocks(leaderEpoch = leaderEpoch, isLeader = true) + assertEquals(Some(4), partition.leaderLogIfLocal.map(_.logEndOffset)) + + val epochEndOffset = partition.lastOffsetForLeaderEpoch(currentLeaderEpoch = Optional.of[Integer](leaderEpoch), + leaderEpoch = leaderEpoch, fetchOnlyFromLeader = true) + assertEquals(4, epochEndOffset.endOffset) + assertEquals(leaderEpoch, epochEndOffset.leaderEpoch) + } + + // Verify that partition.removeFutureLocalReplica() and partition.maybeReplaceCurrentWithFutureReplica() can run concurrently + @Test + def testMaybeReplaceCurrentWithFutureReplica(): Unit = { + val latch = new CountDownLatch(1) + + logManager.maybeUpdatePreferredLogDir(topicPartition, logDir1.getAbsolutePath) + partition.createLogIfNotExists(isNew = true, isFutureReplica = false, offsetCheckpoints, None) + logManager.maybeUpdatePreferredLogDir(topicPartition, logDir2.getAbsolutePath) + partition.maybeCreateFutureReplica(logDir2.getAbsolutePath, offsetCheckpoints) + + val thread1 = new Thread { + override def run(): Unit = { + latch.await() + partition.removeFutureLocalReplica() + } + } + + val thread2 = new Thread { + override def run(): Unit = { + latch.await() + partition.maybeReplaceCurrentWithFutureReplica() + } + } + + thread1.start() + thread2.start() + + latch.countDown() + thread1.join() + thread2.join() + assertEquals(None, partition.futureLog) + } + + // Verify that partition.makeFollower() and partition.appendRecordsToFollowerOrFutureReplica() can run concurrently + @Test + def testMakeFollowerWithWithFollowerAppendRecords(): Unit = { + val appendSemaphore = new Semaphore(0) + val mockTime = new MockTime() + + partition = new Partition( + topicPartition, + replicaLagTimeMaxMs = Defaults.ReplicaLagTimeMaxMs, + interBrokerProtocolVersion = ApiVersion.latestVersion, + localBrokerId = brokerId, + time, + isrChangeListener, + delayedOperations, + metadataCache, + logManager, + alterIsrManager) { + + override def createLog(isNew: Boolean, isFutureReplica: Boolean, offsetCheckpoints: OffsetCheckpoints, topicId: Option[Uuid]): UnifiedLog = { + val log = super.createLog(isNew, isFutureReplica, offsetCheckpoints, None) + val logDirFailureChannel = new LogDirFailureChannel(1) + val segments = new LogSegments(log.topicPartition) + val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(log.dir, log.topicPartition, logDirFailureChannel, log.config.recordVersion, "") + val maxProducerIdExpirationMs = 60 * 60 * 1000 + val producerStateManager = new ProducerStateManager(log.topicPartition, log.dir, maxProducerIdExpirationMs) + val offsets = LogLoader.load(LoadLogParams( + log.dir, + log.topicPartition, + log.config, + mockTime.scheduler, + mockTime, + logDirFailureChannel, + hadCleanShutdown = true, + segments, + 0L, + 0L, + maxProducerIdExpirationMs, + leaderEpochCache, + producerStateManager)) + val localLog = new LocalLog(log.dir, log.config, segments, offsets.recoveryPoint, + offsets.nextOffsetMetadata, mockTime.scheduler, mockTime, log.topicPartition, + logDirFailureChannel) + new SlowLog(log, offsets.logStartOffset, localLog, leaderEpochCache, producerStateManager, appendSemaphore) + } + } + + partition.createLogIfNotExists(isNew = true, isFutureReplica = false, offsetCheckpoints, None) + + val appendThread = new Thread { + override def run(): Unit = { + val records = createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes), + new SimpleRecord("k2".getBytes, "v2".getBytes)), + baseOffset = 0) + partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = false) + } + } + appendThread.start() + TestUtils.waitUntilTrue(() => appendSemaphore.hasQueuedThreads, "follower log append is not called.") + + val partitionState = new LeaderAndIsrPartitionState() + .setControllerEpoch(0) + .setLeader(2) + .setLeaderEpoch(1) + .setIsr(List[Integer](0, 1, 2, brokerId).asJava) + .setZkVersion(1) + .setReplicas(List[Integer](0, 1, 2, brokerId).asJava) + .setIsNew(false) + assertTrue(partition.makeFollower(partitionState, offsetCheckpoints, None)) + + appendSemaphore.release() + appendThread.join() + + assertEquals(2L, partition.localLogOrException.logEndOffset) + assertEquals(2L, partition.leaderReplicaIdOpt.get) + } + + @Test + // Verify that replacement works when the replicas have the same log end offset but different base offsets in the + // active segment + def testMaybeReplaceCurrentWithFutureReplicaDifferentBaseOffsets(): Unit = { + logManager.maybeUpdatePreferredLogDir(topicPartition, logDir1.getAbsolutePath) + partition.createLogIfNotExists(isNew = true, isFutureReplica = false, offsetCheckpoints, None) + logManager.maybeUpdatePreferredLogDir(topicPartition, logDir2.getAbsolutePath) + partition.maybeCreateFutureReplica(logDir2.getAbsolutePath, offsetCheckpoints) + + // Write records with duplicate keys to current replica and roll at offset 6 + val currentLog = partition.log.get + currentLog.appendAsLeader(MemoryRecords.withRecords(0L, CompressionType.NONE, 0, + new SimpleRecord("k1".getBytes, "v1".getBytes), + new SimpleRecord("k1".getBytes, "v2".getBytes), + new SimpleRecord("k1".getBytes, "v3".getBytes), + new SimpleRecord("k2".getBytes, "v4".getBytes), + new SimpleRecord("k2".getBytes, "v5".getBytes), + new SimpleRecord("k2".getBytes, "v6".getBytes) + ), leaderEpoch = 0) + currentLog.roll() + currentLog.appendAsLeader(MemoryRecords.withRecords(0L, CompressionType.NONE, 0, + new SimpleRecord("k3".getBytes, "v7".getBytes), + new SimpleRecord("k4".getBytes, "v8".getBytes) + ), leaderEpoch = 0) + + // Write to the future replica as if the log had been compacted, and do not roll the segment + + val buffer = ByteBuffer.allocate(1024) + val builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE, + TimestampType.CREATE_TIME, 0L, RecordBatch.NO_TIMESTAMP, 0) + builder.appendWithOffset(2L, new SimpleRecord("k1".getBytes, "v3".getBytes)) + builder.appendWithOffset(5L, new SimpleRecord("k2".getBytes, "v6".getBytes)) + builder.appendWithOffset(6L, new SimpleRecord("k3".getBytes, "v7".getBytes)) + builder.appendWithOffset(7L, new SimpleRecord("k4".getBytes, "v8".getBytes)) + + val futureLog = partition.futureLocalLogOrException + futureLog.appendAsFollower(builder.build()) + + assertTrue(partition.maybeReplaceCurrentWithFutureReplica()) + } + + @Test + def testFetchOffsetSnapshotEpochValidationForLeader(): Unit = { + val leaderEpoch = 5 + val partition = setupPartitionWithMocks(leaderEpoch, isLeader = true) + + def assertSnapshotError(expectedError: Errors, currentLeaderEpoch: Optional[Integer]): Unit = { + try { + partition.fetchOffsetSnapshot(currentLeaderEpoch, fetchOnlyFromLeader = true) + assertEquals(Errors.NONE, expectedError) + } catch { + case error: ApiException => assertEquals(expectedError, Errors.forException(error)) + } + } + + assertSnapshotError(Errors.FENCED_LEADER_EPOCH, Optional.of(leaderEpoch - 1)) + assertSnapshotError(Errors.UNKNOWN_LEADER_EPOCH, Optional.of(leaderEpoch + 1)) + assertSnapshotError(Errors.NONE, Optional.of(leaderEpoch)) + assertSnapshotError(Errors.NONE, Optional.empty()) + } + + @Test + def testFetchOffsetSnapshotEpochValidationForFollower(): Unit = { + val leaderEpoch = 5 + val partition = setupPartitionWithMocks(leaderEpoch, isLeader = false) + + def assertSnapshotError(expectedError: Errors, + currentLeaderEpoch: Optional[Integer], + fetchOnlyLeader: Boolean): Unit = { + try { + partition.fetchOffsetSnapshot(currentLeaderEpoch, fetchOnlyFromLeader = fetchOnlyLeader) + assertEquals(Errors.NONE, expectedError) + } catch { + case error: ApiException => assertEquals(expectedError, Errors.forException(error)) + } + } + + assertSnapshotError(Errors.NONE, Optional.of(leaderEpoch), fetchOnlyLeader = false) + assertSnapshotError(Errors.NONE, Optional.empty(), fetchOnlyLeader = false) + assertSnapshotError(Errors.FENCED_LEADER_EPOCH, Optional.of(leaderEpoch - 1), fetchOnlyLeader = false) + assertSnapshotError(Errors.UNKNOWN_LEADER_EPOCH, Optional.of(leaderEpoch + 1), fetchOnlyLeader = false) + + assertSnapshotError(Errors.NOT_LEADER_OR_FOLLOWER, Optional.of(leaderEpoch), fetchOnlyLeader = true) + assertSnapshotError(Errors.NOT_LEADER_OR_FOLLOWER, Optional.empty(), fetchOnlyLeader = true) + assertSnapshotError(Errors.FENCED_LEADER_EPOCH, Optional.of(leaderEpoch - 1), fetchOnlyLeader = true) + assertSnapshotError(Errors.UNKNOWN_LEADER_EPOCH, Optional.of(leaderEpoch + 1), fetchOnlyLeader = true) + } + + @Test + def testOffsetForLeaderEpochValidationForLeader(): Unit = { + val leaderEpoch = 5 + val partition = setupPartitionWithMocks(leaderEpoch, isLeader = true) + + def assertLastOffsetForLeaderError(error: Errors, currentLeaderEpochOpt: Optional[Integer]): Unit = { + val endOffset = partition.lastOffsetForLeaderEpoch(currentLeaderEpochOpt, 0, + fetchOnlyFromLeader = true) + assertEquals(error.code, endOffset.errorCode) + } + + assertLastOffsetForLeaderError(Errors.NONE, Optional.empty()) + assertLastOffsetForLeaderError(Errors.NONE, Optional.of(leaderEpoch)) + assertLastOffsetForLeaderError(Errors.FENCED_LEADER_EPOCH, Optional.of(leaderEpoch - 1)) + assertLastOffsetForLeaderError(Errors.UNKNOWN_LEADER_EPOCH, Optional.of(leaderEpoch + 1)) + } + + @Test + def testOffsetForLeaderEpochValidationForFollower(): Unit = { + val leaderEpoch = 5 + val partition = setupPartitionWithMocks(leaderEpoch, isLeader = false) + + def assertLastOffsetForLeaderError(error: Errors, + currentLeaderEpochOpt: Optional[Integer], + fetchOnlyLeader: Boolean): Unit = { + val endOffset = partition.lastOffsetForLeaderEpoch(currentLeaderEpochOpt, 0, + fetchOnlyFromLeader = fetchOnlyLeader) + assertEquals(error.code, endOffset.errorCode) + } + + assertLastOffsetForLeaderError(Errors.NONE, Optional.empty(), fetchOnlyLeader = false) + assertLastOffsetForLeaderError(Errors.NONE, Optional.of(leaderEpoch), fetchOnlyLeader = false) + assertLastOffsetForLeaderError(Errors.FENCED_LEADER_EPOCH, Optional.of(leaderEpoch - 1), fetchOnlyLeader = false) + assertLastOffsetForLeaderError(Errors.UNKNOWN_LEADER_EPOCH, Optional.of(leaderEpoch + 1), fetchOnlyLeader = false) + + assertLastOffsetForLeaderError(Errors.NOT_LEADER_OR_FOLLOWER, Optional.empty(), fetchOnlyLeader = true) + assertLastOffsetForLeaderError(Errors.NOT_LEADER_OR_FOLLOWER, Optional.of(leaderEpoch), fetchOnlyLeader = true) + assertLastOffsetForLeaderError(Errors.FENCED_LEADER_EPOCH, Optional.of(leaderEpoch - 1), fetchOnlyLeader = true) + assertLastOffsetForLeaderError(Errors.UNKNOWN_LEADER_EPOCH, Optional.of(leaderEpoch + 1), fetchOnlyLeader = true) + } + + @Test + def testReadRecordEpochValidationForLeader(): Unit = { + val leaderEpoch = 5 + val partition = setupPartitionWithMocks(leaderEpoch, isLeader = true) + + def assertReadRecordsError(error: Errors, + currentLeaderEpochOpt: Optional[Integer]): Unit = { + try { + partition.readRecords( + lastFetchedEpoch = Optional.empty(), + fetchOffset = 0L, + currentLeaderEpoch = currentLeaderEpochOpt, + maxBytes = 1024, + fetchIsolation = FetchLogEnd, + fetchOnlyFromLeader = true, + minOneMessage = false) + if (error != Errors.NONE) + fail(s"Expected readRecords to fail with error $error") + } catch { + case e: Exception => + assertEquals(error, Errors.forException(e)) + } + } + + assertReadRecordsError(Errors.NONE, Optional.empty()) + assertReadRecordsError(Errors.NONE, Optional.of(leaderEpoch)) + assertReadRecordsError(Errors.FENCED_LEADER_EPOCH, Optional.of(leaderEpoch - 1)) + assertReadRecordsError(Errors.UNKNOWN_LEADER_EPOCH, Optional.of(leaderEpoch + 1)) + } + + @Test + def testReadRecordEpochValidationForFollower(): Unit = { + val leaderEpoch = 5 + val partition = setupPartitionWithMocks(leaderEpoch, isLeader = false) + + def assertReadRecordsError(error: Errors, + currentLeaderEpochOpt: Optional[Integer], + fetchOnlyLeader: Boolean): Unit = { + try { + partition.readRecords( + lastFetchedEpoch = Optional.empty(), + fetchOffset = 0L, + currentLeaderEpoch = currentLeaderEpochOpt, + maxBytes = 1024, + fetchIsolation = FetchLogEnd, + fetchOnlyFromLeader = fetchOnlyLeader, + minOneMessage = false) + if (error != Errors.NONE) + fail(s"Expected readRecords to fail with error $error") + } catch { + case e: Exception => + assertEquals(error, Errors.forException(e)) + } + } + + assertReadRecordsError(Errors.NONE, Optional.empty(), fetchOnlyLeader = false) + assertReadRecordsError(Errors.NONE, Optional.of(leaderEpoch), fetchOnlyLeader = false) + assertReadRecordsError(Errors.FENCED_LEADER_EPOCH, Optional.of(leaderEpoch - 1), fetchOnlyLeader = false) + assertReadRecordsError(Errors.UNKNOWN_LEADER_EPOCH, Optional.of(leaderEpoch + 1), fetchOnlyLeader = false) + + assertReadRecordsError(Errors.NOT_LEADER_OR_FOLLOWER, Optional.empty(), fetchOnlyLeader = true) + assertReadRecordsError(Errors.NOT_LEADER_OR_FOLLOWER, Optional.of(leaderEpoch), fetchOnlyLeader = true) + assertReadRecordsError(Errors.FENCED_LEADER_EPOCH, Optional.of(leaderEpoch - 1), fetchOnlyLeader = true) + assertReadRecordsError(Errors.UNKNOWN_LEADER_EPOCH, Optional.of(leaderEpoch + 1), fetchOnlyLeader = true) + } + + @Test + def testFetchOffsetForTimestampEpochValidationForLeader(): Unit = { + val leaderEpoch = 5 + val partition = setupPartitionWithMocks(leaderEpoch, isLeader = true) + + def assertFetchOffsetError(error: Errors, + currentLeaderEpochOpt: Optional[Integer]): Unit = { + try { + partition.fetchOffsetForTimestamp(0L, + isolationLevel = None, + currentLeaderEpoch = currentLeaderEpochOpt, + fetchOnlyFromLeader = true) + if (error != Errors.NONE) + fail(s"Expected readRecords to fail with error $error") + } catch { + case e: Exception => + assertEquals(error, Errors.forException(e)) + } + } + + assertFetchOffsetError(Errors.NONE, Optional.empty()) + assertFetchOffsetError(Errors.NONE, Optional.of(leaderEpoch)) + assertFetchOffsetError(Errors.FENCED_LEADER_EPOCH, Optional.of(leaderEpoch - 1)) + assertFetchOffsetError(Errors.UNKNOWN_LEADER_EPOCH, Optional.of(leaderEpoch + 1)) + } + + @Test + def testFetchOffsetForTimestampEpochValidationForFollower(): Unit = { + val leaderEpoch = 5 + val partition = setupPartitionWithMocks(leaderEpoch, isLeader = false) + + def assertFetchOffsetError(error: Errors, + currentLeaderEpochOpt: Optional[Integer], + fetchOnlyLeader: Boolean): Unit = { + try { + partition.fetchOffsetForTimestamp(0L, + isolationLevel = None, + currentLeaderEpoch = currentLeaderEpochOpt, + fetchOnlyFromLeader = fetchOnlyLeader) + if (error != Errors.NONE) + fail(s"Expected readRecords to fail with error $error") + } catch { + case e: Exception => + assertEquals(error, Errors.forException(e)) + } + } + + assertFetchOffsetError(Errors.NONE, Optional.empty(), fetchOnlyLeader = false) + assertFetchOffsetError(Errors.NONE, Optional.of(leaderEpoch), fetchOnlyLeader = false) + assertFetchOffsetError(Errors.FENCED_LEADER_EPOCH, Optional.of(leaderEpoch - 1), fetchOnlyLeader = false) + assertFetchOffsetError(Errors.UNKNOWN_LEADER_EPOCH, Optional.of(leaderEpoch + 1), fetchOnlyLeader = false) + + assertFetchOffsetError(Errors.NOT_LEADER_OR_FOLLOWER, Optional.empty(), fetchOnlyLeader = true) + assertFetchOffsetError(Errors.NOT_LEADER_OR_FOLLOWER, Optional.of(leaderEpoch), fetchOnlyLeader = true) + assertFetchOffsetError(Errors.FENCED_LEADER_EPOCH, Optional.of(leaderEpoch - 1), fetchOnlyLeader = true) + assertFetchOffsetError(Errors.UNKNOWN_LEADER_EPOCH, Optional.of(leaderEpoch + 1), fetchOnlyLeader = true) + } + + @Test + def testFetchLatestOffsetIncludesLeaderEpoch(): Unit = { + val leaderEpoch = 5 + val partition = setupPartitionWithMocks(leaderEpoch, isLeader = true) + + val timestampAndOffsetOpt = partition.fetchOffsetForTimestamp(ListOffsetsRequest.LATEST_TIMESTAMP, + isolationLevel = None, + currentLeaderEpoch = Optional.empty(), + fetchOnlyFromLeader = true) + + assertTrue(timestampAndOffsetOpt.isDefined) + + val timestampAndOffset = timestampAndOffsetOpt.get + assertEquals(leaderEpoch, timestampAndOffset.leaderEpoch.get) + } + + /** + * This test checks that after a new leader election, we don't answer any ListOffsetsRequest until + * the HW of the new leader has caught up to its startLogOffset for this epoch. From a client + * perspective this helps guarantee monotonic offsets + * + * @see KIP-207 + */ + @Test + def testMonotonicOffsetsAfterLeaderChange(): Unit = { + val controllerEpoch = 3 + val leader = brokerId + val follower1 = brokerId + 1 + val follower2 = brokerId + 2 + val replicas = List(leader, follower1, follower2) + val isr = List[Integer](leader, follower2).asJava + val leaderEpoch = 8 + val batch1 = TestUtils.records(records = List( + new SimpleRecord(10, "k1".getBytes, "v1".getBytes), + new SimpleRecord(11,"k2".getBytes, "v2".getBytes))) + val batch2 = TestUtils.records(records = List(new SimpleRecord("k3".getBytes, "v1".getBytes), + new SimpleRecord(20,"k4".getBytes, "v2".getBytes), + new SimpleRecord(21,"k5".getBytes, "v3".getBytes))) + + val leaderState = new LeaderAndIsrPartitionState() + .setControllerEpoch(controllerEpoch) + .setLeader(leader) + .setLeaderEpoch(leaderEpoch) + .setIsr(isr) + .setZkVersion(1) + .setReplicas(replicas.map(Int.box).asJava) + .setIsNew(true) + + assertTrue(partition.makeLeader(leaderState, offsetCheckpoints, None), "Expected first makeLeader() to return 'leader changed'") + assertEquals(leaderEpoch, partition.getLeaderEpoch, "Current leader epoch") + assertEquals(Set[Integer](leader, follower2), partition.isrState.isr, "ISR") + + val requestLocal = RequestLocal.withThreadConfinedCaching + // after makeLeader(() call, partition should know about all the replicas + // append records with initial leader epoch + partition.appendRecordsToLeader(batch1, origin = AppendOrigin.Client, requiredAcks = 0, requestLocal) + partition.appendRecordsToLeader(batch2, origin = AppendOrigin.Client, requiredAcks = 0, requestLocal) + assertEquals(partition.localLogOrException.logStartOffset, partition.localLogOrException.highWatermark, + "Expected leader's HW not move") + + // let the follower in ISR move leader's HW to move further but below LEO + def updateFollowerFetchState(followerId: Int, fetchOffsetMetadata: LogOffsetMetadata): Unit = { + partition.updateFollowerFetchState( + followerId, + followerFetchOffsetMetadata = fetchOffsetMetadata, + followerStartOffset = 0L, + followerFetchTimeMs = time.milliseconds(), + leaderEndOffset = partition.localLogOrException.logEndOffset) + } + + def fetchOffsetsForTimestamp(timestamp: Long, isolation: Option[IsolationLevel]): Either[ApiException, Option[TimestampAndOffset]] = { + try { + Right(partition.fetchOffsetForTimestamp( + timestamp = timestamp, + isolationLevel = isolation, + currentLeaderEpoch = Optional.of(partition.getLeaderEpoch), + fetchOnlyFromLeader = true + )) + } catch { + case e: ApiException => Left(e) + } + } + + updateFollowerFetchState(follower1, LogOffsetMetadata(0)) + updateFollowerFetchState(follower1, LogOffsetMetadata(2)) + + updateFollowerFetchState(follower2, LogOffsetMetadata(0)) + updateFollowerFetchState(follower2, LogOffsetMetadata(2)) + + // Simulate successful ISR update + alterIsrManager.completeIsrUpdate(2) + + // At this point, the leader has gotten 5 writes, but followers have only fetched two + assertEquals(2, partition.localLogOrException.highWatermark) + + // Get the LEO + fetchOffsetsForTimestamp(ListOffsetsRequest.LATEST_TIMESTAMP, None) match { + case Right(Some(offsetAndTimestamp)) => assertEquals(5, offsetAndTimestamp.offset) + case Right(None) => fail("Should have seen some offsets") + case Left(e) => fail("Should not have seen an error") + } + + // Get the HW + fetchOffsetsForTimestamp(ListOffsetsRequest.LATEST_TIMESTAMP, Some(IsolationLevel.READ_UNCOMMITTED)) match { + case Right(Some(offsetAndTimestamp)) => assertEquals(2, offsetAndTimestamp.offset) + case Right(None) => fail("Should have seen some offsets") + case Left(e) => fail("Should not have seen an error") + } + + // Get a offset beyond the HW by timestamp, get a None + assertEquals(Right(None), fetchOffsetsForTimestamp(30, Some(IsolationLevel.READ_UNCOMMITTED))) + + // Make into a follower + val followerState = new LeaderAndIsrPartitionState() + .setControllerEpoch(controllerEpoch) + .setLeader(follower2) + .setLeaderEpoch(leaderEpoch + 1) + .setIsr(isr) + .setZkVersion(4) + .setReplicas(replicas.map(Int.box).asJava) + .setIsNew(false) + + assertTrue(partition.makeFollower(followerState, offsetCheckpoints, None)) + + // Back to leader, this resets the startLogOffset for this epoch (to 2), we're now in the fault condition + val newLeaderState = new LeaderAndIsrPartitionState() + .setControllerEpoch(controllerEpoch) + .setLeader(leader) + .setLeaderEpoch(leaderEpoch + 2) + .setIsr(isr) + .setZkVersion(5) + .setReplicas(replicas.map(Int.box).asJava) + .setIsNew(false) + + assertTrue(partition.makeLeader(newLeaderState, offsetCheckpoints, None)) + + // Try to get offsets as a client + fetchOffsetsForTimestamp(ListOffsetsRequest.LATEST_TIMESTAMP, Some(IsolationLevel.READ_UNCOMMITTED)) match { + case Right(Some(offsetAndTimestamp)) => fail("Should have failed with OffsetNotAvailable") + case Right(None) => fail("Should have seen an error") + case Left(e: OffsetNotAvailableException) => // ok + case Left(e: ApiException) => fail(s"Expected OffsetNotAvailableException, got $e") + } + + // If request is not from a client, we skip the check + fetchOffsetsForTimestamp(ListOffsetsRequest.LATEST_TIMESTAMP, None) match { + case Right(Some(offsetAndTimestamp)) => assertEquals(5, offsetAndTimestamp.offset) + case Right(None) => fail("Should have seen some offsets") + case Left(e: ApiException) => fail(s"Got ApiException $e") + } + + // If we request the earliest timestamp, we skip the check + fetchOffsetsForTimestamp(ListOffsetsRequest.EARLIEST_TIMESTAMP, Some(IsolationLevel.READ_UNCOMMITTED)) match { + case Right(Some(offsetAndTimestamp)) => assertEquals(0, offsetAndTimestamp.offset) + case Right(None) => fail("Should have seen some offsets") + case Left(e: ApiException) => fail(s"Got ApiException $e") + } + + // If we request an offset by timestamp earlier than the HW, we are ok + fetchOffsetsForTimestamp(11, Some(IsolationLevel.READ_UNCOMMITTED)) match { + case Right(Some(offsetAndTimestamp)) => + assertEquals(1, offsetAndTimestamp.offset) + assertEquals(11, offsetAndTimestamp.timestamp) + case Right(None) => fail("Should have seen some offsets") + case Left(e: ApiException) => fail(s"Got ApiException $e") + } + + // Request an offset by timestamp beyond the HW, get an error now since we're in a bad state + fetchOffsetsForTimestamp(100, Some(IsolationLevel.READ_UNCOMMITTED)) match { + case Right(Some(offsetAndTimestamp)) => fail("Should have failed") + case Right(None) => fail("Should have failed") + case Left(e: OffsetNotAvailableException) => // ok + case Left(e: ApiException) => fail(s"Should have seen OffsetNotAvailableException, saw $e") + } + + // Next fetch from replicas, HW is moved up to 5 (ahead of the LEO) + updateFollowerFetchState(follower1, LogOffsetMetadata(5)) + updateFollowerFetchState(follower2, LogOffsetMetadata(5)) + + // Simulate successful ISR update + alterIsrManager.completeIsrUpdate(6) + + // Error goes away + fetchOffsetsForTimestamp(ListOffsetsRequest.LATEST_TIMESTAMP, Some(IsolationLevel.READ_UNCOMMITTED)) match { + case Right(Some(offsetAndTimestamp)) => assertEquals(5, offsetAndTimestamp.offset) + case Right(None) => fail("Should have seen some offsets") + case Left(e: ApiException) => fail(s"Got ApiException $e") + } + + // Now we see None instead of an error for out of range timestamp + assertEquals(Right(None), fetchOffsetsForTimestamp(100, Some(IsolationLevel.READ_UNCOMMITTED))) + } + + @Test + def testAppendRecordsAsFollowerBelowLogStartOffset(): Unit = { + partition.createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None) + val log = partition.localLogOrException + + val initialLogStartOffset = 5L + partition.truncateFullyAndStartAt(initialLogStartOffset, isFuture = false) + assertEquals(initialLogStartOffset, log.logEndOffset, + s"Log end offset after truncate fully and start at $initialLogStartOffset:") + assertEquals(initialLogStartOffset, log.logStartOffset, + s"Log start offset after truncate fully and start at $initialLogStartOffset:") + + // verify that we cannot append records that do not contain log start offset even if the log is empty + assertThrows(classOf[UnexpectedAppendOffsetException], () => + // append one record with offset = 3 + partition.appendRecordsToFollowerOrFutureReplica(createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes)), baseOffset = 3L), isFuture = false) + ) + assertEquals(initialLogStartOffset, log.logEndOffset, + s"Log end offset should not change after failure to append") + + // verify that we can append records that contain log start offset, even when first + // offset < log start offset if the log is empty + val newLogStartOffset = 4L + val records = createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes), + new SimpleRecord("k2".getBytes, "v2".getBytes), + new SimpleRecord("k3".getBytes, "v3".getBytes)), + baseOffset = newLogStartOffset) + partition.appendRecordsToFollowerOrFutureReplica(records, isFuture = false) + assertEquals(7L, log.logEndOffset, s"Log end offset after append of 3 records with base offset $newLogStartOffset:") + assertEquals(newLogStartOffset, log.logStartOffset, s"Log start offset after append of 3 records with base offset $newLogStartOffset:") + + // and we can append more records after that + partition.appendRecordsToFollowerOrFutureReplica(createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes)), baseOffset = 7L), isFuture = false) + assertEquals(8L, log.logEndOffset, s"Log end offset after append of 1 record at offset 7:") + assertEquals(newLogStartOffset, log.logStartOffset, s"Log start offset not expected to change:") + + // but we cannot append to offset < log start if the log is not empty + val records2 = createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes), + new SimpleRecord("k2".getBytes, "v2".getBytes)), + baseOffset = 3L) + assertThrows(classOf[UnexpectedAppendOffsetException], () => partition.appendRecordsToFollowerOrFutureReplica(records2, isFuture = false)) + assertEquals(8L, log.logEndOffset, s"Log end offset should not change after failure to append") + + // we still can append to next offset + partition.appendRecordsToFollowerOrFutureReplica(createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes)), baseOffset = 8L), isFuture = false) + assertEquals(9L, log.logEndOffset, s"Log end offset after append of 1 record at offset 8:") + assertEquals(newLogStartOffset, log.logStartOffset, s"Log start offset not expected to change:") + } + + @Test + def testListOffsetIsolationLevels(): Unit = { + val controllerEpoch = 0 + val leaderEpoch = 5 + val replicas = List[Integer](brokerId, brokerId + 1).asJava + val isr = replicas + + partition.createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None) + + assertTrue(partition.makeLeader(new LeaderAndIsrPartitionState() + .setControllerEpoch(controllerEpoch) + .setLeader(brokerId) + .setLeaderEpoch(leaderEpoch) + .setIsr(isr) + .setZkVersion(1) + .setReplicas(replicas) + .setIsNew(true), offsetCheckpoints, None), "Expected become leader transition to succeed") + assertEquals(leaderEpoch, partition.getLeaderEpoch) + + val records = createTransactionalRecords(List( + new SimpleRecord("k1".getBytes, "v1".getBytes), + new SimpleRecord("k2".getBytes, "v2".getBytes), + new SimpleRecord("k3".getBytes, "v3".getBytes)), + baseOffset = 0L) + partition.appendRecordsToLeader(records, origin = AppendOrigin.Client, requiredAcks = 0, RequestLocal.withThreadConfinedCaching) + + def fetchLatestOffset(isolationLevel: Option[IsolationLevel]): TimestampAndOffset = { + val res = partition.fetchOffsetForTimestamp(ListOffsetsRequest.LATEST_TIMESTAMP, + isolationLevel = isolationLevel, + currentLeaderEpoch = Optional.empty(), + fetchOnlyFromLeader = true) + assertTrue(res.isDefined) + res.get + } + + def fetchEarliestOffset(isolationLevel: Option[IsolationLevel]): TimestampAndOffset = { + val res = partition.fetchOffsetForTimestamp(ListOffsetsRequest.EARLIEST_TIMESTAMP, + isolationLevel = isolationLevel, + currentLeaderEpoch = Optional.empty(), + fetchOnlyFromLeader = true) + assertTrue(res.isDefined) + res.get + } + + assertEquals(3L, fetchLatestOffset(isolationLevel = None).offset) + assertEquals(0L, fetchLatestOffset(isolationLevel = Some(IsolationLevel.READ_UNCOMMITTED)).offset) + assertEquals(0L, fetchLatestOffset(isolationLevel = Some(IsolationLevel.READ_COMMITTED)).offset) + + partition.log.get.updateHighWatermark(1L) + + assertEquals(3L, fetchLatestOffset(isolationLevel = None).offset) + assertEquals(1L, fetchLatestOffset(isolationLevel = Some(IsolationLevel.READ_UNCOMMITTED)).offset) + assertEquals(0L, fetchLatestOffset(isolationLevel = Some(IsolationLevel.READ_COMMITTED)).offset) + + assertEquals(0L, fetchEarliestOffset(isolationLevel = None).offset) + assertEquals(0L, fetchEarliestOffset(isolationLevel = Some(IsolationLevel.READ_UNCOMMITTED)).offset) + assertEquals(0L, fetchEarliestOffset(isolationLevel = Some(IsolationLevel.READ_COMMITTED)).offset) + } + + @Test + def testGetReplica(): Unit = { + assertEquals(None, partition.log) + assertThrows(classOf[NotLeaderOrFollowerException], () => + partition.localLogOrException + ) + } + + @Test + def testAppendRecordsToFollowerWithNoReplicaThrowsException(): Unit = { + assertThrows(classOf[NotLeaderOrFollowerException], () => + partition.appendRecordsToFollowerOrFutureReplica( + createRecords(List(new SimpleRecord("k1".getBytes, "v1".getBytes)), baseOffset = 0L), isFuture = false) + ) + } + + @Test + def testMakeFollowerWithNoLeaderIdChange(): Unit = { + // Start off as follower + var partitionState = new LeaderAndIsrPartitionState() + .setControllerEpoch(0) + .setLeader(1) + .setLeaderEpoch(1) + .setIsr(List[Integer](0, 1, 2, brokerId).asJava) + .setZkVersion(1) + .setReplicas(List[Integer](0, 1, 2, brokerId).asJava) + .setIsNew(false) + partition.makeFollower(partitionState, offsetCheckpoints, None) + + // Request with same leader and epoch increases by only 1, do become-follower steps + partitionState = new LeaderAndIsrPartitionState() + .setControllerEpoch(0) + .setLeader(1) + .setLeaderEpoch(4) + .setIsr(List[Integer](0, 1, 2, brokerId).asJava) + .setZkVersion(1) + .setReplicas(List[Integer](0, 1, 2, brokerId).asJava) + .setIsNew(false) + assertTrue(partition.makeFollower(partitionState, offsetCheckpoints, None)) + + // Request with same leader and same epoch, skip become-follower steps + partitionState = new LeaderAndIsrPartitionState() + .setControllerEpoch(0) + .setLeader(1) + .setLeaderEpoch(4) + .setIsr(List[Integer](0, 1, 2, brokerId).asJava) + .setZkVersion(1) + .setReplicas(List[Integer](0, 1, 2, brokerId).asJava) + assertFalse(partition.makeFollower(partitionState, offsetCheckpoints, None)) + } + + @Test + def testFollowerDoesNotJoinISRUntilCaughtUpToOffsetWithinCurrentLeaderEpoch(): Unit = { + val controllerEpoch = 3 + val leader = brokerId + val follower1 = brokerId + 1 + val follower2 = brokerId + 2 + val replicas = List[Integer](leader, follower1, follower2).asJava + val isr = List[Integer](leader, follower2).asJava + val leaderEpoch = 8 + val batch1 = TestUtils.records(records = List(new SimpleRecord("k1".getBytes, "v1".getBytes), + new SimpleRecord("k2".getBytes, "v2".getBytes))) + val batch2 = TestUtils.records(records = List(new SimpleRecord("k3".getBytes, "v1".getBytes), + new SimpleRecord("k4".getBytes, "v2".getBytes), + new SimpleRecord("k5".getBytes, "v3".getBytes))) + val batch3 = TestUtils.records(records = List(new SimpleRecord("k6".getBytes, "v1".getBytes), + new SimpleRecord("k7".getBytes, "v2".getBytes))) + + val leaderState = new LeaderAndIsrPartitionState() + .setControllerEpoch(controllerEpoch) + .setLeader(leader) + .setLeaderEpoch(leaderEpoch) + .setIsr(isr) + .setZkVersion(1) + .setReplicas(replicas) + .setIsNew(true) + assertTrue(partition.makeLeader(leaderState, offsetCheckpoints, None), "Expected first makeLeader() to return 'leader changed'") + assertEquals(leaderEpoch, partition.getLeaderEpoch, "Current leader epoch") + assertEquals(Set[Integer](leader, follower2), partition.isrState.isr, "ISR") + + val requestLocal = RequestLocal.withThreadConfinedCaching + + // after makeLeader(() call, partition should know about all the replicas + // append records with initial leader epoch + val lastOffsetOfFirstBatch = partition.appendRecordsToLeader(batch1, origin = AppendOrigin.Client, + requiredAcks = 0, requestLocal).lastOffset + partition.appendRecordsToLeader(batch2, origin = AppendOrigin.Client, requiredAcks = 0, requestLocal) + assertEquals(partition.localLogOrException.logStartOffset, partition.log.get.highWatermark, "Expected leader's HW not move") + + // let the follower in ISR move leader's HW to move further but below LEO + def updateFollowerFetchState(followerId: Int, fetchOffsetMetadata: LogOffsetMetadata): Unit = { + partition.updateFollowerFetchState( + followerId, + followerFetchOffsetMetadata = fetchOffsetMetadata, + followerStartOffset = 0L, + followerFetchTimeMs = time.milliseconds(), + leaderEndOffset = partition.localLogOrException.logEndOffset) + } + + updateFollowerFetchState(follower2, LogOffsetMetadata(0)) + updateFollowerFetchState(follower2, LogOffsetMetadata(lastOffsetOfFirstBatch)) + assertEquals(lastOffsetOfFirstBatch, partition.log.get.highWatermark, "Expected leader's HW") + + // current leader becomes follower and then leader again (without any new records appended) + val followerState = new LeaderAndIsrPartitionState() + .setControllerEpoch(controllerEpoch) + .setLeader(follower2) + .setLeaderEpoch(leaderEpoch + 1) + .setIsr(isr) + .setZkVersion(1) + .setReplicas(replicas) + .setIsNew(false) + partition.makeFollower(followerState, offsetCheckpoints, None) + + val newLeaderState = new LeaderAndIsrPartitionState() + .setControllerEpoch(controllerEpoch) + .setLeader(leader) + .setLeaderEpoch(leaderEpoch + 2) + .setIsr(isr) + .setZkVersion(1) + .setReplicas(replicas) + .setIsNew(false) + assertTrue(partition.makeLeader(newLeaderState, offsetCheckpoints, None), + "Expected makeLeader() to return 'leader changed' after makeFollower()") + val currentLeaderEpochStartOffset = partition.localLogOrException.logEndOffset + + // append records with the latest leader epoch + partition.appendRecordsToLeader(batch3, origin = AppendOrigin.Client, requiredAcks = 0, requestLocal) + + // fetch from follower not in ISR from log start offset should not add this follower to ISR + updateFollowerFetchState(follower1, LogOffsetMetadata(0)) + updateFollowerFetchState(follower1, LogOffsetMetadata(lastOffsetOfFirstBatch)) + assertEquals(Set[Integer](leader, follower2), partition.isrState.isr, "ISR") + + // fetch from the follower not in ISR from start offset of the current leader epoch should + // add this follower to ISR + updateFollowerFetchState(follower1, LogOffsetMetadata(currentLeaderEpochStartOffset)) + + // Expansion does not affect the ISR + assertEquals(Set[Integer](leader, follower2), partition.isrState.isr, "ISR") + assertEquals(Set[Integer](leader, follower1, follower2), partition.isrState.maximalIsr, "ISR") + assertEquals(alterIsrManager.isrUpdates.head.leaderAndIsr.isr.toSet, + Set(leader, follower1, follower2), "AlterIsr") + } + + def createRecords(records: Iterable[SimpleRecord], baseOffset: Long, partitionLeaderEpoch: Int = 0): MemoryRecords = { + val buf = ByteBuffer.allocate(DefaultRecordBatch.sizeInBytes(records.asJava)) + val builder = MemoryRecords.builder( + buf, RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE, TimestampType.LOG_APPEND_TIME, + baseOffset, time.milliseconds, partitionLeaderEpoch) + records.foreach(builder.append) + builder.build() + } + + def createTransactionalRecords(records: Iterable[SimpleRecord], + baseOffset: Long): MemoryRecords = { + val producerId = 1L + val producerEpoch = 0.toShort + val baseSequence = 0 + val isTransactional = true + val buf = ByteBuffer.allocate(DefaultRecordBatch.sizeInBytes(records.asJava)) + val builder = MemoryRecords.builder(buf, CompressionType.NONE, baseOffset, producerId, + producerEpoch, baseSequence, isTransactional) + records.foreach(builder.append) + builder.build() + } + + /** + * Test for AtMinIsr partition state. We set the partition replica set size as 3, but only set one replica as an ISR. + * As the default minIsr configuration is 1, then the partition should be at min ISR (isAtMinIsr = true). + */ + @Test + def testAtMinIsr(): Unit = { + val controllerEpoch = 3 + val leader = brokerId + val follower1 = brokerId + 1 + val follower2 = brokerId + 2 + val replicas = List[Integer](leader, follower1, follower2).asJava + val isr = List[Integer](leader).asJava + val leaderEpoch = 8 + + assertFalse(partition.isAtMinIsr) + // Make isr set to only have leader to trigger AtMinIsr (default min isr config is 1) + val leaderState = new LeaderAndIsrPartitionState() + .setControllerEpoch(controllerEpoch) + .setLeader(leader) + .setLeaderEpoch(leaderEpoch) + .setIsr(isr) + .setZkVersion(1) + .setReplicas(replicas) + .setIsNew(true) + partition.makeLeader(leaderState, offsetCheckpoints, None) + assertTrue(partition.isAtMinIsr) + } + + @Test + def testUpdateFollowerFetchState(): Unit = { + val log = logManager.getOrCreateLog(topicPartition, topicId = None) + seedLogData(log, numRecords = 6, leaderEpoch = 4) + + val controllerEpoch = 0 + val leaderEpoch = 5 + val remoteBrokerId = brokerId + 1 + val replicas = List[Integer](brokerId, remoteBrokerId).asJava + val isr = replicas + + partition.createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None) + + val initializeTimeMs = time.milliseconds() + assertTrue(partition.makeLeader( + new LeaderAndIsrPartitionState() + .setControllerEpoch(controllerEpoch) + .setLeader(brokerId) + .setLeaderEpoch(leaderEpoch) + .setIsr(isr) + .setZkVersion(1) + .setReplicas(replicas) + .setIsNew(true), + offsetCheckpoints, None), "Expected become leader transition to succeed") + + val remoteReplica = partition.getReplica(remoteBrokerId).get + assertEquals(initializeTimeMs, remoteReplica.lastCaughtUpTimeMs) + assertEquals(LogOffsetMetadata.UnknownOffsetMetadata.messageOffset, remoteReplica.logEndOffset) + assertEquals(UnifiedLog.UnknownOffset, remoteReplica.logStartOffset) + + time.sleep(500) + + partition.updateFollowerFetchState(remoteBrokerId, + followerFetchOffsetMetadata = LogOffsetMetadata(3), + followerStartOffset = 0L, + followerFetchTimeMs = time.milliseconds(), + leaderEndOffset = 6L) + + assertEquals(initializeTimeMs, remoteReplica.lastCaughtUpTimeMs) + assertEquals(3L, remoteReplica.logEndOffset) + assertEquals(0L, remoteReplica.logStartOffset) + + time.sleep(500) + + partition.updateFollowerFetchState(remoteBrokerId, + followerFetchOffsetMetadata = LogOffsetMetadata(6L), + followerStartOffset = 0L, + followerFetchTimeMs = time.milliseconds(), + leaderEndOffset = 6L) + + assertEquals(time.milliseconds(), remoteReplica.lastCaughtUpTimeMs) + assertEquals(6L, remoteReplica.logEndOffset) + assertEquals(0L, remoteReplica.logStartOffset) + + } + + @Test + def testIsrExpansion(): Unit = { + val log = logManager.getOrCreateLog(topicPartition, topicId = None) + seedLogData(log, numRecords = 10, leaderEpoch = 4) + + val controllerEpoch = 0 + val leaderEpoch = 5 + val remoteBrokerId = brokerId + 1 + val replicas = List(brokerId, remoteBrokerId) + val isr = List[Integer](brokerId).asJava + + partition.createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None) + assertTrue(partition.makeLeader( + new LeaderAndIsrPartitionState() + .setControllerEpoch(controllerEpoch) + .setLeader(brokerId) + .setLeaderEpoch(leaderEpoch) + .setIsr(isr) + .setZkVersion(1) + .setReplicas(replicas.map(Int.box).asJava) + .setIsNew(true), + offsetCheckpoints, None), "Expected become leader transition to succeed") + assertEquals(Set(brokerId), partition.isrState.isr) + + val remoteReplica = partition.getReplica(remoteBrokerId).get + assertEquals(LogOffsetMetadata.UnknownOffsetMetadata.messageOffset, remoteReplica.logEndOffset) + assertEquals(UnifiedLog.UnknownOffset, remoteReplica.logStartOffset) + + partition.updateFollowerFetchState(remoteBrokerId, + followerFetchOffsetMetadata = LogOffsetMetadata(3), + followerStartOffset = 0L, + followerFetchTimeMs = time.milliseconds(), + leaderEndOffset = 6L) + + assertEquals(Set(brokerId), partition.isrState.isr) + assertEquals(3L, remoteReplica.logEndOffset) + assertEquals(0L, remoteReplica.logStartOffset) + + partition.updateFollowerFetchState(remoteBrokerId, + followerFetchOffsetMetadata = LogOffsetMetadata(10), + followerStartOffset = 0L, + followerFetchTimeMs = time.milliseconds(), + leaderEndOffset = 6L) + + assertEquals(alterIsrManager.isrUpdates.size, 1) + val isrItem = alterIsrManager.isrUpdates.head + assertEquals(isrItem.leaderAndIsr.isr, List(brokerId, remoteBrokerId)) + assertEquals(Set(brokerId), partition.isrState.isr) + assertEquals(Set(brokerId, remoteBrokerId), partition.isrState.maximalIsr) + assertEquals(10L, remoteReplica.logEndOffset) + assertEquals(0L, remoteReplica.logStartOffset) + + // Complete the ISR expansion + alterIsrManager.completeIsrUpdate(2) + assertEquals(Set(brokerId, remoteBrokerId), partition.isrState.isr) + + assertEquals(isrChangeListener.expands.get, 1) + assertEquals(isrChangeListener.shrinks.get, 0) + assertEquals(isrChangeListener.failures.get, 0) + } + + @Test + def testIsrNotExpandedIfUpdateFails(): Unit = { + val log = logManager.getOrCreateLog(topicPartition, topicId = None) + seedLogData(log, numRecords = 10, leaderEpoch = 4) + + val controllerEpoch = 0 + val leaderEpoch = 5 + val remoteBrokerId = brokerId + 1 + val replicas = List[Integer](brokerId, remoteBrokerId).asJava + val isr = List[Integer](brokerId).asJava + + partition.createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None) + assertTrue(partition.makeLeader( + new LeaderAndIsrPartitionState() + .setControllerEpoch(controllerEpoch) + .setLeader(brokerId) + .setLeaderEpoch(leaderEpoch) + .setIsr(isr) + .setZkVersion(1) + .setReplicas(replicas) + .setIsNew(true), + offsetCheckpoints, None), "Expected become leader transition to succeed") + assertEquals(Set(brokerId), partition.isrState.isr) + + val remoteReplica = partition.getReplica(remoteBrokerId).get + assertEquals(LogOffsetMetadata.UnknownOffsetMetadata.messageOffset, remoteReplica.logEndOffset) + assertEquals(UnifiedLog.UnknownOffset, remoteReplica.logStartOffset) + + partition.updateFollowerFetchState(remoteBrokerId, + followerFetchOffsetMetadata = LogOffsetMetadata(10), + followerStartOffset = 0L, + followerFetchTimeMs = time.milliseconds(), + leaderEndOffset = 10L) + + // Follower state is updated, but the ISR has not expanded + assertEquals(Set(brokerId), partition.inSyncReplicaIds) + assertEquals(Set(brokerId, remoteBrokerId), partition.isrState.maximalIsr) + assertEquals(alterIsrManager.isrUpdates.size, 1) + assertEquals(10L, remoteReplica.logEndOffset) + assertEquals(0L, remoteReplica.logStartOffset) + + // Simulate failure callback + alterIsrManager.failIsrUpdate(Errors.INVALID_UPDATE_VERSION) + + // Still no ISR change + assertEquals(Set(brokerId), partition.inSyncReplicaIds) + assertEquals(Set(brokerId, remoteBrokerId), partition.isrState.maximalIsr) + assertEquals(alterIsrManager.isrUpdates.size, 0) + + assertEquals(isrChangeListener.expands.get, 0) + assertEquals(isrChangeListener.shrinks.get, 0) + assertEquals(isrChangeListener.failures.get, 1) + } + + @Test + def testRetryShrinkIsr(): Unit = { + val log = logManager.getOrCreateLog(topicPartition, topicId = None) + seedLogData(log, numRecords = 10, leaderEpoch = 4) + + val controllerEpoch = 0 + val leaderEpoch = 5 + val remoteBrokerId = brokerId + 1 + val replicas = Seq(brokerId, remoteBrokerId) + val isr = Seq(brokerId, remoteBrokerId) + val topicId = Uuid.randomUuid() + + assertTrue(makeLeader( + topicId = Some(topicId), + controllerEpoch = controllerEpoch, + leaderEpoch = leaderEpoch, + isr = isr, + replicas = replicas, + zkVersion = 1, + isNew = true + )) + assertEquals(0L, partition.localLogOrException.highWatermark) + + // Sleep enough time to shrink the ISR + time.sleep(partition.replicaLagTimeMaxMs + 1) + + // Try to shrink the ISR + partition.maybeShrinkIsr() + assertEquals(alterIsrManager.isrUpdates.size, 1) + assertEquals(alterIsrManager.isrUpdates.head.leaderAndIsr.isr, List(brokerId)) + assertEquals(Set(brokerId, remoteBrokerId), partition.isrState.isr) + assertEquals(Set(brokerId, remoteBrokerId), partition.isrState.maximalIsr) + + // The shrink fails and we retry + alterIsrManager.failIsrUpdate(Errors.NETWORK_EXCEPTION) + assertEquals(0, isrChangeListener.shrinks.get) + assertEquals(1, isrChangeListener.failures.get) + assertEquals(1, partition.getZkVersion) + assertEquals(alterIsrManager.isrUpdates.size, 1) + assertEquals(Set(brokerId, remoteBrokerId), partition.isrState.isr) + assertEquals(Set(brokerId, remoteBrokerId), partition.isrState.maximalIsr) + assertEquals(0L, partition.localLogOrException.highWatermark) + + // The shrink succeeds after retrying + alterIsrManager.completeIsrUpdate(newZkVersion = 2) + assertEquals(1, isrChangeListener.shrinks.get) + assertEquals(2, partition.getZkVersion) + assertEquals(alterIsrManager.isrUpdates.size, 0) + assertEquals(Set(brokerId), partition.isrState.isr) + assertEquals(Set(brokerId), partition.isrState.maximalIsr) + assertEquals(log.logEndOffset, partition.localLogOrException.highWatermark) + } + + @Test + def testMaybeShrinkIsr(): Unit = { + val log = logManager.getOrCreateLog(topicPartition, topicId = None) + seedLogData(log, numRecords = 10, leaderEpoch = 4) + + val controllerEpoch = 0 + val leaderEpoch = 5 + val remoteBrokerId = brokerId + 1 + val replicas = Seq(brokerId, remoteBrokerId) + val isr = Seq(brokerId, remoteBrokerId) + val initializeTimeMs = time.milliseconds() + + assertTrue(makeLeader( + topicId = None, + controllerEpoch = controllerEpoch, + leaderEpoch = leaderEpoch, + isr = isr, + replicas = replicas, + zkVersion = 1, + isNew = true + )) + assertEquals(0L, partition.localLogOrException.highWatermark) + + val remoteReplica = partition.getReplica(remoteBrokerId).get + assertEquals(initializeTimeMs, remoteReplica.lastCaughtUpTimeMs) + assertEquals(LogOffsetMetadata.UnknownOffsetMetadata.messageOffset, remoteReplica.logEndOffset) + assertEquals(UnifiedLog.UnknownOffset, remoteReplica.logStartOffset) + + // On initialization, the replica is considered caught up and should not be removed + partition.maybeShrinkIsr() + assertEquals(Set(brokerId, remoteBrokerId), partition.isrState.isr) + + // If enough time passes without a fetch update, the ISR should shrink + time.sleep(partition.replicaLagTimeMaxMs + 1) + + // Shrink the ISR + partition.maybeShrinkIsr() + assertEquals(0, isrChangeListener.shrinks.get) + assertEquals(alterIsrManager.isrUpdates.size, 1) + assertEquals(alterIsrManager.isrUpdates.head.leaderAndIsr.isr, List(brokerId)) + assertEquals(Set(brokerId, remoteBrokerId), partition.isrState.isr) + assertEquals(Set(brokerId, remoteBrokerId), partition.isrState.maximalIsr) + assertEquals(0L, partition.localLogOrException.highWatermark) + + // After the ISR shrink completes, the ISR state should be updated and the + // high watermark should be advanced + alterIsrManager.completeIsrUpdate(newZkVersion = 2) + assertEquals(1, isrChangeListener.shrinks.get) + assertEquals(2, partition.getZkVersion) + assertEquals(alterIsrManager.isrUpdates.size, 0) + assertEquals(Set(brokerId), partition.isrState.isr) + assertEquals(Set(brokerId), partition.isrState.maximalIsr) + assertEquals(log.logEndOffset, partition.localLogOrException.highWatermark) + } + + @Test + def testAlterIsrLeaderAndIsrRace(): Unit = { + val log = logManager.getOrCreateLog(topicPartition, topicId = None) + seedLogData(log, numRecords = 10, leaderEpoch = 4) + + val controllerEpoch = 0 + val leaderEpoch = 5 + val remoteBrokerId = brokerId + 1 + val replicas = Seq(brokerId, remoteBrokerId) + val isr = Seq(brokerId, remoteBrokerId) + val initializeTimeMs = time.milliseconds() + + assertTrue(makeLeader( + topicId = None, + controllerEpoch = controllerEpoch, + leaderEpoch = leaderEpoch, + isr = isr, + replicas = replicas, + zkVersion = 1, + isNew = true + )) + assertEquals(0L, partition.localLogOrException.highWatermark) + + val remoteReplica = partition.getReplica(remoteBrokerId).get + assertEquals(initializeTimeMs, remoteReplica.lastCaughtUpTimeMs) + assertEquals(LogOffsetMetadata.UnknownOffsetMetadata.messageOffset, remoteReplica.logEndOffset) + assertEquals(UnifiedLog.UnknownOffset, remoteReplica.logStartOffset) + + // Shrink the ISR + time.sleep(partition.replicaLagTimeMaxMs + 1) + partition.maybeShrinkIsr() + assertTrue(partition.isrState.isInflight) + + // Become leader again, reset the ISR state + assertFalse(makeLeader( + topicId = None, + controllerEpoch = controllerEpoch, + leaderEpoch = leaderEpoch, + isr = isr, + replicas = replicas, + zkVersion = 2, + isNew = false + )) + assertEquals(0L, partition.localLogOrException.highWatermark) + assertFalse(partition.isrState.isInflight, "ISR should be committed and not inflight") + + // Try the shrink again, should not submit until AlterIsr response arrives + time.sleep(partition.replicaLagTimeMaxMs + 1) + partition.maybeShrinkIsr() + assertFalse(partition.isrState.isInflight, "ISR should still be committed and not inflight") + + // Complete the AlterIsr update and now we can make modifications again + alterIsrManager.completeIsrUpdate(10) + partition.maybeShrinkIsr() + assertTrue(partition.isrState.isInflight, "ISR should be pending a shrink") + } + + @Test + def testShouldNotShrinkIsrIfPreviousFetchIsCaughtUp(): Unit = { + val log = logManager.getOrCreateLog(topicPartition, topicId = None) + seedLogData(log, numRecords = 10, leaderEpoch = 4) + + val controllerEpoch = 0 + val leaderEpoch = 5 + val remoteBrokerId = brokerId + 1 + val replicas = Seq(brokerId, remoteBrokerId) + val isr = Seq(brokerId, remoteBrokerId) + val initializeTimeMs = time.milliseconds() + + assertTrue(makeLeader( + topicId = None, + controllerEpoch = controllerEpoch, + leaderEpoch = leaderEpoch, + isr = isr, + replicas = replicas, + zkVersion = 1, + isNew = true + )) + assertEquals(0L, partition.localLogOrException.highWatermark) + + val remoteReplica = partition.getReplica(remoteBrokerId).get + assertEquals(initializeTimeMs, remoteReplica.lastCaughtUpTimeMs) + assertEquals(LogOffsetMetadata.UnknownOffsetMetadata.messageOffset, remoteReplica.logEndOffset) + assertEquals(UnifiedLog.UnknownOffset, remoteReplica.logStartOffset) + + // There is a short delay before the first fetch. The follower is not yet caught up to the log end. + time.sleep(5000) + val firstFetchTimeMs = time.milliseconds() + partition.updateFollowerFetchState(remoteBrokerId, + followerFetchOffsetMetadata = LogOffsetMetadata(5), + followerStartOffset = 0L, + followerFetchTimeMs = firstFetchTimeMs, + leaderEndOffset = 10L) + assertEquals(initializeTimeMs, remoteReplica.lastCaughtUpTimeMs) + assertEquals(5L, partition.localLogOrException.highWatermark) + assertEquals(5L, remoteReplica.logEndOffset) + assertEquals(0L, remoteReplica.logStartOffset) + + // Some new data is appended, but the follower catches up to the old end offset. + // The total elapsed time from initialization is larger than the max allowed replica lag. + time.sleep(5001) + seedLogData(log, numRecords = 5, leaderEpoch = leaderEpoch) + partition.updateFollowerFetchState(remoteBrokerId, + followerFetchOffsetMetadata = LogOffsetMetadata(10), + followerStartOffset = 0L, + followerFetchTimeMs = time.milliseconds(), + leaderEndOffset = 15L) + assertEquals(firstFetchTimeMs, remoteReplica.lastCaughtUpTimeMs) + assertEquals(10L, partition.localLogOrException.highWatermark) + assertEquals(10L, remoteReplica.logEndOffset) + assertEquals(0L, remoteReplica.logStartOffset) + + // The ISR should not be shrunk because the follower has caught up with the leader at the + // time of the first fetch. + partition.maybeShrinkIsr() + assertEquals(Set(brokerId, remoteBrokerId), partition.isrState.isr) + assertEquals(alterIsrManager.isrUpdates.size, 0) + } + + @Test + def testShouldNotShrinkIsrIfFollowerCaughtUpToLogEnd(): Unit = { + val log = logManager.getOrCreateLog(topicPartition, topicId = None) + seedLogData(log, numRecords = 10, leaderEpoch = 4) + + val controllerEpoch = 0 + val leaderEpoch = 5 + val remoteBrokerId = brokerId + 1 + val replicas = Seq(brokerId, remoteBrokerId) + val isr = Seq(brokerId, remoteBrokerId) + val initializeTimeMs = time.milliseconds() + + assertTrue(makeLeader( + topicId = None, + controllerEpoch = controllerEpoch, + leaderEpoch = leaderEpoch, + isr = isr, + replicas = replicas, + zkVersion = 1, + isNew = true + )) + assertEquals(0L, partition.localLogOrException.highWatermark) + + val remoteReplica = partition.getReplica(remoteBrokerId).get + assertEquals(initializeTimeMs, remoteReplica.lastCaughtUpTimeMs) + assertEquals(LogOffsetMetadata.UnknownOffsetMetadata.messageOffset, remoteReplica.logEndOffset) + assertEquals(UnifiedLog.UnknownOffset, remoteReplica.logStartOffset) + + // The follower catches up to the log end immediately. + partition.updateFollowerFetchState(remoteBrokerId, + followerFetchOffsetMetadata = LogOffsetMetadata(10), + followerStartOffset = 0L, + followerFetchTimeMs = time.milliseconds(), + leaderEndOffset = 10L) + assertEquals(initializeTimeMs, remoteReplica.lastCaughtUpTimeMs) + assertEquals(10L, partition.localLogOrException.highWatermark) + assertEquals(10L, remoteReplica.logEndOffset) + assertEquals(0L, remoteReplica.logStartOffset) + + // Sleep longer than the max allowed follower lag + time.sleep(30001) + + // The ISR should not be shrunk because the follower is caught up to the leader's log end + partition.maybeShrinkIsr() + assertEquals(Set(brokerId, remoteBrokerId), partition.isrState.isr) + assertEquals(alterIsrManager.isrUpdates.size, 0) + } + + @Test + def testIsrNotShrunkIfUpdateFails(): Unit = { + val log = logManager.getOrCreateLog(topicPartition, topicId = None) + seedLogData(log, numRecords = 10, leaderEpoch = 4) + + val controllerEpoch = 0 + val leaderEpoch = 5 + val remoteBrokerId = brokerId + 1 + val replicas = Seq(brokerId, remoteBrokerId) + val isr = Seq(brokerId, remoteBrokerId) + val initializeTimeMs = time.milliseconds() + + assertTrue(makeLeader( + topicId = None, + controllerEpoch = controllerEpoch, + leaderEpoch = leaderEpoch, + isr = isr, + replicas = replicas, + zkVersion = 1, + isNew = true + )) + assertEquals(0L, partition.localLogOrException.highWatermark) + + val remoteReplica = partition.getReplica(remoteBrokerId).get + assertEquals(initializeTimeMs, remoteReplica.lastCaughtUpTimeMs) + assertEquals(LogOffsetMetadata.UnknownOffsetMetadata.messageOffset, remoteReplica.logEndOffset) + assertEquals(UnifiedLog.UnknownOffset, remoteReplica.logStartOffset) + + time.sleep(30001) + + // Enqueue and AlterIsr that will fail + partition.maybeShrinkIsr() + assertEquals(Set(brokerId, remoteBrokerId), partition.inSyncReplicaIds) + assertEquals(alterIsrManager.isrUpdates.size, 1) + assertEquals(0L, partition.localLogOrException.highWatermark) + + // Simulate failure callback + alterIsrManager.failIsrUpdate(Errors.INVALID_UPDATE_VERSION) + + // Ensure ISR hasn't changed + assertEquals(partition.isrState.getClass, classOf[PendingShrinkIsr]) + assertEquals(Set(brokerId, remoteBrokerId), partition.inSyncReplicaIds) + assertEquals(alterIsrManager.isrUpdates.size, 0) + assertEquals(0L, partition.localLogOrException.highWatermark) + } + + @Test + def testAlterIsrUnknownTopic(): Unit = { + handleAlterIsrFailure(Errors.UNKNOWN_TOPIC_OR_PARTITION, + (brokerId: Int, remoteBrokerId: Int, partition: Partition) => { + assertEquals(partition.isrState.isr, Set(brokerId)) + assertEquals(partition.isrState.maximalIsr, Set(brokerId, remoteBrokerId)) + assertEquals(alterIsrManager.isrUpdates.size, 0) + }) + } + + @Test + def testAlterIsrInvalidVersion(): Unit = { + handleAlterIsrFailure(Errors.INVALID_UPDATE_VERSION, + (brokerId: Int, remoteBrokerId: Int, partition: Partition) => { + assertEquals(partition.isrState.isr, Set(brokerId)) + assertEquals(partition.isrState.maximalIsr, Set(brokerId, remoteBrokerId)) + assertEquals(alterIsrManager.isrUpdates.size, 0) + }) + } + + @Test + def testAlterIsrUnexpectedError(): Unit = { + handleAlterIsrFailure(Errors.UNKNOWN_SERVER_ERROR, + (brokerId: Int, remoteBrokerId: Int, partition: Partition) => { + // We retry these + assertEquals(partition.isrState.isr, Set(brokerId)) + assertEquals(partition.isrState.maximalIsr, Set(brokerId, remoteBrokerId)) + assertEquals(alterIsrManager.isrUpdates.size, 1) + }) + } + + def handleAlterIsrFailure(error: Errors, callback: (Int, Int, Partition) => Unit): Unit = { + val log = logManager.getOrCreateLog(topicPartition, topicId = None) + seedLogData(log, numRecords = 10, leaderEpoch = 4) + + val controllerEpoch = 0 + val leaderEpoch = 5 + val remoteBrokerId = brokerId + 1 + val replicas = Seq(brokerId, remoteBrokerId) + val isr = Seq(brokerId) + + assertTrue(makeLeader( + topicId = None, + controllerEpoch = controllerEpoch, + leaderEpoch = leaderEpoch, + isr = isr, + replicas = replicas, + zkVersion = 1, + isNew = true + )) + assertEquals(10L, partition.localLogOrException.highWatermark) + + val remoteReplica = partition.getReplica(remoteBrokerId).get + assertEquals(LogOffsetMetadata.UnknownOffsetMetadata.messageOffset, remoteReplica.logEndOffset) + assertEquals(UnifiedLog.UnknownOffset, remoteReplica.logStartOffset) + + // This will attempt to expand the ISR + partition.updateFollowerFetchState(remoteBrokerId, + followerFetchOffsetMetadata = LogOffsetMetadata(10), + followerStartOffset = 0L, + followerFetchTimeMs = time.milliseconds(), + leaderEndOffset = 10L) + + // Follower state is updated, but the ISR has not expanded + assertEquals(Set(brokerId), partition.inSyncReplicaIds) + assertEquals(Set(brokerId, remoteBrokerId), partition.isrState.maximalIsr) + assertEquals(alterIsrManager.isrUpdates.size, 1) + assertEquals(10L, remoteReplica.logEndOffset) + assertEquals(0L, remoteReplica.logStartOffset) + + // Failure + alterIsrManager.failIsrUpdate(error) + callback(brokerId, remoteBrokerId, partition) + } + + @Test + def testSingleInFlightAlterIsr(): Unit = { + val log = logManager.getOrCreateLog(topicPartition, topicId = None) + seedLogData(log, numRecords = 10, leaderEpoch = 4) + + val controllerEpoch = 0 + val leaderEpoch = 5 + val follower1 = brokerId + 1 + val follower2 = brokerId + 2 + val follower3 = brokerId + 3 + val replicas = Seq(brokerId, follower1, follower2, follower3) + val isr = Seq(brokerId, follower1, follower2) + + doNothing().when(delayedOperations).checkAndCompleteAll() + + assertTrue(makeLeader( + topicId = None, + controllerEpoch = controllerEpoch, + leaderEpoch = leaderEpoch, + isr = isr, + replicas = replicas, + zkVersion = 1, + isNew = true + )) + assertEquals(0L, partition.localLogOrException.highWatermark) + + // Expand ISR + partition.updateFollowerFetchState( + followerId = follower3, + followerFetchOffsetMetadata = LogOffsetMetadata(10), + followerStartOffset = 0L, + followerFetchTimeMs = time.milliseconds(), + leaderEndOffset = 10 + ) + assertEquals(Set(brokerId, follower1, follower2), partition.isrState.isr) + assertEquals(Set(brokerId, follower1, follower2, follower3), partition.isrState.maximalIsr) + + // One AlterIsr request in-flight + assertEquals(alterIsrManager.isrUpdates.size, 1) + + // Try to modify ISR again, should do nothing + time.sleep(partition.replicaLagTimeMaxMs + 1) + partition.maybeShrinkIsr() + assertEquals(alterIsrManager.isrUpdates.size, 1) + } + + @Test + def testZkIsrManagerAsyncCallback(): Unit = { + // We need a real scheduler here so that the ISR write lock works properly + val scheduler = new KafkaScheduler(1, "zk-isr-test") + scheduler.startup() + val kafkaZkClient = mock(classOf[KafkaZkClient]) + + doAnswer(_ => (true, 2)) + .when(kafkaZkClient) + .conditionalUpdatePath(anyString(), any(), ArgumentMatchers.eq(1), any()) + + val zkIsrManager = AlterIsrManager(scheduler, time, kafkaZkClient) + zkIsrManager.start() + + val partition = new Partition(topicPartition, + replicaLagTimeMaxMs = Defaults.ReplicaLagTimeMaxMs, + interBrokerProtocolVersion = KAFKA_2_6_IV0, // shouldn't matter, but set this to a ZK isr version + localBrokerId = brokerId, + time, + isrChangeListener, + delayedOperations, + metadataCache, + logManager, + zkIsrManager) + + val log = logManager.getOrCreateLog(topicPartition, topicId = None) + seedLogData(log, numRecords = 10, leaderEpoch = 4) + + val controllerEpoch = 0 + val leaderEpoch = 5 + val follower1 = brokerId + 1 + val follower2 = brokerId + 2 + val follower3 = brokerId + 3 + val replicas = Seq(brokerId, follower1, follower2, follower3) + val isr = Seq(brokerId, follower1, follower2) + + doNothing().when(delayedOperations).checkAndCompleteAll() + + assertTrue(makeLeader( + partition = partition, + topicId = None, + controllerEpoch = controllerEpoch, + leaderEpoch = leaderEpoch, + isr = isr, + replicas = replicas, + zkVersion = 1, + isNew = true + )) + assertEquals(0L, partition.localLogOrException.highWatermark) + + // Expand ISR + partition.updateFollowerFetchState( + followerId = follower3, + followerFetchOffsetMetadata = LogOffsetMetadata(10), + followerStartOffset = 0L, + followerFetchTimeMs = time.milliseconds(), + leaderEndOffset = 10 + ) + + // Try avoiding a race + TestUtils.waitUntilTrue(() => !partition.isrState.isInflight, "Expected ISR state to be committed", 100) + + partition.isrState match { + case committed: CommittedIsr => assertEquals(Set(brokerId, follower1, follower2, follower3), committed.isr) + case _ => fail("Expected a committed ISR following Zk expansion") + } + + scheduler.shutdown() + } + + @Test + def testUseCheckpointToInitializeHighWatermark(): Unit = { + val log = logManager.getOrCreateLog(topicPartition, topicId = None) + seedLogData(log, numRecords = 6, leaderEpoch = 5) + + when(offsetCheckpoints.fetch(logDir1.getAbsolutePath, topicPartition)) + .thenReturn(Some(4L)) + + val controllerEpoch = 3 + val replicas = List[Integer](brokerId, brokerId + 1).asJava + val leaderState = new LeaderAndIsrPartitionState() + .setControllerEpoch(controllerEpoch) + .setLeader(brokerId) + .setLeaderEpoch(6) + .setIsr(replicas) + .setZkVersion(1) + .setReplicas(replicas) + .setIsNew(false) + partition.makeLeader(leaderState, offsetCheckpoints, None) + assertEquals(4, partition.localLogOrException.highWatermark) + } + + @Test + def testTopicIdAndPartitionMetadataFileForLeader(): Unit = { + val controllerEpoch = 3 + val leaderEpoch = 5 + val topicId = Uuid.randomUuid() + val replicas = List[Integer](brokerId, brokerId + 1).asJava + val leaderState = new LeaderAndIsrPartitionState() + .setControllerEpoch(controllerEpoch) + .setLeader(brokerId) + .setLeaderEpoch(leaderEpoch) + .setIsr(replicas) + .setZkVersion(1) + .setReplicas(replicas) + .setIsNew(false) + partition.makeLeader(leaderState, offsetCheckpoints, Some(topicId)) + + checkTopicId(topicId, partition) + + // Create new Partition object for same topicPartition + val partition2 = new Partition(topicPartition, + replicaLagTimeMaxMs = Defaults.ReplicaLagTimeMaxMs, + interBrokerProtocolVersion = ApiVersion.latestVersion, + localBrokerId = brokerId, + time, + isrChangeListener, + delayedOperations, + metadataCache, + logManager, + alterIsrManager) + + // partition2 should not yet be associated with the log, but should be able to get ID + assertTrue(partition2.topicId.isDefined) + assertEquals(topicId, partition2.topicId.get) + assertFalse(partition2.log.isDefined) + + // Calling makeLeader with a new topic ID should not overwrite the old topic ID. We should get an InconsistentTopicIdException. + // This scenario should not occur, since the topic ID check will fail. + assertThrows(classOf[InconsistentTopicIdException], () => partition2.makeLeader(leaderState, offsetCheckpoints, Some(Uuid.randomUuid()))) + + // Calling makeLeader with no topic ID should not overwrite the old topic ID. We should get the original log. + partition2.makeLeader(leaderState, offsetCheckpoints, None) + checkTopicId(topicId, partition2) + } + + @Test + def testTopicIdAndPartitionMetadataFileForFollower(): Unit = { + val controllerEpoch = 3 + val leaderEpoch = 5 + val topicId = Uuid.randomUuid() + val replicas = List[Integer](brokerId, brokerId + 1).asJava + val leaderState = new LeaderAndIsrPartitionState() + .setControllerEpoch(controllerEpoch) + .setLeader(brokerId) + .setLeaderEpoch(leaderEpoch) + .setIsr(replicas) + .setZkVersion(1) + .setReplicas(replicas) + .setIsNew(false) + partition.makeFollower(leaderState, offsetCheckpoints, Some(topicId)) + + checkTopicId(topicId, partition) + + // Create new Partition object for same topicPartition + val partition2 = new Partition(topicPartition, + replicaLagTimeMaxMs = Defaults.ReplicaLagTimeMaxMs, + interBrokerProtocolVersion = ApiVersion.latestVersion, + localBrokerId = brokerId, + time, + isrChangeListener, + delayedOperations, + metadataCache, + logManager, + alterIsrManager) + + // partition2 should not yet be associated with the log, but should be able to get ID + assertTrue(partition2.topicId.isDefined) + assertEquals(topicId, partition2.topicId.get) + assertFalse(partition2.log.isDefined) + + // Calling makeFollower with a new topic ID should not overwrite the old topic ID. We should get an InconsistentTopicIdException. + // This scenario should not occur, since the topic ID check will fail. + assertThrows(classOf[InconsistentTopicIdException], () => partition2.makeFollower(leaderState, offsetCheckpoints, Some(Uuid.randomUuid()))) + + // Calling makeFollower with no topic ID should not overwrite the old topic ID. We should get the original log. + partition2.makeFollower(leaderState, offsetCheckpoints, None) + checkTopicId(topicId, partition2) + } + + def checkTopicId(expectedTopicId: Uuid, partition: Partition): Unit = { + assertTrue(partition.topicId.isDefined) + assertEquals(expectedTopicId, partition.topicId.get) + assertTrue(partition.log.isDefined) + val log = partition.log.get + assertEquals(expectedTopicId, log.topicId.get) + assertTrue(log.partitionMetadataFile.exists()) + assertEquals(expectedTopicId, log.partitionMetadataFile.read().topicId) + } + + @Test + def testAddAndRemoveMetrics(): Unit = { + val metricsToCheck = List( + "UnderReplicated", + "UnderMinIsr", + "InSyncReplicasCount", + "ReplicasCount", + "LastStableOffsetLag", + "AtMinIsr") + + def getMetric(metric: String): Option[Metric] = { + KafkaYammerMetrics.defaultRegistry().allMetrics().asScala.filter { case (metricName, _) => + metricName.getName == metric && metricName.getType == "Partition" + }.headOption.map(_._2) + } + + assertTrue(metricsToCheck.forall(getMetric(_).isDefined)) + + Partition.removeMetrics(topicPartition) + + assertEquals(Set(), KafkaYammerMetrics.defaultRegistry().allMetrics().asScala.keySet.filter(_.getType == "Partition")) + } + + @Test + def testUnderReplicatedPartitionsCorrectSemantics(): Unit = { + val controllerEpoch = 3 + val replicas = List[Integer](brokerId, brokerId + 1, brokerId + 2).asJava + val isr = List[Integer](brokerId, brokerId + 1).asJava + + var leaderState = new LeaderAndIsrPartitionState() + .setControllerEpoch(controllerEpoch) + .setLeader(brokerId) + .setLeaderEpoch(6) + .setIsr(isr) + .setZkVersion(1) + .setReplicas(replicas) + .setIsNew(false) + partition.makeLeader(leaderState, offsetCheckpoints, None) + assertTrue(partition.isUnderReplicated) + + leaderState = leaderState.setIsr(replicas) + partition.makeLeader(leaderState, offsetCheckpoints, None) + assertFalse(partition.isUnderReplicated) + } + + @Test + def testUpdateAssignmentAndIsr(): Unit = { + val topicPartition = new TopicPartition("test", 1) + val partition = new Partition( + topicPartition, 1000, ApiVersion.latestVersion, 0, + new SystemTime(), mock(classOf[IsrChangeListener]), mock(classOf[DelayedOperations]), + mock(classOf[MetadataCache]), mock(classOf[LogManager]), mock(classOf[AlterIsrManager])) + + val replicas = Seq(0, 1, 2, 3) + val isr = Set(0, 1, 2, 3) + val adding = Seq(4, 5) + val removing = Seq(1, 2) + + // Test with ongoing reassignment + partition.updateAssignmentAndIsr(replicas, isr, adding, removing) + + assertTrue(partition.assignmentState.isInstanceOf[OngoingReassignmentState], "The assignmentState is not OngoingReassignmentState") + assertEquals(replicas, partition.assignmentState.replicas) + assertEquals(isr, partition.isrState.isr) + assertEquals(adding, partition.assignmentState.asInstanceOf[OngoingReassignmentState].addingReplicas) + assertEquals(removing, partition.assignmentState.asInstanceOf[OngoingReassignmentState].removingReplicas) + assertEquals(Seq(1, 2, 3), partition.remoteReplicas.map(_.brokerId)) + + // Test with simple assignment + val replicas2 = Seq(0, 3, 4, 5) + val isr2 = Set(0, 3, 4, 5) + partition.updateAssignmentAndIsr(replicas2, isr2, Seq.empty, Seq.empty) + + assertTrue(partition.assignmentState.isInstanceOf[SimpleAssignmentState], "The assignmentState is not SimpleAssignmentState") + assertEquals(replicas2, partition.assignmentState.replicas) + assertEquals(isr2, partition.isrState.isr) + assertEquals(Seq(3, 4, 5), partition.remoteReplicas.map(_.brokerId)) + } + + /** + * Test when log is getting initialized, its config remains untouched after initialization is done. + */ + @Test + def testLogConfigNotDirty(): Unit = { + logManager.shutdown() + val spyConfigRepository = spy(configRepository) + logManager = TestUtils.createLogManager( + logDirs = Seq(logDir1, logDir2), defaultConfig = logConfig, configRepository = spyConfigRepository, + cleanerConfig = CleanerConfig(enableCleaner = false), time = time) + val spyLogManager = spy(logManager) + val partition = new Partition(topicPartition, + replicaLagTimeMaxMs = Defaults.ReplicaLagTimeMaxMs, + interBrokerProtocolVersion = ApiVersion.latestVersion, + localBrokerId = brokerId, + time, + isrChangeListener, + delayedOperations, + metadataCache, + spyLogManager, + alterIsrManager) + + partition.createLog(isNew = true, isFutureReplica = false, offsetCheckpoints, topicId = None) + + // Validate that initializingLog and finishedInitializingLog was called + verify(spyLogManager).initializingLog(ArgumentMatchers.eq(topicPartition)) + verify(spyLogManager).finishedInitializingLog(ArgumentMatchers.eq(topicPartition), ArgumentMatchers.any()) + + // We should retrieve configs only once + verify(spyConfigRepository, times(1)).topicConfig(topicPartition.topic()) + } + + /** + * Test when log is getting initialized, its config remains gets reloaded if Topic config gets changed + * before initialization is done. + */ + @Test + def testLogConfigDirtyAsTopicUpdated(): Unit = { + logManager.shutdown() + val spyConfigRepository = spy(configRepository) + logManager = TestUtils.createLogManager( + logDirs = Seq(logDir1, logDir2), defaultConfig = logConfig, configRepository = spyConfigRepository, + cleanerConfig = CleanerConfig(enableCleaner = false), time = time) + val spyLogManager = spy(logManager) + doAnswer((_: InvocationOnMock) => { + logManager.initializingLog(topicPartition) + logManager.topicConfigUpdated(topicPartition.topic()) + }).when(spyLogManager).initializingLog(ArgumentMatchers.eq(topicPartition)) + + val partition = new Partition(topicPartition, + replicaLagTimeMaxMs = Defaults.ReplicaLagTimeMaxMs, + interBrokerProtocolVersion = ApiVersion.latestVersion, + localBrokerId = brokerId, + time, + isrChangeListener, + delayedOperations, + metadataCache, + spyLogManager, + alterIsrManager) + + partition.createLog(isNew = true, isFutureReplica = false, offsetCheckpoints, topicId = None) + + // Validate that initializingLog and finishedInitializingLog was called + verify(spyLogManager).initializingLog(ArgumentMatchers.eq(topicPartition)) + verify(spyLogManager).finishedInitializingLog(ArgumentMatchers.eq(topicPartition), ArgumentMatchers.any()) + + // We should retrieve configs twice, once before log is created, and second time once + // we find log config is dirty and refresh it. + verify(spyConfigRepository, times(2)).topicConfig(topicPartition.topic()) + } + + /** + * Test when log is getting initialized, its config remains gets reloaded if Broker config gets changed + * before initialization is done. + */ + @Test + def testLogConfigDirtyAsBrokerUpdated(): Unit = { + logManager.shutdown() + val spyConfigRepository = spy(configRepository) + logManager = TestUtils.createLogManager( + logDirs = Seq(logDir1, logDir2), defaultConfig = logConfig, configRepository = spyConfigRepository, + cleanerConfig = CleanerConfig(enableCleaner = false), time = time) + logManager.startup(Set.empty) + + val spyLogManager = spy(logManager) + doAnswer((_: InvocationOnMock) => { + logManager.initializingLog(topicPartition) + logManager.brokerConfigUpdated() + }).when(spyLogManager).initializingLog(ArgumentMatchers.eq(topicPartition)) + + val partition = new Partition(topicPartition, + replicaLagTimeMaxMs = Defaults.ReplicaLagTimeMaxMs, + interBrokerProtocolVersion = ApiVersion.latestVersion, + localBrokerId = brokerId, + time, + isrChangeListener, + delayedOperations, + metadataCache, + spyLogManager, + alterIsrManager) + + partition.createLog(isNew = true, isFutureReplica = false, offsetCheckpoints, topicId = None) + + // Validate that initializingLog and finishedInitializingLog was called + verify(spyLogManager).initializingLog(ArgumentMatchers.eq(topicPartition)) + verify(spyLogManager).finishedInitializingLog(ArgumentMatchers.eq(topicPartition), ArgumentMatchers.any()) + + // We should get configs twice, once before log is created, and second time once + // we find log config is dirty and refresh it. + verify(spyConfigRepository, times(2)).topicConfig(topicPartition.topic()) + } + + private def makeLeader( + topicId: Option[Uuid], + controllerEpoch: Int, + leaderEpoch: Int, + isr: Seq[Int], + replicas: Seq[Int], + zkVersion: Int, + isNew: Boolean, + partition: Partition = partition + ): Boolean = { + partition.createLogIfNotExists( + isNew = isNew, + isFutureReplica = false, + offsetCheckpoints, + topicId + ) + val newLeader = partition.makeLeader( + new LeaderAndIsrPartitionState() + .setControllerEpoch(controllerEpoch) + .setLeader(brokerId) + .setLeaderEpoch(leaderEpoch) + .setIsr(isr.map(Int.box).asJava) + .setZkVersion(zkVersion) + .setReplicas(replicas.map(Int.box).asJava) + .setIsNew(isNew), + offsetCheckpoints, + topicId + ) + assertTrue(partition.isLeader) + assertFalse(partition.isrState.isInflight) + assertEquals(topicId, partition.topicId) + assertEquals(leaderEpoch, partition.getLeaderEpoch) + assertEquals(isr.toSet, partition.isrState.isr) + assertEquals(isr.toSet, partition.isrState.maximalIsr) + assertEquals(zkVersion, partition.getZkVersion) + newLeader + } + + private def seedLogData(log: UnifiedLog, numRecords: Int, leaderEpoch: Int): Unit = { + for (i <- 0 until numRecords) { + val records = MemoryRecords.withRecords(0L, CompressionType.NONE, leaderEpoch, + new SimpleRecord(s"k$i".getBytes, s"v$i".getBytes)) + log.appendAsLeader(records, leaderEpoch) + } + } + + private class SlowLog( + log: UnifiedLog, + logStartOffset: Long, + localLog: LocalLog, + leaderEpochCache: Option[LeaderEpochFileCache], + producerStateManager: ProducerStateManager, + appendSemaphore: Semaphore + ) extends UnifiedLog( + logStartOffset, + localLog, + new BrokerTopicStats, + log.producerIdExpirationCheckIntervalMs, + leaderEpochCache, + producerStateManager, + _topicId = None, + keepPartitionMetadataFile = true) { + + override def appendAsFollower(records: MemoryRecords): LogAppendInfo = { + appendSemaphore.acquire() + val appendInfo = super.appendAsFollower(records) + appendInfo + } + } +} diff --git a/core/src/test/scala/unit/kafka/cluster/PartitionWithLegacyMessageFormatTest.scala b/core/src/test/scala/unit/kafka/cluster/PartitionWithLegacyMessageFormatTest.scala new file mode 100644 index 0000000..50b10fa --- /dev/null +++ b/core/src/test/scala/unit/kafka/cluster/PartitionWithLegacyMessageFormatTest.scala @@ -0,0 +1,64 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.cluster + +import kafka.api.{ApiVersion, KAFKA_2_8_IV1} +import kafka.log.LogConfig +import kafka.utils.TestUtils +import org.apache.kafka.common.record.{RecordVersion, SimpleRecord} +import org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.{UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET} +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test + +import java.util.Optional +import scala.annotation.nowarn + +class PartitionWithLegacyMessageFormatTest extends AbstractPartitionTest { + + // legacy message formats are only supported with IBP < 3.0 + override protected def interBrokerProtocolVersion: ApiVersion = KAFKA_2_8_IV1 + + @nowarn("cat=deprecation") + @Test + def testMakeLeaderDoesNotUpdateEpochCacheForOldFormats(): Unit = { + val leaderEpoch = 8 + configRepository.setTopicConfig(topicPartition.topic(), + LogConfig.MessageFormatVersionProp, kafka.api.KAFKA_0_10_2_IV0.shortVersion) + val log = logManager.getOrCreateLog(topicPartition, topicId = None) + log.appendAsLeader(TestUtils.records(List( + new SimpleRecord("k1".getBytes, "v1".getBytes), + new SimpleRecord("k2".getBytes, "v2".getBytes)), + magicValue = RecordVersion.V1.value + ), leaderEpoch = 0) + log.appendAsLeader(TestUtils.records(List( + new SimpleRecord("k3".getBytes, "v3".getBytes), + new SimpleRecord("k4".getBytes, "v4".getBytes)), + magicValue = RecordVersion.V1.value + ), leaderEpoch = 5) + assertEquals(4, log.logEndOffset) + + val partition = setupPartitionWithMocks(leaderEpoch = leaderEpoch, isLeader = true) + assertEquals(Some(4), partition.leaderLogIfLocal.map(_.logEndOffset)) + assertEquals(None, log.latestEpoch) + + val epochEndOffset = partition.lastOffsetForLeaderEpoch(currentLeaderEpoch = Optional.of(leaderEpoch), + leaderEpoch = leaderEpoch, fetchOnlyFromLeader = true) + assertEquals(UNDEFINED_EPOCH_OFFSET, epochEndOffset.endOffset) + assertEquals(UNDEFINED_EPOCH, epochEndOffset.leaderEpoch) + } + +} diff --git a/core/src/test/scala/unit/kafka/cluster/ReplicaTest.scala b/core/src/test/scala/unit/kafka/cluster/ReplicaTest.scala new file mode 100644 index 0000000..08d0950 --- /dev/null +++ b/core/src/test/scala/unit/kafka/cluster/ReplicaTest.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.cluster + +import java.util.Properties + +import kafka.log.{ClientRecordDeletion, UnifiedLog, LogConfig, LogManager} +import kafka.server.{BrokerTopicStats, LogDirFailureChannel} +import kafka.utils.{MockTime, TestUtils} +import org.apache.kafka.common.errors.OffsetOutOfRangeException +import org.apache.kafka.common.utils.Utils +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} + +class ReplicaTest { + + val tmpDir = TestUtils.tempDir() + val logDir = TestUtils.randomPartitionLogDir(tmpDir) + val time = new MockTime() + val brokerTopicStats = new BrokerTopicStats + var log: UnifiedLog = _ + + @BeforeEach + def setup(): Unit = { + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 512: java.lang.Integer) + logProps.put(LogConfig.SegmentIndexBytesProp, 1000: java.lang.Integer) + logProps.put(LogConfig.RetentionMsProp, 999: java.lang.Integer) + val config = LogConfig(logProps) + log = UnifiedLog(logDir, + config, + logStartOffset = 0L, + recoveryPoint = 0L, + scheduler = time.scheduler, + brokerTopicStats = brokerTopicStats, + time = time, + maxProducerIdExpirationMs = 60 * 60 * 1000, + producerIdExpirationCheckIntervalMs = LogManager.ProducerIdExpirationCheckIntervalMs, + logDirFailureChannel = new LogDirFailureChannel(10), + topicId = None, + keepPartitionMetadataFile = true) + } + + @AfterEach + def tearDown(): Unit = { + log.close() + brokerTopicStats.close() + Utils.delete(tmpDir) + } + + @Test + def testSegmentDeletionWithHighWatermarkInitialization(): Unit = { + val expiredTimestamp = time.milliseconds() - 1000 + for (i <- 0 until 100) { + val records = TestUtils.singletonRecords(value = s"test$i".getBytes, timestamp = expiredTimestamp) + log.appendAsLeader(records, leaderEpoch = 0) + } + + val initialHighWatermark = log.updateHighWatermark(25L) + assertEquals(25L, initialHighWatermark) + + val initialNumSegments = log.numberOfSegments + log.deleteOldSegments() + assertTrue(log.numberOfSegments < initialNumSegments) + assertTrue(log.logStartOffset <= initialHighWatermark) + } + + @Test + def testCannotDeleteSegmentsAtOrAboveHighWatermark(): Unit = { + val expiredTimestamp = time.milliseconds() - 1000 + for (i <- 0 until 100) { + val records = TestUtils.singletonRecords(value = s"test$i".getBytes, timestamp = expiredTimestamp) + log.appendAsLeader(records, leaderEpoch = 0) + } + + // ensure we have at least a few segments so the test case is not trivial + assertTrue(log.numberOfSegments > 5) + assertEquals(0L, log.highWatermark) + assertEquals(0L, log.logStartOffset) + assertEquals(100L, log.logEndOffset) + + for (hw <- 0 to 100) { + log.updateHighWatermark(hw) + assertEquals(hw, log.highWatermark) + log.deleteOldSegments() + assertTrue(log.logStartOffset <= hw) + + // verify that all segments up to the high watermark have been deleted + + log.logSegments.headOption.foreach { segment => + assertTrue(segment.baseOffset <= hw) + assertTrue(segment.baseOffset >= log.logStartOffset) + } + log.logSegments.tail.foreach { segment => + assertTrue(segment.baseOffset > hw) + assertTrue(segment.baseOffset >= log.logStartOffset) + } + } + + assertEquals(100L, log.logStartOffset) + assertEquals(1, log.numberOfSegments) + assertEquals(0, log.activeSegment.size) + } + + @Test + def testCannotIncrementLogStartOffsetPastHighWatermark(): Unit = { + for (i <- 0 until 100) { + val records = TestUtils.singletonRecords(value = s"test$i".getBytes) + log.appendAsLeader(records, leaderEpoch = 0) + } + + log.updateHighWatermark(25L) + assertThrows(classOf[OffsetOutOfRangeException], () => log.maybeIncrementLogStartOffset(26L, ClientRecordDeletion)) + } +} diff --git a/core/src/test/scala/unit/kafka/common/ZkNodeChangeNotificationListenerTest.scala b/core/src/test/scala/unit/kafka/common/ZkNodeChangeNotificationListenerTest.scala new file mode 100644 index 0000000..9b7f090 --- /dev/null +++ b/core/src/test/scala/unit/kafka/common/ZkNodeChangeNotificationListenerTest.scala @@ -0,0 +1,112 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.common + +import kafka.utils.TestUtils +import kafka.zk.{LiteralAclChangeStore, LiteralAclStore, ZkAclChangeStore} +import kafka.server.QuorumTestHarness +import org.apache.kafka.common.resource.PatternType.LITERAL +import org.apache.kafka.common.resource.ResourcePattern +import org.apache.kafka.common.resource.ResourceType.GROUP +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +import scala.collection.mutable.ArrayBuffer +import scala.collection.Seq + +class ZkNodeChangeNotificationListenerTest extends QuorumTestHarness { + + private val changeExpirationMs = 1000 + private var notificationListener: ZkNodeChangeNotificationListener = _ + private var notificationHandler: TestNotificationHandler = _ + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + zkClient.createAclPaths() + notificationHandler = new TestNotificationHandler() + } + + @AfterEach + override def tearDown(): Unit = { + if (notificationListener != null) { + notificationListener.close() + } + super.tearDown() + } + + @Test + def testProcessNotification(): Unit = { + val notificationMessage1 = new ResourcePattern(GROUP, "messageA", LITERAL) + val notificationMessage2 = new ResourcePattern(GROUP, "messageB", LITERAL) + + notificationListener = new ZkNodeChangeNotificationListener(zkClient, LiteralAclChangeStore.aclChangePath, + ZkAclChangeStore.SequenceNumberPrefix, notificationHandler, changeExpirationMs) + notificationListener.init() + + zkClient.createAclChangeNotification(notificationMessage1) + TestUtils.waitUntilTrue(() => notificationHandler.received().size == 1 && notificationHandler.received().last == notificationMessage1, + "Failed to send/process notification message in the timeout period.") + + /* + * There is no easy way to test purging. Even if we mock kafka time with MockTime, the purging compares kafka time + * with the time stored in ZooKeeper stat and the embedded ZooKeeper server does not provide a way to mock time. + * So to test purging we would have to use Time.SYSTEM.sleep(changeExpirationMs + 1) issue a write and check + * Assert.assertEquals(1, KafkaZkClient.getChildren(seqNodeRoot).size). However even that the assertion + * can fail as the second node can be deleted depending on how threads get scheduled. + */ + + zkClient.createAclChangeNotification(notificationMessage2) + TestUtils.waitUntilTrue(() => notificationHandler.received().size == 2 && notificationHandler.received().last == notificationMessage2, + "Failed to send/process notification message in the timeout period.") + + (3 to 10).foreach(i => zkClient.createAclChangeNotification(new ResourcePattern(GROUP, "message" + i, LITERAL))) + + TestUtils.waitUntilTrue(() => notificationHandler.received().size == 10, + s"Expected 10 invocations of processNotifications, but there were ${notificationHandler.received()}") + } + + @Test + def testSwallowsProcessorException(): Unit = { + notificationHandler.setThrowSize(2) + notificationListener = new ZkNodeChangeNotificationListener(zkClient, LiteralAclChangeStore.aclChangePath, + ZkAclChangeStore.SequenceNumberPrefix, notificationHandler, changeExpirationMs) + notificationListener.init() + + zkClient.createAclChangeNotification(new ResourcePattern(GROUP, "messageA", LITERAL)) + zkClient.createAclChangeNotification(new ResourcePattern(GROUP, "messageB", LITERAL)) + zkClient.createAclChangeNotification(new ResourcePattern(GROUP, "messageC", LITERAL)) + + TestUtils.waitUntilTrue(() => notificationHandler.received().size == 3, + s"Expected 2 invocations of processNotifications, but there were ${notificationHandler.received()}") + } + + private class TestNotificationHandler extends NotificationHandler { + private val messages = ArrayBuffer.empty[ResourcePattern] + @volatile private var throwSize = Option.empty[Int] + + override def processNotification(notificationMessage: Array[Byte]): Unit = { + messages += LiteralAclStore.changeStore.decode(notificationMessage) + + if (throwSize.contains(messages.size)) + throw new RuntimeException("Oh no, my processing failed!") + } + + def received(): Seq[ResourcePattern] = messages + + def setThrowSize(index: Int): Unit = throwSize = Option(index) + } +} diff --git a/core/src/test/scala/unit/kafka/controller/ControllerChannelManagerTest.scala b/core/src/test/scala/unit/kafka/controller/ControllerChannelManagerTest.scala new file mode 100644 index 0000000..495f819 --- /dev/null +++ b/core/src/test/scala/unit/kafka/controller/ControllerChannelManagerTest.scala @@ -0,0 +1,961 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.controller + +import java.util.Properties +import kafka.api.{ApiVersion, KAFKA_0_10_0_IV1, KAFKA_0_10_2_IV0, KAFKA_0_9_0, KAFKA_1_0_IV0, KAFKA_2_2_IV0, KAFKA_2_4_IV0, KAFKA_2_4_IV1, KAFKA_2_6_IV0, KAFKA_2_8_IV1, LeaderAndIsr} +import kafka.cluster.{Broker, EndPoint} +import kafka.server.KafkaConfig +import kafka.utils.TestUtils +import org.apache.kafka.common.{TopicPartition, Uuid} +import org.apache.kafka.common.message.{LeaderAndIsrResponseData, StopReplicaResponseData, UpdateMetadataResponseData} +import org.apache.kafka.common.message.LeaderAndIsrResponseData.LeaderAndIsrPartitionError +import org.apache.kafka.common.message.StopReplicaRequestData.StopReplicaPartitionState +import org.apache.kafka.common.message.StopReplicaResponseData.StopReplicaPartitionError +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.requests.{AbstractControlRequest, AbstractResponse, LeaderAndIsrRequest, LeaderAndIsrResponse, StopReplicaRequest, StopReplicaResponse, UpdateMetadataRequest, UpdateMetadataResponse} +import org.apache.kafka.common.message.LeaderAndIsrResponseData.LeaderAndIsrTopicError +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ +import scala.collection.mutable +import scala.collection.mutable.ListBuffer + +class ControllerChannelManagerTest { + private val controllerId = 1 + private val controllerEpoch = 1 + private val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(controllerId, "zkConnect")) + private val logger = new StateChangeLogger(controllerId, true, None) + + type ControlRequest = AbstractControlRequest.Builder[_ <: AbstractControlRequest] + + @Test + def testLeaderAndIsrRequestSent(): Unit = { + val context = initContext(Seq(1, 2, 3), 2, 3, Set("foo", "bar")) + val batch = new MockControllerBrokerRequestBatch(context) + + val partitions = Map( + new TopicPartition("foo", 0) -> LeaderAndIsr(1, List(1, 2)), + new TopicPartition("foo", 1) -> LeaderAndIsr(2, List(2, 3)), + new TopicPartition("bar", 1) -> LeaderAndIsr(3, List(1, 3)) + ) + + batch.newBatch() + partitions.foreach { case (partition, leaderAndIsr) => + val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch) + context.putPartitionLeadershipInfo(partition, leaderIsrAndControllerEpoch) + batch.addLeaderAndIsrRequestForBrokers(Seq(2), partition, leaderIsrAndControllerEpoch, replicaAssignment(Seq(1, 2, 3)), isNew = false) + } + batch.sendRequestsToBrokers(controllerEpoch) + + val leaderAndIsrRequests = batch.collectLeaderAndIsrRequestsFor(2) + val updateMetadataRequests = batch.collectUpdateMetadataRequestsFor(2) + assertEquals(1, leaderAndIsrRequests.size) + assertEquals(1, updateMetadataRequests.size) + + val leaderAndIsrRequest = leaderAndIsrRequests.head + val topicIds = leaderAndIsrRequest.topicIds(); + val topicNames = topicIds.asScala.map { case (k, v) => (v, k) } + assertEquals(controllerId, leaderAndIsrRequest.controllerId) + assertEquals(controllerEpoch, leaderAndIsrRequest.controllerEpoch) + assertEquals(partitions.keySet, + leaderAndIsrRequest.partitionStates.asScala.map(p => new TopicPartition(p.topicName, p.partitionIndex)).toSet) + assertEquals(partitions.map { case (k, v) => (k, v.leader) }, + leaderAndIsrRequest.partitionStates.asScala.map(p => new TopicPartition(p.topicName, p.partitionIndex) -> p.leader).toMap) + assertEquals(partitions.map { case (k, v) => (k, v.isr) }, + leaderAndIsrRequest.partitionStates.asScala.map(p => new TopicPartition(p.topicName, p.partitionIndex) -> p.isr.asScala).toMap) + + applyLeaderAndIsrResponseCallbacks(Errors.NONE, batch.sentRequests(2).toList) + assertEquals(1, batch.sentEvents.size) + + val LeaderAndIsrResponseReceived(leaderAndIsrResponse, brokerId) = batch.sentEvents.head + assertEquals(2, brokerId) + assertEquals(partitions.keySet, + leaderAndIsrResponse.topics.asScala.flatMap(t => t.partitionErrors.asScala.map(p => + new TopicPartition(topicNames(t.topicId), p.partitionIndex))).toSet) + leaderAndIsrResponse.topics.forEach(topic => + assertEquals(topicIds.get(topicNames.get(topic.topicId).get), topic.topicId)) + } + + @Test + def testLeaderAndIsrRequestIsNew(): Unit = { + val context = initContext(Seq(1, 2, 3), 2, 3, Set("foo", "bar")) + val batch = new MockControllerBrokerRequestBatch(context) + + val partition = new TopicPartition("foo", 0) + val leaderAndIsr = LeaderAndIsr(1, List(1, 2)) + + val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch) + context.putPartitionLeadershipInfo(partition, leaderIsrAndControllerEpoch) + + batch.newBatch() + batch.addLeaderAndIsrRequestForBrokers(Seq(2), partition, leaderIsrAndControllerEpoch, replicaAssignment(Seq(1, 2, 3)), isNew = true) + batch.addLeaderAndIsrRequestForBrokers(Seq(2), partition, leaderIsrAndControllerEpoch, replicaAssignment(Seq(1, 2, 3)), isNew = false) + batch.sendRequestsToBrokers(controllerEpoch) + + val leaderAndIsrRequests = batch.collectLeaderAndIsrRequestsFor(2) + val updateMetadataRequests = batch.collectUpdateMetadataRequestsFor(2) + assertEquals(1, leaderAndIsrRequests.size) + assertEquals(1, updateMetadataRequests.size) + + val leaderAndIsrRequest = leaderAndIsrRequests.head + val partitionStates = leaderAndIsrRequest.partitionStates.asScala + assertEquals(Seq(partition), partitionStates.map(p => new TopicPartition(p.topicName, p.partitionIndex))) + val partitionState = partitionStates.find(p => p.topicName == partition.topic && p.partitionIndex == partition.partition) + assertEquals(Some(true), partitionState.map(_.isNew)) + } + + @Test + def testLeaderAndIsrRequestSentToLiveOrShuttingDownBrokers(): Unit = { + val context = initContext(Seq(1, 2, 3), 2, 3, Set("foo", "bar")) + val batch = new MockControllerBrokerRequestBatch(context) + + // 2 is shutting down, 3 is dead + context.shuttingDownBrokerIds.add(2) + context.removeLiveBrokers(Set(3)) + + val partition = new TopicPartition("foo", 0) + val leaderAndIsr = LeaderAndIsr(1, List(1, 2)) + + val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch) + context.putPartitionLeadershipInfo(partition, leaderIsrAndControllerEpoch) + + batch.newBatch() + batch.addLeaderAndIsrRequestForBrokers(Seq(1, 2, 3), partition, leaderIsrAndControllerEpoch, replicaAssignment(Seq(1, 2, 3)), isNew = false) + batch.sendRequestsToBrokers(controllerEpoch) + + assertEquals(0, batch.sentEvents.size) + assertEquals(2, batch.sentRequests.size) + assertEquals(Set(1, 2), batch.sentRequests.keySet) + + for (brokerId <- Set(1, 2)) { + val leaderAndIsrRequests = batch.collectLeaderAndIsrRequestsFor(brokerId) + val updateMetadataRequests = batch.collectUpdateMetadataRequestsFor(brokerId) + assertEquals(1, leaderAndIsrRequests.size) + assertEquals(1, updateMetadataRequests.size) + val leaderAndIsrRequest = leaderAndIsrRequests.head + assertEquals(Seq(partition), leaderAndIsrRequest.partitionStates.asScala.map(p => new TopicPartition(p.topicName, p.partitionIndex))) + } + } + + @Test + def testLeaderAndIsrInterBrokerProtocolVersion(): Unit = { + testLeaderAndIsrRequestFollowsInterBrokerProtocolVersion(ApiVersion.latestVersion, ApiKeys.LEADER_AND_ISR.latestVersion) + + for (apiVersion <- ApiVersion.allVersions) { + val leaderAndIsrRequestVersion: Short = + if (apiVersion >= KAFKA_2_8_IV1) 5 + else if (apiVersion >= KAFKA_2_4_IV1) 4 + else if (apiVersion >= KAFKA_2_4_IV0) 3 + else if (apiVersion >= KAFKA_2_2_IV0) 2 + else if (apiVersion >= KAFKA_1_0_IV0) 1 + else 0 + + testLeaderAndIsrRequestFollowsInterBrokerProtocolVersion(apiVersion, leaderAndIsrRequestVersion) + } + } + + private def testLeaderAndIsrRequestFollowsInterBrokerProtocolVersion(interBrokerProtocolVersion: ApiVersion, + expectedLeaderAndIsrVersion: Short): Unit = { + val context = initContext(Seq(1, 2, 3), 2, 3, Set("foo", "bar")) + val config = createConfig(interBrokerProtocolVersion) + val batch = new MockControllerBrokerRequestBatch(context, config) + + val partition = new TopicPartition("foo", 0) + val leaderAndIsr = LeaderAndIsr(1, List(1, 2)) + + val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch) + context.putPartitionLeadershipInfo(partition, leaderIsrAndControllerEpoch) + + batch.newBatch() + batch.addLeaderAndIsrRequestForBrokers(Seq(2), partition, leaderIsrAndControllerEpoch, replicaAssignment(Seq(1, 2, 3)), isNew = false) + batch.sendRequestsToBrokers(controllerEpoch) + + val leaderAndIsrRequests = batch.collectLeaderAndIsrRequestsFor(2) + assertEquals(1, leaderAndIsrRequests.size) + assertEquals(expectedLeaderAndIsrVersion, leaderAndIsrRequests.head.version, + s"IBP $interBrokerProtocolVersion should use version $expectedLeaderAndIsrVersion") + + val request = leaderAndIsrRequests.head + val byteBuffer = request.serialize + val deserializedRequest = LeaderAndIsrRequest.parse(byteBuffer, expectedLeaderAndIsrVersion) + + if (interBrokerProtocolVersion >= KAFKA_2_8_IV1) { + assertFalse(request.topicIds().get("foo").equals(Uuid.ZERO_UUID)) + assertFalse(deserializedRequest.topicIds().get("foo").equals(Uuid.ZERO_UUID)) + } else if (interBrokerProtocolVersion >= KAFKA_2_2_IV0) { + assertFalse(request.topicIds().get("foo").equals(Uuid.ZERO_UUID)) + assertTrue(deserializedRequest.topicIds().get("foo").equals(Uuid.ZERO_UUID)) + } else { + assertTrue(request.topicIds().get("foo") == null) + assertTrue(deserializedRequest.topicIds().get("foo") == null) + } + } + + @Test + def testUpdateMetadataRequestSent(): Unit = { + + val topicIds = Map("foo" -> Uuid.randomUuid(), "bar" -> Uuid.randomUuid()) + val context = initContext(Seq(1, 2, 3), 2, 3, topicIds) + val batch = new MockControllerBrokerRequestBatch(context) + + val partitions = Map( + new TopicPartition("foo", 0) -> LeaderAndIsr(1, List(1, 2)), + new TopicPartition("foo", 1) -> LeaderAndIsr(2, List(2, 3)), + new TopicPartition("bar", 1) -> LeaderAndIsr(3, List(1, 3)) + ) + + partitions.foreach { case (partition, leaderAndIsr) => + context.putPartitionLeadershipInfo(partition, LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch)) + } + + batch.newBatch() + batch.addUpdateMetadataRequestForBrokers(Seq(2), partitions.keySet) + batch.sendRequestsToBrokers(controllerEpoch) + + val updateMetadataRequests = batch.collectUpdateMetadataRequestsFor(2) + assertEquals(1, updateMetadataRequests.size) + + val updateMetadataRequest = updateMetadataRequests.head + val partitionStates = updateMetadataRequest.partitionStates.asScala.toBuffer + assertEquals(3, partitionStates.size) + assertEquals(partitions.map { case (k, v) => (k, v.leader) }, + partitionStates.map(ps => (new TopicPartition(ps.topicName, ps.partitionIndex), ps.leader)).toMap) + assertEquals(partitions.map { case (k, v) => (k, v.isr) }, + partitionStates.map(ps => (new TopicPartition(ps.topicName, ps.partitionIndex), ps.isr.asScala)).toMap) + + val topicStates = updateMetadataRequest.topicStates() + assertEquals(2, topicStates.size) + for (topicState <- topicStates.asScala) { + assertEquals(topicState.topicId(), topicIds(topicState.topicName())) + } + + assertEquals(controllerId, updateMetadataRequest.controllerId) + assertEquals(controllerEpoch, updateMetadataRequest.controllerEpoch) + assertEquals(3, updateMetadataRequest.liveBrokers.size) + assertEquals(Set(1, 2, 3), updateMetadataRequest.liveBrokers.asScala.map(_.id).toSet) + + applyUpdateMetadataResponseCallbacks(Errors.STALE_BROKER_EPOCH, batch.sentRequests(2).toList) + assertEquals(1, batch.sentEvents.size) + + val UpdateMetadataResponseReceived(updateMetadataResponse, brokerId) = batch.sentEvents.head + assertEquals(2, brokerId) + assertEquals(Errors.STALE_BROKER_EPOCH, updateMetadataResponse.error) + } + + @Test + def testUpdateMetadataDoesNotIncludePartitionsWithoutLeaderAndIsr(): Unit = { + val context = initContext(Seq(1, 2, 3), 2, 3, Set("foo", "bar")) + val batch = new MockControllerBrokerRequestBatch(context) + + val partitions = Set( + new TopicPartition("foo", 0), + new TopicPartition("foo", 1), + new TopicPartition("bar", 1) + ) + + batch.newBatch() + batch.addUpdateMetadataRequestForBrokers(Seq(2), partitions) + batch.sendRequestsToBrokers(controllerEpoch) + + assertEquals(0, batch.sentEvents.size) + assertEquals(1, batch.sentRequests.size) + assertTrue(batch.sentRequests.contains(2)) + + val updateMetadataRequests = batch.collectUpdateMetadataRequestsFor(2) + assertEquals(1, updateMetadataRequests.size) + + val updateMetadataRequest = updateMetadataRequests.head + assertEquals(0, updateMetadataRequest.partitionStates.asScala.size) + assertEquals(3, updateMetadataRequest.liveBrokers.size) + assertEquals(Set(1, 2, 3), updateMetadataRequest.liveBrokers.asScala.map(_.id).toSet) + } + + @Test + def testUpdateMetadataRequestDuringTopicDeletion(): Unit = { + val context = initContext(Seq(1, 2, 3), 2, 3, Set("foo", "bar")) + val batch = new MockControllerBrokerRequestBatch(context) + + val partitions = Map( + new TopicPartition("foo", 0) -> LeaderAndIsr(1, List(1, 2)), + new TopicPartition("foo", 1) -> LeaderAndIsr(2, List(2, 3)), + new TopicPartition("bar", 1) -> LeaderAndIsr(3, List(1, 3)) + ) + + partitions.foreach { case (partition, leaderAndIsr) => + context.putPartitionLeadershipInfo(partition, LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch)) + } + + context.queueTopicDeletion(Set("foo")) + + batch.newBatch() + batch.addUpdateMetadataRequestForBrokers(Seq(2), partitions.keySet) + batch.sendRequestsToBrokers(controllerEpoch) + + val updateMetadataRequests = batch.collectUpdateMetadataRequestsFor(2) + assertEquals(1, updateMetadataRequests.size) + + val updateMetadataRequest = updateMetadataRequests.head + assertEquals(3, updateMetadataRequest.partitionStates.asScala.size) + + assertTrue(updateMetadataRequest.partitionStates.asScala + .filter(_.topicName == "foo") + .map(_.leader) + .forall(leaderId => leaderId == LeaderAndIsr.LeaderDuringDelete)) + + assertEquals(partitions.filter { case (k, _) => k.topic == "bar" }.map { case (k, v) => (k, v.leader) }, + updateMetadataRequest.partitionStates.asScala.filter(ps => ps.topicName == "bar").map { ps => + (new TopicPartition(ps.topicName, ps.partitionIndex), ps.leader) }.toMap) + assertEquals(partitions.map { case (k, v) => (k, v.isr) }, + updateMetadataRequest.partitionStates.asScala.map(ps => (new TopicPartition(ps.topicName, ps.partitionIndex), ps.isr.asScala)).toMap) + + assertEquals(3, updateMetadataRequest.liveBrokers.size) + assertEquals(Set(1, 2, 3), updateMetadataRequest.liveBrokers.asScala.map(_.id).toSet) + } + + @Test + def testUpdateMetadataIncludesLiveOrShuttingDownBrokers(): Unit = { + val context = initContext(Seq(1, 2, 3), 2, 3, Set("foo", "bar")) + val batch = new MockControllerBrokerRequestBatch(context) + + // 2 is shutting down, 3 is dead + context.shuttingDownBrokerIds.add(2) + context.removeLiveBrokers(Set(3)) + + batch.newBatch() + batch.addUpdateMetadataRequestForBrokers(Seq(1, 2, 3), Set.empty) + batch.sendRequestsToBrokers(controllerEpoch) + + assertEquals(Set(1, 2), batch.sentRequests.keySet) + + for (brokerId <- Set(1, 2)) { + val updateMetadataRequests = batch.collectUpdateMetadataRequestsFor(brokerId) + assertEquals(1, updateMetadataRequests.size) + + val updateMetadataRequest = updateMetadataRequests.head + assertEquals(0, updateMetadataRequest.partitionStates.asScala.size) + assertEquals(2, updateMetadataRequest.liveBrokers.size) + assertEquals(Set(1, 2), updateMetadataRequest.liveBrokers.asScala.map(_.id).toSet) + } + } + + @Test + def testUpdateMetadataInterBrokerProtocolVersion(): Unit = { + testUpdateMetadataFollowsInterBrokerProtocolVersion(ApiVersion.latestVersion, ApiKeys.UPDATE_METADATA.latestVersion) + + for (apiVersion <- ApiVersion.allVersions) { + val updateMetadataRequestVersion: Short = + if (apiVersion >= KAFKA_2_8_IV1) 7 + else if (apiVersion >= KAFKA_2_4_IV1) 6 + else if (apiVersion >= KAFKA_2_2_IV0) 5 + else if (apiVersion >= KAFKA_1_0_IV0) 4 + else if (apiVersion >= KAFKA_0_10_2_IV0) 3 + else if (apiVersion >= KAFKA_0_10_0_IV1) 2 + else if (apiVersion >= KAFKA_0_9_0) 1 + else 0 + + testUpdateMetadataFollowsInterBrokerProtocolVersion(apiVersion, updateMetadataRequestVersion) + } + } + + private def testUpdateMetadataFollowsInterBrokerProtocolVersion(interBrokerProtocolVersion: ApiVersion, + expectedUpdateMetadataVersion: Short): Unit = { + val context = initContext(Seq(1, 2, 3), 2, 3, Set("foo", "bar")) + val config = createConfig(interBrokerProtocolVersion) + val batch = new MockControllerBrokerRequestBatch(context, config) + + batch.newBatch() + batch.addUpdateMetadataRequestForBrokers(Seq(2), Set.empty) + batch.sendRequestsToBrokers(controllerEpoch) + + assertEquals(0, batch.sentEvents.size) + assertEquals(1, batch.sentRequests.size) + assertTrue(batch.sentRequests.contains(2)) + + val requests = batch.collectUpdateMetadataRequestsFor(2) + val allVersions = requests.map(_.version) + assertTrue(allVersions.forall(_ == expectedUpdateMetadataVersion), + s"IBP $interBrokerProtocolVersion should use version $expectedUpdateMetadataVersion, but found versions $allVersions") + } + + @Test + def testStopReplicaRequestSent(): Unit = { + val context = initContext(Seq(1, 2, 3), 2, 3, Set("foo", "bar")) + val batch = new MockControllerBrokerRequestBatch(context) + + val partitions = Map( + new TopicPartition("foo", 0) -> LeaderAndDelete(1, false), + new TopicPartition("foo", 1) -> LeaderAndDelete(2, false), + new TopicPartition("bar", 1) -> LeaderAndDelete(3, false) + ) + + batch.newBatch() + partitions.foreach { case (partition, LeaderAndDelete(leaderAndIsr, deletePartition)) => + context.putPartitionLeadershipInfo(partition, LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch)) + batch.addStopReplicaRequestForBrokers(Seq(2), partition, deletePartition) + } + batch.sendRequestsToBrokers(controllerEpoch) + + assertEquals(0, batch.sentEvents.size) + assertEquals(1, batch.sentRequests.size) + assertTrue(batch.sentRequests.contains(2)) + + val sentRequests = batch.sentRequests(2) + assertEquals(1, sentRequests.size) + + val sentStopReplicaRequests = batch.collectStopReplicaRequestsFor(2) + assertEquals(1, sentStopReplicaRequests.size) + + val stopReplicaRequest = sentStopReplicaRequests.head + assertEquals(partitionStates(partitions), stopReplicaRequest.partitionStates().asScala) + + applyStopReplicaResponseCallbacks(Errors.NONE, batch.sentRequests(2).toList) + assertEquals(0, batch.sentEvents.size) + } + + @Test + def testStopReplicaRequestWithAlreadyDefinedDeletedPartition(): Unit = { + val context = initContext(Seq(1, 2, 3), 2, 3, Set("foo", "bar")) + val batch = new MockControllerBrokerRequestBatch(context) + + val partition = new TopicPartition("foo", 0) + val leaderAndIsr = LeaderAndIsr(1, List(1, 2)) + context.putPartitionLeadershipInfo(partition, LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch)) + + batch.newBatch() + batch.addStopReplicaRequestForBrokers(Seq(2), partition, deletePartition = true) + batch.addStopReplicaRequestForBrokers(Seq(2), partition, deletePartition = false) + batch.sendRequestsToBrokers(controllerEpoch) + + val sentStopReplicaRequests = batch.collectStopReplicaRequestsFor(2) + assertEquals(1, sentStopReplicaRequests.size) + + val stopReplicaRequest = sentStopReplicaRequests.head + assertEquals(partitionStates(Map(partition -> LeaderAndDelete(leaderAndIsr, true))), + stopReplicaRequest.partitionStates().asScala) + } + + @Test + def testStopReplicaRequestsWhileTopicQueuedForDeletion(): Unit = { + for (apiVersion <- ApiVersion.allVersions) { + testStopReplicaRequestsWhileTopicQueuedForDeletion(apiVersion) + } + } + + private def testStopReplicaRequestsWhileTopicQueuedForDeletion(interBrokerProtocolVersion: ApiVersion): Unit = { + val context = initContext(Seq(1, 2, 3), 2, 3, Set("foo", "bar")) + val config = createConfig(interBrokerProtocolVersion) + val batch = new MockControllerBrokerRequestBatch(context, config) + + val partitions = Map( + new TopicPartition("foo", 0) -> LeaderAndDelete(1, true), + new TopicPartition("foo", 1) -> LeaderAndDelete(2, true), + new TopicPartition("bar", 1) -> LeaderAndDelete(3, true) + ) + + // Topic deletion is queued, but has not begun + context.queueTopicDeletion(Set("foo")) + + batch.newBatch() + partitions.foreach { case (partition, LeaderAndDelete(leaderAndIsr, deletePartition)) => + context.putPartitionLeadershipInfo(partition, LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch)) + batch.addStopReplicaRequestForBrokers(Seq(2), partition, deletePartition) + } + batch.sendRequestsToBrokers(controllerEpoch) + + assertEquals(0, batch.sentEvents.size) + assertEquals(1, batch.sentRequests.size) + assertTrue(batch.sentRequests.contains(2)) + + val sentRequests = batch.sentRequests(2) + assertEquals(1, sentRequests.size) + + val sentStopReplicaRequests = batch.collectStopReplicaRequestsFor(2) + assertEquals(1, sentStopReplicaRequests.size) + + val stopReplicaRequest = sentStopReplicaRequests.head + assertEquals(partitionStates(partitions, context.topicsQueuedForDeletion, stopReplicaRequest.version), + stopReplicaRequest.partitionStates().asScala) + + // No events will be sent after the response returns + applyStopReplicaResponseCallbacks(Errors.NONE, batch.sentRequests(2).toList) + assertEquals(0, batch.sentEvents.size) + } + + @Test + def testStopReplicaRequestsWhileTopicDeletionStarted(): Unit = { + for (apiVersion <- ApiVersion.allVersions) { + testStopReplicaRequestsWhileTopicDeletionStarted(apiVersion) + } + } + + private def testStopReplicaRequestsWhileTopicDeletionStarted(interBrokerProtocolVersion: ApiVersion): Unit = { + val context = initContext(Seq(1, 2, 3), 2, 3, Set("foo", "bar")) + val config = createConfig(interBrokerProtocolVersion) + val batch = new MockControllerBrokerRequestBatch(context, config) + + val partitions = Map( + new TopicPartition("foo", 0) -> LeaderAndDelete(1, true), + new TopicPartition("foo", 1) -> LeaderAndDelete(2, true), + new TopicPartition("bar", 1) -> LeaderAndDelete(3, true) + ) + + context.queueTopicDeletion(Set("foo")) + context.beginTopicDeletion(Set("foo")) + + batch.newBatch() + partitions.foreach { case (partition, LeaderAndDelete(leaderAndIsr, deletePartition)) => + context.putPartitionLeadershipInfo(partition, LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch)) + batch.addStopReplicaRequestForBrokers(Seq(2), partition, deletePartition) + } + batch.sendRequestsToBrokers(controllerEpoch) + + assertEquals(0, batch.sentEvents.size) + assertEquals(1, batch.sentRequests.size) + assertTrue(batch.sentRequests.contains(2)) + + val sentRequests = batch.sentRequests(2) + assertEquals(1, sentRequests.size) + + val sentStopReplicaRequests = batch.collectStopReplicaRequestsFor(2) + assertEquals(1, sentStopReplicaRequests.size) + + val stopReplicaRequest = sentStopReplicaRequests.head + assertEquals(partitionStates(partitions, context.topicsQueuedForDeletion, stopReplicaRequest.version), + stopReplicaRequest.partitionStates().asScala) + + // When the topic is being deleted, we should provide a callback which sends + // the received event for the StopReplica response + applyStopReplicaResponseCallbacks(Errors.NONE, batch.sentRequests(2).toList) + assertEquals(1, batch.sentEvents.size) + + // We should only receive events for the topic being deleted + val includedPartitions = batch.sentEvents.flatMap { + case event: TopicDeletionStopReplicaResponseReceived => event.partitionErrors.keySet + case otherEvent => throw new AssertionError(s"Unexpected sent event: $otherEvent") + }.toSet + assertEquals(partitions.keys.filter(_.topic == "foo"), includedPartitions) + } + + @Test + def testStopReplicaRequestWithoutDeletePartitionWhileTopicDeletionStarted(): Unit = { + for (apiVersion <- ApiVersion.allVersions) { + testStopReplicaRequestWithoutDeletePartitionWhileTopicDeletionStarted(apiVersion) + } + } + + private def testStopReplicaRequestWithoutDeletePartitionWhileTopicDeletionStarted(interBrokerProtocolVersion: ApiVersion): Unit = { + val context = initContext(Seq(1, 2, 3), 2, 3, Set("foo", "bar")) + val config = createConfig(interBrokerProtocolVersion) + val batch = new MockControllerBrokerRequestBatch(context, config) + + val partitions = Map( + new TopicPartition("foo", 0) -> LeaderAndDelete(1, false), + new TopicPartition("foo", 1) -> LeaderAndDelete(2, false), + new TopicPartition("bar", 1) -> LeaderAndDelete(3, false) + ) + + context.queueTopicDeletion(Set("foo")) + context.beginTopicDeletion(Set("foo")) + + batch.newBatch() + partitions.foreach { case (partition, LeaderAndDelete(leaderAndIsr, deletePartition)) => + context.putPartitionLeadershipInfo(partition, LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch)) + batch.addStopReplicaRequestForBrokers(Seq(2), partition, deletePartition) + } + batch.sendRequestsToBrokers(controllerEpoch) + + assertEquals(0, batch.sentEvents.size) + assertEquals(1, batch.sentRequests.size) + assertTrue(batch.sentRequests.contains(2)) + + val sentRequests = batch.sentRequests(2) + assertEquals(1, sentRequests.size) + + val sentStopReplicaRequests = batch.collectStopReplicaRequestsFor(2) + assertEquals(1, sentStopReplicaRequests.size) + + val stopReplicaRequest = sentStopReplicaRequests.head + assertEquals(partitionStates(partitions, context.topicsQueuedForDeletion, stopReplicaRequest.version), + stopReplicaRequest.partitionStates().asScala) + + // No events should be fired + applyStopReplicaResponseCallbacks(Errors.NONE, batch.sentRequests(2).toList) + assertEquals(0, batch.sentEvents.size) + } + + @Test + def testMixedDeleteAndNotDeleteStopReplicaRequests(): Unit = { + testMixedDeleteAndNotDeleteStopReplicaRequests(ApiVersion.latestVersion, + ApiKeys.STOP_REPLICA.latestVersion) + + for (apiVersion <- ApiVersion.allVersions) { + if (apiVersion < KAFKA_2_2_IV0) + testMixedDeleteAndNotDeleteStopReplicaRequests(apiVersion, 0.toShort) + else if (apiVersion < KAFKA_2_4_IV1) + testMixedDeleteAndNotDeleteStopReplicaRequests(apiVersion, 1.toShort) + else if (apiVersion < KAFKA_2_6_IV0) + testMixedDeleteAndNotDeleteStopReplicaRequests(apiVersion, 2.toShort) + else + testMixedDeleteAndNotDeleteStopReplicaRequests(apiVersion, 3.toShort) + } + } + + private def testMixedDeleteAndNotDeleteStopReplicaRequests(interBrokerProtocolVersion: ApiVersion, + expectedStopReplicaRequestVersion: Short): Unit = { + val context = initContext(Seq(1, 2, 3), 2, 3, Set("foo", "bar")) + val config = createConfig(interBrokerProtocolVersion) + val batch = new MockControllerBrokerRequestBatch(context, config) + + val deletePartitions = Map( + new TopicPartition("foo", 0) -> LeaderAndDelete(1, true), + new TopicPartition("foo", 1) -> LeaderAndDelete(2, true) + ) + + val nonDeletePartitions = Map( + new TopicPartition("bar", 0) -> LeaderAndDelete(1, false), + new TopicPartition("bar", 1) -> LeaderAndDelete(2, false) + ) + + batch.newBatch() + deletePartitions.foreach { case (partition, LeaderAndDelete(leaderAndIsr, deletePartition)) => + context.putPartitionLeadershipInfo(partition, LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch)) + batch.addStopReplicaRequestForBrokers(Seq(2), partition, deletePartition) + } + nonDeletePartitions.foreach { case (partition, LeaderAndDelete(leaderAndIsr, deletePartition)) => + context.putPartitionLeadershipInfo(partition, LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch)) + batch.addStopReplicaRequestForBrokers(Seq(2), partition, deletePartition) + } + batch.sendRequestsToBrokers(controllerEpoch) + + assertEquals(0, batch.sentEvents.size) + assertEquals(1, batch.sentRequests.size) + assertTrue(batch.sentRequests.contains(2)) + + // Since KAFKA_2_6_IV0, only one StopReplicaRequest is sent out + if (interBrokerProtocolVersion >= KAFKA_2_6_IV0) { + val sentRequests = batch.sentRequests(2) + assertEquals(1, sentRequests.size) + + val sentStopReplicaRequests = batch.collectStopReplicaRequestsFor(2) + assertEquals(1, sentStopReplicaRequests.size) + + val stopReplicaRequest = sentStopReplicaRequests.head + assertEquals(partitionStates(deletePartitions ++ nonDeletePartitions, version = stopReplicaRequest.version), + stopReplicaRequest.partitionStates().asScala) + } else { + val sentRequests = batch.sentRequests(2) + assertEquals(2, sentRequests.size) + + val sentStopReplicaRequests = batch.collectStopReplicaRequestsFor(2) + assertEquals(2, sentStopReplicaRequests.size) + + // StopReplicaRequest (deletePartitions = true) is sent first + val stopReplicaRequestWithDelete = sentStopReplicaRequests(0) + assertEquals(partitionStates(deletePartitions, version = stopReplicaRequestWithDelete.version), + stopReplicaRequestWithDelete.partitionStates().asScala) + val stopReplicaRequestWithoutDelete = sentStopReplicaRequests(1) + assertEquals(partitionStates(nonDeletePartitions, version = stopReplicaRequestWithoutDelete.version), + stopReplicaRequestWithoutDelete.partitionStates().asScala) + } + } + + @Test + def testStopReplicaGroupsByBroker(): Unit = { + val context = initContext(Seq(1, 2, 3), 2, 3, Set("foo", "bar")) + val batch = new MockControllerBrokerRequestBatch(context) + + val partitions = Map( + new TopicPartition("foo", 0) -> LeaderAndDelete(1, false), + new TopicPartition("foo", 1) -> LeaderAndDelete(2, false), + new TopicPartition("bar", 1) -> LeaderAndDelete(3, false) + ) + + batch.newBatch() + partitions.foreach { case (partition, LeaderAndDelete(leaderAndIsr, deletePartition)) => + context.putPartitionLeadershipInfo(partition, LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch)) + batch.addStopReplicaRequestForBrokers(Seq(2, 3), partition, deletePartition) + } + batch.sendRequestsToBrokers(controllerEpoch) + + assertEquals(0, batch.sentEvents.size) + assertEquals(2, batch.sentRequests.size) + assertTrue(batch.sentRequests.contains(2)) + assertTrue(batch.sentRequests.contains(3)) + + val sentRequests = batch.sentRequests(2) + assertEquals(1, sentRequests.size) + + for (brokerId <- Set(2, 3)) { + val sentStopReplicaRequests = batch.collectStopReplicaRequestsFor(brokerId) + assertEquals(1, sentStopReplicaRequests.size) + + val stopReplicaRequest = sentStopReplicaRequests.head + assertEquals(partitionStates(partitions), stopReplicaRequest.partitionStates().asScala) + + applyStopReplicaResponseCallbacks(Errors.NONE, batch.sentRequests(2).toList) + assertEquals(0, batch.sentEvents.size) + } + } + + @Test + def testStopReplicaSentOnlyToLiveAndShuttingDownBrokers(): Unit = { + val context = initContext(Seq(1, 2, 3), 2, 3, Set("foo", "bar")) + val batch = new MockControllerBrokerRequestBatch(context) + + // 2 is shutting down, 3 is dead + context.shuttingDownBrokerIds.add(2) + context.removeLiveBrokers(Set(3)) + + val partitions = Map( + new TopicPartition("foo", 0) -> LeaderAndDelete(1, false), + new TopicPartition("foo", 1) -> LeaderAndDelete(2, false), + new TopicPartition("bar", 1) -> LeaderAndDelete(3, false) + ) + + batch.newBatch() + partitions.foreach { case (partition, LeaderAndDelete(leaderAndIsr, deletePartition)) => + context.putPartitionLeadershipInfo(partition, LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch)) + batch.addStopReplicaRequestForBrokers(Seq(2, 3), partition, deletePartition) + } + batch.sendRequestsToBrokers(controllerEpoch) + + assertEquals(0, batch.sentEvents.size) + assertEquals(1, batch.sentRequests.size) + assertTrue(batch.sentRequests.contains(2)) + + val sentRequests = batch.sentRequests(2) + assertEquals(1, sentRequests.size) + + val sentStopReplicaRequests = batch.collectStopReplicaRequestsFor(2) + assertEquals(1, sentStopReplicaRequests.size) + + val stopReplicaRequest = sentStopReplicaRequests.head + assertEquals(partitionStates(partitions), stopReplicaRequest.partitionStates().asScala) + } + + @Test + def testStopReplicaInterBrokerProtocolVersion(): Unit = { + testStopReplicaFollowsInterBrokerProtocolVersion(ApiVersion.latestVersion, ApiKeys.STOP_REPLICA.latestVersion) + + for (apiVersion <- ApiVersion.allVersions) { + if (apiVersion < KAFKA_2_2_IV0) + testStopReplicaFollowsInterBrokerProtocolVersion(apiVersion, 0.toShort) + else if (apiVersion < KAFKA_2_4_IV1) + testStopReplicaFollowsInterBrokerProtocolVersion(apiVersion, 1.toShort) + else if (apiVersion < KAFKA_2_6_IV0) + testStopReplicaFollowsInterBrokerProtocolVersion(apiVersion, 2.toShort) + else + testStopReplicaFollowsInterBrokerProtocolVersion(apiVersion, 3.toShort) + } + } + + private def testStopReplicaFollowsInterBrokerProtocolVersion(interBrokerProtocolVersion: ApiVersion, + expectedStopReplicaRequestVersion: Short): Unit = { + val context = initContext(Seq(1, 2, 3), 2, 3, Set("foo")) + val config = createConfig(interBrokerProtocolVersion) + val batch = new MockControllerBrokerRequestBatch(context, config) + + val partition = new TopicPartition("foo", 0) + val leaderAndIsr = LeaderAndIsr(1, List(1, 2)) + + context.putPartitionLeadershipInfo(partition, LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch)) + + batch.newBatch() + batch.addStopReplicaRequestForBrokers(Seq(2), partition, deletePartition = false) + batch.sendRequestsToBrokers(controllerEpoch) + + assertEquals(0, batch.sentEvents.size) + assertEquals(1, batch.sentRequests.size) + assertTrue(batch.sentRequests.contains(2)) + + val requests = batch.collectStopReplicaRequestsFor(2) + val allVersions = requests.map(_.version) + assertTrue(allVersions.forall(_ == expectedStopReplicaRequestVersion), + s"IBP $interBrokerProtocolVersion should use version $expectedStopReplicaRequestVersion, but found versions $allVersions") + } + + private case class LeaderAndDelete(leaderAndIsr: LeaderAndIsr, + deletePartition: Boolean) + + private object LeaderAndDelete { + def apply(leader: Int, deletePartition: Boolean): LeaderAndDelete = + new LeaderAndDelete(LeaderAndIsr(leader, List()), deletePartition) + } + + private def partitionStates(partitions: Map[TopicPartition, LeaderAndDelete], + topicsQueuedForDeletion: collection.Set[String] = Set.empty[String], + version: Short = ApiKeys.STOP_REPLICA.latestVersion): Map[TopicPartition, StopReplicaPartitionState] = { + partitions.map { case (topicPartition, LeaderAndDelete(leaderAndIsr, deletePartition)) => + topicPartition -> { + val partitionState = new StopReplicaPartitionState() + .setPartitionIndex(topicPartition.partition) + .setDeletePartition(deletePartition) + + if (version >= 3) { + partitionState.setLeaderEpoch(if (topicsQueuedForDeletion.contains(topicPartition.topic)) + LeaderAndIsr.EpochDuringDelete + else + leaderAndIsr.leaderEpoch) + } + + partitionState + } + } + } + + private def applyStopReplicaResponseCallbacks(error: Errors, sentRequests: List[SentRequest]): Unit = { + sentRequests.filter(_.responseCallback != null).foreach { sentRequest => + val stopReplicaRequest = sentRequest.request.build().asInstanceOf[StopReplicaRequest] + val stopReplicaResponse = + if (error == Errors.NONE) { + val partitionErrors = stopReplicaRequest.topicStates.asScala.flatMap { topic => + topic.partitionStates.asScala.map { partition => + new StopReplicaPartitionError() + .setTopicName(topic.topicName) + .setPartitionIndex(partition.partitionIndex) + .setErrorCode(error.code) + } + }.toBuffer.asJava + new StopReplicaResponse(new StopReplicaResponseData().setPartitionErrors(partitionErrors)) + } else { + stopReplicaRequest.getErrorResponse(error.exception) + } + sentRequest.responseCallback.apply(stopReplicaResponse) + } + } + + private def applyLeaderAndIsrResponseCallbacks(error: Errors, sentRequests: List[SentRequest]): Unit = { + sentRequests.filter(_.request.apiKey == ApiKeys.LEADER_AND_ISR).filter(_.responseCallback != null).foreach { sentRequest => + val leaderAndIsrRequest = sentRequest.request.build().asInstanceOf[LeaderAndIsrRequest] + val topicIds = leaderAndIsrRequest.topicIds + val data = new LeaderAndIsrResponseData() + .setErrorCode(error.code) + leaderAndIsrRequest.data.topicStates.asScala.foreach { t => + data.topics.add(new LeaderAndIsrTopicError() + .setTopicId(topicIds.get(t.topicName)) + .setPartitionErrors(t.partitionStates.asScala.map(p => + new LeaderAndIsrPartitionError() + .setPartitionIndex(p.partitionIndex) + .setErrorCode(error.code)).asJava)) + } + val leaderAndIsrResponse = new LeaderAndIsrResponse(data, leaderAndIsrRequest.version) + sentRequest.responseCallback(leaderAndIsrResponse) + } + } + + private def applyUpdateMetadataResponseCallbacks(error: Errors, sentRequests: List[SentRequest]): Unit = { + sentRequests.filter(_.request.apiKey == ApiKeys.UPDATE_METADATA).filter(_.responseCallback != null).foreach { sentRequest => + val response = new UpdateMetadataResponse(new UpdateMetadataResponseData().setErrorCode(error.code)) + sentRequest.responseCallback(response) + } + } + + private def createConfig(interBrokerVersion: ApiVersion): KafkaConfig = { + val props = new Properties() + props.put(KafkaConfig.BrokerIdProp, controllerId.toString) + props.put(KafkaConfig.ZkConnectProp, "zkConnect") + TestUtils.setIbpAndMessageFormatVersions(props, interBrokerVersion) + KafkaConfig.fromProps(props) + } + + private def replicaAssignment(replicas: Seq[Int]): ReplicaAssignment = ReplicaAssignment(replicas, Seq(), Seq()) + + private def initContext(brokers: Seq[Int], + numPartitions: Int, + replicationFactor: Int, + topics: Set[String]): ControllerContext = initContext(brokers, numPartitions, + replicationFactor, topics.map(_ -> Uuid.randomUuid()).toMap) + + private def initContext(brokers: Seq[Int], + numPartitions: Int, + replicationFactor: Int, + topicIds: Map[String, Uuid]): ControllerContext = { + val context = new ControllerContext + val brokerEpochs = brokers.map { brokerId => + val endpoint = new EndPoint("localhost", 9900 + brokerId, new ListenerName("PLAINTEXT"), + SecurityProtocol.PLAINTEXT) + Broker(brokerId, Seq(endpoint), rack = None) -> 1L + }.toMap + + context.setLiveBrokers(brokerEpochs) + context.setAllTopics(topicIds.keySet) + topicIds.foreach { case (name, id) => context.addTopicId(name, id) } + + // Simple round-robin replica assignment + var leaderIndex = 0 + for (topic <- topicIds.keys; partitionId <- 0 until numPartitions) { + val partition = new TopicPartition(topic, partitionId) + val replicas = (0 until replicationFactor).map { i => + val replica = brokers((i + leaderIndex) % brokers.size) + replica + } + context.updatePartitionFullReplicaAssignment(partition, ReplicaAssignment(replicas)) + leaderIndex += 1 + } + + context + } + + private case class SentRequest(request: ControlRequest, responseCallback: AbstractResponse => Unit) + + private class MockControllerBrokerRequestBatch(context: ControllerContext, config: KafkaConfig = config) + extends AbstractControllerBrokerRequestBatch(config, context, logger) { + + val sentEvents = ListBuffer.empty[ControllerEvent] + val sentRequests = mutable.Map.empty[Int, ListBuffer[SentRequest]] + + override def sendEvent(event: ControllerEvent): Unit = { + sentEvents.append(event) + } + override def sendRequest(brokerId: Int, request: ControlRequest, callback: AbstractResponse => Unit): Unit = { + sentRequests.getOrElseUpdate(brokerId, ListBuffer.empty) + sentRequests(brokerId).append(SentRequest(request, callback)) + } + + def collectStopReplicaRequestsFor(brokerId: Int): List[StopReplicaRequest] = { + sentRequests.get(brokerId) match { + case Some(requests) => requests + .filter(_.request.apiKey == ApiKeys.STOP_REPLICA) + .map(_.request.build().asInstanceOf[StopReplicaRequest]).toList + case None => List.empty[StopReplicaRequest] + } + } + + def collectUpdateMetadataRequestsFor(brokerId: Int): List[UpdateMetadataRequest] = { + sentRequests.get(brokerId) match { + case Some(requests) => requests + .filter(_.request.apiKey == ApiKeys.UPDATE_METADATA) + .map(_.request.build().asInstanceOf[UpdateMetadataRequest]).toList + case None => List.empty[UpdateMetadataRequest] + } + } + + def collectLeaderAndIsrRequestsFor(brokerId: Int): List[LeaderAndIsrRequest] = { + sentRequests.get(brokerId) match { + case Some(requests) => requests + .filter(_.request.apiKey == ApiKeys.LEADER_AND_ISR) + .map(_.request.build().asInstanceOf[LeaderAndIsrRequest]).toList + case None => List.empty[LeaderAndIsrRequest] + } + } + } + +} diff --git a/core/src/test/scala/unit/kafka/controller/ControllerContextTest.scala b/core/src/test/scala/unit/kafka/controller/ControllerContextTest.scala new file mode 100644 index 0000000..e8efa5a --- /dev/null +++ b/core/src/test/scala/unit/kafka/controller/ControllerContextTest.scala @@ -0,0 +1,206 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.controller + +import kafka.api.LeaderAndIsr +import kafka.cluster.{Broker, EndPoint} +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertTrue} +import org.junit.jupiter.api.{BeforeEach, Test} + + +class ControllerContextTest { + + var context: ControllerContext = null + val brokers: Seq[Int] = Seq(1, 2, 3) + val tp1 = new TopicPartition("A", 0) + val tp2 = new TopicPartition("A", 1) + val tp3 = new TopicPartition("B", 0) + + @BeforeEach + def setUp(): Unit = { + context = new ControllerContext + + val brokerEpochs = Seq(1,2,3).map { brokerId => + val endpoint = new EndPoint("localhost", 9900 + brokerId, new ListenerName("PLAINTEXT"), + SecurityProtocol.PLAINTEXT) + Broker(brokerId, Seq(endpoint), rack = None) -> 1L + }.toMap + + context.setLiveBrokers(brokerEpochs) + + // Simple round-robin replica assignment + var leaderIndex = 0 + Seq(tp1, tp2, tp3).foreach { partition => + val replicas = brokers.indices.map { i => + brokers((i + leaderIndex) % brokers.size) + } + context.updatePartitionFullReplicaAssignment(partition, ReplicaAssignment(replicas)) + leaderIndex += 1 + } + } + + @Test + def testUpdatePartitionFullReplicaAssignmentUpdatesReplicaAssignment(): Unit = { + val initialReplicas = Seq(4) + context.updatePartitionFullReplicaAssignment(tp1, ReplicaAssignment(initialReplicas)) + val fullAssignment = context.partitionFullReplicaAssignment(tp1) + assertEquals(initialReplicas, fullAssignment.replicas) + assertEquals(Seq(), fullAssignment.addingReplicas) + assertEquals(Seq(), fullAssignment.removingReplicas) + + val expectedFullAssignment = ReplicaAssignment(Seq(3), Seq(1), Seq(2)) + context.updatePartitionFullReplicaAssignment(tp1, expectedFullAssignment) + val updatedFullAssignment = context.partitionFullReplicaAssignment(tp1) + assertEquals(expectedFullAssignment.replicas, updatedFullAssignment.replicas) + assertEquals(expectedFullAssignment.addingReplicas, updatedFullAssignment.addingReplicas) + assertEquals(expectedFullAssignment.removingReplicas, updatedFullAssignment.removingReplicas) + } + + @Test + def testPartitionReplicaAssignmentReturnsEmptySeqIfTopicOrPartitionDoesNotExist(): Unit = { + val noTopicReplicas = context.partitionReplicaAssignment(new TopicPartition("NONEXISTENT", 0)) + assertEquals(Seq.empty, noTopicReplicas) + val noPartitionReplicas = context.partitionReplicaAssignment(new TopicPartition("A", 100)) + assertEquals(Seq.empty, noPartitionReplicas) + } + + @Test + def testPartitionFullReplicaAssignmentReturnsEmptyAssignmentIfTopicOrPartitionDoesNotExist(): Unit = { + val expectedEmptyAssignment = ReplicaAssignment(Seq.empty, Seq.empty, Seq.empty) + + val noTopicAssignment = context.partitionFullReplicaAssignment(new TopicPartition("NONEXISTENT", 0)) + assertEquals(expectedEmptyAssignment, noTopicAssignment) + val noPartitionAssignment = context.partitionFullReplicaAssignment(new TopicPartition("A", 100)) + assertEquals(expectedEmptyAssignment, noPartitionAssignment) + } + + @Test + def testPartitionReplicaAssignmentForTopicReturnsEmptyMapIfTopicDoesNotExist(): Unit = { + assertEquals(Map.empty, context.partitionReplicaAssignmentForTopic("NONEXISTENT")) + } + + @Test + def testPartitionReplicaAssignmentForTopicReturnsExpectedReplicaAssignments(): Unit = { + val expectedAssignments = Map( + tp1 -> context.partitionReplicaAssignment(tp1), + tp2 -> context.partitionReplicaAssignment(tp2) + ) + val receivedAssignments = context.partitionReplicaAssignmentForTopic("A") + assertEquals(expectedAssignments, receivedAssignments) + } + + @Test + def testPartitionReplicaAssignment(): Unit = { + val reassigningPartition = ReplicaAssignment(List(1, 2, 3, 4, 5, 6), List(2, 3, 4), List(1, 5, 6)) + assertTrue(reassigningPartition.isBeingReassigned) + assertEquals(List(2, 3, 4), reassigningPartition.targetReplicas) + + val reassigningPartition2 = ReplicaAssignment(List(1, 2, 3, 4), List(), List(1, 4)) + assertTrue(reassigningPartition2.isBeingReassigned) + assertEquals(List(2, 3), reassigningPartition2.targetReplicas) + + val reassigningPartition3 = ReplicaAssignment(List(1, 2, 3, 4), List(4), List(2)) + assertTrue(reassigningPartition3.isBeingReassigned) + assertEquals(List(1, 3, 4), reassigningPartition3.targetReplicas) + + val partition = ReplicaAssignment(List(1, 2, 3, 4, 5, 6), List(), List()) + assertFalse(partition.isBeingReassigned) + assertEquals(List(1, 2, 3, 4, 5, 6), partition.targetReplicas) + + val reassigningPartition4 = ReplicaAssignment(Seq(1, 2, 3, 4)).reassignTo(Seq(4, 2, 5, 3)) + assertEquals(List(4, 2, 5, 3, 1), reassigningPartition4.replicas) + assertEquals(List(4, 2, 5, 3), reassigningPartition4.targetReplicas) + assertEquals(List(5), reassigningPartition4.addingReplicas) + assertEquals(List(1), reassigningPartition4.removingReplicas) + assertTrue(reassigningPartition4.isBeingReassigned) + + val reassigningPartition5 = ReplicaAssignment(Seq(1, 2, 3)).reassignTo(Seq(4, 5, 6)) + assertEquals(List(4, 5, 6, 1, 2, 3), reassigningPartition5.replicas) + assertEquals(List(4, 5, 6), reassigningPartition5.targetReplicas) + assertEquals(List(4, 5, 6), reassigningPartition5.addingReplicas) + assertEquals(List(1, 2, 3), reassigningPartition5.removingReplicas) + assertTrue(reassigningPartition5.isBeingReassigned) + + val nonReassigningPartition = ReplicaAssignment(Seq(1, 2, 3)).reassignTo(Seq(3, 1, 2)) + assertEquals(List(3, 1, 2), nonReassigningPartition.replicas) + assertEquals(List(3, 1, 2), nonReassigningPartition.targetReplicas) + assertEquals(List(), nonReassigningPartition.addingReplicas) + assertEquals(List(), nonReassigningPartition.removingReplicas) + assertFalse(nonReassigningPartition.isBeingReassigned) + } + + @Test + def testReassignToIdempotence(): Unit = { + val assignment1 = ReplicaAssignment(Seq(1, 2, 3)) + assertEquals(assignment1, assignment1.reassignTo(assignment1.targetReplicas)) + + val assignment2 = ReplicaAssignment(Seq(4, 5, 6, 1, 2, 3), + addingReplicas = Seq(4, 5, 6), removingReplicas = Seq(1, 2, 3)) + assertEquals(assignment2, assignment2.reassignTo(assignment2.targetReplicas)) + + val assignment3 = ReplicaAssignment(Seq(4, 2, 3, 1), + addingReplicas = Seq(4), removingReplicas = Seq(1)) + assertEquals(assignment3, assignment3.reassignTo(assignment3.targetReplicas)) + } + + @Test + def testReassignTo(): Unit = { + val assignment = ReplicaAssignment(Seq(1, 2, 3)) + val firstReassign = assignment.reassignTo(Seq(4, 5, 6)) + + assertEquals(ReplicaAssignment(Seq(4, 5, 6, 1, 2, 3), Seq(4, 5, 6), Seq(1, 2, 3)), firstReassign) + assertEquals(ReplicaAssignment(Seq(7, 8, 9, 1, 2, 3), Seq(7, 8, 9), Seq(1, 2, 3)), firstReassign.reassignTo(Seq(7, 8, 9))) + assertEquals(ReplicaAssignment(Seq(7, 8, 9, 1, 2, 3), Seq(7, 8, 9), Seq(1, 2, 3)), assignment.reassignTo(Seq(7, 8, 9))) + assertEquals(assignment, firstReassign.reassignTo(Seq(1,2,3))) + } + + @Test + def testPreferredReplicaImbalanceMetric(): Unit = { + context.updatePartitionFullReplicaAssignment(tp1, ReplicaAssignment(Seq(1, 2, 3))) + context.updatePartitionFullReplicaAssignment(tp2, ReplicaAssignment(Seq(1, 2, 3))) + context.updatePartitionFullReplicaAssignment(tp3, ReplicaAssignment(Seq(1, 2, 3))) + assertEquals(0, context.preferredReplicaImbalanceCount) + + context.putPartitionLeadershipInfo(tp1, LeaderIsrAndControllerEpoch(LeaderAndIsr(1, List(1, 2, 3)), 0)) + assertEquals(0, context.preferredReplicaImbalanceCount) + + context.putPartitionLeadershipInfo(tp2, LeaderIsrAndControllerEpoch(LeaderAndIsr(2, List(2, 3, 1)), 0)) + assertEquals(1, context.preferredReplicaImbalanceCount) + + context.putPartitionLeadershipInfo(tp3, LeaderIsrAndControllerEpoch(LeaderAndIsr(3, List(3, 1, 2)), 0)) + assertEquals(2, context.preferredReplicaImbalanceCount) + + context.updatePartitionFullReplicaAssignment(tp1, ReplicaAssignment(Seq(2, 3, 1))) + context.updatePartitionFullReplicaAssignment(tp2, ReplicaAssignment(Seq(2, 3, 1))) + assertEquals(2, context.preferredReplicaImbalanceCount) + + context.queueTopicDeletion(Set(tp3.topic)) + assertEquals(1, context.preferredReplicaImbalanceCount) + + context.putPartitionLeadershipInfo(tp3, LeaderIsrAndControllerEpoch(LeaderAndIsr(1, List(3, 1, 2)), 0)) + assertEquals(1, context.preferredReplicaImbalanceCount) + + context.removeTopic(tp1.topic) + context.removeTopic(tp2.topic) + context.removeTopic(tp3.topic) + assertEquals(0, context.preferredReplicaImbalanceCount) + } +} diff --git a/core/src/test/scala/unit/kafka/controller/ControllerEventManagerTest.scala b/core/src/test/scala/unit/kafka/controller/ControllerEventManagerTest.scala new file mode 100644 index 0000000..26bbf94 --- /dev/null +++ b/core/src/test/scala/unit/kafka/controller/ControllerEventManagerTest.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.controller + +import java.util.concurrent.CountDownLatch +import java.util.concurrent.atomic.AtomicInteger + +import com.yammer.metrics.core.{Histogram, MetricName, Timer} +import kafka.controller +import kafka.metrics.KafkaYammerMetrics +import kafka.utils.TestUtils +import org.apache.kafka.common.message.UpdateMetadataResponseData +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.UpdateMetadataResponse +import org.apache.kafka.common.utils.MockTime +import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue, fail} +import org.junit.jupiter.api.{AfterEach, Test} + +import scala.jdk.CollectionConverters._ +import scala.collection.mutable + +class ControllerEventManagerTest { + + private var controllerEventManager: ControllerEventManager = _ + + @AfterEach + def tearDown(): Unit = { + if (controllerEventManager != null) + controllerEventManager.close() + } + + @Test + def testMetricsCleanedOnClose(): Unit = { + val time = new MockTime() + val controllerStats = new ControllerStats + val eventProcessor = new ControllerEventProcessor { + override def process(event: ControllerEvent): Unit = {} + override def preempt(event: ControllerEvent): Unit = {} + } + + def allEventManagerMetrics: Set[MetricName] = { + KafkaYammerMetrics.defaultRegistry.allMetrics.asScala.keySet + .filter(_.getMBeanName.startsWith("kafka.controller:type=ControllerEventManager")) + .toSet + } + + controllerEventManager = new ControllerEventManager(0, eventProcessor, + time, controllerStats.rateAndTimeMetrics) + controllerEventManager.start() + assertTrue(allEventManagerMetrics.nonEmpty) + + controllerEventManager.close() + assertTrue(allEventManagerMetrics.isEmpty) + } + + @Test + def testEventWithoutRateMetrics(): Unit = { + val time = new MockTime() + val controllerStats = new ControllerStats + val processedEvents = mutable.Set.empty[ControllerEvent] + + val eventProcessor = new ControllerEventProcessor { + override def process(event: ControllerEvent): Unit = { processedEvents += event } + override def preempt(event: ControllerEvent): Unit = {} + } + + controllerEventManager = new ControllerEventManager(0, eventProcessor, + time, controllerStats.rateAndTimeMetrics) + controllerEventManager.start() + + val updateMetadataResponse = new UpdateMetadataResponse( + new UpdateMetadataResponseData().setErrorCode(Errors.NONE.code) + ) + val updateMetadataResponseEvent = controller.UpdateMetadataResponseReceived(updateMetadataResponse, brokerId = 1) + controllerEventManager.put(updateMetadataResponseEvent) + TestUtils.waitUntilTrue(() => processedEvents.size == 1, + "Failed to process expected event before timing out") + assertEquals(updateMetadataResponseEvent, processedEvents.head) + } + + @Test + def testEventQueueTime(): Unit = { + val metricName = "kafka.controller:type=ControllerEventManager,name=EventQueueTimeMs" + val controllerStats = new ControllerStats + val time = new MockTime() + val latch = new CountDownLatch(1) + val processedEvents = new AtomicInteger() + + val eventProcessor = new ControllerEventProcessor { + override def process(event: ControllerEvent): Unit = { + latch.await() + time.sleep(500) + processedEvents.incrementAndGet() + } + override def preempt(event: ControllerEvent): Unit = {} + } + + // The metric should not already exist + assertTrue(KafkaYammerMetrics.defaultRegistry.allMetrics.asScala.filter { case (k, _) => + k.getMBeanName == metricName + }.values.isEmpty) + + controllerEventManager = new ControllerEventManager(0, eventProcessor, + time, controllerStats.rateAndTimeMetrics) + controllerEventManager.start() + + controllerEventManager.put(TopicChange) + controllerEventManager.put(TopicChange) + latch.countDown() + + TestUtils.waitUntilTrue(() => processedEvents.get() == 2, + "Timed out waiting for processing of all events") + + val queueTimeHistogram = KafkaYammerMetrics.defaultRegistry.allMetrics.asScala.filter { case (k, _) => + k.getMBeanName == metricName + }.values.headOption.getOrElse(fail(s"Unable to find metric $metricName")).asInstanceOf[Histogram] + + assertEquals(2, queueTimeHistogram.count) + assertEquals(0, queueTimeHistogram.min, 0.01) + assertEquals(500, queueTimeHistogram.max, 0.01) + } + + @Test + def testEventQueueTimeResetOnTimeout(): Unit = { + val metricName = "kafka.controller:type=ControllerEventManager,name=EventQueueTimeMs" + val controllerStats = new ControllerStats + val time = new MockTime() + val processedEvents = new AtomicInteger() + + val eventProcessor = new ControllerEventProcessor { + override def process(event: ControllerEvent): Unit = { + processedEvents.incrementAndGet() + } + override def preempt(event: ControllerEvent): Unit = {} + } + + controllerEventManager = new ControllerEventManager(0, eventProcessor, + time, controllerStats.rateAndTimeMetrics, 1) + controllerEventManager.start() + + controllerEventManager.put(TopicChange) + controllerEventManager.put(TopicChange) + + TestUtils.waitUntilTrue(() => processedEvents.get() == 2, + "Timed out waiting for processing of all events") + + val queueTimeHistogram = KafkaYammerMetrics.defaultRegistry.allMetrics.asScala.filter { case (k, _) => + k.getMBeanName == metricName + }.values.headOption.getOrElse(fail(s"Unable to find metric $metricName")).asInstanceOf[Histogram] + + TestUtils.waitUntilTrue(() => queueTimeHistogram.count == 0, + "Timed out on resetting queueTimeHistogram") + assertEquals(0, queueTimeHistogram.min, 0.1) + assertEquals(0, queueTimeHistogram.max, 0.1) + } + + @Test + def testSuccessfulEvent(): Unit = { + check("kafka.controller:type=ControllerStats,name=AutoLeaderBalanceRateAndTimeMs", + AutoPreferredReplicaLeaderElection, () => ()) + } + + @Test + def testEventThatThrowsException(): Unit = { + check("kafka.controller:type=ControllerStats,name=LeaderElectionRateAndTimeMs", + BrokerChange, () => throw new NullPointerException) + } + + private def check(metricName: String, + event: ControllerEvent, + func: () => Unit): Unit = { + val controllerStats = new ControllerStats + val eventProcessedListenerCount = new AtomicInteger + val latch = new CountDownLatch(1) + val eventProcessor = new ControllerEventProcessor { + override def process(event: ControllerEvent): Unit = { + // Only return from `process()` once we have checked `controllerEventManager.state` + latch.await() + eventProcessedListenerCount.incrementAndGet() + func() + } + override def preempt(event: ControllerEvent): Unit = {} + } + + controllerEventManager = new ControllerEventManager(0, eventProcessor, + new MockTime(), controllerStats.rateAndTimeMetrics) + controllerEventManager.start() + + val initialTimerCount = timer(metricName).count + + controllerEventManager.put(event) + TestUtils.waitUntilTrue(() => controllerEventManager.state == event.state, + s"Controller state is not ${event.state}") + latch.countDown() + + TestUtils.waitUntilTrue(() => controllerEventManager.state == ControllerState.Idle, + "Controller state has not changed back to Idle") + assertEquals(1, eventProcessedListenerCount.get) + + assertEquals(initialTimerCount + 1, timer(metricName).count, "Timer has not been updated") + } + + private def timer(metricName: String): Timer = { + KafkaYammerMetrics.defaultRegistry.allMetrics.asScala.filter { case (k, _) => + k.getMBeanName == metricName + }.values.headOption.getOrElse(fail(s"Unable to find metric $metricName")).asInstanceOf[Timer] + } + +} diff --git a/core/src/test/scala/unit/kafka/controller/ControllerFailoverTest.scala b/core/src/test/scala/unit/kafka/controller/ControllerFailoverTest.scala new file mode 100644 index 0000000..eecc616 --- /dev/null +++ b/core/src/test/scala/unit/kafka/controller/ControllerFailoverTest.scala @@ -0,0 +1,99 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.controller + +import java.util.Properties +import java.util.concurrent.CountDownLatch +import java.util.concurrent.atomic.AtomicReference + +import kafka.integration.KafkaServerTestHarness +import kafka.server.KafkaConfig +import kafka.utils._ +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.metrics.Metrics +import org.apache.log4j.Logger +import org.junit.jupiter.api.{AfterEach, Test} +import org.junit.jupiter.api.Assertions._ + +class ControllerFailoverTest extends KafkaServerTestHarness with Logging { + val log = Logger.getLogger(classOf[ControllerFailoverTest]) + val numNodes = 2 + val numParts = 1 + val msgQueueSize = 1 + val topic = "topic1" + val overridingProps = new Properties() + val metrics = new Metrics() + overridingProps.put(KafkaConfig.NumPartitionsProp, numParts.toString) + + override def generateConfigs = TestUtils.createBrokerConfigs(numNodes, zkConnect) + .map(KafkaConfig.fromProps(_, overridingProps)) + + @AfterEach + override def tearDown(): Unit = { + super.tearDown() + this.metrics.close() + } + + /** + * See @link{https://issues.apache.org/jira/browse/KAFKA-2300} + * for the background of this test case + */ + @Test + def testHandleIllegalStateException(): Unit = { + val initialController = servers.find(_.kafkaController.isActive).map(_.kafkaController).getOrElse { + throw new AssertionError("Could not find controller") + } + val initialEpoch = initialController.epoch + // Create topic with one partition + createTopic(topic, 1, 1) + val topicPartition = new TopicPartition("topic1", 0) + TestUtils.waitUntilTrue(() => + initialController.controllerContext.partitionsInState(OnlinePartition).contains(topicPartition), + s"Partition $topicPartition did not transition to online state") + + // Wait until we have verified that we have resigned + val latch = new CountDownLatch(1) + val exceptionThrown = new AtomicReference[Throwable]() + val illegalStateEvent = new MockEvent(ControllerState.BrokerChange) { + override def process(): Unit = { + try initialController.handleIllegalState(new IllegalStateException("Thrown for test purposes")) + catch { + case t: Throwable => exceptionThrown.set(t) + } + latch.await() + } + + override def preempt(): Unit = {} + } + initialController.eventManager.put(illegalStateEvent) + // Check that we have shutdown the scheduler (via onControllerResigned) + TestUtils.waitUntilTrue(() => !initialController.kafkaScheduler.isStarted, "Scheduler was not shutdown") + TestUtils.waitUntilTrue(() => !initialController.isActive, "Controller did not become inactive") + latch.countDown() + TestUtils.waitUntilTrue(() => Option(exceptionThrown.get()).isDefined, "handleIllegalState did not throw an exception") + assertTrue(exceptionThrown.get.isInstanceOf[IllegalStateException], + s"handleIllegalState should throw an IllegalStateException, but $exceptionThrown was thrown") + + TestUtils.waitUntilTrue(() => { + servers.exists { server => + server.kafkaController.isActive && server.kafkaController.epoch > initialEpoch + } + }, "Failed to find controller") + + } +} diff --git a/core/src/test/scala/unit/kafka/controller/ControllerIntegrationTest.scala b/core/src/test/scala/unit/kafka/controller/ControllerIntegrationTest.scala new file mode 100644 index 0000000..2302007 --- /dev/null +++ b/core/src/test/scala/unit/kafka/controller/ControllerIntegrationTest.scala @@ -0,0 +1,1405 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.controller + +import java.util.Properties +import java.util.concurrent.{CompletableFuture, CountDownLatch, LinkedBlockingQueue, TimeUnit} +import com.yammer.metrics.core.Timer +import kafka.api.{ApiVersion, KAFKA_2_6_IV0, KAFKA_2_7_IV0, LeaderAndIsr} +import kafka.controller.KafkaController.AlterIsrCallback +import kafka.metrics.KafkaYammerMetrics +import kafka.server.{KafkaConfig, KafkaServer, QuorumTestHarness} +import kafka.utils.{LogCaptureAppender, TestUtils} +import kafka.zk.{FeatureZNodeStatus, _} +import org.apache.kafka.common.errors.{ControllerMovedException, StaleBrokerEpochException} +import org.apache.kafka.common.feature.Features +import org.apache.kafka.common.metrics.KafkaMetric +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.{ElectionType, TopicPartition, Uuid} +import org.apache.log4j.Level +import org.junit.jupiter.api.Assertions.{assertEquals, assertNotEquals, assertTrue} +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} +import org.mockito.Mockito.{doAnswer, spy, verify} +import org.mockito.invocation.InvocationOnMock + +import scala.collection.{Map, Seq, mutable} +import scala.jdk.CollectionConverters._ +import scala.util.{Failure, Success, Try} + +class ControllerIntegrationTest extends QuorumTestHarness { + var servers = Seq.empty[KafkaServer] + val firstControllerEpoch = KafkaController.InitialControllerEpoch + 1 + val firstControllerEpochZkVersion = KafkaController.InitialControllerEpochZkVersion + 1 + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + servers = Seq.empty[KafkaServer] + } + + @AfterEach + override def tearDown(): Unit = { + TestUtils.shutdownServers(servers) + super.tearDown() + } + + @Test + def testEmptyCluster(): Unit = { + servers = makeServers(1) + TestUtils.waitUntilTrue(() => zkClient.getControllerId.isDefined, "failed to elect a controller") + waitUntilControllerEpoch(firstControllerEpoch, "broker failed to set controller epoch") + } + + @Test + def testControllerEpochPersistsWhenAllBrokersDown(): Unit = { + servers = makeServers(1) + TestUtils.waitUntilTrue(() => zkClient.getControllerId.isDefined, "failed to elect a controller") + waitUntilControllerEpoch(firstControllerEpoch, "broker failed to set controller epoch") + servers.head.shutdown() + servers.head.awaitShutdown() + TestUtils.waitUntilTrue(() => !zkClient.getControllerId.isDefined, "failed to kill controller") + waitUntilControllerEpoch(firstControllerEpoch, "controller epoch was not persisted after broker failure") + } + + @Test + def testControllerMoveIncrementsControllerEpoch(): Unit = { + servers = makeServers(1) + TestUtils.waitUntilTrue(() => zkClient.getControllerId.isDefined, "failed to elect a controller") + waitUntilControllerEpoch(firstControllerEpoch, "broker failed to set controller epoch") + servers.head.shutdown() + servers.head.awaitShutdown() + servers.head.startup() + TestUtils.waitUntilTrue(() => zkClient.getControllerId.isDefined, "failed to elect a controller") + waitUntilControllerEpoch(firstControllerEpoch + 1, "controller epoch was not incremented after controller move") + } + + @Test + def testMetadataPropagationOnControlPlane(): Unit = { + servers = makeServers(1, + listeners = Some("PLAINTEXT://localhost:0,CONTROLLER://localhost:0"), + listenerSecurityProtocolMap = Some("PLAINTEXT:PLAINTEXT,CONTROLLER:PLAINTEXT"), + controlPlaneListenerName = Some("CONTROLLER")) + TestUtils.waitUntilBrokerMetadataIsPropagated(servers) + val controlPlaneMetricMap = mutable.Map[String, KafkaMetric]() + val dataPlaneMetricMap = mutable.Map[String, KafkaMetric]() + servers.head.metrics.metrics.values.forEach { kafkaMetric => + if (kafkaMetric.metricName.tags.values.contains("CONTROLLER")) { + controlPlaneMetricMap.put(kafkaMetric.metricName().name(), kafkaMetric) + } + if (kafkaMetric.metricName.tags.values.contains("PLAINTEXT")) { + dataPlaneMetricMap.put(kafkaMetric.metricName.name, kafkaMetric) + } + } + assertEquals(1e-0, controlPlaneMetricMap("response-total").metricValue().asInstanceOf[Double], 0) + assertEquals(0e-0, dataPlaneMetricMap("response-total").metricValue().asInstanceOf[Double], 0) + assertEquals(1e-0, controlPlaneMetricMap("request-total").metricValue().asInstanceOf[Double], 0) + assertEquals(0e-0, dataPlaneMetricMap("request-total").metricValue().asInstanceOf[Double], 0) + assertTrue(controlPlaneMetricMap("incoming-byte-total").metricValue().asInstanceOf[Double] > 1.0) + assertTrue(dataPlaneMetricMap("incoming-byte-total").metricValue().asInstanceOf[Double] == 0.0) + assertTrue(controlPlaneMetricMap("network-io-total").metricValue().asInstanceOf[Double] == 2.0) + assertTrue(dataPlaneMetricMap("network-io-total").metricValue().asInstanceOf[Double] == 0.0) + } + + // This test case is used to ensure that there will be no correctness issue after we avoid sending out full + // UpdateMetadataRequest to all brokers in the cluster + @Test + def testMetadataPropagationOnBrokerChange(): Unit = { + servers = makeServers(3) + TestUtils.waitUntilBrokerMetadataIsPropagated(servers) + val controllerId = TestUtils.waitUntilControllerElected(zkClient) + // Need to make sure the broker we shutdown and startup are not the controller. Otherwise we will send out + // full UpdateMetadataReuqest to all brokers during controller failover. + val testBroker = servers.filter(e => e.config.brokerId != controllerId).head + val remainingBrokers = servers.filter(_.config.brokerId != testBroker.config.brokerId) + val topic = "topic1" + // Make sure shutdown the test broker will not require any leadership change to test avoid sending out full + // UpdateMetadataRequest on broker failure + val assignment = Map( + 0 -> Seq(remainingBrokers(0).config.brokerId, testBroker.config.brokerId), + 1 -> remainingBrokers.map(_.config.brokerId)) + + // Create topic + TestUtils.createTopic(zkClient, topic, assignment, servers) + + // Shutdown the broker + testBroker.shutdown() + testBroker.awaitShutdown() + TestUtils.waitUntilBrokerMetadataIsPropagated(remainingBrokers) + remainingBrokers.foreach { server => + val offlineReplicaPartitionInfo = server.metadataCache.getPartitionInfo(topic, 0).get + assertEquals(1, offlineReplicaPartitionInfo.offlineReplicas.size()) + assertEquals(testBroker.config.brokerId, offlineReplicaPartitionInfo.offlineReplicas.get(0)) + assertEquals(assignment(0).asJava, offlineReplicaPartitionInfo.replicas) + assertEquals(Seq(remainingBrokers.head.config.brokerId).asJava, offlineReplicaPartitionInfo.isr) + val onlinePartitionInfo = server.metadataCache.getPartitionInfo(topic, 1).get + assertEquals(assignment(1).asJava, onlinePartitionInfo.replicas) + assertTrue(onlinePartitionInfo.offlineReplicas.isEmpty) + } + + // Startup the broker + testBroker.startup() + TestUtils.waitUntilTrue( () => { + !servers.exists { server => + assignment.exists { case (partitionId, replicas) => + val partitionInfoOpt = server.metadataCache.getPartitionInfo(topic, partitionId) + if (partitionInfoOpt.isDefined) { + val partitionInfo = partitionInfoOpt.get + !partitionInfo.offlineReplicas.isEmpty || !partitionInfo.replicas.asScala.equals(replicas) + } else { + true + } + } + } + }, "Inconsistent metadata after broker startup") + } + + @Test + def testMetadataPropagationForOfflineReplicas(): Unit = { + servers = makeServers(3) + TestUtils.waitUntilBrokerMetadataIsPropagated(servers) + val controllerId = TestUtils.waitUntilControllerElected(zkClient) + + //get brokerId for topic creation with single partition and RF =1 + val replicaBroker = servers.filter(e => e.config.brokerId != controllerId).head + + val controllerBroker = servers.filter(e => e.config.brokerId == controllerId).head + val otherBroker = servers.filter(e => e.config.brokerId != controllerId && + e.config.brokerId != replicaBroker.config.brokerId).head + + val topic = "topic1" + val assignment = Map(0 -> Seq(replicaBroker.config.brokerId)) + + // Create topic + TestUtils.createTopic(zkClient, topic, assignment, servers) + + // Shutdown the other broker + otherBroker.shutdown() + otherBroker.awaitShutdown() + + // Shutdown the broker with replica + replicaBroker.shutdown() + replicaBroker.awaitShutdown() + + //Shutdown controller broker + controllerBroker.shutdown() + controllerBroker.awaitShutdown() + + def verifyMetadata(broker: KafkaServer): Unit = { + broker.startup() + TestUtils.waitUntilTrue(() => { + val partitionInfoOpt = broker.metadataCache.getPartitionInfo(topic, 0) + if (partitionInfoOpt.isDefined) { + val partitionInfo = partitionInfoOpt.get + (!partitionInfo.offlineReplicas.isEmpty && partitionInfo.leader == -1 + && !partitionInfo.replicas.isEmpty && !partitionInfo.isr.isEmpty) + } else { + false + } + }, "Inconsistent metadata after broker startup") + } + + //Start controller broker and check metadata + verifyMetadata(controllerBroker) + + //Start other broker and check metadata + verifyMetadata(otherBroker) + } + + @Test + def testMetadataPropagationOnBrokerShutdownWithNoReplicas(): Unit = { + servers = makeServers(3) + TestUtils.waitUntilBrokerMetadataIsPropagated(servers) + val controllerId = TestUtils.waitUntilControllerElected(zkClient) + val replicaBroker = servers.filter(e => e.config.brokerId != controllerId).head + + val controllerBroker = servers.filter(e => e.config.brokerId == controllerId).head + val otherBroker = servers.filter(e => e.config.brokerId != controllerId && + e.config.brokerId != replicaBroker.config.brokerId).head + + val topic = "topic1" + val assignment = Map(0 -> Seq(replicaBroker.config.brokerId)) + + // Create topic + TestUtils.createTopic(zkClient, topic, assignment, servers) + + // Shutdown the broker with replica + replicaBroker.shutdown() + replicaBroker.awaitShutdown() + + // Shutdown the other broker + otherBroker.shutdown() + otherBroker.awaitShutdown() + + // The controller should be the only alive broker + TestUtils.waitUntilBrokerMetadataIsPropagated(Seq(controllerBroker)) + } + + @Test + def testTopicCreation(): Unit = { + servers = makeServers(1) + val tp = new TopicPartition("t", 0) + val assignment = Map(tp.partition -> Seq(0)) + TestUtils.createTopic(zkClient, tp.topic, partitionReplicaAssignment = assignment, servers = servers) + waitForPartitionState(tp, firstControllerEpoch, 0, LeaderAndIsr.initialLeaderEpoch, + "failed to get expected partition state upon topic creation") + } + + @Test + def testTopicCreationWithOfflineReplica(): Unit = { + servers = makeServers(2) + val controllerId = TestUtils.waitUntilControllerElected(zkClient) + val otherBrokerId = servers.map(_.config.brokerId).filter(_ != controllerId).head + servers(otherBrokerId).shutdown() + servers(otherBrokerId).awaitShutdown() + val tp = new TopicPartition("t", 0) + val assignment = Map(tp.partition -> Seq(otherBrokerId, controllerId)) + TestUtils.createTopic(zkClient, tp.topic, partitionReplicaAssignment = assignment, servers = servers.take(1)) + waitForPartitionState(tp, firstControllerEpoch, controllerId, LeaderAndIsr.initialLeaderEpoch, + "failed to get expected partition state upon topic creation") + } + + @Test + def testTopicPartitionExpansion(): Unit = { + servers = makeServers(1) + val tp0 = new TopicPartition("t", 0) + val tp1 = new TopicPartition("t", 1) + val assignment = Map(tp0.partition -> Seq(0)) + val expandedAssignment = Map( + tp0 -> ReplicaAssignment(Seq(0), Seq(), Seq()), + tp1 -> ReplicaAssignment(Seq(0), Seq(), Seq())) + TestUtils.createTopic(zkClient, tp0.topic, partitionReplicaAssignment = assignment, servers = servers) + zkClient.setTopicAssignment(tp0.topic, Some(Uuid.randomUuid()), expandedAssignment, firstControllerEpochZkVersion) + waitForPartitionState(tp1, firstControllerEpoch, 0, LeaderAndIsr.initialLeaderEpoch, + "failed to get expected partition state upon topic partition expansion") + TestUtils.waitForPartitionMetadata(servers, tp1.topic, tp1.partition) + } + + @Test + def testTopicPartitionExpansionWithOfflineReplica(): Unit = { + servers = makeServers(2) + val controllerId = TestUtils.waitUntilControllerElected(zkClient) + val otherBrokerId = servers.map(_.config.brokerId).filter(_ != controllerId).head + val tp0 = new TopicPartition("t", 0) + val tp1 = new TopicPartition("t", 1) + val assignment = Map(tp0.partition -> Seq(otherBrokerId, controllerId)) + val expandedAssignment = Map( + tp0 -> ReplicaAssignment(Seq(otherBrokerId, controllerId), Seq(), Seq()), + tp1 -> ReplicaAssignment(Seq(otherBrokerId, controllerId), Seq(), Seq())) + TestUtils.createTopic(zkClient, tp0.topic, partitionReplicaAssignment = assignment, servers = servers) + servers(otherBrokerId).shutdown() + servers(otherBrokerId).awaitShutdown() + zkClient.setTopicAssignment(tp0.topic, Some(Uuid.randomUuid()), expandedAssignment, firstControllerEpochZkVersion) + waitForPartitionState(tp1, firstControllerEpoch, controllerId, LeaderAndIsr.initialLeaderEpoch, + "failed to get expected partition state upon topic partition expansion") + TestUtils.waitForPartitionMetadata(Seq(servers(controllerId)), tp1.topic, tp1.partition) + } + + @Test + def testPartitionReassignment(): Unit = { + servers = makeServers(2) + val controllerId = TestUtils.waitUntilControllerElected(zkClient) + + val metricName = s"kafka.controller:type=ControllerStats,name=${ControllerState.AlterPartitionReassignment.rateAndTimeMetricName.get}" + val timerCount = timer(metricName).count + + val otherBrokerId = servers.map(_.config.brokerId).filter(_ != controllerId).head + val tp = new TopicPartition("t", 0) + val assignment = Map(tp.partition -> Seq(controllerId)) + val reassignment = Map(tp -> ReplicaAssignment(Seq(otherBrokerId), List(), List())) + TestUtils.createTopic(zkClient, tp.topic, partitionReplicaAssignment = assignment, servers = servers) + zkClient.createPartitionReassignment(reassignment.map { case (k, v) => k -> v.replicas }) + waitForPartitionState(tp, firstControllerEpoch, otherBrokerId, LeaderAndIsr.initialLeaderEpoch + 3, + "failed to get expected partition state after partition reassignment") + TestUtils.waitUntilTrue(() => zkClient.getFullReplicaAssignmentForTopics(Set(tp.topic)) == reassignment, + "failed to get updated partition assignment on topic znode after partition reassignment") + TestUtils.waitUntilTrue(() => !zkClient.reassignPartitionsInProgress, + "failed to remove reassign partitions path after completion") + + val updatedTimerCount = timer(metricName).count + assertTrue(updatedTimerCount > timerCount, + s"Timer count $updatedTimerCount should be greater than $timerCount") + } + + @Test + def testPartitionReassignmentToBrokerWithOfflineLogDir(): Unit = { + servers = makeServers(2, logDirCount = 2) + val controllerId = TestUtils.waitUntilControllerElected(zkClient) + + val metricName = s"kafka.controller:type=ControllerStats,name=${ControllerState.AlterPartitionReassignment.rateAndTimeMetricName.get}" + val timerCount = timer(metricName).count + + val otherBroker = servers.filter(_.config.brokerId != controllerId).head + val otherBrokerId = otherBroker.config.brokerId + + // To have an offline log dir, we need a topicPartition assigned to it + val topicPartitionToPutOffline = new TopicPartition("filler", 0) + TestUtils.createTopic( + zkClient, + topicPartitionToPutOffline.topic, + partitionReplicaAssignment = Map(topicPartitionToPutOffline.partition -> Seq(otherBrokerId)), + servers = servers + ) + + TestUtils.causeLogDirFailure(TestUtils.Checkpoint, otherBroker, topicPartitionToPutOffline) + + val tp = new TopicPartition("t", 0) + val assignment = Map(tp.partition -> Seq(controllerId)) + val reassignment = Map(tp -> ReplicaAssignment(Seq(otherBrokerId), List(), List())) + TestUtils.createTopic(zkClient, tp.topic, partitionReplicaAssignment = assignment, servers = servers) + zkClient.createPartitionReassignment(reassignment.map { case (k, v) => k -> v.replicas }) + waitForPartitionState(tp, firstControllerEpoch, otherBrokerId, LeaderAndIsr.initialLeaderEpoch + 3, + "with an offline log directory on the target broker, the partition reassignment stalls") + TestUtils.waitUntilTrue(() => zkClient.getFullReplicaAssignmentForTopics(Set(tp.topic)) == reassignment, + "failed to get updated partition assignment on topic znode after partition reassignment") + TestUtils.waitUntilTrue(() => !zkClient.reassignPartitionsInProgress, + "failed to remove reassign partitions path after completion") + + val updatedTimerCount = timer(metricName).count + assertTrue(updatedTimerCount > timerCount, + s"Timer count $updatedTimerCount should be greater than $timerCount") + } + + @Test + def testPartitionReassignmentWithOfflineReplicaHaltingProgress(): Unit = { + servers = makeServers(2) + val controllerId = TestUtils.waitUntilControllerElected(zkClient) + val otherBrokerId = servers.map(_.config.brokerId).filter(_ != controllerId).head + val tp = new TopicPartition("t", 0) + val assignment = Map(tp.partition -> Seq(controllerId)) + val reassignment = Map(tp -> Seq(otherBrokerId)) + TestUtils.createTopic(zkClient, tp.topic, partitionReplicaAssignment = assignment, servers = servers) + servers(otherBrokerId).shutdown() + servers(otherBrokerId).awaitShutdown() + val controller = getController() + zkClient.setOrCreatePartitionReassignment(reassignment, controller.kafkaController.controllerContext.epochZkVersion) + waitForPartitionState(tp, firstControllerEpoch, controllerId, LeaderAndIsr.initialLeaderEpoch + 1, + "failed to get expected partition state during partition reassignment with offline replica") + TestUtils.waitUntilTrue(() => zkClient.reassignPartitionsInProgress, + "partition reassignment path should remain while reassignment in progress") + } + + @Test + def testPartitionReassignmentResumesAfterReplicaComesOnline(): Unit = { + servers = makeServers(2) + val controllerId = TestUtils.waitUntilControllerElected(zkClient) + val otherBrokerId = servers.map(_.config.brokerId).filter(_ != controllerId).head + val tp = new TopicPartition("t", 0) + val assignment = Map(tp.partition -> Seq(controllerId)) + val reassignment = Map(tp -> ReplicaAssignment(Seq(otherBrokerId), List(), List())) + TestUtils.createTopic(zkClient, tp.topic, partitionReplicaAssignment = assignment, servers = servers) + servers(otherBrokerId).shutdown() + servers(otherBrokerId).awaitShutdown() + zkClient.createPartitionReassignment(reassignment.map { case (k, v) => k -> v.replicas }) + waitForPartitionState(tp, firstControllerEpoch, controllerId, LeaderAndIsr.initialLeaderEpoch + 1, + "failed to get expected partition state during partition reassignment with offline replica") + servers(otherBrokerId).startup() + waitForPartitionState(tp, firstControllerEpoch, otherBrokerId, LeaderAndIsr.initialLeaderEpoch + 4, + "failed to get expected partition state after partition reassignment") + TestUtils.waitUntilTrue(() => zkClient.getFullReplicaAssignmentForTopics(Set(tp.topic)) == reassignment, + "failed to get updated partition assignment on topic znode after partition reassignment") + TestUtils.waitUntilTrue(() => !zkClient.reassignPartitionsInProgress, + "failed to remove reassign partitions path after completion") + } + + @Test + def testPreferredReplicaLeaderElection(): Unit = { + servers = makeServers(2) + val controllerId = TestUtils.waitUntilControllerElected(zkClient) + val otherBroker = servers.find(_.config.brokerId != controllerId).get + val tp = new TopicPartition("t", 0) + val assignment = Map(tp.partition -> Seq(otherBroker.config.brokerId, controllerId)) + TestUtils.createTopic(zkClient, tp.topic, partitionReplicaAssignment = assignment, servers = servers) + preferredReplicaLeaderElection(controllerId, otherBroker, tp, assignment(tp.partition).toSet, LeaderAndIsr.initialLeaderEpoch) + } + + @Test + def testBackToBackPreferredReplicaLeaderElections(): Unit = { + servers = makeServers(2) + val controllerId = TestUtils.waitUntilControllerElected(zkClient) + val otherBroker = servers.find(_.config.brokerId != controllerId).get + val tp = new TopicPartition("t", 0) + val assignment = Map(tp.partition -> Seq(otherBroker.config.brokerId, controllerId)) + TestUtils.createTopic(zkClient, tp.topic, partitionReplicaAssignment = assignment, servers = servers) + preferredReplicaLeaderElection(controllerId, otherBroker, tp, assignment(tp.partition).toSet, LeaderAndIsr.initialLeaderEpoch) + preferredReplicaLeaderElection(controllerId, otherBroker, tp, assignment(tp.partition).toSet, LeaderAndIsr.initialLeaderEpoch + 2) + } + + @Test + def testPreferredReplicaLeaderElectionWithOfflinePreferredReplica(): Unit = { + servers = makeServers(2) + val controllerId = TestUtils.waitUntilControllerElected(zkClient) + val otherBrokerId = servers.map(_.config.brokerId).filter(_ != controllerId).head + val tp = new TopicPartition("t", 0) + val assignment = Map(tp.partition -> Seq(otherBrokerId, controllerId)) + TestUtils.createTopic(zkClient, tp.topic, partitionReplicaAssignment = assignment, servers = servers) + servers(otherBrokerId).shutdown() + servers(otherBrokerId).awaitShutdown() + zkClient.createPreferredReplicaElection(Set(tp)) + TestUtils.waitUntilTrue(() => !zkClient.pathExists(PreferredReplicaElectionZNode.path), + "failed to remove preferred replica leader election path after giving up") + waitForPartitionState(tp, firstControllerEpoch, controllerId, LeaderAndIsr.initialLeaderEpoch + 1, + "failed to get expected partition state upon broker shutdown") + } + + @Test + def testAutoPreferredReplicaLeaderElection(): Unit = { + servers = makeServers(2, autoLeaderRebalanceEnable = true) + val controllerId = TestUtils.waitUntilControllerElected(zkClient) + val otherBrokerId = servers.map(_.config.brokerId).filter(_ != controllerId).head + val tp = new TopicPartition("t", 0) + val assignment = Map(tp.partition -> Seq(1, 0)) + TestUtils.createTopic(zkClient, tp.topic, partitionReplicaAssignment = assignment, servers = servers) + servers(otherBrokerId).shutdown() + servers(otherBrokerId).awaitShutdown() + waitForPartitionState(tp, firstControllerEpoch, controllerId, LeaderAndIsr.initialLeaderEpoch + 1, + "failed to get expected partition state upon broker shutdown") + servers(otherBrokerId).startup() + waitForPartitionState(tp, firstControllerEpoch, otherBrokerId, LeaderAndIsr.initialLeaderEpoch + 2, + "failed to get expected partition state upon broker startup") + } + + @Test + def testLeaderAndIsrWhenEntireIsrOfflineAndUncleanLeaderElectionDisabled(): Unit = { + servers = makeServers(2) + val controllerId = TestUtils.waitUntilControllerElected(zkClient) + val otherBrokerId = servers.map(_.config.brokerId).filter(_ != controllerId).head + val tp = new TopicPartition("t", 0) + val assignment = Map(tp.partition -> Seq(otherBrokerId)) + TestUtils.createTopic(zkClient, tp.topic, partitionReplicaAssignment = assignment, servers = servers) + waitForPartitionState(tp, firstControllerEpoch, otherBrokerId, LeaderAndIsr.initialLeaderEpoch, + "failed to get expected partition state upon topic creation") + servers(otherBrokerId).shutdown() + servers(otherBrokerId).awaitShutdown() + TestUtils.waitUntilTrue(() => { + val leaderIsrAndControllerEpochMap = zkClient.getTopicPartitionStates(Seq(tp)) + leaderIsrAndControllerEpochMap.contains(tp) && + isExpectedPartitionState(leaderIsrAndControllerEpochMap(tp), firstControllerEpoch, LeaderAndIsr.NoLeader, LeaderAndIsr.initialLeaderEpoch + 1) && + leaderIsrAndControllerEpochMap(tp).leaderAndIsr.isr == List(otherBrokerId) + }, "failed to get expected partition state after entire isr went offline") + } + + @Test + def testLeaderAndIsrWhenEntireIsrOfflineAndUncleanLeaderElectionEnabled(): Unit = { + servers = makeServers(2, uncleanLeaderElectionEnable = true) + val controllerId = TestUtils.waitUntilControllerElected(zkClient) + val otherBrokerId = servers.map(_.config.brokerId).filter(_ != controllerId).head + val tp = new TopicPartition("t", 0) + val assignment = Map(tp.partition -> Seq(otherBrokerId)) + TestUtils.createTopic(zkClient, tp.topic, partitionReplicaAssignment = assignment, servers = servers) + waitForPartitionState(tp, firstControllerEpoch, otherBrokerId, LeaderAndIsr.initialLeaderEpoch, + "failed to get expected partition state upon topic creation") + servers(otherBrokerId).shutdown() + servers(otherBrokerId).awaitShutdown() + TestUtils.waitUntilTrue(() => { + val leaderIsrAndControllerEpochMap = zkClient.getTopicPartitionStates(Seq(tp)) + leaderIsrAndControllerEpochMap.contains(tp) && + isExpectedPartitionState(leaderIsrAndControllerEpochMap(tp), firstControllerEpoch, LeaderAndIsr.NoLeader, LeaderAndIsr.initialLeaderEpoch + 1) && + leaderIsrAndControllerEpochMap(tp).leaderAndIsr.isr == List(otherBrokerId) + }, "failed to get expected partition state after entire isr went offline") + } + + @Test + def testControlledShutdown(): Unit = { + val expectedReplicaAssignment = Map(0 -> List(0, 1, 2)) + val topic = "test" + val partition = 0 + // create brokers + val serverConfigs = TestUtils.createBrokerConfigs(3, zkConnect, false).map(KafkaConfig.fromProps) + servers = serverConfigs.reverse.map(s => TestUtils.createServer(s)) + // create the topic + TestUtils.createTopic(zkClient, topic, partitionReplicaAssignment = expectedReplicaAssignment, servers = servers) + + val controllerId = zkClient.getControllerId.get + val controller = servers.find(p => p.config.brokerId == controllerId).get.kafkaController + val resultQueue = new LinkedBlockingQueue[Try[collection.Set[TopicPartition]]]() + val controlledShutdownCallback = (controlledShutdownResult: Try[collection.Set[TopicPartition]]) => resultQueue.put(controlledShutdownResult) + controller.controlledShutdown(2, servers.find(_.config.brokerId == 2).get.kafkaController.brokerEpoch, controlledShutdownCallback) + var partitionsRemaining = resultQueue.take().get + var activeServers = servers.filter(s => s.config.brokerId != 2) + // wait for the update metadata request to trickle to the brokers + TestUtils.waitUntilTrue(() => + activeServers.forall(_.dataPlaneRequestProcessor.metadataCache.getPartitionInfo(topic,partition).get.isr.size != 3), + "Topic test not created after timeout") + assertEquals(0, partitionsRemaining.size) + var partitionStateInfo = activeServers.head.dataPlaneRequestProcessor.metadataCache.getPartitionInfo(topic,partition).get + var leaderAfterShutdown = partitionStateInfo.leader + assertEquals(0, leaderAfterShutdown) + assertEquals(2, partitionStateInfo.isr.size) + assertEquals(List(0,1), partitionStateInfo.isr.asScala) + controller.controlledShutdown(1, servers.find(_.config.brokerId == 1).get.kafkaController.brokerEpoch, controlledShutdownCallback) + partitionsRemaining = resultQueue.take() match { + case Success(partitions) => partitions + case Failure(exception) => throw new AssertionError("Controlled shutdown failed due to error", exception) + } + assertEquals(0, partitionsRemaining.size) + activeServers = servers.filter(s => s.config.brokerId == 0) + partitionStateInfo = activeServers.head.dataPlaneRequestProcessor.metadataCache.getPartitionInfo(topic,partition).get + leaderAfterShutdown = partitionStateInfo.leader + assertEquals(0, leaderAfterShutdown) + + assertTrue(servers.forall(_.dataPlaneRequestProcessor.metadataCache.getPartitionInfo(topic,partition).get.leader == 0)) + controller.controlledShutdown(0, servers.find(_.config.brokerId == 0).get.kafkaController.brokerEpoch, controlledShutdownCallback) + partitionsRemaining = resultQueue.take().get + assertEquals(1, partitionsRemaining.size) + // leader doesn't change since all the replicas are shut down + assertTrue(servers.forall(_.dataPlaneRequestProcessor.metadataCache.getPartitionInfo(topic,partition).get.leader == 0)) + } + + @Test + def testControllerRejectControlledShutdownRequestWithStaleBrokerEpoch(): Unit = { + // create brokers + val serverConfigs = TestUtils.createBrokerConfigs(2, zkConnect, false).map(KafkaConfig.fromProps) + servers = serverConfigs.reverse.map(s => TestUtils.createServer(s)) + + val controller = getController().kafkaController + val otherBroker = servers.find(e => e.config.brokerId != controller.config.brokerId).get + @volatile var staleBrokerEpochDetected = false + controller.controlledShutdown(otherBroker.config.brokerId, otherBroker.kafkaController.brokerEpoch - 1, { + case scala.util.Failure(exception) if exception.isInstanceOf[StaleBrokerEpochException] => staleBrokerEpochDetected = true + case _ => + }) + + TestUtils.waitUntilTrue(() => staleBrokerEpochDetected, "Fail to detect stale broker epoch") + } + + @Test + def testControllerMoveOnTopicCreation(): Unit = { + servers = makeServers(1) + TestUtils.waitUntilControllerElected(zkClient) + val tp = new TopicPartition("t", 0) + val assignment = Map(tp.partition -> Seq(0)) + + testControllerMove(() => { + val adminZkClient = new AdminZkClient(zkClient) + adminZkClient.createTopicWithAssignment(tp.topic, config = new Properties(), assignment) + }) + } + + @Test + def testControllerMoveOnTopicDeletion(): Unit = { + servers = makeServers(1) + TestUtils.waitUntilControllerElected(zkClient) + val tp = new TopicPartition("t", 0) + val assignment = Map(tp.partition -> Seq(0)) + TestUtils.createTopic(zkClient, tp.topic(), assignment, servers) + + testControllerMove(() => { + val adminZkClient = new AdminZkClient(zkClient) + adminZkClient.deleteTopic(tp.topic()) + }) + } + + @Test + def testControllerMoveOnPreferredReplicaElection(): Unit = { + servers = makeServers(1) + val tp = new TopicPartition("t", 0) + val assignment = Map(tp.partition -> Seq(0)) + TestUtils.createTopic(zkClient, tp.topic(), assignment, servers) + + testControllerMove(() => zkClient.createPreferredReplicaElection(Set(tp))) + } + + @Test + def testControllerMoveOnPartitionReassignment(): Unit = { + servers = makeServers(1) + TestUtils.waitUntilControllerElected(zkClient) + val tp = new TopicPartition("t", 0) + val assignment = Map(tp.partition -> Seq(0)) + TestUtils.createTopic(zkClient, tp.topic(), assignment, servers) + + val reassignment = Map(tp -> Seq(0)) + testControllerMove(() => zkClient.createPartitionReassignment(reassignment)) + } + + @Test + def testControllerFeatureZNodeSetupWhenFeatureVersioningIsEnabledWithNonExistingFeatureZNode(): Unit = { + testControllerFeatureZNodeSetup(Option.empty, KAFKA_2_7_IV0) + } + + @Test + def testControllerFeatureZNodeSetupWhenFeatureVersioningIsEnabledWithDisabledExistingFeatureZNode(): Unit = { + testControllerFeatureZNodeSetup(Some(new FeatureZNode(FeatureZNodeStatus.Disabled, Features.emptyFinalizedFeatures())), KAFKA_2_7_IV0) + } + + @Test + def testControllerFeatureZNodeSetupWhenFeatureVersioningIsEnabledWithEnabledExistingFeatureZNode(): Unit = { + testControllerFeatureZNodeSetup(Some(new FeatureZNode(FeatureZNodeStatus.Enabled, Features.emptyFinalizedFeatures())), KAFKA_2_7_IV0) + } + + @Test + def testControllerFeatureZNodeSetupWhenFeatureVersioningIsDisabledWithNonExistingFeatureZNode(): Unit = { + testControllerFeatureZNodeSetup(Option.empty, KAFKA_2_6_IV0) + } + + @Test + def testControllerFeatureZNodeSetupWhenFeatureVersioningIsDisabledWithDisabledExistingFeatureZNode(): Unit = { + testControllerFeatureZNodeSetup(Some(new FeatureZNode(FeatureZNodeStatus.Disabled, Features.emptyFinalizedFeatures())), KAFKA_2_6_IV0) + } + + @Test + def testControllerFeatureZNodeSetupWhenFeatureVersioningIsDisabledWithEnabledExistingFeatureZNode(): Unit = { + testControllerFeatureZNodeSetup(Some(new FeatureZNode(FeatureZNodeStatus.Enabled, Features.emptyFinalizedFeatures())), KAFKA_2_6_IV0) + } + + @Test + def testControllerDetectsBouncedBrokers(): Unit = { + servers = makeServers(2, enableControlledShutdown = false) + val controller = getController().kafkaController + val otherBroker = servers.find(e => e.config.brokerId != controller.config.brokerId).get + + // Create a topic + val tp = new TopicPartition("t", 0) + val assignment = Map(tp.partition -> Seq(0, 1)) + + TestUtils.createTopic(zkClient, tp.topic, partitionReplicaAssignment = assignment, servers = servers) + waitForPartitionState(tp, firstControllerEpoch, 0, LeaderAndIsr.initialLeaderEpoch, + "failed to get expected partition state upon topic creation") + + // Wait until the event thread is idle + TestUtils.waitUntilTrue(() => { + controller.eventManager.state == ControllerState.Idle + }, "Controller event thread is still busy") + + val latch = new CountDownLatch(1) + + // Let the controller event thread await on a latch until broker bounce finishes. + // This is used to simulate fast broker bounce + + controller.eventManager.put(new MockEvent(ControllerState.TopicChange) { + override def process(): Unit = latch.await() + override def preempt(): Unit = {} + }) + + otherBroker.shutdown() + otherBroker.startup() + + assertEquals(0, otherBroker.replicaManager.partitionCount.value()) + + // Release the latch so that controller can process broker change event + latch.countDown() + TestUtils.waitUntilTrue(() => { + otherBroker.replicaManager.partitionCount.value() == 1 && + otherBroker.replicaManager.metadataCache.getAllTopics().size == 1 && + otherBroker.replicaManager.metadataCache.getAliveBrokers().size == 2 + }, "Broker fail to initialize after restart") + } + + @Test + def testPreemptionOnControllerShutdown(): Unit = { + servers = makeServers(1, enableControlledShutdown = false) + val controller = getController().kafkaController + var count = 2 + val latch = new CountDownLatch(1) + val spyThread = spy(controller.eventManager.thread) + controller.eventManager.thread = spyThread + val processedEvent = new MockEvent(ControllerState.TopicChange) { + override def process(): Unit = latch.await() + override def preempt(): Unit = {} + } + val preemptedEvent = new MockEvent(ControllerState.TopicChange) { + override def process(): Unit = {} + override def preempt(): Unit = count -= 1 + } + + controller.eventManager.put(processedEvent) + controller.eventManager.put(preemptedEvent) + controller.eventManager.put(preemptedEvent) + + doAnswer((_: InvocationOnMock) => { + latch.countDown() + }).doCallRealMethod().when(spyThread).awaitShutdown() + controller.shutdown() + TestUtils.waitUntilTrue(() => { + count == 0 + }, "preemption was not fully completed before shutdown") + + verify(spyThread).awaitShutdown() + } + + @Test + def testPreemptionWithCallbacks(): Unit = { + servers = makeServers(1, enableControlledShutdown = false) + val controller = getController().kafkaController + val latch = new CountDownLatch(1) + val spyThread = spy(controller.eventManager.thread) + controller.eventManager.thread = spyThread + val processedEvent = new MockEvent(ControllerState.TopicChange) { + override def process(): Unit = latch.await() + override def preempt(): Unit = {} + } + val tp0 = new TopicPartition("t", 0) + val tp1 = new TopicPartition("t", 1) + val partitions = Set(tp0, tp1) + val event1 = ReplicaLeaderElection(Some(partitions), ElectionType.PREFERRED, ZkTriggered, partitionsMap => { + for (partition <- partitionsMap) { + partition._2 match { + case Left(e) => assertEquals(Errors.NOT_CONTROLLER, e.error()) + case Right(_) => throw new AssertionError("replica leader election should error") + } + } + }) + val event2 = ControlledShutdown(0, 0, { + case Success(_) => throw new AssertionError("controlled shutdown should error") + case Failure(e) => + assertEquals(classOf[ControllerMovedException], e.getClass) + }) + val event3 = ApiPartitionReassignment(Map(tp0 -> None, tp1 -> None), { + case Left(_) => throw new AssertionError("api partition reassignment should error") + case Right(e) => assertEquals(Errors.NOT_CONTROLLER, e.error()) + }) + val event4 = ListPartitionReassignments(Some(partitions), { + case Left(_) => throw new AssertionError("api partition reassignment should error") + case Right(e) => assertEquals(Errors.NOT_CONTROLLER, e.error()) + }) + + controller.eventManager.put(processedEvent) + controller.eventManager.put(event1) + controller.eventManager.put(event2) + controller.eventManager.put(event3) + controller.eventManager.put(event4) + + doAnswer((_: InvocationOnMock) => { + latch.countDown() + }).doCallRealMethod().when(spyThread).awaitShutdown() + controller.shutdown() + } + + private def testControllerFeatureZNodeSetup(initialZNode: Option[FeatureZNode], + interBrokerProtocolVersion: ApiVersion): Unit = { + val versionBeforeOpt = initialZNode match { + case Some(node) => + zkClient.createFeatureZNode(node) + Some(zkClient.getDataAndVersion(FeatureZNode.path)._2) + case None => + Option.empty + } + servers = makeServers(1, interBrokerProtocolVersion = Some(interBrokerProtocolVersion)) + TestUtils.waitUntilControllerElected(zkClient) + // Below we wait on a dummy event to finish processing in the controller event thread. + // We schedule this dummy event only after the controller is elected, which is a sign that the + // controller has already started processing the Startup event. Waiting on the dummy event is + // used to make sure that the event thread has completed processing Startup event, that triggers + // the setup of FeatureZNode. + val controller = getController().kafkaController + val latch = new CountDownLatch(1) + controller.eventManager.put(new MockEvent(ControllerState.TopicChange) { + override def process(): Unit = { + latch.countDown() + } + override def preempt(): Unit = {} + }) + latch.await() + + val (mayBeFeatureZNodeBytes, versionAfter) = zkClient.getDataAndVersion(FeatureZNode.path) + val newZNode = FeatureZNode.decode(mayBeFeatureZNodeBytes.get) + if (interBrokerProtocolVersion >= KAFKA_2_7_IV0) { + val emptyZNode = new FeatureZNode(FeatureZNodeStatus.Enabled, Features.emptyFinalizedFeatures) + initialZNode match { + case Some(node) => { + node.status match { + case FeatureZNodeStatus.Enabled => + assertEquals(versionBeforeOpt.get, versionAfter) + assertEquals(node, newZNode) + case FeatureZNodeStatus.Disabled => + assertEquals(versionBeforeOpt.get + 1, versionAfter) + assertEquals(emptyZNode, newZNode) + } + } + case None => + assertEquals(0, versionAfter) + assertEquals(new FeatureZNode(FeatureZNodeStatus.Enabled, Features.emptyFinalizedFeatures), newZNode) + } + } else { + val emptyZNode = new FeatureZNode(FeatureZNodeStatus.Disabled, Features.emptyFinalizedFeatures) + initialZNode match { + case Some(node) => { + node.status match { + case FeatureZNodeStatus.Enabled => + assertEquals(versionBeforeOpt.get + 1, versionAfter) + assertEquals(emptyZNode, newZNode) + case FeatureZNodeStatus.Disabled => + assertEquals(versionBeforeOpt.get, versionAfter) + assertEquals(emptyZNode, newZNode) + } + } + case None => + assertEquals(0, versionAfter) + assertEquals(new FeatureZNode(FeatureZNodeStatus.Disabled, Features.emptyFinalizedFeatures), newZNode) + } + } + } + + @Test + def testIdempotentAlterIsr(): Unit = { + servers = makeServers(2) + val controllerId = TestUtils.waitUntilControllerElected(zkClient) + val otherBroker = servers.find(_.config.brokerId != controllerId).get + val tp = new TopicPartition("t", 0) + val assignment = Map(tp.partition -> Seq(otherBroker.config.brokerId, controllerId)) + TestUtils.createTopic(zkClient, tp.topic, partitionReplicaAssignment = assignment, servers = servers) + + val latch = new CountDownLatch(1) + val controller = getController().kafkaController + + val leaderIsrAndControllerEpochMap = zkClient.getTopicPartitionStates(Seq(tp)) + val newLeaderAndIsr = leaderIsrAndControllerEpochMap(tp).leaderAndIsr + + val callback = (result: Either[Map[TopicPartition, Either[Errors, LeaderAndIsr]], Errors]) => { + result match { + case Left(partitionResults: Map[TopicPartition, Either[Errors, LeaderAndIsr]]) => + partitionResults.get(tp) match { + case Some(Left(error: Errors)) => throw new AssertionError(s"Should not have seen error for $tp") + case Some(Right(leaderAndIsr: LeaderAndIsr)) => assertEquals(leaderAndIsr, newLeaderAndIsr, "ISR should remain unchanged") + case None => throw new AssertionError(s"Should have seen $tp in result") + } + case Right(_: Errors) => throw new AssertionError("Should not have had top-level error here") + } + latch.countDown() + } + + val brokerEpoch = controller.controllerContext.liveBrokerIdAndEpochs.get(otherBroker.config.brokerId).get + // When re-sending the current ISR, we should not get and error or any ISR changes + controller.eventManager.put(AlterIsrReceived(otherBroker.config.brokerId, brokerEpoch, Map(tp -> newLeaderAndIsr), callback)) + latch.await() + } + + @Test + def testAlterIsrErrors(): Unit = { + servers = makeServers(1) + val controllerId = TestUtils.waitUntilControllerElected(zkClient) + val tp = new TopicPartition("t", 0) + val assignment = Map(tp.partition -> Seq(controllerId)) + TestUtils.createTopic(zkClient, tp.topic, partitionReplicaAssignment = assignment, servers = servers) + val controller = getController().kafkaController + var future = captureAlterIsrError(controllerId, controller.brokerEpoch - 1, + Map(tp -> LeaderAndIsr(controllerId, List(controllerId)))) + var capturedError = future.get(5, TimeUnit.SECONDS) + assertEquals(Errors.STALE_BROKER_EPOCH, capturedError) + + future = captureAlterIsrError(99, controller.brokerEpoch, + Map(tp -> LeaderAndIsr(controllerId, List(controllerId)))) + capturedError = future.get(5, TimeUnit.SECONDS) + assertEquals(Errors.STALE_BROKER_EPOCH, capturedError) + + val unknownTopicPartition = new TopicPartition("unknown", 99) + future = captureAlterIsrPartitionError(controllerId, controller.brokerEpoch, + Map(unknownTopicPartition -> LeaderAndIsr(controllerId, List(controllerId))), unknownTopicPartition) + capturedError = future.get(5, TimeUnit.SECONDS) + assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, capturedError) + + future = captureAlterIsrPartitionError(controllerId, controller.brokerEpoch, + Map(tp -> LeaderAndIsr(controllerId, 1, List(controllerId), 99)), tp) + capturedError = future.get(5, TimeUnit.SECONDS) + assertEquals(Errors.INVALID_UPDATE_VERSION, capturedError) + } + + def captureAlterIsrError(brokerId: Int, brokerEpoch: Long, isrsToAlter: Map[TopicPartition, LeaderAndIsr]): CompletableFuture[Errors] = { + val future = new CompletableFuture[Errors]() + val controller = getController().kafkaController + val callback: AlterIsrCallback = { + case Left(_: Map[TopicPartition, Either[Errors, LeaderAndIsr]]) => + future.completeExceptionally(new AssertionError(s"Should have seen top-level error")) + case Right(error: Errors) => + future.complete(error) + } + controller.eventManager.put(AlterIsrReceived(brokerId, brokerEpoch, isrsToAlter, callback)) + future + } + + def captureAlterIsrPartitionError(brokerId: Int, brokerEpoch: Long, isrsToAlter: Map[TopicPartition, LeaderAndIsr], tp: TopicPartition): CompletableFuture[Errors] = { + val future = new CompletableFuture[Errors]() + val controller = getController().kafkaController + val callback: AlterIsrCallback = { + case Left(partitionResults: Map[TopicPartition, Either[Errors, LeaderAndIsr]]) => + partitionResults.get(tp) match { + case Some(Left(error: Errors)) => future.complete(error) + case Some(Right(_: LeaderAndIsr)) => future.completeExceptionally(new AssertionError(s"Should have seen an error for $tp in result")) + case None => future.completeExceptionally(new AssertionError(s"Should have seen $tp in result")) + } + case Right(_: Errors) => + future.completeExceptionally(new AssertionError(s"Should not seen top-level error")) + } + controller.eventManager.put(AlterIsrReceived(brokerId, brokerEpoch, isrsToAlter, callback)) + future + } + + @Test + def testTopicIdsAreAdded(): Unit = { + servers = makeServers(1) + TestUtils.waitUntilControllerElected(zkClient) + val controller = getController().kafkaController + val tp1 = new TopicPartition("t1", 0) + val assignment1 = Map(tp1.partition -> Seq(0)) + + // Before adding the topic, an attempt to get the ID should result in None. + assertEquals(None, controller.controllerContext.topicIds.get("t1")) + + TestUtils.createTopic(zkClient, tp1.topic(), assignment1, servers) + + // Test that the first topic has its ID added correctly + waitForPartitionState(tp1, firstControllerEpoch, 0, LeaderAndIsr.initialLeaderEpoch, + "failed to get expected partition state upon topic creation") + assertNotEquals(None, controller.controllerContext.topicIds.get("t1")) + val topicId1 = controller.controllerContext.topicIds("t1") + assertEquals("t1", controller.controllerContext.topicNames(topicId1)) + + val tp2 = new TopicPartition("t2", 0) + val assignment2 = Map(tp2.partition -> Seq(0)) + TestUtils.createTopic(zkClient, tp2.topic(), assignment2, servers) + + // Test that the second topic has its ID added correctly + waitForPartitionState(tp2, firstControllerEpoch, 0, LeaderAndIsr.initialLeaderEpoch, + "failed to get expected partition state upon topic creation") + assertNotEquals(None, controller.controllerContext.topicIds.get("t2")) + val topicId2 = controller.controllerContext.topicIds("t2") + assertEquals("t2", controller.controllerContext.topicNames(topicId2)) + + // The first topic ID has not changed + assertEquals(topicId1, controller.controllerContext.topicIds.get("t1").get) + assertNotEquals(topicId1, topicId2) + } + + @Test + def testTopicIdsAreNotAdded(): Unit = { + servers = makeServers(1, interBrokerProtocolVersion = Some(KAFKA_2_7_IV0)) + TestUtils.waitUntilControllerElected(zkClient) + val controller = getController().kafkaController + val tp1 = new TopicPartition("t1", 0) + val assignment1 = Map(tp1.partition -> Seq(0)) + + // Before adding the topic, an attempt to get the ID should result in None. + assertEquals(None, controller.controllerContext.topicIds.get("t1")) + + TestUtils.createTopic(zkClient, tp1.topic(), assignment1, servers) + + // Test that the first topic has no topic ID added. + waitForPartitionState(tp1, firstControllerEpoch, 0, LeaderAndIsr.initialLeaderEpoch, + "failed to get expected partition state upon topic creation") + assertEquals(None, controller.controllerContext.topicIds.get("t1")) + + val tp2 = new TopicPartition("t2", 0) + val assignment2 = Map(tp2.partition -> Seq(0)) + TestUtils.createTopic(zkClient, tp2.topic(), assignment2, servers) + + // Test that the second topic has no topic ID added. + waitForPartitionState(tp2, firstControllerEpoch, 0, LeaderAndIsr.initialLeaderEpoch, + "failed to get expected partition state upon topic creation") + assertEquals(None, controller.controllerContext.topicIds.get("t2")) + + // The first topic ID has not changed + assertEquals(None, controller.controllerContext.topicIds.get("t1")) + } + + + @Test + def testTopicIdMigrationAndHandling(): Unit = { + val tp = new TopicPartition("t", 0) + val assignment = Map(tp.partition -> ReplicaAssignment(Seq(0), List(), List())) + val adminZkClient = new AdminZkClient(zkClient) + + servers = makeServers(1) + adminZkClient.createTopic(tp.topic, 1, 1) + waitForPartitionState(tp, firstControllerEpoch, 0, LeaderAndIsr.initialLeaderEpoch, + "failed to get expected partition state upon topic creation") + val topicIdAfterCreate = zkClient.getTopicIdsForTopics(Set(tp.topic())).get(tp.topic()) + assertTrue(topicIdAfterCreate.isDefined) + assertEquals(topicIdAfterCreate, servers.head.kafkaController.controllerContext.topicIds.get(tp.topic), + "correct topic ID cannot be found in the controller context") + + adminZkClient.addPartitions(tp.topic, assignment, adminZkClient.getBrokerMetadatas(), 2) + val topicIdAfterAddition = zkClient.getTopicIdsForTopics(Set(tp.topic())).get(tp.topic()) + assertEquals(topicIdAfterCreate, topicIdAfterAddition) + assertEquals(topicIdAfterCreate, servers.head.kafkaController.controllerContext.topicIds.get(tp.topic), + "topic ID changed after partition additions") + + adminZkClient.deleteTopic(tp.topic) + TestUtils.waitUntilTrue(() => servers.head.kafkaController.controllerContext.topicIds.get(tp.topic).isEmpty, + "topic ID for topic should have been removed from controller context after deletion") + } + + @Test + def testTopicIdMigrationAndHandlingWithOlderVersion(): Unit = { + val tp = new TopicPartition("t", 0) + val assignment = Map(tp.partition -> ReplicaAssignment(Seq(0), List(), List())) + val adminZkClient = new AdminZkClient(zkClient) + + servers = makeServers(1, interBrokerProtocolVersion = Some(KAFKA_2_7_IV0)) + adminZkClient.createTopic(tp.topic, 1, 1) + waitForPartitionState(tp, firstControllerEpoch, 0, LeaderAndIsr.initialLeaderEpoch, + "failed to get expected partition state upon topic creation") + val topicIdAfterCreate = zkClient.getTopicIdsForTopics(Set(tp.topic())).get(tp.topic()) + assertEquals(None, topicIdAfterCreate) + assertEquals(topicIdAfterCreate, servers.head.kafkaController.controllerContext.topicIds.get(tp.topic), + "incorrect topic ID can be found in the controller context") + + adminZkClient.addPartitions(tp.topic, assignment, adminZkClient.getBrokerMetadatas(), 2) + val topicIdAfterAddition = zkClient.getTopicIdsForTopics(Set(tp.topic())).get(tp.topic()) + assertEquals(topicIdAfterCreate, topicIdAfterAddition) + assertEquals(topicIdAfterCreate, servers.head.kafkaController.controllerContext.topicIds.get(tp.topic), + "topic ID changed after partition additions") + + adminZkClient.deleteTopic(tp.topic) + TestUtils.waitUntilTrue(() => !servers.head.kafkaController.controllerContext.allTopics.contains(tp.topic), + "topic should have been removed from controller context after deletion") + } + + @Test + def testTopicIdPersistsThroughControllerReelection(): Unit = { + servers = makeServers(2) + val controllerId = TestUtils.waitUntilControllerElected(zkClient) + val controller = getController().kafkaController + val tp = new TopicPartition("t", 0) + val assignment = Map(tp.partition -> Seq(controllerId)) + TestUtils.createTopic(zkClient, tp.topic, partitionReplicaAssignment = assignment, servers = servers) + waitForPartitionState(tp, firstControllerEpoch, controllerId, LeaderAndIsr.initialLeaderEpoch, + "failed to get expected partition state upon topic creation") + val topicId = controller.controllerContext.topicIds.get("t").get + + servers(controllerId).shutdown() + servers(controllerId).awaitShutdown() + TestUtils.waitUntilTrue(() => zkClient.getControllerId.isDefined, "failed to elect a controller") + val controller2 = getController().kafkaController + assertEquals(topicId, controller2.controllerContext.topicIds.get("t").get) + } + + @Test + def testNoTopicIdPersistsThroughControllerReelection(): Unit = { + servers = makeServers(2, interBrokerProtocolVersion = Some(KAFKA_2_7_IV0)) + val controllerId = TestUtils.waitUntilControllerElected(zkClient) + val controller = getController().kafkaController + val tp = new TopicPartition("t", 0) + val assignment = Map(tp.partition -> Seq(controllerId)) + TestUtils.createTopic(zkClient, tp.topic, partitionReplicaAssignment = assignment, servers = servers) + waitForPartitionState(tp, firstControllerEpoch, controllerId, LeaderAndIsr.initialLeaderEpoch, + "failed to get expected partition state upon topic creation") + val emptyTopicId = controller.controllerContext.topicIds.get("t") + assertEquals(None, emptyTopicId) + + servers(controllerId).shutdown() + servers(controllerId).awaitShutdown() + TestUtils.waitUntilTrue(() => zkClient.getControllerId.isDefined, "failed to elect a controller") + val controller2 = getController().kafkaController + assertEquals(emptyTopicId, controller2.controllerContext.topicIds.get("t")) + } + + @Test + def testTopicIdPersistsThroughControllerRestart(): Unit = { + servers = makeServers(1) + val controllerId = TestUtils.waitUntilControllerElected(zkClient) + val controller = getController().kafkaController + val tp = new TopicPartition("t", 0) + val assignment = Map(tp.partition -> Seq(controllerId)) + TestUtils.createTopic(zkClient, tp.topic, partitionReplicaAssignment = assignment, servers = servers) + waitForPartitionState(tp, firstControllerEpoch, controllerId, LeaderAndIsr.initialLeaderEpoch, + "failed to get expected partition state upon topic creation") + val topicId = controller.controllerContext.topicIds.get("t").get + + servers(controllerId).shutdown() + servers(controllerId).awaitShutdown() + servers(controllerId).startup() + TestUtils.waitUntilTrue(() => zkClient.getControllerId.isDefined, "failed to elect a controller") + val controller2 = getController().kafkaController + assertEquals(topicId, controller2.controllerContext.topicIds.get("t").get) + } + + @Test + def testTopicIdCreatedOnUpgrade(): Unit = { + servers = makeServers(1, interBrokerProtocolVersion = Some(KAFKA_2_7_IV0)) + val controllerId = TestUtils.waitUntilControllerElected(zkClient) + val controller = getController().kafkaController + val tp = new TopicPartition("t", 0) + val assignment = Map(tp.partition -> Seq(controllerId)) + TestUtils.createTopic(zkClient, tp.topic, partitionReplicaAssignment = assignment, servers = servers) + waitForPartitionState(tp, firstControllerEpoch, controllerId, LeaderAndIsr.initialLeaderEpoch, + "failed to get expected partition state upon topic creation") + val topicIdAfterCreate = zkClient.getTopicIdsForTopics(Set(tp.topic())).get(tp.topic()) + assertEquals(None, topicIdAfterCreate) + val emptyTopicId = controller.controllerContext.topicIds.get("t") + assertEquals(None, emptyTopicId) + + servers(controllerId).shutdown() + servers(controllerId).awaitShutdown() + servers = makeServers(1) + TestUtils.waitUntilTrue(() => zkClient.getControllerId.isDefined, "failed to elect a controller") + val topicIdAfterUpgrade = zkClient.getTopicIdsForTopics(Set(tp.topic())).get(tp.topic()) + assertNotEquals(emptyTopicId, topicIdAfterUpgrade) + val controller2 = getController().kafkaController + assertNotEquals(emptyTopicId, controller2.controllerContext.topicIds.get("t")) + val topicId = controller2.controllerContext.topicIds.get("t").get + assertEquals(topicIdAfterUpgrade.get, topicId) + assertEquals("t", controller2.controllerContext.topicNames(topicId)) + + TestUtils.waitUntilTrue(() => servers(0).logManager.getLog(tp).isDefined, "log was not created") + + val topicIdInLog = servers(0).logManager.getLog(tp).get.topicId + assertEquals(Some(topicId), topicIdInLog) + + adminZkClient.deleteTopic(tp.topic) + TestUtils.waitUntilTrue(() => !servers.head.kafkaController.controllerContext.allTopics.contains(tp.topic), + "topic should have been removed from controller context after deletion") + } + + @Test + def testTopicIdCreatedOnUpgradeMultiBrokerScenario(): Unit = { + // Simulate an upgrade scenario where the controller is still on a pre-topic ID IBP, but the other two brokers are upgraded. + servers = makeServers(1, interBrokerProtocolVersion = Some(KAFKA_2_7_IV0)) + servers = servers ++ makeServers(3, startingIdNumber = 1) + val originalControllerId = TestUtils.waitUntilControllerElected(zkClient) + assertEquals(0, originalControllerId) + val controller = getController().kafkaController + assertEquals(KAFKA_2_7_IV0, servers(originalControllerId).config.interBrokerProtocolVersion) + val remainingBrokers = servers.filter(_.config.brokerId != originalControllerId) + val tp = new TopicPartition("t", 0) + // Only the remaining brokers will have the replicas for the partition + val assignment = Map(tp.partition -> remainingBrokers.map(_.config.brokerId)) + TestUtils.createTopic(zkClient, tp.topic, partitionReplicaAssignment = assignment, servers = servers) + waitForPartitionState(tp, firstControllerEpoch, remainingBrokers(0).config.brokerId, LeaderAndIsr.initialLeaderEpoch, + "failed to get expected partition state upon topic creation") + val topicIdAfterCreate = zkClient.getTopicIdsForTopics(Set(tp.topic())).get(tp.topic()) + assertEquals(None, topicIdAfterCreate) + val emptyTopicId = controller.controllerContext.topicIds.get("t") + assertEquals(None, emptyTopicId) + + // All partition logs should not have topic IDs + remainingBrokers.foreach { server => + TestUtils.waitUntilTrue(() => server.logManager.getLog(tp).isDefined, "log was not created for server" + server.config.brokerId) + val topicIdInLog = server.logManager.getLog(tp).get.topicId + assertEquals(None, topicIdInLog) + } + + // Shut down the controller to transfer the controller to a new IBP broker. + servers(originalControllerId).shutdown() + servers(originalControllerId).awaitShutdown() + // If we were upgrading, this server would be the latest IBP, but it doesn't matter in this test scenario + servers(originalControllerId).startup() + TestUtils.waitUntilTrue(() => zkClient.getControllerId.isDefined, "failed to elect a controller") + val topicIdAfterUpgrade = zkClient.getTopicIdsForTopics(Set(tp.topic())).get(tp.topic()) + assertNotEquals(emptyTopicId, topicIdAfterUpgrade) + val controller2 = getController().kafkaController + assertNotEquals(emptyTopicId, controller2.controllerContext.topicIds.get("t")) + val topicId = controller2.controllerContext.topicIds.get("t").get + assertEquals(topicIdAfterUpgrade.get, topicId) + assertEquals("t", controller2.controllerContext.topicNames(topicId)) + + // All partition logs should have topic IDs + remainingBrokers.foreach { server => + TestUtils.waitUntilTrue(() => server.logManager.getLog(tp).isDefined, "log was not created for server" + server.config.brokerId) + val topicIdInLog = server.logManager.getLog(tp).get.topicId + assertEquals(Some(topicId), topicIdInLog, + s"Server ${server.config.brokerId} had topic ID $topicIdInLog instead of ${Some(topicId)} as expected.") + } + + adminZkClient.deleteTopic(tp.topic) + TestUtils.waitUntilTrue(() => !servers.head.kafkaController.controllerContext.allTopics.contains(tp.topic), + "topic should have been removed from controller context after deletion") + } + + @Test + def testTopicIdUpgradeAfterReassigningPartitions(): Unit = { + val tp = new TopicPartition("t", 0) + val reassignment = Map(tp -> Some(Seq(0))) + val adminZkClient = new AdminZkClient(zkClient) + + // start server with old IBP + servers = makeServers(1, interBrokerProtocolVersion = Some(KAFKA_2_7_IV0)) + // use create topic with ZK client directly, without topic ID + adminZkClient.createTopic(tp.topic, 1, 1) + waitForPartitionState(tp, firstControllerEpoch, 0, LeaderAndIsr.initialLeaderEpoch, + "failed to get expected partition state upon topic creation") + val topicIdAfterCreate = zkClient.getTopicIdsForTopics(Set(tp.topic())).get(tp.topic()) + val id = servers.head.kafkaController.controllerContext.topicIds.get(tp.topic) + assertTrue(topicIdAfterCreate.isEmpty) + assertEquals(topicIdAfterCreate, id, + "expected no topic ID, but one existed") + + // Upgrade to IBP 2.8 + servers(0).shutdown() + servers(0).awaitShutdown() + servers = makeServers(1) + waitForPartitionState(tp, firstControllerEpoch, 0, LeaderAndIsr.initialLeaderEpoch, + "failed to get expected partition state upon controller restart") + val topicIdAfterUpgrade = zkClient.getTopicIdsForTopics(Set(tp.topic())).get(tp.topic()) + assertEquals(topicIdAfterUpgrade, servers.head.kafkaController.controllerContext.topicIds.get(tp.topic), + "expected same topic ID but it can not be found") + assertEquals(tp.topic(), servers.head.kafkaController.controllerContext.topicNames(topicIdAfterUpgrade.get), + "correct topic name expected but cannot be found in the controller context") + + // Downgrade back to 2.7 + servers(0).shutdown() + servers(0).awaitShutdown() + servers = makeServers(1, interBrokerProtocolVersion = Some(KAFKA_2_7_IV0)) + waitForPartitionState(tp, firstControllerEpoch, 0, LeaderAndIsr.initialLeaderEpoch, + "failed to get expected partition state upon topic creation") + val topicIdAfterDowngrade = zkClient.getTopicIdsForTopics(Set(tp.topic())).get(tp.topic()) + assertTrue(topicIdAfterDowngrade.isDefined) + assertEquals(topicIdAfterUpgrade, topicIdAfterDowngrade, + "expected same topic ID but it can not be found after downgrade") + assertEquals(topicIdAfterDowngrade, servers.head.kafkaController.controllerContext.topicIds.get(tp.topic), + "expected same topic ID in controller context but it is no longer found after downgrade") + assertEquals(tp.topic(), servers.head.kafkaController.controllerContext.topicNames(topicIdAfterUpgrade.get), + "correct topic name expected but cannot be found in the controller context") + + // Reassign partitions + servers(0).kafkaController.eventManager.put(ApiPartitionReassignment(reassignment, _ => ())) + waitForPartitionState(tp, 3, 0, 1, + "failed to get expected partition state upon controller restart") + val topicIdAfterReassignment = zkClient.getTopicIdsForTopics(Set(tp.topic())).get(tp.topic()) + assertTrue(topicIdAfterReassignment.isDefined) + assertEquals(topicIdAfterUpgrade, topicIdAfterReassignment, + "expected same topic ID but it can not be found after reassignment") + assertEquals(topicIdAfterUpgrade, servers.head.kafkaController.controllerContext.topicIds.get(tp.topic), + "expected same topic ID in controller context but is no longer found after reassignment") + assertEquals(tp.topic(), servers.head.kafkaController.controllerContext.topicNames(topicIdAfterUpgrade.get), + "correct topic name expected but cannot be found in the controller context") + + // Upgrade back to 2.8 + servers(0).shutdown() + servers(0).awaitShutdown() + servers = makeServers(1) + waitForPartitionState(tp, 3, 0, 1, + "failed to get expected partition state upon controller restart") + val topicIdAfterReUpgrade = zkClient.getTopicIdsForTopics(Set(tp.topic())).get(tp.topic()) + assertEquals(topicIdAfterUpgrade, topicIdAfterReUpgrade, + "expected same topic ID but it can not be found after re-upgrade") + assertEquals(topicIdAfterReUpgrade, servers.head.kafkaController.controllerContext.topicIds.get(tp.topic), + "topic ID can not be found in controller context after re-upgrading IBP") + assertEquals(tp.topic(), servers.head.kafkaController.controllerContext.topicNames(topicIdAfterReUpgrade.get), + "correct topic name expected but cannot be found in the controller context") + + adminZkClient.deleteTopic(tp.topic) + TestUtils.waitUntilTrue(() => servers.head.kafkaController.controllerContext.topicIds.get(tp.topic).isEmpty, + "topic ID for topic should have been removed from controller context after deletion") + assertTrue(servers.head.kafkaController.controllerContext.topicNames.get(topicIdAfterUpgrade.get).isEmpty) + } + + private def testControllerMove(fun: () => Unit): Unit = { + val controller = getController().kafkaController + val appender = LogCaptureAppender.createAndRegister() + val previousLevel = LogCaptureAppender.setClassLoggerLevel(controller.getClass, Level.INFO) + + try { + TestUtils.waitUntilTrue(() => { + controller.eventManager.state == ControllerState.Idle + }, "Controller event thread is still busy") + + val latch = new CountDownLatch(1) + + // Let the controller event thread await on a latch before the pre-defined logic is triggered. + // This is used to make sure that when the event thread resumes and starts processing events, the controller has already moved. + controller.eventManager.put(new MockEvent(ControllerState.TopicChange) { + override def process(): Unit = latch.await() + override def preempt(): Unit = {} + }) + + // Execute pre-defined logic. This can be topic creation/deletion, preferred leader election, etc. + fun() + + // Delete the controller path, re-create /controller znode to emulate controller movement + zkClient.deleteController(controller.controllerContext.epochZkVersion) + zkClient.registerControllerAndIncrementControllerEpoch(servers.size) + + // Resume the controller event thread. At this point, the controller should see mismatch controller epoch zkVersion and resign + latch.countDown() + TestUtils.waitUntilTrue(() => !controller.isActive, "Controller fails to resign") + + // Expect to capture the ControllerMovedException in the log of ControllerEventThread + val event = appender.getMessages.find(e => e.getLevel == Level.INFO + && e.getThrowableInformation != null + && e.getThrowableInformation.getThrowable.getClass.getName.equals(classOf[ControllerMovedException].getName)) + assertTrue(event.isDefined) + + } finally { + LogCaptureAppender.unregister(appender) + LogCaptureAppender.setClassLoggerLevel(controller.eventManager.thread.getClass, previousLevel) + } + } + + private def preferredReplicaLeaderElection(controllerId: Int, otherBroker: KafkaServer, tp: TopicPartition, + replicas: Set[Int], leaderEpoch: Int): Unit = { + otherBroker.shutdown() + otherBroker.awaitShutdown() + waitForPartitionState(tp, firstControllerEpoch, controllerId, leaderEpoch + 1, + "failed to get expected partition state upon broker shutdown") + otherBroker.startup() + TestUtils.waitUntilTrue(() => zkClient.getInSyncReplicasForPartition(new TopicPartition(tp.topic, tp.partition)).get.toSet == replicas, "restarted broker failed to join in-sync replicas") + zkClient.createPreferredReplicaElection(Set(tp)) + TestUtils.waitUntilTrue(() => !zkClient.pathExists(PreferredReplicaElectionZNode.path), + "failed to remove preferred replica leader election path after completion") + waitForPartitionState(tp, firstControllerEpoch, otherBroker.config.brokerId, leaderEpoch + 2, + "failed to get expected partition state upon broker startup") + } + + private def waitUntilControllerEpoch(epoch: Int, message: String): Unit = { + TestUtils.waitUntilTrue(() => zkClient.getControllerEpoch.map(_._1).contains(epoch) , message) + } + + private def waitForPartitionState(tp: TopicPartition, + controllerEpoch: Int, + leader: Int, + leaderEpoch: Int, + message: String): Unit = { + TestUtils.waitUntilTrue(() => { + val leaderIsrAndControllerEpochMap = zkClient.getTopicPartitionStates(Seq(tp)) + leaderIsrAndControllerEpochMap.contains(tp) && + isExpectedPartitionState(leaderIsrAndControllerEpochMap(tp), controllerEpoch, leader, leaderEpoch) + }, message) + } + + private def isExpectedPartitionState(leaderIsrAndControllerEpoch: LeaderIsrAndControllerEpoch, + controllerEpoch: Int, + leader: Int, + leaderEpoch: Int) = + leaderIsrAndControllerEpoch.controllerEpoch == controllerEpoch && + leaderIsrAndControllerEpoch.leaderAndIsr.leader == leader && + leaderIsrAndControllerEpoch.leaderAndIsr.leaderEpoch == leaderEpoch + + private def makeServers(numConfigs: Int, + autoLeaderRebalanceEnable: Boolean = false, + uncleanLeaderElectionEnable: Boolean = false, + enableControlledShutdown: Boolean = true, + listeners : Option[String] = None, + listenerSecurityProtocolMap : Option[String] = None, + controlPlaneListenerName : Option[String] = None, + interBrokerProtocolVersion: Option[ApiVersion] = None, + logDirCount: Int = 1, + startingIdNumber: Int = 0) = { + val configs = TestUtils.createBrokerConfigs(numConfigs, zkConnect, enableControlledShutdown = enableControlledShutdown, logDirCount = logDirCount, startingIdNumber = startingIdNumber) + configs.foreach { config => + config.setProperty(KafkaConfig.AutoLeaderRebalanceEnableProp, autoLeaderRebalanceEnable.toString) + config.setProperty(KafkaConfig.UncleanLeaderElectionEnableProp, uncleanLeaderElectionEnable.toString) + config.setProperty(KafkaConfig.LeaderImbalanceCheckIntervalSecondsProp, "1") + listeners.foreach(listener => config.setProperty(KafkaConfig.ListenersProp, listener)) + listenerSecurityProtocolMap.foreach(listenerMap => config.setProperty(KafkaConfig.ListenerSecurityProtocolMapProp, listenerMap)) + controlPlaneListenerName.foreach(controlPlaneListener => config.setProperty(KafkaConfig.ControlPlaneListenerNameProp, controlPlaneListener)) + interBrokerProtocolVersion.foreach(ibp => config.setProperty(KafkaConfig.InterBrokerProtocolVersionProp, ibp.toString)) + } + configs.map(config => TestUtils.createServer(KafkaConfig.fromProps(config))) + } + + private def timer(metricName: String): Timer = { + KafkaYammerMetrics.defaultRegistry.allMetrics.asScala.filter { case (k, _) => + k.getMBeanName == metricName + }.values.headOption.getOrElse(throw new AssertionError(s"Unable to find metric $metricName")).asInstanceOf[Timer] + } + + private def getController(): KafkaServer = { + val controllerId = TestUtils.waitUntilControllerElected(zkClient) + servers.filter(s => s.config.brokerId == controllerId).head + } + +} diff --git a/core/src/test/scala/unit/kafka/controller/MockPartitionStateMachine.scala b/core/src/test/scala/unit/kafka/controller/MockPartitionStateMachine.scala new file mode 100644 index 0000000..b9a4d04 --- /dev/null +++ b/core/src/test/scala/unit/kafka/controller/MockPartitionStateMachine.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.controller + +import kafka.api.LeaderAndIsr +import kafka.common.StateChangeFailedException +import kafka.controller.Election._ +import org.apache.kafka.common.TopicPartition + +import scala.collection.{Seq, mutable} + +class MockPartitionStateMachine(controllerContext: ControllerContext, + uncleanLeaderElectionEnabled: Boolean) + extends PartitionStateMachine(controllerContext) { + + var stateChangesByTargetState = mutable.Map.empty[PartitionState, Int].withDefaultValue(0) + + def stateChangesCalls(targetState: PartitionState): Int = { + stateChangesByTargetState(targetState) + } + + def clear(): Unit = { + stateChangesByTargetState.clear() + } + + override def handleStateChanges( + partitions: Seq[TopicPartition], + targetState: PartitionState, + leaderElectionStrategy: Option[PartitionLeaderElectionStrategy] + ): Map[TopicPartition, Either[Throwable, LeaderAndIsr]] = { + stateChangesByTargetState(targetState) = stateChangesByTargetState(targetState) + 1 + + partitions.foreach(partition => controllerContext.putPartitionStateIfNotExists(partition, NonExistentPartition)) + val (validPartitions, invalidPartitions) = controllerContext.checkValidPartitionStateChange(partitions, targetState) + if (invalidPartitions.nonEmpty) { + val currentStates = invalidPartitions.map(p => controllerContext.partitionStates.get(p)) + throw new IllegalStateException(s"Invalid state transition to $targetState for partitions $currentStates") + } + + if (targetState == OnlinePartition) { + val uninitializedPartitions = validPartitions.filter(partition => controllerContext.partitionState(partition) == NewPartition) + val partitionsToElectLeader = partitions.filter { partition => + val currentState = controllerContext.partitionState(partition) + currentState == OfflinePartition || currentState == OnlinePartition + } + + uninitializedPartitions.foreach { partition => + controllerContext.putPartitionState(partition, targetState) + } + + val electionResults = doLeaderElections(partitionsToElectLeader, leaderElectionStrategy.get) + electionResults.foreach { + case (partition, Right(_)) => controllerContext.putPartitionState(partition, targetState) + case (_, Left(_)) => // Ignore; No need to update the context if the election failed + } + + electionResults + } else { + validPartitions.foreach { partition => + controllerContext.putPartitionState(partition, targetState) + } + Map.empty + } + } + + private def doLeaderElections( + partitions: Seq[TopicPartition], + leaderElectionStrategy: PartitionLeaderElectionStrategy + ): Map[TopicPartition, Either[Throwable, LeaderAndIsr]] = { + val failedElections = mutable.Map.empty[TopicPartition, Either[Throwable, LeaderAndIsr]] + val validLeaderAndIsrs = mutable.Buffer.empty[(TopicPartition, LeaderAndIsr)] + + for (partition <- partitions) { + val leaderIsrAndControllerEpoch = controllerContext.partitionLeadershipInfo(partition).get + if (leaderIsrAndControllerEpoch.controllerEpoch > controllerContext.epoch) { + val failMsg = s"Aborted leader election for partition $partition since the LeaderAndIsr path was " + + s"already written by another controller. This probably means that the current controller went through " + + s"a soft failure and another controller was elected with epoch ${leaderIsrAndControllerEpoch.controllerEpoch}." + failedElections.put(partition, Left(new StateChangeFailedException(failMsg))) + } else { + validLeaderAndIsrs.append((partition, leaderIsrAndControllerEpoch.leaderAndIsr)) + } + } + + val electionResults = leaderElectionStrategy match { + case OfflinePartitionLeaderElectionStrategy(isUnclean) => + val partitionsWithUncleanLeaderElectionState = validLeaderAndIsrs.map { case (partition, leaderAndIsr) => + (partition, Some(leaderAndIsr), isUnclean || uncleanLeaderElectionEnabled) + } + leaderForOffline(controllerContext, partitionsWithUncleanLeaderElectionState) + case ReassignPartitionLeaderElectionStrategy => + leaderForReassign(controllerContext, validLeaderAndIsrs) + case PreferredReplicaPartitionLeaderElectionStrategy => + leaderForPreferredReplica(controllerContext, validLeaderAndIsrs) + case ControlledShutdownPartitionLeaderElectionStrategy => + leaderForControlledShutdown(controllerContext, validLeaderAndIsrs) + } + + val results: Map[TopicPartition, Either[Exception, LeaderAndIsr]] = electionResults.map { electionResult => + val partition = electionResult.topicPartition + val value = electionResult.leaderAndIsr match { + case None => + val failMsg = s"Failed to elect leader for partition $partition under strategy $leaderElectionStrategy" + Left(new StateChangeFailedException(failMsg)) + case Some(leaderAndIsr) => + val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(leaderAndIsr, controllerContext.epoch) + controllerContext.putPartitionLeadershipInfo(partition, leaderIsrAndControllerEpoch) + Right(leaderAndIsr) + } + + partition -> value + }.toMap + + results ++ failedElections + } + +} diff --git a/core/src/test/scala/unit/kafka/controller/MockReplicaStateMachine.scala b/core/src/test/scala/unit/kafka/controller/MockReplicaStateMachine.scala new file mode 100644 index 0000000..32bfc50 --- /dev/null +++ b/core/src/test/scala/unit/kafka/controller/MockReplicaStateMachine.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.controller + +import scala.collection.Seq +import scala.collection.mutable + +class MockReplicaStateMachine(controllerContext: ControllerContext) extends ReplicaStateMachine(controllerContext) { + val stateChangesByTargetState = mutable.Map.empty[ReplicaState, Int].withDefaultValue(0) + + def stateChangesCalls(targetState: ReplicaState): Int = { + stateChangesByTargetState(targetState) + } + + def clear(): Unit = { + stateChangesByTargetState.clear() + } + + override def handleStateChanges(replicas: Seq[PartitionAndReplica], targetState: ReplicaState): Unit = { + stateChangesByTargetState(targetState) = stateChangesByTargetState(targetState) + 1 + + replicas.foreach(replica => controllerContext.putReplicaStateIfNotExists(replica, NonExistentReplica)) + val (validReplicas, invalidReplicas) = controllerContext.checkValidReplicaStateChange(replicas, targetState) + if (invalidReplicas.nonEmpty) { + val currentStates = invalidReplicas.map(replica => replica -> controllerContext.replicaStates.get(replica)).toMap + throw new IllegalStateException(s"Invalid state transition to $targetState for replicas $currentStates") + } + validReplicas.foreach { replica => + if (targetState == NonExistentReplica) + controllerContext.removeReplicaState(replica) + else + controllerContext.putReplicaState(replica, targetState) + } + } + +} diff --git a/core/src/test/scala/unit/kafka/controller/PartitionLeaderElectionAlgorithmsTest.scala b/core/src/test/scala/unit/kafka/controller/PartitionLeaderElectionAlgorithmsTest.scala new file mode 100644 index 0000000..4f3aec0 --- /dev/null +++ b/core/src/test/scala/unit/kafka/controller/PartitionLeaderElectionAlgorithmsTest.scala @@ -0,0 +1,187 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.controller + +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{BeforeEach, Test} + +class PartitionLeaderElectionAlgorithmsTest { + private var controllerContext: ControllerContext = null + + @BeforeEach + def setUp(): Unit = { + controllerContext = new ControllerContext + controllerContext.stats.removeMetric("UncleanLeaderElectionsPerSec") + } + + @Test + def testOfflinePartitionLeaderElection(): Unit = { + val assignment = Seq(2, 4) + val isr = Seq(2, 4) + val liveReplicas = Set(4) + val leaderOpt = PartitionLeaderElectionAlgorithms.offlinePartitionLeaderElection(assignment, + isr, + liveReplicas, + uncleanLeaderElectionEnabled = false, + controllerContext) + assertEquals(Option(4), leaderOpt) + } + + @Test + def testOfflinePartitionLeaderElectionLastIsrOfflineUncleanLeaderElectionDisabled(): Unit = { + val assignment = Seq(2, 4) + val isr = Seq(2) + val liveReplicas = Set(4) + val leaderOpt = PartitionLeaderElectionAlgorithms.offlinePartitionLeaderElection(assignment, + isr, + liveReplicas, + uncleanLeaderElectionEnabled = false, + controllerContext) + assertEquals(None, leaderOpt) + assertEquals(0, controllerContext.stats.uncleanLeaderElectionRate.count()) + } + + @Test + def testOfflinePartitionLeaderElectionLastIsrOfflineUncleanLeaderElectionEnabled(): Unit = { + val assignment = Seq(2, 4) + val isr = Seq(2) + val liveReplicas = Set(4) + val leaderOpt = PartitionLeaderElectionAlgorithms.offlinePartitionLeaderElection(assignment, + isr, + liveReplicas, + uncleanLeaderElectionEnabled = true, + controllerContext) + assertEquals(Option(4), leaderOpt) + assertEquals(1, controllerContext.stats.uncleanLeaderElectionRate.count()) + } + + @Test + def testReassignPartitionLeaderElection(): Unit = { + val reassignment = Seq(2, 4) + val isr = Seq(2, 4) + val liveReplicas = Set(4) + val leaderOpt = PartitionLeaderElectionAlgorithms.reassignPartitionLeaderElection(reassignment, + isr, + liveReplicas) + assertEquals(Option(4), leaderOpt) + } + + @Test + def testReassignPartitionLeaderElectionWithNoLiveIsr(): Unit = { + val reassignment = Seq(2, 4) + val isr = Seq(2) + val liveReplicas = Set.empty[Int] + val leaderOpt = PartitionLeaderElectionAlgorithms.reassignPartitionLeaderElection(reassignment, + isr, + liveReplicas) + assertEquals(None, leaderOpt) + } + + @Test + def testReassignPartitionLeaderElectionWithEmptyIsr(): Unit = { + val reassignment = Seq(2, 4) + val isr = Seq.empty[Int] + val liveReplicas = Set(2) + val leaderOpt = PartitionLeaderElectionAlgorithms.reassignPartitionLeaderElection(reassignment, + isr, + liveReplicas) + assertEquals(None, leaderOpt) + } + + @Test + def testPreferredReplicaPartitionLeaderElection(): Unit = { + val assignment = Seq(2, 4) + val isr = Seq(2, 4) + val liveReplicas = Set(2, 4) + val leaderOpt = PartitionLeaderElectionAlgorithms.preferredReplicaPartitionLeaderElection(assignment, + isr, + liveReplicas) + assertEquals(Option(2), leaderOpt) + } + + @Test + def testPreferredReplicaPartitionLeaderElectionPreferredReplicaInIsrNotLive(): Unit = { + val assignment = Seq(2, 4) + val isr = Seq(2) + val liveReplicas = Set.empty[Int] + val leaderOpt = PartitionLeaderElectionAlgorithms.preferredReplicaPartitionLeaderElection(assignment, + isr, + liveReplicas) + assertEquals(None, leaderOpt) + } + + @Test + def testPreferredReplicaPartitionLeaderElectionPreferredReplicaNotInIsrLive(): Unit = { + val assignment = Seq(2, 4) + val isr = Seq(4) + val liveReplicas = Set(2, 4) + val leaderOpt = PartitionLeaderElectionAlgorithms.preferredReplicaPartitionLeaderElection(assignment, + isr, + liveReplicas) + assertEquals(None, leaderOpt) + } + + @Test + def testPreferredReplicaPartitionLeaderElectionPreferredReplicaNotInIsrNotLive(): Unit = { + val assignment = Seq(2, 4) + val isr = Seq.empty[Int] + val liveReplicas = Set.empty[Int] + val leaderOpt = PartitionLeaderElectionAlgorithms.preferredReplicaPartitionLeaderElection(assignment, + isr, + liveReplicas) + assertEquals(None, leaderOpt) + } + + @Test + def testControlledShutdownPartitionLeaderElection(): Unit = { + val assignment = Seq(2, 4) + val isr = Seq(2, 4) + val liveReplicas = Set(2, 4) + val shuttingDownBrokers = Set(2) + val leaderOpt = PartitionLeaderElectionAlgorithms.controlledShutdownPartitionLeaderElection(assignment, + isr, + liveReplicas, + shuttingDownBrokers) + assertEquals(Option(4), leaderOpt) + } + + @Test + def testControlledShutdownPartitionLeaderElectionLastIsrShuttingDown(): Unit = { + val assignment = Seq(2, 4) + val isr = Seq(2) + val liveReplicas = Set(2, 4) + val shuttingDownBrokers = Set(2) + val leaderOpt = PartitionLeaderElectionAlgorithms.controlledShutdownPartitionLeaderElection(assignment, + isr, + liveReplicas, + shuttingDownBrokers) + assertEquals(None, leaderOpt) + } + + @Test + def testControlledShutdownPartitionLeaderElectionAllIsrSimultaneouslyShutdown(): Unit = { + val assignment = Seq(2, 4) + val isr = Seq(2, 4) + val liveReplicas = Set(2, 4) + val shuttingDownBrokers = Set(2, 4) + val leaderOpt = PartitionLeaderElectionAlgorithms.controlledShutdownPartitionLeaderElection(assignment, + isr, + liveReplicas, + shuttingDownBrokers) + assertEquals(None, leaderOpt) + } +} diff --git a/core/src/test/scala/unit/kafka/controller/PartitionStateMachineTest.scala b/core/src/test/scala/unit/kafka/controller/PartitionStateMachineTest.scala new file mode 100644 index 0000000..2119506 --- /dev/null +++ b/core/src/test/scala/unit/kafka/controller/PartitionStateMachineTest.scala @@ -0,0 +1,533 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.controller + +import kafka.api.LeaderAndIsr +import kafka.log.LogConfig +import kafka.server.KafkaConfig +import kafka.utils.TestUtils +import kafka.zk.KafkaZkClient.UpdateLeaderAndIsrResult +import kafka.zk.{KafkaZkClient, TopicPartitionStateZNode} +import kafka.zookeeper._ +import org.apache.kafka.common.TopicPartition +import org.apache.zookeeper.KeeperException.Code +import org.apache.zookeeper.data.Stat +import org.easymock.EasyMock +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{BeforeEach, Test} +import org.mockito.Mockito + +class PartitionStateMachineTest { + private var controllerContext: ControllerContext = null + private var mockZkClient: KafkaZkClient = null + private var mockControllerBrokerRequestBatch: ControllerBrokerRequestBatch = null + private var partitionStateMachine: PartitionStateMachine = null + + private val brokerId = 5 + private val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(brokerId, "zkConnect")) + private val controllerEpoch = 50 + private val partition = new TopicPartition("t", 0) + private val partitions = Seq(partition) + + @BeforeEach + def setUp(): Unit = { + controllerContext = new ControllerContext + controllerContext.epoch = controllerEpoch + mockZkClient = EasyMock.createMock(classOf[KafkaZkClient]) + mockControllerBrokerRequestBatch = EasyMock.createMock(classOf[ControllerBrokerRequestBatch]) + partitionStateMachine = new ZkPartitionStateMachine(config, new StateChangeLogger(brokerId, true, None), controllerContext, + mockZkClient, mockControllerBrokerRequestBatch) + } + + private def partitionState(partition: TopicPartition): PartitionState = { + controllerContext.partitionState(partition) + } + + @Test + def testNonexistentPartitionToNewPartitionTransition(): Unit = { + partitionStateMachine.handleStateChanges(partitions, NewPartition) + assertEquals(NewPartition, partitionState(partition)) + } + + @Test + def testInvalidNonexistentPartitionToOnlinePartitionTransition(): Unit = { + partitionStateMachine.handleStateChanges( + partitions, + OnlinePartition, + Option(OfflinePartitionLeaderElectionStrategy(false)) + ) + assertEquals(NonExistentPartition, partitionState(partition)) + } + + @Test + def testInvalidNonexistentPartitionToOfflinePartitionTransition(): Unit = { + partitionStateMachine.handleStateChanges(partitions, OfflinePartition) + assertEquals(NonExistentPartition, partitionState(partition)) + } + + @Test + def testNewPartitionToOnlinePartitionTransition(): Unit = { + controllerContext.setLiveBrokers(Map(TestUtils.createBrokerAndEpoch(brokerId, "host", 0))) + controllerContext.updatePartitionFullReplicaAssignment(partition, ReplicaAssignment(Seq(brokerId))) + controllerContext.putPartitionState(partition, NewPartition) + val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(LeaderAndIsr(brokerId, List(brokerId)), controllerEpoch) + EasyMock.expect(mockControllerBrokerRequestBatch.newBatch()) + EasyMock.expect(mockZkClient.createTopicPartitionStatesRaw(Map(partition -> leaderIsrAndControllerEpoch), controllerContext.epochZkVersion)) + .andReturn(Seq(CreateResponse(Code.OK, null, Some(partition), null, ResponseMetadata(0, 0)))) + EasyMock.expect(mockControllerBrokerRequestBatch.addLeaderAndIsrRequestForBrokers(Seq(brokerId), + partition, leaderIsrAndControllerEpoch, replicaAssignment(Seq(brokerId)), isNew = true)) + EasyMock.expect(mockControllerBrokerRequestBatch.sendRequestsToBrokers(controllerEpoch)) + EasyMock.replay(mockZkClient, mockControllerBrokerRequestBatch) + partitionStateMachine.handleStateChanges( + partitions, + OnlinePartition, + Option(OfflinePartitionLeaderElectionStrategy(false)) + ) + EasyMock.verify(mockZkClient, mockControllerBrokerRequestBatch) + assertEquals(OnlinePartition, partitionState(partition)) + } + + @Test + def testNewPartitionToOnlinePartitionTransitionZooKeeperClientExceptionFromCreateStates(): Unit = { + controllerContext.setLiveBrokers(Map(TestUtils.createBrokerAndEpoch(brokerId, "host", 0))) + controllerContext.updatePartitionFullReplicaAssignment(partition, ReplicaAssignment(Seq(brokerId))) + controllerContext.putPartitionState(partition, NewPartition) + val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(LeaderAndIsr(brokerId, List(brokerId)), controllerEpoch) + EasyMock.expect(mockControllerBrokerRequestBatch.newBatch()) + EasyMock.expect(mockZkClient.createTopicPartitionStatesRaw(Map(partition -> leaderIsrAndControllerEpoch), controllerContext.epochZkVersion)) + .andThrow(new ZooKeeperClientException("test")) + EasyMock.expect(mockControllerBrokerRequestBatch.sendRequestsToBrokers(controllerEpoch)) + EasyMock.replay(mockZkClient, mockControllerBrokerRequestBatch) + partitionStateMachine.handleStateChanges( + partitions, + OnlinePartition, + Option(OfflinePartitionLeaderElectionStrategy(false)) + ) + EasyMock.verify(mockZkClient, mockControllerBrokerRequestBatch) + assertEquals(NewPartition, partitionState(partition)) + } + + @Test + def testNewPartitionToOnlinePartitionTransitionErrorCodeFromCreateStates(): Unit = { + controllerContext.setLiveBrokers(Map(TestUtils.createBrokerAndEpoch(brokerId, "host", 0))) + controllerContext.updatePartitionFullReplicaAssignment(partition, ReplicaAssignment(Seq(brokerId))) + controllerContext.putPartitionState(partition, NewPartition) + val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(LeaderAndIsr(brokerId, List(brokerId)), controllerEpoch) + EasyMock.expect(mockControllerBrokerRequestBatch.newBatch()) + EasyMock.expect(mockZkClient.createTopicPartitionStatesRaw(Map(partition -> leaderIsrAndControllerEpoch), controllerContext.epochZkVersion)) + .andReturn(Seq(CreateResponse(Code.NODEEXISTS, null, Some(partition), null, ResponseMetadata(0, 0)))) + EasyMock.expect(mockControllerBrokerRequestBatch.sendRequestsToBrokers(controllerEpoch)) + EasyMock.replay(mockZkClient, mockControllerBrokerRequestBatch) + partitionStateMachine.handleStateChanges( + partitions, + OnlinePartition, + Option(OfflinePartitionLeaderElectionStrategy(false)) + ) + EasyMock.verify(mockZkClient, mockControllerBrokerRequestBatch) + assertEquals(NewPartition, partitionState(partition)) + } + + @Test + def testNewPartitionToOfflinePartitionTransition(): Unit = { + controllerContext.putPartitionState(partition, NewPartition) + partitionStateMachine.handleStateChanges(partitions, OfflinePartition) + assertEquals(OfflinePartition, partitionState(partition)) + } + + @Test + def testInvalidNewPartitionToNonexistentPartitionTransition(): Unit = { + controllerContext.putPartitionState(partition, NewPartition) + partitionStateMachine.handleStateChanges(partitions, NonExistentPartition) + assertEquals(NewPartition, partitionState(partition)) + } + + @Test + def testOnlinePartitionToOnlineTransition(): Unit = { + controllerContext.setLiveBrokers(Map(TestUtils.createBrokerAndEpoch(brokerId, "host", 0))) + controllerContext.updatePartitionFullReplicaAssignment(partition, ReplicaAssignment(Seq(brokerId))) + controllerContext.putPartitionState(partition, OnlinePartition) + val leaderAndIsr = LeaderAndIsr(brokerId, List(brokerId)) + val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch) + controllerContext.putPartitionLeadershipInfo(partition, leaderIsrAndControllerEpoch) + + val stat = new Stat(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) + EasyMock.expect(mockControllerBrokerRequestBatch.newBatch()) + EasyMock.expect(mockZkClient.getTopicPartitionStatesRaw(partitions)) + .andReturn(Seq(GetDataResponse(Code.OK, null, Some(partition), + TopicPartitionStateZNode.encode(leaderIsrAndControllerEpoch), stat, ResponseMetadata(0, 0)))) + + val leaderAndIsrAfterElection = leaderAndIsr.newLeader(brokerId) + val updatedLeaderAndIsr = leaderAndIsrAfterElection.withZkVersion(2) + EasyMock.expect(mockZkClient.updateLeaderAndIsr(Map(partition -> leaderAndIsrAfterElection), controllerEpoch, controllerContext.epochZkVersion)) + .andReturn(UpdateLeaderAndIsrResult(Map(partition -> Right(updatedLeaderAndIsr)), Seq.empty)) + EasyMock.expect(mockControllerBrokerRequestBatch.addLeaderAndIsrRequestForBrokers(Seq(brokerId), + partition, LeaderIsrAndControllerEpoch(updatedLeaderAndIsr, controllerEpoch), replicaAssignment(Seq(brokerId)), isNew = false)) + EasyMock.expect(mockControllerBrokerRequestBatch.sendRequestsToBrokers(controllerEpoch)) + EasyMock.replay(mockZkClient, mockControllerBrokerRequestBatch) + + partitionStateMachine.handleStateChanges(partitions, OnlinePartition, Option(PreferredReplicaPartitionLeaderElectionStrategy)) + EasyMock.verify(mockZkClient, mockControllerBrokerRequestBatch) + assertEquals(OnlinePartition, partitionState(partition)) + } + + @Test + def testOnlinePartitionToOnlineTransitionForControlledShutdown(): Unit = { + val otherBrokerId = brokerId + 1 + controllerContext.setLiveBrokers(Map( + TestUtils.createBrokerAndEpoch(brokerId, "host", 0), + TestUtils.createBrokerAndEpoch(otherBrokerId, "host", 0))) + controllerContext.shuttingDownBrokerIds.add(brokerId) + controllerContext.updatePartitionFullReplicaAssignment( + partition, + ReplicaAssignment(Seq(brokerId, otherBrokerId)) + ) + controllerContext.putPartitionState(partition, OnlinePartition) + val leaderAndIsr = LeaderAndIsr(brokerId, List(brokerId, otherBrokerId)) + val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch) + controllerContext.putPartitionLeadershipInfo(partition, leaderIsrAndControllerEpoch) + + val stat = new Stat(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) + EasyMock.expect(mockControllerBrokerRequestBatch.newBatch()) + EasyMock.expect(mockZkClient.getTopicPartitionStatesRaw(partitions)) + .andReturn(Seq(GetDataResponse(Code.OK, null, Some(partition), + TopicPartitionStateZNode.encode(leaderIsrAndControllerEpoch), stat, ResponseMetadata(0, 0)))) + + val leaderAndIsrAfterElection = leaderAndIsr.newLeaderAndIsr(otherBrokerId, List(otherBrokerId)) + val updatedLeaderAndIsr = leaderAndIsrAfterElection.withZkVersion(2) + EasyMock.expect(mockZkClient.updateLeaderAndIsr(Map(partition -> leaderAndIsrAfterElection), controllerEpoch, controllerContext.epochZkVersion)) + .andReturn(UpdateLeaderAndIsrResult(Map(partition -> Right(updatedLeaderAndIsr)), Seq.empty)) + + // The leaderAndIsr request should be sent to both brokers, including the shutting down one + EasyMock.expect(mockControllerBrokerRequestBatch.addLeaderAndIsrRequestForBrokers(Seq(brokerId, otherBrokerId), + partition, LeaderIsrAndControllerEpoch(updatedLeaderAndIsr, controllerEpoch), replicaAssignment(Seq(brokerId, otherBrokerId)), + isNew = false)) + EasyMock.expect(mockControllerBrokerRequestBatch.sendRequestsToBrokers(controllerEpoch)) + EasyMock.replay(mockZkClient, mockControllerBrokerRequestBatch) + + partitionStateMachine.handleStateChanges(partitions, OnlinePartition, Option(ControlledShutdownPartitionLeaderElectionStrategy)) + EasyMock.verify(mockZkClient, mockControllerBrokerRequestBatch) + assertEquals(OnlinePartition, partitionState(partition)) + } + + @Test + def testOnlinePartitionToOfflineTransition(): Unit = { + controllerContext.putPartitionState(partition, OnlinePartition) + partitionStateMachine.handleStateChanges(partitions, OfflinePartition) + assertEquals(OfflinePartition, partitionState(partition)) + } + + @Test + def testInvalidOnlinePartitionToNonexistentPartitionTransition(): Unit = { + controllerContext.putPartitionState(partition, OnlinePartition) + partitionStateMachine.handleStateChanges(partitions, NonExistentPartition) + assertEquals(OnlinePartition, partitionState(partition)) + } + + @Test + def testInvalidOnlinePartitionToNewPartitionTransition(): Unit = { + controllerContext.putPartitionState(partition, OnlinePartition) + partitionStateMachine.handleStateChanges(partitions, NewPartition) + assertEquals(OnlinePartition, partitionState(partition)) + } + + @Test + def testOfflinePartitionToOnlinePartitionTransition(): Unit = { + controllerContext.setLiveBrokers(Map(TestUtils.createBrokerAndEpoch(brokerId, "host", 0))) + controllerContext.updatePartitionFullReplicaAssignment(partition, ReplicaAssignment(Seq(brokerId))) + controllerContext.putPartitionState(partition, OfflinePartition) + val leaderAndIsr = LeaderAndIsr(LeaderAndIsr.NoLeader, List(brokerId)) + val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch) + controllerContext.putPartitionLeadershipInfo(partition, leaderIsrAndControllerEpoch) + + val stat = new Stat(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) + EasyMock.expect(mockControllerBrokerRequestBatch.newBatch()) + EasyMock.expect(mockZkClient.getTopicPartitionStatesRaw(partitions)) + .andReturn(Seq(GetDataResponse(Code.OK, null, Some(partition), + TopicPartitionStateZNode.encode(leaderIsrAndControllerEpoch), stat, ResponseMetadata(0, 0)))) + + EasyMock.expect(mockZkClient.getLogConfigs(Set.empty, config.originals())) + .andReturn((Map(partition.topic -> LogConfig()), Map.empty)) + val leaderAndIsrAfterElection = leaderAndIsr.newLeader(brokerId) + val updatedLeaderAndIsr = leaderAndIsrAfterElection.withZkVersion(2) + EasyMock.expect(mockZkClient.updateLeaderAndIsr(Map(partition -> leaderAndIsrAfterElection), controllerEpoch, controllerContext.epochZkVersion)) + .andReturn(UpdateLeaderAndIsrResult(Map(partition -> Right(updatedLeaderAndIsr)), Seq.empty)) + EasyMock.expect(mockControllerBrokerRequestBatch.addLeaderAndIsrRequestForBrokers(Seq(brokerId), + partition, LeaderIsrAndControllerEpoch(updatedLeaderAndIsr, controllerEpoch), replicaAssignment(Seq(brokerId)), isNew = false)) + EasyMock.expect(mockControllerBrokerRequestBatch.sendRequestsToBrokers(controllerEpoch)) + EasyMock.replay(mockZkClient, mockControllerBrokerRequestBatch) + + partitionStateMachine.handleStateChanges( + partitions, + OnlinePartition, + Option(OfflinePartitionLeaderElectionStrategy(false)) + ) + EasyMock.verify(mockZkClient, mockControllerBrokerRequestBatch) + assertEquals(OnlinePartition, partitionState(partition)) + } + + @Test + def testOfflinePartitionToUncleanOnlinePartitionTransition(): Unit = { + /* Starting scenario: Leader: X, Isr: [X], Replicas: [X, Y], LiveBrokers: [Y] + * Ending scenario: Leader: Y, Isr: [Y], Replicas: [X, Y], LiverBrokers: [Y] + * + * For the give staring scenario verify that performing an unclean leader + * election on the offline partition results on the first live broker getting + * elected. + */ + val leaderBrokerId = brokerId + 1 + controllerContext.setLiveBrokers(Map(TestUtils.createBrokerAndEpoch(brokerId, "host", 0))) + controllerContext.updatePartitionFullReplicaAssignment( + partition, + ReplicaAssignment(Seq(leaderBrokerId, brokerId)) + ) + controllerContext.putPartitionState(partition, OfflinePartition) + + val leaderAndIsr = LeaderAndIsr(leaderBrokerId, List(leaderBrokerId)) + val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch) + controllerContext.putPartitionLeadershipInfo(partition, leaderIsrAndControllerEpoch) + + EasyMock.expect(mockControllerBrokerRequestBatch.newBatch()) + EasyMock + .expect(mockZkClient.getTopicPartitionStatesRaw(partitions)) + .andReturn( + Seq( + GetDataResponse( + Code.OK, + null, + Option(partition), + TopicPartitionStateZNode.encode(leaderIsrAndControllerEpoch), + new Stat(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), + ResponseMetadata(0, 0) + ) + ) + ) + + val leaderAndIsrAfterElection = leaderAndIsr.newLeaderAndIsr(brokerId, List(brokerId)) + val updatedLeaderAndIsr = leaderAndIsrAfterElection.withZkVersion(2) + + EasyMock + .expect( + mockZkClient.updateLeaderAndIsr( + Map(partition -> leaderAndIsrAfterElection), + controllerEpoch, + controllerContext.epochZkVersion + ) + ) + .andReturn(UpdateLeaderAndIsrResult(Map(partition -> Right(updatedLeaderAndIsr)), Seq.empty)) + EasyMock.expect( + mockControllerBrokerRequestBatch.addLeaderAndIsrRequestForBrokers( + Seq(brokerId), + partition, + LeaderIsrAndControllerEpoch(updatedLeaderAndIsr, controllerEpoch), + replicaAssignment(Seq(leaderBrokerId, brokerId)), + false + ) + ) + EasyMock.expect(mockControllerBrokerRequestBatch.sendRequestsToBrokers(controllerEpoch)) + EasyMock.replay(mockZkClient, mockControllerBrokerRequestBatch) + + partitionStateMachine.handleStateChanges( + partitions, + OnlinePartition, + Option(OfflinePartitionLeaderElectionStrategy(true)) + ) + EasyMock.verify(mockZkClient, mockControllerBrokerRequestBatch) + assertEquals(OnlinePartition, partitionState(partition)) + } + + @Test + def testOfflinePartitionToOnlinePartitionTransitionZooKeeperClientExceptionFromStateLookup(): Unit = { + controllerContext.setLiveBrokers(Map(TestUtils.createBrokerAndEpoch(brokerId, "host", 0))) + controllerContext.updatePartitionFullReplicaAssignment(partition, ReplicaAssignment(Seq(brokerId))) + controllerContext.putPartitionState(partition, OfflinePartition) + val leaderAndIsr = LeaderAndIsr(LeaderAndIsr.NoLeader, List(brokerId)) + val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch) + controllerContext.putPartitionLeadershipInfo(partition, leaderIsrAndControllerEpoch) + + EasyMock.expect(mockControllerBrokerRequestBatch.newBatch()) + EasyMock.expect(mockZkClient.getTopicPartitionStatesRaw(partitions)) + .andThrow(new ZooKeeperClientException("")) + + EasyMock.expect(mockControllerBrokerRequestBatch.sendRequestsToBrokers(controllerEpoch)) + EasyMock.replay(mockZkClient, mockControllerBrokerRequestBatch) + + partitionStateMachine.handleStateChanges( + partitions, + OnlinePartition, + Option(OfflinePartitionLeaderElectionStrategy(false)) + ) + EasyMock.verify(mockZkClient, mockControllerBrokerRequestBatch) + assertEquals(OfflinePartition, partitionState(partition)) + } + + @Test + def testOfflinePartitionToOnlinePartitionTransitionErrorCodeFromStateLookup(): Unit = { + controllerContext.setLiveBrokers(Map(TestUtils.createBrokerAndEpoch(brokerId, "host", 0))) + controllerContext.updatePartitionFullReplicaAssignment(partition, ReplicaAssignment(Seq(brokerId))) + controllerContext.putPartitionState(partition, OfflinePartition) + val leaderAndIsr = LeaderAndIsr(LeaderAndIsr.NoLeader, List(brokerId)) + val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch) + controllerContext.putPartitionLeadershipInfo(partition, leaderIsrAndControllerEpoch) + + val stat = new Stat(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) + EasyMock.expect(mockControllerBrokerRequestBatch.newBatch()) + EasyMock.expect(mockZkClient.getTopicPartitionStatesRaw(partitions)) + .andReturn(Seq(GetDataResponse(Code.NONODE, null, Some(partition), + TopicPartitionStateZNode.encode(leaderIsrAndControllerEpoch), stat, ResponseMetadata(0, 0)))) + + EasyMock.expect(mockControllerBrokerRequestBatch.sendRequestsToBrokers(controllerEpoch)) + EasyMock.replay(mockZkClient, mockControllerBrokerRequestBatch) + + partitionStateMachine.handleStateChanges( + partitions, + OnlinePartition, + Option(OfflinePartitionLeaderElectionStrategy(false)) + ) + EasyMock.verify(mockZkClient, mockControllerBrokerRequestBatch) + assertEquals(OfflinePartition, partitionState(partition)) + } + + @Test + def testOfflinePartitionToNonexistentPartitionTransition(): Unit = { + controllerContext.putPartitionState(partition, OfflinePartition) + partitionStateMachine.handleStateChanges(partitions, NonExistentPartition) + assertEquals(NonExistentPartition, partitionState(partition)) + } + + @Test + def testInvalidOfflinePartitionToNewPartitionTransition(): Unit = { + controllerContext.putPartitionState(partition, OfflinePartition) + partitionStateMachine.handleStateChanges(partitions, NewPartition) + assertEquals(OfflinePartition, partitionState(partition)) + } + + private def prepareMockToElectLeaderForPartitions(partitions: Seq[TopicPartition]): Unit = { + val leaderAndIsr = LeaderAndIsr(brokerId, List(brokerId)) + def prepareMockToGetTopicPartitionsStatesRaw(): Unit = { + val stat = new Stat(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) + val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch) + val getDataResponses = partitions.map {p => GetDataResponse(Code.OK, null, Some(p), + TopicPartitionStateZNode.encode(leaderIsrAndControllerEpoch), stat, ResponseMetadata(0, 0))} + EasyMock.expect(mockZkClient.getTopicPartitionStatesRaw(partitions)) + .andReturn(getDataResponses) + } + prepareMockToGetTopicPartitionsStatesRaw() + + def prepareMockToGetLogConfigs(): Unit = { + EasyMock.expect(mockZkClient.getLogConfigs(Set.empty, config.originals())) + .andReturn(Map.empty, Map.empty) + } + prepareMockToGetLogConfigs() + + def prepareMockToUpdateLeaderAndIsr(): Unit = { + val updatedLeaderAndIsr: Map[TopicPartition, LeaderAndIsr] = partitions.map { partition => + partition -> leaderAndIsr.newLeaderAndIsr(brokerId, List(brokerId)) + }.toMap + EasyMock.expect(mockZkClient.updateLeaderAndIsr(updatedLeaderAndIsr, controllerEpoch, controllerContext.epochZkVersion)) + .andReturn(UpdateLeaderAndIsrResult(updatedLeaderAndIsr.map { case (k, v) => k -> Right(v) }, Seq.empty)) + } + prepareMockToUpdateLeaderAndIsr() + } + + /** + * This method tests changing partitions' state to OfflinePartition increments the offlinePartitionCount, + * and changing their state back to OnlinePartition decrements the offlinePartitionCount + */ + @Test + def testUpdatingOfflinePartitionsCount(): Unit = { + controllerContext.setLiveBrokers(Map(TestUtils.createBrokerAndEpoch(brokerId, "host", 0))) + + val partitionIds = Seq(0, 1, 2, 3) + val topic = "test" + val partitions = partitionIds.map(new TopicPartition(topic, _)) + + partitions.foreach { partition => + controllerContext.updatePartitionFullReplicaAssignment(partition, ReplicaAssignment(Seq(brokerId))) + } + + prepareMockToElectLeaderForPartitions(partitions) + EasyMock.replay(mockZkClient) + + partitionStateMachine.handleStateChanges(partitions, NewPartition) + partitionStateMachine.handleStateChanges(partitions, OfflinePartition) + assertEquals(partitions.size, controllerContext.offlinePartitionCount, + s"There should be ${partitions.size} offline partition(s)") + + partitionStateMachine.handleStateChanges(partitions, OnlinePartition, Some(OfflinePartitionLeaderElectionStrategy(false))) + assertEquals(0, controllerContext.offlinePartitionCount, + s"There should be no offline partition(s)") + } + + /** + * This method tests if a topic is being deleted, then changing partitions' state to OfflinePartition makes no change + * to the offlinePartitionCount + */ + @Test + def testNoOfflinePartitionsChangeForTopicsBeingDeleted() = { + val partitionIds = Seq(0, 1, 2, 3) + val topic = "test" + val partitions = partitionIds.map(new TopicPartition(topic, _)) + + controllerContext.topicsToBeDeleted.add(topic) + controllerContext.topicsWithDeletionStarted.add(topic) + + partitionStateMachine.handleStateChanges(partitions, NewPartition) + partitionStateMachine.handleStateChanges(partitions, OfflinePartition) + assertEquals(0, controllerContext.offlinePartitionCount, + s"There should be no offline partition(s)") + } + + /** + * This method tests if some partitions are already in OfflinePartition state, + * then deleting their topic will decrement the offlinePartitionCount. + * For example, if partitions test-0, test-1, test-2, test-3 are in OfflinePartition state, + * and the offlinePartitionCount is 4, trying to delete the topic "test" means these + * partitions no longer qualify as offline-partitions, and the offlinePartitionCount + * should be decremented to 0. + */ + @Test + def testUpdatingOfflinePartitionsCountDuringTopicDeletion() = { + val partitionIds = Seq(0, 1, 2, 3) + val topic = "test" + val partitions = partitionIds.map(new TopicPartition("test", _)) + partitions.foreach { partition => + controllerContext.updatePartitionFullReplicaAssignment(partition, ReplicaAssignment(Seq(brokerId))) + } + + val partitionStateMachine = new MockPartitionStateMachine(controllerContext, uncleanLeaderElectionEnabled = false) + val replicaStateMachine = new MockReplicaStateMachine(controllerContext) + val deletionClient = Mockito.mock(classOf[DeletionClient]) + val topicDeletionManager = new TopicDeletionManager(config, controllerContext, + replicaStateMachine, partitionStateMachine, deletionClient) + + partitionStateMachine.handleStateChanges(partitions, NewPartition) + partitionStateMachine.handleStateChanges(partitions, OfflinePartition) + partitions.foreach { partition => + val replica = PartitionAndReplica(partition, brokerId) + controllerContext.putReplicaState(replica, OfflineReplica) + } + + assertEquals(partitions.size, controllerContext.offlinePartitionCount, + s"There should be ${partitions.size} offline partition(s)") + topicDeletionManager.enqueueTopicsForDeletion(Set(topic)) + assertEquals(0, controllerContext.offlinePartitionCount, + s"There should be no offline partition(s)") + } + + private def replicaAssignment(replicas: Seq[Int]): ReplicaAssignment = ReplicaAssignment(replicas, Seq(), Seq()) + +} diff --git a/core/src/test/scala/unit/kafka/controller/ReplicaStateMachineTest.scala b/core/src/test/scala/unit/kafka/controller/ReplicaStateMachineTest.scala new file mode 100644 index 0000000..de43e05 --- /dev/null +++ b/core/src/test/scala/unit/kafka/controller/ReplicaStateMachineTest.scala @@ -0,0 +1,414 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +package kafka.controller + +import kafka.api.LeaderAndIsr +import kafka.cluster.{Broker, EndPoint} +import kafka.server.KafkaConfig +import kafka.utils.TestUtils +import kafka.zk.KafkaZkClient.UpdateLeaderAndIsrResult +import kafka.zk.{KafkaZkClient, TopicPartitionStateZNode} +import kafka.zookeeper.{GetDataResponse, ResponseMetadata} +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.zookeeper.KeeperException.Code +import org.apache.zookeeper.data.Stat +import org.easymock.EasyMock +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{BeforeEach, Test} + +class ReplicaStateMachineTest { + private var controllerContext: ControllerContext = null + private var mockZkClient: KafkaZkClient = null + private var mockControllerBrokerRequestBatch: ControllerBrokerRequestBatch = null + private var replicaStateMachine: ReplicaStateMachine = null + + private val brokerId = 5 + private val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(brokerId, "zkConnect")) + private val controllerEpoch = 50 + private val partition = new TopicPartition("t", 0) + private val partitions = Seq(partition) + private val replica = PartitionAndReplica(partition, brokerId) + private val replicas = Seq(replica) + + @BeforeEach + def setUp(): Unit = { + controllerContext = new ControllerContext + controllerContext.epoch = controllerEpoch + mockZkClient = EasyMock.createMock(classOf[KafkaZkClient]) + mockControllerBrokerRequestBatch = EasyMock.createMock(classOf[ControllerBrokerRequestBatch]) + replicaStateMachine = new ZkReplicaStateMachine(config, new StateChangeLogger(brokerId, true, None), + controllerContext, mockZkClient, mockControllerBrokerRequestBatch) + } + + private def replicaState(replica: PartitionAndReplica): ReplicaState = { + controllerContext.replicaState(replica) + } + + @Test + def testStartupOnlinePartition(): Unit = { + val endpoint1 = new EndPoint("localhost", 9997, new ListenerName("blah"), + SecurityProtocol.PLAINTEXT) + val liveBrokerEpochs = Map(Broker(brokerId, Seq(endpoint1), rack = None) -> 1L) + controllerContext.setLiveBrokers(liveBrokerEpochs) + controllerContext.updatePartitionFullReplicaAssignment(partition, ReplicaAssignment(Seq(brokerId))) + assertEquals(None, controllerContext.replicaStates.get(replica)) + replicaStateMachine.startup() + assertEquals(OnlineReplica, replicaState(replica)) + } + + @Test + def testStartupOfflinePartition(): Unit = { + controllerContext.updatePartitionFullReplicaAssignment(partition, ReplicaAssignment(Seq(brokerId))) + assertEquals(None, controllerContext.replicaStates.get(replica)) + replicaStateMachine.startup() + assertEquals(OfflineReplica, replicaState(replica)) + } + + @Test + def testStartupWithReplicaWithoutLeader(): Unit = { + val shutdownBrokerId = 100 + val offlineReplica = PartitionAndReplica(partition, shutdownBrokerId) + val endpoint1 = new EndPoint("localhost", 9997, new ListenerName("blah"), + SecurityProtocol.PLAINTEXT) + val liveBrokerEpochs = Map(Broker(brokerId, Seq(endpoint1), rack = None) -> 1L) + controllerContext.setLiveBrokers(liveBrokerEpochs) + controllerContext.updatePartitionFullReplicaAssignment(partition, ReplicaAssignment(Seq(shutdownBrokerId))) + assertEquals(None, controllerContext.replicaStates.get(offlineReplica)) + replicaStateMachine.startup() + assertEquals(OfflineReplica, replicaState(offlineReplica)) + } + + @Test + def testNonexistentReplicaToNewReplicaTransition(): Unit = { + replicaStateMachine.handleStateChanges(replicas, NewReplica) + assertEquals(NewReplica, replicaState(replica)) + } + + @Test + def testInvalidNonexistentReplicaToOnlineReplicaTransition(): Unit = { + replicaStateMachine.handleStateChanges(replicas, OnlineReplica) + assertEquals(NonExistentReplica, replicaState(replica)) + } + + @Test + def testInvalidNonexistentReplicaToOfflineReplicaTransition(): Unit = { + replicaStateMachine.handleStateChanges(replicas, OfflineReplica) + assertEquals(NonExistentReplica, replicaState(replica)) + } + + @Test + def testInvalidNonexistentReplicaToReplicaDeletionStartedTransition(): Unit = { + replicaStateMachine.handleStateChanges(replicas, ReplicaDeletionStarted) + assertEquals(NonExistentReplica, replicaState(replica)) + } + + @Test + def testInvalidNonexistentReplicaToReplicaDeletionIneligibleTransition(): Unit = { + replicaStateMachine.handleStateChanges(replicas, ReplicaDeletionIneligible) + assertEquals(NonExistentReplica, replicaState(replica)) + } + + @Test + def testInvalidNonexistentReplicaToReplicaDeletionSuccessfulTransition(): Unit = { + replicaStateMachine.handleStateChanges(replicas, ReplicaDeletionSuccessful) + assertEquals(NonExistentReplica, replicaState(replica)) + } + + @Test + def testInvalidNewReplicaToNonexistentReplicaTransition(): Unit = { + testInvalidTransition(NewReplica, NonExistentReplica) + } + + @Test + def testNewReplicaToOnlineReplicaTransition(): Unit = { + controllerContext.putReplicaState(replica, NewReplica) + controllerContext.updatePartitionFullReplicaAssignment(partition, ReplicaAssignment(Seq(brokerId))) + replicaStateMachine.handleStateChanges(replicas, OnlineReplica) + assertEquals(OnlineReplica, replicaState(replica)) + } + + @Test + def testNewReplicaToOfflineReplicaTransition(): Unit = { + val endpoint1 = new EndPoint("localhost", 9997, new ListenerName("blah"), + SecurityProtocol.PLAINTEXT) + val liveBrokerEpochs = Map(Broker(brokerId, Seq(endpoint1), rack = None) -> 1L) + controllerContext.setLiveBrokers(liveBrokerEpochs) + controllerContext.putReplicaState(replica, NewReplica) + EasyMock.expect(mockControllerBrokerRequestBatch.newBatch()) + EasyMock.expect(mockControllerBrokerRequestBatch.addStopReplicaRequestForBrokers(EasyMock.eq(Seq(brokerId)), EasyMock.eq(partition), EasyMock.eq(false))) + EasyMock.expect(mockControllerBrokerRequestBatch.addUpdateMetadataRequestForBrokers(EasyMock.eq(Seq(brokerId)), EasyMock.eq(Set(partition)))) + EasyMock.expect(mockControllerBrokerRequestBatch.sendRequestsToBrokers(controllerEpoch)) + + EasyMock.replay(mockControllerBrokerRequestBatch) + replicaStateMachine.handleStateChanges(replicas, OfflineReplica) + EasyMock.verify(mockControllerBrokerRequestBatch) + assertEquals(OfflineReplica, replicaState(replica)) + } + + @Test + def testInvalidNewReplicaToReplicaDeletionStartedTransition(): Unit = { + testInvalidTransition(NewReplica, ReplicaDeletionStarted) + } + + @Test + def testInvalidNewReplicaToReplicaDeletionIneligibleTransition(): Unit = { + testInvalidTransition(NewReplica, ReplicaDeletionIneligible) + } + + @Test + def testInvalidNewReplicaToReplicaDeletionSuccessfulTransition(): Unit = { + testInvalidTransition(NewReplica, ReplicaDeletionSuccessful) + } + + @Test + def testInvalidOnlineReplicaToNonexistentReplicaTransition(): Unit = { + testInvalidTransition(OnlineReplica, NonExistentReplica) + } + + @Test + def testInvalidOnlineReplicaToNewReplicaTransition(): Unit = { + testInvalidTransition(OnlineReplica, NewReplica) + } + + @Test + def testOnlineReplicaToOnlineReplicaTransition(): Unit = { + controllerContext.putReplicaState(replica, OnlineReplica) + controllerContext.updatePartitionFullReplicaAssignment(partition, ReplicaAssignment(Seq(brokerId))) + val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(LeaderAndIsr(brokerId, List(brokerId)), controllerEpoch) + controllerContext.putPartitionLeadershipInfo(partition, leaderIsrAndControllerEpoch) + EasyMock.expect(mockControllerBrokerRequestBatch.newBatch()) + EasyMock.expect(mockControllerBrokerRequestBatch.addLeaderAndIsrRequestForBrokers(Seq(brokerId), + partition, leaderIsrAndControllerEpoch, replicaAssignment(Seq(brokerId)), isNew = false)) + EasyMock.expect(mockControllerBrokerRequestBatch.sendRequestsToBrokers(controllerEpoch)) + EasyMock.replay(mockZkClient, mockControllerBrokerRequestBatch) + replicaStateMachine.handleStateChanges(replicas, OnlineReplica) + EasyMock.verify(mockZkClient, mockControllerBrokerRequestBatch) + assertEquals(OnlineReplica, replicaState(replica)) + } + + @Test + def testOnlineReplicaToOfflineReplicaTransition(): Unit = { + val otherBrokerId = brokerId + 1 + val replicaIds = List(brokerId, otherBrokerId) + controllerContext.putReplicaState(replica, OnlineReplica) + controllerContext.updatePartitionFullReplicaAssignment(partition, ReplicaAssignment(replicaIds)) + val leaderAndIsr = LeaderAndIsr(brokerId, replicaIds) + val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch) + controllerContext.putPartitionLeadershipInfo(partition, leaderIsrAndControllerEpoch) + + val stat = new Stat(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) + EasyMock.expect(mockControllerBrokerRequestBatch.newBatch()) + EasyMock.expect(mockControllerBrokerRequestBatch.addStopReplicaRequestForBrokers(EasyMock.eq(Seq(brokerId)), EasyMock.eq(partition), EasyMock.eq(false))) + val adjustedLeaderAndIsr = leaderAndIsr.newLeaderAndIsr(LeaderAndIsr.NoLeader, List(otherBrokerId)) + val updatedLeaderAndIsr = adjustedLeaderAndIsr.withZkVersion(adjustedLeaderAndIsr .zkVersion + 1) + val updatedLeaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(updatedLeaderAndIsr, controllerEpoch) + EasyMock.expect(mockZkClient.getTopicPartitionStatesRaw(partitions)).andReturn( + Seq(GetDataResponse(Code.OK, null, Some(partition), + TopicPartitionStateZNode.encode(leaderIsrAndControllerEpoch), stat, ResponseMetadata(0, 0)))) + EasyMock.expect(mockZkClient.updateLeaderAndIsr(Map(partition -> adjustedLeaderAndIsr), controllerEpoch, controllerContext.epochZkVersion)) + .andReturn(UpdateLeaderAndIsrResult(Map(partition -> Right(updatedLeaderAndIsr)), Seq.empty)) + EasyMock.expect(mockControllerBrokerRequestBatch.addLeaderAndIsrRequestForBrokers(Seq(otherBrokerId), + partition, updatedLeaderIsrAndControllerEpoch, replicaAssignment(replicaIds), isNew = false)) + EasyMock.expect(mockControllerBrokerRequestBatch.sendRequestsToBrokers(controllerEpoch)) + + EasyMock.replay(mockZkClient, mockControllerBrokerRequestBatch) + replicaStateMachine.handleStateChanges(replicas, OfflineReplica) + EasyMock.verify(mockZkClient, mockControllerBrokerRequestBatch) + assertEquals(updatedLeaderIsrAndControllerEpoch, controllerContext.partitionLeadershipInfo(partition).get) + assertEquals(OfflineReplica, replicaState(replica)) + } + + @Test + def testInvalidOnlineReplicaToReplicaDeletionStartedTransition(): Unit = { + testInvalidTransition(OnlineReplica, ReplicaDeletionStarted) + } + + @Test + def testInvalidOnlineReplicaToReplicaDeletionIneligibleTransition(): Unit = { + testInvalidTransition(OnlineReplica, ReplicaDeletionIneligible) + } + + @Test + def testInvalidOnlineReplicaToReplicaDeletionSuccessfulTransition(): Unit = { + testInvalidTransition(OnlineReplica, ReplicaDeletionSuccessful) + } + + @Test + def testInvalidOfflineReplicaToNonexistentReplicaTransition(): Unit = { + testInvalidTransition(OfflineReplica, NonExistentReplica) + } + + @Test + def testInvalidOfflineReplicaToNewReplicaTransition(): Unit = { + testInvalidTransition(OfflineReplica, NewReplica) + } + + @Test + def testOfflineReplicaToOnlineReplicaTransition(): Unit = { + controllerContext.putReplicaState(replica, OfflineReplica) + controllerContext.updatePartitionFullReplicaAssignment(partition, ReplicaAssignment(Seq(brokerId))) + val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(LeaderAndIsr(brokerId, List(brokerId)), controllerEpoch) + controllerContext.putPartitionLeadershipInfo(partition, leaderIsrAndControllerEpoch) + EasyMock.expect(mockControllerBrokerRequestBatch.newBatch()) + EasyMock.expect(mockControllerBrokerRequestBatch.addLeaderAndIsrRequestForBrokers(Seq(brokerId), + partition, leaderIsrAndControllerEpoch, replicaAssignment(Seq(brokerId)), isNew = false)) + EasyMock.expect(mockControllerBrokerRequestBatch.sendRequestsToBrokers(controllerEpoch)) + EasyMock.replay(mockZkClient, mockControllerBrokerRequestBatch) + replicaStateMachine.handleStateChanges(replicas, OnlineReplica) + EasyMock.verify(mockZkClient, mockControllerBrokerRequestBatch) + assertEquals(OnlineReplica, replicaState(replica)) + } + + @Test + def testOfflineReplicaToReplicaDeletionStartedTransition(): Unit = { + controllerContext.putReplicaState(replica, OfflineReplica) + EasyMock.expect(mockControllerBrokerRequestBatch.newBatch()) + EasyMock.expect(mockControllerBrokerRequestBatch.addStopReplicaRequestForBrokers(Seq(brokerId), partition, true)) + EasyMock.expect(mockControllerBrokerRequestBatch.sendRequestsToBrokers(controllerEpoch)) + EasyMock.replay(mockZkClient, mockControllerBrokerRequestBatch) + replicaStateMachine.handleStateChanges(replicas, ReplicaDeletionStarted) + EasyMock.verify(mockZkClient, mockControllerBrokerRequestBatch) + assertEquals(ReplicaDeletionStarted, replicaState(replica)) + } + + @Test + def testOfflineReplicaToReplicaDeletionIneligibleTransition(): Unit = { + controllerContext.putReplicaState(replica, OfflineReplica) + replicaStateMachine.handleStateChanges(replicas, ReplicaDeletionIneligible) + assertEquals(ReplicaDeletionIneligible, replicaState(replica)) + } + + @Test + def testInvalidOfflineReplicaToReplicaDeletionSuccessfulTransition(): Unit = { + testInvalidTransition(OfflineReplica, ReplicaDeletionSuccessful) + } + + @Test + def testInvalidReplicaDeletionStartedToNonexistentReplicaTransition(): Unit = { + testInvalidTransition(ReplicaDeletionStarted, NonExistentReplica) + } + + @Test + def testInvalidReplicaDeletionStartedToNewReplicaTransition(): Unit = { + testInvalidTransition(ReplicaDeletionStarted, NewReplica) + } + + @Test + def testInvalidReplicaDeletionStartedToOnlineReplicaTransition(): Unit = { + testInvalidTransition(ReplicaDeletionStarted, OnlineReplica) + } + + @Test + def testInvalidReplicaDeletionStartedToOfflineReplicaTransition(): Unit = { + testInvalidTransition(ReplicaDeletionStarted, OfflineReplica) + } + + @Test + def testReplicaDeletionStartedToReplicaDeletionIneligibleTransition(): Unit = { + controllerContext.putReplicaState(replica, ReplicaDeletionStarted) + replicaStateMachine.handleStateChanges(replicas, ReplicaDeletionIneligible) + assertEquals(ReplicaDeletionIneligible, replicaState(replica)) + } + + @Test + def testReplicaDeletionStartedToReplicaDeletionSuccessfulTransition(): Unit = { + controllerContext.putReplicaState(replica, ReplicaDeletionStarted) + replicaStateMachine.handleStateChanges(replicas, ReplicaDeletionSuccessful) + assertEquals(ReplicaDeletionSuccessful, replicaState(replica)) + } + + @Test + def testReplicaDeletionSuccessfulToNonexistentReplicaTransition(): Unit = { + controllerContext.putReplicaState(replica, ReplicaDeletionSuccessful) + controllerContext.updatePartitionFullReplicaAssignment(partition, ReplicaAssignment(Seq(brokerId))) + replicaStateMachine.handleStateChanges(replicas, NonExistentReplica) + assertEquals(Seq.empty, controllerContext.partitionReplicaAssignment(partition)) + assertEquals(None, controllerContext.replicaStates.get(replica)) + } + + @Test + def testInvalidReplicaDeletionSuccessfulToNewReplicaTransition(): Unit = { + testInvalidTransition(ReplicaDeletionSuccessful, NewReplica) + } + + @Test + def testInvalidReplicaDeletionSuccessfulToOnlineReplicaTransition(): Unit = { + testInvalidTransition(ReplicaDeletionSuccessful, OnlineReplica) + } + + @Test + def testInvalidReplicaDeletionSuccessfulToOfflineReplicaTransition(): Unit = { + testInvalidTransition(ReplicaDeletionSuccessful, OfflineReplica) + } + + @Test + def testInvalidReplicaDeletionSuccessfulToReplicaDeletionStartedTransition(): Unit = { + testInvalidTransition(ReplicaDeletionSuccessful, ReplicaDeletionStarted) + } + + @Test + def testInvalidReplicaDeletionSuccessfulToReplicaDeletionIneligibleTransition(): Unit = { + testInvalidTransition(ReplicaDeletionSuccessful, ReplicaDeletionIneligible) + } + + @Test + def testInvalidReplicaDeletionIneligibleToNonexistentReplicaTransition(): Unit = { + testInvalidTransition(ReplicaDeletionIneligible, NonExistentReplica) + } + + @Test + def testInvalidReplicaDeletionIneligibleToNewReplicaTransition(): Unit = { + testInvalidTransition(ReplicaDeletionIneligible, NewReplica) + } + + @Test + def testReplicaDeletionIneligibleToOnlineReplicaTransition(): Unit = { + controllerContext.putReplicaState(replica, ReplicaDeletionIneligible) + controllerContext.updatePartitionFullReplicaAssignment(partition, ReplicaAssignment(Seq(brokerId))) + val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(LeaderAndIsr(brokerId, List(brokerId)), controllerEpoch) + controllerContext.putPartitionLeadershipInfo(partition, leaderIsrAndControllerEpoch) + EasyMock.expect(mockControllerBrokerRequestBatch.newBatch()) + EasyMock.expect(mockControllerBrokerRequestBatch.addLeaderAndIsrRequestForBrokers(Seq(brokerId), + partition, leaderIsrAndControllerEpoch, replicaAssignment(Seq(brokerId)), isNew = false)) + EasyMock.expect(mockControllerBrokerRequestBatch.sendRequestsToBrokers(controllerEpoch)) + EasyMock.replay(mockZkClient, mockControllerBrokerRequestBatch) + replicaStateMachine.handleStateChanges(replicas, OnlineReplica) + EasyMock.verify(mockZkClient, mockControllerBrokerRequestBatch) + assertEquals(OnlineReplica, replicaState(replica)) + } + + @Test + def testInvalidReplicaDeletionIneligibleToReplicaDeletionStartedTransition(): Unit = { + testInvalidTransition(ReplicaDeletionIneligible, ReplicaDeletionStarted) + } + + @Test + def testInvalidReplicaDeletionIneligibleToReplicaDeletionSuccessfulTransition(): Unit = { + testInvalidTransition(ReplicaDeletionIneligible, ReplicaDeletionSuccessful) + } + + private def testInvalidTransition(fromState: ReplicaState, toState: ReplicaState): Unit = { + controllerContext.putReplicaState(replica, fromState) + replicaStateMachine.handleStateChanges(replicas, toState) + assertEquals(fromState, replicaState(replica)) + } + + private def replicaAssignment(replicas: Seq[Int]): ReplicaAssignment = ReplicaAssignment(replicas, Seq(), Seq()) + +} diff --git a/core/src/test/scala/unit/kafka/controller/TopicDeletionManagerTest.scala b/core/src/test/scala/unit/kafka/controller/TopicDeletionManagerTest.scala new file mode 100644 index 0000000..ec2339d --- /dev/null +++ b/core/src/test/scala/unit/kafka/controller/TopicDeletionManagerTest.scala @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.controller + +import kafka.cluster.{Broker, EndPoint} +import kafka.server.KafkaConfig +import kafka.utils.TestUtils +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test +import org.mockito.Mockito._ + +class TopicDeletionManagerTest { + + private val brokerId = 1 + private val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(brokerId, "zkConnect")) + private val deletionClient = mock(classOf[DeletionClient]) + + @Test + def testInitialization(): Unit = { + val controllerContext = initContext( + brokers = Seq(1, 2, 3), + topics = Set("foo", "bar", "baz"), + numPartitions = 2, + replicationFactor = 3) + + val replicaStateMachine = new MockReplicaStateMachine(controllerContext) + replicaStateMachine.startup() + + val partitionStateMachine = new MockPartitionStateMachine(controllerContext, uncleanLeaderElectionEnabled = false) + partitionStateMachine.startup() + + val deletionManager = new TopicDeletionManager(config, controllerContext, replicaStateMachine, + partitionStateMachine, deletionClient) + + assertTrue(deletionManager.isDeleteTopicEnabled) + deletionManager.init(initialTopicsToBeDeleted = Set("foo", "bar"), initialTopicsIneligibleForDeletion = Set("bar", "baz")) + + assertEquals(Set("foo", "bar"), controllerContext.topicsToBeDeleted.toSet) + assertEquals(Set("bar"), controllerContext.topicsIneligibleForDeletion.toSet) + } + + @Test + def testBasicDeletion(): Unit = { + val controllerContext = initContext( + brokers = Seq(1, 2, 3), + topics = Set("foo", "bar"), + numPartitions = 2, + replicationFactor = 3) + val replicaStateMachine = new MockReplicaStateMachine(controllerContext) + replicaStateMachine.startup() + + val partitionStateMachine = new MockPartitionStateMachine(controllerContext, uncleanLeaderElectionEnabled = false) + partitionStateMachine.startup() + + val deletionManager = new TopicDeletionManager(config, controllerContext, replicaStateMachine, + partitionStateMachine, deletionClient) + assertTrue(deletionManager.isDeleteTopicEnabled) + deletionManager.init(Set.empty, Set.empty) + + val fooPartitions = controllerContext.partitionsForTopic("foo") + val fooReplicas = controllerContext.replicasForPartition(fooPartitions).toSet + val barPartitions = controllerContext.partitionsForTopic("bar") + val barReplicas = controllerContext.replicasForPartition(barPartitions).toSet + + // Clean up state changes before starting the deletion + replicaStateMachine.clear() + partitionStateMachine.clear() + + // Queue the topic for deletion + deletionManager.enqueueTopicsForDeletion(Set("foo", "bar")) + + assertEquals(fooPartitions, controllerContext.partitionsInState("foo", NonExistentPartition)) + assertEquals(fooReplicas, controllerContext.replicasInState("foo", ReplicaDeletionStarted)) + assertEquals(barPartitions, controllerContext.partitionsInState("bar", NonExistentPartition)) + assertEquals(barReplicas, controllerContext.replicasInState("bar", ReplicaDeletionStarted)) + verify(deletionClient).sendMetadataUpdate(fooPartitions ++ barPartitions) + assertEquals(Set("foo", "bar"), controllerContext.topicsToBeDeleted) + assertEquals(Set("foo", "bar"), controllerContext.topicsWithDeletionStarted) + assertEquals(Set(), controllerContext.topicsIneligibleForDeletion) + + // Complete the deletion + deletionManager.completeReplicaDeletion(fooReplicas ++ barReplicas) + + assertEquals(Set.empty, controllerContext.partitionsForTopic("foo")) + assertEquals(Set.empty[PartitionAndReplica], controllerContext.replicaStates.keySet.filter(_.topic == "foo")) + assertEquals(Set.empty, controllerContext.partitionsForTopic("bar")) + assertEquals(Set.empty[PartitionAndReplica], controllerContext.replicaStates.keySet.filter(_.topic == "bar")) + assertEquals(Set(), controllerContext.topicsToBeDeleted) + assertEquals(Set(), controllerContext.topicsWithDeletionStarted) + assertEquals(Set(), controllerContext.topicsIneligibleForDeletion) + + assertEquals(1, partitionStateMachine.stateChangesCalls(OfflinePartition)) + assertEquals(1, partitionStateMachine.stateChangesCalls(NonExistentPartition)) + + assertEquals(1, replicaStateMachine.stateChangesCalls(ReplicaDeletionIneligible)) + assertEquals(1, replicaStateMachine.stateChangesCalls(OfflineReplica)) + assertEquals(1, replicaStateMachine.stateChangesCalls(ReplicaDeletionStarted)) + assertEquals(1, replicaStateMachine.stateChangesCalls(ReplicaDeletionSuccessful)) + } + + @Test + def testDeletionWithBrokerOffline(): Unit = { + val controllerContext = initContext( + brokers = Seq(1, 2, 3), + topics = Set("foo", "bar"), + numPartitions = 2, + replicationFactor = 3) + + val replicaStateMachine = new MockReplicaStateMachine(controllerContext) + replicaStateMachine.startup() + + val partitionStateMachine = new MockPartitionStateMachine(controllerContext, uncleanLeaderElectionEnabled = false) + partitionStateMachine.startup() + + val deletionManager = new TopicDeletionManager(config, controllerContext, replicaStateMachine, + partitionStateMachine, deletionClient) + assertTrue(deletionManager.isDeleteTopicEnabled) + deletionManager.init(Set.empty, Set.empty) + + val fooPartitions = controllerContext.partitionsForTopic("foo") + val fooReplicas = controllerContext.replicasForPartition(fooPartitions).toSet + + // Broker 2 is taken offline + val failedBrokerId = 2 + val offlineBroker = controllerContext.liveOrShuttingDownBroker(failedBrokerId).get + val lastEpoch = controllerContext.liveBrokerIdAndEpochs(failedBrokerId) + controllerContext.removeLiveBrokers(Set(failedBrokerId)) + assertEquals(Set(1, 3), controllerContext.liveBrokerIds) + + val (offlineReplicas, onlineReplicas) = fooReplicas.partition(_.replica == failedBrokerId) + replicaStateMachine.handleStateChanges(offlineReplicas.toSeq, OfflineReplica) + + // Start topic deletion + deletionManager.enqueueTopicsForDeletion(Set("foo")) + assertEquals(fooPartitions, controllerContext.partitionsInState("foo", NonExistentPartition)) + verify(deletionClient).sendMetadataUpdate(fooPartitions) + assertEquals(onlineReplicas, controllerContext.replicasInState("foo", ReplicaDeletionStarted)) + assertEquals(offlineReplicas, controllerContext.replicasInState("foo", ReplicaDeletionIneligible)) + + assertEquals(Set("foo"), controllerContext.topicsToBeDeleted) + assertEquals(Set("foo"), controllerContext.topicsWithDeletionStarted) + assertEquals(Set("foo"), controllerContext.topicsIneligibleForDeletion) + + // Deletion succeeds for online replicas + deletionManager.completeReplicaDeletion(onlineReplicas) + + assertEquals(fooPartitions, controllerContext.partitionsInState("foo", NonExistentPartition)) + assertEquals(Set("foo"), controllerContext.topicsToBeDeleted) + assertEquals(Set("foo"), controllerContext.topicsWithDeletionStarted) + assertEquals(Set("foo"), controllerContext.topicsIneligibleForDeletion) + assertEquals(onlineReplicas, controllerContext.replicasInState("foo", ReplicaDeletionSuccessful)) + assertEquals(offlineReplicas, controllerContext.replicasInState("foo", OfflineReplica)) + + // Broker 2 comes back online and deletion is resumed + controllerContext.addLiveBrokers(Map(offlineBroker -> (lastEpoch + 1L))) + deletionManager.resumeDeletionForTopics(Set("foo")) + + assertEquals(onlineReplicas, controllerContext.replicasInState("foo", ReplicaDeletionSuccessful)) + assertEquals(offlineReplicas, controllerContext.replicasInState("foo", ReplicaDeletionStarted)) + + deletionManager.completeReplicaDeletion(offlineReplicas) + assertEquals(Set.empty, controllerContext.partitionsForTopic("foo")) + assertEquals(Set.empty[PartitionAndReplica], controllerContext.replicaStates.keySet.filter(_.topic == "foo")) + assertEquals(Set(), controllerContext.topicsToBeDeleted) + assertEquals(Set(), controllerContext.topicsWithDeletionStarted) + assertEquals(Set(), controllerContext.topicsIneligibleForDeletion) + } + + @Test + def testBrokerFailureAfterDeletionStarted(): Unit = { + val controllerContext = initContext( + brokers = Seq(1, 2, 3), + topics = Set("foo", "bar"), + numPartitions = 2, + replicationFactor = 3) + + val replicaStateMachine = new MockReplicaStateMachine(controllerContext) + replicaStateMachine.startup() + + val partitionStateMachine = new MockPartitionStateMachine(controllerContext, uncleanLeaderElectionEnabled = false) + partitionStateMachine.startup() + + val deletionManager = new TopicDeletionManager(config, controllerContext, replicaStateMachine, + partitionStateMachine, deletionClient) + deletionManager.init(Set.empty, Set.empty) + + val fooPartitions = controllerContext.partitionsForTopic("foo") + val fooReplicas = controllerContext.replicasForPartition(fooPartitions).toSet + + // Queue the topic for deletion + deletionManager.enqueueTopicsForDeletion(Set("foo")) + assertEquals(fooPartitions, controllerContext.partitionsInState("foo", NonExistentPartition)) + assertEquals(fooReplicas, controllerContext.replicasInState("foo", ReplicaDeletionStarted)) + + // Broker 2 fails + val failedBrokerId = 2 + val offlineBroker = controllerContext.liveOrShuttingDownBroker(failedBrokerId).get + val lastEpoch = controllerContext.liveBrokerIdAndEpochs(failedBrokerId) + controllerContext.removeLiveBrokers(Set(failedBrokerId)) + assertEquals(Set(1, 3), controllerContext.liveBrokerIds) + val (offlineReplicas, onlineReplicas) = fooReplicas.partition(_.replica == failedBrokerId) + + // Fail replica deletion + deletionManager.failReplicaDeletion(offlineReplicas) + assertEquals(Set("foo"), controllerContext.topicsToBeDeleted) + assertEquals(Set("foo"), controllerContext.topicsWithDeletionStarted) + assertEquals(Set("foo"), controllerContext.topicsIneligibleForDeletion) + assertEquals(offlineReplicas, controllerContext.replicasInState("foo", ReplicaDeletionIneligible)) + assertEquals(onlineReplicas, controllerContext.replicasInState("foo", ReplicaDeletionStarted)) + + // Broker 2 is restarted. The offline replicas remain ineligable + // (TODO: this is probably not desired) + controllerContext.addLiveBrokers(Map(offlineBroker -> (lastEpoch + 1L))) + deletionManager.resumeDeletionForTopics(Set("foo")) + assertEquals(Set("foo"), controllerContext.topicsToBeDeleted) + assertEquals(Set("foo"), controllerContext.topicsWithDeletionStarted) + assertEquals(Set(), controllerContext.topicsIneligibleForDeletion) + assertEquals(onlineReplicas, controllerContext.replicasInState("foo", ReplicaDeletionStarted)) + assertEquals(offlineReplicas, controllerContext.replicasInState("foo", ReplicaDeletionIneligible)) + + // When deletion completes for the replicas which started, then deletion begins for the remaining ones + deletionManager.completeReplicaDeletion(onlineReplicas) + assertEquals(Set("foo"), controllerContext.topicsToBeDeleted) + assertEquals(Set("foo"), controllerContext.topicsWithDeletionStarted) + assertEquals(Set(), controllerContext.topicsIneligibleForDeletion) + assertEquals(onlineReplicas, controllerContext.replicasInState("foo", ReplicaDeletionSuccessful)) + assertEquals(offlineReplicas, controllerContext.replicasInState("foo", ReplicaDeletionStarted)) + + } + + def initContext(brokers: Seq[Int], + topics: Set[String], + numPartitions: Int, + replicationFactor: Int): ControllerContext = { + val context = new ControllerContext + val brokerEpochs = brokers.map { brokerId => + val endpoint = new EndPoint("localhost", 9900 + brokerId, new ListenerName("blah"), + SecurityProtocol.PLAINTEXT) + Broker(brokerId, Seq(endpoint), rack = None) -> 1L + }.toMap + context.setLiveBrokers(brokerEpochs) + + // Simple round-robin replica assignment + var leaderIndex = 0 + for (topic <- topics; partitionId <- 0 until numPartitions) { + val partition = new TopicPartition(topic, partitionId) + val replicas = (0 until replicationFactor).map { i => + val replica = brokers((i + leaderIndex) % brokers.size) + replica + } + context.updatePartitionFullReplicaAssignment(partition, ReplicaAssignment(replicas)) + leaderIndex += 1 + } + context + } + +} diff --git a/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala b/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala new file mode 100644 index 0000000..ddd3c18 --- /dev/null +++ b/core/src/test/scala/unit/kafka/coordinator/AbstractCoordinatorConcurrencyTest.scala @@ -0,0 +1,235 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.coordinator + +import java.util.concurrent.{ConcurrentHashMap, Executors} +import java.util.{Collections, Random} +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.locks.Lock + +import kafka.coordinator.AbstractCoordinatorConcurrencyTest._ +import kafka.log.{AppendOrigin, UnifiedLog, LogConfig} +import kafka.server._ +import kafka.utils._ +import kafka.utils.timer.MockTimer +import kafka.zk.KafkaZkClient +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.record.{MemoryRecords, RecordBatch, RecordConversionStats} +import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse +import org.easymock.EasyMock +import org.junit.jupiter.api.{AfterEach, BeforeEach} + +import scala.collection._ +import scala.jdk.CollectionConverters._ + +abstract class AbstractCoordinatorConcurrencyTest[M <: CoordinatorMember] { + + val nThreads = 5 + + val time = new MockTime + val timer = new MockTimer + val executor = Executors.newFixedThreadPool(nThreads) + val scheduler = new MockScheduler(time) + var replicaManager: TestReplicaManager = _ + var zkClient: KafkaZkClient = _ + val serverProps = TestUtils.createBrokerConfig(nodeId = 0, zkConnect = "") + val random = new Random + + @BeforeEach + def setUp(): Unit = { + + replicaManager = EasyMock.partialMockBuilder(classOf[TestReplicaManager]).createMock() + replicaManager.createDelayedProducePurgatory(timer) + + zkClient = EasyMock.createNiceMock(classOf[KafkaZkClient]) + } + + @AfterEach + def tearDown(): Unit = { + EasyMock.reset(replicaManager) + if (executor != null) + executor.shutdownNow() + } + + /** + * Verify that concurrent operations run in the normal sequence produce the expected results. + */ + def verifyConcurrentOperations(createMembers: String => Set[M], operations: Seq[Operation]): Unit = { + OrderedOperationSequence(createMembers("verifyConcurrentOperations"), operations).run() + } + + /** + * Verify that arbitrary operations run in some random sequence don't leave the coordinator + * in a bad state. Operations in the normal sequence should continue to work as expected. + */ + def verifyConcurrentRandomSequences(createMembers: String => Set[M], operations: Seq[Operation]): Unit = { + EasyMock.reset(replicaManager) + for (i <- 0 to 10) { + // Run some random operations + RandomOperationSequence(createMembers(s"random$i"), operations).run() + + // Check that proper sequences still work correctly + OrderedOperationSequence(createMembers(s"ordered$i"), operations).run() + } + } + + def verifyConcurrentActions(actions: Set[Action]): Unit = { + val futures = actions.map(executor.submit) + futures.map(_.get) + enableCompletion() + actions.foreach(_.await()) + } + + def enableCompletion(): Unit = { + replicaManager.tryCompleteActions() + scheduler.tick() + } + + abstract class OperationSequence(members: Set[M], operations: Seq[Operation]) { + def actionSequence: Seq[Set[Action]] + def run(): Unit = { + actionSequence.foreach(verifyConcurrentActions) + } + } + + case class OrderedOperationSequence(members: Set[M], operations: Seq[Operation]) + extends OperationSequence(members, operations) { + override def actionSequence: Seq[Set[Action]] = { + operations.map { op => + members.map(op.actionWithVerify) + } + } + } + + case class RandomOperationSequence(members: Set[M], operations: Seq[Operation]) + extends OperationSequence(members, operations) { + val opCount = operations.length + def actionSequence: Seq[Set[Action]] = { + (0 to opCount).map { _ => + members.map { member => + val op = operations(random.nextInt(opCount)) + op.actionNoVerify(member) // Don't wait or verify since these operations may block + } + } + } + } + + abstract class Operation { + def run(member: M): Unit + def awaitAndVerify(member: M): Unit + def actionWithVerify(member: M): Action = { + new Action() { + def run(): Unit = Operation.this.run(member) + def await(): Unit = awaitAndVerify(member) + } + } + def actionNoVerify(member: M): Action = { + new Action() { + def run(): Unit = Operation.this.run(member) + def await(): Unit = timer.advanceClock(100) // Don't wait since operation may block + } + } + } +} + +object AbstractCoordinatorConcurrencyTest { + + trait Action extends Runnable { + def await(): Unit + } + + trait CoordinatorMember { + } + + class TestReplicaManager extends ReplicaManager( + null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, None, null) { + + @volatile var logs: mutable.Map[TopicPartition, (UnifiedLog, Long)] = _ + var producePurgatory: DelayedOperationPurgatory[DelayedProduce] = _ + var watchKeys: mutable.Set[TopicPartitionOperationKey] = _ + + def createDelayedProducePurgatory(timer: MockTimer): Unit = { + producePurgatory = new DelayedOperationPurgatory[DelayedProduce]("Produce", timer, 1, reaperEnabled = false) + watchKeys = Collections.newSetFromMap(new ConcurrentHashMap[TopicPartitionOperationKey, java.lang.Boolean]()).asScala + } + + override def tryCompleteActions(): Unit = watchKeys.map(producePurgatory.checkAndComplete) + + override def appendRecords(timeout: Long, + requiredAcks: Short, + internalTopicsAllowed: Boolean, + origin: AppendOrigin, + entriesPerPartition: Map[TopicPartition, MemoryRecords], + responseCallback: Map[TopicPartition, PartitionResponse] => Unit, + delayedProduceLock: Option[Lock] = None, + processingStatsCallback: Map[TopicPartition, RecordConversionStats] => Unit = _ => (), + requestLocal: RequestLocal = RequestLocal.NoCaching): Unit = { + + if (entriesPerPartition.isEmpty) + return + val produceMetadata = ProduceMetadata(1, entriesPerPartition.map { + case (tp, _) => + (tp, ProducePartitionStatus(0L, new PartitionResponse(Errors.NONE, 0L, RecordBatch.NO_TIMESTAMP, 0L))) + }) + val delayedProduce = new DelayedProduce(5, produceMetadata, this, responseCallback, delayedProduceLock) { + // Complete produce requests after a few attempts to trigger delayed produce from different threads + val completeAttempts = new AtomicInteger + override def tryComplete(): Boolean = { + if (completeAttempts.incrementAndGet() >= 3) + forceComplete() + else + false + } + override def onComplete(): Unit = { + responseCallback(entriesPerPartition.map { + case (tp, _) => + (tp, new PartitionResponse(Errors.NONE, 0L, RecordBatch.NO_TIMESTAMP, 0L)) + }) + } + } + val producerRequestKeys = entriesPerPartition.keys.map(TopicPartitionOperationKey(_)).toSeq + watchKeys ++= producerRequestKeys + producePurgatory.tryCompleteElseWatch(delayedProduce, producerRequestKeys) + } + + override def getMagic(topicPartition: TopicPartition): Option[Byte] = { + Some(RecordBatch.MAGIC_VALUE_V2) + } + + def getOrCreateLogs(): mutable.Map[TopicPartition, (UnifiedLog, Long)] = { + if (logs == null) + logs = mutable.Map[TopicPartition, (UnifiedLog, Long)]() + logs + } + + def updateLog(topicPartition: TopicPartition, log: UnifiedLog, endOffset: Long): Unit = { + getOrCreateLogs().put(topicPartition, (log, endOffset)) + } + + override def getLogConfig(topicPartition: TopicPartition): Option[LogConfig] = { + getOrCreateLogs().get(topicPartition).map(_._1.config) + } + + override def getLog(topicPartition: TopicPartition): Option[UnifiedLog] = + getOrCreateLogs().get(topicPartition).map(l => l._1) + + override def getLogEndOffset(topicPartition: TopicPartition): Option[Long] = + getOrCreateLogs().get(topicPartition).map(l => l._2) + } +} diff --git a/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorConcurrencyTest.scala b/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorConcurrencyTest.scala new file mode 100644 index 0000000..2ef487c --- /dev/null +++ b/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorConcurrencyTest.scala @@ -0,0 +1,408 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.coordinator.group + +import java.util.Properties +import java.util.concurrent.locks.{Lock, ReentrantLock} +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} + +import kafka.common.OffsetAndMetadata +import kafka.coordinator.AbstractCoordinatorConcurrencyTest +import kafka.coordinator.AbstractCoordinatorConcurrencyTest._ +import kafka.coordinator.group.GroupCoordinatorConcurrencyTest._ +import kafka.server.{DelayedOperationPurgatory, KafkaConfig} +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.internals.Topic +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.message.LeaveGroupRequestData.MemberIdentity +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.{JoinGroupRequest, OffsetFetchResponse} +import org.apache.kafka.common.utils.Time +import org.easymock.EasyMock +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} + +import scala.collection._ +import scala.concurrent.duration.Duration +import scala.concurrent.{Await, Future, Promise, TimeoutException} + +class GroupCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest[GroupMember] { + + private val protocolType = "consumer" + private val protocolName = "range" + private val metadata = Array[Byte]() + private val protocols = List((protocolName, metadata)) + + private val nGroups = nThreads * 10 + private val nMembersPerGroup = nThreads * 5 + private val numPartitions = 2 + + private val allOperations = Seq( + new JoinGroupOperation, + new SyncGroupOperation, + new OffsetFetchOperation, + new CommitOffsetsOperation, + new HeartbeatOperation, + new LeaveGroupOperation + ) + + var heartbeatPurgatory: DelayedOperationPurgatory[DelayedHeartbeat] = _ + var rebalancePurgatory: DelayedOperationPurgatory[DelayedRebalance] = _ + var groupCoordinator: GroupCoordinator = _ + + @BeforeEach + override def setUp(): Unit = { + super.setUp() + + EasyMock.expect(zkClient.getTopicPartitionCount(Topic.GROUP_METADATA_TOPIC_NAME)) + .andReturn(Some(numPartitions)) + .anyTimes() + EasyMock.replay(zkClient) + + serverProps.setProperty(KafkaConfig.GroupMinSessionTimeoutMsProp, ConsumerMinSessionTimeout.toString) + serverProps.setProperty(KafkaConfig.GroupMaxSessionTimeoutMsProp, ConsumerMaxSessionTimeout.toString) + serverProps.setProperty(KafkaConfig.GroupInitialRebalanceDelayMsProp, GroupInitialRebalanceDelay.toString) + + val config = KafkaConfig.fromProps(serverProps) + + heartbeatPurgatory = new DelayedOperationPurgatory[DelayedHeartbeat]("Heartbeat", timer, config.brokerId, reaperEnabled = false) + rebalancePurgatory = new DelayedOperationPurgatory[DelayedRebalance]("Rebalance", timer, config.brokerId, reaperEnabled = false) + + groupCoordinator = GroupCoordinator(config, replicaManager, heartbeatPurgatory, rebalancePurgatory, timer.time, new Metrics()) + groupCoordinator.startup(() => zkClient.getTopicPartitionCount(Topic.GROUP_METADATA_TOPIC_NAME).getOrElse(config.offsetsTopicPartitions), + false) + } + + @AfterEach + override def tearDown(): Unit = { + try { + if (groupCoordinator != null) + groupCoordinator.shutdown() + } finally { + super.tearDown() + } + } + + def createGroupMembers(groupPrefix: String): Set[GroupMember] = { + (0 until nGroups).flatMap { i => + new Group(s"$groupPrefix$i", nMembersPerGroup, groupCoordinator, replicaManager).members + }.toSet + } + + @Test + def testConcurrentGoodPathSequence(): Unit = { + verifyConcurrentOperations(createGroupMembers, allOperations) + } + + @Test + def testConcurrentTxnGoodPathSequence(): Unit = { + verifyConcurrentOperations(createGroupMembers, Seq( + new JoinGroupOperation, + new SyncGroupOperation, + new OffsetFetchOperation, + new CommitTxnOffsetsOperation, + new CompleteTxnOperation, + new HeartbeatOperation, + new LeaveGroupOperation + )) + } + + @Test + def testConcurrentRandomSequence(): Unit = { + /** + * handleTxnCommitOffsets does not complete delayed requests now so it causes error if handleTxnCompletion is executed + * before completing delayed request. In random mode, we use this global lock to prevent such an error. + */ + val lock = new ReentrantLock() + verifyConcurrentRandomSequences(createGroupMembers, Seq( + new JoinGroupOperation, + new SyncGroupOperation, + new OffsetFetchOperation, + new CommitTxnOffsetsOperation(lock = Some(lock)), + new CompleteTxnOperation(lock = Some(lock)), + new HeartbeatOperation, + new LeaveGroupOperation + )) + } + + @Test + def testConcurrentJoinGroupEnforceGroupMaxSize(): Unit = { + val groupMaxSize = 1 + val newProperties = new Properties + newProperties.put(KafkaConfig.GroupMaxSizeProp, groupMaxSize.toString) + val config = KafkaConfig.fromProps(serverProps, newProperties) + + if (groupCoordinator != null) + groupCoordinator.shutdown() + groupCoordinator = GroupCoordinator(config, replicaManager, heartbeatPurgatory, + rebalancePurgatory, timer.time, new Metrics()) + groupCoordinator.startup(() => zkClient.getTopicPartitionCount(Topic.GROUP_METADATA_TOPIC_NAME).getOrElse(config.offsetsTopicPartitions), + false) + + val members = new Group(s"group", nMembersPerGroup, groupCoordinator, replicaManager) + .members + val joinOp = new JoinGroupOperation() + + verifyConcurrentActions(members.toSet.map(joinOp.actionNoVerify)) + + val errors = members.map { member => + val joinGroupResult = joinOp.await(member, DefaultRebalanceTimeout) + joinGroupResult.error + } + + assertEquals(groupMaxSize, errors.count(_ == Errors.NONE)) + assertEquals(members.size-groupMaxSize, errors.count(_ == Errors.GROUP_MAX_SIZE_REACHED)) + } + + abstract class GroupOperation[R, C] extends Operation { + val responseFutures = new ConcurrentHashMap[GroupMember, Future[R]]() + + def setUpCallback(member: GroupMember): C = { + val responsePromise = Promise[R]() + val responseFuture = responsePromise.future + responseFutures.put(member, responseFuture) + responseCallback(responsePromise) + } + def responseCallback(responsePromise: Promise[R]): C + + override def run(member: GroupMember): Unit = { + val responseCallback = setUpCallback(member) + runWithCallback(member, responseCallback) + } + + def runWithCallback(member: GroupMember, responseCallback: C): Unit + + def await(member: GroupMember, timeoutMs: Long): R = { + var retries = (timeoutMs + 10) / 10 + val responseFuture = responseFutures.get(member) + while (retries > 0) { + timer.advanceClock(10) + try { + return Await.result(responseFuture, Duration(10, TimeUnit.MILLISECONDS)) + } catch { + case _: TimeoutException => + } + retries -= 1 + } + throw new TimeoutException(s"Operation did not complete within $timeoutMs millis") + } + } + + class JoinGroupOperation extends GroupOperation[JoinGroupCallbackParams, JoinGroupCallback] { + override def responseCallback(responsePromise: Promise[JoinGroupCallbackParams]): JoinGroupCallback = { + val callback: JoinGroupCallback = responsePromise.success(_) + callback + } + override def runWithCallback(member: GroupMember, responseCallback: JoinGroupCallback): Unit = { + groupCoordinator.handleJoinGroup(member.groupId, member.memberId, None, requireKnownMemberId = false, "clientId", "clientHost", + DefaultRebalanceTimeout, DefaultSessionTimeout, + protocolType, protocols, responseCallback) + replicaManager.tryCompleteActions() + } + override def awaitAndVerify(member: GroupMember): Unit = { + val joinGroupResult = await(member, DefaultRebalanceTimeout) + assertEquals(Errors.NONE, joinGroupResult.error) + member.memberId = joinGroupResult.memberId + member.generationId = joinGroupResult.generationId + } + } + + class SyncGroupOperation extends GroupOperation[SyncGroupCallbackParams, SyncGroupCallback] { + override def responseCallback(responsePromise: Promise[SyncGroupCallbackParams]): SyncGroupCallback = { + val callback: SyncGroupCallback = syncGroupResult => + responsePromise.success(syncGroupResult.error, syncGroupResult.memberAssignment) + callback + } + override def runWithCallback(member: GroupMember, responseCallback: SyncGroupCallback): Unit = { + if (member.leader) { + groupCoordinator.handleSyncGroup(member.groupId, member.generationId, member.memberId, + Some(protocolType), Some(protocolName), member.groupInstanceId, member.group.assignment, responseCallback) + } else { + groupCoordinator.handleSyncGroup(member.groupId, member.generationId, member.memberId, + Some(protocolType), Some(protocolName), member.groupInstanceId, Map.empty[String, Array[Byte]], responseCallback) + } + replicaManager.tryCompleteActions() + } + override def awaitAndVerify(member: GroupMember): Unit = { + val result = await(member, DefaultSessionTimeout) + assertEquals(Errors.NONE, result._1) + assertNotNull(result._2) + assertEquals(0, result._2.length) + } + } + + class HeartbeatOperation extends GroupOperation[HeartbeatCallbackParams, HeartbeatCallback] { + override def responseCallback(responsePromise: Promise[HeartbeatCallbackParams]): HeartbeatCallback = { + val callback: HeartbeatCallback = error => responsePromise.success(error) + callback + } + override def runWithCallback(member: GroupMember, responseCallback: HeartbeatCallback): Unit = { + groupCoordinator.handleHeartbeat(member.groupId, member.memberId, + member.groupInstanceId, member.generationId, responseCallback) + replicaManager.tryCompleteActions() + } + override def awaitAndVerify(member: GroupMember): Unit = { + val error = await(member, DefaultSessionTimeout) + assertEquals(Errors.NONE, error) + } + } + + class OffsetFetchOperation extends GroupOperation[OffsetFetchCallbackParams, OffsetFetchCallback] { + override def responseCallback(responsePromise: Promise[OffsetFetchCallbackParams]): OffsetFetchCallback = { + val callback: OffsetFetchCallback = (error, offsets) => responsePromise.success(error, offsets) + callback + } + override def runWithCallback(member: GroupMember, responseCallback: OffsetFetchCallback): Unit = { + val (error, partitionData) = groupCoordinator.handleFetchOffsets(member.groupId, requireStable = true, None) + replicaManager.tryCompleteActions() + responseCallback(error, partitionData) + } + override def awaitAndVerify(member: GroupMember): Unit = { + val result = await(member, 500) + assertEquals(Errors.NONE, result._1) + assertEquals(Map.empty, result._2) + } + } + + class CommitOffsetsOperation extends GroupOperation[CommitOffsetCallbackParams, CommitOffsetCallback] { + override def responseCallback(responsePromise: Promise[CommitOffsetCallbackParams]): CommitOffsetCallback = { + val callback: CommitOffsetCallback = offsets => responsePromise.success(offsets) + callback + } + override def runWithCallback(member: GroupMember, responseCallback: CommitOffsetCallback): Unit = { + val tp = new TopicPartition("topic", 0) + val offsets = immutable.Map(tp -> OffsetAndMetadata(1, "", Time.SYSTEM.milliseconds())) + groupCoordinator.handleCommitOffsets(member.groupId, member.memberId, + member.groupInstanceId, member.generationId, offsets, responseCallback) + replicaManager.tryCompleteActions() + } + override def awaitAndVerify(member: GroupMember): Unit = { + val offsets = await(member, 500) + offsets.foreach { case (_, error) => assertEquals(Errors.NONE, error) } + } + } + + class CommitTxnOffsetsOperation(lock: Option[Lock] = None) extends CommitOffsetsOperation { + override def runWithCallback(member: GroupMember, responseCallback: CommitOffsetCallback): Unit = { + val tp = new TopicPartition("topic", 0) + val offsets = immutable.Map(tp -> OffsetAndMetadata(1, "", Time.SYSTEM.milliseconds())) + val producerId = 1000L + val producerEpoch : Short = 2 + // When transaction offsets are appended to the log, transactions may be scheduled for + // completion. Since group metadata locks are acquired for transaction completion, include + // this in the callback to test that there are no deadlocks. + def callbackWithTxnCompletion(errors: Map[TopicPartition, Errors]): Unit = { + val offsetsPartitions = (0 to numPartitions).map(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, _)) + groupCoordinator.groupManager.scheduleHandleTxnCompletion(producerId, + offsetsPartitions.map(_.partition).toSet, isCommit = random.nextBoolean) + responseCallback(errors) + } + lock.foreach(_.lock()) + try { + groupCoordinator.handleTxnCommitOffsets(member.group.groupId, producerId, producerEpoch, + JoinGroupRequest.UNKNOWN_MEMBER_ID, Option.empty, JoinGroupRequest.UNKNOWN_GENERATION_ID, + offsets, callbackWithTxnCompletion) + replicaManager.tryCompleteActions() + } finally lock.foreach(_.unlock()) + } + } + + class CompleteTxnOperation(lock: Option[Lock] = None) extends GroupOperation[CompleteTxnCallbackParams, CompleteTxnCallback] { + override def responseCallback(responsePromise: Promise[CompleteTxnCallbackParams]): CompleteTxnCallback = { + val callback: CompleteTxnCallback = error => responsePromise.success(error) + callback + } + override def runWithCallback(member: GroupMember, responseCallback: CompleteTxnCallback): Unit = { + val producerId = 1000L + val offsetsPartitions = (0 to numPartitions).map(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, _)) + lock.foreach(_.lock()) + try { + groupCoordinator.groupManager.handleTxnCompletion(producerId, + offsetsPartitions.map(_.partition).toSet, isCommit = random.nextBoolean) + responseCallback(Errors.NONE) + } finally lock.foreach(_.unlock()) + + } + override def awaitAndVerify(member: GroupMember): Unit = { + val error = await(member, 500) + assertEquals(Errors.NONE, error) + } + } + + class LeaveGroupOperation extends GroupOperation[LeaveGroupCallbackParams, LeaveGroupCallback] { + override def responseCallback(responsePromise: Promise[LeaveGroupCallbackParams]): LeaveGroupCallback = { + val callback: LeaveGroupCallback = result => responsePromise.success(result) + callback + } + override def runWithCallback(member: GroupMember, responseCallback: LeaveGroupCallback): Unit = { + val memberIdentity = new MemberIdentity() + .setMemberId(member.memberId) + groupCoordinator.handleLeaveGroup(member.group.groupId, List(memberIdentity), responseCallback) + } + override def awaitAndVerify(member: GroupMember): Unit = { + val leaveGroupResult = await(member, DefaultSessionTimeout) + + val memberResponses = leaveGroupResult.memberResponses + GroupCoordinatorTest.verifyLeaveGroupResult(leaveGroupResult, Errors.NONE, List(Errors.NONE)) + assertEquals(member.memberId, memberResponses.head.memberId) + assertEquals(None, memberResponses.head.groupInstanceId) + } + } +} + +object GroupCoordinatorConcurrencyTest { + + type JoinGroupCallbackParams = JoinGroupResult + type JoinGroupCallback = JoinGroupResult => Unit + type SyncGroupCallbackParams = (Errors, Array[Byte]) + type SyncGroupCallback = SyncGroupResult => Unit + type HeartbeatCallbackParams = Errors + type HeartbeatCallback = Errors => Unit + type OffsetFetchCallbackParams = (Errors, Map[TopicPartition, OffsetFetchResponse.PartitionData]) + type OffsetFetchCallback = (Errors, Map[TopicPartition, OffsetFetchResponse.PartitionData]) => Unit + type CommitOffsetCallbackParams = Map[TopicPartition, Errors] + type CommitOffsetCallback = Map[TopicPartition, Errors] => Unit + type LeaveGroupCallbackParams = LeaveGroupResult + type LeaveGroupCallback = LeaveGroupResult => Unit + type CompleteTxnCallbackParams = Errors + type CompleteTxnCallback = Errors => Unit + + private val ConsumerMinSessionTimeout = 10 + private val ConsumerMaxSessionTimeout = 120 * 1000 + private val DefaultRebalanceTimeout = 60 * 1000 + private val DefaultSessionTimeout = 60 * 1000 + private val GroupInitialRebalanceDelay = 50 + + class Group(val groupId: String, nMembers: Int, + groupCoordinator: GroupCoordinator, replicaManager: TestReplicaManager) { + val groupPartitionId = groupCoordinator.partitionFor(groupId) + groupCoordinator.groupManager.addPartitionOwnership(groupPartitionId) + val members = (0 until nMembers).map { i => + new GroupMember(this, groupPartitionId, i == 0) + } + def assignment: Map[String, Array[Byte]] = members.map { m => (m.memberId, Array[Byte]()) }.toMap + } + + class GroupMember(val group: Group, val groupPartitionId: Int, val leader: Boolean) extends CoordinatorMember { + @volatile var memberId: String = JoinGroupRequest.UNKNOWN_MEMBER_ID + @volatile var groupInstanceId: Option[String] = None + @volatile var generationId: Int = -1 + def groupId: String = group.groupId + } + +} diff --git a/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorTest.scala b/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorTest.scala new file mode 100644 index 0000000..0784259 --- /dev/null +++ b/core/src/test/scala/unit/kafka/coordinator/group/GroupCoordinatorTest.scala @@ -0,0 +1,4279 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.coordinator.group + +import java.util.Optional +import kafka.common.OffsetAndMetadata +import kafka.server.{DelayedOperationPurgatory, HostedPartition, KafkaConfig, ReplicaManager, RequestLocal} +import kafka.utils._ +import kafka.utils.timer.MockTimer +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.record.{MemoryRecords, RecordBatch} +import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse +import org.apache.kafka.common.requests.{JoinGroupRequest, OffsetCommitRequest, OffsetFetchResponse, TransactionResult} +import org.easymock.{Capture, EasyMock, IAnswer} + +import java.util.concurrent.TimeUnit +import java.util.concurrent.locks.ReentrantLock +import kafka.cluster.Partition +import kafka.log.AppendOrigin +import kafka.zk.KafkaZkClient +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Subscription +import org.apache.kafka.clients.consumer.internals.ConsumerProtocol +import org.apache.kafka.common.internals.Topic +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.message.LeaveGroupRequestData.MemberIdentity +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} + +import scala.jdk.CollectionConverters._ +import scala.collection.{Seq, mutable} +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration.Duration +import scala.concurrent.{Await, Future, Promise, TimeoutException} + +class GroupCoordinatorTest { + import GroupCoordinatorTest._ + + type JoinGroupCallback = JoinGroupResult => Unit + type SyncGroupCallback = SyncGroupResult => Unit + type HeartbeatCallbackParams = Errors + type HeartbeatCallback = Errors => Unit + type CommitOffsetCallbackParams = Map[TopicPartition, Errors] + type CommitOffsetCallback = Map[TopicPartition, Errors] => Unit + type LeaveGroupCallback = LeaveGroupResult => Unit + + val ClientId = "consumer-test" + val ClientHost = "localhost" + val GroupMinSessionTimeout = 10 + val GroupMaxSessionTimeout = 10 * 60 * 1000 + val GroupMaxSize = 4 + val DefaultRebalanceTimeout = 500 + val DefaultSessionTimeout = 500 + val GroupInitialRebalanceDelay = 50 + var timer: MockTimer = null + var groupCoordinator: GroupCoordinator = null + var replicaManager: ReplicaManager = null + var scheduler: KafkaScheduler = null + var zkClient: KafkaZkClient = null + + private val groupId = "groupId" + private val protocolType = "consumer" + private val protocolName = "range" + private val memberId = "memberId" + private val groupInstanceId = "groupInstanceId" + private val leaderInstanceId = "leader" + private val followerInstanceId = "follower" + private val invalidMemberId = "invalidMember" + private val metadata = Array[Byte]() + private val protocols = List((protocolName, metadata)) + private val protocolSuperset = List((protocolName, metadata), ("roundrobin", metadata)) + private val requireStable = true + private var groupPartitionId: Int = -1 + + // we use this string value since its hashcode % #.partitions is different + private val otherGroupId = "otherGroup" + + @BeforeEach + def setUp(): Unit = { + val props = TestUtils.createBrokerConfig(nodeId = 0, zkConnect = "") + props.setProperty(KafkaConfig.GroupMinSessionTimeoutMsProp, GroupMinSessionTimeout.toString) + props.setProperty(KafkaConfig.GroupMaxSessionTimeoutMsProp, GroupMaxSessionTimeout.toString) + props.setProperty(KafkaConfig.GroupMaxSizeProp, GroupMaxSize.toString) + props.setProperty(KafkaConfig.GroupInitialRebalanceDelayMsProp, GroupInitialRebalanceDelay.toString) + // make two partitions of the group topic to make sure some partitions are not owned by the coordinator + val ret = mutable.Map[String, Map[Int, Seq[Int]]]() + ret += (Topic.GROUP_METADATA_TOPIC_NAME -> Map(0 -> Seq(1), 1 -> Seq(1))) + + replicaManager = EasyMock.createNiceMock(classOf[ReplicaManager]) + + zkClient = EasyMock.createNiceMock(classOf[KafkaZkClient]) + // make two partitions of the group topic to make sure some partitions are not owned by the coordinator + EasyMock.expect(zkClient.getTopicPartitionCount(Topic.GROUP_METADATA_TOPIC_NAME)).andReturn(Some(2)) + EasyMock.replay(zkClient) + + timer = new MockTimer + + val config = KafkaConfig.fromProps(props) + + val heartbeatPurgatory = new DelayedOperationPurgatory[DelayedHeartbeat]("Heartbeat", timer, config.brokerId, reaperEnabled = false) + val rebalancePurgatory = new DelayedOperationPurgatory[DelayedRebalance]("Rebalance", timer, config.brokerId, reaperEnabled = false) + + groupCoordinator = GroupCoordinator(config, replicaManager, heartbeatPurgatory, rebalancePurgatory, timer.time, new Metrics()) + groupCoordinator.startup(() => zkClient.getTopicPartitionCount(Topic.GROUP_METADATA_TOPIC_NAME).getOrElse(config.offsetsTopicPartitions), + enableMetadataExpiration = false) + + // add the partition into the owned partition list + groupPartitionId = groupCoordinator.partitionFor(groupId) + groupCoordinator.groupManager.addPartitionOwnership(groupPartitionId) + } + + @AfterEach + def tearDown(): Unit = { + EasyMock.reset(replicaManager) + if (groupCoordinator != null) + groupCoordinator.shutdown() + } + + @Test + def testRequestHandlingWhileLoadingInProgress(): Unit = { + val otherGroupPartitionId = groupCoordinator.groupManager.partitionFor(otherGroupId) + assertTrue(otherGroupPartitionId != groupPartitionId) + + groupCoordinator.groupManager.addLoadingPartition(otherGroupPartitionId) + assertTrue(groupCoordinator.groupManager.isGroupLoading(otherGroupId)) + + // Dynamic Member JoinGroup + var joinGroupResponse: Option[JoinGroupResult] = None + groupCoordinator.handleJoinGroup(otherGroupId, memberId, None, true, "clientId", "clientHost", 60000, 10000, "consumer", + List("range" -> new Array[Byte](0)), result => { joinGroupResponse = Some(result)}) + assertEquals(Some(Errors.COORDINATOR_LOAD_IN_PROGRESS), joinGroupResponse.map(_.error)) + + // Static Member JoinGroup + groupCoordinator.handleJoinGroup(otherGroupId, memberId, Some("groupInstanceId"), false, "clientId", "clientHost", 60000, 10000, "consumer", + List("range" -> new Array[Byte](0)), result => { joinGroupResponse = Some(result)}) + assertEquals(Some(Errors.COORDINATOR_LOAD_IN_PROGRESS), joinGroupResponse.map(_.error)) + + // SyncGroup + var syncGroupResponse: Option[Errors] = None + groupCoordinator.handleSyncGroup(otherGroupId, 1, memberId, Some("consumer"), Some("range"), None, Map.empty[String, Array[Byte]], + syncGroupResult => syncGroupResponse = Some(syncGroupResult.error)) + assertEquals(Some(Errors.REBALANCE_IN_PROGRESS), syncGroupResponse) + + // OffsetCommit + val topicPartition = new TopicPartition("foo", 0) + var offsetCommitErrors = Map.empty[TopicPartition, Errors] + groupCoordinator.handleCommitOffsets(otherGroupId, memberId, None, 1, + Map(topicPartition -> offsetAndMetadata(15L)), result => { offsetCommitErrors = result }) + assertEquals(Some(Errors.COORDINATOR_LOAD_IN_PROGRESS), offsetCommitErrors.get(topicPartition)) + + // Heartbeat + var heartbeatError: Option[Errors] = None + groupCoordinator.handleHeartbeat(otherGroupId, memberId, None, 1, error => { heartbeatError = Some(error) }) + assertEquals(Some(Errors.NONE), heartbeatError) + + // DescribeGroups + val (describeGroupError, _) = groupCoordinator.handleDescribeGroup(otherGroupId) + assertEquals(Errors.COORDINATOR_LOAD_IN_PROGRESS, describeGroupError) + + // ListGroups + val (listGroupsError, _) = groupCoordinator.handleListGroups(Set()) + assertEquals(Errors.COORDINATOR_LOAD_IN_PROGRESS, listGroupsError) + + // DeleteGroups + val deleteGroupsErrors = groupCoordinator.handleDeleteGroups(Set(otherGroupId)) + assertEquals(Some(Errors.COORDINATOR_LOAD_IN_PROGRESS), deleteGroupsErrors.get(otherGroupId)) + + // Check that non-loading groups are still accessible + assertEquals(Errors.NONE, groupCoordinator.handleDescribeGroup(groupId)._1) + + // After loading, we should be able to access the group + val otherGroupMetadataTopicPartition = new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, otherGroupPartitionId) + EasyMock.reset(replicaManager) + EasyMock.expect(replicaManager.getLog(otherGroupMetadataTopicPartition)).andReturn(None) + EasyMock.replay(replicaManager) + // Call removeGroupsAndOffsets so that partition removed from loadingPartitions + groupCoordinator.groupManager.removeGroupsAndOffsets(otherGroupMetadataTopicPartition, Some(1), group => {}) + groupCoordinator.groupManager.loadGroupsAndOffsets(otherGroupMetadataTopicPartition, 1, group => {}, 0L) + assertEquals(Errors.NONE, groupCoordinator.handleDescribeGroup(otherGroupId)._1) + } + + @Test + def testOffsetsRetentionMsIntegerOverflow(): Unit = { + val props = TestUtils.createBrokerConfig(nodeId = 0, zkConnect = "") + props.setProperty(KafkaConfig.OffsetsRetentionMinutesProp, Integer.MAX_VALUE.toString) + val config = KafkaConfig.fromProps(props) + val offsetConfig = GroupCoordinator.offsetConfig(config) + assertEquals(offsetConfig.offsetsRetentionMs, Integer.MAX_VALUE * 60L * 1000L) + } + + @Test + def testJoinGroupWrongCoordinator(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + + var joinGroupResult = dynamicJoinGroup(otherGroupId, memberId, protocolType, protocols) + assertEquals(Errors.NOT_COORDINATOR, joinGroupResult.error) + + EasyMock.reset(replicaManager) + joinGroupResult = staticJoinGroup(otherGroupId, memberId, groupInstanceId, protocolType, protocols) + assertEquals(Errors.NOT_COORDINATOR, joinGroupResult.error) + } + + @Test + def testJoinGroupShouldReceiveErrorIfGroupOverMaxSize(): Unit = { + val futures = ArrayBuffer[Future[JoinGroupResult]]() + val rebalanceTimeout = GroupInitialRebalanceDelay * 2 + + for (i <- 1.to(GroupMaxSize)) { + futures += sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols, rebalanceTimeout = rebalanceTimeout) + if (i != 1) + timer.advanceClock(GroupInitialRebalanceDelay) + EasyMock.reset(replicaManager) + } + // advance clock beyond rebalanceTimeout + timer.advanceClock(GroupInitialRebalanceDelay + 1) + for (future <- futures) { + assertEquals(Errors.NONE, await(future, 1).error) + } + + // Should receive an error since the group is full + val errorFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols, rebalanceTimeout = rebalanceTimeout) + assertEquals(Errors.GROUP_MAX_SIZE_REACHED, await(errorFuture, 1).error) + } + + @Test + def testDynamicMembersJoinGroupWithMaxSizeAndRequiredKnownMember(): Unit = { + val requiredKnownMemberId = true + val nbMembers = GroupMaxSize + 1 + + // First JoinRequests + var futures = 1.to(nbMembers).map { _ => + EasyMock.reset(replicaManager) + sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols, + None, DefaultSessionTimeout, DefaultRebalanceTimeout, requiredKnownMemberId) + } + + // Get back the assigned member ids + val memberIds = futures.map(await(_, 1).memberId) + + // Second JoinRequests + futures = memberIds.map { memberId => + EasyMock.reset(replicaManager) + sendJoinGroup(groupId, memberId, protocolType, protocols, + None, DefaultSessionTimeout, DefaultRebalanceTimeout, requiredKnownMemberId) + } + + // advance clock by GroupInitialRebalanceDelay to complete first InitialDelayedJoin + timer.advanceClock(GroupInitialRebalanceDelay + 1) + // advance clock by GroupInitialRebalanceDelay to complete second InitialDelayedJoin + timer.advanceClock(GroupInitialRebalanceDelay + 1) + + // Awaiting results + val errors = futures.map(await(_, DefaultRebalanceTimeout + 1).error) + + assertEquals(GroupMaxSize, errors.count(_ == Errors.NONE)) + assertEquals(nbMembers-GroupMaxSize, errors.count(_ == Errors.GROUP_MAX_SIZE_REACHED)) + + // Members which were accepted can rejoin, others are rejected, while + // completing rebalance + futures = memberIds.map { memberId => + EasyMock.reset(replicaManager) + sendJoinGroup(groupId, memberId, protocolType, protocols, + None, DefaultSessionTimeout, DefaultRebalanceTimeout, requiredKnownMemberId) + } + + // Awaiting results + val rejoinErrors = futures.map(await(_, 1).error) + + assertEquals(errors, rejoinErrors) + } + + @Test + def testDynamicMembersJoinGroupWithMaxSize(): Unit = { + val requiredKnownMemberId = false + val nbMembers = GroupMaxSize + 1 + + // JoinRequests + var futures = 1.to(nbMembers).map { _ => + EasyMock.reset(replicaManager) + sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols, + None, DefaultSessionTimeout, DefaultRebalanceTimeout, requiredKnownMemberId) + } + + // advance clock by GroupInitialRebalanceDelay to complete first InitialDelayedJoin + timer.advanceClock(GroupInitialRebalanceDelay + 1) + // advance clock by GroupInitialRebalanceDelay to complete second InitialDelayedJoin + timer.advanceClock(GroupInitialRebalanceDelay + 1) + + // Awaiting results + val joinGroupResults = futures.map(await(_, DefaultRebalanceTimeout + 1)) + val errors = joinGroupResults.map(_.error) + + assertEquals(GroupMaxSize, errors.count(_ == Errors.NONE)) + assertEquals(nbMembers-GroupMaxSize, errors.count(_ == Errors.GROUP_MAX_SIZE_REACHED)) + + // Members which were accepted can rejoin, others are rejected, while + // completing rebalance + val memberIds = joinGroupResults.map(_.memberId) + futures = memberIds.map { memberId => + EasyMock.reset(replicaManager) + sendJoinGroup(groupId, memberId, protocolType, protocols, + None, DefaultSessionTimeout, DefaultRebalanceTimeout, requiredKnownMemberId) + } + + // Awaiting results + val rejoinErrors = futures.map(await(_, 1).error) + + assertEquals(errors, rejoinErrors) + } + + @Test + def testStaticMembersJoinGroupWithMaxSize(): Unit = { + val nbMembers = GroupMaxSize + 1 + val instanceIds = 1.to(nbMembers).map(i => Some(s"instance-id-$i")) + + // JoinRequests + var futures = instanceIds.map { instanceId => + EasyMock.reset(replicaManager) + sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols, + instanceId, DefaultSessionTimeout, DefaultRebalanceTimeout) + } + + // advance clock by GroupInitialRebalanceDelay to complete first InitialDelayedJoin + timer.advanceClock(GroupInitialRebalanceDelay + 1) + // advance clock by GroupInitialRebalanceDelay to complete second InitialDelayedJoin + timer.advanceClock(GroupInitialRebalanceDelay + 1) + + // Awaiting results + val joinGroupResults = futures.map(await(_, DefaultRebalanceTimeout + 1)) + val errors = joinGroupResults.map(_.error) + + assertEquals(GroupMaxSize, errors.count(_ == Errors.NONE)) + assertEquals(nbMembers-GroupMaxSize, errors.count(_ == Errors.GROUP_MAX_SIZE_REACHED)) + + // Members which were accepted can rejoin, others are rejected, while + // completing rebalance + val memberIds = joinGroupResults.map(_.memberId) + futures = instanceIds.zip(memberIds).map { case (instanceId, memberId) => + EasyMock.reset(replicaManager) + sendJoinGroup(groupId, memberId, protocolType, protocols, + instanceId, DefaultSessionTimeout, DefaultRebalanceTimeout) + } + + // Awaiting results + val rejoinErrors = futures.map(await(_, 1).error) + + assertEquals(errors, rejoinErrors) + } + + @Test + def testDynamicMembersCanReJoinGroupWithMaxSizeWhileRebalancing(): Unit = { + val requiredKnownMemberId = true + val nbMembers = GroupMaxSize + 1 + + // First JoinRequests + var futures = 1.to(nbMembers).map { _ => + EasyMock.reset(replicaManager) + sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols, + None, DefaultSessionTimeout, DefaultRebalanceTimeout, requiredKnownMemberId) + } + + // Get back the assigned member ids + val memberIds = futures.map(await(_, 1).memberId) + + // Second JoinRequests + memberIds.map { memberId => + EasyMock.reset(replicaManager) + sendJoinGroup(groupId, memberId, protocolType, protocols, + None, DefaultSessionTimeout, DefaultRebalanceTimeout, requiredKnownMemberId) + } + + // Members can rejoin while rebalancing + futures = memberIds.map { memberId => + EasyMock.reset(replicaManager) + sendJoinGroup(groupId, memberId, protocolType, protocols, + None, DefaultSessionTimeout, DefaultRebalanceTimeout, requiredKnownMemberId) + } + + // advance clock by GroupInitialRebalanceDelay to complete first InitialDelayedJoin + timer.advanceClock(GroupInitialRebalanceDelay + 1) + // advance clock by GroupInitialRebalanceDelay to complete second InitialDelayedJoin + timer.advanceClock(GroupInitialRebalanceDelay + 1) + + // Awaiting results + val errors = futures.map(await(_, DefaultRebalanceTimeout + 1).error) + + assertEquals(GroupMaxSize, errors.count(_ == Errors.NONE)) + assertEquals(nbMembers-GroupMaxSize, errors.count(_ == Errors.GROUP_MAX_SIZE_REACHED)) + } + + @Test + def testLastJoiningMembersAreKickedOutWhenReJoiningGroupWithMaxSize(): Unit = { + val nbMembers = GroupMaxSize + 2 + val group = new GroupMetadata(groupId, Stable, new MockTime()) + val memberIds = 1.to(nbMembers).map(_ => group.generateMemberId(ClientId, None)) + + memberIds.foreach { memberId => + group.add(new MemberMetadata(memberId, None, ClientId, ClientHost, + DefaultRebalanceTimeout, GroupMaxSessionTimeout, protocolType, protocols)) + } + groupCoordinator.groupManager.addGroup(group) + + groupCoordinator.prepareRebalance(group, "") + + val futures = memberIds.map { memberId => + EasyMock.reset(replicaManager) + sendJoinGroup(groupId, memberId, protocolType, protocols, + None, GroupMaxSessionTimeout, DefaultRebalanceTimeout) + } + + // advance clock by GroupInitialRebalanceDelay to complete first InitialDelayedJoin + timer.advanceClock(DefaultRebalanceTimeout + 1) + + // Awaiting results + val errors = futures.map(await(_, DefaultRebalanceTimeout + 1).error) + + assertEquals(Set(Errors.NONE), errors.take(GroupMaxSize).toSet) + assertEquals(Set(Errors.GROUP_MAX_SIZE_REACHED), errors.drop(GroupMaxSize).toSet) + + memberIds.drop(GroupMaxSize).foreach { memberId => + assertFalse(group.has(memberId)) + } + } + + @Test + def testJoinGroupSessionTimeoutTooSmall(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols, sessionTimeout = GroupMinSessionTimeout - 1) + assertEquals(Errors.INVALID_SESSION_TIMEOUT, joinGroupResult.error) + } + + @Test + def testJoinGroupSessionTimeoutTooLarge(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols, sessionTimeout = GroupMaxSessionTimeout + 1) + assertEquals(Errors.INVALID_SESSION_TIMEOUT, joinGroupResult.error) + } + + @Test + def testJoinGroupUnknownConsumerNewGroup(): Unit = { + var joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols) + assertEquals(Errors.UNKNOWN_MEMBER_ID, joinGroupResult.error) + + EasyMock.reset(replicaManager) + joinGroupResult = staticJoinGroup(groupId, memberId, groupInstanceId, protocolType, protocols) + assertEquals(Errors.UNKNOWN_MEMBER_ID, joinGroupResult.error) + } + + @Test + def testInvalidGroupId(): Unit = { + val groupId = "" + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols) + assertEquals(Errors.INVALID_GROUP_ID, joinGroupResult.error) + } + + @Test + def testValidJoinGroup(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols) + assertEquals(Errors.NONE, joinGroupResult.error) + } + + @Test + def testJoinGroupInconsistentProtocolType(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + val otherMemberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols) + assertEquals(Errors.NONE, joinGroupResult.error) + + EasyMock.reset(replicaManager) + val otherJoinGroupResult = await(sendJoinGroup(groupId, otherMemberId, "connect", protocols), 1) + assertEquals(Errors.INCONSISTENT_GROUP_PROTOCOL, otherJoinGroupResult.error) + } + + @Test + def testJoinGroupWithEmptyProtocolType(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + + var joinGroupResult = dynamicJoinGroup(groupId, memberId, "", protocols) + assertEquals(Errors.INCONSISTENT_GROUP_PROTOCOL, joinGroupResult.error) + + EasyMock.reset(replicaManager) + joinGroupResult = staticJoinGroup(groupId, memberId, groupInstanceId, "", protocols) + assertEquals(Errors.INCONSISTENT_GROUP_PROTOCOL, joinGroupResult.error) + } + + @Test + def testJoinGroupWithEmptyGroupProtocol(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, List()) + assertEquals(Errors.INCONSISTENT_GROUP_PROTOCOL, joinGroupResult.error) + } + + @Test + def testNewMemberTimeoutCompletion(): Unit = { + val sessionTimeout = GroupCoordinator.NewMemberJoinTimeoutMs + 5000 + val responseFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols, None, sessionTimeout, DefaultRebalanceTimeout, false) + + timer.advanceClock(GroupInitialRebalanceDelay + 1) + + val joinResult = Await.result(responseFuture, Duration(DefaultRebalanceTimeout + 100, TimeUnit.MILLISECONDS)) + val group = groupCoordinator.groupManager.getGroup(groupId).get + val memberId = joinResult.memberId + + assertEquals(Errors.NONE, joinResult.error) + assertEquals(0, group.allMemberMetadata.count(_.isNew)) + + EasyMock.reset(replicaManager) + val syncGroupResult = syncGroupLeader(groupId, joinResult.generationId, memberId, Map(memberId -> Array[Byte]())) + assertEquals(Errors.NONE, syncGroupResult.error) + assertEquals(1, group.size) + + timer.advanceClock(GroupCoordinator.NewMemberJoinTimeoutMs + 100) + + // Make sure the NewMemberTimeout is not still in effect, and the member is not kicked + assertEquals(1, group.size) + + timer.advanceClock(sessionTimeout + 100) + assertEquals(0, group.size) + } + + @Test + def testNewMemberJoinExpiration(): Unit = { + // This tests new member expiration during a protracted rebalance. We first create a + // group with one member which uses a large value for session timeout and rebalance timeout. + // We then join with one new member and let the rebalance hang while we await the first member. + // The new member join timeout expires and its JoinGroup request is failed. + + val sessionTimeout = GroupCoordinator.NewMemberJoinTimeoutMs + 5000 + val rebalanceTimeout = GroupCoordinator.NewMemberJoinTimeoutMs * 2 + + val firstJoinResult = dynamicJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols, + sessionTimeout, rebalanceTimeout) + val firstMemberId = firstJoinResult.memberId + assertEquals(firstMemberId, firstJoinResult.leaderId) + assertEquals(Errors.NONE, firstJoinResult.error) + + val groupOpt = groupCoordinator.groupManager.getGroup(groupId) + assertTrue(groupOpt.isDefined) + val group = groupOpt.get + assertEquals(0, group.allMemberMetadata.count(_.isNew)) + + EasyMock.reset(replicaManager) + val responseFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols, None, sessionTimeout, rebalanceTimeout) + assertFalse(responseFuture.isCompleted) + + assertEquals(2, group.allMembers.size) + assertEquals(1, group.allMemberMetadata.count(_.isNew)) + + val newMember = group.allMemberMetadata.find(_.isNew).get + assertNotEquals(firstMemberId, newMember.memberId) + + timer.advanceClock(GroupCoordinator.NewMemberJoinTimeoutMs + 1) + assertTrue(responseFuture.isCompleted) + + val response = Await.result(responseFuture, Duration(0, TimeUnit.MILLISECONDS)) + assertEquals(Errors.UNKNOWN_MEMBER_ID, response.error) + assertEquals(1, group.allMembers.size) + assertEquals(0, group.allMemberMetadata.count(_.isNew)) + assertEquals(firstMemberId, group.allMembers.head) + } + + @Test + def testNewMemberFailureAfterJoinGroupCompletion(): Unit = { + // For old versions of the JoinGroup protocol, new members were subject + // to expiration if the rebalance took long enough. This test case ensures + // that following completion of the JoinGroup phase, new members follow + // normal heartbeat expiration logic. + + val firstJoinResult = dynamicJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols) + val firstMemberId = firstJoinResult.memberId + val firstGenerationId = firstJoinResult.generationId + assertEquals(firstMemberId, firstJoinResult.leaderId) + assertEquals(Errors.NONE, firstJoinResult.error) + + EasyMock.reset(replicaManager) + val firstSyncResult = syncGroupLeader(groupId, firstGenerationId, firstMemberId, + Map(firstMemberId -> Array[Byte]())) + assertEquals(Errors.NONE, firstSyncResult.error) + + EasyMock.reset(replicaManager) + val otherJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols) + + EasyMock.reset(replicaManager) + val joinFuture = sendJoinGroup(groupId, firstMemberId, protocolType, protocols, + requireKnownMemberId = false) + + val joinResult = await(joinFuture, DefaultSessionTimeout+100) + val otherJoinResult = await(otherJoinFuture, DefaultSessionTimeout+100) + assertEquals(Errors.NONE, joinResult.error) + assertEquals(Errors.NONE, otherJoinResult.error) + + verifySessionExpiration(groupId) + } + + @Test + def testNewMemberFailureAfterSyncGroupCompletion(): Unit = { + // For old versions of the JoinGroup protocol, new members were subject + // to expiration if the rebalance took long enough. This test case ensures + // that following completion of the SyncGroup phase, new members follow + // normal heartbeat expiration logic. + + val firstJoinResult = dynamicJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols) + val firstMemberId = firstJoinResult.memberId + val firstGenerationId = firstJoinResult.generationId + assertEquals(firstMemberId, firstJoinResult.leaderId) + assertEquals(Errors.NONE, firstJoinResult.error) + + EasyMock.reset(replicaManager) + val firstSyncResult = syncGroupLeader(groupId, firstGenerationId, firstMemberId, + Map(firstMemberId -> Array[Byte]())) + assertEquals(Errors.NONE, firstSyncResult.error) + + EasyMock.reset(replicaManager) + val otherJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols) + + EasyMock.reset(replicaManager) + val joinFuture = sendJoinGroup(groupId, firstMemberId, protocolType, protocols, + requireKnownMemberId = false) + + val joinResult = await(joinFuture, DefaultSessionTimeout+100) + val otherJoinResult = await(otherJoinFuture, DefaultSessionTimeout+100) + assertEquals(Errors.NONE, joinResult.error) + assertEquals(Errors.NONE, otherJoinResult.error) + val secondGenerationId = joinResult.generationId + val secondMemberId = otherJoinResult.memberId + + EasyMock.reset(replicaManager) + sendSyncGroupFollower(groupId, secondGenerationId, secondMemberId) + + EasyMock.reset(replicaManager) + val syncGroupResult = syncGroupLeader(groupId, secondGenerationId, firstMemberId, + Map(firstMemberId -> Array.emptyByteArray, secondMemberId -> Array.emptyByteArray)) + assertEquals(Errors.NONE, syncGroupResult.error) + + verifySessionExpiration(groupId) + } + + private def verifySessionExpiration(groupId: String): Unit = { + EasyMock.reset(replicaManager) + EasyMock.expect(replicaManager.getMagic(EasyMock.anyObject())) + .andReturn(Some(RecordBatch.CURRENT_MAGIC_VALUE)).anyTimes() + EasyMock.replay(replicaManager) + + timer.advanceClock(DefaultSessionTimeout + 1) + + val groupMetadata = group(groupId) + assertEquals(Empty, groupMetadata.currentState) + assertTrue(groupMetadata.allMembers.isEmpty) + } + + @Test + def testJoinGroupInconsistentGroupProtocol(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + val otherMemberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + + val joinGroupFuture = sendJoinGroup(groupId, memberId, protocolType, List(("range", metadata))) + + EasyMock.reset(replicaManager) + val otherJoinGroupResult = dynamicJoinGroup(groupId, otherMemberId, protocolType, List(("roundrobin", metadata))) + timer.advanceClock(GroupInitialRebalanceDelay + 1) + + val joinGroupResult = await(joinGroupFuture, 1) + assertEquals(Errors.NONE, joinGroupResult.error) + assertEquals(Errors.INCONSISTENT_GROUP_PROTOCOL, otherJoinGroupResult.error) + } + + @Test + def testJoinGroupUnknownConsumerExistingGroup(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + val otherMemberId = "memberId" + + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols) + assertEquals(Errors.NONE, joinGroupResult.error) + + EasyMock.reset(replicaManager) + val otherJoinGroupResult = await(sendJoinGroup(groupId, otherMemberId, protocolType, protocols), 1) + assertEquals(Errors.UNKNOWN_MEMBER_ID, otherJoinGroupResult.error) + } + + @Test + def testJoinGroupUnknownConsumerNewDeadGroup(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + val deadGroupId = "deadGroupId" + + groupCoordinator.groupManager.addGroup(new GroupMetadata(deadGroupId, Dead, new MockTime())) + val joinGroupResult = dynamicJoinGroup(deadGroupId, memberId, protocolType, protocols) + assertEquals(Errors.COORDINATOR_NOT_AVAILABLE, joinGroupResult.error) + } + + @Test + def testSyncDeadGroup(): Unit = { + val memberId = "memberId" + val deadGroupId = "deadGroupId" + + groupCoordinator.groupManager.addGroup(new GroupMetadata(deadGroupId, Dead, new MockTime())) + val syncGroupResult = syncGroupFollower(deadGroupId, 1, memberId) + assertEquals(Errors.COORDINATOR_NOT_AVAILABLE, syncGroupResult.error) + } + + @Test + def testJoinGroupSecondJoinInconsistentProtocol(): Unit = { + var responseFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols, requireKnownMemberId = true) + var joinGroupResult = Await.result(responseFuture, Duration(DefaultRebalanceTimeout + 1, TimeUnit.MILLISECONDS)) + assertEquals(Errors.MEMBER_ID_REQUIRED, joinGroupResult.error) + val memberId = joinGroupResult.memberId + + // Sending an inconsistent protocol shall be refused + EasyMock.reset(replicaManager) + responseFuture = sendJoinGroup(groupId, memberId, protocolType, List(), requireKnownMemberId = true) + joinGroupResult = Await.result(responseFuture, Duration(DefaultRebalanceTimeout + 1, TimeUnit.MILLISECONDS)) + assertEquals(Errors.INCONSISTENT_GROUP_PROTOCOL, joinGroupResult.error) + + // Sending consistent protocol shall be accepted + EasyMock.reset(replicaManager) + responseFuture = sendJoinGroup(groupId, memberId, protocolType, protocols, requireKnownMemberId = true) + timer.advanceClock(GroupInitialRebalanceDelay + 1) + joinGroupResult = Await.result(responseFuture, Duration(DefaultRebalanceTimeout + 1, TimeUnit.MILLISECONDS)) + assertEquals(Errors.NONE, joinGroupResult.error) + } + + @Test + def staticMemberJoinAsFirstMember(): Unit = { + val joinGroupResult = staticJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, groupInstanceId, protocolType, protocols) + assertEquals(Errors.NONE, joinGroupResult.error) + } + + @Test + def staticMemberReJoinWithExplicitUnknownMemberId(): Unit = { + var joinGroupResult = staticJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, groupInstanceId, protocolType, protocols) + assertEquals(Errors.NONE, joinGroupResult.error) + + EasyMock.reset(replicaManager) + val unknownMemberId = "unknown_member" + joinGroupResult = staticJoinGroup(groupId, unknownMemberId, groupInstanceId, protocolType, protocols) + assertEquals(Errors.FENCED_INSTANCE_ID, joinGroupResult.error) + } + + @Test + def staticMemberFenceDuplicateRejoinedFollower(): Unit = { + val rebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId) + + // A third member joins will trigger rebalance. + sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols) + timer.advanceClock(1) + assertTrue(getGroup(groupId).is(PreparingRebalance)) + + EasyMock.reset(replicaManager) + timer.advanceClock(1) + // Old follower rejoins group will be matching current member.id. + val oldFollowerJoinGroupFuture = + sendJoinGroup(groupId, rebalanceResult.followerId, protocolType, protocols, groupInstanceId = Some(followerInstanceId)) + + EasyMock.reset(replicaManager) + timer.advanceClock(1) + // Duplicate follower joins group with unknown member id will trigger member.id replacement. + val duplicateFollowerJoinFuture = + sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols, groupInstanceId = Some(followerInstanceId)) + + timer.advanceClock(1) + // Old member shall be fenced immediately upon duplicate follower joins. + val oldFollowerJoinGroupResult = Await.result(oldFollowerJoinGroupFuture, Duration(1, TimeUnit.MILLISECONDS)) + checkJoinGroupResult(oldFollowerJoinGroupResult, + Errors.FENCED_INSTANCE_ID, + -1, + Set.empty, + PreparingRebalance, + None) + verifyDelayedTaskNotCompleted(duplicateFollowerJoinFuture) + } + + @Test + def staticMemberFenceDuplicateSyncingFollowerAfterMemberIdChanged(): Unit = { + val rebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId) + + // Known leader rejoins will trigger rebalance. + val leaderJoinGroupFuture = + sendJoinGroup(groupId, rebalanceResult.leaderId, protocolType, protocols, groupInstanceId = Some(leaderInstanceId)) + timer.advanceClock(1) + assertTrue(getGroup(groupId).is(PreparingRebalance)) + + EasyMock.reset(replicaManager) + timer.advanceClock(1) + // Old follower rejoins group will match current member.id. + val oldFollowerJoinGroupFuture = + sendJoinGroup(groupId, rebalanceResult.followerId, protocolType, protocols, groupInstanceId = Some(followerInstanceId)) + + timer.advanceClock(1) + val leaderJoinGroupResult = Await.result(leaderJoinGroupFuture, Duration(1, TimeUnit.MILLISECONDS)) + checkJoinGroupResult(leaderJoinGroupResult, + Errors.NONE, + rebalanceResult.generation + 1, + Set(leaderInstanceId, followerInstanceId), + CompletingRebalance, + Some(protocolType)) + assertEquals(rebalanceResult.leaderId, leaderJoinGroupResult.memberId) + assertEquals(rebalanceResult.leaderId, leaderJoinGroupResult.leaderId) + + // Old follower shall be getting a successful join group response. + val oldFollowerJoinGroupResult = Await.result(oldFollowerJoinGroupFuture, Duration(1, TimeUnit.MILLISECONDS)) + checkJoinGroupResult(oldFollowerJoinGroupResult, + Errors.NONE, + rebalanceResult.generation + 1, + Set.empty, + CompletingRebalance, + Some(protocolType), + expectedLeaderId = leaderJoinGroupResult.memberId) + assertEquals(rebalanceResult.followerId, oldFollowerJoinGroupResult.memberId) + assertEquals(rebalanceResult.leaderId, oldFollowerJoinGroupResult.leaderId) + assertTrue(getGroup(groupId).is(CompletingRebalance)) + + // Duplicate follower joins group with unknown member id will trigger member.id replacement, + // and will also trigger a rebalance under CompletingRebalance state; the old follower sync callback + // will return fenced exception while broker replaces the member identity with the duplicate follower joins. + EasyMock.reset(replicaManager) + val oldFollowerSyncGroupFuture = sendSyncGroupFollower(groupId, oldFollowerJoinGroupResult.generationId, + oldFollowerJoinGroupResult.memberId, Some(protocolType), Some(protocolName), Some(followerInstanceId)) + + EasyMock.reset(replicaManager) + val duplicateFollowerJoinFuture = + sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols, groupInstanceId = Some(followerInstanceId)) + timer.advanceClock(1) + + val oldFollowerSyncGroupResult = Await.result(oldFollowerSyncGroupFuture, Duration(1, TimeUnit.MILLISECONDS)) + assertEquals(Errors.FENCED_INSTANCE_ID, oldFollowerSyncGroupResult.error) + assertTrue(getGroup(groupId).is(PreparingRebalance)) + + timer.advanceClock(GroupInitialRebalanceDelay + 1) + timer.advanceClock(DefaultRebalanceTimeout + 1) + + val duplicateFollowerJoinGroupResult = Await.result(duplicateFollowerJoinFuture, Duration(1, TimeUnit.MILLISECONDS)) + checkJoinGroupResult(duplicateFollowerJoinGroupResult, + Errors.NONE, + rebalanceResult.generation + 2, + Set(followerInstanceId), // this follower will become the new leader, and hence it would have the member list + CompletingRebalance, + Some(protocolType), + expectedLeaderId = duplicateFollowerJoinGroupResult.memberId) + assertTrue(getGroup(groupId).is(CompletingRebalance)) + } + + @Test + def staticMemberFenceDuplicateRejoiningFollowerAfterMemberIdChanged(): Unit = { + val rebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId) + + // Known leader rejoins will trigger rebalance. + val leaderJoinGroupFuture = + sendJoinGroup(groupId, rebalanceResult.leaderId, protocolType, protocols, groupInstanceId = Some(leaderInstanceId)) + timer.advanceClock(1) + assertTrue(getGroup(groupId).is(PreparingRebalance)) + + EasyMock.reset(replicaManager) + // Duplicate follower joins group will trigger member.id replacement. + val duplicateFollowerJoinGroupFuture = + sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols, groupInstanceId = Some(followerInstanceId)) + + EasyMock.reset(replicaManager) + timer.advanceClock(1) + // Old follower rejoins group will fail because member.id already updated. + val oldFollowerJoinGroupFuture = + sendJoinGroup(groupId, rebalanceResult.followerId, protocolType, protocols, groupInstanceId = Some(followerInstanceId)) + + val leaderRejoinGroupResult = Await.result(leaderJoinGroupFuture, Duration(1, TimeUnit.MILLISECONDS)) + checkJoinGroupResult(leaderRejoinGroupResult, + Errors.NONE, + rebalanceResult.generation + 1, + Set(leaderInstanceId, followerInstanceId), + CompletingRebalance, + Some(protocolType)) + + val duplicateFollowerJoinGroupResult = Await.result(duplicateFollowerJoinGroupFuture, Duration(1, TimeUnit.MILLISECONDS)) + checkJoinGroupResult(duplicateFollowerJoinGroupResult, + Errors.NONE, + rebalanceResult.generation + 1, + Set.empty, + CompletingRebalance, + Some(protocolType)) + assertNotEquals(rebalanceResult.followerId, duplicateFollowerJoinGroupResult.memberId) + + val oldFollowerJoinGroupResult = Await.result(oldFollowerJoinGroupFuture, Duration(1, TimeUnit.MILLISECONDS)) + checkJoinGroupResult(oldFollowerJoinGroupResult, + Errors.FENCED_INSTANCE_ID, + -1, + Set.empty, + CompletingRebalance, + None) + } + + @Test + def staticMemberRejoinWithKnownMemberId(): Unit = { + var joinGroupResult = staticJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, groupInstanceId, protocolType, protocols) + assertEquals(Errors.NONE, joinGroupResult.error) + + EasyMock.reset(replicaManager) + val assignedMemberId = joinGroupResult.memberId + // The second join group should return immediately since we are using the same metadata during CompletingRebalance. + val rejoinResponseFuture = sendJoinGroup(groupId, assignedMemberId, protocolType, protocols, Some(groupInstanceId)) + timer.advanceClock(1) + joinGroupResult = Await.result(rejoinResponseFuture, Duration(1, TimeUnit.MILLISECONDS)) + assertEquals(Errors.NONE, joinGroupResult.error) + assertTrue(getGroup(groupId).is(CompletingRebalance)) + + EasyMock.reset(replicaManager) + val syncGroupFuture = sendSyncGroupLeader(groupId, joinGroupResult.generationId, assignedMemberId, + Some(protocolType), Some(protocolName), Some(groupInstanceId), Map(assignedMemberId -> Array[Byte]())) + timer.advanceClock(1) + val syncGroupResult = Await.result(syncGroupFuture, Duration(1, TimeUnit.MILLISECONDS)) + assertEquals(Errors.NONE, syncGroupResult.error) + assertTrue(getGroup(groupId).is(Stable)) + } + + @Test + def staticMemberRejoinWithLeaderIdAndUnknownMemberId(): Unit = { + val rebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId) + + // A static leader rejoin with unknown id will not trigger rebalance, and no assignment will be returned. + val joinGroupResult = staticJoinGroupWithPersistence(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, + leaderInstanceId, protocolType, protocolSuperset, clockAdvance = 1) + + checkJoinGroupResult(joinGroupResult, + Errors.NONE, + rebalanceResult.generation, // The group should be at the same generation + Set.empty, + Stable, + Some(protocolType), + rebalanceResult.leaderId) + + EasyMock.reset(replicaManager) + val oldLeaderJoinGroupResult = staticJoinGroup(groupId, rebalanceResult.leaderId, leaderInstanceId, protocolType, protocolSuperset, clockAdvance = 1) + assertEquals(Errors.FENCED_INSTANCE_ID, oldLeaderJoinGroupResult.error) + + EasyMock.reset(replicaManager) + // Old leader will get fenced. + val oldLeaderSyncGroupResult = syncGroupLeader(groupId, rebalanceResult.generation, rebalanceResult.leaderId, + Map.empty, None, None, Some(leaderInstanceId)) + assertEquals(Errors.FENCED_INSTANCE_ID, oldLeaderSyncGroupResult.error) + + // Calling sync on old leader.id will fail because that leader.id is no longer valid and replaced. + EasyMock.reset(replicaManager) + val newLeaderSyncGroupResult = syncGroupLeader(groupId, rebalanceResult.generation, joinGroupResult.leaderId, Map.empty) + assertEquals(Errors.UNKNOWN_MEMBER_ID, newLeaderSyncGroupResult.error) + } + + @Test + def staticMemberRejoinWithLeaderIdAndKnownMemberId(): Unit = { + val rebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId, + sessionTimeout = DefaultRebalanceTimeout / 2) + + // A static leader with known id rejoin will trigger rebalance. + val joinGroupResult = staticJoinGroup(groupId, rebalanceResult.leaderId, leaderInstanceId, + protocolType, protocolSuperset, clockAdvance = DefaultRebalanceTimeout + 1) + // Timeout follower in the meantime. + assertFalse(getGroup(groupId).hasStaticMember(followerInstanceId)) + checkJoinGroupResult(joinGroupResult, + Errors.NONE, + rebalanceResult.generation + 1, // The group has promoted to the new generation. + Set(leaderInstanceId), + CompletingRebalance, + Some(protocolType), + rebalanceResult.leaderId, + rebalanceResult.leaderId) + } + + @Test + def staticMemberRejoinWithLeaderIdAndUnexpectedDeadGroup(): Unit = { + val rebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId) + + getGroup(groupId).transitionTo(Dead) + + val joinGroupResult = staticJoinGroup(groupId, rebalanceResult.leaderId, leaderInstanceId, protocolType, protocols, clockAdvance = 1) + assertEquals(Errors.COORDINATOR_NOT_AVAILABLE, joinGroupResult.error) + } + + @Test + def staticMemberRejoinWithLeaderIdAndUnexpectedEmptyGroup(): Unit = { + val rebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId) + + getGroup(groupId).transitionTo(PreparingRebalance) + getGroup(groupId).transitionTo(Empty) + + val joinGroupResult = staticJoinGroup(groupId, rebalanceResult.leaderId, leaderInstanceId, protocolType, protocols, clockAdvance = 1) + assertEquals(Errors.UNKNOWN_MEMBER_ID, joinGroupResult.error) + } + + @Test + def staticMemberRejoinWithFollowerIdAndChangeOfProtocol(): Unit = { + val rebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId, sessionTimeout = DefaultSessionTimeout * 2) + + // A static follower rejoin with changed protocol will trigger rebalance. + val newProtocols = List(("roundrobin", metadata)) + // Old leader hasn't joined in the meantime, triggering a re-election. + val joinGroupResult = staticJoinGroup(groupId, rebalanceResult.followerId, followerInstanceId, protocolType, newProtocols, clockAdvance = DefaultSessionTimeout + 1) + + assertEquals(rebalanceResult.followerId, joinGroupResult.memberId) + assertTrue(getGroup(groupId).hasStaticMember(leaderInstanceId)) + assertTrue(getGroup(groupId).isLeader(rebalanceResult.followerId)) + checkJoinGroupResult(joinGroupResult, + Errors.NONE, + rebalanceResult.generation + 1, // The group has promoted to the new generation, and leader has changed because old one times out. + Set(leaderInstanceId, followerInstanceId), + CompletingRebalance, + Some(protocolType), + rebalanceResult.followerId, + rebalanceResult.followerId) + } + + @Test + def staticMemberRejoinWithUnknownMemberIdAndChangeOfProtocolWithSelectedProtocolChanged(): Unit = { + val rebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId) + + // A static follower rejoin with protocol changed and also cause updated group's selectedProtocol changed + // should trigger rebalance. + val selectedProtocols = getGroup(groupId).selectProtocol + val newProtocols = List(("roundrobin", metadata)) + assert(!newProtocols.map(_._1).contains(selectedProtocols)) + // Old leader hasn't joined in the meantime, triggering a re-election. + val joinGroupResult = staticJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, followerInstanceId, protocolType, newProtocols, clockAdvance = DefaultSessionTimeout + 1) + + checkJoinGroupResult(joinGroupResult, + Errors.NONE, + rebalanceResult.generation + 1, + Set(leaderInstanceId, followerInstanceId), + CompletingRebalance, + Some(protocolType)) + + assertTrue(getGroup(groupId).isLeader(joinGroupResult.memberId)) + assertNotEquals(rebalanceResult.followerId, joinGroupResult.memberId) + assertEquals(joinGroupResult.protocolName, Some("roundrobin")) + } + + @Test + def staticMemberRejoinWithUnknownMemberIdAndChangeOfProtocolWhileSelectProtocolUnchangedPersistenceFailure(): Unit = { + val rebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId) + + val selectedProtocol = getGroup(groupId).selectProtocol + val newProtocols = List((selectedProtocol, metadata)) + // Timeout old leader in the meantime. + val joinGroupResult = staticJoinGroupWithPersistence(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, + followerInstanceId, protocolType, newProtocols, clockAdvance = 1, appendRecordError = Errors.MESSAGE_TOO_LARGE) + + checkJoinGroupResult(joinGroupResult, + Errors.UNKNOWN_SERVER_ERROR, + rebalanceResult.generation, + Set.empty, + Stable, + Some(protocolType)) + + EasyMock.reset(replicaManager) + // Join with old member id will not fail because the member id is not updated because of persistence failure + assertNotEquals(rebalanceResult.followerId, joinGroupResult.memberId) + val oldFollowerJoinGroupResult = staticJoinGroup(groupId, rebalanceResult.followerId, followerInstanceId, protocolType, newProtocols, clockAdvance = 1) + assertEquals(Errors.NONE, oldFollowerJoinGroupResult.error) + + EasyMock.reset(replicaManager) + // Sync with old member id will also not fail because the member id is not updated because of persistence failure + val syncGroupWithOldMemberIdResult = syncGroupFollower(groupId, rebalanceResult.generation, + rebalanceResult.followerId, None, None, Some(followerInstanceId)) + assertEquals(Errors.NONE, syncGroupWithOldMemberIdResult.error) + } + + @Test + def staticMemberRejoinWithUnknownMemberIdAndChangeOfProtocolWhileSelectProtocolUnchanged(): Unit = { + val rebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId) + + // A static follower rejoin with protocol changing to leader protocol subset won't trigger rebalance if updated + // group's selectProtocol remain unchanged. + val selectedProtocol = getGroup(groupId).selectProtocol + val newProtocols = List((selectedProtocol, metadata)) + // Timeout old leader in the meantime. + val joinGroupResult = staticJoinGroupWithPersistence(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, + followerInstanceId, protocolType, newProtocols, clockAdvance = 1) + + checkJoinGroupResult(joinGroupResult, + Errors.NONE, + rebalanceResult.generation, + Set.empty, + Stable, + Some(protocolType)) + + EasyMock.reset(replicaManager) + // Join with old member id will fail because the member id is updated + assertNotEquals(rebalanceResult.followerId, joinGroupResult.memberId) + val oldFollowerJoinGroupResult = staticJoinGroup(groupId, rebalanceResult.followerId, followerInstanceId, protocolType, newProtocols, clockAdvance = 1) + assertEquals(Errors.FENCED_INSTANCE_ID, oldFollowerJoinGroupResult.error) + + EasyMock.reset(replicaManager) + // Sync with old member id will fail because the member id is updated + val syncGroupWithOldMemberIdResult = syncGroupFollower(groupId, rebalanceResult.generation, + rebalanceResult.followerId, None, None, Some(followerInstanceId)) + assertEquals(Errors.FENCED_INSTANCE_ID, syncGroupWithOldMemberIdResult.error) + + EasyMock.reset(replicaManager) + val syncGroupWithNewMemberIdResult = syncGroupFollower(groupId, rebalanceResult.generation, + joinGroupResult.memberId, None, None, Some(followerInstanceId)) + assertEquals(Errors.NONE, syncGroupWithNewMemberIdResult.error) + assertEquals(rebalanceResult.followerAssignment, syncGroupWithNewMemberIdResult.memberAssignment) + } + + @Test + def staticMemberRejoinWithKnownLeaderIdToTriggerRebalanceAndFollowerWithChangeofProtocol(): Unit = { + val rebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId) + + // A static leader rejoin with known member id will trigger rebalance. + val leaderRejoinGroupFuture = sendJoinGroup(groupId, rebalanceResult.leaderId, protocolType, + protocolSuperset, Some(leaderInstanceId)) + // Rebalance complete immediately after follower rejoin. + EasyMock.reset(replicaManager) + val followerRejoinWithFuture = sendJoinGroup(groupId, rebalanceResult.followerId, protocolType, + protocolSuperset, Some(followerInstanceId)) + + timer.advanceClock(1) + + // Leader should get the same assignment as last round. + checkJoinGroupResult(await(leaderRejoinGroupFuture, 1), + Errors.NONE, + rebalanceResult.generation + 1, // The group has promoted to the new generation. + Set(leaderInstanceId, followerInstanceId), + CompletingRebalance, + Some(protocolType), + rebalanceResult.leaderId, + rebalanceResult.leaderId) + + checkJoinGroupResult(await(followerRejoinWithFuture, 1), + Errors.NONE, + rebalanceResult.generation + 1, // The group has promoted to the new generation. + Set.empty, + CompletingRebalance, + Some(protocolType), + rebalanceResult.leaderId, + rebalanceResult.followerId) + + EasyMock.reset(replicaManager) + // The follower protocol changed from protocolSuperset to general protocols. + val followerRejoinWithProtocolChangeFuture = sendJoinGroup(groupId, rebalanceResult.followerId, + protocolType, protocols, Some(followerInstanceId)) + // The group will transit to PreparingRebalance due to protocol change from follower. + assertTrue(getGroup(groupId).is(PreparingRebalance)) + + timer.advanceClock(DefaultRebalanceTimeout + 1) + checkJoinGroupResult(await(followerRejoinWithProtocolChangeFuture, 1), + Errors.NONE, + rebalanceResult.generation + 2, // The group has promoted to the new generation. + Set(followerInstanceId), + CompletingRebalance, + Some(protocolType), + rebalanceResult.followerId, + rebalanceResult.followerId) + } + + @Test + def staticMemberRejoinAsFollowerWithUnknownMemberId(): Unit = { + val rebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId) + + // A static follower rejoin with no protocol change will not trigger rebalance. + val joinGroupResult = staticJoinGroupWithPersistence(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, followerInstanceId, protocolType, protocolSuperset, clockAdvance = 1) + + // Old leader shouldn't be timed out. + assertTrue(getGroup(groupId).hasStaticMember(leaderInstanceId)) + checkJoinGroupResult(joinGroupResult, + Errors.NONE, + rebalanceResult.generation, // The group has no change. + Set.empty, + Stable, + Some(protocolType)) + + assertNotEquals(rebalanceResult.followerId, joinGroupResult.memberId) + + EasyMock.reset(replicaManager) + val syncGroupResult = syncGroupFollower(groupId, rebalanceResult.generation, joinGroupResult.memberId) + assertEquals(Errors.NONE, syncGroupResult.error) + assertEquals(rebalanceResult.followerAssignment, syncGroupResult.memberAssignment) + } + + @Test + def staticMemberRejoinAsFollowerWithKnownMemberIdAndNoProtocolChange(): Unit = { + val rebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId) + + // A static follower rejoin with no protocol change will not trigger rebalance. + val joinGroupResult = staticJoinGroup(groupId, rebalanceResult.followerId, followerInstanceId, protocolType, protocolSuperset, clockAdvance = 1) + + // Old leader shouldn't be timed out. + assertTrue(getGroup(groupId).hasStaticMember(leaderInstanceId)) + checkJoinGroupResult(joinGroupResult, + Errors.NONE, + rebalanceResult.generation, // The group has no change. + Set.empty, + Stable, + Some(protocolType), + rebalanceResult.leaderId, + rebalanceResult.followerId) + } + + @Test + def staticMemberRejoinAsFollowerWithMismatchedMemberId(): Unit = { + val rebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId) + + val joinGroupResult = staticJoinGroup(groupId, rebalanceResult.followerId, leaderInstanceId, protocolType, protocolSuperset, clockAdvance = 1) + assertEquals(Errors.FENCED_INSTANCE_ID, joinGroupResult.error) + } + + @Test + def staticMemberRejoinAsLeaderWithMismatchedMemberId(): Unit = { + val rebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId) + + val joinGroupResult = staticJoinGroup(groupId, rebalanceResult.leaderId, followerInstanceId, protocolType, protocolSuperset, clockAdvance = 1) + assertEquals(Errors.FENCED_INSTANCE_ID, joinGroupResult.error) + } + + @Test + def staticMemberSyncAsLeaderWithInvalidMemberId(): Unit = { + val rebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId) + + val syncGroupResult = syncGroupLeader(groupId, rebalanceResult.generation, "invalid", + Map.empty, None, None, Some(leaderInstanceId)) + assertEquals(Errors.FENCED_INSTANCE_ID, syncGroupResult.error) + } + + @Test + def staticMemberHeartbeatLeaderWithInvalidMemberId(): Unit = { + val rebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId) + + val syncGroupResult = syncGroupLeader(groupId, rebalanceResult.generation, rebalanceResult.leaderId, Map.empty) + assertEquals(Errors.NONE, syncGroupResult.error) + + EasyMock.reset(replicaManager) + val validHeartbeatResult = heartbeat(groupId, rebalanceResult.leaderId, rebalanceResult.generation) + assertEquals(Errors.NONE, validHeartbeatResult) + + EasyMock.reset(replicaManager) + val invalidHeartbeatResult = heartbeat(groupId, invalidMemberId, rebalanceResult.generation, Some(leaderInstanceId)) + assertEquals(Errors.FENCED_INSTANCE_ID, invalidHeartbeatResult) + } + + @Test + def shouldGetDifferentStaticMemberIdAfterEachRejoin(): Unit = { + val initialResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId) + + val timeAdvance = 1 + var lastMemberId = initialResult.leaderId + for (_ <- 1 to 5) { + EasyMock.reset(replicaManager) + val joinGroupResult = staticJoinGroupWithPersistence(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, + leaderInstanceId, protocolType, protocols, clockAdvance = timeAdvance) + assertTrue(joinGroupResult.memberId.startsWith(leaderInstanceId)) + assertNotEquals(lastMemberId, joinGroupResult.memberId) + lastMemberId = joinGroupResult.memberId + } + } + + @Test + def testOffsetCommitDeadGroup(): Unit = { + val memberId = "memberId" + + val deadGroupId = "deadGroupId" + val tp = new TopicPartition("topic", 0) + val offset = offsetAndMetadata(0) + + groupCoordinator.groupManager.addGroup(new GroupMetadata(deadGroupId, Dead, new MockTime())) + val offsetCommitResult = commitOffsets(deadGroupId, memberId, 1, Map(tp -> offset)) + assertEquals(Errors.COORDINATOR_NOT_AVAILABLE, offsetCommitResult(tp)) + } + + @Test + def staticMemberCommitOffsetWithInvalidMemberId(): Unit = { + val rebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId) + + val syncGroupResult = syncGroupLeader(groupId, rebalanceResult.generation, rebalanceResult.leaderId, Map.empty) + assertEquals(Errors.NONE, syncGroupResult.error) + + val tp = new TopicPartition("topic", 0) + val offset = offsetAndMetadata(0) + EasyMock.reset(replicaManager) + val validOffsetCommitResult = commitOffsets(groupId, rebalanceResult.leaderId, rebalanceResult.generation, Map(tp -> offset)) + assertEquals(Errors.NONE, validOffsetCommitResult(tp)) + + EasyMock.reset(replicaManager) + val invalidOffsetCommitResult = commitOffsets(groupId, invalidMemberId, rebalanceResult.generation, + Map(tp -> offset), Some(leaderInstanceId)) + assertEquals(Errors.FENCED_INSTANCE_ID, invalidOffsetCommitResult(tp)) + } + + @Test + def staticMemberJoinWithUnknownInstanceIdAndKnownMemberId(): Unit = { + val rebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId) + + val joinGroupResult = staticJoinGroup(groupId, rebalanceResult.leaderId, "unknown_instance", + protocolType, protocolSuperset, clockAdvance = 1) + + assertEquals(Errors.UNKNOWN_MEMBER_ID, joinGroupResult.error) + } + + @Test + def staticMemberReJoinWithIllegalStateAsUnknownMember(): Unit = { + staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId) + val group = groupCoordinator.groupManager.getGroup(groupId).get + group.transitionTo(PreparingRebalance) + group.transitionTo(Empty) + + EasyMock.reset(replicaManager) + + // Illegal state exception shall trigger since follower id resides in pending member bucket. + val expectedException = assertThrows(classOf[IllegalStateException], + () => staticJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, followerInstanceId, protocolType, protocolSuperset, clockAdvance = 1)) + + val message = expectedException.getMessage + assertTrue(message.contains(group.groupId)) + assertTrue(message.contains(followerInstanceId)) + } + + @Test + def testLeaderFailToRejoinBeforeFinalRebalanceTimeoutWithLongSessionTimeout(): Unit = { + groupStuckInRebalanceTimeoutDueToNonjoinedStaticMember() + + timer.advanceClock(DefaultRebalanceTimeout + 1) + // The static leader should already session timeout, moving group towards Empty + assertEquals(Set.empty, getGroup(groupId).allMembers) + assertNull(getGroup(groupId).leaderOrNull) + assertEquals(3, getGroup(groupId).generationId) + assertGroupState(groupState = Empty) + } + + @Test + def testLeaderRejoinBeforeFinalRebalanceTimeoutWithLongSessionTimeout(): Unit = { + groupStuckInRebalanceTimeoutDueToNonjoinedStaticMember() + + EasyMock.reset(replicaManager) + // The static leader should be back now, moving group towards CompletingRebalance + val leaderRejoinGroupResult = staticJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, leaderInstanceId, protocolType, protocols) + checkJoinGroupResult(leaderRejoinGroupResult, + Errors.NONE, + 3, + Set(leaderInstanceId), + CompletingRebalance, + Some(protocolType) + ) + assertEquals(Set(leaderRejoinGroupResult.memberId), getGroup(groupId).allMembers) + assertNotNull(getGroup(groupId).leaderOrNull) + assertEquals(3, getGroup(groupId).generationId) + } + + def groupStuckInRebalanceTimeoutDueToNonjoinedStaticMember(): Unit = { + val longSessionTimeout = DefaultSessionTimeout * 2 + val rebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId, sessionTimeout = longSessionTimeout) + + val dynamicJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocolSuperset, sessionTimeout = longSessionTimeout) + timer.advanceClock(DefaultRebalanceTimeout + 1) + + val dynamicJoinResult = await(dynamicJoinFuture, 100) + // The new dynamic member has been elected as leader + assertEquals(dynamicJoinResult.leaderId, dynamicJoinResult.memberId) + assertEquals(Errors.NONE, dynamicJoinResult.error) + assertEquals(3, dynamicJoinResult.members.size) + assertEquals(2, dynamicJoinResult.generationId) + assertGroupState(groupState = CompletingRebalance) + + assertEquals(Set(rebalanceResult.leaderId, rebalanceResult.followerId, + dynamicJoinResult.memberId), getGroup(groupId).allMembers) + assertEquals(Set(leaderInstanceId, followerInstanceId), + getGroup(groupId).allStaticMembers) + assertEquals(Set(dynamicJoinResult.memberId), getGroup(groupId).allDynamicMembers) + + // Send a special leave group request from static follower, moving group towards PreparingRebalance + EasyMock.reset(replicaManager) + val followerLeaveGroupResults = singleLeaveGroup(groupId, rebalanceResult.followerId) + verifyLeaveGroupResult(followerLeaveGroupResults) + assertGroupState(groupState = PreparingRebalance) + + timer.advanceClock(DefaultRebalanceTimeout + 1) + // Only static leader is maintained, and group is stuck at PreparingRebalance stage + assertTrue(getGroup(groupId).allDynamicMembers.isEmpty) + assertEquals(Set(rebalanceResult.leaderId), getGroup(groupId).allMembers) + assertTrue(getGroup(groupId).allDynamicMembers.isEmpty) + assertEquals(2, getGroup(groupId).generationId) + assertGroupState(groupState = PreparingRebalance) + } + + @Test + def testStaticMemberFollowerFailToRejoinBeforeRebalanceTimeout(): Unit = { + // Increase session timeout so that the follower won't be evicted when rebalance timeout is reached. + val initialRebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId, sessionTimeout = DefaultRebalanceTimeout * 2) + + val newMemberInstanceId = "newMember" + + val leaderId = initialRebalanceResult.leaderId + + val newMemberJoinGroupFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, + protocolSuperset, Some(newMemberInstanceId)) + assertGroupState(groupState = PreparingRebalance) + + EasyMock.reset(replicaManager) + val leaderRejoinGroupResult = staticJoinGroup(groupId, leaderId, leaderInstanceId, protocolType, protocolSuperset, clockAdvance = DefaultRebalanceTimeout + 1) + checkJoinGroupResult(leaderRejoinGroupResult, + Errors.NONE, + initialRebalanceResult.generation + 1, + Set(leaderInstanceId, followerInstanceId, newMemberInstanceId), + CompletingRebalance, + Some(protocolType), + expectedLeaderId = leaderId, + expectedMemberId = leaderId) + + val newMemberJoinGroupResult = Await.result(newMemberJoinGroupFuture, Duration(1, TimeUnit.MILLISECONDS)) + assertEquals(Errors.NONE, newMemberJoinGroupResult.error) + checkJoinGroupResult(newMemberJoinGroupResult, + Errors.NONE, + initialRebalanceResult.generation + 1, + Set.empty, + CompletingRebalance, + Some(protocolType), + expectedLeaderId = leaderId) + } + + @Test + def testStaticMemberLeaderFailToRejoinBeforeRebalanceTimeout(): Unit = { + // Increase session timeout so that the leader won't be evicted when rebalance timeout is reached. + val initialRebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId, sessionTimeout = DefaultRebalanceTimeout * 2) + + val newMemberInstanceId = "newMember" + + val newMemberJoinGroupFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, + protocolSuperset, Some(newMemberInstanceId)) + timer.advanceClock(1) + assertGroupState(groupState = PreparingRebalance) + + EasyMock.reset(replicaManager) + val oldFollowerRejoinGroupResult = staticJoinGroup(groupId, initialRebalanceResult.followerId, followerInstanceId, protocolType, protocolSuperset, clockAdvance = DefaultRebalanceTimeout + 1) + val newMemberJoinGroupResult = Await.result(newMemberJoinGroupFuture, Duration(1, TimeUnit.MILLISECONDS)) + + val (newLeaderResult, newFollowerResult) = if (oldFollowerRejoinGroupResult.leaderId == oldFollowerRejoinGroupResult.memberId) + (oldFollowerRejoinGroupResult, newMemberJoinGroupResult) + else + (newMemberJoinGroupResult, oldFollowerRejoinGroupResult) + + checkJoinGroupResult(newLeaderResult, + Errors.NONE, + initialRebalanceResult.generation + 1, + Set(leaderInstanceId, followerInstanceId, newMemberInstanceId), + CompletingRebalance, + Some(protocolType)) + + checkJoinGroupResult(newFollowerResult, + Errors.NONE, + initialRebalanceResult.generation + 1, + Set.empty, + CompletingRebalance, + Some(protocolType), + expectedLeaderId = newLeaderResult.memberId) + } + + @Test + def testJoinGroupProtocolTypeIsNotProvidedWhenAnErrorOccurs(): Unit = { + // JoinGroup(leader) + EasyMock.reset(replicaManager) + val leaderResponseFuture = sendJoinGroup(groupId, "fake-id", protocolType, + protocolSuperset, Some(leaderInstanceId), DefaultSessionTimeout) + + // The Protocol Type is None when there is an error + val leaderJoinGroupResult = await(leaderResponseFuture, 1) + assertEquals(Errors.UNKNOWN_MEMBER_ID, leaderJoinGroupResult.error) + assertEquals(None, leaderJoinGroupResult.protocolType) + } + + @Test + def testJoinGroupReturnsTheProtocolType(): Unit = { + // JoinGroup(leader) + EasyMock.reset(replicaManager) + val leaderResponseFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, + protocolSuperset, Some(leaderInstanceId), DefaultSessionTimeout) + + // JoinGroup(follower) + EasyMock.reset(replicaManager) + val followerResponseFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, + protocolSuperset, Some(followerInstanceId), DefaultSessionTimeout) + + timer.advanceClock(GroupInitialRebalanceDelay + 1) + timer.advanceClock(DefaultRebalanceTimeout + 1) + + // The Protocol Type is Defined when there is not error + val leaderJoinGroupResult = await(leaderResponseFuture, 1) + assertEquals(Errors.NONE, leaderJoinGroupResult.error) + assertEquals(protocolType, leaderJoinGroupResult.protocolType.orNull) + + // The Protocol Type is Defined when there is not error + val followerJoinGroupResult = await(followerResponseFuture, 1) + assertEquals(Errors.NONE, followerJoinGroupResult.error) + assertEquals(protocolType, followerJoinGroupResult.protocolType.orNull) + } + + @Test + def testSyncGroupReturnsAnErrorWhenProtocolTypeIsInconsistent(): Unit = { + testSyncGroupProtocolTypeAndNameWith(Some("whatever"), None, Errors.INCONSISTENT_GROUP_PROTOCOL, + None, None) + } + + @Test + def testSyncGroupReturnsAnErrorWhenProtocolNameIsInconsistent(): Unit = { + testSyncGroupProtocolTypeAndNameWith(None, Some("whatever"), Errors.INCONSISTENT_GROUP_PROTOCOL, + None, None) + } + + @Test + def testSyncGroupSucceedWhenProtocolTypeAndNameAreNotProvided(): Unit = { + testSyncGroupProtocolTypeAndNameWith(None, None, Errors.NONE, + Some(protocolType), Some(protocolName)) + } + + @Test + def testSyncGroupSucceedWhenProtocolTypeAndNameAreConsistent(): Unit = { + testSyncGroupProtocolTypeAndNameWith(Some(protocolType), Some(protocolName), + Errors.NONE, Some(protocolType), Some(protocolName)) + } + + private def testSyncGroupProtocolTypeAndNameWith(protocolType: Option[String], + protocolName: Option[String], + expectedError: Errors, + expectedProtocolType: Option[String], + expectedProtocolName: Option[String]): Unit = { + // JoinGroup(leader) with the Protocol Type of the group + EasyMock.reset(replicaManager) + val leaderResponseFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, this.protocolType, + protocolSuperset, Some(leaderInstanceId), DefaultSessionTimeout) + + // JoinGroup(follower) with the Protocol Type of the group + EasyMock.reset(replicaManager) + val followerResponseFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, this.protocolType, + protocolSuperset, Some(followerInstanceId), DefaultSessionTimeout) + + timer.advanceClock(GroupInitialRebalanceDelay + 1) + timer.advanceClock(DefaultRebalanceTimeout + 1) + + val leaderJoinGroupResult = await(leaderResponseFuture, 1) + val leaderId = leaderJoinGroupResult.memberId + val generationId = leaderJoinGroupResult.generationId + val followerJoinGroupResult = await(followerResponseFuture, 1) + val followerId = followerJoinGroupResult.memberId + + // SyncGroup with the provided Protocol Type and Name + EasyMock.reset(replicaManager) + val leaderSyncGroupResult = syncGroupLeader(groupId, generationId, leaderId, + Map(leaderId -> Array.empty), protocolType, protocolName) + assertEquals(expectedError, leaderSyncGroupResult.error) + assertEquals(expectedProtocolType, leaderSyncGroupResult.protocolType) + assertEquals(expectedProtocolName, leaderSyncGroupResult.protocolName) + + // SyncGroup with the provided Protocol Type and Name + EasyMock.reset(replicaManager) + val followerSyncGroupResult = syncGroupFollower(groupId, generationId, followerId, + protocolType, protocolName) + assertEquals(expectedError, followerSyncGroupResult.error) + assertEquals(expectedProtocolType, followerSyncGroupResult.protocolType) + assertEquals(expectedProtocolName, followerSyncGroupResult.protocolName) + } + + private class RebalanceResult(val generation: Int, + val leaderId: String, + val leaderAssignment: Array[Byte], + val followerId: String, + val followerAssignment: Array[Byte]) + /** + * Generate static member rebalance results, including: + * - generation + * - leader id + * - leader assignment + * - follower id + * - follower assignment + */ + private def staticMembersJoinAndRebalance(leaderInstanceId: String, + followerInstanceId: String, + sessionTimeout: Int = DefaultSessionTimeout): RebalanceResult = { + EasyMock.reset(replicaManager) + val leaderResponseFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, + protocolSuperset, Some(leaderInstanceId), sessionTimeout) + + EasyMock.reset(replicaManager) + val followerResponseFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, + protocolSuperset, Some(followerInstanceId), sessionTimeout) + // The goal for two timer advance is to let first group initial join complete and set newMemberAdded flag to false. Next advance is + // to trigger the rebalance as needed for follower delayed join. One large time advance won't help because we could only populate one + // delayed join from purgatory and the new delayed op is created at that time and never be triggered. + timer.advanceClock(GroupInitialRebalanceDelay + 1) + timer.advanceClock(DefaultRebalanceTimeout + 1) + val newGeneration = 1 + + val leaderJoinGroupResult = await(leaderResponseFuture, 1) + assertEquals(Errors.NONE, leaderJoinGroupResult.error) + assertEquals(newGeneration, leaderJoinGroupResult.generationId) + + val followerJoinGroupResult = await(followerResponseFuture, 1) + assertEquals(Errors.NONE, followerJoinGroupResult.error) + assertEquals(newGeneration, followerJoinGroupResult.generationId) + + EasyMock.reset(replicaManager) + val leaderId = leaderJoinGroupResult.memberId + val leaderSyncGroupResult = syncGroupLeader(groupId, leaderJoinGroupResult.generationId, leaderId, Map(leaderId -> Array[Byte]())) + assertEquals(Errors.NONE, leaderSyncGroupResult.error) + assertTrue(getGroup(groupId).is(Stable)) + + EasyMock.reset(replicaManager) + val followerId = followerJoinGroupResult.memberId + val followerSyncGroupResult = syncGroupFollower(groupId, leaderJoinGroupResult.generationId, followerId) + assertEquals(Errors.NONE, followerSyncGroupResult.error) + assertTrue(getGroup(groupId).is(Stable)) + + EasyMock.reset(replicaManager) + new RebalanceResult(newGeneration, + leaderId, + leaderSyncGroupResult.memberAssignment, + followerId, + followerSyncGroupResult.memberAssignment) + } + + private def checkJoinGroupResult(joinGroupResult: JoinGroupResult, + expectedError: Errors, + expectedGeneration: Int, + expectedGroupInstanceIds: Set[String], + expectedGroupState: GroupState, + expectedProtocolType: Option[String], + expectedLeaderId: String = JoinGroupRequest.UNKNOWN_MEMBER_ID, + expectedMemberId: String = JoinGroupRequest.UNKNOWN_MEMBER_ID): Unit = { + assertEquals(expectedError, joinGroupResult.error) + assertEquals(expectedGeneration, joinGroupResult.generationId) + assertEquals(expectedGroupInstanceIds.size, joinGroupResult.members.size) + val resultedGroupInstanceIds = joinGroupResult.members.map(member => member.groupInstanceId).toSet + assertEquals(expectedGroupInstanceIds, resultedGroupInstanceIds) + assertGroupState(groupState = expectedGroupState) + assertEquals(expectedProtocolType, joinGroupResult.protocolType) + + if (!expectedLeaderId.equals(JoinGroupRequest.UNKNOWN_MEMBER_ID)) { + assertEquals(expectedLeaderId, joinGroupResult.leaderId) + } + if (!expectedMemberId.equals(JoinGroupRequest.UNKNOWN_MEMBER_ID)) { + assertEquals(expectedMemberId, joinGroupResult.memberId) + } + } + + @Test + def testHeartbeatWrongCoordinator(): Unit = { + val heartbeatResult = heartbeat(otherGroupId, memberId, -1) + assertEquals(Errors.NOT_COORDINATOR, heartbeatResult) + } + + @Test + def testHeartbeatUnknownGroup(): Unit = { + val heartbeatResult = heartbeat(groupId, memberId, -1) + assertEquals(Errors.UNKNOWN_MEMBER_ID, heartbeatResult) + } + + @Test + def testHeartbeatDeadGroup(): Unit = { + val memberId = "memberId" + + val deadGroupId = "deadGroupId" + + groupCoordinator.groupManager.addGroup(new GroupMetadata(deadGroupId, Dead, new MockTime())) + val heartbeatResult = heartbeat(deadGroupId, memberId, 1) + assertEquals(Errors.COORDINATOR_NOT_AVAILABLE, heartbeatResult) + } + + @Test + def testHeartbeatEmptyGroup(): Unit = { + val memberId = "memberId" + + val group = new GroupMetadata(groupId, Empty, new MockTime()) + val member = new MemberMetadata(memberId, Some(groupInstanceId), + ClientId, ClientHost, DefaultRebalanceTimeout, DefaultSessionTimeout, + protocolType, List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte]))) + + group.add(member) + groupCoordinator.groupManager.addGroup(group) + val heartbeatResult = heartbeat(groupId, memberId, 0) + assertEquals(Errors.UNKNOWN_MEMBER_ID, heartbeatResult) + } + + @Test + def testHeartbeatUnknownConsumerExistingGroup(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + val otherMemberId = "memberId" + + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols) + val assignedMemberId = joinGroupResult.memberId + val joinGroupError = joinGroupResult.error + assertEquals(Errors.NONE, joinGroupError) + + EasyMock.reset(replicaManager) + val syncGroupResult = syncGroupLeader(groupId, joinGroupResult.generationId, assignedMemberId, Map(assignedMemberId -> Array[Byte]())) + assertEquals(Errors.NONE, syncGroupResult.error) + + EasyMock.reset(replicaManager) + val heartbeatResult = heartbeat(groupId, otherMemberId, 1) + assertEquals(Errors.UNKNOWN_MEMBER_ID, heartbeatResult) + } + + @Test + def testHeartbeatRebalanceInProgress(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols) + val assignedMemberId = joinGroupResult.memberId + val joinGroupError = joinGroupResult.error + assertEquals(Errors.NONE, joinGroupError) + + EasyMock.reset(replicaManager) + val heartbeatResult = heartbeat(groupId, assignedMemberId, 1) + assertEquals(Errors.NONE, heartbeatResult) + } + + @Test + def testHeartbeatIllegalGeneration(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols) + val assignedMemberId = joinGroupResult.memberId + val joinGroupError = joinGroupResult.error + assertEquals(Errors.NONE, joinGroupError) + + EasyMock.reset(replicaManager) + val syncGroupResult = syncGroupLeader(groupId, joinGroupResult.generationId, assignedMemberId, Map(assignedMemberId -> Array[Byte]())) + assertEquals(Errors.NONE, syncGroupResult.error) + + EasyMock.reset(replicaManager) + val heartbeatResult = heartbeat(groupId, assignedMemberId, 2) + assertEquals(Errors.ILLEGAL_GENERATION, heartbeatResult) + } + + @Test + def testValidHeartbeat(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols) + val assignedConsumerId = joinGroupResult.memberId + val generationId = joinGroupResult.generationId + val joinGroupError = joinGroupResult.error + assertEquals(Errors.NONE, joinGroupError) + + EasyMock.reset(replicaManager) + val syncGroupResult = syncGroupLeader(groupId, generationId, assignedConsumerId, Map(assignedConsumerId -> Array[Byte]())) + assertEquals(Errors.NONE, syncGroupResult.error) + + EasyMock.reset(replicaManager) + val heartbeatResult = heartbeat(groupId, assignedConsumerId, 1) + assertEquals(Errors.NONE, heartbeatResult) + } + + @Test + def testSessionTimeout(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols) + val assignedConsumerId = joinGroupResult.memberId + val generationId = joinGroupResult.generationId + val joinGroupError = joinGroupResult.error + assertEquals(Errors.NONE, joinGroupError) + + EasyMock.reset(replicaManager) + val syncGroupResult = syncGroupLeader(groupId, generationId, assignedConsumerId, Map(assignedConsumerId -> Array[Byte]())) + assertEquals(Errors.NONE, syncGroupResult.error) + + EasyMock.reset(replicaManager) + EasyMock.expect(replicaManager.getPartition(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, groupPartitionId))) + .andReturn(HostedPartition.None) + EasyMock.expect(replicaManager.getMagic(EasyMock.anyObject())).andReturn(Some(RecordBatch.MAGIC_VALUE_V1)).anyTimes() + EasyMock.replay(replicaManager) + + timer.advanceClock(DefaultSessionTimeout + 100) + + EasyMock.reset(replicaManager) + val heartbeatResult = heartbeat(groupId, assignedConsumerId, 1) + assertEquals(Errors.UNKNOWN_MEMBER_ID, heartbeatResult) + } + + @Test + def testHeartbeatMaintainsSession(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + val sessionTimeout = 1000 + + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols, + rebalanceTimeout = sessionTimeout, sessionTimeout = sessionTimeout) + val assignedConsumerId = joinGroupResult.memberId + val generationId = joinGroupResult.generationId + val joinGroupError = joinGroupResult.error + assertEquals(Errors.NONE, joinGroupError) + + EasyMock.reset(replicaManager) + val syncGroupResult = syncGroupLeader(groupId, generationId, assignedConsumerId, Map(assignedConsumerId -> Array[Byte]())) + assertEquals(Errors.NONE, syncGroupResult.error) + + timer.advanceClock(sessionTimeout / 2) + + EasyMock.reset(replicaManager) + var heartbeatResult = heartbeat(groupId, assignedConsumerId, 1) + assertEquals(Errors.NONE, heartbeatResult) + + timer.advanceClock(sessionTimeout / 2 + 100) + + EasyMock.reset(replicaManager) + heartbeatResult = heartbeat(groupId, assignedConsumerId, 1) + assertEquals(Errors.NONE, heartbeatResult) + } + + @Test + def testCommitMaintainsSession(): Unit = { + val sessionTimeout = 1000 + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + val tp = new TopicPartition("topic", 0) + val offset = offsetAndMetadata(0) + + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols, + rebalanceTimeout = sessionTimeout, sessionTimeout = sessionTimeout) + val assignedMemberId = joinGroupResult.memberId + val generationId = joinGroupResult.generationId + val joinGroupError = joinGroupResult.error + assertEquals(Errors.NONE, joinGroupError) + + EasyMock.reset(replicaManager) + val syncGroupResult = syncGroupLeader(groupId, generationId, assignedMemberId, Map(assignedMemberId -> Array[Byte]())) + assertEquals(Errors.NONE, syncGroupResult.error) + + timer.advanceClock(sessionTimeout / 2) + + EasyMock.reset(replicaManager) + val commitOffsetResult = commitOffsets(groupId, assignedMemberId, generationId, Map(tp -> offset)) + assertEquals(Errors.NONE, commitOffsetResult(tp)) + + timer.advanceClock(sessionTimeout / 2 + 100) + + EasyMock.reset(replicaManager) + val heartbeatResult = heartbeat(groupId, assignedMemberId, 1) + assertEquals(Errors.NONE, heartbeatResult) + } + + @Test + def testSessionTimeoutDuringRebalance(): Unit = { + // create a group with a single member + val firstJoinResult = dynamicJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols, + rebalanceTimeout = 2000, sessionTimeout = 1000) + val firstMemberId = firstJoinResult.memberId + val firstGenerationId = firstJoinResult.generationId + assertEquals(firstMemberId, firstJoinResult.leaderId) + assertEquals(Errors.NONE, firstJoinResult.error) + + EasyMock.reset(replicaManager) + val firstSyncResult = syncGroupLeader(groupId, firstGenerationId, firstMemberId, Map(firstMemberId -> Array[Byte]())) + assertEquals(Errors.NONE, firstSyncResult.error) + + // now have a new member join to trigger a rebalance + EasyMock.reset(replicaManager) + val otherJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols) + + timer.advanceClock(500) + + EasyMock.reset(replicaManager) + var heartbeatResult = heartbeat(groupId, firstMemberId, firstGenerationId) + assertEquals(Errors.REBALANCE_IN_PROGRESS, heartbeatResult) + + // letting the session expire should make the member fall out of the group + timer.advanceClock(1100) + + EasyMock.reset(replicaManager) + heartbeatResult = heartbeat(groupId, firstMemberId, firstGenerationId) + assertEquals(Errors.UNKNOWN_MEMBER_ID, heartbeatResult) + + // and the rebalance should complete with only the new member + val otherJoinResult = await(otherJoinFuture, DefaultSessionTimeout+100) + assertEquals(Errors.NONE, otherJoinResult.error) + } + + @Test + def testRebalanceCompletesBeforeMemberJoins(): Unit = { + // create a group with a single member + val firstJoinResult = staticJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, leaderInstanceId, protocolType, protocols, + rebalanceTimeout = 1200, sessionTimeout = 1000) + val firstMemberId = firstJoinResult.memberId + val firstGenerationId = firstJoinResult.generationId + assertEquals(firstMemberId, firstJoinResult.leaderId) + assertEquals(Errors.NONE, firstJoinResult.error) + + EasyMock.reset(replicaManager) + val firstSyncResult = syncGroupLeader(groupId, firstGenerationId, firstMemberId, Map(firstMemberId -> Array[Byte]())) + assertEquals(Errors.NONE, firstSyncResult.error) + + // now have a new member join to trigger a rebalance + EasyMock.reset(replicaManager) + val otherMemberSessionTimeout = DefaultSessionTimeout + val otherJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols) + + // send a couple heartbeats to keep the member alive while the rebalance finishes + var expectedResultList = List(Errors.REBALANCE_IN_PROGRESS, Errors.REBALANCE_IN_PROGRESS) + for (expectedResult <- expectedResultList) { + timer.advanceClock(otherMemberSessionTimeout) + EasyMock.reset(replicaManager) + val heartbeatResult = heartbeat(groupId, firstMemberId, firstGenerationId) + assertEquals(expectedResult, heartbeatResult) + } + + // now timeout the rebalance + timer.advanceClock(otherMemberSessionTimeout) + val otherJoinResult = await(otherJoinFuture, otherMemberSessionTimeout+100) + val otherMemberId = otherJoinResult.memberId + val otherGenerationId = otherJoinResult.generationId + EasyMock.reset(replicaManager) + val syncResult = syncGroupLeader(groupId, otherGenerationId, otherMemberId, Map(otherMemberId -> Array[Byte]())) + assertEquals(Errors.NONE, syncResult.error) + + // the unjoined static member should be remained in the group before session timeout. + assertEquals(Errors.NONE, otherJoinResult.error) + EasyMock.reset(replicaManager) + var heartbeatResult = heartbeat(groupId, firstMemberId, firstGenerationId) + assertEquals(Errors.ILLEGAL_GENERATION, heartbeatResult) + + expectedResultList = List(Errors.NONE, Errors.NONE, Errors.REBALANCE_IN_PROGRESS) + + // now session timeout the unjoined member. Still keeping the new member. + for (expectedResult <- expectedResultList) { + timer.advanceClock(otherMemberSessionTimeout) + EasyMock.reset(replicaManager) + heartbeatResult = heartbeat(groupId, otherMemberId, otherGenerationId) + assertEquals(expectedResult, heartbeatResult) + } + + EasyMock.reset(replicaManager) + val otherRejoinGroupFuture = sendJoinGroup(groupId, otherMemberId, protocolType, protocols) + val otherReJoinResult = await(otherRejoinGroupFuture, otherMemberSessionTimeout+100) + assertEquals(Errors.NONE, otherReJoinResult.error) + + EasyMock.reset(replicaManager) + val otherRejoinGenerationId = otherReJoinResult.generationId + val reSyncResult = syncGroupLeader(groupId, otherRejoinGenerationId, otherMemberId, Map(otherMemberId -> Array[Byte]())) + assertEquals(Errors.NONE, reSyncResult.error) + + // the joined member should get heart beat response with no error. Let the new member keep heartbeating for a while + // to verify that no new rebalance is triggered unexpectedly + for ( _ <- 1 to 20) { + timer.advanceClock(500) + EasyMock.reset(replicaManager) + heartbeatResult = heartbeat(groupId, otherMemberId, otherRejoinGenerationId) + assertEquals(Errors.NONE, heartbeatResult) + } + } + + @Test + def testSyncGroupEmptyAssignment(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols) + val assignedConsumerId = joinGroupResult.memberId + val generationId = joinGroupResult.generationId + val joinGroupError = joinGroupResult.error + assertEquals(Errors.NONE, joinGroupError) + + EasyMock.reset(replicaManager) + val syncGroupResult = syncGroupLeader(groupId, generationId, assignedConsumerId, Map()) + assertEquals(Errors.NONE, syncGroupResult.error) + assertTrue(syncGroupResult.memberAssignment.isEmpty) + + EasyMock.reset(replicaManager) + val heartbeatResult = heartbeat(groupId, assignedConsumerId, 1) + assertEquals(Errors.NONE, heartbeatResult) + } + + @Test + def testSyncGroupNotCoordinator(): Unit = { + val generation = 1 + + val syncGroupResult = syncGroupFollower(otherGroupId, generation, memberId) + assertEquals(Errors.NOT_COORDINATOR, syncGroupResult.error) + } + + @Test + def testSyncGroupFromUnknownGroup(): Unit = { + val syncGroupResult = syncGroupFollower(groupId, 1, memberId) + assertEquals(Errors.UNKNOWN_MEMBER_ID, syncGroupResult.error) + } + + @Test + def testSyncGroupFromUnknownMember(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols) + val assignedConsumerId = joinGroupResult.memberId + val generationId = joinGroupResult.generationId + assertEquals(Errors.NONE, joinGroupResult.error) + + EasyMock.reset(replicaManager) + val syncGroupResult = syncGroupLeader(groupId, generationId, assignedConsumerId, Map(assignedConsumerId -> Array[Byte]())) + val syncGroupError = syncGroupResult.error + assertEquals(Errors.NONE, syncGroupError) + + EasyMock.reset(replicaManager) + val unknownMemberId = "blah" + val unknownMemberSyncResult = syncGroupFollower(groupId, generationId, unknownMemberId) + assertEquals(Errors.UNKNOWN_MEMBER_ID, unknownMemberSyncResult.error) + } + + @Test + def testSyncGroupFromIllegalGeneration(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols) + val assignedConsumerId = joinGroupResult.memberId + val generationId = joinGroupResult.generationId + assertEquals(Errors.NONE, joinGroupResult.error) + + EasyMock.reset(replicaManager) + // send the sync group with an invalid generation + val syncGroupResult = syncGroupLeader(groupId, generationId+1, assignedConsumerId, Map(assignedConsumerId -> Array[Byte]())) + assertEquals(Errors.ILLEGAL_GENERATION, syncGroupResult.error) + } + + @Test + def testJoinGroupFromUnchangedFollowerDoesNotRebalance(): Unit = { + // to get a group of two members: + // 1. join and sync with a single member (because we can't immediately join with two members) + // 2. join and sync with the first member and a new member + + val firstJoinResult = dynamicJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols) + val firstMemberId = firstJoinResult.memberId + val firstGenerationId = firstJoinResult.generationId + assertEquals(firstMemberId, firstJoinResult.leaderId) + assertEquals(Errors.NONE, firstJoinResult.error) + + EasyMock.reset(replicaManager) + val firstSyncResult = syncGroupLeader(groupId, firstGenerationId, firstMemberId, Map(firstMemberId -> Array[Byte]())) + assertEquals(Errors.NONE, firstSyncResult.error) + + EasyMock.reset(replicaManager) + val otherJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols) + + EasyMock.reset(replicaManager) + val joinFuture = sendJoinGroup(groupId, firstMemberId, protocolType, protocols) + + val joinResult = await(joinFuture, DefaultSessionTimeout+100) + val otherJoinResult = await(otherJoinFuture, DefaultSessionTimeout+100) + assertEquals(Errors.NONE, joinResult.error) + assertEquals(Errors.NONE, otherJoinResult.error) + assertTrue(joinResult.generationId == otherJoinResult.generationId) + + assertEquals(firstMemberId, joinResult.leaderId) + assertEquals(firstMemberId, otherJoinResult.leaderId) + + val nextGenerationId = joinResult.generationId + + // this shouldn't cause a rebalance since protocol information hasn't changed + EasyMock.reset(replicaManager) + val followerJoinResult = await(sendJoinGroup(groupId, otherJoinResult.memberId, protocolType, protocols), 1) + + assertEquals(Errors.NONE, followerJoinResult.error) + assertEquals(nextGenerationId, followerJoinResult.generationId) + } + + @Test + def testJoinGroupFromUnchangedLeaderShouldRebalance(): Unit = { + val firstJoinResult = dynamicJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols) + val firstMemberId = firstJoinResult.memberId + val firstGenerationId = firstJoinResult.generationId + assertEquals(firstMemberId, firstJoinResult.leaderId) + assertEquals(Errors.NONE, firstJoinResult.error) + + EasyMock.reset(replicaManager) + val firstSyncResult = syncGroupLeader(groupId, firstGenerationId, firstMemberId, Map(firstMemberId -> Array[Byte]())) + assertEquals(Errors.NONE, firstSyncResult.error) + + // join groups from the leader should force the group to rebalance, which allows the + // leader to push new assignments when local metadata changes + + EasyMock.reset(replicaManager) + val secondJoinResult = await(sendJoinGroup(groupId, firstMemberId, protocolType, protocols), 1) + + assertEquals(Errors.NONE, secondJoinResult.error) + assertNotEquals(firstGenerationId, secondJoinResult.generationId) + } + + /** + * Test if the following scenario completes a rebalance correctly: A new member starts a JoinGroup request with + * an UNKNOWN_MEMBER_ID, attempting to join a stable group. But never initiates the second JoinGroup request with + * the provided member ID and times out. The test checks if original member remains the sole member in this group, + * which should remain stable throughout this test. + */ + @Test + def testSecondMemberPartiallyJoinAndTimeout(): Unit = { + val firstJoinResult = dynamicJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols) + val firstMemberId = firstJoinResult.memberId + val firstGenerationId = firstJoinResult.generationId + assertEquals(firstMemberId, firstJoinResult.leaderId) + assertEquals(Errors.NONE, firstJoinResult.error) + + // Starting sync group leader + EasyMock.reset(replicaManager) + val firstSyncResult = syncGroupLeader(groupId, firstGenerationId, firstMemberId, Map(firstMemberId -> Array[Byte]())) + assertEquals(Errors.NONE, firstSyncResult.error) + timer.advanceClock(100) + assertEquals(Set(firstMemberId), groupCoordinator.groupManager.getGroup(groupId).get.allMembers) + assertEquals(groupCoordinator.groupManager.getGroup(groupId).get.allMembers, + groupCoordinator.groupManager.getGroup(groupId).get.allDynamicMembers) + assertEquals(0, groupCoordinator.groupManager.getGroup(groupId).get.numPending) + val group = groupCoordinator.groupManager.getGroup(groupId).get + + // ensure the group is stable before a new member initiates join request + assertEquals(Stable, group.currentState) + + // new member initiates join group + EasyMock.reset(replicaManager) + val secondJoinResult = joinGroupPartial(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols) + assertEquals(Errors.MEMBER_ID_REQUIRED, secondJoinResult.error) + assertEquals(1, group.numPending) + assertEquals(Stable, group.currentState) + + EasyMock.reset(replicaManager) + EasyMock.expect(replicaManager.getMagic(EasyMock.anyObject())).andReturn(Some(RecordBatch.MAGIC_VALUE_V1)).anyTimes() + EasyMock.replay(replicaManager) + + // advance clock to timeout the pending member + assertEquals(Set(firstMemberId), group.allMembers) + assertEquals(1, group.numPending) + timer.advanceClock(300) + + // original (firstMember) member sends heartbeats to prevent session timeouts. + EasyMock.reset(replicaManager) + val heartbeatResult = heartbeat(groupId, firstMemberId, 1) + assertEquals(Errors.NONE, heartbeatResult) + + // timeout the pending member + timer.advanceClock(300) + + // at this point the second member should have been removed from pending list (session timeout), + // and the group should be in Stable state with only the first member in it. + assertEquals(Set(firstMemberId), group.allMembers) + assertEquals(0, group.numPending) + assertEquals(Stable, group.currentState) + assertTrue(group.has(firstMemberId)) + } + + /** + * Create a group with two members in Stable state. Create a third pending member by completing it's first JoinGroup + * request without a member id. + */ + private def setupGroupWithPendingMember(): JoinGroupResult = { + // add the first member + val joinResult1 = dynamicJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols) + assertGroupState(groupState = CompletingRebalance) + + // now the group is stable, with the one member that joined above + EasyMock.reset(replicaManager) + val firstSyncResult = syncGroupLeader(groupId, joinResult1.generationId, joinResult1.memberId, Map(joinResult1.memberId -> Array[Byte]())) + assertEquals(Errors.NONE, firstSyncResult.error) + assertGroupState(groupState = Stable) + + // start the join for the second member + EasyMock.reset(replicaManager) + val secondJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols) + + // rejoin the first member back into the group + EasyMock.reset(replicaManager) + val firstJoinFuture = sendJoinGroup(groupId, joinResult1.memberId, protocolType, protocols) + val firstMemberJoinResult = await(firstJoinFuture, DefaultSessionTimeout+100) + val secondMemberJoinResult = await(secondJoinFuture, DefaultSessionTimeout+100) + assertGroupState(groupState = CompletingRebalance) + + // stabilize the group + EasyMock.reset(replicaManager) + val secondSyncResult = syncGroupLeader(groupId, firstMemberJoinResult.generationId, joinResult1.memberId, Map(joinResult1.memberId -> Array[Byte]())) + assertEquals(Errors.NONE, secondSyncResult.error) + assertGroupState(groupState = Stable) + + // re-join an existing member, to transition the group to PreparingRebalance state. + EasyMock.reset(replicaManager) + sendJoinGroup(groupId, firstMemberJoinResult.memberId, protocolType, protocols) + assertGroupState(groupState = PreparingRebalance) + + // create a pending member in the group + EasyMock.reset(replicaManager) + val pendingMember = joinGroupPartial(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols, sessionTimeout=100) + assertEquals(1, groupCoordinator.groupManager.getGroup(groupId).get.numPending) + + // re-join the second existing member + EasyMock.reset(replicaManager) + sendJoinGroup(groupId, secondMemberJoinResult.memberId, protocolType, protocols) + assertGroupState(groupState = PreparingRebalance) + assertEquals(1, groupCoordinator.groupManager.getGroup(groupId).get.numPending) + + pendingMember + } + + /** + * Setup a group in with a pending member. The test checks if the a pending member joining completes the rebalancing + * operation + */ + @Test + def testJoinGroupCompletionWhenPendingMemberJoins(): Unit = { + val pendingMember = setupGroupWithPendingMember() + + // compete join group for the pending member + EasyMock.reset(replicaManager) + val pendingMemberJoinFuture = sendJoinGroup(groupId, pendingMember.memberId, protocolType, protocols) + await(pendingMemberJoinFuture, DefaultSessionTimeout+100) + + assertGroupState(groupState = CompletingRebalance) + assertEquals(3, group().allMembers.size) + assertEquals(0, group().numPending) + } + + /** + * Setup a group in with a pending member. The test checks if the timeout of the pending member will + * cause the group to return to a CompletingRebalance state. + */ + @Test + def testJoinGroupCompletionWhenPendingMemberTimesOut(): Unit = { + setupGroupWithPendingMember() + + // Advancing Clock by > 100 (session timeout for third and fourth member) + // and < 500 (for first and second members). This will force the coordinator to attempt join + // completion on heartbeat expiration (since we are in PendingRebalance stage). + EasyMock.reset(replicaManager) + EasyMock.expect(replicaManager.getMagic(EasyMock.anyObject())).andReturn(Some(RecordBatch.MAGIC_VALUE_V1)).anyTimes() + EasyMock.replay(replicaManager) + timer.advanceClock(120) + + assertGroupState(groupState = CompletingRebalance) + assertEquals(2, group().allMembers.size) + assertEquals(0, group().numPending) + } + + @Test + def testPendingMembersLeavesGroup(): Unit = { + val pending = setupGroupWithPendingMember() + + EasyMock.reset(replicaManager) + val leaveGroupResults = singleLeaveGroup(groupId, pending.memberId) + verifyLeaveGroupResult(leaveGroupResults) + + assertGroupState(groupState = CompletingRebalance) + assertEquals(2, group().allMembers.size) + assertEquals(2, group().allDynamicMembers.size) + assertEquals(0, group().numPending) + } + + private def verifyHeartbeat( + joinGroupResult: JoinGroupResult, + expectedError: Errors + ): Unit = { + EasyMock.reset(replicaManager) + val heartbeatResult = heartbeat( + groupId, + joinGroupResult.memberId, + joinGroupResult.generationId + ) + assertEquals(expectedError, heartbeatResult) + } + + private def joinWithNMembers(nbMembers: Int): Seq[JoinGroupResult] = { + val requiredKnownMemberId = true + + // First JoinRequests + var futures = 1.to(nbMembers).map { _ => + EasyMock.reset(replicaManager) + sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols, + None, DefaultSessionTimeout, DefaultRebalanceTimeout, requiredKnownMemberId) + } + + // Get back the assigned member ids + val memberIds = futures.map(await(_, 1).memberId) + + // Second JoinRequests + futures = memberIds.map { memberId => + EasyMock.reset(replicaManager) + sendJoinGroup(groupId, memberId, protocolType, protocols, + None, DefaultSessionTimeout, DefaultRebalanceTimeout, requiredKnownMemberId) + } + + timer.advanceClock(GroupInitialRebalanceDelay + 1) + timer.advanceClock(DefaultRebalanceTimeout + 1) + + futures.map(await(_, 1)) + } + + @Test + def testRebalanceTimesOutWhenSyncRequestIsNotReceived(): Unit = { + // This test case ensure that the DelayedSync does kick out all members + // if they don't sent a sync request before the rebalance timeout. The + // group is in the Stable state in this case. + val results = joinWithNMembers(nbMembers = 3) + assertEquals(Set(Errors.NONE), results.map(_.error).toSet) + + // Advance time + timer.advanceClock(DefaultRebalanceTimeout / 2) + + // Heartbeats to ensure that heartbeating does not interfere with the + // delayed sync operation. + results.foreach { joinGroupResult => + verifyHeartbeat(joinGroupResult, Errors.NONE) + } + + // Advance part the rebalance timeout to trigger the delayed operation. + EasyMock.reset(replicaManager) + EasyMock.expect(replicaManager.getMagic(EasyMock.anyObject())) + .andReturn(Some(RecordBatch.MAGIC_VALUE_V1)) + .anyTimes() + EasyMock.replay(replicaManager) + + timer.advanceClock(DefaultRebalanceTimeout / 2 + 1) + + // Heartbeats fail because none of the members have sent the sync request + results.foreach { joinGroupResult => + verifyHeartbeat(joinGroupResult, Errors.UNKNOWN_MEMBER_ID) + } + } + + @Test + def testRebalanceTimesOutWhenSyncRequestIsNotReceivedFromFollowers(): Unit = { + // This test case ensure that the DelayedSync does kick out the followers + // if they don't sent a sync request before the rebalance timeout. The + // group is in the Stable state in this case. + val results = joinWithNMembers(nbMembers = 3) + assertEquals(Set(Errors.NONE), results.map(_.error).toSet) + + // Advance time + timer.advanceClock(DefaultRebalanceTimeout / 2) + + // Heartbeats to ensure that heartbeating does not interfere with the + // delayed sync operation. + results.foreach { joinGroupResult => + verifyHeartbeat(joinGroupResult, Errors.NONE) + } + + // Leader sends Sync + EasyMock.reset(replicaManager) + val assignments = results.map(result => result.memberId -> Array.empty[Byte]).toMap + val leaderResult = sendSyncGroupLeader(groupId, results.head.generationId, results.head.memberId, + Some(protocolType), Some(protocolName), None, assignments) + + assertEquals(Errors.NONE, await(leaderResult, 1).error) + + // Leader should be able to heartbeart + verifyHeartbeat(results.head, Errors.NONE) + + // Advance part the rebalance timeout to trigger the delayed operation. + timer.advanceClock(DefaultRebalanceTimeout / 2 + 1) + + // Leader should be able to heartbeart + verifyHeartbeat(results.head, Errors.REBALANCE_IN_PROGRESS) + + // Followers should have been removed. + results.tail.foreach { joinGroupResult => + verifyHeartbeat(joinGroupResult, Errors.UNKNOWN_MEMBER_ID) + } + } + + @Test + def testRebalanceTimesOutWhenSyncRequestIsNotReceivedFromLeaders(): Unit = { + // This test case ensure that the DelayedSync does kick out the leader + // if it does not sent a sync request before the rebalance timeout. The + // group is in the CompletingRebalance state in this case. + val results = joinWithNMembers(nbMembers = 3) + assertEquals(Set(Errors.NONE), results.map(_.error).toSet) + + // Advance time + timer.advanceClock(DefaultRebalanceTimeout / 2) + + // Heartbeats to ensure that heartbeating does not interfere with the + // delayed sync operation. + results.foreach { joinGroupResult => + verifyHeartbeat(joinGroupResult, Errors.NONE) + } + + // Followers send Sync + EasyMock.reset(replicaManager) + val followerResults = results.tail.map { joinGroupResult => + EasyMock.reset(replicaManager) + sendSyncGroupFollower(groupId, joinGroupResult.generationId, joinGroupResult.memberId, + Some(protocolType), Some(protocolName), None) + } + + // Advance part the rebalance timeout to trigger the delayed operation. + timer.advanceClock(DefaultRebalanceTimeout / 2 + 1) + + val followerErrors = followerResults.map(await(_, 1).error) + assertEquals(Set(Errors.REBALANCE_IN_PROGRESS), followerErrors.toSet) + + // Leader should have been removed. + verifyHeartbeat(results.head, Errors.UNKNOWN_MEMBER_ID) + + // Followers should be able to heartbeat. + results.tail.foreach { joinGroupResult => + verifyHeartbeat(joinGroupResult, Errors.REBALANCE_IN_PROGRESS) + } + } + + @Test + def testRebalanceDoesNotTimeOutWhenAllSyncAreReceived(): Unit = { + // This test case ensure that the DelayedSync does not kick any + // members out when they have all sent their sync requests. + val results = joinWithNMembers(nbMembers = 3) + assertEquals(Set(Errors.NONE), results.map(_.error).toSet) + + // Advance time + timer.advanceClock(DefaultRebalanceTimeout / 2) + + // Heartbeats to ensure that heartbeating does not interfere with the + // delayed sync operation. + results.foreach { joinGroupResult => + verifyHeartbeat(joinGroupResult, Errors.NONE) + } + + EasyMock.reset(replicaManager) + val assignments = results.map(result => result.memberId -> Array.empty[Byte]).toMap + val leaderResult = sendSyncGroupLeader(groupId, results.head.generationId, results.head.memberId, + Some(protocolType), Some(protocolName), None, assignments) + + assertEquals(Errors.NONE, await(leaderResult, 1).error) + + // Followers send Sync + EasyMock.reset(replicaManager) + val followerResults = results.tail.map { joinGroupResult => + EasyMock.reset(replicaManager) + sendSyncGroupFollower(groupId, joinGroupResult.generationId, joinGroupResult.memberId, + Some(protocolType), Some(protocolName), None) + } + + val followerErrors = followerResults.map(await(_, 1).error) + assertEquals(Set(Errors.NONE), followerErrors.toSet) + + // Advance past the rebalance timeout to expire the Sync timout. All + // members should remain and the group should not rebalance. + timer.advanceClock(DefaultRebalanceTimeout / 2 + 1) + + // Followers should be able to heartbeat. + results.foreach { joinGroupResult => + verifyHeartbeat(joinGroupResult, Errors.NONE) + } + + // Advance a bit more. + timer.advanceClock(DefaultRebalanceTimeout / 2) + + // Followers should be able to heartbeat. + results.foreach { joinGroupResult => + verifyHeartbeat(joinGroupResult, Errors.NONE) + } + } + + private def group(groupId: String = groupId) = { + groupCoordinator.groupManager.getGroup(groupId) match { + case Some(g) => g + case None => null + } + } + + private def assertGroupState(groupId: String = groupId, + groupState: GroupState): Unit = { + groupCoordinator.groupManager.getGroup(groupId) match { + case Some(group) => assertEquals(groupState, group.currentState) + case None => fail(s"Group $groupId not found in coordinator") + } + } + + private def joinGroupPartial(groupId: String, + memberId: String, + protocolType: String, + protocols: List[(String, Array[Byte])], + sessionTimeout: Int = DefaultSessionTimeout, + rebalanceTimeout: Int = DefaultRebalanceTimeout): JoinGroupResult = { + val requireKnownMemberId = true + val responseFuture = sendJoinGroup(groupId, memberId, protocolType, protocols, None, sessionTimeout, rebalanceTimeout, requireKnownMemberId) + Await.result(responseFuture, Duration(rebalanceTimeout + 100, TimeUnit.MILLISECONDS)) + } + + @Test + def testLeaderFailureInSyncGroup(): Unit = { + // to get a group of two members: + // 1. join and sync with a single member (because we can't immediately join with two members) + // 2. join and sync with the first member and a new member + + val firstJoinResult = dynamicJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols) + val firstMemberId = firstJoinResult.memberId + val firstGenerationId = firstJoinResult.generationId + assertEquals(firstMemberId, firstJoinResult.leaderId) + assertEquals(Errors.NONE, firstJoinResult.error) + + EasyMock.reset(replicaManager) + val firstSyncResult = syncGroupLeader(groupId, firstGenerationId, firstMemberId, Map(firstMemberId -> Array[Byte]())) + assertEquals(Errors.NONE, firstSyncResult.error) + + EasyMock.reset(replicaManager) + val otherJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols) + + EasyMock.reset(replicaManager) + val joinFuture = sendJoinGroup(groupId, firstMemberId, protocolType, protocols) + + val joinResult = await(joinFuture, DefaultSessionTimeout+100) + val otherJoinResult = await(otherJoinFuture, DefaultSessionTimeout+100) + assertEquals(Errors.NONE, joinResult.error) + assertEquals(Errors.NONE, otherJoinResult.error) + assertTrue(joinResult.generationId == otherJoinResult.generationId) + + assertEquals(firstMemberId, joinResult.leaderId) + assertEquals(firstMemberId, otherJoinResult.leaderId) + + val nextGenerationId = joinResult.generationId + + // with no leader SyncGroup, the follower's request should fail with an error indicating + // that it should rejoin + EasyMock.reset(replicaManager) + val followerSyncFuture = sendSyncGroupFollower(groupId, nextGenerationId, otherJoinResult.memberId, None, None, None) + + timer.advanceClock(DefaultSessionTimeout + 100) + + val followerSyncResult = await(followerSyncFuture, DefaultSessionTimeout+100) + assertEquals(Errors.REBALANCE_IN_PROGRESS, followerSyncResult.error) + } + + @Test + def testSyncGroupFollowerAfterLeader(): Unit = { + // to get a group of two members: + // 1. join and sync with a single member (because we can't immediately join with two members) + // 2. join and sync with the first member and a new member + + val firstJoinResult = dynamicJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols) + val firstMemberId = firstJoinResult.memberId + val firstGenerationId = firstJoinResult.generationId + assertEquals(firstMemberId, firstJoinResult.leaderId) + assertEquals(Errors.NONE, firstJoinResult.error) + + EasyMock.reset(replicaManager) + val firstSyncResult = syncGroupLeader(groupId, firstGenerationId, firstMemberId, Map(firstMemberId -> Array[Byte]())) + assertEquals(Errors.NONE, firstSyncResult.error) + + EasyMock.reset(replicaManager) + val otherJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols) + + EasyMock.reset(replicaManager) + val joinFuture = sendJoinGroup(groupId, firstMemberId, protocolType, protocols) + + val joinResult = await(joinFuture, DefaultSessionTimeout+100) + val otherJoinResult = await(otherJoinFuture, DefaultSessionTimeout+100) + assertEquals(Errors.NONE, joinResult.error) + assertEquals(Errors.NONE, otherJoinResult.error) + assertTrue(joinResult.generationId == otherJoinResult.generationId) + + assertEquals(firstMemberId, joinResult.leaderId) + assertEquals(firstMemberId, otherJoinResult.leaderId) + + val nextGenerationId = joinResult.generationId + val leaderId = firstMemberId + val leaderAssignment = Array[Byte](0) + val followerId = otherJoinResult.memberId + val followerAssignment = Array[Byte](1) + + EasyMock.reset(replicaManager) + val leaderSyncResult = syncGroupLeader(groupId, nextGenerationId, leaderId, + Map(leaderId -> leaderAssignment, followerId -> followerAssignment)) + assertEquals(Errors.NONE, leaderSyncResult.error) + assertEquals(leaderAssignment, leaderSyncResult.memberAssignment) + + EasyMock.reset(replicaManager) + val followerSyncResult = syncGroupFollower(groupId, nextGenerationId, otherJoinResult.memberId) + assertEquals(Errors.NONE, followerSyncResult.error) + assertEquals(followerAssignment, followerSyncResult.memberAssignment) + } + + @Test + def testSyncGroupLeaderAfterFollower(): Unit = { + // to get a group of two members: + // 1. join and sync with a single member (because we can't immediately join with two members) + // 2. join and sync with the first member and a new member + + val joinGroupResult = dynamicJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols) + val firstMemberId = joinGroupResult.memberId + val firstGenerationId = joinGroupResult.generationId + assertEquals(firstMemberId, joinGroupResult.leaderId) + assertEquals(Errors.NONE, joinGroupResult.error) + + EasyMock.reset(replicaManager) + val syncGroupResult = syncGroupLeader(groupId, firstGenerationId, firstMemberId, Map(firstMemberId -> Array[Byte]())) + assertEquals(Errors.NONE, syncGroupResult.error) + + EasyMock.reset(replicaManager) + val otherJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols) + + EasyMock.reset(replicaManager) + val joinFuture = sendJoinGroup(groupId, firstMemberId, protocolType, protocols) + + val joinResult = await(joinFuture, DefaultSessionTimeout+100) + val otherJoinResult = await(otherJoinFuture, DefaultSessionTimeout+100) + assertEquals(Errors.NONE, joinResult.error) + assertEquals(Errors.NONE, otherJoinResult.error) + assertTrue(joinResult.generationId == otherJoinResult.generationId) + + val nextGenerationId = joinResult.generationId + val leaderId = joinResult.leaderId + val leaderAssignment = Array[Byte](0) + val followerId = otherJoinResult.memberId + val followerAssignment = Array[Byte](1) + + assertEquals(firstMemberId, joinResult.leaderId) + assertEquals(firstMemberId, otherJoinResult.leaderId) + + EasyMock.reset(replicaManager) + val followerSyncFuture = sendSyncGroupFollower(groupId, nextGenerationId, followerId, None, None, None) + + EasyMock.reset(replicaManager) + val leaderSyncResult = syncGroupLeader(groupId, nextGenerationId, leaderId, + Map(leaderId -> leaderAssignment, followerId -> followerAssignment)) + assertEquals(Errors.NONE, leaderSyncResult.error) + assertEquals(leaderAssignment, leaderSyncResult.memberAssignment) + + val followerSyncResult = await(followerSyncFuture, DefaultSessionTimeout+100) + assertEquals(Errors.NONE, followerSyncResult.error) + assertEquals(followerAssignment, followerSyncResult.memberAssignment) + } + + @Test + def testCommitOffsetFromUnknownGroup(): Unit = { + val generationId = 1 + val tp = new TopicPartition("topic", 0) + val offset = offsetAndMetadata(0) + + val commitOffsetResult = commitOffsets(groupId, memberId, generationId, Map(tp -> offset)) + assertEquals(Errors.ILLEGAL_GENERATION, commitOffsetResult(tp)) + } + + @Test + def testCommitOffsetWithDefaultGeneration(): Unit = { + val tp = new TopicPartition("topic", 0) + val offset = offsetAndMetadata(0) + + val commitOffsetResult = commitOffsets(groupId, OffsetCommitRequest.DEFAULT_MEMBER_ID, + OffsetCommitRequest.DEFAULT_GENERATION_ID, Map(tp -> offset)) + assertEquals(Errors.NONE, commitOffsetResult(tp)) + } + + @Test + def testCommitOffsetsAfterGroupIsEmpty(): Unit = { + // Tests the scenario where the reset offset tool modifies the offsets + // of a group after it becomes empty + + // A group member joins + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols) + val assignedMemberId = joinGroupResult.memberId + val joinGroupError = joinGroupResult.error + assertEquals(Errors.NONE, joinGroupError) + + // and leaves. + EasyMock.reset(replicaManager) + val leaveGroupResults = singleLeaveGroup(groupId, assignedMemberId) + verifyLeaveGroupResult(leaveGroupResults) + + // The simple offset commit should now fail + EasyMock.reset(replicaManager) + val tp = new TopicPartition("topic", 0) + val offset = offsetAndMetadata(0) + val commitOffsetResult = commitOffsets(groupId, OffsetCommitRequest.DEFAULT_MEMBER_ID, + OffsetCommitRequest.DEFAULT_GENERATION_ID, Map(tp -> offset)) + assertEquals(Errors.NONE, commitOffsetResult(tp)) + + val (error, partitionData) = groupCoordinator.handleFetchOffsets(groupId, requireStable, Some(Seq(tp))) + assertEquals(Errors.NONE, error) + assertEquals(Some(0), partitionData.get(tp).map(_.offset)) + } + + @Test + def testFetchOffsets(): Unit = { + val tp = new TopicPartition("topic", 0) + val offset = 97L + val metadata = "some metadata" + val leaderEpoch = Optional.of[Integer](15) + val offsetAndMetadata = OffsetAndMetadata(offset, leaderEpoch, metadata, timer.time.milliseconds(), None) + + val commitOffsetResult = commitOffsets(groupId, OffsetCommitRequest.DEFAULT_MEMBER_ID, + OffsetCommitRequest.DEFAULT_GENERATION_ID, Map(tp -> offsetAndMetadata)) + assertEquals(Errors.NONE, commitOffsetResult(tp)) + + val (error, partitionData) = groupCoordinator.handleFetchOffsets(groupId, requireStable, Some(Seq(tp))) + assertEquals(Errors.NONE, error) + + val maybePartitionData = partitionData.get(tp) + assertTrue(maybePartitionData.isDefined) + assertEquals(offset, maybePartitionData.get.offset) + assertEquals(metadata, maybePartitionData.get.metadata) + assertEquals(leaderEpoch, maybePartitionData.get.leaderEpoch) + } + + @Test + def testCommitAndFetchOffsetsWithEmptyGroup(): Unit = { + // For backwards compatibility, the coordinator supports committing/fetching offsets with an empty groupId. + // To allow inspection and removal of the empty group, we must also support DescribeGroups and DeleteGroups + + val tp = new TopicPartition("topic", 0) + val offset = offsetAndMetadata(0) + val groupId = "" + + val commitOffsetResult = commitOffsets(groupId, OffsetCommitRequest.DEFAULT_MEMBER_ID, + OffsetCommitRequest.DEFAULT_GENERATION_ID, Map(tp -> offset)) + assertEquals(Errors.NONE, commitOffsetResult(tp)) + + val (fetchError, partitionData) = groupCoordinator.handleFetchOffsets(groupId, requireStable, Some(Seq(tp))) + assertEquals(Errors.NONE, fetchError) + assertEquals(Some(0), partitionData.get(tp).map(_.offset)) + + val (describeError, summary) = groupCoordinator.handleDescribeGroup(groupId) + assertEquals(Errors.NONE, describeError) + assertEquals(Empty.toString, summary.state) + + val groupTopicPartition = new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, groupPartitionId) + val partition: Partition = EasyMock.niceMock(classOf[Partition]) + + EasyMock.reset(replicaManager) + EasyMock.expect(replicaManager.getMagic(EasyMock.anyObject())).andStubReturn(Some(RecordBatch.CURRENT_MAGIC_VALUE)) + EasyMock.expect(replicaManager.getPartition(groupTopicPartition)).andStubReturn(HostedPartition.Online(partition)) + EasyMock.expect(replicaManager.onlinePartition(groupTopicPartition)).andStubReturn(Some(partition)) + EasyMock.replay(replicaManager, partition) + + val deleteErrors = groupCoordinator.handleDeleteGroups(Set(groupId)) + assertEquals(Errors.NONE, deleteErrors(groupId)) + + val (err, data) = groupCoordinator.handleFetchOffsets(groupId, requireStable, Some(Seq(tp))) + assertEquals(Errors.NONE, err) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), data.get(tp).map(_.offset)) + } + + @Test + def testBasicFetchTxnOffsets(): Unit = { + val tp = new TopicPartition("topic", 0) + val offset = offsetAndMetadata(0) + val producerId = 1000L + val producerEpoch : Short = 2 + + val commitOffsetResult = commitTransactionalOffsets(groupId, producerId, producerEpoch, Map(tp -> offset)) + assertEquals(Errors.NONE, commitOffsetResult(tp)) + + val (error, partitionData) = groupCoordinator.handleFetchOffsets(groupId, requireStable, Some(Seq(tp))) + + // Validate that the offset isn't materialjzed yet. + assertEquals(Errors.NONE, error) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), partitionData.get(tp).map(_.offset)) + + val offsetsTopic = new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, groupPartitionId) + + // Send commit marker. + handleTxnCompletion(producerId, List(offsetsTopic), TransactionResult.COMMIT) + + // Validate that committed offset is materialized. + val (secondReqError, secondReqPartitionData) = groupCoordinator.handleFetchOffsets(groupId, requireStable, Some(Seq(tp))) + assertEquals(Errors.NONE, secondReqError) + assertEquals(Some(0), secondReqPartitionData.get(tp).map(_.offset)) + } + + @Test + def testFetchTxnOffsetsWithAbort(): Unit = { + val tp = new TopicPartition("topic", 0) + val offset = offsetAndMetadata(0) + val producerId = 1000L + val producerEpoch : Short = 2 + + val commitOffsetResult = commitTransactionalOffsets(groupId, producerId, producerEpoch, Map(tp -> offset)) + assertEquals(Errors.NONE, commitOffsetResult(tp)) + + val (error, partitionData) = groupCoordinator.handleFetchOffsets(groupId, requireStable, Some(Seq(tp))) + assertEquals(Errors.NONE, error) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), partitionData.get(tp).map(_.offset)) + + val offsetsTopic = new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, groupPartitionId) + + // Validate that the pending commit is discarded. + handleTxnCompletion(producerId, List(offsetsTopic), TransactionResult.ABORT) + + val (secondReqError, secondReqPartitionData) = groupCoordinator.handleFetchOffsets(groupId, requireStable, Some(Seq(tp))) + assertEquals(Errors.NONE, secondReqError) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), secondReqPartitionData.get(tp).map(_.offset)) + } + + @Test + def testFetchPendingTxnOffsetsWithAbort(): Unit = { + val tp = new TopicPartition("topic", 0) + val offset = offsetAndMetadata(0) + val producerId = 1000L + val producerEpoch : Short = 2 + + val commitOffsetResult = commitTransactionalOffsets(groupId, producerId, producerEpoch, Map(tp -> offset)) + assertEquals(Errors.NONE, commitOffsetResult(tp)) + + val nonExistTp = new TopicPartition("non-exist-topic", 0) + val (error, partitionData) = groupCoordinator.handleFetchOffsets(groupId, requireStable, Some(Seq(tp, nonExistTp))) + assertEquals(Errors.NONE, error) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), partitionData.get(tp).map(_.offset)) + assertEquals(Some(Errors.UNSTABLE_OFFSET_COMMIT), partitionData.get(tp).map(_.error)) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), partitionData.get(nonExistTp).map(_.offset)) + assertEquals(Some(Errors.NONE), partitionData.get(nonExistTp).map(_.error)) + + val offsetsTopic = new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, groupPartitionId) + + // Validate that the pending commit is discarded. + handleTxnCompletion(producerId, List(offsetsTopic), TransactionResult.ABORT) + + val (secondReqError, secondReqPartitionData) = groupCoordinator.handleFetchOffsets(groupId, requireStable, Some(Seq(tp))) + assertEquals(Errors.NONE, secondReqError) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), secondReqPartitionData.get(tp).map(_.offset)) + assertEquals(Some(Errors.NONE), secondReqPartitionData.get(tp).map(_.error)) + } + + @Test + def testFetchPendingTxnOffsetsWithCommit(): Unit = { + val tp = new TopicPartition("topic", 0) + val offset = offsetAndMetadata(25) + val producerId = 1000L + val producerEpoch : Short = 2 + + val commitOffsetResult = commitTransactionalOffsets(groupId, producerId, producerEpoch, Map(tp -> offset)) + assertEquals(Errors.NONE, commitOffsetResult(tp)) + + val (error, partitionData) = groupCoordinator.handleFetchOffsets(groupId, requireStable, Some(Seq(tp))) + assertEquals(Errors.NONE, error) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), partitionData.get(tp).map(_.offset)) + assertEquals(Some(Errors.UNSTABLE_OFFSET_COMMIT), partitionData.get(tp).map(_.error)) + + val offsetsTopic = new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, groupPartitionId) + + // Validate that the pending commit is committed + handleTxnCompletion(producerId, List(offsetsTopic), TransactionResult.COMMIT) + + val (secondReqError, secondReqPartitionData) = groupCoordinator.handleFetchOffsets(groupId, requireStable, Some(Seq(tp))) + assertEquals(Errors.NONE, secondReqError) + assertEquals(Some(25), secondReqPartitionData.get(tp).map(_.offset)) + assertEquals(Some(Errors.NONE), secondReqPartitionData.get(tp).map(_.error)) + } + + @Test + def testFetchTxnOffsetsIgnoreSpuriousCommit(): Unit = { + val tp = new TopicPartition("topic", 0) + val offset = offsetAndMetadata(0) + val producerId = 1000L + val producerEpoch : Short = 2 + + val commitOffsetResult = commitTransactionalOffsets(groupId, producerId, producerEpoch, Map(tp -> offset)) + assertEquals(Errors.NONE, commitOffsetResult(tp)) + + val (error, partitionData) = groupCoordinator.handleFetchOffsets(groupId, requireStable, Some(Seq(tp))) + assertEquals(Errors.NONE, error) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), partitionData.get(tp).map(_.offset)) + + val offsetsTopic = new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, groupPartitionId) + handleTxnCompletion(producerId, List(offsetsTopic), TransactionResult.ABORT) + + val (secondReqError, secondReqPartitionData) = groupCoordinator.handleFetchOffsets(groupId, requireStable, Some(Seq(tp))) + assertEquals(Errors.NONE, secondReqError) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), secondReqPartitionData.get(tp).map(_.offset)) + + // Ignore spurious commit. + handleTxnCompletion(producerId, List(offsetsTopic), TransactionResult.COMMIT) + + val (thirdReqError, thirdReqPartitionData) = groupCoordinator.handleFetchOffsets(groupId, requireStable, Some(Seq(tp))) + assertEquals(Errors.NONE, thirdReqError) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), thirdReqPartitionData.get(tp).map(_.offset)) + } + + @Test + def testFetchTxnOffsetsOneProducerMultipleGroups(): Unit = { + // One producer, two groups located on separate offsets topic partitions. + // Both group have pending offset commits. + // Marker for only one partition is received. That commit should be materialized while the other should not. + + val partitions = List(new TopicPartition("topic1", 0), new TopicPartition("topic2", 0)) + val offsets = List(offsetAndMetadata(10), offsetAndMetadata(15)) + val producerId = 1000L + val producerEpoch: Short = 3 + + val groupIds = List(groupId, otherGroupId) + val offsetTopicPartitions = List(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, groupCoordinator.partitionFor(groupId)), + new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, groupCoordinator.partitionFor(otherGroupId))) + + groupCoordinator.groupManager.addPartitionOwnership(offsetTopicPartitions(1).partition) + val errors = mutable.ArrayBuffer[Errors]() + val partitionData = mutable.ArrayBuffer[scala.collection.Map[TopicPartition, OffsetFetchResponse.PartitionData]]() + + val commitOffsetResults = mutable.ArrayBuffer[CommitOffsetCallbackParams]() + + // Ensure that the two groups map to different partitions. + assertNotEquals(offsetTopicPartitions(0), offsetTopicPartitions(1)) + + commitOffsetResults.append(commitTransactionalOffsets(groupId, producerId, producerEpoch, Map(partitions(0) -> offsets(0)))) + assertEquals(Errors.NONE, commitOffsetResults(0)(partitions(0))) + commitOffsetResults.append(commitTransactionalOffsets(otherGroupId, producerId, producerEpoch, Map(partitions(1) -> offsets(1)))) + assertEquals(Errors.NONE, commitOffsetResults(1)(partitions(1))) + + // We got a commit for only one __consumer_offsets partition. We should only materialize it's group offsets. + handleTxnCompletion(producerId, List(offsetTopicPartitions(0)), TransactionResult.COMMIT) + groupCoordinator.handleFetchOffsets(groupIds(0), requireStable, Some(partitions)) match { + case (error, partData) => + errors.append(error) + partitionData.append(partData) + case _ => + } + + groupCoordinator.handleFetchOffsets(groupIds(1), requireStable, Some(partitions)) match { + case (error, partData) => + errors.append(error) + partitionData.append(partData) + case _ => + } + + assertEquals(2, errors.size) + assertEquals(Errors.NONE, errors(0)) + assertEquals(Errors.NONE, errors(1)) + + // Exactly one offset commit should have been materialized. + assertEquals(Some(offsets(0).offset), partitionData(0).get(partitions(0)).map(_.offset)) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), partitionData(0).get(partitions(1)).map(_.offset)) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), partitionData(1).get(partitions(0)).map(_.offset)) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), partitionData(1).get(partitions(1)).map(_.offset)) + + // Now we receive the other marker. + handleTxnCompletion(producerId, List(offsetTopicPartitions(1)), TransactionResult.COMMIT) + errors.clear() + partitionData.clear() + groupCoordinator.handleFetchOffsets(groupIds(0), requireStable, Some(partitions)) match { + case (error, partData) => + errors.append(error) + partitionData.append(partData) + case _ => + } + + groupCoordinator.handleFetchOffsets(groupIds(1), requireStable, Some(partitions)) match { + case (error, partData) => + errors.append(error) + partitionData.append(partData) + case _ => + } + // Two offsets should have been materialized + assertEquals(Some(offsets(0).offset), partitionData(0).get(partitions(0)).map(_.offset)) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), partitionData(0).get(partitions(1)).map(_.offset)) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), partitionData(1).get(partitions(0)).map(_.offset)) + assertEquals(Some(offsets(1).offset), partitionData(1).get(partitions(1)).map(_.offset)) + } + + @Test + def testFetchTxnOffsetsMultipleProducersOneGroup(): Unit = { + // One group, two producers + // Different producers will commit offsets for different partitions. + // Each partition's offsets should be materialized when the corresponding producer's marker is received. + + val partitions = List(new TopicPartition("topic1", 0), new TopicPartition("topic2", 0)) + val offsets = List(offsetAndMetadata(10), offsetAndMetadata(15)) + val producerIds = List(1000L, 1005L) + val producerEpochs: Seq[Short] = List(3, 4) + + val offsetTopicPartition = new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, groupCoordinator.partitionFor(groupId)) + + val errors = mutable.ArrayBuffer[Errors]() + val partitionData = mutable.ArrayBuffer[scala.collection.Map[TopicPartition, OffsetFetchResponse.PartitionData]]() + + val commitOffsetResults = mutable.ArrayBuffer[CommitOffsetCallbackParams]() + + // producer0 commits the offsets for partition0 + commitOffsetResults.append(commitTransactionalOffsets(groupId, producerIds(0), producerEpochs(0), Map(partitions(0) -> offsets(0)))) + assertEquals(Errors.NONE, commitOffsetResults(0)(partitions(0))) + + // producer1 commits the offsets for partition1 + commitOffsetResults.append(commitTransactionalOffsets(groupId, producerIds(1), producerEpochs(1), Map(partitions(1) -> offsets(1)))) + assertEquals(Errors.NONE, commitOffsetResults(1)(partitions(1))) + + // producer0 commits its transaction. + handleTxnCompletion(producerIds(0), List(offsetTopicPartition), TransactionResult.COMMIT) + groupCoordinator.handleFetchOffsets(groupId, requireStable, Some(partitions)) match { + case (error, partData) => + errors.append(error) + partitionData.append(partData) + case _ => + } + + assertEquals(Errors.NONE, errors(0)) + + // We should only see the offset commit for producer0 + assertEquals(Some(offsets(0).offset), partitionData(0).get(partitions(0)).map(_.offset)) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), partitionData(0).get(partitions(1)).map(_.offset)) + + // producer1 now commits its transaction. + handleTxnCompletion(producerIds(1), List(offsetTopicPartition), TransactionResult.COMMIT) + + groupCoordinator.handleFetchOffsets(groupId, requireStable, Some(partitions)) match { + case (error, partData) => + errors.append(error) + partitionData.append(partData) + case _ => + } + + assertEquals(Errors.NONE, errors(1)) + + // We should now see the offset commits for both producers. + assertEquals(Some(offsets(0).offset), partitionData(1).get(partitions(0)).map(_.offset)) + assertEquals(Some(offsets(1).offset), partitionData(1).get(partitions(1)).map(_.offset)) + } + + @Test + def testFetchOffsetForUnknownPartition(): Unit = { + val tp = new TopicPartition("topic", 0) + val (error, partitionData) = groupCoordinator.handleFetchOffsets(groupId, requireStable, Some(Seq(tp))) + assertEquals(Errors.NONE, error) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), partitionData.get(tp).map(_.offset)) + } + + @Test + def testFetchOffsetNotCoordinatorForGroup(): Unit = { + val tp = new TopicPartition("topic", 0) + val (error, partitionData) = groupCoordinator.handleFetchOffsets(otherGroupId, requireStable, Some(Seq(tp))) + assertEquals(Errors.NOT_COORDINATOR, error) + assertTrue(partitionData.isEmpty) + } + + @Test + def testFetchAllOffsets(): Unit = { + val tp1 = new TopicPartition("topic", 0) + val tp2 = new TopicPartition("topic", 1) + val tp3 = new TopicPartition("other-topic", 0) + val offset1 = offsetAndMetadata(15) + val offset2 = offsetAndMetadata(16) + val offset3 = offsetAndMetadata(17) + + assertEquals((Errors.NONE, Map.empty), groupCoordinator.handleFetchOffsets(groupId, requireStable)) + + val commitOffsetResult = commitOffsets(groupId, OffsetCommitRequest.DEFAULT_MEMBER_ID, + OffsetCommitRequest.DEFAULT_GENERATION_ID, Map(tp1 -> offset1, tp2 -> offset2, tp3 -> offset3)) + assertEquals(Errors.NONE, commitOffsetResult(tp1)) + assertEquals(Errors.NONE, commitOffsetResult(tp2)) + assertEquals(Errors.NONE, commitOffsetResult(tp3)) + + val (error, partitionData) = groupCoordinator.handleFetchOffsets(groupId, requireStable) + assertEquals(Errors.NONE, error) + assertEquals(3, partitionData.size) + assertTrue(partitionData.forall(_._2.error == Errors.NONE)) + assertEquals(Some(offset1.offset), partitionData.get(tp1).map(_.offset)) + assertEquals(Some(offset2.offset), partitionData.get(tp2).map(_.offset)) + assertEquals(Some(offset3.offset), partitionData.get(tp3).map(_.offset)) + } + + @Test + def testCommitOffsetInCompletingRebalance(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + val tp = new TopicPartition("topic", 0) + val offset = offsetAndMetadata(0) + + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols) + val assignedMemberId = joinGroupResult.memberId + val generationId = joinGroupResult.generationId + val joinGroupError = joinGroupResult.error + assertEquals(Errors.NONE, joinGroupError) + + EasyMock.reset(replicaManager) + val commitOffsetResult = commitOffsets(groupId, assignedMemberId, generationId, Map(tp -> offset)) + assertEquals(Errors.REBALANCE_IN_PROGRESS, commitOffsetResult(tp)) + } + + @Test + def testCommitOffsetInCompletingRebalanceFromUnknownMemberId(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + val tp = new TopicPartition("topic", 0) + val offset = offsetAndMetadata(0) + + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols) + val generationId = joinGroupResult.generationId + val joinGroupError = joinGroupResult.error + assertEquals(Errors.NONE, joinGroupError) + + EasyMock.reset(replicaManager) + val commitOffsetResult = commitOffsets(groupId, memberId, generationId, Map(tp -> offset)) + assertEquals(Errors.UNKNOWN_MEMBER_ID, commitOffsetResult(tp)) + } + + @Test + def testCommitOffsetInCompletingRebalanceFromIllegalGeneration(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + val tp = new TopicPartition("topic", 0) + val offset = offsetAndMetadata(0) + + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols) + val assignedMemberId = joinGroupResult.memberId + val generationId = joinGroupResult.generationId + val joinGroupError = joinGroupResult.error + assertEquals(Errors.NONE, joinGroupError) + + EasyMock.reset(replicaManager) + val commitOffsetResult = commitOffsets(groupId, assignedMemberId, generationId + 1, Map(tp -> offset)) + assertEquals(Errors.ILLEGAL_GENERATION, commitOffsetResult(tp)) + } + + @Test + def testTxnCommitOffsetWithFencedInstanceId(): Unit = { + val tp = new TopicPartition("topic", 0) + val offset = offsetAndMetadata(0) + val producerId = 1000L + val producerEpoch : Short = 2 + + val rebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId) + + val leaderNoMemberIdCommitOffsetResult = commitTransactionalOffsets(groupId, producerId, producerEpoch, + Map(tp -> offset), memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID, groupInstanceId = Some(leaderInstanceId)) + assertEquals(Errors.FENCED_INSTANCE_ID, leaderNoMemberIdCommitOffsetResult(tp)) + + val leaderInvalidMemberIdCommitOffsetResult = commitTransactionalOffsets(groupId, producerId, producerEpoch, + Map(tp -> offset), memberId = "invalid-member", groupInstanceId = Some(leaderInstanceId)) + assertEquals(Errors.FENCED_INSTANCE_ID, leaderInvalidMemberIdCommitOffsetResult (tp)) + + val leaderCommitOffsetResult = commitTransactionalOffsets(groupId, producerId, producerEpoch, + Map(tp -> offset), rebalanceResult.leaderId, Some(leaderInstanceId), rebalanceResult.generation) + assertEquals(Errors.NONE, leaderCommitOffsetResult (tp)) + } + + @Test + def testTxnCommitOffsetWithInvalidMemberId(): Unit = { + val tp = new TopicPartition("topic", 0) + val offset = offsetAndMetadata(0) + val producerId = 1000L + val producerEpoch : Short = 2 + + val joinGroupResult = dynamicJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols) + val joinGroupError = joinGroupResult.error + assertEquals(Errors.NONE, joinGroupError) + + EasyMock.reset(replicaManager) + val invalidIdCommitOffsetResult = commitTransactionalOffsets(groupId, producerId, producerEpoch, + Map(tp -> offset), "invalid-member") + assertEquals(Errors.UNKNOWN_MEMBER_ID, invalidIdCommitOffsetResult (tp)) + } + + @Test + def testTxnCommitOffsetWithKnownMemberId(): Unit = { + val tp = new TopicPartition("topic", 0) + val offset = offsetAndMetadata(0) + val producerId = 1000L + val producerEpoch : Short = 2 + + val joinGroupResult = dynamicJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols) + val joinGroupError = joinGroupResult.error + assertEquals(Errors.NONE, joinGroupError) + + + EasyMock.reset(replicaManager) + val assignedConsumerId = joinGroupResult.memberId + val leaderCommitOffsetResult = commitTransactionalOffsets(groupId, producerId, producerEpoch, + Map(tp -> offset), assignedConsumerId, generationId = joinGroupResult.generationId) + assertEquals(Errors.NONE, leaderCommitOffsetResult (tp)) + } + + @Test + def testTxnCommitOffsetWithIllegalGeneration(): Unit = { + val tp = new TopicPartition("topic", 0) + val offset = offsetAndMetadata(0) + val producerId = 1000L + val producerEpoch : Short = 2 + + val joinGroupResult = dynamicJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols) + val joinGroupError = joinGroupResult.error + assertEquals(Errors.NONE, joinGroupError) + + EasyMock.reset(replicaManager) + + val assignedConsumerId = joinGroupResult.memberId + val initialGenerationId = joinGroupResult.generationId + val illegalGenerationCommitOffsetResult = commitTransactionalOffsets(groupId, producerId, producerEpoch, + Map(tp -> offset), memberId = assignedConsumerId, generationId = initialGenerationId + 5) + assertEquals(Errors.ILLEGAL_GENERATION, illegalGenerationCommitOffsetResult(tp)) + } + + @Test + def testTxnCommitOffsetWithLegalGeneration(): Unit = { + val tp = new TopicPartition("topic", 0) + val offset = offsetAndMetadata(0) + val producerId = 1000L + val producerEpoch : Short = 2 + + val joinGroupResult = dynamicJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols) + val joinGroupError = joinGroupResult.error + assertEquals(Errors.NONE, joinGroupError) + + EasyMock.reset(replicaManager) + + val assignedConsumerId = joinGroupResult.memberId + val initialGenerationId = joinGroupResult.generationId + val leaderCommitOffsetResult = commitTransactionalOffsets(groupId, producerId, producerEpoch, + Map(tp -> offset), memberId = assignedConsumerId, generationId = initialGenerationId) + assertEquals(Errors.NONE, leaderCommitOffsetResult (tp)) + } + + @Test + def testHeartbeatDuringRebalanceCausesRebalanceInProgress(): Unit = { + // First start up a group (with a slightly larger timeout to give us time to heartbeat when the rebalance starts) + val joinGroupResult = dynamicJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols) + val assignedConsumerId = joinGroupResult.memberId + val initialGenerationId = joinGroupResult.generationId + val joinGroupError = joinGroupResult.error + assertEquals(Errors.NONE, joinGroupError) + + // Then join with a new consumer to trigger a rebalance + EasyMock.reset(replicaManager) + sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols) + + // We should be in the middle of a rebalance, so the heartbeat should return rebalance in progress + EasyMock.reset(replicaManager) + val heartbeatResult = heartbeat(groupId, assignedConsumerId, initialGenerationId) + assertEquals(Errors.REBALANCE_IN_PROGRESS, heartbeatResult) + } + + @Test + def testGenerationIdIncrementsOnRebalance(): Unit = { + val joinGroupResult = dynamicJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols) + val initialGenerationId = joinGroupResult.generationId + val joinGroupError = joinGroupResult.error + val memberId = joinGroupResult.memberId + assertEquals(1, initialGenerationId) + assertEquals(Errors.NONE, joinGroupError) + + EasyMock.reset(replicaManager) + val syncGroupResult = syncGroupLeader(groupId, initialGenerationId, memberId, Map(memberId -> Array[Byte]())) + assertEquals(Errors.NONE, syncGroupResult.error) + + EasyMock.reset(replicaManager) + val joinGroupFuture = sendJoinGroup(groupId, memberId, protocolType, protocols) + val otherJoinGroupResult = await(joinGroupFuture, 1) + + val nextGenerationId = otherJoinGroupResult.generationId + val otherJoinGroupError = otherJoinGroupResult.error + assertEquals(2, nextGenerationId) + assertEquals(Errors.NONE, otherJoinGroupError) + } + + @Test + def testLeaveGroupWrongCoordinator(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + + val leaveGroupResults = singleLeaveGroup(otherGroupId, memberId) + verifyLeaveGroupResult(leaveGroupResults, Errors.NOT_COORDINATOR) + } + + @Test + def testLeaveGroupUnknownGroup(): Unit = { + val leaveGroupResults = singleLeaveGroup(groupId, memberId) + verifyLeaveGroupResult(leaveGroupResults, Errors.NONE, List(Errors.UNKNOWN_MEMBER_ID)) + } + + @Test + def testLeaveGroupUnknownConsumerExistingGroup(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + val otherMemberId = "consumerId" + + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols) + val joinGroupError = joinGroupResult.error + assertEquals(Errors.NONE, joinGroupError) + + EasyMock.reset(replicaManager) + val leaveGroupResults = singleLeaveGroup(groupId, otherMemberId) + verifyLeaveGroupResult(leaveGroupResults, Errors.NONE, List(Errors.UNKNOWN_MEMBER_ID)) + } + + @Test + def testSingleLeaveDeadGroup(): Unit = { + val deadGroupId = "deadGroupId" + + groupCoordinator.groupManager.addGroup(new GroupMetadata(deadGroupId, Dead, new MockTime())) + val leaveGroupResults = singleLeaveGroup(deadGroupId, memberId) + verifyLeaveGroupResult(leaveGroupResults, Errors.COORDINATOR_NOT_AVAILABLE) + } + + @Test + def testBatchLeaveDeadGroup(): Unit = { + val deadGroupId = "deadGroupId" + + groupCoordinator.groupManager.addGroup(new GroupMetadata(deadGroupId, Dead, new MockTime())) + val leaveGroupResults = batchLeaveGroup(deadGroupId, + List(new MemberIdentity().setMemberId(memberId), new MemberIdentity().setMemberId(memberId))) + verifyLeaveGroupResult(leaveGroupResults, Errors.COORDINATOR_NOT_AVAILABLE) + } + + @Test + def testValidLeaveGroup(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols) + val assignedMemberId = joinGroupResult.memberId + val joinGroupError = joinGroupResult.error + assertEquals(Errors.NONE, joinGroupError) + + EasyMock.reset(replicaManager) + val leaveGroupResults = singleLeaveGroup(groupId, assignedMemberId) + verifyLeaveGroupResult(leaveGroupResults) + } + + @Test + def testLeaveGroupWithFencedInstanceId(): Unit = { + val joinGroupResult = staticJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, leaderInstanceId, protocolType, protocolSuperset) + assertEquals(Errors.NONE, joinGroupResult.error) + + EasyMock.reset(replicaManager) + val leaveGroupResults = singleLeaveGroup(groupId, "some_member", Some(leaderInstanceId)) + verifyLeaveGroupResult(leaveGroupResults, Errors.NONE, List(Errors.FENCED_INSTANCE_ID)) + } + + @Test + def testLeaveGroupStaticMemberWithUnknownMemberId(): Unit = { + val joinGroupResult = staticJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, leaderInstanceId, protocolType, protocolSuperset) + assertEquals(Errors.NONE, joinGroupResult.error) + + EasyMock.reset(replicaManager) + // Having unknown member id will not affect the request processing. + val leaveGroupResults = singleLeaveGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, Some(leaderInstanceId)) + verifyLeaveGroupResult(leaveGroupResults, Errors.NONE, List(Errors.NONE)) + } + + @Test + def testStaticMembersValidBatchLeaveGroup(): Unit = { + staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId) + + val leaveGroupResults = batchLeaveGroup(groupId, List(new MemberIdentity() + .setGroupInstanceId(leaderInstanceId), new MemberIdentity().setGroupInstanceId(followerInstanceId))) + + verifyLeaveGroupResult(leaveGroupResults, Errors.NONE, List(Errors.NONE, Errors.NONE)) + } + + @Test + def testStaticMembersWrongCoordinatorBatchLeaveGroup(): Unit = { + staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId) + + val leaveGroupResults = batchLeaveGroup("invalid-group", List(new MemberIdentity() + .setGroupInstanceId(leaderInstanceId), new MemberIdentity().setGroupInstanceId(followerInstanceId))) + + verifyLeaveGroupResult(leaveGroupResults, Errors.NOT_COORDINATOR) + } + + @Test + def testStaticMembersUnknownGroupBatchLeaveGroup(): Unit = { + val leaveGroupResults = batchLeaveGroup(groupId, List(new MemberIdentity() + .setGroupInstanceId(leaderInstanceId), new MemberIdentity().setGroupInstanceId(followerInstanceId))) + + verifyLeaveGroupResult(leaveGroupResults, Errors.NONE, List(Errors.UNKNOWN_MEMBER_ID, Errors.UNKNOWN_MEMBER_ID)) + } + + @Test + def testStaticMembersFencedInstanceBatchLeaveGroup(): Unit = { + staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId) + + val leaveGroupResults = batchLeaveGroup(groupId, List(new MemberIdentity() + .setGroupInstanceId(leaderInstanceId), new MemberIdentity() + .setGroupInstanceId(followerInstanceId) + .setMemberId("invalid-member"))) + + verifyLeaveGroupResult(leaveGroupResults, Errors.NONE, List(Errors.NONE, Errors.FENCED_INSTANCE_ID)) + } + + @Test + def testStaticMembersUnknownInstanceBatchLeaveGroup(): Unit = { + staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId) + + val leaveGroupResults = batchLeaveGroup(groupId, List(new MemberIdentity() + .setGroupInstanceId("unknown-instance"), new MemberIdentity() + .setGroupInstanceId(followerInstanceId))) + + verifyLeaveGroupResult(leaveGroupResults, Errors.NONE, List(Errors.UNKNOWN_MEMBER_ID, Errors.NONE)) + } + + @Test + def testPendingMemberBatchLeaveGroup(): Unit = { + val pendingMember = setupGroupWithPendingMember() + + EasyMock.reset(replicaManager) + val leaveGroupResults = batchLeaveGroup(groupId, List(new MemberIdentity() + .setGroupInstanceId("unknown-instance"), new MemberIdentity() + .setMemberId(pendingMember.memberId))) + + verifyLeaveGroupResult(leaveGroupResults, Errors.NONE, List(Errors.UNKNOWN_MEMBER_ID, Errors.NONE)) + } + + @Test + def testListGroupsIncludesStableGroups(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols) + val assignedMemberId = joinGroupResult.memberId + val generationId = joinGroupResult.generationId + assertEquals(Errors.NONE, joinGroupResult.error) + + EasyMock.reset(replicaManager) + val syncGroupResult = syncGroupLeader(groupId, generationId, assignedMemberId, Map(assignedMemberId -> Array[Byte]())) + assertEquals(Errors.NONE, syncGroupResult.error) + + val (error, groups) = groupCoordinator.handleListGroups(Set()) + assertEquals(Errors.NONE, error) + assertEquals(1, groups.size) + assertEquals(GroupOverview("groupId", "consumer", Stable.toString), groups.head) + } + + @Test + def testListGroupsIncludesRebalancingGroups(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols) + assertEquals(Errors.NONE, joinGroupResult.error) + + val (error, groups) = groupCoordinator.handleListGroups(Set()) + assertEquals(Errors.NONE, error) + assertEquals(1, groups.size) + assertEquals(GroupOverview("groupId", "consumer", CompletingRebalance.toString), groups.head) + } + + @Test + def testListGroupsWithStates(): Unit = { + val allStates = Set(PreparingRebalance, CompletingRebalance, Stable, Dead, Empty).map(s => s.toString) + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + + // Member joins the group + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols) + val assignedMemberId = joinGroupResult.memberId + val generationId = joinGroupResult.generationId + assertEquals(Errors.NONE, joinGroupResult.error) + + // The group should be in CompletingRebalance + val (error, groups) = groupCoordinator.handleListGroups(Set(CompletingRebalance.toString)) + assertEquals(Errors.NONE, error) + assertEquals(1, groups.size) + val (error2, groups2) = groupCoordinator.handleListGroups(allStates.filterNot(s => s == CompletingRebalance.toString)) + assertEquals(Errors.NONE, error2) + assertEquals(0, groups2.size) + + // Member syncs + EasyMock.reset(replicaManager) + val syncGroupResult = syncGroupLeader(groupId, generationId, assignedMemberId, Map(assignedMemberId -> Array[Byte]())) + assertEquals(Errors.NONE, syncGroupResult.error) + + // The group is now stable + val (error3, groups3) = groupCoordinator.handleListGroups(Set(Stable.toString)) + assertEquals(Errors.NONE, error3) + assertEquals(1, groups3.size) + val (error4, groups4) = groupCoordinator.handleListGroups(allStates.filterNot(s => s == Stable.toString)) + assertEquals(Errors.NONE, error4) + assertEquals(0, groups4.size) + + // Member leaves + EasyMock.reset(replicaManager) + val leaveGroupResults = singleLeaveGroup(groupId, assignedMemberId) + verifyLeaveGroupResult(leaveGroupResults) + + // The group is now empty + val (error5, groups5) = groupCoordinator.handleListGroups(Set(Empty.toString)) + assertEquals(Errors.NONE, error5) + assertEquals(1, groups5.size) + val (error6, groups6) = groupCoordinator.handleListGroups(allStates.filterNot(s => s == Empty.toString)) + assertEquals(Errors.NONE, error6) + assertEquals(0, groups6.size) + } + + @Test + def testDescribeGroupWrongCoordinator(): Unit = { + EasyMock.reset(replicaManager) + val (error, _) = groupCoordinator.handleDescribeGroup(otherGroupId) + assertEquals(Errors.NOT_COORDINATOR, error) + } + + @Test + def testDescribeGroupInactiveGroup(): Unit = { + EasyMock.reset(replicaManager) + val (error, summary) = groupCoordinator.handleDescribeGroup(groupId) + assertEquals(Errors.NONE, error) + assertEquals(GroupCoordinator.DeadGroup, summary) + } + + @Test + def testDescribeGroupStableForDynamicMember(): Unit = { + val joinGroupResult = dynamicJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols) + val assignedMemberId = joinGroupResult.memberId + val generationId = joinGroupResult.generationId + val joinGroupError = joinGroupResult.error + assertEquals(Errors.NONE, joinGroupError) + + EasyMock.reset(replicaManager) + val syncGroupResult = syncGroupLeader(groupId, generationId, assignedMemberId, Map(assignedMemberId -> Array[Byte]())) + assertEquals(Errors.NONE, syncGroupResult.error) + + EasyMock.reset(replicaManager) + val (error, summary) = groupCoordinator.handleDescribeGroup(groupId) + assertEquals(Errors.NONE, error) + assertEquals(protocolType, summary.protocolType) + assertEquals("range", summary.protocol) + assertEquals(List(assignedMemberId), summary.members.map(_.memberId)) + } + + @Test + def testDescribeGroupStableForStaticMember(): Unit = { + val joinGroupResult = staticJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, leaderInstanceId, protocolType, protocols) + val assignedMemberId = joinGroupResult.memberId + val generationId = joinGroupResult.generationId + val joinGroupError = joinGroupResult.error + assertEquals(Errors.NONE, joinGroupError) + + EasyMock.reset(replicaManager) + val syncGroupResult = syncGroupLeader(groupId, generationId, assignedMemberId, Map(assignedMemberId -> Array[Byte]())) + assertEquals(Errors.NONE, syncGroupResult.error) + + EasyMock.reset(replicaManager) + val (error, summary) = groupCoordinator.handleDescribeGroup(groupId) + assertEquals(Errors.NONE, error) + assertEquals(protocolType, summary.protocolType) + assertEquals("range", summary.protocol) + assertEquals(List(assignedMemberId), summary.members.map(_.memberId)) + assertEquals(List(leaderInstanceId), summary.members.flatMap(_.groupInstanceId)) + } + + @Test + def testDescribeGroupRebalancing(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols) + val joinGroupError = joinGroupResult.error + assertEquals(Errors.NONE, joinGroupError) + + EasyMock.reset(replicaManager) + val (error, summary) = groupCoordinator.handleDescribeGroup(groupId) + assertEquals(Errors.NONE, error) + assertEquals(protocolType, summary.protocolType) + assertEquals(GroupCoordinator.NoProtocol, summary.protocol) + assertEquals(CompletingRebalance.toString, summary.state) + assertTrue(summary.members.map(_.memberId).contains(joinGroupResult.memberId)) + assertTrue(summary.members.forall(_.metadata.isEmpty)) + assertTrue(summary.members.forall(_.assignment.isEmpty)) + } + + @Test + def testDeleteNonEmptyGroup(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + dynamicJoinGroup(groupId, memberId, protocolType, protocols) + + val result = groupCoordinator.handleDeleteGroups(Set(groupId)) + assert(result.size == 1 && result.contains(groupId) && result.get(groupId).contains(Errors.NON_EMPTY_GROUP)) + } + + @Test + def testDeleteGroupWithInvalidGroupId(): Unit = { + val invalidGroupId = null + val result = groupCoordinator.handleDeleteGroups(Set(invalidGroupId)) + assert(result.size == 1 && result.contains(invalidGroupId) && result.get(invalidGroupId).contains(Errors.INVALID_GROUP_ID)) + } + + @Test + def testDeleteGroupWithWrongCoordinator(): Unit = { + val result = groupCoordinator.handleDeleteGroups(Set(otherGroupId)) + assert(result.size == 1 && result.contains(otherGroupId) && result.get(otherGroupId).contains(Errors.NOT_COORDINATOR)) + } + + @Test + def testDeleteEmptyGroup(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols) + + EasyMock.reset(replicaManager) + val leaveGroupResults = singleLeaveGroup(groupId, joinGroupResult.memberId) + verifyLeaveGroupResult(leaveGroupResults) + + val groupTopicPartition = new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, groupPartitionId) + val partition: Partition = EasyMock.niceMock(classOf[Partition]) + + EasyMock.reset(replicaManager) + EasyMock.expect(replicaManager.getMagic(EasyMock.anyObject())).andStubReturn(Some(RecordBatch.CURRENT_MAGIC_VALUE)) + EasyMock.expect(replicaManager.getPartition(groupTopicPartition)).andStubReturn(HostedPartition.Online(partition)) + EasyMock.expect(replicaManager.onlinePartition(groupTopicPartition)).andStubReturn(Some(partition)) + EasyMock.replay(replicaManager, partition) + + val result = groupCoordinator.handleDeleteGroups(Set(groupId)) + assert(result.size == 1 && result.contains(groupId) && result.get(groupId).contains(Errors.NONE)) + } + + @Test + def testDeleteEmptyGroupWithStoredOffsets(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols) + val assignedMemberId = joinGroupResult.memberId + val joinGroupError = joinGroupResult.error + assertEquals(Errors.NONE, joinGroupError) + + EasyMock.reset(replicaManager) + val syncGroupResult = syncGroupLeader(groupId, joinGroupResult.generationId, assignedMemberId, Map(assignedMemberId -> Array[Byte]())) + assertEquals(Errors.NONE, syncGroupResult.error) + + EasyMock.reset(replicaManager) + val tp = new TopicPartition("topic", 0) + val offset = offsetAndMetadata(0) + val commitOffsetResult = commitOffsets(groupId, assignedMemberId, joinGroupResult.generationId, Map(tp -> offset)) + assertEquals(Errors.NONE, commitOffsetResult(tp)) + + val describeGroupResult = groupCoordinator.handleDescribeGroup(groupId) + assertEquals(Stable.toString, describeGroupResult._2.state) + assertEquals(assignedMemberId, describeGroupResult._2.members.head.memberId) + + EasyMock.reset(replicaManager) + val leaveGroupResults = singleLeaveGroup(groupId, assignedMemberId) + verifyLeaveGroupResult(leaveGroupResults) + + val groupTopicPartition = new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, groupPartitionId) + val partition: Partition = EasyMock.niceMock(classOf[Partition]) + + EasyMock.reset(replicaManager) + EasyMock.expect(replicaManager.getMagic(EasyMock.anyObject())).andStubReturn(Some(RecordBatch.CURRENT_MAGIC_VALUE)) + EasyMock.expect(replicaManager.getPartition(groupTopicPartition)).andStubReturn(HostedPartition.Online(partition)) + EasyMock.expect(replicaManager.onlinePartition(groupTopicPartition)).andStubReturn(Some(partition)) + EasyMock.replay(replicaManager, partition) + + val result = groupCoordinator.handleDeleteGroups(Set(groupId)) + assert(result.size == 1 && result.contains(groupId) && result.get(groupId).contains(Errors.NONE)) + + assertEquals(Dead.toString, groupCoordinator.handleDescribeGroup(groupId)._2.state) + } + + @Test + def testDeleteOffsetOfNonExistingGroup(): Unit = { + val tp = new TopicPartition("foo", 0) + val (groupError, topics) = groupCoordinator.handleDeleteOffsets(groupId, Seq(tp), + RequestLocal.NoCaching) + + assertEquals(Errors.GROUP_ID_NOT_FOUND, groupError) + assertTrue(topics.isEmpty) + } + + @Test + def testDeleteOffsetOfNonEmptyNonConsumerGroup(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + dynamicJoinGroup(groupId, memberId, "My Protocol", protocols) + val tp = new TopicPartition("foo", 0) + val (groupError, topics) = groupCoordinator.handleDeleteOffsets(groupId, Seq(tp), + RequestLocal.NoCaching) + + assertEquals(Errors.NON_EMPTY_GROUP, groupError) + assertTrue(topics.isEmpty) + } + + @Test + def testDeleteOffsetOfEmptyNonConsumerGroup(): Unit = { + // join the group + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + + val joinGroupResult = dynamicJoinGroup(groupId, memberId, "My Protocol", protocols) + assertEquals(Errors.NONE, joinGroupResult.error) + + EasyMock.reset(replicaManager) + val syncGroupResult = syncGroupLeader(groupId, joinGroupResult.generationId, joinGroupResult.leaderId, Map.empty) + assertEquals(Errors.NONE, syncGroupResult.error) + + val t1p0 = new TopicPartition("foo", 0) + val t2p0 = new TopicPartition("bar", 0) + val offset = offsetAndMetadata(37) + + EasyMock.reset(replicaManager) + val validOffsetCommitResult = commitOffsets(groupId, joinGroupResult.memberId, joinGroupResult.generationId, + Map(t1p0 -> offset, t2p0 -> offset)) + assertEquals(Errors.NONE, validOffsetCommitResult(t1p0)) + assertEquals(Errors.NONE, validOffsetCommitResult(t2p0)) + + // and leaves. + EasyMock.reset(replicaManager) + val leaveGroupResults = singleLeaveGroup(groupId, joinGroupResult.memberId) + verifyLeaveGroupResult(leaveGroupResults) + + assertTrue(groupCoordinator.groupManager.getGroup(groupId).exists(_.is(Empty))) + + val groupTopicPartition = new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, groupPartitionId) + val partition: Partition = EasyMock.niceMock(classOf[Partition]) + + EasyMock.reset(replicaManager) + EasyMock.expect(replicaManager.getMagic(EasyMock.anyObject())).andStubReturn(Some(RecordBatch.CURRENT_MAGIC_VALUE)) + EasyMock.expect(replicaManager.getPartition(groupTopicPartition)).andStubReturn(HostedPartition.Online(partition)) + EasyMock.expect(replicaManager.onlinePartition(groupTopicPartition)).andStubReturn(Some(partition)) + EasyMock.replay(replicaManager, partition) + + val (groupError, topics) = groupCoordinator.handleDeleteOffsets(groupId, Seq(t1p0), + RequestLocal.NoCaching) + + assertEquals(Errors.NONE, groupError) + assertEquals(1, topics.size) + assertEquals(Some(Errors.NONE), topics.get(t1p0)) + + val cachedOffsets = groupCoordinator.groupManager.getOffsets(groupId, requireStable, Some(Seq(t1p0, t2p0))) + + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), cachedOffsets.get(t1p0).map(_.offset)) + assertEquals(Some(offset.offset), cachedOffsets.get(t2p0).map(_.offset)) + } + + @Test + def testDeleteOffsetOfConsumerGroupWithUnparsableProtocol(): Unit = { + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols) + + EasyMock.reset(replicaManager) + val syncGroupResult = syncGroupLeader(groupId, joinGroupResult.generationId, joinGroupResult.leaderId, Map.empty) + assertEquals(Errors.NONE, syncGroupResult.error) + + val tp = new TopicPartition("foo", 0) + val offset = offsetAndMetadata(37) + + EasyMock.reset(replicaManager) + val validOffsetCommitResult = commitOffsets(groupId, joinGroupResult.memberId, joinGroupResult.generationId, + Map(tp -> offset)) + assertEquals(Errors.NONE, validOffsetCommitResult(tp)) + + val (groupError, topics) = groupCoordinator.handleDeleteOffsets(groupId, Seq(tp), + RequestLocal.NoCaching) + + assertEquals(Errors.NONE, groupError) + assertEquals(1, topics.size) + assertEquals(Some(Errors.GROUP_SUBSCRIBED_TO_TOPIC), topics.get(tp)) + } + + @Test + def testDeleteOffsetOfDeadConsumerGroup(): Unit = { + val group = new GroupMetadata(groupId, Dead, new MockTime()) + group.protocolType = Some(protocolType) + groupCoordinator.groupManager.addGroup(group) + + val tp = new TopicPartition("foo", 0) + val (groupError, topics) = groupCoordinator.handleDeleteOffsets(groupId, Seq(tp), + RequestLocal.NoCaching) + + assertEquals(Errors.GROUP_ID_NOT_FOUND, groupError) + assertTrue(topics.isEmpty) + } + + @Test + def testDeleteOffsetOfEmptyConsumerGroup(): Unit = { + // join the group + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, protocols) + assertEquals(Errors.NONE, joinGroupResult.error) + + EasyMock.reset(replicaManager) + val syncGroupResult = syncGroupLeader(groupId, joinGroupResult.generationId, joinGroupResult.leaderId, Map.empty) + assertEquals(Errors.NONE, syncGroupResult.error) + + val t1p0 = new TopicPartition("foo", 0) + val t2p0 = new TopicPartition("bar", 0) + val offset = offsetAndMetadata(37) + + EasyMock.reset(replicaManager) + val validOffsetCommitResult = commitOffsets(groupId, joinGroupResult.memberId, joinGroupResult.generationId, + Map(t1p0 -> offset, t2p0 -> offset)) + assertEquals(Errors.NONE, validOffsetCommitResult(t1p0)) + assertEquals(Errors.NONE, validOffsetCommitResult(t2p0)) + + // and leaves. + EasyMock.reset(replicaManager) + val leaveGroupResults = singleLeaveGroup(groupId, joinGroupResult.memberId) + verifyLeaveGroupResult(leaveGroupResults) + + assertTrue(groupCoordinator.groupManager.getGroup(groupId).exists(_.is(Empty))) + + val groupTopicPartition = new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, groupPartitionId) + val partition: Partition = EasyMock.niceMock(classOf[Partition]) + + EasyMock.reset(replicaManager) + EasyMock.expect(replicaManager.getMagic(EasyMock.anyObject())).andStubReturn(Some(RecordBatch.CURRENT_MAGIC_VALUE)) + EasyMock.expect(replicaManager.getPartition(groupTopicPartition)).andStubReturn(HostedPartition.Online(partition)) + EasyMock.expect(replicaManager.onlinePartition(groupTopicPartition)).andStubReturn(Some(partition)) + EasyMock.replay(replicaManager, partition) + + val (groupError, topics) = groupCoordinator.handleDeleteOffsets(groupId, Seq(t1p0), + RequestLocal.NoCaching) + + assertEquals(Errors.NONE, groupError) + assertEquals(1, topics.size) + assertEquals(Some(Errors.NONE), topics.get(t1p0)) + + val cachedOffsets = groupCoordinator.groupManager.getOffsets(groupId, requireStable, Some(Seq(t1p0, t2p0))) + + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), cachedOffsets.get(t1p0).map(_.offset)) + assertEquals(Some(offset.offset), cachedOffsets.get(t2p0).map(_.offset)) + } + + @Test + def testDeleteOffsetOfStableConsumerGroup(): Unit = { + // join the group + val memberId = JoinGroupRequest.UNKNOWN_MEMBER_ID + val subscription = new Subscription(List("bar").asJava) + + val joinGroupResult = dynamicJoinGroup(groupId, memberId, protocolType, + List(("protocol", ConsumerProtocol.serializeSubscription(subscription).array()))) + assertEquals(Errors.NONE, joinGroupResult.error) + + EasyMock.reset(replicaManager) + val syncGroupResult = syncGroupLeader(groupId, joinGroupResult.generationId, joinGroupResult.leaderId, Map.empty) + assertEquals(Errors.NONE, syncGroupResult.error) + + val t1p0 = new TopicPartition("foo", 0) + val t2p0 = new TopicPartition("bar", 0) + val offset = offsetAndMetadata(37) + + EasyMock.reset(replicaManager) + val validOffsetCommitResult = commitOffsets(groupId, joinGroupResult.memberId, joinGroupResult.generationId, + Map(t1p0 -> offset, t2p0 -> offset)) + assertEquals(Errors.NONE, validOffsetCommitResult(t1p0)) + assertEquals(Errors.NONE, validOffsetCommitResult(t2p0)) + + assertTrue(groupCoordinator.groupManager.getGroup(groupId).exists(_.is(Stable))) + + val groupTopicPartition = new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, groupPartitionId) + val partition: Partition = EasyMock.niceMock(classOf[Partition]) + + EasyMock.reset(replicaManager) + EasyMock.expect(replicaManager.getMagic(EasyMock.anyObject())).andStubReturn(Some(RecordBatch.CURRENT_MAGIC_VALUE)) + EasyMock.expect(replicaManager.getPartition(groupTopicPartition)).andStubReturn(HostedPartition.Online(partition)) + EasyMock.expect(replicaManager.onlinePartition(groupTopicPartition)).andStubReturn(Some(partition)) + EasyMock.replay(replicaManager, partition) + + val (groupError, topics) = groupCoordinator.handleDeleteOffsets(groupId, Seq(t1p0, t2p0), + RequestLocal.NoCaching) + + assertEquals(Errors.NONE, groupError) + assertEquals(2, topics.size) + assertEquals(Some(Errors.NONE), topics.get(t1p0)) + assertEquals(Some(Errors.GROUP_SUBSCRIBED_TO_TOPIC), topics.get(t2p0)) + + val cachedOffsets = groupCoordinator.groupManager.getOffsets(groupId, requireStable, Some(Seq(t1p0, t2p0))) + + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), cachedOffsets.get(t1p0).map(_.offset)) + assertEquals(Some(offset.offset), cachedOffsets.get(t2p0).map(_.offset)) + } + + @Test + def shouldDelayInitialRebalanceByGroupInitialRebalanceDelayOnEmptyGroup(): Unit = { + val firstJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols) + timer.advanceClock(GroupInitialRebalanceDelay / 2) + verifyDelayedTaskNotCompleted(firstJoinFuture) + timer.advanceClock((GroupInitialRebalanceDelay / 2) + 1) + val joinGroupResult = await(firstJoinFuture, 1) + assertEquals(Errors.NONE, joinGroupResult.error) + } + + private def verifyDelayedTaskNotCompleted(firstJoinFuture: Future[JoinGroupResult]) = { + assertThrows(classOf[TimeoutException], () => await(firstJoinFuture, 1), + () => "should have timed out as rebalance delay not expired") + } + + @Test + def shouldResetRebalanceDelayWhenNewMemberJoinsGroupInInitialRebalance(): Unit = { + val rebalanceTimeout = GroupInitialRebalanceDelay * 3 + val firstMemberJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols, rebalanceTimeout = rebalanceTimeout) + EasyMock.reset(replicaManager) + timer.advanceClock(GroupInitialRebalanceDelay - 1) + val secondMemberJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols, rebalanceTimeout = rebalanceTimeout) + EasyMock.reset(replicaManager) + timer.advanceClock(2) + + // advance past initial rebalance delay and make sure that tasks + // haven't been completed + timer.advanceClock(GroupInitialRebalanceDelay / 2 + 1) + verifyDelayedTaskNotCompleted(firstMemberJoinFuture) + verifyDelayedTaskNotCompleted(secondMemberJoinFuture) + // advance clock beyond updated delay and make sure the + // tasks have completed + timer.advanceClock(GroupInitialRebalanceDelay / 2) + val firstResult = await(firstMemberJoinFuture, 1) + val secondResult = await(secondMemberJoinFuture, 1) + assertEquals(Errors.NONE, firstResult.error) + assertEquals(Errors.NONE, secondResult.error) + } + + @Test + def shouldDelayRebalanceUptoRebalanceTimeout(): Unit = { + val rebalanceTimeout = GroupInitialRebalanceDelay * 2 + val firstMemberJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols, rebalanceTimeout = rebalanceTimeout) + EasyMock.reset(replicaManager) + val secondMemberJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols, rebalanceTimeout = rebalanceTimeout) + timer.advanceClock(GroupInitialRebalanceDelay + 1) + EasyMock.reset(replicaManager) + val thirdMemberJoinFuture = sendJoinGroup(groupId, JoinGroupRequest.UNKNOWN_MEMBER_ID, protocolType, protocols, rebalanceTimeout = rebalanceTimeout) + timer.advanceClock(GroupInitialRebalanceDelay) + EasyMock.reset(replicaManager) + + verifyDelayedTaskNotCompleted(firstMemberJoinFuture) + verifyDelayedTaskNotCompleted(secondMemberJoinFuture) + verifyDelayedTaskNotCompleted(thirdMemberJoinFuture) + + // advance clock beyond rebalanceTimeout + timer.advanceClock(1) + + val firstResult = await(firstMemberJoinFuture, 1) + val secondResult = await(secondMemberJoinFuture, 1) + val thirdResult = await(thirdMemberJoinFuture, 1) + assertEquals(Errors.NONE, firstResult.error) + assertEquals(Errors.NONE, secondResult.error) + assertEquals(Errors.NONE, thirdResult.error) + } + + @Test + def testCompleteHeartbeatWithGroupDead(): Unit = { + val rebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId) + heartbeat(groupId, rebalanceResult.leaderId, rebalanceResult.generation) + val group = getGroup(groupId) + group.transitionTo(Dead) + val leaderMemberId = rebalanceResult.leaderId + assertTrue(groupCoordinator.tryCompleteHeartbeat(group, leaderMemberId, false, () => true)) + groupCoordinator.onExpireHeartbeat(group, leaderMemberId, false) + assertTrue(group.has(leaderMemberId)) + } + + @Test + def testCompleteHeartbeatWithMemberAlreadyRemoved(): Unit = { + val rebalanceResult = staticMembersJoinAndRebalance(leaderInstanceId, followerInstanceId) + heartbeat(groupId, rebalanceResult.leaderId, rebalanceResult.generation) + val group = getGroup(groupId) + val leaderMemberId = rebalanceResult.leaderId + group.remove(leaderMemberId) + assertTrue(groupCoordinator.tryCompleteHeartbeat(group, leaderMemberId, false, () => true)) + } + + private def getGroup(groupId: String): GroupMetadata = { + val groupOpt = groupCoordinator.groupManager.getGroup(groupId) + assertTrue(groupOpt.isDefined) + groupOpt.get + } + + private def setupJoinGroupCallback: (Future[JoinGroupResult], JoinGroupCallback) = { + val responsePromise = Promise[JoinGroupResult]() + val responseFuture = responsePromise.future + val responseCallback: JoinGroupCallback = responsePromise.success + (responseFuture, responseCallback) + } + + private def setupSyncGroupCallback: (Future[SyncGroupResult], SyncGroupCallback) = { + val responsePromise = Promise[SyncGroupResult]() + val responseFuture = responsePromise.future + val responseCallback: SyncGroupCallback = responsePromise.success + (responseFuture, responseCallback) + } + + private def setupHeartbeatCallback: (Future[HeartbeatCallbackParams], HeartbeatCallback) = { + val responsePromise = Promise[HeartbeatCallbackParams]() + val responseFuture = responsePromise.future + val responseCallback: HeartbeatCallback = error => responsePromise.success(error) + (responseFuture, responseCallback) + } + + private def setupCommitOffsetsCallback: (Future[CommitOffsetCallbackParams], CommitOffsetCallback) = { + val responsePromise = Promise[CommitOffsetCallbackParams]() + val responseFuture = responsePromise.future + val responseCallback: CommitOffsetCallback = offsets => responsePromise.success(offsets) + (responseFuture, responseCallback) + } + + private def setupLeaveGroupCallback: (Future[LeaveGroupResult], LeaveGroupCallback) = { + val responsePromise = Promise[LeaveGroupResult]() + val responseFuture = responsePromise.future + val responseCallback: LeaveGroupCallback = result => responsePromise.success(result) + (responseFuture, responseCallback) + } + + private def sendJoinGroup(groupId: String, + memberId: String, + protocolType: String, + protocols: List[(String, Array[Byte])], + groupInstanceId: Option[String] = None, + sessionTimeout: Int = DefaultSessionTimeout, + rebalanceTimeout: Int = DefaultRebalanceTimeout, + requireKnownMemberId: Boolean = false): Future[JoinGroupResult] = { + val (responseFuture, responseCallback) = setupJoinGroupCallback + + EasyMock.expect(replicaManager.getMagic(EasyMock.anyObject())).andReturn(Some(RecordBatch.MAGIC_VALUE_V1)).anyTimes() + EasyMock.replay(replicaManager) + + groupCoordinator.handleJoinGroup(groupId, memberId, groupInstanceId, + requireKnownMemberId, "clientId", "clientHost", rebalanceTimeout, sessionTimeout, protocolType, protocols, responseCallback) + responseFuture + } + + private def sendStaticJoinGroupWithPersistence(groupId: String, + memberId: String, + protocolType: String, + protocols: List[(String, Array[Byte])], + groupInstanceId: String, + sessionTimeout: Int, + rebalanceTimeout: Int, + appendRecordError: Errors, + requireKnownMemberId: Boolean = false): Future[JoinGroupResult] = { + val (responseFuture, responseCallback) = setupJoinGroupCallback + + val capturedArgument: Capture[scala.collection.Map[TopicPartition, PartitionResponse] => Unit] = EasyMock.newCapture() + + EasyMock.expect(replicaManager.appendRecords(EasyMock.anyLong(), + EasyMock.anyShort(), + internalTopicsAllowed = EasyMock.eq(true), + origin = EasyMock.eq(AppendOrigin.Coordinator), + EasyMock.anyObject().asInstanceOf[Map[TopicPartition, MemoryRecords]], + EasyMock.capture(capturedArgument), + EasyMock.anyObject().asInstanceOf[Option[ReentrantLock]], + EasyMock.anyObject(), + EasyMock.anyObject() + )).andAnswer(new IAnswer[Unit] { + override def answer: Unit = capturedArgument.getValue.apply( + Map(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, groupPartitionId) -> + new PartitionResponse(appendRecordError, 0L, RecordBatch.NO_TIMESTAMP, 0L) + )) + }) + EasyMock.expect(replicaManager.getMagic(EasyMock.anyObject())).andReturn(Some(RecordBatch.MAGIC_VALUE_V1)).anyTimes() + EasyMock.replay(replicaManager) + + groupCoordinator.handleJoinGroup(groupId, memberId, Some(groupInstanceId), + requireKnownMemberId, "clientId", "clientHost", rebalanceTimeout, sessionTimeout, protocolType, protocols, responseCallback) + responseFuture + } + + private def sendSyncGroupLeader(groupId: String, + generation: Int, + leaderId: String, + protocolType: Option[String], + protocolName: Option[String], + groupInstanceId: Option[String], + assignment: Map[String, Array[Byte]]): Future[SyncGroupResult] = { + val (responseFuture, responseCallback) = setupSyncGroupCallback + + val capturedArgument: Capture[scala.collection.Map[TopicPartition, PartitionResponse] => Unit] = EasyMock.newCapture() + + EasyMock.expect(replicaManager.appendRecords(EasyMock.anyLong(), + EasyMock.anyShort(), + internalTopicsAllowed = EasyMock.eq(true), + origin = EasyMock.eq(AppendOrigin.Coordinator), + EasyMock.anyObject().asInstanceOf[Map[TopicPartition, MemoryRecords]], + EasyMock.capture(capturedArgument), + EasyMock.anyObject().asInstanceOf[Option[ReentrantLock]], + EasyMock.anyObject(), + EasyMock.anyObject())).andAnswer(new IAnswer[Unit] { + override def answer = capturedArgument.getValue.apply( + Map(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, groupPartitionId) -> + new PartitionResponse(Errors.NONE, 0L, RecordBatch.NO_TIMESTAMP, 0L) + ) + )}) + EasyMock.expect(replicaManager.getMagic(EasyMock.anyObject())).andReturn(Some(RecordBatch.MAGIC_VALUE_V1)).anyTimes() + EasyMock.replay(replicaManager) + + groupCoordinator.handleSyncGroup(groupId, generation, leaderId, protocolType, protocolName, + groupInstanceId, assignment, responseCallback) + responseFuture + } + + private def sendSyncGroupFollower(groupId: String, + generation: Int, + memberId: String, + prototolType: Option[String] = None, + prototolName: Option[String] = None, + groupInstanceId: Option[String] = None): Future[SyncGroupResult] = { + val (responseFuture, responseCallback) = setupSyncGroupCallback + + EasyMock.replay(replicaManager) + + groupCoordinator.handleSyncGroup(groupId, generation, memberId, + prototolType, prototolName, groupInstanceId, Map.empty[String, Array[Byte]], responseCallback) + responseFuture + } + + private def dynamicJoinGroup(groupId: String, + memberId: String, + protocolType: String, + protocols: List[(String, Array[Byte])], + sessionTimeout: Int = DefaultSessionTimeout, + rebalanceTimeout: Int = DefaultRebalanceTimeout): JoinGroupResult = { + val requireKnownMemberId = true + var responseFuture = sendJoinGroup(groupId, memberId, protocolType, protocols, None, sessionTimeout, rebalanceTimeout, requireKnownMemberId) + + // Since member id is required, we need another bounce to get the successful join group result. + if (memberId == JoinGroupRequest.UNKNOWN_MEMBER_ID && requireKnownMemberId) { + val joinGroupResult = Await.result(responseFuture, Duration(rebalanceTimeout + 100, TimeUnit.MILLISECONDS)) + // If some other error is triggered, return the error immediately for caller to handle. + if (joinGroupResult.error != Errors.MEMBER_ID_REQUIRED) { + return joinGroupResult + } + EasyMock.reset(replicaManager) + responseFuture = sendJoinGroup(groupId, joinGroupResult.memberId, protocolType, protocols, None, sessionTimeout, rebalanceTimeout, requireKnownMemberId) + } + timer.advanceClock(GroupInitialRebalanceDelay + 1) + // should only have to wait as long as session timeout, but allow some extra time in case of an unexpected delay + Await.result(responseFuture, Duration(rebalanceTimeout + 100, TimeUnit.MILLISECONDS)) + } + + private def staticJoinGroup(groupId: String, + memberId: String, + groupInstanceId: String, + protocolType: String, + protocols: List[(String, Array[Byte])], + clockAdvance: Int = GroupInitialRebalanceDelay + 1, + sessionTimeout: Int = DefaultSessionTimeout, + rebalanceTimeout: Int = DefaultRebalanceTimeout): JoinGroupResult = { + val responseFuture = sendJoinGroup(groupId, memberId, protocolType, protocols, Some(groupInstanceId), sessionTimeout, rebalanceTimeout) + + timer.advanceClock(clockAdvance) + // should only have to wait as long as session timeout, but allow some extra time in case of an unexpected delay + Await.result(responseFuture, Duration(rebalanceTimeout + 100, TimeUnit.MILLISECONDS)) + } + + private def staticJoinGroupWithPersistence(groupId: String, + memberId: String, + groupInstanceId: String, + protocolType: String, + protocols: List[(String, Array[Byte])], + clockAdvance: Int, + sessionTimeout: Int = DefaultSessionTimeout, + rebalanceTimeout: Int = DefaultRebalanceTimeout, + appendRecordError: Errors = Errors.NONE): JoinGroupResult = { + val responseFuture = sendStaticJoinGroupWithPersistence(groupId, memberId, protocolType, protocols, + groupInstanceId, sessionTimeout, rebalanceTimeout, appendRecordError) + + timer.advanceClock(clockAdvance) + // should only have to wait as long as session timeout, but allow some extra time in case of an unexpected delay + Await.result(responseFuture, Duration(rebalanceTimeout + 100, TimeUnit.MILLISECONDS)) + } + + private def syncGroupFollower(groupId: String, + generationId: Int, + memberId: String, + protocolType: Option[String] = None, + protocolName: Option[String] = None, + groupInstanceId: Option[String] = None, + sessionTimeout: Int = DefaultSessionTimeout): SyncGroupResult = { + val responseFuture = sendSyncGroupFollower(groupId, generationId, memberId, protocolType, + protocolName, groupInstanceId) + Await.result(responseFuture, Duration(sessionTimeout + 100, TimeUnit.MILLISECONDS)) + } + + private def syncGroupLeader(groupId: String, + generationId: Int, + memberId: String, + assignment: Map[String, Array[Byte]], + protocolType: Option[String] = None, + protocolName: Option[String] = None, + groupInstanceId: Option[String] = None, + sessionTimeout: Int = DefaultSessionTimeout): SyncGroupResult = { + val responseFuture = sendSyncGroupLeader(groupId, generationId, memberId, protocolType, + protocolName, groupInstanceId, assignment) + Await.result(responseFuture, Duration(sessionTimeout + 100, TimeUnit.MILLISECONDS)) + } + + private def heartbeat(groupId: String, + consumerId: String, + generationId: Int, + groupInstanceId: Option[String] = None): HeartbeatCallbackParams = { + val (responseFuture, responseCallback) = setupHeartbeatCallback + + EasyMock.replay(replicaManager) + + groupCoordinator.handleHeartbeat(groupId, consumerId, groupInstanceId, generationId, responseCallback) + Await.result(responseFuture, Duration(40, TimeUnit.MILLISECONDS)) + } + + private def await[T](future: Future[T], millis: Long): T = { + Await.result(future, Duration(millis, TimeUnit.MILLISECONDS)) + } + + private def commitOffsets(groupId: String, + memberId: String, + generationId: Int, + offsets: Map[TopicPartition, OffsetAndMetadata], + groupInstanceId: Option[String] = None): CommitOffsetCallbackParams = { + val (responseFuture, responseCallback) = setupCommitOffsetsCallback + + val capturedArgument: Capture[scala.collection.Map[TopicPartition, PartitionResponse] => Unit] = EasyMock.newCapture() + + EasyMock.expect(replicaManager.appendRecords(EasyMock.anyLong(), + EasyMock.anyShort(), + internalTopicsAllowed = EasyMock.eq(true), + origin = EasyMock.eq(AppendOrigin.Coordinator), + EasyMock.anyObject().asInstanceOf[Map[TopicPartition, MemoryRecords]], + EasyMock.capture(capturedArgument), + EasyMock.anyObject().asInstanceOf[Option[ReentrantLock]], + EasyMock.anyObject(), + EasyMock.anyObject()) + ).andAnswer(new IAnswer[Unit] { + override def answer = capturedArgument.getValue.apply( + Map(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, groupPartitionId) -> + new PartitionResponse(Errors.NONE, 0L, RecordBatch.NO_TIMESTAMP, 0L) + ) + ) + }) + EasyMock.expect(replicaManager.getMagic(EasyMock.anyObject())).andReturn(Some(RecordBatch.MAGIC_VALUE_V1)).anyTimes() + EasyMock.replay(replicaManager) + + groupCoordinator.handleCommitOffsets(groupId, memberId, groupInstanceId, generationId, offsets, responseCallback) + Await.result(responseFuture, Duration(40, TimeUnit.MILLISECONDS)) + } + + private def commitTransactionalOffsets(groupId: String, + producerId: Long, + producerEpoch: Short, + offsets: Map[TopicPartition, OffsetAndMetadata], + memberId: String = JoinGroupRequest.UNKNOWN_MEMBER_ID, + groupInstanceId: Option[String] = Option.empty, + generationId: Int = JoinGroupRequest.UNKNOWN_GENERATION_ID) = { + val (responseFuture, responseCallback) = setupCommitOffsetsCallback + + val capturedArgument: Capture[scala.collection.Map[TopicPartition, PartitionResponse] => Unit] = EasyMock.newCapture() + + EasyMock.expect(replicaManager.appendRecords(EasyMock.anyLong(), + EasyMock.anyShort(), + internalTopicsAllowed = EasyMock.eq(true), + origin = EasyMock.eq(AppendOrigin.Coordinator), + EasyMock.anyObject().asInstanceOf[Map[TopicPartition, MemoryRecords]], + EasyMock.capture(capturedArgument), + EasyMock.anyObject().asInstanceOf[Option[ReentrantLock]], + EasyMock.anyObject(), + EasyMock.anyObject()) + ).andAnswer(new IAnswer[Unit] { + override def answer = capturedArgument.getValue.apply( + Map(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, groupCoordinator.partitionFor(groupId)) -> + new PartitionResponse(Errors.NONE, 0L, RecordBatch.NO_TIMESTAMP, 0L) + ) + )}) + EasyMock.expect(replicaManager.getMagic(EasyMock.anyObject())).andReturn(Some(RecordBatch.MAGIC_VALUE_V2)).anyTimes() + EasyMock.replay(replicaManager) + + groupCoordinator.handleTxnCommitOffsets(groupId, producerId, producerEpoch, + memberId, groupInstanceId, generationId, offsets, responseCallback) + val result = Await.result(responseFuture, Duration(40, TimeUnit.MILLISECONDS)) + EasyMock.reset(replicaManager) + result + } + + private def singleLeaveGroup(groupId: String, + consumerId: String, + groupInstanceId: Option[String] = None): LeaveGroupResult = { + val singleMemberIdentity = List( + new MemberIdentity() + .setMemberId(consumerId) + .setGroupInstanceId(groupInstanceId.orNull)) + batchLeaveGroup(groupId, singleMemberIdentity) + } + + private def batchLeaveGroup(groupId: String, + memberIdentities: List[MemberIdentity]): LeaveGroupResult = { + val (responseFuture, responseCallback) = setupLeaveGroupCallback + + EasyMock.expect(replicaManager.getPartition(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, groupPartitionId))) + .andReturn(HostedPartition.None) + EasyMock.expect(replicaManager.getMagic(EasyMock.anyObject())).andReturn(Some(RecordBatch.MAGIC_VALUE_V1)).anyTimes() + EasyMock.replay(replicaManager) + + groupCoordinator.handleLeaveGroup(groupId, memberIdentities, responseCallback) + Await.result(responseFuture, Duration(40, TimeUnit.MILLISECONDS)) + } + + def handleTxnCompletion(producerId: Long, + offsetsPartitions: Iterable[TopicPartition], + transactionResult: TransactionResult): Unit = { + val isCommit = transactionResult == TransactionResult.COMMIT + groupCoordinator.groupManager.handleTxnCompletion(producerId, offsetsPartitions.map(_.partition).toSet, isCommit) + } + + private def offsetAndMetadata(offset: Long): OffsetAndMetadata = { + OffsetAndMetadata(offset, "", timer.time.milliseconds()) + } +} + +object GroupCoordinatorTest { + def verifyLeaveGroupResult(leaveGroupResult: LeaveGroupResult, + expectedTopLevelError: Errors = Errors.NONE, + expectedMemberLevelErrors: List[Errors] = List.empty): Unit = { + assertEquals(expectedTopLevelError, leaveGroupResult.topLevelError) + if (expectedMemberLevelErrors.nonEmpty) { + assertEquals(expectedMemberLevelErrors.size, leaveGroupResult.memberResponses.size) + for (i <- expectedMemberLevelErrors.indices) { + assertEquals(expectedMemberLevelErrors(i), leaveGroupResult.memberResponses(i).error) + } + } + } +} diff --git a/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala new file mode 100644 index 0000000..5fe4bf9 --- /dev/null +++ b/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala @@ -0,0 +1,2651 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.coordinator.group + +import java.lang.management.ManagementFactory +import java.nio.ByteBuffer +import java.util.concurrent.locks.ReentrantLock +import java.util.{Collections, Optional} +import com.yammer.metrics.core.Gauge + +import javax.management.ObjectName +import kafka.api._ +import kafka.cluster.Partition +import kafka.common.OffsetAndMetadata +import kafka.log.{AppendOrigin, UnifiedLog, LogAppendInfo} +import kafka.metrics.KafkaYammerMetrics +import kafka.server.{FetchDataInfo, FetchLogEnd, HostedPartition, KafkaConfig, LogOffsetMetadata, ReplicaManager, RequestLocal} +import kafka.utils.{KafkaScheduler, MockTime, TestUtils} +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Subscription +import org.apache.kafka.clients.consumer.internals.ConsumerProtocol +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.internals.Topic +import org.apache.kafka.common.metrics.{JmxReporter, KafkaMetricsContext, Metrics => kMetrics} +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.record._ +import org.apache.kafka.common.requests.OffsetFetchResponse +import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse +import org.apache.kafka.common.utils.Utils +import org.easymock.{Capture, EasyMock, IAnswer} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} + +import scala.jdk.CollectionConverters._ +import scala.collection._ + +class GroupMetadataManagerTest { + + var time: MockTime = null + var replicaManager: ReplicaManager = null + var groupMetadataManager: GroupMetadataManager = null + var scheduler: KafkaScheduler = null + var partition: Partition = null + var defaultOffsetRetentionMs = Long.MaxValue + var metrics: kMetrics = null + + val groupId = "foo" + val groupInstanceId = "bar" + val groupPartitionId = 0 + val groupTopicPartition = new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, groupPartitionId) + val protocolType = "protocolType" + val rebalanceTimeout = 60000 + val sessionTimeout = 10000 + val defaultRequireStable = false + val numOffsetsPartitions = 2 + + private val offsetConfig = { + val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(nodeId = 0, zkConnect = "")) + OffsetConfig(maxMetadataSize = config.offsetMetadataMaxSize, + loadBufferSize = config.offsetsLoadBufferSize, + offsetsRetentionMs = config.offsetsRetentionMinutes * 60 * 1000L, + offsetsRetentionCheckIntervalMs = config.offsetsRetentionCheckIntervalMs, + offsetsTopicNumPartitions = config.offsetsTopicPartitions, + offsetsTopicSegmentBytes = config.offsetsTopicSegmentBytes, + offsetsTopicReplicationFactor = config.offsetsTopicReplicationFactor, + offsetsTopicCompressionCodec = config.offsetsTopicCompressionCodec, + offsetCommitTimeoutMs = config.offsetCommitTimeoutMs, + offsetCommitRequiredAcks = config.offsetCommitRequiredAcks) + } + + @BeforeEach + def setUp(): Unit = { + defaultOffsetRetentionMs = offsetConfig.offsetsRetentionMs + metrics = new kMetrics() + time = new MockTime + replicaManager = EasyMock.createNiceMock(classOf[ReplicaManager]) + groupMetadataManager = new GroupMetadataManager(0, ApiVersion.latestVersion, offsetConfig, replicaManager, + time, metrics) + groupMetadataManager.startup(() => numOffsetsPartitions, false) + partition = EasyMock.niceMock(classOf[Partition]) + } + + @AfterEach + def tearDown(): Unit = { + groupMetadataManager.shutdown() + } + + @Test + def testLogInfoFromCleanupGroupMetadata(): Unit = { + var expiredOffsets: Int = 0 + var infoCount = 0 + val gmm = new GroupMetadataManager(0, ApiVersion.latestVersion, offsetConfig, replicaManager, time, metrics) { + override def cleanupGroupMetadata(groups: Iterable[GroupMetadata], requestLocal: RequestLocal, + selector: GroupMetadata => Map[TopicPartition, OffsetAndMetadata]): Int = expiredOffsets + + override def info(msg: => String): Unit = infoCount += 1 + } + gmm.startup(() => numOffsetsPartitions, false) + try { + // if there are no offsets to expire, we skip to log + gmm.cleanupGroupMetadata() + assertEquals(0, infoCount) + // if there are offsets to expire, we should log info + expiredOffsets = 100 + gmm.cleanupGroupMetadata() + assertEquals(1, infoCount) + } finally { + gmm.shutdown() + } + } + + @Test + def testLoadOffsetsWithoutGroup(): Unit = { + val groupMetadataTopicPartition = groupTopicPartition + val startOffset = 15L + val groupEpoch = 2 + + val committedOffsets = Map( + new TopicPartition("foo", 0) -> 23L, + new TopicPartition("foo", 1) -> 455L, + new TopicPartition("bar", 0) -> 8992L + ) + + val offsetCommitRecords = createCommittedOffsetRecords(committedOffsets) + val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, offsetCommitRecords.toArray: _*) + expectGroupMetadataLoad(groupMetadataTopicPartition, startOffset, records) + + EasyMock.replay(replicaManager) + + groupMetadataManager.loadGroupsAndOffsets(groupMetadataTopicPartition, groupEpoch, _ => (), 0L) + + val group = groupMetadataManager.getGroup(groupId).getOrElse(throw new AssertionError("Group was not loaded into the cache")) + assertEquals(groupId, group.groupId) + assertEquals(Empty, group.currentState) + assertEquals(committedOffsets.size, group.allOffsets.size) + committedOffsets.foreach { case (topicPartition, offset) => + assertEquals(Some(offset), group.offset(topicPartition).map(_.offset)) + } + } + + @Test + def testLoadEmptyGroupWithOffsets(): Unit = { + val groupMetadataTopicPartition = groupTopicPartition + val generation = 15 + val protocolType = "consumer" + val startOffset = 15L + val groupEpoch = 2 + val committedOffsets = Map( + new TopicPartition("foo", 0) -> 23L, + new TopicPartition("foo", 1) -> 455L, + new TopicPartition("bar", 0) -> 8992L + ) + + val offsetCommitRecords = createCommittedOffsetRecords(committedOffsets) + val groupMetadataRecord = buildEmptyGroupRecord(generation, protocolType) + val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, + (offsetCommitRecords ++ Seq(groupMetadataRecord)).toArray: _*) + + expectGroupMetadataLoad(groupMetadataTopicPartition, startOffset, records) + + EasyMock.replay(replicaManager) + + groupMetadataManager.loadGroupsAndOffsets(groupMetadataTopicPartition, groupEpoch, _ => (), 0L) + + val group = groupMetadataManager.getGroup(groupId).getOrElse(throw new AssertionError("Group was not loaded into the cache")) + assertEquals(groupId, group.groupId) + assertEquals(Empty, group.currentState) + assertEquals(generation, group.generationId) + assertEquals(Some(protocolType), group.protocolType) + assertNull(group.leaderOrNull) + assertNull(group.protocolName.orNull) + committedOffsets.foreach { case (topicPartition, offset) => + assertEquals(Some(offset), group.offset(topicPartition).map(_.offset)) + } + } + + @Test + def testLoadTransactionalOffsetsWithoutGroup(): Unit = { + val groupMetadataTopicPartition = groupTopicPartition + val producerId = 1000L + val producerEpoch: Short = 2 + val groupEpoch = 2 + + val committedOffsets = Map( + new TopicPartition("foo", 0) -> 23L, + new TopicPartition("foo", 1) -> 455L, + new TopicPartition("bar", 0) -> 8992L + ) + + val buffer = ByteBuffer.allocate(1024) + var nextOffset = 0 + nextOffset += appendTransactionalOffsetCommits(buffer, producerId, producerEpoch, nextOffset, committedOffsets) + nextOffset += completeTransactionalOffsetCommit(buffer, producerId, producerEpoch, nextOffset, isCommit = true) + buffer.flip() + + val records = MemoryRecords.readableRecords(buffer) + expectGroupMetadataLoad(groupMetadataTopicPartition, 0, records) + + EasyMock.replay(replicaManager) + + groupMetadataManager.loadGroupsAndOffsets(groupMetadataTopicPartition, groupEpoch, _ => (), 0L) + + val group = groupMetadataManager.getGroup(groupId).getOrElse(throw new AssertionError("Group was not loaded into the cache")) + assertEquals(groupId, group.groupId) + assertEquals(Empty, group.currentState) + assertEquals(committedOffsets.size, group.allOffsets.size) + committedOffsets.foreach { case (topicPartition, offset) => + assertEquals(Some(offset), group.offset(topicPartition).map(_.offset)) + } + } + + @Test + def testDoNotLoadAbortedTransactionalOffsetCommits(): Unit = { + val groupMetadataTopicPartition = groupTopicPartition + val producerId = 1000L + val producerEpoch: Short = 2 + val groupEpoch = 2 + + val abortedOffsets = Map( + new TopicPartition("foo", 0) -> 23L, + new TopicPartition("foo", 1) -> 455L, + new TopicPartition("bar", 0) -> 8992L + ) + + val buffer = ByteBuffer.allocate(1024) + var nextOffset = 0 + nextOffset += appendTransactionalOffsetCommits(buffer, producerId, producerEpoch, nextOffset, abortedOffsets) + nextOffset += completeTransactionalOffsetCommit(buffer, producerId, producerEpoch, nextOffset, isCommit = false) + buffer.flip() + + val records = MemoryRecords.readableRecords(buffer) + expectGroupMetadataLoad(groupMetadataTopicPartition, 0, records) + + EasyMock.replay(replicaManager) + + groupMetadataManager.loadGroupsAndOffsets(groupMetadataTopicPartition, groupEpoch, _ => (), 0L) + + // Since there are no committed offsets for the group, and there is no other group metadata, we don't expect the + // group to be loaded. + assertEquals(None, groupMetadataManager.getGroup(groupId)) + } + + @Test + def testGroupLoadedWithPendingCommits(): Unit = { + val groupMetadataTopicPartition = groupTopicPartition + val producerId = 1000L + val producerEpoch: Short = 2 + val groupEpoch = 2 + + val foo0 = new TopicPartition("foo", 0) + val foo1 = new TopicPartition("foo", 1) + val bar0 = new TopicPartition("bar", 0) + val pendingOffsets = Map( + foo0 -> 23L, + foo1 -> 455L, + bar0 -> 8992L + ) + + val buffer = ByteBuffer.allocate(1024) + var nextOffset = 0 + nextOffset += appendTransactionalOffsetCommits(buffer, producerId, producerEpoch, nextOffset, pendingOffsets) + buffer.flip() + + val records = MemoryRecords.readableRecords(buffer) + expectGroupMetadataLoad(groupMetadataTopicPartition, 0, records) + + EasyMock.replay(replicaManager) + + groupMetadataManager.loadGroupsAndOffsets(groupMetadataTopicPartition, groupEpoch, _ => (), 0L) + + // The group should be loaded with pending offsets. + val group = groupMetadataManager.getGroup(groupId).getOrElse(throw new AssertionError("Group was not loaded into the cache")) + assertEquals(groupId, group.groupId) + assertEquals(Empty, group.currentState) + // Ensure that no offsets are materialized, but that we have offsets pending. + assertEquals(0, group.allOffsets.size) + assertTrue(group.hasOffsets) + assertTrue(group.hasPendingOffsetCommitsFromProducer(producerId)) + assertTrue(group.hasPendingOffsetCommitsForTopicPartition(foo0)) + assertTrue(group.hasPendingOffsetCommitsForTopicPartition(foo1)) + assertTrue(group.hasPendingOffsetCommitsForTopicPartition(bar0)) + } + + @Test + def testLoadWithCommittedAndAbortedTransactionalOffsetCommits(): Unit = { + // A test which loads a log with a mix of committed and aborted transactional offset committed messages. + val groupMetadataTopicPartition = groupTopicPartition + val producerId = 1000L + val producerEpoch: Short = 2 + val groupEpoch = 2 + + val committedOffsets = Map( + new TopicPartition("foo", 0) -> 23L, + new TopicPartition("foo", 1) -> 455L, + new TopicPartition("bar", 0) -> 8992L + ) + + val abortedOffsets = Map( + new TopicPartition("foo", 2) -> 231L, + new TopicPartition("foo", 3) -> 4551L, + new TopicPartition("bar", 1) -> 89921L + ) + + val buffer = ByteBuffer.allocate(1024) + var nextOffset = 0 + nextOffset += appendTransactionalOffsetCommits(buffer, producerId, producerEpoch, nextOffset, abortedOffsets) + nextOffset += completeTransactionalOffsetCommit(buffer, producerId, producerEpoch, nextOffset, isCommit = false) + nextOffset += appendTransactionalOffsetCommits(buffer, producerId, producerEpoch, nextOffset, committedOffsets) + nextOffset += completeTransactionalOffsetCommit(buffer, producerId, producerEpoch, nextOffset, isCommit = true) + buffer.flip() + + val records = MemoryRecords.readableRecords(buffer) + expectGroupMetadataLoad(groupMetadataTopicPartition, 0, records) + + EasyMock.replay(replicaManager) + + groupMetadataManager.loadGroupsAndOffsets(groupMetadataTopicPartition, groupEpoch, _ => (), 0L) + + val group = groupMetadataManager.getGroup(groupId).getOrElse(throw new AssertionError("Group was not loaded into the cache")) + assertEquals(groupId, group.groupId) + assertEquals(Empty, group.currentState) + // Ensure that only the committed offsets are materialized, and that there are no pending commits for the producer. + // This allows us to be certain that the aborted offset commits are truly discarded. + assertEquals(committedOffsets.size, group.allOffsets.size) + committedOffsets.foreach { case (topicPartition, offset) => + assertEquals(Some(offset), group.offset(topicPartition).map(_.offset)) + } + assertFalse(group.hasPendingOffsetCommitsFromProducer(producerId)) + } + + @Test + def testLoadWithCommittedAndAbortedAndPendingTransactionalOffsetCommits(): Unit = { + val groupMetadataTopicPartition = groupTopicPartition + val producerId = 1000L + val producerEpoch: Short = 2 + val groupEpoch = 2 + + val committedOffsets = Map( + new TopicPartition("foo", 0) -> 23L, + new TopicPartition("foo", 1) -> 455L, + new TopicPartition("bar", 0) -> 8992L + ) + + val foo3 = new TopicPartition("foo", 3) + + val abortedOffsets = Map( + new TopicPartition("foo", 2) -> 231L, + foo3 -> 4551L, + new TopicPartition("bar", 1) -> 89921L + ) + + val pendingOffsets = Map( + foo3 -> 2312L, + new TopicPartition("foo", 4) -> 45512L, + new TopicPartition("bar", 2) -> 899212L + ) + + val buffer = ByteBuffer.allocate(1024) + var nextOffset = 0 + val commitOffsetsLogPosition = nextOffset + nextOffset += appendTransactionalOffsetCommits(buffer, producerId, producerEpoch, nextOffset, committedOffsets) + nextOffset += completeTransactionalOffsetCommit(buffer, producerId, producerEpoch, nextOffset, isCommit = true) + nextOffset += appendTransactionalOffsetCommits(buffer, producerId, producerEpoch, nextOffset, abortedOffsets) + nextOffset += completeTransactionalOffsetCommit(buffer, producerId, producerEpoch, nextOffset, isCommit = false) + nextOffset += appendTransactionalOffsetCommits(buffer, producerId, producerEpoch, nextOffset, pendingOffsets) + buffer.flip() + + val records = MemoryRecords.readableRecords(buffer) + expectGroupMetadataLoad(groupMetadataTopicPartition, 0, records) + + EasyMock.replay(replicaManager) + + groupMetadataManager.loadGroupsAndOffsets(groupMetadataTopicPartition, groupEpoch, _ => (), 0L) + + val group = groupMetadataManager.getGroup(groupId).getOrElse(throw new AssertionError("Group was not loaded into the cache")) + assertEquals(groupId, group.groupId) + assertEquals(Empty, group.currentState) + + // Ensure that only the committed offsets are materialized, and that there are no pending commits for the producer. + // This allows us to be certain that the aborted offset commits are truly discarded. + assertEquals(committedOffsets.size, group.allOffsets.size) + committedOffsets.foreach { case (topicPartition, offset) => + assertEquals(Some(offset), group.offset(topicPartition).map(_.offset)) + assertEquals(Some(commitOffsetsLogPosition), group.offsetWithRecordMetadata(topicPartition).head.appendedBatchOffset) + } + + // We should have pending commits. + assertTrue(group.hasPendingOffsetCommitsFromProducer(producerId)) + assertTrue(group.hasPendingOffsetCommitsForTopicPartition(foo3)) + + // The loaded pending commits should materialize after a commit marker comes in. + groupMetadataManager.handleTxnCompletion(producerId, List(groupMetadataTopicPartition.partition).toSet, isCommit = true) + assertFalse(group.hasPendingOffsetCommitsFromProducer(producerId)) + pendingOffsets.foreach { case (topicPartition, offset) => + assertEquals(Some(offset), group.offset(topicPartition).map(_.offset)) + } + } + + @Test + def testLoadTransactionalOffsetCommitsFromMultipleProducers(): Unit = { + val groupMetadataTopicPartition = groupTopicPartition + val firstProducerId = 1000L + val firstProducerEpoch: Short = 2 + val secondProducerId = 1001L + val secondProducerEpoch: Short = 3 + val groupEpoch = 2 + + val committedOffsetsFirstProducer = Map( + new TopicPartition("foo", 0) -> 23L, + new TopicPartition("foo", 1) -> 455L, + new TopicPartition("bar", 0) -> 8992L + ) + + val committedOffsetsSecondProducer = Map( + new TopicPartition("foo", 2) -> 231L, + new TopicPartition("foo", 3) -> 4551L, + new TopicPartition("bar", 1) -> 89921L + ) + + val buffer = ByteBuffer.allocate(1024) + var nextOffset = 0L + + val firstProduceRecordOffset = nextOffset + nextOffset += appendTransactionalOffsetCommits(buffer, firstProducerId, firstProducerEpoch, nextOffset, committedOffsetsFirstProducer) + nextOffset += completeTransactionalOffsetCommit(buffer, firstProducerId, firstProducerEpoch, nextOffset, isCommit = true) + + val secondProducerRecordOffset = nextOffset + nextOffset += appendTransactionalOffsetCommits(buffer, secondProducerId, secondProducerEpoch, nextOffset, committedOffsetsSecondProducer) + nextOffset += completeTransactionalOffsetCommit(buffer, secondProducerId, secondProducerEpoch, nextOffset, isCommit = true) + buffer.flip() + + val records = MemoryRecords.readableRecords(buffer) + expectGroupMetadataLoad(groupMetadataTopicPartition, 0, records) + + EasyMock.replay(replicaManager) + + groupMetadataManager.loadGroupsAndOffsets(groupMetadataTopicPartition, groupEpoch, _ => (), 0L) + + val group = groupMetadataManager.getGroup(groupId).getOrElse(throw new AssertionError("Group was not loaded into the cache")) + assertEquals(groupId, group.groupId) + assertEquals(Empty, group.currentState) + + // Ensure that only the committed offsets are materialized, and that there are no pending commits for the producer. + // This allows us to be certain that the aborted offset commits are truly discarded. + assertEquals(committedOffsetsFirstProducer.size + committedOffsetsSecondProducer.size, group.allOffsets.size) + committedOffsetsFirstProducer.foreach { case (topicPartition, offset) => + assertEquals(Some(offset), group.offset(topicPartition).map(_.offset)) + assertEquals(Some(firstProduceRecordOffset), group.offsetWithRecordMetadata(topicPartition).head.appendedBatchOffset) + } + committedOffsetsSecondProducer.foreach { case (topicPartition, offset) => + assertEquals(Some(offset), group.offset(topicPartition).map(_.offset)) + assertEquals(Some(secondProducerRecordOffset), group.offsetWithRecordMetadata(topicPartition).head.appendedBatchOffset) + } + } + + @Test + def testGroupLoadWithConsumerAndTransactionalOffsetCommitsConsumerWins(): Unit = { + val groupMetadataTopicPartition = groupTopicPartition + val producerId = 1000L + val producerEpoch: Short = 2 + val groupEpoch = 2 + + val transactionalOffsetCommits = Map( + new TopicPartition("foo", 0) -> 23L + ) + + val consumerOffsetCommits = Map( + new TopicPartition("foo", 0) -> 24L + ) + + val buffer = ByteBuffer.allocate(1024) + var nextOffset = 0 + nextOffset += appendTransactionalOffsetCommits(buffer, producerId, producerEpoch, nextOffset, transactionalOffsetCommits) + val consumerRecordOffset = nextOffset + nextOffset += appendConsumerOffsetCommit(buffer, nextOffset, consumerOffsetCommits) + nextOffset += completeTransactionalOffsetCommit(buffer, producerId, producerEpoch, nextOffset, isCommit = true) + buffer.flip() + + val records = MemoryRecords.readableRecords(buffer) + expectGroupMetadataLoad(groupMetadataTopicPartition, 0, records) + + EasyMock.replay(replicaManager) + + groupMetadataManager.loadGroupsAndOffsets(groupMetadataTopicPartition, groupEpoch, _ => (), 0L) + + // The group should be loaded with pending offsets. + val group = groupMetadataManager.getGroup(groupId).getOrElse(throw new AssertionError("Group was not loaded into the cache")) + assertEquals(groupId, group.groupId) + assertEquals(Empty, group.currentState) + assertEquals(1, group.allOffsets.size) + assertTrue(group.hasOffsets) + assertFalse(group.hasPendingOffsetCommitsFromProducer(producerId)) + assertEquals(consumerOffsetCommits.size, group.allOffsets.size) + consumerOffsetCommits.foreach { case (topicPartition, offset) => + assertEquals(Some(offset), group.offset(topicPartition).map(_.offset)) + assertEquals(Some(consumerRecordOffset), group.offsetWithRecordMetadata(topicPartition).head.appendedBatchOffset) + } + } + + @Test + def testGroupLoadWithConsumerAndTransactionalOffsetCommitsTransactionWins(): Unit = { + val groupMetadataTopicPartition = groupTopicPartition + val producerId = 1000L + val producerEpoch: Short = 2 + val groupEpoch = 2 + + val transactionalOffsetCommits = Map( + new TopicPartition("foo", 0) -> 23L + ) + + val consumerOffsetCommits = Map( + new TopicPartition("foo", 0) -> 24L + ) + + val buffer = ByteBuffer.allocate(1024) + var nextOffset = 0 + nextOffset += appendConsumerOffsetCommit(buffer, nextOffset, consumerOffsetCommits) + nextOffset += appendTransactionalOffsetCommits(buffer, producerId, producerEpoch, nextOffset, transactionalOffsetCommits) + nextOffset += completeTransactionalOffsetCommit(buffer, producerId, producerEpoch, nextOffset, isCommit = true) + buffer.flip() + + val records = MemoryRecords.readableRecords(buffer) + expectGroupMetadataLoad(groupMetadataTopicPartition, 0, records) + + EasyMock.replay(replicaManager) + + groupMetadataManager.loadGroupsAndOffsets(groupMetadataTopicPartition, groupEpoch, _ => (), 0L) + + // The group should be loaded with pending offsets. + val group = groupMetadataManager.getGroup(groupId).getOrElse(throw new AssertionError("Group was not loaded into the cache")) + assertEquals(groupId, group.groupId) + assertEquals(Empty, group.currentState) + assertEquals(1, group.allOffsets.size) + assertTrue(group.hasOffsets) + assertFalse(group.hasPendingOffsetCommitsFromProducer(producerId)) + assertEquals(consumerOffsetCommits.size, group.allOffsets.size) + transactionalOffsetCommits.foreach { case (topicPartition, offset) => + assertEquals(Some(offset), group.offset(topicPartition).map(_.offset)) + } + } + + @Test + def testGroupNotExists(): Unit = { + // group is not owned + assertFalse(groupMetadataManager.groupNotExists(groupId)) + + groupMetadataManager.addPartitionOwnership(groupPartitionId) + // group is owned but does not exist yet + assertTrue(groupMetadataManager.groupNotExists(groupId)) + + val group = new GroupMetadata(groupId, Empty, time) + groupMetadataManager.addGroup(group) + + // group is owned but not Dead + assertFalse(groupMetadataManager.groupNotExists(groupId)) + + group.transitionTo(Dead) + // group is owned and Dead + assertTrue(groupMetadataManager.groupNotExists(groupId)) + } + + private def appendConsumerOffsetCommit(buffer: ByteBuffer, baseOffset: Long, offsets: Map[TopicPartition, Long]) = { + val builder = MemoryRecords.builder(buffer, CompressionType.NONE, TimestampType.LOG_APPEND_TIME, baseOffset) + val commitRecords = createCommittedOffsetRecords(offsets) + commitRecords.foreach(builder.append) + builder.build() + offsets.size + } + + private def appendTransactionalOffsetCommits(buffer: ByteBuffer, producerId: Long, producerEpoch: Short, + baseOffset: Long, offsets: Map[TopicPartition, Long]): Int = { + val builder = MemoryRecords.builder(buffer, CompressionType.NONE, baseOffset, producerId, producerEpoch, 0, true) + val commitRecords = createCommittedOffsetRecords(offsets) + commitRecords.foreach(builder.append) + builder.build() + offsets.size + } + + private def completeTransactionalOffsetCommit(buffer: ByteBuffer, producerId: Long, producerEpoch: Short, baseOffset: Long, + isCommit: Boolean): Int = { + val builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, CompressionType.NONE, + TimestampType.LOG_APPEND_TIME, baseOffset, time.milliseconds(), producerId, producerEpoch, 0, true, true, + RecordBatch.NO_PARTITION_LEADER_EPOCH) + val controlRecordType = if (isCommit) ControlRecordType.COMMIT else ControlRecordType.ABORT + builder.appendEndTxnMarker(time.milliseconds(), new EndTransactionMarker(controlRecordType, 0)) + builder.build() + 1 + } + + @Test + def testLoadOffsetsWithTombstones(): Unit = { + val groupMetadataTopicPartition = groupTopicPartition + val startOffset = 15L + val groupEpoch = 2 + + val tombstonePartition = new TopicPartition("foo", 1) + val committedOffsets = Map( + new TopicPartition("foo", 0) -> 23L, + tombstonePartition -> 455L, + new TopicPartition("bar", 0) -> 8992L + ) + + val offsetCommitRecords = createCommittedOffsetRecords(committedOffsets) + val tombstone = new SimpleRecord(GroupMetadataManager.offsetCommitKey(groupId, tombstonePartition), null) + val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, + (offsetCommitRecords ++ Seq(tombstone)).toArray: _*) + + expectGroupMetadataLoad(groupMetadataTopicPartition, startOffset, records) + + EasyMock.replay(replicaManager) + + groupMetadataManager.loadGroupsAndOffsets(groupMetadataTopicPartition, groupEpoch, _ => (), 0L) + + val group = groupMetadataManager.getGroup(groupId).getOrElse(throw new AssertionError("Group was not loaded into the cache")) + assertEquals(groupId, group.groupId) + assertEquals(Empty, group.currentState) + assertEquals(committedOffsets.size - 1, group.allOffsets.size) + committedOffsets.foreach { case (topicPartition, offset) => + if (topicPartition == tombstonePartition) + assertEquals(None, group.offset(topicPartition)) + else + assertEquals(Some(offset), group.offset(topicPartition).map(_.offset)) + } + } + + @Test + def testLoadOffsetsAndGroup(): Unit = { + loadOffsetsAndGroup(groupTopicPartition, 2) + } + + def loadOffsetsAndGroup(groupMetadataTopicPartition: TopicPartition, groupEpoch: Int): GroupMetadata = { + val generation = 935 + val protocolType = "consumer" + val protocol = "range" + val startOffset = 15L + val committedOffsets = Map( + new TopicPartition("foo", 0) -> 23L, + new TopicPartition("foo", 1) -> 455L, + new TopicPartition("bar", 0) -> 8992L + ) + + val offsetCommitRecords = createCommittedOffsetRecords(committedOffsets) + val memberId = "98098230493" + val groupMetadataRecord = buildStableGroupRecordWithMember(generation, protocolType, protocol, memberId) + val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, + (offsetCommitRecords ++ Seq(groupMetadataRecord)).toArray: _*) + + expectGroupMetadataLoad(groupMetadataTopicPartition, startOffset, records) + + EasyMock.replay(replicaManager) + + groupMetadataManager.loadGroupsAndOffsets(groupMetadataTopicPartition, groupEpoch, _ => (), 0L) + + val group = groupMetadataManager.getGroup(groupId).getOrElse(throw new AssertionError("Group was not loaded into the cache")) + assertEquals(groupId, group.groupId) + assertEquals(Stable, group.currentState) + assertEquals(memberId, group.leaderOrNull) + assertEquals(generation, group.generationId) + assertEquals(Some(protocolType), group.protocolType) + assertEquals(protocol, group.protocolName.orNull) + assertEquals(Set(memberId), group.allMembers) + assertEquals(committedOffsets.size, group.allOffsets.size) + committedOffsets.foreach { case (topicPartition, offset) => + assertEquals(Some(offset), group.offset(topicPartition).map(_.offset)) + assertTrue(group.offset(topicPartition).map(_.expireTimestamp).contains(None)) + } + group + } + + @Test + def testLoadOffsetsAndGroupIgnored(): Unit = { + val groupEpoch = 2 + loadOffsetsAndGroup(groupTopicPartition, groupEpoch) + assertEquals(groupEpoch, groupMetadataManager.epochForPartitionId.get(groupTopicPartition.partition())) + + groupMetadataManager.removeGroupsAndOffsets(groupTopicPartition, Some(groupEpoch), _ => ()) + assertTrue(groupMetadataManager.getGroup(groupId).isEmpty, + "Removed group remained in cache") + assertEquals(groupEpoch, groupMetadataManager.epochForPartitionId.get(groupTopicPartition.partition())) + + groupMetadataManager.loadGroupsAndOffsets(groupTopicPartition, groupEpoch - 1, _ => (), 0L) + assertTrue(groupMetadataManager.getGroup(groupId).isEmpty, + "Removed group remained in cache") + assertEquals(groupEpoch, groupMetadataManager.epochForPartitionId.get(groupTopicPartition.partition())) + } + + @Test + def testUnloadOffsetsAndGroup(): Unit = { + val groupEpoch = 2 + loadOffsetsAndGroup(groupTopicPartition, groupEpoch) + + groupMetadataManager.removeGroupsAndOffsets(groupTopicPartition, Some(groupEpoch), _ => ()) + assertEquals(groupEpoch, groupMetadataManager.epochForPartitionId.get(groupTopicPartition.partition())) + assertTrue(groupMetadataManager.getGroup(groupId).isEmpty, + "Removed group remained in cache") + } + + @Test + def testUnloadOffsetsAndGroupIgnored(): Unit = { + val groupEpoch = 2 + val initiallyLoaded = loadOffsetsAndGroup(groupTopicPartition, groupEpoch) + + groupMetadataManager.removeGroupsAndOffsets(groupTopicPartition, Some(groupEpoch - 1), _ => ()) + assertEquals(groupEpoch, groupMetadataManager.epochForPartitionId.get(groupTopicPartition.partition())) + val group = groupMetadataManager.getGroup(groupId).getOrElse(throw new AssertionError("Group was not loaded into the cache")) + assertEquals(initiallyLoaded.groupId, group.groupId) + assertEquals(initiallyLoaded.currentState, group.currentState) + assertEquals(initiallyLoaded.leaderOrNull, group.leaderOrNull) + assertEquals(initiallyLoaded.generationId, group.generationId) + assertEquals(initiallyLoaded.protocolType, group.protocolType) + assertEquals(initiallyLoaded.protocolName.orNull, group.protocolName.orNull) + assertEquals(initiallyLoaded.allMembers, group.allMembers) + assertEquals(initiallyLoaded.allOffsets.size, group.allOffsets.size) + initiallyLoaded.allOffsets.foreach { case (topicPartition, offset) => + assertEquals(Some(offset), group.offset(topicPartition)) + assertTrue(group.offset(topicPartition).map(_.expireTimestamp).contains(None)) + } + } + + @Test + def testUnloadOffsetsAndGroupIgnoredAfterStopReplica(): Unit = { + val groupEpoch = 2 + val initiallyLoaded = loadOffsetsAndGroup(groupTopicPartition, groupEpoch) + + groupMetadataManager.removeGroupsAndOffsets(groupTopicPartition, None, _ => ()) + assertTrue(groupMetadataManager.getGroup(groupId).isEmpty, + "Removed group remained in cache") + assertEquals(groupEpoch, groupMetadataManager.epochForPartitionId.get(groupTopicPartition.partition()), + "Replica which was stopped still in epochForPartitionId") + + EasyMock.reset(replicaManager) + loadOffsetsAndGroup(groupTopicPartition, groupEpoch + 1) + val group = groupMetadataManager.getGroup(groupId).getOrElse(throw new AssertionError("Group was not loaded into the cache")) + assertEquals(initiallyLoaded.groupId, group.groupId) + assertEquals(initiallyLoaded.currentState, group.currentState) + assertEquals(initiallyLoaded.leaderOrNull, group.leaderOrNull) + assertEquals(initiallyLoaded.generationId, group.generationId) + assertEquals(initiallyLoaded.protocolType, group.protocolType) + assertEquals(initiallyLoaded.protocolName.orNull, group.protocolName.orNull) + assertEquals(initiallyLoaded.allMembers, group.allMembers) + assertEquals(initiallyLoaded.allOffsets.size, group.allOffsets.size) + initiallyLoaded.allOffsets.foreach { case (topicPartition, offset) => + assertEquals(Some(offset), group.offset(topicPartition)) + assertTrue(group.offset(topicPartition).map(_.expireTimestamp).contains(None)) + } + } + + @Test + def testLoadGroupWithTombstone(): Unit = { + val groupMetadataTopicPartition = groupTopicPartition + val startOffset = 15L + val groupEpoch = 2 + val memberId = "98098230493" + val groupMetadataRecord = buildStableGroupRecordWithMember(generation = 15, + protocolType = "consumer", protocol = "range", memberId) + val groupMetadataTombstone = new SimpleRecord(GroupMetadataManager.groupMetadataKey(groupId), null) + val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, + Seq(groupMetadataRecord, groupMetadataTombstone).toArray: _*) + + expectGroupMetadataLoad(groupMetadataTopicPartition, startOffset, records) + + EasyMock.replay(replicaManager) + + groupMetadataManager.loadGroupsAndOffsets(groupMetadataTopicPartition, groupEpoch, _ => (), 0L) + + assertEquals(None, groupMetadataManager.getGroup(groupId)) + } + + @Test + def testLoadGroupWithLargeGroupMetadataRecord(): Unit = { + val groupMetadataTopicPartition = groupTopicPartition + val startOffset = 15L + val groupEpoch = 2 + val committedOffsets = Map( + new TopicPartition("foo", 0) -> 23L, + new TopicPartition("foo", 1) -> 455L, + new TopicPartition("bar", 0) -> 8992L + ) + + // create a GroupMetadata record larger then offsets.load.buffer.size (here at least 16 bytes larger) + val assignmentSize = OffsetConfig.DefaultLoadBufferSize + 16 + val memberId = "98098230493" + + val offsetCommitRecords = createCommittedOffsetRecords(committedOffsets) + val groupMetadataRecord = buildStableGroupRecordWithMember(generation = 15, + protocolType = "consumer", protocol = "range", memberId, new Array[Byte](assignmentSize)) + val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, + (offsetCommitRecords ++ Seq(groupMetadataRecord)).toArray: _*) + + expectGroupMetadataLoad(groupMetadataTopicPartition, startOffset, records) + + EasyMock.replay(replicaManager) + + groupMetadataManager.loadGroupsAndOffsets(groupMetadataTopicPartition, groupEpoch, _ => (), 0L) + + val group = groupMetadataManager.getGroup(groupId).getOrElse(throw new AssertionError("Group was not loaded into the cache")) + committedOffsets.foreach { case (topicPartition, offset) => + assertEquals(Some(offset), group.offset(topicPartition).map(_.offset)) + } + } + + @Test + def testLoadGroupAndOffsetsWithCorruptedLog(): Unit = { + // Simulate a case where startOffset < endOffset but log is empty. This could theoretically happen + // when all the records are expired and the active segment is truncated or when the partition + // is accidentally corrupted. + val startOffset = 0L + val endOffset = 10L + val groupEpoch = 2 + + val logMock: UnifiedLog = EasyMock.mock(classOf[UnifiedLog]) + EasyMock.expect(replicaManager.getLog(groupTopicPartition)).andStubReturn(Some(logMock)) + expectGroupMetadataLoad(logMock, startOffset, MemoryRecords.EMPTY) + EasyMock.expect(replicaManager.getLogEndOffset(groupTopicPartition)).andStubReturn(Some(endOffset)) + EasyMock.replay(logMock) + + EasyMock.replay(replicaManager) + + groupMetadataManager.loadGroupsAndOffsets(groupTopicPartition, groupEpoch, _ => (), 0L) + + EasyMock.verify(logMock) + EasyMock.verify(replicaManager) + + assertFalse(groupMetadataManager.isPartitionLoading(groupTopicPartition.partition())) + } + + @Test + def testOffsetWriteAfterGroupRemoved(): Unit = { + // this test case checks the following scenario: + // 1. the group exists at some point in time, but is later removed (because all members left) + // 2. a "simple" consumer (i.e. not a consumer group) then uses the same groupId to commit some offsets + + val groupMetadataTopicPartition = groupTopicPartition + val generation = 293 + val protocolType = "consumer" + val protocol = "range" + val startOffset = 15L + val groupEpoch = 2 + + val committedOffsets = Map( + new TopicPartition("foo", 0) -> 23L, + new TopicPartition("foo", 1) -> 455L, + new TopicPartition("bar", 0) -> 8992L + ) + val offsetCommitRecords = createCommittedOffsetRecords(committedOffsets) + val memberId = "98098230493" + val groupMetadataRecord = buildStableGroupRecordWithMember(generation, protocolType, protocol, memberId) + val groupMetadataTombstone = new SimpleRecord(GroupMetadataManager.groupMetadataKey(groupId), null) + val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, + (Seq(groupMetadataRecord, groupMetadataTombstone) ++ offsetCommitRecords).toArray: _*) + + expectGroupMetadataLoad(groupMetadataTopicPartition, startOffset, records) + + EasyMock.replay(replicaManager) + + groupMetadataManager.loadGroupsAndOffsets(groupMetadataTopicPartition, groupEpoch, _ => (), 0L) + + val group = groupMetadataManager.getGroup(groupId).getOrElse(throw new AssertionError("Group was not loaded into the cache")) + assertEquals(groupId, group.groupId) + assertEquals(Empty, group.currentState) + assertEquals(committedOffsets.size, group.allOffsets.size) + committedOffsets.foreach { case (topicPartition, offset) => + assertEquals(Some(offset), group.offset(topicPartition).map(_.offset)) + } + } + + @Test + def testLoadGroupAndOffsetsFromDifferentSegments(): Unit = { + val generation = 293 + val protocolType = "consumer" + val protocol = "range" + val startOffset = 15L + val groupEpoch = 2 + val tp0 = new TopicPartition("foo", 0) + val tp1 = new TopicPartition("foo", 1) + val tp2 = new TopicPartition("bar", 0) + val tp3 = new TopicPartition("xxx", 0) + + val logMock: UnifiedLog = EasyMock.mock(classOf[UnifiedLog]) + EasyMock.expect(replicaManager.getLog(groupTopicPartition)).andStubReturn(Some(logMock)) + + val segment1MemberId = "a" + val segment1Offsets = Map(tp0 -> 23L, tp1 -> 455L, tp3 -> 42L) + val segment1Records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, + (createCommittedOffsetRecords(segment1Offsets) ++ Seq(buildStableGroupRecordWithMember( + generation, protocolType, protocol, segment1MemberId))).toArray: _*) + val segment1End = expectGroupMetadataLoad(logMock, startOffset, segment1Records) + + val segment2MemberId = "b" + val segment2Offsets = Map(tp0 -> 33L, tp2 -> 8992L, tp3 -> 10L) + val segment2Records = MemoryRecords.withRecords(segment1End, CompressionType.NONE, + (createCommittedOffsetRecords(segment2Offsets) ++ Seq(buildStableGroupRecordWithMember( + generation, protocolType, protocol, segment2MemberId))).toArray: _*) + val segment2End = expectGroupMetadataLoad(logMock, segment1End, segment2Records) + + EasyMock.expect(replicaManager.getLogEndOffset(groupTopicPartition)).andStubReturn(Some(segment2End)) + + EasyMock.replay(logMock, replicaManager) + + groupMetadataManager.loadGroupsAndOffsets(groupTopicPartition, groupEpoch, _ => (), 0L) + + val group = groupMetadataManager.getGroup(groupId).getOrElse(throw new AssertionError("Group was not loaded into the cache")) + assertEquals(groupId, group.groupId) + assertEquals(Stable, group.currentState) + + assertEquals(segment2MemberId, group.leaderOrNull, "segment2 group record member should be elected") + assertEquals(Set(segment2MemberId), group.allMembers, "segment2 group record member should be only member") + + // offsets of segment1 should be overridden by segment2 offsets of the same topic partitions + val committedOffsets = segment1Offsets ++ segment2Offsets + assertEquals(committedOffsets.size, group.allOffsets.size) + committedOffsets.foreach { case (topicPartition, offset) => + assertEquals(Some(offset), group.offset(topicPartition).map(_.offset)) + } + } + + @Test + def testAddGroup(): Unit = { + val group = new GroupMetadata("foo", Empty, time) + assertEquals(group, groupMetadataManager.addGroup(group)) + assertEquals(group, groupMetadataManager.addGroup(new GroupMetadata("foo", Empty, time))) + } + + @Test + def testloadGroupWithStaticMember(): Unit = { + val generation = 27 + val protocolType = "consumer" + val staticMemberId = "staticMemberId" + val dynamicMemberId = "dynamicMemberId" + + val staticMember = new MemberMetadata(staticMemberId, Some(groupInstanceId), "", "", rebalanceTimeout, sessionTimeout, + protocolType, List(("protocol", Array[Byte]()))) + + val dynamicMember = new MemberMetadata(dynamicMemberId, None, "", "", rebalanceTimeout, sessionTimeout, + protocolType, List(("protocol", Array[Byte]()))) + + val members = Seq(staticMember, dynamicMember) + + val group = GroupMetadata.loadGroup(groupId, Empty, generation, protocolType, null, null, None, members, time) + + assertTrue(group.is(Empty)) + assertEquals(generation, group.generationId) + assertEquals(Some(protocolType), group.protocolType) + assertTrue(group.has(staticMemberId)) + assertTrue(group.has(dynamicMemberId)) + assertTrue(group.hasStaticMember(groupInstanceId)) + assertEquals(Some(staticMemberId), group.currentStaticMemberId(groupInstanceId)) + } + + @Test + def testLoadConsumerGroup(): Unit = { + val generation = 27 + val protocolType = "consumer" + val protocol = "protocol" + val memberId = "member1" + val topic = "foo" + + val subscriptions = List( + ("protocol", ConsumerProtocol.serializeSubscription(new Subscription(List(topic).asJava)).array()) + ) + + val member = new MemberMetadata(memberId, Some(groupInstanceId), "", "", rebalanceTimeout, + sessionTimeout, protocolType, subscriptions) + + val members = Seq(member) + + val group = GroupMetadata.loadGroup(groupId, Stable, generation, protocolType, protocol, null, None, + members, time) + + assertTrue(group.is(Stable)) + assertEquals(generation, group.generationId) + assertEquals(Some(protocolType), group.protocolType) + assertEquals(protocol, group.protocolName.orNull) + assertEquals(Some(Set(topic)), group.getSubscribedTopics) + assertTrue(group.has(memberId)) + } + + @Test + def testLoadEmptyConsumerGroup(): Unit = { + val generation = 27 + val protocolType = "consumer" + + val group = GroupMetadata.loadGroup(groupId, Empty, generation, protocolType, null, null, None, + Seq(), time) + + assertTrue(group.is(Empty)) + assertEquals(generation, group.generationId) + assertEquals(Some(protocolType), group.protocolType) + assertNull(group.protocolName.orNull) + assertEquals(Some(Set.empty), group.getSubscribedTopics) + } + + @Test + def testLoadConsumerGroupWithFaultyConsumerProtocol(): Unit = { + val generation = 27 + val protocolType = "consumer" + val protocol = "protocol" + val memberId = "member1" + + val subscriptions = List(("protocol", Array[Byte]())) + + val member = new MemberMetadata(memberId, Some(groupInstanceId), "", "", rebalanceTimeout, + sessionTimeout, protocolType, subscriptions) + + val members = Seq(member) + + val group = GroupMetadata.loadGroup(groupId, Stable, generation, protocolType, protocol, null, None, + members, time) + + assertTrue(group.is(Stable)) + assertEquals(generation, group.generationId) + assertEquals(Some(protocolType), group.protocolType) + assertEquals(protocol, group.protocolName.orNull) + assertEquals(None, group.getSubscribedTopics) + assertTrue(group.has(memberId)) + } + + @Test + def testShouldThrowExceptionForUnsupportedGroupMetadataVersion(): Unit = { + val generation = 1 + val protocol = "range" + val memberId = "memberId" + val unsupportedVersion = Short.MinValue + + // put the unsupported version as the version value + val groupMetadataRecordValue = buildStableGroupRecordWithMember(generation, protocolType, protocol, memberId) + .value().putShort(unsupportedVersion) + // reset the position to the starting position 0 so that it can read the data in correct order + groupMetadataRecordValue.position(0) + + val e = assertThrows(classOf[IllegalStateException], + () => GroupMetadataManager.readGroupMessageValue(groupId, groupMetadataRecordValue, time)) + assertEquals(s"Unknown group metadata message version: $unsupportedVersion", e.getMessage) + } + + @Test + def testCurrentStateTimestampForAllGroupMetadataVersions(): Unit = { + val generation = 1 + val protocol = "range" + val memberId = "memberId" + + for (apiVersion <- ApiVersion.allVersions) { + val groupMetadataRecord = buildStableGroupRecordWithMember(generation, protocolType, protocol, memberId, apiVersion = apiVersion) + + val deserializedGroupMetadata = GroupMetadataManager.readGroupMessageValue(groupId, groupMetadataRecord.value(), time) + // GROUP_METADATA_VALUE_SCHEMA_V2 or higher should correctly set the currentStateTimestamp + if (apiVersion >= KAFKA_2_1_IV0) + assertEquals(Some(time.milliseconds()), deserializedGroupMetadata.currentStateTimestamp, + s"the apiVersion $apiVersion doesn't set the currentStateTimestamp correctly.") + else + assertTrue(deserializedGroupMetadata.currentStateTimestamp.isEmpty, + s"the apiVersion $apiVersion should not set the currentStateTimestamp.") + } + } + + @Test + def testReadFromOldGroupMetadata(): Unit = { + val generation = 1 + val protocol = "range" + val memberId = "memberId" + val oldApiVersions = Array(KAFKA_0_9_0, KAFKA_0_10_1_IV0, KAFKA_2_1_IV0) + + for (apiVersion <- oldApiVersions) { + val groupMetadataRecord = buildStableGroupRecordWithMember(generation, protocolType, protocol, memberId, apiVersion = apiVersion) + + val deserializedGroupMetadata = GroupMetadataManager.readGroupMessageValue(groupId, groupMetadataRecord.value(), time) + assertEquals(groupId, deserializedGroupMetadata.groupId) + assertEquals(generation, deserializedGroupMetadata.generationId) + assertEquals(protocolType, deserializedGroupMetadata.protocolType.get) + assertEquals(protocol, deserializedGroupMetadata.protocolName.orNull) + assertEquals(1, deserializedGroupMetadata.allMembers.size) + assertEquals(deserializedGroupMetadata.allMembers, deserializedGroupMetadata.allDynamicMembers) + assertTrue(deserializedGroupMetadata.allMembers.contains(memberId)) + assertTrue(deserializedGroupMetadata.allStaticMembers.isEmpty) + } + } + + @Test + def testStoreEmptyGroup(): Unit = { + val generation = 27 + val protocolType = "consumer" + + val group = GroupMetadata.loadGroup(groupId, Empty, generation, protocolType, null, null, None, Seq.empty, time) + groupMetadataManager.addGroup(group) + + val capturedRecords = expectAppendMessage(Errors.NONE) + EasyMock.replay(replicaManager) + + var maybeError: Option[Errors] = None + def callback(error: Errors): Unit = { + maybeError = Some(error) + } + + groupMetadataManager.storeGroup(group, Map.empty, callback) + assertEquals(Some(Errors.NONE), maybeError) + assertTrue(capturedRecords.hasCaptured) + val records = capturedRecords.getValue()(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, groupPartitionId)) + .records.asScala.toList + assertEquals(1, records.size) + + val record = records.head + val groupMetadata = GroupMetadataManager.readGroupMessageValue(groupId, record.value, time) + assertTrue(groupMetadata.is(Empty)) + assertEquals(generation, groupMetadata.generationId) + assertEquals(Some(protocolType), groupMetadata.protocolType) + } + + @Test + def testStoreEmptySimpleGroup(): Unit = { + val group = new GroupMetadata(groupId, Empty, time) + groupMetadataManager.addGroup(group) + + val capturedRecords = expectAppendMessage(Errors.NONE) + EasyMock.replay(replicaManager) + + var maybeError: Option[Errors] = None + def callback(error: Errors): Unit = { + maybeError = Some(error) + } + + groupMetadataManager.storeGroup(group, Map.empty, callback) + assertEquals(Some(Errors.NONE), maybeError) + assertTrue(capturedRecords.hasCaptured) + + assertTrue(capturedRecords.hasCaptured) + val records = capturedRecords.getValue()(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, groupPartitionId)) + .records.asScala.toList + assertEquals(1, records.size) + + val record = records.head + val groupMetadata = GroupMetadataManager.readGroupMessageValue(groupId, record.value, time) + assertTrue(groupMetadata.is(Empty)) + assertEquals(0, groupMetadata.generationId) + assertEquals(None, groupMetadata.protocolType) + } + + @Test + def testStoreGroupErrorMapping(): Unit = { + assertStoreGroupErrorMapping(Errors.NONE, Errors.NONE) + assertStoreGroupErrorMapping(Errors.UNKNOWN_TOPIC_OR_PARTITION, Errors.COORDINATOR_NOT_AVAILABLE) + assertStoreGroupErrorMapping(Errors.NOT_ENOUGH_REPLICAS, Errors.COORDINATOR_NOT_AVAILABLE) + assertStoreGroupErrorMapping(Errors.NOT_ENOUGH_REPLICAS_AFTER_APPEND, Errors.COORDINATOR_NOT_AVAILABLE) + assertStoreGroupErrorMapping(Errors.NOT_LEADER_OR_FOLLOWER, Errors.NOT_COORDINATOR) + assertStoreGroupErrorMapping(Errors.MESSAGE_TOO_LARGE, Errors.UNKNOWN_SERVER_ERROR) + assertStoreGroupErrorMapping(Errors.RECORD_LIST_TOO_LARGE, Errors.UNKNOWN_SERVER_ERROR) + assertStoreGroupErrorMapping(Errors.INVALID_FETCH_SIZE, Errors.UNKNOWN_SERVER_ERROR) + assertStoreGroupErrorMapping(Errors.CORRUPT_MESSAGE, Errors.CORRUPT_MESSAGE) + } + + private def assertStoreGroupErrorMapping(appendError: Errors, expectedError: Errors): Unit = { + EasyMock.reset(replicaManager) + + val group = new GroupMetadata(groupId, Empty, time) + groupMetadataManager.addGroup(group) + + expectAppendMessage(appendError) + EasyMock.replay(replicaManager) + + var maybeError: Option[Errors] = None + def callback(error: Errors): Unit = { + maybeError = Some(error) + } + + groupMetadataManager.storeGroup(group, Map.empty, callback) + assertEquals(Some(expectedError), maybeError) + + EasyMock.verify(replicaManager) + } + + @Test + def testStoreNonEmptyGroup(): Unit = { + val memberId = "memberId" + val clientId = "clientId" + val clientHost = "localhost" + + val group = new GroupMetadata(groupId, Empty, time) + groupMetadataManager.addGroup(group) + + val member = new MemberMetadata(memberId, Some(groupInstanceId), clientId, clientHost, rebalanceTimeout, sessionTimeout, + protocolType, List(("protocol", Array[Byte]()))) + group.add(member, _ => ()) + group.transitionTo(PreparingRebalance) + group.initNextGeneration() + + expectAppendMessage(Errors.NONE) + EasyMock.replay(replicaManager) + + var maybeError: Option[Errors] = None + def callback(error: Errors): Unit = { + maybeError = Some(error) + } + + groupMetadataManager.storeGroup(group, Map(memberId -> Array[Byte]()), callback) + assertEquals(Some(Errors.NONE), maybeError) + + EasyMock.verify(replicaManager) + } + + @Test + def testStoreNonEmptyGroupWhenCoordinatorHasMoved(): Unit = { + EasyMock.expect(replicaManager.getMagic(EasyMock.anyObject())).andReturn(None) + val memberId = "memberId" + val clientId = "clientId" + val clientHost = "localhost" + + val group = new GroupMetadata(groupId, Empty, time) + + val member = new MemberMetadata(memberId, Some(groupInstanceId), clientId, clientHost, rebalanceTimeout, sessionTimeout, + protocolType, List(("protocol", Array[Byte]()))) + group.add(member, _ => ()) + group.transitionTo(PreparingRebalance) + group.initNextGeneration() + + EasyMock.replay(replicaManager) + + var maybeError: Option[Errors] = None + def callback(error: Errors): Unit = { + maybeError = Some(error) + } + + groupMetadataManager.storeGroup(group, Map(memberId -> Array[Byte]()), callback) + assertEquals(Some(Errors.NOT_COORDINATOR), maybeError) + + EasyMock.verify(replicaManager) + } + + @Test + def testCommitOffset(): Unit = { + val memberId = "" + val topicPartition = new TopicPartition("foo", 0) + val offset = 37 + + groupMetadataManager.addPartitionOwnership(groupPartitionId) + + val group = new GroupMetadata(groupId, Empty, time) + groupMetadataManager.addGroup(group) + + val offsets = immutable.Map(topicPartition -> OffsetAndMetadata(offset, "", time.milliseconds())) + + expectAppendMessage(Errors.NONE) + EasyMock.replay(replicaManager) + + var commitErrors: Option[immutable.Map[TopicPartition, Errors]] = None + def callback(errors: immutable.Map[TopicPartition, Errors]): Unit = { + commitErrors = Some(errors) + } + + assertEquals(0, TestUtils.totalMetricValue(metrics, "offset-commit-count")) + groupMetadataManager.storeOffsets(group, memberId, offsets, callback) + assertTrue(group.hasOffsets) + + assertFalse(commitErrors.isEmpty) + val maybeError = commitErrors.get.get(topicPartition) + assertEquals(Some(Errors.NONE), maybeError) + assertTrue(group.hasOffsets) + + val cachedOffsets = groupMetadataManager.getOffsets(groupId, defaultRequireStable, Some(Seq(topicPartition))) + val maybePartitionResponse = cachedOffsets.get(topicPartition) + assertFalse(maybePartitionResponse.isEmpty) + + val partitionResponse = maybePartitionResponse.get + assertEquals(Errors.NONE, partitionResponse.error) + assertEquals(offset, partitionResponse.offset) + + EasyMock.verify(replicaManager) + // Will update sensor after commit + assertEquals(1, TestUtils.totalMetricValue(metrics, "offset-commit-count")) + } + + @Test + def testTransactionalCommitOffsetCommitted(): Unit = { + val memberId = "" + val topicPartition = new TopicPartition("foo", 0) + val offset = 37 + val producerId = 232L + val producerEpoch = 0.toShort + + groupMetadataManager.addPartitionOwnership(groupPartitionId) + + val group = new GroupMetadata(groupId, Empty, time) + groupMetadataManager.addGroup(group) + + val offsetAndMetadata = OffsetAndMetadata(offset, "", time.milliseconds()) + val offsets = immutable.Map(topicPartition -> offsetAndMetadata) + + val capturedResponseCallback = appendAndCaptureCallback() + EasyMock.replay(replicaManager) + + var commitErrors: Option[immutable.Map[TopicPartition, Errors]] = None + def callback(errors: immutable.Map[TopicPartition, Errors]): Unit = { + commitErrors = Some(errors) + } + + groupMetadataManager.storeOffsets(group, memberId, offsets, callback, producerId, producerEpoch) + assertTrue(group.hasOffsets) + assertTrue(group.allOffsets.isEmpty) + capturedResponseCallback.getValue.apply(Map(groupTopicPartition -> + new PartitionResponse(Errors.NONE, 0L, RecordBatch.NO_TIMESTAMP, 0L))) + + assertTrue(group.hasOffsets) + assertTrue(group.allOffsets.isEmpty) + + group.completePendingTxnOffsetCommit(producerId, isCommit = true) + assertTrue(group.hasOffsets) + assertFalse(group.allOffsets.isEmpty) + assertEquals(Some(offsetAndMetadata), group.offset(topicPartition)) + + EasyMock.verify(replicaManager) + } + + @Test + def testTransactionalCommitOffsetAppendFailure(): Unit = { + val memberId = "" + val topicPartition = new TopicPartition("foo", 0) + val offset = 37 + val producerId = 232L + val producerEpoch = 0.toShort + + groupMetadataManager.addPartitionOwnership(groupPartitionId) + + val group = new GroupMetadata(groupId, Empty, time) + groupMetadataManager.addGroup(group) + + val offsets = immutable.Map(topicPartition -> OffsetAndMetadata(offset, "", time.milliseconds())) + + val capturedResponseCallback = appendAndCaptureCallback() + EasyMock.replay(replicaManager) + + var commitErrors: Option[immutable.Map[TopicPartition, Errors]] = None + def callback(errors: immutable.Map[TopicPartition, Errors]): Unit = { + commitErrors = Some(errors) + } + + groupMetadataManager.storeOffsets(group, memberId, offsets, callback, producerId, producerEpoch) + assertTrue(group.hasOffsets) + assertTrue(group.allOffsets.isEmpty) + capturedResponseCallback.getValue.apply(Map(groupTopicPartition -> + new PartitionResponse(Errors.NOT_ENOUGH_REPLICAS, 0L, RecordBatch.NO_TIMESTAMP, 0L))) + + assertFalse(group.hasOffsets) + assertTrue(group.allOffsets.isEmpty) + + group.completePendingTxnOffsetCommit(producerId, isCommit = false) + assertFalse(group.hasOffsets) + assertTrue(group.allOffsets.isEmpty) + + EasyMock.verify(replicaManager) + } + + @Test + def testTransactionalCommitOffsetAborted(): Unit = { + val memberId = "" + val topicPartition = new TopicPartition("foo", 0) + val offset = 37 + val producerId = 232L + val producerEpoch = 0.toShort + + groupMetadataManager.addPartitionOwnership(groupPartitionId) + + val group = new GroupMetadata(groupId, Empty, time) + groupMetadataManager.addGroup(group) + + val offsets = immutable.Map(topicPartition -> OffsetAndMetadata(offset, "", time.milliseconds())) + + val capturedResponseCallback = appendAndCaptureCallback() + EasyMock.replay(replicaManager) + + var commitErrors: Option[immutable.Map[TopicPartition, Errors]] = None + def callback(errors: immutable.Map[TopicPartition, Errors]): Unit = { + commitErrors = Some(errors) + } + + groupMetadataManager.storeOffsets(group, memberId, offsets, callback, producerId, producerEpoch) + assertTrue(group.hasOffsets) + assertTrue(group.allOffsets.isEmpty) + capturedResponseCallback.getValue.apply(Map(groupTopicPartition -> + new PartitionResponse(Errors.NONE, 0L, RecordBatch.NO_TIMESTAMP, 0L))) + + assertTrue(group.hasOffsets) + assertTrue(group.allOffsets.isEmpty) + + group.completePendingTxnOffsetCommit(producerId, isCommit = false) + assertFalse(group.hasOffsets) + assertTrue(group.allOffsets.isEmpty) + + EasyMock.verify(replicaManager) + } + + @Test + def testCommitOffsetWhenCoordinatorHasMoved(): Unit = { + EasyMock.expect(replicaManager.getMagic(EasyMock.anyObject())).andReturn(None) + val memberId = "" + val topicPartition = new TopicPartition("foo", 0) + val offset = 37 + + groupMetadataManager.addPartitionOwnership(groupPartitionId) + + val group = new GroupMetadata(groupId, Empty, time) + groupMetadataManager.addGroup(group) + + val offsets = immutable.Map(topicPartition -> OffsetAndMetadata(offset, "", time.milliseconds())) + + EasyMock.replay(replicaManager) + + var commitErrors: Option[immutable.Map[TopicPartition, Errors]] = None + def callback(errors: immutable.Map[TopicPartition, Errors]): Unit = { + commitErrors = Some(errors) + } + + groupMetadataManager.storeOffsets(group, memberId, offsets, callback) + + assertFalse(commitErrors.isEmpty) + val maybeError = commitErrors.get.get(topicPartition) + assertEquals(Some(Errors.NOT_COORDINATOR), maybeError) + + EasyMock.verify(replicaManager) + } + + @Test + def testCommitOffsetFailure(): Unit = { + assertCommitOffsetErrorMapping(Errors.UNKNOWN_TOPIC_OR_PARTITION, Errors.COORDINATOR_NOT_AVAILABLE) + assertCommitOffsetErrorMapping(Errors.NOT_ENOUGH_REPLICAS, Errors.COORDINATOR_NOT_AVAILABLE) + assertCommitOffsetErrorMapping(Errors.NOT_ENOUGH_REPLICAS_AFTER_APPEND, Errors.COORDINATOR_NOT_AVAILABLE) + assertCommitOffsetErrorMapping(Errors.NOT_LEADER_OR_FOLLOWER, Errors.NOT_COORDINATOR) + assertCommitOffsetErrorMapping(Errors.MESSAGE_TOO_LARGE, Errors.INVALID_COMMIT_OFFSET_SIZE) + assertCommitOffsetErrorMapping(Errors.RECORD_LIST_TOO_LARGE, Errors.INVALID_COMMIT_OFFSET_SIZE) + assertCommitOffsetErrorMapping(Errors.INVALID_FETCH_SIZE, Errors.INVALID_COMMIT_OFFSET_SIZE) + assertCommitOffsetErrorMapping(Errors.CORRUPT_MESSAGE, Errors.CORRUPT_MESSAGE) + } + + private def assertCommitOffsetErrorMapping(appendError: Errors, expectedError: Errors): Unit = { + EasyMock.reset(replicaManager) + + val memberId = "" + val topicPartition = new TopicPartition("foo", 0) + val offset = 37 + + groupMetadataManager.addPartitionOwnership(groupPartitionId) + + val group = new GroupMetadata(groupId, Empty, time) + groupMetadataManager.addGroup(group) + + val offsets = immutable.Map(topicPartition -> OffsetAndMetadata(offset, "", time.milliseconds())) + + val capturedResponseCallback = appendAndCaptureCallback() + EasyMock.replay(replicaManager) + + var commitErrors: Option[immutable.Map[TopicPartition, Errors]] = None + def callback(errors: immutable.Map[TopicPartition, Errors]): Unit = { + commitErrors = Some(errors) + } + + assertEquals(0, TestUtils.totalMetricValue(metrics, "offset-commit-count")) + groupMetadataManager.storeOffsets(group, memberId, offsets, callback) + assertTrue(group.hasOffsets) + capturedResponseCallback.getValue.apply(Map(groupTopicPartition -> + new PartitionResponse(appendError, 0L, RecordBatch.NO_TIMESTAMP, 0L))) + + assertFalse(commitErrors.isEmpty) + val maybeError = commitErrors.get.get(topicPartition) + assertEquals(Some(expectedError), maybeError) + assertFalse(group.hasOffsets) + + val cachedOffsets = groupMetadataManager.getOffsets(groupId, defaultRequireStable, Some(Seq(topicPartition))) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), cachedOffsets.get(topicPartition).map(_.offset)) + + EasyMock.verify(replicaManager) + // Will not update sensor if failed + assertEquals(0, TestUtils.totalMetricValue(metrics, "offset-commit-count")) + } + + @Test + def testCommitOffsetPartialFailure(): Unit = { + EasyMock.reset(replicaManager) + + val memberId = "" + val topicPartition = new TopicPartition("foo", 0) + val topicPartitionFailed = new TopicPartition("foo", 1) + val offset = 37 + + groupMetadataManager.addPartitionOwnership(groupPartitionId) + + val group = new GroupMetadata(groupId, Empty, time) + groupMetadataManager.addGroup(group) + + val offsets = immutable.Map( + topicPartition -> OffsetAndMetadata(offset, "", time.milliseconds()), + // This will failed + topicPartitionFailed -> OffsetAndMetadata(offset, "s" * (offsetConfig.maxMetadataSize + 1) , time.milliseconds()) + ) + + val capturedResponseCallback = appendAndCaptureCallback() + EasyMock.replay(replicaManager) + + var commitErrors: Option[immutable.Map[TopicPartition, Errors]] = None + def callback(errors: immutable.Map[TopicPartition, Errors]): Unit = { + commitErrors = Some(errors) + } + + assertEquals(0, TestUtils.totalMetricValue(metrics, "offset-commit-count")) + groupMetadataManager.storeOffsets(group, memberId, offsets, callback) + assertTrue(group.hasOffsets) + capturedResponseCallback.getValue.apply(Map(groupTopicPartition -> + new PartitionResponse(Errors.NONE, 0L, RecordBatch.NO_TIMESTAMP, 0L))) + + assertFalse(commitErrors.isEmpty) + assertEquals(Some(Errors.NONE), commitErrors.get.get(topicPartition)) + assertEquals(Some(Errors.OFFSET_METADATA_TOO_LARGE), commitErrors.get.get(topicPartitionFailed)) + assertTrue(group.hasOffsets) + + val cachedOffsets = groupMetadataManager.getOffsets(groupId, defaultRequireStable, Some(Seq(topicPartition, topicPartitionFailed))) + assertEquals(Some(offset), cachedOffsets.get(topicPartition).map(_.offset)) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), cachedOffsets.get(topicPartitionFailed).map(_.offset)) + + EasyMock.verify(replicaManager) + assertEquals(1, TestUtils.totalMetricValue(metrics, "offset-commit-count")) + } + + @Test + def testOffsetMetadataTooLarge(): Unit = { + val memberId = "" + val topicPartition = new TopicPartition("foo", 0) + val offset = 37 + + groupMetadataManager.addPartitionOwnership(groupPartitionId) + val group = new GroupMetadata(groupId, Empty, time) + groupMetadataManager.addGroup(group) + + val offsets = immutable.Map( + topicPartition -> OffsetAndMetadata(offset, "s" * (offsetConfig.maxMetadataSize + 1) , time.milliseconds()) + ) + + var commitErrors: Option[immutable.Map[TopicPartition, Errors]] = None + def callback(errors: immutable.Map[TopicPartition, Errors]): Unit = { + commitErrors = Some(errors) + } + + assertEquals(0, TestUtils.totalMetricValue(metrics, "offset-commit-count")) + groupMetadataManager.storeOffsets(group, memberId, offsets, callback) + assertFalse(group.hasOffsets) + + assertFalse(commitErrors.isEmpty) + val maybeError = commitErrors.get.get(topicPartition) + assertEquals(Some(Errors.OFFSET_METADATA_TOO_LARGE), maybeError) + assertFalse(group.hasOffsets) + + val cachedOffsets = groupMetadataManager.getOffsets(groupId, defaultRequireStable, Some(Seq(topicPartition))) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), cachedOffsets.get(topicPartition).map(_.offset)) + + assertEquals(0, TestUtils.totalMetricValue(metrics, "offset-commit-count")) + } + + @Test + def testExpireOffset(): Unit = { + val memberId = "" + val topicPartition1 = new TopicPartition("foo", 0) + val topicPartition2 = new TopicPartition("foo", 1) + val offset = 37 + + groupMetadataManager.addPartitionOwnership(groupPartitionId) + + val group = new GroupMetadata(groupId, Empty, time) + groupMetadataManager.addGroup(group) + + // expire the offset after 1 millisecond + val startMs = time.milliseconds + val offsets = immutable.Map( + topicPartition1 -> OffsetAndMetadata(offset, "", startMs, startMs + 1), + topicPartition2 -> OffsetAndMetadata(offset, "", startMs, startMs + 3)) + + mockGetPartition() + expectAppendMessage(Errors.NONE) + EasyMock.replay(replicaManager) + + var commitErrors: Option[immutable.Map[TopicPartition, Errors]] = None + def callback(errors: immutable.Map[TopicPartition, Errors]): Unit = { + commitErrors = Some(errors) + } + + groupMetadataManager.storeOffsets(group, memberId, offsets, callback) + assertTrue(group.hasOffsets) + + assertFalse(commitErrors.isEmpty) + assertEquals(Some(Errors.NONE), commitErrors.get.get(topicPartition1)) + + // expire only one of the offsets + time.sleep(2) + + EasyMock.reset(partition) + EasyMock.expect(partition.appendRecordsToLeader(EasyMock.anyObject(classOf[MemoryRecords]), + origin = EasyMock.eq(AppendOrigin.Coordinator), requiredAcks = EasyMock.anyInt(), + EasyMock.anyObject())).andReturn(LogAppendInfo.UnknownLogAppendInfo) + EasyMock.replay(partition) + + groupMetadataManager.cleanupGroupMetadata() + + assertEquals(Some(group), groupMetadataManager.getGroup(groupId)) + assertEquals(None, group.offset(topicPartition1)) + assertEquals(Some(offset), group.offset(topicPartition2).map(_.offset)) + + val cachedOffsets = groupMetadataManager.getOffsets(groupId, defaultRequireStable, Some(Seq(topicPartition1, topicPartition2))) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), cachedOffsets.get(topicPartition1).map(_.offset)) + assertEquals(Some(offset), cachedOffsets.get(topicPartition2).map(_.offset)) + + EasyMock.verify(replicaManager) + } + + @Test + def testGroupMetadataRemoval(): Unit = { + val topicPartition1 = new TopicPartition("foo", 0) + val topicPartition2 = new TopicPartition("foo", 1) + + groupMetadataManager.addPartitionOwnership(groupPartitionId) + + val group = new GroupMetadata(groupId, Empty, time) + groupMetadataManager.addGroup(group) + group.generationId = 5 + + // expect the group metadata tombstone + EasyMock.reset(partition) + val recordsCapture: Capture[MemoryRecords] = EasyMock.newCapture() + + EasyMock.expect(replicaManager.getMagic(EasyMock.anyObject())).andStubReturn(Some(RecordBatch.CURRENT_MAGIC_VALUE)) + mockGetPartition() + EasyMock.expect(partition.appendRecordsToLeader(EasyMock.capture(recordsCapture), + origin = EasyMock.eq(AppendOrigin.Coordinator), requiredAcks = EasyMock.anyInt(), + EasyMock.anyObject())).andReturn(LogAppendInfo.UnknownLogAppendInfo) + EasyMock.replay(replicaManager, partition) + + groupMetadataManager.cleanupGroupMetadata() + + assertTrue(recordsCapture.hasCaptured) + + val records = recordsCapture.getValue.records.asScala.toList + recordsCapture.getValue.batches.forEach { batch => + assertEquals(RecordBatch.CURRENT_MAGIC_VALUE, batch.magic) + assertEquals(TimestampType.CREATE_TIME, batch.timestampType) + } + assertEquals(1, records.size) + + val metadataTombstone = records.head + assertTrue(metadataTombstone.hasKey) + assertFalse(metadataTombstone.hasValue) + assertTrue(metadataTombstone.timestamp > 0) + + val groupKey = GroupMetadataManager.readMessageKey(metadataTombstone.key).asInstanceOf[GroupMetadataKey] + assertEquals(groupId, groupKey.key) + + // the full group should be gone since all offsets were removed + assertEquals(None, groupMetadataManager.getGroup(groupId)) + val cachedOffsets = groupMetadataManager.getOffsets(groupId, defaultRequireStable, Some(Seq(topicPartition1, topicPartition2))) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), cachedOffsets.get(topicPartition1).map(_.offset)) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), cachedOffsets.get(topicPartition2).map(_.offset)) + } + + @Test + def testGroupMetadataRemovalWithLogAppendTime(): Unit = { + val topicPartition1 = new TopicPartition("foo", 0) + val topicPartition2 = new TopicPartition("foo", 1) + + groupMetadataManager.addPartitionOwnership(groupPartitionId) + + val group = new GroupMetadata(groupId, Empty, time) + groupMetadataManager.addGroup(group) + group.generationId = 5 + + // expect the group metadata tombstone + EasyMock.reset(partition) + val recordsCapture: Capture[MemoryRecords] = EasyMock.newCapture() + + EasyMock.expect(replicaManager.getMagic(EasyMock.anyObject())).andStubReturn(Some(RecordBatch.CURRENT_MAGIC_VALUE)) + mockGetPartition() + EasyMock.expect(partition.appendRecordsToLeader(EasyMock.capture(recordsCapture), + origin = EasyMock.eq(AppendOrigin.Coordinator), requiredAcks = EasyMock.anyInt(), + EasyMock.anyObject())).andReturn(LogAppendInfo.UnknownLogAppendInfo) + EasyMock.replay(replicaManager, partition) + + groupMetadataManager.cleanupGroupMetadata() + + assertTrue(recordsCapture.hasCaptured) + + val records = recordsCapture.getValue.records.asScala.toList + recordsCapture.getValue.batches.forEach { batch => + assertEquals(RecordBatch.CURRENT_MAGIC_VALUE, batch.magic) + // Use CREATE_TIME, like the producer. The conversion to LOG_APPEND_TIME (if necessary) happens automatically. + assertEquals(TimestampType.CREATE_TIME, batch.timestampType) + } + assertEquals(1, records.size) + + val metadataTombstone = records.head + assertTrue(metadataTombstone.hasKey) + assertFalse(metadataTombstone.hasValue) + assertTrue(metadataTombstone.timestamp > 0) + + val groupKey = GroupMetadataManager.readMessageKey(metadataTombstone.key).asInstanceOf[GroupMetadataKey] + assertEquals(groupId, groupKey.key) + + // the full group should be gone since all offsets were removed + assertEquals(None, groupMetadataManager.getGroup(groupId)) + val cachedOffsets = groupMetadataManager.getOffsets(groupId, defaultRequireStable, Some(Seq(topicPartition1, topicPartition2))) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), cachedOffsets.get(topicPartition1).map(_.offset)) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), cachedOffsets.get(topicPartition2).map(_.offset)) + } + + @Test + def testExpireGroupWithOffsetsOnly(): Unit = { + // verify that the group is removed properly, but no tombstone is written if + // this is a group which is only using kafka for offset storage + + val memberId = "" + val topicPartition1 = new TopicPartition("foo", 0) + val topicPartition2 = new TopicPartition("foo", 1) + val offset = 37 + + groupMetadataManager.addPartitionOwnership(groupPartitionId) + + val group = new GroupMetadata(groupId, Empty, time) + groupMetadataManager.addGroup(group) + + // expire the offset after 1 millisecond + val startMs = time.milliseconds + val offsets = immutable.Map( + topicPartition1 -> OffsetAndMetadata(offset, Optional.empty(), "", startMs, Some(startMs + 1)), + topicPartition2 -> OffsetAndMetadata(offset, "", startMs, startMs + 3)) + + mockGetPartition() + expectAppendMessage(Errors.NONE) + EasyMock.replay(replicaManager) + + var commitErrors: Option[immutable.Map[TopicPartition, Errors]] = None + def callback(errors: immutable.Map[TopicPartition, Errors]): Unit = { + commitErrors = Some(errors) + } + + groupMetadataManager.storeOffsets(group, memberId, offsets, callback) + assertTrue(group.hasOffsets) + + assertFalse(commitErrors.isEmpty) + assertEquals(Some(Errors.NONE), commitErrors.get.get(topicPartition1)) + + // expire all of the offsets + time.sleep(4) + + // expect the offset tombstone + EasyMock.reset(partition) + val recordsCapture: Capture[MemoryRecords] = EasyMock.newCapture() + + EasyMock.expect(partition.appendRecordsToLeader(EasyMock.capture(recordsCapture), + origin = EasyMock.eq(AppendOrigin.Coordinator), requiredAcks = EasyMock.anyInt(), + EasyMock.anyObject())).andReturn(LogAppendInfo.UnknownLogAppendInfo) + EasyMock.replay(partition) + + groupMetadataManager.cleanupGroupMetadata() + + assertTrue(recordsCapture.hasCaptured) + + // verify the tombstones are correct and only for the expired offsets + val records = recordsCapture.getValue.records.asScala.toList + assertEquals(2, records.size) + records.foreach { message => + assertTrue(message.hasKey) + assertFalse(message.hasValue) + val offsetKey = GroupMetadataManager.readMessageKey(message.key).asInstanceOf[OffsetKey] + assertEquals(groupId, offsetKey.key.group) + assertEquals("foo", offsetKey.key.topicPartition.topic) + } + + // the full group should be gone since all offsets were removed + assertEquals(None, groupMetadataManager.getGroup(groupId)) + val cachedOffsets = groupMetadataManager.getOffsets(groupId, defaultRequireStable, Some(Seq(topicPartition1, topicPartition2))) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), cachedOffsets.get(topicPartition1).map(_.offset)) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), cachedOffsets.get(topicPartition2).map(_.offset)) + + EasyMock.verify(replicaManager) + } + + @Test + def testOffsetExpirationSemantics(): Unit = { + val memberId = "memberId" + val clientId = "clientId" + val clientHost = "localhost" + val topic = "foo" + val topicPartition1 = new TopicPartition(topic, 0) + val topicPartition2 = new TopicPartition(topic, 1) + val topicPartition3 = new TopicPartition(topic, 2) + val offset = 37 + + groupMetadataManager.addPartitionOwnership(groupPartitionId) + + val group = new GroupMetadata(groupId, Empty, time) + groupMetadataManager.addGroup(group) + + val subscription = new Subscription(List(topic).asJava) + val member = new MemberMetadata(memberId, Some(groupInstanceId), clientId, clientHost, rebalanceTimeout, sessionTimeout, + protocolType, List(("protocol", ConsumerProtocol.serializeSubscription(subscription).array()))) + group.add(member, _ => ()) + group.transitionTo(PreparingRebalance) + group.initNextGeneration() + + val startMs = time.milliseconds + // old clients, expiry timestamp is explicitly set + val tp1OffsetAndMetadata = OffsetAndMetadata(offset, "", startMs, startMs + 1) + val tp2OffsetAndMetadata = OffsetAndMetadata(offset, "", startMs, startMs + 3) + // new clients, no per-partition expiry timestamp, offsets of group expire together + val tp3OffsetAndMetadata = OffsetAndMetadata(offset, "", startMs) + val offsets = immutable.Map( + topicPartition1 -> tp1OffsetAndMetadata, + topicPartition2 -> tp2OffsetAndMetadata, + topicPartition3 -> tp3OffsetAndMetadata) + + mockGetPartition() + expectAppendMessage(Errors.NONE) + EasyMock.replay(replicaManager) + + var commitErrors: Option[immutable.Map[TopicPartition, Errors]] = None + def callback(errors: immutable.Map[TopicPartition, Errors]): Unit = { + commitErrors = Some(errors) + } + + groupMetadataManager.storeOffsets(group, memberId, offsets, callback) + assertTrue(group.hasOffsets) + + assertFalse(commitErrors.isEmpty) + assertEquals(Some(Errors.NONE), commitErrors.get.get(topicPartition1)) + + // do not expire any offset even though expiration timestamp is reached for one (due to group still being active) + time.sleep(2) + + groupMetadataManager.cleanupGroupMetadata() + + // group and offsets should still be there + assertEquals(Some(group), groupMetadataManager.getGroup(groupId)) + assertEquals(Some(tp1OffsetAndMetadata), group.offset(topicPartition1)) + assertEquals(Some(tp2OffsetAndMetadata), group.offset(topicPartition2)) + assertEquals(Some(tp3OffsetAndMetadata), group.offset(topicPartition3)) + + var cachedOffsets = groupMetadataManager.getOffsets(groupId, defaultRequireStable, Some(Seq(topicPartition1, topicPartition2, topicPartition3))) + assertEquals(Some(offset), cachedOffsets.get(topicPartition1).map(_.offset)) + assertEquals(Some(offset), cachedOffsets.get(topicPartition2).map(_.offset)) + assertEquals(Some(offset), cachedOffsets.get(topicPartition3).map(_.offset)) + + EasyMock.verify(replicaManager) + + group.transitionTo(PreparingRebalance) + group.transitionTo(Empty) + + // expect the offset tombstone + EasyMock.reset(partition) + EasyMock.expect(partition.appendRecordsToLeader(EasyMock.anyObject(classOf[MemoryRecords]), + origin = EasyMock.eq(AppendOrigin.Coordinator), requiredAcks = EasyMock.anyInt(), + EasyMock.anyObject())).andReturn(LogAppendInfo.UnknownLogAppendInfo) + EasyMock.replay(partition) + + groupMetadataManager.cleanupGroupMetadata() + + // group is empty now, only one offset should expire + assertEquals(Some(group), groupMetadataManager.getGroup(groupId)) + assertEquals(None, group.offset(topicPartition1)) + assertEquals(Some(tp2OffsetAndMetadata), group.offset(topicPartition2)) + assertEquals(Some(tp3OffsetAndMetadata), group.offset(topicPartition3)) + + cachedOffsets = groupMetadataManager.getOffsets(groupId, defaultRequireStable, Some(Seq(topicPartition1, topicPartition2, topicPartition3))) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), cachedOffsets.get(topicPartition1).map(_.offset)) + assertEquals(Some(offset), cachedOffsets.get(topicPartition2).map(_.offset)) + assertEquals(Some(offset), cachedOffsets.get(topicPartition3).map(_.offset)) + + EasyMock.verify(replicaManager) + + time.sleep(2) + + // expect the offset tombstone + EasyMock.reset(partition) + EasyMock.expect(partition.appendRecordsToLeader(EasyMock.anyObject(classOf[MemoryRecords]), + origin = EasyMock.eq(AppendOrigin.Coordinator), requiredAcks = EasyMock.anyInt(), + EasyMock.anyObject())).andReturn(LogAppendInfo.UnknownLogAppendInfo) + EasyMock.replay(partition) + + groupMetadataManager.cleanupGroupMetadata() + + // one more offset should expire + assertEquals(Some(group), groupMetadataManager.getGroup(groupId)) + assertEquals(None, group.offset(topicPartition1)) + assertEquals(None, group.offset(topicPartition2)) + assertEquals(Some(tp3OffsetAndMetadata), group.offset(topicPartition3)) + + cachedOffsets = groupMetadataManager.getOffsets(groupId, defaultRequireStable, Some(Seq(topicPartition1, topicPartition2, topicPartition3))) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), cachedOffsets.get(topicPartition1).map(_.offset)) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), cachedOffsets.get(topicPartition2).map(_.offset)) + assertEquals(Some(offset), cachedOffsets.get(topicPartition3).map(_.offset)) + + EasyMock.verify(replicaManager) + + // advance time to just before the offset of last partition is to be expired, no offset should expire + time.sleep(group.currentStateTimestamp.get + defaultOffsetRetentionMs - time.milliseconds() - 1) + + groupMetadataManager.cleanupGroupMetadata() + + // one more offset should expire + assertEquals(Some(group), groupMetadataManager.getGroup(groupId)) + assertEquals(None, group.offset(topicPartition1)) + assertEquals(None, group.offset(topicPartition2)) + assertEquals(Some(tp3OffsetAndMetadata), group.offset(topicPartition3)) + + cachedOffsets = groupMetadataManager.getOffsets(groupId, defaultRequireStable, Some(Seq(topicPartition1, topicPartition2, topicPartition3))) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), cachedOffsets.get(topicPartition1).map(_.offset)) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), cachedOffsets.get(topicPartition2).map(_.offset)) + assertEquals(Some(offset), cachedOffsets.get(topicPartition3).map(_.offset)) + + EasyMock.verify(replicaManager) + + // advance time enough for that last offset to expire + time.sleep(2) + + // expect the offset tombstone + EasyMock.reset(partition) + EasyMock.expect(partition.appendRecordsToLeader(EasyMock.anyObject(classOf[MemoryRecords]), + origin = EasyMock.eq(AppendOrigin.Coordinator), requiredAcks = EasyMock.anyInt(), + EasyMock.anyObject())).andReturn(LogAppendInfo.UnknownLogAppendInfo) + EasyMock.replay(partition) + + groupMetadataManager.cleanupGroupMetadata() + + // group and all its offsets should be gone now + assertEquals(None, groupMetadataManager.getGroup(groupId)) + assertEquals(None, group.offset(topicPartition1)) + assertEquals(None, group.offset(topicPartition2)) + assertEquals(None, group.offset(topicPartition3)) + + cachedOffsets = groupMetadataManager.getOffsets(groupId, defaultRequireStable, Some(Seq(topicPartition1, topicPartition2, topicPartition3))) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), cachedOffsets.get(topicPartition1).map(_.offset)) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), cachedOffsets.get(topicPartition2).map(_.offset)) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), cachedOffsets.get(topicPartition3).map(_.offset)) + + EasyMock.verify(replicaManager) + + assert(group.is(Dead)) + } + + @Test + def testOffsetExpirationOfSimpleConsumer(): Unit = { + val memberId = "memberId" + val topic = "foo" + val topicPartition1 = new TopicPartition(topic, 0) + val offset = 37 + + groupMetadataManager.addPartitionOwnership(groupPartitionId) + + val group = new GroupMetadata(groupId, Empty, time) + groupMetadataManager.addGroup(group) + + // expire the offset after 1 and 3 milliseconds (old clients) and after default retention (new clients) + val startMs = time.milliseconds + // old clients, expiry timestamp is explicitly set + val tp1OffsetAndMetadata = OffsetAndMetadata(offset, "", startMs) + // new clients, no per-partition expiry timestamp, offsets of group expire together + val offsets = immutable.Map( + topicPartition1 -> tp1OffsetAndMetadata) + + mockGetPartition() + expectAppendMessage(Errors.NONE) + EasyMock.replay(replicaManager) + + var commitErrors: Option[immutable.Map[TopicPartition, Errors]] = None + def callback(errors: immutable.Map[TopicPartition, Errors]): Unit = { + commitErrors = Some(errors) + } + + groupMetadataManager.storeOffsets(group, memberId, offsets, callback) + assertTrue(group.hasOffsets) + + assertFalse(commitErrors.isEmpty) + assertEquals(Some(Errors.NONE), commitErrors.get.get(topicPartition1)) + + // do not expire offsets while within retention period since commit timestamp + val expiryTimestamp = offsets(topicPartition1).commitTimestamp + defaultOffsetRetentionMs + time.sleep(expiryTimestamp - time.milliseconds() - 1) + + groupMetadataManager.cleanupGroupMetadata() + + // group and offsets should still be there + assertEquals(Some(group), groupMetadataManager.getGroup(groupId)) + assertEquals(Some(tp1OffsetAndMetadata), group.offset(topicPartition1)) + + var cachedOffsets = groupMetadataManager.getOffsets(groupId, defaultRequireStable, Some(Seq(topicPartition1))) + assertEquals(Some(offset), cachedOffsets.get(topicPartition1).map(_.offset)) + + EasyMock.verify(replicaManager) + + // advance time to enough for offsets to expire + time.sleep(2) + + // expect the offset tombstone + EasyMock.reset(partition) + EasyMock.expect(partition.appendRecordsToLeader(EasyMock.anyObject(classOf[MemoryRecords]), + origin = EasyMock.eq(AppendOrigin.Coordinator), requiredAcks = EasyMock.anyInt(), + EasyMock.anyObject())).andReturn(LogAppendInfo.UnknownLogAppendInfo) + EasyMock.replay(partition) + + groupMetadataManager.cleanupGroupMetadata() + + // group and all its offsets should be gone now + assertEquals(None, groupMetadataManager.getGroup(groupId)) + assertEquals(None, group.offset(topicPartition1)) + + cachedOffsets = groupMetadataManager.getOffsets(groupId, defaultRequireStable, Some(Seq(topicPartition1))) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), cachedOffsets.get(topicPartition1).map(_.offset)) + + EasyMock.verify(replicaManager) + + assert(group.is(Dead)) + } + + @Test + def testOffsetExpirationOfActiveGroupSemantics(): Unit = { + val memberId = "memberId" + val clientId = "clientId" + val clientHost = "localhost" + + val topic1 = "foo" + val topic1Partition0 = new TopicPartition(topic1, 0) + val topic1Partition1 = new TopicPartition(topic1, 1) + + val topic2 = "bar" + val topic2Partition0 = new TopicPartition(topic2, 0) + val topic2Partition1 = new TopicPartition(topic2, 1) + + val offset = 37 + + groupMetadataManager.addPartitionOwnership(groupPartitionId) + + val group = new GroupMetadata(groupId, Empty, time) + groupMetadataManager.addGroup(group) + + // Subscribe to topic1 and topic2 + val subscriptionTopic1AndTopic2 = new Subscription(List(topic1, topic2).asJava) + + val member = new MemberMetadata( + memberId, + Some(groupInstanceId), + clientId, + clientHost, + rebalanceTimeout, + sessionTimeout, + ConsumerProtocol.PROTOCOL_TYPE, + List(("protocol", ConsumerProtocol.serializeSubscription(subscriptionTopic1AndTopic2).array())) + ) + + group.add(member, _ => ()) + group.transitionTo(PreparingRebalance) + group.initNextGeneration() + group.transitionTo(Stable) + + val startMs = time.milliseconds + + val t1p0OffsetAndMetadata = OffsetAndMetadata(offset, "", startMs) + val t1p1OffsetAndMetadata = OffsetAndMetadata(offset, "", startMs) + + val t2p0OffsetAndMetadata = OffsetAndMetadata(offset, "", startMs) + val t2p1OffsetAndMetadata = OffsetAndMetadata(offset, "", startMs) + + val offsets = immutable.Map( + topic1Partition0 -> t1p0OffsetAndMetadata, + topic1Partition1 -> t1p1OffsetAndMetadata, + topic2Partition0 -> t2p0OffsetAndMetadata, + topic2Partition1 -> t2p1OffsetAndMetadata) + + mockGetPartition() + expectAppendMessage(Errors.NONE) + EasyMock.replay(replicaManager) + + var commitErrors: Option[immutable.Map[TopicPartition, Errors]] = None + def callback(errors: immutable.Map[TopicPartition, Errors]): Unit = { + commitErrors = Some(errors) + } + + groupMetadataManager.storeOffsets(group, memberId, offsets, callback) + assertTrue(group.hasOffsets) + + assertFalse(commitErrors.isEmpty) + assertEquals(Some(Errors.NONE), commitErrors.get.get(topic1Partition0)) + + // advance time to just after the offset of last partition is to be expired + time.sleep(defaultOffsetRetentionMs + 2) + + // no offset should expire because all topics are actively consumed + groupMetadataManager.cleanupGroupMetadata() + + assertEquals(Some(group), groupMetadataManager.getGroup(groupId)) + assert(group.is(Stable)) + + assertEquals(Some(t1p0OffsetAndMetadata), group.offset(topic1Partition0)) + assertEquals(Some(t1p1OffsetAndMetadata), group.offset(topic1Partition1)) + assertEquals(Some(t2p0OffsetAndMetadata), group.offset(topic2Partition0)) + assertEquals(Some(t2p1OffsetAndMetadata), group.offset(topic2Partition1)) + + var cachedOffsets = groupMetadataManager.getOffsets(groupId, defaultRequireStable, Some(Seq(topic1Partition0, topic1Partition1, topic2Partition0, topic2Partition1))) + + assertEquals(Some(offset), cachedOffsets.get(topic1Partition0).map(_.offset)) + assertEquals(Some(offset), cachedOffsets.get(topic1Partition1).map(_.offset)) + assertEquals(Some(offset), cachedOffsets.get(topic2Partition0).map(_.offset)) + assertEquals(Some(offset), cachedOffsets.get(topic2Partition1).map(_.offset)) + + EasyMock.verify(replicaManager) + + group.transitionTo(PreparingRebalance) + + // Subscribe to topic1, offsets of topic2 should be removed + val subscriptionTopic1 = new Subscription(List(topic1).asJava) + + group.updateMember( + member, + List(("protocol", ConsumerProtocol.serializeSubscription(subscriptionTopic1).array())), + null + ) + + group.initNextGeneration() + group.transitionTo(Stable) + + // expect the offset tombstone + EasyMock.expect(partition.appendRecordsToLeader(EasyMock.anyObject(classOf[MemoryRecords]), + origin = EasyMock.eq(AppendOrigin.Coordinator), requiredAcks = EasyMock.anyInt(), + EasyMock.anyObject())).andReturn(LogAppendInfo.UnknownLogAppendInfo) + EasyMock.expectLastCall().times(1) + + EasyMock.replay(partition) + + groupMetadataManager.cleanupGroupMetadata() + + EasyMock.verify(partition) + EasyMock.verify(replicaManager) + + assertEquals(Some(group), groupMetadataManager.getGroup(groupId)) + assert(group.is(Stable)) + + assertEquals(Some(t1p0OffsetAndMetadata), group.offset(topic1Partition0)) + assertEquals(Some(t1p1OffsetAndMetadata), group.offset(topic1Partition1)) + assertEquals(None, group.offset(topic2Partition0)) + assertEquals(None, group.offset(topic2Partition1)) + + cachedOffsets = groupMetadataManager.getOffsets(groupId, defaultRequireStable, Some(Seq(topic1Partition0, topic1Partition1, topic2Partition0, topic2Partition1))) + + assertEquals(Some(offset), cachedOffsets.get(topic1Partition0).map(_.offset)) + assertEquals(Some(offset), cachedOffsets.get(topic1Partition1).map(_.offset)) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), cachedOffsets.get(topic2Partition0).map(_.offset)) + assertEquals(Some(OffsetFetchResponse.INVALID_OFFSET), cachedOffsets.get(topic2Partition1).map(_.offset)) + } + + @Test + def testLoadOffsetFromOldCommit(): Unit = { + val groupMetadataTopicPartition = groupTopicPartition + val generation = 935 + val protocolType = "consumer" + val protocol = "range" + val startOffset = 15L + val groupEpoch = 2 + val committedOffsets = Map( + new TopicPartition("foo", 0) -> 23L, + new TopicPartition("foo", 1) -> 455L, + new TopicPartition("bar", 0) -> 8992L + ) + + val apiVersion = KAFKA_1_1_IV0 + val offsetCommitRecords = createCommittedOffsetRecords(committedOffsets, apiVersion = apiVersion, retentionTimeOpt = Some(100)) + val memberId = "98098230493" + val groupMetadataRecord = buildStableGroupRecordWithMember(generation, protocolType, protocol, memberId, apiVersion = apiVersion) + val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, + (offsetCommitRecords ++ Seq(groupMetadataRecord)).toArray: _*) + + expectGroupMetadataLoad(groupMetadataTopicPartition, startOffset, records) + + EasyMock.replay(replicaManager) + + groupMetadataManager.loadGroupsAndOffsets(groupMetadataTopicPartition, groupEpoch, _ => (), 0L) + + val group = groupMetadataManager.getGroup(groupId).getOrElse(throw new AssertionError("Group was not loaded into the cache")) + assertEquals(groupId, group.groupId) + assertEquals(Stable, group.currentState) + assertEquals(memberId, group.leaderOrNull) + assertEquals(generation, group.generationId) + assertEquals(Some(protocolType), group.protocolType) + assertEquals(protocol, group.protocolName.orNull) + assertEquals(Set(memberId), group.allMembers) + assertEquals(committedOffsets.size, group.allOffsets.size) + committedOffsets.foreach { case (topicPartition, offset) => + assertEquals(Some(offset), group.offset(topicPartition).map(_.offset)) + assertTrue(group.offset(topicPartition).map(_.expireTimestamp).get.nonEmpty) + } + } + + @Test + def testLoadOffsetWithExplicitRetention(): Unit = { + val groupMetadataTopicPartition = groupTopicPartition + val generation = 935 + val protocolType = "consumer" + val protocol = "range" + val startOffset = 15L + val groupEpoch = 2 + val committedOffsets = Map( + new TopicPartition("foo", 0) -> 23L, + new TopicPartition("foo", 1) -> 455L, + new TopicPartition("bar", 0) -> 8992L + ) + + val offsetCommitRecords = createCommittedOffsetRecords(committedOffsets, retentionTimeOpt = Some(100)) + val memberId = "98098230493" + val groupMetadataRecord = buildStableGroupRecordWithMember(generation, protocolType, protocol, memberId) + val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, + (offsetCommitRecords ++ Seq(groupMetadataRecord)).toArray: _*) + + expectGroupMetadataLoad(groupMetadataTopicPartition, startOffset, records) + + EasyMock.replay(replicaManager) + + groupMetadataManager.loadGroupsAndOffsets(groupMetadataTopicPartition, groupEpoch, _ => (), 0L) + + val group = groupMetadataManager.getGroup(groupId).getOrElse(throw new AssertionError("Group was not loaded into the cache")) + assertEquals(groupId, group.groupId) + assertEquals(Stable, group.currentState) + assertEquals(memberId, group.leaderOrNull) + assertEquals(generation, group.generationId) + assertEquals(Some(protocolType), group.protocolType) + assertEquals(protocol, group.protocolName.orNull) + assertEquals(Set(memberId), group.allMembers) + assertEquals(committedOffsets.size, group.allOffsets.size) + committedOffsets.foreach { case (topicPartition, offset) => + assertEquals(Some(offset), group.offset(topicPartition).map(_.offset)) + assertTrue(group.offset(topicPartition).map(_.expireTimestamp).get.nonEmpty) + } + } + + @Test + def testSerdeOffsetCommitValue(): Unit = { + val offsetAndMetadata = OffsetAndMetadata( + offset = 537L, + leaderEpoch = Optional.of(15), + metadata = "metadata", + commitTimestamp = time.milliseconds(), + expireTimestamp = None) + + def verifySerde(apiVersion: ApiVersion, expectedOffsetCommitValueVersion: Int): Unit = { + val bytes = GroupMetadataManager.offsetCommitValue(offsetAndMetadata, apiVersion) + val buffer = ByteBuffer.wrap(bytes) + + assertEquals(expectedOffsetCommitValueVersion, buffer.getShort(0).toInt) + + val deserializedOffsetAndMetadata = GroupMetadataManager.readOffsetMessageValue(buffer) + assertEquals(offsetAndMetadata.offset, deserializedOffsetAndMetadata.offset) + assertEquals(offsetAndMetadata.metadata, deserializedOffsetAndMetadata.metadata) + assertEquals(offsetAndMetadata.commitTimestamp, deserializedOffsetAndMetadata.commitTimestamp) + + // Serialization drops the leader epoch silently if an older inter-broker protocol is in use + val expectedLeaderEpoch = if (expectedOffsetCommitValueVersion >= 3) + offsetAndMetadata.leaderEpoch + else + Optional.empty() + + assertEquals(expectedLeaderEpoch, deserializedOffsetAndMetadata.leaderEpoch) + } + + for (version <- ApiVersion.allVersions) { + val expectedSchemaVersion = version match { + case v if v < KAFKA_2_1_IV0 => 1 + case v if v < KAFKA_2_1_IV1 => 2 + case _ => 3 + } + verifySerde(version, expectedSchemaVersion) + } + } + + @Test + def testSerdeOffsetCommitValueWithExpireTimestamp(): Unit = { + // If expire timestamp is set, we should always use version 1 of the offset commit + // value schema since later versions do not support it + + val offsetAndMetadata = OffsetAndMetadata( + offset = 537L, + leaderEpoch = Optional.empty(), + metadata = "metadata", + commitTimestamp = time.milliseconds(), + expireTimestamp = Some(time.milliseconds() + 1000)) + + def verifySerde(apiVersion: ApiVersion): Unit = { + val bytes = GroupMetadataManager.offsetCommitValue(offsetAndMetadata, apiVersion) + val buffer = ByteBuffer.wrap(bytes) + assertEquals(1, buffer.getShort(0).toInt) + + val deserializedOffsetAndMetadata = GroupMetadataManager.readOffsetMessageValue(buffer) + assertEquals(offsetAndMetadata, deserializedOffsetAndMetadata) + } + + for (version <- ApiVersion.allVersions) + verifySerde(version) + } + + @Test + def testSerdeOffsetCommitValueWithNoneExpireTimestamp(): Unit = { + val offsetAndMetadata = OffsetAndMetadata( + offset = 537L, + leaderEpoch = Optional.empty(), + metadata = "metadata", + commitTimestamp = time.milliseconds(), + expireTimestamp = None) + + def verifySerde(apiVersion: ApiVersion): Unit = { + val bytes = GroupMetadataManager.offsetCommitValue(offsetAndMetadata, apiVersion) + val buffer = ByteBuffer.wrap(bytes) + val version = buffer.getShort(0).toInt + if (apiVersion < KAFKA_2_1_IV0) + assertEquals(1, version) + else if (apiVersion < KAFKA_2_1_IV1) + assertEquals(2, version) + else + assertEquals(3, version) + + val deserializedOffsetAndMetadata = GroupMetadataManager.readOffsetMessageValue(buffer) + assertEquals(offsetAndMetadata, deserializedOffsetAndMetadata) + } + + for (version <- ApiVersion.allVersions) + verifySerde(version) + } + + @Test + def testLoadOffsetsWithEmptyControlBatch(): Unit = { + val groupMetadataTopicPartition = groupTopicPartition + val startOffset = 15L + val generation = 15 + val groupEpoch = 2 + + val committedOffsets = Map( + new TopicPartition("foo", 0) -> 23L, + new TopicPartition("foo", 1) -> 455L, + new TopicPartition("bar", 0) -> 8992L + ) + + val offsetCommitRecords = createCommittedOffsetRecords(committedOffsets) + val groupMetadataRecord = buildEmptyGroupRecord(generation, protocolType) + val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, + (offsetCommitRecords ++ Seq(groupMetadataRecord)).toArray: _*) + + // Prepend empty control batch to valid records + val mockBatch: MutableRecordBatch = EasyMock.createMock(classOf[MutableRecordBatch]) + EasyMock.expect(mockBatch.iterator).andReturn(Collections.emptyIterator[Record]) + EasyMock.expect(mockBatch.isControlBatch).andReturn(true) + EasyMock.expect(mockBatch.isTransactional).andReturn(true) + EasyMock.expect(mockBatch.nextOffset).andReturn(16L) + EasyMock.replay(mockBatch) + val mockRecords: MemoryRecords = EasyMock.createMock(classOf[MemoryRecords]) + EasyMock.expect(mockRecords.batches).andReturn((Iterable[MutableRecordBatch](mockBatch) ++ records.batches.asScala).asJava).anyTimes() + EasyMock.expect(mockRecords.records).andReturn(records.records()).anyTimes() + EasyMock.expect(mockRecords.sizeInBytes()).andReturn(DefaultRecordBatch.RECORD_BATCH_OVERHEAD + records.sizeInBytes()).anyTimes() + EasyMock.replay(mockRecords) + + val logMock: UnifiedLog = EasyMock.mock(classOf[UnifiedLog]) + EasyMock.expect(logMock.logStartOffset).andReturn(startOffset).anyTimes() + EasyMock.expect(logMock.read(EasyMock.eq(startOffset), + maxLength = EasyMock.anyInt(), + isolation = EasyMock.eq(FetchLogEnd), + minOneMessage = EasyMock.eq(true))) + .andReturn(FetchDataInfo(LogOffsetMetadata(startOffset), mockRecords)) + EasyMock.expect(replicaManager.getLog(groupMetadataTopicPartition)).andStubReturn(Some(logMock)) + EasyMock.expect(replicaManager.getLogEndOffset(groupMetadataTopicPartition)).andStubReturn(Some(18)) + EasyMock.replay(logMock) + EasyMock.replay(replicaManager) + + groupMetadataManager.loadGroupsAndOffsets(groupMetadataTopicPartition, groupEpoch, _ => (), 0L) + + // Empty control batch should not have caused the load to fail + val group = groupMetadataManager.getGroup(groupId).getOrElse(throw new AssertionError("Group was not loaded into the cache")) + assertEquals(groupId, group.groupId) + assertEquals(Empty, group.currentState) + assertEquals(generation, group.generationId) + assertEquals(Some(protocolType), group.protocolType) + assertNull(group.leaderOrNull) + assertNull(group.protocolName.orNull) + committedOffsets.foreach { case (topicPartition, offset) => + assertEquals(Some(offset), group.offset(topicPartition).map(_.offset)) + } + } + + @Test + def testCommittedOffsetParsing(): Unit = { + val groupId = "group" + val topicPartition = new TopicPartition("topic", 0) + val offsetCommitRecord = TestUtils.records(Seq( + new SimpleRecord( + GroupMetadataManager.offsetCommitKey(groupId, topicPartition), + GroupMetadataManager.offsetCommitValue(OffsetAndMetadata(35L, "", time.milliseconds()), ApiVersion.latestVersion) + ) + )).records.asScala.head + val (keyStringOpt, valueStringOpt) = GroupMetadataManager.formatRecordKeyAndValue(offsetCommitRecord) + assertEquals(Some(s"offset_commit::group=$groupId,partition=$topicPartition"), keyStringOpt) + assertEquals(Some("offset=35"), valueStringOpt) + } + + @Test + def testCommittedOffsetTombstoneParsing(): Unit = { + val groupId = "group" + val topicPartition = new TopicPartition("topic", 0) + val offsetCommitRecord = TestUtils.records(Seq( + new SimpleRecord(GroupMetadataManager.offsetCommitKey(groupId, topicPartition), null) + )).records.asScala.head + val (keyStringOpt, valueStringOpt) = GroupMetadataManager.formatRecordKeyAndValue(offsetCommitRecord) + assertEquals(Some(s"offset_commit::group=$groupId,partition=$topicPartition"), keyStringOpt) + assertEquals(Some(""), valueStringOpt) + } + + @Test + def testGroupMetadataParsingWithNullUserData(): Unit = { + val generation = 935 + val protocolType = "consumer" + val protocol = "range" + val memberId = "98098230493" + val assignmentBytes = Utils.toArray(ConsumerProtocol.serializeAssignment( + new ConsumerPartitionAssignor.Assignment(List(new TopicPartition("topic", 0)).asJava, null) + )) + val groupMetadataRecord = TestUtils.records(Seq( + buildStableGroupRecordWithMember(generation, protocolType, protocol, memberId, assignmentBytes) + )).records.asScala.head + val (keyStringOpt, valueStringOpt) = GroupMetadataManager.formatRecordKeyAndValue(groupMetadataRecord) + assertEquals(Some(s"group_metadata::group=$groupId"), keyStringOpt) + assertEquals(Some("{\"protocolType\":\"consumer\",\"protocol\":\"range\"," + + "\"generationId\":935,\"assignment\":\"{98098230493=[topic-0]}\"}"), valueStringOpt) + } + + @Test + def testGroupMetadataTombstoneParsing(): Unit = { + val groupId = "group" + val groupMetadataRecord = TestUtils.records(Seq( + new SimpleRecord(GroupMetadataManager.groupMetadataKey(groupId), null) + )).records.asScala.head + val (keyStringOpt, valueStringOpt) = GroupMetadataManager.formatRecordKeyAndValue(groupMetadataRecord) + assertEquals(Some(s"group_metadata::group=$groupId"), keyStringOpt) + assertEquals(Some(""), valueStringOpt) + } + + private def appendAndCaptureCallback(): Capture[Map[TopicPartition, PartitionResponse] => Unit] = { + val capturedArgument: Capture[Map[TopicPartition, PartitionResponse] => Unit] = EasyMock.newCapture() + EasyMock.expect(replicaManager.appendRecords(EasyMock.anyLong(), + EasyMock.anyShort(), + internalTopicsAllowed = EasyMock.eq(true), + origin = EasyMock.eq(AppendOrigin.Coordinator), + EasyMock.anyObject().asInstanceOf[Map[TopicPartition, MemoryRecords]], + EasyMock.capture(capturedArgument), + EasyMock.anyObject().asInstanceOf[Option[ReentrantLock]], + EasyMock.anyObject(), + EasyMock.anyObject()) + ) + EasyMock.expect(replicaManager.getMagic(EasyMock.anyObject())).andStubReturn(Some(RecordBatch.CURRENT_MAGIC_VALUE)) + capturedArgument + } + + private def expectAppendMessage(error: Errors): Capture[Map[TopicPartition, MemoryRecords]] = { + val capturedCallback: Capture[Map[TopicPartition, PartitionResponse] => Unit] = EasyMock.newCapture() + val capturedRecords: Capture[Map[TopicPartition, MemoryRecords]] = EasyMock.newCapture() + EasyMock.expect(replicaManager.appendRecords(EasyMock.anyLong(), + EasyMock.anyShort(), + internalTopicsAllowed = EasyMock.eq(true), + origin = EasyMock.eq(AppendOrigin.Coordinator), + EasyMock.capture(capturedRecords), + EasyMock.capture(capturedCallback), + EasyMock.anyObject().asInstanceOf[Option[ReentrantLock]], + EasyMock.anyObject(), + EasyMock.anyObject()) + ).andAnswer(new IAnswer[Unit] { + override def answer: Unit = capturedCallback.getValue.apply( + Map(groupTopicPartition -> + new PartitionResponse(error, 0L, RecordBatch.NO_TIMESTAMP, 0L) + ) + )}) + EasyMock.expect(replicaManager.getMagic(EasyMock.anyObject())).andStubReturn(Some(RecordBatch.CURRENT_MAGIC_VALUE)) + capturedRecords + } + + private def buildStableGroupRecordWithMember(generation: Int, + protocolType: String, + protocol: String, + memberId: String, + assignmentBytes: Array[Byte] = Array.emptyByteArray, + apiVersion: ApiVersion = ApiVersion.latestVersion): SimpleRecord = { + val memberProtocols = List((protocol, Array.emptyByteArray)) + val member = new MemberMetadata(memberId, Some(groupInstanceId), "clientId", "clientHost", 30000, 10000, protocolType, memberProtocols) + val group = GroupMetadata.loadGroup(groupId, Stable, generation, protocolType, protocol, memberId, + if (apiVersion >= KAFKA_2_1_IV0) Some(time.milliseconds()) else None, Seq(member), time) + val groupMetadataKey = GroupMetadataManager.groupMetadataKey(groupId) + val groupMetadataValue = GroupMetadataManager.groupMetadataValue(group, Map(memberId -> assignmentBytes), apiVersion) + new SimpleRecord(groupMetadataKey, groupMetadataValue) + } + + private def buildEmptyGroupRecord(generation: Int, protocolType: String): SimpleRecord = { + val group = GroupMetadata.loadGroup(groupId, Empty, generation, protocolType, null, null, None, Seq.empty, time) + val groupMetadataKey = GroupMetadataManager.groupMetadataKey(groupId) + val groupMetadataValue = GroupMetadataManager.groupMetadataValue(group, Map.empty, ApiVersion.latestVersion) + new SimpleRecord(groupMetadataKey, groupMetadataValue) + } + + private def expectGroupMetadataLoad(groupMetadataTopicPartition: TopicPartition, + startOffset: Long, + records: MemoryRecords): Unit = { + val logMock: UnifiedLog = EasyMock.mock(classOf[UnifiedLog]) + EasyMock.expect(replicaManager.getLog(groupMetadataTopicPartition)).andStubReturn(Some(logMock)) + val endOffset = expectGroupMetadataLoad(logMock, startOffset, records) + EasyMock.expect(replicaManager.getLogEndOffset(groupMetadataTopicPartition)).andStubReturn(Some(endOffset)) + EasyMock.replay(logMock) + } + + /** + * mock records into a mocked log + * + * @return the calculated end offset to be mocked into [[ReplicaManager.getLogEndOffset]] + */ + private def expectGroupMetadataLoad(logMock: UnifiedLog, + startOffset: Long, + records: MemoryRecords): Long = { + val endOffset = startOffset + records.records.asScala.size + val fileRecordsMock: FileRecords = EasyMock.mock(classOf[FileRecords]) + + EasyMock.expect(logMock.logStartOffset).andStubReturn(startOffset) + EasyMock.expect(logMock.read(EasyMock.eq(startOffset), + maxLength = EasyMock.anyInt(), + isolation = EasyMock.eq(FetchLogEnd), + minOneMessage = EasyMock.eq(true))) + .andReturn(FetchDataInfo(LogOffsetMetadata(startOffset), fileRecordsMock)) + + EasyMock.expect(fileRecordsMock.sizeInBytes()).andStubReturn(records.sizeInBytes) + + val bufferCapture = EasyMock.newCapture[ByteBuffer] + fileRecordsMock.readInto(EasyMock.capture(bufferCapture), EasyMock.anyInt()) + EasyMock.expectLastCall().andAnswer(new IAnswer[Unit] { + override def answer: Unit = { + val buffer = bufferCapture.getValue + buffer.put(records.buffer.duplicate) + buffer.flip() + } + }) + + EasyMock.replay(fileRecordsMock) + + endOffset + } + + private def createCommittedOffsetRecords(committedOffsets: Map[TopicPartition, Long], + groupId: String = groupId, + apiVersion: ApiVersion = ApiVersion.latestVersion, + retentionTimeOpt: Option[Long] = None): Seq[SimpleRecord] = { + committedOffsets.map { case (topicPartition, offset) => + val commitTimestamp = time.milliseconds() + val offsetAndMetadata = retentionTimeOpt match { + case Some(retentionTimeMs) => + val expirationTime = commitTimestamp + retentionTimeMs + OffsetAndMetadata(offset, "", commitTimestamp, expirationTime) + case None => + OffsetAndMetadata(offset, "", commitTimestamp) + } + val offsetCommitKey = GroupMetadataManager.offsetCommitKey(groupId, topicPartition) + val offsetCommitValue = GroupMetadataManager.offsetCommitValue(offsetAndMetadata, apiVersion) + new SimpleRecord(offsetCommitKey, offsetCommitValue) + }.toSeq + } + + private def mockGetPartition(): Unit = { + EasyMock.expect(replicaManager.getPartition(groupTopicPartition)).andStubReturn(HostedPartition.Online(partition)) + EasyMock.expect(replicaManager.onlinePartition(groupTopicPartition)).andStubReturn(Some(partition)) + } + + private def getGauge(manager: GroupMetadataManager, name: String): Gauge[Int] = { + KafkaYammerMetrics.defaultRegistry().allMetrics().get(manager.metricName(name, Map.empty)).asInstanceOf[Gauge[Int]] + } + + private def expectMetrics(manager: GroupMetadataManager, + expectedNumGroups: Int, + expectedNumGroupsPreparingRebalance: Int, + expectedNumGroupsCompletingRebalance: Int): Unit = { + assertEquals(expectedNumGroups, getGauge(manager, "NumGroups").value) + assertEquals(expectedNumGroupsPreparingRebalance, getGauge(manager, "NumGroupsPreparingRebalance").value) + assertEquals(expectedNumGroupsCompletingRebalance, getGauge(manager, "NumGroupsCompletingRebalance").value) + } + + @Test + def testMetrics(): Unit = { + groupMetadataManager.cleanupGroupMetadata() + expectMetrics(groupMetadataManager, 0, 0, 0) + val group = new GroupMetadata("foo2", Stable, time) + groupMetadataManager.addGroup(group) + expectMetrics(groupMetadataManager, 1, 0, 0) + group.transitionTo(PreparingRebalance) + expectMetrics(groupMetadataManager, 1, 1, 0) + group.transitionTo(CompletingRebalance) + expectMetrics(groupMetadataManager, 1, 0, 1) + } + + @Test + def testPartitionLoadMetric(): Unit = { + val server = ManagementFactory.getPlatformMBeanServer + val mBeanName = "kafka.server:type=group-coordinator-metrics" + val reporter = new JmxReporter + val metricsContext = new KafkaMetricsContext("kafka.server") + reporter.contextChange(metricsContext) + metrics.addReporter(reporter) + + def partitionLoadTime(attribute: String): Double = { + server.getAttribute(new ObjectName(mBeanName), attribute).asInstanceOf[Double] + } + + assertTrue(server.isRegistered(new ObjectName(mBeanName))) + assertEquals(Double.NaN, partitionLoadTime( "partition-load-time-max"), 0) + assertEquals(Double.NaN, partitionLoadTime("partition-load-time-avg"), 0) + assertTrue(reporter.containsMbean(mBeanName)) + + val groupMetadataTopicPartition = groupTopicPartition + val startOffset = 15L + val memberId = "98098230493" + val committedOffsets = Map( + new TopicPartition("foo", 0) -> 23L, + new TopicPartition("foo", 1) -> 455L, + new TopicPartition("bar", 0) -> 8992L + ) + + val offsetCommitRecords = createCommittedOffsetRecords(committedOffsets) + val groupMetadataRecord = buildStableGroupRecordWithMember(generation = 15, + protocolType = "consumer", protocol = "range", memberId) + val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, + (offsetCommitRecords ++ Seq(groupMetadataRecord)).toArray: _*) + + expectGroupMetadataLoad(groupMetadataTopicPartition, startOffset, records) + EasyMock.replay(replicaManager) + + // When passed a specific start offset, assert that the measured values are in excess of that. + val now = time.milliseconds() + val diff = 1000 + val groupEpoch = 2 + groupMetadataManager.loadGroupsAndOffsets(groupMetadataTopicPartition, groupEpoch, _ => (), now - diff) + assertTrue(partitionLoadTime("partition-load-time-max") >= diff) + assertTrue(partitionLoadTime("partition-load-time-avg") >= diff) + } +} diff --git a/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataTest.scala b/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataTest.scala new file mode 100644 index 0000000..275b7f6 --- /dev/null +++ b/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataTest.scala @@ -0,0 +1,725 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.coordinator.group + +import kafka.common.OffsetAndMetadata +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Subscription +import org.apache.kafka.clients.consumer.internals.ConsumerProtocol +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.utils.Time +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{BeforeEach, Test} + +import scala.jdk.CollectionConverters._ + +/** + * Test group state transitions and other GroupMetadata functionality + */ +class GroupMetadataTest { + private val protocolType = "consumer" + private val groupInstanceId = "groupInstanceId" + private val memberId = "memberId" + private val clientId = "clientId" + private val clientHost = "clientHost" + private val rebalanceTimeoutMs = 60000 + private val sessionTimeoutMs = 10000 + + private var group: GroupMetadata = null + + @BeforeEach + def setUp(): Unit = { + group = new GroupMetadata("groupId", Empty, Time.SYSTEM) + } + + @Test + def testCanRebalanceWhenStable(): Unit = { + assertTrue(group.canRebalance) + } + + @Test + def testCanRebalanceWhenCompletingRebalance(): Unit = { + group.transitionTo(PreparingRebalance) + group.transitionTo(CompletingRebalance) + assertTrue(group.canRebalance) + } + + @Test + def testCannotRebalanceWhenPreparingRebalance(): Unit = { + group.transitionTo(PreparingRebalance) + assertFalse(group.canRebalance) + } + + @Test + def testCannotRebalanceWhenDead(): Unit = { + group.transitionTo(PreparingRebalance) + group.transitionTo(Empty) + group.transitionTo(Dead) + assertFalse(group.canRebalance) + } + + @Test + def testStableToPreparingRebalanceTransition(): Unit = { + group.transitionTo(PreparingRebalance) + assertState(group, PreparingRebalance) + } + + @Test + def testStableToDeadTransition(): Unit = { + group.transitionTo(Dead) + assertState(group, Dead) + } + + @Test + def testAwaitingRebalanceToPreparingRebalanceTransition(): Unit = { + group.transitionTo(PreparingRebalance) + group.transitionTo(CompletingRebalance) + group.transitionTo(PreparingRebalance) + assertState(group, PreparingRebalance) + } + + @Test + def testPreparingRebalanceToDeadTransition(): Unit = { + group.transitionTo(PreparingRebalance) + group.transitionTo(Dead) + assertState(group, Dead) + } + + @Test + def testPreparingRebalanceToEmptyTransition(): Unit = { + group.transitionTo(PreparingRebalance) + group.transitionTo(Empty) + assertState(group, Empty) + } + + @Test + def testEmptyToDeadTransition(): Unit = { + group.transitionTo(PreparingRebalance) + group.transitionTo(Empty) + group.transitionTo(Dead) + assertState(group, Dead) + } + + @Test + def testAwaitingRebalanceToStableTransition(): Unit = { + group.transitionTo(PreparingRebalance) + group.transitionTo(CompletingRebalance) + group.transitionTo(Stable) + assertState(group, Stable) + } + + @Test + def testEmptyToStableIllegalTransition(): Unit = { + assertThrows(classOf[IllegalStateException], () => group.transitionTo(Stable)) + } + + @Test + def testStableToStableIllegalTransition(): Unit = { + group.transitionTo(PreparingRebalance) + group.transitionTo(CompletingRebalance) + group.transitionTo(Stable) + assertThrows(classOf[IllegalStateException], () => group.transitionTo(Stable)) + } + + @Test + def testEmptyToAwaitingRebalanceIllegalTransition(): Unit = { + assertThrows(classOf[IllegalStateException], () => group.transitionTo(CompletingRebalance)) + } + + @Test + def testPreparingRebalanceToPreparingRebalanceIllegalTransition(): Unit = { + group.transitionTo(PreparingRebalance) + assertThrows(classOf[IllegalStateException], () => group.transitionTo(PreparingRebalance)) + } + + @Test + def testPreparingRebalanceToStableIllegalTransition(): Unit = { + group.transitionTo(PreparingRebalance) + assertThrows(classOf[IllegalStateException], () => group.transitionTo(Stable)) + } + + @Test + def testAwaitingRebalanceToAwaitingRebalanceIllegalTransition(): Unit = { + group.transitionTo(PreparingRebalance) + group.transitionTo(CompletingRebalance) + assertThrows(classOf[IllegalStateException], () => group.transitionTo(CompletingRebalance)) + } + + def testDeadToDeadIllegalTransition(): Unit = { + group.transitionTo(PreparingRebalance) + group.transitionTo(Dead) + group.transitionTo(Dead) + assertState(group, Dead) + } + + @Test + def testDeadToStableIllegalTransition(): Unit = { + group.transitionTo(PreparingRebalance) + group.transitionTo(Dead) + assertThrows(classOf[IllegalStateException], () => group.transitionTo(Stable)) + } + + @Test + def testDeadToPreparingRebalanceIllegalTransition(): Unit = { + group.transitionTo(PreparingRebalance) + group.transitionTo(Dead) + assertThrows(classOf[IllegalStateException], () => group.transitionTo(PreparingRebalance)) + } + + @Test + def testDeadToAwaitingRebalanceIllegalTransition(): Unit = { + group.transitionTo(PreparingRebalance) + group.transitionTo(Dead) + assertThrows(classOf[IllegalStateException], () => group.transitionTo(CompletingRebalance)) + } + + @Test + def testSelectProtocol(): Unit = { + val memberId = "memberId" + val member = new MemberMetadata(memberId, None, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs, + protocolType, List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte]))) + + group.add(member) + assertEquals("range", group.selectProtocol) + + val otherMemberId = "otherMemberId" + val otherMember = new MemberMetadata(otherMemberId, None, clientId, clientHost, rebalanceTimeoutMs, + sessionTimeoutMs, protocolType, List(("roundrobin", Array.empty[Byte]), ("range", Array.empty[Byte]))) + + group.add(otherMember) + // now could be either range or robin since there is no majority preference + assertTrue(Set("range", "roundrobin")(group.selectProtocol)) + + val lastMemberId = "lastMemberId" + val lastMember = new MemberMetadata(lastMemberId, None, clientId, clientHost, rebalanceTimeoutMs, + sessionTimeoutMs, protocolType, List(("roundrobin", Array.empty[Byte]), ("range", Array.empty[Byte]))) + + group.add(lastMember) + // now we should prefer 'roundrobin' + assertEquals("roundrobin", group.selectProtocol) + } + + @Test + def testSelectProtocolRaisesIfNoMembers(): Unit = { + assertThrows(classOf[IllegalStateException], () => group.selectProtocol) + } + + @Test + def testSelectProtocolChoosesCompatibleProtocol(): Unit = { + val memberId = "memberId" + val member = new MemberMetadata(memberId, None, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs, + protocolType, List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte]))) + + val otherMemberId = "otherMemberId" + val otherMember = new MemberMetadata(otherMemberId, None, clientId, clientHost, rebalanceTimeoutMs, + sessionTimeoutMs, protocolType, List(("roundrobin", Array.empty[Byte]), ("blah", Array.empty[Byte]))) + + group.add(member) + group.add(otherMember) + assertEquals("roundrobin", group.selectProtocol) + } + + @Test + def testSupportsProtocols(): Unit = { + val member = new MemberMetadata(memberId, None, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs, + protocolType, List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte]))) + + // by default, the group supports everything + assertTrue(group.supportsProtocols(protocolType, Set("roundrobin", "range"))) + + group.add(member) + group.transitionTo(PreparingRebalance) + assertTrue(group.supportsProtocols(protocolType, Set("roundrobin", "foo"))) + assertTrue(group.supportsProtocols(protocolType, Set("range", "foo"))) + assertFalse(group.supportsProtocols(protocolType, Set("foo", "bar"))) + + val otherMemberId = "otherMemberId" + val otherMember = new MemberMetadata(otherMemberId, None, clientId, clientHost, rebalanceTimeoutMs, + sessionTimeoutMs, protocolType, List(("roundrobin", Array.empty[Byte]), ("blah", Array.empty[Byte]))) + + group.add(otherMember) + + assertTrue(group.supportsProtocols(protocolType, Set("roundrobin", "foo"))) + assertFalse(group.supportsProtocols("invalid_type", Set("roundrobin", "foo"))) + assertFalse(group.supportsProtocols(protocolType, Set("range", "foo"))) + } + + @Test + def testSubscribedTopics(): Unit = { + // not able to compute it for a newly created group + assertEquals(None, group.getSubscribedTopics) + + val memberId = "memberId" + val member = new MemberMetadata(memberId, None, clientId, clientHost, rebalanceTimeoutMs, + sessionTimeoutMs, protocolType, List(("range", ConsumerProtocol.serializeSubscription(new Subscription(List("foo").asJava)).array()))) + + group.transitionTo(PreparingRebalance) + group.add(member) + + group.initNextGeneration() + + assertEquals(Some(Set("foo")), group.getSubscribedTopics) + + group.transitionTo(PreparingRebalance) + group.remove(memberId) + + group.initNextGeneration() + + assertEquals(Some(Set.empty), group.getSubscribedTopics) + + val memberWithFaultyProtocol = new MemberMetadata(memberId, None, clientId, clientHost, rebalanceTimeoutMs, + sessionTimeoutMs, protocolType, List(("range", Array.empty[Byte]))) + + group.transitionTo(PreparingRebalance) + group.add(memberWithFaultyProtocol) + + group.initNextGeneration() + + assertEquals(None, group.getSubscribedTopics) + } + + @Test + def testSubscribedTopicsNonConsumerGroup(): Unit = { + // not able to compute it for a newly created group + assertEquals(None, group.getSubscribedTopics) + + val memberId = "memberId" + val member = new MemberMetadata(memberId, None, clientId, clientHost, rebalanceTimeoutMs, + sessionTimeoutMs, "My Protocol", List(("range", Array.empty[Byte]))) + + group.transitionTo(PreparingRebalance) + group.add(member) + + group.initNextGeneration() + + assertEquals(None, group.getSubscribedTopics) + } + + @Test + def testInitNextGeneration(): Unit = { + val member = new MemberMetadata(memberId, None, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs, + protocolType, List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte]))) + + member.supportedProtocols = List(("roundrobin", Array.empty[Byte])) + + group.transitionTo(PreparingRebalance) + group.add(member, _ => ()) + + assertEquals(0, group.generationId) + assertNull(group.protocolName.orNull) + + group.initNextGeneration() + + assertEquals(1, group.generationId) + assertEquals("roundrobin", group.protocolName.orNull) + } + + @Test + def testInitNextGenerationEmptyGroup(): Unit = { + assertEquals(Empty, group.currentState) + assertEquals(0, group.generationId) + assertNull(group.protocolName.orNull) + + group.transitionTo(PreparingRebalance) + group.initNextGeneration() + + assertEquals(1, group.generationId) + assertNull(group.protocolName.orNull) + } + + @Test + def testOffsetCommit(): Unit = { + val partition = new TopicPartition("foo", 0) + val offset = offsetAndMetadata(37) + val commitRecordOffset = 3 + + group.prepareOffsetCommit(Map(partition -> offset)) + assertTrue(group.hasOffsets) + assertEquals(None, group.offset(partition)) + + group.onOffsetCommitAppend(partition, CommitRecordMetadataAndOffset(Some(commitRecordOffset), offset)) + assertTrue(group.hasOffsets) + assertEquals(Some(offset), group.offset(partition)) + } + + @Test + def testOffsetCommitFailure(): Unit = { + val partition = new TopicPartition("foo", 0) + val offset = offsetAndMetadata(37) + + group.prepareOffsetCommit(Map(partition -> offset)) + assertTrue(group.hasOffsets) + assertEquals(None, group.offset(partition)) + + group.failPendingOffsetWrite(partition, offset) + assertFalse(group.hasOffsets) + assertEquals(None, group.offset(partition)) + } + + @Test + def testOffsetCommitFailureWithAnotherPending(): Unit = { + val partition = new TopicPartition("foo", 0) + val firstOffset = offsetAndMetadata(37) + val secondOffset = offsetAndMetadata(57) + + group.prepareOffsetCommit(Map(partition -> firstOffset)) + assertTrue(group.hasOffsets) + assertEquals(None, group.offset(partition)) + + group.prepareOffsetCommit(Map(partition -> secondOffset)) + assertTrue(group.hasOffsets) + + group.failPendingOffsetWrite(partition, firstOffset) + assertTrue(group.hasOffsets) + assertEquals(None, group.offset(partition)) + + group.onOffsetCommitAppend(partition, CommitRecordMetadataAndOffset(Some(3L), secondOffset)) + assertTrue(group.hasOffsets) + assertEquals(Some(secondOffset), group.offset(partition)) + } + + @Test + def testOffsetCommitWithAnotherPending(): Unit = { + val partition = new TopicPartition("foo", 0) + val firstOffset = offsetAndMetadata(37) + val secondOffset = offsetAndMetadata(57) + + group.prepareOffsetCommit(Map(partition -> firstOffset)) + assertTrue(group.hasOffsets) + assertEquals(None, group.offset(partition)) + + group.prepareOffsetCommit(Map(partition -> secondOffset)) + assertTrue(group.hasOffsets) + + group.onOffsetCommitAppend(partition, CommitRecordMetadataAndOffset(Some(4L), firstOffset)) + assertTrue(group.hasOffsets) + assertEquals(Some(firstOffset), group.offset(partition)) + + group.onOffsetCommitAppend(partition, CommitRecordMetadataAndOffset(Some(5L), secondOffset)) + assertTrue(group.hasOffsets) + assertEquals(Some(secondOffset), group.offset(partition)) + } + + @Test + def testConsumerBeatsTransactionalOffsetCommit(): Unit = { + val partition = new TopicPartition("foo", 0) + val producerId = 13232L + val txnOffsetCommit = offsetAndMetadata(37) + val consumerOffsetCommit = offsetAndMetadata(57) + + group.prepareTxnOffsetCommit(producerId, Map(partition -> txnOffsetCommit)) + assertTrue(group.hasOffsets) + assertEquals(None, group.offset(partition)) + + group.prepareOffsetCommit(Map(partition -> consumerOffsetCommit)) + assertTrue(group.hasOffsets) + + group.onTxnOffsetCommitAppend(producerId, partition, CommitRecordMetadataAndOffset(Some(3L), txnOffsetCommit)) + group.onOffsetCommitAppend(partition, CommitRecordMetadataAndOffset(Some(4L), consumerOffsetCommit)) + assertTrue(group.hasOffsets) + assertEquals(Some(consumerOffsetCommit), group.offset(partition)) + + group.completePendingTxnOffsetCommit(producerId, isCommit = true) + assertTrue(group.hasOffsets) + // This is the crucial assertion which validates that we materialize offsets in offset order, not transactional order. + assertEquals(Some(consumerOffsetCommit), group.offset(partition)) + } + + @Test + def testTransactionBeatsConsumerOffsetCommit(): Unit = { + val partition = new TopicPartition("foo", 0) + val producerId = 13232L + val txnOffsetCommit = offsetAndMetadata(37) + val consumerOffsetCommit = offsetAndMetadata(57) + + group.prepareTxnOffsetCommit(producerId, Map(partition -> txnOffsetCommit)) + assertTrue(group.hasOffsets) + assertEquals(None, group.offset(partition)) + + group.prepareOffsetCommit(Map(partition -> consumerOffsetCommit)) + assertTrue(group.hasOffsets) + + group.onOffsetCommitAppend(partition, CommitRecordMetadataAndOffset(Some(3L), consumerOffsetCommit)) + group.onTxnOffsetCommitAppend(producerId, partition, CommitRecordMetadataAndOffset(Some(4L), txnOffsetCommit)) + assertTrue(group.hasOffsets) + // The transactional offset commit hasn't been committed yet, so we should materialize the consumer offset commit. + assertEquals(Some(consumerOffsetCommit), group.offset(partition)) + + group.completePendingTxnOffsetCommit(producerId, isCommit = true) + assertTrue(group.hasOffsets) + // The transactional offset commit has been materialized and the transactional commit record is later in the log, + // so it should be materialized. + assertEquals(Some(txnOffsetCommit), group.offset(partition)) + } + + @Test + def testTransactionalCommitIsAbortedAndConsumerCommitWins(): Unit = { + val partition = new TopicPartition("foo", 0) + val producerId = 13232L + val txnOffsetCommit = offsetAndMetadata(37) + val consumerOffsetCommit = offsetAndMetadata(57) + + group.prepareTxnOffsetCommit(producerId, Map(partition -> txnOffsetCommit)) + assertTrue(group.hasOffsets) + assertEquals(None, group.offset(partition)) + + group.prepareOffsetCommit(Map(partition -> consumerOffsetCommit)) + assertTrue(group.hasOffsets) + + group.onOffsetCommitAppend(partition, CommitRecordMetadataAndOffset(Some(3L), consumerOffsetCommit)) + group.onTxnOffsetCommitAppend(producerId, partition, CommitRecordMetadataAndOffset(Some(4L), txnOffsetCommit)) + assertTrue(group.hasOffsets) + // The transactional offset commit hasn't been committed yet, so we should materialize the consumer offset commit. + assertEquals(Some(consumerOffsetCommit), group.offset(partition)) + + group.completePendingTxnOffsetCommit(producerId, isCommit = false) + assertTrue(group.hasOffsets) + // The transactional offset commit should be discarded and the consumer offset commit should continue to be + // materialized. + assertFalse(group.hasPendingOffsetCommitsFromProducer(producerId)) + assertEquals(Some(consumerOffsetCommit), group.offset(partition)) + } + + @Test + def testFailedTxnOffsetCommitLeavesNoPendingState(): Unit = { + val partition = new TopicPartition("foo", 0) + val producerId = 13232L + val txnOffsetCommit = offsetAndMetadata(37) + + group.prepareTxnOffsetCommit(producerId, Map(partition -> txnOffsetCommit)) + assertTrue(group.hasPendingOffsetCommitsFromProducer(producerId)) + assertTrue(group.hasOffsets) + assertEquals(None, group.offset(partition)) + group.failPendingTxnOffsetCommit(producerId, partition) + assertFalse(group.hasOffsets) + assertFalse(group.hasPendingOffsetCommitsFromProducer(producerId)) + + // The commit marker should now have no effect. + group.completePendingTxnOffsetCommit(producerId, isCommit = true) + assertFalse(group.hasOffsets) + assertFalse(group.hasPendingOffsetCommitsFromProducer(producerId)) + } + + @Test + def testReplaceGroupInstanceWithNonExistingMember(): Unit = { + val newMemberId = "newMemberId" + assertThrows(classOf[IllegalArgumentException], () => group.replaceStaticMember(groupInstanceId, memberId, newMemberId)) + } + + @Test + def testReplaceGroupInstance(): Unit = { + val member = new MemberMetadata(memberId, Some(groupInstanceId), clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs, + protocolType, List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte]))) + + var joinAwaitingMemberFenced = false + group.add(member, joinGroupResult => { + joinAwaitingMemberFenced = joinGroupResult.error == Errors.FENCED_INSTANCE_ID + }) + var syncAwaitingMemberFenced = false + member.awaitingSyncCallback = syncGroupResult => { + syncAwaitingMemberFenced = syncGroupResult.error == Errors.FENCED_INSTANCE_ID + } + assertTrue(group.isLeader(memberId)) + assertEquals(Some(memberId), group.currentStaticMemberId(groupInstanceId)) + + val newMemberId = "newMemberId" + group.replaceStaticMember(groupInstanceId, memberId, newMemberId) + assertTrue(group.isLeader(newMemberId)) + assertEquals(Some(newMemberId), group.currentStaticMemberId(groupInstanceId)) + assertTrue(joinAwaitingMemberFenced) + assertTrue(syncAwaitingMemberFenced) + assertFalse(member.isAwaitingJoin) + assertFalse(member.isAwaitingSync) + } + + @Test + def testInvokeJoinCallback(): Unit = { + val member = new MemberMetadata(memberId, None, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs, + protocolType, List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte]))) + + var invoked = false + group.add(member, _ => { + invoked = true + }) + + assertTrue(group.hasAllMembersJoined) + group.maybeInvokeJoinCallback(member, JoinGroupResult(member.memberId, Errors.NONE)) + assertTrue(invoked) + assertFalse(member.isAwaitingJoin) + } + + @Test + def testNotInvokeJoinCallback(): Unit = { + val member = new MemberMetadata(memberId, None, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs, + protocolType, List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte]))) + group.add(member) + + assertFalse(member.isAwaitingJoin) + group.maybeInvokeJoinCallback(member, JoinGroupResult(member.memberId, Errors.NONE)) + assertFalse(member.isAwaitingJoin) + } + + @Test + def testInvokeSyncCallback(): Unit = { + val member = new MemberMetadata(memberId, None, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs, + protocolType, List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte]))) + + group.add(member) + member.awaitingSyncCallback = _ => {} + + val invoked = group.maybeInvokeSyncCallback(member, SyncGroupResult(Errors.NONE)) + assertTrue(invoked) + assertFalse(member.isAwaitingSync) + } + + @Test + def testNotInvokeSyncCallback(): Unit = { + val member = new MemberMetadata(memberId, None, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs, + protocolType, List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte]))) + group.add(member) + + val invoked = group.maybeInvokeSyncCallback(member, SyncGroupResult(Errors.NONE)) + assertFalse(invoked) + assertFalse(member.isAwaitingSync) + } + + @Test + def testHasPendingNonTxnOffsets(): Unit = { + val partition = new TopicPartition("foo", 0) + val offset = offsetAndMetadata(37) + + group.prepareOffsetCommit(Map(partition -> offset)) + assertTrue(group.hasPendingOffsetCommitsForTopicPartition(partition)) + } + + @Test + def testHasPendingTxnOffsets(): Unit = { + val txnPartition = new TopicPartition("foo", 1) + val offset = offsetAndMetadata(37) + val producerId = 5 + + group.prepareTxnOffsetCommit(producerId, Map(txnPartition -> offset)) + assertTrue(group.hasPendingOffsetCommitsForTopicPartition(txnPartition)) + + assertFalse(group.hasPendingOffsetCommitsForTopicPartition(new TopicPartition("non-exist", 0))) + } + + @Test + def testCannotAddPendingMemberIfStable(): Unit = { + val member = new MemberMetadata(memberId, None, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs, + protocolType, List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte]))) + group.add(member) + assertThrows(classOf[IllegalStateException], () => group.addPendingMember(memberId)) + } + + @Test + def testRemovalFromPendingAfterMemberIsStable(): Unit = { + group.addPendingMember(memberId) + assertFalse(group.has(memberId)) + assertTrue(group.isPendingMember(memberId)) + + val member = new MemberMetadata(memberId, None, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs, + protocolType, List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte]))) + group.add(member) + assertTrue(group.has(memberId)) + assertFalse(group.isPendingMember(memberId)) + } + + @Test + def testRemovalFromPendingWhenMemberIsRemoved(): Unit = { + group.addPendingMember(memberId) + assertFalse(group.has(memberId)) + assertTrue(group.isPendingMember(memberId)) + + group.remove(memberId) + assertFalse(group.has(memberId)) + assertFalse(group.isPendingMember(memberId)) + } + + @Test + def testCannotAddStaticMemberIfAlreadyPresent(): Unit = { + val member = new MemberMetadata(memberId, Some(groupInstanceId), clientId, clientHost, + rebalanceTimeoutMs, sessionTimeoutMs, protocolType, List(("range", Array.empty[Byte]))) + group.add(member) + assertTrue(group.has(memberId)) + assertTrue(group.hasStaticMember(groupInstanceId)) + + // We aren ot permitted to add the member again if it is already present + assertThrows(classOf[IllegalStateException], () => group.add(member)) + } + + @Test + def testCannotAddPendingSyncOfUnknownMember(): Unit = { + assertThrows(classOf[IllegalStateException], + () => group.addPendingSyncMember(memberId)) + } + + @Test + def testCannotRemovePendingSyncOfUnknownMember(): Unit = { + assertThrows(classOf[IllegalStateException], + () => group.removePendingSyncMember(memberId)) + } + + @Test + def testCanAddAndRemovePendingSyncMember(): Unit = { + val member = new MemberMetadata(memberId, Some(groupInstanceId), clientId, clientHost, + rebalanceTimeoutMs, sessionTimeoutMs, protocolType, List(("range", Array.empty[Byte]))) + group.add(member) + group.addPendingSyncMember(memberId) + assertEquals(Set(memberId), group.allPendingSyncMembers) + group.removePendingSyncMember(memberId) + assertEquals(Set(), group.allPendingSyncMembers) + } + + @Test + def testRemovalFromPendingSyncWhenMemberIsRemoved(): Unit = { + val member = new MemberMetadata(memberId, Some(groupInstanceId), clientId, clientHost, + rebalanceTimeoutMs, sessionTimeoutMs, protocolType, List(("range", Array.empty[Byte]))) + group.add(member) + group.addPendingSyncMember(memberId) + assertEquals(Set(memberId), group.allPendingSyncMembers) + group.remove(memberId) + assertEquals(Set(), group.allPendingSyncMembers) + } + + @Test + def testNewGenerationClearsPendingSyncMembers(): Unit = { + val member = new MemberMetadata(memberId, Some(groupInstanceId), clientId, clientHost, + rebalanceTimeoutMs, sessionTimeoutMs, protocolType, List(("range", Array.empty[Byte]))) + group.add(member) + group.transitionTo(PreparingRebalance) + group.addPendingSyncMember(memberId) + assertEquals(Set(memberId), group.allPendingSyncMembers) + group.initNextGeneration() + assertEquals(Set(), group.allPendingSyncMembers) + } + + private def assertState(group: GroupMetadata, targetState: GroupState): Unit = { + val states: Set[GroupState] = Set(Stable, PreparingRebalance, CompletingRebalance, Dead) + val otherStates = states - targetState + otherStates.foreach { otherState => + assertFalse(group.is(otherState)) + } + assertTrue(group.is(targetState)) + } + + private def offsetAndMetadata(offset: Long): OffsetAndMetadata = { + OffsetAndMetadata(offset, "", Time.SYSTEM.milliseconds()) + } + +} diff --git a/core/src/test/scala/unit/kafka/coordinator/group/MemberMetadataTest.scala b/core/src/test/scala/unit/kafka/coordinator/group/MemberMetadataTest.scala new file mode 100644 index 0000000..a2b8023 --- /dev/null +++ b/core/src/test/scala/unit/kafka/coordinator/group/MemberMetadataTest.scala @@ -0,0 +1,94 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.coordinator.group + +import java.util.Arrays + +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +class MemberMetadataTest { + val groupId = "groupId" + val groupInstanceId = Some("groupInstanceId") + val clientId = "clientId" + val clientHost = "clientHost" + val memberId = "memberId" + val protocolType = "consumer" + val rebalanceTimeoutMs = 60000 + val sessionTimeoutMs = 10000 + + + @Test + def testMatchesSupportedProtocols(): Unit = { + val protocols = List(("range", Array.empty[Byte])) + + val member = new MemberMetadata(memberId, groupInstanceId, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs, + protocolType, protocols) + assertTrue(member.matches(protocols)) + assertFalse(member.matches(List(("range", Array[Byte](0))))) + assertFalse(member.matches(List(("roundrobin", Array.empty[Byte])))) + assertFalse(member.matches(List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte])))) + } + + @Test + def testVoteForPreferredProtocol(): Unit = { + val protocols = List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte])) + + val member = new MemberMetadata(memberId, groupInstanceId, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs, + protocolType, protocols) + assertEquals("range", member.vote(Set("range", "roundrobin"))) + assertEquals("roundrobin", member.vote(Set("blah", "roundrobin"))) + } + + @Test + def testMetadata(): Unit = { + val protocols = List(("range", Array[Byte](0)), ("roundrobin", Array[Byte](1))) + + val member = new MemberMetadata(memberId, groupInstanceId, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs, + protocolType, protocols) + assertTrue(Arrays.equals(Array[Byte](0), member.metadata("range"))) + assertTrue(Arrays.equals(Array[Byte](1), member.metadata("roundrobin"))) + } + + @Test + def testMetadataRaisesOnUnsupportedProtocol(): Unit = { + val protocols = List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte])) + + val member = new MemberMetadata(memberId, groupInstanceId, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs, + protocolType, protocols) + assertThrows(classOf[IllegalArgumentException], () => member.metadata("blah")) + } + + @Test + def testVoteRaisesOnNoSupportedProtocols(): Unit = { + val protocols = List(("range", Array.empty[Byte]), ("roundrobin", Array.empty[Byte])) + + val member = new MemberMetadata(memberId, groupInstanceId, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs, + protocolType, protocols) + assertThrows(classOf[IllegalArgumentException], () => member.vote(Set("blah"))) + } + + @Test + def testHasValidGroupInstanceId(): Unit = { + val protocols = List(("range", Array[Byte](0)), ("roundrobin", Array[Byte](1))) + + val member = new MemberMetadata(memberId, groupInstanceId, clientId, clientHost, rebalanceTimeoutMs, sessionTimeoutMs, + protocolType, protocols) + assertTrue(member.isStaticMember) + assertEquals(groupInstanceId, member.groupInstanceId) + } +} diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/ProducerIdManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/ProducerIdManagerTest.scala new file mode 100644 index 0000000..9232bf0 --- /dev/null +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/ProducerIdManagerTest.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.coordinator.transaction + +import kafka.server.BrokerToControllerChannelManager +import kafka.zk.{KafkaZkClient, ProducerIdBlockZNode} +import org.apache.kafka.common.KafkaException +import org.apache.kafka.common.message.AllocateProducerIdsResponseData +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.AllocateProducerIdsResponse +import org.apache.kafka.server.common.ProducerIdsBlock +import org.easymock.{Capture, EasyMock} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.{EnumSource, ValueSource} + +import java.util.stream.IntStream + +class ProducerIdManagerTest { + + var brokerToController: BrokerToControllerChannelManager = EasyMock.niceMock(classOf[BrokerToControllerChannelManager]) + val zkClient: KafkaZkClient = EasyMock.createNiceMock(classOf[KafkaZkClient]) + + // Mutable test implementation that lets us easily set the idStart and error + class MockProducerIdManager(val brokerId: Int, var idStart: Long, val idLen: Int, var error: Errors = Errors.NONE) + extends RPCProducerIdManager(brokerId, () => 1, brokerToController, 100) { + + override private[transaction] def sendRequest(): Unit = { + if (error == Errors.NONE) { + handleAllocateProducerIdsResponse(new AllocateProducerIdsResponse( + new AllocateProducerIdsResponseData().setProducerIdStart(idStart).setProducerIdLen(idLen))) + idStart += idLen + } else { + handleAllocateProducerIdsResponse(new AllocateProducerIdsResponse( + new AllocateProducerIdsResponseData().setErrorCode(error.code))) + } + } + } + + @Test + def testGetProducerIdZk(): Unit = { + var zkVersion: Option[Int] = None + var data: Array[Byte] = null + EasyMock.expect(zkClient.getDataAndVersion(EasyMock.anyString)).andAnswer(() => + zkVersion.map(Some(data) -> _).getOrElse(None, 0)).anyTimes() + + val capturedVersion: Capture[Int] = EasyMock.newCapture() + val capturedData: Capture[Array[Byte]] = EasyMock.newCapture() + EasyMock.expect(zkClient.conditionalUpdatePath(EasyMock.anyString(), + EasyMock.capture(capturedData), + EasyMock.capture(capturedVersion), + EasyMock.anyObject[Option[(KafkaZkClient, String, Array[Byte]) => (Boolean, Int)]]) + ).andAnswer(() => { + val newZkVersion = capturedVersion.getValue + 1 + zkVersion = Some(newZkVersion) + data = capturedData.getValue + (true, newZkVersion) + }).anyTimes() + + EasyMock.replay(zkClient) + + val manager1 = new ZkProducerIdManager(0, zkClient) + val manager2 = new ZkProducerIdManager(1, zkClient) + + val pid1 = manager1.generateProducerId() + val pid2 = manager2.generateProducerId() + + assertEquals(0, pid1) + assertEquals(ProducerIdsBlock.PRODUCER_ID_BLOCK_SIZE, pid2) + + for (i <- 1L until ProducerIdsBlock.PRODUCER_ID_BLOCK_SIZE) + assertEquals(pid1 + i, manager1.generateProducerId()) + + for (i <- 1L until ProducerIdsBlock.PRODUCER_ID_BLOCK_SIZE) + assertEquals(pid2 + i, manager2.generateProducerId()) + + assertEquals(pid2 + ProducerIdsBlock.PRODUCER_ID_BLOCK_SIZE, manager1.generateProducerId()) + assertEquals(pid2 + ProducerIdsBlock.PRODUCER_ID_BLOCK_SIZE * 2, manager2.generateProducerId()) + + EasyMock.reset(zkClient) + } + + @Test + def testExceedProducerIdLimitZk(): Unit = { + EasyMock.expect(zkClient.getDataAndVersion(EasyMock.anyString)).andAnswer(() => { + val json = ProducerIdBlockZNode.generateProducerIdBlockJson( + new ProducerIdsBlock(0, Long.MaxValue - ProducerIdsBlock.PRODUCER_ID_BLOCK_SIZE, ProducerIdsBlock.PRODUCER_ID_BLOCK_SIZE)) + (Some(json), 0) + }).anyTimes() + EasyMock.replay(zkClient) + assertThrows(classOf[KafkaException], () => new ZkProducerIdManager(0, zkClient)) + } + + @ParameterizedTest + @ValueSource(ints = Array(1, 2, 10)) + def testContiguousIds(idBlockLen: Int): Unit = { + val manager = new MockProducerIdManager(0, 0, idBlockLen) + + IntStream.range(0, idBlockLen * 3).forEach { i => + assertEquals(i, manager.generateProducerId()) + } + } + + @ParameterizedTest + @EnumSource(value = classOf[Errors], names = Array("UNKNOWN_SERVER_ERROR", "INVALID_REQUEST")) + def testUnrecoverableErrors(error: Errors): Unit = { + val manager = new MockProducerIdManager(0, 0, 1) + assertEquals(0, manager.generateProducerId()) + + manager.error = error + assertThrows(classOf[Throwable], () => manager.generateProducerId()) + + manager.error = Errors.NONE + assertEquals(1, manager.generateProducerId()) + } + + @Test + def testInvalidRanges(): Unit = { + var manager = new MockProducerIdManager(0, -1, 10) + assertThrows(classOf[KafkaException], () => manager.generateProducerId()) + + manager = new MockProducerIdManager(0, 0, -1) + assertThrows(classOf[KafkaException], () => manager.generateProducerId()) + + manager = new MockProducerIdManager(0, Long.MaxValue-1, 10) + assertThrows(classOf[KafkaException], () => manager.generateProducerId()) + } +} + diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala new file mode 100644 index 0000000..85778d5 --- /dev/null +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala @@ -0,0 +1,630 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.coordinator.transaction + +import java.nio.ByteBuffer +import java.util.Collections +import java.util.concurrent.atomic.AtomicBoolean + +import kafka.coordinator.AbstractCoordinatorConcurrencyTest +import kafka.coordinator.AbstractCoordinatorConcurrencyTest._ +import kafka.coordinator.transaction.TransactionCoordinatorConcurrencyTest._ +import kafka.log.{UnifiedLog, LogConfig} +import kafka.server.{FetchDataInfo, FetchLogEnd, KafkaConfig, LogOffsetMetadata, MetadataCache, RequestLocal} +import kafka.utils.{Pool, TestUtils} +import org.apache.kafka.clients.{ClientResponse, NetworkClient} +import org.apache.kafka.common.internals.Topic.TRANSACTION_STATE_TOPIC_NAME +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.record.{CompressionType, FileRecords, MemoryRecords, RecordBatch, SimpleRecord} +import org.apache.kafka.common.requests._ +import org.apache.kafka.common.utils.{LogContext, MockTime, ProducerIdAndEpoch} +import org.apache.kafka.common.{Node, TopicPartition} +import org.easymock.{EasyMock, IAnswer} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} + +import scala.jdk.CollectionConverters._ +import scala.collection.{Map, mutable} + +class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurrencyTest[Transaction] { + private val nTransactions = nThreads * 10 + private val coordinatorEpoch = 10 + private val numPartitions = nThreads * 5 + + private val txnConfig = TransactionConfig() + private var transactionCoordinator: TransactionCoordinator = _ + private var txnStateManager: TransactionStateManager = _ + private var txnMarkerChannelManager: TransactionMarkerChannelManager = _ + + private val allOperations = Seq( + new InitProducerIdOperation, + new AddPartitionsToTxnOperation(Set(new TopicPartition("topic", 0))), + new EndTxnOperation) + + private val allTransactions = mutable.Set[Transaction]() + private val txnRecordsByPartition: Map[Int, mutable.ArrayBuffer[SimpleRecord]] = + (0 until numPartitions).map { i => (i, mutable.ArrayBuffer[SimpleRecord]()) }.toMap + + val producerId = 11 + private var bumpProducerId = false + + @BeforeEach + override def setUp(): Unit = { + super.setUp() + + EasyMock.expect(zkClient.getTopicPartitionCount(TRANSACTION_STATE_TOPIC_NAME)) + .andReturn(Some(numPartitions)) + .anyTimes() + EasyMock.replay(zkClient) + + txnStateManager = new TransactionStateManager(0, scheduler, replicaManager, txnConfig, time, + new Metrics()) + txnStateManager.startup(() => zkClient.getTopicPartitionCount(TRANSACTION_STATE_TOPIC_NAME).get, + enableTransactionalIdExpiration = true) + for (i <- 0 until numPartitions) + txnStateManager.addLoadedTransactionsToCache(i, coordinatorEpoch, new Pool[String, TransactionMetadata]()) + + val pidGenerator: ProducerIdManager = EasyMock.createNiceMock(classOf[ProducerIdManager]) + EasyMock.expect(pidGenerator.generateProducerId()) + .andAnswer(() => if (bumpProducerId) producerId + 1 else producerId) + .anyTimes() + val brokerNode = new Node(0, "host", 10) + val metadataCache: MetadataCache = EasyMock.createNiceMock(classOf[MetadataCache]) + EasyMock.expect(metadataCache.getPartitionLeaderEndpoint( + EasyMock.anyString(), + EasyMock.anyInt(), + EasyMock.anyObject()) + ).andReturn(Some(brokerNode)).anyTimes() + val networkClient: NetworkClient = EasyMock.createNiceMock(classOf[NetworkClient]) + txnMarkerChannelManager = new TransactionMarkerChannelManager( + KafkaConfig.fromProps(serverProps), + metadataCache, + networkClient, + txnStateManager, + time) + + transactionCoordinator = new TransactionCoordinator(brokerId = 0, + txnConfig, + scheduler, + () => pidGenerator, + txnStateManager, + txnMarkerChannelManager, + time, + new LogContext) + EasyMock.replay(pidGenerator) + EasyMock.replay(metadataCache) + EasyMock.replay(networkClient) + } + + @AfterEach + override def tearDown(): Unit = { + try { + EasyMock.reset(zkClient, replicaManager) + transactionCoordinator.shutdown() + } finally { + super.tearDown() + } + } + + @Test + def testConcurrentGoodPathWithConcurrentPartitionLoading(): Unit = { + // This is a somewhat contrived test case which reproduces the bug in KAFKA-9777. + // When a new partition needs to be loaded, we acquire the write lock in order to + // add the partition to the set of loading partitions. We should still be able to + // make progress with transactions even while this is ongoing. + + val keepRunning = new AtomicBoolean(true) + val t = new Thread() { + override def run(): Unit = { + while (keepRunning.get()) { + txnStateManager.addLoadingPartition(numPartitions + 1, coordinatorEpoch) + } + } + } + t.start() + + verifyConcurrentOperations(createTransactions, allOperations) + keepRunning.set(false) + t.join() + } + + @Test + def testConcurrentGoodPathSequence(): Unit = { + verifyConcurrentOperations(createTransactions, allOperations) + } + + @Test + def testConcurrentRandomSequences(): Unit = { + verifyConcurrentRandomSequences(createTransactions, allOperations) + } + + /** + * Concurrently load one set of transaction state topic partitions and unload another + * set of partitions. This tests partition leader changes of transaction state topic + * that are handled by different threads concurrently. Verifies that the metadata of + * unloaded partitions are removed from the transaction manager and that the transactions + * from the newly loaded partitions are loaded correctly. + */ + @Test + def testConcurrentLoadUnloadPartitions(): Unit = { + val partitionsToLoad = (0 until numPartitions / 2).toSet + val partitionsToUnload = (numPartitions / 2 until numPartitions).toSet + verifyConcurrentActions(loadUnloadActions(partitionsToLoad, partitionsToUnload)) + } + + /** + * Concurrently load one set of transaction state topic partitions, unload a second set + * of partitions and expire transactions on a third set of partitions. This tests partition + * leader changes of transaction state topic that are handled by different threads concurrently + * while expiry is performed on another thread. Verifies the state of transactions on all the partitions. + */ + @Test + def testConcurrentTransactionExpiration(): Unit = { + val partitionsToLoad = (0 until numPartitions / 3).toSet + val partitionsToUnload = (numPartitions / 3 until numPartitions * 2 / 3).toSet + val partitionsWithExpiringTxn = (numPartitions * 2 / 3 until numPartitions).toSet + val expiringTransactions = allTransactions.filter { txn => + partitionsWithExpiringTxn.contains(txnStateManager.partitionFor(txn.transactionalId)) + }.toSet + val expireAction = new ExpireTransactionsAction(expiringTransactions) + verifyConcurrentActions(loadUnloadActions(partitionsToLoad, partitionsToUnload) + expireAction) + } + + @Test + def testConcurrentNewInitProducerIdRequests(): Unit = { + val transactions = (1 to 100).flatMap(i => createTransactions(s"testConcurrentInitProducerID$i-")) + bumpProducerId = true + transactions.foreach { txn => + val txnMetadata = prepareExhaustedEpochTxnMetadata(txn) + txnStateManager.putTransactionStateIfNotExists(txnMetadata) + + // Test simultaneous requests from an existing producer trying to bump the epoch and a new producer initializing + val newProducerOp1 = new InitProducerIdOperation() + val newProducerOp2 = new InitProducerIdOperation() + verifyConcurrentActions(Set(newProducerOp1, newProducerOp2).map(_.actionNoVerify(txn))) + + // If only one request succeeds, assert that the epoch was successfully increased + // If both requests succeed, the new producer must have run after the existing one and should have the higher epoch + (newProducerOp1.result.get.error, newProducerOp2.result.get.error) match { + case (Errors.NONE, Errors.NONE) => + assertNotEquals(newProducerOp1.result.get.producerEpoch, newProducerOp2.result.get.producerEpoch) + // assertEquals(0, newProducerOp1.result.get.producerEpoch) + // assertEquals(0, newProducerOp2.result.get.producerEpoch) + case (Errors.NONE, _) => + assertEquals(0, newProducerOp1.result.get.producerEpoch) + case (_, Errors.NONE) => + assertEquals(0, newProducerOp2.result.get.producerEpoch) + case (_, _) => fail("One of two InitProducerId requests should succeed") + } + } + } + + @Test + def testConcurrentInitProducerIdRequestsOneNewOneContinuing(): Unit = { + val transactions = (1 to 10).flatMap(i => createTransactions(s"testConcurrentInitProducerID$i-")) + transactions.foreach { txn => + val firstInitReq = new InitProducerIdOperation() + firstInitReq.run(txn) + firstInitReq.awaitAndVerify(txn) + + // Test simultaneous requests from an existing producer trying to bump the epoch and a new producer initializing + val producerIdAndEpoch = new ProducerIdAndEpoch(firstInitReq.result.get.producerId, firstInitReq.result.get.producerEpoch) + val bumpEpochOp = new InitProducerIdOperation(Some(producerIdAndEpoch)) + val newProducerOp = new InitProducerIdOperation() + verifyConcurrentActions(Set(bumpEpochOp, newProducerOp).map(_.actionNoVerify(txn))) + + // If only one request succeeds, assert that the epoch was successfully increased + // If both requests succeed, the new producer must have run after the existing one and should have the higher epoch + (bumpEpochOp.result.get.error, newProducerOp.result.get.error) match { + case (Errors.NONE, Errors.NONE) => + assertEquals(producerIdAndEpoch.epoch + 2, newProducerOp.result.get.producerEpoch) + assertEquals(producerIdAndEpoch.epoch + 1, bumpEpochOp.result.get.producerEpoch) + case (Errors.NONE, _) => + assertEquals(producerIdAndEpoch.epoch + 1, bumpEpochOp.result.get.producerEpoch) + case (_, Errors.NONE) => + assertEquals(producerIdAndEpoch.epoch + 1, newProducerOp.result.get.producerEpoch) + case (_, _) => fail("One of two InitProducerId requests should succeed") + } + } + } + + @Test + def testConcurrentContinuingInitProducerIdRequests(): Unit = { + val transactions = (1 to 100).flatMap(i => createTransactions(s"testConcurrentInitProducerID$i-")) + transactions.foreach { txn => + // Test simultaneous requests from an existing producers trying to re-initialize when no state is present + val producerIdAndEpoch = new ProducerIdAndEpoch(producerId, 10) + val bumpEpochOp1 = new InitProducerIdOperation(Some(producerIdAndEpoch)) + val bumpEpochOp2 = new InitProducerIdOperation(Some(producerIdAndEpoch)) + verifyConcurrentActions(Set(bumpEpochOp1, bumpEpochOp2).map(_.actionNoVerify(txn))) + + // If only one request succeeds, assert that the epoch was successfully increased + // If both requests succeed, the new producer must have run after the existing one and should have the higher epoch + (bumpEpochOp1.result.get.error, bumpEpochOp2.result.get.error) match { + case (Errors.NONE, Errors.NONE) => + fail("One of two InitProducerId requests should fail due to concurrent requests or non-matching epochs") + case (Errors.NONE, _) => + assertEquals(0, bumpEpochOp1.result.get.producerEpoch) + case (_, Errors.NONE) => + assertEquals(0, bumpEpochOp2.result.get.producerEpoch) + case (_, _) => fail("One of two InitProducerId requests should succeed") + } + } + } + + @Test + def testConcurrentInitProducerIdRequestsWithRetry(): Unit = { + val transactions = (1 to 10).flatMap(i => createTransactions(s"testConcurrentInitProducerID$i-")) + transactions.foreach { txn => + val firstInitReq = new InitProducerIdOperation() + firstInitReq.run(txn) + firstInitReq.awaitAndVerify(txn) + + val initialProducerIdAndEpoch = new ProducerIdAndEpoch(firstInitReq.result.get.producerId, firstInitReq.result.get.producerEpoch) + val bumpEpochReq = new InitProducerIdOperation(Some(initialProducerIdAndEpoch)) + bumpEpochReq.run(txn) + bumpEpochReq.awaitAndVerify(txn) + + // Test simultaneous requests from an existing producer retrying the epoch bump and a new producer initializing + val bumpedProducerIdAndEpoch = new ProducerIdAndEpoch(bumpEpochReq.result.get.producerId, bumpEpochReq.result.get.producerEpoch) + val retryBumpEpochOp = new InitProducerIdOperation(Some(initialProducerIdAndEpoch)) + val newProducerOp = new InitProducerIdOperation() + verifyConcurrentActions(Set(retryBumpEpochOp, newProducerOp).map(_.actionNoVerify(txn))) + + // If both requests succeed, the new producer must have run after the existing one and should have the higher epoch + // If the retry succeeds and the new producer doesn't, assert that the already-bumped epoch was returned + // If the new producer succeeds and the retry doesn't, assert the epoch was bumped + (retryBumpEpochOp.result.get.error, newProducerOp.result.get.error) match { + case (Errors.NONE, Errors.NONE) => + assertEquals(bumpedProducerIdAndEpoch.epoch + 1, newProducerOp.result.get.producerEpoch) + assertEquals(bumpedProducerIdAndEpoch.epoch, retryBumpEpochOp.result.get.producerEpoch) + case (Errors.NONE, _) => + assertEquals(bumpedProducerIdAndEpoch.epoch, retryBumpEpochOp.result.get.producerEpoch) + case (_, Errors.NONE) => + assertEquals(bumpedProducerIdAndEpoch.epoch + 1, newProducerOp.result.get.producerEpoch) + case (_, _) => fail("At least one InitProducerId request should succeed") + } + } + } + + @Test + def testConcurrentInitProducerRequestsAtPidBoundary(): Unit = { + val transactions = (1 to 10).flatMap(i => createTransactions(s"testConcurrentInitProducerID$i-")) + bumpProducerId = true + transactions.foreach { txn => + val txnMetadata = prepareExhaustedEpochTxnMetadata(txn) + txnStateManager.putTransactionStateIfNotExists(txnMetadata) + + // Test simultaneous requests from an existing producer attempting to bump the epoch and a new producer initializing + val bumpEpochOp = new InitProducerIdOperation(Some(new ProducerIdAndEpoch(producerId, (Short.MaxValue - 1).toShort))) + val newProducerOp = new InitProducerIdOperation() + verifyConcurrentActions(Set(bumpEpochOp, newProducerOp).map(_.actionNoVerify(txn))) + + // If the retry succeeds and the new producer doesn't, assert that the already-bumped epoch was returned + // If the new producer succeeds and the retry doesn't, assert the epoch was bumped + // If both requests succeed, the new producer must have run after the existing one and should have the higher epoch + (bumpEpochOp.result.get.error, newProducerOp.result.get.error) match { + case (Errors.NONE, Errors.NONE) => + assertEquals(0, bumpEpochOp.result.get.producerEpoch) + assertEquals(producerId + 1, bumpEpochOp.result.get.producerId) + + assertEquals(1, newProducerOp.result.get.producerEpoch) + assertEquals(producerId + 1, newProducerOp.result.get.producerId) + case (Errors.NONE, _) => + assertEquals(0, bumpEpochOp.result.get.producerEpoch) + assertEquals(producerId + 1, bumpEpochOp.result.get.producerId) + case (_, Errors.NONE) => + assertEquals(0, newProducerOp.result.get.producerEpoch) + assertEquals(producerId + 1, newProducerOp.result.get.producerId) + case (_, _) => fail("One of two InitProducerId requests should succeed") + } + } + + bumpProducerId = false + } + + @Test + def testConcurrentInitProducerRequestsWithRetryAtPidBoundary(): Unit = { + val transactions = (1 to 10).flatMap(i => createTransactions(s"testConcurrentInitProducerID$i-")) + bumpProducerId = true + transactions.foreach { txn => + val txnMetadata = prepareExhaustedEpochTxnMetadata(txn) + txnStateManager.putTransactionStateIfNotExists(txnMetadata) + + val bumpEpochReq = new InitProducerIdOperation(Some(new ProducerIdAndEpoch(producerId, (Short.MaxValue - 1).toShort))) + bumpEpochReq.run(txn) + bumpEpochReq.awaitAndVerify(txn) + + // Test simultaneous requests from an existing producer attempting to bump the epoch and a new producer initializing + val retryBumpEpochOp = new InitProducerIdOperation(Some(new ProducerIdAndEpoch(producerId, (Short.MaxValue - 1).toShort))) + val newProducerOp = new InitProducerIdOperation() + verifyConcurrentActions(Set(retryBumpEpochOp, newProducerOp).map(_.actionNoVerify(txn))) + + // If the retry succeeds and the new producer doesn't, assert that the already-bumped epoch was returned + // If the new producer succeeds and the retry doesn't, assert the epoch was bumped + // If both requests succeed, the new producer must have run after the existing one and should have the higher epoch + (retryBumpEpochOp.result.get.error, newProducerOp.result.get.error) match { + case (Errors.NONE, Errors.NONE) => + assertEquals(0, retryBumpEpochOp.result.get.producerEpoch) + assertEquals(producerId + 1, retryBumpEpochOp.result.get.producerId) + + assertEquals(1, newProducerOp.result.get.producerEpoch) + assertEquals(producerId + 1, newProducerOp.result.get.producerId) + case (Errors.NONE, _) => + assertEquals(0, retryBumpEpochOp.result.get.producerEpoch) + assertEquals(producerId + 1, retryBumpEpochOp.result.get.producerId) + case (_, Errors.NONE) => + assertEquals(1, newProducerOp.result.get.producerEpoch) + assertEquals(producerId + 1, newProducerOp.result.get.producerId) + case (_, _) => fail("One of two InitProducerId requests should succeed") + } + } + + bumpProducerId = false + } + + override def enableCompletion(): Unit = { + super.enableCompletion() + + def createResponse(request: WriteTxnMarkersRequest): WriteTxnMarkersResponse = { + val pidErrorMap = request.markers.asScala.map { marker => + (marker.producerId.asInstanceOf[java.lang.Long], marker.partitions.asScala.map { tp => (tp, Errors.NONE) }.toMap.asJava) + }.toMap.asJava + new WriteTxnMarkersResponse(pidErrorMap) + } + synchronized { + txnMarkerChannelManager.generateRequests().foreach { requestAndHandler => + val request = requestAndHandler.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build() + val response = createResponse(request) + requestAndHandler.handler.onComplete(new ClientResponse(new RequestHeader(ApiKeys.PRODUCE, 0, "client", 1), + null, null, 0, 0, false, null, null, response)) + } + } + } + + /** + * Concurrently load `partitionsToLoad` and unload `partitionsToUnload`. Before the concurrent operations + * are run `partitionsToLoad` must be unloaded first since all partitions were loaded during setUp. + */ + private def loadUnloadActions(partitionsToLoad: Set[Int], partitionsToUnload: Set[Int]): Set[Action] = { + val transactions = (1 to 10).flatMap(i => createTransactions(s"testConcurrentLoadUnloadPartitions$i-")).toSet + transactions.foreach(txn => prepareTransaction(txn)) + val unload = partitionsToLoad.map(new UnloadTxnPartitionAction(_)) + unload.foreach(_.run()) + unload.foreach(_.await()) + partitionsToLoad.map(new LoadTxnPartitionAction(_)) ++ partitionsToUnload.map(new UnloadTxnPartitionAction(_)) + } + + private def createTransactions(txnPrefix: String): Set[Transaction] = { + val transactions = (0 until nTransactions).map { i => new Transaction(s"$txnPrefix$i", i, time) } + allTransactions ++= transactions + transactions.toSet + } + + private def verifyTransaction(txn: Transaction, expectedState: TransactionState): Unit = { + val (metadata, success) = TestUtils.computeUntilTrue({ + enableCompletion() + transactionMetadata(txn) + })(metadata => metadata.nonEmpty && metadata.forall(m => m.state == expectedState && m.pendingState.isEmpty)) + assertTrue(success, s"Invalid metadata state $metadata") + } + + private def transactionMetadata(txn: Transaction): Option[TransactionMetadata] = { + txnStateManager.getTransactionState(txn.transactionalId) match { + case Left(error) => + if (error == Errors.NOT_COORDINATOR) + None + else + throw new AssertionError(s"Unexpected transaction error $error for $txn") + case Right(Some(metadata)) => + Some(metadata.transactionMetadata) + case Right(None) => + None + } + } + + private def prepareTransaction(txn: Transaction): Unit = { + val partitionId = txnStateManager.partitionFor(txn.transactionalId) + val txnRecords = txnRecordsByPartition(partitionId) + val initPidOp = new InitProducerIdOperation() + val addPartitionsOp = new AddPartitionsToTxnOperation(Set(new TopicPartition("topic", 0))) + initPidOp.run(txn) + initPidOp.awaitAndVerify(txn) + addPartitionsOp.run(txn) + addPartitionsOp.awaitAndVerify(txn) + + val txnMetadata = transactionMetadata(txn).getOrElse(throw new IllegalStateException(s"Transaction not found $txn")) + txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit())) + + txnMetadata.state = PrepareCommit + txnRecords += new SimpleRecord(txn.txnMessageKeyBytes, TransactionLog.valueToBytes(txnMetadata.prepareNoTransit())) + + prepareTxnLog(partitionId) + } + + private def prepareTxnLog(partitionId: Int): Unit = { + val logMock: UnifiedLog = EasyMock.mock(classOf[UnifiedLog]) + EasyMock.expect(logMock.config).andStubReturn(new LogConfig(Collections.emptyMap())) + + val fileRecordsMock: FileRecords = EasyMock.mock(classOf[FileRecords]) + + val topicPartition = new TopicPartition(TRANSACTION_STATE_TOPIC_NAME, partitionId) + val startOffset = replicaManager.getLogEndOffset(topicPartition).getOrElse(20L) + val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, txnRecordsByPartition(partitionId).toArray: _*) + val endOffset = startOffset + records.records.asScala.size + + EasyMock.expect(logMock.logStartOffset).andStubReturn(startOffset) + EasyMock.expect(logMock.read(EasyMock.eq(startOffset), + maxLength = EasyMock.anyInt(), + isolation = EasyMock.eq(FetchLogEnd), + minOneMessage = EasyMock.eq(true))) + .andReturn(FetchDataInfo(LogOffsetMetadata(startOffset), fileRecordsMock)) + + EasyMock.expect(fileRecordsMock.sizeInBytes()).andStubReturn(records.sizeInBytes) + + val bufferCapture = EasyMock.newCapture[ByteBuffer] + fileRecordsMock.readInto(EasyMock.capture(bufferCapture), EasyMock.anyInt()) + EasyMock.expectLastCall().andAnswer(new IAnswer[Unit] { + override def answer: Unit = { + val buffer = bufferCapture.getValue + buffer.put(records.buffer.duplicate) + buffer.flip() + } + }) + + EasyMock.replay(logMock, fileRecordsMock) + synchronized { + replicaManager.updateLog(topicPartition, logMock, endOffset) + } + } + + private def prepareExhaustedEpochTxnMetadata(txn: Transaction): TransactionMetadata = { + new TransactionMetadata(transactionalId = txn.transactionalId, + producerId = producerId, + lastProducerId = RecordBatch.NO_PRODUCER_ID, + producerEpoch = (Short.MaxValue - 1).toShort, + lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, + txnTimeoutMs = 60000, + state = Empty, + topicPartitions = collection.mutable.Set.empty[TopicPartition], + txnLastUpdateTimestamp = time.milliseconds()) + } + + abstract class TxnOperation[R] extends Operation { + @volatile var result: Option[R] = None + def resultCallback(r: R): Unit = this.result = Some(r) + } + + class InitProducerIdOperation(val producerIdAndEpoch: Option[ProducerIdAndEpoch] = None) extends TxnOperation[InitProducerIdResult] { + override def run(txn: Transaction): Unit = { + transactionCoordinator.handleInitProducerId(txn.transactionalId, 60000, producerIdAndEpoch, resultCallback, + RequestLocal.withThreadConfinedCaching) + replicaManager.tryCompleteActions() + } + override def awaitAndVerify(txn: Transaction): Unit = { + val initPidResult = result.getOrElse(throw new IllegalStateException("InitProducerId has not completed")) + assertEquals(Errors.NONE, initPidResult.error) + verifyTransaction(txn, Empty) + } + } + + class AddPartitionsToTxnOperation(partitions: Set[TopicPartition]) extends TxnOperation[Errors] { + override def run(txn: Transaction): Unit = { + transactionMetadata(txn).foreach { txnMetadata => + transactionCoordinator.handleAddPartitionsToTransaction(txn.transactionalId, + txnMetadata.producerId, + txnMetadata.producerEpoch, + partitions, + resultCallback, + RequestLocal.withThreadConfinedCaching) + replicaManager.tryCompleteActions() + } + } + override def awaitAndVerify(txn: Transaction): Unit = { + val error = result.getOrElse(throw new IllegalStateException("AddPartitionsToTransaction has not completed")) + assertEquals(Errors.NONE, error) + verifyTransaction(txn, Ongoing) + } + } + + class EndTxnOperation extends TxnOperation[Errors] { + override def run(txn: Transaction): Unit = { + transactionMetadata(txn).foreach { txnMetadata => + transactionCoordinator.handleEndTransaction(txn.transactionalId, + txnMetadata.producerId, + txnMetadata.producerEpoch, + transactionResult(txn), + resultCallback, + RequestLocal.withThreadConfinedCaching) + } + } + override def awaitAndVerify(txn: Transaction): Unit = { + val error = result.getOrElse(throw new IllegalStateException("EndTransaction has not completed")) + if (!txn.ended) { + txn.ended = true + assertEquals(Errors.NONE, error) + val expectedState = if (transactionResult(txn) == TransactionResult.COMMIT) CompleteCommit else CompleteAbort + verifyTransaction(txn, expectedState) + } else + assertEquals(Errors.INVALID_TXN_STATE, error) + } + // Test both commit and abort. Transactional ids used in the test have the format + // Use the last digit of the index to decide between commit and abort. + private def transactionResult(txn: Transaction): TransactionResult = { + val txnId = txn.transactionalId + val lastDigit = txnId(txnId.length - 1).toInt + if (lastDigit % 2 == 0) TransactionResult.COMMIT else TransactionResult.ABORT + } + } + + class LoadTxnPartitionAction(txnTopicPartitionId: Int) extends Action { + override def run(): Unit = { + transactionCoordinator.onElection(txnTopicPartitionId, coordinatorEpoch) + } + override def await(): Unit = { + allTransactions.foreach { txn => + if (txnStateManager.partitionFor(txn.transactionalId) == txnTopicPartitionId) { + verifyTransaction(txn, CompleteCommit) + } + } + } + } + + class UnloadTxnPartitionAction(txnTopicPartitionId: Int) extends Action { + val txnRecords: mutable.ArrayBuffer[SimpleRecord] = mutable.ArrayBuffer[SimpleRecord]() + override def run(): Unit = { + transactionCoordinator.onResignation(txnTopicPartitionId, Some(coordinatorEpoch)) + } + override def await(): Unit = { + allTransactions.foreach { txn => + if (txnStateManager.partitionFor(txn.transactionalId) == txnTopicPartitionId) + assertTrue(transactionMetadata(txn).isEmpty, "Transaction metadata not removed") + } + } + } + + class ExpireTransactionsAction(transactions: Set[Transaction]) extends Action { + override def run(): Unit = { + transactions.foreach { txn => + transactionMetadata(txn).foreach { txnMetadata => + txnMetadata.txnLastUpdateTimestamp = time.milliseconds() - txnConfig.transactionalIdExpirationMs + } + } + txnStateManager.enableTransactionalIdExpiration() + replicaManager.tryCompleteActions() + time.sleep(txnConfig.removeExpiredTransactionalIdsIntervalMs + 1) + } + + override def await(): Unit = { + val (_, success) = TestUtils.computeUntilTrue({ + replicaManager.tryCompleteActions() + transactions.forall(txn => transactionMetadata(txn).isEmpty) + })(identity) + assertTrue(success, "Transaction not expired") + } + } +} + +object TransactionCoordinatorConcurrencyTest { + + class Transaction(val transactionalId: String, producerId: Long, time: MockTime) extends CoordinatorMember { + val txnMessageKeyBytes: Array[Byte] = TransactionLog.keyToBytes(transactionalId) + @volatile var ended = false + override def toString: String = transactionalId + } +} diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala new file mode 100644 index 0000000..38e8e71 --- /dev/null +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala @@ -0,0 +1,1249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.coordinator.transaction + +import kafka.utils.MockScheduler +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.record.RecordBatch +import org.apache.kafka.common.requests.TransactionResult +import org.apache.kafka.common.utils.{LogContext, MockTime, ProducerIdAndEpoch} +import org.easymock.{Capture, EasyMock} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.collection.mutable +import scala.jdk.CollectionConverters._ + +class TransactionCoordinatorTest { + + val time = new MockTime() + + var nextPid: Long = 0L + val pidGenerator: ProducerIdManager = EasyMock.createNiceMock(classOf[ProducerIdManager]) + val transactionManager: TransactionStateManager = EasyMock.createNiceMock(classOf[TransactionStateManager]) + val transactionMarkerChannelManager: TransactionMarkerChannelManager = EasyMock.createNiceMock(classOf[TransactionMarkerChannelManager]) + val capturedTxn: Capture[TransactionMetadata] = EasyMock.newCapture() + val capturedErrorsCallback: Capture[Errors => Unit] = EasyMock.newCapture() + val capturedTxnTransitMetadata: Capture[TxnTransitMetadata] = EasyMock.newCapture() + val brokerId = 0 + val coordinatorEpoch = 0 + private val transactionalId = "known" + private val producerId = 10 + private val producerEpoch: Short = 1 + private val txnTimeoutMs = 1 + + private val partitions = mutable.Set[TopicPartition](new TopicPartition("topic1", 0)) + private val scheduler = new MockScheduler(time) + + val coordinator = new TransactionCoordinator( + brokerId, + TransactionConfig(), + scheduler, + () => pidGenerator, + transactionManager, + transactionMarkerChannelManager, + time, + new LogContext) + val transactionStatePartitionCount = 1 + var result: InitProducerIdResult = _ + var error: Errors = Errors.NONE + + private def mockPidGenerator(): Unit = { + EasyMock.expect(pidGenerator.generateProducerId()).andAnswer(() => { + nextPid += 1 + nextPid - 1 + }).anyTimes() + } + + private def initPidGenericMocks(transactionalId: String): Unit = { + mockPidGenerator() + EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt())) + .andReturn(true) + .anyTimes() + } + + @Test + def shouldReturnInvalidRequestWhenTransactionalIdIsEmpty(): Unit = { + mockPidGenerator() + EasyMock.replay(pidGenerator) + + coordinator.handleInitProducerId("", txnTimeoutMs, None, initProducerIdMockCallback) + assertEquals(InitProducerIdResult(-1L, -1, Errors.INVALID_REQUEST), result) + coordinator.handleInitProducerId("", txnTimeoutMs, None, initProducerIdMockCallback) + assertEquals(InitProducerIdResult(-1L, -1, Errors.INVALID_REQUEST), result) + } + + @Test + def shouldAcceptInitPidAndReturnNextPidWhenTransactionalIdIsNull(): Unit = { + mockPidGenerator() + EasyMock.replay(pidGenerator) + + coordinator.handleInitProducerId(null, txnTimeoutMs, None, initProducerIdMockCallback) + assertEquals(InitProducerIdResult(0L, 0, Errors.NONE), result) + coordinator.handleInitProducerId(null, txnTimeoutMs, None, initProducerIdMockCallback) + assertEquals(InitProducerIdResult(1L, 0, Errors.NONE), result) + } + + @Test + def shouldInitPidWithEpochZeroForNewTransactionalId(): Unit = { + initPidGenericMocks(transactionalId) + + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(None)) + .once() + + EasyMock.expect(transactionManager.putTransactionStateIfNotExists(EasyMock.capture(capturedTxn))) + .andAnswer(() => { + assertTrue(capturedTxn.hasCaptured) + Right(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, capturedTxn.getValue)) + }).once() + + EasyMock.expect(transactionManager.appendTransactionToLog( + EasyMock.eq(transactionalId), + EasyMock.eq(coordinatorEpoch), + EasyMock.anyObject().asInstanceOf[TxnTransitMetadata], + EasyMock.capture(capturedErrorsCallback), + EasyMock.anyObject(), + EasyMock.anyObject()) + ).andAnswer(() => capturedErrorsCallback.getValue.apply(Errors.NONE)).anyTimes() + EasyMock.replay(pidGenerator, transactionManager) + + coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, None, initProducerIdMockCallback) + assertEquals(InitProducerIdResult(nextPid - 1, 0, Errors.NONE), result) + } + + @Test + def shouldGenerateNewProducerIdIfNoStateAndProducerIdAndEpochProvided(): Unit = { + initPidGenericMocks(transactionalId) + + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(None)) + .once() + + EasyMock.expect(transactionManager.putTransactionStateIfNotExists(EasyMock.capture(capturedTxn))) + .andAnswer(() => { + assertTrue(capturedTxn.hasCaptured) + Right(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, capturedTxn.getValue)) + }).once() + + EasyMock.expect(transactionManager.appendTransactionToLog( + EasyMock.eq(transactionalId), + EasyMock.eq(coordinatorEpoch), + EasyMock.anyObject().asInstanceOf[TxnTransitMetadata], + EasyMock.capture(capturedErrorsCallback), + EasyMock.anyObject(), + EasyMock.anyObject()) + ).andAnswer(() => capturedErrorsCallback.getValue.apply(Errors.NONE)).anyTimes() + EasyMock.replay(pidGenerator, transactionManager) + + coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, Some(new ProducerIdAndEpoch(producerId, producerEpoch)), + initProducerIdMockCallback) + assertEquals(InitProducerIdResult(nextPid - 1, 0, Errors.NONE), result) + } + + @Test + def shouldGenerateNewProducerIdIfEpochsExhausted(): Unit = { + initPidGenericMocks(transactionalId) + + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (Short.MaxValue - 1).toShort, + (Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, mutable.Set.empty, time.milliseconds(), time.milliseconds()) + + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) + + EasyMock.expect(transactionManager.appendTransactionToLog( + EasyMock.eq(transactionalId), + EasyMock.eq(coordinatorEpoch), + EasyMock.anyObject().asInstanceOf[TxnTransitMetadata], + EasyMock.capture(capturedErrorsCallback), + EasyMock.anyObject(), + EasyMock.anyObject() + )).andAnswer(() => capturedErrorsCallback.getValue.apply(Errors.NONE)) + + EasyMock.replay(pidGenerator, transactionManager) + + coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, None, initProducerIdMockCallback) + assertNotEquals(producerId, result.producerId) + assertEquals(0, result.producerEpoch) + assertEquals(Errors.NONE, result.error) + } + + @Test + def shouldRespondWithNotCoordinatorOnInitPidWhenNotCoordinator(): Unit = { + EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt())) + .andReturn(true) + .anyTimes() + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Left(Errors.NOT_COORDINATOR)) + EasyMock.replay(transactionManager) + + coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, None, initProducerIdMockCallback) + assertEquals(InitProducerIdResult(-1, -1, Errors.NOT_COORDINATOR), result) + } + + @Test + def shouldRespondWithCoordinatorLoadInProgressOnInitPidWhenCoordintorLoading(): Unit = { + EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt())) + .andReturn(true) + .anyTimes() + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Left(Errors.COORDINATOR_LOAD_IN_PROGRESS)) + EasyMock.replay(transactionManager) + + coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, None, initProducerIdMockCallback) + assertEquals(InitProducerIdResult(-1, -1, Errors.COORDINATOR_LOAD_IN_PROGRESS), result) + } + + @Test + def shouldRespondWithInvalidPidMappingOnAddPartitionsToTransactionWhenTransactionalIdNotPresent(): Unit = { + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(None)) + EasyMock.replay(transactionManager) + + coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 1, partitions, errorsCallback) + assertEquals(Errors.INVALID_PRODUCER_ID_MAPPING, error) + } + + @Test + def shouldRespondWithInvalidRequestAddPartitionsToTransactionWhenTransactionalIdIsEmpty(): Unit = { + coordinator.handleAddPartitionsToTransaction("", 0L, 1, partitions, errorsCallback) + assertEquals(Errors.INVALID_REQUEST, error) + } + + @Test + def shouldRespondWithInvalidRequestAddPartitionsToTransactionWhenTransactionalIdIsNull(): Unit = { + coordinator.handleAddPartitionsToTransaction(null, 0L, 1, partitions, errorsCallback) + assertEquals(Errors.INVALID_REQUEST, error) + } + + @Test + def shouldRespondWithNotCoordinatorOnAddPartitionsWhenNotCoordinator(): Unit = { + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Left(Errors.NOT_COORDINATOR)) + EasyMock.replay(transactionManager) + + coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 1, partitions, errorsCallback) + assertEquals(Errors.NOT_COORDINATOR, error) + } + + @Test + def shouldRespondWithCoordinatorLoadInProgressOnAddPartitionsWhenCoordintorLoading(): Unit = { + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Left(Errors.COORDINATOR_LOAD_IN_PROGRESS)) + + EasyMock.replay(transactionManager) + + coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 1, partitions, errorsCallback) + assertEquals(Errors.COORDINATOR_LOAD_IN_PROGRESS, error) + } + + @Test + def shouldRespondWithConcurrentTransactionsOnAddPartitionsWhenStateIsPrepareCommit(): Unit = { + validateConcurrentTransactions(PrepareCommit) + } + + @Test + def shouldRespondWithConcurrentTransactionOnAddPartitionsWhenStateIsPrepareAbort(): Unit = { + validateConcurrentTransactions(PrepareAbort) + } + + def validateConcurrentTransactions(state: TransactionState): Unit = { + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, + new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, state, mutable.Set.empty, 0, 0))))) + + EasyMock.replay(transactionManager) + + coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback) + assertEquals(Errors.CONCURRENT_TRANSACTIONS, error) + } + + @Test + def shouldRespondWithProducerFencedOnAddPartitionsWhenEpochsAreDifferent(): Unit = { + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, + new TransactionMetadata(transactionalId, 0, 0, 10, 9, 0, PrepareCommit, mutable.Set.empty, 0, 0))))) + + EasyMock.replay(transactionManager) + + coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback) + assertEquals(Errors.PRODUCER_FENCED, error) + } + + @Test + def shouldAppendNewMetadataToLogOnAddPartitionsWhenPartitionsAdded(): Unit = { + validateSuccessfulAddPartitions(Empty) + } + + @Test + def shouldRespondWithSuccessOnAddPartitionsWhenStateIsOngoing(): Unit = { + validateSuccessfulAddPartitions(Ongoing) + } + + @Test + def shouldRespondWithSuccessOnAddPartitionsWhenStateIsCompleteCommit(): Unit = { + validateSuccessfulAddPartitions(CompleteCommit) + } + + @Test + def shouldRespondWithSuccessOnAddPartitionsWhenStateIsCompleteAbort(): Unit = { + validateSuccessfulAddPartitions(CompleteAbort) + } + + def validateSuccessfulAddPartitions(previousState: TransactionState): Unit = { + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort, + txnTimeoutMs, previousState, mutable.Set.empty, time.milliseconds(), time.milliseconds()) + + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) + + EasyMock.expect(transactionManager.appendTransactionToLog( + EasyMock.eq(transactionalId), + EasyMock.eq(coordinatorEpoch), + EasyMock.anyObject().asInstanceOf[TxnTransitMetadata], + EasyMock.capture(capturedErrorsCallback), + EasyMock.anyObject(), + EasyMock.anyObject() + )) + + EasyMock.replay(transactionManager) + + coordinator.handleAddPartitionsToTransaction(transactionalId, producerId, producerEpoch, partitions, errorsCallback) + + EasyMock.verify(transactionManager) + } + + @Test + def shouldRespondWithErrorsNoneOnAddPartitionWhenNoErrorsAndPartitionsTheSame(): Unit = { + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, + new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Empty, partitions, 0, 0))))) + + EasyMock.replay(transactionManager) + + coordinator.handleAddPartitionsToTransaction(transactionalId, 0L, 0, partitions, errorsCallback) + assertEquals(Errors.NONE, error) + EasyMock.verify(transactionManager) + + } + + @Test + def shouldReplyWithInvalidPidMappingOnEndTxnWhenTxnIdDoesntExist(): Unit = { + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(None)) + EasyMock.replay(transactionManager) + + coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, errorsCallback) + assertEquals(Errors.INVALID_PRODUCER_ID_MAPPING, error) + EasyMock.verify(transactionManager) + } + + @Test + def shouldReplyWithInvalidPidMappingOnEndTxnWhenPidDosentMatchMapped(): Unit = { + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, + new TransactionMetadata(transactionalId, 10, 10, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))) + EasyMock.replay(transactionManager) + + coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, errorsCallback) + assertEquals(Errors.INVALID_PRODUCER_ID_MAPPING, error) + EasyMock.verify(transactionManager) + } + + @Test + def shouldReplyWithProducerFencedOnEndTxnWhenEpochIsNotSameAsTransaction(): Unit = { + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, + new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort, 1, Ongoing, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))) + EasyMock.replay(transactionManager) + + coordinator.handleEndTransaction(transactionalId, producerId, 0, TransactionResult.COMMIT, errorsCallback) + assertEquals(Errors.PRODUCER_FENCED, error) + EasyMock.verify(transactionManager) + } + + @Test + def shouldReturnOkOnEndTxnWhenStatusIsCompleteCommitAndResultIsCommit(): Unit ={ + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, + new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))) + EasyMock.replay(transactionManager) + + coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.COMMIT, errorsCallback) + assertEquals(Errors.NONE, error) + EasyMock.verify(transactionManager) + } + + @Test + def shouldReturnOkOnEndTxnWhenStatusIsCompleteAbortAndResultIsAbort(): Unit ={ + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()) + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) + EasyMock.replay(transactionManager) + + coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.ABORT, errorsCallback) + assertEquals(Errors.NONE, error) + EasyMock.verify(transactionManager) + } + + @Test + def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteAbortAndResultIsNotAbort(): Unit = { + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort, 1, CompleteAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()) + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) + EasyMock.replay(transactionManager) + + coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.COMMIT, errorsCallback) + assertEquals(Errors.INVALID_TXN_STATE, error) + EasyMock.verify(transactionManager) + } + + @Test + def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsCompleteCommitAndResultIsNotCommit(): Unit = { + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort,1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()) + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) + EasyMock.replay(transactionManager) + + coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.ABORT, errorsCallback) + assertEquals(Errors.INVALID_TXN_STATE, error) + EasyMock.verify(transactionManager) + } + + @Test + def shouldReturnConcurrentTxnRequestOnEndTxnRequestWhenStatusIsPrepareCommit(): Unit = { + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, (producerEpoch - 1).toShort, 1, PrepareCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))) + EasyMock.replay(transactionManager) + + coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.COMMIT, errorsCallback) + assertEquals(Errors.CONCURRENT_TRANSACTIONS, error) + EasyMock.verify(transactionManager) + } + + @Test + def shouldReturnInvalidTxnRequestOnEndTxnRequestWhenStatusIsPrepareAbort(): Unit = { + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, new TransactionMetadata(transactionalId, producerId, producerId, 1, RecordBatch.NO_PRODUCER_EPOCH, 1, PrepareAbort, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))) + EasyMock.replay(transactionManager) + + coordinator.handleEndTransaction(transactionalId, producerId, 1, TransactionResult.COMMIT, errorsCallback) + assertEquals(Errors.INVALID_TXN_STATE, error) + EasyMock.verify(transactionManager) + } + + @Test + def shouldAppendPrepareCommitToLogOnEndTxnWhenStatusIsOngoingAndResultIsCommit(): Unit = { + mockPrepare(PrepareCommit) + + EasyMock.replay(transactionManager) + + coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.COMMIT, errorsCallback) + + EasyMock.verify(transactionManager) + } + + @Test + def shouldAppendPrepareAbortToLogOnEndTxnWhenStatusIsOngoingAndResultIsAbort(): Unit = { + mockPrepare(PrepareAbort) + + EasyMock.replay(transactionManager) + + coordinator.handleEndTransaction(transactionalId, producerId, producerEpoch, TransactionResult.ABORT, errorsCallback) + EasyMock.verify(transactionManager) + } + + @Test + def shouldRespondWithInvalidRequestOnEndTxnWhenTransactionalIdIsNull(): Unit = { + coordinator.handleEndTransaction(null, 0, 0, TransactionResult.COMMIT, errorsCallback) + assertEquals(Errors.INVALID_REQUEST, error) + } + + @Test + def shouldRespondWithInvalidRequestOnEndTxnWhenTransactionalIdIsEmpty(): Unit = { + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Left(Errors.NOT_COORDINATOR)) + EasyMock.replay(transactionManager) + + coordinator.handleEndTransaction("", 0, 0, TransactionResult.COMMIT, errorsCallback) + assertEquals(Errors.INVALID_REQUEST, error) + } + + @Test + def shouldRespondWithNotCoordinatorOnEndTxnWhenIsNotCoordinatorForId(): Unit = { + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Left(Errors.NOT_COORDINATOR)) + EasyMock.replay(transactionManager) + + coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, errorsCallback) + assertEquals(Errors.NOT_COORDINATOR, error) + } + + @Test + def shouldRespondWithCoordinatorLoadInProgressOnEndTxnWhenCoordinatorIsLoading(): Unit = { + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Left(Errors.COORDINATOR_LOAD_IN_PROGRESS)) + + EasyMock.replay(transactionManager) + + coordinator.handleEndTransaction(transactionalId, 0, 0, TransactionResult.COMMIT, errorsCallback) + assertEquals(Errors.COORDINATOR_LOAD_IN_PROGRESS, error) + } + + @Test + def shouldReturnInvalidEpochOnEndTxnWhenEpochIsLarger(): Unit = { + val serverProducerEpoch = 1.toShort + verifyEndTxnEpoch(serverProducerEpoch, (serverProducerEpoch + 1).toShort) + } + + @Test + def shouldReturnInvalidEpochOnEndTxnWhenEpochIsSmaller(): Unit = { + val serverProducerEpoch = 1.toShort + verifyEndTxnEpoch(serverProducerEpoch, (serverProducerEpoch - 1).toShort) + } + + private def verifyEndTxnEpoch(metadataEpoch: Short, requestEpoch: Short): Unit = { + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, + new TransactionMetadata(transactionalId, producerId, producerId, metadataEpoch, 0, 1, CompleteCommit, collection.mutable.Set.empty[TopicPartition], 0, time.milliseconds()))))) + EasyMock.replay(transactionManager) + + coordinator.handleEndTransaction(transactionalId, producerId, requestEpoch, TransactionResult.COMMIT, errorsCallback) + assertEquals(Errors.PRODUCER_FENCED, error) + EasyMock.verify(transactionManager) + } + + @Test + def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingEmptyTransaction(): Unit = { + validateIncrementEpochAndUpdateMetadata(Empty) + } + + @Test + def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingCompleteTransaction(): Unit = { + validateIncrementEpochAndUpdateMetadata(CompleteAbort) + } + + @Test + def shouldIncrementEpochAndUpdateMetadataOnHandleInitPidWhenExistingCompleteCommitTransaction(): Unit = { + validateIncrementEpochAndUpdateMetadata(CompleteCommit) + } + + @Test + def shouldWaitForCommitToCompleteOnHandleInitPidAndExistingTransactionInPrepareCommitState(): Unit ={ + validateRespondsWithConcurrentTransactionsOnInitPidWhenInPrepareState(PrepareCommit) + } + + @Test + def shouldWaitForCommitToCompleteOnHandleInitPidAndExistingTransactionInPrepareAbortState(): Unit ={ + validateRespondsWithConcurrentTransactionsOnInitPidWhenInPrepareState(PrepareAbort) + } + + @Test + def shouldAbortTransactionOnHandleInitPidWhenExistingTransactionInOngoingState(): Unit = { + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, + (producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) + + EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt())) + .andReturn(true) + + EasyMock.expect(transactionManager.putTransactionStateIfNotExists(EasyMock.anyObject[TransactionMetadata]())) + .andReturn(Right(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))) + .anyTimes() + + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) + .anyTimes() + + val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (producerEpoch + 1).toShort, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) + EasyMock.expect(transactionManager.appendTransactionToLog( + EasyMock.eq(transactionalId), + EasyMock.eq(coordinatorEpoch), + EasyMock.eq(originalMetadata.prepareAbortOrCommit(PrepareAbort, time.milliseconds())), + EasyMock.capture(capturedErrorsCallback), + EasyMock.anyObject(), + EasyMock.anyObject()) + ).andAnswer(() => capturedErrorsCallback.getValue.apply(Errors.NONE)) + + EasyMock.replay(transactionManager) + + coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, None, initProducerIdMockCallback) + + assertEquals(InitProducerIdResult(-1, -1, Errors.CONCURRENT_TRANSACTIONS), result) + + EasyMock.verify(transactionManager) + } + + @Test + def shouldFailToAbortTransactionOnHandleInitPidWhenProducerEpochIsSmaller(): Unit = { + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, + (producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) + + EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt())) + .andReturn(true) + + EasyMock.expect(transactionManager.putTransactionStateIfNotExists(EasyMock.anyObject[TransactionMetadata]())) + .andReturn(Right(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))) + .anyTimes() + + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) + .times(1) + + val bumpedTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (producerEpoch + 2).toShort, + (producerEpoch - 1).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) + + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, bumpedTxnMetadata)))) + .times(1) + + EasyMock.replay(transactionManager) + + coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, None, initProducerIdMockCallback) + + assertEquals(InitProducerIdResult(-1, -1, Errors.PRODUCER_FENCED), result) + + EasyMock.verify(transactionManager) + } + + @Test + def shouldNotRepeatedlyBumpEpochDueToInitPidDuringOngoingTxnIfAppendToLogFails(): Unit = { + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) + + EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt())) + .andReturn(true) + .anyTimes() + + EasyMock.expect(transactionManager.putTransactionStateIfNotExists(EasyMock.anyObject[TransactionMetadata]())) + .andReturn(Right(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))) + .anyTimes() + + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andAnswer(() => Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) + .anyTimes() + + val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (producerEpoch + 1).toShort, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) + val txnTransitMetadata = originalMetadata.prepareAbortOrCommit(PrepareAbort, time.milliseconds()) + EasyMock.expect(transactionManager.appendTransactionToLog( + EasyMock.eq(transactionalId), + EasyMock.eq(coordinatorEpoch), + EasyMock.eq(txnTransitMetadata), + EasyMock.capture(capturedErrorsCallback), + EasyMock.anyObject(), + EasyMock.anyObject()) + ).andAnswer(() => { + capturedErrorsCallback.getValue.apply(Errors.NOT_ENOUGH_REPLICAS) + txnMetadata.pendingState = None + }).times(2) + + EasyMock.expect(transactionManager.appendTransactionToLog( + EasyMock.eq(transactionalId), + EasyMock.eq(coordinatorEpoch), + EasyMock.eq(txnTransitMetadata), + EasyMock.capture(capturedErrorsCallback), + EasyMock.anyObject(), + EasyMock.anyObject()) + ).andAnswer(() => { + capturedErrorsCallback.getValue.apply(Errors.NONE) + + // For the successful call, execute the state transitions that would happen in appendTransactionToLog() + txnMetadata.completeTransitionTo(txnTransitMetadata) + txnMetadata.prepareComplete(time.milliseconds()) + }).once() + + EasyMock.replay(transactionManager) + + // For the first two calls, verify that the epoch was only bumped once + coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, None, initProducerIdMockCallback) + assertEquals(InitProducerIdResult(-1, -1, Errors.NOT_ENOUGH_REPLICAS), result) + + assertEquals((producerEpoch + 1).toShort, txnMetadata.producerEpoch) + assertTrue(txnMetadata.hasFailedEpochFence) + + coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, None, initProducerIdMockCallback) + assertEquals(InitProducerIdResult(-1, -1, Errors.NOT_ENOUGH_REPLICAS), result) + + assertEquals((producerEpoch + 1).toShort, txnMetadata.producerEpoch) + assertTrue(txnMetadata.hasFailedEpochFence) + + // For the last, successful call, verify that the epoch was not bumped further + coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, None, initProducerIdMockCallback) + assertEquals(InitProducerIdResult(-1, -1, Errors.CONCURRENT_TRANSACTIONS), result) + + assertEquals((producerEpoch + 1).toShort, txnMetadata.producerEpoch) + assertFalse(txnMetadata.hasFailedEpochFence) + + EasyMock.verify(transactionManager) + } + + @Test + def shouldUseLastEpochToFenceWhenEpochsAreExhausted(): Unit = { + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (Short.MaxValue - 1).toShort, + (Short.MaxValue - 2).toShort, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) + assertTrue(txnMetadata.isProducerEpochExhausted) + + EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt())) + .andReturn(true) + + EasyMock.expect(transactionManager.putTransactionStateIfNotExists(EasyMock.anyObject[TransactionMetadata]())) + .andReturn(Right(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata))) + .anyTimes() + + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) + .times(2) + + val postFenceTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, Short.MaxValue, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, PrepareAbort, partitions, time.milliseconds(), time.milliseconds()) + + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, postFenceTxnMetadata)))) + .once() + + EasyMock.expect(transactionManager.appendTransactionToLog( + EasyMock.eq(transactionalId), + EasyMock.eq(coordinatorEpoch), + EasyMock.eq(TxnTransitMetadata( + producerId = producerId, + lastProducerId = producerId, + producerEpoch = Short.MaxValue, + lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, + txnTimeoutMs = txnTimeoutMs, + txnState = PrepareAbort, + topicPartitions = partitions.toSet, + txnStartTimestamp = time.milliseconds(), + txnLastUpdateTimestamp = time.milliseconds())), + EasyMock.capture(capturedErrorsCallback), + EasyMock.anyObject(), + EasyMock.anyObject()) + ).andAnswer(() => capturedErrorsCallback.getValue.apply(Errors.NONE)) + + EasyMock.replay(transactionManager) + + coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, None, initProducerIdMockCallback) + assertEquals(Short.MaxValue, txnMetadata.producerEpoch) + + assertEquals(InitProducerIdResult(-1, -1, Errors.CONCURRENT_TRANSACTIONS), result) + EasyMock.verify(transactionManager) + } + + @Test + def testInitProducerIdWithNoLastProducerData(): Unit = { + // If the metadata doesn't include the previous producer data (for example, if it was written to the log by a broker + // on an old version), the retry case should fail + val txnMetadata = new TransactionMetadata(transactionalId, producerId, RecordBatch.NO_PRODUCER_ID, (producerEpoch + 1).toShort, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds) + + EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt())) + .andReturn(true).anyTimes() + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))).once + EasyMock.replay(transactionManager) + + // Simulate producer trying to continue after new producer has already been initialized + coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, Some(new ProducerIdAndEpoch(producerId, producerEpoch)), + initProducerIdMockCallback) + assertEquals(InitProducerIdResult(RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, Errors.PRODUCER_FENCED), result) + } + + @Test + def testFenceProducerWhenMappingExistsWithDifferentProducerId(): Unit = { + // Existing transaction ID maps to new producer ID + val txnMetadata = new TransactionMetadata(transactionalId, producerId + 1, producerId, producerEpoch, + (producerEpoch - 1).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds) + + EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt())) + .andReturn(true).anyTimes() + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))).once + EasyMock.replay(transactionManager) + + // Simulate producer trying to continue after new producer has already been initialized + coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, Some(new ProducerIdAndEpoch(producerId, producerEpoch)), + initProducerIdMockCallback) + assertEquals(InitProducerIdResult(RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, Errors.PRODUCER_FENCED), result) + } + + @Test + def testInitProducerIdWithCurrentEpochProvided(): Unit = { + mockPidGenerator() + + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, 10, + 9, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds) + + EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt())) + .andReturn(true).anyTimes() + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))).times(2) + + EasyMock.expect(transactionManager.appendTransactionToLog( + EasyMock.eq(transactionalId), + EasyMock.eq(coordinatorEpoch), + EasyMock.anyObject().asInstanceOf[TxnTransitMetadata], + EasyMock.capture(capturedErrorsCallback), + EasyMock.anyObject(), + EasyMock.anyObject()) + ).andAnswer(() => { + capturedErrorsCallback.getValue.apply(Errors.NONE) + txnMetadata.pendingState = None + }).times(2) + + EasyMock.replay(pidGenerator, transactionManager) + + // Re-initialization should succeed and bump the producer epoch + coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, Some(new ProducerIdAndEpoch(producerId, 10)), + initProducerIdMockCallback) + assertEquals(InitProducerIdResult(producerId, 11, Errors.NONE), result) + + // Simulate producer retrying after successfully re-initializing but failing to receive the response + coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, Some(new ProducerIdAndEpoch(producerId, 10)), + initProducerIdMockCallback) + assertEquals(InitProducerIdResult(producerId, 11, Errors.NONE), result) + } + + @Test + def testInitProducerIdStaleCurrentEpochProvided(): Unit = { + mockPidGenerator() + + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, 10, + 9, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds) + + EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt())) + .andReturn(true).anyTimes() + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))).times(2) + + val capturedTxnTransitMetadata = Capture.newInstance[TxnTransitMetadata] + EasyMock.expect(transactionManager.appendTransactionToLog( + EasyMock.eq(transactionalId), + EasyMock.eq(coordinatorEpoch), + EasyMock.capture(capturedTxnTransitMetadata), + EasyMock.capture(capturedErrorsCallback), + EasyMock.anyObject(), + EasyMock.anyObject()) + ).andAnswer(() => { + capturedErrorsCallback.getValue.apply(Errors.NONE) + txnMetadata.pendingState = None + txnMetadata.producerEpoch = capturedTxnTransitMetadata.getValue.producerEpoch + txnMetadata.lastProducerEpoch = capturedTxnTransitMetadata.getValue.lastProducerEpoch + }).times(2) + + EasyMock.replay(pidGenerator, transactionManager) + + // With producer epoch at 10, new producer calls InitProducerId and should get epoch 11 + coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, None, initProducerIdMockCallback) + assertEquals(InitProducerIdResult(producerId, 11, Errors.NONE), result) + + // Simulate old producer trying to continue from epoch 10 + coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, Some(new ProducerIdAndEpoch(producerId, 10)), + initProducerIdMockCallback) + assertEquals(InitProducerIdResult(RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, Errors.PRODUCER_FENCED), result) + } + + @Test + def testRetryInitProducerIdAfterProducerIdRotation(): Unit = { + // Existing transaction ID maps to new producer ID + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (Short.MaxValue - 1).toShort, + (Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds) + + EasyMock.expect(pidGenerator.generateProducerId()) + .andReturn(producerId + 1) + .anyTimes() + + EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt())) + .andReturn(true).anyTimes() + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))).times(2) + + EasyMock.expect(transactionManager.appendTransactionToLog( + EasyMock.eq(transactionalId), + EasyMock.eq(coordinatorEpoch), + EasyMock.capture(capturedTxnTransitMetadata), + EasyMock.capture(capturedErrorsCallback), + EasyMock.anyObject(), + EasyMock.anyObject()) + ).andAnswer(() => { + capturedErrorsCallback.getValue.apply(Errors.NONE) + txnMetadata.pendingState = None + txnMetadata.producerId = capturedTxnTransitMetadata.getValue.producerId + txnMetadata.lastProducerId = capturedTxnTransitMetadata.getValue.lastProducerId + txnMetadata.producerEpoch = capturedTxnTransitMetadata.getValue.producerEpoch + txnMetadata.lastProducerEpoch = capturedTxnTransitMetadata.getValue.lastProducerEpoch + }).once + + EasyMock.replay(pidGenerator, transactionManager) + + // Bump epoch and cause producer ID to be rotated + coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, Some(new ProducerIdAndEpoch(producerId, + (Short.MaxValue - 1).toShort)), initProducerIdMockCallback) + assertEquals(InitProducerIdResult(producerId + 1, 0, Errors.NONE), result) + + // Simulate producer retrying old request after producer bump + coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, Some(new ProducerIdAndEpoch(producerId, + (Short.MaxValue - 1).toShort)), initProducerIdMockCallback) + assertEquals(InitProducerIdResult(producerId + 1, 0, Errors.NONE), result) + } + + @Test + def testInitProducerIdWithInvalidEpochAfterProducerIdRotation(): Unit = { + // Existing transaction ID maps to new producer ID + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (Short.MaxValue - 1).toShort, + (Short.MaxValue - 2).toShort, txnTimeoutMs, Empty, partitions, time.milliseconds, time.milliseconds) + + EasyMock.expect(pidGenerator.generateProducerId()) + .andReturn(producerId + 1) + .anyTimes() + + EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt())) + .andReturn(true).anyTimes() + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))).times(2) + + EasyMock.expect(transactionManager.appendTransactionToLog( + EasyMock.eq(transactionalId), + EasyMock.eq(coordinatorEpoch), + EasyMock.capture(capturedTxnTransitMetadata), + EasyMock.capture(capturedErrorsCallback), + EasyMock.anyObject(), + EasyMock.anyObject()) + ).andAnswer(() => { + capturedErrorsCallback.getValue.apply(Errors.NONE) + txnMetadata.pendingState = None + txnMetadata.producerId = capturedTxnTransitMetadata.getValue.producerId + txnMetadata.lastProducerId = capturedTxnTransitMetadata.getValue.lastProducerId + txnMetadata.producerEpoch = capturedTxnTransitMetadata.getValue.producerEpoch + txnMetadata.lastProducerEpoch = capturedTxnTransitMetadata.getValue.lastProducerEpoch + }).once + + EasyMock.replay(pidGenerator, transactionManager) + + // Bump epoch and cause producer ID to be rotated + coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, Some(new ProducerIdAndEpoch(producerId, + (Short.MaxValue - 1).toShort)), initProducerIdMockCallback) + assertEquals(InitProducerIdResult(producerId + 1, 0, Errors.NONE), result) + + // Validate that producer with old producer ID and stale epoch is fenced + coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, Some(new ProducerIdAndEpoch(producerId, + (Short.MaxValue - 2).toShort)), initProducerIdMockCallback) + assertEquals(InitProducerIdResult(RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, Errors.PRODUCER_FENCED), result) + } + + @Test + def shouldRemoveTransactionsForPartitionOnEmigration(): Unit = { + EasyMock.expect(transactionManager.removeTransactionsForTxnTopicPartition(0, coordinatorEpoch)) + EasyMock.expect(transactionMarkerChannelManager.removeMarkersForTxnTopicPartition(0)) + EasyMock.replay(transactionManager, transactionMarkerChannelManager) + + coordinator.onResignation(0, Some(coordinatorEpoch)) + + EasyMock.verify(transactionManager, transactionMarkerChannelManager) + } + + @Test + def shouldAbortExpiredTransactionsInOngoingStateAndBumpEpoch(): Unit = { + val now = time.milliseconds() + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now) + + EasyMock.expect(transactionManager.timedOutTransactions()) + .andReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch))) + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) + .times(2) + + val expectedTransition = TxnTransitMetadata(producerId, producerId, (producerEpoch + 1).toShort, RecordBatch.NO_PRODUCER_EPOCH, + txnTimeoutMs, PrepareAbort, partitions.toSet, now, now + TransactionStateManager.DefaultAbortTimedOutTransactionsIntervalMs) + + EasyMock.expect(transactionManager.appendTransactionToLog(EasyMock.eq(transactionalId), + EasyMock.eq(coordinatorEpoch), + EasyMock.eq(expectedTransition), + EasyMock.capture(capturedErrorsCallback), + EasyMock.anyObject(), + EasyMock.anyObject()) + ).andAnswer(() => {}).once() + + EasyMock.replay(transactionManager, transactionMarkerChannelManager) + + coordinator.startup(() => transactionStatePartitionCount, false) + time.sleep(TransactionStateManager.DefaultAbortTimedOutTransactionsIntervalMs) + scheduler.tick() + EasyMock.verify(transactionManager) + } + + @Test + def shouldNotAcceptSmallerEpochDuringTransactionExpiration(): Unit = { + val now = time.milliseconds() + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now) + + EasyMock.expect(transactionManager.timedOutTransactions()) + .andReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch))) + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) + + val bumpedTxnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, (producerEpoch + 2).toShort, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now) + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, bumpedTxnMetadata)))) + + EasyMock.replay(transactionManager, transactionMarkerChannelManager) + + def checkOnEndTransactionComplete(txnIdAndPidEpoch: TransactionalIdAndProducerIdEpoch)(error: Errors): Unit = { + assertEquals(Errors.PRODUCER_FENCED, error) + } + coordinator.abortTimedOutTransactions(checkOnEndTransactionComplete) + + EasyMock.verify(transactionManager) + } + + @Test + def shouldNotAbortExpiredTransactionsThatHaveAPendingStateTransition(): Unit = { + val metadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) + metadata.prepareAbortOrCommit(PrepareCommit, time.milliseconds()) + + EasyMock.expect(transactionManager.timedOutTransactions()) + .andReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch))) + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata)))).once() + + EasyMock.replay(transactionManager, transactionMarkerChannelManager) + + coordinator.startup(() => transactionStatePartitionCount, false) + time.sleep(TransactionStateManager.DefaultAbortTimedOutTransactionsIntervalMs) + scheduler.tick() + EasyMock.verify(transactionManager) + } + + @Test + def shouldNotBumpEpochWhenAbortingExpiredTransactionIfAppendToLogFails(): Unit = { + val now = time.milliseconds() + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now) + + + EasyMock.expect(transactionManager.timedOutTransactions()) + .andReturn(List(TransactionalIdAndProducerIdEpoch(transactionalId, producerId, producerEpoch))) + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) + .times(2) + + val txnMetadataAfterAppendFailure = new TransactionMetadata(transactionalId, producerId, producerId, (producerEpoch + 1).toShort, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, now, now) + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andAnswer(() => Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadataAfterAppendFailure)))) + .once + + val bumpedEpoch = (producerEpoch + 1).toShort + val expectedTransition = TxnTransitMetadata(producerId, producerId, bumpedEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, + PrepareAbort, partitions.toSet, now, now + TransactionStateManager.DefaultAbortTimedOutTransactionsIntervalMs) + + EasyMock.expect(transactionManager.appendTransactionToLog(EasyMock.eq(transactionalId), + EasyMock.eq(coordinatorEpoch), + EasyMock.eq(expectedTransition), + EasyMock.capture(capturedErrorsCallback), + EasyMock.anyObject(), + EasyMock.anyObject()) + ).andAnswer(() => capturedErrorsCallback.getValue.apply(Errors.NOT_ENOUGH_REPLICAS)).once() + + EasyMock.replay(transactionManager, transactionMarkerChannelManager) + + coordinator.startup(() => transactionStatePartitionCount, false) + time.sleep(TransactionStateManager.DefaultAbortTimedOutTransactionsIntervalMs) + scheduler.tick() + EasyMock.verify(transactionManager) + + assertEquals((producerEpoch + 1).toShort, txnMetadataAfterAppendFailure.producerEpoch) + assertTrue(txnMetadataAfterAppendFailure.hasFailedEpochFence) + } + + @Test + def shouldNotBumpEpochWithPendingTransaction(): Unit = { + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) + txnMetadata.prepareAbortOrCommit(PrepareCommit, time.milliseconds()) + + EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt())) + .andReturn(true).anyTimes() + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) + + EasyMock.replay(transactionManager) + + coordinator.handleInitProducerId(transactionalId, txnTimeoutMs, Some(new ProducerIdAndEpoch(producerId, 10)), + initProducerIdMockCallback) + assertEquals(InitProducerIdResult(RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, Errors.CONCURRENT_TRANSACTIONS), result) + + EasyMock.verify(transactionManager) + } + + @Test + def testDescribeTransactionsWithEmptyTransactionalId(): Unit = { + coordinator.startup(() => transactionStatePartitionCount, enableTransactionalIdExpiration = false) + val result = coordinator.handleDescribeTransactions("") + assertEquals("", result.transactionalId) + assertEquals(Errors.INVALID_REQUEST, Errors.forCode(result.errorCode)) + } + + @Test + def testDescribeTransactionsWithExpiringTransactionalId(): Unit = { + coordinator.startup(() => transactionStatePartitionCount, enableTransactionalIdExpiration = false) + + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Dead, mutable.Set.empty, time.milliseconds(), + time.milliseconds()) + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) + EasyMock.replay(transactionManager) + + val result = coordinator.handleDescribeTransactions(transactionalId) + assertEquals(transactionalId, result.transactionalId) + assertEquals(Errors.TRANSACTIONAL_ID_NOT_FOUND, Errors.forCode(result.errorCode)) + } + + @Test + def testDescribeTransactionsWhileCoordinatorLoading(): Unit = { + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Left(Errors.COORDINATOR_LOAD_IN_PROGRESS)) + + EasyMock.replay(transactionManager) + + coordinator.startup(() => transactionStatePartitionCount, enableTransactionalIdExpiration = false) + val result = coordinator.handleDescribeTransactions(transactionalId) + assertEquals(transactionalId, result.transactionalId) + assertEquals(Errors.COORDINATOR_LOAD_IN_PROGRESS, Errors.forCode(result.errorCode)) + + EasyMock.verify(transactionManager) + } + + @Test + def testDescribeTransactions(): Unit = { + val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, + RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, Ongoing, partitions, time.milliseconds(), time.milliseconds()) + + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) + + EasyMock.replay(transactionManager) + + coordinator.startup(() => transactionStatePartitionCount, enableTransactionalIdExpiration = false) + val result = coordinator.handleDescribeTransactions(transactionalId) + assertEquals(Errors.NONE, Errors.forCode(result.errorCode)) + assertEquals(transactionalId, result.transactionalId) + assertEquals(producerId, result.producerId) + assertEquals(producerEpoch, result.producerEpoch) + assertEquals(txnTimeoutMs, result.transactionTimeoutMs) + assertEquals(time.milliseconds(), result.transactionStartTimeMs) + + val addedPartitions = result.topics.asScala.flatMap { topicData => + topicData.partitions.asScala.map(partition => new TopicPartition(topicData.topic, partition)) + }.toSet + assertEquals(partitions, addedPartitions) + + EasyMock.verify(transactionManager) + } + + private def validateRespondsWithConcurrentTransactionsOnInitPidWhenInPrepareState(state: TransactionState): Unit = { + EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt())) + .andReturn(true).anyTimes() + + val metadata = new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, state, mutable.Set[TopicPartition](new TopicPartition("topic", 1)), 0, 0) + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata)))).anyTimes() + + EasyMock.replay(transactionManager) + + coordinator.handleInitProducerId(transactionalId, 10, None, initProducerIdMockCallback) + + assertEquals(InitProducerIdResult(-1, -1, Errors.CONCURRENT_TRANSACTIONS), result) + } + + private def validateIncrementEpochAndUpdateMetadata(state: TransactionState): Unit = { + EasyMock.expect(pidGenerator.generateProducerId()) + .andReturn(producerId) + .anyTimes() + + EasyMock.expect(transactionManager.validateTransactionTimeoutMs(EasyMock.anyInt())) + .andReturn(true) + + val metadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, state, mutable.Set.empty[TopicPartition], time.milliseconds(), time.milliseconds()) + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, metadata)))) + + val capturedNewMetadata: Capture[TxnTransitMetadata] = EasyMock.newCapture() + EasyMock.expect(transactionManager.appendTransactionToLog( + EasyMock.eq(transactionalId), + EasyMock.eq(coordinatorEpoch), + EasyMock.capture(capturedNewMetadata), + EasyMock.capture(capturedErrorsCallback), + EasyMock.anyObject(), + EasyMock.anyObject() + )).andAnswer(() => { + metadata.completeTransitionTo(capturedNewMetadata.getValue) + capturedErrorsCallback.getValue.apply(Errors.NONE) + }) + + EasyMock.replay(pidGenerator, transactionManager) + + val newTxnTimeoutMs = 10 + coordinator.handleInitProducerId(transactionalId, newTxnTimeoutMs, None, initProducerIdMockCallback) + + assertEquals(InitProducerIdResult(producerId, (producerEpoch + 1).toShort, Errors.NONE), result) + assertEquals(newTxnTimeoutMs, metadata.txnTimeoutMs) + assertEquals(time.milliseconds(), metadata.txnLastUpdateTimestamp) + assertEquals((producerEpoch + 1).toShort, metadata.producerEpoch) + assertEquals(producerId, metadata.producerId) + } + + private def mockPrepare(transactionState: TransactionState, runCallback: Boolean = false): TransactionMetadata = { + val now = time.milliseconds() + val originalMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, + txnTimeoutMs, Ongoing, partitions, now, now) + + val transition = TxnTransitMetadata(producerId, producerId, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, txnTimeoutMs, + transactionState, partitions.toSet, now, now) + + EasyMock.expect(transactionManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, originalMetadata)))) + .once() + EasyMock.expect(transactionManager.appendTransactionToLog( + EasyMock.eq(transactionalId), + EasyMock.eq(coordinatorEpoch), + EasyMock.eq(transition), + EasyMock.capture(capturedErrorsCallback), + EasyMock.anyObject(), + EasyMock.anyObject()) + ).andAnswer(() => { + if (runCallback) + capturedErrorsCallback.getValue.apply(Errors.NONE) + }).once() + + new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, RecordBatch.NO_PRODUCER_EPOCH, + txnTimeoutMs, transactionState, partitions, time.milliseconds(), time.milliseconds()) + } + + def initProducerIdMockCallback(ret: InitProducerIdResult): Unit = { + result = ret + } + + def errorsCallback(ret: Errors): Unit = { + error = ret + } +} diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala new file mode 100644 index 0000000..32e17d8 --- /dev/null +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.coordinator.transaction + + +import kafka.utils.TestUtils +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.record.{CompressionType, MemoryRecords, SimpleRecord} +import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows} +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ + +class TransactionLogTest { + + val producerEpoch: Short = 0 + val transactionTimeoutMs: Int = 1000 + + val topicPartitions: Set[TopicPartition] = Set[TopicPartition](new TopicPartition("topic1", 0), + new TopicPartition("topic1", 1), + new TopicPartition("topic2", 0), + new TopicPartition("topic2", 1), + new TopicPartition("topic2", 2)) + + @Test + def shouldThrowExceptionWriteInvalidTxn(): Unit = { + val transactionalId = "transactionalId" + val producerId = 23423L + + val txnMetadata = TransactionMetadata(transactionalId, producerId, producerEpoch, transactionTimeoutMs, 0) + txnMetadata.addPartitions(topicPartitions) + + assertThrows(classOf[IllegalStateException], () => TransactionLog.valueToBytes(txnMetadata.prepareNoTransit())) + } + + @Test + def shouldReadWriteMessages(): Unit = { + val pidMappings = Map[String, Long]("zero" -> 0L, + "one" -> 1L, + "two" -> 2L, + "three" -> 3L, + "four" -> 4L, + "five" -> 5L) + + val transactionStates = Map[Long, TransactionState](0L -> Empty, + 1L -> Ongoing, + 2L -> PrepareCommit, + 3L -> CompleteCommit, + 4L -> PrepareAbort, + 5L -> CompleteAbort) + + // generate transaction log messages + val txnRecords = pidMappings.map { case (transactionalId, producerId) => + val txnMetadata = TransactionMetadata(transactionalId, producerId, producerEpoch, transactionTimeoutMs, + transactionStates(producerId), 0) + + if (!txnMetadata.state.equals(Empty)) + txnMetadata.addPartitions(topicPartitions) + + val keyBytes = TransactionLog.keyToBytes(transactionalId) + val valueBytes = TransactionLog.valueToBytes(txnMetadata.prepareNoTransit()) + + new SimpleRecord(keyBytes, valueBytes) + }.toSeq + + val records = MemoryRecords.withRecords(0, CompressionType.NONE, txnRecords: _*) + + var count = 0 + for (record <- records.records.asScala) { + val txnKey = TransactionLog.readTxnRecordKey(record.key) + val transactionalId = txnKey.transactionalId + val txnMetadata = TransactionLog.readTxnRecordValue(transactionalId, record.value).get + + assertEquals(pidMappings(transactionalId), txnMetadata.producerId) + assertEquals(producerEpoch, txnMetadata.producerEpoch) + assertEquals(transactionTimeoutMs, txnMetadata.txnTimeoutMs) + assertEquals(transactionStates(txnMetadata.producerId), txnMetadata.state) + + if (txnMetadata.state.equals(Empty)) + assertEquals(Set.empty[TopicPartition], txnMetadata.topicPartitions) + else + assertEquals(topicPartitions, txnMetadata.topicPartitions) + + count = count + 1 + } + + assertEquals(pidMappings.size, count) + } + + @Test + def testTransactionMetadataParsing(): Unit = { + val transactionalId = "id" + val producerId = 1334L + val topicPartition = new TopicPartition("topic", 0) + + val txnMetadata = TransactionMetadata(transactionalId, producerId, producerEpoch, + transactionTimeoutMs, Ongoing, 0) + txnMetadata.addPartitions(Set(topicPartition)) + + val keyBytes = TransactionLog.keyToBytes(transactionalId) + val valueBytes = TransactionLog.valueToBytes(txnMetadata.prepareNoTransit()) + val transactionMetadataRecord = TestUtils.records(Seq( + new SimpleRecord(keyBytes, valueBytes) + )).records.asScala.head + + val (keyStringOpt, valueStringOpt) = TransactionLog.formatRecordKeyAndValue(transactionMetadataRecord) + assertEquals(Some(s"transaction_metadata::transactionalId=$transactionalId"), keyStringOpt) + assertEquals(Some(s"producerId:$producerId,producerEpoch:$producerEpoch,state=Ongoing," + + s"partitions=[$topicPartition],txnLastUpdateTimestamp=0,txnTimeoutMs=$transactionTimeoutMs"), valueStringOpt) + } + + @Test + def testTransactionMetadataTombstoneParsing(): Unit = { + val transactionalId = "id" + val transactionMetadataRecord = TestUtils.records(Seq( + new SimpleRecord(TransactionLog.keyToBytes(transactionalId), null) + )).records.asScala.head + + val (keyStringOpt, valueStringOpt) = TransactionLog.formatRecordKeyAndValue(transactionMetadataRecord) + assertEquals(Some(s"transaction_metadata::transactionalId=$transactionalId"), keyStringOpt) + assertEquals(Some(""), valueStringOpt) + } + +} diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala new file mode 100644 index 0000000..0a0ec51 --- /dev/null +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala @@ -0,0 +1,495 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.coordinator.transaction + +import java.util +import java.util.Arrays.asList +import java.util.Collections +import java.util.concurrent.{Callable, Executors, Future} + +import kafka.common.RequestAndCompletionHandler +import kafka.metrics.KafkaYammerMetrics +import kafka.server.{KafkaConfig, MetadataCache} +import kafka.utils.TestUtils +import org.apache.kafka.clients.{ClientResponse, NetworkClient} +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.record.RecordBatch +import org.apache.kafka.common.requests.{RequestHeader, TransactionResult, WriteTxnMarkersRequest, WriteTxnMarkersResponse} +import org.apache.kafka.common.utils.MockTime +import org.apache.kafka.common.{Node, TopicPartition} +import org.easymock.{Capture, EasyMock} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ +import scala.collection.mutable +import scala.util.Try + +class TransactionMarkerChannelManagerTest { + private val metadataCache: MetadataCache = EasyMock.createNiceMock(classOf[MetadataCache]) + private val networkClient: NetworkClient = EasyMock.createNiceMock(classOf[NetworkClient]) + private val txnStateManager: TransactionStateManager = EasyMock.mock(classOf[TransactionStateManager]) + + private val partition1 = new TopicPartition("topic1", 0) + private val partition2 = new TopicPartition("topic1", 1) + private val broker1 = new Node(1, "host", 10) + private val broker2 = new Node(2, "otherhost", 10) + + private val transactionalId1 = "txnId1" + private val transactionalId2 = "txnId2" + private val producerId1 = 0.asInstanceOf[Long] + private val producerId2 = 1.asInstanceOf[Long] + private val producerEpoch = 0.asInstanceOf[Short] + private val lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH + private val txnTopicPartition1 = 0 + private val txnTopicPartition2 = 1 + private val coordinatorEpoch = 0 + private val txnTimeoutMs = 0 + private val txnResult = TransactionResult.COMMIT + private val txnMetadata1 = new TransactionMetadata(transactionalId1, producerId1, producerId1, producerEpoch, lastProducerEpoch, + txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1, partition2), 0L, 0L) + private val txnMetadata2 = new TransactionMetadata(transactionalId2, producerId2, producerId2, producerEpoch, lastProducerEpoch, + txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](partition1), 0L, 0L) + + private val capturedErrorsCallback: Capture[Errors => Unit] = EasyMock.newCapture() + private val time = new MockTime + + private val channelManager = new TransactionMarkerChannelManager( + KafkaConfig.fromProps(TestUtils.createBrokerConfig(1, "localhost:2181")), + metadataCache, + networkClient, + txnStateManager, + time) + + private def mockCache(): Unit = { + EasyMock.expect(txnStateManager.partitionFor(transactionalId1)) + .andReturn(txnTopicPartition1) + .anyTimes() + EasyMock.expect(txnStateManager.partitionFor(transactionalId2)) + .andReturn(txnTopicPartition2) + .anyTimes() + EasyMock.expect(txnStateManager.getTransactionState(EasyMock.eq(transactionalId1))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)))) + .anyTimes() + EasyMock.expect(txnStateManager.getTransactionState(EasyMock.eq(transactionalId2))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata2)))) + .anyTimes() + } + + @Test + def shouldOnlyWriteTxnCompletionOnce(): Unit = { + mockCache() + + val expectedTransition = txnMetadata2.prepareComplete(time.milliseconds()) + + EasyMock.expect(metadataCache.getPartitionLeaderEndpoint( + EasyMock.eq(partition1.topic), + EasyMock.eq(partition1.partition), + EasyMock.anyObject()) + ).andReturn(Some(broker1)).anyTimes() + + EasyMock.expect(txnStateManager.appendTransactionToLog( + EasyMock.eq(transactionalId2), + EasyMock.eq(coordinatorEpoch), + EasyMock.eq(expectedTransition), + EasyMock.capture(capturedErrorsCallback), + EasyMock.anyObject(), + EasyMock.anyObject())) + .andAnswer(() => { + txnMetadata2.completeTransitionTo(expectedTransition) + capturedErrorsCallback.getValue.apply(Errors.NONE) + }).once() + + EasyMock.replay(txnStateManager, metadataCache) + + var addMarkerFuture: Future[Try[Unit]] = null + val executor = Executors.newFixedThreadPool(1) + txnMetadata2.lock.lock() + try { + addMarkerFuture = executor.submit((() => { + Try(channelManager.addTxnMarkersToSend(coordinatorEpoch, txnResult, + txnMetadata2, expectedTransition)) + }): Callable[Try[Unit]]) + + val header = new RequestHeader(ApiKeys.WRITE_TXN_MARKERS, 0, "client", 1) + val response = new WriteTxnMarkersResponse( + Collections.singletonMap(producerId2: java.lang.Long, Collections.singletonMap(partition1, Errors.NONE))) + val clientResponse = new ClientResponse(header, null, null, + time.milliseconds(), time.milliseconds(), false, null, null, + response) + + TestUtils.waitUntilTrue(() => { + val requests = channelManager.generateRequests() + if (requests.nonEmpty) { + assertEquals(1, requests.size) + val request = requests.head + request.handler.onComplete(clientResponse) + true + } else { + false + } + }, "Timed out waiting for expected WriteTxnMarkers request") + } finally { + txnMetadata2.lock.unlock() + executor.shutdown() + } + + assertNotNull(addMarkerFuture) + assertTrue(addMarkerFuture.get().isSuccess, + "Add marker task failed with exception " + addMarkerFuture.get().get) + + EasyMock.verify(txnStateManager) + } + + @Test + def shouldGenerateEmptyMapWhenNoRequestsOutstanding(): Unit = { + assertTrue(channelManager.generateRequests().isEmpty) + } + + @Test + def shouldGenerateRequestPerPartitionPerBroker(): Unit = { + mockCache() + EasyMock.replay(txnStateManager) + + EasyMock.expect(metadataCache.getPartitionLeaderEndpoint( + EasyMock.eq(partition1.topic), + EasyMock.eq(partition1.partition), + EasyMock.anyObject()) + ).andReturn(Some(broker1)).anyTimes() + EasyMock.expect(metadataCache.getPartitionLeaderEndpoint( + EasyMock.eq(partition2.topic), + EasyMock.eq(partition2.partition), + EasyMock.anyObject()) + ).andReturn(Some(broker2)).anyTimes() + + EasyMock.replay(metadataCache) + + channelManager.addTxnMarkersToSend(coordinatorEpoch, txnResult, txnMetadata1, txnMetadata1.prepareComplete(time.milliseconds())) + channelManager.addTxnMarkersToSend(coordinatorEpoch, txnResult, txnMetadata2, txnMetadata2.prepareComplete(time.milliseconds())) + + assertEquals(2, channelManager.numTxnsWithPendingMarkers) + assertEquals(2, channelManager.queueForBroker(broker1.id).get.totalNumMarkers) + assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers(txnTopicPartition1)) + assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers(txnTopicPartition2)) + assertEquals(1, channelManager.queueForBroker(broker2.id).get.totalNumMarkers) + assertEquals(1, channelManager.queueForBroker(broker2.id).get.totalNumMarkers(txnTopicPartition1)) + assertEquals(0, channelManager.queueForBroker(broker2.id).get.totalNumMarkers(txnTopicPartition2)) + + val expectedBroker1Request = new WriteTxnMarkersRequest.Builder(ApiKeys.WRITE_TXN_MARKERS.latestVersion(), + asList(new WriteTxnMarkersRequest.TxnMarkerEntry(producerId1, producerEpoch, coordinatorEpoch, txnResult, asList(partition1)), + new WriteTxnMarkersRequest.TxnMarkerEntry(producerId2, producerEpoch, coordinatorEpoch, txnResult, asList(partition1)))).build() + val expectedBroker2Request = new WriteTxnMarkersRequest.Builder(ApiKeys.WRITE_TXN_MARKERS.latestVersion(), + asList(new WriteTxnMarkersRequest.TxnMarkerEntry(producerId1, producerEpoch, coordinatorEpoch, txnResult, asList(partition2)))).build() + + val requests: Map[Node, WriteTxnMarkersRequest] = channelManager.generateRequests().map { handler => + (handler.destination, handler.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build()) + }.toMap + + assertEquals(Map(broker1 -> expectedBroker1Request, broker2 -> expectedBroker2Request), requests) + assertTrue(channelManager.generateRequests().isEmpty) + } + + @Test + def shouldSkipSendMarkersWhenLeaderNotFound(): Unit = { + mockCache() + EasyMock.replay(txnStateManager) + + EasyMock.expect(metadataCache.getPartitionLeaderEndpoint( + EasyMock.eq(partition1.topic), + EasyMock.eq(partition1.partition), + EasyMock.anyObject()) + ).andReturn(None).anyTimes() + EasyMock.expect(metadataCache.getPartitionLeaderEndpoint( + EasyMock.eq(partition2.topic), + EasyMock.eq(partition2.partition), + EasyMock.anyObject()) + ).andReturn(Some(broker2)).anyTimes() + + EasyMock.replay(metadataCache) + + channelManager.addTxnMarkersToSend(coordinatorEpoch, txnResult, txnMetadata1, txnMetadata1.prepareComplete(time.milliseconds())) + + assertEquals(1, channelManager.numTxnsWithPendingMarkers) + assertEquals(1, channelManager.queueForBroker(broker2.id).get.totalNumMarkers) + assertTrue(channelManager.queueForBroker(broker1.id).isEmpty) + assertEquals(1, channelManager.queueForBroker(broker2.id).get.totalNumMarkers(txnTopicPartition1)) + assertEquals(0, channelManager.queueForBroker(broker2.id).get.totalNumMarkers(txnTopicPartition2)) + } + + @Test + def shouldSaveForLaterWhenLeaderUnknownButNotAvailable(): Unit = { + mockCache() + EasyMock.replay(txnStateManager) + + EasyMock.expect(metadataCache.getPartitionLeaderEndpoint( + EasyMock.eq(partition1.topic), + EasyMock.eq(partition1.partition), + EasyMock.anyObject()) + ).andReturn(Some(Node.noNode)) + .andReturn(Some(Node.noNode)) + .andReturn(Some(Node.noNode)) + .andReturn(Some(Node.noNode)) + .andReturn(Some(broker1)) + .andReturn(Some(broker1)) + EasyMock.expect(metadataCache.getPartitionLeaderEndpoint( + EasyMock.eq(partition2.topic), + EasyMock.eq(partition2.partition), + EasyMock.anyObject()) + ).andReturn(Some(broker2)).anyTimes() + + EasyMock.replay(metadataCache) + + channelManager.addTxnMarkersToSend(coordinatorEpoch, txnResult, txnMetadata1, txnMetadata1.prepareComplete(time.milliseconds())) + channelManager.addTxnMarkersToSend(coordinatorEpoch, txnResult, txnMetadata2, txnMetadata2.prepareComplete(time.milliseconds())) + + assertEquals(2, channelManager.numTxnsWithPendingMarkers) + assertEquals(1, channelManager.queueForBroker(broker2.id).get.totalNumMarkers) + assertTrue(channelManager.queueForBroker(broker1.id).isEmpty) + assertEquals(1, channelManager.queueForBroker(broker2.id).get.totalNumMarkers(txnTopicPartition1)) + assertEquals(0, channelManager.queueForBroker(broker2.id).get.totalNumMarkers(txnTopicPartition2)) + assertEquals(2, channelManager.queueForUnknownBroker.totalNumMarkers) + assertEquals(1, channelManager.queueForUnknownBroker.totalNumMarkers(txnTopicPartition1)) + assertEquals(1, channelManager.queueForUnknownBroker.totalNumMarkers(txnTopicPartition2)) + + val expectedBroker1Request = new WriteTxnMarkersRequest.Builder(ApiKeys.WRITE_TXN_MARKERS.latestVersion(), + asList(new WriteTxnMarkersRequest.TxnMarkerEntry(producerId1, producerEpoch, coordinatorEpoch, txnResult, asList(partition1)), + new WriteTxnMarkersRequest.TxnMarkerEntry(producerId2, producerEpoch, coordinatorEpoch, txnResult, asList(partition1)))).build() + val expectedBroker2Request = new WriteTxnMarkersRequest.Builder(ApiKeys.WRITE_TXN_MARKERS.latestVersion(), + asList(new WriteTxnMarkersRequest.TxnMarkerEntry(producerId1, producerEpoch, coordinatorEpoch, txnResult, asList(partition2)))).build() + + val firstDrainedRequests: Map[Node, WriteTxnMarkersRequest] = channelManager.generateRequests().map { handler => + (handler.destination, handler.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build()) + }.toMap + + assertEquals(Map(broker2 -> expectedBroker2Request), firstDrainedRequests) + + val secondDrainedRequests: Map[Node, WriteTxnMarkersRequest] = channelManager.generateRequests().map { handler => + (handler.destination, handler.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build()) + }.toMap + + assertEquals(Map(broker1 -> expectedBroker1Request), secondDrainedRequests) + } + + @Test + def shouldRemoveMarkersForTxnPartitionWhenPartitionEmigrated(): Unit = { + mockCache() + EasyMock.replay(txnStateManager) + + EasyMock.expect(metadataCache.getPartitionLeaderEndpoint( + EasyMock.eq(partition1.topic), + EasyMock.eq(partition1.partition), + EasyMock.anyObject()) + ).andReturn(Some(broker1)).anyTimes() + EasyMock.expect(metadataCache.getPartitionLeaderEndpoint( + EasyMock.eq(partition2.topic), + EasyMock.eq(partition2.partition), + EasyMock.anyObject()) + ).andReturn(Some(broker2)).anyTimes() + + EasyMock.replay(metadataCache) + + channelManager.addTxnMarkersToSend(coordinatorEpoch, txnResult, txnMetadata1, txnMetadata1.prepareComplete(time.milliseconds())) + channelManager.addTxnMarkersToSend(coordinatorEpoch, txnResult, txnMetadata2, txnMetadata2.prepareComplete(time.milliseconds())) + + assertEquals(2, channelManager.numTxnsWithPendingMarkers) + assertEquals(2, channelManager.queueForBroker(broker1.id).get.totalNumMarkers) + assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers(txnTopicPartition1)) + assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers(txnTopicPartition2)) + assertEquals(1, channelManager.queueForBroker(broker2.id).get.totalNumMarkers) + assertEquals(1, channelManager.queueForBroker(broker2.id).get.totalNumMarkers(txnTopicPartition1)) + assertEquals(0, channelManager.queueForBroker(broker2.id).get.totalNumMarkers(txnTopicPartition2)) + + channelManager.removeMarkersForTxnTopicPartition(txnTopicPartition1) + + assertEquals(1, channelManager.numTxnsWithPendingMarkers) + assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers) + assertEquals(0, channelManager.queueForBroker(broker1.id).get.totalNumMarkers(txnTopicPartition1)) + assertEquals(1, channelManager.queueForBroker(broker1.id).get.totalNumMarkers(txnTopicPartition2)) + assertEquals(0, channelManager.queueForBroker(broker2.id).get.totalNumMarkers) + assertEquals(0, channelManager.queueForBroker(broker2.id).get.totalNumMarkers(txnTopicPartition1)) + assertEquals(0, channelManager.queueForBroker(broker2.id).get.totalNumMarkers(txnTopicPartition2)) + } + + @Test + def shouldCompleteAppendToLogOnEndTxnWhenSendMarkersSucceed(): Unit = { + mockCache() + + EasyMock.expect(metadataCache.getPartitionLeaderEndpoint( + EasyMock.eq(partition1.topic), + EasyMock.eq(partition1.partition), + EasyMock.anyObject()) + ).andReturn(Some(broker1)).anyTimes() + EasyMock.expect(metadataCache.getPartitionLeaderEndpoint( + EasyMock.eq(partition2.topic), + EasyMock.eq(partition2.partition), + EasyMock.anyObject()) + ).andReturn(Some(broker2)).anyTimes() + + val txnTransitionMetadata2 = txnMetadata2.prepareComplete(time.milliseconds()) + + EasyMock.expect(txnStateManager.appendTransactionToLog( + EasyMock.eq(transactionalId2), + EasyMock.eq(coordinatorEpoch), + EasyMock.eq(txnTransitionMetadata2), + EasyMock.capture(capturedErrorsCallback), + EasyMock.anyObject(), + EasyMock.anyObject())) + .andAnswer(() => { + txnMetadata2.completeTransitionTo(txnTransitionMetadata2) + capturedErrorsCallback.getValue.apply(Errors.NONE) + }).once() + EasyMock.replay(txnStateManager, metadataCache) + + channelManager.addTxnMarkersToSend(coordinatorEpoch, txnResult, txnMetadata2, txnTransitionMetadata2) + + val requestAndHandlers: Iterable[RequestAndCompletionHandler] = channelManager.generateRequests() + + val response = new WriteTxnMarkersResponse(createPidErrorMap(Errors.NONE)) + for (requestAndHandler <- requestAndHandlers) { + requestAndHandler.handler.onComplete(new ClientResponse(new RequestHeader(ApiKeys.WRITE_TXN_MARKERS, 0, "client", 1), + null, null, 0, 0, false, null, null, response)) + } + + EasyMock.verify(txnStateManager) + + assertEquals(0, channelManager.numTxnsWithPendingMarkers) + assertEquals(0, channelManager.queueForBroker(broker1.id).get.totalNumMarkers) + assertEquals(None, txnMetadata2.pendingState) + assertEquals(CompleteCommit, txnMetadata2.state) + } + + @Test + def shouldAbortAppendToLogOnEndTxnWhenNotCoordinatorError(): Unit = { + mockCache() + + EasyMock.expect(metadataCache.getPartitionLeaderEndpoint( + EasyMock.eq(partition1.topic), + EasyMock.eq(partition1.partition), + EasyMock.anyObject()) + ).andReturn(Some(broker1)).anyTimes() + EasyMock.expect(metadataCache.getPartitionLeaderEndpoint( + EasyMock.eq(partition2.topic), + EasyMock.eq(partition2.partition), + EasyMock.anyObject()) + ).andReturn(Some(broker2)).anyTimes() + + val txnTransitionMetadata2 = txnMetadata2.prepareComplete(time.milliseconds()) + + EasyMock.expect(txnStateManager.appendTransactionToLog( + EasyMock.eq(transactionalId2), + EasyMock.eq(coordinatorEpoch), + EasyMock.eq(txnTransitionMetadata2), + EasyMock.capture(capturedErrorsCallback), + EasyMock.anyObject(), + EasyMock.anyObject())) + .andAnswer(() => { + txnMetadata2.pendingState = None + capturedErrorsCallback.getValue.apply(Errors.NOT_COORDINATOR) + }).once() + EasyMock.replay(txnStateManager, metadataCache) + + channelManager.addTxnMarkersToSend(coordinatorEpoch, txnResult, txnMetadata2, txnTransitionMetadata2) + + val requestAndHandlers: Iterable[RequestAndCompletionHandler] = channelManager.generateRequests() + + val response = new WriteTxnMarkersResponse(createPidErrorMap(Errors.NONE)) + for (requestAndHandler <- requestAndHandlers) { + requestAndHandler.handler.onComplete(new ClientResponse(new RequestHeader(ApiKeys.WRITE_TXN_MARKERS, 0, "client", 1), + null, null, 0, 0, false, null, null, response)) + } + + EasyMock.verify(txnStateManager) + + assertEquals(0, channelManager.numTxnsWithPendingMarkers) + assertEquals(0, channelManager.queueForBroker(broker1.id).get.totalNumMarkers) + assertEquals(None, txnMetadata2.pendingState) + assertEquals(PrepareCommit, txnMetadata2.state) + } + + @Test + def shouldRetryAppendToLogOnEndTxnWhenCoordinatorNotAvailableError(): Unit = { + mockCache() + + EasyMock.expect(metadataCache.getPartitionLeaderEndpoint( + EasyMock.eq(partition1.topic), + EasyMock.eq(partition1.partition), + EasyMock.anyObject()) + ).andReturn(Some(broker1)).anyTimes() + EasyMock.expect(metadataCache.getPartitionLeaderEndpoint( + EasyMock.eq(partition2.topic), + EasyMock.eq(partition2.partition), + EasyMock.anyObject()) + ).andReturn(Some(broker2)).anyTimes() + + val txnTransitionMetadata2 = txnMetadata2.prepareComplete(time.milliseconds()) + + EasyMock.expect(txnStateManager.appendTransactionToLog( + EasyMock.eq(transactionalId2), + EasyMock.eq(coordinatorEpoch), + EasyMock.eq(txnTransitionMetadata2), + EasyMock.capture(capturedErrorsCallback), + EasyMock.anyObject(), + EasyMock.anyObject())) + .andAnswer(() => capturedErrorsCallback.getValue.apply(Errors.COORDINATOR_NOT_AVAILABLE)) + .andAnswer(() => { + txnMetadata2.completeTransitionTo(txnTransitionMetadata2) + capturedErrorsCallback.getValue.apply(Errors.NONE) + }) + + EasyMock.replay(txnStateManager, metadataCache) + + channelManager.addTxnMarkersToSend(coordinatorEpoch, txnResult, txnMetadata2, txnTransitionMetadata2) + + val requestAndHandlers: Iterable[RequestAndCompletionHandler] = channelManager.generateRequests() + + val response = new WriteTxnMarkersResponse(createPidErrorMap(Errors.NONE)) + for (requestAndHandler <- requestAndHandlers) { + requestAndHandler.handler.onComplete(new ClientResponse(new RequestHeader(ApiKeys.WRITE_TXN_MARKERS, 0, "client", 1), + null, null, 0, 0, false, null, null, response)) + } + + // call this again so that append log will be retried + channelManager.generateRequests() + + EasyMock.verify(txnStateManager) + + assertEquals(0, channelManager.numTxnsWithPendingMarkers) + assertEquals(0, channelManager.queueForBroker(broker1.id).get.totalNumMarkers) + assertEquals(None, txnMetadata2.pendingState) + assertEquals(CompleteCommit, txnMetadata2.state) + } + + private def createPidErrorMap(errors: Errors): util.HashMap[java.lang.Long, util.Map[TopicPartition, Errors]] = { + val pidMap = new java.util.HashMap[java.lang.Long, java.util.Map[TopicPartition, Errors]]() + val errorsMap = new java.util.HashMap[TopicPartition, Errors]() + errorsMap.put(partition1, errors) + pidMap.put(producerId2, errorsMap) + pidMap + } + + @Test + def shouldCreateMetricsOnStarting(): Unit = { + val metrics = KafkaYammerMetrics.defaultRegistry.allMetrics.asScala + + assertEquals(1, metrics.count { case (k, _) => + k.getMBeanName == "kafka.coordinator.transaction:type=TransactionMarkerChannelManager,name=UnknownDestinationQueueSize" + }) + assertEquals(1, metrics.count { case (k, _) => + k.getMBeanName == "kafka.coordinator.transaction:type=TransactionMarkerChannelManager,name=LogAppendRetryQueueSize" + }) + } +} diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala new file mode 100644 index 0000000..d15f3b2 --- /dev/null +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.coordinator.transaction + +import java.{lang, util} +import java.util.Arrays.asList + +import org.apache.kafka.clients.ClientResponse +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.record.RecordBatch +import org.apache.kafka.common.requests.{RequestHeader, TransactionResult, WriteTxnMarkersRequest, WriteTxnMarkersResponse} +import org.easymock.EasyMock +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.collection.mutable + +class TransactionMarkerRequestCompletionHandlerTest { + + private val brokerId = 0 + private val txnTopicPartition = 0 + private val transactionalId = "txnId1" + private val producerId = 0.asInstanceOf[Long] + private val producerEpoch = 0.asInstanceOf[Short] + private val lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH + private val txnTimeoutMs = 0 + private val coordinatorEpoch = 0 + private val txnResult = TransactionResult.COMMIT + private val topicPartition = new TopicPartition("topic1", 0) + private val txnIdAndMarkers = asList( + TxnIdAndMarkerEntry(transactionalId, new WriteTxnMarkersRequest.TxnMarkerEntry(producerId, producerEpoch, coordinatorEpoch, txnResult, asList(topicPartition)))) + + private val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, lastProducerEpoch, + txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](topicPartition), 0L, 0L) + + private val markerChannelManager: TransactionMarkerChannelManager = + EasyMock.createNiceMock(classOf[TransactionMarkerChannelManager]) + + private val txnStateManager: TransactionStateManager = EasyMock.createNiceMock(classOf[TransactionStateManager]) + + private val handler = new TransactionMarkerRequestCompletionHandler(brokerId, txnStateManager, markerChannelManager, txnIdAndMarkers) + + private def mockCache(): Unit = { + EasyMock.expect(txnStateManager.partitionFor(transactionalId)) + .andReturn(txnTopicPartition) + .anyTimes() + EasyMock.expect(txnStateManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata)))) + .anyTimes() + EasyMock.replay(txnStateManager) + } + + @Test + def shouldReEnqueuePartitionsWhenBrokerDisconnected(): Unit = { + mockCache() + + EasyMock.expect(markerChannelManager.addTxnMarkersToBrokerQueue(transactionalId, + producerId, producerEpoch, txnResult, coordinatorEpoch, Set[TopicPartition](topicPartition))) + EasyMock.replay(markerChannelManager) + + handler.onComplete(new ClientResponse(new RequestHeader(ApiKeys.PRODUCE, 0, "client", 1), + null, null, 0, 0, true, null, null, null)) + + EasyMock.verify(markerChannelManager) + } + + @Test + def shouldThrowIllegalStateExceptionIfErrorCodeNotAvailableForPid(): Unit = { + mockCache() + EasyMock.replay(markerChannelManager) + + val response = new WriteTxnMarkersResponse(new java.util.HashMap[java.lang.Long, java.util.Map[TopicPartition, Errors]]()) + + assertThrows(classOf[IllegalStateException], () => handler.onComplete(new ClientResponse(new RequestHeader( + ApiKeys.PRODUCE, 0, "client", 1), null, null, 0, 0, false, null, null, response))) + } + + @Test + def shouldCompleteDelayedOperationWhenNoErrors(): Unit = { + mockCache() + + verifyCompleteDelayedOperationOnError(Errors.NONE) + } + + @Test + def shouldCompleteDelayedOperationWhenNotCoordinator(): Unit = { + EasyMock.expect(txnStateManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Left(Errors.NOT_COORDINATOR)) + .anyTimes() + EasyMock.replay(txnStateManager) + + verifyRemoveDelayedOperationOnError(Errors.NONE) + } + + @Test + def shouldCompleteDelayedOperationWhenCoordinatorLoading(): Unit = { + EasyMock.expect(txnStateManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Left(Errors.COORDINATOR_LOAD_IN_PROGRESS)) + .anyTimes() + EasyMock.replay(txnStateManager) + + verifyRemoveDelayedOperationOnError(Errors.NONE) + } + + @Test + def shouldCompleteDelayedOperationWhenCoordinatorEpochChanged(): Unit = { + EasyMock.expect(txnStateManager.getTransactionState(EasyMock.eq(transactionalId))) + .andReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch+1, txnMetadata)))) + .anyTimes() + EasyMock.replay(txnStateManager) + + verifyRemoveDelayedOperationOnError(Errors.NONE) + } + + @Test + def shouldCompleteDelayedOperationWhenInvalidProducerEpoch(): Unit = { + mockCache() + + verifyRemoveDelayedOperationOnError(Errors.INVALID_PRODUCER_EPOCH) + } + + @Test + def shouldCompleteDelayedOperationWheCoordinatorEpochFenced(): Unit = { + mockCache() + + verifyRemoveDelayedOperationOnError(Errors.TRANSACTION_COORDINATOR_FENCED) + } + + @Test + def shouldThrowIllegalStateExceptionWhenUnknownError(): Unit = { + verifyThrowIllegalStateExceptionOnError(Errors.UNKNOWN_SERVER_ERROR) + } + + @Test + def shouldThrowIllegalStateExceptionWhenCorruptMessageError(): Unit = { + verifyThrowIllegalStateExceptionOnError(Errors.CORRUPT_MESSAGE) + } + + @Test + def shouldThrowIllegalStateExceptionWhenMessageTooLargeError(): Unit = { + verifyThrowIllegalStateExceptionOnError(Errors.MESSAGE_TOO_LARGE) + } + + @Test + def shouldThrowIllegalStateExceptionWhenRecordListTooLargeError(): Unit = { + verifyThrowIllegalStateExceptionOnError(Errors.RECORD_LIST_TOO_LARGE) + } + + @Test + def shouldThrowIllegalStateExceptionWhenInvalidRequiredAcksError(): Unit = { + verifyThrowIllegalStateExceptionOnError(Errors.INVALID_REQUIRED_ACKS) + } + + @Test + def shouldRetryPartitionWhenUnknownTopicOrPartitionError(): Unit = { + verifyRetriesPartitionOnError(Errors.UNKNOWN_TOPIC_OR_PARTITION) + } + + @Test + def shouldRetryPartitionWhenNotLeaderOrFollowerError(): Unit = { + verifyRetriesPartitionOnError(Errors.NOT_LEADER_OR_FOLLOWER) + } + + @Test + def shouldRetryPartitionWhenNotEnoughReplicasError(): Unit = { + verifyRetriesPartitionOnError(Errors.NOT_ENOUGH_REPLICAS) + } + + @Test + def shouldRetryPartitionWhenNotEnoughReplicasAfterAppendError(): Unit = { + verifyRetriesPartitionOnError(Errors.NOT_ENOUGH_REPLICAS_AFTER_APPEND) + } + + @Test + def shouldRetryPartitionWhenKafkaStorageError(): Unit = { + verifyRetriesPartitionOnError(Errors.KAFKA_STORAGE_ERROR) + } + + @Test + def shouldRemoveTopicPartitionFromWaitingSetOnUnsupportedForMessageFormat(): Unit = { + mockCache() + verifyCompleteDelayedOperationOnError(Errors.UNSUPPORTED_FOR_MESSAGE_FORMAT) + } + + private def verifyRetriesPartitionOnError(error: Errors) = { + mockCache() + + EasyMock.expect(markerChannelManager.addTxnMarkersToBrokerQueue(transactionalId, + producerId, producerEpoch, txnResult, coordinatorEpoch, Set[TopicPartition](topicPartition))) + EasyMock.replay(markerChannelManager) + + val response = new WriteTxnMarkersResponse(createProducerIdErrorMap(error)) + handler.onComplete(new ClientResponse(new RequestHeader(ApiKeys.PRODUCE, 0, "client", 1), + null, null, 0, 0, false, null, null, response)) + + assertEquals(txnMetadata.topicPartitions, mutable.Set[TopicPartition](topicPartition)) + EasyMock.verify(markerChannelManager) + } + + private def verifyThrowIllegalStateExceptionOnError(error: Errors) = { + mockCache() + + val response = new WriteTxnMarkersResponse(createProducerIdErrorMap(error)) + assertThrows(classOf[IllegalStateException], () => handler.onComplete(new ClientResponse(new RequestHeader( + ApiKeys.PRODUCE, 0, "client", 1), null, null, 0, 0, false, null, null, response))) + } + + private def verifyCompleteDelayedOperationOnError(error: Errors): Unit = { + + var completed = false + EasyMock.expect(markerChannelManager.maybeWriteTxnCompletion(transactionalId)) + .andAnswer(() => completed = true) + .once() + EasyMock.replay(markerChannelManager) + + val response = new WriteTxnMarkersResponse(createProducerIdErrorMap(error)) + handler.onComplete(new ClientResponse(new RequestHeader(ApiKeys.PRODUCE, 0, "client", 1), + null, null, 0, 0, false, null, null, response)) + + assertTrue(txnMetadata.topicPartitions.isEmpty) + assertTrue(completed) + } + + private def verifyRemoveDelayedOperationOnError(error: Errors): Unit = { + + var removed = false + EasyMock.expect(markerChannelManager.removeMarkersForTxnId(transactionalId)) + .andAnswer(() => removed = true) + .once() + EasyMock.replay(markerChannelManager) + + val response = new WriteTxnMarkersResponse(createProducerIdErrorMap(error)) + handler.onComplete(new ClientResponse(new RequestHeader(ApiKeys.PRODUCE, 0, "client", 1), + null, null, 0, 0, false, null, null, response)) + + assertTrue(removed) + } + + + private def createProducerIdErrorMap(errors: Errors) = { + val pidMap = new java.util.HashMap[lang.Long, util.Map[TopicPartition, Errors]]() + val errorsMap = new util.HashMap[TopicPartition, Errors]() + errorsMap.put(topicPartition, errors) + pidMap.put(producerId, errorsMap) + pidMap + } +} diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala new file mode 100644 index 0000000..4c0f456 --- /dev/null +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMetadataTest.scala @@ -0,0 +1,531 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.coordinator.transaction + +import kafka.utils.MockTime +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.record.RecordBatch +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.collection.mutable + +class TransactionMetadataTest { + + val time = new MockTime() + val producerId = 23423L + val transactionalId = "txnlId" + + @Test + def testInitializeEpoch(): Unit = { + val producerEpoch = RecordBatch.NO_PRODUCER_EPOCH + + val txnMetadata = new TransactionMetadata( + transactionalId = transactionalId, + producerId = producerId, + lastProducerId = RecordBatch.NO_PRODUCER_ID, + producerEpoch = producerEpoch, + lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, + txnTimeoutMs = 30000, + state = Empty, + topicPartitions = mutable.Set.empty, + txnLastUpdateTimestamp = time.milliseconds()) + + val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, None) + txnMetadata.completeTransitionTo(transitMetadata) + assertEquals(producerId, txnMetadata.producerId) + assertEquals(0, txnMetadata.producerEpoch) + assertEquals(RecordBatch.NO_PRODUCER_EPOCH, txnMetadata.lastProducerEpoch) + } + + @Test + def testNormalEpochBump(): Unit = { + val producerEpoch = 735.toShort + + val txnMetadata = new TransactionMetadata( + transactionalId = transactionalId, + producerId = producerId, + lastProducerId = RecordBatch.NO_PRODUCER_ID, + producerEpoch = producerEpoch, + lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, + txnTimeoutMs = 30000, + state = Empty, + topicPartitions = mutable.Set.empty, + txnLastUpdateTimestamp = time.milliseconds()) + + val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, None) + txnMetadata.completeTransitionTo(transitMetadata) + assertEquals(producerId, txnMetadata.producerId) + assertEquals(producerEpoch + 1, txnMetadata.producerEpoch) + assertEquals(RecordBatch.NO_PRODUCER_EPOCH, txnMetadata.lastProducerEpoch) + } + + @Test + def testBumpEpochNotAllowedIfEpochsExhausted(): Unit = { + val producerEpoch = (Short.MaxValue - 1).toShort + + val txnMetadata = new TransactionMetadata( + transactionalId = transactionalId, + producerId = producerId, + lastProducerId = RecordBatch.NO_PRODUCER_ID, + producerEpoch = producerEpoch, + lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, + txnTimeoutMs = 30000, + state = Empty, + topicPartitions = mutable.Set.empty, + txnLastUpdateTimestamp = time.milliseconds()) + assertTrue(txnMetadata.isProducerEpochExhausted) + + assertThrows(classOf[IllegalStateException], () => txnMetadata.prepareIncrementProducerEpoch(30000, + None, time.milliseconds())) + } + + @Test + def testTolerateUpdateTimeShiftDuringEpochBump(): Unit = { + val producerEpoch: Short = 1 + val txnMetadata = new TransactionMetadata( + transactionalId = transactionalId, + producerId = producerId, + lastProducerId = RecordBatch.NO_PRODUCER_ID, + producerEpoch = producerEpoch, + lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, + txnTimeoutMs = 30000, + state = Empty, + topicPartitions = mutable.Set.empty, + txnStartTimestamp = 1L, + txnLastUpdateTimestamp = time.milliseconds()) + + // let new time be smaller + val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Option(producerEpoch), + Some(time.milliseconds() - 1)) + txnMetadata.completeTransitionTo(transitMetadata) + assertEquals(producerId, txnMetadata.producerId) + assertEquals(producerEpoch + 1, txnMetadata.producerEpoch) + assertEquals(producerEpoch, txnMetadata.lastProducerEpoch) + assertEquals(1L, txnMetadata.txnStartTimestamp) + assertEquals(time.milliseconds() - 1, txnMetadata.txnLastUpdateTimestamp) + } + + @Test + def testTolerateUpdateTimeResetDuringProducerIdRotation(): Unit = { + val producerEpoch: Short = 1 + val txnMetadata = new TransactionMetadata( + transactionalId = transactionalId, + producerId = producerId, + lastProducerId = RecordBatch.NO_PRODUCER_ID, + producerEpoch = producerEpoch, + lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, + txnTimeoutMs = 30000, + state = Empty, + topicPartitions = mutable.Set.empty, + txnStartTimestamp = 1L, + txnLastUpdateTimestamp = time.milliseconds()) + + // let new time be smaller + val transitMetadata = txnMetadata.prepareProducerIdRotation(producerId + 1, 30000, time.milliseconds() - 1, recordLastEpoch = true) + txnMetadata.completeTransitionTo(transitMetadata) + assertEquals(producerId + 1, txnMetadata.producerId) + assertEquals(producerEpoch, txnMetadata.lastProducerEpoch) + assertEquals(0, txnMetadata.producerEpoch) + assertEquals(1L, txnMetadata.txnStartTimestamp) + assertEquals(time.milliseconds() - 1, txnMetadata.txnLastUpdateTimestamp) + } + + @Test + def testTolerateTimeShiftDuringAddPartitions(): Unit = { + val producerEpoch: Short = 1 + val txnMetadata = new TransactionMetadata( + transactionalId = transactionalId, + producerId = producerId, + lastProducerId = RecordBatch.NO_PRODUCER_ID, + producerEpoch = producerEpoch, + lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, + txnTimeoutMs = 30000, + state = Empty, + topicPartitions = mutable.Set.empty, + txnStartTimestamp = time.milliseconds(), + txnLastUpdateTimestamp = time.milliseconds()) + + // let new time be smaller; when transiting from Empty the start time would be updated to the update-time + var transitMetadata = txnMetadata.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0)), time.milliseconds() - 1) + txnMetadata.completeTransitionTo(transitMetadata) + assertEquals(Set[TopicPartition](new TopicPartition("topic1", 0)), txnMetadata.topicPartitions) + assertEquals(producerId, txnMetadata.producerId) + assertEquals(RecordBatch.NO_PRODUCER_EPOCH, txnMetadata.lastProducerEpoch) + assertEquals(producerEpoch, txnMetadata.producerEpoch) + assertEquals(time.milliseconds() - 1, txnMetadata.txnStartTimestamp) + assertEquals(time.milliseconds() - 1, txnMetadata.txnLastUpdateTimestamp) + + // add another partition, check that in Ongoing state the start timestamp would not change to update time + transitMetadata = txnMetadata.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds() - 2) + txnMetadata.completeTransitionTo(transitMetadata) + assertEquals(Set[TopicPartition](new TopicPartition("topic1", 0), new TopicPartition("topic2", 0)), txnMetadata.topicPartitions) + assertEquals(producerId, txnMetadata.producerId) + assertEquals(RecordBatch.NO_PRODUCER_EPOCH, txnMetadata.lastProducerEpoch) + assertEquals(producerEpoch, txnMetadata.producerEpoch) + assertEquals(time.milliseconds() - 1, txnMetadata.txnStartTimestamp) + assertEquals(time.milliseconds() - 2, txnMetadata.txnLastUpdateTimestamp) + } + + @Test + def testTolerateTimeShiftDuringPrepareCommit(): Unit = { + val producerEpoch: Short = 1 + val txnMetadata = new TransactionMetadata( + transactionalId = transactionalId, + producerId = producerId, + lastProducerId = RecordBatch.NO_PRODUCER_ID, + producerEpoch = producerEpoch, + lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, + txnTimeoutMs = 30000, + state = Ongoing, + topicPartitions = mutable.Set.empty, + txnStartTimestamp = 1L, + txnLastUpdateTimestamp = time.milliseconds()) + + // let new time be smaller + val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareCommit, time.milliseconds() - 1) + txnMetadata.completeTransitionTo(transitMetadata) + assertEquals(PrepareCommit, txnMetadata.state) + assertEquals(producerId, txnMetadata.producerId) + assertEquals(RecordBatch.NO_PRODUCER_EPOCH, txnMetadata.lastProducerEpoch) + assertEquals(producerEpoch, txnMetadata.producerEpoch) + assertEquals(1L, txnMetadata.txnStartTimestamp) + assertEquals(time.milliseconds() - 1, txnMetadata.txnLastUpdateTimestamp) + } + + @Test + def testTolerateTimeShiftDuringPrepareAbort(): Unit = { + val producerEpoch: Short = 1 + val txnMetadata = new TransactionMetadata( + transactionalId = transactionalId, + producerId = producerId, + lastProducerId = RecordBatch.NO_PRODUCER_ID, + producerEpoch = producerEpoch, + lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, + txnTimeoutMs = 30000, + state = Ongoing, + topicPartitions = mutable.Set.empty, + txnStartTimestamp = 1L, + txnLastUpdateTimestamp = time.milliseconds()) + + // let new time be smaller + val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, time.milliseconds() - 1) + txnMetadata.completeTransitionTo(transitMetadata) + assertEquals(PrepareAbort, txnMetadata.state) + assertEquals(producerId, txnMetadata.producerId) + assertEquals(RecordBatch.NO_PRODUCER_EPOCH, txnMetadata.lastProducerEpoch) + assertEquals(producerEpoch, txnMetadata.producerEpoch) + assertEquals(1L, txnMetadata.txnStartTimestamp) + assertEquals(time.milliseconds() - 1, txnMetadata.txnLastUpdateTimestamp) + } + + @Test + def testTolerateTimeShiftDuringCompleteCommit(): Unit = { + val producerEpoch: Short = 1 + val txnMetadata = new TransactionMetadata( + transactionalId = transactionalId, + producerId = producerId, + lastProducerId = RecordBatch.NO_PRODUCER_ID, + producerEpoch = producerEpoch, + lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, + txnTimeoutMs = 30000, + state = PrepareCommit, + topicPartitions = mutable.Set.empty, + txnStartTimestamp = 1L, + txnLastUpdateTimestamp = time.milliseconds()) + + // let new time be smaller + val transitMetadata = txnMetadata.prepareComplete(time.milliseconds() - 1) + txnMetadata.completeTransitionTo(transitMetadata) + assertEquals(CompleteCommit, txnMetadata.state) + assertEquals(producerId, txnMetadata.producerId) + assertEquals(RecordBatch.NO_PRODUCER_EPOCH, txnMetadata.lastProducerEpoch) + assertEquals(producerEpoch, txnMetadata.producerEpoch) + assertEquals(1L, txnMetadata.txnStartTimestamp) + assertEquals(time.milliseconds() - 1, txnMetadata.txnLastUpdateTimestamp) + } + + @Test + def testTolerateTimeShiftDuringCompleteAbort(): Unit = { + val producerEpoch: Short = 1 + val txnMetadata = new TransactionMetadata( + transactionalId = transactionalId, + producerId = producerId, + lastProducerId = RecordBatch.NO_PRODUCER_ID, + producerEpoch = producerEpoch, + lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, + txnTimeoutMs = 30000, + state = PrepareAbort, + topicPartitions = mutable.Set.empty, + txnStartTimestamp = 1L, + txnLastUpdateTimestamp = time.milliseconds()) + + // let new time be smaller + val transitMetadata = txnMetadata.prepareComplete(time.milliseconds() - 1) + txnMetadata.completeTransitionTo(transitMetadata) + assertEquals(CompleteAbort, txnMetadata.state) + assertEquals(producerId, txnMetadata.producerId) + assertEquals(RecordBatch.NO_PRODUCER_EPOCH, txnMetadata.lastProducerEpoch) + assertEquals(producerEpoch, txnMetadata.producerEpoch) + assertEquals(1L, txnMetadata.txnStartTimestamp) + assertEquals(time.milliseconds() - 1, txnMetadata.txnLastUpdateTimestamp) + } + + @Test + def testFenceProducerAfterEpochsExhausted(): Unit = { + val producerEpoch = (Short.MaxValue - 1).toShort + + val txnMetadata = new TransactionMetadata( + transactionalId = transactionalId, + producerId = producerId, + lastProducerId = RecordBatch.NO_PRODUCER_ID, + producerEpoch = producerEpoch, + lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, + txnTimeoutMs = 30000, + state = Ongoing, + topicPartitions = mutable.Set.empty, + txnLastUpdateTimestamp = time.milliseconds()) + assertTrue(txnMetadata.isProducerEpochExhausted) + + val fencingTransitMetadata = txnMetadata.prepareFenceProducerEpoch() + assertEquals(Short.MaxValue, fencingTransitMetadata.producerEpoch) + assertEquals(RecordBatch.NO_PRODUCER_EPOCH, fencingTransitMetadata.lastProducerEpoch) + assertEquals(Some(PrepareEpochFence), txnMetadata.pendingState) + + // We should reset the pending state to make way for the abort transition. + txnMetadata.pendingState = None + + val transitMetadata = txnMetadata.prepareAbortOrCommit(PrepareAbort, time.milliseconds()) + txnMetadata.completeTransitionTo(transitMetadata) + assertEquals(producerId, transitMetadata.producerId) + } + + @Test + def testFenceProducerNotAllowedIfItWouldOverflow(): Unit = { + val producerEpoch = Short.MaxValue + + val txnMetadata = new TransactionMetadata( + transactionalId = transactionalId, + producerId = producerId, + lastProducerId = RecordBatch.NO_PRODUCER_ID, + producerEpoch = producerEpoch, + lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, + txnTimeoutMs = 30000, + state = Ongoing, + topicPartitions = mutable.Set.empty, + txnLastUpdateTimestamp = time.milliseconds()) + assertTrue(txnMetadata.isProducerEpochExhausted) + assertThrows(classOf[IllegalStateException], () => txnMetadata.prepareFenceProducerEpoch()) + } + + @Test + def testRotateProducerId(): Unit = { + val producerEpoch = (Short.MaxValue - 1).toShort + + val txnMetadata = new TransactionMetadata( + transactionalId = transactionalId, + producerId = producerId, + lastProducerId = RecordBatch.NO_PRODUCER_ID, + producerEpoch = producerEpoch, + lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, + txnTimeoutMs = 30000, + state = Empty, + topicPartitions = mutable.Set.empty, + txnLastUpdateTimestamp = time.milliseconds()) + + val newProducerId = 9893L + val transitMetadata = txnMetadata.prepareProducerIdRotation(newProducerId, 30000, time.milliseconds(), recordLastEpoch = true) + txnMetadata.completeTransitionTo(transitMetadata) + assertEquals(newProducerId, txnMetadata.producerId) + assertEquals(producerId, txnMetadata.lastProducerId) + assertEquals(0, txnMetadata.producerEpoch) + assertEquals(producerEpoch, txnMetadata.lastProducerEpoch) + } + + @Test + def testRotateProducerIdInOngoingState(): Unit = { + assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(Ongoing)) + } + + @Test + def testRotateProducerIdInPrepareAbortState(): Unit = { + assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(PrepareAbort)) + } + + @Test + def testRotateProducerIdInPrepareCommitState(): Unit = { + assertThrows(classOf[IllegalStateException], () => testRotateProducerIdInOngoingState(PrepareCommit)) + } + + @Test + def testAttemptedEpochBumpWithNewlyCreatedMetadata(): Unit = { + val producerEpoch = 735.toShort + + val txnMetadata = new TransactionMetadata( + transactionalId = transactionalId, + producerId = producerId, + lastProducerId = RecordBatch.NO_PRODUCER_ID, + producerEpoch = RecordBatch.NO_PRODUCER_EPOCH, + lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, + txnTimeoutMs = 30000, + state = Empty, + topicPartitions = mutable.Set.empty, + txnLastUpdateTimestamp = time.milliseconds()) + + val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Some(producerEpoch)) + txnMetadata.completeTransitionTo(transitMetadata) + assertEquals(producerId, txnMetadata.producerId) + assertEquals(0, txnMetadata.producerEpoch) + assertEquals(RecordBatch.NO_PRODUCER_EPOCH, txnMetadata.lastProducerEpoch) + } + + @Test + def testEpochBumpWithCurrentEpochProvided(): Unit = { + val producerEpoch = 735.toShort + + val txnMetadata = new TransactionMetadata( + transactionalId = transactionalId, + producerId = producerId, + lastProducerId = RecordBatch.NO_PRODUCER_ID, + producerEpoch = producerEpoch, + lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, + txnTimeoutMs = 30000, + state = Empty, + topicPartitions = mutable.Set.empty, + txnLastUpdateTimestamp = time.milliseconds()) + + val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Some(producerEpoch)) + txnMetadata.completeTransitionTo(transitMetadata) + assertEquals(producerId, txnMetadata.producerId) + assertEquals(producerEpoch + 1, txnMetadata.producerEpoch) + assertEquals(producerEpoch, txnMetadata.lastProducerEpoch) + } + + @Test + def testAttemptedEpochBumpWithLastEpoch(): Unit = { + val producerEpoch = 735.toShort + val lastProducerEpoch = (producerEpoch - 1).toShort + + val txnMetadata = new TransactionMetadata( + transactionalId = transactionalId, + producerId = producerId, + lastProducerId = RecordBatch.NO_PRODUCER_ID, + producerEpoch = producerEpoch, + lastProducerEpoch = lastProducerEpoch, + txnTimeoutMs = 30000, + state = Empty, + topicPartitions = mutable.Set.empty, + txnLastUpdateTimestamp = time.milliseconds()) + + val transitMetadata = prepareSuccessfulIncrementProducerEpoch(txnMetadata, Some(lastProducerEpoch)) + txnMetadata.completeTransitionTo(transitMetadata) + assertEquals(producerId, txnMetadata.producerId) + assertEquals(producerEpoch, txnMetadata.producerEpoch) + assertEquals(lastProducerEpoch, txnMetadata.lastProducerEpoch) + } + + @Test + def testAttemptedEpochBumpWithFencedEpoch(): Unit = { + val producerEpoch = 735.toShort + val lastProducerEpoch = (producerEpoch - 1).toShort + + val txnMetadata = new TransactionMetadata( + transactionalId = transactionalId, + producerId = producerId, + lastProducerId = producerId, + producerEpoch = producerEpoch, + lastProducerEpoch = lastProducerEpoch, + txnTimeoutMs = 30000, + state = Empty, + topicPartitions = mutable.Set.empty, + txnLastUpdateTimestamp = time.milliseconds()) + + val result = txnMetadata.prepareIncrementProducerEpoch(30000, Some((lastProducerEpoch - 1).toShort), + time.milliseconds()) + assertEquals(Left(Errors.PRODUCER_FENCED), result) + } + + @Test + def testTransactionStateIdAndNameMapping(): Unit = { + for (state <- TransactionState.AllStates) { + assertEquals(state, TransactionState.fromId(state.id)) + assertEquals(Some(state), TransactionState.fromName(state.name)) + + if (state != Dead) { + val clientTransactionState = org.apache.kafka.clients.admin.TransactionState.parse(state.name) + assertEquals(state.name, clientTransactionState.toString) + assertNotEquals(org.apache.kafka.clients.admin.TransactionState.UNKNOWN, clientTransactionState) + } + } + } + + @Test + def testAllTransactionStatesAreMapped(): Unit = { + val unmatchedStates = mutable.Set( + Empty, + Ongoing, + PrepareCommit, + PrepareAbort, + CompleteCommit, + CompleteAbort, + PrepareEpochFence, + Dead + ) + + // The exhaustive match is intentional here to ensure that we are + // forced to update the test case if a new state is added. + TransactionState.AllStates.foreach { + case Empty => assertTrue(unmatchedStates.remove(Empty)) + case Ongoing => assertTrue(unmatchedStates.remove(Ongoing)) + case PrepareCommit => assertTrue(unmatchedStates.remove(PrepareCommit)) + case PrepareAbort => assertTrue(unmatchedStates.remove(PrepareAbort)) + case CompleteCommit => assertTrue(unmatchedStates.remove(CompleteCommit)) + case CompleteAbort => assertTrue(unmatchedStates.remove(CompleteAbort)) + case PrepareEpochFence => assertTrue(unmatchedStates.remove(PrepareEpochFence)) + case Dead => assertTrue(unmatchedStates.remove(Dead)) + } + + assertEquals(Set.empty, unmatchedStates) + } + + private def testRotateProducerIdInOngoingState(state: TransactionState): Unit = { + val producerEpoch = (Short.MaxValue - 1).toShort + + val txnMetadata = new TransactionMetadata( + transactionalId = transactionalId, + producerId = producerId, + lastProducerId = producerId, + producerEpoch = producerEpoch, + lastProducerEpoch = RecordBatch.NO_PRODUCER_EPOCH, + txnTimeoutMs = 30000, + state = state, + topicPartitions = mutable.Set.empty, + txnLastUpdateTimestamp = time.milliseconds()) + val newProducerId = 9893L + txnMetadata.prepareProducerIdRotation(newProducerId, 30000, time.milliseconds(), recordLastEpoch = false) + } + + private def prepareSuccessfulIncrementProducerEpoch(txnMetadata: TransactionMetadata, + expectedProducerEpoch: Option[Short], + now: Option[Long] = None): TxnTransitMetadata = { + val result = txnMetadata.prepareIncrementProducerEpoch(30000, expectedProducerEpoch, + now.getOrElse(time.milliseconds())) + result.getOrElse(throw new AssertionError(s"prepareIncrementProducerEpoch failed with $result")) + } + +} diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala new file mode 100644 index 0000000..32e41cd --- /dev/null +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala @@ -0,0 +1,1075 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.coordinator.transaction + +import java.lang.management.ManagementFactory +import java.nio.ByteBuffer +import java.util.concurrent.CountDownLatch +import java.util.concurrent.locks.ReentrantLock + +import javax.management.ObjectName +import kafka.log.{AppendOrigin, Defaults, UnifiedLog, LogConfig} +import kafka.server.{FetchDataInfo, FetchLogEnd, LogOffsetMetadata, ReplicaManager, RequestLocal} +import kafka.utils.{MockScheduler, Pool, TestUtils} +import kafka.zk.KafkaZkClient +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.internals.Topic.TRANSACTION_STATE_TOPIC_NAME +import org.apache.kafka.common.metrics.{JmxReporter, KafkaMetricsContext, Metrics} +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.record._ +import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse +import org.apache.kafka.common.requests.TransactionResult +import org.apache.kafka.common.utils.MockTime +import org.easymock.{Capture, EasyMock, IAnswer} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} + +import scala.collection.{Map, mutable} +import scala.jdk.CollectionConverters._ + +class TransactionStateManagerTest { + + val partitionId = 0 + val numPartitions = 2 + val transactionTimeoutMs: Int = 1000 + val topicPartition = new TopicPartition(TRANSACTION_STATE_TOPIC_NAME, partitionId) + val coordinatorEpoch = 10 + + val txnRecords: mutable.ArrayBuffer[SimpleRecord] = mutable.ArrayBuffer[SimpleRecord]() + + val time = new MockTime() + val scheduler = new MockScheduler(time) + val zkClient: KafkaZkClient = EasyMock.createNiceMock(classOf[KafkaZkClient]) + val replicaManager: ReplicaManager = EasyMock.createNiceMock(classOf[ReplicaManager]) + + EasyMock.expect(zkClient.getTopicPartitionCount(TRANSACTION_STATE_TOPIC_NAME)) + .andReturn(Some(numPartitions)) + .anyTimes() + + EasyMock.replay(zkClient) + val metrics = new Metrics() + + val txnConfig = TransactionConfig() + val transactionManager: TransactionStateManager = new TransactionStateManager(0, scheduler, + replicaManager, txnConfig, time, metrics) + + val transactionalId1: String = "one" + val transactionalId2: String = "two" + val txnMessageKeyBytes1: Array[Byte] = TransactionLog.keyToBytes(transactionalId1) + val txnMessageKeyBytes2: Array[Byte] = TransactionLog.keyToBytes(transactionalId2) + val producerIds: Map[String, Long] = Map[String, Long](transactionalId1 -> 1L, transactionalId2 -> 2L) + var txnMetadata1: TransactionMetadata = transactionMetadata(transactionalId1, producerIds(transactionalId1)) + var txnMetadata2: TransactionMetadata = transactionMetadata(transactionalId2, producerIds(transactionalId2)) + + var expectedError: Errors = Errors.NONE + + @BeforeEach + def setUp(): Unit = { + transactionManager.startup(() => numPartitions, enableTransactionalIdExpiration = false) + // make sure the transactional id hashes to the assigning partition id + assertEquals(partitionId, transactionManager.partitionFor(transactionalId1)) + assertEquals(partitionId, transactionManager.partitionFor(transactionalId2)) + } + + @AfterEach + def tearDown(): Unit = { + EasyMock.reset(zkClient, replicaManager) + transactionManager.shutdown() + } + + @Test + def testValidateTransactionTimeout(): Unit = { + assertTrue(transactionManager.validateTransactionTimeoutMs(1)) + assertFalse(transactionManager.validateTransactionTimeoutMs(-1)) + assertFalse(transactionManager.validateTransactionTimeoutMs(0)) + assertTrue(transactionManager.validateTransactionTimeoutMs(txnConfig.transactionMaxTimeoutMs)) + assertFalse(transactionManager.validateTransactionTimeoutMs(txnConfig.transactionMaxTimeoutMs + 1)) + } + + @Test + def testAddGetPids(): Unit = { + transactionManager.addLoadedTransactionsToCache(partitionId, coordinatorEpoch, new Pool[String, TransactionMetadata]()) + + assertEquals(Right(None), transactionManager.getTransactionState(transactionalId1)) + assertEquals(Right(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1)), + transactionManager.putTransactionStateIfNotExists(txnMetadata1)) + assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), + transactionManager.getTransactionState(transactionalId1)) + assertEquals(Right(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata2)), + transactionManager.putTransactionStateIfNotExists(txnMetadata2)) + } + + @Test + def testDeletePartition(): Unit = { + val metadata1 = transactionMetadata("b", 5L) + val metadata2 = transactionMetadata("a", 10L) + + assertEquals(0, transactionManager.partitionFor(metadata1.transactionalId)) + assertEquals(1, transactionManager.partitionFor(metadata2.transactionalId)) + + transactionManager.addLoadedTransactionsToCache(0, coordinatorEpoch, new Pool[String, TransactionMetadata]()) + transactionManager.putTransactionStateIfNotExists(metadata1) + + transactionManager.addLoadedTransactionsToCache(1, coordinatorEpoch, new Pool[String, TransactionMetadata]()) + transactionManager.putTransactionStateIfNotExists(metadata2) + + def cachedProducerEpoch(transactionalId: String): Option[Short] = { + transactionManager.getTransactionState(transactionalId).toOption.flatten + .map(_.transactionMetadata.producerEpoch) + } + + assertEquals(Some(metadata1.producerEpoch), cachedProducerEpoch(metadata1.transactionalId)) + assertEquals(Some(metadata2.producerEpoch), cachedProducerEpoch(metadata2.transactionalId)) + + transactionManager.removeTransactionsForTxnTopicPartition(0) + + assertEquals(None, cachedProducerEpoch(metadata1.transactionalId)) + assertEquals(Some(metadata2.producerEpoch), cachedProducerEpoch(metadata2.transactionalId)) + } + + @Test + def testDeleteLoadingPartition(): Unit = { + // Verify the handling of a call to delete state for a partition while it is in the + // process of being loaded. Basically should be treated as a no-op. + + val startOffset = 0L + val endOffset = 1L + + val fileRecordsMock = EasyMock.mock[FileRecords](classOf[FileRecords]) + val logMock = EasyMock.mock[UnifiedLog](classOf[UnifiedLog]) + EasyMock.expect(replicaManager.getLog(topicPartition)).andStubReturn(Some(logMock)) + EasyMock.expect(logMock.logStartOffset).andStubReturn(startOffset) + EasyMock.expect(logMock.read(EasyMock.eq(startOffset), + maxLength = EasyMock.anyInt(), + isolation = EasyMock.eq(FetchLogEnd), + minOneMessage = EasyMock.eq(true)) + ).andReturn(FetchDataInfo(LogOffsetMetadata(startOffset), fileRecordsMock)) + EasyMock.expect(replicaManager.getLogEndOffset(topicPartition)).andStubReturn(Some(endOffset)) + + txnMetadata1.state = PrepareCommit + txnMetadata1.addPartitions(Set[TopicPartition]( + new TopicPartition("topic1", 0), + new TopicPartition("topic1", 1))) + val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, + new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit()))) + + // We create a latch which is awaited while the log is loading. This ensures that the deletion + // is triggered before the loading returns + val latch = new CountDownLatch(1) + + EasyMock.expect(fileRecordsMock.sizeInBytes()).andStubReturn(records.sizeInBytes) + val bufferCapture = EasyMock.newCapture[ByteBuffer] + fileRecordsMock.readInto(EasyMock.capture(bufferCapture), EasyMock.anyInt()) + EasyMock.expectLastCall().andAnswer(new IAnswer[Unit] { + override def answer: Unit = { + latch.await() + val buffer = bufferCapture.getValue + buffer.put(records.buffer.duplicate) + buffer.flip() + } + }) + + EasyMock.replay(logMock, fileRecordsMock, replicaManager) + + val coordinatorEpoch = 0 + val partitionAndLeaderEpoch = TransactionPartitionAndLeaderEpoch(partitionId, coordinatorEpoch) + + val loadingThread = new Thread(() => { + transactionManager.loadTransactionsForTxnTopicPartition(partitionId, coordinatorEpoch, (_, _, _, _) => ()) + }) + loadingThread.start() + TestUtils.waitUntilTrue(() => transactionManager.loadingPartitions.contains(partitionAndLeaderEpoch), + "Timed out waiting for loading partition", pause = 10) + + transactionManager.removeTransactionsForTxnTopicPartition(partitionId) + assertFalse(transactionManager.loadingPartitions.contains(partitionAndLeaderEpoch)) + + latch.countDown() + loadingThread.join() + + // Verify that transaction state was not loaded + assertEquals(Left(Errors.NOT_COORDINATOR), transactionManager.getTransactionState(txnMetadata1.transactionalId)) + } + + @Test + def testLoadAndRemoveTransactionsForPartition(): Unit = { + // generate transaction log messages for two pids traces: + + // pid1's transaction started with two partitions + txnMetadata1.state = Ongoing + txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), + new TopicPartition("topic1", 1))) + + txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit())) + + // pid1's transaction adds three more partitions + txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic2", 0), + new TopicPartition("topic2", 1), + new TopicPartition("topic2", 2))) + + txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit())) + + // pid1's transaction is preparing to commit + txnMetadata1.state = PrepareCommit + + txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit())) + + // pid2's transaction started with three partitions + txnMetadata2.state = Ongoing + txnMetadata2.addPartitions(Set[TopicPartition](new TopicPartition("topic3", 0), + new TopicPartition("topic3", 1), + new TopicPartition("topic3", 2))) + + txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit())) + + // pid2's transaction is preparing to abort + txnMetadata2.state = PrepareAbort + + txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit())) + + // pid2's transaction has aborted + txnMetadata2.state = CompleteAbort + + txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit())) + + // pid2's epoch has advanced, with no ongoing transaction yet + txnMetadata2.state = Empty + txnMetadata2.topicPartitions.clear() + + txnRecords += new SimpleRecord(txnMessageKeyBytes2, TransactionLog.valueToBytes(txnMetadata2.prepareNoTransit())) + + val startOffset = 15L // it should work for any start offset + val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, txnRecords.toArray: _*) + + prepareTxnLog(topicPartition, startOffset, records) + + // this partition should not be part of the owned partitions + transactionManager.getTransactionState(transactionalId1).fold( + err => assertEquals(Errors.NOT_COORDINATOR, err), + _ => fail(transactionalId1 + "'s transaction state is already in the cache") + ) + transactionManager.getTransactionState(transactionalId2).fold( + err => assertEquals(Errors.NOT_COORDINATOR, err), + _ => fail(transactionalId2 + "'s transaction state is already in the cache") + ) + + transactionManager.loadTransactionsForTxnTopicPartition(partitionId, 0, (_, _, _, _) => ()) + + // let the time advance to trigger the background thread loading + scheduler.tick() + + transactionManager.getTransactionState(transactionalId1).fold( + err => fail(transactionalId1 + "'s transaction state access returns error " + err), + entry => entry.getOrElse(fail(transactionalId1 + "'s transaction state was not loaded into the cache")) + ) + + val cachedPidMetadata1 = transactionManager.getTransactionState(transactionalId1).fold( + err => throw new AssertionError(transactionalId1 + "'s transaction state access returns error " + err), + entry => entry.getOrElse(throw new AssertionError(transactionalId1 + "'s transaction state was not loaded into the cache")) + ) + val cachedPidMetadata2 = transactionManager.getTransactionState(transactionalId2).fold( + err => throw new AssertionError(transactionalId2 + "'s transaction state access returns error " + err), + entry => entry.getOrElse(throw new AssertionError(transactionalId2 + "'s transaction state was not loaded into the cache")) + ) + + // they should be equal to the latest status of the transaction + assertEquals(txnMetadata1, cachedPidMetadata1.transactionMetadata) + assertEquals(txnMetadata2, cachedPidMetadata2.transactionMetadata) + + transactionManager.removeTransactionsForTxnTopicPartition(partitionId, coordinatorEpoch) + + // let the time advance to trigger the background thread removing + scheduler.tick() + + transactionManager.getTransactionState(transactionalId1).fold( + err => assertEquals(Errors.NOT_COORDINATOR, err), + _ => fail(transactionalId1 + "'s transaction state is still in the cache") + ) + transactionManager.getTransactionState(transactionalId2).fold( + err => assertEquals(Errors.NOT_COORDINATOR, err), + _ => fail(transactionalId2 + "'s transaction state is still in the cache") + ) + } + + @Test + def testCompleteTransitionWhenAppendSucceeded(): Unit = { + transactionManager.addLoadedTransactionsToCache(partitionId, coordinatorEpoch, new Pool[String, TransactionMetadata]()) + + // first insert the initial transaction metadata + transactionManager.putTransactionStateIfNotExists(txnMetadata1) + + prepareForTxnMessageAppend(Errors.NONE) + expectedError = Errors.NONE + + // update the metadata to ongoing with two partitions + val newMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), + new TopicPartition("topic1", 1)), time.milliseconds()) + + // append the new metadata into log + transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch, newMetadata, assertCallback, requestLocal = RequestLocal.withThreadConfinedCaching) + + assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1)) + assertTrue(txnMetadata1.pendingState.isEmpty) + } + + @Test + def testAppendFailToCoordinatorNotAvailableError(): Unit = { + transactionManager.addLoadedTransactionsToCache(partitionId, coordinatorEpoch, new Pool[String, TransactionMetadata]()) + transactionManager.putTransactionStateIfNotExists(txnMetadata1) + + expectedError = Errors.COORDINATOR_NOT_AVAILABLE + var failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds()) + + prepareForTxnMessageAppend(Errors.UNKNOWN_TOPIC_OR_PARTITION) + val requestLocal = RequestLocal.withThreadConfinedCaching + transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, requestLocal = requestLocal) + assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1)) + assertTrue(txnMetadata1.pendingState.isEmpty) + + failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds()) + prepareForTxnMessageAppend(Errors.NOT_ENOUGH_REPLICAS) + transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, requestLocal = requestLocal) + assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1)) + assertTrue(txnMetadata1.pendingState.isEmpty) + + failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds()) + prepareForTxnMessageAppend(Errors.NOT_ENOUGH_REPLICAS_AFTER_APPEND) + transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, requestLocal = requestLocal) + assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1)) + assertTrue(txnMetadata1.pendingState.isEmpty) + + failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds()) + prepareForTxnMessageAppend(Errors.REQUEST_TIMED_OUT) + transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, requestLocal = requestLocal) + assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1)) + assertTrue(txnMetadata1.pendingState.isEmpty) + } + + @Test + def testAppendFailToNotCoordinatorError(): Unit = { + transactionManager.addLoadedTransactionsToCache(partitionId, coordinatorEpoch, new Pool[String, TransactionMetadata]()) + transactionManager.putTransactionStateIfNotExists(txnMetadata1) + + expectedError = Errors.NOT_COORDINATOR + var failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds()) + + prepareForTxnMessageAppend(Errors.NOT_LEADER_OR_FOLLOWER) + val requestLocal = RequestLocal.withThreadConfinedCaching + transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, requestLocal = requestLocal) + assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1)) + assertTrue(txnMetadata1.pendingState.isEmpty) + + failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds()) + prepareForTxnMessageAppend(Errors.NONE) + transactionManager.removeTransactionsForTxnTopicPartition(partitionId, coordinatorEpoch) + transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, requestLocal = requestLocal) + + prepareForTxnMessageAppend(Errors.NONE) + transactionManager.removeTransactionsForTxnTopicPartition(partitionId, coordinatorEpoch) + transactionManager.addLoadedTransactionsToCache(partitionId, coordinatorEpoch + 1, new Pool[String, TransactionMetadata]()) + transactionManager.putTransactionStateIfNotExists(txnMetadata1) + transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, requestLocal = requestLocal) + + prepareForTxnMessageAppend(Errors.NONE) + transactionManager.removeTransactionsForTxnTopicPartition(partitionId, coordinatorEpoch) + transactionManager.addLoadedTransactionsToCache(partitionId, coordinatorEpoch, new Pool[String, TransactionMetadata]()) + transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, requestLocal = requestLocal) + } + + @Test + def testAppendFailToCoordinatorLoadingError(): Unit = { + transactionManager.addLoadedTransactionsToCache(partitionId, coordinatorEpoch, new Pool[String, TransactionMetadata]()) + transactionManager.putTransactionStateIfNotExists(txnMetadata1) + + expectedError = Errors.COORDINATOR_LOAD_IN_PROGRESS + val failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds()) + + prepareForTxnMessageAppend(Errors.NONE) + transactionManager.removeTransactionsForTxnTopicPartition(partitionId, coordinatorEpoch) + transactionManager.addLoadingPartition(partitionId, coordinatorEpoch + 1) + transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, requestLocal = RequestLocal.withThreadConfinedCaching) + } + + @Test + def testAppendFailToUnknownError(): Unit = { + transactionManager.addLoadedTransactionsToCache(partitionId, coordinatorEpoch, new Pool[String, TransactionMetadata]()) + transactionManager.putTransactionStateIfNotExists(txnMetadata1) + + expectedError = Errors.UNKNOWN_SERVER_ERROR + var failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds()) + + prepareForTxnMessageAppend(Errors.MESSAGE_TOO_LARGE) + val requestLocal = RequestLocal.withThreadConfinedCaching + transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, requestLocal = requestLocal) + assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1)) + assertTrue(txnMetadata1.pendingState.isEmpty) + + failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds()) + prepareForTxnMessageAppend(Errors.RECORD_LIST_TOO_LARGE) + transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, requestLocal = requestLocal) + assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1)) + assertTrue(txnMetadata1.pendingState.isEmpty) + } + + @Test + def testPendingStateNotResetOnRetryAppend(): Unit = { + transactionManager.addLoadedTransactionsToCache(partitionId, coordinatorEpoch, new Pool[String, TransactionMetadata]()) + transactionManager.putTransactionStateIfNotExists(txnMetadata1) + + expectedError = Errors.COORDINATOR_NOT_AVAILABLE + val failedMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic2", 0)), time.milliseconds()) + + prepareForTxnMessageAppend(Errors.UNKNOWN_TOPIC_OR_PARTITION) + transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, failedMetadata, assertCallback, _ => true, RequestLocal.withThreadConfinedCaching) + assertEquals(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, txnMetadata1))), transactionManager.getTransactionState(transactionalId1)) + assertEquals(Some(Ongoing), txnMetadata1.pendingState) + } + + @Test + def testAppendTransactionToLogWhileProducerFenced(): Unit = { + transactionManager.addLoadedTransactionsToCache(partitionId, 0, new Pool[String, TransactionMetadata]()) + + // first insert the initial transaction metadata + transactionManager.putTransactionStateIfNotExists(txnMetadata1) + + prepareForTxnMessageAppend(Errors.NONE) + expectedError = Errors.NOT_COORDINATOR + + val newMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), + new TopicPartition("topic1", 1)), time.milliseconds()) + + // modify the cache while trying to append the new metadata + txnMetadata1.producerEpoch = (txnMetadata1.producerEpoch + 1).toShort + + // append the new metadata into log + transactionManager.appendTransactionToLog(transactionalId1, coordinatorEpoch = 10, newMetadata, assertCallback, requestLocal = RequestLocal.withThreadConfinedCaching) + } + + @Test + def testAppendTransactionToLogWhilePendingStateChanged(): Unit = { + // first insert the initial transaction metadata + transactionManager.addLoadedTransactionsToCache(partitionId, coordinatorEpoch, new Pool[String, TransactionMetadata]()) + transactionManager.putTransactionStateIfNotExists(txnMetadata1) + + prepareForTxnMessageAppend(Errors.NONE) + expectedError = Errors.INVALID_PRODUCER_EPOCH + + val newMetadata = txnMetadata1.prepareAddPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), + new TopicPartition("topic1", 1)), time.milliseconds()) + + // modify the cache while trying to append the new metadata + txnMetadata1.pendingState = None + + // append the new metadata into log + assertThrows(classOf[IllegalStateException], () => transactionManager.appendTransactionToLog(transactionalId1, + coordinatorEpoch = 10, newMetadata, assertCallback, requestLocal = RequestLocal.withThreadConfinedCaching)) + } + + @Test + def shouldReturnNotCoordinatorErrorIfTransactionIdPartitionNotOwned(): Unit = { + transactionManager.getTransactionState(transactionalId1).fold( + err => assertEquals(Errors.NOT_COORDINATOR, err), + _ => fail(transactionalId1 + "'s transaction state is already in the cache") + ) + } + + @Test + def testListTransactionsWithCoordinatorLoadingInProgress(): Unit = { + transactionManager.addLoadingPartition(partitionId = 0, coordinatorEpoch = 15) + val listResponse = transactionManager.listTransactionStates( + filterProducerIds = Set.empty, + filterStateNames = Set.empty + ) + assertEquals(Errors.COORDINATOR_LOAD_IN_PROGRESS, Errors.forCode(listResponse.errorCode)) + } + + @Test + def testListTransactionsFiltering(): Unit = { + for (partitionId <- 0 until numPartitions) { + transactionManager.addLoadedTransactionsToCache(partitionId, 0, new Pool[String, TransactionMetadata]()) + } + + def putTransaction( + transactionalId: String, + producerId: Long, + state: TransactionState + ): Unit = { + val txnMetadata = transactionMetadata(transactionalId, producerId, state) + transactionManager.putTransactionStateIfNotExists(txnMetadata).left.toOption.foreach { error => + fail(s"Failed to insert transaction $txnMetadata due to error $error") + } + } + + putTransaction(transactionalId = "t0", producerId = 0, state = Ongoing) + putTransaction(transactionalId = "t1", producerId = 1, state = Ongoing) + putTransaction(transactionalId = "t2", producerId = 2, state = PrepareCommit) + putTransaction(transactionalId = "t3", producerId = 3, state = PrepareAbort) + putTransaction(transactionalId = "t4", producerId = 4, state = CompleteCommit) + putTransaction(transactionalId = "t5", producerId = 5, state = CompleteAbort) + putTransaction(transactionalId = "t6", producerId = 6, state = CompleteAbort) + putTransaction(transactionalId = "t7", producerId = 7, state = PrepareEpochFence) + // Note that `Dead` transactions are never returned. This is a transient state + // which is used when the transaction state is in the process of being deleted + // (whether though expiration or coordinator unloading). + putTransaction(transactionalId = "t8", producerId = 8, state = Dead) + + def assertListTransactions( + expectedTransactionalIds: Set[String], + filterProducerIds: Set[Long] = Set.empty, + filterStates: Set[String] = Set.empty + ): Unit = { + val listResponse = transactionManager.listTransactionStates(filterProducerIds, filterStates) + assertEquals(Errors.NONE, Errors.forCode(listResponse.errorCode)) + assertEquals(expectedTransactionalIds, listResponse.transactionStates.asScala.map(_.transactionalId).toSet) + val expectedUnknownStates = filterStates.filter(state => TransactionState.fromName(state).isEmpty) + assertEquals(expectedUnknownStates, listResponse.unknownStateFilters.asScala.toSet) + } + + assertListTransactions(Set("t0", "t1", "t2", "t3", "t4", "t5", "t6", "t7")) + assertListTransactions(Set("t0", "t1"), filterStates = Set("Ongoing")) + assertListTransactions(Set("t0", "t1"), filterStates = Set("Ongoing", "UnknownState")) + assertListTransactions(Set("t2", "t4"), filterStates = Set("PrepareCommit", "CompleteCommit")) + assertListTransactions(Set(), filterStates = Set("UnknownState")) + assertListTransactions(Set("t5"), filterProducerIds = Set(5L)) + assertListTransactions(Set("t5", "t6"), filterProducerIds = Set(5L, 6L, 8L, 9L)) + assertListTransactions(Set("t4"), filterProducerIds = Set(4L, 5L), filterStates = Set("CompleteCommit")) + assertListTransactions(Set("t4", "t5"), filterProducerIds = Set(4L, 5L), filterStates = Set("CompleteCommit", "CompleteAbort")) + assertListTransactions(Set(), filterProducerIds = Set(3L, 6L), filterStates = Set("UnknownState")) + assertListTransactions(Set(), filterProducerIds = Set(10L), filterStates = Set("CompleteCommit")) + assertListTransactions(Set(), filterStates = Set("Dead")) + } + + @Test + def shouldOnlyConsiderTransactionsInTheOngoingStateToAbort(): Unit = { + for (partitionId <- 0 until numPartitions) { + transactionManager.addLoadedTransactionsToCache(partitionId, 0, new Pool[String, TransactionMetadata]()) + } + + transactionManager.putTransactionStateIfNotExists(transactionMetadata("ongoing", producerId = 0, state = Ongoing)) + transactionManager.putTransactionStateIfNotExists(transactionMetadata("not-expiring", producerId = 1, state = Ongoing, txnTimeout = 10000)) + transactionManager.putTransactionStateIfNotExists(transactionMetadata("prepare-commit", producerId = 2, state = PrepareCommit)) + transactionManager.putTransactionStateIfNotExists(transactionMetadata("prepare-abort", producerId = 3, state = PrepareAbort)) + transactionManager.putTransactionStateIfNotExists(transactionMetadata("complete-commit", producerId = 4, state = CompleteCommit)) + transactionManager.putTransactionStateIfNotExists(transactionMetadata("complete-abort", producerId = 5, state = CompleteAbort)) + + time.sleep(2000) + val expiring = transactionManager.timedOutTransactions() + assertEquals(List(TransactionalIdAndProducerIdEpoch("ongoing", 0, 0)), expiring) + } + + @Test + def shouldWriteTxnMarkersForTransactionInPreparedCommitState(): Unit = { + verifyWritesTxnMarkersInPrepareState(PrepareCommit) + } + + @Test + def shouldWriteTxnMarkersForTransactionInPreparedAbortState(): Unit = { + verifyWritesTxnMarkersInPrepareState(PrepareAbort) + } + + @Test + def shouldRemoveCompleteCommitExpiredTransactionalIds(): Unit = { + setupAndRunTransactionalIdExpiration(Errors.NONE, CompleteCommit) + verifyMetadataDoesntExist(transactionalId1) + verifyMetadataDoesExistAndIsUsable(transactionalId2) + } + + @Test + def shouldRemoveCompleteAbortExpiredTransactionalIds(): Unit = { + setupAndRunTransactionalIdExpiration(Errors.NONE, CompleteAbort) + verifyMetadataDoesntExist(transactionalId1) + verifyMetadataDoesExistAndIsUsable(transactionalId2) + } + + @Test + def shouldRemoveEmptyExpiredTransactionalIds(): Unit = { + setupAndRunTransactionalIdExpiration(Errors.NONE, Empty) + verifyMetadataDoesntExist(transactionalId1) + verifyMetadataDoesExistAndIsUsable(transactionalId2) + } + + @Test + def shouldNotRemoveExpiredTransactionalIdsIfLogAppendFails(): Unit = { + setupAndRunTransactionalIdExpiration(Errors.NOT_ENOUGH_REPLICAS, CompleteAbort) + verifyMetadataDoesExistAndIsUsable(transactionalId1) + verifyMetadataDoesExistAndIsUsable(transactionalId2) + } + + @Test + def shouldNotRemoveOngoingTransactionalIds(): Unit = { + setupAndRunTransactionalIdExpiration(Errors.NONE, Ongoing) + verifyMetadataDoesExistAndIsUsable(transactionalId1) + verifyMetadataDoesExistAndIsUsable(transactionalId2) + } + + @Test + def shouldNotRemovePrepareAbortTransactionalIds(): Unit = { + setupAndRunTransactionalIdExpiration(Errors.NONE, PrepareAbort) + verifyMetadataDoesExistAndIsUsable(transactionalId1) + verifyMetadataDoesExistAndIsUsable(transactionalId2) + } + + @Test + def shouldNotRemovePrepareCommitTransactionalIds(): Unit = { + setupAndRunTransactionalIdExpiration(Errors.NONE, PrepareCommit) + verifyMetadataDoesExistAndIsUsable(transactionalId1) + verifyMetadataDoesExistAndIsUsable(transactionalId2) + } + + @Test + def testTransactionalExpirationWithTooSmallBatchSize(): Unit = { + // The batch size is too small, but we nevertheless expect the + // coordinator to attempt the append. This test mainly ensures + // that the expiration task does not get stuck. + + val partitionIds = 0 until numPartitions + val maxBatchSize = 16 + + loadTransactionsForPartitions(partitionIds) + val allTransactionalIds = loadExpiredTransactionalIds(numTransactionalIds = 20) + + EasyMock.reset(replicaManager) + expectLogConfig(partitionIds, maxBatchSize) + + val attemptedAppends = mutable.Map.empty[TopicPartition, mutable.Buffer[MemoryRecords]] + expectTransactionalIdExpiration(Errors.MESSAGE_TOO_LARGE, attemptedAppends) + EasyMock.replay(replicaManager) + + assertEquals(allTransactionalIds, listExpirableTransactionalIds()) + transactionManager.removeExpiredTransactionalIds() + EasyMock.verify(replicaManager) + + for (batches <- attemptedAppends.values; batch <- batches) { + assertTrue(batch.sizeInBytes() > maxBatchSize) + } + + assertEquals(allTransactionalIds, listExpirableTransactionalIds()) + } + + @Test + def testTransactionalExpirationWithOfflineLogDir(): Unit = { + val onlinePartitionId = 0 + val offlinePartitionId = 1 + + val partitionIds = Seq(onlinePartitionId, offlinePartitionId) + val maxBatchSize = 512 + + loadTransactionsForPartitions(partitionIds) + val allTransactionalIds = loadExpiredTransactionalIds(numTransactionalIds = 20) + + EasyMock.reset(replicaManager) + + // Partition 0 returns log config as normal + expectLogConfig(Seq(onlinePartitionId), maxBatchSize) + // No log config returned for partition 0 since it is offline + EasyMock.expect(replicaManager.getLogConfig(new TopicPartition(TRANSACTION_STATE_TOPIC_NAME, offlinePartitionId))) + .andStubReturn(None) + + val appendedRecords = mutable.Map.empty[TopicPartition, mutable.Buffer[MemoryRecords]] + expectTransactionalIdExpiration(Errors.NONE, appendedRecords) + EasyMock.replay(replicaManager) + + assertEquals(allTransactionalIds, listExpirableTransactionalIds()) + transactionManager.removeExpiredTransactionalIds() + EasyMock.verify(replicaManager) + + assertEquals(Set(onlinePartitionId), appendedRecords.keySet.map(_.partition)) + + val (transactionalIdsForOnlinePartition, transactionalIdsForOfflinePartition) = + allTransactionalIds.partition { transactionalId => + transactionManager.partitionFor(transactionalId) == onlinePartitionId + } + + val expiredTransactionalIds = collectTransactionalIdsFromTombstones(appendedRecords) + assertEquals(transactionalIdsForOnlinePartition, expiredTransactionalIds) + assertEquals(transactionalIdsForOfflinePartition, listExpirableTransactionalIds()) + } + + @Test + def testTransactionExpirationShouldRespectBatchSize(): Unit = { + val partitionIds = 0 until numPartitions + val maxBatchSize = 512 + + loadTransactionsForPartitions(partitionIds) + val allTransactionalIds = loadExpiredTransactionalIds(numTransactionalIds = 1000) + + EasyMock.reset(replicaManager) + expectLogConfig(partitionIds, maxBatchSize) + + val appendedRecords = mutable.Map.empty[TopicPartition, mutable.Buffer[MemoryRecords]] + expectTransactionalIdExpiration(Errors.NONE, appendedRecords) + EasyMock.replay(replicaManager) + + assertEquals(allTransactionalIds, listExpirableTransactionalIds()) + transactionManager.removeExpiredTransactionalIds() + EasyMock.verify(replicaManager) + + assertEquals(Set.empty, listExpirableTransactionalIds()) + assertEquals(partitionIds.toSet, appendedRecords.keys.map(_.partition)) + + appendedRecords.values.foreach { batches => + assertTrue(batches.size > 1) // Ensure a non-trivial test case + assertTrue(batches.forall(_.sizeInBytes() < maxBatchSize)) + } + + val expiredTransactionalIds = collectTransactionalIdsFromTombstones(appendedRecords) + assertEquals(allTransactionalIds, expiredTransactionalIds) + } + + private def collectTransactionalIdsFromTombstones( + appendedRecords: mutable.Map[TopicPartition, mutable.Buffer[MemoryRecords]] + ): Set[String] = { + val expiredTransactionalIds = mutable.Set.empty[String] + appendedRecords.values.foreach { batches => + batches.foreach { records => + records.records.forEach { record => + val transactionalId = TransactionLog.readTxnRecordKey(record.key).transactionalId + assertNull(record.value) + expiredTransactionalIds += transactionalId + assertEquals(Right(None), transactionManager.getTransactionState(transactionalId)) + } + } + } + expiredTransactionalIds.toSet + } + + private def loadExpiredTransactionalIds( + numTransactionalIds: Int + ): Set[String] = { + val allTransactionalIds = mutable.Set.empty[String] + for (i <- 0 to numTransactionalIds) { + val txnlId = s"id_$i" + val producerId = i + val txnMetadata = transactionMetadata(txnlId, producerId) + txnMetadata.txnLastUpdateTimestamp = time.milliseconds() - txnConfig.transactionalIdExpirationMs + transactionManager.putTransactionStateIfNotExists(txnMetadata) + allTransactionalIds += txnlId + } + allTransactionalIds.toSet + } + + private def listExpirableTransactionalIds(): Set[String] = { + val activeTransactionalIds = transactionManager.listTransactionStates(Set.empty, Set.empty) + .transactionStates + .asScala + .map(_.transactionalId) + + activeTransactionalIds.filter { transactionalId => + transactionManager.getTransactionState(transactionalId) match { + case Right(Some(epochAndMetadata)) => + val txnMetadata = epochAndMetadata.transactionMetadata + val timeSinceLastUpdate = time.milliseconds() - txnMetadata.txnLastUpdateTimestamp + timeSinceLastUpdate >= txnConfig.transactionalIdExpirationMs && + txnMetadata.state.isExpirationAllowed && + txnMetadata.pendingState.isEmpty + case _ => false + } + }.toSet + } + + @Test + def testSuccessfulReimmigration(): Unit = { + txnMetadata1.state = PrepareCommit + txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), + new TopicPartition("topic1", 1))) + + txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit())) + val startOffset = 0L + val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, txnRecords.toArray: _*) + + prepareTxnLog(topicPartition, 0, records) + + // immigrate partition at epoch 0 + transactionManager.loadTransactionsForTxnTopicPartition(partitionId, coordinatorEpoch = 0, (_, _, _, _) => ()) + assertEquals(0, transactionManager.loadingPartitions.size) + + // Re-immigrate partition at epoch 1. This should be successful even though we didn't get to emigrate the partition. + prepareTxnLog(topicPartition, 0, records) + transactionManager.loadTransactionsForTxnTopicPartition(partitionId, coordinatorEpoch = 1, (_, _, _, _) => ()) + assertEquals(0, transactionManager.loadingPartitions.size) + assertTrue(transactionManager.transactionMetadataCache.get(partitionId).isDefined) + assertEquals(1, transactionManager.transactionMetadataCache.get(partitionId).get.coordinatorEpoch) + } + + @Test + def testLoadTransactionMetadataWithCorruptedLog(): Unit = { + // Simulate a case where startOffset < endOffset but log is empty. This could theoretically happen + // when all the records are expired and the active segment is truncated or when the partition + // is accidentally corrupted. + val startOffset = 0L + val endOffset = 10L + + val logMock: UnifiedLog = EasyMock.mock(classOf[UnifiedLog]) + EasyMock.expect(replicaManager.getLog(topicPartition)).andStubReturn(Some(logMock)) + EasyMock.expect(logMock.logStartOffset).andStubReturn(startOffset) + EasyMock.expect(logMock.read(EasyMock.eq(startOffset), + maxLength = EasyMock.anyInt(), + isolation = EasyMock.eq(FetchLogEnd), + minOneMessage = EasyMock.eq(true)) + ).andReturn(FetchDataInfo(LogOffsetMetadata(startOffset), MemoryRecords.EMPTY)) + EasyMock.expect(replicaManager.getLogEndOffset(topicPartition)).andStubReturn(Some(endOffset)) + + EasyMock.replay(logMock) + EasyMock.replay(replicaManager) + + transactionManager.loadTransactionsForTxnTopicPartition(partitionId, coordinatorEpoch = 0, (_, _, _, _) => ()) + + // let the time advance to trigger the background thread loading + scheduler.tick() + + EasyMock.verify(logMock) + EasyMock.verify(replicaManager) + assertEquals(0, transactionManager.loadingPartitions.size) + } + + private def verifyMetadataDoesExistAndIsUsable(transactionalId: String): Unit = { + transactionManager.getTransactionState(transactionalId) match { + case Left(_) => fail("shouldn't have been any errors") + case Right(None) => fail("metadata should have been removed") + case Right(Some(metadata)) => + assertTrue(metadata.transactionMetadata.pendingState.isEmpty, "metadata shouldn't be in a pending state") + } + } + + private def verifyMetadataDoesntExist(transactionalId: String): Unit = { + transactionManager.getTransactionState(transactionalId) match { + case Left(_) => fail("shouldn't have been any errors") + case Right(Some(_)) => fail("metadata should have been removed") + case Right(None) => // ok + } + } + + private def expectTransactionalIdExpiration( + appendError: Errors, + capturedAppends: mutable.Map[TopicPartition, mutable.Buffer[MemoryRecords]] + ): Unit = { + val recordsCapture: Capture[Map[TopicPartition, MemoryRecords]] = EasyMock.newCapture() + val callbackCapture: Capture[Map[TopicPartition, PartitionResponse] => Unit] = EasyMock.newCapture() + + EasyMock.expect(replicaManager.appendRecords( + EasyMock.anyLong(), + EasyMock.eq((-1).toShort), + EasyMock.eq(true), + EasyMock.eq(AppendOrigin.Coordinator), + EasyMock.capture(recordsCapture), + EasyMock.capture(callbackCapture), + EasyMock.anyObject().asInstanceOf[Option[ReentrantLock]], + EasyMock.anyObject(), + EasyMock.anyObject() + )).andAnswer(() => callbackCapture.getValue.apply( + recordsCapture.getValue.map { case (topicPartition, records) => + val batches = capturedAppends.getOrElse(topicPartition, { + val batches = mutable.Buffer.empty[MemoryRecords] + capturedAppends += topicPartition -> batches + batches + }) + + batches += records + + topicPartition -> new PartitionResponse(appendError, 0L, RecordBatch.NO_TIMESTAMP, 0L) + }.toMap + )).anyTimes() + } + + private def loadTransactionsForPartitions( + partitionIds: Seq[Int], + ): Unit = { + for (partitionId <- partitionIds) { + transactionManager.addLoadedTransactionsToCache(partitionId, 0, new Pool[String, TransactionMetadata]()) + } + } + + private def expectLogConfig( + partitionIds: Seq[Int], + maxBatchSize: Int + ): Unit = { + val logConfig: LogConfig = EasyMock.mock(classOf[LogConfig]) + EasyMock.expect(logConfig.maxMessageSize).andStubReturn(maxBatchSize) + + for (partitionId <- partitionIds) { + EasyMock.expect(replicaManager.getLogConfig(new TopicPartition(TRANSACTION_STATE_TOPIC_NAME, partitionId))) + .andStubReturn(Some(logConfig)) + } + + EasyMock.replay(logConfig) + } + + private def setupAndRunTransactionalIdExpiration(error: Errors, txnState: TransactionState): Unit = { + val partitionIds = 0 until numPartitions + + loadTransactionsForPartitions(partitionIds) + expectLogConfig(partitionIds, Defaults.MaxMessageSize) + + txnMetadata1.txnLastUpdateTimestamp = time.milliseconds() - txnConfig.transactionalIdExpirationMs + txnMetadata1.state = txnState + transactionManager.putTransactionStateIfNotExists(txnMetadata1) + + txnMetadata2.txnLastUpdateTimestamp = time.milliseconds() + transactionManager.putTransactionStateIfNotExists(txnMetadata2) + + val appendedRecords = mutable.Map.empty[TopicPartition, mutable.Buffer[MemoryRecords]] + expectTransactionalIdExpiration(error, appendedRecords) + + EasyMock.replay(replicaManager) + transactionManager.removeExpiredTransactionalIds() + EasyMock.verify(replicaManager) + + val stateAllowsExpiration = txnState match { + case Empty | CompleteCommit | CompleteAbort => true + case _ => false + } + + if (stateAllowsExpiration) { + val partitionId = transactionManager.partitionFor(transactionalId1) + val topicPartition = new TopicPartition(TRANSACTION_STATE_TOPIC_NAME, partitionId) + val expectedTombstone = new SimpleRecord(time.milliseconds(), TransactionLog.keyToBytes(transactionalId1), null) + val expectedRecords = MemoryRecords.withRecords(TransactionLog.EnforcedCompressionType, expectedTombstone) + assertEquals(Set(topicPartition), appendedRecords.keySet) + assertEquals(Seq(expectedRecords), appendedRecords(topicPartition).toSeq) + } else { + assertEquals(Map.empty, appendedRecords) + } + } + + private def verifyWritesTxnMarkersInPrepareState(state: TransactionState): Unit = { + txnMetadata1.state = state + txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), + new TopicPartition("topic1", 1))) + + txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit())) + val startOffset = 0L + val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, txnRecords.toArray: _*) + + prepareTxnLog(topicPartition, 0, records) + + var txnId: String = null + def rememberTxnMarkers(coordinatorEpoch: Int, + command: TransactionResult, + metadata: TransactionMetadata, + newMetadata: TxnTransitMetadata): Unit = { + txnId = metadata.transactionalId + } + + transactionManager.loadTransactionsForTxnTopicPartition(partitionId, 0, rememberTxnMarkers) + scheduler.tick() + + assertEquals(transactionalId1, txnId) + } + + private def assertCallback(error: Errors): Unit = { + assertEquals(expectedError, error) + } + + private def transactionMetadata(transactionalId: String, + producerId: Long, + state: TransactionState = Empty, + txnTimeout: Int = transactionTimeoutMs): TransactionMetadata = { + TransactionMetadata(transactionalId, producerId, 0.toShort, txnTimeout, state, time.milliseconds()) + } + + private def prepareTxnLog(topicPartition: TopicPartition, + startOffset: Long, + records: MemoryRecords): Unit = { + EasyMock.reset(replicaManager) + + val logMock: UnifiedLog = EasyMock.mock(classOf[UnifiedLog]) + val fileRecordsMock: FileRecords = EasyMock.mock(classOf[FileRecords]) + + val endOffset = startOffset + records.records.asScala.size + + EasyMock.expect(replicaManager.getLog(topicPartition)).andStubReturn(Some(logMock)) + EasyMock.expect(replicaManager.getLogEndOffset(topicPartition)).andStubReturn(Some(endOffset)) + + EasyMock.expect(logMock.logStartOffset).andStubReturn(startOffset) + EasyMock.expect(logMock.read(EasyMock.eq(startOffset), + maxLength = EasyMock.anyInt(), + isolation = EasyMock.eq(FetchLogEnd), + minOneMessage = EasyMock.eq(true))) + .andReturn(FetchDataInfo(LogOffsetMetadata(startOffset), fileRecordsMock)) + + EasyMock.expect(fileRecordsMock.sizeInBytes()).andStubReturn(records.sizeInBytes) + + val bufferCapture = EasyMock.newCapture[ByteBuffer] + fileRecordsMock.readInto(EasyMock.capture(bufferCapture), EasyMock.anyInt()) + EasyMock.expectLastCall().andAnswer(new IAnswer[Unit] { + override def answer: Unit = { + val buffer = bufferCapture.getValue + buffer.put(records.buffer.duplicate) + buffer.flip() + } + }) + EasyMock.replay(logMock, fileRecordsMock, replicaManager) + } + + private def prepareForTxnMessageAppend(error: Errors): Unit = { + EasyMock.reset(replicaManager) + + val capturedArgument: Capture[Map[TopicPartition, PartitionResponse] => Unit] = EasyMock.newCapture() + EasyMock.expect(replicaManager.appendRecords(EasyMock.anyLong(), + EasyMock.anyShort(), + internalTopicsAllowed = EasyMock.eq(true), + origin = EasyMock.eq(AppendOrigin.Coordinator), + EasyMock.anyObject().asInstanceOf[Map[TopicPartition, MemoryRecords]], + EasyMock.capture(capturedArgument), + EasyMock.anyObject().asInstanceOf[Option[ReentrantLock]], + EasyMock.anyObject(), + EasyMock.anyObject()) + ).andAnswer(() => capturedArgument.getValue.apply( + Map(new TopicPartition(TRANSACTION_STATE_TOPIC_NAME, partitionId) -> + new PartitionResponse(error, 0L, RecordBatch.NO_TIMESTAMP, 0L))) + ) + EasyMock.expect(replicaManager.getMagic(EasyMock.anyObject())) + .andStubReturn(Some(RecordBatch.MAGIC_VALUE_V1)) + + EasyMock.replay(replicaManager) + } + + @Test + def testPartitionLoadMetric(): Unit = { + val server = ManagementFactory.getPlatformMBeanServer + val mBeanName = "kafka.server:type=transaction-coordinator-metrics" + val reporter = new JmxReporter + val metricsContext = new KafkaMetricsContext("kafka.server") + reporter.contextChange(metricsContext) + metrics.addReporter(reporter) + + def partitionLoadTime(attribute: String): Double = { + server.getAttribute(new ObjectName(mBeanName), attribute).asInstanceOf[Double] + } + + assertTrue(server.isRegistered(new ObjectName(mBeanName))) + assertEquals(Double.NaN, partitionLoadTime( "partition-load-time-max"), 0) + assertEquals(Double.NaN, partitionLoadTime("partition-load-time-avg"), 0) + assertTrue(reporter.containsMbean(mBeanName)) + + txnMetadata1.state = Ongoing + txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 1), + new TopicPartition("topic1", 1))) + + txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit())) + + val startOffset = 15L + val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, txnRecords.toArray: _*) + + prepareTxnLog(topicPartition, startOffset, records) + transactionManager.loadTransactionsForTxnTopicPartition(partitionId, 0, (_, _, _, _) => ()) + scheduler.tick() + + assertTrue(partitionLoadTime("partition-load-time-max") >= 0) + assertTrue(partitionLoadTime( "partition-load-time-avg") >= 0) + } +} diff --git a/core/src/test/scala/unit/kafka/integration/KafkaServerTestHarness.scala b/core/src/test/scala/unit/kafka/integration/KafkaServerTestHarness.scala new file mode 100755 index 0000000..e8fdaf5 --- /dev/null +++ b/core/src/test/scala/unit/kafka/integration/KafkaServerTestHarness.scala @@ -0,0 +1,255 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.integration + +import java.io.File +import java.util +import java.util.Arrays + +import kafka.server.QuorumTestHarness +import kafka.server._ +import kafka.utils.TestUtils +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.junit.jupiter.api.{AfterEach, BeforeEach, TestInfo} + +import scala.collection.{Seq, mutable} +import scala.jdk.CollectionConverters._ +import java.util.Properties + +import org.apache.kafka.common.{KafkaException, Uuid} +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.security.scram.ScramCredential +import org.apache.kafka.common.utils.Time + +/** + * A test harness that brings up some number of broker nodes + */ +abstract class KafkaServerTestHarness extends QuorumTestHarness { + var instanceConfigs: Seq[KafkaConfig] = null + + private val _brokers = new mutable.ArrayBuffer[KafkaBroker] + + /** + * Get the list of brokers, which could be either BrokerServer objects or KafkaServer objects. + */ + def brokers: mutable.Buffer[KafkaBroker] = _brokers + + /** + * Get the list of brokers, as instances of KafkaServer. + * This method should only be used when dealing with brokers that use ZooKeeper. + */ + def servers: mutable.Buffer[KafkaServer] = { + checkIsZKTest() + _brokers.map(_.asInstanceOf[KafkaServer]) + } + + var brokerList: String = null + var alive: Array[Boolean] = null + + /** + * Implementations must override this method to return a set of KafkaConfigs. This method will be invoked for every + * test and should not reuse previous configurations unless they select their ports randomly when servers are started. + */ + def generateConfigs: Seq[KafkaConfig] + + /** + * Override this in case ACLs or security credentials must be set before `servers` are started. + * + * This is required in some cases because of the topic creation in the setup of `IntegrationTestHarness`. If the ACLs + * are only set later, tests may fail. The failure could manifest itself as a cluster action + * authorization exception when processing an update metadata request (controller -> broker) or in more obscure + * ways (e.g. __consumer_offsets topic replication fails because the metadata cache has no brokers as a previous + * update metadata request failed due to an authorization exception). + * + * The default implementation of this method is a no-op. + */ + def configureSecurityBeforeServersStart(): Unit = {} + + /** + * Override this in case Tokens or security credentials needs to be created after `servers` are started. + * The default implementation of this method is a no-op. + */ + def configureSecurityAfterServersStart(): Unit = {} + + def configs: Seq[KafkaConfig] = { + if (instanceConfigs == null) + instanceConfigs = generateConfigs + instanceConfigs + } + + def serverForId(id: Int): Option[KafkaServer] = servers.find(s => s.config.brokerId == id) + + def boundPort(server: KafkaServer): Int = server.boundPort(listenerName) + + protected def securityProtocol: SecurityProtocol = SecurityProtocol.PLAINTEXT + protected def listenerName: ListenerName = ListenerName.forSecurityProtocol(securityProtocol) + protected def trustStoreFile: Option[File] = None + protected def serverSaslProperties: Option[Properties] = None + protected def clientSaslProperties: Option[Properties] = None + protected def brokerTime(brokerId: Int): Time = Time.SYSTEM + protected def enableForwarding: Boolean = false + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + + if (configs.isEmpty) + throw new KafkaException("Must supply at least one server config.") + + // default implementation is a no-op, it is overridden by subclasses if required + configureSecurityBeforeServersStart() + + // Add each broker to `servers` buffer as soon as it is created to ensure that brokers + // are shutdown cleanly in tearDown even if a subsequent broker fails to start + for (config <- configs) { + if (isKRaftTest()) { + _brokers += createAndStartBroker(config, brokerTime(config.brokerId)) + } else { + _brokers += TestUtils.createServer( + config, + time = brokerTime(config.brokerId), + threadNamePrefix = None, + enableForwarding + ) + } + } + brokerList = TestUtils.bootstrapServers(_brokers, listenerName) + alive = new Array[Boolean](_brokers.length) + Arrays.fill(alive, true) + + // default implementation is a no-op, it is overridden by subclasses if required + configureSecurityAfterServersStart() + } + + @AfterEach + override def tearDown(): Unit = { + TestUtils.shutdownServers(_brokers) + super.tearDown() + } + + /** + * Create a topic. + * Wait until the leader is elected and the metadata is propagated to all brokers. + * Return the leader for each partition. + */ + def createTopic(topic: String, + numPartitions: Int = 1, + replicationFactor: Int = 1, + topicConfig: Properties = new Properties, + adminClientConfig: Properties = new Properties): scala.collection.immutable.Map[Int, Int] = { + if (isKRaftTest()) { + TestUtils.createTopicWithAdmin(topic, numPartitions, replicationFactor, brokers, topicConfig, adminClientConfig) + } else { + TestUtils.createTopic(zkClient, topic, numPartitions, replicationFactor, servers, topicConfig) + } + } + + /** + * Create a topic in ZooKeeper using a customized replica assignment. + * Wait until the leader is elected and the metadata is propagated to all brokers. + * Return the leader for each partition. + */ + def createTopic(topic: String, partitionReplicaAssignment: collection.Map[Int, Seq[Int]]): scala.collection.immutable.Map[Int, Int] = + TestUtils.createTopic(zkClient, topic, partitionReplicaAssignment, servers) + + def deleteTopic(topic: String): Unit = { + if (isKRaftTest()) { + TestUtils.deleteTopicWithAdmin(topic, brokers) + } else { + adminZkClient.deleteTopic(topic) + } + } + + /** + * Pick a broker at random and kill it if it isn't already dead + * Return the id of the broker killed + */ + def killRandomBroker(): Int = { + val index = TestUtils.random.nextInt(_brokers.length) + killBroker(index) + index + } + + def killBroker(index: Int): Unit = { + if(alive(index)) { + _brokers(index).shutdown() + _brokers(index).awaitShutdown() + alive(index) = false + } + } + + /** + * Restart any dead brokers + */ + def restartDeadBrokers(reconfigure: Boolean = false): Unit = { + if (reconfigure) { + instanceConfigs = null + } + for(i <- _brokers.indices if !alive(i)) { + if (reconfigure) { + _brokers(i) = TestUtils.createServer( + configs(i), + time = brokerTime(configs(i).brokerId), + threadNamePrefix = None, + enableForwarding + ) + } + _brokers(i).startup() + alive(i) = true + } + } + + def waitForUserScramCredentialToAppearOnAllBrokers(clientPrincipal: String, mechanismName: String): Unit = { + _brokers.foreach { server => + val cache = server.credentialProvider.credentialCache.cache(mechanismName, classOf[ScramCredential]) + TestUtils.waitUntilTrue(() => cache.get(clientPrincipal) != null, s"SCRAM credentials not created for $clientPrincipal") + } + } + + def getController(): KafkaServer = { + checkIsZKTest() + val controllerId = TestUtils.waitUntilControllerElected(zkClient) + servers.filter(s => s.config.brokerId == controllerId).head + } + + def getTopicIds(names: Seq[String]): Map[String, Uuid] = { + val result = new util.HashMap[String, Uuid]() + if (isKRaftTest()) { + val topicIdsMap = controllerServer.controller.findTopicIds(Long.MaxValue, names.asJava).get() + names.foreach { name => + val response = topicIdsMap.get(name) + result.put(name, response.result()) + } + } else { + val topicIdsMap = getController().kafkaController.controllerContext.topicIds.toMap + names.foreach { name => + if (topicIdsMap.contains(name)) result.put(name, topicIdsMap.get(name).get) + } + } + result.asScala.toMap + } + + def getTopicIds(): Map[String, Uuid] = { + getController().kafkaController.controllerContext.topicIds.toMap + } + + def getTopicNames(): Map[Uuid, String] = { + getController().kafkaController.controllerContext.topicNames.toMap + } + +} diff --git a/core/src/test/scala/unit/kafka/integration/MetricsDuringTopicCreationDeletionTest.scala b/core/src/test/scala/unit/kafka/integration/MetricsDuringTopicCreationDeletionTest.scala new file mode 100644 index 0000000..e045ea9 --- /dev/null +++ b/core/src/test/scala/unit/kafka/integration/MetricsDuringTopicCreationDeletionTest.scala @@ -0,0 +1,153 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.integration + +import java.util.Properties + +import kafka.server.KafkaConfig +import kafka.utils.{Logging, TestUtils} + +import scala.jdk.CollectionConverters._ +import org.junit.jupiter.api.{BeforeEach, Test, TestInfo} +import com.yammer.metrics.core.Gauge +import kafka.metrics.KafkaYammerMetrics + +class MetricsDuringTopicCreationDeletionTest extends KafkaServerTestHarness with Logging { + + private val nodesNum = 3 + private val topicName = "topic" + private val topicNum = 2 + private val replicationFactor = 3 + private val partitionNum = 3 + private val createDeleteIterations = 3 + + private val overridingProps = new Properties + overridingProps.put(KafkaConfig.DeleteTopicEnableProp, "true") + overridingProps.put(KafkaConfig.AutoCreateTopicsEnableProp, "false") + // speed up the test for UnderReplicatedPartitions, which relies on the ISR expiry thread to execute concurrently with topic creation + // But the replica.lag.time.max.ms value still need to consider the slow Jenkins testing environment + overridingProps.put(KafkaConfig.ReplicaLagTimeMaxMsProp, "4000") + + private val testedMetrics = List("OfflinePartitionsCount","PreferredReplicaImbalanceCount","UnderReplicatedPartitions") + private val topics = List.tabulate(topicNum) (n => topicName + n) + + @volatile private var running = true + + override def generateConfigs = TestUtils.createBrokerConfigs(nodesNum, zkConnect) + .map(KafkaConfig.fromProps(_, overridingProps)) + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + // Do some Metrics Registry cleanup by removing the metrics that this test checks. + // This is a test workaround to the issue that prior harness runs may have left a populated registry. + // see https://issues.apache.org/jira/browse/KAFKA-4605 + for (m <- testedMetrics) { + val metricName = KafkaYammerMetrics.defaultRegistry.allMetrics.asScala.keys.find(_.getName.endsWith(m)) + metricName.foreach(KafkaYammerMetrics.defaultRegistry.removeMetric) + } + + super.setUp(testInfo) + } + + /* + * checking all metrics we care in a single test is faster though it would be more elegant to have 3 @Test methods + */ + @Test + def testMetricsDuringTopicCreateDelete(): Unit = { + + // For UnderReplicatedPartitions, because of https://issues.apache.org/jira/browse/KAFKA-4605 + // we can't access the metrics value of each server. So instead we directly invoke the method + // replicaManager.underReplicatedPartitionCount() that defines the metrics value. + @volatile var underReplicatedPartitionCount = 0 + + // For OfflinePartitionsCount and PreferredReplicaImbalanceCount even with https://issues.apache.org/jira/browse/KAFKA-4605 + // the test has worked reliably because the metric that gets triggered is the one generated by the first started server (controller) + val offlinePartitionsCountGauge = getGauge("OfflinePartitionsCount") + @volatile var offlinePartitionsCount = offlinePartitionsCountGauge.value + assert(offlinePartitionsCount == 0) + + val preferredReplicaImbalanceCountGauge = getGauge("PreferredReplicaImbalanceCount") + @volatile var preferredReplicaImbalanceCount = preferredReplicaImbalanceCountGauge.value + assert(preferredReplicaImbalanceCount == 0) + + // Thread checking the metric continuously + running = true + val thread = new Thread(() => { + while (running) { + for (s <- servers if running) { + underReplicatedPartitionCount = s.replicaManager.underReplicatedPartitionCount + if (underReplicatedPartitionCount > 0) { + running = false + } + } + + preferredReplicaImbalanceCount = preferredReplicaImbalanceCountGauge.value + if (preferredReplicaImbalanceCount > 0) { + running = false + } + + offlinePartitionsCount = offlinePartitionsCountGauge.value + if (offlinePartitionsCount > 0) { + running = false + } + } + }) + thread.start + + // breakable loop that creates and deletes topics + createDeleteTopics() + + // if the thread checking the gauge is still run, stop it + running = false; + thread.join + + assert(offlinePartitionsCount==0, s"Expect offlinePartitionsCount to be 0, but got: $offlinePartitionsCount") + assert(preferredReplicaImbalanceCount==0, s"Expect PreferredReplicaImbalanceCount to be 0, but got: $preferredReplicaImbalanceCount") + assert(underReplicatedPartitionCount==0, s"Expect UnderReplicatedPartitionCount to be 0, but got: $underReplicatedPartitionCount") + } + + private def getGauge(metricName: String) = { + KafkaYammerMetrics.defaultRegistry.allMetrics.asScala + .find { case (k, _) => k.getName.endsWith(metricName) } + .getOrElse(throw new AssertionError( "Unable to find metric " + metricName)) + ._2.asInstanceOf[Gauge[Int]] + } + + private def createDeleteTopics(): Unit = { + for (l <- 1 to createDeleteIterations if running) { + // Create topics + for (t <- topics if running) { + try { + createTopic(t, partitionNum, replicationFactor) + } catch { + case e: Exception => e.printStackTrace + } + } + + // Delete topics + for (t <- topics if running) { + try { + adminZkClient.deleteTopic(t) + TestUtils.verifyTopicDeletion(zkClient, t, partitionNum, servers) + } catch { + case e: Exception => e.printStackTrace + } + } + } + } +} diff --git a/core/src/test/scala/unit/kafka/integration/MinIsrConfigTest.scala b/core/src/test/scala/unit/kafka/integration/MinIsrConfigTest.scala new file mode 100644 index 0000000..8ee0647 --- /dev/null +++ b/core/src/test/scala/unit/kafka/integration/MinIsrConfigTest.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.integration + +import java.util.Properties + +import kafka.server.KafkaConfig +import kafka.utils.TestUtils +import org.junit.jupiter.api.Test + +class MinIsrConfigTest extends KafkaServerTestHarness { + + val overridingProps = new Properties() + overridingProps.put(KafkaConfig.MinInSyncReplicasProp, "5") + def generateConfigs = TestUtils.createBrokerConfigs(1, zkConnect).map(KafkaConfig.fromProps(_, overridingProps)) + + @Test + def testDefaultKafkaConfig(): Unit = { + assert(servers.head.getLogManager.initialDefaultConfig.minInSyncReplicas == 5) + } + +} diff --git a/core/src/test/scala/unit/kafka/integration/UncleanLeaderElectionTest.scala b/core/src/test/scala/unit/kafka/integration/UncleanLeaderElectionTest.scala new file mode 100755 index 0000000..9db9e3f --- /dev/null +++ b/core/src/test/scala/unit/kafka/integration/UncleanLeaderElectionTest.scala @@ -0,0 +1,359 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.integration + +import org.apache.kafka.common.config.{ConfigException, ConfigResource} +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +import scala.util.Random +import scala.jdk.CollectionConverters._ +import scala.collection.{Map, Seq} +import org.apache.log4j.{Level, Logger} +import java.util.Properties +import java.util.concurrent.ExecutionException + +import kafka.server.{KafkaConfig, KafkaServer} +import kafka.utils.{CoreUtils, TestUtils} +import kafka.utils.TestUtils._ +import kafka.server.QuorumTestHarness +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.errors.TimeoutException +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.serialization.StringDeserializer +import org.apache.kafka.clients.admin.{Admin, AdminClientConfig, AlterConfigsResult, Config, ConfigEntry} +import org.junit.jupiter.api.Assertions._ + +import scala.annotation.nowarn + +class UncleanLeaderElectionTest extends QuorumTestHarness { + val brokerId1 = 0 + val brokerId2 = 1 + + // controlled shutdown is needed for these tests, but we can trim the retry count and backoff interval to + // reduce test execution time + val enableControlledShutdown = true + + var configProps1: Properties = null + var configProps2: Properties = null + + var configs: Seq[KafkaConfig] = Seq.empty[KafkaConfig] + var servers: Seq[KafkaServer] = Seq.empty[KafkaServer] + + val random = new Random() + val topic = "topic" + random.nextLong() + val partitionId = 0 + + val kafkaApisLogger = Logger.getLogger(classOf[kafka.server.KafkaApis]) + val networkProcessorLogger = Logger.getLogger(classOf[kafka.network.Processor]) + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + + configProps1 = createBrokerConfig(brokerId1, zkConnect) + configProps2 = createBrokerConfig(brokerId2, zkConnect) + + for (configProps <- List(configProps1, configProps2)) { + configProps.put("controlled.shutdown.enable", enableControlledShutdown.toString) + configProps.put("controlled.shutdown.max.retries", "1") + configProps.put("controlled.shutdown.retry.backoff.ms", "1000") + } + + // temporarily set loggers to a higher level so that tests run quietly + kafkaApisLogger.setLevel(Level.FATAL) + networkProcessorLogger.setLevel(Level.FATAL) + } + + @AfterEach + override def tearDown(): Unit = { + servers.foreach(server => shutdownServer(server)) + servers.foreach(server => CoreUtils.delete(server.config.logDirs)) + + // restore log levels + kafkaApisLogger.setLevel(Level.ERROR) + networkProcessorLogger.setLevel(Level.ERROR) + + super.tearDown() + } + + private def startBrokers(cluster: Seq[Properties]): Unit = { + for (props <- cluster) { + val config = KafkaConfig.fromProps(props) + val server = createServer(config) + configs ++= List(config) + servers ++= List(server) + } + } + + @Test + def testUncleanLeaderElectionEnabled(): Unit = { + // enable unclean leader election + configProps1.put("unclean.leader.election.enable", "true") + configProps2.put("unclean.leader.election.enable", "true") + startBrokers(Seq(configProps1, configProps2)) + + // create topic with 1 partition, 2 replicas, one on each broker + TestUtils.createTopic(zkClient, topic, Map(partitionId -> Seq(brokerId1, brokerId2)), servers) + + verifyUncleanLeaderElectionEnabled() + } + + @Test + def testUncleanLeaderElectionDisabled(): Unit = { + // unclean leader election is disabled by default + startBrokers(Seq(configProps1, configProps2)) + + // create topic with 1 partition, 2 replicas, one on each broker + TestUtils.createTopic(zkClient, topic, Map(partitionId -> Seq(brokerId1, brokerId2)), servers) + + verifyUncleanLeaderElectionDisabled() + } + + @Test + def testUncleanLeaderElectionEnabledByTopicOverride(): Unit = { + // disable unclean leader election globally, but enable for our specific test topic + configProps1.put("unclean.leader.election.enable", "false") + configProps2.put("unclean.leader.election.enable", "false") + startBrokers(Seq(configProps1, configProps2)) + + // create topic with 1 partition, 2 replicas, one on each broker, and unclean leader election enabled + val topicProps = new Properties() + topicProps.put("unclean.leader.election.enable", "true") + TestUtils.createTopic(zkClient, topic, Map(partitionId -> Seq(brokerId1, brokerId2)), servers, topicProps) + + verifyUncleanLeaderElectionEnabled() + } + + @Test + def testUncleanLeaderElectionDisabledByTopicOverride(): Unit = { + // enable unclean leader election globally, but disable for our specific test topic + configProps1.put("unclean.leader.election.enable", "true") + configProps2.put("unclean.leader.election.enable", "true") + startBrokers(Seq(configProps1, configProps2)) + + // create topic with 1 partition, 2 replicas, one on each broker, and unclean leader election disabled + val topicProps = new Properties() + topicProps.put("unclean.leader.election.enable", "false") + TestUtils.createTopic(zkClient, topic, Map(partitionId -> Seq(brokerId1, brokerId2)), servers, topicProps) + + verifyUncleanLeaderElectionDisabled() + } + + @Test + def testUncleanLeaderElectionInvalidTopicOverride(): Unit = { + startBrokers(Seq(configProps1)) + + // create topic with an invalid value for unclean leader election + val topicProps = new Properties() + topicProps.put("unclean.leader.election.enable", "invalid") + + assertThrows(classOf[ConfigException], + () => TestUtils.createTopic(zkClient, topic, Map(partitionId -> Seq(brokerId1)), servers, topicProps)) + } + + def verifyUncleanLeaderElectionEnabled(): Unit = { + // wait until leader is elected + val leaderId = waitUntilLeaderIsElectedOrChanged(zkClient, topic, partitionId) + debug("Leader for " + topic + " is elected to be: %s".format(leaderId)) + assertTrue(leaderId == brokerId1 || leaderId == brokerId2, + "Leader id is set to expected value for topic: " + topic) + + // the non-leader broker is the follower + val followerId = if (leaderId == brokerId1) brokerId2 else brokerId1 + debug("Follower for " + topic + " is: %s".format(followerId)) + + produceMessage(servers, topic, "first") + waitForPartitionMetadata(servers, topic, partitionId) + assertEquals(List("first"), consumeAllMessages(topic, 1)) + + // shutdown follower server + servers.filter(server => server.config.brokerId == followerId).map(server => shutdownServer(server)) + + produceMessage(servers, topic, "second") + assertEquals(List("first", "second"), consumeAllMessages(topic, 2)) + + //remove any previous unclean election metric + servers.map(_.kafkaController.controllerContext.stats.removeMetric("UncleanLeaderElectionsPerSec")) + + // shutdown leader and then restart follower + servers.filter(_.config.brokerId == leaderId).map(shutdownServer) + val followerServer = servers.find(_.config.brokerId == followerId).get + followerServer.startup() + + // wait until new leader is (uncleanly) elected + waitUntilLeaderIsElectedOrChanged(zkClient, topic, partitionId, newLeaderOpt = Some(followerId)) + assertEquals(1, followerServer.kafkaController.controllerContext.stats.uncleanLeaderElectionRate.count()) + + produceMessage(servers, topic, "third") + + // second message was lost due to unclean election + assertEquals(List("first", "third"), consumeAllMessages(topic, 2)) + } + + def verifyUncleanLeaderElectionDisabled(): Unit = { + // wait until leader is elected + val leaderId = waitUntilLeaderIsElectedOrChanged(zkClient, topic, partitionId) + debug("Leader for " + topic + " is elected to be: %s".format(leaderId)) + assertTrue(leaderId == brokerId1 || leaderId == brokerId2, + "Leader id is set to expected value for topic: " + topic) + + // the non-leader broker is the follower + val followerId = if (leaderId == brokerId1) brokerId2 else brokerId1 + debug("Follower for " + topic + " is: %s".format(followerId)) + + produceMessage(servers, topic, "first") + waitForPartitionMetadata(servers, topic, partitionId) + assertEquals(List("first"), consumeAllMessages(topic, 1)) + + // shutdown follower server + servers.filter(server => server.config.brokerId == followerId).foreach(server => shutdownServer(server)) + + produceMessage(servers, topic, "second") + assertEquals(List("first", "second"), consumeAllMessages(topic, 2)) + + //remove any previous unclean election metric + servers.foreach(server => server.kafkaController.controllerContext.stats.removeMetric("UncleanLeaderElectionsPerSec")) + + // shutdown leader and then restart follower + servers.filter(server => server.config.brokerId == leaderId).foreach(server => shutdownServer(server)) + val followerServer = servers.find(_.config.brokerId == followerId).get + followerServer.startup() + + // verify that unclean election to non-ISR follower does not occur + waitUntilLeaderIsElectedOrChanged(zkClient, topic, partitionId, newLeaderOpt = Some(-1)) + assertEquals(0, followerServer.kafkaController.controllerContext.stats.uncleanLeaderElectionRate.count()) + + // message production and consumption should both fail while leader is down + val e = assertThrows(classOf[ExecutionException], () => produceMessage(servers, topic, "third", deliveryTimeoutMs = 1000, requestTimeoutMs = 1000)) + assertEquals(classOf[TimeoutException], e.getCause.getClass) + + assertEquals(List.empty[String], consumeAllMessages(topic, 0)) + + // restart leader temporarily to send a successfully replicated message + servers.filter(server => server.config.brokerId == leaderId).foreach(server => server.startup()) + waitUntilLeaderIsElectedOrChanged(zkClient, topic, partitionId, newLeaderOpt = Some(leaderId)) + + produceMessage(servers, topic, "third") + //make sure follower server joins the ISR + TestUtils.waitUntilTrue(() => { + val partitionInfoOpt = followerServer.metadataCache.getPartitionInfo(topic, partitionId) + partitionInfoOpt.isDefined && partitionInfoOpt.get.isr.contains(followerId) + }, "Inconsistent metadata after first server startup") + + servers.filter(server => server.config.brokerId == leaderId).foreach(server => shutdownServer(server)) + // verify clean leader transition to ISR follower + waitUntilLeaderIsElectedOrChanged(zkClient, topic, partitionId, newLeaderOpt = Some(followerId)) + + // verify messages can be consumed from ISR follower that was just promoted to leader + assertEquals(List("first", "second", "third"), consumeAllMessages(topic, 3)) + } + + private def shutdownServer(server: KafkaServer) = { + server.shutdown() + server.awaitShutdown() + } + + private def consumeAllMessages(topic: String, numMessages: Int): Seq[String] = { + val brokerList = TestUtils.bootstrapServers(servers, ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT)) + // Don't rely on coordinator as it may be down when this method is called + val consumer = TestUtils.createConsumer(brokerList, + groupId = "group" + random.nextLong(), + enableAutoCommit = false, + valueDeserializer = new StringDeserializer) + try { + val tp = new TopicPartition(topic, partitionId) + consumer.assign(Seq(tp).asJava) + consumer.seek(tp, 0) + TestUtils.consumeRecords(consumer, numMessages).map(_.value) + } finally consumer.close() + } + + @Test + def testTopicUncleanLeaderElectionEnable(): Unit = { + // unclean leader election is disabled by default + startBrokers(Seq(configProps1, configProps2)) + + // create topic with 1 partition, 2 replicas, one on each broker + adminZkClient.createTopicWithAssignment(topic, config = new Properties(), Map(partitionId -> Seq(brokerId1, brokerId2))) + + // wait until leader is elected + val leaderId = waitUntilLeaderIsElectedOrChanged(zkClient, topic, partitionId) + + // the non-leader broker is the follower + val followerId = if (leaderId == brokerId1) brokerId2 else brokerId1 + + produceMessage(servers, topic, "first") + waitForPartitionMetadata(servers, topic, partitionId) + assertEquals(List("first"), consumeAllMessages(topic, 1)) + + // shutdown follower server + servers.filter(server => server.config.brokerId == followerId).map(server => shutdownServer(server)) + + produceMessage(servers, topic, "second") + assertEquals(List("first", "second"), consumeAllMessages(topic, 2)) + + //remove any previous unclean election metric + servers.map(server => server.kafkaController.controllerContext.stats.removeMetric("UncleanLeaderElectionsPerSec")) + + // shutdown leader and then restart follower + servers.filter(server => server.config.brokerId == leaderId).map(server => shutdownServer(server)) + val followerServer = servers.find(_.config.brokerId == followerId).get + followerServer.startup() + + assertEquals(0, followerServer.kafkaController.controllerContext.stats.uncleanLeaderElectionRate.count()) + + // message production and consumption should both fail while leader is down + val e = assertThrows(classOf[ExecutionException], () => produceMessage(servers, topic, "third", deliveryTimeoutMs = 1000, requestTimeoutMs = 1000)) + assertEquals(classOf[TimeoutException], e.getCause.getClass) + + assertEquals(List.empty[String], consumeAllMessages(topic, 0)) + + // Enable unclean leader election for topic + val adminClient = createAdminClient() + val newProps = new Properties + newProps.put(KafkaConfig.UncleanLeaderElectionEnableProp, "true") + alterTopicConfigs(adminClient, topic, newProps).all.get + adminClient.close() + + // wait until new leader is (uncleanly) elected + waitUntilLeaderIsElectedOrChanged(zkClient, topic, partitionId, newLeaderOpt = Some(followerId)) + assertEquals(1, followerServer.kafkaController.controllerContext.stats.uncleanLeaderElectionRate.count()) + + produceMessage(servers, topic, "third") + + // second message was lost due to unclean election + assertEquals(List("first", "third"), consumeAllMessages(topic, 2)) + } + + @nowarn("cat=deprecation") + private def alterTopicConfigs(adminClient: Admin, topic: String, topicConfigs: Properties): AlterConfigsResult = { + val configEntries = topicConfigs.asScala.map { case (k, v) => new ConfigEntry(k, v) }.toList.asJava + val newConfig = new Config(configEntries) + val configs = Map(new ConfigResource(ConfigResource.Type.TOPIC, topic) -> newConfig).asJava + adminClient.alterConfigs(configs) + } + + private def createAdminClient(): Admin = { + val config = new Properties + val bootstrapServers = TestUtils.bootstrapServers(servers, new ListenerName("PLAINTEXT")) + config.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers) + config.put(AdminClientConfig.METADATA_MAX_AGE_CONFIG, "10") + Admin.create(config) + } +} diff --git a/core/src/test/scala/unit/kafka/log/AbstractLogCleanerIntegrationTest.scala b/core/src/test/scala/unit/kafka/log/AbstractLogCleanerIntegrationTest.scala new file mode 100644 index 0000000..ba10336 --- /dev/null +++ b/core/src/test/scala/unit/kafka/log/AbstractLogCleanerIntegrationTest.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.log + +import java.io.File +import java.nio.file.Files +import java.util.Properties +import kafka.server.{BrokerTopicStats, LogDirFailureChannel} +import kafka.utils.{MockTime, Pool, TestUtils} +import kafka.utils.Implicits._ +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.record.{CompressionType, MemoryRecords, RecordBatch} +import org.apache.kafka.common.utils.Utils +import org.junit.jupiter.api.{AfterEach, Tag} + +import scala.collection.Seq +import scala.collection.mutable.ListBuffer +import scala.util.Random + +@Tag("integration") +abstract class AbstractLogCleanerIntegrationTest { + + var cleaner: LogCleaner = _ + val logDir = TestUtils.tempDir() + + private val logs = ListBuffer.empty[UnifiedLog] + private val defaultMaxMessageSize = 128 + private val defaultMinCleanableDirtyRatio = 0.0F + private val defaultMinCompactionLagMS = 0L + private val defaultDeleteDelay = 1000 + private val defaultSegmentSize = 2048 + private val defaultMaxCompactionLagMs = Long.MaxValue + + def time: MockTime + + @AfterEach + def teardown(): Unit = { + if (cleaner != null) + cleaner.shutdown() + time.scheduler.shutdown() + logs.foreach(_.close()) + Utils.delete(logDir) + } + + def logConfigProperties(propertyOverrides: Properties = new Properties(), + maxMessageSize: Int, + minCleanableDirtyRatio: Float = defaultMinCleanableDirtyRatio, + minCompactionLagMs: Long = defaultMinCompactionLagMS, + deleteDelay: Int = defaultDeleteDelay, + segmentSize: Int = defaultSegmentSize, + maxCompactionLagMs: Long = defaultMaxCompactionLagMs): Properties = { + val props = new Properties() + props.put(LogConfig.MaxMessageBytesProp, maxMessageSize: java.lang.Integer) + props.put(LogConfig.SegmentBytesProp, segmentSize: java.lang.Integer) + props.put(LogConfig.SegmentIndexBytesProp, 100*1024: java.lang.Integer) + props.put(LogConfig.FileDeleteDelayMsProp, deleteDelay: java.lang.Integer) + props.put(LogConfig.CleanupPolicyProp, LogConfig.Compact) + props.put(LogConfig.MinCleanableDirtyRatioProp, minCleanableDirtyRatio: java.lang.Float) + props.put(LogConfig.MessageTimestampDifferenceMaxMsProp, Long.MaxValue.toString) + props.put(LogConfig.MinCompactionLagMsProp, minCompactionLagMs: java.lang.Long) + props.put(LogConfig.MaxCompactionLagMsProp, maxCompactionLagMs: java.lang.Long) + props ++= propertyOverrides + props + } + + def makeCleaner(partitions: Iterable[TopicPartition], + minCleanableDirtyRatio: Float = defaultMinCleanableDirtyRatio, + numThreads: Int = 1, + backOffMs: Long = 15000L, + maxMessageSize: Int = defaultMaxMessageSize, + minCompactionLagMs: Long = defaultMinCompactionLagMS, + deleteDelay: Int = defaultDeleteDelay, + segmentSize: Int = defaultSegmentSize, + maxCompactionLagMs: Long = defaultMaxCompactionLagMs, + cleanerIoBufferSize: Option[Int] = None, + propertyOverrides: Properties = new Properties()): LogCleaner = { + + val logMap = new Pool[TopicPartition, UnifiedLog]() + for (partition <- partitions) { + val dir = new File(logDir, s"${partition.topic}-${partition.partition}") + Files.createDirectories(dir.toPath) + + val logConfig = LogConfig(logConfigProperties(propertyOverrides, + maxMessageSize = maxMessageSize, + minCleanableDirtyRatio = minCleanableDirtyRatio, + minCompactionLagMs = minCompactionLagMs, + deleteDelay = deleteDelay, + segmentSize = segmentSize, + maxCompactionLagMs = maxCompactionLagMs)) + val log = UnifiedLog(dir, + logConfig, + logStartOffset = 0L, + recoveryPoint = 0L, + scheduler = time.scheduler, + time = time, + brokerTopicStats = new BrokerTopicStats, + maxProducerIdExpirationMs = 60 * 60 * 1000, + producerIdExpirationCheckIntervalMs = LogManager.ProducerIdExpirationCheckIntervalMs, + logDirFailureChannel = new LogDirFailureChannel(10), + topicId = None, + keepPartitionMetadataFile = true) + logMap.put(partition, log) + this.logs += log + } + + val cleanerConfig = CleanerConfig( + numThreads = numThreads, + ioBufferSize = cleanerIoBufferSize.getOrElse(maxMessageSize / 2), + maxMessageSize = maxMessageSize, + backOffMs = backOffMs) + new LogCleaner(cleanerConfig, + logDirs = Array(logDir), + logs = logMap, + logDirFailureChannel = new LogDirFailureChannel(1), + time = time) + } + + private var ctr = 0 + def counter: Int = ctr + def incCounter(): Unit = ctr += 1 + + def writeDups(numKeys: Int, numDups: Int, log: UnifiedLog, codec: CompressionType, + startKey: Int = 0, magicValue: Byte = RecordBatch.CURRENT_MAGIC_VALUE): Seq[(Int, String, Long)] = { + for(_ <- 0 until numDups; key <- startKey until (startKey + numKeys)) yield { + val value = counter.toString + val appendInfo = log.appendAsLeader(TestUtils.singletonRecords(value = value.toString.getBytes, codec = codec, + key = key.toString.getBytes, magicValue = magicValue), leaderEpoch = 0) + // move LSO forward to increase compaction bound + log.updateHighWatermark(log.logEndOffset) + incCounter() + (key, value, appendInfo.firstOffset.get.messageOffset) + } + } + + def createLargeSingleMessageSet(key: Int, messageFormatVersion: Byte, codec: CompressionType): (String, MemoryRecords) = { + def messageValue(length: Int): String = { + val random = new Random(0) + new String(random.alphanumeric.take(length).toArray) + } + val value = messageValue(128) + val messageSet = TestUtils.singletonRecords(value = value.getBytes, codec = codec, key = key.toString.getBytes, + magicValue = messageFormatVersion) + (value, messageSet) + } +} diff --git a/core/src/test/scala/unit/kafka/log/BrokerCompressionTest.scala b/core/src/test/scala/unit/kafka/log/BrokerCompressionTest.scala new file mode 100755 index 0000000..f308b54 --- /dev/null +++ b/core/src/test/scala/unit/kafka/log/BrokerCompressionTest.scala @@ -0,0 +1,88 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import kafka.message._ +import kafka.server.{BrokerTopicStats, FetchLogEnd, LogDirFailureChannel} +import kafka.utils._ +import org.apache.kafka.common.record.{CompressionType, MemoryRecords, RecordBatch, SimpleRecord} +import org.apache.kafka.common.utils.Utils +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api._ +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.{Arguments, MethodSource} + +import java.util.Properties +import scala.jdk.CollectionConverters._ + +class BrokerCompressionTest { + + val tmpDir = TestUtils.tempDir() + val logDir = TestUtils.randomPartitionLogDir(tmpDir) + val time = new MockTime(0, 0) + val logConfig = LogConfig() + + @AfterEach + def tearDown(): Unit = { + Utils.delete(tmpDir) + } + + /** + * Test broker-side compression configuration + */ + @ParameterizedTest + @MethodSource(Array("parameters")) + def testBrokerSideCompression(messageCompression: String, brokerCompression: String): Unit = { + val messageCompressionCode = CompressionCodec.getCompressionCodec(messageCompression) + val logProps = new Properties() + logProps.put(LogConfig.CompressionTypeProp, brokerCompression) + /*configure broker-side compression */ + val log = UnifiedLog(logDir, LogConfig(logProps), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler, + time = time, brokerTopicStats = new BrokerTopicStats, maxProducerIdExpirationMs = 60 * 60 * 1000, + producerIdExpirationCheckIntervalMs = LogManager.ProducerIdExpirationCheckIntervalMs, + logDirFailureChannel = new LogDirFailureChannel(10), topicId = None, keepPartitionMetadataFile = true) + + /* append two messages */ + log.appendAsLeader(MemoryRecords.withRecords(CompressionType.forId(messageCompressionCode.codec), 0, + new SimpleRecord("hello".getBytes), new SimpleRecord("there".getBytes)), leaderEpoch = 0) + + def readBatch(offset: Int): RecordBatch = { + val fetchInfo = log.read(offset, + maxLength = 4096, + isolation = FetchLogEnd, + minOneMessage = true) + fetchInfo.records.batches.iterator.next() + } + + if (!brokerCompression.equals("producer")) { + val brokerCompressionCode = BrokerCompressionCodec.getCompressionCodec(brokerCompression) + assertEquals(brokerCompressionCode.codec, readBatch(0).compressionType.id, "Compression at offset 0 should produce " + brokerCompressionCode.name) + } + else + assertEquals(messageCompressionCode.codec, readBatch(0).compressionType.id, "Compression at offset 0 should produce " + messageCompressionCode.name) + } + +} + +object BrokerCompressionTest { + def parameters: java.util.stream.Stream[Arguments] = { + (for (brokerCompression <- BrokerCompressionCodec.brokerCompressionOptions; + messageCompression <- CompressionType.values + ) yield Arguments.of(messageCompression.name, brokerCompression)).asJava.stream() + } +} diff --git a/core/src/test/scala/unit/kafka/log/LocalLogTest.scala b/core/src/test/scala/unit/kafka/log/LocalLogTest.scala new file mode 100644 index 0000000..67e3a79 --- /dev/null +++ b/core/src/test/scala/unit/kafka/log/LocalLogTest.scala @@ -0,0 +1,698 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import java.io.File +import java.nio.channels.ClosedChannelException +import java.nio.charset.StandardCharsets +import java.util.regex.Pattern +import java.util.Collections +import kafka.server.{FetchDataInfo, KafkaConfig, LogDirFailureChannel, LogOffsetMetadata} +import kafka.utils.{MockTime, Scheduler, TestUtils} +import org.apache.kafka.common.{KafkaException, TopicPartition} +import org.apache.kafka.common.errors.KafkaStorageException +import org.apache.kafka.common.record.{CompressionType, MemoryRecords, Record, SimpleRecord} +import org.apache.kafka.common.utils.{Time, Utils} +import org.junit.jupiter.api.Assertions.{assertFalse, _} +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} + +import scala.jdk.CollectionConverters._ + +class LocalLogTest { + + var config: KafkaConfig = null + val tmpDir: File = TestUtils.tempDir() + val logDir: File = TestUtils.randomPartitionLogDir(tmpDir) + val topicPartition = new TopicPartition("test_topic", 1) + val logDirFailureChannel = new LogDirFailureChannel(10) + val mockTime = new MockTime() + val log: LocalLog = createLocalLogWithActiveSegment(config = LogTestUtils.createLogConfig()) + + @BeforeEach + def setUp(): Unit = { + val props = TestUtils.createBrokerConfig(0, "127.0.0.1:1", port = -1) + config = KafkaConfig.fromProps(props) + } + + @AfterEach + def tearDown(): Unit = { + try { + log.close() + } catch { + case _: KafkaStorageException => { + // ignore + } + } + Utils.delete(tmpDir) + } + + case class KeyValue(key: String, value: String) { + def toRecord(timestamp: => Long = mockTime.milliseconds): SimpleRecord = { + new SimpleRecord(timestamp, key.getBytes, value.getBytes) + } + } + + object KeyValue { + def fromRecord(record: Record): KeyValue = { + val key = + if (record.hasKey) + StandardCharsets.UTF_8.decode(record.key()).toString + else + "" + val value = + if (record.hasValue) + StandardCharsets.UTF_8.decode(record.value()).toString + else + "" + KeyValue(key, value) + } + } + + private def kvsToRecords(keyValues: Iterable[KeyValue]): Iterable[SimpleRecord] = { + keyValues.map(kv => kv.toRecord()) + } + + private def recordsToKvs(records: Iterable[Record]): Iterable[KeyValue] = { + records.map(r => KeyValue.fromRecord(r)) + } + + private def appendRecords(records: Iterable[SimpleRecord], + log: LocalLog = log, + initialOffset: Long = 0L): Unit = { + log.append(lastOffset = initialOffset + records.size - 1, + largestTimestamp = records.head.timestamp, + shallowOffsetOfMaxTimestamp = initialOffset, + records = MemoryRecords.withRecords(initialOffset, CompressionType.NONE, 0, records.toList : _*)) + } + + private def readRecords(log: LocalLog = log, + startOffset: Long = 0L, + maxLength: => Int = log.segments.activeSegment.size, + minOneMessage: Boolean = false, + maxOffsetMetadata: => LogOffsetMetadata = log.logEndOffsetMetadata, + includeAbortedTxns: Boolean = false): FetchDataInfo = { + log.read(startOffset, + maxLength, + minOneMessage = minOneMessage, + maxOffsetMetadata, + includeAbortedTxns = includeAbortedTxns) + } + + @Test + def testLogDeleteSegmentsSuccess(): Unit = { + val record = new SimpleRecord(mockTime.milliseconds, "a".getBytes) + appendRecords(List(record)) + log.roll() + assertEquals(2, log.segments.numberOfSegments) + assertFalse(logDir.listFiles.isEmpty) + val segmentsBeforeDelete = List[LogSegment]() ++ log.segments.values + val deletedSegments = log.deleteAllSegments() + assertTrue(log.segments.isEmpty) + assertEquals(segmentsBeforeDelete, deletedSegments) + assertThrows(classOf[KafkaStorageException], () => log.checkIfMemoryMappedBufferClosed()) + assertTrue(logDir.exists) + } + + @Test + def testLogDeleteDirSuccessWhenEmptyAndFailureWhenNonEmpty(): Unit ={ + val record = new SimpleRecord(mockTime.milliseconds, "a".getBytes) + appendRecords(List(record)) + log.roll() + assertEquals(2, log.segments.numberOfSegments) + assertFalse(logDir.listFiles.isEmpty) + + assertThrows(classOf[IllegalStateException], () => log.deleteEmptyDir()) + assertTrue(logDir.exists) + + log.deleteAllSegments() + log.deleteEmptyDir() + assertFalse(logDir.exists) + } + + @Test + def testUpdateConfig(): Unit = { + val oldConfig = log.config + assertEquals(oldConfig, log.config) + + val newConfig = LogTestUtils.createLogConfig(segmentBytes = oldConfig.segmentSize + 1) + log.updateConfig(newConfig) + assertEquals(newConfig, log.config) + } + + @Test + def testLogDirRenameToNewDir(): Unit = { + val record = new SimpleRecord(mockTime.milliseconds, "a".getBytes) + appendRecords(List(record)) + log.roll() + assertEquals(2, log.segments.numberOfSegments) + val newLogDir = TestUtils.randomPartitionLogDir(tmpDir) + assertTrue(log.renameDir(newLogDir.getName)) + assertFalse(logDir.exists()) + assertTrue(newLogDir.exists()) + assertEquals(newLogDir, log.dir) + assertEquals(newLogDir.getParent, log.parentDir) + assertEquals(newLogDir.getParent, log.dir.getParent) + log.segments.values.foreach(segment => assertEquals(newLogDir.getPath, segment.log.file().getParentFile.getPath)) + assertEquals(2, log.segments.numberOfSegments) + } + + @Test + def testLogDirRenameToExistingDir(): Unit = { + assertFalse(log.renameDir(log.dir.getName)) + } + + @Test + def testLogFlush(): Unit = { + assertEquals(0L, log.recoveryPoint) + assertEquals(mockTime.milliseconds, log.lastFlushTime) + + val record = new SimpleRecord(mockTime.milliseconds, "a".getBytes) + appendRecords(List(record)) + mockTime.sleep(1) + val newSegment = log.roll() + log.flush(newSegment.baseOffset) + log.markFlushed(newSegment.baseOffset) + assertEquals(1L, log.recoveryPoint) + assertEquals(mockTime.milliseconds, log.lastFlushTime) + } + + @Test + def testLogAppend(): Unit = { + val fetchDataInfoBeforeAppend = readRecords(maxLength = 1) + assertTrue(fetchDataInfoBeforeAppend.records.records.asScala.isEmpty) + + mockTime.sleep(1) + val keyValues = Seq(KeyValue("abc", "ABC"), KeyValue("de", "DE")) + appendRecords(kvsToRecords(keyValues)) + assertEquals(2L, log.logEndOffset) + assertEquals(0L, log.recoveryPoint) + val fetchDataInfo = readRecords() + assertEquals(2L, fetchDataInfo.records.records.asScala.size) + assertEquals(keyValues, recordsToKvs(fetchDataInfo.records.records.asScala)) + } + + @Test + def testLogCloseSuccess(): Unit = { + val keyValues = Seq(KeyValue("abc", "ABC"), KeyValue("de", "DE")) + appendRecords(kvsToRecords(keyValues)) + log.close() + assertThrows(classOf[ClosedChannelException], () => appendRecords(kvsToRecords(keyValues), initialOffset = 2L)) + } + + @Test + def testLogCloseIdempotent(): Unit = { + log.close() + // Check that LocalLog.close() is idempotent + log.close() + } + + @Test + def testLogCloseFailureWhenInMemoryBufferClosed(): Unit = { + val keyValues = Seq(KeyValue("abc", "ABC"), KeyValue("de", "DE")) + appendRecords(kvsToRecords(keyValues)) + log.closeHandlers() + assertThrows(classOf[KafkaStorageException], () => log.close()) + } + + @Test + def testLogCloseHandlers(): Unit = { + val keyValues = Seq(KeyValue("abc", "ABC"), KeyValue("de", "DE")) + appendRecords(kvsToRecords(keyValues)) + log.closeHandlers() + assertThrows(classOf[ClosedChannelException], + () => appendRecords(kvsToRecords(keyValues), initialOffset = 2L)) + } + + @Test + def testLogCloseHandlersIdempotent(): Unit = { + log.closeHandlers() + // Check that LocalLog.closeHandlers() is idempotent + log.closeHandlers() + } + + private def testRemoveAndDeleteSegments(asyncDelete: Boolean): Unit = { + for (offset <- 0 to 8) { + val record = new SimpleRecord(mockTime.milliseconds, "a".getBytes) + appendRecords(List(record), initialOffset = offset) + log.roll() + } + + assertEquals(10L, log.segments.numberOfSegments) + + class TestDeletionReason extends SegmentDeletionReason { + private var _deletedSegments: Iterable[LogSegment] = List[LogSegment]() + + override def logReason(toDelete: List[LogSegment]): Unit = { + _deletedSegments = List[LogSegment]() ++ toDelete + } + + def deletedSegments: Iterable[LogSegment] = _deletedSegments + } + val reason = new TestDeletionReason() + val toDelete = List[LogSegment]() ++ log.segments.values + log.removeAndDeleteSegments(toDelete, asyncDelete = asyncDelete, reason) + if (asyncDelete) { + mockTime.sleep(log.config.fileDeleteDelayMs + 1) + } + assertTrue(log.segments.isEmpty) + assertEquals(toDelete, reason.deletedSegments) + toDelete.foreach(segment => assertTrue(segment.deleted())) + } + + @Test + def testRemoveAndDeleteSegmentsSync(): Unit = { + testRemoveAndDeleteSegments(asyncDelete = false) + } + + @Test + def testRemoveAndDeleteSegmentsAsync(): Unit = { + testRemoveAndDeleteSegments(asyncDelete = true) + } + + private def testDeleteSegmentFiles(asyncDelete: Boolean): Unit = { + for (offset <- 0 to 8) { + val record = new SimpleRecord(mockTime.milliseconds, "a".getBytes) + appendRecords(List(record), initialOffset = offset) + log.roll() + } + + assertEquals(10L, log.segments.numberOfSegments) + + val toDelete = List[LogSegment]() ++ log.segments.values + LocalLog.deleteSegmentFiles(toDelete, asyncDelete = asyncDelete, log.dir, log.topicPartition, log.config, log.scheduler, log.logDirFailureChannel, "") + if (asyncDelete) { + toDelete.foreach { + segment => + assertFalse(segment.deleted()) + assertTrue(segment.hasSuffix(LocalLog.DeletedFileSuffix)) + } + mockTime.sleep(log.config.fileDeleteDelayMs + 1) + } + toDelete.foreach(segment => assertTrue(segment.deleted())) + } + + @Test + def testDeleteSegmentFilesSync(): Unit = { + testDeleteSegmentFiles(asyncDelete = false) + } + + @Test + def testDeleteSegmentFilesAsync(): Unit = { + testDeleteSegmentFiles(asyncDelete = true) + } + + @Test + def testDeletableSegmentsFilter(): Unit = { + for (offset <- 0 to 8) { + val record = new SimpleRecord(mockTime.milliseconds, "a".getBytes) + appendRecords(List(record), initialOffset = offset) + log.roll() + } + + assertEquals(10, log.segments.numberOfSegments) + + { + val deletable = log.deletableSegments( + (segment: LogSegment, _: Option[LogSegment]) => segment.baseOffset <= 5) + val expected = log.segments.nonActiveLogSegmentsFrom(0L).filter(segment => segment.baseOffset <= 5).toList + assertEquals(6, expected.length) + assertEquals(expected, deletable.toList) + } + + { + val deletable = log.deletableSegments((_: LogSegment, _: Option[LogSegment]) => true) + val expected = log.segments.nonActiveLogSegmentsFrom(0L).toList + assertEquals(9, expected.length) + assertEquals(expected, deletable.toList) + } + + { + val record = new SimpleRecord(mockTime.milliseconds, "a".getBytes) + appendRecords(List(record), initialOffset = 9L) + val deletable = log.deletableSegments((_: LogSegment, _: Option[LogSegment]) => true) + val expected = log.segments.values.toList + assertEquals(10, expected.length) + assertEquals(expected, deletable.toList) + } + } + + @Test + def testDeletableSegmentsIteration(): Unit = { + for (offset <- 0 to 8) { + val record = new SimpleRecord(mockTime.milliseconds, "a".getBytes) + appendRecords(List(record), initialOffset = offset) + log.roll() + } + + assertEquals(10L, log.segments.numberOfSegments) + + var offset = 0 + val deletableSegments = log.deletableSegments( + (segment: LogSegment, nextSegmentOpt: Option[LogSegment]) => { + assertEquals(offset, segment.baseOffset) + val floorSegmentOpt = log.segments.floorSegment(offset) + assertTrue(floorSegmentOpt.isDefined) + assertEquals(floorSegmentOpt.get, segment) + if (offset == log.logEndOffset) { + assertFalse(nextSegmentOpt.isDefined) + } else { + assertTrue(nextSegmentOpt.isDefined) + val higherSegmentOpt = log.segments.higherSegment(segment.baseOffset) + assertTrue(higherSegmentOpt.isDefined) + assertEquals(segment.baseOffset + 1, higherSegmentOpt.get.baseOffset) + assertEquals(higherSegmentOpt.get, nextSegmentOpt.get) + } + offset += 1 + true + }) + assertEquals(10L, log.segments.numberOfSegments) + assertEquals(log.segments.nonActiveLogSegmentsFrom(0L).toSeq, deletableSegments.toSeq) + } + + @Test + def testTruncateFullyAndStartAt(): Unit = { + val record = new SimpleRecord(mockTime.milliseconds, "a".getBytes) + for (offset <- 0 to 7) { + appendRecords(List(record), initialOffset = offset) + if (offset % 2 != 0) + log.roll() + } + for (offset <- 8 to 12) { + val record = new SimpleRecord(mockTime.milliseconds, "a".getBytes) + appendRecords(List(record), initialOffset = offset) + } + assertEquals(5, log.segments.numberOfSegments) + val expected = List[LogSegment]() ++ log.segments.values + val deleted = log.truncateFullyAndStartAt(10L) + assertEquals(expected, deleted) + assertEquals(1, log.segments.numberOfSegments) + assertEquals(10L, log.segments.activeSegment.baseOffset) + assertEquals(0L, log.recoveryPoint) + assertEquals(10L, log.logEndOffset) + val fetchDataInfo = readRecords(startOffset = 10L) + assertTrue(fetchDataInfo.records.records.asScala.isEmpty) + } + + @Test + def testTruncateTo(): Unit = { + for (offset <- 0 to 11) { + val record = new SimpleRecord(mockTime.milliseconds, "a".getBytes) + appendRecords(List(record), initialOffset = offset) + if (offset % 3 == 2) + log.roll() + } + assertEquals(5, log.segments.numberOfSegments) + assertEquals(12L, log.logEndOffset) + + val expected = List[LogSegment]() ++ log.segments.values(9L, log.logEndOffset + 1) + // Truncate to an offset before the base offset of the active segment + val deleted = log.truncateTo(7L) + assertEquals(expected, deleted) + assertEquals(3, log.segments.numberOfSegments) + assertEquals(6L, log.segments.activeSegment.baseOffset) + assertEquals(0L, log.recoveryPoint) + assertEquals(7L, log.logEndOffset) + val fetchDataInfo = readRecords(startOffset = 6L) + assertEquals(1, fetchDataInfo.records.records.asScala.size) + assertEquals(Seq(KeyValue("", "a")), recordsToKvs(fetchDataInfo.records.records.asScala)) + + // Verify that we can still append to the active segment + val record = new SimpleRecord(mockTime.milliseconds, "a".getBytes) + appendRecords(List(record), initialOffset = 7L) + assertEquals(8L, log.logEndOffset) + } + + @Test + def testNonActiveSegmentsFrom(): Unit = { + for (i <- 0 until 5) { + val keyValues = Seq(KeyValue(i.toString, i.toString)) + appendRecords(kvsToRecords(keyValues), initialOffset = i) + log.roll() + } + + def nonActiveBaseOffsetsFrom(startOffset: Long): Seq[Long] = { + log.segments.nonActiveLogSegmentsFrom(startOffset).map(_.baseOffset).toSeq + } + + assertEquals(5L, log.segments.activeSegment.baseOffset) + assertEquals(0 until 5, nonActiveBaseOffsetsFrom(0L)) + assertEquals(Seq.empty, nonActiveBaseOffsetsFrom(5L)) + assertEquals(2 until 5, nonActiveBaseOffsetsFrom(2L)) + assertEquals(Seq.empty, nonActiveBaseOffsetsFrom(6L)) + } + + private def topicPartitionName(topic: String, partition: String): String = topic + "-" + partition + + @Test + def testParseTopicPartitionName(): Unit = { + val topic = "test_topic" + val partition = "143" + val dir = new File(logDir, topicPartitionName(topic, partition)) + val topicPartition = LocalLog.parseTopicPartitionName(dir) + assertEquals(topic, topicPartition.topic) + assertEquals(partition.toInt, topicPartition.partition) + } + + /** + * Tests that log directories with a period in their name that have been marked for deletion + * are parsed correctly by `Log.parseTopicPartitionName` (see KAFKA-5232 for details). + */ + @Test + def testParseTopicPartitionNameWithPeriodForDeletedTopic(): Unit = { + val topic = "foo.bar-testtopic" + val partition = "42" + val dir = new File(logDir, LocalLog.logDeleteDirName(new TopicPartition(topic, partition.toInt))) + val topicPartition = LocalLog.parseTopicPartitionName(dir) + assertEquals(topic, topicPartition.topic, "Unexpected topic name parsed") + assertEquals(partition.toInt, topicPartition.partition, "Unexpected partition number parsed") + } + + @Test + def testParseTopicPartitionNameForEmptyName(): Unit = { + val dir = new File("") + assertThrows(classOf[KafkaException], () => LocalLog.parseTopicPartitionName(dir), + () => "KafkaException should have been thrown for dir: " + dir.getCanonicalPath) + } + + @Test + def testParseTopicPartitionNameForNull(): Unit = { + val dir: File = null + assertThrows(classOf[KafkaException], () => LocalLog.parseTopicPartitionName(dir), + () => "KafkaException should have been thrown for dir: " + dir) + } + + @Test + def testParseTopicPartitionNameForMissingSeparator(): Unit = { + val topic = "test_topic" + val partition = "1999" + val dir = new File(logDir, topic + partition) + assertThrows(classOf[KafkaException], () => LocalLog.parseTopicPartitionName(dir), + () => "KafkaException should have been thrown for dir: " + dir.getCanonicalPath) + // also test the "-delete" marker case + val deleteMarkerDir = new File(logDir, topic + partition + "." + LocalLog.DeleteDirSuffix) + assertThrows(classOf[KafkaException], () => LocalLog.parseTopicPartitionName(deleteMarkerDir), + () => "KafkaException should have been thrown for dir: " + deleteMarkerDir.getCanonicalPath) + } + + @Test + def testParseTopicPartitionNameForMissingTopic(): Unit = { + val topic = "" + val partition = "1999" + val dir = new File(logDir, topicPartitionName(topic, partition)) + assertThrows(classOf[KafkaException], () => LocalLog.parseTopicPartitionName(dir), + () => "KafkaException should have been thrown for dir: " + dir.getCanonicalPath) + + // also test the "-delete" marker case + val deleteMarkerDir = new File(logDir, LocalLog.logDeleteDirName(new TopicPartition(topic, partition.toInt))) + + assertThrows(classOf[KafkaException], () => LocalLog.parseTopicPartitionName(deleteMarkerDir), + () => "KafkaException should have been thrown for dir: " + deleteMarkerDir.getCanonicalPath) + } + + @Test + def testParseTopicPartitionNameForMissingPartition(): Unit = { + val topic = "test_topic" + val partition = "" + val dir = new File(logDir.getPath + topicPartitionName(topic, partition)) + assertThrows(classOf[KafkaException], () => LocalLog.parseTopicPartitionName(dir), + () => "KafkaException should have been thrown for dir: " + dir.getCanonicalPath) + + // also test the "-delete" marker case + val deleteMarkerDir = new File(logDir, topicPartitionName(topic, partition) + "." + LocalLog.DeleteDirSuffix) + assertThrows(classOf[KafkaException], () => LocalLog.parseTopicPartitionName(deleteMarkerDir), + () => "KafkaException should have been thrown for dir: " + deleteMarkerDir.getCanonicalPath) + } + + @Test + def testParseTopicPartitionNameForInvalidPartition(): Unit = { + val topic = "test_topic" + val partition = "1999a" + val dir = new File(logDir, topicPartitionName(topic, partition)) + assertThrows(classOf[KafkaException], () => LocalLog.parseTopicPartitionName(dir), + () => "KafkaException should have been thrown for dir: " + dir.getCanonicalPath) + + // also test the "-delete" marker case + val deleteMarkerDir = new File(logDir, topic + partition + "." + LocalLog.DeleteDirSuffix) + assertThrows(classOf[KafkaException], () => LocalLog.parseTopicPartitionName(deleteMarkerDir), + () => "KafkaException should have been thrown for dir: " + deleteMarkerDir.getCanonicalPath) + } + + @Test + def testParseTopicPartitionNameForExistingInvalidDir(): Unit = { + val dir1 = new File(logDir.getPath + "/non_kafka_dir") + assertThrows(classOf[KafkaException], () => LocalLog.parseTopicPartitionName(dir1), + () => "KafkaException should have been thrown for dir: " + dir1.getCanonicalPath) + val dir2 = new File(logDir.getPath + "/non_kafka_dir-delete") + assertThrows(classOf[KafkaException], () => LocalLog.parseTopicPartitionName(dir2), + () => "KafkaException should have been thrown for dir: " + dir2.getCanonicalPath) + } + + @Test + def testLogDeleteDirName(): Unit = { + val name1 = LocalLog.logDeleteDirName(new TopicPartition("foo", 3)) + assertTrue(name1.length <= 255) + assertTrue(Pattern.compile("foo-3\\.[0-9a-z]{32}-delete").matcher(name1).matches()) + assertTrue(LocalLog.DeleteDirPattern.matcher(name1).matches()) + assertFalse(LocalLog.FutureDirPattern.matcher(name1).matches()) + val name2 = LocalLog.logDeleteDirName( + new TopicPartition("n" + String.join("", Collections.nCopies(248, "o")), 5)) + assertEquals(255, name2.length) + assertTrue(Pattern.compile("n[o]{212}-5\\.[0-9a-z]{32}-delete").matcher(name2).matches()) + assertTrue(LocalLog.DeleteDirPattern.matcher(name2).matches()) + assertFalse(LocalLog.FutureDirPattern.matcher(name2).matches()) + } + + @Test + def testOffsetFromFile(): Unit = { + val offset = 23423423L + + val logFile = LocalLog.logFile(tmpDir, offset) + assertEquals(offset, LocalLog.offsetFromFile(logFile)) + + val offsetIndexFile = LocalLog.offsetIndexFile(tmpDir, offset) + assertEquals(offset, LocalLog.offsetFromFile(offsetIndexFile)) + + val timeIndexFile = LocalLog.timeIndexFile(tmpDir, offset) + assertEquals(offset, LocalLog.offsetFromFile(timeIndexFile)) + } + + @Test + def testRollSegmentThatAlreadyExists(): Unit = { + assertEquals(1, log.segments.numberOfSegments, "Log begins with a single empty segment.") + + // roll active segment with the same base offset of size zero should recreate the segment + log.roll(Some(0L)) + assertEquals(1, log.segments.numberOfSegments, "Expect 1 segment after roll() empty segment with base offset.") + + // should be able to append records to active segment + val keyValues1 = List(KeyValue("k1", "v1")) + appendRecords(kvsToRecords(keyValues1)) + assertEquals(0L, log.segments.activeSegment.baseOffset) + // make sure we can append more records + val keyValues2 = List(KeyValue("k2", "v2")) + appendRecords(keyValues2.map(_.toRecord(mockTime.milliseconds + 10)), initialOffset = 1L) + assertEquals(2, log.logEndOffset, "Expect two records in the log") + val readResult = readRecords() + assertEquals(2L, readResult.records.records.asScala.size) + assertEquals(keyValues1 ++ keyValues2, recordsToKvs(readResult.records.records.asScala)) + + // roll so that active segment is empty + log.roll() + assertEquals(2L, log.segments.activeSegment.baseOffset, "Expect base offset of active segment to be LEO") + assertEquals(2, log.segments.numberOfSegments, "Expect two segments.") + assertEquals(2L, log.logEndOffset) + } + + @Test + def testNewSegmentsAfterRoll(): Unit = { + assertEquals(1, log.segments.numberOfSegments, "Log begins with a single empty segment.") + + // roll active segment with the same base offset of size zero should recreate the segment + { + val newSegment = log.roll() + assertEquals(0L, newSegment.baseOffset) + assertEquals(1, log.segments.numberOfSegments) + assertEquals(0L, log.logEndOffset) + } + + appendRecords(List(KeyValue("k1", "v1").toRecord())) + + { + val newSegment = log.roll() + assertEquals(1L, newSegment.baseOffset) + assertEquals(2, log.segments.numberOfSegments) + assertEquals(1L, log.logEndOffset) + } + + appendRecords(List(KeyValue("k2", "v2").toRecord()), initialOffset = 1L) + + { + val newSegment = log.roll(Some(1L)) + assertEquals(2L, newSegment.baseOffset) + assertEquals(3, log.segments.numberOfSegments) + assertEquals(2L, log.logEndOffset) + } + } + + @Test + def testRollSegmentErrorWhenNextOffsetIsIllegal(): Unit = { + assertEquals(1, log.segments.numberOfSegments, "Log begins with a single empty segment.") + + val keyValues = List(KeyValue("k1", "v1"), KeyValue("k2", "v2"), KeyValue("k3", "v3")) + appendRecords(kvsToRecords(keyValues)) + assertEquals(0L, log.segments.activeSegment.baseOffset) + assertEquals(3, log.logEndOffset, "Expect two records in the log") + + // roll to create an empty active segment + log.roll() + assertEquals(3L, log.segments.activeSegment.baseOffset) + + // intentionally setup the logEndOffset to introduce an error later + log.updateLogEndOffset(1L) + + // expect an error because of attempt to roll to a new offset (1L) that's lower than the + // base offset (3L) of the active segment + assertThrows(classOf[KafkaException], () => log.roll()) + } + + private def createLocalLogWithActiveSegment(dir: File = logDir, + config: LogConfig, + segments: LogSegments = new LogSegments(topicPartition), + recoveryPoint: Long = 0L, + nextOffsetMetadata: LogOffsetMetadata = new LogOffsetMetadata(0L, 0L, 0), + scheduler: Scheduler = mockTime.scheduler, + time: Time = mockTime, + topicPartition: TopicPartition = topicPartition, + logDirFailureChannel: LogDirFailureChannel = logDirFailureChannel): LocalLog = { + segments.add(LogSegment.open(dir = dir, + baseOffset = 0L, + config, + time = time, + initFileSize = config.initFileSize, + preallocate = config.preallocate)) + new LocalLog(_dir = dir, + config = config, + segments = segments, + recoveryPoint = recoveryPoint, + nextOffsetMetadata = nextOffsetMetadata, + scheduler = scheduler, + time = time, + topicPartition = topicPartition, + logDirFailureChannel = logDirFailureChannel) + } +} diff --git a/core/src/test/scala/unit/kafka/log/LogCleanerIntegrationTest.scala b/core/src/test/scala/unit/kafka/log/LogCleanerIntegrationTest.scala new file mode 100644 index 0000000..c979743 --- /dev/null +++ b/core/src/test/scala/unit/kafka/log/LogCleanerIntegrationTest.scala @@ -0,0 +1,233 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import java.io.PrintWriter + +import com.yammer.metrics.core.{Gauge, MetricName} +import kafka.metrics.{KafkaMetricsGroup, KafkaYammerMetrics} +import kafka.utils.{MockTime, TestUtils} +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.record.{CompressionType, RecordBatch} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, Test} + +import scala.collection.{Iterable, Seq} +import scala.jdk.CollectionConverters._ + +/** + * This is an integration test that tests the fully integrated log cleaner + */ +class LogCleanerIntegrationTest extends AbstractLogCleanerIntegrationTest with KafkaMetricsGroup { + + val codec: CompressionType = CompressionType.LZ4 + + val time = new MockTime() + val topicPartitions = Array(new TopicPartition("log", 0), new TopicPartition("log", 1), new TopicPartition("log", 2)) + + @AfterEach + def cleanup(): Unit = { + TestUtils.clearYammerMetrics() + } + + @Test + def testMarksPartitionsAsOfflineAndPopulatesUncleanableMetrics(): Unit = { + val largeMessageKey = 20 + val (_, largeMessageSet) = createLargeSingleMessageSet(largeMessageKey, RecordBatch.CURRENT_MAGIC_VALUE, codec) + val maxMessageSize = largeMessageSet.sizeInBytes + cleaner = makeCleaner(partitions = topicPartitions, maxMessageSize = maxMessageSize, backOffMs = 100) + + def breakPartitionLog(tp: TopicPartition): Unit = { + val log = cleaner.logs.get(tp) + writeDups(numKeys = 20, numDups = 3, log = log, codec = codec) + + val partitionFile = log.logSegments.last.log.file() + val writer = new PrintWriter(partitionFile) + writer.write("jogeajgoea") + writer.close() + + writeDups(numKeys = 20, numDups = 3, log = log, codec = codec) + } + + breakPartitionLog(topicPartitions(0)) + breakPartitionLog(topicPartitions(1)) + + cleaner.startup() + + val log = cleaner.logs.get(topicPartitions(0)) + val log2 = cleaner.logs.get(topicPartitions(1)) + val uncleanableDirectory = log.dir.getParent + val uncleanablePartitionsCountGauge = getGauge[Int]("uncleanable-partitions-count", uncleanableDirectory) + val uncleanableBytesGauge = getGauge[Long]("uncleanable-bytes", uncleanableDirectory) + + TestUtils.waitUntilTrue(() => uncleanablePartitionsCountGauge.value() == 2, "There should be 2 uncleanable partitions", 2000L) + val expectedTotalUncleanableBytes = LogCleanerManager.calculateCleanableBytes(log, 0, log.logSegments.last.baseOffset)._2 + + LogCleanerManager.calculateCleanableBytes(log2, 0, log2.logSegments.last.baseOffset)._2 + TestUtils.waitUntilTrue(() => uncleanableBytesGauge.value() == expectedTotalUncleanableBytes, + s"There should be $expectedTotalUncleanableBytes uncleanable bytes", 1000L) + + val uncleanablePartitions = cleaner.cleanerManager.uncleanablePartitions(uncleanableDirectory) + assertTrue(uncleanablePartitions.contains(topicPartitions(0))) + assertTrue(uncleanablePartitions.contains(topicPartitions(1))) + assertFalse(uncleanablePartitions.contains(topicPartitions(2))) + + // Delete one partition + cleaner.logs.remove(topicPartitions(0)) + TestUtils.waitUntilTrue( + () => { + time.sleep(1000) + uncleanablePartitionsCountGauge.value() == 1 + }, + "There should be 1 uncleanable partitions", + 2000L) + + val uncleanablePartitions2 = cleaner.cleanerManager.uncleanablePartitions(uncleanableDirectory) + assertFalse(uncleanablePartitions2.contains(topicPartitions(0))) + assertTrue(uncleanablePartitions2.contains(topicPartitions(1))) + assertFalse(uncleanablePartitions2.contains(topicPartitions(2))) + } + + private def getGauge[T](filter: MetricName => Boolean): Gauge[T] = { + KafkaYammerMetrics.defaultRegistry.allMetrics.asScala + .filter { case (k, _) => filter(k) } + .headOption + .getOrElse { fail(s"Unable to find metric") } + .asInstanceOf[(Any, Gauge[Any])] + ._2 + .asInstanceOf[Gauge[T]] + } + + private def getGauge[T](metricName: String): Gauge[T] = { + getGauge(mName => mName.getName.endsWith(metricName) && mName.getScope == null) + } + + private def getGauge[T](metricName: String, metricScope: String): Gauge[T] = { + getGauge(k => k.getName.endsWith(metricName) && k.getScope.endsWith(metricScope)) + } + + @Test + def testMaxLogCompactionLag(): Unit = { + val msPerHour = 60 * 60 * 1000 + + val minCompactionLagMs = 1 * msPerHour + val maxCompactionLagMs = 6 * msPerHour + + val cleanerBackOffMs = 200L + val segmentSize = 512 + val topicPartitions = Array(new TopicPartition("log", 0), new TopicPartition("log", 1), new TopicPartition("log", 2)) + val minCleanableDirtyRatio = 1.0F + + cleaner = makeCleaner(partitions = topicPartitions, + backOffMs = cleanerBackOffMs, + minCompactionLagMs = minCompactionLagMs, + segmentSize = segmentSize, + maxCompactionLagMs= maxCompactionLagMs, + minCleanableDirtyRatio = minCleanableDirtyRatio) + val log = cleaner.logs.get(topicPartitions(0)) + + val T0 = time.milliseconds + writeKeyDups(numKeys = 100, numDups = 3, log, CompressionType.NONE, timestamp = T0, startValue = 0, step = 1) + + val startSizeBlock0 = log.size + + val activeSegAtT0 = log.activeSegment + + cleaner.startup() + + // advance to a time still less than maxCompactionLagMs from start + time.sleep(maxCompactionLagMs/2) + Thread.sleep(5 * cleanerBackOffMs) // give cleaning thread a chance to _not_ clean + assertEquals(startSizeBlock0, log.size, "There should be no cleaning until the max compaction lag has passed") + + // advance to time a bit more than one maxCompactionLagMs from start + time.sleep(maxCompactionLagMs/2 + 1) + val T1 = time.milliseconds + + // write the second block of data: all zero keys + val appends1 = writeKeyDups(numKeys = 100, numDups = 1, log, CompressionType.NONE, timestamp = T1, startValue = 0, step = 0) + + // roll the active segment + log.roll() + val activeSegAtT1 = log.activeSegment + val firstBlockCleanableSegmentOffset = activeSegAtT0.baseOffset + + // the first block should get cleaned + cleaner.awaitCleaned(new TopicPartition("log", 0), firstBlockCleanableSegmentOffset) + + val read1 = readFromLog(log) + val lastCleaned = cleaner.cleanerManager.allCleanerCheckpoints(new TopicPartition("log", 0)) + assertTrue(lastCleaned >= firstBlockCleanableSegmentOffset, + s"log cleaner should have processed at least to offset $firstBlockCleanableSegmentOffset, but lastCleaned=$lastCleaned") + + //minCleanableDirtyRatio will prevent second block of data from compacting + assertNotEquals(appends1, read1, s"log should still contain non-zero keys") + + time.sleep(maxCompactionLagMs + 1) + // the second block should get cleaned. only zero keys left + cleaner.awaitCleaned(new TopicPartition("log", 0), activeSegAtT1.baseOffset) + + val read2 = readFromLog(log) + + assertEquals(appends1, read2, s"log should only contains zero keys now") + + val lastCleaned2 = cleaner.cleanerManager.allCleanerCheckpoints(new TopicPartition("log", 0)) + val secondBlockCleanableSegmentOffset = activeSegAtT1.baseOffset + assertTrue(lastCleaned2 >= secondBlockCleanableSegmentOffset, + s"log cleaner should have processed at least to offset $secondBlockCleanableSegmentOffset, but lastCleaned=$lastCleaned2") + } + + private def readFromLog(log: UnifiedLog): Iterable[(Int, Int)] = { + for (segment <- log.logSegments; record <- segment.log.records.asScala) yield { + val key = TestUtils.readString(record.key).toInt + val value = TestUtils.readString(record.value).toInt + key -> value + } + } + + private def writeKeyDups(numKeys: Int, numDups: Int, log: UnifiedLog, codec: CompressionType, timestamp: Long, + startValue: Int, step: Int): Seq[(Int, Int)] = { + var valCounter = startValue + for (_ <- 0 until numDups; key <- 0 until numKeys) yield { + val curValue = valCounter + log.appendAsLeader(TestUtils.singletonRecords(value = curValue.toString.getBytes, codec = codec, + key = key.toString.getBytes, timestamp = timestamp), leaderEpoch = 0) + // move LSO forward to increase compaction bound + log.updateHighWatermark(log.logEndOffset) + valCounter += step + (key, curValue) + } + } + + @Test + def testIsThreadFailed(): Unit = { + val metricName = "DeadThreadCount" + cleaner = makeCleaner(partitions = topicPartitions, maxMessageSize = 100000, backOffMs = 100) + cleaner.startup() + assertEquals(0, cleaner.deadThreadCount) + // we simulate the unexpected error with an interrupt + cleaner.cleaners.foreach(_.interrupt()) + // wait until interruption is propagated to all the threads + TestUtils.waitUntilTrue( + () => cleaner.cleaners.foldLeft(true)((result, thread) => { + thread.isThreadFailed && result + }), "Threads didn't terminate unexpectedly" + ) + assertEquals(cleaner.cleaners.size, getGauge[Int](metricName).value()) + assertEquals(cleaner.cleaners.size, cleaner.deadThreadCount) + } +} diff --git a/core/src/test/scala/unit/kafka/log/LogCleanerLagIntegrationTest.scala b/core/src/test/scala/unit/kafka/log/LogCleanerLagIntegrationTest.scala new file mode 100644 index 0000000..7e0a33d --- /dev/null +++ b/core/src/test/scala/unit/kafka/log/LogCleanerLagIntegrationTest.scala @@ -0,0 +1,126 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import kafka.utils._ +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.record.CompressionType +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.{Arguments, MethodSource} + +import scala.collection._ +import scala.jdk.CollectionConverters._ + +/** + * This is an integration test that tests the fully integrated log cleaner + */ +class LogCleanerLagIntegrationTest extends AbstractLogCleanerIntegrationTest with Logging { + val msPerHour = 60 * 60 * 1000 + + val minCompactionLag = 1 * msPerHour + assertTrue(minCompactionLag % 2 == 0, "compactionLag must be divisible by 2 for this test") + + val time = new MockTime(1400000000000L, 1000L) // Tue May 13 16:53:20 UTC 2014 for `currentTimeMs` + val cleanerBackOffMs = 200L + val segmentSize = 512 + + val topicPartitions = Array(new TopicPartition("log", 0), new TopicPartition("log", 1), new TopicPartition("log", 2)) + + @ParameterizedTest + @MethodSource(Array("parameters")) + def cleanerTest(codec: CompressionType): Unit = { + cleaner = makeCleaner(partitions = topicPartitions, + backOffMs = cleanerBackOffMs, + minCompactionLagMs = minCompactionLag, + segmentSize = segmentSize) + val log = cleaner.logs.get(topicPartitions(0)) + + // t = T0 + val T0 = time.milliseconds + val appends0 = writeDups(numKeys = 100, numDups = 3, log, codec, timestamp = T0) + val startSizeBlock0 = log.size + debug(s"total log size at T0: $startSizeBlock0") + + val activeSegAtT0 = log.activeSegment + debug(s"active segment at T0 has base offset: ${activeSegAtT0.baseOffset}") + val sizeUpToActiveSegmentAtT0 = log.logSegments(0L, activeSegAtT0.baseOffset).map(_.size).sum + debug(s"log size up to base offset of active segment at T0: $sizeUpToActiveSegmentAtT0") + + cleaner.startup() + + // T0 < t < T1 + // advance to a time still less than one compaction lag from start + time.sleep(minCompactionLag/2) + Thread.sleep(5 * cleanerBackOffMs) // give cleaning thread a chance to _not_ clean + assertEquals(startSizeBlock0, log.size, "There should be no cleaning until the compaction lag has passed") + + // t = T1 > T0 + compactionLag + // advance to time a bit more than one compaction lag from start + time.sleep(minCompactionLag/2 + 1) + val T1 = time.milliseconds + + // write another block of data + val appends1 = appends0 ++ writeDups(numKeys = 100, numDups = 3, log, codec, timestamp = T1) + val firstBlock1SegmentBaseOffset = activeSegAtT0.baseOffset + + // the first block should get cleaned + cleaner.awaitCleaned(new TopicPartition("log", 0), activeSegAtT0.baseOffset) + + // check the data is the same + val read1 = readFromLog(log) + assertEquals(appends1.toMap, read1.toMap, "Contents of the map shouldn't change.") + + val compactedSize = log.logSegments(0L, activeSegAtT0.baseOffset).map(_.size).sum + debug(s"after cleaning the compacted size up to active segment at T0: $compactedSize") + val lastCleaned = cleaner.cleanerManager.allCleanerCheckpoints(new TopicPartition("log", 0)) + assertTrue(lastCleaned >= firstBlock1SegmentBaseOffset, s"log cleaner should have processed up to offset $firstBlock1SegmentBaseOffset, but lastCleaned=$lastCleaned") + assertTrue(sizeUpToActiveSegmentAtT0 > compactedSize, s"log should have been compacted: size up to offset of active segment at T0=$sizeUpToActiveSegmentAtT0 compacted size=$compactedSize") + } + + private def readFromLog(log: UnifiedLog): Iterable[(Int, Int)] = { + for (segment <- log.logSegments; record <- segment.log.records.asScala) yield { + val key = TestUtils.readString(record.key).toInt + val value = TestUtils.readString(record.value).toInt + key -> value + } + } + + private def writeDups(numKeys: Int, numDups: Int, log: UnifiedLog, codec: CompressionType, timestamp: Long): Seq[(Int, Int)] = { + for (_ <- 0 until numDups; key <- 0 until numKeys) yield { + val count = counter + log.appendAsLeader(TestUtils.singletonRecords(value = counter.toString.getBytes, codec = codec, + key = key.toString.getBytes, timestamp = timestamp), leaderEpoch = 0) + // move LSO forward to increase compaction bound + log.updateHighWatermark(log.logEndOffset) + incCounter() + (key, count) + } + } +} + +object LogCleanerLagIntegrationTest { + def oneParameter: java.util.Collection[Array[String]] = { + val l = new java.util.ArrayList[Array[String]]() + l.add(Array("NONE")) + l + } + + def parameters: java.util.stream.Stream[Arguments] = + java.util.Arrays.stream(CompressionType.values.map(codec => Arguments.of(codec))) +} diff --git a/core/src/test/scala/unit/kafka/log/LogCleanerManagerTest.scala b/core/src/test/scala/unit/kafka/log/LogCleanerManagerTest.scala new file mode 100644 index 0000000..c4a71cc --- /dev/null +++ b/core/src/test/scala/unit/kafka/log/LogCleanerManagerTest.scala @@ -0,0 +1,858 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import java.io.File +import java.nio.file.Files +import java.util.Properties + +import kafka.server.{BrokerTopicStats, LogDirFailureChannel} +import kafka.utils._ +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.record._ +import org.apache.kafka.common.utils.Utils +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, Test} + +import scala.collection.mutable + +/** + * Unit tests for the log cleaning logic + */ +class LogCleanerManagerTest extends Logging { + + val tmpDir = TestUtils.tempDir() + val tmpDir2 = TestUtils.tempDir() + val logDir = TestUtils.randomPartitionLogDir(tmpDir) + val logDir2 = TestUtils.randomPartitionLogDir(tmpDir) + val topicPartition = new TopicPartition("log", 0) + val topicPartition2 = new TopicPartition("log2", 0) + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 1024: java.lang.Integer) + logProps.put(LogConfig.SegmentIndexBytesProp, 1024: java.lang.Integer) + logProps.put(LogConfig.CleanupPolicyProp, LogConfig.Compact) + val logConfig = LogConfig(logProps) + val time = new MockTime(1400000000000L, 1000L) // Tue May 13 16:53:20 UTC 2014 for `currentTimeMs` + val offset = 999 + + val cleanerCheckpoints: mutable.Map[TopicPartition, Long] = mutable.Map[TopicPartition, Long]() + + class LogCleanerManagerMock(logDirs: Seq[File], + logs: Pool[TopicPartition, UnifiedLog], + logDirFailureChannel: LogDirFailureChannel) extends LogCleanerManager(logDirs, logs, logDirFailureChannel) { + override def allCleanerCheckpoints: Map[TopicPartition, Long] = { + cleanerCheckpoints.toMap + } + + override def updateCheckpoints(dataDir: File, partitionToUpdateOrAdd: Option[(TopicPartition,Long)] = None, + partitionToRemove: Option[TopicPartition] = None): Unit = { + assert(partitionToRemove.isEmpty, "partitionToRemove argument with value not yet handled") + val (tp, offset) = partitionToUpdateOrAdd.getOrElse( + throw new IllegalArgumentException("partitionToUpdateOrAdd==None argument not yet handled")) + cleanerCheckpoints.put(tp, offset) + } + } + + @AfterEach + def tearDown(): Unit = { + Utils.delete(tmpDir) + } + + private def setupIncreasinglyFilthyLogs(partitions: Seq[TopicPartition], + startNumBatches: Int, + batchIncrement: Int): Pool[TopicPartition, UnifiedLog] = { + val logs = new Pool[TopicPartition, UnifiedLog]() + var numBatches = startNumBatches + + for (tp <- partitions) { + val log = createLog(2048, LogConfig.Compact, topicPartition = tp) + logs.put(tp, log) + + writeRecords(log, numBatches = numBatches, recordsPerBatch = 1, batchesPerSegment = 5) + numBatches += batchIncrement + } + logs + } + + @Test + def testGrabFilthiestCompactedLogThrowsException(): Unit = { + val tp = new TopicPartition("A", 1) + val logSegmentSize = TestUtils.singletonRecords("test".getBytes).sizeInBytes * 10 + val logSegmentsCount = 2 + val tpDir = new File(logDir, "A-1") + Files.createDirectories(tpDir.toPath) + val logDirFailureChannel = new LogDirFailureChannel(10) + val config = createLowRetentionLogConfig(logSegmentSize, LogConfig.Compact) + val maxProducerIdExpirationMs = 60 * 60 * 1000 + val segments = new LogSegments(tp) + val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(tpDir, topicPartition, logDirFailureChannel, config.recordVersion, "") + val producerStateManager = new ProducerStateManager(topicPartition, tpDir, maxProducerIdExpirationMs, time) + val offsets = LogLoader.load(LoadLogParams( + tpDir, + tp, + config, + time.scheduler, + time, + logDirFailureChannel, + hadCleanShutdown = true, + segments, + 0L, + 0L, + maxProducerIdExpirationMs, + leaderEpochCache, + producerStateManager)) + val localLog = new LocalLog(tpDir, config, segments, offsets.recoveryPoint, + offsets.nextOffsetMetadata, time.scheduler, time, tp, logDirFailureChannel) + // the exception should be caught and the partition that caused it marked as uncleanable + class LogMock extends UnifiedLog(offsets.logStartOffset, localLog, new BrokerTopicStats, + LogManager.ProducerIdExpirationCheckIntervalMs, leaderEpochCache, + producerStateManager, _topicId = None, keepPartitionMetadataFile = true) { + // Throw an error in getFirstBatchTimestampForSegments since it is called in grabFilthiestLog() + override def getFirstBatchTimestampForSegments(segments: Iterable[LogSegment]): Iterable[Long] = + throw new IllegalStateException("Error!") + } + + val log: UnifiedLog = new LogMock() + writeRecords(log = log, + numBatches = logSegmentsCount * 2, + recordsPerBatch = 10, + batchesPerSegment = 2 + ) + + val logsPool = new Pool[TopicPartition, UnifiedLog]() + logsPool.put(tp, log) + val cleanerManager = createCleanerManagerMock(logsPool) + cleanerCheckpoints.put(tp, 1) + + val thrownException = assertThrows(classOf[LogCleaningException], () => cleanerManager.grabFilthiestCompactedLog(time).get) + assertEquals(log, thrownException.log) + assertTrue(thrownException.getCause.isInstanceOf[IllegalStateException]) + } + + @Test + def testGrabFilthiestCompactedLogReturnsLogWithDirtiestRatio(): Unit = { + val tp0 = new TopicPartition("wishing-well", 0) + val tp1 = new TopicPartition("wishing-well", 1) + val tp2 = new TopicPartition("wishing-well", 2) + val partitions = Seq(tp0, tp1, tp2) + + // setup logs with cleanable range: [20, 20], [20, 25], [20, 30] + val logs = setupIncreasinglyFilthyLogs(partitions, startNumBatches = 20, batchIncrement = 5) + val cleanerManager = createCleanerManagerMock(logs) + partitions.foreach(partition => cleanerCheckpoints.put(partition, 20)) + + val filthiestLog: LogToClean = cleanerManager.grabFilthiestCompactedLog(time).get + assertEquals(tp2, filthiestLog.topicPartition) + assertEquals(tp2, filthiestLog.log.topicPartition) + } + + @Test + def testGrabFilthiestCompactedLogIgnoresUncleanablePartitions(): Unit = { + val tp0 = new TopicPartition("wishing-well", 0) + val tp1 = new TopicPartition("wishing-well", 1) + val tp2 = new TopicPartition("wishing-well", 2) + val partitions = Seq(tp0, tp1, tp2) + + // setup logs with cleanable range: [20, 20], [20, 25], [20, 30] + val logs = setupIncreasinglyFilthyLogs(partitions, startNumBatches = 20, batchIncrement = 5) + val cleanerManager = createCleanerManagerMock(logs) + partitions.foreach(partition => cleanerCheckpoints.put(partition, 20)) + + cleanerManager.markPartitionUncleanable(logs.get(tp2).dir.getParent, tp2) + + val filthiestLog: LogToClean = cleanerManager.grabFilthiestCompactedLog(time).get + assertEquals(tp1, filthiestLog.topicPartition) + assertEquals(tp1, filthiestLog.log.topicPartition) + } + + @Test + def testGrabFilthiestCompactedLogIgnoresInProgressPartitions(): Unit = { + val tp0 = new TopicPartition("wishing-well", 0) + val tp1 = new TopicPartition("wishing-well", 1) + val tp2 = new TopicPartition("wishing-well", 2) + val partitions = Seq(tp0, tp1, tp2) + + // setup logs with cleanable range: [20, 20], [20, 25], [20, 30] + val logs = setupIncreasinglyFilthyLogs(partitions, startNumBatches = 20, batchIncrement = 5) + val cleanerManager = createCleanerManagerMock(logs) + partitions.foreach(partition => cleanerCheckpoints.put(partition, 20)) + + cleanerManager.setCleaningState(tp2, LogCleaningInProgress) + + val filthiestLog: LogToClean = cleanerManager.grabFilthiestCompactedLog(time).get + assertEquals(tp1, filthiestLog.topicPartition) + assertEquals(tp1, filthiestLog.log.topicPartition) + } + + @Test + def testGrabFilthiestCompactedLogIgnoresBothInProgressPartitionsAndUncleanablePartitions(): Unit = { + val tp0 = new TopicPartition("wishing-well", 0) + val tp1 = new TopicPartition("wishing-well", 1) + val tp2 = new TopicPartition("wishing-well", 2) + val partitions = Seq(tp0, tp1, tp2) + + // setup logs with cleanable range: [20, 20], [20, 25], [20, 30] + val logs = setupIncreasinglyFilthyLogs(partitions, startNumBatches = 20, batchIncrement = 5) + val cleanerManager = createCleanerManagerMock(logs) + partitions.foreach(partition => cleanerCheckpoints.put(partition, 20)) + + cleanerManager.setCleaningState(tp2, LogCleaningInProgress) + cleanerManager.markPartitionUncleanable(logs.get(tp1).dir.getParent, tp1) + + val filthiestLog: Option[LogToClean] = cleanerManager.grabFilthiestCompactedLog(time) + assertEquals(None, filthiestLog) + } + + @Test + def testDirtyOffsetResetIfLargerThanEndOffset(): Unit = { + val tp = new TopicPartition("foo", 0) + val logs = setupIncreasinglyFilthyLogs(Seq(tp), startNumBatches = 20, batchIncrement = 5) + val cleanerManager = createCleanerManagerMock(logs) + cleanerCheckpoints.put(tp, 200) + + val filthiestLog = cleanerManager.grabFilthiestCompactedLog(time).get + assertEquals(0L, filthiestLog.firstDirtyOffset) + } + + @Test + def testDirtyOffsetResetIfSmallerThanStartOffset(): Unit = { + val tp = new TopicPartition("foo", 0) + val logs = setupIncreasinglyFilthyLogs(Seq(tp), startNumBatches = 20, batchIncrement = 5) + + logs.get(tp).maybeIncrementLogStartOffset(10L, ClientRecordDeletion) + + val cleanerManager = createCleanerManagerMock(logs) + cleanerCheckpoints.put(tp, 0L) + + val filthiestLog = cleanerManager.grabFilthiestCompactedLog(time).get + assertEquals(10L, filthiestLog.firstDirtyOffset) + } + + @Test + def testLogStartOffsetLargerThanActiveSegmentBaseOffset(): Unit = { + val tp = new TopicPartition("foo", 0) + val log = createLog(segmentSize = 2048, LogConfig.Compact, tp) + + val logs = new Pool[TopicPartition, UnifiedLog]() + logs.put(tp, log) + + appendRecords(log, numRecords = 3) + appendRecords(log, numRecords = 3) + appendRecords(log, numRecords = 3) + + assertEquals(1, log.logSegments.size) + + log.maybeIncrementLogStartOffset(2L, ClientRecordDeletion) + + val cleanerManager = createCleanerManagerMock(logs) + cleanerCheckpoints.put(tp, 0L) + + // The active segment is uncleanable and hence not filthy from the POV of the CleanerManager. + val filthiestLog = cleanerManager.grabFilthiestCompactedLog(time) + assertEquals(None, filthiestLog) + } + + @Test + def testDirtyOffsetLargerThanActiveSegmentBaseOffset(): Unit = { + // It is possible in the case of an unclean leader election for the checkpoint + // dirty offset to get ahead of the active segment base offset, but still be + // within the range of the log. + + val tp = new TopicPartition("foo", 0) + + val logs = new Pool[TopicPartition, UnifiedLog]() + val log = createLog(2048, LogConfig.Compact, topicPartition = tp) + logs.put(tp, log) + + appendRecords(log, numRecords = 3) + appendRecords(log, numRecords = 3) + + assertEquals(1, log.logSegments.size) + assertEquals(0L, log.activeSegment.baseOffset) + + val cleanerManager = createCleanerManagerMock(logs) + cleanerCheckpoints.put(tp, 3L) + + // These segments are uncleanable and hence not filthy + val filthiestLog = cleanerManager.grabFilthiestCompactedLog(time) + assertEquals(None, filthiestLog) + } + + /** + * When checking for logs with segments ready for deletion + * we shouldn't consider logs where cleanup.policy=delete + * as they are handled by the LogManager + */ + @Test + def testLogsWithSegmentsToDeleteShouldNotConsiderCleanupPolicyDeleteLogs(): Unit = { + val records = TestUtils.singletonRecords("test".getBytes) + val log: UnifiedLog = createLog(records.sizeInBytes * 5, LogConfig.Delete) + val cleanerManager: LogCleanerManager = createCleanerManager(log) + + val readyToDelete = cleanerManager.deletableLogs().size + assertEquals(0, readyToDelete, "should have 0 logs ready to be deleted") + } + + /** + * We should find logs with segments ready to be deleted when cleanup.policy=compact,delete + */ + @Test + def testLogsWithSegmentsToDeleteShouldConsiderCleanupPolicyCompactDeleteLogs(): Unit = { + val records = TestUtils.singletonRecords("test".getBytes, key="test".getBytes) + val log: UnifiedLog = createLog(records.sizeInBytes * 5, LogConfig.Compact + "," + LogConfig.Delete) + val cleanerManager: LogCleanerManager = createCleanerManager(log) + + val readyToDelete = cleanerManager.deletableLogs().size + assertEquals(1, readyToDelete, "should have 1 logs ready to be deleted") + } + + /** + * When looking for logs with segments ready to be deleted we should consider + * logs with cleanup.policy=compact because they may have segments from before the log start offset + */ + @Test + def testLogsWithSegmentsToDeleteShouldConsiderCleanupPolicyCompactLogs(): Unit = { + val records = TestUtils.singletonRecords("test".getBytes, key="test".getBytes) + val log: UnifiedLog = createLog(records.sizeInBytes * 5, LogConfig.Compact) + val cleanerManager: LogCleanerManager = createCleanerManager(log) + + val readyToDelete = cleanerManager.deletableLogs().size + assertEquals(1, readyToDelete, "should have 1 logs ready to be deleted") + } + + /** + * log under cleanup should be ineligible for compaction + */ + @Test + def testLogsUnderCleanupIneligibleForCompaction(): Unit = { + val records = TestUtils.singletonRecords("test".getBytes, key="test".getBytes) + val log: UnifiedLog = createLog(records.sizeInBytes * 5, LogConfig.Delete) + val cleanerManager: LogCleanerManager = createCleanerManager(log) + + log.appendAsLeader(records, leaderEpoch = 0) + log.roll() + log.appendAsLeader(records, leaderEpoch = 0) + log.updateHighWatermark(2L) + + // simulate cleanup thread working on the log partition + val deletableLog = cleanerManager.pauseCleaningForNonCompactedPartitions() + assertEquals(1, deletableLog.size, "should have 1 logs ready to be deleted") + + // change cleanup policy from delete to compact + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, log.config.segmentSize) + logProps.put(LogConfig.RetentionMsProp, log.config.retentionMs) + logProps.put(LogConfig.CleanupPolicyProp, LogConfig.Compact) + logProps.put(LogConfig.MinCleanableDirtyRatioProp, 0: Integer) + val config = LogConfig(logProps) + log.updateConfig(config) + + // log cleanup inprogress, the log is not available for compaction + val cleanable = cleanerManager.grabFilthiestCompactedLog(time) + assertEquals(0, cleanable.size, "should have 0 logs ready to be compacted") + + // log cleanup finished, and log can be picked up for compaction + cleanerManager.resumeCleaning(deletableLog.map(_._1)) + val cleanable2 = cleanerManager.grabFilthiestCompactedLog(time) + assertEquals(1, cleanable2.size, "should have 1 logs ready to be compacted") + + // update cleanup policy to delete + logProps.put(LogConfig.CleanupPolicyProp, LogConfig.Delete) + val config2 = LogConfig(logProps) + log.updateConfig(config2) + + // compaction in progress, should have 0 log eligible for log cleanup + val deletableLog2 = cleanerManager.pauseCleaningForNonCompactedPartitions() + assertEquals(0, deletableLog2.size, "should have 0 logs ready to be deleted") + + // compaction done, should have 1 log eligible for log cleanup + cleanerManager.doneDeleting(Seq(cleanable2.get.topicPartition)) + val deletableLog3 = cleanerManager.pauseCleaningForNonCompactedPartitions() + assertEquals(1, deletableLog3.size, "should have 1 logs ready to be deleted") + } + + @Test + def testUpdateCheckpointsShouldAddOffsetToPartition(): Unit = { + val records = TestUtils.singletonRecords("test".getBytes, key="test".getBytes) + val log: UnifiedLog = createLog(records.sizeInBytes * 5, LogConfig.Compact) + val cleanerManager: LogCleanerManager = createCleanerManager(log) + + // expect the checkpoint offset is not the expectedOffset before doing updateCheckpoints + assertNotEquals(offset, cleanerManager.allCleanerCheckpoints.get(topicPartition).getOrElse(0)) + + cleanerManager.updateCheckpoints(logDir, partitionToUpdateOrAdd = Option(topicPartition, offset)) + // expect the checkpoint offset is now updated to the expected offset after doing updateCheckpoints + assertEquals(offset, cleanerManager.allCleanerCheckpoints(topicPartition)) + } + + @Test + def testUpdateCheckpointsShouldRemovePartitionData(): Unit = { + val records = TestUtils.singletonRecords("test".getBytes, key="test".getBytes) + val log: UnifiedLog = createLog(records.sizeInBytes * 5, LogConfig.Compact) + val cleanerManager: LogCleanerManager = createCleanerManager(log) + + // write some data into the cleaner-offset-checkpoint file + cleanerManager.updateCheckpoints(logDir, partitionToUpdateOrAdd = Option(topicPartition, offset)) + assertEquals(offset, cleanerManager.allCleanerCheckpoints(topicPartition)) + + // updateCheckpoints should remove the topicPartition data in the logDir + cleanerManager.updateCheckpoints(logDir, partitionToRemove = Option(topicPartition)) + assertTrue(cleanerManager.allCleanerCheckpoints.get(topicPartition).isEmpty) + } + + @Test + def testHandleLogDirFailureShouldRemoveDirAndData(): Unit = { + val records = TestUtils.singletonRecords("test".getBytes, key="test".getBytes) + val log: UnifiedLog = createLog(records.sizeInBytes * 5, LogConfig.Compact) + val cleanerManager: LogCleanerManager = createCleanerManager(log) + + // write some data into the cleaner-offset-checkpoint file in logDir and logDir2 + cleanerManager.updateCheckpoints(logDir, partitionToUpdateOrAdd = Option(topicPartition, offset)) + cleanerManager.updateCheckpoints(logDir2, partitionToUpdateOrAdd = Option(topicPartition2, offset)) + assertEquals(offset, cleanerManager.allCleanerCheckpoints(topicPartition)) + assertEquals(offset, cleanerManager.allCleanerCheckpoints(topicPartition2)) + + cleanerManager.handleLogDirFailure(logDir.getAbsolutePath) + // verify the partition data in logDir is gone, and data in logDir2 is still there + assertEquals(offset, cleanerManager.allCleanerCheckpoints(topicPartition2)) + assertTrue(cleanerManager.allCleanerCheckpoints.get(topicPartition).isEmpty) + } + + @Test + def testMaybeTruncateCheckpointShouldTruncateData(): Unit = { + val records = TestUtils.singletonRecords("test".getBytes, key="test".getBytes) + val log: UnifiedLog = createLog(records.sizeInBytes * 5, LogConfig.Compact) + val cleanerManager: LogCleanerManager = createCleanerManager(log) + val lowerOffset = 1L + val higherOffset = 1000L + + // write some data into the cleaner-offset-checkpoint file in logDir + cleanerManager.updateCheckpoints(logDir, partitionToUpdateOrAdd = Option(topicPartition, offset)) + assertEquals(offset, cleanerManager.allCleanerCheckpoints(topicPartition)) + + // we should not truncate the checkpoint data for checkpointed offset <= the given offset (higherOffset) + cleanerManager.maybeTruncateCheckpoint(logDir, topicPartition, higherOffset) + assertEquals(offset, cleanerManager.allCleanerCheckpoints(topicPartition)) + // we should truncate the checkpoint data for checkpointed offset > the given offset (lowerOffset) + cleanerManager.maybeTruncateCheckpoint(logDir, topicPartition, lowerOffset) + assertEquals(lowerOffset, cleanerManager.allCleanerCheckpoints(topicPartition)) + } + + @Test + def testAlterCheckpointDirShouldRemoveDataInSrcDirAndAddInNewDir(): Unit = { + val records = TestUtils.singletonRecords("test".getBytes, key="test".getBytes) + val log: UnifiedLog = createLog(records.sizeInBytes * 5, LogConfig.Compact) + val cleanerManager: LogCleanerManager = createCleanerManager(log) + + // write some data into the cleaner-offset-checkpoint file in logDir + cleanerManager.updateCheckpoints(logDir, partitionToUpdateOrAdd = Option(topicPartition, offset)) + assertEquals(offset, cleanerManager.allCleanerCheckpoints(topicPartition)) + + cleanerManager.alterCheckpointDir(topicPartition, logDir, logDir2) + // verify we still can get the partition offset after alterCheckpointDir + // This data should locate in logDir2, not logDir + assertEquals(offset, cleanerManager.allCleanerCheckpoints(topicPartition)) + + // force delete the logDir2 from checkpoints, so that the partition data should also be deleted + cleanerManager.handleLogDirFailure(logDir2.getAbsolutePath) + assertTrue(cleanerManager.allCleanerCheckpoints.get(topicPartition).isEmpty) + } + + /** + * log under cleanup should still be eligible for log truncation + */ + @Test + def testConcurrentLogCleanupAndLogTruncation(): Unit = { + val records = TestUtils.singletonRecords("test".getBytes, key="test".getBytes) + val log: UnifiedLog = createLog(records.sizeInBytes * 5, LogConfig.Delete) + val cleanerManager: LogCleanerManager = createCleanerManager(log) + + // log cleanup starts + val pausedPartitions = cleanerManager.pauseCleaningForNonCompactedPartitions() + // Log truncation happens due to unclean leader election + cleanerManager.abortAndPauseCleaning(log.topicPartition) + cleanerManager.resumeCleaning(Seq(log.topicPartition)) + // log cleanup finishes and pausedPartitions are resumed + cleanerManager.resumeCleaning(pausedPartitions.map(_._1)) + + assertEquals(None, cleanerManager.cleaningState(log.topicPartition)) + } + + /** + * log under cleanup should still be eligible for topic deletion + */ + @Test + def testConcurrentLogCleanupAndTopicDeletion(): Unit = { + val records = TestUtils.singletonRecords("test".getBytes, key = "test".getBytes) + val log: UnifiedLog = createLog(records.sizeInBytes * 5, LogConfig.Delete) + val cleanerManager: LogCleanerManager = createCleanerManager(log) + + // log cleanup starts + val pausedPartitions = cleanerManager.pauseCleaningForNonCompactedPartitions() + // Broker processes StopReplicaRequest with delete=true + cleanerManager.abortCleaning(log.topicPartition) + // log cleanup finishes and pausedPartitions are resumed + cleanerManager.resumeCleaning(pausedPartitions.map(_._1)) + + assertEquals(None, cleanerManager.cleaningState(log.topicPartition)) + } + + /** + * When looking for logs with segments ready to be deleted we shouldn't consider + * logs that have had their partition marked as uncleanable. + */ + @Test + def testLogsWithSegmentsToDeleteShouldNotConsiderUncleanablePartitions(): Unit = { + val records = TestUtils.singletonRecords("test".getBytes, key="test".getBytes) + val log: UnifiedLog = createLog(records.sizeInBytes * 5, LogConfig.Compact) + val cleanerManager: LogCleanerManager = createCleanerManager(log) + cleanerManager.markPartitionUncleanable(log.dir.getParent, topicPartition) + + val readyToDelete = cleanerManager.deletableLogs().size + assertEquals(0, readyToDelete, "should have 0 logs ready to be deleted") + } + + /** + * Test computation of cleanable range with no minimum compaction lag settings active where bounded by LSO + */ + @Test + def testCleanableOffsetsForNone(): Unit = { + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 1024: java.lang.Integer) + + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + while(log.numberOfSegments < 8) + log.appendAsLeader(records(log.logEndOffset.toInt, log.logEndOffset.toInt, time.milliseconds()), leaderEpoch = 0) + + log.updateHighWatermark(50) + + val lastCleanOffset = Some(0L) + val cleanableOffsets = LogCleanerManager.cleanableOffsets(log, lastCleanOffset, time.milliseconds) + assertEquals(0L, cleanableOffsets.firstDirtyOffset, "The first cleanable offset starts at the beginning of the log.") + assertEquals(log.highWatermark, log.lastStableOffset, "The high watermark equals the last stable offset as no transactions are in progress") + assertEquals(log.lastStableOffset, cleanableOffsets.firstUncleanableDirtyOffset, "The first uncleanable offset is bounded by the last stable offset.") + } + + /** + * Test computation of cleanable range with no minimum compaction lag settings active where bounded by active segment + */ + @Test + def testCleanableOffsetsActiveSegment(): Unit = { + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 1024: java.lang.Integer) + + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + while(log.numberOfSegments < 8) + log.appendAsLeader(records(log.logEndOffset.toInt, log.logEndOffset.toInt, time.milliseconds()), leaderEpoch = 0) + + log.updateHighWatermark(log.logEndOffset) + + val lastCleanOffset = Some(0L) + val cleanableOffsets = LogCleanerManager.cleanableOffsets(log, lastCleanOffset, time.milliseconds) + assertEquals(0L, cleanableOffsets.firstDirtyOffset, "The first cleanable offset starts at the beginning of the log.") + assertEquals(log.activeSegment.baseOffset, cleanableOffsets.firstUncleanableDirtyOffset, "The first uncleanable offset begins with the active segment.") + } + + /** + * Test computation of cleanable range with a minimum compaction lag time + */ + @Test + def testCleanableOffsetsForTime(): Unit = { + val compactionLag = 60 * 60 * 1000 + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 1024: java.lang.Integer) + logProps.put(LogConfig.MinCompactionLagMsProp, compactionLag: java.lang.Integer) + + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + val t0 = time.milliseconds + while(log.numberOfSegments < 4) + log.appendAsLeader(records(log.logEndOffset.toInt, log.logEndOffset.toInt, t0), leaderEpoch = 0) + + val activeSegAtT0 = log.activeSegment + + time.sleep(compactionLag + 1) + val t1 = time.milliseconds + + while (log.numberOfSegments < 8) + log.appendAsLeader(records(log.logEndOffset.toInt, log.logEndOffset.toInt, t1), leaderEpoch = 0) + + log.updateHighWatermark(log.logEndOffset) + + val lastCleanOffset = Some(0L) + val cleanableOffsets = LogCleanerManager.cleanableOffsets(log, lastCleanOffset, time.milliseconds) + assertEquals(0L, cleanableOffsets.firstDirtyOffset, "The first cleanable offset starts at the beginning of the log.") + assertEquals(activeSegAtT0.baseOffset, cleanableOffsets.firstUncleanableDirtyOffset, "The first uncleanable offset begins with the second block of log entries.") + } + + /** + * Test computation of cleanable range with a minimum compaction lag time that is small enough that + * the active segment contains it. + */ + @Test + def testCleanableOffsetsForShortTime(): Unit = { + val compactionLag = 60 * 60 * 1000 + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 1024: java.lang.Integer) + logProps.put(LogConfig.MinCompactionLagMsProp, compactionLag: java.lang.Integer) + + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + val t0 = time.milliseconds + while (log.numberOfSegments < 8) + log.appendAsLeader(records(log.logEndOffset.toInt, log.logEndOffset.toInt, t0), leaderEpoch = 0) + + log.updateHighWatermark(log.logEndOffset) + + time.sleep(compactionLag + 1) + + val lastCleanOffset = Some(0L) + val cleanableOffsets = LogCleanerManager.cleanableOffsets(log, lastCleanOffset, time.milliseconds) + assertEquals(0L, cleanableOffsets.firstDirtyOffset, "The first cleanable offset starts at the beginning of the log.") + assertEquals(log.activeSegment.baseOffset, cleanableOffsets.firstUncleanableDirtyOffset, "The first uncleanable offset begins with active segment.") + } + + @Test + def testCleanableOffsetsNeedsCheckpointReset(): Unit = { + val tp = new TopicPartition("foo", 0) + val logs = setupIncreasinglyFilthyLogs(Seq(tp), startNumBatches = 20, batchIncrement = 5) + logs.get(tp).maybeIncrementLogStartOffset(10L, ClientRecordDeletion) + + var lastCleanOffset = Some(15L) + var cleanableOffsets = LogCleanerManager.cleanableOffsets(logs.get(tp), lastCleanOffset, time.milliseconds) + assertFalse(cleanableOffsets.forceUpdateCheckpoint, "Checkpoint offset should not be reset if valid") + + logs.get(tp).maybeIncrementLogStartOffset(20L, ClientRecordDeletion) + cleanableOffsets = LogCleanerManager.cleanableOffsets(logs.get(tp), lastCleanOffset, time.milliseconds) + assertTrue(cleanableOffsets.forceUpdateCheckpoint, "Checkpoint offset needs to be reset if less than log start offset") + + lastCleanOffset = Some(25L) + cleanableOffsets = LogCleanerManager.cleanableOffsets(logs.get(tp), lastCleanOffset, time.milliseconds) + assertTrue(cleanableOffsets.forceUpdateCheckpoint, "Checkpoint offset needs to be reset if greater than log end offset") + } + + @Test + def testUndecidedTransactionalDataNotCleanable(): Unit = { + val compactionLag = 60 * 60 * 1000 + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 1024: java.lang.Integer) + logProps.put(LogConfig.MinCompactionLagMsProp, compactionLag: java.lang.Integer) + + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + val producerId = 15L + val producerEpoch = 0.toShort + val sequence = 0 + log.appendAsLeader(MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, producerEpoch, sequence, + new SimpleRecord(time.milliseconds(), "1".getBytes, "a".getBytes), + new SimpleRecord(time.milliseconds(), "2".getBytes, "b".getBytes)), leaderEpoch = 0) + log.appendAsLeader(MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, producerEpoch, sequence + 2, + new SimpleRecord(time.milliseconds(), "3".getBytes, "c".getBytes)), leaderEpoch = 0) + log.roll() + log.updateHighWatermark(3L) + + time.sleep(compactionLag + 1) + // although the compaction lag has been exceeded, the undecided data should not be cleaned + var cleanableOffsets = LogCleanerManager.cleanableOffsets(log, Some(0L), time.milliseconds()) + assertEquals(0L, cleanableOffsets.firstDirtyOffset) + assertEquals(0L, cleanableOffsets.firstUncleanableDirtyOffset) + + log.appendAsLeader(MemoryRecords.withEndTransactionMarker(time.milliseconds(), producerId, producerEpoch, + new EndTransactionMarker(ControlRecordType.ABORT, 15)), leaderEpoch = 0, + origin = AppendOrigin.Coordinator) + log.roll() + log.updateHighWatermark(4L) + + // the first segment should now become cleanable immediately + cleanableOffsets = LogCleanerManager.cleanableOffsets(log, Some(0L), time.milliseconds()) + assertEquals(0L, cleanableOffsets.firstDirtyOffset) + assertEquals(3L, cleanableOffsets.firstUncleanableDirtyOffset) + + time.sleep(compactionLag + 1) + + // the second segment becomes cleanable after the compaction lag + cleanableOffsets = LogCleanerManager.cleanableOffsets(log, Some(0L), time.milliseconds()) + assertEquals(0L, cleanableOffsets.firstDirtyOffset) + assertEquals(4L, cleanableOffsets.firstUncleanableDirtyOffset) + } + + @Test + def testDoneCleaning(): Unit = { + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 1024: java.lang.Integer) + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + while(log.numberOfSegments < 8) + log.appendAsLeader(records(log.logEndOffset.toInt, log.logEndOffset.toInt, time.milliseconds()), leaderEpoch = 0) + + val cleanerManager: LogCleanerManager = createCleanerManager(log) + + assertThrows(classOf[IllegalStateException], () => cleanerManager.doneCleaning(topicPartition, log.dir, 1)) + + cleanerManager.setCleaningState(topicPartition, LogCleaningPaused(1)) + assertThrows(classOf[IllegalStateException], () => cleanerManager.doneCleaning(topicPartition, log.dir, 1)) + + cleanerManager.setCleaningState(topicPartition, LogCleaningInProgress) + cleanerManager.doneCleaning(topicPartition, log.dir, 1) + assertTrue(cleanerManager.cleaningState(topicPartition).isEmpty) + assertTrue(cleanerManager.allCleanerCheckpoints.get(topicPartition).nonEmpty) + + cleanerManager.setCleaningState(topicPartition, LogCleaningAborted) + cleanerManager.doneCleaning(topicPartition, log.dir, 1) + assertEquals(LogCleaningPaused(1), cleanerManager.cleaningState(topicPartition).get) + assertTrue(cleanerManager.allCleanerCheckpoints.get(topicPartition).nonEmpty) + } + + @Test + def testDoneDeleting(): Unit = { + val records = TestUtils.singletonRecords("test".getBytes, key="test".getBytes) + val log: UnifiedLog = createLog(records.sizeInBytes * 5, LogConfig.Compact + "," + LogConfig.Delete) + val cleanerManager: LogCleanerManager = createCleanerManager(log) + val tp = new TopicPartition("log", 0) + + assertThrows(classOf[IllegalStateException], () => cleanerManager.doneDeleting(Seq(tp))) + + cleanerManager.setCleaningState(tp, LogCleaningPaused(1)) + assertThrows(classOf[IllegalStateException], () => cleanerManager.doneDeleting(Seq(tp))) + + cleanerManager.setCleaningState(tp, LogCleaningInProgress) + cleanerManager.doneDeleting(Seq(tp)) + assertTrue(cleanerManager.cleaningState(tp).isEmpty) + + cleanerManager.setCleaningState(tp, LogCleaningAborted) + cleanerManager.doneDeleting(Seq(tp)) + assertEquals(LogCleaningPaused(1), cleanerManager.cleaningState(tp).get) + } + + /** + * Logs with invalid checkpoint offsets should update their checkpoint offset even if the log doesn't need cleaning + */ + @Test + def testCheckpointUpdatedForInvalidOffsetNoCleaning(): Unit = { + val tp = new TopicPartition("foo", 0) + val logs = setupIncreasinglyFilthyLogs(Seq(tp), startNumBatches = 20, batchIncrement = 5) + + logs.get(tp).maybeIncrementLogStartOffset(20L, ClientRecordDeletion) + val cleanerManager = createCleanerManagerMock(logs) + cleanerCheckpoints.put(tp, 15L) + + val filthiestLog = cleanerManager.grabFilthiestCompactedLog(time) + assertEquals(None, filthiestLog, "Log should not be selected for cleaning") + assertEquals(20L, cleanerCheckpoints.get(tp).get, "Unselected log should have checkpoint offset updated") + } + + /** + * Logs with invalid checkpoint offsets should update their checkpoint offset even if they aren't selected + * for immediate cleaning + */ + @Test + def testCheckpointUpdatedForInvalidOffsetNotSelected(): Unit = { + val tp0 = new TopicPartition("foo", 0) + val tp1 = new TopicPartition("foo", 1) + val partitions = Seq(tp0, tp1) + + // create two logs, one with an invalid offset, and one that is dirtier than the log with an invalid offset + val logs = setupIncreasinglyFilthyLogs(partitions, startNumBatches = 20, batchIncrement = 5) + logs.get(tp0).maybeIncrementLogStartOffset(15L, ClientRecordDeletion) + val cleanerManager = createCleanerManagerMock(logs) + cleanerCheckpoints.put(tp0, 10L) + cleanerCheckpoints.put(tp1, 5L) + + val filthiestLog = cleanerManager.grabFilthiestCompactedLog(time).get + assertEquals(tp1, filthiestLog.topicPartition, "Dirtier log should be selected") + assertEquals(15L, cleanerCheckpoints.get(tp0).get, "Unselected log should have checkpoint offset updated") + } + + private def createCleanerManager(log: UnifiedLog): LogCleanerManager = { + val logs = new Pool[TopicPartition, UnifiedLog]() + logs.put(topicPartition, log) + new LogCleanerManager(Seq(logDir, logDir2), logs, null) + } + + private def createCleanerManagerMock(pool: Pool[TopicPartition, UnifiedLog]): LogCleanerManagerMock = { + new LogCleanerManagerMock(Seq(logDir), pool, null) + } + + private def createLog(segmentSize: Int, + cleanupPolicy: String, + topicPartition: TopicPartition = new TopicPartition("log", 0)): UnifiedLog = { + val config = createLowRetentionLogConfig(segmentSize, cleanupPolicy) + val partitionDir = new File(logDir, UnifiedLog.logDirName(topicPartition)) + + UnifiedLog(partitionDir, + config, + logStartOffset = 0L, + recoveryPoint = 0L, + scheduler = time.scheduler, + time = time, + brokerTopicStats = new BrokerTopicStats, + maxProducerIdExpirationMs = 60 * 60 * 1000, + producerIdExpirationCheckIntervalMs = LogManager.ProducerIdExpirationCheckIntervalMs, + logDirFailureChannel = new LogDirFailureChannel(10), + topicId = None, + keepPartitionMetadataFile = true) + } + + private def createLowRetentionLogConfig(segmentSize: Int, cleanupPolicy: String): LogConfig = { + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, segmentSize: Integer) + logProps.put(LogConfig.RetentionMsProp, 1: Integer) + logProps.put(LogConfig.CleanupPolicyProp, cleanupPolicy) + logProps.put(LogConfig.MinCleanableDirtyRatioProp, 0.05: java.lang.Double) // small for easier and clearer tests + + LogConfig(logProps) + } + + private def writeRecords(log: UnifiedLog, + numBatches: Int, + recordsPerBatch: Int, + batchesPerSegment: Int): Unit = { + for (i <- 0 until numBatches) { + appendRecords(log, recordsPerBatch) + if (i % batchesPerSegment == 0) + log.roll() + } + log.roll() + } + + private def appendRecords(log: UnifiedLog, numRecords: Int): Unit = { + val startOffset = log.logEndOffset + val endOffset = startOffset + numRecords + var lastTimestamp = 0L + val records = (startOffset until endOffset).map { offset => + val currentTimestamp = time.milliseconds() + if (offset == endOffset - 1) + lastTimestamp = currentTimestamp + new SimpleRecord(currentTimestamp, s"key-$offset".getBytes, s"value-$offset".getBytes) + } + + log.appendAsLeader(MemoryRecords.withRecords(CompressionType.NONE, records:_*), leaderEpoch = 1) + log.maybeIncrementHighWatermark(log.logEndOffsetMetadata) + } + + private def makeLog(dir: File = logDir, config: LogConfig) = + UnifiedLog(dir = dir, config = config, logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler, + time = time, brokerTopicStats = new BrokerTopicStats, maxProducerIdExpirationMs = 60 * 60 * 1000, + producerIdExpirationCheckIntervalMs = LogManager.ProducerIdExpirationCheckIntervalMs, + logDirFailureChannel = new LogDirFailureChannel(10), topicId = None, keepPartitionMetadataFile = true) + + private def records(key: Int, value: Int, timestamp: Long) = + MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord(timestamp, key.toString.getBytes, value.toString.getBytes)) + +} diff --git a/core/src/test/scala/unit/kafka/log/LogCleanerParameterizedIntegrationTest.scala b/core/src/test/scala/unit/kafka/log/LogCleanerParameterizedIntegrationTest.scala new file mode 100755 index 0000000..70ac47e --- /dev/null +++ b/core/src/test/scala/unit/kafka/log/LogCleanerParameterizedIntegrationTest.scala @@ -0,0 +1,339 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import java.io.File +import java.util.Properties +import kafka.api.KAFKA_0_11_0_IV0 +import kafka.api.{KAFKA_0_10_0_IV1, KAFKA_0_9_0} +import kafka.server.KafkaConfig +import kafka.server.checkpoints.OffsetCheckpointFile +import kafka.utils._ +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.record._ +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.extension.ExtensionContext +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.{Arguments, ArgumentsProvider, ArgumentsSource} + +import scala.annotation.nowarn +import scala.collection._ +import scala.jdk.CollectionConverters._ + +/** + * This is an integration test that tests the fully integrated log cleaner + */ +class LogCleanerParameterizedIntegrationTest extends AbstractLogCleanerIntegrationTest { + + val time = new MockTime() + + val topicPartitions = Array(new TopicPartition("log", 0), new TopicPartition("log", 1), new TopicPartition("log", 2)) + + @ParameterizedTest + @ArgumentsSource(classOf[LogCleanerParameterizedIntegrationTest.AllCompressions]) + def cleanerTest(codec: CompressionType): Unit = { + val largeMessageKey = 20 + val (largeMessageValue, largeMessageSet) = createLargeSingleMessageSet(largeMessageKey, RecordBatch.CURRENT_MAGIC_VALUE, codec) + val maxMessageSize = largeMessageSet.sizeInBytes + + cleaner = makeCleaner(partitions = topicPartitions, maxMessageSize = maxMessageSize) + val log = cleaner.logs.get(topicPartitions(0)) + + val appends = writeDups(numKeys = 100, numDups = 3, log = log, codec = codec) + val startSize = log.size + cleaner.startup() + + val firstDirty = log.activeSegment.baseOffset + checkLastCleaned("log", 0, firstDirty) + val compactedSize = log.logSegments.map(_.size).sum + assertTrue(startSize > compactedSize, s"log should have been compacted: startSize=$startSize compactedSize=$compactedSize") + + checkLogAfterAppendingDups(log, startSize, appends) + + val appendInfo = log.appendAsLeader(largeMessageSet, leaderEpoch = 0) + // move LSO forward to increase compaction bound + log.updateHighWatermark(log.logEndOffset) + val largeMessageOffset = appendInfo.firstOffset.get.messageOffset + + val dups = writeDups(startKey = largeMessageKey + 1, numKeys = 100, numDups = 3, log = log, codec = codec) + val appends2 = appends ++ Seq((largeMessageKey, largeMessageValue, largeMessageOffset)) ++ dups + val firstDirty2 = log.activeSegment.baseOffset + checkLastCleaned("log", 0, firstDirty2) + + checkLogAfterAppendingDups(log, startSize, appends2) + + // simulate deleting a partition, by removing it from logs + // force a checkpoint + // and make sure its gone from checkpoint file + cleaner.logs.remove(topicPartitions(0)) + cleaner.updateCheckpoints(logDir, partitionToRemove = Option(topicPartitions(0))) + val checkpoints = new OffsetCheckpointFile(new File(logDir, cleaner.cleanerManager.offsetCheckpointFile)).read() + // we expect partition 0 to be gone + assertFalse(checkpoints.contains(topicPartitions(0))) + } + + @ParameterizedTest + @ArgumentsSource(classOf[LogCleanerParameterizedIntegrationTest.AllCompressions]) + def testCleansCombinedCompactAndDeleteTopic(codec: CompressionType): Unit = { + val logProps = new Properties() + val retentionMs: Integer = 100000 + logProps.put(LogConfig.RetentionMsProp, retentionMs: Integer) + logProps.put(LogConfig.CleanupPolicyProp, "compact,delete") + + def runCleanerAndCheckCompacted(numKeys: Int): (UnifiedLog, Seq[(Int, String, Long)]) = { + cleaner = makeCleaner(partitions = topicPartitions.take(1), propertyOverrides = logProps, backOffMs = 100L) + val log = cleaner.logs.get(topicPartitions(0)) + + val messages = writeDups(numKeys = numKeys, numDups = 3, log = log, codec = codec) + val startSize = log.size + + log.updateHighWatermark(log.logEndOffset) + + val firstDirty = log.activeSegment.baseOffset + cleaner.startup() + + // should compact the log + checkLastCleaned("log", 0, firstDirty) + val compactedSize = log.logSegments.map(_.size).sum + assertTrue(startSize > compactedSize, s"log should have been compacted: startSize=$startSize compactedSize=$compactedSize") + (log, messages) + } + + val (log, _) = runCleanerAndCheckCompacted(100) + + // Set the last modified time to an old value to force deletion of old segments + val endOffset = log.logEndOffset + log.logSegments.foreach(_.lastModified = time.milliseconds - (2 * retentionMs)) + TestUtils.waitUntilTrue(() => log.logStartOffset == endOffset, + "Timed out waiting for deletion of old segments") + assertEquals(1, log.numberOfSegments) + + cleaner.shutdown() + + // run the cleaner again to make sure if there are no issues post deletion + val (log2, messages) = runCleanerAndCheckCompacted(20) + val read = readFromLog(log2) + assertEquals(toMap(messages), toMap(read), "Contents of the map shouldn't change") + } + + @nowarn("cat=deprecation") + @ParameterizedTest + @ArgumentsSource(classOf[LogCleanerParameterizedIntegrationTest.ExcludeZstd]) + def testCleanerWithMessageFormatV0(codec: CompressionType): Unit = { + val largeMessageKey = 20 + val (largeMessageValue, largeMessageSet) = createLargeSingleMessageSet(largeMessageKey, RecordBatch.MAGIC_VALUE_V0, codec) + val maxMessageSize = codec match { + case CompressionType.NONE => largeMessageSet.sizeInBytes + case _ => + // the broker assigns absolute offsets for message format 0 which potentially causes the compressed size to + // increase because the broker offsets are larger than the ones assigned by the client + // adding `5` to the message set size is good enough for this test: it covers the increased message size while + // still being less than the overhead introduced by the conversion from message format version 0 to 1 + largeMessageSet.sizeInBytes + 5 + } + + cleaner = makeCleaner(partitions = topicPartitions, maxMessageSize = maxMessageSize) + + val log = cleaner.logs.get(topicPartitions(0)) + val props = logConfigProperties(maxMessageSize = maxMessageSize) + props.put(LogConfig.MessageFormatVersionProp, KAFKA_0_9_0.version) + log.updateConfig(new LogConfig(props)) + + val appends = writeDups(numKeys = 100, numDups = 3, log = log, codec = codec, magicValue = RecordBatch.MAGIC_VALUE_V0) + val startSize = log.size + cleaner.startup() + + val firstDirty = log.activeSegment.baseOffset + checkLastCleaned("log", 0, firstDirty) + val compactedSize = log.logSegments.map(_.size).sum + assertTrue(startSize > compactedSize, s"log should have been compacted: startSize=$startSize compactedSize=$compactedSize") + + checkLogAfterAppendingDups(log, startSize, appends) + + val appends2: Seq[(Int, String, Long)] = { + val dupsV0 = writeDups(numKeys = 40, numDups = 3, log = log, codec = codec, magicValue = RecordBatch.MAGIC_VALUE_V0) + val appendInfo = log.appendAsLeader(largeMessageSet, leaderEpoch = 0) + // move LSO forward to increase compaction bound + log.updateHighWatermark(log.logEndOffset) + val largeMessageOffset = appendInfo.firstOffset.map(_.messageOffset).get + + // also add some messages with version 1 and version 2 to check that we handle mixed format versions correctly + props.put(LogConfig.MessageFormatVersionProp, KAFKA_0_11_0_IV0.version) + log.updateConfig(new LogConfig(props)) + val dupsV1 = writeDups(startKey = 30, numKeys = 40, numDups = 3, log = log, codec = codec, magicValue = RecordBatch.MAGIC_VALUE_V1) + val dupsV2 = writeDups(startKey = 15, numKeys = 5, numDups = 3, log = log, codec = codec, magicValue = RecordBatch.MAGIC_VALUE_V2) + appends ++ dupsV0 ++ Seq((largeMessageKey, largeMessageValue, largeMessageOffset)) ++ dupsV1 ++ dupsV2 + } + val firstDirty2 = log.activeSegment.baseOffset + checkLastCleaned("log", 0, firstDirty2) + + checkLogAfterAppendingDups(log, startSize, appends2) + } + + @nowarn("cat=deprecation") + @ParameterizedTest + @ArgumentsSource(classOf[LogCleanerParameterizedIntegrationTest.ExcludeZstd]) + def testCleaningNestedMessagesWithV0AndV1(codec: CompressionType): Unit = { + val maxMessageSize = 192 + cleaner = makeCleaner(partitions = topicPartitions, maxMessageSize = maxMessageSize, segmentSize = 256) + + val log = cleaner.logs.get(topicPartitions(0)) + val props = logConfigProperties(maxMessageSize = maxMessageSize, segmentSize = 256) + props.put(LogConfig.MessageFormatVersionProp, KAFKA_0_9_0.version) + log.updateConfig(new LogConfig(props)) + + // with compression enabled, these messages will be written as a single message containing + // all of the individual messages + var appendsV0 = writeDupsSingleMessageSet(numKeys = 2, numDups = 3, log = log, codec = codec, magicValue = RecordBatch.MAGIC_VALUE_V0) + appendsV0 ++= writeDupsSingleMessageSet(numKeys = 2, startKey = 3, numDups = 2, log = log, codec = codec, magicValue = RecordBatch.MAGIC_VALUE_V0) + + props.put(LogConfig.MessageFormatVersionProp, KAFKA_0_10_0_IV1.version) + log.updateConfig(new LogConfig(props)) + + var appendsV1 = writeDupsSingleMessageSet(startKey = 4, numKeys = 2, numDups = 2, log = log, codec = codec, magicValue = RecordBatch.MAGIC_VALUE_V1) + appendsV1 ++= writeDupsSingleMessageSet(startKey = 4, numKeys = 2, numDups = 2, log = log, codec = codec, magicValue = RecordBatch.MAGIC_VALUE_V1) + appendsV1 ++= writeDupsSingleMessageSet(startKey = 6, numKeys = 2, numDups = 2, log = log, codec = codec, magicValue = RecordBatch.MAGIC_VALUE_V1) + + val appends = appendsV0 ++ appendsV1 + + val startSize = log.size + cleaner.startup() + + val firstDirty = log.activeSegment.baseOffset + assertTrue(firstDirty > appendsV0.size) // ensure we clean data from V0 and V1 + + checkLastCleaned("log", 0, firstDirty) + val compactedSize = log.logSegments.map(_.size).sum + assertTrue(startSize > compactedSize, s"log should have been compacted: startSize=$startSize compactedSize=$compactedSize") + + checkLogAfterAppendingDups(log, startSize, appends) + } + + @ParameterizedTest + @ArgumentsSource(classOf[LogCleanerParameterizedIntegrationTest.AllCompressions]) + def cleanerConfigUpdateTest(codec: CompressionType): Unit = { + val largeMessageKey = 20 + val (largeMessageValue, largeMessageSet) = createLargeSingleMessageSet(largeMessageKey, RecordBatch.CURRENT_MAGIC_VALUE, codec) + val maxMessageSize = largeMessageSet.sizeInBytes + + cleaner = makeCleaner(partitions = topicPartitions, backOffMs = 1, maxMessageSize = maxMessageSize, + cleanerIoBufferSize = Some(1)) + val log = cleaner.logs.get(topicPartitions(0)) + + writeDups(numKeys = 100, numDups = 3, log = log, codec = codec) + val startSize = log.size + cleaner.startup() + assertEquals(1, cleaner.cleanerCount) + + // Verify no cleaning with LogCleanerIoBufferSizeProp=1 + val firstDirty = log.activeSegment.baseOffset + val topicPartition = new TopicPartition("log", 0) + cleaner.awaitCleaned(topicPartition, firstDirty, maxWaitMs = 10) + assertTrue(cleaner.cleanerManager.allCleanerCheckpoints.isEmpty, "Should not have cleaned") + + def kafkaConfigWithCleanerConfig(cleanerConfig: CleanerConfig): KafkaConfig = { + val props = TestUtils.createBrokerConfig(0, "localhost:2181") + props.put(KafkaConfig.LogCleanerThreadsProp, cleanerConfig.numThreads.toString) + props.put(KafkaConfig.LogCleanerDedupeBufferSizeProp, cleanerConfig.dedupeBufferSize.toString) + props.put(KafkaConfig.LogCleanerDedupeBufferLoadFactorProp, cleanerConfig.dedupeBufferLoadFactor.toString) + props.put(KafkaConfig.LogCleanerIoBufferSizeProp, cleanerConfig.ioBufferSize.toString) + props.put(KafkaConfig.MessageMaxBytesProp, cleanerConfig.maxMessageSize.toString) + props.put(KafkaConfig.LogCleanerBackoffMsProp, cleanerConfig.backOffMs.toString) + props.put(KafkaConfig.LogCleanerIoMaxBytesPerSecondProp, cleanerConfig.maxIoBytesPerSecond.toString) + KafkaConfig.fromProps(props) + } + + // Verify cleaning done with larger LogCleanerIoBufferSizeProp + val oldConfig = kafkaConfigWithCleanerConfig(cleaner.currentConfig) + val newConfig = kafkaConfigWithCleanerConfig(CleanerConfig(numThreads = 2, + dedupeBufferSize = cleaner.currentConfig.dedupeBufferSize, + dedupeBufferLoadFactor = cleaner.currentConfig.dedupeBufferLoadFactor, + ioBufferSize = 100000, + maxMessageSize = cleaner.currentConfig.maxMessageSize, + maxIoBytesPerSecond = cleaner.currentConfig.maxIoBytesPerSecond, + backOffMs = cleaner.currentConfig.backOffMs)) + cleaner.reconfigure(oldConfig, newConfig) + + assertEquals(2, cleaner.cleanerCount) + checkLastCleaned("log", 0, firstDirty) + val compactedSize = log.logSegments.map(_.size).sum + assertTrue(startSize > compactedSize, s"log should have been compacted: startSize=$startSize compactedSize=$compactedSize") + } + + private def checkLastCleaned(topic: String, partitionId: Int, firstDirty: Long): Unit = { + // wait until cleaning up to base_offset, note that cleaning happens only when "log dirty ratio" is higher than + // LogConfig.MinCleanableDirtyRatioProp + val topicPartition = new TopicPartition(topic, partitionId) + cleaner.awaitCleaned(topicPartition, firstDirty) + val lastCleaned = cleaner.cleanerManager.allCleanerCheckpoints(topicPartition) + assertTrue(lastCleaned >= firstDirty, s"log cleaner should have processed up to offset $firstDirty, but lastCleaned=$lastCleaned") + } + + private def checkLogAfterAppendingDups(log: UnifiedLog, startSize: Long, appends: Seq[(Int, String, Long)]): Unit = { + val read = readFromLog(log) + assertEquals(toMap(appends), toMap(read), "Contents of the map shouldn't change") + assertTrue(startSize > log.size) + } + + private def toMap(messages: Iterable[(Int, String, Long)]): Map[Int, (String, Long)] = { + messages.map { case (key, value, offset) => key -> (value, offset) }.toMap + } + + private def readFromLog(log: UnifiedLog): Iterable[(Int, String, Long)] = { + for (segment <- log.logSegments; deepLogEntry <- segment.log.records.asScala) yield { + val key = TestUtils.readString(deepLogEntry.key).toInt + val value = TestUtils.readString(deepLogEntry.value) + (key, value, deepLogEntry.offset) + } + } + + private def writeDupsSingleMessageSet(numKeys: Int, numDups: Int, log: UnifiedLog, codec: CompressionType, + startKey: Int = 0, magicValue: Byte): Seq[(Int, String, Long)] = { + val kvs = for (_ <- 0 until numDups; key <- startKey until (startKey + numKeys)) yield { + val payload = counter.toString + incCounter() + (key, payload) + } + + val records = kvs.map { case (key, payload) => + new SimpleRecord(key.toString.getBytes, payload.toString.getBytes) + } + + val appendInfo = log.appendAsLeader(MemoryRecords.withRecords(magicValue, codec, records: _*), leaderEpoch = 0) + // move LSO forward to increase compaction bound + log.updateHighWatermark(log.logEndOffset) + val offsets = appendInfo.firstOffset.get.messageOffset to appendInfo.lastOffset + + kvs.zip(offsets).map { case (kv, offset) => (kv._1, kv._2, offset) } + } + +} + +object LogCleanerParameterizedIntegrationTest { + + class AllCompressions extends ArgumentsProvider { + override def provideArguments(context: ExtensionContext): java.util.stream.Stream[_ <: Arguments] = + java.util.Arrays.stream(CompressionType.values.map(codec => Arguments.of(codec))) + } + + // zstd compression is not supported with older message formats (i.e supported by V0 and V1) + class ExcludeZstd extends ArgumentsProvider { + override def provideArguments(context: ExtensionContext): java.util.stream.Stream[_ <: Arguments] = + java.util.Arrays.stream(CompressionType.values.filter(_ != CompressionType.ZSTD).map(codec => Arguments.of(codec))) + } +} diff --git a/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala b/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala new file mode 100755 index 0000000..8f1d241 --- /dev/null +++ b/core/src/test/scala/unit/kafka/log/LogCleanerTest.scala @@ -0,0 +1,1912 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import java.io.{File, RandomAccessFile} +import java.nio._ +import java.nio.charset.StandardCharsets +import java.nio.file.Paths +import java.util.Properties +import java.util.concurrent.{CountDownLatch, TimeUnit} + +import kafka.common._ +import kafka.server.{BrokerTopicStats, LogDirFailureChannel} +import kafka.utils._ +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.errors.CorruptRecordException +import org.apache.kafka.common.record._ +import org.apache.kafka.common.utils.Utils +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, Test} + +import scala.collection._ +import scala.jdk.CollectionConverters._ + +/** + * Unit tests for the log cleaning logic + */ +class LogCleanerTest { + + val tmpdir = TestUtils.tempDir() + val dir = TestUtils.randomPartitionLogDir(tmpdir) + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 1024: java.lang.Integer) + logProps.put(LogConfig.SegmentIndexBytesProp, 1024: java.lang.Integer) + logProps.put(LogConfig.CleanupPolicyProp, LogConfig.Compact) + logProps.put(LogConfig.MessageTimestampDifferenceMaxMsProp, Long.MaxValue.toString) + val logConfig = LogConfig(logProps) + val time = new MockTime() + val throttler = new Throttler(desiredRatePerSec = Double.MaxValue, checkIntervalMs = Long.MaxValue, time = time) + val tombstoneRetentionMs = 86400000 + val largeTimestamp = Long.MaxValue - tombstoneRetentionMs - 1 + + @AfterEach + def teardown(): Unit = { + Utils.delete(tmpdir) + } + + /** + * Test simple log cleaning + */ + @Test + def testCleanSegments(): Unit = { + val cleaner = makeCleaner(Int.MaxValue) + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 1024: java.lang.Integer) + + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + // append messages to the log until we have four segments + while(log.numberOfSegments < 4) + log.appendAsLeader(record(log.logEndOffset.toInt, log.logEndOffset.toInt), leaderEpoch = 0) + val keysFound = LogTestUtils.keysInLog(log) + assertEquals(0L until log.logEndOffset, keysFound) + + // pretend we have the following keys + val keys = immutable.ListSet(1L, 3L, 5L, 7L, 9L) + val map = new FakeOffsetMap(Int.MaxValue) + keys.foreach(k => map.put(key(k), Long.MaxValue)) + + // clean the log + val segments = log.logSegments.take(3).toSeq + val stats = new CleanerStats() + val expectedBytesRead = segments.map(_.size).sum + val shouldRemain = LogTestUtils.keysInLog(log).filter(!keys.contains(_)) + cleaner.cleanSegments(log, segments, map, 0L, stats, new CleanedTransactionMetadata, -1) + assertEquals(shouldRemain, LogTestUtils.keysInLog(log)) + assertEquals(expectedBytesRead, stats.bytesRead) + } + + @Test + def testCleanSegmentsWithConcurrentSegmentDeletion(): Unit = { + val deleteStartLatch = new CountDownLatch(1) + val deleteCompleteLatch = new CountDownLatch(1) + + // Construct a log instance. The replaceSegments() method of the log instance is overridden so that + // it waits for another thread to execute deleteOldSegments() + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 1024 : java.lang.Integer) + logProps.put(LogConfig.CleanupPolicyProp, LogConfig.Compact + "," + LogConfig.Delete) + val config = LogConfig.fromProps(logConfig.originals, logProps) + val topicPartition = UnifiedLog.parseTopicPartitionName(dir) + val logDirFailureChannel = new LogDirFailureChannel(10) + val maxProducerIdExpirationMs = 60 * 60 * 1000 + val logSegments = new LogSegments(topicPartition) + val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(dir, topicPartition, logDirFailureChannel, config.recordVersion, "") + val producerStateManager = new ProducerStateManager(topicPartition, dir, maxProducerIdExpirationMs, time) + val offsets = LogLoader.load(LoadLogParams( + dir, + topicPartition, + config, + time.scheduler, + time, + logDirFailureChannel, + hadCleanShutdown = true, + logSegments, + 0L, + 0L, + maxProducerIdExpirationMs, + leaderEpochCache, + producerStateManager)) + val localLog = new LocalLog(dir, config, logSegments, offsets.recoveryPoint, + offsets.nextOffsetMetadata, time.scheduler, time, topicPartition, logDirFailureChannel) + val log = new UnifiedLog(offsets.logStartOffset, + localLog, + brokerTopicStats = new BrokerTopicStats, + producerIdExpirationCheckIntervalMs = LogManager.ProducerIdExpirationCheckIntervalMs, + leaderEpochCache = leaderEpochCache, + producerStateManager = producerStateManager, + _topicId = None, + keepPartitionMetadataFile = true) { + override def replaceSegments(newSegments: Seq[LogSegment], oldSegments: Seq[LogSegment]): Unit = { + deleteStartLatch.countDown() + if (!deleteCompleteLatch.await(5000, TimeUnit.MILLISECONDS)) { + throw new IllegalStateException("Log segment deletion timed out") + } + super.replaceSegments(newSegments, oldSegments) + } + } + + // Start a thread which execute log.deleteOldSegments() right before replaceSegments() is executed + val t = new Thread() { + override def run(): Unit = { + deleteStartLatch.await(5000, TimeUnit.MILLISECONDS) + log.updateHighWatermark(log.activeSegment.baseOffset) + log.maybeIncrementLogStartOffset(log.activeSegment.baseOffset, LeaderOffsetIncremented) + log.updateHighWatermark(log.activeSegment.baseOffset) + log.deleteOldSegments() + deleteCompleteLatch.countDown() + } + } + t.start() + + // Append records so that segment number increase to 3 + while (log.numberOfSegments < 3) { + log.appendAsLeader(record(key = 0, log.logEndOffset.toInt), leaderEpoch = 0) + log.roll() + } + assertEquals(3, log.numberOfSegments) + + // Remember reference to the first log and determine its file name expected for async deletion + val firstLogFile = log.logSegments.head.log + val expectedFileName = CoreUtils.replaceSuffix(firstLogFile.file.getPath, "", UnifiedLog.DeletedFileSuffix) + + // Clean the log. This should trigger replaceSegments() and deleteOldSegments(); + val offsetMap = new FakeOffsetMap(Int.MaxValue) + val cleaner = makeCleaner(Int.MaxValue) + val segments = log.logSegments(0, log.activeSegment.baseOffset).toSeq + val stats = new CleanerStats() + cleaner.buildOffsetMap(log, 0, log.activeSegment.baseOffset, offsetMap, stats) + cleaner.cleanSegments(log, segments, offsetMap, 0L, stats, new CleanedTransactionMetadata, -1) + + // Validate based on the file name that log segment file is renamed exactly once for async deletion + assertEquals(expectedFileName, firstLogFile.file().getPath) + assertEquals(2, log.numberOfSegments) + } + + @Test + def testSizeTrimmedForPreallocatedAndCompactedTopic(): Unit = { + val originalMaxFileSize = 1024; + val cleaner = makeCleaner(2) + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, originalMaxFileSize: java.lang.Integer) + logProps.put(LogConfig.CleanupPolicyProp, "compact": java.lang.String) + logProps.put(LogConfig.PreAllocateEnableProp, "true": java.lang.String) + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + log.appendAsLeader(record(0,0), leaderEpoch = 0) // offset 0 + log.appendAsLeader(record(1,1), leaderEpoch = 0) // offset 1 + log.appendAsLeader(record(0,0), leaderEpoch = 0) // offset 2 + log.appendAsLeader(record(1,1), leaderEpoch = 0) // offset 3 + log.appendAsLeader(record(0,0), leaderEpoch = 0) // offset 4 + // roll the segment, so we can clean the messages already appended + log.roll() + + // clean the log with only one message removed + cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 2, log.activeSegment.baseOffset)) + + assertTrue(log.logSegments.iterator.next().log.channel.size < originalMaxFileSize, + "Cleaned segment file should be trimmed to its real size.") + } + + @Test + def testDuplicateCheckAfterCleaning(): Unit = { + val cleaner = makeCleaner(Int.MaxValue) + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 2048: java.lang.Integer) + var log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + val producerEpoch = 0.toShort + val pid1 = 1 + val pid2 = 2 + val pid3 = 3 + val pid4 = 4 + + appendIdempotentAsLeader(log, pid1, producerEpoch)(Seq(1, 2, 3)) + appendIdempotentAsLeader(log, pid2, producerEpoch)(Seq(3, 1, 4)) + appendIdempotentAsLeader(log, pid3, producerEpoch)(Seq(1, 4)) + + log.roll() + cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 0L, log.activeSegment.baseOffset)) + assertEquals(List(2, 5, 7), lastOffsetsPerBatchInLog(log)) + assertEquals(Map(pid1 -> 2, pid2 -> 2, pid3 -> 1), lastSequencesInLog(log)) + assertEquals(List(2, 3, 1, 4), LogTestUtils.keysInLog(log)) + assertEquals(List(1, 3, 6, 7), offsetsInLog(log)) + + // we have to reload the log to validate that the cleaner maintained sequence numbers correctly + def reloadLog(): Unit = { + log.close() + log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps), recoveryPoint = 0L) + } + + reloadLog() + + // check duplicate append from producer 1 + var logAppendInfo = appendIdempotentAsLeader(log, pid1, producerEpoch)(Seq(1, 2, 3)) + assertEquals(0L, logAppendInfo.firstOffset.get.messageOffset) + assertEquals(2L, logAppendInfo.lastOffset) + + // check duplicate append from producer 3 + logAppendInfo = appendIdempotentAsLeader(log, pid3, producerEpoch)(Seq(1, 4)) + assertEquals(6L, logAppendInfo.firstOffset.get.messageOffset) + assertEquals(7L, logAppendInfo.lastOffset) + + // check duplicate append from producer 2 + logAppendInfo = appendIdempotentAsLeader(log, pid2, producerEpoch)(Seq(3, 1, 4)) + assertEquals(3L, logAppendInfo.firstOffset.get.messageOffset) + assertEquals(5L, logAppendInfo.lastOffset) + + // do one more append and a round of cleaning to force another deletion from producer 1's batch + appendIdempotentAsLeader(log, pid4, producerEpoch)(Seq(2)) + log.roll() + cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 0L, log.activeSegment.baseOffset)) + assertEquals(Map(pid1 -> 2, pid2 -> 2, pid3 -> 1, pid4 -> 0), lastSequencesInLog(log)) + assertEquals(List(2, 5, 7, 8), lastOffsetsPerBatchInLog(log)) + assertEquals(List(3, 1, 4, 2), LogTestUtils.keysInLog(log)) + assertEquals(List(3, 6, 7, 8), offsetsInLog(log)) + + reloadLog() + + // duplicate append from producer1 should still be fine + logAppendInfo = appendIdempotentAsLeader(log, pid1, producerEpoch)(Seq(1, 2, 3)) + assertEquals(0L, logAppendInfo.firstOffset.get.messageOffset) + assertEquals(2L, logAppendInfo.lastOffset) + } + + @Test + def testBasicTransactionAwareCleaning(): Unit = { + val cleaner = makeCleaner(Int.MaxValue) + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 2048: java.lang.Integer) + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + val producerEpoch = 0.toShort + val pid1 = 1 + val pid2 = 2 + + val appendProducer1 = appendTransactionalAsLeader(log, pid1, producerEpoch) + val appendProducer2 = appendTransactionalAsLeader(log, pid2, producerEpoch) + + appendProducer1(Seq(1, 2)) + appendProducer2(Seq(2, 3)) + appendProducer1(Seq(3, 4)) + log.appendAsLeader(abortMarker(pid1, producerEpoch), leaderEpoch = 0, origin = AppendOrigin.Coordinator) + log.appendAsLeader(commitMarker(pid2, producerEpoch), leaderEpoch = 0, origin = AppendOrigin.Coordinator) + appendProducer1(Seq(2)) + log.appendAsLeader(commitMarker(pid1, producerEpoch), leaderEpoch = 0, origin = AppendOrigin.Coordinator) + + val abortedTransactions = log.collectAbortedTransactions(log.logStartOffset, log.logEndOffset) + + log.roll() + cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 0L, log.activeSegment.baseOffset)) + assertEquals(List(3, 2), LogTestUtils.keysInLog(log)) + assertEquals(List(3, 6, 7, 8, 9), offsetsInLog(log)) + + // ensure the transaction index is still correct + assertEquals(abortedTransactions, log.collectAbortedTransactions(log.logStartOffset, log.logEndOffset)) + } + + @Test + def testCleanWithTransactionsSpanningSegments(): Unit = { + val cleaner = makeCleaner(Int.MaxValue) + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 1024: java.lang.Integer) + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + val producerEpoch = 0.toShort + val pid1 = 1 + val pid2 = 2 + val pid3 = 3 + + val appendProducer1 = appendTransactionalAsLeader(log, pid1, producerEpoch) + val appendProducer2 = appendTransactionalAsLeader(log, pid2, producerEpoch) + val appendProducer3 = appendTransactionalAsLeader(log, pid3, producerEpoch) + + appendProducer1(Seq(1, 2)) + appendProducer3(Seq(2, 3)) + appendProducer2(Seq(3, 4)) + + log.roll() + + appendProducer2(Seq(5, 6)) + appendProducer3(Seq(6, 7)) + appendProducer1(Seq(7, 8)) + log.appendAsLeader(abortMarker(pid2, producerEpoch), leaderEpoch = 0, origin = AppendOrigin.Coordinator) + appendProducer3(Seq(8, 9)) + log.appendAsLeader(commitMarker(pid3, producerEpoch), leaderEpoch = 0, origin = AppendOrigin.Coordinator) + appendProducer1(Seq(9, 10)) + log.appendAsLeader(abortMarker(pid1, producerEpoch), leaderEpoch = 0, origin = AppendOrigin.Coordinator) + + // we have only cleaned the records in the first segment + val dirtyOffset = cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 0L, log.activeSegment.baseOffset))._1 + assertEquals(List(2, 3, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10), LogTestUtils.keysInLog(log)) + + log.roll() + + // append a couple extra segments in the new segment to ensure we have sequence numbers + appendProducer2(Seq(11)) + appendProducer1(Seq(12)) + + // finally only the keys from pid3 should remain + cleaner.clean(LogToClean(new TopicPartition("test", 0), log, dirtyOffset, log.activeSegment.baseOffset)) + assertEquals(List(2, 3, 6, 7, 8, 9, 11, 12), LogTestUtils.keysInLog(log)) + } + + @Test + def testCommitMarkerRemoval(): Unit = { + val tp = new TopicPartition("test", 0) + val cleaner = makeCleaner(Int.MaxValue) + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 256: java.lang.Integer) + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + val producerEpoch = 0.toShort + val producerId = 1L + val appendProducer = appendTransactionalAsLeader(log, producerId, producerEpoch) + + appendProducer(Seq(1)) + appendProducer(Seq(2, 3)) + log.appendAsLeader(commitMarker(producerId, producerEpoch), leaderEpoch = 0, origin = AppendOrigin.Coordinator) + appendProducer(Seq(2)) + log.appendAsLeader(commitMarker(producerId, producerEpoch), leaderEpoch = 0, origin = AppendOrigin.Coordinator) + log.roll() + + // cannot remove the marker in this pass because there are still valid records + var dirtyOffset = cleaner.doClean(LogToClean(tp, log, 0L, log.activeSegment.baseOffset), currentTime = largeTimestamp)._1 + assertEquals(List(1, 3, 2), LogTestUtils.keysInLog(log)) + assertEquals(List(0, 2, 3, 4, 5), offsetsInLog(log)) + + appendProducer(Seq(1, 3)) + log.appendAsLeader(commitMarker(producerId, producerEpoch), leaderEpoch = 0, origin = AppendOrigin.Coordinator) + log.roll() + + // the first cleaning preserves the commit marker (at offset 3) since there were still records for the transaction + dirtyOffset = cleaner.doClean(LogToClean(tp, log, dirtyOffset, log.activeSegment.baseOffset), currentTime = largeTimestamp)._1 + assertEquals(List(2, 1, 3), LogTestUtils.keysInLog(log)) + assertEquals(List(3, 4, 5, 6, 7, 8), offsetsInLog(log)) + + // clean again with same timestamp to verify marker is not removed early + dirtyOffset = cleaner.doClean(LogToClean(tp, log, dirtyOffset, log.activeSegment.baseOffset), currentTime = largeTimestamp)._1 + assertEquals(List(2, 1, 3), LogTestUtils.keysInLog(log)) + assertEquals(List(3, 4, 5, 6, 7, 8), offsetsInLog(log)) + + // clean again with max timestamp to verify the marker is removed + dirtyOffset = cleaner.doClean(LogToClean(tp, log, dirtyOffset, log.activeSegment.baseOffset), currentTime = Long.MaxValue)._1 + assertEquals(List(2, 1, 3), LogTestUtils.keysInLog(log)) + assertEquals(List(4, 5, 6, 7, 8), offsetsInLog(log)) + } + + /** + * Tests log cleaning with batches that are deleted where no additional messages + * are available to read in the buffer. Cleaning should continue from the next offset. + */ + @Test + def testDeletedBatchesWithNoMessagesRead(): Unit = { + val tp = new TopicPartition("test", 0) + val cleaner = makeCleaner(capacity = Int.MaxValue, maxMessageSize = 100) + val logProps = new Properties() + logProps.put(LogConfig.MaxMessageBytesProp, 100: java.lang.Integer) + logProps.put(LogConfig.SegmentBytesProp, 1000: java.lang.Integer) + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + val producerEpoch = 0.toShort + val producerId = 1L + val appendProducer = appendTransactionalAsLeader(log, producerId, producerEpoch) + + appendProducer(Seq(1)) + log.appendAsLeader(abortMarker(producerId, producerEpoch), leaderEpoch = 0, origin = AppendOrigin.Coordinator) + appendProducer(Seq(2)) + appendProducer(Seq(2)) + log.appendAsLeader(commitMarker(producerId, producerEpoch), leaderEpoch = 0, origin = AppendOrigin.Coordinator) + log.roll() + + cleaner.doClean(LogToClean(tp, log, 0L, log.activeSegment.baseOffset), currentTime = largeTimestamp) + assertEquals(List(2), LogTestUtils.keysInLog(log)) + assertEquals(List(1, 3, 4), offsetsInLog(log)) + + // In the first pass, the deleteHorizon for {Producer2: Commit} is set. In the second pass, it's removed. + runTwoPassClean(cleaner, LogToClean(tp, log, 0L, log.activeSegment.baseOffset), currentTime = largeTimestamp) + assertEquals(List(2), LogTestUtils.keysInLog(log)) + assertEquals(List(3, 4), offsetsInLog(log)) + } + + @Test + def testCommitMarkerRetentionWithEmptyBatch(): Unit = { + val tp = new TopicPartition("test", 0) + val cleaner = makeCleaner(Int.MaxValue) + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 256: java.lang.Integer) + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + val producerEpoch = 0.toShort + val producer1 = appendTransactionalAsLeader(log, 1L, producerEpoch) + val producer2 = appendTransactionalAsLeader(log, 2L, producerEpoch) + + // [{Producer1: 2, 3}] + producer1(Seq(2, 3)) // offsets 0, 1 + log.roll() + + // [{Producer1: 2, 3}], [{Producer2: 2, 3}, {Producer2: Commit}] + producer2(Seq(2, 3)) // offsets 2, 3 + log.appendAsLeader(commitMarker(2L, producerEpoch), leaderEpoch = 0, + origin = AppendOrigin.Coordinator) // offset 4 + log.roll() + + // [{Producer1: 2, 3}], [{Producer2: 2, 3}, {Producer2: Commit}], [{2}, {3}, {Producer1: Commit}] + // {0, 1}, {2, 3}, {4}, {5}, {6}, {7} ==> Offsets + log.appendAsLeader(record(2, 2), leaderEpoch = 0) // offset 5 + log.appendAsLeader(record(3, 3), leaderEpoch = 0) // offset 6 + log.appendAsLeader(commitMarker(1L, producerEpoch), leaderEpoch = 0, + origin = AppendOrigin.Coordinator) // offset 7 + log.roll() + + // first time through the records are removed + // Expected State: [{Producer1: EmptyBatch}, {Producer2: EmptyBatch}, {Producer2: Commit}, {2}, {3}, {Producer1: Commit}] + var dirtyOffset = cleaner.doClean(LogToClean(tp, log, 0L, log.activeSegment.baseOffset), currentTime = largeTimestamp)._1 + assertEquals(List(2, 3), LogTestUtils.keysInLog(log)) + assertEquals(List(4, 5, 6, 7), offsetsInLog(log)) + assertEquals(List(1, 3, 4, 5, 6, 7), lastOffsetsPerBatchInLog(log)) + + // the empty batch remains if cleaned again because it still holds the last sequence + // Expected State: [{Producer1: EmptyBatch}, {Producer2: EmptyBatch}, {Producer2: Commit}, {2}, {3}, {Producer1: Commit}] + dirtyOffset = cleaner.doClean(LogToClean(tp, log, dirtyOffset, log.activeSegment.baseOffset), currentTime = largeTimestamp)._1 + assertEquals(List(2, 3), LogTestUtils.keysInLog(log)) + assertEquals(List(4, 5, 6, 7), offsetsInLog(log)) + assertEquals(List(1, 3, 4, 5, 6, 7), lastOffsetsPerBatchInLog(log)) + + // append a new record from the producer to allow cleaning of the empty batch + // [{Producer1: EmptyBatch}, {Producer2: EmptyBatch}, {Producer2: Commit}, {2}, {3}, {Producer1: Commit}, {Producer2: 1}, {Producer2: Commit}] + // {1}, {3}, {4}, {5}, {6}, {7}, {8}, {9} ==> Offsets + producer2(Seq(1)) // offset 8 + log.appendAsLeader(commitMarker(2L, producerEpoch), leaderEpoch = 0, + origin = AppendOrigin.Coordinator) // offset 9 + log.roll() + + // Expected State: [{Producer1: EmptyBatch}, {Producer2: Commit}, {2}, {3}, {Producer1: Commit}, {Producer2: 1}, {Producer2: Commit}] + // The deleteHorizon for {Producer2: Commit} is still not set yet. + dirtyOffset = cleaner.doClean(LogToClean(tp, log, dirtyOffset, log.activeSegment.baseOffset), currentTime = largeTimestamp)._1 + assertEquals(List(2, 3, 1), LogTestUtils.keysInLog(log)) + assertEquals(List(4, 5, 6, 7, 8, 9), offsetsInLog(log)) + assertEquals(List(1, 4, 5, 6, 7, 8, 9), lastOffsetsPerBatchInLog(log)) + + // Expected State: [{Producer1: EmptyBatch}, {2}, {3}, {Producer1: Commit}, {Producer2: 1}, {Producer2: Commit}] + // In the first pass, the deleteHorizon for {Producer2: Commit} is set. In the second pass, it's removed. + dirtyOffset = runTwoPassClean(cleaner, LogToClean(tp, log, dirtyOffset, log.activeSegment.baseOffset), currentTime = largeTimestamp) + assertEquals(List(2, 3, 1), LogTestUtils.keysInLog(log)) + assertEquals(List(5, 6, 7, 8, 9), offsetsInLog(log)) + assertEquals(List(1, 5, 6, 7, 8, 9), lastOffsetsPerBatchInLog(log)) + } + + @Test + def testCleanEmptyControlBatch(): Unit = { + val tp = new TopicPartition("test", 0) + val cleaner = makeCleaner(Int.MaxValue) + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 256: java.lang.Integer) + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + val producerEpoch = 0.toShort + + // [{Producer1: Commit}, {2}, {3}] + log.appendAsLeader(commitMarker(1L, producerEpoch), leaderEpoch = 0, + origin = AppendOrigin.Coordinator) // offset 1 + log.appendAsLeader(record(2, 2), leaderEpoch = 0) // offset 2 + log.appendAsLeader(record(3, 3), leaderEpoch = 0) // offset 3 + log.roll() + + // first time through the control batch is retained as an empty batch + // Expected State: [{Producer1: EmptyBatch}], [{2}, {3}] + // In the first pass, the deleteHorizon for the commit marker is set. In the second pass, the commit marker is removed + // but the empty batch is retained for preserving the producer epoch. + var dirtyOffset = runTwoPassClean(cleaner, LogToClean(tp, log, 0L, log.activeSegment.baseOffset), currentTime = largeTimestamp) + assertEquals(List(2, 3), LogTestUtils.keysInLog(log)) + assertEquals(List(1, 2), offsetsInLog(log)) + assertEquals(List(0, 1, 2), lastOffsetsPerBatchInLog(log)) + + // the empty control batch does not cause an exception when cleaned + // Expected State: [{Producer1: EmptyBatch}], [{2}, {3}] + dirtyOffset = cleaner.doClean(LogToClean(tp, log, dirtyOffset, log.activeSegment.baseOffset), currentTime = Long.MaxValue)._1 + assertEquals(List(2, 3), LogTestUtils.keysInLog(log)) + assertEquals(List(1, 2), offsetsInLog(log)) + assertEquals(List(0, 1, 2), lastOffsetsPerBatchInLog(log)) + } + + @Test + def testCommittedTransactionSpanningSegments(): Unit = { + val tp = new TopicPartition("test", 0) + val cleaner = makeCleaner(Int.MaxValue) + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 128: java.lang.Integer) + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + val producerEpoch = 0.toShort + val producerId = 1L + + val appendTransaction = appendTransactionalAsLeader(log, producerId, producerEpoch) + appendTransaction(Seq(1)) + log.roll() + + log.appendAsLeader(commitMarker(producerId, producerEpoch), leaderEpoch = 0, origin = AppendOrigin.Coordinator) + log.roll() + + // Both the record and the marker should remain after cleaning + runTwoPassClean(cleaner, LogToClean(tp, log, 0L, log.activeSegment.baseOffset), currentTime = largeTimestamp) + assertEquals(List(0, 1), offsetsInLog(log)) + assertEquals(List(0, 1), lastOffsetsPerBatchInLog(log)) + } + + @Test + def testAbortedTransactionSpanningSegments(): Unit = { + val tp = new TopicPartition("test", 0) + val cleaner = makeCleaner(Int.MaxValue) + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 128: java.lang.Integer) + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + val producerEpoch = 0.toShort + val producerId = 1L + + val appendTransaction = appendTransactionalAsLeader(log, producerId, producerEpoch) + appendTransaction(Seq(1)) + log.roll() + + log.appendAsLeader(abortMarker(producerId, producerEpoch), leaderEpoch = 0, origin = AppendOrigin.Coordinator) + log.roll() + + // Both the batch and the marker should remain after cleaning. The batch is retained + // because it is the last entry for this producerId. The marker is retained because + // there are still batches remaining from this transaction. + cleaner.doClean(LogToClean(tp, log, 0L, log.activeSegment.baseOffset), currentTime = largeTimestamp) + assertEquals(List(1), offsetsInLog(log)) + assertEquals(List(0, 1), lastOffsetsPerBatchInLog(log)) + + // The empty batch and the marker is still retained after a second cleaning. + cleaner.doClean(LogToClean(tp, log, 0L, log.activeSegment.baseOffset), currentTime = Long.MaxValue) + assertEquals(List(1), offsetsInLog(log)) + assertEquals(List(0, 1), lastOffsetsPerBatchInLog(log)) + } + + @Test + def testAbortMarkerRemoval(): Unit = { + val tp = new TopicPartition("test", 0) + val cleaner = makeCleaner(Int.MaxValue) + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 256: java.lang.Integer) + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + val producerEpoch = 0.toShort + val producerId = 1L + val appendProducer = appendTransactionalAsLeader(log, producerId, producerEpoch) + + appendProducer(Seq(1)) + appendProducer(Seq(2, 3)) + log.appendAsLeader(abortMarker(producerId, producerEpoch), leaderEpoch = 0, origin = AppendOrigin.Coordinator) + appendProducer(Seq(3)) + log.appendAsLeader(commitMarker(producerId, producerEpoch), leaderEpoch = 0, origin = AppendOrigin.Coordinator) + log.roll() + + // Aborted records are removed, but the abort marker is still preserved. + val dirtyOffset = cleaner.doClean(LogToClean(tp, log, 0L, log.activeSegment.baseOffset), currentTime = largeTimestamp)._1 + assertEquals(List(3), LogTestUtils.keysInLog(log)) + assertEquals(List(3, 4, 5), offsetsInLog(log)) + + // In the first pass, the delete horizon for the abort marker is set. In the second pass, the abort marker is removed. + runTwoPassClean(cleaner, LogToClean(tp, log, dirtyOffset, log.activeSegment.baseOffset), currentTime = largeTimestamp) + assertEquals(List(3), LogTestUtils.keysInLog(log)) + assertEquals(List(4, 5), offsetsInLog(log)) + } + + @Test + def testEmptyBatchRemovalWithSequenceReuse(): Unit = { + // The group coordinator always writes batches beginning with sequence number 0. This test + // ensures that we still remove old empty batches and transaction markers under this expectation. + + val producerEpoch = 0.toShort + val producerId = 1L + val tp = new TopicPartition("test", 0) + val cleaner = makeCleaner(Int.MaxValue) + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 2048: java.lang.Integer) + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + val appendFirstTransaction = appendTransactionalAsLeader(log, producerId, producerEpoch, + origin = AppendOrigin.Replication) + appendFirstTransaction(Seq(1)) + log.appendAsLeader(commitMarker(producerId, producerEpoch), leaderEpoch = 0, origin = AppendOrigin.Coordinator) + + val appendSecondTransaction = appendTransactionalAsLeader(log, producerId, producerEpoch, + origin = AppendOrigin.Replication) + appendSecondTransaction(Seq(2)) + log.appendAsLeader(commitMarker(producerId, producerEpoch), leaderEpoch = 0, origin = AppendOrigin.Coordinator) + + log.appendAsLeader(record(1, 1), leaderEpoch = 0) + log.appendAsLeader(record(2, 1), leaderEpoch = 0) + + // Roll the log to ensure that the data is cleanable. + log.roll() + + // Both transactional batches will be cleaned. The last one will remain in the log + // as an empty batch in order to preserve the producer sequence number and epoch + cleaner.doClean(LogToClean(tp, log, 0L, log.activeSegment.baseOffset), currentTime = largeTimestamp) + assertEquals(List(1, 3, 4, 5), offsetsInLog(log)) + assertEquals(List(1, 2, 3, 4, 5), lastOffsetsPerBatchInLog(log)) + + // In the first pass, the delete horizon for the first marker is set. In the second pass, the first marker is removed. + runTwoPassClean(cleaner, LogToClean(tp, log, 0L, log.activeSegment.baseOffset), currentTime = largeTimestamp) + assertEquals(List(3, 4, 5), offsetsInLog(log)) + assertEquals(List(2, 3, 4, 5), lastOffsetsPerBatchInLog(log)) + } + + @Test + def testAbortMarkerRetentionWithEmptyBatch(): Unit = { + val tp = new TopicPartition("test", 0) + val cleaner = makeCleaner(Int.MaxValue) + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 256: java.lang.Integer) + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + val producerEpoch = 0.toShort + val producerId = 1L + val appendProducer = appendTransactionalAsLeader(log, producerId, producerEpoch) + + appendProducer(Seq(2, 3)) // batch last offset is 1 + log.appendAsLeader(abortMarker(producerId, producerEpoch), leaderEpoch = 0, origin = AppendOrigin.Coordinator) + log.roll() + + def assertAbortedTransactionIndexed(): Unit = { + val abortedTxns = log.collectAbortedTransactions(0L, 100L) + assertEquals(1, abortedTxns.size) + assertEquals(producerId, abortedTxns.head.producerId) + assertEquals(0, abortedTxns.head.firstOffset) + assertEquals(2, abortedTxns.head.lastOffset) + } + + assertAbortedTransactionIndexed() + + // first time through the records are removed + var dirtyOffset = cleaner.doClean(LogToClean(tp, log, 0L, log.activeSegment.baseOffset), currentTime = largeTimestamp)._1 + assertAbortedTransactionIndexed() + assertEquals(List(), LogTestUtils.keysInLog(log)) + assertEquals(List(2), offsetsInLog(log)) // abort marker is retained + assertEquals(List(1, 2), lastOffsetsPerBatchInLog(log)) // empty batch is retained + + // the empty batch remains if cleaned again because it still holds the last sequence + dirtyOffset = runTwoPassClean(cleaner, LogToClean(tp, log, dirtyOffset, log.activeSegment.baseOffset), currentTime = largeTimestamp) + assertAbortedTransactionIndexed() + assertEquals(List(), LogTestUtils.keysInLog(log)) + assertEquals(List(2), offsetsInLog(log)) // abort marker is still retained + assertEquals(List(1, 2), lastOffsetsPerBatchInLog(log)) // empty batch is retained + + // now update the last sequence so that the empty batch can be removed + appendProducer(Seq(1)) + log.roll() + + dirtyOffset = cleaner.doClean(LogToClean(tp, log, dirtyOffset, log.activeSegment.baseOffset), currentTime = largeTimestamp)._1 + assertAbortedTransactionIndexed() + assertEquals(List(1), LogTestUtils.keysInLog(log)) + assertEquals(List(2, 3), offsetsInLog(log)) // abort marker is not yet gone because we read the empty batch + assertEquals(List(2, 3), lastOffsetsPerBatchInLog(log)) // but we do not preserve the empty batch + + // In the first pass, the delete horizon for the abort marker is set. In the second pass, the abort marker is removed. + dirtyOffset = runTwoPassClean(cleaner, LogToClean(tp, log, dirtyOffset, log.activeSegment.baseOffset), currentTime = largeTimestamp) + assertEquals(List(1), LogTestUtils.keysInLog(log)) + assertEquals(List(3), offsetsInLog(log)) // abort marker is gone + assertEquals(List(3), lastOffsetsPerBatchInLog(log)) + + // we do not bother retaining the aborted transaction in the index + assertEquals(0, log.collectAbortedTransactions(0L, 100L).size) + } + + /** + * Test log cleaning with logs containing messages larger than default message size + */ + @Test + def testLargeMessage(): Unit = { + val largeMessageSize = 1024 * 1024 + // Create cleaner with very small default max message size + val cleaner = makeCleaner(Int.MaxValue, maxMessageSize=1024) + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, largeMessageSize * 16: java.lang.Integer) + logProps.put(LogConfig.MaxMessageBytesProp, largeMessageSize * 2: java.lang.Integer) + + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + while(log.numberOfSegments < 2) + log.appendAsLeader(record(log.logEndOffset.toInt, Array.fill(largeMessageSize)(0: Byte)), leaderEpoch = 0) + val keysFound = LogTestUtils.keysInLog(log) + assertEquals(0L until log.logEndOffset, keysFound) + + // pretend we have the following keys + val keys = immutable.ListSet(1L, 3L, 5L, 7L, 9L) + val map = new FakeOffsetMap(Int.MaxValue) + keys.foreach(k => map.put(key(k), Long.MaxValue)) + + // clean the log + val stats = new CleanerStats() + cleaner.cleanSegments(log, Seq(log.logSegments.head), map, 0L, stats, new CleanedTransactionMetadata, -1) + val shouldRemain = LogTestUtils.keysInLog(log).filter(!keys.contains(_)) + assertEquals(shouldRemain, LogTestUtils.keysInLog(log)) + } + + /** + * Test log cleaning with logs containing messages larger than topic's max message size + */ + @Test + def testMessageLargerThanMaxMessageSize(): Unit = { + val (log, offsetMap) = createLogWithMessagesLargerThanMaxSize(largeMessageSize = 1024 * 1024) + + val cleaner = makeCleaner(Int.MaxValue, maxMessageSize=1024) + cleaner.cleanSegments(log, Seq(log.logSegments.head), offsetMap, 0L, new CleanerStats, new CleanedTransactionMetadata, -1) + val shouldRemain = LogTestUtils.keysInLog(log).filter(k => !offsetMap.map.containsKey(k.toString)) + assertEquals(shouldRemain, LogTestUtils.keysInLog(log)) + } + + /** + * Test log cleaning with logs containing messages larger than topic's max message size + * where header is corrupt + */ + @Test + def testMessageLargerThanMaxMessageSizeWithCorruptHeader(): Unit = { + val (log, offsetMap) = createLogWithMessagesLargerThanMaxSize(largeMessageSize = 1024 * 1024) + val file = new RandomAccessFile(log.logSegments.head.log.file, "rw") + file.seek(Records.MAGIC_OFFSET) + file.write(0xff) + file.close() + + val cleaner = makeCleaner(Int.MaxValue, maxMessageSize=1024) + assertThrows(classOf[CorruptRecordException], () => + cleaner.cleanSegments(log, Seq(log.logSegments.head), offsetMap, 0L, new CleanerStats, new CleanedTransactionMetadata, -1) + ) + } + + /** + * Test log cleaning with logs containing messages larger than topic's max message size + * where message size is corrupt and larger than bytes available in log segment. + */ + @Test + def testCorruptMessageSizeLargerThanBytesAvailable(): Unit = { + val (log, offsetMap) = createLogWithMessagesLargerThanMaxSize(largeMessageSize = 1024 * 1024) + val file = new RandomAccessFile(log.logSegments.head.log.file, "rw") + file.setLength(1024) + file.close() + + val cleaner = makeCleaner(Int.MaxValue, maxMessageSize=1024) + assertThrows(classOf[CorruptRecordException], () => + cleaner.cleanSegments(log, Seq(log.logSegments.head), offsetMap, 0L, new CleanerStats, new CleanedTransactionMetadata, -1) + ) + } + + def createLogWithMessagesLargerThanMaxSize(largeMessageSize: Int): (UnifiedLog, FakeOffsetMap) = { + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, largeMessageSize * 16: java.lang.Integer) + logProps.put(LogConfig.MaxMessageBytesProp, largeMessageSize * 2: java.lang.Integer) + + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + while(log.numberOfSegments < 2) + log.appendAsLeader(record(log.logEndOffset.toInt, Array.fill(largeMessageSize)(0: Byte)), leaderEpoch = 0) + val keysFound = LogTestUtils.keysInLog(log) + assertEquals(0L until log.logEndOffset, keysFound) + + // Decrease the log's max message size + logProps.put(LogConfig.MaxMessageBytesProp, largeMessageSize / 2: java.lang.Integer) + log.updateConfig(LogConfig.fromProps(logConfig.originals, logProps)) + + // pretend we have the following keys + val keys = immutable.ListSet(1, 3, 5, 7, 9) + val map = new FakeOffsetMap(Int.MaxValue) + keys.foreach(k => map.put(key(k), Long.MaxValue)) + + (log, map) + } + + @Test + def testCleaningWithDeletes(): Unit = { + val cleaner = makeCleaner(Int.MaxValue) + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 1024: java.lang.Integer) + + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + // append messages with the keys 0 through N + while(log.numberOfSegments < 2) + log.appendAsLeader(record(log.logEndOffset.toInt, log.logEndOffset.toInt), leaderEpoch = 0) + + // delete all even keys between 0 and N + val leo = log.logEndOffset + for(key <- 0 until leo.toInt by 2) + log.appendAsLeader(tombstoneRecord(key), leaderEpoch = 0) + + // append some new unique keys to pad out to a new active segment + while(log.numberOfSegments < 4) + log.appendAsLeader(record(log.logEndOffset.toInt, log.logEndOffset.toInt), leaderEpoch = 0) + + cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 0, log.activeSegment.baseOffset)) + val keys = LogTestUtils.keysInLog(log).toSet + assertTrue((0 until leo.toInt by 2).forall(!keys.contains(_)), "None of the keys we deleted should still exist.") + } + + @Test + def testLogCleanerStats(): Unit = { + // because loadFactor is 0.75, this means we can fit 3 messages in the map + val cleaner = makeCleaner(4) + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 1024: java.lang.Integer) + + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + log.appendAsLeader(record(0,0), leaderEpoch = 0) // offset 0 + log.appendAsLeader(record(1,1), leaderEpoch = 0) // offset 1 + log.appendAsLeader(record(0,0), leaderEpoch = 0) // offset 2 + log.appendAsLeader(record(1,1), leaderEpoch = 0) // offset 3 + log.appendAsLeader(record(0,0), leaderEpoch = 0) // offset 4 + // roll the segment, so we can clean the messages already appended + log.roll() + + val initialLogSize = log.size + + val (endOffset, stats) = cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 2, log.activeSegment.baseOffset)) + assertEquals(5, endOffset) + assertEquals(5, stats.messagesRead) + assertEquals(initialLogSize, stats.bytesRead) + assertEquals(2, stats.messagesWritten) + assertEquals(log.size, stats.bytesWritten) + assertEquals(0, stats.invalidMessagesRead) + assertTrue(stats.endTime >= stats.startTime) + } + + @Test + def testLogCleanerRetainsProducerLastSequence(): Unit = { + val cleaner = makeCleaner(10) + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 1024: java.lang.Integer) + + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + log.appendAsLeader(record(0, 0), leaderEpoch = 0) // offset 0 + log.appendAsLeader(record(0, 1, producerId = 1, producerEpoch = 0, sequence = 0), leaderEpoch = 0) // offset 1 + log.appendAsLeader(record(0, 2, producerId = 2, producerEpoch = 0, sequence = 0), leaderEpoch = 0) // offset 2 + log.appendAsLeader(record(0, 3, producerId = 3, producerEpoch = 0, sequence = 0), leaderEpoch = 0) // offset 3 + log.appendAsLeader(record(1, 1, producerId = 2, producerEpoch = 0, sequence = 1), leaderEpoch = 0) // offset 4 + + // roll the segment, so we can clean the messages already appended + log.roll() + + cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 0L, log.activeSegment.baseOffset)) + assertEquals(List(1, 3, 4), lastOffsetsPerBatchInLog(log)) + assertEquals(Map(1L -> 0, 2L -> 1, 3L -> 0), lastSequencesInLog(log)) + assertEquals(List(0, 1), LogTestUtils.keysInLog(log)) + assertEquals(List(3, 4), offsetsInLog(log)) + } + + @Test + def testLogCleanerRetainsLastSequenceEvenIfTransactionAborted(): Unit = { + val cleaner = makeCleaner(10) + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 1024: java.lang.Integer) + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + val producerEpoch = 0.toShort + val producerId = 1L + val appendProducer = appendTransactionalAsLeader(log, producerId, producerEpoch) + + appendProducer(Seq(1)) + appendProducer(Seq(2, 3)) + log.appendAsLeader(abortMarker(producerId, producerEpoch), leaderEpoch = 0, origin = AppendOrigin.Coordinator) + log.roll() + + cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 0L, log.activeSegment.baseOffset)) + assertEquals(List(2, 3), lastOffsetsPerBatchInLog(log)) + assertEquals(Map(producerId -> 2), lastSequencesInLog(log)) + assertEquals(List(), LogTestUtils.keysInLog(log)) + assertEquals(List(3), offsetsInLog(log)) + + // Append a new entry from the producer and verify that the empty batch is cleaned up + appendProducer(Seq(1, 5)) + log.roll() + cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 0L, log.activeSegment.baseOffset)) + + assertEquals(List(3, 5), lastOffsetsPerBatchInLog(log)) + assertEquals(Map(producerId -> 4), lastSequencesInLog(log)) + assertEquals(List(1, 5), LogTestUtils.keysInLog(log)) + assertEquals(List(3, 4, 5), offsetsInLog(log)) + } + + @Test + def testPartialSegmentClean(): Unit = { + // because loadFactor is 0.75, this means we can fit 1 message in the map + val cleaner = makeCleaner(2) + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 1024: java.lang.Integer) + + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + log.appendAsLeader(record(0,0), leaderEpoch = 0) // offset 0 + log.appendAsLeader(record(1,1), leaderEpoch = 0) // offset 1 + log.appendAsLeader(record(0,0), leaderEpoch = 0) // offset 2 + log.appendAsLeader(record(1,1), leaderEpoch = 0) // offset 3 + log.appendAsLeader(record(0,0), leaderEpoch = 0) // offset 4 + // roll the segment, so we can clean the messages already appended + log.roll() + + // clean the log with only one message removed + cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 2, log.activeSegment.baseOffset)) + assertEquals(List(1,0,1,0), LogTestUtils.keysInLog(log)) + assertEquals(List(1,2,3,4), offsetsInLog(log)) + + // continue to make progress, even though we can only clean one message at a time + cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 3, log.activeSegment.baseOffset)) + assertEquals(List(0,1,0), LogTestUtils.keysInLog(log)) + assertEquals(List(2,3,4), offsetsInLog(log)) + + cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 4, log.activeSegment.baseOffset)) + assertEquals(List(1,0), LogTestUtils.keysInLog(log)) + assertEquals(List(3,4), offsetsInLog(log)) + } + + @Test + def testCleaningWithUncleanableSection(): Unit = { + val cleaner = makeCleaner(Int.MaxValue) + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 1024: java.lang.Integer) + + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + // Number of distinct keys. For an effective test this should be small enough such that each log segment contains some duplicates. + val N = 10 + val numCleanableSegments = 2 + val numTotalSegments = 7 + + // append messages with the keys 0 through N-1, values equal offset + while(log.numberOfSegments <= numCleanableSegments) + log.appendAsLeader(record(log.logEndOffset.toInt % N, log.logEndOffset.toInt), leaderEpoch = 0) + + // at this point one message past the cleanable segments has been added + // the entire segment containing the first uncleanable offset should not be cleaned. + val firstUncleanableOffset = log.logEndOffset + 1 // +1 so it is past the baseOffset + + while(log.numberOfSegments < numTotalSegments - 1) + log.appendAsLeader(record(log.logEndOffset.toInt % N, log.logEndOffset.toInt), leaderEpoch = 0) + + // the last (active) segment has just one message + + def distinctValuesBySegment = log.logSegments.map(s => s.log.records.asScala.map(record => TestUtils.readString(record.value)).toSet.size).toSeq + + val disctinctValuesBySegmentBeforeClean = distinctValuesBySegment + assertTrue(distinctValuesBySegment.reverse.tail.forall(_ > N), + "Test is not effective unless each segment contains duplicates. Increase segment size or decrease number of keys.") + + cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 0, firstUncleanableOffset)) + + val distinctValuesBySegmentAfterClean = distinctValuesBySegment + + assertTrue(disctinctValuesBySegmentBeforeClean.zip(distinctValuesBySegmentAfterClean) + .take(numCleanableSegments).forall { case (before, after) => after < before }, + "The cleanable segments should have fewer number of values after cleaning") + assertTrue(disctinctValuesBySegmentBeforeClean.zip(distinctValuesBySegmentAfterClean) + .slice(numCleanableSegments, numTotalSegments).forall { x => x._1 == x._2 }, "The uncleanable segments should have the same number of values after cleaning") + } + + @Test + def testLogToClean(): Unit = { + // create a log with small segment size + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 100: java.lang.Integer) + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + // create 6 segments with only one message in each segment + def createRecorcs = TestUtils.singletonRecords(value = Array.fill[Byte](25)(0), key = 1.toString.getBytes) + for (_ <- 0 until 6) + log.appendAsLeader(createRecorcs, leaderEpoch = 0) + + val logToClean = LogToClean(new TopicPartition("test", 0), log, log.activeSegment.baseOffset, log.activeSegment.baseOffset) + + assertEquals(logToClean.totalBytes, log.size - log.activeSegment.size, + "Total bytes of LogToClean should equal size of all segments excluding the active segment") + } + + @Test + def testLogToCleanWithUncleanableSection(): Unit = { + // create a log with small segment size + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 100: java.lang.Integer) + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + // create 6 segments with only one message in each segment + def createRecords = TestUtils.singletonRecords(value = Array.fill[Byte](25)(0), key = 1.toString.getBytes) + for (_ <- 0 until 6) + log.appendAsLeader(createRecords, leaderEpoch = 0) + + // segments [0,1] are clean; segments [2, 3] are cleanable; segments [4,5] are uncleanable + val segs = log.logSegments.toSeq + val logToClean = LogToClean(new TopicPartition("test", 0), log, segs(2).baseOffset, segs(4).baseOffset) + + val expectedCleanSize = segs.take(2).map(_.size).sum + val expectedCleanableSize = segs.slice(2, 4).map(_.size).sum + + assertEquals(logToClean.cleanBytes, expectedCleanSize, + "Uncleanable bytes of LogToClean should equal size of all segments prior the one containing first dirty") + assertEquals(logToClean.cleanableBytes, expectedCleanableSize, + "Cleanable bytes of LogToClean should equal size of all segments from the one containing first dirty offset" + + " to the segment prior to the one with the first uncleanable offset") + assertEquals(logToClean.totalBytes, expectedCleanSize + expectedCleanableSize, + "Total bytes should be the sum of the clean and cleanable segments") + assertEquals(logToClean.cleanableRatio, + expectedCleanableSize / (expectedCleanSize + expectedCleanableSize).toDouble, 1.0e-6d, + "Total cleanable ratio should be the ratio of cleanable size to clean plus cleanable") + } + + @Test + def testCleaningWithUnkeyedMessages(): Unit = { + val cleaner = makeCleaner(Int.MaxValue) + + // create a log with compaction turned off so we can append unkeyed messages + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 1024: java.lang.Integer) + logProps.put(LogConfig.CleanupPolicyProp, LogConfig.Delete) + + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + // append unkeyed messages + while(log.numberOfSegments < 2) + log.appendAsLeader(unkeyedRecord(log.logEndOffset.toInt), leaderEpoch = 0) + val numInvalidMessages = unkeyedMessageCountInLog(log) + + val sizeWithUnkeyedMessages = log.size + + // append keyed messages + while(log.numberOfSegments < 3) + log.appendAsLeader(record(log.logEndOffset.toInt, log.logEndOffset.toInt), leaderEpoch = 0) + + val expectedSizeAfterCleaning = log.size - sizeWithUnkeyedMessages + val (_, stats) = cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 0, log.activeSegment.baseOffset)) + + assertEquals(0, unkeyedMessageCountInLog(log), "Log should only contain keyed messages after cleaning.") + assertEquals(expectedSizeAfterCleaning, log.size, "Log should only contain keyed messages after cleaning.") + assertEquals(numInvalidMessages, stats.invalidMessagesRead, "Cleaner should have seen %d invalid messages.") + } + + def lastOffsetsPerBatchInLog(log: UnifiedLog): Iterable[Long] = { + for (segment <- log.logSegments; batch <- segment.log.batches.asScala) + yield batch.lastOffset + } + + def lastSequencesInLog(log: UnifiedLog): Map[Long, Int] = { + (for (segment <- log.logSegments; + batch <- segment.log.batches.asScala if !batch.isControlBatch && batch.hasProducerId) + yield batch.producerId -> batch.lastSequence).toMap + } + + /* extract all the offsets from a log */ + def offsetsInLog(log: UnifiedLog): Iterable[Long] = + log.logSegments.flatMap(s => s.log.records.asScala.filter(_.hasValue).filter(_.hasKey).map(m => m.offset)) + + def unkeyedMessageCountInLog(log: UnifiedLog) = + log.logSegments.map(s => s.log.records.asScala.filter(_.hasValue).count(m => !m.hasKey)).sum + + def abortCheckDone(topicPartition: TopicPartition): Unit = { + throw new LogCleaningAbortedException() + } + + /** + * Test that abortion during cleaning throws a LogCleaningAbortedException + */ + @Test + def testCleanSegmentsWithAbort(): Unit = { + val cleaner = makeCleaner(Int.MaxValue, abortCheckDone) + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 1024: java.lang.Integer) + + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + // append messages to the log until we have four segments + while(log.numberOfSegments < 4) + log.appendAsLeader(record(log.logEndOffset.toInt, log.logEndOffset.toInt), leaderEpoch = 0) + + val keys = LogTestUtils.keysInLog(log) + val map = new FakeOffsetMap(Int.MaxValue) + keys.foreach(k => map.put(key(k), Long.MaxValue)) + assertThrows(classOf[LogCleaningAbortedException], () => + cleaner.cleanSegments(log, log.logSegments.take(3).toSeq, map, 0L, new CleanerStats(), + new CleanedTransactionMetadata, -1) + ) + } + + /** + * Validate the logic for grouping log segments together for cleaning + */ + @Test + def testSegmentGrouping(): Unit = { + val cleaner = makeCleaner(Int.MaxValue) + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 300: java.lang.Integer) + logProps.put(LogConfig.IndexIntervalBytesProp, 1: java.lang.Integer) + + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + // append some messages to the log + var i = 0 + while(log.numberOfSegments < 10) { + log.appendAsLeader(TestUtils.singletonRecords(value = "hello".getBytes, key = "hello".getBytes), leaderEpoch = 0) + i += 1 + } + + // grouping by very large values should result in a single group with all the segments in it + var groups = cleaner.groupSegmentsBySize(log.logSegments, maxSize = Int.MaxValue, maxIndexSize = Int.MaxValue, log.logEndOffset) + assertEquals(1, groups.size) + assertEquals(log.numberOfSegments, groups.head.size) + checkSegmentOrder(groups) + + // grouping by very small values should result in all groups having one entry + groups = cleaner.groupSegmentsBySize(log.logSegments, maxSize = 1, maxIndexSize = Int.MaxValue, log.logEndOffset) + assertEquals(log.numberOfSegments, groups.size) + assertTrue(groups.forall(_.size == 1), "All groups should be singletons.") + checkSegmentOrder(groups) + groups = cleaner.groupSegmentsBySize(log.logSegments, maxSize = Int.MaxValue, maxIndexSize = 1, log.logEndOffset) + assertEquals(log.numberOfSegments, groups.size) + assertTrue(groups.forall(_.size == 1), "All groups should be singletons.") + checkSegmentOrder(groups) + + val groupSize = 3 + + // check grouping by log size + val logSize = log.logSegments.take(groupSize).map(_.size).sum.toInt + 1 + groups = cleaner.groupSegmentsBySize(log.logSegments, maxSize = logSize, maxIndexSize = Int.MaxValue, log.logEndOffset) + checkSegmentOrder(groups) + assertTrue(groups.dropRight(1).forall(_.size == groupSize), "All but the last group should be the target size.") + + // check grouping by index size + val indexSize = log.logSegments.take(groupSize).map(_.offsetIndex.sizeInBytes).sum + 1 + groups = cleaner.groupSegmentsBySize(log.logSegments, maxSize = Int.MaxValue, maxIndexSize = indexSize, log.logEndOffset) + checkSegmentOrder(groups) + assertTrue(groups.dropRight(1).forall(_.size == groupSize), + "All but the last group should be the target size.") + } + + @Test + def testSegmentGroupingWithSparseOffsetsAndEmptySegments(): Unit ={ + val cleaner = makeCleaner(Int.MaxValue) + val logProps = new Properties() + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + val k="key".getBytes() + val v="val".getBytes() + + //create 3 segments + for(i <- 0 until 3){ + log.appendAsLeader(TestUtils.singletonRecords(value = v, key = k), leaderEpoch = 0) + //0 to Int.MaxValue is Int.MaxValue+1 message, -1 will be the last message of i-th segment + val records = messageWithOffset(k, v, (i + 1L) * (Int.MaxValue + 1L) -1 ) + log.appendAsFollower(records) + assertEquals(i + 1, log.numberOfSegments) + } + + //4th active segment, not clean + log.appendAsLeader(TestUtils.singletonRecords(value = v, key = k), leaderEpoch = 0) + + val totalSegments = 4 + //last segment not cleanable + val firstUncleanableOffset = log.logEndOffset - 1 + val notCleanableSegments = 1 + + assertEquals(totalSegments, log.numberOfSegments) + var groups = cleaner.groupSegmentsBySize(log.logSegments, maxSize = Int.MaxValue, maxIndexSize = Int.MaxValue, firstUncleanableOffset) + //because index file uses 4 byte relative index offset and current segments all none empty, + //segments will not group even their size is very small. + assertEquals(totalSegments - notCleanableSegments, groups.size) + //do clean to clean first 2 segments to empty + cleaner.clean(LogToClean(log.topicPartition, log, 0, firstUncleanableOffset)) + assertEquals(totalSegments, log.numberOfSegments) + assertEquals(0, log.logSegments.head.size) + + //after clean we got 2 empty segment, they will group together this time + groups = cleaner.groupSegmentsBySize(log.logSegments, maxSize = Int.MaxValue, maxIndexSize = Int.MaxValue, firstUncleanableOffset) + val noneEmptySegment = 1 + assertEquals(noneEmptySegment + 1, groups.size) + + //trigger a clean and 2 empty segments should cleaned to 1 + cleaner.clean(LogToClean(log.topicPartition, log, 0, firstUncleanableOffset)) + assertEquals(totalSegments - 1, log.numberOfSegments) + } + + /** + * Validate the logic for grouping log segments together for cleaning when only a small number of + * messages are retained, but the range of offsets is greater than Int.MaxValue. A group should not + * contain a range of offsets greater than Int.MaxValue to ensure that relative offsets can be + * stored in 4 bytes. + */ + @Test + def testSegmentGroupingWithSparseOffsets(): Unit = { + val cleaner = makeCleaner(Int.MaxValue) + + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 400: java.lang.Integer) + logProps.put(LogConfig.IndexIntervalBytesProp, 1: java.lang.Integer) + + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + // fill up first segment + while (log.numberOfSegments == 1) + log.appendAsLeader(TestUtils.singletonRecords(value = "hello".getBytes, key = "hello".getBytes), leaderEpoch = 0) + + // forward offset and append message to next segment at offset Int.MaxValue + val records = messageWithOffset("hello".getBytes, "hello".getBytes, Int.MaxValue - 1) + log.appendAsFollower(records) + log.appendAsLeader(TestUtils.singletonRecords(value = "hello".getBytes, key = "hello".getBytes), leaderEpoch = 0) + assertEquals(Int.MaxValue, log.activeSegment.offsetIndex.lastOffset) + + // grouping should result in a single group with maximum relative offset of Int.MaxValue + var groups = cleaner.groupSegmentsBySize(log.logSegments, maxSize = Int.MaxValue, maxIndexSize = Int.MaxValue, log.logEndOffset) + assertEquals(1, groups.size) + + // append another message, making last offset of second segment > Int.MaxValue + log.appendAsLeader(TestUtils.singletonRecords(value = "hello".getBytes, key = "hello".getBytes), leaderEpoch = 0) + + // grouping should not group the two segments to ensure that maximum relative offset in each group <= Int.MaxValue + groups = cleaner.groupSegmentsBySize(log.logSegments, maxSize = Int.MaxValue, maxIndexSize = Int.MaxValue, log.logEndOffset) + assertEquals(2, groups.size) + checkSegmentOrder(groups) + + // append more messages, creating new segments, further grouping should still occur + while (log.numberOfSegments < 4) + log.appendAsLeader(TestUtils.singletonRecords(value = "hello".getBytes, key = "hello".getBytes), leaderEpoch = 0) + + groups = cleaner.groupSegmentsBySize(log.logSegments, maxSize = Int.MaxValue, maxIndexSize = Int.MaxValue, log.logEndOffset) + assertEquals(log.numberOfSegments - 1, groups.size) + for (group <- groups) + assertTrue(group.last.offsetIndex.lastOffset - group.head.offsetIndex.baseOffset <= Int.MaxValue, + "Relative offset greater than Int.MaxValue") + checkSegmentOrder(groups) + } + + /** + * Following the loading of a log segment where the index file is zero sized, + * the index returned would be the base offset. Sometimes the log file would + * contain data with offsets in excess of the baseOffset which would cause + * the log cleaner to group together segments with a range of > Int.MaxValue + * this test replicates that scenario to ensure that the segments are grouped + * correctly. + */ + @Test + def testSegmentGroupingFollowingLoadOfZeroIndex(): Unit = { + val cleaner = makeCleaner(Int.MaxValue) + + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 400: java.lang.Integer) + + //mimic the effect of loading an empty index file + logProps.put(LogConfig.IndexIntervalBytesProp, 400: java.lang.Integer) + + val log = makeLog(config = LogConfig.fromProps(logConfig.originals, logProps)) + + val record1 = messageWithOffset("hello".getBytes, "hello".getBytes, 0) + log.appendAsFollower(record1) + val record2 = messageWithOffset("hello".getBytes, "hello".getBytes, 1) + log.appendAsFollower(record2) + log.roll(Some(Int.MaxValue/2)) // starting a new log segment at offset Int.MaxValue/2 + val record3 = messageWithOffset("hello".getBytes, "hello".getBytes, Int.MaxValue/2) + log.appendAsFollower(record3) + val record4 = messageWithOffset("hello".getBytes, "hello".getBytes, Int.MaxValue.toLong + 1) + log.appendAsFollower(record4) + + assertTrue(log.logEndOffset - 1 - log.logStartOffset > Int.MaxValue, "Actual offset range should be > Int.MaxValue") + assertTrue(log.logSegments.last.offsetIndex.lastOffset - log.logStartOffset <= Int.MaxValue, + "index.lastOffset is reporting the wrong last offset") + + // grouping should result in two groups because the second segment takes the offset range > MaxInt + val groups = cleaner.groupSegmentsBySize(log.logSegments, maxSize = Int.MaxValue, maxIndexSize = Int.MaxValue, log.logEndOffset) + assertEquals(2, groups.size) + + for (group <- groups) + assertTrue(group.last.readNextOffset - 1 - group.head.baseOffset <= Int.MaxValue, + "Relative offset greater than Int.MaxValue") + checkSegmentOrder(groups) + } + + private def checkSegmentOrder(groups: Seq[Seq[LogSegment]]): Unit = { + val offsets = groups.flatMap(_.map(_.baseOffset)) + assertEquals(offsets.sorted, offsets, "Offsets should be in increasing order.") + } + + /** + * Test building an offset map off the log + */ + @Test + def testBuildOffsetMap(): Unit = { + val map = new FakeOffsetMap(1000) + val log = makeLog() + val cleaner = makeCleaner(Int.MaxValue) + val start = 0 + val end = 500 + writeToLog(log, (start until end) zip (start until end)) + + def checkRange(map: FakeOffsetMap, start: Int, end: Int): Unit = { + val stats = new CleanerStats() + cleaner.buildOffsetMap(log, start, end, map, stats) + val endOffset = map.latestOffset + 1 + assertEquals(end, endOffset, "Last offset should be the end offset.") + assertEquals(end-start, map.size, "Should have the expected number of messages in the map.") + for(i <- start until end) + assertEquals(i.toLong, map.get(key(i)), "Should find all the keys") + assertEquals(-1L, map.get(key(start - 1)), "Should not find a value too small") + assertEquals(-1L, map.get(key(end)), "Should not find a value too large") + assertEquals(end - start, stats.mapMessagesRead) + } + + val segments = log.logSegments.toSeq + checkRange(map, 0, segments(1).baseOffset.toInt) + checkRange(map, segments(1).baseOffset.toInt, segments(3).baseOffset.toInt) + checkRange(map, segments(3).baseOffset.toInt, log.logEndOffset.toInt) + } + + @Test + def testSegmentWithOffsetOverflow(): Unit = { + val cleaner = makeCleaner(Int.MaxValue) + val logProps = new Properties() + logProps.put(LogConfig.IndexIntervalBytesProp, 1: java.lang.Integer) + logProps.put(LogConfig.FileDeleteDelayMsProp, 1000: java.lang.Integer) + val config = LogConfig.fromProps(logConfig.originals, logProps) + + LogTestUtils.initializeLogDirWithOverflowedSegment(dir) + + val log = makeLog(config = config, recoveryPoint = Long.MaxValue) + val segmentWithOverflow = LogTestUtils.firstOverflowSegment(log).getOrElse { + throw new AssertionError("Failed to create log with a segment which has overflowed offsets") + } + + val numSegmentsInitial = log.logSegments.size + val allKeys = LogTestUtils.keysInLog(log).toList + val expectedKeysAfterCleaning = new mutable.ArrayBuffer[Long]() + + // pretend we want to clean every alternate key + val offsetMap = new FakeOffsetMap(Int.MaxValue) + for (k <- 1 until allKeys.size by 2) { + expectedKeysAfterCleaning += allKeys(k - 1) + offsetMap.put(key(allKeys(k)), Long.MaxValue) + } + + // Try to clean segment with offset overflow. This will trigger log split and the cleaning itself must abort. + assertThrows(classOf[LogCleaningAbortedException], () => + cleaner.cleanSegments(log, Seq(segmentWithOverflow), offsetMap, 0L, new CleanerStats(), + new CleanedTransactionMetadata, -1) + ) + assertEquals(numSegmentsInitial + 1, log.logSegments.size) + assertEquals(allKeys, LogTestUtils.keysInLog(log)) + assertFalse(LogTestUtils.hasOffsetOverflow(log)) + + // Clean each segment now that split is complete. + for (segmentToClean <- log.logSegments) + cleaner.cleanSegments(log, List(segmentToClean), offsetMap, 0L, new CleanerStats(), + new CleanedTransactionMetadata, -1) + assertEquals(expectedKeysAfterCleaning, LogTestUtils.keysInLog(log)) + assertFalse(LogTestUtils.hasOffsetOverflow(log)) + log.close() + } + + /** + * Tests recovery if broker crashes at the following stages during the cleaning sequence + *
                  + *
                1. Cleaner has created .cleaned log containing multiple segments, swap sequence not yet started + *
                2. .cleaned log renamed to .swap, old segment files not yet renamed to .deleted + *
                3. .cleaned log renamed to .swap, old segment files renamed to .deleted, but not yet deleted + *
                4. .swap suffix removed, completing the swap, but async delete of .deleted files not yet complete + *
                + */ + @Test + def testRecoveryAfterCrash(): Unit = { + val cleaner = makeCleaner(Int.MaxValue) + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 300: java.lang.Integer) + logProps.put(LogConfig.IndexIntervalBytesProp, 1: java.lang.Integer) + logProps.put(LogConfig.FileDeleteDelayMsProp, 10: java.lang.Integer) + + val config = LogConfig.fromProps(logConfig.originals, logProps) + + // create a log and append some messages + var log = makeLog(config = config) + var messageCount = 0 + while (log.numberOfSegments < 10) { + log.appendAsLeader(record(log.logEndOffset.toInt, log.logEndOffset.toInt), leaderEpoch = 0) + messageCount += 1 + } + val allKeys = LogTestUtils.keysInLog(log) + + // pretend we have odd-numbered keys + val offsetMap = new FakeOffsetMap(Int.MaxValue) + for (k <- 1 until messageCount by 2) + offsetMap.put(key(k), Long.MaxValue) + + // clean the log + cleaner.cleanSegments(log, log.logSegments.take(9).toSeq, offsetMap, 0L, new CleanerStats(), + new CleanedTransactionMetadata, -1) + // clear scheduler so that async deletes don't run + time.scheduler.clear() + var cleanedKeys = LogTestUtils.keysInLog(log) + log.close() + + // 1) Simulate recovery just after .cleaned file is created, before rename to .swap + // On recovery, clean operation is aborted. All messages should be present in the log + log.logSegments.head.changeFileSuffixes("", UnifiedLog.CleanedFileSuffix) + for (file <- dir.listFiles if file.getName.endsWith(UnifiedLog.DeletedFileSuffix)) { + Utils.atomicMoveWithFallback(file.toPath, Paths.get(CoreUtils.replaceSuffix(file.getPath, UnifiedLog.DeletedFileSuffix, "")), false) + } + log = recoverAndCheck(config, allKeys) + + // clean again + cleaner.cleanSegments(log, log.logSegments.take(9).toSeq, offsetMap, 0L, new CleanerStats(), + new CleanedTransactionMetadata, -1) + // clear scheduler so that async deletes don't run + time.scheduler.clear() + cleanedKeys = LogTestUtils.keysInLog(log) + log.close() + + // 2) Simulate recovery just after .cleaned file is created, and a subset of them are renamed to .swap + // On recovery, clean operation is aborted. All messages should be present in the log + log.logSegments.head.changeFileSuffixes("", UnifiedLog.CleanedFileSuffix) + log.logSegments.head.log.renameTo(new File(CoreUtils.replaceSuffix(log.logSegments.head.log.file.getPath, UnifiedLog.CleanedFileSuffix, UnifiedLog.SwapFileSuffix))) + for (file <- dir.listFiles if file.getName.endsWith(UnifiedLog.DeletedFileSuffix)) { + Utils.atomicMoveWithFallback(file.toPath, Paths.get(CoreUtils.replaceSuffix(file.getPath, UnifiedLog.DeletedFileSuffix, "")), false) + } + log = recoverAndCheck(config, allKeys) + + // clean again + cleaner.cleanSegments(log, log.logSegments.take(9).toSeq, offsetMap, 0L, new CleanerStats(), + new CleanedTransactionMetadata, -1) + // clear scheduler so that async deletes don't run + time.scheduler.clear() + cleanedKeys = LogTestUtils.keysInLog(log) + log.close() + + // 3) Simulate recovery just after swap file is created, before old segment files are + // renamed to .deleted. Clean operation is resumed during recovery. + log.logSegments.head.changeFileSuffixes("", UnifiedLog.SwapFileSuffix) + for (file <- dir.listFiles if file.getName.endsWith(UnifiedLog.DeletedFileSuffix)) { + Utils.atomicMoveWithFallback(file.toPath, Paths.get(CoreUtils.replaceSuffix(file.getPath, UnifiedLog.DeletedFileSuffix, "")), false) + } + log = recoverAndCheck(config, cleanedKeys) + + // add some more messages and clean the log again + while (log.numberOfSegments < 10) { + log.appendAsLeader(record(log.logEndOffset.toInt, log.logEndOffset.toInt), leaderEpoch = 0) + messageCount += 1 + } + for (k <- 1 until messageCount by 2) + offsetMap.put(key(k), Long.MaxValue) + cleaner.cleanSegments(log, log.logSegments.take(9).toSeq, offsetMap, 0L, new CleanerStats(), + new CleanedTransactionMetadata, -1) + // clear scheduler so that async deletes don't run + time.scheduler.clear() + cleanedKeys = LogTestUtils.keysInLog(log) + + // 4) Simulate recovery after swap file is created and old segments files are renamed + // to .deleted. Clean operation is resumed during recovery. + log.logSegments.head.changeFileSuffixes("", UnifiedLog.SwapFileSuffix) + log = recoverAndCheck(config, cleanedKeys) + + // add some more messages and clean the log again + while (log.numberOfSegments < 10) { + log.appendAsLeader(record(log.logEndOffset.toInt, log.logEndOffset.toInt), leaderEpoch = 0) + messageCount += 1 + } + for (k <- 1 until messageCount by 2) + offsetMap.put(key(k), Long.MaxValue) + cleaner.cleanSegments(log, log.logSegments.take(9).toSeq, offsetMap, 0L, new CleanerStats(), + new CleanedTransactionMetadata, -1) + // clear scheduler so that async deletes don't run + time.scheduler.clear() + cleanedKeys = LogTestUtils.keysInLog(log) + + // 5) Simulate recovery after a subset of swap files are renamed to regular files and old segments files are renamed + // to .deleted. Clean operation is resumed during recovery. + log.logSegments.head.timeIndex.file.renameTo(new File(CoreUtils.replaceSuffix(log.logSegments.head.timeIndex.file.getPath, "", UnifiedLog.SwapFileSuffix))) + log = recoverAndCheck(config, cleanedKeys) + + // add some more messages and clean the log again + while (log.numberOfSegments < 10) { + log.appendAsLeader(record(log.logEndOffset.toInt, log.logEndOffset.toInt), leaderEpoch = 0) + messageCount += 1 + } + for (k <- 1 until messageCount by 2) + offsetMap.put(key(k), Long.MaxValue) + cleaner.cleanSegments(log, log.logSegments.take(9).toSeq, offsetMap, 0L, new CleanerStats(), + new CleanedTransactionMetadata, -1) + // clear scheduler so that async deletes don't run + time.scheduler.clear() + cleanedKeys = LogTestUtils.keysInLog(log) + log.close() + + // 6) Simulate recovery after swap is complete, but async deletion + // is not yet complete. Clean operation is resumed during recovery. + log = recoverAndCheck(config, cleanedKeys) + log.close() + } + + @Test + def testBuildOffsetMapFakeLarge(): Unit = { + val map = new FakeOffsetMap(1000) + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 120: java.lang.Integer) + logProps.put(LogConfig.SegmentIndexBytesProp, 120: java.lang.Integer) + logProps.put(LogConfig.CleanupPolicyProp, LogConfig.Compact) + val logConfig = LogConfig(logProps) + val log = makeLog(config = logConfig) + val cleaner = makeCleaner(Int.MaxValue) + val keyStart = 0 + val keyEnd = 2 + val offsetStart = 0L + val offsetEnd = 7206178L + val offsetSeq = Seq(offsetStart, offsetEnd) + writeToLog(log, (keyStart until keyEnd) zip (keyStart until keyEnd), offsetSeq) + cleaner.buildOffsetMap(log, keyStart, offsetEnd + 1L, map, new CleanerStats()) + assertEquals(offsetEnd, map.latestOffset, "Last offset should be the end offset.") + assertEquals(keyEnd - keyStart, map.size, "Should have the expected number of messages in the map.") + assertEquals(0L, map.get(key(0)), "Map should contain first value") + assertEquals(offsetEnd, map.get(key(1)), "Map should contain second value") + } + + /** + * Test building a partial offset map of part of a log segment + */ + @Test + def testBuildPartialOffsetMap(): Unit = { + // because loadFactor is 0.75, this means we can fit 2 messages in the map + val log = makeLog() + val cleaner = makeCleaner(3) + val map = cleaner.offsetMap + + log.appendAsLeader(record(0,0), leaderEpoch = 0) + log.appendAsLeader(record(1,1), leaderEpoch = 0) + log.appendAsLeader(record(2,2), leaderEpoch = 0) + log.appendAsLeader(record(3,3), leaderEpoch = 0) + log.appendAsLeader(record(4,4), leaderEpoch = 0) + log.roll() + + val stats = new CleanerStats() + cleaner.buildOffsetMap(log, 2, Int.MaxValue, map, stats) + assertEquals(2, map.size) + assertEquals(-1, map.get(key(0))) + assertEquals(2, map.get(key(2))) + assertEquals(3, map.get(key(3))) + assertEquals(-1, map.get(key(4))) + assertEquals(4, stats.mapMessagesRead) + } + + /** + * This test verifies that messages corrupted by KAFKA-4298 are fixed by the cleaner + */ + @Test + def testCleanCorruptMessageSet(): Unit = { + val codec = CompressionType.GZIP + + val logProps = new Properties() + logProps.put(LogConfig.CompressionTypeProp, codec.name) + val logConfig = LogConfig(logProps) + + val log = makeLog(config = logConfig) + val cleaner = makeCleaner(10) + + // messages are constructed so that the payload matches the expecting offset to + // make offset validation easier after cleaning + + // one compressed log entry with duplicates + val dupSetKeys = (0 until 2) ++ (0 until 2) + val dupSetOffset = 25 + val dupSet = dupSetKeys zip (dupSetOffset until dupSetOffset + dupSetKeys.size) + + // and one without (should still be fixed by the cleaner) + val noDupSetKeys = 3 until 5 + val noDupSetOffset = 50 + val noDupSet = noDupSetKeys zip (noDupSetOffset until noDupSetOffset + noDupSetKeys.size) + + log.appendAsFollower(invalidCleanedMessage(dupSetOffset, dupSet, codec)) + log.appendAsFollower(invalidCleanedMessage(noDupSetOffset, noDupSet, codec)) + + log.roll() + + cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 0, log.activeSegment.baseOffset)) + + for (segment <- log.logSegments; batch <- segment.log.batches.asScala; record <- batch.asScala) { + assertTrue(record.hasMagic(batch.magic)) + val value = TestUtils.readString(record.value).toLong + assertEquals(record.offset, value) + } + } + + /** + * Verify that the client can handle corrupted messages. Located here for now since the client + * does not support writing messages with the old magic. + */ + @Test + def testClientHandlingOfCorruptMessageSet(): Unit = { + val keys = 1 until 10 + val offset = 50 + val set = keys zip (offset until offset + keys.size) + + val corruptedMessage = invalidCleanedMessage(offset, set) + val records = MemoryRecords.readableRecords(corruptedMessage.buffer) + + for (logEntry <- records.records.asScala) { + val offset = logEntry.offset + val value = TestUtils.readString(logEntry.value).toLong + assertEquals(offset, value) + } + } + + @Test + def testCleanTombstone(): Unit = { + val logConfig = LogConfig(new Properties()) + + val log = makeLog(config = logConfig) + val cleaner = makeCleaner(10) + + // Append a message with a large timestamp. + log.appendAsLeader(TestUtils.singletonRecords(value = "0".getBytes, + key = "0".getBytes, + timestamp = time.milliseconds() + logConfig.deleteRetentionMs + 10000), leaderEpoch = 0) + log.roll() + cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 0, log.activeSegment.baseOffset)) + // Append a tombstone with a small timestamp and roll out a new log segment. + log.appendAsLeader(TestUtils.singletonRecords(value = null, + key = "0".getBytes, + timestamp = time.milliseconds() - logConfig.deleteRetentionMs - 10000), leaderEpoch = 0) + log.roll() + + cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 1, log.activeSegment.baseOffset)) + assertEquals(1, log.logSegments.head.log.batches.iterator.next().lastOffset, + "The tombstone should be retained.") + // Append a message and roll out another log segment. + log.appendAsLeader(TestUtils.singletonRecords(value = "1".getBytes, + key = "1".getBytes, + timestamp = time.milliseconds()), leaderEpoch = 0) + log.roll() + cleaner.clean(LogToClean(new TopicPartition("test", 0), log, 2, log.activeSegment.baseOffset)) + assertEquals(1, log.logSegments.head.log.batches.iterator.next().lastOffset, + "The tombstone should be retained.") + } + + /** + * Verify that the clean is able to move beyond missing offsets records in dirty log + */ + @Test + def testCleaningBeyondMissingOffsets(): Unit = { + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 1024*1024: java.lang.Integer) + logProps.put(LogConfig.CleanupPolicyProp, LogConfig.Compact) + val logConfig = LogConfig(logProps) + val cleaner = makeCleaner(Int.MaxValue) + + { + val log = makeLog(dir = TestUtils.randomPartitionLogDir(tmpdir), config = logConfig) + writeToLog(log, (0 to 9) zip (0 to 9), (0L to 9L)) + // roll new segment with baseOffset 11, leaving previous with holes in offset range [9,10] + log.roll(Some(11L)) + + // active segment record + log.appendAsFollower(messageWithOffset(1015, 1015, 11L)) + + val (nextDirtyOffset, _) = cleaner.clean(LogToClean(log.topicPartition, log, 0L, log.activeSegment.baseOffset, needCompactionNow = true)) + assertEquals(log.activeSegment.baseOffset, nextDirtyOffset, + "Cleaning point should pass offset gap") + } + + + { + val log = makeLog(dir = TestUtils.randomPartitionLogDir(tmpdir), config = logConfig) + writeToLog(log, (0 to 9) zip (0 to 9), (0L to 9L)) + // roll new segment with baseOffset 15, leaving previous with holes in offset rage [10, 14] + log.roll(Some(15L)) + + writeToLog(log, (15 to 24) zip (15 to 24), (15L to 24L)) + // roll new segment with baseOffset 30, leaving previous with holes in offset range [25, 29] + log.roll(Some(30L)) + + // active segment record + log.appendAsFollower(messageWithOffset(1015, 1015, 30L)) + + val (nextDirtyOffset, _) = cleaner.clean(LogToClean(log.topicPartition, log, 0L, log.activeSegment.baseOffset, needCompactionNow = true)) + assertEquals(log.activeSegment.baseOffset, nextDirtyOffset, + "Cleaning point should pass offset gap in multiple segments") + } + } + + @Test + def testMaxCleanTimeSecs(): Unit = { + val logCleaner = new LogCleaner(new CleanerConfig, + logDirs = Array(TestUtils.tempDir()), + logs = new Pool[TopicPartition, UnifiedLog](), + logDirFailureChannel = new LogDirFailureChannel(1), + time = time) + + def checkGauge(name: String): Unit = { + val gauge = logCleaner.newGauge(name, () => 999) + // if there is no cleaners, 0 is default value + assertEquals(0, gauge.value()) + } + + try { + checkGauge("max-buffer-utilization-percent") + checkGauge("max-clean-time-secs") + checkGauge("max-compaction-delay-secs") + } finally logCleaner.shutdown() + } + + + private def writeToLog(log: UnifiedLog, keysAndValues: Iterable[(Int, Int)], offsetSeq: Iterable[Long]): Iterable[Long] = { + for(((key, value), offset) <- keysAndValues.zip(offsetSeq)) + yield log.appendAsFollower(messageWithOffset(key, value, offset)).lastOffset + } + + private def invalidCleanedMessage(initialOffset: Long, + keysAndValues: Iterable[(Int, Int)], + codec: CompressionType = CompressionType.GZIP): MemoryRecords = { + // this function replicates the old versions of the cleaner which under some circumstances + // would write invalid compressed message sets with the outer magic set to 1 and the inner + // magic set to 0 + val records = keysAndValues.map(kv => + LegacyRecord.create(RecordBatch.MAGIC_VALUE_V0, + RecordBatch.NO_TIMESTAMP, + kv._1.toString.getBytes, + kv._2.toString.getBytes)) + + val buffer = ByteBuffer.allocate(math.min(math.max(records.map(_.sizeInBytes()).sum / 2, 1024), 1 << 16)) + val builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V1, codec, TimestampType.CREATE_TIME, initialOffset) + + var offset = initialOffset + records.foreach { record => + builder.appendUncheckedWithOffset(offset, record) + offset += 1 + } + + builder.build() + } + + private def messageWithOffset(key: Array[Byte], value: Array[Byte], offset: Long): MemoryRecords = + MemoryRecords.withRecords(offset, CompressionType.NONE, 0, new SimpleRecord(key, value)) + + private def messageWithOffset(key: Int, value: Int, offset: Long): MemoryRecords = + messageWithOffset(key.toString.getBytes, value.toString.getBytes, offset) + + private def makeLog(dir: File = dir, config: LogConfig = logConfig, recoveryPoint: Long = 0L) = + UnifiedLog(dir = dir, config = config, logStartOffset = 0L, recoveryPoint = recoveryPoint, scheduler = time.scheduler, + time = time, brokerTopicStats = new BrokerTopicStats, maxProducerIdExpirationMs = 60 * 60 * 1000, + producerIdExpirationCheckIntervalMs = LogManager.ProducerIdExpirationCheckIntervalMs, + logDirFailureChannel = new LogDirFailureChannel(10), topicId = None, keepPartitionMetadataFile = true) + + private def makeCleaner(capacity: Int, checkDone: TopicPartition => Unit = _ => (), maxMessageSize: Int = 64*1024) = + new Cleaner(id = 0, + offsetMap = new FakeOffsetMap(capacity), + ioBufferSize = maxMessageSize, + maxIoBufferSize = maxMessageSize, + dupBufferLoadFactor = 0.75, + throttler = throttler, + time = time, + checkDone = checkDone) + + private def writeToLog(log: UnifiedLog, seq: Iterable[(Int, Int)]): Iterable[Long] = { + for ((key, value) <- seq) + yield log.appendAsLeader(record(key, value), leaderEpoch = 0).firstOffset.get.messageOffset + } + + private def key(id: Long) = ByteBuffer.wrap(id.toString.getBytes) + + private def record(key: Int, value: Int, + producerId: Long = RecordBatch.NO_PRODUCER_ID, + producerEpoch: Short = RecordBatch.NO_PRODUCER_EPOCH, + sequence: Int = RecordBatch.NO_SEQUENCE, + partitionLeaderEpoch: Int = RecordBatch.NO_PARTITION_LEADER_EPOCH): MemoryRecords = { + MemoryRecords.withIdempotentRecords(RecordBatch.CURRENT_MAGIC_VALUE, 0L, CompressionType.NONE, producerId, producerEpoch, sequence, + partitionLeaderEpoch, new SimpleRecord(key.toString.getBytes, value.toString.getBytes)) + } + + private def appendTransactionalAsLeader(log: UnifiedLog, + producerId: Long, + producerEpoch: Short, + leaderEpoch: Int = 0, + origin: AppendOrigin = AppendOrigin.Client): Seq[Int] => LogAppendInfo = { + appendIdempotentAsLeader(log, producerId, producerEpoch, isTransactional = true, origin = origin) + } + + private def appendIdempotentAsLeader(log: UnifiedLog, + producerId: Long, + producerEpoch: Short, + isTransactional: Boolean = false, + leaderEpoch: Int = 0, + origin: AppendOrigin = AppendOrigin.Client): Seq[Int] => LogAppendInfo = { + var sequence = 0 + keys: Seq[Int] => { + val simpleRecords = keys.map { key => + val keyBytes = key.toString.getBytes + new SimpleRecord(time.milliseconds(), keyBytes, keyBytes) // the value doesn't matter since we validate offsets + } + val records = if (isTransactional) + MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, producerEpoch, sequence, simpleRecords.toArray: _*) + else + MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, producerEpoch, sequence, simpleRecords.toArray: _*) + sequence += simpleRecords.size + log.appendAsLeader(records, leaderEpoch, origin) + } + } + + private def commitMarker(producerId: Long, producerEpoch: Short, timestamp: Long = time.milliseconds()): MemoryRecords = + endTxnMarker(producerId, producerEpoch, ControlRecordType.COMMIT, 0L, timestamp) + + private def abortMarker(producerId: Long, producerEpoch: Short, timestamp: Long = time.milliseconds()): MemoryRecords = + endTxnMarker(producerId, producerEpoch, ControlRecordType.ABORT, 0L, timestamp) + + private def endTxnMarker(producerId: Long, producerEpoch: Short, controlRecordType: ControlRecordType, + offset: Long, timestamp: Long): MemoryRecords = { + val endTxnMarker = new EndTransactionMarker(controlRecordType, 0) + MemoryRecords.withEndTransactionMarker(offset, timestamp, RecordBatch.NO_PARTITION_LEADER_EPOCH, + producerId, producerEpoch, endTxnMarker) + } + + private def record(key: Int, value: Array[Byte]): MemoryRecords = + TestUtils.singletonRecords(key = key.toString.getBytes, value = value) + + private def unkeyedRecord(value: Int): MemoryRecords = + TestUtils.singletonRecords(value = value.toString.getBytes) + + private def tombstoneRecord(key: Int): MemoryRecords = record(key, null) + + private def recoverAndCheck(config: LogConfig, expectedKeys: Iterable[Long]): UnifiedLog = { + LogTestUtils.recoverAndCheck(dir, config, expectedKeys, new BrokerTopicStats(), time, time.scheduler) + } + + /** + * We need to run a two pass clean to perform the following steps to stimulate a proper clean: + * 1. On the first run, set the delete horizon in the batches with tombstone or markers with empty txn records. + * 2. For the second pass, we will advance the current time by tombstoneRetentionMs, which will cause the + * tombstones to expire, leading to their prompt removal from the log. + * Returns the first dirty offset in the log as a result of the second cleaning. + */ + private def runTwoPassClean(cleaner: Cleaner, logToClean: LogToClean, currentTime: Long, + tombstoneRetentionMs: Long = 86400000) : Long = { + cleaner.doClean(logToClean, currentTime) + cleaner.doClean(logToClean, currentTime + tombstoneRetentionMs + 1)._1 + } +} + +class FakeOffsetMap(val slots: Int) extends OffsetMap { + val map = new java.util.HashMap[String, Long]() + var lastOffset = -1L + + private def keyFor(key: ByteBuffer) = + new String(Utils.readBytes(key.duplicate), StandardCharsets.UTF_8) + + override def put(key: ByteBuffer, offset: Long): Unit = { + lastOffset = offset + map.put(keyFor(key), offset) + } + + override def get(key: ByteBuffer): Long = { + val k = keyFor(key) + if(map.containsKey(k)) + map.get(k) + else + -1L + } + + override def clear(): Unit = map.clear() + + override def size: Int = map.size + + override def latestOffset: Long = lastOffset + + override def updateLatestOffset(offset: Long): Unit = { + lastOffset = offset + } + + override def toString: String = map.toString +} diff --git a/core/src/test/scala/unit/kafka/log/LogConcurrencyTest.scala b/core/src/test/scala/unit/kafka/log/LogConcurrencyTest.scala new file mode 100644 index 0000000..e10b5ab --- /dev/null +++ b/core/src/test/scala/unit/kafka/log/LogConcurrencyTest.scala @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import java.util.Properties +import java.util.concurrent.{Callable, Executors} + +import kafka.server.{BrokerTopicStats, FetchHighWatermark, LogDirFailureChannel} +import kafka.utils.{KafkaScheduler, TestUtils} +import org.apache.kafka.common.record.SimpleRecord +import org.apache.kafka.common.utils.{Time, Utils} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} + +import scala.collection.mutable.ListBuffer +import scala.util.Random + +class LogConcurrencyTest { + private val brokerTopicStats = new BrokerTopicStats + private val random = new Random() + private val scheduler = new KafkaScheduler(1) + private val tmpDir = TestUtils.tempDir() + private val logDir = TestUtils.randomPartitionLogDir(tmpDir) + + @BeforeEach + def setup(): Unit = { + scheduler.startup() + } + + @AfterEach + def shutdown(): Unit = { + scheduler.shutdown() + Utils.delete(tmpDir) + } + + @Test + def testUncommittedDataNotConsumed(): Unit = { + testUncommittedDataNotConsumed(createLog()) + } + + @Test + def testUncommittedDataNotConsumedFrequentSegmentRolls(): Unit = { + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 237: Integer) + val logConfig = LogConfig(logProps) + testUncommittedDataNotConsumed(createLog(logConfig)) + } + + def testUncommittedDataNotConsumed(log: UnifiedLog): Unit = { + val executor = Executors.newFixedThreadPool(2) + try { + val maxOffset = 5000 + val consumer = new ConsumerTask(log, maxOffset) + val appendTask = new LogAppendTask(log, maxOffset) + + val consumerFuture = executor.submit(consumer) + val fetcherTaskFuture = executor.submit(appendTask) + + fetcherTaskFuture.get() + consumerFuture.get() + + validateConsumedData(log, consumer.consumedBatches) + } finally executor.shutdownNow() + } + + /** + * Simple consumption task which reads the log in ascending order and collects + * consumed batches for validation + */ + private class ConsumerTask(log: UnifiedLog, lastOffset: Int) extends Callable[Unit] { + val consumedBatches = ListBuffer.empty[FetchedBatch] + + override def call(): Unit = { + var fetchOffset = 0L + while (log.highWatermark < lastOffset) { + val readInfo = log.read( + startOffset = fetchOffset, + maxLength = 1, + isolation = FetchHighWatermark, + minOneMessage = true + ) + readInfo.records.batches().forEach { batch => + consumedBatches += FetchedBatch(batch.baseOffset, batch.partitionLeaderEpoch) + fetchOffset = batch.lastOffset + 1 + } + } + } + } + + /** + * This class simulates basic leader/follower behavior. + */ + private class LogAppendTask(log: UnifiedLog, lastOffset: Long) extends Callable[Unit] { + override def call(): Unit = { + var leaderEpoch = 1 + var isLeader = true + + while (log.highWatermark < lastOffset) { + random.nextInt(2) match { + case 0 => + val logEndOffsetMetadata = log.logEndOffsetMetadata + val logEndOffset = logEndOffsetMetadata.messageOffset + val batchSize = random.nextInt(9) + 1 + val records = (0 to batchSize).map(i => new SimpleRecord(s"$i".getBytes)) + + if (isLeader) { + log.appendAsLeader(TestUtils.records(records), leaderEpoch) + log.maybeIncrementHighWatermark(logEndOffsetMetadata) + } else { + log.appendAsFollower(TestUtils.records(records, + baseOffset = logEndOffset, + partitionLeaderEpoch = leaderEpoch)) + log.updateHighWatermark(logEndOffset) + } + + case 1 => + isLeader = !isLeader + leaderEpoch += 1 + + if (!isLeader) { + log.truncateTo(log.highWatermark) + } + } + } + } + } + + private def createLog(config: LogConfig = LogConfig(new Properties())): UnifiedLog = { + UnifiedLog(dir = logDir, + config = config, + logStartOffset = 0L, + recoveryPoint = 0L, + scheduler = scheduler, + brokerTopicStats = brokerTopicStats, + time = Time.SYSTEM, + maxProducerIdExpirationMs = 60 * 60 * 1000, + producerIdExpirationCheckIntervalMs = LogManager.ProducerIdExpirationCheckIntervalMs, + logDirFailureChannel = new LogDirFailureChannel(10), + topicId = None, + keepPartitionMetadataFile = true) + } + + private def validateConsumedData(log: UnifiedLog, consumedBatches: Iterable[FetchedBatch]): Unit = { + val iter = consumedBatches.iterator + log.logSegments.foreach { segment => + segment.log.batches.forEach { batch => + if (iter.hasNext) { + val consumedBatch = iter.next() + try { + assertEquals(batch.partitionLeaderEpoch, + consumedBatch.epoch, "Consumed batch with unexpected leader epoch") + assertEquals(batch.baseOffset, + consumedBatch.baseOffset, "Consumed batch with unexpected base offset") + } catch { + case t: Throwable => + throw new AssertionError(s"Consumed batch $consumedBatch " + + s"does not match next expected batch in log $batch", t) + } + } + } + } + } + + private case class FetchedBatch(baseOffset: Long, epoch: Int) { + override def toString: String = { + s"FetchedBatch(baseOffset=$baseOffset, epoch=$epoch)" + } + } + +} diff --git a/core/src/test/scala/unit/kafka/log/LogConfigTest.scala b/core/src/test/scala/unit/kafka/log/LogConfigTest.scala new file mode 100644 index 0000000..f72bb92 --- /dev/null +++ b/core/src/test/scala/unit/kafka/log/LogConfigTest.scala @@ -0,0 +1,284 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import kafka.api.KAFKA_3_0_IV1 +import kafka.server.{KafkaConfig, ThrottledReplicaListValidator} +import kafka.utils.TestUtils +import org.apache.kafka.common.config.ConfigDef.Importance.MEDIUM +import org.apache.kafka.common.config.ConfigDef.Type.INT +import org.apache.kafka.common.config.{ConfigException, TopicConfig} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import java.util.{Collections, Properties} +import scala.annotation.nowarn + +class LogConfigTest { + + /** + * This test verifies that KafkaConfig object initialization does not depend on + * LogConfig initialization. Bad things happen due to static initialization + * order dependencies. For example, LogConfig.configDef ends up adding null + * values in serverDefaultConfigNames. This test ensures that the mapping of + * keys from LogConfig to KafkaConfig are not missing values. + */ + @Test + def ensureNoStaticInitializationOrderDependency(): Unit = { + // Access any KafkaConfig val to load KafkaConfig object before LogConfig. + assertNotNull(KafkaConfig.LogRetentionTimeMillisProp) + assertTrue(LogConfig.configNames.filter(config => !LogConfig.configsWithNoServerDefaults.contains(config)) + .forall { config => + val serverConfigOpt = LogConfig.serverConfigName(config) + serverConfigOpt.isDefined && (serverConfigOpt.get != null) + }) + } + + @nowarn("cat=deprecation") + @Test + def testKafkaConfigToProps(): Unit = { + val millisInHour = 60L * 60L * 1000L + val kafkaProps = TestUtils.createBrokerConfig(nodeId = 0, zkConnect = "") + kafkaProps.put(KafkaConfig.LogRollTimeHoursProp, "2") + kafkaProps.put(KafkaConfig.LogRollTimeJitterHoursProp, "2") + kafkaProps.put(KafkaConfig.LogRetentionTimeHoursProp, "2") + kafkaProps.put(KafkaConfig.LogMessageFormatVersionProp, "0.11.0") + + val kafkaConfig = KafkaConfig.fromProps(kafkaProps) + val logProps = LogConfig.extractLogConfigMap(kafkaConfig) + assertEquals(2 * millisInHour, logProps.get(LogConfig.SegmentMsProp)) + assertEquals(2 * millisInHour, logProps.get(LogConfig.SegmentJitterMsProp)) + assertEquals(2 * millisInHour, logProps.get(LogConfig.RetentionMsProp)) + // The message format version should always be 3.0 if the inter-broker protocol version is 3.0 or higher + assertEquals(KAFKA_3_0_IV1.version, logProps.get(LogConfig.MessageFormatVersionProp)) + } + + @Test + def testFromPropsEmpty(): Unit = { + val p = new Properties() + val config = LogConfig(p) + assertEquals(LogConfig(), config) + } + + @nowarn("cat=deprecation") + @Test + def testFromPropsInvalid(): Unit = { + LogConfig.configNames.foreach(name => name match { + case LogConfig.UncleanLeaderElectionEnableProp => assertPropertyInvalid(name, "not a boolean") + case LogConfig.RetentionBytesProp => assertPropertyInvalid(name, "not_a_number") + case LogConfig.RetentionMsProp => assertPropertyInvalid(name, "not_a_number" ) + case LogConfig.CleanupPolicyProp => assertPropertyInvalid(name, "true", "foobar") + case LogConfig.MinCleanableDirtyRatioProp => assertPropertyInvalid(name, "not_a_number", "-0.1", "1.2") + case LogConfig.MinInSyncReplicasProp => assertPropertyInvalid(name, "not_a_number", "0", "-1") + case LogConfig.MessageFormatVersionProp => assertPropertyInvalid(name, "") + case LogConfig.RemoteLogStorageEnableProp => assertPropertyInvalid(name, "not_a_boolean") + case LogConfig.LocalLogRetentionMsProp => assertPropertyInvalid(name, "not_a_number", "-3") + case LogConfig.LocalLogRetentionBytesProp => assertPropertyInvalid(name, "not_a_number", "-3") + + case _ => assertPropertyInvalid(name, "not_a_number", "-1") + }) + } + + @Test + def testInvalidCompactionLagConfig(): Unit = { + val props = new Properties + props.setProperty(LogConfig.MaxCompactionLagMsProp, "100") + props.setProperty(LogConfig.MinCompactionLagMsProp, "200") + assertThrows(classOf[Exception], () => LogConfig.validate(props)) + } + + @Test + def shouldValidateThrottledReplicasConfig(): Unit = { + assertTrue(isValid("*")) + assertTrue(isValid("* ")) + assertTrue(isValid("")) + assertTrue(isValid(" ")) + assertTrue(isValid("100:10")) + assertTrue(isValid("100:10,12:10")) + assertTrue(isValid("100:10,12:10,15:1")) + assertTrue(isValid("100:10,12:10,15:1 ")) + assertTrue(isValid("100:0,")) + + assertFalse(isValid("100")) + assertFalse(isValid("100:")) + assertFalse(isValid("100:0,10")) + assertFalse(isValid("100:0,10:")) + assertFalse(isValid("100:0,10: ")) + assertFalse(isValid("100 :0,10: ")) + assertFalse(isValid("100: 0,10: ")) + assertFalse(isValid("100:0,10 : ")) + assertFalse(isValid("*,100:10")) + assertFalse(isValid("* ,100:10")) + } + + /* Sanity check that toHtmlTable produces one of the expected configs */ + @Test + def testToHtmlTable(): Unit = { + val html = LogConfig.configDefCopy.toHtmlTable + val expectedConfig = "file.delete.delay.ms" + assertTrue(html.contains(expectedConfig), s"Could not find `$expectedConfig` in:\n $html") + } + + /* Sanity check that toHtml produces one of the expected configs */ + @Test + def testToHtml(): Unit = { + val html = LogConfig.configDefCopy.toHtml(4, (key: String) => "prefix_" + key, Collections.emptyMap()) + val expectedConfig = "

                file.delete.delay.ms

                " + assertTrue(html.contains(expectedConfig), s"Could not find `$expectedConfig` in:\n $html") + } + + /* Sanity check that toEnrichedRst produces one of the expected configs */ + @Test + def testToEnrichedRst(): Unit = { + val rst = LogConfig.configDefCopy.toEnrichedRst + val expectedConfig = "``file.delete.delay.ms``" + assertTrue(rst.contains(expectedConfig), s"Could not find `$expectedConfig` in:\n $rst") + } + + /* Sanity check that toEnrichedRst produces one of the expected configs */ + @Test + def testToRst(): Unit = { + val rst = LogConfig.configDefCopy.toRst + val expectedConfig = "``file.delete.delay.ms``" + assertTrue(rst.contains(expectedConfig), s"Could not find `$expectedConfig` in:\n $rst") + } + + @Test + def testGetConfigValue(): Unit = { + // Add a config that doesn't set the `serverDefaultConfigName` + val configDef = LogConfig.configDefCopy + val configNameWithNoServerMapping = "log.foo" + configDef.define(configNameWithNoServerMapping, INT, 1, MEDIUM, s"$configNameWithNoServerMapping doc") + + val deleteDelayKey = configDef.configKeys.get(TopicConfig.FILE_DELETE_DELAY_MS_CONFIG) + val deleteDelayServerDefault = configDef.getConfigValue(deleteDelayKey, LogConfig.ServerDefaultHeaderName) + assertEquals(KafkaConfig.LogDeleteDelayMsProp, deleteDelayServerDefault) + + val keyWithNoServerMapping = configDef.configKeys.get(configNameWithNoServerMapping) + val nullServerDefault = configDef.getConfigValue(keyWithNoServerMapping, LogConfig.ServerDefaultHeaderName) + assertNull(nullServerDefault) + } + + @Test + def testOverriddenConfigsAsLoggableString(): Unit = { + val kafkaProps = TestUtils.createBrokerConfig(nodeId = 0, zkConnect = "") + kafkaProps.put("unknown.broker.password.config", "aaaaa") + kafkaProps.put(KafkaConfig.SslKeyPasswordProp, "somekeypassword") + kafkaProps.put(KafkaConfig.LogRetentionBytesProp, "50") + val kafkaConfig = KafkaConfig.fromProps(kafkaProps) + val topicOverrides = new Properties + // Only set as a topic config + topicOverrides.setProperty(LogConfig.MinInSyncReplicasProp, "2") + // Overrides value from broker config + topicOverrides.setProperty(LogConfig.RetentionBytesProp, "100") + // Unknown topic config, but known broker config + topicOverrides.setProperty(KafkaConfig.SslTruststorePasswordProp, "sometrustpasswrd") + // Unknown config + topicOverrides.setProperty("unknown.topic.password.config", "bbbb") + // We don't currently have any sensitive topic configs, if we add them, we should set one here + val logConfig = LogConfig.fromProps(LogConfig.extractLogConfigMap(kafkaConfig), topicOverrides) + assertEquals("{min.insync.replicas=2, retention.bytes=100, ssl.truststore.password=(redacted), unknown.topic.password.config=(redacted)}", + logConfig.overriddenConfigsAsLoggableString) + } + + private def isValid(configValue: String): Boolean = { + try { + ThrottledReplicaListValidator.ensureValidString("", configValue) + true + } catch { + case _: ConfigException => false + } + } + + private def assertPropertyInvalid(name: String, values: AnyRef*): Unit = { + values.foreach((value) => { + val props = new Properties + props.setProperty(name, value.toString) + assertThrows(classOf[Exception], () => LogConfig(props)) + }) + } + + @Test + def testLocalLogRetentionDerivedProps(): Unit = { + val props = new Properties() + val retentionBytes = 1024 + val retentionMs = 1000L + props.put(LogConfig.RetentionBytesProp, retentionBytes.toString) + props.put(LogConfig.RetentionMsProp, retentionMs.toString) + val logConfig = new LogConfig(props) + + assertEquals(retentionMs, logConfig.remoteLogConfig.localRetentionMs) + assertEquals(retentionBytes, logConfig.remoteLogConfig.localRetentionBytes) + } + + @Test + def testLocalLogRetentionDerivedDefaultProps(): Unit = { + val logConfig = new LogConfig( new Properties()) + + // Local retention defaults are derived from retention properties which can be default or custom. + assertEquals(Defaults.RetentionMs, logConfig.remoteLogConfig.localRetentionMs) + assertEquals(Defaults.RetentionSize, logConfig.remoteLogConfig.localRetentionBytes) + } + + @Test + def testLocalLogRetentionProps(): Unit = { + val props = new Properties() + val localRetentionMs = 500 + val localRetentionBytes = 1000 + props.put(LogConfig.RetentionBytesProp, 2000.toString) + props.put(LogConfig.RetentionMsProp, 1000.toString) + + props.put(LogConfig.LocalLogRetentionMsProp, localRetentionMs.toString) + props.put(LogConfig.LocalLogRetentionBytesProp, localRetentionBytes.toString) + val logConfig = new LogConfig(props) + + assertEquals(localRetentionMs, logConfig.remoteLogConfig.localRetentionMs) + assertEquals(localRetentionBytes, logConfig.remoteLogConfig.localRetentionBytes) + } + + @Test + def testInvalidLocalLogRetentionProps(): Unit = { + // Check for invalid localRetentionMs, < -2 + doTestInvalidLocalLogRetentionProps(-3, 10, 2, 500L) + + // Check for invalid localRetentionBytes < -2 + doTestInvalidLocalLogRetentionProps(500L, -3, 2, 1000L) + + // Check for invalid case of localRetentionMs > retentionMs + doTestInvalidLocalLogRetentionProps(2000L, 2, 100, 1000L) + + // Check for invalid case of localRetentionBytes > retentionBytes + doTestInvalidLocalLogRetentionProps(500L, 200, 100, 1000L) + + // Check for invalid case of localRetentionMs (-1 viz unlimited) > retentionMs, + doTestInvalidLocalLogRetentionProps(-1, 200, 100, 1000L) + + // Check for invalid case of localRetentionBytes(-1 viz unlimited) > retentionBytes + doTestInvalidLocalLogRetentionProps(2000L, -1, 100, 1000L) + } + + private def doTestInvalidLocalLogRetentionProps(localRetentionMs: Long, localRetentionBytes: Int, retentionBytes: Int, retentionMs: Long) = { + val props = new Properties() + props.put(LogConfig.RetentionBytesProp, retentionBytes.toString) + props.put(LogConfig.RetentionMsProp, retentionMs.toString) + + props.put(LogConfig.LocalLogRetentionMsProp, localRetentionMs.toString) + props.put(LogConfig.LocalLogRetentionBytesProp, localRetentionBytes.toString) + assertThrows(classOf[ConfigException], () => new LogConfig(props)); + } +} diff --git a/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala b/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala new file mode 100644 index 0000000..8e0c484 --- /dev/null +++ b/core/src/test/scala/unit/kafka/log/LogLoaderTest.scala @@ -0,0 +1,1680 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import java.io.{BufferedWriter, File, FileWriter} +import java.nio.ByteBuffer +import java.nio.file.{Files, Paths} +import java.util.Properties +import kafka.api.{ApiVersion, KAFKA_0_11_0_IV0} +import kafka.server.epoch.{EpochEntry, LeaderEpochFileCache} +import kafka.server.{BrokerTopicStats, FetchDataInfo, KafkaConfig, LogDirFailureChannel} +import kafka.server.metadata.MockConfigRepository +import kafka.utils.{CoreUtils, MockTime, Scheduler, TestUtils} +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.record.{CompressionType, ControlRecordType, DefaultRecordBatch, MemoryRecords, RecordBatch, RecordVersion, SimpleRecord, TimestampType} +import org.apache.kafka.common.utils.{Time, Utils} +import org.easymock.EasyMock +import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertNotEquals, assertThrows, assertTrue} +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} + +import scala.annotation.nowarn +import scala.collection.mutable.ListBuffer +import scala.collection.{Iterable, Map, mutable} +import scala.jdk.CollectionConverters._ + +class LogLoaderTest { + var config: KafkaConfig = null + val brokerTopicStats = new BrokerTopicStats() + val maxProducerIdExpirationMs: Int = 60 * 60 * 1000 + val tmpDir = TestUtils.tempDir() + val logDir = TestUtils.randomPartitionLogDir(tmpDir) + val mockTime = new MockTime() + + @BeforeEach + def setUp(): Unit = { + val props = TestUtils.createBrokerConfig(0, "127.0.0.1:1", port = -1) + config = KafkaConfig.fromProps(props) + } + + @AfterEach + def tearDown(): Unit = { + brokerTopicStats.close() + Utils.delete(tmpDir) + } + + @Test + def testLogRecoveryIsCalledUponBrokerCrash(): Unit = { + // LogManager must realize correctly if the last shutdown was not clean and the logs need + // to run recovery while loading upon subsequent broker boot up. + val logDir: File = TestUtils.tempDir() + val logProps = new Properties() + val logConfig = LogConfig(logProps) + val logDirs = Seq(logDir) + val topicPartition = new TopicPartition("foo", 0) + var log: UnifiedLog = null + val time = new MockTime() + var cleanShutdownInterceptedValue = false + case class SimulateError(var hasError: Boolean = false) + val simulateError = SimulateError() + + // Create a LogManager with some overridden methods to facilitate interception of clean shutdown + // flag and to inject a runtime error + def interceptedLogManager(logConfig: LogConfig, logDirs: Seq[File], simulateError: SimulateError): LogManager = + new LogManager( + logDirs = logDirs.map(_.getAbsoluteFile), + initialOfflineDirs = Array.empty[File], + configRepository = new MockConfigRepository(), + initialDefaultConfig = logConfig, + cleanerConfig = CleanerConfig(enableCleaner = false), + recoveryThreadsPerDataDir = 4, + flushCheckMs = 1000L, + flushRecoveryOffsetCheckpointMs = 10000L, + flushStartOffsetCheckpointMs = 10000L, + retentionCheckMs = 1000L, + maxPidExpirationMs = 60 * 60 * 1000, + interBrokerProtocolVersion = config.interBrokerProtocolVersion, + scheduler = time.scheduler, + brokerTopicStats = new BrokerTopicStats(), + logDirFailureChannel = new LogDirFailureChannel(logDirs.size), + time = time, + keepPartitionMetadataFile = config.usesTopicId) { + + override def loadLog(logDir: File, hadCleanShutdown: Boolean, recoveryPoints: Map[TopicPartition, Long], + logStartOffsets: Map[TopicPartition, Long], defaultConfig: LogConfig, + topicConfigs: Map[String, LogConfig]): UnifiedLog = { + if (simulateError.hasError) { + throw new RuntimeException("Simulated error") + } + cleanShutdownInterceptedValue = hadCleanShutdown + val topicPartition = UnifiedLog.parseTopicPartitionName(logDir) + val config = topicConfigs.getOrElse(topicPartition.topic, defaultConfig) + val logRecoveryPoint = recoveryPoints.getOrElse(topicPartition, 0L) + val logStartOffset = logStartOffsets.getOrElse(topicPartition, 0L) + val logDirFailureChannel: LogDirFailureChannel = new LogDirFailureChannel(1) + val maxProducerIdExpirationMs = 60 * 60 * 1000 + val segments = new LogSegments(topicPartition) + val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(logDir, topicPartition, logDirFailureChannel, config.recordVersion, "") + val producerStateManager = new ProducerStateManager(topicPartition, logDir, maxProducerIdExpirationMs, time) + val loadLogParams = LoadLogParams(logDir, topicPartition, config, time.scheduler, time, + logDirFailureChannel, hadCleanShutdown, segments, logStartOffset, logRecoveryPoint, + maxProducerIdExpirationMs, leaderEpochCache, producerStateManager) + val offsets = LogLoader.load(loadLogParams) + val localLog = new LocalLog(logDir, logConfig, segments, offsets.recoveryPoint, + offsets.nextOffsetMetadata, mockTime.scheduler, mockTime, topicPartition, + logDirFailureChannel) + new UnifiedLog(offsets.logStartOffset, localLog, brokerTopicStats, + LogManager.ProducerIdExpirationCheckIntervalMs, leaderEpochCache, + producerStateManager, None, true) + } + } + + val cleanShutdownFile = new File(logDir, LogLoader.CleanShutdownFile) + locally { + val logManager: LogManager = interceptedLogManager(logConfig, logDirs, simulateError) + log = logManager.getOrCreateLog(topicPartition, isNew = true, topicId = None) + + // Load logs after a clean shutdown + Files.createFile(cleanShutdownFile.toPath) + cleanShutdownInterceptedValue = false + var defaultConfig = logManager.currentDefaultConfig + logManager.loadLogs(defaultConfig, logManager.fetchTopicConfigOverrides(defaultConfig, Set.empty)) + assertTrue(cleanShutdownInterceptedValue, "Unexpected value intercepted for clean shutdown flag") + assertFalse(cleanShutdownFile.exists(), "Clean shutdown file must not exist after loadLogs has completed") + // Load logs without clean shutdown file + cleanShutdownInterceptedValue = true + defaultConfig = logManager.currentDefaultConfig + logManager.loadLogs(defaultConfig, logManager.fetchTopicConfigOverrides(defaultConfig, Set.empty)) + assertFalse(cleanShutdownInterceptedValue, "Unexpected value intercepted for clean shutdown flag") + assertFalse(cleanShutdownFile.exists(), "Clean shutdown file must not exist after loadLogs has completed") + // Create clean shutdown file and then simulate error while loading logs such that log loading does not complete. + Files.createFile(cleanShutdownFile.toPath) + logManager.shutdown() + } + + locally { + simulateError.hasError = true + val logManager: LogManager = interceptedLogManager(logConfig, logDirs, simulateError) + log = logManager.getOrCreateLog(topicPartition, isNew = true, topicId = None) + + // Simulate error + assertThrows(classOf[RuntimeException], () => { + val defaultConfig = logManager.currentDefaultConfig + logManager.loadLogs(defaultConfig, logManager.fetchTopicConfigOverrides(defaultConfig, Set.empty)) + }) + assertFalse(cleanShutdownFile.exists(), "Clean shutdown file must not have existed") + // Do not simulate error on next call to LogManager#loadLogs. LogManager must understand that log had unclean shutdown the last time. + simulateError.hasError = false + cleanShutdownInterceptedValue = true + val defaultConfig = logManager.currentDefaultConfig + logManager.loadLogs(defaultConfig, logManager.fetchTopicConfigOverrides(defaultConfig, Set.empty)) + assertFalse(cleanShutdownInterceptedValue, "Unexpected value for clean shutdown flag") + } + } + + @Test + def testProducerSnapshotsRecoveryAfterUncleanShutdownV1(): Unit = { + testProducerSnapshotsRecoveryAfterUncleanShutdown(ApiVersion.minSupportedFor(RecordVersion.V1).version) + } + + @Test + def testProducerSnapshotsRecoveryAfterUncleanShutdownCurrentMessageFormat(): Unit = { + testProducerSnapshotsRecoveryAfterUncleanShutdown(ApiVersion.latestVersion.version) + } + + private def createLog(dir: File, + config: LogConfig, + brokerTopicStats: BrokerTopicStats = brokerTopicStats, + logStartOffset: Long = 0L, + recoveryPoint: Long = 0L, + scheduler: Scheduler = mockTime.scheduler, + time: Time = mockTime, + maxProducerIdExpirationMs: Int = maxProducerIdExpirationMs, + producerIdExpirationCheckIntervalMs: Int = LogManager.ProducerIdExpirationCheckIntervalMs, + lastShutdownClean: Boolean = true): UnifiedLog = { + LogTestUtils.createLog(dir, config, brokerTopicStats, scheduler, time, logStartOffset, recoveryPoint, + maxProducerIdExpirationMs, producerIdExpirationCheckIntervalMs, lastShutdownClean) + } + + private def createLogWithOffsetOverflow(logConfig: LogConfig): (UnifiedLog, LogSegment) = { + LogTestUtils.initializeLogDirWithOverflowedSegment(logDir) + + val log = createLog(logDir, logConfig, recoveryPoint = Long.MaxValue) + val segmentWithOverflow = LogTestUtils.firstOverflowSegment(log).getOrElse { + throw new AssertionError("Failed to create log with a segment which has overflowed offsets") + } + + (log, segmentWithOverflow) + } + + private def recoverAndCheck(config: LogConfig, expectedKeys: Iterable[Long]): UnifiedLog = { + // method is called only in case of recovery from hard reset + LogTestUtils.recoverAndCheck(logDir, config, expectedKeys, brokerTopicStats, mockTime, mockTime.scheduler) + } + + /** + * Wrap a single record log buffer with leader epoch. + */ + private def singletonRecordsWithLeaderEpoch(value: Array[Byte], + key: Array[Byte] = null, + leaderEpoch: Int, + offset: Long, + codec: CompressionType = CompressionType.NONE, + timestamp: Long = RecordBatch.NO_TIMESTAMP, + magicValue: Byte = RecordBatch.CURRENT_MAGIC_VALUE): MemoryRecords = { + val records = Seq(new SimpleRecord(timestamp, key, value)) + + val buf = ByteBuffer.allocate(DefaultRecordBatch.sizeInBytes(records.asJava)) + val builder = MemoryRecords.builder(buf, magicValue, codec, TimestampType.CREATE_TIME, offset, + mockTime.milliseconds, leaderEpoch) + records.foreach(builder.append) + builder.build() + } + + @nowarn("cat=deprecation") + private def testProducerSnapshotsRecoveryAfterUncleanShutdown(messageFormatVersion: String): Unit = { + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, "640") + logProps.put(LogConfig.MessageFormatVersionProp, messageFormatVersion) + val logConfig = LogConfig(logProps) + var log = createLog(logDir, logConfig) + assertEquals(None, log.oldestProducerSnapshotOffset) + + for (i <- 0 to 100) { + val record = new SimpleRecord(mockTime.milliseconds, i.toString.getBytes) + log.appendAsLeader(TestUtils.records(List(record)), leaderEpoch = 0) + } + + assertTrue(log.logSegments.size >= 5) + val segmentOffsets = log.logSegments.toVector.map(_.baseOffset) + val activeSegmentOffset = segmentOffsets.last + + // We want the recovery point to be past the segment offset and before the last 2 segments including a gap of + // 1 segment. We collect the data before closing the log. + val offsetForSegmentAfterRecoveryPoint = segmentOffsets(segmentOffsets.size - 3) + val offsetForRecoveryPointSegment = segmentOffsets(segmentOffsets.size - 4) + val (segOffsetsBeforeRecovery, segOffsetsAfterRecovery) = segmentOffsets.toSet.partition(_ < offsetForRecoveryPointSegment) + val recoveryPoint = offsetForRecoveryPointSegment + 1 + assertTrue(recoveryPoint < offsetForSegmentAfterRecoveryPoint) + log.close() + + val segmentsWithReads = mutable.Set[LogSegment]() + val recoveredSegments = mutable.Set[LogSegment]() + val expectedSegmentsWithReads = mutable.Set[Long]() + val expectedSnapshotOffsets = mutable.Set[Long]() + + if (logConfig.messageFormatVersion < KAFKA_0_11_0_IV0) { + expectedSegmentsWithReads += activeSegmentOffset + expectedSnapshotOffsets ++= log.logSegments.map(_.baseOffset).toVector.takeRight(2) :+ log.logEndOffset + } else { + expectedSegmentsWithReads ++= segOffsetsBeforeRecovery ++ Set(activeSegmentOffset) + expectedSnapshotOffsets ++= log.logSegments.map(_.baseOffset).toVector.takeRight(4) :+ log.logEndOffset + } + + def createLogWithInterceptedReads(recoveryPoint: Long) = { + val maxProducerIdExpirationMs = 60 * 60 * 1000 + val topicPartition = UnifiedLog.parseTopicPartitionName(logDir) + val logDirFailureChannel = new LogDirFailureChannel(10) + // Intercept all segment read calls + val interceptedLogSegments = new LogSegments(topicPartition) { + override def add(segment: LogSegment): LogSegment = { + val wrapper = new LogSegment(segment.log, segment.lazyOffsetIndex, segment.lazyTimeIndex, segment.txnIndex, segment.baseOffset, + segment.indexIntervalBytes, segment.rollJitterMs, mockTime) { + + override def read(startOffset: Long, maxSize: Int, maxPosition: Long, minOneMessage: Boolean): FetchDataInfo = { + segmentsWithReads += this + super.read(startOffset, maxSize, maxPosition, minOneMessage) + } + + override def recover(producerStateManager: ProducerStateManager, + leaderEpochCache: Option[LeaderEpochFileCache]): Int = { + recoveredSegments += this + super.recover(producerStateManager, leaderEpochCache) + } + } + super.add(wrapper) + } + } + val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(logDir, topicPartition, logDirFailureChannel, logConfig.recordVersion, "") + val producerStateManager = new ProducerStateManager(topicPartition, logDir, maxProducerIdExpirationMs, mockTime) + val loadLogParams = LoadLogParams( + logDir, + topicPartition, + logConfig, + mockTime.scheduler, + mockTime, + logDirFailureChannel, + hadCleanShutdown = false, + interceptedLogSegments, + 0L, + recoveryPoint, + maxProducerIdExpirationMs, + leaderEpochCache, + producerStateManager) + val offsets = LogLoader.load(loadLogParams) + val localLog = new LocalLog(logDir, logConfig, interceptedLogSegments, offsets.recoveryPoint, + offsets.nextOffsetMetadata, mockTime.scheduler, mockTime, topicPartition, + logDirFailureChannel) + new UnifiedLog(offsets.logStartOffset, localLog, brokerTopicStats, + LogManager.ProducerIdExpirationCheckIntervalMs, leaderEpochCache, producerStateManager, + None, keepPartitionMetadataFile = true) + } + + // Retain snapshots for the last 2 segments + log.producerStateManager.deleteSnapshotsBefore(segmentOffsets(segmentOffsets.size - 2)) + log = createLogWithInterceptedReads(offsetForRecoveryPointSegment) + // We will reload all segments because the recovery point is behind the producer snapshot files (pre KAFKA-5829 behaviour) + assertEquals(expectedSegmentsWithReads, segmentsWithReads.map(_.baseOffset)) + assertEquals(segOffsetsAfterRecovery, recoveredSegments.map(_.baseOffset)) + assertEquals(expectedSnapshotOffsets, LogTestUtils.listProducerSnapshotOffsets(logDir).toSet) + log.close() + segmentsWithReads.clear() + recoveredSegments.clear() + + // Only delete snapshots before the base offset of the recovery point segment (post KAFKA-5829 behaviour) to + // avoid reading all segments + log.producerStateManager.deleteSnapshotsBefore(offsetForRecoveryPointSegment) + log = createLogWithInterceptedReads(recoveryPoint = recoveryPoint) + assertEquals(Set(activeSegmentOffset), segmentsWithReads.map(_.baseOffset)) + assertEquals(segOffsetsAfterRecovery, recoveredSegments.map(_.baseOffset)) + assertEquals(expectedSnapshotOffsets, LogTestUtils.listProducerSnapshotOffsets(logDir).toSet) + + log.close() + } + + @Test + def testSkipLoadingIfEmptyProducerStateBeforeTruncation(): Unit = { + val stateManager: ProducerStateManager = EasyMock.mock(classOf[ProducerStateManager]) + EasyMock.expect(stateManager.removeStraySnapshots(EasyMock.anyObject())).anyTimes() + // Load the log + EasyMock.expect(stateManager.latestSnapshotOffset).andReturn(None) + + stateManager.updateMapEndOffset(0L) + EasyMock.expectLastCall().anyTimes() + + EasyMock.expect(stateManager.mapEndOffset).andStubReturn(0L) + EasyMock.expect(stateManager.isEmpty).andStubReturn(true) + + stateManager.takeSnapshot() + EasyMock.expectLastCall().anyTimes() + + stateManager.truncateAndReload(EasyMock.eq(0L), EasyMock.eq(0L), EasyMock.anyLong) + EasyMock.expectLastCall() + + EasyMock.expect(stateManager.firstUnstableOffset).andStubReturn(None) + + EasyMock.replay(stateManager) + + val topicPartition = UnifiedLog.parseTopicPartitionName(logDir) + val logDirFailureChannel: LogDirFailureChannel = new LogDirFailureChannel(1) + val config = LogConfig(new Properties()) + val maxProducerIdExpirationMs = 300000 + val segments = new LogSegments(topicPartition) + val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(logDir, topicPartition, logDirFailureChannel, config.recordVersion, "") + val offsets = LogLoader.load(LoadLogParams( + logDir, + topicPartition, + config, + mockTime.scheduler, + mockTime, + logDirFailureChannel, + hadCleanShutdown = false, + segments, + 0L, + 0L, + maxProducerIdExpirationMs, + leaderEpochCache, + stateManager)) + val localLog = new LocalLog(logDir, config, segments, offsets.recoveryPoint, + offsets.nextOffsetMetadata, mockTime.scheduler, mockTime, topicPartition, + logDirFailureChannel) + val log = new UnifiedLog(offsets.logStartOffset, + localLog, + brokerTopicStats = brokerTopicStats, + producerIdExpirationCheckIntervalMs = 30000, + leaderEpochCache = leaderEpochCache, + producerStateManager = stateManager, + _topicId = None, + keepPartitionMetadataFile = true) + + EasyMock.verify(stateManager) + + // Append some messages + EasyMock.reset(stateManager) + EasyMock.expect(stateManager.firstUnstableOffset).andStubReturn(None) + + stateManager.updateMapEndOffset(1L) + EasyMock.expectLastCall() + stateManager.updateMapEndOffset(2L) + EasyMock.expectLastCall() + + EasyMock.replay(stateManager) + + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes))), leaderEpoch = 0) + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("b".getBytes))), leaderEpoch = 0) + + EasyMock.verify(stateManager) + + // Now truncate + EasyMock.reset(stateManager) + EasyMock.expect(stateManager.firstUnstableOffset).andStubReturn(None) + EasyMock.expect(stateManager.latestSnapshotOffset).andReturn(None) + EasyMock.expect(stateManager.isEmpty).andStubReturn(true) + EasyMock.expect(stateManager.mapEndOffset).andReturn(2L) + stateManager.truncateAndReload(EasyMock.eq(0L), EasyMock.eq(1L), EasyMock.anyLong) + EasyMock.expectLastCall() + // Truncation causes the map end offset to reset to 0 + EasyMock.expect(stateManager.mapEndOffset).andReturn(0L) + // We skip directly to updating the map end offset + EasyMock.expect(stateManager.updateMapEndOffset(1L)) + + // Finally, we take a snapshot + stateManager.takeSnapshot() + EasyMock.expectLastCall().once() + + EasyMock.replay(stateManager) + + log.truncateTo(1L) + + EasyMock.verify(stateManager) + } + + @Test + def testRecoverAfterNonMonotonicCoordinatorEpochWrite(): Unit = { + // Due to KAFKA-9144, we may encounter a coordinator epoch which goes backwards. + // This test case verifies that recovery logic relaxes validation in this case and + // just takes the latest write. + + val producerId = 1L + val coordinatorEpoch = 5 + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024 * 5) + var log = createLog(logDir, logConfig) + val epoch = 0.toShort + + val firstAppendTimestamp = mockTime.milliseconds() + LogTestUtils.appendEndTxnMarkerAsLeader(log, producerId, epoch, ControlRecordType.ABORT, + firstAppendTimestamp, coordinatorEpoch = coordinatorEpoch) + assertEquals(firstAppendTimestamp, log.producerStateManager.lastEntry(producerId).get.lastTimestamp) + + val maxProducerIdExpirationMs = 60 * 60 * 1000 + mockTime.sleep(maxProducerIdExpirationMs) + assertEquals(None, log.producerStateManager.lastEntry(producerId)) + + val secondAppendTimestamp = mockTime.milliseconds() + LogTestUtils.appendEndTxnMarkerAsLeader(log, producerId, epoch, ControlRecordType.ABORT, + secondAppendTimestamp, coordinatorEpoch = coordinatorEpoch - 1) + + log.close() + + // Force recovery by setting the recoveryPoint to the log start + log = createLog(logDir, logConfig, recoveryPoint = 0L, lastShutdownClean = false) + assertEquals(secondAppendTimestamp, log.producerStateManager.lastEntry(producerId).get.lastTimestamp) + log.close() + } + + @nowarn("cat=deprecation") + @Test + def testSkipTruncateAndReloadIfOldMessageFormatAndNoCleanShutdown(): Unit = { + val stateManager: ProducerStateManager = EasyMock.mock(classOf[ProducerStateManager]) + EasyMock.expect(stateManager.removeStraySnapshots(EasyMock.anyObject())).anyTimes() + + stateManager.updateMapEndOffset(0L) + EasyMock.expectLastCall().anyTimes() + + stateManager.takeSnapshot() + EasyMock.expectLastCall().anyTimes() + + EasyMock.expect(stateManager.isEmpty).andReturn(true) + EasyMock.expectLastCall().once() + + EasyMock.expect(stateManager.firstUnstableOffset).andReturn(None) + EasyMock.expectLastCall().once() + + EasyMock.replay(stateManager) + + val topicPartition = UnifiedLog.parseTopicPartitionName(logDir) + val logProps = new Properties() + logProps.put(LogConfig.MessageFormatVersionProp, "0.10.2") + val config = LogConfig(logProps) + val maxProducerIdExpirationMs = 300000 + val logDirFailureChannel = null + val segments = new LogSegments(topicPartition) + val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(logDir, topicPartition, logDirFailureChannel, config.recordVersion, "") + val offsets = LogLoader.load(LoadLogParams( + logDir, + topicPartition, + config, + mockTime.scheduler, + mockTime, + logDirFailureChannel, + hadCleanShutdown = false, + segments, + 0L, + 0L, + maxProducerIdExpirationMs, + leaderEpochCache, + stateManager)) + val localLog = new LocalLog(logDir, config, segments, offsets.recoveryPoint, + offsets.nextOffsetMetadata, mockTime.scheduler, mockTime, topicPartition, + logDirFailureChannel) + new UnifiedLog(offsets.logStartOffset, + localLog, + brokerTopicStats = brokerTopicStats, + producerIdExpirationCheckIntervalMs = 30000, + leaderEpochCache = leaderEpochCache, + producerStateManager = stateManager, + _topicId = None, + keepPartitionMetadataFile = true) + + EasyMock.verify(stateManager) + } + + @nowarn("cat=deprecation") + @Test + def testSkipTruncateAndReloadIfOldMessageFormatAndCleanShutdown(): Unit = { + val stateManager: ProducerStateManager = EasyMock.mock(classOf[ProducerStateManager]) + EasyMock.expect(stateManager.removeStraySnapshots(EasyMock.anyObject())).anyTimes() + + stateManager.updateMapEndOffset(0L) + EasyMock.expectLastCall().anyTimes() + + stateManager.takeSnapshot() + EasyMock.expectLastCall().anyTimes() + + EasyMock.expect(stateManager.isEmpty).andReturn(true) + EasyMock.expectLastCall().once() + + EasyMock.expect(stateManager.firstUnstableOffset).andReturn(None) + EasyMock.expectLastCall().once() + + EasyMock.replay(stateManager) + + val topicPartition = UnifiedLog.parseTopicPartitionName(logDir) + val logProps = new Properties() + logProps.put(LogConfig.MessageFormatVersionProp, "0.10.2") + val config = LogConfig(logProps) + val maxProducerIdExpirationMs = 300000 + val logDirFailureChannel = null + val segments = new LogSegments(topicPartition) + val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(logDir, topicPartition, logDirFailureChannel, config.recordVersion, "") + val offsets = LogLoader.load(LoadLogParams( + logDir, + topicPartition, + config, + mockTime.scheduler, + mockTime, + logDirFailureChannel, + hadCleanShutdown = true, + segments, + 0L, + 0L, + maxProducerIdExpirationMs, + leaderEpochCache, + stateManager)) + val localLog = new LocalLog(logDir, config, segments, offsets.recoveryPoint, + offsets.nextOffsetMetadata, mockTime.scheduler, mockTime, topicPartition, + logDirFailureChannel) + new UnifiedLog(offsets.logStartOffset, + localLog, + brokerTopicStats = brokerTopicStats, + producerIdExpirationCheckIntervalMs = 30000, + leaderEpochCache = leaderEpochCache, + producerStateManager = stateManager, + _topicId = None, + keepPartitionMetadataFile = true) + + EasyMock.verify(stateManager) + } + + @nowarn("cat=deprecation") + @Test + def testSkipTruncateAndReloadIfNewMessageFormatAndCleanShutdown(): Unit = { + val stateManager: ProducerStateManager = EasyMock.mock(classOf[ProducerStateManager]) + EasyMock.expect(stateManager.removeStraySnapshots(EasyMock.anyObject())).anyTimes() + + EasyMock.expect(stateManager.latestSnapshotOffset).andReturn(None) + + stateManager.updateMapEndOffset(0L) + EasyMock.expectLastCall().anyTimes() + + stateManager.takeSnapshot() + EasyMock.expectLastCall().anyTimes() + + EasyMock.expect(stateManager.isEmpty).andReturn(true) + EasyMock.expectLastCall().once() + + EasyMock.expect(stateManager.firstUnstableOffset).andReturn(None) + EasyMock.expectLastCall().once() + + EasyMock.replay(stateManager) + + val topicPartition = UnifiedLog.parseTopicPartitionName(logDir) + val logProps = new Properties() + logProps.put(LogConfig.MessageFormatVersionProp, "0.11.0") + val config = LogConfig(logProps) + val maxProducerIdExpirationMs = 300000 + val logDirFailureChannel = null + val segments = new LogSegments(topicPartition) + val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(logDir, topicPartition, logDirFailureChannel, config.recordVersion, "") + val offsets = LogLoader.load(LoadLogParams( + logDir, + topicPartition, + config, + mockTime.scheduler, + mockTime, + logDirFailureChannel, + hadCleanShutdown = true, + segments, + 0L, + 0L, + maxProducerIdExpirationMs, + leaderEpochCache, + stateManager)) + val localLog = new LocalLog(logDir, config, segments, offsets.recoveryPoint, + offsets.nextOffsetMetadata, mockTime.scheduler, mockTime, topicPartition, + logDirFailureChannel) + new UnifiedLog(offsets.logStartOffset, + localLog, + brokerTopicStats = brokerTopicStats, + producerIdExpirationCheckIntervalMs = 30000, + leaderEpochCache = leaderEpochCache, + producerStateManager = stateManager, + _topicId = None, + keepPartitionMetadataFile = true) + + EasyMock.verify(stateManager) + } + + @Test + def testLoadProducersAfterDeleteRecordsMidSegment(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5) + val log = createLog(logDir, logConfig) + val pid1 = 1L + val pid2 = 2L + val epoch = 0.toShort + + log.appendAsLeader(TestUtils.records(List(new SimpleRecord(mockTime.milliseconds(), "a".getBytes)), producerId = pid1, + producerEpoch = epoch, sequence = 0), leaderEpoch = 0) + log.appendAsLeader(TestUtils.records(List(new SimpleRecord(mockTime.milliseconds(), "b".getBytes)), producerId = pid2, + producerEpoch = epoch, sequence = 0), leaderEpoch = 0) + assertEquals(2, log.activeProducersWithLastSequence.size) + + log.updateHighWatermark(log.logEndOffset) + log.maybeIncrementLogStartOffset(1L, ClientRecordDeletion) + + // Deleting records should not remove producer state + assertEquals(2, log.activeProducersWithLastSequence.size) + val retainedLastSeqOpt = log.activeProducersWithLastSequence.get(pid2) + assertTrue(retainedLastSeqOpt.isDefined) + assertEquals(0, retainedLastSeqOpt.get) + + log.close() + + // Because the log start offset did not advance, producer snapshots will still be present and the state will be rebuilt + val reloadedLog = createLog(logDir, logConfig, logStartOffset = 1L, lastShutdownClean = false) + assertEquals(2, reloadedLog.activeProducersWithLastSequence.size) + val reloadedLastSeqOpt = log.activeProducersWithLastSequence.get(pid2) + assertEquals(retainedLastSeqOpt, reloadedLastSeqOpt) + } + + @Test + def testLoadingLogKeepsLargestStrayProducerStateSnapshot(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5, retentionBytes = 0, retentionMs = 1000 * 60, fileDeleteDelayMs = 0) + val log = createLog(logDir, logConfig) + val pid1 = 1L + val epoch = 0.toShort + + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes)), producerId = pid1, producerEpoch = epoch, sequence = 0), leaderEpoch = 0) + log.roll() + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("b".getBytes)), producerId = pid1, producerEpoch = epoch, sequence = 1), leaderEpoch = 0) + log.roll() + + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("c".getBytes)), producerId = pid1, producerEpoch = epoch, sequence = 2), leaderEpoch = 0) + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("d".getBytes)), producerId = pid1, producerEpoch = epoch, sequence = 3), leaderEpoch = 0) + + // Close the log, we should now have 3 segments + log.close() + assertEquals(log.logSegments.size, 3) + // We expect 3 snapshot files, two of which are for the first two segments, the last was written out during log closing. + assertEquals(Seq(1, 2, 4), ProducerStateManager.listSnapshotFiles(logDir).map(_.offset).sorted) + // Inject a stray snapshot file within the bounds of the log at offset 3, it should be cleaned up after loading the log + val straySnapshotFile = UnifiedLog.producerSnapshotFile(logDir, 3).toPath + Files.createFile(straySnapshotFile) + assertEquals(Seq(1, 2, 3, 4), ProducerStateManager.listSnapshotFiles(logDir).map(_.offset).sorted) + + createLog(logDir, logConfig, lastShutdownClean = false) + // We should clean up the stray producer state snapshot file, but keep the largest snapshot file (4) + assertEquals(Seq(1, 2, 4), ProducerStateManager.listSnapshotFiles(logDir).map(_.offset).sorted) + } + + @Test + def testLoadProducersAfterDeleteRecordsOnSegment(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5) + val log = createLog(logDir, logConfig) + val pid1 = 1L + val pid2 = 2L + val epoch = 0.toShort + + log.appendAsLeader(TestUtils.records(List(new SimpleRecord(mockTime.milliseconds(), "a".getBytes)), producerId = pid1, + producerEpoch = epoch, sequence = 0), leaderEpoch = 0) + log.roll() + log.appendAsLeader(TestUtils.records(List(new SimpleRecord(mockTime.milliseconds(), "b".getBytes)), producerId = pid2, + producerEpoch = epoch, sequence = 0), leaderEpoch = 0) + + assertEquals(2, log.logSegments.size) + assertEquals(2, log.activeProducersWithLastSequence.size) + + log.updateHighWatermark(log.logEndOffset) + log.maybeIncrementLogStartOffset(1L, ClientRecordDeletion) + log.deleteOldSegments() + + // Deleting records should not remove producer state + assertEquals(1, log.logSegments.size) + assertEquals(2, log.activeProducersWithLastSequence.size) + val retainedLastSeqOpt = log.activeProducersWithLastSequence.get(pid2) + assertTrue(retainedLastSeqOpt.isDefined) + assertEquals(0, retainedLastSeqOpt.get) + + log.close() + + // After reloading log, producer state should not be regenerated + val reloadedLog = createLog(logDir, logConfig, logStartOffset = 1L, lastShutdownClean = false) + assertEquals(1, reloadedLog.activeProducersWithLastSequence.size) + val reloadedEntryOpt = log.activeProducersWithLastSequence.get(pid2) + assertEquals(retainedLastSeqOpt, reloadedEntryOpt) + } + + /** + * Append a bunch of messages to a log and then re-open it both with and without recovery and check that the log re-initializes correctly. + */ + @Test + def testLogRecoversToCorrectOffset(): Unit = { + val numMessages = 100 + val messageSize = 100 + val segmentSize = 7 * messageSize + val indexInterval = 3 * messageSize + val logConfig = LogTestUtils.createLogConfig(segmentBytes = segmentSize, indexIntervalBytes = indexInterval, segmentIndexBytes = 4096) + var log = createLog(logDir, logConfig) + for(i <- 0 until numMessages) + log.appendAsLeader(TestUtils.singletonRecords(value = TestUtils.randomBytes(messageSize), + timestamp = mockTime.milliseconds + i * 10), leaderEpoch = 0) + assertEquals(numMessages, log.logEndOffset, + "After appending %d messages to an empty log, the log end offset should be %d".format(numMessages, numMessages)) + val lastIndexOffset = log.activeSegment.offsetIndex.lastOffset + val numIndexEntries = log.activeSegment.offsetIndex.entries + val lastOffset = log.logEndOffset + // After segment is closed, the last entry in the time index should be (largest timestamp -> last offset). + val lastTimeIndexOffset = log.logEndOffset - 1 + val lastTimeIndexTimestamp = log.activeSegment.largestTimestamp + // Depending on when the last time index entry is inserted, an entry may or may not be inserted into the time index. + val numTimeIndexEntries = log.activeSegment.timeIndex.entries + { + if (log.activeSegment.timeIndex.lastEntry.offset == log.logEndOffset - 1) 0 else 1 + } + log.close() + + def verifyRecoveredLog(log: UnifiedLog, expectedRecoveryPoint: Long): Unit = { + assertEquals(expectedRecoveryPoint, log.recoveryPoint, s"Unexpected recovery point") + assertEquals(numMessages, log.logEndOffset, s"Should have $numMessages messages when log is reopened w/o recovery") + assertEquals(lastIndexOffset, log.activeSegment.offsetIndex.lastOffset, "Should have same last index offset as before.") + assertEquals(numIndexEntries, log.activeSegment.offsetIndex.entries, "Should have same number of index entries as before.") + assertEquals(lastTimeIndexTimestamp, log.activeSegment.timeIndex.lastEntry.timestamp, "Should have same last time index timestamp") + assertEquals(lastTimeIndexOffset, log.activeSegment.timeIndex.lastEntry.offset, "Should have same last time index offset") + assertEquals(numTimeIndexEntries, log.activeSegment.timeIndex.entries, "Should have same number of time index entries as before.") + } + + log = createLog(logDir, logConfig, recoveryPoint = lastOffset, lastShutdownClean = false) + verifyRecoveredLog(log, lastOffset) + log.close() + + // test recovery case + val recoveryPoint = 10 + log = createLog(logDir, logConfig, recoveryPoint = recoveryPoint, lastShutdownClean = false) + // the recovery point should not be updated after unclean shutdown until the log is flushed + verifyRecoveredLog(log, recoveryPoint) + log.flush() + verifyRecoveredLog(log, lastOffset) + log.close() + } + + /** + * Test that if we manually delete an index segment it is rebuilt when the log is re-opened + */ + @Test + def testIndexRebuild(): Unit = { + // publish the messages and close the log + val numMessages = 200 + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 200, indexIntervalBytes = 1) + var log = createLog(logDir, logConfig) + for(i <- 0 until numMessages) + log.appendAsLeader(TestUtils.singletonRecords(value = TestUtils.randomBytes(10), timestamp = mockTime.milliseconds + i * 10), leaderEpoch = 0) + val indexFiles = log.logSegments.map(_.lazyOffsetIndex.file) + val timeIndexFiles = log.logSegments.map(_.lazyTimeIndex.file) + log.close() + + // delete all the index files + indexFiles.foreach(_.delete()) + timeIndexFiles.foreach(_.delete()) + + // reopen the log + log = createLog(logDir, logConfig, lastShutdownClean = false) + assertEquals(numMessages, log.logEndOffset, "Should have %d messages when log is reopened".format(numMessages)) + assertTrue(log.logSegments.head.offsetIndex.entries > 0, "The index should have been rebuilt") + assertTrue(log.logSegments.head.timeIndex.entries > 0, "The time index should have been rebuilt") + for(i <- 0 until numMessages) { + assertEquals(i, LogTestUtils.readLog(log, i, 100).records.batches.iterator.next().lastOffset) + if (i == 0) + assertEquals(log.logSegments.head.baseOffset, log.fetchOffsetByTimestamp(mockTime.milliseconds + i * 10).get.offset) + else + assertEquals(i, log.fetchOffsetByTimestamp(mockTime.milliseconds + i * 10).get.offset) + } + log.close() + } + + /** + * Test that if messages format version of the messages in a segment is before 0.10.0, the time index should be empty. + */ + @nowarn("cat=deprecation") + @Test + def testRebuildTimeIndexForOldMessages(): Unit = { + val numMessages = 200 + val segmentSize = 200 + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, segmentSize.toString) + logProps.put(LogConfig.IndexIntervalBytesProp, "1") + logProps.put(LogConfig.MessageFormatVersionProp, "0.9.0") + val logConfig = LogConfig(logProps) + var log = createLog(logDir, logConfig) + for (i <- 0 until numMessages) + log.appendAsLeader(TestUtils.singletonRecords(value = TestUtils.randomBytes(10), + timestamp = mockTime.milliseconds + i * 10, magicValue = RecordBatch.MAGIC_VALUE_V1), leaderEpoch = 0) + val timeIndexFiles = log.logSegments.map(_.lazyTimeIndex.file) + log.close() + + // Delete the time index. + timeIndexFiles.foreach(file => Files.delete(file.toPath)) + + // The rebuilt time index should be empty + log = createLog(logDir, logConfig, recoveryPoint = numMessages + 1, lastShutdownClean = false) + for (segment <- log.logSegments.init) { + assertEquals(0, segment.timeIndex.entries, "The time index should be empty") + assertEquals(0, segment.lazyTimeIndex.file.length, "The time index file size should be 0") + } + } + + + /** + * Test that if we have corrupted an index segment it is rebuilt when the log is re-opened + */ + @Test + def testCorruptIndexRebuild(): Unit = { + // publish the messages and close the log + val numMessages = 200 + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 200, indexIntervalBytes = 1) + var log = createLog(logDir, logConfig) + for(i <- 0 until numMessages) + log.appendAsLeader(TestUtils.singletonRecords(value = TestUtils.randomBytes(10), timestamp = mockTime.milliseconds + i * 10), leaderEpoch = 0) + val indexFiles = log.logSegments.map(_.lazyOffsetIndex.file) + val timeIndexFiles = log.logSegments.map(_.lazyTimeIndex.file) + log.close() + + // corrupt all the index files + for( file <- indexFiles) { + val bw = new BufferedWriter(new FileWriter(file)) + bw.write(" ") + bw.close() + } + + // corrupt all the index files + for( file <- timeIndexFiles) { + val bw = new BufferedWriter(new FileWriter(file)) + bw.write(" ") + bw.close() + } + + // reopen the log with recovery point=0 so that the segment recovery can be triggered + log = createLog(logDir, logConfig, lastShutdownClean = false) + assertEquals(numMessages, log.logEndOffset, "Should have %d messages when log is reopened".format(numMessages)) + for(i <- 0 until numMessages) { + assertEquals(i, LogTestUtils.readLog(log, i, 100).records.batches.iterator.next().lastOffset) + if (i == 0) + assertEquals(log.logSegments.head.baseOffset, log.fetchOffsetByTimestamp(mockTime.milliseconds + i * 10).get.offset) + else + assertEquals(i, log.fetchOffsetByTimestamp(mockTime.milliseconds + i * 10).get.offset) + } + log.close() + } + + /** + * When we open a log any index segments without an associated log segment should be deleted. + */ + @Test + def testBogusIndexSegmentsAreRemoved(): Unit = { + val bogusIndex1 = UnifiedLog.offsetIndexFile(logDir, 0) + val bogusTimeIndex1 = UnifiedLog.timeIndexFile(logDir, 0) + val bogusIndex2 = UnifiedLog.offsetIndexFile(logDir, 5) + val bogusTimeIndex2 = UnifiedLog.timeIndexFile(logDir, 5) + + // The files remain absent until we first access it because we are doing lazy loading for time index and offset index + // files but in this test case we need to create these files in order to test we will remove them. + bogusIndex2.createNewFile() + bogusTimeIndex2.createNewFile() + + def createRecords = TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds) + val logConfig = LogTestUtils.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, segmentIndexBytes = 1000, indexIntervalBytes = 1) + val log = createLog(logDir, logConfig) + + // Force the segment to access the index files because we are doing index lazy loading. + log.logSegments.toSeq.head.offsetIndex + log.logSegments.toSeq.head.timeIndex + + assertTrue(bogusIndex1.length > 0, + "The first index file should have been replaced with a larger file") + assertTrue(bogusTimeIndex1.length > 0, + "The first time index file should have been replaced with a larger file") + assertFalse(bogusIndex2.exists, + "The second index file should have been deleted.") + assertFalse(bogusTimeIndex2.exists, + "The second time index file should have been deleted.") + + // check that we can append to the log + for (_ <- 0 until 10) + log.appendAsLeader(createRecords, leaderEpoch = 0) + + log.delete() + } + + /** + * Verify that truncation works correctly after re-opening the log + */ + @Test + def testReopenThenTruncate(): Unit = { + def createRecords = TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds) + // create a log + val logConfig = LogTestUtils.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, segmentIndexBytes = 1000, indexIntervalBytes = 10000) + var log = createLog(logDir, logConfig) + + // add enough messages to roll over several segments then close and re-open and attempt to truncate + for (_ <- 0 until 100) + log.appendAsLeader(createRecords, leaderEpoch = 0) + log.close() + log = createLog(logDir, logConfig, lastShutdownClean = false) + log.truncateTo(3) + assertEquals(1, log.numberOfSegments, "All but one segment should be deleted.") + assertEquals(3, log.logEndOffset, "Log end offset should be 3.") + } + + /** + * Any files ending in .deleted should be removed when the log is re-opened. + */ + @Test + def testOpenDeletesObsoleteFiles(): Unit = { + def createRecords = TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds - 1000) + val logConfig = LogTestUtils.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, segmentIndexBytes = 1000, retentionMs = 999) + var log = createLog(logDir, logConfig) + + // append some messages to create some segments + for (_ <- 0 until 100) + log.appendAsLeader(createRecords, leaderEpoch = 0) + + // expire all segments + log.updateHighWatermark(log.logEndOffset) + log.deleteOldSegments() + log.close() + log = createLog(logDir, logConfig, lastShutdownClean = false) + assertEquals(1, log.numberOfSegments, "The deleted segments should be gone.") + } + + @Test + def testCorruptLog(): Unit = { + // append some messages to create some segments + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1000, indexIntervalBytes = 1, maxMessageBytes = 64 * 1024) + def createRecords = TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds) + val recoveryPoint = 50L + for (_ <- 0 until 10) { + // create a log and write some messages to it + logDir.mkdirs() + var log = createLog(logDir, logConfig) + val numMessages = 50 + TestUtils.random.nextInt(50) + for (_ <- 0 until numMessages) + log.appendAsLeader(createRecords, leaderEpoch = 0) + val records = log.logSegments.flatMap(_.log.records.asScala.toList).toList + log.close() + + // corrupt index and log by appending random bytes + TestUtils.appendNonsenseToFile(log.activeSegment.lazyOffsetIndex.file, TestUtils.random.nextInt(1024) + 1) + TestUtils.appendNonsenseToFile(log.activeSegment.log.file, TestUtils.random.nextInt(1024) + 1) + + // attempt recovery + log = createLog(logDir, logConfig, brokerTopicStats, 0L, recoveryPoint, lastShutdownClean = false) + assertEquals(numMessages, log.logEndOffset) + + val recovered = log.logSegments.flatMap(_.log.records.asScala.toList).toList + assertEquals(records.size, recovered.size) + + for (i <- records.indices) { + val expected = records(i) + val actual = recovered(i) + assertEquals(expected.key, actual.key, s"Keys not equal") + assertEquals(expected.value, actual.value, s"Values not equal") + assertEquals(expected.timestamp, actual.timestamp, s"Timestamps not equal") + } + + Utils.delete(logDir) + } + } + + @Test + def testOverCompactedLogRecovery(): Unit = { + // append some messages to create some segments + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1000, indexIntervalBytes = 1, maxMessageBytes = 64 * 1024) + val log = createLog(logDir, logConfig) + val set1 = MemoryRecords.withRecords(0, CompressionType.NONE, 0, new SimpleRecord("v1".getBytes(), "k1".getBytes())) + val set2 = MemoryRecords.withRecords(Integer.MAX_VALUE.toLong + 2, CompressionType.NONE, 0, new SimpleRecord("v3".getBytes(), "k3".getBytes())) + val set3 = MemoryRecords.withRecords(Integer.MAX_VALUE.toLong + 3, CompressionType.NONE, 0, new SimpleRecord("v4".getBytes(), "k4".getBytes())) + val set4 = MemoryRecords.withRecords(Integer.MAX_VALUE.toLong + 4, CompressionType.NONE, 0, new SimpleRecord("v5".getBytes(), "k5".getBytes())) + //Writes into an empty log with baseOffset 0 + log.appendAsFollower(set1) + assertEquals(0L, log.activeSegment.baseOffset) + //This write will roll the segment, yielding a new segment with base offset = max(1, Integer.MAX_VALUE+2) = Integer.MAX_VALUE+2 + log.appendAsFollower(set2) + assertEquals(Integer.MAX_VALUE.toLong + 2, log.activeSegment.baseOffset) + assertTrue(UnifiedLog.producerSnapshotFile(logDir, Integer.MAX_VALUE.toLong + 2).exists) + //This will go into the existing log + log.appendAsFollower(set3) + assertEquals(Integer.MAX_VALUE.toLong + 2, log.activeSegment.baseOffset) + //This will go into the existing log + log.appendAsFollower(set4) + assertEquals(Integer.MAX_VALUE.toLong + 2, log.activeSegment.baseOffset) + log.close() + val indexFiles = logDir.listFiles.filter(file => file.getName.contains(".index")) + assertEquals(2, indexFiles.length) + for (file <- indexFiles) { + val offsetIndex = new OffsetIndex(file, file.getName.replace(".index","").toLong) + assertTrue(offsetIndex.lastOffset >= 0) + offsetIndex.close() + } + Utils.delete(logDir) + } + + @nowarn("cat=deprecation") + @Test + def testLeaderEpochCacheClearedAfterStaticMessageFormatDowngrade(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1000, indexIntervalBytes = 1, maxMessageBytes = 65536) + val log = createLog(logDir, logConfig) + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("foo".getBytes()))), leaderEpoch = 5) + assertEquals(Some(5), log.latestEpoch) + log.close() + + // reopen the log with an older message format version and check the cache + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, "1000") + logProps.put(LogConfig.IndexIntervalBytesProp, "1") + logProps.put(LogConfig.MaxMessageBytesProp, "65536") + logProps.put(LogConfig.MessageFormatVersionProp, "0.10.2") + val downgradedLogConfig = LogConfig(logProps) + val reopened = createLog(logDir, downgradedLogConfig, lastShutdownClean = false) + LogTestUtils.assertLeaderEpochCacheEmpty(reopened) + + reopened.appendAsLeader(TestUtils.records(List(new SimpleRecord("bar".getBytes())), + magicValue = RecordVersion.V1.value), leaderEpoch = 5) + LogTestUtils.assertLeaderEpochCacheEmpty(reopened) + } + + @Test + def testOverCompactedLogRecoveryMultiRecord(): Unit = { + // append some messages to create some segments + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1000, indexIntervalBytes = 1, maxMessageBytes = 64 * 1024) + val log = createLog(logDir, logConfig) + val set1 = MemoryRecords.withRecords(0, CompressionType.NONE, 0, new SimpleRecord("v1".getBytes(), "k1".getBytes())) + val set2 = MemoryRecords.withRecords(Integer.MAX_VALUE.toLong + 2, CompressionType.GZIP, 0, + new SimpleRecord("v3".getBytes(), "k3".getBytes()), + new SimpleRecord("v4".getBytes(), "k4".getBytes())) + val set3 = MemoryRecords.withRecords(Integer.MAX_VALUE.toLong + 4, CompressionType.GZIP, 0, + new SimpleRecord("v5".getBytes(), "k5".getBytes()), + new SimpleRecord("v6".getBytes(), "k6".getBytes())) + val set4 = MemoryRecords.withRecords(Integer.MAX_VALUE.toLong + 6, CompressionType.GZIP, 0, + new SimpleRecord("v7".getBytes(), "k7".getBytes()), + new SimpleRecord("v8".getBytes(), "k8".getBytes())) + //Writes into an empty log with baseOffset 0 + log.appendAsFollower(set1) + assertEquals(0L, log.activeSegment.baseOffset) + //This write will roll the segment, yielding a new segment with base offset = max(1, Integer.MAX_VALUE+2) = Integer.MAX_VALUE+2 + log.appendAsFollower(set2) + assertEquals(Integer.MAX_VALUE.toLong + 2, log.activeSegment.baseOffset) + assertTrue(UnifiedLog.producerSnapshotFile(logDir, Integer.MAX_VALUE.toLong + 2).exists) + //This will go into the existing log + log.appendAsFollower(set3) + assertEquals(Integer.MAX_VALUE.toLong + 2, log.activeSegment.baseOffset) + //This will go into the existing log + log.appendAsFollower(set4) + assertEquals(Integer.MAX_VALUE.toLong + 2, log.activeSegment.baseOffset) + log.close() + val indexFiles = logDir.listFiles.filter(file => file.getName.contains(".index")) + assertEquals(2, indexFiles.length) + for (file <- indexFiles) { + val offsetIndex = new OffsetIndex(file, file.getName.replace(".index","").toLong) + assertTrue(offsetIndex.lastOffset >= 0) + offsetIndex.close() + } + Utils.delete(logDir) + } + + @Test + def testOverCompactedLogRecoveryMultiRecordV1(): Unit = { + // append some messages to create some segments + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1000, indexIntervalBytes = 1, maxMessageBytes = 64 * 1024) + val log = createLog(logDir, logConfig) + val set1 = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V1, 0, CompressionType.NONE, + new SimpleRecord("v1".getBytes(), "k1".getBytes())) + val set2 = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V1, Integer.MAX_VALUE.toLong + 2, CompressionType.GZIP, + new SimpleRecord("v3".getBytes(), "k3".getBytes()), + new SimpleRecord("v4".getBytes(), "k4".getBytes())) + val set3 = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V1, Integer.MAX_VALUE.toLong + 4, CompressionType.GZIP, + new SimpleRecord("v5".getBytes(), "k5".getBytes()), + new SimpleRecord("v6".getBytes(), "k6".getBytes())) + val set4 = MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V1, Integer.MAX_VALUE.toLong + 6, CompressionType.GZIP, + new SimpleRecord("v7".getBytes(), "k7".getBytes()), + new SimpleRecord("v8".getBytes(), "k8".getBytes())) + //Writes into an empty log with baseOffset 0 + log.appendAsFollower(set1) + assertEquals(0L, log.activeSegment.baseOffset) + //This write will roll the segment, yielding a new segment with base offset = max(1, 3) = 3 + log.appendAsFollower(set2) + assertEquals(3, log.activeSegment.baseOffset) + assertTrue(UnifiedLog.producerSnapshotFile(logDir, 3).exists) + //This will also roll the segment, yielding a new segment with base offset = max(5, Integer.MAX_VALUE+4) = Integer.MAX_VALUE+4 + log.appendAsFollower(set3) + assertEquals(Integer.MAX_VALUE.toLong + 4, log.activeSegment.baseOffset) + assertTrue(UnifiedLog.producerSnapshotFile(logDir, Integer.MAX_VALUE.toLong + 4).exists) + //This will go into the existing log + log.appendAsFollower(set4) + assertEquals(Integer.MAX_VALUE.toLong + 4, log.activeSegment.baseOffset) + log.close() + val indexFiles = logDir.listFiles.filter(file => file.getName.contains(".index")) + assertEquals(3, indexFiles.length) + for (file <- indexFiles) { + val offsetIndex = new OffsetIndex(file, file.getName.replace(".index","").toLong) + assertTrue(offsetIndex.lastOffset >= 0) + offsetIndex.close() + } + Utils.delete(logDir) + } + + @Test + def testRecoveryOfSegmentWithOffsetOverflow(): Unit = { + val logConfig = LogTestUtils.createLogConfig(indexIntervalBytes = 1, fileDeleteDelayMs = 1000) + val (log, _) = createLogWithOffsetOverflow(logConfig) + val expectedKeys = LogTestUtils.keysInLog(log) + + // Run recovery on the log. This should split the segment underneath. Ignore .deleted files as we could have still + // have them lying around after the split. + val recoveredLog = recoverAndCheck(logConfig, expectedKeys) + assertEquals(expectedKeys, LogTestUtils.keysInLog(recoveredLog)) + + // Running split again would throw an error + + for (segment <- recoveredLog.logSegments) { + assertThrows(classOf[IllegalArgumentException], () => log.splitOverflowedSegment(segment)) + } + } + + @Test + def testRecoveryAfterCrashDuringSplitPhase1(): Unit = { + val logConfig = LogTestUtils.createLogConfig(indexIntervalBytes = 1, fileDeleteDelayMs = 1000) + val (log, segmentWithOverflow) = createLogWithOffsetOverflow(logConfig) + val expectedKeys = LogTestUtils.keysInLog(log) + val numSegmentsInitial = log.logSegments.size + + // Split the segment + val newSegments = log.splitOverflowedSegment(segmentWithOverflow) + + // Simulate recovery just after .cleaned file is created, before rename to .swap. On recovery, existing split + // operation is aborted but the recovery process itself kicks off split which should complete. + newSegments.reverse.foreach(segment => { + segment.changeFileSuffixes("", UnifiedLog.CleanedFileSuffix) + segment.truncateTo(0) + }) + for (file <- logDir.listFiles if file.getName.endsWith(UnifiedLog.DeletedFileSuffix)) + Utils.atomicMoveWithFallback(file.toPath, Paths.get(CoreUtils.replaceSuffix(file.getPath, UnifiedLog.DeletedFileSuffix, ""))) + + val recoveredLog = recoverAndCheck(logConfig, expectedKeys) + assertEquals(expectedKeys, LogTestUtils.keysInLog(recoveredLog)) + assertEquals(numSegmentsInitial + 1, recoveredLog.logSegments.size) + recoveredLog.close() + } + + @Test + def testRecoveryAfterCrashDuringSplitPhase2(): Unit = { + val logConfig = LogTestUtils.createLogConfig(indexIntervalBytes = 1, fileDeleteDelayMs = 1000) + val (log, segmentWithOverflow) = createLogWithOffsetOverflow(logConfig) + val expectedKeys = LogTestUtils.keysInLog(log) + val numSegmentsInitial = log.logSegments.size + + // Split the segment + val newSegments = log.splitOverflowedSegment(segmentWithOverflow) + + // Simulate recovery just after one of the new segments has been renamed to .swap. On recovery, existing split + // operation is aborted but the recovery process itself kicks off split which should complete. + newSegments.reverse.foreach { segment => + if (segment != newSegments.last) + segment.changeFileSuffixes("", UnifiedLog.CleanedFileSuffix) + else + segment.changeFileSuffixes("", UnifiedLog.SwapFileSuffix) + segment.truncateTo(0) + } + for (file <- logDir.listFiles if file.getName.endsWith(UnifiedLog.DeletedFileSuffix)) + Utils.atomicMoveWithFallback(file.toPath, Paths.get(CoreUtils.replaceSuffix(file.getPath, UnifiedLog.DeletedFileSuffix, ""))) + + val recoveredLog = recoverAndCheck(logConfig, expectedKeys) + assertEquals(expectedKeys, LogTestUtils.keysInLog(recoveredLog)) + assertEquals(numSegmentsInitial + 1, recoveredLog.logSegments.size) + recoveredLog.close() + } + + @Test + def testRecoveryAfterCrashDuringSplitPhase3(): Unit = { + val logConfig = LogTestUtils.createLogConfig(indexIntervalBytes = 1, fileDeleteDelayMs = 1000) + val (log, segmentWithOverflow) = createLogWithOffsetOverflow(logConfig) + val expectedKeys = LogTestUtils.keysInLog(log) + val numSegmentsInitial = log.logSegments.size + + // Split the segment + val newSegments = log.splitOverflowedSegment(segmentWithOverflow) + + // Simulate recovery right after all new segments have been renamed to .swap. On recovery, existing split operation + // is completed and the old segment must be deleted. + newSegments.reverse.foreach(segment => { + segment.changeFileSuffixes("", UnifiedLog.SwapFileSuffix) + }) + for (file <- logDir.listFiles if file.getName.endsWith(UnifiedLog.DeletedFileSuffix)) + Utils.atomicMoveWithFallback(file.toPath, Paths.get(CoreUtils.replaceSuffix(file.getPath, UnifiedLog.DeletedFileSuffix, ""))) + + // Truncate the old segment + segmentWithOverflow.truncateTo(0) + + val recoveredLog = recoverAndCheck(logConfig, expectedKeys) + assertEquals(expectedKeys, LogTestUtils.keysInLog(recoveredLog)) + assertEquals(numSegmentsInitial + 1, recoveredLog.logSegments.size) + log.close() + } + + @Test + def testRecoveryAfterCrashDuringSplitPhase4(): Unit = { + val logConfig = LogTestUtils.createLogConfig(indexIntervalBytes = 1, fileDeleteDelayMs = 1000) + val (log, segmentWithOverflow) = createLogWithOffsetOverflow(logConfig) + val expectedKeys = LogTestUtils.keysInLog(log) + val numSegmentsInitial = log.logSegments.size + + // Split the segment + val newSegments = log.splitOverflowedSegment(segmentWithOverflow) + + // Simulate recovery right after all new segments have been renamed to .swap and old segment has been deleted. On + // recovery, existing split operation is completed. + newSegments.reverse.foreach(_.changeFileSuffixes("", UnifiedLog.SwapFileSuffix)) + + for (file <- logDir.listFiles if file.getName.endsWith(UnifiedLog.DeletedFileSuffix)) + Utils.delete(file) + + // Truncate the old segment + segmentWithOverflow.truncateTo(0) + + val recoveredLog = recoverAndCheck(logConfig, expectedKeys) + assertEquals(expectedKeys, LogTestUtils.keysInLog(recoveredLog)) + assertEquals(numSegmentsInitial + 1, recoveredLog.logSegments.size) + recoveredLog.close() + } + + @Test + def testRecoveryAfterCrashDuringSplitPhase5(): Unit = { + val logConfig = LogTestUtils.createLogConfig(indexIntervalBytes = 1, fileDeleteDelayMs = 1000) + val (log, segmentWithOverflow) = createLogWithOffsetOverflow(logConfig) + val expectedKeys = LogTestUtils.keysInLog(log) + val numSegmentsInitial = log.logSegments.size + + // Split the segment + val newSegments = log.splitOverflowedSegment(segmentWithOverflow) + + // Simulate recovery right after one of the new segment has been renamed to .swap and the other to .log. On + // recovery, existing split operation is completed. + newSegments.last.changeFileSuffixes("", UnifiedLog.SwapFileSuffix) + + // Truncate the old segment + segmentWithOverflow.truncateTo(0) + + val recoveredLog = recoverAndCheck(logConfig, expectedKeys) + assertEquals(expectedKeys, LogTestUtils.keysInLog(recoveredLog)) + assertEquals(numSegmentsInitial + 1, recoveredLog.logSegments.size) + recoveredLog.close() + } + + @Test + def testCleanShutdownFile(): Unit = { + // append some messages to create some segments + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1000, indexIntervalBytes = 1, maxMessageBytes = 64 * 1024) + def createRecords = TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds) + + var recoveryPoint = 0L + // create a log and write some messages to it + var log = createLog(logDir, logConfig) + for (_ <- 0 until 100) + log.appendAsLeader(createRecords, leaderEpoch = 0) + log.close() + + // check if recovery was attempted. Even if the recovery point is 0L, recovery should not be attempted as the + // clean shutdown file exists. Note: Earlier, Log layer relied on the presence of clean shutdown file to determine the status + // of last shutdown. Now, LogManager checks for the presence of this file and immediately deletes the same. It passes + // down a clean shutdown flag to the Log layer as log is loaded. Recovery is attempted based on this flag. + recoveryPoint = log.logEndOffset + log = createLog(logDir, logConfig) + assertEquals(recoveryPoint, log.logEndOffset) + } + + /** + * Append a bunch of messages to a log and then re-open it with recovery and check that the leader epochs are recovered properly. + */ + @Test + def testLogRecoversForLeaderEpoch(): Unit = { + val log = createLog(logDir, LogConfig()) + val leaderEpochCache = log.leaderEpochCache.get + val firstBatch = singletonRecordsWithLeaderEpoch(value = "random".getBytes, leaderEpoch = 1, offset = 0) + log.appendAsFollower(records = firstBatch) + + val secondBatch = singletonRecordsWithLeaderEpoch(value = "random".getBytes, leaderEpoch = 2, offset = 1) + log.appendAsFollower(records = secondBatch) + + val thirdBatch = singletonRecordsWithLeaderEpoch(value = "random".getBytes, leaderEpoch = 2, offset = 2) + log.appendAsFollower(records = thirdBatch) + + val fourthBatch = singletonRecordsWithLeaderEpoch(value = "random".getBytes, leaderEpoch = 3, offset = 3) + log.appendAsFollower(records = fourthBatch) + + assertEquals(ListBuffer(EpochEntry(1, 0), EpochEntry(2, 1), EpochEntry(3, 3)), leaderEpochCache.epochEntries) + + // deliberately remove some of the epoch entries + leaderEpochCache.truncateFromEnd(2) + assertNotEquals(ListBuffer(EpochEntry(1, 0), EpochEntry(2, 1), EpochEntry(3, 3)), leaderEpochCache.epochEntries) + log.close() + + // reopen the log and recover from the beginning + val recoveredLog = createLog(logDir, LogConfig(), lastShutdownClean = false) + val recoveredLeaderEpochCache = recoveredLog.leaderEpochCache.get + + // epoch entries should be recovered + assertEquals(ListBuffer(EpochEntry(1, 0), EpochEntry(2, 1), EpochEntry(3, 3)), recoveredLeaderEpochCache.epochEntries) + recoveredLog.close() + } + + @Test + def testFullTransactionIndexRecovery(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 128 * 5) + val log = createLog(logDir, logConfig) + val epoch = 0.toShort + + val pid1 = 1L + val pid2 = 2L + val pid3 = 3L + val pid4 = 4L + + val appendPid1 = LogTestUtils.appendTransactionalAsLeader(log, pid1, epoch, mockTime) + val appendPid2 = LogTestUtils.appendTransactionalAsLeader(log, pid2, epoch, mockTime) + val appendPid3 = LogTestUtils.appendTransactionalAsLeader(log, pid3, epoch, mockTime) + val appendPid4 = LogTestUtils.appendTransactionalAsLeader(log, pid4, epoch, mockTime) + + // mix transactional and non-transactional data + appendPid1(5) // nextOffset: 5 + LogTestUtils.appendNonTransactionalAsLeader(log, 3) // 8 + appendPid2(2) // 10 + appendPid1(4) // 14 + appendPid3(3) // 17 + LogTestUtils.appendNonTransactionalAsLeader(log, 2) // 19 + appendPid1(10) // 29 + LogTestUtils.appendEndTxnMarkerAsLeader(log, pid1, epoch, ControlRecordType.ABORT, mockTime.milliseconds()) // 30 + appendPid2(6) // 36 + appendPid4(3) // 39 + LogTestUtils.appendNonTransactionalAsLeader(log, 10) // 49 + appendPid3(9) // 58 + LogTestUtils.appendEndTxnMarkerAsLeader(log, pid3, epoch, ControlRecordType.COMMIT, mockTime.milliseconds()) // 59 + appendPid4(8) // 67 + appendPid2(7) // 74 + LogTestUtils.appendEndTxnMarkerAsLeader(log, pid2, epoch, ControlRecordType.ABORT, mockTime.milliseconds()) // 75 + LogTestUtils.appendNonTransactionalAsLeader(log, 10) // 85 + appendPid4(4) // 89 + LogTestUtils.appendEndTxnMarkerAsLeader(log, pid4, epoch, ControlRecordType.COMMIT, mockTime.milliseconds()) // 90 + + // delete all the offset and transaction index files to force recovery + log.logSegments.foreach { segment => + segment.offsetIndex.deleteIfExists() + segment.txnIndex.deleteIfExists() + } + + log.close() + + val reloadedLogConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 5) + val reloadedLog = createLog(logDir, reloadedLogConfig, lastShutdownClean = false) + val abortedTransactions = LogTestUtils.allAbortedTransactions(reloadedLog) + assertEquals(List(new AbortedTxn(pid1, 0L, 29L, 8L), new AbortedTxn(pid2, 8L, 74L, 36L)), abortedTransactions) + } + + @Test + def testRecoverOnlyLastSegment(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 128 * 5) + val log = createLog(logDir, logConfig) + val epoch = 0.toShort + + val pid1 = 1L + val pid2 = 2L + val pid3 = 3L + val pid4 = 4L + + val appendPid1 = LogTestUtils.appendTransactionalAsLeader(log, pid1, epoch, mockTime) + val appendPid2 = LogTestUtils.appendTransactionalAsLeader(log, pid2, epoch, mockTime) + val appendPid3 = LogTestUtils.appendTransactionalAsLeader(log, pid3, epoch, mockTime) + val appendPid4 = LogTestUtils.appendTransactionalAsLeader(log, pid4, epoch, mockTime) + + // mix transactional and non-transactional data + appendPid1(5) // nextOffset: 5 + LogTestUtils.appendNonTransactionalAsLeader(log, 3) // 8 + appendPid2(2) // 10 + appendPid1(4) // 14 + appendPid3(3) // 17 + LogTestUtils.appendNonTransactionalAsLeader(log, 2) // 19 + appendPid1(10) // 29 + LogTestUtils.appendEndTxnMarkerAsLeader(log, pid1, epoch, ControlRecordType.ABORT, mockTime.milliseconds()) // 30 + appendPid2(6) // 36 + appendPid4(3) // 39 + LogTestUtils.appendNonTransactionalAsLeader(log, 10) // 49 + appendPid3(9) // 58 + LogTestUtils.appendEndTxnMarkerAsLeader(log, pid3, epoch, ControlRecordType.COMMIT, mockTime.milliseconds()) // 59 + appendPid4(8) // 67 + appendPid2(7) // 74 + LogTestUtils.appendEndTxnMarkerAsLeader(log, pid2, epoch, ControlRecordType.ABORT, mockTime.milliseconds()) // 75 + LogTestUtils.appendNonTransactionalAsLeader(log, 10) // 85 + appendPid4(4) // 89 + LogTestUtils.appendEndTxnMarkerAsLeader(log, pid4, epoch, ControlRecordType.COMMIT, mockTime.milliseconds()) // 90 + + // delete the last offset and transaction index files to force recovery + val lastSegment = log.logSegments.last + val recoveryPoint = lastSegment.baseOffset + lastSegment.offsetIndex.deleteIfExists() + lastSegment.txnIndex.deleteIfExists() + + log.close() + + val reloadedLogConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 5) + val reloadedLog = createLog(logDir, reloadedLogConfig, recoveryPoint = recoveryPoint, lastShutdownClean = false) + val abortedTransactions = LogTestUtils.allAbortedTransactions(reloadedLog) + assertEquals(List(new AbortedTxn(pid1, 0L, 29L, 8L), new AbortedTxn(pid2, 8L, 74L, 36L)), abortedTransactions) + } + + @Test + def testRecoverLastSegmentWithNoSnapshots(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 128 * 5) + val log = createLog(logDir, logConfig) + val epoch = 0.toShort + + val pid1 = 1L + val pid2 = 2L + val pid3 = 3L + val pid4 = 4L + + val appendPid1 = LogTestUtils.appendTransactionalAsLeader(log, pid1, epoch, mockTime) + val appendPid2 = LogTestUtils.appendTransactionalAsLeader(log, pid2, epoch, mockTime) + val appendPid3 = LogTestUtils.appendTransactionalAsLeader(log, pid3, epoch, mockTime) + val appendPid4 = LogTestUtils.appendTransactionalAsLeader(log, pid4, epoch, mockTime) + + // mix transactional and non-transactional data + appendPid1(5) // nextOffset: 5 + LogTestUtils.appendNonTransactionalAsLeader(log, 3) // 8 + appendPid2(2) // 10 + appendPid1(4) // 14 + appendPid3(3) // 17 + LogTestUtils.appendNonTransactionalAsLeader(log, 2) // 19 + appendPid1(10) // 29 + LogTestUtils.appendEndTxnMarkerAsLeader(log, pid1, epoch, ControlRecordType.ABORT, mockTime.milliseconds()) // 30 + appendPid2(6) // 36 + appendPid4(3) // 39 + LogTestUtils.appendNonTransactionalAsLeader(log, 10) // 49 + appendPid3(9) // 58 + LogTestUtils.appendEndTxnMarkerAsLeader(log, pid3, epoch, ControlRecordType.COMMIT, mockTime.milliseconds()) // 59 + appendPid4(8) // 67 + appendPid2(7) // 74 + LogTestUtils.appendEndTxnMarkerAsLeader(log, pid2, epoch, ControlRecordType.ABORT, mockTime.milliseconds()) // 75 + LogTestUtils.appendNonTransactionalAsLeader(log, 10) // 85 + appendPid4(4) // 89 + LogTestUtils.appendEndTxnMarkerAsLeader(log, pid4, epoch, ControlRecordType.COMMIT, mockTime.milliseconds()) // 90 + + LogTestUtils.deleteProducerSnapshotFiles(logDir) + + // delete the last offset and transaction index files to force recovery. this should force us to rebuild + // the producer state from the start of the log + val lastSegment = log.logSegments.last + val recoveryPoint = lastSegment.baseOffset + lastSegment.offsetIndex.deleteIfExists() + lastSegment.txnIndex.deleteIfExists() + + log.close() + + val reloadedLogConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 5) + val reloadedLog = createLog(logDir, reloadedLogConfig, recoveryPoint = recoveryPoint, lastShutdownClean = false) + val abortedTransactions = LogTestUtils.allAbortedTransactions(reloadedLog) + assertEquals(List(new AbortedTxn(pid1, 0L, 29L, 8L), new AbortedTxn(pid2, 8L, 74L, 36L)), abortedTransactions) + } + + @Test + def testLogEndLessThanStartAfterReopen(): Unit = { + val logConfig = LogTestUtils.createLogConfig() + var log = createLog(logDir, logConfig) + for (i <- 0 until 5) { + val record = new SimpleRecord(mockTime.milliseconds, i.toString.getBytes) + log.appendAsLeader(TestUtils.records(List(record)), leaderEpoch = 0) + log.roll() + } + assertEquals(6, log.logSegments.size) + + // Increment the log start offset + val startOffset = 4 + log.updateHighWatermark(log.logEndOffset) + log.maybeIncrementLogStartOffset(startOffset, ClientRecordDeletion) + assertTrue(log.logEndOffset > log.logStartOffset) + + // Append garbage to a segment below the current log start offset + val segmentToForceTruncation = log.logSegments.take(2).last + val bw = new BufferedWriter(new FileWriter(segmentToForceTruncation.log.file)) + bw.write("corruptRecord") + bw.close() + log.close() + + // Reopen the log. This will cause truncate the segment to which we appended garbage and delete all other segments. + // All remaining segments will be lower than the current log start offset, which will force deletion of all segments + // and recreation of a single, active segment starting at logStartOffset. + log = createLog(logDir, logConfig, logStartOffset = startOffset, lastShutdownClean = false) + // Wait for segment deletions (if any) to complete. + mockTime.sleep(logConfig.fileDeleteDelayMs) + assertEquals(1, log.numberOfSegments) + assertEquals(startOffset, log.logStartOffset) + assertEquals(startOffset, log.logEndOffset) + // Validate that the remaining segment matches our expectations + val onlySegment = log.logSegments.head + assertEquals(startOffset, onlySegment.baseOffset) + assertTrue(onlySegment.log.file().exists()) + assertTrue(onlySegment.lazyOffsetIndex.file.exists()) + assertTrue(onlySegment.lazyTimeIndex.file.exists()) + } + + @Test + def testCorruptedLogRecoveryDoesNotDeleteProducerStateSnapshotsPostRecovery(): Unit = { + val logConfig = LogTestUtils.createLogConfig() + var log = createLog(logDir, logConfig) + // Create segments: [0-0], [1-1], [2-2], [3-3], [4-4], [5-5], [6-6], [7-7], [8-8], [9-] + // |---> logStartOffset |---> active segment (empty) + // |---> logEndOffset + for (i <- 0 until 9) { + val record = new SimpleRecord(mockTime.milliseconds, i.toString.getBytes) + log.appendAsLeader(TestUtils.records(List(record)), leaderEpoch = 0) + log.roll() + } + assertEquals(10, log.logSegments.size) + assertEquals(0, log.logStartOffset) + assertEquals(9, log.activeSegment.baseOffset) + assertEquals(9, log.logEndOffset) + for (offset <- 1 until 10) { + val snapshotFileBeforeDeletion = log.producerStateManager.snapshotFileForOffset(offset) + assertTrue(snapshotFileBeforeDeletion.isDefined) + assertTrue(snapshotFileBeforeDeletion.get.file.exists) + } + + // Increment the log start offset to 4. + // After this step, the segments should be: + // |---> logStartOffset + // [0-0], [1-1], [2-2], [3-3], [4-4], [5-5], [6-6], [7-7], [8-8], [9-] + // |---> active segment (empty) + // |---> logEndOffset + val newLogStartOffset = 4 + log.updateHighWatermark(log.logEndOffset) + log.maybeIncrementLogStartOffset(newLogStartOffset, ClientRecordDeletion) + assertEquals(4, log.logStartOffset) + assertEquals(9, log.logEndOffset) + + // Append garbage to a segment at baseOffset 1, which is below the current log start offset 4. + // After this step, the segments should be: + // + // [0-0], [1-1], [2-2], [3-3], [4-4], [5-5], [6-6], [7-7], [8-8], [9-] + // | |---> logStartOffset |---> active segment (empty) + // | |---> logEndOffset + // corrupt record inserted + // + val segmentToForceTruncation = log.logSegments.take(2).last + assertEquals(1, segmentToForceTruncation.baseOffset) + val bw = new BufferedWriter(new FileWriter(segmentToForceTruncation.log.file)) + bw.write("corruptRecord") + bw.close() + log.close() + + // Reopen the log. This will do the following: + // - Truncate the segment above to which we appended garbage and will schedule async deletion of all other + // segments from base offsets 2 to 9. + // - The remaining segments at base offsets 0 and 1 will be lower than the current logStartOffset 4. + // This will cause async deletion of both remaining segments. Finally a single, active segment is created + // starting at logStartOffset 4. + // + // Expected segments after the log is opened again: + // [4-] + // |---> active segment (empty) + // |---> logStartOffset + // |---> logEndOffset + log = createLog(logDir, logConfig, logStartOffset = newLogStartOffset, lastShutdownClean = false) + assertEquals(1, log.logSegments.size) + assertEquals(4, log.logStartOffset) + assertEquals(4, log.activeSegment.baseOffset) + assertEquals(4, log.logEndOffset) + + val offsetsWithSnapshotFiles = (1 until 5) + .map(offset => SnapshotFile(UnifiedLog.producerSnapshotFile(logDir, offset))) + .filter(snapshotFile => snapshotFile.file.exists()) + .map(_.offset) + val inMemorySnapshotFiles = (1 until 5) + .flatMap(offset => log.producerStateManager.snapshotFileForOffset(offset)) + + assertTrue(offsetsWithSnapshotFiles.isEmpty, s"Found offsets with producer state snapshot files: $offsetsWithSnapshotFiles while none were expected.") + assertTrue(inMemorySnapshotFiles.isEmpty, s"Found in-memory producer state snapshot files: $inMemorySnapshotFiles while none were expected.") + + // Append records, roll the segments and check that the producer state snapshots are defined. + // The expected segments and producer state snapshots, after the appends are complete and segments are rolled, + // is as shown below: + // [4-4], [5-5], [6-6], [7-7], [8-8], [9-] + // | | | | | |---> active segment (empty) + // | | | | | |---> logEndOffset + // | | | | | | + // | |------.------.------.------.-----> producer state snapshot files are DEFINED for each offset in: [5-9] + // |----------------------------------------> logStartOffset + for (i <- 0 until 5) { + val record = new SimpleRecord(mockTime.milliseconds, i.toString.getBytes) + log.appendAsLeader(TestUtils.records(List(record)), leaderEpoch = 0) + log.roll() + } + assertEquals(9, log.activeSegment.baseOffset) + assertEquals(9, log.logEndOffset) + for (offset <- 5 until 10) { + val snapshotFileBeforeDeletion = log.producerStateManager.snapshotFileForOffset(offset) + assertTrue(snapshotFileBeforeDeletion.isDefined) + assertTrue(snapshotFileBeforeDeletion.get.file.exists) + } + + // Wait for all async segment deletions scheduled during Log recovery to complete. + // The expected segments and producer state snapshot after the deletions, is as shown below: + // [4-4], [5-5], [6-6], [7-7], [8-8], [9-] + // | | | | | |---> active segment (empty) + // | | | | | |---> logEndOffset + // | | | | | | + // | |------.------.------.------.-----> producer state snapshot files should be defined for each offset in: [5-9]. + // |----------------------------------------> logStartOffset + mockTime.sleep(logConfig.fileDeleteDelayMs) + assertEquals(newLogStartOffset, log.logStartOffset) + assertEquals(9, log.logEndOffset) + val offsetsWithMissingSnapshotFiles = ListBuffer[Long]() + for (offset <- 5 until 10) { + val snapshotFile = log.producerStateManager.snapshotFileForOffset(offset) + if (snapshotFile.isEmpty || !snapshotFile.get.file.exists) { + offsetsWithMissingSnapshotFiles.append(offset) + } + } + assertTrue(offsetsWithMissingSnapshotFiles.isEmpty, + s"Found offsets with missing producer state snapshot files: $offsetsWithMissingSnapshotFiles") + assertFalse(logDir.list().exists(_.endsWith(UnifiedLog.DeletedFileSuffix)), "Expected no files to be present with the deleted file suffix") + } +} diff --git a/core/src/test/scala/unit/kafka/log/LogManagerTest.scala b/core/src/test/scala/unit/kafka/log/LogManagerTest.scala new file mode 100755 index 0000000..937d80c --- /dev/null +++ b/core/src/test/scala/unit/kafka/log/LogManagerTest.scala @@ -0,0 +1,741 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import com.yammer.metrics.core.MetricName +import kafka.metrics.KafkaYammerMetrics +import kafka.server.checkpoints.OffsetCheckpointFile +import kafka.server.metadata.{ConfigRepository, MockConfigRepository} +import kafka.server.{FetchDataInfo, FetchLogEnd} +import kafka.utils._ +import org.apache.directory.api.util.FileUtils +import org.apache.kafka.common.errors.OffsetOutOfRangeException +import org.apache.kafka.common.utils.Utils +import org.apache.kafka.common.{KafkaException, TopicPartition} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} +import org.mockito.ArgumentMatchers.any +import org.mockito.{ArgumentMatchers, Mockito} +import org.mockito.Mockito.{doAnswer, mock, never, spy, times, verify} + +import java.io._ +import java.nio.file.Files +import java.util.concurrent.Future +import java.util.{Collections, Properties} +import scala.collection.mutable +import scala.jdk.CollectionConverters._ +import scala.util.{Failure, Try} + +class LogManagerTest { + + val time = new MockTime() + val maxRollInterval = 100 + val maxLogAgeMs = 10 * 60 * 1000 + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, 1024: java.lang.Integer) + logProps.put(LogConfig.SegmentIndexBytesProp, 4096: java.lang.Integer) + logProps.put(LogConfig.RetentionMsProp, maxLogAgeMs: java.lang.Integer) + logProps.put(LogConfig.MessageTimestampDifferenceMaxMsProp, Long.MaxValue.toString) + val logConfig = LogConfig(logProps) + var logDir: File = null + var logManager: LogManager = null + val name = "kafka" + val veryLargeLogFlushInterval = 10000000L + + @BeforeEach + def setUp(): Unit = { + logDir = TestUtils.tempDir() + logManager = createLogManager() + logManager.startup(Set.empty) + } + + @AfterEach + def tearDown(): Unit = { + if (logManager != null) + logManager.shutdown() + Utils.delete(logDir) + // Some tests assign a new LogManager + if (logManager != null) + logManager.liveLogDirs.foreach(Utils.delete) + } + + /** + * Test that getOrCreateLog on a non-existent log creates a new log and that we can append to the new log. + */ + @Test + def testCreateLog(): Unit = { + val log = logManager.getOrCreateLog(new TopicPartition(name, 0), topicId = None) + assertEquals(1, logManager.liveLogDirs.size) + + val logFile = new File(logDir, name + "-0") + assertTrue(logFile.exists) + log.appendAsLeader(TestUtils.singletonRecords("test".getBytes()), leaderEpoch = 0) + } + + /** + * Tests that all internal futures are completed before LogManager.shutdown() returns to the + * caller during error situations. + */ + @Test + def testHandlingExceptionsDuringShutdown(): Unit = { + // We create two directories logDir1 and logDir2 to help effectively test error handling + // during LogManager.shutdown(). + val logDir1 = TestUtils.tempDir() + val logDir2 = TestUtils.tempDir() + var logManagerForTest: Option[LogManager] = Option.empty + try { + logManagerForTest = Some(createLogManager(Seq(logDir1, logDir2))) + + assertEquals(2, logManagerForTest.get.liveLogDirs.size) + logManagerForTest.get.startup(Set.empty) + + val log1 = logManagerForTest.get.getOrCreateLog(new TopicPartition(name, 0), topicId = None) + val log2 = logManagerForTest.get.getOrCreateLog(new TopicPartition(name, 1), topicId = None) + + val logFile1 = new File(logDir1, name + "-0") + assertTrue(logFile1.exists) + val logFile2 = new File(logDir2, name + "-1") + assertTrue(logFile2.exists) + + log1.appendAsLeader(TestUtils.singletonRecords("test1".getBytes()), leaderEpoch = 0) + log1.takeProducerSnapshot() + log1.appendAsLeader(TestUtils.singletonRecords("test1".getBytes()), leaderEpoch = 0) + + log2.appendAsLeader(TestUtils.singletonRecords("test2".getBytes()), leaderEpoch = 0) + log2.takeProducerSnapshot() + log2.appendAsLeader(TestUtils.singletonRecords("test2".getBytes()), leaderEpoch = 0) + + // This should cause log1.close() to fail during LogManger shutdown sequence. + FileUtils.deleteDirectory(logFile1) + + logManagerForTest.get.shutdown() + + assertFalse(Files.exists(new File(logDir1, LogLoader.CleanShutdownFile).toPath)) + assertTrue(Files.exists(new File(logDir2, LogLoader.CleanShutdownFile).toPath)) + } finally { + logManagerForTest.foreach(manager => manager.liveLogDirs.foreach(Utils.delete)) + } + } + + /** + * Test that getOrCreateLog on a non-existent log creates a new log and that we can append to the new log. + * The LogManager is configured with one invalid log directory which should be marked as offline. + */ + @Test + def testCreateLogWithInvalidLogDir(): Unit = { + // Configure the log dir with the Nul character as the path, which causes dir.getCanonicalPath() to throw an + // IOException. This simulates the scenario where the disk is not properly mounted (which is hard to achieve in + // a unit test) + val dirs = Seq(logDir, new File("\u0000")) + + logManager.shutdown() + logManager = createLogManager(dirs) + logManager.startup(Set.empty) + + val log = logManager.getOrCreateLog(new TopicPartition(name, 0), isNew = true, topicId = None) + val logFile = new File(logDir, name + "-0") + assertTrue(logFile.exists) + log.appendAsLeader(TestUtils.singletonRecords("test".getBytes()), leaderEpoch = 0) + } + + @Test + def testCreateLogWithLogDirFallback(): Unit = { + // Configure a number of directories one level deeper in logDir, + // so they all get cleaned up in tearDown(). + val dirs = (0 to 4) + .map(_.toString) + .map(logDir.toPath.resolve(_).toFile) + + // Create a new LogManager with the configured directories and an overridden createLogDirectory. + logManager.shutdown() + logManager = spy(createLogManager(dirs)) + val brokenDirs = mutable.Set[File]() + doAnswer { invocation => + // The first half of directories tried will fail, the rest goes through. + val logDir = invocation.getArgument[File](0) + if (brokenDirs.contains(logDir) || brokenDirs.size < dirs.length / 2) { + brokenDirs.add(logDir) + Failure(new Throwable("broken dir")) + } else { + invocation.callRealMethod().asInstanceOf[Try[File]] + } + }.when(logManager).createLogDirectory(any(), any()) + logManager.startup(Set.empty) + + // Request creating a new log. + // LogManager should try using all configured log directories until one succeeds. + logManager.getOrCreateLog(new TopicPartition(name, 0), isNew = true, topicId = None) + + // Verify that half the directories were considered broken, + assertEquals(dirs.length / 2, brokenDirs.size) + + // and that exactly one log file was created, + val containsLogFile: File => Boolean = dir => new File(dir, name + "-0").exists() + assertEquals(1, dirs.count(containsLogFile), "More than one log file created") + + // and that it wasn't created in one of the broken directories. + assertFalse(brokenDirs.exists(containsLogFile)) + } + + /** + * Test that get on a non-existent returns None and no log is created. + */ + @Test + def testGetNonExistentLog(): Unit = { + val log = logManager.getLog(new TopicPartition(name, 0)) + assertEquals(None, log, "No log should be found.") + val logFile = new File(logDir, name + "-0") + assertFalse(logFile.exists) + } + + /** + * Test time-based log cleanup. First append messages, then set the time into the future and run cleanup. + */ + @Test + def testCleanupExpiredSegments(): Unit = { + val log = logManager.getOrCreateLog(new TopicPartition(name, 0), topicId = None) + var offset = 0L + for(_ <- 0 until 200) { + val set = TestUtils.singletonRecords("test".getBytes()) + val info = log.appendAsLeader(set, leaderEpoch = 0) + offset = info.lastOffset + } + assertTrue(log.numberOfSegments > 1, "There should be more than one segment now.") + log.updateHighWatermark(log.logEndOffset) + + log.logSegments.foreach(_.log.file.setLastModified(time.milliseconds)) + + time.sleep(maxLogAgeMs + 1) + assertEquals(1, log.numberOfSegments, "Now there should only be only one segment in the index.") + time.sleep(log.config.fileDeleteDelayMs + 1) + + log.logSegments.foreach(s => { + s.lazyOffsetIndex.get + s.lazyTimeIndex.get + }) + + // there should be a log file, two indexes, one producer snapshot, and the leader epoch checkpoint + assertEquals(log.numberOfSegments * 4 + 1, log.dir.list.length, "Files should have been deleted") + assertEquals(0, readLog(log, offset + 1).records.sizeInBytes, "Should get empty fetch off new log.") + + assertThrows(classOf[OffsetOutOfRangeException], () => readLog(log, 0), () => "Should get exception from fetching earlier.") + // log should still be appendable + log.appendAsLeader(TestUtils.singletonRecords("test".getBytes()), leaderEpoch = 0) + } + + /** + * Test size-based cleanup. Append messages, then run cleanup and check that segments are deleted. + */ + @Test + def testCleanupSegmentsToMaintainSize(): Unit = { + val setSize = TestUtils.singletonRecords("test".getBytes()).sizeInBytes + logManager.shutdown() + val segmentBytes = 10 * setSize + val properties = new Properties() + properties.put(LogConfig.SegmentBytesProp, segmentBytes.toString) + properties.put(LogConfig.RetentionBytesProp, (5L * 10L * setSize + 10L).toString) + val configRepository = MockConfigRepository.forTopic(name, properties) + + logManager = createLogManager(configRepository = configRepository) + logManager.startup(Set.empty) + + // create a log + val log = logManager.getOrCreateLog(new TopicPartition(name, 0), topicId = None) + var offset = 0L + + // add a bunch of messages that should be larger than the retentionSize + val numMessages = 200 + for (_ <- 0 until numMessages) { + val set = TestUtils.singletonRecords("test".getBytes()) + val info = log.appendAsLeader(set, leaderEpoch = 0) + offset = info.firstOffset.get.messageOffset + } + + log.updateHighWatermark(log.logEndOffset) + assertEquals(numMessages * setSize / segmentBytes, log.numberOfSegments, "Check we have the expected number of segments.") + + // this cleanup shouldn't find any expired segments but should delete some to reduce size + time.sleep(logManager.InitialTaskDelayMs) + assertEquals(6, log.numberOfSegments, "Now there should be exactly 6 segments") + time.sleep(log.config.fileDeleteDelayMs + 1) + + // there should be a log file, two indexes (the txn index is created lazily), + // and a producer snapshot file per segment, and the leader epoch checkpoint. + assertEquals(log.numberOfSegments * 4 + 1, log.dir.list.length, "Files should have been deleted") + assertEquals(0, readLog(log, offset + 1).records.sizeInBytes, "Should get empty fetch off new log.") + assertThrows(classOf[OffsetOutOfRangeException], () => readLog(log, 0)) + // log should still be appendable + log.appendAsLeader(TestUtils.singletonRecords("test".getBytes()), leaderEpoch = 0) + } + + /** + * Ensures that LogManager doesn't run on logs with cleanup.policy=compact,delete + * LogCleaner.CleanerThread handles all logs where compaction is enabled. + */ + @Test + def testDoesntCleanLogsWithCompactDeletePolicy(): Unit = { + testDoesntCleanLogs(LogConfig.Compact + "," + LogConfig.Delete) + } + + /** + * Ensures that LogManager doesn't run on logs with cleanup.policy=compact + * LogCleaner.CleanerThread handles all logs where compaction is enabled. + */ + @Test + def testDoesntCleanLogsWithCompactPolicy(): Unit = { + testDoesntCleanLogs(LogConfig.Compact) + } + + private def testDoesntCleanLogs(policy: String): Unit = { + logManager.shutdown() + val configRepository = MockConfigRepository.forTopic(name, LogConfig.CleanupPolicyProp, policy) + + logManager = createLogManager(configRepository = configRepository) + val log = logManager.getOrCreateLog(new TopicPartition(name, 0), topicId = None) + var offset = 0L + for (_ <- 0 until 200) { + val set = TestUtils.singletonRecords("test".getBytes(), key="test".getBytes()) + val info = log.appendAsLeader(set, leaderEpoch = 0) + offset = info.lastOffset + } + + val numSegments = log.numberOfSegments + assertTrue(log.numberOfSegments > 1, "There should be more than one segment now.") + + log.logSegments.foreach(_.log.file.setLastModified(time.milliseconds)) + + time.sleep(maxLogAgeMs + 1) + assertEquals(numSegments, log.numberOfSegments, "number of segments shouldn't have changed") + } + + /** + * Test that flush is invoked by the background scheduler thread. + */ + @Test + def testTimeBasedFlush(): Unit = { + logManager.shutdown() + val configRepository = MockConfigRepository.forTopic(name, LogConfig.FlushMsProp, "1000") + + logManager = createLogManager(configRepository = configRepository) + logManager.startup(Set.empty) + val log = logManager.getOrCreateLog(new TopicPartition(name, 0), topicId = None) + val lastFlush = log.lastFlushTime + for (_ <- 0 until 200) { + val set = TestUtils.singletonRecords("test".getBytes()) + log.appendAsLeader(set, leaderEpoch = 0) + } + time.sleep(logManager.InitialTaskDelayMs) + assertTrue(lastFlush != log.lastFlushTime, "Time based flush should have been triggered") + } + + /** + * Test that new logs that are created are assigned to the least loaded log directory + */ + @Test + def testLeastLoadedAssignment(): Unit = { + // create a log manager with multiple data directories + val dirs = Seq(TestUtils.tempDir(), + TestUtils.tempDir(), + TestUtils.tempDir()) + logManager.shutdown() + logManager = createLogManager(dirs) + + // verify that logs are always assigned to the least loaded partition + for(partition <- 0 until 20) { + logManager.getOrCreateLog(new TopicPartition("test", partition), topicId = None) + assertEquals(partition + 1, logManager.allLogs.size, "We should have created the right number of logs") + val counts = logManager.allLogs.groupBy(_.dir.getParent).values.map(_.size) + assertTrue(counts.max <= counts.min + 1, "Load should balance evenly") + } + } + + /** + * Test that it is not possible to open two log managers using the same data directory + */ + @Test + def testTwoLogManagersUsingSameDirFails(): Unit = { + assertThrows(classOf[KafkaException], () => createLogManager()) + } + + /** + * Test that recovery points are correctly written out to disk + */ + @Test + def testCheckpointRecoveryPoints(): Unit = { + verifyCheckpointRecovery(Seq(new TopicPartition("test-a", 1), new TopicPartition("test-b", 1)), logManager, logDir) + } + + /** + * Test that recovery points directory checking works with trailing slash + */ + @Test + def testRecoveryDirectoryMappingWithTrailingSlash(): Unit = { + logManager.shutdown() + logManager = TestUtils.createLogManager(logDirs = Seq(new File(TestUtils.tempDir().getAbsolutePath + File.separator))) + logManager.startup(Set.empty) + verifyCheckpointRecovery(Seq(new TopicPartition("test-a", 1)), logManager, logManager.liveLogDirs.head) + } + + /** + * Test that recovery points directory checking works with relative directory + */ + @Test + def testRecoveryDirectoryMappingWithRelativeDirectory(): Unit = { + logManager.shutdown() + logManager = createLogManager(Seq(new File("data", logDir.getName).getAbsoluteFile)) + logManager.startup(Set.empty) + verifyCheckpointRecovery(Seq(new TopicPartition("test-a", 1)), logManager, logManager.liveLogDirs.head) + } + + private def verifyCheckpointRecovery(topicPartitions: Seq[TopicPartition], logManager: LogManager, logDir: File): Unit = { + val logs = topicPartitions.map(logManager.getOrCreateLog(_, topicId = None)) + logs.foreach { log => + for (_ <- 0 until 50) + log.appendAsLeader(TestUtils.singletonRecords("test".getBytes()), leaderEpoch = 0) + + log.flush() + } + + logManager.checkpointLogRecoveryOffsets() + val checkpoints = new OffsetCheckpointFile(new File(logDir, LogManager.RecoveryPointCheckpointFile)).read() + + topicPartitions.zip(logs).foreach { case (tp, log) => + assertEquals(checkpoints(tp), log.recoveryPoint, "Recovery point should equal checkpoint") + } + } + + private def createLogManager(logDirs: Seq[File] = Seq(this.logDir), + configRepository: ConfigRepository = new MockConfigRepository): LogManager = { + TestUtils.createLogManager( + defaultConfig = logConfig, + configRepository = configRepository, + logDirs = logDirs, + time = this.time) + } + + @Test + def testFileReferencesAfterAsyncDelete(): Unit = { + val log = logManager.getOrCreateLog(new TopicPartition(name, 0), topicId = None) + val activeSegment = log.activeSegment + val logName = activeSegment.log.file.getName + val indexName = activeSegment.offsetIndex.file.getName + val timeIndexName = activeSegment.timeIndex.file.getName + val txnIndexName = activeSegment.txnIndex.file.getName + val indexFilesOnDiskBeforeDelete = activeSegment.log.file.getParentFile.listFiles.filter(_.getName.endsWith("index")) + + val removedLog = logManager.asyncDelete(new TopicPartition(name, 0)).get + val removedSegment = removedLog.activeSegment + val indexFilesAfterDelete = Seq(removedSegment.lazyOffsetIndex.file, removedSegment.lazyTimeIndex.file, + removedSegment.txnIndex.file) + + assertEquals(new File(removedLog.dir, logName), removedSegment.log.file) + assertEquals(new File(removedLog.dir, indexName), removedSegment.lazyOffsetIndex.file) + assertEquals(new File(removedLog.dir, timeIndexName), removedSegment.lazyTimeIndex.file) + assertEquals(new File(removedLog.dir, txnIndexName), removedSegment.txnIndex.file) + + // Try to detect the case where a new index type was added and we forgot to update the pointer + // This will only catch cases where the index file is created eagerly instead of lazily + indexFilesOnDiskBeforeDelete.foreach { fileBeforeDelete => + val fileInIndex = indexFilesAfterDelete.find(_.getName == fileBeforeDelete.getName) + assertEquals(Some(fileBeforeDelete.getName), fileInIndex.map(_.getName), + s"Could not find index file ${fileBeforeDelete.getName} in indexFilesAfterDelete") + assertNotEquals("File reference was not updated in index", fileBeforeDelete.getAbsolutePath, + fileInIndex.get.getAbsolutePath) + } + + time.sleep(logManager.InitialTaskDelayMs) + assertTrue(logManager.hasLogsToBeDeleted, "Logs deleted too early") + time.sleep(logManager.currentDefaultConfig.fileDeleteDelayMs - logManager.InitialTaskDelayMs) + assertFalse(logManager.hasLogsToBeDeleted, "Logs not deleted") + } + + @Test + def testCreateAndDeleteOverlyLongTopic(): Unit = { + val invalidTopicName = String.join("", Collections.nCopies(253, "x")) + logManager.getOrCreateLog(new TopicPartition(invalidTopicName, 0), topicId = None) + logManager.asyncDelete(new TopicPartition(invalidTopicName, 0)) + } + + @Test + def testCheckpointForOnlyAffectedLogs(): Unit = { + val tps = Seq( + new TopicPartition("test-a", 0), + new TopicPartition("test-a", 1), + new TopicPartition("test-a", 2), + new TopicPartition("test-b", 0), + new TopicPartition("test-b", 1)) + + val allLogs = tps.map(logManager.getOrCreateLog(_, topicId = None)) + allLogs.foreach { log => + for (_ <- 0 until 50) + log.appendAsLeader(TestUtils.singletonRecords("test".getBytes), leaderEpoch = 0) + log.flush() + } + + logManager.checkpointRecoveryOffsetsInDir(logDir) + + val checkpoints = new OffsetCheckpointFile(new File(logDir, LogManager.RecoveryPointCheckpointFile)).read() + + tps.zip(allLogs).foreach { case (tp, log) => + assertEquals(checkpoints(tp), log.recoveryPoint, + "Recovery point should equal checkpoint") + } + } + + private def readLog(log: UnifiedLog, offset: Long, maxLength: Int = 1024): FetchDataInfo = { + log.read(offset, maxLength, isolation = FetchLogEnd, minOneMessage = true) + } + + /** + * Test when a configuration of a topic is updated while its log is getting initialized, + * the config is refreshed when log initialization is finished. + */ + @Test + def testTopicConfigChangeUpdatesLogConfig(): Unit = { + logManager.shutdown() + val spyConfigRepository = spy(new MockConfigRepository) + logManager = createLogManager(configRepository = spyConfigRepository) + val spyLogManager = spy(logManager) + val mockLog = mock(classOf[UnifiedLog]) + + val testTopicOne = "test-topic-one" + val testTopicTwo = "test-topic-two" + val testTopicOnePartition = new TopicPartition(testTopicOne, 1) + val testTopicTwoPartition = new TopicPartition(testTopicTwo, 1) + + spyLogManager.initializingLog(testTopicOnePartition) + spyLogManager.initializingLog(testTopicTwoPartition) + + spyLogManager.topicConfigUpdated(testTopicOne) + + spyLogManager.finishedInitializingLog(testTopicOnePartition, Some(mockLog)) + spyLogManager.finishedInitializingLog(testTopicTwoPartition, Some(mockLog)) + + // testTopicOne configs loaded again due to the update + verify(spyLogManager).initializingLog(ArgumentMatchers.eq(testTopicOnePartition)) + verify(spyLogManager).finishedInitializingLog(ArgumentMatchers.eq(testTopicOnePartition), ArgumentMatchers.any()) + verify(spyConfigRepository, times(1)).topicConfig(testTopicOne) + + // testTopicTwo configs not loaded again since there was no update + verify(spyLogManager).initializingLog(ArgumentMatchers.eq(testTopicTwoPartition)) + verify(spyLogManager).finishedInitializingLog(ArgumentMatchers.eq(testTopicTwoPartition), ArgumentMatchers.any()) + verify(spyConfigRepository, never).topicConfig(testTopicTwo) + } + + /** + * Test if an error occurs when creating log, log manager removes corresponding + * topic partition from the list of initializing partitions and no configs are retrieved. + */ + @Test + def testConfigChangeGetsCleanedUp(): Unit = { + logManager.shutdown() + val spyConfigRepository = spy(new MockConfigRepository) + logManager = createLogManager(configRepository = spyConfigRepository) + val spyLogManager = spy(logManager) + + val testTopicPartition = new TopicPartition("test-topic", 1) + spyLogManager.initializingLog(testTopicPartition) + spyLogManager.finishedInitializingLog(testTopicPartition, None) + + assertTrue(logManager.partitionsInitializing.isEmpty) + verify(spyConfigRepository, never).topicConfig(testTopicPartition.topic) + } + + /** + * Test when a broker configuration change happens all logs in process of initialization + * pick up latest config when finished with initialization. + */ + @Test + def testBrokerConfigChangeDeliveredToAllLogs(): Unit = { + logManager.shutdown() + val spyConfigRepository = spy(new MockConfigRepository) + logManager = createLogManager(configRepository = spyConfigRepository) + val spyLogManager = spy(logManager) + val mockLog = mock(classOf[UnifiedLog]) + + val testTopicOne = "test-topic-one" + val testTopicTwo = "test-topic-two" + val testTopicOnePartition = new TopicPartition(testTopicOne, 1) + val testTopicTwoPartition = new TopicPartition(testTopicTwo, 1) + + spyLogManager.initializingLog(testTopicOnePartition) + spyLogManager.initializingLog(testTopicTwoPartition) + + spyLogManager.brokerConfigUpdated() + + spyLogManager.finishedInitializingLog(testTopicOnePartition, Some(mockLog)) + spyLogManager.finishedInitializingLog(testTopicTwoPartition, Some(mockLog)) + + verify(spyConfigRepository, times(1)).topicConfig(testTopicOne) + verify(spyConfigRepository, times(1)).topicConfig(testTopicTwo) + } + + /** + * Test when compact is removed that cleaning of the partitions is aborted. + */ + @Test + def testTopicConfigChangeStopCleaningIfCompactIsRemoved(): Unit = { + logManager.shutdown() + logManager = createLogManager(configRepository = new MockConfigRepository) + val spyLogManager = spy(logManager) + + val topic = "topic" + val tp0 = new TopicPartition(topic, 0) + val tp1 = new TopicPartition(topic, 1) + + val oldProperties = new Properties() + oldProperties.put(LogConfig.CleanupPolicyProp, LogConfig.Compact) + val oldLogConfig = LogConfig.fromProps(logConfig.originals, oldProperties) + + val log0 = spyLogManager.getOrCreateLog(tp0, topicId = None) + log0.updateConfig(oldLogConfig) + val log1 = spyLogManager.getOrCreateLog(tp1, topicId = None) + log1.updateConfig(oldLogConfig) + + assertEquals(Set(log0, log1), spyLogManager.logsByTopic(topic).toSet) + + val newProperties = new Properties() + newProperties.put(LogConfig.CleanupPolicyProp, LogConfig.Delete) + + spyLogManager.updateTopicConfig(topic, newProperties) + + assertTrue(log0.config.delete) + assertTrue(log1.config.delete) + assertFalse(log0.config.compact) + assertFalse(log1.config.compact) + + verify(spyLogManager, times(1)).topicConfigUpdated(topic) + verify(spyLogManager, times(1)).abortCleaning(tp0) + verify(spyLogManager, times(1)).abortCleaning(tp1) + } + + /** + * Test even if no log is getting initialized, if config change events are delivered + * things continue to work correctly. This test should not throw. + * + * This makes sure that events can be delivered even when no log is getting initialized. + */ + @Test + def testConfigChangesWithNoLogGettingInitialized(): Unit = { + logManager.brokerConfigUpdated() + logManager.topicConfigUpdated("test-topic") + assertTrue(logManager.partitionsInitializing.isEmpty) + } + + @Test + def testMetricsExistWhenLogIsRecreatedBeforeDeletion(): Unit = { + val topicName = "metric-test" + def logMetrics: mutable.Set[MetricName] = KafkaYammerMetrics.defaultRegistry.allMetrics.keySet.asScala. + filter(metric => metric.getType == "Log" && metric.getScope.contains(topicName)) + + val tp = new TopicPartition(topicName, 0) + val metricTag = s"topic=${tp.topic},partition=${tp.partition}" + + def verifyMetrics(): Unit = { + assertEquals(LogMetricNames.allMetricNames.size, logMetrics.size) + logMetrics.foreach { metric => + assertTrue(metric.getMBeanName.contains(metricTag)) + } + } + + // Create the Log and assert that the metrics are present + logManager.getOrCreateLog(tp, topicId = None) + verifyMetrics() + + // Trigger the deletion and assert that the metrics have been removed + val removedLog = logManager.asyncDelete(tp).get + assertTrue(logMetrics.isEmpty) + + // Recreate the Log and assert that the metrics are present + logManager.getOrCreateLog(tp, topicId = None) + verifyMetrics() + + // Advance time past the file deletion delay and assert that the removed log has been deleted but the metrics + // are still present + time.sleep(logConfig.fileDeleteDelayMs + 1) + assertTrue(removedLog.logSegments.isEmpty) + verifyMetrics() + } + + @Test + def testMetricsAreRemovedWhenMovingCurrentToFutureLog(): Unit = { + val dir1 = TestUtils.tempDir() + val dir2 = TestUtils.tempDir() + logManager = createLogManager(Seq(dir1, dir2)) + logManager.startup(Set.empty) + + val topicName = "future-log" + def logMetrics: mutable.Set[MetricName] = KafkaYammerMetrics.defaultRegistry.allMetrics.keySet.asScala. + filter(metric => metric.getType == "Log" && metric.getScope.contains(topicName)) + + val tp = new TopicPartition(topicName, 0) + val metricTag = s"topic=${tp.topic},partition=${tp.partition}" + + def verifyMetrics(logCount: Int): Unit = { + assertEquals(LogMetricNames.allMetricNames.size * logCount, logMetrics.size) + logMetrics.foreach { metric => + assertTrue(metric.getMBeanName.contains(metricTag)) + } + } + + // Create the current and future logs and verify that metrics are present for both current and future logs + logManager.maybeUpdatePreferredLogDir(tp, dir1.getAbsolutePath) + logManager.getOrCreateLog(tp, topicId = None) + logManager.maybeUpdatePreferredLogDir(tp, dir2.getAbsolutePath) + logManager.getOrCreateLog(tp, isFuture = true, topicId = None) + verifyMetrics(2) + + // Replace the current log with the future one and verify that only one set of metrics are present + logManager.replaceCurrentWithFutureLog(tp) + verifyMetrics(1) + + // Trigger the deletion of the former current directory and verify that one set of metrics is still present + time.sleep(logConfig.fileDeleteDelayMs + 1) + verifyMetrics(1) + } + + @Test + def testWaitForAllToComplete(): Unit = { + var invokedCount = 0 + val success: Future[Boolean] = Mockito.mock(classOf[Future[Boolean]]) + Mockito.when(success.get()).thenAnswer { _ => + invokedCount += 1 + true + } + val failure: Future[Boolean] = Mockito.mock(classOf[Future[Boolean]]) + Mockito.when(failure.get()).thenAnswer{ _ => + invokedCount += 1 + throw new RuntimeException + } + + var failureCount = 0 + // all futures should be evaluated + assertFalse(LogManager.waitForAllToComplete(Seq(success, failure), _ => failureCount += 1)) + assertEquals(2, invokedCount) + assertEquals(1, failureCount) + assertFalse(LogManager.waitForAllToComplete(Seq(failure, success), _ => failureCount += 1)) + assertEquals(4, invokedCount) + assertEquals(2, failureCount) + assertTrue(LogManager.waitForAllToComplete(Seq(success, success), _ => failureCount += 1)) + assertEquals(6, invokedCount) + assertEquals(2, failureCount) + assertFalse(LogManager.waitForAllToComplete(Seq(failure, failure), _ => failureCount += 1)) + assertEquals(8, invokedCount) + assertEquals(4, failureCount) + } +} diff --git a/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala b/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala new file mode 100644 index 0000000..9884576 --- /dev/null +++ b/core/src/test/scala/unit/kafka/log/LogSegmentTest.scala @@ -0,0 +1,589 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.log + +import java.io.File + +import kafka.server.checkpoints.LeaderEpochCheckpoint +import kafka.server.epoch.EpochEntry +import kafka.server.epoch.LeaderEpochFileCache +import kafka.utils.TestUtils +import kafka.utils.TestUtils.checkEquals +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.record._ +import org.apache.kafka.common.utils.{MockTime, Time, Utils} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} + +import scala.jdk.CollectionConverters._ +import scala.collection._ +import scala.collection.mutable.ArrayBuffer + +class LogSegmentTest { + + val topicPartition = new TopicPartition("topic", 0) + val segments = mutable.ArrayBuffer[LogSegment]() + var logDir: File = _ + + /* create a segment with the given base offset */ + def createSegment(offset: Long, + indexIntervalBytes: Int = 10, + time: Time = Time.SYSTEM): LogSegment = { + val seg = LogTestUtils.createSegment(offset, logDir, indexIntervalBytes, time) + segments += seg + seg + } + + /* create a ByteBufferMessageSet for the given messages starting from the given offset */ + def records(offset: Long, records: String*): MemoryRecords = { + MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V1, offset, CompressionType.NONE, TimestampType.CREATE_TIME, + records.map { s => new SimpleRecord(offset * 10, s.getBytes) }: _*) + } + + @BeforeEach + def setup(): Unit = { + logDir = TestUtils.tempDir() + } + + @AfterEach + def teardown(): Unit = { + segments.foreach(_.close()) + Utils.delete(logDir) + } + + /** + * A read on an empty log segment should return null + */ + @Test + def testReadOnEmptySegment(): Unit = { + val seg = createSegment(40) + val read = seg.read(startOffset = 40, maxSize = 300) + assertNull(read, "Read beyond the last offset in the segment should be null") + } + + /** + * Reading from before the first offset in the segment should return messages + * beginning with the first message in the segment + */ + @Test + def testReadBeforeFirstOffset(): Unit = { + val seg = createSegment(40) + val ms = records(50, "hello", "there", "little", "bee") + seg.append(53, RecordBatch.NO_TIMESTAMP, -1L, ms) + val read = seg.read(startOffset = 41, maxSize = 300).records + checkEquals(ms.records.iterator, read.records.iterator) + } + + /** + * If we read from an offset beyond the last offset in the segment we should get null + */ + @Test + def testReadAfterLast(): Unit = { + val seg = createSegment(40) + val ms = records(50, "hello", "there") + seg.append(51, RecordBatch.NO_TIMESTAMP, -1L, ms) + val read = seg.read(startOffset = 52, maxSize = 200) + assertNull(read, "Read beyond the last offset in the segment should give null") + } + + /** + * If we read from an offset which doesn't exist we should get a message set beginning + * with the least offset greater than the given startOffset. + */ + @Test + def testReadFromGap(): Unit = { + val seg = createSegment(40) + val ms = records(50, "hello", "there") + seg.append(51, RecordBatch.NO_TIMESTAMP, -1L, ms) + val ms2 = records(60, "alpha", "beta") + seg.append(61, RecordBatch.NO_TIMESTAMP, -1L, ms2) + val read = seg.read(startOffset = 55, maxSize = 200) + checkEquals(ms2.records.iterator, read.records.records.iterator) + } + + /** + * In a loop append two messages then truncate off the second of those messages and check that we can read + * the first but not the second message. + */ + @Test + def testTruncate(): Unit = { + val seg = createSegment(40) + var offset = 40 + for (_ <- 0 until 30) { + val ms1 = records(offset, "hello") + seg.append(offset, RecordBatch.NO_TIMESTAMP, -1L, ms1) + val ms2 = records(offset + 1, "hello") + seg.append(offset + 1, RecordBatch.NO_TIMESTAMP, -1L, ms2) + // check that we can read back both messages + val read = seg.read(offset, 10000) + assertEquals(List(ms1.records.iterator.next(), ms2.records.iterator.next()), read.records.records.asScala.toList) + // now truncate off the last message + seg.truncateTo(offset + 1) + val read2 = seg.read(offset, 10000) + assertEquals(1, read2.records.records.asScala.size) + checkEquals(ms1.records.iterator, read2.records.records.iterator) + offset += 1 + } + } + + @Test + def testTruncateEmptySegment(): Unit = { + // This tests the scenario in which the follower truncates to an empty segment. In this + // case we must ensure that the index is resized so that the log segment is not mistakenly + // rolled due to a full index + + val maxSegmentMs = 300000 + val time = new MockTime + val seg = createSegment(0, time = time) + // Force load indexes before closing the segment + seg.timeIndex + seg.offsetIndex + seg.close() + + val reopened = createSegment(0, time = time) + assertEquals(0, seg.timeIndex.sizeInBytes) + assertEquals(0, seg.offsetIndex.sizeInBytes) + + time.sleep(500) + reopened.truncateTo(57) + assertEquals(0, reopened.timeWaitedForRoll(time.milliseconds(), RecordBatch.NO_TIMESTAMP)) + assertFalse(reopened.timeIndex.isFull) + assertFalse(reopened.offsetIndex.isFull) + + var rollParams = RollParams(maxSegmentMs, maxSegmentBytes = Int.MaxValue, RecordBatch.NO_TIMESTAMP, + maxOffsetInMessages = 100L, messagesSize = 1024, time.milliseconds()) + assertFalse(reopened.shouldRoll(rollParams)) + + // The segment should not be rolled even if maxSegmentMs has been exceeded + time.sleep(maxSegmentMs + 1) + assertEquals(maxSegmentMs + 1, reopened.timeWaitedForRoll(time.milliseconds(), RecordBatch.NO_TIMESTAMP)) + rollParams = RollParams(maxSegmentMs, maxSegmentBytes = Int.MaxValue, RecordBatch.NO_TIMESTAMP, + maxOffsetInMessages = 100L, messagesSize = 1024, time.milliseconds()) + assertFalse(reopened.shouldRoll(rollParams)) + + // But we should still roll the segment if we cannot fit the next offset + rollParams = RollParams(maxSegmentMs, maxSegmentBytes = Int.MaxValue, RecordBatch.NO_TIMESTAMP, + maxOffsetInMessages = Int.MaxValue.toLong + 200L, messagesSize = 1024, time.milliseconds()) + assertTrue(reopened.shouldRoll(rollParams)) + } + + @Test + def testReloadLargestTimestampAndNextOffsetAfterTruncation(): Unit = { + val numMessages = 30 + val seg = createSegment(40, 2 * records(0, "hello").sizeInBytes - 1) + var offset = 40 + for (_ <- 0 until numMessages) { + seg.append(offset, offset, offset, records(offset, "hello")) + offset += 1 + } + assertEquals(offset, seg.readNextOffset) + + val expectedNumEntries = numMessages / 2 - 1 + assertEquals(expectedNumEntries, seg.timeIndex.entries, s"Should have $expectedNumEntries time indexes") + + seg.truncateTo(41) + assertEquals(0, seg.timeIndex.entries, s"Should have 0 time indexes") + assertEquals(400L, seg.largestTimestamp, s"Largest timestamp should be 400") + assertEquals(41, seg.readNextOffset) + } + + /** + * Test truncating the whole segment, and check that we can reappend with the original offset. + */ + @Test + def testTruncateFull(): Unit = { + // test the case where we fully truncate the log + val time = new MockTime + val seg = createSegment(40, time = time) + seg.append(41, RecordBatch.NO_TIMESTAMP, -1L, records(40, "hello", "there")) + + // If the segment is empty after truncation, the create time should be reset + time.sleep(500) + assertEquals(500, seg.timeWaitedForRoll(time.milliseconds(), RecordBatch.NO_TIMESTAMP)) + + seg.truncateTo(0) + assertEquals(0, seg.timeWaitedForRoll(time.milliseconds(), RecordBatch.NO_TIMESTAMP)) + assertFalse(seg.timeIndex.isFull) + assertFalse(seg.offsetIndex.isFull) + assertNull(seg.read(0, 1024), "Segment should be empty.") + + seg.append(41, RecordBatch.NO_TIMESTAMP, -1L, records(40, "hello", "there")) + } + + /** + * Append messages with timestamp and search message by timestamp. + */ + @Test + def testFindOffsetByTimestamp(): Unit = { + val messageSize = records(0, s"msg00").sizeInBytes + val seg = createSegment(40, messageSize * 2 - 1) + // Produce some messages + for (i <- 40 until 50) + seg.append(i, i * 10, i, records(i, s"msg$i")) + + assertEquals(490, seg.largestTimestamp) + // Search for an indexed timestamp + assertEquals(42, seg.findOffsetByTimestamp(420).get.offset) + assertEquals(43, seg.findOffsetByTimestamp(421).get.offset) + // Search for an un-indexed timestamp + assertEquals(43, seg.findOffsetByTimestamp(430).get.offset) + assertEquals(44, seg.findOffsetByTimestamp(431).get.offset) + // Search beyond the last timestamp + assertEquals(None, seg.findOffsetByTimestamp(491)) + // Search before the first indexed timestamp + assertEquals(41, seg.findOffsetByTimestamp(401).get.offset) + // Search before the first timestamp + assertEquals(40, seg.findOffsetByTimestamp(399).get.offset) + } + + /** + * Test that offsets are assigned sequentially and that the nextOffset variable is incremented + */ + @Test + def testNextOffsetCalculation(): Unit = { + val seg = createSegment(40) + assertEquals(40, seg.readNextOffset) + seg.append(52, RecordBatch.NO_TIMESTAMP, -1L, records(50, "hello", "there", "you")) + assertEquals(53, seg.readNextOffset) + } + + /** + * Test that we can change the file suffixes for the log and index files + */ + @Test + def testChangeFileSuffixes(): Unit = { + val seg = createSegment(40) + val logFile = seg.log.file + val indexFile = seg.lazyOffsetIndex.file + val timeIndexFile = seg.lazyTimeIndex.file + // Ensure that files for offset and time indices have not been created eagerly. + assertFalse(seg.lazyOffsetIndex.file.exists) + assertFalse(seg.lazyTimeIndex.file.exists) + seg.changeFileSuffixes("", ".deleted") + // Ensure that attempt to change suffixes for non-existing offset and time indices does not create new files. + assertFalse(seg.lazyOffsetIndex.file.exists) + assertFalse(seg.lazyTimeIndex.file.exists) + // Ensure that file names are updated accordingly. + assertEquals(logFile.getAbsolutePath + ".deleted", seg.log.file.getAbsolutePath) + assertEquals(indexFile.getAbsolutePath + ".deleted", seg.lazyOffsetIndex.file.getAbsolutePath) + assertEquals(timeIndexFile.getAbsolutePath + ".deleted", seg.lazyTimeIndex.file.getAbsolutePath) + assertTrue(seg.log.file.exists) + // Ensure lazy creation of offset index file upon accessing it. + seg.lazyOffsetIndex.get + assertTrue(seg.lazyOffsetIndex.file.exists) + // Ensure lazy creation of time index file upon accessing it. + seg.lazyTimeIndex.get + assertTrue(seg.lazyTimeIndex.file.exists) + } + + /** + * Create a segment with some data and an index. Then corrupt the index, + * and recover the segment, the entries should all be readable. + */ + @Test + def testRecoveryFixesCorruptIndex(): Unit = { + val seg = createSegment(0) + for(i <- 0 until 100) + seg.append(i, RecordBatch.NO_TIMESTAMP, -1L, records(i, i.toString)) + val indexFile = seg.lazyOffsetIndex.file + TestUtils.writeNonsenseToFile(indexFile, 5, indexFile.length.toInt) + seg.recover(new ProducerStateManager(topicPartition, logDir)) + for(i <- 0 until 100) { + val records = seg.read(i, 1, minOneMessage = true).records.records + assertEquals(i, records.iterator.next().offset) + } + } + + @Test + def testRecoverTransactionIndex(): Unit = { + val segment = createSegment(100) + val producerEpoch = 0.toShort + val partitionLeaderEpoch = 15 + val sequence = 100 + + val pid1 = 5L + val pid2 = 10L + + // append transactional records from pid1 + segment.append(largestOffset = 101L, largestTimestamp = RecordBatch.NO_TIMESTAMP, + shallowOffsetOfMaxTimestamp = 100L, records = MemoryRecords.withTransactionalRecords(100L, CompressionType.NONE, + pid1, producerEpoch, sequence, partitionLeaderEpoch, new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes))) + + // append transactional records from pid2 + segment.append(largestOffset = 103L, largestTimestamp = RecordBatch.NO_TIMESTAMP, + shallowOffsetOfMaxTimestamp = 102L, records = MemoryRecords.withTransactionalRecords(102L, CompressionType.NONE, + pid2, producerEpoch, sequence, partitionLeaderEpoch, new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes))) + + // append non-transactional records + segment.append(largestOffset = 105L, largestTimestamp = RecordBatch.NO_TIMESTAMP, + shallowOffsetOfMaxTimestamp = 104L, records = MemoryRecords.withRecords(104L, CompressionType.NONE, + partitionLeaderEpoch, new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes))) + + // abort the transaction from pid2 (note LSO should be 100L since the txn from pid1 has not completed) + segment.append(largestOffset = 106L, largestTimestamp = RecordBatch.NO_TIMESTAMP, + shallowOffsetOfMaxTimestamp = 106L, records = endTxnRecords(ControlRecordType.ABORT, pid2, producerEpoch, offset = 106L)) + + // commit the transaction from pid1 + segment.append(largestOffset = 107L, largestTimestamp = RecordBatch.NO_TIMESTAMP, + shallowOffsetOfMaxTimestamp = 107L, records = endTxnRecords(ControlRecordType.COMMIT, pid1, producerEpoch, offset = 107L)) + + var stateManager = new ProducerStateManager(topicPartition, logDir) + segment.recover(stateManager) + assertEquals(108L, stateManager.mapEndOffset) + + + var abortedTxns = segment.txnIndex.allAbortedTxns + assertEquals(1, abortedTxns.size) + var abortedTxn = abortedTxns.head + assertEquals(pid2, abortedTxn.producerId) + assertEquals(102L, abortedTxn.firstOffset) + assertEquals(106L, abortedTxn.lastOffset) + assertEquals(100L, abortedTxn.lastStableOffset) + + // recover again, but this time assuming the transaction from pid2 began on a previous segment + stateManager = new ProducerStateManager(topicPartition, logDir) + stateManager.loadProducerEntry(new ProducerStateEntry(pid2, + mutable.Queue[BatchMetadata](BatchMetadata(10, 10L, 5, RecordBatch.NO_TIMESTAMP)), producerEpoch, + 0, RecordBatch.NO_TIMESTAMP, Some(75L))) + segment.recover(stateManager) + assertEquals(108L, stateManager.mapEndOffset) + + abortedTxns = segment.txnIndex.allAbortedTxns + assertEquals(1, abortedTxns.size) + abortedTxn = abortedTxns.head + assertEquals(pid2, abortedTxn.producerId) + assertEquals(75L, abortedTxn.firstOffset) + assertEquals(106L, abortedTxn.lastOffset) + assertEquals(100L, abortedTxn.lastStableOffset) + } + + /** + * Create a segment with some data, then recover the segment. + * The epoch cache entries should reflect the segment. + */ + @Test + def testRecoveryRebuildsEpochCache(): Unit = { + val seg = createSegment(0) + + val checkpoint: LeaderEpochCheckpoint = new LeaderEpochCheckpoint { + private var epochs = Seq.empty[EpochEntry] + + override def write(epochs: Iterable[EpochEntry]): Unit = { + this.epochs = epochs.toVector + } + + override def read(): Seq[EpochEntry] = this.epochs + } + + val cache = new LeaderEpochFileCache(topicPartition, checkpoint) + seg.append(largestOffset = 105L, largestTimestamp = RecordBatch.NO_TIMESTAMP, + shallowOffsetOfMaxTimestamp = 104L, records = MemoryRecords.withRecords(104L, CompressionType.NONE, 0, + new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes))) + + seg.append(largestOffset = 107L, largestTimestamp = RecordBatch.NO_TIMESTAMP, + shallowOffsetOfMaxTimestamp = 106L, records = MemoryRecords.withRecords(106L, CompressionType.NONE, 1, + new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes))) + + seg.append(largestOffset = 109L, largestTimestamp = RecordBatch.NO_TIMESTAMP, + shallowOffsetOfMaxTimestamp = 108L, records = MemoryRecords.withRecords(108L, CompressionType.NONE, 1, + new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes))) + + seg.append(largestOffset = 111L, largestTimestamp = RecordBatch.NO_TIMESTAMP, + shallowOffsetOfMaxTimestamp = 110, records = MemoryRecords.withRecords(110L, CompressionType.NONE, 2, + new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes))) + + seg.recover(new ProducerStateManager(topicPartition, logDir), Some(cache)) + assertEquals(ArrayBuffer(EpochEntry(epoch = 0, startOffset = 104L), + EpochEntry(epoch = 1, startOffset = 106), + EpochEntry(epoch = 2, startOffset = 110)), + cache.epochEntries) + } + + private def endTxnRecords(controlRecordType: ControlRecordType, + producerId: Long, + producerEpoch: Short, + offset: Long, + partitionLeaderEpoch: Int = 0, + coordinatorEpoch: Int = 0, + timestamp: Long = RecordBatch.NO_TIMESTAMP): MemoryRecords = { + val marker = new EndTransactionMarker(controlRecordType, coordinatorEpoch) + MemoryRecords.withEndTransactionMarker(offset, timestamp, partitionLeaderEpoch, producerId, producerEpoch, marker) + } + + /** + * Create a segment with some data and an index. Then corrupt the index, + * and recover the segment, the entries should all be readable. + */ + @Test + def testRecoveryFixesCorruptTimeIndex(): Unit = { + val seg = createSegment(0) + for(i <- 0 until 100) + seg.append(i, i * 10, i, records(i, i.toString)) + val timeIndexFile = seg.lazyTimeIndex.file + TestUtils.writeNonsenseToFile(timeIndexFile, 5, timeIndexFile.length.toInt) + seg.recover(new ProducerStateManager(topicPartition, logDir)) + for(i <- 0 until 100) { + assertEquals(i, seg.findOffsetByTimestamp(i * 10).get.offset) + if (i < 99) + assertEquals(i + 1, seg.findOffsetByTimestamp(i * 10 + 1).get.offset) + } + } + + /** + * Randomly corrupt a log a number of times and attempt recovery. + */ + @Test + def testRecoveryWithCorruptMessage(): Unit = { + val messagesAppended = 20 + for (_ <- 0 until 10) { + val seg = createSegment(0) + for (i <- 0 until messagesAppended) + seg.append(i, RecordBatch.NO_TIMESTAMP, -1L, records(i, i.toString)) + val offsetToBeginCorruption = TestUtils.random.nextInt(messagesAppended) + // start corrupting somewhere in the middle of the chosen record all the way to the end + + val recordPosition = seg.log.searchForOffsetWithSize(offsetToBeginCorruption, 0) + val position = recordPosition.position + TestUtils.random.nextInt(15) + TestUtils.writeNonsenseToFile(seg.log.file, position, (seg.log.file.length - position).toInt) + seg.recover(new ProducerStateManager(topicPartition, logDir)) + assertEquals((0 until offsetToBeginCorruption).toList, seg.log.batches.asScala.map(_.lastOffset).toList, + "Should have truncated off bad messages.") + seg.deleteIfExists() + } + } + + private def createSegment(baseOffset: Long, fileAlreadyExists: Boolean, initFileSize: Int, preallocate: Boolean): LogSegment = { + val tempDir = TestUtils.tempDir() + val logConfig = LogConfig(Map( + LogConfig.IndexIntervalBytesProp -> 10, + LogConfig.SegmentIndexBytesProp -> 1000, + LogConfig.SegmentJitterMsProp -> 0 + ).asJava) + val seg = LogSegment.open(tempDir, baseOffset, logConfig, Time.SYSTEM, fileAlreadyExists = fileAlreadyExists, + initFileSize = initFileSize, preallocate = preallocate) + segments += seg + seg + } + + /* create a segment with pre allocate, put message to it and verify */ + @Test + def testCreateWithInitFileSizeAppendMessage(): Unit = { + val seg = createSegment(40, false, 512*1024*1024, true) + val ms = records(50, "hello", "there") + seg.append(51, RecordBatch.NO_TIMESTAMP, -1L, ms) + val ms2 = records(60, "alpha", "beta") + seg.append(61, RecordBatch.NO_TIMESTAMP, -1L, ms2) + val read = seg.read(startOffset = 55, maxSize = 200) + checkEquals(ms2.records.iterator, read.records.records.iterator) + } + + /* create a segment with pre allocate and clearly shut down*/ + @Test + def testCreateWithInitFileSizeClearShutdown(): Unit = { + val tempDir = TestUtils.tempDir() + val logConfig = LogConfig(Map( + LogConfig.IndexIntervalBytesProp -> 10, + LogConfig.SegmentIndexBytesProp -> 1000, + LogConfig.SegmentJitterMsProp -> 0 + ).asJava) + + val seg = LogSegment.open(tempDir, baseOffset = 40, logConfig, Time.SYSTEM, + initFileSize = 512 * 1024 * 1024, preallocate = true) + + val ms = records(50, "hello", "there") + seg.append(51, RecordBatch.NO_TIMESTAMP, -1L, ms) + val ms2 = records(60, "alpha", "beta") + seg.append(61, RecordBatch.NO_TIMESTAMP, -1L, ms2) + val read = seg.read(startOffset = 55, maxSize = 200) + checkEquals(ms2.records.iterator, read.records.records.iterator) + val oldSize = seg.log.sizeInBytes() + val oldPosition = seg.log.channel.position + val oldFileSize = seg.log.file.length + assertEquals(512*1024*1024, oldFileSize) + seg.close() + //After close, file should be trimmed + assertEquals(oldSize, seg.log.file.length) + + val segReopen = LogSegment.open(tempDir, baseOffset = 40, logConfig, Time.SYSTEM, fileAlreadyExists = true, + initFileSize = 512 * 1024 * 1024, preallocate = true) + segments += segReopen + + val readAgain = segReopen.read(startOffset = 55, maxSize = 200) + checkEquals(ms2.records.iterator, readAgain.records.records.iterator) + val size = segReopen.log.sizeInBytes() + val position = segReopen.log.channel.position + val fileSize = segReopen.log.file.length + assertEquals(oldPosition, position) + assertEquals(oldSize, size) + assertEquals(size, fileSize) + } + + @Test + def shouldTruncateEvenIfOffsetPointsToAGapInTheLog(): Unit = { + val seg = createSegment(40) + val offset = 40 + + def records(offset: Long, record: String): MemoryRecords = + MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V2, offset, CompressionType.NONE, TimestampType.CREATE_TIME, + new SimpleRecord(offset * 1000, record.getBytes)) + + //Given two messages with a gap between them (e.g. mid offset compacted away) + val ms1 = records(offset, "first message") + seg.append(offset, RecordBatch.NO_TIMESTAMP, -1L, ms1) + val ms2 = records(offset + 3, "message after gap") + seg.append(offset + 3, RecordBatch.NO_TIMESTAMP, -1L, ms2) + + // When we truncate to an offset without a corresponding log entry + seg.truncateTo(offset + 1) + + //Then we should still truncate the record that was present (i.e. offset + 3 is gone) + val log = seg.read(offset, 10000) + assertEquals(offset, log.records.batches.iterator.next().baseOffset()) + assertEquals(1, log.records.batches.asScala.size) + } + + @Test + def testAppendFromFile(): Unit = { + def records(offset: Long, size: Int): MemoryRecords = + MemoryRecords.withRecords(RecordBatch.MAGIC_VALUE_V2, offset, CompressionType.NONE, TimestampType.CREATE_TIME, + new SimpleRecord(new Array[Byte](size))) + + // create a log file in a separate directory to avoid conflicting with created segments + val tempDir = TestUtils.tempDir() + val fileRecords = FileRecords.open(UnifiedLog.logFile(tempDir, 0)) + + // Simulate a scenario where we have a single log with an offset range exceeding Int.MaxValue + fileRecords.append(records(0, 1024)) + fileRecords.append(records(500, 1024 * 1024 + 1)) + val sizeBeforeOverflow = fileRecords.sizeInBytes() + fileRecords.append(records(Int.MaxValue + 5L, 1024)) + val sizeAfterOverflow = fileRecords.sizeInBytes() + + val segment = createSegment(0) + val bytesAppended = segment.appendFromFile(fileRecords, 0) + assertEquals(sizeBeforeOverflow, bytesAppended) + assertEquals(sizeBeforeOverflow, segment.size) + + val overflowSegment = createSegment(Int.MaxValue) + val overflowBytesAppended = overflowSegment.appendFromFile(fileRecords, sizeBeforeOverflow) + assertEquals(sizeAfterOverflow - sizeBeforeOverflow, overflowBytesAppended) + assertEquals(overflowBytesAppended, overflowSegment.size) + + Utils.delete(tempDir) + } + +} diff --git a/core/src/test/scala/unit/kafka/log/LogSegmentsTest.scala b/core/src/test/scala/unit/kafka/log/LogSegmentsTest.scala new file mode 100644 index 0000000..7e01e3b --- /dev/null +++ b/core/src/test/scala/unit/kafka/log/LogSegmentsTest.scala @@ -0,0 +1,240 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.log + +import java.io.File + +import kafka.utils.TestUtils +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.utils.{Time, Utils} +import org.easymock.EasyMock +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} + +class LogSegmentsTest { + + val topicPartition = new TopicPartition("topic", 0) + var logDir: File = _ + + /* create a segment with the given base offset */ + private def createSegment(offset: Long, + indexIntervalBytes: Int = 10, + time: Time = Time.SYSTEM): LogSegment = { + LogTestUtils.createSegment(offset, logDir, indexIntervalBytes, time) + } + + @BeforeEach + def setup(): Unit = { + logDir = TestUtils.tempDir() + } + + @AfterEach + def teardown(): Unit = { + Utils.delete(logDir) + } + + private def assertEntry(segment: LogSegment, tested: java.util.Map.Entry[Long, LogSegment]): Unit = { + assertEquals(segment.baseOffset, tested.getKey()) + assertEquals(segment, tested.getValue()) + } + + @Test + def testBasicOperations(): Unit = { + val segments = new LogSegments(topicPartition) + assertTrue(segments.isEmpty) + assertFalse(segments.nonEmpty) + + val offset1 = 40 + val seg1 = createSegment(offset1) + val offset2 = 80 + val seg2 = createSegment(offset2) + val seg3 = createSegment(offset1) + + // Add seg1 + segments.add(seg1) + assertFalse(segments.isEmpty) + assertTrue(segments.nonEmpty) + assertEquals(1, segments.numberOfSegments) + assertTrue(segments.contains(offset1)) + assertEquals(Some(seg1), segments.get(offset1)) + + // Add seg2 + segments.add(seg2) + assertFalse(segments.isEmpty) + assertTrue(segments.nonEmpty) + assertEquals(2, segments.numberOfSegments) + assertTrue(segments.contains(offset2)) + assertEquals(Some(seg2), segments.get(offset2)) + + // Replace seg1 with seg3 + segments.add(seg3) + assertFalse(segments.isEmpty) + assertTrue(segments.nonEmpty) + assertEquals(2, segments.numberOfSegments) + assertTrue(segments.contains(offset1)) + assertEquals(Some(seg3), segments.get(offset1)) + + // Remove seg2 + segments.remove(offset2) + assertFalse(segments.isEmpty) + assertTrue(segments.nonEmpty) + assertEquals(1, segments.numberOfSegments) + assertFalse(segments.contains(offset2)) + + // Clear all segments including seg3 + segments.clear() + assertTrue(segments.isEmpty) + assertFalse(segments.nonEmpty) + assertEquals(0, segments.numberOfSegments) + assertFalse(segments.contains(offset1)) + + segments.close() + } + + @Test + def testSegmentAccess(): Unit = { + val segments = new LogSegments(topicPartition) + + val offset1 = 1 + val seg1 = createSegment(offset1) + val offset2 = 2 + val seg2 = createSegment(offset2) + val offset3 = 3 + val seg3 = createSegment(offset3) + val offset4 = 4 + val seg4 = createSegment(offset4) + + // Test firstEntry, lastEntry + List(seg1, seg2, seg3, seg4).foreach { + seg => + segments.add(seg) + assertEntry(seg1, segments.firstEntry.get) + assertEquals(Some(seg1), segments.firstSegment) + assertEntry(seg, segments.lastEntry.get) + assertEquals(Some(seg), segments.lastSegment) + } + + // Test baseOffsets + assertEquals(Seq(offset1, offset2, offset3, offset4), segments.baseOffsets) + + // Test values + assertEquals(Seq(seg1, seg2, seg3, seg4), segments.values.toSeq) + + // Test values(to, from) + assertThrows(classOf[IllegalArgumentException], () => segments.values(2, 1)) + assertEquals(Seq(), segments.values(1, 1).toSeq) + assertEquals(Seq(seg1), segments.values(1, 2).toSeq) + assertEquals(Seq(seg1, seg2), segments.values(1, 3).toSeq) + assertEquals(Seq(seg1, seg2, seg3), segments.values(1, 4).toSeq) + assertEquals(Seq(seg2, seg3), segments.values(2, 4).toSeq) + assertEquals(Seq(seg3), segments.values(3, 4).toSeq) + assertEquals(Seq(), segments.values(4, 4).toSeq) + assertEquals(Seq(seg4), segments.values(4, 5).toSeq) + + segments.close() + } + + @Test + def testClosestMatchOperations(): Unit = { + val segments = new LogSegments(topicPartition) + + val seg1 = createSegment(1) + val seg2 = createSegment(3) + val seg3 = createSegment(5) + val seg4 = createSegment(7) + + List(seg1, seg2, seg3, seg4).foreach(segments.add) + + // Test floorSegment + assertEquals(Some(seg1), segments.floorSegment(2)) + assertEquals(Some(seg2), segments.floorSegment(3)) + + // Test lowerSegment + assertEquals(Some(seg1), segments.lowerSegment(3)) + assertEquals(Some(seg2), segments.lowerSegment(4)) + + // Test higherSegment, higherEntry + assertEquals(Some(seg3), segments.higherSegment(4)) + assertEntry(seg3, segments.higherEntry(4).get) + assertEquals(Some(seg4), segments.higherSegment(5)) + assertEntry(seg4, segments.higherEntry(5).get) + + segments.close() + } + + @Test + def testHigherSegments(): Unit = { + val segments = new LogSegments(topicPartition) + + val seg1 = createSegment(1) + val seg2 = createSegment(3) + val seg3 = createSegment(5) + val seg4 = createSegment(7) + val seg5 = createSegment(9) + + List(seg1, seg2, seg3, seg4, seg5).foreach(segments.add) + + // higherSegments(0) should return all segments in order + { + val iterator = segments.higherSegments(0).iterator + List(seg1, seg2, seg3, seg4, seg5).foreach { + segment => + assertTrue(iterator.hasNext) + assertEquals(segment, iterator.next()) + } + assertFalse(iterator.hasNext) + } + + // higherSegments(1) should return all segments in order except seg1 + { + val iterator = segments.higherSegments(1).iterator + List(seg2, seg3, seg4, seg5).foreach { + segment => + assertTrue(iterator.hasNext) + assertEquals(segment, iterator.next()) + } + assertFalse(iterator.hasNext) + } + + // higherSegments(8) should return only seg5 + { + val iterator = segments.higherSegments(8).iterator + assertTrue(iterator.hasNext) + assertEquals(seg5, iterator.next()) + assertFalse(iterator.hasNext) + } + + // higherSegments(9) should return no segments + { + val iterator = segments.higherSegments(9).iterator + assertFalse(iterator.hasNext) + } + } + + @Test + def testSizeForLargeLogs(): Unit = { + val largeSize = Int.MaxValue.toLong * 2 + val logSegment: LogSegment = EasyMock.createMock(classOf[LogSegment]) + + EasyMock.expect(logSegment.size).andReturn(Int.MaxValue).anyTimes + EasyMock.replay(logSegment) + + assertEquals(Int.MaxValue, LogSegments.sizeInBytes(Seq(logSegment))) + assertEquals(largeSize, LogSegments.sizeInBytes(Seq(logSegment, logSegment))) + assertTrue(UnifiedLog.sizeInBytes(Seq(logSegment, logSegment)) > Int.MaxValue) + } +} diff --git a/core/src/test/scala/unit/kafka/log/LogTestUtils.scala b/core/src/test/scala/unit/kafka/log/LogTestUtils.scala new file mode 100644 index 0000000..1f32ed8 --- /dev/null +++ b/core/src/test/scala/unit/kafka/log/LogTestUtils.scala @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import java.io.File +import java.util.Properties +import kafka.server.checkpoints.LeaderEpochCheckpointFile +import kafka.server.{BrokerTopicStats, FetchDataInfo, FetchIsolation, FetchLogEnd, LogDirFailureChannel} +import kafka.utils.{Scheduler, TestUtils} +import org.apache.kafka.common.Uuid +import org.apache.kafka.common.record.{CompressionType, ControlRecordType, EndTransactionMarker, FileRecords, MemoryRecords, RecordBatch, SimpleRecord} +import org.apache.kafka.common.utils.{Time, Utils} +import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse} + +import scala.collection.Iterable +import scala.jdk.CollectionConverters._ + +object LogTestUtils { + /** + * Create a segment with the given base offset + */ + def createSegment(offset: Long, + logDir: File, + indexIntervalBytes: Int = 10, + time: Time = Time.SYSTEM): LogSegment = { + val ms = FileRecords.open(UnifiedLog.logFile(logDir, offset)) + val idx = LazyIndex.forOffset(UnifiedLog.offsetIndexFile(logDir, offset), offset, maxIndexSize = 1000) + val timeIdx = LazyIndex.forTime(UnifiedLog.timeIndexFile(logDir, offset), offset, maxIndexSize = 1500) + val txnIndex = new TransactionIndex(offset, UnifiedLog.transactionIndexFile(logDir, offset)) + + new LogSegment(ms, idx, timeIdx, txnIndex, offset, indexIntervalBytes, 0, time) + } + + def createLogConfig(segmentMs: Long = Defaults.SegmentMs, + segmentBytes: Int = Defaults.SegmentSize, + retentionMs: Long = Defaults.RetentionMs, + retentionBytes: Long = Defaults.RetentionSize, + segmentJitterMs: Long = Defaults.SegmentJitterMs, + cleanupPolicy: String = Defaults.CleanupPolicy, + maxMessageBytes: Int = Defaults.MaxMessageSize, + indexIntervalBytes: Int = Defaults.IndexInterval, + segmentIndexBytes: Int = Defaults.MaxIndexSize, + fileDeleteDelayMs: Long = Defaults.FileDeleteDelayMs): LogConfig = { + val logProps = new Properties() + logProps.put(LogConfig.SegmentMsProp, segmentMs: java.lang.Long) + logProps.put(LogConfig.SegmentBytesProp, segmentBytes: Integer) + logProps.put(LogConfig.RetentionMsProp, retentionMs: java.lang.Long) + logProps.put(LogConfig.RetentionBytesProp, retentionBytes: java.lang.Long) + logProps.put(LogConfig.SegmentJitterMsProp, segmentJitterMs: java.lang.Long) + logProps.put(LogConfig.CleanupPolicyProp, cleanupPolicy) + logProps.put(LogConfig.MaxMessageBytesProp, maxMessageBytes: Integer) + logProps.put(LogConfig.IndexIntervalBytesProp, indexIntervalBytes: Integer) + logProps.put(LogConfig.SegmentIndexBytesProp, segmentIndexBytes: Integer) + logProps.put(LogConfig.FileDeleteDelayMsProp, fileDeleteDelayMs: java.lang.Long) + LogConfig(logProps) + } + + def createLog(dir: File, + config: LogConfig, + brokerTopicStats: BrokerTopicStats, + scheduler: Scheduler, + time: Time, + logStartOffset: Long = 0L, + recoveryPoint: Long = 0L, + maxProducerIdExpirationMs: Int = 60 * 60 * 1000, + producerIdExpirationCheckIntervalMs: Int = LogManager.ProducerIdExpirationCheckIntervalMs, + lastShutdownClean: Boolean = true, + topicId: Option[Uuid] = None, + keepPartitionMetadataFile: Boolean = true): UnifiedLog = { + UnifiedLog(dir = dir, + config = config, + logStartOffset = logStartOffset, + recoveryPoint = recoveryPoint, + scheduler = scheduler, + brokerTopicStats = brokerTopicStats, + time = time, + maxProducerIdExpirationMs = maxProducerIdExpirationMs, + producerIdExpirationCheckIntervalMs = producerIdExpirationCheckIntervalMs, + logDirFailureChannel = new LogDirFailureChannel(10), + lastShutdownClean = lastShutdownClean, + topicId = topicId, + keepPartitionMetadataFile = keepPartitionMetadataFile) + } + + /** + * Check if the given log contains any segment with records that cause offset overflow. + * @param log Log to check + * @return true if log contains at least one segment with offset overflow; false otherwise + */ + def hasOffsetOverflow(log: UnifiedLog): Boolean = firstOverflowSegment(log).isDefined + + def firstOverflowSegment(log: UnifiedLog): Option[LogSegment] = { + def hasOverflow(baseOffset: Long, batch: RecordBatch): Boolean = + batch.lastOffset > baseOffset + Int.MaxValue || batch.baseOffset < baseOffset + + for (segment <- log.logSegments) { + val overflowBatch = segment.log.batches.asScala.find(batch => hasOverflow(segment.baseOffset, batch)) + if (overflowBatch.isDefined) + return Some(segment) + } + None + } + + def rawSegment(logDir: File, baseOffset: Long): FileRecords = + FileRecords.open(UnifiedLog.logFile(logDir, baseOffset)) + + /** + * Initialize the given log directory with a set of segments, one of which will have an + * offset which overflows the segment + */ + def initializeLogDirWithOverflowedSegment(logDir: File): Unit = { + def writeSampleBatches(baseOffset: Long, segment: FileRecords): Long = { + def record(offset: Long) = { + val data = offset.toString.getBytes + new SimpleRecord(data, data) + } + + segment.append(MemoryRecords.withRecords(baseOffset, CompressionType.NONE, 0, + record(baseOffset))) + segment.append(MemoryRecords.withRecords(baseOffset + 1, CompressionType.NONE, 0, + record(baseOffset + 1), + record(baseOffset + 2))) + segment.append(MemoryRecords.withRecords(baseOffset + Int.MaxValue - 1, CompressionType.NONE, 0, + record(baseOffset + Int.MaxValue - 1))) + // Need to create the offset files explicitly to avoid triggering segment recovery to truncate segment. + UnifiedLog.offsetIndexFile(logDir, baseOffset).createNewFile() + UnifiedLog.timeIndexFile(logDir, baseOffset).createNewFile() + baseOffset + Int.MaxValue + } + + def writeNormalSegment(baseOffset: Long): Long = { + val segment = rawSegment(logDir, baseOffset) + try writeSampleBatches(baseOffset, segment) + finally segment.close() + } + + def writeOverflowSegment(baseOffset: Long): Long = { + val segment = rawSegment(logDir, baseOffset) + try { + val nextOffset = writeSampleBatches(baseOffset, segment) + writeSampleBatches(nextOffset, segment) + } finally segment.close() + } + + // We create three segments, the second of which contains offsets which overflow + var nextOffset = 0L + nextOffset = writeNormalSegment(nextOffset) + nextOffset = writeOverflowSegment(nextOffset) + writeNormalSegment(nextOffset) + } + + /* extract all the keys from a log */ + def keysInLog(log: UnifiedLog): Iterable[Long] = { + for (logSegment <- log.logSegments; + batch <- logSegment.log.batches.asScala if !batch.isControlBatch; + record <- batch.asScala if record.hasValue && record.hasKey) + yield TestUtils.readString(record.key).toLong + } + + def recoverAndCheck(logDir: File, config: LogConfig, expectedKeys: Iterable[Long], brokerTopicStats: BrokerTopicStats, time: Time, scheduler: Scheduler): UnifiedLog = { + // Recover log file and check that after recovery, keys are as expected + // and all temporary files have been deleted + val recoveredLog = createLog(logDir, config, brokerTopicStats, scheduler, time, lastShutdownClean = false) + time.sleep(config.fileDeleteDelayMs + 1) + for (file <- logDir.listFiles) { + assertFalse(file.getName.endsWith(UnifiedLog.DeletedFileSuffix), "Unexpected .deleted file after recovery") + assertFalse(file.getName.endsWith(UnifiedLog.CleanedFileSuffix), "Unexpected .cleaned file after recovery") + assertFalse(file.getName.endsWith(UnifiedLog.SwapFileSuffix), "Unexpected .swap file after recovery") + } + assertEquals(expectedKeys, keysInLog(recoveredLog)) + assertFalse(hasOffsetOverflow(recoveredLog)) + recoveredLog + } + + def appendEndTxnMarkerAsLeader(log: UnifiedLog, + producerId: Long, + producerEpoch: Short, + controlType: ControlRecordType, + timestamp: Long, + coordinatorEpoch: Int = 0, + leaderEpoch: Int = 0): LogAppendInfo = { + val records = endTxnRecords(controlType, producerId, producerEpoch, + coordinatorEpoch = coordinatorEpoch, timestamp = timestamp) + log.appendAsLeader(records, origin = AppendOrigin.Coordinator, leaderEpoch = leaderEpoch) + } + + private def endTxnRecords(controlRecordType: ControlRecordType, + producerId: Long, + epoch: Short, + offset: Long = 0L, + coordinatorEpoch: Int, + partitionLeaderEpoch: Int = 0, + timestamp: Long): MemoryRecords = { + val marker = new EndTransactionMarker(controlRecordType, coordinatorEpoch) + MemoryRecords.withEndTransactionMarker(offset, timestamp, partitionLeaderEpoch, producerId, epoch, marker) + } + + def readLog(log: UnifiedLog, + startOffset: Long, + maxLength: Int, + isolation: FetchIsolation = FetchLogEnd, + minOneMessage: Boolean = true): FetchDataInfo = { + log.read(startOffset, maxLength, isolation, minOneMessage) + } + + def allAbortedTransactions(log: UnifiedLog): Iterable[AbortedTxn] = log.logSegments.flatMap(_.txnIndex.allAbortedTxns) + + def deleteProducerSnapshotFiles(logDir: File): Unit = { + val files = logDir.listFiles.filter(f => f.isFile && f.getName.endsWith(UnifiedLog.ProducerSnapshotFileSuffix)) + files.foreach(Utils.delete) + } + + def listProducerSnapshotOffsets(logDir: File): Seq[Long] = + ProducerStateManager.listSnapshotFiles(logDir).map(_.offset).sorted + + def assertLeaderEpochCacheEmpty(log: UnifiedLog): Unit = { + assertEquals(None, log.leaderEpochCache) + assertEquals(None, log.latestEpoch) + assertFalse(LeaderEpochCheckpointFile.newFile(log.dir).exists()) + } + + def appendNonTransactionalAsLeader(log: UnifiedLog, numRecords: Int): Unit = { + val simpleRecords = (0 until numRecords).map { seq => + new SimpleRecord(s"$seq".getBytes) + } + val records = MemoryRecords.withRecords(CompressionType.NONE, simpleRecords: _*) + log.appendAsLeader(records, leaderEpoch = 0) + } + + def appendTransactionalAsLeader(log: UnifiedLog, + producerId: Long, + producerEpoch: Short, + time: Time): Int => Unit = { + appendIdempotentAsLeader(log, producerId, producerEpoch, time, isTransactional = true) + } + + def appendIdempotentAsLeader(log: UnifiedLog, + producerId: Long, + producerEpoch: Short, + time: Time, + isTransactional: Boolean = false): Int => Unit = { + var sequence = 0 + numRecords: Int => { + val simpleRecords = (sequence until sequence + numRecords).map { seq => + new SimpleRecord(time.milliseconds(), s"$seq".getBytes) + } + + val records = if (isTransactional) { + MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, + producerEpoch, sequence, simpleRecords: _*) + } else { + MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, + producerEpoch, sequence, simpleRecords: _*) + } + + log.appendAsLeader(records, leaderEpoch = 0) + sequence += numRecords + } + } +} diff --git a/core/src/test/scala/unit/kafka/log/LogValidatorTest.scala b/core/src/test/scala/unit/kafka/log/LogValidatorTest.scala new file mode 100644 index 0000000..4275684 --- /dev/null +++ b/core/src/test/scala/unit/kafka/log/LogValidatorTest.scala @@ -0,0 +1,1560 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.log + +import java.nio.ByteBuffer +import java.util.concurrent.TimeUnit +import kafka.api.{ApiVersion, KAFKA_2_0_IV1, KAFKA_2_3_IV1} +import kafka.common.{LongRef, RecordValidationException} +import kafka.log.LogValidator.ValidationAndOffsetAssignResult +import kafka.message._ +import kafka.metrics.KafkaYammerMetrics +import kafka.server.{BrokerTopicStats, RequestLocal} +import kafka.utils.TestUtils.meterCount +import org.apache.kafka.common.errors.{InvalidTimestampException, UnsupportedCompressionTypeException, UnsupportedForMessageFormatException} +import org.apache.kafka.common.record._ +import org.apache.kafka.common.utils.Time +import org.apache.kafka.common.{InvalidRecordException, TopicPartition} +import org.apache.kafka.test.TestUtils +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ + +class LogValidatorTest { + + val time = Time.SYSTEM + val topicPartition = new TopicPartition("topic", 0) + val brokerTopicStats = new BrokerTopicStats + val metricsKeySet = KafkaYammerMetrics.defaultRegistry.allMetrics.keySet.asScala + + @Test + def testOnlyOneBatch(): Unit = { + checkOnlyOneBatch(RecordBatch.MAGIC_VALUE_V0, CompressionType.GZIP, CompressionType.GZIP) + checkOnlyOneBatch(RecordBatch.MAGIC_VALUE_V1, CompressionType.GZIP, CompressionType.GZIP) + checkOnlyOneBatch(RecordBatch.MAGIC_VALUE_V2, CompressionType.GZIP, CompressionType.GZIP) + checkOnlyOneBatch(RecordBatch.MAGIC_VALUE_V0, CompressionType.GZIP, CompressionType.NONE) + checkOnlyOneBatch(RecordBatch.MAGIC_VALUE_V1, CompressionType.GZIP, CompressionType.NONE) + checkOnlyOneBatch(RecordBatch.MAGIC_VALUE_V2, CompressionType.GZIP, CompressionType.NONE) + checkOnlyOneBatch(RecordBatch.MAGIC_VALUE_V2, CompressionType.NONE, CompressionType.NONE) + checkOnlyOneBatch(RecordBatch.MAGIC_VALUE_V2, CompressionType.NONE, CompressionType.GZIP) + } + + @Test + def testAllowMultiBatch(): Unit = { + checkAllowMultiBatch(RecordBatch.MAGIC_VALUE_V0, CompressionType.NONE, CompressionType.NONE) + checkAllowMultiBatch(RecordBatch.MAGIC_VALUE_V1, CompressionType.NONE, CompressionType.NONE) + checkAllowMultiBatch(RecordBatch.MAGIC_VALUE_V0, CompressionType.NONE, CompressionType.GZIP) + checkAllowMultiBatch(RecordBatch.MAGIC_VALUE_V1, CompressionType.NONE, CompressionType.GZIP) + } + + @Test + def testValidationOfBatchesWithNonSequentialInnerOffsets(): Unit = { + def testMessageValidation(magicValue: Byte): Unit = { + val numRecords = 20 + val invalidRecords = recordsWithNonSequentialInnerOffsets(magicValue, CompressionType.GZIP, numRecords) + + // Validation for v2 and above is strict for this case. For older formats, we fix invalid + // internal offsets by rewriting the batch. + if (magicValue >= RecordBatch.MAGIC_VALUE_V2) { + assertThrows(classOf[InvalidRecordException], + () => validateMessages(invalidRecords, magicValue, CompressionType.GZIP, CompressionType.GZIP) + ) + } else { + val result = validateMessages(invalidRecords, magicValue, CompressionType.GZIP, CompressionType.GZIP) + assertEquals(0 until numRecords, result.validatedRecords.records.asScala.map(_.offset)) + } + } + + for (version <- RecordVersion.values) { + testMessageValidation(version.value) + } + } + + @Test + def testMisMatchMagic(): Unit = { + checkMismatchMagic(RecordBatch.MAGIC_VALUE_V0, RecordBatch.MAGIC_VALUE_V1, CompressionType.GZIP) + checkMismatchMagic(RecordBatch.MAGIC_VALUE_V1, RecordBatch.MAGIC_VALUE_V0, CompressionType.GZIP) + } + + private def checkOnlyOneBatch(magic: Byte, sourceCompressionType: CompressionType, targetCompressionType: CompressionType): Unit = { + assertThrows(classOf[InvalidRecordException], + () => validateMessages(createTwoBatchedRecords(magic, 0L, sourceCompressionType), magic, sourceCompressionType, targetCompressionType) + ) + } + + private def checkAllowMultiBatch(magic: Byte, sourceCompressionType: CompressionType, targetCompressionType: CompressionType): Unit = { + validateMessages(createTwoBatchedRecords(magic, 0L, sourceCompressionType), magic, sourceCompressionType, targetCompressionType) + } + + private def checkMismatchMagic(batchMagic: Byte, recordMagic: Byte, compressionType: CompressionType): Unit = { + assertThrows(classOf[RecordValidationException], + () => validateMessages(recordsWithInvalidInnerMagic(batchMagic, recordMagic, compressionType), batchMagic, compressionType, compressionType) + ) + assertEquals(metricsKeySet.count(_.getMBeanName.endsWith(s"${BrokerTopicStats.InvalidMagicNumberRecordsPerSec}")), 1) + assertTrue(meterCount(s"${BrokerTopicStats.InvalidMagicNumberRecordsPerSec}") > 0) + } + + private def validateMessages(records: MemoryRecords, + magic: Byte, + sourceCompressionType: CompressionType, + targetCompressionType: CompressionType): ValidationAndOffsetAssignResult = { + LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + new LongRef(0L), + time, + now = 0L, + CompressionCodec.getCompressionCodec(sourceCompressionType.name), + CompressionCodec.getCompressionCodec(targetCompressionType.name), + compactedTopic = false, + magic, + TimestampType.CREATE_TIME, + 1000L, + RecordBatch.NO_PRODUCER_EPOCH, + origin = AppendOrigin.Client, + KAFKA_2_3_IV1, + brokerTopicStats, + RequestLocal.withThreadConfinedCaching) + } + + @Test + def testLogAppendTimeNonCompressedV1(): Unit = { + checkLogAppendTimeNonCompressed(RecordBatch.MAGIC_VALUE_V1) + } + + @Test + def testLogAppendTimeNonCompressedV2(): Unit = { + checkLogAppendTimeNonCompressed(RecordBatch.MAGIC_VALUE_V2) + } + + private def checkLogAppendTimeNonCompressed(magic: Byte): Unit = { + val now = System.currentTimeMillis() + // The timestamps should be overwritten + val records = createRecords(magicValue = magic, timestamp = 1234L, codec = CompressionType.NONE) + val validatedResults = LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + offsetCounter = new LongRef(0), + time= time, + now = now, + sourceCodec = NoCompressionCodec, + targetCodec = NoCompressionCodec, + compactedTopic = false, + magic = magic, + timestampType = TimestampType.LOG_APPEND_TIME, + timestampDiffMaxMs = 1000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching) + val validatedRecords = validatedResults.validatedRecords + assertEquals(records.records.asScala.size, validatedRecords.records.asScala.size, "message set size should not change") + validatedRecords.batches.forEach(batch => validateLogAppendTime(now, 1234L, batch)) + assertEquals(now, validatedResults.maxTimestamp, s"Max timestamp should be $now") + assertFalse(validatedResults.messageSizeMaybeChanged, "Message size should not have been changed") + + // we index from last offset in version 2 instead of base offset + val expectedMaxTimestampOffset = if (magic >= RecordBatch.MAGIC_VALUE_V2) 2 else 0 + assertEquals(expectedMaxTimestampOffset, validatedResults.shallowOffsetOfMaxTimestamp, + s"The offset of max timestamp should be $expectedMaxTimestampOffset") + verifyRecordConversionStats(validatedResults.recordConversionStats, numConvertedRecords = 0, records, + compressed = false) + } + + @Test + def testLogAppendTimeWithRecompressionV1(): Unit = { + checkLogAppendTimeWithRecompression(RecordBatch.MAGIC_VALUE_V1) + } + + private def checkLogAppendTimeWithRecompression(targetMagic: Byte): Unit = { + val now = System.currentTimeMillis() + // The timestamps should be overwritten + val records = createRecords(magicValue = RecordBatch.MAGIC_VALUE_V0, codec = CompressionType.GZIP) + val validatedResults = LogValidator.validateMessagesAndAssignOffsets( + records, + topicPartition, + offsetCounter = new LongRef(0), + time = time, + now = now, + sourceCodec = DefaultCompressionCodec, + targetCodec = DefaultCompressionCodec, + compactedTopic = false, + magic = targetMagic, + timestampType = TimestampType.LOG_APPEND_TIME, + timestampDiffMaxMs = 1000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching) + val validatedRecords = validatedResults.validatedRecords + + assertEquals(records.records.asScala.size, validatedRecords.records.asScala.size, + "message set size should not change") + validatedRecords.batches.forEach(batch => validateLogAppendTime(now, -1, batch)) + assertTrue(validatedRecords.batches.iterator.next().isValid, + "MessageSet should still valid") + assertEquals(now, validatedResults.maxTimestamp, + s"Max timestamp should be $now") + assertEquals(records.records.asScala.size - 1, validatedResults.shallowOffsetOfMaxTimestamp, + s"The offset of max timestamp should be ${records.records.asScala.size - 1}") + assertTrue(validatedResults.messageSizeMaybeChanged, + "Message size may have been changed") + + val stats = validatedResults.recordConversionStats + verifyRecordConversionStats(stats, numConvertedRecords = 3, records, compressed = true) + } + + @Test + def testLogAppendTimeWithRecompressionV2(): Unit = { + checkLogAppendTimeWithRecompression(RecordBatch.MAGIC_VALUE_V2) + } + + @Test + def testLogAppendTimeWithoutRecompressionV1(): Unit = { + checkLogAppendTimeWithoutRecompression(RecordBatch.MAGIC_VALUE_V1) + } + + private def checkLogAppendTimeWithoutRecompression(magic: Byte): Unit = { + val now = System.currentTimeMillis() + // The timestamps should be overwritten + val records = createRecords(magicValue = magic, timestamp = 1234L, codec = CompressionType.GZIP) + val validatedResults = LogValidator.validateMessagesAndAssignOffsets( + records, + topicPartition, + offsetCounter = new LongRef(0), + time = time, + now = now, + sourceCodec = DefaultCompressionCodec, + targetCodec = DefaultCompressionCodec, + compactedTopic = false, + magic = magic, + timestampType = TimestampType.LOG_APPEND_TIME, + timestampDiffMaxMs = 1000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching) + val validatedRecords = validatedResults.validatedRecords + + assertEquals(records.records.asScala.size, validatedRecords.records.asScala.size, + "message set size should not change") + validatedRecords.batches.forEach(batch => validateLogAppendTime(now, 1234L, batch)) + assertTrue(validatedRecords.batches.iterator.next().isValid, + "MessageSet should still valid") + assertEquals(now, validatedResults.maxTimestamp, + s"Max timestamp should be $now") + assertEquals(records.records.asScala.size - 1, validatedResults.shallowOffsetOfMaxTimestamp, + s"The offset of max timestamp should be ${records.records.asScala.size - 1}") + assertFalse(validatedResults.messageSizeMaybeChanged, + "Message size should not have been changed") + + verifyRecordConversionStats(validatedResults.recordConversionStats, numConvertedRecords = 0, records, + compressed = true) + } + + @Test + def testInvalidOffsetRangeAndRecordCount(): Unit = { + // The batch to be written contains 3 records, so the correct lastOffsetDelta is 2 + validateRecordBatchWithCountOverrides(lastOffsetDelta = 2, count = 3) + + // Count and offset range are inconsistent or invalid + assertInvalidBatchCountOverrides(lastOffsetDelta = 0, count = 3) + assertInvalidBatchCountOverrides(lastOffsetDelta = 15, count = 3) + assertInvalidBatchCountOverrides(lastOffsetDelta = -3, count = 3) + assertInvalidBatchCountOverrides(lastOffsetDelta = 2, count = -3) + assertInvalidBatchCountOverrides(lastOffsetDelta = 2, count = 6) + assertInvalidBatchCountOverrides(lastOffsetDelta = 2, count = 0) + assertInvalidBatchCountOverrides(lastOffsetDelta = -3, count = -2) + + // Count and offset range are consistent, but do not match the actual number of records + assertInvalidBatchCountOverrides(lastOffsetDelta = 5, count = 6) + assertInvalidBatchCountOverrides(lastOffsetDelta = 1, count = 2) + } + + private def assertInvalidBatchCountOverrides(lastOffsetDelta: Int, count: Int): Unit = { + assertThrows(classOf[InvalidRecordException], + () => validateRecordBatchWithCountOverrides(lastOffsetDelta, count)) + } + + private def validateRecordBatchWithCountOverrides(lastOffsetDelta: Int, count: Int): Unit = { + val records = createRecords(magicValue = RecordBatch.MAGIC_VALUE_V2, timestamp = 1234L, codec = CompressionType.NONE) + records.buffer.putInt(DefaultRecordBatch.RECORDS_COUNT_OFFSET, count) + records.buffer.putInt(DefaultRecordBatch.LAST_OFFSET_DELTA_OFFSET, lastOffsetDelta) + LogValidator.validateMessagesAndAssignOffsets( + records, + topicPartition, + offsetCounter = new LongRef(0), + time = time, + now = time.milliseconds(), + sourceCodec = DefaultCompressionCodec, + targetCodec = DefaultCompressionCodec, + compactedTopic = false, + magic = RecordBatch.MAGIC_VALUE_V2, + timestampType = TimestampType.LOG_APPEND_TIME, + timestampDiffMaxMs = 1000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching) + } + + @Test + def testLogAppendTimeWithoutRecompressionV2(): Unit = { + checkLogAppendTimeWithoutRecompression(RecordBatch.MAGIC_VALUE_V2) + } + + @Test + def testNonCompressedV1(): Unit = { + checkNonCompressed(RecordBatch.MAGIC_VALUE_V1) + } + + private def checkNonCompressed(magic: Byte): Unit = { + val now = System.currentTimeMillis() + val timestampSeq = Seq(now - 1, now + 1, now) + + val (producerId, producerEpoch, baseSequence, isTransactional, partitionLeaderEpoch) = + if (magic >= RecordBatch.MAGIC_VALUE_V2) + (1324L, 10.toShort, 984, true, 40) + else + (RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, false, + RecordBatch.NO_PARTITION_LEADER_EPOCH) + + val records = MemoryRecords.withRecords(magic, 0L, CompressionType.GZIP, TimestampType.CREATE_TIME, producerId, + producerEpoch, baseSequence, partitionLeaderEpoch, isTransactional, + new SimpleRecord(timestampSeq(0), "hello".getBytes), + new SimpleRecord(timestampSeq(1), "there".getBytes), + new SimpleRecord(timestampSeq(2), "beautiful".getBytes)) + + val validatingResults = LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + offsetCounter = new LongRef(0), + time = time, + now = System.currentTimeMillis(), + sourceCodec = NoCompressionCodec, + targetCodec = NoCompressionCodec, + compactedTopic = false, + magic = magic, + timestampType = TimestampType.CREATE_TIME, + timestampDiffMaxMs = 1000L, + partitionLeaderEpoch = partitionLeaderEpoch, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching) + val validatedRecords = validatingResults.validatedRecords + + var i = 0 + for (batch <- validatedRecords.batches.asScala) { + assertTrue(batch.isValid) + assertEquals(batch.timestampType, TimestampType.CREATE_TIME) + maybeCheckBaseTimestamp(timestampSeq(0), batch) + assertEquals(batch.maxTimestamp, batch.asScala.map(_.timestamp).max) + assertEquals(producerEpoch, batch.producerEpoch) + assertEquals(producerId, batch.producerId) + assertEquals(baseSequence, batch.baseSequence) + assertEquals(isTransactional, batch.isTransactional) + assertEquals(partitionLeaderEpoch, batch.partitionLeaderEpoch) + for (record <- batch.asScala) { + record.ensureValid() + assertEquals(timestampSeq(i), record.timestamp) + i += 1 + } + } + assertEquals(now + 1, validatingResults.maxTimestamp, + s"Max timestamp should be ${now + 1}") + assertEquals(1, validatingResults.shallowOffsetOfMaxTimestamp, + s"Offset of max timestamp should be 1") + assertFalse(validatingResults.messageSizeMaybeChanged, + "Message size should not have been changed") + + verifyRecordConversionStats(validatingResults.recordConversionStats, numConvertedRecords = 0, records, + compressed = false) + } + + @Test + def testNonCompressedV2(): Unit = { + checkNonCompressed(RecordBatch.MAGIC_VALUE_V2) + } + + @Test + def testRecompressionV1(): Unit = { + checkRecompression(RecordBatch.MAGIC_VALUE_V1) + } + + private def checkRecompression(magic: Byte): Unit = { + val now = System.currentTimeMillis() + val timestampSeq = Seq(now - 1, now + 1, now) + + val (producerId, producerEpoch, baseSequence, isTransactional, partitionLeaderEpoch) = + if (magic >= RecordBatch.MAGIC_VALUE_V2) + (1324L, 10.toShort, 984, true, 40) + else + (RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, false, + RecordBatch.NO_PARTITION_LEADER_EPOCH) + + val records = MemoryRecords.withRecords(magic, 0L, CompressionType.GZIP, TimestampType.CREATE_TIME, producerId, + producerEpoch, baseSequence, partitionLeaderEpoch, isTransactional, + new SimpleRecord(timestampSeq(0), "hello".getBytes), + new SimpleRecord(timestampSeq(1), "there".getBytes), + new SimpleRecord(timestampSeq(2), "beautiful".getBytes)) + + val validatingResults = LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + offsetCounter = new LongRef(0), + time = time, + now = System.currentTimeMillis(), + sourceCodec = NoCompressionCodec, + targetCodec = GZIPCompressionCodec, + compactedTopic = false, + magic = magic, + timestampType = TimestampType.CREATE_TIME, + timestampDiffMaxMs = 1000L, + partitionLeaderEpoch = partitionLeaderEpoch, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching) + val validatedRecords = validatingResults.validatedRecords + + var i = 0 + for (batch <- validatedRecords.batches.asScala) { + assertTrue(batch.isValid) + assertEquals(batch.timestampType, TimestampType.CREATE_TIME) + maybeCheckBaseTimestamp(timestampSeq(0), batch) + assertEquals(batch.maxTimestamp, batch.asScala.map(_.timestamp).max) + assertEquals(producerEpoch, batch.producerEpoch) + assertEquals(producerId, batch.producerId) + assertEquals(baseSequence, batch.baseSequence) + assertEquals(partitionLeaderEpoch, batch.partitionLeaderEpoch) + for (record <- batch.asScala) { + record.ensureValid() + assertEquals(timestampSeq(i), record.timestamp) + i += 1 + } + } + assertEquals(now + 1, validatingResults.maxTimestamp, + s"Max timestamp should be ${now + 1}") + assertEquals(2, validatingResults.shallowOffsetOfMaxTimestamp, + "Offset of max timestamp should be 2") + assertTrue(validatingResults.messageSizeMaybeChanged, + "Message size should have been changed") + + verifyRecordConversionStats(validatingResults.recordConversionStats, numConvertedRecords = 3, records, + compressed = true) + } + + @Test + def testRecompressionV2(): Unit = { + checkRecompression(RecordBatch.MAGIC_VALUE_V2) + } + + @Test + def testCreateTimeUpConversionV0ToV1(): Unit = { + checkCreateTimeUpConversionFromV0(RecordBatch.MAGIC_VALUE_V1) + } + + private def checkCreateTimeUpConversionFromV0(toMagic: Byte): Unit = { + val records = createRecords(magicValue = RecordBatch.MAGIC_VALUE_V0, codec = CompressionType.GZIP) + val validatedResults = LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + offsetCounter = new LongRef(0), + time = time, + now = System.currentTimeMillis(), + sourceCodec = DefaultCompressionCodec, + targetCodec = DefaultCompressionCodec, + magic = toMagic, + compactedTopic = false, + timestampType = TimestampType.CREATE_TIME, + timestampDiffMaxMs = 1000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching) + val validatedRecords = validatedResults.validatedRecords + + for (batch <- validatedRecords.batches.asScala) { + assertTrue(batch.isValid) + maybeCheckBaseTimestamp(RecordBatch.NO_TIMESTAMP, batch) + assertEquals(RecordBatch.NO_TIMESTAMP, batch.maxTimestamp) + assertEquals(TimestampType.CREATE_TIME, batch.timestampType) + assertEquals(RecordBatch.NO_PRODUCER_EPOCH, batch.producerEpoch) + assertEquals(RecordBatch.NO_PRODUCER_ID, batch.producerId) + assertEquals(RecordBatch.NO_SEQUENCE, batch.baseSequence) + } + assertEquals(validatedResults.maxTimestamp, RecordBatch.NO_TIMESTAMP, + s"Max timestamp should be ${RecordBatch.NO_TIMESTAMP}") + assertEquals(validatedRecords.records.asScala.size - 1, validatedResults.shallowOffsetOfMaxTimestamp, + s"Offset of max timestamp should be ${validatedRecords.records.asScala.size - 1}") + assertTrue(validatedResults.messageSizeMaybeChanged, "Message size should have been changed") + + verifyRecordConversionStats(validatedResults.recordConversionStats, numConvertedRecords = 3, records, + compressed = true) + } + + @Test + def testCreateTimeUpConversionV0ToV2(): Unit = { + checkCreateTimeUpConversionFromV0(RecordBatch.MAGIC_VALUE_V2) + } + + @Test + def testCreateTimeUpConversionV1ToV2(): Unit = { + val timestamp = System.currentTimeMillis() + val records = createRecords(magicValue = RecordBatch.MAGIC_VALUE_V1, codec = CompressionType.GZIP, timestamp = timestamp) + val validatedResults = LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + offsetCounter = new LongRef(0), + time = time, + now = timestamp, + sourceCodec = DefaultCompressionCodec, + targetCodec = DefaultCompressionCodec, + magic = RecordBatch.MAGIC_VALUE_V2, + compactedTopic = false, + timestampType = TimestampType.CREATE_TIME, + timestampDiffMaxMs = 1000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching) + val validatedRecords = validatedResults.validatedRecords + + for (batch <- validatedRecords.batches.asScala) { + assertTrue(batch.isValid) + maybeCheckBaseTimestamp(timestamp, batch) + assertEquals(timestamp, batch.maxTimestamp) + assertEquals(TimestampType.CREATE_TIME, batch.timestampType) + assertEquals(RecordBatch.NO_PRODUCER_EPOCH, batch.producerEpoch) + assertEquals(RecordBatch.NO_PRODUCER_ID, batch.producerId) + assertEquals(RecordBatch.NO_SEQUENCE, batch.baseSequence) + } + assertEquals(timestamp, validatedResults.maxTimestamp) + assertEquals(validatedRecords.records.asScala.size - 1, validatedResults.shallowOffsetOfMaxTimestamp, + s"Offset of max timestamp should be ${validatedRecords.records.asScala.size - 1}") + assertTrue(validatedResults.messageSizeMaybeChanged, "Message size should have been changed") + + verifyRecordConversionStats(validatedResults.recordConversionStats, numConvertedRecords = 3, records, + compressed = true) + } + + @Test + def testCompressedV1(): Unit = { + checkCompressed(RecordBatch.MAGIC_VALUE_V1) + } + + private def checkCompressed(magic: Byte): Unit = { + val now = System.currentTimeMillis() + val timestampSeq = Seq(now - 1, now + 1, now) + + val (producerId, producerEpoch, baseSequence, isTransactional, partitionLeaderEpoch) = + if (magic >= RecordBatch.MAGIC_VALUE_V2) + (1324L, 10.toShort, 984, true, 40) + else + (RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, RecordBatch.NO_SEQUENCE, false, + RecordBatch.NO_PARTITION_LEADER_EPOCH) + + val records = MemoryRecords.withRecords(magic, 0L, CompressionType.GZIP, TimestampType.CREATE_TIME, producerId, + producerEpoch, baseSequence, partitionLeaderEpoch, isTransactional, + new SimpleRecord(timestampSeq(0), "hello".getBytes), + new SimpleRecord(timestampSeq(1), "there".getBytes), + new SimpleRecord(timestampSeq(2), "beautiful".getBytes)) + + val validatedResults = LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + offsetCounter = new LongRef(0), + time = time, + now = System.currentTimeMillis(), + sourceCodec = DefaultCompressionCodec, + targetCodec = DefaultCompressionCodec, + magic = magic, + compactedTopic = false, + timestampType = TimestampType.CREATE_TIME, + timestampDiffMaxMs = 1000L, + partitionLeaderEpoch = partitionLeaderEpoch, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching) + val validatedRecords = validatedResults.validatedRecords + + var i = 0 + for (batch <- validatedRecords.batches.asScala) { + assertTrue(batch.isValid) + assertEquals(batch.timestampType, TimestampType.CREATE_TIME) + maybeCheckBaseTimestamp(timestampSeq(0), batch) + assertEquals(batch.maxTimestamp, batch.asScala.map(_.timestamp).max) + assertEquals(producerEpoch, batch.producerEpoch) + assertEquals(producerId, batch.producerId) + assertEquals(baseSequence, batch.baseSequence) + assertEquals(partitionLeaderEpoch, batch.partitionLeaderEpoch) + for (record <- batch.asScala) { + record.ensureValid() + assertEquals(timestampSeq(i), record.timestamp) + i += 1 + } + } + assertEquals(now + 1, validatedResults.maxTimestamp, s"Max timestamp should be ${now + 1}") + assertEquals(validatedRecords.records.asScala.size - 1, validatedResults.shallowOffsetOfMaxTimestamp, + s"Offset of max timestamp should be ${validatedRecords.records.asScala.size - 1}") + assertFalse(validatedResults.messageSizeMaybeChanged, "Message size should not have been changed") + + verifyRecordConversionStats(validatedResults.recordConversionStats, numConvertedRecords = 0, records, + compressed = true) + } + + @Test + def testCompressedV2(): Unit = { + checkCompressed(RecordBatch.MAGIC_VALUE_V2) + } + + @Test + def testInvalidCreateTimeNonCompressedV1(): Unit = { + val now = System.currentTimeMillis() + val records = createRecords(magicValue = RecordBatch.MAGIC_VALUE_V1, timestamp = now - 1001L, + codec = CompressionType.NONE) + assertThrows(classOf[RecordValidationException], () => LogValidator.validateMessagesAndAssignOffsets( + records, + topicPartition, + offsetCounter = new LongRef(0), + time = time, + now = System.currentTimeMillis(), + sourceCodec = NoCompressionCodec, + targetCodec = NoCompressionCodec, + compactedTopic = false, + magic = RecordBatch.MAGIC_VALUE_V1, + timestampType = TimestampType.CREATE_TIME, + timestampDiffMaxMs = 1000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching)) + } + + @Test + def testInvalidCreateTimeNonCompressedV2(): Unit = { + val now = System.currentTimeMillis() + val records = createRecords(magicValue = RecordBatch.MAGIC_VALUE_V2, timestamp = now - 1001L, + codec = CompressionType.NONE) + assertThrows(classOf[RecordValidationException], () => LogValidator.validateMessagesAndAssignOffsets( + records, + topicPartition, + offsetCounter = new LongRef(0), + time = time, + now = System.currentTimeMillis(), + sourceCodec = NoCompressionCodec, + targetCodec = NoCompressionCodec, + compactedTopic = false, + magic = RecordBatch.MAGIC_VALUE_V2, + timestampType = TimestampType.CREATE_TIME, + timestampDiffMaxMs = 1000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching)) + } + + @Test + def testInvalidCreateTimeCompressedV1(): Unit = { + val now = System.currentTimeMillis() + val records = createRecords(magicValue = RecordBatch.MAGIC_VALUE_V1, timestamp = now - 1001L, + codec = CompressionType.GZIP) + assertThrows(classOf[RecordValidationException], () => LogValidator.validateMessagesAndAssignOffsets( + records, + topicPartition, + offsetCounter = new LongRef(0), + time = time, + now = System.currentTimeMillis(), + sourceCodec = DefaultCompressionCodec, + targetCodec = DefaultCompressionCodec, + magic = RecordBatch.MAGIC_VALUE_V1, + compactedTopic = false, + timestampType = TimestampType.CREATE_TIME, + timestampDiffMaxMs = 1000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching)) + } + + @Test + def testInvalidCreateTimeCompressedV2(): Unit = { + val now = System.currentTimeMillis() + val records = createRecords(magicValue = RecordBatch.MAGIC_VALUE_V2, timestamp = now - 1001L, + codec = CompressionType.GZIP) + assertThrows(classOf[RecordValidationException], () => LogValidator.validateMessagesAndAssignOffsets( + records, + topicPartition, + offsetCounter = new LongRef(0), + time = time, + now = System.currentTimeMillis(), + sourceCodec = DefaultCompressionCodec, + targetCodec = DefaultCompressionCodec, + magic = RecordBatch.MAGIC_VALUE_V1, + compactedTopic = false, + timestampType = TimestampType.CREATE_TIME, + timestampDiffMaxMs = 1000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching)) + } + + @Test + def testAbsoluteOffsetAssignmentNonCompressed(): Unit = { + val records = createRecords(magicValue = RecordBatch.MAGIC_VALUE_V0, codec = CompressionType.NONE) + val offset = 1234567 + checkOffsets(records, 0) + checkOffsets(LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + offsetCounter = new LongRef(offset), + time = time, + now = System.currentTimeMillis(), + sourceCodec = NoCompressionCodec, + targetCodec = NoCompressionCodec, + magic = RecordBatch.MAGIC_VALUE_V0, + compactedTopic = false, + timestampType = TimestampType.CREATE_TIME, + timestampDiffMaxMs = 1000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching).validatedRecords, offset) + } + + @Test + def testAbsoluteOffsetAssignmentCompressed(): Unit = { + val records = createRecords(magicValue = RecordBatch.MAGIC_VALUE_V0, codec = CompressionType.GZIP) + val offset = 1234567 + checkOffsets(records, 0) + checkOffsets(LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + offsetCounter = new LongRef(offset), + time = time, + now = System.currentTimeMillis(), + sourceCodec = DefaultCompressionCodec, + targetCodec = DefaultCompressionCodec, + compactedTopic = false, + magic = RecordBatch.MAGIC_VALUE_V0, + timestampType = TimestampType.CREATE_TIME, + timestampDiffMaxMs = 1000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching).validatedRecords, offset) + } + + @Test + def testRelativeOffsetAssignmentNonCompressedV1(): Unit = { + val now = System.currentTimeMillis() + val records = createRecords(magicValue = RecordBatch.MAGIC_VALUE_V1, timestamp = now, codec = CompressionType.NONE) + val offset = 1234567 + checkOffsets(records, 0) + val messageWithOffset = LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + offsetCounter = new LongRef(offset), + time = time, + now = System.currentTimeMillis(), + sourceCodec = NoCompressionCodec, + targetCodec = NoCompressionCodec, + compactedTopic = false, + magic = RecordBatch.MAGIC_VALUE_V1, + timestampType = TimestampType.CREATE_TIME, + timestampDiffMaxMs = 5000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching).validatedRecords + checkOffsets(messageWithOffset, offset) + } + + @Test + def testRelativeOffsetAssignmentNonCompressedV2(): Unit = { + val now = System.currentTimeMillis() + val records = createRecords(magicValue = RecordBatch.MAGIC_VALUE_V2, timestamp = now, codec = CompressionType.NONE) + val offset = 1234567 + checkOffsets(records, 0) + val messageWithOffset = LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + offsetCounter = new LongRef(offset), + time = time, + now = System.currentTimeMillis(), + sourceCodec = NoCompressionCodec, + targetCodec = NoCompressionCodec, + compactedTopic = false, + magic = RecordBatch.MAGIC_VALUE_V2, + timestampType = TimestampType.CREATE_TIME, + timestampDiffMaxMs = 5000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching).validatedRecords + checkOffsets(messageWithOffset, offset) + } + + @Test + def testRelativeOffsetAssignmentCompressedV1(): Unit = { + val now = System.currentTimeMillis() + val records = createRecords(magicValue = RecordBatch.MAGIC_VALUE_V1, timestamp = now, codec = CompressionType.GZIP) + val offset = 1234567 + checkOffsets(records, 0) + val compressedMessagesWithOffset = LogValidator.validateMessagesAndAssignOffsets( + records, + topicPartition, + offsetCounter = new LongRef(offset), + time = time, + now = System.currentTimeMillis(), + sourceCodec = DefaultCompressionCodec, + targetCodec = DefaultCompressionCodec, + compactedTopic = false, + magic = RecordBatch.MAGIC_VALUE_V1, + timestampType = TimestampType.CREATE_TIME, + timestampDiffMaxMs = 5000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching).validatedRecords + checkOffsets(compressedMessagesWithOffset, offset) + } + + @Test + def testRelativeOffsetAssignmentCompressedV2(): Unit = { + val now = System.currentTimeMillis() + val records = createRecords(magicValue = RecordBatch.MAGIC_VALUE_V2, timestamp = now, codec = CompressionType.GZIP) + val offset = 1234567 + checkOffsets(records, 0) + val compressedMessagesWithOffset = LogValidator.validateMessagesAndAssignOffsets( + records, + topicPartition, + offsetCounter = new LongRef(offset), + time = time, + now = System.currentTimeMillis(), + sourceCodec = DefaultCompressionCodec, + targetCodec = DefaultCompressionCodec, + compactedTopic = false, + magic = RecordBatch.MAGIC_VALUE_V2, + timestampType = TimestampType.CREATE_TIME, + timestampDiffMaxMs = 5000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching).validatedRecords + checkOffsets(compressedMessagesWithOffset, offset) + } + + @Test + def testOffsetAssignmentAfterUpConversionV0ToV1NonCompressed(): Unit = { + val records = createRecords(magicValue = RecordBatch.MAGIC_VALUE_V0, codec = CompressionType.NONE) + checkOffsets(records, 0) + val offset = 1234567 + val validatedResults = LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + offsetCounter = new LongRef(offset), + time = time, + now = System.currentTimeMillis(), + sourceCodec = NoCompressionCodec, + targetCodec = NoCompressionCodec, + compactedTopic = false, + magic = RecordBatch.MAGIC_VALUE_V1, + timestampType = TimestampType.LOG_APPEND_TIME, + timestampDiffMaxMs = 1000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching) + checkOffsets(validatedResults.validatedRecords, offset) + verifyRecordConversionStats(validatedResults.recordConversionStats, numConvertedRecords = 3, records, + compressed = false) + } + + @Test + def testOffsetAssignmentAfterUpConversionV0ToV2NonCompressed(): Unit = { + val records = createRecords(magicValue = RecordBatch.MAGIC_VALUE_V0, codec = CompressionType.NONE) + checkOffsets(records, 0) + val offset = 1234567 + val validatedResults = LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + offsetCounter = new LongRef(offset), + time = time, + now = System.currentTimeMillis(), + sourceCodec = NoCompressionCodec, + targetCodec = NoCompressionCodec, + compactedTopic = false, + magic = RecordBatch.MAGIC_VALUE_V2, + timestampType = TimestampType.LOG_APPEND_TIME, + timestampDiffMaxMs = 1000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching) + checkOffsets(validatedResults.validatedRecords, offset) + verifyRecordConversionStats(validatedResults.recordConversionStats, numConvertedRecords = 3, records, + compressed = false) + } + + @Test + def testOffsetAssignmentAfterUpConversionV0ToV1Compressed(): Unit = { + val records = createRecords(magicValue = RecordBatch.MAGIC_VALUE_V0, codec = CompressionType.GZIP) + val offset = 1234567 + checkOffsets(records, 0) + val validatedResults = LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + offsetCounter = new LongRef(offset), + time = time, + now = System.currentTimeMillis(), + sourceCodec = DefaultCompressionCodec, + targetCodec = DefaultCompressionCodec, + compactedTopic = false, + magic = RecordBatch.MAGIC_VALUE_V1, + timestampType = TimestampType.LOG_APPEND_TIME, + timestampDiffMaxMs = 1000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching) + checkOffsets(validatedResults.validatedRecords, offset) + verifyRecordConversionStats(validatedResults.recordConversionStats, numConvertedRecords = 3, records, + compressed = true) + } + + @Test + def testOffsetAssignmentAfterUpConversionV0ToV2Compressed(): Unit = { + val records = createRecords(magicValue = RecordBatch.MAGIC_VALUE_V0, codec = CompressionType.GZIP) + val offset = 1234567 + checkOffsets(records, 0) + val validatedResults = LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + offsetCounter = new LongRef(offset), + time = time, + now = System.currentTimeMillis(), + sourceCodec = DefaultCompressionCodec, + targetCodec = DefaultCompressionCodec, + compactedTopic = false, + magic = RecordBatch.MAGIC_VALUE_V2, + timestampType = TimestampType.LOG_APPEND_TIME, + timestampDiffMaxMs = 1000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching) + checkOffsets(validatedResults.validatedRecords, offset) + verifyRecordConversionStats(validatedResults.recordConversionStats, numConvertedRecords = 3, records, + compressed = true) + } + + @Test + def testControlRecordsNotAllowedFromClients(): Unit = { + val offset = 1234567 + val endTxnMarker = new EndTransactionMarker(ControlRecordType.COMMIT, 0) + val records = MemoryRecords.withEndTransactionMarker(23423L, 5, endTxnMarker) + assertThrows(classOf[InvalidRecordException], () => LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + offsetCounter = new LongRef(offset), + time = time, + now = System.currentTimeMillis(), + sourceCodec = NoCompressionCodec, + targetCodec = NoCompressionCodec, + compactedTopic = false, + magic = RecordBatch.CURRENT_MAGIC_VALUE, + timestampType = TimestampType.CREATE_TIME, + timestampDiffMaxMs = 5000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching)) + } + + @Test + def testControlRecordsNotCompressed(): Unit = { + val offset = 1234567 + val endTxnMarker = new EndTransactionMarker(ControlRecordType.COMMIT, 0) + val records = MemoryRecords.withEndTransactionMarker(23423L, 5, endTxnMarker) + val result = LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + offsetCounter = new LongRef(offset), + time = time, + now = System.currentTimeMillis(), + sourceCodec = NoCompressionCodec, + targetCodec = SnappyCompressionCodec, + compactedTopic = false, + magic = RecordBatch.CURRENT_MAGIC_VALUE, + timestampType = TimestampType.CREATE_TIME, + timestampDiffMaxMs = 5000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Coordinator, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching) + val batches = TestUtils.toList(result.validatedRecords.batches) + assertEquals(1, batches.size) + val batch = batches.get(0) + assertFalse(batch.isCompressed) + } + + @Test + def testOffsetAssignmentAfterDownConversionV1ToV0NonCompressed(): Unit = { + val offset = 1234567 + val now = System.currentTimeMillis() + val records = createRecords(RecordBatch.MAGIC_VALUE_V1, now, codec = CompressionType.NONE) + checkOffsets(records, 0) + checkOffsets(LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + offsetCounter = new LongRef(offset), + time = time, + now = System.currentTimeMillis(), + sourceCodec = NoCompressionCodec, + targetCodec = NoCompressionCodec, + compactedTopic = false, + magic = RecordBatch.MAGIC_VALUE_V0, + timestampType = TimestampType.CREATE_TIME, + timestampDiffMaxMs = 5000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching).validatedRecords, offset) + } + + @Test + def testOffsetAssignmentAfterDownConversionV1ToV0Compressed(): Unit = { + val offset = 1234567 + val now = System.currentTimeMillis() + val records = createRecords(RecordBatch.MAGIC_VALUE_V1, now, CompressionType.GZIP) + checkOffsets(records, 0) + checkOffsets(LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + offsetCounter = new LongRef(offset), + time = time, + now = System.currentTimeMillis(), + sourceCodec = DefaultCompressionCodec, + targetCodec = DefaultCompressionCodec, + compactedTopic = false, + magic = RecordBatch.MAGIC_VALUE_V0, + timestampType = TimestampType.CREATE_TIME, + timestampDiffMaxMs = 5000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching).validatedRecords, offset) + } + + @Test + def testOffsetAssignmentAfterUpConversionV1ToV2NonCompressed(): Unit = { + val records = createRecords(magicValue = RecordBatch.MAGIC_VALUE_V1, codec = CompressionType.NONE) + checkOffsets(records, 0) + val offset = 1234567 + checkOffsets(LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + offsetCounter = new LongRef(offset), + time = time, + now = System.currentTimeMillis(), + sourceCodec = NoCompressionCodec, + targetCodec = NoCompressionCodec, + compactedTopic = false, + magic = RecordBatch.MAGIC_VALUE_V2, + timestampType = TimestampType.LOG_APPEND_TIME, + timestampDiffMaxMs = 1000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching).validatedRecords, offset) + } + + @Test + def testOffsetAssignmentAfterUpConversionV1ToV2Compressed(): Unit = { + val records = createRecords(magicValue = RecordBatch.MAGIC_VALUE_V1, codec = CompressionType.GZIP) + val offset = 1234567 + checkOffsets(records, 0) + checkOffsets(LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + offsetCounter = new LongRef(offset), + time = time, + now = System.currentTimeMillis(), + sourceCodec = DefaultCompressionCodec, + targetCodec = DefaultCompressionCodec, + compactedTopic = false, + magic = RecordBatch.MAGIC_VALUE_V2, + timestampType = TimestampType.LOG_APPEND_TIME, + timestampDiffMaxMs = 1000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching).validatedRecords, offset) + } + + @Test + def testOffsetAssignmentAfterDownConversionV2ToV1NonCompressed(): Unit = { + val offset = 1234567 + val now = System.currentTimeMillis() + val records = createRecords(RecordBatch.MAGIC_VALUE_V2, now, codec = CompressionType.NONE) + checkOffsets(records, 0) + checkOffsets(LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + offsetCounter = new LongRef(offset), + time = time, + now = System.currentTimeMillis(), + sourceCodec = NoCompressionCodec, + targetCodec = NoCompressionCodec, + compactedTopic = false, + magic = RecordBatch.MAGIC_VALUE_V1, + timestampType = TimestampType.CREATE_TIME, + timestampDiffMaxMs = 5000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching).validatedRecords, offset) + } + + @Test + def testOffsetAssignmentAfterDownConversionV2ToV1Compressed(): Unit = { + val offset = 1234567 + val now = System.currentTimeMillis() + val records = createRecords(RecordBatch.MAGIC_VALUE_V2, now, CompressionType.GZIP) + checkOffsets(records, 0) + checkOffsets(LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + offsetCounter = new LongRef(offset), + time = time, + now = System.currentTimeMillis(), + sourceCodec = DefaultCompressionCodec, + targetCodec = DefaultCompressionCodec, + compactedTopic = false, + magic = RecordBatch.MAGIC_VALUE_V1, + timestampType = TimestampType.CREATE_TIME, + timestampDiffMaxMs = 5000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching).validatedRecords, offset) + } + + @Test + def testDownConversionOfTransactionalRecordsNotPermitted(): Unit = { + val offset = 1234567 + val producerId = 1344L + val producerEpoch = 16.toShort + val sequence = 0 + val records = MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, producerEpoch, sequence, + new SimpleRecord("hello".getBytes), new SimpleRecord("there".getBytes), new SimpleRecord("beautiful".getBytes)) + assertThrows(classOf[UnsupportedForMessageFormatException], () => LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + offsetCounter = new LongRef(offset), + time = time, + now = System.currentTimeMillis(), + sourceCodec = DefaultCompressionCodec, + targetCodec = DefaultCompressionCodec, + compactedTopic = false, + magic = RecordBatch.MAGIC_VALUE_V1, + timestampType = TimestampType.CREATE_TIME, + timestampDiffMaxMs = 5000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching)) + } + + @Test + def testDownConversionOfIdempotentRecordsNotPermitted(): Unit = { + val offset = 1234567 + val producerId = 1344L + val producerEpoch = 16.toShort + val sequence = 0 + val records = MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, producerEpoch, sequence, + new SimpleRecord("hello".getBytes), new SimpleRecord("there".getBytes), new SimpleRecord("beautiful".getBytes)) + assertThrows(classOf[UnsupportedForMessageFormatException], () => LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + offsetCounter = new LongRef(offset), + time = time, + now = System.currentTimeMillis(), + sourceCodec = DefaultCompressionCodec, + targetCodec = DefaultCompressionCodec, + compactedTopic = false, + magic = RecordBatch.MAGIC_VALUE_V1, + timestampType = TimestampType.CREATE_TIME, + timestampDiffMaxMs = 5000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching)) + } + + @Test + def testOffsetAssignmentAfterDownConversionV2ToV0NonCompressed(): Unit = { + val offset = 1234567 + val now = System.currentTimeMillis() + val records = createRecords(RecordBatch.MAGIC_VALUE_V2, now, codec = CompressionType.NONE) + checkOffsets(records, 0) + checkOffsets(LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + offsetCounter = new LongRef(offset), + time = time, + now = System.currentTimeMillis(), + sourceCodec = NoCompressionCodec, + targetCodec = NoCompressionCodec, + compactedTopic = false, + magic = RecordBatch.MAGIC_VALUE_V0, + timestampType = TimestampType.CREATE_TIME, + timestampDiffMaxMs = 5000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching).validatedRecords, offset) + } + + @Test + def testOffsetAssignmentAfterDownConversionV2ToV0Compressed(): Unit = { + val offset = 1234567 + val now = System.currentTimeMillis() + val records = createRecords(RecordBatch.MAGIC_VALUE_V2, now, CompressionType.GZIP) + checkOffsets(records, 0) + checkOffsets(LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + offsetCounter = new LongRef(offset), + time = time, + now = System.currentTimeMillis(), + sourceCodec = DefaultCompressionCodec, + targetCodec = DefaultCompressionCodec, + compactedTopic = false, + magic = RecordBatch.MAGIC_VALUE_V0, + timestampType = TimestampType.CREATE_TIME, + timestampDiffMaxMs = 5000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching).validatedRecords, offset) + } + + @Test + def testNonIncreasingOffsetRecordBatchHasMetricsLogged(): Unit = { + val records = createNonIncreasingOffsetRecords(RecordBatch.MAGIC_VALUE_V2) + records.batches().asScala.head.setLastOffset(2) + assertThrows(classOf[InvalidRecordException], () => LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + offsetCounter = new LongRef(0L), + time = time, + now = System.currentTimeMillis(), + sourceCodec = DefaultCompressionCodec, + targetCodec = DefaultCompressionCodec, + compactedTopic = false, + magic = RecordBatch.MAGIC_VALUE_V0, + timestampType = TimestampType.CREATE_TIME, + timestampDiffMaxMs = 5000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching) + ) + assertEquals(metricsKeySet.count(_.getMBeanName.endsWith(s"${BrokerTopicStats.InvalidOffsetOrSequenceRecordsPerSec}")), 1) + assertTrue(meterCount(s"${BrokerTopicStats.InvalidOffsetOrSequenceRecordsPerSec}") > 0) + } + + @Test + def testCompressedBatchWithoutRecordsNotAllowed(): Unit = { + testBatchWithoutRecordsNotAllowed(DefaultCompressionCodec, DefaultCompressionCodec) + } + + @Test + def testZStdCompressedWithUnavailableIBPVersion(): Unit = { + val now = System.currentTimeMillis() + // The timestamps should be overwritten + val records = createRecords(magicValue = RecordBatch.MAGIC_VALUE_V2, timestamp = 1234L, codec = CompressionType.NONE) + assertThrows(classOf[UnsupportedCompressionTypeException], () => LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + offsetCounter = new LongRef(0), + time= time, + now = now, + sourceCodec = NoCompressionCodec, + targetCodec = ZStdCompressionCodec, + compactedTopic = false, + magic = RecordBatch.MAGIC_VALUE_V2, + timestampType = TimestampType.LOG_APPEND_TIME, + timestampDiffMaxMs = 1000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = KAFKA_2_0_IV1, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching)) + } + + @Test + def testUncompressedBatchWithoutRecordsNotAllowed(): Unit = { + testBatchWithoutRecordsNotAllowed(NoCompressionCodec, NoCompressionCodec) + } + + @Test + def testRecompressedBatchWithoutRecordsNotAllowed(): Unit = { + testBatchWithoutRecordsNotAllowed(NoCompressionCodec, DefaultCompressionCodec) + } + + @Test + def testInvalidTimestampExceptionHasBatchIndex(): Unit = { + val now = System.currentTimeMillis() + val records = createRecords(magicValue = RecordBatch.MAGIC_VALUE_V2, timestamp = now - 1001L, + codec = CompressionType.GZIP) + val e = assertThrows(classOf[RecordValidationException], + () => LogValidator.validateMessagesAndAssignOffsets( + records, + topicPartition, + offsetCounter = new LongRef(0), + time = time, + now = System.currentTimeMillis(), + sourceCodec = DefaultCompressionCodec, + targetCodec = DefaultCompressionCodec, + magic = RecordBatch.MAGIC_VALUE_V1, + compactedTopic = false, + timestampType = TimestampType.CREATE_TIME, + timestampDiffMaxMs = 1000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching) + ) + + assertTrue(e.invalidException.isInstanceOf[InvalidTimestampException]) + assertTrue(e.recordErrors.nonEmpty) + assertEquals(e.recordErrors.size, 3) + } + + @Test + def testInvalidRecordExceptionHasBatchIndex(): Unit = { + val e = assertThrows(classOf[RecordValidationException], + () => validateMessages(recordsWithInvalidInnerMagic( + RecordBatch.MAGIC_VALUE_V0, RecordBatch.MAGIC_VALUE_V1, CompressionType.GZIP), + RecordBatch.MAGIC_VALUE_V0, CompressionType.GZIP, CompressionType.GZIP) + ) + + assertTrue(e.invalidException.isInstanceOf[InvalidRecordException]) + assertTrue(e.recordErrors.nonEmpty) + // recordsWithInvalidInnerMagic creates 20 records + assertEquals(e.recordErrors.size, 20) + e.recordErrors.foreach(assertNotNull(_)) + } + + @Test + def testBatchWithInvalidRecordsAndInvalidTimestamp(): Unit = { + val records = (0 until 5).map(id => + LegacyRecord.create(RecordBatch.MAGIC_VALUE_V0, 0L, null, id.toString.getBytes()) + ) + + val buffer = ByteBuffer.allocate(1024) + val builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V1, CompressionType.GZIP, + TimestampType.CREATE_TIME, 0L) + var offset = 0 + + // we want to mix in a record with invalid timestamp range + builder.appendUncheckedWithOffset(offset, LegacyRecord.create(RecordBatch.MAGIC_VALUE_V1, + 1200L, null, "timestamp".getBytes)) + records.foreach { record => + offset += 30 + builder.appendUncheckedWithOffset(offset, record) + } + val invalidOffsetTimestampRecords = builder.build() + + val e = assertThrows(classOf[RecordValidationException], + () => validateMessages(invalidOffsetTimestampRecords, + RecordBatch.MAGIC_VALUE_V0, CompressionType.GZIP, CompressionType.GZIP) + ) + // if there is a mix of both regular InvalidRecordException and InvalidTimestampException, + // InvalidTimestampException takes precedence + assertTrue(e.invalidException.isInstanceOf[InvalidTimestampException]) + assertTrue(e.recordErrors.nonEmpty) + assertEquals(6, e.recordErrors.size) + } + + private def testBatchWithoutRecordsNotAllowed(sourceCodec: CompressionCodec, targetCodec: CompressionCodec): Unit = { + val offset = 1234567 + val (producerId, producerEpoch, baseSequence, isTransactional, partitionLeaderEpoch) = + (1324L, 10.toShort, 984, true, 40) + val buffer = ByteBuffer.allocate(DefaultRecordBatch.RECORD_BATCH_OVERHEAD) + DefaultRecordBatch.writeEmptyHeader(buffer, RecordBatch.CURRENT_MAGIC_VALUE, producerId, producerEpoch, + baseSequence, 0L, 5L, partitionLeaderEpoch, TimestampType.CREATE_TIME, System.currentTimeMillis(), + isTransactional, false) + buffer.flip() + val records = MemoryRecords.readableRecords(buffer) + assertThrows(classOf[InvalidRecordException], () => LogValidator.validateMessagesAndAssignOffsets(records, + topicPartition, + offsetCounter = new LongRef(offset), + time = time, + now = System.currentTimeMillis(), + sourceCodec = sourceCodec, + targetCodec = targetCodec, + compactedTopic = false, + magic = RecordBatch.CURRENT_MAGIC_VALUE, + timestampType = TimestampType.CREATE_TIME, + timestampDiffMaxMs = 5000L, + partitionLeaderEpoch = RecordBatch.NO_PARTITION_LEADER_EPOCH, + origin = AppendOrigin.Client, + interBrokerProtocolVersion = ApiVersion.latestVersion, + brokerTopicStats = brokerTopicStats, + requestLocal = RequestLocal.withThreadConfinedCaching)) + } + + private def createRecords(magicValue: Byte, + timestamp: Long = RecordBatch.NO_TIMESTAMP, + codec: CompressionType): MemoryRecords = { + val buf = ByteBuffer.allocate(512) + val builder = MemoryRecords.builder(buf, magicValue, codec, TimestampType.CREATE_TIME, 0L) + builder.appendWithOffset(0, timestamp, null, "hello".getBytes) + builder.appendWithOffset(1, timestamp, null, "there".getBytes) + builder.appendWithOffset(2, timestamp, null, "beautiful".getBytes) + builder.build() + } + + private def createNonIncreasingOffsetRecords(magicValue: Byte, + timestamp: Long = RecordBatch.NO_TIMESTAMP, + codec: CompressionType = CompressionType.NONE): MemoryRecords = { + val buf = ByteBuffer.allocate(512) + val builder = MemoryRecords.builder(buf, magicValue, codec, TimestampType.CREATE_TIME, 0L) + builder.appendWithOffset(0, timestamp, null, "hello".getBytes) + builder.appendWithOffset(2, timestamp, null, "there".getBytes) + builder.appendWithOffset(3, timestamp, null, "beautiful".getBytes) + builder.build() + } + + private def createTwoBatchedRecords(magicValue: Byte, + timestamp: Long, + codec: CompressionType): MemoryRecords = { + val buf = ByteBuffer.allocate(2048) + var builder = MemoryRecords.builder(buf, magicValue, codec, TimestampType.CREATE_TIME, 0L) + builder.append(10L, "1".getBytes(), "a".getBytes()) + builder.close() + builder = MemoryRecords.builder(buf, magicValue, codec, TimestampType.CREATE_TIME, 1L) + builder.append(11L, "2".getBytes(), "b".getBytes()) + builder.append(12L, "3".getBytes(), "c".getBytes()) + builder.close() + + buf.flip() + MemoryRecords.readableRecords(buf.slice()) + } + + /* check that offsets are assigned consecutively from the given base offset */ + def checkOffsets(records: MemoryRecords, baseOffset: Long): Unit = { + assertTrue(records.records.asScala.nonEmpty, "Message set should not be empty") + var offset = baseOffset + for (entry <- records.records.asScala) { + assertEquals(offset, entry.offset, "Unexpected offset in message set iterator") + offset += 1 + } + } + + private def recordsWithNonSequentialInnerOffsets(magicValue: Byte, + codec: CompressionType, + numRecords: Int): MemoryRecords = { + val records = (0 until numRecords).map { id => + new SimpleRecord(id.toString.getBytes) + } + + val buffer = ByteBuffer.allocate(1024) + val builder = MemoryRecords.builder(buffer, magicValue, codec, TimestampType.CREATE_TIME, 0L) + + records.foreach { record => + builder.appendUncheckedWithOffset(0, record) + } + + builder.build() + } + + private def recordsWithInvalidInnerMagic(batchMagicValue: Byte, + recordMagicValue: Byte, + codec: CompressionType): MemoryRecords = { + val records = (0 until 20).map(id => + LegacyRecord.create(recordMagicValue, + RecordBatch.NO_TIMESTAMP, + id.toString.getBytes, + id.toString.getBytes)) + + val buffer = ByteBuffer.allocate(math.min(math.max(records.map(_.sizeInBytes()).sum / 2, 1024), 1 << 16)) + val builder = MemoryRecords.builder(buffer, batchMagicValue, codec, + TimestampType.CREATE_TIME, 0L) + + var offset = 1234567 + records.foreach { record => + builder.appendUncheckedWithOffset(offset, record) + offset += 1 + } + + builder.build() + } + + def maybeCheckBaseTimestamp(expected: Long, batch: RecordBatch): Unit = { + batch match { + case b: DefaultRecordBatch => + assertEquals(expected, b.baseTimestamp, s"Unexpected base timestamp of batch $batch") + case _ => // no-op + } + } + + /** + * expectedLogAppendTime is only checked if batch.magic is V2 or higher + */ + def validateLogAppendTime(expectedLogAppendTime: Long, expectedBaseTimestamp: Long, batch: RecordBatch): Unit = { + assertTrue(batch.isValid) + assertTrue(batch.timestampType == TimestampType.LOG_APPEND_TIME) + assertEquals(expectedLogAppendTime, batch.maxTimestamp, s"Unexpected max timestamp of batch $batch") + maybeCheckBaseTimestamp(expectedBaseTimestamp, batch) + for (record <- batch.asScala) { + record.ensureValid() + assertEquals(expectedLogAppendTime, record.timestamp, s"Unexpected timestamp of record $record") + } + } + + def verifyRecordConversionStats(stats: RecordConversionStats, numConvertedRecords: Int, records: MemoryRecords, + compressed: Boolean): Unit = { + assertNotNull(stats, "Records processing info is null") + assertEquals(numConvertedRecords, stats.numRecordsConverted) + if (numConvertedRecords > 0) { + assertTrue(stats.conversionTimeNanos >= 0, s"Conversion time not recorded $stats") + assertTrue(stats.conversionTimeNanos <= TimeUnit.MINUTES.toNanos(1), s"Conversion time not valid $stats") + } + val originalSize = records.sizeInBytes + val tempBytes = stats.temporaryMemoryBytes + if (numConvertedRecords > 0 && compressed) + assertTrue(tempBytes > originalSize, s"Temp bytes too small, orig=$originalSize actual=$tempBytes") + else if (numConvertedRecords > 0 || compressed) + assertTrue(tempBytes > 0, "Temp bytes not updated") + else + assertEquals(0, tempBytes) + } +} diff --git a/core/src/test/scala/unit/kafka/log/OffsetIndexTest.scala b/core/src/test/scala/unit/kafka/log/OffsetIndexTest.scala new file mode 100644 index 0000000..4a9dcd0 --- /dev/null +++ b/core/src/test/scala/unit/kafka/log/OffsetIndexTest.scala @@ -0,0 +1,240 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import java.io._ +import java.nio.file.Files + +import org.junit.jupiter.api.Assertions._ +import java.util.{Arrays, Collections} + +import org.junit.jupiter.api._ + +import scala.collection._ +import scala.util.Random +import kafka.utils.TestUtils +import org.apache.kafka.common.errors.InvalidOffsetException + +import scala.annotation.nowarn + +class OffsetIndexTest { + + var idx: OffsetIndex = null + val maxEntries = 30 + val baseOffset = 45L + + @BeforeEach + def setup(): Unit = { + this.idx = new OffsetIndex(nonExistentTempFile(), baseOffset, maxIndexSize = 30 * 8) + } + + @AfterEach + def teardown(): Unit = { + if(this.idx != null) + this.idx.file.delete() + } + + @nowarn("cat=deprecation") + @Test + def randomLookupTest(): Unit = { + assertEquals(OffsetPosition(idx.baseOffset, 0), idx.lookup(92L), + "Not present value should return physical offset 0.") + + // append some random values + val base = idx.baseOffset.toInt + 1 + val size = idx.maxEntries + val vals: Seq[(Long, Int)] = monotonicSeq(base, size).map(_.toLong).zip(monotonicSeq(0, size)) + vals.foreach{x => idx.append(x._1, x._2)} + + // should be able to find all those values + for((logical, physical) <- vals) + assertEquals(OffsetPosition(logical, physical), idx.lookup(logical), + "Should be able to find values that are present.") + + // for non-present values we should find the offset of the largest value less than or equal to this + val valMap = new immutable.TreeMap[Long, (Long, Int)]() ++ vals.map(p => (p._1, p)) + val offsets = (idx.baseOffset until vals.last._1.toInt).toArray + Collections.shuffle(Arrays.asList(offsets)) + for(offset <- offsets.take(30)) { + val rightAnswer = + if(offset < valMap.firstKey) + OffsetPosition(idx.baseOffset, 0) + else + OffsetPosition(valMap.to(offset).last._1, valMap.to(offset).last._2._2) + assertEquals(rightAnswer, idx.lookup(offset), + "The index should give the same answer as the sorted map") + } + } + + @Test + def lookupExtremeCases(): Unit = { + assertEquals(OffsetPosition(idx.baseOffset, 0), idx.lookup(idx.baseOffset), + "Lookup on empty file") + for(i <- 0 until idx.maxEntries) + idx.append(idx.baseOffset + i + 1, i) + // check first and last entry + assertEquals(OffsetPosition(idx.baseOffset, 0), idx.lookup(idx.baseOffset)) + assertEquals(OffsetPosition(idx.baseOffset + idx.maxEntries, idx.maxEntries - 1), idx.lookup(idx.baseOffset + idx.maxEntries)) + } + + @Test + def testEntry(): Unit = { + for (i <- 0 until idx.maxEntries) + idx.append(idx.baseOffset + i + 1, i) + for (i <- 0 until idx.maxEntries) + assertEquals(OffsetPosition(idx.baseOffset + i + 1, i), idx.entry(i)) + } + + @Test + def testEntryOverflow(): Unit = { + assertThrows(classOf[IllegalArgumentException], () => idx.entry(0)) + } + + @Test + def appendTooMany(): Unit = { + for(i <- 0 until idx.maxEntries) { + val offset = idx.baseOffset + i + 1 + idx.append(offset, i) + } + assertWriteFails("Append should fail on a full index", idx, idx.maxEntries + 1, classOf[IllegalArgumentException]) + } + + @Test + def appendOutOfOrder(): Unit = { + idx.append(51, 0) + assertThrows(classOf[InvalidOffsetException], () => idx.append(50, 1)) + } + + @Test + def testFetchUpperBoundOffset(): Unit = { + val first = OffsetPosition(baseOffset + 0, 0) + val second = OffsetPosition(baseOffset + 1, 10) + val third = OffsetPosition(baseOffset + 2, 23) + val fourth = OffsetPosition(baseOffset + 3, 37) + + assertEquals(None, idx.fetchUpperBoundOffset(first, 5)) + + for (offsetPosition <- Seq(first, second, third, fourth)) + idx.append(offsetPosition.offset, offsetPosition.position) + + assertEquals(Some(second), idx.fetchUpperBoundOffset(first, 5)) + assertEquals(Some(second), idx.fetchUpperBoundOffset(first, 10)) + assertEquals(Some(third), idx.fetchUpperBoundOffset(first, 23)) + assertEquals(Some(third), idx.fetchUpperBoundOffset(first, 22)) + assertEquals(Some(fourth), idx.fetchUpperBoundOffset(second, 24)) + assertEquals(None, idx.fetchUpperBoundOffset(fourth, 1)) + assertEquals(None, idx.fetchUpperBoundOffset(first, 200)) + assertEquals(None, idx.fetchUpperBoundOffset(second, 200)) + } + + @Test + def testReopen(): Unit = { + val first = OffsetPosition(51, 0) + val sec = OffsetPosition(52, 1) + idx.append(first.offset, first.position) + idx.append(sec.offset, sec.position) + idx.close() + val idxRo = new OffsetIndex(idx.file, baseOffset = idx.baseOffset) + assertEquals(first, idxRo.lookup(first.offset)) + assertEquals(sec, idxRo.lookup(sec.offset)) + assertEquals(sec.offset, idxRo.lastOffset) + assertEquals(2, idxRo.entries) + assertWriteFails("Append should fail on read-only index", idxRo, 53, classOf[IllegalArgumentException]) + } + + @Test + def truncate(): Unit = { + val idx = new OffsetIndex(nonExistentTempFile(), baseOffset = 0L, maxIndexSize = 10 * 8) + idx.truncate() + for(i <- 1 until 10) + idx.append(i, i) + + // now check the last offset after various truncate points and validate that we can still append to the index. + idx.truncateTo(12) + assertEquals(OffsetPosition(9, 9), idx.lookup(10), + "Index should be unchanged by truncate past the end") + assertEquals(9, idx.lastOffset, + "9 should be the last entry in the index") + + idx.append(10, 10) + idx.truncateTo(10) + assertEquals(OffsetPosition(9, 9), idx.lookup(10), + "Index should be unchanged by truncate at the end") + assertEquals(9, idx.lastOffset, + "9 should be the last entry in the index") + idx.append(10, 10) + + idx.truncateTo(9) + assertEquals(OffsetPosition(8, 8), idx.lookup(10), + "Index should truncate off last entry") + assertEquals(8, idx.lastOffset, + "8 should be the last entry in the index") + idx.append(9, 9) + + idx.truncateTo(5) + assertEquals(OffsetPosition(4, 4), idx.lookup(10), + "4 should be the last entry in the index") + assertEquals(4, idx.lastOffset, + "4 should be the last entry in the index") + idx.append(5, 5) + + idx.truncate() + assertEquals(0, idx.entries, "Full truncation should leave no entries") + idx.append(0, 0) + } + + @Test + def forceUnmapTest(): Unit = { + val idx = new OffsetIndex(nonExistentTempFile(), baseOffset = 0L, maxIndexSize = 10 * 8) + idx.forceUnmap() + // mmap should be null after unmap causing lookup to throw a NPE + assertThrows(classOf[NullPointerException], () => idx.lookup(1)) + } + + @Test + def testSanityLastOffsetEqualToBaseOffset(): Unit = { + // Test index sanity for the case where the last offset appended to the index is equal to the base offset + val baseOffset = 20L + val idx = new OffsetIndex(nonExistentTempFile(), baseOffset = baseOffset, maxIndexSize = 10 * 8) + idx.append(baseOffset, 0) + idx.sanityCheck() + } + + def assertWriteFails[T](message: String, idx: OffsetIndex, offset: Int, klass: Class[T]): Unit = { + val e = assertThrows(classOf[Exception], () => idx.append(offset, 1), () => message) + assertEquals(klass, e.getClass, "Got an unexpected exception.") + } + + def monotonicSeq(base: Int, len: Int): Seq[Int] = { + val rand = new Random(1L) + val vals = new mutable.ArrayBuffer[Int](len) + var last = base + for (_ <- 0 until len) { + last += rand.nextInt(15) + 1 + vals += last + } + vals + } + + def nonExistentTempFile(): File = { + val file = TestUtils.tempFile() + Files.delete(file.toPath) + file + } + +} diff --git a/core/src/test/scala/unit/kafka/log/OffsetMapTest.scala b/core/src/test/scala/unit/kafka/log/OffsetMapTest.scala new file mode 100644 index 0000000..dd5294b --- /dev/null +++ b/core/src/test/scala/unit/kafka/log/OffsetMapTest.scala @@ -0,0 +1,88 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import java.nio._ + +import kafka.utils.Exit +import org.junit.jupiter.api._ +import org.junit.jupiter.api.Assertions._ + +class OffsetMapTest { + + @Test + def testBasicValidation(): Unit = { + validateMap(10) + validateMap(100) + validateMap(1000) + validateMap(5000) + } + + @Test + def testClear(): Unit = { + val map = new SkimpyOffsetMap(4000) + for(i <- 0 until 10) + map.put(key(i), i) + for(i <- 0 until 10) + assertEquals(i.toLong, map.get(key(i))) + map.clear() + for(i <- 0 until 10) + assertEquals(map.get(key(i)), -1L) + } + + @Test + def testGetWhenFull(): Unit = { + val map = new SkimpyOffsetMap(4096) + var i = 37L //any value would do + while (map.size < map.slots) { + map.put(key(i), i) + i = i + 1L + } + assertEquals(map.get(key(i)), -1L) + assertEquals(map.get(key(i-1L)), i-1L) + } + + def key(key: Long) = ByteBuffer.wrap(key.toString.getBytes) + + def validateMap(items: Int, loadFactor: Double = 0.5): SkimpyOffsetMap = { + val map = new SkimpyOffsetMap((items/loadFactor * 24).toInt) + for(i <- 0 until items) + map.put(key(i), i) + for(i <- 0 until items) + assertEquals(map.get(key(i)), i.toLong) + map + } + +} + +object OffsetMapTest { + def main(args: Array[String]): Unit = { + if(args.length != 2) { + System.err.println("USAGE: java OffsetMapTest size load") + Exit.exit(1) + } + val test = new OffsetMapTest() + val size = args(0).toInt + val load = args(1).toDouble + val start = System.nanoTime + val map = test.validateMap(size, load) + val ellapsedMs = (System.nanoTime - start) / 1000.0 / 1000.0 + println(s"${map.size} entries in map of size ${map.slots} in $ellapsedMs ms") + println("Collision rate: %.1f%%".format(100*map.collisionRate)) + } +} diff --git a/core/src/test/scala/unit/kafka/log/ProducerStateManagerTest.scala b/core/src/test/scala/unit/kafka/log/ProducerStateManagerTest.scala new file mode 100644 index 0000000..0c2fb6b --- /dev/null +++ b/core/src/test/scala/unit/kafka/log/ProducerStateManagerTest.scala @@ -0,0 +1,1010 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import java.io.File +import java.nio.ByteBuffer +import java.nio.channels.FileChannel +import java.nio.file.{Files, StandardOpenOption} +import java.util.Collections +import java.util.concurrent.atomic.AtomicInteger + +import kafka.server.LogOffsetMetadata +import kafka.utils.TestUtils +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.errors._ +import org.apache.kafka.common.internals.Topic +import org.apache.kafka.common.record._ +import org.apache.kafka.common.utils.{MockTime, Utils} +import org.easymock.EasyMock +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} + +class ProducerStateManagerTest { + var logDir: File = null + var stateManager: ProducerStateManager = null + val partition = new TopicPartition("test", 0) + val producerId = 1L + val maxPidExpirationMs = 60 * 1000 + val time = new MockTime + + @BeforeEach + def setUp(): Unit = { + logDir = TestUtils.tempDir() + stateManager = new ProducerStateManager(partition, logDir, maxPidExpirationMs, time) + } + + @AfterEach + def tearDown(): Unit = { + Utils.delete(logDir) + } + + @Test + def testBasicIdMapping(): Unit = { + val epoch = 0.toShort + + // First entry for id 0 added + append(stateManager, producerId, epoch, 0, 0L, 0L) + + // Second entry for id 0 added + append(stateManager, producerId, epoch, 1, 0L, 1L) + + // Duplicates are checked separately and should result in OutOfOrderSequence if appended + assertThrows(classOf[OutOfOrderSequenceException], () => append(stateManager, producerId, epoch, 1, 0L, 1L)) + + // Invalid sequence number (greater than next expected sequence number) + assertThrows(classOf[OutOfOrderSequenceException], () => append(stateManager, producerId, epoch, 5, 0L, 2L)) + + // Change epoch + append(stateManager, producerId, (epoch + 1).toShort, 0, 0L, 3L) + + // Incorrect epoch + assertThrows(classOf[InvalidProducerEpochException], () => append(stateManager, producerId, epoch, 0, 0L, 4L)) + } + + @Test + def testAppendTxnMarkerWithNoProducerState(): Unit = { + val producerEpoch = 2.toShort + appendEndTxnMarker(stateManager, producerId, producerEpoch, ControlRecordType.COMMIT, offset = 27L) + + val firstEntry = stateManager.lastEntry(producerId).getOrElse(throw new RuntimeException("Expected last entry to be defined")) + assertEquals(producerEpoch, firstEntry.producerEpoch) + assertEquals(producerId, firstEntry.producerId) + assertEquals(RecordBatch.NO_SEQUENCE, firstEntry.lastSeq) + + // Fencing should continue to work even if the marker is the only thing left + assertThrows(classOf[InvalidProducerEpochException], () => append(stateManager, producerId, 0.toShort, 0, 0L, 4L)) + + // If the transaction marker is the only thing left in the log, then an attempt to write using a + // non-zero sequence number should cause an OutOfOrderSequenceException, so that the producer can reset its state + assertThrows(classOf[OutOfOrderSequenceException], () => append(stateManager, producerId, producerEpoch, 17, 0L, 4L)) + + // The broker should accept the request if the sequence number is reset to 0 + append(stateManager, producerId, producerEpoch, 0, 39L, 4L) + val secondEntry = stateManager.lastEntry(producerId).getOrElse(throw new RuntimeException("Expected last entry to be defined")) + assertEquals(producerEpoch, secondEntry.producerEpoch) + assertEquals(producerId, secondEntry.producerId) + assertEquals(0, secondEntry.lastSeq) + } + + @Test + def testProducerSequenceWrapAround(): Unit = { + val epoch = 15.toShort + val sequence = Int.MaxValue + val offset = 735L + append(stateManager, producerId, epoch, sequence, offset, origin = AppendOrigin.Replication) + + append(stateManager, producerId, epoch, 0, offset + 500) + + val maybeLastEntry = stateManager.lastEntry(producerId) + assertTrue(maybeLastEntry.isDefined) + + val lastEntry = maybeLastEntry.get + assertEquals(epoch, lastEntry.producerEpoch) + + assertEquals(Int.MaxValue, lastEntry.firstSeq) + assertEquals(0, lastEntry.lastSeq) + } + + @Test + def testProducerSequenceWithWrapAroundBatchRecord(): Unit = { + val epoch = 15.toShort + + val appendInfo = stateManager.prepareUpdate(producerId, origin = AppendOrigin.Replication) + // Sequence number wrap around + appendInfo.appendDataBatch(epoch, Int.MaxValue - 10, 9, time.milliseconds(), + LogOffsetMetadata(2000L), 2020L, isTransactional = false) + assertEquals(None, stateManager.lastEntry(producerId)) + stateManager.update(appendInfo) + assertTrue(stateManager.lastEntry(producerId).isDefined) + + val lastEntry = stateManager.lastEntry(producerId).get + assertEquals(Int.MaxValue-10, lastEntry.firstSeq) + assertEquals(9, lastEntry.lastSeq) + assertEquals(2000L, lastEntry.firstDataOffset) + assertEquals(2020L, lastEntry.lastDataOffset) + } + + @Test + def testProducerSequenceInvalidWrapAround(): Unit = { + val epoch = 15.toShort + val sequence = Int.MaxValue + val offset = 735L + append(stateManager, producerId, epoch, sequence, offset, origin = AppendOrigin.Replication) + assertThrows(classOf[OutOfOrderSequenceException], () => append(stateManager, producerId, epoch, 1, offset + 500)) + } + + @Test + def testNoValidationOnFirstEntryWhenLoadingLog(): Unit = { + val epoch = 5.toShort + val sequence = 16 + val offset = 735L + append(stateManager, producerId, epoch, sequence, offset, origin = AppendOrigin.Replication) + + val maybeLastEntry = stateManager.lastEntry(producerId) + assertTrue(maybeLastEntry.isDefined) + + val lastEntry = maybeLastEntry.get + assertEquals(epoch, lastEntry.producerEpoch) + assertEquals(sequence, lastEntry.firstSeq) + assertEquals(sequence, lastEntry.lastSeq) + assertEquals(offset, lastEntry.lastDataOffset) + assertEquals(offset, lastEntry.firstDataOffset) + } + + @Test + def testControlRecordBumpsProducerEpoch(): Unit = { + val producerEpoch = 0.toShort + append(stateManager, producerId, producerEpoch, 0, 0L) + + val bumpedProducerEpoch = 1.toShort + appendEndTxnMarker(stateManager, producerId, bumpedProducerEpoch, ControlRecordType.ABORT, 1L) + + val maybeLastEntry = stateManager.lastEntry(producerId) + assertTrue(maybeLastEntry.isDefined) + + val lastEntry = maybeLastEntry.get + assertEquals(bumpedProducerEpoch, lastEntry.producerEpoch) + assertEquals(None, lastEntry.currentTxnFirstOffset) + assertEquals(RecordBatch.NO_SEQUENCE, lastEntry.firstSeq) + assertEquals(RecordBatch.NO_SEQUENCE, lastEntry.lastSeq) + + // should be able to append with the new epoch if we start at sequence 0 + append(stateManager, producerId, bumpedProducerEpoch, 0, 2L) + assertEquals(Some(0), stateManager.lastEntry(producerId).map(_.firstSeq)) + } + + @Test + def testTxnFirstOffsetMetadataCached(): Unit = { + val producerEpoch = 0.toShort + val offset = 992342L + val seq = 0 + val producerAppendInfo = new ProducerAppendInfo(partition, producerId, ProducerStateEntry.empty(producerId), AppendOrigin.Client) + + val firstOffsetMetadata = LogOffsetMetadata(messageOffset = offset, segmentBaseOffset = 990000L, + relativePositionInSegment = 234224) + producerAppendInfo.appendDataBatch(producerEpoch, seq, seq, time.milliseconds(), + firstOffsetMetadata, offset, isTransactional = true) + stateManager.update(producerAppendInfo) + + assertEquals(Some(firstOffsetMetadata), stateManager.firstUnstableOffset) + } + + @Test + def testSkipEmptyTransactions(): Unit = { + val producerEpoch = 0.toShort + val coordinatorEpoch = 27 + val seq = new AtomicInteger(0) + + def appendEndTxn( + recordType: ControlRecordType, + offset: Long, + appendInfo: ProducerAppendInfo + ): Option[CompletedTxn] = { + appendInfo.appendEndTxnMarker(new EndTransactionMarker(recordType, coordinatorEpoch), + producerEpoch, offset, time.milliseconds()) + } + + def appendData( + startOffset: Long, + endOffset: Long, + appendInfo: ProducerAppendInfo + ): Unit = { + val count = (endOffset - startOffset).toInt + appendInfo.appendDataBatch(producerEpoch, seq.get(), seq.addAndGet(count), time.milliseconds(), + LogOffsetMetadata(startOffset), endOffset, isTransactional = true) + seq.incrementAndGet() + } + + // Start one transaction in a separate append + val firstAppend = stateManager.prepareUpdate(producerId, origin = AppendOrigin.Client) + appendData(16L, 20L, firstAppend) + assertEquals(new TxnMetadata(producerId, 16L), firstAppend.startedTransactions.head) + stateManager.update(firstAppend) + stateManager.onHighWatermarkUpdated(21L) + assertEquals(Some(LogOffsetMetadata(16L)), stateManager.firstUnstableOffset) + + // Now do a single append which completes the old transaction, mixes in + // some empty transactions, one non-empty complete transaction, and one + // incomplete transaction + val secondAppend = stateManager.prepareUpdate(producerId, origin = AppendOrigin.Client) + val firstCompletedTxn = appendEndTxn(ControlRecordType.COMMIT, 21, secondAppend) + assertEquals(Some(CompletedTxn(producerId, 16L, 21, isAborted = false)), firstCompletedTxn) + assertEquals(None, appendEndTxn(ControlRecordType.COMMIT, 22, secondAppend)) + assertEquals(None, appendEndTxn(ControlRecordType.ABORT, 23, secondAppend)) + appendData(24L, 27L, secondAppend) + val secondCompletedTxn = appendEndTxn(ControlRecordType.ABORT, 28L, secondAppend) + assertTrue(secondCompletedTxn.isDefined) + assertEquals(None, appendEndTxn(ControlRecordType.ABORT, 29L, secondAppend)) + appendData(30L, 31L, secondAppend) + + assertEquals(2, secondAppend.startedTransactions.size) + assertEquals(TxnMetadata(producerId, LogOffsetMetadata(24L)), secondAppend.startedTransactions.head) + assertEquals(TxnMetadata(producerId, LogOffsetMetadata(30L)), secondAppend.startedTransactions.last) + stateManager.update(secondAppend) + stateManager.completeTxn(firstCompletedTxn.get) + stateManager.completeTxn(secondCompletedTxn.get) + stateManager.onHighWatermarkUpdated(32L) + assertEquals(Some(LogOffsetMetadata(30L)), stateManager.firstUnstableOffset) + } + + @Test + def testLastStableOffsetCompletedTxn(): Unit = { + val producerEpoch = 0.toShort + val segmentBaseOffset = 990000L + + def beginTxn(producerId: Long, startOffset: Long): Unit = { + val relativeOffset = (startOffset - segmentBaseOffset).toInt + val producerAppendInfo = new ProducerAppendInfo( + partition, + producerId, + ProducerStateEntry.empty(producerId), + AppendOrigin.Client + ) + val firstOffsetMetadata = LogOffsetMetadata(messageOffset = startOffset, segmentBaseOffset = segmentBaseOffset, + relativePositionInSegment = 50 * relativeOffset) + producerAppendInfo.appendDataBatch(producerEpoch, 0, 0, time.milliseconds(), + firstOffsetMetadata, startOffset, isTransactional = true) + stateManager.update(producerAppendInfo) + } + + val producerId1 = producerId + val startOffset1 = 992342L + beginTxn(producerId1, startOffset1) + + val producerId2 = producerId + 1 + val startOffset2 = startOffset1 + 25 + beginTxn(producerId2, startOffset2) + + val producerId3 = producerId + 2 + val startOffset3 = startOffset1 + 57 + beginTxn(producerId3, startOffset3) + + val lastOffset1 = startOffset3 + 15 + val completedTxn1 = CompletedTxn(producerId1, startOffset1, lastOffset1, isAborted = false) + assertEquals(startOffset2, stateManager.lastStableOffset(completedTxn1)) + stateManager.completeTxn(completedTxn1) + stateManager.onHighWatermarkUpdated(lastOffset1 + 1) + assertEquals(Some(startOffset2), stateManager.firstUnstableOffset.map(_.messageOffset)) + + val lastOffset3 = lastOffset1 + 20 + val completedTxn3 = CompletedTxn(producerId3, startOffset3, lastOffset3, isAborted = false) + assertEquals(startOffset2, stateManager.lastStableOffset(completedTxn3)) + stateManager.completeTxn(completedTxn3) + stateManager.onHighWatermarkUpdated(lastOffset3 + 1) + assertEquals(Some(startOffset2), stateManager.firstUnstableOffset.map(_.messageOffset)) + + val lastOffset2 = lastOffset3 + 78 + val completedTxn2 = CompletedTxn(producerId2, startOffset2, lastOffset2, isAborted = false) + assertEquals(lastOffset2 + 1, stateManager.lastStableOffset(completedTxn2)) + stateManager.completeTxn(completedTxn2) + stateManager.onHighWatermarkUpdated(lastOffset2 + 1) + assertEquals(None, stateManager.firstUnstableOffset) + } + + @Test + def testPrepareUpdateDoesNotMutate(): Unit = { + val producerEpoch = 0.toShort + + val appendInfo = stateManager.prepareUpdate(producerId, origin = AppendOrigin.Client) + appendInfo.appendDataBatch(producerEpoch, 0, 5, time.milliseconds(), + LogOffsetMetadata(15L), 20L, isTransactional = false) + assertEquals(None, stateManager.lastEntry(producerId)) + stateManager.update(appendInfo) + assertTrue(stateManager.lastEntry(producerId).isDefined) + + val nextAppendInfo = stateManager.prepareUpdate(producerId, origin = AppendOrigin.Client) + nextAppendInfo.appendDataBatch(producerEpoch, 6, 10, time.milliseconds(), + LogOffsetMetadata(26L), 30L, isTransactional = false) + assertTrue(stateManager.lastEntry(producerId).isDefined) + + var lastEntry = stateManager.lastEntry(producerId).get + assertEquals(0, lastEntry.firstSeq) + assertEquals(5, lastEntry.lastSeq) + assertEquals(20L, lastEntry.lastDataOffset) + + stateManager.update(nextAppendInfo) + lastEntry = stateManager.lastEntry(producerId).get + assertEquals(0, lastEntry.firstSeq) + assertEquals(10, lastEntry.lastSeq) + assertEquals(30L, lastEntry.lastDataOffset) + } + + @Test + def updateProducerTransactionState(): Unit = { + val producerEpoch = 0.toShort + val coordinatorEpoch = 15 + val offset = 9L + append(stateManager, producerId, producerEpoch, 0, offset) + + val appendInfo = stateManager.prepareUpdate(producerId, origin = AppendOrigin.Client) + appendInfo.appendDataBatch(producerEpoch, 1, 5, time.milliseconds(), + LogOffsetMetadata(16L), 20L, isTransactional = true) + var lastEntry = appendInfo.toEntry + assertEquals(producerEpoch, lastEntry.producerEpoch) + assertEquals(1, lastEntry.firstSeq) + assertEquals(5, lastEntry.lastSeq) + assertEquals(16L, lastEntry.firstDataOffset) + assertEquals(20L, lastEntry.lastDataOffset) + assertEquals(Some(16L), lastEntry.currentTxnFirstOffset) + assertEquals(List(new TxnMetadata(producerId, 16L)), appendInfo.startedTransactions) + + appendInfo.appendDataBatch(producerEpoch, 6, 10, time.milliseconds(), + LogOffsetMetadata(26L), 30L, isTransactional = true) + lastEntry = appendInfo.toEntry + assertEquals(producerEpoch, lastEntry.producerEpoch) + assertEquals(1, lastEntry.firstSeq) + assertEquals(10, lastEntry.lastSeq) + assertEquals(16L, lastEntry.firstDataOffset) + assertEquals(30L, lastEntry.lastDataOffset) + assertEquals(Some(16L), lastEntry.currentTxnFirstOffset) + assertEquals(List(new TxnMetadata(producerId, 16L)), appendInfo.startedTransactions) + + val endTxnMarker = new EndTransactionMarker(ControlRecordType.COMMIT, coordinatorEpoch) + val completedTxnOpt = appendInfo.appendEndTxnMarker(endTxnMarker, producerEpoch, 40L, time.milliseconds()) + assertTrue(completedTxnOpt.isDefined) + + val completedTxn = completedTxnOpt.get + assertEquals(producerId, completedTxn.producerId) + assertEquals(16L, completedTxn.firstOffset) + assertEquals(40L, completedTxn.lastOffset) + assertFalse(completedTxn.isAborted) + + lastEntry = appendInfo.toEntry + assertEquals(producerEpoch, lastEntry.producerEpoch) + // verify that appending the transaction marker doesn't affect the metadata of the cached record batches. + assertEquals(1, lastEntry.firstSeq) + assertEquals(10, lastEntry.lastSeq) + assertEquals(16L, lastEntry.firstDataOffset) + assertEquals(30L, lastEntry.lastDataOffset) + assertEquals(coordinatorEpoch, lastEntry.coordinatorEpoch) + assertEquals(None, lastEntry.currentTxnFirstOffset) + assertEquals(List(new TxnMetadata(producerId, 16L)), appendInfo.startedTransactions) + } + + @Test + def testOutOfSequenceAfterControlRecordEpochBump(): Unit = { + val epoch = 0.toShort + append(stateManager, producerId, epoch, 0, 0L, isTransactional = true) + append(stateManager, producerId, epoch, 1, 1L, isTransactional = true) + + val bumpedEpoch = 1.toShort + appendEndTxnMarker(stateManager, producerId, bumpedEpoch, ControlRecordType.ABORT, 1L) + + // next append is invalid since we expect the sequence to be reset + assertThrows(classOf[OutOfOrderSequenceException], + () => append(stateManager, producerId, bumpedEpoch, 2, 2L, isTransactional = true)) + + assertThrows(classOf[OutOfOrderSequenceException], + () => append(stateManager, producerId, (bumpedEpoch + 1).toShort, 2, 2L, isTransactional = true)) + + // Append with the bumped epoch should be fine if starting from sequence 0 + append(stateManager, producerId, bumpedEpoch, 0, 0L, isTransactional = true) + assertEquals(bumpedEpoch, stateManager.lastEntry(producerId).get.producerEpoch) + assertEquals(0, stateManager.lastEntry(producerId).get.lastSeq) + } + + @Test + def testNonTransactionalAppendWithOngoingTransaction(): Unit = { + val epoch = 0.toShort + append(stateManager, producerId, epoch, 0, 0L, isTransactional = true) + assertThrows(classOf[InvalidTxnStateException], () => append(stateManager, producerId, epoch, 1, 1L, isTransactional = false)) + } + + @Test + def testTruncateAndReloadRemovesOutOfRangeSnapshots(): Unit = { + val epoch = 0.toShort + append(stateManager, producerId, epoch, 0, 0L) + stateManager.takeSnapshot() + append(stateManager, producerId, epoch, 1, 1L) + stateManager.takeSnapshot() + append(stateManager, producerId, epoch, 2, 2L) + stateManager.takeSnapshot() + append(stateManager, producerId, epoch, 3, 3L) + stateManager.takeSnapshot() + append(stateManager, producerId, epoch, 4, 4L) + stateManager.takeSnapshot() + + stateManager.truncateAndReload(1L, 3L, time.milliseconds()) + + assertEquals(Some(2L), stateManager.oldestSnapshotOffset) + assertEquals(Some(3L), stateManager.latestSnapshotOffset) + } + + @Test + def testTakeSnapshot(): Unit = { + val epoch = 0.toShort + append(stateManager, producerId, epoch, 0, 0L, 0L) + append(stateManager, producerId, epoch, 1, 1L, 1L) + + // Take snapshot + stateManager.takeSnapshot() + + // Check that file exists and it is not empty + assertEquals(1, logDir.list().length, "Directory doesn't contain a single file as expected") + assertTrue(logDir.list().head.nonEmpty, "Snapshot file is empty") + } + + @Test + def testRecoverFromSnapshotUnfinishedTransaction(): Unit = { + val epoch = 0.toShort + append(stateManager, producerId, epoch, 0, 0L, isTransactional = true) + append(stateManager, producerId, epoch, 1, 1L, isTransactional = true) + + stateManager.takeSnapshot() + val recoveredMapping = new ProducerStateManager(partition, logDir, maxPidExpirationMs, time) + recoveredMapping.truncateAndReload(0L, 3L, time.milliseconds) + + // The snapshot only persists the last appended batch metadata + val loadedEntry = recoveredMapping.lastEntry(producerId) + assertEquals(1, loadedEntry.get.firstDataOffset) + assertEquals(1, loadedEntry.get.firstSeq) + assertEquals(1, loadedEntry.get.lastDataOffset) + assertEquals(1, loadedEntry.get.lastSeq) + assertEquals(Some(0), loadedEntry.get.currentTxnFirstOffset) + + // entry added after recovery + append(recoveredMapping, producerId, epoch, 2, 2L, isTransactional = true) + } + + @Test + def testRecoverFromSnapshotFinishedTransaction(): Unit = { + val epoch = 0.toShort + append(stateManager, producerId, epoch, 0, 0L, isTransactional = true) + append(stateManager, producerId, epoch, 1, 1L, isTransactional = true) + appendEndTxnMarker(stateManager, producerId, epoch, ControlRecordType.ABORT, offset = 2L) + + stateManager.takeSnapshot() + val recoveredMapping = new ProducerStateManager(partition, logDir, maxPidExpirationMs, time) + recoveredMapping.truncateAndReload(0L, 3L, time.milliseconds) + + // The snapshot only persists the last appended batch metadata + val loadedEntry = recoveredMapping.lastEntry(producerId) + assertEquals(1, loadedEntry.get.firstDataOffset) + assertEquals(1, loadedEntry.get.firstSeq) + assertEquals(1, loadedEntry.get.lastDataOffset) + assertEquals(1, loadedEntry.get.lastSeq) + assertEquals(None, loadedEntry.get.currentTxnFirstOffset) + } + + @Test + def testRecoverFromSnapshotEmptyTransaction(): Unit = { + val epoch = 0.toShort + val appendTimestamp = time.milliseconds() + appendEndTxnMarker(stateManager, producerId, epoch, ControlRecordType.ABORT, + offset = 0L, timestamp = appendTimestamp) + stateManager.takeSnapshot() + + val recoveredMapping = new ProducerStateManager(partition, logDir, maxPidExpirationMs, time) + recoveredMapping.truncateAndReload(logStartOffset = 0L, logEndOffset = 1L, time.milliseconds) + + val lastEntry = recoveredMapping.lastEntry(producerId) + assertTrue(lastEntry.isDefined) + assertEquals(appendTimestamp, lastEntry.get.lastTimestamp) + assertEquals(None, lastEntry.get.currentTxnFirstOffset) + } + + @Test + def testProducerStateAfterFencingAbortMarker(): Unit = { + val epoch = 0.toShort + append(stateManager, producerId, epoch, 0, 0L, isTransactional = true) + appendEndTxnMarker(stateManager, producerId, (epoch + 1).toShort, ControlRecordType.ABORT, offset = 1L) + + val lastEntry = stateManager.lastEntry(producerId).get + assertEquals(None, lastEntry.currentTxnFirstOffset) + assertEquals(-1, lastEntry.lastDataOffset) + assertEquals(-1, lastEntry.firstDataOffset) + + // The producer should not be expired because we want to preserve fencing epochs + stateManager.removeExpiredProducers(time.milliseconds()) + assertTrue(stateManager.lastEntry(producerId).isDefined) + } + + @Test + def testRemoveExpiredPidsOnReload(): Unit = { + val epoch = 0.toShort + append(stateManager, producerId, epoch, 0, 0L, 0) + append(stateManager, producerId, epoch, 1, 1L, 1) + + stateManager.takeSnapshot() + val recoveredMapping = new ProducerStateManager(partition, logDir, maxPidExpirationMs, time) + recoveredMapping.truncateAndReload(0L, 1L, 70000) + + // entry added after recovery. The pid should be expired now, and would not exist in the pid mapping. Hence + // we should accept the append and add the pid back in + append(recoveredMapping, producerId, epoch, 2, 2L, 70001) + + assertEquals(1, recoveredMapping.activeProducers.size) + assertEquals(2, recoveredMapping.activeProducers.head._2.lastSeq) + assertEquals(3L, recoveredMapping.mapEndOffset) + } + + @Test + def testAcceptAppendWithoutProducerStateOnReplica(): Unit = { + val epoch = 0.toShort + append(stateManager, producerId, epoch, 0, 0L, 0) + append(stateManager, producerId, epoch, 1, 1L, 1) + + stateManager.takeSnapshot() + val recoveredMapping = new ProducerStateManager(partition, logDir, maxPidExpirationMs, time) + recoveredMapping.truncateAndReload(0L, 1L, 70000) + + val sequence = 2 + // entry added after recovery. The pid should be expired now, and would not exist in the pid mapping. Nonetheless + // the append on a replica should be accepted with the local producer state updated to the appended value. + assertFalse(recoveredMapping.activeProducers.contains(producerId)) + append(recoveredMapping, producerId, epoch, sequence, 2L, 70001, origin = AppendOrigin.Replication) + assertTrue(recoveredMapping.activeProducers.contains(producerId)) + val producerStateEntry = recoveredMapping.activeProducers.get(producerId).head + assertEquals(epoch, producerStateEntry.producerEpoch) + assertEquals(sequence, producerStateEntry.firstSeq) + assertEquals(sequence, producerStateEntry.lastSeq) + } + + @Test + def testAcceptAppendWithSequenceGapsOnReplica(): Unit = { + val epoch = 0.toShort + append(stateManager, producerId, epoch, 0, 0L, 0) + val outOfOrderSequence = 3 + + // First we ensure that we raise an OutOfOrderSequenceException is raised when the append comes from a client. + assertThrows(classOf[OutOfOrderSequenceException], () => append(stateManager, producerId, epoch, outOfOrderSequence, 1L, 1, origin = AppendOrigin.Client)) + + assertEquals(0L, stateManager.activeProducers(producerId).lastSeq) + append(stateManager, producerId, epoch, outOfOrderSequence, 1L, 1, origin = AppendOrigin.Replication) + assertEquals(outOfOrderSequence, stateManager.activeProducers(producerId).lastSeq) + } + + @Test + def testDeleteSnapshotsBefore(): Unit = { + val epoch = 0.toShort + append(stateManager, producerId, epoch, 0, 0L) + append(stateManager, producerId, epoch, 1, 1L) + stateManager.takeSnapshot() + assertEquals(1, logDir.listFiles().length) + assertEquals(Set(2), currentSnapshotOffsets) + + append(stateManager, producerId, epoch, 2, 2L) + stateManager.takeSnapshot() + assertEquals(2, logDir.listFiles().length) + assertEquals(Set(2, 3), currentSnapshotOffsets) + + stateManager.deleteSnapshotsBefore(3L) + assertEquals(1, logDir.listFiles().length) + assertEquals(Set(3), currentSnapshotOffsets) + + stateManager.deleteSnapshotsBefore(4L) + assertEquals(0, logDir.listFiles().length) + assertEquals(Set(), currentSnapshotOffsets) + } + + @Test + def testTruncateFullyAndStartAt(): Unit = { + val epoch = 0.toShort + + append(stateManager, producerId, epoch, 0, 0L) + append(stateManager, producerId, epoch, 1, 1L) + stateManager.takeSnapshot() + assertEquals(1, logDir.listFiles().length) + assertEquals(Set(2), currentSnapshotOffsets) + + append(stateManager, producerId, epoch, 2, 2L) + stateManager.takeSnapshot() + assertEquals(2, logDir.listFiles().length) + assertEquals(Set(2, 3), currentSnapshotOffsets) + + stateManager.truncateFullyAndStartAt(0L) + + assertEquals(0, logDir.listFiles().length) + assertEquals(Set(), currentSnapshotOffsets) + + append(stateManager, producerId, epoch, 0, 0L) + stateManager.takeSnapshot() + assertEquals(1, logDir.listFiles().length) + assertEquals(Set(1), currentSnapshotOffsets) + } + + @Test + def testFirstUnstableOffsetAfterTruncation(): Unit = { + val epoch = 0.toShort + val sequence = 0 + + append(stateManager, producerId, epoch, sequence, offset = 99, isTransactional = true) + assertEquals(Some(99), stateManager.firstUnstableOffset.map(_.messageOffset)) + stateManager.takeSnapshot() + + appendEndTxnMarker(stateManager, producerId, epoch, ControlRecordType.COMMIT, offset = 105) + stateManager.onHighWatermarkUpdated(106) + assertEquals(None, stateManager.firstUnstableOffset.map(_.messageOffset)) + stateManager.takeSnapshot() + + append(stateManager, producerId, epoch, sequence + 1, offset = 106) + stateManager.truncateAndReload(0L, 106, time.milliseconds()) + assertEquals(None, stateManager.firstUnstableOffset.map(_.messageOffset)) + + stateManager.truncateAndReload(0L, 100L, time.milliseconds()) + assertEquals(Some(99), stateManager.firstUnstableOffset.map(_.messageOffset)) + } + + @Test + def testLoadFromSnapshotRetainsNonExpiredProducers(): Unit = { + val epoch = 0.toShort + val pid1 = 1L + val pid2 = 2L + + append(stateManager, pid1, epoch, 0, 0L) + append(stateManager, pid2, epoch, 0, 1L) + stateManager.takeSnapshot() + assertEquals(2, stateManager.activeProducers.size) + + stateManager.truncateAndReload(1L, 2L, time.milliseconds()) + assertEquals(2, stateManager.activeProducers.size) + + val entry1 = stateManager.lastEntry(pid1) + assertTrue(entry1.isDefined) + assertEquals(0, entry1.get.lastSeq) + assertEquals(0L, entry1.get.lastDataOffset) + + val entry2 = stateManager.lastEntry(pid2) + assertTrue(entry2.isDefined) + assertEquals(0, entry2.get.lastSeq) + assertEquals(1L, entry2.get.lastDataOffset) + } + + @Test + def testSkipSnapshotIfOffsetUnchanged(): Unit = { + val epoch = 0.toShort + append(stateManager, producerId, epoch, 0, 0L, 0L) + + stateManager.takeSnapshot() + assertEquals(1, logDir.listFiles().length) + assertEquals(Set(1), currentSnapshotOffsets) + + // nothing changed so there should be no new snapshot + stateManager.takeSnapshot() + assertEquals(1, logDir.listFiles().length) + assertEquals(Set(1), currentSnapshotOffsets) + } + + @Test + def testPidExpirationTimeout(): Unit = { + val epoch = 5.toShort + val sequence = 37 + append(stateManager, producerId, epoch, sequence, 1L) + time.sleep(maxPidExpirationMs + 1) + stateManager.removeExpiredProducers(time.milliseconds) + append(stateManager, producerId, epoch, sequence + 1, 2L) + assertEquals(1, stateManager.activeProducers.size) + assertEquals(sequence + 1, stateManager.activeProducers.head._2.lastSeq) + assertEquals(3L, stateManager.mapEndOffset) + } + + @Test + def testFirstUnstableOffset(): Unit = { + val epoch = 5.toShort + val sequence = 0 + + assertEquals(None, stateManager.firstUndecidedOffset) + + append(stateManager, producerId, epoch, sequence, offset = 99, isTransactional = true) + assertEquals(Some(99L), stateManager.firstUndecidedOffset) + assertEquals(Some(99L), stateManager.firstUnstableOffset.map(_.messageOffset)) + + val anotherPid = 2L + append(stateManager, anotherPid, epoch, sequence, offset = 105, isTransactional = true) + assertEquals(Some(99L), stateManager.firstUndecidedOffset) + assertEquals(Some(99L), stateManager.firstUnstableOffset.map(_.messageOffset)) + + appendEndTxnMarker(stateManager, producerId, epoch, ControlRecordType.COMMIT, offset = 109) + assertEquals(Some(105L), stateManager.firstUndecidedOffset) + assertEquals(Some(99L), stateManager.firstUnstableOffset.map(_.messageOffset)) + + stateManager.onHighWatermarkUpdated(100L) + assertEquals(Some(99L), stateManager.firstUnstableOffset.map(_.messageOffset)) + + stateManager.onHighWatermarkUpdated(110L) + assertEquals(Some(105L), stateManager.firstUnstableOffset.map(_.messageOffset)) + + appendEndTxnMarker(stateManager, anotherPid, epoch, ControlRecordType.ABORT, offset = 112) + assertEquals(None, stateManager.firstUndecidedOffset) + assertEquals(Some(105L), stateManager.firstUnstableOffset.map(_.messageOffset)) + + stateManager.onHighWatermarkUpdated(113L) + assertEquals(None, stateManager.firstUnstableOffset.map(_.messageOffset)) + } + + @Test + def testProducersWithOngoingTransactionsDontExpire(): Unit = { + val epoch = 5.toShort + val sequence = 0 + + append(stateManager, producerId, epoch, sequence, offset = 99, isTransactional = true) + assertEquals(Some(99L), stateManager.firstUndecidedOffset) + + time.sleep(maxPidExpirationMs + 1) + stateManager.removeExpiredProducers(time.milliseconds) + + assertTrue(stateManager.lastEntry(producerId).isDefined) + assertEquals(Some(99L), stateManager.firstUndecidedOffset) + + stateManager.removeExpiredProducers(time.milliseconds) + assertTrue(stateManager.lastEntry(producerId).isDefined) + } + + @Test + def testSequenceNotValidatedForGroupMetadataTopic(): Unit = { + val partition = new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, 0) + val stateManager = new ProducerStateManager(partition, logDir, maxPidExpirationMs, time) + + val epoch = 0.toShort + append(stateManager, producerId, epoch, RecordBatch.NO_SEQUENCE, offset = 99, + isTransactional = true, origin = AppendOrigin.Coordinator) + append(stateManager, producerId, epoch, RecordBatch.NO_SEQUENCE, offset = 100, + isTransactional = true, origin = AppendOrigin.Coordinator) + } + + @Test + def testOldEpochForControlRecord(): Unit = { + val epoch = 5.toShort + val sequence = 0 + + assertEquals(None, stateManager.firstUndecidedOffset) + + append(stateManager, producerId, epoch, sequence, offset = 99, isTransactional = true) + assertThrows(classOf[InvalidProducerEpochException], () => appendEndTxnMarker(stateManager, producerId, 3.toShort, + ControlRecordType.COMMIT, offset=100)) + } + + @Test + def testCoordinatorFencing(): Unit = { + val epoch = 5.toShort + val sequence = 0 + + append(stateManager, producerId, epoch, sequence, offset = 99, isTransactional = true) + appendEndTxnMarker(stateManager, producerId, epoch, ControlRecordType.COMMIT, offset = 100, coordinatorEpoch = 1) + + val lastEntry = stateManager.lastEntry(producerId) + assertEquals(Some(1), lastEntry.map(_.coordinatorEpoch)) + + // writing with the current epoch is allowed + appendEndTxnMarker(stateManager, producerId, epoch, ControlRecordType.COMMIT, offset = 101, coordinatorEpoch = 1) + + // bumping the epoch is allowed + appendEndTxnMarker(stateManager, producerId, epoch, ControlRecordType.COMMIT, offset = 102, coordinatorEpoch = 2) + + // old epochs are not allowed + assertThrows(classOf[TransactionCoordinatorFencedException], () => appendEndTxnMarker(stateManager, producerId, epoch, ControlRecordType.COMMIT, offset = 103, coordinatorEpoch = 1)) + } + + @Test + def testCoordinatorFencedAfterReload(): Unit = { + val producerEpoch = 0.toShort + append(stateManager, producerId, producerEpoch, 0, offset = 99, isTransactional = true) + appendEndTxnMarker(stateManager, producerId, producerEpoch, ControlRecordType.COMMIT, offset = 100, coordinatorEpoch = 1) + stateManager.takeSnapshot() + + val recoveredMapping = new ProducerStateManager(partition, logDir, maxPidExpirationMs, time) + recoveredMapping.truncateAndReload(0L, 2L, 70000) + + // append from old coordinator should be rejected + assertThrows(classOf[TransactionCoordinatorFencedException], () => appendEndTxnMarker(stateManager, producerId, + producerEpoch, ControlRecordType.COMMIT, offset = 100, coordinatorEpoch = 0)) + } + + @Test + def testLoadFromEmptySnapshotFile(): Unit = { + testLoadFromCorruptSnapshot { file => + file.truncate(0L) + } + } + + @Test + def testLoadFromTruncatedSnapshotFile(): Unit = { + testLoadFromCorruptSnapshot { file => + // truncate to some arbitrary point in the middle of the snapshot + assertTrue(file.size > 2) + file.truncate(file.size / 2) + } + } + + @Test + def testLoadFromCorruptSnapshotFile(): Unit = { + testLoadFromCorruptSnapshot { file => + // write some garbage somewhere in the file + assertTrue(file.size > 2) + file.write(ByteBuffer.wrap(Array[Byte](37)), file.size / 2) + } + } + + @Test + def testAppendEmptyControlBatch(): Unit = { + val producerId = 23423L + val baseOffset = 15 + + val batch: RecordBatch = EasyMock.createMock(classOf[RecordBatch]) + EasyMock.expect(batch.isControlBatch).andReturn(true).once + EasyMock.expect(batch.iterator).andReturn(Collections.emptyIterator[Record]).once + EasyMock.replay(batch) + + // Appending the empty control batch should not throw and a new transaction shouldn't be started + append(stateManager, producerId, baseOffset, batch, origin = AppendOrigin.Client) + assertEquals(None, stateManager.lastEntry(producerId).get.currentTxnFirstOffset) + } + + @Test + def testRemoveStraySnapshotsKeepCleanShutdownSnapshot(): Unit = { + // Test that when stray snapshots are removed, the largest stray snapshot is kept around. This covers the case where + // the broker shutdown cleanly and emitted a snapshot file larger than the base offset of the active segment. + + // Create 3 snapshot files at different offsets. + UnifiedLog.producerSnapshotFile(logDir, 5).createNewFile() // not stray + UnifiedLog.producerSnapshotFile(logDir, 2).createNewFile() // stray + UnifiedLog.producerSnapshotFile(logDir, 42).createNewFile() // not stray + + // claim that we only have one segment with a base offset of 5 + stateManager.removeStraySnapshots(Seq(5)) + + // The snapshot file at offset 2 should be considered a stray, but the snapshot at 42 should be kept + // around because it is the largest snapshot. + assertEquals(Some(42), stateManager.latestSnapshotOffset) + assertEquals(Some(5), stateManager.oldestSnapshotOffset) + assertEquals(Seq(5, 42), ProducerStateManager.listSnapshotFiles(logDir).map(_.offset).sorted) + } + + @Test + def testRemoveAllStraySnapshots(): Unit = { + // Test that when stray snapshots are removed, we remove only the stray snapshots below the largest segment base offset. + // Snapshots associated with an offset in the list of segment base offsets should remain. + + // Create 3 snapshot files at different offsets. + UnifiedLog.producerSnapshotFile(logDir, 5).createNewFile() // stray + UnifiedLog.producerSnapshotFile(logDir, 2).createNewFile() // stray + UnifiedLog.producerSnapshotFile(logDir, 42).createNewFile() // not stray + + stateManager.removeStraySnapshots(Seq(42)) + assertEquals(Seq(42), ProducerStateManager.listSnapshotFiles(logDir).map(_.offset).sorted) + } + + /** + * Test that removeAndMarkSnapshotForDeletion will rename the SnapshotFile with + * the deletion suffix and remove it from the producer state. + */ + @Test + def testRemoveAndMarkSnapshotForDeletion(): Unit = { + UnifiedLog.producerSnapshotFile(logDir, 5).createNewFile() + val manager = new ProducerStateManager(partition, logDir, time = time) + assertTrue(manager.latestSnapshotOffset.isDefined) + val snapshot = manager.removeAndMarkSnapshotForDeletion(5).get + assertTrue(snapshot.file.toPath.toString.endsWith(UnifiedLog.DeletedFileSuffix)) + assertTrue(manager.latestSnapshotOffset.isEmpty) + } + + /** + * Test that marking a snapshot for deletion when the file has already been deleted + * returns None instead of the SnapshotFile. The snapshot file should be removed from + * the in-memory state of the ProducerStateManager. This scenario can occur during log + * recovery when the intermediate ProducerStateManager instance deletes a file without + * updating the state of the "real" ProducerStateManager instance which is passed to the Log. + */ + @Test + def testRemoveAndMarkSnapshotForDeletionAlreadyDeleted(): Unit = { + val file = UnifiedLog.producerSnapshotFile(logDir, 5) + file.createNewFile() + val manager = new ProducerStateManager(partition, logDir, time = time) + assertTrue(manager.latestSnapshotOffset.isDefined) + Files.delete(file.toPath) + assertTrue(manager.removeAndMarkSnapshotForDeletion(5).isEmpty) + assertTrue(manager.latestSnapshotOffset.isEmpty) + } + + private def testLoadFromCorruptSnapshot(makeFileCorrupt: FileChannel => Unit): Unit = { + val epoch = 0.toShort + val producerId = 1L + + append(stateManager, producerId, epoch, seq = 0, offset = 0L) + stateManager.takeSnapshot() + + append(stateManager, producerId, epoch, seq = 1, offset = 1L) + stateManager.takeSnapshot() + + // Truncate the last snapshot + val latestSnapshotOffset = stateManager.latestSnapshotOffset + assertEquals(Some(2L), latestSnapshotOffset) + val snapshotToTruncate = UnifiedLog.producerSnapshotFile(logDir, latestSnapshotOffset.get) + val channel = FileChannel.open(snapshotToTruncate.toPath, StandardOpenOption.WRITE) + try { + makeFileCorrupt(channel) + } finally { + channel.close() + } + + // Ensure that the truncated snapshot is deleted and producer state is loaded from the previous snapshot + val reloadedStateManager = new ProducerStateManager(partition, logDir, maxPidExpirationMs, time) + reloadedStateManager.truncateAndReload(0L, 20L, time.milliseconds()) + assertFalse(snapshotToTruncate.exists()) + + val loadedProducerState = reloadedStateManager.activeProducers(producerId) + assertEquals(0L, loadedProducerState.lastDataOffset) + } + + private def appendEndTxnMarker(mapping: ProducerStateManager, + producerId: Long, + producerEpoch: Short, + controlType: ControlRecordType, + offset: Long, + coordinatorEpoch: Int = 0, + timestamp: Long = time.milliseconds()): Option[CompletedTxn] = { + val producerAppendInfo = stateManager.prepareUpdate(producerId, origin = AppendOrigin.Coordinator) + val endTxnMarker = new EndTransactionMarker(controlType, coordinatorEpoch) + val completedTxnOpt = producerAppendInfo.appendEndTxnMarker(endTxnMarker, producerEpoch, offset, timestamp) + mapping.update(producerAppendInfo) + completedTxnOpt.foreach(mapping.completeTxn) + mapping.updateMapEndOffset(offset + 1) + completedTxnOpt + } + + private def append(stateManager: ProducerStateManager, + producerId: Long, + producerEpoch: Short, + seq: Int, + offset: Long, + timestamp: Long = time.milliseconds(), + isTransactional: Boolean = false, + origin : AppendOrigin = AppendOrigin.Client): Unit = { + val producerAppendInfo = stateManager.prepareUpdate(producerId, origin) + producerAppendInfo.appendDataBatch(producerEpoch, seq, seq, timestamp, + LogOffsetMetadata(offset), offset, isTransactional) + stateManager.update(producerAppendInfo) + stateManager.updateMapEndOffset(offset + 1) + } + + private def append(stateManager: ProducerStateManager, + producerId: Long, + offset: Long, + batch: RecordBatch, + origin: AppendOrigin): Unit = { + val producerAppendInfo = stateManager.prepareUpdate(producerId, origin) + producerAppendInfo.append(batch, firstOffsetMetadataOpt = None) + stateManager.update(producerAppendInfo) + stateManager.updateMapEndOffset(offset + 1) + } + + private def currentSnapshotOffsets: Set[Long] = + logDir.listFiles.map(UnifiedLog.offsetFromFile).toSet + +} diff --git a/core/src/test/scala/unit/kafka/log/TimeIndexTest.scala b/core/src/test/scala/unit/kafka/log/TimeIndexTest.scala new file mode 100644 index 0000000..3318a52 --- /dev/null +++ b/core/src/test/scala/unit/kafka/log/TimeIndexTest.scala @@ -0,0 +1,147 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import java.io.File + +import kafka.utils.TestUtils +import org.apache.kafka.common.errors.InvalidOffsetException +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} +import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows} + +/** + * Unit test for time index. + */ +class TimeIndexTest { + var idx: TimeIndex = null + val maxEntries = 30 + val baseOffset = 45L + + @BeforeEach + def setup(): Unit = { + this.idx = new TimeIndex(nonExistantTempFile(), baseOffset = baseOffset, maxIndexSize = maxEntries * 12) + } + + @AfterEach + def teardown(): Unit = { + if(this.idx != null) + this.idx.file.delete() + } + + @Test + def testLookUp(): Unit = { + // Empty time index + assertEquals(TimestampOffset(-1L, baseOffset), idx.lookup(100L)) + + // Add several time index entries. + appendEntries(maxEntries - 1) + + // look for timestamp smaller than the earliest entry + assertEquals(TimestampOffset(-1L, baseOffset), idx.lookup(9)) + // look for timestamp in the middle of two entries. + assertEquals(TimestampOffset(20L, 65L), idx.lookup(25)) + // look for timestamp same as the one in the entry + assertEquals(TimestampOffset(30L, 75L), idx.lookup(30)) + } + + @Test + def testEntry(): Unit = { + appendEntries(maxEntries - 1) + assertEquals(TimestampOffset(10L, 55L), idx.entry(0)) + assertEquals(TimestampOffset(20L, 65L), idx.entry(1)) + assertEquals(TimestampOffset(30L, 75L), idx.entry(2)) + assertEquals(TimestampOffset(40L, 85L), idx.entry(3)) + } + + @Test + def testEntryOverflow(): Unit = { + assertThrows(classOf[IllegalArgumentException], () => idx.entry(0)) + } + + @Test + def testTruncate(): Unit = { + appendEntries(maxEntries - 1) + idx.truncate() + assertEquals(0, idx.entries) + + appendEntries(maxEntries - 1) + idx.truncateTo(10 + baseOffset) + assertEquals(0, idx.entries) + } + + @Test + def testAppend(): Unit = { + appendEntries(maxEntries - 1) + assertThrows(classOf[IllegalArgumentException], () => idx.maybeAppend(10000L, 1000L)) + assertThrows(classOf[InvalidOffsetException], () => idx.maybeAppend(10000L, (maxEntries - 2) * 10, true)) + idx.maybeAppend(10000L, 1000L, true) + } + + private def appendEntries(numEntries: Int): Unit = { + for (i <- 1 to numEntries) + idx.maybeAppend(i * 10, i * 10 + baseOffset) + } + + def nonExistantTempFile(): File = { + val file = TestUtils.tempFile() + file.delete() + file + } + + @Test + def testSanityCheck(): Unit = { + idx.sanityCheck() + appendEntries(5) + val firstEntry = idx.entry(0) + idx.sanityCheck() + idx.close() + + var shouldCorruptOffset = false + var shouldCorruptTimestamp = false + var shouldCorruptLength = false + idx = new TimeIndex(idx.file, baseOffset = baseOffset, maxIndexSize = maxEntries * 12) { + override def lastEntry = { + val superLastEntry = super.lastEntry + val offset = if (shouldCorruptOffset) baseOffset - 1 else superLastEntry.offset + val timestamp = if (shouldCorruptTimestamp) firstEntry.timestamp - 1 else superLastEntry.timestamp + new TimestampOffset(timestamp, offset) + } + override def length = { + val superLength = super.length + if (shouldCorruptLength) superLength - 1 else superLength + } + } + + shouldCorruptOffset = true + assertThrows(classOf[CorruptIndexException], () => idx.sanityCheck()) + shouldCorruptOffset = false + + shouldCorruptTimestamp = true + assertThrows(classOf[CorruptIndexException], () => idx.sanityCheck()) + shouldCorruptTimestamp = false + + shouldCorruptLength = true + assertThrows(classOf[CorruptIndexException], () => idx.sanityCheck()) + shouldCorruptLength = false + + idx.sanityCheck() + idx.close() + } + +} + diff --git a/core/src/test/scala/unit/kafka/log/TransactionIndexTest.scala b/core/src/test/scala/unit/kafka/log/TransactionIndexTest.scala new file mode 100644 index 0000000..790bcd8 --- /dev/null +++ b/core/src/test/scala/unit/kafka/log/TransactionIndexTest.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.log + +import kafka.utils.TestUtils +import org.apache.kafka.common.message.FetchResponseData +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} + +import java.io.File + +class TransactionIndexTest { + var file: File = _ + var index: TransactionIndex = _ + val offset = 0L + + @BeforeEach + def setup(): Unit = { + file = TestUtils.tempFile() + index = new TransactionIndex(offset, file) + } + + @AfterEach + def teardown(): Unit = { + index.close() + } + + @Test + def testPositionSetCorrectlyWhenOpened(): Unit = { + val abortedTxns = List( + new AbortedTxn(producerId = 0L, firstOffset = 0, lastOffset = 10, lastStableOffset = 11), + new AbortedTxn(producerId = 1L, firstOffset = 5, lastOffset = 15, lastStableOffset = 13), + new AbortedTxn(producerId = 2L, firstOffset = 18, lastOffset = 35, lastStableOffset = 25), + new AbortedTxn(producerId = 3L, firstOffset = 32, lastOffset = 50, lastStableOffset = 40)) + abortedTxns.foreach(index.append) + index.close() + + val reopenedIndex = new TransactionIndex(0L, file) + val anotherAbortedTxn = new AbortedTxn(producerId = 3L, firstOffset = 50, lastOffset = 60, lastStableOffset = 55) + reopenedIndex.append(anotherAbortedTxn) + assertEquals(abortedTxns ++ List(anotherAbortedTxn), reopenedIndex.allAbortedTxns) + } + + @Test + def testSanityCheck(): Unit = { + val abortedTxns = List( + new AbortedTxn(producerId = 0L, firstOffset = 0, lastOffset = 10, lastStableOffset = 11), + new AbortedTxn(producerId = 1L, firstOffset = 5, lastOffset = 15, lastStableOffset = 13), + new AbortedTxn(producerId = 2L, firstOffset = 18, lastOffset = 35, lastStableOffset = 25), + new AbortedTxn(producerId = 3L, firstOffset = 32, lastOffset = 50, lastStableOffset = 40)) + abortedTxns.foreach(index.append) + index.close() + + // open the index with a different starting offset to fake invalid data + val reopenedIndex = new TransactionIndex(100L, file) + assertThrows(classOf[CorruptIndexException], () => reopenedIndex.sanityCheck()) + } + + @Test + def testLastOffsetMustIncrease(): Unit = { + index.append(new AbortedTxn(producerId = 1L, firstOffset = 5, lastOffset = 15, lastStableOffset = 13)) + assertThrows(classOf[IllegalArgumentException], () => index.append(new AbortedTxn(producerId = 0L, firstOffset = 0, + lastOffset = 15, lastStableOffset = 11))) + } + + @Test + def testLastOffsetCannotDecrease(): Unit = { + index.append(new AbortedTxn(producerId = 1L, firstOffset = 5, lastOffset = 15, lastStableOffset = 13)) + assertThrows(classOf[IllegalArgumentException], () => index.append(new AbortedTxn(producerId = 0L, firstOffset = 0, + lastOffset = 10, lastStableOffset = 11))) + } + + @Test + def testCollectAbortedTransactions(): Unit = { + val abortedTransactions = List( + new AbortedTxn(producerId = 0L, firstOffset = 0, lastOffset = 10, lastStableOffset = 11), + new AbortedTxn(producerId = 1L, firstOffset = 5, lastOffset = 15, lastStableOffset = 13), + new AbortedTxn(producerId = 2L, firstOffset = 18, lastOffset = 35, lastStableOffset = 25), + new AbortedTxn(producerId = 3L, firstOffset = 32, lastOffset = 50, lastStableOffset = 40)) + + abortedTransactions.foreach(index.append) + + var result = index.collectAbortedTxns(0L, 100L) + assertEquals(abortedTransactions, result.abortedTransactions) + assertFalse(result.isComplete) + + result = index.collectAbortedTxns(0L, 32) + assertEquals(abortedTransactions.take(3), result.abortedTransactions) + assertTrue(result.isComplete) + + result = index.collectAbortedTxns(0L, 35) + assertEquals(abortedTransactions, result.abortedTransactions) + assertTrue(result.isComplete) + + result = index.collectAbortedTxns(10, 35) + assertEquals(abortedTransactions, result.abortedTransactions) + assertTrue(result.isComplete) + + result = index.collectAbortedTxns(11, 35) + assertEquals(abortedTransactions.slice(1, 4), result.abortedTransactions) + assertTrue(result.isComplete) + + result = index.collectAbortedTxns(20, 41) + assertEquals(abortedTransactions.slice(2, 4), result.abortedTransactions) + assertFalse(result.isComplete) + } + + @Test + def testTruncate(): Unit = { + val abortedTransactions = List( + new AbortedTxn(producerId = 0L, firstOffset = 0, lastOffset = 10, lastStableOffset = 2), + new AbortedTxn(producerId = 1L, firstOffset = 5, lastOffset = 15, lastStableOffset = 16), + new AbortedTxn(producerId = 2L, firstOffset = 18, lastOffset = 35, lastStableOffset = 25), + new AbortedTxn(producerId = 3L, firstOffset = 32, lastOffset = 50, lastStableOffset = 40)) + + abortedTransactions.foreach(index.append) + + index.truncateTo(51) + assertEquals(abortedTransactions, index.collectAbortedTxns(0L, 100L).abortedTransactions) + + index.truncateTo(50) + assertEquals(abortedTransactions.take(3), index.collectAbortedTxns(0L, 100L).abortedTransactions) + + index.reset() + assertEquals(List.empty[FetchResponseData.AbortedTransaction], index.collectAbortedTxns(0L, 100L).abortedTransactions) + } + + @Test + def testAbortedTxnSerde(): Unit = { + val pid = 983493L + val firstOffset = 137L + val lastOffset = 299L + val lastStableOffset = 200L + + val abortedTxn = new AbortedTxn(pid, firstOffset, lastOffset, lastStableOffset) + assertEquals(AbortedTxn.CurrentVersion, abortedTxn.version) + assertEquals(pid, abortedTxn.producerId) + assertEquals(firstOffset, abortedTxn.firstOffset) + assertEquals(lastOffset, abortedTxn.lastOffset) + assertEquals(lastStableOffset, abortedTxn.lastStableOffset) + } + + @Test + def testRenameIndex(): Unit = { + val renamed = TestUtils.tempFile() + index.append(new AbortedTxn(producerId = 0L, firstOffset = 0, lastOffset = 10, lastStableOffset = 2)) + + index.renameTo(renamed) + index.append(new AbortedTxn(producerId = 1L, firstOffset = 5, lastOffset = 15, lastStableOffset = 16)) + + val abortedTxns = index.collectAbortedTxns(0L, 100L).abortedTransactions + assertEquals(2, abortedTxns.size) + assertEquals(0, abortedTxns(0).firstOffset) + assertEquals(5, abortedTxns(1).firstOffset) + } + +} diff --git a/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala b/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala new file mode 100755 index 0000000..be63413 --- /dev/null +++ b/core/src/test/scala/unit/kafka/log/UnifiedLogTest.scala @@ -0,0 +1,3396 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.log + +import java.io._ +import java.nio.ByteBuffer +import java.nio.file.Files +import java.util.concurrent.{Callable, Executors} +import java.util.{Optional, Properties} +import kafka.common.{OffsetsOutOfOrderException, RecordValidationException, UnexpectedAppendOffsetException} +import kafka.metrics.KafkaYammerMetrics +import kafka.server.checkpoints.LeaderEpochCheckpointFile +import kafka.server.epoch.{EpochEntry, LeaderEpochFileCache} +import kafka.server.{BrokerTopicStats, FetchHighWatermark, FetchIsolation, FetchLogEnd, FetchTxnCommitted, KafkaConfig, LogOffsetMetadata, PartitionMetadataFile} +import kafka.utils._ +import org.apache.kafka.common.{InvalidRecordException, TopicPartition, Uuid} +import org.apache.kafka.common.errors._ +import org.apache.kafka.common.message.FetchResponseData +import org.apache.kafka.common.record.FileRecords.TimestampAndOffset +import org.apache.kafka.common.record.MemoryRecords.RecordFilter +import org.apache.kafka.common.record._ +import org.apache.kafka.common.requests.{ListOffsetsRequest, ListOffsetsResponse} +import org.apache.kafka.common.utils.{BufferSupplier, Time, Utils} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} + +import scala.annotation.nowarn +import scala.collection.Map +import scala.jdk.CollectionConverters._ +import scala.collection.mutable.ListBuffer + +class UnifiedLogTest { + var config: KafkaConfig = null + val brokerTopicStats = new BrokerTopicStats + val tmpDir = TestUtils.tempDir() + val logDir = TestUtils.randomPartitionLogDir(tmpDir) + val mockTime = new MockTime() + def metricsKeySet = KafkaYammerMetrics.defaultRegistry.allMetrics.keySet.asScala + + @BeforeEach + def setUp(): Unit = { + val props = TestUtils.createBrokerConfig(0, "127.0.0.1:1", port = -1) + config = KafkaConfig.fromProps(props) + } + + @AfterEach + def tearDown(): Unit = { + brokerTopicStats.close() + Utils.delete(tmpDir) + } + + def createEmptyLogs(dir: File, offsets: Int*): Unit = { + for(offset <- offsets) { + UnifiedLog.logFile(dir, offset).createNewFile() + UnifiedLog.offsetIndexFile(dir, offset).createNewFile() + } + } + + @Test + def testHighWatermarkMetadataUpdatedAfterSegmentRoll(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024) + val log = createLog(logDir, logConfig) + + def assertFetchSizeAndOffsets(fetchOffset: Long, + expectedSize: Int, + expectedOffsets: Seq[Long]): Unit = { + val readInfo = log.read( + startOffset = fetchOffset, + maxLength = 2048, + isolation = FetchHighWatermark, + minOneMessage = false) + assertEquals(expectedSize, readInfo.records.sizeInBytes) + assertEquals(expectedOffsets, readInfo.records.records.asScala.map(_.offset)) + } + + val records = TestUtils.records(List( + new SimpleRecord(mockTime.milliseconds, "a".getBytes, "value".getBytes), + new SimpleRecord(mockTime.milliseconds, "b".getBytes, "value".getBytes), + new SimpleRecord(mockTime.milliseconds, "c".getBytes, "value".getBytes) + )) + + log.appendAsLeader(records, leaderEpoch = 0) + assertFetchSizeAndOffsets(fetchOffset = 0L, 0, Seq()) + + log.maybeIncrementHighWatermark(log.logEndOffsetMetadata) + assertFetchSizeAndOffsets(fetchOffset = 0L, records.sizeInBytes, Seq(0, 1, 2)) + + log.roll() + assertFetchSizeAndOffsets(fetchOffset = 0L, records.sizeInBytes, Seq(0, 1, 2)) + + log.appendAsLeader(records, leaderEpoch = 0) + assertFetchSizeAndOffsets(fetchOffset = 3L, 0, Seq()) + } + + @Test + def testAppendAsLeaderWithRaftLeader(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024) + val log = createLog(logDir, logConfig) + val leaderEpoch = 0 + + def records(offset: Long): MemoryRecords = TestUtils.records(List( + new SimpleRecord(mockTime.milliseconds, "a".getBytes, "value".getBytes), + new SimpleRecord(mockTime.milliseconds, "b".getBytes, "value".getBytes), + new SimpleRecord(mockTime.milliseconds, "c".getBytes, "value".getBytes) + ), baseOffset = offset, partitionLeaderEpoch = leaderEpoch) + + log.appendAsLeader(records(0), leaderEpoch, AppendOrigin.RaftLeader) + assertEquals(0, log.logStartOffset) + assertEquals(3L, log.logEndOffset) + + // Since raft leader is responsible for assigning offsets, and the LogValidator is bypassed from the performance perspective, + // so the first offset of the MemoryRecords to be append should equal to the next offset in the log + assertThrows(classOf[UnexpectedAppendOffsetException], () => (log.appendAsLeader(records(1), leaderEpoch, AppendOrigin.RaftLeader))) + + // When the first offset of the MemoryRecords to be append equals to the next offset in the log, append will succeed + log.appendAsLeader(records(3), leaderEpoch, AppendOrigin.RaftLeader) + assertEquals(6, log.logEndOffset) + } + + @Test + def testAppendInfoFirstOffset(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024) + val log = createLog(logDir, logConfig) + + val simpleRecords = List( + new SimpleRecord(mockTime.milliseconds, "a".getBytes, "value".getBytes), + new SimpleRecord(mockTime.milliseconds, "b".getBytes, "value".getBytes), + new SimpleRecord(mockTime.milliseconds, "c".getBytes, "value".getBytes) + ) + + val records = TestUtils.records(simpleRecords) + + val firstAppendInfo = log.appendAsLeader(records, leaderEpoch = 0) + assertEquals(LogOffsetMetadata(0, 0, 0), firstAppendInfo.firstOffset.get) + + val secondAppendInfo = log.appendAsLeader( + TestUtils.records(simpleRecords), + leaderEpoch = 0 + ) + assertEquals(LogOffsetMetadata(simpleRecords.size, 0, records.sizeInBytes), secondAppendInfo.firstOffset.get) + + log.roll() + val afterRollAppendInfo = log.appendAsLeader(TestUtils.records(simpleRecords), leaderEpoch = 0) + assertEquals(LogOffsetMetadata(simpleRecords.size * 2, simpleRecords.size * 2, 0), afterRollAppendInfo.firstOffset.get) + } + + @Test + def testTruncateBelowFirstUnstableOffset(): Unit = { + testTruncateBelowFirstUnstableOffset(_.truncateTo) + } + + @Test + def testTruncateFullyAndStartBelowFirstUnstableOffset(): Unit = { + testTruncateBelowFirstUnstableOffset(_.truncateFullyAndStartAt) + } + + private def testTruncateBelowFirstUnstableOffset( + truncateFunc: UnifiedLog => (Long => Unit) + ): Unit = { + // Verify that truncation below the first unstable offset correctly + // resets the producer state. Specifically we are testing the case when + // the segment position of the first unstable offset is unknown. + + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024) + val log = createLog(logDir, logConfig) + + val producerId = 17L + val producerEpoch: Short = 10 + val sequence = 0 + + log.appendAsLeader(TestUtils.records(List( + new SimpleRecord("0".getBytes), + new SimpleRecord("1".getBytes), + new SimpleRecord("2".getBytes) + )), leaderEpoch = 0) + + log.appendAsLeader(MemoryRecords.withTransactionalRecords( + CompressionType.NONE, + producerId, + producerEpoch, + sequence, + new SimpleRecord("3".getBytes), + new SimpleRecord("4".getBytes) + ), leaderEpoch = 0) + + assertEquals(Some(3L), log.firstUnstableOffset) + + // We close and reopen the log to ensure that the first unstable offset segment + // position will be undefined when we truncate the log. + log.close() + + val reopened = createLog(logDir, logConfig) + assertEquals(Some(LogOffsetMetadata(3L)), reopened.producerStateManager.firstUnstableOffset) + + truncateFunc(reopened)(0L) + assertEquals(None, reopened.firstUnstableOffset) + assertEquals(Map.empty, reopened.producerStateManager.activeProducers) + } + + @Test + def testHighWatermarkMaintenance(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024) + val log = createLog(logDir, logConfig) + val leaderEpoch = 0 + + def records(offset: Long): MemoryRecords = TestUtils.records(List( + new SimpleRecord(mockTime.milliseconds, "a".getBytes, "value".getBytes), + new SimpleRecord(mockTime.milliseconds, "b".getBytes, "value".getBytes), + new SimpleRecord(mockTime.milliseconds, "c".getBytes, "value".getBytes) + ), baseOffset = offset, partitionLeaderEpoch= leaderEpoch) + + def assertHighWatermark(offset: Long): Unit = { + assertEquals(offset, log.highWatermark) + assertValidLogOffsetMetadata(log, log.fetchOffsetSnapshot.highWatermark) + } + + // High watermark initialized to 0 + assertHighWatermark(0L) + + // High watermark not changed by append + log.appendAsLeader(records(0), leaderEpoch) + assertHighWatermark(0L) + + // Update high watermark as leader + log.maybeIncrementHighWatermark(LogOffsetMetadata(1L)) + assertHighWatermark(1L) + + // Cannot update past the log end offset + log.updateHighWatermark(5L) + assertHighWatermark(3L) + + // Update high watermark as follower + log.appendAsFollower(records(3L)) + log.updateHighWatermark(6L) + assertHighWatermark(6L) + + // High watermark should be adjusted by truncation + log.truncateTo(3L) + assertHighWatermark(3L) + + log.appendAsLeader(records(0L), leaderEpoch = 0) + assertHighWatermark(3L) + assertEquals(6L, log.logEndOffset) + assertEquals(0L, log.logStartOffset) + + // Full truncation should also reset high watermark + log.truncateFullyAndStartAt(4L) + assertEquals(4L, log.logEndOffset) + assertEquals(4L, log.logStartOffset) + assertHighWatermark(4L) + } + + private def assertNonEmptyFetch(log: UnifiedLog, offset: Long, isolation: FetchIsolation): Unit = { + val readInfo = log.read(startOffset = offset, + maxLength = Int.MaxValue, + isolation = isolation, + minOneMessage = true) + + assertFalse(readInfo.firstEntryIncomplete) + assertTrue(readInfo.records.sizeInBytes > 0) + + val upperBoundOffset = isolation match { + case FetchLogEnd => log.logEndOffset + case FetchHighWatermark => log.highWatermark + case FetchTxnCommitted => log.lastStableOffset + } + + for (record <- readInfo.records.records.asScala) + assertTrue(record.offset < upperBoundOffset) + + assertEquals(offset, readInfo.fetchOffsetMetadata.messageOffset) + assertValidLogOffsetMetadata(log, readInfo.fetchOffsetMetadata) + } + + private def assertEmptyFetch(log: UnifiedLog, offset: Long, isolation: FetchIsolation): Unit = { + val readInfo = log.read(startOffset = offset, + maxLength = Int.MaxValue, + isolation = isolation, + minOneMessage = true) + assertFalse(readInfo.firstEntryIncomplete) + assertEquals(0, readInfo.records.sizeInBytes) + assertEquals(offset, readInfo.fetchOffsetMetadata.messageOffset) + assertValidLogOffsetMetadata(log, readInfo.fetchOffsetMetadata) + } + + @Test + def testFetchUpToLogEndOffset(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024) + val log = createLog(logDir, logConfig) + + log.appendAsLeader(TestUtils.records(List( + new SimpleRecord("0".getBytes), + new SimpleRecord("1".getBytes), + new SimpleRecord("2".getBytes) + )), leaderEpoch = 0) + log.appendAsLeader(TestUtils.records(List( + new SimpleRecord("3".getBytes), + new SimpleRecord("4".getBytes) + )), leaderEpoch = 0) + + (log.logStartOffset until log.logEndOffset).foreach { offset => + assertNonEmptyFetch(log, offset, FetchLogEnd) + } + } + + @Test + def testFetchUpToHighWatermark(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024) + val log = createLog(logDir, logConfig) + + log.appendAsLeader(TestUtils.records(List( + new SimpleRecord("0".getBytes), + new SimpleRecord("1".getBytes), + new SimpleRecord("2".getBytes) + )), leaderEpoch = 0) + log.appendAsLeader(TestUtils.records(List( + new SimpleRecord("3".getBytes), + new SimpleRecord("4".getBytes) + )), leaderEpoch = 0) + + def assertHighWatermarkBoundedFetches(): Unit = { + (log.logStartOffset until log.highWatermark).foreach { offset => + assertNonEmptyFetch(log, offset, FetchHighWatermark) + } + + (log.highWatermark to log.logEndOffset).foreach { offset => + assertEmptyFetch(log, offset, FetchHighWatermark) + } + } + + assertHighWatermarkBoundedFetches() + + log.updateHighWatermark(3L) + assertHighWatermarkBoundedFetches() + + log.updateHighWatermark(5L) + assertHighWatermarkBoundedFetches() + } + + @Test + def testActiveProducers(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024) + val log = createLog(logDir, logConfig) + + def assertProducerState( + producerId: Long, + producerEpoch: Short, + lastSequence: Int, + currentTxnStartOffset: Option[Long], + coordinatorEpoch: Option[Int] + ): Unit = { + val producerStateOpt = log.activeProducers.find(_.producerId == producerId) + assertTrue(producerStateOpt.isDefined) + + val producerState = producerStateOpt.get + assertEquals(producerEpoch, producerState.producerEpoch) + assertEquals(lastSequence, producerState.lastSequence) + assertEquals(currentTxnStartOffset.getOrElse(-1L), producerState.currentTxnStartOffset) + assertEquals(coordinatorEpoch.getOrElse(-1), producerState.coordinatorEpoch) + } + + // Test transactional producer state (open transaction) + val producer1Epoch = 5.toShort + val producerId1 = 1L + LogTestUtils.appendTransactionalAsLeader(log, producerId1, producer1Epoch, mockTime)(5) + assertProducerState( + producerId1, + producer1Epoch, + lastSequence = 4, + currentTxnStartOffset = Some(0L), + coordinatorEpoch = None + ) + + // Test transactional producer state (closed transaction) + val coordinatorEpoch = 15 + LogTestUtils.appendEndTxnMarkerAsLeader(log, producerId1, producer1Epoch, ControlRecordType.COMMIT, mockTime.milliseconds(), coordinatorEpoch) + assertProducerState( + producerId1, + producer1Epoch, + lastSequence = 4, + currentTxnStartOffset = None, + coordinatorEpoch = Some(coordinatorEpoch) + ) + + // Test idempotent producer state + val producer2Epoch = 5.toShort + val producerId2 = 2L + LogTestUtils.appendIdempotentAsLeader(log, producerId2, producer2Epoch, mockTime)(3) + assertProducerState( + producerId2, + producer2Epoch, + lastSequence = 2, + currentTxnStartOffset = None, + coordinatorEpoch = None + ) + } + + @Test + def testFetchUpToLastStableOffset(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024) + val log = createLog(logDir, logConfig) + val epoch = 0.toShort + + val producerId1 = 1L + val producerId2 = 2L + + val appendProducer1 = LogTestUtils.appendTransactionalAsLeader(log, producerId1, epoch, mockTime) + val appendProducer2 = LogTestUtils.appendTransactionalAsLeader(log, producerId2, epoch, mockTime) + + appendProducer1(5) + LogTestUtils.appendNonTransactionalAsLeader(log, 3) + appendProducer2(2) + appendProducer1(4) + LogTestUtils.appendNonTransactionalAsLeader(log, 2) + appendProducer1(10) + + def assertLsoBoundedFetches(): Unit = { + (log.logStartOffset until log.lastStableOffset).foreach { offset => + assertNonEmptyFetch(log, offset, FetchTxnCommitted) + } + + (log.lastStableOffset to log.logEndOffset).foreach { offset => + assertEmptyFetch(log, offset, FetchTxnCommitted) + } + } + + assertLsoBoundedFetches() + + log.updateHighWatermark(log.logEndOffset) + assertLsoBoundedFetches() + + LogTestUtils.appendEndTxnMarkerAsLeader(log, producerId1, epoch, ControlRecordType.COMMIT, mockTime.milliseconds()) + assertEquals(0L, log.lastStableOffset) + + log.updateHighWatermark(log.logEndOffset) + assertEquals(8L, log.lastStableOffset) + assertLsoBoundedFetches() + + LogTestUtils.appendEndTxnMarkerAsLeader(log, producerId2, epoch, ControlRecordType.ABORT, mockTime.milliseconds()) + assertEquals(8L, log.lastStableOffset) + + log.updateHighWatermark(log.logEndOffset) + assertEquals(log.logEndOffset, log.lastStableOffset) + assertLsoBoundedFetches() + } + + @Test + def testOffsetFromProducerSnapshotFile(): Unit = { + val offset = 23423423L + val snapshotFile = UnifiedLog.producerSnapshotFile(tmpDir, offset) + assertEquals(offset, UnifiedLog.offsetFromFile(snapshotFile)) + } + + /** + * Tests for time based log roll. This test appends messages then changes the time + * using the mock clock to force the log to roll and checks the number of segments. + */ + @Test + def testTimeBasedLogRollDuringAppend(): Unit = { + def createRecords = TestUtils.singletonRecords("test".getBytes) + val logConfig = LogTestUtils.createLogConfig(segmentMs = 1 * 60 * 60L) + + // create a log + val log = createLog(logDir, logConfig, maxProducerIdExpirationMs = 24 * 60) + assertEquals(1, log.numberOfSegments, "Log begins with a single empty segment.") + // Test the segment rolling behavior when messages do not have a timestamp. + mockTime.sleep(log.config.segmentMs + 1) + log.appendAsLeader(createRecords, leaderEpoch = 0) + assertEquals(1, log.numberOfSegments, "Log doesn't roll if doing so creates an empty segment.") + + log.appendAsLeader(createRecords, leaderEpoch = 0) + assertEquals(2, log.numberOfSegments, "Log rolls on this append since time has expired.") + + for (numSegments <- 3 until 5) { + mockTime.sleep(log.config.segmentMs + 1) + log.appendAsLeader(createRecords, leaderEpoch = 0) + assertEquals(numSegments, log.numberOfSegments, "Changing time beyond rollMs and appending should create a new segment.") + } + + // Append a message with timestamp to a segment whose first message do not have a timestamp. + val timestamp = mockTime.milliseconds + log.config.segmentMs + 1 + def createRecordsWithTimestamp = TestUtils.singletonRecords(value = "test".getBytes, timestamp = timestamp) + log.appendAsLeader(createRecordsWithTimestamp, leaderEpoch = 0) + assertEquals(4, log.numberOfSegments, "Segment should not have been rolled out because the log rolling should be based on wall clock.") + + // Test the segment rolling behavior when messages have timestamps. + mockTime.sleep(log.config.segmentMs + 1) + log.appendAsLeader(createRecordsWithTimestamp, leaderEpoch = 0) + assertEquals(5, log.numberOfSegments, "A new segment should have been rolled out") + + // move the wall clock beyond log rolling time + mockTime.sleep(log.config.segmentMs + 1) + log.appendAsLeader(createRecordsWithTimestamp, leaderEpoch = 0) + assertEquals(5, log.numberOfSegments, "Log should not roll because the roll should depend on timestamp of the first message.") + + val recordWithExpiredTimestamp = TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds) + log.appendAsLeader(recordWithExpiredTimestamp, leaderEpoch = 0) + assertEquals(6, log.numberOfSegments, "Log should roll because the timestamp in the message should make the log segment expire.") + + val numSegments = log.numberOfSegments + mockTime.sleep(log.config.segmentMs + 1) + log.appendAsLeader(MemoryRecords.withRecords(CompressionType.NONE), leaderEpoch = 0) + assertEquals(numSegments, log.numberOfSegments, "Appending an empty message set should not roll log even if sufficient time has passed.") + } + + @Test + def testRollSegmentThatAlreadyExists(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentMs = 1 * 60 * 60L) + + // create a log + val log = createLog(logDir, logConfig) + assertEquals(1, log.numberOfSegments, "Log begins with a single empty segment.") + + // roll active segment with the same base offset of size zero should recreate the segment + log.roll(Some(0L)) + assertEquals(1, log.numberOfSegments, "Expect 1 segment after roll() empty segment with base offset.") + + // should be able to append records to active segment + val records = TestUtils.records( + List(new SimpleRecord(mockTime.milliseconds, "k1".getBytes, "v1".getBytes)), + baseOffset = 0L, partitionLeaderEpoch = 0) + log.appendAsFollower(records) + assertEquals(1, log.numberOfSegments, "Expect one segment.") + assertEquals(0L, log.activeSegment.baseOffset) + + // make sure we can append more records + val records2 = TestUtils.records( + List(new SimpleRecord(mockTime.milliseconds + 10, "k2".getBytes, "v2".getBytes)), + baseOffset = 1L, partitionLeaderEpoch = 0) + log.appendAsFollower(records2) + + assertEquals(2, log.logEndOffset, "Expect two records in the log") + assertEquals(0, LogTestUtils.readLog(log, 0, 1).records.batches.iterator.next().lastOffset) + assertEquals(1, LogTestUtils.readLog(log, 1, 1).records.batches.iterator.next().lastOffset) + + // roll so that active segment is empty + log.roll() + assertEquals(2L, log.activeSegment.baseOffset, "Expect base offset of active segment to be LEO") + assertEquals(2, log.numberOfSegments, "Expect two segments.") + + // manually resize offset index to force roll of an empty active segment on next append + log.activeSegment.offsetIndex.resize(0) + val records3 = TestUtils.records( + List(new SimpleRecord(mockTime.milliseconds + 12, "k3".getBytes, "v3".getBytes)), + baseOffset = 2L, partitionLeaderEpoch = 0) + log.appendAsFollower(records3) + assertTrue(log.activeSegment.offsetIndex.maxEntries > 1) + assertEquals(2, LogTestUtils.readLog(log, 2, 1).records.batches.iterator.next().lastOffset) + assertEquals(2, log.numberOfSegments, "Expect two segments.") + } + + @Test + def testNonSequentialAppend(): Unit = { + // create a log + val log = createLog(logDir, LogConfig()) + val pid = 1L + val epoch: Short = 0 + + val records = TestUtils.records(List(new SimpleRecord(mockTime.milliseconds, "key".getBytes, "value".getBytes)), producerId = pid, producerEpoch = epoch, sequence = 0) + log.appendAsLeader(records, leaderEpoch = 0) + + val nextRecords = TestUtils.records(List(new SimpleRecord(mockTime.milliseconds, "key".getBytes, "value".getBytes)), producerId = pid, producerEpoch = epoch, sequence = 2) + assertThrows(classOf[OutOfOrderSequenceException], () => log.appendAsLeader(nextRecords, leaderEpoch = 0)) + } + + @Test + def testTruncateToEndOffsetClearsEpochCache(): Unit = { + val log = createLog(logDir, LogConfig()) + + // Seed some initial data in the log + val records = TestUtils.records(List(new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes)), + baseOffset = 27) + appendAsFollower(log, records, leaderEpoch = 19) + assertEquals(Some(EpochEntry(epoch = 19, startOffset = 27)), + log.leaderEpochCache.flatMap(_.latestEntry)) + assertEquals(29, log.logEndOffset) + + def verifyTruncationClearsEpochCache(epoch: Int, truncationOffset: Long): Unit = { + // Simulate becoming a leader + log.maybeAssignEpochStartOffset(leaderEpoch = epoch, startOffset = log.logEndOffset) + assertEquals(Some(EpochEntry(epoch = epoch, startOffset = 29)), + log.leaderEpochCache.flatMap(_.latestEntry)) + assertEquals(29, log.logEndOffset) + + // Now we become the follower and truncate to an offset greater + // than or equal to the log end offset. The trivial epoch entry + // at the end of the log should be gone + log.truncateTo(truncationOffset) + assertEquals(Some(EpochEntry(epoch = 19, startOffset = 27)), + log.leaderEpochCache.flatMap(_.latestEntry)) + assertEquals(29, log.logEndOffset) + } + + // Truncations greater than or equal to the log end offset should + // clear the epoch cache + verifyTruncationClearsEpochCache(epoch = 20, truncationOffset = log.logEndOffset) + verifyTruncationClearsEpochCache(epoch = 24, truncationOffset = log.logEndOffset + 1) + } + + /** + * Test the values returned by the logSegments call + */ + @Test + def testLogSegmentsCallCorrect(): Unit = { + // Create 3 segments and make sure we get the right values from various logSegments calls. + def createRecords = TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds) + def getSegmentOffsets(log :UnifiedLog, from: Long, to: Long) = log.logSegments(from, to).map { _.baseOffset } + val setSize = createRecords.sizeInBytes + val msgPerSeg = 10 + val segmentSize = msgPerSeg * setSize // each segment will be 10 messages + // create a log + val logConfig = LogTestUtils.createLogConfig(segmentBytes = segmentSize) + val log = createLog(logDir, logConfig) + assertEquals(1, log.numberOfSegments, "There should be exactly 1 segment.") + + // segments expire in size + for (_ <- 1 to (2 * msgPerSeg + 2)) + log.appendAsLeader(createRecords, leaderEpoch = 0) + assertEquals(3, log.numberOfSegments, "There should be exactly 3 segments.") + + // from == to should always be null + assertEquals(List.empty[LogSegment], getSegmentOffsets(log, 10, 10)) + assertEquals(List.empty[LogSegment], getSegmentOffsets(log, 15, 15)) + + assertEquals(List[Long](0, 10, 20), getSegmentOffsets(log, 0, 21)) + + assertEquals(List[Long](0), getSegmentOffsets(log, 1, 5)) + assertEquals(List[Long](10, 20), getSegmentOffsets(log, 13, 21)) + assertEquals(List[Long](10), getSegmentOffsets(log, 13, 17)) + + // from < to is bad + assertThrows(classOf[IllegalArgumentException], () => log.logSegments(10, 0)) + } + + @Test + def testInitializationOfProducerSnapshotsUpgradePath(): Unit = { + // simulate the upgrade path by creating a new log with several segments, deleting the + // snapshot files, and then reloading the log + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 64 * 10) + var log = createLog(logDir, logConfig) + assertEquals(None, log.oldestProducerSnapshotOffset) + + for (i <- 0 to 100) { + val record = new SimpleRecord(mockTime.milliseconds, i.toString.getBytes) + log.appendAsLeader(TestUtils.records(List(record)), leaderEpoch = 0) + } + assertTrue(log.logSegments.size >= 2) + val logEndOffset = log.logEndOffset + log.close() + + LogTestUtils.deleteProducerSnapshotFiles(logDir) + + // Reload after clean shutdown + log = createLog(logDir, logConfig, recoveryPoint = logEndOffset) + var expectedSnapshotOffsets = log.logSegments.map(_.baseOffset).takeRight(2).toVector :+ log.logEndOffset + assertEquals(expectedSnapshotOffsets, LogTestUtils.listProducerSnapshotOffsets(logDir)) + log.close() + + LogTestUtils.deleteProducerSnapshotFiles(logDir) + + // Reload after unclean shutdown with recoveryPoint set to log end offset + log = createLog(logDir, logConfig, recoveryPoint = logEndOffset, lastShutdownClean = false) + assertEquals(expectedSnapshotOffsets, LogTestUtils.listProducerSnapshotOffsets(logDir)) + log.close() + + LogTestUtils.deleteProducerSnapshotFiles(logDir) + + // Reload after unclean shutdown with recoveryPoint set to 0 + log = createLog(logDir, logConfig, recoveryPoint = 0L, lastShutdownClean = false) + // We progressively create a snapshot for each segment after the recovery point + expectedSnapshotOffsets = log.logSegments.map(_.baseOffset).tail.toVector :+ log.logEndOffset + assertEquals(expectedSnapshotOffsets, LogTestUtils.listProducerSnapshotOffsets(logDir)) + log.close() + } + + @Test + def testLogReinitializeAfterManualDelete(): Unit = { + val logConfig = LogTestUtils.createLogConfig() + // simulate a case where log data does not exist but the start offset is non-zero + val log = createLog(logDir, logConfig, logStartOffset = 500) + assertEquals(500, log.logStartOffset) + assertEquals(500, log.logEndOffset) + } + + /** + * Test that "PeriodicProducerExpirationCheck" scheduled task gets canceled after log + * is deleted. + */ + @Test + def testProducerExpireCheckAfterDelete(): Unit = { + val scheduler = new KafkaScheduler(1) + try { + scheduler.startup() + val logConfig = LogTestUtils.createLogConfig() + val log = createLog(logDir, logConfig, scheduler = scheduler) + + val producerExpireCheck = log.producerExpireCheck + assertTrue(scheduler.taskRunning(producerExpireCheck), "producerExpireCheck isn't as part of scheduled tasks") + + log.delete() + assertFalse(scheduler.taskRunning(producerExpireCheck), + "producerExpireCheck is part of scheduled tasks even after log deletion") + } finally { + scheduler.shutdown() + } + } + + @Test + def testProducerIdMapOffsetUpdatedForNonIdempotentData(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5) + val log = createLog(logDir, logConfig) + val records = TestUtils.records(List(new SimpleRecord(mockTime.milliseconds, "key".getBytes, "value".getBytes))) + log.appendAsLeader(records, leaderEpoch = 0) + log.takeProducerSnapshot() + assertEquals(Some(1), log.latestProducerSnapshotOffset) + } + + @Test + def testRebuildProducerIdMapWithCompactedData(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5) + val log = createLog(logDir, logConfig) + val pid = 1L + val epoch = 0.toShort + val seq = 0 + val baseOffset = 23L + + // create a batch with a couple gaps to simulate compaction + val records = TestUtils.records(producerId = pid, producerEpoch = epoch, sequence = seq, baseOffset = baseOffset, records = List( + new SimpleRecord(mockTime.milliseconds(), "a".getBytes), + new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "b".getBytes), + new SimpleRecord(mockTime.milliseconds(), "c".getBytes), + new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "d".getBytes))) + records.batches.forEach(_.setPartitionLeaderEpoch(0)) + + val filtered = ByteBuffer.allocate(2048) + records.filterTo(new TopicPartition("foo", 0), new RecordFilter(0, 0) { + override def checkBatchRetention(batch: RecordBatch): RecordFilter.BatchRetentionResult = + new RecordFilter.BatchRetentionResult(RecordFilter.BatchRetention.DELETE_EMPTY, false) + override def shouldRetainRecord(recordBatch: RecordBatch, record: Record): Boolean = !record.hasKey + }, filtered, Int.MaxValue, BufferSupplier.NO_CACHING) + filtered.flip() + val filteredRecords = MemoryRecords.readableRecords(filtered) + + log.appendAsFollower(filteredRecords) + + // append some more data and then truncate to force rebuilding of the PID map + val moreRecords = TestUtils.records(baseOffset = baseOffset + 4, records = List( + new SimpleRecord(mockTime.milliseconds(), "e".getBytes), + new SimpleRecord(mockTime.milliseconds(), "f".getBytes))) + moreRecords.batches.forEach(_.setPartitionLeaderEpoch(0)) + log.appendAsFollower(moreRecords) + + log.truncateTo(baseOffset + 4) + + val activeProducers = log.activeProducersWithLastSequence + assertTrue(activeProducers.contains(pid)) + + val lastSeq = activeProducers(pid) + assertEquals(3, lastSeq) + } + + @Test + def testRebuildProducerStateWithEmptyCompactedBatch(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5) + val log = createLog(logDir, logConfig) + val pid = 1L + val epoch = 0.toShort + val seq = 0 + val baseOffset = 23L + + // create an empty batch + val records = TestUtils.records(producerId = pid, producerEpoch = epoch, sequence = seq, baseOffset = baseOffset, records = List( + new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "a".getBytes), + new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "b".getBytes))) + records.batches.forEach(_.setPartitionLeaderEpoch(0)) + + val filtered = ByteBuffer.allocate(2048) + records.filterTo(new TopicPartition("foo", 0), new RecordFilter(0, 0) { + override def checkBatchRetention(batch: RecordBatch): RecordFilter.BatchRetentionResult = + new RecordFilter.BatchRetentionResult(RecordFilter.BatchRetention.RETAIN_EMPTY, true) + override def shouldRetainRecord(recordBatch: RecordBatch, record: Record): Boolean = false + }, filtered, Int.MaxValue, BufferSupplier.NO_CACHING) + filtered.flip() + val filteredRecords = MemoryRecords.readableRecords(filtered) + + log.appendAsFollower(filteredRecords) + + // append some more data and then truncate to force rebuilding of the PID map + val moreRecords = TestUtils.records(baseOffset = baseOffset + 2, records = List( + new SimpleRecord(mockTime.milliseconds(), "e".getBytes), + new SimpleRecord(mockTime.milliseconds(), "f".getBytes))) + moreRecords.batches.forEach(_.setPartitionLeaderEpoch(0)) + log.appendAsFollower(moreRecords) + + log.truncateTo(baseOffset + 2) + + val activeProducers = log.activeProducersWithLastSequence + assertTrue(activeProducers.contains(pid)) + + val lastSeq = activeProducers(pid) + assertEquals(1, lastSeq) + } + + @Test + def testUpdateProducerIdMapWithCompactedData(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5) + val log = createLog(logDir, logConfig) + val pid = 1L + val epoch = 0.toShort + val seq = 0 + val baseOffset = 23L + + // create a batch with a couple gaps to simulate compaction + val records = TestUtils.records(producerId = pid, producerEpoch = epoch, sequence = seq, baseOffset = baseOffset, records = List( + new SimpleRecord(mockTime.milliseconds(), "a".getBytes), + new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "b".getBytes), + new SimpleRecord(mockTime.milliseconds(), "c".getBytes), + new SimpleRecord(mockTime.milliseconds(), "key".getBytes, "d".getBytes))) + records.batches.forEach(_.setPartitionLeaderEpoch(0)) + + val filtered = ByteBuffer.allocate(2048) + records.filterTo(new TopicPartition("foo", 0), new RecordFilter(0, 0) { + override def checkBatchRetention(batch: RecordBatch): RecordFilter.BatchRetentionResult = + new RecordFilter.BatchRetentionResult(RecordFilter.BatchRetention.DELETE_EMPTY, false) + override def shouldRetainRecord(recordBatch: RecordBatch, record: Record): Boolean = !record.hasKey + }, filtered, Int.MaxValue, BufferSupplier.NO_CACHING) + filtered.flip() + val filteredRecords = MemoryRecords.readableRecords(filtered) + + log.appendAsFollower(filteredRecords) + val activeProducers = log.activeProducersWithLastSequence + assertTrue(activeProducers.contains(pid)) + + val lastSeq = activeProducers(pid) + assertEquals(3, lastSeq) + } + + @Test + def testProducerIdMapTruncateTo(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5) + val log = createLog(logDir, logConfig) + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes))), leaderEpoch = 0) + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("b".getBytes))), leaderEpoch = 0) + log.takeProducerSnapshot() + + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("c".getBytes))), leaderEpoch = 0) + log.takeProducerSnapshot() + + log.truncateTo(2) + assertEquals(Some(2), log.latestProducerSnapshotOffset) + assertEquals(2, log.latestProducerStateEndOffset) + + log.truncateTo(1) + assertEquals(Some(1), log.latestProducerSnapshotOffset) + assertEquals(1, log.latestProducerStateEndOffset) + + log.truncateTo(0) + assertEquals(None, log.latestProducerSnapshotOffset) + assertEquals(0, log.latestProducerStateEndOffset) + } + + @Test + def testProducerIdMapTruncateToWithNoSnapshots(): Unit = { + // This ensures that the upgrade optimization path cannot be hit after initial loading + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5) + val log = createLog(logDir, logConfig) + val pid = 1L + val epoch = 0.toShort + + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes)), producerId = pid, + producerEpoch = epoch, sequence = 0), leaderEpoch = 0) + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("b".getBytes)), producerId = pid, + producerEpoch = epoch, sequence = 1), leaderEpoch = 0) + + LogTestUtils.deleteProducerSnapshotFiles(logDir) + + log.truncateTo(1L) + assertEquals(1, log.activeProducersWithLastSequence.size) + + val lastSeqOpt = log.activeProducersWithLastSequence.get(pid) + assertTrue(lastSeqOpt.isDefined) + + val lastSeq = lastSeqOpt.get + assertEquals(0, lastSeq) + } + + @Test + def testRetentionDeletesProducerStateSnapshots(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5, retentionBytes = 0, retentionMs = 1000 * 60, fileDeleteDelayMs = 0) + val log = createLog(logDir, logConfig) + val pid1 = 1L + val epoch = 0.toShort + + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes)), producerId = pid1, + producerEpoch = epoch, sequence = 0), leaderEpoch = 0) + log.roll() + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("b".getBytes)), producerId = pid1, + producerEpoch = epoch, sequence = 1), leaderEpoch = 0) + log.roll() + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("c".getBytes)), producerId = pid1, + producerEpoch = epoch, sequence = 2), leaderEpoch = 0) + + log.updateHighWatermark(log.logEndOffset) + + assertEquals(2, ProducerStateManager.listSnapshotFiles(logDir).size) + // Sleep to breach the retention period + mockTime.sleep(1000 * 60 + 1) + log.deleteOldSegments() + // Sleep to breach the file delete delay and run scheduled file deletion tasks + mockTime.sleep(1) + assertEquals(1, ProducerStateManager.listSnapshotFiles(logDir).size, + "expect a single producer state snapshot remaining") + } + + @Test + def testRetentionIdempotency(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5, retentionBytes = -1, retentionMs = 900, fileDeleteDelayMs = 0) + val log = createLog(logDir, logConfig) + + log.appendAsLeader(TestUtils.records(List(new SimpleRecord(mockTime.milliseconds() + 100, "a".getBytes))), leaderEpoch = 0) + log.roll() + log.appendAsLeader(TestUtils.records(List(new SimpleRecord(mockTime.milliseconds(), "b".getBytes))), leaderEpoch = 0) + log.roll() + log.appendAsLeader(TestUtils.records(List(new SimpleRecord(mockTime.milliseconds() + 100, "c".getBytes))), leaderEpoch = 0) + + mockTime.sleep(901) + + log.updateHighWatermark(log.logEndOffset) + log.maybeIncrementLogStartOffset(1L, ClientRecordDeletion) + assertEquals(2, log.deleteOldSegments(), + "Expecting two segment deletions as log start offset retention should unblock time based retention") + assertEquals(0, log.deleteOldSegments()) + } + + + @Test + def testLogStartOffsetMovementDeletesSnapshots(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5, retentionBytes = -1, fileDeleteDelayMs = 0) + val log = createLog(logDir, logConfig) + val pid1 = 1L + val epoch = 0.toShort + + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes)), producerId = pid1, + producerEpoch = epoch, sequence = 0), leaderEpoch = 0) + log.roll() + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("b".getBytes)), producerId = pid1, + producerEpoch = epoch, sequence = 1), leaderEpoch = 0) + log.roll() + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("c".getBytes)), producerId = pid1, + producerEpoch = epoch, sequence = 2), leaderEpoch = 0) + log.updateHighWatermark(log.logEndOffset) + assertEquals(2, ProducerStateManager.listSnapshotFiles(logDir).size) + + // Increment the log start offset to exclude the first two segments. + log.maybeIncrementLogStartOffset(log.logEndOffset - 1, ClientRecordDeletion) + log.deleteOldSegments() + // Sleep to breach the file delete delay and run scheduled file deletion tasks + mockTime.sleep(1) + assertEquals(1, ProducerStateManager.listSnapshotFiles(logDir).size, + "expect a single producer state snapshot remaining") + } + + @Test + def testCompactionDeletesProducerStateSnapshots(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5, cleanupPolicy = LogConfig.Compact, fileDeleteDelayMs = 0) + val log = createLog(logDir, logConfig) + val pid1 = 1L + val epoch = 0.toShort + val cleaner = new Cleaner(id = 0, + offsetMap = new FakeOffsetMap(Int.MaxValue), + ioBufferSize = 64 * 1024, + maxIoBufferSize = 64 * 1024, + dupBufferLoadFactor = 0.75, + throttler = new Throttler(Double.MaxValue, Long.MaxValue, false, time = mockTime), + time = mockTime, + checkDone = _ => {}) + + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes, "a".getBytes())), producerId = pid1, + producerEpoch = epoch, sequence = 0), leaderEpoch = 0) + log.roll() + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes, "b".getBytes())), producerId = pid1, + producerEpoch = epoch, sequence = 1), leaderEpoch = 0) + log.roll() + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("a".getBytes, "c".getBytes())), producerId = pid1, + producerEpoch = epoch, sequence = 2), leaderEpoch = 0) + log.updateHighWatermark(log.logEndOffset) + assertEquals(log.logSegments.map(_.baseOffset).toSeq.sorted.drop(1), ProducerStateManager.listSnapshotFiles(logDir).map(_.offset).sorted, + "expected a snapshot file per segment base offset, except the first segment") + assertEquals(2, ProducerStateManager.listSnapshotFiles(logDir).size) + + // Clean segments, this should delete everything except the active segment since there only + // exists the key "a". + cleaner.clean(LogToClean(log.topicPartition, log, 0, log.logEndOffset)) + log.deleteOldSegments() + // Sleep to breach the file delete delay and run scheduled file deletion tasks + mockTime.sleep(1) + assertEquals(log.logSegments.map(_.baseOffset).toSeq.sorted.drop(1), ProducerStateManager.listSnapshotFiles(logDir).map(_.offset).sorted, + "expected a snapshot file per segment base offset, excluding the first") + } + + /** + * After loading the log, producer state is truncated such that there are no producer state snapshot files which + * exceed the log end offset. This test verifies that these are removed. + */ + @Test + def testLoadingLogDeletesProducerStateSnapshotsPastLogEndOffset(): Unit = { + val straySnapshotFile = UnifiedLog.producerSnapshotFile(logDir, 42).toPath + Files.createFile(straySnapshotFile) + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5, retentionBytes = -1, fileDeleteDelayMs = 0) + createLog(logDir, logConfig) + assertEquals(0, ProducerStateManager.listSnapshotFiles(logDir).size, + "expected producer state snapshots greater than the log end offset to be cleaned up") + } + + @Test + def testProducerIdMapTruncateFullyAndStartAt(): Unit = { + val records = TestUtils.singletonRecords("foo".getBytes) + val logConfig = LogTestUtils.createLogConfig(segmentBytes = records.sizeInBytes, retentionBytes = records.sizeInBytes * 2) + val log = createLog(logDir, logConfig) + log.appendAsLeader(records, leaderEpoch = 0) + log.takeProducerSnapshot() + + log.appendAsLeader(TestUtils.singletonRecords("bar".getBytes), leaderEpoch = 0) + log.appendAsLeader(TestUtils.singletonRecords("baz".getBytes), leaderEpoch = 0) + log.takeProducerSnapshot() + + assertEquals(3, log.logSegments.size) + assertEquals(3, log.latestProducerStateEndOffset) + assertEquals(Some(3), log.latestProducerSnapshotOffset) + + log.truncateFullyAndStartAt(29) + assertEquals(1, log.logSegments.size) + assertEquals(None, log.latestProducerSnapshotOffset) + assertEquals(29, log.latestProducerStateEndOffset) + } + + @Test + def testProducerIdExpirationOnSegmentDeletion(): Unit = { + val pid1 = 1L + val records = TestUtils.records(Seq(new SimpleRecord("foo".getBytes)), producerId = pid1, producerEpoch = 0, sequence = 0) + val logConfig = LogTestUtils.createLogConfig(segmentBytes = records.sizeInBytes, retentionBytes = records.sizeInBytes * 2) + val log = createLog(logDir, logConfig) + log.appendAsLeader(records, leaderEpoch = 0) + log.takeProducerSnapshot() + + val pid2 = 2L + log.appendAsLeader(TestUtils.records(Seq(new SimpleRecord("bar".getBytes)), producerId = pid2, producerEpoch = 0, sequence = 0), + leaderEpoch = 0) + log.appendAsLeader(TestUtils.records(Seq(new SimpleRecord("baz".getBytes)), producerId = pid2, producerEpoch = 0, sequence = 1), + leaderEpoch = 0) + log.takeProducerSnapshot() + + assertEquals(3, log.logSegments.size) + assertEquals(Set(pid1, pid2), log.activeProducersWithLastSequence.keySet) + + log.updateHighWatermark(log.logEndOffset) + log.deleteOldSegments() + + // Producer state should not be removed when deleting log segment + assertEquals(2, log.logSegments.size) + assertEquals(Set(pid1, pid2), log.activeProducersWithLastSequence.keySet) + } + + @Test + def testTakeSnapshotOnRollAndDeleteSnapshotOnRecoveryPointCheckpoint(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5) + val log = createLog(logDir, logConfig) + log.appendAsLeader(TestUtils.singletonRecords("a".getBytes), leaderEpoch = 0) + log.roll(Some(1L)) + assertEquals(Some(1L), log.latestProducerSnapshotOffset) + assertEquals(Some(1L), log.oldestProducerSnapshotOffset) + + log.appendAsLeader(TestUtils.singletonRecords("b".getBytes), leaderEpoch = 0) + log.roll(Some(2L)) + assertEquals(Some(2L), log.latestProducerSnapshotOffset) + assertEquals(Some(1L), log.oldestProducerSnapshotOffset) + + log.appendAsLeader(TestUtils.singletonRecords("c".getBytes), leaderEpoch = 0) + log.roll(Some(3L)) + assertEquals(Some(3L), log.latestProducerSnapshotOffset) + + // roll triggers a flush at the starting offset of the new segment, we should retain all snapshots + assertEquals(Some(1L), log.oldestProducerSnapshotOffset) + + // even if we flush within the active segment, the snapshot should remain + log.appendAsLeader(TestUtils.singletonRecords("baz".getBytes), leaderEpoch = 0) + log.flush(4L) + assertEquals(Some(3L), log.latestProducerSnapshotOffset) + } + + @Test + def testProducerSnapshotAfterSegmentRollOnAppend(): Unit = { + val producerId = 1L + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024) + val log = createLog(logDir, logConfig) + + log.appendAsLeader(TestUtils.records(Seq(new SimpleRecord(mockTime.milliseconds(), new Array[Byte](512))), + producerId = producerId, producerEpoch = 0, sequence = 0), + leaderEpoch = 0) + + // The next append should overflow the segment and cause it to roll + log.appendAsLeader(TestUtils.records(Seq(new SimpleRecord(mockTime.milliseconds(), new Array[Byte](512))), + producerId = producerId, producerEpoch = 0, sequence = 1), + leaderEpoch = 0) + + assertEquals(2, log.logSegments.size) + assertEquals(1L, log.activeSegment.baseOffset) + assertEquals(Some(1L), log.latestProducerSnapshotOffset) + + // Force a reload from the snapshot to check its consistency + log.truncateTo(1L) + + assertEquals(2, log.logSegments.size) + assertEquals(1L, log.activeSegment.baseOffset) + assertTrue(log.activeSegment.log.batches.asScala.isEmpty) + assertEquals(Some(1L), log.latestProducerSnapshotOffset) + + val lastEntry = log.producerStateManager.lastEntry(producerId) + assertTrue(lastEntry.isDefined) + assertEquals(0L, lastEntry.get.firstDataOffset) + assertEquals(0L, lastEntry.get.lastDataOffset) + } + + @Test + def testRebuildTransactionalState(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024 * 5) + val log = createLog(logDir, logConfig) + + val pid = 137L + val epoch = 5.toShort + val seq = 0 + + // add some transactional records + val records = MemoryRecords.withTransactionalRecords(CompressionType.NONE, pid, epoch, seq, + new SimpleRecord("foo".getBytes), + new SimpleRecord("bar".getBytes), + new SimpleRecord("baz".getBytes)) + log.appendAsLeader(records, leaderEpoch = 0) + val abortAppendInfo = LogTestUtils.appendEndTxnMarkerAsLeader(log, pid, epoch, ControlRecordType.ABORT, mockTime.milliseconds()) + log.updateHighWatermark(abortAppendInfo.lastOffset + 1) + + // now there should be no first unstable offset + assertEquals(None, log.firstUnstableOffset) + + log.close() + + val reopenedLog = createLog(logDir, logConfig, lastShutdownClean = false) + reopenedLog.updateHighWatermark(abortAppendInfo.lastOffset + 1) + assertEquals(None, reopenedLog.firstUnstableOffset) + } + + @Test + def testPeriodicProducerIdExpiration(): Unit = { + val maxProducerIdExpirationMs = 200 + val producerIdExpirationCheckIntervalMs = 100 + + val pid = 23L + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5) + val log = createLog(logDir, logConfig, maxProducerIdExpirationMs = maxProducerIdExpirationMs, + producerIdExpirationCheckIntervalMs = producerIdExpirationCheckIntervalMs) + val records = Seq(new SimpleRecord(mockTime.milliseconds(), "foo".getBytes)) + log.appendAsLeader(TestUtils.records(records, producerId = pid, producerEpoch = 0, sequence = 0), leaderEpoch = 0) + + assertEquals(Set(pid), log.activeProducersWithLastSequence.keySet) + + mockTime.sleep(producerIdExpirationCheckIntervalMs) + assertEquals(Set(pid), log.activeProducersWithLastSequence.keySet) + + mockTime.sleep(producerIdExpirationCheckIntervalMs) + assertEquals(Set(), log.activeProducersWithLastSequence.keySet) + } + + @Test + def testDuplicateAppends(): Unit = { + // create a log + val log = createLog(logDir, LogConfig()) + val pid = 1L + val epoch: Short = 0 + + var seq = 0 + // Pad the beginning of the log. + for (_ <- 0 to 5) { + val record = TestUtils.records(List(new SimpleRecord(mockTime.milliseconds, "key".getBytes, "value".getBytes)), + producerId = pid, producerEpoch = epoch, sequence = seq) + log.appendAsLeader(record, leaderEpoch = 0) + seq = seq + 1 + } + // Append an entry with multiple log records. + def createRecords = TestUtils.records(List( + new SimpleRecord(mockTime.milliseconds, s"key-$seq".getBytes, s"value-$seq".getBytes), + new SimpleRecord(mockTime.milliseconds, s"key-$seq".getBytes, s"value-$seq".getBytes), + new SimpleRecord(mockTime.milliseconds, s"key-$seq".getBytes, s"value-$seq".getBytes) + ), producerId = pid, producerEpoch = epoch, sequence = seq) + val multiEntryAppendInfo = log.appendAsLeader(createRecords, leaderEpoch = 0) + assertEquals( + multiEntryAppendInfo.lastOffset - multiEntryAppendInfo.firstOffset.get.messageOffset + 1, + 3, + "should have appended 3 entries" + ) + + // Append a Duplicate of the tail, when the entry at the tail has multiple records. + val dupMultiEntryAppendInfo = log.appendAsLeader(createRecords, leaderEpoch = 0) + assertEquals( + multiEntryAppendInfo.firstOffset.get.messageOffset, + dupMultiEntryAppendInfo.firstOffset.get.messageOffset, + "Somehow appended a duplicate entry with multiple log records to the tail" + ) + assertEquals(multiEntryAppendInfo.lastOffset, dupMultiEntryAppendInfo.lastOffset, + "Somehow appended a duplicate entry with multiple log records to the tail") + + seq = seq + 3 + + // Append a partial duplicate of the tail. This is not allowed. + var records = TestUtils.records( + List( + new SimpleRecord(mockTime.milliseconds, s"key-$seq".getBytes, s"value-$seq".getBytes), + new SimpleRecord(mockTime.milliseconds, s"key-$seq".getBytes, s"value-$seq".getBytes)), + producerId = pid, producerEpoch = epoch, sequence = seq - 2) + assertThrows(classOf[OutOfOrderSequenceException], () => log.appendAsLeader(records, leaderEpoch = 0), + () => "Should have received an OutOfOrderSequenceException since we attempted to append a duplicate of a records in the middle of the log.") + + // Append a duplicate of the batch which is 4th from the tail. This should succeed without error since we + // retain the batch metadata of the last 5 batches. + val duplicateOfFourth = TestUtils.records(List(new SimpleRecord(mockTime.milliseconds, "key".getBytes, "value".getBytes)), + producerId = pid, producerEpoch = epoch, sequence = 2) + log.appendAsLeader(duplicateOfFourth, leaderEpoch = 0) + + // Duplicates at older entries are reported as OutOfOrderSequence errors + records = TestUtils.records( + List(new SimpleRecord(mockTime.milliseconds, s"key-1".getBytes, s"value-1".getBytes)), + producerId = pid, producerEpoch = epoch, sequence = 1) + assertThrows(classOf[OutOfOrderSequenceException], () => log.appendAsLeader(records, leaderEpoch = 0), + () => "Should have received an OutOfOrderSequenceException since we attempted to append a duplicate of a batch which is older than the last 5 appended batches.") + + // Append a duplicate entry with a single records at the tail of the log. This should return the appendInfo of the original entry. + def createRecordsWithDuplicate = TestUtils.records(List(new SimpleRecord(mockTime.milliseconds, "key".getBytes, "value".getBytes)), + producerId = pid, producerEpoch = epoch, sequence = seq) + val origAppendInfo = log.appendAsLeader(createRecordsWithDuplicate, leaderEpoch = 0) + val newAppendInfo = log.appendAsLeader(createRecordsWithDuplicate, leaderEpoch = 0) + assertEquals( + origAppendInfo.firstOffset.get.messageOffset, + newAppendInfo.firstOffset.get.messageOffset, + "Inserted a duplicate records into the log" + ) + assertEquals(origAppendInfo.lastOffset, newAppendInfo.lastOffset, + "Inserted a duplicate records into the log") + } + + @Test + def testMultipleProducerIdsPerMemoryRecord(): Unit = { + // create a log + val log = createLog(logDir, LogConfig()) + + val epoch: Short = 0 + val buffer = ByteBuffer.allocate(512) + + var builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, CompressionType.NONE, + TimestampType.LOG_APPEND_TIME, 0L, mockTime.milliseconds(), 1L, epoch, 0, false, 0) + builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) + builder.close() + + builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, CompressionType.NONE, + TimestampType.LOG_APPEND_TIME, 1L, mockTime.milliseconds(), 2L, epoch, 0, false, 0) + builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) + builder.close() + + builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, CompressionType.NONE, + TimestampType.LOG_APPEND_TIME, 2L, mockTime.milliseconds(), 3L, epoch, 0, false, 0) + builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) + builder.close() + + builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, CompressionType.NONE, + TimestampType.LOG_APPEND_TIME, 3L, mockTime.milliseconds(), 4L, epoch, 0, false, 0) + builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) + builder.close() + + buffer.flip() + val memoryRecords = MemoryRecords.readableRecords(buffer) + + log.appendAsFollower(memoryRecords) + log.flush() + + val fetchedData = LogTestUtils.readLog(log, 0, Int.MaxValue) + + val origIterator = memoryRecords.batches.iterator() + for (batch <- fetchedData.records.batches.asScala) { + assertTrue(origIterator.hasNext) + val origEntry = origIterator.next() + assertEquals(origEntry.producerId, batch.producerId) + assertEquals(origEntry.baseOffset, batch.baseOffset) + assertEquals(origEntry.baseSequence, batch.baseSequence) + } + } + + @Test + def testDuplicateAppendToFollower(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024 * 5) + val log = createLog(logDir, logConfig) + val epoch: Short = 0 + val pid = 1L + val baseSequence = 0 + val partitionLeaderEpoch = 0 + // The point of this test is to ensure that validation isn't performed on the follower. + // this is a bit contrived. to trigger the duplicate case for a follower append, we have to append + // a batch with matching sequence numbers, but valid increasing offsets + assertEquals(0L, log.logEndOffset) + log.appendAsFollower(MemoryRecords.withIdempotentRecords(0L, CompressionType.NONE, pid, epoch, baseSequence, + partitionLeaderEpoch, new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes))) + log.appendAsFollower(MemoryRecords.withIdempotentRecords(2L, CompressionType.NONE, pid, epoch, baseSequence, + partitionLeaderEpoch, new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes))) + + // Ensure that even the duplicate sequences are accepted on the follower. + assertEquals(4L, log.logEndOffset) + } + + @Test + def testMultipleProducersWithDuplicatesInSingleAppend(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024 * 5) + val log = createLog(logDir, logConfig) + + val pid1 = 1L + val pid2 = 2L + val epoch: Short = 0 + + val buffer = ByteBuffer.allocate(512) + + // pid1 seq = 0 + var builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE, + TimestampType.LOG_APPEND_TIME, 0L, mockTime.milliseconds(), pid1, epoch, 0) + builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) + builder.close() + + // pid2 seq = 0 + builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE, + TimestampType.LOG_APPEND_TIME, 1L, mockTime.milliseconds(), pid2, epoch, 0) + builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) + builder.close() + + // pid1 seq = 1 + builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE, + TimestampType.LOG_APPEND_TIME, 2L, mockTime.milliseconds(), pid1, epoch, 1) + builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) + builder.close() + + // pid2 seq = 1 + builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE, + TimestampType.LOG_APPEND_TIME, 3L, mockTime.milliseconds(), pid2, epoch, 1) + builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) + builder.close() + + // // pid1 seq = 1 (duplicate) + builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE, + TimestampType.LOG_APPEND_TIME, 4L, mockTime.milliseconds(), pid1, epoch, 1) + builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) + builder.close() + + buffer.flip() + + val records = MemoryRecords.readableRecords(buffer) + records.batches.forEach(_.setPartitionLeaderEpoch(0)) + + // Ensure that batches with duplicates are accepted on the follower. + assertEquals(0L, log.logEndOffset) + log.appendAsFollower(records) + assertEquals(5L, log.logEndOffset) + } + + @Test + def testOldProducerEpoch(): Unit = { + // create a log + val log = createLog(logDir, LogConfig()) + val pid = 1L + val newEpoch: Short = 1 + val oldEpoch: Short = 0 + + val records = TestUtils.records(List(new SimpleRecord(mockTime.milliseconds, "key".getBytes, "value".getBytes)), producerId = pid, producerEpoch = newEpoch, sequence = 0) + log.appendAsLeader(records, leaderEpoch = 0) + + val nextRecords = TestUtils.records(List(new SimpleRecord(mockTime.milliseconds, "key".getBytes, "value".getBytes)), producerId = pid, producerEpoch = oldEpoch, sequence = 0) + assertThrows(classOf[InvalidProducerEpochException], () => log.appendAsLeader(nextRecords, leaderEpoch = 0)) + } + + @Test + def testDeleteSnapshotsOnIncrementLogStartOffset(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 2048 * 5) + val log = createLog(logDir, logConfig) + val pid1 = 1L + val pid2 = 2L + val epoch = 0.toShort + + log.appendAsLeader(TestUtils.records(List(new SimpleRecord(mockTime.milliseconds(), "a".getBytes)), producerId = pid1, + producerEpoch = epoch, sequence = 0), leaderEpoch = 0) + log.roll() + log.appendAsLeader(TestUtils.records(List(new SimpleRecord(mockTime.milliseconds(), "b".getBytes)), producerId = pid2, + producerEpoch = epoch, sequence = 0), leaderEpoch = 0) + log.roll() + + assertEquals(2, log.activeProducersWithLastSequence.size) + assertEquals(2, ProducerStateManager.listSnapshotFiles(log.dir).size) + + log.updateHighWatermark(log.logEndOffset) + log.maybeIncrementLogStartOffset(2L, ClientRecordDeletion) + log.deleteOldSegments() // force retention to kick in so that the snapshot files are cleaned up. + mockTime.sleep(logConfig.fileDeleteDelayMs + 1000) // advance the clock so file deletion takes place + + // Deleting records should not remove producer state but should delete snapshots after the file deletion delay. + assertEquals(2, log.activeProducersWithLastSequence.size) + assertEquals(1, ProducerStateManager.listSnapshotFiles(log.dir).size) + val retainedLastSeqOpt = log.activeProducersWithLastSequence.get(pid2) + assertTrue(retainedLastSeqOpt.isDefined) + assertEquals(0, retainedLastSeqOpt.get) + } + + /** + * Test for jitter s for time based log roll. This test appends messages then changes the time + * using the mock clock to force the log to roll and checks the number of segments. + */ + @Test + def testTimeBasedLogRollJitter(): Unit = { + var set = TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds) + val maxJitter = 20 * 60L + // create a log + val logConfig = LogTestUtils.createLogConfig(segmentMs = 1 * 60 * 60L, segmentJitterMs = maxJitter) + val log = createLog(logDir, logConfig) + assertEquals(1, log.numberOfSegments, "Log begins with a single empty segment.") + log.appendAsLeader(set, leaderEpoch = 0) + + mockTime.sleep(log.config.segmentMs - maxJitter) + set = TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds) + log.appendAsLeader(set, leaderEpoch = 0) + assertEquals(1, log.numberOfSegments, + "Log does not roll on this append because it occurs earlier than max jitter") + mockTime.sleep(maxJitter - log.activeSegment.rollJitterMs + 1) + set = TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds) + log.appendAsLeader(set, leaderEpoch = 0) + assertEquals(2, log.numberOfSegments, + "Log should roll after segmentMs adjusted by random jitter") + } + + /** + * Test that appending more than the maximum segment size rolls the log + */ + @Test + def testSizeBasedLogRoll(): Unit = { + def createRecords = TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds) + val setSize = createRecords.sizeInBytes + val msgPerSeg = 10 + val segmentSize = msgPerSeg * (setSize - 1) // each segment will be 10 messages + // create a log + val logConfig = LogTestUtils.createLogConfig(segmentBytes = segmentSize) + val log = createLog(logDir, logConfig) + assertEquals(1, log.numberOfSegments, "There should be exactly 1 segment.") + + // segments expire in size + for (_ <- 1 to (msgPerSeg + 1)) + log.appendAsLeader(createRecords, leaderEpoch = 0) + assertEquals(2, log.numberOfSegments, + "There should be exactly 2 segments.") + } + + /** + * Test that we can open and append to an empty log + */ + @Test + def testLoadEmptyLog(): Unit = { + createEmptyLogs(logDir, 0) + val log = createLog(logDir, LogConfig()) + log.appendAsLeader(TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds), leaderEpoch = 0) + } + + /** + * This test case appends a bunch of messages and checks that we can read them all back using sequential offsets. + */ + @Test + def testAppendAndReadWithSequentialOffsets(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 71) + val log = createLog(logDir, logConfig) + val values = (0 until 100 by 2).map(id => id.toString.getBytes).toArray + + for(value <- values) + log.appendAsLeader(TestUtils.singletonRecords(value = value), leaderEpoch = 0) + + for(i <- values.indices) { + val read = LogTestUtils.readLog(log, i, 1).records.batches.iterator.next() + assertEquals(i, read.lastOffset, "Offset read should match order appended.") + val actual = read.iterator.next() + assertNull(actual.key, "Key should be null") + assertEquals(ByteBuffer.wrap(values(i)), actual.value, "Values not equal") + } + assertEquals(0, LogTestUtils.readLog(log, values.length, 100).records.batches.asScala.size, + "Reading beyond the last message returns nothing.") + } + + /** + * This test appends a bunch of messages with non-sequential offsets and checks that we can an the correct message + * from any offset less than the logEndOffset including offsets not appended. + */ + @Test + def testAppendAndReadWithNonSequentialOffsets(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 72) + val log = createLog(logDir, logConfig) + val messageIds = ((0 until 50) ++ (50 until 200 by 7)).toArray + val records = messageIds.map(id => new SimpleRecord(id.toString.getBytes)) + + // now test the case that we give the offsets and use non-sequential offsets + for(i <- records.indices) + log.appendAsFollower(MemoryRecords.withRecords(messageIds(i), CompressionType.NONE, 0, records(i))) + for(i <- 50 until messageIds.max) { + val idx = messageIds.indexWhere(_ >= i) + val read = LogTestUtils.readLog(log, i, 100).records.records.iterator.next() + assertEquals(messageIds(idx), read.offset, "Offset read should match message id.") + assertEquals(records(idx), new SimpleRecord(read), "Message should match appended.") + } + } + + /** + * This test covers an odd case where we have a gap in the offsets that falls at the end of a log segment. + * Specifically we create a log where the last message in the first segment has offset 0. If we + * then read offset 1, we should expect this read to come from the second segment, even though the + * first segment has the greatest lower bound on the offset. + */ + @Test + def testReadAtLogGap(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 300) + val log = createLog(logDir, logConfig) + + // keep appending until we have two segments with only a single message in the second segment + while(log.numberOfSegments == 1) + log.appendAsLeader(TestUtils.singletonRecords(value = "42".getBytes), leaderEpoch = 0) + + // now manually truncate off all but one message from the first segment to create a gap in the messages + log.logSegments.head.truncateTo(1) + + assertEquals(log.logEndOffset - 1, LogTestUtils.readLog(log, 1, 200).records.batches.iterator.next().lastOffset, + "A read should now return the last message in the log") + } + + @Test + def testLogRollAfterLogHandlerClosed(): Unit = { + val logConfig = LogTestUtils.createLogConfig() + val log = createLog(logDir, logConfig) + log.closeHandlers() + assertThrows(classOf[KafkaStorageException], () => log.roll(Some(1L))) + } + + @Test + def testReadWithMinMessage(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 72) + val log = createLog(logDir, logConfig) + val messageIds = ((0 until 50) ++ (50 until 200 by 7)).toArray + val records = messageIds.map(id => new SimpleRecord(id.toString.getBytes)) + + // now test the case that we give the offsets and use non-sequential offsets + for (i <- records.indices) + log.appendAsFollower(MemoryRecords.withRecords(messageIds(i), CompressionType.NONE, 0, records(i))) + + for (i <- 50 until messageIds.max) { + val idx = messageIds.indexWhere(_ >= i) + val reads = Seq( + LogTestUtils.readLog(log, i, 1), + LogTestUtils.readLog(log, i, 100000), + LogTestUtils.readLog(log, i, 100) + ).map(_.records.records.iterator.next()) + reads.foreach { read => + assertEquals(messageIds(idx), read.offset, "Offset read should match message id.") + assertEquals(records(idx), new SimpleRecord(read), "Message should match appended.") + } + } + } + + @Test + def testReadWithTooSmallMaxLength(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 72) + val log = createLog(logDir, logConfig) + val messageIds = ((0 until 50) ++ (50 until 200 by 7)).toArray + val records = messageIds.map(id => new SimpleRecord(id.toString.getBytes)) + + // now test the case that we give the offsets and use non-sequential offsets + for (i <- records.indices) + log.appendAsFollower(MemoryRecords.withRecords(messageIds(i), CompressionType.NONE, 0, records(i))) + + for (i <- 50 until messageIds.max) { + assertEquals(MemoryRecords.EMPTY, LogTestUtils.readLog(log, i, maxLength = 0, minOneMessage = false).records) + + // we return an incomplete message instead of an empty one for the case below + // we use this mechanism to tell consumers of the fetch request version 2 and below that the message size is + // larger than the fetch size + // in fetch request version 3, we no longer need this as we return oversized messages from the first non-empty + // partition + val fetchInfo = LogTestUtils.readLog(log, i, maxLength = 1, minOneMessage = false) + assertTrue(fetchInfo.firstEntryIncomplete) + assertTrue(fetchInfo.records.isInstanceOf[FileRecords]) + assertEquals(1, fetchInfo.records.sizeInBytes) + } + } + + /** + * Test reading at the boundary of the log, specifically + * - reading from the logEndOffset should give an empty message set + * - reading from the maxOffset should give an empty message set + * - reading beyond the log end offset should throw an OffsetOutOfRangeException + */ + @Test + def testReadOutOfRange(): Unit = { + createEmptyLogs(logDir, 1024) + // set up replica log starting with offset 1024 and with one message (at offset 1024) + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024) + val log = createLog(logDir, logConfig) + log.appendAsLeader(TestUtils.singletonRecords(value = "42".getBytes), leaderEpoch = 0) + + assertEquals(0, LogTestUtils.readLog(log, 1025, 1000).records.sizeInBytes, + "Reading at the log end offset should produce 0 byte read.") + + assertThrows(classOf[OffsetOutOfRangeException], () => LogTestUtils.readLog(log, 0, 1000)) + assertThrows(classOf[OffsetOutOfRangeException], () => LogTestUtils.readLog(log, 1026, 1000)) + } + + /** + * Test that covers reads and writes on a multisegment log. This test appends a bunch of messages + * and then reads them all back and checks that the message read and offset matches what was appended. + */ + @Test + def testLogRolls(): Unit = { + /* create a multipart log with 100 messages */ + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 100) + val log = createLog(logDir, logConfig) + val numMessages = 100 + val messageSets = (0 until numMessages).map(i => TestUtils.singletonRecords(value = i.toString.getBytes, + timestamp = mockTime.milliseconds)) + messageSets.foreach(log.appendAsLeader(_, leaderEpoch = 0)) + log.flush() + + /* do successive reads to ensure all our messages are there */ + var offset = 0L + for(i <- 0 until numMessages) { + val messages = LogTestUtils.readLog(log, offset, 1024*1024).records.batches + val head = messages.iterator.next() + assertEquals(offset, head.lastOffset, "Offsets not equal") + + val expected = messageSets(i).records.iterator.next() + val actual = head.iterator.next() + assertEquals(expected.key, actual.key, s"Keys not equal at offset $offset") + assertEquals(expected.value, actual.value, s"Values not equal at offset $offset") + assertEquals(expected.timestamp, actual.timestamp, s"Timestamps not equal at offset $offset") + offset = head.lastOffset + 1 + } + val lastRead = LogTestUtils.readLog(log, startOffset = numMessages, maxLength = 1024*1024).records + assertEquals(0, lastRead.records.asScala.size, "Should be no more messages") + + // check that rolling the log forced a flushed, the flush is async so retry in case of failure + TestUtils.retry(1000L){ + assertTrue(log.recoveryPoint >= log.activeSegment.baseOffset, "Log role should have forced flush") + } + } + + /** + * Test reads at offsets that fall within compressed message set boundaries. + */ + @Test + def testCompressedMessages(): Unit = { + /* this log should roll after every messageset */ + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 110) + val log = createLog(logDir, logConfig) + + /* append 2 compressed message sets, each with two messages giving offsets 0, 1, 2, 3 */ + log.appendAsLeader(MemoryRecords.withRecords(CompressionType.GZIP, new SimpleRecord("hello".getBytes), new SimpleRecord("there".getBytes)), leaderEpoch = 0) + log.appendAsLeader(MemoryRecords.withRecords(CompressionType.GZIP, new SimpleRecord("alpha".getBytes), new SimpleRecord("beta".getBytes)), leaderEpoch = 0) + + def read(offset: Int) = LogTestUtils.readLog(log, offset, 4096).records.records + + /* we should always get the first message in the compressed set when reading any offset in the set */ + assertEquals(0, read(0).iterator.next().offset, "Read at offset 0 should produce 0") + assertEquals(0, read(1).iterator.next().offset, "Read at offset 1 should produce 0") + assertEquals(2, read(2).iterator.next().offset, "Read at offset 2 should produce 2") + assertEquals(2, read(3).iterator.next().offset, "Read at offset 3 should produce 2") + } + + /** + * Test garbage collecting old segments + */ + @Test + def testThatGarbageCollectingSegmentsDoesntChangeOffset(): Unit = { + for(messagesToAppend <- List(0, 1, 25)) { + logDir.mkdirs() + // first test a log segment starting at 0 + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 100, retentionMs = 0) + val log = createLog(logDir, logConfig) + for(i <- 0 until messagesToAppend) + log.appendAsLeader(TestUtils.singletonRecords(value = i.toString.getBytes, timestamp = mockTime.milliseconds - 10), leaderEpoch = 0) + + val currOffset = log.logEndOffset + assertEquals(currOffset, messagesToAppend) + + // time goes by; the log file is deleted + log.updateHighWatermark(currOffset) + log.deleteOldSegments() + + assertEquals(currOffset, log.logEndOffset, "Deleting segments shouldn't have changed the logEndOffset") + assertEquals(1, log.numberOfSegments, "We should still have one segment left") + assertEquals(0, log.deleteOldSegments(), "Further collection shouldn't delete anything") + assertEquals(currOffset, log.logEndOffset, "Still no change in the logEndOffset") + assertEquals( + currOffset, + log.appendAsLeader( + TestUtils.singletonRecords(value = "hello".getBytes, timestamp = mockTime.milliseconds), + leaderEpoch = 0 + ).firstOffset.get.messageOffset, + "Should still be able to append and should get the logEndOffset assigned to the new append") + + // cleanup the log + log.delete() + } + } + + /** + * MessageSet size shouldn't exceed the config.segmentSize, check that it is properly enforced by + * appending a message set larger than the config.segmentSize setting and checking that an exception is thrown. + */ + @Test + def testMessageSetSizeCheck(): Unit = { + val messageSet = MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("You".getBytes), new SimpleRecord("bethe".getBytes)) + // append messages to log + val configSegmentSize = messageSet.sizeInBytes - 1 + val logConfig = LogTestUtils.createLogConfig(segmentBytes = configSegmentSize) + val log = createLog(logDir, logConfig) + + assertThrows(classOf[RecordBatchTooLargeException], () => log.appendAsLeader(messageSet, leaderEpoch = 0)) + } + + @Test + def testCompactedTopicConstraints(): Unit = { + val keyedMessage = new SimpleRecord("and here it is".getBytes, "this message has a key".getBytes) + val anotherKeyedMessage = new SimpleRecord("another key".getBytes, "this message also has a key".getBytes) + val unkeyedMessage = new SimpleRecord("this message does not have a key".getBytes) + + val messageSetWithUnkeyedMessage = MemoryRecords.withRecords(CompressionType.NONE, unkeyedMessage, keyedMessage) + val messageSetWithOneUnkeyedMessage = MemoryRecords.withRecords(CompressionType.NONE, unkeyedMessage) + val messageSetWithCompressedKeyedMessage = MemoryRecords.withRecords(CompressionType.GZIP, keyedMessage) + val messageSetWithCompressedUnkeyedMessage = MemoryRecords.withRecords(CompressionType.GZIP, keyedMessage, unkeyedMessage) + + val messageSetWithKeyedMessage = MemoryRecords.withRecords(CompressionType.NONE, keyedMessage) + val messageSetWithKeyedMessages = MemoryRecords.withRecords(CompressionType.NONE, keyedMessage, anotherKeyedMessage) + + val logConfig = LogTestUtils.createLogConfig(cleanupPolicy = LogConfig.Compact) + val log = createLog(logDir, logConfig) + + val errorMsgPrefix = "Compacted topic cannot accept message without key" + + var e = assertThrows(classOf[RecordValidationException], + () => log.appendAsLeader(messageSetWithUnkeyedMessage, leaderEpoch = 0)) + assertTrue(e.invalidException.isInstanceOf[InvalidRecordException]) + assertEquals(1, e.recordErrors.size) + assertEquals(0, e.recordErrors.head.batchIndex) + assertTrue(e.recordErrors.head.message.startsWith(errorMsgPrefix)) + + e = assertThrows(classOf[RecordValidationException], + () => log.appendAsLeader(messageSetWithOneUnkeyedMessage, leaderEpoch = 0)) + assertTrue(e.invalidException.isInstanceOf[InvalidRecordException]) + assertEquals(1, e.recordErrors.size) + assertEquals(0, e.recordErrors.head.batchIndex) + assertTrue(e.recordErrors.head.message.startsWith(errorMsgPrefix)) + + e = assertThrows(classOf[RecordValidationException], + () => log.appendAsLeader(messageSetWithCompressedUnkeyedMessage, leaderEpoch = 0)) + assertTrue(e.invalidException.isInstanceOf[InvalidRecordException]) + assertEquals(1, e.recordErrors.size) + assertEquals(1, e.recordErrors.head.batchIndex) // batch index is 1 + assertTrue(e.recordErrors.head.message.startsWith(errorMsgPrefix)) + + // check if metric for NoKeyCompactedTopicRecordsPerSec is logged + assertEquals(metricsKeySet.count(_.getMBeanName.endsWith(s"${BrokerTopicStats.NoKeyCompactedTopicRecordsPerSec}")), 1) + assertTrue(TestUtils.meterCount(s"${BrokerTopicStats.NoKeyCompactedTopicRecordsPerSec}") > 0) + + // the following should succeed without any InvalidMessageException + log.appendAsLeader(messageSetWithKeyedMessage, leaderEpoch = 0) + log.appendAsLeader(messageSetWithKeyedMessages, leaderEpoch = 0) + log.appendAsLeader(messageSetWithCompressedKeyedMessage, leaderEpoch = 0) + } + + /** + * We have a max size limit on message appends, check that it is properly enforced by appending a message larger than the + * setting and checking that an exception is thrown. + */ + @Test + def testMessageSizeCheck(): Unit = { + val first = MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("You".getBytes), new SimpleRecord("bethe".getBytes)) + val second = MemoryRecords.withRecords(CompressionType.NONE, + new SimpleRecord("change (I need more bytes)... blah blah blah.".getBytes), + new SimpleRecord("More padding boo hoo".getBytes)) + + // append messages to log + val maxMessageSize = second.sizeInBytes - 1 + val logConfig = LogTestUtils.createLogConfig(maxMessageBytes = maxMessageSize) + val log = createLog(logDir, logConfig) + + // should be able to append the small message + log.appendAsLeader(first, leaderEpoch = 0) + + assertThrows(classOf[RecordTooLargeException], () => log.appendAsLeader(second, leaderEpoch = 0), + () => "Second message set should throw MessageSizeTooLargeException.") + } + + @Test + def testMessageSizeCheckInAppendAsFollower(): Unit = { + val first = MemoryRecords.withRecords(0, CompressionType.NONE, 0, + new SimpleRecord("You".getBytes), new SimpleRecord("bethe".getBytes)) + val second = MemoryRecords.withRecords(5, CompressionType.NONE, 0, + new SimpleRecord("change (I need more bytes)... blah blah blah.".getBytes), + new SimpleRecord("More padding boo hoo".getBytes)) + + val log = createLog(logDir, LogTestUtils.createLogConfig(maxMessageBytes = second.sizeInBytes - 1)) + + log.appendAsFollower(first) + // the second record is larger then limit but appendAsFollower does not validate the size. + log.appendAsFollower(second) + } + + @Test + def testLogFlushesPartitionMetadataOnAppend(): Unit = { + val logConfig = LogTestUtils.createLogConfig() + val log = createLog(logDir, logConfig) + val record = MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("simpleValue".getBytes)) + + val topicId = Uuid.randomUuid() + log.partitionMetadataFile.record(topicId) + + // Should trigger a synchronous flush + log.appendAsLeader(record, leaderEpoch = 0) + assertTrue(log.partitionMetadataFile.exists()) + assertEquals(topicId, log.partitionMetadataFile.read().topicId) + } + + @Test + def testLogFlushesPartitionMetadataOnClose(): Unit = { + val logConfig = LogTestUtils.createLogConfig() + var log = createLog(logDir, logConfig) + + val topicId = Uuid.randomUuid() + log.partitionMetadataFile.record(topicId) + + // Should trigger a synchronous flush + log.close() + + // We open the log again, and the partition metadata file should exist with the same ID. + log = createLog(logDir, logConfig) + assertTrue(log.partitionMetadataFile.exists()) + assertEquals(topicId, log.partitionMetadataFile.read().topicId) + } + + @Test + def testLogRecoversTopicId(): Unit = { + val logConfig = LogTestUtils.createLogConfig() + var log = createLog(logDir, logConfig) + + val topicId = Uuid.randomUuid() + log.assignTopicId(topicId) + log.close() + + // test recovery case + log = createLog(logDir, logConfig) + assertTrue(log.topicId.isDefined) + assertTrue(log.topicId.get == topicId) + log.close() + } + + @Test + def testNoOpWhenKeepPartitionMetadataFileIsFalse(): Unit = { + val logConfig = LogTestUtils.createLogConfig() + val log = createLog(logDir, logConfig, keepPartitionMetadataFile = false) + + val topicId = Uuid.randomUuid() + log.assignTopicId(topicId) + // We should not write to this file or set the topic ID + assertFalse(log.partitionMetadataFile.exists()) + assertEquals(None, log.topicId) + log.close() + + val log2 = createLog(logDir, logConfig, topicId = Some(Uuid.randomUuid()), keepPartitionMetadataFile = false) + + // We should not write to this file or set the topic ID + assertFalse(log2.partitionMetadataFile.exists()) + assertEquals(None, log2.topicId) + log2.close() + } + + @Test + def testLogFailsWhenInconsistentTopicIdSet(): Unit = { + val logConfig = LogTestUtils.createLogConfig() + var log = createLog(logDir, logConfig) + + val topicId = Uuid.randomUuid() + log.assignTopicId(topicId) + log.close() + + // test creating a log with a new ID + try { + log = createLog(logDir, logConfig, topicId = Some(Uuid.randomUuid())) + log.close() + } catch { + case e: Throwable => assertTrue(e.isInstanceOf[InconsistentTopicIdException]) + } + } + + /** + * Test building the time index on the follower by setting assignOffsets to false. + */ + @Test + def testBuildTimeIndexWhenNotAssigningOffsets(): Unit = { + val numMessages = 100 + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 10000, indexIntervalBytes = 1) + val log = createLog(logDir, logConfig) + + val messages = (0 until numMessages).map { i => + MemoryRecords.withRecords(100 + i, CompressionType.NONE, 0, new SimpleRecord(mockTime.milliseconds + i, i.toString.getBytes())) + } + messages.foreach(log.appendAsFollower) + val timeIndexEntries = log.logSegments.foldLeft(0) { (entries, segment) => entries + segment.timeIndex.entries } + assertEquals(numMessages - 1, timeIndexEntries, s"There should be ${numMessages - 1} time index entries") + assertEquals(mockTime.milliseconds + numMessages - 1, log.activeSegment.timeIndex.lastEntry.timestamp, + s"The last time index entry should have timestamp ${mockTime.milliseconds + numMessages - 1}") + } + + @Test + def testFetchOffsetByTimestampIncludesLeaderEpoch(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 200, indexIntervalBytes = 1) + val log = createLog(logDir, logConfig) + + assertEquals(None, log.fetchOffsetByTimestamp(0L)) + + val firstTimestamp = mockTime.milliseconds + val firstLeaderEpoch = 0 + log.appendAsLeader(TestUtils.singletonRecords( + value = TestUtils.randomBytes(10), + timestamp = firstTimestamp), + leaderEpoch = firstLeaderEpoch) + + val secondTimestamp = firstTimestamp + 1 + val secondLeaderEpoch = 1 + log.appendAsLeader(TestUtils.singletonRecords( + value = TestUtils.randomBytes(10), + timestamp = secondTimestamp), + leaderEpoch = secondLeaderEpoch) + + assertEquals(Some(new TimestampAndOffset(firstTimestamp, 0L, Optional.of(firstLeaderEpoch))), + log.fetchOffsetByTimestamp(firstTimestamp)) + assertEquals(Some(new TimestampAndOffset(secondTimestamp, 1L, Optional.of(secondLeaderEpoch))), + log.fetchOffsetByTimestamp(secondTimestamp)) + + assertEquals(Some(new TimestampAndOffset(ListOffsetsResponse.UNKNOWN_TIMESTAMP, 0L, Optional.of(firstLeaderEpoch))), + log.fetchOffsetByTimestamp(ListOffsetsRequest.EARLIEST_TIMESTAMP)) + assertEquals(Some(new TimestampAndOffset(ListOffsetsResponse.UNKNOWN_TIMESTAMP, 2L, Optional.of(secondLeaderEpoch))), + log.fetchOffsetByTimestamp(ListOffsetsRequest.LATEST_TIMESTAMP)) + + // The cache can be updated directly after a leader change. + // The new latest offset should reflect the updated epoch. + log.maybeAssignEpochStartOffset(2, 2L) + + assertEquals(Some(new TimestampAndOffset(ListOffsetsResponse.UNKNOWN_TIMESTAMP, 2L, Optional.of(2))), + log.fetchOffsetByTimestamp(ListOffsetsRequest.LATEST_TIMESTAMP)) + } + + @Test + def testFetchOffsetByTimestampWithMaxTimestampIncludesTimestamp(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 200, indexIntervalBytes = 1) + val log = createLog(logDir, logConfig) + + assertEquals(None, log.fetchOffsetByTimestamp(0L)) + + val firstTimestamp = mockTime.milliseconds + val leaderEpoch = 0 + log.appendAsLeader(TestUtils.singletonRecords( + value = TestUtils.randomBytes(10), + timestamp = firstTimestamp), + leaderEpoch = leaderEpoch) + + val secondTimestamp = firstTimestamp + 1 + log.appendAsLeader(TestUtils.singletonRecords( + value = TestUtils.randomBytes(10), + timestamp = secondTimestamp), + leaderEpoch = leaderEpoch) + + log.appendAsLeader(TestUtils.singletonRecords( + value = TestUtils.randomBytes(10), + timestamp = firstTimestamp), + leaderEpoch = leaderEpoch) + + assertEquals(Some(new TimestampAndOffset(secondTimestamp, 1L, Optional.of(leaderEpoch))), + log.fetchOffsetByTimestamp(ListOffsetsRequest.MAX_TIMESTAMP)) + } + + /** + * Test the Log truncate operations + */ + @Test + def testTruncateTo(): Unit = { + def createRecords = TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds) + val setSize = createRecords.sizeInBytes + val msgPerSeg = 10 + val segmentSize = msgPerSeg * setSize // each segment will be 10 messages + + // create a log + val logConfig = LogTestUtils.createLogConfig(segmentBytes = segmentSize) + val log = createLog(logDir, logConfig) + assertEquals(1, log.numberOfSegments, "There should be exactly 1 segment.") + + for (_ <- 1 to msgPerSeg) + log.appendAsLeader(createRecords, leaderEpoch = 0) + + assertEquals(1, log.numberOfSegments, "There should be exactly 1 segments.") + assertEquals(msgPerSeg, log.logEndOffset, "Log end offset should be equal to number of messages") + + val lastOffset = log.logEndOffset + val size = log.size + log.truncateTo(log.logEndOffset) // keep the entire log + assertEquals(lastOffset, log.logEndOffset, "Should not change offset") + assertEquals(size, log.size, "Should not change log size") + log.truncateTo(log.logEndOffset + 1) // try to truncate beyond lastOffset + assertEquals(lastOffset, log.logEndOffset, "Should not change offset but should log error") + assertEquals(size, log.size, "Should not change log size") + log.truncateTo(msgPerSeg/2) // truncate somewhere in between + assertEquals(log.logEndOffset, msgPerSeg/2, "Should change offset") + assertTrue(log.size < size, "Should change log size") + log.truncateTo(0) // truncate the entire log + assertEquals(0, log.logEndOffset, "Should change offset") + assertEquals(0, log.size, "Should change log size") + + for (_ <- 1 to msgPerSeg) + log.appendAsLeader(createRecords, leaderEpoch = 0) + + assertEquals(log.logEndOffset, lastOffset, "Should be back to original offset") + assertEquals(log.size, size, "Should be back to original size") + log.truncateFullyAndStartAt(log.logEndOffset - (msgPerSeg - 1)) + assertEquals(log.logEndOffset, lastOffset - (msgPerSeg - 1), "Should change offset") + assertEquals(log.size, 0, "Should change log size") + + for (_ <- 1 to msgPerSeg) + log.appendAsLeader(createRecords, leaderEpoch = 0) + + assertTrue(log.logEndOffset > msgPerSeg, "Should be ahead of to original offset") + assertEquals(size, log.size, "log size should be same as before") + log.truncateTo(0) // truncate before first start offset in the log + assertEquals(0, log.logEndOffset, "Should change offset") + assertEquals(log.size, 0, "Should change log size") + } + + /** + * Verify that when we truncate a log the index of the last segment is resized to the max index size to allow more appends + */ + @Test + def testIndexResizingAtTruncation(): Unit = { + val setSize = TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds).sizeInBytes + val msgPerSeg = 10 + val segmentSize = msgPerSeg * setSize // each segment will be 10 messages + val logConfig = LogTestUtils.createLogConfig(segmentBytes = segmentSize, indexIntervalBytes = setSize - 1) + val log = createLog(logDir, logConfig) + assertEquals(1, log.numberOfSegments, "There should be exactly 1 segment.") + + for (i<- 1 to msgPerSeg) + log.appendAsLeader(TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds + i), leaderEpoch = 0) + assertEquals(1, log.numberOfSegments, "There should be exactly 1 segment.") + + mockTime.sleep(msgPerSeg) + for (i<- 1 to msgPerSeg) + log.appendAsLeader(TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds + i), leaderEpoch = 0) + assertEquals(2, log.numberOfSegments, "There should be exactly 2 segment.") + val expectedEntries = msgPerSeg - 1 + + assertEquals(expectedEntries, log.logSegments.toList.head.offsetIndex.maxEntries, + s"The index of the first segment should have $expectedEntries entries") + assertEquals(expectedEntries, log.logSegments.toList.head.timeIndex.maxEntries, + s"The time index of the first segment should have $expectedEntries entries") + + log.truncateTo(0) + assertEquals(1, log.numberOfSegments, "There should be exactly 1 segment.") + assertEquals(log.config.maxIndexSize/8, log.logSegments.toList.head.offsetIndex.maxEntries, + "The index of segment 1 should be resized to maxIndexSize") + assertEquals(log.config.maxIndexSize/12, log.logSegments.toList.head.timeIndex.maxEntries, + "The time index of segment 1 should be resized to maxIndexSize") + + mockTime.sleep(msgPerSeg) + for (i<- 1 to msgPerSeg) + log.appendAsLeader(TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds + i), leaderEpoch = 0) + assertEquals(1, log.numberOfSegments, + "There should be exactly 1 segment.") + } + + /** + * Test that deleted files are deleted after the appropriate time. + */ + @Test + def testAsyncDelete(): Unit = { + def createRecords = TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds - 1000L) + val asyncDeleteMs = 1000 + val logConfig = LogTestUtils.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, segmentIndexBytes = 1000, indexIntervalBytes = 10000, + retentionMs = 999, fileDeleteDelayMs = asyncDeleteMs) + val log = createLog(logDir, logConfig) + + // append some messages to create some segments + for (_ <- 0 until 100) + log.appendAsLeader(createRecords, leaderEpoch = 0) + + // files should be renamed + val segments = log.logSegments.toArray + val oldFiles = segments.map(_.log.file) ++ segments.map(_.lazyOffsetIndex.file) + + log.updateHighWatermark(log.logEndOffset) + log.deleteOldSegments() + + assertEquals(1, log.numberOfSegments, "Only one segment should remain.") + assertTrue(segments.forall(_.log.file.getName.endsWith(UnifiedLog.DeletedFileSuffix)) && + segments.forall(_.lazyOffsetIndex.file.getName.endsWith(UnifiedLog.DeletedFileSuffix)), + "All log and index files should end in .deleted") + assertTrue(segments.forall(_.log.file.exists) && segments.forall(_.lazyOffsetIndex.file.exists), + "The .deleted files should still be there.") + assertTrue(oldFiles.forall(!_.exists), "The original file should be gone.") + + // when enough time passes the files should be deleted + val deletedFiles = segments.map(_.log.file) ++ segments.map(_.lazyOffsetIndex.file) + mockTime.sleep(asyncDeleteMs + 1) + assertTrue(deletedFiles.forall(!_.exists), "Files should all be gone.") + } + + @Test + def testAppendMessageWithNullPayload(): Unit = { + val log = createLog(logDir, LogConfig()) + log.appendAsLeader(TestUtils.singletonRecords(value = null), leaderEpoch = 0) + val head = LogTestUtils.readLog(log, 0, 4096).records.records.iterator.next() + assertEquals(0, head.offset) + assertFalse(head.hasValue, "Message payload should be null.") + } + + @Test + def testAppendWithOutOfOrderOffsetsThrowsException(): Unit = { + val log = createLog(logDir, LogConfig()) + + val appendOffsets = Seq(0L, 1L, 3L, 2L, 4L) + val buffer = ByteBuffer.allocate(512) + for (offset <- appendOffsets) { + val builder = MemoryRecords.builder(buffer, RecordBatch.MAGIC_VALUE_V2, CompressionType.NONE, + TimestampType.LOG_APPEND_TIME, offset, mockTime.milliseconds(), + 1L, 0, 0, false, 0) + builder.append(new SimpleRecord("key".getBytes, "value".getBytes)) + builder.close() + } + buffer.flip() + val memoryRecords = MemoryRecords.readableRecords(buffer) + + assertThrows(classOf[OffsetsOutOfOrderException], () => + log.appendAsFollower(memoryRecords) + ) + } + + @Test + def testAppendBelowExpectedOffsetThrowsException(): Unit = { + val log = createLog(logDir, LogConfig()) + val records = (0 until 2).map(id => new SimpleRecord(id.toString.getBytes)).toArray + records.foreach(record => log.appendAsLeader(MemoryRecords.withRecords(CompressionType.NONE, record), leaderEpoch = 0)) + + val magicVals = Seq(RecordBatch.MAGIC_VALUE_V0, RecordBatch.MAGIC_VALUE_V1, RecordBatch.MAGIC_VALUE_V2) + val compressionTypes = Seq(CompressionType.NONE, CompressionType.LZ4) + for (magic <- magicVals; compression <- compressionTypes) { + val invalidRecord = MemoryRecords.withRecords(magic, compression, new SimpleRecord(1.toString.getBytes)) + assertThrows(classOf[UnexpectedAppendOffsetException], + () => log.appendAsFollower(invalidRecord), + () => s"Magic=$magic, compressionType=$compression") + } + } + + @Test + def testAppendEmptyLogBelowLogStartOffsetThrowsException(): Unit = { + createEmptyLogs(logDir, 7) + val log = createLog(logDir, LogConfig(), brokerTopicStats = brokerTopicStats) + assertEquals(7L, log.logStartOffset) + assertEquals(7L, log.logEndOffset) + + val firstOffset = 4L + val magicVals = Seq(RecordBatch.MAGIC_VALUE_V0, RecordBatch.MAGIC_VALUE_V1, RecordBatch.MAGIC_VALUE_V2) + val compressionTypes = Seq(CompressionType.NONE, CompressionType.LZ4) + for (magic <- magicVals; compression <- compressionTypes) { + val batch = TestUtils.records(List(new SimpleRecord("k1".getBytes, "v1".getBytes), + new SimpleRecord("k2".getBytes, "v2".getBytes), + new SimpleRecord("k3".getBytes, "v3".getBytes)), + magicValue = magic, codec = compression, + baseOffset = firstOffset) + + val exception = assertThrows(classOf[UnexpectedAppendOffsetException], () => log.appendAsFollower(records = batch)) + assertEquals(firstOffset, exception.firstOffset, s"Magic=$magic, compressionType=$compression, UnexpectedAppendOffsetException#firstOffset") + assertEquals(firstOffset + 2, exception.lastOffset, s"Magic=$magic, compressionType=$compression, UnexpectedAppendOffsetException#lastOffset") + } + } + + @Test + def testAppendWithNoTimestamp(): Unit = { + val log = createLog(logDir, LogConfig()) + log.appendAsLeader(MemoryRecords.withRecords(CompressionType.NONE, + new SimpleRecord(RecordBatch.NO_TIMESTAMP, "key".getBytes, "value".getBytes)), leaderEpoch = 0) + } + + @Test + def testAppendToOrReadFromLogInFailedLogDir(): Unit = { + val pid = 1L + val epoch = 0.toShort + val log = createLog(logDir, LogConfig()) + log.appendAsLeader(TestUtils.singletonRecords(value = null), leaderEpoch = 0) + assertEquals(0, LogTestUtils.readLog(log, 0, 4096).records.records.iterator.next().offset) + val append = LogTestUtils.appendTransactionalAsLeader(log, pid, epoch, mockTime) + append(10) + // Kind of a hack, but renaming the index to a directory ensures that the append + // to the index will fail. + log.activeSegment.txnIndex.renameTo(log.dir) + assertThrows(classOf[KafkaStorageException], () => LogTestUtils.appendEndTxnMarkerAsLeader(log, pid, epoch, ControlRecordType.ABORT, mockTime.milliseconds(), coordinatorEpoch = 1)) + assertThrows(classOf[KafkaStorageException], () => log.appendAsLeader(TestUtils.singletonRecords(value = null), leaderEpoch = 0)) + assertThrows(classOf[KafkaStorageException], () => LogTestUtils.readLog(log, 0, 4096).records.records.iterator.next().offset) + } + + @Test + def testWriteLeaderEpochCheckpointAfterDirectoryRename(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1000, indexIntervalBytes = 1, maxMessageBytes = 64 * 1024) + val log = createLog(logDir, logConfig) + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("foo".getBytes()))), leaderEpoch = 5) + assertEquals(Some(5), log.latestEpoch) + + // Ensure that after a directory rename, the epoch cache is written to the right location + val tp = UnifiedLog.parseTopicPartitionName(log.dir) + log.renameDir(UnifiedLog.logDeleteDirName(tp)) + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("foo".getBytes()))), leaderEpoch = 10) + assertEquals(Some(10), log.latestEpoch) + assertTrue(LeaderEpochCheckpointFile.newFile(log.dir).exists()) + assertFalse(LeaderEpochCheckpointFile.newFile(this.logDir).exists()) + } + + @Test + def testTopicIdTransfersAfterDirectoryRename(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1000, indexIntervalBytes = 1, maxMessageBytes = 64 * 1024) + val log = createLog(logDir, logConfig) + + // Write a topic ID to the partition metadata file to ensure it is transferred correctly. + val topicId = Uuid.randomUuid() + log.assignTopicId(topicId) + + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("foo".getBytes()))), leaderEpoch = 5) + assertEquals(Some(5), log.latestEpoch) + + // Ensure that after a directory rename, the partition metadata file is written to the right location. + val tp = UnifiedLog.parseTopicPartitionName(log.dir) + log.renameDir(UnifiedLog.logDeleteDirName(tp)) + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("foo".getBytes()))), leaderEpoch = 10) + assertEquals(Some(10), log.latestEpoch) + assertTrue(PartitionMetadataFile.newFile(log.dir).exists()) + assertFalse(PartitionMetadataFile.newFile(this.logDir).exists()) + + // Check the topic ID remains in memory and was copied correctly. + assertTrue(log.topicId.isDefined) + assertEquals(topicId, log.topicId.get) + assertEquals(topicId, log.partitionMetadataFile.read().topicId) + } + + @Test + def testTopicIdFlushesBeforeDirectoryRename(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1000, indexIntervalBytes = 1, maxMessageBytes = 64 * 1024) + val log = createLog(logDir, logConfig) + + // Write a topic ID to the partition metadata file to ensure it is transferred correctly. + val topicId = Uuid.randomUuid() + log.partitionMetadataFile.record(topicId) + + // Ensure that after a directory rename, the partition metadata file is written to the right location. + val tp = UnifiedLog.parseTopicPartitionName(log.dir) + log.renameDir(UnifiedLog.logDeleteDirName(tp)) + assertTrue(PartitionMetadataFile.newFile(log.dir).exists()) + assertFalse(PartitionMetadataFile.newFile(this.logDir).exists()) + + // Check the file holds the correct contents. + assertTrue(log.partitionMetadataFile.exists()) + assertEquals(topicId, log.partitionMetadataFile.read().topicId) + } + + @Test + def testLeaderEpochCacheClearedAfterDowngradeInAppendedMessages(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1000, indexIntervalBytes = 1, maxMessageBytes = 64 * 1024) + val log = createLog(logDir, logConfig) + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("foo".getBytes()))), leaderEpoch = 5) + assertEquals(Some(5), log.leaderEpochCache.flatMap(_.latestEpoch)) + + log.appendAsFollower(TestUtils.records(List(new SimpleRecord("foo".getBytes())), + baseOffset = 1L, + magicValue = RecordVersion.V1.value)) + assertEquals(None, log.leaderEpochCache.flatMap(_.latestEpoch)) + } + + @nowarn("cat=deprecation") + @Test + def testLeaderEpochCacheClearedAfterDynamicMessageFormatDowngrade(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1000, indexIntervalBytes = 1, maxMessageBytes = 64 * 1024) + val log = createLog(logDir, logConfig) + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("foo".getBytes()))), leaderEpoch = 5) + assertEquals(Some(5), log.latestEpoch) + + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, "1000") + logProps.put(LogConfig.IndexIntervalBytesProp, "1") + logProps.put(LogConfig.MaxMessageBytesProp, "65536") + logProps.put(LogConfig.MessageFormatVersionProp, "0.10.2") + val downgradedLogConfig = LogConfig(logProps) + log.updateConfig(downgradedLogConfig) + LogTestUtils.assertLeaderEpochCacheEmpty(log) + + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("bar".getBytes())), + magicValue = RecordVersion.V1.value), leaderEpoch = 5) + LogTestUtils.assertLeaderEpochCacheEmpty(log) + } + + @nowarn("cat=deprecation") + @Test + def testLeaderEpochCacheCreatedAfterMessageFormatUpgrade(): Unit = { + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, "1000") + logProps.put(LogConfig.IndexIntervalBytesProp, "1") + logProps.put(LogConfig.MaxMessageBytesProp, "65536") + logProps.put(LogConfig.MessageFormatVersionProp, "0.10.2") + val logConfig = LogConfig(logProps) + val log = createLog(logDir, logConfig) + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("bar".getBytes())), + magicValue = RecordVersion.V1.value), leaderEpoch = 5) + LogTestUtils.assertLeaderEpochCacheEmpty(log) + + logProps.put(LogConfig.MessageFormatVersionProp, "0.11.0") + val upgradedLogConfig = LogConfig(logProps) + log.updateConfig(upgradedLogConfig) + log.appendAsLeader(TestUtils.records(List(new SimpleRecord("foo".getBytes()))), leaderEpoch = 5) + assertEquals(Some(5), log.latestEpoch) + } + + @Test + def testSplitOnOffsetOverflow(): Unit = { + // create a log such that one log segment has offsets that overflow, and call the split API on that segment + val logConfig = LogTestUtils.createLogConfig(indexIntervalBytes = 1, fileDeleteDelayMs = 1000) + val (log, segmentWithOverflow) = createLogWithOffsetOverflow(logConfig) + assertTrue(LogTestUtils.hasOffsetOverflow(log), "At least one segment must have offset overflow") + + val allRecordsBeforeSplit = UnifiedLogTest.allRecords(log) + + // split the segment with overflow + log.splitOverflowedSegment(segmentWithOverflow) + + // assert we were successfully able to split the segment + assertEquals(4, log.numberOfSegments) + UnifiedLogTest.verifyRecordsInLog(log, allRecordsBeforeSplit) + + // verify we do not have offset overflow anymore + assertFalse(LogTestUtils.hasOffsetOverflow(log)) + } + + @Test + def testDegenerateSegmentSplit(): Unit = { + // This tests a scenario where all of the batches appended to a segment have overflowed. + // When we split the overflowed segment, only one new segment will be created. + + val overflowOffset = Int.MaxValue + 1L + val batch1 = MemoryRecords.withRecords(overflowOffset, CompressionType.NONE, 0, + new SimpleRecord("a".getBytes)) + val batch2 = MemoryRecords.withRecords(overflowOffset + 1, CompressionType.NONE, 0, + new SimpleRecord("b".getBytes)) + + testDegenerateSplitSegmentWithOverflow(segmentBaseOffset = 0L, List(batch1, batch2)) + } + + @Test + def testDegenerateSegmentSplitWithOutOfRangeBatchLastOffset(): Unit = { + // Degenerate case where the only batch in the segment overflows. In this scenario, + // the first offset of the batch is valid, but the last overflows. + + val firstBatchBaseOffset = Int.MaxValue - 1 + val records = MemoryRecords.withRecords(firstBatchBaseOffset, CompressionType.NONE, 0, + new SimpleRecord("a".getBytes), + new SimpleRecord("b".getBytes), + new SimpleRecord("c".getBytes)) + + testDegenerateSplitSegmentWithOverflow(segmentBaseOffset = 0L, List(records)) + } + + private def testDegenerateSplitSegmentWithOverflow(segmentBaseOffset: Long, records: List[MemoryRecords]): Unit = { + val segment = LogTestUtils.rawSegment(logDir, segmentBaseOffset) + // Need to create the offset files explicitly to avoid triggering segment recovery to truncate segment. + UnifiedLog.offsetIndexFile(logDir, segmentBaseOffset).createNewFile() + UnifiedLog.timeIndexFile(logDir, segmentBaseOffset).createNewFile() + records.foreach(segment.append _) + segment.close() + + val logConfig = LogTestUtils.createLogConfig(indexIntervalBytes = 1, fileDeleteDelayMs = 1000) + val log = createLog(logDir, logConfig, recoveryPoint = Long.MaxValue) + + val segmentWithOverflow = LogTestUtils.firstOverflowSegment(log).getOrElse { + throw new AssertionError("Failed to create log with a segment which has overflowed offsets") + } + + val allRecordsBeforeSplit = UnifiedLogTest.allRecords(log) + log.splitOverflowedSegment(segmentWithOverflow) + + assertEquals(1, log.numberOfSegments) + + val firstBatchBaseOffset = records.head.batches.asScala.head.baseOffset + assertEquals(firstBatchBaseOffset, log.activeSegment.baseOffset) + UnifiedLogTest.verifyRecordsInLog(log, allRecordsBeforeSplit) + + assertFalse(LogTestUtils.hasOffsetOverflow(log)) + } + + @Test + def testDeleteOldSegments(): Unit = { + def createRecords = TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds - 1000) + val logConfig = LogTestUtils.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, segmentIndexBytes = 1000, retentionMs = 999) + val log = createLog(logDir, logConfig) + + // append some messages to create some segments + for (_ <- 0 until 100) + log.appendAsLeader(createRecords, leaderEpoch = 0) + + log.maybeAssignEpochStartOffset(0, 40) + log.maybeAssignEpochStartOffset(1, 90) + + // segments are not eligible for deletion if no high watermark has been set + val numSegments = log.numberOfSegments + log.deleteOldSegments() + assertEquals(numSegments, log.numberOfSegments) + assertEquals(0L, log.logStartOffset) + + // only segments with offset before the current high watermark are eligible for deletion + for (hw <- 25 to 30) { + log.updateHighWatermark(hw) + log.deleteOldSegments() + assertTrue(log.logStartOffset <= hw) + log.logSegments.foreach { segment => + val segmentFetchInfo = segment.read(startOffset = segment.baseOffset, maxSize = Int.MaxValue) + val segmentLastOffsetOpt = segmentFetchInfo.records.records.asScala.lastOption.map(_.offset) + segmentLastOffsetOpt.foreach { lastOffset => + assertTrue(lastOffset >= hw) + } + } + } + + // expire all segments + log.updateHighWatermark(log.logEndOffset) + log.deleteOldSegments() + assertEquals(1, log.numberOfSegments, "The deleted segments should be gone.") + assertEquals(1, epochCache(log).epochEntries.size, "Epoch entries should have gone.") + assertEquals(EpochEntry(1, 100), epochCache(log).epochEntries.head, "Epoch entry should be the latest epoch and the leo.") + + // append some messages to create some segments + for (_ <- 0 until 100) + log.appendAsLeader(createRecords, leaderEpoch = 0) + + log.delete() + assertEquals(0, log.numberOfSegments, "The number of segments should be 0") + assertEquals(0, log.deleteOldSegments(), "The number of deleted segments should be zero.") + assertEquals(0, epochCache(log).epochEntries.size, "Epoch entries should have gone.") + } + + @Test + def testLogDeletionAfterClose(): Unit = { + def createRecords = TestUtils.singletonRecords(value = "test".getBytes, timestamp = mockTime.milliseconds - 1000) + val logConfig = LogTestUtils.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, segmentIndexBytes = 1000, retentionMs = 999) + val log = createLog(logDir, logConfig) + + // append some messages to create some segments + log.appendAsLeader(createRecords, leaderEpoch = 0) + + assertEquals(1, log.numberOfSegments, "The deleted segments should be gone.") + assertEquals(1, epochCache(log).epochEntries.size, "Epoch entries should have gone.") + + log.close() + log.delete() + assertEquals(0, log.numberOfSegments, "The number of segments should be 0") + assertEquals(0, epochCache(log).epochEntries.size, "Epoch entries should have gone.") + } + + @Test + def testLogDeletionAfterDeleteRecords(): Unit = { + def createRecords = TestUtils.singletonRecords("test".getBytes) + val logConfig = LogTestUtils.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5) + val log = createLog(logDir, logConfig) + + for (_ <- 0 until 15) + log.appendAsLeader(createRecords, leaderEpoch = 0) + assertEquals(3, log.numberOfSegments, "should have 3 segments") + assertEquals(log.logStartOffset, 0) + log.updateHighWatermark(log.logEndOffset) + + log.maybeIncrementLogStartOffset(1, ClientRecordDeletion) + log.deleteOldSegments() + assertEquals(3, log.numberOfSegments, "should have 3 segments") + assertEquals(log.logStartOffset, 1) + + log.maybeIncrementLogStartOffset(6, ClientRecordDeletion) + log.deleteOldSegments() + assertEquals(2, log.numberOfSegments, "should have 2 segments") + assertEquals(log.logStartOffset, 6) + + log.maybeIncrementLogStartOffset(15, ClientRecordDeletion) + log.deleteOldSegments() + assertEquals(1, log.numberOfSegments, "should have 1 segments") + assertEquals(log.logStartOffset, 15) + } + + def epochCache(log: UnifiedLog): LeaderEpochFileCache = { + log.leaderEpochCache.get + } + + @Test + def shouldDeleteSizeBasedSegments(): Unit = { + def createRecords = TestUtils.singletonRecords("test".getBytes) + val logConfig = LogTestUtils.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, retentionBytes = createRecords.sizeInBytes * 10) + val log = createLog(logDir, logConfig) + + // append some messages to create some segments + for (_ <- 0 until 15) + log.appendAsLeader(createRecords, leaderEpoch = 0) + + log.updateHighWatermark(log.logEndOffset) + log.deleteOldSegments() + assertEquals(2,log.numberOfSegments, "should have 2 segments") + } + + @Test + def shouldNotDeleteSizeBasedSegmentsWhenUnderRetentionSize(): Unit = { + def createRecords = TestUtils.singletonRecords("test".getBytes) + val logConfig = LogTestUtils.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, retentionBytes = createRecords.sizeInBytes * 15) + val log = createLog(logDir, logConfig) + + // append some messages to create some segments + for (_ <- 0 until 15) + log.appendAsLeader(createRecords, leaderEpoch = 0) + + log.updateHighWatermark(log.logEndOffset) + log.deleteOldSegments() + assertEquals(3,log.numberOfSegments, "should have 3 segments") + } + + @Test + def shouldDeleteTimeBasedSegmentsReadyToBeDeleted(): Unit = { + def createRecords = TestUtils.singletonRecords("test".getBytes, timestamp = 10) + val logConfig = LogTestUtils.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, retentionMs = 10000) + val log = createLog(logDir, logConfig) + + // append some messages to create some segments + for (_ <- 0 until 15) + log.appendAsLeader(createRecords, leaderEpoch = 0) + + log.updateHighWatermark(log.logEndOffset) + log.deleteOldSegments() + assertEquals(1, log.numberOfSegments, "There should be 1 segment remaining") + } + + @Test + def shouldNotDeleteTimeBasedSegmentsWhenNoneReadyToBeDeleted(): Unit = { + def createRecords = TestUtils.singletonRecords("test".getBytes, timestamp = mockTime.milliseconds) + val logConfig = LogTestUtils.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, retentionMs = 10000000) + val log = createLog(logDir, logConfig) + + // append some messages to create some segments + for (_ <- 0 until 15) + log.appendAsLeader(createRecords, leaderEpoch = 0) + + log.updateHighWatermark(log.logEndOffset) + log.deleteOldSegments() + assertEquals(3, log.numberOfSegments, "There should be 3 segments remaining") + } + + @Test + def shouldNotDeleteSegmentsWhenPolicyDoesNotIncludeDelete(): Unit = { + def createRecords = TestUtils.singletonRecords("test".getBytes, key = "test".getBytes(), timestamp = 10L) + val logConfig = LogTestUtils.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, retentionMs = 10000, cleanupPolicy = "compact") + val log = createLog(logDir, logConfig) + + // append some messages to create some segments + for (_ <- 0 until 15) + log.appendAsLeader(createRecords, leaderEpoch = 0) + + // mark oldest segment as older the retention.ms + log.logSegments.head.lastModified = mockTime.milliseconds - 20000 + + val segments = log.numberOfSegments + log.updateHighWatermark(log.logEndOffset) + log.deleteOldSegments() + assertEquals(segments, log.numberOfSegments, "There should be 3 segments remaining") + } + + @Test + def shouldDeleteSegmentsReadyToBeDeletedWhenCleanupPolicyIsCompactAndDelete(): Unit = { + def createRecords = TestUtils.singletonRecords("test".getBytes, key = "test".getBytes, timestamp = 10L) + val logConfig = LogTestUtils.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, retentionMs = 10000, cleanupPolicy = "compact,delete") + val log = createLog(logDir, logConfig) + + // append some messages to create some segments + for (_ <- 0 until 15) + log.appendAsLeader(createRecords, leaderEpoch = 0) + + log.updateHighWatermark(log.logEndOffset) + log.deleteOldSegments() + assertEquals(1, log.numberOfSegments, "There should be 1 segment remaining") + } + + @Test + def shouldDeleteStartOffsetBreachedSegmentsWhenPolicyDoesNotIncludeDelete(): Unit = { + def createRecords = TestUtils.singletonRecords("test".getBytes, key = "test".getBytes, timestamp = 10L) + val recordsPerSegment = 5 + val logConfig = LogTestUtils.createLogConfig(segmentBytes = createRecords.sizeInBytes * recordsPerSegment, retentionMs = 10000, cleanupPolicy = "compact") + val log = createLog(logDir, logConfig, brokerTopicStats) + + // append some messages to create some segments + for (_ <- 0 until 15) + log.appendAsLeader(createRecords, leaderEpoch = 0) + + // Three segments should be created + assertEquals(3, log.logSegments.count(_ => true)) + log.updateHighWatermark(log.logEndOffset) + log.maybeIncrementLogStartOffset(recordsPerSegment, ClientRecordDeletion) + + // The first segment, which is entirely before the log start offset, should be deleted + // Of the remaining the segments, the first can overlap the log start offset and the rest must have a base offset + // greater than the start offset + log.updateHighWatermark(log.logEndOffset) + log.deleteOldSegments() + assertEquals(2, log.numberOfSegments, "There should be 2 segments remaining") + assertTrue(log.logSegments.head.baseOffset <= log.logStartOffset) + assertTrue(log.logSegments.tail.forall(s => s.baseOffset > log.logStartOffset)) + } + + @Test + def shouldApplyEpochToMessageOnAppendIfLeader(): Unit = { + val records = (0 until 50).toArray.map(id => new SimpleRecord(id.toString.getBytes)) + + //Given this partition is on leader epoch 72 + val epoch = 72 + val log = createLog(logDir, LogConfig()) + log.maybeAssignEpochStartOffset(epoch, records.length) + + //When appending messages as a leader (i.e. assignOffsets = true) + for (record <- records) + log.appendAsLeader( + MemoryRecords.withRecords(CompressionType.NONE, record), + leaderEpoch = epoch + ) + + //Then leader epoch should be set on messages + for (i <- records.indices) { + val read = LogTestUtils.readLog(log, i, 1).records.batches.iterator.next() + assertEquals(72, read.partitionLeaderEpoch, "Should have set leader epoch") + } + } + + @Test + def followerShouldSaveEpochInformationFromReplicatedMessagesToTheEpochCache(): Unit = { + val messageIds = (0 until 50).toArray + val records = messageIds.map(id => new SimpleRecord(id.toString.getBytes)) + + //Given each message has an offset & epoch, as msgs from leader would + def recordsForEpoch(i: Int): MemoryRecords = { + val recs = MemoryRecords.withRecords(messageIds(i), CompressionType.NONE, records(i)) + recs.batches.forEach{record => + record.setPartitionLeaderEpoch(42) + record.setLastOffset(i) + } + recs + } + + val log = createLog(logDir, LogConfig()) + + //When appending as follower (assignOffsets = false) + for (i <- records.indices) + log.appendAsFollower(recordsForEpoch(i)) + + assertEquals(Some(42), log.latestEpoch) + } + + @Test + def shouldTruncateLeaderEpochsWhenDeletingSegments(): Unit = { + def createRecords = TestUtils.singletonRecords("test".getBytes) + val logConfig = LogTestUtils.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, retentionBytes = createRecords.sizeInBytes * 10) + val log = createLog(logDir, logConfig) + val cache = epochCache(log) + + // Given three segments of 5 messages each + for (_ <- 0 until 15) { + log.appendAsLeader(createRecords, leaderEpoch = 0) + } + + //Given epochs + cache.assign(0, 0) + cache.assign(1, 5) + cache.assign(2, 10) + + //When first segment is removed + log.updateHighWatermark(log.logEndOffset) + log.deleteOldSegments() + + //The oldest epoch entry should have been removed + assertEquals(ListBuffer(EpochEntry(1, 5), EpochEntry(2, 10)), cache.epochEntries) + } + + @Test + def shouldUpdateOffsetForLeaderEpochsWhenDeletingSegments(): Unit = { + def createRecords = TestUtils.singletonRecords("test".getBytes) + val logConfig = LogTestUtils.createLogConfig(segmentBytes = createRecords.sizeInBytes * 5, retentionBytes = createRecords.sizeInBytes * 10) + val log = createLog(logDir, logConfig) + val cache = epochCache(log) + + // Given three segments of 5 messages each + for (_ <- 0 until 15) { + log.appendAsLeader(createRecords, leaderEpoch = 0) + } + + //Given epochs + cache.assign(0, 0) + cache.assign(1, 7) + cache.assign(2, 10) + + //When first segment removed (up to offset 5) + log.updateHighWatermark(log.logEndOffset) + log.deleteOldSegments() + + //The first entry should have gone from (0,0) => (0,5) + assertEquals(ListBuffer(EpochEntry(0, 5), EpochEntry(1, 7), EpochEntry(2, 10)), cache.epochEntries) + } + + @Test + def shouldTruncateLeaderEpochCheckpointFileWhenTruncatingLog(): Unit = { + def createRecords(startOffset: Long, epoch: Int): MemoryRecords = { + TestUtils.records(Seq(new SimpleRecord("value".getBytes)), + baseOffset = startOffset, partitionLeaderEpoch = epoch) + } + + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 10 * createRecords(0, 0).sizeInBytes) + val log = createLog(logDir, logConfig) + val cache = epochCache(log) + + def append(epoch: Int, startOffset: Long, count: Int): Unit = { + for (i <- 0 until count) + log.appendAsFollower(createRecords(startOffset + i, epoch)) + } + + //Given 2 segments, 10 messages per segment + append(epoch = 0, startOffset = 0, count = 10) + append(epoch = 1, startOffset = 10, count = 6) + append(epoch = 2, startOffset = 16, count = 4) + + assertEquals(2, log.numberOfSegments) + assertEquals(20, log.logEndOffset) + + //When truncate to LEO (no op) + log.truncateTo(log.logEndOffset) + + //Then no change + assertEquals(3, cache.epochEntries.size) + + //When truncate + log.truncateTo(11) + + //Then no change + assertEquals(2, cache.epochEntries.size) + + //When truncate + log.truncateTo(10) + + //Then + assertEquals(1, cache.epochEntries.size) + + //When truncate all + log.truncateTo(0) + + //Then + assertEquals(0, cache.epochEntries.size) + } + + @Test + def testFirstUnstableOffsetNoTransactionalData(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024 * 5) + val log = createLog(logDir, logConfig) + + val records = MemoryRecords.withRecords(CompressionType.NONE, + new SimpleRecord("foo".getBytes), + new SimpleRecord("bar".getBytes), + new SimpleRecord("baz".getBytes)) + + log.appendAsLeader(records, leaderEpoch = 0) + assertEquals(None, log.firstUnstableOffset) + } + + @Test + def testFirstUnstableOffsetWithTransactionalData(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024 * 5) + val log = createLog(logDir, logConfig) + + val pid = 137L + val epoch = 5.toShort + var seq = 0 + + // add some transactional records + val records = MemoryRecords.withTransactionalRecords(CompressionType.NONE, pid, epoch, seq, + new SimpleRecord("foo".getBytes), + new SimpleRecord("bar".getBytes), + new SimpleRecord("baz".getBytes)) + + val firstAppendInfo = log.appendAsLeader(records, leaderEpoch = 0) + assertEquals(firstAppendInfo.firstOffset.map(_.messageOffset), log.firstUnstableOffset) + + // add more transactional records + seq += 3 + log.appendAsLeader(MemoryRecords.withTransactionalRecords(CompressionType.NONE, pid, epoch, seq, + new SimpleRecord("blah".getBytes)), leaderEpoch = 0) + + // LSO should not have changed + assertEquals(firstAppendInfo.firstOffset.map(_.messageOffset), log.firstUnstableOffset) + + // now transaction is committed + val commitAppendInfo = LogTestUtils.appendEndTxnMarkerAsLeader(log, pid, epoch, ControlRecordType.COMMIT, mockTime.milliseconds()) + + // first unstable offset is not updated until the high watermark is advanced + assertEquals(firstAppendInfo.firstOffset.map(_.messageOffset), log.firstUnstableOffset) + log.updateHighWatermark(commitAppendInfo.lastOffset + 1) + + // now there should be no first unstable offset + assertEquals(None, log.firstUnstableOffset) + } + + @Test + def testReadCommittedWithConcurrentHighWatermarkUpdates(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024 * 5) + val log = createLog(logDir, logConfig) + val lastOffset = 50L + + val producerEpoch = 0.toShort + val producerId = 15L + val appendProducer = LogTestUtils.appendTransactionalAsLeader(log, producerId, producerEpoch, mockTime) + + // Thread 1 writes single-record transactions and attempts to read them + // before they have been aborted, and then aborts them + val txnWriteAndReadLoop: Callable[Int] = () => { + var nonEmptyReads = 0 + while (log.logEndOffset < lastOffset) { + val currentLogEndOffset = log.logEndOffset + + appendProducer(1) + + val readInfo = log.read( + startOffset = currentLogEndOffset, + maxLength = Int.MaxValue, + isolation = FetchTxnCommitted, + minOneMessage = false) + + if (readInfo.records.sizeInBytes() > 0) + nonEmptyReads += 1 + + LogTestUtils.appendEndTxnMarkerAsLeader(log, producerId, producerEpoch, ControlRecordType.ABORT, mockTime.milliseconds()) + } + nonEmptyReads + } + + // Thread 2 watches the log and updates the high watermark + val hwUpdateLoop: Runnable = () => { + while (log.logEndOffset < lastOffset) { + log.updateHighWatermark(log.logEndOffset) + } + } + + val executor = Executors.newFixedThreadPool(2) + try { + executor.submit(hwUpdateLoop) + + val future = executor.submit(txnWriteAndReadLoop) + val nonEmptyReads = future.get() + + assertEquals(0, nonEmptyReads) + } finally { + executor.shutdownNow() + } + } + + @Test + def testTransactionIndexUpdated(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024 * 5) + val log = createLog(logDir, logConfig) + val epoch = 0.toShort + + val pid1 = 1L + val pid2 = 2L + val pid3 = 3L + val pid4 = 4L + + val appendPid1 = LogTestUtils.appendTransactionalAsLeader(log, pid1, epoch, mockTime) + val appendPid2 = LogTestUtils.appendTransactionalAsLeader(log, pid2, epoch, mockTime) + val appendPid3 = LogTestUtils.appendTransactionalAsLeader(log, pid3, epoch, mockTime) + val appendPid4 = LogTestUtils.appendTransactionalAsLeader(log, pid4, epoch, mockTime) + + // mix transactional and non-transactional data + appendPid1(5) // nextOffset: 5 + LogTestUtils.appendNonTransactionalAsLeader(log, 3) // 8 + appendPid2(2) // 10 + appendPid1(4) // 14 + appendPid3(3) // 17 + LogTestUtils.appendNonTransactionalAsLeader(log, 2) // 19 + appendPid1(10) // 29 + LogTestUtils.appendEndTxnMarkerAsLeader(log, pid1, epoch, ControlRecordType.ABORT, mockTime.milliseconds()) // 30 + appendPid2(6) // 36 + appendPid4(3) // 39 + LogTestUtils.appendNonTransactionalAsLeader(log, 10) // 49 + appendPid3(9) // 58 + LogTestUtils.appendEndTxnMarkerAsLeader(log, pid3, epoch, ControlRecordType.COMMIT, mockTime.milliseconds()) // 59 + appendPid4(8) // 67 + appendPid2(7) // 74 + LogTestUtils.appendEndTxnMarkerAsLeader(log, pid2, epoch, ControlRecordType.ABORT, mockTime.milliseconds()) // 75 + LogTestUtils.appendNonTransactionalAsLeader(log, 10) // 85 + appendPid4(4) // 89 + LogTestUtils.appendEndTxnMarkerAsLeader(log, pid4, epoch, ControlRecordType.COMMIT, mockTime.milliseconds()) // 90 + + val abortedTransactions = LogTestUtils.allAbortedTransactions(log) + val expectedTransactions = List( + new AbortedTxn(pid1, 0L, 29L, 8L), + new AbortedTxn(pid2, 8L, 74L, 36L) + ) + assertEquals(expectedTransactions, abortedTransactions) + + // Verify caching of the segment position of the first unstable offset + log.updateHighWatermark(30L) + assertCachedFirstUnstableOffset(log, expectedOffset = 8L) + + log.updateHighWatermark(75L) + assertCachedFirstUnstableOffset(log, expectedOffset = 36L) + + log.updateHighWatermark(log.logEndOffset) + assertEquals(None, log.firstUnstableOffset) + } + + @Test + def testTransactionIndexUpdatedThroughReplication(): Unit = { + val epoch = 0.toShort + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024 * 5) + val log = createLog(logDir, logConfig) + val buffer = ByteBuffer.allocate(2048) + + val pid1 = 1L + val pid2 = 2L + val pid3 = 3L + val pid4 = 4L + + val appendPid1 = appendTransactionalToBuffer(buffer, pid1, epoch) + val appendPid2 = appendTransactionalToBuffer(buffer, pid2, epoch) + val appendPid3 = appendTransactionalToBuffer(buffer, pid3, epoch) + val appendPid4 = appendTransactionalToBuffer(buffer, pid4, epoch) + + appendPid1(0L, 5) + appendNonTransactionalToBuffer(buffer, 5L, 3) + appendPid2(8L, 2) + appendPid1(10L, 4) + appendPid3(14L, 3) + appendNonTransactionalToBuffer(buffer, 17L, 2) + appendPid1(19L, 10) + appendEndTxnMarkerToBuffer(buffer, pid1, epoch, 29L, ControlRecordType.ABORT) + appendPid2(30L, 6) + appendPid4(36L, 3) + appendNonTransactionalToBuffer(buffer, 39L, 10) + appendPid3(49L, 9) + appendEndTxnMarkerToBuffer(buffer, pid3, epoch, 58L, ControlRecordType.COMMIT) + appendPid4(59L, 8) + appendPid2(67L, 7) + appendEndTxnMarkerToBuffer(buffer, pid2, epoch, 74L, ControlRecordType.ABORT) + appendNonTransactionalToBuffer(buffer, 75L, 10) + appendPid4(85L, 4) + appendEndTxnMarkerToBuffer(buffer, pid4, epoch, 89L, ControlRecordType.COMMIT) + + buffer.flip() + + appendAsFollower(log, MemoryRecords.readableRecords(buffer)) + + val abortedTransactions = LogTestUtils.allAbortedTransactions(log) + val expectedTransactions = List( + new AbortedTxn(pid1, 0L, 29L, 8L), + new AbortedTxn(pid2, 8L, 74L, 36L) + ) + + assertEquals(expectedTransactions, abortedTransactions) + + // Verify caching of the segment position of the first unstable offset + log.updateHighWatermark(30L) + assertCachedFirstUnstableOffset(log, expectedOffset = 8L) + + log.updateHighWatermark(75L) + assertCachedFirstUnstableOffset(log, expectedOffset = 36L) + + log.updateHighWatermark(log.logEndOffset) + assertEquals(None, log.firstUnstableOffset) + } + + private def assertCachedFirstUnstableOffset(log: UnifiedLog, expectedOffset: Long): Unit = { + assertTrue(log.producerStateManager.firstUnstableOffset.isDefined) + val firstUnstableOffset = log.producerStateManager.firstUnstableOffset.get + assertEquals(expectedOffset, firstUnstableOffset.messageOffset) + assertFalse(firstUnstableOffset.messageOffsetOnly) + assertValidLogOffsetMetadata(log, firstUnstableOffset) + } + + private def assertValidLogOffsetMetadata(log: UnifiedLog, offsetMetadata: LogOffsetMetadata): Unit = { + assertFalse(offsetMetadata.messageOffsetOnly) + + val segmentBaseOffset = offsetMetadata.segmentBaseOffset + val segmentOpt = log.logSegments(segmentBaseOffset, segmentBaseOffset + 1).headOption + assertTrue(segmentOpt.isDefined) + + val segment = segmentOpt.get + assertEquals(segmentBaseOffset, segment.baseOffset) + assertTrue(offsetMetadata.relativePositionInSegment <= segment.size) + + val readInfo = segment.read(offsetMetadata.messageOffset, + maxSize = 2048, + maxPosition = segment.size, + minOneMessage = false) + + if (offsetMetadata.relativePositionInSegment < segment.size) + assertEquals(offsetMetadata, readInfo.fetchOffsetMetadata) + else + assertNull(readInfo) + } + + @Test + def testZombieCoordinatorFenced(): Unit = { + val pid = 1L + val epoch = 0.toShort + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024 * 5) + val log = createLog(logDir, logConfig) + + val append = LogTestUtils.appendTransactionalAsLeader(log, pid, epoch, mockTime) + + append(10) + LogTestUtils.appendEndTxnMarkerAsLeader(log, pid, epoch, ControlRecordType.ABORT, mockTime.milliseconds(), coordinatorEpoch = 1) + + append(5) + LogTestUtils.appendEndTxnMarkerAsLeader(log, pid, epoch, ControlRecordType.COMMIT, mockTime.milliseconds(), coordinatorEpoch = 2) + + assertThrows( + classOf[TransactionCoordinatorFencedException], + () => LogTestUtils.appendEndTxnMarkerAsLeader(log, pid, epoch, ControlRecordType.ABORT, mockTime.milliseconds(), coordinatorEpoch = 1)) + } + + @Test + def testZombieCoordinatorFencedEmptyTransaction(): Unit = { + val pid = 1L + val epoch = 0.toShort + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024 * 5) + val log = createLog(logDir, logConfig) + + val buffer = ByteBuffer.allocate(256) + val append = appendTransactionalToBuffer(buffer, pid, epoch, leaderEpoch = 1) + append(0, 10) + appendEndTxnMarkerToBuffer(buffer, pid, epoch, 10L, ControlRecordType.COMMIT, leaderEpoch = 1) + + buffer.flip() + log.appendAsFollower(MemoryRecords.readableRecords(buffer)) + + LogTestUtils.appendEndTxnMarkerAsLeader(log, pid, epoch, ControlRecordType.ABORT, mockTime.milliseconds(), coordinatorEpoch = 2, leaderEpoch = 1) + LogTestUtils.appendEndTxnMarkerAsLeader(log, pid, epoch, ControlRecordType.ABORT, mockTime.milliseconds(), coordinatorEpoch = 2, leaderEpoch = 1) + assertThrows(classOf[TransactionCoordinatorFencedException], + () => LogTestUtils.appendEndTxnMarkerAsLeader(log, pid, epoch, ControlRecordType.ABORT, mockTime.milliseconds(), coordinatorEpoch = 1, leaderEpoch = 1)) + } + + @Test + def testEndTxnWithFencedProducerEpoch(): Unit = { + val producerId = 1L + val epoch = 5.toShort + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024 * 5) + val log = createLog(logDir, logConfig) + LogTestUtils.appendEndTxnMarkerAsLeader(log, producerId, epoch, ControlRecordType.ABORT, mockTime.milliseconds(), coordinatorEpoch = 1) + + assertThrows(classOf[InvalidProducerEpochException], + () => LogTestUtils.appendEndTxnMarkerAsLeader(log, producerId, (epoch - 1).toShort, ControlRecordType.ABORT, mockTime.milliseconds(), coordinatorEpoch = 1)) + } + + @Test + def testLastStableOffsetDoesNotExceedLogStartOffsetMidSegment(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024 * 5) + val log = createLog(logDir, logConfig) + val epoch = 0.toShort + val pid = 1L + val appendPid = LogTestUtils.appendTransactionalAsLeader(log, pid, epoch, mockTime) + + appendPid(5) + LogTestUtils.appendNonTransactionalAsLeader(log, 3) + assertEquals(8L, log.logEndOffset) + + log.roll() + assertEquals(2, log.logSegments.size) + appendPid(5) + + assertEquals(Some(0L), log.firstUnstableOffset) + + log.updateHighWatermark(log.logEndOffset) + log.maybeIncrementLogStartOffset(5L, ClientRecordDeletion) + + // the first unstable offset should be lower bounded by the log start offset + assertEquals(Some(5L), log.firstUnstableOffset) + } + + @Test + def testLastStableOffsetDoesNotExceedLogStartOffsetAfterSegmentDeletion(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024 * 5) + val log = createLog(logDir, logConfig) + val epoch = 0.toShort + val pid = 1L + val appendPid = LogTestUtils.appendTransactionalAsLeader(log, pid, epoch, mockTime) + + appendPid(5) + LogTestUtils.appendNonTransactionalAsLeader(log, 3) + assertEquals(8L, log.logEndOffset) + + log.roll() + assertEquals(2, log.logSegments.size) + appendPid(5) + + assertEquals(Some(0L), log.firstUnstableOffset) + + log.updateHighWatermark(log.logEndOffset) + log.maybeIncrementLogStartOffset(8L, ClientRecordDeletion) + log.updateHighWatermark(log.logEndOffset) + log.deleteOldSegments() + assertEquals(1, log.logSegments.size) + + // the first unstable offset should be lower bounded by the log start offset + assertEquals(Some(8L), log.firstUnstableOffset) + } + + @Test + def testAppendToTransactionIndexFailure(): Unit = { + val pid = 1L + val epoch = 0.toShort + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024 * 5) + val log = createLog(logDir, logConfig) + + val append = LogTestUtils.appendTransactionalAsLeader(log, pid, epoch, mockTime) + append(10) + + // Kind of a hack, but renaming the index to a directory ensures that the append + // to the index will fail. + log.activeSegment.txnIndex.renameTo(log.dir) + + // The append will be written to the log successfully, but the write to the index will fail + assertThrows( + classOf[KafkaStorageException], + () => LogTestUtils.appendEndTxnMarkerAsLeader(log, pid, epoch, ControlRecordType.ABORT, mockTime.milliseconds(), coordinatorEpoch = 1)) + assertEquals(11L, log.logEndOffset) + assertEquals(0L, log.lastStableOffset) + + // Try the append a second time. The appended offset in the log should not increase + // because the log dir is marked as failed. Nor will there be a write to the transaction + // index. + assertThrows( + classOf[KafkaStorageException], + () => LogTestUtils.appendEndTxnMarkerAsLeader(log, pid, epoch, ControlRecordType.ABORT, mockTime.milliseconds(), coordinatorEpoch = 1)) + assertEquals(11L, log.logEndOffset) + assertEquals(0L, log.lastStableOffset) + + // Even if the high watermark is updated, the first unstable offset does not move + log.updateHighWatermark(12L) + assertEquals(0L, log.lastStableOffset) + + assertThrows(classOf[KafkaStorageException], () => log.close()) + val reopenedLog = createLog(logDir, logConfig, lastShutdownClean = false) + assertEquals(11L, reopenedLog.logEndOffset) + assertEquals(1, reopenedLog.activeSegment.txnIndex.allAbortedTxns.size) + reopenedLog.updateHighWatermark(12L) + assertEquals(None, reopenedLog.firstUnstableOffset) + } + + @Test + def testOffsetSnapshot(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024 * 5) + val log = createLog(logDir, logConfig) + + // append a few records + appendAsFollower(log, MemoryRecords.withRecords(CompressionType.NONE, + new SimpleRecord("a".getBytes), + new SimpleRecord("b".getBytes), + new SimpleRecord("c".getBytes)), 5) + + + log.updateHighWatermark(2L) + var offsets: LogOffsetSnapshot = log.fetchOffsetSnapshot + assertEquals(offsets.highWatermark.messageOffset, 2L) + assertFalse(offsets.highWatermark.messageOffsetOnly) + + offsets = log.fetchOffsetSnapshot + assertEquals(offsets.highWatermark.messageOffset, 2L) + assertFalse(offsets.highWatermark.messageOffsetOnly) + } + + @Test + def testLastStableOffsetWithMixedProducerData(): Unit = { + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024 * 5) + val log = createLog(logDir, logConfig) + + // for convenience, both producers share the same epoch + val epoch = 5.toShort + + val pid1 = 137L + val seq1 = 0 + val pid2 = 983L + val seq2 = 0 + + // add some transactional records + val firstAppendInfo = log.appendAsLeader(MemoryRecords.withTransactionalRecords(CompressionType.NONE, pid1, epoch, seq1, + new SimpleRecord("a".getBytes), + new SimpleRecord("b".getBytes), + new SimpleRecord("c".getBytes)), leaderEpoch = 0) + assertEquals(firstAppendInfo.firstOffset.map(_.messageOffset), log.firstUnstableOffset) + + // mix in some non-transactional data + log.appendAsLeader(MemoryRecords.withRecords(CompressionType.NONE, + new SimpleRecord("g".getBytes), + new SimpleRecord("h".getBytes), + new SimpleRecord("i".getBytes)), leaderEpoch = 0) + + // append data from a second transactional producer + val secondAppendInfo = log.appendAsLeader(MemoryRecords.withTransactionalRecords(CompressionType.NONE, pid2, epoch, seq2, + new SimpleRecord("d".getBytes), + new SimpleRecord("e".getBytes), + new SimpleRecord("f".getBytes)), leaderEpoch = 0) + + // LSO should not have changed + assertEquals(firstAppendInfo.firstOffset.map(_.messageOffset), log.firstUnstableOffset) + + // now first producer's transaction is aborted + val abortAppendInfo = LogTestUtils.appendEndTxnMarkerAsLeader(log, pid1, epoch, ControlRecordType.ABORT, mockTime.milliseconds()) + log.updateHighWatermark(abortAppendInfo.lastOffset + 1) + + // LSO should now point to one less than the first offset of the second transaction + assertEquals(secondAppendInfo.firstOffset.map(_.messageOffset), log.firstUnstableOffset) + + // commit the second transaction + val commitAppendInfo = LogTestUtils.appendEndTxnMarkerAsLeader(log, pid2, epoch, ControlRecordType.COMMIT, mockTime.milliseconds()) + log.updateHighWatermark(commitAppendInfo.lastOffset + 1) + + // now there should be no first unstable offset + assertEquals(None, log.firstUnstableOffset) + } + + @Test + def testAbortedTransactionSpanningMultipleSegments(): Unit = { + val pid = 137L + val epoch = 5.toShort + var seq = 0 + + val records = MemoryRecords.withTransactionalRecords(CompressionType.NONE, pid, epoch, seq, + new SimpleRecord("a".getBytes), + new SimpleRecord("b".getBytes), + new SimpleRecord("c".getBytes)) + + val logConfig = LogTestUtils.createLogConfig(segmentBytes = records.sizeInBytes) + val log = createLog(logDir, logConfig) + + val firstAppendInfo = log.appendAsLeader(records, leaderEpoch = 0) + assertEquals(firstAppendInfo.firstOffset.map(_.messageOffset), log.firstUnstableOffset) + + // this write should spill to the second segment + seq = 3 + log.appendAsLeader(MemoryRecords.withTransactionalRecords(CompressionType.NONE, pid, epoch, seq, + new SimpleRecord("d".getBytes), + new SimpleRecord("e".getBytes), + new SimpleRecord("f".getBytes)), leaderEpoch = 0) + assertEquals(firstAppendInfo.firstOffset.map(_.messageOffset), log.firstUnstableOffset) + assertEquals(3L, log.logEndOffsetMetadata.segmentBaseOffset) + + // now abort the transaction + val abortAppendInfo = LogTestUtils.appendEndTxnMarkerAsLeader(log, pid, epoch, ControlRecordType.ABORT, mockTime.milliseconds()) + log.updateHighWatermark(abortAppendInfo.lastOffset + 1) + assertEquals(None, log.firstUnstableOffset) + + // now check that a fetch includes the aborted transaction + val fetchDataInfo = log.read(0L, + maxLength = 2048, + isolation = FetchTxnCommitted, + minOneMessage = true) + assertEquals(1, fetchDataInfo.abortedTransactions.size) + + assertTrue(fetchDataInfo.abortedTransactions.isDefined) + assertEquals(new FetchResponseData.AbortedTransaction().setProducerId(pid).setFirstOffset(0), + fetchDataInfo.abortedTransactions.get.head) + } + + @Test + def testLoadPartitionDirWithNoSegmentsShouldNotThrow(): Unit = { + val dirName = UnifiedLog.logDeleteDirName(new TopicPartition("foo", 3)) + val logDir = new File(tmpDir, dirName) + logDir.mkdirs() + val logConfig = LogTestUtils.createLogConfig() + val log = createLog(logDir, logConfig) + assertEquals(1, log.numberOfSegments) + } + + private def appendTransactionalToBuffer(buffer: ByteBuffer, + producerId: Long, + producerEpoch: Short, + leaderEpoch: Int = 0): (Long, Int) => Unit = { + var sequence = 0 + (offset: Long, numRecords: Int) => { + val builder = MemoryRecords.builder(buffer, RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE, TimestampType.CREATE_TIME, + offset, mockTime.milliseconds(), producerId, producerEpoch, sequence, true, leaderEpoch) + for (seq <- sequence until sequence + numRecords) { + val record = new SimpleRecord(s"$seq".getBytes) + builder.append(record) + } + + sequence += numRecords + builder.close() + } + } + + private def appendEndTxnMarkerToBuffer(buffer: ByteBuffer, + producerId: Long, + producerEpoch: Short, + offset: Long, + controlType: ControlRecordType, + coordinatorEpoch: Int = 0, + leaderEpoch: Int = 0): Unit = { + val marker = new EndTransactionMarker(controlType, coordinatorEpoch) + MemoryRecords.writeEndTransactionalMarker(buffer, offset, mockTime.milliseconds(), leaderEpoch, producerId, producerEpoch, marker) + } + + private def appendNonTransactionalToBuffer(buffer: ByteBuffer, offset: Long, numRecords: Int): Unit = { + val builder = MemoryRecords.builder(buffer, CompressionType.NONE, TimestampType.CREATE_TIME, offset) + (0 until numRecords).foreach { seq => + builder.append(new SimpleRecord(s"$seq".getBytes)) + } + builder.close() + } + + private def appendAsFollower(log: UnifiedLog, records: MemoryRecords, leaderEpoch: Int = 0): Unit = { + records.batches.forEach(_.setPartitionLeaderEpoch(leaderEpoch)) + log.appendAsFollower(records) + } + + private def createLog(dir: File, + config: LogConfig, + brokerTopicStats: BrokerTopicStats = brokerTopicStats, + logStartOffset: Long = 0L, + recoveryPoint: Long = 0L, + scheduler: Scheduler = mockTime.scheduler, + time: Time = mockTime, + maxProducerIdExpirationMs: Int = 60 * 60 * 1000, + producerIdExpirationCheckIntervalMs: Int = LogManager.ProducerIdExpirationCheckIntervalMs, + lastShutdownClean: Boolean = true, + topicId: Option[Uuid] = None, + keepPartitionMetadataFile: Boolean = true): UnifiedLog = { + LogTestUtils.createLog(dir, config, brokerTopicStats, scheduler, time, logStartOffset, recoveryPoint, + maxProducerIdExpirationMs, producerIdExpirationCheckIntervalMs, lastShutdownClean, topicId = topicId, keepPartitionMetadataFile = keepPartitionMetadataFile) + } + + private def createLogWithOffsetOverflow(logConfig: LogConfig): (UnifiedLog, LogSegment) = { + LogTestUtils.initializeLogDirWithOverflowedSegment(logDir) + + val log = createLog(logDir, logConfig, recoveryPoint = Long.MaxValue) + val segmentWithOverflow = LogTestUtils.firstOverflowSegment(log).getOrElse { + throw new AssertionError("Failed to create log with a segment which has overflowed offsets") + } + + (log, segmentWithOverflow) + } +} + +object UnifiedLogTest { + def allRecords(log: UnifiedLog): List[Record] = { + val recordsFound = ListBuffer[Record]() + for (logSegment <- log.logSegments) { + for (batch <- logSegment.log.batches.asScala) { + recordsFound ++= batch.iterator().asScala + } + } + recordsFound.toList + } + + def verifyRecordsInLog(log: UnifiedLog, expectedRecords: List[Record]): Unit = { + assertEquals(expectedRecords, allRecords(log)) + } +} diff --git a/core/src/test/scala/unit/kafka/metrics/KafkaTimerTest.scala b/core/src/test/scala/unit/kafka/metrics/KafkaTimerTest.scala new file mode 100644 index 0000000..826c7f7 --- /dev/null +++ b/core/src/test/scala/unit/kafka/metrics/KafkaTimerTest.scala @@ -0,0 +1,59 @@ +package kafka.metrics + +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.junit.jupiter.api.Test +import java.util.concurrent.TimeUnit +import org.junit.jupiter.api.Assertions._ +import com.yammer.metrics.core.{MetricsRegistry, Clock} + +class KafkaTimerTest { + + @Test + def testKafkaTimer(): Unit = { + val clock = new ManualClock + val testRegistry = new MetricsRegistry(clock) + val metric = testRegistry.newTimer(this.getClass, "TestTimer") + val Epsilon = java.lang.Double.longBitsToDouble(0x3ca0000000000000L) + + val timer = new KafkaTimer(metric) + timer.time { + clock.addMillis(1000) + } + assertEquals(1, metric.count()) + assertTrue((metric.max() - 1000).abs <= Epsilon) + assertTrue((metric.min() - 1000).abs <= Epsilon) + } + + private class ManualClock extends Clock { + + private var ticksInNanos = 0L + + override def tick() = { + ticksInNanos + } + + override def time() = { + TimeUnit.NANOSECONDS.toMillis(ticksInNanos) + } + + def addMillis(millis: Long): Unit = { + ticksInNanos += TimeUnit.MILLISECONDS.toNanos(millis) + } + } +} diff --git a/core/src/test/scala/unit/kafka/metrics/MetricsTest.scala b/core/src/test/scala/unit/kafka/metrics/MetricsTest.scala new file mode 100644 index 0000000..f4e69f9 --- /dev/null +++ b/core/src/test/scala/unit/kafka/metrics/MetricsTest.scala @@ -0,0 +1,262 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.metrics + +import java.lang.management.ManagementFactory +import java.util.Properties + +import javax.management.ObjectName +import com.yammer.metrics.core.MetricPredicate +import org.junit.jupiter.api.Assertions._ +import kafka.integration.KafkaServerTestHarness +import kafka.server._ +import kafka.utils._ + +import scala.collection._ +import scala.jdk.CollectionConverters._ +import kafka.log.LogConfig +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.metrics.JmxReporter +import org.apache.kafka.common.utils.Time +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ValueSource + +@Timeout(120) +class MetricsTest extends KafkaServerTestHarness with Logging { + val numNodes = 2 + val numParts = 2 + + val requiredKafkaServerPrefix = "kafka.server:type=KafkaServer,name" + val overridingProps = new Properties + overridingProps.put(KafkaConfig.NumPartitionsProp, numParts.toString) + overridingProps.put(JmxReporter.EXCLUDE_CONFIG, s"$requiredKafkaServerPrefix=ClusterId") + + def generateConfigs: Seq[KafkaConfig] = + TestUtils.createBrokerConfigs(numNodes, zkConnectOrNull, enableControlledShutdown = false). + map(KafkaConfig.fromProps(_, overridingProps)) + + val nMessages = 2 + + @ParameterizedTest + @ValueSource(strings = Array("zk", "kraft")) + def testMetricsReporterAfterDeletingTopic(quorum: String): Unit = { + val topic = "test-topic-metric" + createTopic(topic, 1, 1) + deleteTopic(topic) + TestUtils.verifyTopicDeletion(zkClientOrNull, topic, 1, brokers) + assertEquals(Set.empty, topicMetricGroups(topic), "Topic metrics exists after deleteTopic") + } + + @ParameterizedTest + @ValueSource(strings = Array("zk", "kraft")) + def testBrokerTopicMetricsUnregisteredAfterDeletingTopic(quorum: String): Unit = { + val topic = "test-broker-topic-metric" + createTopic(topic, 2, 1) + // Produce a few messages to create the metrics + // Don't consume messages as it may cause metrics to be re-created causing the test to fail, see KAFKA-5238 + TestUtils.generateAndProduceMessages(brokers, topic, nMessages) + assertTrue(topicMetricGroups(topic).nonEmpty, "Topic metrics don't exist") + brokers.foreach(b => assertNotNull(b.brokerTopicStats.topicStats(topic))) + deleteTopic(topic) + TestUtils.verifyTopicDeletion(zkClientOrNull, topic, 1, brokers) + assertEquals(Set.empty, topicMetricGroups(topic), "Topic metrics exists after deleteTopic") + } + + @ParameterizedTest + @ValueSource(strings = Array("zk", "kraft")) + def testClusterIdMetric(quorum: String): Unit = { + // Check if clusterId metric exists. + val metrics = KafkaYammerMetrics.defaultRegistry.allMetrics + assertEquals(metrics.keySet.asScala.count(_.getMBeanName == s"$requiredKafkaServerPrefix=ClusterId"), 1) + } + + @ParameterizedTest + @ValueSource(strings = Array("zk", "kraft")) + def testBrokerStateMetric(quorum: String): Unit = { + // Check if BrokerState metric exists. + val metrics = KafkaYammerMetrics.defaultRegistry.allMetrics + assertEquals(metrics.keySet.asScala.count(_.getMBeanName == s"$requiredKafkaServerPrefix=BrokerState"), 1) + } + + @ParameterizedTest + @ValueSource(strings = Array("zk", "kraft")) + def testYammerMetricsCountMetric(quorum: String): Unit = { + // Check if yammer-metrics-count metric exists. + val metrics = KafkaYammerMetrics.defaultRegistry.allMetrics + assertEquals(metrics.keySet.asScala.count(_.getMBeanName == s"$requiredKafkaServerPrefix=yammer-metrics-count"), 1) + } + + @ParameterizedTest + @ValueSource(strings = Array("zk", "kraft")) + def testLinuxIoMetrics(quorum: String): Unit = { + // Check if linux-disk-{read,write}-bytes metrics either do or do not exist depending on whether we are or are not + // able to collect those metrics on the platform where this test is running. + val usable = new LinuxIoMetricsCollector("/proc", Time.SYSTEM, logger.underlying).usable() + val expectedCount = if (usable) 1 else 0 + val metrics = KafkaYammerMetrics.defaultRegistry.allMetrics + Set("linux-disk-read-bytes", "linux-disk-write-bytes").foreach(name => + assertEquals(metrics.keySet.asScala.count(_.getMBeanName == s"$requiredKafkaServerPrefix=$name"), expectedCount)) + } + + @ParameterizedTest + @ValueSource(strings = Array("zk", "kraft")) + def testJMXFilter(quorum: String): Unit = { + // Check if cluster id metrics is not exposed in JMX + assertTrue(ManagementFactory.getPlatformMBeanServer + .isRegistered(new ObjectName("kafka.controller:type=KafkaController,name=ActiveControllerCount"))) + assertFalse(ManagementFactory.getPlatformMBeanServer + .isRegistered(new ObjectName(s"$requiredKafkaServerPrefix=ClusterId"))) + } + + @ParameterizedTest + @ValueSource(strings = Array("zk", "kraft")) + def testUpdateJMXFilter(quorum: String): Unit = { + // verify previously exposed metrics are removed and existing matching metrics are added + brokers.foreach(broker => broker.kafkaYammerMetrics.reconfigure( + Map(JmxReporter.EXCLUDE_CONFIG -> "kafka.controller:type=KafkaController,name=ActiveControllerCount").asJava + )) + assertFalse(ManagementFactory.getPlatformMBeanServer + .isRegistered(new ObjectName("kafka.controller:type=KafkaController,name=ActiveControllerCount"))) + assertTrue(ManagementFactory.getPlatformMBeanServer + .isRegistered(new ObjectName(s"$requiredKafkaServerPrefix=ClusterId"))) + } + + @ParameterizedTest + @ValueSource(strings = Array("zk", "kraft")) + def testGeneralBrokerTopicMetricsAreGreedilyRegistered(quorum: String): Unit = { + val topic = "test-broker-topic-metric" + createTopic(topic, 2, 1) + + // The broker metrics for all topics should be greedily registered + assertTrue(topicMetrics(None).nonEmpty, "General topic metrics don't exist") + assertEquals(brokers.head.brokerTopicStats.allTopicsStats.metricMap.size, topicMetrics(None).size) + // topic metrics should be lazily registered + assertTrue(topicMetricGroups(topic).isEmpty, "Topic metrics aren't lazily registered") + TestUtils.generateAndProduceMessages(brokers, topic, nMessages) + assertTrue(topicMetricGroups(topic).nonEmpty, "Topic metrics aren't registered") + } + + @ParameterizedTest + @ValueSource(strings = Array("zk", "kraft")) + def testWindowsStyleTagNames(quorum: String): Unit = { + val path = "C:\\windows-path\\kafka-logs" + val tags = Map("dir" -> path) + val expectedMBeanName = Set(tags.keySet.head, ObjectName.quote(path)).mkString("=") + val metric = KafkaMetricsGroup.metricName("test-metric", tags) + assert(metric.getMBeanName.endsWith(expectedMBeanName)) + } + + @ParameterizedTest + @ValueSource(strings = Array("zk", "kraft")) + def testBrokerTopicMetricsBytesInOut(quorum: String): Unit = { + val topic = "test-bytes-in-out" + val replicationBytesIn = BrokerTopicStats.ReplicationBytesInPerSec + val replicationBytesOut = BrokerTopicStats.ReplicationBytesOutPerSec + val bytesIn = s"${BrokerTopicStats.BytesInPerSec},topic=$topic" + val bytesOut = s"${BrokerTopicStats.BytesOutPerSec},topic=$topic" + + val topicConfig = new Properties + topicConfig.setProperty(LogConfig.MinInSyncReplicasProp, "2") + createTopic(topic, 1, numNodes, topicConfig) + // Produce a few messages to create the metrics + TestUtils.generateAndProduceMessages(brokers, topic, nMessages) + + // Check the log size for each broker so that we can distinguish between failures caused by replication issues + // versus failures caused by the metrics + val topicPartition = new TopicPartition(topic, 0) + brokers.foreach { broker => + val log = broker.logManager.getLog(new TopicPartition(topic, 0)) + val brokerId = broker.config.brokerId + val logSize = log.map(_.size) + assertTrue(logSize.exists(_ > 0), s"Expected broker $brokerId to have a Log for $topicPartition with positive size, actual: $logSize") + } + + // Consume messages to make bytesOut tick + TestUtils.consumeTopicRecords(brokers, topic, nMessages) + val initialReplicationBytesIn = TestUtils.meterCount(replicationBytesIn) + val initialReplicationBytesOut = TestUtils.meterCount(replicationBytesOut) + val initialBytesIn = TestUtils.meterCount(bytesIn) + val initialBytesOut = TestUtils.meterCount(bytesOut) + + // BytesOut doesn't include replication, so it shouldn't have changed + assertEquals(initialBytesOut, TestUtils.meterCount(bytesOut)) + + // Produce a few messages to make the metrics tick + TestUtils.generateAndProduceMessages(brokers, topic, nMessages) + + assertTrue(TestUtils.meterCount(replicationBytesIn) > initialReplicationBytesIn) + assertTrue(TestUtils.meterCount(replicationBytesOut) > initialReplicationBytesOut) + assertTrue(TestUtils.meterCount(bytesIn) > initialBytesIn) + + // Consume messages to make bytesOut tick + TestUtils.consumeTopicRecords(brokers, topic, nMessages) + + assertTrue(TestUtils.meterCount(bytesOut) > initialBytesOut) + } + + @ParameterizedTest + @ValueSource(strings = Array("zk")) + def testZkControllerMetrics(quorum: String): Unit = { + val metrics = KafkaYammerMetrics.defaultRegistry.allMetrics + + assertEquals(metrics.keySet.asScala.count(_.getMBeanName == "kafka.controller:type=KafkaController,name=ActiveControllerCount"), 1) + assertEquals(metrics.keySet.asScala.count(_.getMBeanName == "kafka.controller:type=KafkaController,name=OfflinePartitionsCount"), 1) + assertEquals(metrics.keySet.asScala.count(_.getMBeanName == "kafka.controller:type=KafkaController,name=PreferredReplicaImbalanceCount"), 1) + assertEquals(metrics.keySet.asScala.count(_.getMBeanName == "kafka.controller:type=KafkaController,name=GlobalTopicCount"), 1) + assertEquals(metrics.keySet.asScala.count(_.getMBeanName == "kafka.controller:type=KafkaController,name=GlobalPartitionCount"), 1) + assertEquals(metrics.keySet.asScala.count(_.getMBeanName == "kafka.controller:type=KafkaController,name=TopicsToDeleteCount"), 1) + assertEquals(metrics.keySet.asScala.count(_.getMBeanName == "kafka.controller:type=KafkaController,name=ReplicasToDeleteCount"), 1) + assertEquals(metrics.keySet.asScala.count(_.getMBeanName == "kafka.controller:type=KafkaController,name=TopicsIneligibleToDeleteCount"), 1) + assertEquals(metrics.keySet.asScala.count(_.getMBeanName == "kafka.controller:type=KafkaController,name=ReplicasIneligibleToDeleteCount"), 1) + assertEquals(metrics.keySet.asScala.count(_.getMBeanName == "kafka.controller:type=KafkaController,name=ActiveBrokerCount"), 1) + assertEquals(metrics.keySet.asScala.count(_.getMBeanName == "kafka.controller:type=KafkaController,name=FencedBrokerCount"), 1) + } + + /** + * Test that the metrics are created with the right name, testZooKeeperStateChangeRateMetrics + * and testZooKeeperSessionStateMetric in ZooKeeperClientTest test the metrics behaviour. + */ + @ParameterizedTest + @ValueSource(strings = Array("zk", "kraft")) + def testSessionExpireListenerMetrics(quorum: String): Unit = { + val metrics = KafkaYammerMetrics.defaultRegistry.allMetrics + val expectedNumMetrics = if (isKRaftTest()) 0 else 1 + assertEquals(expectedNumMetrics, metrics.keySet.asScala. + count(_.getMBeanName == "kafka.server:type=SessionExpireListener,name=SessionState")) + assertEquals(expectedNumMetrics, metrics.keySet.asScala. + count(_.getMBeanName == "kafka.server:type=SessionExpireListener,name=ZooKeeperExpiresPerSec")) + assertEquals(expectedNumMetrics, metrics.keySet.asScala. + count(_.getMBeanName == "kafka.server:type=SessionExpireListener,name=ZooKeeperDisconnectsPerSec")) + } + + private def topicMetrics(topic: Option[String]): Set[String] = { + val metricNames = KafkaYammerMetrics.defaultRegistry.allMetrics().keySet.asScala.map(_.getMBeanName) + filterByTopicMetricRegex(metricNames, topic) + } + + private def topicMetricGroups(topic: String): Set[String] = { + val metricGroups = KafkaYammerMetrics.defaultRegistry.groupedMetrics(MetricPredicate.ALL).keySet.asScala + filterByTopicMetricRegex(metricGroups, Some(topic)) + } + + private def filterByTopicMetricRegex(metrics: Set[String], topic: Option[String]): Set[String] = { + val pattern = (".*BrokerTopicMetrics.*" + topic.map(t => s"($t)$$").getOrElse("")).r.pattern + metrics.filter(pattern.matcher(_).matches()) + } +} diff --git a/core/src/test/scala/unit/kafka/network/ConnectionQuotasTest.scala b/core/src/test/scala/unit/kafka/network/ConnectionQuotasTest.scala new file mode 100644 index 0000000..2ebcc37 --- /dev/null +++ b/core/src/test/scala/unit/kafka/network/ConnectionQuotasTest.scala @@ -0,0 +1,953 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.network + +import java.net.InetAddress +import java.util +import java.util.concurrent.{Callable, ExecutorService, Executors, TimeUnit} +import java.util.{Collections, Properties} +import com.yammer.metrics.core.Meter +import kafka.metrics.KafkaMetricsGroup +import kafka.network.Processor.ListenerMetricTag +import kafka.server.KafkaConfig +import kafka.utils.Implicits.MapExtensionMethods +import kafka.utils.{MockTime, TestUtils} +import org.apache.kafka.common.config.ConfigException +import org.apache.kafka.common.config.internals.QuotaConfigs +import org.apache.kafka.common.metrics.internals.MetricsUtils +import org.apache.kafka.common.metrics.{KafkaMetric, MetricConfig, Metrics} +import org.apache.kafka.common.network._ +import org.apache.kafka.common.utils.Time +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api._ + +import scala.jdk.CollectionConverters._ +import scala.collection.{Map, mutable} +import scala.concurrent.TimeoutException + +class ConnectionQuotasTest { + private var metrics: Metrics = _ + private var executor: ExecutorService = _ + private var connectionQuotas: ConnectionQuotas = _ + private var time: Time = _ + + private val listeners = Map( + "EXTERNAL" -> ListenerDesc(new ListenerName("EXTERNAL"), InetAddress.getByName("192.168.1.1")), + "ADMIN" -> ListenerDesc(new ListenerName("ADMIN"), InetAddress.getByName("192.168.1.2")), + "REPLICATION" -> ListenerDesc(new ListenerName("REPLICATION"), InetAddress.getByName("192.168.1.3"))) + private val blockedPercentMeters = mutable.Map[String, Meter]() + private val knownHost = InetAddress.getByName("192.168.10.0") + private val unknownHost = InetAddress.getByName("192.168.2.0") + + private val numQuotaSamples = 2 + private val quotaWindowSizeSeconds = 1 + private val eps = 0.01 + + case class ListenerDesc(listenerName: ListenerName, defaultIp: InetAddress) { + override def toString: String = { + s"(listener=${listenerName.value}, client=${defaultIp.getHostAddress})" + } + } + + def brokerPropsWithDefaultConnectionLimits: Properties = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 0) + props.put(KafkaConfig.ListenersProp, "EXTERNAL://localhost:0,REPLICATION://localhost:1,ADMIN://localhost:2") + // ConnectionQuotas does not limit inter-broker listener even when broker-wide connection limit is reached + props.put(KafkaConfig.InterBrokerListenerNameProp, "REPLICATION") + props.put(KafkaConfig.ListenerSecurityProtocolMapProp, "EXTERNAL:PLAINTEXT,REPLICATION:PLAINTEXT,ADMIN:PLAINTEXT") + props.put(KafkaConfig.NumQuotaSamplesProp, numQuotaSamples.toString) + props.put(KafkaConfig.QuotaWindowSizeSecondsProp, quotaWindowSizeSeconds.toString) + props + } + + private def setupMockTime(): Unit = { + // clean up metrics initialized with Time.SYSTEM + metrics.close() + time = new MockTime() + metrics = new Metrics(time) + } + + @BeforeEach + def setUp(): Unit = { + // Clean-up any metrics left around by previous tests + TestUtils.clearYammerMetrics() + + listeners.keys.foreach { name => + blockedPercentMeters.put(name, KafkaMetricsGroup.newMeter( + s"${name}BlockedPercent", "blocked time", TimeUnit.NANOSECONDS, Map(ListenerMetricTag -> name))) + } + // use system time, because ConnectionQuota causes the current thread to wait with timeout, which waits based on + // system time; so using mock time will likely result in test flakiness due to a mixed use of mock and system time + time = Time.SYSTEM + metrics = new Metrics(new MetricConfig(), Collections.emptyList(), time) + executor = Executors.newFixedThreadPool(listeners.size) + } + + @AfterEach + def tearDown(): Unit = { + executor.shutdownNow() + if (connectionQuotas != null) { + connectionQuotas.close() + } + metrics.close() + TestUtils.clearYammerMetrics() + blockedPercentMeters.clear() + } + + @Test + def testFailWhenNoListeners(): Unit = { + val config = KafkaConfig.fromProps(brokerPropsWithDefaultConnectionLimits) + connectionQuotas = new ConnectionQuotas(config, time, metrics) + + // inc() on a separate thread in case it blocks + val listener = listeners("EXTERNAL") + executor.submit((() => + assertThrows(classOf[RuntimeException], + () => connectionQuotas.inc(listener.listenerName, listener.defaultIp, blockedPercentMeters("EXTERNAL")) + )): Runnable + ).get(5, TimeUnit.SECONDS) + } + + @Test + def testFailDecrementForUnknownIp(): Unit = { + val config = KafkaConfig.fromProps(brokerPropsWithDefaultConnectionLimits) + connectionQuotas = new ConnectionQuotas(config, time, metrics) + addListenersAndVerify(config, connectionQuotas) + + // calling dec() for an IP for which we didn't call inc() should throw an exception + assertThrows(classOf[IllegalArgumentException], () => connectionQuotas.dec(listeners("EXTERNAL").listenerName, unknownHost)) + } + + @Test + def testNoConnectionLimitsByDefault(): Unit = { + val config = KafkaConfig.fromProps(brokerPropsWithDefaultConnectionLimits) + connectionQuotas = new ConnectionQuotas(config, time, metrics) + addListenersAndVerify(config, connectionQuotas) + + // verify there is no limit by accepting 10000 connections as fast as possible + val numConnections = 10000 + val futures = listeners.values.map { listener => + executor.submit((() => acceptConnections(connectionQuotas, listener, numConnections)): Runnable) + } + futures.foreach(_.get(10, TimeUnit.SECONDS)) + assertTrue(metricValue(brokerConnRateMetric())> 0, "Expected broker-connection-accept-rate metric to get recorded") + listeners.values.foreach { listener => + assertEquals(numConnections, connectionQuotas.get(listener.defaultIp), s"Number of connections on $listener:") + + assertTrue(metricValue(listenerConnRateMetric(listener.listenerName.value)) > 0, + s"Expected connection-accept-rate metric to get recorded for listener $listener") + + // verify removing one connection + connectionQuotas.dec(listener.listenerName, listener.defaultIp) + assertEquals(numConnections - 1, connectionQuotas.get(listener.defaultIp), + s"Number of connections on $listener:") + } + // the blocked percent should still be 0, because no limits were reached + verifyNoBlockedPercentRecordedOnAllListeners() + } + + @Test + def testMaxConnectionsPerIp(): Unit = { + val maxConnectionsPerIp = 17 + val props = brokerPropsWithDefaultConnectionLimits + props.put(KafkaConfig.MaxConnectionsPerIpProp, maxConnectionsPerIp.toString) + val config = KafkaConfig.fromProps(props) + connectionQuotas = new ConnectionQuotas(config, time, metrics) + + addListenersAndVerify(config, connectionQuotas) + + val externalListener = listeners("EXTERNAL") + executor.submit((() => + acceptConnections(connectionQuotas, externalListener, maxConnectionsPerIp)): Runnable + ).get(5, TimeUnit.SECONDS) + assertEquals(maxConnectionsPerIp, connectionQuotas.get(externalListener.defaultIp), + s"Number of connections on $externalListener:") + + // all subsequent connections will be added to the counters, but inc() will throw TooManyConnectionsException for each + executor.submit((() => + acceptConnectionsAboveIpLimit(connectionQuotas, externalListener, 2)): Runnable + ).get(5, TimeUnit.SECONDS) + assertEquals(maxConnectionsPerIp + 2, connectionQuotas.get(externalListener.defaultIp), + s"Number of connections on $externalListener:") + + // connections on the same listener but from a different IP should be accepted + executor.submit((() => + acceptConnections(connectionQuotas, externalListener.listenerName, knownHost, maxConnectionsPerIp, + 0, expectIpThrottle = false)): Runnable + ).get(5, TimeUnit.SECONDS) + + // remove two "rejected" connections and remove 2 more connections to free up the space for another 2 connections + for (_ <- 0 until 4) connectionQuotas.dec(externalListener.listenerName, externalListener.defaultIp) + assertEquals(maxConnectionsPerIp - 2, connectionQuotas.get(externalListener.defaultIp), + s"Number of connections on $externalListener:") + + executor.submit((() => + acceptConnections(connectionQuotas, externalListener, 2)): Runnable + ).get(5, TimeUnit.SECONDS) + assertEquals(maxConnectionsPerIp, connectionQuotas.get(externalListener.defaultIp), + s"Number of connections on $externalListener:") + } + + @Test + def testMaxBrokerWideConnectionLimit(): Unit = { + val maxConnections = 800 + val props = brokerPropsWithDefaultConnectionLimits + props.put(KafkaConfig.MaxConnectionsProp, maxConnections.toString) + val config = KafkaConfig.fromProps(props) + connectionQuotas = new ConnectionQuotas(config, time, metrics) + + addListenersAndVerify(config, connectionQuotas) + + // verify that ConnectionQuota can give all connections to one listener + executor.submit((() => + acceptConnections(connectionQuotas, listeners("EXTERNAL"), maxConnections)): Runnable + ).get(5, TimeUnit.SECONDS) + assertEquals(maxConnections, connectionQuotas.get(listeners("EXTERNAL").defaultIp), + s"Number of connections on ${listeners("EXTERNAL")}:") + + // the blocked percent should still be 0, because there should be no wait for a connection slot + assertEquals(0, blockedPercentMeters("EXTERNAL").count()) + + // the number of connections should be above max for maxConnectionsExceeded to return true + assertFalse(connectionQuotas.maxConnectionsExceeded(listeners("EXTERNAL").listenerName), + "Total number of connections is exactly the maximum.") + + // adding one more connection will block ConnectionQuota.inc() + val future = executor.submit((() => + acceptConnections(connectionQuotas, listeners("EXTERNAL"), 1)): Runnable + ) + assertThrows(classOf[TimeoutException], () => future.get(100, TimeUnit.MILLISECONDS)) + + // removing one connection should make the waiting connection to succeed + connectionQuotas.dec(listeners("EXTERNAL").listenerName, listeners("EXTERNAL").defaultIp) + future.get(1, TimeUnit.SECONDS) + assertEquals(maxConnections, connectionQuotas.get(listeners("EXTERNAL").defaultIp), + s"Number of connections on ${listeners("EXTERNAL")}:") + // metric is recorded in nanoseconds + assertTrue(blockedPercentMeters("EXTERNAL").count() > 0, + "Expected BlockedPercentMeter metric to be recorded") + + // adding inter-broker connections should succeed even when the total number of connections reached the max + executor.submit((() => + acceptConnections(connectionQuotas, listeners("REPLICATION"), 1)): Runnable + ).get(5, TimeUnit.SECONDS) + assertTrue(connectionQuotas.maxConnectionsExceeded(listeners("EXTERNAL").listenerName), + "Expected the number of connections to exceed the maximum.") + + // adding one more connection on another non-inter-broker will block ConnectionQuota.inc() + val future1 = executor.submit((() => + acceptConnections(connectionQuotas, listeners("ADMIN"), 1)): Runnable + ) + assertThrows(classOf[TimeoutException], () => future1.get(1, TimeUnit.SECONDS)) + + // adding inter-broker connection should still succeed, even though a connection from another listener is waiting + executor.submit((() => + acceptConnections(connectionQuotas, listeners("REPLICATION"), 1)): Runnable + ).get(5, TimeUnit.SECONDS) + + // at this point, we need to remove 3 connections for the waiting connection to succeed + // remove 2 first -- should not be enough to accept the waiting connection + for (_ <- 0 until 2) connectionQuotas.dec(listeners("EXTERNAL").listenerName, listeners("EXTERNAL").defaultIp) + assertThrows(classOf[TimeoutException], () => future1.get(100, TimeUnit.MILLISECONDS)) + connectionQuotas.dec(listeners("EXTERNAL").listenerName, listeners("EXTERNAL").defaultIp) + future1.get(1, TimeUnit.SECONDS) + } + + @Test + def testMaxListenerConnectionLimits(): Unit = { + val maxConnections = 800 + // sum of per-listener connection limits is below total connection limit + val listenerMaxConnections = 200 + val props = brokerPropsWithDefaultConnectionLimits + props.put(KafkaConfig.MaxConnectionsProp, maxConnections.toString) + val config = KafkaConfig.fromProps(props) + connectionQuotas = new ConnectionQuotas(config, time, metrics) + + addListenersAndVerify(config, connectionQuotas) + + val listenerConfig = Map(KafkaConfig.MaxConnectionsProp -> listenerMaxConnections.toString).asJava + listeners.values.foreach { listener => + connectionQuotas.maxConnectionsPerListener(listener.listenerName).configure(listenerConfig) + } + + // verify each listener can create up to max connections configured for that listener + val futures = listeners.values.map { listener => + executor.submit((() => acceptConnections(connectionQuotas, listener, listenerMaxConnections)): Runnable) + } + futures.foreach(_.get(5, TimeUnit.SECONDS)) + listeners.values.foreach { listener => + assertEquals(listenerMaxConnections, connectionQuotas.get(listener.defaultIp), + s"Number of connections on $listener:") + assertFalse(connectionQuotas.maxConnectionsExceeded(listener.listenerName), + s"Total number of connections on $listener should be exactly the maximum.") + } + + // since every listener has exactly the max number of listener connections, + // every listener should block on the next connection creation, even the inter-broker listener + val overLimitFutures = listeners.values.map { listener => + executor.submit((() => acceptConnections(connectionQuotas, listener, 1)): Runnable) + } + overLimitFutures.foreach { future => + assertThrows(classOf[TimeoutException], () => future.get(1, TimeUnit.SECONDS)) + } + listeners.values.foreach { listener => + // free up one connection slot + connectionQuotas.dec(listener.listenerName, listener.defaultIp) + } + // all connections should get added + overLimitFutures.foreach(_.get(5, TimeUnit.SECONDS)) + verifyConnectionCountOnEveryListener(connectionQuotas, listenerMaxConnections) + } + + @Test + def testBrokerConnectionRateLimitWhenActualRateBelowLimit(): Unit = { + val brokerRateLimit = 125 + // create connections with the total rate < broker-wide quota, and verify there is no throttling + val connCreateIntervalMs = 25 // connection creation rate = 40/sec per listener (3 * 40 = 120/sec total) + val connectionsPerListener = 200 // should take 5 seconds to create 200 connections with rate = 40/sec + val props = brokerPropsWithDefaultConnectionLimits + props.put(KafkaConfig.MaxConnectionCreationRateProp, brokerRateLimit.toString) + val config = KafkaConfig.fromProps(props) + connectionQuotas = new ConnectionQuotas(config, time, metrics) + + addListenersAndVerify(config, connectionQuotas) + + val futures = listeners.values.map { listener => + executor.submit((() => acceptConnections(connectionQuotas, listener, connectionsPerListener, connCreateIntervalMs)): Runnable) + } + futures.foreach(_.get(10, TimeUnit.SECONDS)) + + // the blocked percent should still be 0, because no limits were reached + verifyNoBlockedPercentRecordedOnAllListeners() + verifyConnectionCountOnEveryListener(connectionQuotas, connectionsPerListener) + } + + @Test + def testBrokerConnectionRateLimitWhenActualRateAboveLimit(): Unit = { + val brokerRateLimit = 90 + val props = brokerPropsWithDefaultConnectionLimits + props.put(KafkaConfig.MaxConnectionCreationRateProp, brokerRateLimit.toString) + val config = KafkaConfig.fromProps(props) + connectionQuotas = new ConnectionQuotas(config, time, metrics) + + addListenersAndVerify(config, connectionQuotas) + + // each listener creates connections such that the total connection rate > broker-wide quota + val connCreateIntervalMs = 10 // connection creation rate = 100 + val connectionsPerListener = 400 + val futures = listeners.values.map { listener => + executor.submit((() => acceptConnections(connectionQuotas, listener, connectionsPerListener, connCreateIntervalMs)): Runnable) + } + futures.foreach(_.get(20, TimeUnit.SECONDS)) + + // verify that connections on non-inter-broker listener are throttled + verifyOnlyNonInterBrokerListenersBlockedPercentRecorded() + + // expect all connections to be created (no limit on the number of connections) + verifyConnectionCountOnEveryListener(connectionQuotas, connectionsPerListener) + } + + @Test + def testListenerConnectionRateLimitWhenActualRateBelowLimit(): Unit = { + val brokerRateLimit = 125 + val listenerRateLimit = 50 + val connCreateIntervalMs = 25 // connection creation rate = 40/sec per listener (3 * 40 = 120/sec total) + val props = brokerPropsWithDefaultConnectionLimits + props.put(KafkaConfig.MaxConnectionCreationRateProp, brokerRateLimit.toString) + val config = KafkaConfig.fromProps(props) + connectionQuotas = new ConnectionQuotas(config, time, metrics) + + val listenerConfig = Map(KafkaConfig.MaxConnectionCreationRateProp -> listenerRateLimit.toString).asJava + addListenersAndVerify(config, listenerConfig, connectionQuotas) + + // create connections with the rate < listener quota on every listener, and verify there is no throttling + val connectionsPerListener = 200 // should take 5 seconds to create 200 connections with rate = 40/sec + val futures = listeners.values.map { listener => + executor.submit((() => acceptConnections(connectionQuotas, listener, connectionsPerListener, connCreateIntervalMs)): Runnable) + } + futures.foreach(_.get(10, TimeUnit.SECONDS)) + + // the blocked percent should still be 0, because no limits were reached + verifyNoBlockedPercentRecordedOnAllListeners() + + verifyConnectionCountOnEveryListener(connectionQuotas, connectionsPerListener) + } + + @Test + def testListenerConnectionRateLimitWhenActualRateAboveLimit(): Unit = { + val brokerRateLimit = 125 + val listenerRateLimit = 30 + val connCreateIntervalMs = 25 // connection creation rate = 40/sec per listener (3 * 40 = 120/sec total) + val props = brokerPropsWithDefaultConnectionLimits + props.put(KafkaConfig.MaxConnectionCreationRateProp, brokerRateLimit.toString) + val config = KafkaConfig.fromProps(props) + connectionQuotas = new ConnectionQuotas(config, time, metrics) + + val listenerConfig = Map(KafkaConfig.MaxConnectionCreationRateProp -> listenerRateLimit.toString).asJava + addListenersAndVerify(config, listenerConfig, connectionQuotas) + + // create connections with the rate > listener quota on every listener + // run a bit longer (20 seconds) to also verify the throttle rate + val connectionsPerListener = 600 // should take 20 seconds to create 600 connections with rate = 30/sec + val futures = listeners.values.map { listener => + executor.submit((() => + // epsilon is set to account for the worst-case where the measurement is taken just before or after the quota window + acceptConnectionsAndVerifyRate(connectionQuotas, listener, connectionsPerListener, connCreateIntervalMs, listenerRateLimit, 7)): Runnable) + } + futures.foreach(_.get(30, TimeUnit.SECONDS)) + + // verify that every listener was throttled + verifyNonZeroBlockedPercentAndThrottleTimeOnAllListeners() + + // while the connection creation rate was throttled, + // expect all connections got created (not limit on the number of connections) + verifyConnectionCountOnEveryListener(connectionQuotas, connectionsPerListener) + } + + @Test + def testIpConnectionRateWhenActualRateBelowLimit(): Unit = { + val ipConnectionRateLimit = 30 + val connCreateIntervalMs = 40 // connection creation rate = 25/sec + val props = brokerPropsWithDefaultConnectionLimits + val config = KafkaConfig.fromProps(props) + // use MockTime for IP connection rate quota tests that don't expect to block + setupMockTime() + connectionQuotas = new ConnectionQuotas(config, time, metrics) + addListenersAndVerify(config, connectionQuotas) + val externalListener = listeners("EXTERNAL") + connectionQuotas.updateIpConnectionRateQuota(Some(externalListener.defaultIp), Some(ipConnectionRateLimit)) + val numConnections = 200 + // create connections with the rate < ip quota and verify there is no throttling + acceptConnectionsAndVerifyRate(connectionQuotas, externalListener, numConnections, connCreateIntervalMs, + expectedRate = 25, epsilon = 0) + assertEquals(numConnections, connectionQuotas.get(externalListener.defaultIp), + s"Number of connections on $externalListener:") + + val adminListener = listeners("ADMIN") + val unthrottledConnectionCreateInterval = 20 // connection creation rate = 50/s + // create connections with an IP with no quota and verify there is no throttling + acceptConnectionsAndVerifyRate(connectionQuotas, adminListener, numConnections, unthrottledConnectionCreateInterval, + expectedRate = 50, epsilon = 0) + + assertEquals(numConnections, connectionQuotas.get(adminListener.defaultIp), + s"Number of connections on $adminListener:") + + // acceptor shouldn't block for IP rate throttling + verifyNoBlockedPercentRecordedOnAllListeners() + // no IP throttle time should be recorded on any listeners + listeners.values.map(_.listenerName).foreach(verifyIpThrottleTimeOnListener(_, expectThrottle = false)) + } + + @Test + def testIpConnectionRateWhenActualRateAboveLimit(): Unit = { + val ipConnectionRateLimit = 20 + val connCreateIntervalMs = 25 // connection creation rate = 40/sec + val props = brokerPropsWithDefaultConnectionLimits + val config = KafkaConfig.fromProps(props) + // use MockTime for IP connection rate quota tests that don't expect to block + setupMockTime() + connectionQuotas = new ConnectionQuotas(config, time, metrics) + addListenersAndVerify(config, connectionQuotas) + val externalListener = listeners("EXTERNAL") + connectionQuotas.updateIpConnectionRateQuota(Some(externalListener.defaultIp), Some(ipConnectionRateLimit)) + // create connections with the rate > ip quota + val numConnections = 80 + acceptConnectionsAndVerifyRate(connectionQuotas, externalListener, numConnections, connCreateIntervalMs, ipConnectionRateLimit, + 1, expectIpThrottle = true) + verifyIpThrottleTimeOnListener(externalListener.listenerName, expectThrottle = true) + + // verify that default quota applies to IPs without a quota override + connectionQuotas.updateIpConnectionRateQuota(None, Some(ipConnectionRateLimit)) + val adminListener = listeners("ADMIN") + // listener shouldn't have any IP throttle time recorded + verifyIpThrottleTimeOnListener(adminListener.listenerName, expectThrottle = false) + acceptConnectionsAndVerifyRate(connectionQuotas, adminListener, numConnections, connCreateIntervalMs, ipConnectionRateLimit, + 1, expectIpThrottle = true) + verifyIpThrottleTimeOnListener(adminListener.listenerName, expectThrottle = true) + + // acceptor shouldn't block for IP rate throttling + verifyNoBlockedPercentRecordedOnAllListeners() + // replication listener shouldn't have any IP throttling recorded + verifyIpThrottleTimeOnListener(listeners("REPLICATION").listenerName, expectThrottle = false) + } + + @Test + def testIpConnectionRateWithListenerConnectionRate(): Unit = { + val ipConnectionRateLimit = 25 + val listenerRateLimit = 35 + val props = brokerPropsWithDefaultConnectionLimits + val config = KafkaConfig.fromProps(props) + connectionQuotas = new ConnectionQuotas(config, time, metrics) + // with a default per-IP limit of 25 and a listener rate of 30, only one IP should be able to saturate their IP rate + // limit, the other IP will hit listener rate limits and block + connectionQuotas.updateIpConnectionRateQuota(None, Some(ipConnectionRateLimit)) + val listenerConfig = Map(KafkaConfig.MaxConnectionCreationRateProp -> listenerRateLimit.toString).asJava + addListenersAndVerify(config, listenerConfig, connectionQuotas) + val listener = listeners("EXTERNAL").listenerName + // use a small number of connections because a longer-running test will have both IPs throttle at different times + val numConnections = 35 + val futures = List( + executor.submit((() => acceptConnections(connectionQuotas, listener, knownHost, numConnections, + 0, true)): Callable[Boolean]), + executor.submit((() => acceptConnections(connectionQuotas, listener, unknownHost, numConnections, + 0, true)): Callable[Boolean]) + ) + + val ipsThrottledResults = futures.map(_.get(3, TimeUnit.SECONDS)) + val throttledIps = ipsThrottledResults.filter(identity) + // at most one IP should get IP throttled before the acceptor blocks on listener quota + assertTrue(blockedPercentMeters("EXTERNAL").count() > 0, + "Expected BlockedPercentMeter metric for EXTERNAL listener to be recorded") + assertTrue(throttledIps.size < 2, + "Expect at most one IP to get throttled") + } + + @Test + def testRejectedIpConnectionUnrecordedFromConnectionRateQuotas(): Unit = { + val config = KafkaConfig.fromProps(brokerPropsWithDefaultConnectionLimits) + connectionQuotas = new ConnectionQuotas(config, new MockTime(), metrics) + addListenersAndVerify(config, connectionQuotas) + val externalListener = listeners("EXTERNAL") + val protectedListener = listeners("REPLICATION") + connectionQuotas.updateIpConnectionRateQuota(Some(externalListener.defaultIp), Some(0)) + connectionQuotas.updateIpConnectionRateQuota(Some(protectedListener.defaultIp), Some(0)) + + assertThrows(classOf[ConnectionThrottledException], + () => connectionQuotas.inc(externalListener.listenerName, externalListener.defaultIp, blockedPercentMeters("EXTERNAL")) + ) + + val brokerRateMetric = brokerConnRateMetric() + // rejected connection shouldn't be recorded for any of the connection accepted rate metrics + assertEquals(0, metricValue(ipConnRateMetric(externalListener.defaultIp.getHostAddress)), eps) + assertEquals(0, metricValue(listenerConnRateMetric(externalListener.listenerName.value)), eps) + assertEquals(0, metricValue(brokerRateMetric), eps) + + assertThrows(classOf[ConnectionThrottledException], + () => connectionQuotas.inc(protectedListener.listenerName, protectedListener.defaultIp, blockedPercentMeters("REPLICATION")) + ) + + assertEquals(0, metricValue(ipConnRateMetric(protectedListener.defaultIp.getHostAddress)), eps) + assertEquals(0, metricValue(listenerConnRateMetric(protectedListener.listenerName.value)), eps) + assertEquals(0, metricValue(brokerRateMetric), eps) + } + + @Test + def testMaxListenerConnectionListenerMustBeAboveZero(): Unit = { + val config = KafkaConfig.fromProps(brokerPropsWithDefaultConnectionLimits) + connectionQuotas = new ConnectionQuotas(config, time, metrics) + + connectionQuotas.addListener(config, listeners("EXTERNAL").listenerName) + + val maxListenerConnectionRate = 0 + val listenerConfig = Map(KafkaConfig.MaxConnectionCreationRateProp -> maxListenerConnectionRate.toString).asJava + assertThrows(classOf[ConfigException], + () => connectionQuotas.maxConnectionsPerListener(listeners("EXTERNAL").listenerName).validateReconfiguration(listenerConfig) + ) + } + + @Test + def testMaxListenerConnectionRateReconfiguration(): Unit = { + val config = KafkaConfig.fromProps(brokerPropsWithDefaultConnectionLimits) + connectionQuotas = new ConnectionQuotas(config, time, metrics) + connectionQuotas.addListener(config, listeners("EXTERNAL").listenerName) + + val listenerRateLimit = 20 + val listenerConfig = Map(KafkaConfig.MaxConnectionCreationRateProp -> listenerRateLimit.toString).asJava + connectionQuotas.maxConnectionsPerListener(listeners("EXTERNAL").listenerName).configure(listenerConfig) + + // remove connection rate limit + connectionQuotas.maxConnectionsPerListener(listeners("EXTERNAL").listenerName).reconfigure(Map.empty.asJava) + + // create connections as fast as possible, will timeout if connections get throttled with previous rate + // (50s to create 1000 connections) + executor.submit((() => + acceptConnections(connectionQuotas, listeners("EXTERNAL"), 1000)): Runnable + ).get(10, TimeUnit.SECONDS) + // verify no throttling + assertEquals(0, blockedPercentMeters("EXTERNAL").count(), + s"BlockedPercentMeter metric for EXTERNAL listener") + + // configure 100 connection/second rate limit + val newMaxListenerConnectionRate = 10 + val newListenerConfig = Map(KafkaConfig.MaxConnectionCreationRateProp -> newMaxListenerConnectionRate.toString).asJava + connectionQuotas.maxConnectionsPerListener(listeners("EXTERNAL").listenerName).reconfigure(newListenerConfig) + + // verify rate limit + val connectionsPerListener = 200 // should take 20 seconds to create 200 connections with rate = 10/sec + executor.submit((() => + acceptConnectionsAndVerifyRate(connectionQuotas, listeners("EXTERNAL"), connectionsPerListener, 5, newMaxListenerConnectionRate, 3)): Runnable + ).get(30, TimeUnit.SECONDS) + assertTrue(blockedPercentMeters("EXTERNAL").count() > 0, + "Expected BlockedPercentMeter metric for EXTERNAL listener to be recorded") + } + + @Test + def testMaxBrokerConnectionRateReconfiguration(): Unit = { + val config = KafkaConfig.fromProps(brokerPropsWithDefaultConnectionLimits) + connectionQuotas = new ConnectionQuotas(config, time, metrics) + connectionQuotas.addListener(config, listeners("EXTERNAL").listenerName) + + addListenersAndVerify(config, connectionQuotas) + + val maxBrokerConnectionRate = 50 + connectionQuotas.updateBrokerMaxConnectionRate(maxBrokerConnectionRate) + + // create connections with rate = 200 conn/sec (5ms interval), so that connection rate gets throttled + val totalConnections = 400 + executor.submit((() => + // this is a short run, so setting epsilon higher (enough to check that the rate is not unlimited) + acceptConnectionsAndVerifyRate(connectionQuotas, listeners("EXTERNAL"), totalConnections, 5, maxBrokerConnectionRate, 20)): Runnable + ).get(10, TimeUnit.SECONDS) + assertTrue(blockedPercentMeters("EXTERNAL").count() > 0, + "Expected BlockedPercentMeter metric for EXTERNAL listener to be recorded") + } + + @Test + def testIpConnectionRateMetricUpdate(): Unit = { + val config = KafkaConfig.fromProps(brokerPropsWithDefaultConnectionLimits) + connectionQuotas = new ConnectionQuotas(config, time, metrics) + connectionQuotas.addListener(config, listeners("EXTERNAL").listenerName) + connectionQuotas.addListener(config, listeners("ADMIN").listenerName) + val defaultIpRate = 50 + val defaultOverrideRate = 20 + val overrideIpRate = 30 + val externalListener = listeners("EXTERNAL") + val adminListener = listeners("ADMIN") + // set a non-unlimited default quota so that we create ip rate sensors/metrics + connectionQuotas.updateIpConnectionRateQuota(None, Some(defaultIpRate)) + connectionQuotas.inc(externalListener.listenerName, externalListener.defaultIp, blockedPercentMeters("EXTERNAL")) + connectionQuotas.inc(adminListener.listenerName, adminListener.defaultIp, blockedPercentMeters("ADMIN")) + + // both IPs should have the default rate + verifyIpConnectionQuota(externalListener.defaultIp, defaultIpRate) + verifyIpConnectionQuota(adminListener.defaultIp, defaultIpRate) + + // external listener should have its in-memory quota and metric config updated + connectionQuotas.updateIpConnectionRateQuota(Some(externalListener.defaultIp), Some(overrideIpRate)) + verifyIpConnectionQuota(externalListener.defaultIp, overrideIpRate) + + // update default + connectionQuotas.updateIpConnectionRateQuota(None, Some(defaultOverrideRate)) + + // external listener IP should not have its quota updated to the new default + verifyIpConnectionQuota(externalListener.defaultIp, overrideIpRate) + // admin listener IP should have its quota updated with to the new default + verifyIpConnectionQuota(adminListener.defaultIp, defaultOverrideRate) + + // remove default connection rate quota + connectionQuotas.updateIpConnectionRateQuota(None, None) + verifyIpConnectionQuota(adminListener.defaultIp, QuotaConfigs.IP_CONNECTION_RATE_DEFAULT) + verifyIpConnectionQuota(externalListener.defaultIp, overrideIpRate) + + // remove override for external listener IP + connectionQuotas.updateIpConnectionRateQuota(Some(externalListener.defaultIp), None) + verifyIpConnectionQuota(externalListener.defaultIp, QuotaConfigs.IP_CONNECTION_RATE_DEFAULT) + } + + @Test + def testEnforcedIpConnectionRateQuotaUpdate(): Unit = { + val ipConnectionRateLimit = 20 + val props = brokerPropsWithDefaultConnectionLimits + val config = KafkaConfig.fromProps(props) + // use MockTime for IP connection rate quota tests that don't expect to block + setupMockTime() + connectionQuotas = new ConnectionQuotas(config, time, metrics) + addListenersAndVerify(config, connectionQuotas) + val externalListener = listeners("EXTERNAL") + connectionQuotas.updateIpConnectionRateQuota(Some(externalListener.defaultIp), Some(ipConnectionRateLimit)) + // create connections with the rate > ip quota + val connectionRate = 40 + assertThrows(classOf[ConnectionThrottledException], + () => acceptConnections(connectionQuotas, externalListener, connectionRate) + ) + assertEquals(ipConnectionRateLimit, connectionQuotas.get(externalListener.defaultIp), + s"Number of connections on $externalListener:") + + // increase ip quota, we should accept connections up to the new quota limit + val updatedRateLimit = 30 + connectionQuotas.updateIpConnectionRateQuota(Some(externalListener.defaultIp), Some(updatedRateLimit)) + assertThrows(classOf[ConnectionThrottledException], + () => acceptConnections(connectionQuotas, externalListener, connectionRate) + ) + assertEquals(updatedRateLimit, connectionQuotas.get(externalListener.defaultIp), + s"Number of connections on $externalListener:") + + // remove IP quota, all connections should get accepted + connectionQuotas.updateIpConnectionRateQuota(Some(externalListener.defaultIp), None) + acceptConnections(connectionQuotas, externalListener, connectionRate) + assertEquals(connectionRate + updatedRateLimit, connectionQuotas.get(externalListener.defaultIp), + s"Number of connections on $externalListener:") + + // create connections on a different IP, + val adminListener = listeners("ADMIN") + acceptConnections(connectionQuotas, adminListener, connectionRate) + assertEquals(connectionRate, connectionQuotas.get(adminListener.defaultIp), + s"Number of connections on $adminListener:") + + // set a default IP quota, verify that quota gets propagated + connectionQuotas.updateIpConnectionRateQuota(None, Some(ipConnectionRateLimit)) + assertThrows(classOf[ConnectionThrottledException], + () => acceptConnections(connectionQuotas, adminListener, connectionRate) + ) + assertEquals(connectionRate + ipConnectionRateLimit, connectionQuotas.get(adminListener.defaultIp), + s"Number of connections on $adminListener:") + + // acceptor shouldn't block for IP rate throttling + verifyNoBlockedPercentRecordedOnAllListeners() + } + + @Test + def testNonDefaultConnectionCountLimitAndRateLimit(): Unit = { + val brokerRateLimit = 25 + val maxConnections = 350 // with rate == 25, will run out of connections in 14 seconds + val props = brokerPropsWithDefaultConnectionLimits + props.put(KafkaConfig.MaxConnectionsProp, maxConnections.toString) + props.put(KafkaConfig.MaxConnectionCreationRateProp, brokerRateLimit.toString) + val config = KafkaConfig.fromProps(props) + connectionQuotas = new ConnectionQuotas(config, time, metrics) + connectionQuotas.addListener(config, listeners("EXTERNAL").listenerName) + + addListenersAndVerify(config, connectionQuotas) + + // create connections with rate = 100 conn/sec (10ms interval), so that connection rate gets throttled + val listener = listeners("EXTERNAL") + executor.submit((() => + acceptConnectionsAndVerifyRate(connectionQuotas, listener, maxConnections, 10, brokerRateLimit, 8)): Runnable + ).get(20, TimeUnit.SECONDS) + assertTrue(blockedPercentMeters("EXTERNAL").count() > 0, + "Expected BlockedPercentMeter metric for EXTERNAL listener to be recorded") + assertEquals(maxConnections, connectionQuotas.get(listener.defaultIp), + s"Number of connections on EXTERNAL listener:") + + // adding one more connection will block ConnectionQuota.inc() + val future = executor.submit((() => + acceptConnections(connectionQuotas, listeners("EXTERNAL"), 1)): Runnable + ) + assertThrows(classOf[TimeoutException], () => future.get(100, TimeUnit.MILLISECONDS)) + + // removing one connection should make the waiting connection to succeed + connectionQuotas.dec(listener.listenerName, listener.defaultIp) + future.get(1, TimeUnit.SECONDS) + assertEquals(maxConnections, connectionQuotas.get(listener.defaultIp), + s"Number of connections on EXTERNAL listener:") + } + + private def addListenersAndVerify(config: KafkaConfig, connectionQuotas: ConnectionQuotas) : Unit = { + addListenersAndVerify(config, Map.empty.asJava, connectionQuotas) + } + + private def addListenersAndVerify(config: KafkaConfig, + listenerConfig: util.Map[String, _], + connectionQuotas: ConnectionQuotas) : Unit = { + assertNotNull(brokerConnRateMetric(), + "Expected broker-connection-accept-rate metric to exist") + + // add listeners and verify connection limits not exceeded + listeners.forKeyValue { (name, listener) => + val listenerName = listener.listenerName + connectionQuotas.addListener(config, listenerName) + connectionQuotas.maxConnectionsPerListener(listenerName).configure(listenerConfig) + assertFalse(connectionQuotas.maxConnectionsExceeded(listenerName), + s"Should not exceed max connection limit on $name listener after initialization") + assertEquals(0, connectionQuotas.get(listener.defaultIp), + s"Number of connections on $listener listener:") + assertNotNull(listenerConnRateMetric(listenerName.value), + s"Expected connection-accept-rate metric to exist for listener ${listenerName.value}") + assertEquals(0, metricValue(listenerConnRateMetric(listenerName.value)), eps, + s"Connection acceptance rate metric for listener ${listenerName.value}") + assertNotNull(listenerConnThrottleMetric(listenerName.value), + s"Expected connection-accept-throttle-time metric to exist for listener ${listenerName.value}") + assertEquals(0, metricValue(listenerConnThrottleMetric(listenerName.value)).toLong, + s"Listener connection throttle metric for listener ${listenerName.value}") + assertEquals(0, metricValue(ipConnThrottleMetric(listenerName.value)).toLong, + s"Ip connection throttle metric for listener ${listenerName.value}") + } + verifyNoBlockedPercentRecordedOnAllListeners() + assertEquals(0, metricValue(brokerConnRateMetric()), eps, + "Broker-wide connection acceptance rate metric") + } + + private def verifyNoBlockedPercentRecordedOnAllListeners(): Unit = { + blockedPercentMeters.forKeyValue { (name, meter) => + assertEquals(0, meter.count(), + s"BlockedPercentMeter metric for $name listener") + } + } + + private def verifyNonZeroBlockedPercentAndThrottleTimeOnAllListeners(): Unit = { + blockedPercentMeters.forKeyValue { (name, meter) => + assertTrue(meter.count() > 0, + s"Expected BlockedPercentMeter metric for $name listener to be recorded") + } + listeners.values.foreach { listener => + assertTrue(metricValue(listenerConnThrottleMetric(listener.listenerName.value)).toLong > 0, + s"Connection throttle metric for listener ${listener.listenerName.value}") + } + } + + private def verifyIpThrottleTimeOnListener(listener: ListenerName, expectThrottle: Boolean): Unit = { + assertEquals(expectThrottle, metricValue(ipConnThrottleMetric(listener.value)).toLong > 0, + s"IP connection throttle recorded for listener ${listener.value}") + } + + private def verifyOnlyNonInterBrokerListenersBlockedPercentRecorded(): Unit = { + blockedPercentMeters.forKeyValue { (name, meter) => + name match { + case "REPLICATION" => + assertEquals(0, meter.count(), s"BlockedPercentMeter metric for $name listener") + case _ => + assertTrue(meter.count() > 0, s"Expected BlockedPercentMeter metric for $name listener to be recorded") + } + } + } + + private def verifyConnectionCountOnEveryListener(connectionQuotas: ConnectionQuotas, expectedConnectionCount: Int): Unit = { + listeners.values.foreach { listener => + assertEquals(expectedConnectionCount, connectionQuotas.get(listener.defaultIp), + s"Number of connections on $listener:") + } + } + + private def listenerConnThrottleMetric(listener: String) : KafkaMetric = { + val metricName = metrics.metricName( + "connection-accept-throttle-time", + SocketServer.MetricsGroup, + Collections.singletonMap(Processor.ListenerMetricTag, listener)) + metrics.metric(metricName) + } + + private def ipConnThrottleMetric(listener: String): KafkaMetric = { + val metricName = metrics.metricName( + "ip-connection-accept-throttle-time", + SocketServer.MetricsGroup, + Collections.singletonMap(Processor.ListenerMetricTag, listener)) + metrics.metric(metricName) + } + + private def listenerConnRateMetric(listener: String) : KafkaMetric = { + val metricName = metrics.metricName( + "connection-accept-rate", + SocketServer.MetricsGroup, + Collections.singletonMap(Processor.ListenerMetricTag, listener)) + metrics.metric(metricName) + } + + private def brokerConnRateMetric() : KafkaMetric = { + val metricName = metrics.metricName( + s"broker-connection-accept-rate", + SocketServer.MetricsGroup) + metrics.metric(metricName) + } + + private def ipConnRateMetric(ip: String): KafkaMetric = { + val metricName = metrics.metricName( + s"connection-accept-rate", + SocketServer.MetricsGroup, + Collections.singletonMap("ip", ip)) + metrics.metric(metricName) + } + + private def metricValue(metric: KafkaMetric): Double = { + metric.metricValue.asInstanceOf[Double] + } + + private def verifyIpConnectionQuota(ip: InetAddress, quota: Int): Unit = { + // verify connection quota in-memory rate and metric + assertEquals(quota, connectionQuotas.connectionRateForIp(ip)) + Option(ipConnRateMetric(ip.getHostAddress)) match { + case Some(metric) => assertEquals(quota, metric.config.quota.bound, 0.1) + case None => fail(s"Expected $ip connection rate metric to be defined") + } + } + + // this method must be called on a separate thread, because connectionQuotas.inc() may block + private def acceptConnections(connectionQuotas: ConnectionQuotas, + listenerDesc: ListenerDesc, + numConnections: Long, + timeIntervalMs: Long = 0L, + expectIpThrottle: Boolean = false) : Unit = { + acceptConnections(connectionQuotas, listenerDesc.listenerName, listenerDesc.defaultIp, numConnections, + timeIntervalMs, expectIpThrottle) + } + + // this method must be called on a separate thread, because connectionQuotas.inc() may block + private def acceptConnectionsAndVerifyRate(connectionQuotas: ConnectionQuotas, + listenerDesc: ListenerDesc, + numConnections: Long, + timeIntervalMs: Long, + expectedRate: Int, + epsilon: Int, + expectIpThrottle: Boolean = false) : Unit = { + val startTimeMs = time.milliseconds + val startNumConnections = connectionQuotas.get(listenerDesc.defaultIp) + acceptConnections(connectionQuotas, listenerDesc.listenerName, listenerDesc.defaultIp, numConnections, + timeIntervalMs, expectIpThrottle) + val elapsedSeconds = MetricsUtils.convert(time.milliseconds - startTimeMs, TimeUnit.SECONDS) + val createdConnections = connectionQuotas.get(listenerDesc.defaultIp) - startNumConnections + val actualRate = createdConnections.toDouble / elapsedSeconds + assertEquals(expectedRate.toDouble, actualRate, epsilon, + s"Expected rate ($expectedRate +- $epsilon), but got $actualRate ($createdConnections connections / $elapsedSeconds sec)") + } + + /** + * This method will "create" connections every 'timeIntervalMs' which translates to 1000/timeIntervalMs connection rate, + * as long as the rate is below the connection rate limit. Otherwise, connections will be essentially created as + * fast as possible, which would result in the maximum connection creation rate. + * + * This method must be called on a separate thread, because connectionQuotas.inc() may block + */ + private def acceptConnections(connectionQuotas: ConnectionQuotas, + listenerName: ListenerName, + address: InetAddress, + numConnections: Long, + timeIntervalMs: Long, + expectIpThrottle: Boolean): Boolean = { + var nextSendTime = time.milliseconds + timeIntervalMs + var ipThrottled = false + for (_ <- 0L until numConnections) { + // this method may block if broker-wide or listener limit on the number of connections is reached + try { + connectionQuotas.inc(listenerName, address, blockedPercentMeters(listenerName.value)) + } catch { + case e: ConnectionThrottledException => + if (!expectIpThrottle) + throw e + ipThrottled = true + } + val sleepMs = math.max(nextSendTime - time.milliseconds, 0) + if (sleepMs > 0) + time.sleep(sleepMs) + + nextSendTime = nextSendTime + timeIntervalMs + } + ipThrottled + } + + // this method must be called on a separate thread, because connectionQuotas.inc() may block + private def acceptConnectionsAboveIpLimit(connectionQuotas: ConnectionQuotas, + listenerDesc: ListenerDesc, + numConnections: Long) : Unit = { + val listenerName = listenerDesc.listenerName + for (i <- 0L until numConnections) { + // this method may block if broker-wide or listener limit is reached + assertThrows(classOf[TooManyConnectionsException], + () => connectionQuotas.inc(listenerName, listenerDesc.defaultIp, blockedPercentMeters(listenerName.value)) + ) + } + } +} diff --git a/core/src/test/scala/unit/kafka/network/RequestChannelTest.scala b/core/src/test/scala/unit/kafka/network/RequestChannelTest.scala new file mode 100644 index 0000000..881ea65 --- /dev/null +++ b/core/src/test/scala/unit/kafka/network/RequestChannelTest.scala @@ -0,0 +1,315 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.network + + +import java.io.IOException +import java.net.InetAddress +import java.nio.ByteBuffer +import java.util.Collections +import com.fasterxml.jackson.databind.ObjectMapper +import kafka.network +import kafka.utils.TestUtils +import org.apache.kafka.clients.admin.AlterConfigOp.OpType +import org.apache.kafka.common.config.types.Password +import org.apache.kafka.common.config.{ConfigResource, SaslConfigs, SslConfigs, TopicConfig} +import org.apache.kafka.common.memory.MemoryPool +import org.apache.kafka.common.message.IncrementalAlterConfigsRequestData +import org.apache.kafka.common.message.IncrementalAlterConfigsRequestData._ +import org.apache.kafka.common.network.{ByteBufferSend, ClientInformation, ListenerName} +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.requests.{AbstractRequest, MetadataRequest, RequestTestUtils} +import org.apache.kafka.common.requests.AlterConfigsRequest._ +import org.apache.kafka.common.requests._ +import org.apache.kafka.common.security.auth.{KafkaPrincipal, KafkaPrincipalSerde, SecurityProtocol} +import org.apache.kafka.common.utils.{SecurityUtils, Utils} +import org.easymock.EasyMock._ +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api._ +import org.mockito.{ArgumentCaptor, Mockito} + +import scala.collection.{Map, Seq} +import scala.jdk.CollectionConverters._ + +class RequestChannelTest { + private val requestChannelMetrics: RequestChannel.Metrics = mock(classOf[RequestChannel.Metrics]) + private val clientId = "id" + private val principalSerde = new KafkaPrincipalSerde() { + override def serialize(principal: KafkaPrincipal): Array[Byte] = Utils.utf8(principal.toString) + override def deserialize(bytes: Array[Byte]): KafkaPrincipal = SecurityUtils.parseKafkaPrincipal(Utils.utf8(bytes)) + } + private val mockSend: ByteBufferSend = Mockito.mock(classOf[ByteBufferSend]) + + @Test + def testAlterRequests(): Unit = { + + val sensitiveValue = "secret" + def verifyConfig(resource: ConfigResource, entries: Seq[ConfigEntry], expectedValues: Map[String, String]): Unit = { + val alterConfigs = request(new AlterConfigsRequest.Builder( + Collections.singletonMap(resource, new Config(entries.asJavaCollection)), true).build()) + + val loggableAlterConfigs = alterConfigs.loggableRequest.asInstanceOf[AlterConfigsRequest] + val loggedConfig = loggableAlterConfigs.configs.get(resource) + assertEquals(expectedValues, toMap(loggedConfig)) + val alterConfigsDesc = RequestConvertToJson.requestDesc(alterConfigs.header, alterConfigs.requestLog, alterConfigs.isForwarded).toString + assertFalse(alterConfigsDesc.contains(sensitiveValue), s"Sensitive config logged $alterConfigsDesc") + } + + val brokerResource = new ConfigResource(ConfigResource.Type.BROKER, "1") + val keystorePassword = new ConfigEntry(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG, sensitiveValue) + verifyConfig(brokerResource, Seq(keystorePassword), Map(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG -> Password.HIDDEN)) + + val keystoreLocation = new ConfigEntry(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG, "/path/to/keystore") + verifyConfig(brokerResource, Seq(keystoreLocation), Map(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG -> "/path/to/keystore")) + verifyConfig(brokerResource, Seq(keystoreLocation, keystorePassword), + Map(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG -> "/path/to/keystore", SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG -> Password.HIDDEN)) + + val listenerKeyPassword = new ConfigEntry(s"listener.name.internal.${SslConfigs.SSL_KEY_PASSWORD_CONFIG}", sensitiveValue) + verifyConfig(brokerResource, Seq(listenerKeyPassword), Map(listenerKeyPassword.name -> Password.HIDDEN)) + + val listenerKeystore = new ConfigEntry(s"listener.name.internal.${SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG}", "/path/to/keystore") + verifyConfig(brokerResource, Seq(listenerKeystore), Map(listenerKeystore.name -> "/path/to/keystore")) + + val plainJaasConfig = new ConfigEntry(s"listener.name.internal.plain.${SaslConfigs.SASL_JAAS_CONFIG}", sensitiveValue) + verifyConfig(brokerResource, Seq(plainJaasConfig), Map(plainJaasConfig.name -> Password.HIDDEN)) + + val plainLoginCallback = new ConfigEntry(s"listener.name.internal.plain.${SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS}", "test.LoginClass") + verifyConfig(brokerResource, Seq(plainLoginCallback), Map(plainLoginCallback.name -> plainLoginCallback.value)) + + val customConfig = new ConfigEntry("custom.config", sensitiveValue) + verifyConfig(brokerResource, Seq(customConfig), Map(customConfig.name -> Password.HIDDEN)) + + val topicResource = new ConfigResource(ConfigResource.Type.TOPIC, "testTopic") + val compressionType = new ConfigEntry(TopicConfig.COMPRESSION_TYPE_CONFIG, "lz4") + verifyConfig(topicResource, Seq(compressionType), Map(TopicConfig.COMPRESSION_TYPE_CONFIG -> "lz4")) + verifyConfig(topicResource, Seq(customConfig), Map(customConfig.name -> Password.HIDDEN)) + + // Verify empty request + val alterConfigs = request(new AlterConfigsRequest.Builder( + Collections.emptyMap[ConfigResource, Config], true).build()) + assertEquals(Collections.emptyMap, alterConfigs.loggableRequest.asInstanceOf[AlterConfigsRequest].configs) + } + + @Test + def testIncrementalAlterRequests(): Unit = { + + def incrementalAlterConfigs(resource: ConfigResource, + entries: Map[String, String], op: OpType): IncrementalAlterConfigsRequest = { + val data = new IncrementalAlterConfigsRequestData() + val alterableConfigs = new AlterableConfigCollection() + entries.foreach { case (name, value) => + alterableConfigs.add(new AlterableConfig().setName(name).setValue(value).setConfigOperation(op.id)) + } + data.resources.add(new AlterConfigsResource() + .setResourceName(resource.name).setResourceType(resource.`type`.id) + .setConfigs(alterableConfigs)) + new IncrementalAlterConfigsRequest.Builder(data).build() + } + + val sensitiveValue = "secret" + def verifyConfig(resource: ConfigResource, + op: OpType, + entries: Map[String, String], + expectedValues: Map[String, String]): Unit = { + val alterConfigs = request(incrementalAlterConfigs(resource, entries, op)) + val loggableAlterConfigs = alterConfigs.loggableRequest.asInstanceOf[IncrementalAlterConfigsRequest] + val loggedConfig = loggableAlterConfigs.data.resources.find(resource.`type`.id, resource.name).configs + assertEquals(expectedValues, toMap(loggedConfig)) + val alterConfigsDesc = RequestConvertToJson.requestDesc(alterConfigs.header, alterConfigs.requestLog, alterConfigs.isForwarded).toString + assertFalse(alterConfigsDesc.contains(sensitiveValue), s"Sensitive config logged $alterConfigsDesc") + } + + val brokerResource = new ConfigResource(ConfigResource.Type.BROKER, "1") + val keystorePassword = Map(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG -> sensitiveValue) + verifyConfig(brokerResource, OpType.SET, keystorePassword, Map(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG -> Password.HIDDEN)) + + val keystoreLocation = Map(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG -> "/path/to/keystore") + verifyConfig(brokerResource, OpType.SET, keystoreLocation, keystoreLocation) + verifyConfig(brokerResource, OpType.SET, keystoreLocation ++ keystorePassword, + Map(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG -> "/path/to/keystore", SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG -> Password.HIDDEN)) + + val listenerKeyPassword = Map(s"listener.name.internal.${SslConfigs.SSL_KEY_PASSWORD_CONFIG}" -> sensitiveValue) + verifyConfig(brokerResource, OpType.SET, listenerKeyPassword, + Map(s"listener.name.internal.${SslConfigs.SSL_KEY_PASSWORD_CONFIG}" -> Password.HIDDEN)) + + val listenerKeystore = Map(s"listener.name.internal.${SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG}" -> "/path/to/keystore") + verifyConfig(brokerResource, OpType.SET, listenerKeystore, listenerKeystore) + + val plainJaasConfig = Map(s"listener.name.internal.plain.${SaslConfigs.SASL_JAAS_CONFIG}" -> sensitiveValue) + verifyConfig(brokerResource, OpType.SET, plainJaasConfig, + Map(s"listener.name.internal.plain.${SaslConfigs.SASL_JAAS_CONFIG}" -> Password.HIDDEN)) + + val plainLoginCallback = Map(s"listener.name.internal.plain.${SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS}" -> "test.LoginClass") + verifyConfig(brokerResource, OpType.SET, plainLoginCallback, plainLoginCallback) + + val sslProtocols = Map(SslConfigs.SSL_ENABLED_PROTOCOLS_CONFIG -> "TLSv1.1") + verifyConfig(brokerResource, OpType.APPEND, sslProtocols, Map(SslConfigs.SSL_ENABLED_PROTOCOLS_CONFIG -> "TLSv1.1")) + verifyConfig(brokerResource, OpType.SUBTRACT, sslProtocols, Map(SslConfigs.SSL_ENABLED_PROTOCOLS_CONFIG -> "TLSv1.1")) + val cipherSuites = Map(SslConfigs.SSL_CIPHER_SUITES_CONFIG -> null) + verifyConfig(brokerResource, OpType.DELETE, cipherSuites, cipherSuites) + + val customConfig = Map("custom.config" -> sensitiveValue) + verifyConfig(brokerResource, OpType.SET, customConfig, Map("custom.config" -> Password.HIDDEN)) + + val topicResource = new ConfigResource(ConfigResource.Type.TOPIC, "testTopic") + val compressionType = Map(TopicConfig.COMPRESSION_TYPE_CONFIG -> "lz4") + verifyConfig(topicResource, OpType.SET, compressionType, compressionType) + verifyConfig(topicResource, OpType.SET, customConfig, Map("custom.config" -> Password.HIDDEN)) + } + + @Test + def testNonAlterRequestsNotTransformed(): Unit = { + val metadataRequest = request(new MetadataRequest.Builder(List("topic").asJava, true).build()) + assertSame(metadataRequest.body[MetadataRequest], metadataRequest.loggableRequest) + } + + @Test + def testJsonRequests(): Unit = { + val sensitiveValue = "secret" + val resource = new ConfigResource(ConfigResource.Type.BROKER, "1") + val keystorePassword = new ConfigEntry(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG, sensitiveValue) + val entries = Seq(keystorePassword) + + val alterConfigs = request(new AlterConfigsRequest.Builder(Collections.singletonMap(resource, + new Config(entries.asJavaCollection)), true).build()) + + assertTrue(isValidJson(RequestConvertToJson.request(alterConfigs.loggableRequest).toString)) + } + + @Test + def testEnvelopeBuildResponseSendShouldReturnNoErrorIfInnerResponseHasNoError(): Unit = { + val channelRequest = buildForwardRequestWithEnvelopeRequestAttached(buildMetadataRequest()) + + val envelopeResponseArgumentCaptor = ArgumentCaptor.forClass(classOf[EnvelopeResponse]) + + Mockito.doAnswer(_ => mockSend) + .when(channelRequest.envelope.get.context).buildResponseSend(envelopeResponseArgumentCaptor.capture()) + + // create an inner response without error + val responseWithoutError = RequestTestUtils.metadataUpdateWith(2, Collections.singletonMap("a", 2)) + + // build an envelope response + channelRequest.buildResponseSend(responseWithoutError) + + // expect the envelopeResponse result without error + val capturedValue: EnvelopeResponse = envelopeResponseArgumentCaptor.getValue + assertTrue(capturedValue.error().equals(Errors.NONE)) + } + + @Test + def testEnvelopeBuildResponseSendShouldReturnNoErrorIfInnerResponseHasNoNotControllerError(): Unit = { + val channelRequest = buildForwardRequestWithEnvelopeRequestAttached(buildMetadataRequest()) + + val envelopeResponseArgumentCaptor = ArgumentCaptor.forClass(classOf[EnvelopeResponse]) + + Mockito.doAnswer(_ => mockSend) + .when(channelRequest.envelope.get.context).buildResponseSend(envelopeResponseArgumentCaptor.capture()) + + // create an inner response with REQUEST_TIMED_OUT error + val responseWithTimeoutError = RequestTestUtils.metadataUpdateWith("cluster1", 2, + Collections.singletonMap("a", Errors.REQUEST_TIMED_OUT), + Collections.singletonMap("a", 2)) + + // build an envelope response + channelRequest.buildResponseSend(responseWithTimeoutError) + + // expect the envelopeResponse result without error + val capturedValue: EnvelopeResponse = envelopeResponseArgumentCaptor.getValue + assertTrue(capturedValue.error().equals(Errors.NONE)) + } + + @Test + def testEnvelopeBuildResponseSendShouldReturnNotControllerErrorIfInnerResponseHasOne(): Unit = { + val channelRequest = buildForwardRequestWithEnvelopeRequestAttached(buildMetadataRequest()) + + val envelopeResponseArgumentCaptor = ArgumentCaptor.forClass(classOf[EnvelopeResponse]) + + Mockito.doAnswer(_ => mockSend) + .when(channelRequest.envelope.get.context).buildResponseSend(envelopeResponseArgumentCaptor.capture()) + + // create an inner response with NOT_CONTROLLER error + val responseWithNotControllerError = RequestTestUtils.metadataUpdateWith("cluster1", 2, + Collections.singletonMap("a", Errors.NOT_CONTROLLER), + Collections.singletonMap("a", 2)) + + // build an envelope response + channelRequest.buildResponseSend(responseWithNotControllerError) + + // expect the envelopeResponse result has NOT_CONTROLLER error + val capturedValue: EnvelopeResponse = envelopeResponseArgumentCaptor.getValue + assertTrue(capturedValue.error().equals(Errors.NOT_CONTROLLER)) + } + + private def buildMetadataRequest(): AbstractRequest = { + val resourceName = "topic-1" + val header = new RequestHeader(ApiKeys.METADATA, ApiKeys.METADATA.latestVersion, + clientId, 0) + + new MetadataRequest.Builder(Collections.singletonList(resourceName), true).build(header.apiVersion) + } + + private def buildForwardRequestWithEnvelopeRequestAttached(request: AbstractRequest): RequestChannel.Request = { + val envelopeRequest = TestUtils.buildRequestWithEnvelope( + request, principalSerde, requestChannelMetrics, System.nanoTime(), shouldSpyRequestContext = true) + + TestUtils.buildRequestWithEnvelope( + request, principalSerde, requestChannelMetrics, System.nanoTime(), envelope = Option(envelopeRequest)) + } + + private def isValidJson(str: String): Boolean = { + try { + val mapper = new ObjectMapper + mapper.readTree(str) + true + } catch { + case _: IOException => false + } + } + + def request(req: AbstractRequest): RequestChannel.Request = { + val buffer = req.serializeWithHeader(new RequestHeader(req.apiKey, req.version, "client-id", 1)) + val requestContext = newRequestContext(buffer) + new network.RequestChannel.Request(processor = 1, + requestContext, + startTimeNanos = 0, + createNiceMock(classOf[MemoryPool]), + buffer, + createNiceMock(classOf[RequestChannel.Metrics]) + ) + } + + private def newRequestContext(buffer: ByteBuffer): RequestContext = { + new RequestContext( + RequestHeader.parse(buffer), + "connection-id", + InetAddress.getLoopbackAddress, + new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "user"), + ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT), + SecurityProtocol.PLAINTEXT, + new ClientInformation("name", "version"), + false) + } + + private def toMap(config: Config): Map[String, String] = { + config.entries.asScala.map(e => e.name -> e.value).toMap + } + + private def toMap(config: IncrementalAlterConfigsRequestData.AlterableConfigCollection): Map[String, String] = { + config.asScala.map(e => e.name -> e.value).toMap + } +} diff --git a/core/src/test/scala/unit/kafka/network/RequestConvertToJsonTest.scala b/core/src/test/scala/unit/kafka/network/RequestConvertToJsonTest.scala new file mode 100644 index 0000000..9b8db57 --- /dev/null +++ b/core/src/test/scala/unit/kafka/network/RequestConvertToJsonTest.scala @@ -0,0 +1,188 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.network + +import java.net.InetAddress +import java.nio.ByteBuffer + +import com.fasterxml.jackson.databind.node.{BooleanNode, DoubleNode, JsonNodeFactory, LongNode, ObjectNode, TextNode} +import kafka.network +import kafka.network.RequestConvertToJson.requestHeaderNode +import org.apache.kafka.common.memory.MemoryPool +import org.apache.kafka.common.message._ +import org.apache.kafka.common.network.{ClientInformation, ListenerName, NetworkSend} +import org.apache.kafka.common.protocol.{ApiKeys, MessageUtil} +import org.apache.kafka.common.requests._ +import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol} +import org.easymock.EasyMock.createNiceMock +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test + +import scala.collection.mutable.ArrayBuffer + +class RequestConvertToJsonTest { + + @Test + def testAllRequestTypesHandled(): Unit = { + val unhandledKeys = ArrayBuffer[String]() + ApiKeys.values().foreach { key => { + val version: Short = key.latestVersion() + val message = key match { + case ApiKeys.DESCRIBE_ACLS => + ApiMessageType.fromApiKey(key.id).newRequest().asInstanceOf[DescribeAclsRequestData] + .setPatternTypeFilter(1).setResourceTypeFilter(1).setPermissionType(1).setOperation(1) + case _ => + ApiMessageType.fromApiKey(key.id).newRequest() + } + + val bytes = MessageUtil.toByteBuffer(message, version) + val req = AbstractRequest.parseRequest(key, version, bytes).request + try { + RequestConvertToJson.request(req) + } catch { + case _ : IllegalStateException => unhandledKeys += key.toString + } + }} + assertEquals(ArrayBuffer.empty, unhandledKeys, "Unhandled request keys") + } + + @Test + def testAllResponseTypesHandled(): Unit = { + val unhandledKeys = ArrayBuffer[String]() + ApiKeys.values().foreach { key => { + val version: Short = key.latestVersion() + val message = ApiMessageType.fromApiKey(key.id).newResponse() + val bytes = MessageUtil.toByteBuffer(message, version) + val res = AbstractResponse.parseResponse(key, bytes, version) + try { + RequestConvertToJson.response(res, version) + } catch { + case _ : IllegalStateException => unhandledKeys += key.toString + } + }} + assertEquals(ArrayBuffer.empty, unhandledKeys, "Unhandled response keys") + } + + @Test + def testRequestHeaderNode(): Unit = { + val alterIsrRequest = new AlterIsrRequest(new AlterIsrRequestData(), 0) + val req = request(alterIsrRequest) + val header = req.header + + val expectedNode = RequestHeaderDataJsonConverter.write(header.data, header.headerVersion, false).asInstanceOf[ObjectNode] + expectedNode.set("requestApiKeyName", new TextNode(header.apiKey.toString)) + + val actualNode = RequestConvertToJson.requestHeaderNode(header) + + assertEquals(expectedNode, actualNode); + } + + @Test + def testClientInfoNode(): Unit = { + val clientInfo = new ClientInformation("name", "1") + + val expectedNode = new ObjectNode(JsonNodeFactory.instance) + expectedNode.set("softwareName", new TextNode(clientInfo.softwareName)) + expectedNode.set("softwareVersion", new TextNode(clientInfo.softwareVersion)) + + val actualNode = RequestConvertToJson.clientInfoNode(clientInfo) + + assertEquals(expectedNode, actualNode) + } + + @Test + def testRequestDesc(): Unit = { + val alterIsrRequest = new AlterIsrRequest(new AlterIsrRequestData(), 0) + val req = request(alterIsrRequest) + + val expectedNode = new ObjectNode(JsonNodeFactory.instance) + expectedNode.set("isForwarded", if (req.isForwarded) BooleanNode.TRUE else BooleanNode.FALSE) + expectedNode.set("requestHeader", requestHeaderNode(req.header)) + expectedNode.set("request", req.requestLog.getOrElse(new TextNode(""))) + + val actualNode = RequestConvertToJson.requestDesc(req.header, req.requestLog, req.isForwarded) + + assertEquals(expectedNode, actualNode) + } + + @Test + def testRequestDescMetrics(): Unit = { + val alterIsrRequest = new AlterIsrRequest(new AlterIsrRequestData(), 0) + val req = request(alterIsrRequest) + val send = new NetworkSend(req.context.connectionId, alterIsrRequest.toSend(req.header)) + val headerLog = RequestConvertToJson.requestHeaderNode(req.header) + val res = new RequestChannel.SendResponse(req, send, Some(headerLog), None) + + val totalTimeMs = 1 + val requestQueueTimeMs = 2 + val apiLocalTimeMs = 3 + val apiRemoteTimeMs = 4 + val apiThrottleTimeMs = 5 + val responseQueueTimeMs = 6 + val responseSendTimeMs = 7 + val temporaryMemoryBytes = 8 + val messageConversionsTimeMs = 9 + + val expectedNode = RequestConvertToJson.requestDesc(req.header, req.requestLog, req.isForwarded).asInstanceOf[ObjectNode] + expectedNode.set("response", res.responseLog.getOrElse(new TextNode(""))) + expectedNode.set("connection", new TextNode(req.context.connectionId)) + expectedNode.set("totalTimeMs", new DoubleNode(totalTimeMs)) + expectedNode.set("requestQueueTimeMs", new DoubleNode(requestQueueTimeMs)) + expectedNode.set("localTimeMs", new DoubleNode(apiLocalTimeMs)) + expectedNode.set("remoteTimeMs", new DoubleNode(apiRemoteTimeMs)) + expectedNode.set("throttleTimeMs", new LongNode(apiThrottleTimeMs)) + expectedNode.set("responseQueueTimeMs", new DoubleNode(responseQueueTimeMs)) + expectedNode.set("sendTimeMs", new DoubleNode(responseSendTimeMs)) + expectedNode.set("securityProtocol", new TextNode(req.context.securityProtocol.toString)) + expectedNode.set("principal", new TextNode(req.session.principal.toString)) + expectedNode.set("listener", new TextNode(req.context.listenerName.value)) + expectedNode.set("clientInformation", RequestConvertToJson.clientInfoNode(req.context.clientInformation)) + expectedNode.set("temporaryMemoryBytes", new LongNode(temporaryMemoryBytes)) + expectedNode.set("messageConversionsTime", new DoubleNode(messageConversionsTimeMs)) + + val actualNode = RequestConvertToJson.requestDescMetrics(req.header, req.requestLog, res.responseLog, req.context, req.session, req.isForwarded, + totalTimeMs, requestQueueTimeMs, apiLocalTimeMs, apiRemoteTimeMs, apiThrottleTimeMs, responseQueueTimeMs, + responseSendTimeMs, temporaryMemoryBytes, messageConversionsTimeMs).asInstanceOf[ObjectNode] + + assertEquals(expectedNode, actualNode) + } + + def request(req: AbstractRequest): RequestChannel.Request = { + val buffer = req.serializeWithHeader(new RequestHeader(req.apiKey, req.version, "client-id", 1)) + val requestContext = newRequestContext(buffer) + new network.RequestChannel.Request(processor = 1, + requestContext, + startTimeNanos = 0, + createNiceMock(classOf[MemoryPool]), + buffer, + createNiceMock(classOf[RequestChannel.Metrics]) + ) + } + + private def newRequestContext(buffer: ByteBuffer): RequestContext = { + new RequestContext( + RequestHeader.parse(buffer), + "connection-id", + InetAddress.getLoopbackAddress, + new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "user"), + ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT), + SecurityProtocol.PLAINTEXT, + new ClientInformation("name", "version"), + false) + } +} diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala new file mode 100644 index 0000000..af5631f --- /dev/null +++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala @@ -0,0 +1,2249 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.network + +import java.io._ +import java.net._ +import java.nio.ByteBuffer +import java.nio.channels.{SelectionKey, SocketChannel} +import java.nio.charset.StandardCharsets +import java.util +import java.util.concurrent.{CompletableFuture, ConcurrentLinkedQueue, Executors, TimeUnit} +import java.util.{Properties, Random} + +import com.fasterxml.jackson.databind.node.{JsonNodeFactory, ObjectNode, TextNode} +import com.yammer.metrics.core.{Gauge, Meter} +import javax.net.ssl._ +import kafka.metrics.KafkaYammerMetrics +import kafka.security.CredentialProvider +import kafka.server.{KafkaConfig, SimpleApiVersionManager, ThrottleCallback, ThrottledChannel} +import kafka.utils.Implicits._ +import kafka.utils.TestUtils +import org.apache.kafka.common.memory.MemoryPool +import org.apache.kafka.common.message.ApiMessageType.ListenerType +import org.apache.kafka.common.message.{ProduceRequestData, SaslAuthenticateRequestData, SaslHandshakeRequestData, VoteRequestData} +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.network.KafkaChannel.ChannelMuteState +import org.apache.kafka.common.network.{ClientInformation, _} +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.requests +import org.apache.kafka.common.requests._ +import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol} +import org.apache.kafka.common.security.scram.internals.ScramMechanism +import org.apache.kafka.common.utils.{AppInfoParser, LogContext, MockTime, Time, Utils} +import org.apache.kafka.test.{TestSslUtils, TestUtils => JTestUtils} +import org.apache.log4j.Level +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api._ + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.jdk.CollectionConverters._ +import scala.util.control.ControlThrowable + +class SocketServerTest { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 0) + props.put("listeners", "PLAINTEXT://localhost:0") + props.put("num.network.threads", "1") + props.put("socket.send.buffer.bytes", "300000") + props.put("socket.receive.buffer.bytes", "300000") + props.put("queued.max.requests", "50") + props.put("socket.request.max.bytes", "100") + props.put("max.connections.per.ip", "5") + props.put("connections.max.idle.ms", "60000") + val config = KafkaConfig.fromProps(props) + val metrics = new Metrics + val credentialProvider = new CredentialProvider(ScramMechanism.mechanismNames, null) + val localAddress = InetAddress.getLoopbackAddress + + // Clean-up any metrics left around by previous tests + TestUtils.clearYammerMetrics() + + private val apiVersionManager = new SimpleApiVersionManager(ListenerType.ZK_BROKER) + val server = new SocketServer(config, metrics, Time.SYSTEM, credentialProvider, apiVersionManager) + server.startup() + val sockets = new ArrayBuffer[Socket] + + private val kafkaLogger = org.apache.log4j.LogManager.getLogger("kafka") + private var logLevelToRestore: Level = _ + + @BeforeEach + def setUp(): Unit = { + // Run the tests with TRACE logging to exercise request logging path + logLevelToRestore = kafkaLogger.getLevel + kafkaLogger.setLevel(Level.TRACE) + + assertTrue(server.controlPlaneRequestChannelOpt.isEmpty) + } + + @AfterEach + def tearDown(): Unit = { + shutdownServerAndMetrics(server) + sockets.foreach(_.close()) + sockets.clear() + kafkaLogger.setLevel(logLevelToRestore) + } + + def sendRequest(socket: Socket, request: Array[Byte], id: Option[Short] = None, flush: Boolean = true): Unit = { + val outgoing = new DataOutputStream(socket.getOutputStream) + id match { + case Some(id) => + outgoing.writeInt(request.length + 2) + outgoing.writeShort(id) + case None => + outgoing.writeInt(request.length) + } + outgoing.write(request) + if (flush) + outgoing.flush() + } + + def sendApiRequest(socket: Socket, request: AbstractRequest, header: RequestHeader): Unit = { + val serializedBytes = Utils.toArray(request.serializeWithHeader(header)) + sendRequest(socket, serializedBytes) + } + + def receiveResponse(socket: Socket): Array[Byte] = { + val incoming = new DataInputStream(socket.getInputStream) + val len = incoming.readInt() + val response = new Array[Byte](len) + incoming.readFully(response) + response + } + + private def receiveRequest(channel: RequestChannel, timeout: Long = 2000L): RequestChannel.Request = { + channel.receiveRequest(timeout) match { + case request: RequestChannel.Request => request + case RequestChannel.ShutdownRequest => throw new AssertionError("Unexpected shutdown received") + case null => throw new AssertionError("receiveRequest timed out") + } + } + + /* A simple request handler that just echos back the response */ + def processRequest(channel: RequestChannel): Unit = { + processRequest(channel, receiveRequest(channel)) + } + + def processRequest(channel: RequestChannel, request: RequestChannel.Request): Unit = { + val byteBuffer = request.body[AbstractRequest].serializeWithHeader(request.header) + val send = new NetworkSend(request.context.connectionId, ByteBufferSend.sizePrefixed(byteBuffer)) + val headerLog = RequestConvertToJson.requestHeaderNode(request.header) + channel.sendResponse(new RequestChannel.SendResponse(request, send, Some(headerLog), None)) + } + + def processRequestNoOpResponse(channel: RequestChannel, request: RequestChannel.Request): Unit = { + channel.sendNoOpResponse(request) + } + + def connect(s: SocketServer = server, + listenerName: ListenerName = ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT), + localAddr: InetAddress = null, + port: Int = 0): Socket = { + val socket = new Socket("localhost", s.boundPort(listenerName), localAddr, port) + sockets += socket + socket + } + + def sslConnect(s: SocketServer = server): Socket = { + val socket = sslClientSocket(s.boundPort(ListenerName.forSecurityProtocol(SecurityProtocol.SSL))) + sockets += socket + socket + } + + private def sslClientSocket(port: Int): Socket = { + val sslContext = SSLContext.getInstance(TestSslUtils.DEFAULT_TLS_PROTOCOL_FOR_TESTS) + sslContext.init(null, Array(TestUtils.trustAllCerts), new java.security.SecureRandom()) + val socketFactory = sslContext.getSocketFactory + val socket = socketFactory.createSocket("localhost", port) + socket.asInstanceOf[SSLSocket].setNeedClientAuth(false) + socket + } + + // Create a client connection, process one request and return (client socket, connectionId) + def connectAndProcessRequest(s: SocketServer): (Socket, String) = { + val securityProtocol = s.dataPlaneAcceptors.asScala.head._1.securityProtocol + val socket = securityProtocol match { + case SecurityProtocol.PLAINTEXT | SecurityProtocol.SASL_PLAINTEXT => + connect(s) + case SecurityProtocol.SSL | SecurityProtocol.SASL_SSL => + sslConnect(s) + case _ => + throw new IllegalStateException(s"Unexpected security protocol $securityProtocol") + } + val request = sendAndReceiveRequest(socket, s) + processRequest(s.dataPlaneRequestChannel, request) + (socket, request.context.connectionId) + } + + def sendAndReceiveRequest(socket: Socket, server: SocketServer): RequestChannel.Request = { + sendRequest(socket, producerRequestBytes()) + receiveRequest(server.dataPlaneRequestChannel) + } + + def shutdownServerAndMetrics(server: SocketServer): Unit = { + server.shutdown() + server.metrics.close() + } + + private def producerRequestBytes(ack: Short = 0): Array[Byte] = { + val correlationId = -1 + val clientId = "" + val ackTimeoutMs = 10000 + + val emptyRequest = requests.ProduceRequest.forCurrentMagic(new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection()) + .setAcks(ack) + .setTimeoutMs(ackTimeoutMs) + .setTransactionalId(null)) + .build() + val emptyHeader = new RequestHeader(ApiKeys.PRODUCE, emptyRequest.version, clientId, correlationId) + Utils.toArray(emptyRequest.serializeWithHeader(emptyHeader)) + } + + private def apiVersionRequestBytes(clientId: String, version: Short): Array[Byte] = { + val request = new ApiVersionsRequest.Builder().build(version) + val header = new RequestHeader(ApiKeys.API_VERSIONS, request.version(), clientId, -1) + Utils.toArray(request.serializeWithHeader(header)) + } + + @Test + def simpleRequest(): Unit = { + val plainSocket = connect() + val serializedBytes = producerRequestBytes() + + // Test PLAINTEXT socket + sendRequest(plainSocket, serializedBytes) + processRequest(server.dataPlaneRequestChannel) + assertEquals(serializedBytes.toSeq, receiveResponse(plainSocket).toSeq) + verifyAcceptorBlockedPercent("PLAINTEXT", expectBlocked = false) + } + + + private def testClientInformation(version: Short, expectedClientSoftwareName: String, + expectedClientSoftwareVersion: String): Unit = { + val plainSocket = connect() + val address = plainSocket.getLocalAddress + val clientId = "clientId" + + // Send ApiVersionsRequest - unknown expected + sendRequest(plainSocket, apiVersionRequestBytes(clientId, version)) + var receivedReq = receiveRequest(server.dataPlaneRequestChannel) + + assertEquals(ClientInformation.UNKNOWN_NAME_OR_VERSION, receivedReq.context.clientInformation.softwareName) + assertEquals(ClientInformation.UNKNOWN_NAME_OR_VERSION, receivedReq.context.clientInformation.softwareVersion) + + server.dataPlaneRequestChannel.sendNoOpResponse(receivedReq) + + // Send ProduceRequest - client info expected + sendRequest(plainSocket, producerRequestBytes()) + receivedReq = receiveRequest(server.dataPlaneRequestChannel) + + assertEquals(expectedClientSoftwareName, receivedReq.context.clientInformation.softwareName) + assertEquals(expectedClientSoftwareVersion, receivedReq.context.clientInformation.softwareVersion) + + server.dataPlaneRequestChannel.sendNoOpResponse(receivedReq) + + // Close the socket + plainSocket.setSoLinger(true, 0) + plainSocket.close() + + TestUtils.waitUntilTrue(() => server.connectionCount(address) == 0, msg = "Connection not closed") + } + + @Test + def testClientInformationWithLatestApiVersionsRequest(): Unit = { + testClientInformation( + ApiKeys.API_VERSIONS.latestVersion, + "apache-kafka-java", + AppInfoParser.getVersion + ) + } + + @Test + def testClientInformationWithOldestApiVersionsRequest(): Unit = { + testClientInformation( + ApiKeys.API_VERSIONS.oldestVersion, + ClientInformation.UNKNOWN_NAME_OR_VERSION, + ClientInformation.UNKNOWN_NAME_OR_VERSION + ) + } + + @Test + def testStagedListenerStartup(): Unit = { + val testProps = new Properties + testProps ++= props + testProps.put("listeners", "EXTERNAL://localhost:0,INTERNAL://localhost:0,CONTROLLER://localhost:0") + testProps.put("listener.security.protocol.map", "EXTERNAL:PLAINTEXT,INTERNAL:PLAINTEXT,CONTROLLER:PLAINTEXT") + testProps.put("control.plane.listener.name", "CONTROLLER") + testProps.put("inter.broker.listener.name", "INTERNAL") + val config = KafkaConfig.fromProps(testProps) + val testableServer = new TestableSocketServer(config) + testableServer.startup(startProcessingRequests = false) + + val updatedEndPoints = config.effectiveAdvertisedListeners.map { endpoint => + endpoint.copy(port = testableServer.boundPort(endpoint.listenerName)) + }.map(_.toJava) + + val externalReadyFuture = new CompletableFuture[Void]() + val executor = Executors.newSingleThreadExecutor() + + def controlPlaneListenerStarted() = { + try { + val socket = connect(testableServer, config.controlPlaneListenerName.get, localAddr = InetAddress.getLocalHost) + sendAndReceiveControllerRequest(socket, testableServer) + true + } catch { + case _: Throwable => false + } + } + + def listenerStarted(listenerName: ListenerName) = { + try { + val socket = connect(testableServer, listenerName, localAddr = InetAddress.getLocalHost) + sendAndReceiveRequest(socket, testableServer) + true + } catch { + case _: Throwable => false + } + } + + try { + val externalListener = new ListenerName("EXTERNAL") + val externalEndpoint = updatedEndPoints.find(e => e.listenerName.get == externalListener.value).get + val futures = Map(externalEndpoint -> externalReadyFuture) + val startFuture = executor.submit((() => testableServer.startProcessingRequests(futures)): Runnable) + TestUtils.waitUntilTrue(() => controlPlaneListenerStarted(), "Control plane listener not started") + TestUtils.waitUntilTrue(() => listenerStarted(config.interBrokerListenerName), "Inter-broker listener not started") + assertFalse(startFuture.isDone, "Socket server startup did not wait for future to complete") + + assertFalse(listenerStarted(externalListener)) + + externalReadyFuture.complete(null) + TestUtils.waitUntilTrue(() => listenerStarted(externalListener), "External listener not started") + } finally { + executor.shutdownNow() + shutdownServerAndMetrics(testableServer) + } + } + + @Test + def testStagedListenerShutdownWhenConnectionQueueIsFull(): Unit = { + val testProps = new Properties + testProps ++= props + testProps.put("listeners", "EXTERNAL://localhost:0,INTERNAL://localhost:0,CONTROLLER://localhost:0") + testProps.put("listener.security.protocol.map", "EXTERNAL:PLAINTEXT,INTERNAL:PLAINTEXT,CONTROLLER:PLAINTEXT") + testProps.put("control.plane.listener.name", "CONTROLLER") + testProps.put("inter.broker.listener.name", "INTERNAL") + val config = KafkaConfig.fromProps(testProps) + val connectionQueueSize = 1 + val testableServer = new TestableSocketServer(config, connectionQueueSize) + testableServer.startup(startProcessingRequests = false) + + val socket1 = connect(testableServer, new ListenerName("EXTERNAL"), localAddr = InetAddress.getLocalHost) + sendRequest(socket1, producerRequestBytes()) + val socket2 = connect(testableServer, new ListenerName("EXTERNAL"), localAddr = InetAddress.getLocalHost) + sendRequest(socket2, producerRequestBytes()) + + testableServer.shutdown() + } + + @Test + def testDisabledRequestIsRejected(): Unit = { + val correlationId = 57 + val header = new RequestHeader(ApiKeys.VOTE, 0, "", correlationId) + val request = new VoteRequest.Builder(new VoteRequestData()).build() + val serializedBytes = Utils.toArray(request.serializeWithHeader(header)) + + val socket = connect() + + val outgoing = new DataOutputStream(socket.getOutputStream) + try { + outgoing.writeInt(serializedBytes.length) + outgoing.write(serializedBytes) + outgoing.flush() + receiveResponse(socket) + } catch { + case _: IOException => // we expect the server to close the socket + } finally { + outgoing.close() + } + } + + @Test + def tooBigRequestIsRejected(): Unit = { + val tooManyBytes = new Array[Byte](server.config.socketRequestMaxBytes + 1) + new Random().nextBytes(tooManyBytes) + val socket = connect() + val outgoing = new DataOutputStream(socket.getOutputStream) + outgoing.writeInt(tooManyBytes.length) + try { + // Server closes client connection when it processes the request length because + // it is too big. The write of request body may fail if the connection has been closed. + outgoing.write(tooManyBytes) + outgoing.flush() + receiveResponse(socket) + } catch { + case _: IOException => // thats fine + } + } + + @Test + def testGracefulClose(): Unit = { + val plainSocket = connect() + val serializedBytes = producerRequestBytes() + + for (_ <- 0 until 10) + sendRequest(plainSocket, serializedBytes) + plainSocket.close() + for (_ <- 0 until 10) { + val request = receiveRequest(server.dataPlaneRequestChannel) + assertNotNull(request, "receiveRequest timed out") + processRequestNoOpResponse(server.dataPlaneRequestChannel, request) + } + } + + @Test + def testNoOpAction(): Unit = { + val plainSocket = connect() + val serializedBytes = producerRequestBytes() + + for (_ <- 0 until 3) + sendRequest(plainSocket, serializedBytes) + for (_ <- 0 until 3) { + val request = receiveRequest(server.dataPlaneRequestChannel) + assertNotNull(request, "receiveRequest timed out") + processRequestNoOpResponse(server.dataPlaneRequestChannel, request) + } + } + + @Test + def testConnectionId(): Unit = { + val sockets = (1 to 5).map(_ => connect()) + val serializedBytes = producerRequestBytes() + + val requests = sockets.map{socket => + sendRequest(socket, serializedBytes) + receiveRequest(server.dataPlaneRequestChannel) + } + requests.zipWithIndex.foreach { case (request, i) => + val index = request.context.connectionId.split("-").last + assertEquals(i.toString, index) + } + + sockets.foreach(_.close) + } + + @Test + def testIdleConnection(): Unit = { + val idleTimeMs = 60000 + val time = new MockTime() + props.put(KafkaConfig.ConnectionsMaxIdleMsProp, idleTimeMs.toString) + val serverMetrics = new Metrics + val overrideServer = new SocketServer(KafkaConfig.fromProps(props), serverMetrics, + time, credentialProvider, apiVersionManager) + + try { + overrideServer.startup() + val serializedBytes = producerRequestBytes() + + // Connection with no outstanding requests + val socket0 = connect(overrideServer) + sendRequest(socket0, serializedBytes) + val request0 = receiveRequest(overrideServer.dataPlaneRequestChannel) + processRequest(overrideServer.dataPlaneRequestChannel, request0) + assertTrue(openChannel(request0, overrideServer).nonEmpty, "Channel not open") + assertEquals(openChannel(request0, overrideServer), openOrClosingChannel(request0, overrideServer)) + TestUtils.waitUntilTrue(() => !openChannel(request0, overrideServer).get.isMuted, "Failed to unmute channel") + time.sleep(idleTimeMs + 1) + TestUtils.waitUntilTrue(() => openOrClosingChannel(request0, overrideServer).isEmpty, "Failed to close idle channel") + assertTrue(openChannel(request0, overrideServer).isEmpty, "Channel not removed") + + // Connection with one request being processed (channel is muted), no other in-flight requests + val socket1 = connect(overrideServer) + sendRequest(socket1, serializedBytes) + val request1 = receiveRequest(overrideServer.dataPlaneRequestChannel) + assertTrue(openChannel(request1, overrideServer).nonEmpty, "Channel not open") + assertEquals(openChannel(request1, overrideServer), openOrClosingChannel(request1, overrideServer)) + time.sleep(idleTimeMs + 1) + TestUtils.waitUntilTrue(() => openOrClosingChannel(request1, overrideServer).isEmpty, "Failed to close idle channel") + assertTrue(openChannel(request1, overrideServer).isEmpty, "Channel not removed") + processRequest(overrideServer.dataPlaneRequestChannel, request1) + + // Connection with one request being processed (channel is muted), more in-flight requests + val socket2 = connect(overrideServer) + val request2 = sendRequestsReceiveOne(overrideServer, socket2, serializedBytes, 3) + time.sleep(idleTimeMs + 1) + TestUtils.waitUntilTrue(() => openOrClosingChannel(request2, overrideServer).isEmpty, "Failed to close idle channel") + assertTrue(openChannel(request1, overrideServer).isEmpty, "Channel not removed") + processRequest(overrideServer.dataPlaneRequestChannel, request2) // this triggers a failed send since channel has been closed + assertNull(overrideServer.dataPlaneRequestChannel.receiveRequest(200), "Received request on expired channel") + + } finally { + shutdownServerAndMetrics(overrideServer) + } + } + + @Test + def testConnectionIdReuse(): Unit = { + val idleTimeMs = 60000 + val time = new MockTime() + props.put(KafkaConfig.ConnectionsMaxIdleMsProp, idleTimeMs.toString) + props ++= sslServerProps + val serverMetrics = new Metrics + @volatile var selector: TestableSelector = null + val overrideConnectionId = "127.0.0.1:1-127.0.0.1:2-0" + val overrideServer = new SocketServer( + KafkaConfig.fromProps(props), serverMetrics, time, credentialProvider, apiVersionManager + ) { + override def newProcessor(id: Int, requestChannel: RequestChannel, connectionQuotas: ConnectionQuotas, listenerName: ListenerName, + protocol: SecurityProtocol, memoryPool: MemoryPool, isPrivilegedListener: Boolean): Processor = { + new Processor(id, time, config.socketRequestMaxBytes, dataPlaneRequestChannel, connectionQuotas, + config.connectionsMaxIdleMs, config.failedAuthenticationDelayMs, listenerName, protocol, config, metrics, + credentialProvider, memoryPool, new LogContext(), Processor.ConnectionQueueSize, isPrivilegedListener, apiVersionManager) { + override protected[network] def connectionId(socket: Socket): String = overrideConnectionId + override protected[network] def createSelector(channelBuilder: ChannelBuilder): Selector = { + val testableSelector = new TestableSelector(config, channelBuilder, time, metrics) + selector = testableSelector + testableSelector + } + } + } + } + + def openChannel: Option[KafkaChannel] = overrideServer.dataPlaneProcessor(0).channel(overrideConnectionId) + def openOrClosingChannel: Option[KafkaChannel] = overrideServer.dataPlaneProcessor(0).openOrClosingChannel(overrideConnectionId) + def connectionCount = overrideServer.connectionCount(InetAddress.getByName("127.0.0.1")) + + // Create a client connection and wait for server to register the connection with the selector. For + // test scenarios below where `Selector.register` fails, the wait ensures that checks are performed + // only after `register` is processed by the server. + def connectAndWaitForConnectionRegister(): Socket = { + val connections = selector.operationCounts(SelectorOperation.Register) + val socket = sslConnect(overrideServer) + TestUtils.waitUntilTrue(() => + selector.operationCounts(SelectorOperation.Register) == connections + 1, "Connection not registered") + socket + } + + try { + overrideServer.startup() + val socket1 = connectAndWaitForConnectionRegister() + TestUtils.waitUntilTrue(() => connectionCount == 1 && openChannel.isDefined, "Failed to create channel") + val channel1 = openChannel.getOrElse(throw new RuntimeException("Channel not found")) + + // Create new connection with same id when `channel1` is still open and in Selector.channels + // Check that new connection is closed and openChannel still contains `channel1` + connectAndWaitForConnectionRegister() + TestUtils.waitUntilTrue(() => connectionCount == 1, "Failed to close channel") + assertSame(channel1, openChannel.getOrElse(throw new RuntimeException("Channel not found"))) + socket1.close() + TestUtils.waitUntilTrue(() => openChannel.isEmpty, "Channel not closed") + + // Create a channel with buffered receive and close remote connection + val request = makeChannelWithBufferedRequestsAndCloseRemote(overrideServer, selector) + val channel2 = openChannel.getOrElse(throw new RuntimeException("Channel not found")) + + // Create new connection with same id when `channel2` is closing, but still in Selector.channels + // Check that new connection is closed and openOrClosingChannel still contains `channel2` + connectAndWaitForConnectionRegister() + TestUtils.waitUntilTrue(() => connectionCount == 1, "Failed to close channel") + assertSame(channel2, openOrClosingChannel.getOrElse(throw new RuntimeException("Channel not found"))) + + // Complete request with failed send so that `channel2` is removed from Selector.channels + processRequest(overrideServer.dataPlaneRequestChannel, request) + TestUtils.waitUntilTrue(() => connectionCount == 0 && openOrClosingChannel.isEmpty, "Failed to remove channel with failed send") + + // Check that new connections can be created with the same id since `channel1` is no longer in Selector + connectAndWaitForConnectionRegister() + TestUtils.waitUntilTrue(() => connectionCount == 1 && openChannel.isDefined, "Failed to open new channel") + val newChannel = openChannel.getOrElse(throw new RuntimeException("Channel not found")) + assertNotSame(channel1, newChannel) + newChannel.disconnect() + + } finally { + shutdownServerAndMetrics(overrideServer) + } + } + + private def makeSocketWithBufferedRequests(server: SocketServer, + serverSelector: Selector, + proxyServer: ProxyServer, + numBufferedRequests: Int = 2): (Socket, RequestChannel.Request) = { + + val requestBytes = producerRequestBytes() + val socket = sslClientSocket(proxyServer.localPort) + sendRequest(socket, requestBytes) + val request1 = receiveRequest(server.dataPlaneRequestChannel) + + val connectionId = request1.context.connectionId + val channel = server.dataPlaneProcessor(0).channel(connectionId).getOrElse(throw new IllegalStateException("Channel not found")) + val transportLayer: SslTransportLayer = JTestUtils.fieldValue(channel, classOf[KafkaChannel], "transportLayer") + val netReadBuffer: ByteBuffer = JTestUtils.fieldValue(transportLayer, classOf[SslTransportLayer], "netReadBuffer") + + proxyServer.enableBuffering(netReadBuffer) + (1 to numBufferedRequests).foreach { _ => sendRequest(socket, requestBytes) } + + val keysWithBufferedRead: util.Set[SelectionKey] = JTestUtils.fieldValue(serverSelector, classOf[Selector], "keysWithBufferedRead") + keysWithBufferedRead.add(channel.selectionKey) + JTestUtils.setFieldValue(transportLayer, "hasBytesBuffered", true) + + (socket, request1) + } + + /** + * Create a channel with data in SSL buffers and close the remote connection. + * The channel should remain open in SocketServer even if it detects that the peer has closed + * the connection since there is pending data to be processed. + */ + private def makeChannelWithBufferedRequestsAndCloseRemote(server: SocketServer, + serverSelector: Selector, + makeClosing: Boolean = false): RequestChannel.Request = { + + val proxyServer = new ProxyServer(server) + try { + val (socket, request1) = makeSocketWithBufferedRequests(server, serverSelector, proxyServer) + + socket.close() + proxyServer.serverConnSocket.close() + TestUtils.waitUntilTrue(() => proxyServer.clientConnSocket.isClosed, "Client socket not closed", waitTimeMs = 10000) + + processRequestNoOpResponse(server.dataPlaneRequestChannel, request1) + val channel = openOrClosingChannel(request1, server).getOrElse(throw new IllegalStateException("Channel closed too early")) + if (makeClosing) + serverSelector.asInstanceOf[TestableSelector].pendingClosingChannels.add(channel) + + receiveRequest(server.dataPlaneRequestChannel, timeout = 10000) + } finally { + proxyServer.close() + } + } + + def sendRequestsReceiveOne(server: SocketServer, socket: Socket, requestBytes: Array[Byte], numRequests: Int): RequestChannel.Request = { + (1 to numRequests).foreach(i => sendRequest(socket, requestBytes, flush = i == numRequests)) + receiveRequest(server.dataPlaneRequestChannel) + } + + private def closeSocketWithPendingRequest(server: SocketServer, + createSocket: () => Socket): RequestChannel.Request = { + + def maybeReceiveRequest(): Option[RequestChannel.Request] = { + try { + Some(receiveRequest(server.dataPlaneRequestChannel, timeout = 1000)) + } catch { + case e: Exception => None + } + } + + def closedChannelWithPendingRequest(): Option[RequestChannel.Request] = { + val socket = createSocket.apply() + val req1 = sendRequestsReceiveOne(server, socket, producerRequestBytes(ack = 0), numRequests = 100) + processRequestNoOpResponse(server.dataPlaneRequestChannel, req1) + // Set SoLinger to 0 to force a hard disconnect via TCP RST + socket.setSoLinger(true, 0) + socket.close() + + maybeReceiveRequest().flatMap { req => + processRequestNoOpResponse(server.dataPlaneRequestChannel, req) + maybeReceiveRequest() + } + } + + val (request, _) = TestUtils.computeUntilTrue(closedChannelWithPendingRequest()) { req => req.nonEmpty } + request.getOrElse(throw new IllegalStateException("Could not create close channel with pending request")) + } + + // Prepares test setup for throttled channel tests. throttlingDone controls whether or not throttling has completed + // in quota manager. + def throttledChannelTestSetUp(socket: Socket, serializedBytes: Array[Byte], noOpResponse: Boolean, + throttlingInProgress: Boolean): RequestChannel.Request = { + sendRequest(socket, serializedBytes) + + // Mimic a primitive request handler that fetches the request from RequestChannel and place a response with a + // throttled channel. + val request = receiveRequest(server.dataPlaneRequestChannel) + val byteBuffer = request.body[AbstractRequest].serializeWithHeader(request.header) + val send = new NetworkSend(request.context.connectionId, ByteBufferSend.sizePrefixed(byteBuffer)) + + val channelThrottlingCallback = new ThrottleCallback { + override def startThrottling(): Unit = server.dataPlaneRequestChannel.startThrottling(request) + override def endThrottling(): Unit = server.dataPlaneRequestChannel.endThrottling(request) + } + val throttledChannel = new ThrottledChannel(new MockTime(), 100, channelThrottlingCallback) + val headerLog = RequestConvertToJson.requestHeaderNode(request.header) + val response = + if (!noOpResponse) + new RequestChannel.SendResponse(request, send, Some(headerLog), None) + else + new RequestChannel.NoOpResponse(request) + server.dataPlaneRequestChannel.sendResponse(response) + + // Quota manager would call notifyThrottlingDone() on throttling completion. Simulate it if throttleingInProgress is + // false. + if (!throttlingInProgress) + throttledChannel.notifyThrottlingDone() + + request + } + + def openChannel(request: RequestChannel.Request, server: SocketServer = this.server): Option[KafkaChannel] = + server.dataPlaneProcessor(0).channel(request.context.connectionId) + + def openOrClosingChannel(request: RequestChannel.Request, server: SocketServer = this.server): Option[KafkaChannel] = + server.dataPlaneProcessor(0).openOrClosingChannel(request.context.connectionId) + + @Test + def testSendActionResponseWithThrottledChannelWhereThrottlingInProgress(): Unit = { + val socket = connect() + val serializedBytes = producerRequestBytes() + // SendAction with throttling in progress + val request = throttledChannelTestSetUp(socket, serializedBytes, false, true) + + // receive response + assertEquals(serializedBytes.toSeq, receiveResponse(socket).toSeq) + TestUtils.waitUntilTrue(() => openOrClosingChannel(request).exists(c => c.muteState() == ChannelMuteState.MUTED_AND_THROTTLED), "fail") + // Channel should still be muted. + assertTrue(openOrClosingChannel(request).exists(c => c.isMuted())) + } + + @Test + def testSendActionResponseWithThrottledChannelWhereThrottlingAlreadyDone(): Unit = { + val socket = connect() + val serializedBytes = producerRequestBytes() + // SendAction with throttling in progress + val request = throttledChannelTestSetUp(socket, serializedBytes, false, false) + + // receive response + assertEquals(serializedBytes.toSeq, receiveResponse(socket).toSeq) + // Since throttling is already done, the channel can be unmuted after sending out the response. + TestUtils.waitUntilTrue(() => openOrClosingChannel(request).exists(c => c.muteState() == ChannelMuteState.NOT_MUTED), "fail") + // Channel is now unmuted. + assertFalse(openOrClosingChannel(request).exists(c => c.isMuted())) + } + + @Test + def testNoOpActionResponseWithThrottledChannelWhereThrottlingInProgress(): Unit = { + val socket = connect() + val serializedBytes = producerRequestBytes() + // SendAction with throttling in progress + val request = throttledChannelTestSetUp(socket, serializedBytes, true, true) + + TestUtils.waitUntilTrue(() => openOrClosingChannel(request).exists(c => c.muteState() == ChannelMuteState.MUTED_AND_THROTTLED), "fail") + // Channel should still be muted. + assertTrue(openOrClosingChannel(request).exists(c => c.isMuted())) + } + + @Test + def testNoOpActionResponseWithThrottledChannelWhereThrottlingAlreadyDone(): Unit = { + val socket = connect() + val serializedBytes = producerRequestBytes() + // SendAction with throttling in progress + val request = throttledChannelTestSetUp(socket, serializedBytes, true, false) + + // Since throttling is already done, the channel can be unmuted. + TestUtils.waitUntilTrue(() => openOrClosingChannel(request).exists(c => c.muteState() == ChannelMuteState.NOT_MUTED), "fail") + // Channel is now unmuted. + assertFalse(openOrClosingChannel(request).exists(c => c.isMuted())) + } + + @Test + def testSocketsCloseOnShutdown(): Unit = { + // open a connection + val plainSocket = connect() + plainSocket.setTcpNoDelay(true) + val bytes = new Array[Byte](40) + // send a request first to make sure the connection has been picked up by the socket server + sendRequest(plainSocket, bytes, Some(0)) + processRequest(server.dataPlaneRequestChannel) + // the following sleep is necessary to reliably detect the connection close when we send data below + Thread.sleep(200L) + // make sure the sockets are open + server.dataPlaneAcceptors.asScala.values.foreach(acceptor => assertFalse(acceptor.serverChannel.socket.isClosed)) + // then shutdown the server + shutdownServerAndMetrics(server) + + verifyRemoteConnectionClosed(plainSocket) + } + + @Test + def testMaxConnectionsPerIp(): Unit = { + // make the maximum allowable number of connections + val conns = (0 until server.config.maxConnectionsPerIp).map(_ => connect()) + // now try one more (should fail) + val conn = connect() + conn.setSoTimeout(3000) + assertEquals(-1, conn.getInputStream.read()) + conn.close() + + // it should succeed after closing one connection + val address = conns.head.getInetAddress + conns.head.close() + TestUtils.waitUntilTrue(() => server.connectionCount(address) < conns.length, + "Failed to decrement connection count after close") + val conn2 = connect() + val serializedBytes = producerRequestBytes() + sendRequest(conn2, serializedBytes) + val request = server.dataPlaneRequestChannel.receiveRequest(2000) + assertNotNull(request) + } + + @Test + def testZeroMaxConnectionsPerIp(): Unit = { + val newProps = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 0) + newProps.setProperty(KafkaConfig.MaxConnectionsPerIpProp, "0") + newProps.setProperty(KafkaConfig.MaxConnectionsPerIpOverridesProp, "%s:%s".format("127.0.0.1", "5")) + val server = new SocketServer(KafkaConfig.fromProps(newProps), new Metrics(), + Time.SYSTEM, credentialProvider, apiVersionManager) + try { + server.startup() + // make the maximum allowable number of connections + val conns = (0 until 5).map(_ => connect(server)) + // now try one more (should fail) + val conn = connect(server) + conn.setSoTimeout(3000) + assertEquals(-1, conn.getInputStream.read()) + conn.close() + + // it should succeed after closing one connection + val address = conns.head.getInetAddress + conns.head.close() + TestUtils.waitUntilTrue(() => server.connectionCount(address) < conns.length, + "Failed to decrement connection count after close") + val conn2 = connect(server) + val serializedBytes = producerRequestBytes() + sendRequest(conn2, serializedBytes) + val request = server.dataPlaneRequestChannel.receiveRequest(2000) + assertNotNull(request) + + // now try to connect from the external facing interface, which should fail + val conn3 = connect(s = server, localAddr = InetAddress.getLocalHost) + conn3.setSoTimeout(3000) + assertEquals(-1, conn3.getInputStream.read()) + conn3.close() + } finally { + shutdownServerAndMetrics(server) + } + } + + @Test + def testMaxConnectionsPerIpOverrides(): Unit = { + val overrideNum = server.config.maxConnectionsPerIp + 1 + val overrideProps = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 0) + overrideProps.put(KafkaConfig.MaxConnectionsPerIpOverridesProp, s"localhost:$overrideNum") + val serverMetrics = new Metrics() + val overrideServer = new SocketServer(KafkaConfig.fromProps(overrideProps), serverMetrics, + Time.SYSTEM, credentialProvider, apiVersionManager) + try { + overrideServer.startup() + // make the maximum allowable number of connections + val conns = (0 until overrideNum).map(_ => connect(overrideServer)) + + // it should succeed + val serializedBytes = producerRequestBytes() + sendRequest(conns.last, serializedBytes) + val request = overrideServer.dataPlaneRequestChannel.receiveRequest(2000) + assertNotNull(request) + + // now try one more (should fail) + val conn = connect(overrideServer) + conn.setSoTimeout(3000) + assertEquals(-1, conn.getInputStream.read()) + } finally { + shutdownServerAndMetrics(overrideServer) + } + } + + @Test + def testConnectionRatePerIp(): Unit = { + val defaultTimeoutMs = 2000 + val overrideProps = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 0) + overrideProps.remove(KafkaConfig.MaxConnectionsPerIpProp) + overrideProps.put(KafkaConfig.NumQuotaSamplesProp, String.valueOf(2)) + val connectionRate = 5 + val time = new MockTime() + val overrideServer = new SocketServer(KafkaConfig.fromProps(overrideProps), new Metrics(), + time, credentialProvider, apiVersionManager) + // update the connection rate to 5 + overrideServer.connectionQuotas.updateIpConnectionRateQuota(None, Some(connectionRate)) + try { + overrideServer.startup() + // make the (maximum allowable number + 1) of connections + (0 to connectionRate).map(_ => connect(overrideServer)) + + val acceptors = overrideServer.dataPlaneAcceptors.asScala.values + // waiting for 5 connections got accepted and 1 connection got throttled + TestUtils.waitUntilTrue( + () => acceptors.foldLeft(0)((accumulator, acceptor) => accumulator + acceptor.throttledSockets.size) == 1, + "timeout waiting for 1 connection to get throttled", + defaultTimeoutMs) + + // now try one more, so that we can make sure this connection will get throttled + var conn = connect(overrideServer) + // there should be total 2 connection got throttled now + TestUtils.waitUntilTrue( + () => acceptors.foldLeft(0)((accumulator, acceptor) => accumulator + acceptor.throttledSockets.size) == 2, + "timeout waiting for 2 connection to get throttled", + defaultTimeoutMs) + // advance time to unthrottle connections + time.sleep(defaultTimeoutMs) + acceptors.foreach(_.wakeup()) + // make sure there are no connection got throttled now(and the throttled connections should be closed) + TestUtils.waitUntilTrue(() => acceptors.forall(_.throttledSockets.isEmpty), + "timeout waiting for connection to be unthrottled", + defaultTimeoutMs) + // verify the connection is closed now + verifyRemoteConnectionClosed(conn) + + // new connection should succeed after previous connection closed, and previous samples have been expired + conn = connect(overrideServer) + val serializedBytes = producerRequestBytes() + sendRequest(conn, serializedBytes) + val request = overrideServer.dataPlaneRequestChannel.receiveRequest(defaultTimeoutMs) + assertNotNull(request) + } finally { + shutdownServerAndMetrics(overrideServer) + } + } + + @Test + def testThrottledSocketsClosedOnShutdown(): Unit = { + val overrideProps = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 0) + overrideProps.remove("max.connections.per.ip") + overrideProps.put(KafkaConfig.NumQuotaSamplesProp, String.valueOf(2)) + val connectionRate = 5 + val time = new MockTime() + val overrideServer = new SocketServer(KafkaConfig.fromProps(overrideProps), new Metrics(), + time, credentialProvider, apiVersionManager) + overrideServer.connectionQuotas.updateIpConnectionRateQuota(None, Some(connectionRate)) + overrideServer.startup() + // make the maximum allowable number of connections + (0 until connectionRate).map(_ => connect(overrideServer)) + // now try one more (should get throttled) + val conn = connect(overrideServer) + // don't advance time so that connection never gets unthrottled + shutdownServerAndMetrics(overrideServer) + verifyRemoteConnectionClosed(conn) + } + + private def verifyRemoteConnectionClosed(connection: Socket): Unit = { + val largeChunkOfBytes = new Array[Byte](1000000) + // doing a subsequent send should throw an exception as the connection should be closed. + // send a large chunk of bytes to trigger a socket flush + assertThrows(classOf[IOException], () => sendRequest(connection, largeChunkOfBytes, Some(0))) + } + + @Test + def testSslSocketServer(): Unit = { + val serverMetrics = new Metrics + val overrideServer = new SocketServer(KafkaConfig.fromProps(sslServerProps), serverMetrics, + Time.SYSTEM, credentialProvider, apiVersionManager) + try { + overrideServer.startup() + val sslContext = SSLContext.getInstance(TestSslUtils.DEFAULT_TLS_PROTOCOL_FOR_TESTS) + sslContext.init(null, Array(TestUtils.trustAllCerts), new java.security.SecureRandom()) + val socketFactory = sslContext.getSocketFactory + val sslSocket = socketFactory.createSocket("localhost", + overrideServer.boundPort(ListenerName.forSecurityProtocol(SecurityProtocol.SSL))).asInstanceOf[SSLSocket] + sslSocket.setNeedClientAuth(false) + + val correlationId = -1 + val clientId = "" + val ackTimeoutMs = 10000 + val ack = 0: Short + val emptyRequest = requests.ProduceRequest.forCurrentMagic(new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection()) + .setAcks(ack) + .setTimeoutMs(ackTimeoutMs) + .setTransactionalId(null)) + .build() + val emptyHeader = new RequestHeader(ApiKeys.PRODUCE, emptyRequest.version, clientId, correlationId) + val serializedBytes = Utils.toArray(emptyRequest.serializeWithHeader(emptyHeader)) + + sendRequest(sslSocket, serializedBytes) + processRequest(overrideServer.dataPlaneRequestChannel) + assertEquals(serializedBytes.toSeq, receiveResponse(sslSocket).toSeq) + sslSocket.close() + } finally { + shutdownServerAndMetrics(overrideServer) + } + } + + @Test + def testSaslReauthenticationFailureWithKip152SaslAuthenticate(): Unit = { + checkSaslReauthenticationFailure(true) + } + + @Test + def testSaslReauthenticationFailureNoKip152SaslAuthenticate(): Unit = { + checkSaslReauthenticationFailure(false) + } + + def checkSaslReauthenticationFailure(leverageKip152SaslAuthenticateRequest : Boolean): Unit = { + shutdownServerAndMetrics(server) // we will use our own instance because we require custom configs + val username = "admin" + val password = "admin-secret" + val reauthMs = 1500 + val brokerProps = new Properties + brokerProps.setProperty("listeners", "SASL_PLAINTEXT://localhost:0") + brokerProps.setProperty("security.inter.broker.protocol", "SASL_PLAINTEXT") + brokerProps.setProperty("listener.name.sasl_plaintext.plain.sasl.jaas.config", + "org.apache.kafka.common.security.plain.PlainLoginModule required " + + "username=\"%s\" password=\"%s\" user_%s=\"%s\";".format(username, password, username, password)) + brokerProps.setProperty("sasl.mechanism.inter.broker.protocol", "PLAIN") + brokerProps.setProperty("listener.name.sasl_plaintext.sasl.enabled.mechanisms", "PLAIN") + brokerProps.setProperty("num.network.threads", "1") + brokerProps.setProperty("connections.max.reauth.ms", reauthMs.toString) + val overrideProps = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, + saslProperties = Some(brokerProps), enableSaslPlaintext = true) + val time = new MockTime() + val overrideServer = new TestableSocketServer(KafkaConfig.fromProps(overrideProps), time = time) + try { + overrideServer.startup() + val socket = connect(overrideServer, ListenerName.forSecurityProtocol(SecurityProtocol.SASL_PLAINTEXT)) + + val correlationId = -1 + val clientId = "" + // send a SASL handshake request + val version : Short = if (leverageKip152SaslAuthenticateRequest) ApiKeys.SASL_HANDSHAKE.latestVersion else 0 + val saslHandshakeRequest = new SaslHandshakeRequest.Builder(new SaslHandshakeRequestData().setMechanism("PLAIN")) + .build(version) + val saslHandshakeHeader = new RequestHeader(ApiKeys.SASL_HANDSHAKE, saslHandshakeRequest.version, clientId, + correlationId) + sendApiRequest(socket, saslHandshakeRequest, saslHandshakeHeader) + receiveResponse(socket) + + // now send credentials + val authBytes = "admin\u0000admin\u0000admin-secret".getBytes(StandardCharsets.UTF_8) + if (leverageKip152SaslAuthenticateRequest) { + // send credentials within a SaslAuthenticateRequest + val saslAuthenticateRequest = new SaslAuthenticateRequest.Builder(new SaslAuthenticateRequestData() + .setAuthBytes(authBytes)).build() + val saslAuthenticateHeader = new RequestHeader(ApiKeys.SASL_AUTHENTICATE, saslAuthenticateRequest.version, + clientId, correlationId) + sendApiRequest(socket, saslAuthenticateRequest, saslAuthenticateHeader) + } else { + // send credentials directly, without a SaslAuthenticateRequest + sendRequest(socket, authBytes) + } + receiveResponse(socket) + assertEquals(1, overrideServer.testableSelector.channels.size) + + // advance the clock long enough to cause server-side disconnection upon next send... + time.sleep(reauthMs * 2) + // ...and now send something to trigger the disconnection + val ackTimeoutMs = 10000 + val ack = 0: Short + val emptyRequest = requests.ProduceRequest.forCurrentMagic(new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection()) + .setAcks(ack) + .setTimeoutMs(ackTimeoutMs) + .setTransactionalId(null)) + .build() + val emptyHeader = new RequestHeader(ApiKeys.PRODUCE, emptyRequest.version, clientId, correlationId) + sendApiRequest(socket, emptyRequest, emptyHeader) + // wait a little bit for the server-side disconnection to occur since it happens asynchronously + try { + TestUtils.waitUntilTrue(() => overrideServer.testableSelector.channels.isEmpty, + "Expired connection was not closed", 1000, 100) + } finally { + socket.close() + } + } finally { + shutdownServerAndMetrics(overrideServer) + } + } + + @Test + def testSessionPrincipal(): Unit = { + val socket = connect() + val bytes = new Array[Byte](40) + sendRequest(socket, bytes, Some(0)) + assertEquals(KafkaPrincipal.ANONYMOUS, receiveRequest(server.dataPlaneRequestChannel).session.principal) + } + + /* Test that we update request metrics if the client closes the connection while the broker response is in flight. */ + @Test + def testClientDisconnectionUpdatesRequestMetrics(): Unit = { + // The way we detect a connection close from the client depends on the response size. If it's small, an + // IOException ("Connection reset by peer") is thrown when the Selector reads from the socket. If + // it's large, an IOException ("Broken pipe") is thrown when the Selector writes to the socket. We test + // both paths to ensure they are handled correctly. + checkClientDisconnectionUpdatesRequestMetrics(0) + checkClientDisconnectionUpdatesRequestMetrics(550000) + } + + private def checkClientDisconnectionUpdatesRequestMetrics(responseBufferSize: Int): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 0) + val serverMetrics = new Metrics + var conn: Socket = null + val overrideServer = new SocketServer( + KafkaConfig.fromProps(props), serverMetrics, Time.SYSTEM, credentialProvider, apiVersionManager + ) { + override def newProcessor(id: Int, requestChannel: RequestChannel, connectionQuotas: ConnectionQuotas, listenerName: ListenerName, + protocol: SecurityProtocol, memoryPool: MemoryPool, isPrivilegedListener: Boolean = false): Processor = { + new Processor(id, time, config.socketRequestMaxBytes, dataPlaneRequestChannel, connectionQuotas, + config.connectionsMaxIdleMs, config.failedAuthenticationDelayMs, listenerName, protocol, config, metrics, + credentialProvider, MemoryPool.NONE, new LogContext(), Processor.ConnectionQueueSize, isPrivilegedListener, apiVersionManager) { + override protected[network] def sendResponse(response: RequestChannel.Response, responseSend: Send): Unit = { + conn.close() + super.sendResponse(response, responseSend) + } + } + } + } + try { + overrideServer.startup() + conn = connect(overrideServer) + val serializedBytes = producerRequestBytes() + sendRequest(conn, serializedBytes) + + val channel = overrideServer.dataPlaneRequestChannel + val request = receiveRequest(channel) + + val requestMetrics = channel.metrics(request.header.apiKey.name) + def totalTimeHistCount(): Long = requestMetrics.totalTimeHist.count + val send = new NetworkSend(request.context.connectionId, ByteBufferSend.sizePrefixed(ByteBuffer.allocate(responseBufferSize))) + val headerLog = new ObjectNode(JsonNodeFactory.instance) + headerLog.set("response", new TextNode("someResponse")) + channel.sendResponse(new RequestChannel.SendResponse(request, send, Some(headerLog), None)) + + val expectedTotalTimeCount = totalTimeHistCount() + 1 + TestUtils.waitUntilTrue(() => totalTimeHistCount() == expectedTotalTimeCount, + s"request metrics not updated, expected: $expectedTotalTimeCount, actual: ${totalTimeHistCount()}") + + } finally { + shutdownServerAndMetrics(overrideServer) + } + } + + @Test + def testClientDisconnectionWithOutstandingReceivesProcessedUntilFailedSend(): Unit = { + val serverMetrics = new Metrics + @volatile var selector: TestableSelector = null + val overrideServer = new SocketServer( + KafkaConfig.fromProps(props), serverMetrics, Time.SYSTEM, credentialProvider, apiVersionManager + ) { + override def newProcessor(id: Int, requestChannel: RequestChannel, connectionQuotas: ConnectionQuotas, listenerName: ListenerName, + protocol: SecurityProtocol, memoryPool: MemoryPool, isPrivilegedListener: Boolean): Processor = { + new Processor(id, time, config.socketRequestMaxBytes, dataPlaneRequestChannel, connectionQuotas, + config.connectionsMaxIdleMs, config.failedAuthenticationDelayMs, listenerName, protocol, config, metrics, + credentialProvider, memoryPool, new LogContext(), Processor.ConnectionQueueSize, isPrivilegedListener, apiVersionManager) { + override protected[network] def createSelector(channelBuilder: ChannelBuilder): Selector = { + val testableSelector = new TestableSelector(config, channelBuilder, time, metrics) + selector = testableSelector + testableSelector + } + } + } + } + + try { + overrideServer.startup() + + // Create a channel, send some requests and close socket. Receive one pending request after socket was closed. + val request = closeSocketWithPendingRequest(overrideServer, () => connect(overrideServer)) + + // Complete request with socket exception so that the channel is closed + processRequest(overrideServer.dataPlaneRequestChannel, request) + TestUtils.waitUntilTrue(() => openOrClosingChannel(request, overrideServer).isEmpty, "Channel not closed after failed send") + assertTrue(selector.completedSends.isEmpty, "Unexpected completed send") + } finally { + overrideServer.shutdown() + serverMetrics.close() + } + } + + /* + * Test that we update request metrics if the channel has been removed from the selector when the broker calls + * `selector.send` (selector closes old connections, for example). + */ + @Test + def testBrokerSendAfterChannelClosedUpdatesRequestMetrics(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 0) + props.setProperty(KafkaConfig.ConnectionsMaxIdleMsProp, "110") + val serverMetrics = new Metrics + var conn: Socket = null + val overrideServer = new SocketServer(KafkaConfig.fromProps(props), serverMetrics, + Time.SYSTEM, credentialProvider, apiVersionManager) + try { + overrideServer.startup() + conn = connect(overrideServer) + val serializedBytes = producerRequestBytes() + sendRequest(conn, serializedBytes) + val channel = overrideServer.dataPlaneRequestChannel + val request = receiveRequest(channel) + + TestUtils.waitUntilTrue(() => overrideServer.dataPlaneProcessor(request.processor).channel(request.context.connectionId).isEmpty, + s"Idle connection `${request.context.connectionId}` was not closed by selector") + + val requestMetrics = channel.metrics(request.header.apiKey.name) + def totalTimeHistCount(): Long = requestMetrics.totalTimeHist.count + val expectedTotalTimeCount = totalTimeHistCount() + 1 + + processRequest(channel, request) + + TestUtils.waitUntilTrue(() => totalTimeHistCount() == expectedTotalTimeCount, + s"request metrics not updated, expected: $expectedTotalTimeCount, actual: ${totalTimeHistCount()}") + + } finally { + shutdownServerAndMetrics(overrideServer) + } + } + + @Test + def testRequestMetricsAfterStop(): Unit = { + server.stopProcessingRequests() + val version = ApiKeys.PRODUCE.latestVersion + val version2 = (version - 1).toShort + for (_ <- 0 to 1) server.dataPlaneRequestChannel.metrics(ApiKeys.PRODUCE.name).requestRate(version).mark() + server.dataPlaneRequestChannel.metrics(ApiKeys.PRODUCE.name).requestRate(version2).mark() + assertEquals(2, server.dataPlaneRequestChannel.metrics(ApiKeys.PRODUCE.name).requestRate(version).count()) + server.dataPlaneRequestChannel.updateErrorMetrics(ApiKeys.PRODUCE, Map(Errors.NONE -> 1)) + val nonZeroMeters = Map(s"kafka.network:type=RequestMetrics,name=RequestsPerSec,request=Produce,version=$version" -> 2, + s"kafka.network:type=RequestMetrics,name=RequestsPerSec,request=Produce,version=$version2" -> 1, + "kafka.network:type=RequestMetrics,name=ErrorsPerSec,request=Produce,error=NONE" -> 1) + + def requestMetricMeters = KafkaYammerMetrics + .defaultRegistry + .allMetrics.asScala + .collect { case (k, metric: Meter) if k.getType == "RequestMetrics" => (k.toString, metric.count) } + + assertEquals(nonZeroMeters, requestMetricMeters.filter { case (_, value) => value != 0 }) + server.shutdown() + assertEquals(Map.empty, requestMetricMeters) + } + + @Test + def testMetricCollectionAfterShutdown(): Unit = { + server.shutdown() + + val nonZeroMetricNamesAndValues = KafkaYammerMetrics + .defaultRegistry + .allMetrics.asScala + .filter { case (k, _) => k.getName.endsWith("IdlePercent") || k.getName.endsWith("NetworkProcessorAvgIdlePercent") } + .collect { case (k, metric: Gauge[_]) => (k, metric.value().asInstanceOf[Double]) } + .filter { case (_, value) => value != 0.0 && !value.equals(Double.NaN) } + + assertEquals(Map.empty, nonZeroMetricNamesAndValues) + } + + @Test + def testProcessorMetricsTags(): Unit = { + val kafkaMetricNames = metrics.metrics.keySet.asScala.filter(_.tags.asScala.get("listener").nonEmpty) + assertFalse(kafkaMetricNames.isEmpty) + + val expectedListeners = Set("PLAINTEXT") + kafkaMetricNames.foreach { kafkaMetricName => + assertTrue(expectedListeners.contains(kafkaMetricName.tags.get("listener"))) + } + + // legacy metrics not tagged + val yammerMetricsNames = KafkaYammerMetrics.defaultRegistry.allMetrics.asScala + .filter { case (k, _) => k.getType.equals("Processor") } + .collect { case (k, _: Gauge[_]) => k } + assertFalse(yammerMetricsNames.isEmpty) + + yammerMetricsNames.foreach { yammerMetricName => + assertFalse(yammerMetricName.getMBeanName.contains("listener=")) + } + } + + /** + * Tests exception handling in [[Processor.configureNewConnections]]. Exception is + * injected into [[Selector.register]] which is used to register each new connection. + * Test creates two connections in a single iteration by waking up the selector only + * when two connections are ready. + * Verifies that + * - first failed connection is closed + * - second connection is processed successfully after the first fails with an exception + * - processor is healthy after the exception + */ + @Test + def configureNewConnectionException(): Unit = { + withTestableServer (testWithServer = { testableServer => + val testableSelector = testableServer.testableSelector + + testableSelector.updateMinWakeup(2) + testableSelector.addFailure(SelectorOperation.Register) + val sockets = (1 to 2).map(_ => connect(testableServer)) + testableSelector.waitForOperations(SelectorOperation.Register, 2) + TestUtils.waitUntilTrue(() => testableServer.connectionCount(localAddress) == 1, "Failed channel not removed") + + assertProcessorHealthy(testableServer, testableSelector.notFailed(sockets)) + }) + } + + /** + * Tests exception handling in [[Processor.processNewResponses]]. Exception is + * injected into [[Selector.send]] which is used to send the new response. + * Test creates two responses in a single iteration by waking up the selector only + * when two responses are ready. + * Verifies that + * - first failed channel is closed + * - second response is processed successfully after the first fails with an exception + * - processor is healthy after the exception + */ + @Test + def processNewResponseException(): Unit = { + withTestableServer (testWithServer = { testableServer => + val testableSelector = testableServer.testableSelector + testableSelector.updateMinWakeup(2) + + val sockets = (1 to 2).map(_ => connect(testableServer)) + sockets.foreach(sendRequest(_, producerRequestBytes())) + + testableServer.testableSelector.addFailure(SelectorOperation.Send) + sockets.foreach(_ => processRequest(testableServer.dataPlaneRequestChannel)) + testableSelector.waitForOperations(SelectorOperation.Send, 2) + testableServer.waitForChannelClose(testableSelector.allFailedChannels.head, locallyClosed = true) + + assertProcessorHealthy(testableServer, testableSelector.notFailed(sockets)) + }) + } + + /** + * Tests exception handling in [[Processor.processNewResponses]] when [[Selector.send]] + * fails with `CancelledKeyException`, which is handled by the selector using a different + * code path. Test scenario is similar to [[SocketServerTest.processNewResponseException]]. + */ + @Test + def sendCancelledKeyException(): Unit = { + withTestableServer (testWithServer = { testableServer => + val testableSelector = testableServer.testableSelector + testableSelector.updateMinWakeup(2) + + val sockets = (1 to 2).map(_ => connect(testableServer)) + sockets.foreach(sendRequest(_, producerRequestBytes())) + val requestChannel = testableServer.dataPlaneRequestChannel + + val requests = sockets.map(_ => receiveRequest(requestChannel)) + val failedConnectionId = requests(0).context.connectionId + // `KafkaChannel.disconnect()` cancels the selection key, triggering CancelledKeyException during send + testableSelector.channel(failedConnectionId).disconnect() + requests.foreach(processRequest(requestChannel, _)) + testableSelector.waitForOperations(SelectorOperation.Send, 2) + testableServer.waitForChannelClose(failedConnectionId, locallyClosed = false) + + val successfulSocket = if (isSocketConnectionId(failedConnectionId, sockets(0))) sockets(1) else sockets(0) + assertProcessorHealthy(testableServer, Seq(successfulSocket)) + }) + } + + /** + * Tests channel send failure handling when send failure is triggered by [[Selector.send]] + * to a channel whose peer has closed its connection. + */ + @Test + def remoteCloseSendFailure(): Unit = { + verifySendFailureAfterRemoteClose(makeClosing = false) + } + + /** + * Tests channel send failure handling when send failure is triggered by [[Selector.send]] + * to a channel whose peer has closed its connection and the channel is in `closingChannels`. + */ + @Test + def closingChannelSendFailure(): Unit = { + verifySendFailureAfterRemoteClose(makeClosing = true) + } + + private def verifySendFailureAfterRemoteClose(makeClosing: Boolean): Unit = { + props ++= sslServerProps + withTestableServer (testWithServer = { testableServer => + val testableSelector = testableServer.testableSelector + + val serializedBytes = producerRequestBytes() + val request = makeChannelWithBufferedRequestsAndCloseRemote(testableServer, testableSelector, makeClosing) + val otherSocket = sslConnect(testableServer) + sendRequest(otherSocket, serializedBytes) + + processRequest(testableServer.dataPlaneRequestChannel, request) + processRequest(testableServer.dataPlaneRequestChannel) // Also process request from other socket + testableSelector.waitForOperations(SelectorOperation.Send, 2) + testableServer.waitForChannelClose(request.context.connectionId, locallyClosed = false) + + assertProcessorHealthy(testableServer, Seq(otherSocket)) + }) + } + + /** + * Verifies that all pending buffered receives are processed even if remote connection is closed. + * The channel must be closed after pending receives are processed. + */ + @Test + def remoteCloseWithBufferedReceives(): Unit = { + verifyRemoteCloseWithBufferedReceives(numComplete = 3, hasIncomplete = false) + } + + /** + * Verifies that channel is closed when remote client closes its connection if there is no + * buffered receive. + */ + @Test + def remoteCloseWithoutBufferedReceives(): Unit = { + verifyRemoteCloseWithBufferedReceives(numComplete = 0, hasIncomplete = false) + } + + /** + * Verifies that channel is closed when remote client closes its connection if there is a pending + * receive that is incomplete. + */ + @Test + def remoteCloseWithIncompleteBufferedReceive(): Unit = { + verifyRemoteCloseWithBufferedReceives(numComplete = 0, hasIncomplete = true) + } + + /** + * Verifies that all pending buffered receives are processed even if remote connection is closed. + * The channel must be closed after complete receives are processed, even if there is an incomplete + * receive remaining in the buffers. + */ + @Test + def remoteCloseWithCompleteAndIncompleteBufferedReceives(): Unit = { + verifyRemoteCloseWithBufferedReceives(numComplete = 3, hasIncomplete = true) + } + + /** + * Verifies that pending buffered receives are processed when remote connection is closed + * until a response send fails. + */ + @Test + def remoteCloseWithBufferedReceivesFailedSend(): Unit = { + verifyRemoteCloseWithBufferedReceives(numComplete = 3, hasIncomplete = false, responseRequiredIndex = 1) + } + + /** + * Verifies that all pending buffered receives are processed for channel in closing state. + * The channel must be closed after pending receives are processed. + */ + @Test + def closingChannelWithBufferedReceives(): Unit = { + verifyRemoteCloseWithBufferedReceives(numComplete = 3, hasIncomplete = false, makeClosing = true) + } + + /** + * Verifies that all pending buffered receives are processed for channel in closing state. + * The channel must be closed after complete receives are processed, even if there is an incomplete + * receive remaining in the buffers. + */ + @Test + def closingChannelWithCompleteAndIncompleteBufferedReceives(): Unit = { + verifyRemoteCloseWithBufferedReceives(numComplete = 3, hasIncomplete = true, makeClosing = false) + } + + /** + * Verifies that pending buffered receives are processed for a channel in closing state + * until a response send fails. + */ + @Test + def closingChannelWithBufferedReceivesFailedSend(): Unit = { + verifyRemoteCloseWithBufferedReceives(numComplete = 3, hasIncomplete = false, responseRequiredIndex = 1, makeClosing = false) + } + + /** + * Verifies handling of client disconnections when the server-side channel is in the state + * specified using the parameters. + * + * @param numComplete Number of complete buffered requests + * @param hasIncomplete If true, add an additional partial buffered request + * @param responseRequiredIndex Index of the buffered request for which a response is sent. Previous requests + * are completed without a response. If set to -1, all `numComplete` requests + * are completed without a response. + * @param makeClosing If true, put the channel into closing state in the server Selector. + */ + private def verifyRemoteCloseWithBufferedReceives(numComplete: Int, + hasIncomplete: Boolean, + responseRequiredIndex: Int = -1, + makeClosing: Boolean = false): Unit = { + props ++= sslServerProps + + // Truncates the last request in the SSL buffers by directly updating the buffers to simulate partial buffered request + def truncateBufferedRequest(channel: KafkaChannel): Unit = { + val transportLayer: SslTransportLayer = JTestUtils.fieldValue(channel, classOf[KafkaChannel], "transportLayer") + val netReadBuffer: ByteBuffer = JTestUtils.fieldValue(transportLayer, classOf[SslTransportLayer], "netReadBuffer") + val appReadBuffer: ByteBuffer = JTestUtils.fieldValue(transportLayer, classOf[SslTransportLayer], "appReadBuffer") + if (appReadBuffer.position() > 4) { + appReadBuffer.position(4) + netReadBuffer.position(0) + } else { + netReadBuffer.position(20) + } + } + withTestableServer (testWithServer = { testableServer => + val testableSelector = testableServer.testableSelector + + val proxyServer = new ProxyServer(testableServer) + try { + // Step 1: Send client requests. + // a) request1 is sent by the client to ProxyServer and this is directly sent to the server. This + // ensures that server-side channel is in muted state until this request is processed in Step 3. + // b) `numComplete` requests are sent and buffered in the server-side channel's SSL buffers + // c) If `hasIncomplete=true`, an extra request is sent and buffered as in b). This will be truncated later + // when previous requests have been processed and only one request is remaining in the SSL buffer, + // making it easy to truncate. + val numBufferedRequests = numComplete + (if (hasIncomplete) 1 else 0) + val (socket, request1) = makeSocketWithBufferedRequests(testableServer, testableSelector, proxyServer, numBufferedRequests) + val channel = openChannel(request1, testableServer).getOrElse(throw new IllegalStateException("Channel closed too early")) + + // Step 2: Close the client-side socket and the proxy socket to the server, triggering close notification in the + // server when the client is unmuted in Step 3. Get the channel into its desired closing/buffered state. + socket.close() + proxyServer.serverConnSocket.close() + TestUtils.waitUntilTrue(() => proxyServer.clientConnSocket.isClosed, "Client socket not closed") + if (makeClosing) + testableSelector.pendingClosingChannels.add(channel) + if (numComplete == 0 && hasIncomplete) + truncateBufferedRequest(channel) + + // Step 3: Process the first request. Verify that the channel is not removed since the channel + // should be retained to process buffered data. + processRequestNoOpResponse(testableServer.dataPlaneRequestChannel, request1) + assertSame(channel, openOrClosingChannel(request1, testableServer).getOrElse(throw new IllegalStateException("Channel closed too early"))) + + // Step 4: Process buffered data. if `responseRequiredIndex>=0`, the channel should be failed and removed when + // attempting to send response. Otherwise, the channel should be removed when all completed buffers are processed. + // Channel should be closed and removed even if there is a partial buffered request when `hasIncomplete=true` + val numRequests = if (responseRequiredIndex >= 0) responseRequiredIndex + 1 else numComplete + (0 until numRequests).foreach { i => + val request = receiveRequest(testableServer.dataPlaneRequestChannel) + if (i == numComplete - 1 && hasIncomplete) + truncateBufferedRequest(channel) + if (responseRequiredIndex == i) + processRequest(testableServer.dataPlaneRequestChannel, request) + else + processRequestNoOpResponse(testableServer.dataPlaneRequestChannel, request) + } + testableServer.waitForChannelClose(channel.id, locallyClosed = false) + + // Verify that SocketServer is healthy + val anotherSocket = sslConnect(testableServer) + assertProcessorHealthy(testableServer, Seq(anotherSocket)) + } finally { + proxyServer.close() + } + }) + } + + /** + * Tests idle channel expiry for SSL channels with buffered data. Muted channels are expired + * immediately even if there is pending data to be processed. This is consistent with PLAINTEXT where + * we expire muted channels even if there is data available on the socket. This scenario occurs if broker + * takes longer than idle timeout to process a client request. In this case, typically client would have + * expired its connection and would potentially reconnect to retry the request, so immediate expiry enables + * the old connection and its associated resources to be freed sooner. + */ + @Test + def idleExpiryWithBufferedReceives(): Unit = { + val idleTimeMs = 60000 + val time = new MockTime() + props.put(KafkaConfig.ConnectionsMaxIdleMsProp, idleTimeMs.toString) + props ++= sslServerProps + val testableServer = new TestableSocketServer(time = time) + testableServer.startup() + + assertTrue(testableServer.controlPlaneRequestChannelOpt.isEmpty) + + val proxyServer = new ProxyServer(testableServer) + try { + val testableSelector = testableServer.testableSelector + testableSelector.updateMinWakeup(2) + + val sleepTimeMs = idleTimeMs / 2 + 1 + val (socket, request) = makeSocketWithBufferedRequests(testableServer, testableSelector, proxyServer) + // advance mock time in increments to verify that muted sockets with buffered data dont have their idle time updated + // additional calls to poll() should not update the channel last idle time + for (_ <- 0 to 3) { + time.sleep(sleepTimeMs) + testableSelector.operationCounts.clear() + testableSelector.waitForOperations(SelectorOperation.Poll, 1) + } + testableServer.waitForChannelClose(request.context.connectionId, locallyClosed = false) + + val otherSocket = sslConnect(testableServer) + assertProcessorHealthy(testableServer, Seq(otherSocket)) + + socket.close() + } finally { + proxyServer.close() + shutdownServerAndMetrics(testableServer) + } + } + + @Test + def testUnmuteChannelWithBufferedReceives(): Unit = { + val time = new MockTime() + props ++= sslServerProps + val testableServer = new TestableSocketServer(time = time) + testableServer.startup() + val proxyServer = new ProxyServer(testableServer) + try { + val testableSelector = testableServer.testableSelector + val (socket, request) = makeSocketWithBufferedRequests(testableServer, testableSelector, proxyServer) + testableSelector.operationCounts.clear() + testableSelector.waitForOperations(SelectorOperation.Poll, 1) + val keysWithBufferedRead: util.Set[SelectionKey] = JTestUtils.fieldValue(testableSelector, classOf[Selector], "keysWithBufferedRead") + assertEquals(Set.empty, keysWithBufferedRead.asScala) + processRequest(testableServer.dataPlaneRequestChannel, request) + // buffered requests should be processed after channel is unmuted + receiveRequest(testableServer.dataPlaneRequestChannel) + socket.close() + } finally { + proxyServer.close() + shutdownServerAndMetrics(testableServer) + } + } + /** + * Tests exception handling in [[Processor.processCompletedReceives]]. Exception is + * injected into [[Selector.mute]] which is used to mute the channel when a receive is complete. + * Test creates two receives in a single iteration by caching completed receives until two receives + * are complete. + * Verifies that + * - first failed channel is closed + * - second receive is processed successfully after the first fails with an exception + * - processor is healthy after the exception + */ + @Test + def processCompletedReceiveException(): Unit = { + withTestableServer (testWithServer = { testableServer => + val sockets = (1 to 2).map(_ => connect(testableServer)) + val testableSelector = testableServer.testableSelector + val requestChannel = testableServer.dataPlaneRequestChannel + + testableSelector.cachedCompletedReceives.minPerPoll = 2 + testableSelector.addFailure(SelectorOperation.Mute) + sockets.foreach(sendRequest(_, producerRequestBytes())) + val requests = sockets.map(_ => receiveRequest(requestChannel)) + testableSelector.waitForOperations(SelectorOperation.Mute, 2) + testableServer.waitForChannelClose(testableSelector.allFailedChannels.head, locallyClosed = true) + requests.foreach(processRequest(requestChannel, _)) + + assertProcessorHealthy(testableServer, testableSelector.notFailed(sockets)) + }) + } + + /** + * Tests exception handling in [[Processor.processCompletedSends]]. Exception is + * injected into [[Selector.unmute]] which is used to unmute the channel after send is complete. + * Test creates two completed sends in a single iteration by caching completed sends until two + * sends are complete. + * Verifies that + * - first failed channel is closed + * - second send is processed successfully after the first fails with an exception + * - processor is healthy after the exception + */ + @Test + def processCompletedSendException(): Unit = { + withTestableServer (testWithServer = { testableServer => + val testableSelector = testableServer.testableSelector + val sockets = (1 to 2).map(_ => connect(testableServer)) + val requests = sockets.map(sendAndReceiveRequest(_, testableServer)) + + testableSelector.addFailure(SelectorOperation.Unmute) + requests.foreach(processRequest(testableServer.dataPlaneRequestChannel, _)) + testableSelector.waitForOperations(SelectorOperation.Unmute, 2) + testableServer.waitForChannelClose(testableSelector.allFailedChannels.head, locallyClosed = true) + + assertProcessorHealthy(testableServer, testableSelector.notFailed(sockets)) + }) + } + + /** + * Tests exception handling in [[Processor.processDisconnected]]. An invalid connectionId + * is inserted to the disconnected list just before the actual valid one. + * Verifies that + * - first invalid connectionId is ignored + * - second disconnected channel is processed successfully after the first fails with an exception + * - processor is healthy after the exception + */ + @Test + def processDisconnectedException(): Unit = { + withTestableServer (testWithServer = { testableServer => + val (socket, connectionId) = connectAndProcessRequest(testableServer) + val testableSelector = testableServer.testableSelector + + // Add an invalid connectionId to `Selector.disconnected` list before the actual disconnected channel + // and check that the actual id is processed and the invalid one ignored. + testableSelector.cachedDisconnected.minPerPoll = 2 + testableSelector.cachedDisconnected.deferredValues += "notAValidConnectionId" -> ChannelState.EXPIRED + socket.close() + testableSelector.operationCounts.clear() + testableSelector.waitForOperations(SelectorOperation.Poll, 1) + testableServer.waitForChannelClose(connectionId, locallyClosed = false) + + assertProcessorHealthy(testableServer) + }) + } + + /** + * Tests that `Processor` continues to function correctly after a failed [[Selector.poll]]. + */ + @Test + def pollException(): Unit = { + withTestableServer (testWithServer = { testableServer => + val (socket, _) = connectAndProcessRequest(testableServer) + val testableSelector = testableServer.testableSelector + + testableSelector.addFailure(SelectorOperation.Poll) + testableSelector.operationCounts.clear() + testableSelector.waitForOperations(SelectorOperation.Poll, 2) + + assertProcessorHealthy(testableServer, Seq(socket)) + }) + } + + /** + * Tests handling of `ControlThrowable`. Verifies that the selector is closed. + */ + @Test + def controlThrowable(): Unit = { + withTestableServer (testWithServer = { testableServer => + connectAndProcessRequest(testableServer) + val testableSelector = testableServer.testableSelector + + testableSelector.operationCounts.clear() + testableSelector.addFailure(SelectorOperation.Poll, + Some(new ControlThrowable() {})) + testableSelector.waitForOperations(SelectorOperation.Poll, 1) + + testableSelector.waitForOperations(SelectorOperation.CloseSelector, 1) + assertEquals(1, testableServer.uncaughtExceptions) + testableServer.uncaughtExceptions = 0 + }) + } + + @Test + def testConnectionRateLimit(): Unit = { + shutdownServerAndMetrics(server) + val numConnections = 5 + props.put("max.connections.per.ip", numConnections.toString) + val testableServer = new TestableSocketServer(KafkaConfig.fromProps(props), connectionQueueSize = 1) + testableServer.startup() + val testableSelector = testableServer.testableSelector + val errors = new mutable.HashSet[String] + + def acceptorStackTraces: scala.collection.Map[Thread, String] = { + Thread.getAllStackTraces.asScala.collect { + case (thread, stacktraceElement) if thread.getName.contains("kafka-socket-acceptor") => + thread -> stacktraceElement.mkString("\n") + } + } + + def acceptorBlocked: Boolean = { + val stackTraces = acceptorStackTraces + if (stackTraces.isEmpty) + errors.add(s"Acceptor thread not found, threads=${Thread.getAllStackTraces.keySet}") + stackTraces.exists { case (thread, stackTrace) => + thread.getState == Thread.State.WAITING && stackTrace.contains("ArrayBlockingQueue") + } + } + + def registeredConnectionCount: Int = testableSelector.operationCounts.getOrElse(SelectorOperation.Register, 0) + + try { + // Block selector until Acceptor is blocked while connections are pending + testableSelector.pollCallback = () => { + try { + TestUtils.waitUntilTrue(() => errors.nonEmpty || registeredConnectionCount >= numConnections - 1 || acceptorBlocked, + "Acceptor not blocked", waitTimeMs = 10000) + } catch { + case _: Throwable => errors.add(s"Acceptor not blocked: $acceptorStackTraces") + } + } + testableSelector.operationCounts.clear() + val sockets = (1 to numConnections).map(_ => connect(testableServer)) + TestUtils.waitUntilTrue(() => errors.nonEmpty || registeredConnectionCount == numConnections, + "Connections not registered", waitTimeMs = 15000) + assertEquals(Set.empty, errors) + testableSelector.waitForOperations(SelectorOperation.Register, numConnections) + + // In each iteration, SocketServer processes at most connectionQueueSize (1 in this test) + // new connections and then does poll() to process data from existing connections. So for + // 5 connections, we expect 5 iterations. Since we stop when the 5th connection is processed, + // we can safely check that there were at least 4 polls prior to the 5th connection. + val pollCount = testableSelector.operationCounts(SelectorOperation.Poll) + assertTrue(pollCount >= numConnections - 1, s"Connections created too quickly: $pollCount") + verifyAcceptorBlockedPercent("PLAINTEXT", expectBlocked = true) + + assertProcessorHealthy(testableServer, sockets) + } finally { + shutdownServerAndMetrics(testableServer) + } + } + + + @Test + def testControlPlaneAsPrivilegedListener(): Unit = { + val testProps = new Properties + testProps ++= props + testProps.put("listeners", "PLAINTEXT://localhost:0,CONTROLLER://localhost:0") + testProps.put("listener.security.protocol.map", "PLAINTEXT:PLAINTEXT,CONTROLLER:PLAINTEXT") + testProps.put("control.plane.listener.name", "CONTROLLER") + val config = KafkaConfig.fromProps(testProps) + withTestableServer(config, { testableServer => + val controlPlaneSocket = connect(testableServer, config.controlPlaneListenerName.get, + localAddr = InetAddress.getLocalHost) + val sentRequest = sendAndReceiveControllerRequest(controlPlaneSocket, testableServer) + assertTrue(sentRequest.context.fromPrivilegedListener) + + val plainSocket = connect(testableServer, localAddr = InetAddress.getLocalHost) + val plainRequest = sendAndReceiveRequest(plainSocket, testableServer) + assertFalse(plainRequest.context.fromPrivilegedListener) + }) + } + + @Test + def testInterBrokerListenerAsPrivilegedListener(): Unit = { + val testProps = new Properties + testProps ++= props + testProps.put("listeners", "EXTERNAL://localhost:0,INTERNAL://localhost:0") + testProps.put("listener.security.protocol.map", "EXTERNAL:PLAINTEXT,INTERNAL:PLAINTEXT") + testProps.put("inter.broker.listener.name", "INTERNAL") + val config = KafkaConfig.fromProps(testProps) + withTestableServer(config, { testableServer => + val interBrokerSocket = connect(testableServer, config.interBrokerListenerName, + localAddr = InetAddress.getLocalHost) + val sentRequest = sendAndReceiveRequest(interBrokerSocket, testableServer) + assertTrue(sentRequest.context.fromPrivilegedListener) + + val externalSocket = connect(testableServer, new ListenerName("EXTERNAL"), + localAddr = InetAddress.getLocalHost) + val externalRequest = sendAndReceiveRequest(externalSocket, testableServer) + assertFalse(externalRequest.context.fromPrivilegedListener) + }) + } + + @Test + def testControlPlaneTakePrecedenceOverInterBrokerListenerAsPrivilegedListener(): Unit = { + val testProps = new Properties + testProps ++= props + testProps.put("listeners", "EXTERNAL://localhost:0,INTERNAL://localhost:0,CONTROLLER://localhost:0") + testProps.put("listener.security.protocol.map", "EXTERNAL:PLAINTEXT,INTERNAL:PLAINTEXT,CONTROLLER:PLAINTEXT") + testProps.put("control.plane.listener.name", "CONTROLLER") + testProps.put("inter.broker.listener.name", "INTERNAL") + val config = KafkaConfig.fromProps(testProps) + withTestableServer(config, { testableServer => + val controlPlaneSocket = connect(testableServer, config.controlPlaneListenerName.get, + localAddr = InetAddress.getLocalHost) + val controlPlaneRequest = sendAndReceiveControllerRequest(controlPlaneSocket, testableServer) + assertTrue(controlPlaneRequest.context.fromPrivilegedListener) + + val interBrokerSocket = connect(testableServer, config.interBrokerListenerName, + localAddr = InetAddress.getLocalHost) + val interBrokerRequest = sendAndReceiveRequest(interBrokerSocket, testableServer) + assertFalse(interBrokerRequest.context.fromPrivilegedListener) + + val externalSocket = connect(testableServer, new ListenerName("EXTERNAL"), + localAddr = InetAddress.getLocalHost) + val externalRequest = sendAndReceiveRequest(externalSocket, testableServer) + assertFalse(externalRequest.context.fromPrivilegedListener) + }) + } + + private def sslServerProps: Properties = { + val trustStoreFile = File.createTempFile("truststore", ".jks") + val sslProps = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, interBrokerSecurityProtocol = Some(SecurityProtocol.SSL), + trustStoreFile = Some(trustStoreFile)) + sslProps.put(KafkaConfig.ListenersProp, "SSL://localhost:0") + sslProps + } + + private def withTestableServer(config : KafkaConfig = KafkaConfig.fromProps(props), + testWithServer: TestableSocketServer => Unit): Unit = { + val testableServer = new TestableSocketServer(config) + testableServer.startup() + try { + testWithServer(testableServer) + } finally { + shutdownServerAndMetrics(testableServer) + assertEquals(0, testableServer.uncaughtExceptions) + } + } + + def sendAndReceiveControllerRequest(socket: Socket, server: SocketServer): RequestChannel.Request = { + sendRequest(socket, producerRequestBytes()) + receiveRequest(server.controlPlaneRequestChannelOpt.get) + } + + private def assertProcessorHealthy(testableServer: TestableSocketServer, healthySockets: Seq[Socket] = Seq.empty): Unit = { + val selector = testableServer.testableSelector + selector.reset() + val requestChannel = testableServer.dataPlaneRequestChannel + + // Check that existing channels behave as expected + healthySockets.foreach { socket => + val request = sendAndReceiveRequest(socket, testableServer) + processRequest(requestChannel, request) + socket.close() + } + TestUtils.waitUntilTrue(() => testableServer.connectionCount(localAddress) == 0, "Channels not removed") + + // Check new channel behaves as expected + val (socket, connectionId) = connectAndProcessRequest(testableServer) + assertArrayEquals(producerRequestBytes(), receiveResponse(socket)) + assertNotNull(selector.channel(connectionId), "Channel should not have been closed") + assertNull(selector.closingChannel(connectionId), "Channel should not be closing") + socket.close() + TestUtils.waitUntilTrue(() => testableServer.connectionCount(localAddress) == 0, "Channels not removed") + } + + // Since all sockets use the same local host, it is sufficient to check the local port + def isSocketConnectionId(connectionId: String, socket: Socket): Boolean = + connectionId.contains(s":${socket.getLocalPort}-") + + private def verifyAcceptorBlockedPercent(listenerName: String, expectBlocked: Boolean): Unit = { + val blockedPercentMetricMBeanName = "kafka.network:type=Acceptor,name=AcceptorBlockedPercent,listener=PLAINTEXT" + val blockedPercentMetrics = KafkaYammerMetrics.defaultRegistry.allMetrics.asScala.filter { case (k, _) => + k.getMBeanName == blockedPercentMetricMBeanName + }.values + assertEquals(1, blockedPercentMetrics.size) + val blockedPercentMetric = blockedPercentMetrics.head.asInstanceOf[Meter] + val blockedPercent = blockedPercentMetric.meanRate + if (expectBlocked) { + assertTrue(blockedPercent > 0.0, s"Acceptor blocked percent not recorded: $blockedPercent") + assertTrue(blockedPercent <= 1.0, s"Unexpected blocked percent in acceptor: $blockedPercent") + } else { + assertEquals(0.0, blockedPercent, 0.001) + } + } + + class TestableSocketServer( + config : KafkaConfig = KafkaConfig.fromProps(props), + connectionQueueSize: Int = 20, + time: Time = Time.SYSTEM + ) extends SocketServer( + config, new Metrics, time, credentialProvider, apiVersionManager, + ) { + + @volatile var selector: Option[TestableSelector] = None + @volatile var uncaughtExceptions = 0 + + override def newProcessor(id: Int, requestChannel: RequestChannel, connectionQuotas: ConnectionQuotas, listenerName: ListenerName, + protocol: SecurityProtocol, memoryPool: MemoryPool, isPrivilegedListener: Boolean = false): Processor = { + new Processor(id, time, config.socketRequestMaxBytes, requestChannel, connectionQuotas, config.connectionsMaxIdleMs, + config.failedAuthenticationDelayMs, listenerName, protocol, config, metrics, credentialProvider, + memoryPool, new LogContext(), connectionQueueSize, isPrivilegedListener, apiVersionManager) { + + override protected[network] def createSelector(channelBuilder: ChannelBuilder): Selector = { + val testableSelector = new TestableSelector(config, channelBuilder, time, metrics, metricTags.asScala) + selector = Some(testableSelector) + testableSelector + } + + override private[network] def processException(errorMessage: String, throwable: Throwable): Unit = { + if (errorMessage.contains("uncaught exception")) + uncaughtExceptions += 1 + super.processException(errorMessage, throwable) + } + } + } + + def testableSelector: TestableSelector = + selector.getOrElse(throw new IllegalStateException("Selector not created")) + + def waitForChannelClose(connectionId: String, locallyClosed: Boolean): Unit = { + val selector = testableSelector + if (locallyClosed) { + TestUtils.waitUntilTrue(() => selector.allLocallyClosedChannels.contains(connectionId), + s"Channel not closed: $connectionId") + assertTrue(testableSelector.allDisconnectedChannels.isEmpty, "Unexpected disconnect notification") + } else { + TestUtils.waitUntilTrue(() => selector.allDisconnectedChannels.contains(connectionId), + s"Disconnect notification not received: $connectionId") + assertTrue(testableSelector.allLocallyClosedChannels.isEmpty, "Channel closed locally") + } + val openCount = selector.allChannels.size - 1 // minus one for the channel just closed above + TestUtils.waitUntilTrue(() => connectionCount(localAddress) == openCount, "Connection count not decremented") + TestUtils.waitUntilTrue(() => + dataPlaneProcessor(0).inflightResponseCount == 0, "Inflight responses not cleared") + assertNull(selector.channel(connectionId), "Channel not removed") + assertNull(selector.closingChannel(connectionId), "Closing channel not removed") + } + } + + sealed trait SelectorOperation + object SelectorOperation { + case object Register extends SelectorOperation + case object Poll extends SelectorOperation + case object Send extends SelectorOperation + case object Mute extends SelectorOperation + case object Unmute extends SelectorOperation + case object Wakeup extends SelectorOperation + case object Close extends SelectorOperation + case object CloseSelector extends SelectorOperation + } + + class TestableSelector(config: KafkaConfig, channelBuilder: ChannelBuilder, time: Time, metrics: Metrics, metricTags: mutable.Map[String, String] = mutable.Map.empty) + extends Selector(config.socketRequestMaxBytes, config.connectionsMaxIdleMs, config.failedAuthenticationDelayMs, + metrics, time, "socket-server", metricTags.asJava, false, true, channelBuilder, MemoryPool.NONE, new LogContext()) { + + val failures = mutable.Map[SelectorOperation, Throwable]() + val operationCounts = mutable.Map[SelectorOperation, Int]().withDefaultValue(0) + val allChannels = mutable.Set[String]() + val allLocallyClosedChannels = mutable.Set[String]() + val allDisconnectedChannels = mutable.Set[String]() + val allFailedChannels = mutable.Set[String]() + + // Enable data from `Selector.poll()` to be deferred to a subsequent poll() until + // the number of elements of that type reaches `minPerPoll`. This enables tests to verify + // that failed processing doesn't impact subsequent processing within the same iteration. + abstract class PollData[T] { + var minPerPoll = 1 + val deferredValues = mutable.Buffer[T]() + + /** + * Process new results and return the results for the current poll if at least + * `minPerPoll` results are available including any deferred results. Otherwise + * add the provided values to the deferred set and return an empty buffer. This allows + * tests to process `minPerPoll` elements as the results of a single poll iteration. + */ + protected def update(newValues: mutable.Buffer[T]): mutable.Buffer[T] = { + val currentPollValues = mutable.Buffer[T]() + if (deferredValues.size + newValues.size >= minPerPoll) { + if (deferredValues.nonEmpty) { + currentPollValues ++= deferredValues + deferredValues.clear() + } + currentPollValues ++= newValues + } else + deferredValues ++= newValues + + currentPollValues + } + + /** + * Process results from the appropriate buffer in Selector and update the buffer to either + * defer and return nothing or return all results including previously deferred values. + */ + def updateResults(): Unit + } + + class CompletedReceivesPollData(selector: TestableSelector) extends PollData[NetworkReceive] { + val completedReceivesMap: util.Map[String, NetworkReceive] = JTestUtils.fieldValue(selector, classOf[Selector], "completedReceives") + + override def updateResults(): Unit = { + val currentReceives = update(selector.completedReceives.asScala.toBuffer) + completedReceivesMap.clear() + currentReceives.foreach { receive => + val channelOpt = Option(selector.channel(receive.source)).orElse(Option(selector.closingChannel(receive.source))) + channelOpt.foreach { channel => completedReceivesMap.put(channel.id, receive) } + } + } + } + + class CompletedSendsPollData(selector: TestableSelector) extends PollData[NetworkSend] { + override def updateResults(): Unit = { + val currentSends = update(selector.completedSends.asScala) + selector.completedSends.clear() + currentSends.foreach { selector.completedSends.add } + } + } + + class DisconnectedPollData(selector: TestableSelector) extends PollData[(String, ChannelState)] { + override def updateResults(): Unit = { + val currentDisconnected = update(selector.disconnected.asScala.toBuffer) + selector.disconnected.clear() + currentDisconnected.foreach { case (channelId, state) => selector.disconnected.put(channelId, state) } + } + } + + val cachedCompletedReceives = new CompletedReceivesPollData(this) + val cachedCompletedSends = new CompletedSendsPollData(this) + val cachedDisconnected = new DisconnectedPollData(this) + val allCachedPollData = Seq(cachedCompletedReceives, cachedCompletedSends, cachedDisconnected) + val pendingClosingChannels = new ConcurrentLinkedQueue[KafkaChannel]() + @volatile var minWakeupCount = 0 + @volatile var pollTimeoutOverride: Option[Long] = None + @volatile var pollCallback: () => Unit = () => {} + + def addFailure(operation: SelectorOperation, exception: Option[Throwable] = None): Unit = { + failures += operation -> + exception.getOrElse(new IllegalStateException(s"Test exception during $operation")) + } + + private def onOperation(operation: SelectorOperation, connectionId: Option[String], onFailure: => Unit): Unit = { + operationCounts(operation) += 1 + failures.remove(operation).foreach { e => + connectionId.foreach(allFailedChannels.add) + onFailure + throw e + } + } + + def waitForOperations(operation: SelectorOperation, minExpectedTotal: Int): Unit = { + TestUtils.waitUntilTrue(() => + operationCounts.getOrElse(operation, 0) >= minExpectedTotal, "Operations not performed within timeout") + } + + def runOp[T](operation: SelectorOperation, connectionId: Option[String], + onFailure: => Unit = {})(code: => T): T = { + // If a failure is set on `operation`, throw that exception even if `code` fails + try code + finally onOperation(operation, connectionId, onFailure) + } + + override def register(id: String, socketChannel: SocketChannel): Unit = { + runOp(SelectorOperation.Register, Some(id), onFailure = close(id)) { + super.register(id, socketChannel) + } + } + + override def send(s: NetworkSend): Unit = { + runOp(SelectorOperation.Send, Some(s.destinationId)) { + super.send(s) + } + } + + override def poll(timeout: Long): Unit = { + try { + assertEquals(0, super.completedReceives().size) + assertEquals(0, super.completedSends().size) + + pollCallback.apply() + while (!pendingClosingChannels.isEmpty) { + makeClosing(pendingClosingChannels.poll()) + } + runOp(SelectorOperation.Poll, None) { + super.poll(pollTimeoutOverride.getOrElse(timeout)) + } + } finally { + super.channels.forEach(allChannels += _.id) + allDisconnectedChannels ++= super.disconnected.asScala.keys + + cachedCompletedReceives.updateResults() + cachedCompletedSends.updateResults() + cachedDisconnected.updateResults() + } + } + + override def mute(id: String): Unit = { + runOp(SelectorOperation.Mute, Some(id)) { + super.mute(id) + } + } + + override def unmute(id: String): Unit = { + runOp(SelectorOperation.Unmute, Some(id)) { + super.unmute(id) + } + } + + override def wakeup(): Unit = { + runOp(SelectorOperation.Wakeup, None) { + if (minWakeupCount > 0) + minWakeupCount -= 1 + if (minWakeupCount <= 0) + super.wakeup() + } + } + + override def close(id: String): Unit = { + runOp(SelectorOperation.Close, Some(id)) { + super.close(id) + allLocallyClosedChannels += id + } + } + + override def close(): Unit = { + runOp(SelectorOperation.CloseSelector, None) { + super.close() + } + } + + def updateMinWakeup(count: Int): Unit = { + minWakeupCount = count + // For tests that ignore wakeup to process responses together, increase poll timeout + // to ensure that poll doesn't complete before the responses are ready + pollTimeoutOverride = Some(1000L) + // Wakeup current poll to force new poll timeout to take effect + super.wakeup() + } + + def reset(): Unit = { + failures.clear() + allCachedPollData.foreach(_.minPerPoll = 1) + } + + def notFailed(sockets: Seq[Socket]): Seq[Socket] = { + // Each test generates failure for exactly one failed channel + assertEquals(1, allFailedChannels.size) + val failedConnectionId = allFailedChannels.head + sockets.filterNot(socket => isSocketConnectionId(failedConnectionId, socket)) + } + + private def makeClosing(channel: KafkaChannel): Unit = { + val channels: util.Map[String, KafkaChannel] = JTestUtils.fieldValue(this, classOf[Selector], "channels") + val closingChannels: util.Map[String, KafkaChannel] = JTestUtils.fieldValue(this, classOf[Selector], "closingChannels") + closingChannels.put(channel.id, channel) + channels.remove(channel.id) + } + } + + /** + * Proxy server used to intercept connections to SocketServer. This is used for testing SSL channels + * with buffered data. A single SSL client is expected to be created by the test using this ProxyServer. + * By default, data between the client and the server is simply transferred across to the destination by ProxyServer. + * Tests can enable buffering in ProxyServer to directly copy incoming data from the client to the server-side + * channel's `netReadBuffer` to simulate scenarios with SSL buffered data. + */ + private class ProxyServer(socketServer: SocketServer) { + val serverSocket = new ServerSocket(0) + val localPort = serverSocket.getLocalPort + val serverConnSocket = new Socket("localhost", socketServer.boundPort(ListenerName.forSecurityProtocol(SecurityProtocol.SSL))) + val executor = Executors.newFixedThreadPool(2) + @volatile var clientConnSocket: Socket = _ + @volatile var buffer: Option[ByteBuffer] = None + + executor.submit((() => { + try { + clientConnSocket = serverSocket.accept() + val serverOut = serverConnSocket.getOutputStream + val clientIn = clientConnSocket.getInputStream + var b: Int = -1 + while ({b = clientIn.read(); b != -1}) { + buffer match { + case Some(buf) => + buf.put(b.asInstanceOf[Byte]) + case None => + serverOut.write(b) + serverOut.flush() + } + } + } finally { + clientConnSocket.close() + } + }): Runnable) + + executor.submit((() => { + var b: Int = -1 + val serverIn = serverConnSocket.getInputStream + while ({b = serverIn.read(); b != -1}) { + clientConnSocket.getOutputStream.write(b) + } + }): Runnable) + + def enableBuffering(buffer: ByteBuffer): Unit = this.buffer = Some(buffer) + + def close(): Unit = { + serverSocket.close() + serverConnSocket.close() + clientConnSocket.close() + executor.shutdownNow() + assertTrue(executor.awaitTermination(10, TimeUnit.SECONDS)) + } + + } +} diff --git a/core/src/test/scala/unit/kafka/raft/KafkaNetworkChannelTest.scala b/core/src/test/scala/unit/kafka/raft/KafkaNetworkChannelTest.scala new file mode 100644 index 0000000..eaa21dd --- /dev/null +++ b/core/src/test/scala/unit/kafka/raft/KafkaNetworkChannelTest.scala @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.raft + +import java.net.InetSocketAddress +import java.util +import java.util.Collections +import org.apache.kafka.clients.MockClient.MockMetadataUpdater +import org.apache.kafka.clients.{MockClient, NodeApiVersions} +import org.apache.kafka.common.message.{BeginQuorumEpochResponseData, EndQuorumEpochResponseData, FetchResponseData, VoteResponseData} +import org.apache.kafka.common.protocol.{ApiKeys, ApiMessage, Errors} +import org.apache.kafka.common.requests.{AbstractResponse, ApiVersionsResponse, BeginQuorumEpochRequest, BeginQuorumEpochResponse, EndQuorumEpochRequest, EndQuorumEpochResponse, FetchResponse, VoteRequest, VoteResponse} +import org.apache.kafka.common.utils.{MockTime, Time} +import org.apache.kafka.common.{Node, TopicPartition, Uuid} +import org.apache.kafka.raft.RaftConfig.InetAddressSpec +import org.apache.kafka.raft.{RaftRequest, RaftUtil} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{BeforeEach, Test} + +import scala.jdk.CollectionConverters._ + +class KafkaNetworkChannelTest { + import KafkaNetworkChannelTest._ + + private val clusterId = "clusterId" + private val requestTimeoutMs = 30000 + private val time = new MockTime() + private val client = new MockClient(time, new StubMetadataUpdater) + private val topicPartition = new TopicPartition("topic", 0) + private val topicId = Uuid.randomUuid() + private val channel = new KafkaNetworkChannel(time, client, requestTimeoutMs, threadNamePrefix = "test-raft") + + @BeforeEach + def setupSupportedApis(): Unit = { + val supportedApis = RaftApis.map(ApiVersionsResponse.toApiVersion) + client.setNodeApiVersions(NodeApiVersions.create(supportedApis.asJava)) + } + + @Test + def testSendToUnknownDestination(): Unit = { + val destinationId = 2 + assertBrokerNotAvailable(destinationId) + } + + @Test + def testSendToBlackedOutDestination(): Unit = { + val destinationId = 2 + val destinationNode = new Node(destinationId, "127.0.0.1", 9092) + channel.updateEndpoint(destinationId, new InetAddressSpec( + new InetSocketAddress(destinationNode.host, destinationNode.port))) + client.backoff(destinationNode, 500) + assertBrokerNotAvailable(destinationId) + } + + @Test + def testWakeupClientOnSend(): Unit = { + val destinationId = 2 + val destinationNode = new Node(destinationId, "127.0.0.1", 9092) + channel.updateEndpoint(destinationId, new InetAddressSpec( + new InetSocketAddress(destinationNode.host, destinationNode.port))) + + client.enableBlockingUntilWakeup(1) + + val ioThread = new Thread() { + override def run(): Unit = { + // Block in poll until we get the expected wakeup + channel.pollOnce() + + // Poll a second time to send request and receive response + channel.pollOnce() + } + } + + val response = buildResponse(buildTestErrorResponse(ApiKeys.FETCH, Errors.INVALID_REQUEST)) + client.prepareResponseFrom(response, destinationNode, false) + + ioThread.start() + val request = sendTestRequest(ApiKeys.FETCH, destinationId) + + ioThread.join() + assertResponseCompleted(request, Errors.INVALID_REQUEST) + } + + @Test + def testSendAndDisconnect(): Unit = { + val destinationId = 2 + val destinationNode = new Node(destinationId, "127.0.0.1", 9092) + channel.updateEndpoint(destinationId, new InetAddressSpec( + new InetSocketAddress(destinationNode.host, destinationNode.port))) + + for (apiKey <- RaftApis) { + val response = buildResponse(buildTestErrorResponse(apiKey, Errors.INVALID_REQUEST)) + client.prepareResponseFrom(response, destinationNode, true) + sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE) + } + } + + @Test + def testSendAndFailAuthentication(): Unit = { + val destinationId = 2 + val destinationNode = new Node(destinationId, "127.0.0.1", 9092) + channel.updateEndpoint(destinationId, new InetAddressSpec( + new InetSocketAddress(destinationNode.host, destinationNode.port))) + + for (apiKey <- RaftApis) { + client.createPendingAuthenticationError(destinationNode, 100) + sendAndAssertErrorResponse(apiKey, destinationId, Errors.NETWORK_EXCEPTION) + + // reset to clear backoff time + client.reset() + } + } + + private def assertBrokerNotAvailable(destinationId: Int): Unit = { + for (apiKey <- RaftApis) { + sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE) + } + } + + @Test + def testSendAndReceiveOutboundRequest(): Unit = { + val destinationId = 2 + val destinationNode = new Node(destinationId, "127.0.0.1", 9092) + channel.updateEndpoint(destinationId, new InetAddressSpec( + new InetSocketAddress(destinationNode.host, destinationNode.port))) + + for (apiKey <- RaftApis) { + val expectedError = Errors.INVALID_REQUEST + val response = buildResponse(buildTestErrorResponse(apiKey, expectedError)) + client.prepareResponseFrom(response, destinationNode) + sendAndAssertErrorResponse(apiKey, destinationId, expectedError) + } + } + + @Test + def testUnsupportedVersionError(): Unit = { + val destinationId = 2 + val destinationNode = new Node(destinationId, "127.0.0.1", 9092) + channel.updateEndpoint(destinationId, new InetAddressSpec( + new InetSocketAddress(destinationNode.host, destinationNode.port))) + + for (apiKey <- RaftApis) { + client.prepareUnsupportedVersionResponse(request => request.apiKey == apiKey) + sendAndAssertErrorResponse(apiKey, destinationId, Errors.UNSUPPORTED_VERSION) + } + } + + private def sendTestRequest( + apiKey: ApiKeys, + destinationId: Int, + ): RaftRequest.Outbound = { + val correlationId = channel.newCorrelationId() + val createdTimeMs = time.milliseconds() + val apiRequest = buildTestRequest(apiKey) + val request = new RaftRequest.Outbound(correlationId, apiRequest, destinationId, createdTimeMs) + channel.send(request) + request + } + + private def assertResponseCompleted( + request: RaftRequest.Outbound, + expectedError: Errors + ): Unit = { + assertTrue(request.completion.isDone) + + val response = request.completion.get() + assertEquals(request.destinationId, response.sourceId) + assertEquals(request.correlationId, response.correlationId) + assertEquals(request.data.apiKey, response.data.apiKey) + assertEquals(expectedError, extractError(response.data)) + } + + private def sendAndAssertErrorResponse( + apiKey: ApiKeys, + destinationId: Int, + error: Errors + ): Unit = { + val request = sendTestRequest(apiKey, destinationId) + channel.pollOnce() + assertResponseCompleted(request, error) + } + + private def buildTestRequest(key: ApiKeys): ApiMessage = { + val leaderEpoch = 5 + val leaderId = 1 + key match { + case ApiKeys.BEGIN_QUORUM_EPOCH => + BeginQuorumEpochRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId) + + case ApiKeys.END_QUORUM_EPOCH => + EndQuorumEpochRequest.singletonRequest(topicPartition, clusterId, leaderId, + leaderEpoch, Collections.singletonList(2)) + + case ApiKeys.VOTE => + val lastEpoch = 4 + VoteRequest.singletonRequest(topicPartition, clusterId, leaderEpoch, leaderId, lastEpoch, 329) + + case ApiKeys.FETCH => + val request = RaftUtil.singletonFetchRequest(topicPartition, topicId, fetchPartition => { + fetchPartition + .setCurrentLeaderEpoch(5) + .setFetchOffset(333) + .setLastFetchedEpoch(5) + }) + request.setReplicaId(1) + + case _ => + throw new AssertionError(s"Unexpected api $key") + } + } + + private def buildTestErrorResponse(key: ApiKeys, error: Errors): ApiMessage = { + key match { + case ApiKeys.BEGIN_QUORUM_EPOCH => + new BeginQuorumEpochResponseData() + .setErrorCode(error.code) + + case ApiKeys.END_QUORUM_EPOCH => + new EndQuorumEpochResponseData() + .setErrorCode(error.code) + + case ApiKeys.VOTE => + VoteResponse.singletonResponse(error, topicPartition, Errors.NONE, 1, 5, false); + + case ApiKeys.FETCH => + new FetchResponseData() + .setErrorCode(error.code) + + case _ => + throw new AssertionError(s"Unexpected api $key") + } + } + + private def extractError(response: ApiMessage): Errors = { + val code = (response: @unchecked) match { + case res: BeginQuorumEpochResponseData => res.errorCode + case res: EndQuorumEpochResponseData => res.errorCode + case res: FetchResponseData => res.errorCode + case res: VoteResponseData => res.errorCode + } + Errors.forCode(code) + } + + + def buildResponse(responseData: ApiMessage): AbstractResponse = { + responseData match { + case voteResponse: VoteResponseData => + new VoteResponse(voteResponse) + case beginEpochResponse: BeginQuorumEpochResponseData => + new BeginQuorumEpochResponse(beginEpochResponse) + case endEpochResponse: EndQuorumEpochResponseData => + new EndQuorumEpochResponse(endEpochResponse) + case fetchResponse: FetchResponseData => + new FetchResponse(fetchResponse) + case _ => + throw new IllegalArgumentException(s"Unexpected type for responseData: $responseData") + } + } + +} + +object KafkaNetworkChannelTest { + val RaftApis = Seq( + ApiKeys.VOTE, + ApiKeys.BEGIN_QUORUM_EPOCH, + ApiKeys.END_QUORUM_EPOCH, + ApiKeys.FETCH, + ) + + private class StubMetadataUpdater extends MockMetadataUpdater { + override def fetchNodes(): util.List[Node] = Collections.emptyList() + + override def isUpdateNeeded: Boolean = false + + override def update(time: Time, update: MockClient.MetadataUpdate): Unit = {} + } +} diff --git a/core/src/test/scala/unit/kafka/raft/RaftManagerTest.scala b/core/src/test/scala/unit/kafka/raft/RaftManagerTest.scala new file mode 100644 index 0000000..a7a9519 --- /dev/null +++ b/core/src/test/scala/unit/kafka/raft/RaftManagerTest.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.raft + +import java.util.concurrent.CompletableFuture +import java.util.Properties + +import kafka.raft.KafkaRaftManager.RaftIoThread +import kafka.server.{KafkaConfig, MetaProperties} +import kafka.tools.TestRaftServer.ByteArraySerde +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.Uuid +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.utils.Time +import org.apache.kafka.raft.KafkaRaftClient +import org.apache.kafka.raft.RaftConfig +import org.apache.kafka.test.TestUtils +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test +import org.mockito.Mockito._ + +import java.io.File + +class RaftManagerTest { + + private def instantiateRaftManagerWithConfigs(topicPartition: TopicPartition, processRoles: String, nodeId: String) = { + def configWithProcessRolesAndNodeId(processRoles: String, nodeId: String, logDir: File): KafkaConfig = { + val props = new Properties + props.setProperty(KafkaConfig.MetadataLogDirProp, logDir.getPath) + props.setProperty(KafkaConfig.ProcessRolesProp, processRoles) + props.setProperty(KafkaConfig.NodeIdProp, nodeId) + props.setProperty(KafkaConfig.ControllerListenerNamesProp, "SSL") + if (processRoles.contains("broker")) { + props.setProperty(KafkaConfig.InterBrokerListenerNameProp, "PLAINTEXT") + if (processRoles.contains("controller")) { // co-located + props.setProperty(KafkaConfig.ListenersProp, "PLAINTEXT://localhost:9092,SSL://localhost:9093") + props.setProperty(KafkaConfig.QuorumVotersProp, s"${nodeId}@localhost:9093") + } else { // broker-only + val voterId = (nodeId.toInt + 1) + props.setProperty(KafkaConfig.QuorumVotersProp, s"${voterId}@localhost:9093") + } + } else if (processRoles.contains("controller")) { // controller-only + props.setProperty(KafkaConfig.ListenersProp, "SSL://localhost:9093") + props.setProperty(KafkaConfig.QuorumVotersProp, s"${nodeId}@localhost:9093") + } + + new KafkaConfig(props) + } + + val logDir = TestUtils.tempDirectory() + val config = configWithProcessRolesAndNodeId(processRoles, nodeId, logDir) + val topicId = new Uuid(0L, 2L) + val metaProperties = MetaProperties( + clusterId = Uuid.randomUuid.toString, + nodeId = config.nodeId + ) + + new KafkaRaftManager[Array[Byte]]( + metaProperties, + config, + new ByteArraySerde, + topicPartition, + topicId, + Time.SYSTEM, + new Metrics(Time.SYSTEM), + Option.empty, + CompletableFuture.completedFuture(RaftConfig.parseVoterConnections(config.quorumVoters)) + ) + } + + @Test + def testSentinelNodeIdIfBrokerRoleOnly(): Unit = { + val raftManager = instantiateRaftManagerWithConfigs(new TopicPartition("__raft_id_test", 0), "broker", "1") + assertFalse(raftManager.client.nodeId.isPresent) + raftManager.shutdown() + } + + @Test + def testNodeIdPresentIfControllerRoleOnly(): Unit = { + val raftManager = instantiateRaftManagerWithConfigs(new TopicPartition("__raft_id_test", 0), "controller", "1") + assertTrue(raftManager.client.nodeId.getAsInt == 1) + raftManager.shutdown() + } + + @Test + def testNodeIdPresentIfColocated(): Unit = { + val raftManager = instantiateRaftManagerWithConfigs(new TopicPartition("__raft_id_test", 0), "controller,broker", "1") + assertTrue(raftManager.client.nodeId.getAsInt == 1) + raftManager.shutdown() + } + + @Test + def testShutdownIoThread(): Unit = { + val raftClient = mock(classOf[KafkaRaftClient[String]]) + val ioThread = new RaftIoThread(raftClient, threadNamePrefix = "test-raft") + + when(raftClient.isRunning).thenReturn(true) + assertTrue(ioThread.isRunning) + + val shutdownFuture = new CompletableFuture[Void] + when(raftClient.shutdown(5000)).thenReturn(shutdownFuture) + + ioThread.initiateShutdown() + assertTrue(ioThread.isRunning) + assertTrue(ioThread.isShutdownInitiated) + verify(raftClient).shutdown(5000) + + shutdownFuture.complete(null) + when(raftClient.isRunning).thenReturn(false) + ioThread.run() + assertFalse(ioThread.isRunning) + assertTrue(ioThread.isShutdownComplete) + } + + @Test + def testUncaughtExceptionInIoThread(): Unit = { + val raftClient = mock(classOf[KafkaRaftClient[String]]) + val ioThread = new RaftIoThread(raftClient, threadNamePrefix = "test-raft") + + when(raftClient.isRunning).thenReturn(true) + assertTrue(ioThread.isRunning) + + when(raftClient.poll()).thenThrow(new RuntimeException) + ioThread.run() + + assertTrue(ioThread.isShutdownComplete) + assertTrue(ioThread.isThreadFailed) + assertFalse(ioThread.isRunning) + } + +} diff --git a/core/src/test/scala/unit/kafka/security/auth/ZkAuthorizationTest.scala b/core/src/test/scala/unit/kafka/security/auth/ZkAuthorizationTest.scala new file mode 100644 index 0000000..3bbce4d --- /dev/null +++ b/core/src/test/scala/unit/kafka/security/auth/ZkAuthorizationTest.scala @@ -0,0 +1,342 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.security.auth + +import java.nio.charset.StandardCharsets +import kafka.admin.ZkSecurityMigrator +import kafka.server.QuorumTestHarness +import kafka.utils.{Logging, TestUtils} +import kafka.zk._ +import org.apache.kafka.common.{KafkaException, TopicPartition, Uuid} +import org.apache.kafka.common.security.JaasUtils +import org.apache.zookeeper.data.{ACL, Stat} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +import scala.util.{Failure, Success, Try} +import javax.security.auth.login.Configuration +import kafka.api.ApiVersion +import kafka.cluster.{Broker, EndPoint} +import kafka.controller.ReplicaAssignment +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.utils.Time +import org.apache.zookeeper.client.ZKClientConfig + +import scala.jdk.CollectionConverters._ +import scala.collection.Seq + +class ZkAuthorizationTest extends QuorumTestHarness with Logging { + val jaasFile = kafka.utils.JaasTestUtils.writeJaasContextsToFile(kafka.utils.JaasTestUtils.zkSections) + val authProvider = "zookeeper.authProvider.1" + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + System.setProperty(JaasUtils.JAVA_LOGIN_CONFIG_PARAM, jaasFile.getAbsolutePath) + Configuration.setConfiguration(null) + System.setProperty(authProvider, "org.apache.zookeeper.server.auth.SASLAuthenticationProvider") + super.setUp(testInfo) + } + + @AfterEach + override def tearDown(): Unit = { + super.tearDown() + System.clearProperty(JaasUtils.JAVA_LOGIN_CONFIG_PARAM) + System.clearProperty(authProvider) + Configuration.setConfiguration(null) + } + + /** + * Tests the method in JaasUtils that checks whether to use + * secure ACLs and authentication with ZooKeeper. + */ + @Test + def testIsZkSecurityEnabled(): Unit = { + assertTrue(JaasUtils.isZkSaslEnabled()) + Configuration.setConfiguration(null) + System.clearProperty(JaasUtils.JAVA_LOGIN_CONFIG_PARAM) + assertFalse(JaasUtils.isZkSaslEnabled()) + Configuration.setConfiguration(null) + System.setProperty(JaasUtils.JAVA_LOGIN_CONFIG_PARAM, "no-such-file-exists.conf") + assertThrows(classOf[KafkaException], () => JaasUtils.isZkSaslEnabled()) + } + + /** + * Exercises the code in KafkaZkClient. The goal is mainly + * to verify that the behavior of KafkaZkClient is correct + * when isSecure is set to true. + */ + @Test + def testKafkaZkClient(): Unit = { + assertTrue(zkClient.secure) + for (path <- ZkData.PersistentZkPaths) { + zkClient.makeSurePersistentPathExists(path) + if (ZkData.sensitivePath(path)) { + val aclList = zkClient.getAcl(path) + assertEquals(1, aclList.size, s"Unexpected acl list size for $path") + for (acl <- aclList) + assertTrue(TestUtils.isAclSecure(acl, sensitive = true)) + } else if (!path.equals(ConsumerPathZNode.path)) { + val aclList = zkClient.getAcl(path) + assertEquals(2, aclList.size, s"Unexpected acl list size for $path") + for (acl <- aclList) + assertTrue(TestUtils.isAclSecure(acl, sensitive = false)) + } + } + + // Test that creates Ephemeral node + val brokerInfo = createBrokerInfo(1, "test.host", 9999, SecurityProtocol.PLAINTEXT) + zkClient.registerBroker(brokerInfo) + verify(brokerInfo.path) + + // Test that creates persistent nodes + val topic1 = "topic1" + val topicId = Some(Uuid.randomUuid()) + val assignment = Map( + new TopicPartition(topic1, 0) -> Seq(0, 1), + new TopicPartition(topic1, 1) -> Seq(0, 1), + new TopicPartition(topic1, 2) -> Seq(1, 2, 3) + ) + + // create a topic assignment + zkClient.createTopicAssignment(topic1, topicId, assignment) + verify(TopicZNode.path(topic1)) + + // Test that can create: createSequentialPersistentPath + val seqPath = zkClient.createSequentialPersistentPath("/c", "".getBytes(StandardCharsets.UTF_8)) + verify(seqPath) + + // Test that can update Ephemeral node + val updatedBrokerInfo = createBrokerInfo(1, "test.host2", 9995, SecurityProtocol.SSL) + zkClient.updateBrokerInfo(updatedBrokerInfo) + assertEquals(Some(updatedBrokerInfo.broker), zkClient.getBroker(1)) + + // Test that can update persistent nodes + val updatedAssignment = assignment - new TopicPartition(topic1, 2) + zkClient.setTopicAssignment(topic1, topicId, + updatedAssignment.map { case (k, v) => k -> ReplicaAssignment(v, List(), List()) }) + assertEquals(updatedAssignment.size, zkClient.getTopicPartitionCount(topic1).get) + } + + private def createBrokerInfo(id: Int, host: String, port: Int, securityProtocol: SecurityProtocol, + rack: Option[String] = None): BrokerInfo = + BrokerInfo(Broker(id, Seq(new EndPoint(host, port, ListenerName.forSecurityProtocol + (securityProtocol), securityProtocol)), rack = rack), ApiVersion.latestVersion, jmxPort = port + 10) + + private def newKafkaZkClient(connectionString: String, isSecure: Boolean) = + KafkaZkClient(connectionString, isSecure, 6000, 6000, Int.MaxValue, Time.SYSTEM, "ZkAuthorizationTest", + new ZKClientConfig) + + /** + * Tests the migration tool when making an unsecure + * cluster secure. + */ + @Test + def testZkMigration(): Unit = { + val unsecureZkClient = newKafkaZkClient(zkConnect, isSecure = false) + try { + testMigration(zkConnect, unsecureZkClient, zkClient) + } finally { + unsecureZkClient.close() + } + } + + /** + * Tests the migration tool when making a secure + * cluster unsecure. + */ + @Test + def testZkAntiMigration(): Unit = { + val unsecureZkClient = newKafkaZkClient(zkConnect, isSecure = false) + try { + testMigration(zkConnect, zkClient, unsecureZkClient) + } finally { + unsecureZkClient.close() + } + } + + /** + * Tests that the persistent paths cannot be deleted. + */ + @Test + def testDelete(): Unit = { + info(s"zkConnect string: $zkConnect") + ZkSecurityMigrator.run(Array("--zookeeper.acl=secure", s"--zookeeper.connect=$zkConnect")) + deleteAllUnsecure() + } + + /** + * Tests that znodes cannot be deleted when the + * persistent paths have children. + */ + @Test + def testDeleteRecursive(): Unit = { + info(s"zkConnect string: $zkConnect") + for (path <- ZkData.SecureRootPaths) { + info(s"Creating $path") + zkClient.makeSurePersistentPathExists(path) + zkClient.createRecursive(s"$path/fpjwashere", "".getBytes(StandardCharsets.UTF_8)) + } + zkClient.setAcl("/", zkClient.defaultAcls("/")) + deleteAllUnsecure() + } + + /** + * Tests the migration tool when chroot is being used. + */ + @Test + def testChroot(): Unit = { + val zkUrl = zkConnect + "/kafka" + zkClient.createRecursive("/kafka") + val unsecureZkClient = newKafkaZkClient(zkUrl, isSecure = false) + val secureZkClient = newKafkaZkClient(zkUrl, isSecure = true) + try { + testMigration(zkUrl, unsecureZkClient, secureZkClient) + } finally { + unsecureZkClient.close() + secureZkClient.close() + } + } + + /** + * Exercises the migration tool. It is used in these test cases: + * testZkMigration, testZkAntiMigration, testChroot. + */ + private def testMigration(zkUrl: String, firstZk: KafkaZkClient, secondZk: KafkaZkClient): Unit = { + info(s"zkConnect string: $zkUrl") + for (path <- ZkData.SecureRootPaths ++ ZkData.SensitiveRootPaths) { + info(s"Creating $path") + firstZk.makeSurePersistentPathExists(path) + // Create a child for each znode to exercise the recurrent + // traversal of the data tree + firstZk.createRecursive(s"$path/fpjwashere", "".getBytes(StandardCharsets.UTF_8)) + } + // Getting security option to determine how to verify ACLs. + // Additionally, we create the consumers znode (not in + // securePersistentZkPaths) to make sure that we don't + // add ACLs to it. + val secureOpt: String = + if (secondZk.secure) { + firstZk.createRecursive(ConsumerPathZNode.path) + "secure" + } else { + secondZk.createRecursive(ConsumerPathZNode.path) + "unsecure" + } + ZkSecurityMigrator.run(Array(s"--zookeeper.acl=$secureOpt", s"--zookeeper.connect=$zkUrl")) + info("Done with migration") + for (path <- ZkData.SecureRootPaths ++ ZkData.SensitiveRootPaths) { + val sensitive = ZkData.sensitivePath(path) + val listParent = secondZk.getAcl(path) + assertTrue(isAclCorrect(listParent, secondZk.secure, sensitive), path) + + val childPath = path + "/fpjwashere" + val listChild = secondZk.getAcl(childPath) + assertTrue(isAclCorrect(listChild, secondZk.secure, sensitive), childPath) + } + // Check consumers path. + val consumersAcl = firstZk.getAcl(ConsumerPathZNode.path) + assertTrue(isAclCorrect(consumersAcl, false, false), ConsumerPathZNode.path) + assertTrue(isAclCorrect(firstZk.getAcl("/kafka-acl-extended"), secondZk.secure, + ZkData.sensitivePath(ExtendedAclZNode.path)), "/kafka-acl-extended") + } + + /** + * Verifies that the path has the appropriate secure ACL. + */ + private def verify(path: String): Unit = { + val sensitive = ZkData.sensitivePath(path) + val list = zkClient.getAcl(path) + assertTrue(list.forall(TestUtils.isAclSecure(_, sensitive))) + } + + /** + * Verifies ACL. + */ + private def isAclCorrect(list: Seq[ACL], secure: Boolean, sensitive: Boolean): Boolean = { + val isListSizeCorrect = + if (secure && !sensitive) + list.size == 2 + else + list.size == 1 + isListSizeCorrect && list.forall( + if (secure) + TestUtils.isAclSecure(_, sensitive) + else + TestUtils.isAclUnsecure + ) + } + + /** + * Sets up and starts the recursive execution of deletes. + * This is used in the testDelete and testDeleteRecursive + * test cases. + */ + private def deleteAllUnsecure(): Unit = { + System.setProperty(JaasUtils.ZK_SASL_CLIENT, "false") + val unsecureZkClient = newKafkaZkClient(zkConnect, isSecure = false) + val result: Try[Boolean] = { + deleteRecursive(unsecureZkClient, "/") + } + // Clean up before leaving the test case + unsecureZkClient.close() + System.clearProperty(JaasUtils.ZK_SASL_CLIENT) + + // Fail the test if able to delete + result match { + case Success(_) => // All done + case Failure(e) => fail(e.getMessage) + } + } + + /** + * Tries to delete znodes recursively + */ + private def deleteRecursive(zkClient: KafkaZkClient, path: String): Try[Boolean] = { + info(s"Deleting $path") + var result: Try[Boolean] = Success(true) + for (child <- zkClient.getChildren(path)) + result = (path match { + case "/" => deleteRecursive(zkClient, s"/$child") + case path => deleteRecursive(zkClient, s"$path/$child") + }) match { + case Success(_) => result + case Failure(e) => Failure(e) + } + path match { + // Do not try to delete the root + case "/" => result + // For all other paths, try to delete it + case path => + try { + zkClient.deletePath(path, recursiveDelete = false) + Failure(new Exception(s"Have been able to delete $path")) + } catch { + case _: Exception => result + } + } + } + + @Test + def testConsumerOffsetPathAcls(): Unit = { + zkClient.makeSurePersistentPathExists(ConsumerPathZNode.path) + + val consumerPathAcls = zkClient.currentZooKeeper.getACL(ConsumerPathZNode.path, new Stat()) + assertTrue(consumerPathAcls.asScala.forall(TestUtils.isAclUnsecure), "old consumer znode path acls are not open") + } +} diff --git a/core/src/test/scala/unit/kafka/security/authorizer/AclAuthorizerTest.scala b/core/src/test/scala/unit/kafka/security/authorizer/AclAuthorizerTest.scala new file mode 100644 index 0000000..9011eb6 --- /dev/null +++ b/core/src/test/scala/unit/kafka/security/authorizer/AclAuthorizerTest.scala @@ -0,0 +1,1082 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.security.authorizer + +import java.io.File +import java.net.InetAddress +import java.nio.charset.StandardCharsets.UTF_8 +import java.nio.file.Files +import java.util.{Collections, UUID} +import java.util.concurrent.{Executors, Semaphore, TimeUnit} +import kafka.Kafka +import kafka.api.{ApiVersion, KAFKA_2_0_IV0, KAFKA_2_0_IV1} +import kafka.security.authorizer.AclEntry.{WildcardHost, WildcardPrincipalString} +import kafka.server.{KafkaConfig, QuorumTestHarness} +import kafka.utils.TestUtils +import kafka.zk.ZkAclStore +import kafka.zookeeper.{GetChildrenRequest, GetDataRequest, ZooKeeperClient} +import org.apache.kafka.common.acl._ +import org.apache.kafka.common.acl.AclOperation._ +import org.apache.kafka.common.acl.AclPermissionType.{ALLOW, DENY} +import org.apache.kafka.common.errors.{ApiException, UnsupportedVersionException} +import org.apache.kafka.common.requests.RequestContext +import org.apache.kafka.common.resource.{PatternType, ResourcePattern, ResourcePatternFilter, ResourceType} +import org.apache.kafka.common.resource.Resource.CLUSTER_NAME +import org.apache.kafka.common.resource.ResourcePattern.WILDCARD_RESOURCE +import org.apache.kafka.common.resource.ResourceType._ +import org.apache.kafka.common.resource.PatternType.{LITERAL, MATCH, PREFIXED} +import org.apache.kafka.common.security.auth.KafkaPrincipal +import org.apache.kafka.server.authorizer._ +import org.apache.kafka.common.utils.{Time, SecurityUtils => JSecurityUtils} +import org.apache.zookeeper.client.ZKClientConfig +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +import scala.jdk.CollectionConverters._ +import scala.collection.mutable + +class AclAuthorizerTest extends QuorumTestHarness with BaseAuthorizerTest { + + private val allowReadAcl = new AccessControlEntry(WildcardPrincipalString, WildcardHost, READ, ALLOW) + private val allowWriteAcl = new AccessControlEntry(WildcardPrincipalString, WildcardHost, WRITE, ALLOW) + private val denyReadAcl = new AccessControlEntry(WildcardPrincipalString, WildcardHost, READ, DENY) + + private val wildCardResource = new ResourcePattern(TOPIC, WILDCARD_RESOURCE, LITERAL) + private val prefixedResource = new ResourcePattern(TOPIC, "foo", PREFIXED) + private val clusterResource = new ResourcePattern(CLUSTER, CLUSTER_NAME, LITERAL) + private val wildcardPrincipal = JSecurityUtils.parseKafkaPrincipal(WildcardPrincipalString) + + private val aclAuthorizer = new AclAuthorizer + private val aclAuthorizer2 = new AclAuthorizer + + class CustomPrincipal(principalType: String, name: String) extends KafkaPrincipal(principalType, name) { + override def equals(o: scala.Any): Boolean = false + } + + override def authorizer: Authorizer = aclAuthorizer + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + + // Increase maxUpdateRetries to avoid transient failures + aclAuthorizer.maxUpdateRetries = Int.MaxValue + aclAuthorizer2.maxUpdateRetries = Int.MaxValue + + val props = TestUtils.createBrokerConfig(0, zkConnect) + props.put(AclAuthorizer.SuperUsersProp, superUsers) + + config = KafkaConfig.fromProps(props) + aclAuthorizer.configure(config.originals) + aclAuthorizer2.configure(config.originals) + resource = new ResourcePattern(TOPIC, "foo-" + UUID.randomUUID(), LITERAL) + + zooKeeperClient = new ZooKeeperClient(zkConnect, zkSessionTimeout, zkConnectionTimeout, zkMaxInFlightRequests, + Time.SYSTEM, "kafka.test", "AclAuthorizerTest", new ZKClientConfig, "AclAuthorizerTest") + } + + @AfterEach + override def tearDown(): Unit = { + aclAuthorizer.close() + aclAuthorizer2.close() + zooKeeperClient.close() + super.tearDown() + } + + @Test + def testAuthorizeThrowsOnNonLiteralResource(): Unit = { + assertThrows(classOf[IllegalArgumentException], () => authorize(aclAuthorizer, requestContext, READ, + new ResourcePattern(TOPIC, "something", PREFIXED))) + } + + @Test + def testAuthorizeWithEmptyResourceName(): Unit = { + assertFalse(authorize(aclAuthorizer, requestContext, READ, new ResourcePattern(GROUP, "", LITERAL))) + addAcls(aclAuthorizer, Set(allowReadAcl), new ResourcePattern(GROUP, WILDCARD_RESOURCE, LITERAL)) + assertTrue(authorize(aclAuthorizer, requestContext, READ, new ResourcePattern(GROUP, "", LITERAL))) + } + + // Authorizing the empty resource is not supported because we create a znode with the resource name. + @Test + def testEmptyAclThrowsException(): Unit = { + val e = assertThrows(classOf[ApiException], + () => addAcls(aclAuthorizer, Set(allowReadAcl), new ResourcePattern(GROUP, "", LITERAL))) + assertTrue(e.getCause.isInstanceOf[IllegalArgumentException], s"Unexpected exception $e") + } + + @Test + def testTopicAcl(): Unit = { + val user1 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, username) + val user2 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "rob") + val user3 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "batman") + val host1 = InetAddress.getByName("192.168.1.1") + val host2 = InetAddress.getByName("192.168.1.2") + + //user1 has READ access from host1 and host2. + val acl1 = new AccessControlEntry(user1.toString, host1.getHostAddress, READ, ALLOW) + val acl2 = new AccessControlEntry(user1.toString, host2.getHostAddress, READ, ALLOW) + + //user1 does not have READ access from host1. + val acl3 = new AccessControlEntry(user1.toString, host1.getHostAddress, READ, DENY) + + //user1 has WRITE access from host1 only. + val acl4 = new AccessControlEntry(user1.toString, host1.getHostAddress, WRITE, ALLOW) + + //user1 has DESCRIBE access from all hosts. + val acl5 = new AccessControlEntry(user1.toString, WildcardHost, DESCRIBE, ALLOW) + + //user2 has READ access from all hosts. + val acl6 = new AccessControlEntry(user2.toString, WildcardHost, READ, ALLOW) + + //user3 has WRITE access from all hosts. + val acl7 = new AccessControlEntry(user3.toString, WildcardHost, WRITE, ALLOW) + + val acls = Set(acl1, acl2, acl3, acl4, acl5, acl6, acl7) + + changeAclAndVerify(Set.empty, acls, Set.empty) + + val host1Context = newRequestContext(user1, host1) + val host2Context = newRequestContext(user1, host2) + + assertTrue(authorize(aclAuthorizer, host2Context, READ, resource), "User1 should have READ access from host2") + assertFalse(authorize(aclAuthorizer, host1Context, READ, resource), "User1 should not have READ access from host1 due to denyAcl") + assertTrue(authorize(aclAuthorizer, host1Context, WRITE, resource), "User1 should have WRITE access from host1") + assertFalse(authorize(aclAuthorizer, host2Context, WRITE, resource), "User1 should not have WRITE access from host2 as no allow acl is defined") + assertTrue(authorize(aclAuthorizer, host1Context, DESCRIBE, resource), "User1 should not have DESCRIBE access from host1") + assertTrue(authorize(aclAuthorizer, host2Context, DESCRIBE, resource), "User1 should have DESCRIBE access from host2") + assertFalse(authorize(aclAuthorizer, host1Context, ALTER, resource), "User1 should not have edit access from host1") + assertFalse(authorize(aclAuthorizer, host2Context, ALTER, resource), "User1 should not have edit access from host2") + + //test if user has READ and write access they also get describe access + val user2Context = newRequestContext(user2, host1) + val user3Context = newRequestContext(user3, host1) + assertTrue(authorize(aclAuthorizer, user2Context, DESCRIBE, resource), "User2 should have DESCRIBE access from host1") + assertTrue(authorize(aclAuthorizer, user3Context, DESCRIBE, resource), "User3 should have DESCRIBE access from host2") + assertTrue(authorize(aclAuthorizer, user2Context, READ, resource), "User2 should have READ access from host1") + assertTrue(authorize(aclAuthorizer, user3Context, WRITE, resource), "User3 should have WRITE access from host2") + } + + /** + CustomPrincipals should be compared with their principal type and name + */ + @Test + def testAllowAccessWithCustomPrincipal(): Unit = { + val user = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, username) + val customUserPrincipal = new CustomPrincipal(KafkaPrincipal.USER_TYPE, username) + val host1 = InetAddress.getByName("192.168.1.1") + val host2 = InetAddress.getByName("192.168.1.2") + + // user has READ access from host2 but not from host1 + val acl1 = new AccessControlEntry(user.toString, host1.getHostAddress, READ, DENY) + val acl2 = new AccessControlEntry(user.toString, host2.getHostAddress, READ, ALLOW) + val acls = Set(acl1, acl2) + changeAclAndVerify(Set.empty, acls, Set.empty) + + val host1Context = newRequestContext(customUserPrincipal, host1) + val host2Context = newRequestContext(customUserPrincipal, host2) + + assertTrue(authorize(aclAuthorizer, host2Context, READ, resource), "User1 should have READ access from host2") + assertFalse(authorize(aclAuthorizer, host1Context, READ, resource), "User1 should not have READ access from host1 due to denyAcl") + } + + @Test + def testDenyTakesPrecedence(): Unit = { + val user = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, username) + val host = InetAddress.getByName("192.168.2.1") + val session = newRequestContext(user, host) + + val allowAll = new AccessControlEntry(WildcardPrincipalString, WildcardHost, AclOperation.ALL, ALLOW) + val denyAcl = new AccessControlEntry(user.toString, host.getHostAddress, AclOperation.ALL, DENY) + val acls = Set(allowAll, denyAcl) + + changeAclAndVerify(Set.empty, acls, Set.empty) + + assertFalse(authorize(aclAuthorizer, session, READ, resource), "deny should take precedence over allow.") + } + + @Test + def testAllowAllAccess(): Unit = { + val allowAllAcl = new AccessControlEntry(WildcardPrincipalString, WildcardHost, AclOperation.ALL, ALLOW) + + changeAclAndVerify(Set.empty, Set(allowAllAcl), Set.empty) + + val context = newRequestContext(new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "random"), InetAddress.getByName("192.0.4.4")) + assertTrue(authorize(aclAuthorizer, context, READ, resource), "allow all acl should allow access to all.") + } + + @Test + def testSuperUserHasAccess(): Unit = { + val denyAllAcl = new AccessControlEntry(WildcardPrincipalString, WildcardHost, AclOperation.ALL, DENY) + + changeAclAndVerify(Set.empty, Set(denyAllAcl), Set.empty) + + val session1 = newRequestContext(new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "superuser1"), InetAddress.getByName("192.0.4.4")) + val session2 = newRequestContext(new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "superuser2"), InetAddress.getByName("192.0.4.4")) + + assertTrue(authorize(aclAuthorizer, session1, READ, resource), "superuser always has access, no matter what acls.") + assertTrue(authorize(aclAuthorizer, session2, READ, resource), "superuser always has access, no matter what acls.") + } + + /** + CustomPrincipals should be compared with their principal type and name + */ + @Test + def testSuperUserWithCustomPrincipalHasAccess(): Unit = { + val denyAllAcl = new AccessControlEntry(WildcardPrincipalString, WildcardHost, AclOperation.ALL, DENY) + changeAclAndVerify(Set.empty, Set(denyAllAcl), Set.empty) + + val session = newRequestContext(new CustomPrincipal(KafkaPrincipal.USER_TYPE, "superuser1"), InetAddress.getByName("192.0.4.4")) + + assertTrue(authorize(aclAuthorizer, session, READ, resource), "superuser with custom principal always has access, no matter what acls.") + } + + @Test + def testWildCardAcls(): Unit = { + assertFalse(authorize(aclAuthorizer, requestContext, READ, resource), "when acls = [], authorizer should fail close.") + + val user1 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, username) + val host1 = InetAddress.getByName("192.168.3.1") + val readAcl = new AccessControlEntry(user1.toString, host1.getHostAddress, READ, ALLOW) + + val acls = changeAclAndVerify(Set.empty, Set(readAcl), Set.empty, wildCardResource) + + val host1Context = newRequestContext(user1, host1) + assertTrue(authorize(aclAuthorizer, host1Context, READ, resource), "User1 should have READ access from host1") + + //allow WRITE to specific topic. + val writeAcl = new AccessControlEntry(user1.toString, host1.getHostAddress, WRITE, ALLOW) + changeAclAndVerify(Set.empty, Set(writeAcl), Set.empty) + + //deny WRITE to wild card topic. + val denyWriteOnWildCardResourceAcl = new AccessControlEntry(user1.toString, host1.getHostAddress, WRITE, DENY) + changeAclAndVerify(acls, Set(denyWriteOnWildCardResourceAcl), Set.empty, wildCardResource) + + assertFalse(authorize(aclAuthorizer, host1Context, WRITE, resource), "User1 should not have WRITE access from host1") + } + + @Test + def testNoAclFound(): Unit = { + assertFalse(authorize(aclAuthorizer, requestContext, READ, resource), "when acls = [], authorizer should deny op.") + } + + @Test + def testNoAclFoundOverride(): Unit = { + val props = TestUtils.createBrokerConfig(1, zkConnect) + props.put(AclAuthorizer.AllowEveryoneIfNoAclIsFoundProp, "true") + + val cfg = KafkaConfig.fromProps(props) + val testAuthorizer = new AclAuthorizer + try { + testAuthorizer.configure(cfg.originals) + assertTrue(authorize(testAuthorizer, requestContext, READ, resource), + "when acls = null or [], authorizer should allow op with allow.everyone = true.") + } finally { + testAuthorizer.close() + } + } + + @Test + def testAclManagementAPIs(): Unit = { + val user1 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, username) + val user2 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "bob") + val host1 = "host1" + val host2 = "host2" + + val acl1 = new AccessControlEntry(user1.toString, host1, READ, ALLOW) + val acl2 = new AccessControlEntry(user1.toString, host1, WRITE, ALLOW) + val acl3 = new AccessControlEntry(user2.toString, host2, READ, ALLOW) + val acl4 = new AccessControlEntry(user2.toString, host2, WRITE, ALLOW) + + var acls = changeAclAndVerify(Set.empty, Set(acl1, acl2, acl3, acl4), Set.empty) + + //test addAcl is additive + val acl5 = new AccessControlEntry(user2.toString, WildcardHost, READ, ALLOW) + acls = changeAclAndVerify(acls, Set(acl5), Set.empty) + + //test get by principal name. + TestUtils.waitUntilTrue(() => Set(acl1, acl2).map(acl => new AclBinding(resource, acl)) == getAcls(aclAuthorizer, user1), + "changes not propagated in timeout period") + TestUtils.waitUntilTrue(() => Set(acl3, acl4, acl5).map(acl => new AclBinding(resource, acl)) == getAcls(aclAuthorizer, user2), + "changes not propagated in timeout period") + + val resourceToAcls = Map[ResourcePattern, Set[AccessControlEntry]]( + new ResourcePattern(TOPIC, WILDCARD_RESOURCE, LITERAL) -> Set(new AccessControlEntry(user2.toString, WildcardHost, READ, ALLOW)), + new ResourcePattern(CLUSTER , WILDCARD_RESOURCE, LITERAL) -> Set(new AccessControlEntry(user2.toString, host1, READ, ALLOW)), + new ResourcePattern(GROUP, WILDCARD_RESOURCE, LITERAL) -> acls, + new ResourcePattern(GROUP, "test-ConsumerGroup", LITERAL) -> acls + ) + + resourceToAcls foreach { case (key, value) => changeAclAndVerify(Set.empty, value, Set.empty, key) } + val expectedAcls = (resourceToAcls + (resource -> acls)).flatMap { + case (res, resAcls) => resAcls.map { acl => new AclBinding(res, acl) } + }.toSet + TestUtils.waitUntilTrue(() => expectedAcls == getAcls(aclAuthorizer), "changes not propagated in timeout period.") + + //test remove acl from existing acls. + acls = changeAclAndVerify(acls, Set.empty, Set(acl1, acl5)) + + //test remove all acls for resource + removeAcls(aclAuthorizer, Set.empty, resource) + TestUtils.waitAndVerifyAcls(Set.empty[AccessControlEntry], aclAuthorizer, resource) + assertFalse(zkClient.resourceExists(resource)) + + //test removing last acl also deletes ZooKeeper path + acls = changeAclAndVerify(Set.empty, Set(acl1), Set.empty) + changeAclAndVerify(acls, Set.empty, acls) + assertFalse(zkClient.resourceExists(resource)) + } + + @Test + def testLoadCache(): Unit = { + val user1 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, username) + val acl1 = new AccessControlEntry(user1.toString, "host-1", READ, ALLOW) + val acls = Set(acl1) + addAcls(aclAuthorizer, acls, resource) + + val user2 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "bob") + val resource1 = new ResourcePattern(TOPIC, "test-2", LITERAL) + val acl2 = new AccessControlEntry(user2.toString, "host3", READ, DENY) + val acls1 = Set(acl2) + addAcls(aclAuthorizer, acls1, resource1) + + zkClient.deleteAclChangeNotifications() + val authorizer = new AclAuthorizer + try { + authorizer.configure(config.originals) + + assertEquals(acls, getAcls(authorizer, resource)) + assertEquals(acls1, getAcls(authorizer, resource1)) + } finally { + authorizer.close() + } + } + + /** + * Verify that there is no timing window between loading ACL cache and setting + * up ZK change listener. Cache must be loaded before creating change listener + * in the authorizer to avoid the timing window. + */ + @Test + def testChangeListenerTiming(): Unit = { + val configureSemaphore = new Semaphore(0) + val listenerSemaphore = new Semaphore(0) + val executor = Executors.newSingleThreadExecutor + val aclAuthorizer3 = new AclAuthorizer { + override private[authorizer] def startZkChangeListeners(): Unit = { + configureSemaphore.release() + listenerSemaphore.acquireUninterruptibly() + super.startZkChangeListeners() + } + } + try { + val future = executor.submit((() => aclAuthorizer3.configure(config.originals)): Runnable) + configureSemaphore.acquire() + val user1 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, username) + val acls = Set(new AccessControlEntry(user1.toString, "host-1", READ, DENY)) + addAcls(aclAuthorizer, acls, resource) + + listenerSemaphore.release() + future.get(10, TimeUnit.SECONDS) + + assertEquals(acls, getAcls(aclAuthorizer3, resource)) + } finally { + aclAuthorizer3.close() + executor.shutdownNow() + } + } + + @Test + def testLocalConcurrentModificationOfResourceAcls(): Unit = { + val commonResource = new ResourcePattern(TOPIC, "test", LITERAL) + + val user1 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, username) + val acl1 = new AccessControlEntry(user1.toString, WildcardHost, READ, ALLOW) + + val user2 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "bob") + val acl2 = new AccessControlEntry(user2.toString, WildcardHost, READ, DENY) + + addAcls(aclAuthorizer, Set(acl1), commonResource) + addAcls(aclAuthorizer, Set(acl2), commonResource) + + TestUtils.waitAndVerifyAcls(Set(acl1, acl2), aclAuthorizer, commonResource) + } + + @Test + def testDistributedConcurrentModificationOfResourceAcls(): Unit = { + val commonResource = new ResourcePattern(TOPIC, "test", LITERAL) + + val user1 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, username) + val acl1 = new AccessControlEntry(user1.toString, WildcardHost, READ, ALLOW) + + val user2 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "bob") + val acl2 = new AccessControlEntry(user2.toString, WildcardHost, READ, DENY) + + // Add on each instance + addAcls(aclAuthorizer, Set(acl1), commonResource) + addAcls(aclAuthorizer2, Set(acl2), commonResource) + + TestUtils.waitAndVerifyAcls(Set(acl1, acl2), aclAuthorizer, commonResource) + TestUtils.waitAndVerifyAcls(Set(acl1, acl2), aclAuthorizer2, commonResource) + + val user3 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "joe") + val acl3 = new AccessControlEntry(user3.toString, WildcardHost, READ, DENY) + + // Add on one instance and delete on another + addAcls(aclAuthorizer, Set(acl3), commonResource) + val deleted = removeAcls(aclAuthorizer2, Set(acl3), commonResource) + + assertTrue(deleted, "The authorizer should see a value that needs to be deleted") + + TestUtils.waitAndVerifyAcls(Set(acl1, acl2), aclAuthorizer, commonResource) + TestUtils.waitAndVerifyAcls(Set(acl1, acl2), aclAuthorizer2, commonResource) + } + + @Test + def testHighConcurrencyModificationOfResourceAcls(): Unit = { + val commonResource = new ResourcePattern(TOPIC, "test", LITERAL) + + val acls= (0 to 50).map { i => + val useri = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, i.toString) + (new AccessControlEntry(useri.toString, WildcardHost, READ, ALLOW), i) + } + + // Alternate authorizer, Remove all acls that end in 0 + val concurrentFuctions = acls.map { case (acl, aclId) => + () => { + if (aclId % 2 == 0) { + addAcls(aclAuthorizer, Set(acl), commonResource) + } else { + addAcls(aclAuthorizer2, Set(acl), commonResource) + } + if (aclId % 10 == 0) { + removeAcls(aclAuthorizer2, Set(acl), commonResource) + } + } + } + + val expectedAcls = acls.filter { case (acl, aclId) => + aclId % 10 != 0 + }.map(_._1).toSet + + TestUtils.assertConcurrent("Should support many concurrent calls", concurrentFuctions, 30 * 1000) + + TestUtils.waitAndVerifyAcls(expectedAcls, aclAuthorizer, commonResource) + TestUtils.waitAndVerifyAcls(expectedAcls, aclAuthorizer2, commonResource) + } + + /** + * Test ACL inheritance, as described in #{org.apache.kafka.common.acl.AclOperation} + */ + @Test + def testAclInheritance(): Unit = { + testImplicationsOfAllow(AclOperation.ALL, Set(READ, WRITE, CREATE, DELETE, ALTER, DESCRIBE, + CLUSTER_ACTION, DESCRIBE_CONFIGS, ALTER_CONFIGS, IDEMPOTENT_WRITE)) + testImplicationsOfDeny(AclOperation.ALL, Set(READ, WRITE, CREATE, DELETE, ALTER, DESCRIBE, + CLUSTER_ACTION, DESCRIBE_CONFIGS, ALTER_CONFIGS, IDEMPOTENT_WRITE)) + testImplicationsOfAllow(READ, Set(DESCRIBE)) + testImplicationsOfAllow(WRITE, Set(DESCRIBE)) + testImplicationsOfAllow(DELETE, Set(DESCRIBE)) + testImplicationsOfAllow(ALTER, Set(DESCRIBE)) + testImplicationsOfDeny(DESCRIBE, Set()) + testImplicationsOfAllow(ALTER_CONFIGS, Set(DESCRIBE_CONFIGS)) + testImplicationsOfDeny(DESCRIBE_CONFIGS, Set()) + } + + private def testImplicationsOfAllow(parentOp: AclOperation, allowedOps: Set[AclOperation]): Unit = { + val user = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, username) + val host = InetAddress.getByName("192.168.3.1") + val hostContext = newRequestContext(user, host) + val acl = new AccessControlEntry(user.toString, WildcardHost, parentOp, ALLOW) + addAcls(aclAuthorizer, Set(acl), clusterResource) + AclOperation.values.filter(validOp).foreach { op => + val authorized = authorize(aclAuthorizer, hostContext, op, clusterResource) + if (allowedOps.contains(op) || op == parentOp) + assertTrue(authorized, s"ALLOW $parentOp should imply ALLOW $op") + else + assertFalse(authorized, s"ALLOW $parentOp should not imply ALLOW $op") + } + removeAcls(aclAuthorizer, Set(acl), clusterResource) + } + + private def testImplicationsOfDeny(parentOp: AclOperation, deniedOps: Set[AclOperation]): Unit = { + val user1 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, username) + val host1 = InetAddress.getByName("192.168.3.1") + val host1Context = newRequestContext(user1, host1) + val acls = Set(new AccessControlEntry(user1.toString, WildcardHost, parentOp, DENY), + new AccessControlEntry(user1.toString, WildcardHost, AclOperation.ALL, ALLOW)) + addAcls(aclAuthorizer, acls, clusterResource) + AclOperation.values.filter(validOp).foreach { op => + val authorized = authorize(aclAuthorizer, host1Context, op, clusterResource) + if (deniedOps.contains(op) || op == parentOp) + assertFalse(authorized, s"DENY $parentOp should imply DENY $op") + else + assertTrue(authorized, s"DENY $parentOp should not imply DENY $op") + } + removeAcls(aclAuthorizer, acls, clusterResource) + } + + @Test + def testHighConcurrencyDeletionOfResourceAcls(): Unit = { + val acl = new AccessControlEntry(new KafkaPrincipal(KafkaPrincipal.USER_TYPE, username).toString, WildcardHost, AclOperation.ALL, ALLOW) + + // Alternate authorizer to keep adding and removing ZooKeeper path + val concurrentFuctions = (0 to 50).map { _ => + () => { + addAcls(aclAuthorizer, Set(acl), resource) + removeAcls(aclAuthorizer2, Set(acl), resource) + } + } + + TestUtils.assertConcurrent("Should support many concurrent calls", concurrentFuctions, 30 * 1000) + + TestUtils.waitAndVerifyAcls(Set.empty[AccessControlEntry], aclAuthorizer, resource) + TestUtils.waitAndVerifyAcls(Set.empty[AccessControlEntry], aclAuthorizer2, resource) + } + + @Test + def testAccessAllowedIfAllowAclExistsOnWildcardResource(): Unit = { + addAcls(aclAuthorizer, Set(allowReadAcl), wildCardResource) + + assertTrue(authorize(aclAuthorizer, requestContext, READ, resource)) + } + + @Test + def testDeleteAclOnWildcardResource(): Unit = { + addAcls(aclAuthorizer, Set(allowReadAcl, allowWriteAcl), wildCardResource) + + removeAcls(aclAuthorizer, Set(allowReadAcl), wildCardResource) + + assertEquals(Set(allowWriteAcl), getAcls(aclAuthorizer, wildCardResource)) + } + + @Test + def testDeleteAllAclOnWildcardResource(): Unit = { + addAcls(aclAuthorizer, Set(allowReadAcl), wildCardResource) + + removeAcls(aclAuthorizer, Set.empty, wildCardResource) + + assertEquals(Set.empty, getAcls(aclAuthorizer)) + } + + @Test + def testAccessAllowedIfAllowAclExistsOnPrefixedResource(): Unit = { + addAcls(aclAuthorizer, Set(allowReadAcl), prefixedResource) + + assertTrue(authorize(aclAuthorizer, requestContext, READ, resource)) + } + + @Test + def testDeleteAclOnPrefixedResource(): Unit = { + addAcls(aclAuthorizer, Set(allowReadAcl, allowWriteAcl), prefixedResource) + + removeAcls(aclAuthorizer, Set(allowReadAcl), prefixedResource) + + assertEquals(Set(allowWriteAcl), getAcls(aclAuthorizer, prefixedResource)) + } + + @Test + def testDeleteAllAclOnPrefixedResource(): Unit = { + addAcls(aclAuthorizer, Set(allowReadAcl, allowWriteAcl), prefixedResource) + + removeAcls(aclAuthorizer, Set.empty, prefixedResource) + + assertEquals(Set.empty, getAcls(aclAuthorizer)) + } + + @Test + def testAddAclsOnLiteralResource(): Unit = { + addAcls(aclAuthorizer, Set(allowReadAcl, allowWriteAcl), resource) + addAcls(aclAuthorizer, Set(allowWriteAcl, denyReadAcl), resource) + + assertEquals(Set(allowReadAcl, allowWriteAcl, denyReadAcl), getAcls(aclAuthorizer, resource)) + assertEquals(Set.empty, getAcls(aclAuthorizer, wildCardResource)) + assertEquals(Set.empty, getAcls(aclAuthorizer, prefixedResource)) + } + + @Test + def testAddAclsOnWildcardResource(): Unit = { + addAcls(aclAuthorizer, Set(allowReadAcl, allowWriteAcl), wildCardResource) + addAcls(aclAuthorizer, Set(allowWriteAcl, denyReadAcl), wildCardResource) + + assertEquals(Set(allowReadAcl, allowWriteAcl, denyReadAcl), getAcls(aclAuthorizer, wildCardResource)) + assertEquals(Set.empty, getAcls(aclAuthorizer, resource)) + assertEquals(Set.empty, getAcls(aclAuthorizer, prefixedResource)) + } + + @Test + def testAddAclsOnPrefixedResource(): Unit = { + addAcls(aclAuthorizer, Set(allowReadAcl, allowWriteAcl), prefixedResource) + addAcls(aclAuthorizer, Set(allowWriteAcl, denyReadAcl), prefixedResource) + + assertEquals(Set(allowReadAcl, allowWriteAcl, denyReadAcl), getAcls(aclAuthorizer, prefixedResource)) + assertEquals(Set.empty, getAcls(aclAuthorizer, wildCardResource)) + assertEquals(Set.empty, getAcls(aclAuthorizer, resource)) + } + + @Test + def testAuthorizeWithPrefixedResource(): Unit = { + addAcls(aclAuthorizer, Set(denyReadAcl), new ResourcePattern(TOPIC, "a_other", LITERAL)) + addAcls(aclAuthorizer, Set(denyReadAcl), new ResourcePattern(TOPIC, "a_other", PREFIXED)) + addAcls(aclAuthorizer, Set(denyReadAcl), new ResourcePattern(TOPIC, "foo-" + UUID.randomUUID(), PREFIXED)) + addAcls(aclAuthorizer, Set(denyReadAcl), new ResourcePattern(TOPIC, "foo-" + UUID.randomUUID(), PREFIXED)) + addAcls(aclAuthorizer, Set(denyReadAcl), new ResourcePattern(TOPIC, "foo-" + UUID.randomUUID() + "-zzz", PREFIXED)) + addAcls(aclAuthorizer, Set(denyReadAcl), new ResourcePattern(TOPIC, "fooo-" + UUID.randomUUID(), PREFIXED)) + addAcls(aclAuthorizer, Set(denyReadAcl), new ResourcePattern(TOPIC, "fo-" + UUID.randomUUID(), PREFIXED)) + addAcls(aclAuthorizer, Set(denyReadAcl), new ResourcePattern(TOPIC, "fop-" + UUID.randomUUID(), PREFIXED)) + addAcls(aclAuthorizer, Set(denyReadAcl), new ResourcePattern(TOPIC, "fon-" + UUID.randomUUID(), PREFIXED)) + addAcls(aclAuthorizer, Set(denyReadAcl), new ResourcePattern(TOPIC, "fon-", PREFIXED)) + addAcls(aclAuthorizer, Set(denyReadAcl), new ResourcePattern(TOPIC, "z_other", PREFIXED)) + addAcls(aclAuthorizer, Set(denyReadAcl), new ResourcePattern(TOPIC, "z_other", LITERAL)) + + addAcls(aclAuthorizer, Set(allowReadAcl), prefixedResource) + + assertTrue(authorize(aclAuthorizer, requestContext, READ, resource)) + } + + @Test + def testSingleCharacterResourceAcls(): Unit = { + addAcls(aclAuthorizer, Set(allowReadAcl), new ResourcePattern(TOPIC, "f", LITERAL)) + assertTrue(authorize(aclAuthorizer, requestContext, READ, new ResourcePattern(TOPIC, "f", LITERAL))) + assertFalse(authorize(aclAuthorizer, requestContext, READ, new ResourcePattern(TOPIC, "foo", LITERAL))) + + addAcls(aclAuthorizer, Set(allowReadAcl), new ResourcePattern(TOPIC, "_", PREFIXED)) + assertTrue(authorize(aclAuthorizer, requestContext, READ, new ResourcePattern(TOPIC, "_foo", LITERAL))) + assertTrue(authorize(aclAuthorizer, requestContext, READ, new ResourcePattern(TOPIC, "_", LITERAL))) + assertFalse(authorize(aclAuthorizer, requestContext, READ, new ResourcePattern(TOPIC, "foo_", LITERAL))) + } + + @Test + def testGetAclsPrincipal(): Unit = { + val aclOnSpecificPrincipal = new AccessControlEntry(principal.toString, WildcardHost, WRITE, ALLOW) + addAcls(aclAuthorizer, Set(aclOnSpecificPrincipal), resource) + + assertEquals(0, + getAcls(aclAuthorizer, wildcardPrincipal).size, "acl on specific should not be returned for wildcard request") + assertEquals(1, + getAcls(aclAuthorizer, principal).size, "acl on specific should be returned for specific request") + assertEquals(1, + getAcls(aclAuthorizer, new KafkaPrincipal(principal.getPrincipalType, principal.getName)).size, "acl on specific should be returned for different principal instance") + + removeAcls(aclAuthorizer, Set.empty, resource) + val aclOnWildcardPrincipal = new AccessControlEntry(WildcardPrincipalString, WildcardHost, WRITE, ALLOW) + addAcls(aclAuthorizer, Set(aclOnWildcardPrincipal), resource) + + assertEquals(1, getAcls(aclAuthorizer, wildcardPrincipal).size, "acl on wildcard should be returned for wildcard request") + assertEquals(0, getAcls(aclAuthorizer, principal).size, "acl on wildcard should not be returned for specific request") + } + + @Test + def testAclsFilter(): Unit = { + val resource1 = new ResourcePattern(TOPIC, "foo-" + UUID.randomUUID(), LITERAL) + val resource2 = new ResourcePattern(TOPIC, "bar-" + UUID.randomUUID(), LITERAL) + val prefixedResource = new ResourcePattern(TOPIC, "bar-", PREFIXED) + + val acl1 = new AclBinding(resource1, new AccessControlEntry(principal.toString, WildcardHost, READ, ALLOW)) + val acl2 = new AclBinding(resource1, new AccessControlEntry(principal.toString, "192.168.0.1", WRITE, ALLOW)) + val acl3 = new AclBinding(resource2, new AccessControlEntry(principal.toString, WildcardHost, DESCRIBE, ALLOW)) + val acl4 = new AclBinding(prefixedResource, new AccessControlEntry(wildcardPrincipal.toString, WildcardHost, READ, ALLOW)) + + aclAuthorizer.createAcls(requestContext, List(acl1, acl2, acl3, acl4).asJava) + assertEquals(Set(acl1, acl2, acl3, acl4), aclAuthorizer.acls(AclBindingFilter.ANY).asScala.toSet) + assertEquals(Set(acl1, acl2), aclAuthorizer.acls(new AclBindingFilter(resource1.toFilter, AccessControlEntryFilter.ANY)).asScala.toSet) + assertEquals(Set(acl4), aclAuthorizer.acls(new AclBindingFilter(prefixedResource.toFilter, AccessControlEntryFilter.ANY)).asScala.toSet) + val matchingFilter = new AclBindingFilter(new ResourcePatternFilter(ResourceType.ANY, resource2.name, MATCH), AccessControlEntryFilter.ANY) + assertEquals(Set(acl3, acl4), aclAuthorizer.acls(matchingFilter).asScala.toSet) + + val filters = List(matchingFilter, + acl1.toFilter, + new AclBindingFilter(resource2.toFilter, AccessControlEntryFilter.ANY), + new AclBindingFilter(new ResourcePatternFilter(TOPIC, "baz", PatternType.ANY), AccessControlEntryFilter.ANY)) + val deleteResults = aclAuthorizer.deleteAcls(requestContext, filters.asJava).asScala.map(_.toCompletableFuture.get) + assertEquals(List.empty, deleteResults.filter(_.exception.isPresent)) + filters.indices.foreach { i => + assertEquals(Set.empty, deleteResults(i).aclBindingDeleteResults.asScala.toSet.filter(_.exception.isPresent)) + } + assertEquals(Set(acl3, acl4), deleteResults(0).aclBindingDeleteResults.asScala.map(_.aclBinding).toSet) + assertEquals(Set(acl1), deleteResults(1).aclBindingDeleteResults.asScala.map(_.aclBinding).toSet) + assertEquals(Set.empty, deleteResults(2).aclBindingDeleteResults.asScala.map(_.aclBinding).toSet) + assertEquals(Set.empty, deleteResults(3).aclBindingDeleteResults.asScala.map(_.aclBinding).toSet) + } + + @Test + def testThrowsOnAddPrefixedAclIfInterBrokerProtocolVersionTooLow(): Unit = { + givenAuthorizerWithProtocolVersion(Option(KAFKA_2_0_IV0)) + val e = assertThrows(classOf[ApiException], + () => addAcls(aclAuthorizer, Set(denyReadAcl), new ResourcePattern(TOPIC, "z_other", PREFIXED))) + assertTrue(e.getCause.isInstanceOf[UnsupportedVersionException], s"Unexpected exception $e") + } + + @Test + def testWritesExtendedAclChangeEventIfInterBrokerProtocolNotSet(): Unit = { + givenAuthorizerWithProtocolVersion(Option.empty) + val resource = new ResourcePattern(TOPIC, "z_other", PREFIXED) + val expected = new String(ZkAclStore(PREFIXED).changeStore + .createChangeNode(resource).bytes, UTF_8) + + addAcls(aclAuthorizer, Set(denyReadAcl), resource) + + val actual = getAclChangeEventAsString(PREFIXED) + + assertEquals(expected, actual) + } + + @Test + def testWritesExtendedAclChangeEventWhenInterBrokerProtocolAtLeastKafkaV2(): Unit = { + givenAuthorizerWithProtocolVersion(Option(KAFKA_2_0_IV1)) + val resource = new ResourcePattern(TOPIC, "z_other", PREFIXED) + val expected = new String(ZkAclStore(PREFIXED).changeStore + .createChangeNode(resource).bytes, UTF_8) + + addAcls(aclAuthorizer, Set(denyReadAcl), resource) + + val actual = getAclChangeEventAsString(PREFIXED) + + assertEquals(expected, actual) + } + + @Test + def testWritesLiteralWritesLiteralAclChangeEventWhenInterBrokerProtocolLessThanKafkaV2eralAclChangesForOlderProtocolVersions(): Unit = { + givenAuthorizerWithProtocolVersion(Option(KAFKA_2_0_IV0)) + val resource = new ResourcePattern(TOPIC, "z_other", LITERAL) + val expected = new String(ZkAclStore(LITERAL).changeStore + .createChangeNode(resource).bytes, UTF_8) + + addAcls(aclAuthorizer, Set(denyReadAcl), resource) + + val actual = getAclChangeEventAsString(LITERAL) + + assertEquals(expected, actual) + } + + @Test + def testWritesLiteralAclChangeEventWhenInterBrokerProtocolIsKafkaV2(): Unit = { + givenAuthorizerWithProtocolVersion(Option(KAFKA_2_0_IV1)) + val resource = new ResourcePattern(TOPIC, "z_other", LITERAL) + val expected = new String(ZkAclStore(LITERAL).changeStore + .createChangeNode(resource).bytes, UTF_8) + + addAcls(aclAuthorizer, Set(denyReadAcl), resource) + + val actual = getAclChangeEventAsString(LITERAL) + + assertEquals(expected, actual) + } + + @Test + def testAuthorizerNoZkConfig(): Unit = { + val noTlsProps = Kafka.getPropsFromArgs(Array(prepareDefaultConfig)) + val zkClientConfig = AclAuthorizer.zkClientConfigFromKafkaConfigAndMap( + KafkaConfig.fromProps(noTlsProps), + noTlsProps.asInstanceOf[java.util.Map[String, Any]].asScala) + KafkaConfig.ZkSslConfigToSystemPropertyMap.keys.foreach { propName => + assertNull(zkClientConfig.getProperty(propName)) + } + } + + @Test + def testAuthorizerZkConfigFromKafkaConfigWithDefaults(): Unit = { + val props = new java.util.Properties() + val kafkaValue = "kafkaValue" + val configs = Map("zookeeper.connect" -> "somewhere", // required, otherwise we would omit it + KafkaConfig.ZkSslClientEnableProp -> "true", + KafkaConfig.ZkClientCnxnSocketProp -> kafkaValue, + KafkaConfig.ZkSslKeyStoreLocationProp -> kafkaValue, + KafkaConfig.ZkSslKeyStorePasswordProp -> kafkaValue, + KafkaConfig.ZkSslKeyStoreTypeProp -> kafkaValue, + KafkaConfig.ZkSslTrustStoreLocationProp -> kafkaValue, + KafkaConfig.ZkSslTrustStorePasswordProp -> kafkaValue, + KafkaConfig.ZkSslTrustStoreTypeProp -> kafkaValue, + KafkaConfig.ZkSslEnabledProtocolsProp -> kafkaValue, + KafkaConfig.ZkSslCipherSuitesProp -> kafkaValue) + configs.foreach { case (key, value) => props.put(key, value) } + + val zkClientConfig = AclAuthorizer.zkClientConfigFromKafkaConfigAndMap( + KafkaConfig.fromProps(props), mutable.Map(configs.toSeq: _*)) + // confirm we get all the values we expect + KafkaConfig.ZkSslConfigToSystemPropertyMap.keys.foreach(prop => prop match { + case KafkaConfig.ZkSslClientEnableProp | KafkaConfig.ZkSslEndpointIdentificationAlgorithmProp => + assertEquals("true", KafkaConfig.zooKeeperClientProperty(zkClientConfig, prop).getOrElse("")) + case KafkaConfig.ZkSslCrlEnableProp | KafkaConfig.ZkSslOcspEnableProp => + assertEquals("false", KafkaConfig.zooKeeperClientProperty(zkClientConfig, prop).getOrElse("")) + case KafkaConfig.ZkSslProtocolProp => + assertEquals("TLSv1.2", KafkaConfig.zooKeeperClientProperty(zkClientConfig, prop).getOrElse("")) + case _ => assertEquals(kafkaValue, KafkaConfig.zooKeeperClientProperty(zkClientConfig, prop).getOrElse("")) + }) + } + + @Test + def testAuthorizerZkConfigFromKafkaConfig(): Unit = { + val props = new java.util.Properties() + val kafkaValue = "kafkaValue" + val configs = Map("zookeeper.connect" -> "somewhere", // required, otherwise we would omit it + KafkaConfig.ZkSslClientEnableProp -> "true", + KafkaConfig.ZkClientCnxnSocketProp -> kafkaValue, + KafkaConfig.ZkSslKeyStoreLocationProp -> kafkaValue, + KafkaConfig.ZkSslKeyStorePasswordProp -> kafkaValue, + KafkaConfig.ZkSslKeyStoreTypeProp -> kafkaValue, + KafkaConfig.ZkSslTrustStoreLocationProp -> kafkaValue, + KafkaConfig.ZkSslTrustStorePasswordProp -> kafkaValue, + KafkaConfig.ZkSslTrustStoreTypeProp -> kafkaValue, + KafkaConfig.ZkSslProtocolProp -> kafkaValue, + KafkaConfig.ZkSslEnabledProtocolsProp -> kafkaValue, + KafkaConfig.ZkSslCipherSuitesProp -> kafkaValue, + KafkaConfig.ZkSslEndpointIdentificationAlgorithmProp -> "HTTPS", + KafkaConfig.ZkSslCrlEnableProp -> "false", + KafkaConfig.ZkSslOcspEnableProp -> "false") + configs.foreach{case (key, value) => props.put(key, value.toString) } + + val zkClientConfig = AclAuthorizer.zkClientConfigFromKafkaConfigAndMap( + KafkaConfig.fromProps(props), mutable.Map(configs.toSeq: _*)) + // confirm we get all the values we expect + KafkaConfig.ZkSslConfigToSystemPropertyMap.keys.foreach(prop => prop match { + case KafkaConfig.ZkSslClientEnableProp | KafkaConfig.ZkSslEndpointIdentificationAlgorithmProp => + assertEquals("true", KafkaConfig.zooKeeperClientProperty(zkClientConfig, prop).getOrElse("")) + case KafkaConfig.ZkSslCrlEnableProp | KafkaConfig.ZkSslOcspEnableProp => + assertEquals("false", KafkaConfig.zooKeeperClientProperty(zkClientConfig, prop).getOrElse("")) + case _ => assertEquals(kafkaValue, KafkaConfig.zooKeeperClientProperty(zkClientConfig, prop).getOrElse("")) + }) + } + + @Test + def testAuthorizerZkConfigFromPrefixOverrides(): Unit = { + val props = new java.util.Properties() + val kafkaValue = "kafkaValue" + val prefixedValue = "prefixedValue" + val prefix = "authorizer." + val configs = Map("zookeeper.connect" -> "somewhere", // required, otherwise we would omit it + KafkaConfig.ZkSslClientEnableProp -> "false", + KafkaConfig.ZkClientCnxnSocketProp -> kafkaValue, + KafkaConfig.ZkSslKeyStoreLocationProp -> kafkaValue, + KafkaConfig.ZkSslKeyStorePasswordProp -> kafkaValue, + KafkaConfig.ZkSslKeyStoreTypeProp -> kafkaValue, + KafkaConfig.ZkSslTrustStoreLocationProp -> kafkaValue, + KafkaConfig.ZkSslTrustStorePasswordProp -> kafkaValue, + KafkaConfig.ZkSslTrustStoreTypeProp -> kafkaValue, + KafkaConfig.ZkSslProtocolProp -> kafkaValue, + KafkaConfig.ZkSslEnabledProtocolsProp -> kafkaValue, + KafkaConfig.ZkSslCipherSuitesProp -> kafkaValue, + KafkaConfig.ZkSslEndpointIdentificationAlgorithmProp -> "HTTPS", + KafkaConfig.ZkSslCrlEnableProp -> "false", + KafkaConfig.ZkSslOcspEnableProp -> "false", + prefix + KafkaConfig.ZkSslClientEnableProp -> "true", + prefix + KafkaConfig.ZkClientCnxnSocketProp -> prefixedValue, + prefix + KafkaConfig.ZkSslKeyStoreLocationProp -> prefixedValue, + prefix + KafkaConfig.ZkSslKeyStorePasswordProp -> prefixedValue, + prefix + KafkaConfig.ZkSslKeyStoreTypeProp -> prefixedValue, + prefix + KafkaConfig.ZkSslTrustStoreLocationProp -> prefixedValue, + prefix + KafkaConfig.ZkSslTrustStorePasswordProp -> prefixedValue, + prefix + KafkaConfig.ZkSslTrustStoreTypeProp -> prefixedValue, + prefix + KafkaConfig.ZkSslProtocolProp -> prefixedValue, + prefix + KafkaConfig.ZkSslEnabledProtocolsProp -> prefixedValue, + prefix + KafkaConfig.ZkSslCipherSuitesProp -> prefixedValue, + prefix + KafkaConfig.ZkSslEndpointIdentificationAlgorithmProp -> "", + prefix + KafkaConfig.ZkSslCrlEnableProp -> "true", + prefix + KafkaConfig.ZkSslOcspEnableProp -> "true") + configs.foreach{case (key, value) => props.put(key, value.toString) } + + val zkClientConfig = AclAuthorizer.zkClientConfigFromKafkaConfigAndMap( + KafkaConfig.fromProps(props), mutable.Map(configs.toSeq: _*)) + // confirm we get all the values we expect + KafkaConfig.ZkSslConfigToSystemPropertyMap.keys.foreach(prop => prop match { + case KafkaConfig.ZkSslClientEnableProp | KafkaConfig.ZkSslCrlEnableProp | KafkaConfig.ZkSslOcspEnableProp => + assertEquals("true", KafkaConfig.zooKeeperClientProperty(zkClientConfig, prop).getOrElse("")) + case KafkaConfig.ZkSslEndpointIdentificationAlgorithmProp => + assertEquals("false", KafkaConfig.zooKeeperClientProperty(zkClientConfig, prop).getOrElse("")) + case _ => assertEquals(prefixedValue, KafkaConfig.zooKeeperClientProperty(zkClientConfig, prop).getOrElse("")) + }) + } + + @Test + def testCreateDeleteTiming(): Unit = { + val literalResource = new ResourcePattern(TOPIC, "foo-" + UUID.randomUUID(), LITERAL) + val prefixedResource = new ResourcePattern(TOPIC, "bar-", PREFIXED) + val wildcardResource = new ResourcePattern(TOPIC, "*", LITERAL) + val ace = new AccessControlEntry(principal.toString, WildcardHost, READ, ALLOW) + val updateSemaphore = new Semaphore(1) + + def createAcl(createAuthorizer: AclAuthorizer, resource: ResourcePattern): AclBinding = { + val acl = new AclBinding(resource, ace) + createAuthorizer.createAcls(requestContext, Collections.singletonList(acl)).asScala + .foreach(_.toCompletableFuture.get(15, TimeUnit.SECONDS)) + acl + } + + def deleteAcl(deleteAuthorizer: AclAuthorizer, + resource: ResourcePattern, + deletePatternType: PatternType): List[AclBinding] = { + + val filter = new AclBindingFilter( + new ResourcePatternFilter(resource.resourceType(), resource.name(), deletePatternType), + AccessControlEntryFilter.ANY) + deleteAuthorizer.deleteAcls(requestContext, Collections.singletonList(filter)).asScala + .map(_.toCompletableFuture.get(15, TimeUnit.SECONDS)) + .flatMap(_.aclBindingDeleteResults.asScala) + .map(_.aclBinding) + .toList + } + + def listAcls(authorizer: AclAuthorizer): List[AclBinding] = { + authorizer.acls(AclBindingFilter.ANY).asScala.toList + } + + def verifyCreateDeleteAcl(deleteAuthorizer: AclAuthorizer, + resource: ResourcePattern, + deletePatternType: PatternType): Unit = { + updateSemaphore.acquire() + assertEquals(List.empty, listAcls(deleteAuthorizer)) + val acl = createAcl(aclAuthorizer, resource) + val deleted = deleteAcl(deleteAuthorizer, resource, deletePatternType) + if (deletePatternType != PatternType.MATCH) { + assertEquals(List(acl), deleted) + } else { + assertEquals(List.empty[AclBinding], deleted) + } + updateSemaphore.release() + if (deletePatternType == PatternType.MATCH) { + TestUtils.waitUntilTrue(() => listAcls(deleteAuthorizer).nonEmpty, "ACL not propagated") + assertEquals(List(acl), deleteAcl(deleteAuthorizer, resource, deletePatternType)) + } + TestUtils.waitUntilTrue(() => listAcls(deleteAuthorizer).isEmpty, "ACL delete not propagated") + } + + val deleteAuthorizer = new AclAuthorizer { + override def processAclChangeNotification(resource: ResourcePattern): Unit = { + updateSemaphore.acquire() + try { + super.processAclChangeNotification(resource) + } finally { + updateSemaphore.release() + } + } + } + + try { + deleteAuthorizer.configure(config.originals) + List(literalResource, prefixedResource, wildcardResource).foreach { resource => + verifyCreateDeleteAcl(deleteAuthorizer, resource, resource.patternType()) + verifyCreateDeleteAcl(deleteAuthorizer, resource, PatternType.ANY) + verifyCreateDeleteAcl(deleteAuthorizer, resource, PatternType.MATCH) + } + } finally { + deleteAuthorizer.close() + } + } + + @Test + def testAuthorizeByResourceTypeNoAclFoundOverride(): Unit = { + val props = TestUtils.createBrokerConfig(1, zkConnect) + props.put(AclAuthorizer.AllowEveryoneIfNoAclIsFoundProp, "true") + + val cfg = KafkaConfig.fromProps(props) + val aclAuthorizer = new AclAuthorizer + try { + aclAuthorizer.configure(cfg.originals) + assertTrue(authorizeByResourceType(aclAuthorizer, requestContext, READ, resource.resourceType()), + "If allow.everyone.if.no.acl.found = true, caller should have read access to at least one topic") + assertTrue(authorizeByResourceType(aclAuthorizer, requestContext, WRITE, resource.resourceType()), + "If allow.everyone.if.no.acl.found = true, caller should have write access to at least one topic") + } finally { + aclAuthorizer.close() + } + } + + private def givenAuthorizerWithProtocolVersion(protocolVersion: Option[ApiVersion]): Unit = { + aclAuthorizer.close() + + val props = TestUtils.createBrokerConfig(0, zkConnect) + props.put(AclAuthorizer.SuperUsersProp, superUsers) + protocolVersion.foreach(version => props.put(KafkaConfig.InterBrokerProtocolVersionProp, version.toString)) + + config = KafkaConfig.fromProps(props) + + aclAuthorizer.configure(config.originals) + } + + private def getAclChangeEventAsString(patternType: PatternType) = { + val store = ZkAclStore(patternType) + val children = zooKeeperClient.handleRequest(GetChildrenRequest(store.changeStore.aclChangePath, registerWatch = true)) + children.maybeThrow() + assertEquals(1, children.children.size, "Expecting 1 change event") + + val data = zooKeeperClient.handleRequest(GetDataRequest(s"${store.changeStore.aclChangePath}/${children.children.head}")) + data.maybeThrow() + + new String(data.data, UTF_8) + } + + private def changeAclAndVerify(originalAcls: Set[AccessControlEntry], + addedAcls: Set[AccessControlEntry], + removedAcls: Set[AccessControlEntry], + resource: ResourcePattern = resource): Set[AccessControlEntry] = { + var acls = originalAcls + + if(addedAcls.nonEmpty) { + addAcls(aclAuthorizer, addedAcls, resource) + acls ++= addedAcls + } + + if(removedAcls.nonEmpty) { + removeAcls(aclAuthorizer, removedAcls, resource) + acls --=removedAcls + } + + TestUtils.waitAndVerifyAcls(acls, aclAuthorizer, resource) + + acls + } + + private def authorize(authorizer: AclAuthorizer, requestContext: RequestContext, operation: AclOperation, resource: ResourcePattern): Boolean = { + val action = new Action(operation, resource, 1, true, true) + authorizer.authorize(requestContext, List(action).asJava).asScala.head == AuthorizationResult.ALLOWED + } + + private def getAcls(authorizer: AclAuthorizer, resourcePattern: ResourcePattern): Set[AccessControlEntry] = { + val acls = authorizer.acls(new AclBindingFilter(resourcePattern.toFilter, AccessControlEntryFilter.ANY)).asScala.toSet + acls.map(_.entry) + } + + private def getAcls(authorizer: AclAuthorizer, principal: KafkaPrincipal): Set[AclBinding] = { + val filter = new AclBindingFilter(ResourcePatternFilter.ANY, + new AccessControlEntryFilter(principal.toString, null, AclOperation.ANY, AclPermissionType.ANY)) + authorizer.acls(filter).asScala.toSet + } + + private def getAcls(authorizer: AclAuthorizer): Set[AclBinding] = { + authorizer.acls(AclBindingFilter.ANY).asScala.toSet + } + + private def validOp(op: AclOperation): Boolean = { + op != AclOperation.ANY && op != AclOperation.UNKNOWN + } + + private def prepareDefaultConfig: String = + prepareConfig(Array("broker.id=1", "zookeeper.connect=somewhere")) + + private def prepareConfig(lines : Array[String]): String = { + val file = File.createTempFile("kafkatest", ".properties") + file.deleteOnExit() + + val writer = Files.newOutputStream(file.toPath) + try { + lines.foreach { l => + writer.write(l.getBytes) + writer.write("\n".getBytes) + } + file.getAbsolutePath + } finally writer.close() + } +} diff --git a/core/src/test/scala/unit/kafka/security/authorizer/AclAuthorizerWithZkSaslTest.scala b/core/src/test/scala/unit/kafka/security/authorizer/AclAuthorizerWithZkSaslTest.scala new file mode 100644 index 0000000..e38cafd --- /dev/null +++ b/core/src/test/scala/unit/kafka/security/authorizer/AclAuthorizerWithZkSaslTest.scala @@ -0,0 +1,186 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.security.authorizer + +import java.net.InetAddress +import java.util +import java.util.UUID +import java.util.concurrent.{Executors, TimeUnit} + +import javax.security.auth.Subject +import javax.security.auth.callback.CallbackHandler +import kafka.api.SaslSetup +import kafka.security.authorizer.AclEntry.WildcardHost +import kafka.server.{KafkaConfig, QuorumTestHarness} +import kafka.utils.JaasTestUtils.{JaasModule, JaasSection} +import kafka.utils.{JaasTestUtils, TestUtils} +import kafka.zk.KafkaZkClient +import kafka.zookeeper.ZooKeeperClient +import org.apache.kafka.common.acl.{AccessControlEntry, AccessControlEntryFilter, AclBinding, AclBindingFilter} +import org.apache.kafka.common.acl.AclOperation.{READ, WRITE} +import org.apache.kafka.common.acl.AclPermissionType.ALLOW +import org.apache.kafka.common.network.{ClientInformation, ListenerName} +import org.apache.kafka.common.protocol.ApiKeys +import org.apache.kafka.common.requests.{RequestContext, RequestHeader} +import org.apache.kafka.common.resource.PatternType.LITERAL +import org.apache.kafka.common.resource.ResourcePattern +import org.apache.kafka.common.resource.ResourceType.TOPIC +import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol} +import org.apache.kafka.test.{TestUtils => JTestUtils} +import org.apache.zookeeper.server.auth.DigestLoginModule +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +import scala.jdk.CollectionConverters._ +import scala.collection.Seq + +class AclAuthorizerWithZkSaslTest extends QuorumTestHarness with SaslSetup { + + private val aclAuthorizer = new AclAuthorizer + private val aclAuthorizer2 = new AclAuthorizer + private val resource: ResourcePattern = new ResourcePattern(TOPIC, "foo-" + UUID.randomUUID(), LITERAL) + private val username = "alice" + private val principal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, username) + private val requestContext = newRequestContext(principal, InetAddress.getByName("192.168.0.1")) + private val executor = Executors.newSingleThreadScheduledExecutor + private var config: KafkaConfig = _ + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + // Allow failed clients to avoid server closing the connection before reporting AuthFailed. + System.setProperty("zookeeper.allowSaslFailedClients", "true") + + // Configure ZK SASL with TestableDigestLoginModule for clients to inject failures + TestableDigestLoginModule.reset() + val jaasSections = JaasTestUtils.zkSections + val serverJaas = jaasSections.filter(_.contextName == "Server") + val clientJaas = jaasSections.filter(_.contextName == "Client") + .map(section => new TestableJaasSection(section.contextName, section.modules)) + startSasl(serverJaas ++ clientJaas) + + // Increase maxUpdateRetries to avoid transient failures + aclAuthorizer.maxUpdateRetries = Int.MaxValue + aclAuthorizer2.maxUpdateRetries = Int.MaxValue + + super.setUp(testInfo) + config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(0, zkConnect)) + + aclAuthorizer.configure(config.originals) + aclAuthorizer2.configure(config.originals) + } + + @AfterEach + override def tearDown(): Unit = { + System.clearProperty("zookeeper.allowSaslFailedClients") + TestableDigestLoginModule.reset() + executor.shutdownNow() + aclAuthorizer.close() + aclAuthorizer2.close() + super.tearDown() + } + + @Test + def testAclUpdateWithSessionExpiration(): Unit = { + zkClient(aclAuthorizer).currentZooKeeper.getTestable.injectSessionExpiration() + zkClient(aclAuthorizer2).currentZooKeeper.getTestable.injectSessionExpiration() + verifyAclUpdate() + } + + @Test + def testAclUpdateWithAuthFailure(): Unit = { + injectTransientAuthenticationFailure() + verifyAclUpdate() + } + + private def injectTransientAuthenticationFailure(): Unit = { + TestableDigestLoginModule.injectInvalidCredentials() + zkClient(aclAuthorizer).currentZooKeeper.getTestable.injectSessionExpiration() + zkClient(aclAuthorizer2).currentZooKeeper.getTestable.injectSessionExpiration() + executor.schedule((() => TestableDigestLoginModule.reset()): Runnable, + ZooKeeperClient.RetryBackoffMs * 2, TimeUnit.MILLISECONDS) + } + + private def verifyAclUpdate(): Unit = { + val allowReadAcl = new AccessControlEntry(principal.toString, WildcardHost, READ, ALLOW) + val allowWriteAcl = new AccessControlEntry(principal.toString, WildcardHost, WRITE, ALLOW) + val acls = Set(allowReadAcl, allowWriteAcl) + + TestUtils.retry(maxWaitMs = 15000) { + try { + addAcls(aclAuthorizer, acls, resource) + } catch { + case _: Exception => // Ignore error and retry + } + assertEquals(acls, getAcls(aclAuthorizer, resource)) + } + val (acls2, _) = TestUtils.computeUntilTrue(getAcls(aclAuthorizer2, resource)) { _ == acls } + assertEquals(acls, acls2) + } + + private def zkClient(authorizer: AclAuthorizer): KafkaZkClient = { + JTestUtils.fieldValue(authorizer, classOf[AclAuthorizer], "zkClient") + } + + private def addAcls(authorizer: AclAuthorizer, aces: Set[AccessControlEntry], resourcePattern: ResourcePattern): Unit = { + val bindings = aces.map { ace => new AclBinding(resourcePattern, ace) } + authorizer.createAcls(requestContext, bindings.toList.asJava).asScala + .map(_.toCompletableFuture.get) + .foreach { result => result.exception.ifPresent { e => throw e } } + } + + private def getAcls(authorizer: AclAuthorizer, resourcePattern: ResourcePattern): Set[AccessControlEntry] = { + val acls = authorizer.acls(new AclBindingFilter(resourcePattern.toFilter, AccessControlEntryFilter.ANY)).asScala.toSet + acls.map(_.entry) + } + + private def newRequestContext(principal: KafkaPrincipal, clientAddress: InetAddress, apiKey: ApiKeys = ApiKeys.PRODUCE): RequestContext = { + val securityProtocol = SecurityProtocol.SASL_PLAINTEXT + val header = new RequestHeader(apiKey, 2, "", 1) //ApiKeys apiKey, short version, String clientId, int correlation + new RequestContext(header, "", clientAddress, principal, ListenerName.forSecurityProtocol(securityProtocol), + securityProtocol, ClientInformation.EMPTY, false) + } +} + +object TestableDigestLoginModule { + @volatile var injectedPassword: Option[String] = None + + def reset(): Unit = { + injectedPassword = None + } + + def injectInvalidCredentials(): Unit = { + injectedPassword = Some("invalidPassword") + } +} + +class TestableDigestLoginModule extends DigestLoginModule { + override def initialize(subject: Subject, callbackHandler: CallbackHandler, sharedState: util.Map[String, _], options: util.Map[String, _]): Unit = { + super.initialize(subject, callbackHandler, sharedState, options) + val injectedPassword = TestableDigestLoginModule.injectedPassword + injectedPassword.foreach { newPassword => + val oldPassword = subject.getPrivateCredentials.asScala.head + subject.getPrivateCredentials.add(newPassword) + subject.getPrivateCredentials.remove(oldPassword) + } + } +} + +class TestableJaasSection(contextName: String, modules: Seq[JaasModule]) extends JaasSection(contextName, modules) { + override def toString: String = { + super.toString.replaceFirst(classOf[DigestLoginModule].getName, classOf[TestableDigestLoginModule].getName) + } +} diff --git a/core/src/test/scala/unit/kafka/security/authorizer/AclEntryTest.scala b/core/src/test/scala/unit/kafka/security/authorizer/AclEntryTest.scala new file mode 100644 index 0000000..ddf00a5 --- /dev/null +++ b/core/src/test/scala/unit/kafka/security/authorizer/AclEntryTest.scala @@ -0,0 +1,47 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.security.authorizer + +import java.nio.charset.StandardCharsets.UTF_8 + +import kafka.utils.Json +import org.apache.kafka.common.acl.AclOperation.READ +import org.apache.kafka.common.acl.AclPermissionType.{ALLOW, DENY} +import org.apache.kafka.common.security.auth.KafkaPrincipal +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ + +class AclEntryTest { + + val AclJson = """{"version": 1, "acls": [{"host": "host1","permissionType": "Deny","operation": "READ", "principal": "User:alice" }, + { "host": "*" , "permissionType": "Allow", "operation": "Read", "principal": "User:bob" }, + { "host": "host1", "permissionType": "Deny", "operation": "Read" , "principal": "User:bob"}]}""" + + @Test + def testAclJsonConversion(): Unit = { + val acl1 = AclEntry(new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "alice"), DENY, "host1" , READ) + val acl2 = AclEntry(new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "bob"), ALLOW, "*", READ) + val acl3 = AclEntry(new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "bob"), DENY, "host1", READ) + + val acls = Set[AclEntry](acl1, acl2, acl3) + + assertEquals(acls, AclEntry.fromBytes(Json.encodeAsBytes(AclEntry.toJsonCompatibleMap(acls).asJava))) + assertEquals(acls, AclEntry.fromBytes(AclJson.getBytes(UTF_8))) + } +} diff --git a/core/src/test/scala/unit/kafka/security/authorizer/AuthorizerInterfaceDefaultTest.scala b/core/src/test/scala/unit/kafka/security/authorizer/AuthorizerInterfaceDefaultTest.scala new file mode 100644 index 0000000..6852ab3 --- /dev/null +++ b/core/src/test/scala/unit/kafka/security/authorizer/AuthorizerInterfaceDefaultTest.scala @@ -0,0 +1,95 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.security.authorizer + +import java.util.concurrent.CompletionStage +import java.{lang, util} +import kafka.server.KafkaConfig +import kafka.utils.TestUtils +import kafka.server.QuorumTestHarness +import kafka.zookeeper.ZooKeeperClient +import org.apache.kafka.common.Endpoint +import org.apache.kafka.common.acl._ +import org.apache.kafka.common.utils.Time +import org.apache.kafka.server.authorizer._ +import org.apache.zookeeper.client.ZKClientConfig +import org.junit.jupiter.api.{AfterEach, BeforeEach, TestInfo} + +class AuthorizerInterfaceDefaultTest extends QuorumTestHarness with BaseAuthorizerTest { + + private val interfaceDefaultAuthorizer = new DelegateAuthorizer + + override def authorizer: Authorizer = interfaceDefaultAuthorizer + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + + // Increase maxUpdateRetries to avoid transient failures + interfaceDefaultAuthorizer.authorizer.maxUpdateRetries = Int.MaxValue + + val props = TestUtils.createBrokerConfig(0, zkConnect) + props.put(AclAuthorizer.SuperUsersProp, superUsers) + + config = KafkaConfig.fromProps(props) + interfaceDefaultAuthorizer.authorizer.configure(config.originals) + + zooKeeperClient = new ZooKeeperClient(zkConnect, zkSessionTimeout, zkConnectionTimeout, zkMaxInFlightRequests, + Time.SYSTEM, "kafka.test", "AuthorizerInterfaceDefaultTest", new ZKClientConfig, + "AuthorizerInterfaceDefaultTest") + } + + @AfterEach + override def tearDown(): Unit = { + interfaceDefaultAuthorizer.close() + zooKeeperClient.close() + super.tearDown() + } + + class DelegateAuthorizer extends Authorizer { + val authorizer = new AclAuthorizer + + override def start(serverInfo: AuthorizerServerInfo): util.Map[Endpoint, _ <: CompletionStage[Void]] = { + authorizer.start(serverInfo) + } + + override def authorize(requestContext: AuthorizableRequestContext, actions: util.List[Action]): util.List[AuthorizationResult] = { + authorizer.authorize(requestContext, actions) + } + + override def createAcls(requestContext: AuthorizableRequestContext, aclBindings: util.List[AclBinding]): util.List[_ <: CompletionStage[AclCreateResult]] = { + authorizer.createAcls(requestContext, aclBindings) + } + + override def deleteAcls(requestContext: AuthorizableRequestContext, aclBindingFilters: util.List[AclBindingFilter]): util.List[_ <: CompletionStage[AclDeleteResult]] = { + authorizer.deleteAcls(requestContext, aclBindingFilters) + } + + override def acls(filter: AclBindingFilter): lang.Iterable[AclBinding] = { + authorizer.acls(filter) + } + + override def configure(configs: util.Map[String, _]): Unit = { + authorizer.configure(configs) + } + + override def close(): Unit = { + authorizer.close() + } + } + +} diff --git a/core/src/test/scala/unit/kafka/security/authorizer/BaseAuthorizerTest.scala b/core/src/test/scala/unit/kafka/security/authorizer/BaseAuthorizerTest.scala new file mode 100644 index 0000000..c502b48 --- /dev/null +++ b/core/src/test/scala/unit/kafka/security/authorizer/BaseAuthorizerTest.scala @@ -0,0 +1,375 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.security.authorizer + +import java.net.InetAddress +import java.util.UUID + +import kafka.security.authorizer.AclEntry.{WildcardHost, WildcardPrincipalString} +import kafka.server.KafkaConfig +import kafka.zookeeper.ZooKeeperClient +import org.apache.kafka.common.acl.AclOperation.{ALL, READ, WRITE} +import org.apache.kafka.common.acl.AclPermissionType.{ALLOW, DENY} +import org.apache.kafka.common.acl.{AccessControlEntry, AccessControlEntryFilter, AclBinding, AclBindingFilter, AclOperation} +import org.apache.kafka.common.network.{ClientInformation, ListenerName} +import org.apache.kafka.common.protocol.ApiKeys +import org.apache.kafka.common.requests.{RequestContext, RequestHeader} +import org.apache.kafka.common.resource.PatternType.{LITERAL, PREFIXED} +import org.apache.kafka.common.resource.ResourcePattern.WILDCARD_RESOURCE +import org.apache.kafka.common.resource.ResourceType.{CLUSTER, GROUP, TOPIC, TRANSACTIONAL_ID} +import org.apache.kafka.common.resource.{ResourcePattern, ResourceType} +import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol} +import org.apache.kafka.server.authorizer.{AuthorizationResult, Authorizer} +import org.junit.jupiter.api.Assertions.{assertFalse, assertTrue} +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ + +trait BaseAuthorizerTest { + + def authorizer: Authorizer + + val superUsers = "User:superuser1; User:superuser2" + val username = "alice" + val principal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, username) + val requestContext: RequestContext = newRequestContext(principal, InetAddress.getByName("192.168.0.1")) + val superUserName = "superuser1" + var config: KafkaConfig = _ + var zooKeeperClient: ZooKeeperClient = _ + var resource: ResourcePattern = _ + + @Test + def testAuthorizeByResourceTypeMultipleAddAndRemove(): Unit = { + val user1 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "user1") + val host1 = InetAddress.getByName("192.168.1.1") + val resource1 = new ResourcePattern(TOPIC, "sb1" + UUID.randomUUID(), LITERAL) + val denyRead = new AccessControlEntry(user1.toString, host1.getHostAddress, READ, DENY) + val allowRead = new AccessControlEntry(user1.toString, host1.getHostAddress, READ, ALLOW) + val u1h1Context = newRequestContext(user1, host1) + + for (_ <- 1 to 10) { + assertFalse(authorizeByResourceType(authorizer, u1h1Context, READ, ResourceType.TOPIC), + "User1 from host1 should not have READ access to any topic when no ACL exists") + + addAcls(authorizer, Set(allowRead), resource1) + assertTrue(authorizeByResourceType(authorizer, u1h1Context, READ, ResourceType.TOPIC), + "User1 from host1 now should have READ access to at least one topic") + + for (_ <- 1 to 10) { + addAcls(authorizer, Set(denyRead), resource1) + assertFalse(authorizeByResourceType(authorizer, u1h1Context, READ, ResourceType.TOPIC), + "User1 from host1 now should not have READ access to any topic") + + removeAcls(authorizer, Set(denyRead), resource1) + addAcls(authorizer, Set(allowRead), resource1) + assertTrue(authorizeByResourceType(authorizer, u1h1Context, READ, ResourceType.TOPIC), + "User1 from host1 now should have READ access to at least one topic") + } + + removeAcls(authorizer, Set(allowRead), resource1) + assertFalse(authorizeByResourceType(authorizer, u1h1Context, READ, ResourceType.TOPIC), + "User1 from host1 now should not have READ access to any topic") + } + } + + @Test + def testAuthorizeByResourceTypeIsolationUnrelatedDenyWontDominateAllow(): Unit = { + val user1 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "user1") + val user2 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "user2") + val host1 = InetAddress.getByName("192.168.1.1") + val host2 = InetAddress.getByName("192.168.1.2") + val resource1 = new ResourcePattern(TOPIC, "sb1" + UUID.randomUUID(), LITERAL) + val resource2 = new ResourcePattern(TOPIC, "sb2" + UUID.randomUUID(), LITERAL) + val resource3 = new ResourcePattern(GROUP, "s", PREFIXED) + + val acl1 = new AccessControlEntry(user1.toString, host1.getHostAddress, READ, DENY) + val acl2 = new AccessControlEntry(user2.toString, host1.getHostAddress, READ, DENY) + val acl3 = new AccessControlEntry(user1.toString, host2.getHostAddress, WRITE, DENY) + val acl4 = new AccessControlEntry(user1.toString, host2.getHostAddress, READ, DENY) + val acl5 = new AccessControlEntry(user1.toString, host2.getHostAddress, READ, DENY) + val acl6 = new AccessControlEntry(user2.toString, host2.getHostAddress, READ, DENY) + val acl7 = new AccessControlEntry(user1.toString, host2.getHostAddress, READ, ALLOW) + + addAcls(authorizer, Set(acl1, acl2, acl3, acl6, acl7), resource1) + addAcls(authorizer, Set(acl4), resource2) + addAcls(authorizer, Set(acl5), resource3) + + val u1h1Context = newRequestContext(user1, host1) + val u1h2Context = newRequestContext(user1, host2) + + assertFalse(authorizeByResourceType(authorizer, u1h1Context, READ, ResourceType.TOPIC), + "User1 from host1 should not have READ access to any topic") + assertFalse(authorizeByResourceType(authorizer, u1h1Context, READ, ResourceType.GROUP), + "User1 from host2 should not have READ access to any consumer group") + assertFalse(authorizeByResourceType(authorizer, u1h1Context, READ, ResourceType.TRANSACTIONAL_ID), + "User1 from host2 should not have READ access to any topic") + assertFalse(authorizeByResourceType(authorizer, u1h1Context, READ, ResourceType.CLUSTER), + "User1 from host2 should not have READ access to any topic") + assertTrue(authorizeByResourceType(authorizer, u1h2Context, READ, ResourceType.TOPIC), + "User1 from host2 should have READ access to at least one topic") + } + + @Test + def testAuthorizeByResourceTypeDenyTakesPrecedence(): Unit = { + val user1 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "user1") + val host1 = InetAddress.getByName("192.168.1.1") + val resource1 = new ResourcePattern(TOPIC, "sb1" + UUID.randomUUID(), LITERAL) + + val u1h1Context = newRequestContext(user1, host1) + val acl1 = new AccessControlEntry(user1.toString, host1.getHostAddress, WRITE, ALLOW) + val acl2 = new AccessControlEntry(user1.toString, host1.getHostAddress, WRITE, DENY) + + addAcls(authorizer, Set(acl1), resource1) + assertTrue(authorizeByResourceType(authorizer, u1h1Context, WRITE, ResourceType.TOPIC), + "User1 from host1 should have WRITE access to at least one topic") + + addAcls(authorizer, Set(acl2), resource1) + assertFalse(authorizeByResourceType(authorizer, u1h1Context, WRITE, ResourceType.TOPIC), + "User1 from host1 should not have WRITE access to any topic") + } + + @Test + def testAuthorizeByResourceTypePrefixedResourceDenyDominate(): Unit = { + val user1 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "user1") + val host1 = InetAddress.getByName("192.168.1.1") + val a = new ResourcePattern(GROUP, "a", PREFIXED) + val ab = new ResourcePattern(GROUP, "ab", PREFIXED) + val abc = new ResourcePattern(GROUP, "abc", PREFIXED) + val abcd = new ResourcePattern(GROUP, "abcd", PREFIXED) + val abcde = new ResourcePattern(GROUP, "abcde", PREFIXED) + + val u1h1Context = newRequestContext(user1, host1) + val allowAce = new AccessControlEntry(user1.toString, host1.getHostAddress, READ, ALLOW) + val denyAce = new AccessControlEntry(user1.toString, host1.getHostAddress, READ, DENY) + + addAcls(authorizer, Set(allowAce), abcde) + assertTrue(authorizeByResourceType(authorizer, u1h1Context, READ, ResourceType.GROUP), + "User1 from host1 should have READ access to at least one group") + + addAcls(authorizer, Set(denyAce), abcd) + assertFalse(authorizeByResourceType(authorizer, u1h1Context, READ, ResourceType.GROUP), + "User1 from host1 now should not have READ access to any group") + + addAcls(authorizer, Set(allowAce), abc) + assertTrue(authorizeByResourceType(authorizer, u1h1Context, READ, ResourceType.GROUP), + "User1 from host1 now should have READ access to any group") + + addAcls(authorizer, Set(denyAce), a) + assertFalse(authorizeByResourceType(authorizer, u1h1Context, READ, ResourceType.GROUP), + "User1 from host1 now should not have READ access to any group") + + addAcls(authorizer, Set(allowAce), ab) + assertFalse(authorizeByResourceType(authorizer, u1h1Context, READ, ResourceType.GROUP), + "User1 from host1 still should not have READ access to any group") + } + + @Test + def testAuthorizeByResourceTypeWildcardResourceDenyDominate(): Unit = { + val user1 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "user1") + val host1 = InetAddress.getByName("192.168.1.1") + val wildcard = new ResourcePattern(GROUP, ResourcePattern.WILDCARD_RESOURCE, LITERAL) + val prefixed = new ResourcePattern(GROUP, "hello", PREFIXED) + val literal = new ResourcePattern(GROUP, "aloha", LITERAL) + + val u1h1Context = newRequestContext(user1, host1) + val allowAce = new AccessControlEntry(user1.toString, host1.getHostAddress, WRITE, ALLOW) + val denyAce = new AccessControlEntry(user1.toString, host1.getHostAddress, WRITE, DENY) + + addAcls(authorizer, Set(allowAce), prefixed) + assertTrue(authorizeByResourceType(authorizer, u1h1Context, WRITE, ResourceType.GROUP), + "User1 from host1 should have WRITE access to at least one group") + + addAcls(authorizer, Set(denyAce), wildcard) + assertFalse(authorizeByResourceType(authorizer, u1h1Context, WRITE, ResourceType.GROUP), + "User1 from host1 now should not have WRITE access to any group") + + addAcls(authorizer, Set(allowAce), wildcard) + assertFalse(authorizeByResourceType(authorizer, u1h1Context, WRITE, ResourceType.GROUP), + "User1 from host1 still should not have WRITE access to any group") + + addAcls(authorizer, Set(allowAce), literal) + assertFalse(authorizeByResourceType(authorizer, u1h1Context, WRITE, ResourceType.GROUP), + "User1 from host1 still should not have WRITE access to any group") + } + + @Test + def testAuthorizeByResourceTypeWithAllOperationAce(): Unit = { + val user1 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "user1") + val host1 = InetAddress.getByName("192.168.1.1") + val resource1 = new ResourcePattern(TOPIC, "sb1" + UUID.randomUUID(), LITERAL) + val denyAll = new AccessControlEntry(user1.toString, host1.getHostAddress, ALL, DENY) + val allowAll = new AccessControlEntry(user1.toString, host1.getHostAddress, ALL, ALLOW) + val denyWrite = new AccessControlEntry(user1.toString, host1.getHostAddress, WRITE, DENY) + val u1h1Context = newRequestContext(user1, host1) + + assertFalse(authorizeByResourceType(authorizer, u1h1Context, READ, ResourceType.TOPIC), + "User1 from host1 should not have READ access to any topic when no ACL exists") + + addAcls(authorizer, Set(denyWrite, allowAll), resource1) + assertTrue(authorizeByResourceType(authorizer, u1h1Context, READ, ResourceType.TOPIC), + "User1 from host1 now should have READ access to at least one topic") + + addAcls(authorizer, Set(denyAll), resource1) + assertFalse(authorizeByResourceType(authorizer, u1h1Context, READ, ResourceType.TOPIC), + "User1 from host1 now should not have READ access to any topic") + } + + @Test + def testAuthorizeByResourceTypeWithAllHostAce(): Unit = { + val user1 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "user1") + val host1 = InetAddress.getByName("192.168.1.1") + val host2 = InetAddress.getByName("192.168.1.2") + val allHost = AclEntry.WildcardHost + val resource1 = new ResourcePattern(TOPIC, "sb1" + UUID.randomUUID(), LITERAL) + val resource2 = new ResourcePattern(TOPIC, "sb2" + UUID.randomUUID(), LITERAL) + val allowHost1 = new AccessControlEntry(user1.toString, host1.getHostAddress, READ, ALLOW) + val denyHost1 = new AccessControlEntry(user1.toString, host1.getHostAddress, READ, DENY) + val denyAllHost = new AccessControlEntry(user1.toString, allHost, READ, DENY) + val allowAllHost = new AccessControlEntry(user1.toString, allHost, READ, ALLOW) + val u1h1Context = newRequestContext(user1, host1) + val u1h2Context = newRequestContext(user1, host2) + + assertFalse(authorizeByResourceType(authorizer, u1h1Context, READ, ResourceType.TOPIC), + "User1 from host1 should not have READ access to any topic when no ACL exists") + + addAcls(authorizer, Set(allowHost1), resource1) + assertTrue(authorizeByResourceType(authorizer, u1h1Context, READ, ResourceType.TOPIC), + "User1 from host1 should now have READ access to at least one topic") + + addAcls(authorizer, Set(denyAllHost), resource1) + assertFalse(authorizeByResourceType(authorizer, u1h1Context, READ, ResourceType.TOPIC), + "User1 from host1 now shouldn't have READ access to any topic") + + addAcls(authorizer, Set(denyHost1), resource2) + assertFalse(authorizeByResourceType(authorizer, u1h1Context, READ, ResourceType.TOPIC), + "User1 from host1 still should not have READ access to any topic") + assertFalse(authorizeByResourceType(authorizer, u1h2Context, READ, ResourceType.TOPIC), + "User1 from host2 should not have READ access to any topic") + + addAcls(authorizer, Set(allowAllHost), resource2) + assertTrue(authorizeByResourceType(authorizer, u1h2Context, READ, ResourceType.TOPIC), + "User1 from host2 should now have READ access to at least one topic") + + addAcls(authorizer, Set(denyAllHost), resource2) + assertFalse(authorizeByResourceType(authorizer, u1h2Context, READ, ResourceType.TOPIC), + "User1 from host2 now shouldn't have READ access to any topic") + } + + @Test + def testAuthorizeByResourceTypeWithAllPrincipalAce(): Unit = { + val user1 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "user1") + val user2 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "user2") + val allUser = AclEntry.WildcardPrincipalString + val host1 = InetAddress.getByName("192.168.1.1") + val resource1 = new ResourcePattern(TOPIC, "sb1" + UUID.randomUUID(), LITERAL) + val resource2 = new ResourcePattern(TOPIC, "sb2" + UUID.randomUUID(), LITERAL) + val allowUser1 = new AccessControlEntry(user1.toString, host1.getHostAddress, READ, ALLOW) + val denyUser1 = new AccessControlEntry(user1.toString, host1.getHostAddress, READ, DENY) + val denyAllUser = new AccessControlEntry(allUser, host1.getHostAddress, READ, DENY) + val allowAllUser = new AccessControlEntry(allUser, host1.getHostAddress, READ, ALLOW) + val u1h1Context = newRequestContext(user1, host1) + val u2h1Context = newRequestContext(user2, host1) + + assertFalse(authorizeByResourceType(authorizer, u1h1Context, READ, ResourceType.TOPIC), + "User1 from host1 should not have READ access to any topic when no ACL exists") + + addAcls(authorizer, Set(allowUser1), resource1) + assertTrue(authorizeByResourceType(authorizer, u1h1Context, READ, ResourceType.TOPIC), + "User1 from host1 should now have READ access to at least one topic") + + addAcls(authorizer, Set(denyAllUser), resource1) + assertFalse(authorizeByResourceType(authorizer, u1h1Context, READ, ResourceType.TOPIC), + "User1 from host1 now shouldn't have READ access to any topic") + + addAcls(authorizer, Set(denyUser1), resource2) + assertFalse(authorizeByResourceType(authorizer, u1h1Context, READ, ResourceType.TOPIC), + "User1 from host1 still should not have READ access to any topic") + assertFalse(authorizeByResourceType(authorizer, u2h1Context, READ, ResourceType.TOPIC), + "User2 from host1 should not have READ access to any topic") + + addAcls(authorizer, Set(allowAllUser), resource2) + assertTrue(authorizeByResourceType(authorizer, u2h1Context, READ, ResourceType.TOPIC), + "User2 from host1 should now have READ access to at least one topic") + + addAcls(authorizer, Set(denyAllUser), resource2) + assertFalse(authorizeByResourceType(authorizer, u2h1Context, READ, ResourceType.TOPIC), + "User2 from host1 now shouldn't have READ access to any topic") + } + + @Test + def testAuthorzeByResourceTypeSuperUserHasAccess(): Unit = { + val denyAllAce = new AccessControlEntry(WildcardPrincipalString, WildcardHost, AclOperation.ALL, DENY) + val superUser1 = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, superUserName) + val host1 = InetAddress.getByName("192.0.4.4") + val allTopicsResource = new ResourcePattern(TOPIC, WILDCARD_RESOURCE, LITERAL) + val clusterResource = new ResourcePattern(CLUSTER, WILDCARD_RESOURCE, LITERAL) + val groupResource = new ResourcePattern(GROUP, WILDCARD_RESOURCE, LITERAL) + val transactionIdResource = new ResourcePattern(TRANSACTIONAL_ID, WILDCARD_RESOURCE, LITERAL) + + addAcls(authorizer, Set(denyAllAce), allTopicsResource) + addAcls(authorizer, Set(denyAllAce), clusterResource) + addAcls(authorizer, Set(denyAllAce), groupResource) + addAcls(authorizer, Set(denyAllAce), transactionIdResource) + + val superUserContext = newRequestContext(superUser1, host1) + + assertTrue(authorizeByResourceType(authorizer, superUserContext, READ, ResourceType.TOPIC), + "superuser always has access, no matter what acls.") + assertTrue(authorizeByResourceType(authorizer, superUserContext, READ, ResourceType.CLUSTER), + "superuser always has access, no matter what acls.") + assertTrue(authorizeByResourceType(authorizer, superUserContext, READ, ResourceType.GROUP), + "superuser always has access, no matter what acls.") + assertTrue(authorizeByResourceType(authorizer, superUserContext, READ, ResourceType.TRANSACTIONAL_ID), + "superuser always has access, no matter what acls.") + } + + def newRequestContext(principal: KafkaPrincipal, clientAddress: InetAddress, apiKey: ApiKeys = ApiKeys.PRODUCE): RequestContext = { + val securityProtocol = SecurityProtocol.SASL_PLAINTEXT + val header = new RequestHeader(apiKey, 2, "", 1) //ApiKeys apiKey, short version, String clientId, int correlation + new RequestContext(header, "", clientAddress, principal, ListenerName.forSecurityProtocol(securityProtocol), + securityProtocol, ClientInformation.EMPTY, false) + } + + def authorizeByResourceType(authorizer: Authorizer, requestContext: RequestContext, operation: AclOperation, resourceType: ResourceType) : Boolean = { + authorizer.authorizeByResourceType(requestContext, operation, resourceType) == AuthorizationResult.ALLOWED + } + + def addAcls(authorizer: Authorizer, aces: Set[AccessControlEntry], resourcePattern: ResourcePattern): Unit = { + val bindings = aces.map { ace => new AclBinding(resourcePattern, ace) } + authorizer.createAcls(requestContext, bindings.toList.asJava).asScala + .map(_.toCompletableFuture.get) + .foreach { result => result.exception.ifPresent { e => throw e } } + } + + def removeAcls(authorizer: Authorizer, aces: Set[AccessControlEntry], resourcePattern: ResourcePattern): Boolean = { + val bindings = if (aces.isEmpty) + Set(new AclBindingFilter(resourcePattern.toFilter, AccessControlEntryFilter.ANY) ) + else + aces.map { ace => new AclBinding(resourcePattern, ace).toFilter } + authorizer.deleteAcls(requestContext, bindings.toList.asJava).asScala + .map(_.toCompletableFuture.get) + .forall { result => + result.exception.ifPresent { e => throw e } + result.aclBindingDeleteResults.forEach { r => + r.exception.ifPresent { e => throw e } + } + !result.aclBindingDeleteResults.isEmpty + } + } + +} diff --git a/core/src/test/scala/unit/kafka/security/token/delegation/DelegationTokenManagerTest.scala b/core/src/test/scala/unit/kafka/security/token/delegation/DelegationTokenManagerTest.scala new file mode 100644 index 0000000..523b6a7 --- /dev/null +++ b/core/src/test/scala/unit/kafka/security/token/delegation/DelegationTokenManagerTest.scala @@ -0,0 +1,363 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.security.token.delegation + +import java.net.InetAddress +import java.nio.ByteBuffer +import java.util.{Base64, Properties} + +import kafka.network.RequestChannel.Session +import kafka.security.authorizer.{AclAuthorizer, AuthorizerUtils} +import kafka.security.authorizer.AclEntry.WildcardHost +import kafka.server.{CreateTokenResult, Defaults, DelegationTokenManager, KafkaConfig, QuorumTestHarness} +import kafka.utils.TestUtils +import kafka.zk.KafkaZkClient +import org.apache.kafka.common.acl.{AccessControlEntry, AclBinding, AclOperation} +import org.apache.kafka.common.acl.AclOperation._ +import org.apache.kafka.common.acl.AclPermissionType.ALLOW +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.resource.PatternType.LITERAL +import org.apache.kafka.common.resource.ResourcePattern +import org.apache.kafka.common.resource.ResourceType.DELEGATION_TOKEN +import org.apache.kafka.common.security.auth.KafkaPrincipal +import org.apache.kafka.common.security.scram.internals.ScramMechanism +import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache +import org.apache.kafka.common.security.token.delegation.{DelegationToken, TokenInformation} +import org.apache.kafka.common.utils.{MockTime, SecurityUtils, Time} +import org.apache.kafka.server.authorizer._ +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +import scala.jdk.CollectionConverters._ +import scala.collection.mutable.Buffer + +class DelegationTokenManagerTest extends QuorumTestHarness { + + val time = new MockTime() + val owner = SecurityUtils.parseKafkaPrincipal("User:owner") + val renewer = List(SecurityUtils.parseKafkaPrincipal("User:renewer1")) + val tokenManagers = Buffer[DelegationTokenManager]() + + val secretKey = "secretKey" + val maxLifeTimeMsDefault = Defaults.DelegationTokenMaxLifeTimeMsDefault + val renewTimeMsDefault = Defaults.DelegationTokenExpiryTimeMsDefault + var tokenCache: DelegationTokenCache = null + var props: Properties = null + + var createTokenResult: CreateTokenResult = _ + var error: Errors = Errors.NONE + var expiryTimeStamp: Long = 0 + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + props = TestUtils.createBrokerConfig(0, zkConnect, enableToken = true) + props.put(KafkaConfig.SaslEnabledMechanismsProp, ScramMechanism.mechanismNames().asScala.mkString(",")) + props.put(KafkaConfig.DelegationTokenSecretKeyProp, secretKey) + tokenCache = new DelegationTokenCache(ScramMechanism.mechanismNames()) + } + + @AfterEach + override def tearDown(): Unit = { + tokenManagers.foreach(_.shutdown()) + super.tearDown() + } + + @Test + def testTokenRequestsWithDelegationTokenDisabled(): Unit = { + val props: Properties = TestUtils.createBrokerConfig(0, zkConnect) + val config = KafkaConfig.fromProps(props) + val tokenManager = createDelegationTokenManager(config, tokenCache, time, zkClient) + + tokenManager.createToken(owner, renewer, -1, createTokenResultCallBack) + assertEquals(Errors.DELEGATION_TOKEN_AUTH_DISABLED, createTokenResult.error) + assert(Array[Byte]() sameElements createTokenResult.hmac) + + tokenManager.renewToken(owner, ByteBuffer.wrap("test".getBytes), 1000000, renewResponseCallback) + assertEquals(Errors.DELEGATION_TOKEN_AUTH_DISABLED, error) + + tokenManager.expireToken(owner, ByteBuffer.wrap("test".getBytes), 1000000, renewResponseCallback) + assertEquals(Errors.DELEGATION_TOKEN_AUTH_DISABLED, error) + } + + @Test + def testCreateToken(): Unit = { + val config = KafkaConfig.fromProps(props) + val tokenManager = createDelegationTokenManager(config, tokenCache, time, zkClient) + tokenManager.startup() + + tokenManager.createToken(owner, renewer, -1 , createTokenResultCallBack) + val issueTime = time.milliseconds + val tokenId = createTokenResult.tokenId + val password = DelegationTokenManager.createHmac(tokenId, secretKey) + assertEquals(CreateTokenResult(issueTime, issueTime + renewTimeMsDefault, issueTime + maxLifeTimeMsDefault, tokenId, password, Errors.NONE), createTokenResult) + + val token = tokenManager.getToken(tokenId) + assertFalse(token.isEmpty ) + assertTrue(password sameElements token.get.hmac) + } + + @Test + def testRenewToken(): Unit = { + val config = KafkaConfig.fromProps(props) + val tokenManager = createDelegationTokenManager(config, tokenCache, time, zkClient) + tokenManager.startup() + + tokenManager.createToken(owner, renewer, -1 , createTokenResultCallBack) + val issueTime = time.milliseconds + val maxLifeTime = issueTime + maxLifeTimeMsDefault + val tokenId = createTokenResult.tokenId + val password = DelegationTokenManager.createHmac(tokenId, secretKey) + assertEquals(CreateTokenResult(issueTime, issueTime + renewTimeMsDefault, maxLifeTime, tokenId, password, Errors.NONE), createTokenResult) + + //try renewing non-existing token + tokenManager.renewToken(owner, ByteBuffer.wrap("test".getBytes), -1 , renewResponseCallback) + assertEquals(Errors.DELEGATION_TOKEN_NOT_FOUND, error) + + //try renew non-owned tokens + val unknownOwner = SecurityUtils.parseKafkaPrincipal("User:Unknown") + tokenManager.renewToken(unknownOwner, ByteBuffer.wrap(password), -1 , renewResponseCallback) + assertEquals(Errors.DELEGATION_TOKEN_OWNER_MISMATCH, error) + + // try renew with default time period + time.sleep(24 * 60 * 60 * 1000L) + var expectedExpiryStamp = time.milliseconds + renewTimeMsDefault + tokenManager.renewToken(owner, ByteBuffer.wrap(password), -1 , renewResponseCallback) + assertEquals(expectedExpiryStamp, expiryTimeStamp) + assertEquals(Errors.NONE, error) + + // try renew with specific time period + time.sleep(24 * 60 * 60 * 1000L) + expectedExpiryStamp = time.milliseconds + 1 * 60 * 60 * 1000L + tokenManager.renewToken(owner, ByteBuffer.wrap(password), 1 * 60 * 60 * 1000L , renewResponseCallback) + assertEquals(expectedExpiryStamp, expiryTimeStamp) + assertEquals(Errors.NONE, error) + + //try renewing more than max time period + time.sleep( 1 * 60 * 60 * 1000L) + tokenManager.renewToken(owner, ByteBuffer.wrap(password), 8 * 24 * 60 * 60 * 1000L, renewResponseCallback) + assertEquals(maxLifeTime, expiryTimeStamp) + assertEquals(Errors.NONE, error) + + //try renewing expired token + time.sleep(8 * 24 * 60 * 60 * 1000L) + tokenManager.renewToken(owner, ByteBuffer.wrap(password), -1 , renewResponseCallback) + assertEquals(Errors.DELEGATION_TOKEN_EXPIRED, error) + } + + @Test + def testExpireToken(): Unit = { + val config = KafkaConfig.fromProps(props) + val tokenManager = createDelegationTokenManager(config, tokenCache, time, zkClient) + tokenManager.startup() + + tokenManager.createToken(owner, renewer, -1 , createTokenResultCallBack) + val issueTime = time.milliseconds + val tokenId = createTokenResult.tokenId + val password = DelegationTokenManager.createHmac(tokenId, secretKey) + assertEquals(CreateTokenResult(issueTime, issueTime + renewTimeMsDefault, issueTime + maxLifeTimeMsDefault, tokenId, password, Errors.NONE), createTokenResult) + + //try expire non-existing token + tokenManager.expireToken(owner, ByteBuffer.wrap("test".getBytes), -1 , renewResponseCallback) + assertEquals(Errors.DELEGATION_TOKEN_NOT_FOUND, error) + + //try expire non-owned tokens + val unknownOwner = SecurityUtils.parseKafkaPrincipal("User:Unknown") + tokenManager.expireToken(unknownOwner, ByteBuffer.wrap(password), -1 , renewResponseCallback) + assertEquals(Errors.DELEGATION_TOKEN_OWNER_MISMATCH, error) + + //try expire token at a timestamp + time.sleep(24 * 60 * 60 * 1000L) + val expectedExpiryStamp = time.milliseconds + 2 * 60 * 60 * 1000L + tokenManager.expireToken(owner, ByteBuffer.wrap(password), 2 * 60 * 60 * 1000L, renewResponseCallback) + assertEquals(expectedExpiryStamp, expiryTimeStamp) + + //try expire token immediately + time.sleep(1 * 60 * 60 * 1000L) + tokenManager.expireToken(owner, ByteBuffer.wrap(password), -1, renewResponseCallback) + assert(tokenManager.getToken(tokenId).isEmpty) + assertEquals(Errors.NONE, error) + assertEquals(time.milliseconds, expiryTimeStamp) + } + + @Test + def testRemoveTokenHmac():Unit = { + val config = KafkaConfig.fromProps(props) + val tokenManager = createDelegationTokenManager(config, tokenCache, time, zkClient) + tokenManager.startup() + + tokenManager.createToken(owner, renewer, -1 , createTokenResultCallBack) + val issueTime = time.milliseconds + val tokenId = createTokenResult.tokenId + val password = DelegationTokenManager.createHmac(tokenId, secretKey) + assertEquals(CreateTokenResult(issueTime, issueTime + renewTimeMsDefault, issueTime + maxLifeTimeMsDefault, tokenId, password, Errors.NONE), createTokenResult) + + // expire the token immediately + tokenManager.expireToken(owner, ByteBuffer.wrap(password), -1, renewResponseCallback) + + val encodedHmac = Base64.getEncoder.encodeToString(password) + // check respective hmac map entry is removed for the expired tokenId. + val tokenInformation = tokenManager.tokenCache.tokenIdForHmac(encodedHmac) + assertNull(tokenInformation) + + //check that the token is removed + assert(tokenManager.getToken(tokenId).isEmpty) + } + + @Test + def testDescribeToken(): Unit = { + + val config = KafkaConfig.fromProps(props) + + val owner1 = SecurityUtils.parseKafkaPrincipal("User:owner1") + val owner2 = SecurityUtils.parseKafkaPrincipal("User:owner2") + val owner3 = SecurityUtils.parseKafkaPrincipal("User:owner3") + val owner4 = SecurityUtils.parseKafkaPrincipal("User:owner4") + + val renewer1 = SecurityUtils.parseKafkaPrincipal("User:renewer1") + val renewer2 = SecurityUtils.parseKafkaPrincipal("User:renewer2") + val renewer3 = SecurityUtils.parseKafkaPrincipal("User:renewer3") + val renewer4 = SecurityUtils.parseKafkaPrincipal("User:renewer4") + + val aclAuthorizer = new AclAuthorizer + aclAuthorizer.configure(config.originals) + + var hostSession = new Session(owner1, InetAddress.getByName("192.168.1.1")) + + val tokenManager = createDelegationTokenManager(config, tokenCache, time, zkClient) + tokenManager.startup() + + //create tokens + tokenManager.createToken(owner1, List(renewer1, renewer2), 1 * 60 * 60 * 1000L, createTokenResultCallBack) + + tokenManager.createToken(owner2, List(renewer3), 1 * 60 * 60 * 1000L, createTokenResultCallBack) + val tokenId2 = createTokenResult.tokenId + + tokenManager.createToken(owner3, List(renewer4), 2 * 60 * 60 * 1000L, createTokenResultCallBack) + val tokenId3 = createTokenResult.tokenId + + tokenManager.createToken(owner4, List(owner1, renewer4), 2 * 60 * 60 * 1000L, createTokenResultCallBack) + + assert(tokenManager.getAllTokenInformation.size == 4 ) + + //get tokens non-exiting owner + var tokens = getTokens(tokenManager, aclAuthorizer, hostSession, owner1, List(SecurityUtils.parseKafkaPrincipal("User:unknown"))) + assert(tokens.size == 0) + + //get all tokens for empty owner list + tokens = getTokens(tokenManager, aclAuthorizer, hostSession, owner1, List()) + assert(tokens.size == 0) + + //get all tokens for owner1 + tokens = getTokens(tokenManager, aclAuthorizer, hostSession, owner1, List(owner1)) + assert(tokens.size == 2) + + //get all tokens for owner1 + tokens = getTokens(tokenManager, aclAuthorizer, hostSession, owner1, null) + assert(tokens.size == 2) + + //get all tokens for unknown owner + tokens = getTokens(tokenManager, aclAuthorizer, hostSession, SecurityUtils.parseKafkaPrincipal("User:unknown"), null) + assert(tokens.size == 0) + + //get all tokens for multiple owners (owner1, renewer4) and without permission for renewer4 + tokens = getTokens(tokenManager, aclAuthorizer, hostSession, owner1, List(owner1, renewer4)) + assert(tokens.size == 2) + + def createAcl(aclBinding: AclBinding): Unit = { + val result = aclAuthorizer.createAcls(null, List(aclBinding).asJava).get(0).toCompletableFuture.get + result.exception.ifPresent { e => throw e } + } + + //get all tokens for multiple owners (owner1, renewer4) and with permission + createAcl(new AclBinding(new ResourcePattern(DELEGATION_TOKEN, tokenId3, LITERAL), + new AccessControlEntry(owner1.toString, WildcardHost, DESCRIBE, ALLOW))) + tokens = getTokens(tokenManager, aclAuthorizer, hostSession, owner1, List(owner1, renewer4)) + assert(tokens.size == 3) + + //get all tokens for renewer4 which is a renewer principal for some tokens + tokens = getTokens(tokenManager, aclAuthorizer, hostSession, renewer4, List(renewer4)) + assert(tokens.size == 2) + + //get all tokens for multiple owners (renewer2, renewer3) which are token renewers principals and without permissions for renewer3 + tokens = getTokens(tokenManager, aclAuthorizer, hostSession, renewer2, List(renewer2, renewer3)) + assert(tokens.size == 1) + + //get all tokens for multiple owners (renewer2, renewer3) which are token renewers principals and with permissions + hostSession = Session(renewer2, InetAddress.getByName("192.168.1.1")) + createAcl(new AclBinding(new ResourcePattern(DELEGATION_TOKEN, tokenId2, LITERAL), + new AccessControlEntry(renewer2.toString, WildcardHost, DESCRIBE, ALLOW))) + tokens = getTokens(tokenManager, aclAuthorizer, hostSession, renewer2, List(renewer2, renewer3)) + assert(tokens.size == 2) + + aclAuthorizer.close() + } + + private def getTokens(tokenManager: DelegationTokenManager, aclAuthorizer: AclAuthorizer, hostSession: Session, + requestPrincipal: KafkaPrincipal, requestedOwners: List[KafkaPrincipal]): List[DelegationToken] = { + + if (requestedOwners != null && requestedOwners.isEmpty) { + List() + } + else { + def authorizeToken(tokenId: String) = { + val requestContext = AuthorizerUtils.sessionToRequestContext(hostSession) + val action = new Action(AclOperation.DESCRIBE, + new ResourcePattern(DELEGATION_TOKEN, tokenId, LITERAL), 1, true, true) + aclAuthorizer.authorize(requestContext, List(action).asJava).asScala.head == AuthorizationResult.ALLOWED + } + def eligible(token: TokenInformation) = DelegationTokenManager.filterToken(requestPrincipal, Option(requestedOwners), token, authorizeToken) + tokenManager.getTokens(eligible) + } + } + + @Test + def testPeriodicTokenExpiry(): Unit = { + val config = KafkaConfig.fromProps(props) + val tokenManager = createDelegationTokenManager(config, tokenCache, time, zkClient) + tokenManager.startup() + + //create tokens + tokenManager.createToken(owner, renewer, 1 * 60 * 60 * 1000L, createTokenResultCallBack) + tokenManager.createToken(owner, renewer, 1 * 60 * 60 * 1000L, createTokenResultCallBack) + tokenManager.createToken(owner, renewer, 2 * 60 * 60 * 1000L, createTokenResultCallBack) + tokenManager.createToken(owner, renewer, 2 * 60 * 60 * 1000L, createTokenResultCallBack) + assert(tokenManager.getAllTokenInformation.size == 4 ) + + time.sleep(2 * 60 * 60 * 1000L) + tokenManager.expireTokens() + assert(tokenManager.getAllTokenInformation.size == 2 ) + + } + + private def createTokenResultCallBack(ret: CreateTokenResult): Unit = { + createTokenResult = ret + } + + private def renewResponseCallback(ret: Errors, timeStamp: Long): Unit = { + error = ret + expiryTimeStamp = timeStamp + } + + private def createDelegationTokenManager(config: KafkaConfig, tokenCache: DelegationTokenCache, + time: Time, zkClient: KafkaZkClient): DelegationTokenManager = { + val tokenManager = new DelegationTokenManager(config, tokenCache, time, zkClient) + tokenManagers += tokenManager + tokenManager + } +} diff --git a/core/src/test/scala/unit/kafka/server/AbstractApiVersionsRequestTest.scala b/core/src/test/scala/unit/kafka/server/AbstractApiVersionsRequestTest.scala new file mode 100644 index 0000000..1e6c05d --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/AbstractApiVersionsRequestTest.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.util.Properties + +import kafka.test.ClusterInstance +import org.apache.kafka.common.message.ApiMessageType.ListenerType +import org.apache.kafka.common.message.ApiVersionsResponseData.ApiVersion +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.protocol.ApiKeys +import org.apache.kafka.common.requests.{ApiVersionsRequest, ApiVersionsResponse, RequestUtils} +import org.apache.kafka.common.utils.Utils +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Tag + +import scala.jdk.CollectionConverters._ + +@Tag("integration") +abstract class AbstractApiVersionsRequestTest(cluster: ClusterInstance) { + + def sendApiVersionsRequest(request: ApiVersionsRequest, listenerName: ListenerName): ApiVersionsResponse = { + IntegrationTestUtils.connectAndReceive[ApiVersionsResponse](request, cluster.brokerSocketServers().asScala.head, listenerName) + } + + def controlPlaneListenerName = new ListenerName("CONTROLLER") + + // Configure control plane listener to make sure we have separate listeners for testing. + def brokerPropertyOverrides(properties: Properties): Unit = { + val securityProtocol = cluster.config().securityProtocol() + properties.setProperty(KafkaConfig.ControlPlaneListenerNameProp, controlPlaneListenerName.value()) + properties.setProperty(KafkaConfig.ListenerSecurityProtocolMapProp, s"${controlPlaneListenerName.value()}:$securityProtocol,$securityProtocol:$securityProtocol") + properties.setProperty("listeners", s"$securityProtocol://localhost:0,${controlPlaneListenerName.value()}://localhost:0") + properties.setProperty(KafkaConfig.AdvertisedListenersProp, s"$securityProtocol://localhost:0,${controlPlaneListenerName.value()}://localhost:0") + } + + def sendUnsupportedApiVersionRequest(request: ApiVersionsRequest): ApiVersionsResponse = { + val overrideHeader = IntegrationTestUtils.nextRequestHeader(ApiKeys.API_VERSIONS, Short.MaxValue) + val socket = IntegrationTestUtils.connect(cluster.brokerSocketServers().asScala.head, cluster.clientListener()) + try { + val serializedBytes = Utils.toArray( + RequestUtils.serialize(overrideHeader.data, overrideHeader.headerVersion, request.data, request.version)) + IntegrationTestUtils.sendRequest(socket, serializedBytes) + IntegrationTestUtils.receive[ApiVersionsResponse](socket, ApiKeys.API_VERSIONS, 0.toShort) + } finally socket.close() + } + + def validateApiVersionsResponse(apiVersionsResponse: ApiVersionsResponse): Unit = { + val expectedApis = ApiKeys.zkBrokerApis() + assertEquals(expectedApis.size(), apiVersionsResponse.data.apiKeys().size(), + "API keys in ApiVersionsResponse must match API keys supported by broker.") + + val defaultApiVersionsResponse = ApiVersionsResponse.defaultApiVersionsResponse(ListenerType.ZK_BROKER) + for (expectedApiVersion: ApiVersion <- defaultApiVersionsResponse.data.apiKeys().asScala) { + val actualApiVersion = apiVersionsResponse.apiVersion(expectedApiVersion.apiKey) + assertNotNull(actualApiVersion, s"API key ${actualApiVersion.apiKey} is supported by broker, but not received in ApiVersionsResponse.") + assertEquals(expectedApiVersion.apiKey, actualApiVersion.apiKey, "API key must be supported by the broker.") + assertEquals(expectedApiVersion.minVersion, actualApiVersion.minVersion, s"Received unexpected min version for API key ${actualApiVersion.apiKey}.") + assertEquals(expectedApiVersion.maxVersion, actualApiVersion.maxVersion, s"Received unexpected max version for API key ${actualApiVersion.apiKey}.") + } + } +} diff --git a/core/src/test/scala/unit/kafka/server/AbstractCreateTopicsRequestTest.scala b/core/src/test/scala/unit/kafka/server/AbstractCreateTopicsRequestTest.scala new file mode 100644 index 0000000..b6036b7 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/AbstractCreateTopicsRequestTest.scala @@ -0,0 +1,191 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util +import java.util.Properties + +import kafka.network.SocketServer +import kafka.utils.TestUtils +import org.apache.kafka.common.message.CreateTopicsRequestData +import org.apache.kafka.common.message.CreateTopicsRequestData._ +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests._ +import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertNotNull, assertTrue} + +import scala.jdk.CollectionConverters._ + +abstract class AbstractCreateTopicsRequestTest extends BaseRequestTest { + + override def brokerPropertyOverrides(properties: Properties): Unit = + properties.put(KafkaConfig.AutoCreateTopicsEnableProp, false.toString) + + def topicsReq(topics: Seq[CreatableTopic], + timeout: Integer = 10000, + validateOnly: Boolean = false) = { + val req = new CreateTopicsRequestData() + req.setTimeoutMs(timeout) + req.setTopics(new CreatableTopicCollection(topics.asJava.iterator())) + req.setValidateOnly(validateOnly) + new CreateTopicsRequest.Builder(req).build() + } + + def topicReq(name: String, + numPartitions: Integer = null, + replicationFactor: Integer = null, + config: Map[String, String] = null, + assignment: Map[Int, Seq[Int]] = null): CreatableTopic = { + val topic = new CreatableTopic() + topic.setName(name) + if (numPartitions != null) { + topic.setNumPartitions(numPartitions) + } else if (assignment != null) { + topic.setNumPartitions(-1) + } else { + topic.setNumPartitions(1) + } + if (replicationFactor != null) { + topic.setReplicationFactor(replicationFactor.toShort) + } else if (assignment != null) { + topic.setReplicationFactor((-1).toShort) + } else { + topic.setReplicationFactor(1.toShort) + } + if (config != null) { + val effectiveConfigs = new CreateableTopicConfigCollection() + config.foreach { + case (name, value) => + effectiveConfigs.add(new CreateableTopicConfig().setName(name).setValue(value)) + } + topic.setConfigs(effectiveConfigs) + } + if (assignment != null) { + val effectiveAssignments = new CreatableReplicaAssignmentCollection() + assignment.foreach { + case (partitionIndex, brokerIdList) => { + val effectiveAssignment = new CreatableReplicaAssignment() + effectiveAssignment.setPartitionIndex(partitionIndex) + val brokerIds = new util.ArrayList[java.lang.Integer]() + brokerIdList.foreach(brokerId => brokerIds.add(brokerId)) + effectiveAssignment.setBrokerIds(brokerIds) + effectiveAssignments.add(effectiveAssignment) + } + } + topic.setAssignments(effectiveAssignments) + } + topic + } + + protected def validateValidCreateTopicsRequests(request: CreateTopicsRequest): Unit = { + val response = sendCreateTopicRequest(request) + + assertFalse(response.errorCounts().keySet().asScala.exists(_.code() > 0), + s"There should be no errors, found ${response.errorCounts().keySet().asScala.mkString(", ")},") + + request.data.topics.forEach { topic => + def verifyMetadata(socketServer: SocketServer) = { + val metadata = sendMetadataRequest( + new MetadataRequest.Builder(List(topic.name()).asJava, false).build()).topicMetadata.asScala + val metadataForTopic = metadata.filter(_.topic == topic.name()).head + + val partitions = if (!topic.assignments().isEmpty) + topic.assignments().size + else + topic.numPartitions + + val replication = if (!topic.assignments().isEmpty) + topic.assignments().iterator().next().brokerIds().size() + else + topic.replicationFactor + + if (request.data.validateOnly) { + assertNotNull(metadataForTopic, s"Topic $topic should be created") + assertFalse(metadataForTopic.error == Errors.NONE, s"Error ${metadataForTopic.error} for topic $topic") + assertTrue(metadataForTopic.partitionMetadata.isEmpty, "The topic should have no partitions") + } + else { + assertNotNull(metadataForTopic, "The topic should be created") + assertEquals(Errors.NONE, metadataForTopic.error) + if (partitions == -1) { + assertEquals(configs.head.numPartitions, metadataForTopic.partitionMetadata.size, "The topic should have the default number of partitions") + } else { + assertEquals(partitions, metadataForTopic.partitionMetadata.size, "The topic should have the correct number of partitions") + } + + if (replication == -1) { + assertEquals(configs.head.defaultReplicationFactor, + metadataForTopic.partitionMetadata.asScala.head.replicaIds.size, "The topic should have the default replication factor") + } else { + assertEquals(replication, metadataForTopic.partitionMetadata.asScala.head.replicaIds.size, "The topic should have the correct replication factor") + } + } + } + + // Verify controller broker has the correct metadata + verifyMetadata(controllerSocketServer) + if (!request.data.validateOnly) { + // Wait until metadata is propagated and validate non-controller broker has the correct metadata + TestUtils.waitForPartitionMetadata(servers, topic.name(), 0) + } + verifyMetadata(notControllerSocketServer) + } + } + + protected def error(error: Errors, errorMessage: Option[String] = None): ApiError = + new ApiError(error, errorMessage.orNull) + + protected def validateErrorCreateTopicsRequests(request: CreateTopicsRequest, + expectedResponse: Map[String, ApiError], + checkErrorMessage: Boolean = true): Unit = { + val response = sendCreateTopicRequest(request) + assertEquals(expectedResponse.size, response.data().topics().size, "The response size should match") + + expectedResponse.foreach { case (topicName, expectedError) => + val expected = expectedResponse(topicName) + val actual = response.data().topics().find(topicName) + if (actual == null) { + throw new RuntimeException(s"No response data found for topic $topicName") + } + assertEquals(expected.error.code(), actual.errorCode(), "The response error should match") + if (checkErrorMessage) { + assertEquals(expected.message, actual.errorMessage()) + } + // If no error validate topic exists + if (expectedError.isSuccess && !request.data.validateOnly) { + validateTopicExists(topicName) + } + } + } + + protected def validateTopicExists(topic: String): Unit = { + TestUtils.waitForPartitionMetadata(servers, topic, 0) + val metadata = sendMetadataRequest( + new MetadataRequest.Builder(List(topic).asJava, true).build()).topicMetadata.asScala + assertTrue(metadata.exists(p => p.topic.equals(topic) && p.error == Errors.NONE), "The topic should be created") + } + + protected def sendCreateTopicRequest(request: CreateTopicsRequest, + socketServer: SocketServer = controllerSocketServer): CreateTopicsResponse = { + connectAndReceive[CreateTopicsResponse](request, socketServer) + } + + protected def sendMetadataRequest(request: MetadataRequest): MetadataResponse = { + connectAndReceive[MetadataResponse](request) + } + +} diff --git a/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala b/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala new file mode 100644 index 0000000..6708c51 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/AbstractFetcherManagerTest.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import com.yammer.metrics.core.Gauge +import kafka.cluster.BrokerEndPoint +import kafka.metrics.KafkaYammerMetrics +import kafka.utils.TestUtils +import org.apache.kafka.common.{TopicPartition, Uuid} +import org.easymock.EasyMock +import org.junit.jupiter.api.{BeforeEach, Test} +import org.junit.jupiter.api.Assertions._ + +import scala.jdk.CollectionConverters._ + +class AbstractFetcherManagerTest { + + @BeforeEach + def cleanMetricRegistry(): Unit = { + TestUtils.clearYammerMetrics() + } + + private def getMetricValue(name: String): Any = { + KafkaYammerMetrics.defaultRegistry.allMetrics.asScala.filter { case (k, _) => k.getName == name }.values.headOption.get. + asInstanceOf[Gauge[Int]].value() + } + + @Test + def testAddAndRemovePartition(): Unit = { + val fetcher: AbstractFetcherThread = EasyMock.mock(classOf[AbstractFetcherThread]) + val fetcherManager = new AbstractFetcherManager[AbstractFetcherThread]("fetcher-manager", "fetcher-manager", 2) { + override def createFetcherThread(fetcherId: Int, sourceBroker: BrokerEndPoint): AbstractFetcherThread = { + fetcher + } + } + + val fetchOffset = 10L + val leaderEpoch = 15 + val tp = new TopicPartition("topic", 0) + val topicId = Some(Uuid.randomUuid()) + val initialFetchState = InitialFetchState( + topicId = topicId, + leader = new BrokerEndPoint(0, "localhost", 9092), + currentLeaderEpoch = leaderEpoch, + initOffset = fetchOffset) + + EasyMock.expect(fetcher.start()) + EasyMock.expect(fetcher.addPartitions(Map(tp -> initialFetchState))) + .andReturn(Set(tp)) + EasyMock.expect(fetcher.fetchState(tp)) + .andReturn(Some(PartitionFetchState(topicId, fetchOffset, None, leaderEpoch, Truncating, lastFetchedEpoch = None))) + EasyMock.expect(fetcher.removePartitions(Set(tp))).andReturn(Map.empty) + EasyMock.expect(fetcher.fetchState(tp)).andReturn(None) + EasyMock.replay(fetcher) + + fetcherManager.addFetcherForPartitions(Map(tp -> initialFetchState)) + assertEquals(Some(fetcher), fetcherManager.getFetcher(tp)) + + fetcherManager.removeFetcherForPartitions(Set(tp)) + assertEquals(None, fetcherManager.getFetcher(tp)) + + EasyMock.verify(fetcher) + } + + @Test + def testMetricFailedPartitionCount(): Unit = { + val fetcher: AbstractFetcherThread = EasyMock.mock(classOf[AbstractFetcherThread]) + val fetcherManager = new AbstractFetcherManager[AbstractFetcherThread]("fetcher-manager", "fetcher-manager", 2) { + override def createFetcherThread(fetcherId: Int, sourceBroker: BrokerEndPoint): AbstractFetcherThread = { + fetcher + } + } + + val tp = new TopicPartition("topic", 0) + val metricName = "FailedPartitionsCount" + + // initial value for failed partition count + assertEquals(0, getMetricValue(metricName)) + + // partition marked as failed increments the count for failed partitions + fetcherManager.failedPartitions.add(tp) + assertEquals(1, getMetricValue(metricName)) + + // removing fetcher for the partition would remove the partition from set of failed partitions and decrement the + // count for failed partitions + fetcherManager.removeFetcherForPartitions(Set(tp)) + assertEquals(0, getMetricValue(metricName)) + } + @Test + def testDeadThreadCountMetric(): Unit = { + val fetcher: AbstractFetcherThread = EasyMock.mock(classOf[AbstractFetcherThread]) + val fetcherManager = new AbstractFetcherManager[AbstractFetcherThread]("fetcher-manager", "fetcher-manager", 2) { + override def createFetcherThread(fetcherId: Int, sourceBroker: BrokerEndPoint): AbstractFetcherThread = { + fetcher + } + } + + val fetchOffset = 10L + val leaderEpoch = 15 + val tp = new TopicPartition("topic", 0) + val topicId = Some(Uuid.randomUuid()) + val initialFetchState = InitialFetchState( + topicId = topicId, + leader = new BrokerEndPoint(0, "localhost", 9092), + currentLeaderEpoch = leaderEpoch, + initOffset = fetchOffset) + + EasyMock.expect(fetcher.start()) + EasyMock.expect(fetcher.addPartitions(Map(tp -> initialFetchState))) + .andReturn(Set(tp)) + EasyMock.expect(fetcher.isThreadFailed).andReturn(true) + EasyMock.replay(fetcher) + + fetcherManager.addFetcherForPartitions(Map(tp -> initialFetchState)) + + assertEquals(1, fetcherManager.deadThreadCount) + EasyMock.verify(fetcher) + + EasyMock.reset(fetcher) + EasyMock.expect(fetcher.isThreadFailed).andReturn(false) + EasyMock.replay(fetcher) + + assertEquals(0, fetcherManager.deadThreadCount) + EasyMock.verify(fetcher) + } + + @Test + def testMaybeUpdateTopicIds(): Unit = { + val fetcher: AbstractFetcherThread = EasyMock.mock(classOf[AbstractFetcherThread]) + val fetcherManager = new AbstractFetcherManager[AbstractFetcherThread]("fetcher-manager", "fetcher-manager", 2) { + override def createFetcherThread(fetcherId: Int, sourceBroker: BrokerEndPoint): AbstractFetcherThread = { + fetcher + } + } + + val fetchOffset = 10L + val leaderEpoch = 15 + val tp1 = new TopicPartition("topic1", 0) + val tp2 = new TopicPartition("topic2", 0) + val unknownTp = new TopicPartition("topic2", 1) + val topicId1 = Some(Uuid.randomUuid()) + val topicId2 = Some(Uuid.randomUuid()) + + // Start out with no topic ID. + val initialFetchState1 = InitialFetchState( + topicId = None, + leader = new BrokerEndPoint(0, "localhost", 9092), + currentLeaderEpoch = leaderEpoch, + initOffset = fetchOffset) + + // Include a partition on a different leader + val initialFetchState2 = InitialFetchState( + topicId = None, + leader = new BrokerEndPoint(1, "localhost", 9092), + currentLeaderEpoch = leaderEpoch, + initOffset = fetchOffset) + + // Simulate calls to different fetchers due to different leaders + EasyMock.expect(fetcher.start()) + EasyMock.expect(fetcher.start()) + + EasyMock.expect(fetcher.addPartitions(Map(tp1 -> initialFetchState1))) + .andReturn(Set(tp1)) + EasyMock.expect(fetcher.addPartitions(Map(tp2 -> initialFetchState2))) + .andReturn(Set(tp2)) + + EasyMock.expect(fetcher.fetchState(tp1)) + .andReturn(Some(PartitionFetchState(None, fetchOffset, None, leaderEpoch, Truncating, lastFetchedEpoch = None))) + EasyMock.expect(fetcher.fetchState(tp2)) + .andReturn(Some(PartitionFetchState(None, fetchOffset, None, leaderEpoch, Truncating, lastFetchedEpoch = None))) + + val topicIds = Map(tp1.topic -> topicId1, tp2.topic -> topicId2) + EasyMock.expect(fetcher.maybeUpdateTopicIds(Set(tp1), topicIds)) + EasyMock.expect(fetcher.maybeUpdateTopicIds(Set(tp2), topicIds)) + + EasyMock.expect(fetcher.fetchState(tp1)) + .andReturn(Some(PartitionFetchState(topicId1, fetchOffset, None, leaderEpoch, Truncating, lastFetchedEpoch = None))) + EasyMock.expect(fetcher.fetchState(tp2)) + .andReturn(Some(PartitionFetchState(topicId2, fetchOffset, None, leaderEpoch, Truncating, lastFetchedEpoch = None))) + + // When targeting a fetcher that doesn't exist, we will not see fetcher.maybeUpdateTopicIds called. + // We will see it for a topic partition that does not exist. + EasyMock.expect(fetcher.maybeUpdateTopicIds(Set(unknownTp), topicIds)) + EasyMock.expect(fetcher.fetchState(unknownTp)) + .andReturn(None) + EasyMock.replay(fetcher) + + def verifyFetchState(fetchState: Option[PartitionFetchState], expectedTopicId: Option[Uuid]): Unit = { + assertTrue(fetchState.isDefined) + assertEquals(expectedTopicId, fetchState.get.topicId) + } + + fetcherManager.addFetcherForPartitions(Map(tp1 -> initialFetchState1, tp2 -> initialFetchState2)) + verifyFetchState(fetcher.fetchState(tp1), None) + verifyFetchState(fetcher.fetchState(tp2), None) + + val partitionsToUpdate = Map(tp1 -> initialFetchState1.leader.id, tp2 -> initialFetchState2.leader.id) + fetcherManager.maybeUpdateTopicIds(partitionsToUpdate, topicIds) + verifyFetchState(fetcher.fetchState(tp1), topicId1) + verifyFetchState(fetcher.fetchState(tp2), topicId2) + + // Try an invalid fetcher and an invalid topic partition + val invalidPartitionsToUpdate = Map(tp1 -> 2, unknownTp -> initialFetchState1.leader.id) + fetcherManager.maybeUpdateTopicIds(invalidPartitionsToUpdate, topicIds) + assertTrue(fetcher.fetchState(unknownTp).isEmpty) + + EasyMock.verify(fetcher) + } +} diff --git a/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala b/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala new file mode 100644 index 0000000..148a903 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadTest.scala @@ -0,0 +1,1336 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.nio.ByteBuffer +import java.util.Optional +import java.util.concurrent.atomic.AtomicInteger + +import kafka.cluster.BrokerEndPoint +import kafka.log.LogAppendInfo +import kafka.message.NoCompressionCodec +import kafka.metrics.KafkaYammerMetrics +import kafka.server.AbstractFetcherThread.ReplicaFetch +import kafka.server.AbstractFetcherThread.ResultWithPartitions +import kafka.utils.Implicits.MapExtensionMethods +import kafka.utils.TestUtils +import org.apache.kafka.common.{KafkaException, TopicPartition, Uuid} +import org.apache.kafka.common.errors.{FencedLeaderEpochException, UnknownLeaderEpochException, UnknownTopicIdException} +import org.apache.kafka.common.message.FetchResponseData +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.record._ +import org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.{UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET} +import org.apache.kafka.common.requests.{FetchRequest, FetchResponse} +import org.apache.kafka.common.utils.Time +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Assumptions.assumeTrue +import org.junit.jupiter.api.{BeforeEach, Test} + +import scala.jdk.CollectionConverters._ +import scala.collection.{Map, Set, mutable} +import scala.util.Random +import scala.collection.mutable.ArrayBuffer +import scala.compat.java8.OptionConverters._ + +class AbstractFetcherThreadTest { + + val truncateOnFetch = true + val topicIds = Map("topic1" -> Uuid.randomUuid(), "topic2" -> Uuid.randomUuid()) + val version = ApiKeys.FETCH.latestVersion() + private val partition1 = new TopicPartition("topic1", 0) + private val partition2 = new TopicPartition("topic2", 0) + private val failedPartitions = new FailedPartitions + + @BeforeEach + def cleanMetricRegistry(): Unit = { + TestUtils.clearYammerMetrics() + } + + private def allMetricsNames: Set[String] = KafkaYammerMetrics.defaultRegistry().allMetrics().asScala.keySet.map(_.getName) + + private def mkBatch(baseOffset: Long, leaderEpoch: Int, records: SimpleRecord*): RecordBatch = { + MemoryRecords.withRecords(baseOffset, CompressionType.NONE, leaderEpoch, records: _*) + .batches.asScala.head + } + + private def initialFetchState(topicId: Option[Uuid], fetchOffset: Long, leaderEpoch: Int): InitialFetchState = { + InitialFetchState(topicId = topicId, leader = new BrokerEndPoint(0, "localhost", 9092), + initOffset = fetchOffset, currentLeaderEpoch = leaderEpoch) + } + + @Test + def testMetricsRemovedOnShutdown(): Unit = { + val partition = new TopicPartition("topic", 0) + val fetcher = new MockFetcherThread + + // add one partition to create the consumer lag metric + fetcher.setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0)) + fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = 0))) + fetcher.setLeaderState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0)) + + fetcher.start() + + val brokerTopicStatsMetrics = fetcher.brokerTopicStats.allTopicsStats.metricMap.keySet + val fetcherMetrics = Set(FetcherMetrics.BytesPerSec, FetcherMetrics.RequestsPerSec, FetcherMetrics.ConsumerLag) + + // wait until all fetcher metrics are present + TestUtils.waitUntilTrue(() => allMetricsNames == brokerTopicStatsMetrics ++ fetcherMetrics, + "Failed waiting for all fetcher metrics to be registered") + + fetcher.shutdown() + + // verify that all the fetcher metrics are removed and only brokerTopicStats left + val metricNames = KafkaYammerMetrics.defaultRegistry().allMetrics().asScala.keySet.map(_.getName).toSet + assertTrue(metricNames.intersect(fetcherMetrics).isEmpty) + assertEquals(brokerTopicStatsMetrics, metricNames.intersect(brokerTopicStatsMetrics)) + } + + @Test + def testConsumerLagRemovedWithPartition(): Unit = { + val partition = new TopicPartition("topic", 0) + val fetcher = new MockFetcherThread + + // add one partition to create the consumer lag metric + fetcher.setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0)) + fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = 0))) + fetcher.setLeaderState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0)) + + fetcher.doWork() + + assertTrue(allMetricsNames(FetcherMetrics.ConsumerLag), + "Failed waiting for consumer lag metric") + + // remove the partition to simulate leader migration + fetcher.removePartitions(Set(partition)) + + // the lag metric should now be gone + assertFalse(allMetricsNames(FetcherMetrics.ConsumerLag)) + } + + @Test + def testSimpleFetch(): Unit = { + val partition = new TopicPartition("topic", 0) + val fetcher = new MockFetcherThread + + fetcher.setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0)) + fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = 0))) + + val batch = mkBatch(baseOffset = 0L, leaderEpoch = 0, + new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes)) + val leaderState = MockFetcherThread.PartitionState(Seq(batch), leaderEpoch = 0, highWatermark = 2L) + fetcher.setLeaderState(partition, leaderState) + + fetcher.doWork() + + val replicaState = fetcher.replicaPartitionState(partition) + assertEquals(2L, replicaState.logEndOffset) + assertEquals(2L, replicaState.highWatermark) + } + + @Test + def testDelay(): Unit = { + val partition = new TopicPartition("topic", 0) + val fetchBackOffMs = 250 + + val fetcher = new MockFetcherThread(fetchBackOffMs = fetchBackOffMs) { + override def fetchFromLeader(fetchRequest: FetchRequest.Builder): Map[TopicPartition, FetchData] = { + throw new UnknownTopicIdException("Topic ID was unknown as expected for this test") + } + } + + fetcher.setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0)) + fetcher.addPartitions(Map(partition -> initialFetchState(Some(Uuid.randomUuid()), 0L, leaderEpoch = 0))) + + val batch = mkBatch(baseOffset = 0L, leaderEpoch = 0, + new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes)) + val leaderState = MockFetcherThread.PartitionState(Seq(batch), leaderEpoch = 0, highWatermark = 2L) + fetcher.setLeaderState(partition, leaderState) + + // Do work for the first time. This should result in all partitions in error. + val timeBeforeFirst = System.currentTimeMillis() + fetcher.doWork() + val timeAfterFirst = System.currentTimeMillis() + val firstWorkDuration = timeAfterFirst - timeBeforeFirst + + // The second doWork will pause for fetchBackOffMs since all partitions will be delayed + val timeBeforeSecond = System.currentTimeMillis() + fetcher.doWork() + val timeAfterSecond = System.currentTimeMillis() + val secondWorkDuration = timeAfterSecond - timeBeforeSecond + + assertTrue(firstWorkDuration < secondWorkDuration) + // The second call should have taken more than fetchBackOffMs + assertTrue(fetchBackOffMs <= secondWorkDuration, + "secondWorkDuration: " + secondWorkDuration + " was not greater than or equal to fetchBackOffMs: " + fetchBackOffMs) + } + + @Test + def testPartitionsInError(): Unit = { + val partition1 = new TopicPartition("topic1", 0) + val partition2 = new TopicPartition("topic2", 0) + val partition3 = new TopicPartition("topic3", 0) + val fetchBackOffMs = 250 + + val fetcher = new MockFetcherThread(fetchBackOffMs = fetchBackOffMs) { + override def fetchFromLeader(fetchRequest: FetchRequest.Builder): Map[TopicPartition, FetchData] = { + Map(partition1 -> new FetchData().setErrorCode(Errors.UNKNOWN_TOPIC_ID.code), + partition2 -> new FetchData().setErrorCode(Errors.INCONSISTENT_TOPIC_ID.code), + partition3 -> new FetchData().setErrorCode(Errors.NONE.code)) + } + } + + fetcher.setReplicaState(partition1, MockFetcherThread.PartitionState(leaderEpoch = 0)) + fetcher.addPartitions(Map(partition1 -> initialFetchState(Some(Uuid.randomUuid()), 0L, leaderEpoch = 0))) + fetcher.setReplicaState(partition2, MockFetcherThread.PartitionState(leaderEpoch = 0)) + fetcher.addPartitions(Map(partition2 -> initialFetchState(Some(Uuid.randomUuid()), 0L, leaderEpoch = 0))) + fetcher.setReplicaState(partition3, MockFetcherThread.PartitionState(leaderEpoch = 0)) + fetcher.addPartitions(Map(partition3 -> initialFetchState(Some(Uuid.randomUuid()), 0L, leaderEpoch = 0))) + + val batch = mkBatch(baseOffset = 0L, leaderEpoch = 0, + new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes)) + val leaderState = MockFetcherThread.PartitionState(Seq(batch), leaderEpoch = 0, highWatermark = 2L) + fetcher.setLeaderState(partition1, leaderState) + fetcher.setLeaderState(partition2, leaderState) + fetcher.setLeaderState(partition3, leaderState) + + fetcher.doWork() + + val partition1FetchState = fetcher.fetchState(partition1) + val partition2FetchState = fetcher.fetchState(partition2) + val partition3FetchState = fetcher.fetchState(partition3) + assertTrue(partition1FetchState.isDefined) + assertTrue(partition2FetchState.isDefined) + assertTrue(partition3FetchState.isDefined) + + // Only the partitions with errors should be delayed. + assertTrue(partition1FetchState.get.isDelayed) + assertTrue(partition2FetchState.get.isDelayed) + assertFalse(partition3FetchState.get.isDelayed) + } + + @Test + def testFencedTruncation(): Unit = { + val partition = new TopicPartition("topic", 0) + val fetcher = new MockFetcherThread + + fetcher.setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0)) + fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = 0))) + + val batch = mkBatch(baseOffset = 0L, leaderEpoch = 1, + new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes)) + val leaderState = MockFetcherThread.PartitionState(Seq(batch), leaderEpoch = 1, highWatermark = 2L) + fetcher.setLeaderState(partition, leaderState) + + fetcher.doWork() + + // No progress should be made + val replicaState = fetcher.replicaPartitionState(partition) + assertEquals(0L, replicaState.logEndOffset) + assertEquals(0L, replicaState.highWatermark) + + // After fencing, the fetcher should remove the partition from tracking and mark as failed + assertTrue(fetcher.fetchState(partition).isEmpty) + assertTrue(failedPartitions.contains(partition)) + } + + @Test + def testFencedFetch(): Unit = { + val partition = new TopicPartition("topic", 0) + val fetcher = new MockFetcherThread + + val replicaState = MockFetcherThread.PartitionState(leaderEpoch = 0) + fetcher.setReplicaState(partition, replicaState) + fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = 0))) + + val batch = mkBatch(baseOffset = 0L, leaderEpoch = 0, + new SimpleRecord("a".getBytes), + new SimpleRecord("b".getBytes)) + val leaderState = MockFetcherThread.PartitionState(Seq(batch), leaderEpoch = 0, highWatermark = 2L) + fetcher.setLeaderState(partition, leaderState) + + fetcher.doWork() + + // Verify we have caught up + assertEquals(2, replicaState.logEndOffset) + + // Bump the epoch on the leader + fetcher.leaderPartitionState(partition).leaderEpoch += 1 + + fetcher.doWork() + + // After fencing, the fetcher should remove the partition from tracking and mark as failed + assertTrue(fetcher.fetchState(partition).isEmpty) + assertTrue(failedPartitions.contains(partition)) + } + + @Test + def testUnknownLeaderEpochInTruncation(): Unit = { + val partition = new TopicPartition("topic", 0) + val fetcher = new MockFetcherThread + + // The replica's leader epoch is ahead of the leader + val replicaState = MockFetcherThread.PartitionState(leaderEpoch = 1) + fetcher.setReplicaState(partition, replicaState) + fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = 1)), forceTruncation = true) + + val batch = mkBatch(baseOffset = 0L, leaderEpoch = 0, new SimpleRecord("a".getBytes)) + val leaderState = MockFetcherThread.PartitionState(Seq(batch), leaderEpoch = 0, highWatermark = 2L) + fetcher.setLeaderState(partition, leaderState) + + fetcher.doWork() + + // Not data has been fetched and the follower is still truncating + assertEquals(0, replicaState.logEndOffset) + assertEquals(Some(Truncating), fetcher.fetchState(partition).map(_.state)) + + // Bump the epoch on the leader + fetcher.leaderPartitionState(partition).leaderEpoch += 1 + + // Now we can make progress + fetcher.doWork() + + assertEquals(1, replicaState.logEndOffset) + assertEquals(Some(Fetching), fetcher.fetchState(partition).map(_.state)) + } + + @Test + def testUnknownLeaderEpochWhileFetching(): Unit = { + val partition = new TopicPartition("topic", 0) + val fetcher = new MockFetcherThread + + // This test is contrived because it shouldn't be possible to to see unknown leader epoch + // in the Fetching state as the leader must validate the follower's epoch when it checks + // the truncation offset. + + val replicaState = MockFetcherThread.PartitionState(leaderEpoch = 1) + fetcher.setReplicaState(partition, replicaState) + fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = 1))) + + val leaderState = MockFetcherThread.PartitionState(Seq( + mkBatch(baseOffset = 0L, leaderEpoch = 0, new SimpleRecord("a".getBytes)), + mkBatch(baseOffset = 1L, leaderEpoch = 0, new SimpleRecord("b".getBytes)) + ), leaderEpoch = 1, highWatermark = 2L) + fetcher.setLeaderState(partition, leaderState) + + fetcher.doWork() + + // We have fetched one batch and gotten out of the truncation phase + assertEquals(1, replicaState.logEndOffset) + assertEquals(Some(Fetching), fetcher.fetchState(partition).map(_.state)) + + // Somehow the leader epoch rewinds + fetcher.leaderPartitionState(partition).leaderEpoch = 0 + + // We are stuck at the current offset + fetcher.doWork() + assertEquals(1, replicaState.logEndOffset) + assertEquals(Some(Fetching), fetcher.fetchState(partition).map(_.state)) + + // After returning to the right epoch, we can continue fetching + fetcher.leaderPartitionState(partition).leaderEpoch = 1 + fetcher.doWork() + assertEquals(2, replicaState.logEndOffset) + assertEquals(Some(Fetching), fetcher.fetchState(partition).map(_.state)) + } + + @Test + def testTruncation(): Unit = { + val partition = new TopicPartition("topic", 0) + val fetcher = new MockFetcherThread + + val replicaLog = Seq( + mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes)), + mkBatch(baseOffset = 1, leaderEpoch = 2, new SimpleRecord("b".getBytes)), + mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes))) + + val replicaState = MockFetcherThread.PartitionState(replicaLog, leaderEpoch = 5, highWatermark = 0L) + fetcher.setReplicaState(partition, replicaState) + fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 3L, leaderEpoch = 5))) + + val leaderLog = Seq( + mkBatch(baseOffset = 0, leaderEpoch = 1, new SimpleRecord("a".getBytes)), + mkBatch(baseOffset = 1, leaderEpoch = 3, new SimpleRecord("b".getBytes)), + mkBatch(baseOffset = 2, leaderEpoch = 5, new SimpleRecord("c".getBytes))) + + val leaderState = MockFetcherThread.PartitionState(leaderLog, leaderEpoch = 5, highWatermark = 2L) + fetcher.setLeaderState(partition, leaderState) + + TestUtils.waitUntilTrue(() => { + fetcher.doWork() + fetcher.replicaPartitionState(partition).log == fetcher.leaderPartitionState(partition).log + }, "Failed to reconcile leader and follower logs") + + assertEquals(leaderState.logStartOffset, replicaState.logStartOffset) + assertEquals(leaderState.logEndOffset, replicaState.logEndOffset) + assertEquals(leaderState.highWatermark, replicaState.highWatermark) + } + + @Test + def testTruncateToHighWatermarkIfLeaderEpochRequestNotSupported(): Unit = { + val highWatermark = 2L + val partition = new TopicPartition("topic", 0) + val fetcher = new MockFetcherThread { + override def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit = { + assertEquals(highWatermark, truncationState.offset) + assertTrue(truncationState.truncationCompleted) + super.truncate(topicPartition, truncationState) + } + + override def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] = + throw new UnsupportedOperationException + + override protected val isOffsetForLeaderEpochSupported: Boolean = false + + override protected val isTruncationOnFetchSupported: Boolean = false + } + + val replicaLog = Seq( + mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes)), + mkBatch(baseOffset = 1, leaderEpoch = 2, new SimpleRecord("b".getBytes)), + mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes))) + + val replicaState = MockFetcherThread.PartitionState(replicaLog, leaderEpoch = 5, highWatermark) + fetcher.setReplicaState(partition, replicaState) + fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), highWatermark, leaderEpoch = 5))) + + fetcher.doWork() + + assertEquals(highWatermark, replicaState.logEndOffset) + assertEquals(highWatermark, fetcher.fetchState(partition).get.fetchOffset) + assertTrue(fetcher.fetchState(partition).get.isReadyForFetch) + } + + @Test + def testTruncateToHighWatermarkIfLeaderEpochInfoNotAvailable(): Unit = { + val highWatermark = 2L + val partition = new TopicPartition("topic", 0) + val fetcher = new MockFetcherThread { + override def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit = { + assertEquals(highWatermark, truncationState.offset) + assertTrue(truncationState.truncationCompleted) + super.truncate(topicPartition, truncationState) + } + + override def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] = + throw new UnsupportedOperationException + + override def latestEpoch(topicPartition: TopicPartition): Option[Int] = None + } + + val replicaLog = Seq( + mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes)), + mkBatch(baseOffset = 1, leaderEpoch = 2, new SimpleRecord("b".getBytes)), + mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes))) + + val replicaState = MockFetcherThread.PartitionState(replicaLog, leaderEpoch = 5, highWatermark) + fetcher.setReplicaState(partition, replicaState) + fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), highWatermark, leaderEpoch = 5))) + + fetcher.doWork() + + assertEquals(highWatermark, replicaState.logEndOffset) + assertEquals(highWatermark, fetcher.fetchState(partition).get.fetchOffset) + assertTrue(fetcher.fetchState(partition).get.isReadyForFetch) + } + + @Test + def testTruncateToHighWatermarkDuringRemovePartitions(): Unit = { + val highWatermark = 2L + val partition = new TopicPartition("topic", 0) + val fetcher = new MockFetcherThread { + override def truncateToHighWatermark(partitions: Set[TopicPartition]): Unit = { + removePartitions(Set(partition)) + super.truncateToHighWatermark(partitions) + } + + override def latestEpoch(topicPartition: TopicPartition): Option[Int] = None + } + + val replicaLog = Seq( + mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes)), + mkBatch(baseOffset = 1, leaderEpoch = 2, new SimpleRecord("b".getBytes)), + mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes))) + + val replicaState = MockFetcherThread.PartitionState(replicaLog, leaderEpoch = 5, highWatermark) + fetcher.setReplicaState(partition, replicaState) + fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), highWatermark, leaderEpoch = 5))) + + fetcher.doWork() + + assertEquals(replicaLog.last.nextOffset(), replicaState.logEndOffset) + assertTrue(fetcher.fetchState(partition).isEmpty) + } + + @Test + def testTruncationSkippedIfNoEpochChange(): Unit = { + val partition = new TopicPartition("topic", 0) + + var truncations = 0 + val fetcher = new MockFetcherThread { + override def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit = { + truncations += 1 + super.truncate(topicPartition, truncationState) + } + } + + val replicaState = MockFetcherThread.PartitionState(leaderEpoch = 5) + fetcher.setReplicaState(partition, replicaState) + fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = 5)), forceTruncation = true) + + val leaderLog = Seq( + mkBatch(baseOffset = 0, leaderEpoch = 1, new SimpleRecord("a".getBytes)), + mkBatch(baseOffset = 1, leaderEpoch = 3, new SimpleRecord("b".getBytes)), + mkBatch(baseOffset = 2, leaderEpoch = 5, new SimpleRecord("c".getBytes))) + + val leaderState = MockFetcherThread.PartitionState(leaderLog, leaderEpoch = 5, highWatermark = 2L) + fetcher.setLeaderState(partition, leaderState) + + // Do one round of truncation + fetcher.doWork() + + // We only fetch one record at a time with mock fetcher + assertEquals(1, replicaState.logEndOffset) + assertEquals(1, truncations) + + // Add partitions again with the same epoch + fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 3L, leaderEpoch = 5))) + + // Verify we did not truncate + fetcher.doWork() + + // No truncations occurred and we have fetched another record + assertEquals(1, truncations) + assertEquals(2, replicaState.logEndOffset) + } + + @Test + def testTruncationOnFetchSkippedIfPartitionRemoved(): Unit = { + assumeTrue(truncateOnFetch) + val partition = new TopicPartition("topic", 0) + var truncations = 0 + val fetcher = new MockFetcherThread { + override def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit = { + truncations += 1 + super.truncate(topicPartition, truncationState) + } + } + val replicaLog = Seq( + mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes)), + mkBatch(baseOffset = 1, leaderEpoch = 2, new SimpleRecord("b".getBytes)), + mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes))) + + val replicaState = MockFetcherThread.PartitionState(replicaLog, leaderEpoch = 5, highWatermark = 2L) + fetcher.setReplicaState(partition, replicaState) + + // Verify that truncation based on fetch response is performed if partition is owned by fetcher thread + fetcher.addPartitions(Map(partition -> initialFetchState(Some(Uuid.randomUuid()), 6L, leaderEpoch = 4))) + val endOffset = new EpochEndOffset() + .setPartition(partition.partition) + .setErrorCode(Errors.NONE.code) + .setLeaderEpoch(4) + .setEndOffset(3L) + fetcher.truncateOnFetchResponse(Map(partition -> endOffset)) + assertEquals(1, truncations) + + // Verify that truncation based on fetch response is not performed if partition is removed from fetcher thread + val offsets = fetcher.removePartitions(Set(partition)) + assertEquals(Set(partition), offsets.keySet) + assertEquals(3L, offsets(partition).fetchOffset) + val newEndOffset = new EpochEndOffset() + .setPartition(partition.partition) + .setErrorCode(Errors.NONE.code) + .setLeaderEpoch(4) + .setEndOffset(2L) + fetcher.truncateOnFetchResponse(Map(partition -> newEndOffset)) + assertEquals(1, truncations) + } + + @Test + def testFollowerFetchOutOfRangeHigh(): Unit = { + val partition = new TopicPartition("topic", 0) + val fetcher = new MockFetcherThread() + + val replicaLog = Seq( + mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes)), + mkBatch(baseOffset = 1, leaderEpoch = 2, new SimpleRecord("b".getBytes)), + mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes))) + + val replicaState = MockFetcherThread.PartitionState(replicaLog, leaderEpoch = 4, highWatermark = 0L) + fetcher.setReplicaState(partition, replicaState) + fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 3L, leaderEpoch = 4))) + + val leaderLog = Seq( + mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes)), + mkBatch(baseOffset = 1, leaderEpoch = 2, new SimpleRecord("b".getBytes)), + mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes))) + + val leaderState = MockFetcherThread.PartitionState(leaderLog, leaderEpoch = 4, highWatermark = 2L) + fetcher.setLeaderState(partition, leaderState) + + // initial truncation and verify that the log end offset is updated + fetcher.doWork() + assertEquals(3L, replicaState.logEndOffset) + assertEquals(Option(Fetching), fetcher.fetchState(partition).map(_.state)) + + // To hit this case, we have to change the leader log without going through the truncation phase + leaderState.log.clear() + leaderState.logEndOffset = 0L + leaderState.logStartOffset = 0L + leaderState.highWatermark = 0L + + fetcher.doWork() + + assertEquals(0L, replicaState.logEndOffset) + assertEquals(0L, replicaState.logStartOffset) + assertEquals(0L, replicaState.highWatermark) + } + + @Test + def testFencedOffsetResetAfterOutOfRange(): Unit = { + val partition = new TopicPartition("topic", 0) + var fetchedEarliestOffset = false + val fetcher = new MockFetcherThread() { + override protected def fetchEarliestOffsetFromLeader(topicPartition: TopicPartition, leaderEpoch: Int): Long = { + fetchedEarliestOffset = true + throw new FencedLeaderEpochException(s"Epoch $leaderEpoch is fenced") + } + } + + val replicaLog = Seq() + val replicaState = MockFetcherThread.PartitionState(replicaLog, leaderEpoch = 4, highWatermark = 0L) + fetcher.setReplicaState(partition, replicaState) + fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = 4))) + + val leaderLog = Seq( + mkBatch(baseOffset = 1, leaderEpoch = 2, new SimpleRecord("b".getBytes)), + mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes))) + val leaderState = MockFetcherThread.PartitionState(leaderLog, leaderEpoch = 4, highWatermark = 2L) + fetcher.setLeaderState(partition, leaderState) + + // After the out of range error, we get a fenced error and remove the partition and mark as failed + fetcher.doWork() + assertEquals(0, replicaState.logEndOffset) + assertTrue(fetchedEarliestOffset) + assertTrue(fetcher.fetchState(partition).isEmpty) + assertTrue(failedPartitions.contains(partition)) + } + + @Test + def testFollowerFetchOutOfRangeLow(): Unit = { + val partition = new TopicPartition("topic", 0) + val fetcher = new MockFetcherThread + + // The follower begins from an offset which is behind the leader's log start offset + val replicaLog = Seq( + mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes))) + + val replicaState = MockFetcherThread.PartitionState(replicaLog, leaderEpoch = 0, highWatermark = 0L) + fetcher.setReplicaState(partition, replicaState) + fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 3L, leaderEpoch = 0))) + + val leaderLog = Seq( + mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes))) + + val leaderState = MockFetcherThread.PartitionState(leaderLog, leaderEpoch = 0, highWatermark = 2L) + fetcher.setLeaderState(partition, leaderState) + + // initial truncation and verify that the log start offset is updated + fetcher.doWork() + if (truncateOnFetch) { + // Second iteration required here since first iteration is required to + // perform initial truncaton based on diverging epoch. + fetcher.doWork() + } + assertEquals(Option(Fetching), fetcher.fetchState(partition).map(_.state)) + assertEquals(2, replicaState.logStartOffset) + assertEquals(List(), replicaState.log.toList) + + TestUtils.waitUntilTrue(() => { + fetcher.doWork() + fetcher.replicaPartitionState(partition).log == fetcher.leaderPartitionState(partition).log + }, "Failed to reconcile leader and follower logs") + + assertEquals(leaderState.logStartOffset, replicaState.logStartOffset) + assertEquals(leaderState.logEndOffset, replicaState.logEndOffset) + assertEquals(leaderState.highWatermark, replicaState.highWatermark) + } + + @Test + def testRetryAfterUnknownLeaderEpochInLatestOffsetFetch(): Unit = { + val partition = new TopicPartition("topic", 0) + val fetcher: MockFetcherThread = new MockFetcherThread { + val tries = new AtomicInteger(0) + override protected def fetchLatestOffsetFromLeader(topicPartition: TopicPartition, leaderEpoch: Int): Long = { + if (tries.getAndIncrement() == 0) + throw new UnknownLeaderEpochException("Unexpected leader epoch") + super.fetchLatestOffsetFromLeader(topicPartition, leaderEpoch) + } + } + + // The follower begins from an offset which is behind the leader's log start offset + val replicaLog = Seq( + mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes))) + + val replicaState = MockFetcherThread.PartitionState(replicaLog, leaderEpoch = 0, highWatermark = 0L) + fetcher.setReplicaState(partition, replicaState) + fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 3L, leaderEpoch = 0))) + + val leaderLog = Seq( + mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes))) + + val leaderState = MockFetcherThread.PartitionState(leaderLog, leaderEpoch = 0, highWatermark = 2L) + fetcher.setLeaderState(partition, leaderState) + + // initial truncation and initial error response handling + fetcher.doWork() + assertEquals(Option(Fetching), fetcher.fetchState(partition).map(_.state)) + + TestUtils.waitUntilTrue(() => { + fetcher.doWork() + fetcher.replicaPartitionState(partition).log == fetcher.leaderPartitionState(partition).log + }, "Failed to reconcile leader and follower logs") + + assertEquals(leaderState.logStartOffset, replicaState.logStartOffset) + assertEquals(leaderState.logEndOffset, replicaState.logEndOffset) + assertEquals(leaderState.highWatermark, replicaState.highWatermark) + } + + @Test + def testCorruptMessage(): Unit = { + val partition = new TopicPartition("topic", 0) + + val fetcher = new MockFetcherThread { + var fetchedOnce = false + override def fetchFromLeader(fetchRequest: FetchRequest.Builder): Map[TopicPartition, FetchData] = { + val fetchedData = super.fetchFromLeader(fetchRequest) + if (!fetchedOnce) { + val records = fetchedData.head._2.records.asInstanceOf[MemoryRecords] + val buffer = records.buffer() + buffer.putInt(15, buffer.getInt(15) ^ 23422) + buffer.putInt(30, buffer.getInt(30) ^ 93242) + fetchedOnce = true + } + fetchedData + } + } + + fetcher.setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0)) + fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = 0))) + + val batch = mkBatch(baseOffset = 0L, leaderEpoch = 0, + new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes)) + val leaderState = MockFetcherThread.PartitionState(Seq(batch), leaderEpoch = 0, highWatermark = 2L) + fetcher.setLeaderState(partition, leaderState) + + fetcher.doWork() // fails with corrupt record + fetcher.doWork() // should succeed + + val replicaState = fetcher.replicaPartitionState(partition) + assertEquals(2L, replicaState.logEndOffset) + } + + @Test + def testLeaderEpochChangeDuringFencedFetchEpochsFromLeader(): Unit = { + // The leader is on the new epoch when the OffsetsForLeaderEpoch with old epoch is sent, so it + // returns the fence error. Validate that response is ignored if the leader epoch changes on + // the follower while OffsetsForLeaderEpoch request is in flight, but able to truncate and fetch + // in the next of round of "doWork" + testLeaderEpochChangeDuringFetchEpochsFromLeader(leaderEpochOnLeader = 1) + } + + @Test + def testLeaderEpochChangeDuringSuccessfulFetchEpochsFromLeader(): Unit = { + // The leader is on the old epoch when the OffsetsForLeaderEpoch with old epoch is sent + // and returns the valid response. Validate that response is ignored if the leader epoch changes + // on the follower while OffsetsForLeaderEpoch request is in flight, but able to truncate and + // fetch once the leader is on the newer epoch (same as follower) + testLeaderEpochChangeDuringFetchEpochsFromLeader(leaderEpochOnLeader = 0) + } + + private def testLeaderEpochChangeDuringFetchEpochsFromLeader(leaderEpochOnLeader: Int): Unit = { + val partition = new TopicPartition("topic", 1) + val initialLeaderEpochOnFollower = 0 + val nextLeaderEpochOnFollower = initialLeaderEpochOnFollower + 1 + + val fetcher = new MockFetcherThread { + var fetchEpochsFromLeaderOnce = false + override def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] = { + val fetchedEpochs = super.fetchEpochEndOffsets(partitions) + if (!fetchEpochsFromLeaderOnce) { + // leader epoch changes while fetching epochs from leader + removePartitions(Set(partition)) + setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = nextLeaderEpochOnFollower)) + addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = nextLeaderEpochOnFollower)), forceTruncation = true) + fetchEpochsFromLeaderOnce = true + } + fetchedEpochs + } + } + + fetcher.setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = initialLeaderEpochOnFollower)) + fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = initialLeaderEpochOnFollower)), forceTruncation = true) + + val leaderLog = Seq( + mkBatch(baseOffset = 0, leaderEpoch = initialLeaderEpochOnFollower, new SimpleRecord("c".getBytes))) + val leaderState = MockFetcherThread.PartitionState(leaderLog, leaderEpochOnLeader, highWatermark = 0L) + fetcher.setLeaderState(partition, leaderState) + + // first round of truncation + fetcher.doWork() + + // Since leader epoch changed, fetch epochs response is ignored due to partition being in + // truncating state with the updated leader epoch + assertEquals(Option(Truncating), fetcher.fetchState(partition).map(_.state)) + assertEquals(Option(nextLeaderEpochOnFollower), fetcher.fetchState(partition).map(_.currentLeaderEpoch)) + + if (leaderEpochOnLeader < nextLeaderEpochOnFollower) { + fetcher.setLeaderState( + partition, MockFetcherThread.PartitionState(leaderLog, nextLeaderEpochOnFollower, highWatermark = 0L)) + } + + // make sure the fetcher is now able to truncate and fetch + fetcher.doWork() + assertEquals(fetcher.leaderPartitionState(partition).log, fetcher.replicaPartitionState(partition).log) + } + + @Test + def testTruncateToEpochEndOffsetsDuringRemovePartitions(): Unit = { + val partition = new TopicPartition("topic", 0) + val leaderEpochOnLeader = 0 + val initialLeaderEpochOnFollower = 0 + val nextLeaderEpochOnFollower = initialLeaderEpochOnFollower + 1 + + val fetcher = new MockFetcherThread { + override def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] = { + val fetchedEpochs = super.fetchEpochEndOffsets(partitions) + // leader epoch changes while fetching epochs from leader + // at the same time, the replica fetcher manager removes the partition + removePartitions(Set(partition)) + setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = nextLeaderEpochOnFollower)) + fetchedEpochs + } + } + + fetcher.setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = initialLeaderEpochOnFollower)) + fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = initialLeaderEpochOnFollower))) + + val leaderLog = Seq( + mkBatch(baseOffset = 0, leaderEpoch = initialLeaderEpochOnFollower, new SimpleRecord("c".getBytes))) + val leaderState = MockFetcherThread.PartitionState(leaderLog, leaderEpochOnLeader, highWatermark = 0L) + fetcher.setLeaderState(partition, leaderState) + + // first round of work + fetcher.doWork() + + // since the partition was removed before the fetched endOffsets were filtered against the leader epoch, + // we do not expect the partition to be in Truncating state + assertEquals(None, fetcher.fetchState(partition).map(_.state)) + assertEquals(None, fetcher.fetchState(partition).map(_.currentLeaderEpoch)) + + fetcher.setLeaderState( + partition, MockFetcherThread.PartitionState(leaderLog, nextLeaderEpochOnFollower, highWatermark = 0L)) + + // make sure the fetcher is able to continue work + fetcher.doWork() + assertEquals(ArrayBuffer.empty, fetcher.replicaPartitionState(partition).log) + } + + @Test + def testTruncationThrowsExceptionIfLeaderReturnsPartitionsNotRequestedInFetchEpochs(): Unit = { + val partition = new TopicPartition("topic", 0) + val fetcher = new MockFetcherThread { + override def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] = { + val unrequestedTp = new TopicPartition("topic2", 0) + super.fetchEpochEndOffsets(partitions).toMap + (unrequestedTp -> new EpochEndOffset() + .setPartition(unrequestedTp.partition) + .setErrorCode(Errors.NONE.code) + .setLeaderEpoch(0) + .setEndOffset(0)) + } + } + + fetcher.setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0)) + fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 0L, leaderEpoch = 0)), forceTruncation = true) + fetcher.setLeaderState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0)) + + // first round of truncation should throw an exception + assertThrows(classOf[IllegalStateException], () => fetcher.doWork()) + } + + @Test + def testFetcherThreadHandlingPartitionFailureDuringAppending(): Unit = { + val fetcherForAppend = new MockFetcherThread { + override def processPartitionData(topicPartition: TopicPartition, fetchOffset: Long, partitionData: FetchData): Option[LogAppendInfo] = { + if (topicPartition == partition1) { + throw new KafkaException() + } else { + super.processPartitionData(topicPartition, fetchOffset, partitionData) + } + } + } + verifyFetcherThreadHandlingPartitionFailure(fetcherForAppend) + } + + @Test + def testFetcherThreadHandlingPartitionFailureDuringTruncation(): Unit = { + val fetcherForTruncation = new MockFetcherThread { + override def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit = { + if(topicPartition == partition1) + throw new Exception() + else { + super.truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState) + } + } + } + verifyFetcherThreadHandlingPartitionFailure(fetcherForTruncation) + } + + private def verifyFetcherThreadHandlingPartitionFailure(fetcher: MockFetcherThread): Unit = { + + fetcher.setReplicaState(partition1, MockFetcherThread.PartitionState(leaderEpoch = 0)) + fetcher.addPartitions(Map(partition1 -> initialFetchState(topicIds.get(partition1.topic), 0L, leaderEpoch = 0)), forceTruncation = true) + fetcher.setLeaderState(partition1, MockFetcherThread.PartitionState(leaderEpoch = 0)) + + fetcher.setReplicaState(partition2, MockFetcherThread.PartitionState(leaderEpoch = 0)) + fetcher.addPartitions(Map(partition2 -> initialFetchState(topicIds.get(partition2.topic), 0L, leaderEpoch = 0)), forceTruncation = true) + fetcher.setLeaderState(partition2, MockFetcherThread.PartitionState(leaderEpoch = 0)) + + // processing data fails for partition1 + fetcher.doWork() + + // partition1 marked as failed + assertTrue(failedPartitions.contains(partition1)) + assertEquals(None, fetcher.fetchState(partition1)) + + // make sure the fetcher continues to work with rest of the partitions + fetcher.doWork() + assertEquals(Some(Fetching), fetcher.fetchState(partition2).map(_.state)) + assertFalse(failedPartitions.contains(partition2)) + + // simulate a leader change + fetcher.removePartitions(Set(partition1)) + failedPartitions.removeAll(Set(partition1)) + fetcher.addPartitions(Map(partition1 -> initialFetchState(topicIds.get(partition1.topic), 0L, leaderEpoch = 1)), forceTruncation = true) + + // partition1 added back + assertEquals(Some(Truncating), fetcher.fetchState(partition1).map(_.state)) + assertFalse(failedPartitions.contains(partition1)) + + } + + @Test + def testDivergingEpochs(): Unit = { + val partition = new TopicPartition("topic", 0) + val fetcher = new MockFetcherThread + + val replicaLog = Seq( + mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes)), + mkBatch(baseOffset = 1, leaderEpoch = 2, new SimpleRecord("b".getBytes)), + mkBatch(baseOffset = 2, leaderEpoch = 4, new SimpleRecord("c".getBytes))) + + val replicaState = MockFetcherThread.PartitionState(replicaLog, leaderEpoch = 5, highWatermark = 0L) + fetcher.setReplicaState(partition, replicaState) + fetcher.addPartitions(Map(partition -> initialFetchState(topicIds.get(partition.topic), 3L, leaderEpoch = 5))) + assertEquals(3L, replicaState.logEndOffset) + fetcher.verifyLastFetchedEpoch(partition, expectedEpoch = Some(4)) + + val leaderLog = Seq( + mkBatch(baseOffset = 0, leaderEpoch = 0, new SimpleRecord("a".getBytes)), + mkBatch(baseOffset = 1, leaderEpoch = 2, new SimpleRecord("b".getBytes)), + mkBatch(baseOffset = 2, leaderEpoch = 5, new SimpleRecord("d".getBytes))) + + val leaderState = MockFetcherThread.PartitionState(leaderLog, leaderEpoch = 5, highWatermark = 2L) + fetcher.setLeaderState(partition, leaderState) + + fetcher.doWork() + fetcher.verifyLastFetchedEpoch(partition, Some(2)) + + TestUtils.waitUntilTrue(() => { + fetcher.doWork() + fetcher.replicaPartitionState(partition).log == fetcher.leaderPartitionState(partition).log + }, "Failed to reconcile leader and follower logs") + fetcher.verifyLastFetchedEpoch(partition, Some(5)) + } + + @Test + def testMaybeUpdateTopicIds(): Unit = { + val partition = new TopicPartition("topic1", 0) + val fetcher = new MockFetcherThread + + // Start with no topic IDs + fetcher.setReplicaState(partition, MockFetcherThread.PartitionState(leaderEpoch = 0)) + fetcher.addPartitions(Map(partition -> initialFetchState(None, 0L, leaderEpoch = 0))) + + def verifyFetchState(fetchState: Option[PartitionFetchState], expectedTopicId: Option[Uuid]): Unit = { + assertTrue(fetchState.isDefined) + assertEquals(expectedTopicId, fetchState.get.topicId) + } + + verifyFetchState(fetcher.fetchState(partition), None) + + // Add topic ID + fetcher.maybeUpdateTopicIds(Set(partition), topicName => topicIds.get(topicName)) + verifyFetchState(fetcher.fetchState(partition), topicIds.get(partition.topic)) + + // Try to update topic ID for non-existent topic partition + val unknownPartition = new TopicPartition("unknown", 0) + fetcher.maybeUpdateTopicIds(Set(unknownPartition), topicName => topicIds.get(topicName)) + assertTrue(fetcher.fetchState(unknownPartition).isEmpty) + } + + object MockFetcherThread { + class PartitionState(var log: mutable.Buffer[RecordBatch], + var leaderEpoch: Int, + var logStartOffset: Long, + var logEndOffset: Long, + var highWatermark: Long) + + object PartitionState { + def apply(log: Seq[RecordBatch], leaderEpoch: Int, highWatermark: Long): PartitionState = { + val logStartOffset = log.headOption.map(_.baseOffset).getOrElse(0L) + val logEndOffset = log.lastOption.map(_.nextOffset).getOrElse(0L) + new PartitionState(log.toBuffer, leaderEpoch, logStartOffset, logEndOffset, highWatermark) + } + + def apply(leaderEpoch: Int): PartitionState = { + apply(Seq(), leaderEpoch = leaderEpoch, highWatermark = 0L) + } + } + } + + class MockFetcherThread(val replicaId: Int = 0, val leaderId: Int = 1, fetchBackOffMs: Int = 0) + extends AbstractFetcherThread("mock-fetcher", + clientId = "mock-fetcher", + sourceBroker = new BrokerEndPoint(leaderId, host = "localhost", port = Random.nextInt()), + failedPartitions, + fetchBackOffMs = fetchBackOffMs, + brokerTopicStats = new BrokerTopicStats) { + + import MockFetcherThread.PartitionState + + private val replicaPartitionStates = mutable.Map[TopicPartition, PartitionState]() + private val leaderPartitionStates = mutable.Map[TopicPartition, PartitionState]() + private var latestEpochDefault: Option[Int] = Some(0) + + def setLeaderState(topicPartition: TopicPartition, state: PartitionState): Unit = { + leaderPartitionStates.put(topicPartition, state) + } + + def setReplicaState(topicPartition: TopicPartition, state: PartitionState): Unit = { + replicaPartitionStates.put(topicPartition, state) + } + + def replicaPartitionState(topicPartition: TopicPartition): PartitionState = { + replicaPartitionStates.getOrElse(topicPartition, + throw new IllegalArgumentException(s"Unknown partition $topicPartition")) + } + + def leaderPartitionState(topicPartition: TopicPartition): PartitionState = { + leaderPartitionStates.getOrElse(topicPartition, + throw new IllegalArgumentException(s"Unknown partition $topicPartition")) + } + + def addPartitions(initialFetchStates: Map[TopicPartition, InitialFetchState], forceTruncation: Boolean): Set[TopicPartition] = { + latestEpochDefault = if (forceTruncation) None else Some(0) + val partitions = super.addPartitions(initialFetchStates) + latestEpochDefault = Some(0) + partitions + } + + override def processPartitionData(topicPartition: TopicPartition, + fetchOffset: Long, + partitionData: FetchData): Option[LogAppendInfo] = { + val state = replicaPartitionState(topicPartition) + + if (isTruncationOnFetchSupported && FetchResponse.isDivergingEpoch(partitionData)) { + val divergingEpoch = partitionData.divergingEpoch + truncateOnFetchResponse(Map(topicPartition -> new EpochEndOffset() + .setPartition(topicPartition.partition) + .setErrorCode(Errors.NONE.code) + .setLeaderEpoch(divergingEpoch.epoch) + .setEndOffset(divergingEpoch.endOffset))) + return None + } + + // Throw exception if the fetchOffset does not match the fetcherThread partition state + if (fetchOffset != state.logEndOffset) + throw new RuntimeException(s"Offset mismatch for partition $topicPartition: " + + s"fetched offset = $fetchOffset, log end offset = ${state.logEndOffset}.") + + // Now check message's crc + val batches = FetchResponse.recordsOrFail(partitionData).batches.asScala + var maxTimestamp = RecordBatch.NO_TIMESTAMP + var offsetOfMaxTimestamp = -1L + var lastOffset = state.logEndOffset + var lastEpoch: Option[Int] = None + + for (batch <- batches) { + batch.ensureValid() + if (batch.maxTimestamp > maxTimestamp) { + maxTimestamp = batch.maxTimestamp + offsetOfMaxTimestamp = batch.baseOffset + } + state.log.append(batch) + state.logEndOffset = batch.nextOffset + lastOffset = batch.lastOffset + lastEpoch = Some(batch.partitionLeaderEpoch) + } + + state.logStartOffset = partitionData.logStartOffset + state.highWatermark = partitionData.highWatermark + + Some(LogAppendInfo(firstOffset = Some(LogOffsetMetadata(fetchOffset)), + lastOffset = lastOffset, + lastLeaderEpoch = lastEpoch, + maxTimestamp = maxTimestamp, + offsetOfMaxTimestamp = offsetOfMaxTimestamp, + logAppendTime = Time.SYSTEM.milliseconds(), + logStartOffset = state.logStartOffset, + recordConversionStats = RecordConversionStats.EMPTY, + sourceCodec = NoCompressionCodec, + targetCodec = NoCompressionCodec, + shallowCount = batches.size, + validBytes = FetchResponse.recordsSize(partitionData), + offsetsMonotonic = true, + lastOffsetOfFirstBatch = batches.headOption.map(_.lastOffset).getOrElse(-1))) + } + + override def truncate(topicPartition: TopicPartition, truncationState: OffsetTruncationState): Unit = { + val state = replicaPartitionState(topicPartition) + state.log = state.log.takeWhile { batch => + batch.lastOffset < truncationState.offset + } + state.logEndOffset = state.log.lastOption.map(_.lastOffset + 1).getOrElse(state.logStartOffset) + state.highWatermark = math.min(state.highWatermark, state.logEndOffset) + } + + override def truncateFullyAndStartAt(topicPartition: TopicPartition, offset: Long): Unit = { + val state = replicaPartitionState(topicPartition) + state.log.clear() + state.logStartOffset = offset + state.logEndOffset = offset + state.highWatermark = offset + } + + override def buildFetch(partitionMap: Map[TopicPartition, PartitionFetchState]): ResultWithPartitions[Option[ReplicaFetch]] = { + val fetchData = mutable.Map.empty[TopicPartition, FetchRequest.PartitionData] + partitionMap.foreach { case (partition, state) => + if (state.isReadyForFetch) { + val replicaState = replicaPartitionState(partition) + val lastFetchedEpoch = if (isTruncationOnFetchSupported) + state.lastFetchedEpoch.map(_.asInstanceOf[Integer]).asJava + else + Optional.empty[Integer] + fetchData.put(partition, + new FetchRequest.PartitionData(state.topicId.getOrElse(Uuid.ZERO_UUID), state.fetchOffset, replicaState.logStartOffset, + 1024 * 1024, Optional.of[Integer](state.currentLeaderEpoch), lastFetchedEpoch)) + } + } + val fetchRequest = FetchRequest.Builder.forReplica(version, replicaId, 0, 1, fetchData.asJava) + val fetchRequestOpt = + if (fetchData.isEmpty) + None + else + Some(ReplicaFetch(fetchData.asJava, fetchRequest)) + ResultWithPartitions(fetchRequestOpt, Set.empty) + } + + override def latestEpoch(topicPartition: TopicPartition): Option[Int] = { + val state = replicaPartitionState(topicPartition) + state.log.lastOption.map(_.partitionLeaderEpoch).orElse(latestEpochDefault) + } + + override def logStartOffset(topicPartition: TopicPartition): Long = replicaPartitionState(topicPartition).logStartOffset + + override def logEndOffset(topicPartition: TopicPartition): Long = replicaPartitionState(topicPartition).logEndOffset + + override def endOffsetForEpoch(topicPartition: TopicPartition, epoch: Int): Option[OffsetAndEpoch] = { + val epochData = new EpochData() + .setPartition(topicPartition.partition) + .setLeaderEpoch(epoch) + val result = lookupEndOffsetForEpoch(topicPartition, epochData, replicaPartitionState(topicPartition)) + if (result.endOffset == UNDEFINED_EPOCH_OFFSET) + None + else + Some(OffsetAndEpoch(result.endOffset, result.leaderEpoch)) + } + + private def checkExpectedLeaderEpoch(expectedEpochOpt: Optional[Integer], + partitionState: PartitionState): Option[Errors] = { + if (expectedEpochOpt.isPresent) { + checkExpectedLeaderEpoch(expectedEpochOpt.get, partitionState) + } else { + None + } + } + + private def checkExpectedLeaderEpoch(expectedEpoch: Int, + partitionState: PartitionState): Option[Errors] = { + if (expectedEpoch != RecordBatch.NO_PARTITION_LEADER_EPOCH) { + if (expectedEpoch < partitionState.leaderEpoch) + Some(Errors.FENCED_LEADER_EPOCH) + else if (expectedEpoch > partitionState.leaderEpoch) + Some(Errors.UNKNOWN_LEADER_EPOCH) + else + None + } else { + None + } + } + + def verifyLastFetchedEpoch(partition: TopicPartition, expectedEpoch: Option[Int]): Unit = { + if (isTruncationOnFetchSupported) { + assertEquals(Some(Fetching), fetchState(partition).map(_.state)) + assertEquals(expectedEpoch, fetchState(partition).flatMap(_.lastFetchedEpoch)) + } + } + + private def divergingEpochAndOffset(topicPartition: TopicPartition, + lastFetchedEpoch: Optional[Integer], + fetchOffset: Long, + partitionState: PartitionState): Option[FetchResponseData.EpochEndOffset] = { + lastFetchedEpoch.asScala.flatMap { fetchEpoch => + val epochEndOffset = fetchEpochEndOffsets( + Map(topicPartition -> new EpochData() + .setPartition(topicPartition.partition) + .setLeaderEpoch(fetchEpoch)))(topicPartition) + + if (partitionState.log.isEmpty + || epochEndOffset.endOffset == UNDEFINED_EPOCH_OFFSET + || epochEndOffset.leaderEpoch == UNDEFINED_EPOCH) + None + else if (epochEndOffset.leaderEpoch < fetchEpoch || epochEndOffset.endOffset < fetchOffset) { + Some(new FetchResponseData.EpochEndOffset() + .setEpoch(epochEndOffset.leaderEpoch) + .setEndOffset(epochEndOffset.endOffset)) + } else + None + } + } + + private def lookupEndOffsetForEpoch(topicPartition: TopicPartition, + epochData: EpochData, + partitionState: PartitionState): EpochEndOffset = { + checkExpectedLeaderEpoch(epochData.currentLeaderEpoch, partitionState).foreach { error => + return new EpochEndOffset() + .setPartition(topicPartition.partition) + .setErrorCode(error.code) + } + + var epochLowerBound = UNDEFINED_EPOCH + for (batch <- partitionState.log) { + if (batch.partitionLeaderEpoch > epochData.leaderEpoch) { + // If we don't have the requested epoch, return the next higher entry + if (epochLowerBound == UNDEFINED_EPOCH) + return new EpochEndOffset() + .setPartition(topicPartition.partition) + .setErrorCode(Errors.NONE.code) + .setLeaderEpoch(batch.partitionLeaderEpoch) + .setEndOffset(batch.baseOffset) + else + return new EpochEndOffset() + .setPartition(topicPartition.partition) + .setErrorCode(Errors.NONE.code) + .setLeaderEpoch(epochLowerBound) + .setEndOffset(batch.baseOffset) + } + epochLowerBound = batch.partitionLeaderEpoch + } + new EpochEndOffset() + .setPartition(topicPartition.partition) + .setErrorCode(Errors.NONE.code) + } + + override def fetchEpochEndOffsets(partitions: Map[TopicPartition, EpochData]): Map[TopicPartition, EpochEndOffset] = { + val endOffsets = mutable.Map[TopicPartition, EpochEndOffset]() + partitions.forKeyValue { (partition, epochData) => + assert(partition.partition == epochData.partition, + "Partition must be consistent between TopicPartition and EpochData") + val leaderState = leaderPartitionState(partition) + val epochEndOffset = lookupEndOffsetForEpoch(partition, epochData, leaderState) + endOffsets.put(partition, epochEndOffset) + } + endOffsets + } + + override protected val isOffsetForLeaderEpochSupported: Boolean = true + + override protected val isTruncationOnFetchSupported: Boolean = truncateOnFetch + + override def fetchFromLeader(fetchRequest: FetchRequest.Builder): Map[TopicPartition, FetchData] = { + fetchRequest.fetchData.asScala.map { case (partition, fetchData) => + val leaderState = leaderPartitionState(partition) + val epochCheckError = checkExpectedLeaderEpoch(fetchData.currentLeaderEpoch, leaderState) + val divergingEpoch = divergingEpochAndOffset(partition, fetchData.lastFetchedEpoch, fetchData.fetchOffset, leaderState) + + val (error, records) = if (epochCheckError.isDefined) { + (epochCheckError.get, MemoryRecords.EMPTY) + } else if (fetchData.fetchOffset > leaderState.logEndOffset || fetchData.fetchOffset < leaderState.logStartOffset) { + (Errors.OFFSET_OUT_OF_RANGE, MemoryRecords.EMPTY) + } else if (divergingEpoch.nonEmpty) { + (Errors.NONE, MemoryRecords.EMPTY) + } else { + // for simplicity, we fetch only one batch at a time + val records = leaderState.log.find(_.baseOffset >= fetchData.fetchOffset) match { + case Some(batch) => + val buffer = ByteBuffer.allocate(batch.sizeInBytes) + batch.writeTo(buffer) + buffer.flip() + MemoryRecords.readableRecords(buffer) + + case None => + MemoryRecords.EMPTY + } + + (Errors.NONE, records) + } + val partitionData = new FetchData() + .setPartitionIndex(partition.partition) + .setErrorCode(error.code) + .setHighWatermark(leaderState.highWatermark) + .setLastStableOffset(leaderState.highWatermark) + .setLogStartOffset(leaderState.logStartOffset) + .setRecords(records) + divergingEpoch.foreach(partitionData.setDivergingEpoch) + + (partition, partitionData) + }.toMap + } + + private def checkLeaderEpochAndThrow(expectedEpoch: Int, partitionState: PartitionState): Unit = { + checkExpectedLeaderEpoch(expectedEpoch, partitionState).foreach { error => + throw error.exception() + } + } + + override protected def fetchEarliestOffsetFromLeader(topicPartition: TopicPartition, leaderEpoch: Int): Long = { + val leaderState = leaderPartitionState(topicPartition) + checkLeaderEpochAndThrow(leaderEpoch, leaderState) + leaderState.logStartOffset + } + + override protected def fetchLatestOffsetFromLeader(topicPartition: TopicPartition, leaderEpoch: Int): Long = { + val leaderState = leaderPartitionState(topicPartition) + checkLeaderEpochAndThrow(leaderEpoch, leaderState) + leaderState.logEndOffset + } + + } + +} diff --git a/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadWithIbp26Test.scala b/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadWithIbp26Test.scala new file mode 100644 index 0000000..f2e04a4 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/AbstractFetcherThreadWithIbp26Test.scala @@ -0,0 +1,28 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import org.apache.kafka.common.Uuid + +class AbstractFetcherThreadWithIbp26Test extends AbstractFetcherThreadTest { + + override val truncateOnFetch = false + override val version = 11 + override val topicIds = Map.empty[String, Uuid] + +} diff --git a/core/src/test/scala/unit/kafka/server/AbstractMetadataRequestTest.scala b/core/src/test/scala/unit/kafka/server/AbstractMetadataRequestTest.scala new file mode 100644 index 0000000..b140851 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/AbstractMetadataRequestTest.scala @@ -0,0 +1,61 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util.Properties + +import kafka.network.SocketServer +import kafka.utils.TestUtils +import org.apache.kafka.common.message.MetadataRequestData +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.{MetadataRequest, MetadataResponse} +import org.junit.jupiter.api.Assertions.assertEquals + +abstract class AbstractMetadataRequestTest extends BaseRequestTest { + + override def brokerPropertyOverrides(properties: Properties): Unit = { + properties.setProperty(KafkaConfig.OffsetsTopicPartitionsProp, "1") + properties.setProperty(KafkaConfig.DefaultReplicationFactorProp, "2") + properties.setProperty(KafkaConfig.RackProp, s"rack/${properties.getProperty(KafkaConfig.BrokerIdProp)}") + } + + protected def requestData(topics: List[String], allowAutoTopicCreation: Boolean): MetadataRequestData = { + val data = new MetadataRequestData + if (topics == null) + data.setTopics(null) + else + topics.foreach(topic => + data.topics.add( + new MetadataRequestData.MetadataRequestTopic() + .setName(topic))) + + data.setAllowAutoTopicCreation(allowAutoTopicCreation) + data + } + + protected def sendMetadataRequest(request: MetadataRequest, destination: Option[SocketServer] = None): MetadataResponse = { + connectAndReceive[MetadataResponse](request, destination = destination.getOrElse(anySocketServer)) + } + + protected def checkAutoCreatedTopic(autoCreatedTopic: String, response: MetadataResponse): Unit = { + assertEquals(Errors.LEADER_NOT_AVAILABLE, response.errors.get(autoCreatedTopic)) + assertEquals(Some(servers.head.config.numPartitions), zkClient.getTopicPartitionCount(autoCreatedTopic)) + for (i <- 0 until servers.head.config.numPartitions) + TestUtils.waitForPartitionMetadata(servers, autoCreatedTopic, i) + } +} diff --git a/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnRequestServerTest.scala b/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnRequestServerTest.scala new file mode 100644 index 0000000..0a98d26 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnRequestServerTest.scala @@ -0,0 +1,72 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util.Properties + +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.{AddPartitionsToTxnRequest, AddPartitionsToTxnResponse} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{BeforeEach, Test, TestInfo} + +import scala.jdk.CollectionConverters._ + +class AddPartitionsToTxnRequestServerTest extends BaseRequestTest { + private val topic1 = "topic1" + val numPartitions = 1 + + override def brokerPropertyOverrides(properties: Properties): Unit = + properties.put(KafkaConfig.AutoCreateTopicsEnableProp, false.toString) + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + createTopic(topic1, numPartitions, servers.size, new Properties()) + } + + @Test + def shouldReceiveOperationNotAttemptedWhenOtherPartitionHasError(): Unit = { + // The basic idea is that we have one unknown topic and one created topic. We should get the 'UNKNOWN_TOPIC_OR_PARTITION' + // error for the unknown topic and the 'OPERATION_NOT_ATTEMPTED' error for the known and authorized topic. + val nonExistentTopic = new TopicPartition("unknownTopic", 0) + val createdTopicPartition = new TopicPartition(topic1, 0) + + val transactionalId = "foobar" + val producerId = 1000L + val producerEpoch: Short = 0 + + val request = new AddPartitionsToTxnRequest.Builder( + transactionalId, + producerId, + producerEpoch, + List(createdTopicPartition, nonExistentTopic).asJava) + .build() + + val leaderId = servers.head.config.brokerId + val response = connectAndReceive[AddPartitionsToTxnResponse](request, brokerSocketServer(leaderId)) + + assertEquals(2, response.errors.size) + + assertTrue(response.errors.containsKey(createdTopicPartition)) + assertEquals(Errors.OPERATION_NOT_ATTEMPTED, response.errors.get(createdTopicPartition)) + + assertTrue(response.errors.containsKey(nonExistentTopic)) + assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, response.errors.get(nonExistentTopic)) + } +} diff --git a/core/src/test/scala/unit/kafka/server/AdvertiseBrokerTest.scala b/core/src/test/scala/unit/kafka/server/AdvertiseBrokerTest.scala new file mode 100755 index 0000000..3f74863 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/AdvertiseBrokerTest.scala @@ -0,0 +1,77 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import org.junit.jupiter.api.Assertions._ +import kafka.utils.TestUtils +import kafka.server.QuorumTestHarness +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.junit.jupiter.api.{AfterEach, Test} + +import scala.collection.mutable.ArrayBuffer + +class AdvertiseBrokerTest extends QuorumTestHarness { + val servers = ArrayBuffer[KafkaServer]() + + val brokerId = 0 + + @AfterEach + override def tearDown(): Unit = { + TestUtils.shutdownServers(servers) + super.tearDown() + } + + @Test + def testBrokerAdvertiseListenersToZK(): Unit = { + val props = TestUtils.createBrokerConfig(brokerId, zkConnect, enableControlledShutdown = false) + props.put("advertised.listeners", "PLAINTEXT://routable-listener:3334") + servers += TestUtils.createServer(KafkaConfig.fromProps(props)) + + val brokerInfo = zkClient.getBroker(brokerId).get + assertEquals(1, brokerInfo.endPoints.size) + val endpoint = brokerInfo.endPoints.head + assertEquals("routable-listener", endpoint.host) + assertEquals(3334, endpoint.port) + assertEquals(SecurityProtocol.PLAINTEXT, endpoint.securityProtocol) + assertEquals(SecurityProtocol.PLAINTEXT.name, endpoint.listenerName.value) + } + + @Test + def testBrokerAdvertiseListenersWithCustomNamesToZK(): Unit = { + val props = TestUtils.createBrokerConfig(brokerId, zkConnect, enableControlledShutdown = false) + props.put("listeners", "INTERNAL://:0,EXTERNAL://:0") + props.put("advertised.listeners", "EXTERNAL://external-listener:9999,INTERNAL://internal-listener:10999") + props.put("listener.security.protocol.map", "INTERNAL:PLAINTEXT,EXTERNAL:PLAINTEXT") + props.put("inter.broker.listener.name", "INTERNAL") + servers += TestUtils.createServer(KafkaConfig.fromProps(props)) + + val brokerInfo = zkClient.getBroker(brokerId).get + assertEquals(2, brokerInfo.endPoints.size) + val endpoint = brokerInfo.endPoints.head + assertEquals("external-listener", endpoint.host) + assertEquals(9999, endpoint.port) + assertEquals(SecurityProtocol.PLAINTEXT, endpoint.securityProtocol) + assertEquals("EXTERNAL", endpoint.listenerName.value) + val endpoint2 = brokerInfo.endPoints(1) + assertEquals("internal-listener", endpoint2.host) + assertEquals(10999, endpoint2.port) + assertEquals(SecurityProtocol.PLAINTEXT, endpoint.securityProtocol) + assertEquals("INTERNAL", endpoint2.listenerName.value) + } + +} diff --git a/core/src/test/scala/unit/kafka/server/AlterIsrManagerTest.scala b/core/src/test/scala/unit/kafka/server/AlterIsrManagerTest.scala new file mode 100644 index 0000000..86f0dd2 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/AlterIsrManagerTest.scala @@ -0,0 +1,405 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util.Collections + +import kafka.api.LeaderAndIsr +import kafka.utils.{MockScheduler, MockTime} +import kafka.zk.KafkaZkClient +import org.apache.kafka.clients.ClientResponse +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.errors.{AuthenticationException, InvalidUpdateVersionException, OperationNotAttemptedException, UnknownServerException, UnsupportedVersionException} +import org.apache.kafka.common.message.AlterIsrResponseData +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.requests.{AbstractRequest, AlterIsrRequest, AlterIsrResponse} +import org.apache.kafka.test.TestUtils.assertFutureThrows +import org.easymock.EasyMock +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{BeforeEach, Test} +import org.mockito.ArgumentMatchers.{any, anyString} +import org.mockito.{ArgumentCaptor, ArgumentMatchers, Mockito} + +import scala.jdk.CollectionConverters._ + +class AlterIsrManagerTest { + + val topic = "test-topic" + val time = new MockTime + val metrics = new Metrics + val brokerId = 1 + + var brokerToController: BrokerToControllerChannelManager = _ + + val tp0 = new TopicPartition(topic, 0) + val tp1 = new TopicPartition(topic, 1) + val tp2 = new TopicPartition(topic, 2) + + @BeforeEach + def setup(): Unit = { + brokerToController = EasyMock.createMock(classOf[BrokerToControllerChannelManager]) + } + + @Test + def testBasic(): Unit = { + EasyMock.expect(brokerToController.start()) + EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.anyObject())).once() + EasyMock.replay(brokerToController) + + val scheduler = new MockScheduler(time) + val alterIsrManager = new DefaultAlterIsrManager(brokerToController, scheduler, time, brokerId, () => 2) + alterIsrManager.start() + alterIsrManager.submit(tp0, new LeaderAndIsr(1, 1, List(1,2,3), 10), 0) + EasyMock.verify(brokerToController) + } + + @Test + def testOverwriteWithinBatch(): Unit = { + val capture = EasyMock.newCapture[AbstractRequest.Builder[AlterIsrRequest]]() + val callbackCapture = EasyMock.newCapture[ControllerRequestCompletionHandler]() + + EasyMock.expect(brokerToController.start()) + EasyMock.expect(brokerToController.sendRequest(EasyMock.capture(capture), EasyMock.capture(callbackCapture))).times(2) + EasyMock.replay(brokerToController) + + val scheduler = new MockScheduler(time) + val alterIsrManager = new DefaultAlterIsrManager(brokerToController, scheduler, time, brokerId, () => 2) + alterIsrManager.start() + + // Only send one ISR update for a given topic+partition + val firstSubmitFuture = alterIsrManager.submit(tp0, new LeaderAndIsr(1, 1, List(1,2,3), 10), 0) + assertFalse(firstSubmitFuture.isDone) + + val failedSubmitFuture = alterIsrManager.submit(tp0, new LeaderAndIsr(1, 1, List(1,2), 10), 0) + assertTrue(failedSubmitFuture.isCompletedExceptionally) + assertFutureThrows(failedSubmitFuture, classOf[OperationNotAttemptedException]) + + // Simulate response + val alterIsrResp = partitionResponse(tp0, Errors.NONE) + val resp = new ClientResponse(null, null, "", 0L, 0L, + false, null, null, alterIsrResp) + callbackCapture.getValue.onComplete(resp) + + // Now we can submit this partition again + val newSubmitFuture = alterIsrManager.submit(tp0, new LeaderAndIsr(1, 1, List(1), 10), 0) + assertFalse(newSubmitFuture.isDone) + + EasyMock.verify(brokerToController) + + // Make sure we sent the right request ISR={1} + val request = capture.getValue.build() + assertEquals(request.data().topics().size(), 1) + assertEquals(request.data().topics().get(0).partitions().get(0).newIsr().size(), 1) + } + + @Test + def testSingleBatch(): Unit = { + val capture = EasyMock.newCapture[AbstractRequest.Builder[AlterIsrRequest]]() + val callbackCapture = EasyMock.newCapture[ControllerRequestCompletionHandler]() + + EasyMock.expect(brokerToController.start()) + EasyMock.expect(brokerToController.sendRequest(EasyMock.capture(capture), EasyMock.capture(callbackCapture))).times(2) + EasyMock.replay(brokerToController) + + val scheduler = new MockScheduler(time) + val alterIsrManager = new DefaultAlterIsrManager(brokerToController, scheduler, time, brokerId, () => 2) + alterIsrManager.start() + + // First request will send batch of one + alterIsrManager.submit(new TopicPartition(topic, 0), + new LeaderAndIsr(1, 1, List(1,2,3), 10), 0) + + // Other submissions will queue up until a response + for (i <- 1 to 9) { + alterIsrManager.submit(new TopicPartition(topic, i), + new LeaderAndIsr(1, 1, List(1,2,3), 10), 0) + } + + // Simulate response, omitting partition 0 will allow it to stay in unsent queue + val alterIsrResp = new AlterIsrResponse(new AlterIsrResponseData()) + val resp = new ClientResponse(null, null, "", 0L, 0L, + false, null, null, alterIsrResp) + + // On the callback, we check for unsent items and send another request + callbackCapture.getValue.onComplete(resp) + + EasyMock.verify(brokerToController) + + // Verify the last request sent had all 10 items + val request = capture.getValue.build() + assertEquals(request.data().topics().size(), 1) + assertEquals(request.data().topics().get(0).partitions().size(), 10) + } + + @Test + def testAuthorizationFailed(): Unit = { + testRetryOnTopLevelError(Errors.CLUSTER_AUTHORIZATION_FAILED) + } + + @Test + def testStaleBrokerEpoch(): Unit = { + testRetryOnTopLevelError(Errors.STALE_BROKER_EPOCH) + } + + @Test + def testUnknownServer(): Unit = { + testRetryOnTopLevelError(Errors.UNKNOWN_SERVER_ERROR) + } + + @Test + def testRetryOnAuthenticationFailure(): Unit = { + testRetryOnErrorResponse(new ClientResponse(null, null, "", 0L, 0L, + false, null, new AuthenticationException("authentication failed"), null)) + } + + @Test + def testRetryOnUnsupportedVersionError(): Unit = { + testRetryOnErrorResponse(new ClientResponse(null, null, "", 0L, 0L, + false, new UnsupportedVersionException("unsupported version"), null, null)) + } + + private def testRetryOnTopLevelError(error: Errors): Unit = { + val alterIsrResp = new AlterIsrResponse(new AlterIsrResponseData().setErrorCode(error.code)) + val response = new ClientResponse(null, null, "", 0L, 0L, + false, null, null, alterIsrResp) + testRetryOnErrorResponse(response) + } + + private def testRetryOnErrorResponse(response: ClientResponse): Unit = { + val leaderAndIsr = new LeaderAndIsr(1, 1, List(1,2,3), 10) + val callbackCapture = EasyMock.newCapture[ControllerRequestCompletionHandler]() + + EasyMock.expect(brokerToController.start()) + EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.capture(callbackCapture))).times(1) + EasyMock.replay(brokerToController) + + val scheduler = new MockScheduler(time) + val alterIsrManager = new DefaultAlterIsrManager(brokerToController, scheduler, time, brokerId, () => 2) + alterIsrManager.start() + alterIsrManager.submit(tp0, leaderAndIsr, 0) + + EasyMock.verify(brokerToController) + + callbackCapture.getValue.onComplete(response) + + // Any top-level error, we want to retry, so we don't clear items from the pending map + assertTrue(alterIsrManager.unsentIsrUpdates.containsKey(tp0)) + + EasyMock.reset(brokerToController) + EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.capture(callbackCapture))).times(1) + EasyMock.replay(brokerToController) + + // After some time, we will retry failed requests + time.sleep(100) + scheduler.tick() + + // After a successful response, we can submit another AlterIsrItem + val retryAlterIsrResponse = partitionResponse(tp0, Errors.NONE) + val retryResponse = new ClientResponse(null, null, "", 0L, 0L, + false, null, null, retryAlterIsrResponse) + callbackCapture.getValue.onComplete(retryResponse) + + EasyMock.verify(brokerToController) + + assertFalse(alterIsrManager.unsentIsrUpdates.containsKey(tp0)) + } + + @Test + def testInvalidUpdateVersion(): Unit = { + checkPartitionError(Errors.INVALID_UPDATE_VERSION) + } + + @Test + def testUnknownTopicPartition(): Unit = { + checkPartitionError(Errors.UNKNOWN_TOPIC_OR_PARTITION) + } + + @Test + def testNotLeaderOrFollower(): Unit = { + checkPartitionError(Errors.NOT_LEADER_OR_FOLLOWER) + } + + private def checkPartitionError(error: Errors): Unit = { + val alterIsrManager = testPartitionError(tp0, error) + // Any partition-level error should clear the item from the pending queue allowing for future updates + val future = alterIsrManager.submit(tp0, new LeaderAndIsr(1, 1, List(1,2,3), 10), 0) + assertFalse(future.isDone) + } + + private def testPartitionError(tp: TopicPartition, error: Errors): AlterIsrManager = { + val callbackCapture = EasyMock.newCapture[ControllerRequestCompletionHandler]() + EasyMock.reset(brokerToController) + EasyMock.expect(brokerToController.start()) + EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.capture(callbackCapture))).once() + EasyMock.replay(brokerToController) + + val scheduler = new MockScheduler(time) + val alterIsrManager = new DefaultAlterIsrManager(brokerToController, scheduler, time, brokerId, () => 2) + alterIsrManager.start() + + val future = alterIsrManager.submit(tp, new LeaderAndIsr(1, 1, List(1,2,3), 10), 0) + + EasyMock.verify(brokerToController) + EasyMock.reset(brokerToController) + + val alterIsrResp = partitionResponse(tp, error) + val resp = new ClientResponse(null, null, "", 0L, 0L, + false, null, null, alterIsrResp) + callbackCapture.getValue.onComplete(resp) + assertTrue(future.isCompletedExceptionally) + assertFutureThrows(future, error.exception.getClass) + alterIsrManager + } + + @Test + def testOneInFlight(): Unit = { + val callbackCapture = EasyMock.newCapture[ControllerRequestCompletionHandler]() + EasyMock.reset(brokerToController) + EasyMock.expect(brokerToController.start()) + EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.capture(callbackCapture))).once() + EasyMock.replay(brokerToController) + + val scheduler = new MockScheduler(time) + val alterIsrManager = new DefaultAlterIsrManager(brokerToController, scheduler, time, brokerId, () => 2) + alterIsrManager.start() + + // First submit will send the request + alterIsrManager.submit(tp0, new LeaderAndIsr(1, 1, List(1,2,3), 10), 0) + + // These will become pending unsent items + alterIsrManager.submit(tp1, new LeaderAndIsr(1, 1, List(1,2,3), 10), 0) + alterIsrManager.submit(tp2, new LeaderAndIsr(1, 1, List(1,2,3), 10), 0) + + EasyMock.verify(brokerToController) + + // Once the callback runs, another request will be sent + EasyMock.reset(brokerToController) + EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.capture(callbackCapture))).once() + EasyMock.replay(brokerToController) + val alterIsrResp = new AlterIsrResponse(new AlterIsrResponseData()) + val resp = new ClientResponse(null, null, "", 0L, 0L, + false, null, null, alterIsrResp) + callbackCapture.getValue.onComplete(resp) + EasyMock.verify(brokerToController) + } + + @Test + def testPartitionMissingInResponse(): Unit = { + brokerToController = Mockito.mock(classOf[BrokerToControllerChannelManager]) + + val brokerEpoch = 2 + val scheduler = new MockScheduler(time) + val alterIsrManager = new DefaultAlterIsrManager(brokerToController, scheduler, time, brokerId, () => brokerEpoch) + alterIsrManager.start() + + def matchesAlterIsr(topicPartitions: Set[TopicPartition]): AbstractRequest.Builder[_ <: AbstractRequest] = { + ArgumentMatchers.argThat[AbstractRequest.Builder[_ <: AbstractRequest]] { request => + assertEquals(ApiKeys.ALTER_ISR, request.apiKey()) + val alterIsrRequest = request.asInstanceOf[AlterIsrRequest.Builder].build() + + val requestTopicPartitions = alterIsrRequest.data.topics.asScala.flatMap { topicData => + val topic = topicData.name + topicData.partitions.asScala.map(partitionData => new TopicPartition(topic, partitionData.partitionIndex)) + }.toSet + + topicPartitions == requestTopicPartitions + } + } + + def verifySendAlterIsr(topicPartitions: Set[TopicPartition]): ControllerRequestCompletionHandler = { + val callbackCapture: ArgumentCaptor[ControllerRequestCompletionHandler] = + ArgumentCaptor.forClass(classOf[ControllerRequestCompletionHandler]) + Mockito.verify(brokerToController).sendRequest( + matchesAlterIsr(topicPartitions), + callbackCapture.capture() + ) + Mockito.reset(brokerToController) + callbackCapture.getValue + } + + def clientResponse(topicPartition: TopicPartition, error: Errors): ClientResponse = { + val alterIsrResponse = partitionResponse(topicPartition, error) + new ClientResponse(null, null, "", 0L, 0L, + false, null, null, alterIsrResponse) + } + + // The first `submit` will send the `AlterIsr` request + val future1 = alterIsrManager.submit(tp0, new LeaderAndIsr(1, 1, List(1,2,3), 10), 0) + val callback1 = verifySendAlterIsr(Set(tp0)) + + // Additional calls while the `AlterIsr` request is inflight will be queued + val future2 = alterIsrManager.submit(tp1, new LeaderAndIsr(1, 1, List(1,2,3), 10), 0) + val future3 = alterIsrManager.submit(tp2, new LeaderAndIsr(1, 1, List(1,2,3), 10), 0) + + // Respond to the first request, which will also allow the next request to get sent + callback1.onComplete(clientResponse(tp0, Errors.UNKNOWN_SERVER_ERROR)) + assertFutureThrows(future1, classOf[UnknownServerException]) + assertFalse(future2.isDone) + assertFalse(future3.isDone) + + // Verify the second request includes both expected partitions, but only respond with one of them + val callback2 = verifySendAlterIsr(Set(tp1, tp2)) + callback2.onComplete(clientResponse(tp2, Errors.UNKNOWN_SERVER_ERROR)) + assertFutureThrows(future3, classOf[UnknownServerException]) + assertFalse(future2.isDone) + + // The missing partition should be retried + val callback3 = verifySendAlterIsr(Set(tp1)) + callback3.onComplete(clientResponse(tp1, Errors.UNKNOWN_SERVER_ERROR)) + assertFutureThrows(future2, classOf[UnknownServerException]) + } + + @Test + def testZkBasic(): Unit = { + val scheduler = new MockScheduler(time) + scheduler.startup() + + val kafkaZkClient = Mockito.mock(classOf[KafkaZkClient]) + Mockito.doAnswer(_ => (true, 2)) + .when(kafkaZkClient) + .conditionalUpdatePath(anyString(), any(), ArgumentMatchers.eq(1), any()) + Mockito.doAnswer(_ => (false, 2)) + .when(kafkaZkClient) + .conditionalUpdatePath(anyString(), any(), ArgumentMatchers.eq(3), any()) + + val zkIsrManager = new ZkIsrManager(scheduler, time, kafkaZkClient) + zkIsrManager.start() + + // Correct ZK version + val future1 = zkIsrManager.submit(tp0, new LeaderAndIsr(1, 1, List(1,2,3), 1), 0) + assertTrue(future1.isDone) + assertEquals(new LeaderAndIsr(1, 1, List(1,2,3), 2), future1.get) + + // Wrong ZK version + val future2 = zkIsrManager.submit(tp0, new LeaderAndIsr(1, 1, List(1,2,3), 3), 0) + assertTrue(future2.isCompletedExceptionally) + assertFutureThrows(future2, classOf[InvalidUpdateVersionException]) + } + + private def partitionResponse(tp: TopicPartition, error: Errors): AlterIsrResponse = { + new AlterIsrResponse(new AlterIsrResponseData() + .setTopics(Collections.singletonList( + new AlterIsrResponseData.TopicData() + .setName(tp.topic()) + .setPartitions(Collections.singletonList( + new AlterIsrResponseData.PartitionData() + .setPartitionIndex(tp.partition()) + .setErrorCode(error.code)))))) + } +} diff --git a/core/src/test/scala/unit/kafka/server/AlterReplicaLogDirsRequestTest.scala b/core/src/test/scala/unit/kafka/server/AlterReplicaLogDirsRequestTest.scala new file mode 100644 index 0000000..2653d3f --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/AlterReplicaLogDirsRequestTest.scala @@ -0,0 +1,137 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.io.File + +import kafka.utils._ +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.message.AlterReplicaLogDirsRequestData +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.{AlterReplicaLogDirsRequest, AlterReplicaLogDirsResponse} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ +import scala.collection.mutable +import scala.util.Random + +class AlterReplicaLogDirsRequestTest extends BaseRequestTest { + override val logDirCount = 5 + override val brokerCount = 1 + + val topic = "topic" + + private def findErrorForPartition(response: AlterReplicaLogDirsResponse, tp: TopicPartition): Errors = { + Errors.forCode(response.data.results.asScala + .find(x => x.topicName == tp.topic).get.partitions.asScala + .find(p => p.partitionIndex == tp.partition).get.errorCode) + } + + @Test + def testAlterReplicaLogDirsRequest(): Unit = { + val partitionNum = 5 + + // Alter replica dir before topic creation + val logDir1 = new File(servers.head.config.logDirs(Random.nextInt(logDirCount))).getAbsolutePath + val partitionDirs1 = (0 until partitionNum).map(partition => new TopicPartition(topic, partition) -> logDir1).toMap + val alterReplicaLogDirsResponse1 = sendAlterReplicaLogDirsRequest(partitionDirs1) + + // The response should show error UNKNOWN_TOPIC_OR_PARTITION for all partitions + (0 until partitionNum).foreach { partition => + val tp = new TopicPartition(topic, partition) + assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, findErrorForPartition(alterReplicaLogDirsResponse1, tp)) + assertTrue(servers.head.logManager.getLog(tp).isEmpty) + } + + createTopic(topic, partitionNum, 1) + (0 until partitionNum).foreach { partition => + assertEquals(logDir1, servers.head.logManager.getLog(new TopicPartition(topic, partition)).get.dir.getParent) + } + + // Alter replica dir again after topic creation + val logDir2 = new File(servers.head.config.logDirs(Random.nextInt(logDirCount))).getAbsolutePath + val partitionDirs2 = (0 until partitionNum).map(partition => new TopicPartition(topic, partition) -> logDir2).toMap + val alterReplicaLogDirsResponse2 = sendAlterReplicaLogDirsRequest(partitionDirs2) + // The response should succeed for all partitions + (0 until partitionNum).foreach { partition => + val tp = new TopicPartition(topic, partition) + assertEquals(Errors.NONE, findErrorForPartition(alterReplicaLogDirsResponse2, tp)) + TestUtils.waitUntilTrue(() => { + logDir2 == servers.head.logManager.getLog(new TopicPartition(topic, partition)).get.dir.getParent + }, "timed out waiting for replica movement") + } + } + + @Test + def testAlterReplicaLogDirsRequestErrorCode(): Unit = { + val offlineDir = new File(servers.head.config.logDirs.tail.head).getAbsolutePath + val validDir1 = new File(servers.head.config.logDirs(1)).getAbsolutePath + val validDir2 = new File(servers.head.config.logDirs(2)).getAbsolutePath + val validDir3 = new File(servers.head.config.logDirs(3)).getAbsolutePath + + // Test AlterReplicaDirRequest before topic creation + val partitionDirs1 = mutable.Map.empty[TopicPartition, String] + partitionDirs1.put(new TopicPartition(topic, 0), "invalidDir") + partitionDirs1.put(new TopicPartition(topic, 1), validDir1) + val alterReplicaDirResponse1 = sendAlterReplicaLogDirsRequest(partitionDirs1.toMap) + assertEquals(Errors.LOG_DIR_NOT_FOUND, findErrorForPartition(alterReplicaDirResponse1, new TopicPartition(topic, 0))) + assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, findErrorForPartition(alterReplicaDirResponse1, new TopicPartition(topic, 1))) + + createTopic(topic, 3, 1) + + // Test AlterReplicaDirRequest after topic creation + val partitionDirs2 = mutable.Map.empty[TopicPartition, String] + partitionDirs2.put(new TopicPartition(topic, 0), "invalidDir") + partitionDirs2.put(new TopicPartition(topic, 1), validDir2) + val alterReplicaDirResponse2 = sendAlterReplicaLogDirsRequest(partitionDirs2.toMap) + assertEquals(Errors.LOG_DIR_NOT_FOUND, findErrorForPartition(alterReplicaDirResponse2, new TopicPartition(topic, 0))) + assertEquals(Errors.NONE, findErrorForPartition(alterReplicaDirResponse2, new TopicPartition(topic, 1))) + + // Test AlterReplicaDirRequest after topic creation and log directory failure + servers.head.logDirFailureChannel.maybeAddOfflineLogDir(offlineDir, "", new java.io.IOException()) + TestUtils.waitUntilTrue(() => !servers.head.logManager.isLogDirOnline(offlineDir), s"timed out waiting for $offlineDir to be offline", 3000) + val partitionDirs3 = mutable.Map.empty[TopicPartition, String] + partitionDirs3.put(new TopicPartition(topic, 0), "invalidDir") + partitionDirs3.put(new TopicPartition(topic, 1), validDir3) + partitionDirs3.put(new TopicPartition(topic, 2), offlineDir) + val alterReplicaDirResponse3 = sendAlterReplicaLogDirsRequest(partitionDirs3.toMap) + assertEquals(Errors.LOG_DIR_NOT_FOUND, findErrorForPartition(alterReplicaDirResponse3, new TopicPartition(topic, 0))) + assertEquals(Errors.KAFKA_STORAGE_ERROR, findErrorForPartition(alterReplicaDirResponse3, new TopicPartition(topic, 1))) + assertEquals(Errors.KAFKA_STORAGE_ERROR, findErrorForPartition(alterReplicaDirResponse3, new TopicPartition(topic, 2))) + } + + private def sendAlterReplicaLogDirsRequest(partitionDirs: Map[TopicPartition, String]): AlterReplicaLogDirsResponse = { + val logDirs = partitionDirs.groupBy{case (_, dir) => dir}.map{ case(dir, tps) => + new AlterReplicaLogDirsRequestData.AlterReplicaLogDir() + .setPath(dir) + .setTopics(new AlterReplicaLogDirsRequestData.AlterReplicaLogDirTopicCollection( + tps.groupBy { case (tp, _) => tp.topic } + .map { case (topic, tpPartitions) => + new AlterReplicaLogDirsRequestData.AlterReplicaLogDirTopic() + .setName(topic) + .setPartitions(tpPartitions.map{case (tp, _) => tp.partition.asInstanceOf[Integer]}.toList.asJava) + }.toList.asJava.iterator)) + } + val data = new AlterReplicaLogDirsRequestData() + .setDirs(new AlterReplicaLogDirsRequestData.AlterReplicaLogDirCollection(logDirs.asJava.iterator)) + val request = new AlterReplicaLogDirsRequest.Builder(data).build() + connectAndReceive[AlterReplicaLogDirsResponse](request, destination = controllerSocketServer) + } + +} diff --git a/core/src/test/scala/unit/kafka/server/AlterUserScramCredentialsRequestNotAuthorizedTest.scala b/core/src/test/scala/unit/kafka/server/AlterUserScramCredentialsRequestNotAuthorizedTest.scala new file mode 100644 index 0000000..b981e7e --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/AlterUserScramCredentialsRequestNotAuthorizedTest.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import kafka.network.SocketServer +import org.apache.kafka.clients.admin.ScramMechanism +import org.apache.kafka.common.message.AlterUserScramCredentialsRequestData +import org.apache.kafka.common.message.AlterUserScramCredentialsResponseData.AlterUserScramCredentialsResult +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.{AlterUserScramCredentialsRequest, AlterUserScramCredentialsResponse} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import java.util +import java.util.Properties +import scala.jdk.CollectionConverters._ + +/** + * see AlterUserScramCredentialsRequestTest + */ +class AlterUserScramCredentialsRequestNotAuthorizedTest extends BaseRequestTest { + + override def brokerPropertyOverrides(properties: Properties): Unit = { + properties.put(KafkaConfig.ControlledShutdownEnableProp, "false") + properties.put(KafkaConfig.AuthorizerClassNameProp, classOf[AlterCredentialsTest.TestAuthorizer].getName) + properties.put(KafkaConfig.PrincipalBuilderClassProp, classOf[AlterCredentialsTest.TestPrincipalBuilderReturningUnauthorized].getName) + } + + private val user1 = "user1" + private val user2 = "user2" + + @Test + def testAlterNothingNotAuthorized(): Unit = { + val request = new AlterUserScramCredentialsRequest.Builder( + new AlterUserScramCredentialsRequestData() + .setDeletions(new util.ArrayList[AlterUserScramCredentialsRequestData.ScramCredentialDeletion]) + .setUpsertions(new util.ArrayList[AlterUserScramCredentialsRequestData.ScramCredentialUpsertion])).build() + val response = sendAlterUserScramCredentialsRequest(request) + + val results = response.data.results + assertEquals(0, results.size) + } + + @Test + def testAlterSomethingNotAuthorized(): Unit = { + val request = new AlterUserScramCredentialsRequest.Builder( + new AlterUserScramCredentialsRequestData() + .setDeletions(util.Arrays.asList(new AlterUserScramCredentialsRequestData.ScramCredentialDeletion().setName(user1).setMechanism(ScramMechanism.SCRAM_SHA_256.`type`))) + .setUpsertions(util.Arrays.asList(new AlterUserScramCredentialsRequestData.ScramCredentialUpsertion().setName(user2).setMechanism(ScramMechanism.SCRAM_SHA_512.`type`)))).build() + val response = sendAlterUserScramCredentialsRequest(request) + + val results = response.data.results + assertEquals(2, results.size) + checkAllErrorsAlteringCredentials(results, Errors.CLUSTER_AUTHORIZATION_FAILED, "when not authorized") + } + + private def sendAlterUserScramCredentialsRequest(request: AlterUserScramCredentialsRequest, socketServer: SocketServer = controllerSocketServer): AlterUserScramCredentialsResponse = { + connectAndReceive[AlterUserScramCredentialsResponse](request, destination = socketServer) + } + + private def checkAllErrorsAlteringCredentials(resultsToCheck: util.List[AlterUserScramCredentialsResult], expectedError: Errors, contextMsg: String): Unit = { + assertEquals(0, resultsToCheck.asScala.filterNot(_.errorCode == expectedError.code).size, + s"Expected all '${expectedError.name}' errors when altering credentials $contextMsg") + } +} \ No newline at end of file diff --git a/core/src/test/scala/unit/kafka/server/AlterUserScramCredentialsRequestTest.scala b/core/src/test/scala/unit/kafka/server/AlterUserScramCredentialsRequestTest.scala new file mode 100644 index 0000000..e0121b1 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/AlterUserScramCredentialsRequestTest.scala @@ -0,0 +1,398 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + + +import java.nio.charset.StandardCharsets +import java.util +import java.util.Properties +import kafka.network.SocketServer +import kafka.security.authorizer.AclAuthorizer +import org.apache.kafka.clients.admin.ScramMechanism +import org.apache.kafka.common.message.AlterUserScramCredentialsResponseData.AlterUserScramCredentialsResult +import org.apache.kafka.common.message.DescribeUserScramCredentialsResponseData.DescribeUserScramCredentialsResult +import org.apache.kafka.common.message.{AlterUserScramCredentialsRequestData, DescribeUserScramCredentialsRequestData} +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.requests.{AlterUserScramCredentialsRequest, AlterUserScramCredentialsResponse, DescribeUserScramCredentialsRequest, DescribeUserScramCredentialsResponse} +import org.apache.kafka.common.security.auth.{AuthenticationContext, KafkaPrincipal} +import org.apache.kafka.common.security.authenticator.DefaultKafkaPrincipalBuilder +import org.apache.kafka.server.authorizer.{Action, AuthorizableRequestContext, AuthorizationResult} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ + +/** + * Test AlterUserScramCredentialsRequest/Response API for the cases where either no credentials are altered + * or failure is expected due to lack of authorization, sending the request to a non-controller broker, or some other issue. + * Also tests the Alter and Describe APIs for the case where credentials are successfully altered/described. + */ +class AlterUserScramCredentialsRequestTest extends BaseRequestTest { + override def brokerPropertyOverrides(properties: Properties): Unit = { + properties.put(KafkaConfig.ControlledShutdownEnableProp, "false") + properties.put(KafkaConfig.AuthorizerClassNameProp, classOf[AlterCredentialsTest.TestAuthorizer].getName) + properties.put(KafkaConfig.PrincipalBuilderClassProp, classOf[AlterCredentialsTest.TestPrincipalBuilderReturningAuthorized].getName) + } + + private val saltedPasswordBytes = "saltedPassword".getBytes(StandardCharsets.UTF_8) + private val saltBytes = "salt".getBytes(StandardCharsets.UTF_8) + private val user1 = "user1" + private val user2 = "user2" + private val unknownUser = "unknownUser" + + @Test + def testAlterNothing(): Unit = { + val request = new AlterUserScramCredentialsRequest.Builder( + new AlterUserScramCredentialsRequestData() + .setDeletions(new util.ArrayList[AlterUserScramCredentialsRequestData.ScramCredentialDeletion]) + .setUpsertions(new util.ArrayList[AlterUserScramCredentialsRequestData.ScramCredentialUpsertion])).build() + val response = sendAlterUserScramCredentialsRequest(request) + + val results = response.data.results + assertEquals(0, results.size) + } + + @Test + def testAlterSameThingTwice(): Unit = { + val deletion1 = new AlterUserScramCredentialsRequestData.ScramCredentialDeletion().setName(user1).setMechanism(ScramMechanism.SCRAM_SHA_256.`type`) + val deletion2 = new AlterUserScramCredentialsRequestData.ScramCredentialDeletion().setName(user2).setMechanism(ScramMechanism.SCRAM_SHA_256.`type`) + val upsertion1 = new AlterUserScramCredentialsRequestData.ScramCredentialUpsertion().setName(user1).setMechanism(ScramMechanism.SCRAM_SHA_256.`type`) + .setIterations(4096).setSalt(saltBytes).setSaltedPassword(saltedPasswordBytes) + val upsertion2 = new AlterUserScramCredentialsRequestData.ScramCredentialUpsertion().setName(user2).setMechanism(ScramMechanism.SCRAM_SHA_256.`type`) + .setIterations(4096).setSalt(saltBytes).setSaltedPassword(saltedPasswordBytes) + val requests = List ( + new AlterUserScramCredentialsRequest.Builder( + new AlterUserScramCredentialsRequestData() + .setDeletions(util.Arrays.asList(deletion1, deletion1)) + .setUpsertions(util.Arrays.asList(upsertion2, upsertion2))).build(), + new AlterUserScramCredentialsRequest.Builder( + new AlterUserScramCredentialsRequestData() + .setDeletions(util.Arrays.asList(deletion1, deletion2)) + .setUpsertions(util.Arrays.asList(upsertion1, upsertion2))).build(), + ) + requests.foreach(request => { + val response = sendAlterUserScramCredentialsRequest(request) + val results = response.data.results + assertEquals(2, results.size) + checkAllErrorsAlteringCredentials(results, Errors.DUPLICATE_RESOURCE, "when altering the same credential twice in a single request") + }) + } + + @Test + def testAlterEmptyUser(): Unit = { + val deletionEmpty = new AlterUserScramCredentialsRequestData.ScramCredentialDeletion().setName("").setMechanism(ScramMechanism.SCRAM_SHA_256.`type`) + val upsertionEmpty = new AlterUserScramCredentialsRequestData.ScramCredentialUpsertion().setName("").setMechanism(ScramMechanism.SCRAM_SHA_256.`type`) + .setIterations(4096).setSalt(saltBytes).setSaltedPassword(saltedPasswordBytes) + val requests = List ( + new AlterUserScramCredentialsRequest.Builder( + new AlterUserScramCredentialsRequestData() + .setDeletions(util.Arrays.asList(deletionEmpty)) + .setUpsertions(new util.ArrayList[AlterUserScramCredentialsRequestData.ScramCredentialUpsertion])).build(), + new AlterUserScramCredentialsRequest.Builder( + new AlterUserScramCredentialsRequestData() + .setDeletions(new util.ArrayList[AlterUserScramCredentialsRequestData.ScramCredentialDeletion]) + .setUpsertions(util.Arrays.asList(upsertionEmpty))).build(), + new AlterUserScramCredentialsRequest.Builder( + new AlterUserScramCredentialsRequestData() + .setDeletions(util.Arrays.asList(deletionEmpty, deletionEmpty)) + .setUpsertions(util.Arrays.asList(upsertionEmpty))).build(), + ) + requests.foreach(request => { + val response = sendAlterUserScramCredentialsRequest(request) + val results = response.data.results + assertEquals(1, results.size) + checkAllErrorsAlteringCredentials(results, Errors.UNACCEPTABLE_CREDENTIAL, "when altering an empty user") + assertEquals("Username must not be empty", results.get(0).errorMessage) + }) + } + + @Test + def testAlterUnknownMechanism(): Unit = { + val deletionUnknown1 = new AlterUserScramCredentialsRequestData.ScramCredentialDeletion().setName(user1).setMechanism(ScramMechanism.UNKNOWN.`type`) + val deletionValid1 = new AlterUserScramCredentialsRequestData.ScramCredentialDeletion().setName(user1).setMechanism(ScramMechanism.SCRAM_SHA_256.`type`) + val deletionUnknown2 = new AlterUserScramCredentialsRequestData.ScramCredentialDeletion().setName(user2).setMechanism(10.toByte) + val user3 = "user3" + val upsertionUnknown3 = new AlterUserScramCredentialsRequestData.ScramCredentialUpsertion().setName(user3).setMechanism(ScramMechanism.UNKNOWN.`type`) + .setIterations(8192).setSalt(saltBytes).setSaltedPassword(saltedPasswordBytes) + val upsertionValid3 = new AlterUserScramCredentialsRequestData.ScramCredentialUpsertion().setName(user3).setMechanism(ScramMechanism.SCRAM_SHA_256.`type`) + .setIterations(8192).setSalt(saltBytes).setSaltedPassword(saltedPasswordBytes) + val user4 = "user4" + val upsertionUnknown4 = new AlterUserScramCredentialsRequestData.ScramCredentialUpsertion().setName(user4).setMechanism(10.toByte) + .setIterations(8192).setSalt(saltBytes).setSaltedPassword(saltedPasswordBytes) + val user5 = "user5" + val upsertionUnknown5 = new AlterUserScramCredentialsRequestData.ScramCredentialUpsertion().setName(user5).setMechanism(ScramMechanism.UNKNOWN.`type`) + .setIterations(8192).setSalt(saltBytes).setSaltedPassword(saltedPasswordBytes) + val request = new AlterUserScramCredentialsRequest.Builder( + new AlterUserScramCredentialsRequestData() + .setDeletions(util.Arrays.asList(deletionUnknown1, deletionValid1, deletionUnknown2)) + .setUpsertions(util.Arrays.asList(upsertionUnknown3, upsertionValid3, upsertionUnknown4, upsertionUnknown5))).build() + val response = sendAlterUserScramCredentialsRequest(request) + val results = response.data.results + assertEquals(5, results.size) + checkAllErrorsAlteringCredentials(results, Errors.UNSUPPORTED_SASL_MECHANISM, "when altering the credentials with unknown SCRAM mechanisms") + results.asScala.foreach(result => assertEquals("Unknown SCRAM mechanism", result.errorMessage)) + } + + @Test + def testAlterTooFewIterations(): Unit = { + val upsertionTooFewIterations = new AlterUserScramCredentialsRequestData.ScramCredentialUpsertion().setName(user1) + .setMechanism(ScramMechanism.SCRAM_SHA_256.`type`).setIterations(1) + .setSalt(saltBytes).setSaltedPassword(saltedPasswordBytes) + val request = new AlterUserScramCredentialsRequest.Builder( + new AlterUserScramCredentialsRequestData() + .setDeletions(util.Collections.emptyList()) + .setUpsertions(util.Arrays.asList(upsertionTooFewIterations))).build() + val response = sendAlterUserScramCredentialsRequest(request) + val results = response.data.results + assertEquals(1, results.size) + checkAllErrorsAlteringCredentials(results, Errors.UNACCEPTABLE_CREDENTIAL, "when altering the credentials with too few iterations") + assertEquals("Too few iterations", results.get(0).errorMessage) + } + + @Test + def testAlterTooManyIterations(): Unit = { + val upsertionTooFewIterations = new AlterUserScramCredentialsRequestData.ScramCredentialUpsertion().setName(user1) + .setMechanism(ScramMechanism.SCRAM_SHA_256.`type`).setIterations(Integer.MAX_VALUE) + .setSalt(saltBytes).setSaltedPassword(saltedPasswordBytes) + val request = new AlterUserScramCredentialsRequest.Builder( + new AlterUserScramCredentialsRequestData() + .setDeletions(util.Collections.emptyList()) + .setUpsertions(util.Arrays.asList(upsertionTooFewIterations))).build() + val response = sendAlterUserScramCredentialsRequest(request) + val results = response.data.results + assertEquals(1, results.size) + checkAllErrorsAlteringCredentials(results, Errors.UNACCEPTABLE_CREDENTIAL, "when altering the credentials with too many iterations") + assertEquals("Too many iterations", results.get(0).errorMessage) + } + + @Test + def testDeleteSomethingThatDoesNotExist(): Unit = { + val request = new AlterUserScramCredentialsRequest.Builder( + new AlterUserScramCredentialsRequestData() + .setDeletions(util.Arrays.asList(new AlterUserScramCredentialsRequestData.ScramCredentialDeletion().setName(user1).setMechanism(ScramMechanism.SCRAM_SHA_256.`type`))) + .setUpsertions(new util.ArrayList[AlterUserScramCredentialsRequestData.ScramCredentialUpsertion])).build() + val response = sendAlterUserScramCredentialsRequest(request) + + val results = response.data.results + assertEquals(1, results.size) + checkAllErrorsAlteringCredentials(results, Errors.RESOURCE_NOT_FOUND, "when deleting a non-existing credential") + } + + @Test + def testAlterNotController(): Unit = { + val request = new AlterUserScramCredentialsRequest.Builder( + new AlterUserScramCredentialsRequestData() + .setDeletions(util.Arrays.asList(new AlterUserScramCredentialsRequestData.ScramCredentialDeletion().setName(user1).setMechanism(ScramMechanism.SCRAM_SHA_256.`type`))) + .setUpsertions(util.Arrays.asList(new AlterUserScramCredentialsRequestData.ScramCredentialUpsertion().setName(user2).setMechanism(ScramMechanism.SCRAM_SHA_512.`type`)))).build() + val response = sendAlterUserScramCredentialsRequest(request, notControllerSocketServer) + + val results = response.data.results + assertEquals(2, results.size) + checkAllErrorsAlteringCredentials(results, Errors.NOT_CONTROLLER, "when routed incorrectly to a non-Controller broker") + } + + @Test + def testAlterAndDescribe(): Unit = { + // create a bunch of credentials + val request1 = new AlterUserScramCredentialsRequest.Builder( + new AlterUserScramCredentialsRequestData() + .setUpsertions(util.Arrays.asList( + new AlterUserScramCredentialsRequestData.ScramCredentialUpsertion() + .setName(user1).setMechanism(ScramMechanism.SCRAM_SHA_256.`type`) + .setIterations(4096) + .setSalt(saltBytes) + .setSaltedPassword(saltedPasswordBytes), + new AlterUserScramCredentialsRequestData.ScramCredentialUpsertion() + .setName(user1).setMechanism(ScramMechanism.SCRAM_SHA_512.`type`) + .setIterations(8192) + .setSalt(saltBytes) + .setSaltedPassword(saltedPasswordBytes), + new AlterUserScramCredentialsRequestData.ScramCredentialUpsertion() + .setName(user2).setMechanism(ScramMechanism.SCRAM_SHA_512.`type`) + .setIterations(8192) + .setSalt(saltBytes) + .setSaltedPassword(saltedPasswordBytes), + ))).build() + val results1 = sendAlterUserScramCredentialsRequest(request1).data.results + assertEquals(2, results1.size) + checkNoErrorsAlteringCredentials(results1) + checkUserAppearsInAlterResults(results1, user1) + checkUserAppearsInAlterResults(results1, user2) + + // now describe them all + val results2 = describeAllWithNoTopLevelErrorConfirmed().data.results + assertEquals(2, results2.size) + checkUserHasTwoCredentials(results2, user1) + checkForSingleSha512Iterations8192Credential(results2, user2) + + // now describe just one + val request3 = new DescribeUserScramCredentialsRequest.Builder( + new DescribeUserScramCredentialsRequestData().setUsers(util.Arrays.asList( + new DescribeUserScramCredentialsRequestData.UserName().setName(user1)))).build() + val response3 = sendDescribeUserScramCredentialsRequest(request3) + checkNoTopLevelErrorDescribingCredentials(response3) + val results3 = response3.data.results + assertEquals(1, results3.size) + checkUserHasTwoCredentials(results3, user1) + + // now test per-user errors by describing user1 and an unknown + val requestUnknown = new DescribeUserScramCredentialsRequest.Builder( + new DescribeUserScramCredentialsRequestData().setUsers(util.Arrays.asList( + new DescribeUserScramCredentialsRequestData.UserName().setName(user1), + new DescribeUserScramCredentialsRequestData.UserName().setName(unknownUser)))).build() + val responseUnknown = sendDescribeUserScramCredentialsRequest(requestUnknown) + checkNoTopLevelErrorDescribingCredentials(responseUnknown) + val resultsUnknown = responseUnknown.data.results + assertEquals(2, resultsUnknown.size) + checkUserHasTwoCredentials(resultsUnknown, user1) + checkDescribeForError(resultsUnknown, unknownUser, Errors.RESOURCE_NOT_FOUND) + + // now test per-user errors again by describing user1 along with user2 twice + val requestDuplicateUser = new DescribeUserScramCredentialsRequest.Builder( + new DescribeUserScramCredentialsRequestData().setUsers(util.Arrays.asList( + new DescribeUserScramCredentialsRequestData.UserName().setName(user1), + new DescribeUserScramCredentialsRequestData.UserName().setName(user2), + new DescribeUserScramCredentialsRequestData.UserName().setName(user2)))).build() + val responseDuplicateUser = sendDescribeUserScramCredentialsRequest(requestDuplicateUser) + checkNoTopLevelErrorDescribingCredentials(responseDuplicateUser) + val resultsDuplicateUser = responseDuplicateUser.data.results + assertEquals(2, resultsDuplicateUser.size) + checkUserHasTwoCredentials(resultsDuplicateUser, user1) + checkDescribeForError(resultsDuplicateUser, user2, Errors.DUPLICATE_RESOURCE) + + // now delete a couple of credentials + val request4 = new AlterUserScramCredentialsRequest.Builder( + new AlterUserScramCredentialsRequestData() + .setDeletions(util.Arrays.asList( + new AlterUserScramCredentialsRequestData.ScramCredentialDeletion() + .setName(user1).setMechanism(ScramMechanism.SCRAM_SHA_256.`type`), + new AlterUserScramCredentialsRequestData.ScramCredentialDeletion() + .setName(user2).setMechanism(ScramMechanism.SCRAM_SHA_512.`type`), + ))).build() + val response4 = sendAlterUserScramCredentialsRequest(request4) + val results4 = response4.data.results + assertEquals(2, results4.size) + checkNoErrorsAlteringCredentials(results4) + checkUserAppearsInAlterResults(results4, user1) + checkUserAppearsInAlterResults(results4, user2) + + // now describe them all, which should just yield 1 credential + val results5 = describeAllWithNoTopLevelErrorConfirmed().data.results + assertEquals(1, results5.size) + checkForSingleSha512Iterations8192Credential(results5, user1) + + // now delete the last one + val request6 = new AlterUserScramCredentialsRequest.Builder( + new AlterUserScramCredentialsRequestData() + .setDeletions(util.Arrays.asList( + new AlterUserScramCredentialsRequestData.ScramCredentialDeletion() + .setName(user1).setMechanism(ScramMechanism.SCRAM_SHA_512.`type`), + ))).build() + val results6 = sendAlterUserScramCredentialsRequest(request6).data.results + assertEquals(1, results6.size) + checkNoErrorsAlteringCredentials(results6) + checkUserAppearsInAlterResults(results6, user1) + + // now describe them all, which should yield 0 credentials + val results7 = describeAllWithNoTopLevelErrorConfirmed().data.results + assertEquals(0, results7.size) + } + + private def sendAlterUserScramCredentialsRequest(request: AlterUserScramCredentialsRequest, socketServer: SocketServer = controllerSocketServer): AlterUserScramCredentialsResponse = { + connectAndReceive[AlterUserScramCredentialsResponse](request, destination = socketServer) + } + + private def sendDescribeUserScramCredentialsRequest(request: DescribeUserScramCredentialsRequest, socketServer: SocketServer = controllerSocketServer): DescribeUserScramCredentialsResponse = { + connectAndReceive[DescribeUserScramCredentialsResponse](request, destination = socketServer) + } + + private def checkAllErrorsAlteringCredentials(resultsToCheck: util.List[AlterUserScramCredentialsResult], expectedError: Errors, contextMsg: String) = { + assertEquals(0, resultsToCheck.asScala.filterNot(_.errorCode == expectedError.code).size, + s"Expected all '${expectedError.name}' errors when altering credentials $contextMsg") + } + + private def checkNoErrorsAlteringCredentials(resultsToCheck: util.List[AlterUserScramCredentialsResult]) = { + assertEquals(0, resultsToCheck.asScala.filterNot(_.errorCode == Errors.NONE.code).size, + "Expected no error when altering credentials") + } + + private def checkUserAppearsInAlterResults(resultsToCheck: util.List[AlterUserScramCredentialsResult], user: String) = { + assertTrue(resultsToCheck.asScala.exists(_.user == user), s"Expected result to contain '$user'") + } + + private def describeAllWithNoTopLevelErrorConfirmed() = { + val response = sendDescribeUserScramCredentialsRequest( + new DescribeUserScramCredentialsRequest.Builder(new DescribeUserScramCredentialsRequestData()).build()) + checkNoTopLevelErrorDescribingCredentials(response) + response + } + + private def checkNoTopLevelErrorDescribingCredentials(responseToCheck: DescribeUserScramCredentialsResponse) = { + assertEquals(Errors.NONE.code, responseToCheck.data.errorCode, "Expected no top-level error when describing the credentials") + } + + private def checkUserHasTwoCredentials(resultsToCheck: util.List[DescribeUserScramCredentialsResult], user: String) = { + assertTrue(resultsToCheck.asScala.exists(result => result.user == user && result.credentialInfos.size == 2 && result.errorCode == Errors.NONE.code), + s"Expected result to contain '$user' with 2 credentials: $resultsToCheck") + assertTrue(resultsToCheck.asScala.exists(result => result.user == user && result.credentialInfos.asScala.exists(info => + info.mechanism == ScramMechanism.SCRAM_SHA_256.`type` && info.iterations == 4096) + && result.credentialInfos.asScala.exists(info => + info.mechanism == ScramMechanism.SCRAM_SHA_512.`type` && info.iterations == 8192)), + s"Expected result to contain '$user' with SCRAM_SHA_256/4096 and SCRAM_SHA_512/8192 credentials: $resultsToCheck") + } + + private def checkForSingleSha512Iterations8192Credential(resultsToCheck: util.List[DescribeUserScramCredentialsResult], user: String) = { + assertTrue(resultsToCheck.asScala.exists(result => result.user == user && result.credentialInfos.size == 1 && result.errorCode == Errors.NONE.code), + s"Expected result to contain '$user' with 1 credential: $resultsToCheck") + assertTrue(resultsToCheck.asScala.exists(result => result.user == user && result.credentialInfos.asScala.exists(info => + info.mechanism == ScramMechanism.SCRAM_SHA_512.`type` && info.iterations == 8192)), + s"Expected result to contain '$user' with SCRAM_SHA_512/8192 credential: $resultsToCheck") + } + + private def checkDescribeForError(resultsToCheck: util.List[DescribeUserScramCredentialsResult], user: String, expectedError: Errors) = { + assertTrue(resultsToCheck.asScala.exists(result => result.user == user && result.credentialInfos.size == 0 && result.errorCode == expectedError.code), + s"Expected result to contain '$user' with a ${expectedError.name} error: $resultsToCheck") + } +} + +object AlterCredentialsTest { + val UnauthorizedPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "Unauthorized") + val AuthorizedPrincipal = KafkaPrincipal.ANONYMOUS + + class TestAuthorizer extends AclAuthorizer { + override def authorize(requestContext: AuthorizableRequestContext, actions: util.List[Action]): util.List[AuthorizationResult] = { + actions.asScala.map { _ => + if (requestContext.requestType == ApiKeys.ALTER_USER_SCRAM_CREDENTIALS.id && requestContext.principal == UnauthorizedPrincipal) + AuthorizationResult.DENIED + else + AuthorizationResult.ALLOWED + }.asJava + } + } + + class TestPrincipalBuilderReturningAuthorized extends DefaultKafkaPrincipalBuilder(null, null) { + override def build(context: AuthenticationContext): KafkaPrincipal = { + AuthorizedPrincipal + } + } + + class TestPrincipalBuilderReturningUnauthorized extends DefaultKafkaPrincipalBuilder(null, null) { + override def build(context: AuthenticationContext): KafkaPrincipal = { + UnauthorizedPrincipal + } + } +} diff --git a/core/src/test/scala/unit/kafka/server/ApiVersionManagerTest.scala b/core/src/test/scala/unit/kafka/server/ApiVersionManagerTest.scala new file mode 100644 index 0000000..a93cc90 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/ApiVersionManagerTest.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import kafka.api.ApiVersion +import org.apache.kafka.clients.NodeApiVersions +import org.apache.kafka.common.message.ApiMessageType.ListenerType +import org.apache.kafka.common.protocol.ApiKeys +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.EnumSource +import org.mockito.Mockito + +import scala.jdk.CollectionConverters._ + +class ApiVersionManagerTest { + private val brokerFeatures = BrokerFeatures.createDefault() + private val featureCache = new FinalizedFeatureCache(brokerFeatures) + + @ParameterizedTest + @EnumSource(classOf[ListenerType]) + def testApiScope(apiScope: ListenerType): Unit = { + val versionManager = new DefaultApiVersionManager( + listenerType = apiScope, + interBrokerProtocolVersion = ApiVersion.latestVersion, + forwardingManager = None, + features = brokerFeatures, + featureCache = featureCache + ) + assertEquals(ApiKeys.apisForListener(apiScope).asScala, versionManager.enabledApis) + assertTrue(ApiKeys.apisForListener(apiScope).asScala.forall(versionManager.isApiEnabled)) + } + + @Test + def testControllerApiIntersection(): Unit = { + val controllerMinVersion: Short = 1 + val controllerMaxVersion: Short = 5 + + val forwardingManager = Mockito.mock(classOf[ForwardingManager]) + + Mockito.when(forwardingManager.controllerApiVersions).thenReturn(Some(NodeApiVersions.create( + ApiKeys.CREATE_TOPICS.id, + controllerMinVersion, + controllerMaxVersion + ))) + + val versionManager = new DefaultApiVersionManager( + listenerType = ListenerType.ZK_BROKER, + interBrokerProtocolVersion = ApiVersion.latestVersion, + forwardingManager = Some(forwardingManager), + features = brokerFeatures, + featureCache = featureCache + ) + + val apiVersionsResponse = versionManager.apiVersionResponse(throttleTimeMs = 0) + val alterConfigVersion = apiVersionsResponse.data.apiKeys.find(ApiKeys.CREATE_TOPICS.id) + assertNotNull(alterConfigVersion) + assertEquals(controllerMinVersion, alterConfigVersion.minVersion) + assertEquals(controllerMaxVersion, alterConfigVersion.maxVersion) + } + + @Test + def testEnvelopeEnabledWhenForwardingManagerPresent(): Unit = { + val forwardingManager = Mockito.mock(classOf[ForwardingManager]) + Mockito.when(forwardingManager.controllerApiVersions).thenReturn(None) + + val versionManager = new DefaultApiVersionManager( + listenerType = ListenerType.ZK_BROKER, + interBrokerProtocolVersion = ApiVersion.latestVersion, + forwardingManager = Some(forwardingManager), + features = brokerFeatures, + featureCache = featureCache + ) + assertTrue(versionManager.isApiEnabled(ApiKeys.ENVELOPE)) + assertTrue(versionManager.enabledApis.contains(ApiKeys.ENVELOPE)) + + val apiVersionsResponse = versionManager.apiVersionResponse(throttleTimeMs = 0) + val envelopeVersion = apiVersionsResponse.data.apiKeys.find(ApiKeys.ENVELOPE.id) + assertNotNull(envelopeVersion) + assertEquals(ApiKeys.ENVELOPE.oldestVersion, envelopeVersion.minVersion) + assertEquals(ApiKeys.ENVELOPE.latestVersion, envelopeVersion.maxVersion) + } + + @Test + def testEnvelopeDisabledWhenForwardingManagerEmpty(): Unit = { + val versionManager = new DefaultApiVersionManager( + listenerType = ListenerType.ZK_BROKER, + interBrokerProtocolVersion = ApiVersion.latestVersion, + forwardingManager = None, + features = brokerFeatures, + featureCache = featureCache + ) + assertFalse(versionManager.isApiEnabled(ApiKeys.ENVELOPE)) + assertFalse(versionManager.enabledApis.contains(ApiKeys.ENVELOPE)) + + val apiVersionsResponse = versionManager.apiVersionResponse(throttleTimeMs = 0) + assertNull(apiVersionsResponse.data.apiKeys.find(ApiKeys.ENVELOPE.id)) + } + +} diff --git a/core/src/test/scala/unit/kafka/server/ApiVersionsRequestTest.scala b/core/src/test/scala/unit/kafka/server/ApiVersionsRequestTest.scala new file mode 100644 index 0000000..34ee74a --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/ApiVersionsRequestTest.scala @@ -0,0 +1,86 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.test.{ClusterConfig, ClusterInstance} +import org.apache.kafka.common.message.ApiVersionsRequestData +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.requests.ApiVersionsRequest +import kafka.test.annotation.ClusterTest +import kafka.test.junit.ClusterTestExtensions +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.api.extension.ExtendWith + + +@ExtendWith(value = Array(classOf[ClusterTestExtensions])) +class ApiVersionsRequestTest(cluster: ClusterInstance) extends AbstractApiVersionsRequestTest(cluster) { + + @BeforeEach + def setup(config: ClusterConfig): Unit = { + super.brokerPropertyOverrides(config.serverProperties()) + } + + @ClusterTest + def testApiVersionsRequest(): Unit = { + val request = new ApiVersionsRequest.Builder().build() + val apiVersionsResponse = sendApiVersionsRequest(request, cluster.clientListener()) + validateApiVersionsResponse(apiVersionsResponse) + } + + @ClusterTest + def testApiVersionsRequestThroughControlPlaneListener(): Unit = { + val request = new ApiVersionsRequest.Builder().build() + val apiVersionsResponse = sendApiVersionsRequest(request, super.controlPlaneListenerName) + validateApiVersionsResponse(apiVersionsResponse) + } + + @ClusterTest + def testApiVersionsRequestWithUnsupportedVersion(): Unit = { + val apiVersionsRequest = new ApiVersionsRequest.Builder().build() + val apiVersionsResponse = sendUnsupportedApiVersionRequest(apiVersionsRequest) + assertEquals(Errors.UNSUPPORTED_VERSION.code(), apiVersionsResponse.data.errorCode()) + assertFalse(apiVersionsResponse.data.apiKeys().isEmpty) + val apiVersion = apiVersionsResponse.data.apiKeys().find(ApiKeys.API_VERSIONS.id) + assertEquals(ApiKeys.API_VERSIONS.id, apiVersion.apiKey()) + assertEquals(ApiKeys.API_VERSIONS.oldestVersion(), apiVersion.minVersion()) + assertEquals(ApiKeys.API_VERSIONS.latestVersion(), apiVersion.maxVersion()) + } + + @ClusterTest + def testApiVersionsRequestValidationV0(): Unit = { + val apiVersionsRequest = new ApiVersionsRequest.Builder().build(0.asInstanceOf[Short]) + val apiVersionsResponse = sendApiVersionsRequest(apiVersionsRequest, cluster.clientListener()) + validateApiVersionsResponse(apiVersionsResponse) + } + + @ClusterTest + def testApiVersionsRequestValidationV0ThroughControlPlaneListener(): Unit = { + val apiVersionsRequest = new ApiVersionsRequest.Builder().build(0.asInstanceOf[Short]) + val apiVersionsResponse = sendApiVersionsRequest(apiVersionsRequest, super.controlPlaneListenerName) + validateApiVersionsResponse(apiVersionsResponse) + } + + @ClusterTest + def testApiVersionsRequestValidationV3(): Unit = { + // Invalid request because Name and Version are empty by default + val apiVersionsRequest = new ApiVersionsRequest(new ApiVersionsRequestData(), 3.asInstanceOf[Short]) + val apiVersionsResponse = sendApiVersionsRequest(apiVersionsRequest, cluster.clientListener()) + assertEquals(Errors.INVALID_REQUEST.code(), apiVersionsResponse.data.errorCode()) + } +} diff --git a/core/src/test/scala/unit/kafka/server/AuthHelperTest.scala b/core/src/test/scala/unit/kafka/server/AuthHelperTest.scala new file mode 100644 index 0000000..194b2c6 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/AuthHelperTest.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.net.InetAddress +import java.util +import org.apache.kafka.common.acl.AclOperation +import org.apache.kafka.common.network.{ClientInformation, ListenerName} +import org.apache.kafka.common.protocol.ApiKeys +import org.apache.kafka.common.requests.{RequestContext, RequestHeader} +import org.apache.kafka.common.resource.{PatternType, ResourcePattern, ResourceType} +import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol} +import org.apache.kafka.server.authorizer.{Action, AuthorizationResult, Authorizer} +import org.easymock.EasyMock._ +import org.easymock.{EasyMock, IArgumentMatcher} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.collection.Seq +import scala.jdk.CollectionConverters._ + +class AuthHelperTest { + import AuthHelperTest._ + + private val clientId = "" + + @Test + def testAuthorize(): Unit = { + val authorizer: Authorizer = EasyMock.niceMock(classOf[Authorizer]) + + val operation = AclOperation.WRITE + val resourceType = ResourceType.TOPIC + val resourceName = "topic-1" + val requestHeader = new RequestHeader(ApiKeys.PRODUCE, ApiKeys.PRODUCE.latestVersion, clientId, 0) + val requestContext = new RequestContext(requestHeader, "1", InetAddress.getLocalHost, + KafkaPrincipal.ANONYMOUS, ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT), + SecurityProtocol.PLAINTEXT, ClientInformation.EMPTY, false) + + val expectedActions = Seq( + new Action(operation, new ResourcePattern(resourceType, resourceName, PatternType.LITERAL), + 1, true, true) + ) + + EasyMock.expect(authorizer.authorize(requestContext, expectedActions.asJava)) + .andReturn(Seq(AuthorizationResult.ALLOWED).asJava) + .once() + + EasyMock.replay(authorizer) + + val result = new AuthHelper(Some(authorizer)).authorize( + requestContext, operation, resourceType, resourceName) + + verify(authorizer) + + assertEquals(true, result) + } + + @Test + def testFilterByAuthorized(): Unit = { + val authorizer: Authorizer = EasyMock.niceMock(classOf[Authorizer]) + + val operation = AclOperation.WRITE + val resourceType = ResourceType.TOPIC + val resourceName1 = "topic-1" + val resourceName2 = "topic-2" + val resourceName3 = "topic-3" + val requestHeader = new RequestHeader(ApiKeys.PRODUCE, ApiKeys.PRODUCE.latestVersion, + clientId, 0) + val requestContext = new RequestContext(requestHeader, "1", InetAddress.getLocalHost, + KafkaPrincipal.ANONYMOUS, ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT), + SecurityProtocol.PLAINTEXT, ClientInformation.EMPTY, false) + + val expectedActions = Seq( + new Action(operation, new ResourcePattern(resourceType, resourceName1, PatternType.LITERAL), + 2, true, true), + new Action(operation, new ResourcePattern(resourceType, resourceName2, PatternType.LITERAL), + 1, true, true), + new Action(operation, new ResourcePattern(resourceType, resourceName3, PatternType.LITERAL), + 1, true, true), + ) + + EasyMock.expect(authorizer.authorize( + EasyMock.eq(requestContext), matchSameElements(expectedActions.asJava) + )).andAnswer { () => + val actions = EasyMock.getCurrentArguments.apply(1).asInstanceOf[util.List[Action]].asScala + actions.map { action => + if (Set(resourceName1, resourceName3).contains(action.resourcePattern.name)) + AuthorizationResult.ALLOWED + else + AuthorizationResult.DENIED + }.asJava + }.once() + + EasyMock.replay(authorizer) + + val result = new AuthHelper(Some(authorizer)).filterByAuthorized( + requestContext, + operation, + resourceType, + // Duplicate resource names should not trigger multiple calls to authorize + Seq(resourceName1, resourceName2, resourceName1, resourceName3) + )(identity) + + verify(authorizer) + + assertEquals(Set(resourceName1, resourceName3), result) + } + +} + +object AuthHelperTest { + + /** + * Similar to `EasyMock.eq`, but matches if both lists have the same elements irrespective of ordering. + */ + def matchSameElements[T](list: java.util.List[T]): java.util.List[T] = { + EasyMock.reportMatcher(new IArgumentMatcher { + def matches(argument: Any): Boolean = argument match { + case l: java.util.List[_] => list.asScala.toSet == l.asScala.toSet + case _ => false + } + def appendTo(buffer: StringBuffer): Unit = buffer.append(s"list($list)") + }) + null + } + +} diff --git a/core/src/test/scala/unit/kafka/server/AutoTopicCreationManagerTest.scala b/core/src/test/scala/unit/kafka/server/AutoTopicCreationManagerTest.scala new file mode 100644 index 0000000..8c6b922 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/AutoTopicCreationManagerTest.scala @@ -0,0 +1,402 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.net.InetAddress +import java.nio.ByteBuffer +import java.util.concurrent.atomic.AtomicBoolean +import java.util.{Collections, Optional, Properties} + +import kafka.controller.KafkaController +import kafka.coordinator.group.GroupCoordinator +import kafka.coordinator.transaction.TransactionCoordinator +import kafka.utils.TestUtils +import org.apache.kafka.clients.{ClientResponse, NodeApiVersions, RequestCompletionHandler} +import org.apache.kafka.common.Node +import org.apache.kafka.common.internals.Topic +import org.apache.kafka.common.internals.Topic.{GROUP_METADATA_TOPIC_NAME, TRANSACTION_STATE_TOPIC_NAME} +import org.apache.kafka.common.message.{ApiVersionsResponseData, CreateTopicsRequestData} +import org.apache.kafka.common.message.CreateTopicsRequestData.CreatableTopic +import org.apache.kafka.common.message.MetadataResponseData.MetadataResponseTopic +import org.apache.kafka.common.network.{ClientInformation, ListenerName} +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.requests._ +import org.apache.kafka.common.security.auth.{KafkaPrincipal, KafkaPrincipalSerde, SecurityProtocol} +import org.apache.kafka.common.utils.{SecurityUtils, Utils} +import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows, assertTrue} +import org.junit.jupiter.api.{BeforeEach, Test} +import org.mockito.ArgumentMatchers.any +import org.mockito.invocation.InvocationOnMock +import org.mockito.{ArgumentCaptor, ArgumentMatchers, Mockito} + +import scala.collection.{Map, Seq} + +class AutoTopicCreationManagerTest { + + private val requestTimeout = 100 + private var config: KafkaConfig = _ + private val metadataCache = Mockito.mock(classOf[MetadataCache]) + private val brokerToController = Mockito.mock(classOf[BrokerToControllerChannelManager]) + private val adminManager = Mockito.mock(classOf[ZkAdminManager]) + private val controller = Mockito.mock(classOf[KafkaController]) + private val groupCoordinator = Mockito.mock(classOf[GroupCoordinator]) + private val transactionCoordinator = Mockito.mock(classOf[TransactionCoordinator]) + private var autoTopicCreationManager: AutoTopicCreationManager = _ + + private val internalTopicPartitions = 2 + private val internalTopicReplicationFactor: Short = 2 + + @BeforeEach + def setup(): Unit = { + val props = TestUtils.createBrokerConfig(1, "localhost") + props.setProperty(KafkaConfig.RequestTimeoutMsProp, requestTimeout.toString) + + props.setProperty(KafkaConfig.OffsetsTopicReplicationFactorProp, internalTopicPartitions.toString) + props.setProperty(KafkaConfig.TransactionsTopicReplicationFactorProp, internalTopicPartitions.toString) + + props.setProperty(KafkaConfig.OffsetsTopicPartitionsProp, internalTopicReplicationFactor.toString) + props.setProperty(KafkaConfig.TransactionsTopicPartitionsProp, internalTopicReplicationFactor.toString) + + config = KafkaConfig.fromProps(props) + val aliveBrokers = Seq(new Node(0, "host0", 0), new Node(1, "host1", 1)) + + Mockito.reset(metadataCache, controller, brokerToController, groupCoordinator, transactionCoordinator) + + Mockito.when(metadataCache.getAliveBrokerNodes(any(classOf[ListenerName]))).thenReturn(aliveBrokers) + } + + @Test + def testCreateOffsetTopic(): Unit = { + Mockito.when(groupCoordinator.offsetsTopicConfigs).thenReturn(new Properties) + testCreateTopic(GROUP_METADATA_TOPIC_NAME, true, internalTopicPartitions, internalTopicReplicationFactor) + } + + @Test + def testCreateTxnTopic(): Unit = { + Mockito.when(transactionCoordinator.transactionTopicConfigs).thenReturn(new Properties) + testCreateTopic(TRANSACTION_STATE_TOPIC_NAME, true, internalTopicPartitions, internalTopicReplicationFactor) + } + + @Test + def testCreateNonInternalTopic(): Unit = { + testCreateTopic("topic", false) + } + + private def testCreateTopic(topicName: String, + isInternal: Boolean, + numPartitions: Int = 1, + replicationFactor: Short = 1): Unit = { + autoTopicCreationManager = new DefaultAutoTopicCreationManager( + config, + Some(brokerToController), + Some(adminManager), + Some(controller), + groupCoordinator, + transactionCoordinator) + + val topicsCollection = new CreateTopicsRequestData.CreatableTopicCollection + topicsCollection.add(getNewTopic(topicName, numPartitions, replicationFactor)) + val requestBody = new CreateTopicsRequest.Builder( + new CreateTopicsRequestData() + .setTopics(topicsCollection) + .setTimeoutMs(requestTimeout)) + + Mockito.when(controller.isActive).thenReturn(false) + + // Calling twice with the same topic will only trigger one forwarding. + createTopicAndVerifyResult(Errors.UNKNOWN_TOPIC_OR_PARTITION, topicName, isInternal) + createTopicAndVerifyResult(Errors.UNKNOWN_TOPIC_OR_PARTITION, topicName, isInternal) + + Mockito.verify(brokerToController).sendRequest( + ArgumentMatchers.eq(requestBody), + any(classOf[ControllerRequestCompletionHandler])) + } + + @Test + def testCreateTopicsWithForwardingDisabled(): Unit = { + autoTopicCreationManager = new DefaultAutoTopicCreationManager( + config, + None, + Some(adminManager), + Some(controller), + groupCoordinator, + transactionCoordinator) + + val topicName = "topic" + + Mockito.when(controller.isActive).thenReturn(false) + + createTopicAndVerifyResult(Errors.UNKNOWN_TOPIC_OR_PARTITION, topicName, false) + + Mockito.verify(adminManager).createTopics( + ArgumentMatchers.eq(0), + ArgumentMatchers.eq(false), + ArgumentMatchers.eq(Map(topicName -> getNewTopic(topicName))), + ArgumentMatchers.eq(Map.empty), + any(classOf[ControllerMutationQuota]), + any(classOf[Map[String, ApiError] => Unit])) + } + + @Test + def testInvalidReplicationFactorForNonInternalTopics(): Unit = { + testErrorWithCreationInZk(Errors.INVALID_REPLICATION_FACTOR, "topic", isInternal = false) + } + + @Test + def testInvalidReplicationFactorForConsumerOffsetsTopic(): Unit = { + Mockito.when(groupCoordinator.offsetsTopicConfigs).thenReturn(new Properties) + testErrorWithCreationInZk(Errors.INVALID_REPLICATION_FACTOR, Topic.GROUP_METADATA_TOPIC_NAME, isInternal = true) + } + + @Test + def testInvalidReplicationFactorForTxnOffsetTopic(): Unit = { + Mockito.when(transactionCoordinator.transactionTopicConfigs).thenReturn(new Properties) + testErrorWithCreationInZk(Errors.INVALID_REPLICATION_FACTOR, Topic.TRANSACTION_STATE_TOPIC_NAME, isInternal = true) + } + + @Test + def testTopicExistsErrorSwapForNonInternalTopics(): Unit = { + testErrorWithCreationInZk(Errors.TOPIC_ALREADY_EXISTS, "topic", isInternal = false, + expectedError = Some(Errors.LEADER_NOT_AVAILABLE)) + } + + @Test + def testTopicExistsErrorSwapForConsumerOffsetsTopic(): Unit = { + Mockito.when(groupCoordinator.offsetsTopicConfigs).thenReturn(new Properties) + testErrorWithCreationInZk(Errors.TOPIC_ALREADY_EXISTS, Topic.GROUP_METADATA_TOPIC_NAME, isInternal = true, + expectedError = Some(Errors.LEADER_NOT_AVAILABLE)) + } + + @Test + def testTopicExistsErrorSwapForTxnOffsetTopic(): Unit = { + Mockito.when(transactionCoordinator.transactionTopicConfigs).thenReturn(new Properties) + testErrorWithCreationInZk(Errors.TOPIC_ALREADY_EXISTS, Topic.TRANSACTION_STATE_TOPIC_NAME, isInternal = true, + expectedError = Some(Errors.LEADER_NOT_AVAILABLE)) + } + + @Test + def testRequestTimeoutErrorSwapForNonInternalTopics(): Unit = { + testErrorWithCreationInZk(Errors.REQUEST_TIMED_OUT, "topic", isInternal = false, + expectedError = Some(Errors.LEADER_NOT_AVAILABLE)) + } + + @Test + def testRequestTimeoutErrorSwapForConsumerOffsetTopic(): Unit = { + Mockito.when(groupCoordinator.offsetsTopicConfigs).thenReturn(new Properties) + testErrorWithCreationInZk(Errors.REQUEST_TIMED_OUT, Topic.GROUP_METADATA_TOPIC_NAME, isInternal = true, + expectedError = Some(Errors.LEADER_NOT_AVAILABLE)) + } + + @Test + def testRequestTimeoutErrorSwapForTxnOffsetTopic(): Unit = { + Mockito.when(transactionCoordinator.transactionTopicConfigs).thenReturn(new Properties) + testErrorWithCreationInZk(Errors.REQUEST_TIMED_OUT, Topic.TRANSACTION_STATE_TOPIC_NAME, isInternal = true, + expectedError = Some(Errors.LEADER_NOT_AVAILABLE)) + } + + @Test + def testUnknownTopicPartitionForNonIntervalTopic(): Unit = { + testErrorWithCreationInZk(Errors.UNKNOWN_TOPIC_OR_PARTITION, "topic", isInternal = false) + } + + @Test + def testUnknownTopicPartitionForConsumerOffsetTopic(): Unit = { + Mockito.when(groupCoordinator.offsetsTopicConfigs).thenReturn(new Properties) + testErrorWithCreationInZk(Errors.UNKNOWN_TOPIC_OR_PARTITION, Topic.GROUP_METADATA_TOPIC_NAME, isInternal = true) + } + + @Test + def testUnknownTopicPartitionForTxnOffsetTopic(): Unit = { + Mockito.when(transactionCoordinator.transactionTopicConfigs).thenReturn(new Properties) + testErrorWithCreationInZk(Errors.UNKNOWN_TOPIC_OR_PARTITION, Topic.TRANSACTION_STATE_TOPIC_NAME, isInternal = true) + } + + @Test + def testTopicCreationWithMetadataContextPassPrincipal(): Unit = { + val topicName = "topic" + + val userPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "user") + val serializeIsCalled = new AtomicBoolean(false) + val principalSerde = new KafkaPrincipalSerde { + override def serialize(principal: KafkaPrincipal): Array[Byte] = { + assertEquals(principal, userPrincipal) + serializeIsCalled.set(true) + Utils.utf8(principal.toString) + } + override def deserialize(bytes: Array[Byte]): KafkaPrincipal = SecurityUtils.parseKafkaPrincipal(Utils.utf8(bytes)) + } + + val requestContext = initializeRequestContext(topicName, userPrincipal, Optional.of(principalSerde)) + + autoTopicCreationManager.createTopics( + Set(topicName), UnboundedControllerMutationQuota, Some(requestContext)) + + assertTrue(serializeIsCalled.get()) + + val argumentCaptor = ArgumentCaptor.forClass(classOf[AbstractRequest.Builder[_ <: AbstractRequest]]) + Mockito.verify(brokerToController).sendRequest( + argumentCaptor.capture(), + any(classOf[ControllerRequestCompletionHandler])) + val capturedRequest = argumentCaptor.getValue.asInstanceOf[EnvelopeRequest.Builder].build(ApiKeys.ENVELOPE.latestVersion()) + assertEquals(userPrincipal, SecurityUtils.parseKafkaPrincipal(Utils.utf8(capturedRequest.requestPrincipal))) + } + + @Test + def testTopicCreationWithMetadataContextWhenPrincipalSerdeNotDefined(): Unit = { + val topicName = "topic" + + val requestContext = initializeRequestContext(topicName, KafkaPrincipal.ANONYMOUS, Optional.empty()) + + // Throw upon undefined principal serde when building the forward request + assertThrows(classOf[IllegalArgumentException], () => autoTopicCreationManager.createTopics( + Set(topicName), UnboundedControllerMutationQuota, Some(requestContext))) + } + + @Test + def testTopicCreationWithMetadataContextNoRetryUponUnsupportedVersion(): Unit = { + val topicName = "topic" + + val principalSerde = new KafkaPrincipalSerde { + override def serialize(principal: KafkaPrincipal): Array[Byte] = { + Utils.utf8(principal.toString) + } + override def deserialize(bytes: Array[Byte]): KafkaPrincipal = SecurityUtils.parseKafkaPrincipal(Utils.utf8(bytes)) + } + + val requestContext = initializeRequestContext(topicName, KafkaPrincipal.ANONYMOUS, Optional.of(principalSerde)) + autoTopicCreationManager.createTopics( + Set(topicName), UnboundedControllerMutationQuota, Some(requestContext)) + autoTopicCreationManager.createTopics( + Set(topicName), UnboundedControllerMutationQuota, Some(requestContext)) + + // Should only trigger once + val argumentCaptor = ArgumentCaptor.forClass(classOf[ControllerRequestCompletionHandler]) + Mockito.verify(brokerToController).sendRequest( + any(classOf[AbstractRequest.Builder[_ <: AbstractRequest]]), + argumentCaptor.capture()) + + // Complete with unsupported version will not trigger a retry, but cleanup the inflight topics instead + val header = new RequestHeader(ApiKeys.ENVELOPE, 0, "client", 1) + val response = new EnvelopeResponse(ByteBuffer.allocate(0), Errors.UNSUPPORTED_VERSION) + val clientResponse = new ClientResponse(header, null, null, + 0, 0, false, null, null, response) + argumentCaptor.getValue.asInstanceOf[RequestCompletionHandler].onComplete(clientResponse) + Mockito.verify(brokerToController, Mockito.times(1)).sendRequest( + any(classOf[AbstractRequest.Builder[_ <: AbstractRequest]]), + argumentCaptor.capture()) + + // Could do the send again as inflight topics are cleared. + autoTopicCreationManager.createTopics( + Set(topicName), UnboundedControllerMutationQuota, Some(requestContext)) + Mockito.verify(brokerToController, Mockito.times(2)).sendRequest( + any(classOf[AbstractRequest.Builder[_ <: AbstractRequest]]), + argumentCaptor.capture()) + } + + private def initializeRequestContext(topicName: String, + kafkaPrincipal: KafkaPrincipal, + principalSerde: Optional[KafkaPrincipalSerde]): RequestContext = { + + autoTopicCreationManager = new DefaultAutoTopicCreationManager( + config, + Some(brokerToController), + Some(adminManager), + Some(controller), + groupCoordinator, + transactionCoordinator) + + val topicsCollection = new CreateTopicsRequestData.CreatableTopicCollection + topicsCollection.add(getNewTopic(topicName)) + val createTopicApiVersion = new ApiVersionsResponseData.ApiVersion() + .setApiKey(ApiKeys.CREATE_TOPICS.id) + .setMinVersion(0) + .setMaxVersion(0) + Mockito.when(brokerToController.controllerApiVersions()) + .thenReturn(Some(NodeApiVersions.create(Collections.singleton(createTopicApiVersion)))) + + Mockito.when(controller.isActive).thenReturn(false) + + val requestHeader = new RequestHeader(ApiKeys.METADATA, ApiKeys.METADATA.latestVersion, + "clientId", 0) + new RequestContext(requestHeader, "1", InetAddress.getLocalHost, + kafkaPrincipal, ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT), + SecurityProtocol.PLAINTEXT, ClientInformation.EMPTY, false, principalSerde) + } + + private def testErrorWithCreationInZk(error: Errors, + topicName: String, + isInternal: Boolean, + expectedError: Option[Errors] = None): Unit = { + autoTopicCreationManager = new DefaultAutoTopicCreationManager( + config, + None, + Some(adminManager), + Some(controller), + groupCoordinator, + transactionCoordinator) + + Mockito.when(controller.isActive).thenReturn(false) + val newTopic = if (isInternal) { + topicName match { + case Topic.GROUP_METADATA_TOPIC_NAME => getNewTopic(topicName, + numPartitions = config.offsetsTopicPartitions, replicationFactor = config.offsetsTopicReplicationFactor) + case Topic.TRANSACTION_STATE_TOPIC_NAME => getNewTopic(topicName, + numPartitions = config.transactionTopicPartitions, replicationFactor = config.transactionTopicReplicationFactor) + } + } else { + getNewTopic(topicName) + } + + val topicErrors = if (error == Errors.UNKNOWN_TOPIC_OR_PARTITION) null else + Map(topicName -> new ApiError(error)) + Mockito.when(adminManager.createTopics( + ArgumentMatchers.eq(0), + ArgumentMatchers.eq(false), + ArgumentMatchers.eq(Map(topicName -> newTopic)), + ArgumentMatchers.eq(Map.empty), + any(classOf[ControllerMutationQuota]), + any(classOf[Map[String, ApiError] => Unit]))).thenAnswer((invocation: InvocationOnMock) => { + invocation.getArgument(5).asInstanceOf[Map[String, ApiError] => Unit] + .apply(topicErrors) + }) + + createTopicAndVerifyResult(expectedError.getOrElse(error), topicName, isInternal = isInternal) + } + + private def createTopicAndVerifyResult(error: Errors, + topicName: String, + isInternal: Boolean, + metadataContext: Option[RequestContext] = None): Unit = { + val topicResponses = autoTopicCreationManager.createTopics( + Set(topicName), UnboundedControllerMutationQuota, metadataContext) + + val expectedResponses = Seq(new MetadataResponseTopic() + .setErrorCode(error.code()) + .setIsInternal(isInternal) + .setName(topicName)) + + assertEquals(expectedResponses, topicResponses) + } + + private def getNewTopic(topicName: String, numPartitions: Int = 1, replicationFactor: Short = 1): CreatableTopic = { + new CreatableTopic() + .setName(topicName) + .setNumPartitions(numPartitions) + .setReplicationFactor(replicationFactor) + } +} diff --git a/core/src/test/scala/unit/kafka/server/BaseClientQuotaManagerTest.scala b/core/src/test/scala/unit/kafka/server/BaseClientQuotaManagerTest.scala new file mode 100644 index 0000000..526e6b5 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/BaseClientQuotaManagerTest.scala @@ -0,0 +1,85 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.net.InetAddress +import java.util +import java.util.Collections + +import kafka.network.RequestChannel +import kafka.network.RequestChannel.Session +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.memory.MemoryPool +import org.apache.kafka.common.metrics.{MetricConfig, Metrics} +import org.apache.kafka.common.network.{ClientInformation, ListenerName} +import org.apache.kafka.common.protocol.ApiKeys +import org.apache.kafka.common.requests.FetchRequest.PartitionData +import org.apache.kafka.common.requests.{AbstractRequest, FetchRequest, RequestContext, RequestHeader} +import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol} +import org.apache.kafka.common.utils.MockTime +import org.easymock.EasyMock +import org.junit.jupiter.api.AfterEach + +class BaseClientQuotaManagerTest { + protected val time = new MockTime + protected var numCallbacks: Int = 0 + protected val metrics = new Metrics(new MetricConfig(), Collections.emptyList(), time) + + @AfterEach + def tearDown(): Unit = { + metrics.close() + } + + protected def callback: ThrottleCallback = new ThrottleCallback { + override def startThrottling(): Unit = {} + override def endThrottling(): Unit = { + // Count how many times this callback is called for notifyThrottlingDone(). + numCallbacks += 1 + } + } + + protected def buildRequest[T <: AbstractRequest](builder: AbstractRequest.Builder[T], + listenerName: ListenerName = ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT)): (T, RequestChannel.Request) = { + + val request = builder.build() + val buffer = request.serializeWithHeader(new RequestHeader(builder.apiKey, request.version, "", 0)) + val requestChannelMetrics: RequestChannel.Metrics = EasyMock.createNiceMock(classOf[RequestChannel.Metrics]) + + // read the header from the buffer first so that the body can be read next from the Request constructor + val header = RequestHeader.parse(buffer) + val context = new RequestContext(header, "1", InetAddress.getLocalHost, KafkaPrincipal.ANONYMOUS, + listenerName, SecurityProtocol.PLAINTEXT, ClientInformation.EMPTY, false) + (request, new RequestChannel.Request(processor = 1, context = context, startTimeNanos = 0, MemoryPool.NONE, buffer, + requestChannelMetrics)) + } + + protected def buildSession(user: String): Session = { + val principal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, user) + Session(principal, null) + } + + protected def maybeRecord(quotaManager: ClientQuotaManager, user: String, clientId: String, value: Double): Int = { + + quotaManager.maybeRecordAndGetThrottleTimeMs(buildSession(user), clientId, value, time.milliseconds) + } + + protected def throttle(quotaManager: ClientQuotaManager, user: String, clientId: String, throttleTimeMs: Int, + channelThrottlingCallback: ThrottleCallback): Unit = { + val (_, request) = buildRequest(FetchRequest.Builder.forConsumer(ApiKeys.FETCH.latestVersion, 0, 1000, new util.HashMap[TopicPartition, PartitionData])) + quotaManager.throttle(request, channelThrottlingCallback, throttleTimeMs) + } +} diff --git a/core/src/test/scala/unit/kafka/server/BaseFetchRequestTest.scala b/core/src/test/scala/unit/kafka/server/BaseFetchRequestTest.scala new file mode 100644 index 0000000..8ef7424 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/BaseFetchRequestTest.scala @@ -0,0 +1,103 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import kafka.log.LogConfig +import kafka.utils.TestUtils +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord, RecordMetadata} +import org.apache.kafka.common.{TopicPartition, Uuid} +import org.apache.kafka.common.message.FetchResponseData +import org.apache.kafka.common.record.Record +import org.apache.kafka.common.requests.{FetchRequest, FetchResponse} +import org.apache.kafka.common.serialization.StringSerializer +import org.junit.jupiter.api.AfterEach +import java.util +import java.util.{Optional, Properties} + +import scala.collection.Seq +import scala.jdk.CollectionConverters._ + +class BaseFetchRequestTest extends BaseRequestTest { + + protected var producer: KafkaProducer[String, String] = null + + override def brokerPropertyOverrides(properties: Properties): Unit = { + properties.put(KafkaConfig.FetchMaxBytes, Int.MaxValue.toString) + } + + @AfterEach + override def tearDown(): Unit = { + if (producer != null) + producer.close() + super.tearDown() + } + + protected def createFetchRequest(maxResponseBytes: Int, maxPartitionBytes: Int, topicPartitions: Seq[TopicPartition], + offsetMap: Map[TopicPartition, Long], + version: Short): FetchRequest = { + FetchRequest.Builder.forConsumer(version, Int.MaxValue, 0, createPartitionMap(maxPartitionBytes, topicPartitions, offsetMap)) + .setMaxBytes(maxResponseBytes).build() + } + + protected def createPartitionMap(maxPartitionBytes: Int, topicPartitions: Seq[TopicPartition], + offsetMap: Map[TopicPartition, Long] = Map.empty): util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] = { + val partitionMap = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + topicPartitions.foreach { tp => + partitionMap.put(tp, + new FetchRequest.PartitionData(getTopicIds().getOrElse(tp.topic, Uuid.ZERO_UUID), offsetMap.getOrElse(tp, 0), 0L, maxPartitionBytes, + Optional.empty())) + } + partitionMap + } + + protected def sendFetchRequest(leaderId: Int, request: FetchRequest): FetchResponse = { + connectAndReceive[FetchResponse](request, destination = brokerSocketServer(leaderId)) + } + + protected def initProducer(): Unit = { + producer = TestUtils.createProducer(TestUtils.getBrokerListStrFromServers(servers), + keySerializer = new StringSerializer, valueSerializer = new StringSerializer) + } + + protected def createTopics(numTopics: Int, numPartitions: Int, configs: Map[String, String] = Map.empty): Map[TopicPartition, Int] = { + val topics = (0 until numTopics).map(t => s"topic$t") + val topicConfig = new Properties + topicConfig.setProperty(LogConfig.MinInSyncReplicasProp, 2.toString) + configs.foreach { case (k, v) => topicConfig.setProperty(k, v) } + topics.flatMap { topic => + val partitionToLeader = createTopic(topic, numPartitions = numPartitions, replicationFactor = 2, + topicConfig = topicConfig) + partitionToLeader.map { case (partition, leader) => new TopicPartition(topic, partition) -> leader } + }.toMap + } + + protected def produceData(topicPartitions: Iterable[TopicPartition], numMessagesPerPartition: Int): Seq[RecordMetadata] = { + val records = for { + tp <- topicPartitions.toSeq + messageIndex <- 0 until numMessagesPerPartition + } yield { + val suffix = s"$tp-$messageIndex" + new ProducerRecord(tp.topic, tp.partition, s"key $suffix", s"value $suffix") + } + records.map(producer.send(_).get) + } + + protected def records(partitionData: FetchResponseData.PartitionData): Seq[Record] = { + FetchResponse.recordsOrFail(partitionData).records.asScala.toBuffer + } + +} diff --git a/core/src/test/scala/unit/kafka/server/BaseRequestTest.scala b/core/src/test/scala/unit/kafka/server/BaseRequestTest.scala new file mode 100644 index 0000000..3d3d0ca --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/BaseRequestTest.scala @@ -0,0 +1,153 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.api.IntegrationTestHarness +import kafka.network.SocketServer +import kafka.utils.NotNothing +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.protocol.ApiKeys +import org.apache.kafka.common.requests.{AbstractRequest, AbstractResponse, RequestHeader, ResponseHeader} +import org.apache.kafka.common.utils.Utils +import org.apache.kafka.metadata.BrokerState + +import java.io.{DataInputStream, DataOutputStream} +import java.net.Socket +import java.nio.ByteBuffer +import java.util.Properties +import scala.annotation.nowarn +import scala.collection.Seq +import scala.reflect.ClassTag + +abstract class BaseRequestTest extends IntegrationTestHarness { + private var correlationId = 0 + + // If required, set number of brokers + override def brokerCount: Int = 3 + + // If required, override properties by mutating the passed Properties object + protected def brokerPropertyOverrides(properties: Properties): Unit = {} + + override def modifyConfigs(props: Seq[Properties]): Unit = { + props.foreach { p => + p.put(KafkaConfig.ControlledShutdownEnableProp, "false") + brokerPropertyOverrides(p) + } + } + + def anySocketServer: SocketServer = { + servers.find { server => + val state = server.brokerState + state != BrokerState.NOT_RUNNING && state != BrokerState.SHUTTING_DOWN + }.map(_.socketServer).getOrElse(throw new IllegalStateException("No live broker is available")) + } + + def controllerSocketServer: SocketServer = { + servers.find { server => + server.kafkaController.isActive + }.map(_.socketServer).getOrElse(throw new IllegalStateException("No controller broker is available")) + } + + def notControllerSocketServer: SocketServer = { + servers.find { server => + !server.kafkaController.isActive + }.map(_.socketServer).getOrElse(throw new IllegalStateException("No non-controller broker is available")) + } + + def brokerSocketServer(brokerId: Int): SocketServer = { + servers.find { server => + server.config.brokerId == brokerId + }.map(_.socketServer).getOrElse(throw new IllegalStateException(s"Could not find broker with id $brokerId")) + } + + def connect(socketServer: SocketServer = anySocketServer, + listenerName: ListenerName = listenerName): Socket = { + new Socket("localhost", socketServer.boundPort(listenerName)) + } + + private def sendRequest(socket: Socket, request: Array[Byte]): Unit = { + val outgoing = new DataOutputStream(socket.getOutputStream) + outgoing.writeInt(request.length) + outgoing.write(request) + outgoing.flush() + } + + def receive[T <: AbstractResponse](socket: Socket, apiKey: ApiKeys, version: Short) + (implicit classTag: ClassTag[T], @nowarn("cat=unused") nn: NotNothing[T]): T = { + val incoming = new DataInputStream(socket.getInputStream) + val len = incoming.readInt() + + val responseBytes = new Array[Byte](len) + incoming.readFully(responseBytes) + + val responseBuffer = ByteBuffer.wrap(responseBytes) + ResponseHeader.parse(responseBuffer, apiKey.responseHeaderVersion(version)) + + AbstractResponse.parseResponse(apiKey, responseBuffer, version) match { + case response: T => response + case response => + throw new ClassCastException(s"Expected response with type ${classTag.runtimeClass}, but found ${response.getClass}") + } + } + + def sendAndReceive[T <: AbstractResponse](request: AbstractRequest, + socket: Socket, + clientId: String = "client-id", + correlationId: Option[Int] = None) + (implicit classTag: ClassTag[T], nn: NotNothing[T]): T = { + send(request, socket, clientId, correlationId) + receive[T](socket, request.apiKey, request.version) + } + + def connectAndReceive[T <: AbstractResponse](request: AbstractRequest, + destination: SocketServer = anySocketServer, + listenerName: ListenerName = listenerName) + (implicit classTag: ClassTag[T], nn: NotNothing[T]): T = { + val socket = connect(destination, listenerName) + try sendAndReceive[T](request, socket) + finally socket.close() + } + + /** + * Serializes and sends the request to the given api. + */ + def send(request: AbstractRequest, + socket: Socket, + clientId: String = "client-id", + correlationId: Option[Int] = None): Unit = { + val header = nextRequestHeader(request.apiKey, request.version, clientId, correlationId) + sendWithHeader(request, header, socket) + } + + def sendWithHeader(request: AbstractRequest, header: RequestHeader, socket: Socket): Unit = { + val serializedBytes = Utils.toArray(request.serializeWithHeader(header)) + sendRequest(socket, serializedBytes) + } + + def nextRequestHeader[T <: AbstractResponse](apiKey: ApiKeys, + apiVersion: Short, + clientId: String = "client-id", + correlationIdOpt: Option[Int] = None): RequestHeader = { + val correlationId = correlationIdOpt.getOrElse { + this.correlationId += 1 + this.correlationId + } + new RequestHeader(apiKey, apiVersion, clientId, correlationId) + } + +} diff --git a/core/src/test/scala/unit/kafka/server/BrokerEpochIntegrationTest.scala b/core/src/test/scala/unit/kafka/server/BrokerEpochIntegrationTest.scala new file mode 100755 index 0000000..0e1b148 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/BrokerEpochIntegrationTest.scala @@ -0,0 +1,283 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util.Collections + +import kafka.api.LeaderAndIsr +import kafka.cluster.Broker +import kafka.controller.{ControllerChannelManager, ControllerContext, StateChangeLogger} +import kafka.utils.TestUtils +import kafka.utils.TestUtils.createTopic +import kafka.server.QuorumTestHarness +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrPartitionState +import org.apache.kafka.common.message.StopReplicaRequestData.{StopReplicaPartitionState, StopReplicaTopicState} +import org.apache.kafka.common.message.UpdateMetadataRequestData.{UpdateMetadataBroker, UpdateMetadataEndpoint, UpdateMetadataPartitionState} +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.requests._ +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.utils.Time +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +import scala.jdk.CollectionConverters._ + +class BrokerEpochIntegrationTest extends QuorumTestHarness { + val brokerId1 = 0 + val brokerId2 = 1 + + var servers: Seq[KafkaServer] = Seq.empty[KafkaServer] + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + val configs = Seq( + TestUtils.createBrokerConfig(brokerId1, zkConnect), + TestUtils.createBrokerConfig(brokerId2, zkConnect)) + + configs.foreach { config => + config.setProperty(KafkaConfig.AutoLeaderRebalanceEnableProp, false.toString)} + + // start both servers + servers = configs.map(config => TestUtils.createServer(KafkaConfig.fromProps(config))) + } + + @AfterEach + override def tearDown(): Unit = { + TestUtils.shutdownServers(servers) + super.tearDown() + } + + @Test + def testReplicaManagerBrokerEpochMatchesWithZk(): Unit = { + val brokerAndEpochs = zkClient.getAllBrokerAndEpochsInCluster + assertEquals(brokerAndEpochs.size, servers.size) + brokerAndEpochs.foreach { + case (broker, epoch) => + val brokerServer = servers.find(e => e.config.brokerId == broker.id) + assertTrue(brokerServer.isDefined) + assertEquals(epoch, brokerServer.get.kafkaController.brokerEpoch) + } + } + + @Test + def testControllerBrokerEpochCacheMatchesWithZk(): Unit = { + val controller = getController + val otherBroker = servers.find(e => e.config.brokerId != controller.config.brokerId).get + + // Broker epochs cache matches with zk in steady state + checkControllerBrokerEpochsCacheMatchesWithZk(controller.kafkaController.controllerContext) + + // Shutdown a broker and make sure broker epochs cache still matches with zk state + otherBroker.shutdown() + checkControllerBrokerEpochsCacheMatchesWithZk(controller.kafkaController.controllerContext) + + // Restart a broker and make sure broker epochs cache still matches with zk state + otherBroker.startup() + checkControllerBrokerEpochsCacheMatchesWithZk(controller.kafkaController.controllerContext) + } + + @Test + def testControlRequestWithCorrectBrokerEpoch(): Unit = { + testControlRequestWithBrokerEpoch(0) + } + + @Test + def testControlRequestWithStaleBrokerEpoch(): Unit = { + testControlRequestWithBrokerEpoch(-1) + } + + @Test + def testControlRequestWithNewerBrokerEpoch(): Unit = { + testControlRequestWithBrokerEpoch(1) + } + + private def testControlRequestWithBrokerEpoch(epochInRequestDiffFromCurrentEpoch: Long): Unit = { + val tp = new TopicPartition("new-topic", 0) + + // create topic with 1 partition, 2 replicas, one on each broker + createTopic(zkClient, tp.topic(), partitionReplicaAssignment = Map(0 -> Seq(brokerId1, brokerId2)), servers = servers) + val topicIds = getController.kafkaController.controllerContext.topicIds.toMap.asJava + + val controllerId = 2 + val controllerEpoch = zkClient.getControllerEpoch.get._1 + + val controllerConfig = KafkaConfig.fromProps(TestUtils.createBrokerConfig(controllerId, zkConnect)) + val securityProtocol = SecurityProtocol.PLAINTEXT + val listenerName = ListenerName.forSecurityProtocol(securityProtocol) + val brokerAndEpochs = servers.map(s => + (new Broker(s.config.brokerId, "localhost", TestUtils.boundPort(s), listenerName, securityProtocol), + s.kafkaController.brokerEpoch)).toMap + val nodes = brokerAndEpochs.keys.map(_.node(listenerName)) + + val controllerContext = new ControllerContext + controllerContext.setLiveBrokers(brokerAndEpochs) + val metrics = new Metrics + val controllerChannelManager = new ControllerChannelManager(controllerContext, controllerConfig, Time.SYSTEM, + metrics, new StateChangeLogger(controllerId, inControllerContext = true, None)) + controllerChannelManager.startup() + + val broker2 = servers(brokerId2) + val epochInRequest = broker2.kafkaController.brokerEpoch + epochInRequestDiffFromCurrentEpoch + + try { + // Send LeaderAndIsr request with correct broker epoch + { + val partitionStates = Seq( + new LeaderAndIsrPartitionState() + .setTopicName(tp.topic) + .setPartitionIndex(tp.partition) + .setControllerEpoch(controllerEpoch) + .setLeader(brokerId2) + .setLeaderEpoch(LeaderAndIsr.initialLeaderEpoch + 1) + .setIsr(Seq(brokerId1, brokerId2).map(Integer.valueOf).asJava) + .setZkVersion(LeaderAndIsr.initialZKVersion) + .setReplicas(Seq(0, 1).map(Integer.valueOf).asJava) + .setIsNew(false) + ) + val requestBuilder = new LeaderAndIsrRequest.Builder( + ApiKeys.LEADER_AND_ISR.latestVersion, controllerId, controllerEpoch, + epochInRequest, + partitionStates.asJava, topicIds, nodes.toSet.asJava) + + if (epochInRequestDiffFromCurrentEpoch < 0) { + // stale broker epoch in LEADER_AND_ISR + sendAndVerifyStaleBrokerEpochInResponse(controllerChannelManager, requestBuilder) + } + else { + // broker epoch in LEADER_AND_ISR >= current broker epoch + sendAndVerifySuccessfulResponse(controllerChannelManager, requestBuilder) + TestUtils.waitUntilLeaderIsKnown(Seq(broker2), tp, 10000) + } + } + + // Send UpdateMetadata request with correct broker epoch + { + val partitionStates = Seq( + new UpdateMetadataPartitionState() + .setTopicName(tp.topic) + .setPartitionIndex(tp.partition) + .setControllerEpoch(controllerEpoch) + .setLeader(brokerId2) + .setLeaderEpoch(LeaderAndIsr.initialLeaderEpoch + 1) + .setIsr(Seq(brokerId1, brokerId2).map(Integer.valueOf).asJava) + .setZkVersion(LeaderAndIsr.initialZKVersion) + .setReplicas(Seq(0, 1).map(Integer.valueOf).asJava)) + val liveBrokers = brokerAndEpochs.map { case (broker, _) => + val securityProtocol = SecurityProtocol.PLAINTEXT + val listenerName = ListenerName.forSecurityProtocol(securityProtocol) + val node = broker.node(listenerName) + val endpoints = Seq(new UpdateMetadataEndpoint() + .setHost(node.host) + .setPort(node.port) + .setSecurityProtocol(securityProtocol.id) + .setListener(listenerName.value)) + new UpdateMetadataBroker() + .setId(broker.id) + .setEndpoints(endpoints.asJava) + .setRack(broker.rack.orNull) + }.toBuffer + val requestBuilder = new UpdateMetadataRequest.Builder( + ApiKeys.UPDATE_METADATA.latestVersion, controllerId, controllerEpoch, + epochInRequest, + partitionStates.asJava, liveBrokers.asJava, Collections.emptyMap()) + + if (epochInRequestDiffFromCurrentEpoch < 0) { + // stale broker epoch in UPDATE_METADATA + sendAndVerifyStaleBrokerEpochInResponse(controllerChannelManager, requestBuilder) + } + else { + // broker epoch in UPDATE_METADATA >= current broker epoch + sendAndVerifySuccessfulResponse(controllerChannelManager, requestBuilder) + TestUtils.waitForPartitionMetadata(Seq(broker2), tp.topic, tp.partition, 10000) + assertEquals(brokerId2, + broker2.metadataCache.getPartitionInfo(tp.topic, tp.partition).get.leader) + } + } + + // Send StopReplica request with correct broker epoch + { + val topicStates = Seq( + new StopReplicaTopicState() + .setTopicName(tp.topic()) + .setPartitionStates(Seq(new StopReplicaPartitionState() + .setPartitionIndex(tp.partition()) + .setLeaderEpoch(LeaderAndIsr.initialLeaderEpoch + 2) + .setDeletePartition(true)).asJava) + ).asJava + val requestBuilder = new StopReplicaRequest.Builder( + ApiKeys.STOP_REPLICA.latestVersion, controllerId, controllerEpoch, + epochInRequest, // Correct broker epoch + false, topicStates) + + if (epochInRequestDiffFromCurrentEpoch < 0) { + // stale broker epoch in STOP_REPLICA + sendAndVerifyStaleBrokerEpochInResponse(controllerChannelManager, requestBuilder) + } else { + // broker epoch in STOP_REPLICA >= current broker epoch + sendAndVerifySuccessfulResponse(controllerChannelManager, requestBuilder) + assertEquals(HostedPartition.None, broker2.replicaManager.getPartition(tp)) + } + } + } finally { + controllerChannelManager.shutdown() + metrics.close() + } + } + + private def getController: KafkaServer = { + val controllerId = TestUtils.waitUntilControllerElected(zkClient) + servers.filter(s => s.config.brokerId == controllerId).head + } + + private def checkControllerBrokerEpochsCacheMatchesWithZk(controllerContext: ControllerContext): Unit = { + val brokerAndEpochs = zkClient.getAllBrokerAndEpochsInCluster + TestUtils.waitUntilTrue(() => { + val brokerEpochsInControllerContext = controllerContext.liveBrokerIdAndEpochs + if (brokerAndEpochs.size != brokerEpochsInControllerContext.size) false + else { + brokerAndEpochs.forall { + case (broker, epoch) => brokerEpochsInControllerContext.get(broker.id).contains(epoch) + } + } + }, "Broker epoch mismatches") + } + + private def sendAndVerifyStaleBrokerEpochInResponse(controllerChannelManager: ControllerChannelManager, + builder: AbstractControlRequest.Builder[_ <: AbstractControlRequest]): Unit = { + var staleBrokerEpochDetected = false + controllerChannelManager.sendRequest(brokerId2, builder, response => { + staleBrokerEpochDetected = response.errorCounts().containsKey(Errors.STALE_BROKER_EPOCH) + }) + TestUtils.waitUntilTrue(() => staleBrokerEpochDetected, "Broker epoch should be stale") + assertTrue(staleBrokerEpochDetected, "Stale broker epoch not detected by the broker") + } + + private def sendAndVerifySuccessfulResponse(controllerChannelManager: ControllerChannelManager, + builder: AbstractControlRequest.Builder[_ <: AbstractControlRequest]): Unit = { + @volatile var succeed = false + controllerChannelManager.sendRequest(brokerId2, builder, response => { + succeed = response.errorCounts().isEmpty || + (response.errorCounts().containsKey(Errors.NONE) && response.errorCounts().size() == 1) + }) + TestUtils.waitUntilTrue(() => succeed, "Should receive response with no errors") + } +} diff --git a/core/src/test/scala/unit/kafka/server/BrokerFeaturesTest.scala b/core/src/test/scala/unit/kafka/server/BrokerFeaturesTest.scala new file mode 100644 index 0000000..c4cc52c --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/BrokerFeaturesTest.scala @@ -0,0 +1,106 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import org.apache.kafka.common.feature.{Features, FinalizedVersionRange, SupportedVersionRange} +import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertTrue} +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ + +class BrokerFeaturesTest { + + @Test + def testEmpty(): Unit = { + assertTrue(BrokerFeatures.createDefault().supportedFeatures.empty) + } + + @Test + def testIncompatibilitiesDueToAbsentFeature(): Unit = { + val brokerFeatures = BrokerFeatures.createDefault() + val supportedFeatures = Features.supportedFeatures(Map[String, SupportedVersionRange]( + "test_feature_1" -> new SupportedVersionRange(1, 4), + "test_feature_2" -> new SupportedVersionRange(1, 3)).asJava) + brokerFeatures.setSupportedFeatures(supportedFeatures) + + val compatibleFeatures = Map[String, FinalizedVersionRange]( + "test_feature_1" -> new FinalizedVersionRange(2, 3)) + val inCompatibleFeatures = Map[String, FinalizedVersionRange]( + "test_feature_3" -> new FinalizedVersionRange(3, 4)) + val features = compatibleFeatures++inCompatibleFeatures + val finalizedFeatures = Features.finalizedFeatures(features.asJava) + + assertEquals( + Features.finalizedFeatures(inCompatibleFeatures.asJava), + brokerFeatures.incompatibleFeatures(finalizedFeatures)) + assertTrue(BrokerFeatures.hasIncompatibleFeatures(supportedFeatures, finalizedFeatures)) + } + + @Test + def testIncompatibilitiesDueToIncompatibleFeature(): Unit = { + val brokerFeatures = BrokerFeatures.createDefault() + val supportedFeatures = Features.supportedFeatures(Map[String, SupportedVersionRange]( + "test_feature_1" -> new SupportedVersionRange(1, 4), + "test_feature_2" -> new SupportedVersionRange(1, 3)).asJava) + brokerFeatures.setSupportedFeatures(supportedFeatures) + + val compatibleFeatures = Map[String, FinalizedVersionRange]( + "test_feature_1" -> new FinalizedVersionRange(2, 3)) + val inCompatibleFeatures = Map[String, FinalizedVersionRange]( + "test_feature_2" -> new FinalizedVersionRange(1, 4)) + val features = compatibleFeatures++inCompatibleFeatures + val finalizedFeatures = Features.finalizedFeatures(features.asJava) + + assertEquals( + Features.finalizedFeatures(inCompatibleFeatures.asJava), + brokerFeatures.incompatibleFeatures(finalizedFeatures)) + assertTrue(BrokerFeatures.hasIncompatibleFeatures(supportedFeatures, finalizedFeatures)) + } + + @Test + def testCompatibleFeatures(): Unit = { + val brokerFeatures = BrokerFeatures.createDefault() + val supportedFeatures = Features.supportedFeatures(Map[String, SupportedVersionRange]( + "test_feature_1" -> new SupportedVersionRange(1, 4), + "test_feature_2" -> new SupportedVersionRange(1, 3)).asJava) + brokerFeatures.setSupportedFeatures(supportedFeatures) + + val compatibleFeatures = Map[String, FinalizedVersionRange]( + "test_feature_1" -> new FinalizedVersionRange(2, 3), + "test_feature_2" -> new FinalizedVersionRange(1, 3)) + val finalizedFeatures = Features.finalizedFeatures(compatibleFeatures.asJava) + assertTrue(brokerFeatures.incompatibleFeatures(finalizedFeatures).empty()) + assertFalse(BrokerFeatures.hasIncompatibleFeatures(supportedFeatures, finalizedFeatures)) + } + + @Test + def testDefaultFinalizedFeatures(): Unit = { + val brokerFeatures = BrokerFeatures.createDefault() + val supportedFeatures = Features.supportedFeatures(Map[String, SupportedVersionRange]( + "test_feature_1" -> new SupportedVersionRange(1, 4), + "test_feature_2" -> new SupportedVersionRange(1, 3), + "test_feature_3" -> new SupportedVersionRange(3, 7)).asJava) + brokerFeatures.setSupportedFeatures(supportedFeatures) + + val expectedFeatures = Map[String, FinalizedVersionRange]( + "test_feature_1" -> new FinalizedVersionRange(1, 4), + "test_feature_2" -> new FinalizedVersionRange(1, 3), + "test_feature_3" -> new FinalizedVersionRange(3, 7)) + assertEquals(Features.finalizedFeatures(expectedFeatures.asJava), brokerFeatures.defaultFinalizedFeatures) + } +} diff --git a/core/src/test/scala/unit/kafka/server/BrokerLifecycleManagerTest.scala b/core/src/test/scala/unit/kafka/server/BrokerLifecycleManagerTest.scala new file mode 100644 index 0000000..dd3e49d --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/BrokerLifecycleManagerTest.scala @@ -0,0 +1,235 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +package kafka.server + +import java.util.{Collections, Properties} +import java.util.concurrent.atomic.{AtomicLong, AtomicReference} + +import kafka.utils.{MockTime, TestUtils} +import org.apache.kafka.clients.{Metadata, MockClient, NodeApiVersions} +import org.apache.kafka.common.config.SaslConfigs +import org.apache.kafka.common.Node +import org.apache.kafka.common.internals.ClusterResourceListeners +import org.apache.kafka.common.message.ApiVersionsResponseData.ApiVersion +import org.apache.kafka.common.message.BrokerRegistrationRequestData.{Listener, ListenerCollection} +import org.apache.kafka.common.message.{BrokerHeartbeatResponseData, BrokerRegistrationResponseData} +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.protocol.ApiKeys.{BROKER_HEARTBEAT, BROKER_REGISTRATION} +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.{AbstractRequest, BrokerHeartbeatRequest, BrokerHeartbeatResponse, BrokerRegistrationResponse} +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.utils.LogContext +import org.apache.kafka.metadata.BrokerState +import org.junit.jupiter.api.{Test, Timeout} +import org.junit.jupiter.api.Assertions._ + +import scala.jdk.CollectionConverters._ + + +@Timeout(value = 12) +class BrokerLifecycleManagerTest { + def configProperties = { + val properties = new Properties() + properties.setProperty(KafkaConfig.LogDirsProp, "/tmp/foo") + properties.setProperty(KafkaConfig.ProcessRolesProp, "broker") + properties.setProperty(KafkaConfig.NodeIdProp, "1") + properties.setProperty(KafkaConfig.QuorumVotersProp, s"2@localhost:9093") + properties.setProperty(KafkaConfig.ControllerListenerNamesProp, "SSL") + properties.setProperty(KafkaConfig.InitialBrokerRegistrationTimeoutMsProp, "300000") + properties + } + + class SimpleControllerNodeProvider extends ControllerNodeProvider { + val node = new AtomicReference[Node](null) + + override def get(): Option[Node] = Option(node.get()) + + override def listenerName: ListenerName = new ListenerName("PLAINTEXT") + + override def securityProtocol: SecurityProtocol = SecurityProtocol.PLAINTEXT; + + override def saslMechanism: String = SaslConfigs.DEFAULT_SASL_MECHANISM + } + + class BrokerLifecycleManagerTestContext(properties: Properties) { + val config = new KafkaConfig(properties) + val time = new MockTime(1, 1) + val highestMetadataOffset = new AtomicLong(0) + val metadata = new Metadata(1000, 1000, new LogContext(), new ClusterResourceListeners()) + val mockClient = new MockClient(time, metadata) + val controllerNodeProvider = new SimpleControllerNodeProvider() + val nodeApiVersions = new NodeApiVersions(Seq(BROKER_REGISTRATION, BROKER_HEARTBEAT).map { + apiKey => new ApiVersion().setApiKey(apiKey.id). + setMinVersion(apiKey.oldestVersion()).setMaxVersion(apiKey.latestVersion()) + }.toList.asJava) + val mockChannelManager = new MockBrokerToControllerChannelManager(mockClient, + time, controllerNodeProvider, nodeApiVersions) + val clusterId = "x4AJGXQSRnephtTZzujw4w" + val advertisedListeners = new ListenerCollection() + config.effectiveAdvertisedListeners.foreach { ep => + advertisedListeners.add(new Listener().setHost(ep.host). + setName(ep.listenerName.value()). + setPort(ep.port.shortValue()). + setSecurityProtocol(ep.securityProtocol.id)) + } + + def poll(): Unit = { + mockClient.wakeup() + mockChannelManager.poll() + } + } + + @Test + def testCreateAndClose(): Unit = { + val context = new BrokerLifecycleManagerTestContext(configProperties) + val manager = new BrokerLifecycleManager(context.config, context.time, None) + manager.close() + } + + @Test + def testCreateStartAndClose(): Unit = { + val context = new BrokerLifecycleManagerTestContext(configProperties) + val manager = new BrokerLifecycleManager(context.config, context.time, None) + assertEquals(BrokerState.NOT_RUNNING, manager.state) + manager.start(() => context.highestMetadataOffset.get(), + context.mockChannelManager, context.clusterId, context.advertisedListeners, + Collections.emptyMap()) + TestUtils.retry(60000) { + assertEquals(BrokerState.STARTING, manager.state) + } + manager.close() + assertEquals(BrokerState.SHUTTING_DOWN, manager.state) + } + + @Test + def testSuccessfulRegistration(): Unit = { + val context = new BrokerLifecycleManagerTestContext(configProperties) + val manager = new BrokerLifecycleManager(context.config, context.time, None) + val controllerNode = new Node(3000, "localhost", 8021) + context.controllerNodeProvider.node.set(controllerNode) + context.mockClient.prepareResponseFrom(new BrokerRegistrationResponse( + new BrokerRegistrationResponseData().setBrokerEpoch(1000)), controllerNode) + manager.start(() => context.highestMetadataOffset.get(), + context.mockChannelManager, context.clusterId, context.advertisedListeners, + Collections.emptyMap()) + TestUtils.retry(10000) { + context.poll() + assertEquals(1000L, manager.brokerEpoch) + } + manager.close() + + } + + @Test + def testRegistrationTimeout(): Unit = { + val context = new BrokerLifecycleManagerTestContext(configProperties) + val controllerNode = new Node(3000, "localhost", 8021) + val manager = new BrokerLifecycleManager(context.config, context.time, None) + context.controllerNodeProvider.node.set(controllerNode) + def newDuplicateRegistrationResponse(): Unit = { + context.mockClient.prepareResponseFrom(new BrokerRegistrationResponse( + new BrokerRegistrationResponseData(). + setErrorCode(Errors.DUPLICATE_BROKER_REGISTRATION.code())), controllerNode) + context.mockChannelManager.poll() + } + newDuplicateRegistrationResponse() + assertEquals(1, context.mockClient.futureResponses().size) + manager.start(() => context.highestMetadataOffset.get(), + context.mockChannelManager, context.clusterId, context.advertisedListeners, + Collections.emptyMap()) + // We should send the first registration request and get a failure immediately + TestUtils.retry(60000) { + context.poll() + assertEquals(0, context.mockClient.futureResponses().size) + } + // Verify that we resend the registration request. + newDuplicateRegistrationResponse() + TestUtils.retry(60000) { + context.time.sleep(100) + context.poll() + manager.eventQueue.wakeup() + assertEquals(0, context.mockClient.futureResponses().size) + } + // Verify that we time out eventually. + context.time.sleep(300000) + TestUtils.retry(60000) { + context.poll() + manager.eventQueue.wakeup() + assertEquals(BrokerState.SHUTTING_DOWN, manager.state) + assertTrue(manager.initialCatchUpFuture.isCompletedExceptionally()) + assertEquals(-1L, manager.brokerEpoch) + } + manager.close() + } + + @Test + def testControlledShutdown(): Unit = { + val context = new BrokerLifecycleManagerTestContext(configProperties) + val manager = new BrokerLifecycleManager(context.config, context.time, None) + val controllerNode = new Node(3000, "localhost", 8021) + context.controllerNodeProvider.node.set(controllerNode) + context.mockClient.prepareResponseFrom(new BrokerRegistrationResponse( + new BrokerRegistrationResponseData().setBrokerEpoch(1000)), controllerNode) + context.mockClient.prepareResponseFrom(new BrokerHeartbeatResponse( + new BrokerHeartbeatResponseData().setIsCaughtUp(true)), controllerNode) + manager.start(() => context.highestMetadataOffset.get(), + context.mockChannelManager, context.clusterId, context.advertisedListeners, + Collections.emptyMap()) + TestUtils.retry(10000) { + context.poll() + manager.eventQueue.wakeup() + assertEquals(BrokerState.RECOVERY, manager.state) + } + context.mockClient.prepareResponseFrom(new BrokerHeartbeatResponse( + new BrokerHeartbeatResponseData().setIsFenced(false)), controllerNode) + context.time.sleep(20) + TestUtils.retry(10000) { + context.poll() + manager.eventQueue.wakeup() + assertEquals(BrokerState.RUNNING, manager.state) + } + manager.beginControlledShutdown() + TestUtils.retry(10000) { + context.poll() + manager.eventQueue.wakeup() + assertEquals(BrokerState.PENDING_CONTROLLED_SHUTDOWN, manager.state) + assertTrue(context.mockClient.hasInFlightRequests) + } + + context.mockClient.respond( + (body: AbstractRequest) => { + body match { + case heartbeatRequest: BrokerHeartbeatRequest => + assertTrue(heartbeatRequest.data.wantShutDown) + true + case _ => + false + } + }, + new BrokerHeartbeatResponse(new BrokerHeartbeatResponseData().setShouldShutDown(true)) + ) + + TestUtils.retry(10000) { + context.poll() + manager.eventQueue.wakeup() + assertEquals(BrokerState.SHUTTING_DOWN, manager.state) + } + manager.controlledShutdownFuture.get() + manager.close() + } +} diff --git a/core/src/test/scala/unit/kafka/server/BrokerMetricNamesTest.scala b/core/src/test/scala/unit/kafka/server/BrokerMetricNamesTest.scala new file mode 100644 index 0000000..3bd9c6d --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/BrokerMetricNamesTest.scala @@ -0,0 +1,56 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.metrics.KafkaYammerMetrics +import kafka.test.ClusterInstance +import kafka.test.annotation.{ClusterTest, ClusterTestDefaults, Type} +import kafka.test.junit.ClusterTestExtensions +import kafka.utils.TestUtils +import org.junit.jupiter.api.AfterEach +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.extension.ExtendWith + +import scala.jdk.CollectionConverters._ + +@ClusterTestDefaults(clusterType = Type.BOTH) +@ExtendWith(value = Array(classOf[ClusterTestExtensions])) +class BrokerMetricNamesTest(cluster: ClusterInstance) { + @AfterEach + def tearDown(): Unit = { + TestUtils.clearYammerMetrics() + } + + @ClusterTest + def testMetrics(): Unit = { + checkReplicaManagerMetrics() + } + + def checkReplicaManagerMetrics(): Unit = { + val metrics = KafkaYammerMetrics.defaultRegistry.allMetrics + val expectedPrefix = "kafka.server:type=ReplicaManager,name" + val expectedMetricNames = Set( + "LeaderCount", "PartitionCount", "OfflineReplicaCount", "UnderReplicatedPartitions", + "UnderMinIsrPartitionCount", "AtMinIsrPartitionCount", "ReassigningPartitions", + "IsrExpandsPerSec", "IsrShrinksPerSec", "FailedIsrUpdatesPerSec", + ) + expectedMetricNames.foreach { metricName => + assertEquals(1, metrics.keySet.asScala.count(_.getMBeanName == s"$expectedPrefix=$metricName")) + } + } +} diff --git a/core/src/test/scala/unit/kafka/server/ClientQuotaManagerTest.scala b/core/src/test/scala/unit/kafka/server/ClientQuotaManagerTest.scala new file mode 100644 index 0000000..4159a1b --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/ClientQuotaManagerTest.scala @@ -0,0 +1,427 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.net.InetAddress + +import kafka.network.RequestChannel.Session +import kafka.server.QuotaType._ +import org.apache.kafka.common.metrics.Quota +import org.apache.kafka.common.security.auth.KafkaPrincipal +import org.apache.kafka.common.utils.Sanitizer + +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +class ClientQuotaManagerTest extends BaseClientQuotaManagerTest { + private val config = ClientQuotaManagerConfig() + + private def testQuotaParsing(config: ClientQuotaManagerConfig, client1: UserClient, client2: UserClient, randomClient: UserClient, defaultConfigClient: UserClient): Unit = { + val clientQuotaManager = new ClientQuotaManager(config, metrics, Produce, time, "") + + try { + // Case 1: Update the quota. Assert that the new quota value is returned + clientQuotaManager.updateQuota(client1.configUser, client1.configClientId, client1.sanitizedConfigClientId, Some(new Quota(2000, true))) + clientQuotaManager.updateQuota(client2.configUser, client2.configClientId, client2.sanitizedConfigClientId, Some(new Quota(4000, true))) + + assertEquals(Long.MaxValue.toDouble, clientQuotaManager.quota(randomClient.user, randomClient.clientId).bound, 0.0, + "Default producer quota should be " + Long.MaxValue.toDouble) + assertEquals(2000, clientQuotaManager.quota(client1.user, client1.clientId).bound, 0.0, + "Should return the overridden value (2000)") + assertEquals(4000, clientQuotaManager.quota(client2.user, client2.clientId).bound, 0.0, + "Should return the overridden value (4000)") + + // p1 should be throttled using the overridden quota + var throttleTimeMs = maybeRecord(clientQuotaManager, client1.user, client1.clientId, 2500 * config.numQuotaSamples) + assertTrue(throttleTimeMs > 0, s"throttleTimeMs should be > 0. was $throttleTimeMs") + + // Case 2: Change quota again. The quota should be updated within KafkaMetrics as well since the sensor was created. + // p1 should not longer be throttled after the quota change + clientQuotaManager.updateQuota(client1.configUser, client1.configClientId, client1.sanitizedConfigClientId, Some(new Quota(3000, true))) + assertEquals(3000, clientQuotaManager.quota(client1.user, client1.clientId).bound, 0.0, "Should return the newly overridden value (3000)") + + throttleTimeMs = maybeRecord(clientQuotaManager, client1.user, client1.clientId, 0) + assertEquals(0, throttleTimeMs, s"throttleTimeMs should be 0. was $throttleTimeMs") + + // Case 3: Change quota back to default. Should be throttled again + clientQuotaManager.updateQuota(client1.configUser, client1.configClientId, client1.sanitizedConfigClientId, Some(new Quota(500, true))) + assertEquals(500, clientQuotaManager.quota(client1.user, client1.clientId).bound, 0.0, "Should return the default value (500)") + + throttleTimeMs = maybeRecord(clientQuotaManager, client1.user, client1.clientId, 0) + assertTrue(throttleTimeMs > 0, s"throttleTimeMs should be > 0. was $throttleTimeMs") + + // Case 4: Set high default quota, remove p1 quota. p1 should no longer be throttled + clientQuotaManager.updateQuota(client1.configUser, client1.configClientId, client1.sanitizedConfigClientId, None) + clientQuotaManager.updateQuota(defaultConfigClient.configUser, defaultConfigClient.configClientId, defaultConfigClient.sanitizedConfigClientId, Some(new Quota(4000, true))) + assertEquals(4000, clientQuotaManager.quota(client1.user, client1.clientId).bound, 0.0, "Should return the newly overridden value (4000)") + + throttleTimeMs = maybeRecord(clientQuotaManager, client1.user, client1.clientId, 1000 * config.numQuotaSamples) + assertEquals(0, throttleTimeMs, s"throttleTimeMs should be 0. was $throttleTimeMs") + + } finally { + clientQuotaManager.shutdown() + } + } + + /** + * Tests parsing for quotas. + * Quota overrides persisted in ZooKeeper in /config/clients/, default persisted in /config/clients/ + */ + @Test + def testClientIdQuotaParsing(): Unit = { + val client1 = UserClient("ANONYMOUS", "p1", None, Some("p1")) + val client2 = UserClient("ANONYMOUS", "p2", None, Some("p2")) + val randomClient = UserClient("ANONYMOUS", "random-client-id", None, None) + val defaultConfigClient = UserClient("", "", None, Some(ConfigEntityName.Default)) + testQuotaParsing(config, client1, client2, randomClient, defaultConfigClient) + } + + /** + * Tests parsing for quotas. + * Quota overrides persisted in ZooKeeper in /config/users/, default persisted in /config/users/ + */ + @Test + def testUserQuotaParsing(): Unit = { + val client1 = UserClient("User1", "p1", Some("User1"), None) + val client2 = UserClient("User2", "p2", Some("User2"), None) + val randomClient = UserClient("RandomUser", "random-client-id", None, None) + val defaultConfigClient = UserClient("", "", Some(ConfigEntityName.Default), None) + val config = ClientQuotaManagerConfig() + testQuotaParsing(config, client1, client2, randomClient, defaultConfigClient) + } + + /** + * Tests parsing for quotas. + * Quotas persisted in ZooKeeper in /config/users//clients/, default in /config/users//clients/ + */ + @Test + def testUserClientIdQuotaParsing(): Unit = { + val client1 = UserClient("User1", "p1", Some("User1"), Some("p1")) + val client2 = UserClient("User2", "p2", Some("User2"), Some("p2")) + val randomClient = UserClient("RandomUser", "random-client-id", None, None) + val defaultConfigClient = UserClient("", "", Some(ConfigEntityName.Default), Some(ConfigEntityName.Default)) + val config = ClientQuotaManagerConfig() + testQuotaParsing(config, client1, client2, randomClient, defaultConfigClient) + } + + /** + * Tests parsing for quotas when client-id default quota properties are set. + */ + @Test + def testUserQuotaParsingWithDefaultClientIdQuota(): Unit = { + val client1 = UserClient("User1", "p1", Some("User1"), None) + val client2 = UserClient("User2", "p2", Some("User2"), None) + val randomClient = UserClient("RandomUser", "random-client-id", None, None) + val defaultConfigClient = UserClient("", "", Some(ConfigEntityName.Default), None) + testQuotaParsing(config, client1, client2, randomClient, defaultConfigClient) + } + + /** + * Tests parsing for quotas when client-id default quota properties are set. + */ + @Test + def testUserClientQuotaParsingIdWithDefaultClientIdQuota(): Unit = { + val client1 = UserClient("User1", "p1", Some("User1"), Some("p1")) + val client2 = UserClient("User2", "p2", Some("User2"), Some("p2")) + val randomClient = UserClient("RandomUser", "random-client-id", None, None) + val defaultConfigClient = UserClient("", "", Some(ConfigEntityName.Default), Some(ConfigEntityName.Default)) + testQuotaParsing(config, client1, client2, randomClient, defaultConfigClient) + } + + private def checkQuota(quotaManager: ClientQuotaManager, user: String, clientId: String, expectedBound: Long, value: Int, expectThrottle: Boolean): Unit = { + assertEquals(expectedBound.toDouble, quotaManager.quota(user, clientId).bound, 0.0) + val session = Session(new KafkaPrincipal(KafkaPrincipal.USER_TYPE, user), InetAddress.getLocalHost) + val expectedMaxValueInQuotaWindow = + if (expectedBound < Long.MaxValue) config.quotaWindowSizeSeconds * (config.numQuotaSamples - 1) * expectedBound.toDouble + else Double.MaxValue + assertEquals(expectedMaxValueInQuotaWindow, quotaManager.getMaxValueInQuotaWindow(session, clientId), 0.01) + + val throttleTimeMs = maybeRecord(quotaManager, user, clientId, value * config.numQuotaSamples) + if (expectThrottle) + assertTrue(throttleTimeMs > 0, s"throttleTimeMs should be > 0. was $throttleTimeMs") + else + assertEquals(0, throttleTimeMs, s"throttleTimeMs should be 0. was $throttleTimeMs") + } + + @Test + def testGetMaxValueInQuotaWindowWithNonDefaultQuotaWindow(): Unit = { + val numFullQuotaWindows = 3 // 3 seconds window (vs. 10 seconds default) + val nonDefaultConfig = ClientQuotaManagerConfig(numQuotaSamples = numFullQuotaWindows + 1) + val clientQuotaManager = new ClientQuotaManager(nonDefaultConfig, metrics, Fetch, time, "") + val userSession = Session(new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "userA"), InetAddress.getLocalHost) + + try { + // no quota set + assertEquals(Double.MaxValue, clientQuotaManager.getMaxValueInQuotaWindow(userSession, "client1"), 0.01) + + // Set default quota config + clientQuotaManager.updateQuota(Some(ConfigEntityName.Default), None, None, Some(new Quota(10, true))) + assertEquals(10 * numFullQuotaWindows, clientQuotaManager.getMaxValueInQuotaWindow(userSession, "client1"), 0.01) + } finally { + clientQuotaManager.shutdown() + } + } + + @Test + def testSetAndRemoveDefaultUserQuota(): Unit = { + // quotaTypesEnabled will be QuotaTypes.NoQuotas initially + val clientQuotaManager = new ClientQuotaManager(ClientQuotaManagerConfig(), + metrics, Produce, time, "") + + try { + // no quota set yet, should not throttle + checkQuota(clientQuotaManager, "userA", "client1", Long.MaxValue, 1000, false) + + // Set default quota config + clientQuotaManager.updateQuota(Some(ConfigEntityName.Default), None, None, Some(new Quota(10, true))) + checkQuota(clientQuotaManager, "userA", "client1", 10, 1000, true) + + // Remove default quota config, back to no quotas + clientQuotaManager.updateQuota(Some(ConfigEntityName.Default), None, None, None) + checkQuota(clientQuotaManager, "userA", "client1", Long.MaxValue, 1000, false) + } finally { + clientQuotaManager.shutdown() + } + } + + @Test + def testSetAndRemoveUserQuota(): Unit = { + // quotaTypesEnabled will be QuotaTypes.NoQuotas initially + val clientQuotaManager = new ClientQuotaManager(ClientQuotaManagerConfig(), + metrics, Produce, time, "") + + try { + // Set quota config + clientQuotaManager.updateQuota(Some("userA"), None, None, Some(new Quota(10, true))) + checkQuota(clientQuotaManager, "userA", "client1", 10, 1000, true) + + // Remove quota config, back to no quotas + clientQuotaManager.updateQuota(Some("userA"), None, None, None) + checkQuota(clientQuotaManager, "userA", "client1", Long.MaxValue, 1000, false) + } finally { + clientQuotaManager.shutdown() + } + } + + @Test + def testSetAndRemoveUserClientQuota(): Unit = { + // quotaTypesEnabled will be QuotaTypes.NoQuotas initially + val clientQuotaManager = new ClientQuotaManager(ClientQuotaManagerConfig(), + metrics, Produce, time, "") + + try { + // Set quota config + clientQuotaManager.updateQuota(Some("userA"), Some("client1"), Some("client1"), Some(new Quota(10, true))) + checkQuota(clientQuotaManager, "userA", "client1", 10, 1000, true) + + // Remove quota config, back to no quotas + clientQuotaManager.updateQuota(Some("userA"), Some("client1"), Some("client1"), None) + checkQuota(clientQuotaManager, "userA", "client1", Long.MaxValue, 1000, false) + } finally { + clientQuotaManager.shutdown() + } + } + + @Test + def testQuotaConfigPrecedence(): Unit = { + val clientQuotaManager = new ClientQuotaManager(ClientQuotaManagerConfig(), + metrics, Produce, time, "") + + try { + clientQuotaManager.updateQuota(Some(ConfigEntityName.Default), None, None, Some(new Quota(1000, true))) + clientQuotaManager.updateQuota(None, Some(ConfigEntityName.Default), Some(ConfigEntityName.Default), Some(new Quota(2000, true))) + clientQuotaManager.updateQuota(Some(ConfigEntityName.Default), Some(ConfigEntityName.Default), Some(ConfigEntityName.Default), Some(new Quota(3000, true))) + clientQuotaManager.updateQuota(Some("userA"), None, None, Some(new Quota(4000, true))) + clientQuotaManager.updateQuota(Some("userA"), Some("client1"), Some("client1"), Some(new Quota(5000, true))) + clientQuotaManager.updateQuota(Some("userB"), None, None, Some(new Quota(6000, true))) + clientQuotaManager.updateQuota(Some("userB"), Some("client1"), Some("client1"), Some(new Quota(7000, true))) + clientQuotaManager.updateQuota(Some("userB"), Some(ConfigEntityName.Default), Some(ConfigEntityName.Default), Some(new Quota(8000, true))) + clientQuotaManager.updateQuota(Some("userC"), None, None, Some(new Quota(10000, true))) + clientQuotaManager.updateQuota(None, Some("client1"), Some("client1"), Some(new Quota(9000, true))) + + checkQuota(clientQuotaManager, "userA", "client1", 5000, 4500, false) // quota takes precedence over + checkQuota(clientQuotaManager, "userA", "client2", 4000, 4500, true) // quota takes precedence over and defaults + checkQuota(clientQuotaManager, "userA", "client3", 4000, 0, true) // quota is shared across clients of user + checkQuota(clientQuotaManager, "userA", "client1", 5000, 0, false) // is exclusive use, unaffected by other clients + + checkQuota(clientQuotaManager, "userB", "client1", 7000, 8000, true) + checkQuota(clientQuotaManager, "userB", "client2", 8000, 7000, false) // Default per-client quota for exclusive use of + checkQuota(clientQuotaManager, "userB", "client3", 8000, 7000, false) + + checkQuota(clientQuotaManager, "userD", "client1", 3000, 3500, true) // Default quota + checkQuota(clientQuotaManager, "userD", "client2", 3000, 2500, false) + checkQuota(clientQuotaManager, "userE", "client1", 3000, 2500, false) + + // Remove default quota config, revert to default + clientQuotaManager.updateQuota(Some(ConfigEntityName.Default), Some(ConfigEntityName.Default), Some(ConfigEntityName.Default), None) + checkQuota(clientQuotaManager, "userD", "client1", 1000, 0, false) // Metrics tags changed, restart counter + checkQuota(clientQuotaManager, "userE", "client4", 1000, 1500, true) + checkQuota(clientQuotaManager, "userF", "client4", 1000, 800, false) // Default quota shared across clients of user + checkQuota(clientQuotaManager, "userF", "client5", 1000, 800, true) + + // Remove default quota config, revert to default + clientQuotaManager.updateQuota(Some(ConfigEntityName.Default), None, None, None) + checkQuota(clientQuotaManager, "userF", "client4", 2000, 0, false) // Default quota shared across client-id of all users + checkQuota(clientQuotaManager, "userF", "client5", 2000, 0, false) + checkQuota(clientQuotaManager, "userF", "client5", 2000, 2500, true) + checkQuota(clientQuotaManager, "userG", "client5", 2000, 0, true) + + // Update quotas + clientQuotaManager.updateQuota(Some("userA"), None, None, Some(new Quota(8000, true))) + clientQuotaManager.updateQuota(Some("userA"), Some("client1"), Some("client1"), Some(new Quota(10000, true))) + checkQuota(clientQuotaManager, "userA", "client2", 8000, 0, false) + checkQuota(clientQuotaManager, "userA", "client2", 8000, 4500, true) // Throttled due to sum of new and earlier values + checkQuota(clientQuotaManager, "userA", "client1", 10000, 0, false) + checkQuota(clientQuotaManager, "userA", "client1", 10000, 6000, true) + clientQuotaManager.updateQuota(Some("userA"), Some("client1"), Some("client1"), None) + checkQuota(clientQuotaManager, "userA", "client6", 8000, 0, true) // Throttled due to shared user quota + clientQuotaManager.updateQuota(Some("userA"), Some("client6"), Some("client6"), Some(new Quota(11000, true))) + checkQuota(clientQuotaManager, "userA", "client6", 11000, 8500, false) + clientQuotaManager.updateQuota(Some("userA"), Some(ConfigEntityName.Default), Some(ConfigEntityName.Default), Some(new Quota(12000, true))) + clientQuotaManager.updateQuota(Some("userA"), Some("client6"), Some("client6"), None) + checkQuota(clientQuotaManager, "userA", "client6", 12000, 4000, true) // Throttled due to sum of new and earlier values + + } finally { + clientQuotaManager.shutdown() + } + } + + @Test + def testQuotaViolation(): Unit = { + val clientQuotaManager = new ClientQuotaManager(config, metrics, Produce, time, "") + val queueSizeMetric = metrics.metrics().get(metrics.metricName("queue-size", "Produce", "")) + try { + clientQuotaManager.updateQuota(None, Some(ConfigEntityName.Default), Some(ConfigEntityName.Default), + Some(new Quota(500, true))) + + // We have 10 second windows. Make sure that there is no quota violation + // if we produce under the quota + for (_ <- 0 until 10) { + assertEquals(0, maybeRecord(clientQuotaManager, "ANONYMOUS", "unknown", 400)) + time.sleep(1000) + } + assertEquals(0, queueSizeMetric.metricValue.asInstanceOf[Double].toInt) + + // Create a spike. + // 400*10 + 2000 + 300 = 6300/10.5 = 600 bytes per second. + // (600 - quota)/quota*window-size = (600-500)/500*10.5 seconds = 2100 + // 10.5 seconds because the last window is half complete + time.sleep(500) + val throttleTime = maybeRecord(clientQuotaManager, "ANONYMOUS", "unknown", 2300) + + assertEquals(2100, throttleTime, "Should be throttled") + throttle(clientQuotaManager, "ANONYMOUS", "unknown", throttleTime, callback) + assertEquals(1, queueSizeMetric.metricValue.asInstanceOf[Double].toInt) + // After a request is delayed, the callback cannot be triggered immediately + clientQuotaManager.throttledChannelReaper.doWork() + assertEquals(0, numCallbacks) + time.sleep(throttleTime) + + // Callback can only be triggered after the delay time passes + clientQuotaManager.throttledChannelReaper.doWork() + assertEquals(0, queueSizeMetric.metricValue.asInstanceOf[Double].toInt) + assertEquals(1, numCallbacks) + + // Could continue to see delays until the bursty sample disappears + for (_ <- 0 until 10) { + maybeRecord(clientQuotaManager, "ANONYMOUS", "unknown", 400) + time.sleep(1000) + } + + assertEquals(0, maybeRecord(clientQuotaManager, "ANONYMOUS", "unknown", 0), + "Should be unthrottled since bursty sample has rolled over") + } finally { + clientQuotaManager.shutdown() + } + } + + @Test + def testExpireThrottleTimeSensor(): Unit = { + val clientQuotaManager = new ClientQuotaManager(config, metrics, Produce, time, "") + try { + clientQuotaManager.updateQuota(None, Some(ConfigEntityName.Default), Some(ConfigEntityName.Default), + Some(new Quota(500, true))) + + maybeRecord(clientQuotaManager, "ANONYMOUS", "client1", 100) + // remove the throttle time sensor + metrics.removeSensor("ProduceThrottleTime-:client1") + // should not throw an exception even if the throttle time sensor does not exist. + val throttleTime = maybeRecord(clientQuotaManager, "ANONYMOUS", "client1", 10000) + assertTrue(throttleTime > 0, "Should be throttled") + // the sensor should get recreated + val throttleTimeSensor = metrics.getSensor("ProduceThrottleTime-:client1") + assertNotNull(throttleTimeSensor, "Throttle time sensor should exist") + assertNotNull(throttleTimeSensor, "Throttle time sensor should exist") + } finally { + clientQuotaManager.shutdown() + } + } + + @Test + def testExpireQuotaSensors(): Unit = { + val clientQuotaManager = new ClientQuotaManager(config, metrics, Produce, time, "") + try { + clientQuotaManager.updateQuota(None, Some(ConfigEntityName.Default), Some(ConfigEntityName.Default), + Some(new Quota(500, true))) + + maybeRecord(clientQuotaManager, "ANONYMOUS", "client1", 100) + // remove all the sensors + metrics.removeSensor("ProduceThrottleTime-:client1") + metrics.removeSensor("Produce-ANONYMOUS:client1") + // should not throw an exception + val throttleTime = maybeRecord(clientQuotaManager, "ANONYMOUS", "client1", 10000) + assertTrue(throttleTime > 0, "Should be throttled") + + // all the sensors should get recreated + val throttleTimeSensor = metrics.getSensor("ProduceThrottleTime-:client1") + assertNotNull(throttleTimeSensor, "Throttle time sensor should exist") + + val byteRateSensor = metrics.getSensor("Produce-:client1") + assertNotNull(byteRateSensor, "Byte rate sensor should exist") + } finally { + clientQuotaManager.shutdown() + } + } + + @Test + def testClientIdNotSanitized(): Unit = { + val clientQuotaManager = new ClientQuotaManager(config, metrics, Produce, time, "") + val clientId = "client@#$%" + try { + clientQuotaManager.updateQuota(None, Some(ConfigEntityName.Default), Some(ConfigEntityName.Default), + Some(new Quota(500, true))) + + maybeRecord(clientQuotaManager, "ANONYMOUS", clientId, 100) + + // The metrics should use the raw client ID, even if the reporters internally sanitize them + val throttleTimeSensor = metrics.getSensor("ProduceThrottleTime-:" + clientId) + assertNotNull(throttleTimeSensor, "Throttle time sensor should exist") + + val byteRateSensor = metrics.getSensor("Produce-:" + clientId) + assertNotNull(byteRateSensor, "Byte rate sensor should exist") + } finally { + clientQuotaManager.shutdown() + } + } + + private case class UserClient(val user: String, val clientId: String, val configUser: Option[String] = None, val configClientId: Option[String] = None) { + // The class under test expects only sanitized client configs. We pass both the default value (which should not be + // sanitized to ensure it remains unique) and non-default values, so we need to take care in generating the sanitized + // client ID + def sanitizedConfigClientId = configClientId.map(x => if (x == ConfigEntityName.Default) ConfigEntityName.Default else Sanitizer.sanitize(x)) + } +} diff --git a/core/src/test/scala/unit/kafka/server/ClientQuotasRequestTest.scala b/core/src/test/scala/unit/kafka/server/ClientQuotasRequestTest.scala new file mode 100644 index 0000000..573bd95 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/ClientQuotasRequestTest.scala @@ -0,0 +1,611 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.net.InetAddress +import java.util +import java.util.concurrent.{ExecutionException, TimeUnit} + +import kafka.test.ClusterInstance +import kafka.test.annotation.{ClusterTest, ClusterTestDefaults, Type} +import kafka.test.junit.ClusterTestExtensions +import kafka.utils.TestUtils +import org.apache.kafka.clients.admin.{ScramCredentialInfo, ScramMechanism, UserScramCredentialUpsertion} +import org.apache.kafka.common.config.internals.QuotaConfigs +import org.apache.kafka.common.errors.{InvalidRequestException, UnsupportedVersionException} +import org.apache.kafka.common.internals.KafkaFutureImpl +import org.apache.kafka.common.quota.{ClientQuotaAlteration, ClientQuotaEntity, ClientQuotaFilter, ClientQuotaFilterComponent} +import org.apache.kafka.common.requests.{AlterClientQuotasRequest, AlterClientQuotasResponse, DescribeClientQuotasRequest, DescribeClientQuotasResponse} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Tag +import org.junit.jupiter.api.extension.ExtendWith + +import scala.jdk.CollectionConverters._ + +@ClusterTestDefaults(clusterType = Type.BOTH) +@ExtendWith(value = Array(classOf[ClusterTestExtensions])) +@Tag("integration") +class ClientQuotasRequestTest(cluster: ClusterInstance) { + private val ConsumerByteRateProp = QuotaConfigs.CONSUMER_BYTE_RATE_OVERRIDE_CONFIG + private val ProducerByteRateProp = QuotaConfigs.PRODUCER_BYTE_RATE_OVERRIDE_CONFIG + private val RequestPercentageProp = QuotaConfigs.REQUEST_PERCENTAGE_OVERRIDE_CONFIG + private val IpConnectionRateProp = QuotaConfigs.IP_CONNECTION_RATE_OVERRIDE_CONFIG + + @ClusterTest + def testAlterClientQuotasRequest(): Unit = { + + val entity = new ClientQuotaEntity(Map((ClientQuotaEntity.USER -> "user"), (ClientQuotaEntity.CLIENT_ID -> "client-id")).asJava) + + // Expect an empty configuration. + verifyDescribeEntityQuotas(entity, Map.empty) + + // Add two configuration entries. + alterEntityQuotas(entity, Map( + (ProducerByteRateProp -> Some(10000.0)), + (ConsumerByteRateProp -> Some(20000.0)) + ), validateOnly = false) + + verifyDescribeEntityQuotas(entity, Map( + (ProducerByteRateProp -> 10000.0), + (ConsumerByteRateProp -> 20000.0) + )) + + // Update an existing entry. + alterEntityQuotas(entity, Map( + (ProducerByteRateProp -> Some(15000.0)) + ), validateOnly = false) + + verifyDescribeEntityQuotas(entity, Map( + (ProducerByteRateProp -> 15000.0), + (ConsumerByteRateProp -> 20000.0) + )) + + // Remove an existing configuration entry. + alterEntityQuotas(entity, Map( + (ProducerByteRateProp -> None) + ), validateOnly = false) + + verifyDescribeEntityQuotas(entity, Map( + (ConsumerByteRateProp -> 20000.0) + )) + + // Remove a non-existent configuration entry. This should make no changes. + alterEntityQuotas(entity, Map( + (RequestPercentageProp -> None) + ), validateOnly = false) + + verifyDescribeEntityQuotas(entity, Map( + (ConsumerByteRateProp -> 20000.0) + )) + + // Add back a deleted configuration entry. + alterEntityQuotas(entity, Map( + (ProducerByteRateProp -> Some(5000.0)) + ), validateOnly = false) + + verifyDescribeEntityQuotas(entity, Map( + (ProducerByteRateProp -> 5000.0), + (ConsumerByteRateProp -> 20000.0) + )) + + // Perform a mixed update. + alterEntityQuotas(entity, Map( + (ProducerByteRateProp -> Some(20000.0)), + (ConsumerByteRateProp -> None), + (RequestPercentageProp -> Some(12.3)) + ), validateOnly = false) + + verifyDescribeEntityQuotas(entity, Map( + (ProducerByteRateProp -> 20000.0), + (RequestPercentageProp -> 12.3) + )) + } + + @ClusterTest + def testAlterClientQuotasRequestValidateOnly(): Unit = { + val entity = new ClientQuotaEntity(Map((ClientQuotaEntity.USER -> "user")).asJava) + + // Set up a configuration. + alterEntityQuotas(entity, Map( + (ProducerByteRateProp -> Some(20000.0)), + (RequestPercentageProp -> Some(23.45)) + ), validateOnly = false) + + verifyDescribeEntityQuotas(entity, Map( + (ProducerByteRateProp -> 20000.0), + (RequestPercentageProp -> 23.45) + )) + + // Validate-only addition. + alterEntityQuotas(entity, Map( + (ConsumerByteRateProp -> Some(50000.0)) + ), validateOnly = true) + + verifyDescribeEntityQuotas(entity, Map( + (ProducerByteRateProp -> 20000.0), + (RequestPercentageProp -> 23.45) + )) + + // Validate-only modification. + alterEntityQuotas(entity, Map( + (ProducerByteRateProp -> Some(10000.0)) + ), validateOnly = true) + + verifyDescribeEntityQuotas(entity, Map( + (ProducerByteRateProp -> 20000.0), + (RequestPercentageProp -> 23.45) + )) + + // Validate-only removal. + alterEntityQuotas(entity, Map( + (RequestPercentageProp -> None) + ), validateOnly = true) + + verifyDescribeEntityQuotas(entity, Map( + (ProducerByteRateProp -> 20000.0), + (RequestPercentageProp -> 23.45) + )) + + // Validate-only mixed update. + alterEntityQuotas(entity, Map( + (ProducerByteRateProp -> Some(10000.0)), + (ConsumerByteRateProp -> Some(50000.0)), + (RequestPercentageProp -> None) + ), validateOnly = true) + + verifyDescribeEntityQuotas(entity, Map( + (ProducerByteRateProp -> 20000.0), + (RequestPercentageProp -> 23.45) + )) + } + + @ClusterTest(clusterType = Type.ZK) // No SCRAM for Raft yet + def testClientQuotasForScramUsers(): Unit = { + val userName = "user" + + val results = cluster.createAdminClient().alterUserScramCredentials(util.Arrays.asList( + new UserScramCredentialUpsertion(userName, new ScramCredentialInfo(ScramMechanism.SCRAM_SHA_256, 4096), "password"))) + results.all.get + + val entity = new ClientQuotaEntity(Map(ClientQuotaEntity.USER -> userName).asJava) + + verifyDescribeEntityQuotas(entity, Map.empty) + + alterEntityQuotas(entity, Map( + (ProducerByteRateProp -> Some(10000.0)), + (ConsumerByteRateProp -> Some(20000.0)) + ), validateOnly = false) + + verifyDescribeEntityQuotas(entity, Map( + (ProducerByteRateProp -> 10000.0), + (ConsumerByteRateProp -> 20000.0) + )) + } + + @ClusterTest + def testAlterIpQuotasRequest(): Unit = { + val knownHost = "1.2.3.4" + val unknownHost = "2.3.4.5" + val entity = toIpEntity(Some(knownHost)) + val defaultEntity = toIpEntity(Some(null)) + val entityFilter = ClientQuotaFilterComponent.ofEntity(ClientQuotaEntity.IP, knownHost) + val defaultEntityFilter = ClientQuotaFilterComponent.ofDefaultEntity(ClientQuotaEntity.IP) + val allIpEntityFilter = ClientQuotaFilterComponent.ofEntityType(ClientQuotaEntity.IP) + + def verifyIpQuotas(entityFilter: ClientQuotaFilterComponent, expectedMatches: Map[ClientQuotaEntity, Double]): Unit = { + TestUtils.tryUntilNoAssertionError() { + val result = describeClientQuotas(ClientQuotaFilter.containsOnly(List(entityFilter).asJava)) + assertEquals(expectedMatches.keySet, result.asScala.keySet) + result.asScala.foreach { case (entity, props) => + assertEquals(Set(IpConnectionRateProp), props.asScala.keySet) + assertEquals(expectedMatches(entity), props.get(IpConnectionRateProp)) + val entityName = entity.entries.get(ClientQuotaEntity.IP) + // ClientQuotaEntity with null name maps to default entity + val entityIp = if (entityName == null) + InetAddress.getByName(unknownHost) + else + InetAddress.getByName(entityName) + var currentServerQuota = 0 + currentServerQuota = cluster.brokerSocketServers().asScala.head.connectionQuotas.connectionRateForIp(entityIp) + assertTrue(Math.abs(expectedMatches(entity) - currentServerQuota) < 0.01, + s"Connection quota of $entity is not ${expectedMatches(entity)} but $currentServerQuota") + } + } + } + + // Expect an empty configuration. + verifyIpQuotas(allIpEntityFilter, Map.empty) + + // Add a configuration entry. + alterEntityQuotas(entity, Map(IpConnectionRateProp -> Some(100.0)), validateOnly = false) + verifyIpQuotas(entityFilter, Map(entity -> 100.0)) + + // update existing entry + alterEntityQuotas(entity, Map(IpConnectionRateProp -> Some(150.0)), validateOnly = false) + verifyIpQuotas(entityFilter, Map(entity -> 150.0)) + + // update default value + alterEntityQuotas(defaultEntity, Map(IpConnectionRateProp -> Some(200.0)), validateOnly = false) + verifyIpQuotas(defaultEntityFilter, Map(defaultEntity -> 200.0)) + + // describe all IP quotas + verifyIpQuotas(allIpEntityFilter, Map(entity -> 150.0, defaultEntity -> 200.0)) + + // remove entry + alterEntityQuotas(entity, Map(IpConnectionRateProp -> None), validateOnly = false) + verifyIpQuotas(entityFilter, Map.empty) + + // remove default value + alterEntityQuotas(defaultEntity, Map(IpConnectionRateProp -> None), validateOnly = false) + verifyIpQuotas(allIpEntityFilter, Map.empty) + } + + @ClusterTest + def testAlterClientQuotasInvalidRequests(): Unit = { + var entity = new ClientQuotaEntity(Map((ClientQuotaEntity.USER -> "")).asJava) + assertThrows(classOf[InvalidRequestException], () => alterEntityQuotas(entity, Map((RequestPercentageProp -> Some(12.34))), validateOnly = true)) + + entity = new ClientQuotaEntity(Map((ClientQuotaEntity.CLIENT_ID -> "")).asJava) + assertThrows(classOf[InvalidRequestException], () => alterEntityQuotas(entity, Map((RequestPercentageProp -> Some(12.34))), validateOnly = true)) + + entity = new ClientQuotaEntity(Map(("" -> "name")).asJava) + assertThrows(classOf[InvalidRequestException], () => alterEntityQuotas(entity, Map((RequestPercentageProp -> Some(12.34))), validateOnly = true)) + + entity = new ClientQuotaEntity(Map.empty.asJava) + assertThrows(classOf[InvalidRequestException], () => alterEntityQuotas(entity, Map((ProducerByteRateProp -> Some(10000.5))), validateOnly = true)) + + entity = new ClientQuotaEntity(Map((ClientQuotaEntity.USER -> "user")).asJava) + assertThrows(classOf[InvalidRequestException], () => alterEntityQuotas(entity, Map(("bad" -> Some(1.0))), validateOnly = true)) + + entity = new ClientQuotaEntity(Map((ClientQuotaEntity.USER -> "user")).asJava) + assertThrows(classOf[InvalidRequestException], () => alterEntityQuotas(entity, Map((ProducerByteRateProp -> Some(10000.5))), validateOnly = true)) + } + + private def expectInvalidRequestWithMessage(runnable: => Unit, expectedMessage: String): Unit = { + val exception = assertThrows(classOf[InvalidRequestException], () => runnable) + assertTrue(exception.getMessage.contains(expectedMessage), s"Expected message $exception to contain $expectedMessage") + } + + @ClusterTest + def testAlterClientQuotasInvalidEntityCombination(): Unit = { + val userAndIpEntity = new ClientQuotaEntity(Map(ClientQuotaEntity.USER -> "user", ClientQuotaEntity.IP -> "1.2.3.4").asJava) + val clientAndIpEntity = new ClientQuotaEntity(Map(ClientQuotaEntity.CLIENT_ID -> "client", ClientQuotaEntity.IP -> "1.2.3.4").asJava) + val expectedExceptionMessage = "Invalid quota entity combination" + expectInvalidRequestWithMessage(alterEntityQuotas(userAndIpEntity, Map(RequestPercentageProp -> Some(12.34)), + validateOnly = true), expectedExceptionMessage) + expectInvalidRequestWithMessage(alterEntityQuotas(clientAndIpEntity, Map(RequestPercentageProp -> Some(12.34)), + validateOnly = true), expectedExceptionMessage) + } + + @ClusterTest + def testAlterClientQuotasBadIp(): Unit = { + val invalidHostPatternEntity = new ClientQuotaEntity(Map(ClientQuotaEntity.IP -> "abc-123").asJava) + val unresolvableHostEntity = new ClientQuotaEntity(Map(ClientQuotaEntity.IP -> "ip").asJava) + val expectedExceptionMessage = "not a valid IP" + expectInvalidRequestWithMessage(alterEntityQuotas(invalidHostPatternEntity, Map(IpConnectionRateProp -> Some(50.0)), + validateOnly = true), expectedExceptionMessage) + expectInvalidRequestWithMessage(alterEntityQuotas(unresolvableHostEntity, Map(IpConnectionRateProp -> Some(50.0)), + validateOnly = true), expectedExceptionMessage) + } + + @ClusterTest + def testDescribeClientQuotasInvalidFilterCombination(): Unit = { + val ipFilterComponent = ClientQuotaFilterComponent.ofEntityType(ClientQuotaEntity.IP) + val userFilterComponent = ClientQuotaFilterComponent.ofEntityType(ClientQuotaEntity.USER) + val clientIdFilterComponent = ClientQuotaFilterComponent.ofEntityType(ClientQuotaEntity.CLIENT_ID) + val expectedExceptionMessage = "Invalid entity filter component combination" + expectInvalidRequestWithMessage(describeClientQuotas(ClientQuotaFilter.contains(List(ipFilterComponent, userFilterComponent).asJava)), + expectedExceptionMessage) + expectInvalidRequestWithMessage(describeClientQuotas(ClientQuotaFilter.contains(List(ipFilterComponent, clientIdFilterComponent).asJava)), + expectedExceptionMessage) + } + + // Entities to be matched against. + private val matchUserClientEntities = List( + (Some("user-1"), Some("client-id-1"), 50.50), + (Some("user-2"), Some("client-id-1"), 51.51), + (Some("user-3"), Some("client-id-2"), 52.52), + (Some(null), Some("client-id-1"), 53.53), + (Some("user-1"), Some(null), 54.54), + (Some("user-3"), Some(null), 55.55), + (Some("user-1"), None, 56.56), + (Some("user-2"), None, 57.57), + (Some("user-3"), None, 58.58), + (Some(null), None, 59.59), + (None, Some("client-id-2"), 60.60) + ).map { case (u, c, v) => (toClientEntity(u, c), v) } + + private val matchIpEntities = List( + (Some("1.2.3.4"), 10.0), + (Some("2.3.4.5"), 20.0) + ).map { case (ip, quota) => (toIpEntity(ip), quota)} + + private def setupDescribeClientQuotasMatchTest() = { + val userClientQuotas = matchUserClientEntities.map { case (e, v) => + e -> Map((RequestPercentageProp, Some(v))) + }.toMap + val ipQuotas = matchIpEntities.map { case (e, v) => + e -> Map((IpConnectionRateProp, Some(v))) + }.toMap + val result = alterClientQuotas(userClientQuotas ++ ipQuotas, validateOnly = false) + (matchUserClientEntities ++ matchIpEntities).foreach(e => result(e._1).get(10, TimeUnit.SECONDS)) + } + + @ClusterTest + def testDescribeClientQuotasMatchExact(): Unit = { + setupDescribeClientQuotasMatchTest() + + def matchEntity(entity: ClientQuotaEntity) = { + val components = entity.entries.asScala.map { case (entityType, entityName) => + entityName match { + case null => ClientQuotaFilterComponent.ofDefaultEntity(entityType) + case name => ClientQuotaFilterComponent.ofEntity(entityType, name) + } + } + describeClientQuotas(ClientQuotaFilter.containsOnly(components.toList.asJava)) + } + + // Test exact matches. + matchUserClientEntities.foreach { case (e, v) => + TestUtils.tryUntilNoAssertionError() { + val result = matchEntity(e) + assertEquals(1, result.size) + assertTrue(result.get(e) != null) + val value = result.get(e).get(RequestPercentageProp) + assertNotNull(value) + assertEquals(value, v, 1e-6) + } + } + + // Entities not contained in `matchEntityList`. + val notMatchEntities = List( + (Some("user-1"), Some("client-id-2")), + (Some("user-3"), Some("client-id-1")), + (Some("user-2"), Some(null)), + (Some("user-4"), None), + (Some(null), Some("client-id-2")), + (None, Some("client-id-1")), + (None, Some("client-id-3")), + ).map { case (u, c) => + new ClientQuotaEntity((u.map((ClientQuotaEntity.USER, _)) ++ + c.map((ClientQuotaEntity.CLIENT_ID, _))).toMap.asJava) + } + + // Verify exact matches of the non-matches returns empty. + notMatchEntities.foreach { e => + val result = matchEntity(e) + assertEquals(0, result.size) + } + } + + @ClusterTest + def testDescribeClientQuotasMatchPartial(): Unit = { + setupDescribeClientQuotasMatchTest() + + def testMatchEntities(filter: ClientQuotaFilter, expectedMatchSize: Int, partition: ClientQuotaEntity => Boolean): Unit = { + TestUtils.tryUntilNoAssertionError() { + val result = describeClientQuotas(filter) + val (expectedMatches, _) = (matchUserClientEntities ++ matchIpEntities).partition(e => partition(e._1)) + assertEquals(expectedMatchSize, expectedMatches.size) // for test verification + assertEquals(expectedMatchSize, result.size, s"Failed to match $expectedMatchSize entities for $filter") + val expectedMatchesMap = expectedMatches.toMap + matchUserClientEntities.foreach { case (entity, expectedValue) => + if (expectedMatchesMap.contains(entity)) { + val config = result.get(entity) + assertNotNull(config) + val value = config.get(RequestPercentageProp) + assertNotNull(value) + assertEquals(expectedValue, value, 1e-6) + } else { + assertNull(result.get(entity)) + } + } + matchIpEntities.foreach { case (entity, expectedValue) => + if (expectedMatchesMap.contains(entity)) { + val config = result.get(entity) + assertNotNull(config) + val value = config.get(IpConnectionRateProp) + assertNotNull(value) + assertEquals(expectedValue, value, 1e-6) + } else { + assertNull(result.get(entity)) + } + } + } + } + + // Match open-ended existing user. + testMatchEntities( + ClientQuotaFilter.contains(List(ClientQuotaFilterComponent.ofEntity(ClientQuotaEntity.USER, "user-1")).asJava), 3, + entity => entity.entries.get(ClientQuotaEntity.USER) == "user-1" + ) + + // Match open-ended non-existent user. + testMatchEntities( + ClientQuotaFilter.contains(List(ClientQuotaFilterComponent.ofEntity(ClientQuotaEntity.USER, "unknown")).asJava), 0, + entity => false + ) + + // Match open-ended existing client ID. + testMatchEntities( + ClientQuotaFilter.contains(List(ClientQuotaFilterComponent.ofEntity(ClientQuotaEntity.CLIENT_ID, "client-id-2")).asJava), 2, + entity => entity.entries.get(ClientQuotaEntity.CLIENT_ID) == "client-id-2" + ) + + // Match open-ended default user. + testMatchEntities( + ClientQuotaFilter.contains(List(ClientQuotaFilterComponent.ofDefaultEntity(ClientQuotaEntity.USER)).asJava), 2, + entity => entity.entries.containsKey(ClientQuotaEntity.USER) && entity.entries.get(ClientQuotaEntity.USER) == null + ) + + // Match close-ended existing user. + testMatchEntities( + ClientQuotaFilter.containsOnly(List(ClientQuotaFilterComponent.ofEntity(ClientQuotaEntity.USER, "user-2")).asJava), 1, + entity => entity.entries.get(ClientQuotaEntity.USER) == "user-2" && !entity.entries.containsKey(ClientQuotaEntity.CLIENT_ID) + ) + + // Match close-ended existing client ID that has no matching entity. + testMatchEntities( + ClientQuotaFilter.containsOnly(List(ClientQuotaFilterComponent.ofEntity(ClientQuotaEntity.CLIENT_ID, "client-id-1")).asJava), 0, + entity => false + ) + + // Match against all entities with the user type in a close-ended match. + testMatchEntities( + ClientQuotaFilter.containsOnly(List(ClientQuotaFilterComponent.ofEntityType(ClientQuotaEntity.USER)).asJava), 4, + entity => entity.entries.containsKey(ClientQuotaEntity.USER) && !entity.entries.containsKey(ClientQuotaEntity.CLIENT_ID) + ) + + // Match against all entities with the user type in an open-ended match. + testMatchEntities( + ClientQuotaFilter.contains(List(ClientQuotaFilterComponent.ofEntityType(ClientQuotaEntity.USER)).asJava), 10, + entity => entity.entries.containsKey(ClientQuotaEntity.USER) + ) + + // Match against all entities with the client ID type in a close-ended match. + testMatchEntities( + ClientQuotaFilter.containsOnly(List(ClientQuotaFilterComponent.ofEntityType(ClientQuotaEntity.CLIENT_ID)).asJava), 1, + entity => entity.entries.containsKey(ClientQuotaEntity.CLIENT_ID) && !entity.entries.containsKey(ClientQuotaEntity.USER) + ) + + // Match against all entities with the client ID type in an open-ended match. + testMatchEntities( + ClientQuotaFilter.contains(List(ClientQuotaFilterComponent.ofEntityType(ClientQuotaEntity.CLIENT_ID)).asJava), 7, + entity => entity.entries.containsKey(ClientQuotaEntity.CLIENT_ID) + ) + + // Match against all entities with IP type in an open-ended match. + testMatchEntities( + ClientQuotaFilter.contains(List(ClientQuotaFilterComponent.ofEntityType(ClientQuotaEntity.IP)).asJava), 2, + entity => entity.entries.containsKey(ClientQuotaEntity.IP) + ) + + // Match open-ended empty filter list. This should match all entities. + testMatchEntities(ClientQuotaFilter.contains(List.empty.asJava), 13, entity => true) + + // Match close-ended empty filter list. This should match no entities. + testMatchEntities(ClientQuotaFilter.containsOnly(List.empty.asJava), 0, entity => false) + } + + @ClusterTest + def testClientQuotasUnsupportedEntityTypes(): Unit = { + val entity = new ClientQuotaEntity(Map(("other" -> "name")).asJava) + assertThrows(classOf[UnsupportedVersionException], () => verifyDescribeEntityQuotas(entity, Map.empty)) + } + + @ClusterTest + def testClientQuotasSanitized(): Unit = { + // An entity with name that must be sanitized when writing to Zookeeper. + val entity = new ClientQuotaEntity(Map((ClientQuotaEntity.USER -> "user with spaces")).asJava) + + alterEntityQuotas(entity, Map( + (ProducerByteRateProp -> Some(20000.0)), + ), validateOnly = false) + + verifyDescribeEntityQuotas(entity, Map( + (ProducerByteRateProp -> 20000.0), + )) + } + + @ClusterTest + def testClientQuotasWithDefaultName(): Unit = { + // An entity using the name associated with the default entity name. The entity's name should be sanitized so + // that it does not conflict with the default entity name. + val entity = new ClientQuotaEntity(Map((ClientQuotaEntity.CLIENT_ID -> ConfigEntityName.Default)).asJava) + alterEntityQuotas(entity, Map((ProducerByteRateProp -> Some(20000.0))), validateOnly = false) + verifyDescribeEntityQuotas(entity, Map((ProducerByteRateProp -> 20000.0))) + + // This should not match. + val result = describeClientQuotas( + ClientQuotaFilter.containsOnly(List(ClientQuotaFilterComponent.ofDefaultEntity(ClientQuotaEntity.CLIENT_ID)).asJava)) + assert(result.isEmpty) + } + + private def verifyDescribeEntityQuotas(entity: ClientQuotaEntity, quotas: Map[String, Double]) = { + TestUtils.tryUntilNoAssertionError(waitTime = 5000L) { + val components = entity.entries.asScala.map { case (entityType, entityName) => + Option(entityName).map{ name => ClientQuotaFilterComponent.ofEntity(entityType, name)} + .getOrElse(ClientQuotaFilterComponent.ofDefaultEntity(entityType) + ) + } + val describe = describeClientQuotas(ClientQuotaFilter.containsOnly(components.toList.asJava)) + if (quotas.isEmpty) { + assertEquals(0, describe.size) + } else { + assertEquals(1, describe.size) + val configs = describe.get(entity) + assertNotNull(configs) + assertEquals(quotas.size, configs.size) + quotas.foreach { case (k, v) => + val value = configs.get(k) + assertNotNull(value) + assertEquals(v, value, 1e-6) + } + } + } + } + + private def toClientEntity(user: Option[String], clientId: Option[String]) = + new ClientQuotaEntity((user.map((ClientQuotaEntity.USER -> _)) ++ clientId.map((ClientQuotaEntity.CLIENT_ID -> _))).toMap.asJava) + + private def toIpEntity(ip: Option[String]) = new ClientQuotaEntity(ip.map(ClientQuotaEntity.IP -> _).toMap.asJava) + + private def describeClientQuotas(filter: ClientQuotaFilter) = { + val result = new KafkaFutureImpl[java.util.Map[ClientQuotaEntity, java.util.Map[String, java.lang.Double]]] + sendDescribeClientQuotasRequest(filter).complete(result) + try result.get catch { + case e: ExecutionException => throw e.getCause + } + } + + private def sendDescribeClientQuotasRequest(filter: ClientQuotaFilter): DescribeClientQuotasResponse = { + val request = new DescribeClientQuotasRequest.Builder(filter).build() + IntegrationTestUtils.connectAndReceive[DescribeClientQuotasResponse](request, + destination = cluster.anyBrokerSocketServer(), + listenerName = cluster.clientListener()) + } + + private def alterEntityQuotas(entity: ClientQuotaEntity, alter: Map[String, Option[Double]], validateOnly: Boolean) = + try alterClientQuotas(Map(entity -> alter), validateOnly).get(entity).get.get(10, TimeUnit.SECONDS) catch { + case e: ExecutionException => throw e.getCause + } + + private def alterClientQuotas(request: Map[ClientQuotaEntity, Map[String, Option[Double]]], validateOnly: Boolean) = { + val entries = request.map { case (entity, alter) => + val ops = alter.map { case (key, value) => + new ClientQuotaAlteration.Op(key, value.map(Double.box).getOrElse(null)) + }.asJavaCollection + new ClientQuotaAlteration(entity, ops) + } + + val response = request.map(e => (e._1 -> new KafkaFutureImpl[Void])).asJava + sendAlterClientQuotasRequest(entries, validateOnly).complete(response) + val result = response.asScala + assertEquals(request.size, result.size) + request.foreach(e => assertTrue(result.contains(e._1))) + result + } + + private def sendAlterClientQuotasRequest(entries: Iterable[ClientQuotaAlteration], validateOnly: Boolean): AlterClientQuotasResponse = { + val request = new AlterClientQuotasRequest.Builder(entries.asJavaCollection, validateOnly).build() + IntegrationTestUtils.connectAndReceive[AlterClientQuotasResponse](request, + destination = cluster.anyBrokerSocketServer(), + listenerName = cluster.clientListener()) + } + +} diff --git a/core/src/test/scala/unit/kafka/server/ClientRequestQuotaManagerTest.scala b/core/src/test/scala/unit/kafka/server/ClientRequestQuotaManagerTest.scala new file mode 100644 index 0000000..db2dcea --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/ClientRequestQuotaManagerTest.scala @@ -0,0 +1,89 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import kafka.server.QuotaType.Request +import org.apache.kafka.common.metrics.Quota + +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +class ClientRequestQuotaManagerTest extends BaseClientQuotaManagerTest { + private val config = ClientQuotaManagerConfig() + + @Test + def testRequestPercentageQuotaViolation(): Unit = { + val clientRequestQuotaManager = new ClientRequestQuotaManager(config, metrics, time, "", None) + clientRequestQuotaManager.updateQuota(Some("ANONYMOUS"), Some("test-client"), Some("test-client"), Some(Quota.upperBound(1))) + val queueSizeMetric = metrics.metrics().get(metrics.metricName("queue-size", Request.toString, "")) + def millisToPercent(millis: Double) = millis * 1000 * 1000 * ClientRequestQuotaManager.NanosToPercentagePerSecond + try { + // We have 10 second windows. Make sure that there is no quota violation + // if we are under the quota + for (_ <- 0 until 10) { + assertEquals(0, maybeRecord(clientRequestQuotaManager, "ANONYMOUS", "test-client", millisToPercent(4))) + time.sleep(1000) + } + assertEquals(0, queueSizeMetric.metricValue.asInstanceOf[Double].toInt) + + // Create a spike. + // quota = 1% (10ms per second) + // 4*10 + 67.1 = 107.1/10.5 = 10.2ms per second. + // (10.2 - quota)/quota*window-size = (10.2-10)/10*10.5 seconds = 210ms + // 10.5 seconds interval because the last window is half complete + time.sleep(500) + val throttleTime = maybeRecord(clientRequestQuotaManager, "ANONYMOUS", "test-client", millisToPercent(67.1)) + + assertEquals(210, throttleTime, "Should be throttled") + + throttle(clientRequestQuotaManager, "ANONYMOUS", "test-client", throttleTime, callback) + assertEquals(1, queueSizeMetric.metricValue.asInstanceOf[Double].toInt) + // After a request is delayed, the callback cannot be triggered immediately + clientRequestQuotaManager.throttledChannelReaper.doWork() + assertEquals(0, numCallbacks) + time.sleep(throttleTime) + + // Callback can only be triggered after the delay time passes + clientRequestQuotaManager.throttledChannelReaper.doWork() + assertEquals(0, queueSizeMetric.metricValue.asInstanceOf[Double].toInt) + assertEquals(1, numCallbacks) + + // Could continue to see delays until the bursty sample disappears + for (_ <- 0 until 11) { + maybeRecord(clientRequestQuotaManager, "ANONYMOUS", "test-client", millisToPercent(4)) + time.sleep(1000) + } + + assertEquals(0, + maybeRecord(clientRequestQuotaManager, "ANONYMOUS", "test-client", 0), "Should be unthrottled since bursty sample has rolled over") + + // Create a very large spike which requires > one quota window to bring within quota + assertEquals(1000, maybeRecord(clientRequestQuotaManager, "ANONYMOUS", "test-client", millisToPercent(500))) + for (_ <- 0 until 10) { + time.sleep(1000) + assertEquals(1000, maybeRecord(clientRequestQuotaManager, "ANONYMOUS", "test-client", 0)) + } + time.sleep(1000) + assertEquals(0, + maybeRecord(clientRequestQuotaManager, "ANONYMOUS", "test-client", 0), "Should be unthrottled since bursty sample has rolled over") + + } finally { + clientRequestQuotaManager.shutdown() + } + } +} + diff --git a/core/src/test/scala/unit/kafka/server/ControllerApisTest.scala b/core/src/test/scala/unit/kafka/server/ControllerApisTest.scala new file mode 100644 index 0000000..2176d23 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/ControllerApisTest.scala @@ -0,0 +1,810 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.net.InetAddress +import java.util +import java.util.Collections.singletonList +import java.util.Properties +import java.util.concurrent.{CompletableFuture, ExecutionException} + +import kafka.network.RequestChannel +import kafka.raft.RaftManager +import kafka.server.QuotaFactory.QuotaManagers +import kafka.test.MockController +import kafka.utils.{MockTime, NotNothing} +import org.apache.kafka.clients.admin.AlterConfigOp +import org.apache.kafka.common.Uuid.ZERO_UUID +import org.apache.kafka.common.acl.AclOperation +import org.apache.kafka.common.config.{ConfigResource, TopicConfig} +import org.apache.kafka.common.errors._ +import org.apache.kafka.common.memory.MemoryPool +import org.apache.kafka.common.message.AlterConfigsRequestData.{AlterConfigsResource => OldAlterConfigsResource, AlterConfigsResourceCollection => OldAlterConfigsResourceCollection, AlterableConfig => OldAlterableConfig, AlterableConfigCollection => OldAlterableConfigCollection} +import org.apache.kafka.common.message.AlterConfigsResponseData.{AlterConfigsResourceResponse => OldAlterConfigsResourceResponse} +import org.apache.kafka.common.message.ApiMessageType.ListenerType +import org.apache.kafka.common.message.CreatePartitionsRequestData.CreatePartitionsTopic +import org.apache.kafka.common.message.CreatePartitionsResponseData.CreatePartitionsTopicResult +import org.apache.kafka.common.message.CreateTopicsRequestData.{CreatableTopic, CreatableTopicCollection} +import org.apache.kafka.common.message.CreateTopicsResponseData.CreatableTopicResult +import org.apache.kafka.common.message.DeleteTopicsRequestData.DeleteTopicState +import org.apache.kafka.common.message.DeleteTopicsResponseData.DeletableTopicResult +import org.apache.kafka.common.message.IncrementalAlterConfigsRequestData.{AlterConfigsResource, AlterConfigsResourceCollection, AlterableConfig, AlterableConfigCollection} +import org.apache.kafka.common.message.IncrementalAlterConfigsResponseData.AlterConfigsResourceResponse +import org.apache.kafka.common.message.{CreateTopicsRequestData, _} +import org.apache.kafka.common.network.{ClientInformation, ListenerName} +import org.apache.kafka.common.protocol.Errors._ +import org.apache.kafka.common.protocol.{ApiKeys, ApiMessage, Errors} +import org.apache.kafka.common.requests._ +import org.apache.kafka.common.resource.{PatternType, Resource, ResourcePattern, ResourceType} +import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol} +import org.apache.kafka.common.{ElectionType, Uuid} +import org.apache.kafka.controller.Controller +import org.apache.kafka.server.authorizer.{Action, AuthorizableRequestContext, AuthorizationResult, Authorizer} +import org.apache.kafka.server.common.ApiMessageAndVersion +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, Test} +import org.mockito.ArgumentMatchers._ +import org.mockito.Mockito._ +import org.mockito.{ArgumentCaptor, ArgumentMatchers} + +import scala.annotation.nowarn +import scala.jdk.CollectionConverters._ +import scala.reflect.ClassTag + +class ControllerApisTest { + private val nodeId = 1 + private val brokerRack = "Rack1" + private val clientID = "Client1" + private val requestChannelMetrics: RequestChannel.Metrics = mock(classOf[RequestChannel.Metrics]) + private val requestChannel: RequestChannel = mock(classOf[RequestChannel]) + private val time = new MockTime + private val clientQuotaManager: ClientQuotaManager = mock(classOf[ClientQuotaManager]) + private val clientRequestQuotaManager: ClientRequestQuotaManager = mock(classOf[ClientRequestQuotaManager]) + private val clientControllerQuotaManager: ControllerMutationQuotaManager = mock(classOf[ControllerMutationQuotaManager]) + private val replicaQuotaManager: ReplicationQuotaManager = mock(classOf[ReplicationQuotaManager]) + private val raftManager: RaftManager[ApiMessageAndVersion] = mock(classOf[RaftManager[ApiMessageAndVersion]]) + + private val quotas = QuotaManagers( + clientQuotaManager, + clientQuotaManager, + clientRequestQuotaManager, + clientControllerQuotaManager, + replicaQuotaManager, + replicaQuotaManager, + replicaQuotaManager, + None) + + private def createControllerApis(authorizer: Option[Authorizer], + controller: Controller, + props: Properties = new Properties()): ControllerApis = { + props.put(KafkaConfig.NodeIdProp, nodeId: java.lang.Integer) + props.put(KafkaConfig.ProcessRolesProp, "controller") + props.put(KafkaConfig.ControllerListenerNamesProp, "PLAINTEXT") + props.put(KafkaConfig.QuorumVotersProp, s"$nodeId@localhost:9092") + new ControllerApis( + requestChannel, + authorizer, + quotas, + time, + Map.empty, + controller, + raftManager, + new KafkaConfig(props), + MetaProperties("JgxuGe9URy-E-ceaL04lEw", nodeId = nodeId), + Seq.empty, + new SimpleApiVersionManager(ListenerType.CONTROLLER) + ) + } + + /** + * Build a RequestChannel.Request from the AbstractRequest + * + * @param request - AbstractRequest + * @param listenerName - Default listener for the RequestChannel + * @tparam T - Type of AbstractRequest + * @return + */ + private def buildRequest[T <: AbstractRequest]( + request: AbstractRequest, + listenerName: ListenerName = ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT) + ): RequestChannel.Request = { + val buffer = request.serializeWithHeader(new RequestHeader(request.apiKey, request.version, clientID, 0)) + + // read the header from the buffer first so that the body can be read next from the Request constructor + val header = RequestHeader.parse(buffer) + val context = new RequestContext(header, "1", InetAddress.getLocalHost, KafkaPrincipal.ANONYMOUS, + listenerName, SecurityProtocol.PLAINTEXT, ClientInformation.EMPTY, false) + new RequestChannel.Request(processor = 1, context = context, startTimeNanos = 0, MemoryPool.NONE, buffer, + requestChannelMetrics) + } + + def createDenyAllAuthorizer(): Authorizer = { + val authorizer = mock(classOf[Authorizer]) + when(authorizer.authorize( + any(classOf[AuthorizableRequestContext]), + any(classOf[java.util.List[Action]]) + )).thenReturn( + singletonList(AuthorizationResult.DENIED) + ) + authorizer + } + + @Test + def testUnauthorizedFetch(): Unit = { + assertThrows(classOf[ClusterAuthorizationException], () => createControllerApis( + Some(createDenyAllAuthorizer()), new MockController.Builder().build()). + handleFetch(buildRequest(new FetchRequest(new FetchRequestData(), 12)))) + } + + @Test + def testFetchSentToKRaft(): Unit = { + when( + raftManager.handleRequest( + any(classOf[RequestHeader]), + any(classOf[ApiMessage]), + any(classOf[Long]) + ) + ).thenReturn( + new CompletableFuture[ApiMessage]() + ) + + createControllerApis(None, new MockController.Builder().build()) + .handleFetch(buildRequest(new FetchRequest(new FetchRequestData(), 12))) + + verify(raftManager).handleRequest( + ArgumentMatchers.any(), + ArgumentMatchers.any(), + ArgumentMatchers.any() + ) + } + + @Test + def testUnauthorizedFetchSnapshot(): Unit = { + assertThrows(classOf[ClusterAuthorizationException], () => createControllerApis( + Some(createDenyAllAuthorizer()), new MockController.Builder().build()). + handleFetchSnapshot(buildRequest(new FetchSnapshotRequest(new FetchSnapshotRequestData(), 0)))) + } + + @Test + def testFetchSnapshotSentToKRaft(): Unit = { + when( + raftManager.handleRequest( + any(classOf[RequestHeader]), + any(classOf[ApiMessage]), + any(classOf[Long]) + ) + ).thenReturn( + new CompletableFuture[ApiMessage]() + ) + + createControllerApis(None, new MockController.Builder().build()) + .handleFetchSnapshot(buildRequest(new FetchSnapshotRequest(new FetchSnapshotRequestData(), 0))) + + verify(raftManager).handleRequest( + ArgumentMatchers.any(), + ArgumentMatchers.any(), + ArgumentMatchers.any() + ) + } + + @Test + def testUnauthorizedVote(): Unit = { + assertThrows(classOf[ClusterAuthorizationException], () => createControllerApis( + Some(createDenyAllAuthorizer()), new MockController.Builder().build()). + handleVote(buildRequest(new VoteRequest.Builder(new VoteRequestData()).build(0)))) + } + + @Test + def testHandleLegacyAlterConfigsErrors(): Unit = { + val requestData = new AlterConfigsRequestData().setResources( + new OldAlterConfigsResourceCollection(util.Arrays.asList( + new OldAlterConfigsResource(). + setResourceName("1"). + setResourceType(ConfigResource.Type.BROKER.id()). + setConfigs(new OldAlterableConfigCollection(util.Arrays.asList(new OldAlterableConfig(). + setName(KafkaConfig.LogCleanerBackoffMsProp). + setValue("100000")).iterator())), + new OldAlterConfigsResource(). + setResourceName("2"). + setResourceType(ConfigResource.Type.BROKER.id()). + setConfigs(new OldAlterableConfigCollection(util.Arrays.asList(new OldAlterableConfig(). + setName(KafkaConfig.LogCleanerBackoffMsProp). + setValue("100000")).iterator())), + new OldAlterConfigsResource(). + setResourceName("2"). + setResourceType(ConfigResource.Type.BROKER.id()). + setConfigs(new OldAlterableConfigCollection(util.Arrays.asList(new OldAlterableConfig(). + setName(KafkaConfig.LogCleanerBackoffMsProp). + setValue("100000")).iterator())), + new OldAlterConfigsResource(). + setResourceName("baz"). + setResourceType(123.toByte). + setConfigs(new OldAlterableConfigCollection(util.Arrays.asList(new OldAlterableConfig(). + setName("foo"). + setValue("bar")).iterator())), + ).iterator())) + val request = buildRequest(new AlterConfigsRequest(requestData, 0)) + createControllerApis(Some(createDenyAllAuthorizer()), + new MockController.Builder().build()).handleLegacyAlterConfigs(request) + val capturedResponse: ArgumentCaptor[AbstractResponse] = + ArgumentCaptor.forClass(classOf[AbstractResponse]) + verify(requestChannel).sendResponse( + ArgumentMatchers.eq(request), + capturedResponse.capture(), + ArgumentMatchers.eq(None)) + assertNotNull(capturedResponse.getValue) + val response = capturedResponse.getValue.asInstanceOf[AlterConfigsResponse] + assertEquals(Set( + new OldAlterConfigsResourceResponse(). + setErrorCode(INVALID_REQUEST.code()). + setErrorMessage("Duplicate resource."). + setResourceName("2"). + setResourceType(ConfigResource.Type.BROKER.id()), + new OldAlterConfigsResourceResponse(). + setErrorCode(UNSUPPORTED_VERSION.code()). + setErrorMessage("Unknown resource type 123."). + setResourceName("baz"). + setResourceType(123.toByte), + new OldAlterConfigsResourceResponse(). + setErrorCode(CLUSTER_AUTHORIZATION_FAILED.code()). + setErrorMessage("Cluster authorization failed."). + setResourceName("1"). + setResourceType(ConfigResource.Type.BROKER.id())), + response.data().responses().asScala.toSet) + } + + @Test + def testUnauthorizedBeginQuorumEpoch(): Unit = { + assertThrows(classOf[ClusterAuthorizationException], () => createControllerApis( + Some(createDenyAllAuthorizer()), new MockController.Builder().build()). + handleBeginQuorumEpoch(buildRequest(new BeginQuorumEpochRequest.Builder( + new BeginQuorumEpochRequestData()).build(0)))) + } + + @Test + def testUnauthorizedEndQuorumEpoch(): Unit = { + assertThrows(classOf[ClusterAuthorizationException], () => createControllerApis( + Some(createDenyAllAuthorizer()), new MockController.Builder().build()). + handleEndQuorumEpoch(buildRequest(new EndQuorumEpochRequest.Builder( + new EndQuorumEpochRequestData()).build(0)))) + } + + @Test + def testUnauthorizedDescribeQuorum(): Unit = { + assertThrows(classOf[ClusterAuthorizationException], () => createControllerApis( + Some(createDenyAllAuthorizer()), new MockController.Builder().build()). + handleDescribeQuorum(buildRequest(new DescribeQuorumRequest.Builder( + new DescribeQuorumRequestData()).build(0)))) + } + + @Test + def testUnauthorizedHandleAlterIsrRequest(): Unit = { + assertThrows(classOf[ClusterAuthorizationException], () => createControllerApis( + Some(createDenyAllAuthorizer()), new MockController.Builder().build()). + handleAlterIsrRequest(buildRequest(new AlterIsrRequest.Builder( + new AlterIsrRequestData()).build(0)))) + } + + @Test + def testUnauthorizedHandleBrokerHeartBeatRequest(): Unit = { + assertThrows(classOf[ClusterAuthorizationException], () => createControllerApis( + Some(createDenyAllAuthorizer()), new MockController.Builder().build()). + handleBrokerHeartBeatRequest(buildRequest(new BrokerHeartbeatRequest.Builder( + new BrokerHeartbeatRequestData()).build(0)))) + } + + @Test + def testUnauthorizedHandleUnregisterBroker(): Unit = { + assertThrows(classOf[ClusterAuthorizationException], () => createControllerApis( + Some(createDenyAllAuthorizer()), new MockController.Builder().build()). + handleUnregisterBroker(buildRequest(new UnregisterBrokerRequest.Builder( + new UnregisterBrokerRequestData()).build(0)))) + } + + @Test + def testClose(): Unit = { + val apis = createControllerApis(Some(createDenyAllAuthorizer()), mock(classOf[Controller])) + apis.close() + assertTrue(apis.isClosed) + } + + @Test + def testUnauthorizedBrokerRegistration(): Unit = { + val brokerRegistrationRequest = new BrokerRegistrationRequest.Builder( + new BrokerRegistrationRequestData() + .setBrokerId(nodeId) + .setRack(brokerRack) + ).build() + + val request = buildRequest(brokerRegistrationRequest) + val capturedResponse: ArgumentCaptor[AbstractResponse] = ArgumentCaptor.forClass(classOf[AbstractResponse]) + + createControllerApis(Some(createDenyAllAuthorizer()), mock(classOf[Controller])).handle(request, + RequestLocal.withThreadConfinedCaching) + verify(requestChannel).sendResponse( + ArgumentMatchers.eq(request), + capturedResponse.capture(), + ArgumentMatchers.eq(None)) + + assertNotNull(capturedResponse.getValue) + + val brokerRegistrationResponse = capturedResponse.getValue.asInstanceOf[BrokerRegistrationResponse] + assertEquals(Map(CLUSTER_AUTHORIZATION_FAILED -> 1), + brokerRegistrationResponse.errorCounts().asScala) + } + + @Test + def testUnauthorizedHandleAlterClientQuotas(): Unit = { + assertThrows(classOf[ClusterAuthorizationException], () => createControllerApis( + Some(createDenyAllAuthorizer()), new MockController.Builder().build()). + handleAlterClientQuotas(buildRequest(new AlterClientQuotasRequest( + new AlterClientQuotasRequestData(), 0)))) + } + + @Test + def testUnauthorizedHandleIncrementalAlterConfigs(): Unit = { + val requestData = new IncrementalAlterConfigsRequestData().setResources( + new AlterConfigsResourceCollection( + util.Arrays.asList(new AlterConfigsResource(). + setResourceName("1"). + setResourceType(ConfigResource.Type.BROKER.id()). + setConfigs(new AlterableConfigCollection(util.Arrays.asList(new AlterableConfig(). + setName(KafkaConfig.LogCleanerBackoffMsProp). + setValue("100000"). + setConfigOperation(AlterConfigOp.OpType.SET.id())).iterator())), + new AlterConfigsResource(). + setResourceName("foo"). + setResourceType(ConfigResource.Type.TOPIC.id()). + setConfigs(new AlterableConfigCollection(util.Arrays.asList(new AlterableConfig(). + setName(TopicConfig.FLUSH_MS_CONFIG). + setValue("1000"). + setConfigOperation(AlterConfigOp.OpType.SET.id())).iterator())), + ).iterator())) + val request = buildRequest(new IncrementalAlterConfigsRequest.Builder(requestData).build(0)) + createControllerApis(Some(createDenyAllAuthorizer()), + new MockController.Builder().build()).handleIncrementalAlterConfigs(request) + val capturedResponse: ArgumentCaptor[AbstractResponse] = + ArgumentCaptor.forClass(classOf[AbstractResponse]) + verify(requestChannel).sendResponse( + ArgumentMatchers.eq(request), + capturedResponse.capture(), + ArgumentMatchers.eq(None)) + assertNotNull(capturedResponse.getValue) + val response = capturedResponse.getValue.asInstanceOf[IncrementalAlterConfigsResponse] + assertEquals(Set(new AlterConfigsResourceResponse(). + setErrorCode(CLUSTER_AUTHORIZATION_FAILED.code()). + setErrorMessage(CLUSTER_AUTHORIZATION_FAILED.message()). + setResourceName("1"). + setResourceType(ConfigResource.Type.BROKER.id()), + new AlterConfigsResourceResponse(). + setErrorCode(TOPIC_AUTHORIZATION_FAILED.code()). + setErrorMessage(TOPIC_AUTHORIZATION_FAILED.message()). + setResourceName("foo"). + setResourceType(ConfigResource.Type.TOPIC.id())), + response.data().responses().asScala.toSet) + } + + @Test + def testInvalidIncrementalAlterConfigsResources(): Unit = { + val requestData = new IncrementalAlterConfigsRequestData().setResources( + new AlterConfigsResourceCollection(util.Arrays.asList( + new AlterConfigsResource(). + setResourceName("1"). + setResourceType(ConfigResource.Type.BROKER_LOGGER.id()). + setConfigs(new AlterableConfigCollection(util.Arrays.asList(new AlterableConfig(). + setName("kafka.server.KafkaConfig"). + setValue("TRACE"). + setConfigOperation(AlterConfigOp.OpType.SET.id())).iterator())), + new AlterConfigsResource(). + setResourceName("3"). + setResourceType(ConfigResource.Type.BROKER.id()). + setConfigs(new AlterableConfigCollection(util.Arrays.asList(new AlterableConfig(). + setName(KafkaConfig.LogCleanerBackoffMsProp). + setValue("100000"). + setConfigOperation(AlterConfigOp.OpType.SET.id())).iterator())), + new AlterConfigsResource(). + setResourceName("3"). + setResourceType(ConfigResource.Type.BROKER.id()). + setConfigs(new AlterableConfigCollection(util.Arrays.asList(new AlterableConfig(). + setName(KafkaConfig.LogCleanerBackoffMsProp). + setValue("100000"). + setConfigOperation(AlterConfigOp.OpType.SET.id())).iterator())), + new AlterConfigsResource(). + setResourceName("foo"). + setResourceType(124.toByte). + setConfigs(new AlterableConfigCollection(util.Arrays.asList(new AlterableConfig(). + setName("foo"). + setValue("bar"). + setConfigOperation(AlterConfigOp.OpType.SET.id())).iterator())), + ).iterator())) + val request = buildRequest(new IncrementalAlterConfigsRequest.Builder(requestData).build(0)) + createControllerApis(Some(createDenyAllAuthorizer()), + new MockController.Builder().build()).handleIncrementalAlterConfigs(request) + val capturedResponse: ArgumentCaptor[AbstractResponse] = + ArgumentCaptor.forClass(classOf[AbstractResponse]) + verify(requestChannel).sendResponse( + ArgumentMatchers.eq(request), + capturedResponse.capture(), + ArgumentMatchers.eq(None)) + assertNotNull(capturedResponse.getValue) + val response = capturedResponse.getValue.asInstanceOf[IncrementalAlterConfigsResponse] + assertEquals(Set( + new AlterConfigsResourceResponse(). + setErrorCode(INVALID_REQUEST.code()). + setErrorMessage("Unexpected resource type BROKER_LOGGER."). + setResourceName("1"). + setResourceType(ConfigResource.Type.BROKER_LOGGER.id()), + new AlterConfigsResourceResponse(). + setErrorCode(INVALID_REQUEST.code()). + setErrorMessage("Duplicate resource."). + setResourceName("3"). + setResourceType(ConfigResource.Type.BROKER.id()), + new AlterConfigsResourceResponse(). + setErrorCode(UNSUPPORTED_VERSION.code()). + setErrorMessage("Unknown resource type 124."). + setResourceName("foo"). + setResourceType(124.toByte)), + response.data().responses().asScala.toSet) + } + + @Test + def testUnauthorizedHandleAlterPartitionReassignments(): Unit = { + assertThrows(classOf[ClusterAuthorizationException], () => createControllerApis( + Some(createDenyAllAuthorizer()), new MockController.Builder().build()). + handleAlterPartitionReassignments(buildRequest(new AlterPartitionReassignmentsRequest.Builder( + new AlterPartitionReassignmentsRequestData()).build()))) + } + + @Test + def testUnauthorizedHandleAllocateProducerIds(): Unit = { + assertThrows(classOf[ClusterAuthorizationException], () => createControllerApis( + Some(createDenyAllAuthorizer()), new MockController.Builder().build()). + handleAllocateProducerIdsRequest(buildRequest(new AllocateProducerIdsRequest.Builder( + new AllocateProducerIdsRequestData()).build()))) + } + + @Test + def testUnauthorizedHandleListPartitionReassignments(): Unit = { + assertThrows(classOf[ClusterAuthorizationException], () => createControllerApis( + Some(createDenyAllAuthorizer()), new MockController.Builder().build()). + handleListPartitionReassignments(buildRequest(new ListPartitionReassignmentsRequest.Builder( + new ListPartitionReassignmentsRequestData()).build()))) + } + + @Test + def testCreateTopics(): Unit = { + val controller = new MockController.Builder().build() + val controllerApis = createControllerApis(None, controller) + val request = new CreateTopicsRequestData().setTopics(new CreatableTopicCollection( + util.Arrays.asList(new CreatableTopic().setName("foo").setNumPartitions(1).setReplicationFactor(3), + new CreatableTopic().setName("foo").setNumPartitions(2).setReplicationFactor(3), + new CreatableTopic().setName("bar").setNumPartitions(2).setReplicationFactor(3), + new CreatableTopic().setName("bar").setNumPartitions(2).setReplicationFactor(3), + new CreatableTopic().setName("bar").setNumPartitions(2).setReplicationFactor(3), + new CreatableTopic().setName("baz").setNumPartitions(2).setReplicationFactor(3), + new CreatableTopic().setName("quux").setNumPartitions(2).setReplicationFactor(3), + ).iterator())) + val expectedResponse = Set(new CreatableTopicResult().setName("foo"). + setErrorCode(INVALID_REQUEST.code()). + setErrorMessage("Duplicate topic name."), + new CreatableTopicResult().setName("bar"). + setErrorCode(INVALID_REQUEST.code()). + setErrorMessage("Duplicate topic name."), + new CreatableTopicResult().setName("baz"). + setErrorCode(NONE.code()). + setTopicId(new Uuid(0L, 1L)), + new CreatableTopicResult().setName("quux"). + setErrorCode(TOPIC_AUTHORIZATION_FAILED.code())) + assertEquals(expectedResponse, controllerApis.createTopics(request, + false, + _ => Set("baz")).get().topics().asScala.toSet) + } + + @Test + def testDeleteTopicsByName(): Unit = { + val fooId = Uuid.fromString("vZKYST0pSA2HO5x_6hoO2Q") + val controller = new MockController.Builder().newInitialTopic("foo", fooId).build() + val controllerApis = createControllerApis(None, controller) + val request = new DeleteTopicsRequestData().setTopicNames( + util.Arrays.asList("foo", "bar", "quux", "quux")) + val expectedResponse = Set(new DeletableTopicResult().setName("quux"). + setErrorCode(INVALID_REQUEST.code()). + setErrorMessage("Duplicate topic name."), + new DeletableTopicResult().setName("bar"). + setErrorCode(UNKNOWN_TOPIC_OR_PARTITION.code()). + setErrorMessage("This server does not host this topic-partition."), + new DeletableTopicResult().setName("foo").setTopicId(fooId)) + assertEquals(expectedResponse, controllerApis.deleteTopics(request, + ApiKeys.DELETE_TOPICS.latestVersion().toInt, + true, + _ => Set.empty, + _ => Set.empty).get().asScala.toSet) + } + + @Test + def testDeleteTopicsById(): Unit = { + val fooId = Uuid.fromString("vZKYST0pSA2HO5x_6hoO2Q") + val barId = Uuid.fromString("VlFu5c51ToiNx64wtwkhQw") + val quuxId = Uuid.fromString("ObXkLhL_S5W62FAE67U3MQ") + val controller = new MockController.Builder().newInitialTopic("foo", fooId).build() + val controllerApis = createControllerApis(None, controller) + val request = new DeleteTopicsRequestData() + request.topics().add(new DeleteTopicState().setName(null).setTopicId(fooId)) + request.topics().add(new DeleteTopicState().setName(null).setTopicId(barId)) + request.topics().add(new DeleteTopicState().setName(null).setTopicId(quuxId)) + request.topics().add(new DeleteTopicState().setName(null).setTopicId(quuxId)) + val response = Set(new DeletableTopicResult().setName(null).setTopicId(quuxId). + setErrorCode(INVALID_REQUEST.code()). + setErrorMessage("Duplicate topic id."), + new DeletableTopicResult().setName(null).setTopicId(barId). + setErrorCode(UNKNOWN_TOPIC_ID.code()). + setErrorMessage("This server does not host this topic ID."), + new DeletableTopicResult().setName("foo").setTopicId(fooId)) + assertEquals(response, controllerApis.deleteTopics(request, + ApiKeys.DELETE_TOPICS.latestVersion().toInt, + true, + _ => Set.empty, + _ => Set.empty).get().asScala.toSet) + } + + @Test + def testInvalidDeleteTopicsRequest(): Unit = { + val fooId = Uuid.fromString("vZKYST0pSA2HO5x_6hoO2Q") + val barId = Uuid.fromString("VlFu5c51ToiNx64wtwkhQw") + val bazId = Uuid.fromString("YOS4oQ3UT9eSAZahN1ysSA") + val controller = new MockController.Builder(). + newInitialTopic("foo", fooId). + newInitialTopic("bar", barId).build() + val controllerApis = createControllerApis(None, controller) + val request = new DeleteTopicsRequestData() + request.topics().add(new DeleteTopicState().setName(null).setTopicId(ZERO_UUID)) + request.topics().add(new DeleteTopicState().setName("foo").setTopicId(fooId)) + request.topics().add(new DeleteTopicState().setName("bar").setTopicId(ZERO_UUID)) + request.topics().add(new DeleteTopicState().setName(null).setTopicId(barId)) + request.topics().add(new DeleteTopicState().setName("quux").setTopicId(ZERO_UUID)) + request.topics().add(new DeleteTopicState().setName("quux").setTopicId(ZERO_UUID)) + request.topics().add(new DeleteTopicState().setName("quux").setTopicId(ZERO_UUID)) + request.topics().add(new DeleteTopicState().setName(null).setTopicId(bazId)) + request.topics().add(new DeleteTopicState().setName(null).setTopicId(bazId)) + request.topics().add(new DeleteTopicState().setName(null).setTopicId(bazId)) + val response = Set(new DeletableTopicResult().setName(null).setTopicId(ZERO_UUID). + setErrorCode(INVALID_REQUEST.code()). + setErrorMessage("Neither topic name nor id were specified."), + new DeletableTopicResult().setName("foo").setTopicId(fooId). + setErrorCode(INVALID_REQUEST.code()). + setErrorMessage("You may not specify both topic name and topic id."), + new DeletableTopicResult().setName("bar").setTopicId(barId). + setErrorCode(INVALID_REQUEST.code()). + setErrorMessage("The provided topic name maps to an ID that was already supplied."), + new DeletableTopicResult().setName("quux").setTopicId(ZERO_UUID). + setErrorCode(INVALID_REQUEST.code()). + setErrorMessage("Duplicate topic name."), + new DeletableTopicResult().setName(null).setTopicId(bazId). + setErrorCode(INVALID_REQUEST.code()). + setErrorMessage("Duplicate topic id.")) + assertEquals(response, controllerApis.deleteTopics(request, + ApiKeys.DELETE_TOPICS.latestVersion().toInt, + false, + names => names.toSet, + names => names.toSet).get().asScala.toSet) + } + + @Test + def testNotAuthorizedToDeleteWithTopicExisting(): Unit = { + val fooId = Uuid.fromString("vZKYST0pSA2HO5x_6hoO2Q") + val barId = Uuid.fromString("VlFu5c51ToiNx64wtwkhQw") + val bazId = Uuid.fromString("hr4TVh3YQiu3p16Awkka6w") + val quuxId = Uuid.fromString("5URoQzW_RJiERVZXJgUVLg") + val controller = new MockController.Builder(). + newInitialTopic("foo", fooId). + newInitialTopic("bar", barId). + newInitialTopic("baz", bazId). + newInitialTopic("quux", quuxId).build() + val controllerApis = createControllerApis(None, controller) + val request = new DeleteTopicsRequestData() + request.topics().add(new DeleteTopicState().setName(null).setTopicId(fooId)) + request.topics().add(new DeleteTopicState().setName(null).setTopicId(barId)) + request.topics().add(new DeleteTopicState().setName("baz").setTopicId(ZERO_UUID)) + request.topics().add(new DeleteTopicState().setName("quux").setTopicId(ZERO_UUID)) + val response = Set(new DeletableTopicResult().setName(null).setTopicId(barId). + setErrorCode(TOPIC_AUTHORIZATION_FAILED.code). + setErrorMessage(TOPIC_AUTHORIZATION_FAILED.message), + new DeletableTopicResult().setName("quux").setTopicId(ZERO_UUID). + setErrorCode(TOPIC_AUTHORIZATION_FAILED.code). + setErrorMessage(TOPIC_AUTHORIZATION_FAILED.message), + new DeletableTopicResult().setName("baz").setTopicId(ZERO_UUID). + setErrorCode(TOPIC_AUTHORIZATION_FAILED.code). + setErrorMessage(TOPIC_AUTHORIZATION_FAILED.message), + new DeletableTopicResult().setName("foo").setTopicId(fooId). + setErrorCode(TOPIC_AUTHORIZATION_FAILED.code). + setErrorMessage(TOPIC_AUTHORIZATION_FAILED.message)) + assertEquals(response, controllerApis.deleteTopics(request, + ApiKeys.DELETE_TOPICS.latestVersion().toInt, + false, + _ => Set("foo", "baz"), + _ => Set.empty).get().asScala.toSet) + } + + @Test + def testNotAuthorizedToDeleteWithTopicNotExisting(): Unit = { + val barId = Uuid.fromString("VlFu5c51ToiNx64wtwkhQw") + val controller = new MockController.Builder().build() + val controllerApis = createControllerApis(None, controller) + val request = new DeleteTopicsRequestData() + request.topics().add(new DeleteTopicState().setName("foo").setTopicId(ZERO_UUID)) + request.topics().add(new DeleteTopicState().setName("bar").setTopicId(ZERO_UUID)) + request.topics().add(new DeleteTopicState().setName(null).setTopicId(barId)) + val expectedResponse = Set(new DeletableTopicResult().setName("foo"). + setErrorCode(UNKNOWN_TOPIC_OR_PARTITION.code). + setErrorMessage(UNKNOWN_TOPIC_OR_PARTITION.message), + new DeletableTopicResult().setName("bar"). + setErrorCode(TOPIC_AUTHORIZATION_FAILED.code). + setErrorMessage(TOPIC_AUTHORIZATION_FAILED.message), + new DeletableTopicResult().setName(null).setTopicId(barId). + setErrorCode(UNKNOWN_TOPIC_ID.code). + setErrorMessage(UNKNOWN_TOPIC_ID.message)) + assertEquals(expectedResponse, controllerApis.deleteTopics(request, + ApiKeys.DELETE_TOPICS.latestVersion().toInt, + false, + _ => Set("foo"), + _ => Set.empty).get().asScala.toSet) + } + + @Test + def testNotControllerErrorPreventsDeletingTopics(): Unit = { + val fooId = Uuid.fromString("vZKYST0pSA2HO5x_6hoO2Q") + val barId = Uuid.fromString("VlFu5c51ToiNx64wtwkhQw") + val controller = new MockController.Builder(). + newInitialTopic("foo", fooId).build() + controller.setActive(false) + val controllerApis = createControllerApis(None, controller) + val request = new DeleteTopicsRequestData() + request.topics().add(new DeleteTopicState().setName(null).setTopicId(fooId)) + request.topics().add(new DeleteTopicState().setName(null).setTopicId(barId)) + assertEquals(classOf[NotControllerException], assertThrows( + classOf[ExecutionException], () => controllerApis.deleteTopics(request, + ApiKeys.DELETE_TOPICS.latestVersion().toInt, + false, + _ => Set("foo", "bar"), + _ => Set("foo", "bar")).get()).getCause.getClass) + } + + @Test + def testDeleteTopicsDisabled(): Unit = { + val fooId = Uuid.fromString("vZKYST0pSA2HO5x_6hoO2Q") + val controller = new MockController.Builder(). + newInitialTopic("foo", fooId).build() + val props = new Properties() + props.put(KafkaConfig.DeleteTopicEnableProp, "false") + val controllerApis = createControllerApis(None, controller, props) + val request = new DeleteTopicsRequestData() + request.topics().add(new DeleteTopicState().setName("foo").setTopicId(ZERO_UUID)) + assertThrows(classOf[TopicDeletionDisabledException], () => controllerApis.deleteTopics(request, + ApiKeys.DELETE_TOPICS.latestVersion().toInt, + false, + _ => Set("foo", "bar"), + _ => Set("foo", "bar"))) + assertThrows(classOf[InvalidRequestException], () => controllerApis.deleteTopics(request, + 1, + false, + _ => Set("foo", "bar"), + _ => Set("foo", "bar"))) + } + + @Test + def testCreatePartitionsRequest(): Unit = { + val controller = new MockController.Builder(). + newInitialTopic("foo", Uuid.fromString("vZKYST0pSA2HO5x_6hoO2Q")). + newInitialTopic("bar", Uuid.fromString("VlFu5c51ToiNx64wtwkhQw")).build() + val controllerApis = createControllerApis(None, controller) + val request = new CreatePartitionsRequestData() + request.topics().add(new CreatePartitionsTopic().setName("foo").setAssignments(null).setCount(5)) + request.topics().add(new CreatePartitionsTopic().setName("bar").setAssignments(null).setCount(5)) + request.topics().add(new CreatePartitionsTopic().setName("bar").setAssignments(null).setCount(5)) + request.topics().add(new CreatePartitionsTopic().setName("bar").setAssignments(null).setCount(5)) + request.topics().add(new CreatePartitionsTopic().setName("baz").setAssignments(null).setCount(5)) + assertEquals(Set(new CreatePartitionsTopicResult().setName("foo"). + setErrorCode(NONE.code()). + setErrorMessage(null), + new CreatePartitionsTopicResult().setName("bar"). + setErrorCode(INVALID_REQUEST.code()). + setErrorMessage("Duplicate topic name."), + new CreatePartitionsTopicResult().setName("baz"). + setErrorCode(TOPIC_AUTHORIZATION_FAILED.code()). + setErrorMessage(null)), + controllerApis.createPartitions(request, false, _ => Set("foo", "bar")).get().asScala.toSet) + } + + @Test + def testElectLeadersAuthorization(): Unit = { + val authorizer = mock(classOf[Authorizer]) + val controller = mock(classOf[Controller]) + val controllerApis = createControllerApis(Some(authorizer), controller) + + val request = new ElectLeadersRequest.Builder( + ElectionType.PREFERRED, + null, + 30000 + ).build() + + val resource = new ResourcePattern(ResourceType.CLUSTER, Resource.CLUSTER_NAME, PatternType.LITERAL) + val actions = singletonList(new Action(AclOperation.ALTER, resource, 1, true, true)) + + when(authorizer.authorize( + any[RequestContext], + ArgumentMatchers.eq(actions) + )).thenReturn(singletonList(AuthorizationResult.DENIED)) + + val response = handleRequest[ElectLeadersResponse](request, controllerApis) + assertEquals(Errors.CLUSTER_AUTHORIZATION_FAILED, Errors.forCode(response.data.errorCode)) + } + + @Test + def testElectLeadersHandledByController(): Unit = { + val controller = mock(classOf[Controller]) + val controllerApis = createControllerApis(None, controller) + + val request = new ElectLeadersRequest.Builder( + ElectionType.PREFERRED, + null, + 30000 + ).build() + + val responseData = new ElectLeadersResponseData() + .setErrorCode(Errors.NOT_CONTROLLER.code) + + when(controller.electLeaders( + request.data + )).thenReturn(CompletableFuture.completedFuture(responseData)) + + val response = handleRequest[ElectLeadersResponse](request, controllerApis) + assertEquals(Errors.NOT_CONTROLLER, Errors.forCode(response.data.errorCode)) + } + + private def handleRequest[T <: AbstractResponse]( + request: AbstractRequest, + controllerApis: ControllerApis + )( + implicit classTag: ClassTag[T], + @nowarn("cat=unused") nn: NotNothing[T] + ): T = { + val req = buildRequest(request) + + controllerApis.handle(req, RequestLocal.NoCaching) + + val capturedResponse: ArgumentCaptor[AbstractResponse] = + ArgumentCaptor.forClass(classOf[AbstractResponse]) + verify(requestChannel).sendResponse( + ArgumentMatchers.eq(req), + capturedResponse.capture(), + ArgumentMatchers.eq(None) + ) + + capturedResponse.getValue match { + case response: T => response + case response => + throw new ClassCastException(s"Expected response with type ${classTag.runtimeClass}, " + + s"but found ${response.getClass}") + } + } + + @AfterEach + def tearDown(): Unit = { + quotas.shutdown() + } +} diff --git a/core/src/test/scala/unit/kafka/server/ControllerConfigurationValidatorTest.scala b/core/src/test/scala/unit/kafka/server/ControllerConfigurationValidatorTest.scala new file mode 100644 index 0000000..3c85299 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/ControllerConfigurationValidatorTest.scala @@ -0,0 +1,71 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +package kafka.server + +import java.util.TreeMap +import java.util.Collections.emptyMap + +import org.junit.jupiter.api.Test +import org.apache.kafka.common.config.ConfigResource +import org.apache.kafka.common.config.ConfigResource.Type.{BROKER_LOGGER, TOPIC} +import org.apache.kafka.common.config.TopicConfig.{SEGMENT_BYTES_CONFIG, SEGMENT_JITTER_MS_CONFIG, SEGMENT_MS_CONFIG} +import org.apache.kafka.common.errors.{InvalidConfigurationException, InvalidRequestException} +import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows} + +class ControllerConfigurationValidatorTest { + @Test + def testUnknownResourceType(): Unit = { + val validator = new ControllerConfigurationValidator() + assertEquals("Unknown resource type BROKER_LOGGER", + assertThrows(classOf[InvalidRequestException], () => validator.validate( + new ConfigResource(BROKER_LOGGER, "foo"), emptyMap())). getMessage()) + } + + @Test + def testNullTopicConfigValue(): Unit = { + val validator = new ControllerConfigurationValidator() + val config = new TreeMap[String, String]() + config.put(SEGMENT_JITTER_MS_CONFIG, "10") + config.put(SEGMENT_BYTES_CONFIG, null) + config.put(SEGMENT_MS_CONFIG, null) + assertEquals("Null value not supported for topic configs : segment.bytes,segment.ms", + assertThrows(classOf[InvalidRequestException], () => validator.validate( + new ConfigResource(TOPIC, "foo"), config)). getMessage()) + } + + @Test + def testValidTopicConfig(): Unit = { + val validator = new ControllerConfigurationValidator() + val config = new TreeMap[String, String]() + config.put(SEGMENT_JITTER_MS_CONFIG, "1000") + config.put(SEGMENT_BYTES_CONFIG, "67108864") + validator.validate(new ConfigResource(TOPIC, "foo"), config) + } + + @Test + def testInvalidTopicConfig(): Unit = { + val validator = new ControllerConfigurationValidator() + val config = new TreeMap[String, String]() + config.put(SEGMENT_JITTER_MS_CONFIG, "1000") + config.put(SEGMENT_BYTES_CONFIG, "67108864") + config.put("foobar", "abc") + assertEquals("Unknown topic config name: foobar", + assertThrows(classOf[InvalidConfigurationException], () => validator.validate( + new ConfigResource(TOPIC, "foo"), config)). getMessage()) + } +} \ No newline at end of file diff --git a/core/src/test/scala/unit/kafka/server/ControllerMutationQuotaManagerTest.scala b/core/src/test/scala/unit/kafka/server/ControllerMutationQuotaManagerTest.scala new file mode 100644 index 0000000..5479e15 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/ControllerMutationQuotaManagerTest.scala @@ -0,0 +1,235 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.util.concurrent.TimeUnit + +import kafka.server.QuotaType.ControllerMutation +import org.apache.kafka.common.errors.ThrottlingQuotaExceededException +import org.apache.kafka.common.metrics.MetricConfig +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.metrics.Quota +import org.apache.kafka.common.metrics.QuotaViolationException +import org.apache.kafka.common.metrics.stats.TokenBucket +import org.apache.kafka.common.utils.MockTime +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertFalse +import org.junit.jupiter.api.Test + +class StrictControllerMutationQuotaTest { + @Test + def testControllerMutationQuotaViolation(): Unit = { + val time = new MockTime(0, System.currentTimeMillis, 0) + val metrics = new Metrics(time) + val sensor = metrics.sensor("sensor", new MetricConfig() + .quota(Quota.upperBound(10)) + .timeWindow(1, TimeUnit.SECONDS) + .samples(10)) + val metricName = metrics.metricName("rate", "test-group") + assertTrue(sensor.add(metricName, new TokenBucket)) + + val quota = new StrictControllerMutationQuota(time, sensor) + assertFalse(quota.isExceeded) + + // Recording a first value at T to bring the tokens to 10. Value is accepted + // because the quota is not exhausted yet. + quota.record(90) + assertFalse(quota.isExceeded) + assertEquals(0, quota.throttleTime) + + // Recording a second value at T to bring the tokens to -80. Value is accepted + quota.record(90) + assertFalse(quota.isExceeded) + assertEquals(0, quota.throttleTime) + + // Recording a third value at T is rejected immediately because there are not + // tokens available in the bucket. + assertThrows(classOf[ThrottlingQuotaExceededException], () => quota.record(90)) + assertTrue(quota.isExceeded) + assertEquals(8000, quota.throttleTime) + + // Throttle time is adjusted with time + time.sleep(5000) + assertEquals(3000, quota.throttleTime) + + metrics.close() + } +} + +class PermissiveControllerMutationQuotaTest { + @Test + def testControllerMutationQuotaViolation(): Unit = { + val time = new MockTime(0, System.currentTimeMillis, 0) + val metrics = new Metrics(time) + val sensor = metrics.sensor("sensor", new MetricConfig() + .quota(Quota.upperBound(10)) + .timeWindow(1, TimeUnit.SECONDS) + .samples(10)) + val metricName = metrics.metricName("rate", "test-group") + assertTrue(sensor.add(metricName, new TokenBucket)) + + val quota = new PermissiveControllerMutationQuota(time, sensor) + assertFalse(quota.isExceeded) + + // Recording a first value at T to bring the tokens 10. Value is accepted + // because the quota is not exhausted yet. + quota.record(90) + assertFalse(quota.isExceeded) + assertEquals(0, quota.throttleTime) + + // Recording a second value at T to bring the tokens to -80. Value is accepted + quota.record(90) + assertFalse(quota.isExceeded) + assertEquals(8000, quota.throttleTime) + + // Recording a second value at T to bring the tokens to -170. Value is accepted + // even though the quota is exhausted. + quota.record(90) + assertFalse(quota.isExceeded) // quota is never exceeded + assertEquals(17000, quota.throttleTime) + + // Throttle time is adjusted with time + time.sleep(5000) + assertEquals(12000, quota.throttleTime) + + metrics.close() + } +} + +class ControllerMutationQuotaManagerTest extends BaseClientQuotaManagerTest { + private val User = "ANONYMOUS" + private val ClientId = "test-client" + + private val config = ClientQuotaManagerConfig( + numQuotaSamples = 10, + quotaWindowSizeSeconds = 1 + ) + + private def withQuotaManager(f: ControllerMutationQuotaManager => Unit): Unit = { + val quotaManager = new ControllerMutationQuotaManager(config, metrics, time,"", None) + try { + f(quotaManager) + } finally { + quotaManager.shutdown() + } + } + + @Test + def testThrottleTime(): Unit = { + import ControllerMutationQuotaManager._ + + val time = new MockTime(0, System.currentTimeMillis, 0) + val metrics = new Metrics(time) + val sensor = metrics.sensor("sensor") + val metricName = metrics.metricName("tokens", "test-group") + sensor.add(metricName, new TokenBucket) + val metric = metrics.metric(metricName) + + assertEquals(0, throttleTimeMs(new QuotaViolationException(metric, 0, 10), time.milliseconds())) + assertEquals(500, throttleTimeMs(new QuotaViolationException(metric, -5, 10), time.milliseconds())) + assertEquals(1000, throttleTimeMs(new QuotaViolationException(metric, -10, 10), time.milliseconds())) + } + + @Test + def testControllerMutationQuotaViolation(): Unit = { + withQuotaManager { quotaManager => + quotaManager.updateQuota(Some(User), Some(ClientId), Some(ClientId), + Some(Quota.upperBound(10))) + val queueSizeMetric = metrics.metrics().get( + metrics.metricName("queue-size", ControllerMutation.toString, "")) + + // Verify that there is no quota violation if we remain under the quota. + for (_ <- 0 until 10) { + assertEquals(0, maybeRecord(quotaManager, User, ClientId, 10)) + time.sleep(1000) + } + assertEquals(0, queueSizeMetric.metricValue.asInstanceOf[Double].toInt) + + // Create a spike worth of 110 mutations. + // Current tokens in the bucket = 100 + // As we use the Strict enforcement, the quota is checked before updating the rate. Hence, + // the spike is accepted and no quota violation error is raised. + var throttleTime = maybeRecord(quotaManager, User, ClientId, 110) + assertEquals(0, throttleTime, "Should not be throttled") + + // Create a spike worth of 110 mutations. + // Current tokens in the bucket = 100 - 110 = -10 + // As the quota is already violated, the spike is rejected immediately without updating the + // rate. The client must wait: + // 10 / 10 = 1s + throttleTime = maybeRecord(quotaManager, User, ClientId, 110) + assertEquals(1000, throttleTime, "Should be throttled") + + // Throttle + throttle(quotaManager, User, ClientId, throttleTime, callback) + assertEquals(1, queueSizeMetric.metricValue.asInstanceOf[Double].toInt) + + // After a request is delayed, the callback cannot be triggered immediately + quotaManager.throttledChannelReaper.doWork() + assertEquals(0, numCallbacks) + + // Callback can only be triggered after the delay time passes + time.sleep(throttleTime) + quotaManager.throttledChannelReaper.doWork() + assertEquals(0, queueSizeMetric.metricValue.asInstanceOf[Double].toInt) + assertEquals(1, numCallbacks) + + // Retry to spike worth of 110 mutations after having waited the required throttle time. + // Current tokens in the bucket = 0 + throttleTime = maybeRecord(quotaManager, User, ClientId, 110) + assertEquals(0, throttleTime, "Should be throttled") + } + } + + @Test + def testNewStrictQuotaForReturnsUnboundedQuotaWhenQuotaIsDisabled(): Unit = { + withQuotaManager { quotaManager => + assertEquals(UnboundedControllerMutationQuota, + quotaManager.newStrictQuotaFor(buildSession(User), ClientId)) + } + } + + @Test + def testNewStrictQuotaForReturnsStrictQuotaWhenQuotaIsEnabled(): Unit = { + withQuotaManager { quotaManager => + quotaManager.updateQuota(Some(User), Some(ClientId), Some(ClientId), + Some(Quota.upperBound(10))) + val quota = quotaManager.newStrictQuotaFor(buildSession(User), ClientId) + assertTrue(quota.isInstanceOf[StrictControllerMutationQuota]) + + } + } + + @Test + def testNewPermissiveQuotaForReturnsUnboundedQuotaWhenQuotaIsDisabled(): Unit = { + withQuotaManager { quotaManager => + assertEquals(UnboundedControllerMutationQuota, + quotaManager.newPermissiveQuotaFor(buildSession(User), ClientId)) + } + } + + @Test + def testNewPermissiveQuotaForReturnsStrictQuotaWhenQuotaIsEnabled(): Unit = { + withQuotaManager { quotaManager => + quotaManager.updateQuota(Some(User), Some(ClientId), Some(ClientId), + Some(Quota.upperBound(10))) + val quota = quotaManager.newPermissiveQuotaFor(buildSession(User), ClientId) + assertTrue(quota.isInstanceOf[PermissiveControllerMutationQuota]) + } + } +} diff --git a/core/src/test/scala/unit/kafka/server/ControllerMutationQuotaTest.scala b/core/src/test/scala/unit/kafka/server/ControllerMutationQuotaTest.scala new file mode 100644 index 0000000..c05044e --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/ControllerMutationQuotaTest.scala @@ -0,0 +1,419 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + **/ +package kafka.server + +import java.util.Properties +import java.util.concurrent.ExecutionException +import java.util.concurrent.TimeUnit +import kafka.server.ClientQuotaManager.DefaultTags +import kafka.utils.TestUtils +import org.apache.kafka.common.config.internals.QuotaConfigs +import org.apache.kafka.common.internals.KafkaFutureImpl +import org.apache.kafka.common.message.CreatePartitionsRequestData +import org.apache.kafka.common.message.CreatePartitionsRequestData.CreatePartitionsTopic +import org.apache.kafka.common.message.CreateTopicsRequestData +import org.apache.kafka.common.message.CreateTopicsRequestData.CreatableTopic +import org.apache.kafka.common.message.DeleteTopicsRequestData +import org.apache.kafka.common.metrics.KafkaMetric +import org.apache.kafka.common.protocol.ApiKeys +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.quota.ClientQuotaAlteration +import org.apache.kafka.common.quota.ClientQuotaEntity +import org.apache.kafka.common.requests.AlterClientQuotasRequest +import org.apache.kafka.common.requests.AlterClientQuotasResponse +import org.apache.kafka.common.requests.CreatePartitionsRequest +import org.apache.kafka.common.requests.CreatePartitionsResponse +import org.apache.kafka.common.requests.CreateTopicsRequest +import org.apache.kafka.common.requests.CreateTopicsResponse +import org.apache.kafka.common.requests.DeleteTopicsRequest +import org.apache.kafka.common.requests.DeleteTopicsResponse +import org.apache.kafka.common.security.auth.AuthenticationContext +import org.apache.kafka.common.security.auth.KafkaPrincipal +import org.apache.kafka.common.security.authenticator.DefaultKafkaPrincipalBuilder +import org.apache.kafka.test.{TestUtils => JTestUtils} +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Assertions.fail +import org.junit.jupiter.api.{BeforeEach, Test, TestInfo} + +import scala.jdk.CollectionConverters._ + +object ControllerMutationQuotaTest { + // Principal used for all client connections. This is updated by each test. + var principal = KafkaPrincipal.ANONYMOUS + class TestPrincipalBuilder extends DefaultKafkaPrincipalBuilder(null, null) { + override def build(context: AuthenticationContext): KafkaPrincipal = { + principal + } + } + + def asPrincipal(newPrincipal: KafkaPrincipal)(f: => Unit): Unit = { + val currentPrincipal = principal + principal = newPrincipal + try f + finally principal = currentPrincipal + } + + val ThrottledPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "ThrottledPrincipal") + val UnboundedPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "UnboundedPrincipal") + + val StrictCreateTopicsRequestVersion = ApiKeys.CREATE_TOPICS.latestVersion + val PermissiveCreateTopicsRequestVersion = 5.toShort + + val StrictDeleteTopicsRequestVersion = ApiKeys.DELETE_TOPICS.latestVersion + val PermissiveDeleteTopicsRequestVersion = 4.toShort + + val StrictCreatePartitionsRequestVersion = ApiKeys.CREATE_PARTITIONS.latestVersion + val PermissiveCreatePartitionsRequestVersion = 2.toShort + + val Topic1 = "topic-1" + val Topic2 = "topic-2" + val TopicsWithOnePartition = Map(Topic1 -> 1, Topic2 -> 1) + val TopicsWith30Partitions = Map(Topic1 -> 30, Topic2 -> 30) + val TopicsWith31Partitions = Map(Topic1 -> 31, Topic2 -> 31) + + val ControllerQuotaSamples = 10 + val ControllerQuotaWindowSizeSeconds = 1 + val ControllerMutationRate = 2.0 +} + +class ControllerMutationQuotaTest extends BaseRequestTest { + import ControllerMutationQuotaTest._ + + override def brokerCount: Int = 1 + + override def brokerPropertyOverrides(properties: Properties): Unit = { + properties.put(KafkaConfig.ControlledShutdownEnableProp, "false") + properties.put(KafkaConfig.OffsetsTopicReplicationFactorProp, "1") + properties.put(KafkaConfig.OffsetsTopicPartitionsProp, "1") + properties.put(KafkaConfig.PrincipalBuilderClassProp, + classOf[ControllerMutationQuotaTest.TestPrincipalBuilder].getName) + // Specify number of samples and window size. + properties.put(KafkaConfig.NumControllerQuotaSamplesProp, ControllerQuotaSamples.toString) + properties.put(KafkaConfig.ControllerQuotaWindowSizeSecondsProp, ControllerQuotaWindowSizeSeconds.toString) + } + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + + // Define a quota for ThrottledPrincipal + defineUserQuota(ThrottledPrincipal.getName, Some(ControllerMutationRate)) + waitUserQuota(ThrottledPrincipal.getName, ControllerMutationRate) + } + + @Test + def testSetUnsetQuota(): Unit = { + val rate = 1.5 + val principal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "User") + // Default Value + waitUserQuota(principal.getName, Long.MaxValue) + // Define a new quota + defineUserQuota(principal.getName, Some(rate)) + // Check it + waitUserQuota(principal.getName, rate) + // Remove it + defineUserQuota(principal.getName, None) + // Back to the default + waitUserQuota(principal.getName, Long.MaxValue) + } + + @Test + def testQuotaMetric(): Unit = { + asPrincipal(ThrottledPrincipal) { + // Metric is lazily created + assertTrue(quotaMetric(principal.getName).isEmpty) + + // Create a topic to create the metrics + val (_, errors) = createTopics(Map("topic" -> 1), StrictDeleteTopicsRequestVersion) + assertEquals(Set(Errors.NONE), errors.values.toSet) + + // Metric must be there with the correct config + waitQuotaMetric(principal.getName, ControllerMutationRate) + + // Update quota + defineUserQuota(ThrottledPrincipal.getName, Some(ControllerMutationRate * 2)) + waitUserQuota(ThrottledPrincipal.getName, ControllerMutationRate * 2) + + // Metric must be there with the updated config + waitQuotaMetric(principal.getName, ControllerMutationRate * 2) + } + } + + @Test + def testStrictCreateTopicsRequest(): Unit = { + asPrincipal(ThrottledPrincipal) { + // Create two topics worth of 30 partitions each. As we use a strict quota, we + // expect one to be created and one to be rejected. + // Theoretically, the throttle time should be below or equal to: + // -(-10) / 2 = 5s + val (throttleTimeMs1, errors1) = createTopics(TopicsWith30Partitions, StrictCreateTopicsRequestVersion) + assertThrottleTime(5000, throttleTimeMs1) + // Ordering is not guaranteed so we only check the errors + assertEquals(Set(Errors.NONE, Errors.THROTTLING_QUOTA_EXCEEDED), errors1.values.toSet) + + // Retry the rejected topic. It should succeed after the throttling delay is passed and the + // throttle time should be zero. + val rejectedTopicName = errors1.filter(_._2 == Errors.THROTTLING_QUOTA_EXCEEDED).keys.head + val rejectedTopicSpec = TopicsWith30Partitions.filter(_._1 == rejectedTopicName) + TestUtils.waitUntilTrue(() => { + val (throttleTimeMs2, errors2) = createTopics(rejectedTopicSpec, StrictCreateTopicsRequestVersion) + throttleTimeMs2 == 0 && errors2 == Map(rejectedTopicName -> Errors.NONE) + }, "Failed to create topics after having been throttled") + } + } + + @Test + def testPermissiveCreateTopicsRequest(): Unit = { + asPrincipal(ThrottledPrincipal) { + // Create two topics worth of 30 partitions each. As we use a permissive quota, we + // expect both topics to be created. + // Theoretically, the throttle time should be below or equal to: + // -(-40) / 2 = 20s + val (throttleTimeMs, errors) = createTopics(TopicsWith30Partitions, PermissiveCreateTopicsRequestVersion) + assertThrottleTime(20000, throttleTimeMs) + assertEquals(Map(Topic1 -> Errors.NONE, Topic2 -> Errors.NONE), errors) + } + } + + @Test + def testUnboundedCreateTopicsRequest(): Unit = { + asPrincipal(UnboundedPrincipal) { + // Create two topics worth of 30 partitions each. As we use an user without quota, we + // expect both topics to be created. The throttle time should be equal to 0. + val (throttleTimeMs, errors) = createTopics(TopicsWith30Partitions, StrictCreateTopicsRequestVersion) + assertEquals(0, throttleTimeMs) + assertEquals(Map(Topic1 -> Errors.NONE, Topic2 -> Errors.NONE), errors) + } + } + + @Test + def testStrictDeleteTopicsRequest(): Unit = { + asPrincipal(UnboundedPrincipal) { + createTopics(TopicsWith30Partitions, StrictCreateTopicsRequestVersion) + } + + asPrincipal(ThrottledPrincipal) { + // Delete two topics worth of 30 partitions each. As we use a strict quota, we + // expect the first topic to be deleted and the second to be rejected. + // Theoretically, the throttle time should be below or equal to: + // -(-10) / 2 = 5s + val (throttleTimeMs1, errors1) = deleteTopics(TopicsWith30Partitions, StrictDeleteTopicsRequestVersion) + assertThrottleTime(5000, throttleTimeMs1) + // Ordering is not guaranteed so we only check the errors + assertEquals(Set(Errors.NONE, Errors.THROTTLING_QUOTA_EXCEEDED), errors1.values.toSet) + + // Retry the rejected topic. It should succeed after the throttling delay is passed and the + // throttle time should be zero. + val rejectedTopicName = errors1.filter(_._2 == Errors.THROTTLING_QUOTA_EXCEEDED).keys.head + val rejectedTopicSpec = TopicsWith30Partitions.filter(_._1 == rejectedTopicName) + TestUtils.waitUntilTrue(() => { + val (throttleTimeMs2, errors2) = deleteTopics(rejectedTopicSpec, StrictDeleteTopicsRequestVersion) + throttleTimeMs2 == 0 && errors2 == Map(rejectedTopicName -> Errors.NONE) + }, "Failed to delete topics after having been throttled") + } + } + + @Test + def testPermissiveDeleteTopicsRequest(): Unit = { + asPrincipal(UnboundedPrincipal) { + createTopics(TopicsWith30Partitions, StrictCreateTopicsRequestVersion) + } + + asPrincipal(ThrottledPrincipal) { + // Delete two topics worth of 30 partitions each. As we use a permissive quota, we + // expect both topics to be deleted. + // Theoretically, the throttle time should be below or equal to: + // -(-40) / 2 = 20s + val (throttleTimeMs, errors) = deleteTopics(TopicsWith30Partitions, PermissiveDeleteTopicsRequestVersion) + assertThrottleTime(20000, throttleTimeMs) + assertEquals(Map(Topic1 -> Errors.NONE, Topic2 -> Errors.NONE), errors) + } + } + + @Test + def testUnboundedDeleteTopicsRequest(): Unit = { + asPrincipal(UnboundedPrincipal) { + createTopics(TopicsWith30Partitions, StrictCreateTopicsRequestVersion) + + // Delete two topics worth of 30 partitions each. As we use an user without quota, we + // expect both topics to be deleted. The throttle time should be equal to 0. + val (throttleTimeMs, errors) = deleteTopics(TopicsWith30Partitions, StrictDeleteTopicsRequestVersion) + assertEquals(0, throttleTimeMs) + assertEquals(Map(Topic1 -> Errors.NONE, Topic2 -> Errors.NONE), errors) + } + } + + @Test + def testStrictCreatePartitionsRequest(): Unit = { + asPrincipal(UnboundedPrincipal) { + createTopics(TopicsWithOnePartition, StrictCreatePartitionsRequestVersion) + } + + asPrincipal(ThrottledPrincipal) { + // Add 30 partitions to each topic. As we use a strict quota, we + // expect the first topic to be extended and the second to be rejected. + // Theoretically, the throttle time should be below or equal to: + // -(-10) / 2 = 5s + val (throttleTimeMs1, errors1) = createPartitions(TopicsWith31Partitions, StrictCreatePartitionsRequestVersion) + assertThrottleTime(5000, throttleTimeMs1) + // Ordering is not guaranteed so we only check the errors + assertEquals(Set(Errors.NONE, Errors.THROTTLING_QUOTA_EXCEEDED), errors1.values.toSet) + + // Retry the rejected topic. It should succeed after the throttling delay is passed and the + // throttle time should be zero. + val rejectedTopicName = errors1.filter(_._2 == Errors.THROTTLING_QUOTA_EXCEEDED).keys.head + val rejectedTopicSpec = TopicsWith30Partitions.filter(_._1 == rejectedTopicName) + TestUtils.waitUntilTrue(() => { + val (throttleTimeMs2, errors2) = createPartitions(rejectedTopicSpec, StrictCreatePartitionsRequestVersion) + throttleTimeMs2 == 0 && errors2 == Map(rejectedTopicName -> Errors.NONE) + }, "Failed to create partitions after having been throttled") + } + } + + @Test + def testPermissiveCreatePartitionsRequest(): Unit = { + asPrincipal(UnboundedPrincipal) { + createTopics(TopicsWithOnePartition, StrictCreatePartitionsRequestVersion) + } + + asPrincipal(ThrottledPrincipal) { + // Create two topics worth of 30 partitions each. As we use a permissive quota, we + // expect both topics to be created. + // Theoretically, the throttle time should be below or equal to: + // -(-40) / 2 = 20s + val (throttleTimeMs, errors) = createPartitions(TopicsWith31Partitions, PermissiveCreatePartitionsRequestVersion) + assertThrottleTime(20000, throttleTimeMs) + assertEquals(Map(Topic1 -> Errors.NONE, Topic2 -> Errors.NONE), errors) + } + } + + @Test + def testUnboundedCreatePartitionsRequest(): Unit = { + asPrincipal(UnboundedPrincipal) { + createTopics(TopicsWithOnePartition, StrictCreatePartitionsRequestVersion) + + // Create two topics worth of 30 partitions each. As we use an user without quota, we + // expect both topics to be created. The throttle time should be equal to 0. + val (throttleTimeMs, errors) = createPartitions(TopicsWith31Partitions, StrictCreatePartitionsRequestVersion) + assertEquals(0, throttleTimeMs) + assertEquals(Map(Topic1 -> Errors.NONE, Topic2 -> Errors.NONE), errors) + } + } + + private def assertThrottleTime(max: Int, actual: Int): Unit = { + assertTrue( + (actual >= 0) && (actual <= max), + s"Expected a throttle time between 0 and $max but got $actual") + } + + private def createTopics(topics: Map[String, Int], version: Short): (Int, Map[String, Errors]) = { + val data = new CreateTopicsRequestData() + topics.foreach { case (topic, numPartitions) => + data.topics.add(new CreatableTopic() + .setName(topic).setNumPartitions(numPartitions).setReplicationFactor(1)) + } + val request = new CreateTopicsRequest.Builder(data).build(version) + val response = connectAndReceive[CreateTopicsResponse](request) + response.data.throttleTimeMs -> response.data.topics.asScala + .map(topic => topic.name -> Errors.forCode(topic.errorCode)).toMap + } + + private def deleteTopics(topics: Map[String, Int], version: Short): (Int, Map[String, Errors]) = { + val data = new DeleteTopicsRequestData() + .setTimeoutMs(60000) + .setTopicNames(topics.keys.toSeq.asJava) + val request = new DeleteTopicsRequest.Builder(data).build(version) + val response = connectAndReceive[DeleteTopicsResponse](request) + response.data.throttleTimeMs -> response.data.responses.asScala + .map(topic => topic.name -> Errors.forCode(topic.errorCode)).toMap + } + + private def createPartitions(topics: Map[String, Int], version: Short): (Int, Map[String, Errors]) = { + val data = new CreatePartitionsRequestData().setTimeoutMs(60000) + topics.foreach { case (topic, numPartitions) => + data.topics.add(new CreatePartitionsTopic() + .setName(topic).setCount(numPartitions).setAssignments(null)) + } + val request = new CreatePartitionsRequest.Builder(data).build(version) + val response = connectAndReceive[CreatePartitionsResponse](request) + response.data.throttleTimeMs -> response.data.results.asScala + .map(topic => topic.name -> Errors.forCode(topic.errorCode)).toMap + } + + private def defineUserQuota(user: String, quota: Option[Double]): Unit = { + val entity = new ClientQuotaEntity(Map(ClientQuotaEntity.USER -> user).asJava) + val quotas = Map(QuotaConfigs.CONTROLLER_MUTATION_RATE_OVERRIDE_CONFIG -> quota) + + try alterClientQuotas(Map(entity -> quotas))(entity).get(10, TimeUnit.SECONDS) catch { + case e: ExecutionException => throw e.getCause + } + } + + private def waitUserQuota(user: String, expectedQuota: Double): Unit = { + val quotaManager = servers.head.quotaManagers.controllerMutation + var actualQuota = Double.MinValue + + TestUtils.waitUntilTrue(() => { + actualQuota = quotaManager.quota(user, "").bound() + expectedQuota == actualQuota + }, s"Quota of $user is not $expectedQuota but $actualQuota") + } + + private def quotaMetric(user: String): Option[KafkaMetric] = { + val metrics = servers.head.metrics + val metricName = metrics.metricName( + "tokens", + QuotaType.ControllerMutation.toString, + "Tracking remaining tokens in the token bucket per user/client-id", + Map(DefaultTags.User -> user, DefaultTags.ClientId -> "").asJava) + Option(servers.head.metrics.metric(metricName)) + } + + private def waitQuotaMetric(user: String, expectedQuota: Double): Unit = { + TestUtils.retry(JTestUtils.DEFAULT_MAX_WAIT_MS) { + quotaMetric(user) match { + case Some(metric) => + val config = metric.config() + assertEquals(expectedQuota, config.quota().bound(), 0.1) + assertEquals(ControllerQuotaSamples, config.samples()) + assertEquals(ControllerQuotaWindowSizeSeconds * 1000, config.timeWindowMs()) + + case None => + fail(s"Quota metric of $user is not defined") + } + } + } + + private def alterClientQuotas(request: Map[ClientQuotaEntity, Map[String, Option[Double]]]): Map[ClientQuotaEntity, KafkaFutureImpl[Void]] = { + val entries = request.map { case (entity, alter) => + val ops = alter.map { case (key, value) => + new ClientQuotaAlteration.Op(key, value.map(Double.box).orNull) + }.asJavaCollection + new ClientQuotaAlteration(entity, ops) + } + + val response = request.map(e => e._1 -> new KafkaFutureImpl[Void]).asJava + sendAlterClientQuotasRequest(entries).complete(response) + val result = response.asScala + assertEquals(request.size, result.size) + request.foreach(e => assertTrue(result.get(e._1).isDefined)) + result.toMap + } + + private def sendAlterClientQuotasRequest(entries: Iterable[ClientQuotaAlteration]): AlterClientQuotasResponse = { + val request = new AlterClientQuotasRequest.Builder(entries.asJavaCollection, false).build() + connectAndReceive[AlterClientQuotasResponse](request, destination = controllerSocketServer) + } +} diff --git a/core/src/test/scala/unit/kafka/server/CreateTopicsRequestTest.scala b/core/src/test/scala/unit/kafka/server/CreateTopicsRequestTest.scala new file mode 100644 index 0000000..0f72ee2 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/CreateTopicsRequestTest.scala @@ -0,0 +1,181 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.utils._ +import org.apache.kafka.common.Uuid +import org.apache.kafka.common.message.CreateTopicsRequestData +import org.apache.kafka.common.message.CreateTopicsRequestData.CreatableTopicCollection +import org.apache.kafka.common.protocol.ApiKeys +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.CreateTopicsRequest +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ + +class CreateTopicsRequestTest extends AbstractCreateTopicsRequestTest { + + @Test + def testValidCreateTopicsRequests(): Unit = { + // Generated assignments + validateValidCreateTopicsRequests(topicsReq(Seq(topicReq("topic1")))) + validateValidCreateTopicsRequests(topicsReq(Seq(topicReq("topic2", replicationFactor = 3)))) + validateValidCreateTopicsRequests(topicsReq(Seq(topicReq("topic3", + numPartitions = 5, replicationFactor = 2, config = Map("min.insync.replicas" -> "2"))))) + // Manual assignments + validateValidCreateTopicsRequests(topicsReq(Seq(topicReq("topic4", assignment = Map(0 -> List(0)))))) + validateValidCreateTopicsRequests(topicsReq(Seq(topicReq("topic5", + assignment = Map(0 -> List(0, 1), 1 -> List(1, 0), 2 -> List(1, 2)), + config = Map("min.insync.replicas" -> "2"))))) + // Mixed + validateValidCreateTopicsRequests(topicsReq(Seq(topicReq("topic6"), + topicReq("topic7", numPartitions = 5, replicationFactor = 2), + topicReq("topic8", assignment = Map(0 -> List(0, 1), 1 -> List(1, 0), 2 -> List(1, 2)))))) + validateValidCreateTopicsRequests(topicsReq(Seq(topicReq("topic9"), + topicReq("topic10", numPartitions = 5, replicationFactor = 2), + topicReq("topic11", assignment = Map(0 -> List(0, 1), 1 -> List(1, 0), 2 -> List(1, 2)))), + validateOnly = true)) + // Defaults + validateValidCreateTopicsRequests(topicsReq(Seq( + topicReq("topic12", replicationFactor = -1, numPartitions = -1)))) + validateValidCreateTopicsRequests(topicsReq(Seq( + topicReq("topic13", replicationFactor = 2, numPartitions = -1)))) + validateValidCreateTopicsRequests(topicsReq(Seq( + topicReq("topic14", replicationFactor = -1, numPartitions = 2)))) + } + + @Test + def testErrorCreateTopicsRequests(): Unit = { + val existingTopic = "existing-topic" + createTopic(existingTopic, 1, 1) + // Basic + validateErrorCreateTopicsRequests(topicsReq(Seq(topicReq(existingTopic))), + Map(existingTopic -> error(Errors.TOPIC_ALREADY_EXISTS, Some("Topic 'existing-topic' already exists.")))) + validateErrorCreateTopicsRequests(topicsReq(Seq(topicReq("error-partitions", numPartitions = -2))), + Map("error-partitions" -> error(Errors.INVALID_PARTITIONS)), checkErrorMessage = false) + validateErrorCreateTopicsRequests(topicsReq(Seq(topicReq("error-replication", + replicationFactor = brokerCount + 1))), + Map("error-replication" -> error(Errors.INVALID_REPLICATION_FACTOR)), checkErrorMessage = false) + validateErrorCreateTopicsRequests(topicsReq(Seq(topicReq("error-config", + config=Map("not.a.property" -> "error")))), + Map("error-config" -> error(Errors.INVALID_CONFIG)), checkErrorMessage = false) + validateErrorCreateTopicsRequests(topicsReq(Seq(topicReq("error-config-value", + config=Map("message.format.version" -> "invalid-value")))), + Map("error-config-value" -> error(Errors.INVALID_CONFIG)), checkErrorMessage = false) + validateErrorCreateTopicsRequests(topicsReq(Seq(topicReq("error-assignment", + assignment=Map(0 -> List(0, 1), 1 -> List(0))))), + Map("error-assignment" -> error(Errors.INVALID_REPLICA_ASSIGNMENT)), checkErrorMessage = false) + + // Partial + validateErrorCreateTopicsRequests(topicsReq(Seq( + topicReq(existingTopic), + topicReq("partial-partitions", numPartitions = -2), + topicReq("partial-replication", replicationFactor=brokerCount + 1), + topicReq("partial-assignment", assignment=Map(0 -> List(0, 1), 1 -> List(0))), + topicReq("partial-none"))), + Map( + existingTopic -> error(Errors.TOPIC_ALREADY_EXISTS), + "partial-partitions" -> error(Errors.INVALID_PARTITIONS), + "partial-replication" -> error(Errors.INVALID_REPLICATION_FACTOR), + "partial-assignment" -> error(Errors.INVALID_REPLICA_ASSIGNMENT), + "partial-none" -> error(Errors.NONE) + ), checkErrorMessage = false + ) + validateTopicExists("partial-none") + + // Timeout + // We don't expect a request to ever complete within 1ms. A timeout of 1 ms allows us to test the purgatory timeout logic. + validateErrorCreateTopicsRequests(topicsReq(Seq( + topicReq("error-timeout", numPartitions = 10, replicationFactor = 3)), timeout = 1), + Map("error-timeout" -> error(Errors.REQUEST_TIMED_OUT)), checkErrorMessage = false) + validateErrorCreateTopicsRequests(topicsReq(Seq( + topicReq("error-timeout-zero", numPartitions = 10, replicationFactor = 3)), timeout = 0), + Map("error-timeout-zero" -> error(Errors.REQUEST_TIMED_OUT)), checkErrorMessage = false) + // Negative timeouts are treated the same as 0 + validateErrorCreateTopicsRequests(topicsReq(Seq( + topicReq("error-timeout-negative", numPartitions = 10, replicationFactor = 3)), timeout = -1), + Map("error-timeout-negative" -> error(Errors.REQUEST_TIMED_OUT)), checkErrorMessage = false) + // The topics should still get created eventually + TestUtils.waitForPartitionMetadata(servers, "error-timeout", 0) + TestUtils.waitForPartitionMetadata(servers, "error-timeout-zero", 0) + TestUtils.waitForPartitionMetadata(servers, "error-timeout-negative", 0) + validateTopicExists("error-timeout") + validateTopicExists("error-timeout-zero") + validateTopicExists("error-timeout-negative") + } + + @Test + def testInvalidCreateTopicsRequests(): Unit = { + // Partitions/ReplicationFactor and ReplicaAssignment + validateErrorCreateTopicsRequests(topicsReq(Seq( + topicReq("bad-args-topic", numPartitions = 10, replicationFactor = 3, + assignment = Map(0 -> List(0))))), + Map("bad-args-topic" -> error(Errors.INVALID_REQUEST)), checkErrorMessage = false) + + validateErrorCreateTopicsRequests(topicsReq(Seq( + topicReq("bad-args-topic", numPartitions = 10, replicationFactor = 3, + assignment = Map(0 -> List(0)))), validateOnly = true), + Map("bad-args-topic" -> error(Errors.INVALID_REQUEST)), checkErrorMessage = false) + } + + @Test + def testNotController(): Unit = { + val req = topicsReq(Seq(topicReq("topic1"))) + val response = sendCreateTopicRequest(req, notControllerSocketServer) + assertEquals(1, response.errorCounts().get(Errors.NOT_CONTROLLER)) + } + + @Test + def testCreateTopicsRequestVersions(): Unit = { + for (version <- ApiKeys.CREATE_TOPICS.oldestVersion to ApiKeys.CREATE_TOPICS.latestVersion) { + val topic = s"topic_$version" + val data = new CreateTopicsRequestData() + data.setTimeoutMs(10000) + data.setValidateOnly(false) + data.setTopics(new CreatableTopicCollection(List( + topicReq(topic, numPartitions = 1, replicationFactor = 1, + config = Map("min.insync.replicas" -> "2")) + ).asJava.iterator())) + + val request = new CreateTopicsRequest.Builder(data).build(version.asInstanceOf[Short]) + val response = sendCreateTopicRequest(request) + + val topicResponse = response.data.topics.find(topic) + assertNotNull(topicResponse) + assertEquals(topic, topicResponse.name) + assertEquals(Errors.NONE.code, topicResponse.errorCode) + if (version >= 5) { + assertEquals(1, topicResponse.numPartitions) + assertEquals(1, topicResponse.replicationFactor) + val config = topicResponse.configs().asScala.find(_.name == "min.insync.replicas") + assertTrue(config.isDefined) + assertEquals("2", config.get.value) + } else { + assertEquals(-1, topicResponse.numPartitions) + assertEquals(-1, topicResponse.replicationFactor) + assertTrue(topicResponse.configs.isEmpty) + } + + if (version >= 7) + assertNotEquals(Uuid.ZERO_UUID, topicResponse.topicId()) + else + assertEquals(Uuid.ZERO_UUID, topicResponse.topicId()) + } + } +} diff --git a/core/src/test/scala/unit/kafka/server/CreateTopicsRequestWithForwardingTest.scala b/core/src/test/scala/unit/kafka/server/CreateTopicsRequestWithForwardingTest.scala new file mode 100644 index 0000000..f55aa34 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/CreateTopicsRequestWithForwardingTest.scala @@ -0,0 +1,37 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import org.apache.kafka.common.protocol.Errors +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ + +class CreateTopicsRequestWithForwardingTest extends AbstractCreateTopicsRequestTest { + + override def enableForwarding: Boolean = true + + @Test + def testForwardToController(): Unit = { + val req = topicsReq(Seq(topicReq("topic1"))) + val response = sendCreateTopicRequest(req, notControllerSocketServer) + // With forwarding enabled, request could be forwarded to the active controller. + assertEquals(Map(Errors.NONE -> 1), response.errorCounts().asScala) + } +} diff --git a/core/src/test/scala/unit/kafka/server/CreateTopicsRequestWithPolicyTest.scala b/core/src/test/scala/unit/kafka/server/CreateTopicsRequestWithPolicyTest.scala new file mode 100644 index 0000000..b5456c6 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/CreateTopicsRequestWithPolicyTest.scala @@ -0,0 +1,161 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util +import java.util.Properties + +import kafka.log.LogConfig +import org.apache.kafka.common.errors.PolicyViolationException +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.server.policy.CreateTopicPolicy +import org.apache.kafka.server.policy.CreateTopicPolicy.RequestMetadata +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ + +class CreateTopicsRequestWithPolicyTest extends AbstractCreateTopicsRequestTest { + import CreateTopicsRequestWithPolicyTest._ + + override def brokerPropertyOverrides(properties: Properties): Unit = { + super.brokerPropertyOverrides(properties) + properties.put(KafkaConfig.CreateTopicPolicyClassNameProp, classOf[Policy].getName) + } + + @Test + def testValidCreateTopicsRequests(): Unit = { + validateValidCreateTopicsRequests(topicsReq(Seq(topicReq("topic1", + numPartitions = 5)))) + + validateValidCreateTopicsRequests(topicsReq(Seq(topicReq("topic2", + numPartitions = 5, replicationFactor = 3)), + validateOnly = true)) + + validateValidCreateTopicsRequests(topicsReq(Seq(topicReq("topic3", + numPartitions = 11, replicationFactor = 2, + config = Map(LogConfig.RetentionMsProp -> 4999.toString))), + validateOnly = true)) + + validateValidCreateTopicsRequests(topicsReq(Seq(topicReq("topic4", + assignment = Map(0 -> List(1, 0), 1 -> List(0, 1)))))) + } + + @Test + def testErrorCreateTopicsRequests(): Unit = { + val existingTopic = "existing-topic" + createTopic(existingTopic, 1, 1) + + // Policy violations + validateErrorCreateTopicsRequests(topicsReq(Seq(topicReq("policy-topic1", + numPartitions = 4, replicationFactor = 1))), + Map("policy-topic1" -> error(Errors.POLICY_VIOLATION, Some("Topics should have at least 5 partitions, received 4")))) + + validateErrorCreateTopicsRequests(topicsReq(Seq(topicReq("policy-topic2", + numPartitions = 4, replicationFactor = 3)), validateOnly = true), + Map("policy-topic2" -> error(Errors.POLICY_VIOLATION, Some("Topics should have at least 5 partitions, received 4")))) + + validateErrorCreateTopicsRequests(topicsReq(Seq(topicReq("policy-topic3", + numPartitions = 11, replicationFactor = 2, + config = Map(LogConfig.RetentionMsProp -> 5001.toString))), validateOnly = true), + Map("policy-topic3" -> error(Errors.POLICY_VIOLATION, + Some("RetentionMs should be less than 5000ms if replicationFactor > 5")))) + + validateErrorCreateTopicsRequests(topicsReq(Seq(topicReq("policy-topic4", + numPartitions = 11, replicationFactor = 3, + config = Map(LogConfig.RetentionMsProp -> 5001.toString))), validateOnly = true), + Map("policy-topic4" -> error(Errors.POLICY_VIOLATION, + Some("RetentionMs should be less than 5000ms if replicationFactor > 5")))) + + validateErrorCreateTopicsRequests(topicsReq(Seq(topicReq("policy-topic5", + assignment = Map(0 -> List(1), 1 -> List(0)), + config = Map(LogConfig.RetentionMsProp -> 5001.toString))), validateOnly = true), + Map("policy-topic5" -> error(Errors.POLICY_VIOLATION, + Some("Topic partitions should have at least 2 partitions, received 1 for partition 0")))) + + // Check that basic errors still work + validateErrorCreateTopicsRequests(topicsReq(Seq(topicReq(existingTopic, + numPartitions = 5, replicationFactor = 1))), + Map(existingTopic -> error(Errors.TOPIC_ALREADY_EXISTS, + Some("Topic 'existing-topic' already exists.")))) + + validateErrorCreateTopicsRequests(topicsReq(Seq(topicReq("error-replication", + numPartitions = 10, replicationFactor = brokerCount + 1)), validateOnly = true), + Map("error-replication" -> error(Errors.INVALID_REPLICATION_FACTOR, + Some("Replication factor: 4 larger than available brokers: 3.")))) + + validateErrorCreateTopicsRequests(topicsReq(Seq(topicReq("error-replication2", + numPartitions = 10, replicationFactor = -2)), validateOnly = true), + Map("error-replication2" -> error(Errors.INVALID_REPLICATION_FACTOR, + Some("Replication factor must be larger than 0.")))) + + validateErrorCreateTopicsRequests(topicsReq(Seq(topicReq("error-partitions", + numPartitions = -2, replicationFactor = 1)), validateOnly = true), + Map("error-partitions" -> error(Errors.INVALID_PARTITIONS, + Some("Number of partitions must be larger than 0.")))) + } + +} + +object CreateTopicsRequestWithPolicyTest { + + class Policy extends CreateTopicPolicy { + + var configs: Map[String, _] = _ + var closed = false + + def configure(configs: util.Map[String, _]): Unit = { + this.configs = configs.asScala.toMap + } + + def validate(requestMetadata: RequestMetadata): Unit = { + require(!closed, "Policy should not be closed") + require(!configs.isEmpty, "configure should have been called with non empty configs") + + import requestMetadata._ + if (numPartitions != null || replicationFactor != null) { + require(numPartitions != null, s"numPartitions should not be null, but it is $numPartitions") + require(replicationFactor != null, s"replicationFactor should not be null, but it is $replicationFactor") + require(replicasAssignments == null, s"replicaAssigments should be null, but it is $replicasAssignments") + + if (numPartitions < 5) + throw new PolicyViolationException(s"Topics should have at least 5 partitions, received $numPartitions") + + if (numPartitions > 10) { + if (requestMetadata.configs.asScala.get(LogConfig.RetentionMsProp).fold(true)(_.toInt > 5000)) + throw new PolicyViolationException("RetentionMs should be less than 5000ms if replicationFactor > 5") + } else + require(requestMetadata.configs.isEmpty, s"Topic configs should be empty, but it is ${requestMetadata.configs}") + + } else { + require(numPartitions == null, s"numPartitions should be null, but it is $numPartitions") + require(replicationFactor == null, s"replicationFactor should be null, but it is $replicationFactor") + require(replicasAssignments != null, s"replicaAssigments should not be null, but it is $replicasAssignments") + + replicasAssignments.asScala.toSeq.sortBy { case (tp, _) => tp }.foreach { case (partitionId, assignment) => + if (assignment.size < 2) + throw new PolicyViolationException("Topic partitions should have at least 2 partitions, received " + + s"${assignment.size} for partition $partitionId") + } + } + + } + + def close(): Unit = closed = true + + } +} diff --git a/core/src/test/scala/unit/kafka/server/DelayedOperationTest.scala b/core/src/test/scala/unit/kafka/server/DelayedOperationTest.scala new file mode 100644 index 0000000..5df53c8 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/DelayedOperationTest.scala @@ -0,0 +1,404 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util.Random +import java.util.concurrent._ +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.locks.ReentrantLock + +import kafka.utils.CoreUtils.inLock +import kafka.utils.TestUtils +import org.apache.kafka.common.utils.Time +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} +import org.junit.jupiter.api.Assertions._ + +import scala.jdk.CollectionConverters._ + +class DelayedOperationTest { + + var purgatory: DelayedOperationPurgatory[DelayedOperation] = null + var executorService: ExecutorService = null + + @BeforeEach + def setUp(): Unit = { + purgatory = DelayedOperationPurgatory[DelayedOperation](purgatoryName = "mock") + } + + @AfterEach + def tearDown(): Unit = { + purgatory.shutdown() + if (executorService != null) + executorService.shutdown() + } + + @Test + def testLockInTryCompleteElseWatch(): Unit = { + val op = new DelayedOperation(100000L) { + override def onExpiration(): Unit = {} + override def onComplete(): Unit = {} + override def tryComplete(): Boolean = { + assertTrue(lock.asInstanceOf[ReentrantLock].isHeldByCurrentThread) + false + } + override def safeTryComplete(): Boolean = { + fail("tryCompleteElseWatch should not use safeTryComplete") + super.safeTryComplete() + } + } + purgatory.tryCompleteElseWatch(op, Seq("key")) + } + + @Test + def testSafeTryCompleteOrElse(): Unit = { + def op(shouldComplete: Boolean) = new DelayedOperation(100000L) { + override def onExpiration(): Unit = {} + override def onComplete(): Unit = {} + override def tryComplete(): Boolean = { + assertTrue(lock.asInstanceOf[ReentrantLock].isHeldByCurrentThread) + shouldComplete + } + } + var pass = false + assertFalse(op(false).safeTryCompleteOrElse { + pass = true + }) + assertTrue(pass) + assertTrue(op(true).safeTryCompleteOrElse { + fail("this method should NOT be executed") + }) + } + + @Test + def testRequestSatisfaction(): Unit = { + val r1 = new MockDelayedOperation(100000L) + val r2 = new MockDelayedOperation(100000L) + assertEquals(0, purgatory.checkAndComplete("test1"), "With no waiting requests, nothing should be satisfied") + assertFalse(purgatory.tryCompleteElseWatch(r1, Array("test1")), "r1 not satisfied and hence watched") + assertEquals(0, purgatory.checkAndComplete("test1"), "Still nothing satisfied") + assertFalse(purgatory.tryCompleteElseWatch(r2, Array("test2")), "r2 not satisfied and hence watched") + assertEquals(0, purgatory.checkAndComplete("test2"), "Still nothing satisfied") + r1.completable = true + assertEquals(1, purgatory.checkAndComplete("test1"), "r1 satisfied") + assertEquals(0, purgatory.checkAndComplete("test1"), "Nothing satisfied") + r2.completable = true + assertEquals(1, purgatory.checkAndComplete("test2"), "r2 satisfied") + assertEquals(0, purgatory.checkAndComplete("test2"), "Nothing satisfied") + } + + @Test + def testRequestExpiry(): Unit = { + val expiration = 20L + val start = Time.SYSTEM.hiResClockMs + val r1 = new MockDelayedOperation(expiration) + val r2 = new MockDelayedOperation(200000L) + assertFalse(purgatory.tryCompleteElseWatch(r1, Array("test1")), "r1 not satisfied and hence watched") + assertFalse(purgatory.tryCompleteElseWatch(r2, Array("test2")), "r2 not satisfied and hence watched") + r1.awaitExpiration() + val elapsed = Time.SYSTEM.hiResClockMs - start + assertTrue(r1.isCompleted, "r1 completed due to expiration") + assertFalse(r2.isCompleted, "r2 hasn't completed") + assertTrue(elapsed >= expiration, s"Time for expiration $elapsed should at least $expiration") + } + + @Test + def testDelayedFuture(): Unit = { + val purgatoryName = "testDelayedFuture" + val purgatory = new DelayedFuturePurgatory(purgatoryName, brokerId = 0) + val result = new AtomicInteger() + + def hasExecutorThread: Boolean = Thread.getAllStackTraces.keySet.asScala.map(_.getName) + .exists(_.contains(s"DelayedExecutor-$purgatoryName")) + def updateResult(futures: List[CompletableFuture[Integer]]): Unit = + result.set(futures.filterNot(_.isCompletedExceptionally).map(_.get.intValue).sum) + + assertFalse(hasExecutorThread, "Unnecessary thread created") + + // Two completed futures: callback should be executed immediately on the same thread + val futures1 = List(CompletableFuture.completedFuture(10.asInstanceOf[Integer]), + CompletableFuture.completedFuture(11.asInstanceOf[Integer])) + val r1 = purgatory.tryCompleteElseWatch[Integer](100000L, futures1, () => updateResult(futures1)) + assertTrue(r1.isCompleted, "r1 not completed") + assertEquals(21, result.get()) + assertFalse(hasExecutorThread, "Unnecessary thread created") + + // Two delayed futures: callback should wait for both to complete + result.set(-1) + val futures2 = List(new CompletableFuture[Integer], new CompletableFuture[Integer]) + val r2 = purgatory.tryCompleteElseWatch[Integer](100000L, futures2, () => updateResult(futures2)) + assertFalse(r2.isCompleted, "r2 should be incomplete") + futures2.head.complete(20) + assertFalse(r2.isCompleted) + assertEquals(-1, result.get()) + futures2(1).complete(21) + TestUtils.waitUntilTrue(() => r2.isCompleted, "r2 not completed") + TestUtils.waitUntilTrue(() => result.get == 41, "callback not invoked") + assertTrue(hasExecutorThread, "Thread not created for executing delayed task") + + // One immediate and one delayed future: callback should wait for delayed task to complete + result.set(-1) + val futures3 = List(new CompletableFuture[Integer], CompletableFuture.completedFuture(31.asInstanceOf[Integer])) + val r3 = purgatory.tryCompleteElseWatch[Integer](100000L, futures3, () => updateResult(futures3)) + assertFalse(r3.isCompleted, "r3 should be incomplete") + assertEquals(-1, result.get()) + futures3.head.complete(30) + TestUtils.waitUntilTrue(() => r3.isCompleted, "r3 not completed") + TestUtils.waitUntilTrue(() => result.get == 61, "callback not invoked") + + + // One future doesn't complete within timeout. Should expire and invoke callback after timeout. + result.set(-1) + val start = Time.SYSTEM.hiResClockMs + val expirationMs = 2000L + val futures4 = List(new CompletableFuture[Integer], new CompletableFuture[Integer]) + val r4 = purgatory.tryCompleteElseWatch[Integer](expirationMs, futures4, () => updateResult(futures4)) + futures4.head.complete(40) + TestUtils.waitUntilTrue(() => futures4(1).isDone, "r4 futures not expired") + assertTrue(r4.isCompleted, "r4 not completed after timeout") + val elapsed = Time.SYSTEM.hiResClockMs - start + assertTrue(elapsed >= expirationMs, s"Time for expiration $elapsed should at least $expirationMs") + assertEquals(40, futures4.head.get) + assertEquals(classOf[org.apache.kafka.common.errors.TimeoutException], + assertThrows(classOf[ExecutionException], () => futures4(1).get).getCause.getClass) + assertEquals(40, result.get()) + } + + @Test + def testRequestPurge(): Unit = { + val r1 = new MockDelayedOperation(100000L) + val r2 = new MockDelayedOperation(100000L) + val r3 = new MockDelayedOperation(100000L) + purgatory.tryCompleteElseWatch(r1, Array("test1")) + purgatory.tryCompleteElseWatch(r2, Array("test1", "test2")) + purgatory.tryCompleteElseWatch(r3, Array("test1", "test2", "test3")) + + assertEquals(3, purgatory.numDelayed, "Purgatory should have 3 total delayed operations") + assertEquals(6, purgatory.watched, "Purgatory should have 6 watched elements") + + // complete the operations, it should immediately be purged from the delayed operation + r2.completable = true + r2.tryComplete() + assertEquals(2, purgatory.numDelayed, "Purgatory should have 2 total delayed operations instead of " + purgatory.numDelayed) + + r3.completable = true + r3.tryComplete() + assertEquals(1, purgatory.numDelayed, "Purgatory should have 1 total delayed operations instead of " + purgatory.numDelayed) + + // checking a watch should purge the watch list + purgatory.checkAndComplete("test1") + assertEquals(4, purgatory.watched, "Purgatory should have 4 watched elements instead of " + purgatory.watched) + + purgatory.checkAndComplete("test2") + assertEquals(2, purgatory.watched, "Purgatory should have 2 watched elements instead of " + purgatory.watched) + + purgatory.checkAndComplete("test3") + assertEquals(1, purgatory.watched, "Purgatory should have 1 watched elements instead of " + purgatory.watched) + } + + @Test + def shouldCancelForKeyReturningCancelledOperations(): Unit = { + purgatory.tryCompleteElseWatch(new MockDelayedOperation(10000L), Seq("key")) + purgatory.tryCompleteElseWatch(new MockDelayedOperation(10000L), Seq("key")) + purgatory.tryCompleteElseWatch(new MockDelayedOperation(10000L), Seq("key2")) + + val cancelledOperations = purgatory.cancelForKey("key") + assertEquals(2, cancelledOperations.size) + assertEquals(1, purgatory.numDelayed) + assertEquals(1, purgatory.watched) + } + + @Test + def shouldReturnNilOperationsOnCancelForKeyWhenKeyDoesntExist(): Unit = { + val cancelledOperations = purgatory.cancelForKey("key") + assertEquals(Nil, cancelledOperations) + } + + /** + * Test `tryComplete` with multiple threads to verify that there are no timing windows + * when completion is not performed even if the thread that makes the operation completable + * may not be able to acquire the operation lock. Since it is difficult to test all scenarios, + * this test uses random delays with a large number of threads. + */ + @Test + def testTryCompleteWithMultipleThreads(): Unit = { + val executor = Executors.newScheduledThreadPool(20) + this.executorService = executor + val random = new Random + val maxDelayMs = 10 + val completionAttempts = 20 + + class TestDelayOperation(index: Int) extends MockDelayedOperation(10000L) { + val key = s"key$index" + val completionAttemptsRemaining = new AtomicInteger(completionAttempts) + + override def tryComplete(): Boolean = { + val shouldComplete = completable + Thread.sleep(random.nextInt(maxDelayMs)) + if (shouldComplete) + forceComplete() + else + false + } + } + val ops = (0 until 100).map { index => + val op = new TestDelayOperation(index) + purgatory.tryCompleteElseWatch(op, Seq(op.key)) + op + } + + def scheduleTryComplete(op: TestDelayOperation, delayMs: Long): Future[_] = { + executor.schedule(new Runnable { + override def run(): Unit = { + if (op.completionAttemptsRemaining.decrementAndGet() == 0) + op.completable = true + purgatory.checkAndComplete(op.key) + } + }, delayMs, TimeUnit.MILLISECONDS) + } + + (1 to completionAttempts).flatMap { _ => + ops.map { op => scheduleTryComplete(op, random.nextInt(maxDelayMs)) } + }.foreach { future => future.get } + + ops.foreach { op => assertTrue(op.isCompleted, "Operation should have completed") } + } + + def verifyDelayedOperationLock(mockDelayedOperation: => MockDelayedOperation, mismatchedLocks: Boolean): Unit = { + val key = "key" + executorService = Executors.newSingleThreadExecutor + def createDelayedOperations(count: Int): Seq[MockDelayedOperation] = { + (1 to count).map { _ => + val op = mockDelayedOperation + purgatory.tryCompleteElseWatch(op, Seq(key)) + assertFalse(op.isCompleted, "Not completable") + op + } + } + + def createCompletableOperations(count: Int): Seq[MockDelayedOperation] = { + (1 to count).map { _ => + val op = mockDelayedOperation + op.completable = true + op + } + } + + def checkAndComplete(completableOps: Seq[MockDelayedOperation], expectedComplete: Seq[MockDelayedOperation]): Unit = { + completableOps.foreach(op => op.completable = true) + val completed = purgatory.checkAndComplete(key) + assertEquals(expectedComplete.size, completed) + expectedComplete.foreach(op => assertTrue(op.isCompleted, "Should have completed")) + val expectedNotComplete = completableOps.toSet -- expectedComplete + expectedNotComplete.foreach(op => assertFalse(op.isCompleted, "Should not have completed")) + } + + // If locks are free all completable operations should complete + var ops = createDelayedOperations(2) + checkAndComplete(ops, ops) + + // Lock held by current thread, completable operations should complete + ops = createDelayedOperations(2) + inLock(ops(1).lock) { + checkAndComplete(ops, ops) + } + + // Lock held by another thread, should not block, only operations that can be + // locked without blocking on the current thread should complete + ops = createDelayedOperations(2) + runOnAnotherThread(ops(0).lock.lock(), true) + try { + checkAndComplete(ops, Seq(ops(1))) + } finally { + runOnAnotherThread(ops(0).lock.unlock(), true) + checkAndComplete(Seq(ops(0)), Seq(ops(0))) + } + + // Lock acquired by response callback held by another thread, should not block + // if the response lock is used as operation lock, only operations + // that can be locked without blocking on the current thread should complete + ops = createDelayedOperations(2) + ops(0).responseLockOpt.foreach { lock => + runOnAnotherThread(lock.lock(), true) + try { + try { + checkAndComplete(ops, Seq(ops(1))) + assertFalse(mismatchedLocks, "Should have failed with mismatched locks") + } catch { + case e: IllegalStateException => + assertTrue(mismatchedLocks, "Should not have failed with valid locks") + } + } finally { + runOnAnotherThread(lock.unlock(), true) + checkAndComplete(Seq(ops(0)), Seq(ops(0))) + } + } + + // Immediately completable operations should complete without locking + ops = createCompletableOperations(2) + ops.foreach { op => + assertTrue(purgatory.tryCompleteElseWatch(op, Seq(key)), "Should have completed") + assertTrue(op.isCompleted, "Should have completed") + } + } + + private def runOnAnotherThread(fun: => Unit, shouldComplete: Boolean): Future[_] = { + val future = executorService.submit(new Runnable { + def run() = fun + }) + if (shouldComplete) + future.get() + else + assertFalse(future.isDone, "Should not have completed") + future + } + + class MockDelayedOperation(delayMs: Long, + lockOpt: Option[ReentrantLock] = None, + val responseLockOpt: Option[ReentrantLock] = None) + extends DelayedOperation(delayMs, lockOpt) { + var completable = false + + def awaitExpiration(): Unit = { + synchronized { + wait() + } + } + + override def tryComplete() = { + if (completable) + forceComplete() + else + false + } + + override def onExpiration(): Unit = { + + } + + override def onComplete(): Unit = { + responseLockOpt.foreach { lock => + if (!lock.tryLock()) + throw new IllegalStateException("Response callback lock could not be acquired in callback") + } + synchronized { + notify() + } + } + } + +} diff --git a/core/src/test/scala/unit/kafka/server/DelegationTokenRequestsOnPlainTextTest.scala b/core/src/test/scala/unit/kafka/server/DelegationTokenRequestsOnPlainTextTest.scala new file mode 100644 index 0000000..1d355ce --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/DelegationTokenRequestsOnPlainTextTest.scala @@ -0,0 +1,72 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.util + +import kafka.utils.TestUtils +import org.apache.kafka.clients.admin.{Admin, AdminClientConfig} +import org.apache.kafka.common.errors.UnsupportedByAuthenticationException +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} +import org.junit.jupiter.api.Assertions.assertThrows + +import scala.concurrent.ExecutionException + +class DelegationTokenRequestsOnPlainTextTest extends BaseRequestTest { + var adminClient: Admin = null + + override def brokerCount = 1 + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + } + + def createAdminConfig: util.Map[String, Object] = { + val config = new util.HashMap[String, Object] + config.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList) + val securityProps: util.Map[Object, Object] = + TestUtils.adminClientSecurityConfigs(securityProtocol, trustStoreFile, clientSaslProperties) + securityProps.forEach { (key, value) => config.put(key.asInstanceOf[String], value) } + config + } + + @Test + def testDelegationTokenRequests(): Unit = { + adminClient = Admin.create(createAdminConfig) + + val createResult = adminClient.createDelegationToken() + assertThrows(classOf[ExecutionException], () => createResult.delegationToken().get()).getCause.isInstanceOf[UnsupportedByAuthenticationException] + + val describeResult = adminClient.describeDelegationToken() + assertThrows(classOf[ExecutionException], () => describeResult.delegationTokens().get()).getCause.isInstanceOf[UnsupportedByAuthenticationException] + + val renewResult = adminClient.renewDelegationToken("".getBytes()) + assertThrows(classOf[ExecutionException], () => renewResult.expiryTimestamp().get()).getCause.isInstanceOf[UnsupportedByAuthenticationException] + + val expireResult = adminClient.expireDelegationToken("".getBytes()) + assertThrows(classOf[ExecutionException], () => expireResult.expiryTimestamp().get()).getCause.isInstanceOf[UnsupportedByAuthenticationException] + } + + + @AfterEach + override def tearDown(): Unit = { + if (adminClient != null) + adminClient.close() + super.tearDown() + } +} diff --git a/core/src/test/scala/unit/kafka/server/DelegationTokenRequestsTest.scala b/core/src/test/scala/unit/kafka/server/DelegationTokenRequestsTest.scala new file mode 100644 index 0000000..4f304bd --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/DelegationTokenRequestsTest.scala @@ -0,0 +1,133 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import kafka.api.{KafkaSasl, SaslSetup} +import kafka.utils.{JaasTestUtils, TestUtils} +import org.apache.kafka.clients.admin.{Admin, AdminClientConfig, CreateDelegationTokenOptions, DescribeDelegationTokenOptions} +import org.apache.kafka.common.errors.InvalidPrincipalTypeException +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.utils.SecurityUtils +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +import java.util +import scala.concurrent.ExecutionException +import scala.jdk.CollectionConverters._ + +class DelegationTokenRequestsTest extends BaseRequestTest with SaslSetup { + override protected def securityProtocol = SecurityProtocol.SASL_PLAINTEXT + private val kafkaClientSaslMechanism = "PLAIN" + private val kafkaServerSaslMechanisms = List("PLAIN") + protected override val serverSaslProperties = Some(kafkaServerSaslProperties(kafkaServerSaslMechanisms, kafkaClientSaslMechanism)) + protected override val clientSaslProperties = Some(kafkaClientSaslProperties(kafkaClientSaslMechanism)) + var adminClient: Admin = null + + override def brokerCount = 1 + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + startSasl(jaasSections(kafkaServerSaslMechanisms, Some(kafkaClientSaslMechanism), KafkaSasl, JaasTestUtils.KafkaServerContextName)) + super.setUp(testInfo) + } + + override def generateConfigs = { + val props = TestUtils.createBrokerConfigs(brokerCount, zkConnect, + enableControlledShutdown = false, + interBrokerSecurityProtocol = Some(securityProtocol), + trustStoreFile = trustStoreFile, saslProperties = serverSaslProperties, enableToken = true) + props.foreach(brokerPropertyOverrides) + props.map(KafkaConfig.fromProps) + } + + private def createAdminConfig: util.Map[String, Object] = { + val config = new util.HashMap[String, Object] + config.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList) + val securityProps: util.Map[Object, Object] = + TestUtils.adminClientSecurityConfigs(securityProtocol, trustStoreFile, clientSaslProperties) + securityProps.forEach { (key, value) => config.put(key.asInstanceOf[String], value) } + config + } + + @Test + def testDelegationTokenRequests(): Unit = { + adminClient = Admin.create(createAdminConfig) + + // create token1 with renewer1 + val renewer1 = List(SecurityUtils.parseKafkaPrincipal("User:renewer1")).asJava + val createResult1 = adminClient.createDelegationToken(new CreateDelegationTokenOptions().renewers(renewer1)) + val tokenCreated = createResult1.delegationToken().get() + + //test describe token + var tokens = adminClient.describeDelegationToken().delegationTokens().get() + assertEquals(1, tokens.size()) + var token1 = tokens.get(0) + assertEquals(token1, tokenCreated) + + // create token2 with renewer2 + val renewer2 = List(SecurityUtils.parseKafkaPrincipal("User:renewer2")).asJava + val createResult2 = adminClient.createDelegationToken(new CreateDelegationTokenOptions().renewers(renewer2)) + val token2 = createResult2.delegationToken().get() + + //get all tokens + tokens = adminClient.describeDelegationToken().delegationTokens().get() + assertTrue(tokens.size() == 2) + assertEquals(Set(token1, token2), tokens.asScala.toSet) + + //get tokens for renewer2 + tokens = adminClient.describeDelegationToken(new DescribeDelegationTokenOptions().owners(renewer2)).delegationTokens().get() + assertTrue(tokens.size() == 1) + assertEquals(Set(token2), tokens.asScala.toSet) + + //test renewing tokens + val renewResult = adminClient.renewDelegationToken(token1.hmac()) + var expiryTimestamp = renewResult.expiryTimestamp().get() + + val describeResult = adminClient.describeDelegationToken() + val tokenId = token1.tokenInfo().tokenId() + token1 = describeResult.delegationTokens().get().asScala.filter(dt => dt.tokenInfo().tokenId() == tokenId).head + assertEquals(expiryTimestamp, token1.tokenInfo().expiryTimestamp()) + + //test expire tokens + val expireResult1 = adminClient.expireDelegationToken(token1.hmac()) + expiryTimestamp = expireResult1.expiryTimestamp().get() + + val expireResult2 = adminClient.expireDelegationToken(token2.hmac()) + expiryTimestamp = expireResult2.expiryTimestamp().get() + + tokens = adminClient.describeDelegationToken().delegationTokens().get() + assertTrue(tokens.size == 0) + + //create token with invalid principal type + val renewer3 = List(SecurityUtils.parseKafkaPrincipal("Group:Renewer3")).asJava + val createResult3 = adminClient.createDelegationToken(new CreateDelegationTokenOptions().renewers(renewer3)) + assertThrows(classOf[ExecutionException], () => createResult3.delegationToken().get()).getCause.isInstanceOf[InvalidPrincipalTypeException] + + // try describing tokens for unknown owner + val unknownOwner = List(SecurityUtils.parseKafkaPrincipal("User:Unknown")).asJava + tokens = adminClient.describeDelegationToken(new DescribeDelegationTokenOptions().owners(unknownOwner)).delegationTokens().get() + assertTrue(tokens.isEmpty) + } + + @AfterEach + override def tearDown(): Unit = { + if (adminClient != null) + adminClient.close() + super.tearDown() + closeSasl() + } +} diff --git a/core/src/test/scala/unit/kafka/server/DelegationTokenRequestsWithDisableTokenFeatureTest.scala b/core/src/test/scala/unit/kafka/server/DelegationTokenRequestsWithDisableTokenFeatureTest.scala new file mode 100644 index 0000000..e1e5c30 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/DelegationTokenRequestsWithDisableTokenFeatureTest.scala @@ -0,0 +1,79 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import kafka.api.{KafkaSasl, SaslSetup} +import kafka.utils.{JaasTestUtils, TestUtils} +import org.apache.kafka.clients.admin.{Admin, AdminClientConfig} +import org.apache.kafka.common.errors.DelegationTokenDisabledException +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.junit.jupiter.api.Assertions.assertThrows +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +import java.util +import scala.concurrent.ExecutionException + +class DelegationTokenRequestsWithDisableTokenFeatureTest extends BaseRequestTest with SaslSetup { + override protected def securityProtocol = SecurityProtocol.SASL_PLAINTEXT + private val kafkaClientSaslMechanism = "PLAIN" + private val kafkaServerSaslMechanisms = List("PLAIN") + protected override val serverSaslProperties = Some(kafkaServerSaslProperties(kafkaServerSaslMechanisms, kafkaClientSaslMechanism)) + protected override val clientSaslProperties = Some(kafkaClientSaslProperties(kafkaClientSaslMechanism)) + var adminClient: Admin = null + + override def brokerCount = 1 + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + startSasl(jaasSections(kafkaServerSaslMechanisms, Some(kafkaClientSaslMechanism), KafkaSasl, JaasTestUtils.KafkaServerContextName)) + super.setUp(testInfo) + } + + def createAdminConfig: util.Map[String, Object] = { + val config = new util.HashMap[String, Object] + config.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList) + val securityProps: util.Map[Object, Object] = + TestUtils.adminClientSecurityConfigs(securityProtocol, trustStoreFile, clientSaslProperties) + securityProps.forEach { (key, value) => config.put(key.asInstanceOf[String], value) } + config + } + + @Test + def testDelegationTokenRequests(): Unit = { + adminClient = Admin.create(createAdminConfig) + + val createResult = adminClient.createDelegationToken() + assertThrows(classOf[ExecutionException], () => createResult.delegationToken().get()).getCause.isInstanceOf[DelegationTokenDisabledException] + + val describeResult = adminClient.describeDelegationToken() + assertThrows(classOf[ExecutionException], () => describeResult.delegationTokens().get()).getCause.isInstanceOf[DelegationTokenDisabledException] + + val renewResult = adminClient.renewDelegationToken("".getBytes()) + assertThrows(classOf[ExecutionException], () => renewResult.expiryTimestamp().get()).getCause.isInstanceOf[DelegationTokenDisabledException] + + val expireResult = adminClient.expireDelegationToken("".getBytes()) + assertThrows(classOf[ExecutionException], () => expireResult.expiryTimestamp().get()).getCause.isInstanceOf[DelegationTokenDisabledException] + } + + @AfterEach + override def tearDown(): Unit = { + if (adminClient != null) + adminClient.close() + super.tearDown() + closeSasl() + } +} diff --git a/core/src/test/scala/unit/kafka/server/DeleteTopicsRequestTest.scala b/core/src/test/scala/unit/kafka/server/DeleteTopicsRequestTest.scala new file mode 100644 index 0000000..a176121 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/DeleteTopicsRequestTest.scala @@ -0,0 +1,192 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util.{Arrays, Collections} + +import kafka.network.SocketServer +import kafka.utils._ +import org.apache.kafka.common.Uuid +import org.apache.kafka.common.message.DeleteTopicsRequestData +import org.apache.kafka.common.message.DeleteTopicsRequestData.DeleteTopicState +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.{DeleteTopicsRequest, DeleteTopicsResponse, MetadataRequest, MetadataResponse} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ + +class DeleteTopicsRequestTest extends BaseRequestTest { + + @Test + def testValidDeleteTopicRequests(): Unit = { + val timeout = 10000 + // Single topic + createTopic("topic-1", 1, 1) + validateValidDeleteTopicRequests(new DeleteTopicsRequest.Builder( + new DeleteTopicsRequestData() + .setTopicNames(Arrays.asList("topic-1")) + .setTimeoutMs(timeout)).build()) + // Multi topic + createTopic("topic-3", 5, 2) + createTopic("topic-4", 1, 2) + validateValidDeleteTopicRequests(new DeleteTopicsRequest.Builder( + new DeleteTopicsRequestData() + .setTopicNames(Arrays.asList("topic-3", "topic-4")) + .setTimeoutMs(timeout)).build()) + + // Topic Ids + createTopic("topic-7", 3, 2) + createTopic("topic-6", 1, 2) + val ids = getTopicIds() + validateValidDeleteTopicRequestsWithIds(new DeleteTopicsRequest.Builder( + new DeleteTopicsRequestData() + .setTopics(Arrays.asList(new DeleteTopicState().setTopicId(ids("topic-7")), + new DeleteTopicState().setTopicId(ids("topic-6")) + ) + ).setTimeoutMs(timeout)).build()) + } + + private def validateValidDeleteTopicRequests(request: DeleteTopicsRequest): Unit = { + val response = sendDeleteTopicsRequest(request) + val error = response.errorCounts.asScala.find(_._1 != Errors.NONE) + assertTrue(error.isEmpty, s"There should be no errors, found ${response.data.responses.asScala}") + request.data.topicNames.forEach { topic => + validateTopicIsDeleted(topic) + } + } + + private def validateValidDeleteTopicRequestsWithIds(request: DeleteTopicsRequest): Unit = { + val response = sendDeleteTopicsRequest(request) + val error = response.errorCounts.asScala.find(_._1 != Errors.NONE) + assertTrue(error.isEmpty, s"There should be no errors, found ${response.data.responses.asScala}") + response.data.responses.forEach { response => + validateTopicIsDeleted(response.name()) + } + } + + @Test + def testErrorDeleteTopicRequests(): Unit = { + val timeout = 30000 + val timeoutTopic = "invalid-timeout" + + // Basic + validateErrorDeleteTopicRequests(new DeleteTopicsRequest.Builder( + new DeleteTopicsRequestData() + .setTopicNames(Arrays.asList("invalid-topic")) + .setTimeoutMs(timeout)).build(), + Map("invalid-topic" -> Errors.UNKNOWN_TOPIC_OR_PARTITION)) + + // Partial + createTopic("partial-topic-1", 1, 1) + validateErrorDeleteTopicRequests(new DeleteTopicsRequest.Builder( + new DeleteTopicsRequestData() + .setTopicNames(Arrays.asList("partial-topic-1", "partial-invalid-topic")) + .setTimeoutMs(timeout)).build(), + Map( + "partial-topic-1" -> Errors.NONE, + "partial-invalid-topic" -> Errors.UNKNOWN_TOPIC_OR_PARTITION + ) + ) + + // Topic IDs + createTopic("topic-id-1", 1, 1) + val validId = getTopicIds()("topic-id-1") + val invalidId = Uuid.randomUuid + validateErrorDeleteTopicRequestsWithIds(new DeleteTopicsRequest.Builder( + new DeleteTopicsRequestData() + .setTopics(Arrays.asList(new DeleteTopicState().setTopicId(invalidId), + new DeleteTopicState().setTopicId(validId))) + .setTimeoutMs(timeout)).build(), + Map( + invalidId -> Errors.UNKNOWN_TOPIC_ID, + validId -> Errors.NONE + ) + ) + + // Timeout + createTopic(timeoutTopic, 5, 2) + // Must be a 0ms timeout to avoid transient test failures. Even a timeout of 1ms has succeeded in the past. + validateErrorDeleteTopicRequests(new DeleteTopicsRequest.Builder( + new DeleteTopicsRequestData() + .setTopicNames(Arrays.asList(timeoutTopic)) + .setTimeoutMs(0)).build(), + Map(timeoutTopic -> Errors.REQUEST_TIMED_OUT)) + // The topic should still get deleted eventually + TestUtils.waitUntilTrue(() => !servers.head.metadataCache.contains(timeoutTopic), s"Topic $timeoutTopic is never deleted") + validateTopicIsDeleted(timeoutTopic) + } + + private def validateErrorDeleteTopicRequests(request: DeleteTopicsRequest, expectedResponse: Map[String, Errors]): Unit = { + val response = sendDeleteTopicsRequest(request) + val errors = response.data.responses + + val errorCount = response.errorCounts().asScala.foldLeft(0)(_+_._2) + assertEquals(expectedResponse.size, errorCount, "The response size should match") + + expectedResponse.foreach { case (topic, expectedError) => + assertEquals(expectedResponse(topic).code, errors.find(topic).errorCode, "The response error should match") + // If no error validate the topic was deleted + if (expectedError == Errors.NONE) { + validateTopicIsDeleted(topic) + } + } + } + + private def validateErrorDeleteTopicRequestsWithIds(request: DeleteTopicsRequest, expectedResponse: Map[Uuid, Errors]): Unit = { + val response = sendDeleteTopicsRequest(request) + val responses = response.data.responses + val errors = responses.asScala.map(result => result.topicId() -> result.errorCode()).toMap + val names = responses.asScala.map(result => result.topicId() -> result.name()).toMap + + val errorCount = response.errorCounts().asScala.foldLeft(0)(_+_._2) + assertEquals(expectedResponse.size, errorCount, "The response size should match") + + expectedResponse.foreach { case (topic, expectedError) => + assertEquals(expectedResponse(topic).code, errors(topic), "The response error should match") + // If no error validate the topic was deleted + if (expectedError == Errors.NONE) { + validateTopicIsDeleted(names(topic)) + } + } + } + + @Test + def testNotController(): Unit = { + val request = new DeleteTopicsRequest.Builder( + new DeleteTopicsRequestData() + .setTopicNames(Collections.singletonList("not-controller")) + .setTimeoutMs(1000)).build() + val response = sendDeleteTopicsRequest(request, notControllerSocketServer) + + val error = response.data.responses.find("not-controller").errorCode() + assertEquals(Errors.NOT_CONTROLLER.code, error, "Expected controller error when routed incorrectly") + } + + private def validateTopicIsDeleted(topic: String): Unit = { + val metadata = connectAndReceive[MetadataResponse](new MetadataRequest.Builder( + List(topic).asJava, true).build).topicMetadata.asScala + TestUtils.waitUntilTrue (() => !metadata.exists(p => p.topic.equals(topic) && p.error == Errors.NONE), + s"The topic $topic should not exist") + } + + private def sendDeleteTopicsRequest(request: DeleteTopicsRequest, socketServer: SocketServer = controllerSocketServer): DeleteTopicsResponse = { + connectAndReceive[DeleteTopicsResponse](request, destination = socketServer) + } + +} diff --git a/core/src/test/scala/unit/kafka/server/DeleteTopicsRequestWithDeletionDisabledTest.scala b/core/src/test/scala/unit/kafka/server/DeleteTopicsRequestWithDeletionDisabledTest.scala new file mode 100644 index 0000000..ad44f55 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/DeleteTopicsRequestWithDeletionDisabledTest.scala @@ -0,0 +1,64 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util.Collections + +import kafka.utils._ +import org.apache.kafka.common.message.DeleteTopicsRequestData +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.{DeleteTopicsRequest, DeleteTopicsResponse} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +class DeleteTopicsRequestWithDeletionDisabledTest extends BaseRequestTest { + + override def brokerCount: Int = 1 + + override def generateConfigs = { + val props = TestUtils.createBrokerConfigs(brokerCount, zkConnect, + enableControlledShutdown = false, enableDeleteTopic = false, + interBrokerSecurityProtocol = Some(securityProtocol), + trustStoreFile = trustStoreFile, saslProperties = serverSaslProperties, logDirCount = logDirCount) + props.foreach(brokerPropertyOverrides) + props.map(KafkaConfig.fromProps) + } + + @Test + def testDeleteRecordsRequest(): Unit = { + val topic = "topic-1" + val request = new DeleteTopicsRequest.Builder( + new DeleteTopicsRequestData() + .setTopicNames(Collections.singletonList(topic)) + .setTimeoutMs(1000)).build() + val response = sendDeleteTopicsRequest(request) + assertEquals(Errors.TOPIC_DELETION_DISABLED.code, response.data.responses.find(topic).errorCode) + + val v2request = new DeleteTopicsRequest.Builder( + new DeleteTopicsRequestData() + .setTopicNames(Collections.singletonList(topic)) + .setTimeoutMs(1000)).build(2) + val v2response = sendDeleteTopicsRequest(v2request) + assertEquals(Errors.INVALID_REQUEST.code, v2response.data.responses.find(topic).errorCode) + } + + private def sendDeleteTopicsRequest(request: DeleteTopicsRequest): DeleteTopicsResponse = { + connectAndReceive[DeleteTopicsResponse](request, destination = controllerSocketServer) + } + +} diff --git a/core/src/test/scala/unit/kafka/server/DescribeClusterRequestTest.scala b/core/src/test/scala/unit/kafka/server/DescribeClusterRequestTest.scala new file mode 100644 index 0000000..222ff2d --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/DescribeClusterRequestTest.scala @@ -0,0 +1,93 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.lang.{Byte => JByte} +import java.util.Properties + +import kafka.network.SocketServer +import kafka.security.authorizer.AclEntry +import org.apache.kafka.common.message.{DescribeClusterRequestData, DescribeClusterResponseData} +import org.apache.kafka.common.protocol.ApiKeys +import org.apache.kafka.common.requests.{DescribeClusterRequest, DescribeClusterResponse} +import org.apache.kafka.common.resource.ResourceType +import org.apache.kafka.common.utils.Utils +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.{BeforeEach, Test, TestInfo} + +import scala.jdk.CollectionConverters._ + +class DescribeClusterRequestTest extends BaseRequestTest { + + override def brokerPropertyOverrides(properties: Properties): Unit = { + properties.setProperty(KafkaConfig.OffsetsTopicPartitionsProp, "1") + properties.setProperty(KafkaConfig.DefaultReplicationFactorProp, "2") + properties.setProperty(KafkaConfig.RackProp, s"rack/${properties.getProperty(KafkaConfig.BrokerIdProp)}") + } + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + doSetup(testInfo, createOffsetsTopic = false) + } + + @Test + def testDescribeClusterRequestIncludingClusterAuthorizedOperations(): Unit = { + testDescribeClusterRequest(true) + } + + @Test + def testDescribeClusterRequestExcludingClusterAuthorizedOperations(): Unit = { + testDescribeClusterRequest(false) + } + + def testDescribeClusterRequest(includeClusterAuthorizedOperations: Boolean): Unit = { + val expectedBrokers = servers.map { server => + new DescribeClusterResponseData.DescribeClusterBroker() + .setBrokerId(server.config.brokerId) + .setHost("localhost") + .setPort(server.socketServer.boundPort(listenerName)) + .setRack(server.config.rack.orNull) + }.toSet + val expectedControllerId = servers.filter(_.kafkaController.isActive).last.config.brokerId + val expectedClusterId = servers.last.clusterId + + val expectedClusterAuthorizedOperations = if (includeClusterAuthorizedOperations) { + Utils.to32BitField( + AclEntry.supportedOperations(ResourceType.CLUSTER) + .map(_.code.asInstanceOf[JByte]).asJava) + } else { + Int.MinValue + } + + for (version <- ApiKeys.DESCRIBE_CLUSTER.oldestVersion to ApiKeys.DESCRIBE_CLUSTER.latestVersion) { + val describeClusterRequest = new DescribeClusterRequest.Builder(new DescribeClusterRequestData() + .setIncludeClusterAuthorizedOperations(includeClusterAuthorizedOperations)) + .build(version.toShort) + val describeClusterResponse = sentDescribeClusterRequest(describeClusterRequest) + + assertEquals(expectedControllerId, describeClusterResponse.data.controllerId) + assertEquals(expectedClusterId, describeClusterResponse.data.clusterId) + assertEquals(expectedClusterAuthorizedOperations, describeClusterResponse.data.clusterAuthorizedOperations) + assertEquals(expectedBrokers, describeClusterResponse.data.brokers.asScala.toSet) + } + } + + private def sentDescribeClusterRequest(request: DescribeClusterRequest, destination: Option[SocketServer] = None): DescribeClusterResponse = { + connectAndReceive[DescribeClusterResponse](request, destination = destination.getOrElse(anySocketServer)) + } +} diff --git a/core/src/test/scala/unit/kafka/server/DescribeLogDirsRequestTest.scala b/core/src/test/scala/unit/kafka/server/DescribeLogDirsRequestTest.scala new file mode 100644 index 0000000..9ab3f86 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/DescribeLogDirsRequestTest.scala @@ -0,0 +1,76 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.io.File + +import kafka.utils._ +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.message.DescribeLogDirsRequestData +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests._ +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ + +class DescribeLogDirsRequestTest extends BaseRequestTest { + override val logDirCount = 2 + override val brokerCount: Int = 1 + + val topic = "topic" + val partitionNum = 2 + val tp0 = new TopicPartition(topic, 0) + val tp1 = new TopicPartition(topic, 1) + + @Test + def testDescribeLogDirsRequest(): Unit = { + val onlineDir = new File(servers.head.config.logDirs.head).getAbsolutePath + val offlineDir = new File(servers.head.config.logDirs.tail.head).getAbsolutePath + servers.head.replicaManager.handleLogDirFailure(offlineDir) + createTopic(topic, partitionNum, 1) + TestUtils.generateAndProduceMessages(servers, topic, 10) + + val request = new DescribeLogDirsRequest.Builder(new DescribeLogDirsRequestData().setTopics(null)).build() + val response = connectAndReceive[DescribeLogDirsResponse](request, destination = controllerSocketServer) + + assertEquals(logDirCount, response.data.results.size) + val offlineResult = response.data.results.asScala.find(logDirResult => logDirResult.logDir == offlineDir).get + assertEquals(Errors.KAFKA_STORAGE_ERROR.code, offlineResult.errorCode) + assertEquals(0, offlineResult.topics.asScala.map(t => t.partitions().size()).sum) + + val onlineResult = response.data.results.asScala.find(logDirResult => logDirResult.logDir == onlineDir).get + assertEquals(Errors.NONE.code, onlineResult.errorCode) + val onlinePartitionsMap = onlineResult.topics.asScala.flatMap { topic => + topic.partitions().asScala.map { partitionResult => + new TopicPartition(topic.name, partitionResult.partitionIndex) -> partitionResult + } + }.toMap + val replicaInfo0 = onlinePartitionsMap(tp0) + val replicaInfo1 = onlinePartitionsMap(tp1) + val log0 = servers.head.logManager.getLog(tp0).get + val log1 = servers.head.logManager.getLog(tp1).get + assertEquals(log0.size, replicaInfo0.partitionSize) + assertEquals(log1.size, replicaInfo1.partitionSize) + val logEndOffset = servers.head.logManager.getLog(tp0).get.logEndOffset + assertTrue(logEndOffset > 0, s"LogEndOffset '$logEndOffset' should be > 0") + assertEquals(servers.head.replicaManager.getLogEndOffsetLag(tp0, log0.logEndOffset, false), replicaInfo0.offsetLag) + assertEquals(servers.head.replicaManager.getLogEndOffsetLag(tp1, log1.logEndOffset, false), replicaInfo1.offsetLag) + } + +} diff --git a/core/src/test/scala/unit/kafka/server/DescribeQuorumRequestTest.scala b/core/src/test/scala/unit/kafka/server/DescribeQuorumRequestTest.scala new file mode 100644 index 0000000..55b9fe9 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/DescribeQuorumRequestTest.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.io.IOException + +import kafka.test.ClusterInstance +import kafka.test.annotation.{ClusterTest, ClusterTestDefaults, Type} +import kafka.test.junit.ClusterTestExtensions +import kafka.utils.NotNothing +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.requests.DescribeQuorumRequest.singletonRequest +import org.apache.kafka.common.requests.{AbstractRequest, AbstractResponse, ApiVersionsRequest, ApiVersionsResponse, DescribeQuorumRequest, DescribeQuorumResponse} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Tag +import org.junit.jupiter.api.extension.ExtendWith + +import scala.jdk.CollectionConverters._ +import scala.reflect.ClassTag + +@ExtendWith(value = Array(classOf[ClusterTestExtensions])) +@ClusterTestDefaults(clusterType = Type.KRAFT) +@Tag("integration") +class DescribeQuorumRequestTest(cluster: ClusterInstance) { + + @ClusterTest(clusterType = Type.ZK) + def testDescribeQuorumNotSupportedByZkBrokers(): Unit = { + val apiRequest = new ApiVersionsRequest.Builder().build() + val apiResponse = connectAndReceive[ApiVersionsResponse](apiRequest) + assertNull(apiResponse.apiVersion(ApiKeys.DESCRIBE_QUORUM.id)) + + val describeQuorumRequest = new DescribeQuorumRequest.Builder( + singletonRequest(KafkaRaftServer.MetadataPartition) + ).build() + + assertThrows(classOf[IOException], () => { + connectAndReceive[DescribeQuorumResponse](describeQuorumRequest) + }) + } + + @ClusterTest + def testDescribeQuorum(): Unit = { + val request = new DescribeQuorumRequest.Builder( + singletonRequest(KafkaRaftServer.MetadataPartition) + ).build() + + val response = connectAndReceive[DescribeQuorumResponse](request) + + assertEquals(Errors.NONE, Errors.forCode(response.data.errorCode)) + assertEquals(1, response.data.topics.size) + + val topicData = response.data.topics.get(0) + assertEquals(KafkaRaftServer.MetadataTopic, topicData.topicName) + assertEquals(1, topicData.partitions.size) + + val partitionData = topicData.partitions.get(0) + assertEquals(KafkaRaftServer.MetadataPartition.partition, partitionData.partitionIndex) + assertEquals(Errors.NONE, Errors.forCode(partitionData.errorCode)) + assertTrue(partitionData.leaderEpoch > 0) + + val leaderId = partitionData.leaderId + assertTrue(leaderId > 0) + + val leaderState = partitionData.currentVoters.asScala.find(_.replicaId == leaderId) + .getOrElse(throw new AssertionError("Failed to find leader among current voter states")) + assertTrue(leaderState.logEndOffset > 0) + } + + private def connectAndReceive[T <: AbstractResponse]( + request: AbstractRequest + )( + implicit classTag: ClassTag[T], nn: NotNothing[T] + ): T = { + IntegrationTestUtils.connectAndReceive( + request, + cluster.brokerSocketServers().asScala.head, + cluster.clientListener() + ) + } + +} diff --git a/core/src/test/scala/unit/kafka/server/DescribeUserScramCredentialsRequestNotAuthorizedTest.scala b/core/src/test/scala/unit/kafka/server/DescribeUserScramCredentialsRequestNotAuthorizedTest.scala new file mode 100644 index 0000000..15f4ab4 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/DescribeUserScramCredentialsRequestNotAuthorizedTest.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import kafka.network.SocketServer +import org.apache.kafka.common.message.DescribeUserScramCredentialsRequestData +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.{DescribeUserScramCredentialsRequest, DescribeUserScramCredentialsResponse} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import java.util.Properties + +/** + * see DescribeUserScramCredentialsRequestTest + */ +class DescribeUserScramCredentialsRequestNotAuthorizedTest extends BaseRequestTest { + override def brokerPropertyOverrides(properties: Properties): Unit = { + properties.put(KafkaConfig.ControlledShutdownEnableProp, "false") + properties.put(KafkaConfig.AuthorizerClassNameProp, classOf[DescribeCredentialsTest.TestAuthorizer].getName) + properties.put(KafkaConfig.PrincipalBuilderClassProp, classOf[DescribeCredentialsTest.TestPrincipalBuilderReturningUnauthorized].getName) + } + + @Test + def testDescribeNotAuthorized(): Unit = { + val request = new DescribeUserScramCredentialsRequest.Builder( + new DescribeUserScramCredentialsRequestData()).build() + val response = sendDescribeUserScramCredentialsRequest(request) + + val error = response.data.errorCode + assertEquals(Errors.CLUSTER_AUTHORIZATION_FAILED.code, error, "Expected not authorized error") + } + + private def sendDescribeUserScramCredentialsRequest(request: DescribeUserScramCredentialsRequest, socketServer: SocketServer = controllerSocketServer): DescribeUserScramCredentialsResponse = { + connectAndReceive[DescribeUserScramCredentialsResponse](request, destination = socketServer) + } +} diff --git a/core/src/test/scala/unit/kafka/server/DescribeUserScramCredentialsRequestTest.scala b/core/src/test/scala/unit/kafka/server/DescribeUserScramCredentialsRequestTest.scala new file mode 100644 index 0000000..012f833 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/DescribeUserScramCredentialsRequestTest.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.util +import java.util.Properties +import kafka.network.SocketServer +import kafka.security.authorizer.AclAuthorizer +import org.apache.kafka.common.message.{DescribeUserScramCredentialsRequestData, DescribeUserScramCredentialsResponseData} +import org.apache.kafka.common.message.DescribeUserScramCredentialsRequestData.UserName +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.requests.{DescribeUserScramCredentialsRequest, DescribeUserScramCredentialsResponse} +import org.apache.kafka.common.security.auth.{AuthenticationContext, KafkaPrincipal} +import org.apache.kafka.common.security.authenticator.DefaultKafkaPrincipalBuilder +import org.apache.kafka.server.authorizer.{Action, AuthorizableRequestContext, AuthorizationResult} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ + +/** + * Test DescribeUserScramCredentialsRequest/Response API for the cases where no credentials exist + * or failure is expected due to lack of authorization, sending the request to a non-controller broker, or some other issue. + * Testing the API for the case where there are actually credentials to describe is performed elsewhere. + */ +class DescribeUserScramCredentialsRequestTest extends BaseRequestTest { + + override def brokerPropertyOverrides(properties: Properties): Unit = { + properties.put(KafkaConfig.ControlledShutdownEnableProp, "false") + properties.put(KafkaConfig.AuthorizerClassNameProp, classOf[DescribeCredentialsTest.TestAuthorizer].getName) + properties.put(KafkaConfig.PrincipalBuilderClassProp, classOf[DescribeCredentialsTest.TestPrincipalBuilderReturningAuthorized].getName) + } + + @Test + def testDescribeNothing(): Unit = { + val request = new DescribeUserScramCredentialsRequest.Builder( + new DescribeUserScramCredentialsRequestData()).build() + val response = sendDescribeUserScramCredentialsRequest(request) + + val error = response.data.errorCode + assertEquals(Errors.NONE.code, error, "Expected no error when describing everything and there are no credentials") + assertEquals(0, response.data.results.size, "Expected no credentials when describing everything and there are no credentials") + } + + @Test + def testDescribeWithNull(): Unit = { + val request = new DescribeUserScramCredentialsRequest.Builder( + new DescribeUserScramCredentialsRequestData().setUsers(null)).build() + val response = sendDescribeUserScramCredentialsRequest(request) + + val error = response.data.errorCode + assertEquals(Errors.NONE.code, error, "Expected no error when describing everything and there are no credentials") + assertEquals(0, response.data.results.size, "Expected no credentials when describing everything and there are no credentials") + } + + @Test + def testDescribeNotController(): Unit = { + val request = new DescribeUserScramCredentialsRequest.Builder( + new DescribeUserScramCredentialsRequestData()).build() + val response = sendDescribeUserScramCredentialsRequest(request, notControllerSocketServer) + + val error = response.data.errorCode + assertEquals(Errors.NONE.code, error, "Did not expect controller error when routed to non-controller") + } + + @Test + def testDescribeSameUserTwice(): Unit = { + val user = "user1" + val userName = new UserName().setName(user) + val request = new DescribeUserScramCredentialsRequest.Builder( + new DescribeUserScramCredentialsRequestData().setUsers(List(userName, userName).asJava)).build() + val response = sendDescribeUserScramCredentialsRequest(request) + + assertEquals(Errors.NONE.code, response.data.errorCode, "Expected no top-level error") + assertEquals(1, response.data.results.size) + val result: DescribeUserScramCredentialsResponseData.DescribeUserScramCredentialsResult = response.data.results.get(0) + assertEquals(Errors.DUPLICATE_RESOURCE.code, result.errorCode, s"Expected duplicate resource error for $user") + assertEquals(s"Cannot describe SCRAM credentials for the same user twice in a single request: $user", result.errorMessage) + } + + @Test + def testUnknownUser(): Unit = { + val unknownUser = "unknownUser" + val request = new DescribeUserScramCredentialsRequest.Builder( + new DescribeUserScramCredentialsRequestData().setUsers(List(new UserName().setName(unknownUser)).asJava)).build() + val response = sendDescribeUserScramCredentialsRequest(request) + + assertEquals(Errors.NONE.code, response.data.errorCode, "Expected no top-level error") + assertEquals(1, response.data.results.size) + val result: DescribeUserScramCredentialsResponseData.DescribeUserScramCredentialsResult = response.data.results.get(0) + assertEquals(Errors.RESOURCE_NOT_FOUND.code, result.errorCode, s"Expected duplicate resource error for $unknownUser") + assertEquals(s"Attempt to describe a user credential that does not exist: $unknownUser", result.errorMessage) + } + + private def sendDescribeUserScramCredentialsRequest(request: DescribeUserScramCredentialsRequest, socketServer: SocketServer = controllerSocketServer): DescribeUserScramCredentialsResponse = { + connectAndReceive[DescribeUserScramCredentialsResponse](request, destination = socketServer) + } +} + +object DescribeCredentialsTest { + val UnauthorizedPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "Unauthorized") + val AuthorizedPrincipal = KafkaPrincipal.ANONYMOUS + + class TestAuthorizer extends AclAuthorizer { + override def authorize(requestContext: AuthorizableRequestContext, actions: util.List[Action]): util.List[AuthorizationResult] = { + actions.asScala.map { _ => + if (requestContext.requestType == ApiKeys.DESCRIBE_USER_SCRAM_CREDENTIALS.id && requestContext.principal == UnauthorizedPrincipal) + AuthorizationResult.DENIED + else + AuthorizationResult.ALLOWED + }.asJava + } + } + + class TestPrincipalBuilderReturningAuthorized extends DefaultKafkaPrincipalBuilder(null, null) { + override def build(context: AuthenticationContext): KafkaPrincipal = { + AuthorizedPrincipal + } + } + + class TestPrincipalBuilderReturningUnauthorized extends DefaultKafkaPrincipalBuilder(null, null) { + override def build(context: AuthenticationContext): KafkaPrincipal = { + UnauthorizedPrincipal + } + } +} diff --git a/core/src/test/scala/unit/kafka/server/DynamicBrokerConfigTest.scala b/core/src/test/scala/unit/kafka/server/DynamicBrokerConfigTest.scala new file mode 100755 index 0000000..b940bc9 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/DynamicBrokerConfigTest.scala @@ -0,0 +1,547 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.{lang, util} +import java.util.Properties +import java.util.concurrent.CompletionStage +import java.util.concurrent.atomic.AtomicReference + +import kafka.controller.KafkaController +import kafka.log.{LogConfig, LogManager} +import kafka.network.SocketServer +import kafka.utils.{KafkaScheduler, TestUtils} +import kafka.zk.KafkaZkClient +import org.apache.kafka.common.{Endpoint, Reconfigurable} +import org.apache.kafka.common.acl.{AclBinding, AclBindingFilter} +import org.apache.kafka.common.config.types.Password +import org.apache.kafka.common.config.{ConfigException, SslConfigs} +import org.apache.kafka.server.authorizer._ +import org.easymock.EasyMock +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test +import org.mockito.{ArgumentMatchers, Mockito} + +import scala.annotation.nowarn +import scala.jdk.CollectionConverters._ +import scala.collection.Set + +class DynamicBrokerConfigTest { + + @Test + def testConfigUpdate(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + val oldKeystore = "oldKs.jks" + props.put(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG, oldKeystore) + val config = KafkaConfig(props) + val dynamicConfig = config.dynamicConfig + dynamicConfig.initialize(None) + + assertEquals(config, dynamicConfig.currentKafkaConfig) + assertEquals(oldKeystore, config.values.get(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG)) + assertEquals(oldKeystore, + config.valuesFromThisConfigWithPrefixOverride("listener.name.external.").get(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG)) + assertEquals(oldKeystore, config.originalsFromThisConfig.get(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG)) + + (1 to 2).foreach { i => + val props1 = new Properties + val newKeystore = s"ks$i.jks" + props1.put(s"listener.name.external.${SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG}", newKeystore) + dynamicConfig.updateBrokerConfig(0, props1) + assertNotSame(config, dynamicConfig.currentKafkaConfig) + + assertEquals(newKeystore, + config.valuesWithPrefixOverride("listener.name.external.").get(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG)) + assertEquals(newKeystore, + config.originalsWithPrefix("listener.name.external.").get(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG)) + assertEquals(newKeystore, + config.valuesWithPrefixOverride("listener.name.external.").get(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG)) + assertEquals(newKeystore, + config.originalsWithPrefix("listener.name.external.").get(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG)) + + assertEquals(oldKeystore, config.getString(KafkaConfig.SslKeystoreLocationProp)) + assertEquals(oldKeystore, config.originals.get(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG)) + assertEquals(oldKeystore, config.values.get(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG)) + assertEquals(oldKeystore, config.originalsStrings.get(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG)) + + assertEquals(oldKeystore, + config.valuesFromThisConfigWithPrefixOverride("listener.name.external.").get(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG)) + assertEquals(oldKeystore, config.originalsFromThisConfig.get(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG)) + assertEquals(oldKeystore, config.valuesFromThisConfig.get(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG)) + assertEquals(oldKeystore, config.originalsFromThisConfig.get(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG)) + assertEquals(oldKeystore, config.valuesFromThisConfig.get(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG)) + } + } + + @Test + def testEnableDefaultUncleanLeaderElection(): Unit = { + val origProps = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + origProps.put(KafkaConfig.UncleanLeaderElectionEnableProp, "false") + + val config = KafkaConfig(origProps) + val serverMock = Mockito.mock(classOf[KafkaServer]) + val controllerMock = Mockito.mock(classOf[KafkaController]) + val logManagerMock = Mockito.mock(classOf[LogManager]) + + Mockito.when(serverMock.config).thenReturn(config) + Mockito.when(serverMock.kafkaController).thenReturn(controllerMock) + Mockito.when(serverMock.logManager).thenReturn(logManagerMock) + Mockito.when(logManagerMock.allLogs).thenReturn(Iterable.empty) + + val currentDefaultLogConfig = new AtomicReference(LogConfig()) + Mockito.when(logManagerMock.currentDefaultConfig).thenAnswer(_ => currentDefaultLogConfig.get()) + Mockito.when(logManagerMock.reconfigureDefaultLogConfig(ArgumentMatchers.any(classOf[LogConfig]))) + .thenAnswer(invocation => currentDefaultLogConfig.set(invocation.getArgument(0))) + + config.dynamicConfig.initialize(None) + config.dynamicConfig.addBrokerReconfigurable(new DynamicLogConfig(logManagerMock, serverMock)) + + val props = new Properties() + + props.put(KafkaConfig.UncleanLeaderElectionEnableProp, "true") + config.dynamicConfig.updateDefaultConfig(props) + assertTrue(config.uncleanLeaderElectionEnable) + Mockito.verify(controllerMock).enableDefaultUncleanLeaderElection() + } + + @Test + def testUpdateDynamicThreadPool(): Unit = { + val origProps = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + origProps.put(KafkaConfig.NumIoThreadsProp, "4") + origProps.put(KafkaConfig.NumNetworkThreadsProp, "2") + origProps.put(KafkaConfig.NumReplicaFetchersProp, "1") + origProps.put(KafkaConfig.NumRecoveryThreadsPerDataDirProp, "1") + origProps.put(KafkaConfig.BackgroundThreadsProp, "3") + + val config = KafkaConfig(origProps) + val serverMock = Mockito.mock(classOf[KafkaBroker]) + val handlerPoolMock = Mockito.mock(classOf[KafkaRequestHandlerPool]) + val socketServerMock = Mockito.mock(classOf[SocketServer]) + val replicaManagerMock = Mockito.mock(classOf[ReplicaManager]) + val logManagerMock = Mockito.mock(classOf[LogManager]) + val schedulerMock = Mockito.mock(classOf[KafkaScheduler]) + + Mockito.when(serverMock.config).thenReturn(config) + Mockito.when(serverMock.dataPlaneRequestHandlerPool).thenReturn(handlerPoolMock) + Mockito.when(serverMock.socketServer).thenReturn(socketServerMock) + Mockito.when(serverMock.replicaManager).thenReturn(replicaManagerMock) + Mockito.when(serverMock.logManager).thenReturn(logManagerMock) + Mockito.when(serverMock.kafkaScheduler).thenReturn(schedulerMock) + + config.dynamicConfig.initialize(None) + config.dynamicConfig.addBrokerReconfigurable(new DynamicThreadPool(serverMock)) + + val props = new Properties() + + props.put(KafkaConfig.NumIoThreadsProp, "8") + config.dynamicConfig.updateDefaultConfig(props) + assertEquals(8, config.numIoThreads) + Mockito.verify(handlerPoolMock).resizeThreadPool(newSize = 8) + + props.put(KafkaConfig.NumNetworkThreadsProp, "4") + config.dynamicConfig.updateDefaultConfig(props) + assertEquals(4, config.numNetworkThreads) + Mockito.verify(socketServerMock).resizeThreadPool(oldNumNetworkThreads = 2, newNumNetworkThreads = 4) + + props.put(KafkaConfig.NumReplicaFetchersProp, "2") + config.dynamicConfig.updateDefaultConfig(props) + assertEquals(2, config.numReplicaFetchers) + Mockito.verify(replicaManagerMock).resizeFetcherThreadPool(newSize = 2) + + props.put(KafkaConfig.NumRecoveryThreadsPerDataDirProp, "2") + config.dynamicConfig.updateDefaultConfig(props) + assertEquals(2, config.numRecoveryThreadsPerDataDir) + Mockito.verify(logManagerMock).resizeRecoveryThreadPool(newSize = 2) + + props.put(KafkaConfig.BackgroundThreadsProp, "6") + config.dynamicConfig.updateDefaultConfig(props) + assertEquals(6, config.backgroundThreads) + Mockito.verify(schedulerMock).resizeThreadPool(newSize = 6) + + Mockito.verifyNoMoreInteractions( + handlerPoolMock, + socketServerMock, + replicaManagerMock, + logManagerMock, + schedulerMock + ) + } + + @nowarn("cat=deprecation") + @Test + def testConfigUpdateWithSomeInvalidConfigs(): Unit = { + val origProps = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + origProps.put(SslConfigs.SSL_KEYSTORE_TYPE_CONFIG, "JKS") + val config = KafkaConfig(origProps) + config.dynamicConfig.initialize(None) + + val validProps = Map(s"listener.name.external.${SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG}" -> "ks.p12") + + val securityPropsWithoutListenerPrefix = Map(SslConfigs.SSL_KEYSTORE_TYPE_CONFIG -> "PKCS12") + verifyConfigUpdateWithInvalidConfig(config, origProps, validProps, securityPropsWithoutListenerPrefix) + val nonDynamicProps = Map(KafkaConfig.ZkConnectProp -> "somehost:2181") + verifyConfigUpdateWithInvalidConfig(config, origProps, validProps, nonDynamicProps) + + // Test update of configs with invalid type + val invalidProps = Map(KafkaConfig.LogCleanerThreadsProp -> "invalid") + verifyConfigUpdateWithInvalidConfig(config, origProps, validProps, invalidProps) + + val excludedTopicConfig = Map(KafkaConfig.LogMessageFormatVersionProp -> "0.10.2") + verifyConfigUpdateWithInvalidConfig(config, origProps, validProps, excludedTopicConfig) + } + + @Test + def testConfigUpdateWithReconfigurableValidationFailure(): Unit = { + val origProps = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + origProps.put(KafkaConfig.LogCleanerDedupeBufferSizeProp, "100000000") + val config = KafkaConfig(origProps) + config.dynamicConfig.initialize(None) + + val validProps = Map.empty[String, String] + val invalidProps = Map(KafkaConfig.LogCleanerThreadsProp -> "20") + + def validateLogCleanerConfig(configs: util.Map[String, _]): Unit = { + val cleanerThreads = configs.get(KafkaConfig.LogCleanerThreadsProp).toString.toInt + if (cleanerThreads <=0 || cleanerThreads >= 5) + throw new ConfigException(s"Invalid cleaner threads $cleanerThreads") + } + val reconfigurable = new Reconfigurable { + override def configure(configs: util.Map[String, _]): Unit = {} + override def reconfigurableConfigs(): util.Set[String] = Set(KafkaConfig.LogCleanerThreadsProp).asJava + override def validateReconfiguration(configs: util.Map[String, _]): Unit = validateLogCleanerConfig(configs) + override def reconfigure(configs: util.Map[String, _]): Unit = {} + } + config.dynamicConfig.addReconfigurable(reconfigurable) + verifyConfigUpdateWithInvalidConfig(config, origProps, validProps, invalidProps) + config.dynamicConfig.removeReconfigurable(reconfigurable) + + val brokerReconfigurable = new BrokerReconfigurable { + override def reconfigurableConfigs: collection.Set[String] = Set(KafkaConfig.LogCleanerThreadsProp) + override def validateReconfiguration(newConfig: KafkaConfig): Unit = validateLogCleanerConfig(newConfig.originals) + override def reconfigure(oldConfig: KafkaConfig, newConfig: KafkaConfig): Unit = {} + } + config.dynamicConfig.addBrokerReconfigurable(brokerReconfigurable) + verifyConfigUpdateWithInvalidConfig(config, origProps, validProps, invalidProps) + } + + @Test + def testReconfigurableValidation(): Unit = { + val origProps = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + val config = KafkaConfig(origProps) + val invalidReconfigurableProps = Set(KafkaConfig.LogCleanerThreadsProp, KafkaConfig.BrokerIdProp, "some.prop") + val validReconfigurableProps = Set(KafkaConfig.LogCleanerThreadsProp, KafkaConfig.LogCleanerDedupeBufferSizeProp, "some.prop") + + def createReconfigurable(configs: Set[String]) = new Reconfigurable { + override def configure(configs: util.Map[String, _]): Unit = {} + override def reconfigurableConfigs(): util.Set[String] = configs.asJava + override def validateReconfiguration(configs: util.Map[String, _]): Unit = {} + override def reconfigure(configs: util.Map[String, _]): Unit = {} + } + assertThrows(classOf[IllegalArgumentException], () => config.dynamicConfig.addReconfigurable(createReconfigurable(invalidReconfigurableProps))) + config.dynamicConfig.addReconfigurable(createReconfigurable(validReconfigurableProps)) + + def createBrokerReconfigurable(configs: Set[String]) = new BrokerReconfigurable { + override def reconfigurableConfigs: collection.Set[String] = configs + override def validateReconfiguration(newConfig: KafkaConfig): Unit = {} + override def reconfigure(oldConfig: KafkaConfig, newConfig: KafkaConfig): Unit = {} + } + assertThrows(classOf[IllegalArgumentException], () => config.dynamicConfig.addBrokerReconfigurable(createBrokerReconfigurable(invalidReconfigurableProps))) + config.dynamicConfig.addBrokerReconfigurable(createBrokerReconfigurable(validReconfigurableProps)) + } + + @Test + def testSecurityConfigs(): Unit = { + def verifyUpdate(name: String, value: Object): Unit = { + verifyConfigUpdate(name, value, perBrokerConfig = true, expectFailure = true) + verifyConfigUpdate(s"listener.name.external.$name", value, perBrokerConfig = true, expectFailure = false) + verifyConfigUpdate(name, value, perBrokerConfig = false, expectFailure = true) + verifyConfigUpdate(s"listener.name.external.$name", value, perBrokerConfig = false, expectFailure = true) + } + + verifyUpdate(SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG, "ks.jks") + verifyUpdate(SslConfigs.SSL_KEYSTORE_TYPE_CONFIG, "JKS") + verifyUpdate(SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG, "password") + verifyUpdate(SslConfigs.SSL_KEY_PASSWORD_CONFIG, "password") + } + + @Test + def testConnectionQuota(): Unit = { + verifyConfigUpdate(KafkaConfig.MaxConnectionsPerIpProp, "100", perBrokerConfig = true, expectFailure = false) + verifyConfigUpdate(KafkaConfig.MaxConnectionsPerIpProp, "100", perBrokerConfig = false, expectFailure = false) + //MaxConnectionsPerIpProp can be set to zero only if MaxConnectionsPerIpOverridesProp property is set + verifyConfigUpdate(KafkaConfig.MaxConnectionsPerIpProp, "0", perBrokerConfig = false, expectFailure = true) + + verifyConfigUpdate(KafkaConfig.MaxConnectionsPerIpOverridesProp, "hostName1:100,hostName2:0", perBrokerConfig = true, + expectFailure = false) + verifyConfigUpdate(KafkaConfig.MaxConnectionsPerIpOverridesProp, "hostName1:100,hostName2:0", perBrokerConfig = false, + expectFailure = false) + //test invalid address + verifyConfigUpdate(KafkaConfig.MaxConnectionsPerIpOverridesProp, "hostName#:100", perBrokerConfig = true, + expectFailure = true) + + verifyConfigUpdate(KafkaConfig.MaxConnectionsProp, "100", perBrokerConfig = true, expectFailure = false) + verifyConfigUpdate(KafkaConfig.MaxConnectionsProp, "100", perBrokerConfig = false, expectFailure = false) + val listenerMaxConnectionsProp = s"listener.name.external.${KafkaConfig.MaxConnectionsProp}" + verifyConfigUpdate(listenerMaxConnectionsProp, "10", perBrokerConfig = true, expectFailure = false) + verifyConfigUpdate(listenerMaxConnectionsProp, "10", perBrokerConfig = false, expectFailure = false) + } + + @Test + def testConnectionRateQuota(): Unit = { + verifyConfigUpdate(KafkaConfig.MaxConnectionCreationRateProp, "110", perBrokerConfig = true, expectFailure = false) + verifyConfigUpdate(KafkaConfig.MaxConnectionCreationRateProp, "120", perBrokerConfig = false, expectFailure = false) + val listenerMaxConnectionsProp = s"listener.name.external.${KafkaConfig.MaxConnectionCreationRateProp}" + verifyConfigUpdate(listenerMaxConnectionsProp, "20", perBrokerConfig = true, expectFailure = false) + verifyConfigUpdate(listenerMaxConnectionsProp, "30", perBrokerConfig = false, expectFailure = false) + } + + private def verifyConfigUpdate(name: String, value: Object, perBrokerConfig: Boolean, expectFailure: Boolean): Unit = { + val configProps = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + configProps.put(KafkaConfig.PasswordEncoderSecretProp, "broker.secret") + val config = KafkaConfig(configProps) + config.dynamicConfig.initialize(None) + + val props = new Properties + props.put(name, value) + val oldValue = config.originals.get(name) + + def updateConfig() = { + if (perBrokerConfig) + config.dynamicConfig.updateBrokerConfig(0, config.dynamicConfig.toPersistentProps(props, perBrokerConfig)) + else + config.dynamicConfig.updateDefaultConfig(props) + } + if (!expectFailure) { + config.dynamicConfig.validate(props, perBrokerConfig) + updateConfig() + assertEquals(value, config.originals.get(name)) + } else { + assertThrows(classOf[Exception], () => config.dynamicConfig.validate(props, perBrokerConfig)) + updateConfig() + assertEquals(oldValue, config.originals.get(name)) + } + } + + private def verifyConfigUpdateWithInvalidConfig(config: KafkaConfig, + origProps: Properties, + validProps: Map[String, String], + invalidProps: Map[String, String]): Unit = { + val props = new Properties + validProps.foreach { case (k, v) => props.put(k, v) } + invalidProps.foreach { case (k, v) => props.put(k, v) } + + // DynamicBrokerConfig#validate is used by AdminClient to validate the configs provided in + // in an AlterConfigs request. Validation should fail with an exception if any of the configs are invalid. + assertThrows(classOf[ConfigException], () => config.dynamicConfig.validate(props, perBrokerConfig = true)) + + // DynamicBrokerConfig#updateBrokerConfig is used to update configs from ZooKeeper during + // startup and when configs are updated in ZK. Update should apply valid configs and ignore + // invalid ones. + config.dynamicConfig.updateBrokerConfig(0, props) + validProps.foreach { case (name, value) => assertEquals(value, config.originals.get(name)) } + invalidProps.keySet.foreach { name => + assertEquals(origProps.get(name), config.originals.get(name)) + } + } + + @Test + def testPasswordConfigEncryption(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + val configWithoutSecret = KafkaConfig(props) + props.put(KafkaConfig.PasswordEncoderSecretProp, "config-encoder-secret") + val configWithSecret = KafkaConfig(props) + val dynamicProps = new Properties + dynamicProps.put(KafkaConfig.SaslJaasConfigProp, "myLoginModule required;") + + try { + configWithoutSecret.dynamicConfig.toPersistentProps(dynamicProps, perBrokerConfig = true) + } catch { + case e: ConfigException => // expected exception + } + val persistedProps = configWithSecret.dynamicConfig.toPersistentProps(dynamicProps, perBrokerConfig = true) + assertFalse(persistedProps.getProperty(KafkaConfig.SaslJaasConfigProp).contains("myLoginModule"), + "Password not encoded") + val decodedProps = configWithSecret.dynamicConfig.fromPersistentProps(persistedProps, perBrokerConfig = true) + assertEquals("myLoginModule required;", decodedProps.getProperty(KafkaConfig.SaslJaasConfigProp)) + } + + @Test + def testPasswordConfigEncoderSecretChange(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + props.put(KafkaConfig.SaslJaasConfigProp, "staticLoginModule required;") + props.put(KafkaConfig.PasswordEncoderSecretProp, "config-encoder-secret") + val config = KafkaConfig(props) + config.dynamicConfig.initialize(None) + val dynamicProps = new Properties + dynamicProps.put(KafkaConfig.SaslJaasConfigProp, "dynamicLoginModule required;") + + val persistedProps = config.dynamicConfig.toPersistentProps(dynamicProps, perBrokerConfig = true) + assertFalse(persistedProps.getProperty(KafkaConfig.SaslJaasConfigProp).contains("LoginModule"), + "Password not encoded") + config.dynamicConfig.updateBrokerConfig(0, persistedProps) + assertEquals("dynamicLoginModule required;", config.values.get(KafkaConfig.SaslJaasConfigProp).asInstanceOf[Password].value) + + // New config with same secret should use the dynamic password config + val newConfigWithSameSecret = KafkaConfig(props) + newConfigWithSameSecret.dynamicConfig.initialize(None) + newConfigWithSameSecret.dynamicConfig.updateBrokerConfig(0, persistedProps) + assertEquals("dynamicLoginModule required;", newConfigWithSameSecret.values.get(KafkaConfig.SaslJaasConfigProp).asInstanceOf[Password].value) + + // New config with new secret should use the dynamic password config if new and old secrets are configured in KafkaConfig + props.put(KafkaConfig.PasswordEncoderSecretProp, "new-encoder-secret") + props.put(KafkaConfig.PasswordEncoderOldSecretProp, "config-encoder-secret") + val newConfigWithNewAndOldSecret = KafkaConfig(props) + newConfigWithNewAndOldSecret.dynamicConfig.updateBrokerConfig(0, persistedProps) + assertEquals("dynamicLoginModule required;", newConfigWithSameSecret.values.get(KafkaConfig.SaslJaasConfigProp).asInstanceOf[Password].value) + + // New config with new secret alone should revert to static password config since dynamic config cannot be decoded + props.put(KafkaConfig.PasswordEncoderSecretProp, "another-new-encoder-secret") + val newConfigWithNewSecret = KafkaConfig(props) + newConfigWithNewSecret.dynamicConfig.updateBrokerConfig(0, persistedProps) + assertEquals("staticLoginModule required;", newConfigWithNewSecret.values.get(KafkaConfig.SaslJaasConfigProp).asInstanceOf[Password].value) + } + + @Test + def testDynamicListenerConfig(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 9092) + val oldConfig = KafkaConfig.fromProps(props) + val kafkaServer: KafkaServer = EasyMock.createMock(classOf[kafka.server.KafkaServer]) + EasyMock.expect(kafkaServer.config).andReturn(oldConfig).anyTimes() + EasyMock.replay(kafkaServer) + + props.put(KafkaConfig.ListenersProp, "PLAINTEXT://hostname:9092,SASL_PLAINTEXT://hostname:9093") + new DynamicListenerConfig(kafkaServer).validateReconfiguration(KafkaConfig(props)) + + // it is illegal to update non-reconfiguable configs of existent listeners + props.put("listener.name.plaintext.you.should.not.pass", "failure") + val dynamicListenerConfig = new DynamicListenerConfig(kafkaServer) + assertThrows(classOf[ConfigException], () => dynamicListenerConfig.validateReconfiguration(KafkaConfig(props))) + } + + @Test + def testAuthorizerConfig(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 9092) + val oldConfig = KafkaConfig.fromProps(props) + oldConfig.dynamicConfig.initialize(None) + + val kafkaServer: KafkaServer = EasyMock.createMock(classOf[kafka.server.KafkaServer]) + + class TestAuthorizer extends Authorizer with Reconfigurable { + @volatile var superUsers = "" + override def start(serverInfo: AuthorizerServerInfo): util.Map[Endpoint, _ <: CompletionStage[Void]] = Map.empty.asJava + override def authorize(requestContext: AuthorizableRequestContext, actions: util.List[Action]): util.List[AuthorizationResult] = null + override def createAcls(requestContext: AuthorizableRequestContext, aclBindings: util.List[AclBinding]): util.List[_ <: CompletionStage[AclCreateResult]] = null + override def deleteAcls(requestContext: AuthorizableRequestContext, aclBindingFilters: util.List[AclBindingFilter]): util.List[_ <: CompletionStage[AclDeleteResult]] = null + override def acls(filter: AclBindingFilter): lang.Iterable[AclBinding] = null + override def close(): Unit = {} + override def configure(configs: util.Map[String, _]): Unit = {} + override def reconfigurableConfigs(): util.Set[String] = Set("super.users").asJava + override def validateReconfiguration(configs: util.Map[String, _]): Unit = {} + override def reconfigure(configs: util.Map[String, _]): Unit = { + superUsers = configs.get("super.users").toString + } + } + + val authorizer = new TestAuthorizer + EasyMock.expect(kafkaServer.config).andReturn(oldConfig).anyTimes() + EasyMock.expect(kafkaServer.authorizer).andReturn(Some(authorizer)).anyTimes() + EasyMock.replay(kafkaServer) + // We are only testing authorizer reconfiguration, ignore any exceptions due to incomplete mock + assertThrows(classOf[Throwable], () => kafkaServer.config.dynamicConfig.addReconfigurables(kafkaServer)) + props.put("super.users", "User:admin") + kafkaServer.config.dynamicConfig.updateBrokerConfig(0, props) + assertEquals("User:admin", authorizer.superUsers) + } + + @Test + def testSynonyms(): Unit = { + assertEquals(List("listener.name.secure.ssl.keystore.type", "ssl.keystore.type"), + DynamicBrokerConfig.brokerConfigSynonyms("listener.name.secure.ssl.keystore.type", matchListenerOverride = true)) + assertEquals(List("listener.name.sasl_ssl.plain.sasl.jaas.config", "sasl.jaas.config"), + DynamicBrokerConfig.brokerConfigSynonyms("listener.name.sasl_ssl.plain.sasl.jaas.config", matchListenerOverride = true)) + assertEquals(List("some.config"), + DynamicBrokerConfig.brokerConfigSynonyms("some.config", matchListenerOverride = true)) + assertEquals(List(KafkaConfig.LogRollTimeMillisProp, KafkaConfig.LogRollTimeHoursProp), + DynamicBrokerConfig.brokerConfigSynonyms(KafkaConfig.LogRollTimeMillisProp, matchListenerOverride = true)) + } + + @Test + def testDynamicConfigInitializationWithoutConfigsInZK(): Unit = { + val zkClient: KafkaZkClient = EasyMock.createMock(classOf[KafkaZkClient]) + EasyMock.expect(zkClient.getEntityConfigs(EasyMock.anyString(), EasyMock.anyString())).andReturn(new java.util.Properties()).anyTimes() + EasyMock.replay(zkClient) + + val oldConfig = KafkaConfig.fromProps(TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 9092)) + val dynamicBrokerConfig = new DynamicBrokerConfig(oldConfig) + dynamicBrokerConfig.initialize(Some(zkClient)) + dynamicBrokerConfig.addBrokerReconfigurable(new TestDynamicThreadPool) + + val newprops = new Properties() + newprops.put(KafkaConfig.NumIoThreadsProp, "10") + newprops.put(KafkaConfig.BackgroundThreadsProp, "100") + dynamicBrokerConfig.updateBrokerConfig(0, newprops) + } + + @Test + def testImproperConfigsAreRemoved(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect) + val config = KafkaConfig(props) + config.dynamicConfig.initialize(None) + + assertEquals(Defaults.MaxConnections, config.maxConnections) + assertEquals(Defaults.MessageMaxBytes, config.messageMaxBytes) + + var newProps = new Properties() + newProps.put(KafkaConfig.MaxConnectionsProp, "9999") + newProps.put(KafkaConfig.MessageMaxBytesProp, "2222") + + config.dynamicConfig.updateDefaultConfig(newProps) + assertEquals(9999, config.maxConnections) + assertEquals(2222, config.messageMaxBytes) + + newProps = new Properties() + newProps.put(KafkaConfig.MaxConnectionsProp, "INVALID_INT") + newProps.put(KafkaConfig.MessageMaxBytesProp, "1111") + + config.dynamicConfig.updateDefaultConfig(newProps) + // Invalid value should be skipped and reassigned as default value + assertEquals(Defaults.MaxConnections, config.maxConnections) + // Even if One property is invalid, the below should get correctly updated. + assertEquals(1111, config.messageMaxBytes) + } +} + +class TestDynamicThreadPool() extends BrokerReconfigurable { + + override def reconfigurableConfigs: Set[String] = { + DynamicThreadPool.ReconfigurableConfigs + } + + override def reconfigure(oldConfig: KafkaConfig, newConfig: KafkaConfig): Unit = { + assertEquals(Defaults.NumIoThreads, oldConfig.numIoThreads) + assertEquals(Defaults.BackgroundThreads, oldConfig.backgroundThreads) + + assertEquals(10, newConfig.numIoThreads) + assertEquals(100, newConfig.backgroundThreads) + } + + override def validateReconfiguration(newConfig: KafkaConfig): Unit = { + assertEquals(10, newConfig.numIoThreads) + assertEquals(100, newConfig.backgroundThreads) + } +} diff --git a/core/src/test/scala/unit/kafka/server/DynamicConfigChangeTest.scala b/core/src/test/scala/unit/kafka/server/DynamicConfigChangeTest.scala new file mode 100644 index 0000000..4b39824 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/DynamicConfigChangeTest.scala @@ -0,0 +1,418 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import kafka.api.KAFKA_3_0_IV1 + +import java.net.InetAddress +import java.nio.charset.StandardCharsets +import java.util.Properties +import java.util.concurrent.ExecutionException +import kafka.integration.KafkaServerTestHarness +import kafka.log.LogConfig._ +import kafka.utils._ +import kafka.server.Constants._ +import kafka.zk.ConfigEntityChangeNotificationZNode +import org.apache.kafka.clients.CommonClientConfigs +import org.apache.kafka.clients.admin.{Admin, AlterConfigOp, ConfigEntry} +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.config.ConfigResource +import org.apache.kafka.common.config.internals.QuotaConfigs +import org.apache.kafka.common.errors.UnknownTopicOrPartitionException +import org.apache.kafka.common.metrics.Quota +import org.apache.kafka.common.record.{CompressionType, RecordVersion} +import org.easymock.EasyMock +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.annotation.nowarn +import scala.collection.{Map, Seq} +import scala.jdk.CollectionConverters._ + +class DynamicConfigChangeTest extends KafkaServerTestHarness { + def generateConfigs = List(KafkaConfig.fromProps(TestUtils.createBrokerConfig(0, zkConnect))) + + @Test + def testConfigChange(): Unit = { + assertTrue(this.servers.head.dynamicConfigHandlers.contains(ConfigType.Topic), + "Should contain a ConfigHandler for topics") + val oldVal: java.lang.Long = 100000L + val newVal: java.lang.Long = 200000L + val tp = new TopicPartition("test", 0) + val logProps = new Properties() + logProps.put(FlushMessagesProp, oldVal.toString) + createTopic(tp.topic, 1, 1, logProps) + TestUtils.retry(10000) { + val logOpt = this.servers.head.logManager.getLog(tp) + assertTrue(logOpt.isDefined) + assertEquals(oldVal, logOpt.get.config.flushInterval) + } + logProps.put(FlushMessagesProp, newVal.toString) + adminZkClient.changeTopicConfig(tp.topic, logProps) + TestUtils.retry(10000) { + assertEquals(newVal, this.servers.head.logManager.getLog(tp).get.config.flushInterval) + } + } + + @Test + def testDynamicTopicConfigChange(): Unit = { + val tp = new TopicPartition("test", 0) + val oldSegmentSize = 1000 + val logProps = new Properties() + logProps.put(SegmentBytesProp, oldSegmentSize.toString) + createTopic(tp.topic, 1, 1, logProps) + TestUtils.retry(10000) { + val logOpt = this.servers.head.logManager.getLog(tp) + assertTrue(logOpt.isDefined) + assertEquals(oldSegmentSize, logOpt.get.config.segmentSize) + } + + val log = servers.head.logManager.getLog(tp).get + + val newSegmentSize = 2000 + logProps.put(SegmentBytesProp, newSegmentSize.toString) + adminZkClient.changeTopicConfig(tp.topic, logProps) + TestUtils.retry(10000) { + assertEquals(newSegmentSize, log.config.segmentSize) + } + + (1 to 50).foreach(i => TestUtils.produceMessage(servers, tp.topic, i.toString)) + // Verify that the new config is used for all segments + assertTrue(log.logSegments.forall(_.size > 1000), "Log segment size change not applied") + } + + @nowarn("cat=deprecation") + @Test + def testMessageFormatVersionChange(): Unit = { + val tp = new TopicPartition("test", 0) + val logProps = new Properties() + logProps.put(MessageFormatVersionProp, "0.10.2") + createTopic(tp.topic, 1, 1, logProps) + val server = servers.head + TestUtils.waitUntilTrue(() => server.logManager.getLog(tp).isDefined, + "Topic metadata propagation failed") + val log = server.logManager.getLog(tp).get + // message format version should always be 3.0 if inter-broker protocol is 3.0 or higher + assertEquals(KAFKA_3_0_IV1, log.config.messageFormatVersion) + assertEquals(RecordVersion.V2, log.config.recordVersion) + + val compressionType = CompressionType.LZ4.name + logProps.put(MessageFormatVersionProp, "0.11.0") + // set compression type so that we can detect when the config change has propagated + logProps.put(CompressionTypeProp, compressionType) + adminZkClient.changeTopicConfig(tp.topic, logProps) + TestUtils.waitUntilTrue(() => + server.logManager.getLog(tp).get.config.compressionType == compressionType, + "Topic config change propagation failed") + assertEquals(KAFKA_3_0_IV1, log.config.messageFormatVersion) + assertEquals(RecordVersion.V2, log.config.recordVersion) + } + + private def testQuotaConfigChange(user: String, clientId: String, rootEntityType: String, configEntityName: String): Unit = { + assertTrue(this.servers.head.dynamicConfigHandlers.contains(rootEntityType), "Should contain a ConfigHandler for " + rootEntityType) + val props = new Properties() + props.put(QuotaConfigs.PRODUCER_BYTE_RATE_OVERRIDE_CONFIG, "1000") + props.put(QuotaConfigs.CONSUMER_BYTE_RATE_OVERRIDE_CONFIG, "2000") + + val quotaManagers = servers.head.dataPlaneRequestProcessor.quotas + rootEntityType match { + case ConfigType.Client => adminZkClient.changeClientIdConfig(configEntityName, props) + case _ => adminZkClient.changeUserOrUserClientIdConfig(configEntityName, props) + } + + TestUtils.retry(10000) { + val overrideProducerQuota = quotaManagers.produce.quota(user, clientId) + val overrideConsumerQuota = quotaManagers.fetch.quota(user, clientId) + + assertEquals(Quota.upperBound(1000), + overrideProducerQuota, s"User $user clientId $clientId must have overridden producer quota of 1000") + assertEquals(Quota.upperBound(2000), + overrideConsumerQuota, s"User $user clientId $clientId must have overridden consumer quota of 2000") + } + + val defaultProducerQuota = Long.MaxValue.asInstanceOf[Double] + val defaultConsumerQuota = Long.MaxValue.asInstanceOf[Double] + + val emptyProps = new Properties() + rootEntityType match { + case ConfigType.Client => adminZkClient.changeClientIdConfig(configEntityName, emptyProps) + case _ => adminZkClient.changeUserOrUserClientIdConfig(configEntityName, emptyProps) + } + TestUtils.retry(10000) { + val producerQuota = quotaManagers.produce.quota(user, clientId) + val consumerQuota = quotaManagers.fetch.quota(user, clientId) + + assertEquals(Quota.upperBound(defaultProducerQuota), + producerQuota, s"User $user clientId $clientId must have reset producer quota to " + defaultProducerQuota) + assertEquals(Quota.upperBound(defaultConsumerQuota), + consumerQuota, s"User $user clientId $clientId must have reset consumer quota to " + defaultConsumerQuota) + } + } + + @Test + def testClientIdQuotaConfigChange(): Unit = { + testQuotaConfigChange("ANONYMOUS", "testClient", ConfigType.Client, "testClient") + } + + @Test + def testUserQuotaConfigChange(): Unit = { + testQuotaConfigChange("ANONYMOUS", "testClient", ConfigType.User, "ANONYMOUS") + } + + @Test + def testUserClientIdQuotaChange(): Unit = { + testQuotaConfigChange("ANONYMOUS", "testClient", ConfigType.User, "ANONYMOUS/clients/testClient") + } + + @Test + def testDefaultClientIdQuotaConfigChange(): Unit = { + testQuotaConfigChange("ANONYMOUS", "testClient", ConfigType.Client, "") + } + + @Test + def testDefaultUserQuotaConfigChange(): Unit = { + testQuotaConfigChange("ANONYMOUS", "testClient", ConfigType.User, "") + } + + @Test + def testDefaultUserClientIdQuotaConfigChange(): Unit = { + testQuotaConfigChange("ANONYMOUS", "testClient", ConfigType.User, "/clients/") + } + + @Test + def testQuotaInitialization(): Unit = { + val server = servers.head + val clientIdProps = new Properties() + server.shutdown() + clientIdProps.put(QuotaConfigs.PRODUCER_BYTE_RATE_OVERRIDE_CONFIG, "1000") + clientIdProps.put(QuotaConfigs.CONSUMER_BYTE_RATE_OVERRIDE_CONFIG, "2000") + val userProps = new Properties() + userProps.put(QuotaConfigs.PRODUCER_BYTE_RATE_OVERRIDE_CONFIG, "10000") + userProps.put(QuotaConfigs.CONSUMER_BYTE_RATE_OVERRIDE_CONFIG, "20000") + val userClientIdProps = new Properties() + userClientIdProps.put(QuotaConfigs.PRODUCER_BYTE_RATE_OVERRIDE_CONFIG, "100000") + userClientIdProps.put(QuotaConfigs.CONSUMER_BYTE_RATE_OVERRIDE_CONFIG, "200000") + + adminZkClient.changeClientIdConfig("overriddenClientId", clientIdProps) + adminZkClient.changeUserOrUserClientIdConfig("overriddenUser", userProps) + adminZkClient.changeUserOrUserClientIdConfig("ANONYMOUS/clients/overriddenUserClientId", userClientIdProps) + + // Remove config change znodes to force quota initialization only through loading of user/client quotas + zkClient.getChildren(ConfigEntityChangeNotificationZNode.path).foreach { p => zkClient.deletePath(ConfigEntityChangeNotificationZNode.path + "/" + p) } + server.startup() + val quotaManagers = server.dataPlaneRequestProcessor.quotas + + assertEquals(Quota.upperBound(1000), quotaManagers.produce.quota("someuser", "overriddenClientId")) + assertEquals(Quota.upperBound(2000), quotaManagers.fetch.quota("someuser", "overriddenClientId")) + assertEquals(Quota.upperBound(10000), quotaManagers.produce.quota("overriddenUser", "someclientId")) + assertEquals(Quota.upperBound(20000), quotaManagers.fetch.quota("overriddenUser", "someclientId")) + assertEquals(Quota.upperBound(100000), quotaManagers.produce.quota("ANONYMOUS", "overriddenUserClientId")) + assertEquals(Quota.upperBound(200000), quotaManagers.fetch.quota("ANONYMOUS", "overriddenUserClientId")) + } + + @Test + def testIpHandlerUnresolvableAddress(): Unit = { + val configHandler = new IpConfigHandler(null) + val props: Properties = new Properties() + props.put(QuotaConfigs.IP_CONNECTION_RATE_OVERRIDE_CONFIG, "1") + + assertThrows(classOf[IllegalArgumentException], () => configHandler.processConfigChanges("illegal-hostname", props)) + } + + @Test + def testIpQuotaInitialization(): Unit = { + val server = servers.head + val ipOverrideProps = new Properties() + ipOverrideProps.put(QuotaConfigs.IP_CONNECTION_RATE_OVERRIDE_CONFIG, "10") + val ipDefaultProps = new Properties() + ipDefaultProps.put(QuotaConfigs.IP_CONNECTION_RATE_OVERRIDE_CONFIG, "20") + server.shutdown() + + adminZkClient.changeIpConfig(ConfigEntityName.Default, ipDefaultProps) + adminZkClient.changeIpConfig("1.2.3.4", ipOverrideProps) + + // Remove config change znodes to force quota initialization only through loading of ip quotas + zkClient.getChildren(ConfigEntityChangeNotificationZNode.path).foreach { p => + zkClient.deletePath(ConfigEntityChangeNotificationZNode.path + "/" + p) + } + server.startup() + + val connectionQuotas = server.socketServer.connectionQuotas + assertEquals(10L, connectionQuotas.connectionRateForIp(InetAddress.getByName("1.2.3.4"))) + assertEquals(20L, connectionQuotas.connectionRateForIp(InetAddress.getByName("2.4.6.8"))) + } + + @Test + def testIpQuotaConfigChange(): Unit = { + val ipOverrideProps = new Properties() + ipOverrideProps.put(QuotaConfigs.IP_CONNECTION_RATE_OVERRIDE_CONFIG, "10") + val ipDefaultProps = new Properties() + ipDefaultProps.put(QuotaConfigs.IP_CONNECTION_RATE_OVERRIDE_CONFIG, "20") + + val overrideQuotaIp = InetAddress.getByName("1.2.3.4") + val defaultQuotaIp = InetAddress.getByName("2.3.4.5") + adminZkClient.changeIpConfig(ConfigEntityName.Default, ipDefaultProps) + adminZkClient.changeIpConfig(overrideQuotaIp.getHostAddress, ipOverrideProps) + + val connectionQuotas = servers.head.socketServer.connectionQuotas + + def verifyConnectionQuota(ip: InetAddress, expectedQuota: Integer) = { + TestUtils.retry(10000) { + val quota = connectionQuotas.connectionRateForIp(ip) + assertEquals(expectedQuota, quota, s"Unexpected quota for IP $ip") + } + } + + verifyConnectionQuota(overrideQuotaIp, 10) + verifyConnectionQuota(defaultQuotaIp, 20) + + val emptyProps = new Properties() + adminZkClient.changeIpConfig(overrideQuotaIp.getHostAddress, emptyProps) + verifyConnectionQuota(overrideQuotaIp, 20) + + adminZkClient.changeIpConfig(ConfigEntityName.Default, emptyProps) + verifyConnectionQuota(overrideQuotaIp, QuotaConfigs.IP_CONNECTION_RATE_DEFAULT) + } + + @Test + def testConfigChangeOnNonExistingTopic(): Unit = { + val topic = TestUtils.tempTopic() + val logProps = new Properties() + logProps.put(FlushMessagesProp, 10000: java.lang.Integer) + assertThrows(classOf[UnknownTopicOrPartitionException], () => adminZkClient.changeTopicConfig(topic, logProps)) + } + + @Test + def testConfigChangeOnNonExistingTopicWithAdminClient(): Unit = { + val topic = TestUtils.tempTopic() + val admin = createAdminClient() + try { + val resource = new ConfigResource(ConfigResource.Type.TOPIC, topic) + val op = new AlterConfigOp(new ConfigEntry(FlushMessagesProp, "10000"), AlterConfigOp.OpType.SET) + admin.incrementalAlterConfigs(Map(resource -> List(op).asJavaCollection).asJava).all.get + fail("Should fail with UnknownTopicOrPartitionException for topic doesn't exist") + } catch { + case e: ExecutionException => + assertTrue(e.getCause.isInstanceOf[UnknownTopicOrPartitionException]) + } finally { + admin.close() + } + } + + @Test + def testProcessNotification(): Unit = { + val props = new Properties() + props.put("a.b", "10") + + // Create a mock ConfigHandler to record config changes it is asked to process + val entityArgument = EasyMock.newCapture[String] + val propertiesArgument = EasyMock.newCapture[Properties] + val handler: ConfigHandler = EasyMock.createNiceMock(classOf[ConfigHandler]) + handler.processConfigChanges( + EasyMock.and(EasyMock.capture(entityArgument), EasyMock.isA(classOf[String])), + EasyMock.and(EasyMock.capture(propertiesArgument), EasyMock.isA(classOf[Properties]))) + EasyMock.expectLastCall().once() + EasyMock.replay(handler) + + val configManager = new DynamicConfigManager(zkClient, Map(ConfigType.Topic -> handler)) + // Notifications created using the old TopicConfigManager are ignored. + configManager.ConfigChangedNotificationHandler.processNotification("not json".getBytes(StandardCharsets.UTF_8)) + + // Incorrect Map. No version + var jsonMap: Map[String, Any] = Map("v" -> 1, "x" -> 2) + + assertThrows(classOf[Throwable], () => configManager.ConfigChangedNotificationHandler.processNotification(Json.encodeAsBytes(jsonMap.asJava))) + // Version is provided. EntityType is incorrect + jsonMap = Map("version" -> 1, "entity_type" -> "garbage", "entity_name" -> "x") + assertThrows(classOf[Throwable], () => configManager.ConfigChangedNotificationHandler.processNotification(Json.encodeAsBytes(jsonMap.asJava))) + + // EntityName isn't provided + jsonMap = Map("version" -> 1, "entity_type" -> ConfigType.Topic) + assertThrows(classOf[Throwable], () => configManager.ConfigChangedNotificationHandler.processNotification(Json.encodeAsBytes(jsonMap.asJava))) + + // Everything is provided + jsonMap = Map("version" -> 1, "entity_type" -> ConfigType.Topic, "entity_name" -> "x") + configManager.ConfigChangedNotificationHandler.processNotification(Json.encodeAsBytes(jsonMap.asJava)) + + // Verify that processConfigChanges was only called once + EasyMock.verify(handler) + } + + @Test + def shouldParseReplicationQuotaProperties(): Unit = { + val configHandler: TopicConfigHandler = new TopicConfigHandler(null, null, null, null) + val props: Properties = new Properties() + + //Given + props.put(LeaderReplicationThrottledReplicasProp, "0:101,0:102,1:101,1:102") + + //When/Then + assertEquals(Seq(0,1), configHandler.parseThrottledPartitions(props, 102, LeaderReplicationThrottledReplicasProp)) + assertEquals(Seq(), configHandler.parseThrottledPartitions(props, 103, LeaderReplicationThrottledReplicasProp)) + } + + @Test + def shouldParseWildcardReplicationQuotaProperties(): Unit = { + val configHandler: TopicConfigHandler = new TopicConfigHandler(null, null, null, null) + val props: Properties = new Properties() + + //Given + props.put(LeaderReplicationThrottledReplicasProp, "*") + + //When + val result = configHandler.parseThrottledPartitions(props, 102, LeaderReplicationThrottledReplicasProp) + + //Then + assertEquals(AllReplicas, result) + } + + @Test + def shouldParseReplicationQuotaReset(): Unit = { + val configHandler: TopicConfigHandler = new TopicConfigHandler(null, null, null, null) + val props: Properties = new Properties() + + //Given + props.put(FollowerReplicationThrottledReplicasProp, "") + + //When + val result = configHandler.parseThrottledPartitions(props, 102, FollowerReplicationThrottledReplicasProp) + + //Then + assertEquals(Seq(), result) + } + + @Test + def shouldParseRegardlessOfWhitespaceAroundValues(): Unit = { + val configHandler: TopicConfigHandler = new TopicConfigHandler(null, null, null, null) + assertEquals(AllReplicas, parse(configHandler, "* ")) + assertEquals(Seq(), parse(configHandler, " ")) + assertEquals(Seq(6), parse(configHandler, "6:102")) + assertEquals(Seq(6), parse(configHandler, "6:102 ")) + assertEquals(Seq(6), parse(configHandler, " 6:102")) + } + + def parse(configHandler: TopicConfigHandler, value: String): Seq[Int] = { + configHandler.parseThrottledPartitions(CoreUtils.propsWith(LeaderReplicationThrottledReplicasProp, value), 102, LeaderReplicationThrottledReplicasProp) + } + + private def createAdminClient(): Admin = { + val props = new Properties() + props.put(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, brokerList) + Admin.create(props) + } + +} diff --git a/core/src/test/scala/unit/kafka/server/DynamicConfigTest.scala b/core/src/test/scala/unit/kafka/server/DynamicConfigTest.scala new file mode 100644 index 0000000..1fb9f33 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/DynamicConfigTest.scala @@ -0,0 +1,72 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import kafka.admin.AdminOperationException +import kafka.utils.CoreUtils._ +import kafka.server.QuorumTestHarness +import org.apache.kafka.common.config._ +import org.apache.kafka.common.config.internals.QuotaConfigs +import org.junit.jupiter.api.Assertions.assertThrows +import org.junit.jupiter.api.Test + +class DynamicConfigTest extends QuorumTestHarness { + private final val nonExistentConfig: String = "some.config.that.does.not.exist" + private final val someValue: String = "some interesting value" + + @Test + def shouldFailWhenChangingClientIdUnknownConfig(): Unit = { + assertThrows(classOf[IllegalArgumentException], () => adminZkClient.changeClientIdConfig("ClientId", + propsWith(nonExistentConfig, someValue))) + } + + @Test + def shouldFailWhenChangingUserUnknownConfig(): Unit = { + assertThrows(classOf[IllegalArgumentException], () => adminZkClient.changeUserOrUserClientIdConfig("UserId", + propsWith(nonExistentConfig, someValue))) + } + + @Test + def shouldFailLeaderConfigsWithInvalidValues(): Unit = { + assertThrows(classOf[ConfigException], () => adminZkClient.changeBrokerConfig(Seq(0), + propsWith(DynamicConfig.Broker.LeaderReplicationThrottledRateProp, "-100"))) + } + + @Test + def shouldFailFollowerConfigsWithInvalidValues(): Unit = { + assertThrows(classOf[ConfigException], () => adminZkClient.changeBrokerConfig(Seq(0), + propsWith(DynamicConfig.Broker.FollowerReplicationThrottledRateProp, "-100"))) + } + + @Test + def shouldFailIpConfigsWithInvalidValues(): Unit = { + assertThrows(classOf[ConfigException], () => adminZkClient.changeIpConfig("1.2.3.4", + propsWith(QuotaConfigs.IP_CONNECTION_RATE_OVERRIDE_CONFIG, "-1"))) + } + + @Test + def shouldFailIpConfigsWithInvalidIpv4Entity(): Unit = { + assertThrows(classOf[AdminOperationException], () => adminZkClient.changeIpConfig("1,1.1.1", + propsWith(QuotaConfigs.IP_CONNECTION_RATE_OVERRIDE_CONFIG, "2"))) + } + + @Test + def shouldFailIpConfigsWithBadHost(): Unit = { + assertThrows(classOf[AdminOperationException], () => adminZkClient.changeIpConfig("ip", + propsWith(QuotaConfigs.IP_CONNECTION_RATE_OVERRIDE_CONFIG, "2"))) + } +} diff --git a/core/src/test/scala/unit/kafka/server/EdgeCaseRequestTest.scala b/core/src/test/scala/unit/kafka/server/EdgeCaseRequestTest.scala new file mode 100755 index 0000000..4d31e8e --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/EdgeCaseRequestTest.scala @@ -0,0 +1,194 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.io.{DataInputStream, DataOutputStream} +import java.net.Socket +import java.nio.ByteBuffer +import java.util.Collections + +import kafka.integration.KafkaServerTestHarness +import kafka.network.SocketServer +import kafka.utils._ +import org.apache.kafka.common.message.ProduceRequestData +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.protocol.types.Type +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.record.{CompressionType, MemoryRecords, SimpleRecord} +import org.apache.kafka.common.requests.{ProduceResponse, ResponseHeader} +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.utils.ByteUtils +import org.apache.kafka.common.{TopicPartition, requests} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ + +class EdgeCaseRequestTest extends KafkaServerTestHarness { + + def generateConfigs = { + val props = TestUtils.createBrokerConfig(1, zkConnect) + props.setProperty(KafkaConfig.AutoCreateTopicsEnableProp, "false") + List(KafkaConfig.fromProps(props)) + } + + private def socketServer = servers.head.socketServer + + private def connect(s: SocketServer = socketServer, protocol: SecurityProtocol = SecurityProtocol.PLAINTEXT): Socket = { + new Socket("localhost", s.boundPort(ListenerName.forSecurityProtocol(protocol))) + } + + private def sendRequest(socket: Socket, request: Array[Byte], id: Option[Short] = None): Unit = { + val outgoing = new DataOutputStream(socket.getOutputStream) + id match { + case Some(id) => + outgoing.writeInt(request.length + 2) + outgoing.writeShort(id) + case None => + outgoing.writeInt(request.length) + } + outgoing.write(request) + outgoing.flush() + } + + private def receiveResponse(socket: Socket): Array[Byte] = { + val incoming = new DataInputStream(socket.getInputStream) + val len = incoming.readInt() + val response = new Array[Byte](len) + incoming.readFully(response) + response + } + + private def requestAndReceive(request: Array[Byte], id: Option[Short] = None): Array[Byte] = { + val plainSocket = connect() + try { + sendRequest(plainSocket, request, id) + receiveResponse(plainSocket) + } finally { + plainSocket.close() + } + } + + // Custom header serialization so that protocol assumptions are not forced + def requestHeaderBytes(apiKey: Short, apiVersion: Short, clientId: String = "", correlationId: Int = -1): Array[Byte] = { + // Check for flex versions, some tests here verify that an invalid apiKey is detected properly, so if -1 is used, + // assume the request is not using flex versions. + val flexVersion = if (apiKey >= 0) ApiKeys.forId(apiKey).requestHeaderVersion(apiVersion) >= 2 else false + val size = { + 2 /* apiKey */ + + 2 /* version id */ + + 4 /* correlation id */ + + Type.NULLABLE_STRING.sizeOf(clientId) /* client id */ + + (if (flexVersion) ByteUtils.sizeOfUnsignedVarint(0) else 0) /* Empty tagged fields for flexible versions */ + } + + val buffer = ByteBuffer.allocate(size) + buffer.putShort(apiKey) + buffer.putShort(apiVersion) + buffer.putInt(correlationId) + Type.NULLABLE_STRING.write(buffer, clientId) + if (flexVersion) ByteUtils.writeUnsignedVarint(0, buffer) + buffer.array() + } + + private def verifyDisconnect(request: Array[Byte]): Unit = { + val plainSocket = connect() + try { + sendRequest(plainSocket, requestHeaderBytes(-1, 0)) + assertEquals(-1, plainSocket.getInputStream.read(), "The server should disconnect") + } finally { + plainSocket.close() + } + } + + @Test + def testProduceRequestWithNullClientId(): Unit = { + val topic = "topic" + val topicPartition = new TopicPartition(topic, 0) + val correlationId = -1 + createTopic(topic, numPartitions = 1, replicationFactor = 1) + + val version = ApiKeys.PRODUCE.latestVersion: Short + val (serializedBytes, responseHeaderVersion) = { + val headerBytes = requestHeaderBytes(ApiKeys.PRODUCE.id, version, "", correlationId) + val request = requests.ProduceRequest.forCurrentMagic(new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection( + Collections.singletonList(new ProduceRequestData.TopicProduceData() + .setName(topicPartition.topic()).setPartitionData(Collections.singletonList( + new ProduceRequestData.PartitionProduceData() + .setIndex(topicPartition.partition()) + .setRecords(MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("message".getBytes)))))) + .iterator)) + .setAcks(1.toShort) + .setTimeoutMs(10000) + .setTransactionalId(null)) + .build() + val bodyBytes = request.serialize + val byteBuffer = ByteBuffer.allocate(headerBytes.length + bodyBytes.remaining()) + byteBuffer.put(headerBytes) + byteBuffer.put(bodyBytes) + (byteBuffer.array(), request.apiKey.responseHeaderVersion(version)) + } + + val response = requestAndReceive(serializedBytes) + + val responseBuffer = ByteBuffer.wrap(response) + val responseHeader = ResponseHeader.parse(responseBuffer, responseHeaderVersion) + val produceResponse = ProduceResponse.parse(responseBuffer, version) + + assertEquals(0, responseBuffer.remaining, "The response should parse completely") + assertEquals(correlationId, responseHeader.correlationId, "The correlationId should match request") + assertEquals(1, produceResponse.data.responses.size, "One topic response should be returned") + val topicProduceResponse = produceResponse.data.responses.asScala.head + assertEquals(1, topicProduceResponse.partitionResponses.size, "One partition response should be returned") + val partitionProduceResponse = topicProduceResponse.partitionResponses.asScala.head + assertNotNull(partitionProduceResponse) + assertEquals(Errors.NONE, Errors.forCode(partitionProduceResponse.errorCode), "There should be no error") + } + + @Test + def testHeaderOnlyRequest(): Unit = { + verifyDisconnect(requestHeaderBytes(ApiKeys.PRODUCE.id, 1)) + } + + @Test + def testInvalidApiKeyRequest(): Unit = { + verifyDisconnect(requestHeaderBytes(-1, 0)) + } + + @Test + def testInvalidApiVersionRequest(): Unit = { + verifyDisconnect(requestHeaderBytes(ApiKeys.PRODUCE.id, -1)) + } + + @Test + def testMalformedHeaderRequest(): Unit = { + val serializedBytes = { + // Only send apiKey and apiVersion + val buffer = ByteBuffer.allocate( + 2 /* apiKey */ + + 2 /* apiVersion */ + ) + buffer.putShort(ApiKeys.PRODUCE.id) + buffer.putShort(1) + buffer.array() + } + + verifyDisconnect(serializedBytes) + } +} diff --git a/core/src/test/scala/unit/kafka/server/FetchRequestDownConversionConfigTest.scala b/core/src/test/scala/unit/kafka/server/FetchRequestDownConversionConfigTest.scala new file mode 100644 index 0000000..07a20f9 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/FetchRequestDownConversionConfigTest.scala @@ -0,0 +1,195 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.util +import java.util.{Optional, Properties} +import kafka.log.LogConfig +import kafka.utils.TestUtils +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord} +import org.apache.kafka.common.{TopicPartition, Uuid} +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.requests.{FetchRequest, FetchResponse} +import org.apache.kafka.common.serialization.StringSerializer +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +import scala.jdk.CollectionConverters._ + +class FetchRequestDownConversionConfigTest extends BaseRequestTest { + private var producer: KafkaProducer[String, String] = null + override def brokerCount: Int = 1 + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + initProducer() + } + + @AfterEach + override def tearDown(): Unit = { + if (producer != null) + producer.close() + super.tearDown() + } + + override protected def brokerPropertyOverrides(properties: Properties): Unit = { + super.brokerPropertyOverrides(properties) + properties.put(KafkaConfig.LogMessageDownConversionEnableProp, "false") + } + + private def initProducer(): Unit = { + producer = TestUtils.createProducer(TestUtils.getBrokerListStrFromServers(servers), + keySerializer = new StringSerializer, valueSerializer = new StringSerializer) + } + + private def createTopics(numTopics: Int, numPartitions: Int, + configs: Map[String, String] = Map.empty, topicSuffixStart: Int = 0): Map[TopicPartition, Int] = { + val topics = (0 until numTopics).map(t => s"topic${t + topicSuffixStart}") + val topicConfig = new Properties + topicConfig.setProperty(LogConfig.MinInSyncReplicasProp, 1.toString) + configs.foreach { case (k, v) => topicConfig.setProperty(k, v) } + topics.flatMap { topic => + val partitionToLeader = createTopic(topic, numPartitions = numPartitions, replicationFactor = 1, + topicConfig = topicConfig) + partitionToLeader.map { case (partition, leader) => new TopicPartition(topic, partition) -> leader } + }.toMap + } + + private def createPartitionMap(maxPartitionBytes: Int, topicPartitions: Seq[TopicPartition], + topicIds: Map[String, Uuid], + offsetMap: Map[TopicPartition, Long] = Map.empty): util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] = { + val partitionMap = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + topicPartitions.foreach { tp => + partitionMap.put(tp, new FetchRequest.PartitionData(topicIds.getOrElse(tp.topic, Uuid.ZERO_UUID), offsetMap.getOrElse(tp, 0), 0L, + maxPartitionBytes, Optional.empty())) + } + partitionMap + } + + private def sendFetchRequest(leaderId: Int, request: FetchRequest): FetchResponse = { + connectAndReceive[FetchResponse](request, destination = brokerSocketServer(leaderId)) + } + + /** + * Tests that fetch request that require down-conversion returns with an error response when down-conversion is disabled on broker. + */ + @Test + def testV1FetchWithDownConversionDisabled(): Unit = { + val topicMap = createTopics(numTopics = 5, numPartitions = 1) + val topicPartitions = topicMap.keySet.toSeq + val topicIds = servers.head.kafkaController.controllerContext.topicIds + val topicNames = topicIds.map(_.swap) + topicPartitions.foreach(tp => producer.send(new ProducerRecord(tp.topic(), "key", "value")).get()) + val fetchRequest = FetchRequest.Builder.forConsumer(1, Int.MaxValue, 0, createPartitionMap(1024, + topicPartitions, topicIds.toMap)).build(1) + val fetchResponse = sendFetchRequest(topicMap.head._2, fetchRequest) + val fetchResponseData = fetchResponse.responseData(topicNames.asJava, 1) + topicPartitions.foreach(tp => assertEquals(Errors.UNSUPPORTED_VERSION, Errors.forCode(fetchResponseData.get(tp).errorCode))) + } + + /** + * Tests that "message.downconversion.enable" has no effect when down-conversion is not required. + */ + @Test + def testLatestFetchWithDownConversionDisabled(): Unit = { + val topicMap = createTopics(numTopics = 5, numPartitions = 1) + val topicPartitions = topicMap.keySet.toSeq + val topicIds = servers.head.kafkaController.controllerContext.topicIds + val topicNames = topicIds.map(_.swap) + topicPartitions.foreach(tp => producer.send(new ProducerRecord(tp.topic(), "key", "value")).get()) + val fetchRequest = FetchRequest.Builder.forConsumer(ApiKeys.FETCH.latestVersion, Int.MaxValue, 0, createPartitionMap(1024, + topicPartitions, topicIds.toMap)).build() + val fetchResponse = sendFetchRequest(topicMap.head._2, fetchRequest) + val fetchResponseData = fetchResponse.responseData(topicNames.asJava, ApiKeys.FETCH.latestVersion) + topicPartitions.foreach(tp => assertEquals(Errors.NONE, Errors.forCode(fetchResponseData.get(tp).errorCode))) + } + + /** + * Tests that "message.downconversion.enable" has no effect when down-conversion is not required on last version before topic IDs. + */ + @Test + def testV12WithDownConversionDisabled(): Unit = { + val topicMap = createTopics(numTopics = 5, numPartitions = 1) + val topicPartitions = topicMap.keySet.toSeq + val topicIds = servers.head.kafkaController.controllerContext.topicIds + val topicNames = topicIds.map(_.swap) + topicPartitions.foreach(tp => producer.send(new ProducerRecord(tp.topic(), "key", "value")).get()) + val fetchRequest = FetchRequest.Builder.forConsumer(ApiKeys.FETCH.latestVersion, Int.MaxValue, 0, createPartitionMap(1024, + topicPartitions, topicIds.toMap)).build(12) + val fetchResponse = sendFetchRequest(topicMap.head._2, fetchRequest) + val fetchResponseData = fetchResponse.responseData(topicNames.asJava, 12) + topicPartitions.foreach(tp => assertEquals(Errors.NONE, Errors.forCode(fetchResponseData.get(tp).errorCode))) + } + + /** + * Tests that "message.downconversion.enable" can be set at topic level, and its configuration is obeyed for client + * fetch requests. + */ + @Test + def testV1FetchWithTopicLevelOverrides(): Unit = { + // create topics with default down-conversion configuration (i.e. conversion disabled) + val conversionDisabledTopicsMap = createTopics(numTopics = 5, numPartitions = 1, topicSuffixStart = 0) + val conversionDisabledTopicPartitions = conversionDisabledTopicsMap.keySet.toSeq + + // create topics with down-conversion configuration enabled + val topicConfig = Map(LogConfig.MessageDownConversionEnableProp -> "true") + val conversionEnabledTopicsMap = createTopics(numTopics = 5, numPartitions = 1, topicConfig, topicSuffixStart = 5) + val conversionEnabledTopicPartitions = conversionEnabledTopicsMap.keySet.toSeq + + val allTopics = conversionDisabledTopicPartitions ++ conversionEnabledTopicPartitions + val leaderId = conversionDisabledTopicsMap.head._2 + val topicIds = servers.head.kafkaController.controllerContext.topicIds + val topicNames = topicIds.map(_.swap) + + allTopics.foreach(tp => producer.send(new ProducerRecord(tp.topic(), "key", "value")).get()) + val fetchRequest = FetchRequest.Builder.forConsumer(1, Int.MaxValue, 0, createPartitionMap(1024, + allTopics, topicIds.toMap)).build(1) + val fetchResponse = sendFetchRequest(leaderId, fetchRequest) + + val fetchResponseData = fetchResponse.responseData(topicNames.asJava, 1) + conversionDisabledTopicPartitions.foreach(tp => assertEquals(Errors.UNSUPPORTED_VERSION, Errors.forCode(fetchResponseData.get(tp).errorCode))) + conversionEnabledTopicPartitions.foreach(tp => assertEquals(Errors.NONE, Errors.forCode(fetchResponseData.get(tp).errorCode))) + } + + /** + * Tests that "message.downconversion.enable" has no effect on fetch requests from replicas. + */ + @Test + def testV1FetchFromReplica(): Unit = { + // create topics with default down-conversion configuration (i.e. conversion disabled) + val conversionDisabledTopicsMap = createTopics(numTopics = 5, numPartitions = 1, topicSuffixStart = 0) + val conversionDisabledTopicPartitions = conversionDisabledTopicsMap.keySet.toSeq + + // create topics with down-conversion configuration enabled + val topicConfig = Map(LogConfig.MessageDownConversionEnableProp -> "true") + val conversionEnabledTopicsMap = createTopics(numTopics = 5, numPartitions = 1, topicConfig, topicSuffixStart = 5) + val conversionEnabledTopicPartitions = conversionEnabledTopicsMap.keySet.toSeq + + val allTopicPartitions = conversionDisabledTopicPartitions ++ conversionEnabledTopicPartitions + val topicIds = servers.head.kafkaController.controllerContext.topicIds + val topicNames = topicIds.map(_.swap) + val leaderId = conversionDisabledTopicsMap.head._2 + + allTopicPartitions.foreach(tp => producer.send(new ProducerRecord(tp.topic, "key", "value")).get()) + val fetchRequest = FetchRequest.Builder.forReplica(1, 1, Int.MaxValue, 0, + createPartitionMap(1024, allTopicPartitions, topicIds.toMap)).build() + val fetchResponse = sendFetchRequest(leaderId, fetchRequest) + val fetchResponseData = fetchResponse.responseData(topicNames.asJava, 1) + allTopicPartitions.foreach(tp => assertEquals(Errors.NONE, Errors.forCode(fetchResponseData.get(tp).errorCode))) + } +} diff --git a/core/src/test/scala/unit/kafka/server/FetchRequestMaxBytesTest.scala b/core/src/test/scala/unit/kafka/server/FetchRequestMaxBytesTest.scala new file mode 100644 index 0000000..5bf43b3 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/FetchRequestMaxBytesTest.scala @@ -0,0 +1,133 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.log.LogConfig +import kafka.utils.TestUtils +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord} +import org.apache.kafka.common.{TopicPartition, Uuid} +import org.apache.kafka.common.requests.FetchRequest.PartitionData +import org.apache.kafka.common.requests.{FetchRequest, FetchResponse} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +import java.util.{Optional, Properties} +import scala.jdk.CollectionConverters._ + +/** + * This test verifies that the KIP-541 broker-level FetchMaxBytes configuration is honored. + */ +class FetchRequestMaxBytesTest extends BaseRequestTest { + override def brokerCount: Int = 1 + + private var producer: KafkaProducer[Array[Byte], Array[Byte]] = null + private val testTopic = "testTopic" + private val testTopicPartition = new TopicPartition(testTopic, 0) + private val messages = IndexedSeq( + multiByteArray(1), + multiByteArray(500), + multiByteArray(1040), + multiByteArray(500), + multiByteArray(50)) + + private def multiByteArray(length: Int): Array[Byte] = { + val array = new Array[Byte](length) + array.indices.foreach(i => array(i) = (i % 5).toByte) + array + } + + private def oneByteArray(value: Byte): Array[Byte] = { + val array = new Array[Byte](1) + array(0) = value + array + } + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + producer = TestUtils.createProducer(TestUtils.getBrokerListStrFromServers(servers)) + } + + @AfterEach + override def tearDown(): Unit = { + if (producer != null) + producer.close() + super.tearDown() + } + + override protected def brokerPropertyOverrides(properties: Properties): Unit = { + super.brokerPropertyOverrides(properties) + properties.put(KafkaConfig.FetchMaxBytes, "1024") + } + + private def createTopics(): Unit = { + val topicConfig = new Properties + topicConfig.setProperty(LogConfig.MinInSyncReplicasProp, 1.toString) + createTopic(testTopic, + numPartitions = 1, + replicationFactor = 1, + topicConfig = topicConfig) + // Produce several messages as single batches. + messages.indices.foreach(i => { + val record = new ProducerRecord(testTopic, 0, oneByteArray(i.toByte), messages(i)) + val future = producer.send(record) + producer.flush() + future.get() + }) + } + + private def sendFetchRequest(leaderId: Int, request: FetchRequest): FetchResponse = { + connectAndReceive[FetchResponse](request, destination = brokerSocketServer(leaderId)) + } + + /** + * Tests that each of our fetch requests respects FetchMaxBytes. + * + * Note that when a single batch is larger than FetchMaxBytes, it will be + * returned in full even if this is larger than FetchMaxBytes. See KIP-74. + */ + @Test + def testConsumeMultipleRecords(): Unit = { + createTopics() + + expectNextRecords(IndexedSeq(messages(0), messages(1)), 0) + expectNextRecords(IndexedSeq(messages(2)), 2) + expectNextRecords(IndexedSeq(messages(3), messages(4)), 3) + } + + private def expectNextRecords(expected: IndexedSeq[Array[Byte]], + fetchOffset: Long): Unit = { + val response = sendFetchRequest(0, + FetchRequest.Builder.forConsumer(3, Int.MaxValue, 0, + Map(testTopicPartition -> + new PartitionData(Uuid.ZERO_UUID, fetchOffset, 0, Integer.MAX_VALUE, Optional.empty())).asJava).build(3)) + val records = FetchResponse.recordsOrFail(response.responseData(getTopicNames().asJava, 3).get(testTopicPartition)).records() + assertNotNull(records) + val recordsList = records.asScala.toList + assertEquals(expected.size, recordsList.size) + recordsList.zipWithIndex.foreach { + case (record, i) => { + val buffer = record.value().duplicate() + val array = new Array[Byte](buffer.remaining()) + buffer.get(array) + assertArrayEquals(expected(i), + array, s"expectNextRecords unexpected element ${i}") + } + } + } +} diff --git a/core/src/test/scala/unit/kafka/server/FetchRequestTest.scala b/core/src/test/scala/unit/kafka/server/FetchRequestTest.scala new file mode 100644 index 0000000..82c990d --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/FetchRequestTest.scala @@ -0,0 +1,781 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import kafka.log.LogConfig +import kafka.message.{GZIPCompressionCodec, ProducerCompressionCodec, ZStdCompressionCodec} +import kafka.utils.TestUtils +import org.apache.kafka.clients.producer.ProducerRecord +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.record.RecordBatch +import org.apache.kafka.common.requests.{FetchRequest, FetchResponse, FetchMetadata => JFetchMetadata} +import org.apache.kafka.common.serialization.{ByteArraySerializer, StringSerializer} +import org.apache.kafka.common.{IsolationLevel, TopicIdPartition, TopicPartition, Uuid} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import java.io.DataInputStream +import java.util +import java.util.Optional +import scala.collection.Seq +import scala.jdk.CollectionConverters._ +import scala.util.Random + +/** + * Subclasses of `BaseConsumerTest` exercise the consumer and fetch request/response. This class + * complements those classes with tests that require lower-level access to the protocol. + */ +class FetchRequestTest extends BaseFetchRequestTest { + + @Test + def testBrokerRespectsPartitionsOrderAndSizeLimits(): Unit = { + initProducer() + + val messagesPerPartition = 9 + val maxResponseBytes = 800 + val maxPartitionBytes = 190 + + def createFetchRequest(topicPartitions: Seq[TopicPartition], offsetMap: Map[TopicPartition, Long] = Map.empty, + version: Short = ApiKeys.FETCH.latestVersion()): FetchRequest = + this.createFetchRequest(maxResponseBytes, maxPartitionBytes, topicPartitions, offsetMap, version) + + val topicPartitionToLeader = createTopics(numTopics = 5, numPartitions = 6) + val random = new Random(0) + val topicPartitions = topicPartitionToLeader.keySet + val topicIds = getTopicIds().asJava + val topicNames = topicIds.asScala.map(_.swap).asJava + produceData(topicPartitions, messagesPerPartition) + + val leaderId = servers.head.config.brokerId + val partitionsForLeader = topicPartitionToLeader.toVector.collect { + case (tp, partitionLeaderId) if partitionLeaderId == leaderId => tp + } + + val partitionsWithLargeMessages = partitionsForLeader.takeRight(2) + val partitionWithLargeMessage1 = partitionsWithLargeMessages.head + val partitionWithLargeMessage2 = partitionsWithLargeMessages(1) + producer.send(new ProducerRecord(partitionWithLargeMessage1.topic, partitionWithLargeMessage1.partition, + "larger than partition limit", new String(new Array[Byte](maxPartitionBytes + 1)))).get + producer.send(new ProducerRecord(partitionWithLargeMessage2.topic, partitionWithLargeMessage2.partition, + "larger than response limit", new String(new Array[Byte](maxResponseBytes + 1)))).get + + val partitionsWithoutLargeMessages = partitionsForLeader.filterNot(partitionsWithLargeMessages.contains) + + // 1. Partitions with large messages at the end + val shuffledTopicPartitions1 = random.shuffle(partitionsWithoutLargeMessages) ++ partitionsWithLargeMessages + val fetchRequest1 = createFetchRequest(shuffledTopicPartitions1) + val fetchResponse1 = sendFetchRequest(leaderId, fetchRequest1) + checkFetchResponse(shuffledTopicPartitions1, fetchResponse1, maxPartitionBytes, maxResponseBytes, messagesPerPartition) + val fetchRequest1V12 = createFetchRequest(shuffledTopicPartitions1, version = 12) + val fetchResponse1V12 = sendFetchRequest(leaderId, fetchRequest1V12) + checkFetchResponse(shuffledTopicPartitions1, fetchResponse1V12, maxPartitionBytes, maxResponseBytes, messagesPerPartition, 12) + + // 2. Same as 1, but shuffled again + val shuffledTopicPartitions2 = random.shuffle(partitionsWithoutLargeMessages) ++ partitionsWithLargeMessages + val fetchRequest2 = createFetchRequest(shuffledTopicPartitions2) + val fetchResponse2 = sendFetchRequest(leaderId, fetchRequest2) + checkFetchResponse(shuffledTopicPartitions2, fetchResponse2, maxPartitionBytes, maxResponseBytes, messagesPerPartition) + val fetchRequest2V12 = createFetchRequest(shuffledTopicPartitions2, version = 12) + val fetchResponse2V12 = sendFetchRequest(leaderId, fetchRequest2V12) + checkFetchResponse(shuffledTopicPartitions2, fetchResponse2V12, maxPartitionBytes, maxResponseBytes, messagesPerPartition, 12) + + // 3. Partition with message larger than the partition limit at the start of the list + val shuffledTopicPartitions3 = Seq(partitionWithLargeMessage1, partitionWithLargeMessage2) ++ + random.shuffle(partitionsWithoutLargeMessages) + val fetchRequest3 = createFetchRequest(shuffledTopicPartitions3, Map(partitionWithLargeMessage1 -> messagesPerPartition)) + val fetchResponse3 = sendFetchRequest(leaderId, fetchRequest3) + val fetchRequest3V12 = createFetchRequest(shuffledTopicPartitions3, Map(partitionWithLargeMessage1 -> messagesPerPartition), 12) + val fetchResponse3V12 = sendFetchRequest(leaderId, fetchRequest3V12) + def evaluateResponse3(response: FetchResponse, version: Short = ApiKeys.FETCH.latestVersion()) = { + val responseData = response.responseData(topicNames, version) + assertEquals(shuffledTopicPartitions3, responseData.keySet.asScala.toSeq) + val responseSize = responseData.asScala.values.map { partitionData => + records(partitionData).map(_.sizeInBytes).sum + }.sum + assertTrue(responseSize <= maxResponseBytes) + val partitionData = responseData.get(partitionWithLargeMessage1) + assertEquals(Errors.NONE.code, partitionData.errorCode) + assertTrue(partitionData.highWatermark > 0) + val size3 = records(partitionData).map(_.sizeInBytes).sum + assertTrue(size3 <= maxResponseBytes, s"Expected $size3 to be smaller than $maxResponseBytes") + assertTrue(size3 > maxPartitionBytes, s"Expected $size3 to be larger than $maxPartitionBytes") + assertTrue(maxPartitionBytes < partitionData.records.sizeInBytes) + } + evaluateResponse3(fetchResponse3) + evaluateResponse3(fetchResponse3V12, 12) + + // 4. Partition with message larger than the response limit at the start of the list + val shuffledTopicPartitions4 = Seq(partitionWithLargeMessage2, partitionWithLargeMessage1) ++ + random.shuffle(partitionsWithoutLargeMessages) + val fetchRequest4 = createFetchRequest(shuffledTopicPartitions4, Map(partitionWithLargeMessage2 -> messagesPerPartition)) + val fetchResponse4 = sendFetchRequest(leaderId, fetchRequest4) + val fetchRequest4V12 = createFetchRequest(shuffledTopicPartitions4, Map(partitionWithLargeMessage2 -> messagesPerPartition), 12) + val fetchResponse4V12 = sendFetchRequest(leaderId, fetchRequest4V12) + def evaluateResponse4(response: FetchResponse, version: Short = ApiKeys.FETCH.latestVersion()) = { + val responseData = response.responseData(topicNames, version) + assertEquals(shuffledTopicPartitions4, responseData.keySet.asScala.toSeq) + val nonEmptyPartitions = responseData.asScala.toSeq.collect { + case (tp, partitionData) if records(partitionData).map(_.sizeInBytes).sum > 0 => tp + } + assertEquals(Seq(partitionWithLargeMessage2), nonEmptyPartitions) + val partitionData = responseData.get(partitionWithLargeMessage2) + assertEquals(Errors.NONE.code, partitionData.errorCode) + assertTrue(partitionData.highWatermark > 0) + val size4 = records(partitionData).map(_.sizeInBytes).sum + assertTrue(size4 > maxResponseBytes, s"Expected $size4 to be larger than $maxResponseBytes") + assertTrue(maxResponseBytes < partitionData.records.sizeInBytes) + } + evaluateResponse4(fetchResponse4) + evaluateResponse4(fetchResponse4V12, 12) + } + + @Test + def testFetchRequestV4WithReadCommitted(): Unit = { + initProducer() + val maxPartitionBytes = 200 + val (topicPartition, leaderId) = createTopics(numTopics = 1, numPartitions = 1).head + val topicIds = getTopicIds().asJava + val topicNames = topicIds.asScala.map(_.swap).asJava + producer.send(new ProducerRecord(topicPartition.topic, topicPartition.partition, + "key", new String(new Array[Byte](maxPartitionBytes + 1)))).get + val fetchRequest = FetchRequest.Builder.forConsumer(4, Int.MaxValue, 0, createPartitionMap(maxPartitionBytes, + Seq(topicPartition))).isolationLevel(IsolationLevel.READ_COMMITTED).build(4) + val fetchResponse = sendFetchRequest(leaderId, fetchRequest) + val partitionData = fetchResponse.responseData(topicNames, 4).get(topicPartition) + assertEquals(Errors.NONE.code, partitionData.errorCode) + assertTrue(partitionData.lastStableOffset > 0) + assertTrue(records(partitionData).map(_.sizeInBytes).sum > 0) + } + + @Test + def testFetchRequestToNonReplica(): Unit = { + val topic = "topic" + val partition = 0 + val topicPartition = new TopicPartition(topic, partition) + + // Create a single-partition topic and find a broker which is not the leader + val partitionToLeader = TestUtils.createTopic(zkClient, topic, numPartitions = 1, 1, servers) + val topicIds = getTopicIds().asJava + val topicNames = topicIds.asScala.map(_.swap).asJava + val leader = partitionToLeader(partition) + val nonReplicaOpt = servers.find(_.config.brokerId != leader) + assertTrue(nonReplicaOpt.isDefined) + val nonReplicaId = nonReplicaOpt.get.config.brokerId + + // Send the fetch request to the non-replica and verify the error code + val fetchRequest = FetchRequest.Builder.forConsumer(ApiKeys.FETCH.latestVersion, Int.MaxValue, 0, createPartitionMap(1024, + Seq(topicPartition))).build() + val fetchResponse = sendFetchRequest(nonReplicaId, fetchRequest) + val partitionData = fetchResponse.responseData(topicNames, ApiKeys.FETCH.latestVersion).get(topicPartition) + assertEquals(Errors.NOT_LEADER_OR_FOLLOWER.code, partitionData.errorCode) + + // Repeat with request that does not use topic IDs + val oldFetchRequest = FetchRequest.Builder.forConsumer(12, Int.MaxValue, 0, createPartitionMap(1024, + Seq(topicPartition))).build() + val oldFetchResponse = sendFetchRequest(nonReplicaId, oldFetchRequest) + val oldPartitionData = oldFetchResponse.responseData(topicNames, 12).get(topicPartition) + assertEquals(Errors.NOT_LEADER_OR_FOLLOWER.code, oldPartitionData.errorCode) + } + + @Test + def testLastFetchedEpochValidation(): Unit = { + checkLastFetchedEpochValidation(ApiKeys.FETCH.latestVersion()) + } + + @Test + def testLastFetchedEpochValidationV12(): Unit = { + checkLastFetchedEpochValidation(12) + } + + private def checkLastFetchedEpochValidation(version: Short): Unit = { + val topic = "topic" + val topicPartition = new TopicPartition(topic, 0) + val partitionToLeader = TestUtils.createTopic(zkClient, topic, numPartitions = 1, replicationFactor = 3, servers) + val firstLeaderId = partitionToLeader(topicPartition.partition) + val firstLeaderEpoch = TestUtils.findLeaderEpoch(firstLeaderId, topicPartition, servers) + + initProducer() + + // Write some data in epoch 0 + val firstEpochResponses = produceData(Seq(topicPartition), 100) + val firstEpochEndOffset = firstEpochResponses.lastOption.get.offset + 1 + // Force a leader change + killBroker(firstLeaderId) + // Write some more data in epoch 1 + val secondLeaderId = TestUtils.awaitLeaderChange(servers, topicPartition, firstLeaderId) + val secondLeaderEpoch = TestUtils.findLeaderEpoch(secondLeaderId, topicPartition, servers) + val secondEpochResponses = produceData(Seq(topicPartition), 100) + val secondEpochEndOffset = secondEpochResponses.lastOption.get.offset + 1 + val topicIds = getTopicIds().asJava + val topicNames = topicIds.asScala.map(_.swap).asJava + + // Build a fetch request in the middle of the second epoch, but with the first epoch + val fetchOffset = secondEpochEndOffset + (secondEpochEndOffset - firstEpochEndOffset) / 2 + val partitionMap = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + partitionMap.put(topicPartition, + new FetchRequest.PartitionData(topicIds.getOrDefault(topic, Uuid.ZERO_UUID), fetchOffset, 0L, 1024, + Optional.of(secondLeaderEpoch), Optional.of(firstLeaderEpoch))) + val fetchRequest = FetchRequest.Builder.forConsumer(version, 0, 1, partitionMap).build() + + // Validate the expected truncation + val fetchResponse = sendFetchRequest(secondLeaderId, fetchRequest) + val partitionData = fetchResponse.responseData(topicNames, version).get(topicPartition) + assertEquals(Errors.NONE.code, partitionData.errorCode) + assertEquals(0L, FetchResponse.recordsSize(partitionData)) + assertTrue(FetchResponse.isDivergingEpoch(partitionData)) + + val divergingEpoch = partitionData.divergingEpoch + assertEquals(firstLeaderEpoch, divergingEpoch.epoch) + assertEquals(firstEpochEndOffset, divergingEpoch.endOffset) + } + + @Test + def testCurrentEpochValidation(): Unit = { + checkCurrentEpochValidation(ApiKeys.FETCH.latestVersion()) + } + + @Test + def testCurrentEpochValidationV12(): Unit = { + checkCurrentEpochValidation(12) + } + + private def checkCurrentEpochValidation(version: Short): Unit = { + val topic = "topic" + val topicPartition = new TopicPartition(topic, 0) + val partitionToLeader = TestUtils.createTopic(zkClient, topic, numPartitions = 1, replicationFactor = 3, servers) + val firstLeaderId = partitionToLeader(topicPartition.partition) + + def assertResponseErrorForEpoch(error: Errors, brokerId: Int, leaderEpoch: Optional[Integer]): Unit = { + val partitionMap = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + val topicIds = getTopicIds().asJava + val topicNames = topicIds.asScala.map(_.swap).asJava + partitionMap.put(topicPartition, + new FetchRequest.PartitionData(topicIds.getOrDefault(topic, Uuid.ZERO_UUID), 0L, 0L, 1024, leaderEpoch)) + val fetchRequest = FetchRequest.Builder.forConsumer(version, 0, 1, partitionMap).build() + val fetchResponse = sendFetchRequest(brokerId, fetchRequest) + val partitionData = fetchResponse.responseData(topicNames, version).get(topicPartition) + assertEquals(error.code, partitionData.errorCode) + } + + // We need a leader change in order to check epoch fencing since the first epoch is 0 and + // -1 is treated as having no epoch at all + killBroker(firstLeaderId) + + // Check leader error codes + val secondLeaderId = TestUtils.awaitLeaderChange(servers, topicPartition, firstLeaderId) + val secondLeaderEpoch = TestUtils.findLeaderEpoch(secondLeaderId, topicPartition, servers) + assertResponseErrorForEpoch(Errors.NONE, secondLeaderId, Optional.empty()) + assertResponseErrorForEpoch(Errors.NONE, secondLeaderId, Optional.of(secondLeaderEpoch)) + assertResponseErrorForEpoch(Errors.FENCED_LEADER_EPOCH, secondLeaderId, Optional.of(secondLeaderEpoch - 1)) + assertResponseErrorForEpoch(Errors.UNKNOWN_LEADER_EPOCH, secondLeaderId, Optional.of(secondLeaderEpoch + 1)) + + // Check follower error codes + val followerId = TestUtils.findFollowerId(topicPartition, servers) + assertResponseErrorForEpoch(Errors.NONE, followerId, Optional.empty()) + assertResponseErrorForEpoch(Errors.NONE, followerId, Optional.of(secondLeaderEpoch)) + assertResponseErrorForEpoch(Errors.UNKNOWN_LEADER_EPOCH, followerId, Optional.of(secondLeaderEpoch + 1)) + assertResponseErrorForEpoch(Errors.FENCED_LEADER_EPOCH, followerId, Optional.of(secondLeaderEpoch - 1)) + } + + @Test + def testEpochValidationWithinFetchSession(): Unit = { + checkEpochValidationWithinFetchSession(ApiKeys.FETCH.latestVersion()) + } + + @Test + def testEpochValidationWithinFetchSessionV12(): Unit = { + checkEpochValidationWithinFetchSession(12) + } + + private def checkEpochValidationWithinFetchSession(version: Short): Unit = { + val topic = "topic" + val topicPartition = new TopicPartition(topic, 0) + val partitionToLeader = TestUtils.createTopic(zkClient, topic, numPartitions = 1, replicationFactor = 3, servers) + val firstLeaderId = partitionToLeader(topicPartition.partition) + + // We need a leader change in order to check epoch fencing since the first epoch is 0 and + // -1 is treated as having no epoch at all + killBroker(firstLeaderId) + + val secondLeaderId = TestUtils.awaitLeaderChange(servers, topicPartition, firstLeaderId) + val secondLeaderEpoch = TestUtils.findLeaderEpoch(secondLeaderId, topicPartition, servers) + verifyFetchSessionErrors(topicPartition, secondLeaderEpoch, secondLeaderId, version) + + val followerId = TestUtils.findFollowerId(topicPartition, servers) + verifyFetchSessionErrors(topicPartition, secondLeaderEpoch, followerId, version) + } + + private def verifyFetchSessionErrors(topicPartition: TopicPartition, + leaderEpoch: Int, + destinationBrokerId: Int, + version: Short): Unit = { + val partitionMap = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + val topicIds = getTopicIds().asJava + val topicNames = topicIds.asScala.map(_.swap).asJava + partitionMap.put(topicPartition, new FetchRequest.PartitionData(topicIds.getOrDefault(topicPartition.topic, Uuid.ZERO_UUID), + 0L, 0L, 1024, Optional.of(leaderEpoch))) + val fetchRequest = FetchRequest.Builder.forConsumer(version, 0, 1, partitionMap) + .metadata(JFetchMetadata.INITIAL) + .build() + val fetchResponse = sendFetchRequest(destinationBrokerId, fetchRequest) + val sessionId = fetchResponse.sessionId + + def assertResponseErrorForEpoch(expectedError: Errors, + sessionFetchEpoch: Int, + leaderEpoch: Optional[Integer]): Unit = { + val partitionMap = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + partitionMap.put(topicPartition, new FetchRequest.PartitionData(topicIds.getOrDefault(topicPartition.topic, Uuid.ZERO_UUID), 0L, 0L, 1024, leaderEpoch)) + val fetchRequest = FetchRequest.Builder.forConsumer(version, 0, 1, partitionMap) + .metadata(new JFetchMetadata(sessionId, sessionFetchEpoch)) + .build() + val fetchResponse = sendFetchRequest(destinationBrokerId, fetchRequest) + val partitionData = fetchResponse.responseData(topicNames, version).get(topicPartition) + assertEquals(expectedError.code, partitionData.errorCode) + } + + // We only check errors because we do not expect the partition in the response otherwise + assertResponseErrorForEpoch(Errors.FENCED_LEADER_EPOCH, 1, Optional.of(leaderEpoch - 1)) + assertResponseErrorForEpoch(Errors.UNKNOWN_LEADER_EPOCH, 2, Optional.of(leaderEpoch + 1)) + } + + /** + * Tests that down-conversions don't leak memory. Large down conversions are triggered + * in the server. The client closes its connection after reading partial data when the + * channel is muted in the server. If buffers are not released this will result in OOM. + */ + @Test + def testDownConversionWithConnectionFailure(): Unit = { + val (topicPartition, leaderId) = createTopics(numTopics = 1, numPartitions = 1).head + val topicIds = getTopicIds().asJava + val topicNames = topicIds.asScala.map(_.swap).asJava + + val msgValueLen = 100 * 1000 + val batchSize = 4 * msgValueLen + val producer = TestUtils.createProducer(TestUtils.getBrokerListStrFromServers(servers), + lingerMs = Int.MaxValue, + deliveryTimeoutMs = Int.MaxValue, + batchSize = batchSize, + keySerializer = new StringSerializer, + valueSerializer = new ByteArraySerializer) + val bytes = new Array[Byte](msgValueLen) + val futures = try { + (0 to 1000).map { _ => + producer.send(new ProducerRecord(topicPartition.topic, topicPartition.partition, "key", bytes)) + } + } finally { + producer.close() + } + // Check futures to ensure sends succeeded, but do this after close since the last + // batch is not complete, but sent when the producer is closed + futures.foreach(_.get) + + def fetch(version: Short, maxPartitionBytes: Int, closeAfterPartialResponse: Boolean): Option[FetchResponse] = { + val fetchRequest = FetchRequest.Builder.forConsumer(version, Int.MaxValue, 0, createPartitionMap(maxPartitionBytes, + Seq(topicPartition))).build(version) + + val socket = connect(brokerSocketServer(leaderId)) + try { + send(fetchRequest, socket) + if (closeAfterPartialResponse) { + // read some data to ensure broker has muted this channel and then close socket + val size = new DataInputStream(socket.getInputStream).readInt() + // Check that we have received almost `maxPartitionBytes` (minus a tolerance) since in + // the case of OOM, the size will be significantly smaller. We can't check for exactly + // maxPartitionBytes since we use approx message sizes that include only the message value. + assertTrue(size > maxPartitionBytes - batchSize, + s"Fetch size too small $size, broker may have run out of memory") + None + } else { + Some(receive[FetchResponse](socket, ApiKeys.FETCH, version)) + } + } finally { + socket.close() + } + } + + val version = 1.toShort + (0 to 15).foreach(_ => fetch(version, maxPartitionBytes = msgValueLen * 1000, closeAfterPartialResponse = true)) + + val response = fetch(version, maxPartitionBytes = batchSize, closeAfterPartialResponse = false) + val fetchResponse = response.getOrElse(throw new IllegalStateException("No fetch response")) + val partitionData = fetchResponse.responseData(topicNames, version).get(topicPartition) + assertEquals(Errors.NONE.code, partitionData.errorCode) + val batches = FetchResponse.recordsOrFail(partitionData).batches.asScala.toBuffer + assertEquals(3, batches.size) // size is 3 (not 4) since maxPartitionBytes=msgValueSize*4, excluding key and headers + } + + /** + * Ensure that we respect the fetch offset when returning records that were converted from an uncompressed v2 + * record batch to multiple v0/v1 record batches with size 1. If the fetch offset points to inside the record batch, + * some records have to be dropped during the conversion. + */ + @Test + def testDownConversionFromBatchedToUnbatchedRespectsOffset(): Unit = { + // Increase linger so that we have control over the batches created + producer = TestUtils.createProducer(TestUtils.getBrokerListStrFromServers(servers), + retries = 5, + keySerializer = new StringSerializer, + valueSerializer = new StringSerializer, + lingerMs = 30 * 1000, + deliveryTimeoutMs = 60 * 1000) + + val (topicPartition, leaderId) = createTopics(numTopics = 1, numPartitions = 1).head + val topic = topicPartition.topic + val topicIds = getTopicIds().asJava + val topicNames = topicIds.asScala.map(_.swap).asJava + + val firstBatchFutures = (0 until 10).map(i => producer.send(new ProducerRecord(topic, s"key-$i", s"value-$i"))) + producer.flush() + val secondBatchFutures = (10 until 25).map(i => producer.send(new ProducerRecord(topic, s"key-$i", s"value-$i"))) + producer.flush() + + firstBatchFutures.foreach(_.get) + secondBatchFutures.foreach(_.get) + + def check(fetchOffset: Long, requestVersion: Short, expectedOffset: Long, expectedNumBatches: Int, expectedMagic: Byte): Unit = { + var batchesReceived = 0 + var currentFetchOffset = fetchOffset + var currentExpectedOffset = expectedOffset + + // With KIP-283, we might not receive all batches in a single fetch request so loop through till we have consumed + // all batches we are interested in. + while (batchesReceived < expectedNumBatches) { + val fetchRequest = FetchRequest.Builder.forConsumer(requestVersion, Int.MaxValue, 0, createPartitionMap(Int.MaxValue, + Seq(topicPartition), Map(topicPartition -> currentFetchOffset))).build(requestVersion) + val fetchResponse = sendFetchRequest(leaderId, fetchRequest) + + // validate response + val partitionData = fetchResponse.responseData(topicNames, requestVersion).get(topicPartition) + assertEquals(Errors.NONE.code, partitionData.errorCode) + assertTrue(partitionData.highWatermark > 0) + val batches = FetchResponse.recordsOrFail(partitionData).batches.asScala.toBuffer + val batch = batches.head + assertEquals(expectedMagic, batch.magic) + assertEquals(currentExpectedOffset, batch.baseOffset) + + currentFetchOffset = batches.last.lastOffset + 1 + currentExpectedOffset += (batches.last.lastOffset - batches.head.baseOffset + 1) + batchesReceived += batches.size + } + + assertEquals(expectedNumBatches, batchesReceived) + } + + // down conversion to message format 0, batches of 1 message are returned so we receive the exact offset we requested + check(fetchOffset = 3, expectedOffset = 3, requestVersion = 1, expectedNumBatches = 22, + expectedMagic = RecordBatch.MAGIC_VALUE_V0) + check(fetchOffset = 15, expectedOffset = 15, requestVersion = 1, expectedNumBatches = 10, + expectedMagic = RecordBatch.MAGIC_VALUE_V0) + + // down conversion to message format 1, batches of 1 message are returned so we receive the exact offset we requested + check(fetchOffset = 3, expectedOffset = 3, requestVersion = 3, expectedNumBatches = 22, + expectedMagic = RecordBatch.MAGIC_VALUE_V1) + check(fetchOffset = 15, expectedOffset = 15, requestVersion = 3, expectedNumBatches = 10, + expectedMagic = RecordBatch.MAGIC_VALUE_V1) + + // no down conversion, we receive a single batch so the received offset won't necessarily be the same + check(fetchOffset = 3, expectedOffset = 0, requestVersion = 4, expectedNumBatches = 2, + expectedMagic = RecordBatch.MAGIC_VALUE_V2) + check(fetchOffset = 15, expectedOffset = 10, requestVersion = 4, expectedNumBatches = 1, + expectedMagic = RecordBatch.MAGIC_VALUE_V2) + + // no down conversion, we receive a single batch and the exact offset we requested because it happens to be the + // offset of the first record in the batch + check(fetchOffset = 10, expectedOffset = 10, requestVersion = 4, expectedNumBatches = 1, + expectedMagic = RecordBatch.MAGIC_VALUE_V2) + } + + /** + * Test that when an incremental fetch session contains partitions with an error, + * those partitions are returned in all incremental fetch requests. + * This tests using FetchRequests that don't use topic IDs + */ + @Test + def testCreateIncrementalFetchWithPartitionsInErrorV12(): Unit = { + def createFetchRequest(topicPartitions: Seq[TopicPartition], + metadata: JFetchMetadata, + toForget: Seq[TopicIdPartition]): FetchRequest = + FetchRequest.Builder.forConsumer(12, Int.MaxValue, 0, + createPartitionMap(Integer.MAX_VALUE, topicPartitions, Map.empty)) + .removed(toForget.asJava) + .metadata(metadata) + .build() + val foo0 = new TopicPartition("foo", 0) + val foo1 = new TopicPartition("foo", 1) + // topicNames can be empty because we are using old requests + val topicNames = Map[Uuid, String]().asJava + createTopic("foo", Map(0 -> List(0, 1), 1 -> List(0, 2))) + val bar0 = new TopicPartition("bar", 0) + val req1 = createFetchRequest(List(foo0, foo1, bar0), JFetchMetadata.INITIAL, Nil) + val resp1 = sendFetchRequest(0, req1) + assertEquals(Errors.NONE, resp1.error()) + assertTrue(resp1.sessionId() > 0, "Expected the broker to create a new incremental fetch session") + debug(s"Test created an incremental fetch session ${resp1.sessionId}") + val responseData1 = resp1.responseData(topicNames, 12) + assertTrue(responseData1.containsKey(foo0)) + assertTrue(responseData1.containsKey(foo1)) + assertTrue(responseData1.containsKey(bar0)) + assertEquals(Errors.NONE.code, responseData1.get(foo0).errorCode) + assertEquals(Errors.NONE.code, responseData1.get(foo1).errorCode) + assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION.code, responseData1.get(bar0).errorCode) + val req2 = createFetchRequest(Nil, new JFetchMetadata(resp1.sessionId(), 1), Nil) + val resp2 = sendFetchRequest(0, req2) + assertEquals(Errors.NONE, resp2.error()) + assertEquals(resp1.sessionId(), + resp2.sessionId(), "Expected the broker to continue the incremental fetch session") + val responseData2 = resp2.responseData(topicNames, 12) + assertFalse(responseData2.containsKey(foo0)) + assertFalse(responseData2.containsKey(foo1)) + assertTrue(responseData2.containsKey(bar0)) + assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION.code, responseData2.get(bar0).errorCode) + createTopic("bar", Map(0 -> List(0, 1))) + val req3 = createFetchRequest(Nil, new JFetchMetadata(resp1.sessionId(), 2), Nil) + val resp3 = sendFetchRequest(0, req3) + assertEquals(Errors.NONE, resp3.error()) + val responseData3 = resp3.responseData(topicNames, 12) + assertFalse(responseData3.containsKey(foo0)) + assertFalse(responseData3.containsKey(foo1)) + assertTrue(responseData3.containsKey(bar0)) + assertEquals(Errors.NONE.code, responseData3.get(bar0).errorCode) + val req4 = createFetchRequest(Nil, new JFetchMetadata(resp1.sessionId(), 3), Nil) + val resp4 = sendFetchRequest(0, req4) + assertEquals(Errors.NONE, resp4.error()) + val responseData4 = resp4.responseData(topicNames, 12) + assertFalse(responseData4.containsKey(foo0)) + assertFalse(responseData4.containsKey(foo1)) + assertFalse(responseData4.containsKey(bar0)) + } + + /** + * Test that when a Fetch Request receives an unknown topic ID, it returns a top level error. + */ + @Test + def testFetchWithPartitionsWithIdError(): Unit = { + def createFetchRequest(fetchData: util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData], + metadata: JFetchMetadata, + toForget: Seq[TopicIdPartition]): FetchRequest = { + FetchRequest.Builder.forConsumer(ApiKeys.FETCH.latestVersion(), Int.MaxValue, 0, fetchData) + .removed(toForget.asJava) + .metadata(metadata) + .build() + } + + val foo0 = new TopicPartition("foo", 0) + val foo1 = new TopicPartition("foo", 1) + createTopic("foo", Map(0 -> List(0, 1), 1 -> List(0, 2))) + val topicIds = getTopicIds() + val topicIdsWithUnknown = topicIds ++ Map("bar" -> Uuid.randomUuid()) + val bar0 = new TopicPartition("bar", 0) + + def createPartitionMap(maxPartitionBytes: Int, topicPartitions: Seq[TopicPartition], + offsetMap: Map[TopicPartition, Long]): util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] = { + val partitionMap = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + topicPartitions.foreach { tp => + partitionMap.put(tp, + new FetchRequest.PartitionData(topicIdsWithUnknown.getOrElse(tp.topic, Uuid.ZERO_UUID), offsetMap.getOrElse(tp, 0), + 0L, maxPartitionBytes, Optional.empty())) + } + partitionMap + } + + val req1 = createFetchRequest( createPartitionMap(Integer.MAX_VALUE, List(foo0, foo1, bar0), Map.empty), JFetchMetadata.INITIAL, Nil) + val resp1 = sendFetchRequest(0, req1) + assertEquals(Errors.NONE, resp1.error()) + val topicNames1 = topicIdsWithUnknown.map(_.swap).asJava + val responseData1 = resp1.responseData(topicNames1, ApiKeys.FETCH.latestVersion()) + assertTrue(responseData1.containsKey(foo0)) + assertTrue(responseData1.containsKey(foo1)) + assertTrue(responseData1.containsKey(bar0)) + assertEquals(Errors.NONE.code, responseData1.get(foo0).errorCode) + assertEquals(Errors.NONE.code, responseData1.get(foo1).errorCode) + assertEquals(Errors.UNKNOWN_TOPIC_ID.code, responseData1.get(bar0).errorCode) + } + + @Test + def testZStdCompressedTopic(): Unit = { + // ZSTD compressed topic + val topicConfig = Map(LogConfig.CompressionTypeProp -> ZStdCompressionCodec.name) + val (topicPartition, leaderId) = createTopics(numTopics = 1, numPartitions = 1, configs = topicConfig).head + val topicIds = getTopicIds().asJava + val topicNames = topicIds.asScala.map(_.swap).asJava + + // Produce messages (v2) + producer = TestUtils.createProducer(TestUtils.getBrokerListStrFromServers(servers), + keySerializer = new StringSerializer, + valueSerializer = new StringSerializer) + producer.send(new ProducerRecord(topicPartition.topic, topicPartition.partition, + "key1", "value1")).get + producer.send(new ProducerRecord(topicPartition.topic, topicPartition.partition, + "key2", "value2")).get + producer.send(new ProducerRecord(topicPartition.topic, topicPartition.partition, + "key3", "value3")).get + producer.close() + + // fetch request with version below v10: UNSUPPORTED_COMPRESSION_TYPE error occurs + val req0 = new FetchRequest.Builder(0, 9, -1, Int.MaxValue, 0, + createPartitionMap(300, Seq(topicPartition), Map.empty)) + .setMaxBytes(800).build() + + val res0 = sendFetchRequest(leaderId, req0) + val data0 = res0.responseData(topicNames, 9).get(topicPartition) + assertEquals(Errors.UNSUPPORTED_COMPRESSION_TYPE.code, data0.errorCode) + + // fetch request with version 10: works fine! + val req1= new FetchRequest.Builder(0, 10, -1, Int.MaxValue, 0, + createPartitionMap(300, Seq(topicPartition), Map.empty)) + .setMaxBytes(800).build() + val res1 = sendFetchRequest(leaderId, req1) + val data1 = res1.responseData(topicNames, 10).get(topicPartition) + assertEquals(Errors.NONE.code, data1.errorCode) + assertEquals(3, records(data1).size) + + val req2 = new FetchRequest.Builder(ApiKeys.FETCH.latestVersion(), ApiKeys.FETCH.latestVersion(), -1, Int.MaxValue, 0, + createPartitionMap(300, Seq(topicPartition), Map.empty)) + .setMaxBytes(800).build() + val res2 = sendFetchRequest(leaderId, req2) + val data2 = res2.responseData(topicNames, ApiKeys.FETCH.latestVersion()).get(topicPartition) + assertEquals(Errors.NONE.code, data2.errorCode) + assertEquals(3, records(data2).size) + } + + @Test + def testZStdCompressedRecords(): Unit = { + // Producer compressed topic + val topicConfig = Map(LogConfig.CompressionTypeProp -> ProducerCompressionCodec.name) + val (topicPartition, leaderId) = createTopics(numTopics = 1, numPartitions = 1, configs = topicConfig).head + val topicIds = getTopicIds().asJava + val topicNames = topicIds.asScala.map(_.swap).asJava + + // Produce GZIP compressed messages (v2) + val producer1 = TestUtils.createProducer(TestUtils.getBrokerListStrFromServers(servers), + compressionType = GZIPCompressionCodec.name, + keySerializer = new StringSerializer, + valueSerializer = new StringSerializer) + producer1.send(new ProducerRecord(topicPartition.topic, topicPartition.partition, + "key1", "value1")).get + producer1.close() + // Produce ZSTD compressed messages (v2) + val producer2 = TestUtils.createProducer(TestUtils.getBrokerListStrFromServers(servers), + compressionType = ZStdCompressionCodec.name, + keySerializer = new StringSerializer, + valueSerializer = new StringSerializer) + producer2.send(new ProducerRecord(topicPartition.topic, topicPartition.partition, + "key2", "value2")).get + producer2.send(new ProducerRecord(topicPartition.topic, topicPartition.partition, + "key3", "value3")).get + producer2.close() + + // fetch request with fetch version v1 (magic 0): + // gzip compressed record is returned with down-conversion. + // zstd compressed record raises UNSUPPORTED_COMPRESSION_TYPE error. + val req0 = new FetchRequest.Builder(0, 1, -1, Int.MaxValue, 0, + createPartitionMap(300, Seq(topicPartition), Map.empty)) + .setMaxBytes(800) + .build() + + val res0 = sendFetchRequest(leaderId, req0) + val data0 = res0.responseData(topicNames, 1).get(topicPartition) + assertEquals(Errors.NONE.code, data0.errorCode) + assertEquals(1, records(data0).size) + + val req1 = new FetchRequest.Builder(0, 1, -1, Int.MaxValue, 0, + createPartitionMap(300, Seq(topicPartition), Map(topicPartition -> 1L))) + .setMaxBytes(800).build() + + val res1 = sendFetchRequest(leaderId, req1) + val data1 = res1.responseData(topicNames, 1).get(topicPartition) + assertEquals(Errors.UNSUPPORTED_COMPRESSION_TYPE.code, data1.errorCode) + + // fetch request with fetch version v3 (magic 1): + // gzip compressed record is returned with down-conversion. + // zstd compressed record raises UNSUPPORTED_COMPRESSION_TYPE error. + val req2 = new FetchRequest.Builder(2, 3, -1, Int.MaxValue, 0, + createPartitionMap(300, Seq(topicPartition), Map.empty)) + .setMaxBytes(800).build() + + val res2 = sendFetchRequest(leaderId, req2) + val data2 = res2.responseData(topicNames, 3).get(topicPartition) + assertEquals(Errors.NONE.code, data2.errorCode) + assertEquals(1, records(data2).size) + + val req3 = new FetchRequest.Builder(0, 1, -1, Int.MaxValue, 0, + createPartitionMap(300, Seq(topicPartition), Map(topicPartition -> 1L))) + .setMaxBytes(800).build() + + val res3 = sendFetchRequest(leaderId, req3) + val data3 = res3.responseData(topicNames, 1).get(topicPartition) + assertEquals(Errors.UNSUPPORTED_COMPRESSION_TYPE.code, data3.errorCode) + + // fetch request with version 10: works fine! + val req4 = new FetchRequest.Builder(0, 10, -1, Int.MaxValue, 0, + createPartitionMap(300, Seq(topicPartition), Map.empty)) + .setMaxBytes(800).build() + val res4 = sendFetchRequest(leaderId, req4) + val data4 = res4.responseData(topicNames, 10).get(topicPartition) + assertEquals(Errors.NONE.code, data4.errorCode) + assertEquals(3, records(data4).size) + + val req5 = new FetchRequest.Builder(0, ApiKeys.FETCH.latestVersion(), -1, Int.MaxValue, 0, + createPartitionMap(300, Seq(topicPartition), Map.empty)) + .setMaxBytes(800).build() + val res5 = sendFetchRequest(leaderId, req5) + val data5 = res5.responseData(topicNames, ApiKeys.FETCH.latestVersion()).get(topicPartition) + assertEquals(Errors.NONE.code, data5.errorCode) + assertEquals(3, records(data5).size) + } + + private def checkFetchResponse(expectedPartitions: Seq[TopicPartition], fetchResponse: FetchResponse, + maxPartitionBytes: Int, maxResponseBytes: Int, numMessagesPerPartition: Int, + responseVersion: Short = ApiKeys.FETCH.latestVersion()): Unit = { + val topicNames = getTopicIds().map(_.swap).asJava + val responseData = fetchResponse.responseData(topicNames, responseVersion) + assertEquals(expectedPartitions, responseData.keySet.asScala.toSeq) + var emptyResponseSeen = false + var responseSize = 0 + var responseBufferSize = 0 + + expectedPartitions.foreach { tp => + val partitionData = responseData.get(tp) + assertEquals(Errors.NONE.code, partitionData.errorCode) + assertTrue(partitionData.highWatermark > 0) + + val records = FetchResponse.recordsOrFail(partitionData) + responseBufferSize += records.sizeInBytes + + val batches = records.batches.asScala.toBuffer + assertTrue(batches.size < numMessagesPerPartition) + val batchesSize = batches.map(_.sizeInBytes).sum + responseSize += batchesSize + if (batchesSize == 0 && !emptyResponseSeen) { + assertEquals(0, records.sizeInBytes) + emptyResponseSeen = true + } + else if (batchesSize != 0 && !emptyResponseSeen) { + assertTrue(batchesSize <= maxPartitionBytes) + assertEquals(maxPartitionBytes, records.sizeInBytes) + } + else if (batchesSize != 0 && emptyResponseSeen) + fail(s"Expected partition with size 0, but found $tp with size $batchesSize") + else if (records.sizeInBytes != 0 && emptyResponseSeen) + fail(s"Expected partition buffer with size 0, but found $tp with size ${records.sizeInBytes}") + } + + assertEquals(maxResponseBytes - maxResponseBytes % maxPartitionBytes, responseBufferSize) + assertTrue(responseSize <= maxResponseBytes) + } + +} diff --git a/core/src/test/scala/unit/kafka/server/FetchRequestWithLegacyMessageFormatTest.scala b/core/src/test/scala/unit/kafka/server/FetchRequestWithLegacyMessageFormatTest.scala new file mode 100644 index 0000000..2f78b9d --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/FetchRequestWithLegacyMessageFormatTest.scala @@ -0,0 +1,70 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import kafka.api.KAFKA_0_10_2_IV0 +import kafka.log.LogConfig +import org.apache.kafka.clients.producer.ProducerRecord +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.{FetchRequest, FetchResponse} +import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue} +import org.junit.jupiter.api.Test + +import java.util.Properties + +import scala.annotation.nowarn +import scala.collection.Seq +import scala.jdk.CollectionConverters._ + +class FetchRequestWithLegacyMessageFormatTest extends BaseFetchRequestTest { + + override def brokerPropertyOverrides(properties: Properties): Unit = { + super.brokerPropertyOverrides(properties) + // legacy message formats are only supported with IBP < 3.0 + properties.put(KafkaConfig.InterBrokerProtocolVersionProp, "2.8") + } + + /** + * Fetch request v2 (pre KIP-74) respected `maxPartitionBytes` even if no message could be returned + * due to a message that was larger than `maxPartitionBytes`. + */ + @nowarn("cat=deprecation") + @Test + def testFetchRequestV2WithOversizedMessage(): Unit = { + initProducer() + val maxPartitionBytes = 200 + // Fetch v2 down-converts if the message format is >= 0.11 and we want to avoid + // that as it affects the size of the returned buffer + val topicConfig = Map(LogConfig.MessageFormatVersionProp -> KAFKA_0_10_2_IV0.version) + val (topicPartition, leaderId) = createTopics(numTopics = 1, numPartitions = 1, topicConfig).head + val topicIds = getTopicIds().asJava + val topicNames = topicIds.asScala.map(_.swap).asJava + producer.send(new ProducerRecord(topicPartition.topic, topicPartition.partition, + "key", new String(new Array[Byte](maxPartitionBytes + 1)))).get + val fetchVersion: Short = 2 + val fetchRequest = FetchRequest.Builder.forConsumer(fetchVersion, Int.MaxValue, 0, + createPartitionMap(maxPartitionBytes, Seq(topicPartition))).build(fetchVersion) + val fetchResponse = sendFetchRequest(leaderId, fetchRequest) + val partitionData = fetchResponse.responseData(topicNames, fetchVersion).get(topicPartition) + assertEquals(Errors.NONE.code, partitionData.errorCode) + + assertTrue(partitionData.highWatermark > 0) + assertEquals(maxPartitionBytes, FetchResponse.recordsSize(partitionData)) + assertEquals(0, records(partitionData).map(_.sizeInBytes).sum) + } + +} diff --git a/core/src/test/scala/unit/kafka/server/FetchSessionTest.scala b/core/src/test/scala/unit/kafka/server/FetchSessionTest.scala new file mode 100755 index 0000000..538a061 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/FetchSessionTest.scala @@ -0,0 +1,1946 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +package kafka.server + +import kafka.utils.MockTime +import org.apache.kafka.common.{TopicIdPartition, TopicPartition, Uuid} +import org.apache.kafka.common.message.FetchResponseData +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.record.CompressionType +import org.apache.kafka.common.record.MemoryRecords +import org.apache.kafka.common.record.SimpleRecord +import org.apache.kafka.common.requests.FetchMetadata.{FINAL_EPOCH, INVALID_SESSION_ID} +import org.apache.kafka.common.requests.{FetchRequest, FetchResponse, FetchMetadata => JFetchMetadata} +import org.apache.kafka.common.utils.Utils +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{Test, Timeout} +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.{Arguments, MethodSource, ValueSource} + +import scala.jdk.CollectionConverters._ +import java.util +import java.util.{Collections, Optional} + +import scala.collection.mutable.ArrayBuffer + +@Timeout(120) +class FetchSessionTest { + + @Test + def testNewSessionId(): Unit = { + val cache = new FetchSessionCache(3, 100) + for (_ <- 0 to 10000) { + val id = cache.newSessionId() + assertTrue(id > 0) + } + } + + def assertCacheContains(cache: FetchSessionCache, sessionIds: Int*) = { + var i = 0 + for (sessionId <- sessionIds) { + i = i + 1 + assertTrue(cache.get(sessionId).isDefined, + "Missing session " + i + " out of " + sessionIds.size + "(" + sessionId + ")") + } + assertEquals(sessionIds.size, cache.size) + } + + private def dummyCreate(size: Int): FetchSession.CACHE_MAP = { + val cacheMap = new FetchSession.CACHE_MAP(size) + for (i <- 0 until size) { + cacheMap.add(new CachedPartition("test", Uuid.randomUuid(), i)) + } + cacheMap + } + + @Test + def testSessionCache(): Unit = { + val cache = new FetchSessionCache(3, 100) + assertEquals(0, cache.size) + val id1 = cache.maybeCreateSession(0, false, 10, true, () => dummyCreate(10)) + val id2 = cache.maybeCreateSession(10, false, 20, true, () => dummyCreate(20)) + val id3 = cache.maybeCreateSession(20, false, 30, true, () => dummyCreate(30)) + assertEquals(INVALID_SESSION_ID, cache.maybeCreateSession(30, false, 40, true, () => dummyCreate(40))) + assertEquals(INVALID_SESSION_ID, cache.maybeCreateSession(40, false, 5, true, () => dummyCreate(5))) + assertCacheContains(cache, id1, id2, id3) + cache.touch(cache.get(id1).get, 200) + val id4 = cache.maybeCreateSession(210, false, 11, true, () => dummyCreate(11)) + assertCacheContains(cache, id1, id3, id4) + cache.touch(cache.get(id1).get, 400) + cache.touch(cache.get(id3).get, 390) + cache.touch(cache.get(id4).get, 400) + val id5 = cache.maybeCreateSession(410, false, 50, true, () => dummyCreate(50)) + assertCacheContains(cache, id3, id4, id5) + assertEquals(INVALID_SESSION_ID, cache.maybeCreateSession(410, false, 5, true, () => dummyCreate(5))) + val id6 = cache.maybeCreateSession(410, true, 5, true, () => dummyCreate(5)) + assertCacheContains(cache, id3, id5, id6) + } + + @Test + def testResizeCachedSessions(): Unit = { + val cache = new FetchSessionCache(2, 100) + assertEquals(0, cache.totalPartitions) + assertEquals(0, cache.size) + assertEquals(0, cache.evictionsMeter.count) + val id1 = cache.maybeCreateSession(0, false, 2, true, () => dummyCreate(2)) + assertTrue(id1 > 0) + assertCacheContains(cache, id1) + val session1 = cache.get(id1).get + assertEquals(2, session1.size) + assertEquals(2, cache.totalPartitions) + assertEquals(1, cache.size) + assertEquals(0, cache.evictionsMeter.count) + val id2 = cache.maybeCreateSession(0, false, 4, true, () => dummyCreate(4)) + val session2 = cache.get(id2).get + assertTrue(id2 > 0) + assertCacheContains(cache, id1, id2) + assertEquals(6, cache.totalPartitions) + assertEquals(2, cache.size) + assertEquals(0, cache.evictionsMeter.count) + cache.touch(session1, 200) + cache.touch(session2, 200) + val id3 = cache.maybeCreateSession(200, false, 5, true, () => dummyCreate(5)) + assertTrue(id3 > 0) + assertCacheContains(cache, id2, id3) + assertEquals(9, cache.totalPartitions) + assertEquals(2, cache.size) + assertEquals(1, cache.evictionsMeter.count) + cache.remove(id3) + assertCacheContains(cache, id2) + assertEquals(1, cache.size) + assertEquals(1, cache.evictionsMeter.count) + assertEquals(4, cache.totalPartitions) + val iter = session2.partitionMap.iterator + iter.next() + iter.remove() + assertEquals(3, session2.size) + assertEquals(4, session2.cachedSize) + cache.touch(session2, session2.lastUsedMs) + assertEquals(3, cache.totalPartitions) + } + + private val EMPTY_PART_LIST = Collections.unmodifiableList(new util.ArrayList[TopicIdPartition]()) + + def createRequest(metadata: JFetchMetadata, + fetchData: util.Map[TopicPartition, FetchRequest.PartitionData], + toForget: util.List[TopicIdPartition], isFromFollower: Boolean, + version: Short = ApiKeys.FETCH.latestVersion): FetchRequest = { + new FetchRequest.Builder(version, version, if (isFromFollower) 1 else FetchRequest.CONSUMER_REPLICA_ID, + 0, 0, fetchData).metadata(metadata).removed(toForget).build + } + + def createRequestWithoutTopicIds(metadata: JFetchMetadata, + fetchData: util.Map[TopicPartition, FetchRequest.PartitionData], + toForget: util.List[TopicIdPartition], isFromFollower: Boolean): FetchRequest = { + new FetchRequest.Builder(12, 12, if (isFromFollower) 1 else FetchRequest.CONSUMER_REPLICA_ID, + 0, 0, fetchData).metadata(metadata).removed(toForget).build + } + + @Test + def testCachedLeaderEpoch(): Unit = { + val time = new MockTime() + val cache = new FetchSessionCache(10, 1000) + val fetchManager = new FetchManager(time, cache) + + val topicIds = Map("foo" -> Uuid.randomUuid(), "bar" -> Uuid.randomUuid()).asJava + val tp0 = new TopicIdPartition(topicIds.get("foo"), new TopicPartition("foo", 0)) + val tp1 = new TopicIdPartition(topicIds.get("foo"), new TopicPartition("foo", 1)) + val tp2 = new TopicIdPartition(topicIds.get("bar"), new TopicPartition("bar", 1)) + val topicNames = topicIds.asScala.map(_.swap).asJava + + def cachedLeaderEpochs(context: FetchContext): Map[TopicIdPartition, Optional[Integer]] = { + val mapBuilder = Map.newBuilder[TopicIdPartition, Optional[Integer]] + context.foreachPartition((tp, data) => mapBuilder += tp -> data.currentLeaderEpoch) + mapBuilder.result() + } + + val requestData1 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + requestData1.put(tp0.topicPartition, new FetchRequest.PartitionData(tp0.topicId, 0, 0, 100, Optional.empty())) + requestData1.put(tp1.topicPartition, new FetchRequest.PartitionData(tp1.topicId, 10, 0, 100, Optional.of(1))) + requestData1.put(tp2.topicPartition, new FetchRequest.PartitionData(tp2.topicId, 10, 0, 100, Optional.of(2))) + + val request1 = createRequest(JFetchMetadata.INITIAL, requestData1, EMPTY_PART_LIST, false) + val context1 = fetchManager.newContext( + request1.version, + request1.metadata, + request1.isFromFollower, + request1.fetchData(topicNames), + request1.forgottenTopics(topicNames), + topicNames + ) + val epochs1 = cachedLeaderEpochs(context1) + assertEquals(Optional.empty(), epochs1(tp0)) + assertEquals(Optional.of(1), epochs1(tp1)) + assertEquals(Optional.of(2), epochs1(tp2)) + + val response = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + response.put(tp0, new FetchResponseData.PartitionData() + .setPartitionIndex(tp0.partition) + .setHighWatermark(100) + .setLastStableOffset(100) + .setLogStartOffset(100)) + response.put(tp1, new FetchResponseData.PartitionData() + .setPartitionIndex(tp1.partition) + .setHighWatermark(10) + .setLastStableOffset(10) + .setLogStartOffset(10)) + response.put(tp2, new FetchResponseData.PartitionData() + .setPartitionIndex(tp2.partition) + .setHighWatermark(5) + .setLastStableOffset(5) + .setLogStartOffset(5)) + + val sessionId = context1.updateAndGenerateResponseData(response).sessionId() + + // With no changes, the cached epochs should remain the same + val requestData2 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + val request2 = createRequest(new JFetchMetadata(sessionId, 1), requestData2, EMPTY_PART_LIST, false) + val context2 = fetchManager.newContext( + request2.version, + request2.metadata, + request2.isFromFollower, + request2.fetchData(topicNames), + request2.forgottenTopics(topicNames), + topicNames + ) + val epochs2 = cachedLeaderEpochs(context2) + assertEquals(Optional.empty(), epochs1(tp0)) + assertEquals(Optional.of(1), epochs2(tp1)) + assertEquals(Optional.of(2), epochs2(tp2)) + context2.updateAndGenerateResponseData(response).sessionId() + + // Now verify we can change the leader epoch and the context is updated + val requestData3 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + requestData3.put(tp0.topicPartition, new FetchRequest.PartitionData(tp0.topicId, 0, 0, 100, Optional.of(6))) + requestData3.put(tp1.topicPartition, new FetchRequest.PartitionData(tp1.topicId, 10, 0, 100, Optional.empty())) + requestData3.put(tp2.topicPartition, new FetchRequest.PartitionData(tp2.topicId, 10, 0, 100, Optional.of(3))) + + val request3 = createRequest(new JFetchMetadata(sessionId, 2), requestData3, EMPTY_PART_LIST, false) + val context3 = fetchManager.newContext( + request3.version, + request3.metadata, + request3.isFromFollower, + request3.fetchData(topicNames), + request3.forgottenTopics(topicNames), + topicNames + ) + val epochs3 = cachedLeaderEpochs(context3) + assertEquals(Optional.of(6), epochs3(tp0)) + assertEquals(Optional.empty(), epochs3(tp1)) + assertEquals(Optional.of(3), epochs3(tp2)) + } + + @Test + def testLastFetchedEpoch(): Unit = { + val time = new MockTime() + val cache = new FetchSessionCache(10, 1000) + val fetchManager = new FetchManager(time, cache) + + val topicIds = Map("foo" -> Uuid.randomUuid(), "bar" -> Uuid.randomUuid()).asJava + val topicNames = topicIds.asScala.map(_.swap).asJava + val tp0 = new TopicIdPartition(topicIds.get("foo"), new TopicPartition("foo", 0)) + val tp1 = new TopicIdPartition(topicIds.get("foo"), new TopicPartition("foo", 1)) + val tp2 = new TopicIdPartition(topicIds.get("bar"), new TopicPartition("bar", 1)) + + def cachedLeaderEpochs(context: FetchContext): Map[TopicIdPartition, Optional[Integer]] = { + val mapBuilder = Map.newBuilder[TopicIdPartition, Optional[Integer]] + context.foreachPartition((tp, data) => mapBuilder += tp -> data.currentLeaderEpoch) + mapBuilder.result() + } + + def cachedLastFetchedEpochs(context: FetchContext): Map[TopicIdPartition, Optional[Integer]] = { + val mapBuilder = Map.newBuilder[TopicIdPartition, Optional[Integer]] + context.foreachPartition((tp, data) => mapBuilder += tp -> data.lastFetchedEpoch) + mapBuilder.result() + } + + val requestData1 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + requestData1.put(tp0.topicPartition, new FetchRequest.PartitionData(tp0.topicId, 0, 0, 100, Optional.empty[Integer], Optional.empty[Integer])) + requestData1.put(tp1.topicPartition, new FetchRequest.PartitionData(tp1.topicId, 10, 0, 100, Optional.of(1), Optional.empty[Integer])) + requestData1.put(tp2.topicPartition, new FetchRequest.PartitionData(tp2.topicId, 10, 0, 100, Optional.of(2), Optional.of(1))) + + val request1 = createRequest(JFetchMetadata.INITIAL, requestData1, EMPTY_PART_LIST, false) + val context1 = fetchManager.newContext( + request1.version, + request1.metadata, + request1.isFromFollower, + request1.fetchData(topicNames), + request1.forgottenTopics(topicNames), + topicNames + ) + assertEquals(Map(tp0 -> Optional.empty, tp1 -> Optional.of(1), tp2 -> Optional.of(2)), + cachedLeaderEpochs(context1)) + assertEquals(Map(tp0 -> Optional.empty, tp1 -> Optional.empty, tp2 -> Optional.of(1)), + cachedLastFetchedEpochs(context1)) + + val response = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + response.put(tp0, new FetchResponseData.PartitionData() + .setPartitionIndex(tp0.partition) + .setHighWatermark(100) + .setLastStableOffset(100) + .setLogStartOffset(100)) + response.put(tp1, new FetchResponseData.PartitionData() + .setPartitionIndex(tp1.partition) + .setHighWatermark(10) + .setLastStableOffset(10) + .setLogStartOffset(10)) + response.put(tp2, new FetchResponseData.PartitionData() + .setPartitionIndex(tp2.partition) + .setHighWatermark(5) + .setLastStableOffset(5) + .setLogStartOffset(5)) + + val sessionId = context1.updateAndGenerateResponseData(response).sessionId() + + // With no changes, the cached epochs should remain the same + val requestData2 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + val request2 = createRequest(new JFetchMetadata(sessionId, 1), requestData2, EMPTY_PART_LIST, false) + val context2 = fetchManager.newContext( + request2.version, + request2.metadata, + request2.isFromFollower, + request2.fetchData(topicNames), + request2.forgottenTopics(topicNames), + topicNames + ) + assertEquals(Map(tp0 -> Optional.empty, tp1 -> Optional.of(1), tp2 -> Optional.of(2)), cachedLeaderEpochs(context2)) + assertEquals(Map(tp0 -> Optional.empty, tp1 -> Optional.empty, tp2 -> Optional.of(1)), + cachedLastFetchedEpochs(context2)) + context2.updateAndGenerateResponseData(response).sessionId() + + // Now verify we can change the leader epoch and the context is updated + val requestData3 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + requestData3.put(tp0.topicPartition, new FetchRequest.PartitionData(tp0.topicId, 0, 0, 100, Optional.of(6), Optional.of(5))) + requestData3.put(tp1.topicPartition, new FetchRequest.PartitionData(tp1.topicId, 10, 0, 100, Optional.empty[Integer], Optional.empty[Integer])) + requestData3.put(tp2.topicPartition, new FetchRequest.PartitionData(tp2.topicId, 10, 0, 100, Optional.of(3), Optional.of(3))) + + val request3 = createRequest(new JFetchMetadata(sessionId, 2), requestData3, EMPTY_PART_LIST, false) + val context3 = fetchManager.newContext( + request3.version, + request3.metadata, + request3.isFromFollower, + request3.fetchData(topicNames), + request3.forgottenTopics(topicNames), + topicNames + ) + assertEquals(Map(tp0 -> Optional.of(6), tp1 -> Optional.empty, tp2 -> Optional.of(3)), + cachedLeaderEpochs(context3)) + assertEquals(Map(tp0 -> Optional.of(5), tp1 -> Optional.empty, tp2 -> Optional.of(3)), + cachedLastFetchedEpochs(context2)) + } + + @Test + def testFetchRequests(): Unit = { + val time = new MockTime() + val cache = new FetchSessionCache(10, 1000) + val fetchManager = new FetchManager(time, cache) + val topicNames = Map(Uuid.randomUuid() -> "foo", Uuid.randomUuid() -> "bar").asJava + val topicIds = topicNames.asScala.map(_.swap).asJava + val tp0 = new TopicIdPartition(topicIds.get("foo"), new TopicPartition("foo", 0)) + val tp1 = new TopicIdPartition(topicIds.get("foo"), new TopicPartition("foo", 1)) + val tp2 = new TopicIdPartition(topicIds.get("bar"), new TopicPartition("bar", 0)) + val tp3 = new TopicIdPartition(topicIds.get("bar"), new TopicPartition("bar", 1)) + + // Verify that SESSIONLESS requests get a SessionlessFetchContext + val request = createRequest(JFetchMetadata.LEGACY, new util.HashMap[TopicPartition, FetchRequest.PartitionData](), EMPTY_PART_LIST, true) + val context = fetchManager.newContext( + request.version, + request.metadata, + request.isFromFollower, + request.fetchData(topicNames), + request.forgottenTopics(topicNames), + topicNames + ) + assertEquals(classOf[SessionlessFetchContext], context.getClass) + + // Create a new fetch session with a FULL fetch request + val reqData2 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + reqData2.put(tp0.topicPartition, new FetchRequest.PartitionData(tp0.topicId, 0, 0, 100, + Optional.empty())) + reqData2.put(tp1.topicPartition, new FetchRequest.PartitionData(tp1.topicId, 10, 0, 100, + Optional.empty())) + val request2 = createRequest(JFetchMetadata.INITIAL, reqData2, EMPTY_PART_LIST, false) + val context2 = fetchManager.newContext( + request2.version, + request2.metadata, + request2.isFromFollower, + request2.fetchData(topicNames), + request2.forgottenTopics(topicNames), + topicNames + ) + assertEquals(classOf[FullFetchContext], context2.getClass) + val reqData2Iter = reqData2.entrySet().iterator() + context2.foreachPartition((topicIdPart, data) => { + val entry = reqData2Iter.next() + assertEquals(entry.getKey, topicIdPart.topicPartition) + assertEquals(topicIds.get(entry.getKey.topic), topicIdPart.topicId) + assertEquals(entry.getValue, data) + }) + assertEquals(0, context2.getFetchOffset(tp0).get) + assertEquals(10, context2.getFetchOffset(tp1).get) + val respData2 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + respData2.put(tp0, + new FetchResponseData.PartitionData() + .setPartitionIndex(0) + .setHighWatermark(100) + .setLastStableOffset(100) + .setLogStartOffset(100)) + respData2.put(tp1, + new FetchResponseData.PartitionData() + .setPartitionIndex(1) + .setHighWatermark(10) + .setLastStableOffset(10) + .setLogStartOffset(10)) + val resp2 = context2.updateAndGenerateResponseData(respData2) + assertEquals(Errors.NONE, resp2.error()) + assertTrue(resp2.sessionId() != INVALID_SESSION_ID) + assertEquals(respData2.asScala.map { case (tp, data) => (tp.topicPartition, data)}.toMap.asJava, resp2.responseData(topicNames, request2.version)) + + // Test trying to create a new session with an invalid epoch + val request3 = createRequest(new JFetchMetadata(resp2.sessionId(), 5), reqData2, EMPTY_PART_LIST, false) + val context3 = fetchManager.newContext( + request3.version, + request3.metadata, + request3.isFromFollower, + request3.fetchData(topicNames), + request3.forgottenTopics(topicNames), + topicNames + ) + assertEquals(classOf[SessionErrorContext], context3.getClass) + assertEquals(Errors.INVALID_FETCH_SESSION_EPOCH, + context3.updateAndGenerateResponseData(respData2).error()) + + // Test trying to create a new session with a non-existent session id + val request4 = createRequest(new JFetchMetadata(resp2.sessionId() + 1, 1), reqData2, EMPTY_PART_LIST, false) + val context4 = fetchManager.newContext( + request4.version, + request4.metadata, + request4.isFromFollower, + request4.fetchData(topicNames), + request4.forgottenTopics(topicNames), + topicNames + ) + assertEquals(Errors.FETCH_SESSION_ID_NOT_FOUND, + context4.updateAndGenerateResponseData(respData2).error()) + + // Continue the first fetch session we created. + val reqData5 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + val request5 = createRequest( new JFetchMetadata(resp2.sessionId(), 1), reqData5, EMPTY_PART_LIST, false) + val context5 = fetchManager.newContext( + request5.version, + request5.metadata, + request5.isFromFollower, + request5.fetchData(topicNames), + request5.forgottenTopics(topicNames), + topicNames + ) + assertEquals(classOf[IncrementalFetchContext], context5.getClass) + val reqData5Iter = reqData2.entrySet().iterator() + context5.foreachPartition((topicIdPart, data) => { + val entry = reqData5Iter.next() + assertEquals(entry.getKey, topicIdPart.topicPartition) + assertEquals(topicIds.get(entry.getKey.topic()), topicIdPart.topicId) + assertEquals(entry.getValue, data) + }) + assertEquals(10, context5.getFetchOffset(tp1).get) + val resp5 = context5.updateAndGenerateResponseData(respData2) + assertEquals(Errors.NONE, resp5.error()) + assertEquals(resp2.sessionId(), resp5.sessionId()) + assertEquals(0, resp5.responseData(topicNames, request5.version).size()) + + // Test setting an invalid fetch session epoch. + val request6 = createRequest( new JFetchMetadata(resp2.sessionId(), 5), reqData2, EMPTY_PART_LIST, false) + val context6 = fetchManager.newContext( + request6.version, + request6.metadata, + request6.isFromFollower, + request6.fetchData(topicNames), + request6.forgottenTopics(topicNames), + topicNames + ) + assertEquals(classOf[SessionErrorContext], context6.getClass) + assertEquals(Errors.INVALID_FETCH_SESSION_EPOCH, + context6.updateAndGenerateResponseData(respData2).error()) + + // Test generating a throttled response for the incremental fetch session + val reqData7 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + val request7 = createRequest( new JFetchMetadata(resp2.sessionId(), 2), reqData7, EMPTY_PART_LIST, false) + val context7 = fetchManager.newContext( + request7.version, + request7.metadata, + request7.isFromFollower, + request7.fetchData(topicNames), + request7.forgottenTopics(topicNames), + topicNames + ) + val resp7 = context7.getThrottledResponse(100) + assertEquals(Errors.NONE, resp7.error()) + assertEquals(resp2.sessionId(), resp7.sessionId()) + assertEquals(100, resp7.throttleTimeMs()) + + // Close the incremental fetch session. + val prevSessionId = resp5.sessionId + var nextSessionId = prevSessionId + do { + val reqData8 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + reqData8.put(tp2.topicPartition, new FetchRequest.PartitionData(tp2.topicId, 0, 0, 100, + Optional.empty())) + reqData8.put(tp3.topicPartition, new FetchRequest.PartitionData(tp3.topicId, 10, 0, 100, + Optional.empty())) + val request8 = createRequest(new JFetchMetadata(prevSessionId, FINAL_EPOCH), reqData8, EMPTY_PART_LIST, false) + val context8 = fetchManager.newContext( + request8.version, + request8.metadata, + request8.isFromFollower, + request8.fetchData(topicNames), + request8.forgottenTopics(topicNames), + topicNames + ) + assertEquals(classOf[SessionlessFetchContext], context8.getClass) + assertEquals(0, cache.size) + val respData8 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + respData8.put(tp2, + new FetchResponseData.PartitionData() + .setPartitionIndex(0) + .setHighWatermark(100) + .setLastStableOffset(100) + .setLogStartOffset(100)) + respData8.put(tp3, + new FetchResponseData.PartitionData() + .setPartitionIndex(1) + .setHighWatermark(100) + .setLastStableOffset(100) + .setLogStartOffset(100)) + val resp8 = context8.updateAndGenerateResponseData(respData8) + assertEquals(Errors.NONE, resp8.error) + nextSessionId = resp8.sessionId + } while (nextSessionId == prevSessionId) + } + + @ParameterizedTest + @ValueSource(booleans = Array(true, false)) + def testIncrementalFetchSession(usesTopicIds: Boolean): Unit = { + val time = new MockTime() + val cache = new FetchSessionCache(10, 1000) + val fetchManager = new FetchManager(time, cache) + val topicNames = if (usesTopicIds) Map(Uuid.randomUuid() -> "foo", Uuid.randomUuid() -> "bar").asJava else Map[Uuid, String]().asJava + val topicIds = topicNames.asScala.map(_.swap).asJava + val version = if (usesTopicIds) ApiKeys.FETCH.latestVersion else 12.toShort + val fooId = topicIds.getOrDefault("foo", Uuid.ZERO_UUID) + val barId = topicIds.getOrDefault("bar", Uuid.ZERO_UUID) + val tp0 = new TopicIdPartition(fooId, new TopicPartition("foo", 0)) + val tp1 = new TopicIdPartition(fooId, new TopicPartition("foo", 1)) + val tp2 = new TopicIdPartition(barId, new TopicPartition("bar", 0)) + + // Create a new fetch session with foo-0 and foo-1 + val reqData1 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + reqData1.put(tp0.topicPartition, new FetchRequest.PartitionData(fooId,0, 0, 100, + Optional.empty())) + reqData1.put(tp1.topicPartition, new FetchRequest.PartitionData(fooId, 10, 0, 100, + Optional.empty())) + val request1 = createRequest(JFetchMetadata.INITIAL, reqData1, EMPTY_PART_LIST, false, version) + val context1 = fetchManager.newContext( + request1.version, + request1.metadata, + request1.isFromFollower, + request1.fetchData(topicNames), + request1.forgottenTopics(topicNames), + topicNames + ) + assertEquals(classOf[FullFetchContext], context1.getClass) + val respData1 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + respData1.put(tp0, new FetchResponseData.PartitionData() + .setPartitionIndex(0) + .setHighWatermark(100) + .setLastStableOffset(100) + .setLogStartOffset(100)) + respData1.put(tp1, new FetchResponseData.PartitionData() + .setPartitionIndex(1) + .setHighWatermark(10) + .setLastStableOffset(10) + .setLogStartOffset(10)) + val resp1 = context1.updateAndGenerateResponseData(respData1) + assertEquals(Errors.NONE, resp1.error()) + assertTrue(resp1.sessionId() != INVALID_SESSION_ID) + assertEquals(2, resp1.responseData(topicNames, request1.version).size()) + + // Create an incremental fetch request that removes foo-0 and adds bar-0 + val reqData2 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + reqData2.put(tp2.topicPartition, new FetchRequest.PartitionData(barId,15, 0, 0, + Optional.empty())) + val removed2 = new util.ArrayList[TopicIdPartition] + removed2.add(tp0) + val request2 = createRequest(new JFetchMetadata(resp1.sessionId(), 1), reqData2, removed2, false, version) + val context2 = fetchManager.newContext( + request2.version, + request2.metadata, + request2.isFromFollower, + request2.fetchData(topicNames), + request2.forgottenTopics(topicNames), + topicNames + ) + assertEquals(classOf[IncrementalFetchContext], context2.getClass) + val parts2 = Set(tp1, tp2) + val reqData2Iter = parts2.iterator + context2.foreachPartition((topicIdPart, _) => { + assertEquals(reqData2Iter.next(), topicIdPart) + }) + assertEquals(None, context2.getFetchOffset(tp0)) + assertEquals(10, context2.getFetchOffset(tp1).get) + assertEquals(15, context2.getFetchOffset(tp2).get) + assertEquals(None, context2.getFetchOffset(new TopicIdPartition(barId, new TopicPartition("bar", 2)))) + val respData2 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + respData2.put(tp1, new FetchResponseData.PartitionData() + .setPartitionIndex(1) + .setHighWatermark(10) + .setLastStableOffset(10) + .setLogStartOffset(10)) + respData2.put(tp2, new FetchResponseData.PartitionData() + .setPartitionIndex(0) + .setHighWatermark(10) + .setLastStableOffset(10) + .setLogStartOffset(10)) + val resp2 = context2.updateAndGenerateResponseData(respData2) + assertEquals(Errors.NONE, resp2.error) + assertEquals(1, resp2.responseData(topicNames, request2.version).size) + assertTrue(resp2.sessionId > 0) + } + + // This test simulates a request without IDs sent to a broker with IDs. + @Test + def testFetchSessionWithUnknownIdOldRequestVersion(): Unit = { + val time = new MockTime() + val cache = new FetchSessionCache(10, 1000) + val fetchManager = new FetchManager(time, cache) + val topicNames = Map(Uuid.randomUuid() -> "foo", Uuid.randomUuid() -> "bar").asJava + val topicIds = topicNames.asScala.map(_.swap).asJava + val tp0 = new TopicIdPartition(topicIds.get("foo"), new TopicPartition("foo", 0)) + val tp1 = new TopicIdPartition(topicIds.get("foo"), new TopicPartition("foo", 1)) + + // Create a new fetch session with foo-0 and foo-1 + val reqData1 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + reqData1.put(tp0.topicPartition, new FetchRequest.PartitionData(tp0.topicId, 0, 0, 100, + Optional.empty())) + reqData1.put(tp1.topicPartition, new FetchRequest.PartitionData(Uuid.ZERO_UUID, 10, 0, 100, + Optional.empty())) + val request1 = createRequestWithoutTopicIds(JFetchMetadata.INITIAL, reqData1, EMPTY_PART_LIST, false) + // Simulate unknown topic ID for foo. + val topicNamesOnlyBar = Collections.singletonMap(topicIds.get("bar"), "bar") + // We should not throw error since we have an older request version. + val context1 = fetchManager.newContext( + request1.version, + request1.metadata, + request1.isFromFollower, + request1.fetchData(topicNamesOnlyBar), + request1.forgottenTopics(topicNamesOnlyBar), + topicNamesOnlyBar + ) + assertEquals(classOf[FullFetchContext], context1.getClass) + val respData1 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + respData1.put(tp0, new FetchResponseData.PartitionData() + .setPartitionIndex(0) + .setHighWatermark(100) + .setLastStableOffset(100) + .setLogStartOffset(100)) + respData1.put(tp1, new FetchResponseData.PartitionData() + .setPartitionIndex(1) + .setHighWatermark(10) + .setLastStableOffset(10) + .setLogStartOffset(10)) + val resp1 = context1.updateAndGenerateResponseData(respData1) + // Since we are ignoring IDs, we should have no errors. + assertEquals(Errors.NONE, resp1.error()) + assertTrue(resp1.sessionId() != INVALID_SESSION_ID) + assertEquals(2, resp1.responseData(topicNames, request1.version).size) + resp1.responseData(topicNames, request1.version).forEach( (_, resp) => assertEquals(Errors.NONE.code, resp.errorCode)) + } + + @Test + def testFetchSessionWithUnknownId(): Unit = { + val time = new MockTime() + val cache = new FetchSessionCache(10, 1000) + val fetchManager = new FetchManager(time, cache) + val fooId = Uuid.randomUuid() + val barId = Uuid.randomUuid() + val zarId = Uuid.randomUuid() + val topicNames = Map(fooId -> "foo", barId -> "bar", zarId -> "zar").asJava + val foo0 = new TopicIdPartition(fooId, new TopicPartition("foo", 0)) + val foo1 = new TopicIdPartition(fooId, new TopicPartition("foo", 1)) + val zar0 = new TopicIdPartition(zarId, new TopicPartition("zar", 0)) + val emptyFoo0 = new TopicIdPartition(fooId, new TopicPartition(null, 0)) + val emptyFoo1 = new TopicIdPartition(fooId, new TopicPartition(null, 1)) + val emptyZar0 = new TopicIdPartition(zarId, new TopicPartition(null, 0)) + + // Create a new fetch session with foo-0 and foo-1 + val reqData1 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + reqData1.put(foo0.topicPartition, new FetchRequest.PartitionData(foo0.topicId, 0, 0, 100, + Optional.empty())) + reqData1.put(foo1.topicPartition, new FetchRequest.PartitionData(foo1.topicId, 10, 0, 100, + Optional.empty())) + reqData1.put(zar0.topicPartition, new FetchRequest.PartitionData(zar0.topicId, 10, 0, 100, + Optional.empty())) + val request1 = createRequest(JFetchMetadata.INITIAL, reqData1, EMPTY_PART_LIST, false) + // Simulate unknown topic ID for foo. + val topicNamesOnlyBar = Collections.singletonMap(barId, "bar") + // We should not throw error since we have an older request version. + val context1 = fetchManager.newContext( + request1.version, + request1.metadata, + request1.isFromFollower, + request1.fetchData(topicNamesOnlyBar), + request1.forgottenTopics(topicNamesOnlyBar), + topicNamesOnlyBar + ) + assertEquals(classOf[FullFetchContext], context1.getClass) + assertPartitionsOrder(context1, Seq(emptyFoo0, emptyFoo1, emptyZar0)) + val respData1 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + respData1.put(emptyFoo0, new FetchResponseData.PartitionData() + .setPartitionIndex(0) + .setErrorCode(Errors.UNKNOWN_TOPIC_ID.code)) + respData1.put(emptyFoo1, new FetchResponseData.PartitionData() + .setPartitionIndex(1) + .setErrorCode(Errors.UNKNOWN_TOPIC_ID.code)) + respData1.put(emptyZar0, new FetchResponseData.PartitionData() + .setPartitionIndex(1) + .setErrorCode(Errors.UNKNOWN_TOPIC_ID.code)) + val resp1 = context1.updateAndGenerateResponseData(respData1) + // On the latest request version, we should have unknown topic ID errors. + assertEquals(Errors.NONE, resp1.error()) + assertTrue(resp1.sessionId() != INVALID_SESSION_ID) + assertEquals( + Map( + foo0.topicPartition -> Errors.UNKNOWN_TOPIC_ID.code, + foo1.topicPartition -> Errors.UNKNOWN_TOPIC_ID.code, + zar0.topicPartition() -> Errors.UNKNOWN_TOPIC_ID.code + ), + resp1.responseData(topicNames, request1.version).asScala.map { case (tp, resp) => + tp -> resp.errorCode + } + ) + + // Create an incremental request where we resolve the partitions + val reqData2 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + val request2 = createRequest(new JFetchMetadata(resp1.sessionId(), 1), reqData2, EMPTY_PART_LIST, false) + val topicNamesNoZar = Map(fooId -> "foo", barId -> "bar").asJava + val context2 = fetchManager.newContext( + request2.version, + request2.metadata, + request2.isFromFollower, + request2.fetchData(topicNamesNoZar), + request2.forgottenTopics(topicNamesNoZar), + topicNamesNoZar + ) + assertEquals(classOf[IncrementalFetchContext], context2.getClass) + // Topic names in the session but not in the request are lazily resolved via foreachPartition. Resolve foo topic IDs here. + assertPartitionsOrder(context2, Seq(foo0, foo1, emptyZar0)) + val respData2 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + respData2.put(foo0, new FetchResponseData.PartitionData() + .setPartitionIndex(0) + .setHighWatermark(100) + .setLastStableOffset(100) + .setLogStartOffset(100)) + respData2.put(foo1, new FetchResponseData.PartitionData() + .setPartitionIndex(1) + .setHighWatermark(10) + .setLastStableOffset(10) + .setLogStartOffset(10)) + respData2.put(emptyZar0, new FetchResponseData.PartitionData() + .setPartitionIndex(1) + .setErrorCode(Errors.UNKNOWN_TOPIC_ID.code)) + val resp2 = context2.updateAndGenerateResponseData(respData2) + // Since we are ignoring IDs, we should have no errors. + assertEquals(Errors.NONE, resp2.error()) + assertTrue(resp2.sessionId() != INVALID_SESSION_ID) + assertEquals(3, resp2.responseData(topicNames, request2.version).size) + assertEquals( + Map( + foo0.topicPartition -> Errors.NONE.code, + foo1.topicPartition -> Errors.NONE.code, + zar0.topicPartition -> Errors.UNKNOWN_TOPIC_ID.code + ), + resp2.responseData(topicNames, request2.version).asScala.map { case (tp, resp) => + tp -> resp.errorCode + } + ) + } + + @Test + def testIncrementalFetchSessionWithIdsWhenSessionDoesNotUseIds() : Unit = { + val time = new MockTime() + val cache = new FetchSessionCache(10, 1000) + val fetchManager = new FetchManager(time, cache) + val topicNames = new util.HashMap[Uuid, String]() + val foo0 = new TopicIdPartition(Uuid.ZERO_UUID, new TopicPartition("foo", 0)) + + // Create a new fetch session with foo-0 + val reqData1 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + reqData1.put(foo0.topicPartition, new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0, 0, 100, + Optional.empty())) + val request1 = createRequestWithoutTopicIds(JFetchMetadata.INITIAL, reqData1, EMPTY_PART_LIST, false) + // Start a fetch session using a request version that does not use topic IDs. + val context1 = fetchManager.newContext( + request1.version, + request1.metadata, + request1.isFromFollower, + request1.fetchData(topicNames), + request1.forgottenTopics(topicNames), + topicNames + ) + assertEquals(classOf[FullFetchContext], context1.getClass) + val respData1 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + respData1.put(foo0, new FetchResponseData.PartitionData() + .setPartitionIndex(0) + .setHighWatermark(100) + .setLastStableOffset(100) + .setLogStartOffset(100)) + val resp1 = context1.updateAndGenerateResponseData(respData1) + assertEquals(Errors.NONE, resp1.error()) + assertTrue(resp1.sessionId() != INVALID_SESSION_ID) + + // Create an incremental fetch request as though no topics changed. However, send a v13 request. + // Also simulate the topic ID found on the server. + val fooId = Uuid.randomUuid() + topicNames.put(fooId, "foo") + val reqData2 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + val request2 = createRequest(new JFetchMetadata(resp1.sessionId(), 1), reqData2, EMPTY_PART_LIST, false) + val context2 = fetchManager.newContext( + request2.version, + request2.metadata, + request2.isFromFollower, + request2.fetchData(topicNames), + request2.forgottenTopics(topicNames), + topicNames + ) + + assertEquals(classOf[SessionErrorContext], context2.getClass) + val respData2 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + assertEquals(Errors.FETCH_SESSION_TOPIC_ID_ERROR, + context2.updateAndGenerateResponseData(respData2).error()) + } + + @Test + def testIncrementalFetchSessionWithoutIdsWhenSessionUsesIds() : Unit = { + val time = new MockTime() + val cache = new FetchSessionCache(10, 1000) + val fetchManager = new FetchManager(time, cache) + val fooId = Uuid.randomUuid() + val topicNames = new util.HashMap[Uuid, String]() + topicNames.put(fooId, "foo") + val foo0 = new TopicIdPartition(fooId, new TopicPartition("foo", 0)) + + // Create a new fetch session with foo-0 + val reqData1 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + reqData1.put(foo0.topicPartition, new FetchRequest.PartitionData(fooId,0, 0, 100, + Optional.empty())) + val request1 = createRequest(JFetchMetadata.INITIAL, reqData1, EMPTY_PART_LIST, false) + // Start a fetch session using a request version that uses topic IDs. + val context1 = fetchManager.newContext( + request1.version, + request1.metadata, + request1.isFromFollower, + request1.fetchData(topicNames), + request1.forgottenTopics(topicNames), + topicNames + ) + assertEquals(classOf[FullFetchContext], context1.getClass) + val respData1 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + respData1.put(foo0, new FetchResponseData.PartitionData() + .setPartitionIndex(0) + .setHighWatermark(100) + .setLastStableOffset(100) + .setLogStartOffset(100)) + val resp1 = context1.updateAndGenerateResponseData(respData1) + assertEquals(Errors.NONE, resp1.error()) + assertTrue(resp1.sessionId() != INVALID_SESSION_ID) + + // Create an incremental fetch request as though no topics changed. However, send a v12 request. + // Also simulate the topic ID not found on the server + topicNames.remove(fooId) + val reqData2 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + val request2 = createRequestWithoutTopicIds(new JFetchMetadata(resp1.sessionId(), 1), reqData2, EMPTY_PART_LIST, false) + val context2 = fetchManager.newContext( + request2.version, + request2.metadata, + request2.isFromFollower, + request2.fetchData(topicNames), + request2.forgottenTopics(topicNames), + topicNames + ) + + assertEquals(classOf[SessionErrorContext], context2.getClass) + val respData2 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + assertEquals(Errors.FETCH_SESSION_TOPIC_ID_ERROR, + context2.updateAndGenerateResponseData(respData2).error()) + } + + // This test simulates a session where the topic ID changes broker side (the one handling the request) in both the metadata cache and the log + // -- as though the topic is deleted and recreated. + @Test + def testFetchSessionUpdateTopicIdsBrokerSide(): Unit = { + val time = new MockTime() + val cache = new FetchSessionCache(10, 1000) + val fetchManager = new FetchManager(time, cache) + val topicNames = Map(Uuid.randomUuid() -> "foo", Uuid.randomUuid() -> "bar").asJava + val topicIds = topicNames.asScala.map(_.swap).asJava + val tp0 = new TopicIdPartition(topicIds.get("foo"), new TopicPartition("foo", 0)) + val tp1 = new TopicIdPartition(topicIds.get("bar"), new TopicPartition("bar", 1)) + + // Create a new fetch session with foo-0 and bar-1 + val reqData1 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + reqData1.put(tp0.topicPartition, new FetchRequest.PartitionData(tp0.topicId, 0, 0, 100, + Optional.empty())) + reqData1.put(tp1.topicPartition, new FetchRequest.PartitionData(tp1.topicId, 10, 0, 100, + Optional.empty())) + val request1 = createRequest(JFetchMetadata.INITIAL, reqData1, EMPTY_PART_LIST, false) + // Start a fetch session. Simulate unknown partition foo-0. + val context1 = fetchManager.newContext( + request1.version, + request1.metadata, + request1.isFromFollower, + request1.fetchData(topicNames), + request1.forgottenTopics(topicNames), + topicNames + ) + assertEquals(classOf[FullFetchContext], context1.getClass) + val respData1 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + respData1.put(tp1, new FetchResponseData.PartitionData() + .setPartitionIndex(1) + .setHighWatermark(10) + .setLastStableOffset(10) + .setLogStartOffset(10)) + respData1.put(tp0, new FetchResponseData.PartitionData() + .setPartitionIndex(0) + .setHighWatermark(-1) + .setLastStableOffset(-1) + .setLogStartOffset(-1) + .setErrorCode(Errors.UNKNOWN_TOPIC_OR_PARTITION.code)) + val resp1 = context1.updateAndGenerateResponseData(respData1) + assertEquals(Errors.NONE, resp1.error()) + assertTrue(resp1.sessionId() != INVALID_SESSION_ID) + assertEquals(2, resp1.responseData(topicNames, request1.version).size) + + // Create an incremental fetch request as though no topics changed. + val reqData2 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + val request2 = createRequest(new JFetchMetadata(resp1.sessionId(), 1), reqData2, EMPTY_PART_LIST, false) + // Simulate ID changing on server. + val topicNamesFooChanged = Map(topicIds.get("bar") -> "bar", Uuid.randomUuid() -> "foo").asJava + val context2 = fetchManager.newContext( + request2.version, + request2.metadata, + request2.isFromFollower, + request2.fetchData(topicNamesFooChanged), + request2.forgottenTopics(topicNamesFooChanged), + topicNamesFooChanged + ) + assertEquals(classOf[IncrementalFetchContext], context2.getClass) + val respData2 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + // Likely if the topic ID is different in the broker, it will be different in the log. Simulate the log check finding an inconsistent ID. + respData2.put(tp0, new FetchResponseData.PartitionData() + .setPartitionIndex(0) + .setHighWatermark(-1) + .setLastStableOffset(-1) + .setLogStartOffset(-1) + .setErrorCode(Errors.INCONSISTENT_TOPIC_ID.code)) + val resp2 = context2.updateAndGenerateResponseData(respData2) + + assertEquals(Errors.NONE, resp2.error) + assertTrue(resp2.sessionId > 0) + val responseData2 = resp2.responseData(topicNames, request2.version) + // We should have the inconsistent topic ID error on the partition + assertEquals(Errors.INCONSISTENT_TOPIC_ID.code, responseData2.get(tp0.topicPartition).errorCode) + } + + private def noErrorResponse: FetchResponseData.PartitionData = { + new FetchResponseData.PartitionData() + .setPartitionIndex(1) + .setHighWatermark(10) + .setLastStableOffset(10) + .setLogStartOffset(10) + } + + private def errorResponse(errorCode: Short): FetchResponseData.PartitionData = { + new FetchResponseData.PartitionData() + .setPartitionIndex(0) + .setHighWatermark(-1) + .setLastStableOffset(-1) + .setLogStartOffset(-1) + .setErrorCode(errorCode) + } + + @Test + def testResolveUnknownPartitions(): Unit = { + val time = new MockTime() + val cache = new FetchSessionCache(10, 1000) + val fetchManager = new FetchManager(time, cache) + + def newContext( + metadata: JFetchMetadata, + partitions: Seq[TopicIdPartition], + topicNames: Map[Uuid, String] // Topic ID to name mapping known by the broker. + ): FetchContext = { + val data = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + partitions.foreach { topicIdPartition => + data.put( + topicIdPartition.topicPartition, + new FetchRequest.PartitionData(topicIdPartition.topicId, 0, 0, 100, Optional.empty()) + ) + } + + val fetchRequest = createRequest(metadata, data, EMPTY_PART_LIST, false) + + fetchManager.newContext( + fetchRequest.version, + fetchRequest.metadata, + fetchRequest.isFromFollower, + fetchRequest.fetchData(topicNames.asJava), + fetchRequest.forgottenTopics(topicNames.asJava), + topicNames.asJava + ) + } + + def updateAndGenerateResponseData( + context: FetchContext + ): Int = { + val data = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + context.foreachPartition { (topicIdPartition, _) => + data.put( + topicIdPartition, + if (topicIdPartition.topic == null) + errorResponse(Errors.UNKNOWN_TOPIC_ID.code) + else + noErrorResponse + ) + } + context.updateAndGenerateResponseData(data).sessionId + } + + val foo = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 0)) + val bar = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("bar", 0)) + val zar = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("zar", 0)) + + val fooUnresolved = new TopicIdPartition(foo.topicId, new TopicPartition(null, foo.partition)) + val barUnresolved = new TopicIdPartition(bar.topicId, new TopicPartition(null, bar.partition)) + val zarUnresolved = new TopicIdPartition(zar.topicId, new TopicPartition(null, zar.partition)) + + // The metadata cache does not know about the topic. + val context1 = newContext( + JFetchMetadata.INITIAL, + Seq(foo, bar, zar), + Map.empty[Uuid, String] + ) + + // So the context contains unresolved partitions. + assertEquals(classOf[FullFetchContext], context1.getClass) + assertPartitionsOrder(context1, Seq(fooUnresolved, barUnresolved, zarUnresolved)) + + // The response is sent back to create the session. + val sessionId = updateAndGenerateResponseData(context1) + + // The metadata cache only knows about foo. + val context2 = newContext( + new JFetchMetadata(sessionId, 1), + Seq.empty, + Map(foo.topicId -> foo.topic) + ) + + // So foo is resolved but not the others. + assertEquals(classOf[IncrementalFetchContext], context2.getClass) + assertPartitionsOrder(context2, Seq(foo, barUnresolved, zarUnresolved)) + + updateAndGenerateResponseData(context2) + + // The metadata cache knows about foo and bar. + val context3 = newContext( + new JFetchMetadata(sessionId, 2), + Seq(bar), + Map(foo.topicId -> foo.topic, bar.topicId -> bar.topic) + ) + + // So foo and bar are resolved. + assertEquals(classOf[IncrementalFetchContext], context3.getClass) + assertPartitionsOrder(context3, Seq(foo, bar, zarUnresolved)) + + updateAndGenerateResponseData(context3) + + // The metadata cache knows about all topics. + val context4 = newContext( + new JFetchMetadata(sessionId, 3), + Seq.empty, + Map(foo.topicId -> foo.topic, bar.topicId -> bar.topic, zar.topicId -> zar.topic) + ) + + // So all topics are resolved. + assertEquals(classOf[IncrementalFetchContext], context4.getClass) + assertPartitionsOrder(context4, Seq(foo, bar, zar)) + + updateAndGenerateResponseData(context4) + + // The metadata cache does not know about the topics anymore (e.g. deleted). + val context5 = newContext( + new JFetchMetadata(sessionId, 4), + Seq.empty, + Map.empty + ) + + // All topics remain resolved. + assertEquals(classOf[IncrementalFetchContext], context5.getClass) + assertPartitionsOrder(context4, Seq(foo, bar, zar)) + } + + // This test simulates trying to forget a topic partition with all possible topic ID usages for both requests. + @ParameterizedTest + @MethodSource(Array("idUsageCombinations")) + def testToForgetPartitions(fooStartsResolved: Boolean, fooEndsResolved: Boolean): Unit = { + val time = new MockTime() + val cache = new FetchSessionCache(10, 1000) + val fetchManager = new FetchManager(time, cache) + + def newContext( + metadata: JFetchMetadata, + partitions: Seq[TopicIdPartition], + toForget: Seq[TopicIdPartition], + topicNames: Map[Uuid, String] // Topic ID to name mapping known by the broker. + ): FetchContext = { + val data = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + partitions.foreach { topicIdPartition => + data.put( + topicIdPartition.topicPartition, + new FetchRequest.PartitionData(topicIdPartition.topicId, 0, 0, 100, Optional.empty()) + ) + } + + val fetchRequest = createRequest(metadata, data, toForget.toList.asJava, false) + + fetchManager.newContext( + fetchRequest.version, + fetchRequest.metadata, + fetchRequest.isFromFollower, + fetchRequest.fetchData(topicNames.asJava), + fetchRequest.forgottenTopics(topicNames.asJava), + topicNames.asJava + ) + } + + def updateAndGenerateResponseData( + context: FetchContext + ): Int = { + val data = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + context.foreachPartition { (topicIdPartition, _) => + data.put( + topicIdPartition, + if (topicIdPartition.topic == null) + errorResponse(Errors.UNKNOWN_TOPIC_ID.code) + else + noErrorResponse + ) + } + context.updateAndGenerateResponseData(data).sessionId + } + + val foo = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 0)) + val bar = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("bar", 0)) + + val fooUnresolved = new TopicIdPartition(foo.topicId, new TopicPartition(null, foo.partition)) + val barUnresolved = new TopicIdPartition(bar.topicId, new TopicPartition(null, bar.partition)) + + // Create a new context where foo's resolution depends on fooStartsResolved and bar is unresolved. + val context1Names = if (fooStartsResolved) Map(foo.topicId -> foo.topic) else Map.empty[Uuid, String] + val fooContext1 = if (fooStartsResolved) foo else fooUnresolved + val context1 = newContext( + JFetchMetadata.INITIAL, + Seq(fooContext1, bar), + Seq.empty, + context1Names + ) + + // So the context contains unresolved bar and a resolved foo iff fooStartsResolved + assertEquals(classOf[FullFetchContext], context1.getClass) + assertPartitionsOrder(context1, Seq(fooContext1, barUnresolved)) + + // The response is sent back to create the session. + val sessionId = updateAndGenerateResponseData(context1) + + // Forget foo, but keep bar. Foo's resolution depends on fooEndsResolved and bar stays unresolved. + val context2Names = if (fooEndsResolved) Map(foo.topicId -> foo.topic) else Map.empty[Uuid, String] + val fooContext2 = if (fooEndsResolved) foo else fooUnresolved + val context2 = newContext( + new JFetchMetadata(sessionId, 1), + Seq.empty, + Seq(fooContext2), + context2Names + ) + + // So foo is removed but not the others. + assertEquals(classOf[IncrementalFetchContext], context2.getClass) + assertPartitionsOrder(context2, Seq(barUnresolved)) + + updateAndGenerateResponseData(context2) + + // Now remove bar + val context3 = newContext( + new JFetchMetadata(sessionId, 2), + Seq.empty, + Seq(bar), + Map.empty[Uuid, String] + ) + + // Context is sessionless since it is empty. + assertEquals(classOf[SessionlessFetchContext], context3.getClass) + assertPartitionsOrder(context3, Seq()) + } + + @Test + def testUpdateAndGenerateResponseData(): Unit = { + val time = new MockTime() + val cache = new FetchSessionCache(10, 1000) + val fetchManager = new FetchManager(time, cache) + + def newContext( + metadata: JFetchMetadata, + partitions: Seq[TopicIdPartition], + topicNames: Map[Uuid, String] // Topic ID to name mapping known by the broker. + ): FetchContext = { + val data = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + partitions.foreach { topicIdPartition => + data.put( + topicIdPartition.topicPartition, + new FetchRequest.PartitionData(topicIdPartition.topicId, 0, 0, 100, Optional.empty()) + ) + } + + val fetchRequest = createRequest(metadata, data, EMPTY_PART_LIST, false) + + fetchManager.newContext( + fetchRequest.version, + fetchRequest.metadata, + fetchRequest.isFromFollower, + fetchRequest.fetchData(topicNames.asJava), + fetchRequest.forgottenTopics(topicNames.asJava), + topicNames.asJava + ) + } + + // Give both topics errors so they will stay in the session. + def updateAndGenerateResponseData( + context: FetchContext + ): FetchResponse = { + val data = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + context.foreachPartition { (topicIdPartition, _) => + data.put( + topicIdPartition, + if (topicIdPartition.topic == null) + errorResponse(Errors.UNKNOWN_TOPIC_ID.code) + else + errorResponse(Errors.UNKNOWN_TOPIC_OR_PARTITION.code) + ) + } + context.updateAndGenerateResponseData(data) + } + + val foo = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 0)) + val bar = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("bar", 0)) + + // Foo will always be resolved and bar will always not be resolved on the receiving broker. + val receivingBrokerTopicNames = Map(foo.topicId -> foo.topic) + // The sender will know both topics' id to name mappings. + val sendingTopicNames = Map(foo.topicId -> foo.topic, bar.topicId -> bar.topic) + + def checkResponseData(response: FetchResponse): Unit = { + assertEquals( + Map( + foo.topicPartition -> Errors.UNKNOWN_TOPIC_OR_PARTITION.code, + bar.topicPartition -> Errors.UNKNOWN_TOPIC_ID.code, + ), + response.responseData(sendingTopicNames.asJava, ApiKeys.FETCH.latestVersion).asScala.map { case (tp, resp) => + tp -> resp.errorCode + } + ) + } + + // Start with a sessionless context. + val context1 = newContext( + JFetchMetadata.LEGACY, + Seq(foo, bar), + receivingBrokerTopicNames + ) + assertEquals(classOf[SessionlessFetchContext], context1.getClass) + // Check the response can be read as expected. + checkResponseData(updateAndGenerateResponseData(context1)) + + // Now create a full context. + val context2 = newContext( + JFetchMetadata.INITIAL, + Seq(foo, bar), + receivingBrokerTopicNames + ) + assertEquals(classOf[FullFetchContext], context2.getClass) + // We want to get the session ID to build more contexts in this session. + val response2 = updateAndGenerateResponseData(context2) + val sessionId = response2.sessionId + checkResponseData(response2) + + // Now create an incremental context. We re-add foo as though the partition data is updated. In a real broker, the data would update. + val context3 = newContext( + new JFetchMetadata(sessionId, 1), + Seq.empty, + receivingBrokerTopicNames + ) + assertEquals(classOf[IncrementalFetchContext], context3.getClass) + checkResponseData(updateAndGenerateResponseData(context3)) + + // Finally create an error context by using the same epoch + val context4 = newContext( + new JFetchMetadata(sessionId, 1), + Seq.empty, + receivingBrokerTopicNames + ) + assertEquals(classOf[SessionErrorContext], context4.getClass) + // The response should be empty. + assertEquals(Collections.emptyList, updateAndGenerateResponseData(context4).data.responses) + } + + @Test + def testFetchSessionExpiration(): Unit = { + val time = new MockTime() + // set maximum entries to 2 to allow for eviction later + val cache = new FetchSessionCache(2, 1000) + val fetchManager = new FetchManager(time, cache) + val fooId = Uuid.randomUuid() + val topicNames = Map(fooId -> "foo").asJava + val foo0 = new TopicIdPartition(fooId, new TopicPartition("foo", 0)) + val foo1 = new TopicIdPartition(fooId, new TopicPartition("foo", 1)) + + // Create a new fetch session, session 1 + val session1req = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + session1req.put(foo0.topicPartition, new FetchRequest.PartitionData(fooId,0, 0, 100, + Optional.empty())) + session1req.put(foo1.topicPartition, new FetchRequest.PartitionData(fooId,10, 0, 100, + Optional.empty())) + val session1request1 = createRequest(JFetchMetadata.INITIAL, session1req, EMPTY_PART_LIST, false) + val session1context1 = fetchManager.newContext( + session1request1.version, + session1request1.metadata, + session1request1.isFromFollower, + session1request1.fetchData(topicNames), + session1request1.forgottenTopics(topicNames), + topicNames + ) + assertEquals(classOf[FullFetchContext], session1context1.getClass) + val respData1 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + respData1.put(foo0, new FetchResponseData.PartitionData() + .setPartitionIndex(0) + .setHighWatermark(100) + .setLastStableOffset(100) + .setLogStartOffset(100)) + respData1.put(foo1, new FetchResponseData.PartitionData() + .setPartitionIndex(1) + .setHighWatermark(10) + .setLastStableOffset(10) + .setLogStartOffset(10)) + val session1resp = session1context1.updateAndGenerateResponseData(respData1) + assertEquals(Errors.NONE, session1resp.error()) + assertTrue(session1resp.sessionId() != INVALID_SESSION_ID) + assertEquals(2, session1resp.responseData(topicNames, session1request1.version).size) + + // check session entered into case + assertTrue(cache.get(session1resp.sessionId()).isDefined) + time.sleep(500) + + // Create a second new fetch session + val session2req = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + session2req.put(foo0.topicPartition, new FetchRequest.PartitionData(fooId, 0, 0, 100, + Optional.empty())) + session2req.put(foo1.topicPartition, new FetchRequest.PartitionData(fooId,10, 0, 100, + Optional.empty())) + val session2request1 = createRequest(JFetchMetadata.INITIAL, session1req, EMPTY_PART_LIST, false) + val session2context = fetchManager.newContext( + session2request1.version, + session2request1.metadata, + session2request1.isFromFollower, + session2request1.fetchData(topicNames), + session2request1.forgottenTopics(topicNames), + topicNames + ) + assertEquals(classOf[FullFetchContext], session2context.getClass) + val session2RespData = new util.LinkedHashMap[TopicPartition, FetchResponseData.PartitionData] + session2RespData.put(foo0.topicPartition, + new FetchResponseData.PartitionData() + .setPartitionIndex(0) + .setHighWatermark(100) + .setLastStableOffset(100) + .setLogStartOffset(100)) + session2RespData.put(foo1.topicPartition, new FetchResponseData.PartitionData() + .setPartitionIndex(1) + .setHighWatermark(10) + .setLastStableOffset(10) + .setLogStartOffset(10)) + val session2resp = session2context.updateAndGenerateResponseData(respData1) + assertEquals(Errors.NONE, session2resp.error()) + assertTrue(session2resp.sessionId() != INVALID_SESSION_ID) + assertEquals(2, session2resp.responseData(topicNames, session2request1.version()).size()) + + // both newly created entries are present in cache + assertTrue(cache.get(session1resp.sessionId()).isDefined) + assertTrue(cache.get(session2resp.sessionId()).isDefined) + time.sleep(500) + + // Create an incremental fetch request for session 1 + val session1request2 = createRequest( + new JFetchMetadata(session1resp.sessionId(), 1), + new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData], + new util.ArrayList[TopicIdPartition], false) + val context1v2 = fetchManager.newContext( + session1request2.version, + session1request2.metadata, + session1request2.isFromFollower, + session1request2.fetchData(topicNames), + session1request2.forgottenTopics(topicNames), + topicNames + ) + assertEquals(classOf[IncrementalFetchContext], context1v2.getClass) + + // total sleep time will now be large enough that fetch session 1 will be evicted if not correctly touched + time.sleep(501) + + // create one final session to test that the least recently used entry is evicted + // the second session should be evicted because the first session was incrementally fetched + // more recently than the second session was created + val session3req = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + session3req.put(foo0.topicPartition, new FetchRequest.PartitionData(fooId, 0, 0, 100, + Optional.empty())) + session3req.put(foo1.topicPartition, new FetchRequest.PartitionData(fooId,0, 0, 100, + Optional.empty())) + val session3request1 = createRequest(JFetchMetadata.INITIAL, session3req, EMPTY_PART_LIST, false) + val session3context = fetchManager.newContext( + session3request1.version, + session3request1.metadata, + session3request1.isFromFollower, + session3request1.fetchData(topicNames), + session3request1.forgottenTopics(topicNames), + topicNames + ) + assertEquals(classOf[FullFetchContext], session3context.getClass) + val respData3 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + respData3.put(new TopicIdPartition(fooId, new TopicPartition("foo", 0)), new FetchResponseData.PartitionData() + .setPartitionIndex(0) + .setHighWatermark(100) + .setLastStableOffset(100) + .setLogStartOffset(100)) + respData3.put(new TopicIdPartition(fooId, new TopicPartition("foo", 1)), + new FetchResponseData.PartitionData() + .setPartitionIndex(1) + .setHighWatermark(10) + .setLastStableOffset(10) + .setLogStartOffset(10)) + val session3resp = session3context.updateAndGenerateResponseData(respData3) + assertEquals(Errors.NONE, session3resp.error()) + assertTrue(session3resp.sessionId() != INVALID_SESSION_ID) + assertEquals(2, session3resp.responseData(topicNames, session3request1.version).size) + + assertTrue(cache.get(session1resp.sessionId()).isDefined) + assertFalse(cache.get(session2resp.sessionId()).isDefined, "session 2 should have been evicted by latest session, as session 1 was used more recently") + assertTrue(cache.get(session3resp.sessionId()).isDefined) + } + + @Test + def testPrivilegedSessionHandling(): Unit = { + val time = new MockTime() + // set maximum entries to 2 to allow for eviction later + val cache = new FetchSessionCache(2, 1000) + val fetchManager = new FetchManager(time, cache) + val fooId = Uuid.randomUuid() + val topicNames = Map(fooId -> "foo").asJava + val foo0 = new TopicIdPartition(fooId, new TopicPartition("foo", 0)) + val foo1 = new TopicIdPartition(fooId, new TopicPartition("foo", 1)) + + // Create a new fetch session, session 1 + val session1req = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + session1req.put(foo0.topicPartition, new FetchRequest.PartitionData(fooId, 0, 0, 100, + Optional.empty())) + session1req.put(foo1.topicPartition, new FetchRequest.PartitionData(fooId, 10, 0, 100, + Optional.empty())) + val session1request = createRequest(JFetchMetadata.INITIAL, session1req, EMPTY_PART_LIST, true) + val session1context = fetchManager.newContext( + session1request.version, + session1request.metadata, + session1request.isFromFollower, + session1request.fetchData(topicNames), + session1request.forgottenTopics(topicNames), + topicNames + ) + assertEquals(classOf[FullFetchContext], session1context.getClass) + val respData1 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + respData1.put(foo0, new FetchResponseData.PartitionData() + .setPartitionIndex(0) + .setHighWatermark(100) + .setLastStableOffset(100) + .setLogStartOffset(100)) + respData1.put(foo1, new FetchResponseData.PartitionData() + .setPartitionIndex(1) + .setHighWatermark(10) + .setLastStableOffset(10) + .setLogStartOffset(10)) + val session1resp = session1context.updateAndGenerateResponseData(respData1) + assertEquals(Errors.NONE, session1resp.error()) + assertTrue(session1resp.sessionId() != INVALID_SESSION_ID) + assertEquals(2, session1resp.responseData(topicNames, session1request.version).size) + assertEquals(1, cache.size) + + // move time forward to age session 1 a little compared to session 2 + time.sleep(500) + + // Create a second new fetch session, unprivileged + val session2req = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + session2req.put(foo0.topicPartition, new FetchRequest.PartitionData(fooId, 0, 0, 100, + Optional.empty())) + session2req.put(foo1.topicPartition, new FetchRequest.PartitionData(fooId, 10, 0, 100, + Optional.empty())) + val session2request = createRequest(JFetchMetadata.INITIAL, session1req, EMPTY_PART_LIST, false) + val session2context = fetchManager.newContext( + session2request.version, + session2request.metadata, + session2request.isFromFollower, + session2request.fetchData(topicNames), + session2request.forgottenTopics(topicNames), + topicNames + ) + assertEquals(classOf[FullFetchContext], session2context.getClass) + val session2RespData = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + session2RespData.put(foo0, + new FetchResponseData.PartitionData() + .setPartitionIndex(0) + .setHighWatermark(100) + .setLastStableOffset(100) + .setLogStartOffset(100)) + session2RespData.put(foo1, + new FetchResponseData.PartitionData() + .setPartitionIndex(1) + .setHighWatermark(10) + .setLastStableOffset(10) + .setLogStartOffset(10)) + val session2resp = session2context.updateAndGenerateResponseData(session2RespData) + assertEquals(Errors.NONE, session2resp.error()) + assertTrue(session2resp.sessionId() != INVALID_SESSION_ID) + assertEquals(2, session2resp.responseData(topicNames, session2request.version).size) + + // both newly created entries are present in cache + assertTrue(cache.get(session1resp.sessionId()).isDefined) + assertTrue(cache.get(session2resp.sessionId()).isDefined) + assertEquals(2, cache.size) + time.sleep(500) + + // create a session to test session1 privileges mean that session 1 is retained and session 2 is evicted + val session3req = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + session3req.put(foo0.topicPartition, new FetchRequest.PartitionData(fooId, 0, 0, 100, + Optional.empty())) + session3req.put(foo1.topicPartition, new FetchRequest.PartitionData(fooId, 0, 0, 100, + Optional.empty())) + val session3request = createRequest(JFetchMetadata.INITIAL, session3req, EMPTY_PART_LIST, true) + val session3context = fetchManager.newContext( + session3request.version, + session3request.metadata, + session3request.isFromFollower, + session3request.fetchData(topicNames), + session3request.forgottenTopics(topicNames), + topicNames + ) + assertEquals(classOf[FullFetchContext], session3context.getClass) + val respData3 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + respData3.put(foo0, + new FetchResponseData.PartitionData() + .setPartitionIndex(0) + .setHighWatermark(100) + .setLastStableOffset(100) + .setLogStartOffset(100)) + respData3.put(foo1, + new FetchResponseData.PartitionData() + .setPartitionIndex(1) + .setHighWatermark(10) + .setLastStableOffset(10) + .setLogStartOffset(10)) + val session3resp = session3context.updateAndGenerateResponseData(respData3) + assertEquals(Errors.NONE, session3resp.error()) + assertTrue(session3resp.sessionId() != INVALID_SESSION_ID) + assertEquals(2, session3resp.responseData(topicNames, session3request.version).size) + + assertTrue(cache.get(session1resp.sessionId()).isDefined) + // even though session 2 is more recent than session 1, and has not reached expiry time, it is less + // privileged than session 2, and thus session 3 should be entered and session 2 evicted. + assertFalse(cache.get(session2resp.sessionId()).isDefined, "session 2 should have been evicted by session 3") + assertTrue(cache.get(session3resp.sessionId()).isDefined) + assertEquals(2, cache.size) + + time.sleep(501) + + // create a final session to test whether session1 can be evicted due to age even though it is privileged + val session4req = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + session4req.put(foo0.topicPartition, new FetchRequest.PartitionData(fooId, 0, 0, 100, + Optional.empty())) + session4req.put(foo1.topicPartition, new FetchRequest.PartitionData(fooId, 0, 0, 100, + Optional.empty())) + val session4request = createRequest(JFetchMetadata.INITIAL, session4req, EMPTY_PART_LIST, true) + val session4context = fetchManager.newContext( + session4request.version, + session4request.metadata, + session4request.isFromFollower, + session4request.fetchData(topicNames), + session4request.forgottenTopics(topicNames), + topicNames + ) + assertEquals(classOf[FullFetchContext], session4context.getClass) + val respData4 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + respData4.put(foo0, + new FetchResponseData.PartitionData() + .setPartitionIndex(0) + .setHighWatermark(100) + .setLastStableOffset(100) + .setLogStartOffset(100)) + respData4.put(foo1, + new FetchResponseData.PartitionData() + .setPartitionIndex(1) + .setHighWatermark(10) + .setLastStableOffset(10) + .setLogStartOffset(10)) + val session4resp = session3context.updateAndGenerateResponseData(respData4) + assertEquals(Errors.NONE, session4resp.error()) + assertTrue(session4resp.sessionId() != INVALID_SESSION_ID) + assertEquals(2, session4resp.responseData(topicNames, session4request.version).size) + + assertFalse(cache.get(session1resp.sessionId()).isDefined, "session 1 should have been evicted by session 4 even though it is privileged as it has hit eviction time") + assertTrue(cache.get(session3resp.sessionId()).isDefined) + assertTrue(cache.get(session4resp.sessionId()).isDefined) + assertEquals(2, cache.size) + } + + @Test + def testZeroSizeFetchSession(): Unit = { + val time = new MockTime() + val cache = new FetchSessionCache(10, 1000) + val fetchManager = new FetchManager(time, cache) + val fooId = Uuid.randomUuid() + val topicNames = Map(fooId -> "foo").asJava + val foo0 = new TopicIdPartition(fooId, new TopicPartition("foo", 0)) + val foo1 = new TopicIdPartition(fooId, new TopicPartition("foo", 1)) + + // Create a new fetch session with foo-0 and foo-1 + val reqData1 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + reqData1.put(foo0.topicPartition, new FetchRequest.PartitionData(fooId, 0, 0, 100, + Optional.empty())) + reqData1.put(foo1.topicPartition, new FetchRequest.PartitionData(fooId, 10, 0, 100, + Optional.empty())) + val request1 = createRequest(JFetchMetadata.INITIAL, reqData1, EMPTY_PART_LIST, false) + val context1 = fetchManager.newContext( + request1.version, + request1.metadata, + request1.isFromFollower, + request1.fetchData(topicNames), + request1.forgottenTopics(topicNames), + topicNames + ) + assertEquals(classOf[FullFetchContext], context1.getClass) + val respData1 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + respData1.put(foo0, new FetchResponseData.PartitionData() + .setPartitionIndex(0) + .setHighWatermark(100) + .setLastStableOffset(100) + .setLogStartOffset(100)) + respData1.put(foo1, new FetchResponseData.PartitionData() + .setPartitionIndex(1) + .setHighWatermark(10) + .setLastStableOffset(10) + .setLogStartOffset(10)) + val resp1 = context1.updateAndGenerateResponseData(respData1) + assertEquals(Errors.NONE, resp1.error) + assertTrue(resp1.sessionId() != INVALID_SESSION_ID) + assertEquals(2, resp1.responseData(topicNames, request1.version).size) + + // Create an incremental fetch request that removes foo-0 and foo-1 + // Verify that the previous fetch session was closed. + val reqData2 = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + val removed2 = new util.ArrayList[TopicIdPartition] + removed2.add(foo0) + removed2.add(foo1) + val request2 = createRequest( new JFetchMetadata(resp1.sessionId, 1), reqData2, removed2, false) + val context2 = fetchManager.newContext( + request2.version, + request2.metadata, + request2.isFromFollower, + request2.fetchData(topicNames), + request2.forgottenTopics(topicNames), + topicNames + ) + assertEquals(classOf[SessionlessFetchContext], context2.getClass) + val respData2 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + val resp2 = context2.updateAndGenerateResponseData(respData2) + assertEquals(INVALID_SESSION_ID, resp2.sessionId) + assertTrue(resp2.responseData(topicNames, request2.version).isEmpty) + assertEquals(0, cache.size) + } + + @Test + def testDivergingEpoch(): Unit = { + val time = new MockTime() + val cache = new FetchSessionCache(10, 1000) + val fetchManager = new FetchManager(time, cache) + val topicNames = Map(Uuid.randomUuid() -> "foo", Uuid.randomUuid() -> "bar").asJava + val topicIds = topicNames.asScala.map(_.swap).asJava + val tp1 = new TopicIdPartition(topicIds.get("foo"), new TopicPartition("foo", 1)) + val tp2 = new TopicIdPartition(topicIds.get("bar"), new TopicPartition("bar", 2)) + + val reqData = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + reqData.put(tp1.topicPartition, new FetchRequest.PartitionData(tp1.topicId, 100, 0, 1000, Optional.of(5), Optional.of(4))) + reqData.put(tp2.topicPartition, new FetchRequest.PartitionData(tp2.topicId, 100, 0, 1000, Optional.of(5), Optional.of(4))) + + // Full fetch context returns all partitions in the response + val request1 = createRequest(JFetchMetadata.INITIAL, reqData, EMPTY_PART_LIST, false) + val context1 = fetchManager.newContext( + request1.version, + request1.metadata, + request1.isFromFollower, + request1.fetchData(topicNames), + request1.forgottenTopics(topicNames), + topicNames + ) + assertEquals(classOf[FullFetchContext], context1.getClass) + val respData = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + respData.put(tp1, new FetchResponseData.PartitionData() + .setPartitionIndex(tp1.partition) + .setHighWatermark(105) + .setLastStableOffset(105) + .setLogStartOffset(0)) + val divergingEpoch = new FetchResponseData.EpochEndOffset().setEpoch(3).setEndOffset(90) + respData.put(tp2, new FetchResponseData.PartitionData() + .setPartitionIndex(tp2.partition) + .setHighWatermark(105) + .setLastStableOffset(105) + .setLogStartOffset(0) + .setDivergingEpoch(divergingEpoch)) + val resp1 = context1.updateAndGenerateResponseData(respData) + assertEquals(Errors.NONE, resp1.error) + assertNotEquals(INVALID_SESSION_ID, resp1.sessionId) + assertEquals(Utils.mkSet(tp1.topicPartition, tp2.topicPartition), resp1.responseData(topicNames, request1.version).keySet) + + // Incremental fetch context returns partitions with divergent epoch even if none + // of the other conditions for return are met. + val request2 = createRequest(new JFetchMetadata(resp1.sessionId, 1), reqData, EMPTY_PART_LIST, false) + val context2 = fetchManager.newContext( + request2.version, + request2.metadata, + request2.isFromFollower, + request2.fetchData(topicNames), + request2.forgottenTopics(topicNames), + topicNames + ) + assertEquals(classOf[IncrementalFetchContext], context2.getClass) + val resp2 = context2.updateAndGenerateResponseData(respData) + assertEquals(Errors.NONE, resp2.error) + assertEquals(resp1.sessionId, resp2.sessionId) + assertEquals(Collections.singleton(tp2.topicPartition), resp2.responseData(topicNames, request2.version).keySet) + + // All partitions with divergent epoch should be returned. + respData.put(tp1, new FetchResponseData.PartitionData() + .setPartitionIndex(tp1.partition) + .setHighWatermark(105) + .setLastStableOffset(105) + .setLogStartOffset(0) + .setDivergingEpoch(divergingEpoch)) + val resp3 = context2.updateAndGenerateResponseData(respData) + assertEquals(Errors.NONE, resp3.error) + assertEquals(resp1.sessionId, resp3.sessionId) + assertEquals(Utils.mkSet(tp1.topicPartition, tp2.topicPartition), resp3.responseData(topicNames, request2.version).keySet) + + // Partitions that meet other conditions should be returned regardless of whether + // divergingEpoch is set or not. + respData.put(tp1, new FetchResponseData.PartitionData() + .setPartitionIndex(tp1.partition) + .setHighWatermark(110) + .setLastStableOffset(110) + .setLogStartOffset(0)) + val resp4 = context2.updateAndGenerateResponseData(respData) + assertEquals(Errors.NONE, resp4.error) + assertEquals(resp1.sessionId, resp4.sessionId) + assertEquals(Utils.mkSet(tp1.topicPartition, tp2.topicPartition), resp4.responseData(topicNames, request2.version).keySet) + } + + @Test + def testDeprioritizesPartitionsWithRecordsOnly(): Unit = { + val time = new MockTime() + val cache = new FetchSessionCache(10, 1000) + val fetchManager = new FetchManager(time, cache) + val topicIds = Map("foo" -> Uuid.randomUuid(), "bar" -> Uuid.randomUuid(), "zar" -> Uuid.randomUuid()).asJava + val topicNames = topicIds.asScala.map(_.swap).asJava + val tp1 = new TopicIdPartition(topicIds.get("foo"), new TopicPartition("foo", 1)) + val tp2 = new TopicIdPartition(topicIds.get("bar"), new TopicPartition("bar", 2)) + val tp3 = new TopicIdPartition(topicIds.get("zar"), new TopicPartition("zar", 3)) + + val reqData = new util.LinkedHashMap[TopicIdPartition, FetchRequest.PartitionData] + reqData.put(tp1, new FetchRequest.PartitionData(tp1.topicId, 100, 0, 1000, Optional.of(5), Optional.of(4))) + reqData.put(tp2, new FetchRequest.PartitionData(tp2.topicId, 100, 0, 1000, Optional.of(5), Optional.of(4))) + reqData.put(tp3, new FetchRequest.PartitionData(tp3.topicId, 100, 0, 1000, Optional.of(5), Optional.of(4))) + + // Full fetch context returns all partitions in the response + val context1 = fetchManager.newContext(ApiKeys.FETCH.latestVersion(), JFetchMetadata.INITIAL, false, + reqData, Collections.emptyList(), topicNames) + assertEquals(classOf[FullFetchContext], context1.getClass) + + val respData1 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + respData1.put(tp1, new FetchResponseData.PartitionData() + .setPartitionIndex(tp1.topicPartition.partition) + .setHighWatermark(50) + .setLastStableOffset(50) + .setLogStartOffset(0)) + respData1.put(tp2, new FetchResponseData.PartitionData() + .setPartitionIndex(tp2.topicPartition.partition) + .setHighWatermark(50) + .setLastStableOffset(50) + .setLogStartOffset(0)) + respData1.put(tp3, new FetchResponseData.PartitionData() + .setPartitionIndex(tp3.topicPartition.partition) + .setHighWatermark(50) + .setLastStableOffset(50) + .setLogStartOffset(0)) + + val resp1 = context1.updateAndGenerateResponseData(respData1) + assertEquals(Errors.NONE, resp1.error) + assertNotEquals(INVALID_SESSION_ID, resp1.sessionId) + assertEquals(Utils.mkSet(tp1.topicPartition, tp2.topicPartition, tp3.topicPartition), resp1.responseData(topicNames, ApiKeys.FETCH.latestVersion()).keySet()) + + // Incremental fetch context returns partitions with changes but only deprioritizes + // the partitions with records + val context2 = fetchManager.newContext(ApiKeys.FETCH.latestVersion(), new JFetchMetadata(resp1.sessionId, 1), false, + reqData, Collections.emptyList(), topicNames) + assertEquals(classOf[IncrementalFetchContext], context2.getClass) + + // Partitions are ordered in the session as per last response + assertPartitionsOrder(context2, Seq(tp1, tp2, tp3)) + + // Response is empty + val respData2 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + val resp2 = context2.updateAndGenerateResponseData(respData2) + assertEquals(Errors.NONE, resp2.error) + assertEquals(resp1.sessionId, resp2.sessionId) + assertEquals(Collections.emptySet(), resp2.responseData(topicNames, ApiKeys.FETCH.latestVersion()).keySet) + + // All partitions with changes should be returned. + val respData3 = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + respData3.put(tp1, new FetchResponseData.PartitionData() + .setPartitionIndex(tp1.topicPartition.partition) + .setHighWatermark(60) + .setLastStableOffset(50) + .setLogStartOffset(0)) + respData3.put(tp2, new FetchResponseData.PartitionData() + .setPartitionIndex(tp2.topicPartition.partition) + .setHighWatermark(60) + .setLastStableOffset(50) + .setLogStartOffset(0) + .setRecords(MemoryRecords.withRecords(CompressionType.NONE, + new SimpleRecord(100, null)))) + respData3.put(tp3, new FetchResponseData.PartitionData() + .setPartitionIndex(tp3.topicPartition.partition) + .setHighWatermark(50) + .setLastStableOffset(50) + .setLogStartOffset(0)) + val resp3 = context2.updateAndGenerateResponseData(respData3) + assertEquals(Errors.NONE, resp3.error) + assertEquals(resp1.sessionId, resp3.sessionId) + assertEquals(Utils.mkSet(tp1.topicPartition, tp2.topicPartition), resp3.responseData(topicNames, ApiKeys.FETCH.latestVersion()).keySet) + + // Only the partitions whose returned records in the last response + // were deprioritized + assertPartitionsOrder(context2, Seq(tp1, tp3, tp2)) + } + + @Test + def testCachedPartitionEqualsAndHashCode(): Unit = { + val topicId = Uuid.randomUuid() + val topicName = "topic" + val partition = 0 + + val cachedPartitionWithIdAndName = new CachedPartition(topicName, topicId, partition) + val cachedPartitionWithIdAndNoName = new CachedPartition(null, topicId, partition) + val cachedPartitionWithDifferentIdAndName = new CachedPartition(topicName, Uuid.randomUuid(), partition) + val cachedPartitionWithZeroIdAndName = new CachedPartition(topicName, Uuid.ZERO_UUID, partition) + val cachedPartitionWithZeroIdAndOtherName = new CachedPartition("otherTopic", Uuid.ZERO_UUID, partition) + + // CachedPartitions with valid topic IDs will compare topic ID and partition but not topic name. + assertEquals(cachedPartitionWithIdAndName, cachedPartitionWithIdAndNoName) + assertEquals(cachedPartitionWithIdAndName.hashCode, cachedPartitionWithIdAndNoName.hashCode) + + assertNotEquals(cachedPartitionWithIdAndName, cachedPartitionWithDifferentIdAndName) + assertNotEquals(cachedPartitionWithIdAndName.hashCode, cachedPartitionWithDifferentIdAndName.hashCode) + + assertNotEquals(cachedPartitionWithIdAndName, cachedPartitionWithZeroIdAndName) + assertNotEquals(cachedPartitionWithIdAndName.hashCode, cachedPartitionWithZeroIdAndName.hashCode) + + // CachedPartitions will null name and valid IDs will act just like ones with valid names + assertEquals(cachedPartitionWithIdAndNoName, cachedPartitionWithIdAndName) + assertEquals(cachedPartitionWithIdAndNoName.hashCode, cachedPartitionWithIdAndName.hashCode) + + assertNotEquals(cachedPartitionWithIdAndNoName, cachedPartitionWithDifferentIdAndName) + assertNotEquals(cachedPartitionWithIdAndNoName.hashCode, cachedPartitionWithDifferentIdAndName.hashCode) + + assertNotEquals(cachedPartitionWithIdAndNoName, cachedPartitionWithZeroIdAndName) + assertNotEquals(cachedPartitionWithIdAndNoName.hashCode, cachedPartitionWithZeroIdAndName.hashCode) + + // CachedPartition with zero Uuids will compare topic name and partition. + assertNotEquals(cachedPartitionWithZeroIdAndName, cachedPartitionWithZeroIdAndOtherName) + assertNotEquals(cachedPartitionWithZeroIdAndName.hashCode, cachedPartitionWithZeroIdAndOtherName.hashCode) + + assertEquals(cachedPartitionWithZeroIdAndName, cachedPartitionWithZeroIdAndName) + assertEquals(cachedPartitionWithZeroIdAndName.hashCode, cachedPartitionWithZeroIdAndName.hashCode) + } + + @Test + def testMaybeResolveUnknownName(): Unit = { + val namedPartition = new CachedPartition("topic", Uuid.randomUuid(), 0) + val nullNamePartition1 = new CachedPartition(null, Uuid.randomUuid(), 0) + val nullNamePartition2 = new CachedPartition(null, Uuid.randomUuid(), 0) + + val topicNames = Map(namedPartition.topicId -> "foo", nullNamePartition1.topicId -> "bar").asJava + + // Since the name is not null, we should not change the topic name. + // We should never have a scenario where the same ID is used by two topic names, but this is used to test we respect the null check. + namedPartition.maybeResolveUnknownName(topicNames) + assertEquals("topic", namedPartition.topic) + + // We will resolve this name as it is in the map and the current name is null. + nullNamePartition1.maybeResolveUnknownName(topicNames) + assertEquals("bar", nullNamePartition1.topic) + + // If the ID is not in the map, then we don't resolve the name. + nullNamePartition2.maybeResolveUnknownName(topicNames) + assertEquals(null, nullNamePartition2.topic) + } + + private def assertPartitionsOrder(context: FetchContext, partitions: Seq[TopicIdPartition]): Unit = { + val partitionsInContext = ArrayBuffer.empty[TopicIdPartition] + context.foreachPartition { (tp, _) => + partitionsInContext += tp + } + assertEquals(partitions, partitionsInContext.toSeq) + } +} + +object FetchSessionTest { + def idUsageCombinations: java.util.stream.Stream[Arguments] = { + val data = new java.util.ArrayList[Arguments]() + for (startsWithTopicIds <- Array(java.lang.Boolean.TRUE, java.lang.Boolean.FALSE)) + for (endsWithTopicIds <- Array(java.lang.Boolean.TRUE, java.lang.Boolean.FALSE)) + data.add(Arguments.of(startsWithTopicIds, endsWithTopicIds)) + data.stream() + } +} diff --git a/core/src/test/scala/unit/kafka/server/FinalizedFeatureCacheTest.scala b/core/src/test/scala/unit/kafka/server/FinalizedFeatureCacheTest.scala new file mode 100644 index 0000000..d0f4c0a --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/FinalizedFeatureCacheTest.scala @@ -0,0 +1,114 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import org.apache.kafka.common.feature.{Features, FinalizedVersionRange, SupportedVersionRange} +import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows, assertTrue} +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ + +class FinalizedFeatureCacheTest { + + @Test + def testEmpty(): Unit = { + assertTrue(new FinalizedFeatureCache(BrokerFeatures.createDefault()).get.isEmpty) + } + + @Test + def testUpdateOrThrowFailedDueToInvalidEpoch(): Unit = { + val supportedFeatures = Map[String, SupportedVersionRange]( + "feature_1" -> new SupportedVersionRange(1, 4)) + val brokerFeatures = BrokerFeatures.createDefault() + brokerFeatures.setSupportedFeatures(Features.supportedFeatures(supportedFeatures.asJava)) + + val features = Map[String, FinalizedVersionRange]( + "feature_1" -> new FinalizedVersionRange(1, 4)) + val finalizedFeatures = Features.finalizedFeatures(features.asJava) + + val cache = new FinalizedFeatureCache(brokerFeatures) + cache.updateOrThrow(finalizedFeatures, 10) + assertTrue(cache.get.isDefined) + assertEquals(finalizedFeatures, cache.get.get.features) + assertEquals(10, cache.get.get.epoch) + + assertThrows(classOf[FeatureCacheUpdateException], () => cache.updateOrThrow(finalizedFeatures, 9)) + + // Check that the failed updateOrThrow call did not make any mutations. + assertTrue(cache.get.isDefined) + assertEquals(finalizedFeatures, cache.get.get.features) + assertEquals(10, cache.get.get.epoch) + } + + @Test + def testUpdateOrThrowFailedDueToInvalidFeatures(): Unit = { + val supportedFeatures = + Map[String, SupportedVersionRange]("feature_1" -> new SupportedVersionRange(1, 1)) + val brokerFeatures = BrokerFeatures.createDefault() + brokerFeatures.setSupportedFeatures(Features.supportedFeatures(supportedFeatures.asJava)) + + val features = Map[String, FinalizedVersionRange]( + "feature_1" -> new FinalizedVersionRange(1, 2)) + val finalizedFeatures = Features.finalizedFeatures(features.asJava) + + val cache = new FinalizedFeatureCache(brokerFeatures) + assertThrows(classOf[FeatureCacheUpdateException], () => cache.updateOrThrow(finalizedFeatures, 12)) + + // Check that the failed updateOrThrow call did not make any mutations. + assertTrue(cache.isEmpty) + } + + @Test + def testUpdateOrThrowSuccess(): Unit = { + val supportedFeatures = + Map[String, SupportedVersionRange]("feature_1" -> new SupportedVersionRange(1, 4)) + val brokerFeatures = BrokerFeatures.createDefault() + brokerFeatures.setSupportedFeatures(Features.supportedFeatures(supportedFeatures.asJava)) + + val features = Map[String, FinalizedVersionRange]( + "feature_1" -> new FinalizedVersionRange(2, 3)) + val finalizedFeatures = Features.finalizedFeatures(features.asJava) + + val cache = new FinalizedFeatureCache(brokerFeatures) + cache.updateOrThrow(finalizedFeatures, 12) + assertTrue(cache.get.isDefined) + assertEquals(finalizedFeatures, cache.get.get.features) + assertEquals(12, cache.get.get.epoch) + } + + @Test + def testClear(): Unit = { + val supportedFeatures = + Map[String, SupportedVersionRange]("feature_1" -> new SupportedVersionRange(1, 4)) + val brokerFeatures = BrokerFeatures.createDefault() + brokerFeatures.setSupportedFeatures(Features.supportedFeatures(supportedFeatures.asJava)) + + val features = Map[String, FinalizedVersionRange]( + "feature_1" -> new FinalizedVersionRange(2, 3)) + val finalizedFeatures = Features.finalizedFeatures(features.asJava) + + val cache = new FinalizedFeatureCache(brokerFeatures) + cache.updateOrThrow(finalizedFeatures, 12) + assertTrue(cache.get.isDefined) + assertEquals(finalizedFeatures, cache.get.get.features) + assertEquals(12, cache.get.get.epoch) + + cache.clear() + assertTrue(cache.isEmpty) + } +} diff --git a/core/src/test/scala/unit/kafka/server/FinalizedFeatureChangeListenerTest.scala b/core/src/test/scala/unit/kafka/server/FinalizedFeatureChangeListenerTest.scala new file mode 100644 index 0000000..d59474e --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/FinalizedFeatureChangeListenerTest.scala @@ -0,0 +1,270 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util.concurrent.{CountDownLatch, TimeoutException} + +import kafka.server.QuorumTestHarness +import kafka.zk.{FeatureZNode, FeatureZNodeStatus, ZkVersion} +import kafka.utils.TestUtils +import org.apache.kafka.common.utils.Exit +import org.apache.kafka.common.feature.{Features, FinalizedVersionRange, SupportedVersionRange} +import org.apache.kafka.test.{TestUtils => JTestUtils} +import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertNotEquals, assertThrows, assertTrue} +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ + +class FinalizedFeatureChangeListenerTest extends QuorumTestHarness { + + private def createBrokerFeatures(): BrokerFeatures = { + val supportedFeaturesMap = Map[String, SupportedVersionRange]( + "feature_1" -> new SupportedVersionRange(1, 4), + "feature_2" -> new SupportedVersionRange(1, 3)) + val brokerFeatures = BrokerFeatures.createDefault() + brokerFeatures.setSupportedFeatures(Features.supportedFeatures(supportedFeaturesMap.asJava)) + brokerFeatures + } + + private def createFinalizedFeatures(): FinalizedFeaturesAndEpoch = { + val finalizedFeaturesMap = Map[String, FinalizedVersionRange]( + "feature_1" -> new FinalizedVersionRange(2, 3)) + val finalizedFeatures = Features.finalizedFeatures(finalizedFeaturesMap.asJava) + zkClient.createFeatureZNode(FeatureZNode(FeatureZNodeStatus.Enabled, finalizedFeatures)) + val (mayBeFeatureZNodeBytes, version) = zkClient.getDataAndVersion(FeatureZNode.path) + assertNotEquals(version, ZkVersion.UnknownVersion) + assertFalse(mayBeFeatureZNodeBytes.isEmpty) + FinalizedFeaturesAndEpoch(finalizedFeatures, version) + } + + private def createListener( + cache: FinalizedFeatureCache, + expectedCacheContent: Option[FinalizedFeaturesAndEpoch] + ): FinalizedFeatureChangeListener = { + val listener = new FinalizedFeatureChangeListener(cache, zkClient) + assertFalse(listener.isListenerInitiated) + assertTrue(cache.isEmpty) + listener.initOrThrow(15000) + assertTrue(listener.isListenerInitiated) + if (expectedCacheContent.isDefined) { + val mayBeNewCacheContent = cache.get + assertFalse(mayBeNewCacheContent.isEmpty) + val newCacheContent = mayBeNewCacheContent.get + assertEquals(expectedCacheContent.get.features, newCacheContent.features) + assertEquals(expectedCacheContent.get.epoch, newCacheContent.epoch) + } else { + val mayBeNewCacheContent = cache.get + assertTrue(mayBeNewCacheContent.isEmpty) + } + listener + } + + /** + * Tests that the listener can be initialized, and that it can listen to ZK notifications + * successfully from an "Enabled" FeatureZNode (the ZK data has no feature incompatibilities). + * Particularly the test checks if multiple notifications can be processed in ZK + * (i.e. whether the FeatureZNode watch can be re-established). + */ + @Test + def testInitSuccessAndNotificationSuccess(): Unit = { + val initialFinalizedFeatures = createFinalizedFeatures() + val brokerFeatures = createBrokerFeatures() + val cache = new FinalizedFeatureCache(brokerFeatures) + val listener = createListener(cache, Some(initialFinalizedFeatures)) + + def updateAndCheckCache(finalizedFeatures: Features[FinalizedVersionRange]): Unit = { + zkClient.updateFeatureZNode(FeatureZNode(FeatureZNodeStatus.Enabled, finalizedFeatures)) + val (mayBeFeatureZNodeNewBytes, updatedVersion) = zkClient.getDataAndVersion(FeatureZNode.path) + assertNotEquals(updatedVersion, ZkVersion.UnknownVersion) + assertFalse(mayBeFeatureZNodeNewBytes.isEmpty) + assertTrue(updatedVersion > initialFinalizedFeatures.epoch) + + cache.waitUntilEpochOrThrow(updatedVersion, JTestUtils.DEFAULT_MAX_WAIT_MS) + assertEquals(FinalizedFeaturesAndEpoch(finalizedFeatures, updatedVersion), cache.get.get) + assertTrue(listener.isListenerInitiated) + } + + // Check if the write succeeds and a ZK notification is received that causes the feature cache + // to be populated. + updateAndCheckCache( + Features.finalizedFeatures( + Map[String, FinalizedVersionRange]( + "feature_1" -> new FinalizedVersionRange(2, 4)).asJava)) + // Check if second write succeeds and a ZK notification is again received that causes the cache + // to be populated. This check is needed to verify that the watch on the FeatureZNode was + // re-established after the notification was received due to the first write above. + updateAndCheckCache( + Features.finalizedFeatures( + Map[String, FinalizedVersionRange]( + "feature_1" -> new FinalizedVersionRange(2, 4), + "feature_2" -> new FinalizedVersionRange(1, 3)).asJava)) + } + + /** + * Tests that the listener can be initialized, and that it can process FeatureZNode deletion + * successfully. + */ + @Test + def testFeatureZNodeDeleteNotificationProcessing(): Unit = { + val brokerFeatures = createBrokerFeatures() + val cache = new FinalizedFeatureCache(brokerFeatures) + val initialFinalizedFeatures = createFinalizedFeatures() + val listener = createListener(cache, Some(initialFinalizedFeatures)) + + zkClient.deleteFeatureZNode() + val (mayBeFeatureZNodeDeletedBytes, deletedVersion) = zkClient.getDataAndVersion(FeatureZNode.path) + assertEquals(deletedVersion, ZkVersion.UnknownVersion) + assertTrue(mayBeFeatureZNodeDeletedBytes.isEmpty) + TestUtils.waitUntilTrue(() => { + cache.isEmpty + }, "Timed out waiting for FinalizedFeatureCache to become empty") + assertTrue(listener.isListenerInitiated) + } + + /** + * Tests that the listener can be initialized, and that it can process disabling of a FeatureZNode + * successfully. + */ + @Test + def testFeatureZNodeDisablingNotificationProcessing(): Unit = { + val brokerFeatures = createBrokerFeatures() + val cache = new FinalizedFeatureCache(brokerFeatures) + val initialFinalizedFeatures = createFinalizedFeatures() + + val updatedFinalizedFeaturesMap = Map[String, FinalizedVersionRange]() + val updatedFinalizedFeatures = Features.finalizedFeatures(updatedFinalizedFeaturesMap.asJava) + zkClient.updateFeatureZNode(FeatureZNode(FeatureZNodeStatus.Disabled, updatedFinalizedFeatures)) + val (mayBeFeatureZNodeNewBytes, updatedVersion) = zkClient.getDataAndVersion(FeatureZNode.path) + assertNotEquals(updatedVersion, ZkVersion.UnknownVersion) + assertFalse(mayBeFeatureZNodeNewBytes.isEmpty) + assertTrue(updatedVersion > initialFinalizedFeatures.epoch) + assertTrue(cache.get.isEmpty) + } + + /** + * Tests that the wait operation on the cache fails (as expected) when an epoch can never be + * reached. Also tests that the wait operation on the cache succeeds when an epoch is expected to + * be reached. + */ + @Test + def testCacheUpdateWaitFailsForUnreachableVersion(): Unit = { + val initialFinalizedFeatures = createFinalizedFeatures() + val cache = new FinalizedFeatureCache(createBrokerFeatures()) + val listener = createListener(cache, Some(initialFinalizedFeatures)) + + assertThrows(classOf[TimeoutException], () => cache.waitUntilEpochOrThrow(initialFinalizedFeatures.epoch + 1, JTestUtils.DEFAULT_MAX_WAIT_MS)) + + val updatedFinalizedFeaturesMap = Map[String, FinalizedVersionRange]() + val updatedFinalizedFeatures = Features.finalizedFeatures(updatedFinalizedFeaturesMap.asJava) + zkClient.updateFeatureZNode(FeatureZNode(FeatureZNodeStatus.Disabled, updatedFinalizedFeatures)) + val (mayBeFeatureZNodeNewBytes, updatedVersion) = zkClient.getDataAndVersion(FeatureZNode.path) + assertNotEquals(updatedVersion, ZkVersion.UnknownVersion) + assertFalse(mayBeFeatureZNodeNewBytes.isEmpty) + assertTrue(updatedVersion > initialFinalizedFeatures.epoch) + + assertThrows(classOf[TimeoutException], () => cache.waitUntilEpochOrThrow(updatedVersion, JTestUtils.DEFAULT_MAX_WAIT_MS)) + assertTrue(cache.get.isEmpty) + assertTrue(listener.isListenerInitiated) + } + + /** + * Tests that the listener initialization fails when it picks up a feature incompatibility from + * ZK from an "Enabled" FeatureZNode. + */ + @Test + def testInitFailureDueToFeatureIncompatibility(): Unit = { + val brokerFeatures = createBrokerFeatures() + val cache = new FinalizedFeatureCache(brokerFeatures) + + val incompatibleFinalizedFeaturesMap = Map[String, FinalizedVersionRange]( + "feature_1" -> new FinalizedVersionRange(2, 5)) + val incompatibleFinalizedFeatures = Features.finalizedFeatures(incompatibleFinalizedFeaturesMap.asJava) + zkClient.createFeatureZNode(FeatureZNode(FeatureZNodeStatus.Enabled, incompatibleFinalizedFeatures)) + val (mayBeFeatureZNodeBytes, initialVersion) = zkClient.getDataAndVersion(FeatureZNode.path) + assertNotEquals(initialVersion, ZkVersion.UnknownVersion) + assertFalse(mayBeFeatureZNodeBytes.isEmpty) + + val exitLatch = new CountDownLatch(1) + Exit.setExitProcedure((_, _) => exitLatch.countDown()) + try { + val listener = new FinalizedFeatureChangeListener(cache, zkClient) + assertFalse(listener.isListenerInitiated) + assertTrue(cache.isEmpty) + assertThrows(classOf[TimeoutException], () => listener.initOrThrow(5000)) + exitLatch.await() + assertFalse(listener.isListenerInitiated) + assertTrue(listener.isListenerDead) + assertTrue(cache.isEmpty) + } finally { + Exit.resetExitProcedure() + } + } + + /** + * Tests that the listener initialization fails when invalid wait time (<= 0) is provided as input. + */ + @Test + def testInitFailureDueToInvalidWaitTime(): Unit = { + val brokerFeatures = createBrokerFeatures() + val cache = new FinalizedFeatureCache(brokerFeatures) + val listener = new FinalizedFeatureChangeListener(cache, zkClient) + assertThrows(classOf[IllegalArgumentException], () => listener.initOrThrow(0)) + assertThrows(classOf[IllegalArgumentException], () => listener.initOrThrow(-1)) + } + + /** + * Tests that after successful initialization, the listener fails when it picks up a feature + * incompatibility from ZK. + */ + @Test + def testNotificationFailureDueToFeatureIncompatibility(): Unit = { + val brokerFeatures = createBrokerFeatures() + val cache = new FinalizedFeatureCache(brokerFeatures) + val initialFinalizedFeatures = createFinalizedFeatures() + val listener = createListener(cache, Some(initialFinalizedFeatures)) + + val exitLatch = new CountDownLatch(1) + Exit.setExitProcedure((_, _) => exitLatch.countDown()) + val incompatibleFinalizedFeaturesMap = Map[String, FinalizedVersionRange]( + "feature_1" -> new FinalizedVersionRange( + brokerFeatures.supportedFeatures.get("feature_1").min(), + (brokerFeatures.supportedFeatures.get("feature_1").max() + 1).asInstanceOf[Short])) + val incompatibleFinalizedFeatures = Features.finalizedFeatures(incompatibleFinalizedFeaturesMap.asJava) + zkClient.updateFeatureZNode(FeatureZNode(FeatureZNodeStatus.Enabled, incompatibleFinalizedFeatures)) + val (mayBeFeatureZNodeIncompatibleBytes, updatedVersion) = zkClient.getDataAndVersion(FeatureZNode.path) + assertNotEquals(updatedVersion, ZkVersion.UnknownVersion) + assertFalse(mayBeFeatureZNodeIncompatibleBytes.isEmpty) + + try { + TestUtils.waitUntilTrue(() => { + // Make sure the custom exit procedure (defined above) was called. + exitLatch.getCount == 0 && + // Make sure the listener is no longer initiated (because, it is dead). + !listener.isListenerInitiated && + // Make sure the listener dies after hitting an exception when processing incompatible + // features read from ZK. + listener.isListenerDead && + // Make sure the cache contents are as expected, and, the incompatible features were not + // applied. + cache.get.get.equals(initialFinalizedFeatures) + }, "Timed out waiting for listener death and FinalizedFeatureCache to be updated") + } finally { + Exit.resetExitProcedure() + } + } +} diff --git a/core/src/test/scala/unit/kafka/server/ForwardingManagerTest.scala b/core/src/test/scala/unit/kafka/server/ForwardingManagerTest.scala new file mode 100644 index 0000000..d0fc30f --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/ForwardingManagerTest.scala @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.net.InetAddress +import java.nio.ByteBuffer +import java.util.Optional +import java.util.concurrent.atomic.AtomicReference + +import kafka.network +import kafka.network.RequestChannel +import kafka.utils.MockTime +import org.apache.kafka.clients.{MockClient, NodeApiVersions} +import org.apache.kafka.clients.MockClient.RequestMatcher +import org.apache.kafka.common.Node +import org.apache.kafka.common.config.{ConfigResource, TopicConfig} +import org.apache.kafka.common.memory.MemoryPool +import org.apache.kafka.common.message.ApiMessageType.ListenerType +import org.apache.kafka.common.message.{AlterConfigsResponseData, ApiVersionsResponseData} +import org.apache.kafka.common.network.{ClientInformation, ListenerName} +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.requests.{AbstractRequest, AbstractResponse, AlterConfigsRequest, AlterConfigsResponse, EnvelopeRequest, EnvelopeResponse, RequestContext, RequestHeader, RequestTestUtils} +import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol} +import org.apache.kafka.common.security.authenticator.DefaultKafkaPrincipalBuilder +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test +import org.mockito.Mockito + +import scala.jdk.CollectionConverters._ + +class ForwardingManagerTest { + private val time = new MockTime() + private val client = new MockClient(time) + private val controllerNodeProvider = Mockito.mock(classOf[ControllerNodeProvider]) + private val brokerToController = new MockBrokerToControllerChannelManager( + client, time, controllerNodeProvider, controllerApiVersions) + private val forwardingManager = new ForwardingManagerImpl(brokerToController) + private val principalBuilder = new DefaultKafkaPrincipalBuilder(null, null) + + private def controllerApiVersions: NodeApiVersions = { + // The Envelope API is not yet included in the standard set of APIs + val envelopeApiVersion = new ApiVersionsResponseData.ApiVersion() + .setApiKey(ApiKeys.ENVELOPE.id) + .setMinVersion(ApiKeys.ENVELOPE.oldestVersion) + .setMaxVersion(ApiKeys.ENVELOPE.latestVersion) + NodeApiVersions.create(List(envelopeApiVersion).asJava) + } + + @Test + def testResponseCorrelationIdMismatch(): Unit = { + val requestCorrelationId = 27 + val clientPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "client") + val (requestHeader, requestBuffer) = buildRequest(testAlterConfigRequest, requestCorrelationId) + val request = buildRequest(requestHeader, requestBuffer, clientPrincipal) + + val responseBody = new AlterConfigsResponse(new AlterConfigsResponseData()) + val responseBuffer = RequestTestUtils.serializeResponseWithHeader(responseBody, requestHeader.apiVersion, + requestCorrelationId + 1) + + Mockito.when(controllerNodeProvider.get()).thenReturn(Some(new Node(0, "host", 1234))) + val isEnvelopeRequest: RequestMatcher = request => request.isInstanceOf[EnvelopeRequest] + client.prepareResponse(isEnvelopeRequest, new EnvelopeResponse(responseBuffer, Errors.NONE)); + + val responseOpt = new AtomicReference[Option[AbstractResponse]]() + forwardingManager.forwardRequest(request, responseOpt.set) + brokerToController.poll() + assertTrue(Option(responseOpt.get).isDefined) + + val response = responseOpt.get.get + assertEquals(Map(Errors.UNKNOWN_SERVER_ERROR -> 1).asJava, response.errorCounts()) + } + + @Test + def testUnsupportedVersions(): Unit = { + val requestCorrelationId = 27 + val clientPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "client") + val (requestHeader, requestBuffer) = buildRequest(testAlterConfigRequest, requestCorrelationId) + val request = buildRequest(requestHeader, requestBuffer, clientPrincipal) + + val responseBody = new AlterConfigsResponse(new AlterConfigsResponseData()) + val responseBuffer = RequestTestUtils.serializeResponseWithHeader(responseBody, + requestHeader.apiVersion, requestCorrelationId) + + Mockito.when(controllerNodeProvider.get()).thenReturn(Some(new Node(0, "host", 1234))) + val isEnvelopeRequest: RequestMatcher = request => request.isInstanceOf[EnvelopeRequest] + client.prepareResponse(isEnvelopeRequest, new EnvelopeResponse(responseBuffer, Errors.UNSUPPORTED_VERSION)); + + val responseOpt = new AtomicReference[Option[AbstractResponse]]() + forwardingManager.forwardRequest(request, responseOpt.set) + brokerToController.poll() + assertEquals(None, responseOpt.get) + } + + @Test + def testForwardingTimeoutWaitingForControllerDiscovery(): Unit = { + val requestCorrelationId = 27 + val clientPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "client") + val (requestHeader, requestBuffer) = buildRequest(testAlterConfigRequest, requestCorrelationId) + val request = buildRequest(requestHeader, requestBuffer, clientPrincipal) + + Mockito.when(controllerNodeProvider.get()).thenReturn(None) + + val response = new AtomicReference[AbstractResponse]() + forwardingManager.forwardRequest(request, res => res.foreach(response.set)) + brokerToController.poll() + assertNull(response.get) + + // The controller is not discovered before reaching the retry timeout. + // The request should fail with a timeout error. + time.sleep(brokerToController.retryTimeoutMs) + brokerToController.poll() + assertNotNull(response.get) + + val alterConfigResponse = response.get.asInstanceOf[AlterConfigsResponse] + assertEquals(Map(Errors.REQUEST_TIMED_OUT -> 1).asJava, alterConfigResponse.errorCounts) + } + + @Test + def testForwardingTimeoutAfterRetry(): Unit = { + val requestCorrelationId = 27 + val clientPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "client") + val (requestHeader, requestBuffer) = buildRequest(testAlterConfigRequest, requestCorrelationId) + val request = buildRequest(requestHeader, requestBuffer, clientPrincipal) + + Mockito.when(controllerNodeProvider.get()).thenReturn(Some(new Node(0, "host", 1234))) + + val response = new AtomicReference[AbstractResponse]() + forwardingManager.forwardRequest(request, res => res.foreach(response.set)) + brokerToController.poll() + assertNull(response.get) + + // After reaching the retry timeout, we get a disconnect. Instead of retrying, + // we should fail the request with a timeout error. + time.sleep(brokerToController.retryTimeoutMs) + client.respond(testAlterConfigRequest.getErrorResponse(0, Errors.UNKNOWN_SERVER_ERROR.exception), true) + brokerToController.poll() + brokerToController.poll() + assertNotNull(response.get) + + val alterConfigResponse = response.get.asInstanceOf[AlterConfigsResponse] + assertEquals(Map(Errors.REQUEST_TIMED_OUT -> 1).asJava, alterConfigResponse.errorCounts) + } + + @Test + def testUnsupportedVersionFromNetworkClient(): Unit = { + val requestCorrelationId = 27 + val clientPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "client") + val (requestHeader, requestBuffer) = buildRequest(testAlterConfigRequest, requestCorrelationId) + val request = buildRequest(requestHeader, requestBuffer, clientPrincipal) + + val controllerNode = new Node(0, "host", 1234) + Mockito.when(controllerNodeProvider.get()).thenReturn(Some(controllerNode)) + + client.prepareUnsupportedVersionResponse(req => req.apiKey == requestHeader.apiKey) + + val response = new AtomicReference[AbstractResponse]() + forwardingManager.forwardRequest(request, res => res.foreach(response.set)) + brokerToController.poll() + assertNotNull(response.get) + + val alterConfigResponse = response.get.asInstanceOf[AlterConfigsResponse] + assertEquals(Map(Errors.UNKNOWN_SERVER_ERROR -> 1).asJava, alterConfigResponse.errorCounts) + } + + @Test + def testFailedAuthentication(): Unit = { + val requestCorrelationId = 27 + val clientPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "client") + val (requestHeader, requestBuffer) = buildRequest(testAlterConfigRequest, requestCorrelationId) + val request = buildRequest(requestHeader, requestBuffer, clientPrincipal) + + val controllerNode = new Node(0, "host", 1234) + Mockito.when(controllerNodeProvider.get()).thenReturn(Some(controllerNode)) + + client.createPendingAuthenticationError(controllerNode, 50) + + val response = new AtomicReference[AbstractResponse]() + forwardingManager.forwardRequest(request, res => res.foreach(response.set)) + brokerToController.poll() + assertNotNull(response.get) + + val alterConfigResponse = response.get.asInstanceOf[AlterConfigsResponse] + assertEquals(Map(Errors.UNKNOWN_SERVER_ERROR -> 1).asJava, alterConfigResponse.errorCounts) + } + + private def buildRequest( + body: AbstractRequest, + correlationId: Int + ): (RequestHeader, ByteBuffer) = { + val header = new RequestHeader( + body.apiKey, + body.version, + "clientId", + correlationId + ) + val buffer = body.serializeWithHeader(header) + + // Fast-forward buffer to start of the request as `RequestChannel.Request` expects + RequestHeader.parse(buffer) + + (header, buffer) + } + + private def buildRequest( + requestHeader: RequestHeader, + requestBuffer: ByteBuffer, + principal: KafkaPrincipal + ): RequestChannel.Request = { + val requestContext = new RequestContext( + requestHeader, + "1", + InetAddress.getLocalHost, + principal, + new ListenerName("client"), + SecurityProtocol.SASL_PLAINTEXT, + ClientInformation.EMPTY, + false, + Optional.of(principalBuilder) + ) + + new network.RequestChannel.Request( + processor = 1, + context = requestContext, + startTimeNanos = time.nanoseconds(), + memoryPool = MemoryPool.NONE, + buffer = requestBuffer, + metrics = new RequestChannel.Metrics(ListenerType.CONTROLLER), + envelope = None + ) + } + + private def testAlterConfigRequest: AlterConfigsRequest = { + val configResource = new ConfigResource(ConfigResource.Type.TOPIC, "foo") + val configs = List(new AlterConfigsRequest.ConfigEntry(TopicConfig.MIN_IN_SYNC_REPLICAS_CONFIG, "1")).asJava + new AlterConfigsRequest.Builder(Map( + configResource -> new AlterConfigsRequest.Config(configs) + ).asJava, false).build() + } + +} diff --git a/core/src/test/scala/unit/kafka/server/HighwatermarkPersistenceTest.scala b/core/src/test/scala/unit/kafka/server/HighwatermarkPersistenceTest.scala new file mode 100755 index 0000000..221fd9a --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/HighwatermarkPersistenceTest.scala @@ -0,0 +1,195 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +package kafka.server + +import kafka.log._ +import java.io.File + +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.utils.Utils +import org.junit.jupiter.api._ +import org.junit.jupiter.api.Assertions._ +import kafka.utils.{KafkaScheduler, MockTime, TestUtils} + +import kafka.cluster.Partition +import kafka.server.metadata.MockConfigRepository +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.record.SimpleRecord + +class HighwatermarkPersistenceTest { + + val configs = TestUtils.createBrokerConfigs(2, TestUtils.MockZkConnect).map(KafkaConfig.fromProps) + val topic = "foo" + val configRepository = new MockConfigRepository() + val logManagers = configs map { config => + TestUtils.createLogManager( + logDirs = config.logDirs.map(new File(_)), + cleanerConfig = CleanerConfig()) + } + + val logDirFailureChannels = configs map { config => + new LogDirFailureChannel(config.logDirs.size) + } + + val alterIsrManager = TestUtils.createAlterIsrManager() + + @AfterEach + def teardown(): Unit = { + for (manager <- logManagers; dir <- manager.liveLogDirs) + Utils.delete(dir) + } + + @Test + def testHighWatermarkPersistenceSinglePartition(): Unit = { + // create kafka scheduler + val scheduler = new KafkaScheduler(2) + scheduler.startup() + val metrics = new Metrics + val time = new MockTime + val quotaManager = QuotaFactory.instantiate(configs.head, metrics, time, "") + // create replica manager + val replicaManager = new ReplicaManager( + metrics = metrics, + config = configs.head, + time = time, + scheduler = scheduler, + logManager = logManagers.head, + quotaManagers = quotaManager, + metadataCache = MetadataCache.zkMetadataCache(configs.head.brokerId), + logDirFailureChannel = logDirFailureChannels.head, + alterIsrManager = alterIsrManager) + replicaManager.startup() + try { + replicaManager.checkpointHighWatermarks() + var fooPartition0Hw = hwmFor(replicaManager, topic, 0) + assertEquals(0L, fooPartition0Hw) + val tp0 = new TopicPartition(topic, 0) + val partition0 = replicaManager.createPartition(tp0) + // create leader and follower replicas + val log0 = logManagers.head.getOrCreateLog(new TopicPartition(topic, 0), topicId = None) + partition0.setLog(log0, isFutureLog = false) + + partition0.updateAssignmentAndIsr( + assignment = Seq(configs.head.brokerId, configs.last.brokerId), + isr = Set(configs.head.brokerId), + addingReplicas = Seq.empty, + removingReplicas = Seq.empty + ) + + replicaManager.checkpointHighWatermarks() + fooPartition0Hw = hwmFor(replicaManager, topic, 0) + assertEquals(log0.highWatermark, fooPartition0Hw) + // set the high watermark for local replica + partition0.localLogOrException.updateHighWatermark(5L) + replicaManager.checkpointHighWatermarks() + fooPartition0Hw = hwmFor(replicaManager, topic, 0) + assertEquals(log0.highWatermark, fooPartition0Hw) + } finally { + // shutdown the replica manager upon test completion + replicaManager.shutdown(false) + quotaManager.shutdown() + metrics.close() + scheduler.shutdown() + } + } + + @Test + def testHighWatermarkPersistenceMultiplePartitions(): Unit = { + val topic1 = "foo1" + val topic2 = "foo2" + // create kafka scheduler + val scheduler = new KafkaScheduler(2) + scheduler.startup() + val metrics = new Metrics + val time = new MockTime + val quotaManager = QuotaFactory.instantiate(configs.head, metrics, time, "") + // create replica manager + val replicaManager = new ReplicaManager( + metrics = metrics, + config = configs.head, + time = time, + scheduler = scheduler, + logManager = logManagers.head, + quotaManagers = quotaManager, + metadataCache = MetadataCache.zkMetadataCache(configs.head.brokerId), + logDirFailureChannel = logDirFailureChannels.head, + alterIsrManager = alterIsrManager) + replicaManager.startup() + try { + replicaManager.checkpointHighWatermarks() + var topic1Partition0Hw = hwmFor(replicaManager, topic1, 0) + assertEquals(0L, topic1Partition0Hw) + val t1p0 = new TopicPartition(topic1, 0) + val topic1Partition0 = replicaManager.createPartition(t1p0) + // create leader log + val topic1Log0 = logManagers.head.getOrCreateLog(t1p0, topicId = None) + // create a local replica for topic1 + topic1Partition0.setLog(topic1Log0, isFutureLog = false) + replicaManager.checkpointHighWatermarks() + topic1Partition0Hw = hwmFor(replicaManager, topic1, 0) + assertEquals(topic1Log0.highWatermark, topic1Partition0Hw) + // set the high watermark for local replica + append(topic1Partition0, count = 5) + topic1Partition0.localLogOrException.updateHighWatermark(5L) + replicaManager.checkpointHighWatermarks() + topic1Partition0Hw = hwmFor(replicaManager, topic1, 0) + assertEquals(5L, topic1Log0.highWatermark) + assertEquals(5L, topic1Partition0Hw) + // add another partition and set highwatermark + val t2p0 = new TopicPartition(topic2, 0) + val topic2Partition0 = replicaManager.createPartition(t2p0) + // create leader log + val topic2Log0 = logManagers.head.getOrCreateLog(t2p0, topicId = None) + // create a local replica for topic2 + topic2Partition0.setLog(topic2Log0, isFutureLog = false) + replicaManager.checkpointHighWatermarks() + var topic2Partition0Hw = hwmFor(replicaManager, topic2, 0) + assertEquals(topic2Log0.highWatermark, topic2Partition0Hw) + // set the highwatermark for local replica + append(topic2Partition0, count = 15) + topic2Partition0.localLogOrException.updateHighWatermark(15L) + assertEquals(15L, topic2Log0.highWatermark) + // change the highwatermark for topic1 + append(topic1Partition0, count = 5) + topic1Partition0.localLogOrException.updateHighWatermark(10L) + assertEquals(10L, topic1Log0.highWatermark) + replicaManager.checkpointHighWatermarks() + // verify checkpointed hw for topic 2 + topic2Partition0Hw = hwmFor(replicaManager, topic2, 0) + assertEquals(15L, topic2Partition0Hw) + // verify checkpointed hw for topic 1 + topic1Partition0Hw = hwmFor(replicaManager, topic1, 0) + assertEquals(10L, topic1Partition0Hw) + } finally { + // shutdown the replica manager upon test completion + replicaManager.shutdown(false) + quotaManager.shutdown() + metrics.close() + scheduler.shutdown() + } + } + + private def append(partition: Partition, count: Int): Unit = { + val records = TestUtils.records((0 to count).map(i => new SimpleRecord(s"$i".getBytes))) + partition.localLogOrException.appendAsLeader(records, leaderEpoch = 0) + } + + private def hwmFor(replicaManager: ReplicaManager, topic: String, partition: Int): Long = { + replicaManager.highWatermarkCheckpoints(new File(replicaManager.config.logDirs.head).getAbsolutePath).read().getOrElse( + new TopicPartition(topic, partition), 0L) + } +} diff --git a/core/src/test/scala/unit/kafka/server/IsrExpirationTest.scala b/core/src/test/scala/unit/kafka/server/IsrExpirationTest.scala new file mode 100644 index 0000000..57e709c --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/IsrExpirationTest.scala @@ -0,0 +1,258 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +package kafka.server + +import java.io.File +import java.util.Properties + +import kafka.cluster.Partition +import kafka.log.{UnifiedLog, LogManager} +import kafka.server.QuotaFactory.QuotaManagers +import kafka.utils.TestUtils.MockAlterIsrManager +import kafka.utils._ +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.utils.Time +import org.easymock.EasyMock +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} + +import scala.collection.Seq +import scala.collection.mutable.{HashMap, Map} + +class IsrExpirationTest { + + var topicPartitionIsr: Map[(String, Int), Seq[Int]] = new HashMap[(String, Int), Seq[Int]]() + val replicaLagTimeMaxMs = 100L + val replicaFetchWaitMaxMs = 100 + val leaderLogEndOffset = 20 + val leaderLogHighWatermark = 20L + + val overridingProps = new Properties() + overridingProps.put(KafkaConfig.ReplicaLagTimeMaxMsProp, replicaLagTimeMaxMs.toString) + overridingProps.put(KafkaConfig.ReplicaFetchWaitMaxMsProp, replicaFetchWaitMaxMs.toString) + val configs = TestUtils.createBrokerConfigs(2, TestUtils.MockZkConnect).map(KafkaConfig.fromProps(_, overridingProps)) + val topic = "foo" + + val time = new MockTime + val metrics = new Metrics + + var quotaManager: QuotaManagers = null + var replicaManager: ReplicaManager = null + + var alterIsrManager: MockAlterIsrManager = _ + + @BeforeEach + def setUp(): Unit = { + val logManager: LogManager = EasyMock.createMock(classOf[LogManager]) + EasyMock.expect(logManager.liveLogDirs).andReturn(Array.empty[File]).anyTimes() + EasyMock.replay(logManager) + + alterIsrManager = TestUtils.createAlterIsrManager() + quotaManager = QuotaFactory.instantiate(configs.head, metrics, time, "") + replicaManager = new ReplicaManager( + metrics = metrics, + config = configs.head, + time = time, + scheduler = null, + logManager = logManager, + quotaManagers = quotaManager, + metadataCache = MetadataCache.zkMetadataCache(configs.head.brokerId), + logDirFailureChannel = new LogDirFailureChannel(configs.head.logDirs.size), + alterIsrManager = alterIsrManager) + } + + @AfterEach + def tearDown(): Unit = { + Option(replicaManager).foreach(_.shutdown(false)) + Option(quotaManager).foreach(_.shutdown()) + metrics.close() + } + + /* + * Test the case where a follower is caught up but stops making requests to the leader. Once beyond the configured time limit, it should fall out of ISR + */ + @Test + def testIsrExpirationForStuckFollowers(): Unit = { + val log = logMock + + // create one partition and all replicas + val partition0 = getPartitionWithAllReplicasInIsr(topic, 0, time, configs.head, log) + assertEquals(configs.map(_.brokerId).toSet, partition0.inSyncReplicaIds, "All replicas should be in ISR") + + // let the follower catch up to the Leader logEndOffset - 1 + for (replica <- partition0.remoteReplicas) + replica.updateFetchState( + followerFetchOffsetMetadata = LogOffsetMetadata(leaderLogEndOffset - 1), + followerStartOffset = 0L, + followerFetchTimeMs= time.milliseconds, + leaderEndOffset = leaderLogEndOffset) + var partition0OSR = partition0.getOutOfSyncReplicas(configs.head.replicaLagTimeMaxMs) + assertEquals(Set.empty[Int], partition0OSR, "No replica should be out of sync") + + // let some time pass + time.sleep(150) + + // now follower hasn't pulled any data for > replicaMaxLagTimeMs ms. So it is stuck + partition0OSR = partition0.getOutOfSyncReplicas(configs.head.replicaLagTimeMaxMs) + assertEquals(Set(configs.last.brokerId), partition0OSR, "Replica 1 should be out of sync") + EasyMock.verify(log) + } + + /* + * Test the case where a follower never makes a fetch request. It should fall out of ISR because it will be declared stuck + */ + @Test + def testIsrExpirationIfNoFetchRequestMade(): Unit = { + val log = logMock + + // create one partition and all replicas + val partition0 = getPartitionWithAllReplicasInIsr(topic, 0, time, configs.head, log) + assertEquals(configs.map(_.brokerId).toSet, partition0.inSyncReplicaIds, "All replicas should be in ISR") + + // Let enough time pass for the replica to be considered stuck + time.sleep(150) + + val partition0OSR = partition0.getOutOfSyncReplicas(configs.head.replicaLagTimeMaxMs) + assertEquals(Set(configs.last.brokerId), partition0OSR, "Replica 1 should be out of sync") + EasyMock.verify(log) + } + + /* + * Test the case where a follower continually makes fetch requests but is unable to catch up. It should fall out of the ISR + * However, any time it makes a request to the LogEndOffset it should be back in the ISR + */ + @Test + def testIsrExpirationForSlowFollowers(): Unit = { + // create leader replica + val log = logMock + // add one partition + val partition0 = getPartitionWithAllReplicasInIsr(topic, 0, time, configs.head, log) + assertEquals(configs.map(_.brokerId).toSet, partition0.inSyncReplicaIds, "All replicas should be in ISR") + // Make the remote replica not read to the end of log. It should be not be out of sync for at least 100 ms + for (replica <- partition0.remoteReplicas) + replica.updateFetchState( + followerFetchOffsetMetadata = LogOffsetMetadata(leaderLogEndOffset - 2), + followerStartOffset = 0L, + followerFetchTimeMs= time.milliseconds, + leaderEndOffset = leaderLogEndOffset) + + // Simulate 2 fetch requests spanning more than 100 ms which do not read to the end of the log. + // The replicas will no longer be in ISR. We do 2 fetches because we want to simulate the case where the replica is lagging but is not stuck + var partition0OSR = partition0.getOutOfSyncReplicas(configs.head.replicaLagTimeMaxMs) + assertEquals(Set.empty[Int], partition0OSR, "No replica should be out of sync") + + time.sleep(75) + + partition0.remoteReplicas.foreach { r => + r.updateFetchState( + followerFetchOffsetMetadata = LogOffsetMetadata(leaderLogEndOffset - 1), + followerStartOffset = 0L, + followerFetchTimeMs= time.milliseconds, + leaderEndOffset = leaderLogEndOffset) + } + partition0OSR = partition0.getOutOfSyncReplicas(configs.head.replicaLagTimeMaxMs) + assertEquals(Set.empty[Int], partition0OSR, "No replica should be out of sync") + + time.sleep(75) + + // The replicas will no longer be in ISR + partition0OSR = partition0.getOutOfSyncReplicas(configs.head.replicaLagTimeMaxMs) + assertEquals(Set(configs.last.brokerId), partition0OSR, "Replica 1 should be out of sync") + + // Now actually make a fetch to the end of the log. The replicas should be back in ISR + partition0.remoteReplicas.foreach { r => + r.updateFetchState( + followerFetchOffsetMetadata = LogOffsetMetadata(leaderLogEndOffset), + followerStartOffset = 0L, + followerFetchTimeMs= time.milliseconds, + leaderEndOffset = leaderLogEndOffset) + } + partition0OSR = partition0.getOutOfSyncReplicas(configs.head.replicaLagTimeMaxMs) + assertEquals(Set.empty[Int], partition0OSR, "No replica should be out of sync") + + EasyMock.verify(log) + } + + /* + * Test the case where a follower has already caught up with same log end offset with the leader. This follower should not be considered as out-of-sync + */ + @Test + def testIsrExpirationForCaughtUpFollowers(): Unit = { + val log = logMock + + // create one partition and all replicas + val partition0 = getPartitionWithAllReplicasInIsr(topic, 0, time, configs.head, log) + assertEquals(configs.map(_.brokerId).toSet, partition0.inSyncReplicaIds, "All replicas should be in ISR") + + // let the follower catch up to the Leader logEndOffset + for (replica <- partition0.remoteReplicas) + replica.updateFetchState( + followerFetchOffsetMetadata = LogOffsetMetadata(leaderLogEndOffset), + followerStartOffset = 0L, + followerFetchTimeMs= time.milliseconds, + leaderEndOffset = leaderLogEndOffset) + + var partition0OSR = partition0.getOutOfSyncReplicas(configs.head.replicaLagTimeMaxMs) + assertEquals(Set.empty[Int], partition0OSR, "No replica should be out of sync") + + // let some time pass + time.sleep(150) + + // even though follower hasn't pulled any data for > replicaMaxLagTimeMs ms, the follower has already caught up. So it is not out-of-sync. + partition0OSR = partition0.getOutOfSyncReplicas(configs.head.replicaLagTimeMaxMs) + assertEquals(Set.empty[Int], partition0OSR, "No replica should be out of sync") + EasyMock.verify(log) + } + + private def getPartitionWithAllReplicasInIsr(topic: String, partitionId: Int, time: Time, config: KafkaConfig, + localLog: UnifiedLog): Partition = { + val leaderId = config.brokerId + val tp = new TopicPartition(topic, partitionId) + val partition = replicaManager.createPartition(tp) + partition.setLog(localLog, isFutureLog = false) + + partition.updateAssignmentAndIsr( + assignment = configs.map(_.brokerId), + isr = configs.map(_.brokerId).toSet, + addingReplicas = Seq.empty, + removingReplicas = Seq.empty + ) + + // set lastCaughtUpTime to current time + for (replica <- partition.remoteReplicas) + replica.updateFetchState( + followerFetchOffsetMetadata = LogOffsetMetadata(0L), + followerStartOffset = 0L, + followerFetchTimeMs= time.milliseconds, + leaderEndOffset = 0L) + + // set the leader and its hw and the hw update time + partition.leaderReplicaIdOpt = Some(leaderId) + partition + } + + private def logMock: UnifiedLog = { + val log: UnifiedLog = EasyMock.createMock(classOf[UnifiedLog]) + EasyMock.expect(log.dir).andReturn(TestUtils.tempDir()).anyTimes() + EasyMock.expect(log.logEndOffsetMetadata).andReturn(LogOffsetMetadata(leaderLogEndOffset)).anyTimes() + EasyMock.expect(log.logEndOffset).andReturn(leaderLogEndOffset).anyTimes() + EasyMock.expect(log.highWatermark).andReturn(leaderLogHighWatermark).anyTimes() + EasyMock.replay(log) + log + } +} diff --git a/core/src/test/scala/unit/kafka/server/KRaftMetadataTest.scala.tmp b/core/src/test/scala/unit/kafka/server/KRaftMetadataTest.scala.tmp new file mode 100644 index 0000000..8121234 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/KRaftMetadataTest.scala.tmp @@ -0,0 +1,161 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import kafka.test.ClusterInstance +import kafka.test.annotation.{ClusterTest, ClusterTestDefaults, Type} +import kafka.test.junit.ClusterTestExtensions +import kafka.test.junit.RaftClusterInvocationContext.RaftClusterInstance +import kafka.utils.TestUtils +import org.apache.kafka.common.feature.{Features, SupportedVersionRange} +import org.apache.kafka.common.message.BrokerRegistrationRequestData.{Listener, ListenerCollection} +import org.apache.kafka.common.metadata.MetadataRecordType.{FEATURE_LEVEL_RECORD, REGISTER_BROKER_RECORD} +import org.apache.kafka.common.metadata.{FeatureLevelRecord, RegisterBrokerRecord} +import org.apache.kafka.common.utils.Utils +import org.apache.kafka.controller.QuorumController +import org.apache.kafka.metadata.{FeatureMapAndEpoch, VersionRange} +import org.apache.kafka.server.common.ApiMessageAndVersion +import org.junit.jupiter.api.Tag +import org.junit.jupiter.api.extension.ExtendWith + +import scala.jdk.CollectionConverters._ + +@ExtendWith(value = Array(classOf[ClusterTestExtensions])) +@ClusterTestDefaults(clusterType = Type.KRAFT, brokers = 3, controllers = 3) +@Tag("integration") +class KRaftMetadataTest(cluster: ClusterInstance) { + + var brokerServers: Seq[BrokerServer] = _ + var controllerServers: Seq[ControllerServer] = _ + var activeController: ControllerServer = _ + var epoch: Int = _ + + def updateFinalizedVersion(apiVersionAndMessages: List[ApiMessageAndVersion]): Unit = { + val offset = updateMetadata(apiVersionAndMessages) + brokerServers.foreach(s => { + s.featureCache.waitUntilEpochOrThrow(offset, s.config.zkConnectionTimeoutMs) + }) + TestUtils.waitUntilTrue( + () => try { + activeController.controller.finalizedFeatures().get() // .map().features() + true + } catch { + case _: Throwable => false + }, + "Controller did not get broker updates" + ) + } + + def updateSupportedVersion(features: Features[SupportedVersionRange], + targetServers: Seq[BrokerServer]): Unit = { + targetServers.foreach(brokerServer => { + TestUtils.waitUntilTrue(() => brokerServer.lifecycleManager.brokerEpoch != -1, "broker registration failed") + brokerServer.brokerFeatures.setSupportedFeatures(features) + updateMetadata(List(toApiMessageAndVersion(features, brokerServer))) + }) + + val brokerRegistrations = activeController.controller.asInstanceOf[QuorumController].brokerRegistrations() + brokerRegistrations.asScala.foreach { case (_, value) => + TestUtils.waitUntilTrue( + () => value.supportedFeatures().asScala == toVersionRanges(features), + "Controller did not get broker updates" + ) + } + } + + def toVersionRanges(features: Features[SupportedVersionRange]): Map[String, VersionRange] = { + features.features().asScala.map { case (key, value) => + (key, new VersionRange(value.min(), value.max())) + }.toMap + } + + def toApiMessageAndVersion(features: Features[SupportedVersionRange], + brokerServer: BrokerServer): ApiMessageAndVersion = { + val networkListeners = new ListenerCollection() + brokerServer.config.advertisedListeners.foreach { ep => + networkListeners.add(new Listener(). + setHost(ep.host). + setName(ep.listenerName.value()). + setPort(ep.port). + setSecurityProtocol(ep.securityProtocol.id)) + } + + val featureCollection = new RegisterBrokerRecord.BrokerFeatureCollection() + features.features().asScala.foreach{ feature => + featureCollection.add(new RegisterBrokerRecord.BrokerFeature() + .setName(feature._1) + .setMinSupportedVersion(feature._2.min()) + .setMaxSupportedVersion(feature._2.max())) + } + new ApiMessageAndVersion( + new RegisterBrokerRecord() + .setBrokerId(brokerServer.config.nodeId) + .setEndPoints(new RegisterBrokerRecord.BrokerEndpointCollection()) + .setBrokerEpoch(brokerServer.lifecycleManager.brokerEpoch) + .setFeatures(featureCollection), + REGISTER_BROKER_RECORD.highestSupportedVersion() + ) + } + + def updateMetadata(apiVersionAndMessages: List[ApiMessageAndVersion]): Long = { + // Append to controller + val offset = activeController.controller.asInstanceOf[QuorumController].updateMetadata(apiVersionAndMessages.asJava) + // Wait raft response + offset.get() + } + + def getFeatureMetadataData(): FeatureMapAndEpoch = + activeController.controller.finalizedFeatures().get() + + @ClusterTest + def testUpdateFinalizedVersion(): Unit = { + val raftCluster = cluster.asInstanceOf[RaftClusterInstance] + activeController = raftCluster.activeController() + brokerServers = raftCluster.brokerServers().asScala.toSeq + controllerServers = raftCluster.controllerServers().asScala.toSeq + epoch = activeController.controller.curClaimEpoch() + + updateFinalizedVersion( + List( + new ApiMessageAndVersion( + new FeatureLevelRecord() + .setName("feature") + .setMinFeatureLevel(1) + .setMaxFeatureLevel(2), + FEATURE_LEVEL_RECORD.highestSupportedVersion() + ) + ) + ) + + println(getFeatureMetadataData()) + } + + @ClusterTest + def testUpdateSupportedVersion(): Unit = { + val raftCluster = cluster.asInstanceOf[RaftClusterInstance] + activeController = raftCluster.activeController() + brokerServers = raftCluster.brokerServers().asScala.toSeq + controllerServers = raftCluster.controllerServers().asScala.toSeq + epoch = activeController.controller.curClaimEpoch() + + updateSupportedVersion( + Features.supportedFeatures(Utils.mkMap(Utils.mkEntry("feature_1", new SupportedVersionRange(1, 3)))), + brokerServers + ) + } + +} diff --git a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala new file mode 100644 index 0000000..98c2666 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala @@ -0,0 +1,4218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.net.InetAddress +import java.nio.charset.StandardCharsets +import java.util +import java.util.Arrays.asList +import java.util.concurrent.TimeUnit +import java.util.{Collections, Optional, Properties, Random} + +import kafka.api.{ApiVersion, KAFKA_0_10_2_IV0, KAFKA_2_2_IV1, LeaderAndIsr} +import kafka.cluster.{Broker, Partition} +import kafka.controller.{ControllerContext, KafkaController} +import kafka.coordinator.group.GroupCoordinatorConcurrencyTest.{JoinGroupCallback, SyncGroupCallback} +import kafka.coordinator.group._ +import kafka.coordinator.transaction.{InitProducerIdResult, TransactionCoordinator} +import kafka.log.AppendOrigin +import kafka.network.RequestChannel +import kafka.server.QuotaFactory.QuotaManagers +import kafka.server.metadata.{ConfigRepository, KRaftMetadataCache, MockConfigRepository, ZkMetadataCache} +import kafka.utils.{MockTime, TestUtils} +import kafka.zk.KafkaZkClient +import org.apache.kafka.clients.admin.AlterConfigOp.OpType +import org.apache.kafka.clients.admin.{AlterConfigOp, ConfigEntry} +import org.apache.kafka.common.acl.AclOperation +import org.apache.kafka.common.config.ConfigResource +import org.apache.kafka.common.errors.UnsupportedVersionException +import org.apache.kafka.common.internals.{KafkaFutureImpl, Topic} +import org.apache.kafka.common.memory.MemoryPool +import org.apache.kafka.common.message.ApiMessageType.ListenerType +import org.apache.kafka.common.message.CreateTopicsRequestData.{CreatableTopic, CreatableTopicCollection} +import org.apache.kafka.common.message.DescribeConfigsResponseData.DescribeConfigsResult +import org.apache.kafka.common.message.JoinGroupRequestData.JoinGroupRequestProtocol +import org.apache.kafka.common.message.LeaveGroupRequestData.MemberIdentity +import org.apache.kafka.common.message.ListOffsetsRequestData.{ListOffsetsPartition, ListOffsetsTopic} +import org.apache.kafka.common.message.MetadataResponseData.MetadataResponseTopic +import org.apache.kafka.common.message.OffsetDeleteRequestData.{OffsetDeleteRequestPartition, OffsetDeleteRequestTopic, OffsetDeleteRequestTopicCollection} +import org.apache.kafka.common.message.StopReplicaRequestData.{StopReplicaPartitionState, StopReplicaTopicState} +import org.apache.kafka.common.message.UpdateMetadataRequestData.{UpdateMetadataBroker, UpdateMetadataEndpoint, UpdateMetadataPartitionState} +import org.apache.kafka.common.message._ +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.network.{ClientInformation, ListenerName} +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.quota.{ClientQuotaAlteration, ClientQuotaEntity} +import org.apache.kafka.common.record.FileRecords.TimestampAndOffset +import org.apache.kafka.common.record._ +import org.apache.kafka.common.replica.ClientMetadata +import org.apache.kafka.common.requests.FindCoordinatorRequest.CoordinatorType +import org.apache.kafka.common.requests.MetadataResponse.TopicMetadata +import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse +import org.apache.kafka.common.requests.WriteTxnMarkersRequest.TxnMarkerEntry +import org.apache.kafka.common.requests.{FetchMetadata => JFetchMetadata, _} +import org.apache.kafka.common.resource.{PatternType, Resource, ResourcePattern, ResourceType} +import org.apache.kafka.common.security.auth.{KafkaPrincipal, KafkaPrincipalSerde, SecurityProtocol} +import org.apache.kafka.common.utils.{ProducerIdAndEpoch, SecurityUtils, Utils} +import org.apache.kafka.common.{ElectionType, IsolationLevel, Node, TopicIdPartition, TopicPartition, Uuid} +import org.apache.kafka.server.authorizer.{Action, AuthorizationResult, Authorizer} +import org.easymock.EasyMock._ +import org.easymock.{Capture, EasyMock, IAnswer} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, Test} +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ValueSource +import org.mockito.{ArgumentMatchers, Mockito} + +import scala.collection.{Map, Seq, mutable} +import scala.jdk.CollectionConverters._ +import java.util.Arrays + +class KafkaApisTest { + + private val requestChannel: RequestChannel = EasyMock.createNiceMock(classOf[RequestChannel]) + private val requestChannelMetrics: RequestChannel.Metrics = EasyMock.createNiceMock(classOf[RequestChannel.Metrics]) + private val replicaManager: ReplicaManager = EasyMock.createNiceMock(classOf[ReplicaManager]) + private val groupCoordinator: GroupCoordinator = EasyMock.createNiceMock(classOf[GroupCoordinator]) + private val adminManager: ZkAdminManager = EasyMock.createNiceMock(classOf[ZkAdminManager]) + private val txnCoordinator: TransactionCoordinator = EasyMock.createNiceMock(classOf[TransactionCoordinator]) + private val controller: KafkaController = EasyMock.createNiceMock(classOf[KafkaController]) + private val forwardingManager: ForwardingManager = EasyMock.createNiceMock(classOf[ForwardingManager]) + private val autoTopicCreationManager: AutoTopicCreationManager = EasyMock.createNiceMock(classOf[AutoTopicCreationManager]) + + private val kafkaPrincipalSerde = new KafkaPrincipalSerde { + override def serialize(principal: KafkaPrincipal): Array[Byte] = Utils.utf8(principal.toString) + override def deserialize(bytes: Array[Byte]): KafkaPrincipal = SecurityUtils.parseKafkaPrincipal(Utils.utf8(bytes)) + } + private val zkClient: KafkaZkClient = EasyMock.createNiceMock(classOf[KafkaZkClient]) + private val metrics = new Metrics() + private val brokerId = 1 + private var metadataCache: MetadataCache = MetadataCache.zkMetadataCache(brokerId) + private val clientQuotaManager: ClientQuotaManager = EasyMock.createNiceMock(classOf[ClientQuotaManager]) + private val clientRequestQuotaManager: ClientRequestQuotaManager = EasyMock.createNiceMock(classOf[ClientRequestQuotaManager]) + private val clientControllerQuotaManager: ControllerMutationQuotaManager = EasyMock.createNiceMock(classOf[ControllerMutationQuotaManager]) + private val replicaQuotaManager: ReplicationQuotaManager = EasyMock.createNiceMock(classOf[ReplicationQuotaManager]) + private val quotas = QuotaManagers(clientQuotaManager, clientQuotaManager, clientRequestQuotaManager, + clientControllerQuotaManager, replicaQuotaManager, replicaQuotaManager, replicaQuotaManager, None) + private val fetchManager: FetchManager = EasyMock.createNiceMock(classOf[FetchManager]) + private val brokerTopicStats = new BrokerTopicStats + private val clusterId = "clusterId" + private val time = new MockTime + private val clientId = "" + + @AfterEach + def tearDown(): Unit = { + quotas.shutdown() + TestUtils.clearYammerMetrics() + metrics.close() + } + + def createKafkaApis(interBrokerProtocolVersion: ApiVersion = ApiVersion.latestVersion, + authorizer: Option[Authorizer] = None, + enableForwarding: Boolean = false, + configRepository: ConfigRepository = new MockConfigRepository(), + raftSupport: Boolean = false, + overrideProperties: Map[String, String] = Map.empty): KafkaApis = { + + val properties = if (raftSupport) { + val properties = TestUtils.createBrokerConfig(brokerId, "") + properties.put(KafkaConfig.NodeIdProp, brokerId.toString) + properties.put(KafkaConfig.ProcessRolesProp, "broker") + val voterId = (brokerId + 1) + properties.put(KafkaConfig.QuorumVotersProp, s"$voterId@localhost:9093") + properties.put(KafkaConfig.ControllerListenerNamesProp, "SSL") + properties + } else { + TestUtils.createBrokerConfig(brokerId, "zk") + } + overrideProperties.foreach( p => properties.put(p._1, p._2)) + TestUtils.setIbpAndMessageFormatVersions(properties, interBrokerProtocolVersion) + val config = new KafkaConfig(properties) + + val forwardingManagerOpt = if (enableForwarding) + Some(this.forwardingManager) + else + None + + val metadataSupport = if (raftSupport) { + // it will be up to the test to replace the default ZkMetadataCache implementation + // with a KRaftMetadataCache instance + metadataCache match { + case cache: KRaftMetadataCache => RaftSupport(forwardingManager, cache) + case _ => throw new IllegalStateException("Test must set an instance of KRaftMetadataCache") + } + } else { + metadataCache match { + case zkMetadataCache: ZkMetadataCache => + ZkSupport(adminManager, controller, zkClient, forwardingManagerOpt, zkMetadataCache) + case _ => throw new IllegalStateException("Test must set an instance of ZkMetadataCache") + } + } + + val listenerType = if (raftSupport) ListenerType.BROKER else ListenerType.ZK_BROKER + val enabledApis = if (enableForwarding) { + ApiKeys.apisForListener(listenerType).asScala ++ Set(ApiKeys.ENVELOPE) + } else { + ApiKeys.apisForListener(listenerType).asScala.toSet + } + val apiVersionManager = new SimpleApiVersionManager(listenerType, enabledApis) + + new KafkaApis( + metadataSupport = metadataSupport, + requestChannel = requestChannel, + replicaManager = replicaManager, + groupCoordinator = groupCoordinator, + txnCoordinator = txnCoordinator, + autoTopicCreationManager = autoTopicCreationManager, + brokerId = brokerId, + config = config, + configRepository = configRepository, + metadataCache = metadataCache, + metrics = metrics, + authorizer = authorizer, + quotas = quotas, + fetchManager = fetchManager, + brokerTopicStats = brokerTopicStats, + clusterId = clusterId, + time = time, + tokenManager = null, + apiVersionManager = apiVersionManager) + } + + @Test + def testDescribeConfigsWithAuthorizer(): Unit = { + val authorizer: Authorizer = EasyMock.niceMock(classOf[Authorizer]) + + val operation = AclOperation.DESCRIBE_CONFIGS + val resourceType = ResourceType.TOPIC + val resourceName = "topic-1" + val requestHeader = new RequestHeader(ApiKeys.DESCRIBE_CONFIGS, ApiKeys.DESCRIBE_CONFIGS.latestVersion, + clientId, 0) + + val expectedActions = Seq( + new Action(operation, new ResourcePattern(resourceType, resourceName, PatternType.LITERAL), + 1, true, true) + ) + + // Verify that authorize is only called once + EasyMock.expect(authorizer.authorize(anyObject[RequestContext], EasyMock.eq(expectedActions.asJava))) + .andReturn(Seq(AuthorizationResult.ALLOWED).asJava) + .once() + + val configRepository: ConfigRepository = EasyMock.strictMock(classOf[ConfigRepository]) + val topicConfigs = new Properties() + val propName = "min.insync.replicas" + val propValue = "3" + topicConfigs.put(propName, propValue) + EasyMock.expect(configRepository.topicConfig(resourceName)).andReturn(topicConfigs) + + metadataCache = + EasyMock.partialMockBuilder(classOf[ZkMetadataCache]) + .withConstructor(classOf[Int]) + .withArgs(Int.box(brokerId)) // Need to box it for Scala 2.12 and before + .addMockedMethod("contains", classOf[String]) + .createMock() + + expect(metadataCache.contains(resourceName)).andReturn(true) + + val describeConfigsRequest = new DescribeConfigsRequest.Builder(new DescribeConfigsRequestData() + .setIncludeSynonyms(true) + .setResources(List(new DescribeConfigsRequestData.DescribeConfigsResource() + .setResourceName(resourceName) + .setResourceType(ConfigResource.Type.TOPIC.id)).asJava)) + .build(requestHeader.apiVersion) + val request = buildRequest(describeConfigsRequest, + requestHeader = Option(requestHeader)) + val capturedResponse = expectNoThrottling(request) + + EasyMock.replay(metadataCache, replicaManager, clientRequestQuotaManager, requestChannel, + authorizer, configRepository, adminManager) + createKafkaApis(authorizer = Some(authorizer), configRepository = configRepository) + .handleDescribeConfigsRequest(request) + + verify(authorizer, replicaManager) + + val response = capturedResponse.getValue.asInstanceOf[DescribeConfigsResponse] + val results = response.data().results() + assertEquals(1, results.size()) + val describeConfigsResult: DescribeConfigsResult = results.get(0) + assertEquals(ConfigResource.Type.TOPIC.id, describeConfigsResult.resourceType()) + assertEquals(resourceName, describeConfigsResult.resourceName()) + val configs = describeConfigsResult.configs().asScala.filter(_.name() == propName) + assertEquals(1, configs.length) + val describeConfigsResponseData = configs.head + assertEquals(propName, describeConfigsResponseData.name()) + assertEquals(propValue, describeConfigsResponseData.value()) + } + + @Test + def testEnvelopeRequestHandlingAsController(): Unit = { + testEnvelopeRequestWithAlterConfig( + alterConfigHandler = () => ApiError.NONE, + expectedError = Errors.NONE + ) + } + + @Test + def testEnvelopeRequestWithAlterConfigUnhandledError(): Unit = { + testEnvelopeRequestWithAlterConfig( + alterConfigHandler = () => throw new IllegalStateException(), + expectedError = Errors.UNKNOWN_SERVER_ERROR + ) + } + + private def testEnvelopeRequestWithAlterConfig( + alterConfigHandler: () => ApiError, + expectedError: Errors + ): Unit = { + val authorizer: Authorizer = EasyMock.niceMock(classOf[Authorizer]) + + authorizeResource(authorizer, AclOperation.CLUSTER_ACTION, ResourceType.CLUSTER, Resource.CLUSTER_NAME, AuthorizationResult.ALLOWED) + + val operation = AclOperation.ALTER_CONFIGS + val resourceName = "topic-1" + val requestHeader = new RequestHeader(ApiKeys.ALTER_CONFIGS, ApiKeys.ALTER_CONFIGS.latestVersion, + clientId, 0) + + EasyMock.expect(controller.isActive).andReturn(true) + + authorizeResource(authorizer, operation, ResourceType.TOPIC, resourceName, AuthorizationResult.ALLOWED) + + val configResource = new ConfigResource(ConfigResource.Type.TOPIC, resourceName) + EasyMock.expect(adminManager.alterConfigs(anyObject(), EasyMock.eq(false))) + .andAnswer(() => { + Map(configResource -> alterConfigHandler.apply()) + }) + + val configs = Map( + configResource -> new AlterConfigsRequest.Config( + Seq(new AlterConfigsRequest.ConfigEntry("foo", "bar")).asJava)) + val alterConfigsRequest = new AlterConfigsRequest.Builder(configs.asJava, false).build(requestHeader.apiVersion) + + val request = TestUtils.buildRequestWithEnvelope( + alterConfigsRequest, kafkaPrincipalSerde, requestChannelMetrics, time.nanoseconds()) + + val capturedResponse = EasyMock.newCapture[AbstractResponse]() + val capturedRequest = EasyMock.newCapture[RequestChannel.Request]() + + EasyMock.expect(requestChannel.sendResponse( + EasyMock.capture(capturedRequest), + EasyMock.capture(capturedResponse), + EasyMock.anyObject() + )) + + EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, authorizer, + adminManager, controller) + createKafkaApis(authorizer = Some(authorizer), enableForwarding = true).handle(request, RequestLocal.withThreadConfinedCaching) + + assertEquals(Some(request), capturedRequest.getValue.envelope) + val innerResponse = capturedResponse.getValue.asInstanceOf[AlterConfigsResponse] + val responseMap = innerResponse.data.responses().asScala.map { resourceResponse => + resourceResponse.resourceName() -> Errors.forCode(resourceResponse.errorCode) + }.toMap + + assertEquals(Map(resourceName -> expectedError), responseMap) + + verify(authorizer, controller, adminManager) + } + + @Test + def testInvalidEnvelopeRequestWithNonForwardableAPI(): Unit = { + val requestHeader = new RequestHeader(ApiKeys.LEAVE_GROUP, ApiKeys.LEAVE_GROUP.latestVersion, + clientId, 0) + val leaveGroupRequest = new LeaveGroupRequest.Builder("group", + Collections.singletonList(new MemberIdentity())).build(requestHeader.apiVersion) + + EasyMock.expect(controller.isActive).andReturn(true) + + val request = TestUtils.buildRequestWithEnvelope( + leaveGroupRequest, kafkaPrincipalSerde, requestChannelMetrics, time.nanoseconds()) + + val capturedResponse = expectNoThrottling(request) + + EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, controller) + createKafkaApis(enableForwarding = true).handle(request, RequestLocal.withThreadConfinedCaching) + + val response = capturedResponse.getValue.asInstanceOf[EnvelopeResponse] + assertEquals(Errors.INVALID_REQUEST, response.error()) + } + + @Test + def testEnvelopeRequestWithNotFromPrivilegedListener(): Unit = { + testInvalidEnvelopeRequest(Errors.NONE, fromPrivilegedListener = false, + shouldCloseConnection = true) + } + + @Test + def testEnvelopeRequestNotAuthorized(): Unit = { + testInvalidEnvelopeRequest(Errors.CLUSTER_AUTHORIZATION_FAILED, + performAuthorize = true, authorizeResult = AuthorizationResult.DENIED) + } + + @Test + def testEnvelopeRequestNotControllerHandling(): Unit = { + testInvalidEnvelopeRequest(Errors.NOT_CONTROLLER, performAuthorize = true, isActiveController = false) + } + + private def testInvalidEnvelopeRequest(expectedError: Errors, + fromPrivilegedListener: Boolean = true, + shouldCloseConnection: Boolean = false, + performAuthorize: Boolean = false, + authorizeResult: AuthorizationResult = AuthorizationResult.ALLOWED, + isActiveController: Boolean = true): Unit = { + val authorizer: Authorizer = EasyMock.niceMock(classOf[Authorizer]) + + if (performAuthorize) { + authorizeResource(authorizer, AclOperation.CLUSTER_ACTION, ResourceType.CLUSTER, Resource.CLUSTER_NAME, authorizeResult) + } + + val resourceName = "topic-1" + val requestHeader = new RequestHeader(ApiKeys.ALTER_CONFIGS, ApiKeys.ALTER_CONFIGS.latestVersion, + clientId, 0) + + EasyMock.expect(controller.isActive).andReturn(isActiveController) + + val configResource = new ConfigResource(ConfigResource.Type.TOPIC, resourceName) + + val configs = Map( + configResource -> new AlterConfigsRequest.Config( + Seq(new AlterConfigsRequest.ConfigEntry("foo", "bar")).asJava)) + val alterConfigsRequest = new AlterConfigsRequest.Builder(configs.asJava, false) + .build(requestHeader.apiVersion) + + val request = TestUtils.buildRequestWithEnvelope( + alterConfigsRequest, kafkaPrincipalSerde, requestChannelMetrics, time.nanoseconds(), fromPrivilegedListener) + + val capturedResponse = EasyMock.newCapture[AbstractResponse]() + if (shouldCloseConnection) { + EasyMock.expect(requestChannel.closeConnection( + EasyMock.eq(request), + EasyMock.eq(java.util.Collections.emptyMap()) + )) + } else { + EasyMock.expect(requestChannel.sendResponse( + EasyMock.eq(request), + EasyMock.capture(capturedResponse), + EasyMock.eq(None) + )) + } + + EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, authorizer, + adminManager, controller) + createKafkaApis(authorizer = Some(authorizer), enableForwarding = true).handle(request, RequestLocal.withThreadConfinedCaching) + + if (!shouldCloseConnection) { + val response = capturedResponse.getValue.asInstanceOf[EnvelopeResponse] + assertEquals(expectedError, response.error) + } + + verify(authorizer, adminManager, requestChannel) + } + + @Test + def testAlterConfigsWithAuthorizer(): Unit = { + val authorizer: Authorizer = EasyMock.niceMock(classOf[Authorizer]) + + val authorizedTopic = "authorized-topic" + val unauthorizedTopic = "unauthorized-topic" + val (authorizedResource, unauthorizedResource) = + createConfigsWithAuthorization(authorizer, authorizedTopic, unauthorizedTopic) + + val configs = Map( + authorizedResource -> new AlterConfigsRequest.Config( + Seq(new AlterConfigsRequest.ConfigEntry("foo", "bar")).asJava), + unauthorizedResource -> new AlterConfigsRequest.Config( + Seq(new AlterConfigsRequest.ConfigEntry("foo-1", "bar-1")).asJava) + ) + + val topicHeader = new RequestHeader(ApiKeys.ALTER_CONFIGS, ApiKeys.ALTER_CONFIGS.latestVersion, + clientId, 0) + + val alterConfigsRequest = new AlterConfigsRequest.Builder(configs.asJava, false) + .build(topicHeader.apiVersion) + val request = buildRequest(alterConfigsRequest) + + EasyMock.expect(controller.isActive).andReturn(false) + + val capturedResponse = expectNoThrottling(request) + + EasyMock.expect(adminManager.alterConfigs(anyObject(), EasyMock.eq(false))) + .andReturn(Map(authorizedResource -> ApiError.NONE)) + + EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, authorizer, + adminManager, controller) + + createKafkaApis(authorizer = Some(authorizer)).handleAlterConfigsRequest(request) + + verifyAlterConfigResult(capturedResponse, Map(authorizedTopic -> Errors.NONE, + unauthorizedTopic -> Errors.TOPIC_AUTHORIZATION_FAILED)) + + verify(authorizer, adminManager) + } + + @Test + def testAlterConfigsWithForwarding(): Unit = { + val requestBuilder = new AlterConfigsRequest.Builder(Collections.emptyMap(), false) + testForwardableApi(ApiKeys.ALTER_CONFIGS, requestBuilder) + } + + @Test + def testElectLeadersForwarding(): Unit = { + val requestBuilder = new ElectLeadersRequest.Builder(ElectionType.PREFERRED, null, 30000) + testKraftForwarding(ApiKeys.ELECT_LEADERS, requestBuilder) + } + + @Test + def testDescribeQuorumNotAllowedForZkClusters(): Unit = { + val requestData = DescribeQuorumRequest.singletonRequest(KafkaRaftServer.MetadataPartition) + val requestBuilder = new DescribeQuorumRequest.Builder(requestData) + val request = buildRequest(requestBuilder.build()) + + val capturedResponse = expectNoThrottling(request) + EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, adminManager, controller) + createKafkaApis(enableForwarding = true).handle(request, RequestLocal.withThreadConfinedCaching) + + val response = capturedResponse.getValue.asInstanceOf[DescribeQuorumResponse] + assertEquals(Errors.UNKNOWN_SERVER_ERROR, Errors.forCode(response.data.errorCode)) + } + + @Test + def testDescribeQuorumForwardedForKRaftClusters(): Unit = { + val requestData = DescribeQuorumRequest.singletonRequest(KafkaRaftServer.MetadataPartition) + val requestBuilder = new DescribeQuorumRequest.Builder(requestData) + metadataCache = MetadataCache.kRaftMetadataCache(brokerId) + + testForwardableApi( + createKafkaApis(raftSupport = true), + ApiKeys.DESCRIBE_QUORUM, + requestBuilder + ) + } + + private def testKraftForwarding( + apiKey: ApiKeys, + requestBuilder: AbstractRequest.Builder[_ <: AbstractRequest] + ): Unit = { + metadataCache = MetadataCache.kRaftMetadataCache(brokerId) + testForwardableApi( + createKafkaApis(enableForwarding = true, raftSupport = true), + apiKey, + requestBuilder + ) + } + + private def testForwardableApi(apiKey: ApiKeys, requestBuilder: AbstractRequest.Builder[_ <: AbstractRequest]): Unit = { + testForwardableApi( + createKafkaApis(enableForwarding = true), + apiKey, + requestBuilder + ) + } + + private def testForwardableApi( + kafkaApis: KafkaApis, + apiKey: ApiKeys, + requestBuilder: AbstractRequest.Builder[_ <: AbstractRequest] + ): Unit = { + val topicHeader = new RequestHeader(apiKey, apiKey.latestVersion, + clientId, 0) + + val apiRequest = requestBuilder.build(topicHeader.apiVersion) + val request = buildRequest(apiRequest) + + if (kafkaApis.metadataSupport.isInstanceOf[ZkSupport]) { + // The controller check only makes sense for ZK clusters. For KRaft, + // controller requests are handled on a separate listener, so there + // is no choice but to forward them. + EasyMock.expect(controller.isActive).andReturn(false) + } + + val capturedResponse = expectNoThrottling(request) + val forwardCallback: Capture[Option[AbstractResponse] => Unit] = EasyMock.newCapture() + + EasyMock.expect(forwardingManager.forwardRequest( + EasyMock.eq(request), + EasyMock.capture(forwardCallback) + )).once() + + EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, controller, forwardingManager) + + kafkaApis.handle(request, RequestLocal.withThreadConfinedCaching) + assertNotNull(request.buffer, "The buffer was unexpectedly deallocated after " + + s"`handle` returned (is $apiKey marked as forwardable in `ApiKeys`?)") + + val expectedResponse = apiRequest.getErrorResponse(Errors.NOT_CONTROLLER.exception) + assertTrue(forwardCallback.hasCaptured) + forwardCallback.getValue.apply(Some(expectedResponse)) + + assertTrue(capturedResponse.hasCaptured) + assertEquals(expectedResponse, capturedResponse.getValue) + + EasyMock.verify(controller, requestChannel, forwardingManager) + } + + private def authorizeResource(authorizer: Authorizer, + operation: AclOperation, + resourceType: ResourceType, + resourceName: String, + result: AuthorizationResult, + logIfAllowed: Boolean = true, + logIfDenied: Boolean = true): Unit = { + val expectedAuthorizedAction = if (operation == AclOperation.CLUSTER_ACTION) + new Action(operation, + new ResourcePattern(ResourceType.CLUSTER, Resource.CLUSTER_NAME, PatternType.LITERAL), + 1, logIfAllowed, logIfDenied) + else + new Action(operation, + new ResourcePattern(resourceType, resourceName, PatternType.LITERAL), + 1, logIfAllowed, logIfDenied) + + EasyMock.expect(authorizer.authorize(anyObject[RequestContext], EasyMock.eq(Seq(expectedAuthorizedAction).asJava))) + .andReturn(Seq(result).asJava) + .once() + } + + private def verifyAlterConfigResult(capturedResponse: Capture[AbstractResponse], + expectedResults: Map[String, Errors]): Unit = { + val response = capturedResponse.getValue.asInstanceOf[AlterConfigsResponse] + val responseMap = response.data.responses().asScala.map { resourceResponse => + resourceResponse.resourceName() -> Errors.forCode(resourceResponse.errorCode) + }.toMap + + assertEquals(expectedResults, responseMap) + } + + private def createConfigsWithAuthorization(authorizer: Authorizer, + authorizedTopic: String, + unauthorizedTopic: String): (ConfigResource, ConfigResource) = { + val authorizedResource = new ConfigResource(ConfigResource.Type.TOPIC, authorizedTopic) + + val unauthorizedResource = new ConfigResource(ConfigResource.Type.TOPIC, unauthorizedTopic) + + createTopicAuthorization(authorizer, AclOperation.ALTER_CONFIGS, authorizedTopic, unauthorizedTopic) + (authorizedResource, unauthorizedResource) + } + + @Test + def testIncrementalAlterConfigsWithAuthorizer(): Unit = { + val authorizer: Authorizer = EasyMock.niceMock(classOf[Authorizer]) + + val authorizedTopic = "authorized-topic" + val unauthorizedTopic = "unauthorized-topic" + val (authorizedResource, unauthorizedResource) = + createConfigsWithAuthorization(authorizer, authorizedTopic, unauthorizedTopic) + + val requestHeader = new RequestHeader(ApiKeys.INCREMENTAL_ALTER_CONFIGS, ApiKeys.INCREMENTAL_ALTER_CONFIGS.latestVersion, clientId, 0) + + val incrementalAlterConfigsRequest = getIncrementalAlterConfigRequestBuilder(Seq(authorizedResource, unauthorizedResource)) + .build(requestHeader.apiVersion) + val request = buildRequest(incrementalAlterConfigsRequest, + fromPrivilegedListener = true, requestHeader = Option(requestHeader)) + + EasyMock.expect(controller.isActive).andReturn(true) + + val capturedResponse = expectNoThrottling(request) + + EasyMock.expect(adminManager.incrementalAlterConfigs(anyObject(), EasyMock.eq(false))) + .andReturn(Map(authorizedResource -> ApiError.NONE)) + + EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, authorizer, + adminManager, controller) + createKafkaApis(authorizer = Some(authorizer)).handleIncrementalAlterConfigsRequest(request) + + verifyIncrementalAlterConfigResult(capturedResponse, Map( + authorizedTopic -> Errors.NONE, + unauthorizedTopic -> Errors.TOPIC_AUTHORIZATION_FAILED + )) + + verify(authorizer, adminManager) + } + + @Test + def testIncrementalAlterConfigsWithForwarding(): Unit = { + val requestBuilder = new IncrementalAlterConfigsRequest.Builder( + new IncrementalAlterConfigsRequestData()) + testForwardableApi(ApiKeys.INCREMENTAL_ALTER_CONFIGS, requestBuilder) + } + + private def getIncrementalAlterConfigRequestBuilder(configResources: Seq[ConfigResource]): IncrementalAlterConfigsRequest.Builder = { + val resourceMap = configResources.map(configResource => { + configResource -> Set( + new AlterConfigOp(new ConfigEntry("foo", "bar"), + OpType.forId(configResource.`type`.id))).asJavaCollection + }).toMap.asJava + + new IncrementalAlterConfigsRequest.Builder(resourceMap, false) + } + + private def verifyIncrementalAlterConfigResult(capturedResponse: Capture[AbstractResponse], + expectedResults: Map[String, Errors]): Unit = { + val response = capturedResponse.getValue.asInstanceOf[IncrementalAlterConfigsResponse] + val responseMap = response.data.responses().asScala.map { resourceResponse => + resourceResponse.resourceName() -> Errors.forCode(resourceResponse.errorCode) + }.toMap + assertEquals(expectedResults, responseMap) + } + + @Test + def testAlterClientQuotasWithAuthorizer(): Unit = { + val authorizer: Authorizer = EasyMock.niceMock(classOf[Authorizer]) + + authorizeResource(authorizer, AclOperation.ALTER_CONFIGS, ResourceType.CLUSTER, + Resource.CLUSTER_NAME, AuthorizationResult.DENIED) + + val quotaEntity = new ClientQuotaEntity(Collections.singletonMap(ClientQuotaEntity.USER, "user")) + val quotas = Seq(new ClientQuotaAlteration(quotaEntity, Seq.empty.asJavaCollection)) + + val requestHeader = new RequestHeader(ApiKeys.ALTER_CLIENT_QUOTAS, ApiKeys.ALTER_CLIENT_QUOTAS.latestVersion, clientId, 0) + + val alterClientQuotasRequest = new AlterClientQuotasRequest.Builder(quotas.asJavaCollection, false) + .build(requestHeader.apiVersion) + val request = buildRequest(alterClientQuotasRequest, + fromPrivilegedListener = true, requestHeader = Option(requestHeader)) + + EasyMock.expect(controller.isActive).andReturn(true) + + val capturedResponse = expectNoThrottling(request) + + EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, authorizer, + adminManager, controller) + createKafkaApis(authorizer = Some(authorizer)).handleAlterClientQuotasRequest(request) + + verifyAlterClientQuotaResult(capturedResponse, Map(quotaEntity -> Errors.CLUSTER_AUTHORIZATION_FAILED)) + + verify(authorizer, adminManager) + } + + @Test + def testAlterClientQuotasWithForwarding(): Unit = { + val requestBuilder = new AlterClientQuotasRequest.Builder(List.empty.asJava, false) + testForwardableApi(ApiKeys.ALTER_CLIENT_QUOTAS, requestBuilder) + } + + private def verifyAlterClientQuotaResult(capturedResponse: Capture[AbstractResponse], + expected: Map[ClientQuotaEntity, Errors]): Unit = { + val response = capturedResponse.getValue.asInstanceOf[AlterClientQuotasResponse] + val futures = expected.keys.map(quotaEntity => quotaEntity -> new KafkaFutureImpl[Void]()).toMap + response.complete(futures.asJava) + futures.foreach { + case (entity, future) => + future.whenComplete((_, thrown) => + assertEquals(thrown, expected(entity).exception()) + ).isDone + } + } + + @Test + def testCreateTopicsWithAuthorizer(): Unit = { + val authorizer: Authorizer = EasyMock.niceMock(classOf[Authorizer]) + + val authorizedTopic = "authorized-topic" + val unauthorizedTopic = "unauthorized-topic" + + authorizeResource(authorizer, AclOperation.CREATE, ResourceType.CLUSTER, + Resource.CLUSTER_NAME, AuthorizationResult.DENIED, logIfDenied = false) + + createCombinedTopicAuthorization(authorizer, AclOperation.CREATE, + authorizedTopic, unauthorizedTopic) + + createCombinedTopicAuthorization(authorizer, AclOperation.DESCRIBE_CONFIGS, + authorizedTopic, unauthorizedTopic, logIfDenied = false) + + val requestHeader = new RequestHeader(ApiKeys.CREATE_TOPICS, ApiKeys.CREATE_TOPICS.latestVersion, clientId, 0) + + EasyMock.expect(controller.isActive).andReturn(true) + + val topics = new CreateTopicsRequestData.CreatableTopicCollection(2) + val topicToCreate = new CreateTopicsRequestData.CreatableTopic() + .setName(authorizedTopic) + topics.add(topicToCreate) + + val topicToFilter = new CreateTopicsRequestData.CreatableTopic() + .setName(unauthorizedTopic) + topics.add(topicToFilter) + + val timeout = 10 + val createTopicsRequest = new CreateTopicsRequest.Builder( + new CreateTopicsRequestData() + .setTimeoutMs(timeout) + .setValidateOnly(false) + .setTopics(topics)) + .build(requestHeader.apiVersion) + val request = buildRequest(createTopicsRequest, + fromPrivilegedListener = true, requestHeader = Option(requestHeader)) + + val capturedResponse = expectNoThrottling(request) + + EasyMock.expect(clientControllerQuotaManager.newQuotaFor( + EasyMock.eq(request), EasyMock.eq(6))).andReturn(UnboundedControllerMutationQuota) + + val capturedCallback = EasyMock.newCapture[Map[String, ApiError] => Unit]() + + EasyMock.expect(adminManager.createTopics( + EasyMock.eq(timeout), + EasyMock.eq(false), + EasyMock.eq(Map(authorizedTopic -> topicToCreate)), + anyObject(), + EasyMock.eq(UnboundedControllerMutationQuota), + EasyMock.capture(capturedCallback))) + + EasyMock.replay(replicaManager, clientRequestQuotaManager, clientControllerQuotaManager, + requestChannel, authorizer, adminManager, controller) + + createKafkaApis(authorizer = Some(authorizer)).handleCreateTopicsRequest(request) + + capturedCallback.getValue.apply(Map(authorizedTopic -> ApiError.NONE)) + + verifyCreateTopicsResult(createTopicsRequest, + capturedResponse, Map(authorizedTopic -> Errors.NONE, + unauthorizedTopic -> Errors.TOPIC_AUTHORIZATION_FAILED)) + + verify(authorizer, adminManager, clientControllerQuotaManager) + } + + @Test + def testCreateTopicsWithForwarding(): Unit = { + val requestBuilder = new CreateTopicsRequest.Builder( + new CreateTopicsRequestData().setTopics( + new CreatableTopicCollection(Collections.singleton( + new CreatableTopic().setName("topic").setNumPartitions(1). + setReplicationFactor(1.toShort)).iterator()))) + testForwardableApi(ApiKeys.CREATE_TOPICS, requestBuilder) + } + + private def createTopicAuthorization(authorizer: Authorizer, + operation: AclOperation, + authorizedTopic: String, + unauthorizedTopic: String, + logIfAllowed: Boolean = true, + logIfDenied: Boolean = true): Unit = { + authorizeResource(authorizer, operation, ResourceType.TOPIC, + authorizedTopic, AuthorizationResult.ALLOWED, logIfAllowed, logIfDenied) + authorizeResource(authorizer, operation, ResourceType.TOPIC, + unauthorizedTopic, AuthorizationResult.DENIED, logIfAllowed, logIfDenied) + } + + private def createCombinedTopicAuthorization(authorizer: Authorizer, + operation: AclOperation, + authorizedTopic: String, + unauthorizedTopic: String, + logIfAllowed: Boolean = true, + logIfDenied: Boolean = true): Unit = { + val expectedAuthorizedActions = Seq( + new Action(operation, + new ResourcePattern(ResourceType.TOPIC, authorizedTopic, PatternType.LITERAL), + 1, logIfAllowed, logIfDenied), + new Action(operation, + new ResourcePattern(ResourceType.TOPIC, unauthorizedTopic, PatternType.LITERAL), + 1, logIfAllowed, logIfDenied)) + + EasyMock.expect(authorizer.authorize( + anyObject[RequestContext], AuthHelperTest.matchSameElements(expectedAuthorizedActions.asJava) + )).andAnswer { () => + val actions = EasyMock.getCurrentArguments.apply(1).asInstanceOf[util.List[Action]].asScala + actions.map { action => + if (action.resourcePattern().name().equals(authorizedTopic)) + AuthorizationResult.ALLOWED + else + AuthorizationResult.DENIED + }.asJava + }.once() + } + + private def verifyCreateTopicsResult(createTopicsRequest: CreateTopicsRequest, + capturedResponse: Capture[AbstractResponse], + expectedResults: Map[String, Errors]): Unit = { + val response = capturedResponse.getValue.asInstanceOf[CreateTopicsResponse] + val responseMap = response.data.topics().asScala.map { topicResponse => + topicResponse.name() -> Errors.forCode(topicResponse.errorCode) + }.toMap + + assertEquals(expectedResults, responseMap) + } + + @Test + def testCreateAclWithForwarding(): Unit = { + val requestBuilder = new CreateAclsRequest.Builder(new CreateAclsRequestData()) + testForwardableApi(ApiKeys.CREATE_ACLS, requestBuilder) + } + + @Test + def testDeleteAclWithForwarding(): Unit = { + val requestBuilder = new DeleteAclsRequest.Builder(new DeleteAclsRequestData()) + testForwardableApi(ApiKeys.DELETE_ACLS, requestBuilder) + } + + @Test + def testCreateDelegationTokenWithForwarding(): Unit = { + val requestBuilder = new CreateDelegationTokenRequest.Builder(new CreateDelegationTokenRequestData()) + testForwardableApi(ApiKeys.CREATE_DELEGATION_TOKEN, requestBuilder) + } + + @Test + def testRenewDelegationTokenWithForwarding(): Unit = { + val requestBuilder = new RenewDelegationTokenRequest.Builder(new RenewDelegationTokenRequestData()) + testForwardableApi(ApiKeys.RENEW_DELEGATION_TOKEN, requestBuilder) + } + + @Test + def testExpireDelegationTokenWithForwarding(): Unit = { + val requestBuilder = new ExpireDelegationTokenRequest.Builder(new ExpireDelegationTokenRequestData()) + testForwardableApi(ApiKeys.EXPIRE_DELEGATION_TOKEN, requestBuilder) + } + + @Test + def testAlterPartitionReassignmentsWithForwarding(): Unit = { + val requestBuilder = new AlterPartitionReassignmentsRequest.Builder(new AlterPartitionReassignmentsRequestData()) + testForwardableApi(ApiKeys.ALTER_PARTITION_REASSIGNMENTS, requestBuilder) + } + + @Test + def testCreatePartitionsWithForwarding(): Unit = { + val requestBuilder = new CreatePartitionsRequest.Builder(new CreatePartitionsRequestData()) + testForwardableApi(ApiKeys.CREATE_PARTITIONS, requestBuilder) + } + + @Test + def testDeleteTopicsWithForwarding(): Unit = { + val requestBuilder = new DeleteTopicsRequest.Builder(new DeleteTopicsRequestData()) + testForwardableApi(ApiKeys.DELETE_TOPICS, requestBuilder) + } + + @Test + def testUpdateFeaturesWithForwarding(): Unit = { + val requestBuilder = new UpdateFeaturesRequest.Builder(new UpdateFeaturesRequestData()) + testForwardableApi(ApiKeys.UPDATE_FEATURES, requestBuilder) + } + + @Test + def testAlterScramWithForwarding(): Unit = { + val requestBuilder = new AlterUserScramCredentialsRequest.Builder(new AlterUserScramCredentialsRequestData()) + testForwardableApi(ApiKeys.ALTER_USER_SCRAM_CREDENTIALS, requestBuilder) + } + + @Test + def testFindCoordinatorAutoTopicCreationForOffsetTopic(): Unit = { + testFindCoordinatorWithTopicCreation(CoordinatorType.GROUP) + } + + @Test + def testFindCoordinatorAutoTopicCreationForTxnTopic(): Unit = { + testFindCoordinatorWithTopicCreation(CoordinatorType.TRANSACTION) + } + + @Test + def testFindCoordinatorNotEnoughBrokersForOffsetTopic(): Unit = { + testFindCoordinatorWithTopicCreation(CoordinatorType.GROUP, hasEnoughLiveBrokers = false) + } + + @Test + def testFindCoordinatorNotEnoughBrokersForTxnTopic(): Unit = { + testFindCoordinatorWithTopicCreation(CoordinatorType.TRANSACTION, hasEnoughLiveBrokers = false) + } + + @Test + def testOldFindCoordinatorAutoTopicCreationForOffsetTopic(): Unit = { + testFindCoordinatorWithTopicCreation(CoordinatorType.GROUP, version = 3) + } + + @Test + def testOldFindCoordinatorAutoTopicCreationForTxnTopic(): Unit = { + testFindCoordinatorWithTopicCreation(CoordinatorType.TRANSACTION, version = 3) + } + + @Test + def testOldFindCoordinatorNotEnoughBrokersForOffsetTopic(): Unit = { + testFindCoordinatorWithTopicCreation(CoordinatorType.GROUP, hasEnoughLiveBrokers = false, version = 3) + } + + @Test + def testOldFindCoordinatorNotEnoughBrokersForTxnTopic(): Unit = { + testFindCoordinatorWithTopicCreation(CoordinatorType.TRANSACTION, hasEnoughLiveBrokers = false, version = 3) + } + + private def testFindCoordinatorWithTopicCreation(coordinatorType: CoordinatorType, + hasEnoughLiveBrokers: Boolean = true, + version: Short = ApiKeys.FIND_COORDINATOR.latestVersion): Unit = { + val authorizer: Authorizer = EasyMock.niceMock(classOf[Authorizer]) + + val requestHeader = new RequestHeader(ApiKeys.FIND_COORDINATOR, version, clientId, 0) + + val numBrokersNeeded = 3 + + setupBrokerMetadata(hasEnoughLiveBrokers, numBrokersNeeded) + + val requestTimeout = 10 + val topicConfigOverride = mutable.Map.empty[String, String] + topicConfigOverride.put(KafkaConfig.RequestTimeoutMsProp, requestTimeout.toString) + + val groupId = "group" + val topicName = + coordinatorType match { + case CoordinatorType.GROUP => + topicConfigOverride.put(KafkaConfig.OffsetsTopicPartitionsProp, numBrokersNeeded.toString) + topicConfigOverride.put(KafkaConfig.OffsetsTopicReplicationFactorProp, numBrokersNeeded.toString) + EasyMock.expect(groupCoordinator.offsetsTopicConfigs).andReturn(new Properties) + authorizeResource(authorizer, AclOperation.DESCRIBE, ResourceType.GROUP, + groupId, AuthorizationResult.ALLOWED) + Topic.GROUP_METADATA_TOPIC_NAME + case CoordinatorType.TRANSACTION => + topicConfigOverride.put(KafkaConfig.TransactionsTopicPartitionsProp, numBrokersNeeded.toString) + topicConfigOverride.put(KafkaConfig.TransactionsTopicReplicationFactorProp, numBrokersNeeded.toString) + EasyMock.expect(txnCoordinator.transactionTopicConfigs).andReturn(new Properties) + authorizeResource(authorizer, AclOperation.DESCRIBE, ResourceType.TRANSACTIONAL_ID, + groupId, AuthorizationResult.ALLOWED) + Topic.TRANSACTION_STATE_TOPIC_NAME + case _ => + throw new IllegalStateException(s"Unknown coordinator type $coordinatorType") + } + + val findCoordinatorRequestBuilder = if (version >= 4) { + new FindCoordinatorRequest.Builder( + new FindCoordinatorRequestData() + .setKeyType(coordinatorType.id()) + .setCoordinatorKeys(Arrays.asList(groupId))) + } else { + new FindCoordinatorRequest.Builder( + new FindCoordinatorRequestData() + .setKeyType(coordinatorType.id()) + .setKey(groupId)) + } + val request = buildRequest(findCoordinatorRequestBuilder.build(requestHeader.apiVersion)) + + val capturedResponse = expectNoThrottling(request) + + val capturedRequest = verifyTopicCreation(topicName, true, true, request) + + EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, authorizer, + autoTopicCreationManager, forwardingManager, controller, clientControllerQuotaManager, groupCoordinator, txnCoordinator) + + createKafkaApis(authorizer = Some(authorizer), + overrideProperties = topicConfigOverride).handleFindCoordinatorRequest(request) + + val response = capturedResponse.getValue.asInstanceOf[FindCoordinatorResponse] + if (version >= 4) { + assertEquals(Errors.COORDINATOR_NOT_AVAILABLE.code, response.data.coordinators.get(0).errorCode) + assertEquals(groupId, response.data.coordinators.get(0).key) + } else { + assertEquals(Errors.COORDINATOR_NOT_AVAILABLE.code, response.data.errorCode) + } + + assertTrue(capturedRequest.getValue.isEmpty) + + verify(authorizer, autoTopicCreationManager) + } + + @Test + def testMetadataAutoTopicCreationForOffsetTopic(): Unit = { + testMetadataAutoTopicCreation(Topic.GROUP_METADATA_TOPIC_NAME, enableAutoTopicCreation = true, + expectedError = Errors.UNKNOWN_TOPIC_OR_PARTITION) + } + + @Test + def testMetadataAutoTopicCreationForTxnTopic(): Unit = { + testMetadataAutoTopicCreation(Topic.TRANSACTION_STATE_TOPIC_NAME, enableAutoTopicCreation = true, + expectedError = Errors.UNKNOWN_TOPIC_OR_PARTITION) + } + + @Test + def testMetadataAutoTopicCreationForNonInternalTopic(): Unit = { + testMetadataAutoTopicCreation("topic", enableAutoTopicCreation = true, + expectedError = Errors.UNKNOWN_TOPIC_OR_PARTITION) + } + + @Test + def testMetadataAutoTopicCreationDisabledForOffsetTopic(): Unit = { + testMetadataAutoTopicCreation(Topic.GROUP_METADATA_TOPIC_NAME, enableAutoTopicCreation = false, + expectedError = Errors.UNKNOWN_TOPIC_OR_PARTITION) + } + + @Test + def testMetadataAutoTopicCreationDisabledForTxnTopic(): Unit = { + testMetadataAutoTopicCreation(Topic.TRANSACTION_STATE_TOPIC_NAME, enableAutoTopicCreation = false, + expectedError = Errors.UNKNOWN_TOPIC_OR_PARTITION) + } + + @Test + def testMetadataAutoTopicCreationDisabledForNonInternalTopic(): Unit = { + testMetadataAutoTopicCreation("topic", enableAutoTopicCreation = false, + expectedError = Errors.UNKNOWN_TOPIC_OR_PARTITION) + } + + @Test + def testMetadataAutoCreationDisabledForNonInternal(): Unit = { + testMetadataAutoTopicCreation("topic", enableAutoTopicCreation = true, + expectedError = Errors.UNKNOWN_TOPIC_OR_PARTITION) + } + + private def testMetadataAutoTopicCreation(topicName: String, + enableAutoTopicCreation: Boolean, + expectedError: Errors): Unit = { + val authorizer: Authorizer = EasyMock.niceMock(classOf[Authorizer]) + + val requestHeader = new RequestHeader(ApiKeys.METADATA, ApiKeys.METADATA.latestVersion, + clientId, 0) + + val numBrokersNeeded = 3 + addTopicToMetadataCache("some-topic", 1, 3) + + authorizeResource(authorizer, AclOperation.DESCRIBE, ResourceType.TOPIC, + topicName, AuthorizationResult.ALLOWED) + + if (enableAutoTopicCreation) + authorizeResource(authorizer, AclOperation.CREATE, ResourceType.CLUSTER, + Resource.CLUSTER_NAME, AuthorizationResult.ALLOWED, logIfDenied = false) + + val topicConfigOverride = mutable.Map.empty[String, String] + val isInternal = + topicName match { + case Topic.GROUP_METADATA_TOPIC_NAME => + topicConfigOverride.put(KafkaConfig.OffsetsTopicPartitionsProp, numBrokersNeeded.toString) + topicConfigOverride.put(KafkaConfig.OffsetsTopicReplicationFactorProp, numBrokersNeeded.toString) + EasyMock.expect(groupCoordinator.offsetsTopicConfigs).andReturn(new Properties) + true + + case Topic.TRANSACTION_STATE_TOPIC_NAME => + topicConfigOverride.put(KafkaConfig.TransactionsTopicPartitionsProp, numBrokersNeeded.toString) + topicConfigOverride.put(KafkaConfig.TransactionsTopicReplicationFactorProp, numBrokersNeeded.toString) + EasyMock.expect(txnCoordinator.transactionTopicConfigs).andReturn(new Properties) + true + case _ => + topicConfigOverride.put(KafkaConfig.NumPartitionsProp, numBrokersNeeded.toString) + topicConfigOverride.put(KafkaConfig.DefaultReplicationFactorProp, numBrokersNeeded.toString) + false + } + + val metadataRequest = new MetadataRequest.Builder( + List(topicName).asJava, enableAutoTopicCreation + ).build(requestHeader.apiVersion) + val request = buildRequest(metadataRequest) + + val capturedResponse = expectNoThrottling(request) + + val capturedRequest = verifyTopicCreation(topicName, enableAutoTopicCreation, isInternal, request) + + EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, authorizer, + autoTopicCreationManager, forwardingManager, clientControllerQuotaManager, groupCoordinator, txnCoordinator) + + createKafkaApis(authorizer = Some(authorizer), enableForwarding = enableAutoTopicCreation, + overrideProperties = topicConfigOverride).handleTopicMetadataRequest(request) + + val response = capturedResponse.getValue.asInstanceOf[MetadataResponse] + val expectedMetadataResponse = util.Collections.singletonList(new TopicMetadata( + expectedError, + topicName, + isInternal, + util.Collections.emptyList() + )) + + assertEquals(expectedMetadataResponse, response.topicMetadata()) + + if (enableAutoTopicCreation) { + assertTrue(capturedRequest.getValue.isDefined) + assertEquals(request.context, capturedRequest.getValue.get) + } + + verify(authorizer, autoTopicCreationManager) + } + + private def verifyTopicCreation(topicName: String, + enableAutoTopicCreation: Boolean, + isInternal: Boolean, + request: RequestChannel.Request): Capture[Option[RequestContext]] = { + val capturedRequest = EasyMock.newCapture[Option[RequestContext]]() + if (enableAutoTopicCreation) { + EasyMock.expect(clientControllerQuotaManager.newPermissiveQuotaFor(EasyMock.eq(request))) + .andReturn(UnboundedControllerMutationQuota) + + EasyMock.expect(autoTopicCreationManager.createTopics( + EasyMock.eq(Set(topicName)), + EasyMock.eq(UnboundedControllerMutationQuota), + EasyMock.capture(capturedRequest))).andReturn( + Seq(new MetadataResponseTopic() + .setErrorCode(Errors.UNKNOWN_TOPIC_OR_PARTITION.code()) + .setIsInternal(isInternal) + .setName(topicName)) + ).once() + } + capturedRequest + } + + private def setupBrokerMetadata(hasEnoughLiveBrokers: Boolean, numBrokersNeeded: Int): Unit = { + addTopicToMetadataCache("some-topic", 1, + if (hasEnoughLiveBrokers) + numBrokersNeeded + else + numBrokersNeeded - 1) + } + + @Test + def testInvalidMetadataRequestReturnsError(): Unit = { + // Construct invalid MetadataRequestTopics. We will try each one separately and ensure the error is thrown. + val topics = List(new MetadataRequestData.MetadataRequestTopic().setName(null).setTopicId(Uuid.randomUuid()), + new MetadataRequestData.MetadataRequestTopic().setName(null), + new MetadataRequestData.MetadataRequestTopic().setTopicId(Uuid.randomUuid()), + new MetadataRequestData.MetadataRequestTopic().setName("topic1").setTopicId(Uuid.randomUuid())) + + EasyMock.replay(replicaManager, clientRequestQuotaManager, + autoTopicCreationManager, forwardingManager, clientControllerQuotaManager, groupCoordinator, txnCoordinator) + + // if version is 10 or 11, the invalid topic metadata should return an error + val invalidVersions = Set(10, 11) + invalidVersions.foreach( version => + topics.foreach(topic => { + val metadataRequestData = new MetadataRequestData().setTopics(Collections.singletonList(topic)) + val request = buildRequest(new MetadataRequest(metadataRequestData, version.toShort)) + val kafkaApis = createKafkaApis() + + val capturedResponse = EasyMock.newCapture[AbstractResponse]() + EasyMock.expect(requestChannel.sendResponse( + EasyMock.eq(request), + EasyMock.capture(capturedResponse), + EasyMock.anyObject() + )) + + EasyMock.replay(requestChannel) + kafkaApis.handle(request, RequestLocal.withThreadConfinedCaching) + + val response = capturedResponse.getValue.asInstanceOf[MetadataResponse] + assertEquals(1, response.topicMetadata.size) + assertEquals(1, response.errorCounts.get(Errors.INVALID_REQUEST)) + response.data.topics.forEach(topic => assertNotEquals(null, topic.name)) + reset(requestChannel) + }) + ) + } + + @Test + def testOffsetCommitWithInvalidPartition(): Unit = { + val topic = "topic" + addTopicToMetadataCache(topic, numPartitions = 1) + + def checkInvalidPartition(invalidPartitionId: Int): Unit = { + EasyMock.reset(replicaManager, clientRequestQuotaManager, requestChannel) + + val offsetCommitRequest = new OffsetCommitRequest.Builder( + new OffsetCommitRequestData() + .setGroupId("groupId") + .setTopics(Collections.singletonList( + new OffsetCommitRequestData.OffsetCommitRequestTopic() + .setName(topic) + .setPartitions(Collections.singletonList( + new OffsetCommitRequestData.OffsetCommitRequestPartition() + .setPartitionIndex(invalidPartitionId) + .setCommittedOffset(15) + .setCommittedLeaderEpoch(RecordBatch.NO_PARTITION_LEADER_EPOCH) + .setCommittedMetadata("")) + ) + ))).build() + + val request = buildRequest(offsetCommitRequest) + val capturedResponse = expectNoThrottling(request) + EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel) + createKafkaApis().handleOffsetCommitRequest(request, RequestLocal.withThreadConfinedCaching) + + val response = capturedResponse.getValue.asInstanceOf[OffsetCommitResponse] + assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, + Errors.forCode(response.data.topics().get(0).partitions().get(0).errorCode)) + } + + checkInvalidPartition(-1) + checkInvalidPartition(1) // topic has only one partition + } + + @Test + def testTxnOffsetCommitWithInvalidPartition(): Unit = { + val topic = "topic" + addTopicToMetadataCache(topic, numPartitions = 1) + + def checkInvalidPartition(invalidPartitionId: Int): Unit = { + EasyMock.reset(replicaManager, clientRequestQuotaManager, requestChannel) + + val invalidTopicPartition = new TopicPartition(topic, invalidPartitionId) + val partitionOffsetCommitData = new TxnOffsetCommitRequest.CommittedOffset(15L, "", Optional.empty()) + val offsetCommitRequest = new TxnOffsetCommitRequest.Builder( + "txnId", + "groupId", + 15L, + 0.toShort, + Map(invalidTopicPartition -> partitionOffsetCommitData).asJava, + ).build() + val request = buildRequest(offsetCommitRequest) + + val capturedResponse = expectNoThrottling(request) + EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel) + createKafkaApis().handleTxnOffsetCommitRequest(request, RequestLocal.withThreadConfinedCaching) + + val response = capturedResponse.getValue.asInstanceOf[TxnOffsetCommitResponse] + assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, response.errors().get(invalidTopicPartition)) + } + + checkInvalidPartition(-1) + checkInvalidPartition(1) // topic has only one partition + } + + @Test + def shouldReplaceCoordinatorNotAvailableWithLoadInProcessInTxnOffsetCommitWithOlderClient(): Unit = { + val topic = "topic" + addTopicToMetadataCache(topic, numPartitions = 2) + + for (version <- ApiKeys.TXN_OFFSET_COMMIT.oldestVersion to ApiKeys.TXN_OFFSET_COMMIT.latestVersion) { + EasyMock.reset(replicaManager, clientRequestQuotaManager, requestChannel, groupCoordinator) + + val topicPartition = new TopicPartition(topic, 1) + val capturedResponse: Capture[AbstractResponse] = EasyMock.newCapture() + val responseCallback: Capture[Map[TopicPartition, Errors] => Unit] = EasyMock.newCapture() + + val partitionOffsetCommitData = new TxnOffsetCommitRequest.CommittedOffset(15L, "", Optional.empty()) + val groupId = "groupId" + + val producerId = 15L + val epoch = 0.toShort + + val offsetCommitRequest = new TxnOffsetCommitRequest.Builder( + "txnId", + groupId, + producerId, + epoch, + Map(topicPartition -> partitionOffsetCommitData).asJava, + ).build(version.toShort) + val request = buildRequest(offsetCommitRequest) + + val requestLocal = RequestLocal.withThreadConfinedCaching + EasyMock.expect(groupCoordinator.handleTxnCommitOffsets( + EasyMock.eq(groupId), + EasyMock.eq(producerId), + EasyMock.eq(epoch), + EasyMock.anyString(), + EasyMock.eq(Option.empty), + EasyMock.anyInt(), + EasyMock.anyObject(), + EasyMock.capture(responseCallback), + EasyMock.eq(requestLocal) + )).andAnswer( + () => responseCallback.getValue.apply(Map(topicPartition -> Errors.COORDINATOR_LOAD_IN_PROGRESS))) + + EasyMock.expect(requestChannel.sendResponse( + EasyMock.eq(request), + EasyMock.capture(capturedResponse), + EasyMock.eq(None) + )) + + EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, groupCoordinator) + + createKafkaApis().handleTxnOffsetCommitRequest(request, requestLocal) + + val response = capturedResponse.getValue.asInstanceOf[TxnOffsetCommitResponse] + + if (version < 2) { + assertEquals(Errors.COORDINATOR_NOT_AVAILABLE, response.errors().get(topicPartition)) + } else { + assertEquals(Errors.COORDINATOR_LOAD_IN_PROGRESS, response.errors().get(topicPartition)) + } + } + } + + @Test + def shouldReplaceProducerFencedWithInvalidProducerEpochInInitProducerIdWithOlderClient(): Unit = { + val topic = "topic" + addTopicToMetadataCache(topic, numPartitions = 2) + + for (version <- ApiKeys.INIT_PRODUCER_ID.oldestVersion to ApiKeys.INIT_PRODUCER_ID.latestVersion) { + + EasyMock.reset(replicaManager, clientRequestQuotaManager, requestChannel, txnCoordinator) + + val capturedResponse: Capture[AbstractResponse] = EasyMock.newCapture() + val responseCallback: Capture[InitProducerIdResult => Unit] = EasyMock.newCapture() + + val transactionalId = "txnId" + val producerId = if (version < 3) + RecordBatch.NO_PRODUCER_ID + else + 15 + + val epoch = if (version < 3) + RecordBatch.NO_PRODUCER_EPOCH + else + 0.toShort + + val txnTimeoutMs = TimeUnit.MINUTES.toMillis(15).toInt + + val initProducerIdRequest = new InitProducerIdRequest.Builder( + new InitProducerIdRequestData() + .setTransactionalId(transactionalId) + .setTransactionTimeoutMs(txnTimeoutMs) + .setProducerId(producerId) + .setProducerEpoch(epoch) + ).build(version.toShort) + + val request = buildRequest(initProducerIdRequest) + + val expectedProducerIdAndEpoch = if (version < 3) + Option.empty + else + Option(new ProducerIdAndEpoch(producerId, epoch)) + + val requestLocal = RequestLocal.withThreadConfinedCaching + EasyMock.expect(txnCoordinator.handleInitProducerId( + EasyMock.eq(transactionalId), + EasyMock.eq(txnTimeoutMs), + EasyMock.eq(expectedProducerIdAndEpoch), + EasyMock.capture(responseCallback), + EasyMock.eq(requestLocal) + )).andAnswer( + () => responseCallback.getValue.apply(InitProducerIdResult(producerId, epoch, Errors.PRODUCER_FENCED))) + + EasyMock.expect(requestChannel.sendResponse( + EasyMock.eq(request), + EasyMock.capture(capturedResponse), + EasyMock.eq(None) + )) + + EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, txnCoordinator) + + createKafkaApis().handleInitProducerIdRequest(request, requestLocal) + + val response = capturedResponse.getValue.asInstanceOf[InitProducerIdResponse] + + if (version < 4) { + assertEquals(Errors.INVALID_PRODUCER_EPOCH.code, response.data.errorCode) + } else { + assertEquals(Errors.PRODUCER_FENCED.code, response.data.errorCode) + } + } + } + + @Test + def shouldReplaceProducerFencedWithInvalidProducerEpochInAddOffsetToTxnWithOlderClient(): Unit = { + val topic = "topic" + addTopicToMetadataCache(topic, numPartitions = 2) + + for (version <- ApiKeys.ADD_OFFSETS_TO_TXN.oldestVersion to ApiKeys.ADD_OFFSETS_TO_TXN.latestVersion) { + + EasyMock.reset(replicaManager, clientRequestQuotaManager, requestChannel, groupCoordinator, txnCoordinator) + + val capturedResponse: Capture[AbstractResponse] = EasyMock.newCapture() + val responseCallback: Capture[Errors => Unit] = EasyMock.newCapture() + + val groupId = "groupId" + val transactionalId = "txnId" + val producerId = 15L + val epoch = 0.toShort + + val addOffsetsToTxnRequest = new AddOffsetsToTxnRequest.Builder( + new AddOffsetsToTxnRequestData() + .setGroupId(groupId) + .setTransactionalId(transactionalId) + .setProducerId(producerId) + .setProducerEpoch(epoch) + ).build(version.toShort) + val request = buildRequest(addOffsetsToTxnRequest) + + val partition = 1 + EasyMock.expect(groupCoordinator.partitionFor( + EasyMock.eq(groupId) + )).andReturn(partition) + + val requestLocal = RequestLocal.withThreadConfinedCaching + EasyMock.expect(txnCoordinator.handleAddPartitionsToTransaction( + EasyMock.eq(transactionalId), + EasyMock.eq(producerId), + EasyMock.eq(epoch), + EasyMock.eq(Set(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, partition))), + EasyMock.capture(responseCallback), + EasyMock.eq(requestLocal) + )).andAnswer( + () => responseCallback.getValue.apply(Errors.PRODUCER_FENCED)) + + EasyMock.expect(requestChannel.sendResponse( + EasyMock.eq(request), + EasyMock.capture(capturedResponse), + EasyMock.eq(None) + )) + + EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, txnCoordinator, groupCoordinator) + + createKafkaApis().handleAddOffsetsToTxnRequest(request, requestLocal) + + val response = capturedResponse.getValue.asInstanceOf[AddOffsetsToTxnResponse] + + if (version < 2) { + assertEquals(Errors.INVALID_PRODUCER_EPOCH.code, response.data.errorCode) + } else { + assertEquals(Errors.PRODUCER_FENCED.code, response.data.errorCode) + } + } + } + + @Test + def shouldReplaceProducerFencedWithInvalidProducerEpochInAddPartitionToTxnWithOlderClient(): Unit = { + val topic = "topic" + addTopicToMetadataCache(topic, numPartitions = 2) + + for (version <- ApiKeys.ADD_PARTITIONS_TO_TXN.oldestVersion to ApiKeys.ADD_PARTITIONS_TO_TXN.latestVersion) { + + EasyMock.reset(replicaManager, clientRequestQuotaManager, requestChannel, txnCoordinator) + + val capturedResponse: Capture[AbstractResponse] = EasyMock.newCapture() + val responseCallback: Capture[Errors => Unit] = EasyMock.newCapture() + + val transactionalId = "txnId" + val producerId = 15L + val epoch = 0.toShort + + val partition = 1 + val topicPartition = new TopicPartition(topic, partition) + + val addPartitionsToTxnRequest = new AddPartitionsToTxnRequest.Builder( + transactionalId, + producerId, + epoch, + Collections.singletonList(topicPartition) + ).build(version.toShort) + val request = buildRequest(addPartitionsToTxnRequest) + + val requestLocal = RequestLocal.withThreadConfinedCaching + EasyMock.expect(txnCoordinator.handleAddPartitionsToTransaction( + EasyMock.eq(transactionalId), + EasyMock.eq(producerId), + EasyMock.eq(epoch), + EasyMock.eq(Set(topicPartition)), + EasyMock.capture(responseCallback), + EasyMock.eq(requestLocal) + )).andAnswer( + () => responseCallback.getValue.apply(Errors.PRODUCER_FENCED)) + + EasyMock.expect(requestChannel.sendResponse( + EasyMock.eq(request), + EasyMock.capture(capturedResponse), + EasyMock.eq(None) + )) + + EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, txnCoordinator) + + createKafkaApis().handleAddPartitionToTxnRequest(request, requestLocal) + + val response = capturedResponse.getValue.asInstanceOf[AddPartitionsToTxnResponse] + + if (version < 2) { + assertEquals(Collections.singletonMap(topicPartition, Errors.INVALID_PRODUCER_EPOCH), response.errors()) + } else { + assertEquals(Collections.singletonMap(topicPartition, Errors.PRODUCER_FENCED), response.errors()) + } + } + } + + @Test + def shouldReplaceProducerFencedWithInvalidProducerEpochInEndTxnWithOlderClient(): Unit = { + val topic = "topic" + addTopicToMetadataCache(topic, numPartitions = 2) + + for (version <- ApiKeys.END_TXN.oldestVersion to ApiKeys.END_TXN.latestVersion) { + EasyMock.reset(replicaManager, clientRequestQuotaManager, requestChannel, txnCoordinator) + + val capturedResponse: Capture[AbstractResponse] = EasyMock.newCapture() + val responseCallback: Capture[Errors => Unit] = EasyMock.newCapture() + + val transactionalId = "txnId" + val producerId = 15L + val epoch = 0.toShort + + val endTxnRequest = new EndTxnRequest.Builder( + new EndTxnRequestData() + .setTransactionalId(transactionalId) + .setProducerId(producerId) + .setProducerEpoch(epoch) + .setCommitted(true) + ).build(version.toShort) + val request = buildRequest(endTxnRequest) + + val requestLocal = RequestLocal.withThreadConfinedCaching + EasyMock.expect(txnCoordinator.handleEndTransaction( + EasyMock.eq(transactionalId), + EasyMock.eq(producerId), + EasyMock.eq(epoch), + EasyMock.eq(TransactionResult.COMMIT), + EasyMock.capture(responseCallback), + EasyMock.eq(requestLocal) + )).andAnswer( + () => responseCallback.getValue.apply(Errors.PRODUCER_FENCED)) + + EasyMock.expect(requestChannel.sendResponse( + EasyMock.eq(request), + EasyMock.capture(capturedResponse), + EasyMock.eq(None) + )) + + EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, txnCoordinator) + createKafkaApis().handleEndTxnRequest(request, requestLocal) + + val response = capturedResponse.getValue.asInstanceOf[EndTxnResponse] + + if (version < 2) { + assertEquals(Errors.INVALID_PRODUCER_EPOCH.code, response.data.errorCode) + } else { + assertEquals(Errors.PRODUCER_FENCED.code, response.data.errorCode) + } + } + } + + @Test + def shouldReplaceProducerFencedWithInvalidProducerEpochInProduceResponse(): Unit = { + val topic = "topic" + addTopicToMetadataCache(topic, numPartitions = 2) + + for (version <- ApiKeys.PRODUCE.oldestVersion to ApiKeys.PRODUCE.latestVersion) { + + EasyMock.reset(replicaManager, clientQuotaManager, clientRequestQuotaManager, requestChannel, txnCoordinator) + + val responseCallback: Capture[Map[TopicPartition, PartitionResponse] => Unit] = EasyMock.newCapture() + + val tp = new TopicPartition("topic", 0) + + val produceRequest = ProduceRequest.forCurrentMagic(new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection( + Collections.singletonList(new ProduceRequestData.TopicProduceData() + .setName(tp.topic).setPartitionData(Collections.singletonList( + new ProduceRequestData.PartitionProduceData() + .setIndex(tp.partition) + .setRecords(MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("test".getBytes)))))) + .iterator)) + .setAcks(1.toShort) + .setTimeoutMs(5000)) + .build(version.toShort) + val request = buildRequest(produceRequest) + + EasyMock.expect(replicaManager.appendRecords(EasyMock.anyLong(), + EasyMock.anyShort(), + EasyMock.eq(false), + EasyMock.eq(AppendOrigin.Client), + EasyMock.anyObject(), + EasyMock.capture(responseCallback), + EasyMock.anyObject(), + EasyMock.anyObject(), + EasyMock.anyObject()) + ).andAnswer(() => responseCallback.getValue.apply(Map(tp -> new PartitionResponse(Errors.INVALID_PRODUCER_EPOCH)))) + + val capturedResponse = expectNoThrottling(request) + EasyMock.expect(clientQuotaManager.maybeRecordAndGetThrottleTimeMs( + anyObject[RequestChannel.Request](), anyDouble, anyLong)).andReturn(0) + + EasyMock.replay(replicaManager, clientQuotaManager, clientRequestQuotaManager, requestChannel, txnCoordinator) + + createKafkaApis().handleProduceRequest(request, RequestLocal.withThreadConfinedCaching) + + val response = capturedResponse.getValue.asInstanceOf[ProduceResponse] + + assertEquals(1, response.data.responses.size) + val topicProduceResponse = response.data.responses.asScala.head + assertEquals(1, topicProduceResponse.partitionResponses.size) + val partitionProduceResponse = topicProduceResponse.partitionResponses.asScala.head + assertEquals(Errors.INVALID_PRODUCER_EPOCH, Errors.forCode(partitionProduceResponse.errorCode)) + } + } + + @Test + def testAddPartitionsToTxnWithInvalidPartition(): Unit = { + val topic = "topic" + addTopicToMetadataCache(topic, numPartitions = 1) + + def checkInvalidPartition(invalidPartitionId: Int): Unit = { + EasyMock.reset(replicaManager, clientRequestQuotaManager, requestChannel) + + val invalidTopicPartition = new TopicPartition(topic, invalidPartitionId) + val addPartitionsToTxnRequest = new AddPartitionsToTxnRequest.Builder( + "txnlId", 15L, 0.toShort, List(invalidTopicPartition).asJava + ).build() + val request = buildRequest(addPartitionsToTxnRequest) + + val capturedResponse = expectNoThrottling(request) + EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel) + createKafkaApis().handleAddPartitionToTxnRequest(request, RequestLocal.withThreadConfinedCaching) + + val response = capturedResponse.getValue.asInstanceOf[AddPartitionsToTxnResponse] + assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, response.errors().get(invalidTopicPartition)) + } + + checkInvalidPartition(-1) + checkInvalidPartition(1) // topic has only one partition + } + + @Test + def shouldThrowUnsupportedVersionExceptionOnHandleAddOffsetToTxnRequestWhenInterBrokerProtocolNotSupported(): Unit = { + assertThrows(classOf[UnsupportedVersionException], + () => createKafkaApis(KAFKA_0_10_2_IV0).handleAddOffsetsToTxnRequest(null, RequestLocal.withThreadConfinedCaching)) + } + + @Test + def shouldThrowUnsupportedVersionExceptionOnHandleAddPartitionsToTxnRequestWhenInterBrokerProtocolNotSupported(): Unit = { + assertThrows(classOf[UnsupportedVersionException], + () => createKafkaApis(KAFKA_0_10_2_IV0).handleAddPartitionToTxnRequest(null, RequestLocal.withThreadConfinedCaching)) + } + + @Test + def shouldThrowUnsupportedVersionExceptionOnHandleTxnOffsetCommitRequestWhenInterBrokerProtocolNotSupported(): Unit = { + assertThrows(classOf[UnsupportedVersionException], + () => createKafkaApis(KAFKA_0_10_2_IV0).handleAddPartitionToTxnRequest(null, RequestLocal.withThreadConfinedCaching)) + } + + @Test + def shouldThrowUnsupportedVersionExceptionOnHandleEndTxnRequestWhenInterBrokerProtocolNotSupported(): Unit = { + assertThrows(classOf[UnsupportedVersionException], + () => createKafkaApis(KAFKA_0_10_2_IV0).handleEndTxnRequest(null, RequestLocal.withThreadConfinedCaching)) + } + + @Test + def shouldThrowUnsupportedVersionExceptionOnHandleWriteTxnMarkersRequestWhenInterBrokerProtocolNotSupported(): Unit = { + assertThrows(classOf[UnsupportedVersionException], + () => createKafkaApis(KAFKA_0_10_2_IV0).handleWriteTxnMarkersRequest(null, RequestLocal.withThreadConfinedCaching)) + } + + @Test + def shouldRespondWithUnsupportedForMessageFormatOnHandleWriteTxnMarkersWhenMagicLowerThanRequired(): Unit = { + val topicPartition = new TopicPartition("t", 0) + val (writeTxnMarkersRequest, request) = createWriteTxnMarkersRequest(asList(topicPartition)) + val expectedErrors = Map(topicPartition -> Errors.UNSUPPORTED_FOR_MESSAGE_FORMAT).asJava + val capturedResponse: Capture[AbstractResponse] = EasyMock.newCapture() + + EasyMock.expect(replicaManager.getMagic(topicPartition)) + .andReturn(Some(RecordBatch.MAGIC_VALUE_V1)) + EasyMock.expect(requestChannel.sendResponse( + EasyMock.eq(request), + EasyMock.capture(capturedResponse), + EasyMock.eq(None) + )) + EasyMock.replay(replicaManager, replicaQuotaManager, requestChannel) + + createKafkaApis().handleWriteTxnMarkersRequest(request, RequestLocal.withThreadConfinedCaching) + + val markersResponse = capturedResponse.getValue.asInstanceOf[WriteTxnMarkersResponse] + assertEquals(expectedErrors, markersResponse.errorsByProducerId.get(1L)) + } + + @Test + def shouldRespondWithUnknownTopicWhenPartitionIsNotHosted(): Unit = { + val topicPartition = new TopicPartition("t", 0) + val (writeTxnMarkersRequest, request) = createWriteTxnMarkersRequest(asList(topicPartition)) + val expectedErrors = Map(topicPartition -> Errors.UNKNOWN_TOPIC_OR_PARTITION).asJava + val capturedResponse: Capture[AbstractResponse] = EasyMock.newCapture() + + EasyMock.expect(replicaManager.getMagic(topicPartition)) + .andReturn(None) + EasyMock.expect(requestChannel.sendResponse( + EasyMock.eq(request), + EasyMock.capture(capturedResponse), + EasyMock.eq(None) + )) + EasyMock.replay(replicaManager, replicaQuotaManager, requestChannel) + + createKafkaApis().handleWriteTxnMarkersRequest(request, RequestLocal.withThreadConfinedCaching) + + val markersResponse = capturedResponse.getValue.asInstanceOf[WriteTxnMarkersResponse] + assertEquals(expectedErrors, markersResponse.errorsByProducerId.get(1L)) + } + + @Test + def shouldRespondWithUnsupportedMessageFormatForBadPartitionAndNoErrorsForGoodPartition(): Unit = { + val tp1 = new TopicPartition("t", 0) + val tp2 = new TopicPartition("t1", 0) + val (writeTxnMarkersRequest, request) = createWriteTxnMarkersRequest(asList(tp1, tp2)) + val expectedErrors = Map(tp1 -> Errors.UNSUPPORTED_FOR_MESSAGE_FORMAT, tp2 -> Errors.NONE).asJava + + val capturedResponse: Capture[AbstractResponse] = EasyMock.newCapture() + val responseCallback: Capture[Map[TopicPartition, PartitionResponse] => Unit] = EasyMock.newCapture() + + EasyMock.expect(replicaManager.getMagic(tp1)) + .andReturn(Some(RecordBatch.MAGIC_VALUE_V1)) + EasyMock.expect(replicaManager.getMagic(tp2)) + .andReturn(Some(RecordBatch.MAGIC_VALUE_V2)) + + val requestLocal = RequestLocal.withThreadConfinedCaching + EasyMock.expect(replicaManager.appendRecords(EasyMock.anyLong(), + EasyMock.anyShort(), + EasyMock.eq(true), + EasyMock.eq(AppendOrigin.Coordinator), + EasyMock.anyObject(), + EasyMock.capture(responseCallback), + EasyMock.anyObject(), + EasyMock.anyObject(), + EasyMock.eq(requestLocal)) + ).andAnswer(() => responseCallback.getValue.apply(Map(tp2 -> new PartitionResponse(Errors.NONE)))) + + EasyMock.expect(requestChannel.sendResponse( + EasyMock.eq(request), + EasyMock.capture(capturedResponse), + EasyMock.eq(None) + )) + EasyMock.replay(replicaManager, replicaQuotaManager, requestChannel) + + createKafkaApis().handleWriteTxnMarkersRequest(request, requestLocal) + + val markersResponse = capturedResponse.getValue.asInstanceOf[WriteTxnMarkersResponse] + assertEquals(expectedErrors, markersResponse.errorsByProducerId.get(1L)) + EasyMock.verify(replicaManager) + } + + @Test + def shouldResignCoordinatorsIfStopReplicaReceivedWithDeleteFlagAndLeaderEpoch(): Unit = { + shouldResignCoordinatorsIfStopReplicaReceivedWithDeleteFlag( + LeaderAndIsr.initialLeaderEpoch + 2, deletePartition = true) + } + + @Test + def shouldResignCoordinatorsIfStopReplicaReceivedWithDeleteFlagAndDeleteSentinel(): Unit = { + shouldResignCoordinatorsIfStopReplicaReceivedWithDeleteFlag( + LeaderAndIsr.EpochDuringDelete, deletePartition = true) + } + + @Test + def shouldResignCoordinatorsIfStopReplicaReceivedWithDeleteFlagAndNoEpochSentinel(): Unit = { + shouldResignCoordinatorsIfStopReplicaReceivedWithDeleteFlag( + LeaderAndIsr.NoEpoch, deletePartition = true) + } + + @Test + def shouldNotResignCoordinatorsIfStopReplicaReceivedWithoutDeleteFlag(): Unit = { + shouldResignCoordinatorsIfStopReplicaReceivedWithDeleteFlag( + LeaderAndIsr.initialLeaderEpoch + 2, deletePartition = false) + } + + def shouldResignCoordinatorsIfStopReplicaReceivedWithDeleteFlag(leaderEpoch: Int, + deletePartition: Boolean): Unit = { + val controllerId = 0 + val controllerEpoch = 5 + val brokerEpoch = 230498320L + + val fooPartition = new TopicPartition("foo", 0) + val groupMetadataPartition = new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, 0) + val txnStatePartition = new TopicPartition(Topic.TRANSACTION_STATE_TOPIC_NAME, 0) + + val topicStates = Seq( + new StopReplicaTopicState() + .setTopicName(groupMetadataPartition.topic) + .setPartitionStates(Seq(new StopReplicaPartitionState() + .setPartitionIndex(groupMetadataPartition.partition) + .setLeaderEpoch(leaderEpoch) + .setDeletePartition(deletePartition)).asJava), + new StopReplicaTopicState() + .setTopicName(txnStatePartition.topic) + .setPartitionStates(Seq(new StopReplicaPartitionState() + .setPartitionIndex(txnStatePartition.partition) + .setLeaderEpoch(leaderEpoch) + .setDeletePartition(deletePartition)).asJava), + new StopReplicaTopicState() + .setTopicName(fooPartition.topic) + .setPartitionStates(Seq(new StopReplicaPartitionState() + .setPartitionIndex(fooPartition.partition) + .setLeaderEpoch(leaderEpoch) + .setDeletePartition(deletePartition)).asJava) + ).asJava + + val stopReplicaRequest = new StopReplicaRequest.Builder( + ApiKeys.STOP_REPLICA.latestVersion, + controllerId, + controllerEpoch, + brokerEpoch, + false, + topicStates + ).build() + val request = buildRequest(stopReplicaRequest) + + EasyMock.expect(replicaManager.stopReplicas( + EasyMock.eq(request.context.correlationId), + EasyMock.eq(controllerId), + EasyMock.eq(controllerEpoch), + EasyMock.eq(stopReplicaRequest.partitionStates().asScala) + )).andReturn( + (mutable.Map( + groupMetadataPartition -> Errors.NONE, + txnStatePartition -> Errors.NONE, + fooPartition -> Errors.NONE + ), Errors.NONE) + ) + EasyMock.expect(controller.brokerEpoch).andStubReturn(brokerEpoch) + + if (deletePartition) { + if (leaderEpoch >= 0) { + txnCoordinator.onResignation(txnStatePartition.partition, Some(leaderEpoch)) + } else { + txnCoordinator.onResignation(txnStatePartition.partition, None) + } + EasyMock.expectLastCall() + } + + if (deletePartition) { + if (leaderEpoch >= 0) { + groupCoordinator.onResignation(groupMetadataPartition.partition, Some(leaderEpoch)) + } else { + groupCoordinator.onResignation(groupMetadataPartition.partition, None) + } + EasyMock.expectLastCall() + } + + EasyMock.replay(controller, replicaManager, txnCoordinator, groupCoordinator) + + createKafkaApis().handleStopReplicaRequest(request) + + EasyMock.verify(txnCoordinator, groupCoordinator) + } + + @Test + def shouldRespondWithUnknownTopicOrPartitionForBadPartitionAndNoErrorsForGoodPartition(): Unit = { + val tp1 = new TopicPartition("t", 0) + val tp2 = new TopicPartition("t1", 0) + val (writeTxnMarkersRequest, request) = createWriteTxnMarkersRequest(asList(tp1, tp2)) + val expectedErrors = Map(tp1 -> Errors.UNKNOWN_TOPIC_OR_PARTITION, tp2 -> Errors.NONE).asJava + + val capturedResponse: Capture[AbstractResponse] = EasyMock.newCapture() + val responseCallback: Capture[Map[TopicPartition, PartitionResponse] => Unit] = EasyMock.newCapture() + + EasyMock.expect(replicaManager.getMagic(tp1)) + .andReturn(None) + EasyMock.expect(replicaManager.getMagic(tp2)) + .andReturn(Some(RecordBatch.MAGIC_VALUE_V2)) + + val requestLocal = RequestLocal.withThreadConfinedCaching + EasyMock.expect(replicaManager.appendRecords(EasyMock.anyLong(), + EasyMock.anyShort(), + EasyMock.eq(true), + EasyMock.eq(AppendOrigin.Coordinator), + EasyMock.anyObject(), + EasyMock.capture(responseCallback), + EasyMock.anyObject(), + EasyMock.anyObject(), + EasyMock.eq(requestLocal)) + ).andAnswer(() => responseCallback.getValue.apply(Map(tp2 -> new PartitionResponse(Errors.NONE)))) + + EasyMock.expect(requestChannel.sendResponse( + EasyMock.eq(request), + EasyMock.capture(capturedResponse), + EasyMock.eq(None) + )) + EasyMock.replay(replicaManager, replicaQuotaManager, requestChannel) + + createKafkaApis().handleWriteTxnMarkersRequest(request, requestLocal) + + val markersResponse = capturedResponse.getValue.asInstanceOf[WriteTxnMarkersResponse] + assertEquals(expectedErrors, markersResponse.errorsByProducerId.get(1L)) + EasyMock.verify(replicaManager) + } + + @Test + def shouldAppendToLogOnWriteTxnMarkersWhenCorrectMagicVersion(): Unit = { + val topicPartition = new TopicPartition("t", 0) + val request = createWriteTxnMarkersRequest(asList(topicPartition))._2 + EasyMock.expect(replicaManager.getMagic(topicPartition)) + .andReturn(Some(RecordBatch.MAGIC_VALUE_V2)) + + val requestLocal = RequestLocal.withThreadConfinedCaching + EasyMock.expect(replicaManager.appendRecords(EasyMock.anyLong(), + EasyMock.anyShort(), + EasyMock.eq(true), + EasyMock.eq(AppendOrigin.Coordinator), + EasyMock.anyObject(), + EasyMock.anyObject(), + EasyMock.anyObject(), + EasyMock.anyObject(), + EasyMock.eq(requestLocal))) + + EasyMock.replay(replicaManager) + + createKafkaApis().handleWriteTxnMarkersRequest(request, requestLocal) + EasyMock.verify(replicaManager) + } + + @Test + def testLeaderReplicaIfLocalRaisesFencedLeaderEpoch(): Unit = { + testListOffsetFailedGetLeaderReplica(Errors.FENCED_LEADER_EPOCH) + } + + @Test + def testLeaderReplicaIfLocalRaisesUnknownLeaderEpoch(): Unit = { + testListOffsetFailedGetLeaderReplica(Errors.UNKNOWN_LEADER_EPOCH) + } + + @Test + def testLeaderReplicaIfLocalRaisesNotLeaderOrFollower(): Unit = { + testListOffsetFailedGetLeaderReplica(Errors.NOT_LEADER_OR_FOLLOWER) + } + + @Test + def testLeaderReplicaIfLocalRaisesUnknownTopicOrPartition(): Unit = { + testListOffsetFailedGetLeaderReplica(Errors.UNKNOWN_TOPIC_OR_PARTITION) + } + + @Test + def testDescribeGroups(): Unit = { + val groupId = "groupId" + val random = new Random() + val metadata = new Array[Byte](10) + random.nextBytes(metadata) + val assignment = new Array[Byte](10) + random.nextBytes(assignment) + + val memberSummary = MemberSummary("memberid", Some("instanceid"), "clientid", "clienthost", metadata, assignment) + val groupSummary = GroupSummary("Stable", "consumer", "roundrobin", List(memberSummary)) + + EasyMock.reset(groupCoordinator, replicaManager, clientRequestQuotaManager, requestChannel) + + val describeGroupsRequest = new DescribeGroupsRequest.Builder( + new DescribeGroupsRequestData().setGroups(List(groupId).asJava) + ).build() + val request = buildRequest(describeGroupsRequest) + + val capturedResponse = expectNoThrottling(request) + EasyMock.expect(groupCoordinator.handleDescribeGroup(EasyMock.eq(groupId))) + .andReturn((Errors.NONE, groupSummary)) + EasyMock.replay(groupCoordinator, replicaManager, clientRequestQuotaManager, requestChannel) + + createKafkaApis().handleDescribeGroupRequest(request) + + val response = capturedResponse.getValue.asInstanceOf[DescribeGroupsResponse] + + val group = response.data.groups().get(0) + assertEquals(Errors.NONE, Errors.forCode(group.errorCode)) + assertEquals(groupId, group.groupId()) + assertEquals(groupSummary.state, group.groupState()) + assertEquals(groupSummary.protocolType, group.protocolType()) + assertEquals(groupSummary.protocol, group.protocolData()) + assertEquals(groupSummary.members.size, group.members().size()) + + val member = group.members().get(0) + assertEquals(memberSummary.memberId, member.memberId()) + assertEquals(memberSummary.groupInstanceId.orNull, member.groupInstanceId()) + assertEquals(memberSummary.clientId, member.clientId()) + assertEquals(memberSummary.clientHost, member.clientHost()) + assertArrayEquals(memberSummary.metadata, member.memberMetadata()) + assertArrayEquals(memberSummary.assignment, member.memberAssignment()) + } + + @Test + def testOffsetDelete(): Unit = { + val group = "groupId" + addTopicToMetadataCache("topic-1", numPartitions = 2) + addTopicToMetadataCache("topic-2", numPartitions = 2) + + EasyMock.reset(groupCoordinator, replicaManager, clientRequestQuotaManager, requestChannel) + + val topics = new OffsetDeleteRequestTopicCollection() + topics.add(new OffsetDeleteRequestTopic() + .setName("topic-1") + .setPartitions(Seq( + new OffsetDeleteRequestPartition().setPartitionIndex(0), + new OffsetDeleteRequestPartition().setPartitionIndex(1)).asJava)) + topics.add(new OffsetDeleteRequestTopic() + .setName("topic-2") + .setPartitions(Seq( + new OffsetDeleteRequestPartition().setPartitionIndex(0), + new OffsetDeleteRequestPartition().setPartitionIndex(1)).asJava)) + + val offsetDeleteRequest = new OffsetDeleteRequest.Builder( + new OffsetDeleteRequestData() + .setGroupId(group) + .setTopics(topics) + ).build() + val request = buildRequest(offsetDeleteRequest) + + val requestLocal = RequestLocal.withThreadConfinedCaching + val capturedResponse = expectNoThrottling(request) + EasyMock.expect(groupCoordinator.handleDeleteOffsets( + EasyMock.eq(group), + EasyMock.eq(Seq( + new TopicPartition("topic-1", 0), + new TopicPartition("topic-1", 1), + new TopicPartition("topic-2", 0), + new TopicPartition("topic-2", 1) + )), + EasyMock.eq(requestLocal) + )).andReturn((Errors.NONE, Map( + new TopicPartition("topic-1", 0) -> Errors.NONE, + new TopicPartition("topic-1", 1) -> Errors.NONE, + new TopicPartition("topic-2", 0) -> Errors.NONE, + new TopicPartition("topic-2", 1) -> Errors.NONE, + ))) + + EasyMock.replay(groupCoordinator, replicaManager, clientRequestQuotaManager, requestChannel) + + createKafkaApis().handleOffsetDeleteRequest(request, requestLocal) + + val response = capturedResponse.getValue.asInstanceOf[OffsetDeleteResponse] + + def errorForPartition(topic: String, partition: Int): Errors = { + Errors.forCode(response.data.topics.find(topic).partitions.find(partition).errorCode) + } + + assertEquals(2, response.data.topics.size) + assertEquals(Errors.NONE, errorForPartition("topic-1", 0)) + assertEquals(Errors.NONE, errorForPartition("topic-1", 1)) + assertEquals(Errors.NONE, errorForPartition("topic-2", 0)) + assertEquals(Errors.NONE, errorForPartition("topic-2", 1)) + } + + @Test + def testOffsetDeleteWithInvalidPartition(): Unit = { + val group = "groupId" + val topic = "topic" + addTopicToMetadataCache(topic, numPartitions = 1) + + def checkInvalidPartition(invalidPartitionId: Int): Unit = { + EasyMock.reset(groupCoordinator, replicaManager, clientRequestQuotaManager, requestChannel) + + val topics = new OffsetDeleteRequestTopicCollection() + topics.add(new OffsetDeleteRequestTopic() + .setName(topic) + .setPartitions(Collections.singletonList( + new OffsetDeleteRequestPartition().setPartitionIndex(invalidPartitionId)))) + val offsetDeleteRequest = new OffsetDeleteRequest.Builder( + new OffsetDeleteRequestData() + .setGroupId(group) + .setTopics(topics) + ).build() + val request = buildRequest(offsetDeleteRequest) + val capturedResponse = expectNoThrottling(request) + + val requestLocal = RequestLocal.withThreadConfinedCaching + EasyMock.expect(groupCoordinator.handleDeleteOffsets(EasyMock.eq(group), EasyMock.eq(Seq.empty), + EasyMock.eq(requestLocal))).andReturn((Errors.NONE, Map.empty)) + EasyMock.replay(groupCoordinator, replicaManager, clientRequestQuotaManager, requestChannel) + + createKafkaApis().handleOffsetDeleteRequest(request, requestLocal) + + val response = capturedResponse.getValue.asInstanceOf[OffsetDeleteResponse] + + assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, + Errors.forCode(response.data.topics.find(topic).partitions.find(invalidPartitionId).errorCode)) + } + + checkInvalidPartition(-1) + checkInvalidPartition(1) // topic has only one partition + } + + @Test + def testOffsetDeleteWithInvalidGroup(): Unit = { + val group = "groupId" + + EasyMock.reset(groupCoordinator, replicaManager, clientRequestQuotaManager, requestChannel) + + val offsetDeleteRequest = new OffsetDeleteRequest.Builder( + new OffsetDeleteRequestData() + .setGroupId(group) + ).build() + val request = buildRequest(offsetDeleteRequest) + + val capturedResponse = expectNoThrottling(request) + val requestLocal = RequestLocal.withThreadConfinedCaching + EasyMock.expect(groupCoordinator.handleDeleteOffsets(EasyMock.eq(group), EasyMock.eq(Seq.empty), + EasyMock.eq(requestLocal))).andReturn((Errors.GROUP_ID_NOT_FOUND, Map.empty)) + EasyMock.replay(groupCoordinator, replicaManager, clientRequestQuotaManager, requestChannel) + + createKafkaApis().handleOffsetDeleteRequest(request, requestLocal) + + val response = capturedResponse.getValue.asInstanceOf[OffsetDeleteResponse] + + assertEquals(Errors.GROUP_ID_NOT_FOUND, Errors.forCode(response.data.errorCode)) + } + + private def testListOffsetFailedGetLeaderReplica(error: Errors): Unit = { + val tp = new TopicPartition("foo", 0) + val isolationLevel = IsolationLevel.READ_UNCOMMITTED + val currentLeaderEpoch = Optional.of[Integer](15) + + EasyMock.expect(replicaManager.fetchOffsetForTimestamp( + EasyMock.eq(tp), + EasyMock.eq(ListOffsetsRequest.EARLIEST_TIMESTAMP), + EasyMock.eq(Some(isolationLevel)), + EasyMock.eq(currentLeaderEpoch), + fetchOnlyFromLeader = EasyMock.eq(true)) + ).andThrow(error.exception) + + val targetTimes = List(new ListOffsetsTopic() + .setName(tp.topic) + .setPartitions(List(new ListOffsetsPartition() + .setPartitionIndex(tp.partition) + .setTimestamp(ListOffsetsRequest.EARLIEST_TIMESTAMP) + .setCurrentLeaderEpoch(currentLeaderEpoch.get)).asJava)).asJava + val listOffsetRequest = ListOffsetsRequest.Builder.forConsumer(true, isolationLevel, false) + .setTargetTimes(targetTimes).build() + val request = buildRequest(listOffsetRequest) + val capturedResponse = expectNoThrottling(request) + + EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel) + createKafkaApis().handleListOffsetRequest(request) + + val response = capturedResponse.getValue.asInstanceOf[ListOffsetsResponse] + val partitionDataOptional = response.topics.asScala.find(_.name == tp.topic).get + .partitions.asScala.find(_.partitionIndex == tp.partition) + assertTrue(partitionDataOptional.isDefined) + + val partitionData = partitionDataOptional.get + assertEquals(error.code, partitionData.errorCode) + assertEquals(ListOffsetsResponse.UNKNOWN_OFFSET, partitionData.offset) + assertEquals(ListOffsetsResponse.UNKNOWN_TIMESTAMP, partitionData.timestamp) + } + + @Test + def testReadUncommittedConsumerListOffsetLatest(): Unit = { + testConsumerListOffsetLatest(IsolationLevel.READ_UNCOMMITTED) + } + + @Test + def testReadCommittedConsumerListOffsetLatest(): Unit = { + testConsumerListOffsetLatest(IsolationLevel.READ_COMMITTED) + } + + /** + * Verifies that the metadata response is correct if the broker listeners are inconsistent (i.e. one broker has + * more listeners than another) and the request is sent on the listener that exists in both brokers. + */ + @Test + def testMetadataRequestOnSharedListenerWithInconsistentListenersAcrossBrokers(): Unit = { + val (plaintextListener, _) = updateMetadataCacheWithInconsistentListeners() + val response = sendMetadataRequestWithInconsistentListeners(plaintextListener) + assertEquals(Set(0, 1), response.brokers.asScala.map(_.id).toSet) + } + + /** + * Verifies that the metadata response is correct if the broker listeners are inconsistent (i.e. one broker has + * more listeners than another) and the request is sent on the listener that exists in one broker. + */ + @Test + def testMetadataRequestOnDistinctListenerWithInconsistentListenersAcrossBrokers(): Unit = { + val (_, anotherListener) = updateMetadataCacheWithInconsistentListeners() + val response = sendMetadataRequestWithInconsistentListeners(anotherListener) + assertEquals(Set(0), response.brokers.asScala.map(_.id).toSet) + } + + + /** + * Metadata request to fetch all topics should not result in the followings: + * 1) Auto topic creation + * 2) UNKNOWN_TOPIC_OR_PARTITION + * + * This case is testing the case that a topic is being deleted from MetadataCache right after + * authorization but before checking in MetadataCache. + */ + @Test + def getAllTopicMetadataShouldNotCreateTopicOrReturnUnknownTopicPartition(): Unit = { + // Setup: authorizer authorizes 2 topics, but one got deleted in metadata cache + metadataCache = + EasyMock.partialMockBuilder(classOf[ZkMetadataCache]) + .withConstructor(classOf[Int]) + .withArgs(Int.box(brokerId)) // Need to box it for Scala 2.12 and before + .addMockedMethod("getAllTopics") + .addMockedMethod("getTopicMetadata") + .createMock() + + // 2 topics returned for authorization in during handle + val topicsReturnedFromMetadataCacheForAuthorization = Set("remaining-topic", "later-deleted-topic") + expect(metadataCache.getAllTopics()).andReturn(topicsReturnedFromMetadataCacheForAuthorization).once() + // 1 topic is deleted from metadata right at the time between authorization and the next getTopicMetadata() call + expect(metadataCache.getTopicMetadata( + EasyMock.eq(topicsReturnedFromMetadataCacheForAuthorization), + anyObject[ListenerName], + anyBoolean, + anyBoolean + )).andStubReturn(Seq( + new MetadataResponseTopic() + .setErrorCode(Errors.NONE.code) + .setName("remaining-topic") + .setIsInternal(false) + )) + + EasyMock.replay(metadataCache) + + var createTopicIsCalled: Boolean = false; + // Specific mock on zkClient for this use case + // Expect it's never called to do auto topic creation + expect(zkClient.setOrCreateEntityConfigs( + EasyMock.eq(ConfigType.Topic), + EasyMock.anyString, + EasyMock.anyObject[Properties] + )).andStubAnswer(() => { + createTopicIsCalled = true + }) + // No need to use + expect(zkClient.getAllBrokersInCluster) + .andStubReturn(Seq(new Broker( + brokerId, "localhost", 9902, + ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT), SecurityProtocol.PLAINTEXT + ))) + + EasyMock.replay(zkClient) + + val (requestListener, _) = updateMetadataCacheWithInconsistentListeners() + val response = sendMetadataRequestWithInconsistentListeners(requestListener) + + assertFalse(createTopicIsCalled) + val responseTopics = response.topicMetadata().asScala.map { metadata => metadata.topic() } + assertEquals(List("remaining-topic"), responseTopics) + assertTrue(response.topicsByError(Errors.UNKNOWN_TOPIC_OR_PARTITION).isEmpty) + } + + @Test + def testUnauthorizedTopicMetadataRequest(): Unit = { + // 1. Set up broker information + val plaintextListener = ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT) + val broker = new UpdateMetadataBroker() + .setId(0) + .setRack("rack") + .setEndpoints(Seq( + new UpdateMetadataEndpoint() + .setHost("broker0") + .setPort(9092) + .setSecurityProtocol(SecurityProtocol.PLAINTEXT.id) + .setListener(plaintextListener.value) + ).asJava) + + // 2. Set up authorizer + val authorizer: Authorizer = EasyMock.niceMock(classOf[Authorizer]) + val unauthorizedTopic = "unauthorized-topic" + val authorizedTopic = "authorized-topic" + + val expectedActions = Seq( + new Action(AclOperation.DESCRIBE, new ResourcePattern(ResourceType.TOPIC, unauthorizedTopic, PatternType.LITERAL), 1, true, true), + new Action(AclOperation.DESCRIBE, new ResourcePattern(ResourceType.TOPIC, authorizedTopic, PatternType.LITERAL), 1, true, true) + ) + + // Here we need to use AuthHelperTest.matchSameElements instead of EasyMock.eq since the order of the request is unknown + EasyMock.expect(authorizer.authorize(anyObject[RequestContext], AuthHelperTest.matchSameElements(expectedActions.asJava))) + .andAnswer { () => + val actions = EasyMock.getCurrentArguments.apply(1).asInstanceOf[util.List[Action]].asScala + actions.map { action => + if (action.resourcePattern().name().equals(authorizedTopic)) + AuthorizationResult.ALLOWED + else + AuthorizationResult.DENIED + }.asJava + }.times(2) + + // 3. Set up MetadataCache + val authorizedTopicId = Uuid.randomUuid() + val unauthorizedTopicId = Uuid.randomUuid(); + + val topicIds = new util.HashMap[String, Uuid]() + topicIds.put(authorizedTopic, authorizedTopicId) + topicIds.put(unauthorizedTopic, unauthorizedTopicId) + + def createDummyPartitionStates(topic: String) = { + new UpdateMetadataPartitionState() + .setTopicName(topic) + .setPartitionIndex(0) + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(0) + .setReplicas(Collections.singletonList(0)) + .setZkVersion(0) + .setIsr(Collections.singletonList(0)) + } + + // Send UpdateMetadataReq to update MetadataCache + val partitionStates = Seq(unauthorizedTopic, authorizedTopic).map(createDummyPartitionStates) + + val updateMetadataRequest = new UpdateMetadataRequest.Builder(ApiKeys.UPDATE_METADATA.latestVersion, 0, + 0, 0, partitionStates.asJava, Seq(broker).asJava, topicIds).build() + metadataCache.asInstanceOf[ZkMetadataCache].updateMetadata(correlationId = 0, updateMetadataRequest) + + // 4. Send TopicMetadataReq using topicId + val metadataReqByTopicId = new MetadataRequest.Builder(util.Arrays.asList(authorizedTopicId, unauthorizedTopicId)).build() + val repByTopicId = buildRequest(metadataReqByTopicId, plaintextListener) + val capturedMetadataByTopicIdResp = expectNoThrottling(repByTopicId) + EasyMock.replay(clientRequestQuotaManager, requestChannel, authorizer) + + createKafkaApis(authorizer = Some(authorizer)).handleTopicMetadataRequest(repByTopicId) + val metadataByTopicIdResp = capturedMetadataByTopicIdResp.getValue.asInstanceOf[MetadataResponse] + + val metadataByTopicId = metadataByTopicIdResp.data().topics().asScala.groupBy(_.topicId()).map(kv => (kv._1, kv._2.head)) + + metadataByTopicId.foreach{ case (topicId, metadataResponseTopic) => + if (topicId == unauthorizedTopicId) { + // Return an TOPIC_AUTHORIZATION_FAILED on unauthorized error regardless of leaking the existence of topic id + assertEquals(Errors.TOPIC_AUTHORIZATION_FAILED.code(), metadataResponseTopic.errorCode()) + // Do not return topic information on unauthorized error + assertNull(metadataResponseTopic.name()) + } else { + assertEquals(Errors.NONE.code(), metadataResponseTopic.errorCode()) + assertEquals(authorizedTopic, metadataResponseTopic.name()) + } + } + + // 4. Send TopicMetadataReq using topic name + EasyMock.reset(clientRequestQuotaManager, requestChannel) + val metadataReqByTopicName = new MetadataRequest.Builder(util.Arrays.asList(authorizedTopic, unauthorizedTopic), false).build() + val repByTopicName = buildRequest(metadataReqByTopicName, plaintextListener) + val capturedMetadataByTopicNameResp = expectNoThrottling(repByTopicName) + EasyMock.replay(clientRequestQuotaManager, requestChannel) + + createKafkaApis(authorizer = Some(authorizer)).handleTopicMetadataRequest(repByTopicName) + val metadataByTopicNameResp = capturedMetadataByTopicNameResp.getValue.asInstanceOf[MetadataResponse] + + val metadataByTopicName = metadataByTopicNameResp.data().topics().asScala.groupBy(_.name()).map(kv => (kv._1, kv._2.head)) + + metadataByTopicName.foreach{ case (topicName, metadataResponseTopic) => + if (topicName == unauthorizedTopic) { + assertEquals(Errors.TOPIC_AUTHORIZATION_FAILED.code(), metadataResponseTopic.errorCode()) + // Do not return topic Id on unauthorized error + assertEquals(Uuid.ZERO_UUID, metadataResponseTopic.topicId()) + } else { + assertEquals(Errors.NONE.code(), metadataResponseTopic.errorCode()) + assertEquals(authorizedTopicId, metadataResponseTopic.topicId()) + } + } + } + + /** + * Verifies that sending a fetch request with version 9 works correctly when + * ReplicaManager.getLogConfig returns None. + */ + @Test + def testFetchRequestV9WithNoLogConfig(): Unit = { + val tidp = new TopicIdPartition(Uuid.ZERO_UUID, new TopicPartition("foo", 0)) + val tp = tidp.topicPartition + addTopicToMetadataCache(tp.topic, numPartitions = 1) + val hw = 3 + val timestamp = 1000 + + expect(replicaManager.getLogConfig(EasyMock.eq(tp))).andReturn(None) + + replicaManager.fetchMessages(anyLong, anyInt, anyInt, anyInt, anyBoolean, + anyObject[Seq[(TopicIdPartition, FetchRequest.PartitionData)]], anyObject[ReplicaQuota], + anyObject[Seq[(TopicIdPartition, FetchPartitionData)] => Unit](), anyObject[IsolationLevel], + anyObject[Option[ClientMetadata]]) + expectLastCall[Unit].andAnswer(new IAnswer[Unit] { + def answer: Unit = { + val callback = getCurrentArguments.apply(7) + .asInstanceOf[Seq[(TopicIdPartition, FetchPartitionData)] => Unit] + val records = MemoryRecords.withRecords(CompressionType.NONE, + new SimpleRecord(timestamp, "foo".getBytes(StandardCharsets.UTF_8))) + callback(Seq(tidp -> FetchPartitionData(Errors.NONE, hw, 0, records, + None, None, None, Option.empty, isReassignmentFetch = false))) + } + }) + + val fetchData = Map(tidp -> new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0, 0, 1000, + Optional.empty())).asJava + val fetchDataBuilder = Map(tp -> new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0, 0, 1000, + Optional.empty())).asJava + val fetchMetadata = new JFetchMetadata(0, 0) + val fetchContext = new FullFetchContext(time, new FetchSessionCache(1000, 100), + fetchMetadata, fetchData, false, false) + expect(fetchManager.newContext( + anyObject[Short], + anyObject[JFetchMetadata], + anyObject[Boolean], + anyObject[util.Map[TopicIdPartition, FetchRequest.PartitionData]], + anyObject[util.List[TopicIdPartition]], + anyObject[util.Map[Uuid, String]])).andReturn(fetchContext) + + EasyMock.expect(clientQuotaManager.maybeRecordAndGetThrottleTimeMs( + anyObject[RequestChannel.Request](), anyDouble, anyLong)).andReturn(0) + + val fetchRequest = new FetchRequest.Builder(9, 9, -1, 100, 0, fetchDataBuilder) + .build() + val request = buildRequest(fetchRequest) + val capturedResponse = expectNoThrottling(request) + + EasyMock.replay(replicaManager, clientQuotaManager, clientRequestQuotaManager, requestChannel, fetchManager) + createKafkaApis().handleFetchRequest(request) + + val response = capturedResponse.getValue.asInstanceOf[FetchResponse] + val responseData = response.responseData(metadataCache.topicIdsToNames(), 9) + assertTrue(responseData.containsKey(tp)) + + val partitionData = responseData.get(tp) + assertEquals(Errors.NONE.code, partitionData.errorCode) + assertEquals(hw, partitionData.highWatermark) + assertEquals(-1, partitionData.lastStableOffset) + assertEquals(0, partitionData.logStartOffset) + assertEquals(timestamp, FetchResponse.recordsOrFail(partitionData).batches.iterator.next.maxTimestamp) + assertNull(partitionData.abortedTransactions) + } + + /** + * Verifies that partitions with unknown topic ID errors are added to the erroneous set and there is not an attempt to fetch them. + */ + @ParameterizedTest + @ValueSource(ints = Array(-1, 0)) + def testFetchRequestErroneousPartitions(replicaId: Int): Unit = { + val foo = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 0)) + val unresolvedFoo = new TopicIdPartition(foo.topicId, new TopicPartition(null, foo.partition)) + + addTopicToMetadataCache(foo.topic, 1, topicId = foo.topicId) + + // We will never return a logConfig when the topic name is null. This is ok since we won't have any records to convert. + expect(replicaManager.getLogConfig(EasyMock.eq(unresolvedFoo.topicPartition))).andReturn(None) + + // Simulate unknown topic ID in the context + val fetchData = Map(new TopicIdPartition(foo.topicId, new TopicPartition(null, foo.partition)) -> + new FetchRequest.PartitionData(foo.topicId, 0, 0, 1000, Optional.empty())).asJava + val fetchDataBuilder = Map(foo.topicPartition -> new FetchRequest.PartitionData(foo.topicId, 0, 0, 1000, + Optional.empty())).asJava + val fetchMetadata = new JFetchMetadata(0, 0) + val fetchContext = new FullFetchContext(time, new FetchSessionCache(1000, 100), + fetchMetadata, fetchData, true, replicaId >= 0) + // We expect to have the resolved partition, but we will simulate an unknown one with the fetchContext we return. + expect(fetchManager.newContext( + ApiKeys.FETCH.latestVersion, + fetchMetadata, + replicaId >= 0, + Collections.singletonMap(foo, new FetchRequest.PartitionData(foo.topicId, 0, 0, 1000, Optional.empty())), + Collections.emptyList[TopicIdPartition], + metadataCache.topicIdsToNames()) + ).andReturn(fetchContext) + + EasyMock.expect(clientQuotaManager.maybeRecordAndGetThrottleTimeMs( + anyObject[RequestChannel.Request](), anyDouble, anyLong)).andReturn(0) + + // If replicaId is -1 we will build a consumer request. Any non-negative replicaId will build a follower request. + val fetchRequest = new FetchRequest.Builder(ApiKeys.FETCH.latestVersion, ApiKeys.FETCH.latestVersion, + replicaId, 100, 0, fetchDataBuilder).metadata(fetchMetadata).build() + val request = buildRequest(fetchRequest) + val capturedResponse = expectNoThrottling(request) + + EasyMock.replay(replicaManager, clientQuotaManager, clientRequestQuotaManager, requestChannel, fetchManager) + createKafkaApis().handleFetchRequest(request) + + val response = capturedResponse.getValue.asInstanceOf[FetchResponse] + val responseData = response.responseData(metadataCache.topicIdsToNames(), ApiKeys.FETCH.latestVersion) + assertTrue(responseData.containsKey(foo.topicPartition)) + + val partitionData = responseData.get(foo.topicPartition) + assertEquals(Errors.UNKNOWN_TOPIC_ID.code, partitionData.errorCode) + assertEquals(-1, partitionData.highWatermark) + assertEquals(-1, partitionData.lastStableOffset) + assertEquals(-1, partitionData.logStartOffset) + assertEquals(MemoryRecords.EMPTY, FetchResponse.recordsOrFail(partitionData)) + } + + @Test + def testJoinGroupProtocolsOrder(): Unit = { + val protocols = List( + ("first", "first".getBytes()), + ("second", "second".getBytes()) + ) + + val groupId = "group" + val memberId = "member1" + val protocolType = "consumer" + val rebalanceTimeoutMs = 10 + val sessionTimeoutMs = 5 + val capturedProtocols = EasyMock.newCapture[List[(String, Array[Byte])]]() + + EasyMock.expect(groupCoordinator.handleJoinGroup( + EasyMock.eq(groupId), + EasyMock.eq(memberId), + EasyMock.eq(None), + EasyMock.eq(true), + EasyMock.eq(clientId), + EasyMock.eq(InetAddress.getLocalHost.toString), + EasyMock.eq(rebalanceTimeoutMs), + EasyMock.eq(sessionTimeoutMs), + EasyMock.eq(protocolType), + EasyMock.capture(capturedProtocols), + anyObject(), + anyObject() + )) + + EasyMock.replay(groupCoordinator) + + createKafkaApis().handleJoinGroupRequest( + buildRequest( + new JoinGroupRequest.Builder( + new JoinGroupRequestData() + .setGroupId(groupId) + .setMemberId(memberId) + .setProtocolType(protocolType) + .setRebalanceTimeoutMs(rebalanceTimeoutMs) + .setSessionTimeoutMs(sessionTimeoutMs) + .setProtocols(new JoinGroupRequestData.JoinGroupRequestProtocolCollection( + protocols.map { case (name, protocol) => new JoinGroupRequestProtocol() + .setName(name).setMetadata(protocol) + }.iterator.asJava)) + ).build() + ), + RequestLocal.withThreadConfinedCaching) + + EasyMock.verify(groupCoordinator) + + val capturedProtocolsList = capturedProtocols.getValue + assertEquals(protocols.size, capturedProtocolsList.size) + protocols.zip(capturedProtocolsList).foreach { case ((expectedName, expectedBytes), (name, bytes)) => + assertEquals(expectedName, name) + assertArrayEquals(expectedBytes, bytes) + } + } + + @Test + def testJoinGroupWhenAnErrorOccurs(): Unit = { + for (version <- ApiKeys.JOIN_GROUP.oldestVersion to ApiKeys.JOIN_GROUP.latestVersion) { + testJoinGroupWhenAnErrorOccurs(version.asInstanceOf[Short]) + } + } + + def testJoinGroupWhenAnErrorOccurs(version: Short): Unit = { + EasyMock.reset(groupCoordinator, clientRequestQuotaManager, requestChannel, replicaManager) + + val groupId = "group" + val memberId = "member1" + val protocolType = "consumer" + val rebalanceTimeoutMs = 10 + val sessionTimeoutMs = 5 + + val capturedCallback = EasyMock.newCapture[JoinGroupCallback]() + + EasyMock.expect(groupCoordinator.handleJoinGroup( + EasyMock.eq(groupId), + EasyMock.eq(memberId), + EasyMock.eq(None), + EasyMock.eq(if (version >= 4) true else false), + EasyMock.eq(clientId), + EasyMock.eq(InetAddress.getLocalHost.toString), + EasyMock.eq(if (version >= 1) rebalanceTimeoutMs else sessionTimeoutMs), + EasyMock.eq(sessionTimeoutMs), + EasyMock.eq(protocolType), + EasyMock.eq(List.empty), + EasyMock.capture(capturedCallback), + EasyMock.anyObject() + )) + + val joinGroupRequest = new JoinGroupRequest.Builder( + new JoinGroupRequestData() + .setGroupId(groupId) + .setMemberId(memberId) + .setProtocolType(protocolType) + .setRebalanceTimeoutMs(rebalanceTimeoutMs) + .setSessionTimeoutMs(sessionTimeoutMs) + ).build(version) + + val requestChannelRequest = buildRequest(joinGroupRequest) + val capturedResponse = expectNoThrottling(requestChannelRequest) + + EasyMock.replay(groupCoordinator, clientRequestQuotaManager, requestChannel, replicaManager) + createKafkaApis().handleJoinGroupRequest(requestChannelRequest, RequestLocal.withThreadConfinedCaching) + + EasyMock.verify(groupCoordinator) + + capturedCallback.getValue.apply(JoinGroupResult(memberId, Errors.INCONSISTENT_GROUP_PROTOCOL)) + + val response = capturedResponse.getValue.asInstanceOf[JoinGroupResponse] + + assertEquals(Errors.INCONSISTENT_GROUP_PROTOCOL, response.error) + assertEquals(0, response.data.members.size) + assertEquals(memberId, response.data.memberId) + assertEquals(GroupCoordinator.NoGeneration, response.data.generationId) + assertEquals(GroupCoordinator.NoLeader, response.data.leader) + assertNull(response.data.protocolType) + + if (version >= 7) { + assertNull(response.data.protocolName) + } else { + assertEquals(GroupCoordinator.NoProtocol, response.data.protocolName) + } + + EasyMock.verify(clientRequestQuotaManager, requestChannel) + } + + @Test + def testJoinGroupProtocolType(): Unit = { + for (version <- ApiKeys.JOIN_GROUP.oldestVersion to ApiKeys.JOIN_GROUP.latestVersion) { + testJoinGroupProtocolType(version.asInstanceOf[Short]) + } + } + + def testJoinGroupProtocolType(version: Short): Unit = { + EasyMock.reset(groupCoordinator, clientRequestQuotaManager, requestChannel, replicaManager) + + val groupId = "group" + val memberId = "member1" + val protocolType = "consumer" + val protocolName = "range" + val rebalanceTimeoutMs = 10 + val sessionTimeoutMs = 5 + + val capturedCallback = EasyMock.newCapture[JoinGroupCallback]() + + EasyMock.expect(groupCoordinator.handleJoinGroup( + EasyMock.eq(groupId), + EasyMock.eq(memberId), + EasyMock.eq(None), + EasyMock.eq(if (version >= 4) true else false), + EasyMock.eq(clientId), + EasyMock.eq(InetAddress.getLocalHost.toString), + EasyMock.eq(if (version >= 1) rebalanceTimeoutMs else sessionTimeoutMs), + EasyMock.eq(sessionTimeoutMs), + EasyMock.eq(protocolType), + EasyMock.eq(List.empty), + EasyMock.capture(capturedCallback), + EasyMock.anyObject() + )) + + val joinGroupRequest = new JoinGroupRequest.Builder( + new JoinGroupRequestData() + .setGroupId(groupId) + .setMemberId(memberId) + .setProtocolType(protocolType) + .setRebalanceTimeoutMs(rebalanceTimeoutMs) + .setSessionTimeoutMs(sessionTimeoutMs) + ).build(version) + + val requestChannelRequest = buildRequest(joinGroupRequest) + val capturedResponse = expectNoThrottling(requestChannelRequest) + + EasyMock.replay(groupCoordinator, clientRequestQuotaManager, requestChannel, replicaManager) + createKafkaApis().handleJoinGroupRequest(requestChannelRequest, RequestLocal.withThreadConfinedCaching) + + EasyMock.verify(groupCoordinator) + + capturedCallback.getValue.apply(JoinGroupResult( + members = List.empty, + memberId = memberId, + generationId = 0, + protocolType = Some(protocolType), + protocolName = Some(protocolName), + leaderId = memberId, + error = Errors.NONE + )) + + val response = capturedResponse.getValue.asInstanceOf[JoinGroupResponse] + + assertEquals(Errors.NONE, response.error) + assertEquals(0, response.data.members.size) + assertEquals(memberId, response.data.memberId) + assertEquals(0, response.data.generationId) + assertEquals(memberId, response.data.leader) + assertEquals(protocolName, response.data.protocolName) + assertEquals(protocolType, response.data.protocolType) + + EasyMock.verify(clientRequestQuotaManager, requestChannel) + } + + @Test + def testSyncGroupProtocolTypeAndName(): Unit = { + for (version <- ApiKeys.SYNC_GROUP.oldestVersion to ApiKeys.SYNC_GROUP.latestVersion) { + testSyncGroupProtocolTypeAndName(version.asInstanceOf[Short]) + } + } + + def testSyncGroupProtocolTypeAndName(version: Short): Unit = { + EasyMock.reset(groupCoordinator, clientRequestQuotaManager, requestChannel, replicaManager) + + val groupId = "group" + val memberId = "member1" + val protocolType = "consumer" + val protocolName = "range" + + val capturedCallback = EasyMock.newCapture[SyncGroupCallback]() + + val requestLocal = RequestLocal.withThreadConfinedCaching + EasyMock.expect(groupCoordinator.handleSyncGroup( + EasyMock.eq(groupId), + EasyMock.eq(0), + EasyMock.eq(memberId), + EasyMock.eq(if (version >= 5) Some(protocolType) else None), + EasyMock.eq(if (version >= 5) Some(protocolName) else None), + EasyMock.eq(None), + EasyMock.eq(Map.empty), + EasyMock.capture(capturedCallback), + EasyMock.eq(requestLocal) + )) + + val syncGroupRequest = new SyncGroupRequest.Builder( + new SyncGroupRequestData() + .setGroupId(groupId) + .setGenerationId(0) + .setMemberId(memberId) + .setProtocolType(protocolType) + .setProtocolName(protocolName) + ).build(version) + + val requestChannelRequest = buildRequest(syncGroupRequest) + val capturedResponse = expectNoThrottling(requestChannelRequest) + + EasyMock.replay(groupCoordinator, clientRequestQuotaManager, requestChannel, replicaManager) + createKafkaApis().handleSyncGroupRequest(requestChannelRequest, requestLocal) + + EasyMock.verify(groupCoordinator) + + capturedCallback.getValue.apply(SyncGroupResult( + protocolType = Some(protocolType), + protocolName = Some(protocolName), + memberAssignment = Array.empty, + error = Errors.NONE + )) + + val response = capturedResponse.getValue.asInstanceOf[SyncGroupResponse] + + assertEquals(Errors.NONE, response.error) + assertArrayEquals(Array.empty[Byte], response.data.assignment) + assertEquals(protocolType, response.data.protocolType) + + EasyMock.verify(clientRequestQuotaManager, requestChannel) + } + + @Test + def testSyncGroupProtocolTypeAndNameAreMandatorySinceV5(): Unit = { + for (version <- ApiKeys.SYNC_GROUP.oldestVersion to ApiKeys.SYNC_GROUP.latestVersion) { + testSyncGroupProtocolTypeAndNameAreMandatorySinceV5(version.asInstanceOf[Short]) + } + } + + def testSyncGroupProtocolTypeAndNameAreMandatorySinceV5(version: Short): Unit = { + EasyMock.reset(groupCoordinator, clientRequestQuotaManager, requestChannel, replicaManager) + + val groupId = "group" + val memberId = "member1" + val protocolType = "consumer" + val protocolName = "range" + + val capturedCallback = EasyMock.newCapture[SyncGroupCallback]() + + val requestLocal = RequestLocal.withThreadConfinedCaching + if (version < 5) { + EasyMock.expect(groupCoordinator.handleSyncGroup( + EasyMock.eq(groupId), + EasyMock.eq(0), + EasyMock.eq(memberId), + EasyMock.eq(None), + EasyMock.eq(None), + EasyMock.eq(None), + EasyMock.eq(Map.empty), + EasyMock.capture(capturedCallback), + EasyMock.eq(requestLocal) + )) + } + + val syncGroupRequest = new SyncGroupRequest.Builder( + new SyncGroupRequestData() + .setGroupId(groupId) + .setGenerationId(0) + .setMemberId(memberId) + ).build(version) + + val requestChannelRequest = buildRequest(syncGroupRequest) + val capturedResponse = expectNoThrottling(requestChannelRequest) + + EasyMock.replay(groupCoordinator, clientRequestQuotaManager, requestChannel, replicaManager) + createKafkaApis().handleSyncGroupRequest(requestChannelRequest, requestLocal) + + EasyMock.verify(groupCoordinator) + + if (version < 5) { + capturedCallback.getValue.apply(SyncGroupResult( + protocolType = Some(protocolType), + protocolName = Some(protocolName), + memberAssignment = Array.empty, + error = Errors.NONE + )) + } + + val response = capturedResponse.getValue.asInstanceOf[SyncGroupResponse] + + if (version < 5) { + assertEquals(Errors.NONE, response.error) + } else { + assertEquals(Errors.INCONSISTENT_GROUP_PROTOCOL, response.error) + } + + EasyMock.verify(clientRequestQuotaManager, requestChannel) + } + + @Test + def rejectJoinGroupRequestWhenStaticMembershipNotSupported(): Unit = { + val joinGroupRequest = new JoinGroupRequest.Builder( + new JoinGroupRequestData() + .setGroupId("test") + .setMemberId("test") + .setGroupInstanceId("instanceId") + .setProtocolType("consumer") + .setProtocols(new JoinGroupRequestData.JoinGroupRequestProtocolCollection) + ).build() + + val requestChannelRequest = buildRequest(joinGroupRequest) + val capturedResponse = expectNoThrottling(requestChannelRequest) + + EasyMock.replay(clientRequestQuotaManager, requestChannel) + createKafkaApis(KAFKA_2_2_IV1).handleJoinGroupRequest(requestChannelRequest, RequestLocal.withThreadConfinedCaching) + + val response = capturedResponse.getValue.asInstanceOf[JoinGroupResponse] + assertEquals(Errors.UNSUPPORTED_VERSION, response.error()) + EasyMock.replay(groupCoordinator) + } + + @Test + def rejectSyncGroupRequestWhenStaticMembershipNotSupported(): Unit = { + val syncGroupRequest = new SyncGroupRequest.Builder( + new SyncGroupRequestData() + .setGroupId("test") + .setMemberId("test") + .setGroupInstanceId("instanceId") + .setGenerationId(1) + ).build() + + val requestChannelRequest = buildRequest(syncGroupRequest) + val capturedResponse = expectNoThrottling(requestChannelRequest) + + EasyMock.replay(clientRequestQuotaManager, requestChannel) + createKafkaApis(KAFKA_2_2_IV1).handleSyncGroupRequest(requestChannelRequest, RequestLocal.withThreadConfinedCaching) + + val response = capturedResponse.getValue.asInstanceOf[SyncGroupResponse] + assertEquals(Errors.UNSUPPORTED_VERSION, response.error) + EasyMock.replay(groupCoordinator) + } + + @Test + def rejectHeartbeatRequestWhenStaticMembershipNotSupported(): Unit = { + val heartbeatRequest = new HeartbeatRequest.Builder( + new HeartbeatRequestData() + .setGroupId("test") + .setMemberId("test") + .setGroupInstanceId("instanceId") + .setGenerationId(1) + ).build() + val requestChannelRequest = buildRequest(heartbeatRequest) + val capturedResponse = expectNoThrottling(requestChannelRequest) + + EasyMock.replay(clientRequestQuotaManager, requestChannel) + createKafkaApis(KAFKA_2_2_IV1).handleHeartbeatRequest(requestChannelRequest) + + val response = capturedResponse.getValue.asInstanceOf[HeartbeatResponse] + assertEquals(Errors.UNSUPPORTED_VERSION, response.error()) + EasyMock.replay(groupCoordinator) + } + + @Test + def rejectOffsetCommitRequestWhenStaticMembershipNotSupported(): Unit = { + val offsetCommitRequest = new OffsetCommitRequest.Builder( + new OffsetCommitRequestData() + .setGroupId("test") + .setMemberId("test") + .setGroupInstanceId("instanceId") + .setGenerationId(100) + .setTopics(Collections.singletonList( + new OffsetCommitRequestData.OffsetCommitRequestTopic() + .setName("test") + .setPartitions(Collections.singletonList( + new OffsetCommitRequestData.OffsetCommitRequestPartition() + .setPartitionIndex(0) + .setCommittedOffset(100) + .setCommittedLeaderEpoch(RecordBatch.NO_PARTITION_LEADER_EPOCH) + .setCommittedMetadata("") + )) + )) + ).build() + + val requestChannelRequest = buildRequest(offsetCommitRequest) + val capturedResponse = expectNoThrottling(requestChannelRequest) + + EasyMock.replay(clientRequestQuotaManager, requestChannel) + createKafkaApis(KAFKA_2_2_IV1).handleOffsetCommitRequest(requestChannelRequest, RequestLocal.withThreadConfinedCaching) + + val expectedTopicErrors = Collections.singletonList( + new OffsetCommitResponseData.OffsetCommitResponseTopic() + .setName("test") + .setPartitions(Collections.singletonList( + new OffsetCommitResponseData.OffsetCommitResponsePartition() + .setPartitionIndex(0) + .setErrorCode(Errors.UNSUPPORTED_VERSION.code) + )) + ) + val response = capturedResponse.getValue.asInstanceOf[OffsetCommitResponse] + assertEquals(expectedTopicErrors, response.data.topics()) + EasyMock.replay(groupCoordinator) + } + + @Test + def testMultipleLeaveGroup(): Unit = { + val groupId = "groupId" + + val leaveMemberList = List( + new MemberIdentity() + .setMemberId("member-1") + .setGroupInstanceId("instance-1"), + new MemberIdentity() + .setMemberId("member-2") + .setGroupInstanceId("instance-2") + ) + + EasyMock.expect(groupCoordinator.handleLeaveGroup( + EasyMock.eq(groupId), + EasyMock.eq(leaveMemberList), + anyObject() + )) + + val leaveRequest = buildRequest( + new LeaveGroupRequest.Builder( + groupId, + leaveMemberList.asJava + ).build() + ) + + createKafkaApis().handleLeaveGroupRequest(leaveRequest) + + EasyMock.replay(groupCoordinator) + } + + @Test + def testSingleLeaveGroup(): Unit = { + val groupId = "groupId" + val memberId = "member" + + val singleLeaveMember = List( + new MemberIdentity() + .setMemberId(memberId) + ) + + EasyMock.expect(groupCoordinator.handleLeaveGroup( + EasyMock.eq(groupId), + EasyMock.eq(singleLeaveMember), + anyObject() + )) + + val leaveRequest = buildRequest( + new LeaveGroupRequest.Builder( + groupId, + singleLeaveMember.asJava + ).build() + ) + + createKafkaApis().handleLeaveGroupRequest(leaveRequest) + + EasyMock.replay(groupCoordinator) + } + + @Test + def testReassignmentAndReplicationBytesOutRateWhenReassigning(): Unit = { + assertReassignmentAndReplicationBytesOutPerSec(true) + } + + @Test + def testReassignmentAndReplicationBytesOutRateWhenNotReassigning(): Unit = { + assertReassignmentAndReplicationBytesOutPerSec(false) + } + + private def assertReassignmentAndReplicationBytesOutPerSec(isReassigning: Boolean): Unit = { + val leaderEpoch = 0 + val tp0 = new TopicPartition("tp", 0) + val topicId = Uuid.randomUuid() + val tidp0 = new TopicIdPartition(topicId, tp0) + + setupBasicMetadataCache(tp0.topic, numPartitions = 1, 1, topicId) + val hw = 3 + + val fetchDataBuilder = Collections.singletonMap(tp0, new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0, 0, Int.MaxValue, Optional.of(leaderEpoch))) + val fetchData = Collections.singletonMap(tidp0, new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0, 0, Int.MaxValue, Optional.of(leaderEpoch))) + val fetchFromFollower = buildRequest(new FetchRequest.Builder( + ApiKeys.FETCH.oldestVersion(), ApiKeys.FETCH.latestVersion(), 1, 1000, 0, fetchDataBuilder).build()) + + val records = MemoryRecords.withRecords(CompressionType.NONE, + new SimpleRecord(1000, "foo".getBytes(StandardCharsets.UTF_8))) + replicaManager.fetchMessages(anyLong, anyInt, anyInt, anyInt, anyBoolean, + anyObject[Seq[(TopicIdPartition, FetchRequest.PartitionData)]], anyObject[ReplicaQuota], + anyObject[Seq[(TopicIdPartition, FetchPartitionData)] => Unit](), anyObject[IsolationLevel], + anyObject[Option[ClientMetadata]]) + expectLastCall[Unit].andAnswer(new IAnswer[Unit] { + def answer: Unit = { + val callback = getCurrentArguments.apply(7).asInstanceOf[Seq[(TopicIdPartition, FetchPartitionData)] => Unit] + callback(Seq(tidp0 -> FetchPartitionData(Errors.NONE, hw, 0, records, + None, None, None, Option.empty, isReassignmentFetch = isReassigning))) + } + }) + + val fetchMetadata = new JFetchMetadata(0, 0) + val fetchContext = new FullFetchContext(time, new FetchSessionCache(1000, 100), + fetchMetadata, fetchData, true, true) + expect(fetchManager.newContext( + anyObject[Short], + anyObject[JFetchMetadata], + anyObject[Boolean], + anyObject[util.Map[TopicIdPartition, FetchRequest.PartitionData]], + anyObject[util.List[TopicIdPartition]], + anyObject[util.Map[Uuid, String]])).andReturn(fetchContext) + + expect(replicaQuotaManager.record(anyLong())) + expect(replicaManager.getLogConfig(EasyMock.eq(tp0))).andReturn(None) + + val partition: Partition = createNiceMock(classOf[Partition]) + expect(replicaManager.isAddingReplica(anyObject(), anyInt())).andReturn(isReassigning) + + replay(replicaManager, fetchManager, clientQuotaManager, requestChannel, replicaQuotaManager, partition) + + createKafkaApis().handle(fetchFromFollower, RequestLocal.withThreadConfinedCaching) + + if (isReassigning) + assertEquals(records.sizeInBytes(), brokerTopicStats.allTopicsStats.reassignmentBytesOutPerSec.get.count()) + else + assertEquals(0, brokerTopicStats.allTopicsStats.reassignmentBytesOutPerSec.get.count()) + assertEquals(records.sizeInBytes(), brokerTopicStats.allTopicsStats.replicationBytesOutRate.get.count()) + + } + + @Test + def rejectInitProducerIdWhenIdButNotEpochProvided(): Unit = { + val initProducerIdRequest = new InitProducerIdRequest.Builder( + new InitProducerIdRequestData() + .setTransactionalId("known") + .setTransactionTimeoutMs(TimeUnit.MINUTES.toMillis(15).toInt) + .setProducerId(10) + .setProducerEpoch(RecordBatch.NO_PRODUCER_EPOCH) + ).build() + + val requestChannelRequest = buildRequest(initProducerIdRequest) + val capturedResponse = expectNoThrottling(requestChannelRequest) + + EasyMock.replay(clientRequestQuotaManager, requestChannel) + createKafkaApis(KAFKA_2_2_IV1).handleInitProducerIdRequest(requestChannelRequest, RequestLocal.withThreadConfinedCaching) + + val response = capturedResponse.getValue.asInstanceOf[InitProducerIdResponse] + assertEquals(Errors.INVALID_REQUEST, response.error) + } + + @Test + def rejectInitProducerIdWhenEpochButNotIdProvided(): Unit = { + val initProducerIdRequest = new InitProducerIdRequest.Builder( + new InitProducerIdRequestData() + .setTransactionalId("known") + .setTransactionTimeoutMs(TimeUnit.MINUTES.toMillis(15).toInt) + .setProducerId(RecordBatch.NO_PRODUCER_ID) + .setProducerEpoch(2) + ).build() + val requestChannelRequest = buildRequest(initProducerIdRequest) + val capturedResponse = expectNoThrottling(requestChannelRequest) + + EasyMock.replay(clientRequestQuotaManager, requestChannel) + createKafkaApis(KAFKA_2_2_IV1).handleInitProducerIdRequest(requestChannelRequest, RequestLocal.withThreadConfinedCaching) + + val response = capturedResponse.getValue.asInstanceOf[InitProducerIdResponse] + assertEquals(Errors.INVALID_REQUEST, response.error) + } + + @Test + def testUpdateMetadataRequestWithCurrentBrokerEpoch(): Unit = { + val currentBrokerEpoch = 1239875L + testUpdateMetadataRequest(currentBrokerEpoch, currentBrokerEpoch, Errors.NONE) + } + + @Test + def testUpdateMetadataRequestWithNewerBrokerEpochIsValid(): Unit = { + val currentBrokerEpoch = 1239875L + testUpdateMetadataRequest(currentBrokerEpoch, currentBrokerEpoch + 1, Errors.NONE) + } + + @Test + def testUpdateMetadataRequestWithStaleBrokerEpochIsRejected(): Unit = { + val currentBrokerEpoch = 1239875L + testUpdateMetadataRequest(currentBrokerEpoch, currentBrokerEpoch - 1, Errors.STALE_BROKER_EPOCH) + } + + def testUpdateMetadataRequest(currentBrokerEpoch: Long, brokerEpochInRequest: Long, expectedError: Errors): Unit = { + val updateMetadataRequest = createBasicMetadataRequest("topicA", 1, brokerEpochInRequest, 1) + val request = buildRequest(updateMetadataRequest) + + val capturedResponse: Capture[AbstractResponse] = EasyMock.newCapture() + + EasyMock.expect(controller.brokerEpoch).andStubReturn(currentBrokerEpoch) + EasyMock.expect(replicaManager.maybeUpdateMetadataCache( + EasyMock.eq(request.context.correlationId), + EasyMock.anyObject() + )).andStubReturn( + Seq() + ) + + EasyMock.expect(requestChannel.sendResponse( + EasyMock.eq(request), + EasyMock.capture(capturedResponse), + EasyMock.eq(None) + )) + EasyMock.replay(replicaManager, controller, requestChannel) + + createKafkaApis().handleUpdateMetadataRequest(request, RequestLocal.withThreadConfinedCaching) + val updateMetadataResponse = capturedResponse.getValue.asInstanceOf[UpdateMetadataResponse] + assertEquals(expectedError, updateMetadataResponse.error()) + EasyMock.verify(replicaManager) + } + + @Test + def testLeaderAndIsrRequestWithCurrentBrokerEpoch(): Unit = { + val currentBrokerEpoch = 1239875L + testLeaderAndIsrRequest(currentBrokerEpoch, currentBrokerEpoch, Errors.NONE) + } + + @Test + def testLeaderAndIsrRequestWithNewerBrokerEpochIsValid(): Unit = { + val currentBrokerEpoch = 1239875L + testLeaderAndIsrRequest(currentBrokerEpoch, currentBrokerEpoch + 1, Errors.NONE) + } + + @Test + def testLeaderAndIsrRequestWithStaleBrokerEpochIsRejected(): Unit = { + val currentBrokerEpoch = 1239875L + testLeaderAndIsrRequest(currentBrokerEpoch, currentBrokerEpoch - 1, Errors.STALE_BROKER_EPOCH) + } + + def testLeaderAndIsrRequest(currentBrokerEpoch: Long, brokerEpochInRequest: Long, expectedError: Errors): Unit = { + val controllerId = 2 + val controllerEpoch = 6 + val capturedResponse: Capture[AbstractResponse] = EasyMock.newCapture() + val partitionStates = Seq( + new LeaderAndIsrRequestData.LeaderAndIsrPartitionState() + .setTopicName("topicW") + .setPartitionIndex(1) + .setControllerEpoch(1) + .setLeader(0) + .setLeaderEpoch(1) + .setIsr(asList(0, 1)) + .setZkVersion(2) + .setReplicas(asList(0, 1, 2)) + .setIsNew(false) + ).asJava + val leaderAndIsrRequest = new LeaderAndIsrRequest.Builder( + ApiKeys.LEADER_AND_ISR.latestVersion, + controllerId, + controllerEpoch, + brokerEpochInRequest, + partitionStates, + Collections.singletonMap("topicW", Uuid.randomUuid()), + asList(new Node(0, "host0", 9090), new Node(1, "host1", 9091)) + ).build() + val request = buildRequest(leaderAndIsrRequest) + val response = new LeaderAndIsrResponse(new LeaderAndIsrResponseData() + .setErrorCode(Errors.NONE.code) + .setPartitionErrors(asList()), leaderAndIsrRequest.version()) + + EasyMock.expect(controller.brokerEpoch).andStubReturn(currentBrokerEpoch) + EasyMock.expect(replicaManager.becomeLeaderOrFollower( + EasyMock.eq(request.context.correlationId), + EasyMock.anyObject(), + EasyMock.anyObject() + )).andStubReturn( + response + ) + + EasyMock.expect(requestChannel.sendResponse( + EasyMock.eq(request), + EasyMock.capture(capturedResponse), + EasyMock.eq(None) + )) + EasyMock.replay(replicaManager, controller, requestChannel) + + createKafkaApis().handleLeaderAndIsrRequest(request) + val leaderAndIsrResponse = capturedResponse.getValue.asInstanceOf[LeaderAndIsrResponse] + assertEquals(expectedError, leaderAndIsrResponse.error()) + EasyMock.verify(replicaManager) + } + + @Test + def testStopReplicaRequestWithCurrentBrokerEpoch(): Unit = { + val currentBrokerEpoch = 1239875L + testStopReplicaRequest(currentBrokerEpoch, currentBrokerEpoch, Errors.NONE) + } + + @Test + def testStopReplicaRequestWithNewerBrokerEpochIsValid(): Unit = { + val currentBrokerEpoch = 1239875L + testStopReplicaRequest(currentBrokerEpoch, currentBrokerEpoch + 1, Errors.NONE) + } + + @Test + def testStopReplicaRequestWithStaleBrokerEpochIsRejected(): Unit = { + val currentBrokerEpoch = 1239875L + testStopReplicaRequest(currentBrokerEpoch, currentBrokerEpoch - 1, Errors.STALE_BROKER_EPOCH) + } + + def testStopReplicaRequest(currentBrokerEpoch: Long, brokerEpochInRequest: Long, expectedError: Errors): Unit = { + val controllerId = 0 + val controllerEpoch = 5 + val capturedResponse: Capture[AbstractResponse] = EasyMock.newCapture() + val fooPartition = new TopicPartition("foo", 0) + val topicStates = Seq( + new StopReplicaTopicState() + .setTopicName(fooPartition.topic) + .setPartitionStates(Seq(new StopReplicaPartitionState() + .setPartitionIndex(fooPartition.partition) + .setLeaderEpoch(1) + .setDeletePartition(false)).asJava) + ).asJava + val stopReplicaRequest = new StopReplicaRequest.Builder( + ApiKeys.STOP_REPLICA.latestVersion, + controllerId, + controllerEpoch, + brokerEpochInRequest, + false, + topicStates + ).build() + val request = buildRequest(stopReplicaRequest) + + EasyMock.expect(controller.brokerEpoch).andStubReturn(currentBrokerEpoch) + EasyMock.expect(replicaManager.stopReplicas( + EasyMock.eq(request.context.correlationId), + EasyMock.eq(controllerId), + EasyMock.eq(controllerEpoch), + EasyMock.eq(stopReplicaRequest.partitionStates().asScala) + )).andStubReturn( + (mutable.Map( + fooPartition -> Errors.NONE + ), Errors.NONE) + ) + EasyMock.expect(requestChannel.sendResponse( + EasyMock.eq(request), + EasyMock.capture(capturedResponse), + EasyMock.eq(None) + )) + + EasyMock.replay(controller, replicaManager, requestChannel) + + createKafkaApis().handleStopReplicaRequest(request) + val stopReplicaResponse = capturedResponse.getValue.asInstanceOf[StopReplicaResponse] + assertEquals(expectedError, stopReplicaResponse.error()) + EasyMock.verify(replicaManager) + } + + @Test + def testListGroupsRequest(): Unit = { + val overviews = List( + GroupOverview("group1", "protocol1", "Stable"), + GroupOverview("group2", "qwerty", "Empty") + ) + val response = listGroupRequest(None, overviews) + assertEquals(2, response.data.groups.size) + assertEquals("Stable", response.data.groups.get(0).groupState) + assertEquals("Empty", response.data.groups.get(1).groupState) + } + + @Test + def testListGroupsRequestWithState(): Unit = { + val overviews = List( + GroupOverview("group1", "protocol1", "Stable") + ) + val response = listGroupRequest(Some("Stable"), overviews) + assertEquals(1, response.data.groups.size) + assertEquals("Stable", response.data.groups.get(0).groupState) + } + + private def listGroupRequest(state: Option[String], overviews: List[GroupOverview]): ListGroupsResponse = { + EasyMock.reset(groupCoordinator, clientRequestQuotaManager, requestChannel) + + val data = new ListGroupsRequestData() + if (state.isDefined) + data.setStatesFilter(Collections.singletonList(state.get)) + val listGroupsRequest = new ListGroupsRequest.Builder(data).build() + val requestChannelRequest = buildRequest(listGroupsRequest) + + val capturedResponse = expectNoThrottling(requestChannelRequest) + val expectedStates: Set[String] = if (state.isDefined) Set(state.get) else Set() + EasyMock.expect(groupCoordinator.handleListGroups(expectedStates)) + .andReturn((Errors.NONE, overviews)) + EasyMock.replay(groupCoordinator, clientRequestQuotaManager, requestChannel) + + createKafkaApis().handleListGroupsRequest(requestChannelRequest) + + val response = capturedResponse.getValue.asInstanceOf[ListGroupsResponse] + assertEquals(Errors.NONE.code, response.data.errorCode) + response + } + + @Test + def testDescribeClusterRequest(): Unit = { + val plaintextListener = ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT) + val brokers = Seq( + new UpdateMetadataBroker() + .setId(0) + .setRack("rack") + .setEndpoints(Seq( + new UpdateMetadataEndpoint() + .setHost("broker0") + .setPort(9092) + .setSecurityProtocol(SecurityProtocol.PLAINTEXT.id) + .setListener(plaintextListener.value) + ).asJava), + new UpdateMetadataBroker() + .setId(1) + .setRack("rack") + .setEndpoints(Seq( + new UpdateMetadataEndpoint() + .setHost("broker1") + .setPort(9092) + .setSecurityProtocol(SecurityProtocol.PLAINTEXT.id) + .setListener(plaintextListener.value)).asJava) + ) + val updateMetadataRequest = new UpdateMetadataRequest.Builder(ApiKeys.UPDATE_METADATA.latestVersion, 0, + 0, 0, Seq.empty[UpdateMetadataPartitionState].asJava, brokers.asJava, Collections.emptyMap()).build() + MetadataCacheTest.updateCache(metadataCache, updateMetadataRequest) + + val describeClusterRequest = new DescribeClusterRequest.Builder(new DescribeClusterRequestData() + .setIncludeClusterAuthorizedOperations(true)).build() + + val request = buildRequest(describeClusterRequest, plaintextListener) + val capturedResponse = expectNoThrottling(request) + + EasyMock.replay(clientRequestQuotaManager, requestChannel) + createKafkaApis().handleDescribeCluster(request) + + val describeClusterResponse = capturedResponse.getValue.asInstanceOf[DescribeClusterResponse] + + assertEquals(metadataCache.getControllerId.get, describeClusterResponse.data.controllerId) + assertEquals(clusterId, describeClusterResponse.data.clusterId) + assertEquals(8096, describeClusterResponse.data.clusterAuthorizedOperations) + assertEquals(metadataCache.getAliveBrokerNodes(plaintextListener).toSet, + describeClusterResponse.nodes.asScala.values.toSet) + } + + /** + * Return pair of listener names in the metadataCache: PLAINTEXT and LISTENER2 respectively. + */ + private def updateMetadataCacheWithInconsistentListeners(): (ListenerName, ListenerName) = { + val plaintextListener = ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT) + val anotherListener = new ListenerName("LISTENER2") + val brokers = Seq( + new UpdateMetadataBroker() + .setId(0) + .setRack("rack") + .setEndpoints(Seq( + new UpdateMetadataEndpoint() + .setHost("broker0") + .setPort(9092) + .setSecurityProtocol(SecurityProtocol.PLAINTEXT.id) + .setListener(plaintextListener.value), + new UpdateMetadataEndpoint() + .setHost("broker0") + .setPort(9093) + .setSecurityProtocol(SecurityProtocol.PLAINTEXT.id) + .setListener(anotherListener.value) + ).asJava), + new UpdateMetadataBroker() + .setId(1) + .setRack("rack") + .setEndpoints(Seq( + new UpdateMetadataEndpoint() + .setHost("broker1") + .setPort(9092) + .setSecurityProtocol(SecurityProtocol.PLAINTEXT.id) + .setListener(plaintextListener.value)).asJava) + ) + val updateMetadataRequest = new UpdateMetadataRequest.Builder(ApiKeys.UPDATE_METADATA.latestVersion, 0, + 0, 0, Seq.empty[UpdateMetadataPartitionState].asJava, brokers.asJava, Collections.emptyMap()).build() + MetadataCacheTest.updateCache(metadataCache, updateMetadataRequest) + (plaintextListener, anotherListener) + } + + private def sendMetadataRequestWithInconsistentListeners(requestListener: ListenerName): MetadataResponse = { + val metadataRequest = MetadataRequest.Builder.allTopics.build() + val requestChannelRequest = buildRequest(metadataRequest, requestListener) + val capturedResponse = expectNoThrottling(requestChannelRequest) + EasyMock.replay(clientRequestQuotaManager, requestChannel) + + createKafkaApis().handleTopicMetadataRequest(requestChannelRequest) + + capturedResponse.getValue.asInstanceOf[MetadataResponse] + } + + private def testConsumerListOffsetLatest(isolationLevel: IsolationLevel): Unit = { + val tp = new TopicPartition("foo", 0) + val latestOffset = 15L + val currentLeaderEpoch = Optional.empty[Integer]() + + EasyMock.expect(replicaManager.fetchOffsetForTimestamp( + EasyMock.eq(tp), + EasyMock.eq(ListOffsetsRequest.LATEST_TIMESTAMP), + EasyMock.eq(Some(isolationLevel)), + EasyMock.eq(currentLeaderEpoch), + fetchOnlyFromLeader = EasyMock.eq(true)) + ).andReturn(Some(new TimestampAndOffset(ListOffsetsResponse.UNKNOWN_TIMESTAMP, latestOffset, currentLeaderEpoch))) + + val targetTimes = List(new ListOffsetsTopic() + .setName(tp.topic) + .setPartitions(List(new ListOffsetsPartition() + .setPartitionIndex(tp.partition) + .setTimestamp(ListOffsetsRequest.LATEST_TIMESTAMP)).asJava)).asJava + val listOffsetRequest = ListOffsetsRequest.Builder.forConsumer(true, isolationLevel, false) + .setTargetTimes(targetTimes).build() + val request = buildRequest(listOffsetRequest) + val capturedResponse = expectNoThrottling(request) + + EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel) + createKafkaApis().handleListOffsetRequest(request) + + val response = capturedResponse.getValue.asInstanceOf[ListOffsetsResponse] + val partitionDataOptional = response.topics.asScala.find(_.name == tp.topic).get + .partitions.asScala.find(_.partitionIndex == tp.partition) + assertTrue(partitionDataOptional.isDefined) + + val partitionData = partitionDataOptional.get + assertEquals(Errors.NONE.code, partitionData.errorCode) + assertEquals(latestOffset, partitionData.offset) + assertEquals(ListOffsetsResponse.UNKNOWN_TIMESTAMP, partitionData.timestamp) + } + + private def createWriteTxnMarkersRequest(partitions: util.List[TopicPartition]) = { + val writeTxnMarkersRequest = new WriteTxnMarkersRequest.Builder(ApiKeys.WRITE_TXN_MARKERS.latestVersion(), + asList(new TxnMarkerEntry(1, 1.toShort, 0, TransactionResult.COMMIT, partitions))).build() + (writeTxnMarkersRequest, buildRequest(writeTxnMarkersRequest)) + } + + private def buildRequest(request: AbstractRequest, + listenerName: ListenerName = ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT), + fromPrivilegedListener: Boolean = false, + requestHeader: Option[RequestHeader] = None): RequestChannel.Request = { + val buffer = request.serializeWithHeader( + requestHeader.getOrElse(new RequestHeader(request.apiKey, request.version, clientId, 0))) + + // read the header from the buffer first so that the body can be read next from the Request constructor + val header = RequestHeader.parse(buffer) + val context = new RequestContext(header, "1", InetAddress.getLocalHost, KafkaPrincipal.ANONYMOUS, + listenerName, SecurityProtocol.PLAINTEXT, ClientInformation.EMPTY, fromPrivilegedListener, + Optional.of(kafkaPrincipalSerde)) + new RequestChannel.Request(processor = 1, context = context, startTimeNanos = 0, MemoryPool.NONE, buffer, + requestChannelMetrics, envelope = None) + } + + private def expectNoThrottling(request: RequestChannel.Request): Capture[AbstractResponse] = { + EasyMock.expect(clientRequestQuotaManager.maybeRecordAndGetThrottleTimeMs(EasyMock.anyObject[RequestChannel.Request](), + EasyMock.anyObject[Long])).andReturn(0) + + EasyMock.expect(clientRequestQuotaManager.throttle( + EasyMock.eq(request), + EasyMock.anyObject[ThrottleCallback](), + EasyMock.eq(0))) + + val capturedResponse = EasyMock.newCapture[AbstractResponse]() + EasyMock.expect(requestChannel.sendResponse( + EasyMock.eq(request), + EasyMock.capture(capturedResponse), + EasyMock.anyObject() + )) + + capturedResponse + } + + private def createBasicMetadataRequest(topic: String, + numPartitions: Int, + brokerEpoch: Long, + numBrokers: Int, + topicId: Uuid = Uuid.ZERO_UUID): UpdateMetadataRequest = { + val replicas = List(0.asInstanceOf[Integer]).asJava + + def createPartitionState(partition: Int) = new UpdateMetadataPartitionState() + .setTopicName(topic) + .setPartitionIndex(partition) + .setControllerEpoch(1) + .setLeader(0) + .setLeaderEpoch(1) + .setReplicas(replicas) + .setZkVersion(0) + .setIsr(replicas) + + val plaintextListener = ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT) + val partitionStates = (0 until numPartitions).map(createPartitionState) + val liveBrokers = (0 until numBrokers).map( + brokerId => createMetadataBroker(brokerId, plaintextListener)) + new UpdateMetadataRequest.Builder(ApiKeys.UPDATE_METADATA.latestVersion, 0, + 0, brokerEpoch, partitionStates.asJava, liveBrokers.asJava, Collections.singletonMap(topic, topicId)).build() + } + + private def setupBasicMetadataCache(topic: String, numPartitions: Int, numBrokers: Int, topicId: Uuid): Unit = { + val updateMetadataRequest = createBasicMetadataRequest(topic, numPartitions, 0, numBrokers, topicId) + MetadataCacheTest.updateCache(metadataCache, updateMetadataRequest) + } + + private def addTopicToMetadataCache(topic: String, numPartitions: Int, numBrokers: Int = 1, topicId: Uuid = Uuid.ZERO_UUID): Unit = { + val updateMetadataRequest = createBasicMetadataRequest(topic, numPartitions, 0, numBrokers, topicId) + MetadataCacheTest.updateCache(metadataCache, updateMetadataRequest) + } + + private def createMetadataBroker(brokerId: Int, + listener: ListenerName): UpdateMetadataBroker = { + new UpdateMetadataBroker() + .setId(brokerId) + .setRack("rack") + .setEndpoints(Seq(new UpdateMetadataEndpoint() + .setHost("broker" + brokerId) + .setPort(9092) + .setSecurityProtocol(SecurityProtocol.PLAINTEXT.id) + .setListener(listener.value)).asJava) + } + + @Test + def testAlterReplicaLogDirs(): Unit = { + val data = new AlterReplicaLogDirsRequestData() + val dir = new AlterReplicaLogDirsRequestData.AlterReplicaLogDir() + .setPath("/foo") + dir.topics().add(new AlterReplicaLogDirsRequestData.AlterReplicaLogDirTopic().setName("t0").setPartitions(asList(0, 1, 2))) + data.dirs().add(dir) + val alterReplicaLogDirsRequest = new AlterReplicaLogDirsRequest.Builder( + data + ).build() + val request = buildRequest(alterReplicaLogDirsRequest) + + EasyMock.reset(replicaManager, clientRequestQuotaManager, requestChannel) + + val capturedResponse = expectNoThrottling(request) + val t0p0 = new TopicPartition("t0", 0) + val t0p1 = new TopicPartition("t0", 1) + val t0p2 = new TopicPartition("t0", 2) + val partitionResults = Map( + t0p0 -> Errors.NONE, + t0p1 -> Errors.LOG_DIR_NOT_FOUND, + t0p2 -> Errors.INVALID_TOPIC_EXCEPTION) + EasyMock.expect(replicaManager.alterReplicaLogDirs(EasyMock.eq(Map( + t0p0 -> "/foo", + t0p1 -> "/foo", + t0p2 -> "/foo")))) + .andReturn(partitionResults) + EasyMock.replay(replicaManager, clientQuotaManager, clientRequestQuotaManager, requestChannel) + + createKafkaApis().handleAlterReplicaLogDirsRequest(request) + + val response = capturedResponse.getValue.asInstanceOf[AlterReplicaLogDirsResponse] + assertEquals(partitionResults, response.data.results.asScala.flatMap { tr => + tr.partitions().asScala.map { pr => + new TopicPartition(tr.topicName, pr.partitionIndex) -> Errors.forCode(pr.errorCode) + } + }.toMap) + assertEquals(Map(Errors.NONE -> 1, + Errors.LOG_DIR_NOT_FOUND -> 1, + Errors.INVALID_TOPIC_EXCEPTION -> 1).asJava, response.errorCounts) + } + + @Test + def testSizeOfThrottledPartitions(): Unit = { + val topicNames = new util.HashMap[Uuid, String] + val topicIds = new util.HashMap[String, Uuid]() + def fetchResponse(data: Map[TopicIdPartition, String]): FetchResponse = { + val responseData = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData]( + data.map { case (tp, raw) => + tp -> new FetchResponseData.PartitionData() + .setPartitionIndex(tp.topicPartition.partition) + .setHighWatermark(105) + .setLastStableOffset(105) + .setLogStartOffset(0) + .setRecords(MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord(100, raw.getBytes(StandardCharsets.UTF_8)))) + }.toMap.asJava) + + data.foreach{case (tp, _) => + topicIds.put(tp.topicPartition.topic, tp.topicId) + topicNames.put(tp.topicId, tp.topicPartition.topic) + } + FetchResponse.of(Errors.NONE, 100, 100, responseData) + } + + val throttledPartition = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("throttledData", 0)) + val throttledData = Map(throttledPartition -> "throttledData") + val expectedSize = FetchResponse.sizeOf(FetchResponseData.HIGHEST_SUPPORTED_VERSION, + fetchResponse(throttledData).responseData(topicNames, FetchResponseData.HIGHEST_SUPPORTED_VERSION).entrySet.asScala.map( entry => + (new TopicIdPartition(Uuid.ZERO_UUID, entry.getKey), entry.getValue)).toMap.asJava.entrySet.iterator) + + val response = fetchResponse(throttledData ++ Map(new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("nonThrottledData", 0)) -> "nonThrottledData")) + + val quota = Mockito.mock(classOf[ReplicationQuotaManager]) + Mockito.when(quota.isThrottled(ArgumentMatchers.any(classOf[TopicPartition]))) + .thenAnswer(invocation => throttledPartition.topicPartition == invocation.getArgument(0).asInstanceOf[TopicPartition]) + + assertEquals(expectedSize, KafkaApis.sizeOfThrottledPartitions(FetchResponseData.HIGHEST_SUPPORTED_VERSION, response, quota)) + } + + @Test + def testDescribeProducers(): Unit = { + val tp1 = new TopicPartition("foo", 0) + val tp2 = new TopicPartition("bar", 3) + val tp3 = new TopicPartition("baz", 1) + val tp4 = new TopicPartition("invalid;topic", 1) + + val authorizer: Authorizer = EasyMock.niceMock(classOf[Authorizer]) + val data = new DescribeProducersRequestData().setTopics(List( + new DescribeProducersRequestData.TopicRequest() + .setName(tp1.topic) + .setPartitionIndexes(List(Int.box(tp1.partition)).asJava), + new DescribeProducersRequestData.TopicRequest() + .setName(tp2.topic) + .setPartitionIndexes(List(Int.box(tp2.partition)).asJava), + new DescribeProducersRequestData.TopicRequest() + .setName(tp3.topic) + .setPartitionIndexes(List(Int.box(tp3.partition)).asJava), + new DescribeProducersRequestData.TopicRequest() + .setName(tp4.topic) + .setPartitionIndexes(List(Int.box(tp4.partition)).asJava) + ).asJava) + + def buildExpectedActions(topic: String): util.List[Action] = { + val pattern = new ResourcePattern(ResourceType.TOPIC, topic, PatternType.LITERAL) + val action = new Action(AclOperation.READ, pattern, 1, true, true) + Collections.singletonList(action) + } + + // Topic `foo` is authorized and present in the metadata + addTopicToMetadataCache(tp1.topic, 4) // We will only access the first topic + EasyMock.expect(authorizer.authorize(anyObject[RequestContext], EasyMock.eq(buildExpectedActions(tp1.topic)))) + .andReturn(Seq(AuthorizationResult.ALLOWED).asJava) + .once() + + // Topic `bar` is not authorized + EasyMock.expect(authorizer.authorize(anyObject[RequestContext], EasyMock.eq(buildExpectedActions(tp2.topic)))) + .andReturn(Seq(AuthorizationResult.DENIED).asJava) + .once() + + // Topic `baz` is authorized, but not present in the metadata + EasyMock.expect(authorizer.authorize(anyObject[RequestContext], EasyMock.eq(buildExpectedActions(tp3.topic)))) + .andReturn(Seq(AuthorizationResult.ALLOWED).asJava) + .once() + + EasyMock.expect(replicaManager.activeProducerState(tp1)) + .andReturn(new DescribeProducersResponseData.PartitionResponse() + .setErrorCode(Errors.NONE.code) + .setPartitionIndex(tp1.partition) + .setActiveProducers(List( + new DescribeProducersResponseData.ProducerState() + .setProducerId(12345L) + .setProducerEpoch(15) + .setLastSequence(100) + .setLastTimestamp(time.milliseconds()) + .setCurrentTxnStartOffset(-1) + .setCoordinatorEpoch(200) + ).asJava)) + + val describeProducersRequest = new DescribeProducersRequest.Builder(data).build() + val request = buildRequest(describeProducersRequest) + val capturedResponse = expectNoThrottling(request) + + EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, txnCoordinator, authorizer) + createKafkaApis(authorizer = Some(authorizer)).handleDescribeProducersRequest(request) + + val response = capturedResponse.getValue.asInstanceOf[DescribeProducersResponse] + assertEquals(Set("foo", "bar", "baz", "invalid;topic"), response.data.topics.asScala.map(_.name).toSet) + + def assertPartitionError( + topicPartition: TopicPartition, + error: Errors + ): DescribeProducersResponseData.PartitionResponse = { + val topicData = response.data.topics.asScala.find(_.name == topicPartition.topic).get + val partitionData = topicData.partitions.asScala.find(_.partitionIndex == topicPartition.partition).get + assertEquals(error, Errors.forCode(partitionData.errorCode)) + partitionData + } + + val fooPartition = assertPartitionError(tp1, Errors.NONE) + assertEquals(Errors.NONE, Errors.forCode(fooPartition.errorCode)) + assertEquals(1, fooPartition.activeProducers.size) + val fooProducer = fooPartition.activeProducers.get(0) + assertEquals(12345L, fooProducer.producerId) + assertEquals(15, fooProducer.producerEpoch) + assertEquals(100, fooProducer.lastSequence) + assertEquals(time.milliseconds(), fooProducer.lastTimestamp) + assertEquals(-1, fooProducer.currentTxnStartOffset) + assertEquals(200, fooProducer.coordinatorEpoch) + + assertPartitionError(tp2, Errors.TOPIC_AUTHORIZATION_FAILED) + assertPartitionError(tp3, Errors.UNKNOWN_TOPIC_OR_PARTITION) + assertPartitionError(tp4, Errors.INVALID_TOPIC_EXCEPTION) + } + + @Test + def testDescribeTransactions(): Unit = { + val authorizer: Authorizer = EasyMock.niceMock(classOf[Authorizer]) + val data = new DescribeTransactionsRequestData() + .setTransactionalIds(List("foo", "bar").asJava) + val describeTransactionsRequest = new DescribeTransactionsRequest.Builder(data).build() + val request = buildRequest(describeTransactionsRequest) + val capturedResponse = expectNoThrottling(request) + + def buildExpectedActions(transactionalId: String): util.List[Action] = { + val pattern = new ResourcePattern(ResourceType.TRANSACTIONAL_ID, transactionalId, PatternType.LITERAL) + val action = new Action(AclOperation.DESCRIBE, pattern, 1, true, true) + Collections.singletonList(action) + } + + EasyMock.expect(txnCoordinator.handleDescribeTransactions("foo")) + .andReturn(new DescribeTransactionsResponseData.TransactionState() + .setErrorCode(Errors.NONE.code) + .setTransactionalId("foo") + .setProducerId(12345L) + .setProducerEpoch(15) + .setTransactionStartTimeMs(time.milliseconds()) + .setTransactionState("CompleteCommit") + .setTransactionTimeoutMs(10000)) + + EasyMock.expect(authorizer.authorize(anyObject[RequestContext], EasyMock.eq(buildExpectedActions("foo")))) + .andReturn(Seq(AuthorizationResult.ALLOWED).asJava) + .once() + + EasyMock.expect(authorizer.authorize(anyObject[RequestContext], EasyMock.eq(buildExpectedActions("bar")))) + .andReturn(Seq(AuthorizationResult.DENIED).asJava) + .once() + + EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, txnCoordinator, authorizer) + createKafkaApis(authorizer = Some(authorizer)).handleDescribeTransactionsRequest(request) + + val response = capturedResponse.getValue.asInstanceOf[DescribeTransactionsResponse] + assertEquals(2, response.data.transactionStates.size) + + val fooState = response.data.transactionStates.asScala.find(_.transactionalId == "foo").get + assertEquals(Errors.NONE.code, fooState.errorCode) + assertEquals(12345L, fooState.producerId) + assertEquals(15, fooState.producerEpoch) + assertEquals(time.milliseconds(), fooState.transactionStartTimeMs) + assertEquals("CompleteCommit", fooState.transactionState) + assertEquals(10000, fooState.transactionTimeoutMs) + assertEquals(List.empty, fooState.topics.asScala.toList) + + val barState = response.data.transactionStates.asScala.find(_.transactionalId == "bar").get + assertEquals(Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED.code, barState.errorCode) + } + + @Test + def testDescribeTransactionsFiltersUnauthorizedTopics(): Unit = { + val authorizer: Authorizer = EasyMock.niceMock(classOf[Authorizer]) + val transactionalId = "foo" + val data = new DescribeTransactionsRequestData() + .setTransactionalIds(List(transactionalId).asJava) + val describeTransactionsRequest = new DescribeTransactionsRequest.Builder(data).build() + val request = buildRequest(describeTransactionsRequest) + val capturedResponse = expectNoThrottling(request) + + def expectDescribe( + resourceType: ResourceType, + transactionalId: String, + result: AuthorizationResult + ): Unit = { + val pattern = new ResourcePattern(resourceType, transactionalId, PatternType.LITERAL) + val action = new Action(AclOperation.DESCRIBE, pattern, 1, true, true) + val actions = Collections.singletonList(action) + + EasyMock.expect(authorizer.authorize(anyObject[RequestContext], EasyMock.eq(actions))) + .andReturn(Seq(result).asJava) + .once() + } + + // Principal is authorized to one of the two topics. The second topic should be + // filtered from the result. + expectDescribe(ResourceType.TRANSACTIONAL_ID, transactionalId, AuthorizationResult.ALLOWED) + expectDescribe(ResourceType.TOPIC, "foo", AuthorizationResult.ALLOWED) + expectDescribe(ResourceType.TOPIC, "bar", AuthorizationResult.DENIED) + + def mkTopicData( + topic: String, + partitions: Seq[Int] + ): DescribeTransactionsResponseData.TopicData = { + new DescribeTransactionsResponseData.TopicData() + .setTopic(topic) + .setPartitions(partitions.map(Int.box).asJava) + } + + val describeTransactionsResponse = new DescribeTransactionsResponseData.TransactionState() + .setErrorCode(Errors.NONE.code) + .setTransactionalId(transactionalId) + .setProducerId(12345L) + .setProducerEpoch(15) + .setTransactionStartTimeMs(time.milliseconds()) + .setTransactionState("Ongoing") + .setTransactionTimeoutMs(10000) + + describeTransactionsResponse.topics.add(mkTopicData(topic = "foo", Seq(1, 2))) + describeTransactionsResponse.topics.add(mkTopicData(topic = "bar", Seq(3, 4))) + + EasyMock.expect(txnCoordinator.handleDescribeTransactions("foo")) + .andReturn(describeTransactionsResponse) + + EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, txnCoordinator, authorizer) + createKafkaApis(authorizer = Some(authorizer)).handleDescribeTransactionsRequest(request) + + val response = capturedResponse.getValue.asInstanceOf[DescribeTransactionsResponse] + assertEquals(1, response.data.transactionStates.size) + + val fooState = response.data.transactionStates.asScala.find(_.transactionalId == "foo").get + assertEquals(Errors.NONE.code, fooState.errorCode) + assertEquals(12345L, fooState.producerId) + assertEquals(15, fooState.producerEpoch) + assertEquals(time.milliseconds(), fooState.transactionStartTimeMs) + assertEquals("Ongoing", fooState.transactionState) + assertEquals(10000, fooState.transactionTimeoutMs) + assertEquals(List(mkTopicData(topic = "foo", Seq(1, 2))), fooState.topics.asScala.toList) + } + + @Test + def testListTransactionsErrorResponse(): Unit = { + val data = new ListTransactionsRequestData() + val listTransactionsRequest = new ListTransactionsRequest.Builder(data).build() + val request = buildRequest(listTransactionsRequest) + val capturedResponse = expectNoThrottling(request) + + EasyMock.expect(txnCoordinator.handleListTransactions(Set.empty[Long], Set.empty[String])) + .andReturn(new ListTransactionsResponseData() + .setErrorCode(Errors.COORDINATOR_LOAD_IN_PROGRESS.code)) + + EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, txnCoordinator) + createKafkaApis().handleListTransactionsRequest(request) + + val response = capturedResponse.getValue.asInstanceOf[ListTransactionsResponse] + assertEquals(0, response.data.transactionStates.size) + assertEquals(Errors.COORDINATOR_LOAD_IN_PROGRESS, Errors.forCode(response.data.errorCode)) + } + + @Test + def testListTransactionsAuthorization(): Unit = { + val authorizer: Authorizer = EasyMock.niceMock(classOf[Authorizer]) + val data = new ListTransactionsRequestData() + val listTransactionsRequest = new ListTransactionsRequest.Builder(data).build() + val request = buildRequest(listTransactionsRequest) + val capturedResponse = expectNoThrottling(request) + + val transactionStates = new util.ArrayList[ListTransactionsResponseData.TransactionState]() + transactionStates.add(new ListTransactionsResponseData.TransactionState() + .setTransactionalId("foo") + .setProducerId(12345L) + .setTransactionState("Ongoing")) + transactionStates.add(new ListTransactionsResponseData.TransactionState() + .setTransactionalId("bar") + .setProducerId(98765) + .setTransactionState("PrepareAbort")) + + EasyMock.expect(txnCoordinator.handleListTransactions(Set.empty[Long], Set.empty[String])) + .andReturn(new ListTransactionsResponseData() + .setErrorCode(Errors.NONE.code) + .setTransactionStates(transactionStates)) + + def buildExpectedActions(transactionalId: String): util.List[Action] = { + val pattern = new ResourcePattern(ResourceType.TRANSACTIONAL_ID, transactionalId, PatternType.LITERAL) + val action = new Action(AclOperation.DESCRIBE, pattern, 1, true, true) + Collections.singletonList(action) + } + + EasyMock.expect(authorizer.authorize(anyObject[RequestContext], EasyMock.eq(buildExpectedActions("foo")))) + .andReturn(Seq(AuthorizationResult.ALLOWED).asJava) + .once() + + EasyMock.expect(authorizer.authorize(anyObject[RequestContext], EasyMock.eq(buildExpectedActions("bar")))) + .andReturn(Seq(AuthorizationResult.DENIED).asJava) + .once() + + EasyMock.replay(replicaManager, clientRequestQuotaManager, requestChannel, txnCoordinator, authorizer) + createKafkaApis(authorizer = Some(authorizer)).handleListTransactionsRequest(request) + + val response = capturedResponse.getValue.asInstanceOf[ListTransactionsResponse] + assertEquals(1, response.data.transactionStates.size()) + val transactionState = response.data.transactionStates.get(0) + assertEquals("foo", transactionState.transactionalId) + assertEquals(12345L, transactionState.producerId) + assertEquals("Ongoing", transactionState.transactionState) + } + + @Test + def testDeleteTopicsByIdAuthorization(): Unit = { + val authorizer: Authorizer = EasyMock.niceMock(classOf[Authorizer]) + val controllerContext: ControllerContext = EasyMock.mock(classOf[ControllerContext]) + + EasyMock.expect(clientControllerQuotaManager.newQuotaFor( + EasyMock.anyObject(classOf[RequestChannel.Request]), + EasyMock.anyShort() + )).andReturn(UnboundedControllerMutationQuota) + EasyMock.expect(controller.isActive).andReturn(true) + EasyMock.expect(controller.controllerContext).andStubReturn(controllerContext) + + // Try to delete three topics: + // 1. One without describe permission + // 2. One without delete permission + // 3. One which is authorized, but doesn't exist + + expectTopicAuthorization(authorizer, AclOperation.DESCRIBE, Map( + "foo" -> AuthorizationResult.DENIED, + "bar" -> AuthorizationResult.ALLOWED + )) + + expectTopicAuthorization(authorizer, AclOperation.DELETE, Map( + "foo" -> AuthorizationResult.DENIED, + "bar" -> AuthorizationResult.DENIED + )) + + val topicIdsMap = Map( + Uuid.randomUuid() -> Some("foo"), + Uuid.randomUuid() -> Some("bar"), + Uuid.randomUuid() -> None + ) + + topicIdsMap.foreach { case (topicId, topicNameOpt) => + EasyMock.expect(controllerContext.topicName(topicId)).andReturn(topicNameOpt) + } + + val topicDatas = topicIdsMap.keys.map { topicId => + new DeleteTopicsRequestData.DeleteTopicState().setTopicId(topicId) + }.toList + val deleteRequest = new DeleteTopicsRequest.Builder(new DeleteTopicsRequestData() + .setTopics(topicDatas.asJava)) + .build(ApiKeys.DELETE_TOPICS.latestVersion) + + val request = buildRequest(deleteRequest) + val capturedResponse = expectNoThrottling(request) + + EasyMock.replay(replicaManager, clientRequestQuotaManager, clientControllerQuotaManager, + requestChannel, txnCoordinator, controller, controllerContext, authorizer) + createKafkaApis(authorizer = Some(authorizer)).handleDeleteTopicsRequest(request) + + val deleteResponse = capturedResponse.getValue.asInstanceOf[DeleteTopicsResponse] + + topicIdsMap.foreach { case (topicId, nameOpt) => + val response = deleteResponse.data.responses.asScala.find(_.topicId == topicId).get + nameOpt match { + case Some("foo") => + assertNull(response.name) + assertEquals(Errors.TOPIC_AUTHORIZATION_FAILED, Errors.forCode(response.errorCode)) + case Some("bar") => + assertEquals("bar", response.name) + assertEquals(Errors.TOPIC_AUTHORIZATION_FAILED, Errors.forCode(response.errorCode)) + case None => + assertNull(response.name) + assertEquals(Errors.UNKNOWN_TOPIC_ID, Errors.forCode(response.errorCode)) + case _ => + fail("Unexpected topic id/name mapping") + } + } + } + + @ParameterizedTest + @ValueSource(booleans = Array(true, false)) + def testDeleteTopicsByNameAuthorization(usePrimitiveTopicNameArray: Boolean): Unit = { + val authorizer: Authorizer = EasyMock.niceMock(classOf[Authorizer]) + + EasyMock.expect(clientControllerQuotaManager.newQuotaFor( + EasyMock.anyObject(classOf[RequestChannel.Request]), + EasyMock.anyShort() + )).andReturn(UnboundedControllerMutationQuota) + EasyMock.expect(controller.isActive).andReturn(true) + + // Try to delete three topics: + // 1. One without describe permission + // 2. One without delete permission + // 3. One which is authorized, but doesn't exist + + expectTopicAuthorization(authorizer, AclOperation.DESCRIBE, Map( + "foo" -> AuthorizationResult.DENIED, + "bar" -> AuthorizationResult.ALLOWED, + "baz" -> AuthorizationResult.ALLOWED + )) + + expectTopicAuthorization(authorizer, AclOperation.DELETE, Map( + "foo" -> AuthorizationResult.DENIED, + "bar" -> AuthorizationResult.DENIED, + "baz" -> AuthorizationResult.ALLOWED + )) + + val deleteRequest = if (usePrimitiveTopicNameArray) { + new DeleteTopicsRequest.Builder(new DeleteTopicsRequestData() + .setTopicNames(List("foo", "bar", "baz").asJava)) + .build(5.toShort) + } else { + val topicDatas = List( + new DeleteTopicsRequestData.DeleteTopicState().setName("foo"), + new DeleteTopicsRequestData.DeleteTopicState().setName("bar"), + new DeleteTopicsRequestData.DeleteTopicState().setName("baz") + ) + new DeleteTopicsRequest.Builder(new DeleteTopicsRequestData() + .setTopics(topicDatas.asJava)) + .build(ApiKeys.DELETE_TOPICS.latestVersion) + } + + val request = buildRequest(deleteRequest) + val capturedResponse = expectNoThrottling(request) + + EasyMock.replay(replicaManager, clientRequestQuotaManager, clientControllerQuotaManager, + requestChannel, txnCoordinator, controller, authorizer) + createKafkaApis(authorizer = Some(authorizer)).handleDeleteTopicsRequest(request) + + val deleteResponse = capturedResponse.getValue.asInstanceOf[DeleteTopicsResponse] + + def lookupErrorCode(topic: String): Option[Errors] = { + Option(deleteResponse.data.responses().find(topic)) + .map(result => Errors.forCode(result.errorCode)) + } + + assertEquals(Some(Errors.TOPIC_AUTHORIZATION_FAILED), lookupErrorCode("foo")) + assertEquals(Some(Errors.TOPIC_AUTHORIZATION_FAILED), lookupErrorCode("bar")) + assertEquals(Some(Errors.UNKNOWN_TOPIC_OR_PARTITION), lookupErrorCode("baz")) + } + + def expectTopicAuthorization( + authorizer: Authorizer, + aclOperation: AclOperation, + topicResults: Map[String, AuthorizationResult] + ): Unit = { + val expectedActions = topicResults.keys.map { topic => + val pattern = new ResourcePattern(ResourceType.TOPIC, topic, PatternType.LITERAL) + topic -> new Action(aclOperation, pattern, 1, true, true) + }.toMap + + val actionsCapture: Capture[util.List[Action]] = EasyMock.newCapture() + EasyMock.expect(authorizer.authorize(anyObject[RequestContext], EasyMock.capture(actionsCapture))) + .andAnswer(() => { + actionsCapture.getValue.asScala.map { action => + val topic = action.resourcePattern.name + assertEquals(expectedActions(topic), action) + topicResults(topic) + }.asJava + }) + .once() + } + + private def createMockRequest(): RequestChannel.Request = { + val request: RequestChannel.Request = EasyMock.createNiceMock(classOf[RequestChannel.Request]) + val requestHeader: RequestHeader = EasyMock.createNiceMock(classOf[RequestHeader]) + expect(request.header).andReturn(requestHeader).anyTimes() + expect(requestHeader.apiKey()).andReturn(ApiKeys.values().head).anyTimes() + EasyMock.replay(request, requestHeader) + request + } + + private def verifyShouldNeverHandleErrorMessage(handler: RequestChannel.Request => Unit): Unit = { + val request = createMockRequest() + val e = assertThrows(classOf[UnsupportedVersionException], () => handler(request)) + assertEquals(KafkaApis.shouldNeverReceive(request).getMessage, e.getMessage) + } + + private def verifyShouldAlwaysForwardErrorMessage(handler: RequestChannel.Request => Unit): Unit = { + val request = createMockRequest() + val e = assertThrows(classOf[UnsupportedVersionException], () => handler(request)) + assertEquals(KafkaApis.shouldAlwaysForward(request).getMessage, e.getMessage) + } + + @Test + def testRaftShouldNeverHandleLeaderAndIsrRequest(): Unit = { + metadataCache = MetadataCache.kRaftMetadataCache(brokerId) + verifyShouldNeverHandleErrorMessage(createKafkaApis(raftSupport = true).handleLeaderAndIsrRequest) + } + + @Test + def testRaftShouldNeverHandleStopReplicaRequest(): Unit = { + metadataCache = MetadataCache.kRaftMetadataCache(brokerId) + verifyShouldNeverHandleErrorMessage(createKafkaApis(raftSupport = true).handleStopReplicaRequest) + } + + @Test + def testRaftShouldNeverHandleUpdateMetadataRequest(): Unit = { + metadataCache = MetadataCache.kRaftMetadataCache(brokerId) + verifyShouldNeverHandleErrorMessage(createKafkaApis(raftSupport = true).handleUpdateMetadataRequest(_, RequestLocal.withThreadConfinedCaching)) + } + + @Test + def testRaftShouldNeverHandleControlledShutdownRequest(): Unit = { + metadataCache = MetadataCache.kRaftMetadataCache(brokerId) + verifyShouldNeverHandleErrorMessage(createKafkaApis(raftSupport = true).handleControlledShutdownRequest) + } + + @Test + def testRaftShouldNeverHandleAlterIsrRequest(): Unit = { + metadataCache = MetadataCache.kRaftMetadataCache(brokerId) + verifyShouldNeverHandleErrorMessage(createKafkaApis(raftSupport = true).handleAlterIsrRequest) + } + + @Test + def testRaftShouldNeverHandleEnvelope(): Unit = { + metadataCache = MetadataCache.kRaftMetadataCache(brokerId) + verifyShouldNeverHandleErrorMessage(createKafkaApis(raftSupport = true).handleEnvelope(_, RequestLocal.withThreadConfinedCaching)) + } + + @Test + def testRaftShouldAlwaysForwardCreateTopicsRequest(): Unit = { + metadataCache = MetadataCache.kRaftMetadataCache(brokerId) + verifyShouldAlwaysForwardErrorMessage(createKafkaApis(raftSupport = true).handleCreateTopicsRequest) + } + + @Test + def testRaftShouldAlwaysForwardCreatePartitionsRequest(): Unit = { + metadataCache = MetadataCache.kRaftMetadataCache(brokerId) + verifyShouldAlwaysForwardErrorMessage(createKafkaApis(raftSupport = true).handleCreatePartitionsRequest) + } + + @Test + def testRaftShouldAlwaysForwardDeleteTopicsRequest(): Unit = { + metadataCache = MetadataCache.kRaftMetadataCache(brokerId) + verifyShouldAlwaysForwardErrorMessage(createKafkaApis(raftSupport = true).handleDeleteTopicsRequest) + } + + @Test + def testRaftShouldAlwaysForwardCreateAcls(): Unit = { + metadataCache = MetadataCache.kRaftMetadataCache(brokerId) + verifyShouldAlwaysForwardErrorMessage(createKafkaApis(raftSupport = true).handleCreateAcls) + } + + @Test + def testRaftShouldAlwaysForwardDeleteAcls(): Unit = { + metadataCache = MetadataCache.kRaftMetadataCache(brokerId) + verifyShouldAlwaysForwardErrorMessage(createKafkaApis(raftSupport = true).handleDeleteAcls) + } + + @Test + def testRaftShouldAlwaysForwardAlterConfigsRequest(): Unit = { + metadataCache = MetadataCache.kRaftMetadataCache(brokerId) + verifyShouldAlwaysForwardErrorMessage(createKafkaApis(raftSupport = true).handleAlterConfigsRequest) + } + + @Test + def testRaftShouldAlwaysForwardAlterPartitionReassignmentsRequest(): Unit = { + metadataCache = MetadataCache.kRaftMetadataCache(brokerId) + verifyShouldAlwaysForwardErrorMessage(createKafkaApis(raftSupport = true).handleAlterPartitionReassignmentsRequest) + } + + @Test + def testRaftShouldAlwaysForwardIncrementalAlterConfigsRequest(): Unit = { + metadataCache = MetadataCache.kRaftMetadataCache(brokerId) + verifyShouldAlwaysForwardErrorMessage(createKafkaApis(raftSupport = true).handleIncrementalAlterConfigsRequest) + } + + @Test + def testRaftShouldAlwaysForwardCreateTokenRequest(): Unit = { + metadataCache = MetadataCache.kRaftMetadataCache(brokerId) + verifyShouldAlwaysForwardErrorMessage(createKafkaApis(raftSupport = true).handleCreateTokenRequest) + } + + @Test + def testRaftShouldAlwaysForwardRenewTokenRequest(): Unit = { + metadataCache = MetadataCache.kRaftMetadataCache(brokerId) + verifyShouldAlwaysForwardErrorMessage(createKafkaApis(raftSupport = true).handleRenewTokenRequest) + } + + @Test + def testRaftShouldAlwaysForwardExpireTokenRequest(): Unit = { + metadataCache = MetadataCache.kRaftMetadataCache(brokerId) + verifyShouldAlwaysForwardErrorMessage(createKafkaApis(raftSupport = true).handleExpireTokenRequest) + } + + @Test + def testRaftShouldAlwaysForwardAlterClientQuotasRequest(): Unit = { + metadataCache = MetadataCache.kRaftMetadataCache(brokerId) + verifyShouldAlwaysForwardErrorMessage(createKafkaApis(raftSupport = true).handleAlterClientQuotasRequest) + } + + @Test + def testRaftShouldAlwaysForwardAlterUserScramCredentialsRequest(): Unit = { + metadataCache = MetadataCache.kRaftMetadataCache(brokerId) + verifyShouldAlwaysForwardErrorMessage(createKafkaApis(raftSupport = true).handleAlterUserScramCredentialsRequest) + } + + @Test + def testRaftShouldAlwaysForwardUpdateFeatures(): Unit = { + metadataCache = MetadataCache.kRaftMetadataCache(brokerId) + verifyShouldAlwaysForwardErrorMessage(createKafkaApis(raftSupport = true).handleUpdateFeatures) + } + + @Test + def testRaftShouldAlwaysForwardElectLeaders(): Unit = { + metadataCache = MetadataCache.kRaftMetadataCache(brokerId) + verifyShouldAlwaysForwardErrorMessage(createKafkaApis(raftSupport = true).handleElectLeaders) + } + + @Test + def testRaftShouldAlwaysForwardListPartitionReassignments(): Unit = { + metadataCache = MetadataCache.kRaftMetadataCache(brokerId) + verifyShouldAlwaysForwardErrorMessage(createKafkaApis(raftSupport = true).handleListPartitionReassignmentsRequest) + } +} diff --git a/core/src/test/scala/unit/kafka/server/KafkaConfigTest.scala b/core/src/test/scala/unit/kafka/server/KafkaConfigTest.scala new file mode 100755 index 0000000..2ce336e --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/KafkaConfigTest.scala @@ -0,0 +1,1511 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.api.{ApiVersion, KAFKA_0_8_2, KAFKA_3_0_IV1} +import kafka.cluster.EndPoint +import kafka.log.LogConfig +import kafka.message._ +import kafka.utils.TestUtils.assertBadConfigContainingMessage +import kafka.utils.{CoreUtils, TestUtils} +import org.apache.kafka.common.config.{ConfigException, TopicConfig} +import org.apache.kafka.common.metrics.Sensor +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.record.Records +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.raft.RaftConfig +import org.apache.kafka.raft.RaftConfig.{AddressSpec, InetAddressSpec, UNKNOWN_ADDRESS_SPEC_INSTANCE} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import java.net.InetSocketAddress +import java.util +import java.util.{Collections, Properties} +import org.apache.kafka.common.Node +import org.apache.kafka.server.log.remote.storage.RemoteLogManagerConfig +import org.junit.jupiter.api.function.Executable + +import scala.annotation.nowarn +import scala.jdk.CollectionConverters._ + +class KafkaConfigTest { + + @Test + def testLogRetentionTimeHoursProvided(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + props.put(KafkaConfig.LogRetentionTimeHoursProp, "1") + + val cfg = KafkaConfig.fromProps(props) + assertEquals(60L * 60L * 1000L, cfg.logRetentionTimeMillis) + } + + @Test + def testLogRetentionTimeMinutesProvided(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + props.put(KafkaConfig.LogRetentionTimeMinutesProp, "30") + + val cfg = KafkaConfig.fromProps(props) + assertEquals(30 * 60L * 1000L, cfg.logRetentionTimeMillis) + } + + @Test + def testLogRetentionTimeMsProvided(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + props.put(KafkaConfig.LogRetentionTimeMillisProp, "1800000") + + val cfg = KafkaConfig.fromProps(props) + assertEquals(30 * 60L * 1000L, cfg.logRetentionTimeMillis) + } + + @Test + def testLogRetentionTimeNoConfigProvided(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + + val cfg = KafkaConfig.fromProps(props) + assertEquals(24 * 7 * 60L * 60L * 1000L, cfg.logRetentionTimeMillis) + } + + @Test + def testLogRetentionTimeBothMinutesAndHoursProvided(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + props.put(KafkaConfig.LogRetentionTimeMinutesProp, "30") + props.put(KafkaConfig.LogRetentionTimeHoursProp, "1") + + val cfg = KafkaConfig.fromProps(props) + assertEquals( 30 * 60L * 1000L, cfg.logRetentionTimeMillis) + } + + @Test + def testLogRetentionTimeBothMinutesAndMsProvided(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + props.put(KafkaConfig.LogRetentionTimeMillisProp, "1800000") + props.put(KafkaConfig.LogRetentionTimeMinutesProp, "10") + + val cfg = KafkaConfig.fromProps(props) + assertEquals( 30 * 60L * 1000L, cfg.logRetentionTimeMillis) + } + + @Test + def testLogRetentionUnlimited(): Unit = { + val props1 = TestUtils.createBrokerConfig(0,TestUtils.MockZkConnect, port = 8181) + val props2 = TestUtils.createBrokerConfig(0,TestUtils.MockZkConnect, port = 8181) + val props3 = TestUtils.createBrokerConfig(0,TestUtils.MockZkConnect, port = 8181) + val props4 = TestUtils.createBrokerConfig(0,TestUtils.MockZkConnect, port = 8181) + val props5 = TestUtils.createBrokerConfig(0,TestUtils.MockZkConnect, port = 8181) + + props1.put("log.retention.ms", "-1") + props2.put("log.retention.minutes", "-1") + props3.put("log.retention.hours", "-1") + + val cfg1 = KafkaConfig.fromProps(props1) + val cfg2 = KafkaConfig.fromProps(props2) + val cfg3 = KafkaConfig.fromProps(props3) + assertEquals(-1, cfg1.logRetentionTimeMillis, "Should be -1") + assertEquals(-1, cfg2.logRetentionTimeMillis, "Should be -1") + assertEquals(-1, cfg3.logRetentionTimeMillis, "Should be -1") + + props4.put("log.retention.ms", "-1") + props4.put("log.retention.minutes", "30") + + val cfg4 = KafkaConfig.fromProps(props4) + assertEquals(-1, cfg4.logRetentionTimeMillis, "Should be -1") + + props5.put("log.retention.ms", "0") + + assertThrows(classOf[IllegalArgumentException], () => KafkaConfig.fromProps(props5)) + } + + @Test + def testLogRetentionValid(): Unit = { + val props1 = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + val props2 = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + val props3 = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + + props1.put("log.retention.ms", "0") + props2.put("log.retention.minutes", "0") + props3.put("log.retention.hours", "0") + + assertThrows(classOf[IllegalArgumentException], () => KafkaConfig.fromProps(props1)) + assertThrows(classOf[IllegalArgumentException], () => KafkaConfig.fromProps(props2)) + assertThrows(classOf[IllegalArgumentException], () => KafkaConfig.fromProps(props3)) + + } + + @Test + def testAdvertiseDefaults(): Unit = { + val port = 9999 + val hostName = "fake-host" + val props = new Properties() + props.put(KafkaConfig.BrokerIdProp, "1") + props.put(KafkaConfig.ZkConnectProp, "localhost:2181") + props.put(KafkaConfig.ListenersProp, s"PLAINTEXT://$hostName:$port") + val serverConfig = KafkaConfig.fromProps(props) + + val endpoints = serverConfig.effectiveAdvertisedListeners + assertEquals(1, endpoints.size) + val endpoint = endpoints.find(_.securityProtocol == SecurityProtocol.PLAINTEXT).get + assertEquals(endpoint.host, hostName) + assertEquals(endpoint.port, port) + } + + @Test + def testAdvertiseConfigured(): Unit = { + val advertisedHostName = "routable-host" + val advertisedPort = 1234 + + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect) + props.put(KafkaConfig.AdvertisedListenersProp, s"PLAINTEXT://$advertisedHostName:$advertisedPort") + + val serverConfig = KafkaConfig.fromProps(props) + val endpoints = serverConfig.effectiveAdvertisedListeners + val endpoint = endpoints.find(_.securityProtocol == SecurityProtocol.PLAINTEXT).get + + assertEquals(endpoint.host, advertisedHostName) + assertEquals(endpoint.port, advertisedPort) + } + + @Test + def testDuplicateListeners(): Unit = { + val props = new Properties() + props.put(KafkaConfig.BrokerIdProp, "1") + props.put(KafkaConfig.ZkConnectProp, "localhost:2181") + + // listeners with duplicate port + props.put(KafkaConfig.ListenersProp, "PLAINTEXT://localhost:9091,SSL://localhost:9091") + assertBadConfigContainingMessage(props, "Each listener must have a different port") + + // listeners with duplicate name + props.put(KafkaConfig.ListenersProp, "PLAINTEXT://localhost:9091,PLAINTEXT://localhost:9092") + assertBadConfigContainingMessage(props, "Each listener must have a different name") + + // advertised listeners can have duplicate ports + props.put(KafkaConfig.ListenerSecurityProtocolMapProp, "HOST:SASL_SSL,LB:SASL_SSL") + props.put(KafkaConfig.InterBrokerListenerNameProp, "HOST") + props.put(KafkaConfig.ListenersProp, "HOST://localhost:9091,LB://localhost:9092") + props.put(KafkaConfig.AdvertisedListenersProp, "HOST://localhost:9091,LB://localhost:9091") + KafkaConfig.fromProps(props) + + // but not duplicate names + props.put(KafkaConfig.AdvertisedListenersProp, "HOST://localhost:9091,HOST://localhost:9091") + assertBadConfigContainingMessage(props, "Each listener must have a different name") + } + + @Test + def testControlPlaneListenerName(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect) + props.put("listeners", "PLAINTEXT://localhost:0,CONTROLLER://localhost:5000") + props.put("listener.security.protocol.map", "PLAINTEXT:PLAINTEXT,CONTROLLER:SSL") + props.put("control.plane.listener.name", "CONTROLLER") + KafkaConfig.fromProps(props) + + val serverConfig = KafkaConfig.fromProps(props) + val controlEndpoint = serverConfig.controlPlaneListener.get + assertEquals("localhost", controlEndpoint.host) + assertEquals(5000, controlEndpoint.port) + assertEquals(SecurityProtocol.SSL, controlEndpoint.securityProtocol) + + //advertised listener should contain control-plane listener + val advertisedEndpoints = serverConfig.effectiveAdvertisedListeners + assertTrue(advertisedEndpoints.exists { endpoint => + endpoint.securityProtocol == controlEndpoint.securityProtocol && endpoint.listenerName.value().equals(controlEndpoint.listenerName.value()) + }) + + // interBrokerListener name should be different from control-plane listener name + val interBrokerListenerName = serverConfig.interBrokerListenerName + assertFalse(interBrokerListenerName.value().equals(controlEndpoint.listenerName.value())) + } + + @Test + def testControllerListenerNames(): Unit = { + val props = new Properties() + props.put(KafkaConfig.ProcessRolesProp, "broker,controller") + props.put(KafkaConfig.ListenersProp, "PLAINTEXT://localhost:0,CONTROLLER://localhost:5000") + props.put(KafkaConfig.ControllerListenerNamesProp, "CONTROLLER") + props.put(KafkaConfig.NodeIdProp, "2") + props.put(KafkaConfig.QuorumVotersProp, "2@localhost:5000") + props.put(KafkaConfig.ListenerSecurityProtocolMapProp, "PLAINTEXT:PLAINTEXT,CONTROLLER:SASL_SSL") + + val serverConfig = KafkaConfig.fromProps(props) + val controllerEndpoints = serverConfig.controllerListeners + assertEquals(1, controllerEndpoints.size) + val controllerEndpoint = controllerEndpoints.iterator.next() + assertEquals("localhost", controllerEndpoint.host) + assertEquals(5000, controllerEndpoint.port) + assertEquals(SecurityProtocol.SASL_SSL, controllerEndpoint.securityProtocol) + } + + @Test + def testControlPlaneListenerNameNotAllowedWithKRaft(): Unit = { + val props = new Properties() + props.put(KafkaConfig.ProcessRolesProp, "broker,controller") + props.put(KafkaConfig.ListenersProp, "PLAINTEXT://localhost:9092,SSL://localhost:9093") + props.put(KafkaConfig.ControllerListenerNamesProp, "SSL") + props.put(KafkaConfig.NodeIdProp, "2") + props.put(KafkaConfig.QuorumVotersProp, "2@localhost:9093") + props.put(KafkaConfig.ControlPlaneListenerNameProp, "SSL") + + assertFalse(isValidKafkaConfig(props)) + assertBadConfigContainingMessage(props, "control.plane.listener.name is not supported in KRaft mode.") + + props.remove(KafkaConfig.ControlPlaneListenerNameProp) + KafkaConfig.fromProps(props) + } + + @Test + def testControllerListenerDefinedForKRaftController(): Unit = { + val props = new Properties() + props.put(KafkaConfig.ProcessRolesProp, "controller") + props.put(KafkaConfig.ListenersProp, "SSL://localhost:9093") + props.put(KafkaConfig.NodeIdProp, "2") + props.put(KafkaConfig.QuorumVotersProp, "2@localhost:9093") + + assertBadConfigContainingMessage(props, "The listeners config must only contain KRaft controller listeners from controller.listener.names when process.roles=controller") + + props.put(KafkaConfig.ControllerListenerNamesProp, "SSL") + KafkaConfig.fromProps(props) + + // confirm that redirecting via listener.security.protocol.map is acceptable + props.put(KafkaConfig.ListenerSecurityProtocolMapProp, "CONTROLLER:SSL") + props.put(KafkaConfig.ListenersProp, "CONTROLLER://localhost:9093") + props.put(KafkaConfig.ControllerListenerNamesProp, "CONTROLLER") + KafkaConfig.fromProps(props) + } + + @Test + def testControllerListenerDefinedForKRaftBroker(): Unit = { + val props = new Properties() + props.put(KafkaConfig.ProcessRolesProp, "broker") + props.put(KafkaConfig.NodeIdProp, "1") + props.put(KafkaConfig.QuorumVotersProp, "2@localhost:9093") + + assertFalse(isValidKafkaConfig(props)) + assertBadConfigContainingMessage(props, "controller.listener.names must contain at least one value when running KRaft with just the broker role") + + props.put(KafkaConfig.ControllerListenerNamesProp, "SSL") + KafkaConfig.fromProps(props) + + // confirm that redirecting via listener.security.protocol.map is acceptable + props.put(KafkaConfig.ListenerSecurityProtocolMapProp, "PLAINTEXT:PLAINTEXT,CONTROLLER:SSL") + props.put(KafkaConfig.ControllerListenerNamesProp, "CONTROLLER") + KafkaConfig.fromProps(props) + } + + @Test + def testPortInQuorumVotersNotRequiredToMatchFirstControllerListenerPortForThisKRaftController(): Unit = { + val props = new Properties() + props.put(KafkaConfig.ProcessRolesProp, "controller,broker") + props.put(KafkaConfig.ListenersProp, "PLAINTEXT://localhost:9092,SSL://localhost:9093,SASL_SSL://localhost:9094") + props.put(KafkaConfig.NodeIdProp, "2") + props.put(KafkaConfig.QuorumVotersProp, "2@localhost:9093,3@anotherhost:9094") + props.put(KafkaConfig.ControllerListenerNamesProp, "SSL,SASL_SSL") + KafkaConfig.fromProps(props) + + // change each of the 4 ports to port 5555 -- should pass in all circumstances since we can't validate the + // controller.quorum.voters ports (which are the ports that clients use and are semantically "advertised" ports + // even though the controller configuration doesn't list them in advertised.listeners) against the + // listener ports (which are semantically different then the ports that clients use). + props.put(KafkaConfig.ListenersProp, "PLAINTEXT://localhost:9092,SSL://localhost:5555,SASL_SSL://localhost:9094") + KafkaConfig.fromProps(props) + props.put(KafkaConfig.ListenersProp, "PLAINTEXT://localhost:9092,SSL://localhost:9093,SASL_SSL://localhost:5555") + KafkaConfig.fromProps(props) + props.put(KafkaConfig.ListenersProp, "PLAINTEXT://localhost:9092,SSL://localhost:9093,SASL_SSL://localhost:9094") // reset to original value + props.put(KafkaConfig.QuorumVotersProp, "2@localhost:5555,3@anotherhost:9094") + KafkaConfig.fromProps(props) + props.put(KafkaConfig.QuorumVotersProp, "2@localhost:9093,3@anotherhost:5555") + KafkaConfig.fromProps(props) + } + + @Test + def testSeparateControllerListenerDefinedForKRaftBrokerController(): Unit = { + val props = new Properties() + props.put(KafkaConfig.ProcessRolesProp, "broker,controller") + props.put(KafkaConfig.ListenersProp, "SSL://localhost:9093") + props.put(KafkaConfig.ControllerListenerNamesProp, "SSL") + props.put(KafkaConfig.NodeIdProp, "2") + props.put(KafkaConfig.QuorumVotersProp, "2@localhost:9093") + + assertFalse(isValidKafkaConfig(props)) + assertBadConfigContainingMessage(props, "There must be at least one advertised listener. Perhaps all listeners appear in controller.listener.names?") + + props.put(KafkaConfig.ListenersProp, "PLAINTEXT://localhost:9092,SSL://localhost:9093") + KafkaConfig.fromProps(props) + + // confirm that redirecting via listener.security.protocol.map is acceptable + props.put(KafkaConfig.ListenersProp, "PLAINTEXT://localhost:9092,CONTROLLER://localhost:9093") + props.put(KafkaConfig.ListenerSecurityProtocolMapProp, "PLAINTEXT:PLAINTEXT,CONTROLLER:SSL") + props.put(KafkaConfig.ControllerListenerNamesProp, "CONTROLLER") + KafkaConfig.fromProps(props) + } + + @Test + def testControllerListenerNameMapsToPlaintextByDefaultForKRaft(): Unit = { + val props = new Properties() + props.put(KafkaConfig.ProcessRolesProp, "broker") + props.put(KafkaConfig.ControllerListenerNamesProp, "CONTROLLER") + props.put(KafkaConfig.NodeIdProp, "1") + props.put(KafkaConfig.QuorumVotersProp, "2@localhost:9093") + val controllerListenerName = new ListenerName("CONTROLLER") + assertEquals(Some(SecurityProtocol.PLAINTEXT), + KafkaConfig.fromProps(props).effectiveListenerSecurityProtocolMap.get(controllerListenerName)) + // ensure we don't map it to PLAINTEXT when there is a SSL or SASL controller listener + props.put(KafkaConfig.ControllerListenerNamesProp, "CONTROLLER,SSL") + val controllerNotFoundInMapMessage = "Controller listener with name CONTROLLER defined in controller.listener.names not found in listener.security.protocol.map" + assertBadConfigContainingMessage(props, controllerNotFoundInMapMessage) + // ensure we don't map it to PLAINTEXT when there is a SSL or SASL listener + props.put(KafkaConfig.ControllerListenerNamesProp, "CONTROLLER") + props.put(KafkaConfig.ListenersProp, "SSL://localhost:9092") + assertBadConfigContainingMessage(props, controllerNotFoundInMapMessage) + props.remove(KafkaConfig.ListenersProp) + // ensure we don't map it to PLAINTEXT when it is explicitly mapped otherwise + props.put(KafkaConfig.ListenerSecurityProtocolMapProp, "PLAINTEXT:PLAINTEXT,CONTROLLER:SSL") + assertEquals(Some(SecurityProtocol.SSL), + KafkaConfig.fromProps(props).effectiveListenerSecurityProtocolMap.get(controllerListenerName)) + // ensure we don't map it to PLAINTEXT when anything is explicitly given + // (i.e. it is only part of the default value, even with KRaft) + props.put(KafkaConfig.ListenerSecurityProtocolMapProp, "PLAINTEXT:PLAINTEXT") + assertBadConfigContainingMessage(props, controllerNotFoundInMapMessage) + // ensure we can map it to a non-PLAINTEXT security protocol by default (i.e. when nothing is given) + props.remove(KafkaConfig.ListenerSecurityProtocolMapProp) + props.put(KafkaConfig.ControllerListenerNamesProp, "SSL") + assertEquals(Some(SecurityProtocol.SSL), + KafkaConfig.fromProps(props).effectiveListenerSecurityProtocolMap.get(new ListenerName("SSL"))) + } + + @Test + def testMultipleControllerListenerNamesMapToPlaintextByDefaultForKRaft(): Unit = { + val props = new Properties() + props.put(KafkaConfig.ProcessRolesProp, "controller") + props.put(KafkaConfig.ListenersProp, "CONTROLLER1://localhost:9092,CONTROLLER2://localhost:9093") + props.put(KafkaConfig.ControllerListenerNamesProp, "CONTROLLER1,CONTROLLER2") + props.put(KafkaConfig.NodeIdProp, "1") + props.put(KafkaConfig.QuorumVotersProp, "1@localhost:9092") + assertEquals(Some(SecurityProtocol.PLAINTEXT), + KafkaConfig.fromProps(props).effectiveListenerSecurityProtocolMap.get(new ListenerName("CONTROLLER1"))) + assertEquals(Some(SecurityProtocol.PLAINTEXT), + KafkaConfig.fromProps(props).effectiveListenerSecurityProtocolMap.get(new ListenerName("CONTROLLER2"))) + } + + @Test + def testControllerListenerNameDoesNotMapToPlaintextByDefaultForNonKRaft(): Unit = { + val props = new Properties() + props.put(KafkaConfig.BrokerIdProp, "1") + props.put(KafkaConfig.ZkConnectProp, "localhost:2181") + props.put(KafkaConfig.ListenersProp, "CONTROLLER://localhost:9092") + assertBadConfigContainingMessage(props, + "Error creating broker listeners from 'CONTROLLER://localhost:9092': No security protocol defined for listener CONTROLLER") + // Valid now + props.put(KafkaConfig.ListenersProp, "PLAINTEXT://localhost:9092") + assertEquals(None, KafkaConfig.fromProps(props).effectiveListenerSecurityProtocolMap.get(new ListenerName("CONTROLLER"))) + } + + @Test + def testBadListenerProtocol(): Unit = { + val props = new Properties() + props.put(KafkaConfig.BrokerIdProp, "1") + props.put(KafkaConfig.ZkConnectProp, "localhost:2181") + props.put(KafkaConfig.ListenersProp, "BAD://localhost:9091") + + assertFalse(isValidKafkaConfig(props)) + } + + @Test + def testListenerNamesWithAdvertisedListenerUnset(): Unit = { + val props = new Properties() + props.put(KafkaConfig.BrokerIdProp, "1") + props.put(KafkaConfig.ZkConnectProp, "localhost:2181") + + props.put(KafkaConfig.ListenersProp, "CLIENT://localhost:9091,REPLICATION://localhost:9092,INTERNAL://localhost:9093") + props.put(KafkaConfig.ListenerSecurityProtocolMapProp, "CLIENT:SSL,REPLICATION:SSL,INTERNAL:PLAINTEXT") + props.put(KafkaConfig.InterBrokerListenerNameProp, "REPLICATION") + val config = KafkaConfig.fromProps(props) + val expectedListeners = Seq( + EndPoint("localhost", 9091, new ListenerName("CLIENT"), SecurityProtocol.SSL), + EndPoint("localhost", 9092, new ListenerName("REPLICATION"), SecurityProtocol.SSL), + EndPoint("localhost", 9093, new ListenerName("INTERNAL"), SecurityProtocol.PLAINTEXT)) + assertEquals(expectedListeners, config.listeners) + assertEquals(expectedListeners, config.effectiveAdvertisedListeners) + val expectedSecurityProtocolMap = Map( + new ListenerName("CLIENT") -> SecurityProtocol.SSL, + new ListenerName("REPLICATION") -> SecurityProtocol.SSL, + new ListenerName("INTERNAL") -> SecurityProtocol.PLAINTEXT + ) + assertEquals(expectedSecurityProtocolMap, config.effectiveListenerSecurityProtocolMap) + } + + @Test + def testListenerAndAdvertisedListenerNames(): Unit = { + val props = new Properties() + props.put(KafkaConfig.BrokerIdProp, "1") + props.put(KafkaConfig.ZkConnectProp, "localhost:2181") + + props.put(KafkaConfig.ListenersProp, "EXTERNAL://localhost:9091,INTERNAL://localhost:9093") + props.put(KafkaConfig.AdvertisedListenersProp, "EXTERNAL://lb1.example.com:9000,INTERNAL://host1:9093") + props.put(KafkaConfig.ListenerSecurityProtocolMapProp, "EXTERNAL:SSL,INTERNAL:PLAINTEXT") + props.put(KafkaConfig.InterBrokerListenerNameProp, "INTERNAL") + val config = KafkaConfig.fromProps(props) + + val expectedListeners = Seq( + EndPoint("localhost", 9091, new ListenerName("EXTERNAL"), SecurityProtocol.SSL), + EndPoint("localhost", 9093, new ListenerName("INTERNAL"), SecurityProtocol.PLAINTEXT) + ) + assertEquals(expectedListeners, config.listeners) + + val expectedAdvertisedListeners = Seq( + EndPoint("lb1.example.com", 9000, new ListenerName("EXTERNAL"), SecurityProtocol.SSL), + EndPoint("host1", 9093, new ListenerName("INTERNAL"), SecurityProtocol.PLAINTEXT) + ) + assertEquals(expectedAdvertisedListeners, config.effectiveAdvertisedListeners) + + val expectedSecurityProtocolMap = Map( + new ListenerName("EXTERNAL") -> SecurityProtocol.SSL, + new ListenerName("INTERNAL") -> SecurityProtocol.PLAINTEXT + ) + assertEquals(expectedSecurityProtocolMap, config.effectiveListenerSecurityProtocolMap) + } + + @Test + def testListenerNameMissingFromListenerSecurityProtocolMap(): Unit = { + val props = new Properties() + props.put(KafkaConfig.BrokerIdProp, "1") + props.put(KafkaConfig.ZkConnectProp, "localhost:2181") + + props.put(KafkaConfig.ListenersProp, "SSL://localhost:9091,REPLICATION://localhost:9092") + props.put(KafkaConfig.InterBrokerListenerNameProp, "SSL") + assertFalse(isValidKafkaConfig(props)) + } + + @Test + def testInterBrokerListenerNameMissingFromListenerSecurityProtocolMap(): Unit = { + val props = new Properties() + props.put(KafkaConfig.BrokerIdProp, "1") + props.put(KafkaConfig.ZkConnectProp, "localhost:2181") + + props.put(KafkaConfig.ListenersProp, "SSL://localhost:9091") + props.put(KafkaConfig.InterBrokerListenerNameProp, "REPLICATION") + assertFalse(isValidKafkaConfig(props)) + } + + @Test + def testInterBrokerListenerNameAndSecurityProtocolSet(): Unit = { + val props = new Properties() + props.put(KafkaConfig.BrokerIdProp, "1") + props.put(KafkaConfig.ZkConnectProp, "localhost:2181") + + props.put(KafkaConfig.ListenersProp, "SSL://localhost:9091") + props.put(KafkaConfig.InterBrokerListenerNameProp, "SSL") + props.put(KafkaConfig.InterBrokerSecurityProtocolProp, "SSL") + assertFalse(isValidKafkaConfig(props)) + } + + @Test + def testCaseInsensitiveListenerProtocol(): Unit = { + val props = new Properties() + props.put(KafkaConfig.BrokerIdProp, "1") + props.put(KafkaConfig.ZkConnectProp, "localhost:2181") + props.put(KafkaConfig.ListenersProp, "plaintext://localhost:9091,SsL://localhost:9092") + val config = KafkaConfig.fromProps(props) + assertEquals(Some("SSL://localhost:9092"), config.listeners.find(_.listenerName.value == "SSL").map(_.connectionString)) + assertEquals(Some("PLAINTEXT://localhost:9091"), config.listeners.find(_.listenerName.value == "PLAINTEXT").map(_.connectionString)) + } + + private def listenerListToEndPoints(listenerList: String, + securityProtocolMap: collection.Map[ListenerName, SecurityProtocol] = EndPoint.DefaultSecurityProtocolMap) = + CoreUtils.listenerListToEndPoints(listenerList, securityProtocolMap) + + @Test + def testListenerDefaults(): Unit = { + val props = new Properties() + props.put(KafkaConfig.BrokerIdProp, "1") + props.put(KafkaConfig.ZkConnectProp, "localhost:2181") + + // configuration with no listeners + val conf = KafkaConfig.fromProps(props) + assertEquals(listenerListToEndPoints("PLAINTEXT://:9092"), conf.listeners) + assertNull(conf.listeners.find(_.securityProtocol == SecurityProtocol.PLAINTEXT).get.host) + assertEquals(conf.effectiveAdvertisedListeners, listenerListToEndPoints("PLAINTEXT://:9092")) + } + + @nowarn("cat=deprecation") + @Test + def testVersionConfiguration(): Unit = { + val props = new Properties() + props.put(KafkaConfig.BrokerIdProp, "1") + props.put(KafkaConfig.ZkConnectProp, "localhost:2181") + val conf = KafkaConfig.fromProps(props) + assertEquals(ApiVersion.latestVersion, conf.interBrokerProtocolVersion) + + props.put(KafkaConfig.InterBrokerProtocolVersionProp, "0.8.2.0") + // We need to set the message format version to make the configuration valid. + props.put(KafkaConfig.LogMessageFormatVersionProp, "0.8.2.0") + val conf2 = KafkaConfig.fromProps(props) + assertEquals(KAFKA_0_8_2, conf2.interBrokerProtocolVersion) + + // check that 0.8.2.0 is the same as 0.8.2.1 + props.put(KafkaConfig.InterBrokerProtocolVersionProp, "0.8.2.1") + // We need to set the message format version to make the configuration valid + props.put(KafkaConfig.LogMessageFormatVersionProp, "0.8.2.1") + val conf3 = KafkaConfig.fromProps(props) + assertEquals(KAFKA_0_8_2, conf3.interBrokerProtocolVersion) + + //check that latest is newer than 0.8.2 + assertTrue(ApiVersion.latestVersion >= conf3.interBrokerProtocolVersion) + } + + private def isValidKafkaConfig(props: Properties): Boolean = { + try { + KafkaConfig.fromProps(props) + true + } catch { + case _: IllegalArgumentException | _: ConfigException => false + } + } + + @Test + def testUncleanLeaderElectionDefault(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + val serverConfig = KafkaConfig.fromProps(props) + + assertEquals(serverConfig.uncleanLeaderElectionEnable, false) + } + + @Test + def testUncleanElectionDisabled(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + props.put(KafkaConfig.UncleanLeaderElectionEnableProp, String.valueOf(false)) + val serverConfig = KafkaConfig.fromProps(props) + + assertEquals(serverConfig.uncleanLeaderElectionEnable, false) + } + + @Test + def testUncleanElectionEnabled(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + props.put(KafkaConfig.UncleanLeaderElectionEnableProp, String.valueOf(true)) + val serverConfig = KafkaConfig.fromProps(props) + + assertEquals(serverConfig.uncleanLeaderElectionEnable, true) + } + + @Test + def testUncleanElectionInvalid(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + props.put(KafkaConfig.UncleanLeaderElectionEnableProp, "invalid") + + assertThrows(classOf[ConfigException], () => KafkaConfig.fromProps(props)) + } + + @Test + def testLogRollTimeMsProvided(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + props.put(KafkaConfig.LogRollTimeMillisProp, "1800000") + + val cfg = KafkaConfig.fromProps(props) + assertEquals(30 * 60L * 1000L, cfg.logRollTimeMillis) + } + + @Test + def testLogRollTimeBothMsAndHoursProvided(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + props.put(KafkaConfig.LogRollTimeMillisProp, "1800000") + props.put(KafkaConfig.LogRollTimeHoursProp, "1") + + val cfg = KafkaConfig.fromProps(props) + assertEquals( 30 * 60L * 1000L, cfg.logRollTimeMillis) + } + + @Test + def testLogRollTimeNoConfigProvided(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + + val cfg = KafkaConfig.fromProps(props) + assertEquals(24 * 7 * 60L * 60L * 1000L, cfg.logRollTimeMillis ) + } + + @Test + def testDefaultCompressionType(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + val serverConfig = KafkaConfig.fromProps(props) + + assertEquals(serverConfig.compressionType, "producer") + } + + @Test + def testValidCompressionType(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + props.put("compression.type", "gzip") + val serverConfig = KafkaConfig.fromProps(props) + + assertEquals(serverConfig.compressionType, "gzip") + } + + @Test + def testInvalidCompressionType(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + props.put(KafkaConfig.CompressionTypeProp, "abc") + assertThrows(classOf[IllegalArgumentException], () => KafkaConfig.fromProps(props)) + } + + @Test + def testInvalidInterBrokerSecurityProtocol(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + props.put(KafkaConfig.ListenersProp, "SSL://localhost:0") + props.put(KafkaConfig.InterBrokerSecurityProtocolProp, SecurityProtocol.PLAINTEXT.toString) + assertThrows(classOf[IllegalArgumentException], () => KafkaConfig.fromProps(props)) + } + + @Test + def testEqualAdvertisedListenersProtocol(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + props.put(KafkaConfig.ListenersProp, "PLAINTEXT://localhost:9092,SSL://localhost:9093") + props.put(KafkaConfig.AdvertisedListenersProp, "PLAINTEXT://localhost:9092,SSL://localhost:9093") + KafkaConfig.fromProps(props) + } + + @Test + def testInvalidAdvertisedListenersProtocol(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + props.put(KafkaConfig.ListenersProp, "TRACE://localhost:9091,SSL://localhost:9093") + props.put(KafkaConfig.AdvertisedListenersProp, "PLAINTEXT://localhost:9092") + assertBadConfigContainingMessage(props, "No security protocol defined for listener TRACE") + + props.put(KafkaConfig.ListenerSecurityProtocolMapProp, "PLAINTEXT:PLAINTEXT,TRACE:PLAINTEXT,SSL:SSL") + assertBadConfigContainingMessage(props, "advertised.listeners listener names must be equal to or a subset of the ones defined in listeners") + } + + @nowarn("cat=deprecation") + @Test + def testInterBrokerVersionMessageFormatCompatibility(): Unit = { + def buildConfig(interBrokerProtocol: ApiVersion, messageFormat: ApiVersion): KafkaConfig = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + props.put(KafkaConfig.InterBrokerProtocolVersionProp, interBrokerProtocol.version) + props.put(KafkaConfig.LogMessageFormatVersionProp, messageFormat.version) + KafkaConfig.fromProps(props) + } + + ApiVersion.allVersions.foreach { interBrokerVersion => + ApiVersion.allVersions.foreach { messageFormatVersion => + if (interBrokerVersion.recordVersion.value >= messageFormatVersion.recordVersion.value) { + val config = buildConfig(interBrokerVersion, messageFormatVersion) + assertEquals(interBrokerVersion, config.interBrokerProtocolVersion) + if (interBrokerVersion >= KAFKA_3_0_IV1) + assertEquals(KAFKA_3_0_IV1, config.logMessageFormatVersion) + else + assertEquals(messageFormatVersion, config.logMessageFormatVersion) + } else { + assertThrows(classOf[IllegalArgumentException], () => buildConfig(interBrokerVersion, messageFormatVersion)) + } + } + } + } + + @Test + def testFromPropsInvalid(): Unit = { + def baseProperties: Properties = { + val validRequiredProperties = new Properties() + validRequiredProperties.put(KafkaConfig.ZkConnectProp, "127.0.0.1:2181") + validRequiredProperties + } + // to ensure a basis is valid - bootstraps all needed validation + KafkaConfig.fromProps(baseProperties) + + KafkaConfig.configNames.foreach { name => + name match { + case KafkaConfig.ZkConnectProp => // ignore string + case KafkaConfig.ZkSessionTimeoutMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.ZkConnectionTimeoutMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.ZkSyncTimeMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.ZkEnableSecureAclsProp => assertPropertyInvalid(baseProperties, name, "not_a_boolean") + case KafkaConfig.ZkMaxInFlightRequestsProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0") + case KafkaConfig.ZkSslClientEnableProp => assertPropertyInvalid(baseProperties, name, "not_a_boolean") + case KafkaConfig.ZkClientCnxnSocketProp => //ignore string + case KafkaConfig.ZkSslKeyStoreLocationProp => //ignore string + case KafkaConfig.ZkSslKeyStorePasswordProp => //ignore string + case KafkaConfig.ZkSslKeyStoreTypeProp => //ignore string + case KafkaConfig.ZkSslTrustStoreLocationProp => //ignore string + case KafkaConfig.ZkSslTrustStorePasswordProp => //ignore string + case KafkaConfig.ZkSslTrustStoreTypeProp => //ignore string + case KafkaConfig.ZkSslProtocolProp => //ignore string + case KafkaConfig.ZkSslEnabledProtocolsProp => //ignore string + case KafkaConfig.ZkSslCipherSuitesProp => //ignore string + case KafkaConfig.ZkSslEndpointIdentificationAlgorithmProp => //ignore string + case KafkaConfig.ZkSslCrlEnableProp => assertPropertyInvalid(baseProperties, name, "not_a_boolean") + case KafkaConfig.ZkSslOcspEnableProp => assertPropertyInvalid(baseProperties, name, "not_a_boolean") + + case KafkaConfig.BrokerIdProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.NumNetworkThreadsProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0") + case KafkaConfig.NumIoThreadsProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0") + case KafkaConfig.BackgroundThreadsProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0") + case KafkaConfig.QueuedMaxRequestsProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0") + case KafkaConfig.NumReplicaAlterLogDirsThreadsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.QueuedMaxBytesProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.RequestTimeoutMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.ConnectionSetupTimeoutMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.ConnectionSetupTimeoutMaxMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + + // KRaft mode configs + case KafkaConfig.ProcessRolesProp => // ignore + case KafkaConfig.InitialBrokerRegistrationTimeoutMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.BrokerHeartbeatIntervalMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.BrokerSessionTimeoutMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.NodeIdProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.MetadataLogDirProp => // ignore string + case KafkaConfig.MetadataLogSegmentBytesProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.MetadataLogSegmentMillisProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.MetadataMaxRetentionBytesProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.MetadataMaxRetentionMillisProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.ControllerListenerNamesProp => // ignore string + + case KafkaConfig.AuthorizerClassNameProp => //ignore string + case KafkaConfig.CreateTopicPolicyClassNameProp => //ignore string + + case KafkaConfig.SocketSendBufferBytesProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.SocketReceiveBufferBytesProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.MaxConnectionsPerIpOverridesProp => + assertPropertyInvalid(baseProperties, name, "127.0.0.1:not_a_number") + case KafkaConfig.ConnectionsMaxIdleMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.FailedAuthenticationDelayMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "-1") + + case KafkaConfig.NumPartitionsProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0") + case KafkaConfig.LogDirsProp => // ignore string + case KafkaConfig.LogDirProp => // ignore string + case KafkaConfig.LogSegmentBytesProp => assertPropertyInvalid(baseProperties, name, "not_a_number", Records.LOG_OVERHEAD - 1) + + case KafkaConfig.LogRollTimeMillisProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0") + case KafkaConfig.LogRollTimeHoursProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0") + + case KafkaConfig.LogRetentionTimeMillisProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0") + case KafkaConfig.LogRetentionTimeMinutesProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0") + case KafkaConfig.LogRetentionTimeHoursProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0") + + case KafkaConfig.LogRetentionBytesProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.LogCleanupIntervalMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0") + case KafkaConfig.LogCleanupPolicyProp => assertPropertyInvalid(baseProperties, name, "unknown_policy", "0") + case KafkaConfig.LogCleanerIoMaxBytesPerSecondProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.LogCleanerDedupeBufferSizeProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "1024") + case KafkaConfig.LogCleanerDedupeBufferLoadFactorProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.LogCleanerEnableProp => assertPropertyInvalid(baseProperties, name, "not_a_boolean") + case KafkaConfig.LogCleanerDeleteRetentionMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.LogCleanerMinCompactionLagMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.LogCleanerMaxCompactionLagMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.LogCleanerMinCleanRatioProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.LogIndexSizeMaxBytesProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "3") + case KafkaConfig.LogFlushIntervalMessagesProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0") + case KafkaConfig.LogFlushSchedulerIntervalMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.LogFlushIntervalMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.LogMessageTimestampDifferenceMaxMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.LogFlushStartOffsetCheckpointIntervalMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.NumRecoveryThreadsPerDataDirProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0") + case KafkaConfig.AutoCreateTopicsEnableProp => assertPropertyInvalid(baseProperties, name, "not_a_boolean", "0") + case KafkaConfig.MinInSyncReplicasProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0") + case KafkaConfig.ControllerSocketTimeoutMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.DefaultReplicationFactorProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.ReplicaLagTimeMaxMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.ReplicaSocketTimeoutMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "-2") + case KafkaConfig.ReplicaSocketReceiveBufferBytesProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.ReplicaFetchMaxBytesProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.ReplicaFetchWaitMaxMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.ReplicaFetchMinBytesProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.ReplicaFetchResponseMaxBytesProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.ReplicaSelectorClassProp => // Ignore string + case KafkaConfig.NumReplicaFetchersProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.ReplicaHighWatermarkCheckpointIntervalMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.FetchPurgatoryPurgeIntervalRequestsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.ProducerPurgatoryPurgeIntervalRequestsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.DeleteRecordsPurgatoryPurgeIntervalRequestsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.AutoLeaderRebalanceEnableProp => assertPropertyInvalid(baseProperties, name, "not_a_boolean", "0") + case KafkaConfig.LeaderImbalancePerBrokerPercentageProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.LeaderImbalanceCheckIntervalSecondsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.UncleanLeaderElectionEnableProp => assertPropertyInvalid(baseProperties, name, "not_a_boolean", "0") + case KafkaConfig.ControlledShutdownMaxRetriesProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.ControlledShutdownRetryBackoffMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.ControlledShutdownEnableProp => assertPropertyInvalid(baseProperties, name, "not_a_boolean", "0") + case KafkaConfig.GroupMinSessionTimeoutMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.GroupMaxSessionTimeoutMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.GroupInitialRebalanceDelayMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.GroupMaxSizeProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0", "-1") + case KafkaConfig.OffsetMetadataMaxSizeProp => assertPropertyInvalid(baseProperties, name, "not_a_number") + case KafkaConfig.OffsetsLoadBufferSizeProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0") + case KafkaConfig.OffsetsTopicReplicationFactorProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0") + case KafkaConfig.OffsetsTopicPartitionsProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0") + case KafkaConfig.OffsetsTopicSegmentBytesProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0") + case KafkaConfig.OffsetsTopicCompressionCodecProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "-1") + case KafkaConfig.OffsetsRetentionMinutesProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0") + case KafkaConfig.OffsetsRetentionCheckIntervalMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0") + case KafkaConfig.OffsetCommitTimeoutMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0") + case KafkaConfig.OffsetCommitRequiredAcksProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "-2") + case KafkaConfig.TransactionalIdExpirationMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0", "-2") + case KafkaConfig.TransactionsMaxTimeoutMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0", "-2") + case KafkaConfig.TransactionsTopicMinISRProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0", "-2") + case KafkaConfig.TransactionsLoadBufferSizeProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0", "-2") + case KafkaConfig.TransactionsTopicPartitionsProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0", "-2") + case KafkaConfig.TransactionsTopicSegmentBytesProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0", "-2") + case KafkaConfig.TransactionsTopicReplicationFactorProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0", "-2") + case KafkaConfig.NumQuotaSamplesProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0") + case KafkaConfig.QuotaWindowSizeSecondsProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0") + case KafkaConfig.DeleteTopicEnableProp => assertPropertyInvalid(baseProperties, name, "not_a_boolean", "0") + + case KafkaConfig.MetricNumSamplesProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "-1", "0") + case KafkaConfig.MetricSampleWindowMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "-1", "0") + case KafkaConfig.MetricReporterClassesProp => // ignore string + case KafkaConfig.MetricRecordingLevelProp => // ignore string + case KafkaConfig.RackProp => // ignore string + //SSL Configs + case KafkaConfig.PrincipalBuilderClassProp => + case KafkaConfig.ConnectionsMaxReauthMsProp => + case KafkaConfig.SslProtocolProp => // ignore string + case KafkaConfig.SslProviderProp => // ignore string + case KafkaConfig.SslEnabledProtocolsProp => + case KafkaConfig.SslKeystoreTypeProp => // ignore string + case KafkaConfig.SslKeystoreLocationProp => // ignore string + case KafkaConfig.SslKeystorePasswordProp => // ignore string + case KafkaConfig.SslKeyPasswordProp => // ignore string + case KafkaConfig.SslKeystoreCertificateChainProp => // ignore string + case KafkaConfig.SslKeystoreKeyProp => // ignore string + case KafkaConfig.SslTruststoreTypeProp => // ignore string + case KafkaConfig.SslTruststorePasswordProp => // ignore string + case KafkaConfig.SslTruststoreLocationProp => // ignore string + case KafkaConfig.SslTruststoreCertificatesProp => // ignore string + case KafkaConfig.SslKeyManagerAlgorithmProp => + case KafkaConfig.SslTrustManagerAlgorithmProp => + case KafkaConfig.SslClientAuthProp => // ignore string + case KafkaConfig.SslEndpointIdentificationAlgorithmProp => // ignore string + case KafkaConfig.SslSecureRandomImplementationProp => // ignore string + case KafkaConfig.SslCipherSuitesProp => // ignore string + case KafkaConfig.SslPrincipalMappingRulesProp => // ignore string + + //Sasl Configs + case KafkaConfig.SaslMechanismControllerProtocolProp => // ignore + case KafkaConfig.SaslMechanismInterBrokerProtocolProp => // ignore + case KafkaConfig.SaslEnabledMechanismsProp => + case KafkaConfig.SaslClientCallbackHandlerClassProp => + case KafkaConfig.SaslServerCallbackHandlerClassProp => + case KafkaConfig.SaslLoginClassProp => + case KafkaConfig.SaslLoginCallbackHandlerClassProp => + case KafkaConfig.SaslKerberosServiceNameProp => // ignore string + case KafkaConfig.SaslKerberosKinitCmdProp => + case KafkaConfig.SaslKerberosTicketRenewWindowFactorProp => + case KafkaConfig.SaslKerberosTicketRenewJitterProp => + case KafkaConfig.SaslKerberosMinTimeBeforeReloginProp => + case KafkaConfig.SaslKerberosPrincipalToLocalRulesProp => // ignore string + case KafkaConfig.SaslJaasConfigProp => + case KafkaConfig.SaslLoginRefreshWindowFactorProp => + case KafkaConfig.SaslLoginRefreshWindowJitterProp => + case KafkaConfig.SaslLoginRefreshMinPeriodSecondsProp => + case KafkaConfig.SaslLoginRefreshBufferSecondsProp => + case KafkaConfig.SaslLoginConnectTimeoutMsProp => + case KafkaConfig.SaslLoginReadTimeoutMsProp => + case KafkaConfig.SaslLoginRetryBackoffMaxMsProp => + case KafkaConfig.SaslLoginRetryBackoffMsProp => + case KafkaConfig.SaslOAuthBearerScopeClaimNameProp => + case KafkaConfig.SaslOAuthBearerSubClaimNameProp => + case KafkaConfig.SaslOAuthBearerTokenEndpointUrlProp => + case KafkaConfig.SaslOAuthBearerJwksEndpointUrlProp => + case KafkaConfig.SaslOAuthBearerJwksEndpointRefreshMsProp => + case KafkaConfig.SaslOAuthBearerJwksEndpointRetryBackoffMaxMsProp => + case KafkaConfig.SaslOAuthBearerJwksEndpointRetryBackoffMsProp => + case KafkaConfig.SaslOAuthBearerClockSkewSecondsProp => + case KafkaConfig.SaslOAuthBearerExpectedAudienceProp => + case KafkaConfig.SaslOAuthBearerExpectedIssuerProp => + + // Security config + case KafkaConfig.securityProviderClassProp => + + // Password encoder configs + case KafkaConfig.PasswordEncoderSecretProp => + case KafkaConfig.PasswordEncoderOldSecretProp => + case KafkaConfig.PasswordEncoderKeyFactoryAlgorithmProp => + case KafkaConfig.PasswordEncoderCipherAlgorithmProp => + case KafkaConfig.PasswordEncoderKeyLengthProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "-1", "0") + case KafkaConfig.PasswordEncoderIterationsProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "-1", "0") + + //delegation token configs + case KafkaConfig.DelegationTokenSecretKeyAliasProp => // ignore + case KafkaConfig.DelegationTokenSecretKeyProp => // ignore + case KafkaConfig.DelegationTokenMaxLifeTimeProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0") + case KafkaConfig.DelegationTokenExpiryTimeMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0") + case KafkaConfig.DelegationTokenExpiryCheckIntervalMsProp => assertPropertyInvalid(baseProperties, name, "not_a_number", "0") + + //Kafka Yammer metrics reporter configs + case KafkaConfig.KafkaMetricsReporterClassesProp => // ignore + case KafkaConfig.KafkaMetricsPollingIntervalSecondsProp => //ignore + + // Raft Quorum Configs + case RaftConfig.QUORUM_VOTERS_CONFIG => // ignore string + case RaftConfig.QUORUM_ELECTION_TIMEOUT_MS_CONFIG => assertPropertyInvalid(baseProperties, name, "not_a_number") + case RaftConfig.QUORUM_FETCH_TIMEOUT_MS_CONFIG => assertPropertyInvalid(baseProperties, name, "not_a_number") + case RaftConfig.QUORUM_ELECTION_BACKOFF_MAX_MS_CONFIG => assertPropertyInvalid(baseProperties, name, "not_a_number") + case RaftConfig.QUORUM_LINGER_MS_CONFIG => assertPropertyInvalid(baseProperties, name, "not_a_number") + case RaftConfig.QUORUM_REQUEST_TIMEOUT_MS_CONFIG => assertPropertyInvalid(baseProperties, name, "not_a_number") + case RaftConfig.QUORUM_RETRY_BACKOFF_MS_CONFIG => assertPropertyInvalid(baseProperties, name, "not_a_number") + + // Remote Log Manager Configs + case RemoteLogManagerConfig.REMOTE_LOG_STORAGE_SYSTEM_ENABLE_PROP => assertPropertyInvalid(baseProperties, name, "not_a_boolean") + case RemoteLogManagerConfig.REMOTE_STORAGE_MANAGER_CLASS_NAME_PROP => // ignore string + case RemoteLogManagerConfig.REMOTE_STORAGE_MANAGER_CLASS_PATH_PROP => // ignore string + case RemoteLogManagerConfig.REMOTE_STORAGE_MANAGER_CONFIG_PREFIX_PROP => // ignore string + case RemoteLogManagerConfig.REMOTE_LOG_METADATA_MANAGER_CLASS_NAME_PROP => // ignore string + case RemoteLogManagerConfig.REMOTE_LOG_METADATA_MANAGER_CLASS_PATH_PROP => // ignore string + case RemoteLogManagerConfig.REMOTE_LOG_METADATA_MANAGER_CONFIG_PREFIX_PROP => // ignore string + case RemoteLogManagerConfig.REMOTE_LOG_METADATA_MANAGER_LISTENER_NAME_PROP => // ignore string + case RemoteLogManagerConfig.REMOTE_LOG_INDEX_FILE_CACHE_TOTAL_SIZE_BYTES_PROP => assertPropertyInvalid(baseProperties, name, "not_a_number", 0, -1) + case RemoteLogManagerConfig.REMOTE_LOG_MANAGER_THREAD_POOL_SIZE_PROP => assertPropertyInvalid(baseProperties, name, "not_a_number", 0, -1) + case RemoteLogManagerConfig.REMOTE_LOG_MANAGER_TASK_INTERVAL_MS_PROP => assertPropertyInvalid(baseProperties, name, "not_a_number", 0, -1) + case RemoteLogManagerConfig.REMOTE_LOG_MANAGER_TASK_RETRY_BACK_OFF_MS_PROP => assertPropertyInvalid(baseProperties, name, "not_a_number", 0, -1) + case RemoteLogManagerConfig.REMOTE_LOG_MANAGER_TASK_RETRY_BACK_OFF_MAX_MS_PROP => assertPropertyInvalid(baseProperties, name, "not_a_number", 0, -1) + case RemoteLogManagerConfig.REMOTE_LOG_MANAGER_TASK_RETRY_JITTER_PROP => assertPropertyInvalid(baseProperties, name, "not_a_number", -1, 0.51) + case RemoteLogManagerConfig.REMOTE_LOG_READER_THREADS_PROP => assertPropertyInvalid(baseProperties, name, "not_a_number", 0, -1) + case RemoteLogManagerConfig.REMOTE_LOG_READER_MAX_PENDING_TASKS_PROP => assertPropertyInvalid(baseProperties, name, "not_a_number", 0, -1) + + case TopicConfig.LOCAL_LOG_RETENTION_MS_CONFIG => assertPropertyInvalid(baseProperties, name, "not_a_number", 0, -2) + case TopicConfig.LOCAL_LOG_RETENTION_BYTES_CONFIG => assertPropertyInvalid(baseProperties, name, "not_a_number", 0, -2) + + case _ => assertPropertyInvalid(baseProperties, name, "not_a_number", "-1") + } + } + } + + @nowarn("cat=deprecation") + @Test + def testDynamicLogConfigs(): Unit = { + def baseProperties: Properties = { + val validRequiredProperties = new Properties() + validRequiredProperties.put(KafkaConfig.ZkConnectProp, "127.0.0.1:2181") + validRequiredProperties + } + + val props = baseProperties + val config = KafkaConfig.fromProps(props) + + def assertDynamic(property: String, value: Any, accessor: () => Any): Unit = { + val initial = accessor() + props.put(property, value.toString) + config.updateCurrentConfig(new KafkaConfig(props)) + assertNotEquals(initial, accessor()) + } + + // Test dynamic log config values can be correctly passed through via KafkaConfig to LogConfig + // Every log config prop must be explicitly accounted for here. + // A value other than the default value for this config should be set to ensure that we can check whether + // the value is dynamically updatable. + LogConfig.TopicConfigSynonyms.foreach { case (logConfig, kafkaConfigProp) => + logConfig match { + case LogConfig.CleanupPolicyProp => + assertDynamic(kafkaConfigProp, Defaults.Compact, () => config.logCleanupPolicy) + case LogConfig.CompressionTypeProp => + assertDynamic(kafkaConfigProp, "lz4", () => config.compressionType) + case LogConfig.SegmentBytesProp => + assertDynamic(kafkaConfigProp, 10000, () => config.logSegmentBytes) + case LogConfig.SegmentMsProp => + assertDynamic(kafkaConfigProp, 10001L, () => config.logRollTimeMillis) + case LogConfig.DeleteRetentionMsProp => + assertDynamic(kafkaConfigProp, 10002L, () => config.logCleanerDeleteRetentionMs) + case LogConfig.FileDeleteDelayMsProp => + assertDynamic(kafkaConfigProp, 10003L, () => config.logDeleteDelayMs) + case LogConfig.FlushMessagesProp => + assertDynamic(kafkaConfigProp, 10004L, () => config.logFlushIntervalMessages) + case LogConfig.FlushMsProp => + assertDynamic(kafkaConfigProp, 10005L, () => config.logFlushIntervalMs) + case LogConfig.MaxCompactionLagMsProp => + assertDynamic(kafkaConfigProp, 10006L, () => config.logCleanerMaxCompactionLagMs) + case LogConfig.IndexIntervalBytesProp => + assertDynamic(kafkaConfigProp, 10007, () => config.logIndexIntervalBytes) + case LogConfig.MaxMessageBytesProp => + assertDynamic(kafkaConfigProp, 10008, () => config.messageMaxBytes) + case LogConfig.MessageDownConversionEnableProp => + assertDynamic(kafkaConfigProp, false, () => config.logMessageDownConversionEnable) + case LogConfig.MessageTimestampDifferenceMaxMsProp => + assertDynamic(kafkaConfigProp, 10009, () => config.logMessageTimestampDifferenceMaxMs) + case LogConfig.MessageTimestampTypeProp => + assertDynamic(kafkaConfigProp, "LogAppendTime", () => config.logMessageTimestampType.name) + case LogConfig.MinCleanableDirtyRatioProp => + assertDynamic(kafkaConfigProp, 0.01, () => config.logCleanerMinCleanRatio) + case LogConfig.MinCompactionLagMsProp => + assertDynamic(kafkaConfigProp, 10010L, () => config.logCleanerMinCompactionLagMs) + case LogConfig.MinInSyncReplicasProp => + assertDynamic(kafkaConfigProp, 4, () => config.minInSyncReplicas) + case LogConfig.PreAllocateEnableProp => + assertDynamic(kafkaConfigProp, true, () => config.logPreAllocateEnable) + case LogConfig.RetentionBytesProp => + assertDynamic(kafkaConfigProp, 10011L, () => config.logRetentionBytes) + case LogConfig.RetentionMsProp => + assertDynamic(kafkaConfigProp, 10012L, () => config.logRetentionTimeMillis) + case LogConfig.SegmentIndexBytesProp => + assertDynamic(kafkaConfigProp, 10013, () => config.logIndexSizeMaxBytes) + case LogConfig.SegmentJitterMsProp => + assertDynamic(kafkaConfigProp, 10014L, () => config.logRollTimeJitterMillis) + case LogConfig.UncleanLeaderElectionEnableProp => + assertDynamic(kafkaConfigProp, true, () => config.uncleanLeaderElectionEnable) + case LogConfig.MessageFormatVersionProp => + // not dynamically updatable + case LogConfig.FollowerReplicationThrottledReplicasProp => + // topic only config + case LogConfig.LeaderReplicationThrottledReplicasProp => + // topic only config + case prop => + fail(prop + " must be explicitly checked for dynamic updatability. Note that LogConfig(s) require that KafkaConfig value lookups are dynamic and not static values.") + } + } + } + + @Test + def testSpecificProperties(): Unit = { + val defaults = new Properties() + defaults.put(KafkaConfig.ZkConnectProp, "127.0.0.1:2181") + // For ZkConnectionTimeoutMs + defaults.put(KafkaConfig.ZkSessionTimeoutMsProp, "1234") + defaults.put(KafkaConfig.BrokerIdGenerationEnableProp, "false") + defaults.put(KafkaConfig.MaxReservedBrokerIdProp, "1") + defaults.put(KafkaConfig.BrokerIdProp, "1") + defaults.put(KafkaConfig.ListenersProp, "PLAINTEXT://127.0.0.1:1122") + defaults.put(KafkaConfig.MaxConnectionsPerIpOverridesProp, "127.0.0.1:2, 127.0.0.2:3") + defaults.put(KafkaConfig.LogDirProp, "/tmp1,/tmp2") + defaults.put(KafkaConfig.LogRollTimeHoursProp, "12") + defaults.put(KafkaConfig.LogRollTimeJitterHoursProp, "11") + defaults.put(KafkaConfig.LogRetentionTimeHoursProp, "10") + //For LogFlushIntervalMsProp + defaults.put(KafkaConfig.LogFlushSchedulerIntervalMsProp, "123") + defaults.put(KafkaConfig.OffsetsTopicCompressionCodecProp, SnappyCompressionCodec.codec.toString) + // For MetricRecordingLevelProp + defaults.put(KafkaConfig.MetricRecordingLevelProp, Sensor.RecordingLevel.DEBUG.toString) + + val config = KafkaConfig.fromProps(defaults) + assertEquals("127.0.0.1:2181", config.zkConnect) + assertEquals(1234, config.zkConnectionTimeoutMs) + assertEquals(false, config.brokerIdGenerationEnable) + assertEquals(1, config.maxReservedBrokerId) + assertEquals(1, config.brokerId) + assertEquals(Seq("PLAINTEXT://127.0.0.1:1122"), config.effectiveAdvertisedListeners.map(_.connectionString)) + assertEquals(Map("127.0.0.1" -> 2, "127.0.0.2" -> 3), config.maxConnectionsPerIpOverrides) + assertEquals(List("/tmp1", "/tmp2"), config.logDirs) + assertEquals(12 * 60L * 1000L * 60, config.logRollTimeMillis) + assertEquals(11 * 60L * 1000L * 60, config.logRollTimeJitterMillis) + assertEquals(10 * 60L * 1000L * 60, config.logRetentionTimeMillis) + assertEquals(123L, config.logFlushIntervalMs) + assertEquals(SnappyCompressionCodec, config.offsetsTopicCompressionCodec) + assertEquals(Sensor.RecordingLevel.DEBUG.toString, config.metricRecordingLevel) + assertEquals(false, config.tokenAuthEnabled) + assertEquals(7 * 24 * 60L * 60L * 1000L, config.delegationTokenMaxLifeMs) + assertEquals(24 * 60L * 60L * 1000L, config.delegationTokenExpiryTimeMs) + assertEquals(1 * 60L * 1000L * 60, config.delegationTokenExpiryCheckIntervalMs) + + defaults.put(KafkaConfig.DelegationTokenSecretKeyProp, "1234567890") + val config1 = KafkaConfig.fromProps(defaults) + assertEquals(true, config1.tokenAuthEnabled) + } + + @Test + def testNonroutableAdvertisedListeners(): Unit = { + val props = new Properties() + props.put(KafkaConfig.ZkConnectProp, "127.0.0.1:2181") + props.put(KafkaConfig.ListenersProp, "PLAINTEXT://0.0.0.0:9092") + assertFalse(isValidKafkaConfig(props)) + } + + @Test + def testMaxConnectionsPerIpProp(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 8181) + props.put(KafkaConfig.MaxConnectionsPerIpProp, "0") + assertFalse(isValidKafkaConfig(props)) + props.put(KafkaConfig.MaxConnectionsPerIpOverridesProp, "127.0.0.1:100") + KafkaConfig.fromProps(props) + props.put(KafkaConfig.MaxConnectionsPerIpOverridesProp, "127.0.0.0#:100") + assertFalse(isValidKafkaConfig(props)) + } + + private def assertPropertyInvalid(validRequiredProps: => Properties, name: String, values: Any*): Unit = { + values.foreach { value => + val props = validRequiredProps + props.setProperty(name, value.toString) + + val buildConfig: Executable = () => KafkaConfig.fromProps(props) + assertThrows(classOf[Exception], buildConfig, + s"Expected exception for property `$name` with invalid value `$value` was not thrown") + } + } + + @Test + def testDistinctControllerAndAdvertisedListenersAllowedForKRaftBroker(): Unit = { + val props = new Properties() + props.put(KafkaConfig.ProcessRolesProp, "broker") + props.put(KafkaConfig.ListenersProp, "PLAINTEXT://A:9092,SSL://B:9093,SASL_SSL://C:9094") + props.put(KafkaConfig.AdvertisedListenersProp, "PLAINTEXT://A:9092,SSL://B:9093") // explicitly setting it in KRaft + props.put(KafkaConfig.ControllerListenerNamesProp, "SASL_SSL") + props.put(KafkaConfig.NodeIdProp, "2") + props.put(KafkaConfig.QuorumVotersProp, "3@localhost:9094") + + // invalid due to extra listener also appearing in controller listeners + assertBadConfigContainingMessage(props, + "controller.listener.names must not contain a value appearing in the 'listeners' configuration when running KRaft with just the broker role") + + // Valid now + props.put(KafkaConfig.ListenersProp, "PLAINTEXT://A:9092,SSL://B:9093") + KafkaConfig.fromProps(props) + + // Also valid if we let advertised listeners be derived from listeners/controller.listener.names + // since listeners and advertised.listeners are explicitly identical at this point + props.remove(KafkaConfig.AdvertisedListenersProp) + KafkaConfig.fromProps(props) + } + + @Test + def testControllerListenersCannotBeAdvertisedForKRaftBroker(): Unit = { + val props = new Properties() + props.put(KafkaConfig.ProcessRolesProp, "broker,controller") + val listeners = "PLAINTEXT://A:9092,SSL://B:9093,SASL_SSL://C:9094" + props.put(KafkaConfig.ListenersProp, listeners) + props.put(KafkaConfig.AdvertisedListenersProp, listeners) // explicitly setting it in KRaft + props.put(KafkaConfig.InterBrokerListenerNameProp, "SASL_SSL") + props.put(KafkaConfig.ControllerListenerNamesProp, "PLAINTEXT,SSL") + props.put(KafkaConfig.NodeIdProp, "2") + props.put(KafkaConfig.QuorumVotersProp, "2@localhost:9092") + assertBadConfigContainingMessage(props, + "The advertised.listeners config must not contain KRaft controller listeners from controller.listener.names when process.roles contains the broker role") + + // Valid now + props.put(KafkaConfig.AdvertisedListenersProp, "SASL_SSL://C:9094") + KafkaConfig.fromProps(props) + + // Also valid if we allow advertised listeners to derive from listeners/controller.listener.names + props.remove(KafkaConfig.AdvertisedListenersProp) + KafkaConfig.fromProps(props) + } + + @Test + def testAdvertisedListenersDisallowedForKRaftControllerOnlyRole(): Unit = { + val props = new Properties() + props.put(KafkaConfig.ProcessRolesProp, "controller") + val listeners = "PLAINTEXT://A:9092,SSL://B:9093,SASL_SSL://C:9094" + props.put(KafkaConfig.ListenersProp, listeners) + props.put(KafkaConfig.AdvertisedListenersProp, listeners) // explicitly setting it in KRaft + props.put(KafkaConfig.ControllerListenerNamesProp, "PLAINTEXT,SSL") + props.put(KafkaConfig.NodeIdProp, "2") + props.put(KafkaConfig.QuorumVotersProp, "2@localhost:9092") + val expectedExceptionContainsTextSuffix = " config must only contain KRaft controller listeners from controller.listener.names when process.roles=controller" + assertBadConfigContainingMessage(props, "The advertised.listeners" + expectedExceptionContainsTextSuffix) + + // Still invalid due to extra listener if we set advertised listeners explicitly to be correct + val correctListeners = "PLAINTEXT://A:9092,SSL://B:9093" + props.put(KafkaConfig.AdvertisedListenersProp, correctListeners) + assertBadConfigContainingMessage(props, "The advertised.listeners" + expectedExceptionContainsTextSuffix) + + // Still invalid due to extra listener if we allow advertised listeners to derive from listeners/controller.listener.names + props.remove(KafkaConfig.AdvertisedListenersProp) + assertBadConfigContainingMessage(props, "The listeners" + expectedExceptionContainsTextSuffix) + + // Valid now + props.put(KafkaConfig.ListenersProp, correctListeners) + KafkaConfig.fromProps(props) + } + + @Test + def testControllerQuorumVoterStringsToNodes(): Unit = { + assertThrows(classOf[ConfigException], () => RaftConfig.quorumVoterStringsToNodes(Collections.singletonList(""))) + assertEquals(Seq(new Node(3000, "example.com", 9093)), + RaftConfig.quorumVoterStringsToNodes(util.Arrays.asList("3000@example.com:9093")).asScala.toSeq) + assertEquals(Seq(new Node(3000, "example.com", 9093), + new Node(3001, "example.com", 9094)), + RaftConfig.quorumVoterStringsToNodes(util.Arrays.asList("3000@example.com:9093","3001@example.com:9094")).asScala.toSeq) + } + + @Test + def testInvalidQuorumVoterConfig(): Unit = { + assertInvalidQuorumVoters("1") + assertInvalidQuorumVoters("1@") + assertInvalidQuorumVoters("1:") + assertInvalidQuorumVoters("blah@") + assertInvalidQuorumVoters("1@kafka1") + assertInvalidQuorumVoters("1@kafka1:9092,") + assertInvalidQuorumVoters("1@kafka1:9092,") + assertInvalidQuorumVoters("1@kafka1:9092,2") + assertInvalidQuorumVoters("1@kafka1:9092,2@") + assertInvalidQuorumVoters("1@kafka1:9092,2@blah") + assertInvalidQuorumVoters("1@kafka1:9092,2@blah,") + assertInvalidQuorumVoters("1@kafka1:9092:1@kafka2:9092") + } + + private def assertInvalidQuorumVoters(value: String): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect) + props.put(KafkaConfig.QuorumVotersProp, value) + assertThrows(classOf[ConfigException], () => KafkaConfig.fromProps(props)) + } + + @Test + def testValidQuorumVotersConfig(): Unit = { + val expected = new util.HashMap[Integer, AddressSpec]() + assertValidQuorumVoters("", expected) + + expected.put(1, new InetAddressSpec(new InetSocketAddress("127.0.0.1", 9092))) + assertValidQuorumVoters("1@127.0.0.1:9092", expected) + + expected.clear() + expected.put(1, UNKNOWN_ADDRESS_SPEC_INSTANCE) + assertValidQuorumVoters("1@0.0.0.0:0", expected) + + expected.clear() + expected.put(1, new InetAddressSpec(new InetSocketAddress("kafka1", 9092))) + expected.put(2, new InetAddressSpec(new InetSocketAddress("kafka2", 9092))) + expected.put(3, new InetAddressSpec(new InetSocketAddress("kafka3", 9092))) + assertValidQuorumVoters("1@kafka1:9092,2@kafka2:9092,3@kafka3:9092", expected) + } + + private def assertValidQuorumVoters(value: String, expectedVoters: util.Map[Integer, AddressSpec]): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect) + props.put(KafkaConfig.QuorumVotersProp, value) + val raftConfig = new RaftConfig(KafkaConfig.fromProps(props)) + assertEquals(expectedVoters, raftConfig.quorumVoterConnections()) + } + + @Test + def testAcceptsLargeNodeIdForRaftBasedCase(): Unit = { + // Generation of Broker IDs is not supported when using Raft-based controller quorums, + // so pick a broker ID greater than reserved.broker.max.id, which defaults to 1000, + // and make sure it is allowed despite broker.id.generation.enable=true (true is the default) + val largeBrokerId = 2000 + val props = new Properties() + props.put(KafkaConfig.ProcessRolesProp, "broker") + props.put(KafkaConfig.ListenersProp, "PLAINTEXT://localhost:9092") + props.put(KafkaConfig.ControllerListenerNamesProp, "SSL") + props.put(KafkaConfig.QuorumVotersProp, "2@localhost:9093") + props.put(KafkaConfig.NodeIdProp, largeBrokerId.toString) + KafkaConfig.fromProps(props) + } + + @Test + def testRejectsNegativeNodeIdForRaftBasedBrokerCaseWithAutoGenEnabled(): Unit = { + // -1 is the default for both node.id and broker.id + val props = new Properties() + props.put(KafkaConfig.ProcessRolesProp, "broker") + assertFalse(isValidKafkaConfig(props)) + } + + @Test + def testRejectsNegativeNodeIdForRaftBasedControllerCaseWithAutoGenEnabled(): Unit = { + // -1 is the default for both node.id and broker.id + val props = new Properties() + props.put(KafkaConfig.ProcessRolesProp, "controller") + assertFalse(isValidKafkaConfig(props)) + } + + @Test + def testRejectsNegativeNodeIdForRaftBasedCaseWithAutoGenDisabled(): Unit = { + // -1 is the default for both node.id and broker.id + val props = new Properties() + props.put(KafkaConfig.ProcessRolesProp, "broker") + props.put(KafkaConfig.BrokerIdGenerationEnableProp, "false") + props.put(KafkaConfig.QuorumVotersProp, "2@localhost:9093") + assertFalse(isValidKafkaConfig(props)) + } + + @Test + def testRejectsLargeNodeIdForZkBasedCaseWithAutoGenEnabled(): Unit = { + // Generation of Broker IDs is supported when using ZooKeeper-based controllers, + // so pick a broker ID greater than reserved.broker.max.id, which defaults to 1000, + // and make sure it is not allowed with broker.id.generation.enable=true (true is the default) + val largeBrokerId = 2000 + val props = TestUtils.createBrokerConfig(largeBrokerId, TestUtils.MockZkConnect, port = TestUtils.MockZkPort) + val listeners = "PLAINTEXT://A:9092,SSL://B:9093,SASL_SSL://C:9094" + props.put(KafkaConfig.ListenersProp, listeners) + props.put(KafkaConfig.AdvertisedListenersProp, listeners) + assertFalse(isValidKafkaConfig(props)) + } + + @Test + def testAcceptsNegativeOneNodeIdForZkBasedCaseWithAutoGenEnabled(): Unit = { + // -1 is the default for both node.id and broker.id; it implies "auto-generate" and should succeed + val props = TestUtils.createBrokerConfig(-1, TestUtils.MockZkConnect, port = TestUtils.MockZkPort) + val listeners = "PLAINTEXT://A:9092,SSL://B:9093,SASL_SSL://C:9094" + props.put(KafkaConfig.ListenersProp, listeners) + props.put(KafkaConfig.AdvertisedListenersProp, listeners) + KafkaConfig.fromProps(props) + } + + @Test + def testRejectsNegativeTwoNodeIdForZkBasedCaseWithAutoGenEnabled(): Unit = { + // -1 implies "auto-generate" and should succeed, but -2 does not and should fail + val negativeTwoNodeId = -2 + val props = TestUtils.createBrokerConfig(negativeTwoNodeId, TestUtils.MockZkConnect, port = TestUtils.MockZkPort) + val listeners = "PLAINTEXT://A:9092,SSL://B:9093,SASL_SSL://C:9094" + props.put(KafkaConfig.ListenersProp, listeners) + props.put(KafkaConfig.AdvertisedListenersProp, listeners) + props.put(KafkaConfig.NodeIdProp, negativeTwoNodeId.toString) + props.put(KafkaConfig.BrokerIdProp, negativeTwoNodeId.toString) + assertFalse(isValidKafkaConfig(props)) + } + + @Test + def testAcceptsLargeNodeIdForZkBasedCaseWithAutoGenDisabled(): Unit = { + // Ensure a broker ID greater than reserved.broker.max.id, which defaults to 1000, + // is allowed with broker.id.generation.enable=false + val largeBrokerId = 2000 + val props = TestUtils.createBrokerConfig(largeBrokerId, TestUtils.MockZkConnect, port = TestUtils.MockZkPort) + val listeners = "PLAINTEXT://A:9092,SSL://B:9093,SASL_SSL://C:9094" + props.put(KafkaConfig.ListenersProp, listeners) + props.put(KafkaConfig.AdvertisedListenersProp, listeners) + props.put(KafkaConfig.BrokerIdGenerationEnableProp, "false") + KafkaConfig.fromProps(props) + } + + @Test + def testRejectsNegativeNodeIdForZkBasedCaseWithAutoGenDisabled(): Unit = { + // -1 is the default for both node.id and broker.id + val props = TestUtils.createBrokerConfig(-1, TestUtils.MockZkConnect, port = TestUtils.MockZkPort) + val listeners = "PLAINTEXT://A:9092,SSL://B:9093,SASL_SSL://C:9094" + props.put(KafkaConfig.ListenersProp, listeners) + props.put(KafkaConfig.BrokerIdGenerationEnableProp, "false") + assertFalse(isValidKafkaConfig(props)) + } + + @Test + def testZookeeperConnectRequiredIfEmptyProcessRoles(): Unit = { + val props = new Properties() + props.put(KafkaConfig.ProcessRolesProp, "") + props.put(KafkaConfig.ListenersProp, "PLAINTEXT://127.0.0.1:9092") + assertFalse(isValidKafkaConfig(props)) + } + + @Test + def testZookeeperConnectNotRequiredIfNonEmptyProcessRoles(): Unit = { + val props = new Properties() + props.put(KafkaConfig.ProcessRolesProp, "broker") + props.put(KafkaConfig.ListenersProp, "PLAINTEXT://127.0.0.1:9092") + props.put(KafkaConfig.ControllerListenerNamesProp, "SSL") + props.put(KafkaConfig.NodeIdProp, "1") + props.put(KafkaConfig.QuorumVotersProp, "2@localhost:9093") + KafkaConfig.fromProps(props) + } + + @Test + def testCustomMetadataLogDir(): Unit = { + val metadataDir = "/path/to/metadata/dir" + val dataDir = "/path/to/data/dir" + + val props = new Properties() + props.put(KafkaConfig.ProcessRolesProp, "broker") + props.put(KafkaConfig.ControllerListenerNamesProp, "SSL") + props.put(KafkaConfig.MetadataLogDirProp, metadataDir) + props.put(KafkaConfig.LogDirProp, dataDir) + props.put(KafkaConfig.NodeIdProp, "1") + props.put(KafkaConfig.QuorumVotersProp, "2@localhost:9093") + KafkaConfig.fromProps(props) + + val config = KafkaConfig.fromProps(props) + assertEquals(metadataDir, config.metadataLogDir) + assertEquals(Seq(dataDir), config.logDirs) + } + + @Test + def testDefaultMetadataLogDir(): Unit = { + val dataDir1 = "/path/to/data/dir/1" + val dataDir2 = "/path/to/data/dir/2" + + val props = new Properties() + props.put(KafkaConfig.ProcessRolesProp, "broker") + props.put(KafkaConfig.ControllerListenerNamesProp, "SSL") + props.put(KafkaConfig.LogDirProp, s"$dataDir1,$dataDir2") + props.put(KafkaConfig.NodeIdProp, "1") + props.put(KafkaConfig.QuorumVotersProp, "2@localhost:9093") + KafkaConfig.fromProps(props) + + val config = KafkaConfig.fromProps(props) + assertEquals(dataDir1, config.metadataLogDir) + assertEquals(Seq(dataDir1, dataDir2), config.logDirs) + } + + @Test + def testPopulateSynonymsOnEmptyMap(): Unit = { + assertEquals(Collections.emptyMap(), KafkaConfig.populateSynonyms(Collections.emptyMap())) + } + + @Test + def testPopulateSynonymsOnMapWithoutNodeId(): Unit = { + val input = new util.HashMap[String, String]() + input.put(KafkaConfig.BrokerIdProp, "4") + val expectedOutput = new util.HashMap[String, String]() + expectedOutput.put(KafkaConfig.BrokerIdProp, "4") + expectedOutput.put(KafkaConfig.NodeIdProp, "4") + assertEquals(expectedOutput, KafkaConfig.populateSynonyms(input)) + } + + @Test + def testPopulateSynonymsOnMapWithoutBrokerId(): Unit = { + val input = new util.HashMap[String, String]() + input.put(KafkaConfig.NodeIdProp, "4") + val expectedOutput = new util.HashMap[String, String]() + expectedOutput.put(KafkaConfig.BrokerIdProp, "4") + expectedOutput.put(KafkaConfig.NodeIdProp, "4") + assertEquals(expectedOutput, KafkaConfig.populateSynonyms(input)) + } + + @Test + def testNodeIdMustNotBeDifferentThanBrokerId(): Unit = { + val props = new Properties() + props.setProperty(KafkaConfig.BrokerIdProp, "1") + props.setProperty(KafkaConfig.NodeIdProp, "2") + assertEquals("You must set `node.id` to the same value as `broker.id`.", + assertThrows(classOf[ConfigException], () => KafkaConfig.fromProps(props)).getMessage()) + } + + @Test + def testNodeIdOrBrokerIdMustBeSetWithKraft(): Unit = { + val props = new Properties() + props.setProperty(KafkaConfig.ProcessRolesProp, "broker") + props.setProperty(KafkaConfig.QuorumVotersProp, "2@localhost:9093") + assertEquals("Missing configuration `node.id` which is required when `process.roles` " + + "is defined (i.e. when running in KRaft mode).", + assertThrows(classOf[ConfigException], () => KafkaConfig.fromProps(props)).getMessage()) + } + + @Test + def testNodeIdIsInferredByBrokerIdWithKraft(): Unit = { + val props = new Properties() + props.setProperty(KafkaConfig.ProcessRolesProp, "broker") + props.put(KafkaConfig.ControllerListenerNamesProp, "SSL") + props.setProperty(KafkaConfig.BrokerIdProp, "3") + props.setProperty(KafkaConfig.QuorumVotersProp, "2@localhost:9093") + val config = KafkaConfig.fromProps(props) + assertEquals(3, config.brokerId) + assertEquals(3, config.nodeId) + val originals = config.originals() + assertEquals("3", originals.get(KafkaConfig.BrokerIdProp)) + assertEquals("3", originals.get(KafkaConfig.NodeIdProp)) + } + + @Test + def testBrokerIdIsInferredByNodeIdWithKraft(): Unit = { + val props = new Properties() + props.setProperty(KafkaConfig.ProcessRolesProp, "broker") + props.put(KafkaConfig.ControllerListenerNamesProp, "SSL") + props.setProperty(KafkaConfig.NodeIdProp, "3") + props.setProperty(KafkaConfig.QuorumVotersProp, "1@localhost:9093") + val config = KafkaConfig.fromProps(props) + assertEquals(3, config.brokerId) + assertEquals(3, config.nodeId) + val originals = config.originals() + assertEquals("3", originals.get(KafkaConfig.BrokerIdProp)) + assertEquals("3", originals.get(KafkaConfig.NodeIdProp)) + } + + @Test + def testSaslJwksEndpointRetryDefaults(): Unit = { + val props = new Properties() + props.put(KafkaConfig.ZkConnectProp, "localhost:2181") + val config = KafkaConfig.fromProps(props) + assertNotNull(config.getLong(KafkaConfig.SaslOAuthBearerJwksEndpointRetryBackoffMsProp)) + assertNotNull(config.getLong(KafkaConfig.SaslOAuthBearerJwksEndpointRetryBackoffMaxMsProp)) + } +} diff --git a/core/src/test/scala/unit/kafka/server/KafkaMetricReporterClusterIdTest.scala b/core/src/test/scala/unit/kafka/server/KafkaMetricReporterClusterIdTest.scala new file mode 100755 index 0000000..26f6a12 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/KafkaMetricReporterClusterIdTest.scala @@ -0,0 +1,118 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.util.concurrent.atomic.AtomicReference + +import kafka.metrics.KafkaMetricsReporter +import kafka.utils.{CoreUtils, TestUtils, VerifiableProperties} +import kafka.server.QuorumTestHarness +import org.apache.kafka.common.{ClusterResource, ClusterResourceListener} +import org.apache.kafka.test.MockMetricsReporter +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} +import org.apache.kafka.test.TestUtils.isValidClusterId + +object KafkaMetricReporterClusterIdTest { + val setupError = new AtomicReference[String]("") + + class MockKafkaMetricsReporter extends KafkaMetricsReporter with ClusterResourceListener { + + override def onUpdate(clusterMetadata: ClusterResource): Unit = { + MockKafkaMetricsReporter.CLUSTER_META.set(clusterMetadata) + } + + override def init(props: VerifiableProperties): Unit = { + } + } + + object MockKafkaMetricsReporter { + val CLUSTER_META = new AtomicReference[ClusterResource] + } + + object MockBrokerMetricsReporter { + val CLUSTER_META: AtomicReference[ClusterResource] = new AtomicReference[ClusterResource] + } + + class MockBrokerMetricsReporter extends MockMetricsReporter with ClusterResourceListener { + + override def onUpdate(clusterMetadata: ClusterResource): Unit = { + MockBrokerMetricsReporter.CLUSTER_META.set(clusterMetadata) + } + + override def configure(configs: java.util.Map[String, _]): Unit = { + // Check that the configuration passed to the MetricsReporter includes the broker id as an Integer. + // This is a regression test for KAFKA-4756. + // + // Because this code is run during the test setUp phase, if we throw an exception here, + // it just results in the test itself being declared "not found" rather than failing. + // So we track an error message which we will check later in the test body. + val brokerId = configs.get(KafkaConfig.BrokerIdProp) + if (brokerId == null) + setupError.compareAndSet("", "No value was set for the broker id.") + else if (!brokerId.isInstanceOf[String]) + setupError.compareAndSet("", "The value set for the broker id was not a string.") + try + Integer.parseInt(brokerId.asInstanceOf[String]) + catch { + case e: Exception => setupError.compareAndSet("", "Error parsing broker id " + e.toString) + } + } + } +} + +class KafkaMetricReporterClusterIdTest extends QuorumTestHarness { + var server: KafkaServer = null + var config: KafkaConfig = null + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + val props = TestUtils.createBrokerConfig(1, zkConnect) + props.setProperty(KafkaConfig.KafkaMetricsReporterClassesProp, "kafka.server.KafkaMetricReporterClusterIdTest$MockKafkaMetricsReporter") + props.setProperty(KafkaConfig.MetricReporterClassesProp, "kafka.server.KafkaMetricReporterClusterIdTest$MockBrokerMetricsReporter") + props.setProperty(KafkaConfig.BrokerIdGenerationEnableProp, "true") + props.setProperty(KafkaConfig.BrokerIdProp, "-1") + config = KafkaConfig.fromProps(props) + server = new KafkaServer(config, threadNamePrefix = Option(this.getClass.getName)) + server.startup() + } + + @Test + def testClusterIdPresent(): Unit = { + assertEquals("", KafkaMetricReporterClusterIdTest.setupError.get()) + + assertNotNull(KafkaMetricReporterClusterIdTest.MockKafkaMetricsReporter.CLUSTER_META) + isValidClusterId(KafkaMetricReporterClusterIdTest.MockKafkaMetricsReporter.CLUSTER_META.get().clusterId()) + + assertNotNull(KafkaMetricReporterClusterIdTest.MockBrokerMetricsReporter.CLUSTER_META) + isValidClusterId(KafkaMetricReporterClusterIdTest.MockBrokerMetricsReporter.CLUSTER_META.get().clusterId()) + + assertEquals(KafkaMetricReporterClusterIdTest.MockKafkaMetricsReporter.CLUSTER_META.get().clusterId(), + KafkaMetricReporterClusterIdTest.MockBrokerMetricsReporter.CLUSTER_META.get().clusterId()) + + server.shutdown() + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @AfterEach + override def tearDown(): Unit = { + server.shutdown() + CoreUtils.delete(config.logDirs) + super.tearDown() + } +} diff --git a/core/src/test/scala/unit/kafka/server/KafkaMetricReporterExceptionHandlingTest.scala b/core/src/test/scala/unit/kafka/server/KafkaMetricReporterExceptionHandlingTest.scala new file mode 100644 index 0000000..c84a91b --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/KafkaMetricReporterExceptionHandlingTest.scala @@ -0,0 +1,119 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + **/ + +package kafka.server + +import java.net.Socket +import java.util.{Collections, Properties} + +import kafka.utils.TestUtils +import org.apache.kafka.common.config.internals.QuotaConfigs +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.requests.{ListGroupsRequest, ListGroupsResponse} +import org.apache.kafka.common.metrics.MetricsReporter +import org.apache.kafka.common.metrics.KafkaMetric +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.protocol.Errors +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.kafka.common.message.ListGroupsRequestData + +/* + * this test checks that a reporter that throws an exception will not affect other reporters + * and will not affect the broker's message handling + */ +class KafkaMetricReporterExceptionHandlingTest extends BaseRequestTest { + + override def brokerCount: Int = 1 + + override def brokerPropertyOverrides(properties: Properties): Unit = { + properties.put(KafkaConfig.MetricReporterClassesProp, classOf[KafkaMetricReporterExceptionHandlingTest.BadReporter].getName + "," + classOf[KafkaMetricReporterExceptionHandlingTest.GoodReporter].getName) + } + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + + // need a quota prop to register a "throttle-time" metrics after server startup + val quotaProps = new Properties() + quotaProps.put(QuotaConfigs.REQUEST_PERCENTAGE_OVERRIDE_CONFIG, "0.1") + adminZkClient.changeClientIdConfig("", quotaProps) + } + + @AfterEach + override def tearDown(): Unit = { + KafkaMetricReporterExceptionHandlingTest.goodReporterRegistered.set(0) + KafkaMetricReporterExceptionHandlingTest.badReporterRegistered.set(0) + + super.tearDown() + } + + @Test + def testBothReportersAreInvoked(): Unit = { + val port = anySocketServer.boundPort(ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT)) + val socket = new Socket("localhost", port) + socket.setSoTimeout(10000) + + try { + TestUtils.retry(10000) { + val listGroupsRequest = new ListGroupsRequest.Builder(new ListGroupsRequestData).build() + val listGroupsResponse = sendAndReceive[ListGroupsResponse](listGroupsRequest, socket) + val errors = listGroupsResponse.errorCounts() + assertEquals(Collections.singletonMap(Errors.NONE, 1), errors) + assertEquals(KafkaMetricReporterExceptionHandlingTest.goodReporterRegistered.get, KafkaMetricReporterExceptionHandlingTest.badReporterRegistered.get) + assertTrue(KafkaMetricReporterExceptionHandlingTest.goodReporterRegistered.get > 0) + } + } finally { + socket.close() + } + } +} + +object KafkaMetricReporterExceptionHandlingTest { + var goodReporterRegistered = new AtomicInteger + var badReporterRegistered = new AtomicInteger + + class GoodReporter extends MetricsReporter { + + def configure(configs: java.util.Map[String, _]): Unit = { + } + + def init(metrics: java.util.List[KafkaMetric]): Unit = { + } + + def metricChange(metric: KafkaMetric): Unit = { + if (metric.metricName.group == "Request") { + goodReporterRegistered.incrementAndGet + } + } + + def metricRemoval(metric: KafkaMetric): Unit = { + } + + def close(): Unit = { + } + } + + class BadReporter extends GoodReporter { + + override def metricChange(metric: KafkaMetric): Unit = { + if (metric.metricName.group == "Request") { + badReporterRegistered.incrementAndGet + throw new RuntimeException(metric.metricName.toString) + } + } + } +} diff --git a/core/src/test/scala/unit/kafka/server/KafkaMetricsReporterTest.scala b/core/src/test/scala/unit/kafka/server/KafkaMetricsReporterTest.scala new file mode 100644 index 0000000..7e5d791 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/KafkaMetricsReporterTest.scala @@ -0,0 +1,95 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.util + +import java.util.concurrent.atomic.AtomicReference + +import kafka.utils.{CoreUtils, TestUtils} +import kafka.server.QuorumTestHarness +import org.apache.kafka.common.metrics.{KafkaMetric, MetricsContext, MetricsReporter} +import org.junit.jupiter.api.Assertions.{assertEquals} +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} +import org.junit.jupiter.api.Assertions._ + + +object KafkaMetricsReporterTest { + val setupError = new AtomicReference[String]("") + + class MockMetricsReporter extends MetricsReporter { + def init(metrics: util.List[KafkaMetric]): Unit = {} + + def metricChange(metric: KafkaMetric): Unit = {} + + def metricRemoval(metric: KafkaMetric): Unit = {} + + override def close(): Unit = {} + + override def contextChange(metricsContext: MetricsContext): Unit = { + //read jmxPrefix + + MockMetricsReporter.JMXPREFIX.set(metricsContext.contextLabels().get("_namespace").toString) + MockMetricsReporter.CLUSTERID.set(metricsContext.contextLabels().get("kafka.cluster.id").toString) + MockMetricsReporter.BROKERID.set(metricsContext.contextLabels().get("kafka.broker.id").toString) + } + + override def configure(configs: util.Map[String, _]): Unit = {} + + } + + object MockMetricsReporter { + val JMXPREFIX: AtomicReference[String] = new AtomicReference[String] + val BROKERID : AtomicReference[String] = new AtomicReference[String] + val CLUSTERID : AtomicReference[String] = new AtomicReference[String] + } +} + +class KafkaMetricsReporterTest extends QuorumTestHarness { + var server: KafkaServer = null + var config: KafkaConfig = null + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + val props = TestUtils.createBrokerConfig(1, zkConnect) + props.setProperty(KafkaConfig.MetricReporterClassesProp, "kafka.server.KafkaMetricsReporterTest$MockMetricsReporter") + props.setProperty(KafkaConfig.BrokerIdGenerationEnableProp, "true") + props.setProperty(KafkaConfig.BrokerIdProp, "-1") + config = KafkaConfig.fromProps(props) + server = new KafkaServer(config, threadNamePrefix = Option(this.getClass.getName)) + server.startup() + } + + @Test + def testMetricsContextNamespacePresent(): Unit = { + assertNotNull(KafkaMetricsReporterTest.MockMetricsReporter.CLUSTERID) + assertNotNull(KafkaMetricsReporterTest.MockMetricsReporter.BROKERID) + assertNotNull(KafkaMetricsReporterTest.MockMetricsReporter.JMXPREFIX) + assertEquals("kafka.server", KafkaMetricsReporterTest.MockMetricsReporter.JMXPREFIX.get()) + + server.shutdown() + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @AfterEach + override def tearDown(): Unit = { + server.shutdown() + CoreUtils.delete(config.logDirs) + super.tearDown() + } +} diff --git a/core/src/test/scala/unit/kafka/server/KafkaRaftServerTest.scala b/core/src/test/scala/unit/kafka/server/KafkaRaftServerTest.scala new file mode 100644 index 0000000..82ad542 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/KafkaRaftServerTest.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.io.File +import java.nio.file.Files +import java.util.Properties +import kafka.common.{InconsistentBrokerMetadataException, InconsistentNodeIdException, KafkaException} +import kafka.log.UnifiedLog +import org.apache.kafka.common.Uuid +import org.apache.kafka.common.utils.Utils +import org.apache.kafka.test.TestUtils +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +class KafkaRaftServerTest { + private val clusterIdBase64 = "H3KKO4NTRPaCWtEmm3vW7A" + + @Test + def testSuccessfulLoadMetaProperties(): Unit = { + val clusterId = clusterIdBase64 + val nodeId = 0 + val metaProperties = MetaProperties(clusterId, nodeId) + + val configProperties = new Properties + configProperties.put(KafkaConfig.ProcessRolesProp, "broker,controller") + configProperties.put(KafkaConfig.NodeIdProp, nodeId.toString) + configProperties.put(KafkaConfig.ListenersProp, "PLAINTEXT://127.0.0.1:9092,SSL://127.0.0.1:9093") + configProperties.put(KafkaConfig.QuorumVotersProp, s"$nodeId@localhost:9093") + configProperties.put(KafkaConfig.ControllerListenerNamesProp, "SSL") + + val (loadedMetaProperties, offlineDirs) = + invokeLoadMetaProperties(metaProperties, configProperties) + + assertEquals(metaProperties, loadedMetaProperties) + assertEquals(Seq.empty, offlineDirs) + } + + @Test + def testLoadMetaPropertiesWithInconsistentNodeId(): Unit = { + val clusterId = clusterIdBase64 + val metaNodeId = 1 + val configNodeId = 0 + + val metaProperties = MetaProperties(clusterId, metaNodeId) + val configProperties = new Properties + + configProperties.put(KafkaConfig.ProcessRolesProp, "controller") + configProperties.put(KafkaConfig.NodeIdProp, configNodeId.toString) + configProperties.put(KafkaConfig.QuorumVotersProp, s"$configNodeId@localhost:9092") + configProperties.put(KafkaConfig.ControllerListenerNamesProp, "PLAINTEXT") + + assertThrows(classOf[InconsistentNodeIdException], () => + invokeLoadMetaProperties(metaProperties, configProperties)) + } + + private def invokeLoadMetaProperties( + metaProperties: MetaProperties, + configProperties: Properties + ): (MetaProperties, collection.Seq[String]) = { + val tempLogDir = TestUtils.tempDirectory() + try { + writeMetaProperties(tempLogDir, metaProperties) + + configProperties.put(KafkaConfig.LogDirProp, tempLogDir.getAbsolutePath) + val config = KafkaConfig.fromProps(configProperties) + KafkaRaftServer.initializeLogDirs(config) + } finally { + Utils.delete(tempLogDir) + } + } + + private def writeMetaProperties( + logDir: File, + metaProperties: MetaProperties + ): Unit = { + val metaPropertiesFile = new File(logDir.getAbsolutePath, "meta.properties") + val checkpoint = new BrokerMetadataCheckpoint(metaPropertiesFile) + checkpoint.write(metaProperties.toProperties) + } + + @Test + def testStartupFailsIfMetaPropertiesMissingInSomeLogDir(): Unit = { + val clusterId = clusterIdBase64 + val nodeId = 1 + + // One log dir is online and has properly formatted `meta.properties`. + // The other is online, but has no `meta.properties`. + val logDir1 = TestUtils.tempDirectory() + val logDir2 = TestUtils.tempDirectory() + writeMetaProperties(logDir1, MetaProperties(clusterId, nodeId)) + + val configProperties = new Properties + configProperties.put(KafkaConfig.ProcessRolesProp, "broker") + configProperties.put(KafkaConfig.NodeIdProp, nodeId.toString) + configProperties.put(KafkaConfig.QuorumVotersProp, s"${(nodeId + 1)}@localhost:9092") + configProperties.put(KafkaConfig.LogDirProp, Seq(logDir1, logDir2).map(_.getAbsolutePath).mkString(",")) + configProperties.put(KafkaConfig.ControllerListenerNamesProp, "SSL") + val config = KafkaConfig.fromProps(configProperties) + + assertThrows(classOf[KafkaException], () => KafkaRaftServer.initializeLogDirs(config)) + } + + @Test + def testStartupFailsIfMetaLogDirIsOffline(): Unit = { + val clusterId = clusterIdBase64 + val nodeId = 1 + + // One log dir is online and has properly formatted `meta.properties` + val validDir = TestUtils.tempDirectory() + writeMetaProperties(validDir, MetaProperties(clusterId, nodeId)) + + // Use a regular file as an invalid log dir to trigger an IO error + val invalidDir = TestUtils.tempFile("blah") + val configProperties = new Properties + configProperties.put(KafkaConfig.ProcessRolesProp, "broker") + configProperties.put(KafkaConfig.QuorumVotersProp, s"${(nodeId + 1)}@localhost:9092") + configProperties.put(KafkaConfig.NodeIdProp, nodeId.toString) + configProperties.put(KafkaConfig.MetadataLogDirProp, invalidDir.getAbsolutePath) + configProperties.put(KafkaConfig.LogDirProp, validDir.getAbsolutePath) + configProperties.put(KafkaConfig.ControllerListenerNamesProp, "SSL") + val config = KafkaConfig.fromProps(configProperties) + + assertThrows(classOf[KafkaException], () => KafkaRaftServer.initializeLogDirs(config)) + } + + @Test + def testStartupDoesNotFailIfDataDirIsOffline(): Unit = { + val clusterId = clusterIdBase64 + val nodeId = 1 + + // One log dir is online and has properly formatted `meta.properties` + val validDir = TestUtils.tempDirectory() + writeMetaProperties(validDir, MetaProperties(clusterId, nodeId)) + + // Use a regular file as an invalid log dir to trigger an IO error + val invalidDir = TestUtils.tempFile("blah") + val configProperties = new Properties + configProperties.put(KafkaConfig.ProcessRolesProp, "broker") + configProperties.put(KafkaConfig.NodeIdProp, nodeId.toString) + configProperties.put(KafkaConfig.QuorumVotersProp, s"${(nodeId + 1)}@localhost:9092") + configProperties.put(KafkaConfig.MetadataLogDirProp, validDir.getAbsolutePath) + configProperties.put(KafkaConfig.LogDirProp, invalidDir.getAbsolutePath) + configProperties.put(KafkaConfig.ControllerListenerNamesProp, "SSL") + val config = KafkaConfig.fromProps(configProperties) + + val (loadedProperties, offlineDirs) = KafkaRaftServer.initializeLogDirs(config) + assertEquals(nodeId, loadedProperties.nodeId) + assertEquals(Seq(invalidDir.getAbsolutePath), offlineDirs) + } + + @Test + def testStartupFailsIfUnexpectedMetadataDir(): Unit = { + val nodeId = 1 + val clusterId = clusterIdBase64 + + // Create two directories with valid `meta.properties` + val metadataDir = TestUtils.tempDirectory() + val dataDir = TestUtils.tempDirectory() + + Seq(metadataDir, dataDir).foreach { dir => + writeMetaProperties(dir, MetaProperties(clusterId, nodeId)) + } + + // Create the metadata dir in the data directory + Files.createDirectory(new File(dataDir, UnifiedLog.logDirName(KafkaRaftServer.MetadataPartition)).toPath) + + val configProperties = new Properties + configProperties.put(KafkaConfig.ProcessRolesProp, "broker") + configProperties.put(KafkaConfig.NodeIdProp, nodeId.toString) + configProperties.put(KafkaConfig.QuorumVotersProp, s"${(nodeId + 1)}@localhost:9092") + configProperties.put(KafkaConfig.MetadataLogDirProp, metadataDir.getAbsolutePath) + configProperties.put(KafkaConfig.LogDirProp, dataDir.getAbsolutePath) + configProperties.put(KafkaConfig.ControllerListenerNamesProp, "SSL") + val config = KafkaConfig.fromProps(configProperties) + + assertThrows(classOf[KafkaException], () => KafkaRaftServer.initializeLogDirs(config)) + } + + @Test + def testLoadPropertiesWithInconsistentClusterIds(): Unit = { + val nodeId = 1 + val logDir1 = TestUtils.tempDirectory() + val logDir2 = TestUtils.tempDirectory() + + // Create a random clusterId in each log dir + Seq(logDir1, logDir2).foreach { dir => + writeMetaProperties(dir, MetaProperties(clusterId = Uuid.randomUuid().toString, nodeId)) + } + + val configProperties = new Properties + configProperties.put(KafkaConfig.ProcessRolesProp, "broker") + configProperties.put(KafkaConfig.QuorumVotersProp, s"${(nodeId + 1)}@localhost:9092") + configProperties.put(KafkaConfig.NodeIdProp, nodeId.toString) + configProperties.put(KafkaConfig.LogDirProp, Seq(logDir1, logDir2).map(_.getAbsolutePath).mkString(",")) + configProperties.put(KafkaConfig.ControllerListenerNamesProp, "SSL") + val config = KafkaConfig.fromProps(configProperties) + + assertThrows(classOf[InconsistentBrokerMetadataException], + () => KafkaRaftServer.initializeLogDirs(config)) + } + +} diff --git a/core/src/test/scala/unit/kafka/server/KafkaServerTest.scala b/core/src/test/scala/unit/kafka/server/KafkaServerTest.scala new file mode 100755 index 0000000..6ab6930 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/KafkaServerTest.scala @@ -0,0 +1,138 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.api.ApiVersion +import kafka.utils.TestUtils +import kafka.server.QuorumTestHarness +import org.junit.jupiter.api.Assertions.{assertEquals, assertNull, assertThrows, fail} +import org.junit.jupiter.api.Test + +import java.util.Properties + +class KafkaServerTest extends QuorumTestHarness { + + @Test + def testAlreadyRegisteredAdvertisedListeners(): Unit = { + //start a server with a advertised listener + val server1 = createServer(1, "myhost", TestUtils.RandomPort) + + //start a server with same advertised listener + assertThrows(classOf[IllegalArgumentException], () => createServer(2, "myhost", TestUtils.boundPort(server1))) + + //start a server with same host but with different port + val server2 = createServer(2, "myhost", TestUtils.RandomPort) + + TestUtils.shutdownServers(Seq(server1, server2)) + } + + @Test + def testCreatesProperZkTlsConfigWhenDisabled(): Unit = { + val props = new Properties + props.put(KafkaConfig.ZkConnectProp, zkConnect) // required, otherwise we would leave it out + props.put(KafkaConfig.ZkSslClientEnableProp, "false") + val zkClientConfig = KafkaServer.zkClientConfigFromKafkaConfig(KafkaConfig.fromProps(props)) + KafkaConfig.ZkSslConfigToSystemPropertyMap.keys.foreach { propName => + assertNull(zkClientConfig.getProperty(propName)) + } + } + + @Test + def testCreatesProperZkTlsConfigWithTrueValues(): Unit = { + val props = new Properties + props.put(KafkaConfig.ZkConnectProp, zkConnect) // required, otherwise we would leave it out + // should get correct config for all properties if TLS is enabled + val someValue = "some_value" + def kafkaConfigValueToSet(kafkaProp: String) : String = kafkaProp match { + case KafkaConfig.ZkSslClientEnableProp | KafkaConfig.ZkSslCrlEnableProp | KafkaConfig.ZkSslOcspEnableProp => "true" + case KafkaConfig.ZkSslEndpointIdentificationAlgorithmProp => "HTTPS" + case _ => someValue + } + KafkaConfig.ZkSslConfigToSystemPropertyMap.keys.foreach(kafkaProp => props.put(kafkaProp, kafkaConfigValueToSet(kafkaProp))) + val zkClientConfig = KafkaServer.zkClientConfigFromKafkaConfig(KafkaConfig.fromProps(props)) + // now check to make sure the values were set correctly + def zkClientValueToExpect(kafkaProp: String) : String = kafkaProp match { + case KafkaConfig.ZkSslClientEnableProp | KafkaConfig.ZkSslCrlEnableProp | KafkaConfig.ZkSslOcspEnableProp => "true" + case KafkaConfig.ZkSslEndpointIdentificationAlgorithmProp => "true" + case _ => someValue + } + KafkaConfig.ZkSslConfigToSystemPropertyMap.keys.foreach(kafkaProp => + assertEquals(zkClientValueToExpect(kafkaProp), zkClientConfig.getProperty(KafkaConfig.ZkSslConfigToSystemPropertyMap(kafkaProp)))) + } + + @Test + def testCreatesProperZkTlsConfigWithFalseAndListValues(): Unit = { + val props = new Properties + props.put(KafkaConfig.ZkConnectProp, zkConnect) // required, otherwise we would leave it out + // should get correct config for all properties if TLS is enabled + val someValue = "some_value" + def kafkaConfigValueToSet(kafkaProp: String) : String = kafkaProp match { + case KafkaConfig.ZkSslClientEnableProp => "true" + case KafkaConfig.ZkSslCrlEnableProp | KafkaConfig.ZkSslOcspEnableProp => "false" + case KafkaConfig.ZkSslEndpointIdentificationAlgorithmProp => "" + case KafkaConfig.ZkSslEnabledProtocolsProp | KafkaConfig.ZkSslCipherSuitesProp => "A,B" + case _ => someValue + } + KafkaConfig.ZkSslConfigToSystemPropertyMap.keys.foreach(kafkaProp => props.put(kafkaProp, kafkaConfigValueToSet(kafkaProp))) + val zkClientConfig = KafkaServer.zkClientConfigFromKafkaConfig(KafkaConfig.fromProps(props)) + // now check to make sure the values were set correctly + def zkClientValueToExpect(kafkaProp: String) : String = kafkaProp match { + case KafkaConfig.ZkSslClientEnableProp => "true" + case KafkaConfig.ZkSslCrlEnableProp | KafkaConfig.ZkSslOcspEnableProp => "false" + case KafkaConfig.ZkSslEndpointIdentificationAlgorithmProp => "false" + case KafkaConfig.ZkSslEnabledProtocolsProp | KafkaConfig.ZkSslCipherSuitesProp => "A,B" + case _ => someValue + } + KafkaConfig.ZkSslConfigToSystemPropertyMap.keys.foreach(kafkaProp => + assertEquals(zkClientValueToExpect(kafkaProp), zkClientConfig.getProperty(KafkaConfig.ZkSslConfigToSystemPropertyMap(kafkaProp)))) + } + + @Test + def testZkIsrManager(): Unit = { + val props = TestUtils.createBrokerConfigs(1, zkConnect).head + props.put(KafkaConfig.InterBrokerProtocolVersionProp, "2.7-IV1") + + val server = TestUtils.createServer(KafkaConfig.fromProps(props)) + server.replicaManager.alterIsrManager match { + case _: ZkIsrManager => + case _ => fail("Should use ZK for ISR manager in versions before 2.7-IV2") + } + server.shutdown() + } + + @Test + def testAlterIsrManager(): Unit = { + val props = TestUtils.createBrokerConfigs(1, zkConnect).head + props.put(KafkaConfig.InterBrokerProtocolVersionProp, ApiVersion.latestVersion.toString) + + val server = TestUtils.createServer(KafkaConfig.fromProps(props)) + server.replicaManager.alterIsrManager match { + case _: DefaultAlterIsrManager => + case _ => fail("Should use AlterIsr for ISR manager in versions after 2.7-IV2") + } + server.shutdown() + } + + def createServer(nodeId: Int, hostName: String, port: Int): KafkaServer = { + val props = TestUtils.createBrokerConfig(nodeId, zkConnect) + props.put(KafkaConfig.AdvertisedListenersProp, s"PLAINTEXT://$hostName:$port") + val kafkaConfig = KafkaConfig.fromProps(props) + TestUtils.createServer(kafkaConfig) + } + +} diff --git a/core/src/test/scala/unit/kafka/server/LeaderElectionTest.scala b/core/src/test/scala/unit/kafka/server/LeaderElectionTest.scala new file mode 100755 index 0000000..a1fb7cd --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/LeaderElectionTest.scala @@ -0,0 +1,179 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util.Collections + +import org.apache.kafka.common.{TopicPartition, Uuid} + +import scala.jdk.CollectionConverters._ +import kafka.api.LeaderAndIsr +import org.apache.kafka.common.requests._ +import org.junit.jupiter.api.Assertions._ +import kafka.utils.TestUtils +import kafka.cluster.Broker +import kafka.controller.{ControllerChannelManager, ControllerContext, StateChangeLogger} +import kafka.utils.TestUtils._ +import kafka.server.QuorumTestHarness +import org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrPartitionState +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.utils.Time +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +class LeaderElectionTest extends QuorumTestHarness { + val brokerId1 = 0 + val brokerId2 = 1 + + var servers: Seq[KafkaServer] = Seq.empty[KafkaServer] + + var staleControllerEpochDetected = false + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + + val configProps1 = TestUtils.createBrokerConfig(brokerId1, zkConnect, enableControlledShutdown = false) + val configProps2 = TestUtils.createBrokerConfig(brokerId2, zkConnect, enableControlledShutdown = false) + + configProps1.put("unclean.leader.election.enable", "true") + configProps2.put("unclean.leader.election.enable", "true") + + // start both servers + val server1 = TestUtils.createServer(KafkaConfig.fromProps(configProps1)) + val server2 = TestUtils.createServer(KafkaConfig.fromProps(configProps2)) + servers ++= List(server1, server2) + } + + @AfterEach + override def tearDown(): Unit = { + TestUtils.shutdownServers(servers) + super.tearDown() + } + + @Test + def testLeaderElectionAndEpoch(): Unit = { + // start 2 brokers + val topic = "new-topic" + val partitionId = 0 + + TestUtils.waitUntilBrokerMetadataIsPropagated(servers) + + // create topic with 1 partition, 2 replicas, one on each broker + val leader1 = createTopic(zkClient, topic, partitionReplicaAssignment = Map(0 -> Seq(0, 1)), servers = servers)(0) + + val leaderEpoch1 = zkClient.getEpochForPartition(new TopicPartition(topic, partitionId)).get + assertTrue(leader1 == 0, "Leader should be broker 0") + assertEquals(0, leaderEpoch1, "First epoch value should be 0") + + // kill the server hosting the preferred replica/initial leader + servers.head.shutdown() + // check if leader moves to the other server + val leader2 = waitUntilLeaderIsElectedOrChanged(zkClient, topic, partitionId, oldLeaderOpt = Some(leader1)) + val leaderEpoch2 = zkClient.getEpochForPartition(new TopicPartition(topic, partitionId)).get + assertEquals(1, leader2, "Leader must move to broker 1") + // new leaderEpoch will be leaderEpoch1+2, one increment during ReplicaStateMachine.startup()-> handleStateChanges + // for offline replica and one increment during PartitionStateMachine.triggerOnlinePartitionStateChange() + assertEquals(leaderEpoch1 + 2 , leaderEpoch2, "Second epoch value should be %d".format(leaderEpoch1 + 2)) + + servers.head.startup() + //make sure second server joins the ISR + TestUtils.waitUntilTrue(() => { + servers.last.metadataCache.getPartitionInfo(topic, partitionId).exists(_.isr.size == 2) + }, "Inconsistent metadata after second broker startup") + + servers.last.shutdown() + + Thread.sleep(zookeeper.tickTime) + val leader3 = waitUntilLeaderIsElectedOrChanged(zkClient, topic, partitionId, oldLeaderOpt = Some(leader2)) + val leaderEpoch3 = zkClient.getEpochForPartition(new TopicPartition(topic, partitionId)).get + assertEquals(0, leader3, "Leader must return to 0") + assertEquals(leaderEpoch2 + 2 , leaderEpoch3, "Second epoch value should be %d".format(leaderEpoch2 + 2)) + } + + @Test + def testLeaderElectionWithStaleControllerEpoch(): Unit = { + // start 2 brokers + val topic = "new-topic" + val partitionId = 0 + + // create topic with 1 partition, 2 replicas, one on each broker + val leader1 = createTopic(zkClient, topic, partitionReplicaAssignment = Map(0 -> Seq(0, 1)), servers = servers)(0) + + val leaderEpoch1 = zkClient.getEpochForPartition(new TopicPartition(topic, partitionId)).get + debug("leader Epoch: " + leaderEpoch1) + debug("Leader is elected to be: %s".format(leader1)) + // NOTE: this is to avoid transient test failures + assertTrue(leader1 == 0 || leader1 == 1, "Leader could be broker 0 or broker 1") + assertEquals(0, leaderEpoch1, "First epoch value should be 0") + + // start another controller + val controllerId = 2 + + val controllerConfig = KafkaConfig.fromProps(TestUtils.createBrokerConfig(controllerId, zkConnect)) + val securityProtocol = SecurityProtocol.PLAINTEXT + val listenerName = ListenerName.forSecurityProtocol(securityProtocol) + val brokerAndEpochs = servers.map(s => + (new Broker(s.config.brokerId, "localhost", TestUtils.boundPort(s), listenerName, securityProtocol), + s.kafkaController.brokerEpoch)).toMap + val nodes = brokerAndEpochs.keys.map(_.node(listenerName)) + + val controllerContext = new ControllerContext + controllerContext.setLiveBrokers(brokerAndEpochs) + val metrics = new Metrics + val controllerChannelManager = new ControllerChannelManager(controllerContext, controllerConfig, Time.SYSTEM, + metrics, new StateChangeLogger(controllerId, inControllerContext = true, None)) + controllerChannelManager.startup() + try { + val staleControllerEpoch = 0 + val partitionStates = Seq( + new LeaderAndIsrPartitionState() + .setTopicName(topic) + .setPartitionIndex(partitionId) + .setControllerEpoch(2) + .setLeader(brokerId2) + .setLeaderEpoch(LeaderAndIsr.initialLeaderEpoch) + .setIsr(Seq(brokerId1, brokerId2).map(Integer.valueOf).asJava) + .setZkVersion(LeaderAndIsr.initialZKVersion) + .setReplicas(Seq(0, 1).map(Integer.valueOf).asJava) + .setIsNew(false) + ) + val requestBuilder = new LeaderAndIsrRequest.Builder( + ApiKeys.LEADER_AND_ISR.latestVersion, controllerId, staleControllerEpoch, + servers(brokerId2).kafkaController.brokerEpoch, partitionStates.asJava, + Collections.singletonMap(topic, Uuid.randomUuid()), nodes.toSet.asJava) + + controllerChannelManager.sendRequest(brokerId2, requestBuilder, staleControllerEpochCallback) + TestUtils.waitUntilTrue(() => staleControllerEpochDetected, "Controller epoch should be stale") + assertTrue(staleControllerEpochDetected, "Stale controller epoch not detected by the broker") + } finally { + controllerChannelManager.shutdown() + metrics.close() + } + } + + private def staleControllerEpochCallback(response: AbstractResponse): Unit = { + val leaderAndIsrResponse = response.asInstanceOf[LeaderAndIsrResponse] + staleControllerEpochDetected = leaderAndIsrResponse.error match { + case Errors.STALE_CONTROLLER_EPOCH => true + case _ => false + } + } +} diff --git a/core/src/test/scala/unit/kafka/server/ListOffsetsRequestTest.scala b/core/src/test/scala/unit/kafka/server/ListOffsetsRequestTest.scala new file mode 100644 index 0000000..1988ad6 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/ListOffsetsRequestTest.scala @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import kafka.utils.TestUtils +import org.apache.kafka.common.message.ListOffsetsRequestData.{ListOffsetsPartition, ListOffsetsTopic} +import org.apache.kafka.common.message.ListOffsetsResponseData.ListOffsetsPartitionResponse +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.requests.{ListOffsetsRequest, ListOffsetsResponse} +import org.apache.kafka.common.{IsolationLevel, TopicPartition} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import java.util.Optional +import scala.jdk.CollectionConverters._ + +class ListOffsetsRequestTest extends BaseRequestTest { + + val topic = "topic" + val partition = new TopicPartition(topic, 0) + + @Test + def testListOffsetsErrorCodes(): Unit = { + val targetTimes = List(new ListOffsetsTopic() + .setName(topic) + .setPartitions(List(new ListOffsetsPartition() + .setPartitionIndex(partition.partition) + .setTimestamp(ListOffsetsRequest.EARLIEST_TIMESTAMP) + .setCurrentLeaderEpoch(0)).asJava)).asJava + + val consumerRequest = ListOffsetsRequest.Builder + .forConsumer(false, IsolationLevel.READ_UNCOMMITTED, false) + .setTargetTimes(targetTimes) + .build() + + val replicaRequest = ListOffsetsRequest.Builder + .forReplica(ApiKeys.LIST_OFFSETS.latestVersion, servers.head.config.brokerId) + .setTargetTimes(targetTimes) + .build() + + val debugReplicaRequest = ListOffsetsRequest.Builder + .forReplica(ApiKeys.LIST_OFFSETS.latestVersion, ListOffsetsRequest.DEBUGGING_REPLICA_ID) + .setTargetTimes(targetTimes) + .build() + + // Unknown topic + val randomBrokerId = servers.head.config.brokerId + assertResponseError(Errors.UNKNOWN_TOPIC_OR_PARTITION, randomBrokerId, consumerRequest) + assertResponseError(Errors.UNKNOWN_TOPIC_OR_PARTITION, randomBrokerId, replicaRequest) + assertResponseError(Errors.UNKNOWN_TOPIC_OR_PARTITION, randomBrokerId, debugReplicaRequest) + + val partitionToLeader = TestUtils.createTopic(zkClient, topic, numPartitions = 1, replicationFactor = 2, servers) + val replicas = zkClient.getReplicasForPartition(partition).toSet + val leader = partitionToLeader(partition.partition) + val follower = replicas.find(_ != leader).get + val nonReplica = servers.map(_.config.brokerId).find(!replicas.contains(_)).get + + // Follower + assertResponseError(Errors.NOT_LEADER_OR_FOLLOWER, follower, consumerRequest) + assertResponseError(Errors.NOT_LEADER_OR_FOLLOWER, follower, replicaRequest) + assertResponseError(Errors.NONE, follower, debugReplicaRequest) + + // Non-replica + assertResponseError(Errors.NOT_LEADER_OR_FOLLOWER, nonReplica, consumerRequest) + assertResponseError(Errors.NOT_LEADER_OR_FOLLOWER, nonReplica, replicaRequest) + assertResponseError(Errors.NOT_LEADER_OR_FOLLOWER, nonReplica, debugReplicaRequest) + } + + @Test + def testListOffsetsMaxTimeStampOldestVersion(): Unit = { + val consumerRequestBuilder = ListOffsetsRequest.Builder + .forConsumer(false, IsolationLevel.READ_UNCOMMITTED, false) + + val maxTimestampRequestBuilder = ListOffsetsRequest.Builder + .forConsumer(false, IsolationLevel.READ_UNCOMMITTED, true) + + assertEquals(0.toShort, consumerRequestBuilder.oldestAllowedVersion()) + assertEquals(7.toShort, maxTimestampRequestBuilder.oldestAllowedVersion()) + } + + def assertResponseErrorForEpoch(error: Errors, brokerId: Int, currentLeaderEpoch: Optional[Integer]): Unit = { + val listOffsetPartition = new ListOffsetsPartition() + .setPartitionIndex(partition.partition) + .setTimestamp(ListOffsetsRequest.EARLIEST_TIMESTAMP) + if (currentLeaderEpoch.isPresent) + listOffsetPartition.setCurrentLeaderEpoch(currentLeaderEpoch.get) + val targetTimes = List(new ListOffsetsTopic() + .setName(topic) + .setPartitions(List(listOffsetPartition).asJava)).asJava + val request = ListOffsetsRequest.Builder + .forConsumer(false, IsolationLevel.READ_UNCOMMITTED, false) + .setTargetTimes(targetTimes) + .build() + assertResponseError(error, brokerId, request) + } + + @Test + def testCurrentEpochValidation(): Unit = { + val topic = "topic" + val topicPartition = new TopicPartition(topic, 0) + val partitionToLeader = TestUtils.createTopic(zkClient, topic, numPartitions = 1, replicationFactor = 3, servers) + val firstLeaderId = partitionToLeader(topicPartition.partition) + + // We need a leader change in order to check epoch fencing since the first epoch is 0 and + // -1 is treated as having no epoch at all + killBroker(firstLeaderId) + + // Check leader error codes + val secondLeaderId = TestUtils.awaitLeaderChange(servers, topicPartition, firstLeaderId) + val secondLeaderEpoch = TestUtils.findLeaderEpoch(secondLeaderId, topicPartition, servers) + assertResponseErrorForEpoch(Errors.NONE, secondLeaderId, Optional.empty()) + assertResponseErrorForEpoch(Errors.NONE, secondLeaderId, Optional.of(secondLeaderEpoch)) + assertResponseErrorForEpoch(Errors.FENCED_LEADER_EPOCH, secondLeaderId, Optional.of(secondLeaderEpoch - 1)) + assertResponseErrorForEpoch(Errors.UNKNOWN_LEADER_EPOCH, secondLeaderId, Optional.of(secondLeaderEpoch + 1)) + + // Check follower error codes + val followerId = TestUtils.findFollowerId(topicPartition, servers) + assertResponseErrorForEpoch(Errors.NOT_LEADER_OR_FOLLOWER, followerId, Optional.empty()) + assertResponseErrorForEpoch(Errors.NOT_LEADER_OR_FOLLOWER, followerId, Optional.of(secondLeaderEpoch)) + assertResponseErrorForEpoch(Errors.UNKNOWN_LEADER_EPOCH, followerId, Optional.of(secondLeaderEpoch + 1)) + assertResponseErrorForEpoch(Errors.FENCED_LEADER_EPOCH, followerId, Optional.of(secondLeaderEpoch - 1)) + } + + private[this] def sendRequest(serverId: Int, + timestamp: Long, + version: Short): ListOffsetsPartitionResponse = { + val targetTimes = List(new ListOffsetsTopic() + .setName(topic) + .setPartitions(List(new ListOffsetsPartition() + .setPartitionIndex(partition.partition) + .setTimestamp(timestamp)).asJava)).asJava + + val builder = ListOffsetsRequest.Builder + .forConsumer(false, IsolationLevel.READ_UNCOMMITTED, false) + .setTargetTimes(targetTimes) + + val request = if (version == -1) builder.build() else builder.build(version) + + sendRequest(serverId, request).topics.asScala.find(_.name == topic).get + .partitions.asScala.find(_.partitionIndex == partition.partition).get + } + + // -1 indicate "latest" + private[this] def fetchOffsetAndEpoch(serverId: Int, + timestamp: Long, + version: Short): (Long, Int) = { + val partitionData = sendRequest(serverId, timestamp, version) + + if (version == 0) { + if (partitionData.oldStyleOffsets().isEmpty) + (-1, partitionData.leaderEpoch) + else + (partitionData.oldStyleOffsets().asScala.head, partitionData.leaderEpoch) + } else + (partitionData.offset, partitionData.leaderEpoch) + } + + @Test + def testResponseIncludesLeaderEpoch(): Unit = { + val partitionToLeader = TestUtils.createTopic(zkClient, topic, numPartitions = 1, replicationFactor = 3, servers) + val firstLeaderId = partitionToLeader(partition.partition) + + TestUtils.generateAndProduceMessages(servers, topic, 9) + TestUtils.produceMessage(servers, topic, "test-10", System.currentTimeMillis() + 10L) + + assertEquals((0L, 0), fetchOffsetAndEpoch(firstLeaderId, 0L, -1)) + assertEquals((0L, 0), fetchOffsetAndEpoch(firstLeaderId, ListOffsetsRequest.EARLIEST_TIMESTAMP, -1)) + assertEquals((10L, 0), fetchOffsetAndEpoch(firstLeaderId, ListOffsetsRequest.LATEST_TIMESTAMP, -1)) + assertEquals((9L, 0), fetchOffsetAndEpoch(firstLeaderId, ListOffsetsRequest.MAX_TIMESTAMP, -1)) + + // Kill the first leader so that we can verify the epoch change when fetching the latest offset + killBroker(firstLeaderId) + val secondLeaderId = TestUtils.awaitLeaderChange(servers, partition, firstLeaderId) + // make sure high watermark of new leader has caught up + TestUtils.waitUntilTrue(() => sendRequest(secondLeaderId, 0L, -1).errorCode() != Errors.OFFSET_NOT_AVAILABLE.code(), + "the second leader does not sync to follower") + val secondLeaderEpoch = TestUtils.findLeaderEpoch(secondLeaderId, partition, servers) + + // No changes to written data + assertEquals((0L, 0), fetchOffsetAndEpoch(secondLeaderId, 0L, -1)) + assertEquals((0L, 0), fetchOffsetAndEpoch(secondLeaderId, ListOffsetsRequest.EARLIEST_TIMESTAMP, -1)) + + assertEquals((0L, 0), fetchOffsetAndEpoch(secondLeaderId, 0L, -1)) + assertEquals((0L, 0), fetchOffsetAndEpoch(secondLeaderId, ListOffsetsRequest.EARLIEST_TIMESTAMP, -1)) + + // The latest offset reflects the updated epoch + assertEquals((10L, secondLeaderEpoch), fetchOffsetAndEpoch(secondLeaderId, ListOffsetsRequest.LATEST_TIMESTAMP, -1)) + assertEquals((9L, secondLeaderEpoch), fetchOffsetAndEpoch(secondLeaderId, ListOffsetsRequest.MAX_TIMESTAMP, -1)) + } + + @Test + def testResponseDefaultOffsetAndLeaderEpochForAllVersions(): Unit = { + val partitionToLeader = TestUtils.createTopic(zkClient, topic, numPartitions = 1, replicationFactor = 3, servers) + val firstLeaderId = partitionToLeader(partition.partition) + + TestUtils.generateAndProduceMessages(servers, topic, 9) + TestUtils.produceMessage(servers, topic, "test-10", System.currentTimeMillis() + 10L) + + for (version <- ApiKeys.LIST_OFFSETS.oldestVersion to ApiKeys.LIST_OFFSETS.latestVersion) { + if (version == 0) { + assertEquals((-1L, -1), fetchOffsetAndEpoch(firstLeaderId, 0L, version.toShort)) + assertEquals((0L, -1), fetchOffsetAndEpoch(firstLeaderId, ListOffsetsRequest.EARLIEST_TIMESTAMP, version.toShort)) + assertEquals((10L, -1), fetchOffsetAndEpoch(firstLeaderId, ListOffsetsRequest.LATEST_TIMESTAMP, version.toShort)) + } else if (version >= 1 && version <= 3) { + assertEquals((0L, -1), fetchOffsetAndEpoch(firstLeaderId, 0L, version.toShort)) + assertEquals((0L, -1), fetchOffsetAndEpoch(firstLeaderId, ListOffsetsRequest.EARLIEST_TIMESTAMP, version.toShort)) + assertEquals((10L, -1), fetchOffsetAndEpoch(firstLeaderId, ListOffsetsRequest.LATEST_TIMESTAMP, version.toShort)) + } else if (version >= 4 && version <= 6) { + assertEquals((0L, 0), fetchOffsetAndEpoch(firstLeaderId, 0L, version.toShort)) + assertEquals((0L, 0), fetchOffsetAndEpoch(firstLeaderId, ListOffsetsRequest.EARLIEST_TIMESTAMP, version.toShort)) + assertEquals((10L, 0), fetchOffsetAndEpoch(firstLeaderId, ListOffsetsRequest.LATEST_TIMESTAMP, version.toShort)) + } else if (version >= 7) { + assertEquals((0L, 0), fetchOffsetAndEpoch(firstLeaderId, 0L, version.toShort)) + assertEquals((0L, 0), fetchOffsetAndEpoch(firstLeaderId, ListOffsetsRequest.EARLIEST_TIMESTAMP, version.toShort)) + assertEquals((10L, 0), fetchOffsetAndEpoch(firstLeaderId, ListOffsetsRequest.LATEST_TIMESTAMP, version.toShort)) + assertEquals((9L, 0), fetchOffsetAndEpoch(firstLeaderId, ListOffsetsRequest.MAX_TIMESTAMP, version.toShort)) + } + } + } + + private def assertResponseError(error: Errors, brokerId: Int, request: ListOffsetsRequest): Unit = { + val response = sendRequest(brokerId, request) + assertEquals(request.topics.size, response.topics.size) + response.topics.asScala.foreach { topic => + topic.partitions.asScala.foreach { partition => + assertEquals(error.code, partition.errorCode) + } + } + } + + private def sendRequest(leaderId: Int, request: ListOffsetsRequest): ListOffsetsResponse = { + connectAndReceive[ListOffsetsResponse](request, destination = brokerSocketServer(leaderId)) + } +} diff --git a/core/src/test/scala/unit/kafka/server/LogDirFailureTest.scala b/core/src/test/scala/unit/kafka/server/LogDirFailureTest.scala new file mode 100644 index 0000000..1025d7a --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/LogDirFailureTest.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.io.File +import java.util.Collections +import java.util.concurrent.{ExecutionException, TimeUnit} +import kafka.api.IntegrationTestHarness +import kafka.controller.{OfflineReplica, PartitionAndReplica} +import kafka.utils.TestUtils.{Checkpoint, LogDirFailureType, Roll} +import kafka.utils.{CoreUtils, Exit, TestUtils} +import org.apache.kafka.clients.consumer.KafkaConsumer +import org.apache.kafka.clients.producer.{ProducerConfig, ProducerRecord} +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.errors.{KafkaStorageException, NotLeaderOrFollowerException} +import org.apache.kafka.common.utils.Utils +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{BeforeEach, Test, TestInfo} + +import scala.annotation.nowarn +import scala.jdk.CollectionConverters._ + +/** + * Test whether clients can produce and consume when there is log directory failure + */ +class LogDirFailureTest extends IntegrationTestHarness { + + val producerCount: Int = 1 + val consumerCount: Int = 1 + val brokerCount: Int = 2 + private val topic = "topic" + private val partitionNum = 12 + override val logDirCount = 3 + + this.serverConfig.setProperty(KafkaConfig.ReplicaHighWatermarkCheckpointIntervalMsProp, "60000") + this.serverConfig.setProperty(KafkaConfig.NumReplicaFetchersProp, "1") + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + createTopic(topic, partitionNum, brokerCount) + } + + @Test + def testProduceErrorFromFailureOnLogRoll(): Unit = { + testProduceErrorsFromLogDirFailureOnLeader(Roll) + } + + @Test + def testIOExceptionDuringLogRoll(): Unit = { + testProduceAfterLogDirFailureOnLeader(Roll) + } + + // Broker should halt on any log directory failure if inter-broker protocol < 1.0 + @nowarn("cat=deprecation") + @Test + def brokerWithOldInterBrokerProtocolShouldHaltOnLogDirFailure(): Unit = { + @volatile var statusCodeOption: Option[Int] = None + Exit.setHaltProcedure { (statusCode, _) => + statusCodeOption = Some(statusCode) + throw new IllegalArgumentException + } + + var server: KafkaServer = null + try { + val props = TestUtils.createBrokerConfig(brokerCount, zkConnect, logDirCount = 3) + props.put(KafkaConfig.InterBrokerProtocolVersionProp, "0.11.0") + props.put(KafkaConfig.LogMessageFormatVersionProp, "0.11.0") + val kafkaConfig = KafkaConfig.fromProps(props) + val logDir = new File(kafkaConfig.logDirs.head) + // Make log directory of the partition on the leader broker inaccessible by replacing it with a file + CoreUtils.swallow(Utils.delete(logDir), this) + logDir.createNewFile() + assertTrue(logDir.isFile) + + server = TestUtils.createServer(kafkaConfig) + TestUtils.waitUntilTrue(() => statusCodeOption.contains(1), "timed out waiting for broker to halt") + } finally { + Exit.resetHaltProcedure() + if (server != null) + TestUtils.shutdownServers(List(server)) + } + } + + @Test + def testProduceErrorFromFailureOnCheckpoint(): Unit = { + testProduceErrorsFromLogDirFailureOnLeader(Checkpoint) + } + + @Test + def testIOExceptionDuringCheckpoint(): Unit = { + testProduceAfterLogDirFailureOnLeader(Checkpoint) + } + + @Test + def testReplicaFetcherThreadAfterLogDirFailureOnFollower(): Unit = { + this.producerConfig.setProperty(ProducerConfig.RETRIES_CONFIG, "0") + val producer = createProducer() + val partition = new TopicPartition(topic, 0) + + val partitionInfo = producer.partitionsFor(topic).asScala.find(_.partition() == 0).get + val leaderServerId = partitionInfo.leader().id() + val leaderServer = servers.find(_.config.brokerId == leaderServerId).get + val followerServerId = partitionInfo.replicas().map(_.id()).find(_ != leaderServerId).get + val followerServer = servers.find(_.config.brokerId == followerServerId).get + + followerServer.replicaManager.markPartitionOffline(partition) + // Send a message to another partition whose leader is the same as partition 0 + // so that ReplicaFetcherThread on the follower will get response from leader immediately + val anotherPartitionWithTheSameLeader = (1 until partitionNum).find { i => + leaderServer.replicaManager.onlinePartition(new TopicPartition(topic, i)) + .flatMap(_.leaderLogIfLocal).isDefined + }.get + val record = new ProducerRecord[Array[Byte], Array[Byte]](topic, anotherPartitionWithTheSameLeader, topic.getBytes, "message".getBytes) + // When producer.send(...).get returns, it is guaranteed that ReplicaFetcherThread on the follower + // has fetched from the leader and attempts to append to the offline replica. + producer.send(record).get + + assertEquals(brokerCount, leaderServer.replicaManager.onlinePartition(new TopicPartition(topic, anotherPartitionWithTheSameLeader)) + .get.inSyncReplicaIds.size) + followerServer.replicaManager.replicaFetcherManager.fetcherThreadMap.values.foreach { thread => + assertFalse(thread.isShutdownComplete, "ReplicaFetcherThread should still be working if its partition count > 0") + } + } + + def testProduceErrorsFromLogDirFailureOnLeader(failureType: LogDirFailureType): Unit = { + // Disable retries to allow exception to bubble up for validation + this.producerConfig.setProperty(ProducerConfig.RETRIES_CONFIG, "0") + val producer = createProducer() + + val partition = new TopicPartition(topic, 0) + val record = new ProducerRecord(topic, 0, s"key".getBytes, s"value".getBytes) + + val leaderServerId = producer.partitionsFor(topic).asScala.find(_.partition() == 0).get.leader().id() + val leaderServer = servers.find(_.config.brokerId == leaderServerId).get + + TestUtils.causeLogDirFailure(failureType, leaderServer, partition) + + // send() should fail due to either KafkaStorageException or NotLeaderOrFollowerException + val e = assertThrows(classOf[ExecutionException], () => producer.send(record).get(6000, TimeUnit.MILLISECONDS)) + assertTrue(e.getCause.isInstanceOf[KafkaStorageException] || + // This may happen if ProduceRequest version <= 3 + e.getCause.isInstanceOf[NotLeaderOrFollowerException]) + } + + def testProduceAfterLogDirFailureOnLeader(failureType: LogDirFailureType): Unit = { + val consumer = createConsumer() + subscribeAndWaitForAssignment(topic, consumer) + + val producer = createProducer() + + val partition = new TopicPartition(topic, 0) + val record = new ProducerRecord(topic, 0, s"key".getBytes, s"value".getBytes) + + val leaderServerId = producer.partitionsFor(topic).asScala.find(_.partition() == 0).get.leader().id() + val leaderServer = servers.find(_.config.brokerId == leaderServerId).get + + // The first send() should succeed + producer.send(record).get() + TestUtils.consumeRecords(consumer, 1) + + TestUtils.causeLogDirFailure(failureType, leaderServer, partition) + + TestUtils.waitUntilTrue(() => { + // ProduceResponse may contain KafkaStorageException and trigger metadata update + producer.send(record) + producer.partitionsFor(topic).asScala.find(_.partition() == 0).get.leader().id() != leaderServerId + }, "Expected new leader for the partition", 6000L) + + // Block on send to ensure that new leader accepts a message. + producer.send(record).get(6000L, TimeUnit.MILLISECONDS) + + // Consumer should receive some messages + TestUtils.pollUntilAtLeastNumRecords(consumer, 1) + + // There should be no remaining LogDirEventNotification znode + assertTrue(zkClient.getAllLogDirEventNotifications.isEmpty) + + // The controller should have marked the replica on the original leader as offline + val controllerServer = servers.find(_.kafkaController.isActive).get + val offlineReplicas = controllerServer.kafkaController.controllerContext.replicasInState(topic, OfflineReplica) + assertTrue(offlineReplicas.contains(PartitionAndReplica(new TopicPartition(topic, 0), leaderServerId))) + } + + + private def subscribeAndWaitForAssignment(topic: String, consumer: KafkaConsumer[Array[Byte], Array[Byte]]): Unit = { + consumer.subscribe(Collections.singletonList(topic)) + TestUtils.pollUntilTrue(consumer, () => !consumer.assignment.isEmpty, "Expected non-empty assignment") + } + +} diff --git a/core/src/test/scala/unit/kafka/server/LogOffsetTest.scala b/core/src/test/scala/unit/kafka/server/LogOffsetTest.scala new file mode 100755 index 0000000..be62211 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/LogOffsetTest.scala @@ -0,0 +1,325 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.log.{ClientRecordDeletion, LogSegment, UnifiedLog} +import kafka.utils.{MockTime, TestUtils} +import org.apache.kafka.common.message.ListOffsetsRequestData.{ListOffsetsPartition, ListOffsetsTopic} +import org.apache.kafka.common.message.ListOffsetsResponseData.{ListOffsetsPartitionResponse, ListOffsetsTopicResponse} +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.requests.{FetchRequest, FetchResponse, ListOffsetsRequest, ListOffsetsResponse} +import org.apache.kafka.common.{IsolationLevel, TopicPartition} +import org.easymock.{EasyMock, IAnswer} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import java.io.File +import java.util.concurrent.atomic.AtomicInteger +import java.util.{Optional, Properties, Random} +import scala.collection.mutable.Buffer +import scala.jdk.CollectionConverters._ + +class LogOffsetTest extends BaseRequestTest { + + private lazy val time = new MockTime + + override def brokerCount = 1 + + protected override def brokerTime(brokerId: Int) = time + + protected override def brokerPropertyOverrides(props: Properties): Unit = { + props.put("log.flush.interval.messages", "1") + props.put("num.partitions", "20") + props.put("log.retention.hours", "10") + props.put("log.retention.check.interval.ms", (5 * 1000 * 60).toString) + props.put("log.segment.bytes", "140") + } + + @deprecated("ListOffsetsRequest V0", since = "") + @Test + def testGetOffsetsForUnknownTopic(): Unit = { + val topicPartition = new TopicPartition("foo", 0) + val request = ListOffsetsRequest.Builder.forConsumer(false, IsolationLevel.READ_UNCOMMITTED, false) + .setTargetTimes(buildTargetTimes(topicPartition, ListOffsetsRequest.LATEST_TIMESTAMP, 10).asJava).build(0) + val response = sendListOffsetsRequest(request) + assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION.code, findPartition(response.topics.asScala, topicPartition).errorCode) + } + + @deprecated("ListOffsetsRequest V0", since = "") + @Test + def testGetOffsetsAfterDeleteRecords(): Unit = { + val topic = "kafka-" + val topicPartition = new TopicPartition(topic, 0) + val log = createTopicAndGetLog(topic, topicPartition) + + for (_ <- 0 until 20) + log.appendAsLeader(TestUtils.singletonRecords(value = Integer.toString(42).getBytes()), leaderEpoch = 0) + log.flush() + + log.updateHighWatermark(log.logEndOffset) + log.maybeIncrementLogStartOffset(3, ClientRecordDeletion) + log.deleteOldSegments() + + val offsets = log.legacyFetchOffsetsBefore(ListOffsetsRequest.LATEST_TIMESTAMP, 15) + assertEquals(Seq(20L, 18L, 16L, 14L, 12L, 10L, 8L, 6L, 4L, 3L), offsets) + + TestUtils.waitUntilTrue(() => TestUtils.isLeaderLocalOnBroker(topic, topicPartition.partition, server), + "Leader should be elected") + val request = ListOffsetsRequest.Builder.forReplica(0, 0) + .setTargetTimes(buildTargetTimes(topicPartition, ListOffsetsRequest.LATEST_TIMESTAMP, 15).asJava).build() + val consumerOffsets = findPartition(sendListOffsetsRequest(request).topics.asScala, topicPartition).oldStyleOffsets.asScala + assertEquals(Seq(20L, 18L, 16L, 14L, 12L, 10L, 8L, 6L, 4L, 3L), consumerOffsets) + } + + @Test + def testFetchOffsetByTimestampForMaxTimestampAfterTruncate(): Unit = { + val topic = "kafka-" + val topicPartition = new TopicPartition(topic, 0) + val log = createTopicAndGetLog(topic, topicPartition) + + for (timestamp <- 0 until 20) + log.appendAsLeader(TestUtils.singletonRecords(value = Integer.toString(42).getBytes(), timestamp = timestamp.toLong), leaderEpoch = 0) + log.flush() + + log.updateHighWatermark(log.logEndOffset) + + val firstOffset = log.fetchOffsetByTimestamp(ListOffsetsRequest.MAX_TIMESTAMP) + assertEquals(19L, firstOffset.get.offset) + assertEquals(19L, firstOffset.get.timestamp) + + log.truncateTo(0) + + val secondOffset = log.fetchOffsetByTimestamp(ListOffsetsRequest.MAX_TIMESTAMP) + assertEquals(0L, secondOffset.get.offset) + assertEquals(-1L, secondOffset.get.timestamp) + } + + @Test + def testFetchOffsetByTimestampForMaxTimestampWithUnorderedTimestamps(): Unit = { + val topic = "kafka-" + val topicPartition = new TopicPartition(topic, 0) + val log = createTopicAndGetLog(topic, topicPartition) + + for (timestamp <- List(0L, 1L, 2L, 3L, 4L, 6L, 5L)) + log.appendAsLeader(TestUtils.singletonRecords(value = Integer.toString(42).getBytes(), timestamp = timestamp), leaderEpoch = 0) + log.flush() + + log.updateHighWatermark(log.logEndOffset) + + val maxTimestampOffset = log.fetchOffsetByTimestamp(ListOffsetsRequest.MAX_TIMESTAMP) + assertEquals(7L, log.logEndOffset) + assertEquals(5L, maxTimestampOffset.get.offset) + assertEquals(6L, maxTimestampOffset.get.timestamp) + } + + @Test + def testGetOffsetsBeforeLatestTime(): Unit = { + val topic = "kafka-" + val topicPartition = new TopicPartition(topic, 0) + val log = createTopicAndGetLog(topic, topicPartition) + + val topicIds = getTopicIds().asJava + val topicNames = topicIds.asScala.map(_.swap).asJava + val topicId = topicIds.get(topic) + + for (_ <- 0 until 20) + log.appendAsLeader(TestUtils.singletonRecords(value = Integer.toString(42).getBytes()), leaderEpoch = 0) + log.flush() + + val offsets = log.legacyFetchOffsetsBefore(ListOffsetsRequest.LATEST_TIMESTAMP, 15) + assertEquals(Seq(20L, 18L, 16L, 14L, 12L, 10L, 8L, 6L, 4L, 2L, 0L), offsets) + + TestUtils.waitUntilTrue(() => TestUtils.isLeaderLocalOnBroker(topic, 0, server), + "Leader should be elected") + val request = ListOffsetsRequest.Builder.forReplica(0, 0) + .setTargetTimes(buildTargetTimes(topicPartition, ListOffsetsRequest.LATEST_TIMESTAMP, 15).asJava).build() + val consumerOffsets = findPartition(sendListOffsetsRequest(request).topics.asScala, topicPartition).oldStyleOffsets.asScala + assertEquals(Seq(20L, 18L, 16L, 14L, 12L, 10L, 8L, 6L, 4L, 2L, 0L), consumerOffsets) + + // try to fetch using latest offset + val fetchRequest = FetchRequest.Builder.forConsumer(ApiKeys.FETCH.latestVersion, 0, 1, + Map(topicPartition -> new FetchRequest.PartitionData(topicId, consumerOffsets.head, FetchRequest.INVALID_LOG_START_OFFSET, + 300 * 1024, Optional.empty())).asJava).build() + val fetchResponse = sendFetchRequest(fetchRequest) + assertFalse(FetchResponse.recordsOrFail(fetchResponse.responseData(topicNames, ApiKeys.FETCH.latestVersion).get(topicPartition)).batches.iterator.hasNext) + } + + @Test + def testEmptyLogsGetOffsets(): Unit = { + val random = new Random + val topic = "kafka-" + val topicPartition = new TopicPartition(topic, random.nextInt(10)) + val topicPartitionPath = s"${TestUtils.tempDir().getAbsolutePath}/$topic-${topicPartition.partition}" + val topicLogDir = new File(topicPartitionPath) + topicLogDir.mkdir() + + createTopic(topic, numPartitions = 1, replicationFactor = 1) + + var offsetChanged = false + for (_ <- 1 to 14) { + val topicPartition = new TopicPartition(topic, 0) + val request = ListOffsetsRequest.Builder.forReplica(0, 0) + .setTargetTimes(buildTargetTimes(topicPartition, ListOffsetsRequest.EARLIEST_TIMESTAMP, 1).asJava).build() + val consumerOffsets = findPartition(sendListOffsetsRequest(request).topics.asScala, topicPartition).oldStyleOffsets.asScala + if (consumerOffsets.head == 1) + offsetChanged = true + } + assertFalse(offsetChanged) + } + + @Test + def testFetchOffsetByTimestampForMaxTimestampWithEmptyLog(): Unit = { + val topic = "kafka-" + val topicPartition = new TopicPartition(topic, 0) + val log = createTopicAndGetLog(topic, topicPartition) + + log.updateHighWatermark(log.logEndOffset) + + val maxTimestampOffset = log.fetchOffsetByTimestamp(ListOffsetsRequest.MAX_TIMESTAMP) + assertEquals(0L, log.logEndOffset) + assertEquals(0L, maxTimestampOffset.get.offset) + assertEquals(-1L, maxTimestampOffset.get.timestamp) + } + + @deprecated("legacyFetchOffsetsBefore", since = "") + @Test + def testGetOffsetsBeforeNow(): Unit = { + val random = new Random + val topic = "kafka-" + val topicPartition = new TopicPartition(topic, random.nextInt(3)) + + createTopic(topic, 3, 1) + + val logManager = server.getLogManager + val log = logManager.getOrCreateLog(topicPartition, topicId = None) + + for (_ <- 0 until 20) + log.appendAsLeader(TestUtils.singletonRecords(value = Integer.toString(42).getBytes()), leaderEpoch = 0) + log.flush() + + val now = time.milliseconds + 30000 // pretend it is the future to avoid race conditions with the fs + + val offsets = log.legacyFetchOffsetsBefore(now, 15) + assertEquals(Seq(20L, 18L, 16L, 14L, 12L, 10L, 8L, 6L, 4L, 2L, 0L), offsets) + + TestUtils.waitUntilTrue(() => TestUtils.isLeaderLocalOnBroker(topic, topicPartition.partition, server), + "Leader should be elected") + val request = ListOffsetsRequest.Builder.forReplica(0, 0) + .setTargetTimes(buildTargetTimes(topicPartition, now, 15).asJava).build() + val consumerOffsets = findPartition(sendListOffsetsRequest(request).topics.asScala, topicPartition).oldStyleOffsets.asScala + assertEquals(Seq(20L, 18L, 16L, 14L, 12L, 10L, 8L, 6L, 4L, 2L, 0L), consumerOffsets) + } + + @deprecated("legacyFetchOffsetsBefore", since = "") + @Test + def testGetOffsetsBeforeEarliestTime(): Unit = { + val random = new Random + val topic = "kafka-" + val topicPartition = new TopicPartition(topic, random.nextInt(3)) + + createTopic(topic, 3, 1) + + val logManager = server.getLogManager + val log = logManager.getOrCreateLog(topicPartition, topicId = None) + for (_ <- 0 until 20) + log.appendAsLeader(TestUtils.singletonRecords(value = Integer.toString(42).getBytes()), leaderEpoch = 0) + log.flush() + + val offsets = log.legacyFetchOffsetsBefore(ListOffsetsRequest.EARLIEST_TIMESTAMP, 10) + + assertEquals(Seq(0L), offsets) + + TestUtils.waitUntilTrue(() => TestUtils.isLeaderLocalOnBroker(topic, topicPartition.partition, server), + "Leader should be elected") + val request = ListOffsetsRequest.Builder.forReplica(0, 0) + .setTargetTimes(buildTargetTimes(topicPartition, ListOffsetsRequest.EARLIEST_TIMESTAMP, 10).asJava).build() + val consumerOffsets = findPartition(sendListOffsetsRequest(request).topics.asScala, topicPartition).oldStyleOffsets.asScala + assertEquals(Seq(0L), consumerOffsets) + } + + /* We test that `fetchOffsetsBefore` works correctly if `LogSegment.size` changes after each invocation (simulating + * a race condition) */ + @Test + def testFetchOffsetsBeforeWithChangingSegmentSize(): Unit = { + val log: UnifiedLog = EasyMock.niceMock(classOf[UnifiedLog]) + val logSegment: LogSegment = EasyMock.niceMock(classOf[LogSegment]) + EasyMock.expect(logSegment.size).andStubAnswer(new IAnswer[Int] { + private val value = new AtomicInteger(0) + def answer: Int = value.getAndIncrement() + }) + EasyMock.replay(logSegment) + val logSegments = Seq(logSegment) + EasyMock.expect(log.logSegments).andStubReturn(logSegments) + EasyMock.replay(log) + log.legacyFetchOffsetsBefore(System.currentTimeMillis, 100) + } + + /* We test that `fetchOffsetsBefore` works correctly if `Log.logSegments` content and size are + * different (simulating a race condition) */ + @Test + def testFetchOffsetsBeforeWithChangingSegments(): Unit = { + val log: UnifiedLog = EasyMock.niceMock(classOf[UnifiedLog]) + val logSegment: LogSegment = EasyMock.niceMock(classOf[LogSegment]) + EasyMock.expect(log.logSegments).andStubAnswer { + new IAnswer[Iterable[LogSegment]] { + def answer = new Iterable[LogSegment] { + override def size = 2 + def iterator = Seq(logSegment).iterator + } + } + } + EasyMock.replay(logSegment) + EasyMock.replay(log) + log.legacyFetchOffsetsBefore(System.currentTimeMillis, 100) + } + + private def server: KafkaServer = servers.head + + private def sendListOffsetsRequest(request: ListOffsetsRequest): ListOffsetsResponse = { + connectAndReceive[ListOffsetsResponse](request) + } + + private def sendFetchRequest(request: FetchRequest): FetchResponse = { + connectAndReceive[FetchResponse](request) + } + + private def buildTargetTimes(tp: TopicPartition, timestamp: Long, maxNumOffsets: Int): List[ListOffsetsTopic] = { + List(new ListOffsetsTopic() + .setName(tp.topic) + .setPartitions(List(new ListOffsetsPartition() + .setPartitionIndex(tp.partition) + .setTimestamp(timestamp) + .setMaxNumOffsets(maxNumOffsets)).asJava) + ) + } + + private def findPartition(topics: Buffer[ListOffsetsTopicResponse], tp: TopicPartition): ListOffsetsPartitionResponse = { + topics.find(_.name == tp.topic).get + .partitions.asScala.find(_.partitionIndex == tp.partition).get + } + + private def createTopicAndGetLog(topic: String, topicPartition: TopicPartition): UnifiedLog = { + createTopic(topic, 1, 1) + + val logManager = server.getLogManager + TestUtils.waitUntilTrue(() => logManager.getLog(topicPartition).isDefined, + "Log for partition [topic,0] should be created") + logManager.getLog(topicPartition).get + } + +} diff --git a/core/src/test/scala/unit/kafka/server/LogRecoveryTest.scala b/core/src/test/scala/unit/kafka/server/LogRecoveryTest.scala new file mode 100755 index 0000000..30fc72d --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/LogRecoveryTest.scala @@ -0,0 +1,246 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +package kafka.server + +import java.util.Properties + +import scala.collection.Seq + +import kafka.utils.TestUtils +import TestUtils._ +import kafka.server.QuorumTestHarness +import java.io.File + +import kafka.server.checkpoints.OffsetCheckpointFile +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord} +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.serialization.{IntegerSerializer, StringSerializer} +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} +import org.junit.jupiter.api.Assertions._ + +class LogRecoveryTest extends QuorumTestHarness { + + val replicaLagTimeMaxMs = 5000L + val replicaLagMaxMessages = 10L + val replicaFetchWaitMaxMs = 1000 + val replicaFetchMinBytes = 20 + + val overridingProps = new Properties() + overridingProps.put(KafkaConfig.ReplicaLagTimeMaxMsProp, replicaLagTimeMaxMs.toString) + overridingProps.put(KafkaConfig.ReplicaFetchWaitMaxMsProp, replicaFetchWaitMaxMs.toString) + overridingProps.put(KafkaConfig.ReplicaFetchMinBytesProp, replicaFetchMinBytes.toString) + + var configs: Seq[KafkaConfig] = null + val topic = "new-topic" + val partitionId = 0 + val topicPartition = new TopicPartition(topic, partitionId) + + var server1: KafkaServer = null + var server2: KafkaServer = null + + def configProps1 = configs.head + def configProps2 = configs.last + + val message = "hello" + + var producer: KafkaProducer[Integer, String] = null + def hwFile1 = new OffsetCheckpointFile(new File(configProps1.logDirs.head, ReplicaManager.HighWatermarkFilename)) + def hwFile2 = new OffsetCheckpointFile(new File(configProps2.logDirs.head, ReplicaManager.HighWatermarkFilename)) + var servers = Seq.empty[KafkaServer] + + // Some tests restart the brokers then produce more data. But since test brokers use random ports, we need + // to use a new producer that knows the new ports + def updateProducer() = { + if (producer != null) + producer.close() + producer = TestUtils.createProducer( + TestUtils.getBrokerListStrFromServers(servers), + keySerializer = new IntegerSerializer, + valueSerializer = new StringSerializer + ) + } + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + + configs = TestUtils.createBrokerConfigs(2, zkConnect, enableControlledShutdown = false).map(KafkaConfig.fromProps(_, overridingProps)) + + // start both servers + server1 = TestUtils.createServer(configProps1) + server2 = TestUtils.createServer(configProps2) + servers = List(server1, server2) + + // create topic with 1 partition, 2 replicas, one on each broker + createTopic(zkClient, topic, partitionReplicaAssignment = Map(0 -> Seq(0,1)), servers = servers) + + // create the producer + updateProducer() + } + + @AfterEach + override def tearDown(): Unit = { + producer.close() + TestUtils.shutdownServers(servers) + super.tearDown() + } + + @Test + def testHWCheckpointNoFailuresSingleLogSegment(): Unit = { + val numMessages = 2L + sendMessages(numMessages.toInt) + + // give some time for the follower 1 to record leader HW + TestUtils.waitUntilTrue(() => + server2.replicaManager.localLogOrException(topicPartition).highWatermark == numMessages, + "Failed to update high watermark for follower after timeout") + + servers.foreach(_.replicaManager.checkpointHighWatermarks()) + val leaderHW = hwFile1.read().getOrElse(topicPartition, 0L) + assertEquals(numMessages, leaderHW) + val followerHW = hwFile2.read().getOrElse(topicPartition, 0L) + assertEquals(numMessages, followerHW) + } + + @Test + def testHWCheckpointWithFailuresSingleLogSegment(): Unit = { + var leader = waitUntilLeaderIsElectedOrChanged(zkClient, topic, partitionId) + + assertEquals(0L, hwFile1.read().getOrElse(topicPartition, 0L)) + + sendMessages(1) + Thread.sleep(1000) + var hw = 1L + + // kill the server hosting the preferred replica + server1.shutdown() + assertEquals(hw, hwFile1.read().getOrElse(topicPartition, 0L)) + + // check if leader moves to the other server + leader = waitUntilLeaderIsElectedOrChanged(zkClient, topic, partitionId, oldLeaderOpt = Some(leader)) + assertEquals(1, leader, "Leader must move to broker 1") + + // bring the preferred replica back + server1.startup() + // Update producer with new server settings + updateProducer() + + leader = waitUntilLeaderIsElectedOrChanged(zkClient, topic, partitionId) + assertTrue(leader == 0 || leader == 1, + "Leader must remain on broker 1, in case of ZooKeeper session expiration it can move to broker 0") + + assertEquals(hw, hwFile1.read().getOrElse(topicPartition, 0L)) + /** We plan to shutdown server2 and transfer the leadership to server1. + * With unclean leader election turned off, a prerequisite for the successful leadership transition + * is that server1 has caught up on the topicPartition, and has joined the ISR. + * In the line below, we wait until the condition is met before shutting down server2 + */ + waitUntilTrue(() => server2.replicaManager.onlinePartition(topicPartition).get.inSyncReplicaIds.size == 2, + "Server 1 is not able to join the ISR after restart") + + + // since server 2 was never shut down, the hw value of 30 is probably not checkpointed to disk yet + server2.shutdown() + assertEquals(hw, hwFile2.read().getOrElse(topicPartition, 0L)) + + server2.startup() + updateProducer() + leader = waitUntilLeaderIsElectedOrChanged(zkClient, topic, partitionId, oldLeaderOpt = Some(leader)) + assertTrue(leader == 0 || leader == 1, + "Leader must remain on broker 0, in case of ZooKeeper session expiration it can move to broker 1") + + sendMessages(1) + hw += 1 + + // give some time for follower 1 to record leader HW of 60 + TestUtils.waitUntilTrue(() => + server2.replicaManager.localLogOrException(topicPartition).highWatermark == hw, + "Failed to update high watermark for follower after timeout") + // shutdown the servers to allow the hw to be checkpointed + servers.foreach(_.shutdown()) + assertEquals(hw, hwFile1.read().getOrElse(topicPartition, 0L)) + assertEquals(hw, hwFile2.read().getOrElse(topicPartition, 0L)) + } + + @Test + def testHWCheckpointNoFailuresMultipleLogSegments(): Unit = { + sendMessages(20) + val hw = 20L + // give some time for follower 1 to record leader HW of 600 + TestUtils.waitUntilTrue(() => + server2.replicaManager.localLogOrException(topicPartition).highWatermark == hw, + "Failed to update high watermark for follower after timeout") + // shutdown the servers to allow the hw to be checkpointed + servers.foreach(_.shutdown()) + val leaderHW = hwFile1.read().getOrElse(topicPartition, 0L) + assertEquals(hw, leaderHW) + val followerHW = hwFile2.read().getOrElse(topicPartition, 0L) + assertEquals(hw, followerHW) + } + + @Test + def testHWCheckpointWithFailuresMultipleLogSegments(): Unit = { + var leader = waitUntilLeaderIsElectedOrChanged(zkClient, topic, partitionId) + + sendMessages(2) + var hw = 2L + + // allow some time for the follower to get the leader HW + TestUtils.waitUntilTrue(() => + server2.replicaManager.localLogOrException(topicPartition).highWatermark == hw, + "Failed to update high watermark for follower after timeout") + // kill the server hosting the preferred replica + server1.shutdown() + server2.shutdown() + assertEquals(hw, hwFile1.read().getOrElse(topicPartition, 0L)) + assertEquals(hw, hwFile2.read().getOrElse(topicPartition, 0L)) + + server2.startup() + updateProducer() + // check if leader moves to the other server + leader = waitUntilLeaderIsElectedOrChanged(zkClient, topic, partitionId, oldLeaderOpt = Some(leader)) + assertEquals(1, leader, "Leader must move to broker 1") + + assertEquals(hw, hwFile1.read().getOrElse(topicPartition, 0L)) + + // bring the preferred replica back + server1.startup() + updateProducer() + + assertEquals(hw, hwFile1.read().getOrElse(topicPartition, 0L)) + assertEquals(hw, hwFile2.read().getOrElse(topicPartition, 0L)) + + sendMessages(2) + hw += 2 + + // allow some time for the follower to create replica + TestUtils.waitUntilTrue(() => server1.replicaManager.localLog(topicPartition).nonEmpty, + "Failed to create replica in follower after timeout") + // allow some time for the follower to get the leader HW + TestUtils.waitUntilTrue(() => + server1.replicaManager.localLogOrException(topicPartition).highWatermark == hw, + "Failed to update high watermark for follower after timeout") + // shutdown the servers to allow the hw to be checkpointed + servers.foreach(_.shutdown()) + assertEquals(hw, hwFile1.read().getOrElse(topicPartition, 0L)) + assertEquals(hw, hwFile2.read().getOrElse(topicPartition, 0L)) + } + + private def sendMessages(n: Int): Unit = { + (0 until n).map(_ => producer.send(new ProducerRecord(topic, 0, message))).foreach(_.get) + } +} diff --git a/core/src/test/scala/unit/kafka/server/MetadataCacheTest.scala b/core/src/test/scala/unit/kafka/server/MetadataCacheTest.scala new file mode 100644 index 0000000..369d2f8 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/MetadataCacheTest.scala @@ -0,0 +1,636 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import org.apache.kafka.common.{Node, TopicPartition, Uuid} + +import java.util +import util.Arrays.asList +import org.apache.kafka.common.message.UpdateMetadataRequestData.{UpdateMetadataBroker, UpdateMetadataEndpoint, UpdateMetadataPartitionState, UpdateMetadataTopicState} +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.protocol.{ApiKeys, ApiMessage, Errors} +import org.apache.kafka.common.record.RecordBatch +import org.apache.kafka.common.requests.UpdateMetadataRequest +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.raft.{OffsetAndEpoch => RaftOffsetAndEpoch} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.MethodSource + +import java.util.Collections +import kafka.api.LeaderAndIsr +import kafka.server.metadata.{KRaftMetadataCache, ZkMetadataCache} +import org.apache.kafka.common.metadata.{PartitionRecord, RegisterBrokerRecord, RemoveTopicRecord, TopicRecord} +import org.apache.kafka.common.metadata.RegisterBrokerRecord.{BrokerEndpoint, BrokerEndpointCollection} +import org.apache.kafka.image.{ClusterImage, MetadataDelta, MetadataImage} + +import scala.collection.{Seq, mutable} +import scala.jdk.CollectionConverters._ + +object MetadataCacheTest { + def zkCacheProvider(): util.stream.Stream[MetadataCache] = + util.stream.Stream.of[MetadataCache]( + MetadataCache.zkMetadataCache(1) + ) + + def cacheProvider(): util.stream.Stream[MetadataCache] = + util.stream.Stream.of[MetadataCache]( + MetadataCache.zkMetadataCache(1), + MetadataCache.kRaftMetadataCache(1) + ) + + def updateCache(cache: MetadataCache, request: UpdateMetadataRequest): Unit = { + cache match { + case c: ZkMetadataCache => c.updateMetadata(0, request) + case c: KRaftMetadataCache => { + // UpdateMetadataRequest always contains a full list of brokers, but may contain + // a partial list of partitions. Therefore, base our delta off a partial image that + // contains no brokers, but which contains the previous partitions. + val image = c.currentImage() + val partialImage = new MetadataImage( + new RaftOffsetAndEpoch(100, 10), + image.features(), ClusterImage.EMPTY, + image.topics(), image.configs(), image.clientQuotas()) + val delta = new MetadataDelta(partialImage) + + def toRecord(broker: UpdateMetadataBroker): RegisterBrokerRecord = { + val endpoints = new BrokerEndpointCollection() + broker.endpoints().forEach { e => + endpoints.add(new BrokerEndpoint(). + setName(e.listener()). + setHost(e.host()). + setPort(e.port()). + setSecurityProtocol(e.securityProtocol())) + } + val prevBroker = Option(image.cluster().broker(broker.id())) + // UpdateMetadataRequest doesn't contain all the broker registration fields, so get + // them from the previous registration if available. + val (epoch, incarnationId, fenced) = prevBroker match { + case None => (0L, Uuid.ZERO_UUID, false) + case Some(b) => (b.epoch(), b.incarnationId(), b.fenced()) + } + new RegisterBrokerRecord(). + setBrokerId(broker.id()). + setBrokerEpoch(epoch). + setIncarnationId(incarnationId). + setEndPoints(endpoints). + setRack(broker.rack()). + setFenced(fenced) + } + request.liveBrokers().iterator().asScala.foreach { brokerInfo => + delta.replay(100, 10, toRecord(brokerInfo)) + } + + def toRecords(topic: UpdateMetadataTopicState): Seq[ApiMessage] = { + val results = new mutable.ArrayBuffer[ApiMessage]() + results += new TopicRecord().setName(topic.topicName()).setTopicId(topic.topicId()) + topic.partitionStates().forEach { partition => + if (partition.leader() == LeaderAndIsr.LeaderDuringDelete) { + results += new RemoveTopicRecord().setTopicId(topic.topicId()) + } else { + results += new PartitionRecord(). + setPartitionId(partition.partitionIndex()). + setTopicId(topic.topicId()). + setReplicas(partition.replicas()). + setIsr(partition.isr()). + setRemovingReplicas(Collections.emptyList()). + setAddingReplicas(Collections.emptyList()). + setLeader(partition.leader()). + setLeaderEpoch(partition.leaderEpoch()). + setPartitionEpoch(partition.zkVersion()) + } + } + results + } + request.topicStates().forEach { topic => + toRecords(topic).foreach(delta.replay(100, 10, _)) + } + c.setImage(delta.apply()) + } + case _ => throw new RuntimeException("Unsupported cache type") + } + } +} + +class MetadataCacheTest { + val brokerEpoch = 0L + + @ParameterizedTest + @MethodSource(Array("cacheProvider")) + def getTopicMetadataNonExistingTopics(cache: MetadataCache): Unit = { + val topic = "topic" + val topicMetadata = cache.getTopicMetadata(Set(topic), ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT)) + assertTrue(topicMetadata.isEmpty) + } + + @ParameterizedTest + @MethodSource(Array("cacheProvider")) + def getTopicMetadata(cache: MetadataCache): Unit = { + val topic0 = "topic-0" + val topic1 = "topic-1" + + val zkVersion = 3 + val controllerId = 2 + val controllerEpoch = 1 + + def endpoints(brokerId: Int): Seq[UpdateMetadataEndpoint] = { + val host = s"foo-$brokerId" + Seq( + new UpdateMetadataEndpoint() + .setHost(host) + .setPort(9092) + .setSecurityProtocol(SecurityProtocol.PLAINTEXT.id) + .setListener(ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT).value), + new UpdateMetadataEndpoint() + .setHost(host) + .setPort(9093) + .setSecurityProtocol(SecurityProtocol.SSL.id) + .setListener(ListenerName.forSecurityProtocol(SecurityProtocol.SSL).value) + ) + } + + val brokers = (0 to 4).map { brokerId => + new UpdateMetadataBroker() + .setId(brokerId) + .setEndpoints(endpoints(brokerId).asJava) + .setRack("rack1") + } + + val partitionStates = Seq( + new UpdateMetadataPartitionState() + .setTopicName(topic0) + .setPartitionIndex(0) + .setControllerEpoch(controllerEpoch) + .setLeader(0) + .setLeaderEpoch(0) + .setIsr(asList(0, 1, 3)) + .setZkVersion(zkVersion) + .setReplicas(asList(0, 1, 3)), + new UpdateMetadataPartitionState() + .setTopicName(topic0) + .setPartitionIndex(1) + .setControllerEpoch(controllerEpoch) + .setLeader(1) + .setLeaderEpoch(1) + .setIsr(asList(1, 0)) + .setZkVersion(zkVersion) + .setReplicas(asList(1, 2, 0, 4)), + new UpdateMetadataPartitionState() + .setTopicName(topic1) + .setPartitionIndex(0) + .setControllerEpoch(controllerEpoch) + .setLeader(2) + .setLeaderEpoch(2) + .setIsr(asList(2, 1)) + .setZkVersion(zkVersion) + .setReplicas(asList(2, 1, 3))) + + val topicIds = new util.HashMap[String, Uuid]() + topicIds.put(topic0, Uuid.randomUuid()) + topicIds.put(topic1, Uuid.randomUuid()) + + val version = ApiKeys.UPDATE_METADATA.latestVersion + val updateMetadataRequest = new UpdateMetadataRequest.Builder(version, controllerId, controllerEpoch, brokerEpoch, + partitionStates.asJava, brokers.asJava, topicIds).build() + MetadataCacheTest.updateCache(cache, updateMetadataRequest) + + for (securityProtocol <- Seq(SecurityProtocol.PLAINTEXT, SecurityProtocol.SSL)) { + val listenerName = ListenerName.forSecurityProtocol(securityProtocol) + + def checkTopicMetadata(topic: String): Unit = { + val topicMetadatas = cache.getTopicMetadata(Set(topic), listenerName) + assertEquals(1, topicMetadatas.size) + + val topicMetadata = topicMetadatas.head + assertEquals(Errors.NONE.code, topicMetadata.errorCode) + assertEquals(topic, topicMetadata.name) + assertEquals(topicIds.get(topic), topicMetadata.topicId()) + + val topicPartitionStates = partitionStates.filter { ps => ps.topicName == topic } + val partitionMetadatas = topicMetadata.partitions.asScala.sortBy(_.partitionIndex) + assertEquals(topicPartitionStates.size, partitionMetadatas.size, s"Unexpected partition count for topic $topic") + + partitionMetadatas.zipWithIndex.foreach { case (partitionMetadata, partitionId) => + assertEquals(Errors.NONE.code, partitionMetadata.errorCode) + assertEquals(partitionId, partitionMetadata.partitionIndex) + val partitionState = topicPartitionStates.find(_.partitionIndex == partitionId).getOrElse( + fail(s"Unable to find partition state for partition $partitionId")) + assertEquals(partitionState.leader, partitionMetadata.leaderId) + assertEquals(partitionState.leaderEpoch, partitionMetadata.leaderEpoch) + assertEquals(partitionState.isr, partitionMetadata.isrNodes) + assertEquals(partitionState.replicas, partitionMetadata.replicaNodes) + } + } + + checkTopicMetadata(topic0) + checkTopicMetadata(topic1) + } + + } + + @ParameterizedTest + @MethodSource(Array("cacheProvider")) + def getTopicMetadataPartitionLeaderNotAvailable(cache: MetadataCache): Unit = { + val securityProtocol = SecurityProtocol.PLAINTEXT + val listenerName = ListenerName.forSecurityProtocol(securityProtocol) + val brokers = Seq(new UpdateMetadataBroker() + .setId(0) + .setEndpoints(Seq(new UpdateMetadataEndpoint() + .setHost("foo") + .setPort(9092) + .setSecurityProtocol(securityProtocol.id) + .setListener(listenerName.value)).asJava)) + val metadataCacheBrokerId = 0 + // leader is not available. expect LEADER_NOT_AVAILABLE for any metadata version. + verifyTopicMetadataPartitionLeaderOrEndpointNotAvailable(cache, metadataCacheBrokerId, brokers, listenerName, + leader = 1, Errors.LEADER_NOT_AVAILABLE, errorUnavailableListeners = false) + verifyTopicMetadataPartitionLeaderOrEndpointNotAvailable(cache, metadataCacheBrokerId, brokers, listenerName, + leader = 1, Errors.LEADER_NOT_AVAILABLE, errorUnavailableListeners = true) + } + + @ParameterizedTest + @MethodSource(Array("cacheProvider")) + def getTopicMetadataPartitionListenerNotAvailableOnLeader(cache: MetadataCache): Unit = { + // when listener name is not present in the metadata cache for a broker, getTopicMetadata should + // return LEADER_NOT_AVAILABLE or LISTENER_NOT_FOUND errors for old and new versions respectively. + val plaintextListenerName = ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT) + val sslListenerName = ListenerName.forSecurityProtocol(SecurityProtocol.SSL) + val broker0Endpoints = Seq( + new UpdateMetadataEndpoint() + .setHost("host0") + .setPort(9092) + .setSecurityProtocol(SecurityProtocol.PLAINTEXT.id) + .setListener(plaintextListenerName.value), + new UpdateMetadataEndpoint() + .setHost("host0") + .setPort(9093) + .setSecurityProtocol(SecurityProtocol.SSL.id) + .setListener(sslListenerName.value)) + val broker1Endpoints = Seq(new UpdateMetadataEndpoint() + .setHost("host1") + .setPort(9092) + .setSecurityProtocol(SecurityProtocol.PLAINTEXT.id) + .setListener(plaintextListenerName.value)) + val brokers = Seq( + new UpdateMetadataBroker() + .setId(0) + .setEndpoints(broker0Endpoints.asJava), + new UpdateMetadataBroker() + .setId(1) + .setEndpoints(broker1Endpoints.asJava)) + val metadataCacheBrokerId = 0 + // leader available in cache but listener name not present. expect LISTENER_NOT_FOUND error for new metadata version + verifyTopicMetadataPartitionLeaderOrEndpointNotAvailable(cache, metadataCacheBrokerId, brokers, sslListenerName, + leader = 1, Errors.LISTENER_NOT_FOUND, errorUnavailableListeners = true) + // leader available in cache but listener name not present. expect LEADER_NOT_AVAILABLE error for old metadata version + verifyTopicMetadataPartitionLeaderOrEndpointNotAvailable(cache, metadataCacheBrokerId, brokers, sslListenerName, + leader = 1, Errors.LEADER_NOT_AVAILABLE, errorUnavailableListeners = false) + } + + private def verifyTopicMetadataPartitionLeaderOrEndpointNotAvailable(cache: MetadataCache, + metadataCacheBrokerId: Int, + brokers: Seq[UpdateMetadataBroker], + listenerName: ListenerName, + leader: Int, + expectedError: Errors, + errorUnavailableListeners: Boolean): Unit = { + val topic = "topic" + + val zkVersion = 3 + val controllerId = 2 + val controllerEpoch = 1 + + val leaderEpoch = 1 + val partitionStates = Seq(new UpdateMetadataPartitionState() + .setTopicName(topic) + .setPartitionIndex(0) + .setControllerEpoch(controllerEpoch) + .setLeader(leader) + .setLeaderEpoch(leaderEpoch) + .setIsr(asList(0)) + .setZkVersion(zkVersion) + .setReplicas(asList(0))) + + val version = ApiKeys.UPDATE_METADATA.latestVersion + val updateMetadataRequest = new UpdateMetadataRequest.Builder(version, controllerId, controllerEpoch, brokerEpoch, + partitionStates.asJava, brokers.asJava, util.Collections.emptyMap()).build() + MetadataCacheTest.updateCache(cache, updateMetadataRequest) + + val topicMetadatas = cache.getTopicMetadata(Set(topic), listenerName, errorUnavailableListeners = errorUnavailableListeners) + assertEquals(1, topicMetadatas.size) + + val topicMetadata = topicMetadatas.head + assertEquals(Errors.NONE.code, topicMetadata.errorCode) + + val partitionMetadatas = topicMetadata.partitions + assertEquals(1, partitionMetadatas.size) + + val partitionMetadata = partitionMetadatas.get(0) + assertEquals(0, partitionMetadata.partitionIndex) + assertEquals(expectedError.code, partitionMetadata.errorCode) + assertFalse(partitionMetadata.isrNodes.isEmpty) + assertEquals(List(0), partitionMetadata.replicaNodes.asScala) + } + + @ParameterizedTest + @MethodSource(Array("cacheProvider")) + def getTopicMetadataReplicaNotAvailable(cache: MetadataCache): Unit = { + val topic = "topic" + + val zkVersion = 3 + val controllerId = 2 + val controllerEpoch = 1 + val securityProtocol = SecurityProtocol.PLAINTEXT + val listenerName = ListenerName.forSecurityProtocol(securityProtocol) + val brokers = Seq(new UpdateMetadataBroker() + .setId(0) + .setEndpoints(Seq(new UpdateMetadataEndpoint() + .setHost("foo") + .setPort(9092) + .setSecurityProtocol(securityProtocol.id) + .setListener(listenerName.value)).asJava)) + + // replica 1 is not available + val leader = 0 + val leaderEpoch = 0 + val replicas = asList[Integer](0, 1) + val isr = asList[Integer](0) + + val partitionStates = Seq( + new UpdateMetadataPartitionState() + .setTopicName(topic) + .setPartitionIndex(0) + .setControllerEpoch(controllerEpoch) + .setLeader(leader) + .setLeaderEpoch(leaderEpoch) + .setIsr(isr) + .setZkVersion(zkVersion) + .setReplicas(replicas)) + + val version = ApiKeys.UPDATE_METADATA.latestVersion + val updateMetadataRequest = new UpdateMetadataRequest.Builder(version, controllerId, controllerEpoch, brokerEpoch, + partitionStates.asJava, brokers.asJava, util.Collections.emptyMap()).build() + MetadataCacheTest.updateCache(cache, updateMetadataRequest) + + // Validate errorUnavailableEndpoints = false + val topicMetadatas = cache.getTopicMetadata(Set(topic), listenerName, errorUnavailableEndpoints = false) + assertEquals(1, topicMetadatas.size) + + val topicMetadata = topicMetadatas.head + assertEquals(Errors.NONE.code(), topicMetadata.errorCode) + + val partitionMetadatas = topicMetadata.partitions + assertEquals(1, partitionMetadatas.size) + + val partitionMetadata = partitionMetadatas.get(0) + assertEquals(0, partitionMetadata.partitionIndex) + assertEquals(Errors.NONE.code, partitionMetadata.errorCode) + assertEquals(Set(0, 1), partitionMetadata.replicaNodes.asScala.toSet) + assertEquals(Set(0), partitionMetadata.isrNodes.asScala.toSet) + + // Validate errorUnavailableEndpoints = true + val topicMetadatasWithError = cache.getTopicMetadata(Set(topic), listenerName, errorUnavailableEndpoints = true) + assertEquals(1, topicMetadatasWithError.size) + + val topicMetadataWithError = topicMetadatasWithError.head + assertEquals(Errors.NONE.code, topicMetadataWithError.errorCode) + + val partitionMetadatasWithError = topicMetadataWithError.partitions() + assertEquals(1, partitionMetadatasWithError.size) + + val partitionMetadataWithError = partitionMetadatasWithError.get(0) + assertEquals(0, partitionMetadataWithError.partitionIndex) + assertEquals(Errors.REPLICA_NOT_AVAILABLE.code, partitionMetadataWithError.errorCode) + assertEquals(Set(0), partitionMetadataWithError.replicaNodes.asScala.toSet) + assertEquals(Set(0), partitionMetadataWithError.isrNodes.asScala.toSet) + } + + @ParameterizedTest + @MethodSource(Array("cacheProvider")) + def getTopicMetadataIsrNotAvailable(cache: MetadataCache): Unit = { + val topic = "topic" + + val zkVersion = 3 + val controllerId = 2 + val controllerEpoch = 1 + val securityProtocol = SecurityProtocol.PLAINTEXT + val listenerName = ListenerName.forSecurityProtocol(securityProtocol) + val brokers = Seq(new UpdateMetadataBroker() + .setId(0) + .setRack("rack1") + .setEndpoints(Seq(new UpdateMetadataEndpoint() + .setHost("foo") + .setPort(9092) + .setSecurityProtocol(securityProtocol.id) + .setListener(listenerName.value)).asJava)) + + // replica 1 is not available + val leader = 0 + val leaderEpoch = 0 + val replicas = asList[Integer](0) + val isr = asList[Integer](0, 1) + + val partitionStates = Seq(new UpdateMetadataPartitionState() + .setTopicName(topic) + .setPartitionIndex(0) + .setControllerEpoch(controllerEpoch) + .setLeader(leader) + .setLeaderEpoch(leaderEpoch) + .setIsr(isr) + .setZkVersion(zkVersion) + .setReplicas(replicas)) + + val version = ApiKeys.UPDATE_METADATA.latestVersion + val updateMetadataRequest = new UpdateMetadataRequest.Builder(version, controllerId, controllerEpoch, brokerEpoch, + partitionStates.asJava, brokers.asJava, util.Collections.emptyMap()).build() + MetadataCacheTest.updateCache(cache, updateMetadataRequest) + + // Validate errorUnavailableEndpoints = false + val topicMetadatas = cache.getTopicMetadata(Set(topic), listenerName, errorUnavailableEndpoints = false) + assertEquals(1, topicMetadatas.size) + + val topicMetadata = topicMetadatas.head + assertEquals(Errors.NONE.code(), topicMetadata.errorCode) + + val partitionMetadatas = topicMetadata.partitions + assertEquals(1, partitionMetadatas.size) + + val partitionMetadata = partitionMetadatas.get(0) + assertEquals(0, partitionMetadata.partitionIndex) + assertEquals(Errors.NONE.code, partitionMetadata.errorCode) + assertEquals(Set(0), partitionMetadata.replicaNodes.asScala.toSet) + assertEquals(Set(0, 1), partitionMetadata.isrNodes.asScala.toSet) + + // Validate errorUnavailableEndpoints = true + val topicMetadatasWithError = cache.getTopicMetadata(Set(topic), listenerName, errorUnavailableEndpoints = true) + assertEquals(1, topicMetadatasWithError.size) + + val topicMetadataWithError = topicMetadatasWithError.head + assertEquals(Errors.NONE.code, topicMetadataWithError.errorCode) + + val partitionMetadatasWithError = topicMetadataWithError.partitions + assertEquals(1, partitionMetadatasWithError.size) + + val partitionMetadataWithError = partitionMetadatasWithError.get(0) + assertEquals(0, partitionMetadataWithError.partitionIndex) + assertEquals(Errors.REPLICA_NOT_AVAILABLE.code, partitionMetadataWithError.errorCode) + assertEquals(Set(0), partitionMetadataWithError.replicaNodes.asScala.toSet) + assertEquals(Set(0), partitionMetadataWithError.isrNodes.asScala.toSet) + } + + @ParameterizedTest + @MethodSource(Array("cacheProvider")) + def getTopicMetadataWithNonSupportedSecurityProtocol(cache: MetadataCache): Unit = { + val topic = "topic" + val securityProtocol = SecurityProtocol.PLAINTEXT + val brokers = Seq(new UpdateMetadataBroker() + .setId(0) + .setRack("") + .setEndpoints(Seq(new UpdateMetadataEndpoint() + .setHost("foo") + .setPort(9092) + .setSecurityProtocol(securityProtocol.id) + .setListener(ListenerName.forSecurityProtocol(securityProtocol).value)).asJava)) + val controllerEpoch = 1 + val leader = 0 + val leaderEpoch = 0 + val replicas = asList[Integer](0) + val isr = asList[Integer](0, 1) + val partitionStates = Seq(new UpdateMetadataPartitionState() + .setTopicName(topic) + .setPartitionIndex(0) + .setControllerEpoch(controllerEpoch) + .setLeader(leader) + .setLeaderEpoch(leaderEpoch) + .setIsr(isr) + .setZkVersion(3) + .setReplicas(replicas)) + val version = ApiKeys.UPDATE_METADATA.latestVersion + val updateMetadataRequest = new UpdateMetadataRequest.Builder(version, 2, controllerEpoch, brokerEpoch, partitionStates.asJava, + brokers.asJava, util.Collections.emptyMap()).build() + MetadataCacheTest.updateCache(cache, updateMetadataRequest) + + val topicMetadata = cache.getTopicMetadata(Set(topic), ListenerName.forSecurityProtocol(SecurityProtocol.SSL)) + assertEquals(1, topicMetadata.size) + assertEquals(1, topicMetadata.head.partitions.size) + assertEquals(RecordBatch.NO_PARTITION_LEADER_EPOCH, topicMetadata.head.partitions.get(0).leaderId) + } + + @ParameterizedTest + @MethodSource(Array("cacheProvider")) + def getAliveBrokersShouldNotBeMutatedByUpdateCache(cache: MetadataCache): Unit = { + val topic = "topic" + + def updateCache(brokerIds: Seq[Int]): Unit = { + val brokers = brokerIds.map { brokerId => + val securityProtocol = SecurityProtocol.PLAINTEXT + new UpdateMetadataBroker() + .setId(brokerId) + .setRack("") + .setEndpoints(Seq(new UpdateMetadataEndpoint() + .setHost("foo") + .setPort(9092) + .setSecurityProtocol(securityProtocol.id) + .setListener(ListenerName.forSecurityProtocol(securityProtocol).value)).asJava) + } + val controllerEpoch = 1 + val leader = 0 + val leaderEpoch = 0 + val replicas = asList[Integer](0) + val isr = asList[Integer](0, 1) + val partitionStates = Seq(new UpdateMetadataPartitionState() + .setTopicName(topic) + .setPartitionIndex(0) + .setControllerEpoch(controllerEpoch) + .setLeader(leader) + .setLeaderEpoch(leaderEpoch) + .setIsr(isr) + .setZkVersion(3) + .setReplicas(replicas)) + val version = ApiKeys.UPDATE_METADATA.latestVersion + val updateMetadataRequest = new UpdateMetadataRequest.Builder(version, 2, controllerEpoch, brokerEpoch, partitionStates.asJava, + brokers.asJava, util.Collections.emptyMap()).build() + MetadataCacheTest.updateCache(cache, updateMetadataRequest) + } + + val initialBrokerIds = (0 to 2) + updateCache(initialBrokerIds) + val aliveBrokersFromCache = cache.getAliveBrokers() + // This should not change `aliveBrokersFromCache` + updateCache((0 to 3)) + assertEquals(initialBrokerIds.toSet, aliveBrokersFromCache.map(_.id).toSet) + } + + // This test runs only for the ZK cache, because KRaft mode doesn't support offline + // replicas yet. TODO: implement KAFKA-13005. + @ParameterizedTest + @MethodSource(Array("zkCacheProvider")) + def testGetClusterMetadataWithOfflineReplicas(cache: MetadataCache): Unit = { + val topic = "topic" + val topicPartition = new TopicPartition(topic, 0) + val securityProtocol = SecurityProtocol.PLAINTEXT + val listenerName = ListenerName.forSecurityProtocol(securityProtocol) + + val brokers = Seq( + new UpdateMetadataBroker() + .setId(0) + .setRack("") + .setEndpoints(Seq(new UpdateMetadataEndpoint() + .setHost("foo") + .setPort(9092) + .setSecurityProtocol(securityProtocol.id) + .setListener(listenerName.value)).asJava), + new UpdateMetadataBroker() + .setId(1) + .setEndpoints(Seq.empty.asJava) + ) + val controllerEpoch = 1 + val leader = 1 + val leaderEpoch = 0 + val replicas = asList[Integer](0, 1) + val isr = asList[Integer](0, 1) + val offline = asList[Integer](1) + val partitionStates = Seq(new UpdateMetadataPartitionState() + .setTopicName(topic) + .setPartitionIndex(topicPartition.partition) + .setControllerEpoch(controllerEpoch) + .setLeader(leader) + .setLeaderEpoch(leaderEpoch) + .setIsr(isr) + .setZkVersion(3) + .setReplicas(replicas) + .setOfflineReplicas(offline)) + val version = ApiKeys.UPDATE_METADATA.latestVersion + val updateMetadataRequest = new UpdateMetadataRequest.Builder(version, 2, controllerEpoch, brokerEpoch, partitionStates.asJava, + brokers.asJava, Collections.emptyMap()).build() + MetadataCacheTest.updateCache(cache, updateMetadataRequest) + + val expectedNode0 = new Node(0, "foo", 9092) + val expectedNode1 = new Node(1, "", -1) + + val cluster = cache.getClusterMetadata("clusterId", listenerName) + assertEquals(expectedNode0, cluster.nodeById(0)) + assertNull(cluster.nodeById(1)) + assertEquals(expectedNode1, cluster.leaderFor(topicPartition)) + + val partitionInfo = cluster.partition(topicPartition) + assertEquals(expectedNode1, partitionInfo.leader) + assertEquals(Seq(expectedNode0, expectedNode1), partitionInfo.replicas.toSeq) + assertEquals(Seq(expectedNode0, expectedNode1), partitionInfo.inSyncReplicas.toSeq) + assertEquals(Seq(expectedNode1), partitionInfo.offlineReplicas.toSeq) + } +} diff --git a/core/src/test/scala/unit/kafka/server/MetadataRequestTest.scala b/core/src/test/scala/unit/kafka/server/MetadataRequestTest.scala new file mode 100644 index 0000000..d91d58e --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/MetadataRequestTest.scala @@ -0,0 +1,377 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util.Optional + +import kafka.utils.TestUtils +import org.apache.kafka.common.Uuid +import org.apache.kafka.common.errors.UnsupportedVersionException +import org.apache.kafka.common.internals.Topic +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.{MetadataRequest, MetadataResponse} +import org.apache.kafka.metadata.BrokerState +import org.apache.kafka.test.TestUtils.isValidClusterId +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{BeforeEach, Test, TestInfo} + +import scala.collection.Seq +import scala.jdk.CollectionConverters._ + +class MetadataRequestTest extends AbstractMetadataRequestTest { + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + doSetup(testInfo, createOffsetsTopic = false) + } + + @Test + def testClusterIdWithRequestVersion1(): Unit = { + val v1MetadataResponse = sendMetadataRequest(MetadataRequest.Builder.allTopics.build(1.toShort)) + val v1ClusterId = v1MetadataResponse.clusterId + assertNull(v1ClusterId, s"v1 clusterId should be null") + } + + @Test + def testClusterIdIsValid(): Unit = { + val metadataResponse = sendMetadataRequest(MetadataRequest.Builder.allTopics.build(2.toShort)) + isValidClusterId(metadataResponse.clusterId) + } + + @Test + def testControllerId(): Unit = { + val controllerServer = servers.find(_.kafkaController.isActive).get + val controllerId = controllerServer.config.brokerId + val metadataResponse = sendMetadataRequest(MetadataRequest.Builder.allTopics.build(1.toShort)) + + assertEquals(controllerId, + metadataResponse.controller.id, "Controller id should match the active controller") + + // Fail over the controller + controllerServer.shutdown() + controllerServer.startup() + + val controllerServer2 = servers.find(_.kafkaController.isActive).get + val controllerId2 = controllerServer2.config.brokerId + assertNotEquals(controllerId, controllerId2, "Controller id should switch to a new broker") + TestUtils.waitUntilTrue(() => { + val metadataResponse2 = sendMetadataRequest(MetadataRequest.Builder.allTopics.build(1.toShort)) + metadataResponse2.controller != null && controllerServer2.dataPlaneRequestProcessor.brokerId == metadataResponse2.controller.id + }, "Controller id should match the active controller after failover", 5000) + } + + @Test + def testRack(): Unit = { + val metadataResponse = sendMetadataRequest(MetadataRequest.Builder.allTopics.build(1.toShort)) + // Validate rack matches what's set in generateConfigs() above + metadataResponse.brokers.forEach { broker => + assertEquals(s"rack/${broker.id}", broker.rack, "Rack information should match config") + } + } + + @Test + def testIsInternal(): Unit = { + val internalTopic = Topic.GROUP_METADATA_TOPIC_NAME + val notInternalTopic = "notInternal" + // create the topics + createTopic(internalTopic, 3, 2) + createTopic(notInternalTopic, 3, 2) + + val metadataResponse = sendMetadataRequest(MetadataRequest.Builder.allTopics.build(1.toShort)) + assertTrue(metadataResponse.errors.isEmpty, "Response should have no errors") + + val topicMetadata = metadataResponse.topicMetadata.asScala + val internalTopicMetadata = topicMetadata.find(_.topic == internalTopic).get + val notInternalTopicMetadata = topicMetadata.find(_.topic == notInternalTopic).get + + assertTrue(internalTopicMetadata.isInternal, "internalTopic should show isInternal") + assertFalse(notInternalTopicMetadata.isInternal, "notInternalTopic topic not should show isInternal") + + assertEquals(Set(internalTopic).asJava, metadataResponse.buildCluster().internalTopics) + } + + @Test + def testNoTopicsRequest(): Unit = { + // create some topics + createTopic("t1", 3, 2) + createTopic("t2", 3, 2) + + // v0, Doesn't support a "no topics" request + // v1, Empty list represents "no topics" + val metadataResponse = sendMetadataRequest(new MetadataRequest.Builder(List[String]().asJava, true, 1.toShort).build) + assertTrue(metadataResponse.errors.isEmpty, "Response should have no errors") + assertTrue(metadataResponse.topicMetadata.isEmpty, "Response should have no topics") + } + + @Test + def testAutoTopicCreation(): Unit = { + val topic1 = "t1" + val topic2 = "t2" + val topic3 = "t3" + val topic4 = "t4" + val topic5 = "t5" + createTopic(topic1) + + val response1 = sendMetadataRequest(new MetadataRequest.Builder(Seq(topic1, topic2).asJava, true).build()) + assertNull(response1.errors.get(topic1)) + checkAutoCreatedTopic(topic2, response1) + + // The default behavior in old versions of the metadata API is to allow topic creation, so + // protocol downgrades should happen gracefully when auto-creation is explicitly requested. + val response2 = sendMetadataRequest(new MetadataRequest.Builder(Seq(topic3).asJava, true).build(1)) + checkAutoCreatedTopic(topic3, response2) + + // V3 doesn't support a configurable allowAutoTopicCreation, so disabling auto-creation is not supported + assertThrows(classOf[UnsupportedVersionException], () => sendMetadataRequest(new MetadataRequest(requestData(List(topic4), false), 3.toShort))) + + // V4 and higher support a configurable allowAutoTopicCreation + val response3 = sendMetadataRequest(new MetadataRequest.Builder(Seq(topic4, topic5).asJava, false, 4.toShort).build) + assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, response3.errors.get(topic4)) + assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, response3.errors.get(topic5)) + assertEquals(None, zkClient.getTopicPartitionCount(topic5)) + } + + @Test + def testAutoCreateTopicWithInvalidReplicationFactor(): Unit = { + // Shutdown all but one broker so that the number of brokers is less than the default replication factor + servers.tail.foreach(_.shutdown()) + servers.tail.foreach(_.awaitShutdown()) + + val topic1 = "testAutoCreateTopic" + val response1 = sendMetadataRequest(new MetadataRequest.Builder(Seq(topic1).asJava, true).build) + assertEquals(1, response1.topicMetadata.size) + val topicMetadata = response1.topicMetadata.asScala.head + assertEquals(Errors.INVALID_REPLICATION_FACTOR, topicMetadata.error) + assertEquals(topic1, topicMetadata.topic) + assertEquals(0, topicMetadata.partitionMetadata.size) + } + + @Test + def testAutoCreateOfCollidingTopics(): Unit = { + val topic1 = "testAutoCreate.Topic" + val topic2 = "testAutoCreate_Topic" + val response1 = sendMetadataRequest(new MetadataRequest.Builder(Seq(topic1, topic2).asJava, true).build) + assertEquals(2, response1.topicMetadata.size) + + val responseMap = response1.topicMetadata.asScala.map(metadata => (metadata.topic(), metadata.error)).toMap + + assertEquals(Set(topic1, topic2), responseMap.keySet) + // The topic creation will be delayed, and the name collision error will be swallowed. + assertEquals(Set(Errors.LEADER_NOT_AVAILABLE, Errors.INVALID_TOPIC_EXCEPTION), responseMap.values.toSet) + + val topicCreated = responseMap.head._1 + TestUtils.waitUntilLeaderIsElectedOrChanged(zkClient, topicCreated, 0) + TestUtils.waitForPartitionMetadata(servers, topicCreated, 0) + + // retry the metadata for the first auto created topic + val response2 = sendMetadataRequest(new MetadataRequest.Builder(Seq(topicCreated).asJava, true).build) + val topicMetadata1 = response2.topicMetadata.asScala.head + assertEquals(Errors.NONE, topicMetadata1.error) + assertEquals(Seq(Errors.NONE), topicMetadata1.partitionMetadata.asScala.map(_.error)) + assertEquals(1, topicMetadata1.partitionMetadata.size) + val partitionMetadata = topicMetadata1.partitionMetadata.asScala.head + assertEquals(0, partitionMetadata.partition) + assertEquals(2, partitionMetadata.replicaIds.size) + assertTrue(partitionMetadata.leaderId.isPresent) + assertTrue(partitionMetadata.leaderId.get >= 0) + } + + @Test + def testAllTopicsRequest(): Unit = { + // create some topics + createTopic("t1", 3, 2) + createTopic("t2", 3, 2) + + // v0, Empty list represents all topics + val metadataResponseV0 = sendMetadataRequest(new MetadataRequest(requestData(List(), true), 0.toShort)) + assertTrue(metadataResponseV0.errors.isEmpty, "V0 Response should have no errors") + assertEquals(2, metadataResponseV0.topicMetadata.size(), "V0 Response should have 2 (all) topics") + + // v1, Null represents all topics + val metadataResponseV1 = sendMetadataRequest(MetadataRequest.Builder.allTopics.build(1.toShort)) + assertTrue(metadataResponseV1.errors.isEmpty, "V1 Response should have no errors") + assertEquals(2, metadataResponseV1.topicMetadata.size(), "V1 Response should have 2 (all) topics") + } + + @Test + def testTopicIdsInResponse(): Unit = { + val replicaAssignment = Map(0 -> Seq(1, 2, 0), 1 -> Seq(2, 0, 1)) + val topic1 = "topic1" + val topic2 = "topic2" + createTopic(topic1, replicaAssignment) + createTopic(topic2, replicaAssignment) + + // if version < 9, return ZERO_UUID in MetadataResponse + val resp1 = sendMetadataRequest(new MetadataRequest.Builder(Seq(topic1, topic2).asJava, true, 0, 9).build(), Some(controllerSocketServer)) + assertEquals(2, resp1.topicMetadata.size) + resp1.topicMetadata.forEach { topicMetadata => + assertEquals(Errors.NONE, topicMetadata.error) + assertEquals(Uuid.ZERO_UUID, topicMetadata.topicId()) + } + + // from version 10, UUID will be included in MetadataResponse + val resp2 = sendMetadataRequest(new MetadataRequest.Builder(Seq(topic1, topic2).asJava, true, 10, 10).build(), Some(notControllerSocketServer)) + assertEquals(2, resp2.topicMetadata.size) + resp2.topicMetadata.forEach { topicMetadata => + assertEquals(Errors.NONE, topicMetadata.error) + assertNotEquals(Uuid.ZERO_UUID, topicMetadata.topicId()) + assertNotNull(topicMetadata.topicId()) + } + } + + /** + * Preferred replica should be the first item in the replicas list + */ + @Test + def testPreferredReplica(): Unit = { + val replicaAssignment = Map(0 -> Seq(1, 2, 0), 1 -> Seq(2, 0, 1)) + createTopic("t1", replicaAssignment) + // Call controller and one different broker to ensure that metadata propagation works correctly + val responses = Seq( + sendMetadataRequest(new MetadataRequest.Builder(Seq("t1").asJava, true).build(), Some(controllerSocketServer)), + sendMetadataRequest(new MetadataRequest.Builder(Seq("t1").asJava, true).build(), Some(notControllerSocketServer)) + ) + responses.foreach { response => + assertEquals(1, response.topicMetadata.size) + val topicMetadata = response.topicMetadata.iterator.next() + assertEquals(Errors.NONE, topicMetadata.error) + assertEquals("t1", topicMetadata.topic) + assertEquals(Set(0, 1), topicMetadata.partitionMetadata.asScala.map(_.partition).toSet) + topicMetadata.partitionMetadata.forEach { partitionMetadata => + val assignment = replicaAssignment(partitionMetadata.partition) + assertEquals(assignment, partitionMetadata.replicaIds.asScala) + assertEquals(assignment, partitionMetadata.inSyncReplicaIds.asScala) + assertEquals(Optional.of(assignment.head), partitionMetadata.leaderId) + } + } + } + + @Test + def testReplicaDownResponse(): Unit = { + val replicaDownTopic = "replicaDown" + val replicaCount = 3 + + // create a topic with 3 replicas + createTopic(replicaDownTopic, 1, replicaCount) + + // Kill a replica node that is not the leader + val metadataResponse = sendMetadataRequest(new MetadataRequest.Builder(List(replicaDownTopic).asJava, true).build()) + val partitionMetadata = metadataResponse.topicMetadata.asScala.head.partitionMetadata.asScala.head + val downNode = servers.find { server => + val serverId = server.dataPlaneRequestProcessor.brokerId + val leaderId = partitionMetadata.leaderId + val replicaIds = partitionMetadata.replicaIds.asScala + leaderId.isPresent && leaderId.get() != serverId && replicaIds.contains(serverId) + }.get + downNode.shutdown() + + TestUtils.waitUntilTrue(() => { + val response = sendMetadataRequest(new MetadataRequest.Builder(List(replicaDownTopic).asJava, true).build()) + !response.brokers.asScala.exists(_.id == downNode.dataPlaneRequestProcessor.brokerId) + }, "Replica was not found down", 5000) + + // Validate version 0 still filters unavailable replicas and contains error + val v0MetadataResponse = sendMetadataRequest(new MetadataRequest(requestData(List(replicaDownTopic), true), 0.toShort)) + val v0BrokerIds = v0MetadataResponse.brokers().asScala.map(_.id).toSeq + assertTrue(v0MetadataResponse.errors.isEmpty, "Response should have no errors") + assertFalse(v0BrokerIds.contains(downNode.config.brokerId), s"The downed broker should not be in the brokers list") + assertTrue(v0MetadataResponse.topicMetadata.size == 1, "Response should have one topic") + val v0PartitionMetadata = v0MetadataResponse.topicMetadata.asScala.head.partitionMetadata.asScala.head + assertTrue(v0PartitionMetadata.error == Errors.REPLICA_NOT_AVAILABLE, "PartitionMetadata should have an error") + assertTrue(v0PartitionMetadata.replicaIds.size == replicaCount - 1, s"Response should have ${replicaCount - 1} replicas") + + // Validate version 1 returns unavailable replicas with no error + val v1MetadataResponse = sendMetadataRequest(new MetadataRequest.Builder(List(replicaDownTopic).asJava, true).build(1)) + val v1BrokerIds = v1MetadataResponse.brokers().asScala.map(_.id).toSeq + assertTrue(v1MetadataResponse.errors.isEmpty, "Response should have no errors") + assertFalse(v1BrokerIds.contains(downNode.config.brokerId), s"The downed broker should not be in the brokers list") + assertEquals(1, v1MetadataResponse.topicMetadata.size, "Response should have one topic") + val v1PartitionMetadata = v1MetadataResponse.topicMetadata.asScala.head.partitionMetadata.asScala.head + assertEquals(Errors.NONE, v1PartitionMetadata.error, "PartitionMetadata should have no errors") + assertEquals(replicaCount, v1PartitionMetadata.replicaIds.size, s"Response should have $replicaCount replicas") + } + + @Test + def testIsrAfterBrokerShutDownAndJoinsBack(): Unit = { + def checkIsr(servers: Seq[KafkaServer], topic: String): Unit = { + val activeBrokers = servers.filter(_.brokerState != BrokerState.NOT_RUNNING) + val expectedIsr = activeBrokers.map(_.config.brokerId).toSet + + // Assert that topic metadata at new brokers is updated correctly + activeBrokers.foreach { broker => + var actualIsr = Set.empty[Int] + TestUtils.waitUntilTrue(() => { + val metadataResponse = sendMetadataRequest(new MetadataRequest.Builder(Seq(topic).asJava, false).build, + Some(brokerSocketServer(broker.config.brokerId))) + val firstPartitionMetadata = metadataResponse.topicMetadata.asScala.headOption.flatMap(_.partitionMetadata.asScala.headOption) + actualIsr = firstPartitionMetadata.map { partitionMetadata => + partitionMetadata.inSyncReplicaIds.asScala.map(Int.unbox).toSet + }.getOrElse(Set.empty) + expectedIsr == actualIsr + }, s"Topic metadata not updated correctly in broker $broker\n" + + s"Expected ISR: $expectedIsr \n" + + s"Actual ISR : $actualIsr") + } + } + + val topic = "isr-after-broker-shutdown" + val replicaCount = 3 + createTopic(topic, 1, replicaCount) + + servers.last.shutdown() + servers.last.awaitShutdown() + servers.last.startup() + + checkIsr(servers, topic) + } + + @Test + def testAliveBrokersWithNoTopics(): Unit = { + def checkMetadata(servers: Seq[KafkaServer], expectedBrokersCount: Int): Unit = { + var controllerMetadataResponse: Option[MetadataResponse] = None + TestUtils.waitUntilTrue(() => { + val metadataResponse = sendMetadataRequest(MetadataRequest.Builder.allTopics.build, + Some(controllerSocketServer)) + controllerMetadataResponse = Some(metadataResponse) + metadataResponse.brokers.size == expectedBrokersCount + }, s"Expected $expectedBrokersCount brokers, but there are ${controllerMetadataResponse.get.brokers.size} " + + "according to the Controller") + + val brokersInController = controllerMetadataResponse.get.brokers.asScala.toSeq.sortBy(_.id) + + // Assert that metadata is propagated correctly + servers.filter(_.brokerState != BrokerState.NOT_RUNNING).foreach { broker => + TestUtils.waitUntilTrue(() => { + val metadataResponse = sendMetadataRequest(MetadataRequest.Builder.allTopics.build, + Some(brokerSocketServer(broker.config.brokerId))) + val brokers = metadataResponse.brokers.asScala.toSeq.sortBy(_.id) + val topicMetadata = metadataResponse.topicMetadata.asScala.toSeq.sortBy(_.topic) + brokersInController == brokers && metadataResponse.topicMetadata.asScala.toSeq.sortBy(_.topic) == topicMetadata + }, s"Topic metadata not updated correctly") + } + } + + val serverToShutdown = servers.filterNot(_.kafkaController.isActive).last + serverToShutdown.shutdown() + serverToShutdown.awaitShutdown() + checkMetadata(servers, servers.size - 1) + + serverToShutdown.startup() + checkMetadata(servers, servers.size) + } +} diff --git a/core/src/test/scala/unit/kafka/server/MetadataRequestWithForwardingTest.scala b/core/src/test/scala/unit/kafka/server/MetadataRequestWithForwardingTest.scala new file mode 100644 index 0000000..3580e2b --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/MetadataRequestWithForwardingTest.scala @@ -0,0 +1,111 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.utils.TestUtils +import org.apache.kafka.common.errors.UnsupportedVersionException +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.MetadataRequest +import org.junit.jupiter.api.Assertions.{assertEquals, assertNull, assertThrows, assertTrue} +import org.junit.jupiter.api.{BeforeEach, Test, TestInfo} + +import scala.collection.Seq +import scala.jdk.CollectionConverters._ + +class MetadataRequestWithForwardingTest extends AbstractMetadataRequestTest { + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + doSetup(testInfo, createOffsetsTopic = false) + } + + override def enableForwarding: Boolean = true + + @Test + def testAutoTopicCreation(): Unit = { + val topic1 = "t1" + val topic2 = "t2" + val topic3 = "t3" + val topic4 = "t4" + val topic5 = "t5" + createTopic(topic1) + + val response1 = sendMetadataRequest(new MetadataRequest.Builder(Seq(topic1, topic2).asJava, true).build()) + assertNull(response1.errors.get(topic1)) + checkAutoCreatedTopic(topic2, response1) + + // The default behavior in old versions of the metadata API is to allow topic creation, so + // protocol downgrades should happen gracefully when auto-creation is explicitly requested. + val response2 = sendMetadataRequest(new MetadataRequest.Builder(Seq(topic3).asJava, true).build(1)) + checkAutoCreatedTopic(topic3, response2) + + // V3 doesn't support a configurable allowAutoTopicCreation, so disabling auto-creation is not supported + assertThrows(classOf[UnsupportedVersionException], () => sendMetadataRequest(new MetadataRequest(requestData(List(topic4), false), 3.toShort))) + + // V4 and higher support a configurable allowAutoTopicCreation + val response3 = sendMetadataRequest(new MetadataRequest.Builder(Seq(topic4, topic5).asJava, false, 4.toShort).build) + assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, response3.errors.get(topic4)) + assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, response3.errors.get(topic5)) + assertEquals(None, zkClient.getTopicPartitionCount(topic5)) + } + + @Test + def testAutoCreateTopicWithInvalidReplicationFactor(): Unit = { + // Shutdown all but one broker so that the number of brokers is less than the default replication factor + servers.tail.foreach(_.shutdown()) + servers.tail.foreach(_.awaitShutdown()) + + val topic1 = "testAutoCreateTopic" + val response1 = sendMetadataRequest(new MetadataRequest.Builder(Seq(topic1).asJava, true).build) + assertEquals(1, response1.topicMetadata.size) + val topicMetadata = response1.topicMetadata.asScala.head + assertEquals(Errors.INVALID_REPLICATION_FACTOR, topicMetadata.error) + assertEquals(topic1, topicMetadata.topic) + assertEquals(0, topicMetadata.partitionMetadata.size) + } + + @Test + def testAutoCreateOfCollidingTopics(): Unit = { + val topic1 = "testAutoCreate.Topic" + val topic2 = "testAutoCreate_Topic" + val response1 = sendMetadataRequest(new MetadataRequest.Builder(Seq(topic1, topic2).asJava, true).build) + assertEquals(2, response1.topicMetadata.size) + + val responseMap = response1.topicMetadata.asScala.map(metadata => (metadata.topic(), metadata.error)).toMap + + assertEquals(Set(topic1, topic2), responseMap.keySet) + // The topic creation will be delayed, and the name collision error will be swallowed. + assertEquals(Set(Errors.LEADER_NOT_AVAILABLE, Errors.INVALID_TOPIC_EXCEPTION), responseMap.values.toSet) + + val topicCreated = responseMap.head._1 + TestUtils.waitUntilLeaderIsElectedOrChanged(zkClient, topicCreated, 0) + TestUtils.waitForPartitionMetadata(servers, topicCreated, 0) + + // retry the metadata for the first auto created topic + val response2 = sendMetadataRequest(new MetadataRequest.Builder(Seq(topicCreated).asJava, true).build) + val topicMetadata1 = response2.topicMetadata.asScala.head + assertEquals(Errors.NONE, topicMetadata1.error) + assertEquals(Seq(Errors.NONE), topicMetadata1.partitionMetadata.asScala.map(_.error)) + assertEquals(1, topicMetadata1.partitionMetadata.size) + val partitionMetadata = topicMetadata1.partitionMetadata.asScala.head + assertEquals(0, partitionMetadata.partition) + assertEquals(2, partitionMetadata.replicaIds.size) + assertTrue(partitionMetadata.leaderId.isPresent) + assertTrue(partitionMetadata.leaderId.get >= 0) + } +} diff --git a/core/src/test/scala/unit/kafka/server/MockBrokerToControllerChannelManager.scala b/core/src/test/scala/unit/kafka/server/MockBrokerToControllerChannelManager.scala new file mode 100644 index 0000000..febd06f --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/MockBrokerToControllerChannelManager.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import kafka.utils.MockTime +import org.apache.kafka.clients.{ClientResponse, MockClient, NodeApiVersions} +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.AbstractRequest + +class MockBrokerToControllerChannelManager( + val client: MockClient, + time: MockTime, + controllerNodeProvider: ControllerNodeProvider, + controllerApiVersions: NodeApiVersions = NodeApiVersions.create(), + val retryTimeoutMs: Int = 60000, + val requestTimeoutMs: Int = 30000 +) extends BrokerToControllerChannelManager { + private val unsentQueue = new java.util.ArrayDeque[BrokerToControllerQueueItem]() + + client.setNodeApiVersions(controllerApiVersions) + + override def start(): Unit = {} + + override def shutdown(): Unit = {} + + override def sendRequest( + request: AbstractRequest.Builder[_ <: AbstractRequest], + callback: ControllerRequestCompletionHandler + ): Unit = { + unsentQueue.add(BrokerToControllerQueueItem( + createdTimeMs = time.milliseconds(), + request = request, + callback = callback + )) + } + + override def controllerApiVersions(): Option[NodeApiVersions] = { + Some(controllerApiVersions) + } + + private[server] def handleResponse(request: BrokerToControllerQueueItem)(response: ClientResponse): Unit = { + if (response.authenticationException != null || response.versionMismatch != null) { + request.callback.onComplete(response) + } else if (response.wasDisconnected() || response.responseBody.errorCounts.containsKey(Errors.NOT_CONTROLLER)) { + unsentQueue.addFirst(request) + } else { + request.callback.onComplete(response) + } + } + + def poll(): Unit = { + val unsentIterator = unsentQueue.iterator() + var canSend = true + + while (canSend && unsentIterator.hasNext) { + val queueItem = unsentIterator.next() + val elapsedTimeMs = time.milliseconds() - queueItem.createdTimeMs + if (elapsedTimeMs >= retryTimeoutMs) { + queueItem.callback.onTimeout() + unsentIterator.remove() + } else { + controllerNodeProvider.get() match { + case Some(controller) if client.ready(controller, time.milliseconds()) => + val clientRequest = client.newClientRequest( + controller.idString, + queueItem.request, + queueItem.createdTimeMs, + true, // we expect response, + requestTimeoutMs, + handleResponse(queueItem) + ) + client.send(clientRequest, time.milliseconds()) + unsentIterator.remove() + + case _ => canSend = false + } + } + } + + client.poll(0L, time.milliseconds()) + } + +} diff --git a/core/src/test/scala/unit/kafka/server/OffsetFetchRequestTest.scala b/core/src/test/scala/unit/kafka/server/OffsetFetchRequestTest.scala new file mode 100644 index 0000000..477f3eb --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/OffsetFetchRequestTest.scala @@ -0,0 +1,232 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.utils.TestUtils +import org.apache.kafka.clients.consumer.{ConsumerConfig, OffsetAndMetadata} +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.message.OffsetFetchRequestData.{OffsetFetchRequestGroup, OffsetFetchRequestTopics} +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.requests.OffsetFetchResponse.PartitionData +import org.apache.kafka.common.requests.{AbstractResponse, OffsetFetchRequest, OffsetFetchResponse} +import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue} +import org.junit.jupiter.api.{BeforeEach, Test, TestInfo} +import java.util +import java.util.Collections.singletonList + +import scala.jdk.CollectionConverters._ +import java.util.{Optional, Properties} + +class OffsetFetchRequestTest extends BaseRequestTest { + + override def brokerCount: Int = 1 + + val brokerId: Integer = 0 + val offset = 15L + val leaderEpoch: Optional[Integer] = Optional.of(3) + val metadata = "metadata" + val topic = "topic" + val groupId = "groupId" + val groups: Seq[String] = (1 to 5).map(i => s"group$i") + val topics: Seq[String] = (1 to 3).map(i => s"topic$i") + val topic1List = singletonList(new TopicPartition(topics(0), 0)) + val topic1And2List = util.Arrays.asList( + new TopicPartition(topics(0), 0), + new TopicPartition(topics(1), 0), + new TopicPartition(topics(1), 1)) + val allTopicsList = util.Arrays.asList( + new TopicPartition(topics(0), 0), + new TopicPartition(topics(1), 0), + new TopicPartition(topics(1), 1), + new TopicPartition(topics(2), 0), + new TopicPartition(topics(2), 1), + new TopicPartition(topics(2), 2)) + val groupToPartitionMap: util.Map[String, util.List[TopicPartition]] = + new util.HashMap[String, util.List[TopicPartition]]() + groupToPartitionMap.put(groups(0), topic1List) + groupToPartitionMap.put(groups(1), topic1And2List) + groupToPartitionMap.put(groups(2), allTopicsList) + groupToPartitionMap.put(groups(3), null) + groupToPartitionMap.put(groups(4), null) + + override def brokerPropertyOverrides(properties: Properties): Unit = { + properties.put(KafkaConfig.BrokerIdProp, brokerId.toString) + properties.put(KafkaConfig.OffsetsTopicPartitionsProp, "1") + properties.put(KafkaConfig.OffsetsTopicReplicationFactorProp, "1") + properties.put(KafkaConfig.TransactionsTopicPartitionsProp, "1") + properties.put(KafkaConfig.TransactionsTopicReplicationFactorProp, "1") + properties.put(KafkaConfig.TransactionsTopicMinISRProp, "1") + } + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + doSetup(testInfo, createOffsetsTopic = false) + + TestUtils.createOffsetsTopic(zkClient, servers) + } + + @Test + def testOffsetFetchRequestSingleGroup(): Unit = { + createTopic(topic) + + val tpList = singletonList(new TopicPartition(topic, 0)) + consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, groupId) + commitOffsets(tpList) + + // testing from version 1 onward since version 0 read offsets from ZK + for (version <- 1 to ApiKeys.OFFSET_FETCH.latestVersion()) { + if (version < 8) { + val request = + if (version < 7) { + new OffsetFetchRequest.Builder( + groupId, false, tpList, false) + .build(version.asInstanceOf[Short]) + } else { + new OffsetFetchRequest.Builder( + groupId, false, tpList, true) + .build(version.asInstanceOf[Short]) + } + val response = connectAndReceive[OffsetFetchResponse](request) + val topicData = response.data().topics().get(0) + val partitionData = topicData.partitions().get(0) + if (version < 3) { + assertEquals(AbstractResponse.DEFAULT_THROTTLE_TIME, response.throttleTimeMs()) + } + verifySingleGroupResponse(version.asInstanceOf[Short], + response.error().code(), partitionData.errorCode(), topicData.name(), + partitionData.partitionIndex(), partitionData.committedOffset(), + partitionData.committedLeaderEpoch(), partitionData.metadata()) + } else { + val request = new OffsetFetchRequest.Builder( + Map(groupId -> tpList).asJava, false, false) + .build(version.asInstanceOf[Short]) + val response = connectAndReceive[OffsetFetchResponse](request) + val groupData = response.data().groups().get(0) + val topicData = groupData.topics().get(0) + val partitionData = topicData.partitions().get(0) + verifySingleGroupResponse(version.asInstanceOf[Short], + groupData.errorCode(), partitionData.errorCode(), topicData.name(), + partitionData.partitionIndex(), partitionData.committedOffset(), + partitionData.committedLeaderEpoch(), partitionData.metadata()) + } + } + } + + @Test + def testOffsetFetchRequestWithMultipleGroups(): Unit = { + createTopic(topics(0)) + createTopic(topics(1), numPartitions = 2) + createTopic(topics(2), numPartitions = 3) + + // create 5 consumers to commit offsets so we can fetch them later + val partitionMap = groupToPartitionMap.asScala.map(e => (e._1, Option(e._2).getOrElse(allTopicsList))) + groups.foreach { groupId => + consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, groupId) + commitOffsets(partitionMap(groupId)) + } + + for (version <- 8 to ApiKeys.OFFSET_FETCH.latestVersion()) { + val request = new OffsetFetchRequest.Builder(groupToPartitionMap, false, false) + .build(version.asInstanceOf[Short]) + val response = connectAndReceive[OffsetFetchResponse](request) + response.data.groups.asScala.map(_.groupId).foreach( groupId => + verifyResponse(response.groupLevelError(groupId), response.partitionDataMap(groupId), partitionMap(groupId)) + ) + } + } + + @Test + def testOffsetFetchRequestWithMultipleGroupsWithOneGroupRepeating(): Unit = { + createTopic(topics(0)) + createTopic(topics(1), numPartitions = 2) + createTopic(topics(2), numPartitions = 3) + + // create 5 consumers to commit offsets so we can fetch them later + val partitionMap = groupToPartitionMap.asScala.map(e => (e._1, Option(e._2).getOrElse(allTopicsList))) + groups.foreach { groupId => + consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, groupId) + commitOffsets(partitionMap(groupId)) + } + + for (version <- 8 to ApiKeys.OFFSET_FETCH.latestVersion()) { + val request = new OffsetFetchRequest.Builder(groupToPartitionMap, false, false) + .build(version.asInstanceOf[Short]) + val requestGroups = request.data().groups() + requestGroups.add( + // add the same group as before with different topic partitions + new OffsetFetchRequestGroup() + .setGroupId(groups(2)) + .setTopics(singletonList( + new OffsetFetchRequestTopics() + .setName(topics(0)) + .setPartitionIndexes(singletonList(0))))) + request.data().setGroups(requestGroups) + val response = connectAndReceive[OffsetFetchResponse](request) + response.data.groups.asScala.map(_.groupId).foreach( groupId => + if (groupId == "group3") // verify that the response gives back the latest changed topic partition list + verifyResponse(response.groupLevelError(groupId), response.partitionDataMap(groupId), topic1List) + else + verifyResponse(response.groupLevelError(groupId), response.partitionDataMap(groupId), partitionMap(groupId)) + ) + } + } + + private def verifySingleGroupResponse(version: Short, + responseError: Short, + partitionError: Short, + topicName: String, + partitionIndex: Integer, + committedOffset: Long, + committedLeaderEpoch: Integer, + partitionMetadata: String): Unit = { + assertEquals(Errors.NONE.code(), responseError) + assertEquals(topic, topicName) + assertEquals(0, partitionIndex) + assertEquals(offset, committedOffset) + if (version >= 5) { + assertEquals(leaderEpoch.get(), committedLeaderEpoch) + } + assertEquals(metadata, partitionMetadata) + assertEquals(Errors.NONE.code(), partitionError) + } + + private def verifyPartitionData(partitionData: OffsetFetchResponse.PartitionData): Unit = { + assertTrue(!partitionData.hasError) + assertEquals(offset, partitionData.offset) + assertEquals(metadata, partitionData.metadata) + assertEquals(leaderEpoch.get(), partitionData.leaderEpoch.get()) + } + + private def verifyResponse(groupLevelResponse: Errors, + partitionData: util.Map[TopicPartition, PartitionData], + topicList: util.List[TopicPartition]): Unit = { + assertEquals(Errors.NONE, groupLevelResponse) + assertTrue(partitionData.size() == topicList.size()) + topicList.forEach(t => verifyPartitionData(partitionData.get(t))) + } + + private def commitOffsets(tpList: util.List[TopicPartition]): Unit = { + val consumer = createConsumer() + consumer.assign(tpList) + val offsets = tpList.asScala.map{ + tp => (tp, new OffsetAndMetadata(offset, leaderEpoch, metadata)) + }.toMap.asJava + consumer.commitSync(offsets) + consumer.close() + } +} diff --git a/core/src/test/scala/unit/kafka/server/OffsetsForLeaderEpochRequestTest.scala b/core/src/test/scala/unit/kafka/server/OffsetsForLeaderEpochRequestTest.scala new file mode 100644 index 0000000..12d2fb3 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/OffsetsForLeaderEpochRequestTest.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.util.Optional + +import kafka.utils.TestUtils +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderPartition +import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderTopic +import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderTopicCollection +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.record.RecordBatch +import org.apache.kafka.common.requests.{OffsetsForLeaderEpochRequest, OffsetsForLeaderEpochResponse} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ + +class OffsetsForLeaderEpochRequestTest extends BaseRequestTest { + + @Test + def testOffsetsForLeaderEpochErrorCodes(): Unit = { + val topic = "topic" + val partition = new TopicPartition(topic, 0) + val epochs = offsetForLeaderTopicCollectionFor(partition, 0, RecordBatch.NO_PARTITION_LEADER_EPOCH) + + val request = OffsetsForLeaderEpochRequest.Builder.forFollower( + ApiKeys.OFFSET_FOR_LEADER_EPOCH.latestVersion, epochs, 1).build() + + // Unknown topic + val randomBrokerId = servers.head.config.brokerId + assertResponseError(Errors.UNKNOWN_TOPIC_OR_PARTITION, randomBrokerId, request) + + val partitionToLeader = TestUtils.createTopic(zkClient, topic, numPartitions = 1, replicationFactor = 2, servers) + val replicas = zkClient.getReplicasForPartition(partition).toSet + val leader = partitionToLeader(partition.partition) + val follower = replicas.find(_ != leader).get + val nonReplica = servers.map(_.config.brokerId).find(!replicas.contains(_)).get + + assertResponseError(Errors.NOT_LEADER_OR_FOLLOWER, follower, request) + assertResponseError(Errors.NOT_LEADER_OR_FOLLOWER, nonReplica, request) + } + + @Test + def testCurrentEpochValidation(): Unit = { + val topic = "topic" + val topicPartition = new TopicPartition(topic, 0) + val partitionToLeader = TestUtils.createTopic(zkClient, topic, numPartitions = 1, replicationFactor = 3, servers) + val firstLeaderId = partitionToLeader(topicPartition.partition) + + def assertResponseErrorForEpoch(error: Errors, brokerId: Int, currentLeaderEpoch: Optional[Integer]): Unit = { + val epochs = offsetForLeaderTopicCollectionFor(topicPartition, 0, + currentLeaderEpoch.orElse(RecordBatch.NO_PARTITION_LEADER_EPOCH)) + val request = OffsetsForLeaderEpochRequest.Builder.forFollower( + ApiKeys.OFFSET_FOR_LEADER_EPOCH.latestVersion, epochs, 1).build() + assertResponseError(error, brokerId, request) + } + + // We need a leader change in order to check epoch fencing since the first epoch is 0 and + // -1 is treated as having no epoch at all + killBroker(firstLeaderId) + + // Check leader error codes + val secondLeaderId = TestUtils.awaitLeaderChange(servers, topicPartition, firstLeaderId) + val secondLeaderEpoch = TestUtils.findLeaderEpoch(secondLeaderId, topicPartition, servers) + assertResponseErrorForEpoch(Errors.NONE, secondLeaderId, Optional.empty()) + assertResponseErrorForEpoch(Errors.NONE, secondLeaderId, Optional.of(secondLeaderEpoch)) + assertResponseErrorForEpoch(Errors.FENCED_LEADER_EPOCH, secondLeaderId, Optional.of(secondLeaderEpoch - 1)) + assertResponseErrorForEpoch(Errors.UNKNOWN_LEADER_EPOCH, secondLeaderId, Optional.of(secondLeaderEpoch + 1)) + + // Check follower error codes + val followerId = TestUtils.findFollowerId(topicPartition, servers) + assertResponseErrorForEpoch(Errors.NOT_LEADER_OR_FOLLOWER, followerId, Optional.empty()) + assertResponseErrorForEpoch(Errors.NOT_LEADER_OR_FOLLOWER, followerId, Optional.of(secondLeaderEpoch)) + assertResponseErrorForEpoch(Errors.UNKNOWN_LEADER_EPOCH, followerId, Optional.of(secondLeaderEpoch + 1)) + assertResponseErrorForEpoch(Errors.FENCED_LEADER_EPOCH, followerId, Optional.of(secondLeaderEpoch - 1)) + } + + private def offsetForLeaderTopicCollectionFor( + topicPartition: TopicPartition, + leaderEpoch: Int, + currentLeaderEpoch: Int + ): OffsetForLeaderTopicCollection = { + new OffsetForLeaderTopicCollection(List( + new OffsetForLeaderTopic() + .setTopic(topicPartition.topic) + .setPartitions(List( + new OffsetForLeaderPartition() + .setPartition(topicPartition.partition) + .setLeaderEpoch(leaderEpoch) + .setCurrentLeaderEpoch(currentLeaderEpoch) + ).asJava)).iterator.asJava) + } + + private def assertResponseError(error: Errors, brokerId: Int, request: OffsetsForLeaderEpochRequest): Unit = { + val response = sendRequest(brokerId, request) + assertEquals(request.data.topics.size, response.data.topics.size) + response.data.topics.asScala.foreach { offsetForLeaderTopic => + assertEquals(request.data.topics.find(offsetForLeaderTopic.topic).partitions.size, + offsetForLeaderTopic.partitions.size) + offsetForLeaderTopic.partitions.asScala.foreach { offsetForLeaderPartition => + assertEquals(error.code(), offsetForLeaderPartition.errorCode()) + } + } + } + + private def sendRequest(brokerId: Int, request: OffsetsForLeaderEpochRequest): OffsetsForLeaderEpochResponse = { + connectAndReceive[OffsetsForLeaderEpochResponse](request, destination = brokerSocketServer(brokerId)) + } + +} diff --git a/core/src/test/scala/unit/kafka/server/ProduceRequestTest.scala b/core/src/test/scala/unit/kafka/server/ProduceRequestTest.scala new file mode 100644 index 0000000..7d3ded5 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/ProduceRequestTest.scala @@ -0,0 +1,253 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.nio.ByteBuffer +import java.util.{Collections, Properties} + +import kafka.log.LogConfig +import kafka.message.ZStdCompressionCodec +import kafka.metrics.KafkaYammerMetrics +import kafka.utils.TestUtils +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.message.ProduceRequestData +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.record._ +import org.apache.kafka.common.requests.{ProduceRequest, ProduceResponse} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ + +/** + * Subclasses of `BaseProduceSendRequestTest` exercise the producer and produce request/response. This class + * complements those classes with tests that require lower-level access to the protocol. + */ +class ProduceRequestTest extends BaseRequestTest { + + val metricsKeySet = KafkaYammerMetrics.defaultRegistry.allMetrics.keySet.asScala + + @Test + def testSimpleProduceRequest(): Unit = { + val (partition, leader) = createTopicAndFindPartitionWithLeader("topic") + + def sendAndCheck(memoryRecords: MemoryRecords, expectedOffset: Long): Unit = { + val topicPartition = new TopicPartition("topic", partition) + val produceResponse = sendProduceRequest(leader, + ProduceRequest.forCurrentMagic(new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection(Collections.singletonList( + new ProduceRequestData.TopicProduceData() + .setName(topicPartition.topic()) + .setPartitionData(Collections.singletonList(new ProduceRequestData.PartitionProduceData() + .setIndex(topicPartition.partition()) + .setRecords(memoryRecords)))).iterator)) + .setAcks((-1).toShort) + .setTimeoutMs(3000) + .setTransactionalId(null)).build()) + assertEquals(1, produceResponse.data.responses.size) + val topicProduceResponse = produceResponse.data.responses.asScala.head + assertEquals(1, topicProduceResponse.partitionResponses.size) + val partitionProduceResponse = topicProduceResponse.partitionResponses.asScala.head + val tp = new TopicPartition(topicProduceResponse.name, partitionProduceResponse.index) + assertEquals(topicPartition, tp) + assertEquals(Errors.NONE, Errors.forCode(partitionProduceResponse.errorCode)) + assertEquals(expectedOffset, partitionProduceResponse.baseOffset) + assertEquals(-1, partitionProduceResponse.logAppendTimeMs) + assertTrue(partitionProduceResponse.recordErrors.isEmpty) + } + + sendAndCheck(MemoryRecords.withRecords(CompressionType.NONE, + new SimpleRecord(System.currentTimeMillis(), "key".getBytes, "value".getBytes)), 0) + + sendAndCheck(MemoryRecords.withRecords(CompressionType.GZIP, + new SimpleRecord(System.currentTimeMillis(), "key1".getBytes, "value1".getBytes), + new SimpleRecord(System.currentTimeMillis(), "key2".getBytes, "value2".getBytes)), 1) + } + + @Test + def testProduceWithInvalidTimestamp(): Unit = { + val topic = "topic" + val partition = 0 + val topicConfig = new Properties + topicConfig.setProperty(LogConfig.MessageTimestampDifferenceMaxMsProp, "1000") + val partitionToLeader = TestUtils.createTopic(zkClient, topic, 1, 1, servers, topicConfig) + val leader = partitionToLeader(partition) + + def createRecords(magicValue: Byte, timestamp: Long, codec: CompressionType): MemoryRecords = { + val buf = ByteBuffer.allocate(512) + val builder = MemoryRecords.builder(buf, magicValue, codec, TimestampType.CREATE_TIME, 0L) + builder.appendWithOffset(0, timestamp, null, "hello".getBytes) + builder.appendWithOffset(1, timestamp, null, "there".getBytes) + builder.appendWithOffset(2, timestamp, null, "beautiful".getBytes) + builder.build() + } + + val records = createRecords(RecordBatch.MAGIC_VALUE_V2, System.currentTimeMillis() - 1001L, CompressionType.GZIP) + val topicPartition = new TopicPartition("topic", partition) + val produceResponse = sendProduceRequest(leader, ProduceRequest.forCurrentMagic(new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection(Collections.singletonList( + new ProduceRequestData.TopicProduceData() + .setName(topicPartition.topic()) + .setPartitionData(Collections.singletonList(new ProduceRequestData.PartitionProduceData() + .setIndex(topicPartition.partition()) + .setRecords(records)))).iterator)) + .setAcks((-1).toShort) + .setTimeoutMs(3000) + .setTransactionalId(null)).build()) + + assertEquals(1, produceResponse.data.responses.size) + val topicProduceResponse = produceResponse.data.responses.asScala.head + assertEquals(1, topicProduceResponse.partitionResponses.size) + val partitionProduceResponse = topicProduceResponse.partitionResponses.asScala.head + val tp = new TopicPartition(topicProduceResponse.name, partitionProduceResponse.index) + assertEquals(topicPartition, tp) + assertEquals(Errors.INVALID_TIMESTAMP, Errors.forCode(partitionProduceResponse.errorCode)) + // there are 3 records with InvalidTimestampException created from inner function createRecords + assertEquals(3, partitionProduceResponse.recordErrors.size) + val recordErrors = partitionProduceResponse.recordErrors.asScala + recordErrors.indices.foreach(i => assertEquals(i, recordErrors(i).batchIndex)) + recordErrors.foreach(recordError => assertNotNull(recordError.batchIndexErrorMessage)) + assertEquals("One or more records have been rejected due to invalid timestamp", partitionProduceResponse.errorMessage) + } + + @Test + def testProduceToNonReplica(): Unit = { + val topic = "topic" + val partition = 0 + + // Create a single-partition topic and find a broker which is not the leader + val partitionToLeader = TestUtils.createTopic(zkClient, topic, numPartitions = 1, 1, servers) + val leader = partitionToLeader(partition) + val nonReplicaOpt = servers.find(_.config.brokerId != leader) + assertTrue(nonReplicaOpt.isDefined) + val nonReplicaId = nonReplicaOpt.get.config.brokerId + + // Send the produce request to the non-replica + val records = MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("key".getBytes, "value".getBytes)) + val topicPartition = new TopicPartition("topic", partition) + val produceRequest = ProduceRequest.forCurrentMagic(new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection(Collections.singletonList( + new ProduceRequestData.TopicProduceData() + .setName(topicPartition.topic()) + .setPartitionData(Collections.singletonList(new ProduceRequestData.PartitionProduceData() + .setIndex(topicPartition.partition()) + .setRecords(records)))).iterator)) + .setAcks((-1).toShort) + .setTimeoutMs(3000) + .setTransactionalId(null)).build() + + val produceResponse = sendProduceRequest(nonReplicaId, produceRequest) + assertEquals(1, produceResponse.data.responses.size) + val topicProduceResponse = produceResponse.data.responses.asScala.head + assertEquals(1, topicProduceResponse.partitionResponses.size) + val partitionProduceResponse = topicProduceResponse.partitionResponses.asScala.head + assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, Errors.forCode(partitionProduceResponse.errorCode)) + } + + /* returns a pair of partition id and leader id */ + private def createTopicAndFindPartitionWithLeader(topic: String): (Int, Int) = { + val partitionToLeader = TestUtils.createTopic(zkClient, topic, 3, 2, servers) + partitionToLeader.collectFirst { + case (partition, leader) if leader != -1 => (partition, leader) + }.getOrElse(throw new AssertionError(s"No leader elected for topic $topic")) + } + + @Test + def testCorruptLz4ProduceRequest(): Unit = { + val (partition, leader) = createTopicAndFindPartitionWithLeader("topic") + val timestamp = 1000000 + val memoryRecords = MemoryRecords.withRecords(CompressionType.LZ4, + new SimpleRecord(timestamp, "key".getBytes, "value".getBytes)) + // Change the lz4 checksum value (not the kafka record crc) so that it doesn't match the contents + val lz4ChecksumOffset = 6 + memoryRecords.buffer.array.update(DefaultRecordBatch.RECORD_BATCH_OVERHEAD + lz4ChecksumOffset, 0) + val topicPartition = new TopicPartition("topic", partition) + val produceResponse = sendProduceRequest(leader, ProduceRequest.forCurrentMagic(new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection(Collections.singletonList( + new ProduceRequestData.TopicProduceData() + .setName(topicPartition.topic()) + .setPartitionData(Collections.singletonList(new ProduceRequestData.PartitionProduceData() + .setIndex(topicPartition.partition()) + .setRecords(memoryRecords)))).iterator)) + .setAcks((-1).toShort) + .setTimeoutMs(3000) + .setTransactionalId(null)).build()) + + assertEquals(1, produceResponse.data.responses.size) + val topicProduceResponse = produceResponse.data.responses.asScala.head + assertEquals(1, topicProduceResponse.partitionResponses.size) + val partitionProduceResponse = topicProduceResponse.partitionResponses.asScala.head + val tp = new TopicPartition(topicProduceResponse.name, partitionProduceResponse.index) + assertEquals(topicPartition, tp) + assertEquals(Errors.CORRUPT_MESSAGE, Errors.forCode(partitionProduceResponse.errorCode)) + assertEquals(-1, partitionProduceResponse.baseOffset) + assertEquals(-1, partitionProduceResponse.logAppendTimeMs) + assertEquals(metricsKeySet.count(_.getMBeanName.endsWith(s"${BrokerTopicStats.InvalidMessageCrcRecordsPerSec}")), 1) + assertTrue(TestUtils.meterCount(s"${BrokerTopicStats.InvalidMessageCrcRecordsPerSec}") > 0) + } + + @Test + def testZSTDProduceRequest(): Unit = { + val topic = "topic" + val partition = 0 + + // Create a single-partition topic compressed with ZSTD + val topicConfig = new Properties + topicConfig.setProperty(LogConfig.CompressionTypeProp, ZStdCompressionCodec.name) + val partitionToLeader = TestUtils.createTopic(zkClient, topic, 1, 1, servers, topicConfig) + val leader = partitionToLeader(partition) + val memoryRecords = MemoryRecords.withRecords(CompressionType.ZSTD, + new SimpleRecord(System.currentTimeMillis(), "key".getBytes, "value".getBytes)) + val topicPartition = new TopicPartition("topic", partition) + val partitionRecords = new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection(Collections.singletonList( + new ProduceRequestData.TopicProduceData() + .setName("topic").setPartitionData(Collections.singletonList( + new ProduceRequestData.PartitionProduceData() + .setIndex(partition) + .setRecords(memoryRecords)))) + .iterator)) + .setAcks((-1).toShort) + .setTimeoutMs(3000) + .setTransactionalId(null) + + // produce request with v7: works fine! + val produceResponse1 = sendProduceRequest(leader, new ProduceRequest.Builder(7, 7, partitionRecords).build()) + + val topicProduceResponse1 = produceResponse1.data.responses.asScala.head + val partitionProduceResponse1 = topicProduceResponse1.partitionResponses.asScala.head + val tp1 = new TopicPartition(topicProduceResponse1.name, partitionProduceResponse1.index) + assertEquals(topicPartition, tp1) + assertEquals(Errors.NONE, Errors.forCode(partitionProduceResponse1.errorCode)) + assertEquals(0, partitionProduceResponse1.baseOffset) + assertEquals(-1, partitionProduceResponse1.logAppendTimeMs) + + // produce request with v3: returns Errors.UNSUPPORTED_COMPRESSION_TYPE. + val produceResponse2 = sendProduceRequest(leader, new ProduceRequest.Builder(3, 3, partitionRecords).buildUnsafe(3)) + val topicProduceResponse2 = produceResponse2.data.responses.asScala.head + val partitionProduceResponse2 = topicProduceResponse2.partitionResponses.asScala.head + val tp2 = new TopicPartition(topicProduceResponse2.name, partitionProduceResponse2.index) + assertEquals(topicPartition, tp2) + assertEquals(Errors.UNSUPPORTED_COMPRESSION_TYPE, Errors.forCode(partitionProduceResponse2.errorCode)) + } + + private def sendProduceRequest(leaderId: Int, request: ProduceRequest): ProduceResponse = { + connectAndReceive[ProduceResponse](request, destination = brokerSocketServer(leaderId)) + } + +} diff --git a/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala new file mode 100644 index 0000000..80d69ec --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/ReplicaAlterLogDirsThreadTest.scala @@ -0,0 +1,972 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.util.{Collections, Optional} +import kafka.api.Request +import kafka.cluster.{BrokerEndPoint, Partition} +import kafka.log.{LogManager, UnifiedLog} +import kafka.server.AbstractFetcherThread.ResultWithPartitions +import kafka.server.QuotaFactory.UnboundedQuota +import kafka.server.metadata.ZkMetadataCache +import kafka.utils.{DelayedItem, TestUtils} +import org.apache.kafka.common.errors.KafkaStorageException +import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderPartition +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset +import org.apache.kafka.common.message.UpdateMetadataRequestData +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.record.MemoryRecords +import org.apache.kafka.common.requests.{FetchRequest, UpdateMetadataRequest} +import org.apache.kafka.common.{IsolationLevel, TopicIdPartition, TopicPartition, Uuid} +import org.easymock.EasyMock._ +import org.easymock.{Capture, CaptureType, EasyMock, IExpectationSetters} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test +import org.mockito.Mockito.{doNothing, when} +import org.mockito.{ArgumentCaptor, ArgumentMatchers, Mockito} + +import scala.collection.{Map, Seq} +import scala.jdk.CollectionConverters._ + +class ReplicaAlterLogDirsThreadTest { + + private val t1p0 = new TopicPartition("topic1", 0) + private val t1p1 = new TopicPartition("topic1", 1) + private val topicId = Uuid.randomUuid() + private val topicIds = collection.immutable.Map("topic1" -> topicId) + private val topicNames = collection.immutable.Map(topicId -> "topic1") + private val tid1p0 = new TopicIdPartition(topicId, t1p0) + private val failedPartitions = new FailedPartitions + + private val partitionStates = List(new UpdateMetadataRequestData.UpdateMetadataPartitionState() + .setTopicName("topic1") + .setPartitionIndex(0) + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(0)).asJava + + private val updateMetadataRequest = new UpdateMetadataRequest.Builder(ApiKeys.UPDATE_METADATA.latestVersion(), + 0, 0, 0, partitionStates, Collections.emptyList(), topicIds.asJava).build() + // TODO: support raft code? + private val metadataCache = new ZkMetadataCache(0) + metadataCache.updateMetadata(0, updateMetadataRequest) + + private def initialFetchState(fetchOffset: Long, leaderEpoch: Int = 1): InitialFetchState = { + InitialFetchState(topicId = Some(topicId), leader = new BrokerEndPoint(0, "localhost", 9092), + initOffset = fetchOffset, currentLeaderEpoch = leaderEpoch) + } + + @Test + def shouldNotAddPartitionIfFutureLogIsNotDefined(): Unit = { + val brokerId = 1 + val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(brokerId, "localhost:1234")) + + val replicaManager = Mockito.mock(classOf[ReplicaManager]) + val quotaManager = Mockito.mock(classOf[ReplicationQuotaManager]) + + when(replicaManager.futureLogExists(t1p0)).thenReturn(false) + + val endPoint = new BrokerEndPoint(0, "localhost", 1000) + val thread = new ReplicaAlterLogDirsThread( + "alter-logs-dirs-thread", + sourceBroker = endPoint, + brokerConfig = config, + failedPartitions = failedPartitions, + replicaMgr = replicaManager, + quota = quotaManager, + brokerTopicStats = new BrokerTopicStats) + + val addedPartitions = thread.addPartitions(Map(t1p0 -> initialFetchState(0L))) + assertEquals(Set.empty, addedPartitions) + assertEquals(0, thread.partitionCount) + assertEquals(None, thread.fetchState(t1p0)) + } + + @Test + def shouldUpdateLeaderEpochAfterFencedEpochError(): Unit = { + val brokerId = 1 + val partitionId = 0 + val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(brokerId, "localhost:1234")) + + val partition = Mockito.mock(classOf[Partition]) + val replicaManager = Mockito.mock(classOf[ReplicaManager]) + val quotaManager = Mockito.mock(classOf[ReplicationQuotaManager]) + val futureLog = Mockito.mock(classOf[UnifiedLog]) + + val leaderEpoch = 5 + val logEndOffset = 0 + + when(partition.partitionId).thenReturn(partitionId) + when(replicaManager.metadataCache).thenReturn(metadataCache) + when(replicaManager.futureLocalLogOrException(t1p0)).thenReturn(futureLog) + when(replicaManager.futureLogExists(t1p0)).thenReturn(true) + when(replicaManager.onlinePartition(t1p0)).thenReturn(Some(partition)) + when(replicaManager.getPartitionOrException(t1p0)).thenReturn(partition) + + when(quotaManager.isQuotaExceeded).thenReturn(false) + + when(partition.lastOffsetForLeaderEpoch(Optional.empty(), leaderEpoch, fetchOnlyFromLeader = false)) + .thenReturn(new EpochEndOffset() + .setPartition(partitionId) + .setErrorCode(Errors.NONE.code) + .setLeaderEpoch(leaderEpoch) + .setEndOffset(logEndOffset)) + when(partition.futureLocalLogOrException).thenReturn(futureLog) + doNothing().when(partition).truncateTo(offset = 0, isFuture = true) + when(partition.maybeReplaceCurrentWithFutureReplica()).thenReturn(true) + + when(futureLog.logStartOffset).thenReturn(0L) + when(futureLog.logEndOffset).thenReturn(0L) + when(futureLog.latestEpoch).thenReturn(None) + + val fencedRequestData = new FetchRequest.PartitionData(topicId, 0L, 0L, + config.replicaFetchMaxBytes, Optional.of(leaderEpoch - 1)) + val fencedResponseData = FetchPartitionData( + error = Errors.FENCED_LEADER_EPOCH, + highWatermark = -1, + logStartOffset = -1, + records = MemoryRecords.EMPTY, + divergingEpoch = None, + lastStableOffset = None, + abortedTransactions = None, + preferredReadReplica = None, + isReassignmentFetch = false) + mockFetchFromCurrentLog(tid1p0, fencedRequestData, config, replicaManager, fencedResponseData) + + val endPoint = new BrokerEndPoint(0, "localhost", 1000) + val thread = new ReplicaAlterLogDirsThread( + "alter-logs-dirs-thread", + sourceBroker = endPoint, + brokerConfig = config, + failedPartitions = failedPartitions, + replicaMgr = replicaManager, + quota = quotaManager, + brokerTopicStats = new BrokerTopicStats) + + // Initially we add the partition with an older epoch which results in an error + thread.addPartitions(Map(t1p0 -> initialFetchState(fetchOffset = 0L, leaderEpoch - 1))) + assertTrue(thread.fetchState(t1p0).isDefined) + assertEquals(1, thread.partitionCount) + + thread.doWork() + + assertTrue(failedPartitions.contains(t1p0)) + assertEquals(None, thread.fetchState(t1p0)) + assertEquals(0, thread.partitionCount) + + // Next we update the epoch and assert that we can continue + thread.addPartitions(Map(t1p0 -> initialFetchState(fetchOffset = 0L, leaderEpoch))) + assertEquals(Some(leaderEpoch), thread.fetchState(t1p0).map(_.currentLeaderEpoch)) + assertEquals(1, thread.partitionCount) + + val requestData = new FetchRequest.PartitionData(topicId, 0L, 0L, + config.replicaFetchMaxBytes, Optional.of(leaderEpoch)) + val responseData = FetchPartitionData( + error = Errors.NONE, + highWatermark = 0L, + logStartOffset = 0L, + records = MemoryRecords.EMPTY, + divergingEpoch = None, + lastStableOffset = None, + abortedTransactions = None, + preferredReadReplica = None, + isReassignmentFetch = false) + mockFetchFromCurrentLog(tid1p0, requestData, config, replicaManager, responseData) + + thread.doWork() + + assertFalse(failedPartitions.contains(t1p0)) + assertEquals(None, thread.fetchState(t1p0)) + assertEquals(0, thread.partitionCount) + } + + @Test + def shouldReplaceCurrentLogDirWhenCaughtUp(): Unit = { + val brokerId = 1 + val partitionId = 0 + val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(brokerId, "localhost:1234")) + + val partition = Mockito.mock(classOf[Partition]) + val replicaManager = Mockito.mock(classOf[ReplicaManager]) + val quotaManager = Mockito.mock(classOf[ReplicationQuotaManager]) + val futureLog = Mockito.mock(classOf[UnifiedLog]) + + val leaderEpoch = 5 + val logEndOffset = 0 + + when(partition.partitionId).thenReturn(partitionId) + when(replicaManager.metadataCache).thenReturn(metadataCache) + when(replicaManager.futureLocalLogOrException(t1p0)).thenReturn(futureLog) + when(replicaManager.futureLogExists(t1p0)).thenReturn(true) + when(replicaManager.onlinePartition(t1p0)).thenReturn(Some(partition)) + when(replicaManager.getPartitionOrException(t1p0)).thenReturn(partition) + + when(quotaManager.isQuotaExceeded).thenReturn(false) + + when(partition.lastOffsetForLeaderEpoch(Optional.empty(), leaderEpoch, fetchOnlyFromLeader = false)) + .thenReturn(new EpochEndOffset() + .setPartition(partitionId) + .setErrorCode(Errors.NONE.code) + .setLeaderEpoch(leaderEpoch) + .setEndOffset(logEndOffset)) + when(partition.futureLocalLogOrException).thenReturn(futureLog) + doNothing().when(partition).truncateTo(offset = 0, isFuture = true) + when(partition.maybeReplaceCurrentWithFutureReplica()).thenReturn(true) + + when(futureLog.logStartOffset).thenReturn(0L) + when(futureLog.logEndOffset).thenReturn(0L) + when(futureLog.latestEpoch).thenReturn(None) + + val requestData = new FetchRequest.PartitionData(topicId, 0L, 0L, + config.replicaFetchMaxBytes, Optional.of(leaderEpoch)) + val responseData = FetchPartitionData( + error = Errors.NONE, + highWatermark = 0L, + logStartOffset = 0L, + records = MemoryRecords.EMPTY, + divergingEpoch = None, + lastStableOffset = None, + abortedTransactions = None, + preferredReadReplica = None, + isReassignmentFetch = false) + mockFetchFromCurrentLog(tid1p0, requestData, config, replicaManager, responseData) + + val endPoint = new BrokerEndPoint(0, "localhost", 1000) + val thread = new ReplicaAlterLogDirsThread( + "alter-logs-dirs-thread", + sourceBroker = endPoint, + brokerConfig = config, + failedPartitions = failedPartitions, + replicaMgr = replicaManager, + quota = quotaManager, + brokerTopicStats = new BrokerTopicStats) + + thread.addPartitions(Map(t1p0 -> initialFetchState(fetchOffset = 0L, leaderEpoch))) + assertTrue(thread.fetchState(t1p0).isDefined) + assertEquals(1, thread.partitionCount) + + thread.doWork() + + assertEquals(None, thread.fetchState(t1p0)) + assertEquals(0, thread.partitionCount) + } + + private def mockFetchFromCurrentLog(topicIdPartition: TopicIdPartition, + requestData: FetchRequest.PartitionData, + config: KafkaConfig, + replicaManager: ReplicaManager, + responseData: FetchPartitionData): Unit = { + val callbackCaptor: ArgumentCaptor[Seq[(TopicIdPartition, FetchPartitionData)] => Unit] = + ArgumentCaptor.forClass(classOf[Seq[(TopicIdPartition, FetchPartitionData)] => Unit]) + when(replicaManager.fetchMessages( + timeout = ArgumentMatchers.eq(0L), + replicaId = ArgumentMatchers.eq(Request.FutureLocalReplicaId), + fetchMinBytes = ArgumentMatchers.eq(0), + fetchMaxBytes = ArgumentMatchers.eq(config.replicaFetchResponseMaxBytes), + hardMaxBytesLimit = ArgumentMatchers.eq(false), + fetchInfos = ArgumentMatchers.eq(Seq(topicIdPartition -> requestData)), + quota = ArgumentMatchers.eq(UnboundedQuota), + responseCallback = callbackCaptor.capture(), + isolationLevel = ArgumentMatchers.eq(IsolationLevel.READ_UNCOMMITTED), + clientMetadata = ArgumentMatchers.eq(None) + )).thenAnswer(_ => { + callbackCaptor.getValue.apply(Seq((topicIdPartition, responseData))) + }) + } + + @Test + def issuesEpochRequestFromLocalReplica(): Unit = { + val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(1, "localhost:1234")) + + //Setup all dependencies + + val partitionT1p0: Partition = createMock(classOf[Partition]) + val partitionT1p1: Partition = createMock(classOf[Partition]) + val replicaManager: ReplicaManager = createMock(classOf[ReplicaManager]) + + val partitionT1p0Id = 0 + val partitionT1p1Id = 1 + val leaderEpochT1p0 = 2 + val leaderEpochT1p1 = 5 + val leoT1p0 = 13 + val leoT1p1 = 232 + + //Stubs + expect(partitionT1p0.partitionId).andStubReturn(partitionT1p0Id) + expect(partitionT1p0.partitionId).andStubReturn(partitionT1p1Id) + + expect(replicaManager.getPartitionOrException(t1p0)) + .andStubReturn(partitionT1p0) + expect(partitionT1p0.lastOffsetForLeaderEpoch(Optional.empty(), leaderEpochT1p0, fetchOnlyFromLeader = false)) + .andReturn(new EpochEndOffset() + .setPartition(partitionT1p0Id) + .setErrorCode(Errors.NONE.code) + .setLeaderEpoch(leaderEpochT1p0) + .setEndOffset(leoT1p0)) + .anyTimes() + + expect(replicaManager.getPartitionOrException(t1p1)) + .andStubReturn(partitionT1p1) + expect(partitionT1p1.lastOffsetForLeaderEpoch(Optional.empty(), leaderEpochT1p1, fetchOnlyFromLeader = false)) + .andReturn(new EpochEndOffset() + .setPartition(partitionT1p1Id) + .setErrorCode(Errors.NONE.code) + .setLeaderEpoch(leaderEpochT1p1) + .setEndOffset(leoT1p1)) + .anyTimes() + + replay(partitionT1p0, partitionT1p1, replicaManager) + + val endPoint = new BrokerEndPoint(0, "localhost", 1000) + val thread = new ReplicaAlterLogDirsThread( + "alter-logs-dirs-thread-test1", + sourceBroker = endPoint, + brokerConfig = config, + failedPartitions = failedPartitions, + replicaMgr = replicaManager, + quota = null, + brokerTopicStats = null) + + val result = thread.fetchEpochEndOffsets(Map( + t1p0 -> new OffsetForLeaderPartition() + .setPartition(t1p0.partition) + .setLeaderEpoch(leaderEpochT1p0), + t1p1 -> new OffsetForLeaderPartition() + .setPartition(t1p1.partition) + .setLeaderEpoch(leaderEpochT1p1))) + + val expected = Map( + t1p0 -> new EpochEndOffset() + .setPartition(t1p0.partition) + .setErrorCode(Errors.NONE.code) + .setLeaderEpoch(leaderEpochT1p0) + .setEndOffset(leoT1p0), + t1p1 -> new EpochEndOffset() + .setPartition(t1p1.partition) + .setErrorCode(Errors.NONE.code) + .setLeaderEpoch(leaderEpochT1p1) + .setEndOffset(leoT1p1) + ) + + assertEquals(expected, result, "results from leader epoch request should have offset from local replica") + } + + @Test + def fetchEpochsFromLeaderShouldHandleExceptionFromGetLocalReplica(): Unit = { + val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(1, "localhost:1234")) + + //Setup all dependencies + val partitionT1p0: Partition = createMock(classOf[Partition]) + val replicaManager: ReplicaManager = createMock(classOf[ReplicaManager]) + + val partitionId = 0 + val leaderEpoch = 2 + val leo = 13 + + //Stubs + expect(partitionT1p0.partitionId).andStubReturn(partitionId) + + expect(replicaManager.getPartitionOrException(t1p0)) + .andStubReturn(partitionT1p0) + expect(partitionT1p0.lastOffsetForLeaderEpoch(Optional.empty(), leaderEpoch, fetchOnlyFromLeader = false)) + .andReturn(new EpochEndOffset() + .setPartition(partitionId) + .setErrorCode(Errors.NONE.code) + .setLeaderEpoch(leaderEpoch) + .setEndOffset(leo)) + .anyTimes() + + expect(replicaManager.getPartitionOrException(t1p1)) + .andThrow(new KafkaStorageException).once() + + replay(partitionT1p0, replicaManager) + + val endPoint = new BrokerEndPoint(0, "localhost", 1000) + val thread = new ReplicaAlterLogDirsThread( + "alter-logs-dirs-thread-test1", + sourceBroker = endPoint, + brokerConfig = config, + failedPartitions = failedPartitions, + replicaMgr = replicaManager, + quota = null, + brokerTopicStats = null) + + val result = thread.fetchEpochEndOffsets(Map( + t1p0 -> new OffsetForLeaderPartition() + .setPartition(t1p0.partition) + .setLeaderEpoch(leaderEpoch), + t1p1 -> new OffsetForLeaderPartition() + .setPartition(t1p1.partition) + .setLeaderEpoch(leaderEpoch))) + + val expected = Map( + t1p0 -> new EpochEndOffset() + .setPartition(t1p0.partition) + .setErrorCode(Errors.NONE.code) + .setLeaderEpoch(leaderEpoch) + .setEndOffset(leo), + t1p1 -> new EpochEndOffset() + .setPartition(t1p1.partition) + .setErrorCode(Errors.KAFKA_STORAGE_ERROR.code) + ) + + assertEquals(expected, result) + } + + @Test + def shouldTruncateToReplicaOffset(): Unit = { + + //Create a capture to track what partitions/offsets are truncated + val truncateCaptureT1p0: Capture[Long] = newCapture(CaptureType.ALL) + val truncateCaptureT1p1: Capture[Long] = newCapture(CaptureType.ALL) + + // Setup all the dependencies + val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(1, "localhost:1234")) + val quotaManager: ReplicationQuotaManager = createNiceMock(classOf[ReplicationQuotaManager]) + val logManager: LogManager = createMock(classOf[LogManager]) + val logT1p0: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + val logT1p1: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + // one future replica mock because our mocking methods return same values for both future replicas + val futureLogT1p0: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + val futureLogT1p1: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + val partitionT1p0: Partition = createMock(classOf[Partition]) + val partitionT1p1: Partition = createMock(classOf[Partition]) + val replicaManager: ReplicaManager = createMock(classOf[ReplicaManager]) + val responseCallback: Capture[Seq[(TopicIdPartition, FetchPartitionData)] => Unit] = EasyMock.newCapture() + + val partitionT1p0Id = 0 + val partitionT1p1Id = 1 + val leaderEpoch = 2 + val futureReplicaLEO = 191 + val replicaT1p0LEO = 190 + val replicaT1p1LEO = 192 + + //Stubs + expect(partitionT1p0.partitionId).andStubReturn(partitionT1p0Id) + expect(partitionT1p1.partitionId).andStubReturn(partitionT1p1Id) + + expect(replicaManager.metadataCache).andStubReturn(metadataCache) + expect(replicaManager.getPartitionOrException(t1p0)) + .andStubReturn(partitionT1p0) + expect(replicaManager.getPartitionOrException(t1p1)) + .andStubReturn(partitionT1p1) + expect(replicaManager.futureLocalLogOrException(t1p0)).andStubReturn(futureLogT1p0) + expect(replicaManager.futureLogExists(t1p0)).andStubReturn(true) + expect(replicaManager.futureLocalLogOrException(t1p1)).andStubReturn(futureLogT1p1) + expect(replicaManager.futureLogExists(t1p1)).andStubReturn(true) + expect(partitionT1p0.truncateTo(capture(truncateCaptureT1p0), anyBoolean())).anyTimes() + expect(partitionT1p1.truncateTo(capture(truncateCaptureT1p1), anyBoolean())).anyTimes() + + expect(futureLogT1p0.logEndOffset).andReturn(futureReplicaLEO).anyTimes() + expect(futureLogT1p1.logEndOffset).andReturn(futureReplicaLEO).anyTimes() + + expect(futureLogT1p0.latestEpoch).andReturn(Some(leaderEpoch)).anyTimes() + expect(futureLogT1p0.endOffsetForEpoch(leaderEpoch)).andReturn( + Some(OffsetAndEpoch(futureReplicaLEO, leaderEpoch))).anyTimes() + expect(partitionT1p0.lastOffsetForLeaderEpoch(Optional.of(1), leaderEpoch, fetchOnlyFromLeader = false)) + .andReturn(new EpochEndOffset() + .setPartition(partitionT1p0Id) + .setErrorCode(Errors.NONE.code) + .setLeaderEpoch(leaderEpoch) + .setEndOffset(replicaT1p0LEO)) + .anyTimes() + + expect(futureLogT1p1.latestEpoch).andReturn(Some(leaderEpoch)).anyTimes() + expect(futureLogT1p1.endOffsetForEpoch(leaderEpoch)).andReturn( + Some(OffsetAndEpoch(futureReplicaLEO, leaderEpoch))).anyTimes() + expect(partitionT1p1.lastOffsetForLeaderEpoch(Optional.of(1), leaderEpoch, fetchOnlyFromLeader = false)) + .andReturn(new EpochEndOffset() + .setPartition(partitionT1p1Id) + .setErrorCode(Errors.NONE.code) + .setLeaderEpoch(leaderEpoch) + .setEndOffset(replicaT1p1LEO)) + .anyTimes() + + expect(replicaManager.logManager).andReturn(logManager).anyTimes() + stubWithFetchMessages(logT1p0, logT1p1, futureLogT1p0, partitionT1p0, replicaManager, responseCallback) + + replay(replicaManager, logManager, quotaManager, partitionT1p0, partitionT1p1, logT1p0, logT1p1, futureLogT1p0, futureLogT1p1) + + //Create the thread + val endPoint = new BrokerEndPoint(0, "localhost", 1000) + val thread = new ReplicaAlterLogDirsThread( + "alter-logs-dirs-thread-test1", + sourceBroker = endPoint, + brokerConfig = config, + failedPartitions = failedPartitions, + replicaMgr = replicaManager, + quota = quotaManager, + brokerTopicStats = null) + thread.addPartitions(Map(t1p0 -> initialFetchState(0L), t1p1 -> initialFetchState(0L))) + + //Run it + thread.doWork() + + //We should have truncated to the offsets in the response + assertEquals(replicaT1p0LEO, truncateCaptureT1p0.getValue) + assertEquals(futureReplicaLEO, truncateCaptureT1p1.getValue) + } + + @Test + def shouldTruncateToEndOffsetOfLargestCommonEpoch(): Unit = { + + //Create a capture to track what partitions/offsets are truncated + val truncateToCapture: Capture[Long] = newCapture(CaptureType.ALL) + + // Setup all the dependencies + val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(1, "localhost:1234")) + val quotaManager: ReplicationQuotaManager = createNiceMock(classOf[ReplicationQuotaManager]) + val logManager: LogManager = createMock(classOf[LogManager]) + val log: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + // one future replica mock because our mocking methods return same values for both future replicas + val futureLog: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + val partition: Partition = createMock(classOf[Partition]) + val replicaManager: ReplicaManager = createMock(classOf[ReplicaManager]) + val responseCallback: Capture[Seq[(TopicIdPartition, FetchPartitionData)] => Unit] = EasyMock.newCapture() + + val partitionId = 0 + val leaderEpoch = 5 + val futureReplicaLEO = 195 + val replicaLEO = 200 + val replicaEpochEndOffset = 190 + val futureReplicaEpochEndOffset = 191 + + //Stubs + expect(partition.partitionId).andStubReturn(partitionId) + + expect(replicaManager.metadataCache).andStubReturn(metadataCache) + expect(replicaManager.getPartitionOrException(t1p0)) + .andStubReturn(partition) + expect(replicaManager.futureLocalLogOrException(t1p0)).andStubReturn(futureLog) + expect(replicaManager.futureLogExists(t1p0)).andStubReturn(true) + + expect(partition.truncateTo(capture(truncateToCapture), EasyMock.eq(true))).anyTimes() + expect(futureLog.logEndOffset).andReturn(futureReplicaLEO).anyTimes() + expect(futureLog.latestEpoch).andReturn(Some(leaderEpoch)).once() + expect(futureLog.latestEpoch).andReturn(Some(leaderEpoch - 2)).times(3) + + // leader replica truncated and fetched new offsets with new leader epoch + expect(partition.lastOffsetForLeaderEpoch(Optional.of(1), leaderEpoch, fetchOnlyFromLeader = false)) + .andReturn(new EpochEndOffset() + .setPartition(partitionId) + .setErrorCode(Errors.NONE.code) + .setLeaderEpoch(leaderEpoch - 1) + .setEndOffset(replicaLEO)) + .anyTimes() + // but future replica does not know about this leader epoch, so returns a smaller leader epoch + expect(futureLog.endOffsetForEpoch(leaderEpoch - 1)).andReturn( + Some(OffsetAndEpoch(futureReplicaLEO, leaderEpoch - 2))).anyTimes() + // finally, the leader replica knows about the leader epoch and returns end offset + expect(partition.lastOffsetForLeaderEpoch(Optional.of(1), leaderEpoch - 2, fetchOnlyFromLeader = false)) + .andReturn(new EpochEndOffset() + .setPartition(partitionId) + .setErrorCode(Errors.NONE.code) + .setLeaderEpoch(leaderEpoch - 2) + .setEndOffset(replicaEpochEndOffset)) + .anyTimes() + expect(futureLog.endOffsetForEpoch(leaderEpoch - 2)).andReturn( + Some(OffsetAndEpoch(futureReplicaEpochEndOffset, leaderEpoch - 2))).anyTimes() + + expect(replicaManager.logManager).andReturn(logManager).anyTimes() + stubWithFetchMessages(log, null, futureLog, partition, replicaManager, responseCallback) + + replay(replicaManager, logManager, quotaManager, partition, log, futureLog) + + //Create the thread + val endPoint = new BrokerEndPoint(0, "localhost", 1000) + val thread = new ReplicaAlterLogDirsThread( + "alter-logs-dirs-thread-test1", + sourceBroker = endPoint, + brokerConfig = config, + failedPartitions = failedPartitions, + replicaMgr = replicaManager, + quota = quotaManager, + brokerTopicStats = null) + thread.addPartitions(Map(t1p0 -> initialFetchState(0L))) + + // First run will result in another offset for leader epoch request + thread.doWork() + // Second run should actually truncate + thread.doWork() + + //We should have truncated to the offsets in the response + assertTrue(truncateToCapture.getValues.asScala.contains(replicaEpochEndOffset), + "Expected offset " + replicaEpochEndOffset + " in captured truncation offsets " + truncateToCapture.getValues) + } + + @Test + def shouldTruncateToInitialFetchOffsetIfReplicaReturnsUndefinedOffset(): Unit = { + + //Create a capture to track what partitions/offsets are truncated + val truncated: Capture[Long] = newCapture(CaptureType.ALL) + + // Setup all the dependencies + val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(1, "localhost:1234")) + val quotaManager: ReplicationQuotaManager = createNiceMock(classOf[ReplicationQuotaManager]) + val logManager: LogManager = createMock(classOf[LogManager]) + val log: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + val futureLog: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + val partition: Partition = createMock(classOf[Partition]) + val replicaManager: ReplicaManager = createMock(classOf[ReplicaManager]) + val responseCallback: Capture[Seq[(TopicIdPartition, FetchPartitionData)] => Unit] = EasyMock.newCapture() + + val initialFetchOffset = 100 + + //Stubs + expect(replicaManager.getPartitionOrException(t1p0)) + .andStubReturn(partition) + expect(partition.truncateTo(capture(truncated), isFuture = EasyMock.eq(true))).anyTimes() + expect(replicaManager.metadataCache).andStubReturn(metadataCache) + expect(replicaManager.futureLocalLogOrException(t1p0)).andStubReturn(futureLog) + expect(replicaManager.futureLogExists(t1p0)).andStubReturn(true) + + expect(replicaManager.logManager).andReturn(logManager).anyTimes() + + // pretend this is a completely new future replica, with no leader epochs recorded + expect(futureLog.latestEpoch).andReturn(None).anyTimes() + + stubWithFetchMessages(log, null, futureLog, partition, replicaManager, responseCallback) + replay(replicaManager, logManager, quotaManager, partition, log, futureLog) + + //Create the thread + val endPoint = new BrokerEndPoint(0, "localhost", 1000) + val thread = new ReplicaAlterLogDirsThread( + "alter-logs-dirs-thread-test1", + sourceBroker = endPoint, + brokerConfig = config, + failedPartitions = failedPartitions, + replicaMgr = replicaManager, + quota = quotaManager, + brokerTopicStats = null) + thread.addPartitions(Map(t1p0 -> initialFetchState(initialFetchOffset))) + + //Run it + thread.doWork() + + //We should have truncated to initial fetch offset + assertEquals(initialFetchOffset, + truncated.getValue, "Expected future replica to truncate to initial fetch offset if replica returns UNDEFINED_EPOCH_OFFSET") + } + + @Test + def shouldPollIndefinitelyIfReplicaNotAvailable(): Unit = { + + //Create a capture to track what partitions/offsets are truncated + val truncated: Capture[Long] = newCapture(CaptureType.ALL) + + // Setup all the dependencies + val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(1, "localhost:1234")) + val quotaManager: ReplicationQuotaManager = createNiceMock(classOf[ReplicationQuotaManager]) + val logManager: LogManager = createMock(classOf[LogManager]) + val log: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + val futureLog: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + val partition: Partition = createMock(classOf[Partition]) + val replicaManager: ReplicaManager = createMock(classOf[ReplicaManager]) + val responseCallback: Capture[Seq[(TopicIdPartition, FetchPartitionData)] => Unit] = EasyMock.newCapture() + + val partitionId = 0 + val futureReplicaLeaderEpoch = 1 + val futureReplicaLEO = 290 + val replicaLEO = 300 + + //Stubs + expect(partition.partitionId).andStubReturn(partitionId) + + expect(replicaManager.metadataCache).andStubReturn(metadataCache) + expect(replicaManager.getPartitionOrException(t1p0)) + .andStubReturn(partition) + expect(partition.truncateTo(capture(truncated), isFuture = EasyMock.eq(true))).once() + + expect(replicaManager.futureLocalLogOrException(t1p0)).andStubReturn(futureLog) + expect(replicaManager.futureLogExists(t1p0)).andStubReturn(true) + expect(futureLog.logEndOffset).andReturn(futureReplicaLEO).anyTimes() + expect(futureLog.latestEpoch).andStubReturn(Some(futureReplicaLeaderEpoch)) + expect(futureLog.endOffsetForEpoch(futureReplicaLeaderEpoch)).andReturn( + Some(OffsetAndEpoch(futureReplicaLEO, futureReplicaLeaderEpoch))) + expect(replicaManager.localLog(t1p0)).andReturn(Some(log)).anyTimes() + + // this will cause fetchEpochsFromLeader return an error with undefined offset + expect(partition.lastOffsetForLeaderEpoch(Optional.of(1), futureReplicaLeaderEpoch, fetchOnlyFromLeader = false)) + .andReturn(new EpochEndOffset() + .setPartition(partitionId) + .setErrorCode(Errors.REPLICA_NOT_AVAILABLE.code)) + .times(3) + .andReturn(new EpochEndOffset() + .setPartition(partitionId) + .setErrorCode(Errors.NONE.code) + .setLeaderEpoch(futureReplicaLeaderEpoch) + .setEndOffset(replicaLEO)) + + expect(replicaManager.logManager).andReturn(logManager).anyTimes() + expect(replicaManager.fetchMessages( + EasyMock.anyLong(), + EasyMock.anyInt(), + EasyMock.anyInt(), + EasyMock.anyInt(), + EasyMock.anyObject(), + EasyMock.anyObject(), + EasyMock.anyObject(), + EasyMock.capture(responseCallback), + EasyMock.anyObject(), + EasyMock.anyObject()) + ).andAnswer(() => responseCallback.getValue.apply(Seq.empty[(TopicIdPartition, FetchPartitionData)])).anyTimes() + + replay(replicaManager, logManager, quotaManager, partition, log, futureLog) + + //Create the thread + val endPoint = new BrokerEndPoint(0, "localhost", 1000) + val thread = new ReplicaAlterLogDirsThread( + "alter-logs-dirs-thread-test1", + sourceBroker = endPoint, + brokerConfig = config, + failedPartitions = failedPartitions, + replicaMgr = replicaManager, + quota = quotaManager, + brokerTopicStats = null) + thread.addPartitions(Map(t1p0 -> initialFetchState(0L))) + + // Run thread 3 times (exactly number of times we mock exception for getReplicaOrException) + (0 to 2).foreach { _ => + thread.doWork() + } + + // Nothing happened since the replica was not available + assertEquals(0, truncated.getValues.size()) + + // Next time we loop, getReplicaOrException will return replica + thread.doWork() + + // Now the final call should have actually done a truncation (to offset futureReplicaLEO) + assertEquals(futureReplicaLEO, truncated.getValue) + } + + @Test + def shouldFetchLeaderEpochOnFirstFetchOnly(): Unit = { + + //Setup all dependencies + val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(1, "localhost:1234")) + val quotaManager: ReplicationQuotaManager = createNiceMock(classOf[ReplicationQuotaManager]) + val logManager: LogManager = createMock(classOf[LogManager]) + val log: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + val futureLog: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + val partition: Partition = createMock(classOf[Partition]) + val replicaManager: ReplicaManager = createMock(classOf[ReplicaManager]) + val responseCallback: Capture[Seq[(TopicIdPartition, FetchPartitionData)] => Unit] = EasyMock.newCapture() + + val partitionId = 0 + val leaderEpoch = 5 + val futureReplicaLEO = 190 + val replicaLEO = 213 + + expect(partition.partitionId).andStubReturn(partitionId) + + expect(replicaManager.metadataCache).andStubReturn(metadataCache) + expect(replicaManager.getPartitionOrException(t1p0)) + .andStubReturn(partition) + expect(partition.lastOffsetForLeaderEpoch(Optional.of(1), leaderEpoch, fetchOnlyFromLeader = false)) + .andReturn(new EpochEndOffset() + .setPartition(partitionId) + .setErrorCode(Errors.NONE.code) + .setLeaderEpoch(leaderEpoch) + .setEndOffset(replicaLEO)) + expect(partition.truncateTo(futureReplicaLEO, isFuture = true)).once() + + expect(replicaManager.futureLocalLogOrException(t1p0)).andStubReturn(futureLog) + expect(replicaManager.futureLogExists(t1p0)).andStubReturn(true) + expect(futureLog.latestEpoch).andStubReturn(Some(leaderEpoch)) + expect(futureLog.logEndOffset).andStubReturn(futureReplicaLEO) + expect(futureLog.endOffsetForEpoch(leaderEpoch)).andReturn( + Some(OffsetAndEpoch(futureReplicaLEO, leaderEpoch))) + expect(replicaManager.logManager).andReturn(logManager).anyTimes() + stubWithFetchMessages(log, null, futureLog, partition, replicaManager, responseCallback) + + replay(replicaManager, logManager, quotaManager, partition, log, futureLog) + + //Create the fetcher thread + val endPoint = new BrokerEndPoint(0, "localhost", 1000) + val thread = new ReplicaAlterLogDirsThread( + "alter-logs-dirs-thread-test1", + sourceBroker = endPoint, + brokerConfig = config, + failedPartitions = failedPartitions, + replicaMgr = replicaManager, + quota = quotaManager, + brokerTopicStats = null) + thread.addPartitions(Map(t1p0 -> initialFetchState(0L))) + + // loop few times + (0 to 3).foreach { _ => + thread.doWork() + } + + //Assert that truncate to is called exactly once (despite more loops) + verify(partition) + } + + @Test + def shouldFetchOneReplicaAtATime(): Unit = { + + //Setup all dependencies + val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(1, "localhost:1234")) + val quotaManager: ReplicationQuotaManager = createNiceMock(classOf[ReplicationQuotaManager]) + val logManager: LogManager = createMock(classOf[LogManager]) + val log: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + val futureLog: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + val partition: Partition = createMock(classOf[Partition]) + val replicaManager: ReplicaManager = createMock(classOf[ReplicaManager]) + + //Stubs + expect(replicaManager.logManager).andReturn(logManager).anyTimes() + expect(replicaManager.metadataCache).andStubReturn(metadataCache) + stub(log, null, futureLog, partition, replicaManager) + + replay(replicaManager, logManager, quotaManager, partition, log) + + //Create the fetcher thread + val endPoint = new BrokerEndPoint(0, "localhost", 1000) + val leaderEpoch = 1 + val thread = new ReplicaAlterLogDirsThread( + "alter-logs-dirs-thread-test1", + sourceBroker = endPoint, + brokerConfig = config, + failedPartitions = failedPartitions, + replicaMgr = replicaManager, + quota = quotaManager, + brokerTopicStats = null) + thread.addPartitions(Map( + t1p0 -> initialFetchState(0L, leaderEpoch), + t1p1 -> initialFetchState(0L, leaderEpoch))) + + val ResultWithPartitions(fetchRequestOpt, partitionsWithError) = thread.buildFetch(Map( + t1p0 -> PartitionFetchState(Some(topicId), 150, None, leaderEpoch, None, state = Fetching, lastFetchedEpoch = None), + t1p1 -> PartitionFetchState(Some(topicId), 160, None, leaderEpoch, None, state = Fetching, lastFetchedEpoch = None))) + + assertTrue(fetchRequestOpt.isDefined) + val fetchRequest = fetchRequestOpt.get.fetchRequest + assertFalse(fetchRequest.fetchData.isEmpty) + assertFalse(partitionsWithError.nonEmpty) + val request = fetchRequest.build() + assertEquals(0, request.minBytes) + val fetchInfos = request.fetchData(topicNames.asJava).asScala.toSeq + assertEquals(1, fetchInfos.length) + assertEquals(t1p0, fetchInfos.head._1.topicPartition, "Expected fetch request for first partition") + assertEquals(150, fetchInfos.head._2.fetchOffset) + } + + @Test + def shouldFetchNonDelayedAndNonTruncatingReplicas(): Unit = { + + //Setup all dependencies + val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(1, "localhost:1234")) + val quotaManager: ReplicationQuotaManager = createNiceMock(classOf[ReplicationQuotaManager]) + val logManager: LogManager = createMock(classOf[LogManager]) + val log: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + val futureLog: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + val partition: Partition = createMock(classOf[Partition]) + val replicaManager: ReplicaManager = createMock(classOf[ReplicaManager]) + + //Stubs + val startOffset = 123 + expect(futureLog.logStartOffset).andReturn(startOffset).anyTimes() + expect(replicaManager.logManager).andReturn(logManager).anyTimes() + expect(replicaManager.metadataCache).andStubReturn(metadataCache) + stub(log, null, futureLog, partition, replicaManager) + + replay(replicaManager, logManager, quotaManager, partition, log, futureLog) + + //Create the fetcher thread + val endPoint = new BrokerEndPoint(0, "localhost", 1000) + val leaderEpoch = 1 + val thread = new ReplicaAlterLogDirsThread( + "alter-logs-dirs-thread-test1", + sourceBroker = endPoint, + brokerConfig = config, + failedPartitions = failedPartitions, + replicaMgr = replicaManager, + quota = quotaManager, + brokerTopicStats = null) + thread.addPartitions(Map( + t1p0 -> initialFetchState(0L, leaderEpoch), + t1p1 -> initialFetchState(0L, leaderEpoch))) + + // one partition is ready and one is truncating + val ResultWithPartitions(fetchRequestOpt, partitionsWithError) = thread.buildFetch(Map( + t1p0 -> PartitionFetchState(Some(topicId), 150, None, leaderEpoch, state = Fetching, lastFetchedEpoch = None), + t1p1 -> PartitionFetchState(Some(topicId), 160, None, leaderEpoch, state = Truncating, lastFetchedEpoch = None))) + + assertTrue(fetchRequestOpt.isDefined) + val fetchRequest = fetchRequestOpt.get + assertFalse(fetchRequest.partitionData.isEmpty) + assertFalse(partitionsWithError.nonEmpty) + val fetchInfos = fetchRequest.fetchRequest.build().fetchData(topicNames.asJava).asScala.toSeq + assertEquals(1, fetchInfos.length) + assertEquals(t1p0, fetchInfos.head._1.topicPartition, "Expected fetch request for non-truncating partition") + assertEquals(150, fetchInfos.head._2.fetchOffset) + + // one partition is ready and one is delayed + val ResultWithPartitions(fetchRequest2Opt, partitionsWithError2) = thread.buildFetch(Map( + t1p0 -> PartitionFetchState(Some(topicId), 140, None, leaderEpoch, state = Fetching, lastFetchedEpoch = None), + t1p1 -> PartitionFetchState(Some(topicId), 160, None, leaderEpoch, delay = Some(new DelayedItem(5000)), state = Fetching, lastFetchedEpoch = None))) + + assertTrue(fetchRequest2Opt.isDefined) + val fetchRequest2 = fetchRequest2Opt.get + assertFalse(fetchRequest2.partitionData.isEmpty) + assertFalse(partitionsWithError2.nonEmpty) + val fetchInfos2 = fetchRequest2.fetchRequest.build().fetchData(topicNames.asJava).asScala.toSeq + assertEquals(1, fetchInfos2.length) + assertEquals(t1p0, fetchInfos2.head._1.topicPartition, "Expected fetch request for non-delayed partition") + assertEquals(140, fetchInfos2.head._2.fetchOffset) + + // both partitions are delayed + val ResultWithPartitions(fetchRequest3Opt, partitionsWithError3) = thread.buildFetch(Map( + t1p0 -> PartitionFetchState(Some(topicId), 140, None, leaderEpoch, delay = Some(new DelayedItem(5000)), state = Fetching, lastFetchedEpoch = None), + t1p1 -> PartitionFetchState(Some(topicId), 160, None, leaderEpoch, delay = Some(new DelayedItem(5000)), state = Fetching, lastFetchedEpoch = None))) + assertTrue(fetchRequest3Opt.isEmpty, "Expected no fetch requests since all partitions are delayed") + assertFalse(partitionsWithError3.nonEmpty) + } + + def stub(logT1p0: UnifiedLog, logT1p1: UnifiedLog, futureLog: UnifiedLog, partition: Partition, + replicaManager: ReplicaManager): IExpectationSetters[Option[Partition]] = { + expect(replicaManager.localLog(t1p0)).andReturn(Some(logT1p0)).anyTimes() + expect(replicaManager.localLogOrException(t1p0)).andReturn(logT1p0).anyTimes() + expect(replicaManager.futureLocalLogOrException(t1p0)).andReturn(futureLog).anyTimes() + expect(replicaManager.futureLogExists(t1p0)).andStubReturn(true) + expect(replicaManager.onlinePartition(t1p0)).andReturn(Some(partition)).anyTimes() + expect(replicaManager.localLog(t1p1)).andReturn(Some(logT1p1)).anyTimes() + expect(replicaManager.localLogOrException(t1p1)).andReturn(logT1p1).anyTimes() + expect(replicaManager.futureLocalLogOrException(t1p1)).andReturn(futureLog).anyTimes() + expect(replicaManager.futureLogExists(t1p1)).andStubReturn(true) + expect(replicaManager.onlinePartition(t1p1)).andReturn(Some(partition)).anyTimes() + } + + def stubWithFetchMessages(logT1p0: UnifiedLog, logT1p1: UnifiedLog, futureLog: UnifiedLog, partition: Partition, replicaManager: ReplicaManager, + responseCallback: Capture[Seq[(TopicIdPartition, FetchPartitionData)] => Unit]): IExpectationSetters[Unit] = { + stub(logT1p0, logT1p1, futureLog, partition, replicaManager) + expect(replicaManager.fetchMessages( + EasyMock.anyLong(), + EasyMock.anyInt(), + EasyMock.anyInt(), + EasyMock.anyInt(), + EasyMock.anyObject(), + EasyMock.anyObject(), + EasyMock.anyObject(), + EasyMock.capture(responseCallback), + EasyMock.anyObject(), + EasyMock.anyObject()) + ).andAnswer(() => responseCallback.getValue.apply(Seq.empty[(TopicIdPartition, FetchPartitionData)])).anyTimes() + } +} diff --git a/core/src/test/scala/unit/kafka/server/ReplicaFetchTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaFetchTest.scala new file mode 100644 index 0000000..4a5ab60 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/ReplicaFetchTest.scala @@ -0,0 +1,81 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ + +package kafka.server + +import scala.collection.Seq + +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} +import kafka.server.QuorumTestHarness +import kafka.utils.TestUtils +import TestUtils._ +import org.apache.kafka.clients.producer.ProducerRecord +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.serialization.StringSerializer + +class ReplicaFetchTest extends QuorumTestHarness { + var brokers: Seq[KafkaServer] = null + val topic1 = "foo" + val topic2 = "bar" + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + val props = createBrokerConfigs(2, zkConnect) + brokers = props.map(KafkaConfig.fromProps).map(TestUtils.createServer(_)) + } + + @AfterEach + override def tearDown(): Unit = { + TestUtils.shutdownServers(brokers) + super.tearDown() + } + + @Test + def testReplicaFetcherThread(): Unit = { + val partition = 0 + val testMessageList1 = List("test1", "test2", "test3", "test4") + val testMessageList2 = List("test5", "test6", "test7", "test8") + + // create a topic and partition and await leadership + for (topic <- List(topic1,topic2)) { + createTopic(zkClient, topic, numPartitions = 1, replicationFactor = 2, servers = brokers) + } + + // send test messages to leader + val producer = TestUtils.createProducer(TestUtils.getBrokerListStrFromServers(brokers), + keySerializer = new StringSerializer, + valueSerializer = new StringSerializer) + val records = testMessageList1.map(m => new ProducerRecord(topic1, m, m)) ++ + testMessageList2.map(m => new ProducerRecord(topic2, m, m)) + records.map(producer.send).foreach(_.get) + producer.close() + + def logsMatch(): Boolean = { + var result = true + for (topic <- List(topic1, topic2)) { + val tp = new TopicPartition(topic, partition) + val expectedOffset = brokers.head.getLogManager.getLog(tp).get.logEndOffset + result = result && expectedOffset > 0 && brokers.forall { item => + expectedOffset == item.getLogManager.getLog(tp).get.logEndOffset + } + } + result + } + waitUntilTrue(logsMatch _, "Broker logs should be identical") + } +} diff --git a/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala new file mode 100644 index 0000000..bbd9330 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/ReplicaFetcherThreadTest.scala @@ -0,0 +1,1110 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import kafka.api.{ApiVersion, KAFKA_2_6_IV0} +import kafka.cluster.{BrokerEndPoint, Partition} +import kafka.log.{LogAppendInfo, LogManager, UnifiedLog} +import kafka.server.AbstractFetcherThread.ResultWithPartitions +import kafka.server.QuotaFactory.UnboundedQuota +import kafka.server.epoch.util.ReplicaFetcherMockBlockingSend +import kafka.server.metadata.ZkMetadataCache +import kafka.utils.TestUtils +import org.apache.kafka.common.{TopicIdPartition, TopicPartition, Uuid} +import org.apache.kafka.common.message.{FetchResponseData, UpdateMetadataRequestData} +import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderPartition +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.protocol.Errors._ +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.record.{CompressionType, MemoryRecords, SimpleRecord} +import org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.{UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET} +import org.apache.kafka.common.requests.{FetchRequest, FetchResponse, UpdateMetadataRequest} +import org.apache.kafka.common.utils.SystemTime +import org.easymock.EasyMock._ +import org.easymock.{Capture, CaptureType} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, Test} + +import java.nio.charset.StandardCharsets +import java.util +import java.util.{Collections, Optional} +import scala.collection.{Map, mutable} +import scala.jdk.CollectionConverters._ + +class ReplicaFetcherThreadTest { + + private val t1p0 = new TopicPartition("topic1", 0) + private val t1p1 = new TopicPartition("topic1", 1) + private val t2p1 = new TopicPartition("topic2", 1) + + private val topicId1 = Uuid.randomUuid() + private val topicId2 = Uuid.randomUuid() + + private val topicIds = Map("topic1" -> topicId1, "topic2" -> topicId2) + + private val brokerEndPoint = new BrokerEndPoint(0, "localhost", 1000) + private val failedPartitions = new FailedPartitions + + private val partitionStates = List( + new UpdateMetadataRequestData.UpdateMetadataPartitionState() + .setTopicName("topic1") + .setPartitionIndex(0) + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(0), + new UpdateMetadataRequestData.UpdateMetadataPartitionState() + .setTopicName("topic2") + .setPartitionIndex(0) + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(0), + ).asJava + + private val updateMetadataRequest = new UpdateMetadataRequest.Builder(ApiKeys.UPDATE_METADATA.latestVersion(), + 0, 0, 0, partitionStates, Collections.emptyList(), topicIds.asJava).build() + // TODO: support raft code? + private val metadataCache = new ZkMetadataCache(0) + metadataCache.updateMetadata(0, updateMetadataRequest) + + private def initialFetchState(topicId: Option[Uuid], fetchOffset: Long, leaderEpoch: Int = 1): InitialFetchState = { + InitialFetchState(topicId = topicId, leader = new BrokerEndPoint(0, "localhost", 9092), + initOffset = fetchOffset, currentLeaderEpoch = leaderEpoch) + } + + @AfterEach + def cleanup(): Unit = { + TestUtils.clearYammerMetrics() + } + + @Test + def shouldSendLatestRequestVersionsByDefault(): Unit = { + val props = TestUtils.createBrokerConfig(1, "localhost:1234") + val config = KafkaConfig.fromProps(props) + val replicaManager: ReplicaManager = mock(classOf[ReplicaManager]) + expect(replicaManager.brokerTopicStats).andReturn(mock(classOf[BrokerTopicStats])) + replay(replicaManager) + val thread = new ReplicaFetcherThread( + name = "bob", + fetcherId = 0, + sourceBroker = brokerEndPoint, + brokerConfig = config, + failedPartitions: FailedPartitions, + replicaMgr = replicaManager, + metrics = new Metrics(), + time = new SystemTime(), + quota = UnboundedQuota, + leaderEndpointBlockingSend = None) + assertEquals(ApiKeys.FETCH.latestVersion, thread.fetchRequestVersion) + assertEquals(ApiKeys.OFFSET_FOR_LEADER_EPOCH.latestVersion, thread.offsetForLeaderEpochRequestVersion) + assertEquals(ApiKeys.LIST_OFFSETS.latestVersion, thread.listOffsetRequestVersion) + } + + @Test + def testFetchLeaderEpochRequestIfLastEpochDefinedForSomePartitions(): Unit = { + val config = kafkaConfigNoTruncateOnFetch + + //Setup all dependencies + val quota: ReplicationQuotaManager = createNiceMock(classOf[ReplicationQuotaManager]) + val logManager: LogManager = createMock(classOf[LogManager]) + val replicaAlterLogDirsManager: ReplicaAlterLogDirsManager = createMock(classOf[ReplicaAlterLogDirsManager]) + val log: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + val partition: Partition = createMock(classOf[Partition]) + val replicaManager: ReplicaManager = createMock(classOf[ReplicaManager]) + + val leaderEpoch = 5 + + //Stubs + expect(partition.localLogOrException).andReturn(log).anyTimes() + expect(log.logEndOffset).andReturn(0).anyTimes() + expect(log.highWatermark).andReturn(0).anyTimes() + expect(log.latestEpoch).andReturn(Some(leaderEpoch)).once() + expect(log.latestEpoch).andReturn(Some(leaderEpoch)).once() + expect(log.latestEpoch).andReturn(None).once() // t2p1 doesnt support epochs + expect(log.endOffsetForEpoch(leaderEpoch)).andReturn( + Some(OffsetAndEpoch(0, leaderEpoch))).anyTimes() + expect(replicaManager.metadataCache).andStubReturn(metadataCache) + expect(replicaManager.logManager).andReturn(logManager).anyTimes() + expect(replicaManager.replicaAlterLogDirsManager).andReturn(replicaAlterLogDirsManager).anyTimes() + expect(replicaManager.brokerTopicStats).andReturn(mock(classOf[BrokerTopicStats])) + stub(partition, replicaManager, log) + + //Expectations + expect(partition.truncateTo(anyLong(), anyBoolean())).times(3) + + replay(replicaManager, logManager, quota, partition, log) + + //Define the offsets for the OffsetsForLeaderEpochResponse + val offsets = Map( + t1p0 -> newOffsetForLeaderPartitionResult(t1p0, leaderEpoch, 1), + t1p1 -> newOffsetForLeaderPartitionResult(t1p1, leaderEpoch, 1)).asJava + + //Create the fetcher thread + val mockNetwork = new ReplicaFetcherMockBlockingSend(offsets, brokerEndPoint, new SystemTime()) + + val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork)) + + // topic 1 supports epoch, t2 doesn't. + thread.addPartitions(Map( + t1p0 -> initialFetchState(Some(topicId1), 0L), + t1p1 -> initialFetchState(Some(topicId2), 0L), + t2p1 -> initialFetchState(Some(topicId2), 0L))) + + assertPartitionStates(thread, shouldBeReadyForFetch = false, shouldBeTruncatingLog = true, shouldBeDelayed = false) + //Loop 1 + thread.doWork() + assertEquals(1, mockNetwork.epochFetchCount) + assertEquals(1, mockNetwork.fetchCount) + + assertPartitionStates(thread, shouldBeReadyForFetch = true, shouldBeTruncatingLog = false, shouldBeDelayed = false) + + //Loop 2 we should not fetch epochs + thread.doWork() + assertEquals(1, mockNetwork.epochFetchCount) + assertEquals(2, mockNetwork.fetchCount) + + assertPartitionStates(thread, shouldBeReadyForFetch = true, shouldBeTruncatingLog = false, shouldBeDelayed = false) + + //Loop 3 we should not fetch epochs + thread.doWork() + assertEquals(1, mockNetwork.epochFetchCount) + assertEquals(3, mockNetwork.fetchCount) + + assertPartitionStates(thread, shouldBeReadyForFetch = true, shouldBeTruncatingLog = false, shouldBeDelayed = false) + + //Assert that truncate to is called exactly once (despite two loops) + verify(logManager) + } + + /** + * Assert that all partitions' states are as expected + * + */ + def assertPartitionStates(fetcher: AbstractFetcherThread, + shouldBeReadyForFetch: Boolean, + shouldBeTruncatingLog: Boolean, + shouldBeDelayed: Boolean): Unit = { + for (tp <- List(t1p0, t1p1, t2p1)) { + assertTrue(fetcher.fetchState(tp).isDefined) + val fetchState = fetcher.fetchState(tp).get + + assertEquals(shouldBeReadyForFetch, fetchState.isReadyForFetch, + s"Partition $tp should${if (!shouldBeReadyForFetch) " NOT" else ""} be ready for fetching") + + assertEquals(shouldBeTruncatingLog, fetchState.isTruncating, + s"Partition $tp should${if (!shouldBeTruncatingLog) " NOT" else ""} be truncating its log") + + assertEquals(shouldBeDelayed, fetchState.isDelayed, + s"Partition $tp should${if (!shouldBeDelayed) " NOT" else ""} be delayed") + } + } + + @Test + def shouldHandleExceptionFromBlockingSend(): Unit = { + val props = TestUtils.createBrokerConfig(1, "localhost:1234") + val config = KafkaConfig.fromProps(props) + val mockBlockingSend: BlockingSend = createMock(classOf[BlockingSend]) + + expect(mockBlockingSend.sendRequest(anyObject())).andThrow(new NullPointerException).once() + val replicaManager: ReplicaManager = mock(classOf[ReplicaManager]) + expect(replicaManager.brokerTopicStats).andReturn(mock(classOf[BrokerTopicStats])) + replay(mockBlockingSend, replicaManager) + + val thread = new ReplicaFetcherThread( + name = "bob", + fetcherId = 0, + sourceBroker = brokerEndPoint, + brokerConfig = config, + failedPartitions: FailedPartitions, + replicaMgr = replicaManager, + metrics = new Metrics(), + time = new SystemTime(), + quota = null, + leaderEndpointBlockingSend = Some(mockBlockingSend)) + + val result = thread.fetchEpochEndOffsets(Map( + t1p0 -> new OffsetForLeaderPartition() + .setPartition(t1p0.partition) + .setLeaderEpoch(0), + t1p1 -> new OffsetForLeaderPartition() + .setPartition(t1p1.partition) + .setLeaderEpoch(0))) + + val expected = Map( + t1p0 -> newOffsetForLeaderPartitionResult(t1p0, Errors.UNKNOWN_SERVER_ERROR, UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET), + t1p1 -> newOffsetForLeaderPartitionResult(t1p1, Errors.UNKNOWN_SERVER_ERROR, UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET) + ) + + assertEquals(expected, result, "results from leader epoch request should have undefined offset") + verify(mockBlockingSend) + } + + @Test + def shouldFetchLeaderEpochOnFirstFetchOnlyIfLeaderEpochKnownToBothIbp26(): Unit = { + verifyFetchLeaderEpochOnFirstFetch(KAFKA_2_6_IV0) + } + + @Test + def shouldNotFetchLeaderEpochOnFirstFetchWithTruncateOnFetch(): Unit = { + verifyFetchLeaderEpochOnFirstFetch(ApiVersion.latestVersion, epochFetchCount = 0) + } + + private def verifyFetchLeaderEpochOnFirstFetch(ibp: ApiVersion, epochFetchCount: Int = 1): Unit = { + val props = TestUtils.createBrokerConfig(1, "localhost:1234") + props.setProperty(KafkaConfig.InterBrokerProtocolVersionProp, ibp.version) + val config = KafkaConfig.fromProps(props) + + //Setup all dependencies + val logManager: LogManager = createMock(classOf[LogManager]) + val replicaAlterLogDirsManager: ReplicaAlterLogDirsManager = createMock(classOf[ReplicaAlterLogDirsManager]) + val log: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + val partition: Partition = createMock(classOf[Partition]) + val replicaManager: ReplicaManager = createMock(classOf[ReplicaManager]) + + val leaderEpoch = 5 + + //Stubs + expect(partition.localLogOrException).andReturn(log).anyTimes() + expect(log.highWatermark).andReturn(0).anyTimes() + expect(log.latestEpoch).andReturn(Some(leaderEpoch)).anyTimes() + expect(log.endOffsetForEpoch(leaderEpoch)).andReturn( + Some(OffsetAndEpoch(0, leaderEpoch))).anyTimes() + expect(replicaManager.metadataCache).andStubReturn(metadataCache) + expect(replicaManager.logManager).andReturn(logManager).anyTimes() + expect(replicaManager.replicaAlterLogDirsManager).andReturn(replicaAlterLogDirsManager).anyTimes() + expect(replicaManager.brokerTopicStats).andReturn(mock(classOf[BrokerTopicStats])) + stub(partition, replicaManager, log) + + //Expectations + expect(partition.truncateTo(anyLong(), anyBoolean())).times(2) + + replay(replicaManager, logManager, partition, log) + + //Define the offsets for the OffsetsForLeaderEpochResponse + val offsets = Map( + t1p0 -> newOffsetForLeaderPartitionResult(t1p0, leaderEpoch, 1), + t1p1 -> newOffsetForLeaderPartitionResult(t1p1, leaderEpoch, 1)).asJava + + //Create the fetcher thread + val mockNetwork = new ReplicaFetcherMockBlockingSend(offsets, brokerEndPoint, new SystemTime()) + val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, + new Metrics, new SystemTime, UnboundedQuota, Some(mockNetwork)) + thread.addPartitions(Map(t1p0 -> initialFetchState(Some(topicId1), 0L), t1p1 -> initialFetchState(Some(topicId1), 0L))) + + //Loop 1 + thread.doWork() + assertEquals(epochFetchCount, mockNetwork.epochFetchCount) + assertEquals(1, mockNetwork.fetchCount) + + //Loop 2 we should not fetch epochs + thread.doWork() + assertEquals(epochFetchCount, mockNetwork.epochFetchCount) + assertEquals(2, mockNetwork.fetchCount) + + //Loop 3 we should not fetch epochs + thread.doWork() + assertEquals(epochFetchCount, mockNetwork.epochFetchCount) + assertEquals(3, mockNetwork.fetchCount) + + //Assert that truncate to is called exactly once (despite two loops) + verify(logManager) + } + + @Test + def shouldTruncateToOffsetSpecifiedInEpochOffsetResponse(): Unit = { + + //Create a capture to track what partitions/offsets are truncated + val truncateToCapture: Capture[Long] = newCapture(CaptureType.ALL) + + // Setup all the dependencies + val config = kafkaConfigNoTruncateOnFetch + val quota: ReplicationQuotaManager = createNiceMock(classOf[ReplicationQuotaManager]) + val logManager: LogManager = createMock(classOf[LogManager]) + val replicaAlterLogDirsManager: ReplicaAlterLogDirsManager = createMock(classOf[ReplicaAlterLogDirsManager]) + val log: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + val partition: Partition = createMock(classOf[Partition]) + val replicaManager: ReplicaManager = createMock(classOf[ReplicaManager]) + + val leaderEpoch = 5 + val initialLEO = 200 + + //Stubs + expect(partition.truncateTo(capture(truncateToCapture), anyBoolean())).anyTimes() + expect(partition.localLogOrException).andReturn(log).anyTimes() + expect(log.highWatermark).andReturn(initialLEO - 1).anyTimes() + expect(log.latestEpoch).andReturn(Some(leaderEpoch)).anyTimes() + expect(log.endOffsetForEpoch(leaderEpoch)).andReturn( + Some(OffsetAndEpoch(initialLEO, leaderEpoch))).anyTimes() + expect(log.logEndOffset).andReturn(initialLEO).anyTimes() + expect(replicaManager.metadataCache).andStubReturn(metadataCache) + expect(replicaManager.localLogOrException(anyObject(classOf[TopicPartition]))).andReturn(log).anyTimes() + expect(replicaManager.logManager).andReturn(logManager).anyTimes() + expect(replicaManager.replicaAlterLogDirsManager).andReturn(replicaAlterLogDirsManager).anyTimes() + expect(replicaManager.brokerTopicStats).andReturn(mock(classOf[BrokerTopicStats])) + stub(partition, replicaManager, log) + + replay(replicaManager, logManager, quota, partition, log) + + //Define the offsets for the OffsetsForLeaderEpochResponse, these are used for truncation + val offsetsReply = Map( + t1p0 -> newOffsetForLeaderPartitionResult(t1p0, leaderEpoch, 156), + t2p1 -> newOffsetForLeaderPartitionResult(t2p1, leaderEpoch, 172)).asJava + + //Create the thread + val mockNetwork = new ReplicaFetcherMockBlockingSend(offsetsReply, brokerEndPoint, new SystemTime()) + val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, + new Metrics(), new SystemTime(), quota, Some(mockNetwork)) + thread.addPartitions(Map(t1p0 -> initialFetchState(Some(topicId1), 0L), t2p1 -> initialFetchState(Some(topicId2), 0L))) + + //Run it + thread.doWork() + + //We should have truncated to the offsets in the response + assertTrue(truncateToCapture.getValues.asScala.contains(156), + "Expected " + t1p0 + " to truncate to offset 156 (truncation offsets: " + truncateToCapture.getValues + ")") + assertTrue(truncateToCapture.getValues.asScala.contains(172), + "Expected " + t2p1 + " to truncate to offset 172 (truncation offsets: " + truncateToCapture.getValues + ")") + } + + @Test + def shouldTruncateToOffsetSpecifiedInEpochOffsetResponseIfFollowerHasNoMoreEpochs(): Unit = { + // Create a capture to track what partitions/offsets are truncated + val truncateToCapture: Capture[Long] = newCapture(CaptureType.ALL) + + // Setup all the dependencies + val config = kafkaConfigNoTruncateOnFetch + val quota: ReplicationQuotaManager = createNiceMock(classOf[ReplicationQuotaManager]) + val logManager: LogManager = createMock(classOf[LogManager]) + val replicaAlterLogDirsManager: ReplicaAlterLogDirsManager = createMock(classOf[ReplicaAlterLogDirsManager]) + val log: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + val partition: Partition = createMock(classOf[Partition]) + val replicaManager: ReplicaManager = createMock(classOf[ReplicaManager]) + + val leaderEpochAtFollower = 5 + val leaderEpochAtLeader = 4 + val initialLEO = 200 + + //Stubs + expect(partition.truncateTo(capture(truncateToCapture), anyBoolean())).anyTimes() + expect(partition.localLogOrException).andReturn(log).anyTimes() + expect(log.highWatermark).andReturn(initialLEO - 3).anyTimes() + expect(log.latestEpoch).andReturn(Some(leaderEpochAtFollower)).anyTimes() + expect(log.endOffsetForEpoch(leaderEpochAtLeader)).andReturn(None).anyTimes() + expect(log.logEndOffset).andReturn(initialLEO).anyTimes() + expect(replicaManager.metadataCache).andStubReturn(metadataCache) + expect(replicaManager.localLogOrException(anyObject(classOf[TopicPartition]))).andReturn(log).anyTimes() + expect(replicaManager.logManager).andReturn(logManager).anyTimes() + expect(replicaManager.replicaAlterLogDirsManager).andReturn(replicaAlterLogDirsManager).anyTimes() + expect(replicaManager.brokerTopicStats).andReturn(mock(classOf[BrokerTopicStats])) + stub(partition, replicaManager, log) + + replay(replicaManager, logManager, quota, partition, log) + + //Define the offsets for the OffsetsForLeaderEpochResponse, these are used for truncation + val offsetsReply = Map( + t1p0 -> newOffsetForLeaderPartitionResult(t1p0, leaderEpochAtLeader, 156), + t2p1 -> newOffsetForLeaderPartitionResult(t2p1, leaderEpochAtLeader, 202)).asJava + + //Create the thread + val mockNetwork = new ReplicaFetcherMockBlockingSend(offsetsReply, brokerEndPoint, new SystemTime()) + val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, + replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork)) + thread.addPartitions(Map(t1p0 -> initialFetchState(Some(topicId1), 0L), t2p1 -> initialFetchState(Some(topicId2), 0L))) + + //Run it + thread.doWork() + + //We should have truncated to the offsets in the response + assertTrue(truncateToCapture.getValues.asScala.contains(156), + "Expected " + t1p0 + " to truncate to offset 156 (truncation offsets: " + truncateToCapture.getValues + ")") + assertTrue(truncateToCapture.getValues.asScala.contains(initialLEO), + "Expected " + t2p1 + " to truncate to offset " + initialLEO + + " (truncation offsets: " + truncateToCapture.getValues + ")") + } + + @Test + def shouldFetchLeaderEpochSecondTimeIfLeaderRepliesWithEpochNotKnownToFollower(): Unit = { + // Create a capture to track what partitions/offsets are truncated + val truncateToCapture: Capture[Long] = newCapture(CaptureType.ALL) + + val config = kafkaConfigNoTruncateOnFetch + + // Setup all dependencies + val quota: ReplicationQuotaManager = createNiceMock(classOf[ReplicationQuotaManager]) + val logManager: LogManager = createMock(classOf[LogManager]) + val replicaAlterLogDirsManager: ReplicaAlterLogDirsManager = createMock(classOf[ReplicaAlterLogDirsManager]) + val log: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + val partition: Partition = createMock(classOf[Partition]) + val replicaManager: ReplicaManager = createMock(classOf[ReplicaManager]) + + val initialLEO = 200 + + // Stubs + expect(partition.truncateTo(capture(truncateToCapture), anyBoolean())).anyTimes() + expect(partition.localLogOrException).andReturn(log).anyTimes() + expect(log.highWatermark).andReturn(initialLEO - 2).anyTimes() + expect(log.latestEpoch).andReturn(Some(5)).anyTimes() + expect(log.endOffsetForEpoch(4)).andReturn( + Some(OffsetAndEpoch(120, 3))).anyTimes() + expect(log.endOffsetForEpoch(3)).andReturn( + Some(OffsetAndEpoch(120, 3))).anyTimes() + expect(log.logEndOffset).andReturn(initialLEO).anyTimes() + expect(replicaManager.metadataCache).andStubReturn(metadataCache) + expect(replicaManager.localLogOrException(anyObject(classOf[TopicPartition]))).andReturn(log).anyTimes() + expect(replicaManager.logManager).andReturn(logManager).anyTimes() + expect(replicaManager.replicaAlterLogDirsManager).andReturn(replicaAlterLogDirsManager).anyTimes() + expect(replicaManager.brokerTopicStats).andReturn(mock(classOf[BrokerTopicStats])) + stub(partition, replicaManager, log) + + replay(replicaManager, logManager, quota, partition, log) + + // Define the offsets for the OffsetsForLeaderEpochResponse + val offsets = Map( + t1p0 -> newOffsetForLeaderPartitionResult(t1p0, 4, 155), + t1p1 -> newOffsetForLeaderPartitionResult(t1p1, 4, 143)).asJava + + // Create the fetcher thread + val mockNetwork = new ReplicaFetcherMockBlockingSend(offsets, brokerEndPoint, new SystemTime()) + val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork)) + thread.addPartitions(Map(t1p0 -> initialFetchState(Some(topicId1), 0L), t1p1 -> initialFetchState(Some(topicId1), 0L))) + + // Loop 1 -- both topic partitions will need to fetch another leader epoch + thread.doWork() + assertEquals(1, mockNetwork.epochFetchCount) + assertEquals(0, mockNetwork.fetchCount) + + // Loop 2 should do the second fetch for both topic partitions because the leader replied with + // epoch 4 while follower knows only about epoch 3 + val nextOffsets = Map( + t1p0 -> newOffsetForLeaderPartitionResult(t1p0, 3, 101), + t1p1 -> newOffsetForLeaderPartitionResult(t1p1, 3, 102)).asJava + mockNetwork.setOffsetsForNextResponse(nextOffsets) + + thread.doWork() + assertEquals(2, mockNetwork.epochFetchCount) + assertEquals(1, mockNetwork.fetchCount) + assertTrue(mockNetwork.lastUsedOffsetForLeaderEpochVersion >= 3, + "OffsetsForLeaderEpochRequest version.") + + //Loop 3 we should not fetch epochs + thread.doWork() + assertEquals(2, mockNetwork.epochFetchCount) + assertEquals(2, mockNetwork.fetchCount) + + + //We should have truncated to the offsets in the second response + assertTrue(truncateToCapture.getValues.asScala.contains(102), + "Expected " + t1p1 + " to truncate to offset 102 (truncation offsets: " + truncateToCapture.getValues + ")") + assertTrue(truncateToCapture.getValues.asScala.contains(101), + "Expected " + t1p0 + " to truncate to offset 101 (truncation offsets: " + truncateToCapture.getValues + ")") + } + + @Test + def shouldTruncateIfLeaderRepliesWithDivergingEpochNotKnownToFollower(): Unit = { + + // Create a capture to track what partitions/offsets are truncated + val truncateToCapture: Capture[Long] = newCapture(CaptureType.ALL) + + val config = KafkaConfig.fromProps(TestUtils.createBrokerConfig(1, "localhost:1234")) + + // Setup all dependencies + val quota: ReplicationQuotaManager = createNiceMock(classOf[ReplicationQuotaManager]) + val logManager: LogManager = createMock(classOf[LogManager]) + val replicaAlterLogDirsManager: ReplicaAlterLogDirsManager = createMock(classOf[ReplicaAlterLogDirsManager]) + val log: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + val partition: Partition = createNiceMock(classOf[Partition]) + val replicaManager: ReplicaManager = createMock(classOf[ReplicaManager]) + + val initialLEO = 200 + var latestLogEpoch: Option[Int] = Some(5) + + // Stubs + expect(partition.truncateTo(capture(truncateToCapture), anyBoolean())).anyTimes() + expect(partition.localLogOrException).andReturn(log).anyTimes() + expect(log.highWatermark).andReturn(115).anyTimes() + expect(log.latestEpoch).andAnswer(() => latestLogEpoch).anyTimes() + expect(log.endOffsetForEpoch(4)).andReturn(Some(OffsetAndEpoch(149, 4))).anyTimes() + expect(log.endOffsetForEpoch(3)).andReturn(Some(OffsetAndEpoch(129, 2))).anyTimes() + expect(log.endOffsetForEpoch(2)).andReturn(Some(OffsetAndEpoch(119, 1))).anyTimes() + expect(log.logEndOffset).andReturn(initialLEO).anyTimes() + expect(replicaManager.metadataCache).andStubReturn(metadataCache) + expect(replicaManager.localLogOrException(anyObject(classOf[TopicPartition]))).andReturn(log).anyTimes() + expect(replicaManager.logManager).andReturn(logManager).anyTimes() + expect(replicaManager.replicaAlterLogDirsManager).andReturn(replicaAlterLogDirsManager).anyTimes() + expect(replicaManager.brokerTopicStats).andReturn(mock(classOf[BrokerTopicStats])) + stub(partition, replicaManager, log) + + replay(replicaManager, logManager, quota, partition, log) + + // Create the fetcher thread + val mockNetwork = new ReplicaFetcherMockBlockingSend(Collections.emptyMap(), brokerEndPoint, new SystemTime()) + val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork)) { + override def processPartitionData(topicPartition: TopicPartition, fetchOffset: Long, partitionData: FetchData): Option[LogAppendInfo] = None + } + thread.addPartitions(Map(t1p0 -> initialFetchState(Some(topicId1), initialLEO), t1p1 -> initialFetchState(Some(topicId1), initialLEO))) + val partitions = Set(t1p0, t1p1) + + // Loop 1 -- both topic partitions skip epoch fetch and send fetch request since we can truncate + // later based on diverging epochs in fetch response. + thread.doWork() + assertEquals(0, mockNetwork.epochFetchCount) + assertEquals(1, mockNetwork.fetchCount) + partitions.foreach { tp => assertEquals(Fetching, thread.fetchState(tp).get.state) } + + def partitionData(partition: Int, divergingEpoch: FetchResponseData.EpochEndOffset): FetchResponseData.PartitionData = { + new FetchResponseData.PartitionData() + .setPartitionIndex(partition) + .setLastStableOffset(0) + .setLogStartOffset(0) + .setDivergingEpoch(divergingEpoch) + } + + // Loop 2 should truncate based on diverging epoch and continue to send fetch requests. + mockNetwork.setFetchPartitionDataForNextResponse(Map( + t1p0 -> partitionData(t1p0.partition, new FetchResponseData.EpochEndOffset().setEpoch(4).setEndOffset(140)), + t1p1 -> partitionData(t1p1.partition, new FetchResponseData.EpochEndOffset().setEpoch(4).setEndOffset(141)) + )) + mockNetwork.setIdsForNextResponse(topicIds) + latestLogEpoch = Some(4) + thread.doWork() + assertEquals(0, mockNetwork.epochFetchCount) + assertEquals(2, mockNetwork.fetchCount) + assertTrue(truncateToCapture.getValues.asScala.contains(140), + "Expected " + t1p0 + " to truncate to offset 140 (truncation offsets: " + truncateToCapture.getValues + ")") + assertTrue(truncateToCapture.getValues.asScala.contains(141), + "Expected " + t1p1 + " to truncate to offset 141 (truncation offsets: " + truncateToCapture.getValues + ")") + partitions.foreach { tp => assertEquals(Fetching, thread.fetchState(tp).get.state) } + + // Loop 3 should truncate because of diverging epoch. Offset truncation is not complete + // because divergent epoch is not known to follower. We truncate and stay in Fetching state. + mockNetwork.setFetchPartitionDataForNextResponse(Map( + t1p0 -> partitionData(t1p0.partition, new FetchResponseData.EpochEndOffset().setEpoch(3).setEndOffset(130)), + t1p1 -> partitionData(t1p1.partition, new FetchResponseData.EpochEndOffset().setEpoch(3).setEndOffset(131)) + )) + mockNetwork.setIdsForNextResponse(topicIds) + thread.doWork() + assertEquals(0, mockNetwork.epochFetchCount) + assertEquals(3, mockNetwork.fetchCount) + assertTrue(truncateToCapture.getValues.asScala.contains(129), + "Expected to truncate to offset 129 (truncation offsets: " + truncateToCapture.getValues + ")") + partitions.foreach { tp => assertEquals(Fetching, thread.fetchState(tp).get.state) } + + // Loop 4 should truncate because of diverging epoch. Offset truncation is not complete + // because divergent epoch is not known to follower. Last fetched epoch cannot be determined + // from the log. We truncate and stay in Fetching state. + mockNetwork.setFetchPartitionDataForNextResponse(Map( + t1p0 -> partitionData(t1p0.partition, new FetchResponseData.EpochEndOffset().setEpoch(2).setEndOffset(120)), + t1p1 -> partitionData(t1p1.partition, new FetchResponseData.EpochEndOffset().setEpoch(2).setEndOffset(121)) + )) + mockNetwork.setIdsForNextResponse(topicIds) + latestLogEpoch = None + thread.doWork() + assertEquals(0, mockNetwork.epochFetchCount) + assertEquals(4, mockNetwork.fetchCount) + assertTrue(truncateToCapture.getValues.asScala.contains(119), + "Expected to truncate to offset 119 (truncation offsets: " + truncateToCapture.getValues + ")") + partitions.foreach { tp => assertEquals(Fetching, thread.fetchState(tp).get.state) } + } + + @Test + def shouldUseLeaderEndOffsetIfInterBrokerVersionBelow20(): Unit = { + + // Create a capture to track what partitions/offsets are truncated + val truncateToCapture: Capture[Long] = newCapture(CaptureType.ALL) + + val props = TestUtils.createBrokerConfig(1, "localhost:1234") + props.put(KafkaConfig.InterBrokerProtocolVersionProp, "0.11.0") + val config = KafkaConfig.fromProps(props) + + // Setup all dependencies + val quota: ReplicationQuotaManager = createNiceMock(classOf[ReplicationQuotaManager]) + val logManager: LogManager = createMock(classOf[LogManager]) + val replicaAlterLogDirsManager: ReplicaAlterLogDirsManager = createMock(classOf[ReplicaAlterLogDirsManager]) + val log: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + val partition: Partition = createMock(classOf[Partition]) + val replicaManager: ReplicaManager = createMock(classOf[ReplicaManager]) + + val initialLEO = 200 + + // Stubs + expect(partition.truncateTo(capture(truncateToCapture), anyBoolean())).anyTimes() + expect(partition.localLogOrException).andReturn(log).anyTimes() + expect(log.highWatermark).andReturn(initialLEO - 2).anyTimes() + expect(log.latestEpoch).andReturn(Some(5)).anyTimes() + expect(log.endOffsetForEpoch(4)).andReturn( + Some(OffsetAndEpoch(120, 3))).anyTimes() + expect(log.endOffsetForEpoch(3)).andReturn( + Some(OffsetAndEpoch(120, 3))).anyTimes() + expect(log.logEndOffset).andReturn(initialLEO).anyTimes() + expect(replicaManager.metadataCache).andStubReturn(metadataCache) + expect(replicaManager.localLogOrException(anyObject(classOf[TopicPartition]))).andReturn(log).anyTimes() + expect(replicaManager.logManager).andReturn(logManager).anyTimes() + expect(replicaManager.replicaAlterLogDirsManager).andReturn(replicaAlterLogDirsManager).anyTimes() + expect(replicaManager.brokerTopicStats).andReturn(mock(classOf[BrokerTopicStats])) + stub(partition, replicaManager, log) + + replay(replicaManager, logManager, quota, partition, log) + + // Define the offsets for the OffsetsForLeaderEpochResponse with undefined epoch to simulate + // older protocol version + val offsets = Map( + t1p0 -> newOffsetForLeaderPartitionResult(t1p0, UNDEFINED_EPOCH, 155), + t1p1 -> newOffsetForLeaderPartitionResult(t1p1, UNDEFINED_EPOCH, 143)).asJava + + // Create the fetcher thread + val mockNetwork = new ReplicaFetcherMockBlockingSend(offsets, brokerEndPoint, new SystemTime()) + val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork)) + thread.addPartitions(Map(t1p0 -> initialFetchState(Some(topicId1), 0L), t1p1 -> initialFetchState(Some(topicId1), 0L))) + + // Loop 1 -- both topic partitions will truncate to leader offset even though they don't know + // about leader epoch + thread.doWork() + assertEquals(1, mockNetwork.epochFetchCount) + assertEquals(1, mockNetwork.fetchCount) + assertEquals(0, mockNetwork.lastUsedOffsetForLeaderEpochVersion, "OffsetsForLeaderEpochRequest version.") + + //Loop 2 we should not fetch epochs + thread.doWork() + assertEquals(1, mockNetwork.epochFetchCount) + assertEquals(2, mockNetwork.fetchCount) + + //We should have truncated to the offsets in the first response + assertTrue(truncateToCapture.getValues.asScala.contains(155), + "Expected " + t1p0 + " to truncate to offset 155 (truncation offsets: " + truncateToCapture.getValues + ")") + assertTrue(truncateToCapture.getValues.asScala.contains(143), + "Expected " + t1p1 + " to truncate to offset 143 (truncation offsets: " + truncateToCapture.getValues + ")") + } + + @Test + def shouldTruncateToInitialFetchOffsetIfLeaderReturnsUndefinedOffset(): Unit = { + + //Create a capture to track what partitions/offsets are truncated + val truncated: Capture[Long] = newCapture(CaptureType.ALL) + + // Setup all the dependencies + val config = kafkaConfigNoTruncateOnFetch + val quota: ReplicationQuotaManager = createNiceMock(classOf[ReplicationQuotaManager]) + val logManager: LogManager = createMock(classOf[LogManager]) + val replicaAlterLogDirsManager: ReplicaAlterLogDirsManager = createMock(classOf[ReplicaAlterLogDirsManager]) + val log: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + val partition: Partition = createMock(classOf[Partition]) + val replicaManager: ReplicaManager = createMock(classOf[ReplicaManager]) + + val initialFetchOffset = 100 + + //Stubs + expect(partition.truncateTo(capture(truncated), anyBoolean())).anyTimes() + expect(partition.localLogOrException).andReturn(log).anyTimes() + expect(log.highWatermark).andReturn(initialFetchOffset).anyTimes() + expect(log.latestEpoch).andReturn(Some(5)).times(2) + expect(replicaManager.metadataCache).andStubReturn(metadataCache) + expect(replicaManager.logManager).andReturn(logManager).anyTimes() + expect(replicaManager.replicaAlterLogDirsManager).andReturn(replicaAlterLogDirsManager).anyTimes() + expect(replicaManager.brokerTopicStats).andReturn(mock(classOf[BrokerTopicStats])) + stub(partition, replicaManager, log) + replay(replicaManager, logManager, quota, partition, log) + + //Define the offsets for the OffsetsForLeaderEpochResponse, these are used for truncation + val offsetsReply = Map( + t1p0 -> newOffsetForLeaderPartitionResult(t1p0, UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET)).asJava + + //Create the thread + val mockNetwork = new ReplicaFetcherMockBlockingSend(offsetsReply, brokerEndPoint, new SystemTime()) + val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork)) + thread.addPartitions(Map(t1p0 -> initialFetchState(Some(topicId1), initialFetchOffset))) + + //Run it + thread.doWork() + + //We should have truncated to initial fetch offset + assertEquals(initialFetchOffset, truncated.getValue) + } + + @Test + def shouldPollIndefinitelyIfLeaderReturnsAnyException(): Unit = { + + //Create a capture to track what partitions/offsets are truncated + val truncated: Capture[Long] = newCapture(CaptureType.ALL) + + // Setup all the dependencies + val config = kafkaConfigNoTruncateOnFetch + val quota: ReplicationQuotaManager = createNiceMock(classOf[ReplicationQuotaManager]) + val logManager: LogManager = createMock(classOf[LogManager]) + val replicaAlterLogDirsManager: ReplicaAlterLogDirsManager = createMock(classOf[ReplicaAlterLogDirsManager]) + val log: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + val partition: Partition = createMock(classOf[Partition]) + val replicaManager: ReplicaManager = createMock(classOf[ReplicaManager]) + + val leaderEpoch = 5 + val highWaterMark = 100 + val initialLeo = 300 + + //Stubs + expect(log.highWatermark).andReturn(highWaterMark).anyTimes() + expect(partition.truncateTo(capture(truncated), anyBoolean())).anyTimes() + expect(partition.localLogOrException).andReturn(log).anyTimes() + expect(log.latestEpoch).andReturn(Some(leaderEpoch)).anyTimes() + // this is for the last reply with EpochEndOffset(5, 156) + expect(log.endOffsetForEpoch(leaderEpoch)).andReturn( + Some(OffsetAndEpoch(initialLeo, leaderEpoch))).anyTimes() + expect(log.logEndOffset).andReturn(initialLeo).anyTimes() + expect(replicaManager.metadataCache).andStubReturn(metadataCache) + expect(replicaManager.localLogOrException(anyObject(classOf[TopicPartition]))).andReturn(log).anyTimes() + expect(replicaManager.logManager).andReturn(logManager).anyTimes() + expect(replicaManager.replicaAlterLogDirsManager).andReturn(replicaAlterLogDirsManager).anyTimes() + expect(replicaManager.brokerTopicStats).andReturn(mock(classOf[BrokerTopicStats])) + stub(partition, replicaManager, log) + replay(replicaManager, logManager, quota, partition, log) + + //Define the offsets for the OffsetsForLeaderEpochResponse, these are used for truncation + val offsetsReply = mutable.Map( + t1p0 -> newOffsetForLeaderPartitionResult(t1p0, NOT_LEADER_OR_FOLLOWER, UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET), + t1p1 -> newOffsetForLeaderPartitionResult(t1p1, UNKNOWN_SERVER_ERROR, UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET) + ).asJava + + //Create the thread + val mockNetwork = new ReplicaFetcherMockBlockingSend(offsetsReply, brokerEndPoint, new SystemTime()) + val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork)) + thread.addPartitions(Map(t1p0 -> initialFetchState(Some(topicId1), 0L), t1p1 -> initialFetchState(Some(topicId1), 0L))) + + //Run thread 3 times + (0 to 3).foreach { _ => + thread.doWork() + } + + //Then should loop continuously while there is no leader + assertEquals(0, truncated.getValues.size()) + + //New leader elected and replies + offsetsReply.put(t1p0, newOffsetForLeaderPartitionResult(t1p0, leaderEpoch, 156)) + + thread.doWork() + + //Now the final call should have actually done a truncation (to offset 156) + assertEquals(156, truncated.getValue) + } + + @Test + def shouldMovePartitionsOutOfTruncatingLogState(): Unit = { + val config = kafkaConfigNoTruncateOnFetch + + //Setup all stubs + val quota: ReplicationQuotaManager = createNiceMock(classOf[ReplicationQuotaManager]) + val logManager: LogManager = createNiceMock(classOf[LogManager]) + val replicaAlterLogDirsManager: ReplicaAlterLogDirsManager = createMock(classOf[ReplicaAlterLogDirsManager]) + val log: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + val partition: Partition = createMock(classOf[Partition]) + val replicaManager: ReplicaManager = createNiceMock(classOf[ReplicaManager]) + + val leaderEpoch = 4 + + //Stub return values + expect(partition.truncateTo(0L, false)).times(2) + expect(partition.localLogOrException).andReturn(log).anyTimes() + expect(log.highWatermark).andReturn(0).anyTimes() + expect(log.latestEpoch).andReturn(Some(leaderEpoch)).anyTimes() + expect(log.endOffsetForEpoch(leaderEpoch)).andReturn( + Some(OffsetAndEpoch(0, leaderEpoch))).anyTimes() + expect(replicaManager.metadataCache).andStubReturn(metadataCache) + expect(replicaManager.logManager).andReturn(logManager).anyTimes() + expect(replicaManager.replicaAlterLogDirsManager).andReturn(replicaAlterLogDirsManager).anyTimes() + stub(partition, replicaManager, log) + + replay(replicaManager, logManager, quota, partition, log) + + //Define the offsets for the OffsetsForLeaderEpochResponse + val offsetsReply = Map( + t1p0 -> newOffsetForLeaderPartitionResult(t1p0, leaderEpoch, 1), + t1p1 -> newOffsetForLeaderPartitionResult(t1p1, leaderEpoch, 1) + ).asJava + + //Create the fetcher thread + val mockNetwork = new ReplicaFetcherMockBlockingSend(offsetsReply, brokerEndPoint, new SystemTime()) + val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, new Metrics(), new SystemTime(), quota, Some(mockNetwork)) + + //When + thread.addPartitions(Map(t1p0 -> initialFetchState(Some(topicId1), 0L), t1p1 -> initialFetchState(Some(topicId1), 0L))) + + //Then all partitions should start in an TruncatingLog state + assertEquals(Option(Truncating), thread.fetchState(t1p0).map(_.state)) + assertEquals(Option(Truncating), thread.fetchState(t1p1).map(_.state)) + + //When + thread.doWork() + + //Then none should be TruncatingLog anymore + assertEquals(Option(Fetching), thread.fetchState(t1p0).map(_.state)) + assertEquals(Option(Fetching), thread.fetchState(t1p1).map(_.state)) + } + + @Test + def shouldFilterPartitionsMadeLeaderDuringLeaderEpochRequest(): Unit ={ + val config = kafkaConfigNoTruncateOnFetch + val truncateToCapture: Capture[Long] = newCapture(CaptureType.ALL) + val initialLEO = 100 + + //Setup all stubs + val quota: ReplicationQuotaManager = createNiceMock(classOf[ReplicationQuotaManager]) + val logManager: LogManager = createNiceMock(classOf[LogManager]) + val replicaAlterLogDirsManager: ReplicaAlterLogDirsManager = createMock(classOf[ReplicaAlterLogDirsManager]) + val log: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + val partition: Partition = createMock(classOf[Partition]) + val replicaManager: ReplicaManager = createNiceMock(classOf[ReplicaManager]) + + //Stub return values + expect(partition.truncateTo(capture(truncateToCapture), anyBoolean())).once + expect(partition.localLogOrException).andReturn(log).anyTimes() + expect(log.highWatermark).andReturn(initialLEO - 2).anyTimes() + expect(log.latestEpoch).andReturn(Some(5)).anyTimes() + expect(log.endOffsetForEpoch(5)).andReturn(Some(OffsetAndEpoch(initialLEO, 5))).anyTimes() + expect(log.logEndOffset).andReturn(initialLEO).anyTimes() + expect(replicaManager.metadataCache).andStubReturn(metadataCache) + expect(replicaManager.localLogOrException(anyObject(classOf[TopicPartition]))).andReturn(log).anyTimes() + expect(replicaManager.logManager).andReturn(logManager).anyTimes() + expect(replicaManager.replicaAlterLogDirsManager).andReturn(replicaAlterLogDirsManager).anyTimes() + stub(partition, replicaManager, log) + + replay(replicaManager, logManager, quota, partition, log) + + //Define the offsets for the OffsetsForLeaderEpochResponse + val offsetsReply = Map( + t1p0 -> newOffsetForLeaderPartitionResult(t1p0, 5, 52), + t1p1 -> newOffsetForLeaderPartitionResult(t1p1, 5, 49) + ).asJava + + //Create the fetcher thread + val mockNetwork = new ReplicaFetcherMockBlockingSend(offsetsReply, brokerEndPoint, new SystemTime()) + val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, replicaManager, new Metrics(), + new SystemTime(), quota, Some(mockNetwork)) + + //When + thread.addPartitions(Map(t1p0 -> initialFetchState(Some(topicId1), 0L), t1p1 -> initialFetchState(Some(topicId1), 0L))) + + //When the epoch request is outstanding, remove one of the partitions to simulate a leader change. We do this via a callback passed to the mock thread + val partitionThatBecameLeader = t1p0 + mockNetwork.setEpochRequestCallback(() => { + thread.removePartitions(Set(partitionThatBecameLeader)) + }) + + //When + thread.doWork() + + //Then we should not have truncated the partition that became leader. Exactly one partition should be truncated. + assertEquals(49, truncateToCapture.getValue) + } + + @Test + def shouldCatchExceptionFromBlockingSendWhenShuttingDownReplicaFetcherThread(): Unit = { + val props = TestUtils.createBrokerConfig(1, "localhost:1234") + val config = KafkaConfig.fromProps(props) + val mockBlockingSend: BlockingSend = createMock(classOf[BlockingSend]) + + expect(mockBlockingSend.initiateClose()).andThrow(new IllegalArgumentException()).once() + expect(mockBlockingSend.close()).andThrow(new IllegalStateException()).once() + val replicaManager: ReplicaManager = mock(classOf[ReplicaManager]) + expect(replicaManager.brokerTopicStats).andReturn(mock(classOf[BrokerTopicStats])) + replay(mockBlockingSend, replicaManager) + + val thread = new ReplicaFetcherThread( + name = "bob", + fetcherId = 0, + sourceBroker = brokerEndPoint, + brokerConfig = config, + failedPartitions = failedPartitions, + replicaMgr = replicaManager, + metrics = new Metrics(), + time = new SystemTime(), + quota = null, + leaderEndpointBlockingSend = Some(mockBlockingSend)) + thread.start() + + // Verify that: + // 1) IllegalArgumentException thrown by BlockingSend#initiateClose() during `initiateShutdown` is not propagated + // 2) BlockingSend.close() is invoked even if BlockingSend#initiateClose() fails + // 3) IllegalStateException thrown by BlockingSend.close() during `awaitShutdown` is not propagated + thread.initiateShutdown() + thread.awaitShutdown() + verify(mockBlockingSend) + } + + @Test + def shouldUpdateReassignmentBytesInMetrics(): Unit = { + assertProcessPartitionDataWhen(isReassigning = true) + } + + @Test + def shouldNotUpdateReassignmentBytesInMetricsWhenNoReassignmentsInProgress(): Unit = { + assertProcessPartitionDataWhen(isReassigning = false) + } + + @Test + def testBuildFetch(): Unit = { + val tid1p0 = new TopicIdPartition(topicId1, t1p0) + val tid1p1 = new TopicIdPartition(topicId1, t1p1) + val tid2p1 = new TopicIdPartition(topicId2, t2p1) + + val props = TestUtils.createBrokerConfig(1, "localhost:1234") + val config = KafkaConfig.fromProps(props) + val replicaManager: ReplicaManager = mock(classOf[ReplicaManager]) + val mockBlockingSend: BlockingSend = createMock(classOf[BlockingSend]) + val replicaQuota: ReplicaQuota = createNiceMock(classOf[ReplicaQuota]) + val log: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + + expect(replicaManager.brokerTopicStats).andReturn(mock(classOf[BrokerTopicStats])) + expect(replicaManager.localLogOrException(anyObject(classOf[TopicPartition]))).andReturn(log).anyTimes() + expect(replicaQuota.isThrottled(anyObject(classOf[TopicPartition]))).andReturn(false).anyTimes() + expect(log.logStartOffset).andReturn(0).anyTimes() + replay(log, replicaQuota, replicaManager) + + val thread = new ReplicaFetcherThread("bob", 0, brokerEndPoint, config, failedPartitions, + replicaManager, new Metrics(), new SystemTime(), replicaQuota, Some(mockBlockingSend)) + + val leaderEpoch = 1 + + val partitionMap = Map( + t1p0 -> PartitionFetchState(Some(topicId1), 150, None, leaderEpoch, None, state = Fetching, lastFetchedEpoch = None), + t1p1 -> PartitionFetchState(Some(topicId1), 155, None, leaderEpoch, None, state = Fetching, lastFetchedEpoch = None), + t2p1 -> PartitionFetchState(Some(topicId2), 160, None, leaderEpoch, None, state = Fetching, lastFetchedEpoch = None)) + + val ResultWithPartitions(fetchRequestOpt, _) = thread.buildFetch(partitionMap) + + assertTrue(fetchRequestOpt.isDefined) + val fetchRequestBuilder = fetchRequestOpt.get.fetchRequest + + val partitionDataMap = partitionMap.map { case (tp, state) => + (tp, new FetchRequest.PartitionData(state.topicId.get, state.fetchOffset, 0L, + config.replicaFetchMaxBytes, Optional.of(state.currentLeaderEpoch), Optional.empty())) + } + + assertEquals(partitionDataMap.asJava, fetchRequestBuilder.fetchData()) + assertEquals(0, fetchRequestBuilder.replaced().size) + assertEquals(0, fetchRequestBuilder.removed().size) + + val responseData = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + responseData.put(tid1p0, new FetchResponseData.PartitionData()) + responseData.put(tid1p1, new FetchResponseData.PartitionData()) + responseData.put(tid2p1, new FetchResponseData.PartitionData()) + val fetchResponse = FetchResponse.of(Errors.NONE, 0, 123, responseData) + + thread.fetchSessionHandler.handleResponse(fetchResponse, ApiKeys.FETCH.latestVersion()) + + // Remove t1p0, change the ID for t2p1, and keep t1p1 the same + val newTopicId = Uuid.randomUuid() + val partitionMap2 = Map( + t1p1 -> PartitionFetchState(Some(topicId1), 155, None, leaderEpoch, None, state = Fetching, lastFetchedEpoch = None), + t2p1 -> PartitionFetchState(Some(newTopicId), 160, None, leaderEpoch, None, state = Fetching, lastFetchedEpoch = None)) + val ResultWithPartitions(fetchRequestOpt2, _) = thread.buildFetch(partitionMap2) + + // Since t1p1 didn't change, we drop that one + val partitionDataMap2 = partitionMap2.drop(1).map { case (tp, state) => + (tp, new FetchRequest.PartitionData(state.topicId.get, state.fetchOffset, 0L, + config.replicaFetchMaxBytes, Optional.of(state.currentLeaderEpoch), Optional.empty())) + } + + assertTrue(fetchRequestOpt2.isDefined) + val fetchRequestBuilder2 = fetchRequestOpt2.get.fetchRequest + assertEquals(partitionDataMap2.asJava, fetchRequestBuilder2.fetchData()) + assertEquals(Collections.singletonList(tid2p1), fetchRequestBuilder2.replaced()) + assertEquals(Collections.singletonList(tid1p0), fetchRequestBuilder2.removed()) + } + + private def newOffsetForLeaderPartitionResult( + tp: TopicPartition, + leaderEpoch: Int, + endOffset: Long + ): EpochEndOffset = { + newOffsetForLeaderPartitionResult(tp, Errors.NONE, leaderEpoch, endOffset) + } + + private def newOffsetForLeaderPartitionResult( + tp: TopicPartition, + error: Errors, + leaderEpoch: Int, + endOffset: Long + ): EpochEndOffset = { + new EpochEndOffset() + .setPartition(tp.partition) + .setErrorCode(error.code) + .setLeaderEpoch(leaderEpoch) + .setEndOffset(endOffset) + } + + private def assertProcessPartitionDataWhen(isReassigning: Boolean): Unit = { + val props = TestUtils.createBrokerConfig(1, "localhost:1234") + val config = KafkaConfig.fromProps(props) + + val mockBlockingSend: BlockingSend = createNiceMock(classOf[BlockingSend]) + + val log: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + + val partition: Partition = createNiceMock(classOf[Partition]) + expect(partition.localLogOrException).andReturn(log) + expect(partition.isReassigning).andReturn(isReassigning) + expect(partition.isAddingLocalReplica).andReturn(isReassigning) + + val replicaManager: ReplicaManager = createNiceMock(classOf[ReplicaManager]) + expect(replicaManager.getPartitionOrException(anyObject[TopicPartition])).andReturn(partition) + val brokerTopicStats = new BrokerTopicStats + expect(replicaManager.brokerTopicStats).andReturn(brokerTopicStats).anyTimes() + + val replicaQuota: ReplicaQuota = createNiceMock(classOf[ReplicaQuota]) + replay(mockBlockingSend, replicaManager, partition, log, replicaQuota) + + val thread = new ReplicaFetcherThread( + name = "bob", + fetcherId = 0, + sourceBroker = brokerEndPoint, + brokerConfig = config, + failedPartitions = failedPartitions, + replicaMgr = replicaManager, + metrics = new Metrics(), + time = new SystemTime(), + quota = replicaQuota, + leaderEndpointBlockingSend = Some(mockBlockingSend)) + + val records = MemoryRecords.withRecords(CompressionType.NONE, + new SimpleRecord(1000, "foo".getBytes(StandardCharsets.UTF_8))) + val partitionData: thread.FetchData = new FetchResponseData.PartitionData() + .setPartitionIndex(t1p0.partition) + .setLastStableOffset(0) + .setLogStartOffset(0) + .setRecords(records) + thread.processPartitionData(t1p0, 0, partitionData) + + if (isReassigning) + assertEquals(records.sizeInBytes(), brokerTopicStats.allTopicsStats.reassignmentBytesInPerSec.get.count()) + else + assertEquals(0, brokerTopicStats.allTopicsStats.reassignmentBytesInPerSec.get.count()) + + assertEquals(records.sizeInBytes(), brokerTopicStats.allTopicsStats.replicationBytesInRate.get.count()) + } + + def stub(partition: Partition, replicaManager: ReplicaManager, log: UnifiedLog): Unit = { + expect(replicaManager.localLogOrException(t1p0)).andReturn(log).anyTimes() + expect(replicaManager.getPartitionOrException(t1p0)).andReturn(partition).anyTimes() + expect(replicaManager.localLogOrException(t1p1)).andReturn(log).anyTimes() + expect(replicaManager.getPartitionOrException(t1p1)).andReturn(partition).anyTimes() + expect(replicaManager.localLogOrException(t2p1)).andReturn(log).anyTimes() + expect(replicaManager.getPartitionOrException(t2p1)).andReturn(partition).anyTimes() + } + + private def kafkaConfigNoTruncateOnFetch: KafkaConfig = { + val props = TestUtils.createBrokerConfig(1, "localhost:1234") + props.setProperty(KafkaConfig.InterBrokerProtocolVersionProp, KAFKA_2_6_IV0.version) + KafkaConfig.fromProps(props) + } +} diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerConcurrencyTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaManagerConcurrencyTest.scala new file mode 100644 index 0000000..f0003f4 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerConcurrencyTest.scala @@ -0,0 +1,456 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.net.InetAddress +import java.util +import java.util.concurrent.{CompletableFuture, Executors, LinkedBlockingQueue, TimeUnit} +import java.util.{Optional, Properties} + +import kafka.api.LeaderAndIsr +import kafka.log.{AppendOrigin, LogConfig} +import kafka.server.metadata.MockConfigRepository +import kafka.utils.TestUtils.waitUntilTrue +import kafka.utils.{MockTime, ShutdownableThread, TestUtils} +import org.apache.kafka.common.metadata.{PartitionChangeRecord, PartitionRecord, TopicRecord} +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.record.SimpleRecord +import org.apache.kafka.common.replica.ClientMetadata.DefaultClientMetadata +import org.apache.kafka.common.requests.{FetchRequest, ProduceResponse} +import org.apache.kafka.common.security.auth.KafkaPrincipal +import org.apache.kafka.common.utils.Time +import org.apache.kafka.common.{IsolationLevel, TopicIdPartition, TopicPartition, Uuid} +import org.apache.kafka.image.{MetadataDelta, MetadataImage} +import org.apache.kafka.metadata.PartitionRegistration +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, Test} +import org.mockito.Mockito + +import scala.collection.mutable +import scala.jdk.CollectionConverters._ +import scala.util.Random + +class ReplicaManagerConcurrencyTest { + + private val time = new MockTime() + private val metrics = new Metrics() + private val executor = Executors.newScheduledThreadPool(8) + private val tasks = mutable.Buffer.empty[ShutdownableThread] + + private def submit(task: ShutdownableThread): Unit = { + tasks += task + executor.submit(task) + } + + @AfterEach + def cleanup(): Unit = { + tasks.foreach(_.shutdown()) + executor.shutdownNow() + executor.awaitTermination(5, TimeUnit.SECONDS) + metrics.close() + } + + @Test + def testIsrExpandAndShrinkWithConcurrentProduce(): Unit = { + val localId = 0 + val remoteId = 1 + val channel = new ControllerChannel + val replicaManager = buildReplicaManager(localId, channel) + + // Start with the remote replica out of the ISR + val initialPartitionRegistration = registration( + replicaIds = Seq(localId, remoteId), + isr = Seq(localId), + leader = localId + ) + + val topicModel = new TopicModel(Uuid.randomUuid(), "foo", Map(0 -> initialPartitionRegistration)) + val topicPartition = new TopicPartition(topicModel.name, 0) + val topicIdPartition = new TopicIdPartition(topicModel.topicId, topicPartition) + val controller = new ControllerModel(topicModel, channel, replicaManager) + + submit(new Clock(time)) + replicaManager.startup() + + submit(controller) + controller.initialize() + + waitUntilTrue(() => { + replicaManager.getPartition(topicPartition) match { + case HostedPartition.Online(partition) => partition.isLeader + case _ => false + } + }, "Timed out waiting for partition to initialize") + + val partition = replicaManager.getPartitionOrException(topicPartition) + + // Start several producers which are actively writing to the partition + (0 to 2).foreach { i => + submit(new ProducerModel( + clientId = s"producer-$i", + topicPartition, + replicaManager + )) + } + + // Start the remote replica fetcher and wait for it to join the ISR + val fetcher = new FetcherModel( + clientId = s"replica-$remoteId", + replicaId = remoteId, + topicIdPartition, + replicaManager + ) + + submit(fetcher) + waitUntilTrue(() => { + partition.inSyncReplicaIds == Set(localId, remoteId) + }, "Test timed out before ISR was expanded") + + // Stop the fetcher so that the replica is removed from the ISR + fetcher.shutdown() + waitUntilTrue(() => { + partition.inSyncReplicaIds == Set(localId) + }, "Test timed out before ISR was shrunk") + } + + private class Clock( + time: MockTime + ) extends ShutdownableThread(name = "clock", isInterruptible = false) { + override def doWork(): Unit = { + time.sleep(1) + } + } + + private def buildReplicaManager( + localId: Int, + channel: ControllerChannel + ): ReplicaManager = { + val logDir = TestUtils.tempDir() + + val props = new Properties + props.put(KafkaConfig.QuorumVotersProp, "100@localhost:12345") + props.put(KafkaConfig.ProcessRolesProp, "broker") + props.put(KafkaConfig.NodeIdProp, localId.toString) + props.put(KafkaConfig.ControllerListenerNamesProp, "SSL") + props.put(KafkaConfig.LogDirProp, logDir.getAbsolutePath) + props.put(KafkaConfig.ReplicaLagTimeMaxMsProp, 5000.toString) + + val config = new KafkaConfig(props, doLog = false) + + val logManager = TestUtils.createLogManager( + defaultConfig = new LogConfig(new Properties), + configRepository = new MockConfigRepository, + logDirs = Seq(logDir), + time = time + ) + + new ReplicaManager( + metrics = metrics, + config = config, + time = time, + scheduler = time.scheduler, + logManager = logManager, + quotaManagers = QuotaFactory.instantiate(config, metrics, time, ""), + metadataCache = MetadataCache.kRaftMetadataCache(config.brokerId), + logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size), + alterIsrManager = new MockAlterIsrManager(channel) + ) { + override def createReplicaFetcherManager( + metrics: Metrics, + time: Time, + threadNamePrefix: Option[String], + quotaManager: ReplicationQuotaManager + ): ReplicaFetcherManager = { + Mockito.mock(classOf[ReplicaFetcherManager]) + } + } + } + + private class FetcherModel( + clientId: String, + replicaId: Int, + topicIdPartition: TopicIdPartition, + replicaManager: ReplicaManager + ) extends ShutdownableThread(name = clientId, isInterruptible = false) { + private val random = new Random() + + private val clientMetadata = new DefaultClientMetadata( + "", + clientId, + InetAddress.getLocalHost, + KafkaPrincipal.ANONYMOUS, + "PLAINTEXT" + ) + + private var fetchOffset = 0L + + override def doWork(): Unit = { + val partitionData = new FetchRequest.PartitionData( + topicIdPartition.topicId, + fetchOffset, + -1, + 65536, + Optional.empty(), + Optional.empty() + ) + + val future = new CompletableFuture[FetchPartitionData]() + def fetchCallback(results: collection.Seq[(TopicIdPartition, FetchPartitionData)]): Unit = { + try { + assertEquals(1, results.size) + val (topicIdPartition, result) = results.head + assertEquals(this.topicIdPartition, topicIdPartition) + assertEquals(Errors.NONE, result.error) + future.complete(result) + } catch { + case e: Throwable => future.completeExceptionally(e) + } + } + + replicaManager.fetchMessages( + timeout = random.nextInt(100), + replicaId = replicaId, + fetchMinBytes = 1, + fetchMaxBytes = 1024 * 1024, + hardMaxBytesLimit = false, + fetchInfos = Seq(topicIdPartition -> partitionData), + quota = QuotaFactory.UnboundedQuota, + responseCallback = fetchCallback, + isolationLevel = IsolationLevel.READ_UNCOMMITTED, + clientMetadata = Some(clientMetadata) + ) + + val fetchResult = future.get() + fetchResult.records.batches.forEach { batch => + fetchOffset = batch.lastOffset + 1 + } + } + } + + private class ProducerModel( + clientId: String, + topicPartition: TopicPartition, + replicaManager: ReplicaManager + ) extends ShutdownableThread(name = clientId, isInterruptible = false) { + private val random = new Random() + private var sequence = 0 + + override def doWork(): Unit = { + val numRecords = (random.nextInt() % 10) + 1 + + val records = (0 until numRecords).map { i => + new SimpleRecord(s"$clientId-${sequence + i}".getBytes) + } + + val future = new CompletableFuture[ProduceResponse.PartitionResponse]() + def produceCallback(results: collection.Map[TopicPartition, ProduceResponse.PartitionResponse]): Unit = { + try { + assertEquals(1, results.size) + val (topicPartition, result) = results.head + assertEquals(this.topicPartition, topicPartition) + assertEquals(Errors.NONE, result.error) + future.complete(result) + } catch { + case e: Throwable => future.completeExceptionally(e) + } + } + + replicaManager.appendRecords( + timeout = 30000, + requiredAcks = (-1).toShort, + internalTopicsAllowed = false, + origin = AppendOrigin.Client, + entriesPerPartition = collection.Map(topicPartition -> TestUtils.records(records)), + responseCallback = produceCallback + ) + + future.get() + sequence += numRecords + } + } + + sealed trait ControllerEvent + case object InitializeEvent extends ControllerEvent + case object ShutdownEvent extends ControllerEvent + case class AlterIsrEvent( + future: CompletableFuture[LeaderAndIsr], + topicPartition: TopicPartition, + leaderAndIsr: LeaderAndIsr + ) extends ControllerEvent + + private class ControllerChannel { + private val eventQueue = new LinkedBlockingQueue[ControllerEvent]() + + def poll(): ControllerEvent = { + eventQueue.take() + } + + def alterIsr( + topicPartition: TopicPartition, + leaderAndIsr: LeaderAndIsr + ): CompletableFuture[LeaderAndIsr] = { + val future = new CompletableFuture[LeaderAndIsr]() + eventQueue.offer(AlterIsrEvent(future, topicPartition, leaderAndIsr)) + future + } + + def initialize(): Unit = { + eventQueue.offer(InitializeEvent) + } + + def shutdown(): Unit = { + eventQueue.offer(ShutdownEvent) + } + } + + private class ControllerModel( + topic: TopicModel, + channel: ControllerChannel, + replicaManager: ReplicaManager + ) extends ShutdownableThread(name = "controller", isInterruptible = false) { + private var latestImage = MetadataImage.EMPTY + + def initialize(): Unit = { + channel.initialize() + } + + override def shutdown(): Unit = { + super.initiateShutdown() + channel.shutdown() + super.awaitShutdown() + } + + override def doWork(): Unit = { + channel.poll() match { + case InitializeEvent => + val delta = new MetadataDelta(latestImage) + topic.initialize(delta) + latestImage = delta.apply() + replicaManager.applyDelta(delta.topicsDelta, latestImage) + + case AlterIsrEvent(future, topicPartition, leaderAndIsr) => + val delta = new MetadataDelta(latestImage) + val updatedLeaderAndIsr = topic.alterIsr(topicPartition, leaderAndIsr, delta) + latestImage = delta.apply() + future.complete(updatedLeaderAndIsr) + replicaManager.applyDelta(delta.topicsDelta, latestImage) + + case ShutdownEvent => + } + } + } + + private class TopicModel( + val topicId: Uuid, + val name: String, + initialRegistrations: Map[Int, PartitionRegistration] + ) { + private val partitions: Map[Int, PartitionModel] = initialRegistrations.map { + case (partitionId, registration) => + partitionId -> new PartitionModel(this, partitionId, registration) + } + + def initialize(delta: MetadataDelta): Unit = { + delta.replay(new TopicRecord() + .setName(name) + .setTopicId(topicId) + ) + partitions.values.foreach(_.initialize(delta)) + } + + def alterIsr( + topicPartition: TopicPartition, + leaderAndIsr: LeaderAndIsr, + delta: MetadataDelta + ): LeaderAndIsr = { + val partitionModel = partitions.getOrElse(topicPartition.partition, + throw new IllegalStateException(s"Unexpected partition $topicPartition") + ) + partitionModel.alterIsr(leaderAndIsr, delta) + } + } + + private class PartitionModel( + val topic: TopicModel, + val partitionId: Int, + var registration: PartitionRegistration + ) { + def alterIsr( + leaderAndIsr: LeaderAndIsr, + delta: MetadataDelta + ): LeaderAndIsr = { + delta.replay(new PartitionChangeRecord() + .setTopicId(topic.topicId) + .setPartitionId(partitionId) + .setIsr(leaderAndIsr.isr.map(Int.box).asJava) + .setLeader(leaderAndIsr.leader) + ) + this.registration = delta.topicsDelta + .changedTopic(topic.topicId) + .partitionChanges + .get(partitionId) + + leaderAndIsr.withZkVersion(registration.partitionEpoch) + } + + private def toList(ints: Array[Int]): util.List[Integer] = { + ints.map(Int.box).toList.asJava + } + + def initialize(delta: MetadataDelta): Unit = { + delta.replay(new PartitionRecord() + .setTopicId(topic.topicId) + .setPartitionId(partitionId) + .setReplicas(toList(registration.replicas)) + .setIsr(toList(registration.isr)) + .setLeader(registration.leader) + .setLeaderEpoch(registration.leaderEpoch) + .setPartitionEpoch(registration.partitionEpoch) + ) + } + } + + private class MockAlterIsrManager(channel: ControllerChannel) extends AlterIsrManager { + override def submit( + topicPartition: TopicPartition, + leaderAndIsr: LeaderAndIsr, + controllerEpoch: Int + ): CompletableFuture[LeaderAndIsr] = { + channel.alterIsr(topicPartition, leaderAndIsr) + } + } + + private def registration( + replicaIds: Seq[Int], + isr: Seq[Int], + leader: Int, + leaderEpoch: Int = 0, + version: Int = 0 + ): PartitionRegistration = { + new PartitionRegistration( + replicaIds.toArray, + isr.toArray, + Array.empty[Int], + Array.empty[Int], + leader, + leaderEpoch, + version + ) + } + +} diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerQuotasTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaManagerQuotasTest.scala new file mode 100644 index 0000000..8be7810 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerQuotasTest.scala @@ -0,0 +1,287 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.io.File +import java.util.{Collections, Optional, Properties} + +import kafka.cluster.Partition +import kafka.log.{UnifiedLog, LogManager, LogOffsetSnapshot} +import kafka.utils._ +import org.apache.kafka.common.{TopicPartition, TopicIdPartition, Uuid} +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.record.{CompressionType, MemoryRecords, SimpleRecord} +import org.apache.kafka.common.requests.FetchRequest.PartitionData +import org.easymock.EasyMock +import EasyMock._ +import kafka.server.QuotaFactory.QuotaManagers +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, Test} + +import scala.jdk.CollectionConverters._ + +class ReplicaManagerQuotasTest { + val configs = TestUtils.createBrokerConfigs(2, TestUtils.MockZkConnect).map(KafkaConfig.fromProps(_, new Properties())) + val time = new MockTime + val metrics = new Metrics + val record = new SimpleRecord("some-data-in-a-message".getBytes()) + val topicPartition1 = new TopicPartition("test-topic", 1) + val topicPartition2 = new TopicPartition("test-topic", 2) + val topicId = Uuid.randomUuid() + val topicIds = Collections.singletonMap("test-topic", topicId) + val topicIdPartition1 = new TopicIdPartition(topicId, topicPartition1) + val topicIdPartition2 = new TopicIdPartition(topicId, topicPartition2) + val fetchInfo = Seq( + topicIdPartition1 -> new PartitionData(Uuid.ZERO_UUID, 0, 0, 100, Optional.empty()), + topicIdPartition2 -> new PartitionData(Uuid.ZERO_UUID, 0, 0, 100, Optional.empty())) + var quotaManager: QuotaManagers = _ + var replicaManager: ReplicaManager = _ + + @Test + def shouldExcludeSubsequentThrottledPartitions(): Unit = { + setUpMocks(fetchInfo) + val followerReplicaId = configs.last.brokerId + + val quota = mockQuota(1000000) + expect(quota.isQuotaExceeded).andReturn(false).once() + expect(quota.isQuotaExceeded).andReturn(true).once() + replay(quota) + + val fetch = replicaManager.readFromLocalLog( + replicaId = followerReplicaId, + fetchOnlyFromLeader = true, + fetchIsolation = FetchHighWatermark, + fetchMaxBytes = Int.MaxValue, + hardMaxBytesLimit = false, + readPartitionInfo = fetchInfo, + quota = quota, + clientMetadata = None) + assertEquals(1, fetch.find(_._1 == topicIdPartition1).get._2.info.records.batches.asScala.size, + "Given two partitions, with only one throttled, we should get the first") + + assertEquals(0, fetch.find(_._1 == topicIdPartition2).get._2.info.records.batches.asScala.size, + "But we shouldn't get the second") + } + + @Test + def shouldGetNoMessagesIfQuotasExceededOnSubsequentPartitions(): Unit = { + setUpMocks(fetchInfo) + val followerReplicaId = configs.last.brokerId + + val quota = mockQuota(1000000) + expect(quota.isQuotaExceeded).andReturn(true).once() + expect(quota.isQuotaExceeded).andReturn(true).once() + replay(quota) + + val fetch = replicaManager.readFromLocalLog( + replicaId = followerReplicaId, + fetchOnlyFromLeader = true, + fetchIsolation = FetchHighWatermark, + fetchMaxBytes = Int.MaxValue, + hardMaxBytesLimit = false, + readPartitionInfo = fetchInfo, + quota = quota, + clientMetadata = None) + assertEquals(0, fetch.find(_._1 == topicIdPartition1).get._2.info.records.batches.asScala.size, + "Given two partitions, with both throttled, we should get no messages") + assertEquals(0, fetch.find(_._1 == topicIdPartition2).get._2.info.records.batches.asScala.size, + "Given two partitions, with both throttled, we should get no messages") + } + + @Test + def shouldGetBothMessagesIfQuotasAllow(): Unit = { + setUpMocks(fetchInfo) + val followerReplicaId = configs.last.brokerId + + val quota = mockQuota(1000000) + expect(quota.isQuotaExceeded).andReturn(false).once() + expect(quota.isQuotaExceeded).andReturn(false).once() + replay(quota) + + val fetch = replicaManager.readFromLocalLog( + replicaId = followerReplicaId, + fetchOnlyFromLeader = true, + fetchIsolation = FetchHighWatermark, + fetchMaxBytes = Int.MaxValue, + hardMaxBytesLimit = false, + readPartitionInfo = fetchInfo, + quota = quota, + clientMetadata = None) + assertEquals(1, fetch.find(_._1 == topicIdPartition1).get._2.info.records.batches.asScala.size, + "Given two partitions, with both non-throttled, we should get both messages") + assertEquals(1, fetch.find(_._1 == topicIdPartition2).get._2.info.records.batches.asScala.size, + "Given two partitions, with both non-throttled, we should get both messages") + } + + @Test + def shouldIncludeInSyncThrottledReplicas(): Unit = { + setUpMocks(fetchInfo, bothReplicasInSync = true) + val followerReplicaId = configs.last.brokerId + + val quota = mockQuota(1000000) + expect(quota.isQuotaExceeded).andReturn(false).once() + expect(quota.isQuotaExceeded).andReturn(true).once() + replay(quota) + + val fetch = replicaManager.readFromLocalLog( + replicaId = followerReplicaId, + fetchOnlyFromLeader = true, + fetchIsolation = FetchHighWatermark, + fetchMaxBytes = Int.MaxValue, + hardMaxBytesLimit = false, + readPartitionInfo = fetchInfo, + quota = quota, + clientMetadata = None) + assertEquals(1, fetch.find(_._1 == topicIdPartition1).get._2.info.records.batches.asScala.size, + "Given two partitions, with only one throttled, we should get the first") + + assertEquals(1, fetch.find(_._1 == topicIdPartition2).get._2.info.records.batches.asScala.size, + "But we should get the second too since it's throttled but in sync") + } + + @Test + def testCompleteInDelayedFetchWithReplicaThrottling(): Unit = { + // Set up DelayedFetch where there is data to return to a follower replica, either in-sync or out of sync + def setupDelayedFetch(isReplicaInSync: Boolean): DelayedFetch = { + val endOffsetMetadata = LogOffsetMetadata(messageOffset = 100L, segmentBaseOffset = 0L, relativePositionInSegment = 500) + val partition: Partition = EasyMock.createMock(classOf[Partition]) + + val offsetSnapshot = LogOffsetSnapshot( + logStartOffset = 0L, + logEndOffset = endOffsetMetadata, + highWatermark = endOffsetMetadata, + lastStableOffset = endOffsetMetadata) + EasyMock.expect(partition.fetchOffsetSnapshot(Optional.empty(), fetchOnlyFromLeader = true)) + .andReturn(offsetSnapshot) + + val replicaManager: ReplicaManager = EasyMock.createMock(classOf[ReplicaManager]) + EasyMock.expect(replicaManager.getPartitionOrException(EasyMock.anyObject[TopicPartition])) + .andReturn(partition).anyTimes() + + EasyMock.expect(replicaManager.shouldLeaderThrottle(EasyMock.anyObject[ReplicaQuota], EasyMock.anyObject[Partition], EasyMock.anyObject[Int])) + .andReturn(!isReplicaInSync).anyTimes() + EasyMock.expect(partition.getReplica(1)).andReturn(None) + EasyMock.replay(replicaManager, partition) + + val tp = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("t1", 0)) + val fetchPartitionStatus = FetchPartitionStatus(LogOffsetMetadata(messageOffset = 50L, segmentBaseOffset = 0L, + relativePositionInSegment = 250), new PartitionData(Uuid.ZERO_UUID, 50, 0, 1, Optional.empty())) + val fetchMetadata = FetchMetadata(fetchMinBytes = 1, + fetchMaxBytes = 1000, + hardMaxBytesLimit = true, + fetchOnlyLeader = true, + fetchIsolation = FetchLogEnd, + isFromFollower = true, + replicaId = 1, + fetchPartitionStatus = List((tp, fetchPartitionStatus)) + ) + new DelayedFetch(delayMs = 600, fetchMetadata = fetchMetadata, replicaManager = replicaManager, + quota = null, clientMetadata = None, responseCallback = null) { + override def forceComplete(): Boolean = true + } + } + + assertTrue(setupDelayedFetch(isReplicaInSync = true).tryComplete(), "In sync replica should complete") + assertFalse(setupDelayedFetch(isReplicaInSync = false).tryComplete(), "Out of sync replica should not complete") + } + + def setUpMocks(fetchInfo: Seq[(TopicIdPartition, PartitionData)], record: SimpleRecord = this.record, + bothReplicasInSync: Boolean = false): Unit = { + val scheduler: KafkaScheduler = createNiceMock(classOf[KafkaScheduler]) + + //Create log which handles both a regular read and a 0 bytes read + val log: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + expect(log.logStartOffset).andReturn(0L).anyTimes() + expect(log.logEndOffset).andReturn(20L).anyTimes() + expect(log.highWatermark).andReturn(5).anyTimes() + expect(log.lastStableOffset).andReturn(5).anyTimes() + expect(log.logEndOffsetMetadata).andReturn(LogOffsetMetadata(20L)).anyTimes() + expect(log.topicId).andReturn(Some(topicId)).anyTimes() + + //if we ask for len 1 return a message + expect(log.read(anyObject(), + maxLength = geq(1), + isolation = anyObject(), + minOneMessage = anyBoolean())).andReturn( + FetchDataInfo( + LogOffsetMetadata(0L, 0L, 0), + MemoryRecords.withRecords(CompressionType.NONE, record) + )).anyTimes() + + //if we ask for len = 0, return 0 messages + expect(log.read(anyObject(), + maxLength = EasyMock.eq(0), + isolation = anyObject(), + minOneMessage = anyBoolean())).andReturn( + FetchDataInfo( + LogOffsetMetadata(0L, 0L, 0), + MemoryRecords.EMPTY + )).anyTimes() + replay(log) + + //Create log manager + val logManager: LogManager = createMock(classOf[LogManager]) + + //Return the same log for each partition as it doesn't matter + expect(logManager.getLog(anyObject(), anyBoolean())).andReturn(Some(log)).anyTimes() + expect(logManager.liveLogDirs).andReturn(Array.empty[File]).anyTimes() + replay(logManager) + + val alterIsrManager: AlterIsrManager = createMock(classOf[AlterIsrManager]) + + val leaderBrokerId = configs.head.brokerId + quotaManager = QuotaFactory.instantiate(configs.head, metrics, time, "") + replicaManager = new ReplicaManager( + metrics = metrics, + config = configs.head, + time = time, + scheduler = scheduler, + logManager = logManager, + quotaManagers = quotaManager, + metadataCache = MetadataCache.zkMetadataCache(leaderBrokerId), + logDirFailureChannel = new LogDirFailureChannel(configs.head.logDirs.size), + alterIsrManager = alterIsrManager) + + //create the two replicas + for ((p, _) <- fetchInfo) { + val partition = replicaManager.createPartition(p.topicPartition) + log.updateHighWatermark(5) + partition.leaderReplicaIdOpt = Some(leaderBrokerId) + partition.setLog(log, isFutureLog = false) + + partition.updateAssignmentAndIsr( + assignment = Seq(leaderBrokerId, configs.last.brokerId), + isr = if (bothReplicasInSync) Set(leaderBrokerId, configs.last.brokerId) else Set(leaderBrokerId), + addingReplicas = Seq.empty, + removingReplicas = Seq.empty + ) + } + } + + @AfterEach + def tearDown(): Unit = { + Option(replicaManager).foreach(_.shutdown(false)) + Option(quotaManager).foreach(_.shutdown()) + metrics.close() + } + + def mockQuota(bound: Long): ReplicaQuota = { + val quota: ReplicaQuota = createMock(classOf[ReplicaQuota]) + expect(quota.isThrottled(anyObject())).andReturn(true).anyTimes() + quota + } +} diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala new file mode 100644 index 0000000..2b1a397 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala @@ -0,0 +1,3585 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.io.File +import java.net.InetAddress +import java.nio.file.Files +import java.util +import java.util.concurrent.atomic.AtomicReference +import java.util.concurrent.{CountDownLatch, TimeUnit} +import java.util.stream.IntStream +import java.util.{Collections, Optional, Properties} +import kafka.api._ +import kafka.cluster.{BrokerEndPoint, Partition} +import kafka.log._ +import kafka.server.QuotaFactory.{QuotaManagers, UnboundedQuota} +import kafka.server.checkpoints.{LazyOffsetCheckpoints, OffsetCheckpointFile} +import kafka.server.epoch.util.ReplicaFetcherMockBlockingSend +import kafka.utils.timer.MockTimer +import kafka.utils.{MockScheduler, MockTime, TestUtils} +import org.apache.kafka.common.message.FetchResponseData +import org.apache.kafka.common.message.LeaderAndIsrRequestData +import org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrPartitionState +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset +import org.apache.kafka.common.message.StopReplicaRequestData.StopReplicaPartitionState +import org.apache.kafka.common.metadata.{PartitionChangeRecord, PartitionRecord, RemoveTopicRecord, TopicRecord} +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.record._ +import org.apache.kafka.common.replica.ClientMetadata +import org.apache.kafka.common.replica.ClientMetadata.DefaultClientMetadata +import org.apache.kafka.common.requests.FetchRequest.PartitionData +import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse +import org.apache.kafka.common.requests._ +import org.apache.kafka.common.security.auth.KafkaPrincipal +import org.apache.kafka.common.utils.{Time, Utils} +import org.apache.kafka.common.{IsolationLevel, Node, TopicPartition, TopicIdPartition, Uuid} +import org.apache.kafka.image.{ClientQuotasImage, ClusterImageTest, ConfigurationsImage, FeaturesImage, MetadataImage, TopicsDelta, TopicsImage} +import org.apache.kafka.raft.{OffsetAndEpoch => RaftOffsetAndEpoch} +import org.easymock.EasyMock +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ValueSource +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer +import org.mockito.{ArgumentMatchers, Mockito} + +import scala.collection.{Map, Seq, mutable} +import scala.jdk.CollectionConverters._ + +class ReplicaManagerTest { + + val topic = "test-topic" + val topicId = Uuid.randomUuid() + val topicIds = scala.Predef.Map("test-topic" -> topicId) + val topicNames = scala.Predef.Map(topicId -> "test-topic") + val time = new MockTime + val scheduler = new MockScheduler(time) + val metrics = new Metrics + var alterIsrManager: AlterIsrManager = _ + var config: KafkaConfig = _ + var quotaManager: QuotaManagers = _ + + // Constants defined for readability + val zkVersion = 0 + val correlationId = 0 + var controllerEpoch = 0 + val brokerEpoch = 0L + + @BeforeEach + def setUp(): Unit = { + val props = TestUtils.createBrokerConfig(1, TestUtils.MockZkConnect) + config = KafkaConfig.fromProps(props) + alterIsrManager = EasyMock.createMock(classOf[AlterIsrManager]) + quotaManager = QuotaFactory.instantiate(config, metrics, time, "") + } + + @AfterEach + def tearDown(): Unit = { + TestUtils.clearYammerMetrics() + Option(quotaManager).foreach(_.shutdown()) + metrics.close() + } + + @Test + def testHighWaterMarkDirectoryMapping(): Unit = { + val mockLogMgr = TestUtils.createLogManager(config.logDirs.map(new File(_))) + val rm = new ReplicaManager( + metrics = metrics, + config = config, + time = time, + scheduler = new MockScheduler(time), + logManager = mockLogMgr, + quotaManagers = quotaManager, + metadataCache = MetadataCache.zkMetadataCache(config.brokerId), + logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size), + alterIsrManager = alterIsrManager) + try { + val partition = rm.createPartition(new TopicPartition(topic, 1)) + partition.createLogIfNotExists(isNew = false, isFutureReplica = false, + new LazyOffsetCheckpoints(rm.highWatermarkCheckpoints), None) + rm.checkpointHighWatermarks() + } finally { + // shutdown the replica manager upon test completion + rm.shutdown(false) + } + } + + @Test + def testHighwaterMarkRelativeDirectoryMapping(): Unit = { + val props = TestUtils.createBrokerConfig(1, TestUtils.MockZkConnect) + props.put("log.dir", TestUtils.tempRelativeDir("data").getAbsolutePath) + val config = KafkaConfig.fromProps(props) + val mockLogMgr = TestUtils.createLogManager(config.logDirs.map(new File(_))) + val rm = new ReplicaManager( + metrics = metrics, + config = config, + time = time, + scheduler = new MockScheduler(time), + logManager = mockLogMgr, + quotaManagers = quotaManager, + metadataCache = MetadataCache.zkMetadataCache(config.brokerId), + logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size), + alterIsrManager = alterIsrManager) + try { + val partition = rm.createPartition(new TopicPartition(topic, 1)) + partition.createLogIfNotExists(isNew = false, isFutureReplica = false, + new LazyOffsetCheckpoints(rm.highWatermarkCheckpoints), None) + rm.checkpointHighWatermarks() + } finally { + // shutdown the replica manager upon test completion + rm.shutdown(checkpointHW = false) + } + } + + @Test + def testIllegalRequiredAcks(): Unit = { + val mockLogMgr = TestUtils.createLogManager(config.logDirs.map(new File(_))) + val rm = new ReplicaManager( + metrics = metrics, + config = config, + time = time, + scheduler = new MockScheduler(time), + logManager = mockLogMgr, + quotaManagers = quotaManager, + metadataCache = MetadataCache.zkMetadataCache(config.brokerId), + logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size), + alterIsrManager = alterIsrManager, + threadNamePrefix = Option(this.getClass.getName)) + try { + def callback(responseStatus: Map[TopicPartition, PartitionResponse]) = { + assert(responseStatus.values.head.error == Errors.INVALID_REQUIRED_ACKS) + } + rm.appendRecords( + timeout = 0, + requiredAcks = 3, + internalTopicsAllowed = false, + origin = AppendOrigin.Client, + entriesPerPartition = Map(new TopicPartition("test1", 0) -> MemoryRecords.withRecords(CompressionType.NONE, + new SimpleRecord("first message".getBytes))), + responseCallback = callback) + } finally { + rm.shutdown(checkpointHW = false) + } + + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + private def mockGetAliveBrokerFunctions(cache: MetadataCache, aliveBrokers: Seq[Node]): Unit = { + Mockito.when(cache.hasAliveBroker(ArgumentMatchers.anyInt())).thenAnswer(new Answer[Boolean]() { + override def answer(invocation: InvocationOnMock): Boolean = { + aliveBrokers.map(_.id()).contains(invocation.getArguments()(0).asInstanceOf[Int]) + } + }) + Mockito.when(cache.getAliveBrokerNode(ArgumentMatchers.anyInt(), ArgumentMatchers.any[ListenerName])). + thenAnswer(new Answer[Option[Node]]() { + override def answer(invocation: InvocationOnMock): Option[Node] = { + aliveBrokers.find(node => node.id == invocation.getArguments()(0).asInstanceOf[Integer]) + } + }) + Mockito.when(cache.getAliveBrokerNodes(ArgumentMatchers.any[ListenerName])).thenReturn(aliveBrokers) + } + + @Test + def testClearPurgatoryOnBecomingFollower(): Unit = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect) + props.put("log.dir", TestUtils.tempRelativeDir("data").getAbsolutePath) + val config = KafkaConfig.fromProps(props) + val logProps = new Properties() + val mockLogMgr = TestUtils.createLogManager(config.logDirs.map(new File(_)), LogConfig(logProps)) + val aliveBrokers = Seq(new Node(0, "host0", 0), new Node(1, "host1", 1)) + val metadataCache: MetadataCache = Mockito.mock(classOf[MetadataCache]) + mockGetAliveBrokerFunctions(metadataCache, aliveBrokers) + val rm = new ReplicaManager( + metrics = metrics, + config = config, + time = time, + scheduler = new MockScheduler(time), + logManager = mockLogMgr, + quotaManagers = quotaManager, + metadataCache = metadataCache, + logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size), + alterIsrManager = alterIsrManager) + + try { + val brokerList = Seq[Integer](0, 1).asJava + val topicIds = Collections.singletonMap(topic, Uuid.randomUuid()) + + val partition = rm.createPartition(new TopicPartition(topic, 0)) + partition.createLogIfNotExists(isNew = false, isFutureReplica = false, + new LazyOffsetCheckpoints(rm.highWatermarkCheckpoints), None) + // Make this replica the leader. + val leaderAndIsrRequest1 = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch, + Seq(new LeaderAndIsrPartitionState() + .setTopicName(topic) + .setPartitionIndex(0) + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(0) + .setIsr(brokerList) + .setZkVersion(0) + .setReplicas(brokerList) + .setIsNew(false)).asJava, + topicIds, + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build() + rm.becomeLeaderOrFollower(0, leaderAndIsrRequest1, (_, _) => ()) + rm.getPartitionOrException(new TopicPartition(topic, 0)) + .localLogOrException + + val records = MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("first message".getBytes())) + val appendResult = appendRecords(rm, new TopicPartition(topic, 0), records).onFire { response => + assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, response.error) + } + + // Make this replica the follower + val leaderAndIsrRequest2 = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch, + Seq(new LeaderAndIsrPartitionState() + .setTopicName(topic) + .setPartitionIndex(0) + .setControllerEpoch(0) + .setLeader(1) + .setLeaderEpoch(1) + .setIsr(brokerList) + .setZkVersion(0) + .setReplicas(brokerList) + .setIsNew(false)).asJava, + topicIds, + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build() + rm.becomeLeaderOrFollower(1, leaderAndIsrRequest2, (_, _) => ()) + + assertTrue(appendResult.isFired) + } finally { + rm.shutdown(checkpointHW = false) + } + } + + @Test + def testFencedErrorCausedByBecomeLeader(): Unit = { + testFencedErrorCausedByBecomeLeader(0) + testFencedErrorCausedByBecomeLeader(1) + testFencedErrorCausedByBecomeLeader(10) + } + + private[this] def testFencedErrorCausedByBecomeLeader(loopEpochChange: Int): Unit = { + val replicaManager = setupReplicaManagerWithMockedPurgatories(new MockTimer(time)) + try { + val brokerList = Seq[Integer](0, 1).asJava + val topicPartition = new TopicPartition(topic, 0) + replicaManager.createPartition(topicPartition) + .createLogIfNotExists(isNew = false, isFutureReplica = false, + new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints), None) + + def leaderAndIsrRequest(epoch: Int): LeaderAndIsrRequest = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch, + Seq(new LeaderAndIsrPartitionState() + .setTopicName(topic) + .setPartitionIndex(0) + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(epoch) + .setIsr(brokerList) + .setZkVersion(0) + .setReplicas(brokerList) + .setIsNew(true)).asJava, + topicIds.asJava, + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build() + + replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest(0), (_, _) => ()) + val partition = replicaManager.getPartitionOrException(new TopicPartition(topic, 0)) + assertEquals(1, replicaManager.logManager.liveLogDirs.filterNot(_ == partition.log.get.dir.getParentFile).size) + + val previousReplicaFolder = partition.log.get.dir.getParentFile + // find the live and different folder + val newReplicaFolder = replicaManager.logManager.liveLogDirs.filterNot(_ == partition.log.get.dir.getParentFile).head + assertEquals(0, replicaManager.replicaAlterLogDirsManager.fetcherThreadMap.size) + replicaManager.alterReplicaLogDirs(Map(topicPartition -> newReplicaFolder.getAbsolutePath)) + // make sure the future log is created + replicaManager.futureLocalLogOrException(topicPartition) + assertEquals(1, replicaManager.replicaAlterLogDirsManager.fetcherThreadMap.size) + (1 to loopEpochChange).foreach(epoch => replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest(epoch), (_, _) => ())) + // wait for the ReplicaAlterLogDirsThread to complete + TestUtils.waitUntilTrue(() => { + replicaManager.replicaAlterLogDirsManager.shutdownIdleFetcherThreads() + replicaManager.replicaAlterLogDirsManager.fetcherThreadMap.isEmpty + }, s"ReplicaAlterLogDirsThread should be gone") + + // the fenced error should be recoverable + assertEquals(0, replicaManager.replicaAlterLogDirsManager.failedPartitions.size) + // the replica change is completed after retrying + assertTrue(partition.futureLog.isEmpty) + assertEquals(newReplicaFolder.getAbsolutePath, partition.log.get.dir.getParent) + // change the replica folder again + val response = replicaManager.alterReplicaLogDirs(Map(topicPartition -> previousReplicaFolder.getAbsolutePath)) + assertNotEquals(0, response.size) + response.values.foreach(assertEquals(Errors.NONE, _)) + // should succeed to invoke ReplicaAlterLogDirsThread again + assertEquals(1, replicaManager.replicaAlterLogDirsManager.fetcherThreadMap.size) + } finally replicaManager.shutdown(checkpointHW = false) + } + + @Test + def testReceiveOutOfOrderSequenceExceptionWithLogStartOffset(): Unit = { + val timer = new MockTimer(time) + val replicaManager = setupReplicaManagerWithMockedPurgatories(timer) + + try { + val brokerList = Seq[Integer](0, 1).asJava + + val partition = replicaManager.createPartition(new TopicPartition(topic, 0)) + partition.createLogIfNotExists(isNew = false, isFutureReplica = false, + new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints), None) + + // Make this replica the leader. + val leaderAndIsrRequest1 = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch, + Seq(new LeaderAndIsrPartitionState() + .setTopicName(topic) + .setPartitionIndex(0) + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(0) + .setIsr(brokerList) + .setZkVersion(0) + .setReplicas(brokerList) + .setIsNew(true)).asJava, + Collections.singletonMap(topic, Uuid.randomUuid()), + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build() + replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest1, (_, _) => ()) + replicaManager.getPartitionOrException(new TopicPartition(topic, 0)) + .localLogOrException + + val producerId = 234L + val epoch = 5.toShort + + // write a few batches as part of a transaction + val numRecords = 3 + for (sequence <- 0 until numRecords) { + val records = MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, epoch, sequence, + new SimpleRecord(s"message $sequence".getBytes)) + appendRecords(replicaManager, new TopicPartition(topic, 0), records).onFire { response => + assertEquals(Errors.NONE, response.error) + } + } + + assertEquals(0, partition.logStartOffset) + + // Append a record with an out of range sequence. We should get the OutOfOrderSequence error code with the log + // start offset set. + val outOfRangeSequence = numRecords + 10 + val record = MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, epoch, outOfRangeSequence, + new SimpleRecord(s"message: $outOfRangeSequence".getBytes)) + appendRecords(replicaManager, new TopicPartition(topic, 0), record).onFire { response => + assertEquals(Errors.OUT_OF_ORDER_SEQUENCE_NUMBER, response.error) + assertEquals(0, response.logStartOffset) + } + + } finally { + replicaManager.shutdown(checkpointHW = false) + } + + } + + @Test + def testReadCommittedFetchLimitedAtLSO(): Unit = { + val timer = new MockTimer(time) + val replicaManager = setupReplicaManagerWithMockedPurgatories(timer) + + try { + val brokerList = Seq[Integer](0, 1).asJava + + val partition = replicaManager.createPartition(new TopicPartition(topic, 0)) + partition.createLogIfNotExists(isNew = false, isFutureReplica = false, + new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints), None) + + // Make this replica the leader. + val leaderAndIsrRequest1 = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch, + Seq(new LeaderAndIsrPartitionState() + .setTopicName(topic) + .setPartitionIndex(0) + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(0) + .setIsr(brokerList) + .setZkVersion(0) + .setReplicas(brokerList) + .setIsNew(true)).asJava, + topicIds.asJava, + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build() + replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest1, (_, _) => ()) + replicaManager.getPartitionOrException(new TopicPartition(topic, 0)) + .localLogOrException + + val producerId = 234L + val epoch = 5.toShort + + // write a few batches as part of a transaction + val numRecords = 3 + for (sequence <- 0 until numRecords) { + val records = MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, epoch, sequence, + new SimpleRecord(s"message $sequence".getBytes)) + appendRecords(replicaManager, new TopicPartition(topic, 0), records).onFire { response => + assertEquals(Errors.NONE, response.error) + } + } + + // fetch as follower to advance the high watermark + fetchAsFollower(replicaManager, new TopicIdPartition(topicId, new TopicPartition(topic, 0)), + new PartitionData(Uuid.ZERO_UUID, numRecords, 0, 100000, Optional.empty()), + isolationLevel = IsolationLevel.READ_UNCOMMITTED) + + // fetch should return empty since LSO should be stuck at 0 + var consumerFetchResult = fetchAsConsumer(replicaManager, new TopicIdPartition(topicId, new TopicPartition(topic, 0)), + new PartitionData(Uuid.ZERO_UUID, 0, 0, 100000, Optional.empty()), + isolationLevel = IsolationLevel.READ_COMMITTED) + var fetchData = consumerFetchResult.assertFired + assertEquals(Errors.NONE, fetchData.error) + assertTrue(fetchData.records.batches.asScala.isEmpty) + assertEquals(Some(0), fetchData.lastStableOffset) + assertEquals(Some(List.empty[FetchResponseData.AbortedTransaction]), fetchData.abortedTransactions) + + // delayed fetch should timeout and return nothing + consumerFetchResult = fetchAsConsumer(replicaManager, new TopicIdPartition(topicId, new TopicPartition(topic, 0)), + new PartitionData(Uuid.ZERO_UUID, 0, 0, 100000, Optional.empty()), + isolationLevel = IsolationLevel.READ_COMMITTED, minBytes = 1000) + assertFalse(consumerFetchResult.isFired) + timer.advanceClock(1001) + + fetchData = consumerFetchResult.assertFired + assertEquals(Errors.NONE, fetchData.error) + assertTrue(fetchData.records.batches.asScala.isEmpty) + assertEquals(Some(0), fetchData.lastStableOffset) + assertEquals(Some(List.empty[FetchResponseData.AbortedTransaction]), fetchData.abortedTransactions) + + // now commit the transaction + val endTxnMarker = new EndTransactionMarker(ControlRecordType.COMMIT, 0) + val commitRecordBatch = MemoryRecords.withEndTransactionMarker(producerId, epoch, endTxnMarker) + appendRecords(replicaManager, new TopicPartition(topic, 0), commitRecordBatch, + origin = AppendOrigin.Coordinator) + .onFire { response => assertEquals(Errors.NONE, response.error) } + + // the LSO has advanced, but the appended commit marker has not been replicated, so + // none of the data from the transaction should be visible yet + consumerFetchResult = fetchAsConsumer(replicaManager, new TopicIdPartition(topicId, new TopicPartition(topic, 0)), + new PartitionData(Uuid.ZERO_UUID, 0, 0, 100000, Optional.empty()), + isolationLevel = IsolationLevel.READ_COMMITTED) + + fetchData = consumerFetchResult.assertFired + assertEquals(Errors.NONE, fetchData.error) + assertTrue(fetchData.records.batches.asScala.isEmpty) + + // fetch as follower to advance the high watermark + fetchAsFollower(replicaManager, new TopicIdPartition(topicId, new TopicPartition(topic, 0)), + new PartitionData(Uuid.ZERO_UUID, numRecords + 1, 0, 100000, Optional.empty()), + isolationLevel = IsolationLevel.READ_UNCOMMITTED) + + // now all of the records should be fetchable + consumerFetchResult = fetchAsConsumer(replicaManager, new TopicIdPartition(topicId, new TopicPartition(topic, 0)), + new PartitionData(Uuid.ZERO_UUID, 0, 0, 100000, Optional.empty()), + isolationLevel = IsolationLevel.READ_COMMITTED) + + fetchData = consumerFetchResult.assertFired + assertEquals(Errors.NONE, fetchData.error) + assertEquals(Some(numRecords + 1), fetchData.lastStableOffset) + assertEquals(Some(List.empty[FetchResponseData.AbortedTransaction]), fetchData.abortedTransactions) + assertEquals(numRecords + 1, fetchData.records.batches.asScala.size) + } finally { + replicaManager.shutdown(checkpointHW = false) + } + } + + @Test + def testDelayedFetchIncludesAbortedTransactions(): Unit = { + val timer = new MockTimer(time) + val replicaManager = setupReplicaManagerWithMockedPurgatories(timer) + + try { + val brokerList = Seq[Integer](0, 1).asJava + val partition = replicaManager.createPartition(new TopicPartition(topic, 0)) + partition.createLogIfNotExists(isNew = false, isFutureReplica = false, + new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints), None) + + // Make this replica the leader. + val leaderAndIsrRequest1 = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch, + Seq(new LeaderAndIsrPartitionState() + .setTopicName(topic) + .setPartitionIndex(0) + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(0) + .setIsr(brokerList) + .setZkVersion(0) + .setReplicas(brokerList) + .setIsNew(true)).asJava, + topicIds.asJava, + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build() + replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest1, (_, _) => ()) + replicaManager.getPartitionOrException(new TopicPartition(topic, 0)) + .localLogOrException + + val producerId = 234L + val epoch = 5.toShort + + // write a few batches as part of a transaction + val numRecords = 3 + for (sequence <- 0 until numRecords) { + val records = MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, epoch, sequence, + new SimpleRecord(s"message $sequence".getBytes)) + appendRecords(replicaManager, new TopicPartition(topic, 0), records).onFire { response => + assertEquals(Errors.NONE, response.error) + } + } + + // now abort the transaction + val endTxnMarker = new EndTransactionMarker(ControlRecordType.ABORT, 0) + val abortRecordBatch = MemoryRecords.withEndTransactionMarker(producerId, epoch, endTxnMarker) + appendRecords(replicaManager, new TopicPartition(topic, 0), abortRecordBatch, + origin = AppendOrigin.Coordinator) + .onFire { response => assertEquals(Errors.NONE, response.error) } + + // fetch as follower to advance the high watermark + fetchAsFollower(replicaManager, new TopicIdPartition(topicId, new TopicPartition(topic, 0)), + new PartitionData(Uuid.ZERO_UUID, numRecords + 1, 0, 100000, Optional.empty()), + isolationLevel = IsolationLevel.READ_UNCOMMITTED) + + // Set the minBytes in order force this request to enter purgatory. When it returns, we should still + // see the newly aborted transaction. + val fetchResult = fetchAsConsumer(replicaManager, new TopicIdPartition(topicId, new TopicPartition(topic, 0)), + new PartitionData(Uuid.ZERO_UUID, 0, 0, 100000, Optional.empty()), + isolationLevel = IsolationLevel.READ_COMMITTED, minBytes = 10000) + assertFalse(fetchResult.isFired) + + timer.advanceClock(1001) + val fetchData = fetchResult.assertFired + + assertEquals(Errors.NONE, fetchData.error) + assertEquals(Some(numRecords + 1), fetchData.lastStableOffset) + assertEquals(numRecords + 1, fetchData.records.records.asScala.size) + assertTrue(fetchData.abortedTransactions.isDefined) + assertEquals(1, fetchData.abortedTransactions.get.size) + + val abortedTransaction = fetchData.abortedTransactions.get.head + assertEquals(0L, abortedTransaction.firstOffset) + assertEquals(producerId, abortedTransaction.producerId) + } finally { + replicaManager.shutdown(checkpointHW = false) + } + } + + @Test + def testFetchBeyondHighWatermark(): Unit = { + val rm = setupReplicaManagerWithMockedPurgatories(new MockTimer(time), aliveBrokerIds = Seq(0, 1, 2)) + try { + val brokerList = Seq[Integer](0, 1, 2).asJava + + val partition = rm.createPartition(new TopicPartition(topic, 0)) + partition.createLogIfNotExists(isNew = false, isFutureReplica = false, + new LazyOffsetCheckpoints(rm.highWatermarkCheckpoints), None) + + // Make this replica the leader. + val leaderAndIsrRequest1 = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch, + Seq(new LeaderAndIsrPartitionState() + .setTopicName(topic) + .setPartitionIndex(0) + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(0) + .setIsr(brokerList) + .setZkVersion(0) + .setReplicas(brokerList) + .setIsNew(false)).asJava, + topicIds.asJava, + Set(new Node(0, "host1", 0), new Node(1, "host2", 1), new Node(2, "host2", 2)).asJava).build() + rm.becomeLeaderOrFollower(0, leaderAndIsrRequest1, (_, _) => ()) + rm.getPartitionOrException(new TopicPartition(topic, 0)) + .localLogOrException + + // Append a couple of messages. + for (i <- 1 to 2) { + val records = TestUtils.singletonRecords(s"message $i".getBytes) + appendRecords(rm, new TopicPartition(topic, 0), records).onFire { response => + assertEquals(Errors.NONE, response.error) + } + } + + // Followers are always allowed to fetch above the high watermark + val followerFetchResult = fetchAsFollower(rm, new TopicIdPartition(topicId, new TopicPartition(topic, 0)), + new PartitionData(Uuid.ZERO_UUID, 1, 0, 100000, Optional.empty())) + val followerFetchData = followerFetchResult.assertFired + assertEquals(Errors.NONE, followerFetchData.error, "Should not give an exception") + assertTrue(followerFetchData.records.batches.iterator.hasNext, "Should return some data") + + // Consumers are not allowed to consume above the high watermark. However, since the + // high watermark could be stale at the time of the request, we do not return an out of + // range error and instead return an empty record set. + val consumerFetchResult = fetchAsConsumer(rm, new TopicIdPartition(topicId, new TopicPartition(topic, 0)), + new PartitionData(Uuid.ZERO_UUID, 1, 0, 100000, Optional.empty())) + val consumerFetchData = consumerFetchResult.assertFired + assertEquals(Errors.NONE, consumerFetchData.error, "Should not give an exception") + assertEquals(MemoryRecords.EMPTY, consumerFetchData.records, "Should return empty response") + } finally { + rm.shutdown(checkpointHW = false) + } + } + + @Test + def testFollowerStateNotUpdatedIfLogReadFails(): Unit = { + val maxFetchBytes = 1024 * 1024 + val aliveBrokersIds = Seq(0, 1) + val leaderEpoch = 5 + val replicaManager = setupReplicaManagerWithMockedPurgatories(new MockTimer(time), + brokerId = 0, aliveBrokersIds) + try { + val tp = new TopicPartition(topic, 0) + val tidp = new TopicIdPartition(topicId, tp) + val replicas = aliveBrokersIds.toList.map(Int.box).asJava + + // Broker 0 becomes leader of the partition + val leaderAndIsrPartitionState = new LeaderAndIsrPartitionState() + .setTopicName(topic) + .setPartitionIndex(0) + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(leaderEpoch) + .setIsr(replicas) + .setZkVersion(0) + .setReplicas(replicas) + .setIsNew(true) + val leaderAndIsrRequest = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch, + Seq(leaderAndIsrPartitionState).asJava, + Collections.singletonMap(topic, topicId), + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build() + val leaderAndIsrResponse = replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest, (_, _) => ()) + assertEquals(Errors.NONE, leaderAndIsrResponse.error) + + // Follower replica state is initialized, but initial state is not known + assertTrue(replicaManager.onlinePartition(tp).isDefined) + val partition = replicaManager.onlinePartition(tp).get + + assertTrue(partition.getReplica(1).isDefined) + val followerReplica = partition.getReplica(1).get + assertEquals(-1L, followerReplica.logStartOffset) + assertEquals(-1L, followerReplica.logEndOffset) + + // Leader appends some data + for (i <- 1 to 5) { + appendRecords(replicaManager, tp, TestUtils.singletonRecords(s"message $i".getBytes)).onFire { response => + assertEquals(Errors.NONE, response.error) + } + } + + // We receive one valid request from the follower and replica state is updated + var successfulFetch: Option[FetchPartitionData] = None + def callback(response: Seq[(TopicIdPartition, FetchPartitionData)]): Unit = { + successfulFetch = response.headOption.filter(_._1 == tidp).map(_._2) + } + + val validFetchPartitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0L, 0L, maxFetchBytes, + Optional.of(leaderEpoch)) + + replicaManager.fetchMessages( + timeout = 0L, + replicaId = 1, + fetchMinBytes = 1, + fetchMaxBytes = maxFetchBytes, + hardMaxBytesLimit = false, + fetchInfos = Seq(tidp -> validFetchPartitionData), + quota = UnboundedQuota, + isolationLevel = IsolationLevel.READ_UNCOMMITTED, + responseCallback = callback, + clientMetadata = None + ) + + assertTrue(successfulFetch.isDefined) + assertEquals(0L, followerReplica.logStartOffset) + assertEquals(0L, followerReplica.logEndOffset) + + + // Next we receive an invalid request with a higher fetch offset, but an old epoch. + // We expect that the replica state does not get updated. + val invalidFetchPartitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 3L, 0L, maxFetchBytes, + Optional.of(leaderEpoch - 1)) + + replicaManager.fetchMessages( + timeout = 0L, + replicaId = 1, + fetchMinBytes = 1, + fetchMaxBytes = maxFetchBytes, + hardMaxBytesLimit = false, + fetchInfos = Seq(tidp -> invalidFetchPartitionData), + quota = UnboundedQuota, + isolationLevel = IsolationLevel.READ_UNCOMMITTED, + responseCallback = callback, + clientMetadata = None + ) + + assertTrue(successfulFetch.isDefined) + assertEquals(0L, followerReplica.logStartOffset) + assertEquals(0L, followerReplica.logEndOffset) + + // Next we receive an invalid request with a higher fetch offset, but a diverging epoch. + // We expect that the replica state does not get updated. + val divergingFetchPartitionData = new FetchRequest.PartitionData(tidp.topicId, 3L, 0L, maxFetchBytes, + Optional.of(leaderEpoch), Optional.of(leaderEpoch - 1)) + + replicaManager.fetchMessages( + timeout = 0L, + replicaId = 1, + fetchMinBytes = 1, + fetchMaxBytes = maxFetchBytes, + hardMaxBytesLimit = false, + fetchInfos = Seq(tidp -> divergingFetchPartitionData), + quota = UnboundedQuota, + isolationLevel = IsolationLevel.READ_UNCOMMITTED, + responseCallback = callback, + clientMetadata = None + ) + + assertTrue(successfulFetch.isDefined) + assertEquals(0L, followerReplica.logStartOffset) + assertEquals(0L, followerReplica.logEndOffset) + + } finally { + replicaManager.shutdown(checkpointHW = false) + } + } + + @Test + def testFetchMessagesWithInconsistentTopicId(): Unit = { + val maxFetchBytes = 1024 * 1024 + val aliveBrokersIds = Seq(0, 1) + val leaderEpoch = 5 + val replicaManager = setupReplicaManagerWithMockedPurgatories(new MockTimer(time), + brokerId = 0, aliveBrokersIds) + try { + val tp = new TopicPartition(topic, 0) + val tidp = new TopicIdPartition(topicId, tp) + val replicas = aliveBrokersIds.toList.map(Int.box).asJava + + // Broker 0 becomes leader of the partition + val leaderAndIsrPartitionState = new LeaderAndIsrPartitionState() + .setTopicName(topic) + .setPartitionIndex(0) + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(leaderEpoch) + .setIsr(replicas) + .setZkVersion(0) + .setReplicas(replicas) + .setIsNew(true) + val leaderAndIsrRequest = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch, + Seq(leaderAndIsrPartitionState).asJava, + Collections.singletonMap(topic, topicId), + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build() + val leaderAndIsrResponse = replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest, (_, _) => ()) + assertEquals(Errors.NONE, leaderAndIsrResponse.error) + + assertEquals(Some(topicId), replicaManager.getPartitionOrException(tp).topicId) + + // We receive one valid request from the follower and replica state is updated + var successfulFetch: Seq[(TopicIdPartition, FetchPartitionData)] = Seq() + + val validFetchPartitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0L, 0L, maxFetchBytes, + Optional.of(leaderEpoch)) + + // Fetch messages simulating a different ID than the one in the log. + val inconsistentTidp = new TopicIdPartition(Uuid.randomUuid(), tidp.topicPartition) + def callback(response: Seq[(TopicIdPartition, FetchPartitionData)]): Unit = { + successfulFetch = response + } + replicaManager.fetchMessages( + timeout = 0L, + replicaId = 1, + fetchMinBytes = 1, + fetchMaxBytes = maxFetchBytes, + hardMaxBytesLimit = false, + fetchInfos = Seq(inconsistentTidp -> validFetchPartitionData), + quota = UnboundedQuota, + isolationLevel = IsolationLevel.READ_UNCOMMITTED, + responseCallback = callback, + clientMetadata = None + ) + val fetch1 = successfulFetch.headOption.filter(_._1 == inconsistentTidp).map(_._2) + assertTrue(fetch1.isDefined) + assertEquals(Errors.INCONSISTENT_TOPIC_ID, fetch1.get.error) + + // Simulate where the fetch request did not use topic IDs + // Fetch messages simulating an ID in the log. + // We should not see topic ID errors. + val zeroTidp = new TopicIdPartition(Uuid.ZERO_UUID, tidp.topicPartition) + replicaManager.fetchMessages( + timeout = 0L, + replicaId = 1, + fetchMinBytes = 1, + fetchMaxBytes = maxFetchBytes, + hardMaxBytesLimit = false, + fetchInfos = Seq(zeroTidp -> validFetchPartitionData), + quota = UnboundedQuota, + isolationLevel = IsolationLevel.READ_UNCOMMITTED, + responseCallback = callback, + clientMetadata = None + ) + val fetch2 = successfulFetch.headOption.filter(_._1 == zeroTidp).map(_._2) + assertTrue(fetch2.isDefined) + assertEquals(Errors.NONE, fetch2.get.error) + + // Next create a topic without a topic ID written in the log. + val tp2 = new TopicPartition("noIdTopic", 0) + val tidp2 = new TopicIdPartition(Uuid.randomUuid(), tp2) + + // Broker 0 becomes leader of the partition + val leaderAndIsrPartitionState2 = new LeaderAndIsrPartitionState() + .setTopicName("noIdTopic") + .setPartitionIndex(0) + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(leaderEpoch) + .setIsr(replicas) + .setZkVersion(0) + .setReplicas(replicas) + .setIsNew(true) + val leaderAndIsrRequest2 = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch, + Seq(leaderAndIsrPartitionState2).asJava, + Collections.emptyMap(), + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build() + val leaderAndIsrResponse2 = replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest2, (_, _) => ()) + assertEquals(Errors.NONE, leaderAndIsrResponse2.error) + + assertEquals(None, replicaManager.getPartitionOrException(tp2).topicId) + + // Fetch messages simulating the request containing a topic ID. We should not have an error. + replicaManager.fetchMessages( + timeout = 0L, + replicaId = 1, + fetchMinBytes = 1, + fetchMaxBytes = maxFetchBytes, + hardMaxBytesLimit = false, + fetchInfos = Seq(tidp2 -> validFetchPartitionData), + quota = UnboundedQuota, + isolationLevel = IsolationLevel.READ_UNCOMMITTED, + responseCallback = callback, + clientMetadata = None + ) + val fetch3 = successfulFetch.headOption.filter(_._1 == tidp2).map(_._2) + assertTrue(fetch3.isDefined) + assertEquals(Errors.NONE, fetch3.get.error) + + // Fetch messages simulating the request not containing a topic ID. We should not have an error. + val zeroTidp2 = new TopicIdPartition(Uuid.ZERO_UUID, tidp2.topicPartition) + replicaManager.fetchMessages( + timeout = 0L, + replicaId = 1, + fetchMinBytes = 1, + fetchMaxBytes = maxFetchBytes, + hardMaxBytesLimit = false, + fetchInfos = Seq(zeroTidp2 -> validFetchPartitionData), + quota = UnboundedQuota, + isolationLevel = IsolationLevel.READ_UNCOMMITTED, + responseCallback = callback, + clientMetadata = None + ) + val fetch4 = successfulFetch.headOption.filter(_._1 == zeroTidp2).map(_._2) + assertTrue(fetch4.isDefined) + assertEquals(Errors.NONE, fetch4.get.error) + + } finally { + replicaManager.shutdown(checkpointHW = false) + } + } + + /** + * If a follower sends a fetch request for 2 partitions and it's no longer the follower for one of them, the other + * partition should not be affected. + */ + @Test + def testFetchMessagesWhenNotFollowerForOnePartition(): Unit = { + val replicaManager = setupReplicaManagerWithMockedPurgatories(new MockTimer(time), aliveBrokerIds = Seq(0, 1, 2)) + + try { + // Create 2 partitions, assign replica 0 as the leader for both a different follower (1 and 2) for each + val tp0 = new TopicPartition(topic, 0) + val tp1 = new TopicPartition(topic, 1) + val topicId = Uuid.randomUuid(); + val tidp0 = new TopicIdPartition(topicId, tp0) + val tidp1 = new TopicIdPartition(topicId, tp1) + val offsetCheckpoints = new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints) + replicaManager.createPartition(tp0).createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None) + replicaManager.createPartition(tp1).createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None) + val partition0Replicas = Seq[Integer](0, 1).asJava + val partition1Replicas = Seq[Integer](0, 2).asJava + val topicIds = Map(tp0.topic -> topicId, tp1.topic -> topicId).asJava + val leaderAndIsrRequest = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch, + Seq( + new LeaderAndIsrPartitionState() + .setTopicName(tp0.topic) + .setPartitionIndex(tp0.partition) + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(0) + .setIsr(partition0Replicas) + .setZkVersion(0) + .setReplicas(partition0Replicas) + .setIsNew(true), + new LeaderAndIsrPartitionState() + .setTopicName(tp1.topic) + .setPartitionIndex(tp1.partition) + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(0) + .setIsr(partition1Replicas) + .setZkVersion(0) + .setReplicas(partition1Replicas) + .setIsNew(true) + ).asJava, + topicIds, + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build() + replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest, (_, _) => ()) + + // Append a couple of messages. + for (i <- 1 to 2) { + appendRecords(replicaManager, tp0, TestUtils.singletonRecords(s"message $i".getBytes)).onFire { response => + assertEquals(Errors.NONE, response.error) + } + appendRecords(replicaManager, tp1, TestUtils.singletonRecords(s"message $i".getBytes)).onFire { response => + assertEquals(Errors.NONE, response.error) + } + } + + def fetchCallback(responseStatus: Seq[(TopicIdPartition, FetchPartitionData)]) = { + val responseStatusMap = responseStatus.toMap + assertEquals(2, responseStatus.size) + assertEquals(Set(tidp0, tidp1), responseStatusMap.keySet) + + val tp0Status = responseStatusMap.get(tidp0) + assertTrue(tp0Status.isDefined) + // the response contains high watermark on the leader before it is updated based + // on this fetch request + assertEquals(0, tp0Status.get.highWatermark) + assertEquals(Some(0), tp0Status.get.lastStableOffset) + assertEquals(Errors.NONE, tp0Status.get.error) + assertTrue(tp0Status.get.records.batches.iterator.hasNext) + + val tp1Status = responseStatusMap.get(tidp1) + assertTrue(tp1Status.isDefined) + assertEquals(0, tp1Status.get.highWatermark) + assertEquals(Some(0), tp0Status.get.lastStableOffset) + assertEquals(Errors.NONE, tp1Status.get.error) + assertFalse(tp1Status.get.records.batches.iterator.hasNext) + } + + replicaManager.fetchMessages( + timeout = 1000, + replicaId = 1, + fetchMinBytes = 0, + fetchMaxBytes = Int.MaxValue, + hardMaxBytesLimit = false, + fetchInfos = Seq( + tidp0 -> new PartitionData(Uuid.ZERO_UUID, 1, 0, 100000, Optional.empty()), + tidp1 -> new PartitionData(Uuid.ZERO_UUID, 1, 0, 100000, Optional.empty())), + quota = UnboundedQuota, + responseCallback = fetchCallback, + isolationLevel = IsolationLevel.READ_UNCOMMITTED, + clientMetadata = None + ) + val tp0Log = replicaManager.localLog(tp0) + assertTrue(tp0Log.isDefined) + assertEquals(1, tp0Log.get.highWatermark, "hw should be incremented") + + val tp1Replica = replicaManager.localLog(tp1) + assertTrue(tp1Replica.isDefined) + assertEquals(0, tp1Replica.get.highWatermark, "hw should not be incremented") + + } finally { + replicaManager.shutdown(checkpointHW = false) + } + } + + @Test + def testBecomeFollowerWhenLeaderIsUnchangedButMissedLeaderUpdate(): Unit = { + verifyBecomeFollowerWhenLeaderIsUnchangedButMissedLeaderUpdate(new Properties, expectTruncation = false) + } + + @Test + def testBecomeFollowerWhenLeaderIsUnchangedButMissedLeaderUpdateIbp26(): Unit = { + val extraProps = new Properties + extraProps.put(KafkaConfig.InterBrokerProtocolVersionProp, KAFKA_2_6_IV0.version) + verifyBecomeFollowerWhenLeaderIsUnchangedButMissedLeaderUpdate(extraProps, expectTruncation = true) + } + + /** + * If a partition becomes a follower and the leader is unchanged it should check for truncation + * if the epoch has increased by more than one (which suggests it has missed an update). For + * IBP version 2.7 onwards, we don't require this since we can truncate at any time based + * on diverging epochs returned in fetch responses. + */ + private def verifyBecomeFollowerWhenLeaderIsUnchangedButMissedLeaderUpdate(extraProps: Properties, + expectTruncation: Boolean): Unit = { + val topicPartition = 0 + val topicId = Uuid.randomUuid() + val followerBrokerId = 0 + val leaderBrokerId = 1 + val controllerId = 0 + val controllerEpoch = 0 + var leaderEpoch = 1 + val leaderEpochIncrement = 2 + val aliveBrokerIds = Seq[Integer](followerBrokerId, leaderBrokerId) + val countDownLatch = new CountDownLatch(1) + + // Prepare the mocked components for the test + val (replicaManager, mockLogMgr) = prepareReplicaManagerAndLogManager(new MockTimer(time), + topicPartition, leaderEpoch + leaderEpochIncrement, followerBrokerId, leaderBrokerId, countDownLatch, + expectTruncation = expectTruncation, localLogOffset = Some(10), extraProps = extraProps, topicId = Some(topicId)) + + try { + // Initialize partition state to follower, with leader = 1, leaderEpoch = 1 + val tp = new TopicPartition(topic, topicPartition) + val partition = replicaManager.createPartition(tp) + val offsetCheckpoints = new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints) + partition.createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None) + partition.makeFollower( + leaderAndIsrPartitionState(tp, leaderEpoch, leaderBrokerId, aliveBrokerIds), + offsetCheckpoints, + None) + + // Make local partition a follower - because epoch increased by more than 1, truncation should + // trigger even though leader does not change + leaderEpoch += leaderEpochIncrement + val leaderAndIsrRequest0 = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, + controllerId, controllerEpoch, brokerEpoch, + Seq(leaderAndIsrPartitionState(tp, leaderEpoch, leaderBrokerId, aliveBrokerIds)).asJava, + Collections.singletonMap(topic, topicId), + Set(new Node(followerBrokerId, "host1", 0), + new Node(leaderBrokerId, "host2", 1)).asJava).build() + replicaManager.becomeLeaderOrFollower(correlationId, leaderAndIsrRequest0, + (_, followers) => assertEquals(followerBrokerId, followers.head.partitionId)) + assertTrue(countDownLatch.await(1000L, TimeUnit.MILLISECONDS)) + + // Truncation should have happened once + EasyMock.verify(mockLogMgr) + } finally { + replicaManager.shutdown() + } + + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @Test + def testReplicaSelector(): Unit = { + val topicPartition = 0 + val followerBrokerId = 0 + val leaderBrokerId = 1 + val leaderEpoch = 1 + val leaderEpochIncrement = 2 + val aliveBrokerIds = Seq[Integer](followerBrokerId, leaderBrokerId) + val countDownLatch = new CountDownLatch(1) + + // Prepare the mocked components for the test + val (replicaManager, _) = prepareReplicaManagerAndLogManager(new MockTimer(time), + topicPartition, leaderEpoch + leaderEpochIncrement, followerBrokerId, + leaderBrokerId, countDownLatch, expectTruncation = true) + + val tp = new TopicPartition(topic, topicPartition) + val partition = replicaManager.createPartition(tp) + + val offsetCheckpoints = new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints) + partition.createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None) + partition.makeLeader( + leaderAndIsrPartitionState(tp, leaderEpoch, leaderBrokerId, aliveBrokerIds), + offsetCheckpoints, + None) + + val metadata: ClientMetadata = new DefaultClientMetadata("rack-a", "client-id", + InetAddress.getByName("localhost"), KafkaPrincipal.ANONYMOUS, "default") + + // We expect to select the leader, which means we return None + val preferredReadReplica: Option[Int] = replicaManager.findPreferredReadReplica( + partition, metadata, Request.OrdinaryConsumerId, 1L, System.currentTimeMillis) + assertFalse(preferredReadReplica.isDefined) + } + + @Test + def testPreferredReplicaAsFollower(): Unit = { + val topicPartition = 0 + val topicId = Uuid.randomUuid() + val followerBrokerId = 0 + val leaderBrokerId = 1 + val leaderEpoch = 1 + val leaderEpochIncrement = 2 + val countDownLatch = new CountDownLatch(1) + + // Prepare the mocked components for the test + val (replicaManager, _) = prepareReplicaManagerAndLogManager(new MockTimer(time), + topicPartition, leaderEpoch + leaderEpochIncrement, followerBrokerId, + leaderBrokerId, countDownLatch, expectTruncation = true, topicId = Some(topicId)) + + try { + val brokerList = Seq[Integer](0, 1).asJava + + val tp0 = new TopicPartition(topic, 0) + val tidp0 = new TopicIdPartition(topicId, tp0) + + initializeLogAndTopicId(replicaManager, tp0, topicId) + + // Make this replica the follower + val leaderAndIsrRequest2 = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch, + Seq(new LeaderAndIsrPartitionState() + .setTopicName(topic) + .setPartitionIndex(0) + .setControllerEpoch(0) + .setLeader(1) + .setLeaderEpoch(1) + .setIsr(brokerList) + .setZkVersion(0) + .setReplicas(brokerList) + .setIsNew(false)).asJava, + Collections.singletonMap(topic, topicId), + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build() + replicaManager.becomeLeaderOrFollower(1, leaderAndIsrRequest2, (_, _) => ()) + + val metadata: ClientMetadata = new DefaultClientMetadata("rack-a", "client-id", + InetAddress.getByName("localhost"), KafkaPrincipal.ANONYMOUS, "default") + + val consumerResult = fetchAsConsumer(replicaManager, tidp0, + new PartitionData(Uuid.ZERO_UUID, 0, 0, 100000, Optional.empty()), + clientMetadata = Some(metadata)) + + // Fetch from follower succeeds + assertTrue(consumerResult.isFired) + + // But only leader will compute preferred replica + assertTrue(consumerResult.assertFired.preferredReadReplica.isEmpty) + } finally { + replicaManager.shutdown() + } + + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @Test + def testPreferredReplicaAsLeader(): Unit = { + val topicPartition = 0 + val topicId = Uuid.randomUuid() + val followerBrokerId = 0 + val leaderBrokerId = 1 + val leaderEpoch = 1 + val leaderEpochIncrement = 2 + val countDownLatch = new CountDownLatch(1) + + // Prepare the mocked components for the test + val (replicaManager, _) = prepareReplicaManagerAndLogManager(new MockTimer(time), + topicPartition, leaderEpoch + leaderEpochIncrement, followerBrokerId, + leaderBrokerId, countDownLatch, expectTruncation = true, topicId = Some(topicId)) + + try { + val brokerList = Seq[Integer](0, 1).asJava + + val tp0 = new TopicPartition(topic, 0) + val tidp0 = new TopicIdPartition(topicId, tp0) + + initializeLogAndTopicId(replicaManager, tp0, topicId) + + // Make this replica the follower + val leaderAndIsrRequest2 = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch, + Seq(new LeaderAndIsrPartitionState() + .setTopicName(topic) + .setPartitionIndex(0) + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(1) + .setIsr(brokerList) + .setZkVersion(0) + .setReplicas(brokerList) + .setIsNew(false)).asJava, + Collections.singletonMap(topic, topicId), + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build() + replicaManager.becomeLeaderOrFollower(1, leaderAndIsrRequest2, (_, _) => ()) + + val metadata: ClientMetadata = new DefaultClientMetadata("rack-a", "client-id", + InetAddress.getByName("localhost"), KafkaPrincipal.ANONYMOUS, "default") + + val consumerResult = fetchAsConsumer(replicaManager, tidp0, + new PartitionData(Uuid.ZERO_UUID, 0, 0, 100000, Optional.empty()), + clientMetadata = Some(metadata)) + + // Fetch from follower succeeds + assertTrue(consumerResult.isFired) + + // Returns a preferred replica (should just be the leader, which is None) + assertFalse(consumerResult.assertFired.preferredReadReplica.isDefined) + } finally { + replicaManager.shutdown() + } + + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @Test + def testFollowerFetchWithDefaultSelectorNoForcedHwPropagation(): Unit = { + val topicPartition = 0 + val followerBrokerId = 0 + val leaderBrokerId = 1 + val leaderEpoch = 1 + val leaderEpochIncrement = 2 + val countDownLatch = new CountDownLatch(1) + val timer = new MockTimer(time) + + // Prepare the mocked components for the test + val (replicaManager, _) = prepareReplicaManagerAndLogManager(timer, + topicPartition, leaderEpoch + leaderEpochIncrement, followerBrokerId, + leaderBrokerId, countDownLatch, expectTruncation = true, topicId = Some(topicId)) + + val brokerList = Seq[Integer](0, 1).asJava + + val tp0 = new TopicPartition(topic, 0) + val tidp0 = new TopicIdPartition(topicId, tp0) + + initializeLogAndTopicId(replicaManager, tp0, topicId) + + // Make this replica the follower + val leaderAndIsrRequest2 = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch, + Seq(new LeaderAndIsrPartitionState() + .setTopicName(topic) + .setPartitionIndex(0) + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(1) + .setIsr(brokerList) + .setZkVersion(0) + .setReplicas(brokerList) + .setIsNew(false)).asJava, + Collections.singletonMap(topic, topicId), + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build() + replicaManager.becomeLeaderOrFollower(1, leaderAndIsrRequest2, (_, _) => ()) + + val simpleRecords = Seq(new SimpleRecord("a".getBytes), new SimpleRecord("b".getBytes)) + val appendResult = appendRecords(replicaManager, tp0, + MemoryRecords.withRecords(CompressionType.NONE, simpleRecords.toSeq: _*), AppendOrigin.Client) + + // Increment the hw in the leader by fetching from the last offset + val fetchOffset = simpleRecords.size + var followerResult = fetchAsFollower(replicaManager, tidp0, + new PartitionData(Uuid.ZERO_UUID, fetchOffset, 0, 100000, Optional.empty()), + clientMetadata = None) + assertTrue(followerResult.isFired) + assertEquals(0, followerResult.assertFired.highWatermark) + + assertTrue(appendResult.isFired, "Expected producer request to be acked") + + // Fetch from the same offset, no new data is expected and hence the fetch request should + // go to the purgatory + followerResult = fetchAsFollower(replicaManager, tidp0, + new PartitionData(Uuid.ZERO_UUID, fetchOffset, 0, 100000, Optional.empty()), + clientMetadata = None, minBytes = 1000) + assertFalse(followerResult.isFired, "Request completed immediately unexpectedly") + + // Complete the request in the purgatory by advancing the clock + timer.advanceClock(1001) + assertTrue(followerResult.isFired) + + assertEquals(fetchOffset, followerResult.assertFired.highWatermark) + } + + @Test + def testUnknownReplicaSelector(): Unit = { + val topicPartition = 0 + val followerBrokerId = 0 + val leaderBrokerId = 1 + val leaderEpoch = 1 + val leaderEpochIncrement = 2 + val countDownLatch = new CountDownLatch(1) + + val props = new Properties() + props.put(KafkaConfig.ReplicaSelectorClassProp, "non-a-class") + assertThrows(classOf[ClassNotFoundException], () => prepareReplicaManagerAndLogManager(new MockTimer(time), + topicPartition, leaderEpoch + leaderEpochIncrement, followerBrokerId, + leaderBrokerId, countDownLatch, expectTruncation = true, extraProps = props)) + } + + // Due to some limitations to EasyMock, we need to create the log so that the Partition.topicId does not call + // LogManager.getLog with a default argument + // TODO: convert tests to using Mockito to avoid this issue. + private def initializeLogAndTopicId(replicaManager: ReplicaManager, topicPartition: TopicPartition, topicId: Uuid): Unit = { + val partition = replicaManager.createPartition(new TopicPartition(topic, 0)) + val log = replicaManager.logManager.getOrCreateLog(topicPartition, false, false, Some(topicId)) + partition.log = Some(log) + } + + @Test + def testDefaultReplicaSelector(): Unit = { + val topicPartition = 0 + val followerBrokerId = 0 + val leaderBrokerId = 1 + val leaderEpoch = 1 + val leaderEpochIncrement = 2 + val countDownLatch = new CountDownLatch(1) + + val (replicaManager, _) = prepareReplicaManagerAndLogManager(new MockTimer(time), + topicPartition, leaderEpoch + leaderEpochIncrement, followerBrokerId, + leaderBrokerId, countDownLatch, expectTruncation = true) + assertFalse(replicaManager.replicaSelectorOpt.isDefined) + } + + @Test + def testFetchFollowerNotAllowedForOlderClients(): Unit = { + val replicaManager = setupReplicaManagerWithMockedPurgatories(new MockTimer(time), aliveBrokerIds = Seq(0, 1)) + + try { + val tp0 = new TopicPartition(topic, 0) + val tidp0 = new TopicIdPartition(topicId, tp0) + val offsetCheckpoints = new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints) + replicaManager.createPartition(tp0).createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None) + val partition0Replicas = Seq[Integer](0, 1).asJava + val becomeFollowerRequest = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch, + Seq(new LeaderAndIsrPartitionState() + .setTopicName(tp0.topic) + .setPartitionIndex(tp0.partition) + .setControllerEpoch(0) + .setLeader(1) + .setLeaderEpoch(0) + .setIsr(partition0Replicas) + .setZkVersion(0) + .setReplicas(partition0Replicas) + .setIsNew(true)).asJava, + topicIds.asJava, + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build() + replicaManager.becomeLeaderOrFollower(0, becomeFollowerRequest, (_, _) => ()) + + // Fetch from follower, with non-empty ClientMetadata (FetchRequest v11+) + val clientMetadata = new DefaultClientMetadata("", "", null, KafkaPrincipal.ANONYMOUS, "") + var partitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0L, 0L, 100, + Optional.of(0)) + var fetchResult = sendConsumerFetch(replicaManager, tidp0, partitionData, Some(clientMetadata)) + assertNotNull(fetchResult.get) + assertEquals(Errors.NONE, fetchResult.get.error) + + // Fetch from follower, with empty ClientMetadata (which implies an older version) + partitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0L, 0L, 100, + Optional.of(0)) + fetchResult = sendConsumerFetch(replicaManager, tidp0, partitionData, None) + assertNotNull(fetchResult.get) + assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, fetchResult.get.error) + } finally { + replicaManager.shutdown() + } + + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @Test + def testFetchRequestRateMetrics(): Unit = { + val mockTimer = new MockTimer(time) + val replicaManager = setupReplicaManagerWithMockedPurgatories(mockTimer, aliveBrokerIds = Seq(0, 1)) + + val tp0 = new TopicPartition(topic, 0) + val tidp0 = new TopicIdPartition(topicId, tp0) + val offsetCheckpoints = new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints) + replicaManager.createPartition(tp0).createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None) + val partition0Replicas = Seq[Integer](0, 1).asJava + + val becomeLeaderRequest = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch, + Seq(new LeaderAndIsrPartitionState() + .setTopicName(tp0.topic) + .setPartitionIndex(tp0.partition) + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(1) + .setIsr(partition0Replicas) + .setZkVersion(0) + .setReplicas(partition0Replicas) + .setIsNew(true)).asJava, + topicIds.asJava, + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build() + replicaManager.becomeLeaderOrFollower(1, becomeLeaderRequest, (_, _) => ()) + + def assertMetricCount(expected: Int): Unit = { + assertEquals(expected, replicaManager.brokerTopicStats.allTopicsStats.totalFetchRequestRate.count) + assertEquals(expected, replicaManager.brokerTopicStats.topicStats(topic).totalFetchRequestRate.count) + } + + val partitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0L, 0L, 100, + Optional.empty()) + + val nonPurgatoryFetchResult = sendConsumerFetch(replicaManager, tidp0, partitionData, None, timeout = 0) + assertNotNull(nonPurgatoryFetchResult.get) + assertEquals(Errors.NONE, nonPurgatoryFetchResult.get.error) + assertMetricCount(1) + + val purgatoryFetchResult = sendConsumerFetch(replicaManager, tidp0, partitionData, None, timeout = 10) + assertNull(purgatoryFetchResult.get) + mockTimer.advanceClock(11) + assertNotNull(purgatoryFetchResult.get) + assertEquals(Errors.NONE, purgatoryFetchResult.get.error) + assertMetricCount(2) + } + + @Test + def testBecomeFollowerWhileOldClientFetchInPurgatory(): Unit = { + val mockTimer = new MockTimer(time) + val replicaManager = setupReplicaManagerWithMockedPurgatories(mockTimer, aliveBrokerIds = Seq(0, 1)) + + try { + val tp0 = new TopicPartition(topic, 0) + val tidp0 = new TopicIdPartition(topicId, tp0) + val offsetCheckpoints = new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints) + replicaManager.createPartition(tp0).createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None) + val partition0Replicas = Seq[Integer](0, 1).asJava + + val becomeLeaderRequest = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch, + Seq(new LeaderAndIsrPartitionState() + .setTopicName(tp0.topic) + .setPartitionIndex(tp0.partition) + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(1) + .setIsr(partition0Replicas) + .setZkVersion(0) + .setReplicas(partition0Replicas) + .setIsNew(true)).asJava, + topicIds.asJava, + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build() + replicaManager.becomeLeaderOrFollower(1, becomeLeaderRequest, (_, _) => ()) + + val partitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0L, 0L, 100, + Optional.empty()) + val fetchResult = sendConsumerFetch(replicaManager, tidp0, partitionData, None, timeout = 10) + assertNull(fetchResult.get) + + // Become a follower and ensure that the delayed fetch returns immediately + val becomeFollowerRequest = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch, + Seq(new LeaderAndIsrPartitionState() + .setTopicName(tp0.topic) + .setPartitionIndex(tp0.partition) + .setControllerEpoch(0) + .setLeader(1) + .setLeaderEpoch(2) + .setIsr(partition0Replicas) + .setZkVersion(0) + .setReplicas(partition0Replicas) + .setIsNew(true)).asJava, + topicIds.asJava, + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build() + replicaManager.becomeLeaderOrFollower(0, becomeFollowerRequest, (_, _) => ()) + + assertNotNull(fetchResult.get) + assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, fetchResult.get.error) + } finally { + replicaManager.shutdown() + } + + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @Test + def testBecomeFollowerWhileNewClientFetchInPurgatory(): Unit = { + val mockTimer = new MockTimer(time) + val replicaManager = setupReplicaManagerWithMockedPurgatories(mockTimer, aliveBrokerIds = Seq(0, 1)) + + try { + val tp0 = new TopicPartition(topic, 0) + val tidp0 = new TopicIdPartition(topicId, tp0) + val offsetCheckpoints = new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints) + replicaManager.createPartition(tp0).createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None) + val partition0Replicas = Seq[Integer](0, 1).asJava + + val becomeLeaderRequest = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch, + Seq(new LeaderAndIsrPartitionState() + .setTopicName(tp0.topic) + .setPartitionIndex(tp0.partition) + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(1) + .setIsr(partition0Replicas) + .setZkVersion(0) + .setReplicas(partition0Replicas) + .setIsNew(true)).asJava, + topicIds.asJava, + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build() + replicaManager.becomeLeaderOrFollower(1, becomeLeaderRequest, (_, _) => ()) + + val clientMetadata = new DefaultClientMetadata("", "", null, KafkaPrincipal.ANONYMOUS, "") + val partitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0L, 0L, 100, + Optional.of(1)) + val fetchResult = sendConsumerFetch(replicaManager, tidp0, partitionData, Some(clientMetadata), timeout = 10) + assertNull(fetchResult.get) + + // Become a follower and ensure that the delayed fetch returns immediately + val becomeFollowerRequest = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch, + Seq(new LeaderAndIsrPartitionState() + .setTopicName(tp0.topic) + .setPartitionIndex(tp0.partition) + .setControllerEpoch(0) + .setLeader(1) + .setLeaderEpoch(2) + .setIsr(partition0Replicas) + .setZkVersion(0) + .setReplicas(partition0Replicas) + .setIsNew(true)).asJava, + topicIds.asJava, + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build() + replicaManager.becomeLeaderOrFollower(0, becomeFollowerRequest, (_, _) => ()) + + assertNotNull(fetchResult.get) + assertEquals(Errors.FENCED_LEADER_EPOCH, fetchResult.get.error) + } finally { + replicaManager.shutdown() + } + + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @Test + def testFetchFromLeaderAlwaysAllowed(): Unit = { + val replicaManager = setupReplicaManagerWithMockedPurgatories(new MockTimer(time), aliveBrokerIds = Seq(0, 1)) + + val tp0 = new TopicPartition(topic, 0) + val tidp0 = new TopicIdPartition(topicId, tp0) + val offsetCheckpoints = new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints) + replicaManager.createPartition(tp0).createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None) + val partition0Replicas = Seq[Integer](0, 1).asJava + + val becomeLeaderRequest = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch, + Seq(new LeaderAndIsrPartitionState() + .setTopicName(tp0.topic) + .setPartitionIndex(tp0.partition) + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(1) + .setIsr(partition0Replicas) + .setZkVersion(0) + .setReplicas(partition0Replicas) + .setIsNew(true)).asJava, + topicIds.asJava, + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build() + replicaManager.becomeLeaderOrFollower(1, becomeLeaderRequest, (_, _) => ()) + + val clientMetadata = new DefaultClientMetadata("", "", null, KafkaPrincipal.ANONYMOUS, "") + var partitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0L, 0L, 100, + Optional.of(1)) + var fetchResult = sendConsumerFetch(replicaManager, tidp0, partitionData, Some(clientMetadata)) + assertNotNull(fetchResult.get) + assertEquals(Errors.NONE, fetchResult.get.error) + + partitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0L, 0L, 100, + Optional.empty()) + fetchResult = sendConsumerFetch(replicaManager, tidp0, partitionData, Some(clientMetadata)) + assertNotNull(fetchResult.get) + assertEquals(Errors.NONE, fetchResult.get.error) + } + + @Test + def testClearFetchPurgatoryOnStopReplica(): Unit = { + // As part of a reassignment, we may send StopReplica to the old leader. + // In this case, we should ensure that pending purgatory operations are cancelled + // immediately rather than sitting around to timeout. + + val mockTimer = new MockTimer(time) + val replicaManager = setupReplicaManagerWithMockedPurgatories(mockTimer, aliveBrokerIds = Seq(0, 1)) + + val tp0 = new TopicPartition(topic, 0) + val tidp0 = new TopicIdPartition(topicId, tp0) + val offsetCheckpoints = new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints) + replicaManager.createPartition(tp0).createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None) + val partition0Replicas = Seq[Integer](0, 1).asJava + + val becomeLeaderRequest = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch, + Seq(new LeaderAndIsrPartitionState() + .setTopicName(tp0.topic) + .setPartitionIndex(tp0.partition) + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(1) + .setIsr(partition0Replicas) + .setZkVersion(0) + .setReplicas(partition0Replicas) + .setIsNew(true)).asJava, + topicIds.asJava, + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build() + replicaManager.becomeLeaderOrFollower(1, becomeLeaderRequest, (_, _) => ()) + + val partitionData = new FetchRequest.PartitionData(Uuid.ZERO_UUID, 0L, 0L, 100, + Optional.of(1)) + val fetchResult = sendConsumerFetch(replicaManager, tidp0, partitionData, None, timeout = 10) + assertNull(fetchResult.get) + Mockito.when(replicaManager.metadataCache.contains(ArgumentMatchers.eq(tp0))).thenReturn(true) + + // We have a fetch in purgatory, now receive a stop replica request and + // assert that the fetch returns with a NOT_LEADER error + replicaManager.stopReplicas(2, 0, 0, + mutable.Map(tp0 -> new StopReplicaPartitionState() + .setPartitionIndex(tp0.partition) + .setDeletePartition(true) + .setLeaderEpoch(LeaderAndIsr.EpochDuringDelete))) + + assertNotNull(fetchResult.get) + assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, fetchResult.get.error) + } + + @Test + def testClearProducePurgatoryOnStopReplica(): Unit = { + val mockTimer = new MockTimer(time) + val replicaManager = setupReplicaManagerWithMockedPurgatories(mockTimer, aliveBrokerIds = Seq(0, 1)) + + val tp0 = new TopicPartition(topic, 0) + val offsetCheckpoints = new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints) + replicaManager.createPartition(tp0).createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None) + val partition0Replicas = Seq[Integer](0, 1).asJava + + val becomeLeaderRequest = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch, + Seq(new LeaderAndIsrPartitionState() + .setTopicName(tp0.topic) + .setPartitionIndex(tp0.partition) + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(1) + .setIsr(partition0Replicas) + .setZkVersion(0) + .setReplicas(partition0Replicas) + .setIsNew(true)).asJava, + topicIds.asJava, + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build() + replicaManager.becomeLeaderOrFollower(1, becomeLeaderRequest, (_, _) => ()) + + val produceResult = sendProducerAppend(replicaManager, tp0, 3) + assertNull(produceResult.get) + + Mockito.when(replicaManager.metadataCache.contains(tp0)).thenReturn(true) + + replicaManager.stopReplicas(2, 0, 0, + mutable.Map(tp0 -> new StopReplicaPartitionState() + .setPartitionIndex(tp0.partition) + .setDeletePartition(true) + .setLeaderEpoch(LeaderAndIsr.EpochDuringDelete))) + + assertNotNull(produceResult.get) + assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, produceResult.get.error) + } + + private def sendProducerAppend( + replicaManager: ReplicaManager, + topicPartition: TopicPartition, + numOfRecords: Int + ): AtomicReference[PartitionResponse] = { + val produceResult = new AtomicReference[PartitionResponse]() + def callback(response: Map[TopicPartition, PartitionResponse]): Unit = { + produceResult.set(response(topicPartition)) + } + + val records = MemoryRecords.withRecords( + CompressionType.NONE, + IntStream + .range(0, numOfRecords) + .mapToObj(i => new SimpleRecord(i.toString.getBytes)) + .toArray(Array.ofDim[SimpleRecord]): _* + ) + + replicaManager.appendRecords( + timeout = 10, + requiredAcks = -1, + internalTopicsAllowed = false, + origin = AppendOrigin.Client, + entriesPerPartition = Map(topicPartition -> records), + responseCallback = callback + ) + produceResult + } + + private def sendConsumerFetch(replicaManager: ReplicaManager, + topicIdPartition: TopicIdPartition, + partitionData: FetchRequest.PartitionData, + clientMetadataOpt: Option[ClientMetadata], + timeout: Long = 0L): AtomicReference[FetchPartitionData] = { + val fetchResult = new AtomicReference[FetchPartitionData]() + def callback(response: Seq[(TopicIdPartition, FetchPartitionData)]): Unit = { + fetchResult.set(response.toMap.apply(topicIdPartition)) + } + replicaManager.fetchMessages( + timeout = timeout, + replicaId = Request.OrdinaryConsumerId, + fetchMinBytes = 1, + fetchMaxBytes = 100, + hardMaxBytesLimit = false, + fetchInfos = Seq(topicIdPartition -> partitionData), + quota = UnboundedQuota, + isolationLevel = IsolationLevel.READ_UNCOMMITTED, + responseCallback = callback, + clientMetadata = clientMetadataOpt + ) + fetchResult + } + + /** + * This method assumes that the test using created ReplicaManager calls + * ReplicaManager.becomeLeaderOrFollower() once with LeaderAndIsrRequest containing + * 'leaderEpochInLeaderAndIsr' leader epoch for partition 'topicPartition'. + */ + private def prepareReplicaManagerAndLogManager(timer: MockTimer, + topicPartition: Int, + leaderEpochInLeaderAndIsr: Int, + followerBrokerId: Int, + leaderBrokerId: Int, + countDownLatch: CountDownLatch, + expectTruncation: Boolean, + localLogOffset: Option[Long] = None, + offsetFromLeader: Long = 5, + leaderEpochFromLeader: Int = 3, + extraProps: Properties = new Properties(), + topicId: Option[Uuid] = None): (ReplicaManager, LogManager) = { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect) + props.put("log.dir", TestUtils.tempRelativeDir("data").getAbsolutePath) + props.asScala ++= extraProps.asScala + val config = KafkaConfig.fromProps(props) + val logConfig = LogConfig() + val logDir = new File(new File(config.logDirs.head), s"$topic-$topicPartition") + Files.createDirectories(logDir.toPath) + val mockScheduler = new MockScheduler(time) + val mockBrokerTopicStats = new BrokerTopicStats + val mockLogDirFailureChannel = new LogDirFailureChannel(config.logDirs.size) + val tp = new TopicPartition(topic, topicPartition) + val maxProducerIdExpirationMs = 30000 + val segments = new LogSegments(tp) + val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(logDir, tp, mockLogDirFailureChannel, logConfig.recordVersion, "") + val producerStateManager = new ProducerStateManager(tp, logDir, maxProducerIdExpirationMs, time) + val offsets = LogLoader.load(LoadLogParams( + logDir, + tp, + logConfig, + mockScheduler, + time, + mockLogDirFailureChannel, + hadCleanShutdown = true, + segments, + 0L, + 0L, + maxProducerIdExpirationMs, + leaderEpochCache, + producerStateManager)) + val localLog = new LocalLog(logDir, logConfig, segments, offsets.recoveryPoint, + offsets.nextOffsetMetadata, mockScheduler, time, tp, mockLogDirFailureChannel) + val mockLog = new UnifiedLog( + logStartOffset = offsets.logStartOffset, + localLog = localLog, + brokerTopicStats = mockBrokerTopicStats, + producerIdExpirationCheckIntervalMs = 30000, + leaderEpochCache = leaderEpochCache, + producerStateManager = producerStateManager, + _topicId = topicId, + keepPartitionMetadataFile = true) { + + override def endOffsetForEpoch(leaderEpoch: Int): Option[OffsetAndEpoch] = { + assertEquals(leaderEpoch, leaderEpochFromLeader) + localLogOffset.map { logOffset => + Some(OffsetAndEpoch(logOffset, leaderEpochFromLeader)) + }.getOrElse(super.endOffsetForEpoch(leaderEpoch)) + } + + override def latestEpoch: Option[Int] = Some(leaderEpochFromLeader) + + override def logEndOffsetMetadata: LogOffsetMetadata = + localLogOffset.map(LogOffsetMetadata(_)).getOrElse(super.logEndOffsetMetadata) + + override def logEndOffset: Long = localLogOffset.getOrElse(super.logEndOffset) + } + + // Expect to call LogManager.truncateTo exactly once + val topicPartitionObj = new TopicPartition(topic, topicPartition) + val mockLogMgr: LogManager = EasyMock.createMock(classOf[LogManager]) + EasyMock.expect(mockLogMgr.liveLogDirs).andReturn(config.logDirs.map(new File(_).getAbsoluteFile)).anyTimes + EasyMock.expect(mockLogMgr.getOrCreateLog(EasyMock.eq(topicPartitionObj), + isNew = EasyMock.eq(false), isFuture = EasyMock.eq(false), EasyMock.anyObject())).andReturn(mockLog).anyTimes + if (expectTruncation) { + EasyMock.expect(mockLogMgr.truncateTo(Map(topicPartitionObj -> offsetFromLeader), + isFuture = false)).once + } + EasyMock.expect(mockLogMgr.initializingLog(topicPartitionObj)).anyTimes + EasyMock.expect(mockLogMgr.getLog(topicPartitionObj, isFuture = true)).andReturn(None) + + EasyMock.expect(mockLogMgr.finishedInitializingLog( + EasyMock.eq(topicPartitionObj), EasyMock.anyObject())).anyTimes + + EasyMock.replay(mockLogMgr) + + val aliveBrokerIds = Seq[Integer](followerBrokerId, leaderBrokerId) + val aliveBrokers = aliveBrokerIds.map(brokerId => new Node(brokerId, s"host$brokerId", brokerId)) + + val metadataCache: MetadataCache = Mockito.mock(classOf[MetadataCache]) + mockGetAliveBrokerFunctions(metadataCache, aliveBrokers) + Mockito.when(metadataCache.getPartitionReplicaEndpoints( + ArgumentMatchers.any[TopicPartition], ArgumentMatchers.any[ListenerName])). + thenReturn(Map(leaderBrokerId -> new Node(leaderBrokerId, "host1", 9092, "rack-a"), + followerBrokerId -> new Node(followerBrokerId, "host2", 9092, "rack-b")).toMap) + + val mockProducePurgatory = new DelayedOperationPurgatory[DelayedProduce]( + purgatoryName = "Produce", timer, reaperEnabled = false) + val mockFetchPurgatory = new DelayedOperationPurgatory[DelayedFetch]( + purgatoryName = "Fetch", timer, reaperEnabled = false) + val mockDeleteRecordsPurgatory = new DelayedOperationPurgatory[DelayedDeleteRecords]( + purgatoryName = "DeleteRecords", timer, reaperEnabled = false) + val mockElectLeaderPurgatory = new DelayedOperationPurgatory[DelayedElectLeader]( + purgatoryName = "ElectLeader", timer, reaperEnabled = false) + + // Mock network client to show leader offset of 5 + val blockingSend = new ReplicaFetcherMockBlockingSend( + Map(topicPartitionObj -> new EpochEndOffset() + .setPartition(topicPartitionObj.partition) + .setErrorCode(Errors.NONE.code) + .setLeaderEpoch(leaderEpochFromLeader) + .setEndOffset(offsetFromLeader)).asJava, + BrokerEndPoint(1, "host1" ,1), time) + val replicaManager = new ReplicaManager( + metrics = metrics, + config = config, + time = time, + scheduler = mockScheduler, + logManager = mockLogMgr, + quotaManagers = quotaManager, + brokerTopicStats = mockBrokerTopicStats, + metadataCache = metadataCache, + logDirFailureChannel = mockLogDirFailureChannel, + alterIsrManager = alterIsrManager, + delayedProducePurgatoryParam = Some(mockProducePurgatory), + delayedFetchPurgatoryParam = Some(mockFetchPurgatory), + delayedDeleteRecordsPurgatoryParam = Some(mockDeleteRecordsPurgatory), + delayedElectLeaderPurgatoryParam = Some(mockElectLeaderPurgatory), + threadNamePrefix = Option(this.getClass.getName)) { + + override protected def createReplicaFetcherManager(metrics: Metrics, + time: Time, + threadNamePrefix: Option[String], + replicationQuotaManager: ReplicationQuotaManager): ReplicaFetcherManager = { + new ReplicaFetcherManager(config, this, metrics, time, threadNamePrefix, replicationQuotaManager) { + + override def createFetcherThread(fetcherId: Int, sourceBroker: BrokerEndPoint): ReplicaFetcherThread = { + new ReplicaFetcherThread(s"ReplicaFetcherThread-$fetcherId", fetcherId, + sourceBroker, config, failedPartitions, replicaManager, metrics, time, quotaManager.follower, Some(blockingSend)) { + + override def doWork() = { + // In case the thread starts before the partition is added by AbstractFetcherManager, + // add it here (it's a no-op if already added) + val initialOffset = InitialFetchState( + topicId = topicId, + leader = new BrokerEndPoint(0, "localhost", 9092), + initOffset = 0L, currentLeaderEpoch = leaderEpochInLeaderAndIsr) + addPartitions(Map(new TopicPartition(topic, topicPartition) -> initialOffset)) + super.doWork() + + // Shut the thread down after one iteration to avoid double-counting truncations + initiateShutdown() + countDownLatch.countDown() + } + } + } + } + } + } + + (replicaManager, mockLogMgr) + } + + private def leaderAndIsrPartitionState(topicPartition: TopicPartition, + leaderEpoch: Int, + leaderBrokerId: Int, + aliveBrokerIds: Seq[Integer], + isNew: Boolean = false): LeaderAndIsrPartitionState = { + new LeaderAndIsrPartitionState() + .setTopicName(topic) + .setPartitionIndex(topicPartition.partition) + .setControllerEpoch(controllerEpoch) + .setLeader(leaderBrokerId) + .setLeaderEpoch(leaderEpoch) + .setIsr(aliveBrokerIds.asJava) + .setZkVersion(zkVersion) + .setReplicas(aliveBrokerIds.asJava) + .setIsNew(isNew) + } + + private class CallbackResult[T] { + private var value: Option[T] = None + private var fun: Option[T => Unit] = None + + def assertFired: T = { + assertTrue(isFired, "Callback has not been fired") + value.get + } + + def isFired: Boolean = { + value.isDefined + } + + def fire(value: T): Unit = { + this.value = Some(value) + fun.foreach(f => f(value)) + } + + def onFire(fun: T => Unit): CallbackResult[T] = { + this.fun = Some(fun) + if (this.isFired) fire(value.get) + this + } + } + + private def appendRecords(replicaManager: ReplicaManager, + partition: TopicPartition, + records: MemoryRecords, + origin: AppendOrigin = AppendOrigin.Client, + requiredAcks: Short = -1): CallbackResult[PartitionResponse] = { + val result = new CallbackResult[PartitionResponse]() + def appendCallback(responses: Map[TopicPartition, PartitionResponse]): Unit = { + val response = responses.get(partition) + assertTrue(response.isDefined) + result.fire(response.get) + } + + replicaManager.appendRecords( + timeout = 1000, + requiredAcks = requiredAcks, + internalTopicsAllowed = false, + origin = origin, + entriesPerPartition = Map(partition -> records), + responseCallback = appendCallback) + + result + } + + private def fetchAsConsumer(replicaManager: ReplicaManager, + partition: TopicIdPartition, + partitionData: PartitionData, + minBytes: Int = 0, + isolationLevel: IsolationLevel = IsolationLevel.READ_UNCOMMITTED, + clientMetadata: Option[ClientMetadata] = None): CallbackResult[FetchPartitionData] = { + fetchMessages(replicaManager, replicaId = -1, partition, partitionData, minBytes, isolationLevel, clientMetadata) + } + + private def fetchAsFollower(replicaManager: ReplicaManager, + partition: TopicIdPartition, + partitionData: PartitionData, + minBytes: Int = 0, + isolationLevel: IsolationLevel = IsolationLevel.READ_UNCOMMITTED, + clientMetadata: Option[ClientMetadata] = None): CallbackResult[FetchPartitionData] = { + fetchMessages(replicaManager, replicaId = 1, partition, partitionData, minBytes, isolationLevel, clientMetadata) + } + + private def fetchMessages(replicaManager: ReplicaManager, + replicaId: Int, + partition: TopicIdPartition, + partitionData: PartitionData, + minBytes: Int, + isolationLevel: IsolationLevel, + clientMetadata: Option[ClientMetadata]): CallbackResult[FetchPartitionData] = { + val result = new CallbackResult[FetchPartitionData]() + def fetchCallback(responseStatus: Seq[(TopicIdPartition, FetchPartitionData)]) = { + assertEquals(1, responseStatus.size) + val (topicPartition, fetchData) = responseStatus.head + assertEquals(partition, topicPartition) + result.fire(fetchData) + } + + replicaManager.fetchMessages( + timeout = 1000, + replicaId = replicaId, + fetchMinBytes = minBytes, + fetchMaxBytes = Int.MaxValue, + hardMaxBytesLimit = false, + fetchInfos = Seq(partition -> partitionData), + quota = UnboundedQuota, + responseCallback = fetchCallback, + isolationLevel = isolationLevel, + clientMetadata = clientMetadata + ) + + result + } + + private def setupReplicaManagerWithMockedPurgatories( + timer: MockTimer, + brokerId: Int = 0, + aliveBrokerIds: Seq[Int] = Seq(0, 1), + propsModifier: Properties => Unit = _ => {}, + mockReplicaFetcherManager: Option[ReplicaFetcherManager] = None + ): ReplicaManager = { + val props = TestUtils.createBrokerConfig(brokerId, TestUtils.MockZkConnect) + props.put("log.dirs", TestUtils.tempRelativeDir("data").getAbsolutePath + "," + TestUtils.tempRelativeDir("data2").getAbsolutePath) + propsModifier.apply(props) + val config = KafkaConfig.fromProps(props) + val logProps = new Properties() + val mockLogMgr = TestUtils.createLogManager(config.logDirs.map(new File(_)), LogConfig(logProps)) + val aliveBrokers = aliveBrokerIds.map(brokerId => new Node(brokerId, s"host$brokerId", brokerId)) + + val metadataCache: MetadataCache = Mockito.mock(classOf[MetadataCache]) + Mockito.when(metadataCache.topicIdInfo()).thenReturn((topicIds.asJava, topicNames.asJava)) + Mockito.when(metadataCache.topicNamesToIds()).thenReturn(topicIds.asJava) + Mockito.when(metadataCache.topicIdsToNames()).thenReturn(topicNames.asJava) + mockGetAliveBrokerFunctions(metadataCache, aliveBrokers) + val mockProducePurgatory = new DelayedOperationPurgatory[DelayedProduce]( + purgatoryName = "Produce", timer, reaperEnabled = false) + val mockFetchPurgatory = new DelayedOperationPurgatory[DelayedFetch]( + purgatoryName = "Fetch", timer, reaperEnabled = false) + val mockDeleteRecordsPurgatory = new DelayedOperationPurgatory[DelayedDeleteRecords]( + purgatoryName = "DeleteRecords", timer, reaperEnabled = false) + val mockDelayedElectLeaderPurgatory = new DelayedOperationPurgatory[DelayedElectLeader]( + purgatoryName = "DelayedElectLeader", timer, reaperEnabled = false) + + new ReplicaManager( + metrics = metrics, + config = config, + time = time, + scheduler = scheduler, + logManager = mockLogMgr, + quotaManagers = quotaManager, + metadataCache = metadataCache, + logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size), + alterIsrManager = alterIsrManager, + delayedProducePurgatoryParam = Some(mockProducePurgatory), + delayedFetchPurgatoryParam = Some(mockFetchPurgatory), + delayedDeleteRecordsPurgatoryParam = Some(mockDeleteRecordsPurgatory), + delayedElectLeaderPurgatoryParam = Some(mockDelayedElectLeaderPurgatory), + threadNamePrefix = Option(this.getClass.getName)) { + + override protected def createReplicaFetcherManager( + metrics: Metrics, + time: Time, + threadNamePrefix: Option[String], + quotaManager: ReplicationQuotaManager + ): ReplicaFetcherManager = { + mockReplicaFetcherManager.getOrElse { + super.createReplicaFetcherManager( + metrics, + time, + threadNamePrefix, + quotaManager + ) + } + } + } + } + + @Test + def testOldLeaderLosesMetricsWhenReassignPartitions(): Unit = { + val controllerEpoch = 0 + val leaderEpoch = 0 + val leaderEpochIncrement = 1 + val correlationId = 0 + val controllerId = 0 + val mockTopicStats1: BrokerTopicStats = EasyMock.mock(classOf[BrokerTopicStats]) + val (rm0, rm1) = prepareDifferentReplicaManagers(EasyMock.mock(classOf[BrokerTopicStats]), mockTopicStats1) + + EasyMock.expect(mockTopicStats1.removeOldLeaderMetrics(topic)).andVoid.once + EasyMock.replay(mockTopicStats1) + + try { + // make broker 0 the leader of partition 0 and + // make broker 1 the leader of partition 1 + val tp0 = new TopicPartition(topic, 0) + val tp1 = new TopicPartition(topic, 1) + val partition0Replicas = Seq[Integer](0, 1).asJava + val partition1Replicas = Seq[Integer](1, 0).asJava + val topicIds = Map(tp0.topic -> Uuid.randomUuid(), tp1.topic -> Uuid.randomUuid()).asJava + + val leaderAndIsrRequest1 = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, + controllerId, 0, brokerEpoch, + Seq( + new LeaderAndIsrPartitionState() + .setTopicName(tp0.topic) + .setPartitionIndex(tp0.partition) + .setControllerEpoch(controllerEpoch) + .setLeader(0) + .setLeaderEpoch(leaderEpoch) + .setIsr(partition0Replicas) + .setZkVersion(0) + .setReplicas(partition0Replicas) + .setIsNew(true), + new LeaderAndIsrPartitionState() + .setTopicName(tp1.topic) + .setPartitionIndex(tp1.partition) + .setControllerEpoch(controllerEpoch) + .setLeader(1) + .setLeaderEpoch(leaderEpoch) + .setIsr(partition1Replicas) + .setZkVersion(0) + .setReplicas(partition1Replicas) + .setIsNew(true) + ).asJava, + topicIds, + Set(new Node(0, "host0", 0), new Node(1, "host1", 1)).asJava).build() + + rm0.becomeLeaderOrFollower(correlationId, leaderAndIsrRequest1, (_, _) => ()) + rm1.becomeLeaderOrFollower(correlationId, leaderAndIsrRequest1, (_, _) => ()) + + // make broker 0 the leader of partition 1 so broker 1 loses its leadership position + val leaderAndIsrRequest2 = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, controllerId, + controllerEpoch, brokerEpoch, + Seq( + new LeaderAndIsrPartitionState() + .setTopicName(tp0.topic) + .setPartitionIndex(tp0.partition) + .setControllerEpoch(controllerEpoch) + .setLeader(0) + .setLeaderEpoch(leaderEpoch + leaderEpochIncrement) + .setIsr(partition0Replicas) + .setZkVersion(0) + .setReplicas(partition0Replicas) + .setIsNew(true), + new LeaderAndIsrPartitionState() + .setTopicName(tp1.topic) + .setPartitionIndex(tp1.partition) + .setControllerEpoch(controllerEpoch) + .setLeader(0) + .setLeaderEpoch(leaderEpoch + leaderEpochIncrement) + .setIsr(partition1Replicas) + .setZkVersion(0) + .setReplicas(partition1Replicas) + .setIsNew(true) + ).asJava, + topicIds, + Set(new Node(0, "host0", 0), new Node(1, "host1", 1)).asJava).build() + + rm0.becomeLeaderOrFollower(correlationId, leaderAndIsrRequest2, (_, _) => ()) + rm1.becomeLeaderOrFollower(correlationId, leaderAndIsrRequest2, (_, _) => ()) + } finally { + rm0.shutdown() + rm1.shutdown() + } + + // verify that broker 1 did remove its metrics when no longer being the leader of partition 1 + EasyMock.verify(mockTopicStats1) + } + + @Test + def testOldFollowerLosesMetricsWhenReassignPartitions(): Unit = { + val controllerEpoch = 0 + val leaderEpoch = 0 + val leaderEpochIncrement = 1 + val correlationId = 0 + val controllerId = 0 + val mockTopicStats1: BrokerTopicStats = EasyMock.mock(classOf[BrokerTopicStats]) + val (rm0, rm1) = prepareDifferentReplicaManagers(EasyMock.mock(classOf[BrokerTopicStats]), mockTopicStats1) + + EasyMock.expect(mockTopicStats1.removeOldLeaderMetrics(topic)).andVoid.once + EasyMock.expect(mockTopicStats1.removeOldFollowerMetrics(topic)).andVoid.once + EasyMock.replay(mockTopicStats1) + + try { + // make broker 0 the leader of partition 0 and + // make broker 1 the leader of partition 1 + val tp0 = new TopicPartition(topic, 0) + val tp1 = new TopicPartition(topic, 1) + val partition0Replicas = Seq[Integer](1, 0).asJava + val partition1Replicas = Seq[Integer](1, 0).asJava + val topicIds = Map(tp0.topic -> Uuid.randomUuid(), tp1.topic -> Uuid.randomUuid()).asJava + + val leaderAndIsrRequest1 = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, + controllerId, 0, brokerEpoch, + Seq( + new LeaderAndIsrPartitionState() + .setTopicName(tp0.topic) + .setPartitionIndex(tp0.partition) + .setControllerEpoch(controllerEpoch) + .setLeader(1) + .setLeaderEpoch(leaderEpoch) + .setIsr(partition0Replicas) + .setZkVersion(0) + .setReplicas(partition0Replicas) + .setIsNew(true), + new LeaderAndIsrPartitionState() + .setTopicName(tp1.topic) + .setPartitionIndex(tp1.partition) + .setControllerEpoch(controllerEpoch) + .setLeader(1) + .setLeaderEpoch(leaderEpoch) + .setIsr(partition1Replicas) + .setZkVersion(0) + .setReplicas(partition1Replicas) + .setIsNew(true) + ).asJava, + topicIds, + Set(new Node(0, "host0", 0), new Node(1, "host1", 1)).asJava).build() + + rm0.becomeLeaderOrFollower(correlationId, leaderAndIsrRequest1, (_, _) => ()) + rm1.becomeLeaderOrFollower(correlationId, leaderAndIsrRequest1, (_, _) => ()) + + // make broker 0 the leader of partition 1 so broker 1 loses its leadership position + val leaderAndIsrRequest2 = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, controllerId, + controllerEpoch, brokerEpoch, + Seq( + new LeaderAndIsrPartitionState() + .setTopicName(tp0.topic) + .setPartitionIndex(tp0.partition) + .setControllerEpoch(controllerEpoch) + .setLeader(0) + .setLeaderEpoch(leaderEpoch + leaderEpochIncrement) + .setIsr(partition0Replicas) + .setZkVersion(0) + .setReplicas(partition0Replicas) + .setIsNew(true), + new LeaderAndIsrPartitionState() + .setTopicName(tp1.topic) + .setPartitionIndex(tp1.partition) + .setControllerEpoch(controllerEpoch) + .setLeader(0) + .setLeaderEpoch(leaderEpoch + leaderEpochIncrement) + .setIsr(partition1Replicas) + .setZkVersion(0) + .setReplicas(partition1Replicas) + .setIsNew(true) + ).asJava, + topicIds, + Set(new Node(0, "host0", 0), new Node(1, "host1", 1)).asJava).build() + + rm0.becomeLeaderOrFollower(correlationId, leaderAndIsrRequest2, (_, _) => ()) + rm1.becomeLeaderOrFollower(correlationId, leaderAndIsrRequest2, (_, _) => ()) + } finally { + rm0.shutdown() + rm1.shutdown() + } + + // verify that broker 1 did remove its metrics when no longer being the leader of partition 1 + EasyMock.verify(mockTopicStats1) + } + + private def prepareDifferentReplicaManagers(brokerTopicStats1: BrokerTopicStats, + brokerTopicStats2: BrokerTopicStats): (ReplicaManager, ReplicaManager) = { + val props0 = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect) + val props1 = TestUtils.createBrokerConfig(1, TestUtils.MockZkConnect) + + props0.put("log0.dir", TestUtils.tempRelativeDir("data").getAbsolutePath) + props1.put("log1.dir", TestUtils.tempRelativeDir("data").getAbsolutePath) + + val config0 = KafkaConfig.fromProps(props0) + val config1 = KafkaConfig.fromProps(props1) + + val mockLogMgr0 = TestUtils.createLogManager(config0.logDirs.map(new File(_))) + val mockLogMgr1 = TestUtils.createLogManager(config1.logDirs.map(new File(_))) + + val metadataCache0: MetadataCache = Mockito.mock(classOf[MetadataCache]) + val metadataCache1: MetadataCache = Mockito.mock(classOf[MetadataCache]) + val aliveBrokers = Seq(new Node(0, "host0", 0), new Node(1, "host1", 1)) + mockGetAliveBrokerFunctions(metadataCache0, aliveBrokers) + mockGetAliveBrokerFunctions(metadataCache1, aliveBrokers) + + // each replica manager is for a broker + val rm0 = new ReplicaManager( + metrics = metrics, + config = config0, + time = time, + scheduler = new MockScheduler(time), + logManager = mockLogMgr0, + quotaManagers = quotaManager, + brokerTopicStats = brokerTopicStats1, + metadataCache = metadataCache0, + logDirFailureChannel = new LogDirFailureChannel(config0.logDirs.size), + alterIsrManager = alterIsrManager) + val rm1 = new ReplicaManager( + metrics = metrics, + config = config1, + time = time, + scheduler = new MockScheduler(time), + logManager = mockLogMgr1, + quotaManagers = quotaManager, + brokerTopicStats = brokerTopicStats2, + metadataCache = metadataCache1, + logDirFailureChannel = new LogDirFailureChannel(config1.logDirs.size), + alterIsrManager = alterIsrManager) + + (rm0, rm1) + } + + @Test + def testStopReplicaWithStaleControllerEpoch(): Unit = { + val mockTimer = new MockTimer(time) + val replicaManager = setupReplicaManagerWithMockedPurgatories(mockTimer, aliveBrokerIds = Seq(0, 1)) + + val tp0 = new TopicPartition(topic, 0) + val offsetCheckpoints = new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints) + replicaManager.createPartition(tp0).createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None) + + val becomeLeaderRequest = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 10, brokerEpoch, + Seq(leaderAndIsrPartitionState(tp0, 1, 0, Seq(0, 1), true)).asJava, + Collections.singletonMap(topic, Uuid.randomUuid()), + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava + ).build() + + replicaManager.becomeLeaderOrFollower(1, becomeLeaderRequest, (_, _) => ()) + + val partitionStates = Map(tp0 -> new StopReplicaPartitionState() + .setPartitionIndex(tp0.partition) + .setLeaderEpoch(1) + .setDeletePartition(false) + ) + + val (_, error) = replicaManager.stopReplicas(1, 0, 0, partitionStates) + assertEquals(Errors.STALE_CONTROLLER_EPOCH, error) + } + + @Test + def testStopReplicaWithOfflinePartition(): Unit = { + val mockTimer = new MockTimer(time) + val replicaManager = setupReplicaManagerWithMockedPurgatories(mockTimer, aliveBrokerIds = Seq(0, 1)) + + val tp0 = new TopicPartition(topic, 0) + val offsetCheckpoints = new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints) + replicaManager.createPartition(tp0).createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None) + + val becomeLeaderRequest = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch, + Seq(leaderAndIsrPartitionState(tp0, 1, 0, Seq(0, 1), true)).asJava, + Collections.singletonMap(topic, Uuid.randomUuid()), + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava + ).build() + + replicaManager.becomeLeaderOrFollower(1, becomeLeaderRequest, (_, _) => ()) + replicaManager.markPartitionOffline(tp0) + + val partitionStates = Map(tp0 -> new StopReplicaPartitionState() + .setPartitionIndex(tp0.partition) + .setLeaderEpoch(1) + .setDeletePartition(false) + ) + + val (result, error) = replicaManager.stopReplicas(1, 0, 0, partitionStates) + assertEquals(Errors.NONE, error) + assertEquals(Map(tp0 -> Errors.KAFKA_STORAGE_ERROR), result) + } + + @Test + def testStopReplicaWithInexistentPartition(): Unit = { + testStopReplicaWithInexistentPartition(false, false) + } + + @Test + def testStopReplicaWithInexistentPartitionAndPartitionsDelete(): Unit = { + testStopReplicaWithInexistentPartition(true, false) + } + + @Test + def testStopReplicaWithInexistentPartitionAndPartitionsDeleteAndIOException(): Unit = { + testStopReplicaWithInexistentPartition(true, true) + } + + private def testStopReplicaWithInexistentPartition(deletePartitions: Boolean, throwIOException: Boolean): Unit = { + val mockTimer = new MockTimer(time) + val replicaManager = setupReplicaManagerWithMockedPurgatories(mockTimer, aliveBrokerIds = Seq(0, 1)) + + val tp0 = new TopicPartition(topic, 0) + val log = replicaManager.logManager.getOrCreateLog(tp0, true, topicId = None) + + if (throwIOException) { + // Delete the underlying directory to trigger an KafkaStorageException + val dir = log.dir.getParentFile + Utils.delete(dir) + dir.createNewFile() + } + + val partitionStates = Map(tp0 -> new StopReplicaPartitionState() + .setPartitionIndex(tp0.partition) + .setLeaderEpoch(1) + .setDeletePartition(deletePartitions) + ) + + val (result, error) = replicaManager.stopReplicas(1, 0, 0, partitionStates) + assertEquals(Errors.NONE, error) + + if (throwIOException && deletePartitions) { + assertEquals(Map(tp0 -> Errors.KAFKA_STORAGE_ERROR), result) + assertTrue(replicaManager.logManager.getLog(tp0).isEmpty) + } else if (deletePartitions) { + assertEquals(Map(tp0 -> Errors.NONE), result) + assertTrue(replicaManager.logManager.getLog(tp0).isEmpty) + } else { + assertEquals(Map(tp0 -> Errors.NONE), result) + assertTrue(replicaManager.logManager.getLog(tp0).isDefined) + } + } + + @Test + def testStopReplicaWithExistingPartitionAndNewerLeaderEpoch(): Unit = { + testStopReplicaWithExistingPartition(2, false, false, Errors.NONE) + } + + @Test + def testStopReplicaWithExistingPartitionAndOlderLeaderEpoch(): Unit = { + testStopReplicaWithExistingPartition(0, false, false, Errors.FENCED_LEADER_EPOCH) + } + + @Test + def testStopReplicaWithExistingPartitionAndEqualLeaderEpoch(): Unit = { + testStopReplicaWithExistingPartition(1, false, false, Errors.FENCED_LEADER_EPOCH) + } + + @Test + def testStopReplicaWithExistingPartitionAndDeleteSentinel(): Unit = { + testStopReplicaWithExistingPartition(LeaderAndIsr.EpochDuringDelete, false, false, Errors.NONE) + } + + @Test + def testStopReplicaWithExistingPartitionAndLeaderEpochNotProvided(): Unit = { + testStopReplicaWithExistingPartition(LeaderAndIsr.NoEpoch, false, false, Errors.NONE) + } + + @Test + def testStopReplicaWithDeletePartitionAndExistingPartitionAndNewerLeaderEpoch(): Unit = { + testStopReplicaWithExistingPartition(2, true, false, Errors.NONE) + } + + @Test + def testStopReplicaWithDeletePartitionAndExistingPartitionAndNewerLeaderEpochAndIOException(): Unit = { + testStopReplicaWithExistingPartition(2, true, true, Errors.KAFKA_STORAGE_ERROR) + } + + @Test + def testStopReplicaWithDeletePartitionAndExistingPartitionAndOlderLeaderEpoch(): Unit = { + testStopReplicaWithExistingPartition(0, true, false, Errors.FENCED_LEADER_EPOCH) + } + + @Test + def testStopReplicaWithDeletePartitionAndExistingPartitionAndEqualLeaderEpoch(): Unit = { + testStopReplicaWithExistingPartition(1, true, false, Errors.FENCED_LEADER_EPOCH) + } + + @Test + def testStopReplicaWithDeletePartitionAndExistingPartitionAndDeleteSentinel(): Unit = { + testStopReplicaWithExistingPartition(LeaderAndIsr.EpochDuringDelete, true, false, Errors.NONE) + } + + @Test + def testStopReplicaWithDeletePartitionAndExistingPartitionAndLeaderEpochNotProvided(): Unit = { + testStopReplicaWithExistingPartition(LeaderAndIsr.NoEpoch, true, false, Errors.NONE) + } + + private def testStopReplicaWithExistingPartition(leaderEpoch: Int, + deletePartition: Boolean, + throwIOException: Boolean, + expectedOutput: Errors): Unit = { + val mockTimer = new MockTimer(time) + val replicaManager = setupReplicaManagerWithMockedPurgatories(mockTimer, aliveBrokerIds = Seq(0, 1)) + + val tp0 = new TopicPartition(topic, 0) + val offsetCheckpoints = new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints) + val partition = replicaManager.createPartition(tp0) + partition.createLogIfNotExists(isNew = false, isFutureReplica = false, offsetCheckpoints, None) + + val logDirFailureChannel = new LogDirFailureChannel(replicaManager.config.logDirs.size) + val logDir = partition.log.get.parentDirFile + + def readRecoveryPointCheckpoint(): Map[TopicPartition, Long] = { + new OffsetCheckpointFile(new File(logDir, LogManager.RecoveryPointCheckpointFile), + logDirFailureChannel).read() + } + + def readLogStartOffsetCheckpoint(): Map[TopicPartition, Long] = { + new OffsetCheckpointFile(new File(logDir, LogManager.LogStartOffsetCheckpointFile), + logDirFailureChannel).read() + } + + val becomeLeaderRequest = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch, + Seq(leaderAndIsrPartitionState(tp0, 1, 0, Seq(0, 1), true)).asJava, + Collections.singletonMap(tp0.topic(), Uuid.randomUuid()), + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava + ).build() + + replicaManager.becomeLeaderOrFollower(1, becomeLeaderRequest, (_, _) => ()) + + val batch = TestUtils.records(records = List( + new SimpleRecord(10, "k1".getBytes, "v1".getBytes), + new SimpleRecord(11, "k2".getBytes, "v2".getBytes))) + partition.appendRecordsToLeader(batch, AppendOrigin.Client, requiredAcks = 0, RequestLocal.withThreadConfinedCaching) + partition.log.get.updateHighWatermark(2L) + partition.log.get.maybeIncrementLogStartOffset(1L, LeaderOffsetIncremented) + replicaManager.logManager.checkpointLogRecoveryOffsets() + replicaManager.logManager.checkpointLogStartOffsets() + assertEquals(Some(1L), readRecoveryPointCheckpoint().get(tp0)) + assertEquals(Some(1L), readLogStartOffsetCheckpoint().get(tp0)) + + if (throwIOException) { + // Delete the underlying directory to trigger an KafkaStorageException + val dir = partition.log.get.dir + Utils.delete(dir) + dir.createNewFile() + } + + val partitionStates = Map(tp0 -> new StopReplicaPartitionState() + .setPartitionIndex(tp0.partition) + .setLeaderEpoch(leaderEpoch) + .setDeletePartition(deletePartition) + ) + + val (result, error) = replicaManager.stopReplicas(1, 0, 0, partitionStates) + assertEquals(Errors.NONE, error) + assertEquals(Map(tp0 -> expectedOutput), result) + + if (expectedOutput == Errors.NONE && deletePartition) { + assertEquals(HostedPartition.None, replicaManager.getPartition(tp0)) + assertFalse(readRecoveryPointCheckpoint().contains(tp0)) + assertFalse(readLogStartOffsetCheckpoint().contains(tp0)) + } + } + + @Test + def testReplicaNotAvailable(): Unit = { + + def createReplicaManager(): ReplicaManager = { + val props = TestUtils.createBrokerConfig(1, TestUtils.MockZkConnect) + val config = KafkaConfig.fromProps(props) + val mockLogMgr = TestUtils.createLogManager(config.logDirs.map(new File(_))) + new ReplicaManager( + metrics = metrics, + config = config, + time = time, + scheduler = new MockScheduler(time), + logManager = mockLogMgr, + quotaManagers = quotaManager, + metadataCache = MetadataCache.zkMetadataCache(config.brokerId), + logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size), + alterIsrManager = alterIsrManager) { + override def getPartitionOrException(topicPartition: TopicPartition): Partition = { + throw Errors.NOT_LEADER_OR_FOLLOWER.exception() + } + } + } + + val replicaManager = createReplicaManager() + try { + val tp = new TopicPartition(topic, 0) + val dir = replicaManager.logManager.liveLogDirs.head.getAbsolutePath + val errors = replicaManager.alterReplicaLogDirs(Map(tp -> dir)) + assertEquals(Errors.REPLICA_NOT_AVAILABLE, errors(tp)) + } finally { + replicaManager.shutdown(false) + } + } + + @Test + def testPartitionMetadataFile(): Unit = { + val replicaManager = setupReplicaManagerWithMockedPurgatories(new MockTimer(time)) + try { + val brokerList = Seq[Integer](0, 1).asJava + val topicPartition = new TopicPartition(topic, 0) + val topicIds = Collections.singletonMap(topic, Uuid.randomUuid()) + val topicNames = topicIds.asScala.map(_.swap).asJava + + def leaderAndIsrRequest(epoch: Int, topicIds: java.util.Map[String, Uuid]): LeaderAndIsrRequest = + new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch, + Seq(new LeaderAndIsrPartitionState() + .setTopicName(topic) + .setPartitionIndex(0) + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(epoch) + .setIsr(brokerList) + .setZkVersion(0) + .setReplicas(brokerList) + .setIsNew(true)).asJava, + topicIds, + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build() + + val response = replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest(0, topicIds), (_, _) => ()) + assertEquals(Errors.NONE, response.partitionErrors(topicNames).get(topicPartition)) + assertFalse(replicaManager.localLog(topicPartition).isEmpty) + val id = topicIds.get(topicPartition.topic()) + val log = replicaManager.localLog(topicPartition).get + assertTrue(log.partitionMetadataFile.exists()) + val partitionMetadata = log.partitionMetadataFile.read() + + // Current version of PartitionMetadataFile is 0. + assertEquals(0, partitionMetadata.version) + assertEquals(id, partitionMetadata.topicId) + } finally replicaManager.shutdown(checkpointHW = false) + } + + @Test + def testPartitionMetadataFileCreatedWithExistingLog(): Unit = { + val replicaManager = setupReplicaManagerWithMockedPurgatories(new MockTimer(time)) + try { + val brokerList = Seq[Integer](0, 1).asJava + val topicPartition = new TopicPartition(topic, 0) + + replicaManager.logManager.getOrCreateLog(topicPartition, isNew = true, topicId = None) + + assertTrue(replicaManager.getLog(topicPartition).isDefined) + var log = replicaManager.getLog(topicPartition).get + assertEquals(None, log.topicId) + assertFalse(log.partitionMetadataFile.exists()) + + val topicIds = Collections.singletonMap(topic, Uuid.randomUuid()) + val topicNames = topicIds.asScala.map(_.swap).asJava + + def leaderAndIsrRequest(epoch: Int): LeaderAndIsrRequest = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch, + Seq(new LeaderAndIsrPartitionState() + .setTopicName(topic) + .setPartitionIndex(0) + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(epoch) + .setIsr(brokerList) + .setZkVersion(0) + .setReplicas(brokerList) + .setIsNew(true)).asJava, + topicIds, + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build() + + val response = replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest(0), (_, _) => ()) + assertEquals(Errors.NONE, response.partitionErrors(topicNames).get(topicPartition)) + assertFalse(replicaManager.localLog(topicPartition).isEmpty) + val id = topicIds.get(topicPartition.topic()) + log = replicaManager.localLog(topicPartition).get + assertTrue(log.partitionMetadataFile.exists()) + val partitionMetadata = log.partitionMetadataFile.read() + + // Current version of PartitionMetadataFile is 0. + assertEquals(0, partitionMetadata.version) + assertEquals(id, partitionMetadata.topicId) + } finally replicaManager.shutdown(checkpointHW = false) + } + + @Test + def testPartitionMetadataFileCreatedAfterPreviousRequestWithoutIds(): Unit = { + val replicaManager = setupReplicaManagerWithMockedPurgatories(new MockTimer(time)) + try { + val brokerList = Seq[Integer](0, 1).asJava + val topicPartition = new TopicPartition(topic, 0) + val topicPartition2 = new TopicPartition(topic, 1) + + def leaderAndIsrRequest(topicIds: util.Map[String, Uuid], version: Short, partition: Int = 0, leaderEpoch: Int = 0): LeaderAndIsrRequest = + new LeaderAndIsrRequest.Builder(version, 0, 0, brokerEpoch, + Seq(new LeaderAndIsrPartitionState() + .setTopicName(topic) + .setPartitionIndex(partition) + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(leaderEpoch) + .setIsr(brokerList) + .setZkVersion(0) + .setReplicas(brokerList) + .setIsNew(true)).asJava, + topicIds, + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build() + + // Send a request without a topic ID so that we have a log without a topic ID associated to the partition. + val response = replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest(Collections.emptyMap(), 4), (_, _) => ()) + assertEquals(Errors.NONE, response.partitionErrors(Collections.emptyMap()).get(topicPartition)) + assertTrue(replicaManager.localLog(topicPartition).isDefined) + val log = replicaManager.localLog(topicPartition).get + assertFalse(log.partitionMetadataFile.exists()) + assertTrue(log.topicId.isEmpty) + + val response2 = replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest(topicIds.asJava, ApiKeys.LEADER_AND_ISR.latestVersion), (_, _) => ()) + assertEquals(Errors.NONE, response2.partitionErrors(topicNames.asJava).get(topicPartition)) + assertTrue(replicaManager.localLog(topicPartition).isDefined) + assertTrue(log.partitionMetadataFile.exists()) + assertTrue(log.topicId.isDefined) + assertEquals(topicId, log.topicId.get) + + // Repeat with partition 2, but in this case, update the leader epoch + // Send a request without a topic ID so that we have a log without a topic ID associated to the partition. + val response3 = replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest(Collections.emptyMap(), 4, 1), (_, _) => ()) + assertEquals(Errors.NONE, response3.partitionErrors(Collections.emptyMap()).get(topicPartition2)) + assertTrue(replicaManager.localLog(topicPartition2).isDefined) + val log2 = replicaManager.localLog(topicPartition2).get + assertFalse(log2.partitionMetadataFile.exists()) + assertTrue(log2.topicId.isEmpty) + + val response4 = replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest(topicIds.asJava, ApiKeys.LEADER_AND_ISR.latestVersion, 1, 1), (_, _) => ()) + assertEquals(Errors.NONE, response4.partitionErrors(topicNames.asJava).get(topicPartition2)) + assertTrue(replicaManager.localLog(topicPartition2).isDefined) + assertTrue(log2.partitionMetadataFile.exists()) + assertTrue(log2.topicId.isDefined) + assertEquals(topicId, log2.topicId.get) + + assertEquals(topicId, log.partitionMetadataFile.read().topicId) + assertEquals(topicId, log2.partitionMetadataFile.read().topicId) + } finally replicaManager.shutdown(checkpointHW = false) + } + + @Test + def testInconsistentIdReturnsError(): Unit = { + val replicaManager = setupReplicaManagerWithMockedPurgatories(new MockTimer(time)) + try { + val brokerList = Seq[Integer](0, 1).asJava + val topicPartition = new TopicPartition(topic, 0) + val topicIds = Collections.singletonMap(topic, Uuid.randomUuid()) + val topicNames = topicIds.asScala.map(_.swap).asJava + + val invalidTopicIds = Collections.singletonMap(topic, Uuid.randomUuid()) + val invalidTopicNames = invalidTopicIds.asScala.map(_.swap).asJava + + def leaderAndIsrRequest(epoch: Int, topicIds: java.util.Map[String, Uuid]): LeaderAndIsrRequest = + new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, 0, 0, brokerEpoch, + Seq(new LeaderAndIsrPartitionState() + .setTopicName(topic) + .setPartitionIndex(0) + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(epoch) + .setIsr(brokerList) + .setZkVersion(0) + .setReplicas(brokerList) + .setIsNew(true)).asJava, + topicIds, + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build() + + val response = replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest(0, topicIds), (_, _) => ()) + assertEquals(Errors.NONE, response.partitionErrors(topicNames).get(topicPartition)) + + val response2 = replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest(1, topicIds), (_, _) => ()) + assertEquals(Errors.NONE, response2.partitionErrors(topicNames).get(topicPartition)) + + // Send request with inconsistent ID. + val response3 = replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest(1, invalidTopicIds), (_, _) => ()) + assertEquals(Errors.INCONSISTENT_TOPIC_ID, response3.partitionErrors(invalidTopicNames).get(topicPartition)) + + val response4 = replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest(2, invalidTopicIds), (_, _) => ()) + assertEquals(Errors.INCONSISTENT_TOPIC_ID, response4.partitionErrors(invalidTopicNames).get(topicPartition)) + } finally replicaManager.shutdown(checkpointHW = false) + } + + @Test + def testPartitionMetadataFileNotCreated(): Unit = { + val replicaManager = setupReplicaManagerWithMockedPurgatories(new MockTimer(time)) + try { + val brokerList = Seq[Integer](0, 1).asJava + val topicPartition = new TopicPartition(topic, 0) + val topicPartitionFoo = new TopicPartition("foo", 0) + val topicPartitionFake = new TopicPartition("fakeTopic", 0) + val topicIds = Map(topic -> Uuid.ZERO_UUID, "foo" -> Uuid.randomUuid()).asJava + val topicNames = topicIds.asScala.map(_.swap).asJava + + def leaderAndIsrRequest(epoch: Int, name: String, version: Short): LeaderAndIsrRequest = LeaderAndIsrRequest.parse( + new LeaderAndIsrRequest.Builder(version, 0, 0, brokerEpoch, + Seq(new LeaderAndIsrPartitionState() + .setTopicName(name) + .setPartitionIndex(0) + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(epoch) + .setIsr(brokerList) + .setZkVersion(0) + .setReplicas(brokerList) + .setIsNew(true)).asJava, + topicIds, + Set(new Node(0, "host1", 0), new Node(1, "host2", 1)).asJava).build().serialize(), version) + + // There is no file if the topic does not have an associated topic ID. + val response = replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest(0, "fakeTopic", ApiKeys.LEADER_AND_ISR.latestVersion), (_, _) => ()) + assertTrue(replicaManager.localLog(topicPartitionFake).isDefined) + val log = replicaManager.localLog(topicPartitionFake).get + assertFalse(log.partitionMetadataFile.exists()) + assertEquals(Errors.NONE, response.partitionErrors(topicNames).get(topicPartition)) + + // There is no file if the topic has the default UUID. + val response2 = replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest(0, topic, ApiKeys.LEADER_AND_ISR.latestVersion), (_, _) => ()) + assertTrue(replicaManager.localLog(topicPartition).isDefined) + val log2 = replicaManager.localLog(topicPartition).get + assertFalse(log2.partitionMetadataFile.exists()) + assertEquals(Errors.NONE, response2.partitionErrors(topicNames).get(topicPartition)) + + // There is no file if the request an older version + val response3 = replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest(0, "foo", 0), (_, _) => ()) + assertTrue(replicaManager.localLog(topicPartitionFoo).isDefined) + val log3 = replicaManager.localLog(topicPartitionFoo).get + assertFalse(log3.partitionMetadataFile.exists()) + assertEquals(Errors.NONE, response3.partitionErrors(topicNames).get(topicPartitionFoo)) + + // There is no file if the request is an older version + val response4 = replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest(1, "foo", 4), (_, _) => ()) + assertTrue(replicaManager.localLog(topicPartitionFoo).isDefined) + val log4 = replicaManager.localLog(topicPartitionFoo).get + assertFalse(log4.partitionMetadataFile.exists()) + assertEquals(Errors.NONE, response4.partitionErrors(topicNames).get(topicPartitionFoo)) + } finally replicaManager.shutdown(checkpointHW = false) + } + + @ParameterizedTest + @ValueSource(booleans = Array(true, false)) + def testPartitionMarkedOfflineIfLogCantBeCreated(becomeLeader: Boolean): Unit = { + val dataDir = TestUtils.tempDir() + val topicPartition = new TopicPartition(topic, 0) + val replicaManager = setupReplicaManagerWithMockedPurgatories( + timer = new MockTimer(time), + propsModifier = props => props.put(KafkaConfig.LogDirsProp, dataDir.getAbsolutePath) + ) + + try { + // Delete the data directory to trigger a storage exception + Utils.delete(dataDir) + + val request = leaderAndIsrRequest( + topicId = Uuid.randomUuid(), + topicPartition = topicPartition, + replicas = Seq(0, 1), + leaderAndIsr = LeaderAndIsr(if (becomeLeader) 0 else 1, List(0, 1)), + isNew = true + ) + + replicaManager.becomeLeaderOrFollower(0, request, (_, _) => ()) + + assertEquals(HostedPartition.Offline, replicaManager.getPartition(topicPartition)) + } finally { + replicaManager.shutdown(checkpointHW = false) + } + } + + private def leaderAndIsrRequest( + topicId: Uuid, + topicPartition: TopicPartition, + replicas: Seq[Int], + leaderAndIsr: LeaderAndIsr, + isNew: Boolean = true, + brokerEpoch: Int = 0, + controllerId: Int = 0, + controllerEpoch: Int = 0, + version: Short = LeaderAndIsrRequestData.HIGHEST_SUPPORTED_VERSION + ): LeaderAndIsrRequest = { + val partitionState = new LeaderAndIsrPartitionState() + .setTopicName(topicPartition.topic) + .setPartitionIndex(topicPartition.partition) + .setControllerEpoch(controllerEpoch) + .setLeader(leaderAndIsr.leader) + .setLeaderEpoch(leaderAndIsr.leaderEpoch) + .setIsr(leaderAndIsr.isr.map(Int.box).asJava) + .setZkVersion(leaderAndIsr.zkVersion) + .setReplicas(replicas.map(Int.box).asJava) + .setIsNew(isNew) + + def mkNode(replicaId: Int): Node = { + new Node(replicaId, s"host-$replicaId", 9092) + } + + val nodes = Set(mkNode(controllerId)) ++ replicas.map(mkNode).toSet + + new LeaderAndIsrRequest.Builder( + version, + controllerId, + controllerEpoch, + brokerEpoch, + Seq(partitionState).asJava, + Map(topicPartition.topic -> topicId).asJava, + nodes.asJava + ).build() + } + + @Test + def testActiveProducerState(): Unit = { + val brokerId = 0 + val replicaManager = setupReplicaManagerWithMockedPurgatories(new MockTimer(time), brokerId) + try { + val fooPartition = new TopicPartition("foo", 0) + Mockito.when(replicaManager.metadataCache.contains(fooPartition)).thenReturn(false) + val fooProducerState = replicaManager.activeProducerState(fooPartition) + assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, Errors.forCode(fooProducerState.errorCode)) + + val oofPartition = new TopicPartition("oof", 0) + Mockito.when(replicaManager.metadataCache.contains(oofPartition)).thenReturn(true) + val oofProducerState = replicaManager.activeProducerState(oofPartition) + assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, Errors.forCode(oofProducerState.errorCode)) + + // This API is supported by both leaders and followers + + val barPartition = new TopicPartition("bar", 0) + val barLeaderAndIsrRequest = leaderAndIsrRequest( + topicId = Uuid.randomUuid(), + topicPartition = barPartition, + replicas = Seq(brokerId), + leaderAndIsr = LeaderAndIsr(brokerId, List(brokerId)) + ) + replicaManager.becomeLeaderOrFollower(0, barLeaderAndIsrRequest, (_, _) => ()) + val barProducerState = replicaManager.activeProducerState(barPartition) + assertEquals(Errors.NONE, Errors.forCode(barProducerState.errorCode)) + + val otherBrokerId = 1 + val bazPartition = new TopicPartition("baz", 0) + val bazLeaderAndIsrRequest = leaderAndIsrRequest( + topicId = Uuid.randomUuid(), + topicPartition = bazPartition, + replicas = Seq(brokerId, otherBrokerId), + leaderAndIsr = LeaderAndIsr(otherBrokerId, List(brokerId, otherBrokerId)) + ) + replicaManager.becomeLeaderOrFollower(0, bazLeaderAndIsrRequest, (_, _) => ()) + val bazProducerState = replicaManager.activeProducerState(bazPartition) + assertEquals(Errors.NONE, Errors.forCode(bazProducerState.errorCode)) + } finally { + replicaManager.shutdown(checkpointHW = false) + } + } + + val FOO_UUID = Uuid.fromString("fFJBx0OmQG-UqeaT6YaSwA") + + val BAR_UUID = Uuid.fromString("vApAP6y7Qx23VOfKBzbOBQ") + + @Test + def testGetOrCreatePartition(): Unit = { + val brokerId = 0 + val replicaManager = setupReplicaManagerWithMockedPurgatories(new MockTimer(time), brokerId) + val foo0 = new TopicPartition("foo", 0) + val emptyDelta = new TopicsDelta(TopicsImage.EMPTY) + val (fooPart, fooNew) = replicaManager.getOrCreatePartition(foo0, emptyDelta, FOO_UUID).get + assertTrue(fooNew) + assertEquals(foo0, fooPart.topicPartition) + val (fooPart2, fooNew2) = replicaManager.getOrCreatePartition(foo0, emptyDelta, FOO_UUID).get + assertFalse(fooNew2) + assertTrue(fooPart eq fooPart2) + val bar1 = new TopicPartition("bar", 1) + replicaManager.markPartitionOffline(bar1) + assertEquals(None, replicaManager.getOrCreatePartition(bar1, emptyDelta, BAR_UUID)) + } + + @Test + def testDeltaFromLeaderToFollower(): Unit = { + val localId = 1 + val otherId = localId + 1 + val numOfRecords = 3 + val topicPartition = new TopicPartition("foo", 0) + val replicaManager = setupReplicaManagerWithMockedPurgatories(new MockTimer(time), localId) + + try { + // Make the local replica the leader + val leaderTopicsDelta = topicsCreateDelta(localId, true) + val leaderMetadataImage = imageFromTopics(leaderTopicsDelta.apply()) + val topicId = leaderMetadataImage.topics().topicsByName.get("foo").id + val topicIdPartition = new TopicIdPartition(topicId, topicPartition) + replicaManager.applyDelta(leaderTopicsDelta, leaderMetadataImage) + + // Check the state of that partition and fetcher + val HostedPartition.Online(leaderPartition) = replicaManager.getPartition(topicPartition) + assertTrue(leaderPartition.isLeader) + assertEquals(Set(localId, otherId), leaderPartition.inSyncReplicaIds) + assertEquals(0, leaderPartition.getLeaderEpoch) + + assertEquals(None, replicaManager.replicaFetcherManager.getFetcher(topicPartition)) + + // Send a produce request and advance the highwatermark + val leaderResponse = sendProducerAppend(replicaManager, topicPartition, numOfRecords) + fetchMessages( + replicaManager, + otherId, + topicIdPartition, + new PartitionData(Uuid.ZERO_UUID, numOfRecords, 0, Int.MaxValue, Optional.empty()), + Int.MaxValue, + IsolationLevel.READ_UNCOMMITTED, + None + ) + assertEquals(Errors.NONE, leaderResponse.get.error) + + // Change the local replica to follower + val followerTopicsDelta = topicsChangeDelta(leaderMetadataImage.topics(), localId, false) + val followerMetadataImage = imageFromTopics(followerTopicsDelta.apply()) + replicaManager.applyDelta(followerTopicsDelta, followerMetadataImage) + + // Append on a follower should fail + val followerResponse = sendProducerAppend(replicaManager, topicPartition, numOfRecords) + assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, followerResponse.get.error) + + // Check the state of that partition and fetcher + val HostedPartition.Online(followerPartition) = replicaManager.getPartition(topicPartition) + assertFalse(followerPartition.isLeader) + assertEquals(1, followerPartition.getLeaderEpoch) + + val fetcher = replicaManager.replicaFetcherManager.getFetcher(topicPartition) + assertEquals(Some(BrokerEndPoint(otherId, "localhost", 9093)), fetcher.map(_.sourceBroker)) + } finally { + replicaManager.shutdown() + } + + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @Test + def testDeltaFromFollowerToLeader(): Unit = { + val localId = 1 + val otherId = localId + 1 + val numOfRecords = 3 + val topicPartition = new TopicPartition("foo", 0) + val replicaManager = setupReplicaManagerWithMockedPurgatories(new MockTimer(time), localId) + + try { + // Make the local replica the follower + val followerTopicsDelta = topicsCreateDelta(localId, false) + val followerMetadataImage = imageFromTopics(followerTopicsDelta.apply()) + replicaManager.applyDelta(followerTopicsDelta, followerMetadataImage) + + // Check the state of that partition and fetcher + val HostedPartition.Online(followerPartition) = replicaManager.getPartition(topicPartition) + assertFalse(followerPartition.isLeader) + assertEquals(0, followerPartition.getLeaderEpoch) + + val fetcher = replicaManager.replicaFetcherManager.getFetcher(topicPartition) + assertEquals(Some(BrokerEndPoint(otherId, "localhost", 9093)), fetcher.map(_.sourceBroker)) + + // Append on a follower should fail + val followerResponse = sendProducerAppend(replicaManager, topicPartition, numOfRecords) + assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, followerResponse.get.error) + + // Change the local replica to leader + val leaderTopicsDelta = topicsChangeDelta(followerMetadataImage.topics(), localId, true) + val leaderMetadataImage = imageFromTopics(leaderTopicsDelta.apply()) + val topicId = leaderMetadataImage.topics().topicsByName.get("foo").id + val topicIdPartition = new TopicIdPartition(topicId, topicPartition) + replicaManager.applyDelta(leaderTopicsDelta, leaderMetadataImage) + + // Send a produce request and advance the highwatermark + val leaderResponse = sendProducerAppend(replicaManager, topicPartition, numOfRecords) + fetchMessages( + replicaManager, + otherId, + topicIdPartition, + new PartitionData(Uuid.ZERO_UUID, numOfRecords, 0, Int.MaxValue, Optional.empty()), + Int.MaxValue, + IsolationLevel.READ_UNCOMMITTED, + None + ) + assertEquals(Errors.NONE, leaderResponse.get.error) + + val HostedPartition.Online(leaderPartition) = replicaManager.getPartition(topicPartition) + assertTrue(leaderPartition.isLeader) + assertEquals(Set(localId, otherId), leaderPartition.inSyncReplicaIds) + assertEquals(1, leaderPartition.getLeaderEpoch) + + assertEquals(None, replicaManager.replicaFetcherManager.getFetcher(topicPartition)) + } finally { + replicaManager.shutdown() + } + + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @Test + def testDeltaFollowerWithNoChange(): Unit = { + val localId = 1 + val otherId = localId + 1 + val topicPartition = new TopicPartition("foo", 0) + val replicaManager = setupReplicaManagerWithMockedPurgatories(new MockTimer(time), localId) + + try { + // Make the local replica the follower + val followerTopicsDelta = topicsCreateDelta(localId, false) + val followerMetadataImage = imageFromTopics(followerTopicsDelta.apply()) + replicaManager.applyDelta(followerTopicsDelta, followerMetadataImage) + + // Check the state of that partition and fetcher + val HostedPartition.Online(followerPartition) = replicaManager.getPartition(topicPartition) + assertFalse(followerPartition.isLeader) + assertEquals(0, followerPartition.getLeaderEpoch) + + val fetcher = replicaManager.replicaFetcherManager.getFetcher(topicPartition) + assertEquals(Some(BrokerEndPoint(otherId, "localhost", 9093)), fetcher.map(_.sourceBroker)) + + // Apply the same delta again + replicaManager.applyDelta(followerTopicsDelta, followerMetadataImage) + + // Check that the state stays the same + val HostedPartition.Online(noChangePartition) = replicaManager.getPartition(topicPartition) + assertFalse(noChangePartition.isLeader) + assertEquals(0, noChangePartition.getLeaderEpoch) + + val noChangeFetcher = replicaManager.replicaFetcherManager.getFetcher(topicPartition) + assertEquals(Some(BrokerEndPoint(otherId, "localhost", 9093)), noChangeFetcher.map(_.sourceBroker)) + } finally { + replicaManager.shutdown() + } + + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @Test + def testDeltaFollowerToNotReplica(): Unit = { + val localId = 1 + val otherId = localId + 1 + val topicPartition = new TopicPartition("foo", 0) + val replicaManager = setupReplicaManagerWithMockedPurgatories(new MockTimer(time), localId) + + try { + // Make the local replica the follower + val followerTopicsDelta = topicsCreateDelta(localId, false) + val followerMetadataImage = imageFromTopics(followerTopicsDelta.apply()) + replicaManager.applyDelta(followerTopicsDelta, followerMetadataImage) + + // Check the state of that partition and fetcher + val HostedPartition.Online(followerPartition) = replicaManager.getPartition(topicPartition) + assertFalse(followerPartition.isLeader) + assertEquals(0, followerPartition.getLeaderEpoch) + + val fetcher = replicaManager.replicaFetcherManager.getFetcher(topicPartition) + assertEquals(Some(BrokerEndPoint(otherId, "localhost", 9093)), fetcher.map(_.sourceBroker)) + + // Apply changes that remove replica + val notReplicaTopicsDelta = topicsChangeDelta(followerMetadataImage.topics(), otherId, true) + val notReplicaMetadataImage = imageFromTopics(notReplicaTopicsDelta.apply()) + replicaManager.applyDelta(notReplicaTopicsDelta, notReplicaMetadataImage) + + // Check that the partition was removed + assertEquals(HostedPartition.None, replicaManager.getPartition(topicPartition)) + assertEquals(None, replicaManager.replicaFetcherManager.getFetcher(topicPartition)) + assertEquals(None, replicaManager.logManager.getLog(topicPartition)) + } finally { + replicaManager.shutdown() + } + + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @Test + def testDeltaFollowerRemovedTopic(): Unit = { + val localId = 1 + val otherId = localId + 1 + val topicPartition = new TopicPartition("foo", 0) + val replicaManager = setupReplicaManagerWithMockedPurgatories(new MockTimer(time), localId) + + try { + // Make the local replica the follower + val followerTopicsDelta = topicsCreateDelta(localId, false) + val followerMetadataImage = imageFromTopics(followerTopicsDelta.apply()) + replicaManager.applyDelta(followerTopicsDelta, followerMetadataImage) + + // Check the state of that partition and fetcher + val HostedPartition.Online(followerPartition) = replicaManager.getPartition(topicPartition) + assertFalse(followerPartition.isLeader) + assertEquals(0, followerPartition.getLeaderEpoch) + + val fetcher = replicaManager.replicaFetcherManager.getFetcher(topicPartition) + assertEquals(Some(BrokerEndPoint(otherId, "localhost", 9093)), fetcher.map(_.sourceBroker)) + + // Apply changes that remove topic and replica + val removeTopicsDelta = topicsDeleteDelta(followerMetadataImage.topics()) + val removeMetadataImage = imageFromTopics(removeTopicsDelta.apply()) + replicaManager.applyDelta(removeTopicsDelta, removeMetadataImage) + + // Check that the partition was removed + assertEquals(HostedPartition.None, replicaManager.getPartition(topicPartition)) + assertEquals(None, replicaManager.replicaFetcherManager.getFetcher(topicPartition)) + assertEquals(None, replicaManager.logManager.getLog(topicPartition)) + } finally { + replicaManager.shutdown() + } + + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @Test + def testDeltaLeaderToNotReplica(): Unit = { + val localId = 1 + val otherId = localId + 1 + val topicPartition = new TopicPartition("foo", 0) + val replicaManager = setupReplicaManagerWithMockedPurgatories(new MockTimer(time), localId) + + try { + // Make the local replica the follower + val leaderTopicsDelta = topicsCreateDelta(localId, true) + val leaderMetadataImage = imageFromTopics(leaderTopicsDelta.apply()) + replicaManager.applyDelta(leaderTopicsDelta, leaderMetadataImage) + + // Check the state of that partition and fetcher + val HostedPartition.Online(leaderPartition) = replicaManager.getPartition(topicPartition) + assertTrue(leaderPartition.isLeader) + assertEquals(Set(localId, otherId), leaderPartition.inSyncReplicaIds) + assertEquals(0, leaderPartition.getLeaderEpoch) + + assertEquals(None, replicaManager.replicaFetcherManager.getFetcher(topicPartition)) + + // Apply changes that remove replica + val notReplicaTopicsDelta = topicsChangeDelta(leaderMetadataImage.topics(), otherId, true) + val notReplicaMetadataImage = imageFromTopics(notReplicaTopicsDelta.apply()) + replicaManager.applyDelta(notReplicaTopicsDelta, notReplicaMetadataImage) + + // Check that the partition was removed + assertEquals(HostedPartition.None, replicaManager.getPartition(topicPartition)) + assertEquals(None, replicaManager.replicaFetcherManager.getFetcher(topicPartition)) + assertEquals(None, replicaManager.logManager.getLog(topicPartition)) + } finally { + replicaManager.shutdown() + } + + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @Test + def testDeltaLeaderToRemovedTopic(): Unit = { + val localId = 1 + val otherId = localId + 1 + val topicPartition = new TopicPartition("foo", 0) + val replicaManager = setupReplicaManagerWithMockedPurgatories(new MockTimer(time), localId) + + try { + // Make the local replica the follower + val leaderTopicsDelta = topicsCreateDelta(localId, true) + val leaderMetadataImage = imageFromTopics(leaderTopicsDelta.apply()) + replicaManager.applyDelta(leaderTopicsDelta, leaderMetadataImage) + + // Check the state of that partition and fetcher + val HostedPartition.Online(leaderPartition) = replicaManager.getPartition(topicPartition) + assertTrue(leaderPartition.isLeader) + assertEquals(Set(localId, otherId), leaderPartition.inSyncReplicaIds) + assertEquals(0, leaderPartition.getLeaderEpoch) + + assertEquals(None, replicaManager.replicaFetcherManager.getFetcher(topicPartition)) + + // Apply changes that remove topic and replica + val removeTopicsDelta = topicsDeleteDelta(leaderMetadataImage.topics()) + val removeMetadataImage = imageFromTopics(removeTopicsDelta.apply()) + replicaManager.applyDelta(removeTopicsDelta, removeMetadataImage) + + // Check that the partition was removed + assertEquals(HostedPartition.None, replicaManager.getPartition(topicPartition)) + assertEquals(None, replicaManager.replicaFetcherManager.getFetcher(topicPartition)) + assertEquals(None, replicaManager.logManager.getLog(topicPartition)) + } finally { + replicaManager.shutdown() + } + + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @Test + def testDeltaToFollowerCompletesProduce(): Unit = { + val localId = 1 + val otherId = localId + 1 + val numOfRecords = 3 + val topicPartition = new TopicPartition("foo", 0) + val replicaManager = setupReplicaManagerWithMockedPurgatories(new MockTimer(time), localId) + + try { + // Make the local replica the leader + val leaderTopicsDelta = topicsCreateDelta(localId, true) + val leaderMetadataImage = imageFromTopics(leaderTopicsDelta.apply()) + replicaManager.applyDelta(leaderTopicsDelta, leaderMetadataImage) + + // Check the state of that partition and fetcher + val HostedPartition.Online(leaderPartition) = replicaManager.getPartition(topicPartition) + assertTrue(leaderPartition.isLeader) + assertEquals(Set(localId, otherId), leaderPartition.inSyncReplicaIds) + assertEquals(0, leaderPartition.getLeaderEpoch) + + assertEquals(None, replicaManager.replicaFetcherManager.getFetcher(topicPartition)) + + // Send a produce request + val leaderResponse = sendProducerAppend(replicaManager, topicPartition, numOfRecords) + + // Change the local replica to follower + val followerTopicsDelta = topicsChangeDelta(leaderMetadataImage.topics(), localId, false) + val followerMetadataImage = imageFromTopics(followerTopicsDelta.apply()) + replicaManager.applyDelta(followerTopicsDelta, followerMetadataImage) + + // Check that the produce failed because it changed to follower before replicating + assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, leaderResponse.get.error) + } finally { + replicaManager.shutdown() + } + + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @Test + def testDeltaToFollowerCompletesFetch(): Unit = { + val localId = 1 + val otherId = localId + 1 + val topicPartition = new TopicPartition("foo", 0) + val replicaManager = setupReplicaManagerWithMockedPurgatories(new MockTimer(time), localId) + + try { + // Make the local replica the leader + val leaderTopicsDelta = topicsCreateDelta(localId, true) + val leaderMetadataImage = imageFromTopics(leaderTopicsDelta.apply()) + val topicId = leaderMetadataImage.topics().topicsByName.get("foo").id + val topicIdPartition = new TopicIdPartition(topicId, topicPartition) + replicaManager.applyDelta(leaderTopicsDelta, leaderMetadataImage) + + // Check the state of that partition and fetcher + val HostedPartition.Online(leaderPartition) = replicaManager.getPartition(topicPartition) + assertTrue(leaderPartition.isLeader) + assertEquals(Set(localId, otherId), leaderPartition.inSyncReplicaIds) + assertEquals(0, leaderPartition.getLeaderEpoch) + + assertEquals(None, replicaManager.replicaFetcherManager.getFetcher(topicPartition)) + + // Send a fetch request + val fetchCallback = fetchMessages( + replicaManager, + otherId, + topicIdPartition, + new PartitionData(Uuid.ZERO_UUID, 0, 0, Int.MaxValue, Optional.empty()), + Int.MaxValue, + IsolationLevel.READ_UNCOMMITTED, + None + ) + + // Change the local replica to follower + val followerTopicsDelta = topicsChangeDelta(leaderMetadataImage.topics(), localId, false) + val followerMetadataImage = imageFromTopics(followerTopicsDelta.apply()) + replicaManager.applyDelta(followerTopicsDelta, followerMetadataImage) + + // Check that the produce failed because it changed to follower before replicating + assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, fetchCallback.assertFired.error) + } finally { + replicaManager.shutdown() + } + + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @ParameterizedTest + @ValueSource(booleans = Array(true, false)) + def testDeltaToLeaderOrFollowerMarksPartitionOfflineIfLogCantBeCreated(isStartIdLeader: Boolean): Unit = { + val localId = 1 + val topicPartition = new TopicPartition("foo", 0) + val dataDir = TestUtils.tempDir() + val replicaManager = setupReplicaManagerWithMockedPurgatories( + timer = new MockTimer(time), + brokerId = localId, + propsModifier = props => props.put(KafkaConfig.LogDirsProp, dataDir.getAbsolutePath) + ) + + try { + // Delete the data directory to trigger a storage exception + Utils.delete(dataDir) + + // Make the local replica the leader + val topicsDelta = topicsCreateDelta(localId, isStartIdLeader) + val leaderMetadataImage = imageFromTopics(topicsDelta.apply()) + replicaManager.applyDelta(topicsDelta, leaderMetadataImage) + + assertEquals(HostedPartition.Offline, replicaManager.getPartition(topicPartition)) + } finally { + replicaManager.shutdown(checkpointHW = false) + } + } + + @Test + def testDeltaFollowerStopFetcherBeforeCreatingInitialFetchOffset(): Unit = { + val localId = 1 + val otherId = localId + 1 + val topicPartition = new TopicPartition("foo", 0) + + val mockReplicaFetcherManager = Mockito.mock(classOf[ReplicaFetcherManager]) + val replicaManager = setupReplicaManagerWithMockedPurgatories( + timer = new MockTimer(time), + brokerId = localId, + mockReplicaFetcherManager = Some(mockReplicaFetcherManager) + ) + + try { + // The first call to removeFetcherForPartitions should be ignored. + Mockito.when(mockReplicaFetcherManager.removeFetcherForPartitions( + Set(topicPartition)) + ).thenReturn(Map.empty[TopicPartition, PartitionFetchState]) + + // Make the local replica the follower + var followerTopicsDelta = topicsCreateDelta(localId, false) + var followerMetadataImage = imageFromTopics(followerTopicsDelta.apply()) + replicaManager.applyDelta(followerTopicsDelta, followerMetadataImage) + + // Check the state of that partition + val HostedPartition.Online(followerPartition) = replicaManager.getPartition(topicPartition) + assertFalse(followerPartition.isLeader) + assertEquals(0, followerPartition.getLeaderEpoch) + assertEquals(0, followerPartition.localLogOrException.logEndOffset) + + // Verify that addFetcherForPartitions was called with the correct + // init offset. + Mockito.verify(mockReplicaFetcherManager, Mockito.times(1)) + .addFetcherForPartitions( + Map(topicPartition -> InitialFetchState( + topicId = Some(FOO_UUID), + leader = BrokerEndPoint(otherId, "localhost", 9093), + currentLeaderEpoch = 0, + initOffset = 0 + )) + ) + + // The second call to removeFetcherForPartitions simulate the case + // where the fetcher write to the log before being shutdown. + Mockito.when(mockReplicaFetcherManager.removeFetcherForPartitions( + Set(topicPartition)) + ).thenAnswer { _ => + replicaManager.getPartition(topicPartition) match { + case HostedPartition.Online(partition) => + partition.appendRecordsToFollowerOrFutureReplica( + records = MemoryRecords.withRecords(CompressionType.NONE, 0, + new SimpleRecord("first message".getBytes)), + isFuture = false + ) + + case _ => + } + + Map.empty[TopicPartition, PartitionFetchState] + } + + // Apply changes that bumps the leader epoch. + followerTopicsDelta = topicsChangeDelta(followerMetadataImage.topics(), localId, false) + followerMetadataImage = imageFromTopics(followerTopicsDelta.apply()) + replicaManager.applyDelta(followerTopicsDelta, followerMetadataImage) + + assertFalse(followerPartition.isLeader) + assertEquals(1, followerPartition.getLeaderEpoch) + assertEquals(1, followerPartition.localLogOrException.logEndOffset) + + // Verify that addFetcherForPartitions was called with the correct + // init offset. + Mockito.verify(mockReplicaFetcherManager, Mockito.times(1)) + .addFetcherForPartitions( + Map(topicPartition -> InitialFetchState( + topicId = Some(FOO_UUID), + leader = BrokerEndPoint(otherId, "localhost", 9093), + currentLeaderEpoch = 1, + initOffset = 1 + )) + ) + } finally { + replicaManager.shutdown() + } + + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + private def topicsCreateDelta(startId: Int, isStartIdLeader: Boolean): TopicsDelta = { + val leader = if (isStartIdLeader) startId else startId + 1 + val delta = new TopicsDelta(TopicsImage.EMPTY) + delta.replay(new TopicRecord().setName("foo").setTopicId(FOO_UUID)) + delta.replay( + new PartitionRecord() + .setPartitionId(0) + .setTopicId(FOO_UUID) + .setReplicas(util.Arrays.asList(startId, startId + 1)) + .setIsr(util.Arrays.asList(startId, startId + 1)) + .setRemovingReplicas(Collections.emptyList()) + .setAddingReplicas(Collections.emptyList()) + .setLeader(leader) + .setLeaderEpoch(0) + .setPartitionEpoch(0) + ) + + delta + } + + private def topicsChangeDelta(topicsImage: TopicsImage, startId: Int, isStartIdLeader: Boolean): TopicsDelta = { + val leader = if (isStartIdLeader) startId else startId + 1 + val delta = new TopicsDelta(topicsImage) + delta.replay( + new PartitionChangeRecord() + .setPartitionId(0) + .setTopicId(FOO_UUID) + .setReplicas(util.Arrays.asList(startId, startId + 1)) + .setIsr(util.Arrays.asList(startId, startId + 1)) + .setLeader(leader) + ) + delta + } + + private def topicsDeleteDelta(topicsImage: TopicsImage): TopicsDelta = { + val delta = new TopicsDelta(topicsImage) + delta.replay(new RemoveTopicRecord().setTopicId(FOO_UUID)) + + delta + } + + private def imageFromTopics(topicsImage: TopicsImage): MetadataImage = { + new MetadataImage( + new RaftOffsetAndEpoch(100, 10), + FeaturesImage.EMPTY, + ClusterImageTest.IMAGE1, + topicsImage, + ConfigurationsImage.EMPTY, + ClientQuotasImage.EMPTY + ) + } + + def assertFetcherHasTopicId[T <: AbstractFetcherThread](manager: AbstractFetcherManager[T], + tp: TopicPartition, + expectedTopicId: Option[Uuid]): Unit = { + val fetchState = manager.getFetcher(tp).flatMap(_.fetchState(tp)) + assertTrue(fetchState.isDefined) + assertEquals(expectedTopicId, fetchState.get.topicId) + } + + @ParameterizedTest + @ValueSource(booleans = Array(true, false)) + def testPartitionFetchStateUpdatesWithTopicIdChanges(startsWithTopicId: Boolean): Unit = { + val aliveBrokersIds = Seq(0, 1) + val replicaManager = setupReplicaManagerWithMockedPurgatories(new MockTimer(time), + brokerId = 0, aliveBrokersIds) + try { + val tp = new TopicPartition(topic, 0) + val leaderAndIsr = new LeaderAndIsr(1, 0, aliveBrokersIds.toList, 0) + + // This test either starts with a topic ID in the PartitionFetchState and removes it on the next request (startsWithTopicId) + // or does not start with a topic ID in the PartitionFetchState and adds one on the next request (!startsWithTopicId) + val startingId = if (startsWithTopicId) topicId else Uuid.ZERO_UUID + val startingIdOpt = if (startsWithTopicId) Some(topicId) else None + val leaderAndIsrRequest1 = leaderAndIsrRequest(startingId, tp, aliveBrokersIds, leaderAndIsr) + val leaderAndIsrResponse1 = replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest1, (_, _) => ()) + assertEquals(Errors.NONE, leaderAndIsrResponse1.error) + + assertFetcherHasTopicId(replicaManager.replicaFetcherManager, tp, startingIdOpt) + + val endingId = if (!startsWithTopicId) topicId else Uuid.ZERO_UUID + val endingIdOpt = if (!startsWithTopicId) Some(topicId) else None + val leaderAndIsrRequest2 = leaderAndIsrRequest(endingId, tp, aliveBrokersIds, leaderAndIsr) + val leaderAndIsrResponse2 = replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest2, (_, _) => ()) + assertEquals(Errors.NONE, leaderAndIsrResponse2.error) + + assertFetcherHasTopicId(replicaManager.replicaFetcherManager, tp, endingIdOpt) + } finally { + replicaManager.shutdown(checkpointHW = false) + } + } + + @ParameterizedTest + @ValueSource(booleans = Array(true, false)) + def testReplicaAlterLogDirsWithAndWithoutIds(usesTopicIds: Boolean): Unit = { + val version = if (usesTopicIds) LeaderAndIsrRequestData.HIGHEST_SUPPORTED_VERSION else 4.toShort + val topicId = if (usesTopicIds) this.topicId else Uuid.ZERO_UUID + val topicIdOpt = if (usesTopicIds) Some(topicId) else None + val replicaManager = setupReplicaManagerWithMockedPurgatories(new MockTimer(time)) + try { + val topicPartition = new TopicPartition(topic, 0) + val aliveBrokersIds = Seq(0, 1) + replicaManager.createPartition(topicPartition) + .createLogIfNotExists(isNew = false, isFutureReplica = false, + new LazyOffsetCheckpoints(replicaManager.highWatermarkCheckpoints), None) + val tp = new TopicPartition(topic, 0) + val leaderAndIsr = new LeaderAndIsr(0, 0, aliveBrokersIds.toList, 0) + + val leaderAndIsrRequest1 = leaderAndIsrRequest(topicId, tp, aliveBrokersIds, leaderAndIsr, version = version) + replicaManager.becomeLeaderOrFollower(0, leaderAndIsrRequest1, (_, _) => ()) + val partition = replicaManager.getPartitionOrException(tp) + assertEquals(1, replicaManager.logManager.liveLogDirs.filterNot(_ == partition.log.get.dir.getParentFile).size) + + // Append a couple of messages. + for (i <- 1 to 500) { + val records = TestUtils.singletonRecords(s"message $i".getBytes) + appendRecords(replicaManager, tp, records).onFire { response => + assertEquals(Errors.NONE, response.error) + } + } + + // Find the live and different folder. + val newReplicaFolder = replicaManager.logManager.liveLogDirs.filterNot(_ == partition.log.get.dir.getParentFile).head + assertEquals(0, replicaManager.replicaAlterLogDirsManager.fetcherThreadMap.size) + replicaManager.alterReplicaLogDirs(Map(topicPartition -> newReplicaFolder.getAbsolutePath)) + + assertFetcherHasTopicId(replicaManager.replicaAlterLogDirsManager, partition.topicPartition, topicIdOpt) + + // Make sure the future log is created. + replicaManager.futureLocalLogOrException(topicPartition) + assertEquals(1, replicaManager.replicaAlterLogDirsManager.fetcherThreadMap.size) + + // Wait for the ReplicaAlterLogDirsThread to complete. + TestUtils.waitUntilTrue(() => { + replicaManager.replicaAlterLogDirsManager.shutdownIdleFetcherThreads() + replicaManager.replicaAlterLogDirsManager.fetcherThreadMap.isEmpty + }, s"ReplicaAlterLogDirsThread should be gone") + } finally { + replicaManager.shutdown(checkpointHW = false) + } + } +} diff --git a/core/src/test/scala/unit/kafka/server/ReplicationQuotaManagerTest.scala b/core/src/test/scala/unit/kafka/server/ReplicationQuotaManagerTest.scala new file mode 100644 index 0000000..be07c41 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/ReplicationQuotaManagerTest.scala @@ -0,0 +1,124 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.util.Collections + +import kafka.server.QuotaType._ +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.metrics.{MetricConfig, Metrics, Quota} +import org.apache.kafka.common.utils.MockTime +import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertTrue} +import org.junit.jupiter.api.{AfterEach, Test} + +import scala.jdk.CollectionConverters._ + +class ReplicationQuotaManagerTest { + private val time = new MockTime + private val metrics = new Metrics(new MetricConfig(), Collections.emptyList(), time) + + @AfterEach + def tearDown(): Unit = { + metrics.close() + } + + @Test + def shouldThrottleOnlyDefinedReplicas(): Unit = { + val quota = new ReplicationQuotaManager(ReplicationQuotaManagerConfig(), metrics, QuotaType.Fetch, time) + quota.markThrottled("topic1", Seq(1, 2, 3)) + + assertTrue(quota.isThrottled(tp1(1))) + assertTrue(quota.isThrottled(tp1(2))) + assertTrue(quota.isThrottled(tp1(3))) + assertFalse(quota.isThrottled(tp1(4))) + } + + @Test + def shouldExceedQuotaThenReturnBackBelowBoundAsTimePasses(): Unit = { + val quota = new ReplicationQuotaManager(ReplicationQuotaManagerConfig(numQuotaSamples = 10, quotaWindowSizeSeconds = 1), metrics, LeaderReplication, time) + + //Given + quota.updateQuota(new Quota(100, true)) + + //Quota should not be broken when we start + assertFalse(quota.isQuotaExceeded) + + //First window is fixed, so we'll skip it + time.sleep(1000) + + //When we record up to the quota value after half a window + time.sleep(500) + quota.record(1) + + //Then it should not break the quota + assertFalse(quota.isQuotaExceeded) + + //When we record half the quota (half way through the window), we still should not break + quota.record(149) //150B, 1.5s + assertFalse(quota.isQuotaExceeded) + + //Add a byte to push over quota + quota.record(1) //151B, 1.5s + + //Then it should break the quota + assertEquals(151 / 1.5, rate(metrics), 0) //151B, 1.5s + assertTrue(quota.isQuotaExceeded) + + //When we sleep for the remaining half the window + time.sleep(500) //151B, 2s + + //Then Our rate should have halved (i.e back down below the quota) + assertFalse(quota.isQuotaExceeded) + assertEquals(151d / 2, rate(metrics), 0.1) //151B, 2s + + //When we sleep for another half a window (now half way through second window) + time.sleep(500) + quota.record(99) //250B, 2.5s + + //Then the rate should be exceeded again + assertEquals(250 / 2.5, rate(metrics), 0) //250B, 2.5s + assertFalse(quota.isQuotaExceeded) + quota.record(1) + assertTrue(quota.isQuotaExceeded) + assertEquals(251 / 2.5, rate(metrics), 0) + + //Sleep for 2 more window + time.sleep(2 * 1000) //so now at 3.5s + assertFalse(quota.isQuotaExceeded) + assertEquals(251d / 4.5, rate(metrics), 0) + } + + def rate(metrics: Metrics): Double = { + val metricName = metrics.metricName("byte-rate", LeaderReplication.toString, "Tracking byte-rate for " + LeaderReplication) + val leaderThrottledRate = metrics.metrics.asScala(metricName).metricValue.asInstanceOf[Double] + leaderThrottledRate + } + + @Test + def shouldSupportWildcardThrottledReplicas(): Unit = { + val quota = new ReplicationQuotaManager(ReplicationQuotaManagerConfig(), metrics, LeaderReplication, time) + + //When + quota.markThrottled("MyTopic") + + //Then + assertTrue(quota.isThrottled(new TopicPartition("MyTopic", 0))) + assertFalse(quota.isThrottled(new TopicPartition("MyOtherTopic", 0))) + } + + private def tp1(id: Int): TopicPartition = new TopicPartition("topic1", id) +} diff --git a/core/src/test/scala/unit/kafka/server/ReplicationQuotasTest.scala b/core/src/test/scala/unit/kafka/server/ReplicationQuotasTest.scala new file mode 100644 index 0000000..4b7cae8 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/ReplicationQuotasTest.scala @@ -0,0 +1,241 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util.Properties + +import kafka.log.LogConfig._ +import kafka.server.KafkaConfig.fromProps +import kafka.server.QuotaType._ +import kafka.utils.TestUtils._ +import kafka.utils.CoreUtils._ +import kafka.utils.TestUtils +import kafka.server.QuorumTestHarness +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord} +import org.apache.kafka.common.TopicPartition +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, Test} + +import scala.jdk.CollectionConverters._ + +/** + * This is the main test which ensure Replication Quotas work correctly. + * + * The test will fail if the quota is < 1MB/s as 1MB is the default for replica.fetch.max.bytes. + * So with a throttle of 100KB/s, 1 fetch of 1 partition would fill 10s of quota. In turn causing + * the throttled broker to pause for > 10s + * + * Anything over 100MB/s tends to fail as this is the non-throttled replication rate + */ +class ReplicationQuotasTest extends QuorumTestHarness { + def percentError(percent: Int, value: Long): Long = Math.round(value * percent / 100.0) + + val msg100KB = new Array[Byte](100000) + var brokers: Seq[KafkaServer] = null + val topic = "topic1" + var producer: KafkaProducer[Array[Byte], Array[Byte]] = null + + @AfterEach + override def tearDown(): Unit = { + producer.close() + shutdownServers(brokers) + super.tearDown() + } + + @Test + def shouldBootstrapTwoBrokersWithLeaderThrottle(): Unit = { + shouldMatchQuotaReplicatingThroughAnAsymmetricTopology(true) + } + + @Test + def shouldBootstrapTwoBrokersWithFollowerThrottle(): Unit = { + shouldMatchQuotaReplicatingThroughAnAsymmetricTopology(false) + } + + def shouldMatchQuotaReplicatingThroughAnAsymmetricTopology(leaderThrottle: Boolean): Unit = { + /** + * In short we have 8 brokers, 2 are not-started. We assign replicas for the two non-started + * brokers, so when we start them we can monitor replication from the 6 to the 2. + * + * We also have two non-throttled partitions on two of the 6 brokers, just to make sure + * regular replication works as expected. + */ + + brokers = (100 to 105).map { id => createServer(fromProps(createBrokerConfig(id, zkConnect))) } + + //Given six partitions, led on nodes 0,1,2,3,4,5 but with followers on node 6,7 (not started yet) + //And two extra partitions 6,7, which we don't intend on throttling. + val assignment = Map( + 0 -> Seq(100, 106), //Throttled + 1 -> Seq(101, 106), //Throttled + 2 -> Seq(102, 106), //Throttled + 3 -> Seq(103, 107), //Throttled + 4 -> Seq(104, 107), //Throttled + 5 -> Seq(105, 107), //Throttled + 6 -> Seq(100, 106), //Not Throttled + 7 -> Seq(101, 107) //Not Throttled + ) + TestUtils.createTopic(zkClient, topic, assignment, brokers) + + val msg = msg100KB + val msgCount = 100 + val expectedDuration = 10 //Keep the test to N seconds + var throttle: Long = msgCount * msg.length / expectedDuration + if (!leaderThrottle) throttle = throttle * 3 //Follower throttle needs to replicate 3x as fast to get the same duration as there are three replicas to replicate for each of the two follower brokers + + //Set the throttle limit on all 8 brokers, but only assign throttled replicas to the six leaders, or two followers + (100 to 107).foreach { brokerId => + adminZkClient.changeBrokerConfig(Seq(brokerId), + propsWith( + (DynamicConfig.Broker.LeaderReplicationThrottledRateProp, throttle.toString), + (DynamicConfig.Broker.FollowerReplicationThrottledRateProp, throttle.toString) + )) + } + + //Either throttle the six leaders or the two followers + if (leaderThrottle) + adminZkClient.changeTopicConfig(topic, propsWith(LeaderReplicationThrottledReplicasProp, "0:100,1:101,2:102,3:103,4:104,5:105" )) + else + adminZkClient.changeTopicConfig(topic, propsWith(FollowerReplicationThrottledReplicasProp, "0:106,1:106,2:106,3:107,4:107,5:107")) + + //Add data equally to each partition + producer = createProducer(getBrokerListStrFromServers(brokers), acks = 1) + (0 until msgCount).foreach { _ => + (0 to 7).foreach { partition => + producer.send(new ProducerRecord(topic, partition, null, msg)) + } + } + + //Ensure data is fully written: broker 1 has partition 1, broker 2 has partition 2 etc + (0 to 5).foreach { id => waitForOffsetsToMatch(msgCount, id, 100 + id) } + //Check the non-throttled partitions too + waitForOffsetsToMatch(msgCount, 6, 100) + waitForOffsetsToMatch(msgCount, 7, 101) + + val start = System.currentTimeMillis() + + //When we create the 2 new, empty brokers + createBrokers(106 to 107) + + //Check that throttled config correctly migrated to the new brokers + (106 to 107).foreach { brokerId => + assertEquals(throttle, brokerFor(brokerId).quotaManagers.follower.upperBound) + } + if (!leaderThrottle) { + (0 to 2).foreach { partition => assertTrue(brokerFor(106).quotaManagers.follower.isThrottled(tp(partition))) } + (3 to 5).foreach { partition => assertTrue(brokerFor(107).quotaManagers.follower.isThrottled(tp(partition))) } + } + + //Wait for non-throttled partitions to replicate first + (6 to 7).foreach { id => waitForOffsetsToMatch(msgCount, id, 100 + id) } + val unthrottledTook = System.currentTimeMillis() - start + + //Wait for replicas 0,1,2,3,4,5 to fully replicated to broker 106,107 + (0 to 2).foreach { id => waitForOffsetsToMatch(msgCount, id, 106) } + (3 to 5).foreach { id => waitForOffsetsToMatch(msgCount, id, 107) } + + val throttledTook = System.currentTimeMillis() - start + + //Check the times for throttled/unthrottled are each side of what we expect + val throttledLowerBound = expectedDuration * 1000 * 0.9 + val throttledUpperBound = expectedDuration * 1000 * 3 + assertTrue(unthrottledTook < throttledLowerBound, s"Expected $unthrottledTook < $throttledLowerBound") + assertTrue(throttledTook > throttledLowerBound, s"Expected $throttledTook > $throttledLowerBound") + assertTrue(throttledTook < throttledUpperBound, s"Expected $throttledTook < $throttledUpperBound") + + // Check the rate metric matches what we expect. + // In a short test the brokers can be read unfairly, so assert against the average + val rateUpperBound = throttle * 1.1 + val rateLowerBound = throttle * 0.5 + val rate = if (leaderThrottle) avRate(LeaderReplication, 100 to 105) else avRate(FollowerReplication, 106 to 107) + assertTrue(rate < rateUpperBound, s"Expected ${rate} < $rateUpperBound") + assertTrue(rate > rateLowerBound, s"Expected ${rate} > $rateLowerBound") + } + + def tp(partition: Int): TopicPartition = new TopicPartition(topic, partition) + + @Test + def shouldThrottleOldSegments(): Unit = { + /** + * Simple test which ensures throttled replication works when the dataset spans many segments + */ + + //2 brokers with 1MB Segment Size & 1 partition + val config: Properties = createBrokerConfig(100, zkConnect) + config.put("log.segment.bytes", (1024 * 1024).toString) + brokers = Seq(createServer(fromProps(config))) + TestUtils.createTopic(zkClient, topic, Map(0 -> Seq(100, 101)), brokers) + + //Write 20MBs and throttle at 5MB/s + val msg = msg100KB + val msgCount: Int = 200 + val expectedDuration = 4 + val throttle: Long = msg.length * msgCount / expectedDuration + + //Set the throttle to only limit leader + adminZkClient.changeBrokerConfig(Seq(100), propsWith(DynamicConfig.Broker.LeaderReplicationThrottledRateProp, throttle.toString)) + adminZkClient.changeTopicConfig(topic, propsWith(LeaderReplicationThrottledReplicasProp, "0:100")) + + //Add data + addData(msgCount, msg) + + //Start the new broker (and hence start replicating) + debug("Starting new broker") + brokers = brokers :+ createServer(fromProps(createBrokerConfig(101, zkConnect))) + val start = System.currentTimeMillis() + + waitForOffsetsToMatch(msgCount, 0, 101) + + val throttledTook = System.currentTimeMillis() - start + + assertTrue(throttledTook > expectedDuration * 1000 * 0.9, + s"Throttled replication of ${throttledTook}ms should be > ${expectedDuration * 1000 * 0.9}ms") + assertTrue(throttledTook < expectedDuration * 1000 * 1.5, + s"Throttled replication of ${throttledTook}ms should be < ${expectedDuration * 1500}ms") + } + + def addData(msgCount: Int, msg: Array[Byte]): Unit = { + producer = createProducer(getBrokerListStrFromServers(brokers), acks = 0) + (0 until msgCount).map(_ => producer.send(new ProducerRecord(topic, msg))).foreach(_.get) + waitForOffsetsToMatch(msgCount, 0, 100) + } + + private def waitForOffsetsToMatch(offset: Int, partitionId: Int, brokerId: Int): Unit = { + waitUntilTrue(() => { + offset == brokerFor(brokerId).getLogManager.getLog(new TopicPartition(topic, partitionId)) + .map(_.logEndOffset).getOrElse(0) + }, s"Offsets did not match for partition $partitionId on broker $brokerId", 60000) + } + + private def brokerFor(id: Int): KafkaServer = brokers.filter(_.config.brokerId == id).head + + def createBrokers(brokerIds: Seq[Int]): Unit = { + brokerIds.foreach { id => + brokers = brokers :+ createServer(fromProps(createBrokerConfig(id, zkConnect))) + } + } + + private def avRate(replicationType: QuotaType, brokers: Seq[Int]): Double = { + brokers.map(brokerFor).map(measuredRate(_, replicationType)).sum / brokers.length + } + + private def measuredRate(broker: KafkaServer, repType: QuotaType): Double = { + val metricName = broker.metrics.metricName("byte-rate", repType.toString) + broker.metrics.metrics.asScala(metricName).metricValue.asInstanceOf[Double] + } +} diff --git a/core/src/test/scala/unit/kafka/server/RequestQuotaTest.scala b/core/src/test/scala/unit/kafka/server/RequestQuotaTest.scala new file mode 100644 index 0000000..ddbc987 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/RequestQuotaTest.scala @@ -0,0 +1,794 @@ +/** + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + **/ + +package kafka.server + +import java.net.InetAddress +import java.util +import java.util.concurrent.{Executors, Future, TimeUnit} +import java.util.{Collections, Optional, Properties} +import kafka.api.LeaderAndIsr +import kafka.log.LogConfig +import kafka.network.RequestChannel.Session +import kafka.security.authorizer.AclAuthorizer +import kafka.utils.TestUtils +import org.apache.kafka.common.acl._ +import org.apache.kafka.common.config.ConfigResource +import org.apache.kafka.common.message.CreatePartitionsRequestData.CreatePartitionsTopic +import org.apache.kafka.common.message.CreateTopicsRequestData.{CreatableTopic, CreatableTopicCollection} +import org.apache.kafka.common.message.JoinGroupRequestData.JoinGroupRequestProtocolCollection +import org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrPartitionState +import org.apache.kafka.common.message.LeaveGroupRequestData.MemberIdentity +import org.apache.kafka.common.message.ListOffsetsRequestData.{ListOffsetsPartition, ListOffsetsTopic} +import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderPartition +import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderTopic +import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderTopicCollection +import org.apache.kafka.common.message.StopReplicaRequestData.{StopReplicaPartitionState, StopReplicaTopicState} +import org.apache.kafka.common.message.UpdateMetadataRequestData.{UpdateMetadataBroker, UpdateMetadataEndpoint, UpdateMetadataPartitionState} +import org.apache.kafka.common.message.{AddOffsetsToTxnRequestData, _} +import org.apache.kafka.common.metrics.{KafkaMetric, Quota, Sensor} +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.protocol.ApiKeys +import org.apache.kafka.common.quota.ClientQuotaFilter +import org.apache.kafka.common.record._ +import org.apache.kafka.common.requests._ +import org.apache.kafka.common.resource.{PatternType, ResourceType => AdminResourceType} +import org.apache.kafka.common.security.auth._ +import org.apache.kafka.common.utils.{Sanitizer, SecurityUtils} +import org.apache.kafka.common._ +import org.apache.kafka.common.config.internals.QuotaConfigs +import org.apache.kafka.server.authorizer.{Action, AuthorizableRequestContext, AuthorizationResult} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +import scala.collection.mutable.ListBuffer +import scala.jdk.CollectionConverters._ + +class RequestQuotaTest extends BaseRequestTest { + + override def brokerCount: Int = 1 + + private val topic = "topic-1" + private val numPartitions = 1 + private val tp = new TopicPartition(topic, 0) + private val logDir = "logDir" + private val unthrottledClientId = "unthrottled-client" + private val smallQuotaProducerClientId = "small-quota-producer-client" + private val smallQuotaConsumerClientId = "small-quota-consumer-client" + private val brokerId: Integer = 0 + private var leaderNode: KafkaServer = _ + + // Run tests concurrently since a throttle could be up to 1 second because quota percentage allocated is very low + case class Task(apiKey: ApiKeys, future: Future[_]) + private val executor = Executors.newCachedThreadPool + private val tasks = new ListBuffer[Task] + + override def brokerPropertyOverrides(properties: Properties): Unit = { + properties.put(KafkaConfig.ControlledShutdownEnableProp, "false") + properties.put(KafkaConfig.OffsetsTopicReplicationFactorProp, "1") + properties.put(KafkaConfig.OffsetsTopicPartitionsProp, "1") + properties.put(KafkaConfig.GroupMinSessionTimeoutMsProp, "100") + properties.put(KafkaConfig.GroupInitialRebalanceDelayMsProp, "0") + properties.put(KafkaConfig.AuthorizerClassNameProp, classOf[RequestQuotaTest.TestAuthorizer].getName) + properties.put(KafkaConfig.PrincipalBuilderClassProp, classOf[RequestQuotaTest.TestPrincipalBuilder].getName) + } + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + RequestQuotaTest.principal = KafkaPrincipal.ANONYMOUS + super.setUp(testInfo) + + createTopic(topic, numPartitions) + leaderNode = servers.head + + // Change default client-id request quota to a small value and a single unthrottledClient with a large quota + val quotaProps = new Properties() + quotaProps.put(QuotaConfigs.REQUEST_PERCENTAGE_OVERRIDE_CONFIG, "0.01") + quotaProps.put(QuotaConfigs.PRODUCER_BYTE_RATE_OVERRIDE_CONFIG, "2000") + quotaProps.put(QuotaConfigs.CONSUMER_BYTE_RATE_OVERRIDE_CONFIG, "2000") + adminZkClient.changeClientIdConfig("", quotaProps) + quotaProps.put(QuotaConfigs.REQUEST_PERCENTAGE_OVERRIDE_CONFIG, "2000") + adminZkClient.changeClientIdConfig(Sanitizer.sanitize(unthrottledClientId), quotaProps) + + // Client ids with small producer and consumer (fetch) quotas. Quota values were picked so that both + // producer/consumer and request quotas are violated on the first produce/consume operation, and the delay due to + // producer/consumer quota violation will be longer than the delay due to request quota violation. + quotaProps.put(QuotaConfigs.PRODUCER_BYTE_RATE_OVERRIDE_CONFIG, "1") + quotaProps.put(QuotaConfigs.REQUEST_PERCENTAGE_OVERRIDE_CONFIG, "0.01") + adminZkClient.changeClientIdConfig(Sanitizer.sanitize(smallQuotaProducerClientId), quotaProps) + quotaProps.put(QuotaConfigs.CONSUMER_BYTE_RATE_OVERRIDE_CONFIG, "1") + quotaProps.put(QuotaConfigs.REQUEST_PERCENTAGE_OVERRIDE_CONFIG, "0.01") + adminZkClient.changeClientIdConfig(Sanitizer.sanitize(smallQuotaConsumerClientId), quotaProps) + + TestUtils.retry(20000) { + val quotaManager = servers.head.dataPlaneRequestProcessor.quotas.request + assertEquals(Quota.upperBound(0.01), quotaManager.quota("some-user", "some-client"), s"Default request quota not set") + assertEquals(Quota.upperBound(2000), quotaManager.quota("some-user", unthrottledClientId), s"Request quota override not set") + val produceQuotaManager = servers.head.dataPlaneRequestProcessor.quotas.produce + assertEquals(Quota.upperBound(1), produceQuotaManager.quota("some-user", smallQuotaProducerClientId), s"Produce quota override not set") + val consumeQuotaManager = servers.head.dataPlaneRequestProcessor.quotas.fetch + assertEquals(Quota.upperBound(1), consumeQuotaManager.quota("some-user", smallQuotaConsumerClientId), s"Consume quota override not set") + } + } + + @AfterEach + override def tearDown(): Unit = { + try executor.shutdownNow() + finally super.tearDown() + } + + @Test + def testResponseThrottleTime(): Unit = { + for (apiKey <- RequestQuotaTest.ClientActions ++ RequestQuotaTest.ClusterActionsWithThrottle) + submitTest(apiKey, () => checkRequestThrottleTime(apiKey)) + + waitAndCheckResults() + } + + @Test + def testResponseThrottleTimeWhenBothProduceAndRequestQuotasViolated(): Unit = { + submitTest(ApiKeys.PRODUCE, () => checkSmallQuotaProducerRequestThrottleTime()) + waitAndCheckResults() + } + + @Test + def testResponseThrottleTimeWhenBothFetchAndRequestQuotasViolated(): Unit = { + submitTest(ApiKeys.FETCH, () => checkSmallQuotaConsumerRequestThrottleTime()) + waitAndCheckResults() + } + + @Test + def testUnthrottledClient(): Unit = { + for (apiKey <- RequestQuotaTest.ClientActions) { + submitTest(apiKey, () => checkUnthrottledClient(apiKey)) + } + + waitAndCheckResults() + } + + @Test + def testExemptRequestTime(): Unit = { + for (apiKey <- RequestQuotaTest.ClusterActions -- RequestQuotaTest.ClusterActionsWithThrottle) { + submitTest(apiKey, () => checkExemptRequestMetric(apiKey)) + } + + waitAndCheckResults() + } + + @Test + def testUnauthorizedThrottle(): Unit = { + RequestQuotaTest.principal = RequestQuotaTest.UnauthorizedPrincipal + + for (apiKey <- ApiKeys.zkBrokerApis.asScala) { + submitTest(apiKey, () => checkUnauthorizedRequestThrottle(apiKey)) + } + + waitAndCheckResults() + } + + def session(user: String): Session = Session(new KafkaPrincipal(KafkaPrincipal.USER_TYPE, user), null) + + private def throttleTimeMetricValue(clientId: String): Double = { + throttleTimeMetricValueForQuotaType(clientId, QuotaType.Request) + } + + private def throttleTimeMetricValueForQuotaType(clientId: String, quotaType: QuotaType): Double = { + val metricName = leaderNode.metrics.metricName("throttle-time", quotaType.toString, + "", "user", "", "client-id", clientId) + val sensor = leaderNode.quotaManagers.request.getOrCreateQuotaSensors(session("ANONYMOUS"), + clientId).throttleTimeSensor + metricValue(leaderNode.metrics.metrics.get(metricName), sensor) + } + + private def requestTimeMetricValue(clientId: String): Double = { + val metricName = leaderNode.metrics.metricName("request-time", QuotaType.Request.toString, + "", "user", "", "client-id", clientId) + val sensor = leaderNode.quotaManagers.request.getOrCreateQuotaSensors(session("ANONYMOUS"), + clientId).quotaSensor + metricValue(leaderNode.metrics.metrics.get(metricName), sensor) + } + + private def exemptRequestMetricValue: Double = { + val metricName = leaderNode.metrics.metricName("exempt-request-time", QuotaType.Request.toString, "") + metricValue(leaderNode.metrics.metrics.get(metricName), leaderNode.quotaManagers.request.exemptSensor) + } + + private def metricValue(metric: KafkaMetric, sensor: Sensor): Double = { + sensor.synchronized { + if (metric == null) -1.0 else metric.metricValue.asInstanceOf[Double] + } + } + + private def requestBuilder(apiKey: ApiKeys): AbstractRequest.Builder[_ <: AbstractRequest] = { + apiKey match { + case ApiKeys.PRODUCE => + requests.ProduceRequest.forCurrentMagic(new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection( + Collections.singletonList(new ProduceRequestData.TopicProduceData() + .setName(tp.topic()).setPartitionData(Collections.singletonList( + new ProduceRequestData.PartitionProduceData() + .setIndex(tp.partition()) + .setRecords(MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("test".getBytes)))))) + .iterator)) + .setAcks(1.toShort) + .setTimeoutMs(5000)) + + case ApiKeys.FETCH => + val partitionMap = new util.LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + partitionMap.put(tp, new FetchRequest.PartitionData(getTopicIds().getOrElse(tp.topic, Uuid.ZERO_UUID), 0, 0, 100, Optional.of(15))) + FetchRequest.Builder.forConsumer(ApiKeys.FETCH.latestVersion, 0, 0, partitionMap) + + case ApiKeys.METADATA => + new MetadataRequest.Builder(List(topic).asJava, true) + + case ApiKeys.LIST_OFFSETS => + val topic = new ListOffsetsTopic() + .setName(tp.topic) + .setPartitions(List(new ListOffsetsPartition() + .setPartitionIndex(tp.partition) + .setTimestamp(0L) + .setCurrentLeaderEpoch(15)).asJava) + ListOffsetsRequest.Builder.forConsumer(false, IsolationLevel.READ_UNCOMMITTED, false) + .setTargetTimes(List(topic).asJava) + + case ApiKeys.LEADER_AND_ISR => + new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, brokerId, Int.MaxValue, Long.MaxValue, + Seq(new LeaderAndIsrPartitionState() + .setTopicName(tp.topic) + .setPartitionIndex(tp.partition) + .setControllerEpoch(Int.MaxValue) + .setLeader(brokerId) + .setLeaderEpoch(Int.MaxValue) + .setIsr(List(brokerId).asJava) + .setZkVersion(2) + .setReplicas(Seq(brokerId).asJava) + .setIsNew(true)).asJava, + getTopicIds().asJava, + Set(new Node(brokerId, "localhost", 0)).asJava) + + case ApiKeys.STOP_REPLICA => + val topicStates = Seq( + new StopReplicaTopicState() + .setTopicName(tp.topic()) + .setPartitionStates(Seq(new StopReplicaPartitionState() + .setPartitionIndex(tp.partition()) + .setLeaderEpoch(LeaderAndIsr.initialLeaderEpoch + 2) + .setDeletePartition(true)).asJava) + ).asJava + new StopReplicaRequest.Builder(ApiKeys.STOP_REPLICA.latestVersion, brokerId, + Int.MaxValue, Long.MaxValue, false, topicStates) + + case ApiKeys.UPDATE_METADATA => + val partitionState = Seq(new UpdateMetadataPartitionState() + .setTopicName(tp.topic) + .setPartitionIndex(tp.partition) + .setControllerEpoch(Int.MaxValue) + .setLeader(brokerId) + .setLeaderEpoch(Int.MaxValue) + .setIsr(List(brokerId).asJava) + .setZkVersion(2) + .setReplicas(Seq(brokerId).asJava)).asJava + val securityProtocol = SecurityProtocol.PLAINTEXT + val brokers = Seq(new UpdateMetadataBroker() + .setId(brokerId) + .setEndpoints(Seq(new UpdateMetadataEndpoint() + .setHost("localhost") + .setPort(0) + .setSecurityProtocol(securityProtocol.id) + .setListener(ListenerName.forSecurityProtocol(securityProtocol).value)).asJava)).asJava + new UpdateMetadataRequest.Builder(ApiKeys.UPDATE_METADATA.latestVersion, brokerId, Int.MaxValue, Long.MaxValue, + partitionState, brokers, Collections.emptyMap()) + + case ApiKeys.CONTROLLED_SHUTDOWN => + new ControlledShutdownRequest.Builder( + new ControlledShutdownRequestData() + .setBrokerId(brokerId) + .setBrokerEpoch(Long.MaxValue), + ApiKeys.CONTROLLED_SHUTDOWN.latestVersion) + + case ApiKeys.OFFSET_COMMIT => + new OffsetCommitRequest.Builder( + new OffsetCommitRequestData() + .setGroupId("test-group") + .setGenerationId(1) + .setMemberId(JoinGroupRequest.UNKNOWN_MEMBER_ID) + .setTopics( + Collections.singletonList( + new OffsetCommitRequestData.OffsetCommitRequestTopic() + .setName(topic) + .setPartitions( + Collections.singletonList( + new OffsetCommitRequestData.OffsetCommitRequestPartition() + .setPartitionIndex(0) + .setCommittedLeaderEpoch(RecordBatch.NO_PARTITION_LEADER_EPOCH) + .setCommittedOffset(0) + .setCommittedMetadata("metadata") + ) + ) + ) + ) + ) + case ApiKeys.OFFSET_FETCH => + new OffsetFetchRequest.Builder("test-group", false, List(tp).asJava, false) + + case ApiKeys.FIND_COORDINATOR => + new FindCoordinatorRequest.Builder( + new FindCoordinatorRequestData() + .setKeyType(FindCoordinatorRequest.CoordinatorType.GROUP.id) + .setCoordinatorKeys(Collections.singletonList("test-group"))) + + case ApiKeys.JOIN_GROUP => + new JoinGroupRequest.Builder( + new JoinGroupRequestData() + .setGroupId("test-join-group") + .setSessionTimeoutMs(200) + .setMemberId(JoinGroupRequest.UNKNOWN_MEMBER_ID) + .setGroupInstanceId(null) + .setProtocolType("consumer") + .setProtocols( + new JoinGroupRequestProtocolCollection( + Collections.singletonList(new JoinGroupRequestData.JoinGroupRequestProtocol() + .setName("consumer-range") + .setMetadata("test".getBytes())).iterator() + ) + ) + .setRebalanceTimeoutMs(100) + ) + + case ApiKeys.HEARTBEAT => + new HeartbeatRequest.Builder( + new HeartbeatRequestData() + .setGroupId("test-group") + .setGenerationId(1) + .setMemberId(JoinGroupRequest.UNKNOWN_MEMBER_ID) + ) + + case ApiKeys.LEAVE_GROUP => + new LeaveGroupRequest.Builder( + "test-leave-group", + Collections.singletonList( + new MemberIdentity() + .setMemberId(JoinGroupRequest.UNKNOWN_MEMBER_ID)) + ) + + case ApiKeys.SYNC_GROUP => + new SyncGroupRequest.Builder( + new SyncGroupRequestData() + .setGroupId("test-sync-group") + .setGenerationId(1) + .setMemberId(JoinGroupRequest.UNKNOWN_MEMBER_ID) + .setAssignments(Collections.emptyList()) + ) + + case ApiKeys.DESCRIBE_GROUPS => + new DescribeGroupsRequest.Builder(new DescribeGroupsRequestData().setGroups(List("test-group").asJava)) + + case ApiKeys.LIST_GROUPS => + new ListGroupsRequest.Builder(new ListGroupsRequestData()) + + case ApiKeys.SASL_HANDSHAKE => + new SaslHandshakeRequest.Builder(new SaslHandshakeRequestData().setMechanism("PLAIN")) + + case ApiKeys.SASL_AUTHENTICATE => + new SaslAuthenticateRequest.Builder(new SaslAuthenticateRequestData().setAuthBytes(new Array[Byte](0))) + + case ApiKeys.API_VERSIONS => + new ApiVersionsRequest.Builder() + + case ApiKeys.CREATE_TOPICS => + new CreateTopicsRequest.Builder( + new CreateTopicsRequestData().setTopics( + new CreatableTopicCollection(Collections.singleton( + new CreatableTopic().setName("topic-2").setNumPartitions(1). + setReplicationFactor(1.toShort)).iterator()))) + + case ApiKeys.DELETE_TOPICS => + new DeleteTopicsRequest.Builder( + new DeleteTopicsRequestData() + .setTopicNames(Collections.singletonList("topic-2")) + .setTimeoutMs(5000)) + + case ApiKeys.DELETE_RECORDS => + new DeleteRecordsRequest.Builder( + new DeleteRecordsRequestData() + .setTimeoutMs(5000) + .setTopics(Collections.singletonList(new DeleteRecordsRequestData.DeleteRecordsTopic() + .setName(tp.topic()) + .setPartitions(Collections.singletonList(new DeleteRecordsRequestData.DeleteRecordsPartition() + .setPartitionIndex(tp.partition()) + .setOffset(0L)))))) + + case ApiKeys.INIT_PRODUCER_ID => + val requestData = new InitProducerIdRequestData() + .setTransactionalId("test-transactional-id") + .setTransactionTimeoutMs(5000) + new InitProducerIdRequest.Builder(requestData) + + case ApiKeys.OFFSET_FOR_LEADER_EPOCH => + val epochs = new OffsetForLeaderTopicCollection() + epochs.add(new OffsetForLeaderTopic() + .setTopic(tp.topic()) + .setPartitions(List(new OffsetForLeaderPartition() + .setPartition(tp.partition()) + .setLeaderEpoch(0) + .setCurrentLeaderEpoch(15)).asJava)) + OffsetsForLeaderEpochRequest.Builder.forConsumer(epochs) + + case ApiKeys.ADD_PARTITIONS_TO_TXN => + new AddPartitionsToTxnRequest.Builder("test-transactional-id", 1, 0, List(tp).asJava) + + case ApiKeys.ADD_OFFSETS_TO_TXN => + new AddOffsetsToTxnRequest.Builder(new AddOffsetsToTxnRequestData() + .setTransactionalId("test-transactional-id") + .setProducerId(1) + .setProducerEpoch(0) + .setGroupId("test-txn-group") + ) + + case ApiKeys.END_TXN => + new EndTxnRequest.Builder(new EndTxnRequestData() + .setTransactionalId("test-transactional-id") + .setProducerId(1) + .setProducerEpoch(0) + .setCommitted(false) + ) + + case ApiKeys.WRITE_TXN_MARKERS => + new WriteTxnMarkersRequest.Builder(ApiKeys.WRITE_TXN_MARKERS.latestVersion(), List.empty.asJava) + + case ApiKeys.TXN_OFFSET_COMMIT => + new TxnOffsetCommitRequest.Builder( + "test-transactional-id", + "test-txn-group", + 2, + 0, + Map.empty[TopicPartition, TxnOffsetCommitRequest.CommittedOffset].asJava + ) + + case ApiKeys.DESCRIBE_ACLS => + new DescribeAclsRequest.Builder(AclBindingFilter.ANY) + + case ApiKeys.CREATE_ACLS => + new CreateAclsRequest.Builder(new CreateAclsRequestData().setCreations(Collections.singletonList( + new CreateAclsRequestData.AclCreation() + .setResourceType(AdminResourceType.TOPIC.code) + .setResourceName("mytopic") + .setResourcePatternType(PatternType.LITERAL.code) + .setPrincipal("User:ANONYMOUS") + .setHost("*") + .setOperation(AclOperation.WRITE.code) + .setPermissionType(AclPermissionType.DENY.code)))) + case ApiKeys.DELETE_ACLS => + new DeleteAclsRequest.Builder(new DeleteAclsRequestData().setFilters(Collections.singletonList( + new DeleteAclsRequestData.DeleteAclsFilter() + .setResourceTypeFilter(AdminResourceType.TOPIC.code) + .setResourceNameFilter(null) + .setPatternTypeFilter(PatternType.LITERAL.code) + .setPrincipalFilter("User:ANONYMOUS") + .setHostFilter("*") + .setOperation(AclOperation.ANY.code) + .setPermissionType(AclPermissionType.DENY.code)))) + case ApiKeys.DESCRIBE_CONFIGS => + new DescribeConfigsRequest.Builder(new DescribeConfigsRequestData() + .setResources(Collections.singletonList(new DescribeConfigsRequestData.DescribeConfigsResource() + .setResourceType(ConfigResource.Type.TOPIC.id) + .setResourceName(tp.topic)))) + + case ApiKeys.ALTER_CONFIGS => + new AlterConfigsRequest.Builder( + Collections.singletonMap(new ConfigResource(ConfigResource.Type.TOPIC, tp.topic), + new AlterConfigsRequest.Config(Collections.singleton( + new AlterConfigsRequest.ConfigEntry(LogConfig.MaxMessageBytesProp, "1000000") + ))), true) + + case ApiKeys.ALTER_REPLICA_LOG_DIRS => + val dir = new AlterReplicaLogDirsRequestData.AlterReplicaLogDir() + .setPath(logDir) + dir.topics.add(new AlterReplicaLogDirsRequestData.AlterReplicaLogDirTopic() + .setName(tp.topic) + .setPartitions(Collections.singletonList(tp.partition))) + val data = new AlterReplicaLogDirsRequestData(); + data.dirs.add(dir) + new AlterReplicaLogDirsRequest.Builder(data) + + case ApiKeys.DESCRIBE_LOG_DIRS => + val data = new DescribeLogDirsRequestData() + data.topics.add(new DescribeLogDirsRequestData.DescribableLogDirTopic() + .setTopic(tp.topic) + .setPartitions(Collections.singletonList(tp.partition))) + new DescribeLogDirsRequest.Builder(data) + + case ApiKeys.CREATE_PARTITIONS => + val data = new CreatePartitionsRequestData() + .setTimeoutMs(0) + .setValidateOnly(false) + data.topics().add(new CreatePartitionsTopic().setName("topic-2").setCount(1)) + new CreatePartitionsRequest.Builder(data) + + case ApiKeys.CREATE_DELEGATION_TOKEN => + new CreateDelegationTokenRequest.Builder( + new CreateDelegationTokenRequestData() + .setRenewers(Collections.singletonList(new CreateDelegationTokenRequestData.CreatableRenewers() + .setPrincipalType("User") + .setPrincipalName("test"))) + .setMaxLifetimeMs(1000) + ) + + case ApiKeys.EXPIRE_DELEGATION_TOKEN => + new ExpireDelegationTokenRequest.Builder( + new ExpireDelegationTokenRequestData() + .setHmac("".getBytes) + .setExpiryTimePeriodMs(1000L)) + + case ApiKeys.DESCRIBE_DELEGATION_TOKEN => + new DescribeDelegationTokenRequest.Builder(Collections.singletonList(SecurityUtils.parseKafkaPrincipal("User:test"))) + + case ApiKeys.RENEW_DELEGATION_TOKEN => + new RenewDelegationTokenRequest.Builder( + new RenewDelegationTokenRequestData() + .setHmac("".getBytes) + .setRenewPeriodMs(1000L)) + + case ApiKeys.DELETE_GROUPS => + new DeleteGroupsRequest.Builder(new DeleteGroupsRequestData() + .setGroupsNames(Collections.singletonList("test-group"))) + + case ApiKeys.ELECT_LEADERS => + new ElectLeadersRequest.Builder( + ElectionType.PREFERRED, + Collections.singletonList(new TopicPartition("my_topic", 0)), + 0 + ) + + case ApiKeys.INCREMENTAL_ALTER_CONFIGS => + new IncrementalAlterConfigsRequest.Builder( + new IncrementalAlterConfigsRequestData()) + + case ApiKeys.ALTER_PARTITION_REASSIGNMENTS => + new AlterPartitionReassignmentsRequest.Builder( + new AlterPartitionReassignmentsRequestData() + ) + + case ApiKeys.LIST_PARTITION_REASSIGNMENTS => + new ListPartitionReassignmentsRequest.Builder( + new ListPartitionReassignmentsRequestData() + ) + + case ApiKeys.OFFSET_DELETE => + new OffsetDeleteRequest.Builder( + new OffsetDeleteRequestData() + .setGroupId("test-group") + .setTopics(new OffsetDeleteRequestData.OffsetDeleteRequestTopicCollection( + Collections.singletonList(new OffsetDeleteRequestData.OffsetDeleteRequestTopic() + .setName("test-topic") + .setPartitions(Collections.singletonList( + new OffsetDeleteRequestData.OffsetDeleteRequestPartition() + .setPartitionIndex(0)))).iterator()))) + + case ApiKeys.DESCRIBE_CLIENT_QUOTAS => + new DescribeClientQuotasRequest.Builder(ClientQuotaFilter.all()) + + case ApiKeys.ALTER_CLIENT_QUOTAS => + new AlterClientQuotasRequest.Builder(List.empty.asJava, false) + + case ApiKeys.DESCRIBE_USER_SCRAM_CREDENTIALS => + new DescribeUserScramCredentialsRequest.Builder(new DescribeUserScramCredentialsRequestData()) + + case ApiKeys.ALTER_USER_SCRAM_CREDENTIALS => + new AlterUserScramCredentialsRequest.Builder(new AlterUserScramCredentialsRequestData()) + + case ApiKeys.VOTE => + new VoteRequest.Builder(VoteRequest.singletonRequest(tp, 1, 2, 0, 10)) + + case ApiKeys.BEGIN_QUORUM_EPOCH => + new BeginQuorumEpochRequest.Builder(BeginQuorumEpochRequest.singletonRequest(tp, 2, 5)) + + case ApiKeys.END_QUORUM_EPOCH => + new EndQuorumEpochRequest.Builder(EndQuorumEpochRequest.singletonRequest( + tp, 10, 5, Collections.singletonList(3))) + + case ApiKeys.ALTER_ISR => + new AlterIsrRequest.Builder(new AlterIsrRequestData()) + + case ApiKeys.UPDATE_FEATURES => + new UpdateFeaturesRequest.Builder(new UpdateFeaturesRequestData()) + + case ApiKeys.ENVELOPE => + val requestHeader = new RequestHeader( + ApiKeys.ALTER_CLIENT_QUOTAS, + ApiKeys.ALTER_CLIENT_QUOTAS.latestVersion, + "client-id", + 0 + ) + val embedRequestData = new AlterClientQuotasRequest.Builder(List.empty.asJava, false).build() + .serializeWithHeader(requestHeader) + new EnvelopeRequest.Builder(embedRequestData, new Array[Byte](0), + InetAddress.getByName("192.168.1.1").getAddress) + + case ApiKeys.DESCRIBE_CLUSTER => + new DescribeClusterRequest.Builder(new DescribeClusterRequestData()) + + case ApiKeys.DESCRIBE_PRODUCERS => + new DescribeProducersRequest.Builder(new DescribeProducersRequestData() + .setTopics(List(new DescribeProducersRequestData.TopicRequest() + .setName("test-topic") + .setPartitionIndexes(List(1, 2, 3).map(Int.box).asJava)).asJava)) + + case ApiKeys.BROKER_REGISTRATION => + new BrokerRegistrationRequest.Builder(new BrokerRegistrationRequestData()) + + case ApiKeys.BROKER_HEARTBEAT => + new BrokerHeartbeatRequest.Builder(new BrokerHeartbeatRequestData()) + + case ApiKeys.UNREGISTER_BROKER => + new UnregisterBrokerRequest.Builder(new UnregisterBrokerRequestData()) + + case ApiKeys.DESCRIBE_TRANSACTIONS => + new DescribeTransactionsRequest.Builder(new DescribeTransactionsRequestData() + .setTransactionalIds(List("test-transactional-id").asJava)) + + case ApiKeys.LIST_TRANSACTIONS => + new ListTransactionsRequest.Builder(new ListTransactionsRequestData()) + + case ApiKeys.ALLOCATE_PRODUCER_IDS => + new AllocateProducerIdsRequest.Builder(new AllocateProducerIdsRequestData()) + case _ => + throw new IllegalArgumentException("Unsupported API key " + apiKey) + } + } + + case class Client(clientId: String, apiKey: ApiKeys) { + var correlationId: Int = 0 + def runUntil(until: AbstractResponse => Boolean): Boolean = { + val startMs = System.currentTimeMillis + var done = false + val socket = connect() + try { + while (!done && System.currentTimeMillis < startMs + 10000) { + correlationId += 1 + val request = requestBuilder(apiKey).build() + val response = sendAndReceive[AbstractResponse](request, socket, clientId, Some(correlationId)) + done = until.apply(response) + } + } finally { + socket.close() + } + done + } + + override def toString: String = { + val requestTime = requestTimeMetricValue(clientId) + val throttleTime = throttleTimeMetricValue(clientId) + val produceThrottleTime = throttleTimeMetricValueForQuotaType(clientId, QuotaType.Produce) + val consumeThrottleTime = throttleTimeMetricValueForQuotaType(clientId, QuotaType.Fetch) + s"Client $clientId apiKey $apiKey requests $correlationId requestTime $requestTime " + + s"throttleTime $throttleTime produceThrottleTime $produceThrottleTime consumeThrottleTime $consumeThrottleTime" + } + } + + private def submitTest(apiKey: ApiKeys, test: () => Unit): Unit = { + val future = executor.submit(new Runnable() { + def run(): Unit = { + test.apply() + } + }) + tasks += Task(apiKey, future) + } + + private def waitAndCheckResults(): Unit = { + for (task <- tasks) { + try { + task.future.get(15, TimeUnit.SECONDS) + } catch { + case e: Throwable => + error(s"Test failed for api-key ${task.apiKey} with exception $e") + throw e + } + } + } + + private def checkRequestThrottleTime(apiKey: ApiKeys): Unit = { + // Request until throttled using client-id with default small quota + val clientId = apiKey.toString + val client = Client(clientId, apiKey) + val throttled = client.runUntil(_.throttleTimeMs > 0) + + assertTrue(throttled, s"Response not throttled: $client") + assertTrue(throttleTimeMetricValue(clientId) > 0 , s"Throttle time metrics not updated: $client") + } + + private def checkSmallQuotaProducerRequestThrottleTime(): Unit = { + + // Request until throttled using client-id with default small producer quota + val smallQuotaProducerClient = Client(smallQuotaProducerClientId, ApiKeys.PRODUCE) + val throttled = smallQuotaProducerClient.runUntil(_.throttleTimeMs > 0) + + assertTrue(throttled, s"Response not throttled: $smallQuotaProducerClient") + assertTrue(throttleTimeMetricValueForQuotaType(smallQuotaProducerClientId, QuotaType.Produce) > 0, + s"Throttle time metrics for produce quota not updated: $smallQuotaProducerClient") + assertTrue(throttleTimeMetricValueForQuotaType(smallQuotaProducerClientId, QuotaType.Request).isNaN, + s"Throttle time metrics for request quota updated: $smallQuotaProducerClient") + } + + private def checkSmallQuotaConsumerRequestThrottleTime(): Unit = { + + // Request until throttled using client-id with default small consumer quota + val smallQuotaConsumerClient = Client(smallQuotaConsumerClientId, ApiKeys.FETCH) + val throttled = smallQuotaConsumerClient.runUntil(_.throttleTimeMs > 0) + + assertTrue(throttled, s"Response not throttled: $smallQuotaConsumerClientId") + assertTrue(throttleTimeMetricValueForQuotaType(smallQuotaConsumerClientId, QuotaType.Fetch) > 0, + s"Throttle time metrics for consumer quota not updated: $smallQuotaConsumerClient") + assertTrue(throttleTimeMetricValueForQuotaType(smallQuotaConsumerClientId, QuotaType.Request).isNaN, + s"Throttle time metrics for request quota updated: $smallQuotaConsumerClient") + } + + private def checkUnthrottledClient(apiKey: ApiKeys): Unit = { + + // Test that request from client with large quota is not throttled + val unthrottledClient = Client(unthrottledClientId, apiKey) + unthrottledClient.runUntil(_.throttleTimeMs <= 0.0) + assertEquals(1, unthrottledClient.correlationId) + assertTrue(throttleTimeMetricValue(unthrottledClientId).isNaN, s"Client should not have been throttled: $unthrottledClient") + } + + private def checkExemptRequestMetric(apiKey: ApiKeys): Unit = { + val exemptTarget = exemptRequestMetricValue + 0.02 + val clientId = apiKey.toString + val client = Client(clientId, apiKey) + val updated = client.runUntil(_ => exemptRequestMetricValue > exemptTarget) + + assertTrue(updated, s"Exempt-request-time metric not updated: $client") + assertTrue(throttleTimeMetricValue(clientId).isNaN, s"Client should not have been throttled: $client") + } + + private def checkUnauthorizedRequestThrottle(apiKey: ApiKeys): Unit = { + val clientId = "unauthorized-" + apiKey.toString + val client = Client(clientId, apiKey) + val throttled = client.runUntil(_ => throttleTimeMetricValue(clientId) > 0.0) + assertTrue(throttled, s"Unauthorized client should have been throttled: $client") + } +} + +object RequestQuotaTest { + val ClusterActions = ApiKeys.zkBrokerApis.asScala.filter(_.clusterAction).toSet + val ClusterActionsWithThrottle = Set(ApiKeys.ALLOCATE_PRODUCER_IDS) + val SaslActions = Set(ApiKeys.SASL_HANDSHAKE, ApiKeys.SASL_AUTHENTICATE) + val ClientActions = ApiKeys.zkBrokerApis.asScala.toSet -- ClusterActions -- SaslActions + + val UnauthorizedPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "Unauthorized") + // Principal used for all client connections. This is modified by tests which + // check unauthorized code path + var principal = KafkaPrincipal.ANONYMOUS + class TestAuthorizer extends AclAuthorizer { + override def authorize(requestContext: AuthorizableRequestContext, actions: util.List[Action]): util.List[AuthorizationResult] = { + actions.asScala.map { _ => + if (requestContext.principal != UnauthorizedPrincipal) AuthorizationResult.ALLOWED else AuthorizationResult.DENIED + }.asJava + } + } + class TestPrincipalBuilder extends KafkaPrincipalBuilder with KafkaPrincipalSerde { + override def build(context: AuthenticationContext): KafkaPrincipal = { + principal + } + + override def serialize(principal: KafkaPrincipal): Array[Byte] = { + new Array[Byte](0) + } + + override def deserialize(bytes: Array[Byte]): KafkaPrincipal = { + principal + } + } +} diff --git a/core/src/test/scala/unit/kafka/server/SaslApiVersionsRequestTest.scala b/core/src/test/scala/unit/kafka/server/SaslApiVersionsRequestTest.scala new file mode 100644 index 0000000..2660948 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/SaslApiVersionsRequestTest.scala @@ -0,0 +1,109 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.net.Socket +import java.util.Collections + +import kafka.api.{KafkaSasl, SaslSetup} +import kafka.test.annotation.{ClusterTest, Type} +import kafka.test.junit.ClusterTestExtensions +import kafka.test.{ClusterConfig, ClusterInstance} +import kafka.utils.JaasTestUtils +import org.apache.kafka.common.message.SaslHandshakeRequestData +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.requests.{ApiVersionsRequest, ApiVersionsResponse, SaslHandshakeRequest, SaslHandshakeResponse} +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.extension.ExtendWith +import org.junit.jupiter.api.{AfterEach, BeforeEach} + +import scala.jdk.CollectionConverters._ + + +@ExtendWith(value = Array(classOf[ClusterTestExtensions])) +class SaslApiVersionsRequestTest(cluster: ClusterInstance) extends AbstractApiVersionsRequestTest(cluster) { + + val kafkaClientSaslMechanism = "PLAIN" + val kafkaServerSaslMechanisms = List("PLAIN") + + private var sasl: SaslSetup = _ + + @BeforeEach + def setupSasl(config: ClusterConfig): Unit = { + sasl = new SaslSetup() {} + sasl.startSasl(sasl.jaasSections(kafkaServerSaslMechanisms, Some(kafkaClientSaslMechanism), KafkaSasl, JaasTestUtils.KafkaServerContextName)) + config.saslServerProperties().putAll(sasl.kafkaServerSaslProperties(kafkaServerSaslMechanisms, kafkaClientSaslMechanism)) + config.saslClientProperties().putAll(sasl.kafkaClientSaslProperties(kafkaClientSaslMechanism)) + super.brokerPropertyOverrides(config.serverProperties()) + } + + @ClusterTest(securityProtocol = SecurityProtocol.SASL_PLAINTEXT, clusterType = Type.ZK) + def testApiVersionsRequestBeforeSaslHandshakeRequest(): Unit = { + val socket = IntegrationTestUtils.connect(cluster.brokerSocketServers().asScala.head, cluster.clientListener()) + try { + val apiVersionsResponse = IntegrationTestUtils.sendAndReceive[ApiVersionsResponse]( + new ApiVersionsRequest.Builder().build(0), socket) + validateApiVersionsResponse(apiVersionsResponse) + sendSaslHandshakeRequestValidateResponse(socket) + } finally { + socket.close() + } + } + + @ClusterTest(securityProtocol = SecurityProtocol.SASL_PLAINTEXT, clusterType = Type.ZK) + def testApiVersionsRequestAfterSaslHandshakeRequest(): Unit = { + val socket = IntegrationTestUtils.connect(cluster.brokerSocketServers().asScala.head, cluster.clientListener()) + try { + sendSaslHandshakeRequestValidateResponse(socket) + val response = IntegrationTestUtils.sendAndReceive[ApiVersionsResponse]( + new ApiVersionsRequest.Builder().build(0), socket) + assertEquals(Errors.ILLEGAL_SASL_STATE.code, response.data.errorCode) + } finally { + socket.close() + } + } + + @ClusterTest(securityProtocol = SecurityProtocol.SASL_PLAINTEXT, clusterType = Type.ZK) + def testApiVersionsRequestWithUnsupportedVersion(): Unit = { + val socket = IntegrationTestUtils.connect(cluster.brokerSocketServers().asScala.head, cluster.clientListener()) + try { + val apiVersionsRequest = new ApiVersionsRequest.Builder().build(0) + val apiVersionsResponse = sendUnsupportedApiVersionRequest(apiVersionsRequest) + assertEquals(Errors.UNSUPPORTED_VERSION.code, apiVersionsResponse.data.errorCode) + val apiVersionsResponse2 = IntegrationTestUtils.sendAndReceive[ApiVersionsResponse]( + new ApiVersionsRequest.Builder().build(0), socket) + validateApiVersionsResponse(apiVersionsResponse2) + sendSaslHandshakeRequestValidateResponse(socket) + } finally { + socket.close() + } + } + + @AfterEach + def closeSasl(): Unit = { + sasl.closeSasl() + } + + private def sendSaslHandshakeRequestValidateResponse(socket: Socket): Unit = { + val request = new SaslHandshakeRequest(new SaslHandshakeRequestData().setMechanism("PLAIN"), + ApiKeys.SASL_HANDSHAKE.latestVersion) + val response = IntegrationTestUtils.sendAndReceive[SaslHandshakeResponse](request, socket) + assertEquals(Errors.NONE, response.error) + assertEquals(Collections.singletonList("PLAIN"), response.enabledMechanisms) + } +} diff --git a/core/src/test/scala/unit/kafka/server/ServerGenerateBrokerIdTest.scala b/core/src/test/scala/unit/kafka/server/ServerGenerateBrokerIdTest.scala new file mode 100755 index 0000000..096debe --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/ServerGenerateBrokerIdTest.scala @@ -0,0 +1,193 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.util.Properties + +import scala.collection.Seq + +import kafka.server.QuorumTestHarness +import kafka.utils.TestUtils +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} +import org.junit.jupiter.api.Assertions._ +import java.io.File + +import org.apache.zookeeper.KeeperException.NodeExistsException + +class ServerGenerateBrokerIdTest extends QuorumTestHarness { + var props1: Properties = null + var config1: KafkaConfig = null + var props2: Properties = null + var config2: KafkaConfig = null + val brokerMetaPropsFile = "meta.properties" + var servers: Seq[KafkaServer] = Seq() + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + props1 = TestUtils.createBrokerConfig(-1, zkConnect) + config1 = KafkaConfig.fromProps(props1) + props2 = TestUtils.createBrokerConfig(0, zkConnect) + config2 = KafkaConfig.fromProps(props2) + } + + @AfterEach + override def tearDown(): Unit = { + TestUtils.shutdownServers(servers) + super.tearDown() + } + + @Test + def testAutoGenerateBrokerId(): Unit = { + var server1 = new KafkaServer(config1, threadNamePrefix = Option(this.getClass.getName)) + server1.startup() + server1.shutdown() + assertTrue(verifyBrokerMetadata(config1.logDirs, 1001)) + // restart the server check to see if it uses the brokerId generated previously + server1 = TestUtils.createServer(config1, threadNamePrefix = Option(this.getClass.getName)) + servers = Seq(server1) + assertEquals(server1.config.brokerId, 1001) + server1.shutdown() + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @Test + def testUserConfigAndGeneratedBrokerId(): Unit = { + // start the server with broker.id as part of config + val server1 = new KafkaServer(config1, threadNamePrefix = Option(this.getClass.getName)) + val server2 = new KafkaServer(config2, threadNamePrefix = Option(this.getClass.getName)) + val props3 = TestUtils.createBrokerConfig(-1, zkConnect) + val server3 = new KafkaServer(KafkaConfig.fromProps(props3), threadNamePrefix = Option(this.getClass.getName)) + server1.startup() + assertEquals(server1.config.brokerId, 1001) + server2.startup() + assertEquals(server2.config.brokerId, 0) + server3.startup() + assertEquals(server3.config.brokerId, 1002) + servers = Seq(server1, server2, server3) + servers.foreach(_.shutdown()) + assertTrue(verifyBrokerMetadata(server1.config.logDirs, 1001)) + assertTrue(verifyBrokerMetadata(server2.config.logDirs, 0)) + assertTrue(verifyBrokerMetadata(server3.config.logDirs, 1002)) + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @Test + def testDisableGeneratedBrokerId(): Unit = { + val props3 = TestUtils.createBrokerConfig(3, zkConnect) + props3.put(KafkaConfig.BrokerIdGenerationEnableProp, "false") + // Set reserve broker ids to cause collision and ensure disabling broker id generation ignores the setting + props3.put(KafkaConfig.MaxReservedBrokerIdProp, "0") + val config3 = KafkaConfig.fromProps(props3) + val server3 = TestUtils.createServer(config3, threadNamePrefix = Option(this.getClass.getName)) + servers = Seq(server3) + assertEquals(server3.config.brokerId, 3) + server3.shutdown() + assertTrue(verifyBrokerMetadata(server3.config.logDirs, 3)) + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @Test + def testMultipleLogDirsMetaProps(): Unit = { + // add multiple logDirs and check if the generate brokerId is stored in all of them + val logDirs = props1.getProperty("log.dir")+ "," + TestUtils.tempDir().getAbsolutePath + + "," + TestUtils.tempDir().getAbsolutePath + props1.setProperty("log.dir", logDirs) + config1 = KafkaConfig.fromProps(props1) + var server1 = new KafkaServer(config1, threadNamePrefix = Option(this.getClass.getName)) + server1.startup() + servers = Seq(server1) + server1.shutdown() + assertTrue(verifyBrokerMetadata(config1.logDirs, 1001)) + // addition to log.dirs after generation of a broker.id from zk should be copied over + val newLogDirs = props1.getProperty("log.dir") + "," + TestUtils.tempDir().getAbsolutePath + props1.setProperty("log.dir", newLogDirs) + config1 = KafkaConfig.fromProps(props1) + server1 = new KafkaServer(config1, threadNamePrefix = Option(this.getClass.getName)) + server1.startup() + servers = Seq(server1) + server1.shutdown() + assertTrue(verifyBrokerMetadata(config1.logDirs, 1001)) + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @Test + def testConsistentBrokerIdFromUserConfigAndMetaProps(): Unit = { + // check if configured brokerId and stored brokerId are equal or throw InconsistentBrokerException + var server1 = new KafkaServer(config1, threadNamePrefix = Option(this.getClass.getName)) //auto generate broker Id + server1.startup() + servers = Seq(server1) + server1.shutdown() + server1 = new KafkaServer(config2, threadNamePrefix = Option(this.getClass.getName)) // user specified broker id + try { + server1.startup() + } catch { + case _: kafka.common.InconsistentBrokerIdException => //success + } + server1.shutdown() + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @Test + def testBrokerMetadataOnIdCollision(): Unit = { + // Start a good server + val propsA = TestUtils.createBrokerConfig(1, zkConnect) + val configA = KafkaConfig.fromProps(propsA) + val serverA = TestUtils.createServer(configA, threadNamePrefix = Option(this.getClass.getName)) + + // Start a server that collides on the broker id + val propsB = TestUtils.createBrokerConfig(1, zkConnect) + val configB = KafkaConfig.fromProps(propsB) + val serverB = new KafkaServer(configB, threadNamePrefix = Option(this.getClass.getName)) + assertThrows(classOf[NodeExistsException], () => serverB.startup()) + servers = Seq(serverA) + + // verify no broker metadata was written + serverB.config.logDirs.foreach { logDir => + val brokerMetaFile = new File(logDir + File.separator + brokerMetaPropsFile) + assertFalse(brokerMetaFile.exists()) + } + + // adjust the broker config and start again + propsB.setProperty(KafkaConfig.BrokerIdProp, "2") + val newConfigB = KafkaConfig.fromProps(propsB) + val newServerB = TestUtils.createServer(newConfigB, threadNamePrefix = Option(this.getClass.getName)) + servers = Seq(serverA, newServerB) + + serverA.shutdown() + newServerB.shutdown() + + // verify correct broker metadata was written + assertTrue(verifyBrokerMetadata(serverA.config.logDirs, 1)) + assertTrue(verifyBrokerMetadata(newServerB.config.logDirs, 2)) + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + def verifyBrokerMetadata(logDirs: Seq[String], brokerId: Int): Boolean = { + for (logDir <- logDirs) { + val brokerMetadataOpt = new BrokerMetadataCheckpoint( + new File(logDir + File.separator + brokerMetaPropsFile)).read() + brokerMetadataOpt match { + case Some(properties) => + val brokerMetadata = new RawMetaProperties(properties) + if (brokerMetadata.brokerId.exists(_ != brokerId)) return false + case _ => return false + } + } + true + } +} diff --git a/core/src/test/scala/unit/kafka/server/ServerGenerateClusterIdTest.scala b/core/src/test/scala/unit/kafka/server/ServerGenerateClusterIdTest.scala new file mode 100755 index 0000000..fd9b365 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/ServerGenerateClusterIdTest.scala @@ -0,0 +1,234 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.io.File + + +import scala.collection.Seq +import scala.concurrent._ +import scala.concurrent.duration._ +import ExecutionContext.Implicits._ + +import kafka.common.{InconsistentBrokerMetadataException, InconsistentClusterIdException} +import kafka.utils.TestUtils +import kafka.server.QuorumTestHarness + +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} +import org.apache.kafka.test.TestUtils.isValidClusterId + + +class ServerGenerateClusterIdTest extends QuorumTestHarness { + var config1: KafkaConfig = null + var config2: KafkaConfig = null + var config3: KafkaConfig = null + var servers: Seq[KafkaServer] = Seq() + val brokerMetaPropsFile = "meta.properties" + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + config1 = KafkaConfig.fromProps(TestUtils.createBrokerConfig(1, zkConnect)) + config2 = KafkaConfig.fromProps(TestUtils.createBrokerConfig(2, zkConnect)) + config3 = KafkaConfig.fromProps(TestUtils.createBrokerConfig(3, zkConnect)) + } + + @AfterEach + override def tearDown(): Unit = { + TestUtils.shutdownServers(servers) + super.tearDown() + } + + + @Test + def testAutoGenerateClusterId(): Unit = { + // Make sure that the cluster id doesn't exist yet. + assertFalse(zkClient.getClusterId.isDefined) + + var server1 = TestUtils.createServer(config1, threadNamePrefix = Option(this.getClass.getName)) + servers = Seq(server1) + + // Validate the cluster id + val clusterIdOnFirstBoot = server1.clusterId + isValidClusterId(clusterIdOnFirstBoot) + + server1.shutdown() + + // Make sure that the cluster id is persistent. + assertTrue(zkClient.getClusterId.isDefined) + assertEquals(zkClient.getClusterId, Some(clusterIdOnFirstBoot)) + + // Restart the server check to confirm that it uses the clusterId generated previously + server1 = TestUtils.createServer(config1, threadNamePrefix = Option(this.getClass.getName)) + servers = Seq(server1) + + val clusterIdOnSecondBoot = server1.clusterId + assertEquals(clusterIdOnFirstBoot, clusterIdOnSecondBoot) + + server1.shutdown() + + // Make sure that the cluster id is persistent after multiple reboots. + assertTrue(zkClient.getClusterId.isDefined) + assertEquals(zkClient.getClusterId, Some(clusterIdOnFirstBoot)) + + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @Test + def testAutoGenerateClusterIdForKafkaClusterSequential(): Unit = { + val server1 = TestUtils.createServer(config1, threadNamePrefix = Option(this.getClass.getName)) + val clusterIdFromServer1 = server1.clusterId + + val server2 = TestUtils.createServer(config2, threadNamePrefix = Option(this.getClass.getName)) + val clusterIdFromServer2 = server2.clusterId + + val server3 = TestUtils.createServer(config3, threadNamePrefix = Option(this.getClass.getName)) + val clusterIdFromServer3 = server3.clusterId + servers = Seq(server1, server2, server3) + + servers.foreach(_.shutdown()) + + isValidClusterId(clusterIdFromServer1) + assertEquals(clusterIdFromServer1, clusterIdFromServer2, clusterIdFromServer3) + + // Check again after reboot + server1.startup() + assertEquals(clusterIdFromServer1, server1.clusterId) + server2.startup() + assertEquals(clusterIdFromServer2, server2.clusterId) + server3.startup() + assertEquals(clusterIdFromServer3, server3.clusterId) + + servers.foreach(_.shutdown()) + + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @Test + def testAutoGenerateClusterIdForKafkaClusterParallel(): Unit = { + val firstBoot = Future.traverse(Seq(config1, config2, config3))(config => Future(TestUtils.createServer(config, threadNamePrefix = Option(this.getClass.getName)))) + servers = Await.result(firstBoot, 100 second) + val Seq(server1, server2, server3) = servers + + val clusterIdFromServer1 = server1.clusterId + val clusterIdFromServer2 = server2.clusterId + val clusterIdFromServer3 = server3.clusterId + + servers.foreach(_.shutdown()) + isValidClusterId(clusterIdFromServer1) + assertEquals(clusterIdFromServer1, clusterIdFromServer2, clusterIdFromServer3) + + // Check again after reboot + val secondBoot = Future.traverse(Seq(server1, server2, server3))(server => Future { + server.startup() + server + }) + servers = Await.result(secondBoot, 100 second) + servers.foreach(server => assertEquals(clusterIdFromServer1, server.clusterId)) + + servers.foreach(_.shutdown()) + + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @Test + def testConsistentClusterIdFromZookeeperAndFromMetaProps() = { + // Check at the first boot + val server = TestUtils.createServer(config1, threadNamePrefix = Option(this.getClass.getName)) + val clusterId = server.clusterId + + assertTrue(verifyBrokerMetadata(server.config.logDirs, clusterId)) + + server.shutdown() + + // Check again after reboot + server.startup() + + assertEquals(clusterId, server.clusterId) + assertTrue(verifyBrokerMetadata(server.config.logDirs, server.clusterId)) + + server.shutdown() + + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @Test + def testInconsistentClusterIdFromZookeeperAndFromMetaProps() = { + forgeBrokerMetadata(config1.logDirs, config1.brokerId, "aclusterid") + + val server = new KafkaServer(config1, threadNamePrefix = Option(this.getClass.getName)) + + // Startup fails + assertThrows(classOf[InconsistentClusterIdException], () => server.startup()) + + server.shutdown() + + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @Test + def testInconsistentBrokerMetadataBetweenMultipleLogDirs(): Unit = { + // Add multiple logDirs with different BrokerMetadata + val logDir1 = TestUtils.tempDir().getAbsolutePath + val logDir2 = TestUtils.tempDir().getAbsolutePath + val logDirs = logDir1 + "," + logDir2 + + forgeBrokerMetadata(logDir1, 1, "ebwOKU-zSieInaFQh_qP4g") + forgeBrokerMetadata(logDir2, 1, "blaOKU-zSieInaFQh_qP4g") + + val props = TestUtils.createBrokerConfig(1, zkConnect) + props.setProperty("log.dir", logDirs) + val config = KafkaConfig.fromProps(props) + + val server = new KafkaServer(config, threadNamePrefix = Option(this.getClass.getName)) + + // Startup fails + assertThrows(classOf[InconsistentBrokerMetadataException], () => server.startup()) + + server.shutdown() + + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + def forgeBrokerMetadata(logDirs: Seq[String], brokerId: Int, clusterId: String): Unit = { + for (logDir <- logDirs) { + forgeBrokerMetadata(logDir, brokerId, clusterId) + } + } + + def forgeBrokerMetadata(logDir: String, brokerId: Int, clusterId: String): Unit = { + val checkpoint = new BrokerMetadataCheckpoint( + new File(logDir + File.separator + brokerMetaPropsFile)) + checkpoint.write(ZkMetaProperties(clusterId, brokerId).toProperties) + } + + def verifyBrokerMetadata(logDirs: Seq[String], clusterId: String): Boolean = { + for (logDir <- logDirs) { + val brokerMetadataOpt = new BrokerMetadataCheckpoint( + new File(logDir + File.separator + brokerMetaPropsFile)).read() + brokerMetadataOpt match { + case Some(properties) => + val brokerMetadata = new RawMetaProperties(properties) + if (brokerMetadata.clusterId.exists(_ != clusterId)) return false + case _ => return false + } + } + true + } + +} diff --git a/core/src/test/scala/unit/kafka/server/ServerMetricsTest.scala b/core/src/test/scala/unit/kafka/server/ServerMetricsTest.scala new file mode 100755 index 0000000..6c7f148 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/ServerMetricsTest.scala @@ -0,0 +1,48 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.utils.TestUtils +import org.apache.kafka.common.metrics.Sensor +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +class ServerMetricsTest { + + @Test + def testMetricsConfig(): Unit = { + val recordingLevels = List(Sensor.RecordingLevel.DEBUG, Sensor.RecordingLevel.INFO) + val illegalNames = List("IllegalName", "") + val props = TestUtils.createBrokerConfig(0, "localhost:2818") + + for (recordingLevel <- recordingLevels) { + props.put(KafkaConfig.MetricRecordingLevelProp, recordingLevel.name) + val config = KafkaConfig.fromProps(props) + val metricConfig = Server.buildMetricsConfig(config) + assertEquals(recordingLevel, metricConfig.recordLevel) + } + + for (illegalName <- illegalNames) { + props.put(KafkaConfig.MetricRecordingLevelProp, illegalName) + val config = KafkaConfig.fromProps(props) + assertThrows(classOf[IllegalArgumentException], () => Server.buildMetricsConfig(config)) + } + + } + +} diff --git a/core/src/test/scala/unit/kafka/server/ServerShutdownTest.scala b/core/src/test/scala/unit/kafka/server/ServerShutdownTest.scala new file mode 100755 index 0000000..013d084 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/ServerShutdownTest.scala @@ -0,0 +1,258 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import kafka.server.QuorumTestHarness +import kafka.utils.{CoreUtils, TestUtils} +import kafka.utils.TestUtils._ + +import java.io.{DataInputStream, File} +import java.net.ServerSocket +import java.util.Collections +import java.util.concurrent.{Executors, TimeUnit} +import kafka.cluster.Broker +import kafka.controller.{ControllerChannelManager, ControllerContext, StateChangeLogger} +import kafka.log.LogManager +import kafka.zookeeper.ZooKeeperClientTimeoutException +import org.apache.kafka.clients.consumer.KafkaConsumer +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord} +import org.apache.kafka.common.Uuid +import org.apache.kafka.common.errors.KafkaStorageException +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.protocol.ApiKeys +import org.apache.kafka.common.requests.LeaderAndIsrRequest +import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.serialization.{IntegerDeserializer, IntegerSerializer, StringDeserializer, StringSerializer} +import org.apache.kafka.common.utils.Time +import org.apache.kafka.metadata.BrokerState +import org.junit.jupiter.api.{BeforeEach, Test, TestInfo, Timeout} +import org.junit.jupiter.api.Assertions._ + +import scala.jdk.CollectionConverters._ +import scala.reflect.ClassTag + +@Timeout(60) +class ServerShutdownTest extends QuorumTestHarness { + var config: KafkaConfig = null + val host = "localhost" + val topic = "test" + val sent1 = List("hello", "there") + val sent2 = List("more", "messages") + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + val props = TestUtils.createBrokerConfig(0, zkConnect) + config = KafkaConfig.fromProps(props) + } + + @Test + def testCleanShutdown(): Unit = { + + def createProducer(server: KafkaServer): KafkaProducer[Integer, String] = + TestUtils.createProducer( + TestUtils.getBrokerListStrFromServers(Seq(server)), + keySerializer = new IntegerSerializer, + valueSerializer = new StringSerializer + ) + + def createConsumer(server: KafkaServer): KafkaConsumer[Integer, String] = + TestUtils.createConsumer( + TestUtils.getBrokerListStrFromServers(Seq(server)), + securityProtocol = SecurityProtocol.PLAINTEXT, + keyDeserializer = new IntegerDeserializer, + valueDeserializer = new StringDeserializer + ) + + var server = new KafkaServer(config, threadNamePrefix = Option(this.getClass.getName)) + server.startup() + var producer = createProducer(server) + + // create topic + createTopic(zkClient, topic, servers = Seq(server)) + + // send some messages + sent1.map(value => producer.send(new ProducerRecord(topic, 0, value))).foreach(_.get) + + // do a clean shutdown and check that offset checkpoint file exists + server.shutdown() + for (logDir <- config.logDirs) { + val OffsetCheckpointFile = new File(logDir, LogManager.RecoveryPointCheckpointFile) + assertTrue(OffsetCheckpointFile.exists) + assertTrue(OffsetCheckpointFile.length() > 0) + } + producer.close() + + /* now restart the server and check that the written data is still readable and everything still works */ + server = new KafkaServer(config) + server.startup() + + // wait for the broker to receive the update metadata request after startup + TestUtils.waitForPartitionMetadata(Seq(server), topic, 0) + + producer = createProducer(server) + val consumer = createConsumer(server) + consumer.subscribe(Seq(topic).asJava) + + val consumerRecords = TestUtils.consumeRecords(consumer, sent1.size) + assertEquals(sent1, consumerRecords.map(_.value)) + + // send some more messages + sent2.map(value => producer.send(new ProducerRecord(topic, 0, value))).foreach(_.get) + + val consumerRecords2 = TestUtils.consumeRecords(consumer, sent2.size) + assertEquals(sent2, consumerRecords2.map(_.value)) + + consumer.close() + producer.close() + server.shutdown() + CoreUtils.delete(server.config.logDirs) + verifyNonDaemonThreadsStatus() + } + + @Test + def testCleanShutdownAfterFailedStartup(): Unit = { + val newProps = TestUtils.createBrokerConfig(0, zkConnect) + newProps.setProperty(KafkaConfig.ZkConnectionTimeoutMsProp, "50") + newProps.setProperty(KafkaConfig.ZkConnectProp, "some.invalid.hostname.foo.bar.local:65535") + val newConfig = KafkaConfig.fromProps(newProps) + verifyCleanShutdownAfterFailedStartup[ZooKeeperClientTimeoutException](newConfig) + } + + @Test + def testCleanShutdownAfterFailedStartupDueToCorruptLogs(): Unit = { + val server = new KafkaServer(config, threadNamePrefix = Option(this.getClass.getName)) + server.startup() + createTopic(zkClient, topic, servers = Seq(server)) + server.shutdown() + server.awaitShutdown() + config.logDirs.foreach { dirName => + val partitionDir = new File(dirName, s"$topic-0") + partitionDir.listFiles.foreach(f => TestUtils.appendNonsenseToFile(f, TestUtils.random.nextInt(1024) + 1)) + } + verifyCleanShutdownAfterFailedStartup[KafkaStorageException](config) + } + + @Test + def testCleanShutdownWithZkUnavailable(): Unit = { + val server = new KafkaServer(config, threadNamePrefix = Option(this.getClass.getName)) + server.startup() + shutdownZooKeeper() + server.shutdown() + server.awaitShutdown() + CoreUtils.delete(server.config.logDirs) + verifyNonDaemonThreadsStatus() + } + + private def verifyCleanShutdownAfterFailedStartup[E <: Exception](config: KafkaConfig)(implicit exceptionClassTag: ClassTag[E]): Unit = { + val server = new KafkaServer(config, threadNamePrefix = Option(this.getClass.getName)) + try { + server.startup() + fail("Expected KafkaServer setup to fail and throw exception") + } + catch { + // Try to clean up carefully without hanging even if the test fails. This means trying to accurately + // identify the correct exception, making sure the server was shutdown, and cleaning up if anything + // goes wrong so that awaitShutdown doesn't hang + case e: Exception => + assertTrue(exceptionClassTag.runtimeClass.isInstance(e), s"Unexpected exception $e") + assertEquals(BrokerState.NOT_RUNNING, server.brokerState) + } + finally { + if (server.brokerState != BrokerState.NOT_RUNNING) + server.shutdown() + server.awaitShutdown() + } + CoreUtils.delete(server.config.logDirs) + verifyNonDaemonThreadsStatus() + } + + private[this] def isNonDaemonKafkaThread(t: Thread): Boolean = { + !t.isDaemon && t.isAlive && t.getName.startsWith(this.getClass.getName) + } + + def verifyNonDaemonThreadsStatus(): Unit = { + assertEquals(0, Thread.getAllStackTraces.keySet.toArray + .map(_.asInstanceOf[Thread]) + .count(isNonDaemonKafkaThread)) + } + + @Test + def testConsecutiveShutdown(): Unit = { + val server = new KafkaServer(config) + server.startup() + server.shutdown() + server.awaitShutdown() + server.shutdown() + } + + // Verify that if controller is in the midst of processing a request, shutdown completes + // without waiting for request timeout. + @Test + def testControllerShutdownDuringSend(): Unit = { + val securityProtocol = SecurityProtocol.PLAINTEXT + val listenerName = ListenerName.forSecurityProtocol(securityProtocol) + + val controllerId = 2 + val metrics = new Metrics + val executor = Executors.newSingleThreadExecutor + var serverSocket: ServerSocket = null + var controllerChannelManager: ControllerChannelManager = null + + try { + // Set up a server to accept a connection and receive one byte from the first request. No response is sent. + serverSocket = new ServerSocket(0) + val receiveFuture = executor.submit(new Runnable { + override def run(): Unit = { + val socket = serverSocket.accept() + new DataInputStream(socket.getInputStream).readByte() + } + }) + + // Start a ControllerChannelManager + val brokerAndEpochs = Map((new Broker(1, "localhost", serverSocket.getLocalPort, listenerName, securityProtocol), 0L)) + val controllerConfig = KafkaConfig.fromProps(TestUtils.createBrokerConfig(controllerId, zkConnect)) + val controllerContext = new ControllerContext + controllerContext.setLiveBrokers(brokerAndEpochs) + controllerChannelManager = new ControllerChannelManager(controllerContext, controllerConfig, Time.SYSTEM, + metrics, new StateChangeLogger(controllerId, inControllerContext = true, None)) + controllerChannelManager.startup() + + // Initiate a sendRequest and wait until connection is established and one byte is received by the peer + val requestBuilder = new LeaderAndIsrRequest.Builder(ApiKeys.LEADER_AND_ISR.latestVersion, + controllerId, 1, 0L, Seq.empty.asJava, Collections.singletonMap(topic, Uuid.randomUuid()), + brokerAndEpochs.keys.map(_.node(listenerName)).toSet.asJava) + controllerChannelManager.sendRequest(1, requestBuilder) + receiveFuture.get(10, TimeUnit.SECONDS) + + // Shutdown controller. Request timeout is 30s, verify that shutdown completed well before that + val shutdownFuture = executor.submit(new Runnable { + override def run(): Unit = controllerChannelManager.shutdown() + }) + shutdownFuture.get(10, TimeUnit.SECONDS) + + } finally { + if (serverSocket != null) + serverSocket.close() + if (controllerChannelManager != null) + controllerChannelManager.shutdown() + executor.shutdownNow() + metrics.close() + } + } +} diff --git a/core/src/test/scala/unit/kafka/server/ServerStartupTest.scala b/core/src/test/scala/unit/kafka/server/ServerStartupTest.scala new file mode 100755 index 0000000..e80084d --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/ServerStartupTest.scala @@ -0,0 +1,108 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.utils.TestUtils +import kafka.server.QuorumTestHarness +import org.apache.kafka.common.KafkaException +import org.apache.kafka.metadata.BrokerState +import org.apache.zookeeper.KeeperException.NodeExistsException +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, Test} + +class ServerStartupTest extends QuorumTestHarness { + + private var server: KafkaServer = null + + @AfterEach + override def tearDown(): Unit = { + if (server != null) + TestUtils.shutdownServers(Seq(server)) + super.tearDown() + } + + @Test + def testBrokerCreatesZKChroot(): Unit = { + val brokerId = 0 + val zookeeperChroot = "/kafka-chroot-for-unittest" + val props = TestUtils.createBrokerConfig(brokerId, zkConnect) + val zooKeeperConnect = props.get("zookeeper.connect") + props.put("zookeeper.connect", zooKeeperConnect.toString + zookeeperChroot) + server = TestUtils.createServer(KafkaConfig.fromProps(props)) + + val pathExists = zkClient.pathExists(zookeeperChroot) + assertTrue(pathExists) + } + + @Test + def testConflictBrokerStartupWithSamePort(): Unit = { + // Create and start first broker + val brokerId1 = 0 + val props1 = TestUtils.createBrokerConfig(brokerId1, zkConnect) + server = TestUtils.createServer(KafkaConfig.fromProps(props1)) + val port = TestUtils.boundPort(server) + + // Create a second broker with same port + val brokerId2 = 1 + val props2 = TestUtils.createBrokerConfig(brokerId2, zkConnect, port = port) + assertThrows(classOf[KafkaException], () => TestUtils.createServer(KafkaConfig.fromProps(props2))) + } + + @Test + def testConflictBrokerRegistration(): Unit = { + // Try starting a broker with the a conflicting broker id. + // This shouldn't affect the existing broker registration. + + val brokerId = 0 + val props1 = TestUtils.createBrokerConfig(brokerId, zkConnect) + server = TestUtils.createServer(KafkaConfig.fromProps(props1)) + val brokerRegistration = zkClient.getBroker(brokerId).getOrElse(fail("broker doesn't exists")) + + val props2 = TestUtils.createBrokerConfig(brokerId, zkConnect) + assertThrows(classOf[NodeExistsException], () => TestUtils.createServer(KafkaConfig.fromProps(props2))) + + // broker registration shouldn't change + assertEquals(brokerRegistration, zkClient.getBroker(brokerId).getOrElse(fail("broker doesn't exists"))) + } + + @Test + def testBrokerSelfAware(): Unit = { + val brokerId = 0 + val props = TestUtils.createBrokerConfig(brokerId, zkConnect) + server = TestUtils.createServer(KafkaConfig.fromProps(props)) + + TestUtils.waitUntilTrue(() => server.metadataCache.getAliveBrokers().nonEmpty, "Wait for cache to update") + assertEquals(1, server.metadataCache.getAliveBrokers().size) + assertEquals(brokerId, server.metadataCache.getAliveBrokers().head.id) + } + + @Test + def testBrokerStateRunningAfterZK(): Unit = { + val brokerId = 0 + + val props = TestUtils.createBrokerConfig(brokerId, zkConnect) + server = new KafkaServer(KafkaConfig.fromProps(props)) + + server.startup() + TestUtils.waitUntilTrue(() => server.brokerState == BrokerState.RUNNING, + "waiting for the broker state to become RUNNING") + val brokers = zkClient.getAllBrokersInCluster + assertEquals(1, brokers.size) + assertEquals(brokerId, brokers.head.id) + } +} diff --git a/core/src/test/scala/unit/kafka/server/ServerTest.scala b/core/src/test/scala/unit/kafka/server/ServerTest.scala new file mode 100644 index 0000000..d72ad2d --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/ServerTest.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.util.Properties + +import org.apache.kafka.common.Uuid +import org.apache.kafka.common.metrics.MetricsContext +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ + +class ServerTest { + + @Test + def testCreateSelfManagedKafkaMetricsContext(): Unit = { + val nodeId = 0 + val clusterId = Uuid.randomUuid().toString + + val props = new Properties() + props.put(KafkaConfig.ProcessRolesProp, "broker") + props.put(KafkaConfig.NodeIdProp, nodeId.toString) + props.put(KafkaConfig.QuorumVotersProp, s"${(nodeId + 1)}@localhost:9093") + props.put(KafkaConfig.ControllerListenerNamesProp, "SSL") + val config = KafkaConfig.fromProps(props) + + val context = Server.createKafkaMetricsContext(config, clusterId) + assertEquals(Map( + MetricsContext.NAMESPACE -> Server.MetricsPrefix, + Server.ClusterIdLabel -> clusterId, + Server.NodeIdLabel -> nodeId.toString + ), context.contextLabels.asScala) + } + + @Test + def testCreateZkKafkaMetricsContext(): Unit = { + val brokerId = 0 + val clusterId = Uuid.randomUuid().toString + + val props = new Properties() + props.put(KafkaConfig.BrokerIdProp, brokerId.toString) + props.put(KafkaConfig.ZkConnectProp, "127.0.0.1:0") + val config = KafkaConfig.fromProps(props) + + val context = Server.createKafkaMetricsContext(config, clusterId) + assertEquals(Map( + MetricsContext.NAMESPACE -> Server.MetricsPrefix, + Server.ClusterIdLabel -> clusterId, + Server.BrokerIdLabel -> brokerId.toString + ), context.contextLabels.asScala) + } + +} diff --git a/core/src/test/scala/unit/kafka/server/StopReplicaRequestTest.scala b/core/src/test/scala/unit/kafka/server/StopReplicaRequestTest.scala new file mode 100644 index 0000000..ff246aa --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/StopReplicaRequestTest.scala @@ -0,0 +1,78 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.api.LeaderAndIsr +import kafka.utils._ +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.message.StopReplicaRequestData.{StopReplicaPartitionState, StopReplicaTopicState} +import org.apache.kafka.common.protocol.ApiKeys +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests._ +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ +import scala.collection.Seq + +class StopReplicaRequestTest extends BaseRequestTest { + override val logDirCount = 2 + override val brokerCount: Int = 1 + + val topic = "topic" + val partitionNum = 2 + val tp0 = new TopicPartition(topic, 0) + val tp1 = new TopicPartition(topic, 1) + + @Test + def testStopReplicaRequest(): Unit = { + createTopic(topic, partitionNum, 1) + TestUtils.generateAndProduceMessages(servers, topic, 10) + + val server = servers.head + val offlineDir = server.logManager.getLog(tp1).get.dir.getParent + server.replicaManager.handleLogDirFailure(offlineDir, sendZkNotification = false) + + val topicStates = Seq( + new StopReplicaTopicState() + .setTopicName(tp0.topic()) + .setPartitionStates(Seq(new StopReplicaPartitionState() + .setPartitionIndex(tp0.partition()) + .setLeaderEpoch(LeaderAndIsr.initialLeaderEpoch + 2) + .setDeletePartition(true)).asJava), + new StopReplicaTopicState() + .setTopicName(tp1.topic()) + .setPartitionStates(Seq(new StopReplicaPartitionState() + .setPartitionIndex(tp1.partition()) + .setLeaderEpoch(LeaderAndIsr.initialLeaderEpoch + 2) + .setDeletePartition(true)).asJava) + ).asJava + + for (_ <- 1 to 2) { + val request1 = new StopReplicaRequest.Builder(ApiKeys.STOP_REPLICA.latestVersion, + server.config.brokerId, server.replicaManager.controllerEpoch, server.kafkaController.brokerEpoch, + false, topicStates).build() + val response1 = connectAndReceive[StopReplicaResponse](request1, destination = controllerSocketServer) + val partitionErrors1 = response1.partitionErrors.asScala + assertEquals(Some(Errors.NONE.code), + partitionErrors1.find(pe => pe.topicName == tp0.topic && pe.partitionIndex == tp0.partition).map(_.errorCode)) + assertEquals(Some(Errors.KAFKA_STORAGE_ERROR.code), + partitionErrors1.find(pe => pe.topicName == tp1.topic && pe.partitionIndex == tp1.partition).map(_.errorCode)) + } + } +} diff --git a/core/src/test/scala/unit/kafka/server/ThrottledChannelExpirationTest.scala b/core/src/test/scala/unit/kafka/server/ThrottledChannelExpirationTest.scala new file mode 100644 index 0000000..15ad22d --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/ThrottledChannelExpirationTest.scala @@ -0,0 +1,102 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + + +import java.util.Collections +import java.util.concurrent.{DelayQueue, TimeUnit} + +import org.apache.kafka.common.metrics.MetricConfig +import org.apache.kafka.common.utils.MockTime +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{BeforeEach, Test} + +class ThrottledChannelExpirationTest { + private val time = new MockTime + private var numCallbacksForStartThrottling: Int = 0 + private var numCallbacksForEndThrottling: Int = 0 + private val metrics = new org.apache.kafka.common.metrics.Metrics(new MetricConfig(), + Collections.emptyList(), + time) + private val callback = new ThrottleCallback { + override def startThrottling(): Unit = { + numCallbacksForStartThrottling += 1 + } + + override def endThrottling(): Unit = { + numCallbacksForEndThrottling += 1 + } + } + + @BeforeEach + def beforeMethod(): Unit = { + numCallbacksForStartThrottling = 0 + numCallbacksForEndThrottling = 0 + } + + @Test + def testCallbackInvocationAfterExpiration(): Unit = { + val clientMetrics = new ClientQuotaManager(ClientQuotaManagerConfig(), metrics, QuotaType.Produce, time, "") + + val delayQueue = new DelayQueue[ThrottledChannel]() + val reaper = new clientMetrics.ThrottledChannelReaper(delayQueue, "") + try { + // Add 4 elements to the queue out of order. Add 2 elements with the same expire timestamp. + val channel1 = new ThrottledChannel(time, 10, callback) + val channel2 = new ThrottledChannel(time, 30, callback) + val channel3 = new ThrottledChannel(time, 30, callback) + val channel4 = new ThrottledChannel(time, 20, callback) + delayQueue.add(channel1) + delayQueue.add(channel2) + delayQueue.add(channel3) + delayQueue.add(channel4) + assertEquals(4, numCallbacksForStartThrottling) + + for(itr <- 1 to 3) { + time.sleep(10) + reaper.doWork() + assertEquals(itr, numCallbacksForEndThrottling) + } + reaper.doWork() + assertEquals(4, numCallbacksForEndThrottling) + assertEquals(0, delayQueue.size()) + reaper.doWork() + assertEquals(4, numCallbacksForEndThrottling) + } finally { + clientMetrics.shutdown() + } + } + + @Test + def testThrottledChannelDelay(): Unit = { + val t1: ThrottledChannel = new ThrottledChannel(time, 10, callback) + val t2: ThrottledChannel = new ThrottledChannel(time, 20, callback) + val t3: ThrottledChannel = new ThrottledChannel(time, 20, callback) + assertEquals(10, t1.throttleTimeMs) + assertEquals(20, t2.throttleTimeMs) + assertEquals(20, t3.throttleTimeMs) + + for(itr <- 0 to 2) { + assertEquals(10 - 10*itr, t1.getDelay(TimeUnit.MILLISECONDS)) + assertEquals(20 - 10*itr, t2.getDelay(TimeUnit.MILLISECONDS)) + assertEquals(20 - 10*itr, t3.getDelay(TimeUnit.MILLISECONDS)) + time.sleep(10) + } + } + +} diff --git a/core/src/test/scala/unit/kafka/server/TopicIdWithOldInterBrokerProtocolTest.scala b/core/src/test/scala/unit/kafka/server/TopicIdWithOldInterBrokerProtocolTest.scala new file mode 100644 index 0000000..a49fcb0 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/TopicIdWithOldInterBrokerProtocolTest.scala @@ -0,0 +1,179 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util.{Arrays, LinkedHashMap, Optional, Properties} + +import kafka.api.KAFKA_2_7_IV0 +import kafka.network.SocketServer +import kafka.utils.TestUtils +import org.apache.kafka.common.{TopicIdPartition, TopicPartition, Uuid} +import org.apache.kafka.common.message.DeleteTopicsRequestData +import org.apache.kafka.common.message.DeleteTopicsRequestData.DeleteTopicState +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.requests.{DeleteTopicsRequest, DeleteTopicsResponse, FetchRequest, FetchResponse, MetadataRequest, MetadataResponse} +import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue} +import org.junit.jupiter.api.{BeforeEach, Test, TestInfo} + +import scala.collection.Seq +import scala.jdk.CollectionConverters._ + +class TopicIdWithOldInterBrokerProtocolTest extends BaseRequestTest { + + override def brokerPropertyOverrides(properties: Properties): Unit = { + properties.setProperty(KafkaConfig.InterBrokerProtocolVersionProp, KAFKA_2_7_IV0.toString) + properties.setProperty(KafkaConfig.OffsetsTopicPartitionsProp, "1") + properties.setProperty(KafkaConfig.DefaultReplicationFactorProp, "2") + properties.setProperty(KafkaConfig.RackProp, s"rack/${properties.getProperty(KafkaConfig.BrokerIdProp)}") + } + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + doSetup(testInfo, createOffsetsTopic = false) + } + + @Test + def testMetadataTopicIdsWithOldIBP(): Unit = { + val replicaAssignment = Map(0 -> Seq(1, 2, 0), 1 -> Seq(2, 0, 1)) + val topic1 = "topic1" + createTopic(topic1, replicaAssignment) + + val resp = sendMetadataRequest(new MetadataRequest.Builder(Seq(topic1, topic1).asJava, true, 10, 10).build(), Some(notControllerSocketServer)) + assertEquals(1, resp.topicMetadata.size) + resp.topicMetadata.forEach { topicMetadata => + assertEquals(Errors.NONE, topicMetadata.error) + assertEquals(Uuid.ZERO_UUID, topicMetadata.topicId()) + } + } + + // This also simulates sending a fetch to a broker that is still in the process of updating. + @Test + def testFetchTopicIdsWithOldIBPWrongFetchVersion(): Unit = { + val replicaAssignment = Map(0 -> Seq(1, 2, 0), 1 -> Seq(2, 0, 1)) + val topic1 = "topic1" + val tp0 = new TopicPartition("topic1", 0) + val maxResponseBytes = 800 + val maxPartitionBytes = 190 + val topicIds = Map("topic1" -> Uuid.randomUuid()) + val topicNames = topicIds.map(_.swap) + val tidp0 = new TopicIdPartition(topicIds(topic1), tp0) + + val leadersMap = createTopic(topic1, replicaAssignment) + val req = createFetchRequest(maxResponseBytes, maxPartitionBytes, Seq(tidp0), Map.empty, ApiKeys.FETCH.latestVersion()) + val resp = sendFetchRequest(leadersMap(0), req) + + val responseData = resp.responseData(topicNames.asJava, ApiKeys.FETCH.latestVersion()) + assertEquals(Errors.NONE.code, resp.error().code()) + assertEquals(1, responseData.size()) + assertEquals(Errors.UNKNOWN_TOPIC_ID.code, responseData.get(tp0).errorCode) + } + + @Test + def testFetchTopicIdsWithOldIBPCorrectFetchVersion(): Unit = { + val replicaAssignment = Map(0 -> Seq(1, 2, 0), 1 -> Seq(2, 0, 1)) + val topic1 = "topic1" + val tp0 = new TopicPartition("topic1", 0) + val maxResponseBytes = 800 + val maxPartitionBytes = 190 + val topicIds = Map("topic1" -> Uuid.randomUuid()) + val topicNames = topicIds.map(_.swap) + val tidp0 = new TopicIdPartition(topicIds(topic1), tp0) + + val leadersMap = createTopic(topic1, replicaAssignment) + val req = createFetchRequest(maxResponseBytes, maxPartitionBytes, Seq(tidp0), Map.empty, 12) + val resp = sendFetchRequest(leadersMap(0), req) + + assertEquals(Errors.NONE, resp.error()) + + val responseData = resp.responseData(topicNames.asJava, 12) + assertEquals(Errors.NONE.code, responseData.get(tp0).errorCode); + } + + @Test + def testDeleteTopicsWithOldIBP(): Unit = { + val timeout = 10000 + createTopic("topic-3", 5, 2) + createTopic("topic-4", 1, 2) + val request = new DeleteTopicsRequest.Builder( + new DeleteTopicsRequestData() + .setTopicNames(Arrays.asList("topic-3", "topic-4")) + .setTimeoutMs(timeout)).build() + val resp = sendDeleteTopicsRequest(request) + val error = resp.errorCounts.asScala.find(_._1 != Errors.NONE) + assertTrue(error.isEmpty, s"There should be no errors, found ${resp.data.responses.asScala}") + request.data.topicNames.forEach { topic => + validateTopicIsDeleted(topic) + } + resp.data.responses.forEach { response => + assertEquals(Uuid.ZERO_UUID, response.topicId()) + } + } + + @Test + def testDeleteTopicsWithOldIBPUsingIDs(): Unit = { + val timeout = 10000 + createTopic("topic-7", 3, 2) + createTopic("topic-6", 1, 2) + val ids = Map("topic-7" -> Uuid.randomUuid(), "topic-6" -> Uuid.randomUuid()) + val request = new DeleteTopicsRequest.Builder( + new DeleteTopicsRequestData() + .setTopics(Arrays.asList(new DeleteTopicState().setTopicId(ids("topic-7")), + new DeleteTopicState().setTopicId(ids("topic-6")) + )).setTimeoutMs(timeout)).build() + val response = sendDeleteTopicsRequest(request) + val error = response.errorCounts.asScala + assertEquals(2, error(Errors.UNKNOWN_TOPIC_ID)) + } + + private def sendMetadataRequest(request: MetadataRequest, destination: Option[SocketServer]): MetadataResponse = { + connectAndReceive[MetadataResponse](request, destination = destination.getOrElse(anySocketServer)) + } + + private def createFetchRequest(maxResponseBytes: Int, maxPartitionBytes: Int, topicPartitions: Seq[TopicIdPartition], + offsetMap: Map[TopicPartition, Long], + version: Short): FetchRequest = { + FetchRequest.Builder.forConsumer(version, Int.MaxValue, 0, createPartitionMap(maxPartitionBytes, topicPartitions, offsetMap)) + .setMaxBytes(maxResponseBytes).build() + } + + private def createPartitionMap(maxPartitionBytes: Int, topicPartitions: Seq[TopicIdPartition], + offsetMap: Map[TopicPartition, Long]): LinkedHashMap[TopicPartition, FetchRequest.PartitionData] = { + val partitionMap = new LinkedHashMap[TopicPartition, FetchRequest.PartitionData] + topicPartitions.foreach { tp => + partitionMap.put(tp.topicPartition, new FetchRequest.PartitionData(tp.topicId, offsetMap.getOrElse(tp.topicPartition, 0), 0L, maxPartitionBytes, + Optional.empty())) + } + partitionMap + } + + private def sendFetchRequest(leaderId: Int, request: FetchRequest): FetchResponse = { + connectAndReceive[FetchResponse](request, destination = brokerSocketServer(leaderId)) + } + + private def sendDeleteTopicsRequest(request: DeleteTopicsRequest, socketServer: SocketServer = controllerSocketServer): DeleteTopicsResponse = { + connectAndReceive[DeleteTopicsResponse](request, destination = socketServer) + } + + private def validateTopicIsDeleted(topic: String): Unit = { + val metadata = connectAndReceive[MetadataResponse](new MetadataRequest.Builder( + List(topic).asJava, true).build).topicMetadata.asScala + TestUtils.waitUntilTrue (() => !metadata.exists(p => p.topic.equals(topic) && p.error == Errors.NONE), + s"The topic $topic should not exist") + } + +} diff --git a/core/src/test/scala/unit/kafka/server/UpdateFeaturesTest.scala b/core/src/test/scala/unit/kafka/server/UpdateFeaturesTest.scala new file mode 100644 index 0000000..92ba042 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/UpdateFeaturesTest.scala @@ -0,0 +1,577 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import java.util.{Optional, Properties} +import java.util.concurrent.ExecutionException + +import kafka.api.KAFKA_2_7_IV0 +import kafka.utils.TestUtils +import kafka.zk.{FeatureZNode, FeatureZNodeStatus, ZkVersion} +import kafka.utils.TestUtils.waitUntilTrue +import org.apache.kafka.clients.admin.{Admin, FeatureUpdate, UpdateFeaturesOptions, UpdateFeaturesResult} +import org.apache.kafka.common.errors.InvalidRequestException +import org.apache.kafka.common.feature.FinalizedVersionRange +import org.apache.kafka.common.feature.{Features, SupportedVersionRange} +import org.apache.kafka.common.message.UpdateFeaturesRequestData +import org.apache.kafka.common.message.UpdateFeaturesRequestData.FeatureUpdateKeyCollection +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.requests.{UpdateFeaturesRequest, UpdateFeaturesResponse} +import org.apache.kafka.common.utils.Utils +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertNotEquals, assertNotNull, assertTrue, assertThrows} + +import scala.jdk.CollectionConverters._ +import scala.reflect.ClassTag +import scala.util.matching.Regex + +class UpdateFeaturesTest extends BaseRequestTest { + + override def brokerCount = 3 + + override def brokerPropertyOverrides(props: Properties): Unit = { + props.put(KafkaConfig.InterBrokerProtocolVersionProp, KAFKA_2_7_IV0.toString) + } + + private def defaultSupportedFeatures(): Features[SupportedVersionRange] = { + Features.supportedFeatures(Utils.mkMap(Utils.mkEntry("feature_1", new SupportedVersionRange(1, 3)))) + } + + private def defaultFinalizedFeatures(): Features[FinalizedVersionRange] = { + Features.finalizedFeatures(Utils.mkMap(Utils.mkEntry("feature_1", new FinalizedVersionRange(1, 2)))) + } + + private def updateSupportedFeatures( + features: Features[SupportedVersionRange], targetServers: Set[KafkaServer]): Unit = { + targetServers.foreach(s => { + s.brokerFeatures.setSupportedFeatures(features) + s.zkClient.updateBrokerInfo(s.createBrokerInfo) + }) + + // Wait until updates to all BrokerZNode supported features propagate to the controller. + val brokerIds = targetServers.map(s => s.config.brokerId) + waitUntilTrue( + () => servers.exists(s => { + if (s.kafkaController.isActive) { + s.kafkaController.controllerContext.liveOrShuttingDownBrokers + .filter(b => brokerIds.contains(b.id)) + .forall(b => { + b.features.equals(features) + }) + } else { + false + } + }), + "Controller did not get broker updates") + } + + private def updateSupportedFeaturesInAllBrokers(features: Features[SupportedVersionRange]): Unit = { + updateSupportedFeatures(features, Set[KafkaServer]() ++ servers) + } + + private def updateFeatureZNode(features: Features[FinalizedVersionRange]): Int = { + val server = serverForId(0).get + val newNode = new FeatureZNode(FeatureZNodeStatus.Enabled, features) + val newVersion = server.zkClient.updateFeatureZNode(newNode) + servers.foreach(s => { + s.featureCache.waitUntilEpochOrThrow(newVersion, s.config.zkConnectionTimeoutMs) + }) + newVersion + } + + private def getFeatureZNode(): FeatureZNode = { + val (mayBeFeatureZNodeBytes, version) = serverForId(0).get.zkClient.getDataAndVersion(FeatureZNode.path) + assertNotEquals(version, ZkVersion.UnknownVersion) + FeatureZNode.decode(mayBeFeatureZNodeBytes.get) + } + + private def finalizedFeatures(features: java.util.Map[String, org.apache.kafka.clients.admin.FinalizedVersionRange]): Features[FinalizedVersionRange] = { + Features.finalizedFeatures(features.asScala.map { + case(name, versionRange) => + (name, new FinalizedVersionRange(versionRange.minVersionLevel(), versionRange.maxVersionLevel())) + }.asJava) + } + + private def supportedFeatures(features: java.util.Map[String, org.apache.kafka.clients.admin.SupportedVersionRange]): Features[SupportedVersionRange] = { + Features.supportedFeatures(features.asScala.map { + case(name, versionRange) => + (name, new SupportedVersionRange(versionRange.minVersion(), versionRange.maxVersion())) + }.asJava) + } + + private def checkFeatures(client: Admin, + expectedNode: FeatureZNode, + expectedFinalizedFeatures: Features[FinalizedVersionRange], + expectedFinalizedFeaturesEpoch: Long, + expectedSupportedFeatures: Features[SupportedVersionRange]): Unit = { + assertEquals(expectedNode, getFeatureZNode()) + val featureMetadata = client.describeFeatures.featureMetadata.get + assertEquals(expectedFinalizedFeatures, finalizedFeatures(featureMetadata.finalizedFeatures)) + assertEquals(expectedSupportedFeatures, supportedFeatures(featureMetadata.supportedFeatures)) + assertEquals(Optional.of(expectedFinalizedFeaturesEpoch), featureMetadata.finalizedFeaturesEpoch) + } + + private def checkException[ExceptionType <: Throwable](result: UpdateFeaturesResult, + featureExceptionMsgPatterns: Map[String, Regex]) + (implicit tag: ClassTag[ExceptionType]): Unit = { + featureExceptionMsgPatterns.foreach { + case (feature, exceptionMsgPattern) => + val exception = assertThrows(classOf[ExecutionException], () => result.values().get(feature).get()) + val cause = exception.getCause + assertNotNull(cause) + assertEquals(cause.getClass, tag.runtimeClass) + assertTrue(exceptionMsgPattern.findFirstIn(cause.getMessage).isDefined, + s"Received unexpected error message: ${cause.getMessage}") + } + } + + /** + * Tests whether an invalid feature update does not get processed on the server as expected, + * and raises the ExceptionType on the client side as expected. + * + * @param feature the feature to be updated + * @param invalidUpdate the invalid feature update to be sent in the + * updateFeatures request to the server + * @param exceptionMsgPattern a pattern for the expected exception message + */ + private def testWithInvalidFeatureUpdate[ExceptionType <: Throwable](feature: String, + invalidUpdate: FeatureUpdate, + exceptionMsgPattern: Regex) + (implicit tag: ClassTag[ExceptionType]): Unit = { + TestUtils.waitUntilControllerElected(zkClient) + + updateSupportedFeaturesInAllBrokers(defaultSupportedFeatures()) + val versionBefore = updateFeatureZNode(defaultFinalizedFeatures()) + val adminClient = createAdminClient() + val nodeBefore = getFeatureZNode() + + val result = adminClient.updateFeatures(Utils.mkMap(Utils.mkEntry(feature, invalidUpdate)), new UpdateFeaturesOptions()) + + checkException[ExceptionType](result, Map(feature -> exceptionMsgPattern)) + checkFeatures( + adminClient, + nodeBefore, + defaultFinalizedFeatures(), + versionBefore, + defaultSupportedFeatures()) + } + + /** + * Tests that an UpdateFeatures request sent to a non-Controller node fails as expected. + */ + @Test + def testShouldFailRequestIfNotController(): Unit = { + TestUtils.waitUntilControllerElected(zkClient) + + updateSupportedFeaturesInAllBrokers(defaultSupportedFeatures()) + val versionBefore = updateFeatureZNode(defaultFinalizedFeatures()) + + val nodeBefore = getFeatureZNode() + val validUpdates = new FeatureUpdateKeyCollection() + val validUpdate = new UpdateFeaturesRequestData.FeatureUpdateKey(); + validUpdate.setFeature("feature_1"); + validUpdate.setMaxVersionLevel(defaultSupportedFeatures().get("feature_1").max()) + validUpdate.setAllowDowngrade(false) + validUpdates.add(validUpdate) + + val response = connectAndReceive[UpdateFeaturesResponse]( + new UpdateFeaturesRequest.Builder(new UpdateFeaturesRequestData().setFeatureUpdates(validUpdates)).build(), + notControllerSocketServer) + + assertEquals(Errors.NOT_CONTROLLER, Errors.forCode(response.data.errorCode)) + assertNotNull(response.data.errorMessage()) + assertEquals(0, response.data.results.size) + checkFeatures( + createAdminClient(), + nodeBefore, + defaultFinalizedFeatures(), + versionBefore, + defaultSupportedFeatures()) + } + + /** + * Tests that an UpdateFeatures request fails in the Controller, when, for a feature the + * allowDowngrade flag is not set during a downgrade request. + */ + @Test + def testShouldFailRequestWhenDowngradeFlagIsNotSetDuringDowngrade(): Unit = { + val targetMaxVersionLevel = (defaultFinalizedFeatures().get("feature_1").max() - 1).asInstanceOf[Short] + testWithInvalidFeatureUpdate[InvalidRequestException]( + "feature_1", + new FeatureUpdate(targetMaxVersionLevel,false), + ".*Can not downgrade finalized feature.*allowDowngrade.*".r) + } + + /** + * Tests that an UpdateFeatures request fails in the Controller, when, for a feature the downgrade + * is attempted to a max version level higher than the existing max version level. + */ + @Test + def testShouldFailRequestWhenDowngradeToHigherVersionLevelIsAttempted(): Unit = { + val targetMaxVersionLevel = (defaultFinalizedFeatures().get("feature_1").max() + 1).asInstanceOf[Short] + testWithInvalidFeatureUpdate[InvalidRequestException]( + "feature_1", + new FeatureUpdate(targetMaxVersionLevel, true), + ".*When the allowDowngrade flag set in the request, the provided maxVersionLevel:3.*existing maxVersionLevel:2.*".r) + } + + /** + * Tests that an UpdateFeatures request fails in the Controller, when, a feature deletion is + * attempted without setting the allowDowngrade flag. + */ + @Test + def testShouldFailRequestInServerWhenDowngradeFlagIsNotSetDuringDeletion(): Unit = { + TestUtils.waitUntilControllerElected(zkClient) + + updateSupportedFeaturesInAllBrokers(defaultSupportedFeatures()) + val versionBefore = updateFeatureZNode(defaultFinalizedFeatures()) + + val adminClient = createAdminClient() + val nodeBefore = getFeatureZNode() + + val invalidUpdates + = new UpdateFeaturesRequestData.FeatureUpdateKeyCollection(); + val invalidUpdate = new UpdateFeaturesRequestData.FeatureUpdateKey(); + invalidUpdate.setFeature("feature_1") + invalidUpdate.setMaxVersionLevel(0) + invalidUpdate.setAllowDowngrade(false) + invalidUpdates.add(invalidUpdate); + val requestData = new UpdateFeaturesRequestData() + requestData.setFeatureUpdates(invalidUpdates); + + val response = connectAndReceive[UpdateFeaturesResponse]( + new UpdateFeaturesRequest.Builder(new UpdateFeaturesRequestData().setFeatureUpdates(invalidUpdates)).build(), + controllerSocketServer) + + assertEquals(1, response.data().results().size()) + val result = response.data.results.asScala.head + assertEquals("feature_1", result.feature) + assertEquals(Errors.INVALID_REQUEST, Errors.forCode(result.errorCode)) + assertNotNull(result.errorMessage) + assertFalse(result.errorMessage.isEmpty) + val exceptionMsgPattern = ".*Can not provide maxVersionLevel: 0 less than 1.*allowDowngrade.*".r + assertTrue(exceptionMsgPattern.findFirstIn(result.errorMessage).isDefined, result.errorMessage) + checkFeatures( + adminClient, + nodeBefore, + defaultFinalizedFeatures(), + versionBefore, + defaultSupportedFeatures()) + } + + /** + * Tests that an UpdateFeatures request fails in the Controller, when, a feature version level + * upgrade is attempted for a non-existing feature. + */ + @Test + def testShouldFailRequestDuringDeletionOfNonExistingFeature(): Unit = { + testWithInvalidFeatureUpdate[InvalidRequestException]( + "feature_non_existing", + new FeatureUpdate(3, true), + ".*Could not apply finalized feature update because the provided feature is not supported.*".r) + } + + /** + * Tests that an UpdateFeatures request fails in the Controller, when, a feature version level + * upgrade is attempted to a version level same as the existing max version level. + */ + @Test + def testShouldFailRequestWhenUpgradingToSameVersionLevel(): Unit = { + val targetMaxVersionLevel = defaultFinalizedFeatures().get("feature_1").max() + testWithInvalidFeatureUpdate[InvalidRequestException]( + "feature_1", + new FeatureUpdate(targetMaxVersionLevel, false), + ".*Can not upgrade a finalized feature.*to the same value.*".r) + } + + private def testShouldFailRequestDuringBrokerMaxVersionLevelIncompatibility( + featureName: String, + supportedVersionRange: SupportedVersionRange, + initialFinalizedVersionRange: Option[FinalizedVersionRange] + ): Unit = { + TestUtils.waitUntilControllerElected(zkClient) + + val controller = servers.filter { server => server.kafkaController.isActive}.head + val nonControllerServers = servers.filter { server => !server.kafkaController.isActive} + // We setup the supported features on the broker such that 1/3 of the brokers does not + // support an expected feature version, while 2/3 brokers support the expected feature + // version. + val brokersWithVersionIncompatibility = Set[KafkaServer](nonControllerServers.head) + val versionCompatibleBrokers = Set[KafkaServer](nonControllerServers(1), controller) + + val supportedFeatures = Features.supportedFeatures(Utils.mkMap(Utils.mkEntry(featureName, supportedVersionRange))) + updateSupportedFeatures(supportedFeatures, versionCompatibleBrokers) + + val unsupportedMaxVersion = (supportedVersionRange.max() - 1).asInstanceOf[Short] + val supportedFeaturesWithVersionIncompatibility = Features.supportedFeatures( + Utils.mkMap( + Utils.mkEntry("feature_1", + new SupportedVersionRange( + supportedVersionRange.min(), + unsupportedMaxVersion)))) + updateSupportedFeatures(supportedFeaturesWithVersionIncompatibility, brokersWithVersionIncompatibility) + + val initialFinalizedFeatures = initialFinalizedVersionRange.map( + versionRange => Features.finalizedFeatures(Utils.mkMap(Utils.mkEntry(featureName, versionRange))) + ).getOrElse(Features.emptyFinalizedFeatures()) + val versionBefore = updateFeatureZNode(initialFinalizedFeatures) + + val invalidUpdate = new FeatureUpdate(supportedVersionRange.max(), false) + val nodeBefore = getFeatureZNode() + val adminClient = createAdminClient() + val result = adminClient.updateFeatures( + Utils.mkMap(Utils.mkEntry("feature_1", invalidUpdate)), + new UpdateFeaturesOptions()) + + checkException[InvalidRequestException](result, Map("feature_1" -> ".*brokers.*incompatible.*".r)) + checkFeatures( + adminClient, + nodeBefore, + initialFinalizedFeatures, + versionBefore, + supportedFeatures) + } + + /** + * Tests that an UpdateFeatures request fails in the Controller, when for an existing finalized + * feature, a version level upgrade introduces a version incompatibility with existing supported + * features. + */ + @Test + def testShouldFailRequestDuringBrokerMaxVersionLevelIncompatibilityForExistingFinalizedFeature(): Unit = { + val feature = "feature_1" + testShouldFailRequestDuringBrokerMaxVersionLevelIncompatibility( + feature, + defaultSupportedFeatures().get(feature), + Some(defaultFinalizedFeatures().get(feature))) + } + + /** + * Tests that an UpdateFeatures request fails in the Controller, when for a non-existing finalized + * feature, a version level upgrade introduces a version incompatibility with existing supported + * features. + */ + @Test + def testShouldFailRequestDuringBrokerMaxVersionLevelIncompatibilityWithNoExistingFinalizedFeature(): Unit = { + val feature = "feature_1" + testShouldFailRequestDuringBrokerMaxVersionLevelIncompatibility( + feature, + defaultSupportedFeatures().get(feature), + Option.empty) + } + + /** + * Tests that an UpdateFeatures request succeeds in the Controller, when, there are no existing + * finalized features in FeatureZNode when the test starts. + */ + @Test + def testSuccessfulFeatureUpgradeAndWithNoExistingFinalizedFeatures(): Unit = { + TestUtils.waitUntilControllerElected(zkClient) + + val supportedFeatures = + Features.supportedFeatures( + Utils.mkMap( + Utils.mkEntry("feature_1", new SupportedVersionRange(1, 3)), + Utils.mkEntry("feature_2", new SupportedVersionRange(2, 5)))) + updateSupportedFeaturesInAllBrokers(supportedFeatures) + val versionBefore = updateFeatureZNode(Features.emptyFinalizedFeatures()) + + val targetFinalizedFeatures = Features.finalizedFeatures( + Utils.mkMap( + Utils.mkEntry("feature_1", new FinalizedVersionRange(1, 3)), + Utils.mkEntry("feature_2", new FinalizedVersionRange(2, 3)))) + val update1 = new FeatureUpdate(targetFinalizedFeatures.get("feature_1").max(), false) + val update2 = new FeatureUpdate(targetFinalizedFeatures.get("feature_2").max(), false) + + val adminClient = createAdminClient() + adminClient.updateFeatures( + Utils.mkMap(Utils.mkEntry("feature_1", update1), Utils.mkEntry("feature_2", update2)), + new UpdateFeaturesOptions() + ).all().get() + + checkFeatures( + adminClient, + new FeatureZNode(FeatureZNodeStatus.Enabled, targetFinalizedFeatures), + targetFinalizedFeatures, + versionBefore + 1, + supportedFeatures) + } + + /** + * Tests that an UpdateFeatures request succeeds in the Controller, when, the request contains + * both a valid feature version level upgrade as well as a downgrade request. + */ + @Test + def testSuccessfulFeatureUpgradeAndDowngrade(): Unit = { + TestUtils.waitUntilControllerElected(zkClient) + + val supportedFeatures = Features.supportedFeatures( + Utils.mkMap( + Utils.mkEntry("feature_1", new SupportedVersionRange(1, 3)), + Utils.mkEntry("feature_2", new SupportedVersionRange(2, 5)))) + updateSupportedFeaturesInAllBrokers(supportedFeatures) + val initialFinalizedFeatures = Features.finalizedFeatures( + Utils.mkMap( + Utils.mkEntry("feature_1", new FinalizedVersionRange(1, 2)), + Utils.mkEntry("feature_2", new FinalizedVersionRange(2, 4)))) + val versionBefore = updateFeatureZNode(initialFinalizedFeatures) + + // Below we aim to do the following: + // - Valid upgrade of feature_1 maxVersionLevel from 2 to 3 + // - Valid downgrade of feature_2 maxVersionLevel from 4 to 3 + val targetFinalizedFeatures = Features.finalizedFeatures( + Utils.mkMap( + Utils.mkEntry("feature_1", new FinalizedVersionRange(1, 3)), + Utils.mkEntry("feature_2", new FinalizedVersionRange(2, 3)))) + val update1 = new FeatureUpdate(targetFinalizedFeatures.get("feature_1").max(), false) + val update2 = new FeatureUpdate(targetFinalizedFeatures.get("feature_2").max(), true) + + val adminClient = createAdminClient() + adminClient.updateFeatures( + Utils.mkMap(Utils.mkEntry("feature_1", update1), Utils.mkEntry("feature_2", update2)), + new UpdateFeaturesOptions() + ).all().get() + + checkFeatures( + adminClient, + new FeatureZNode(FeatureZNodeStatus.Enabled, targetFinalizedFeatures), + targetFinalizedFeatures, + versionBefore + 1, + supportedFeatures) + } + + /** + * Tests that an UpdateFeatures request succeeds partially in the Controller, when, the request + * contains a valid feature version level upgrade and an invalid version level downgrade. + * i.e. expect the upgrade operation to succeed, and the downgrade operation to fail. + */ + @Test + def testPartialSuccessDuringValidFeatureUpgradeAndInvalidDowngrade(): Unit = { + TestUtils.waitUntilControllerElected(zkClient) + + val supportedFeatures = Features.supportedFeatures( + Utils.mkMap( + Utils.mkEntry("feature_1", new SupportedVersionRange(1, 3)), + Utils.mkEntry("feature_2", new SupportedVersionRange(2, 5)))) + updateSupportedFeaturesInAllBrokers(supportedFeatures) + val initialFinalizedFeatures = Features.finalizedFeatures( + Utils.mkMap( + Utils.mkEntry("feature_1", new FinalizedVersionRange(1, 2)), + Utils.mkEntry("feature_2", new FinalizedVersionRange(2, 4)))) + val versionBefore = updateFeatureZNode(initialFinalizedFeatures) + + // Below we aim to do the following: + // - Valid upgrade of feature_1 maxVersionLevel from 2 to 3 + // - Invalid downgrade of feature_2 maxVersionLevel from 4 to 3 + // (because we intentionally do not set the allowDowngrade flag) + val targetFinalizedFeatures = Features.finalizedFeatures( + Utils.mkMap( + Utils.mkEntry("feature_1", new FinalizedVersionRange(1, 3)), + Utils.mkEntry("feature_2", new FinalizedVersionRange(2, 3)))) + val validUpdate = new FeatureUpdate(targetFinalizedFeatures.get("feature_1").max(), false) + val invalidUpdate = new FeatureUpdate(targetFinalizedFeatures.get("feature_2").max(), false) + + val adminClient = createAdminClient() + val result = adminClient.updateFeatures( + Utils.mkMap(Utils.mkEntry("feature_1", validUpdate), Utils.mkEntry("feature_2", invalidUpdate)), + new UpdateFeaturesOptions()) + + // Expect update for "feature_1" to have succeeded. + result.values().get("feature_1").get() + // Expect update for "feature_2" to have failed. + checkException[InvalidRequestException]( + result, Map("feature_2" -> ".*Can not downgrade finalized feature.*allowDowngrade.*".r)) + val expectedFeatures = Features.finalizedFeatures( + Utils.mkMap( + Utils.mkEntry("feature_1", targetFinalizedFeatures.get("feature_1")), + Utils.mkEntry("feature_2", initialFinalizedFeatures.get("feature_2")))) + checkFeatures( + adminClient, + FeatureZNode(FeatureZNodeStatus.Enabled, expectedFeatures), + expectedFeatures, + versionBefore + 1, + supportedFeatures) + } + + /** + * Tests that an UpdateFeatures request succeeds partially in the Controller, when, the request + * contains an invalid feature version level upgrade and a valid version level downgrade. + * i.e. expect the downgrade operation to succeed, and the upgrade operation to fail. + */ + @Test + def testPartialSuccessDuringInvalidFeatureUpgradeAndValidDowngrade(): Unit = { + TestUtils.waitUntilControllerElected(zkClient) + + val controller = servers.filter { server => server.kafkaController.isActive}.head + val nonControllerServers = servers.filter { server => !server.kafkaController.isActive} + // We setup the supported features on the broker such that 1/3 of the brokers does not + // support an expected feature version, while 2/3 brokers support the expected feature + // version. + val brokersWithVersionIncompatibility = Set[KafkaServer](nonControllerServers.head) + val versionCompatibleBrokers = Set[KafkaServer](nonControllerServers(1), controller) + + val supportedFeatures = Features.supportedFeatures( + Utils.mkMap( + Utils.mkEntry("feature_1", new SupportedVersionRange(1, 3)), + Utils.mkEntry("feature_2", new SupportedVersionRange(2, 5)))) + updateSupportedFeatures(supportedFeatures, versionCompatibleBrokers) + + val supportedFeaturesWithVersionIncompatibility = Features.supportedFeatures( + Utils.mkMap( + Utils.mkEntry("feature_1", new SupportedVersionRange(1, 2)), + Utils.mkEntry("feature_2", supportedFeatures.get("feature_2")))) + updateSupportedFeatures(supportedFeaturesWithVersionIncompatibility, brokersWithVersionIncompatibility) + + val initialFinalizedFeatures = Features.finalizedFeatures( + Utils.mkMap( + Utils.mkEntry("feature_1", new FinalizedVersionRange(1, 2)), + Utils.mkEntry("feature_2", new FinalizedVersionRange(2, 4)))) + val versionBefore = updateFeatureZNode(initialFinalizedFeatures) + + // Below we aim to do the following: + // - Invalid upgrade of feature_1 maxVersionLevel from 2 to 3 + // (because one of the brokers does not support the max version: 3) + // - Valid downgrade of feature_2 maxVersionLevel from 4 to 3 + val targetFinalizedFeatures = Features.finalizedFeatures( + Utils.mkMap( + Utils.mkEntry("feature_1", new FinalizedVersionRange(1, 3)), + Utils.mkEntry("feature_2", new FinalizedVersionRange(2, 3)))) + val invalidUpdate = new FeatureUpdate(targetFinalizedFeatures.get("feature_1").max(), false) + val validUpdate = new FeatureUpdate(targetFinalizedFeatures.get("feature_2").max(), true) + + val adminClient = createAdminClient() + val result = adminClient.updateFeatures( + Utils.mkMap(Utils.mkEntry("feature_1", invalidUpdate), Utils.mkEntry("feature_2", validUpdate)), + new UpdateFeaturesOptions()) + + // Expect update for "feature_2" to have succeeded. + result.values().get("feature_2").get() + // Expect update for "feature_1" to have failed. + checkException[InvalidRequestException](result, Map("feature_1" -> ".*brokers.*incompatible.*".r)) + val expectedFeatures = Features.finalizedFeatures( + Utils.mkMap( + Utils.mkEntry("feature_1", initialFinalizedFeatures.get("feature_1")), + Utils.mkEntry("feature_2", targetFinalizedFeatures.get("feature_2")))) + checkFeatures( + adminClient, + FeatureZNode(FeatureZNodeStatus.Enabled, expectedFeatures), + expectedFeatures, + versionBefore + 1, + supportedFeatures) + } +} diff --git a/core/src/test/scala/unit/kafka/server/ZkAdminManagerTest.scala b/core/src/test/scala/unit/kafka/server/ZkAdminManagerTest.scala new file mode 100644 index 0000000..3533133 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/ZkAdminManagerTest.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server + +import kafka.zk.{AdminZkClient, KafkaZkClient} +import org.apache.kafka.common.metrics.Metrics +import org.easymock.EasyMock +import kafka.utils.TestUtils +import org.apache.kafka.common.config.ConfigResource +import org.apache.kafka.common.message.DescribeConfigsRequestData +import org.apache.kafka.common.message.DescribeConfigsResponseData +import org.apache.kafka.common.protocol.Errors +import org.junit.jupiter.api.{AfterEach, Test} +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertFalse +import org.junit.jupiter.api.Assertions.assertNotNull +import org.junit.jupiter.api.Assertions.assertNotEquals +import java.util.Properties + +import kafka.server.metadata.ZkConfigRepository + +import scala.jdk.CollectionConverters._ + +class ZkAdminManagerTest { + + private val zkClient: KafkaZkClient = EasyMock.createNiceMock(classOf[KafkaZkClient]) + private val metrics = new Metrics() + private val brokerId = 1 + private val topic = "topic-1" + private val metadataCache: MetadataCache = EasyMock.createNiceMock(classOf[MetadataCache]) + + @AfterEach + def tearDown(): Unit = { + metrics.close() + } + + def createConfigHelper(metadataCache: MetadataCache, zkClient: KafkaZkClient): ConfigHelper = { + val props = TestUtils.createBrokerConfig(brokerId, "zk") + new ConfigHelper(metadataCache, KafkaConfig.fromProps(props), new ZkConfigRepository(new AdminZkClient(zkClient))) + } + + @Test + def testDescribeConfigsWithNullConfigurationKeys(): Unit = { + EasyMock.expect(zkClient.getEntityConfigs(ConfigType.Topic, topic)).andReturn(TestUtils.createBrokerConfig(brokerId, "zk")) + EasyMock.expect(metadataCache.contains(topic)).andReturn(true) + + EasyMock.replay(zkClient, metadataCache) + + val resources = List(new DescribeConfigsRequestData.DescribeConfigsResource() + .setResourceName(topic) + .setResourceType(ConfigResource.Type.TOPIC.id) + .setConfigurationKeys(null)) + val configHelper = createConfigHelper(metadataCache, zkClient) + val results: List[DescribeConfigsResponseData.DescribeConfigsResult] = configHelper.describeConfigs(resources, true, true) + assertEquals(Errors.NONE.code, results.head.errorCode()) + assertFalse(results.head.configs().isEmpty, "Should return configs") + } + + @Test + def testDescribeConfigsWithEmptyConfigurationKeys(): Unit = { + EasyMock.expect(zkClient.getEntityConfigs(ConfigType.Topic, topic)).andReturn(TestUtils.createBrokerConfig(brokerId, "zk")) + EasyMock.expect(metadataCache.contains(topic)).andReturn(true) + + EasyMock.replay(zkClient, metadataCache) + + val resources = List(new DescribeConfigsRequestData.DescribeConfigsResource() + .setResourceName(topic) + .setResourceType(ConfigResource.Type.TOPIC.id)) + val configHelper = createConfigHelper(metadataCache, zkClient) + val results: List[DescribeConfigsResponseData.DescribeConfigsResult] = configHelper.describeConfigs(resources, true, true) + assertEquals(Errors.NONE.code, results.head.errorCode()) + assertFalse(results.head.configs().isEmpty, "Should return configs") + } + + @Test + def testDescribeConfigsWithConfigurationKeys(): Unit = { + EasyMock.expect(zkClient.getEntityConfigs(ConfigType.Topic, topic)).andReturn(TestUtils.createBrokerConfig(brokerId, "zk")) + EasyMock.expect(metadataCache.contains(topic)).andReturn(true) + + EasyMock.replay(zkClient, metadataCache) + + val resources = List(new DescribeConfigsRequestData.DescribeConfigsResource() + .setResourceName(topic) + .setResourceType(ConfigResource.Type.TOPIC.id) + .setConfigurationKeys(List("retention.ms", "retention.bytes", "segment.bytes").asJava) + ) + val configHelper = createConfigHelper(metadataCache, zkClient) + val results: List[DescribeConfigsResponseData.DescribeConfigsResult] = configHelper.describeConfigs(resources, true, true) + assertEquals(Errors.NONE.code, results.head.errorCode()) + val resultConfigKeys = results.head.configs().asScala.map(r => r.name()).toSet + assertEquals(Set("retention.ms", "retention.bytes", "segment.bytes"), resultConfigKeys) + } + + @Test + def testDescribeConfigsWithDocumentation(): Unit = { + EasyMock.expect(zkClient.getEntityConfigs(ConfigType.Topic, topic)).andReturn(new Properties) + EasyMock.expect(zkClient.getEntityConfigs(ConfigType.Broker, brokerId.toString)).andReturn(new Properties) + EasyMock.expect(metadataCache.contains(topic)).andReturn(true) + EasyMock.replay(zkClient, metadataCache) + + val configHelper = createConfigHelper(metadataCache, zkClient) + + val resources = List( + new DescribeConfigsRequestData.DescribeConfigsResource() + .setResourceName(topic) + .setResourceType(ConfigResource.Type.TOPIC.id), + new DescribeConfigsRequestData.DescribeConfigsResource() + .setResourceName(brokerId.toString) + .setResourceType(ConfigResource.Type.BROKER.id)) + + val results: List[DescribeConfigsResponseData.DescribeConfigsResult] = configHelper.describeConfigs(resources, true, true) + assertEquals(2, results.size) + results.foreach(r => { + assertEquals(Errors.NONE.code, r.errorCode) + assertFalse(r.configs.isEmpty, "Should return configs") + r.configs.forEach(c => { + assertNotNull(c.documentation, s"Config ${c.name} should have non null documentation") + assertNotEquals(s"Config ${c.name} should have non blank documentation", "", c.documentation.trim) + }) + }) + } +} diff --git a/core/src/test/scala/unit/kafka/server/checkpoints/LeaderEpochCheckpointFileWithFailureHandlerTest.scala b/core/src/test/scala/unit/kafka/server/checkpoints/LeaderEpochCheckpointFileWithFailureHandlerTest.scala new file mode 100644 index 0000000..5ac9202 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/checkpoints/LeaderEpochCheckpointFileWithFailureHandlerTest.scala @@ -0,0 +1,70 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server.checkpoints + +import java.io.File + +import kafka.server.epoch.EpochEntry +import kafka.utils.Logging +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +class LeaderEpochCheckpointFileWithFailureHandlerTest extends Logging { + + @Test + def shouldPersistAndOverwriteAndReloadFile(): Unit ={ + val file = File.createTempFile("temp-checkpoint-file", System.nanoTime().toString) + file.deleteOnExit() + + val checkpoint = new LeaderEpochCheckpointFile(file) + + //Given + val epochs = Seq(EpochEntry(0, 1L), EpochEntry(1, 2L), EpochEntry(2, 3L)) + + //When + checkpoint.write(epochs) + + //Then + assertEquals(epochs, checkpoint.read()) + + //Given overwrite + val epochs2 = Seq(EpochEntry(3, 4L), EpochEntry(4, 5L)) + + //When + checkpoint.write(epochs2) + + //Then + assertEquals(epochs2, checkpoint.read()) + } + + @Test + def shouldRetainValuesEvenIfCheckpointIsRecreated(): Unit ={ + val file = File.createTempFile("temp-checkpoint-file", System.nanoTime().toString) + file.deleteOnExit() + + //Given a file with data in + val checkpoint = new LeaderEpochCheckpointFile(file) + val epochs = Seq(EpochEntry(0, 1L), EpochEntry(1, 2L), EpochEntry(2, 3L)) + checkpoint.write(epochs) + + //When we recreate + val checkpoint2 = new LeaderEpochCheckpointFile(file) + + //The data should still be there + assertEquals(epochs, checkpoint2.read()) + } +} diff --git a/core/src/test/scala/unit/kafka/server/checkpoints/OffsetCheckpointFileWithFailureHandlerTest.scala b/core/src/test/scala/unit/kafka/server/checkpoints/OffsetCheckpointFileWithFailureHandlerTest.scala new file mode 100644 index 0000000..4889c54 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/checkpoints/OffsetCheckpointFileWithFailureHandlerTest.scala @@ -0,0 +1,134 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server.checkpoints + +import kafka.server.LogDirFailureChannel +import kafka.utils.{Logging, TestUtils} +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.errors.KafkaStorageException +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test +import org.mockito.Mockito + +import scala.collection.Map + +class OffsetCheckpointFileWithFailureHandlerTest extends Logging { + + @Test + def shouldPersistAndOverwriteAndReloadFile(): Unit = { + + val checkpoint = new OffsetCheckpointFile(TestUtils.tempFile()) + + //Given + val offsets = Map(new TopicPartition("foo", 1) -> 5L, new TopicPartition("bar", 2) -> 10L) + + //When + checkpoint.write(offsets) + + //Then + assertEquals(offsets, checkpoint.read()) + + //Given overwrite + val offsets2 = Map(new TopicPartition("foo", 2) -> 15L, new TopicPartition("bar", 3) -> 20L) + + //When + checkpoint.write(offsets2) + + //Then + assertEquals(offsets2, checkpoint.read()) + } + + @Test + def shouldHandleMultipleLines(): Unit = { + + val checkpoint = new OffsetCheckpointFile(TestUtils.tempFile()) + + //Given + val offsets = Map( + new TopicPartition("foo", 1) -> 5L, new TopicPartition("bar", 6) -> 10L, + new TopicPartition("foo", 2) -> 5L, new TopicPartition("bar", 7) -> 10L, + new TopicPartition("foo", 3) -> 5L, new TopicPartition("bar", 8) -> 10L, + new TopicPartition("foo", 4) -> 5L, new TopicPartition("bar", 9) -> 10L, + new TopicPartition("foo", 5) -> 5L, new TopicPartition("bar", 10) -> 10L + ) + + //When + checkpoint.write(offsets) + + //Then + assertEquals(offsets, checkpoint.read()) + } + + @Test + def shouldReturnEmptyMapForEmptyFile(): Unit = { + + //When + val checkpoint = new OffsetCheckpointFile(TestUtils.tempFile()) + + //Then + assertEquals(Map(), checkpoint.read()) + + //When + checkpoint.write(Map()) + + //Then + assertEquals(Map(), checkpoint.read()) + } + + @Test + def shouldThrowIfVersionIsNotRecognised(): Unit = { + val file = TestUtils.tempFile() + val logDirFailureChannel = new LogDirFailureChannel(10) + val checkpointFile = new CheckpointFileWithFailureHandler(file, OffsetCheckpointFile.CurrentVersion + 1, + OffsetCheckpointFile.Formatter, logDirFailureChannel, file.getParent) + checkpointFile.write(Seq(new TopicPartition("foo", 5) -> 10L)) + assertThrows(classOf[KafkaStorageException], () => new OffsetCheckpointFile(checkpointFile.file, logDirFailureChannel).read()) + } + + @Test + def testLazyOffsetCheckpoint(): Unit = { + val logDir = "/tmp/kafka-logs" + val mockCheckpointFile = Mockito.mock(classOf[OffsetCheckpointFile]) + + val lazyCheckpoints = new LazyOffsetCheckpoints(Map(logDir -> mockCheckpointFile)) + Mockito.verify(mockCheckpointFile, Mockito.never()).read() + + val partition0 = new TopicPartition("foo", 0) + val partition1 = new TopicPartition("foo", 1) + val partition2 = new TopicPartition("foo", 2) + + Mockito.when(mockCheckpointFile.read()).thenReturn(Map( + partition0 -> 1000L, + partition1 -> 2000L + )) + + assertEquals(Some(1000L), lazyCheckpoints.fetch(logDir, partition0)) + assertEquals(Some(2000L), lazyCheckpoints.fetch(logDir, partition1)) + assertEquals(None, lazyCheckpoints.fetch(logDir, partition2)) + + Mockito.verify(mockCheckpointFile, Mockito.times(1)).read() + } + + @Test + def testLazyOffsetCheckpointFileInvalidLogDir(): Unit = { + val logDir = "/tmp/kafka-logs" + val mockCheckpointFile = Mockito.mock(classOf[OffsetCheckpointFile]) + val lazyCheckpoints = new LazyOffsetCheckpoints(Map(logDir -> mockCheckpointFile)) + assertThrows(classOf[IllegalArgumentException], () => lazyCheckpoints.fetch("/invalid/kafka-logs", new TopicPartition("foo", 0))) + } + +} diff --git a/core/src/test/scala/unit/kafka/server/epoch/EpochDrivenReplicationProtocolAcceptanceTest.scala b/core/src/test/scala/unit/kafka/server/epoch/EpochDrivenReplicationProtocolAcceptanceTest.scala new file mode 100644 index 0000000..72d2866 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/epoch/EpochDrivenReplicationProtocolAcceptanceTest.scala @@ -0,0 +1,472 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server.epoch + +import java.io.{File, RandomAccessFile} +import java.util.Properties +import kafka.api.ApiVersion +import kafka.log.{UnifiedLog, LogLoader} +import kafka.server.KafkaConfig._ +import kafka.server.{KafkaConfig, KafkaServer} +import kafka.tools.DumpLogSegments +import kafka.utils.{CoreUtils, Logging, TestUtils} +import kafka.utils.TestUtils._ +import kafka.server.QuorumTestHarness +import org.apache.kafka.clients.consumer.{ConsumerConfig, KafkaConsumer} +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord} +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.record.RecordBatch +import org.apache.kafka.common.serialization.ByteArrayDeserializer +import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue} +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +import scala.jdk.CollectionConverters._ +import scala.collection.mutable.{ListBuffer => Buffer} +import scala.collection.Seq + +/** + * These tests were written to assert the addition of leader epochs to the replication protocol fix the problems + * described in KIP-101. + * + * https://cwiki.apache.org/confluence/display/KAFKA/KIP-101+-+Alter+Replication+Protocol+to+use+Leader+Epoch+rather+than+High+Watermark+for+Truncation + * + * A test which validates the end to end workflow is also included. + */ +class EpochDrivenReplicationProtocolAcceptanceTest extends QuorumTestHarness with Logging { + + // Set this to KAFKA_0_11_0_IV1 to demonstrate the tests failing in the pre-KIP-101 case + val apiVersion = ApiVersion.latestVersion + val topic = "topic1" + val msg = new Array[Byte](1000) + val msgBigger = new Array[Byte](10000) + var brokers: Seq[KafkaServer] = null + var producer: KafkaProducer[Array[Byte], Array[Byte]] = null + var consumer: KafkaConsumer[Array[Byte], Array[Byte]] = null + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + } + + @AfterEach + override def tearDown(): Unit = { + producer.close() + TestUtils.shutdownServers(brokers) + super.tearDown() + } + + @Test + def shouldFollowLeaderEpochBasicWorkflow(): Unit = { + + //Given 2 brokers + brokers = (100 to 101).map(createBroker(_)) + + //A single partition topic with 2 replicas + TestUtils.createTopic(zkClient, topic, Map(0 -> Seq(100, 101)), brokers) + producer = createProducer + val tp = new TopicPartition(topic, 0) + + //When one record is written to the leader + producer.send(new ProducerRecord(topic, 0, null, msg)).get + + //The message should have epoch 0 stamped onto it in both leader and follower + assertEquals(0, latestRecord(leader).partitionLeaderEpoch) + assertEquals(0, latestRecord(follower).partitionLeaderEpoch) + + //Both leader and follower should have recorded Epoch 0 at Offset 0 + assertEquals(Buffer(EpochEntry(0, 0)), epochCache(leader).epochEntries) + assertEquals(Buffer(EpochEntry(0, 0)), epochCache(follower).epochEntries) + + //Bounce the follower + bounce(follower) + awaitISR(tp) + + //Nothing happens yet as we haven't sent any new messages. + assertEquals(Buffer(EpochEntry(0, 0), EpochEntry(1, 1)), epochCache(leader).epochEntries) + assertEquals(Buffer(EpochEntry(0, 0)), epochCache(follower).epochEntries) + + //Send a message + producer.send(new ProducerRecord(topic, 0, null, msg)).get + + //Epoch1 should now propagate to the follower with the written message + assertEquals(Buffer(EpochEntry(0, 0), EpochEntry(1, 1)), epochCache(leader).epochEntries) + assertEquals(Buffer(EpochEntry(0, 0), EpochEntry(1, 1)), epochCache(follower).epochEntries) + + //The new message should have epoch 1 stamped + assertEquals(1, latestRecord(leader).partitionLeaderEpoch()) + assertEquals(1, latestRecord(follower).partitionLeaderEpoch()) + + //Bounce the leader. Epoch -> 2 + bounce(leader) + awaitISR(tp) + + //Epochs 2 should be added to the leader, but not on the follower (yet), as there has been no replication. + assertEquals(Buffer(EpochEntry(0, 0), EpochEntry(1, 1), EpochEntry(2, 2)), epochCache(leader).epochEntries) + assertEquals(Buffer(EpochEntry(0, 0), EpochEntry(1, 1)), epochCache(follower).epochEntries) + + //Send a message + producer.send(new ProducerRecord(topic, 0, null, msg)).get + + //This should case epoch 2 to propagate to the follower + assertEquals(2, latestRecord(leader).partitionLeaderEpoch()) + assertEquals(2, latestRecord(follower).partitionLeaderEpoch()) + + //The leader epoch files should now match on leader and follower + assertEquals(Buffer(EpochEntry(0, 0), EpochEntry(1, 1), EpochEntry(2, 2)), epochCache(leader).epochEntries) + assertEquals(Buffer(EpochEntry(0, 0), EpochEntry(1, 1), EpochEntry(2, 2)), epochCache(follower).epochEntries) + } + + @Test + def shouldNotAllowDivergentLogs(): Unit = { + //Given two brokers + brokers = (100 to 101).map { id => createServer(fromProps(createBrokerConfig(id, zkConnect))) } + val broker100 = brokers(0) + val broker101 = brokers(1) + + //A single partition topic with 2 replicas + TestUtils.createTopic(zkClient, topic, Map(0 -> Seq(100, 101)), brokers) + producer = createProducer + + //Write 10 messages (ensure they are not batched so we can truncate in the middle below) + (0 until 10).foreach { i => + producer.send(new ProducerRecord(topic, 0, s"$i".getBytes, msg)).get() + } + + //Stop the brokers (broker 101 first so that 100 is the leader) + broker101.shutdown() + broker100.shutdown() + + //Delete the clean shutdown file to simulate crash + new File(broker100.config.logDirs.head, LogLoader.CleanShutdownFile).delete() + + //Delete 5 messages from the leader's log on 100 + deleteMessagesFromLogFile(5 * msg.length, broker100, 0) + + //Restart broker 100 + broker100.startup() + + //Bounce the producer (this is required since the broker uses a random port) + producer.close() + producer = createProducer + + //Write ten additional messages + (11 until 20).map { i => + producer.send(new ProducerRecord(topic, 0, s"$i".getBytes, msg)) + }.foreach(_.get()) + + //Start broker 101 (we expect it to truncate to match broker 100's log) + broker101.startup() + + //Wait for replication to resync + waitForLogsToMatch(broker100, broker101) + + assertEquals(getLogFile(brokers(0), 0).length, getLogFile(brokers(1), 0).length, "Log files should match Broker0 vs Broker 1") + } + + //We can reproduce the pre-KIP-101 failure of this test by setting KafkaConfig.InterBrokerProtocolVersionProp = KAFKA_0_11_0_IV1 + @Test + def offsetsShouldNotGoBackwards(): Unit = { + + //Given two brokers + brokers = (100 to 101).map(createBroker(_)) + + //A single partition topic with 2 replicas + TestUtils.createTopic(zkClient, topic, Map(0 -> Seq(100, 101)), brokers) + producer = createBufferingProducer + + //Write 100 messages + (0 until 100).foreach { i => + producer.send(new ProducerRecord(topic, 0, null, msg)) + producer.flush() + } + + //Stop the brokers + brokers.foreach { b => b.shutdown() } + + //Delete the clean shutdown file to simulate crash + new File(brokers(0).config.logDirs(0), LogLoader.CleanShutdownFile).delete() + + //Delete half the messages from the log file + deleteMessagesFromLogFile(getLogFile(brokers(0), 0).length() / 2, brokers(0), 0) + + //Start broker 100 again + brokers(0).startup() + + //Bounce the producer (this is required, although I'm unsure as to why?) + producer.close() + producer = createBufferingProducer + + //Write two large batches of messages. This will ensure that the LeO of the follower's log aligns with the middle + //of the a compressed message set in the leader (which, when forwarded, will result in offsets going backwards) + (0 until 77).foreach { _ => + producer.send(new ProducerRecord(topic, 0, null, msg)) + } + producer.flush() + (0 until 77).foreach { _ => + producer.send(new ProducerRecord(topic, 0, null, msg)) + } + producer.flush() + + printSegments() + + //Start broker 101. When it comes up it should read a whole batch of messages from the leader. + //As the chronology is lost we would end up with non-monatonic offsets (pre kip-101) + brokers(1).startup() + + //Wait for replication to resync + waitForLogsToMatch(brokers(0), brokers(1)) + + printSegments() + + //Shut down broker 100, so we read from broker 101 which should have corrupted + brokers(0).shutdown() + + //Search to see if we have non-monotonic offsets in the log + startConsumer() + val records = TestUtils.pollUntilAtLeastNumRecords(consumer, 100) + var prevOffset = -1L + records.foreach { r => + assertTrue(r.offset > prevOffset, s"Offset $prevOffset came before ${r.offset} ") + prevOffset = r.offset + } + + //Are the files identical? + assertEquals(getLogFile(brokers(0), 0).length, getLogFile(brokers(1), 0).length, "Log files should match Broker0 vs Broker 1") + } + + /** + * Unlike the tests above, this test doesn't fail prior to the Leader Epoch Change. I was unable to find a deterministic + * method for recreating the fast leader change bug. + */ + @Test + def shouldSurviveFastLeaderChange(): Unit = { + val tp = new TopicPartition(topic, 0) + + //Given 2 brokers + brokers = (100 to 101).map(createBroker(_)) + + //A single partition topic with 2 replicas + TestUtils.createTopic(zkClient, topic, Map(0 -> Seq(100, 101)), brokers) + producer = createProducer + + //Kick off with a single record + producer.send(new ProducerRecord(topic, 0, null, msg)).get + var messagesWritten = 1 + + //Now invoke the fast leader change bug + (0 until 5).foreach { i => + val leaderId = zkClient.getLeaderForPartition(new TopicPartition(topic, 0)).get + val leader = brokers.filter(_.config.brokerId == leaderId)(0) + val follower = brokers.filter(_.config.brokerId != leaderId)(0) + + producer.send(new ProducerRecord(topic, 0, null, msg)).get + messagesWritten += 1 + + //As soon as it replicates, bounce the follower + bounce(follower) + + log(leader, follower) + awaitISR(tp) + + //Then bounce the leader + bounce(leader) + + log(leader, follower) + awaitISR(tp) + + //Ensure no data was lost + assertTrue(brokers.forall { broker => getLog(broker, 0).logEndOffset == messagesWritten }) + } + } + + @Test + def logsShouldNotDivergeOnUncleanLeaderElections(): Unit = { + + // Given two brokers, unclean leader election is enabled + brokers = (100 to 101).map(createBroker(_, enableUncleanLeaderElection = true)) + + // A single partition topic with 2 replicas, min.isr = 1 + TestUtils.createTopic(zkClient, topic, Map(0 -> Seq(100, 101)), brokers, + CoreUtils.propsWith((KafkaConfig.MinInSyncReplicasProp, "1"))) + + producer = TestUtils.createProducer(getBrokerListStrFromServers(brokers), acks = 1) + + // Write one message while both brokers are up + (0 until 1).foreach { i => + producer.send(new ProducerRecord(topic, 0, null, msg)) + producer.flush()} + + // Since we use producer with acks = 1, make sure that logs match for the first epoch + waitForLogsToMatch(brokers(0), brokers(1)) + + // shutdown broker 100 + brokers(0).shutdown() + + //Write 1 message + (0 until 1).foreach { i => + producer.send(new ProducerRecord(topic, 0, null, msg)) + producer.flush()} + + brokers(1).shutdown() + brokers(0).startup() + + //Bounce the producer (this is required, probably because the broker port changes on restart?) + producer.close() + producer = TestUtils.createProducer(getBrokerListStrFromServers(brokers), acks = 1) + + //Write 3 messages + (0 until 3).foreach { i => + producer.send(new ProducerRecord(topic, 0, null, msgBigger)) + producer.flush()} + + brokers(0).shutdown() + brokers(1).startup() + + //Bounce the producer (this is required, probably because the broker port changes on restart?) + producer.close() + producer = TestUtils.createProducer(getBrokerListStrFromServers(brokers), acks = 1) + + //Write 1 message + (0 until 1).foreach { i => + producer.send(new ProducerRecord(topic, 0, null, msg)) + producer.flush()} + + brokers(1).shutdown() + brokers(0).startup() + + //Bounce the producer (this is required, probably because the broker port changes on restart?) + producer.close() + producer = TestUtils.createProducer(getBrokerListStrFromServers(brokers), acks = 1) + + //Write 2 messages + (0 until 2).foreach { i => + producer.send(new ProducerRecord(topic, 0, null, msgBigger)) + producer.flush()} + + printSegments() + + brokers(1).startup() + + waitForLogsToMatch(brokers(0), brokers(1)) + printSegments() + + def crcSeq(broker: KafkaServer, partition: Int = 0): Seq[Long] = { + val batches = getLog(broker, partition).activeSegment.read(0, Integer.MAX_VALUE) + .records.batches().asScala.toSeq + batches.map(_.checksum) + } + assertTrue(crcSeq(brokers(0)) == crcSeq(brokers(1)), + s"Logs on Broker 100 and Broker 101 should match") + } + + private def log(leader: KafkaServer, follower: KafkaServer): Unit = { + info(s"Bounce complete for follower ${follower.config.brokerId}") + info(s"Leader: leo${leader.config.brokerId}: " + getLog(leader, 0).logEndOffset + " cache: " + epochCache(leader).epochEntries) + info(s"Follower: leo${follower.config.brokerId}: " + getLog(follower, 0).logEndOffset + " cache: " + epochCache(follower).epochEntries) + } + + private def waitForLogsToMatch(b1: KafkaServer, b2: KafkaServer, partition: Int = 0): Unit = { + TestUtils.waitUntilTrue(() => {getLog(b1, partition).logEndOffset == getLog(b2, partition).logEndOffset}, "Logs didn't match.") + } + + private def printSegments(): Unit = { + info("Broker0:") + DumpLogSegments.main(Seq("--files", getLogFile(brokers(0), 0).getCanonicalPath).toArray) + info("Broker1:") + DumpLogSegments.main(Seq("--files", getLogFile(brokers(1), 0).getCanonicalPath).toArray) + } + + private def startConsumer(): KafkaConsumer[Array[Byte], Array[Byte]] = { + val consumerConfig = new Properties() + consumerConfig.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, getBrokerListStrFromServers(brokers)) + consumerConfig.put(ConsumerConfig.FETCH_MAX_BYTES_CONFIG, String.valueOf(getLogFile(brokers(1), 0).length() * 2)) + consumerConfig.put(ConsumerConfig.MAX_PARTITION_FETCH_BYTES_CONFIG, String.valueOf(getLogFile(brokers(1), 0).length() * 2)) + consumer = new KafkaConsumer(consumerConfig, new ByteArrayDeserializer, new ByteArrayDeserializer) + consumer.assign(List(new TopicPartition(topic, 0)).asJava) + consumer.seek(new TopicPartition(topic, 0), 0) + consumer + } + + private def deleteMessagesFromLogFile(bytes: Long, broker: KafkaServer, partitionId: Int): Unit = { + val logFile = getLogFile(broker, partitionId) + val writable = new RandomAccessFile(logFile, "rwd") + writable.setLength(logFile.length() - bytes) + writable.close() + } + + private def createBufferingProducer: KafkaProducer[Array[Byte], Array[Byte]] = { + TestUtils.createProducer(getBrokerListStrFromServers(brokers), + acks = -1, + lingerMs = 10000, + batchSize = msg.length * 1000, + compressionType = "snappy") + } + + private def getLogFile(broker: KafkaServer, partition: Int): File = { + val log: UnifiedLog = getLog(broker, partition) + log.flush() + log.dir.listFiles.filter(_.getName.endsWith(".log"))(0) + } + + private def getLog(broker: KafkaServer, partition: Int): UnifiedLog = { + broker.logManager.getLog(new TopicPartition(topic, partition)).orNull + } + + private def bounce(follower: KafkaServer): Unit = { + follower.shutdown() + follower.startup() + producer.close() + producer = createProducer //TODO not sure why we need to recreate the producer, but it doesn't reconnect if we don't + } + + private def epochCache(broker: KafkaServer): LeaderEpochFileCache = getLog(broker, 0).leaderEpochCache.get + + private def latestRecord(leader: KafkaServer, offset: Int = -1, partition: Int = 0): RecordBatch = { + getLog(leader, partition).activeSegment.read(0, Integer.MAX_VALUE) + .records.batches().asScala.toSeq.last + } + + private def awaitISR(tp: TopicPartition): Unit = { + TestUtils.waitUntilTrue(() => { + leader.replicaManager.onlinePartition(tp).get.inSyncReplicaIds.size == 2 + }, "Timed out waiting for replicas to join ISR") + } + + private def createProducer: KafkaProducer[Array[Byte], Array[Byte]] = { + TestUtils.createProducer(getBrokerListStrFromServers(brokers), acks = -1) + } + + private def leader: KafkaServer = { + assertEquals(2, brokers.size) + val leaderId = zkClient.getLeaderForPartition(new TopicPartition(topic, 0)).get + brokers.filter(_.config.brokerId == leaderId).head + } + + private def follower: KafkaServer = { + assertEquals(2, brokers.size) + val leader = zkClient.getLeaderForPartition(new TopicPartition(topic, 0)).get + brokers.filter(_.config.brokerId != leader).head + } + + private def createBroker(id: Int, enableUncleanLeaderElection: Boolean = false): KafkaServer = { + val config = createBrokerConfig(id, zkConnect) + TestUtils.setIbpAndMessageFormatVersions(config, apiVersion) + config.setProperty(KafkaConfig.UncleanLeaderElectionEnableProp, enableUncleanLeaderElection.toString) + createServer(fromProps(config)) + } +} diff --git a/core/src/test/scala/unit/kafka/server/epoch/EpochDrivenReplicationProtocolAcceptanceWithIbp26Test.scala b/core/src/test/scala/unit/kafka/server/epoch/EpochDrivenReplicationProtocolAcceptanceWithIbp26Test.scala new file mode 100644 index 0000000..2ad4776 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/epoch/EpochDrivenReplicationProtocolAcceptanceWithIbp26Test.scala @@ -0,0 +1,29 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server.epoch + +import kafka.api.KAFKA_2_6_IV0 + +/** + * With IBP 2.7 onwards, we truncate based on diverging epochs returned in fetch responses. + * EpochDrivenReplicationProtocolAcceptanceTest tests epochs with latest version. This test + * verifies that we handle older IBP versions with truncation on leader/follower change correctly. + */ +class EpochDrivenReplicationProtocolAcceptanceWithIbp26Test extends EpochDrivenReplicationProtocolAcceptanceTest { + override val apiVersion = KAFKA_2_6_IV0 +} diff --git a/core/src/test/scala/unit/kafka/server/epoch/LeaderEpochFileCacheTest.scala b/core/src/test/scala/unit/kafka/server/epoch/LeaderEpochFileCacheTest.scala new file mode 100644 index 0000000..1a4a82f --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/epoch/LeaderEpochFileCacheTest.scala @@ -0,0 +1,581 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server.epoch + +import java.io.File + +import scala.collection.Seq +import scala.collection.mutable.ListBuffer + +import kafka.server.checkpoints.{LeaderEpochCheckpoint, LeaderEpochCheckpointFile} +import kafka.utils.TestUtils +import org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.{UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET} +import org.apache.kafka.common.TopicPartition +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +/** + * Unit test for the LeaderEpochFileCache. + */ +class LeaderEpochFileCacheTest { + val tp = new TopicPartition("TestTopic", 5) + private val checkpoint: LeaderEpochCheckpoint = new LeaderEpochCheckpoint { + private var epochs: Seq[EpochEntry] = Seq() + override def write(epochs: Iterable[EpochEntry]): Unit = this.epochs = epochs.toSeq + override def read(): Seq[EpochEntry] = this.epochs + } + private val cache = new LeaderEpochFileCache(tp, checkpoint) + + @Test + def testPreviousEpoch(): Unit = { + assertEquals(None, cache.previousEpoch) + + cache.assign(epoch = 2, startOffset = 10) + assertEquals(None, cache.previousEpoch) + + cache.assign(epoch = 4, startOffset = 15) + assertEquals(Some(2), cache.previousEpoch) + + cache.assign(epoch = 10, startOffset = 20) + assertEquals(Some(4), cache.previousEpoch) + + cache.truncateFromEnd(18) + assertEquals(Some(2), cache.previousEpoch) + } + + @Test + def shouldAddEpochAndMessageOffsetToCache() = { + //When + cache.assign(epoch = 2, startOffset = 10) + val logEndOffset = 11 + + //Then + assertEquals(Some(2), cache.latestEpoch) + assertEquals(EpochEntry(2, 10), cache.epochEntries(0)) + assertEquals((2, logEndOffset), cache.endOffsetFor(2, logEndOffset)) //should match logEndOffset + } + + @Test + def shouldReturnLogEndOffsetIfLatestEpochRequested() = { + //When just one epoch + cache.assign(epoch = 2, startOffset = 11) + cache.assign(epoch = 2, startOffset = 12) + val logEndOffset = 14 + + //Then + assertEquals((2, logEndOffset), cache.endOffsetFor(2, logEndOffset)) + } + + @Test + def shouldReturnUndefinedOffsetIfUndefinedEpochRequested() = { + val expectedEpochEndOffset = (UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET) + + // assign couple of epochs + cache.assign(epoch = 2, startOffset = 11) + cache.assign(epoch = 3, startOffset = 12) + + //When (say a bootstraping follower) sends request for UNDEFINED_EPOCH + val epochAndOffsetFor = cache.endOffsetFor(UNDEFINED_EPOCH, 0L) + + //Then + assertEquals(expectedEpochEndOffset, + epochAndOffsetFor, "Expected undefined epoch and offset if undefined epoch requested. Cache not empty.") + } + + @Test + def shouldNotOverwriteLogEndOffsetForALeaderEpochOnceItHasBeenAssigned() = { + //Given + val logEndOffset = 9 + + cache.assign(2, logEndOffset) + + //When called again later + cache.assign(2, 10) + + //Then the offset should NOT have been updated + assertEquals(logEndOffset, cache.epochEntries(0).startOffset) + assertEquals(ListBuffer(EpochEntry(2, 9)), cache.epochEntries) + } + + @Test + def shouldEnforceMonotonicallyIncreasingStartOffsets() = { + //Given + cache.assign(2, 9) + + //When update epoch new epoch but same offset + cache.assign(3, 9) + + //Then epoch should have been updated + assertEquals(ListBuffer(EpochEntry(3, 9)), cache.epochEntries) + } + + @Test + def shouldNotOverwriteOffsetForALeaderEpochOnceItHasBeenAssigned() = { + cache.assign(2, 6) + + //When called again later with a greater offset + cache.assign(2, 10) + + //Then later update should have been ignored + assertEquals(6, cache.epochEntries(0).startOffset) + } + + @Test + def shouldReturnUnsupportedIfNoEpochRecorded(): Unit = { + //Then + assertEquals((UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET), cache.endOffsetFor(0, 0L)) + } + + @Test + def shouldReturnUnsupportedIfNoEpochRecordedAndUndefinedEpochRequested(): Unit = { + //When (say a follower on older message format version) sends request for UNDEFINED_EPOCH + val offsetFor = cache.endOffsetFor(UNDEFINED_EPOCH, 73) + + //Then + assertEquals((UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET), + offsetFor, "Expected undefined epoch and offset if undefined epoch requested. Empty cache.") + } + + @Test + def shouldReturnFirstEpochIfRequestedEpochLessThanFirstEpoch(): Unit = { + cache.assign(epoch = 5, startOffset = 11) + cache.assign(epoch = 6, startOffset = 12) + cache.assign(epoch = 7, startOffset = 13) + + //When + val epochAndOffset = cache.endOffsetFor(4, 0L) + + //Then + assertEquals((4, 11), epochAndOffset) + } + + @Test + def shouldTruncateIfMatchingEpochButEarlierStartingOffset(): Unit = { + cache.assign(epoch = 5, startOffset = 11) + cache.assign(epoch = 6, startOffset = 12) + cache.assign(epoch = 7, startOffset = 13) + + // epoch 7 starts at an earlier offset + cache.assign(epoch = 7, startOffset = 12) + + assertEquals((5, 12), cache.endOffsetFor(5, 0L)) + assertEquals((5, 12), cache.endOffsetFor(6, 0L)) + } + + @Test + def shouldGetFirstOffsetOfSubsequentEpochWhenOffsetRequestedForPreviousEpoch() = { + //When several epochs + cache.assign(epoch = 1, startOffset = 11) + cache.assign(epoch = 1, startOffset = 12) + cache.assign(epoch = 2, startOffset = 13) + cache.assign(epoch = 2, startOffset = 14) + cache.assign(epoch = 3, startOffset = 15) + cache.assign(epoch = 3, startOffset = 16) + + //Then get the start offset of the next epoch + assertEquals((2, 15), cache.endOffsetFor(2, 17)) + } + + @Test + def shouldReturnNextAvailableEpochIfThereIsNoExactEpochForTheOneRequested(): Unit = { + //When + cache.assign(epoch = 0, startOffset = 10) + cache.assign(epoch = 2, startOffset = 13) + cache.assign(epoch = 4, startOffset = 17) + + //Then + assertEquals((0, 13), cache.endOffsetFor(1, 0L)) + assertEquals((2, 17), cache.endOffsetFor(2, 0L)) + assertEquals((2, 17), cache.endOffsetFor(3, 0L)) + } + + @Test + def shouldNotUpdateEpochAndStartOffsetIfItDidNotChange() = { + //When + cache.assign(epoch = 2, startOffset = 6) + cache.assign(epoch = 2, startOffset = 7) + + //Then + assertEquals(1, cache.epochEntries.size) + assertEquals(EpochEntry(2, 6), cache.epochEntries.toList(0)) + } + + @Test + def shouldReturnInvalidOffsetIfEpochIsRequestedWhichIsNotCurrentlyTracked(): Unit = { + //When + cache.assign(epoch = 2, startOffset = 100) + + //Then + assertEquals((UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET), cache.endOffsetFor(3, 100)) + } + + @Test + def shouldSupportEpochsThatDoNotStartFromZero(): Unit = { + //When + cache.assign(epoch = 2, startOffset = 6) + val logEndOffset = 7 + + //Then + assertEquals((2, logEndOffset), cache.endOffsetFor(2, logEndOffset)) + assertEquals(1, cache.epochEntries.size) + assertEquals(EpochEntry(2, 6), cache.epochEntries(0)) + } + + @Test + def shouldPersistEpochsBetweenInstances(): Unit = { + val checkpointPath = TestUtils.tempFile().getAbsolutePath + val checkpoint = new LeaderEpochCheckpointFile(new File(checkpointPath)) + + //Given + val cache = new LeaderEpochFileCache(tp, checkpoint) + cache.assign(epoch = 2, startOffset = 6) + + //When + val checkpoint2 = new LeaderEpochCheckpointFile(new File(checkpointPath)) + val cache2 = new LeaderEpochFileCache(tp, checkpoint2) + + //Then + assertEquals(1, cache2.epochEntries.size) + assertEquals(EpochEntry(2, 6), cache2.epochEntries.toList(0)) + } + + @Test + def shouldEnforceMonotonicallyIncreasingEpochs(): Unit = { + //Given + cache.assign(epoch = 1, startOffset = 5); + var logEndOffset = 6 + cache.assign(epoch = 2, startOffset = 6); + logEndOffset = 7 + + //When we update an epoch in the past with a different offset, the log has already reached + //an inconsistent state. Our options are either to raise an error, ignore the new append, + //or truncate the cached epochs to the point of conflict. We take this latter approach in + //order to guarantee that epochs and offsets in the cache increase monotonically, which makes + //the search logic simpler to reason about. + cache.assign(epoch = 1, startOffset = 7); + logEndOffset = 8 + + //Then later epochs will be removed + assertEquals(Some(1), cache.latestEpoch) + + //Then end offset for epoch 1 will have changed + assertEquals((1, 8), cache.endOffsetFor(1, logEndOffset)) + + //Then end offset for epoch 2 is now undefined + assertEquals((UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET), cache.endOffsetFor(2, logEndOffset)) + assertEquals(EpochEntry(1, 7), cache.epochEntries(0)) + } + + @Test + def shouldEnforceOffsetsIncreaseMonotonically() = { + //When epoch goes forward but offset goes backwards + cache.assign(epoch = 2, startOffset = 6) + cache.assign(epoch = 3, startOffset = 5) + + //The last assignment wins and the conflicting one is removed from the log + assertEquals(EpochEntry(3, 5), cache.epochEntries.toList(0)) + } + + @Test + def shouldIncreaseAndTrackEpochsAsLeadersChangeManyTimes(): Unit = { + var logEndOffset = 0L + + //Given + cache.assign(epoch = 0, startOffset = 0) //logEndOffset=0 + + //When + cache.assign(epoch = 1, startOffset = 0) //logEndOffset=0 + + //Then epoch should go up + assertEquals(Some(1), cache.latestEpoch) + //offset for 1 should still be 0 + assertEquals((1, 0), cache.endOffsetFor(1, logEndOffset)) + //offset for epoch 0 should still be 0 + assertEquals((0, 0), cache.endOffsetFor(0, logEndOffset)) + + //When we write 5 messages as epoch 1 + logEndOffset = 5L + + //Then end offset for epoch(1) should be logEndOffset => 5 + assertEquals((1, 5), cache.endOffsetFor(1, logEndOffset)) + //Epoch 0 should still be at offset 0 + assertEquals((0, 0), cache.endOffsetFor(0, logEndOffset)) + + //When + cache.assign(epoch = 2, startOffset = 5) //logEndOffset=5 + + logEndOffset = 10 //write another 5 messages + + //Then end offset for epoch(2) should be logEndOffset => 10 + assertEquals((2, 10), cache.endOffsetFor(2, logEndOffset)) + + //end offset for epoch(1) should be the start offset of epoch(2) => 5 + assertEquals((1, 5), cache.endOffsetFor(1, logEndOffset)) + + //epoch (0) should still be 0 + assertEquals((0, 0), cache.endOffsetFor(0, logEndOffset)) + } + + @Test + def shouldIncreaseAndTrackEpochsAsFollowerReceivesManyMessages(): Unit = { + //When Messages come in + cache.assign(epoch = 0, startOffset = 0); + var logEndOffset = 1 + cache.assign(epoch = 0, startOffset = 1); + logEndOffset = 2 + cache.assign(epoch = 0, startOffset = 2); + logEndOffset = 3 + + //Then epoch should stay, offsets should grow + assertEquals(Some(0), cache.latestEpoch) + assertEquals((0, logEndOffset), cache.endOffsetFor(0, logEndOffset)) + + //When messages arrive with greater epoch + cache.assign(epoch = 1, startOffset = 3); + logEndOffset = 4 + cache.assign(epoch = 1, startOffset = 4); + logEndOffset = 5 + cache.assign(epoch = 1, startOffset = 5); + logEndOffset = 6 + + assertEquals(Some(1), cache.latestEpoch) + assertEquals((1, logEndOffset), cache.endOffsetFor(1, logEndOffset)) + + //When + cache.assign(epoch = 2, startOffset = 6); + logEndOffset = 7 + cache.assign(epoch = 2, startOffset = 7); + logEndOffset = 8 + cache.assign(epoch = 2, startOffset = 8); + logEndOffset = 9 + + assertEquals(Some(2), cache.latestEpoch) + assertEquals((2, logEndOffset), cache.endOffsetFor(2, logEndOffset)) + + //Older epochs should return the start offset of the first message in the subsequent epoch. + assertEquals((0, 3), cache.endOffsetFor(0, logEndOffset)) + assertEquals((1, 6), cache.endOffsetFor(1, logEndOffset)) + } + + @Test + def shouldDropEntriesOnEpochBoundaryWhenRemovingLatestEntries(): Unit = { + //Given + cache.assign(epoch = 2, startOffset = 6) + cache.assign(epoch = 3, startOffset = 8) + cache.assign(epoch = 4, startOffset = 11) + + //When clear latest on epoch boundary + cache.truncateFromEnd(endOffset = 8) + + //Then should remove two latest epochs (remove is inclusive) + assertEquals(ListBuffer(EpochEntry(2, 6)), cache.epochEntries) + } + + @Test + def shouldPreserveResetOffsetOnClearEarliestIfOneExists(): Unit = { + //Given + cache.assign(epoch = 2, startOffset = 6) + cache.assign(epoch = 3, startOffset = 8) + cache.assign(epoch = 4, startOffset = 11) + + //When reset to offset ON epoch boundary + cache.truncateFromStart(startOffset = 8) + + //Then should preserve (3, 8) + assertEquals(ListBuffer(EpochEntry(3, 8), EpochEntry(4, 11)), cache.epochEntries) + } + + @Test + def shouldUpdateSavedOffsetWhenOffsetToClearToIsBetweenEpochs(): Unit = { + //Given + cache.assign(epoch = 2, startOffset = 6) + cache.assign(epoch = 3, startOffset = 8) + cache.assign(epoch = 4, startOffset = 11) + + //When reset to offset BETWEEN epoch boundaries + cache.truncateFromStart(startOffset = 9) + + //Then we should retain epoch 3, but update it's offset to 9 as 8 has been removed + assertEquals(ListBuffer(EpochEntry(3, 9), EpochEntry(4, 11)), cache.epochEntries) + } + + @Test + def shouldNotClearAnythingIfOffsetToEarly(): Unit = { + //Given + cache.assign(epoch = 2, startOffset = 6) + cache.assign(epoch = 3, startOffset = 8) + cache.assign(epoch = 4, startOffset = 11) + + //When reset to offset before first epoch offset + cache.truncateFromStart(startOffset = 1) + + //Then nothing should change + assertEquals(ListBuffer(EpochEntry(2, 6),EpochEntry(3, 8), EpochEntry(4, 11)), cache.epochEntries) + } + + @Test + def shouldNotClearAnythingIfOffsetToFirstOffset(): Unit = { + //Given + cache.assign(epoch = 2, startOffset = 6) + cache.assign(epoch = 3, startOffset = 8) + cache.assign(epoch = 4, startOffset = 11) + + //When reset to offset on earliest epoch boundary + cache.truncateFromStart(startOffset = 6) + + //Then nothing should change + assertEquals(ListBuffer(EpochEntry(2, 6),EpochEntry(3, 8), EpochEntry(4, 11)), cache.epochEntries) + } + + @Test + def shouldRetainLatestEpochOnClearAllEarliest(): Unit = { + //Given + cache.assign(epoch = 2, startOffset = 6) + cache.assign(epoch = 3, startOffset = 8) + cache.assign(epoch = 4, startOffset = 11) + + //When + cache.truncateFromStart(startOffset = 11) + + //Then retain the last + assertEquals(ListBuffer(EpochEntry(4, 11)), cache.epochEntries) + } + + @Test + def shouldUpdateOffsetBetweenEpochBoundariesOnClearEarliest(): Unit = { + //Given + cache.assign(epoch = 2, startOffset = 6) + cache.assign(epoch = 3, startOffset = 8) + cache.assign(epoch = 4, startOffset = 11) + + //When we clear from a position between offset 8 & offset 11 + cache.truncateFromStart(startOffset = 9) + + //Then we should update the middle epoch entry's offset + assertEquals(ListBuffer(EpochEntry(3, 9), EpochEntry(4, 11)), cache.epochEntries) + } + + @Test + def shouldUpdateOffsetBetweenEpochBoundariesOnClearEarliest2(): Unit = { + //Given + cache.assign(epoch = 0, startOffset = 0) + cache.assign(epoch = 1, startOffset = 7) + cache.assign(epoch = 2, startOffset = 10) + + //When we clear from a position between offset 0 & offset 7 + cache.truncateFromStart(startOffset = 5) + + //Then we should keep epoch 0 but update the offset appropriately + assertEquals(ListBuffer(EpochEntry(0,5), EpochEntry(1, 7), EpochEntry(2, 10)), cache.epochEntries) + } + + @Test + def shouldRetainLatestEpochOnClearAllEarliestAndUpdateItsOffset(): Unit = { + //Given + cache.assign(epoch = 2, startOffset = 6) + cache.assign(epoch = 3, startOffset = 8) + cache.assign(epoch = 4, startOffset = 11) + + //When reset to offset beyond last epoch + cache.truncateFromStart(startOffset = 15) + + //Then update the last + assertEquals(ListBuffer(EpochEntry(4, 15)), cache.epochEntries) + } + + @Test + def shouldDropEntriesBetweenEpochBoundaryWhenRemovingNewest(): Unit = { + //Given + cache.assign(epoch = 2, startOffset = 6) + cache.assign(epoch = 3, startOffset = 8) + cache.assign(epoch = 4, startOffset = 11) + + //When reset to offset BETWEEN epoch boundaries + cache.truncateFromEnd(endOffset = 9) + + //Then should keep the preceding epochs + assertEquals(Some(3), cache.latestEpoch) + assertEquals(ListBuffer(EpochEntry(2, 6), EpochEntry(3, 8)), cache.epochEntries) + } + + @Test + def shouldClearAllEntries(): Unit = { + //Given + cache.assign(epoch = 2, startOffset = 6) + cache.assign(epoch = 3, startOffset = 8) + cache.assign(epoch = 4, startOffset = 11) + + //When + cache.clearAndFlush() + + //Then + assertEquals(0, cache.epochEntries.size) + } + + @Test + def shouldNotResetEpochHistoryHeadIfUndefinedPassed(): Unit = { + //Given + cache.assign(epoch = 2, startOffset = 6) + cache.assign(epoch = 3, startOffset = 8) + cache.assign(epoch = 4, startOffset = 11) + + //When reset to offset on epoch boundary + cache.truncateFromStart(startOffset = UNDEFINED_EPOCH_OFFSET) + + //Then should do nothing + assertEquals(3, cache.epochEntries.size) + } + + @Test + def shouldNotResetEpochHistoryTailIfUndefinedPassed(): Unit = { + //Given + cache.assign(epoch = 2, startOffset = 6) + cache.assign(epoch = 3, startOffset = 8) + cache.assign(epoch = 4, startOffset = 11) + + //When reset to offset on epoch boundary + cache.truncateFromEnd(endOffset = UNDEFINED_EPOCH_OFFSET) + + //Then should do nothing + assertEquals(3, cache.epochEntries.size) + } + + @Test + def shouldFetchLatestEpochOfEmptyCache(): Unit = { + //Then + assertEquals(None, cache.latestEpoch) + } + + @Test + def shouldFetchEndOffsetOfEmptyCache(): Unit = { + //Then + assertEquals((UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET), cache.endOffsetFor(7, 0L)) + } + + @Test + def shouldClearEarliestOnEmptyCache(): Unit = { + //Then + cache.truncateFromStart(7) + } + + @Test + def shouldClearLatestOnEmptyCache(): Unit = { + //Then + cache.truncateFromEnd(7) + } +} diff --git a/core/src/test/scala/unit/kafka/server/epoch/LeaderEpochIntegrationTest.scala b/core/src/test/scala/unit/kafka/server/epoch/LeaderEpochIntegrationTest.scala new file mode 100644 index 0000000..2557d35 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/epoch/LeaderEpochIntegrationTest.scala @@ -0,0 +1,304 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server.epoch + +import kafka.cluster.BrokerEndPoint +import kafka.server.KafkaConfig._ +import kafka.server.{BlockingSend, KafkaServer, ReplicaFetcherBlockingSend} +import kafka.utils.Implicits._ +import kafka.utils.TestUtils._ +import kafka.utils.{Logging, TestUtils} +import kafka.server.QuorumTestHarness +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord} +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.protocol.Errors._ +import org.apache.kafka.common.serialization.StringSerializer +import org.apache.kafka.common.utils.{LogContext, SystemTime} +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderPartition +import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderTopic +import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderTopicCollection +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset +import org.apache.kafka.common.protocol.ApiKeys +import org.apache.kafka.common.requests.{OffsetsForLeaderEpochRequest, OffsetsForLeaderEpochResponse} +import org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.UNDEFINED_EPOCH_OFFSET +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, Test} + +import scala.jdk.CollectionConverters._ +import scala.collection.Map +import scala.collection.mutable.ListBuffer + +class LeaderEpochIntegrationTest extends QuorumTestHarness with Logging { + var brokers: ListBuffer[KafkaServer] = ListBuffer() + val topic1 = "foo" + val topic2 = "bar" + val t1p0 = new TopicPartition(topic1, 0) + val t1p1 = new TopicPartition(topic1, 1) + val t1p2 = new TopicPartition(topic1, 2) + val t2p0 = new TopicPartition(topic2, 0) + val t2p2 = new TopicPartition(topic2, 2) + val tp = t1p0 + var producer: KafkaProducer[Array[Byte], Array[Byte]] = null + + @AfterEach + override def tearDown(): Unit = { + if (producer != null) + producer.close() + TestUtils.shutdownServers(brokers) + super.tearDown() + } + + @Test + def shouldAddCurrentLeaderEpochToMessagesAsTheyAreWrittenToLeader(): Unit = { + brokers ++= (0 to 1).map { id => createServer(fromProps(createBrokerConfig(id, zkConnect))) } + + // Given two topics with replication of a single partition + for (topic <- List(topic1, topic2)) { + createTopic(zkClient, topic, Map(0 -> Seq(0, 1)), servers = brokers) + } + + // When we send four messages + sendFourMessagesToEachTopic() + + //Then they should be stamped with Leader Epoch 0 + var expectedLeaderEpoch = 0 + waitUntilTrue(() => messagesHaveLeaderEpoch(brokers(0), expectedLeaderEpoch, 0), "Leader epoch should be 0") + + //Given we then bounce the leader + brokers(0).shutdown() + brokers(0).startup() + + //Then LeaderEpoch should now have changed from 0 -> 1 + expectedLeaderEpoch = 1 + waitForEpochChangeTo(topic1, 0, expectedLeaderEpoch) + waitForEpochChangeTo(topic2, 0, expectedLeaderEpoch) + + //Given we now send messages + sendFourMessagesToEachTopic() + + //The new messages should be stamped with LeaderEpoch = 1 + waitUntilTrue(() => messagesHaveLeaderEpoch(brokers(0), expectedLeaderEpoch, 4), "Leader epoch should be 1") + } + + @Test + def shouldSendLeaderEpochRequestAndGetAResponse(): Unit = { + + //3 brokers, put partition on 100/101 and then pretend to be 102 + brokers ++= (100 to 102).map { id => createServer(fromProps(createBrokerConfig(id, zkConnect))) } + + val assignment1 = Map(0 -> Seq(100), 1 -> Seq(101)) + TestUtils.createTopic(zkClient, topic1, assignment1, brokers) + + val assignment2 = Map(0 -> Seq(100)) + TestUtils.createTopic(zkClient, topic2, assignment2, brokers) + + //Send messages equally to the two partitions, then half as many to a third + producer = createProducer(getBrokerListStrFromServers(brokers), acks = -1) + (0 until 10).foreach { _ => + producer.send(new ProducerRecord(topic1, 0, null, "IHeartLogs".getBytes)) + } + (0 until 20).foreach { _ => + producer.send(new ProducerRecord(topic1, 1, null, "OhAreThey".getBytes)) + } + (0 until 30).foreach { _ => + producer.send(new ProducerRecord(topic2, 0, null, "IReallyDo".getBytes)) + } + producer.flush() + + val fetcher0 = new TestFetcherThread(sender(from = brokers(2), to = brokers(0))) + val epochsRequested = Map(t1p0 -> 0, t1p1 -> 0, t2p0 -> 0, t2p2 -> 0) + + //When + val offsetsForEpochs = fetcher0.leaderOffsetsFor(epochsRequested) + + //Then end offset should be correct + assertEquals(10, offsetsForEpochs(t1p0).endOffset) + assertEquals(30, offsetsForEpochs(t2p0).endOffset) + + //And should get no leader for partition error from t1p1 (as it's not on broker 0) + assertEquals(NOT_LEADER_OR_FOLLOWER.code, offsetsForEpochs(t1p1).errorCode) + assertEquals(UNDEFINED_EPOCH_OFFSET, offsetsForEpochs(t1p1).endOffset) + + //Repointing to broker 1 we should get the correct offset for t1p1 + val fetcher1 = new TestFetcherThread(sender(from = brokers(2), to = brokers(1))) + val offsetsForEpochs1 = fetcher1.leaderOffsetsFor(epochsRequested) + assertEquals(20, offsetsForEpochs1(t1p1).endOffset) + } + + @Test + def shouldIncreaseLeaderEpochBetweenLeaderRestarts(): Unit = { + //Setup: we are only interested in the single partition on broker 101 + brokers += createServer(fromProps(createBrokerConfig(100, zkConnect))) + assertEquals(100, TestUtils.waitUntilControllerElected(zkClient)) + + brokers += createServer(fromProps(createBrokerConfig(101, zkConnect))) + + def leo() = brokers(1).replicaManager.localLog(tp).get.logEndOffset + + TestUtils.createTopic(zkClient, tp.topic, Map(tp.partition -> Seq(101)), brokers) + producer = createProducer(getBrokerListStrFromServers(brokers), acks = -1) + + //1. Given a single message + producer.send(new ProducerRecord(tp.topic, tp.partition, null, "IHeartLogs".getBytes)).get + var fetcher = new TestFetcherThread(sender(brokers(0), brokers(1))) + + //Then epoch should be 0 and leo: 1 + var epochEndOffset = fetcher.leaderOffsetsFor(Map(tp -> 0))(tp) + assertEquals(0, epochEndOffset.leaderEpoch) + assertEquals(1, epochEndOffset.endOffset) + assertEquals(1, leo()) + + //2. When broker is bounced + brokers(1).shutdown() + brokers(1).startup() + + producer.send(new ProducerRecord(tp.topic, tp.partition, null, "IHeartLogs".getBytes)).get + fetcher = new TestFetcherThread(sender(brokers(0), brokers(1))) + + //Then epoch 0 should still be the start offset of epoch 1 + epochEndOffset = fetcher.leaderOffsetsFor(Map(tp -> 0))(tp) + assertEquals(1, epochEndOffset.endOffset) + assertEquals(0, epochEndOffset.leaderEpoch) + + //No data written in epoch 1 + epochEndOffset = fetcher.leaderOffsetsFor(Map(tp -> 1))(tp) + assertEquals(0, epochEndOffset.leaderEpoch) + assertEquals(1, epochEndOffset.endOffset) + + //Then epoch 2 should be the leo (NB: The leader epoch goes up in factors of 2 - + //This is because we have to first change leader to -1 and then change it again to the live replica) + //Note that the expected leader changes depend on the controller being on broker 100, which is not restarted + epochEndOffset = fetcher.leaderOffsetsFor(Map(tp -> 2))(tp) + assertEquals(2, epochEndOffset.leaderEpoch) + assertEquals(2, epochEndOffset.endOffset) + assertEquals(2, leo()) + + //3. When broker is bounced again + brokers(1).shutdown() + brokers(1).startup() + + producer.send(new ProducerRecord(tp.topic, tp.partition, null, "IHeartLogs".getBytes)).get + fetcher = new TestFetcherThread(sender(brokers(0), brokers(1))) + + //Then Epoch 0 should still map to offset 1 + assertEquals(1, fetcher.leaderOffsetsFor(Map(tp -> 0))(tp).endOffset()) + + //Then Epoch 2 should still map to offset 2 + assertEquals(2, fetcher.leaderOffsetsFor(Map(tp -> 2))(tp).endOffset()) + + //Then Epoch 4 should still map to offset 2 + assertEquals(3, fetcher.leaderOffsetsFor(Map(tp -> 4))(tp).endOffset()) + assertEquals(leo(), fetcher.leaderOffsetsFor(Map(tp -> 4))(tp).endOffset()) + + //Adding some extra assertions here to save test setup. + shouldSupportRequestsForEpochsNotOnTheLeader(fetcher) + } + + //Appended onto the previous test to save on setup cost. + def shouldSupportRequestsForEpochsNotOnTheLeader(fetcher: TestFetcherThread): Unit = { + /** + * Asking for an epoch not present on the leader should return the + * next matching epoch, unless there isn't any, which should return + * undefined. + */ + + val epoch1 = Map(t1p0 -> 1) + assertEquals(1, fetcher.leaderOffsetsFor(epoch1)(t1p0).endOffset()) + + val epoch3 = Map(t1p0 -> 3) + assertEquals(2, fetcher.leaderOffsetsFor(epoch3)(t1p0).endOffset()) + + val epoch5 = Map(t1p0 -> 5) + assertEquals(-1, fetcher.leaderOffsetsFor(epoch5)(t1p0).endOffset()) + } + + private def sender(from: KafkaServer, to: KafkaServer): BlockingSend = { + val node = from.metadataCache.getAliveBrokerNode(to.config.brokerId, + from.config.interBrokerListenerName).get + val endPoint = new BrokerEndPoint(node.id(), node.host(), node.port()) + new ReplicaFetcherBlockingSend(endPoint, from.config, new Metrics(), new SystemTime(), 42, "TestFetcher", new LogContext()) + } + + private def waitForEpochChangeTo(topic: String, partition: Int, epoch: Int): Unit = { + TestUtils.waitUntilTrue(() => { + brokers(0).metadataCache.getPartitionInfo(topic, partition).exists(_.leaderEpoch == epoch) + }, "Epoch didn't change") + } + + private def messagesHaveLeaderEpoch(broker: KafkaServer, expectedLeaderEpoch: Int, minOffset: Int): Boolean = { + var result = true + for (topic <- List(topic1, topic2)) { + val tp = new TopicPartition(topic, 0) + val leo = broker.getLogManager.getLog(tp).get.logEndOffset + result = result && leo > 0 && brokers.forall { broker => + broker.getLogManager.getLog(tp).get.logSegments.iterator.forall { segment => + if (segment.read(minOffset, Integer.MAX_VALUE) == null) { + false + } else { + segment.read(minOffset, Integer.MAX_VALUE) + .records.batches().iterator().asScala.forall( + expectedLeaderEpoch == _.partitionLeaderEpoch() + ) + } + } + } + } + result + } + + private def sendFourMessagesToEachTopic() = { + val testMessageList1 = List("test1", "test2", "test3", "test4") + val testMessageList2 = List("test5", "test6", "test7", "test8") + val producer = TestUtils.createProducer(TestUtils.getBrokerListStrFromServers(brokers), + keySerializer = new StringSerializer, valueSerializer = new StringSerializer) + val records = + testMessageList1.map(m => new ProducerRecord(topic1, m, m)) ++ + testMessageList2.map(m => new ProducerRecord(topic2, m, m)) + records.map(producer.send).foreach(_.get) + producer.close() + } + + /** + * Simulates how the Replica Fetcher Thread requests leader offsets for epochs + */ + private[epoch] class TestFetcherThread(sender: BlockingSend) extends Logging { + + def leaderOffsetsFor(partitions: Map[TopicPartition, Int]): Map[TopicPartition, EpochEndOffset] = { + val topics = new OffsetForLeaderTopicCollection(partitions.size) + partitions.forKeyValue { (topicPartition, leaderEpoch) => + var topic = topics.find(topicPartition.topic) + if (topic == null) { + topic = new OffsetForLeaderTopic().setTopic(topicPartition.topic) + topics.add(topic) + } + topic.partitions.add(new OffsetForLeaderPartition() + .setPartition(topicPartition.partition) + .setLeaderEpoch(leaderEpoch)) + } + + val request = OffsetsForLeaderEpochRequest.Builder.forFollower( + ApiKeys.OFFSET_FOR_LEADER_EPOCH.latestVersion, topics, 1) + val response = sender.sendRequest(request) + response.responseBody.asInstanceOf[OffsetsForLeaderEpochResponse].data.topics.asScala.flatMap { topic => + topic.partitions.asScala.map { partition => + new TopicPartition(topic.topic, partition.partition) -> partition + } + }.toMap + } + } +} diff --git a/core/src/test/scala/unit/kafka/server/epoch/OffsetsForLeaderEpochTest.scala b/core/src/test/scala/unit/kafka/server/epoch/OffsetsForLeaderEpochTest.scala new file mode 100644 index 0000000..7851b53 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/epoch/OffsetsForLeaderEpochTest.scala @@ -0,0 +1,187 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server.epoch + +import java.io.File + +import kafka.log.{UnifiedLog, LogManager} +import kafka.server.QuotaFactory.QuotaManagers +import kafka.server._ +import kafka.utils.{MockTime, TestUtils} +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.{OffsetForLeaderPartition, OffsetForLeaderTopic} +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.{EpochEndOffset, OffsetForLeaderTopicResult} +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.record.RecordBatch +import org.apache.kafka.common.requests.OffsetsForLeaderEpochResponse.{UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET} +import org.easymock.EasyMock._ +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} + +import scala.jdk.CollectionConverters._ + +class OffsetsForLeaderEpochTest { + private val config = TestUtils.createBrokerConfigs(1, TestUtils.MockZkConnect).map(KafkaConfig.fromProps).head + private val time = new MockTime + private val metrics = new Metrics + private val alterIsrManager = TestUtils.createAlterIsrManager() + private val tp = new TopicPartition("topic", 1) + private var replicaManager: ReplicaManager = _ + private var quotaManager: QuotaManagers = _ + + @BeforeEach + def setUp(): Unit = { + quotaManager = QuotaFactory.instantiate(config, metrics, time, "") + } + + @Test + def shouldGetEpochsFromReplica(): Unit = { + //Given + val offsetAndEpoch = OffsetAndEpoch(42L, 5) + val epochRequested: Integer = 5 + val request = Seq(newOffsetForLeaderTopic(tp, RecordBatch.NO_PARTITION_LEADER_EPOCH, epochRequested)) + + //Stubs + val mockLog: UnifiedLog = createNiceMock(classOf[UnifiedLog]) + val logManager: LogManager = createNiceMock(classOf[LogManager]) + expect(mockLog.endOffsetForEpoch(epochRequested)).andReturn(Some(offsetAndEpoch)) + expect(logManager.liveLogDirs).andReturn(Array.empty[File]).anyTimes() + replay(mockLog, logManager) + + // create a replica manager with 1 partition that has 1 replica + replicaManager = new ReplicaManager( + metrics = metrics, + config = config, + time = time, + scheduler = null, + logManager = logManager, + quotaManagers = quotaManager, + metadataCache = MetadataCache.zkMetadataCache(config.brokerId), + logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size), + alterIsrManager = alterIsrManager) + val partition = replicaManager.createPartition(tp) + partition.setLog(mockLog, isFutureLog = false) + partition.leaderReplicaIdOpt = Some(config.brokerId) + + //When + val response = replicaManager.lastOffsetForLeaderEpoch(request) + + //Then + assertEquals( + Seq(newOffsetForLeaderTopicResult(tp, Errors.NONE, offsetAndEpoch.leaderEpoch, offsetAndEpoch.offset)), + response) + } + + @Test + def shouldReturnNoLeaderForPartitionIfThrown(): Unit = { + val logManager: LogManager = createNiceMock(classOf[LogManager]) + expect(logManager.liveLogDirs).andReturn(Array.empty[File]).anyTimes() + replay(logManager) + + //create a replica manager with 1 partition that has 0 replica + replicaManager = new ReplicaManager( + metrics = metrics, + config = config, + time = time, + scheduler = null, + logManager = logManager, + quotaManagers = quotaManager, + metadataCache = MetadataCache.zkMetadataCache(config.brokerId), + logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size), + alterIsrManager = alterIsrManager) + replicaManager.createPartition(tp) + + //Given + val epochRequested: Integer = 5 + val request = Seq(newOffsetForLeaderTopic(tp, RecordBatch.NO_PARTITION_LEADER_EPOCH, epochRequested)) + + //When + val response = replicaManager.lastOffsetForLeaderEpoch(request) + + //Then + assertEquals( + Seq(newOffsetForLeaderTopicResult(tp, Errors.NOT_LEADER_OR_FOLLOWER, UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET)), + response) + } + + @Test + def shouldReturnUnknownTopicOrPartitionIfThrown(): Unit = { + val logManager: LogManager = createNiceMock(classOf[LogManager]) + expect(logManager.liveLogDirs).andReturn(Array.empty[File]).anyTimes() + replay(logManager) + + //create a replica manager with 0 partition + replicaManager = new ReplicaManager( + metrics = metrics, + config = config, + time = time, + scheduler = null, + logManager = logManager, + quotaManagers = quotaManager, + metadataCache = MetadataCache.zkMetadataCache(config.brokerId), + logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size), + alterIsrManager = alterIsrManager) + + //Given + val epochRequested: Integer = 5 + val request = Seq(newOffsetForLeaderTopic(tp, RecordBatch.NO_PARTITION_LEADER_EPOCH, epochRequested)) + + //When + val response = replicaManager.lastOffsetForLeaderEpoch(request) + + //Then + assertEquals( + Seq(newOffsetForLeaderTopicResult(tp, Errors.UNKNOWN_TOPIC_OR_PARTITION, UNDEFINED_EPOCH, UNDEFINED_EPOCH_OFFSET)), + response) + } + + @AfterEach + def tearDown(): Unit = { + Option(replicaManager).foreach(_.shutdown(checkpointHW = false)) + Option(quotaManager).foreach(_.shutdown()) + metrics.close() + } + + private def newOffsetForLeaderTopic( + tp: TopicPartition, + currentLeaderEpoch: Int, + leaderEpoch: Int + ): OffsetForLeaderTopic = { + new OffsetForLeaderTopic() + .setTopic(tp.topic) + .setPartitions(List(new OffsetForLeaderPartition() + .setPartition(tp.partition) + .setCurrentLeaderEpoch(currentLeaderEpoch) + .setLeaderEpoch(leaderEpoch)).asJava) + } + + private def newOffsetForLeaderTopicResult( + tp: TopicPartition, + error: Errors, + leaderEpoch: Int, + endOffset: Long + ): OffsetForLeaderTopicResult = { + new OffsetForLeaderTopicResult() + .setTopic(tp.topic) + .setPartitions(List(new EpochEndOffset() + .setPartition(tp.partition) + .setErrorCode(error.code) + .setLeaderEpoch(leaderEpoch) + .setEndOffset(endOffset)).asJava) + } +} diff --git a/core/src/test/scala/unit/kafka/server/epoch/util/ReplicaFetcherMockBlockingSend.scala b/core/src/test/scala/unit/kafka/server/epoch/util/ReplicaFetcherMockBlockingSend.scala new file mode 100644 index 0000000..8f3fcff --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/epoch/util/ReplicaFetcherMockBlockingSend.scala @@ -0,0 +1,131 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server.epoch.util + +import java.net.SocketTimeoutException +import java.util +import kafka.cluster.BrokerEndPoint +import kafka.server.BlockingSend +import org.apache.kafka.clients.{ClientRequest, ClientResponse, MockClient, NetworkClientUtils} +import org.apache.kafka.common.message.{FetchResponseData, OffsetForLeaderEpochResponseData} +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.{EpochEndOffset, OffsetForLeaderTopicResult} +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.requests.AbstractRequest.Builder +import org.apache.kafka.common.requests.{AbstractRequest, FetchResponse, OffsetsForLeaderEpochResponse, FetchMetadata => JFetchMetadata} +import org.apache.kafka.common.utils.{SystemTime, Time} +import org.apache.kafka.common.{Node, TopicIdPartition, TopicPartition, Uuid} + +import scala.collection.Map + +/** + * Stub network client used for testing the ReplicaFetcher, wraps the MockClient used for consumer testing + * + * The common case is that there is only one OFFSET_FOR_LEADER_EPOCH request/response. So, the + * response to OFFSET_FOR_LEADER_EPOCH is 'offsets' map. If the test needs to set another round of + * OFFSET_FOR_LEADER_EPOCH with different offsets in response, it should update offsets using + * setOffsetsForNextResponse + */ +class ReplicaFetcherMockBlockingSend(offsets: java.util.Map[TopicPartition, EpochEndOffset], + sourceBroker: BrokerEndPoint, + time: Time) + extends BlockingSend { + + private val client = new MockClient(new SystemTime) + var fetchCount = 0 + var epochFetchCount = 0 + var lastUsedOffsetForLeaderEpochVersion = -1 + var callback: Option[() => Unit] = None + var currentOffsets: util.Map[TopicPartition, EpochEndOffset] = offsets + var fetchPartitionData: Map[TopicPartition, FetchResponseData.PartitionData] = Map.empty + var topicIds: Map[String, Uuid] = Map.empty + private val sourceNode = new Node(sourceBroker.id, sourceBroker.host, sourceBroker.port) + + def setEpochRequestCallback(postEpochFunction: () => Unit): Unit = { + callback = Some(postEpochFunction) + } + + def setOffsetsForNextResponse(newOffsets: util.Map[TopicPartition, EpochEndOffset]): Unit = { + currentOffsets = newOffsets + } + + def setFetchPartitionDataForNextResponse(partitionData: Map[TopicPartition, FetchResponseData.PartitionData]): Unit = { + fetchPartitionData = partitionData + } + + def setIdsForNextResponse(topicIds: Map[String, Uuid]): Unit = { + this.topicIds = topicIds + } + + override def sendRequest(requestBuilder: Builder[_ <: AbstractRequest]): ClientResponse = { + if (!NetworkClientUtils.awaitReady(client, sourceNode, time, 500)) + throw new SocketTimeoutException(s"Failed to connect within 500 ms") + + //Send the request to the mock client + val clientRequest = request(requestBuilder) + client.send(clientRequest, time.milliseconds()) + + //Create a suitable response based on the API key + val response = requestBuilder.apiKey() match { + case ApiKeys.OFFSET_FOR_LEADER_EPOCH => + callback.foreach(_.apply()) + epochFetchCount += 1 + lastUsedOffsetForLeaderEpochVersion = requestBuilder.latestAllowedVersion() + + val data = new OffsetForLeaderEpochResponseData() + currentOffsets.forEach((tp, offsetForLeaderPartition) => { + var topic = data.topics.find(tp.topic) + if (topic == null) { + topic = new OffsetForLeaderTopicResult() + .setTopic(tp.topic) + data.topics.add(topic) + } + topic.partitions.add(offsetForLeaderPartition) + }) + + new OffsetsForLeaderEpochResponse(data) + + case ApiKeys.FETCH => + fetchCount += 1 + val partitionData = new util.LinkedHashMap[TopicIdPartition, FetchResponseData.PartitionData] + fetchPartitionData.foreach { case (tp, data) => partitionData.put(new TopicIdPartition(topicIds.getOrElse(tp.topic(), Uuid.ZERO_UUID), tp), data) } + fetchPartitionData = Map.empty + topicIds = Map.empty + FetchResponse.of(Errors.NONE, 0, + if (partitionData.isEmpty) JFetchMetadata.INVALID_SESSION_ID else 1, + partitionData) + + case _ => + throw new UnsupportedOperationException + } + + //Use mock client to create the appropriate response object + client.respondFrom(response, sourceNode) + client.poll(30, time.milliseconds()).iterator().next() + } + + private def request(requestBuilder: Builder[_ <: AbstractRequest]): ClientRequest = { + client.newClientRequest( + sourceBroker.id.toString, + requestBuilder, + time.milliseconds(), + true) + } + + override def initiateClose(): Unit = {} + + override def close(): Unit = {} +} diff --git a/core/src/test/scala/unit/kafka/server/metadata/BrokerMetadataListenerTest.scala b/core/src/test/scala/unit/kafka/server/metadata/BrokerMetadataListenerTest.scala new file mode 100644 index 0000000..84ed069 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/metadata/BrokerMetadataListenerTest.scala @@ -0,0 +1,233 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server.metadata + +import java.util +import java.util.concurrent.atomic.AtomicReference +import java.util.{Collections, Optional} + +import org.apache.kafka.common.metadata.{PartitionChangeRecord, PartitionRecord, RegisterBrokerRecord, TopicRecord} +import org.apache.kafka.common.utils.Time +import org.apache.kafka.common.{Endpoint, Uuid} +import org.apache.kafka.image.{MetadataDelta, MetadataImage} +import org.apache.kafka.metadata.{BrokerRegistration, RecordTestUtils, VersionRange} +import org.apache.kafka.server.common.ApiMessageAndVersion +import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue} +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ + +class BrokerMetadataListenerTest { + @Test + def testCreateAndClose(): Unit = { + val listener = new BrokerMetadataListener(0, Time.SYSTEM, None, 1000000L, + snapshotter = None) + listener.close() + } + + @Test + def testPublish(): Unit = { + val listener = new BrokerMetadataListener(0, Time.SYSTEM, None, 1000000L, + snapshotter = None) + try { + listener.handleCommit(RecordTestUtils.mockBatchReader(100L, + util.Arrays.asList(new ApiMessageAndVersion(new RegisterBrokerRecord(). + setBrokerId(0). + setBrokerEpoch(100L). + setFenced(false). + setRack(null). + setIncarnationId(Uuid.fromString("GFBwlTcpQUuLYQ2ig05CSg")), 0.toShort)))) + val imageRecords = listener.getImageRecords().get() + assertEquals(0, imageRecords.size()) + assertEquals(100L, listener.highestMetadataOffset) + listener.handleCommit(RecordTestUtils.mockBatchReader(200L, + util.Arrays.asList(new ApiMessageAndVersion(new RegisterBrokerRecord(). + setBrokerId(1). + setBrokerEpoch(200L). + setFenced(true). + setRack(null). + setIncarnationId(Uuid.fromString("QkOQtNKVTYatADcaJ28xDg")), 0.toShort)))) + listener.startPublishing(new MetadataPublisher { + override def publish(delta: MetadataDelta, newImage: MetadataImage): Unit = { + assertEquals(200L, newImage.highestOffsetAndEpoch().offset) + assertEquals(new BrokerRegistration(0, 100L, + Uuid.fromString("GFBwlTcpQUuLYQ2ig05CSg"), Collections.emptyList[Endpoint](), + Collections.emptyMap[String, VersionRange](), Optional.empty[String](), false), + delta.clusterDelta().broker(0)) + assertEquals(new BrokerRegistration(1, 200L, + Uuid.fromString("QkOQtNKVTYatADcaJ28xDg"), Collections.emptyList[Endpoint](), + Collections.emptyMap[String, VersionRange](), Optional.empty[String](), true), + delta.clusterDelta().broker(1)) + } + }).get() + } finally { + listener.close() + } + } + + class MockMetadataSnapshotter extends MetadataSnapshotter { + var image = MetadataImage.EMPTY + val failure = new AtomicReference[Throwable](null) + var activeSnapshotOffset = -1L + var prevCommittedOffset = -1L + var prevCommittedEpoch = -1 + var prevLastContainedLogTime = -1L + + override def maybeStartSnapshot(lastContainedLogTime: Long, newImage: MetadataImage): Boolean = { + try { + if (activeSnapshotOffset == -1L) { + assertTrue(prevCommittedOffset <= newImage.highestOffsetAndEpoch().offset) + assertTrue(prevCommittedEpoch <= newImage.highestOffsetAndEpoch().epoch) + assertTrue(prevLastContainedLogTime <= lastContainedLogTime) + prevCommittedOffset = newImage.highestOffsetAndEpoch().offset + prevCommittedEpoch = newImage.highestOffsetAndEpoch().epoch + prevLastContainedLogTime = lastContainedLogTime + image = newImage + activeSnapshotOffset = newImage.highestOffsetAndEpoch().offset + true + } else { + false + } + } catch { + case t: Throwable => failure.compareAndSet(null, t) + } + } + } + + class MockMetadataPublisher extends MetadataPublisher { + var image = MetadataImage.EMPTY + + override def publish(delta: MetadataDelta, newImage: MetadataImage): Unit = { + image = newImage + } + } + + private val FOO_ID = Uuid.fromString("jj1G9utnTuCegi_gpnRgYw") + + private def generateManyRecords(listener: BrokerMetadataListener, + endOffset: Long): Unit = { + (0 to 10000).foreach { _ => + listener.handleCommit(RecordTestUtils.mockBatchReader(endOffset, + util.Arrays.asList(new ApiMessageAndVersion(new PartitionChangeRecord(). + setPartitionId(0). + setTopicId(FOO_ID). + setRemovingReplicas(Collections.singletonList(1)), 0.toShort), + new ApiMessageAndVersion(new PartitionChangeRecord(). + setPartitionId(0). + setTopicId(FOO_ID). + setRemovingReplicas(Collections.emptyList()), 0.toShort)))) + } + listener.getImageRecords().get() + } + + @Test + def testHandleCommitsWithNoSnapshotterDefined(): Unit = { + val listener = new BrokerMetadataListener(0, Time.SYSTEM, None, 1000L, + snapshotter = None) + try { + val brokerIds = 0 to 3 + + registerBrokers(listener, brokerIds, endOffset = 100L) + createTopicWithOnePartition(listener, replicas = brokerIds, endOffset = 200L) + listener.getImageRecords().get() + assertEquals(200L, listener.highestMetadataOffset) + + generateManyRecords(listener, endOffset = 1000L) + assertEquals(1000L, listener.highestMetadataOffset) + } finally { + listener.close() + } + } + + @Test + def testCreateSnapshot(): Unit = { + val snapshotter = new MockMetadataSnapshotter() + val listener = new BrokerMetadataListener(0, Time.SYSTEM, None, 1000L, Some(snapshotter)) + try { + val brokerIds = 0 to 3 + + registerBrokers(listener, brokerIds, endOffset = 100L) + createTopicWithOnePartition(listener, replicas = brokerIds, endOffset = 200L) + listener.getImageRecords().get() + assertEquals(200L, listener.highestMetadataOffset) + + // Check that we generate at least one snapshot once we see enough records. + assertEquals(-1L, snapshotter.prevCommittedOffset) + generateManyRecords(listener, 1000L) + assertEquals(1000L, snapshotter.prevCommittedOffset) + assertEquals(1000L, snapshotter.activeSnapshotOffset) + snapshotter.activeSnapshotOffset = -1L + + // Test creating a new snapshot after publishing it. + val publisher = new MockMetadataPublisher() + listener.startPublishing(publisher).get() + generateManyRecords(listener, 2000L) + listener.getImageRecords().get() + assertEquals(2000L, snapshotter.activeSnapshotOffset) + assertEquals(2000L, snapshotter.prevCommittedOffset) + + // Test how we handle the snapshotter returning false. + generateManyRecords(listener, 3000L) + assertEquals(2000L, snapshotter.activeSnapshotOffset) + generateManyRecords(listener, 4000L) + assertEquals(2000L, snapshotter.activeSnapshotOffset) + snapshotter.activeSnapshotOffset = -1L + generateManyRecords(listener, 5000L) + assertEquals(5000L, snapshotter.activeSnapshotOffset) + assertEquals(null, snapshotter.failure.get()) + } finally { + listener.close() + } + } + + private def registerBrokers( + listener: BrokerMetadataListener, + brokerIds: Iterable[Int], + endOffset: Long + ): Unit = { + brokerIds.foreach { brokerId => + listener.handleCommit(RecordTestUtils.mockBatchReader(endOffset, + util.Arrays.asList(new ApiMessageAndVersion(new RegisterBrokerRecord(). + setBrokerId(brokerId). + setBrokerEpoch(100L). + setFenced(false). + setRack(null). + setIncarnationId(Uuid.fromString("GFBwlTcpQUuLYQ2ig05CS" + brokerId)), 0.toShort)))) + } + } + + private def createTopicWithOnePartition( + listener: BrokerMetadataListener, + replicas: Seq[Int], + endOffset: Long + ): Unit = { + listener.handleCommit(RecordTestUtils.mockBatchReader(endOffset, + util.Arrays.asList( + new ApiMessageAndVersion(new TopicRecord(). + setName("foo"). + setTopicId(FOO_ID), 0.toShort), + new ApiMessageAndVersion(new PartitionRecord(). + setPartitionId(0). + setTopicId(FOO_ID). + setIsr(replicas.map(Int.box).asJava). + setLeader(0). + setReplicas(replicas.map(Int.box).asJava), 0.toShort))) + ) + } + +} diff --git a/core/src/test/scala/unit/kafka/server/metadata/BrokerMetadataPublisherTest.scala b/core/src/test/scala/unit/kafka/server/metadata/BrokerMetadataPublisherTest.scala new file mode 100644 index 0000000..a8c5002 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/metadata/BrokerMetadataPublisherTest.scala @@ -0,0 +1,145 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package unit.kafka.server.metadata + +import kafka.log.UnifiedLog +import kafka.server.metadata.BrokerMetadataPublisher +import org.apache.kafka.common.{TopicPartition, Uuid} +import org.apache.kafka.image.{MetadataImageTest, TopicImage, TopicsImage} +import org.apache.kafka.metadata.PartitionRegistration +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Assertions.assertEquals + +import org.mockito.Mockito + +import scala.jdk.CollectionConverters._ + +class BrokerMetadataPublisherTest { + @Test + def testGetTopicDelta(): Unit = { + assert(BrokerMetadataPublisher.getTopicDelta( + "not-a-topic", + MetadataImageTest.IMAGE1, + MetadataImageTest.DELTA1).isEmpty, "Expected no delta for unknown topic") + + assert(BrokerMetadataPublisher.getTopicDelta( + "foo", + MetadataImageTest.IMAGE1, + MetadataImageTest.DELTA1).isEmpty, "Expected no delta for deleted topic") + + assert(BrokerMetadataPublisher.getTopicDelta( + "bar", + MetadataImageTest.IMAGE1, + MetadataImageTest.DELTA1).isDefined, "Expected to see delta for changed topic") + } + + @Test + def testFindStrayReplicas(): Unit = { + val brokerId = 0 + + // Topic has been deleted + val deletedTopic = "a" + val deletedTopicId = Uuid.randomUuid() + val deletedTopicPartition1 = new TopicPartition(deletedTopic, 0) + val deletedTopicLog1 = mockLog(deletedTopicId, deletedTopicPartition1) + val deletedTopicPartition2 = new TopicPartition(deletedTopic, 1) + val deletedTopicLog2 = mockLog(deletedTopicId, deletedTopicPartition2) + + // Topic was deleted and recreated + val recreatedTopic = "b" + val recreatedTopicPartition = new TopicPartition(recreatedTopic, 0) + val recreatedTopicLog = mockLog(Uuid.randomUuid(), recreatedTopicPartition) + val recreatedTopicImage = topicImage(Uuid.randomUuid(), recreatedTopic, Map( + recreatedTopicPartition.partition -> Seq(0, 1, 2) + )) + + // Topic exists, but some partitions were reassigned + val reassignedTopic = "c" + val reassignedTopicId = Uuid.randomUuid() + val reassignedTopicPartition = new TopicPartition(reassignedTopic, 0) + val reassignedTopicLog = mockLog(reassignedTopicId, reassignedTopicPartition) + val retainedTopicPartition = new TopicPartition(reassignedTopic, 1) + val retainedTopicLog = mockLog(reassignedTopicId, retainedTopicPartition) + + val reassignedTopicImage = topicImage(reassignedTopicId, reassignedTopic, Map( + reassignedTopicPartition.partition -> Seq(1, 2, 3), + retainedTopicPartition.partition -> Seq(0, 2, 3) + )) + + val logs = Seq( + deletedTopicLog1, + deletedTopicLog2, + recreatedTopicLog, + reassignedTopicLog, + retainedTopicLog + ) + + val image = topicsImage(Seq( + recreatedTopicImage, + reassignedTopicImage + )) + + val expectedStrayPartitions = Set( + deletedTopicPartition1, + deletedTopicPartition2, + recreatedTopicPartition, + reassignedTopicPartition + ) + + val strayPartitions = BrokerMetadataPublisher.findStrayPartitions(brokerId, image, logs).toSet + assertEquals(expectedStrayPartitions, strayPartitions) + } + + private def mockLog( + topicId: Uuid, + topicPartition: TopicPartition + ): UnifiedLog = { + val log = Mockito.mock(classOf[UnifiedLog]) + Mockito.when(log.topicId).thenReturn(Some(topicId)) + Mockito.when(log.topicPartition).thenReturn(topicPartition) + log + } + + private def topicImage( + topicId: Uuid, + topic: String, + partitions: Map[Int, Seq[Int]] + ): TopicImage = { + val partitionRegistrations = partitions.map { case (partitionId, replicas) => + Int.box(partitionId) -> new PartitionRegistration( + replicas.toArray, + replicas.toArray, + Array.empty[Int], + Array.empty[Int], + replicas.head, + 0, + 0 + ) + } + new TopicImage(topic, topicId, partitionRegistrations.asJava) + } + + private def topicsImage( + topics: Seq[TopicImage] + ): TopicsImage = { + val idsMap = topics.map(t => t.id -> t).toMap + val namesMap = topics.map(t => t.name -> t).toMap + new TopicsImage(idsMap.asJava, namesMap.asJava) + } + +} diff --git a/core/src/test/scala/unit/kafka/server/metadata/BrokerMetadataSnapshotterTest.scala b/core/src/test/scala/unit/kafka/server/metadata/BrokerMetadataSnapshotterTest.scala new file mode 100644 index 0000000..888fec5 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/metadata/BrokerMetadataSnapshotterTest.scala @@ -0,0 +1,107 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server.metadata + +import java.nio.ByteBuffer +import java.util.Optional +import java.util.concurrent.{CompletableFuture, CountDownLatch} + +import org.apache.kafka.common.memory.MemoryPool +import org.apache.kafka.common.protocol.ByteBufferAccessor +import org.apache.kafka.common.record.{CompressionType, MemoryRecords} +import org.apache.kafka.common.utils.Time +import org.apache.kafka.image.{MetadataDelta, MetadataImage, MetadataImageTest} +import org.apache.kafka.metadata.MetadataRecordSerde +import org.apache.kafka.queue.EventQueue +import org.apache.kafka.raft.OffsetAndEpoch +import org.apache.kafka.server.common.ApiMessageAndVersion +import org.apache.kafka.snapshot.{MockRawSnapshotWriter, SnapshotWriter} +import org.junit.jupiter.api.Assertions.{assertEquals, assertFalse, assertTrue} +import org.junit.jupiter.api.Test + + +class BrokerMetadataSnapshotterTest { + @Test + def testCreateAndClose(): Unit = { + val snapshotter = new BrokerMetadataSnapshotter(0, Time.SYSTEM, None, + (_, _, _) => throw new RuntimeException("unimplemented")) + snapshotter.close() + } + + class MockSnapshotWriterBuilder extends SnapshotWriterBuilder { + var image = new CompletableFuture[MetadataImage] + + override def build(committedOffset: Long, + committedEpoch: Int, + lastContainedLogTime: Long): SnapshotWriter[ApiMessageAndVersion] = { + val offsetAndEpoch = new OffsetAndEpoch(committedOffset, committedEpoch) + SnapshotWriter.createWithHeader( + () => { + Optional.of( + new MockRawSnapshotWriter(offsetAndEpoch, consumeSnapshotBuffer(committedOffset, committedEpoch)) + ) + }, + 1024, + MemoryPool.NONE, + Time.SYSTEM, + lastContainedLogTime, + CompressionType.NONE, + MetadataRecordSerde.INSTANCE + ).get(); + } + + def consumeSnapshotBuffer(committedOffset: Long, committedEpoch: Int)(buffer: ByteBuffer): Unit = { + val delta = new MetadataDelta(MetadataImage.EMPTY) + val memoryRecords = MemoryRecords.readableRecords(buffer) + val batchIterator = memoryRecords.batchIterator() + while (batchIterator.hasNext) { + val batch = batchIterator.next() + if (!batch.isControlBatch()) { + batch.forEach(record => { + val recordBuffer = record.value().duplicate() + val messageAndVersion = MetadataRecordSerde.INSTANCE.read( + new ByteBufferAccessor(recordBuffer), recordBuffer.remaining()) + delta.replay(committedOffset, committedEpoch, messageAndVersion.message()) + }) + } + } + image.complete(delta.apply()) + } + } + + class BlockingEvent extends EventQueue.Event { + val latch = new CountDownLatch(1) + override def run(): Unit = latch.await() + } + + @Test + def testCreateSnapshot(): Unit = { + val writerBuilder = new MockSnapshotWriterBuilder() + val snapshotter = new BrokerMetadataSnapshotter(0, Time.SYSTEM, None, writerBuilder) + try { + val blockingEvent = new BlockingEvent() + snapshotter.eventQueue.append(blockingEvent) + assertTrue(snapshotter.maybeStartSnapshot(10000L, MetadataImageTest.IMAGE1)) + assertFalse(snapshotter.maybeStartSnapshot(11000L, MetadataImageTest.IMAGE2)) + blockingEvent.latch.countDown() + assertEquals(MetadataImageTest.IMAGE1, writerBuilder.image.get()) + } finally { + snapshotter.close() + } + } +} diff --git a/core/src/test/scala/unit/kafka/server/metadata/MockConfigRepositoryTest.scala b/core/src/test/scala/unit/kafka/server/metadata/MockConfigRepositoryTest.scala new file mode 100644 index 0000000..372372b --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/metadata/MockConfigRepositoryTest.scala @@ -0,0 +1,54 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.server.metadata + +import java.util.Properties + +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test + +class MockConfigRepositoryTest { + @Test + def testEmptyRepository(): Unit = { + val repository = new MockConfigRepository() + assertEquals(new Properties(), repository.brokerConfig(0)) + assertEquals(new Properties(), repository.topicConfig("foo")) + } + + @Test + def testSetTopicConfig(): Unit = { + val repository = new MockConfigRepository() + val topic0 = "topic0" + repository.setTopicConfig(topic0, "foo", null) + + val topic1 = "topic1" + repository.setTopicConfig(topic1, "foo", "bar") + val topicProperties = new Properties() + topicProperties.put("foo", "bar") + assertEquals(topicProperties, repository.topicConfig(topic1)) + + val topicProperties2 = new Properties() + topicProperties2.put("foo", "bar") + topicProperties2.put("foo2", "baz") + repository.setTopicConfig(topic1, "foo2", "baz") // add another prop + assertEquals(topicProperties2, repository.topicConfig(topic1)) // should get both props + + repository.setTopicConfig(topic1, "foo2", null) + assertEquals(topicProperties, repository.topicConfig(topic1)) + } +} diff --git a/core/src/test/scala/unit/kafka/server/metadata/ZkConfigRepositoryTest.scala b/core/src/test/scala/unit/kafka/server/metadata/ZkConfigRepositoryTest.scala new file mode 100644 index 0000000..f873775 --- /dev/null +++ b/core/src/test/scala/unit/kafka/server/metadata/ZkConfigRepositoryTest.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.server + +import java.util.Properties + +import kafka.server.metadata.ZkConfigRepository +import kafka.zk.KafkaZkClient +import org.apache.kafka.common.config.ConfigResource +import org.apache.kafka.common.config.ConfigResource.Type +import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows} +import org.junit.jupiter.api.Test +import org.mockito.Mockito.{mock, when} + +class ZkConfigRepositoryTest { + + @Test + def testZkConfigRepository(): Unit = { + val zkClient: KafkaZkClient = mock(classOf[KafkaZkClient]) + val zkConfigRepository = ZkConfigRepository(zkClient) + val brokerId = 1 + val topic = "topic" + val brokerProps = new Properties() + brokerProps.put("a", "b") + val topicProps = new Properties() + topicProps.put("c", "d") + when(zkClient.getEntityConfigs(ConfigType.Broker, brokerId.toString)).thenReturn(brokerProps) + when(zkClient.getEntityConfigs(ConfigType.Topic, topic)).thenReturn(topicProps) + assertEquals(brokerProps, zkConfigRepository.brokerConfig(brokerId)) + assertEquals(topicProps, zkConfigRepository.topicConfig(topic)) + } + + @Test + def testUnsupportedTypes(): Unit = { + val zkClient: KafkaZkClient = mock(classOf[KafkaZkClient]) + val zkConfigRepository = ZkConfigRepository(zkClient) + Type.values().foreach(value => if (value != Type.BROKER && value != Type.TOPIC) + assertThrows(classOf[IllegalArgumentException], () => zkConfigRepository.config(new ConfigResource(value, value.toString)))) + } +} diff --git a/core/src/test/scala/unit/kafka/tools/ClusterToolTest.scala b/core/src/test/scala/unit/kafka/tools/ClusterToolTest.scala new file mode 100644 index 0000000..b98cd8e --- /dev/null +++ b/core/src/test/scala/unit/kafka/tools/ClusterToolTest.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.tools + +import java.io.{ByteArrayOutputStream, PrintStream} +import org.apache.kafka.clients.admin.MockAdminClient +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{Test, Timeout} + +@Timeout(value = 60) +class ClusterToolTest { + @Test + def testPrintClusterId(): Unit = { + val adminClient = new MockAdminClient.Builder(). + clusterId("QtNwvtfVQ3GEFpzOmDEE-w"). + build() + val stream = new ByteArrayOutputStream() + ClusterTool.clusterIdCommand(new PrintStream(stream), adminClient) + assertEquals( + s"""Cluster ID: QtNwvtfVQ3GEFpzOmDEE-w +""", stream.toString()) + } + + @Test + def testClusterTooOldToHaveId(): Unit = { + val adminClient = new MockAdminClient.Builder(). + clusterId(null). + build() + val stream = new ByteArrayOutputStream() + ClusterTool.clusterIdCommand(new PrintStream(stream), adminClient) + assertEquals( + s"""No cluster ID found. The Kafka version is probably too old. +""", stream.toString()) + } + + @Test + def testUnregisterBroker(): Unit = { + val adminClient = new MockAdminClient.Builder().numBrokers(3). + usingRaftController(true). + build() + val stream = new ByteArrayOutputStream() + ClusterTool.unregisterCommand(new PrintStream(stream), adminClient, 0) + assertEquals( + s"""Broker 0 is no longer registered. +""", stream.toString()) + } + + @Test + def testLegacyModeClusterCannotUnregisterBroker(): Unit = { + val adminClient = new MockAdminClient.Builder().numBrokers(3). + usingRaftController(false). + build() + val stream = new ByteArrayOutputStream() + ClusterTool.unregisterCommand(new PrintStream(stream), adminClient, 0) + assertEquals( + s"""The target cluster does not support the broker unregistration API. +""", stream.toString()) + } +} diff --git a/core/src/test/scala/unit/kafka/tools/ConsoleConsumerTest.scala b/core/src/test/scala/unit/kafka/tools/ConsoleConsumerTest.scala new file mode 100644 index 0000000..9a8b734 --- /dev/null +++ b/core/src/test/scala/unit/kafka/tools/ConsoleConsumerTest.scala @@ -0,0 +1,640 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.tools + +import java.io.{ByteArrayOutputStream, PrintStream} +import java.nio.file.Files +import java.util.{HashMap, Optional, Map => JMap} +import kafka.tools.ConsoleConsumer.ConsumerWrapper +import kafka.utils.{Exit, TestUtils} +import org.apache.kafka.clients.consumer.{ConsumerRecord, MockConsumer, OffsetResetStrategy} +import org.apache.kafka.common.{MessageFormatter, TopicPartition} +import org.apache.kafka.common.record.TimestampType +import org.apache.kafka.clients.consumer.ConsumerConfig +import org.apache.kafka.test.MockDeserializer +import org.mockito.Mockito._ +import org.mockito.ArgumentMatchers +import ArgumentMatchers._ +import org.apache.kafka.common.header.internals.RecordHeaders +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{BeforeEach, Test} + +import scala.jdk.CollectionConverters._ + +class ConsoleConsumerTest { + + @BeforeEach + def setup(): Unit = { + ConsoleConsumer.messageCount = 0 + } + + @Test + def shouldResetUnConsumedOffsetsBeforeExit(): Unit = { + val topic = "test" + val maxMessages: Int = 123 + val totalMessages: Int = 700 + val startOffset: java.lang.Long = 0L + + val mockConsumer = new MockConsumer[Array[Byte], Array[Byte]](OffsetResetStrategy.EARLIEST) + val tp1 = new TopicPartition(topic, 0) + val tp2 = new TopicPartition(topic, 1) + + val consumer = new ConsumerWrapper(Some(topic), None, None, None, mockConsumer) + + mockConsumer.rebalance(List(tp1, tp2).asJava) + mockConsumer.updateBeginningOffsets(Map(tp1 -> startOffset, tp2 -> startOffset).asJava) + + 0 until totalMessages foreach { i => + // add all records, each partition should have half of `totalMessages` + mockConsumer.addRecord(new ConsumerRecord[Array[Byte], Array[Byte]](topic, i % 2, i / 2, "key".getBytes, "value".getBytes)) + } + + val formatter = mock(classOf[MessageFormatter]) + + ConsoleConsumer.process(maxMessages, formatter, consumer, System.out, skipMessageOnError = false) + assertEquals(totalMessages, mockConsumer.position(tp1) + mockConsumer.position(tp2)) + + consumer.resetUnconsumedOffsets() + assertEquals(maxMessages, mockConsumer.position(tp1) + mockConsumer.position(tp2)) + + verify(formatter, times(maxMessages)).writeTo(any(), any()) + } + + @Test + def shouldLimitReadsToMaxMessageLimit(): Unit = { + val consumer = mock(classOf[ConsumerWrapper]) + val formatter = mock(classOf[MessageFormatter]) + val record = new ConsumerRecord("foo", 1, 1, Array[Byte](), Array[Byte]()) + + val messageLimit: Int = 10 + when(consumer.receive()).thenReturn(record) + + ConsoleConsumer.process(messageLimit, formatter, consumer, System.out, true) + + verify(consumer, times(messageLimit)).receive() + verify(formatter, times(messageLimit)).writeTo(any(), any()) + + consumer.cleanup() + } + + @Test + def shouldStopWhenOutputCheckErrorFails(): Unit = { + val consumer = mock(classOf[ConsumerWrapper]) + val formatter = mock(classOf[MessageFormatter]) + val printStream = mock(classOf[PrintStream]) + + val record = new ConsumerRecord("foo", 1, 1, Array[Byte](), Array[Byte]()) + + when(consumer.receive()).thenReturn(record) + //Simulate an error on System.out after the first record has been printed + when(printStream.checkError()).thenReturn(true) + + ConsoleConsumer.process(-1, formatter, consumer, printStream, true) + + verify(formatter).writeTo(any(), ArgumentMatchers.eq(printStream)) + verify(consumer).receive() + verify(printStream).checkError() + + consumer.cleanup() + } + + @Test + def shouldParseValidConsumerValidConfig(): Unit = { + //Given + val args: Array[String] = Array( + "--bootstrap-server", "localhost:9092", + "--topic", "test", + "--from-beginning") + + //When + val config = new ConsoleConsumer.ConsumerConfig(args) + + //Then + assertEquals("localhost:9092", config.bootstrapServer) + assertEquals("test", config.topicArg) + assertEquals(true, config.fromBeginning) + } + + @Test + def shouldParseIncludeArgument(): Unit = { + //Given + val args: Array[String] = Array( + "--bootstrap-server", "localhost:9092", + "--include", "includeTest*", + "--from-beginning") + + //When + val config = new ConsoleConsumer.ConsumerConfig(args) + + //Then + assertEquals("localhost:9092", config.bootstrapServer) + assertEquals("includeTest*", config.includedTopicsArg) + assertEquals(true, config.fromBeginning) + } + + @Test + def shouldParseWhitelistArgument(): Unit = { + //Given + val args: Array[String] = Array( + "--bootstrap-server", "localhost:9092", + "--whitelist", "whitelistTest*", + "--from-beginning") + + //When + val config = new ConsoleConsumer.ConsumerConfig(args) + + //Then + assertEquals("localhost:9092", config.bootstrapServer) + assertEquals("whitelistTest*", config.includedTopicsArg) + assertEquals(true, config.fromBeginning) + } + + @Test + def shouldIgnoreWhitelistArgumentIfIncludeSpecified(): Unit = { + //Given + val args: Array[String] = Array( + "--bootstrap-server", "localhost:9092", + "--include", "includeTest*", + "--whitelist", "whitelistTest*", + "--from-beginning") + + //When + val config = new ConsoleConsumer.ConsumerConfig(args) + + //Then + assertEquals("localhost:9092", config.bootstrapServer) + assertEquals("includeTest*", config.includedTopicsArg) + assertEquals(true, config.fromBeginning) + } + + @Test + def shouldParseValidSimpleConsumerValidConfigWithNumericOffset(): Unit = { + //Given + val args: Array[String] = Array( + "--bootstrap-server", "localhost:9092", + "--topic", "test", + "--partition", "0", + "--offset", "3") + + //When + val config = new ConsoleConsumer.ConsumerConfig(args) + + //Then + assertEquals("localhost:9092", config.bootstrapServer) + assertEquals("test", config.topicArg) + assertEquals(0, config.partitionArg.get) + assertEquals(3, config.offsetArg) + assertEquals(false, config.fromBeginning) + + } + + @Test + def shouldExitOnUnrecognizedNewConsumerOption(): Unit = { + Exit.setExitProcedure((_, message) => throw new IllegalArgumentException(message.orNull)) + + //Given + val args: Array[String] = Array( + "--new-consumer", + "--bootstrap-server", "localhost:9092", + "--topic", "test", + "--from-beginning") + + try assertThrows(classOf[IllegalArgumentException], () => new ConsoleConsumer.ConsumerConfig(args)) + finally Exit.resetExitProcedure() + } + + @Test + def shouldParseValidSimpleConsumerValidConfigWithStringOffset(): Unit = { + //Given + val args: Array[String] = Array( + "--bootstrap-server", "localhost:9092", + "--topic", "test", + "--partition", "0", + "--offset", "LatEst", + "--property", "print.value=false") + + //When + val config = new ConsoleConsumer.ConsumerConfig(args) + + //Then + assertEquals("localhost:9092", config.bootstrapServer) + assertEquals("test", config.topicArg) + assertEquals(0, config.partitionArg.get) + assertEquals(-1, config.offsetArg) + assertEquals(false, config.fromBeginning) + assertEquals(false, config.formatter.asInstanceOf[DefaultMessageFormatter].printValue) + } + + @Test + def shouldParseValidConsumerConfigWithAutoOffsetResetLatest(): Unit = { + //Given + val args: Array[String] = Array( + "--bootstrap-server", "localhost:9092", + "--topic", "test", + "--consumer-property", "auto.offset.reset=latest") + + //When + val config = new ConsoleConsumer.ConsumerConfig(args) + val consumerProperties = ConsoleConsumer.consumerProps(config) + + //Then + assertEquals("localhost:9092", config.bootstrapServer) + assertEquals("test", config.topicArg) + assertEquals(false, config.fromBeginning) + assertEquals("latest", consumerProperties.getProperty(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG)) + } + + @Test + def shouldParseValidConsumerConfigWithAutoOffsetResetEarliest(): Unit = { + //Given + val args: Array[String] = Array( + "--bootstrap-server", "localhost:9092", + "--topic", "test", + "--consumer-property", "auto.offset.reset=earliest") + + //When + val config = new ConsoleConsumer.ConsumerConfig(args) + val consumerProperties = ConsoleConsumer.consumerProps(config) + + //Then + assertEquals("localhost:9092", config.bootstrapServer) + assertEquals("test", config.topicArg) + assertEquals(false, config.fromBeginning) + assertEquals("earliest", consumerProperties.getProperty(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG)) + } + + @Test + def shouldParseValidConsumerConfigWithAutoOffsetResetAndMatchingFromBeginning(): Unit = { + //Given + val args: Array[String] = Array( + "--bootstrap-server", "localhost:9092", + "--topic", "test", + "--consumer-property", "auto.offset.reset=earliest", + "--from-beginning") + + //When + val config = new ConsoleConsumer.ConsumerConfig(args) + val consumerProperties = ConsoleConsumer.consumerProps(config) + + //Then + assertEquals("localhost:9092", config.bootstrapServer) + assertEquals("test", config.topicArg) + assertEquals(true, config.fromBeginning) + assertEquals("earliest", consumerProperties.getProperty(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG)) + } + + @Test + def shouldParseValidConsumerConfigWithNoOffsetReset(): Unit = { + //Given + val args: Array[String] = Array( + "--bootstrap-server", "localhost:9092", + "--topic", "test") + + //When + val config = new ConsoleConsumer.ConsumerConfig(args) + val consumerProperties = ConsoleConsumer.consumerProps(config) + + //Then + assertEquals("localhost:9092", config.bootstrapServer) + assertEquals("test", config.topicArg) + assertEquals(false, config.fromBeginning) + assertEquals("latest", consumerProperties.getProperty(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG)) + } + + @Test + def shouldExitOnInvalidConfigWithAutoOffsetResetAndConflictingFromBeginning(): Unit = { + Exit.setExitProcedure((_, message) => throw new IllegalArgumentException(message.orNull)) + + //Given + val args: Array[String] = Array( + "--bootstrap-server", "localhost:9092", + "--topic", "test", + "--consumer-property", "auto.offset.reset=latest", + "--from-beginning") + try { + val config = new ConsoleConsumer.ConsumerConfig(args) + assertThrows(classOf[IllegalArgumentException], () => ConsoleConsumer.consumerProps(config)) + } + finally Exit.resetExitProcedure() + } + + @Test + def shouldParseConfigsFromFile(): Unit = { + val propsFile = TestUtils.tempFile() + val propsStream = Files.newOutputStream(propsFile.toPath) + propsStream.write("request.timeout.ms=1000\n".getBytes()) + propsStream.write("group.id=group1".getBytes()) + propsStream.close() + val args: Array[String] = Array( + "--bootstrap-server", "localhost:9092", + "--topic", "test", + "--consumer.config", propsFile.getAbsolutePath + ) + + val config = new ConsoleConsumer.ConsumerConfig(args) + + assertEquals("1000", config.consumerProps.getProperty("request.timeout.ms")) + assertEquals("group1", config.consumerProps.getProperty("group.id")) + } + + @Test + def groupIdsProvidedInDifferentPlacesMustMatch(): Unit = { + Exit.setExitProcedure((_, message) => throw new IllegalArgumentException(message.orNull)) + + // different in all three places + var propsFile = TestUtils.tempFile() + var propsStream = Files.newOutputStream(propsFile.toPath) + propsStream.write("group.id=group-from-file".getBytes()) + propsStream.close() + var args: Array[String] = Array( + "--bootstrap-server", "localhost:9092", + "--topic", "test", + "--group", "group-from-arguments", + "--consumer-property", "group.id=group-from-properties", + "--consumer.config", propsFile.getAbsolutePath + ) + + assertThrows(classOf[IllegalArgumentException], () => new ConsoleConsumer.ConsumerConfig(args)) + + // the same in all three places + propsFile = TestUtils.tempFile() + propsStream = Files.newOutputStream(propsFile.toPath) + propsStream.write("group.id=test-group".getBytes()) + propsStream.close() + args = Array( + "--bootstrap-server", "localhost:9092", + "--topic", "test", + "--group", "test-group", + "--consumer-property", "group.id=test-group", + "--consumer.config", propsFile.getAbsolutePath + ) + + var config = new ConsoleConsumer.ConsumerConfig(args) + var props = ConsoleConsumer.consumerProps(config) + assertEquals("test-group", props.getProperty("group.id")) + + // different via --consumer-property and --consumer.config + propsFile = TestUtils.tempFile() + propsStream = Files.newOutputStream(propsFile.toPath) + propsStream.write("group.id=group-from-file".getBytes()) + propsStream.close() + args = Array( + "--bootstrap-server", "localhost:9092", + "--topic", "test", + "--consumer-property", "group.id=group-from-properties", + "--consumer.config", propsFile.getAbsolutePath + ) + + assertThrows(classOf[IllegalArgumentException], () => new ConsoleConsumer.ConsumerConfig(args)) + + // different via --consumer-property and --group + args = Array( + "--bootstrap-server", "localhost:9092", + "--topic", "test", + "--group", "group-from-arguments", + "--consumer-property", "group.id=group-from-properties" + ) + + assertThrows(classOf[IllegalArgumentException], () => new ConsoleConsumer.ConsumerConfig(args)) + + // different via --group and --consumer.config + propsFile = TestUtils.tempFile() + propsStream = Files.newOutputStream(propsFile.toPath) + propsStream.write("group.id=group-from-file".getBytes()) + propsStream.close() + args = Array( + "--bootstrap-server", "localhost:9092", + "--topic", "test", + "--group", "group-from-arguments", + "--consumer.config", propsFile.getAbsolutePath + ) + assertThrows(classOf[IllegalArgumentException], () => new ConsoleConsumer.ConsumerConfig(args)) + + // via --group only + args = Array( + "--bootstrap-server", "localhost:9092", + "--topic", "test", + "--group", "group-from-arguments" + ) + + config = new ConsoleConsumer.ConsumerConfig(args) + props = ConsoleConsumer.consumerProps(config) + assertEquals("group-from-arguments", props.getProperty("group.id")) + + Exit.resetExitProcedure() + } + + @Test + def testCustomPropertyShouldBePassedToConfigureMethod(): Unit = { + val args = Array( + "--bootstrap-server", "localhost:9092", + "--topic", "test", + "--property", "print.key=true", + "--property", "key.deserializer=org.apache.kafka.test.MockDeserializer", + "--property", "key.deserializer.my-props=abc" + ) + val config = new ConsoleConsumer.ConsumerConfig(args) + assertTrue(config.formatter.isInstanceOf[DefaultMessageFormatter]) + assertTrue(config.formatterArgs.containsKey("key.deserializer.my-props")) + val formatter = config.formatter.asInstanceOf[DefaultMessageFormatter] + assertTrue(formatter.keyDeserializer.get.isInstanceOf[MockDeserializer]) + assertEquals(1, formatter.keyDeserializer.get.asInstanceOf[MockDeserializer].configs.size) + assertEquals("abc", formatter.keyDeserializer.get.asInstanceOf[MockDeserializer].configs.get("my-props")) + assertTrue(formatter.keyDeserializer.get.asInstanceOf[MockDeserializer].isKey) + } + + @Test + def shouldParseGroupIdFromBeginningGivenTogether(): Unit = { + // Start from earliest + var args: Array[String] = Array( + "--bootstrap-server", "localhost:9092", + "--topic", "test", + "--group", "test-group", + "--from-beginning") + + var config = new ConsoleConsumer.ConsumerConfig(args) + assertEquals("localhost:9092", config.bootstrapServer) + assertEquals("test", config.topicArg) + assertEquals(-2, config.offsetArg) + assertEquals(true, config.fromBeginning) + + // Start from latest + args = Array( + "--bootstrap-server", "localhost:9092", + "--topic", "test", + "--group", "test-group" + ) + + config = new ConsoleConsumer.ConsumerConfig(args) + assertEquals("localhost:9092", config.bootstrapServer) + assertEquals("test", config.topicArg) + assertEquals(-1, config.offsetArg) + assertEquals(false, config.fromBeginning) + } + + @Test + def shouldExitOnGroupIdAndPartitionGivenTogether(): Unit = { + Exit.setExitProcedure((_, message) => throw new IllegalArgumentException(message.orNull)) + //Given + val args: Array[String] = Array( + "--bootstrap-server", "localhost:9092", + "--topic", "test", + "--group", "test-group", + "--partition", "0") + + try assertThrows(classOf[IllegalArgumentException], () => new ConsoleConsumer.ConsumerConfig(args)) + finally Exit.resetExitProcedure() + } + + @Test + def shouldExitOnOffsetWithoutPartition(): Unit = { + Exit.setExitProcedure((_, message) => throw new IllegalArgumentException(message.orNull)) + //Given + val args: Array[String] = Array( + "--bootstrap-server", "localhost:9092", + "--topic", "test", + "--offset", "10") + + try assertThrows(classOf[IllegalArgumentException], () => new ConsoleConsumer.ConsumerConfig(args)) + finally Exit.resetExitProcedure() + } + + @Test + def testDefaultMessageFormatter(): Unit = { + val record = new ConsumerRecord("topic", 0, 123, "key".getBytes, "value".getBytes) + val formatter = new DefaultMessageFormatter() + val configs: JMap[String, String] = new HashMap() + + formatter.configure(configs) + var out = new ByteArrayOutputStream() + formatter.writeTo(record, new PrintStream(out)) + assertEquals("value\n", out.toString) + + configs.put("print.key", "true") + formatter.configure(configs) + out = new ByteArrayOutputStream() + formatter.writeTo(record, new PrintStream(out)) + assertEquals("key\tvalue\n", out.toString) + + configs.put("print.partition", "true") + formatter.configure(configs) + out = new ByteArrayOutputStream() + formatter.writeTo(record, new PrintStream(out)) + assertEquals("Partition:0\tkey\tvalue\n", out.toString) + + configs.put("print.timestamp", "true") + formatter.configure(configs) + out = new ByteArrayOutputStream() + formatter.writeTo(record, new PrintStream(out)) + assertEquals("NO_TIMESTAMP\tPartition:0\tkey\tvalue\n", out.toString) + + configs.put("print.offset", "true") + formatter.configure(configs) + out = new ByteArrayOutputStream() + formatter.writeTo(record, new PrintStream(out)) + assertEquals("NO_TIMESTAMP\tPartition:0\tOffset:123\tkey\tvalue\n", out.toString) + + out = new ByteArrayOutputStream() + val record2 = new ConsumerRecord("topic", 0, 123, 123L, TimestampType.CREATE_TIME, -1, -1, "key".getBytes, "value".getBytes, + new RecordHeaders(), Optional.empty[Integer]) + formatter.writeTo(record2, new PrintStream(out)) + assertEquals("CreateTime:123\tPartition:0\tOffset:123\tkey\tvalue\n", out.toString) + formatter.close() + } + + @Test + def testNoOpMessageFormatter(): Unit = { + val record = new ConsumerRecord("topic", 0, 123, "key".getBytes, "value".getBytes) + val formatter = new NoOpMessageFormatter() + + formatter.configure(new HashMap()) + val out = new ByteArrayOutputStream() + formatter.writeTo(record, new PrintStream(out)) + assertEquals("", out.toString) + } + + @Test + def shouldExitIfNoTopicOrFilterSpecified(): Unit = { + Exit.setExitProcedure((_, message) => throw new IllegalArgumentException(message.orNull)) + //Given + val args: Array[String] = Array( + "--bootstrap-server", "localhost:9092") + + try assertThrows(classOf[IllegalArgumentException], () => new ConsoleConsumer.ConsumerConfig(args)) + finally Exit.resetExitProcedure() + } + + @Test + def shouldExitIfTopicAndIncludeSpecified(): Unit = { + Exit.setExitProcedure((_, message) => throw new IllegalArgumentException(message.orNull)) + //Given + val args: Array[String] = Array( + "--bootstrap-server", "localhost:9092", + "--topic", "test", + "--include", "includeTest*") + + try assertThrows(classOf[IllegalArgumentException], () => new ConsoleConsumer.ConsumerConfig(args)) + finally Exit.resetExitProcedure() + } + + @Test + def shouldExitIfTopicAndWhitelistSpecified(): Unit = { + Exit.setExitProcedure((_, message) => throw new IllegalArgumentException(message.orNull)) + //Given + val args: Array[String] = Array( + "--bootstrap-server", "localhost:9092", + "--topic", "test", + "--whitelist", "whitelistTest*") + + try assertThrows(classOf[IllegalArgumentException], () => new ConsoleConsumer.ConsumerConfig(args)) + finally Exit.resetExitProcedure() + } + + @Test + def testClientIdOverride(): Unit = { + //Given + val args: Array[String] = Array( + "--bootstrap-server", "localhost:9092", + "--topic", "test", + "--from-beginning", + "--consumer-property", "client.id=consumer-1") + + //When + val config = new ConsoleConsumer.ConsumerConfig(args) + val consumerProperties = ConsoleConsumer.consumerProps(config) + + //Then + assertEquals("consumer-1", consumerProperties.getProperty(ConsumerConfig.CLIENT_ID_CONFIG)) + } + + @Test + def testDefaultClientId(): Unit = { + //Given + val args: Array[String] = Array( + "--bootstrap-server", "localhost:9092", + "--topic", "test", + "--from-beginning") + + //When + val config = new ConsoleConsumer.ConsumerConfig(args) + val consumerProperties = ConsoleConsumer.consumerProps(config) + + //Then + assertEquals("console-consumer", consumerProperties.getProperty(ConsumerConfig.CLIENT_ID_CONFIG)) + } +} diff --git a/core/src/test/scala/unit/kafka/tools/ConsoleProducerTest.scala b/core/src/test/scala/unit/kafka/tools/ConsoleProducerTest.scala new file mode 100644 index 0000000..84aafa1 --- /dev/null +++ b/core/src/test/scala/unit/kafka/tools/ConsoleProducerTest.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.tools + +import kafka.tools.ConsoleProducer.LineMessageReader +import kafka.utils.Exit +import org.apache.kafka.clients.producer.ProducerConfig +import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows} +import org.junit.jupiter.api.Test + +import java.util + +class ConsoleProducerTest { + + val brokerListValidArgs: Array[String] = Array( + "--broker-list", + "localhost:1001,localhost:1002", + "--topic", + "t3", + "--property", + "parse.key=true", + "--property", + "key.separator=#" + ) + val bootstrapServerValidArgs: Array[String] = Array( + "--bootstrap-server", + "localhost:1003,localhost:1004", + "--topic", + "t3", + "--property", + "parse.key=true", + "--property", + "key.separator=#" + ) + val invalidArgs: Array[String] = Array( + "--t", // not a valid argument + "t3" + ) + val bootstrapServerOverride: Array[String] = Array( + "--broker-list", + "localhost:1001", + "--bootstrap-server", + "localhost:1002", + "--topic", + "t3", + ) + val clientIdOverride: Array[String] = Array( + "--broker-list", + "localhost:1001", + "--topic", + "t3", + "--producer-property", + "client.id=producer-1" + ) + + @Test + def testValidConfigsBrokerList(): Unit = { + val config = new ConsoleProducer.ProducerConfig(brokerListValidArgs) + val producerConfig = new ProducerConfig(ConsoleProducer.producerProps(config)) + assertEquals(util.Arrays.asList("localhost:1001", "localhost:1002"), + producerConfig.getList(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG)) + } + + @Test + def testValidConfigsBootstrapServer(): Unit = { + val config = new ConsoleProducer.ProducerConfig(bootstrapServerValidArgs) + val producerConfig = new ProducerConfig(ConsoleProducer.producerProps(config)) + assertEquals(util.Arrays.asList("localhost:1003", "localhost:1004"), + producerConfig.getList(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG)) + } + + @Test + def testInvalidConfigs(): Unit = { + Exit.setExitProcedure((_, message) => throw new IllegalArgumentException(message.orNull)) + try assertThrows(classOf[IllegalArgumentException], () => new ConsoleProducer.ProducerConfig(invalidArgs)) + finally Exit.resetExitProcedure() + } + + @Test + def testParseKeyProp(): Unit = { + val config = new ConsoleProducer.ProducerConfig(brokerListValidArgs) + val reader = Class.forName(config.readerClass).getDeclaredConstructor().newInstance().asInstanceOf[LineMessageReader] + reader.init(System.in,ConsoleProducer.getReaderProps(config)) + assert(reader.keySeparator == "#") + assert(reader.parseKey) + } + + @Test + def testBootstrapServerOverride(): Unit = { + val config = new ConsoleProducer.ProducerConfig(bootstrapServerOverride) + val producerConfig = new ProducerConfig(ConsoleProducer.producerProps(config)) + assertEquals(util.Arrays.asList("localhost:1002"), + producerConfig.getList(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG)) + } + + @Test + def testClientIdOverride(): Unit = { + val config = new ConsoleProducer.ProducerConfig(clientIdOverride) + val producerConfig = new ProducerConfig(ConsoleProducer.producerProps(config)) + assertEquals("producer-1", + producerConfig.getString(ProducerConfig.CLIENT_ID_CONFIG)) + } + + @Test + def testDefaultClientId(): Unit = { + val config = new ConsoleProducer.ProducerConfig(brokerListValidArgs) + val producerConfig = new ProducerConfig(ConsoleProducer.producerProps(config)) + assertEquals("console-producer", + producerConfig.getString(ProducerConfig.CLIENT_ID_CONFIG)) + } +} diff --git a/core/src/test/scala/unit/kafka/tools/ConsumerPerformanceTest.scala b/core/src/test/scala/unit/kafka/tools/ConsumerPerformanceTest.scala new file mode 100644 index 0000000..3cd3193 --- /dev/null +++ b/core/src/test/scala/unit/kafka/tools/ConsumerPerformanceTest.scala @@ -0,0 +1,164 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.tools + +import java.io.{ByteArrayOutputStream, File, PrintWriter} +import java.text.SimpleDateFormat +import kafka.utils.Exit +import org.apache.kafka.clients.consumer.ConsumerConfig +import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows} +import org.junit.jupiter.api.Test + +class ConsumerPerformanceTest { + + private val outContent = new ByteArrayOutputStream() + private val dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss:SSS") + + @Test + def testDetailedHeaderMatchBody(): Unit = { + testHeaderMatchContent(detailed = true, 2, + () => ConsumerPerformance.printConsumerProgress(1, 1024 * 1024, 0, 1, 0, 0, 1, dateFormat, 1L)) + } + + @Test + def testNonDetailedHeaderMatchBody(): Unit = { + testHeaderMatchContent(detailed = false, 2, () => println(s"${dateFormat.format(System.currentTimeMillis)}, " + + s"${dateFormat.format(System.currentTimeMillis)}, 1.0, 1.0, 1, 1.0, 1, 1, 1.1, 1.1")) + } + + @Test + def testConfigBrokerList(): Unit = { + //Given + val args: Array[String] = Array( + "--broker-list", "localhost:9092", + "--topic", "test", + "--messages", "10" + ) + + //When + val config = new ConsumerPerformance.ConsumerPerfConfig(args) + + //Then + assertEquals("localhost:9092", config.brokerHostsAndPorts) + assertEquals("test", config.topic) + assertEquals(10, config.numMessages) + } + + @Test + def testConfigBootStrapServer(): Unit = { + //Given + val args: Array[String] = Array( + "--bootstrap-server", "localhost:9092", + "--topic", "test", + "--messages", "10", + "--print-metrics" + ) + + //When + val config = new ConsumerPerformance.ConsumerPerfConfig(args) + + //Then + assertEquals("localhost:9092", config.brokerHostsAndPorts) + assertEquals("test", config.topic) + assertEquals(10, config.numMessages) + } + + @Test + def testBrokerListOverride(): Unit = { + //Given + val args: Array[String] = Array( + "--broker-list", "localhost:9094", + "--bootstrap-server", "localhost:9092", + "--topic", "test", + "--messages", "10" + ) + + //When + val config = new ConsumerPerformance.ConsumerPerfConfig(args) + + //Then + assertEquals("localhost:9092", config.brokerHostsAndPorts) + assertEquals("test", config.topic) + assertEquals(10, config.numMessages) + } + + @Test + def testConfigWithUnrecognizedOption(): Unit = { + Exit.setExitProcedure((_, message) => throw new IllegalArgumentException(message.orNull)) + //Given + val args: Array[String] = Array( + "--broker-list", "localhost:9092", + "--topic", "test", + "--messages", "10", + "--new-consumer" + ) + try assertThrows(classOf[IllegalArgumentException], () => new ConsumerPerformance.ConsumerPerfConfig(args)) + finally Exit.resetExitProcedure() + } + + @Test + def testClientIdOverride(): Unit = { + val consumerConfigFile = File.createTempFile("test_consumer_config",".conf") + consumerConfigFile.deleteOnExit() + new PrintWriter(consumerConfigFile.getPath) { write("client.id=consumer-1"); close() } + + //Given + val args: Array[String] = Array( + "--broker-list", "localhost:9092", + "--topic", "test", + "--messages", "10", + "--consumer.config", consumerConfigFile.getPath + ) + + //When + val config = new ConsumerPerformance.ConsumerPerfConfig(args) + + //Then + assertEquals("consumer-1", config.props.getProperty(ConsumerConfig.CLIENT_ID_CONFIG)) + } + + @Test + def testDefaultClientId(): Unit = { + //Given + val args: Array[String] = Array( + "--broker-list", "localhost:9092", + "--topic", "test", + "--messages", "10" + ) + + //When + val config = new ConsumerPerformance.ConsumerPerfConfig(args) + + //Then + assertEquals("perf-consumer-client", config.props.getProperty(ConsumerConfig.CLIENT_ID_CONFIG)) + } + + private def testHeaderMatchContent(detailed: Boolean, expectedOutputLineCount: Int, fun: () => Unit): Unit = { + Console.withOut(outContent) { + ConsumerPerformance.printHeader(detailed) + fun() + + val contents = outContent.toString.split("\n") + assertEquals(expectedOutputLineCount, contents.length) + val header = contents(0) + val body = contents(1) + + assertEquals(header.split(",").length, body.split(",").length) + } + } +} diff --git a/core/src/test/scala/unit/kafka/tools/DumpLogSegmentsTest.scala b/core/src/test/scala/unit/kafka/tools/DumpLogSegmentsTest.scala new file mode 100644 index 0000000..bd2aae8 --- /dev/null +++ b/core/src/test/scala/unit/kafka/tools/DumpLogSegmentsTest.scala @@ -0,0 +1,393 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.tools + +import java.io.{ByteArrayOutputStream, File, PrintWriter} +import java.nio.ByteBuffer +import java.util +import java.util.Properties + +import kafka.log.{AppendOrigin, UnifiedLog, LogConfig, LogManager, LogTestUtils} +import kafka.server.{BrokerTopicStats, FetchLogEnd, LogDirFailureChannel} +import kafka.tools.DumpLogSegments.TimeIndexDumpErrors +import kafka.utils.{MockTime, TestUtils} +import org.apache.kafka.common.Uuid +import org.apache.kafka.common.metadata.{PartitionChangeRecord, RegisterBrokerRecord, TopicRecord} +import org.apache.kafka.common.protocol.{ByteBufferAccessor, ObjectSerializationCache} +import org.apache.kafka.common.record.{CompressionType, ControlRecordType, EndTransactionMarker, MemoryRecords, RecordVersion, SimpleRecord} +import org.apache.kafka.common.utils.Utils +import org.apache.kafka.metadata.MetadataRecordSerde +import org.apache.kafka.server.common.ApiMessageAndVersion +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test} + +import scala.jdk.CollectionConverters._ +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +case class BatchInfo(records: Seq[SimpleRecord], hasKeys: Boolean, hasValues: Boolean) + +class DumpLogSegmentsTest { + + val tmpDir = TestUtils.tempDir() + val logDir = TestUtils.randomPartitionLogDir(tmpDir) + val segmentName = "00000000000000000000" + val logFilePath = s"$logDir/$segmentName.log" + val indexFilePath = s"$logDir/$segmentName.index" + val timeIndexFilePath = s"$logDir/$segmentName.timeindex" + val time = new MockTime(0, 0) + + val batches = new ArrayBuffer[BatchInfo] + var log: UnifiedLog = _ + + @BeforeEach + def setUp(): Unit = { + val props = new Properties + props.setProperty(LogConfig.IndexIntervalBytesProp, "128") + log = UnifiedLog(logDir, LogConfig(props), logStartOffset = 0L, recoveryPoint = 0L, scheduler = time.scheduler, + time = time, brokerTopicStats = new BrokerTopicStats, maxProducerIdExpirationMs = 60 * 60 * 1000, + producerIdExpirationCheckIntervalMs = LogManager.ProducerIdExpirationCheckIntervalMs, + logDirFailureChannel = new LogDirFailureChannel(10), topicId = None, keepPartitionMetadataFile = true) + } + + def addSimpleRecords(): Unit = { + val now = System.currentTimeMillis() + val firstBatchRecords = (0 until 10).map { i => new SimpleRecord(now + i * 2, s"message key $i".getBytes, s"message value $i".getBytes)} + batches += BatchInfo(firstBatchRecords, true, true) + val secondBatchRecords = (10 until 30).map { i => new SimpleRecord(now + i * 3, s"message key $i".getBytes, null)} + batches += BatchInfo(secondBatchRecords, true, false) + val thirdBatchRecords = (30 until 50).map { i => new SimpleRecord(now + i * 5, null, s"message value $i".getBytes)} + batches += BatchInfo(thirdBatchRecords, false, true) + val fourthBatchRecords = (50 until 60).map { i => new SimpleRecord(now + i * 7, null)} + batches += BatchInfo(fourthBatchRecords, false, false) + + batches.foreach { batchInfo => + log.appendAsLeader(MemoryRecords.withRecords(CompressionType.NONE, 0, batchInfo.records: _*), + leaderEpoch = 0) + } + // Flush, but don't close so that the indexes are not trimmed and contain some zero entries + log.flush() + } + + @AfterEach + def tearDown(): Unit = { + log.close() + Utils.delete(tmpDir) + } + + @Test + def testBatchAndRecordMetadataOutput(): Unit = { + log.appendAsLeader(MemoryRecords.withRecords(CompressionType.NONE, 0, + new SimpleRecord("a".getBytes), + new SimpleRecord("b".getBytes) + ), leaderEpoch = 0) + + log.appendAsLeader(MemoryRecords.withRecords(CompressionType.GZIP, 0, + new SimpleRecord(time.milliseconds(), "c".getBytes, "1".getBytes), + new SimpleRecord("d".getBytes) + ), leaderEpoch = 3) + + log.appendAsLeader(MemoryRecords.withRecords(CompressionType.NONE, 0, + new SimpleRecord("e".getBytes, null), + new SimpleRecord(null, "f".getBytes), + new SimpleRecord("g".getBytes) + ), leaderEpoch = 3) + + log.appendAsLeader(MemoryRecords.withIdempotentRecords(CompressionType.NONE, 29342342L, 15.toShort, 234123, + new SimpleRecord("h".getBytes) + ), leaderEpoch = 3) + + log.appendAsLeader(MemoryRecords.withTransactionalRecords(CompressionType.GZIP, 98323L, 99.toShort, 266, + new SimpleRecord("i".getBytes), + new SimpleRecord("j".getBytes) + ), leaderEpoch = 5) + + log.appendAsLeader(MemoryRecords.withEndTransactionMarker(98323L, 99.toShort, + new EndTransactionMarker(ControlRecordType.COMMIT, 100) + ), origin = AppendOrigin.Coordinator, leaderEpoch = 7) + + assertDumpLogRecordMetadata() + } + + @Test + def testPrintDataLog(): Unit = { + addSimpleRecords() + def verifyRecordsInOutput(checkKeysAndValues: Boolean, args: Array[String]): Unit = { + def isBatch(index: Int): Boolean = { + var i = 0 + batches.zipWithIndex.foreach { case (batch, batchIndex) => + if (i == index) + return true + + i += 1 + + batch.records.indices.foreach { recordIndex => + if (i == index) + return false + i += 1 + } + } + throw new AssertionError(s"No match for index $index") + } + + val output = runDumpLogSegments(args) + val lines = output.split("\n") + assertTrue(lines.length > 2, s"Data not printed: $output") + val totalRecords = batches.map(_.records.size).sum + var offset = 0 + val batchIterator = batches.iterator + var batch : BatchInfo = null; + (0 until totalRecords + batches.size).foreach { index => + val line = lines(lines.length - totalRecords - batches.size + index) + // The base offset of the batch is the offset of the first record in the batch, so we + // only increment the offset if it's not a batch + if (isBatch(index)) { + assertTrue(line.startsWith(s"baseOffset: $offset lastOffset: "), s"Not a valid batch-level message record: $line") + batch = batchIterator.next() + } else { + assertTrue(line.startsWith(s"${DumpLogSegments.RecordIndent} offset: $offset"), s"Not a valid message record: $line") + if (checkKeysAndValues) { + var suffix = "headerKeys: []" + if (batch.hasKeys) + suffix += s" key: message key $offset" + if (batch.hasValues) + suffix += s" payload: message value $offset" + assertTrue(line.endsWith(suffix), s"Message record missing key or value: $line") + } + offset += 1 + } + } + } + + def verifyNoRecordsInOutput(args: Array[String]): Unit = { + val output = runDumpLogSegments(args) + assertFalse(output.matches("(?s).*offset: [0-9]* isvalid.*"), s"Data should not have been printed: $output") + } + + // Verify that records are printed with --print-data-log even if --deep-iteration is not specified + verifyRecordsInOutput(true, Array("--print-data-log", "--files", logFilePath)) + // Verify that records are printed with --print-data-log if --deep-iteration is also specified + verifyRecordsInOutput(true, Array("--print-data-log", "--deep-iteration", "--files", logFilePath)) + // Verify that records are printed with --value-decoder even if --print-data-log is not specified + verifyRecordsInOutput(true, Array("--value-decoder-class", "kafka.serializer.StringDecoder", "--files", logFilePath)) + // Verify that records are printed with --key-decoder even if --print-data-log is not specified + verifyRecordsInOutput(true, Array("--key-decoder-class", "kafka.serializer.StringDecoder", "--files", logFilePath)) + // Verify that records are printed with --deep-iteration even if --print-data-log is not specified + verifyRecordsInOutput(false, Array("--deep-iteration", "--files", logFilePath)) + + // Verify that records are not printed by default + verifyNoRecordsInOutput(Array("--files", logFilePath)) + } + + @Test + def testDumpIndexMismatches(): Unit = { + addSimpleRecords() + val offsetMismatches = mutable.Map[String, List[(Long, Long)]]() + DumpLogSegments.dumpIndex(new File(indexFilePath), indexSanityOnly = false, verifyOnly = true, offsetMismatches, + Int.MaxValue) + assertEquals(Map.empty, offsetMismatches) + } + + @Test + def testDumpTimeIndexErrors(): Unit = { + addSimpleRecords() + val errors = new TimeIndexDumpErrors + DumpLogSegments.dumpTimeIndex(new File(timeIndexFilePath), indexSanityOnly = false, verifyOnly = true, errors, + Int.MaxValue) + assertEquals(Map.empty, errors.misMatchesForTimeIndexFilesMap) + assertEquals(Map.empty, errors.outOfOrderTimestamp) + assertEquals(Map.empty, errors.shallowOffsetNotFound) + } + + @Test + def testDumpMetadataRecords(): Unit = { + val mockTime = new MockTime + val logConfig = LogTestUtils.createLogConfig(segmentBytes = 1024 * 1024) + val log = LogTestUtils.createLog(logDir, logConfig, new BrokerTopicStats, mockTime.scheduler, mockTime) + + val metadataRecords = Seq( + new ApiMessageAndVersion( + new RegisterBrokerRecord().setBrokerId(0).setBrokerEpoch(10), 0.toShort), + new ApiMessageAndVersion( + new RegisterBrokerRecord().setBrokerId(1).setBrokerEpoch(20), 0.toShort), + new ApiMessageAndVersion( + new TopicRecord().setName("test-topic").setTopicId(Uuid.randomUuid()), 0.toShort), + new ApiMessageAndVersion( + new PartitionChangeRecord().setTopicId(Uuid.randomUuid()).setLeader(1). + setPartitionId(0).setIsr(util.Arrays.asList(0, 1, 2)), 0.toShort) + ) + + val records: Array[SimpleRecord] = metadataRecords.map(message => { + val serde = new MetadataRecordSerde() + val cache = new ObjectSerializationCache + val size = serde.recordSize(message, cache) + val buf = ByteBuffer.allocate(size) + val writer = new ByteBufferAccessor(buf) + serde.write(message, cache, writer) + buf.flip() + new SimpleRecord(null, buf.array) + }).toArray + log.appendAsLeader(MemoryRecords.withRecords(CompressionType.NONE, records:_*), leaderEpoch = 1) + log.flush() + + var output = runDumpLogSegments(Array("--cluster-metadata-decoder", "false", "--files", logFilePath)) + assert(output.contains("TOPIC_RECORD")) + assert(output.contains("BROKER_RECORD")) + + output = runDumpLogSegments(Array("--cluster-metadata-decoder", "--skip-record-metadata", "false", "--files", logFilePath)) + assert(output.contains("TOPIC_RECORD")) + assert(output.contains("BROKER_RECORD")) + + // Bogus metadata record + val buf = ByteBuffer.allocate(4) + val writer = new ByteBufferAccessor(buf) + writer.writeUnsignedVarint(10000) + writer.writeUnsignedVarint(10000) + log.appendAsLeader(MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord(null, buf.array)), leaderEpoch = 2) + log.appendAsLeader(MemoryRecords.withRecords(CompressionType.NONE, records:_*), leaderEpoch = 2) + + output = runDumpLogSegments(Array("--cluster-metadata-decoder", "--skip-record-metadata", "false", "--files", logFilePath)) + assert(output.contains("TOPIC_RECORD")) + assert(output.contains("BROKER_RECORD")) + assert(output.contains("skipping")) + } + + @Test + def testDumpEmptyIndex(): Unit = { + val indexFile = new File(indexFilePath) + new PrintWriter(indexFile).close() + val expectOutput = s"$indexFile is empty.\n" + val outContent = new ByteArrayOutputStream() + Console.withOut(outContent) { + DumpLogSegments.dumpIndex(indexFile, indexSanityOnly = false, verifyOnly = true, + misMatchesForIndexFilesMap = mutable.Map[String, List[(Long, Long)]](), Int.MaxValue) + } + assertEquals(expectOutput, outContent.toString) + } + + private def runDumpLogSegments(args: Array[String]): String = { + val outContent = new ByteArrayOutputStream + Console.withOut(outContent) { + DumpLogSegments.main(args) + } + outContent.toString + } + + private def readBatchMetadata(lines: util.ListIterator[String]): Option[String] = { + while (lines.hasNext) { + val line = lines.next() + if (line.startsWith("|")) { + throw new IllegalStateException("Read unexpected record entry") + } else if (line.startsWith("baseOffset")) { + return Some(line) + } + } + None + } + + private def readBatchRecords(lines: util.ListIterator[String]): Seq[String] = { + val records = mutable.ArrayBuffer.empty[String] + while (lines.hasNext) { + val line = lines.next() + if (line.startsWith("|")) { + records += line.substring(1) + } else { + lines.previous() + return records.toSeq + } + } + records.toSeq + } + + private def parseMetadataFields(line: String): Map[String, String] = { + val fields = mutable.Map.empty[String, String] + val tokens = line.split("\\s+").map(_.trim()).filter(_.nonEmpty).iterator + + while (tokens.hasNext) { + val token = tokens.next() + if (!token.endsWith(":")) { + throw new IllegalStateException(s"Unexpected non-field token $token") + } + + val field = token.substring(0, token.length - 1) + if (!tokens.hasNext) { + throw new IllegalStateException(s"Failed to parse value for $field") + } + + val value = tokens.next() + fields += field -> value + } + + fields.toMap + } + + private def assertDumpLogRecordMetadata(): Unit = { + val logReadInfo = log.read( + startOffset = 0, + maxLength = Int.MaxValue, + isolation = FetchLogEnd, + minOneMessage = true + ) + + val output = runDumpLogSegments(Array("--deep-iteration", "--files", logFilePath)) + val lines = util.Arrays.asList(output.split("\n"): _*).listIterator() + + for (batch <- logReadInfo.records.batches.asScala) { + val parsedBatchOpt = readBatchMetadata(lines) + assertTrue(parsedBatchOpt.isDefined) + + val parsedBatch = parseMetadataFields(parsedBatchOpt.get) + assertEquals(Some(batch.baseOffset), parsedBatch.get("baseOffset").map(_.toLong)) + assertEquals(Some(batch.lastOffset), parsedBatch.get("lastOffset").map(_.toLong)) + assertEquals(Option(batch.countOrNull), parsedBatch.get("count").map(_.toLong)) + assertEquals(Some(batch.partitionLeaderEpoch), parsedBatch.get("partitionLeaderEpoch").map(_.toInt)) + assertEquals(Some(batch.isTransactional), parsedBatch.get("isTransactional").map(_.toBoolean)) + assertEquals(Some(batch.isControlBatch), parsedBatch.get("isControl").map(_.toBoolean)) + assertEquals(Some(batch.producerId), parsedBatch.get("producerId").map(_.toLong)) + assertEquals(Some(batch.producerEpoch), parsedBatch.get("producerEpoch").map(_.toShort)) + assertEquals(Some(batch.baseSequence), parsedBatch.get("baseSequence").map(_.toInt)) + assertEquals(Some(batch.compressionType.name), parsedBatch.get("compresscodec")) + + val parsedRecordIter = readBatchRecords(lines).iterator + for (record <- batch.asScala) { + assertTrue(parsedRecordIter.hasNext) + val parsedRecord = parseMetadataFields(parsedRecordIter.next()) + assertEquals(Some(record.offset), parsedRecord.get("offset").map(_.toLong)) + assertEquals(Some(record.keySize), parsedRecord.get("keySize").map(_.toInt)) + assertEquals(Some(record.valueSize), parsedRecord.get("valueSize").map(_.toInt)) + assertEquals(Some(record.timestamp), parsedRecord.get(batch.timestampType.name).map(_.toLong)) + + if (batch.magic >= RecordVersion.V2.value) { + assertEquals(Some(record.sequence), parsedRecord.get("sequence").map(_.toInt)) + } + + // Batch fields should not be present in the record output + assertEquals(None, parsedRecord.get("baseOffset")) + assertEquals(None, parsedRecord.get("lastOffset")) + assertEquals(None, parsedRecord.get("partitionLeaderEpoch")) + assertEquals(None, parsedRecord.get("producerId")) + assertEquals(None, parsedRecord.get("producerEpoch")) + assertEquals(None, parsedRecord.get("baseSequence")) + assertEquals(None, parsedRecord.get("isTransactional")) + assertEquals(None, parsedRecord.get("isControl")) + assertEquals(None, parsedRecord.get("compresscodec")) + } + } + } + +} diff --git a/core/src/test/scala/unit/kafka/tools/MirrorMakerTest.scala b/core/src/test/scala/unit/kafka/tools/MirrorMakerTest.scala new file mode 100644 index 0000000..2a03ace --- /dev/null +++ b/core/src/test/scala/unit/kafka/tools/MirrorMakerTest.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.tools + +import kafka.consumer.BaseConsumerRecord +import org.apache.kafka.common.record.{RecordBatch, TimestampType} + +import scala.jdk.CollectionConverters._ +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.annotation.nowarn + +@nowarn("cat=deprecation") +class MirrorMakerTest { + + @Test + def testDefaultMirrorMakerMessageHandler(): Unit = { + val now = 12345L + val consumerRecord = BaseConsumerRecord("topic", 0, 1L, now, TimestampType.CREATE_TIME, "key".getBytes, "value".getBytes) + + val result = MirrorMaker.defaultMirrorMakerMessageHandler.handle(consumerRecord) + assertEquals(1, result.size) + + val producerRecord = result.get(0) + assertEquals(now, producerRecord.timestamp) + assertEquals("topic", producerRecord.topic) + assertNull(producerRecord.partition) + assertEquals("key", new String(producerRecord.key)) + assertEquals("value", new String(producerRecord.value)) + } + + @Test + def testDefaultMirrorMakerMessageHandlerWithNoTimestampInSourceMessage(): Unit = { + val consumerRecord = BaseConsumerRecord("topic", 0, 1L, RecordBatch.NO_TIMESTAMP, TimestampType.CREATE_TIME, + "key".getBytes, "value".getBytes) + + val result = MirrorMaker.defaultMirrorMakerMessageHandler.handle(consumerRecord) + assertEquals(1, result.size) + + val producerRecord = result.get(0) + assertNull(producerRecord.timestamp) + assertEquals("topic", producerRecord.topic) + assertNull(producerRecord.partition) + assertEquals("key", new String(producerRecord.key)) + assertEquals("value", new String(producerRecord.value)) + } + + @Test + def testDefaultMirrorMakerMessageHandlerWithHeaders(): Unit = { + val now = 12345L + val consumerRecord = BaseConsumerRecord("topic", 0, 1L, now, TimestampType.CREATE_TIME, "key".getBytes, + "value".getBytes) + consumerRecord.headers.add("headerKey", "headerValue".getBytes) + val result = MirrorMaker.defaultMirrorMakerMessageHandler.handle(consumerRecord) + assertEquals(1, result.size) + + val producerRecord = result.get(0) + assertEquals(now, producerRecord.timestamp) + assertEquals("topic", producerRecord.topic) + assertNull(producerRecord.partition) + assertEquals("key", new String(producerRecord.key)) + assertEquals("value", new String(producerRecord.value)) + assertEquals("headerValue", new String(producerRecord.headers.lastHeader("headerKey").value)) + assertEquals(1, producerRecord.headers.asScala.size) + } +} diff --git a/core/src/test/scala/unit/kafka/tools/StorageToolTest.scala b/core/src/test/scala/unit/kafka/tools/StorageToolTest.scala new file mode 100644 index 0000000..0242c33 --- /dev/null +++ b/core/src/test/scala/unit/kafka/tools/StorageToolTest.scala @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.tools + +import java.io.{ByteArrayOutputStream, PrintStream} +import java.nio.charset.StandardCharsets +import java.nio.file.Files +import java.util +import java.util.Properties + +import kafka.server.{KafkaConfig, MetaProperties} +import kafka.utils.TestUtils +import org.apache.kafka.common.utils.Utils +import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows} +import org.junit.jupiter.api.{Test, Timeout} + + +@Timeout(value = 40) +class StorageToolTest { + private def newSelfManagedProperties() = { + val properties = new Properties() + properties.setProperty(KafkaConfig.LogDirsProp, "/tmp/foo,/tmp/bar") + properties.setProperty(KafkaConfig.ProcessRolesProp, "controller") + properties.setProperty(KafkaConfig.NodeIdProp, "2") + properties.setProperty(KafkaConfig.QuorumVotersProp, s"2@localhost:9092") + properties.setProperty(KafkaConfig.ControllerListenerNamesProp, "PLAINTEXT") + properties + } + + @Test + def testConfigToLogDirectories(): Unit = { + val config = new KafkaConfig(newSelfManagedProperties()) + assertEquals(Seq("/tmp/bar", "/tmp/foo"), StorageTool.configToLogDirectories(config)) + } + + @Test + def testConfigToLogDirectoriesWithMetaLogDir(): Unit = { + val properties = newSelfManagedProperties() + properties.setProperty(KafkaConfig.MetadataLogDirProp, "/tmp/baz") + val config = new KafkaConfig(properties) + assertEquals(Seq("/tmp/bar", "/tmp/baz", "/tmp/foo"), + StorageTool.configToLogDirectories(config)) + } + + @Test + def testInfoCommandOnEmptyDirectory(): Unit = { + val stream = new ByteArrayOutputStream() + val tempDir = TestUtils.tempDir() + try { + assertEquals(1, StorageTool. + infoCommand(new PrintStream(stream), true, Seq(tempDir.toString))) + assertEquals(s"""Found log directory: + ${tempDir.toString} + +Found problem: + ${tempDir.toString} is not formatted. + +""", stream.toString()) + } finally Utils.delete(tempDir) + } + + @Test + def testInfoCommandOnMissingDirectory(): Unit = { + val stream = new ByteArrayOutputStream() + val tempDir = TestUtils.tempDir() + tempDir.delete() + try { + assertEquals(1, StorageTool. + infoCommand(new PrintStream(stream), true, Seq(tempDir.toString))) + assertEquals(s"""Found problem: + ${tempDir.toString} does not exist + +""", stream.toString()) + } finally Utils.delete(tempDir) + } + + @Test + def testInfoCommandOnDirectoryAsFile(): Unit = { + val stream = new ByteArrayOutputStream() + val tempFile = TestUtils.tempFile() + try { + assertEquals(1, StorageTool. + infoCommand(new PrintStream(stream), true, Seq(tempFile.toString))) + assertEquals(s"""Found problem: + ${tempFile.toString} is not a directory + +""", stream.toString()) + } finally tempFile.delete() + } + + @Test + def testInfoWithMismatchedLegacyKafkaConfig(): Unit = { + val stream = new ByteArrayOutputStream() + val tempDir = TestUtils.tempDir() + try { + Files.write(tempDir.toPath.resolve("meta.properties"), + String.join("\n", util.Arrays.asList( + "version=1", + "cluster.id=XcZZOzUqS4yHOjhMQB6JLQ")). + getBytes(StandardCharsets.UTF_8)) + assertEquals(1, StorageTool. + infoCommand(new PrintStream(stream), false, Seq(tempDir.toString))) + assertEquals(s"""Found log directory: + ${tempDir.toString} + +Found metadata: {cluster.id=XcZZOzUqS4yHOjhMQB6JLQ, version=1} + +Found problem: + The kafka configuration file appears to be for a legacy cluster, but the directories are formatted for a cluster in KRaft mode. + +""", stream.toString()) + } finally Utils.delete(tempDir) + } + + @Test + def testInfoWithMismatchedSelfManagedKafkaConfig(): Unit = { + val stream = new ByteArrayOutputStream() + val tempDir = TestUtils.tempDir() + try { + Files.write(tempDir.toPath.resolve("meta.properties"), + String.join("\n", util.Arrays.asList( + "version=0", + "broker.id=1", + "cluster.id=26c36907-4158-4a35-919d-6534229f5241")). + getBytes(StandardCharsets.UTF_8)) + assertEquals(1, StorageTool. + infoCommand(new PrintStream(stream), true, Seq(tempDir.toString))) + assertEquals(s"""Found log directory: + ${tempDir.toString} + +Found metadata: {broker.id=1, cluster.id=26c36907-4158-4a35-919d-6534229f5241, version=0} + +Found problem: + The kafka configuration file appears to be for a cluster in KRaft mode, but the directories are formatted for legacy mode. + +""", stream.toString()) + } finally Utils.delete(tempDir) + } + + @Test + def testFormatEmptyDirectory(): Unit = { + val tempDir = TestUtils.tempDir() + try { + val metaProperties = MetaProperties( + clusterId = "XcZZOzUqS4yHOjhMQB6JLQ", nodeId = 2) + val stream = new ByteArrayOutputStream() + assertEquals(0, StorageTool. + formatCommand(new PrintStream(stream), Seq(tempDir.toString), metaProperties, false)) + assertEquals("Formatting %s%n".format(tempDir), stream.toString()) + + try assertEquals(1, StorageTool. + formatCommand(new PrintStream(new ByteArrayOutputStream()), Seq(tempDir.toString), metaProperties, false)) catch { + case e: TerseFailure => assertEquals(s"Log directory ${tempDir} is already " + + "formatted. Use --ignore-formatted to ignore this directory and format the " + + "others.", e.getMessage) + } + + val stream2 = new ByteArrayOutputStream() + assertEquals(0, StorageTool. + formatCommand(new PrintStream(stream2), Seq(tempDir.toString), metaProperties, true)) + assertEquals("All of the log directories are already formatted.%n".format(), stream2.toString()) + } finally Utils.delete(tempDir) + } + + @Test + def testFormatWithInvalidClusterId(): Unit = { + val config = new KafkaConfig(newSelfManagedProperties()) + assertEquals("Cluster ID string invalid does not appear to be a valid UUID: " + + "Input string `invalid` decoded as 5 bytes, which is not equal to the expected " + + "16 bytes of a base64-encoded UUID", assertThrows(classOf[TerseFailure], + () => StorageTool.buildMetadataProperties("invalid", config)).getMessage) + } +} diff --git a/core/src/test/scala/unit/kafka/utils/CommandLineUtilsTest.scala b/core/src/test/scala/unit/kafka/utils/CommandLineUtilsTest.scala new file mode 100644 index 0000000..8b528ed --- /dev/null +++ b/core/src/test/scala/unit/kafka/utils/CommandLineUtilsTest.scala @@ -0,0 +1,223 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import java.util.Properties + +import joptsimple.{OptionParser, OptionSpec} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +class CommandLineUtilsTest { + + + @Test + def testParseEmptyArg(): Unit = { + val argArray = Array("my.empty.property=") + + assertThrows(classOf[java.lang.IllegalArgumentException], () => CommandLineUtils.parseKeyValueArgs(argArray, acceptMissingValue = false)) + } + + @Test + def testParseEmptyArgWithNoDelimiter(): Unit = { + val argArray = Array("my.empty.property") + + assertThrows(classOf[java.lang.IllegalArgumentException], () => CommandLineUtils.parseKeyValueArgs(argArray, acceptMissingValue = false)) + } + + @Test + def testParseEmptyArgAsValid(): Unit = { + val argArray = Array("my.empty.property=", "my.empty.property1") + val props = CommandLineUtils.parseKeyValueArgs(argArray) + + assertEquals(props.getProperty("my.empty.property"), "", "Value of a key with missing value should be an empty string") + assertEquals(props.getProperty("my.empty.property1"), "", "Value of a key with missing value with no delimiter should be an empty string") + } + + @Test + def testParseSingleArg(): Unit = { + val argArray = Array("my.property=value") + val props = CommandLineUtils.parseKeyValueArgs(argArray) + + assertEquals(props.getProperty("my.property"), "value", "Value of a single property should be 'value' ") + } + + @Test + def testParseArgs(): Unit = { + val argArray = Array("first.property=first","second.property=second") + val props = CommandLineUtils.parseKeyValueArgs(argArray) + + assertEquals(props.getProperty("first.property"), "first", "Value of first property should be 'first'") + assertEquals(props.getProperty("second.property"), "second", "Value of second property should be 'second'") + } + + @Test + def testParseArgsWithMultipleDelimiters(): Unit = { + val argArray = Array("first.property==first", "second.property=second=", "third.property=thi=rd") + val props = CommandLineUtils.parseKeyValueArgs(argArray) + + assertEquals(props.getProperty("first.property"), "=first", "Value of first property should be '=first'") + assertEquals(props.getProperty("second.property"), "second=", "Value of second property should be 'second='") + assertEquals(props.getProperty("third.property"), "thi=rd", "Value of second property should be 'thi=rd'") + } + + val props = new Properties() + val parser = new OptionParser(false) + var stringOpt : OptionSpec[String] = _ + var intOpt : OptionSpec[java.lang.Integer] = _ + var stringOptOptionalArg : OptionSpec[String] = _ + var intOptOptionalArg : OptionSpec[java.lang.Integer] = _ + var stringOptOptionalArgNoDefault : OptionSpec[String] = _ + var intOptOptionalArgNoDefault : OptionSpec[java.lang.Integer] = _ + + def setUpOptions(): Unit = { + stringOpt = parser.accepts("str") + .withRequiredArg + .ofType(classOf[String]) + .defaultsTo("default-string") + intOpt = parser.accepts("int") + .withRequiredArg() + .ofType(classOf[java.lang.Integer]) + .defaultsTo(100) + stringOptOptionalArg = parser.accepts("str-opt") + .withOptionalArg + .ofType(classOf[String]) + .defaultsTo("default-string-2") + intOptOptionalArg = parser.accepts("int-opt") + .withOptionalArg + .ofType(classOf[java.lang.Integer]) + .defaultsTo(200) + stringOptOptionalArgNoDefault = parser.accepts("str-opt-nodef") + .withOptionalArg + .ofType(classOf[String]) + intOptOptionalArgNoDefault = parser.accepts("int-opt-nodef") + .withOptionalArg + .ofType(classOf[java.lang.Integer]) + } + + @Test + def testMaybeMergeOptionsOverwriteExisting(): Unit = { + setUpOptions() + + props.put("skey", "existing-string") + props.put("ikey", "300") + props.put("sokey", "existing-string-2") + props.put("iokey", "400") + props.put("sondkey", "existing-string-3") + props.put("iondkey", "500") + + val options = parser.parse( + "--str", "some-string", + "--int", "600", + "--str-opt", "some-string-2", + "--int-opt", "700", + "--str-opt-nodef", "some-string-3", + "--int-opt-nodef", "800" + ) + + CommandLineUtils.maybeMergeOptions(props, "skey", options, stringOpt) + CommandLineUtils.maybeMergeOptions(props, "ikey", options, intOpt) + CommandLineUtils.maybeMergeOptions(props, "sokey", options, stringOptOptionalArg) + CommandLineUtils.maybeMergeOptions(props, "iokey", options, intOptOptionalArg) + CommandLineUtils.maybeMergeOptions(props, "sondkey", options, stringOptOptionalArgNoDefault) + CommandLineUtils.maybeMergeOptions(props, "iondkey", options, intOptOptionalArgNoDefault) + + assertEquals("some-string", props.get("skey")) + assertEquals("600", props.get("ikey")) + assertEquals("some-string-2", props.get("sokey")) + assertEquals("700", props.get("iokey")) + assertEquals("some-string-3", props.get("sondkey")) + assertEquals("800", props.get("iondkey")) + } + + @Test + def testMaybeMergeOptionsDefaultOverwriteExisting(): Unit = { + setUpOptions() + + props.put("sokey", "existing-string") + props.put("iokey", "300") + props.put("sondkey", "existing-string-2") + props.put("iondkey", "400") + + val options = parser.parse( + "--str-opt", + "--int-opt", + "--str-opt-nodef", + "--int-opt-nodef" + ) + + CommandLineUtils.maybeMergeOptions(props, "sokey", options, stringOptOptionalArg) + CommandLineUtils.maybeMergeOptions(props, "iokey", options, intOptOptionalArg) + CommandLineUtils.maybeMergeOptions(props, "sondkey", options, stringOptOptionalArgNoDefault) + CommandLineUtils.maybeMergeOptions(props, "iondkey", options, intOptOptionalArgNoDefault) + + assertEquals("default-string-2", props.get("sokey")) + assertEquals("200", props.get("iokey")) + assertNull(props.get("sondkey")) + assertNull(props.get("iondkey")) + } + + @Test + def testMaybeMergeOptionsDefaultValueIfNotExist(): Unit = { + setUpOptions() + + val options = parser.parse() + + CommandLineUtils.maybeMergeOptions(props, "skey", options, stringOpt) + CommandLineUtils.maybeMergeOptions(props, "ikey", options, intOpt) + CommandLineUtils.maybeMergeOptions(props, "sokey", options, stringOptOptionalArg) + CommandLineUtils.maybeMergeOptions(props, "iokey", options, intOptOptionalArg) + CommandLineUtils.maybeMergeOptions(props, "sondkey", options, stringOptOptionalArgNoDefault) + CommandLineUtils.maybeMergeOptions(props, "iondkey", options, intOptOptionalArgNoDefault) + + assertEquals("default-string", props.get("skey")) + assertEquals("100", props.get("ikey")) + assertEquals("default-string-2", props.get("sokey")) + assertEquals("200", props.get("iokey")) + assertNull(props.get("sondkey")) + assertNull(props.get("iondkey")) + } + + @Test + def testMaybeMergeOptionsNotOverwriteExisting(): Unit = { + setUpOptions() + + props.put("skey", "existing-string") + props.put("ikey", "300") + props.put("sokey", "existing-string-2") + props.put("iokey", "400") + props.put("sondkey", "existing-string-3") + props.put("iondkey", "500") + + val options = parser.parse() + + CommandLineUtils.maybeMergeOptions(props, "skey", options, stringOpt) + CommandLineUtils.maybeMergeOptions(props, "ikey", options, intOpt) + CommandLineUtils.maybeMergeOptions(props, "sokey", options, stringOptOptionalArg) + CommandLineUtils.maybeMergeOptions(props, "iokey", options, intOptOptionalArg) + CommandLineUtils.maybeMergeOptions(props, "sondkey", options, stringOptOptionalArgNoDefault) + CommandLineUtils.maybeMergeOptions(props, "iondkey", options, intOptOptionalArgNoDefault) + + assertEquals("existing-string", props.get("skey")) + assertEquals("300", props.get("ikey")) + assertEquals("existing-string-2", props.get("sokey")) + assertEquals("400", props.get("iokey")) + assertEquals("existing-string-3", props.get("sondkey")) + assertEquals("500", props.get("iondkey")) + } +} diff --git a/core/src/test/scala/unit/kafka/utils/CoreUtilsTest.scala b/core/src/test/scala/unit/kafka/utils/CoreUtilsTest.scala new file mode 100755 index 0000000..7a86df8 --- /dev/null +++ b/core/src/test/scala/unit/kafka/utils/CoreUtilsTest.scala @@ -0,0 +1,259 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import java.util.{Arrays, Base64, UUID} +import java.util.concurrent.{ConcurrentHashMap, Executors, TimeUnit} +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.locks.ReentrantLock +import java.nio.ByteBuffer +import java.util.regex.Pattern + +import org.junit.jupiter.api.Assertions._ +import kafka.utils.CoreUtils.inLock +import org.apache.kafka.common.KafkaException +import org.junit.jupiter.api.Test +import org.apache.kafka.common.utils.Utils +import org.slf4j.event.Level + +import scala.jdk.CollectionConverters._ +import scala.collection.mutable +import scala.concurrent.duration.Duration +import scala.concurrent.{Await, ExecutionContext, Future} + +class CoreUtilsTest extends Logging { + + val clusterIdPattern = Pattern.compile("[a-zA-Z0-9_\\-]+") + + @Test + def testSwallow(): Unit = { + CoreUtils.swallow(throw new KafkaException("test"), this, Level.INFO) + } + + @Test + def testTryAll(): Unit = { + case class TestException(key: String) extends Exception + + val recorded = mutable.Map.empty[String, Either[TestException, String]] + def recordingFunction(v: Either[TestException, String]): Unit = { + val key = v match { + case Right(key) => key + case Left(e) => e.key + } + recorded(key) = v + } + + CoreUtils.tryAll(Seq( + () => recordingFunction(Right("valid-0")), + () => recordingFunction(Left(new TestException("exception-1"))), + () => recordingFunction(Right("valid-2")), + () => recordingFunction(Left(new TestException("exception-3"))) + )) + var expected = Map( + "valid-0" -> Right("valid-0"), + "exception-1" -> Left(TestException("exception-1")), + "valid-2" -> Right("valid-2"), + "exception-3" -> Left(TestException("exception-3")) + ) + assertEquals(expected, recorded) + + recorded.clear() + CoreUtils.tryAll(Seq( + () => recordingFunction(Right("valid-0")), + () => recordingFunction(Right("valid-1")) + )) + expected = Map( + "valid-0" -> Right("valid-0"), + "valid-1" -> Right("valid-1") + ) + assertEquals(expected, recorded) + + recorded.clear() + CoreUtils.tryAll(Seq( + () => recordingFunction(Left(new TestException("exception-0"))), + () => recordingFunction(Left(new TestException("exception-1"))) + )) + expected = Map( + "exception-0" -> Left(TestException("exception-0")), + "exception-1" -> Left(TestException("exception-1")) + ) + assertEquals(expected, recorded) + } + + @Test + def testCircularIterator(): Unit = { + val l = List(1, 2) + val itl = CoreUtils.circularIterator(l) + assertEquals(1, itl.next()) + assertEquals(2, itl.next()) + assertEquals(1, itl.next()) + assertEquals(2, itl.next()) + assertFalse(itl.isEmpty) + + val s = Set(1, 2) + val its = CoreUtils.circularIterator(s) + assertEquals(1, its.next()) + assertEquals(2, its.next()) + assertEquals(1, its.next()) + assertEquals(2, its.next()) + assertEquals(1, its.next()) + } + + @Test + def testReadBytes(): Unit = { + for(testCase <- List("", "a", "abcd")) { + val bytes = testCase.getBytes + assertTrue(Arrays.equals(bytes, Utils.readBytes(ByteBuffer.wrap(bytes)))) + } + } + + @Test + def testAbs(): Unit = { + assertEquals(0, Utils.abs(Integer.MIN_VALUE)) + assertEquals(1, Utils.abs(-1)) + assertEquals(0, Utils.abs(0)) + assertEquals(1, Utils.abs(1)) + assertEquals(Integer.MAX_VALUE, Utils.abs(Integer.MAX_VALUE)) + } + + @Test + def testReplaceSuffix(): Unit = { + assertEquals("blah.foo.text", CoreUtils.replaceSuffix("blah.foo.txt", ".txt", ".text")) + assertEquals("blah.foo", CoreUtils.replaceSuffix("blah.foo.txt", ".txt", "")) + assertEquals("txt.txt", CoreUtils.replaceSuffix("txt.txt.txt", ".txt", "")) + assertEquals("foo.txt", CoreUtils.replaceSuffix("foo", "", ".txt")) + } + + @Test + def testReadInt(): Unit = { + val values = Array(0, 1, -1, Byte.MaxValue, Short.MaxValue, 2 * Short.MaxValue, Int.MaxValue/2, Int.MinValue/2, Int.MaxValue, Int.MinValue, Int.MaxValue) + val buffer = ByteBuffer.allocate(4 * values.size) + for(i <- 0 until values.length) { + buffer.putInt(i*4, values(i)) + assertEquals(values(i), CoreUtils.readInt(buffer.array, i*4), "Written value should match read value.") + } + } + + @Test + def testCsvList(): Unit = { + val emptyString:String = "" + val nullString:String = null + val emptyList = CoreUtils.parseCsvList(emptyString) + val emptyListFromNullString = CoreUtils.parseCsvList(nullString) + val emptyStringList = Seq.empty[String] + assertTrue(emptyList!=null) + assertTrue(emptyListFromNullString!=null) + assertTrue(emptyStringList.equals(emptyListFromNullString)) + assertTrue(emptyStringList.equals(emptyList)) + } + + @Test + def testCsvMap(): Unit = { + val emptyString: String = "" + val emptyMap = CoreUtils.parseCsvMap(emptyString) + val emptyStringMap = Map.empty[String, String] + assertTrue(emptyMap != null) + assertTrue(emptyStringMap.equals(emptyStringMap)) + + val kvPairsIpV6: String = "a:b:c:v,a:b:c:v" + val ipv6Map = CoreUtils.parseCsvMap(kvPairsIpV6) + for (m <- ipv6Map) { + assertTrue(m._1.equals("a:b:c")) + assertTrue(m._2.equals("v")) + } + + val singleEntry:String = "key:value" + val singleMap = CoreUtils.parseCsvMap(singleEntry) + val value = singleMap.getOrElse("key", 0) + assertTrue(value.equals("value")) + + val kvPairsIpV4: String = "192.168.2.1/30:allow, 192.168.2.1/30:allow" + val ipv4Map = CoreUtils.parseCsvMap(kvPairsIpV4) + for (m <- ipv4Map) { + assertTrue(m._1.equals("192.168.2.1/30")) + assertTrue(m._2.equals("allow")) + } + + val kvPairsSpaces: String = "key:value , key: value" + val spaceMap = CoreUtils.parseCsvMap(kvPairsSpaces) + for (m <- spaceMap) { + assertTrue(m._1.equals("key")) + assertTrue(m._2.equals("value")) + } + } + + @Test + def testInLock(): Unit = { + val lock = new ReentrantLock() + val result = inLock(lock) { + assertTrue(lock.isHeldByCurrentThread, "Should be in lock") + 1 + 1 + } + assertEquals(2, result) + assertFalse(lock.isLocked, "Should be unlocked") + } + + @Test + def testUrlSafeBase64EncodeUUID(): Unit = { + + // Test a UUID that has no + or / characters in base64 encoding [a149b4a3-06e1-4b49-a8cb-8a9c4a59fa46 ->(base64)-> oUm0owbhS0moy4qcSln6Rg==] + val clusterId1 = Base64.getUrlEncoder.withoutPadding.encodeToString(CoreUtils.getBytesFromUuid(UUID.fromString( + "a149b4a3-06e1-4b49-a8cb-8a9c4a59fa46"))) + assertEquals(clusterId1, "oUm0owbhS0moy4qcSln6Rg") + assertEquals(clusterId1.length, 22) + assertTrue(clusterIdPattern.matcher(clusterId1).matches()) + + // Test a UUID that has + or / characters in base64 encoding [d418ec02-277e-4853-81e6-afe30259daec ->(base64)-> 1BjsAid+SFOB5q/jAlna7A==] + val clusterId2 = Base64.getUrlEncoder.withoutPadding.encodeToString(CoreUtils.getBytesFromUuid(UUID.fromString( + "d418ec02-277e-4853-81e6-afe30259daec"))) + assertEquals(clusterId2, "1BjsAid-SFOB5q_jAlna7A") + assertEquals(clusterId2.length, 22) + assertTrue(clusterIdPattern.matcher(clusterId2).matches()) + } + + @Test + def testGenerateUuidAsBase64(): Unit = { + val clusterId = CoreUtils.generateUuidAsBase64() + assertEquals(clusterId.length, 22) + assertTrue(clusterIdPattern.matcher(clusterId).matches()) + } + + @Test + def testAtomicGetOrUpdate(): Unit = { + val count = 1000 + val nThreads = 5 + val createdCount = new AtomicInteger + val map = new ConcurrentHashMap[Int, AtomicInteger]().asScala + implicit val executionContext = ExecutionContext.fromExecutorService(Executors.newFixedThreadPool(nThreads)) + try { + Await.result(Future.traverse(1 to count) { i => + Future { + CoreUtils.atomicGetOrUpdate(map, 0, { + createdCount.incrementAndGet + new AtomicInteger + }).incrementAndGet() + } + }, Duration(1, TimeUnit.MINUTES)) + assertEquals(count, map(0).get) + val created = createdCount.get + assertTrue(created > 0 && created <= nThreads, s"Too many creations $created") + } finally { + executionContext.shutdownNow() + } + } +} diff --git a/core/src/test/scala/unit/kafka/utils/JaasTestUtils.scala b/core/src/test/scala/unit/kafka/utils/JaasTestUtils.scala new file mode 100644 index 0000000..3d257ba --- /dev/null +++ b/core/src/test/scala/unit/kafka/utils/JaasTestUtils.scala @@ -0,0 +1,295 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.utils + +import java.io.{BufferedWriter, File, FileWriter} +import java.util.Properties + +import scala.collection.Seq +import kafka.server.KafkaConfig +import org.apache.kafka.clients.admin.ScramMechanism +import org.apache.kafka.common.utils.Java + +object JaasTestUtils { + + case class Krb5LoginModule(useKeyTab: Boolean, + storeKey: Boolean, + keyTab: String, + principal: String, + debug: Boolean, + serviceName: Option[String]) extends JaasModule { + + def name = + if (Java.isIbmJdk) + "com.ibm.security.auth.module.Krb5LoginModule" + else + "com.sun.security.auth.module.Krb5LoginModule" + + def entries: Map[String, String] = + if (Java.isIbmJdk) + Map( + "principal" -> principal, + "credsType" -> "both" + ) ++ (if (useKeyTab) Map("useKeytab" -> s"file:$keyTab") else Map.empty) + else + Map( + "useKeyTab" -> useKeyTab.toString, + "storeKey" -> storeKey.toString, + "keyTab" -> keyTab, + "principal" -> principal + ) ++ serviceName.map(s => Map("serviceName" -> s)).getOrElse(Map.empty) + } + + case class PlainLoginModule(username: String, + password: String, + debug: Boolean = false, + validUsers: Map[String, String] = Map.empty) extends JaasModule { + + def name = "org.apache.kafka.common.security.plain.PlainLoginModule" + + def entries: Map[String, String] = Map( + "username" -> username, + "password" -> password + ) ++ validUsers.map { case (user, pass) => s"user_$user" -> pass } + + } + + case class ZkDigestModule(debug: Boolean = false, + entries: Map[String, String] = Map.empty) extends JaasModule { + def name = "org.apache.zookeeper.server.auth.DigestLoginModule" + } + + case class ScramLoginModule(username: String, + password: String, + debug: Boolean = false, + tokenProps: Map[String, String] = Map.empty) extends JaasModule { + + def name = "org.apache.kafka.common.security.scram.ScramLoginModule" + + def entries: Map[String, String] = Map( + "username" -> username, + "password" -> password + ) ++ tokenProps.map { case (name, value) => name -> value } + } + + case class OAuthBearerLoginModule(username: String, + debug: Boolean = false) extends JaasModule { + + def name = "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule" + + def entries: Map[String, String] = Map( + "unsecuredLoginStringClaim_sub" -> username + ) + + } + + sealed trait JaasModule { + def name: String + def debug: Boolean + def entries: Map[String, String] + + override def toString: String = { + s"""$name required + | debug=$debug + | ${entries.map { case (k, v) => s"""$k="$v"""" }.mkString("", "\n| ", ";")} + |""".stripMargin + } + } + + case class JaasSection(contextName: String, modules: Seq[JaasModule]) { + override def toString: String = { + s"""|$contextName { + | ${modules.mkString("\n ")} + |}; + |""".stripMargin + } + } + + private val ZkServerContextName = "Server" + private val ZkClientContextName = "Client" + private val ZkUserSuperPasswd = "adminpasswd" + private val ZkUser = "fpj" + private val ZkUserPassword = "fpjsecret" + + val KafkaServerContextName = "KafkaServer" + val KafkaServerPrincipalUnqualifiedName = "kafka" + private val KafkaServerPrincipal = KafkaServerPrincipalUnqualifiedName + "/localhost@EXAMPLE.COM" + val KafkaClientContextName = "KafkaClient" + val KafkaClientPrincipalUnqualifiedName = "client" + private val KafkaClientPrincipal = KafkaClientPrincipalUnqualifiedName + "@EXAMPLE.COM" + val KafkaClientPrincipalUnqualifiedName2 = "client2" + private val KafkaClientPrincipal2 = KafkaClientPrincipalUnqualifiedName2 + "@EXAMPLE.COM" + + val KafkaPlainUser = "plain-user" + private val KafkaPlainPassword = "plain-user-secret" + val KafkaPlainUser2 = "plain-user2" + val KafkaPlainPassword2 = "plain-user2-secret" + val KafkaPlainAdmin = "plain-admin" + private val KafkaPlainAdminPassword = "plain-admin-secret" + + val KafkaScramUser = "scram-user" + val KafkaScramPassword = "scram-user-secret" + val KafkaScramUser2 = "scram-user2" + val KafkaScramPassword2 = "scram-user2-secret" + val KafkaScramAdmin = "scram-admin" + val KafkaScramAdminPassword = "scram-admin-secret" + + val KafkaOAuthBearerUser = "oauthbearer-user" + val KafkaOAuthBearerUser2 = "oauthbearer-user2" + val KafkaOAuthBearerAdmin = "oauthbearer-admin" + + val serviceName = "kafka" + + def saslConfigs(saslProperties: Option[Properties]): Properties = { + val result = saslProperties.getOrElse(new Properties) + // IBM Kerberos module doesn't support the serviceName JAAS property, hence it needs to be + // passed as a Kafka property + if (Java.isIbmJdk && !result.contains(KafkaConfig.SaslKerberosServiceNameProp)) + result.put(KafkaConfig.SaslKerberosServiceNameProp, serviceName) + result + } + + def writeJaasContextsToFile(jaasSections: Seq[JaasSection]): File = { + val jaasFile = TestUtils.tempFile() + writeToFile(jaasFile, jaasSections) + jaasFile + } + + // Returns a SASL/SCRAM configuration using credentials for the given user and password + def scramClientLoginModule(mechanism: String, scramUser: String, scramPassword: String): String = { + if (ScramMechanism.fromMechanismName(mechanism) == ScramMechanism.UNKNOWN) { + throw new IllegalArgumentException("Unsupported SCRAM mechanism " + mechanism) + } + ScramLoginModule( + scramUser, + scramPassword + ).toString + } + + // Returns the dynamic configuration, using credentials for user #1 + def clientLoginModule(mechanism: String, keytabLocation: Option[File], serviceName: String = serviceName): String = + kafkaClientModule(mechanism, keytabLocation, KafkaClientPrincipal, KafkaPlainUser, KafkaPlainPassword, KafkaScramUser, KafkaScramPassword, KafkaOAuthBearerUser, serviceName).toString + + def tokenClientLoginModule(tokenId: String, password: String): String = { + ScramLoginModule( + tokenId, + password, + debug = false, + Map( + "tokenauth" -> "true" + )).toString + } + + def zkSections: Seq[JaasSection] = Seq( + JaasSection(ZkServerContextName, Seq(ZkDigestModule(debug = false, + Map("user_super" -> ZkUserSuperPasswd, s"user_$ZkUser" -> ZkUserPassword)))), + JaasSection(ZkClientContextName, Seq(ZkDigestModule(debug = false, + Map("username" -> ZkUser, "password" -> ZkUserPassword)))) + ) + + def kafkaServerSection(contextName: String, mechanisms: Seq[String], keytabLocation: Option[File]): JaasSection = { + val modules = mechanisms.map { + case "GSSAPI" => + Krb5LoginModule( + useKeyTab = true, + storeKey = true, + keyTab = keytabLocation.getOrElse(throw new IllegalArgumentException("Keytab location not specified for GSSAPI")).getAbsolutePath, + principal = KafkaServerPrincipal, + debug = true, + serviceName = Some(serviceName)) + case "PLAIN" => + PlainLoginModule( + KafkaPlainAdmin, + KafkaPlainAdminPassword, + debug = false, + Map( + KafkaPlainAdmin -> KafkaPlainAdminPassword, + KafkaPlainUser -> KafkaPlainPassword, + KafkaPlainUser2 -> KafkaPlainPassword2 + )) + case "OAUTHBEARER" => + OAuthBearerLoginModule(KafkaOAuthBearerAdmin) + case mechanism => { + if (ScramMechanism.fromMechanismName(mechanism) != ScramMechanism.UNKNOWN) { + ScramLoginModule( + KafkaScramAdmin, + KafkaScramAdminPassword, + debug = false) + } else { + throw new IllegalArgumentException("Unsupported server mechanism " + mechanism) + } + } + } + JaasSection(contextName, modules) + } + + // consider refactoring if more mechanisms are added + private def kafkaClientModule(mechanism: String, + keytabLocation: Option[File], clientPrincipal: String, + plainUser: String, plainPassword: String, + scramUser: String, scramPassword: String, + oauthBearerUser: String, serviceName: String = serviceName): JaasModule = { + mechanism match { + case "GSSAPI" => + Krb5LoginModule( + useKeyTab = true, + storeKey = true, + keyTab = keytabLocation.getOrElse(throw new IllegalArgumentException("Keytab location not specified for GSSAPI")).getAbsolutePath, + principal = clientPrincipal, + debug = true, + serviceName = Some(serviceName) + ) + case "PLAIN" => + PlainLoginModule( + plainUser, + plainPassword + ) + case "OAUTHBEARER" => + OAuthBearerLoginModule( + oauthBearerUser + ) + case mechanism => { + if (ScramMechanism.fromMechanismName(mechanism) != ScramMechanism.UNKNOWN) { + ScramLoginModule( + scramUser, + scramPassword + ) + } else { + throw new IllegalArgumentException("Unsupported client mechanism " + mechanism) + } + } + } + } + + /* + * Used for the static JAAS configuration and it uses the credentials for client#2 + */ + def kafkaClientSection(mechanism: Option[String], keytabLocation: Option[File]): JaasSection = { + JaasSection(KafkaClientContextName, mechanism.map(m => + kafkaClientModule(m, keytabLocation, KafkaClientPrincipal2, KafkaPlainUser2, KafkaPlainPassword2, KafkaScramUser2, KafkaScramPassword2, KafkaOAuthBearerUser2)).toSeq) + } + + private def jaasSectionsToString(jaasSections: Seq[JaasSection]): String = + jaasSections.mkString + + private def writeToFile(file: File, jaasSections: Seq[JaasSection]): Unit = { + val writer = new BufferedWriter(new FileWriter(file)) + try writer.write(jaasSectionsToString(jaasSections)) + finally writer.close() + } + +} diff --git a/core/src/test/scala/unit/kafka/utils/JsonTest.scala b/core/src/test/scala/unit/kafka/utils/JsonTest.scala new file mode 100644 index 0000000..7560d20 --- /dev/null +++ b/core/src/test/scala/unit/kafka/utils/JsonTest.scala @@ -0,0 +1,134 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.utils + +import java.nio.charset.StandardCharsets + +import com.fasterxml.jackson.annotation.JsonProperty +import com.fasterxml.jackson.core.{JsonParseException, JsonProcessingException} +import com.fasterxml.jackson.databind.JsonNode +import com.fasterxml.jackson.databind.node._ +import kafka.utils.JsonTest.TestObject +import kafka.utils.json.JsonValue +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ +import scala.collection.Map + +object JsonTest { + case class TestObject(@JsonProperty("foo") foo: String, @JsonProperty("bar") bar: Int) +} + +class JsonTest { + + @Test + def testJsonParse(): Unit = { + val jnf = JsonNodeFactory.instance + + assertEquals(Some(JsonValue(new ObjectNode(jnf))), Json.parseFull("{}")) + assertEquals(Right(JsonValue(new ObjectNode(jnf))), Json.tryParseFull("{}")) + assertEquals(classOf[Left[JsonProcessingException, JsonValue]], Json.tryParseFull(null).getClass) + assertThrows(classOf[IllegalArgumentException], () => Json.tryParseBytes(null)) + + assertEquals(None, Json.parseFull("")) + assertEquals(classOf[Left[JsonProcessingException, JsonValue]], Json.tryParseFull("").getClass) + + assertEquals(None, Json.parseFull("""{"foo":"bar"s}""")) + val tryRes = Json.tryParseFull("""{"foo":"bar"s}""") + assertTrue(tryRes.isInstanceOf[Left[_, JsonValue]]) + + val objectNode = new ObjectNode( + jnf, + Map[String, JsonNode]("foo" -> new TextNode("bar"), "is_enabled" -> BooleanNode.TRUE).asJava + ) + assertEquals(Some(JsonValue(objectNode)), Json.parseFull("""{"foo":"bar", "is_enabled":true}""")) + assertEquals(Right(JsonValue(objectNode)), Json.tryParseFull("""{"foo":"bar", "is_enabled":true}""")) + + val arrayNode = new ArrayNode(jnf) + Vector(1, 2, 3).map(new IntNode(_)).foreach(arrayNode.add) + assertEquals(Some(JsonValue(arrayNode)), Json.parseFull("[1, 2, 3]")) + + // Test with encoder that properly escapes backslash and quotes + val map = Map("foo1" -> """bar1\,bar2""", "foo2" -> """\bar""").asJava + val encoded = Json.encodeAsString(map) + val decoded = Json.parseFull(encoded) + assertEquals(decoded, Json.parseFull("""{"foo1":"bar1\\,bar2", "foo2":"\\bar"}""")) + } + + @Test + def testEncodeAsString(): Unit = { + assertEquals("null", Json.encodeAsString(null)) + assertEquals("1", Json.encodeAsString(1)) + assertEquals("1", Json.encodeAsString(1L)) + assertEquals("1", Json.encodeAsString(1.toByte)) + assertEquals("1", Json.encodeAsString(1.toShort)) + assertEquals("1.0", Json.encodeAsString(1.0)) + assertEquals(""""str"""", Json.encodeAsString("str")) + assertEquals("true", Json.encodeAsString(true)) + assertEquals("false", Json.encodeAsString(false)) + assertEquals("[]", Json.encodeAsString(Seq().asJava)) + assertEquals("[null]", Json.encodeAsString(Seq(null).asJava)) + assertEquals("[1,2,3]", Json.encodeAsString(Seq(1,2,3).asJava)) + assertEquals("""[1,"2",[3],null]""", Json.encodeAsString(Seq(1,"2",Seq(3).asJava,null).asJava)) + assertEquals("{}", Json.encodeAsString(Map().asJava)) + assertEquals("""{"a":1,"b":2,"c":null}""", Json.encodeAsString(Map("a" -> 1, "b" -> 2, "c" -> null).asJava)) + assertEquals("""{"a":[1,2],"c":[3,4]}""", Json.encodeAsString(Map("a" -> Seq(1,2).asJava, "c" -> Seq(3,4).asJava).asJava)) + assertEquals("""{"a":[1,2],"b":[3,4],"c":null}""", Json.encodeAsString(Map("a" -> Seq(1,2).asJava, "b" -> Seq(3,4).asJava, "c" -> null).asJava)) + assertEquals(""""str1\\,str2"""", Json.encodeAsString("""str1\,str2""")) + assertEquals(""""\"quoted\""""", Json.encodeAsString(""""quoted"""")) + } + + @Test + def testEncodeAsBytes(): Unit = { + assertEquals("null", new String(Json.encodeAsBytes(null), StandardCharsets.UTF_8)) + assertEquals("1", new String(Json.encodeAsBytes(1), StandardCharsets.UTF_8)) + assertEquals("1", new String(Json.encodeAsBytes(1L), StandardCharsets.UTF_8)) + assertEquals("1", new String(Json.encodeAsBytes(1.toByte), StandardCharsets.UTF_8)) + assertEquals("1", new String(Json.encodeAsBytes(1.toShort), StandardCharsets.UTF_8)) + assertEquals("1.0", new String(Json.encodeAsBytes(1.0), StandardCharsets.UTF_8)) + assertEquals(""""str"""", new String(Json.encodeAsBytes("str"), StandardCharsets.UTF_8)) + assertEquals("true", new String(Json.encodeAsBytes(true), StandardCharsets.UTF_8)) + assertEquals("false", new String(Json.encodeAsBytes(false), StandardCharsets.UTF_8)) + assertEquals("[]", new String(Json.encodeAsBytes(Seq().asJava), StandardCharsets.UTF_8)) + assertEquals("[null]", new String(Json.encodeAsBytes(Seq(null).asJava), StandardCharsets.UTF_8)) + assertEquals("[1,2,3]", new String(Json.encodeAsBytes(Seq(1,2,3).asJava), StandardCharsets.UTF_8)) + assertEquals("""[1,"2",[3],null]""", new String(Json.encodeAsBytes(Seq(1,"2",Seq(3).asJava,null).asJava), StandardCharsets.UTF_8)) + assertEquals("{}", new String(Json.encodeAsBytes(Map().asJava), StandardCharsets.UTF_8)) + assertEquals("""{"a":1,"b":2,"c":null}""", new String(Json.encodeAsBytes(Map("a" -> 1, "b" -> 2, "c" -> null).asJava), StandardCharsets.UTF_8)) + assertEquals("""{"a":[1,2],"c":[3,4]}""", new String(Json.encodeAsBytes(Map("a" -> Seq(1,2).asJava, "c" -> Seq(3,4).asJava).asJava), StandardCharsets.UTF_8)) + assertEquals("""{"a":[1,2],"b":[3,4],"c":null}""", new String(Json.encodeAsBytes(Map("a" -> Seq(1,2).asJava, "b" -> Seq(3,4).asJava, "c" -> null).asJava), StandardCharsets.UTF_8)) + assertEquals(""""str1\\,str2"""", new String(Json.encodeAsBytes("""str1\,str2"""), StandardCharsets.UTF_8)) + assertEquals(""""\"quoted\""""", new String(Json.encodeAsBytes(""""quoted""""), StandardCharsets.UTF_8)) + } + + @Test + def testParseTo() = { + val foo = "baz" + val bar = 1 + + val result = Json.parseStringAs[TestObject](s"""{"foo": "$foo", "bar": $bar}""") + + assertEquals(Right(TestObject(foo, bar)), result) + } + + @Test + def testParseToWithInvalidJson() = { + val result = Json.parseStringAs[TestObject]("{invalid json}") + assertEquals(Left(classOf[JsonParseException]), result.left.map(_.getClass)) + } +} diff --git a/core/src/test/scala/unit/kafka/utils/LogCaptureAppender.scala b/core/src/test/scala/unit/kafka/utils/LogCaptureAppender.scala new file mode 100644 index 0000000..2d07145 --- /dev/null +++ b/core/src/test/scala/unit/kafka/utils/LogCaptureAppender.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import org.apache.log4j.{AppenderSkeleton, Level, Logger} +import org.apache.log4j.spi.LoggingEvent + +import scala.collection.mutable.ListBuffer + +class LogCaptureAppender extends AppenderSkeleton { + private val events: ListBuffer[LoggingEvent] = ListBuffer.empty + + override protected def append(event: LoggingEvent): Unit = { + events.synchronized { + events += event + } + } + + def getMessages: ListBuffer[LoggingEvent] = { + events.synchronized { + return events.clone() + } + } + + override def close(): Unit = { + events.synchronized { + events.clear() + } + } + + override def requiresLayout: Boolean = false +} + +object LogCaptureAppender { + def createAndRegister(): LogCaptureAppender = { + val logCaptureAppender: LogCaptureAppender = new LogCaptureAppender + Logger.getRootLogger.addAppender(logCaptureAppender) + logCaptureAppender + } + + def setClassLoggerLevel(clazz: Class[_], logLevel: Level): Level = { + val logger = Logger.getLogger(clazz) + val previousLevel = logger.getLevel + Logger.getLogger(clazz).setLevel(logLevel) + previousLevel + } + + def unregister(logCaptureAppender: LogCaptureAppender): Unit = { + Logger.getRootLogger.removeAppender(logCaptureAppender) + } +} diff --git a/core/src/test/scala/unit/kafka/utils/MockScheduler.scala b/core/src/test/scala/unit/kafka/utils/MockScheduler.scala new file mode 100644 index 0000000..de3dfd7 --- /dev/null +++ b/core/src/test/scala/unit/kafka/utils/MockScheduler.scala @@ -0,0 +1,149 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.utils + +import scala.collection.mutable.PriorityQueue +import java.util.concurrent.{Delayed, ScheduledFuture, TimeUnit} + +import org.apache.kafka.common.utils.Time + +/** + * A mock scheduler that executes tasks synchronously using a mock time instance. Tasks are executed synchronously when + * the time is advanced. This class is meant to be used in conjunction with MockTime. + * + * Example usage + * + * val time = new MockTime + * time.scheduler.schedule("a task", println("hello world: " + time.milliseconds), delay = 1000) + * time.sleep(1001) // this should cause our scheduled task to fire + * + * + * Incrementing the time to the exact next execution time of a task will result in that task executing (it as if execution itself takes no time). + */ +class MockScheduler(val time: Time) extends Scheduler { + + /* a priority queue of tasks ordered by next execution time */ + private val tasks = new PriorityQueue[MockTask]() + + def isStarted = true + + def startup(): Unit = {} + + def shutdown(): Unit = { + var currTask: Option[MockTask] = None + do { + currTask = poll(_ => true) + currTask.foreach(_.fun()) + } while (currTask.nonEmpty) + } + + /** + * Check for any tasks that need to execute. Since this is a mock scheduler this check only occurs + * when this method is called and the execution happens synchronously in the calling thread. + * If you are using the scheduler associated with a MockTime instance this call be triggered automatically. + */ + def tick(): Unit = { + val now = time.milliseconds + var currTask: Option[MockTask] = None + /* pop and execute the task with the lowest next execution time if ready */ + do { + currTask = poll(_.nextExecution <= now) + currTask.foreach { curr => + curr.fun() + /* if the task is periodic, reschedule it and re-enqueue */ + if(curr.periodic) { + curr.nextExecution += curr.period + add(curr) + } + } + } while (currTask.nonEmpty) + } + + def schedule(name: String, fun: () => Unit, delay: Long = 0, period: Long = -1, unit: TimeUnit = TimeUnit.MILLISECONDS): ScheduledFuture[Unit] = { + val task = MockTask(name, fun, time.milliseconds + delay, period = period, time=time) + add(task) + tick() + task + } + + def clear(): Unit = { + this synchronized { + tasks.clear() + } + } + + private def poll(predicate: MockTask => Boolean): Option[MockTask] = { + this synchronized { + if (tasks.nonEmpty && predicate.apply(tasks.head)) + Some(tasks.dequeue()) + else + None + } + } + + private def add(task: MockTask): Unit = { + this synchronized { + tasks += task + } + } +} + +case class MockTask(name: String, fun: () => Unit, var nextExecution: Long, period: Long, time: Time) extends ScheduledFuture[Unit] { + def periodic = period >= 0 + def compare(t: MockTask): Int = { + if(t.nextExecution == nextExecution) + 0 + else if (t.nextExecution < nextExecution) + -1 + else + 1 + } + + /** + * Not used, so not not fully implemented + */ + def cancel(mayInterruptIfRunning: Boolean) : Boolean = { + false + } + + def get(): Unit = { + } + + def get(timeout: Long, unit: TimeUnit): Unit = { + } + + def isCancelled: Boolean = { + false + } + + def isDone: Boolean = { + false + } + + def getDelay(unit: TimeUnit): Long = { + this synchronized { + time.milliseconds - nextExecution + } + } + + def compareTo(o : Delayed) : Int = { + this.getDelay(TimeUnit.MILLISECONDS).compareTo(o.getDelay(TimeUnit.MILLISECONDS)) + } +} +object MockTask { + implicit def MockTaskOrdering: Ordering[MockTask] = (x, y) => x.compare(y) +} diff --git a/core/src/test/scala/unit/kafka/utils/MockTime.scala b/core/src/test/scala/unit/kafka/utils/MockTime.scala new file mode 100644 index 0000000..bf0e7bd --- /dev/null +++ b/core/src/test/scala/unit/kafka/utils/MockTime.scala @@ -0,0 +1,40 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import org.apache.kafka.common.utils.{MockTime => JMockTime} + +/** + * A class used for unit testing things which depend on the Time interface. + * There a couple of difference between this class and `org.apache.kafka.common.utils.MockTime`: + * + * 1. This has an associated scheduler instance for managing background tasks in a deterministic way. + * 2. This doesn't support the `auto-tick` functionality as it interacts badly with the current implementation of `MockScheduler`. + */ +class MockTime(currentTimeMs: Long, currentHiResTimeNs: Long) extends JMockTime(0, currentTimeMs, currentHiResTimeNs) { + + def this() = this(System.currentTimeMillis(), System.nanoTime()) + + val scheduler = new MockScheduler(this) + + override def sleep(ms: Long): Unit = { + super.sleep(ms) + scheduler.tick() + } + +} diff --git a/core/src/test/scala/unit/kafka/utils/PasswordEncoderTest.scala b/core/src/test/scala/unit/kafka/utils/PasswordEncoderTest.scala new file mode 100755 index 0000000..0a5d5ac --- /dev/null +++ b/core/src/test/scala/unit/kafka/utils/PasswordEncoderTest.scala @@ -0,0 +1,124 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + + +import javax.crypto.SecretKeyFactory + +import kafka.server.Defaults +import org.apache.kafka.common.config.ConfigException +import org.apache.kafka.common.config.types.Password +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +class PasswordEncoderTest { + + @Test + def testEncodeDecode(): Unit = { + val encoder = new PasswordEncoder(new Password("password-encoder-secret"), + None, + Defaults.PasswordEncoderCipherAlgorithm, + Defaults.PasswordEncoderKeyLength, + Defaults.PasswordEncoderIterations) + val password = "test-password" + val encoded = encoder.encode(new Password(password)) + val encodedMap = CoreUtils.parseCsvMap(encoded) + assertEquals("4096", encodedMap(PasswordEncoder.IterationsProp)) + assertEquals("128", encodedMap(PasswordEncoder.KeyLengthProp)) + val defaultKeyFactoryAlgorithm = try { + SecretKeyFactory.getInstance("PBKDF2WithHmacSHA512") + "PBKDF2WithHmacSHA512" + } catch { + case _: Exception => "PBKDF2WithHmacSHA1" + } + assertEquals(defaultKeyFactoryAlgorithm, encodedMap(PasswordEncoder.KeyFactoryAlgorithmProp)) + assertEquals("AES/CBC/PKCS5Padding", encodedMap(PasswordEncoder.CipherAlgorithmProp)) + + verifyEncodedPassword(encoder, password, encoded) + } + + @Test + def testEncoderConfigChange(): Unit = { + val encoder = new PasswordEncoder(new Password("password-encoder-secret"), + Some("PBKDF2WithHmacSHA1"), + "DES/CBC/PKCS5Padding", + 64, + 1024) + val password = "test-password" + val encoded = encoder.encode(new Password(password)) + val encodedMap = CoreUtils.parseCsvMap(encoded) + assertEquals("1024", encodedMap(PasswordEncoder.IterationsProp)) + assertEquals("64", encodedMap(PasswordEncoder.KeyLengthProp)) + assertEquals("PBKDF2WithHmacSHA1", encodedMap(PasswordEncoder.KeyFactoryAlgorithmProp)) + assertEquals("DES/CBC/PKCS5Padding", encodedMap(PasswordEncoder.CipherAlgorithmProp)) + + // Test that decoding works even if PasswordEncoder algorithm, iterations etc. are altered + val decoder = new PasswordEncoder(new Password("password-encoder-secret"), + Some("PBKDF2WithHmacSHA1"), + "AES/CBC/PKCS5Padding", + 128, + 2048) + assertEquals(password, decoder.decode(encoded).value) + + // Test that decoding fails if secret is altered + val decoder2 = new PasswordEncoder(new Password("secret-2"), + Some("PBKDF2WithHmacSHA1"), + "AES/CBC/PKCS5Padding", + 128, + 1024) + try { + decoder2.decode(encoded) + } catch { + case e: ConfigException => // expected exception + } + } + + @Test + def testEncodeDecodeAlgorithms(): Unit = { + + def verifyEncodeDecode(keyFactoryAlg: Option[String], cipherAlg: String, keyLength: Int): Unit = { + val encoder = new PasswordEncoder(new Password("password-encoder-secret"), + keyFactoryAlg, + cipherAlg, + keyLength, + Defaults.PasswordEncoderIterations) + val password = "test-password" + val encoded = encoder.encode(new Password(password)) + verifyEncodedPassword(encoder, password, encoded) + } + + verifyEncodeDecode(keyFactoryAlg = None, "DES/CBC/PKCS5Padding", keyLength = 64) + verifyEncodeDecode(keyFactoryAlg = None, "DESede/CBC/PKCS5Padding", keyLength = 192) + verifyEncodeDecode(keyFactoryAlg = None, "AES/CBC/PKCS5Padding", keyLength = 128) + verifyEncodeDecode(keyFactoryAlg = None, "AES/CFB/PKCS5Padding", keyLength = 128) + verifyEncodeDecode(keyFactoryAlg = None, "AES/OFB/PKCS5Padding", keyLength = 128) + verifyEncodeDecode(keyFactoryAlg = Some("PBKDF2WithHmacSHA1"), Defaults.PasswordEncoderCipherAlgorithm, keyLength = 128) + verifyEncodeDecode(keyFactoryAlg = None, "AES/GCM/NoPadding", keyLength = 128) + verifyEncodeDecode(keyFactoryAlg = Some("PBKDF2WithHmacSHA256"), Defaults.PasswordEncoderCipherAlgorithm, keyLength = 128) + verifyEncodeDecode(keyFactoryAlg = Some("PBKDF2WithHmacSHA512"), Defaults.PasswordEncoderCipherAlgorithm, keyLength = 128) + } + + private def verifyEncodedPassword(encoder: PasswordEncoder, password: String, encoded: String): Unit = { + val encodedMap = CoreUtils.parseCsvMap(encoded) + assertEquals(password.length.toString, encodedMap(PasswordEncoder.PasswordLengthProp)) + assertNotNull(encoder.base64Decode(encodedMap("salt")), "Invalid salt") + assertNotNull(encoder.base64Decode(encodedMap(PasswordEncoder.InitializationVectorProp)), "Invalid encoding parameters") + assertNotNull(encoder.base64Decode(encodedMap(PasswordEncoder.EncyrptedPasswordProp)), "Invalid encoded password") + assertEquals(password, encoder.decode(encoded).value) + } +} diff --git a/core/src/test/scala/unit/kafka/utils/PoolTest.scala b/core/src/test/scala/unit/kafka/utils/PoolTest.scala new file mode 100644 index 0000000..4f88329 --- /dev/null +++ b/core/src/test/scala/unit/kafka/utils/PoolTest.scala @@ -0,0 +1,40 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test + + +class PoolTest { + @Test + def testRemoveAll(): Unit = { + val pool = new Pool[Int, String] + pool.put(1, "1") + pool.put(2, "2") + pool.put(3, "3") + + assertEquals(3, pool.size) + + pool.removeAll(Seq(1, 2)) + assertEquals(1, pool.size) + assertEquals("3", pool.get(3)) + pool.removeAll(Seq(3)) + assertEquals(0, pool.size) + } +} diff --git a/core/src/test/scala/unit/kafka/utils/QuotaUtilsTest.scala b/core/src/test/scala/unit/kafka/utils/QuotaUtilsTest.scala new file mode 100755 index 0000000..949618d --- /dev/null +++ b/core/src/test/scala/unit/kafka/utils/QuotaUtilsTest.scala @@ -0,0 +1,134 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import java.util.concurrent.TimeUnit + +import org.apache.kafka.common.MetricName +import org.apache.kafka.common.metrics.{KafkaMetric, MetricConfig, Quota, QuotaViolationException} +import org.apache.kafka.common.metrics.stats.{Rate, Value} + +import scala.jdk.CollectionConverters._ +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +class QuotaUtilsTest { + + private val time = new MockTime + private val numSamples = 10 + private val sampleWindowSec = 1 + private val maxThrottleTimeMs = 500 + private val metricName = new MetricName("test-metric", "groupA", "testA", Map.empty.asJava) + + @Test + def testThrottleTimeObservedRateEqualsQuota(): Unit = { + val numSamples = 10 + val observedValue = 16.5 + + assertEquals(0, throttleTime(observedValue, observedValue, numSamples)) + + // should be independent of window size + assertEquals(0, throttleTime(observedValue, observedValue, numSamples + 1)) + } + + @Test + def testThrottleTimeObservedRateBelowQuota(): Unit = { + val observedValue = 16.5 + val quota = 20.4 + assertTrue(throttleTime(observedValue, quota, numSamples) < 0) + + // should be independent of window size + assertTrue(throttleTime(observedValue, quota, numSamples + 1) < 0) + } + + @Test + def testThrottleTimeObservedRateAboveQuota(): Unit = { + val quota = 50.0 + val observedValue = 100.0 + assertEquals(2000, throttleTime(observedValue, quota, 3)) + } + + @Test + def testBoundedThrottleTimeObservedRateEqualsQuota(): Unit = { + val observedValue = 18.2 + assertEquals(0, boundedThrottleTime(observedValue, observedValue, numSamples, maxThrottleTimeMs)) + + // should be independent of window size + assertEquals(0, boundedThrottleTime(observedValue, observedValue, numSamples + 1, maxThrottleTimeMs)) + } + + @Test + def testBoundedThrottleTimeObservedRateBelowQuota(): Unit = { + val observedValue = 16.5 + val quota = 22.4 + + assertTrue(boundedThrottleTime(observedValue, quota, numSamples, maxThrottleTimeMs) < 0) + + // should be independent of window size + assertTrue(boundedThrottleTime(observedValue, quota, numSamples + 1, maxThrottleTimeMs) < 0) + } + + @Test + def testBoundedThrottleTimeObservedRateAboveQuotaBelowLimit(): Unit = { + val quota = 50.0 + val observedValue = 55.0 + assertEquals(100, boundedThrottleTime(observedValue, quota, 2, maxThrottleTimeMs)) + } + + @Test + def testBoundedThrottleTimeObservedRateAboveQuotaAboveLimit(): Unit = { + val quota = 50.0 + val observedValue = 100.0 + assertEquals(maxThrottleTimeMs, boundedThrottleTime(observedValue, quota, numSamples, maxThrottleTimeMs)) + } + + @Test + def testThrottleTimeThrowsExceptionIfProvidedNonRateMetric(): Unit = { + val testMetric = new KafkaMetric(new Object(), metricName, new Value(), new MetricConfig, time); + + assertThrows(classOf[IllegalArgumentException], () => QuotaUtils.throttleTime(new QuotaViolationException(testMetric, 10.0, 20.0), time.milliseconds)) + } + + @Test + def testBoundedThrottleTimeThrowsExceptionIfProvidedNonRateMetric(): Unit = { + val testMetric = new KafkaMetric(new Object(), metricName, new Value(), new MetricConfig, time); + + assertThrows(classOf[IllegalArgumentException], () => QuotaUtils.boundedThrottleTime(new QuotaViolationException(testMetric, 10.0, 20.0), + maxThrottleTimeMs, time.milliseconds)) + } + + // the `metric` passed into the returned QuotaViolationException will return windowSize = 'numSamples' - 1 + private def quotaViolationException(observedValue: Double, quota: Double, numSamples: Int): QuotaViolationException = { + val metricConfig = new MetricConfig() + .timeWindow(sampleWindowSec, TimeUnit.SECONDS) + .samples(numSamples) + .quota(new Quota(quota, true)) + val metric = new KafkaMetric(new Object(), metricName, new Rate(), metricConfig, time) + new QuotaViolationException(metric, observedValue, quota) + } + + private def throttleTime(observedValue: Double, quota: Double, numSamples: Int): Long = { + val e = quotaViolationException(observedValue, quota, numSamples) + QuotaUtils.throttleTime(e, time.milliseconds) + } + + private def boundedThrottleTime(observedValue: Double, quota: Double, numSamples: Int, maxThrottleTime: Long): Long = { + val e = quotaViolationException(observedValue, quota, numSamples) + QuotaUtils.boundedThrottleTime(e, maxThrottleTime, time.milliseconds) + } +} diff --git a/core/src/test/scala/unit/kafka/utils/ReplicationUtilsTest.scala b/core/src/test/scala/unit/kafka/utils/ReplicationUtilsTest.scala new file mode 100644 index 0000000..a610956 --- /dev/null +++ b/core/src/test/scala/unit/kafka/utils/ReplicationUtilsTest.scala @@ -0,0 +1,76 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import kafka.api.LeaderAndIsr +import kafka.controller.LeaderIsrAndControllerEpoch +import kafka.server.QuorumTestHarness +import kafka.zk._ +import org.apache.kafka.common.TopicPartition +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{BeforeEach, Test, TestInfo} + +class ReplicationUtilsTest extends QuorumTestHarness { + private val zkVersion = 1 + private val topic = "my-topic-test" + private val partition = 0 + private val leader = 1 + private val leaderEpoch = 1 + private val controllerEpoch = 1 + private val isr = List(1, 2) + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + zkClient.makeSurePersistentPathExists(TopicZNode.path(topic)) + val topicPartition = new TopicPartition(topic, partition) + val leaderAndIsr = LeaderAndIsr(leader, leaderEpoch, isr, 1) + val leaderIsrAndControllerEpoch = LeaderIsrAndControllerEpoch(leaderAndIsr, controllerEpoch) + zkClient.createTopicPartitionStatesRaw(Map(topicPartition -> leaderIsrAndControllerEpoch), ZkVersion.MatchAnyVersion) + } + + @Test + def testUpdateLeaderAndIsr(): Unit = { + zkClient.makeSurePersistentPathExists(IsrChangeNotificationZNode.path) + + val replicas = List(0, 1) + + // regular update + val newLeaderAndIsr1 = new LeaderAndIsr(leader, leaderEpoch, replicas, 0) + val (updateSucceeded1, newZkVersion1) = ReplicationUtils.updateLeaderAndIsr(zkClient, + new TopicPartition(topic, partition), newLeaderAndIsr1, controllerEpoch) + assertTrue(updateSucceeded1) + assertEquals(newZkVersion1, 1) + + // mismatched zkVersion with the same data + val newLeaderAndIsr2 = new LeaderAndIsr(leader, leaderEpoch, replicas, zkVersion + 1) + val (updateSucceeded2, newZkVersion2) = ReplicationUtils.updateLeaderAndIsr(zkClient, + new TopicPartition(topic, partition), newLeaderAndIsr2, controllerEpoch) + assertTrue(updateSucceeded2) + // returns true with existing zkVersion + assertEquals(newZkVersion2, 1) + + // mismatched zkVersion and leaderEpoch + val newLeaderAndIsr3 = new LeaderAndIsr(leader, leaderEpoch + 1, replicas, zkVersion + 1) + val (updateSucceeded3, newZkVersion3) = ReplicationUtils.updateLeaderAndIsr(zkClient, + new TopicPartition(topic, partition), newLeaderAndIsr3, controllerEpoch) + assertFalse(updateSucceeded3) + assertEquals(newZkVersion3, -1) + } + +} diff --git a/core/src/test/scala/unit/kafka/utils/SchedulerTest.scala b/core/src/test/scala/unit/kafka/utils/SchedulerTest.scala new file mode 100644 index 0000000..0f23519 --- /dev/null +++ b/core/src/test/scala/unit/kafka/utils/SchedulerTest.scala @@ -0,0 +1,186 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.utils + +import java.util.Properties +import java.util.concurrent.atomic._ +import java.util.concurrent.{CountDownLatch, Executors, TimeUnit} +import kafka.log.{LoadLogParams, LocalLog, UnifiedLog, LogConfig, LogLoader, LogManager, LogSegments, ProducerStateManager} +import kafka.server.{BrokerTopicStats, LogDirFailureChannel} +import kafka.utils.TestUtils.retry +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, Timeout} + +class SchedulerTest { + + val scheduler = new KafkaScheduler(1) + val mockTime = new MockTime + val counter1 = new AtomicInteger(0) + val counter2 = new AtomicInteger(0) + + @BeforeEach + def setup(): Unit = { + scheduler.startup() + } + + @AfterEach + def teardown(): Unit = { + scheduler.shutdown() + } + + @Test + def testMockSchedulerNonPeriodicTask(): Unit = { + mockTime.scheduler.schedule("test1", counter1.getAndIncrement _, delay=1) + mockTime.scheduler.schedule("test2", counter2.getAndIncrement _, delay=100) + assertEquals(0, counter1.get, "Counter1 should not be incremented prior to task running.") + assertEquals(0, counter2.get, "Counter2 should not be incremented prior to task running.") + mockTime.sleep(1) + assertEquals(1, counter1.get, "Counter1 should be incremented") + assertEquals(0, counter2.get, "Counter2 should not be incremented") + mockTime.sleep(100000) + assertEquals(1, counter1.get, "More sleeping should not result in more incrementing on counter1.") + assertEquals(1, counter2.get, "Counter2 should now be incremented.") + } + + @Test + def testMockSchedulerPeriodicTask(): Unit = { + mockTime.scheduler.schedule("test1", counter1.getAndIncrement _, delay=1, period=1) + mockTime.scheduler.schedule("test2", counter2.getAndIncrement _, delay=100, period=100) + assertEquals(0, counter1.get, "Counter1 should not be incremented prior to task running.") + assertEquals(0, counter2.get, "Counter2 should not be incremented prior to task running.") + mockTime.sleep(1) + assertEquals(1, counter1.get, "Counter1 should be incremented") + assertEquals(0, counter2.get, "Counter2 should not be incremented") + mockTime.sleep(100) + assertEquals(101, counter1.get, "Counter1 should be incremented 101 times") + assertEquals(1, counter2.get, "Counter2 should not be incremented once") + } + + @Test + def testReentrantTaskInMockScheduler(): Unit = { + mockTime.scheduler.schedule("test1", () => mockTime.scheduler.schedule("test2", counter2.getAndIncrement _, delay=0), delay=1) + mockTime.sleep(1) + assertEquals(1, counter2.get) + } + + @Test + def testNonPeriodicTask(): Unit = { + scheduler.schedule("test", counter1.getAndIncrement _, delay = 0) + retry(30000) { + assertEquals(counter1.get, 1) + } + Thread.sleep(5) + assertEquals(1, counter1.get, "Should only run once") + } + + @Test + def testPeriodicTask(): Unit = { + scheduler.schedule("test", counter1.getAndIncrement _, delay = 0, period = 5) + retry(30000){ + assertTrue(counter1.get >= 20, "Should count to 20") + } + } + + @Test + def testRestart(): Unit = { + // schedule a task to increment a counter + mockTime.scheduler.schedule("test1", counter1.getAndIncrement _, delay=1) + mockTime.sleep(1) + assertEquals(1, counter1.get()) + + // restart the scheduler + mockTime.scheduler.shutdown() + mockTime.scheduler.startup() + + // schedule another task to increment the counter + mockTime.scheduler.schedule("test1", counter1.getAndIncrement _, delay=1) + mockTime.sleep(1) + assertEquals(2, counter1.get()) + } + + @Test + def testUnscheduleProducerTask(): Unit = { + val tmpDir = TestUtils.tempDir() + val logDir = TestUtils.randomPartitionLogDir(tmpDir) + val logConfig = LogConfig(new Properties()) + val brokerTopicStats = new BrokerTopicStats + val maxProducerIdExpirationMs = 60 * 60 * 1000 + val topicPartition = UnifiedLog.parseTopicPartitionName(logDir) + val logDirFailureChannel = new LogDirFailureChannel(10) + val segments = new LogSegments(topicPartition) + val leaderEpochCache = UnifiedLog.maybeCreateLeaderEpochCache(logDir, topicPartition, logDirFailureChannel, logConfig.recordVersion, "") + val producerStateManager = new ProducerStateManager(topicPartition, logDir, maxProducerIdExpirationMs, mockTime) + val offsets = LogLoader.load(LoadLogParams( + logDir, + topicPartition, + logConfig, + scheduler, + mockTime, + logDirFailureChannel, + hadCleanShutdown = true, + segments, + 0L, + 0L, + maxProducerIdExpirationMs, + leaderEpochCache, + producerStateManager)) + val localLog = new LocalLog(logDir, logConfig, segments, offsets.recoveryPoint, + offsets.nextOffsetMetadata, scheduler, mockTime, topicPartition, logDirFailureChannel) + val log = new UnifiedLog(logStartOffset = offsets.logStartOffset, + localLog = localLog, + brokerTopicStats, LogManager.ProducerIdExpirationCheckIntervalMs, + leaderEpochCache, producerStateManager, + _topicId = None, keepPartitionMetadataFile = true) + assertTrue(scheduler.taskRunning(log.producerExpireCheck)) + log.close() + assertFalse(scheduler.taskRunning(log.producerExpireCheck)) + } + + /** + * Verify that scheduler lock is not held when invoking task method, allowing new tasks to be scheduled + * when another is being executed. This is required to avoid deadlocks when: + * a) Thread1 executes a task which attempts to acquire LockA + * b) Thread2 holding LockA attempts to schedule a new task + */ + @Timeout(15) + @Test + def testMockSchedulerLocking(): Unit = { + val initLatch = new CountDownLatch(1) + val completionLatch = new CountDownLatch(2) + val taskLatches = List(new CountDownLatch(1), new CountDownLatch(1)) + def scheduledTask(taskLatch: CountDownLatch): Unit = { + initLatch.countDown() + assertTrue(taskLatch.await(30, TimeUnit.SECONDS), "Timed out waiting for latch") + completionLatch.countDown() + } + mockTime.scheduler.schedule("test1", () => scheduledTask(taskLatches.head), delay=1) + val tickExecutor = Executors.newSingleThreadScheduledExecutor() + try { + tickExecutor.scheduleWithFixedDelay(() => mockTime.sleep(1), 0, 1, TimeUnit.MILLISECONDS) + + // wait for first task to execute and then schedule the next task while the first one is running + assertTrue(initLatch.await(10, TimeUnit.SECONDS)) + mockTime.scheduler.schedule("test2", () => scheduledTask(taskLatches(1)), delay = 1) + + taskLatches.foreach(_.countDown()) + assertTrue(completionLatch.await(10, TimeUnit.SECONDS), "Tasks did not complete") + + } finally { + tickExecutor.shutdownNow() + } + } +} diff --git a/core/src/test/scala/unit/kafka/utils/ShutdownableThreadTest.scala b/core/src/test/scala/unit/kafka/utils/ShutdownableThreadTest.scala new file mode 100644 index 0000000..7cc1e8c --- /dev/null +++ b/core/src/test/scala/unit/kafka/utils/ShutdownableThreadTest.scala @@ -0,0 +1,54 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.utils + +import java.util.concurrent.{CountDownLatch, TimeUnit} + +import org.apache.kafka.common.internals.FatalExitError +import org.junit.jupiter.api.{AfterEach, Test} +import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue} + +class ShutdownableThreadTest { + + @AfterEach + def tearDown(): Unit = Exit.resetExitProcedure() + + @Test + def testShutdownWhenCalledAfterThreadStart(): Unit = { + @volatile var statusCodeOption: Option[Int] = None + Exit.setExitProcedure { (statusCode, _) => + statusCodeOption = Some(statusCode) + // Sleep until interrupted to emulate the fact that `System.exit()` never returns + Thread.sleep(Long.MaxValue) + throw new AssertionError + } + val latch = new CountDownLatch(1) + val thread = new ShutdownableThread("shutdownable-thread-test") { + override def doWork(): Unit = { + latch.countDown() + throw new FatalExitError + } + } + thread.start() + assertTrue(latch.await(10, TimeUnit.SECONDS), "doWork was not invoked") + + thread.shutdown() + TestUtils.waitUntilTrue(() => statusCodeOption.isDefined, "Status code was not set by exit procedure") + assertEquals(1, statusCodeOption.get) + } + +} diff --git a/core/src/test/scala/unit/kafka/utils/TestUtils.scala b/core/src/test/scala/unit/kafka/utils/TestUtils.scala new file mode 100755 index 0000000..99d7639 --- /dev/null +++ b/core/src/test/scala/unit/kafka/utils/TestUtils.scala @@ -0,0 +1,2173 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.utils + +import java.io._ +import java.net.InetAddress +import java.nio._ +import java.nio.channels._ +import java.nio.charset.{Charset, StandardCharsets} +import java.nio.file.{Files, StandardOpenOption} +import java.security.cert.X509Certificate +import java.time.Duration +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} +import java.util.concurrent.{Callable, CompletableFuture, ExecutionException, Executors, TimeUnit} +import java.util.{Arrays, Collections, Optional, Properties} + +import com.yammer.metrics.core.Meter +import javax.net.ssl.X509TrustManager +import kafka.api._ +import kafka.cluster.{Broker, EndPoint, IsrChangeListener} +import kafka.controller.{ControllerEventManager, LeaderIsrAndControllerEpoch} +import kafka.log._ +import kafka.metrics.KafkaYammerMetrics +import kafka.network.RequestChannel +import kafka.server._ +import kafka.server.checkpoints.OffsetCheckpointFile +import kafka.server.metadata.{ConfigRepository, MockConfigRepository} +import kafka.utils.Implicits._ +import kafka.zk._ +import org.apache.kafka.clients.CommonClientConfigs +import org.apache.kafka.clients.admin.AlterConfigOp.OpType +import org.apache.kafka.clients.admin._ +import org.apache.kafka.clients.consumer._ +import org.apache.kafka.clients.consumer.internals.AbstractCoordinator +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerConfig, ProducerRecord} +import org.apache.kafka.common.acl.{AccessControlEntry, AccessControlEntryFilter, AclBinding, AclBindingFilter} +import org.apache.kafka.common.config.{ConfigException, ConfigResource} +import org.apache.kafka.common.config.ConfigResource.Type.TOPIC +import org.apache.kafka.common.errors.{KafkaStorageException, OperationNotAttemptedException, TopicExistsException, UnknownTopicOrPartitionException} +import org.apache.kafka.common.header.Header +import org.apache.kafka.common.internals.Topic +import org.apache.kafka.common.memory.MemoryPool +import org.apache.kafka.common.message.UpdateMetadataRequestData.UpdateMetadataPartitionState +import org.apache.kafka.common.metrics.Metrics +import org.apache.kafka.common.network.{ClientInformation, ListenerName, Mode} +import org.apache.kafka.common.protocol.{ApiKeys, Errors} +import org.apache.kafka.common.quota.{ClientQuotaAlteration, ClientQuotaEntity} +import org.apache.kafka.common.record._ +import org.apache.kafka.common.requests.{AbstractRequest, EnvelopeRequest, RequestContext, RequestHeader} +import org.apache.kafka.common.resource.ResourcePattern +import org.apache.kafka.common.security.auth.{KafkaPrincipal, KafkaPrincipalSerde, SecurityProtocol} +import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer, Deserializer, IntegerSerializer, Serializer} +import org.apache.kafka.common.utils.Utils._ +import org.apache.kafka.common.utils.{Time, Utils} +import org.apache.kafka.common.{KafkaFuture, TopicPartition} +import org.apache.kafka.controller.QuorumController +import org.apache.kafka.server.authorizer.{Authorizer => JAuthorizer} +import org.apache.kafka.test.{TestSslUtils, TestUtils => JTestUtils} +import org.apache.zookeeper.KeeperException.SessionExpiredException +import org.apache.zookeeper.ZooDefs._ +import org.apache.zookeeper.data.ACL +import org.junit.jupiter.api.Assertions._ +import org.mockito.Mockito + +import scala.annotation.nowarn +import scala.collection.mutable.{ArrayBuffer, ListBuffer} +import scala.collection.{Map, Seq, mutable} +import scala.concurrent.duration.FiniteDuration +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.jdk.CollectionConverters._ +import scala.util.{Failure, Success, Try} + +/** + * Utility functions to help with testing + */ +object TestUtils extends Logging { + + val random = JTestUtils.RANDOM + + /* 0 gives a random port; you can then retrieve the assigned port from the Socket object. */ + val RandomPort = 0 + + /* Incorrect broker port which can used by kafka clients in tests. This port should not be used + by any other service and hence we use a reserved port. */ + val IncorrectBrokerPort = 225 + + /** Port to use for unit tests that mock/don't require a real ZK server. */ + val MockZkPort = 1 + /** ZooKeeper connection string to use for unit tests that mock/don't require a real ZK server. */ + val MockZkConnect = "127.0.0.1:" + MockZkPort + // CN in SSL certificates - this is used for endpoint validation when enabled + val SslCertificateCn = "localhost" + + private val transactionStatusKey = "transactionStatus" + private val committedValue : Array[Byte] = "committed".getBytes(StandardCharsets.UTF_8) + private val abortedValue : Array[Byte] = "aborted".getBytes(StandardCharsets.UTF_8) + + sealed trait LogDirFailureType + case object Roll extends LogDirFailureType + case object Checkpoint extends LogDirFailureType + + /** + * Create a temporary directory + */ + def tempDir(): File = JTestUtils.tempDirectory() + + def tempTopic(): String = "testTopic" + random.nextInt(1000000) + + /** + * Create a temporary relative directory + */ + def tempRelativeDir(parent: String): File = { + val parentFile = new File(parent) + parentFile.mkdirs() + + JTestUtils.tempDirectory(parentFile.toPath, null) + } + + /** + * Create a random log directory in the format - used for Kafka partition logs. + * It is the responsibility of the caller to set up a shutdown hook for deletion of the directory. + */ + def randomPartitionLogDir(parentDir: File): File = { + val attempts = 1000 + val f = Iterator.continually(new File(parentDir, "kafka-" + random.nextInt(1000000))) + .take(attempts).find(_.mkdir()) + .getOrElse(sys.error(s"Failed to create directory after $attempts attempts")) + f.deleteOnExit() + f + } + + /** + * Create a temporary file + */ + def tempFile(): File = JTestUtils.tempFile() + + /** + * Create a temporary file and return an open file channel for this file + */ + def tempChannel(): FileChannel = + FileChannel.open(tempFile().toPath, StandardOpenOption.READ, StandardOpenOption.WRITE) + + /** + * Create a kafka server instance with appropriate test settings + * USING THIS IS A SIGN YOU ARE NOT WRITING A REAL UNIT TEST + * + * @param config The configuration of the server + */ + def createServer(config: KafkaConfig, time: Time = Time.SYSTEM): KafkaServer = { + createServer(config, time, None) + } + + def createServer(config: KafkaConfig, threadNamePrefix: Option[String]): KafkaServer = { + createServer(config, Time.SYSTEM, threadNamePrefix) + } + + def createServer(config: KafkaConfig, time: Time, threadNamePrefix: Option[String]): KafkaServer = { + createServer(config, time, threadNamePrefix, enableForwarding = false) + } + + def createServer(config: KafkaConfig, time: Time, threadNamePrefix: Option[String], enableForwarding: Boolean): KafkaServer = { + val server = new KafkaServer(config, time, threadNamePrefix, enableForwarding) + server.startup() + server + } + + def boundPort(broker: KafkaBroker, securityProtocol: SecurityProtocol = SecurityProtocol.PLAINTEXT): Int = + broker.boundPort(ListenerName.forSecurityProtocol(securityProtocol)) + + def createBrokerAndEpoch(id: Int, host: String, port: Int, securityProtocol: SecurityProtocol = SecurityProtocol.PLAINTEXT, + epoch: Long = 0): (Broker, Long) = { + (new Broker(id, host, port, ListenerName.forSecurityProtocol(securityProtocol), securityProtocol), epoch) + } + + /** + * Create a test config for the provided parameters. + * + * Note that if `interBrokerSecurityProtocol` is defined, the listener for the `SecurityProtocol` will be enabled. + */ + def createBrokerConfigs( + numConfigs: Int, + zkConnect: String, + enableControlledShutdown: Boolean = true, + enableDeleteTopic: Boolean = true, + interBrokerSecurityProtocol: Option[SecurityProtocol] = None, + trustStoreFile: Option[File] = None, + saslProperties: Option[Properties] = None, + enablePlaintext: Boolean = true, + enableSsl: Boolean = false, + enableSaslPlaintext: Boolean = false, + enableSaslSsl: Boolean = false, + rackInfo: Map[Int, String] = Map(), + logDirCount: Int = 1, + enableToken: Boolean = false, + numPartitions: Int = 1, + defaultReplicationFactor: Short = 1, + startingIdNumber: Int = 0 + ): Seq[Properties] = { + val endingIdNumber = startingIdNumber + numConfigs - 1 + (startingIdNumber to endingIdNumber).map { node => + createBrokerConfig(node, zkConnect, enableControlledShutdown, enableDeleteTopic, RandomPort, + interBrokerSecurityProtocol, trustStoreFile, saslProperties, enablePlaintext = enablePlaintext, enableSsl = enableSsl, + enableSaslPlaintext = enableSaslPlaintext, enableSaslSsl = enableSaslSsl, rack = rackInfo.get(node), logDirCount = logDirCount, enableToken = enableToken, + numPartitions = numPartitions, defaultReplicationFactor = defaultReplicationFactor) + } + } + + def getBrokerListStrFromServers[B <: KafkaBroker]( + brokers: Seq[B], + protocol: SecurityProtocol = SecurityProtocol.PLAINTEXT): String = { + brokers.map { s => + val listener = s.config.effectiveAdvertisedListeners.find(_.securityProtocol == protocol).getOrElse( + sys.error(s"Could not find listener with security protocol $protocol")) + formatAddress(listener.host, boundPort(s, protocol)) + }.mkString(",") + } + + def bootstrapServers[B <: KafkaBroker](brokers: Seq[B], listenerName: ListenerName): String = { + brokers.map { s => + val listener = s.config.effectiveAdvertisedListeners.find(_.listenerName == listenerName).getOrElse( + sys.error(s"Could not find listener with name ${listenerName.value}")) + formatAddress(listener.host, s.boundPort(listenerName)) + }.mkString(",") + } + + /** + * Shutdown `servers` and delete their log directories. + */ + def shutdownServers[B <: KafkaBroker](brokers: Seq[B]): Unit = { + import ExecutionContext.Implicits._ + val future = Future.traverse(brokers) { s => + Future { + s.shutdown() + CoreUtils.delete(s.config.logDirs) + } + } + Await.result(future, FiniteDuration(5, TimeUnit.MINUTES)) + } + + /** + * Create a test config for the provided parameters. + * + * Note that if `interBrokerSecurityProtocol` is defined, the listener for the `SecurityProtocol` will be enabled. + */ + def createBrokerConfig(nodeId: Int, + zkConnect: String, + enableControlledShutdown: Boolean = true, + enableDeleteTopic: Boolean = true, + port: Int = RandomPort, + interBrokerSecurityProtocol: Option[SecurityProtocol] = None, + trustStoreFile: Option[File] = None, + saslProperties: Option[Properties] = None, + enablePlaintext: Boolean = true, + enableSaslPlaintext: Boolean = false, + saslPlaintextPort: Int = RandomPort, + enableSsl: Boolean = false, + sslPort: Int = RandomPort, + enableSaslSsl: Boolean = false, + saslSslPort: Int = RandomPort, + rack: Option[String] = None, + logDirCount: Int = 1, + enableToken: Boolean = false, + numPartitions: Int = 1, + defaultReplicationFactor: Short = 1): Properties = { + def shouldEnable(protocol: SecurityProtocol) = interBrokerSecurityProtocol.fold(false)(_ == protocol) + + val protocolAndPorts = ArrayBuffer[(SecurityProtocol, Int)]() + if (enablePlaintext || shouldEnable(SecurityProtocol.PLAINTEXT)) + protocolAndPorts += SecurityProtocol.PLAINTEXT -> port + if (enableSsl || shouldEnable(SecurityProtocol.SSL)) + protocolAndPorts += SecurityProtocol.SSL -> sslPort + if (enableSaslPlaintext || shouldEnable(SecurityProtocol.SASL_PLAINTEXT)) + protocolAndPorts += SecurityProtocol.SASL_PLAINTEXT -> saslPlaintextPort + if (enableSaslSsl || shouldEnable(SecurityProtocol.SASL_SSL)) + protocolAndPorts += SecurityProtocol.SASL_SSL -> saslSslPort + + val listeners = protocolAndPorts.map { case (protocol, port) => + s"${protocol.name}://localhost:$port" + }.mkString(",") + + val props = new Properties + if (zkConnect == null) { + props.put(KafkaConfig.NodeIdProp, nodeId.toString) + props.put(KafkaConfig.BrokerIdProp, nodeId.toString) + props.put(KafkaConfig.AdvertisedListenersProp, listeners) + props.put(KafkaConfig.ListenersProp, listeners) + props.put(KafkaConfig.ControllerListenerNamesProp, "CONTROLLER") + props.put(KafkaConfig.ListenerSecurityProtocolMapProp, protocolAndPorts. + map(p => "%s:%s".format(p._1, p._1)).mkString(",") + ",CONTROLLER:PLAINTEXT") + } else { + if (nodeId >= 0) props.put(KafkaConfig.BrokerIdProp, nodeId.toString) + props.put(KafkaConfig.ListenersProp, listeners) + } + if (logDirCount > 1) { + val logDirs = (1 to logDirCount).toList.map(i => + // We would like to allow user to specify both relative path and absolute path as log directory for backward-compatibility reason + // We can verify this by using a mixture of relative path and absolute path as log directories in the test + if (i % 2 == 0) tempDir().getAbsolutePath else tempRelativeDir("data") + ).mkString(",") + props.put(KafkaConfig.LogDirsProp, logDirs) + } else { + props.put(KafkaConfig.LogDirProp, tempDir().getAbsolutePath) + } + if (zkConnect == null) { + props.put(KafkaConfig.ProcessRolesProp, "broker") + // Note: this is just a placeholder value for controller.quorum.voters. JUnit + // tests use random port assignment, so the controller ports are not known ahead of + // time. Therefore, we ignore controller.quorum.voters and use + // controllerQuorumVotersFuture instead. + props.put(KafkaConfig.QuorumVotersProp, "1000@localhost:0") + } else { + props.put(KafkaConfig.ZkConnectProp, zkConnect) + props.put(KafkaConfig.ZkConnectionTimeoutMsProp, "10000") + } + props.put(KafkaConfig.ReplicaSocketTimeoutMsProp, "1500") + props.put(KafkaConfig.ControllerSocketTimeoutMsProp, "1500") + props.put(KafkaConfig.ControlledShutdownEnableProp, enableControlledShutdown.toString) + props.put(KafkaConfig.DeleteTopicEnableProp, enableDeleteTopic.toString) + props.put(KafkaConfig.LogDeleteDelayMsProp, "1000") + props.put(KafkaConfig.ControlledShutdownRetryBackoffMsProp, "100") + props.put(KafkaConfig.LogCleanerDedupeBufferSizeProp, "2097152") + props.put(KafkaConfig.LogMessageTimestampDifferenceMaxMsProp, Long.MaxValue.toString) + props.put(KafkaConfig.OffsetsTopicReplicationFactorProp, "1") + if (!props.containsKey(KafkaConfig.OffsetsTopicPartitionsProp)) + props.put(KafkaConfig.OffsetsTopicPartitionsProp, "5") + if (!props.containsKey(KafkaConfig.GroupInitialRebalanceDelayMsProp)) + props.put(KafkaConfig.GroupInitialRebalanceDelayMsProp, "0") + rack.foreach(props.put(KafkaConfig.RackProp, _)) + + if (protocolAndPorts.exists { case (protocol, _) => usesSslTransportLayer(protocol) }) + props ++= sslConfigs(Mode.SERVER, false, trustStoreFile, s"server$nodeId") + + if (protocolAndPorts.exists { case (protocol, _) => usesSaslAuthentication(protocol) }) + props ++= JaasTestUtils.saslConfigs(saslProperties) + + interBrokerSecurityProtocol.foreach { protocol => + props.put(KafkaConfig.InterBrokerSecurityProtocolProp, protocol.name) + } + + if (enableToken) + props.put(KafkaConfig.DelegationTokenSecretKeyProp, "secretkey") + + props.put(KafkaConfig.NumPartitionsProp, numPartitions.toString) + props.put(KafkaConfig.DefaultReplicationFactorProp, defaultReplicationFactor.toString) + + props + } + + @nowarn("cat=deprecation") + def setIbpAndMessageFormatVersions(config: Properties, version: ApiVersion): Unit = { + config.setProperty(KafkaConfig.InterBrokerProtocolVersionProp, version.version) + // for clarity, only set the log message format version if it's not ignored + if (!LogConfig.shouldIgnoreMessageFormatVersion(version)) + config.setProperty(KafkaConfig.LogMessageFormatVersionProp, version.version) + } + + def createAdminClient[B <: KafkaBroker]( + brokers: Seq[B], + adminConfig: Properties): Admin = { + val adminClientProperties = new Properties(adminConfig) + if (!adminClientProperties.containsKey(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG)) { + adminClientProperties.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, + getBrokerListStrFromServers(brokers)) + } + Admin.create(adminClientProperties) + } + + def createTopicWithAdmin[B <: KafkaBroker]( + topic: String, + numPartitions: Int = 1, + replicationFactor: Int = 1, + brokers: Seq[B], + topicConfig: Properties = new Properties, + adminConfig: Properties = new Properties): scala.collection.immutable.Map[Int, Int] = { + val adminClient = createAdminClient(brokers, adminConfig) + try { + val configsMap = new java.util.HashMap[String, String]() + topicConfig.forEach((k, v) => configsMap.put(k.toString, v.toString)) + try { + adminClient.createTopics(Collections.singletonList(new NewTopic( + topic, numPartitions, replicationFactor.toShort).configs(configsMap))).all().get() + } catch { + case e: ExecutionException => if (!(e.getCause != null && + e.getCause.isInstanceOf[TopicExistsException] && + topicHasSameNumPartitionsAndReplicationFactor(adminClient, topic, numPartitions, replicationFactor))) { + throw e + } + } + } finally { + adminClient.close() + } + // wait until we've propagated all partitions metadata to all brokers + val allPartitionsMetadata = waitForAllPartitionsMetadata(brokers, topic, numPartitions) + + (0 until numPartitions).map { i => + i -> allPartitionsMetadata.get(new TopicPartition(topic, i)).map(_.leader()).getOrElse( + throw new IllegalStateException(s"Cannot get the partition leader for topic: $topic, partition: $i in server metadata cache")) + }.toMap + } + + def topicHasSameNumPartitionsAndReplicationFactor(adminClient: Admin, + topic: String, + numPartitions: Int, + replicationFactor: Int): Boolean = { + val describedTopics = adminClient.describeTopics(Collections. + singleton(topic)).allTopicNames().get() + val description = describedTopics.get(topic) + (description != null && + description.partitions().size() == numPartitions && + description.partitions().iterator().next().replicas().size() == replicationFactor) + } + + def createOffsetsTopicWithAdmin[B <: KafkaBroker]( + brokers: Seq[B], + adminConfig: Properties = new Properties) = { + val broker = brokers.head + createTopicWithAdmin(topic = Topic.GROUP_METADATA_TOPIC_NAME, + numPartitions = broker.config.getInt(KafkaConfig.OffsetsTopicPartitionsProp), + replicationFactor = broker.config.getShort(KafkaConfig.OffsetsTopicReplicationFactorProp).toInt, + brokers = brokers, + topicConfig = broker.groupCoordinator.offsetsTopicConfigs, + adminConfig = adminConfig) + } + + def deleteTopicWithAdmin[B <: KafkaBroker]( + topic: String, + brokers: Seq[B], + adminConfig: Properties = new Properties): Unit = { + val adminClient = createAdminClient(brokers, adminConfig) + try { + adminClient.deleteTopics(Collections.singletonList(topic)).all().get() + } catch { + case e: ExecutionException => if (e.getCause != null && + e.getCause.isInstanceOf[UnknownTopicOrPartitionException]) { + // ignore + } else { + throw e + } + } finally { + adminClient.close() + } + waitForAllPartitionsMetadata(brokers, topic, 0) + } + + /** + * Create a topic in ZooKeeper. + * Wait until the leader is elected and the metadata is propagated to all brokers. + * Return the leader for each partition. + */ + def createTopic(zkClient: KafkaZkClient, + topic: String, + numPartitions: Int = 1, + replicationFactor: Int = 1, + servers: Seq[KafkaServer], + topicConfig: Properties = new Properties): scala.collection.immutable.Map[Int, Int] = { + val adminZkClient = new AdminZkClient(zkClient) + // create topic + waitUntilTrue( () => { + var hasSessionExpirationException = false + try { + adminZkClient.createTopic(topic, numPartitions, replicationFactor, topicConfig) + } catch { + case _: SessionExpiredException => hasSessionExpirationException = true + case e: Throwable => throw e // let other exceptions propagate + } + !hasSessionExpirationException}, + s"Can't create topic $topic") + + // wait until we've propagated all partitions metadata to all servers + val allPartitionsMetadata = waitForAllPartitionsMetadata(servers, topic, numPartitions) + + (0 until numPartitions).map { i => + i -> allPartitionsMetadata.get(new TopicPartition(topic, i)).map(_.leader()).getOrElse( + throw new IllegalStateException(s"Cannot get the partition leader for topic: $topic, partition: $i in server metadata cache")) + }.toMap + } + + /** + * Create a topic in ZooKeeper using a customized replica assignment. + * Wait until the leader is elected and the metadata is propagated to all brokers. + * Return the leader for each partition. + */ + def createTopic(zkClient: KafkaZkClient, + topic: String, + partitionReplicaAssignment: collection.Map[Int, Seq[Int]], + servers: Seq[KafkaServer]): scala.collection.immutable.Map[Int, Int] = { + createTopic(zkClient, topic, partitionReplicaAssignment, servers, new Properties()) + } + + /** + * Create a topic in ZooKeeper using a customized replica assignment. + * Wait until the leader is elected and the metadata is propagated to all brokers. + * Return the leader for each partition. + */ + def createTopic(zkClient: KafkaZkClient, + topic: String, + partitionReplicaAssignment: collection.Map[Int, Seq[Int]], + servers: Seq[KafkaServer], + topicConfig: Properties): scala.collection.immutable.Map[Int, Int] = { + val adminZkClient = new AdminZkClient(zkClient) + // create topic + waitUntilTrue( () => { + var hasSessionExpirationException = false + try { + adminZkClient.createTopicWithAssignment(topic, topicConfig, partitionReplicaAssignment) + } catch { + case _: SessionExpiredException => hasSessionExpirationException = true + case e: Throwable => throw e // let other exceptions propagate + } + !hasSessionExpirationException}, + s"Can't create topic $topic") + + // wait until we've propagated all partitions metadata to all servers + val allPartitionsMetadata = waitForAllPartitionsMetadata(servers, topic, partitionReplicaAssignment.size) + + partitionReplicaAssignment.keySet.map { i => + i -> allPartitionsMetadata.get(new TopicPartition(topic, i)).map(_.leader()).getOrElse( + throw new IllegalStateException(s"Cannot get the partition leader for topic: $topic, partition: $i in server metadata cache")) + }.toMap + } + + /** + * Create the consumer offsets/group metadata topic and wait until the leader is elected and metadata is propagated + * to all brokers. + */ + def createOffsetsTopic(zkClient: KafkaZkClient, servers: Seq[KafkaServer]): Unit = { + val server = servers.head + createTopic(zkClient, Topic.GROUP_METADATA_TOPIC_NAME, + server.config.getInt(KafkaConfig.OffsetsTopicPartitionsProp), + server.config.getShort(KafkaConfig.OffsetsTopicReplicationFactorProp).toInt, + servers, + server.groupCoordinator.offsetsTopicConfigs) + } + + /** + * Wrap a single record log buffer. + */ + def singletonRecords(value: Array[Byte], + key: Array[Byte] = null, + codec: CompressionType = CompressionType.NONE, + timestamp: Long = RecordBatch.NO_TIMESTAMP, + magicValue: Byte = RecordBatch.CURRENT_MAGIC_VALUE): MemoryRecords = { + records(Seq(new SimpleRecord(timestamp, key, value)), magicValue = magicValue, codec = codec) + } + + def recordsWithValues(magicValue: Byte, + codec: CompressionType, + values: Array[Byte]*): MemoryRecords = { + records(values.map(value => new SimpleRecord(value)), magicValue, codec) + } + + def records(records: Iterable[SimpleRecord], + magicValue: Byte = RecordBatch.CURRENT_MAGIC_VALUE, + codec: CompressionType = CompressionType.NONE, + producerId: Long = RecordBatch.NO_PRODUCER_ID, + producerEpoch: Short = RecordBatch.NO_PRODUCER_EPOCH, + sequence: Int = RecordBatch.NO_SEQUENCE, + baseOffset: Long = 0L, + partitionLeaderEpoch: Int = RecordBatch.NO_PARTITION_LEADER_EPOCH): MemoryRecords = { + val buf = ByteBuffer.allocate(DefaultRecordBatch.sizeInBytes(records.asJava)) + val builder = MemoryRecords.builder(buf, magicValue, codec, TimestampType.CREATE_TIME, baseOffset, + System.currentTimeMillis, producerId, producerEpoch, sequence, false, partitionLeaderEpoch) + records.foreach(builder.append) + builder.build() + } + + /** + * Generate an array of random bytes + * + * @param numBytes The size of the array + */ + def randomBytes(numBytes: Int): Array[Byte] = JTestUtils.randomBytes(numBytes) + + /** + * Generate a random string of letters and digits of the given length + * + * @param len The length of the string + * @return The random string + */ + def randomString(len: Int): String = JTestUtils.randomString(len) + + /** + * Check that the buffer content from buffer.position() to buffer.limit() is equal + */ + def checkEquals(b1: ByteBuffer, b2: ByteBuffer): Unit = { + assertEquals(b1.limit() - b1.position(), b2.limit() - b2.position(), "Buffers should have equal length") + for(i <- 0 until b1.limit() - b1.position()) + assertEquals(b1.get(b1.position() + i), b2.get(b1.position() + i), "byte " + i + " byte not equal.") + } + + /** + * Throw an exception if an iterable has different length than expected + * + */ + def checkLength[T](s1: Iterator[T], expectedLength:Int): Unit = { + var n = 0 + while (s1.hasNext) { + n += 1 + s1.next() + } + assertEquals(expectedLength, n) + } + + /** + * Throw an exception if the two iterators are of differing lengths or contain + * different messages on their Nth element + */ + def checkEquals[T](s1: java.util.Iterator[T], s2: java.util.Iterator[T]): Unit = { + while(s1.hasNext && s2.hasNext) + assertEquals(s1.next, s2.next) + assertFalse(s1.hasNext, "Iterators have uneven length--first has more") + assertFalse(s2.hasNext, "Iterators have uneven length--second has more") + } + + def stackedIterator[T](s: Iterator[T]*): Iterator[T] = { + new Iterator[T] { + var cur: Iterator[T] = null + val topIterator = s.iterator + + def hasNext: Boolean = { + while (true) { + if (cur == null) { + if (topIterator.hasNext) + cur = topIterator.next() + else + return false + } + if (cur.hasNext) + return true + cur = null + } + // should never reach here + throw new RuntimeException("should not reach here") + } + + def next() : T = cur.next() + } + } + + /** + * Create a hexadecimal string for the given bytes + */ + def hexString(bytes: Array[Byte]): String = hexString(ByteBuffer.wrap(bytes)) + + /** + * Create a hexadecimal string for the given bytes + */ + def hexString(buffer: ByteBuffer): String = { + val builder = new StringBuilder("0x") + for(i <- 0 until buffer.limit()) + builder.append(String.format("%x", Integer.valueOf(buffer.get(buffer.position() + i)))) + builder.toString + } + + /** + * Returns security configuration options for broker or clients + * + * @param mode Client or server mode + * @param securityProtocol Security protocol which indicates if SASL or SSL or both configs are included + * @param trustStoreFile Trust store file must be provided for SSL and SASL_SSL + * @param certAlias Alias of certificate in SSL key store + * @param certCn CN for certificate + * @param saslProperties SASL configs if security protocol is SASL_SSL or SASL_PLAINTEXT + * @param tlsProtocol TLS version + * @param needsClientCert If not empty, a flag which indicates if client certificates are required. By default + * client certificates are generated only if securityProtocol is SSL (not for SASL_SSL). + */ + def securityConfigs(mode: Mode, + securityProtocol: SecurityProtocol, + trustStoreFile: Option[File], + certAlias: String, + certCn: String, + saslProperties: Option[Properties], + tlsProtocol: String = TestSslUtils.DEFAULT_TLS_PROTOCOL_FOR_TESTS, + needsClientCert: Option[Boolean] = None): Properties = { + val props = new Properties + if (usesSslTransportLayer(securityProtocol)) { + val addClientCert = needsClientCert.getOrElse(securityProtocol == SecurityProtocol.SSL) + props ++= sslConfigs(mode, addClientCert, trustStoreFile, certAlias, certCn, tlsProtocol) + } + + if (usesSaslAuthentication(securityProtocol)) + props ++= JaasTestUtils.saslConfigs(saslProperties) + props.put(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, securityProtocol.name) + props + } + + def producerSecurityConfigs(securityProtocol: SecurityProtocol, + trustStoreFile: Option[File], + saslProperties: Option[Properties]): Properties = + securityConfigs(Mode.CLIENT, securityProtocol, trustStoreFile, "producer", SslCertificateCn, saslProperties) + + /** + * Create a (new) producer with a few pre-configured properties. + */ + def createProducer[K, V](brokerList: String, + acks: Int = -1, + maxBlockMs: Long = 60 * 1000L, + bufferSize: Long = 1024L * 1024L, + retries: Int = Int.MaxValue, + deliveryTimeoutMs: Int = 30 * 1000, + lingerMs: Int = 0, + batchSize: Int = 16384, + compressionType: String = "none", + requestTimeoutMs: Int = 20 * 1000, + securityProtocol: SecurityProtocol = SecurityProtocol.PLAINTEXT, + trustStoreFile: Option[File] = None, + saslProperties: Option[Properties] = None, + keySerializer: Serializer[K] = new ByteArraySerializer, + valueSerializer: Serializer[V] = new ByteArraySerializer, + enableIdempotence: Boolean = false): KafkaProducer[K, V] = { + val producerProps = new Properties + producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList) + producerProps.put(ProducerConfig.ACKS_CONFIG, acks.toString) + producerProps.put(ProducerConfig.MAX_BLOCK_MS_CONFIG, maxBlockMs.toString) + producerProps.put(ProducerConfig.BUFFER_MEMORY_CONFIG, bufferSize.toString) + producerProps.put(ProducerConfig.RETRIES_CONFIG, retries.toString) + producerProps.put(ProducerConfig.DELIVERY_TIMEOUT_MS_CONFIG, deliveryTimeoutMs.toString) + producerProps.put(ProducerConfig.REQUEST_TIMEOUT_MS_CONFIG, requestTimeoutMs.toString) + producerProps.put(ProducerConfig.LINGER_MS_CONFIG, lingerMs.toString) + producerProps.put(ProducerConfig.BATCH_SIZE_CONFIG, batchSize.toString) + producerProps.put(ProducerConfig.COMPRESSION_TYPE_CONFIG, compressionType) + producerProps.put(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG, enableIdempotence.toString) + producerProps ++= producerSecurityConfigs(securityProtocol, trustStoreFile, saslProperties) + new KafkaProducer[K, V](producerProps, keySerializer, valueSerializer) + } + + def usesSslTransportLayer(securityProtocol: SecurityProtocol): Boolean = securityProtocol match { + case SecurityProtocol.SSL | SecurityProtocol.SASL_SSL => true + case _ => false + } + + def usesSaslAuthentication(securityProtocol: SecurityProtocol): Boolean = securityProtocol match { + case SecurityProtocol.SASL_PLAINTEXT | SecurityProtocol.SASL_SSL => true + case _ => false + } + + def consumerSecurityConfigs(securityProtocol: SecurityProtocol, trustStoreFile: Option[File], saslProperties: Option[Properties]): Properties = + securityConfigs(Mode.CLIENT, securityProtocol, trustStoreFile, "consumer", SslCertificateCn, saslProperties) + + def adminClientSecurityConfigs(securityProtocol: SecurityProtocol, trustStoreFile: Option[File], saslProperties: Option[Properties]): Properties = + securityConfigs(Mode.CLIENT, securityProtocol, trustStoreFile, "admin-client", SslCertificateCn, saslProperties) + + /** + * Create a consumer with a few pre-configured properties. + */ + def createConsumer[K, V](brokerList: String, + groupId: String = "group", + autoOffsetReset: String = "earliest", + enableAutoCommit: Boolean = true, + readCommitted: Boolean = false, + maxPollRecords: Int = 500, + securityProtocol: SecurityProtocol = SecurityProtocol.PLAINTEXT, + trustStoreFile: Option[File] = None, + saslProperties: Option[Properties] = None, + keyDeserializer: Deserializer[K] = new ByteArrayDeserializer, + valueDeserializer: Deserializer[V] = new ByteArrayDeserializer): KafkaConsumer[K, V] = { + val consumerProps = new Properties + consumerProps.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList) + consumerProps.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, autoOffsetReset) + consumerProps.put(ConsumerConfig.GROUP_ID_CONFIG, groupId) + consumerProps.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, enableAutoCommit.toString) + consumerProps.put(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, maxPollRecords.toString) + consumerProps.put(ConsumerConfig.ISOLATION_LEVEL_CONFIG, if (readCommitted) "read_committed" else "read_uncommitted") + consumerProps ++= consumerSecurityConfigs(securityProtocol, trustStoreFile, saslProperties) + new KafkaConsumer[K, V](consumerProps, keyDeserializer, valueDeserializer) + } + + def createBrokersInZk(zkClient: KafkaZkClient, ids: Seq[Int]): Seq[Broker] = + createBrokersInZk(ids.map(kafka.admin.BrokerMetadata(_, None)), zkClient) + + def createBrokersInZk(brokerMetadatas: Seq[kafka.admin.BrokerMetadata], zkClient: KafkaZkClient): Seq[Broker] = { + zkClient.makeSurePersistentPathExists(BrokerIdsZNode.path) + val brokers = brokerMetadatas.map { b => + val protocol = SecurityProtocol.PLAINTEXT + val listenerName = ListenerName.forSecurityProtocol(protocol) + Broker(b.id, Seq(EndPoint("localhost", 6667, listenerName, protocol)), b.rack) + } + brokers.foreach(b => zkClient.registerBroker(BrokerInfo(Broker(b.id, b.endPoints, rack = b.rack), + ApiVersion.latestVersion, jmxPort = -1))) + brokers + } + + def getMsgStrings(n: Int): Seq[String] = { + val buffer = new ListBuffer[String] + for (i <- 0 until n) + buffer += ("msg" + i) + buffer + } + + def makeLeaderForPartition(zkClient: KafkaZkClient, + topic: String, + leaderPerPartitionMap: scala.collection.immutable.Map[Int, Int], + controllerEpoch: Int): Unit = { + val newLeaderIsrAndControllerEpochs = leaderPerPartitionMap.map { case (partition, leader) => + val topicPartition = new TopicPartition(topic, partition) + val newLeaderAndIsr = zkClient.getTopicPartitionState(topicPartition) + .map(_.leaderAndIsr.newLeader(leader)) + .getOrElse(LeaderAndIsr(leader, List(leader))) + topicPartition -> LeaderIsrAndControllerEpoch(newLeaderAndIsr, controllerEpoch) + } + zkClient.setTopicPartitionStatesRaw(newLeaderIsrAndControllerEpochs, ZkVersion.MatchAnyVersion) + } + + /** + * If neither oldLeaderOpt nor newLeaderOpt is defined, wait until the leader of a partition is elected. + * If oldLeaderOpt is defined, it waits until the new leader is different from the old leader. + * If newLeaderOpt is defined, it waits until the new leader becomes the expected new leader. + * + * @return The new leader (note that negative values are used to indicate conditions like NoLeader and + * LeaderDuringDelete). + * @throws AssertionError if the expected condition is not true within the timeout. + */ + def waitUntilLeaderIsElectedOrChanged(zkClient: KafkaZkClient, topic: String, partition: Int, timeoutMs: Long = 30000L, + oldLeaderOpt: Option[Int] = None, newLeaderOpt: Option[Int] = None): Int = { + require(!(oldLeaderOpt.isDefined && newLeaderOpt.isDefined), "Can't define both the old and the new leader") + val startTime = System.currentTimeMillis() + val topicPartition = new TopicPartition(topic, partition) + + trace(s"Waiting for leader to be elected or changed for partition $topicPartition, old leader is $oldLeaderOpt, " + + s"new leader is $newLeaderOpt") + + var leader: Option[Int] = None + var electedOrChangedLeader: Option[Int] = None + while (electedOrChangedLeader.isEmpty && System.currentTimeMillis() < startTime + timeoutMs) { + // check if leader is elected + leader = zkClient.getLeaderForPartition(topicPartition) + leader match { + case Some(l) => (newLeaderOpt, oldLeaderOpt) match { + case (Some(newLeader), _) if newLeader == l => + trace(s"Expected new leader $l is elected for partition $topicPartition") + electedOrChangedLeader = leader + case (_, Some(oldLeader)) if oldLeader != l => + trace(s"Leader for partition $topicPartition is changed from $oldLeader to $l") + electedOrChangedLeader = leader + case (None, None) => + trace(s"Leader $l is elected for partition $topicPartition") + electedOrChangedLeader = leader + case _ => + trace(s"Current leader for partition $topicPartition is $l") + } + case None => + trace(s"Leader for partition $topicPartition is not elected yet") + } + Thread.sleep(math.min(timeoutMs, 100L)) + } + electedOrChangedLeader.getOrElse { + val errorMessage = (newLeaderOpt, oldLeaderOpt) match { + case (Some(newLeader), _) => + s"Timing out after $timeoutMs ms since expected new leader $newLeader was not elected for partition $topicPartition, leader is $leader" + case (_, Some(oldLeader)) => + s"Timing out after $timeoutMs ms since a new leader that is different from $oldLeader was not elected for partition $topicPartition, " + + s"leader is $leader" + case _ => + s"Timing out after $timeoutMs ms since a leader was not elected for partition $topicPartition" + } + throw new AssertionError(errorMessage) + } + } + + /** + * Execute the given block. If it throws an assert error, retry. Repeat + * until no error is thrown or the time limit elapses + */ + def retry(maxWaitMs: Long)(block: => Unit): Unit = { + var wait = 1L + val startTime = System.currentTimeMillis() + while(true) { + try { + block + return + } catch { + case e: AssertionError => + val elapsed = System.currentTimeMillis - startTime + if (elapsed > maxWaitMs) { + throw e + } else { + info("Attempt failed, sleeping for " + wait + ", and then retrying.") + Thread.sleep(wait) + wait += math.min(wait, 1000) + } + } + } + } + + def pollUntilTrue(consumer: Consumer[_, _], + action: () => Boolean, + msg: => String, + waitTimeMs: Long = JTestUtils.DEFAULT_MAX_WAIT_MS): Unit = { + waitUntilTrue(() => { + consumer.poll(Duration.ofMillis(100)) + action() + }, msg = msg, pause = 0L, waitTimeMs = waitTimeMs) + } + + def pollRecordsUntilTrue[K, V](consumer: Consumer[K, V], + action: ConsumerRecords[K, V] => Boolean, + msg: => String, + waitTimeMs: Long = JTestUtils.DEFAULT_MAX_WAIT_MS): Unit = { + waitUntilTrue(() => { + val records = consumer.poll(Duration.ofMillis(100)) + action(records) + }, msg = msg, pause = 0L, waitTimeMs = waitTimeMs) + } + + def subscribeAndWaitForRecords(topic: String, + consumer: KafkaConsumer[Array[Byte], Array[Byte]], + waitTimeMs: Long = JTestUtils.DEFAULT_MAX_WAIT_MS): Unit = { + consumer.subscribe(Collections.singletonList(topic)) + pollRecordsUntilTrue( + consumer, + (records: ConsumerRecords[Array[Byte], Array[Byte]]) => !records.isEmpty, + "Expected records", + waitTimeMs) + } + + /** + * Wait for the presence of an optional value. + * + * @param func The function defining the optional value + * @param msg Error message in the case that the value never appears + * @param waitTimeMs Maximum time to wait + * @return The unwrapped value returned by the function + */ + def awaitValue[T](func: () => Option[T], msg: => String, waitTimeMs: Long = JTestUtils.DEFAULT_MAX_WAIT_MS): T = { + var value: Option[T] = None + waitUntilTrue(() => { + value = func() + value.isDefined + }, msg, waitTimeMs) + value.get + } + + /** + * Wait until the given condition is true or throw an exception if the given wait time elapses. + * + * @param condition condition to check + * @param msg error message + * @param waitTimeMs maximum time to wait and retest the condition before failing the test + * @param pause delay between condition checks + */ + def waitUntilTrue(condition: () => Boolean, msg: => String, + waitTimeMs: Long = JTestUtils.DEFAULT_MAX_WAIT_MS, pause: Long = 100L): Unit = { + val startTime = System.currentTimeMillis() + while (true) { + if (condition()) + return + if (System.currentTimeMillis() > startTime + waitTimeMs) + fail(msg) + Thread.sleep(waitTimeMs.min(pause)) + } + + // should never hit here + throw new RuntimeException("unexpected error") + } + + /** + * Invoke `compute` until `predicate` is true or `waitTime` elapses. + * + * Return the last `compute` result and a boolean indicating whether `predicate` succeeded for that value. + * + * This method is useful in cases where `waitUntilTrue` makes it awkward to provide good error messages. + */ + def computeUntilTrue[T](compute: => T, waitTime: Long = JTestUtils.DEFAULT_MAX_WAIT_MS, pause: Long = 100L)( + predicate: T => Boolean): (T, Boolean) = { + val startTime = System.currentTimeMillis() + while (true) { + val result = compute + if (predicate(result)) + return result -> true + if (System.currentTimeMillis() > startTime + waitTime) + return result -> false + Thread.sleep(waitTime.min(pause)) + } + // should never hit here + throw new RuntimeException("unexpected error") + } + + /** + * Invoke `assertions` until no AssertionErrors are thrown or `waitTime` elapses. + * + * This method is useful in cases where there may be some expected delay in a particular test condition that is + * otherwise difficult to poll for. `computeUntilTrue` and `waitUntilTrue` should be preferred in cases where we can + * easily wait on a condition before evaluating the assertions. + */ + def tryUntilNoAssertionError(waitTime: Long = JTestUtils.DEFAULT_MAX_WAIT_MS, pause: Long = 100L)(assertions: => Unit) = { + val (error, success) = TestUtils.computeUntilTrue({ + try { + assertions + None + } catch { + case ae: AssertionError => Some(ae) + } + }, waitTime = waitTime, pause = pause)(_.isEmpty) + + if (!success) { + throw error.get + } + } + + def isLeaderLocalOnBroker(topic: String, partitionId: Int, broker: KafkaBroker): Boolean = { + broker.replicaManager.onlinePartition(new TopicPartition(topic, partitionId)).exists(_.leaderLogIfLocal.isDefined) + } + + def findLeaderEpoch(brokerId: Int, + topicPartition: TopicPartition, + brokers: Iterable[KafkaBroker]): Int = { + val leaderBroker = brokers.find(_.config.brokerId == brokerId) + val leaderPartition = leaderBroker.flatMap(_.replicaManager.onlinePartition(topicPartition)) + .getOrElse(throw new AssertionError(s"Failed to find expected replica on broker $brokerId")) + leaderPartition.getLeaderEpoch + } + + def findFollowerId(topicPartition: TopicPartition, + brokers: Iterable[KafkaBroker]): Int = { + val followerOpt = brokers.find { server => + server.replicaManager.onlinePartition(topicPartition) match { + case Some(partition) => !partition.leaderReplicaIdOpt.contains(server.config.brokerId) + case None => false + } + } + followerOpt + .map(_.config.brokerId) + .getOrElse(throw new AssertionError(s"Unable to locate follower for $topicPartition")) + } + + /** + * Wait until all brokers know about each other. + * + * @param brokers The Kafka brokers. + * @param timeout The amount of time waiting on this condition before assert to fail + */ + def waitUntilBrokerMetadataIsPropagated[B <: KafkaBroker]( + brokers: Seq[B], + timeout: Long = JTestUtils.DEFAULT_MAX_WAIT_MS): Unit = { + val expectedBrokerIds = brokers.map(_.config.brokerId).toSet + waitUntilTrue(() => brokers.forall(server => + expectedBrokerIds == server.dataPlaneRequestProcessor.metadataCache.getAliveBrokers().map(_.id).toSet + ), "Timed out waiting for broker metadata to propagate to all servers", timeout) + } + + /** + * Wait until the expected number of partitions is in the metadata cache in each broker. + * + * @param brokers The list of brokers that the metadata should reach to + * @param topic The topic name + * @param expectedNumPartitions The expected number of partitions + * @return all partitions metadata + */ + def waitForAllPartitionsMetadata[B <: KafkaBroker]( + brokers: Seq[B], + topic: String, expectedNumPartitions: Int): Map[TopicPartition, UpdateMetadataPartitionState] = { + waitUntilTrue( + () => brokers.forall { broker => + if (expectedNumPartitions == 0) { + broker.metadataCache.numPartitions(topic) == None + } else { + broker.metadataCache.numPartitions(topic) == Some(expectedNumPartitions) + } + }, + s"Topic [$topic] metadata not propagated after 60000 ms", waitTimeMs = 60000L) + + // since the metadata is propagated, we should get the same metadata from each server + (0 until expectedNumPartitions).map { i => + new TopicPartition(topic, i) -> brokers.head.metadataCache.getPartitionInfo(topic, i).getOrElse( + throw new IllegalStateException(s"Cannot get topic: $topic, partition: $i in server metadata cache")) + }.toMap + } + + /** + * Wait until a valid leader is propagated to the metadata cache in each broker. + * It assumes that the leader propagated to each broker is the same. + * + * @param brokers The list of brokers that the metadata should reach to + * @param topic The topic name + * @param partition The partition Id + * @param timeout The amount of time waiting on this condition before assert to fail + * @return The metadata of the partition. + */ + def waitForPartitionMetadata[B <: KafkaBroker]( + brokers: Seq[B], topic: String, partition: Int, + timeout: Long = JTestUtils.DEFAULT_MAX_WAIT_MS): UpdateMetadataPartitionState = { + waitUntilTrue( + () => brokers.forall { broker => + broker.metadataCache.getPartitionInfo(topic, partition) match { + case Some(partitionState) => Request.isValidBrokerId(partitionState.leader) + case _ => false + } + }, + "Partition [%s,%d] metadata not propagated after %d ms".format(topic, partition, timeout), + waitTimeMs = timeout) + + brokers.head.metadataCache.getPartitionInfo(topic, partition).getOrElse( + throw new IllegalStateException(s"Cannot get topic: $topic, partition: $partition in server metadata cache")) + } + + def waitUntilControllerElected(zkClient: KafkaZkClient, timeout: Long = JTestUtils.DEFAULT_MAX_WAIT_MS): Int = { + val (controllerId, _) = computeUntilTrue(zkClient.getControllerId, waitTime = timeout)(_.isDefined) + controllerId.getOrElse(throw new AssertionError(s"Controller not elected after $timeout ms")) + } + + def awaitLeaderChange[B <: KafkaBroker]( + brokers: Seq[B], + tp: TopicPartition, + oldLeader: Int, + timeout: Long = JTestUtils.DEFAULT_MAX_WAIT_MS): Int = { + def newLeaderExists: Option[Int] = { + brokers.find { broker => + broker.config.brokerId != oldLeader && + broker.replicaManager.onlinePartition(tp).exists(_.leaderLogIfLocal.isDefined) + }.map(_.config.brokerId) + } + + waitUntilTrue(() => newLeaderExists.isDefined, + s"Did not observe leader change for partition $tp after $timeout ms", waitTimeMs = timeout) + + newLeaderExists.get + } + + def waitUntilLeaderIsKnown[B <: KafkaBroker]( + brokers: Seq[B], + tp: TopicPartition, + timeout: Long = JTestUtils.DEFAULT_MAX_WAIT_MS): Int = { + def leaderIfExists: Option[Int] = { + brokers.find { broker => + broker.replicaManager.onlinePartition(tp).exists(_.leaderLogIfLocal.isDefined) + }.map(_.config.brokerId) + } + + waitUntilTrue(() => leaderIfExists.isDefined, + s"Partition $tp leaders not made yet after $timeout ms", waitTimeMs = timeout) + + leaderIfExists.get + } + + def writeNonsenseToFile(fileName: File, position: Long, size: Int): Unit = { + val file = new RandomAccessFile(fileName, "rw") + file.seek(position) + for (_ <- 0 until size) + file.writeByte(random.nextInt(255)) + file.close() + } + + def appendNonsenseToFile(file: File, size: Int): Unit = { + val outputStream = Files.newOutputStream(file.toPath(), StandardOpenOption.APPEND) + try { + for (_ <- 0 until size) + outputStream.write(random.nextInt(255)) + } finally outputStream.close() + } + + def checkForPhantomInSyncReplicas(zkClient: KafkaZkClient, topic: String, partitionToBeReassigned: Int, assignedReplicas: Seq[Int]): Unit = { + val inSyncReplicas = zkClient.getInSyncReplicasForPartition(new TopicPartition(topic, partitionToBeReassigned)) + // in sync replicas should not have any replica that is not in the new assigned replicas + val phantomInSyncReplicas = inSyncReplicas.get.toSet -- assignedReplicas.toSet + assertTrue(phantomInSyncReplicas.isEmpty, + "All in sync replicas %s must be in the assigned replica list %s".format(inSyncReplicas, assignedReplicas)) + } + + def ensureNoUnderReplicatedPartitions(zkClient: KafkaZkClient, topic: String, partitionToBeReassigned: Int, assignedReplicas: Seq[Int], + servers: Seq[KafkaServer]): Unit = { + val topicPartition = new TopicPartition(topic, partitionToBeReassigned) + waitUntilTrue(() => { + val inSyncReplicas = zkClient.getInSyncReplicasForPartition(topicPartition) + inSyncReplicas.get.size == assignedReplicas.size + }, + "Reassigned partition [%s,%d] is under replicated".format(topic, partitionToBeReassigned)) + var leader: Option[Int] = None + waitUntilTrue(() => { + leader = zkClient.getLeaderForPartition(topicPartition) + leader.isDefined + }, + "Reassigned partition [%s,%d] is unavailable".format(topic, partitionToBeReassigned)) + waitUntilTrue(() => { + val leaderBroker = servers.filter(s => s.config.brokerId == leader.get).head + leaderBroker.replicaManager.underReplicatedPartitionCount == 0 + }, + "Reassigned partition [%s,%d] is under-replicated as reported by the leader %d".format(topic, partitionToBeReassigned, leader.get)) + } + + // Note: Call this method in the test itself, rather than the @AfterEach method. + // Because of the assert, if assertNoNonDaemonThreads fails, nothing after would be executed. + def assertNoNonDaemonThreads(threadNamePrefix: String): Unit = { + val threadCount = Thread.getAllStackTraces.keySet.asScala.count { t => + !t.isDaemon && t.isAlive && t.getName.startsWith(threadNamePrefix) + } + assertEquals(0, threadCount) + } + + def allThreadStackTraces(): String = { + Thread.getAllStackTraces.asScala.map { case (thread, stackTrace) => + thread.getName + "\n\t" + stackTrace.toList.map(_.toString).mkString("\n\t") + }.mkString("\n") + } + + /** + * Create new LogManager instance with default configuration for testing + */ + def createLogManager(logDirs: Seq[File] = Seq.empty[File], + defaultConfig: LogConfig = LogConfig(), + configRepository: ConfigRepository = new MockConfigRepository, + cleanerConfig: CleanerConfig = CleanerConfig(enableCleaner = false), + time: MockTime = new MockTime(), + interBrokerProtocolVersion: ApiVersion = ApiVersion.latestVersion): LogManager = { + new LogManager(logDirs = logDirs.map(_.getAbsoluteFile), + initialOfflineDirs = Array.empty[File], + configRepository = configRepository, + initialDefaultConfig = defaultConfig, + cleanerConfig = cleanerConfig, + recoveryThreadsPerDataDir = 4, + flushCheckMs = 1000L, + flushRecoveryOffsetCheckpointMs = 10000L, + flushStartOffsetCheckpointMs = 10000L, + retentionCheckMs = 1000L, + maxPidExpirationMs = 60 * 60 * 1000, + scheduler = time.scheduler, + time = time, + brokerTopicStats = new BrokerTopicStats, + logDirFailureChannel = new LogDirFailureChannel(logDirs.size), + keepPartitionMetadataFile = true, + interBrokerProtocolVersion = interBrokerProtocolVersion) + } + + class MockAlterIsrManager extends AlterIsrManager { + val isrUpdates: mutable.Queue[AlterIsrItem] = new mutable.Queue[AlterIsrItem]() + val inFlight: AtomicBoolean = new AtomicBoolean(false) + + + override def submit( + topicPartition: TopicPartition, + leaderAndIsr: LeaderAndIsr, + controllerEpoch: Int + ): CompletableFuture[LeaderAndIsr]= { + val future = new CompletableFuture[LeaderAndIsr]() + if (inFlight.compareAndSet(false, true)) { + isrUpdates += AlterIsrItem(topicPartition, leaderAndIsr, future, controllerEpoch) + } else { + future.completeExceptionally(new OperationNotAttemptedException( + s"Failed to enqueue AlterIsr request for $topicPartition since there is already an inflight request")) + } + future + } + + def completeIsrUpdate(newZkVersion: Int): Unit = { + if (inFlight.compareAndSet(true, false)) { + val item = isrUpdates.dequeue() + item.future.complete(item.leaderAndIsr.withZkVersion(newZkVersion)) + } else { + fail("Expected an in-flight ISR update, but there was none") + } + } + + def failIsrUpdate(error: Errors): Unit = { + if (inFlight.compareAndSet(true, false)) { + val item = isrUpdates.dequeue() + item.future.completeExceptionally(error.exception) + } else { + fail("Expected an in-flight ISR update, but there was none") + } + } + } + + def createAlterIsrManager(): MockAlterIsrManager = { + new MockAlterIsrManager() + } + + class MockIsrChangeListener extends IsrChangeListener { + val expands: AtomicInteger = new AtomicInteger(0) + val shrinks: AtomicInteger = new AtomicInteger(0) + val failures: AtomicInteger = new AtomicInteger(0) + + override def markExpand(): Unit = expands.incrementAndGet() + + override def markShrink(): Unit = shrinks.incrementAndGet() + + override def markFailed(): Unit = failures.incrementAndGet() + + def reset(): Unit = { + expands.set(0) + shrinks.set(0) + failures.set(0) + } + } + + def createIsrChangeListener(): MockIsrChangeListener = { + new MockIsrChangeListener() + } + + def produceMessages[B <: KafkaBroker]( + brokers: Seq[B], + records: Seq[ProducerRecord[Array[Byte], Array[Byte]]], + acks: Int = -1): Unit = { + val producer = createProducer(TestUtils.getBrokerListStrFromServers(brokers), acks = acks) + try { + val futures = records.map(producer.send) + futures.foreach(_.get) + } finally { + producer.close() + } + + val topics = records.map(_.topic).distinct + debug(s"Sent ${records.size} messages for topics ${topics.mkString(",")}") + } + + def generateAndProduceMessages[B <: KafkaBroker]( + brokers: Seq[B], + topic: String, + numMessages: Int, + acks: Int = -1): Seq[String] = { + val values = (0 until numMessages).map(x => s"test-$x") + val intSerializer = new IntegerSerializer() + val records = values.zipWithIndex.map { case (v, i) => + new ProducerRecord(topic, intSerializer.serialize(topic, i), v.getBytes) + } + produceMessages(brokers, records, acks) + values + } + + def produceMessage[B <: KafkaBroker]( + brokers: Seq[B], + topic: String, + message: String, + timestamp: java.lang.Long = null, + deliveryTimeoutMs: Int = 30 * 1000, + requestTimeoutMs: Int = 20 * 1000): Unit = { + val producer = createProducer(TestUtils.getBrokerListStrFromServers(brokers), + deliveryTimeoutMs = deliveryTimeoutMs, requestTimeoutMs = requestTimeoutMs) + try { + producer.send(new ProducerRecord(topic, null, timestamp, topic.getBytes, message.getBytes)).get + } finally { + producer.close() + } + } + + def verifyTopicDeletion[B <: KafkaBroker]( + zkClient: KafkaZkClient, + topic: String, + numPartitions: Int, + brokers: Seq[B]): Unit = { + val topicPartitions = (0 until numPartitions).map(new TopicPartition(topic, _)) + if (zkClient != null) { + // wait until admin path for delete topic is deleted, signaling completion of topic deletion + waitUntilTrue(() => !zkClient.isTopicMarkedForDeletion(topic), + "Admin path /admin/delete_topics/%s path not deleted even after a replica is restarted".format(topic)) + waitUntilTrue(() => !zkClient.topicExists(topic), + "Topic path /brokers/topics/%s not deleted after /admin/delete_topics/%s path is deleted".format(topic, topic)) + } + // ensure that the topic-partition has been deleted from all brokers' replica managers + waitUntilTrue(() => + brokers.forall(broker => topicPartitions.forall(tp => broker.replicaManager.onlinePartition(tp).isEmpty)), + "Replica manager's should have deleted all of this topic's partitions") + // ensure that logs from all replicas are deleted if delete topic is marked successful in ZooKeeper + assertTrue(brokers.forall(broker => topicPartitions.forall(tp => broker.logManager.getLog(tp).isEmpty)), + "Replica logs not deleted after delete topic is complete") + // ensure that topic is removed from all cleaner offsets + waitUntilTrue(() => brokers.forall(broker => topicPartitions.forall { tp => + val checkpoints = broker.logManager.liveLogDirs.map { logDir => + new OffsetCheckpointFile(new File(logDir, "cleaner-offset-checkpoint")).read() + } + checkpoints.forall(checkpointsPerLogDir => !checkpointsPerLogDir.contains(tp)) + }), "Cleaner offset for deleted partition should have been removed") + waitUntilTrue(() => brokers.forall(broker => + broker.config.logDirs.forall { logDir => + topicPartitions.forall { tp => + !new File(logDir, tp.topic + "-" + tp.partition).exists() + } + } + ), "Failed to soft-delete the data to a delete directory") + waitUntilTrue(() => brokers.forall(broker => + broker.config.logDirs.forall { logDir => + topicPartitions.forall { tp => + !Arrays.asList(new File(logDir).list()).asScala.exists { partitionDirectoryName => + partitionDirectoryName.startsWith(tp.topic + "-" + tp.partition) && + partitionDirectoryName.endsWith(UnifiedLog.DeleteDirSuffix) + } + } + } + ), "Failed to hard-delete the delete directory") + } + + + def causeLogDirFailure(failureType: LogDirFailureType, leaderBroker: KafkaBroker, partition: TopicPartition): Unit = { + // Make log directory of the partition on the leader broker inaccessible by replacing it with a file + val localLog = leaderBroker.replicaManager.localLogOrException(partition) + val logDir = localLog.dir.getParentFile + CoreUtils.swallow(Utils.delete(logDir), this) + logDir.createNewFile() + assertTrue(logDir.isFile) + + if (failureType == Roll) { + assertThrows(classOf[KafkaStorageException], () => leaderBroker.replicaManager.getLog(partition).get.roll()) + } else if (failureType == Checkpoint) { + leaderBroker.replicaManager.checkpointHighWatermarks() + } + + // Wait for ReplicaHighWatermarkCheckpoint to happen so that the log directory of the topic will be offline + waitUntilTrue(() => !leaderBroker.logManager.isLogDirOnline(logDir.getAbsolutePath), "Expected log directory offline", 3000L) + assertTrue(leaderBroker.replicaManager.localLog(partition).isEmpty) + } + + /** + * Translate the given buffer into a string + * + * @param buffer The buffer to translate + * @param encoding The encoding to use in translating bytes to characters + */ + def readString(buffer: ByteBuffer, encoding: String = Charset.defaultCharset.toString): String = { + val bytes = new Array[Byte](buffer.remaining) + buffer.get(bytes) + new String(bytes, encoding) + } + + def copyOf(props: Properties): Properties = { + val copy = new Properties() + copy ++= props + copy + } + + def sslConfigs(mode: Mode, clientCert: Boolean, trustStoreFile: Option[File], certAlias: String, + certCn: String = SslCertificateCn, + tlsProtocol: String = TestSslUtils.DEFAULT_TLS_PROTOCOL_FOR_TESTS): Properties = { + val trustStore = trustStoreFile.getOrElse { + throw new Exception("SSL enabled but no trustStoreFile provided") + } + + val sslConfigs = new TestSslUtils.SslConfigsBuilder(mode) + .useClientCert(clientCert) + .createNewTrustStore(trustStore) + .certAlias(certAlias) + .cn(certCn) + .tlsProtocol(tlsProtocol) + .build() + + val sslProps = new Properties() + sslConfigs.forEach { (k, v) => sslProps.put(k, v) } + sslProps + } + + // a X509TrustManager to trust self-signed certs for unit tests. + def trustAllCerts: X509TrustManager = { + val trustManager = new X509TrustManager() { + override def getAcceptedIssuers: Array[X509Certificate] = { + null + } + override def checkClientTrusted(certs: Array[X509Certificate], authType: String): Unit = { + } + override def checkServerTrusted(certs: Array[X509Certificate], authType: String): Unit = { + } + } + trustManager + } + + def waitAndVerifyAcls(expected: Set[AccessControlEntry], + authorizer: JAuthorizer, + resource: ResourcePattern, + accessControlEntryFilter: AccessControlEntryFilter = AccessControlEntryFilter.ANY): Unit = { + val newLine = scala.util.Properties.lineSeparator + + val filter = new AclBindingFilter(resource.toFilter, accessControlEntryFilter) + waitUntilTrue(() => authorizer.acls(filter).asScala.map(_.entry).toSet == expected, + s"expected acls:${expected.mkString(newLine + "\t", newLine + "\t", newLine)}" + + s"but got:${authorizer.acls(filter).asScala.map(_.entry).mkString(newLine + "\t", newLine + "\t", newLine)}") + } + + /** + * Verifies that this ACL is the secure one. + */ + def isAclSecure(acl: ACL, sensitive: Boolean): Boolean = { + debug(s"ACL $acl") + acl.getPerms match { + case Perms.READ => !sensitive && acl.getId.getScheme == "world" + case Perms.ALL => acl.getId.getScheme == "sasl" + case _ => false + } + } + + /** + * Verifies that the ACL corresponds to the unsecure one that + * provides ALL access to everyone (world). + */ + def isAclUnsecure(acl: ACL): Boolean = { + debug(s"ACL $acl") + acl.getPerms match { + case Perms.ALL => acl.getId.getScheme == "world" + case _ => false + } + } + + private def secureZkPaths(zkClient: KafkaZkClient): Seq[String] = { + def subPaths(path: String): Seq[String] = { + if (zkClient.pathExists(path)) + path +: zkClient.getChildren(path).map(c => path + "/" + c).flatMap(subPaths) + else + Seq.empty + } + val topLevelPaths = ZkData.SecureRootPaths ++ ZkData.SensitiveRootPaths + topLevelPaths.flatMap(subPaths) + } + + /** + * Verifies that all secure paths in ZK are created with the expected ACL. + */ + def verifySecureZkAcls(zkClient: KafkaZkClient, usersWithAccess: Int): Unit = { + secureZkPaths(zkClient).foreach(path => { + if (zkClient.pathExists(path)) { + val sensitive = ZkData.sensitivePath(path) + // usersWithAccess have ALL access to path. For paths that are + // not sensitive, world has READ access. + val aclCount = if (sensitive) usersWithAccess else usersWithAccess + 1 + val acls = zkClient.getAcl(path) + assertEquals(aclCount, acls.size, s"Invalid ACLs for $path $acls") + acls.foreach(acl => isAclSecure(acl, sensitive)) + } + }) + } + + /** + * Verifies that secure paths in ZK have no access control. This is + * the case when zookeeper.set.acl=false and no ACLs have been configured. + */ + def verifyUnsecureZkAcls(zkClient: KafkaZkClient): Unit = { + secureZkPaths(zkClient).foreach(path => { + if (zkClient.pathExists(path)) { + val acls = zkClient.getAcl(path) + assertEquals(1, acls.size, s"Invalid ACLs for $path $acls") + acls.foreach(isAclUnsecure) + } + }) + } + + /** + * To use this you pass in a sequence of functions that are your arrange/act/assert test on the SUT. + * They all run at the same time in the assertConcurrent method; the chances of triggering a multithreading code error, + * and thereby failing some assertion are greatly increased. + */ + def assertConcurrent(message: String, functions: Seq[() => Any], timeoutMs: Int): Unit = { + + def failWithTimeout(): Unit = { + fail(s"$message. Timed out, the concurrent functions took more than $timeoutMs milliseconds") + } + + val numThreads = functions.size + val threadPool = Executors.newFixedThreadPool(numThreads) + val exceptions = ArrayBuffer[Throwable]() + try { + val runnables = functions.map { function => + new Callable[Unit] { + override def call(): Unit = function() + } + }.asJava + val futures = threadPool.invokeAll(runnables, timeoutMs, TimeUnit.MILLISECONDS).asScala + futures.foreach { future => + if (future.isCancelled) + failWithTimeout() + else + try future.get() + catch { case e: Exception => + exceptions += e + } + } + } catch { + case _: InterruptedException => failWithTimeout() + case e: Throwable => exceptions += e + } finally { + threadPool.shutdownNow() + } + assertTrue(exceptions.isEmpty, s"$message failed with exception(s) $exceptions") + } + + def consumeTopicRecords[K, V, B <: KafkaBroker]( + brokers: Seq[B], + topic: String, + numMessages: Int, + groupId: String = "group", + securityProtocol: SecurityProtocol = SecurityProtocol.PLAINTEXT, + trustStoreFile: Option[File] = None, + waitTime: Long = JTestUtils.DEFAULT_MAX_WAIT_MS): Seq[ConsumerRecord[Array[Byte], Array[Byte]]] = { + val consumer = createConsumer(TestUtils.getBrokerListStrFromServers(brokers, securityProtocol), + groupId = groupId, + securityProtocol = securityProtocol, + trustStoreFile = trustStoreFile) + try { + consumer.subscribe(Collections.singleton(topic)) + consumeRecords(consumer, numMessages, waitTime) + } finally consumer.close() + } + + def pollUntilAtLeastNumRecords[K, V](consumer: Consumer[K, V], + numRecords: Int, + waitTimeMs: Long = JTestUtils.DEFAULT_MAX_WAIT_MS): Seq[ConsumerRecord[K, V]] = { + val records = new ArrayBuffer[ConsumerRecord[K, V]]() + def pollAction(polledRecords: ConsumerRecords[K, V]): Boolean = { + records ++= polledRecords.asScala + records.size >= numRecords + } + pollRecordsUntilTrue(consumer, pollAction, + waitTimeMs = waitTimeMs, + msg = s"Consumed ${records.size} records before timeout instead of the expected $numRecords records") + records + } + + def consumeRecords[K, V](consumer: Consumer[K, V], + numRecords: Int, + waitTimeMs: Long = JTestUtils.DEFAULT_MAX_WAIT_MS): Seq[ConsumerRecord[K, V]] = { + val records = pollUntilAtLeastNumRecords(consumer, numRecords, waitTimeMs) + assertEquals(numRecords, records.size, "Consumed more records than expected") + records + } + + /** + * Will consume all the records for the given consumer for the specified duration. If you want to drain all the + * remaining messages in the partitions the consumer is subscribed to, the duration should be set high enough so + * that the consumer has enough time to poll everything. This would be based on the number of expected messages left + * in the topic, and should not be too large (ie. more than a second) in our tests. + * + * @return All the records consumed by the consumer within the specified duration. + */ + def consumeRecordsFor[K, V](consumer: KafkaConsumer[K, V], duration: Long = JTestUtils.DEFAULT_MAX_WAIT_MS): Seq[ConsumerRecord[K, V]] = { + val startTime = System.currentTimeMillis() + val records = new ArrayBuffer[ConsumerRecord[K, V]]() + waitUntilTrue(() => { + records ++= consumer.poll(Duration.ofMillis(50)).asScala + System.currentTimeMillis() - startTime > duration + }, s"The timeout $duration was greater than the maximum wait time.") + records + } + + def createTransactionalProducer[B <: KafkaBroker]( + transactionalId: String, + brokers: Seq[B], + batchSize: Int = 16384, + transactionTimeoutMs: Long = 60000, + maxBlockMs: Long = 60000, + deliveryTimeoutMs: Int = 120000, + requestTimeoutMs: Int = 30000, + maxInFlight: Int = 5): KafkaProducer[Array[Byte], Array[Byte]] = { + val props = new Properties() + props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, getBrokerListStrFromServers(brokers)) + props.put(ProducerConfig.ACKS_CONFIG, "all") + props.put(ProducerConfig.BATCH_SIZE_CONFIG, batchSize.toString) + props.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, transactionalId) + props.put(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG, "true") + props.put(ProducerConfig.TRANSACTION_TIMEOUT_CONFIG, transactionTimeoutMs.toString) + props.put(ProducerConfig.MAX_BLOCK_MS_CONFIG, maxBlockMs.toString) + props.put(ProducerConfig.DELIVERY_TIMEOUT_MS_CONFIG, deliveryTimeoutMs.toString) + props.put(ProducerConfig.REQUEST_TIMEOUT_MS_CONFIG, requestTimeoutMs.toString) + props.put(ProducerConfig.MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION, maxInFlight.toString) + new KafkaProducer[Array[Byte], Array[Byte]](props, new ByteArraySerializer, new ByteArraySerializer) + } + + // Seeds the given topic with records with keys and values in the range [0..numRecords) + def seedTopicWithNumberedRecords[B <: KafkaBroker]( + topic: String, + numRecords: Int, + brokers: Seq[B]): Unit = { + val props = new Properties() + props.put(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG, "true") + props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, getBrokerListStrFromServers(brokers)) + val producer = new KafkaProducer[Array[Byte], Array[Byte]](props, new ByteArraySerializer, new ByteArraySerializer) + try { + for (i <- 0 until numRecords) { + producer.send(new ProducerRecord[Array[Byte], Array[Byte]](topic, asBytes(i.toString), asBytes(i.toString))) + } + producer.flush() + } finally { + producer.close() + } + } + + private def asString(bytes: Array[Byte]) = new String(bytes, StandardCharsets.UTF_8) + + private def asBytes(string: String) = string.getBytes(StandardCharsets.UTF_8) + + // Verifies that the record was intended to be committed by checking the headers for an expected transaction status + // If true, this will return the value as a string. It is expected that the record in question should have been created + // by the `producerRecordWithExpectedTransactionStatus` method. + def assertCommittedAndGetValue(record: ConsumerRecord[Array[Byte], Array[Byte]]) : String = { + record.headers.headers(transactionStatusKey).asScala.headOption match { + case Some(header) => + assertEquals(asString(committedValue), asString(header.value), s"Got ${asString(header.value)} but expected the value to indicate " + + s"committed status.") + case None => + fail("expected the record header to include an expected transaction status, but received nothing.") + } + recordValueAsString(record) + } + + def recordValueAsString(record: ConsumerRecord[Array[Byte], Array[Byte]]) : String = { + asString(record.value) + } + + def producerRecordWithExpectedTransactionStatus(topic: String, partition: Integer, key: Array[Byte], value: Array[Byte], willBeCommitted: Boolean): ProducerRecord[Array[Byte], Array[Byte]] = { + val header = new Header {override def key() = transactionStatusKey + override def value() = if (willBeCommitted) + committedValue + else + abortedValue + } + new ProducerRecord[Array[Byte], Array[Byte]](topic, partition, key, value, Collections.singleton(header)) + } + + def producerRecordWithExpectedTransactionStatus(topic: String, partition: Integer, key: String, value: String, willBeCommitted: Boolean): ProducerRecord[Array[Byte], Array[Byte]] = { + producerRecordWithExpectedTransactionStatus(topic, partition, asBytes(key), asBytes(value), willBeCommitted) + } + + // Collect the current positions for all partition in the consumers current assignment. + def consumerPositions(consumer: KafkaConsumer[Array[Byte], Array[Byte]]) : Map[TopicPartition, OffsetAndMetadata] = { + val offsetsToCommit = new mutable.HashMap[TopicPartition, OffsetAndMetadata]() + consumer.assignment.forEach { topicPartition => + offsetsToCommit.put(topicPartition, new OffsetAndMetadata(consumer.position(topicPartition))) + } + offsetsToCommit.toMap + } + + def resetToCommittedPositions(consumer: KafkaConsumer[Array[Byte], Array[Byte]]): Unit = { + val committed = consumer.committed(consumer.assignment).asScala.filter(_._2 != null).map { case (k, v) => k -> v.offset } + + consumer.assignment.forEach { topicPartition => + if (committed.contains(topicPartition)) + consumer.seek(topicPartition, committed(topicPartition)) + else + consumer.seekToBeginning(Collections.singletonList(topicPartition)) + } + } + + def fetchEntityConfigWithAdmin[B <: KafkaBroker]( + configResource: ConfigResource, + brokers: Seq[B], + adminConfig: Properties = new Properties): Properties = { + val properties = new Properties() + val adminClient = createAdminClient(brokers, adminConfig) + try { + val result = adminClient.describeConfigs(Collections.singletonList(configResource)).all().get() + val config = result.get(configResource) + if (config != null) { + config.entries().forEach(e => properties.setProperty(e.name(), e.value())) + } + } finally { + adminClient.close() + } + properties + } + + def incrementalAlterConfigs[B <: KafkaBroker]( + servers: Seq[B], + adminClient: Admin, + props: Properties, + perBrokerConfig: Boolean, + opType: OpType = OpType.SET): AlterConfigsResult = { + val configEntries = props.asScala.map { case (k, v) => new AlterConfigOp(new ConfigEntry(k, v), opType) }.toList.asJavaCollection + val configs = if (perBrokerConfig) { + servers.map { server => + val resource = new ConfigResource(ConfigResource.Type.BROKER, server.config.brokerId.toString) + (resource, configEntries) + }.toMap.asJava + } else { + Map(new ConfigResource(ConfigResource.Type.BROKER, "") -> configEntries).asJava + } + adminClient.incrementalAlterConfigs(configs) + } + + def alterClientQuotas(adminClient: Admin, request: Map[ClientQuotaEntity, Map[String, Option[Double]]]): AlterClientQuotasResult = { + val entries = request.map { case (entity, alter) => + val ops = alter.map { case (key, value) => + new ClientQuotaAlteration.Op(key, value.map(Double.box).getOrElse(null)) + }.asJavaCollection + new ClientQuotaAlteration(entity, ops) + }.asJavaCollection + adminClient.alterClientQuotas(entries) + } + + def assertLeader(client: Admin, topicPartition: TopicPartition, expectedLeader: Int): Unit = { + waitForLeaderToBecome(client, topicPartition, Some(expectedLeader)) + } + + def assertNoLeader(client: Admin, topicPartition: TopicPartition): Unit = { + waitForLeaderToBecome(client, topicPartition, None) + } + + def waitForOnlineBroker(client: Admin, brokerId: Int): Unit = { + waitUntilTrue(() => { + val nodes = client.describeCluster().nodes().get() + nodes.asScala.exists(_.id == brokerId) + }, s"Timed out waiting for brokerId $brokerId to come online") + } + + /** + * Get the replica assignment for some topics. Topics which don't exist will be ignored. + */ + def getReplicaAssignmentForTopics[B <: KafkaBroker]( + topicNames: Seq[String], + brokers: Seq[B], + adminConfig: Properties = new Properties): Map[TopicPartition, Seq[Int]] = { + val adminClient = createAdminClient(brokers, adminConfig) + val results = new mutable.HashMap[TopicPartition, Seq[Int]] + try { + adminClient.describeTopics(topicNames.toList.asJava).topicNameValues().forEach { + case (topicName, future) => + try { + val description = future.get() + description.partitions().forEach { + case partition => + val topicPartition = new TopicPartition(topicName, partition.partition()) + results.put(topicPartition, partition.replicas().asScala.map(_.id)) + } + } catch { + case e: ExecutionException => if (e.getCause != null && + e.getCause.isInstanceOf[UnknownTopicOrPartitionException]) { + // ignore + } else { + throw e + } + } + } + } finally { + adminClient.close() + } + results + } + + def waitForLeaderToBecome( + client: Admin, + topicPartition: TopicPartition, + expectedLeaderOpt: Option[Int] + ): Unit = { + val topic = topicPartition.topic + val partitionId = topicPartition.partition + + def currentLeader: Try[Option[Int]] = Try { + val topicDescription = client.describeTopics(List(topic).asJava).allTopicNames.get.get(topic) + topicDescription.partitions.asScala + .find(_.partition == partitionId) + .flatMap(partitionState => Option(partitionState.leader)) + .map(_.id) + } + + val (lastLeaderCheck, isLeaderElected) = computeUntilTrue(currentLeader) { + case Success(leaderOpt) => leaderOpt == expectedLeaderOpt + case Failure(e: ExecutionException) if e.getCause.isInstanceOf[UnknownTopicOrPartitionException] => false + case Failure(e) => throw e + } + + assertTrue(isLeaderElected, s"Timed out waiting for leader to become $expectedLeaderOpt. " + + s"Last metadata lookup returned leader = ${lastLeaderCheck.getOrElse("unknown")}") + } + + def waitForBrokersOutOfIsr(client: Admin, partition: Set[TopicPartition], brokerIds: Set[Int]): Unit = { + waitUntilTrue( + () => { + val description = client.describeTopics(partition.map(_.topic).asJava).allTopicNames.get.asScala + val isr = description + .values + .flatMap(_.partitions.asScala.flatMap(_.isr.asScala)) + .map(_.id) + .toSet + + brokerIds.intersect(isr).isEmpty + }, + s"Expected brokers $brokerIds to no longer be in the ISR for $partition" + ) + } + + def waitForBrokersInIsr(client: Admin, partition: TopicPartition, brokerIds: Set[Int]): Unit = { + waitUntilTrue( + () => { + val description = client.describeTopics(Set(partition.topic).asJava).allTopicNames.get.asScala + val isr = description + .values + .flatMap(_.partitions.asScala.flatMap(_.isr.asScala)) + .map(_.id) + .toSet + + brokerIds.subsetOf(isr) + }, + s"Expected brokers $brokerIds to be in the ISR for $partition" + ) + } + + def waitForReplicasAssigned(client: Admin, partition: TopicPartition, brokerIds: Seq[Int]): Unit = { + waitUntilTrue( + () => { + val description = client.describeTopics(Set(partition.topic).asJava).allTopicNames.get.asScala + val replicas = description + .values + .flatMap(_.partitions.asScala.flatMap(_.replicas.asScala)) + .map(_.id) + .toSeq + + brokerIds == replicas + }, + s"Expected brokers $brokerIds to be the replicas for $partition" + ) + } + + /** + * Capture the console output during the execution of the provided function. + */ + def grabConsoleOutput(f: => Unit) : String = { + val out = new ByteArrayOutputStream + try scala.Console.withOut(out)(f) + finally scala.Console.out.flush() + out.toString + } + + /** + * Capture the console error during the execution of the provided function. + */ + def grabConsoleError(f: => Unit) : String = { + val err = new ByteArrayOutputStream + try scala.Console.withErr(err)(f) + finally scala.Console.err.flush() + err.toString + } + + /** + * Capture both the console output and console error during the execution of the provided function. + */ + def grabConsoleOutputAndError(f: => Unit) : (String, String) = { + val out = new ByteArrayOutputStream + val err = new ByteArrayOutputStream + try scala.Console.withOut(out)(scala.Console.withErr(err)(f)) + finally { + scala.Console.out.flush() + scala.Console.err.flush() + } + (out.toString, err.toString) + } + + def assertFutureExceptionTypeEquals(future: KafkaFuture[_], clazz: Class[_ <: Throwable], + expectedErrorMessage: Option[String] = None): Unit = { + val cause = assertThrows(classOf[ExecutionException], () => future.get()).getCause + assertTrue(clazz.isInstance(cause), "Expected an exception of type " + clazz.getName + "; got type " + + cause.getClass.getName) + expectedErrorMessage.foreach(message => assertTrue(cause.getMessage.contains(message), s"Received error message : ${cause.getMessage}" + + s" does not contain expected error message : $message")) + } + + def assertBadConfigContainingMessage(props: Properties, expectedExceptionContainsText: String): Unit = { + try { + KafkaConfig.fromProps(props) + fail("Expected illegal configuration but instead it was legal") + } catch { + case caught @ (_: ConfigException | _: IllegalArgumentException) => + assertTrue(caught.getMessage.contains(expectedExceptionContainsText)) + } + } + + def totalMetricValue(broker: KafkaBroker, metricName: String): Long = { + totalMetricValue(broker.metrics, metricName) + } + + def totalMetricValue(metrics: Metrics, metricName: String): Long = { + val allMetrics = metrics.metrics + val total = allMetrics.values().asScala.filter(_.metricName().name() == metricName) + .foldLeft(0.0)((total, metric) => total + metric.metricValue.asInstanceOf[Double]) + total.toLong + } + + def meterCount(metricName: String): Long = { + KafkaYammerMetrics.defaultRegistry.allMetrics.asScala + .filter { case (k, _) => k.getMBeanName.endsWith(metricName) } + .values + .headOption + .getOrElse(fail(s"Unable to find metric $metricName")) + .asInstanceOf[Meter] + .count + } + + def clearYammerMetrics(): Unit = { + for (metricName <- KafkaYammerMetrics.defaultRegistry.allMetrics.keySet.asScala) + KafkaYammerMetrics.defaultRegistry.removeMetric(metricName) + } + + def stringifyTopicPartitions(partitions: Set[TopicPartition]): String = { + Json.encodeAsString(Map("partitions" -> + partitions.map(tp => Map("topic" -> tp.topic, "partition" -> tp.partition).asJava).asJava).asJava) + } + + def resource[R <: AutoCloseable, A](resource: R)(func: R => A): A = { + try { + func(resource) + } finally { + resource.close() + } + } + + /** + * Set broker replication quotas and enable throttling for a set of partitions. This + * will override any previous replication quotas, but will leave the throttling status + * of other partitions unaffected. + */ + def setReplicationThrottleForPartitions(admin: Admin, + brokerIds: Seq[Int], + partitions: Set[TopicPartition], + throttleBytes: Int): Unit = { + throttleAllBrokersReplication(admin, brokerIds, throttleBytes) + assignThrottledPartitionReplicas(admin, partitions.map(_ -> brokerIds).toMap) + } + + /** + * Remove a set of throttled partitions and reset the overall replication quota. + */ + def removeReplicationThrottleForPartitions(admin: Admin, + brokerIds: Seq[Int], + partitions: Set[TopicPartition]): Unit = { + removePartitionReplicaThrottles(admin, partitions) + resetBrokersThrottle(admin, brokerIds) + } + + /** + * Throttles all replication across the cluster. + * @param adminClient is the adminClient to use for making connection with the cluster + * @param brokerIds all broker ids in the cluster + * @param throttleBytes is the target throttle + */ + def throttleAllBrokersReplication(adminClient: Admin, brokerIds: Seq[Int], throttleBytes: Int): Unit = { + val throttleConfigs = Seq( + new AlterConfigOp(new ConfigEntry(DynamicConfig.Broker.LeaderReplicationThrottledRateProp, throttleBytes.toString), AlterConfigOp.OpType.SET), + new AlterConfigOp(new ConfigEntry(DynamicConfig.Broker.FollowerReplicationThrottledRateProp, throttleBytes.toString), AlterConfigOp.OpType.SET) + ).asJavaCollection + + adminClient.incrementalAlterConfigs( + brokerIds.map { brokerId => + new ConfigResource(ConfigResource.Type.BROKER, brokerId.toString) -> throttleConfigs + }.toMap.asJava + ).all().get() + } + + def resetBrokersThrottle(adminClient: Admin, brokerIds: Seq[Int]): Unit = + throttleAllBrokersReplication(adminClient, brokerIds, Int.MaxValue) + + def assignThrottledPartitionReplicas(adminClient: Admin, allReplicasByPartition: Map[TopicPartition, Seq[Int]]): Unit = { + val throttles = allReplicasByPartition.groupBy(_._1.topic()).map { + case (topic, replicasByPartition) => + new ConfigResource(TOPIC, topic) -> Seq( + new AlterConfigOp(new ConfigEntry(LogConfig.LeaderReplicationThrottledReplicasProp, formatReplicaThrottles(replicasByPartition)), AlterConfigOp.OpType.SET), + new AlterConfigOp(new ConfigEntry(LogConfig.FollowerReplicationThrottledReplicasProp, formatReplicaThrottles(replicasByPartition)), AlterConfigOp.OpType.SET) + ).asJavaCollection + } + adminClient.incrementalAlterConfigs(throttles.asJava).all().get() + } + + def removePartitionReplicaThrottles(adminClient: Admin, partitions: Set[TopicPartition]): Unit = { + val throttles = partitions.map { + tp => + new ConfigResource(TOPIC, tp.topic()) -> Seq( + new AlterConfigOp(new ConfigEntry(LogConfig.LeaderReplicationThrottledReplicasProp, ""), AlterConfigOp.OpType.DELETE), + new AlterConfigOp(new ConfigEntry(LogConfig.FollowerReplicationThrottledReplicasProp, ""), AlterConfigOp.OpType.DELETE) + ).asJavaCollection + }.toMap + adminClient.incrementalAlterConfigs(throttles.asJava).all().get() + } + + def formatReplicaThrottles(moves: Map[TopicPartition, Seq[Int]]): String = + moves.flatMap { case (tp, assignment) => + assignment.map(replicaId => s"${tp.partition}:$replicaId") + }.mkString(",") + + def waitForAllReassignmentsToComplete(adminClient: Admin, pause: Long = 100L): Unit = { + waitUntilTrue(() => adminClient.listPartitionReassignments().reassignments().get().isEmpty, + s"There still are ongoing reassignments", pause = pause) + } + + def addAndVerifyAcls(broker: KafkaBroker, acls: Set[AccessControlEntry], resource: ResourcePattern): Unit = { + val authorizer = broker.dataPlaneRequestProcessor.authorizer.get + val aclBindings = acls.map { acl => new AclBinding(resource, acl) } + authorizer.createAcls(null, aclBindings.toList.asJava).asScala + .map(_.toCompletableFuture.get) + .foreach { result => + result.exception.ifPresent { e => throw e } + } + val aclFilter = new AclBindingFilter(resource.toFilter, AccessControlEntryFilter.ANY) + waitAndVerifyAcls( + authorizer.acls(aclFilter).asScala.map(_.entry).toSet ++ acls, + authorizer, resource) + } + + def removeAndVerifyAcls(broker: KafkaBroker, acls: Set[AccessControlEntry], resource: ResourcePattern): Unit = { + val authorizer = broker.dataPlaneRequestProcessor.authorizer.get + val aclBindingFilters = acls.map { acl => new AclBindingFilter(resource.toFilter, acl.toFilter) } + authorizer.deleteAcls(null, aclBindingFilters.toList.asJava).asScala + .map(_.toCompletableFuture.get) + .foreach { result => + result.exception.ifPresent { e => throw e } + } + val aclFilter = new AclBindingFilter(resource.toFilter, AccessControlEntryFilter.ANY) + waitAndVerifyAcls( + authorizer.acls(aclFilter).asScala.map(_.entry).toSet -- acls, + authorizer, resource) + } + + def buildRequestWithEnvelope(request: AbstractRequest, + principalSerde: KafkaPrincipalSerde, + requestChannelMetrics: RequestChannel.Metrics, + startTimeNanos: Long, + fromPrivilegedListener: Boolean = true, + shouldSpyRequestContext: Boolean = false, + envelope: Option[RequestChannel.Request] = None + ): RequestChannel.Request = { + val clientId = "id" + val listenerName = ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT) + + val requestHeader = new RequestHeader(request.apiKey, request.version, clientId, 0) + val requestBuffer = request.serializeWithHeader(requestHeader) + + val envelopeHeader = new RequestHeader(ApiKeys.ENVELOPE, ApiKeys.ENVELOPE.latestVersion(), clientId, 0) + val envelopeBuffer = new EnvelopeRequest.Builder( + requestBuffer, + principalSerde.serialize(KafkaPrincipal.ANONYMOUS), + InetAddress.getLocalHost.getAddress + ).build().serializeWithHeader(envelopeHeader) + + RequestHeader.parse(envelopeBuffer) + + var requestContext = new RequestContext(envelopeHeader, "1", InetAddress.getLocalHost, + KafkaPrincipal.ANONYMOUS, listenerName, SecurityProtocol.PLAINTEXT, ClientInformation.EMPTY, + fromPrivilegedListener, Optional.of(principalSerde)) + + if (shouldSpyRequestContext) { + requestContext = Mockito.spy(requestContext) + } + + new RequestChannel.Request( + processor = 1, + context = requestContext, + startTimeNanos = startTimeNanos, + memoryPool = MemoryPool.NONE, + buffer = envelopeBuffer, + metrics = requestChannelMetrics, + envelope = envelope + ) + } + + def verifyNoUnexpectedThreads(context: String): Unit = { + // Threads which may cause transient failures in subsequent tests if not shutdown. + // These include threads which make connections to brokers and may cause issues + // when broker ports are reused (e.g. auto-create topics) as well as threads + // which reset static JAAS configuration. + val unexpectedThreadNames = Set( + ControllerEventManager.ControllerEventThreadName, + KafkaProducer.NETWORK_THREAD_PREFIX, + AdminClientUnitTestEnv.kafkaAdminClientNetworkThreadPrefix(), + AbstractCoordinator.HEARTBEAT_THREAD_PREFIX, + QuorumTestHarness.ZkClientEventThreadSuffix, + QuorumController.CONTROLLER_THREAD_SUFFIX + ) + + def unexpectedThreads: Set[String] = { + val allThreads = Thread.getAllStackTraces.keySet.asScala.map(thread => thread.getName) + allThreads.filter(t => unexpectedThreadNames.exists(s => t.contains(s))).toSet + } + + val (unexpected, _) = TestUtils.computeUntilTrue(unexpectedThreads)(_.isEmpty) + assertTrue(unexpected.isEmpty, + s"Found ${unexpected.size} unexpected threads during $context: " + + s"${unexpected.mkString("`", ",", "`")}") + } + +} diff --git a/core/src/test/scala/unit/kafka/utils/ThrottlerTest.scala b/core/src/test/scala/unit/kafka/utils/ThrottlerTest.scala new file mode 100755 index 0000000..1591cba --- /dev/null +++ b/core/src/test/scala/unit/kafka/utils/ThrottlerTest.scala @@ -0,0 +1,61 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Assertions.{assertTrue, assertEquals} + + +class ThrottlerTest { + @Test + def testThrottleDesiredRate(): Unit = { + val throttleCheckIntervalMs = 100 + val desiredCountPerSec = 1000.0 + val desiredCountPerInterval = desiredCountPerSec * throttleCheckIntervalMs / 1000.0 + + val mockTime = new MockTime() + val throttler = new Throttler(desiredRatePerSec = desiredCountPerSec, + checkIntervalMs = throttleCheckIntervalMs, + time = mockTime) + + // Observe desiredCountPerInterval at t1 + val t1 = mockTime.milliseconds() + throttler.maybeThrottle(desiredCountPerInterval) + assertEquals(t1, mockTime.milliseconds()) + + // Observe desiredCountPerInterval at t1 + throttleCheckIntervalMs + 1, + mockTime.sleep(throttleCheckIntervalMs + 1) + throttler.maybeThrottle(desiredCountPerInterval) + val t2 = mockTime.milliseconds() + assertTrue(t2 >= t1 + 2 * throttleCheckIntervalMs) + + // Observe desiredCountPerInterval at t2 + throttler.maybeThrottle(desiredCountPerInterval) + assertEquals(t2, mockTime.milliseconds()) + + // Observe desiredCountPerInterval at t2 + throttleCheckIntervalMs + 1 + mockTime.sleep(throttleCheckIntervalMs + 1) + throttler.maybeThrottle(desiredCountPerInterval) + val t3 = mockTime.milliseconds() + assertTrue(t3 >= t2 + 2 * throttleCheckIntervalMs) + + val elapsedTimeMs = t3 - t1 + val actualCountPerSec = 4 * desiredCountPerInterval * 1000 / elapsedTimeMs + assertTrue(actualCountPerSec <= desiredCountPerSec) + } +} diff --git a/core/src/test/scala/unit/kafka/utils/TopicFilterTest.scala b/core/src/test/scala/unit/kafka/utils/TopicFilterTest.scala new file mode 100644 index 0000000..ea728dc --- /dev/null +++ b/core/src/test/scala/unit/kafka/utils/TopicFilterTest.scala @@ -0,0 +1,49 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import org.apache.kafka.common.internals.Topic +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +class TopicFilterTest { + + @Test + def testIncludeLists(): Unit = { + + val topicFilter1 = IncludeList("yes1,yes2") + assertTrue(topicFilter1.isTopicAllowed("yes2", excludeInternalTopics = true)) + assertTrue(topicFilter1.isTopicAllowed("yes2", excludeInternalTopics = false)) + assertFalse(topicFilter1.isTopicAllowed("no1", excludeInternalTopics = true)) + assertFalse(topicFilter1.isTopicAllowed("no1", excludeInternalTopics = false)) + + val topicFilter2 = IncludeList(".+") + assertTrue(topicFilter2.isTopicAllowed("alltopics", excludeInternalTopics = true)) + assertFalse(topicFilter2.isTopicAllowed(Topic.GROUP_METADATA_TOPIC_NAME, excludeInternalTopics = true)) + assertTrue(topicFilter2.isTopicAllowed(Topic.GROUP_METADATA_TOPIC_NAME, excludeInternalTopics = false)) + + val topicFilter3 = IncludeList("included-topic.+") + assertTrue(topicFilter3.isTopicAllowed("included-topic1", excludeInternalTopics = true)) + assertFalse(topicFilter3.isTopicAllowed("no1", excludeInternalTopics = true)) + + val topicFilter4 = IncludeList("test-(?!bad\\b)[\\w]+") + assertTrue(topicFilter4.isTopicAllowed("test-good", excludeInternalTopics = true)) + assertFalse(topicFilter4.isTopicAllowed("test-bad", excludeInternalTopics = true)) + } + +} diff --git a/core/src/test/scala/unit/kafka/utils/json/JsonValueTest.scala b/core/src/test/scala/unit/kafka/utils/json/JsonValueTest.scala new file mode 100644 index 0000000..8194b29 --- /dev/null +++ b/core/src/test/scala/unit/kafka/utils/json/JsonValueTest.scala @@ -0,0 +1,212 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils.json + +import scala.collection.Seq + +import com.fasterxml.jackson.databind.{ObjectMapper, JsonMappingException} +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Assertions._ + +import kafka.utils.Json + +class JsonValueTest { + + private val json = """ + |{ + | "boolean": false, + | "int": 1234, + | "long": 3000000000, + | "double": 16.244355, + | "string": "string", + | "number_as_string": "123", + | "array": [4.0, 11.1, 44.5], + | "object": { + | "a": true, + | "b": false + | }, + | "null": null + |} + """.stripMargin + + private def parse(s: String): JsonValue = + Json.parseFull(s).getOrElse(sys.error("Failed to parse json: " + s)) + + private def assertTo[T: DecodeJson](expected: T, jsonValue: JsonObject => JsonValue): Unit = { + val parsed = jsonValue(parse(json).asJsonObject) + assertEquals(Right(expected), parsed.toEither[T]) + assertEquals(expected, parsed.to[T]) + } + + private def assertToFails[T: DecodeJson](jsonValue: JsonObject => JsonValue): Unit = { + val parsed = jsonValue(parse(json).asJsonObject) + assertTrue(parsed.toEither[T].isLeft) + assertThrow[JsonMappingException](parsed.to[T]) + } + + def assertThrow[E <: Throwable : Manifest](body: => Unit): Unit = { + import scala.util.control.Exception._ + val klass = manifest[E].runtimeClass + catchingPromiscuously(klass).opt(body).foreach { _ => + fail("Expected `" + klass + "` to be thrown, but no exception was thrown") + } + } + + @Test + def testAsJsonObject(): Unit = { + val parsed = parse(json).asJsonObject + val obj = parsed("object") + assertEquals(obj, obj.asJsonObject) + assertThrow[JsonMappingException](parsed("array").asJsonObject) + } + + @Test + def testAsJsonObjectOption(): Unit = { + val parsed = parse(json).asJsonObject + assertTrue(parsed("object").asJsonObjectOption.isDefined) + assertEquals(None, parsed("array").asJsonObjectOption) + } + + @Test + def testAsJsonArray(): Unit = { + val parsed = parse(json).asJsonObject + val array = parsed("array") + assertEquals(array, array.asJsonArray) + assertThrow[JsonMappingException](parsed("object").asJsonArray) + } + + @Test + def testAsJsonArrayOption(): Unit = { + val parsed = parse(json).asJsonObject + assertTrue(parsed("array").asJsonArrayOption.isDefined) + assertEquals(None, parsed("object").asJsonArrayOption) + } + + @Test + def testJsonObjectGet(): Unit = { + val parsed = parse(json).asJsonObject + assertEquals(Some(parse("""{"a":true,"b":false}""")), parsed.get("object")) + assertEquals(None, parsed.get("aaaaa")) + } + + @Test + def testJsonObjectApply(): Unit = { + val parsed = parse(json).asJsonObject + assertEquals(parse("""{"a":true,"b":false}"""), parsed("object")) + assertThrow[JsonMappingException](parsed("aaaaaaaa")) + } + + @Test + def testJsonObjectIterator(): Unit = { + assertEquals( + Vector("a" -> parse("true"), "b" -> parse("false")), + parse(json).asJsonObject("object").asJsonObject.iterator.toVector + ) + } + + @Test + def testJsonArrayIterator(): Unit = { + assertEquals(Vector("4.0", "11.1", "44.5").map(parse), parse(json).asJsonObject("array").asJsonArray.iterator.toVector) + } + + @Test + def testJsonValueEquals(): Unit = { + + assertEquals(parse(json), parse(json)) + + assertEquals(parse("""{"blue": true, "red": false}"""), parse("""{"red": false, "blue": true}""")) + assertNotEquals(parse("""{"blue": true, "red": true}"""), parse("""{"red": false, "blue": true}""")) + + assertEquals(parse("""[1, 2, 3]"""), parse("""[1, 2, 3]""")) + assertNotEquals(parse("""[1, 2, 3]"""), parse("""[2, 1, 3]""")) + + assertEquals(parse("1344"), parse("1344")) + assertNotEquals(parse("1344"), parse("144")) + + } + + @Test + def testJsonValueHashCode(): Unit = { + assertEquals(new ObjectMapper().readTree(json).hashCode, parse(json).hashCode) + } + + @Test + def testJsonValueToString(): Unit = { + val js = """{"boolean":false,"int":1234,"array":[4.0,11.1,44.5],"object":{"a":true,"b":false}}""" + assertEquals(js, parse(js).toString) + } + + @Test + def testDecodeBoolean(): Unit = { + assertTo[Boolean](false, _("boolean")) + assertToFails[Boolean](_("int")) + } + + @Test + def testDecodeString(): Unit = { + assertTo[String]("string", _("string")) + assertTo[String]("123", _("number_as_string")) + assertToFails[String](_("int")) + assertToFails[String](_("array")) + } + + @Test + def testDecodeInt(): Unit = { + assertTo[Int](1234, _("int")) + assertToFails[Int](_("long")) + } + + @Test + def testDecodeLong(): Unit = { + assertTo[Long](3000000000L, _("long")) + assertTo[Long](1234, _("int")) + assertToFails[Long](_("string")) + } + + @Test + def testDecodeDouble(): Unit = { + assertTo[Double](16.244355, _("double")) + assertTo[Double](1234.0, _("int")) + assertTo[Double](3000000000L, _("long")) + assertToFails[Double](_("string")) + } + + @Test + def testDecodeSeq(): Unit = { + assertTo[Seq[Double]](Seq(4.0, 11.1, 44.5), _("array")) + assertToFails[Seq[Double]](_("string")) + assertToFails[Seq[Double]](_("object")) + assertToFails[Seq[String]](_("array")) + } + + @Test + def testDecodeMap(): Unit = { + assertTo[Map[String, Boolean]](Map("a" -> true, "b" -> false), _("object")) + assertToFails[Map[String, Int]](_("object")) + assertToFails[Map[String, String]](_("object")) + assertToFails[Map[String, Double]](_("array")) + } + + @Test + def testDecodeOption(): Unit = { + assertTo[Option[Int]](None, _("null")) + assertTo[Option[Int]](Some(1234), _("int")) + assertToFails[Option[String]](_("int")) + } + +} diff --git a/core/src/test/scala/unit/kafka/utils/timer/MockTimer.scala b/core/src/test/scala/unit/kafka/utils/timer/MockTimer.scala new file mode 100644 index 0000000..819954a --- /dev/null +++ b/core/src/test/scala/unit/kafka/utils/timer/MockTimer.scala @@ -0,0 +1,69 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.utils.timer + +import kafka.utils.MockTime + +import scala.collection.mutable + +class MockTimer(val time: MockTime = new MockTime) extends Timer { + + private val taskQueue = mutable.PriorityQueue[TimerTaskEntry]()(Ordering[TimerTaskEntry].reverse) + + def add(timerTask: TimerTask): Unit = { + if (timerTask.delayMs <= 0) + timerTask.run() + else { + taskQueue synchronized { + taskQueue.enqueue(new TimerTaskEntry(timerTask, timerTask.delayMs + time.milliseconds)) + } + } + } + + def advanceClock(timeoutMs: Long): Boolean = { + time.sleep(timeoutMs) + + var executed = false + val now = time.milliseconds + + var hasMore = true + while (hasMore) { + hasMore = false + val head = taskQueue synchronized { + if (taskQueue.nonEmpty && now > taskQueue.head.expirationMs) { + val entry = Some(taskQueue.dequeue()) + hasMore = taskQueue.nonEmpty + entry + } else + None + } + head.foreach { taskEntry => + if (!taskEntry.cancelled) { + val task = taskEntry.timerTask + task.run() + executed = true + } + } + } + executed + } + + def size: Int = taskQueue.size + + override def shutdown(): Unit = {} + +} diff --git a/core/src/test/scala/unit/kafka/utils/timer/TimerTaskListTest.scala b/core/src/test/scala/unit/kafka/utils/timer/TimerTaskListTest.scala new file mode 100644 index 0000000..011041a --- /dev/null +++ b/core/src/test/scala/unit/kafka/utils/timer/TimerTaskListTest.scala @@ -0,0 +1,93 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.utils.timer + +import org.junit.jupiter.api.Assertions._ +import java.util.concurrent.atomic._ +import org.junit.jupiter.api.Test + +class TimerTaskListTest { + + private class TestTask(val delayMs: Long) extends TimerTask { + def run(): Unit = { } + } + + private def size(list: TimerTaskList): Int = { + var count = 0 + list.foreach(_ => count += 1) + count + } + + @Test + def testAll(): Unit = { + val sharedCounter = new AtomicInteger(0) + val list1 = new TimerTaskList(sharedCounter) + val list2 = new TimerTaskList(sharedCounter) + val list3 = new TimerTaskList(sharedCounter) + + val tasks = (1 to 10).map { i => + val task = new TestTask(0L) + list1.add(new TimerTaskEntry(task, 10L)) + assertEquals(i, sharedCounter.get) + task + } + + assertEquals(tasks.size, sharedCounter.get) + + // reinserting the existing tasks shouldn't change the task count + tasks.take(4).foreach { task => + val prevCount = sharedCounter.get + // new TimerTaskEntry(task) will remove the existing entry from the list + list2.add(new TimerTaskEntry(task, 10L)) + assertEquals(prevCount, sharedCounter.get) + } + assertEquals(10 - 4, size(list1)) + assertEquals(4, size(list2)) + + assertEquals(tasks.size, sharedCounter.get) + + // reinserting the existing tasks shouldn't change the task count + tasks.drop(4).foreach { task => + val prevCount = sharedCounter.get + // new TimerTaskEntry(task) will remove the existing entry from the list + list3.add(new TimerTaskEntry(task, 10L)) + assertEquals(prevCount, sharedCounter.get) + } + assertEquals(0, size(list1)) + assertEquals(4, size(list2)) + assertEquals(6, size(list3)) + + assertEquals(tasks.size, sharedCounter.get) + + // cancel tasks in lists + list1.foreach { _.cancel() } + assertEquals(0, size(list1)) + assertEquals(4, size(list2)) + assertEquals(6, size(list3)) + + list2.foreach { _.cancel() } + assertEquals(0, size(list1)) + assertEquals(0, size(list2)) + assertEquals(6, size(list3)) + + list3.foreach { _.cancel() } + assertEquals(0, size(list1)) + assertEquals(0, size(list2)) + assertEquals(0, size(list3)) + } + +} diff --git a/core/src/test/scala/unit/kafka/utils/timer/TimerTest.scala b/core/src/test/scala/unit/kafka/utils/timer/TimerTest.scala new file mode 100644 index 0000000..2c08627 --- /dev/null +++ b/core/src/test/scala/unit/kafka/utils/timer/TimerTest.scala @@ -0,0 +1,107 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.utils.timer + +import java.util.concurrent.{CountDownLatch, TimeUnit} + +import org.junit.jupiter.api.Assertions._ +import java.util.concurrent.atomic._ +import org.junit.jupiter.api.{Test, AfterEach, BeforeEach} + +import scala.collection.mutable.ArrayBuffer + +class TimerTest { + + private class TestTask(override val delayMs: Long, id: Int, latch: CountDownLatch, output: ArrayBuffer[Int]) extends TimerTask { + private[this] val completed = new AtomicBoolean(false) + def run(): Unit = { + if (completed.compareAndSet(false, true)) { + output.synchronized { output += id } + latch.countDown() + } + } + } + + private[this] var timer: Timer = null + + @BeforeEach + def setup(): Unit = { + timer = new SystemTimer("test", tickMs = 1, wheelSize = 3) + } + + @AfterEach + def teardown(): Unit = { + timer.shutdown() + } + + @Test + def testAlreadyExpiredTask(): Unit = { + val output = new ArrayBuffer[Int]() + + + val latches = (-5 until 0).map { i => + val latch = new CountDownLatch(1) + timer.add(new TestTask(i, i, latch, output)) + latch + } + + timer.advanceClock(0) + + latches.take(5).foreach { latch => + assertEquals(true, latch.await(3, TimeUnit.SECONDS), "already expired tasks should run immediately") + } + + assertEquals(Set(-5, -4, -3, -2, -1), output.toSet, "output of already expired tasks") + } + + @Test + def testTaskExpiration(): Unit = { + val output = new ArrayBuffer[Int]() + + val tasks = new ArrayBuffer[TestTask]() + val ids = new ArrayBuffer[Int]() + + val latches = + (0 until 5).map { i => + val latch = new CountDownLatch(1) + tasks += new TestTask(i, i, latch, output) + ids += i + latch + } ++ (10 until 100).map { i => + val latch = new CountDownLatch(2) + tasks += new TestTask(i, i, latch, output) + tasks += new TestTask(i, i, latch, output) + ids += i + ids += i + latch + } ++ (100 until 500).map { i => + val latch = new CountDownLatch(1) + tasks += new TestTask(i, i, latch, output) + ids += i + latch + } + + // randomly submit requests + tasks.foreach { task => timer.add(task) } + + while (timer.advanceClock(2000)) {} + + latches.foreach { latch => latch.await() } + + assertEquals(ids.sorted, output.toSeq, "output should match") + } +} diff --git a/core/src/test/scala/unit/kafka/zk/AdminZkClientTest.scala b/core/src/test/scala/unit/kafka/zk/AdminZkClientTest.scala new file mode 100644 index 0000000..5e9817b --- /dev/null +++ b/core/src/test/scala/unit/kafka/zk/AdminZkClientTest.scala @@ -0,0 +1,360 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.admin + +import java.util +import java.util.Properties + +import kafka.controller.ReplicaAssignment +import kafka.log._ +import kafka.server.DynamicConfig.Broker._ +import kafka.server.KafkaConfig._ +import kafka.server.{ConfigType, KafkaConfig, KafkaServer, QuorumTestHarness} +import kafka.utils.CoreUtils._ +import kafka.utils.TestUtils._ +import kafka.utils.{Logging, TestUtils} +import kafka.zk.{AdminZkClient, KafkaZkClient} +import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.config.TopicConfig +import org.apache.kafka.common.errors.{InvalidReplicaAssignmentException, InvalidTopicException, TopicExistsException} +import org.apache.kafka.common.metrics.Quota +import org.apache.kafka.test.{TestUtils => JTestUtils} +import org.easymock.EasyMock +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, Test} + +import scala.jdk.CollectionConverters._ +import scala.collection.{Map, Seq, immutable} + +class AdminZkClientTest extends QuorumTestHarness with Logging with RackAwareTest { + + var servers: Seq[KafkaServer] = Seq() + + @AfterEach + override def tearDown(): Unit = { + TestUtils.shutdownServers(servers) + super.tearDown() + } + + @Test + def testManualReplicaAssignment(): Unit = { + val brokers = List(0, 1, 2, 3, 4) + TestUtils.createBrokersInZk(zkClient, brokers) + + val topicConfig = new Properties() + + // duplicate brokers + assertThrows(classOf[InvalidReplicaAssignmentException], () => adminZkClient.createTopicWithAssignment("test", topicConfig, Map(0->Seq(0,0)))) + + // inconsistent replication factor + assertThrows(classOf[InvalidReplicaAssignmentException], () => adminZkClient.createTopicWithAssignment("test", topicConfig, Map(0->Seq(0,1), 1->Seq(0)))) + + // partitions should be 0-based + assertThrows(classOf[InvalidReplicaAssignmentException], () => adminZkClient.createTopicWithAssignment("test", topicConfig, Map(1->Seq(1,2), 2->Seq(1,2)))) + + // partitions should be 0-based and consecutive + assertThrows(classOf[InvalidReplicaAssignmentException], () => adminZkClient.createTopicWithAssignment("test", topicConfig, Map(0->Seq(1,2), 0->Seq(1,2), 3->Seq(1,2)))) + + // partitions should be 0-based and consecutive + assertThrows(classOf[InvalidReplicaAssignmentException], () => adminZkClient.createTopicWithAssignment("test", topicConfig, Map(-1->Seq(1,2), 1->Seq(1,2), 2->Seq(1,2), 4->Seq(1,2)))) + + // good assignment + val assignment = Map(0 -> List(0, 1, 2), + 1 -> List(1, 2, 3)) + adminZkClient.createTopicWithAssignment("test", topicConfig, assignment) + val found = zkClient.getPartitionAssignmentForTopics(Set("test")) + assertEquals(assignment.map { case (k, v) => k -> ReplicaAssignment(v, List(), List()) }, found("test")) + } + + @Test + def testTopicCreationInZK(): Unit = { + val expectedReplicaAssignment = Map( + 0 -> List(0, 1, 2), + 1 -> List(1, 2, 3), + 2 -> List(2, 3, 4), + 3 -> List(3, 4, 0), + 4 -> List(4, 0, 1), + 5 -> List(0, 2, 3), + 6 -> List(1, 3, 4), + 7 -> List(2, 4, 0), + 8 -> List(3, 0, 1), + 9 -> List(4, 1, 2), + 10 -> List(1, 2, 3), + 11 -> List(1, 3, 4) + ) + val leaderForPartitionMap = immutable.Map( + 0 -> 0, + 1 -> 1, + 2 -> 2, + 3 -> 3, + 4 -> 4, + 5 -> 0, + 6 -> 1, + 7 -> 2, + 8 -> 3, + 9 -> 4, + 10 -> 1, + 11 -> 1 + ) + val topic = "test" + val topicConfig = new Properties() + TestUtils.createBrokersInZk(zkClient, List(0, 1, 2, 3, 4)) + // create the topic + adminZkClient.createTopicWithAssignment(topic, topicConfig, expectedReplicaAssignment) + // create leaders for all partitions + TestUtils.makeLeaderForPartition(zkClient, topic, leaderForPartitionMap, 1) + val actualReplicaMap = leaderForPartitionMap.keys.map(p => p -> zkClient.getReplicasForPartition(new TopicPartition(topic, p))).toMap + assertEquals(expectedReplicaAssignment.size, actualReplicaMap.size) + for(i <- 0 until actualReplicaMap.size) + assertEquals(expectedReplicaAssignment.get(i).get, actualReplicaMap(i)) + + // shouldn't be able to create a topic that already exists + assertThrows(classOf[TopicExistsException], () => adminZkClient.createTopicWithAssignment(topic, topicConfig, expectedReplicaAssignment)) + } + + @Test + def testTopicCreationWithCollision(): Unit = { + val topic = "test.topic" + val collidingTopic = "test_topic" + TestUtils.createBrokersInZk(zkClient, List(0, 1, 2, 3, 4)) + // create the topic + adminZkClient.createTopic(topic, 3, 1) + + // shouldn't be able to create a topic that collides + assertThrows(classOf[InvalidTopicException], () => adminZkClient.createTopic(collidingTopic, 3, 1)) + } + + @Test + def testMarkedDeletionTopicCreation(): Unit = { + val zkMock: KafkaZkClient = EasyMock.createNiceMock(classOf[KafkaZkClient]) + val topicPartition = new TopicPartition("test", 0) + val topic = topicPartition.topic + EasyMock.expect(zkMock.isTopicMarkedForDeletion(topic)).andReturn(true); + EasyMock.replay(zkMock) + val adminZkClient = new AdminZkClient(zkMock) + + assertThrows(classOf[TopicExistsException], () => adminZkClient.validateTopicCreate(topic, Map.empty, new Properties)) + } + + @Test + def testMockedConcurrentTopicCreation(): Unit = { + val topic = "test.topic" + + // simulate the ZK interactions that can happen when a topic is concurrently created by multiple processes + val zkMock: KafkaZkClient = EasyMock.createNiceMock(classOf[KafkaZkClient]) + EasyMock.expect(zkMock.topicExists(topic)).andReturn(false) + EasyMock.expect(zkMock.getAllTopicsInCluster(false)).andReturn(Set("some.topic", topic, "some.other.topic")) + EasyMock.replay(zkMock) + val adminZkClient = new AdminZkClient(zkMock) + + assertThrows(classOf[TopicExistsException], () => adminZkClient.validateTopicCreate(topic, Map.empty, new Properties)) + } + + @Test + def testConcurrentTopicCreation(): Unit = { + val topic = "test-concurrent-topic-creation" + TestUtils.createBrokersInZk(zkClient, List(0, 1, 2, 3, 4)) + val props = new Properties + props.setProperty(TopicConfig.MIN_IN_SYNC_REPLICAS_CONFIG, "2") + def createTopic(): Unit = { + try adminZkClient.createTopic(topic, 3, 1, props) + catch { case _: TopicExistsException => () } + val (_, partitionAssignment) = zkClient.getPartitionAssignmentForTopics(Set(topic)).head + assertEquals(3, partitionAssignment.size) + partitionAssignment.foreach { case (partition, partitionReplicaAssignment) => + assertEquals(1, partitionReplicaAssignment.replicas.size, s"Unexpected replication factor for $partition") + } + val savedProps = zkClient.getEntityConfigs(ConfigType.Topic, topic) + assertEquals(props, savedProps) + } + + TestUtils.assertConcurrent("Concurrent topic creation failed", Seq(() => createTopic(), () => createTopic()), + JTestUtils.DEFAULT_MAX_WAIT_MS.toInt) + } + + /** + * This test creates a topic with a few config overrides and checks that the configs are applied to the new topic + * then changes the config and checks that the new values take effect. + */ + @Test + def testTopicConfigChange(): Unit = { + val partitions = 3 + val topic = "my-topic" + val server = TestUtils.createServer(KafkaConfig.fromProps(TestUtils.createBrokerConfig(0, zkConnect))) + servers = Seq(server) + + def makeConfig(messageSize: Int, retentionMs: Long, throttledLeaders: String, throttledFollowers: String) = { + val props = new Properties() + props.setProperty(LogConfig.MaxMessageBytesProp, messageSize.toString) + props.setProperty(LogConfig.RetentionMsProp, retentionMs.toString) + props.setProperty(LogConfig.LeaderReplicationThrottledReplicasProp, throttledLeaders) + props.setProperty(LogConfig.FollowerReplicationThrottledReplicasProp, throttledFollowers) + props + } + + def checkConfig(messageSize: Int, retentionMs: Long, throttledLeaders: String, throttledFollowers: String, quotaManagerIsThrottled: Boolean): Unit = { + def checkList(actual: util.List[String], expected: String): Unit = { + assertNotNull(actual) + if (expected == "") + assertTrue(actual.isEmpty) + else + assertEquals(expected.split(",").toSeq, actual.asScala) + } + TestUtils.retry(10000) { + for (part <- 0 until partitions) { + val tp = new TopicPartition(topic, part) + val log = server.logManager.getLog(tp) + assertTrue(log.isDefined) + assertEquals(retentionMs, log.get.config.retentionMs) + assertEquals(messageSize, log.get.config.maxMessageSize) + checkList(log.get.config.LeaderReplicationThrottledReplicas, throttledLeaders) + checkList(log.get.config.FollowerReplicationThrottledReplicas, throttledFollowers) + assertEquals(quotaManagerIsThrottled, server.quotaManagers.leader.isThrottled(tp)) + } + } + } + + // create a topic with a few config overrides and check that they are applied + val maxMessageSize = 1024 + val retentionMs = 1000 * 1000 + adminZkClient.createTopic(topic, partitions, 1, makeConfig(maxMessageSize, retentionMs, "0:0,1:0,2:0", "0:1,1:1,2:1")) + + //Standard topic configs will be propagated at topic creation time, but the quota manager will not have been updated. + checkConfig(maxMessageSize, retentionMs, "0:0,1:0,2:0", "0:1,1:1,2:1", false) + + //Update dynamically and all properties should be applied + adminZkClient.changeTopicConfig(topic, makeConfig(maxMessageSize, retentionMs, "0:0,1:0,2:0", "0:1,1:1,2:1")) + + checkConfig(maxMessageSize, retentionMs, "0:0,1:0,2:0", "0:1,1:1,2:1", true) + + // now double the config values for the topic and check that it is applied + val newConfig = makeConfig(2 * maxMessageSize, 2 * retentionMs, "*", "*") + adminZkClient.changeTopicConfig(topic, makeConfig(2 * maxMessageSize, 2 * retentionMs, "*", "*")) + checkConfig(2 * maxMessageSize, 2 * retentionMs, "*", "*", quotaManagerIsThrottled = true) + + // Verify that the same config can be read from ZK + val configInZk = adminZkClient.fetchEntityConfig(ConfigType.Topic, topic) + assertEquals(newConfig, configInZk) + + //Now delete the config + adminZkClient.changeTopicConfig(topic, new Properties) + checkConfig(Defaults.MaxMessageSize, Defaults.RetentionMs, "", "", quotaManagerIsThrottled = false) + + //Add config back + adminZkClient.changeTopicConfig(topic, makeConfig(maxMessageSize, retentionMs, "0:0,1:0,2:0", "0:1,1:1,2:1")) + checkConfig(maxMessageSize, retentionMs, "0:0,1:0,2:0", "0:1,1:1,2:1", quotaManagerIsThrottled = true) + + //Now ensure updating to "" removes the throttled replica list also + adminZkClient.changeTopicConfig(topic, propsWith((LogConfig.FollowerReplicationThrottledReplicasProp, ""), (LogConfig.LeaderReplicationThrottledReplicasProp, ""))) + checkConfig(Defaults.MaxMessageSize, Defaults.RetentionMs, "", "", quotaManagerIsThrottled = false) + } + + @Test + def shouldPropagateDynamicBrokerConfigs(): Unit = { + val brokerIds = Seq(0, 1, 2) + servers = createBrokerConfigs(3, zkConnect).map(fromProps).map(createServer(_)) + + def checkConfig(limit: Long): Unit = { + retry(10000) { + for (server <- servers) { + assertEquals(limit, server.quotaManagers.leader.upperBound, "Leader Quota Manager was not updated") + assertEquals(limit, server.quotaManagers.follower.upperBound, "Follower Quota Manager was not updated") + } + } + } + + val limit: Long = 1000000 + + // Set the limit & check it is applied to the log + adminZkClient.changeBrokerConfig(brokerIds, propsWith( + (LeaderReplicationThrottledRateProp, limit.toString), + (FollowerReplicationThrottledRateProp, limit.toString))) + checkConfig(limit) + + // Now double the config values for the topic and check that it is applied + val newLimit = 2 * limit + adminZkClient.changeBrokerConfig(brokerIds, propsWith( + (LeaderReplicationThrottledRateProp, newLimit.toString), + (FollowerReplicationThrottledRateProp, newLimit.toString))) + checkConfig(newLimit) + + // Verify that the same config can be read from ZK + for (brokerId <- brokerIds) { + val configInZk = adminZkClient.fetchEntityConfig(ConfigType.Broker, brokerId.toString) + assertEquals(newLimit, configInZk.getProperty(LeaderReplicationThrottledRateProp).toInt) + assertEquals(newLimit, configInZk.getProperty(FollowerReplicationThrottledRateProp).toInt) + } + + //Now delete the config + adminZkClient.changeBrokerConfig(brokerIds, new Properties) + checkConfig(DefaultReplicationThrottledRate) + } + + /** + * This test simulates a client config change in ZK whose notification has been purged. + * Basically, it asserts that notifications are bootstrapped from ZK + */ + @Test + def testBootstrapClientIdConfig(): Unit = { + val clientId = "my-client" + val props = new Properties() + props.setProperty("producer_byte_rate", "1000") + props.setProperty("consumer_byte_rate", "2000") + + // Write config without notification to ZK. + zkClient.setOrCreateEntityConfigs(ConfigType.Client, clientId, props) + + val configInZk: Map[String, Properties] = adminZkClient.fetchAllEntityConfigs(ConfigType.Client) + assertEquals(1, configInZk.size, "Must have 1 overridden client config") + assertEquals(props, configInZk(clientId)) + + // Test that the existing clientId overrides are read + val server = TestUtils.createServer(KafkaConfig.fromProps(TestUtils.createBrokerConfig(0, zkConnect))) + servers = Seq(server) + assertEquals(new Quota(1000, true), server.dataPlaneRequestProcessor.quotas.produce.quota("ANONYMOUS", clientId)) + assertEquals(new Quota(2000, true), server.dataPlaneRequestProcessor.quotas.fetch.quota("ANONYMOUS", clientId)) + } + + @Test + def testGetBrokerMetadatas(): Unit = { + // broker 4 has no rack information + val brokerList = 0 to 5 + val rackInfo = Map(0 -> "rack1", 1 -> "rack2", 2 -> "rack2", 3 -> "rack1", 5 -> "rack3") + val brokerMetadatas = toBrokerMetadata(rackInfo, brokersWithoutRack = brokerList.filterNot(rackInfo.keySet)) + TestUtils.createBrokersInZk(brokerMetadatas, zkClient) + + val processedMetadatas1 = adminZkClient.getBrokerMetadatas(RackAwareMode.Disabled) + assertEquals(brokerList, processedMetadatas1.map(_.id)) + assertEquals(List.fill(brokerList.size)(None), processedMetadatas1.map(_.rack)) + + val processedMetadatas2 = adminZkClient.getBrokerMetadatas(RackAwareMode.Safe) + assertEquals(brokerList, processedMetadatas2.map(_.id)) + assertEquals(List.fill(brokerList.size)(None), processedMetadatas2.map(_.rack)) + + assertThrows(classOf[AdminOperationException], () => adminZkClient.getBrokerMetadatas(RackAwareMode.Enforced)) + + val partialList = List(0, 1, 2, 3, 5) + val processedMetadatas3 = adminZkClient.getBrokerMetadatas(RackAwareMode.Enforced, Some(partialList)) + assertEquals(partialList, processedMetadatas3.map(_.id)) + assertEquals(partialList.map(rackInfo), processedMetadatas3.flatMap(_.rack)) + + val numPartitions = 3 + adminZkClient.createTopic("foo", numPartitions, 2, rackAwareMode = RackAwareMode.Safe) + val assignment = zkClient.getReplicaAssignmentForTopics(Set("foo")) + assertEquals(numPartitions, assignment.size) + } +} diff --git a/core/src/test/scala/unit/kafka/zk/EmbeddedZookeeper.scala b/core/src/test/scala/unit/kafka/zk/EmbeddedZookeeper.scala new file mode 100755 index 0000000..28b592e --- /dev/null +++ b/core/src/test/scala/unit/kafka/zk/EmbeddedZookeeper.scala @@ -0,0 +1,69 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.zk + +import org.apache.zookeeper.server.ZooKeeperServer +import org.apache.zookeeper.server.NIOServerCnxnFactory +import kafka.utils.{CoreUtils, Logging, TestUtils} +import java.net.InetSocketAddress + +import org.apache.kafka.common.utils.Utils + +/** + * ZooKeeperServer wrapper that starts the server with temporary directories during construction and deletes + * the directories when `shutdown()` is called. + * + * This is an internal class and it's subject to change. We recommend that you implement your own simple wrapper + * if you need similar functionality. + */ +// This should be named EmbeddedZooKeeper for consistency with other classes, but since this is widely used by other +// projects (even though it's internal), we keep the name as it is until we have a publicly supported test library for +// others to use. +class EmbeddedZookeeper() extends Logging { + + val snapshotDir = TestUtils.tempDir() + val logDir = TestUtils.tempDir() + val tickTime = 800 // allow a maxSessionTimeout of 20 * 800ms = 16 secs + + System.setProperty("zookeeper.forceSync", "no") //disable fsync to ZK txn log in tests to avoid timeout + val zookeeper = new ZooKeeperServer(snapshotDir, logDir, tickTime) + val factory = new NIOServerCnxnFactory() + private val addr = new InetSocketAddress("127.0.0.1", TestUtils.RandomPort) + factory.configure(addr, 0) + factory.startup(zookeeper) + val port = zookeeper.getClientPort + + def shutdown(): Unit = { + // Also shuts down ZooKeeperServer + CoreUtils.swallow(factory.shutdown(), this) + + def isDown(): Boolean = { + try { + ZkFourLetterWords.sendStat("127.0.0.1", port, 3000) + false + } catch { case _: Throwable => true } + } + + Iterator.continually(isDown()).exists(identity) + CoreUtils.swallow(zookeeper.getZKDatabase().close(), this) + + Utils.delete(logDir) + Utils.delete(snapshotDir) + } + +} diff --git a/core/src/test/scala/unit/kafka/zk/KafkaZkClientTest.scala b/core/src/test/scala/unit/kafka/zk/KafkaZkClientTest.scala new file mode 100644 index 0000000..6be954d --- /dev/null +++ b/core/src/test/scala/unit/kafka/zk/KafkaZkClientTest.scala @@ -0,0 +1,1402 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +package kafka.zk + +import java.util.{Collections, Properties} +import java.nio.charset.StandardCharsets.UTF_8 +import java.util.concurrent.{CountDownLatch, TimeUnit} +import kafka.api.{ApiVersion, LeaderAndIsr} +import kafka.cluster.{Broker, EndPoint} +import kafka.log.LogConfig +import kafka.server.{ConfigType, KafkaConfig, QuorumTestHarness} +import kafka.utils.CoreUtils +import org.apache.kafka.common.{TopicPartition, Uuid} +import org.apache.kafka.common.network.ListenerName +import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol} +import org.apache.kafka.common.security.token.delegation.TokenInformation +import org.apache.kafka.common.utils.{SecurityUtils, Time} +import org.apache.zookeeper.KeeperException.{Code, NoAuthException, NoNodeException, NodeExistsException} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +import scala.jdk.CollectionConverters._ +import scala.collection.mutable.ArrayBuffer +import scala.collection.{Seq, mutable} +import scala.util.Random +import kafka.controller.{LeaderIsrAndControllerEpoch, ReplicaAssignment} +import kafka.security.authorizer.AclEntry +import kafka.zk.KafkaZkClient.UpdateLeaderAndIsrResult +import kafka.zookeeper._ +import org.apache.kafka.common.acl.AclOperation.READ +import org.apache.kafka.common.acl.AclPermissionType.{ALLOW, DENY} +import org.apache.kafka.common.errors.ControllerMovedException +import org.apache.kafka.common.feature.{Features, SupportedVersionRange} +import org.apache.kafka.common.feature.Features._ +import org.apache.kafka.common.resource.ResourcePattern +import org.apache.kafka.common.resource.ResourceType.{GROUP, TOPIC} +import org.apache.kafka.common.security.JaasUtils +import org.apache.zookeeper.ZooDefs +import org.apache.zookeeper.client.ZKClientConfig +import org.apache.zookeeper.common.ZKConfig +import org.apache.zookeeper.data.Stat +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ValueSource + +class KafkaZkClientTest extends QuorumTestHarness { + + private val group = "my-group" + private val topic1 = "topic1" + private val topic2 = "topic2" + private val topicIds = Map(topic1 -> Uuid.randomUuid(), topic2 -> Uuid.randomUuid()) + + val topicPartition10 = new TopicPartition(topic1, 0) + val topicPartition11 = new TopicPartition(topic1, 1) + val topicPartition20 = new TopicPartition(topic2, 0) + val topicPartitions10_11 = Seq(topicPartition10, topicPartition11) + val controllerEpochZkVersion = 0 + + var otherZkClient: KafkaZkClient = _ + var expiredSessionZkClient: ExpiredKafkaZkClient = _ + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + super.setUp(testInfo) + zkClient.createControllerEpochRaw(1) + otherZkClient = KafkaZkClient(zkConnect, zkAclsEnabled.getOrElse(JaasUtils.isZkSaslEnabled), zkSessionTimeout, + zkConnectionTimeout, zkMaxInFlightRequests, Time.SYSTEM, name = "KafkaZkClient", + zkClientConfig = new ZKClientConfig) + expiredSessionZkClient = ExpiredKafkaZkClient(zkConnect, zkAclsEnabled.getOrElse(JaasUtils.isZkSaslEnabled), + zkSessionTimeout, zkConnectionTimeout, zkMaxInFlightRequests, Time.SYSTEM) + } + + @AfterEach + override def tearDown(): Unit = { + if (otherZkClient != null) + otherZkClient.close() + zkClient.deletePath(ControllerEpochZNode.path) + if (expiredSessionZkClient != null) + expiredSessionZkClient.close() + super.tearDown() + } + + private val topicPartition = new TopicPartition("topic", 0) + + @Test + def testConnectionViaNettyClient(): Unit = { + // Confirm that we can explicitly set client connection configuration, which is necessary for TLS. + // TLS connectivity itself is tested in system tests rather than here to avoid having to add TLS support + // to kafka.zk.EmbeddedZookeeper + val clientConfig = new ZKClientConfig() + val propKey = KafkaConfig.ZkClientCnxnSocketProp + val propVal = "org.apache.zookeeper.ClientCnxnSocketNetty" + KafkaConfig.setZooKeeperClientProperty(clientConfig, propKey, propVal) + val client = KafkaZkClient(zkConnect, zkAclsEnabled.getOrElse(JaasUtils.isZkSaslEnabled), zkSessionTimeout, + zkConnectionTimeout, zkMaxInFlightRequests, Time.SYSTEM, name = "KafkaZkClient", zkClientConfig = clientConfig) + try { + assertEquals(Some(propVal), KafkaConfig.zooKeeperClientProperty(client.currentZooKeeper.getClientConfig, propKey)) + // For a sanity check, make sure a bad client connection socket class name generates an exception + val badClientConfig = new ZKClientConfig() + KafkaConfig.setZooKeeperClientProperty(badClientConfig, propKey, propVal + "BadClassName") + assertThrows(classOf[Exception], + () => KafkaZkClient(zkConnect, zkAclsEnabled.getOrElse(JaasUtils.isZkSaslEnabled), zkSessionTimeout, + zkConnectionTimeout, zkMaxInFlightRequests, Time.SYSTEM, name = "KafkaZkClientTest", zkClientConfig = badClientConfig)) + } finally { + client.close() + } + } + + @ParameterizedTest + @ValueSource(booleans = Array(true, false)) + def testChroot(createChrootIfNecessary: Boolean): Unit = { + val chroot = "/chroot" + val client = KafkaZkClient(zkConnect + chroot, zkAclsEnabled.getOrElse(JaasUtils.isZkSaslEnabled), zkSessionTimeout, + zkConnectionTimeout, zkMaxInFlightRequests, Time.SYSTEM, name = "KafkaZkClientTest", + zkClientConfig = new ZKClientConfig, createChrootIfNecessary = createChrootIfNecessary) + try { + client.createTopLevelPaths() + if (!createChrootIfNecessary) { + fail("We should not have been able to create top-level paths with a chroot when not explicitly creating the chroot path, but we were able to do so") + } + } catch { + case e: Exception => + if (createChrootIfNecessary) { + fail("We should have been able to create top-level paths with a chroot when explicitly creating the chroot path, but we failed to do so", + e) + } + } finally { + client.close() + } + } + + @Test + def testChrootExistsAndRootIsLocked(): Unit = { + // chroot is accessible + val root = "/testChrootExistsAndRootIsLocked" + val chroot = s"$root/chroot" + + zkClient.makeSurePersistentPathExists(chroot) + zkClient.setAcl(chroot, ZooDefs.Ids.OPEN_ACL_UNSAFE.asScala) + + // root is read-only + zkClient.setAcl(root, ZooDefs.Ids.READ_ACL_UNSAFE.asScala) + + // we should not be able to create node under chroot folder + assertThrows(classOf[NoAuthException], () => zkClient.makeSurePersistentPathExists(chroot)) + + // this client doesn't have create permission to the root and chroot, but the chroot already exists + // Expect that no exception thrown + val chrootClient = KafkaZkClient(zkConnect + chroot, zkAclsEnabled.getOrElse(JaasUtils.isZkSaslEnabled), zkSessionTimeout, + zkConnectionTimeout, zkMaxInFlightRequests, Time.SYSTEM, name = "KafkaZkClientTest", + zkClientConfig = new ZKClientConfig, createChrootIfNecessary = true) + chrootClient.close() + } + + @Test + def testSetAndGetConsumerOffset(): Unit = { + val offset = 123L + // None if no committed offsets + assertTrue(zkClient.getConsumerOffset(group, topicPartition).isEmpty) + // Set and retrieve an offset + zkClient.setOrCreateConsumerOffset(group, topicPartition, offset) + assertEquals(offset, zkClient.getConsumerOffset(group, topicPartition).get) + // Update an existing offset and retrieve it + zkClient.setOrCreateConsumerOffset(group, topicPartition, offset + 2L) + assertEquals(offset + 2L, zkClient.getConsumerOffset(group, topicPartition).get) + } + + @Test + def testGetConsumerOffsetNoData(): Unit = { + zkClient.createRecursive(ConsumerOffset.path(group, topicPartition.topic, topicPartition.partition)) + assertTrue(zkClient.getConsumerOffset(group, topicPartition).isEmpty) + } + + @Test + def testDeleteRecursive(): Unit = { + zkClient.deleteRecursive("/delete/does-not-exist") + + zkClient.createRecursive("/delete/some/random/path") + assertTrue(zkClient.pathExists("/delete/some/random/path")) + assertTrue(zkClient.deleteRecursive("/delete")) + assertFalse(zkClient.pathExists("/delete")) + + assertThrows(classOf[IllegalArgumentException], () => zkClient.deleteRecursive("delete-invalid-path")) + } + + @Test + def testDeleteRecursiveWithControllerEpochVersionCheck(): Unit = { + assertFalse(zkClient.deleteRecursive("/delete/does-not-exist", controllerEpochZkVersion)) + + zkClient.createRecursive("/delete/some/random/path") + assertTrue(zkClient.pathExists("/delete/some/random/path")) + assertThrows(classOf[ControllerMovedException], () => zkClient.deleteRecursive("/delete", controllerEpochZkVersion + 1)) + + assertTrue(zkClient.deleteRecursive("/delete", controllerEpochZkVersion)) + assertFalse(zkClient.pathExists("/delete")) + + assertThrows(classOf[IllegalArgumentException], () => zkClient.deleteRecursive( + "delete-invalid-path", controllerEpochZkVersion)) + } + + @Test + def testCreateRecursive(): Unit = { + zkClient.createRecursive("/create-newrootpath") + assertTrue(zkClient.pathExists("/create-newrootpath")) + + zkClient.createRecursive("/create/some/random/long/path") + assertTrue(zkClient.pathExists("/create/some/random/long/path")) + zkClient.createRecursive("/create/some/random/long/path", throwIfPathExists = false) // no errors if path already exists + + assertThrows(classOf[IllegalArgumentException], () => zkClient.createRecursive("create-invalid-path")) + } + + @Test + def testTopicAssignmentMethods(): Unit = { + assertTrue(zkClient.getAllTopicsInCluster().isEmpty) + + // test with non-existing topic + assertFalse(zkClient.topicExists(topic1)) + assertTrue(zkClient.getTopicPartitionCount(topic1).isEmpty) + assertTrue(zkClient.getPartitionAssignmentForTopics(Set(topic1)).isEmpty) + assertTrue(zkClient.getPartitionsForTopics(Set(topic1)).isEmpty) + assertTrue(zkClient.getReplicasForPartition(new TopicPartition(topic1, 2)).isEmpty) + + val assignment = Map( + new TopicPartition(topic1, 0) -> Seq(0, 1), + new TopicPartition(topic1, 1) -> Seq(0, 1), + new TopicPartition(topic1, 2) -> Seq(1, 2, 3) + ) + + // create a topic assignment + zkClient.createTopicAssignment(topic1, topicIds.get(topic1), assignment) + + assertTrue(zkClient.topicExists(topic1)) + + val expectedAssignment = assignment map { topicAssignment => + val partition = topicAssignment._1.partition + val assignment = topicAssignment._2 + partition -> ReplicaAssignment(assignment, List(), List()) + } + + assertEquals(assignment.size, zkClient.getTopicPartitionCount(topic1).get) + assertEquals(expectedAssignment, zkClient.getPartitionAssignmentForTopics(Set(topic1))(topic1)) + assertEquals(Set(0, 1, 2), zkClient.getPartitionsForTopics(Set(topic1))(topic1).toSet) + assertEquals(Set(1, 2, 3), zkClient.getReplicasForPartition(new TopicPartition(topic1, 2)).toSet) + + val updatedAssignment = assignment - new TopicPartition(topic1, 2) + + zkClient.setTopicAssignment(topic1, topicIds.get(topic1), updatedAssignment.map { + case (k, v) => k -> ReplicaAssignment(v, List(), List()) }) + assertEquals(updatedAssignment.size, zkClient.getTopicPartitionCount(topic1).get) + + // add second topic + val secondAssignment = Map( + new TopicPartition(topic2, 0) -> Seq(0, 1), + new TopicPartition(topic2, 1) -> Seq(0, 1) + ) + + zkClient.createTopicAssignment(topic2, topicIds.get(topic2), secondAssignment) + + assertEquals(Set(topic1, topic2), zkClient.getAllTopicsInCluster()) + } + + @Test + def testGetAllTopicsInClusterTriggersWatch(): Unit = { + zkClient.createTopLevelPaths() + val latch = registerChildChangeHandler(1) + + // Listing all the topics and register the watch + assertTrue(zkClient.getAllTopicsInCluster(true).isEmpty) + + // Verifies that listing all topics without registering the watch does + // not interfere with the previous registered watcher + assertTrue(zkClient.getAllTopicsInCluster(false).isEmpty) + + zkClient.createTopicAssignment(topic1, topicIds.get(topic1), Map.empty) + + assertTrue(latch.await(5, TimeUnit.SECONDS), + "Failed to receive watch notification") + + assertTrue(zkClient.topicExists(topic1)) + } + + @Test + def testGetAllTopicsInClusterDoesNotTriggerWatch(): Unit = { + zkClient.createTopLevelPaths() + val latch = registerChildChangeHandler(1) + + // Listing all the topics and don't register the watch + assertTrue(zkClient.getAllTopicsInCluster(false).isEmpty) + + zkClient.createTopicAssignment(topic1, topicIds.get(topic1), Map.empty) + + assertFalse(latch.await(100, TimeUnit.MILLISECONDS), + "Received watch notification") + + assertTrue(zkClient.topicExists(topic1)) + } + + private def registerChildChangeHandler(count: Int): CountDownLatch = { + val znodeChildChangeHandlerCountDownLatch = new CountDownLatch(1) + val znodeChildChangeHandler = new ZNodeChildChangeHandler { + override val path: String = TopicsZNode.path + + override def handleChildChange(): Unit = { + znodeChildChangeHandlerCountDownLatch.countDown() + } + } + zkClient.registerZNodeChildChangeHandler(znodeChildChangeHandler) + znodeChildChangeHandlerCountDownLatch + } + + @Test + def testGetDataAndVersion(): Unit = { + val path = "/testpath" + + // test with non-existing path + val (data0, version0) = zkClient.getDataAndVersion(path) + assertTrue(data0.isEmpty) + assertEquals(ZkVersion.UnknownVersion, version0) + + // create a test path + zkClient.createRecursive(path) + zkClient.conditionalUpdatePath(path, "version1".getBytes(UTF_8), 0) + + // test with existing path + val (data1, version1) = zkClient.getDataAndVersion(path) + assertEquals("version1", new String(data1.get, UTF_8)) + assertEquals(1, version1) + + zkClient.conditionalUpdatePath(path, "version2".getBytes(UTF_8), 1) + val (data2, version2) = zkClient.getDataAndVersion(path) + assertEquals("version2", new String(data2.get, UTF_8)) + assertEquals(2, version2) + } + + @Test + def testConditionalUpdatePath(): Unit = { + val path = "/testconditionalpath" + + // test with non-existing path + var statusAndVersion = zkClient.conditionalUpdatePath(path, "version0".getBytes(UTF_8), 0) + assertFalse(statusAndVersion._1) + assertEquals(ZkVersion.UnknownVersion, statusAndVersion._2) + + // create path + zkClient.createRecursive(path) + + // test with valid expected version + statusAndVersion = zkClient.conditionalUpdatePath(path, "version1".getBytes(UTF_8), 0) + assertTrue(statusAndVersion._1) + assertEquals(1, statusAndVersion._2) + + // test with invalid expected version + statusAndVersion = zkClient.conditionalUpdatePath(path, "version2".getBytes(UTF_8), 2) + assertFalse(statusAndVersion._1) + assertEquals(ZkVersion.UnknownVersion, statusAndVersion._2) + } + + @Test + def testCreateSequentialPersistentPath(): Unit = { + val path = "/testpath" + zkClient.createRecursive(path) + + var result = zkClient.createSequentialPersistentPath(path + "/sequence_", null) + assertEquals(s"$path/sequence_0000000000", result) + assertTrue(zkClient.pathExists(s"$path/sequence_0000000000")) + assertEquals(None, dataAsString(s"$path/sequence_0000000000")) + + result = zkClient.createSequentialPersistentPath(path + "/sequence_", "some value".getBytes(UTF_8)) + assertEquals(s"$path/sequence_0000000001", result) + assertTrue(zkClient.pathExists(s"$path/sequence_0000000001")) + assertEquals(Some("some value"), dataAsString(s"$path/sequence_0000000001")) + } + + @Test + def testPropagateIsrChanges(): Unit = { + zkClient.createRecursive("/isr_change_notification") + + zkClient.propagateIsrChanges(Set(new TopicPartition("topic-a", 0), new TopicPartition("topic-b", 0))) + var expectedPath = "/isr_change_notification/isr_change_0000000000" + assertTrue(zkClient.pathExists(expectedPath)) + assertEquals(Some("""{"version":1,"partitions":[{"topic":"topic-a","partition":0},{"topic":"topic-b","partition":0}]}"""), + dataAsString(expectedPath)) + + zkClient.propagateIsrChanges(Set(new TopicPartition("topic-b", 0))) + expectedPath = "/isr_change_notification/isr_change_0000000001" + assertTrue(zkClient.pathExists(expectedPath)) + assertEquals(Some("""{"version":1,"partitions":[{"topic":"topic-b","partition":0}]}"""), dataAsString(expectedPath)) + } + + @Test + def testIsrChangeNotificationGetters(): Unit = { + assertEquals(Seq.empty, zkClient.getAllIsrChangeNotifications, "Failed for non existing parent ZK node") + assertEquals(Seq.empty, zkClient.getPartitionsFromIsrChangeNotifications(Seq("0000000000")), "Failed for non existing parent ZK node") + + zkClient.createRecursive("/isr_change_notification") + + zkClient.propagateIsrChanges(Set(topicPartition10, topicPartition11)) + zkClient.propagateIsrChanges(Set(topicPartition10)) + + assertEquals(Set("0000000000", "0000000001"), zkClient.getAllIsrChangeNotifications.toSet) + + // A partition can have multiple notifications + assertEquals(Seq(topicPartition10, topicPartition11, topicPartition10), + zkClient.getPartitionsFromIsrChangeNotifications(Seq("0000000000", "0000000001"))) + } + + @Test + def testIsrChangeNotificationsDeletion(): Unit = { + // Should not fail even if parent node does not exist + zkClient.deleteIsrChangeNotifications(Seq("0000000000"), controllerEpochZkVersion) + + zkClient.createRecursive("/isr_change_notification") + + zkClient.propagateIsrChanges(Set(topicPartition10, topicPartition11)) + zkClient.propagateIsrChanges(Set(topicPartition10)) + zkClient.propagateIsrChanges(Set(topicPartition11)) + + // Should throw exception if the controllerEpochZkVersion does not match + assertThrows(classOf[ControllerMovedException], () => zkClient.deleteIsrChangeNotifications(Seq("0000000001"), controllerEpochZkVersion + 1)) + // Delete should not succeed + assertEquals(Set("0000000000", "0000000001", "0000000002"), zkClient.getAllIsrChangeNotifications.toSet) + + zkClient.deleteIsrChangeNotifications(Seq("0000000001"), controllerEpochZkVersion) + // Should not fail if called on a non-existent notification + zkClient.deleteIsrChangeNotifications(Seq("0000000001"), controllerEpochZkVersion) + + assertEquals(Set("0000000000", "0000000002"), zkClient.getAllIsrChangeNotifications.toSet) + zkClient.deleteIsrChangeNotifications(controllerEpochZkVersion) + assertEquals(Seq.empty, zkClient.getAllIsrChangeNotifications) + } + + @Test + def testPropagateLogDir(): Unit = { + zkClient.createRecursive("/log_dir_event_notification") + + val brokerId = 3 + + zkClient.propagateLogDirEvent(brokerId) + var expectedPath = "/log_dir_event_notification/log_dir_event_0000000000" + assertTrue(zkClient.pathExists(expectedPath)) + assertEquals(Some("""{"version":1,"broker":3,"event":1}"""), dataAsString(expectedPath)) + + zkClient.propagateLogDirEvent(brokerId) + expectedPath = "/log_dir_event_notification/log_dir_event_0000000001" + assertTrue(zkClient.pathExists(expectedPath)) + assertEquals(Some("""{"version":1,"broker":3,"event":1}"""), dataAsString(expectedPath)) + + val anotherBrokerId = 4 + zkClient.propagateLogDirEvent(anotherBrokerId) + expectedPath = "/log_dir_event_notification/log_dir_event_0000000002" + assertTrue(zkClient.pathExists(expectedPath)) + assertEquals(Some("""{"version":1,"broker":4,"event":1}"""), dataAsString(expectedPath)) + } + + @Test + def testLogDirGetters(): Unit = { + assertEquals(Seq.empty, + zkClient.getAllLogDirEventNotifications, "getAllLogDirEventNotifications failed for non existing parent ZK node") + assertEquals(Seq.empty, + zkClient.getBrokerIdsFromLogDirEvents(Seq("0000000000")), "getBrokerIdsFromLogDirEvents failed for non existing parent ZK node") + + zkClient.createRecursive("/log_dir_event_notification") + + val brokerId = 3 + zkClient.propagateLogDirEvent(brokerId) + + assertEquals(Seq(3), zkClient.getBrokerIdsFromLogDirEvents(Seq("0000000000"))) + + zkClient.propagateLogDirEvent(brokerId) + + val anotherBrokerId = 4 + zkClient.propagateLogDirEvent(anotherBrokerId) + + val notifications012 = Seq("0000000000", "0000000001", "0000000002") + assertEquals(notifications012.toSet, zkClient.getAllLogDirEventNotifications.toSet) + assertEquals(Seq(3, 3, 4), zkClient.getBrokerIdsFromLogDirEvents(notifications012)) + } + + @Test + def testLogDirEventNotificationsDeletion(): Unit = { + // Should not fail even if parent node does not exist + zkClient.deleteLogDirEventNotifications(Seq("0000000000", "0000000002"), controllerEpochZkVersion) + + zkClient.createRecursive("/log_dir_event_notification") + + val brokerId = 3 + val anotherBrokerId = 4 + + zkClient.propagateLogDirEvent(brokerId) + zkClient.propagateLogDirEvent(brokerId) + zkClient.propagateLogDirEvent(anotherBrokerId) + + assertThrows(classOf[ControllerMovedException], () => zkClient.deleteLogDirEventNotifications(Seq("0000000000", "0000000002"), controllerEpochZkVersion + 1)) + assertEquals(Seq("0000000000", "0000000001", "0000000002"), zkClient.getAllLogDirEventNotifications) + + zkClient.deleteLogDirEventNotifications(Seq("0000000000", "0000000002"), controllerEpochZkVersion) + + assertEquals(Seq("0000000001"), zkClient.getAllLogDirEventNotifications) + + zkClient.propagateLogDirEvent(anotherBrokerId) + + zkClient.deleteLogDirEventNotifications(controllerEpochZkVersion) + assertEquals(Seq.empty, zkClient.getAllLogDirEventNotifications) + } + + @Test + def testSetGetAndDeletePartitionReassignment(): Unit = { + zkClient.createRecursive(AdminZNode.path) + + assertEquals(Map.empty, zkClient.getPartitionReassignment) + + val reassignment = Map( + new TopicPartition("topic_a", 0) -> Seq(0, 1, 3), + new TopicPartition("topic_a", 1) -> Seq(2, 1, 3), + new TopicPartition("topic_b", 0) -> Seq(4, 5), + new TopicPartition("topic_c", 0) -> Seq(5, 3) + ) + + // Should throw ControllerMovedException if the controller epoch zkVersion does not match + assertThrows(classOf[ControllerMovedException], () => zkClient.setOrCreatePartitionReassignment(reassignment, controllerEpochZkVersion + 1)) + + zkClient.setOrCreatePartitionReassignment(reassignment, controllerEpochZkVersion) + assertEquals(reassignment, zkClient.getPartitionReassignment) + + val updatedReassignment = reassignment - new TopicPartition("topic_b", 0) + zkClient.setOrCreatePartitionReassignment(updatedReassignment, controllerEpochZkVersion) + assertEquals(updatedReassignment, zkClient.getPartitionReassignment) + + zkClient.deletePartitionReassignment(controllerEpochZkVersion) + assertEquals(Map.empty, zkClient.getPartitionReassignment) + + zkClient.createPartitionReassignment(reassignment) + assertEquals(reassignment, zkClient.getPartitionReassignment) + } + + @Test + def testGetDataAndStat(): Unit = { + val path = "/testpath" + + // test with non-existing path + val (data0, version0) = zkClient.getDataAndStat(path) + assertTrue(data0.isEmpty) + assertEquals(0, version0.getVersion) + + // create a test path + zkClient.createRecursive(path) + zkClient.conditionalUpdatePath(path, "version1".getBytes(UTF_8), 0) + + // test with existing path + val (data1, version1) = zkClient.getDataAndStat(path) + assertEquals("version1", new String(data1.get, UTF_8)) + assertEquals(1, version1.getVersion) + + zkClient.conditionalUpdatePath(path, "version2".getBytes(UTF_8), 1) + val (data2, version2) = zkClient.getDataAndStat(path) + assertEquals("version2", new String(data2.get, UTF_8)) + assertEquals(2, version2.getVersion) + } + + @Test + def testGetChildren(): Unit = { + val path = "/testpath" + + // test with non-existing path + assertTrue(zkClient.getChildren(path).isEmpty) + + // create child nodes + zkClient.createRecursive( "/testpath/child1") + zkClient.createRecursive( "/testpath/child2") + zkClient.createRecursive( "/testpath/child3") + + val children = zkClient.getChildren(path) + + assertEquals(3, children.size) + assertEquals(Set("child1","child2","child3"), children.toSet) + } + + @Test + def testAclManagementMethods(): Unit = { + ZkAclStore.stores.foreach(store => { + assertFalse(zkClient.pathExists(store.aclPath)) + assertFalse(zkClient.pathExists(store.changeStore.aclChangePath)) + AclEntry.ResourceTypes.foreach(resource => assertFalse(zkClient.pathExists(store.path(resource)))) + }) + + // create acl paths + zkClient.createAclPaths() + + ZkAclStore.stores.foreach(store => { + assertTrue(zkClient.pathExists(store.aclPath)) + assertTrue(zkClient.pathExists(store.changeStore.aclChangePath)) + AclEntry.ResourceTypes.foreach(resource => assertTrue(zkClient.pathExists(store.path(resource)))) + + val resource1 = new ResourcePattern(TOPIC, Uuid.randomUuid().toString, store.patternType) + val resource2 = new ResourcePattern(TOPIC, Uuid.randomUuid().toString, store.patternType) + + // try getting acls for non-existing resource + var versionedAcls = zkClient.getVersionedAclsForResource(resource1) + assertTrue(versionedAcls.acls.isEmpty) + assertEquals(ZkVersion.UnknownVersion, versionedAcls.zkVersion) + assertFalse(zkClient.resourceExists(resource1)) + + + val acl1 = AclEntry(new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "alice"), DENY, "host1" , READ) + val acl2 = AclEntry(new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "bob"), ALLOW, "*", READ) + val acl3 = AclEntry(new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "bob"), DENY, "host1", READ) + + // Conditional set should fail if path not created + assertFalse(zkClient.conditionalSetAclsForResource(resource1, Set(acl1, acl3), 0)._1) + + //create acls for resources + assertTrue(zkClient.createAclsForResourceIfNotExists(resource1, Set(acl1, acl2))._1) + assertTrue(zkClient.createAclsForResourceIfNotExists(resource2, Set(acl1, acl3))._1) + + // Create should fail if path already exists + assertFalse(zkClient.createAclsForResourceIfNotExists(resource2, Set(acl1, acl3))._1) + + versionedAcls = zkClient.getVersionedAclsForResource(resource1) + assertEquals(Set(acl1, acl2), versionedAcls.acls) + assertEquals(0, versionedAcls.zkVersion) + assertTrue(zkClient.resourceExists(resource1)) + + //update acls for resource + assertTrue(zkClient.conditionalSetAclsForResource(resource1, Set(acl1, acl3), 0)._1) + + versionedAcls = zkClient.getVersionedAclsForResource(resource1) + assertEquals(Set(acl1, acl3), versionedAcls.acls) + assertEquals(1, versionedAcls.zkVersion) + + //get resource Types + assertEquals(AclEntry.ResourceTypes.map(SecurityUtils.resourceTypeName), zkClient.getResourceTypes(store.patternType).toSet) + + //get resource name + val resourceNames = zkClient.getResourceNames(store.patternType, TOPIC) + assertEquals(2, resourceNames.size) + assertTrue(Set(resource1.name,resource2.name) == resourceNames.toSet) + + //delete resource + assertTrue(zkClient.deleteResource(resource1)) + assertFalse(zkClient.resourceExists(resource1)) + + //delete with invalid expected zk version + assertFalse(zkClient.conditionalDelete(resource2, 10)) + //delete with valid expected zk version + assertTrue(zkClient.conditionalDelete(resource2, 0)) + + zkClient.createAclChangeNotification(new ResourcePattern(GROUP, "resource1", store.patternType)) + zkClient.createAclChangeNotification(new ResourcePattern(TOPIC, "resource2", store.patternType)) + + assertEquals(2, zkClient.getChildren(store.changeStore.aclChangePath).size) + + zkClient.deleteAclChangeNotifications() + assertTrue(zkClient.getChildren(store.changeStore.aclChangePath).isEmpty) + }) + } + + @Test + def testDeletePath(): Unit = { + val path = "/a/b/c" + zkClient.createRecursive(path) + zkClient.deletePath(path) + assertFalse(zkClient.pathExists(path)) + + zkClient.createRecursive(path) + zkClient.deletePath("/a") + assertFalse(zkClient.pathExists(path)) + + zkClient.createRecursive(path) + zkClient.deletePath(path, recursiveDelete = false) + assertFalse(zkClient.pathExists(path)) + assertTrue(zkClient.pathExists("/a/b")) + } + + @Test + def testDeleteTopicZNode(): Unit = { + zkClient.deleteTopicZNode(topic1, controllerEpochZkVersion) + zkClient.createRecursive(TopicZNode.path(topic1)) + zkClient.deleteTopicZNode(topic1, controllerEpochZkVersion) + assertFalse(zkClient.pathExists(TopicZNode.path(topic1))) + } + + @Test + def testDeleteTopicPathMethods(): Unit = { + assertFalse(zkClient.isTopicMarkedForDeletion(topic1)) + assertTrue(zkClient.getTopicDeletions.isEmpty) + + zkClient.createDeleteTopicPath(topic1) + zkClient.createDeleteTopicPath(topic2) + + assertTrue(zkClient.isTopicMarkedForDeletion(topic1)) + assertEquals(Set(topic1, topic2), zkClient.getTopicDeletions.toSet) + + assertThrows(classOf[ControllerMovedException], () => zkClient.deleteTopicDeletions(Seq(topic1, topic2), controllerEpochZkVersion + 1)) + assertEquals(Set(topic1, topic2), zkClient.getTopicDeletions.toSet) + + zkClient.deleteTopicDeletions(Seq(topic1, topic2), controllerEpochZkVersion) + assertTrue(zkClient.getTopicDeletions.isEmpty) + } + + private def assertPathExistenceAndData(expectedPath: String, data: String): Unit = { + assertTrue(zkClient.pathExists(expectedPath)) + assertEquals(Some(data), dataAsString(expectedPath)) + } + + @Test + def testCreateTokenChangeNotification(): Unit = { + assertThrows(classOf[NoNodeException], () => zkClient.createTokenChangeNotification("delegationToken")) + zkClient.createDelegationTokenPaths() + + zkClient.createTokenChangeNotification("delegationToken") + assertPathExistenceAndData("/delegation_token/token_changes/token_change_0000000000", "delegationToken") + } + + @Test + def testEntityConfigManagementMethods(): Unit = { + assertTrue(zkClient.getEntityConfigs(ConfigType.Topic, topic1).isEmpty) + + zkClient.setOrCreateEntityConfigs(ConfigType.Topic, topic1, logProps) + assertEquals(logProps, zkClient.getEntityConfigs(ConfigType.Topic, topic1)) + + logProps.remove(LogConfig.CleanupPolicyProp) + zkClient.setOrCreateEntityConfigs(ConfigType.Topic, topic1, logProps) + assertEquals(logProps, zkClient.getEntityConfigs(ConfigType.Topic, topic1)) + + zkClient.setOrCreateEntityConfigs(ConfigType.Topic, topic2, logProps) + assertEquals(Set(topic1, topic2), zkClient.getAllEntitiesWithConfig(ConfigType.Topic).toSet) + + zkClient.deleteTopicConfigs(Seq(topic1, topic2), controllerEpochZkVersion) + assertTrue(zkClient.getEntityConfigs(ConfigType.Topic, topic1).isEmpty) + } + + @Test + def testCreateConfigChangeNotification(): Unit = { + assertFalse(zkClient.pathExists(ConfigEntityChangeNotificationZNode.path)) + + // The parent path is created if needed + zkClient.createConfigChangeNotification(ConfigEntityZNode.path(ConfigType.Topic, topic1)) + assertPathExistenceAndData( + "/config/changes/config_change_0000000000", + """{"version":2,"entity_path":"/config/topics/topic1"}""") + + // Creation does not fail if the parent path exists + zkClient.createConfigChangeNotification(ConfigEntityZNode.path(ConfigType.Topic, topic2)) + assertPathExistenceAndData( + "/config/changes/config_change_0000000001", + """{"version":2,"entity_path":"/config/topics/topic2"}""") + } + + private def createLogProps(bytesProp: Int): Properties = { + val logProps = new Properties() + logProps.put(LogConfig.SegmentBytesProp, bytesProp.toString) + logProps.put(LogConfig.SegmentIndexBytesProp, bytesProp.toString) + logProps.put(LogConfig.CleanupPolicyProp, LogConfig.Compact) + logProps + } + + private val logProps = createLogProps(1024) + + @Test + def testGetLogConfigs(): Unit = { + val emptyConfig = LogConfig(Collections.emptyMap()) + assertEquals((Map(topic1 -> emptyConfig), Map.empty), + zkClient.getLogConfigs(Set(topic1), Collections.emptyMap()), + "Non existent config, no defaults") + + val logProps2 = createLogProps(2048) + + zkClient.setOrCreateEntityConfigs(ConfigType.Topic, topic1, logProps) + assertEquals((Map(topic1 -> LogConfig(logProps), topic2 -> emptyConfig), Map.empty), + zkClient.getLogConfigs(Set(topic1, topic2), Collections.emptyMap()), + "One existing and one non-existent topic") + + zkClient.setOrCreateEntityConfigs(ConfigType.Topic, topic2, logProps2) + assertEquals((Map(topic1 -> LogConfig(logProps), topic2 -> LogConfig(logProps2)), Map.empty), + zkClient.getLogConfigs(Set(topic1, topic2), Collections.emptyMap()), + "Two existing topics") + + val logProps1WithMoreValues = createLogProps(1024) + logProps1WithMoreValues.put(LogConfig.SegmentJitterMsProp, "100") + logProps1WithMoreValues.put(LogConfig.SegmentBytesProp, "1024") + + assertEquals((Map(topic1 -> LogConfig(logProps1WithMoreValues)), Map.empty), + zkClient.getLogConfigs(Set(topic1), + Map[String, AnyRef](LogConfig.SegmentJitterMsProp -> "100", LogConfig.SegmentBytesProp -> "128").asJava), + "Config with defaults") + } + + private def createBrokerInfo(id: Int, host: String, port: Int, securityProtocol: SecurityProtocol, + rack: Option[String] = None, + features: Features[SupportedVersionRange] = emptySupportedFeatures): BrokerInfo = + BrokerInfo( + Broker( + id, + Seq(new EndPoint(host, port, ListenerName.forSecurityProtocol(securityProtocol), securityProtocol)), + rack = rack, + features = features), + ApiVersion.latestVersion, jmxPort = port + 10) + + @Test + def testRegisterBrokerInfo(): Unit = { + zkClient.createTopLevelPaths() + + val brokerInfo = createBrokerInfo( + 1, "test.host", 9999, SecurityProtocol.PLAINTEXT, + rack = None, + features = Features.supportedFeatures( + Map[String, SupportedVersionRange]( + "feature1" -> new SupportedVersionRange(1, 2)).asJava)) + val differentBrokerInfoWithSameId = createBrokerInfo( + 1, "test.host2", 9995, SecurityProtocol.SSL, + features = Features.supportedFeatures( + Map[String, SupportedVersionRange]( + "feature2" -> new SupportedVersionRange(4, 7)).asJava)) + + zkClient.registerBroker(brokerInfo) + assertEquals(Some(brokerInfo.broker), zkClient.getBroker(1)) + assertEquals(Some(brokerInfo.broker), otherZkClient.getBroker(1), "Other ZK clients can read broker info") + + // Node exists, owned by current session - no error, no update + zkClient.registerBroker(differentBrokerInfoWithSameId) + assertEquals(Some(brokerInfo.broker), zkClient.getBroker(1)) + + // Other client tries to register broker with same id causes failure, info is not changed in ZK + assertThrows(classOf[NodeExistsException], () => otherZkClient.registerBroker(differentBrokerInfoWithSameId)) + assertEquals(Some(brokerInfo.broker), zkClient.getBroker(1)) + } + + @Test + def testRetryRegisterBrokerInfo(): Unit = { + val brokerId = 5 + val brokerPort = 9999 + val brokerHost = "test.host" + val expiredBrokerInfo = createBrokerInfo(brokerId, brokerHost, brokerPort, SecurityProtocol.PLAINTEXT) + expiredSessionZkClient.createTopLevelPaths() + + // Register the broker, for the first time + expiredSessionZkClient.registerBroker(expiredBrokerInfo) + assertEquals(Some(expiredBrokerInfo.broker), expiredSessionZkClient.getBroker(brokerId)) + val originalCzxid = expiredSessionZkClient.getPathCzxid(BrokerIdZNode.path(brokerId)) + + // Here, the node exists already, when trying to register under a different session id, + // the node will be deleted and created again using the new session id. + expiredSessionZkClient.registerBroker(expiredBrokerInfo) + + // The broker info should be the same, no error should be raised + assertEquals(Some(expiredBrokerInfo.broker), expiredSessionZkClient.getBroker(brokerId)) + val newCzxid = expiredSessionZkClient.getPathCzxid(BrokerIdZNode.path(brokerId)) + + assertNotEquals(originalCzxid, newCzxid, "The Czxid of original ephemeral znode should be different " + + "from the new ephemeral znode Czxid") + } + + @Test + def testGetBrokerMethods(): Unit = { + zkClient.createTopLevelPaths() + + assertEquals(Seq.empty,zkClient.getAllBrokersInCluster) + assertEquals(Seq.empty, zkClient.getSortedBrokerList) + assertEquals(None, zkClient.getBroker(0)) + + val brokerInfo0 = createBrokerInfo( + 0, "test.host0", 9998, SecurityProtocol.PLAINTEXT, + features = Features.supportedFeatures( + Map[String, SupportedVersionRange]( + "feature1" -> new SupportedVersionRange(1, 2)).asJava)) + val brokerInfo1 = createBrokerInfo( + 1, "test.host1", 9999, SecurityProtocol.SSL, + features = Features.supportedFeatures( + Map[String, SupportedVersionRange]( + "feature2" -> new SupportedVersionRange(3, 6)).asJava)) + + zkClient.registerBroker(brokerInfo1) + otherZkClient.registerBroker(brokerInfo0) + + assertEquals(Seq(0, 1), zkClient.getSortedBrokerList) + assertEquals( + Seq(brokerInfo0.broker, brokerInfo1.broker), + zkClient.getAllBrokersInCluster + ) + assertEquals(Some(brokerInfo0.broker), zkClient.getBroker(0)) + } + + @Test + def testUpdateBrokerInfo(): Unit = { + zkClient.createTopLevelPaths() + + // Updating info of a broker not existing in ZK fails + val originalBrokerInfo = createBrokerInfo(1, "test.host", 9999, SecurityProtocol.PLAINTEXT) + assertThrows(classOf[NoNodeException], () => zkClient.updateBrokerInfo(originalBrokerInfo)) + + zkClient.registerBroker(originalBrokerInfo) + + val updatedBrokerInfo = createBrokerInfo(1, "test.host2", 9995, SecurityProtocol.SSL) + zkClient.updateBrokerInfo(updatedBrokerInfo) + assertEquals(Some(updatedBrokerInfo.broker), zkClient.getBroker(1)) + + // Other ZK clients can update info + otherZkClient.updateBrokerInfo(originalBrokerInfo) + assertEquals(Some(originalBrokerInfo.broker), otherZkClient.getBroker(1)) + } + + private def statWithVersion(version: Int): Stat = { + val stat = new Stat(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) + stat.setVersion(version) + stat + } + + private def leaderIsrAndControllerEpochs(state: Int, zkVersion: Int): Map[TopicPartition, LeaderIsrAndControllerEpoch] = + Map( + topicPartition10 -> LeaderIsrAndControllerEpoch( + LeaderAndIsr(leader = 1, leaderEpoch = state, isr = List(2 + state, 3 + state), zkVersion = zkVersion), + controllerEpoch = 4), + topicPartition11 -> LeaderIsrAndControllerEpoch( + LeaderAndIsr(leader = 0, leaderEpoch = state + 1, isr = List(1 + state, 2 + state), zkVersion = zkVersion), + controllerEpoch = 4)) + + val initialLeaderIsrAndControllerEpochs: Map[TopicPartition, LeaderIsrAndControllerEpoch] = + leaderIsrAndControllerEpochs(0, 0) + + val initialLeaderIsrs: Map[TopicPartition, LeaderAndIsr] = + initialLeaderIsrAndControllerEpochs.map { case (k, v) => k -> v.leaderAndIsr } + + private def leaderIsrs(state: Int, zkVersion: Int): Map[TopicPartition, LeaderAndIsr] = + leaderIsrAndControllerEpochs(state, zkVersion).map { case (k, v) => k -> v.leaderAndIsr } + + private def checkUpdateLeaderAndIsrResult( + expectedSuccessfulPartitions: Map[TopicPartition, LeaderAndIsr], + expectedPartitionsToRetry: Seq[TopicPartition], + expectedFailedPartitions: Map[TopicPartition, (Class[_], String)], + actualUpdateLeaderAndIsrResult: UpdateLeaderAndIsrResult): Unit = { + val failedPartitionsExcerpt = mutable.Map.empty[TopicPartition, (Class[_], String)] + val successfulPartitions = mutable.Map.empty[TopicPartition, LeaderAndIsr] + + actualUpdateLeaderAndIsrResult.finishedPartitions.foreach { + case (partition, Left(e)) => failedPartitionsExcerpt += partition -> (e.getClass, e.getMessage) + case (partition, Right(leaderAndIsr)) => successfulPartitions += partition -> leaderAndIsr + } + + assertEquals(expectedFailedPartitions, + failedPartitionsExcerpt, "Permanently failed updates do not match expected") + assertEquals(expectedPartitionsToRetry, + actualUpdateLeaderAndIsrResult.partitionsToRetry, "Retriable updates (due to BADVERSION) do not match expected") + assertEquals(expectedSuccessfulPartitions, + successfulPartitions, "Successful updates do not match expected") + } + + @Test + def testTopicAssignments(): Unit = { + val topicId = Some(Uuid.randomUuid()) + assertEquals(0, zkClient.getPartitionAssignmentForTopics(Set(topicPartition.topic())).size) + zkClient.createTopicAssignment(topicPartition.topic(), topicId, + Map(topicPartition -> Seq())) + + val expectedAssignment = ReplicaAssignment(Seq(1,2,3), Seq(1), Seq(3)) + val response = zkClient.setTopicAssignmentRaw(topicPartition.topic(), topicId, + Map(topicPartition -> expectedAssignment), controllerEpochZkVersion) + assertEquals(Code.OK, response.resultCode) + + val topicPartitionAssignments = zkClient.getPartitionAssignmentForTopics(Set(topicPartition.topic())) + assertEquals(1, topicPartitionAssignments.size) + assertTrue(topicPartitionAssignments.contains(topicPartition.topic())) + val partitionAssignments = topicPartitionAssignments(topicPartition.topic()) + assertEquals(1, partitionAssignments.size) + assertTrue(partitionAssignments.contains(topicPartition.partition())) + val assignment = partitionAssignments(topicPartition.partition()) + assertEquals(expectedAssignment, assignment) + } + + @Test + def testUpdateLeaderAndIsr(): Unit = { + zkClient.createRecursive(TopicZNode.path(topic1)) + + // Non-existing topicPartitions + checkUpdateLeaderAndIsrResult( + Map.empty, + mutable.ArrayBuffer.empty, + Map( + topicPartition10 -> (classOf[NoNodeException], "KeeperErrorCode = NoNode for /brokers/topics/topic1/partitions/0/state"), + topicPartition11 -> (classOf[NoNodeException], "KeeperErrorCode = NoNode for /brokers/topics/topic1/partitions/1/state")), + zkClient.updateLeaderAndIsr(initialLeaderIsrs, controllerEpoch = 4, controllerEpochZkVersion)) + + zkClient.createTopicPartitionStatesRaw(initialLeaderIsrAndControllerEpochs, controllerEpochZkVersion) + + // Mismatch controller epoch zkVersion + assertThrows(classOf[ControllerMovedException], () => zkClient.updateLeaderAndIsr(initialLeaderIsrs, controllerEpoch = 4, controllerEpochZkVersion + 1)) + + // successful updates + checkUpdateLeaderAndIsrResult( + leaderIsrs(state = 1, zkVersion = 1), + mutable.ArrayBuffer.empty, + Map.empty, + zkClient.updateLeaderAndIsr(leaderIsrs(state = 1, zkVersion = 0),controllerEpoch = 4, controllerEpochZkVersion)) + + // Try to update with wrong ZK version + checkUpdateLeaderAndIsrResult( + Map.empty, + ArrayBuffer(topicPartition10, topicPartition11), + Map.empty, + zkClient.updateLeaderAndIsr(leaderIsrs(state = 1, zkVersion = 0),controllerEpoch = 4, controllerEpochZkVersion)) + + // Trigger successful, to be retried and failed partitions in same call + val mixedState = Map( + topicPartition10 -> LeaderAndIsr(leader = 1, leaderEpoch = 2, isr = List(4, 5), zkVersion = 1), + topicPartition11 -> LeaderAndIsr(leader = 0, leaderEpoch = 2, isr = List(3, 4), zkVersion = 0), + topicPartition20 -> LeaderAndIsr(leader = 0, leaderEpoch = 2, isr = List(3, 4), zkVersion = 0)) + + checkUpdateLeaderAndIsrResult( + leaderIsrs(state = 2, zkVersion = 2).filter { case (tp, _) => tp == topicPartition10 }, + ArrayBuffer(topicPartition11), + Map( + topicPartition20 -> (classOf[NoNodeException], "KeeperErrorCode = NoNode for /brokers/topics/topic2/partitions/0/state")), + zkClient.updateLeaderAndIsr(mixedState, controllerEpoch = 4, controllerEpochZkVersion)) + } + + private def checkGetDataResponse( + leaderIsrAndControllerEpochs: Map[TopicPartition,LeaderIsrAndControllerEpoch], + topicPartition: TopicPartition, + response: GetDataResponse): Unit = { + val zkVersion = leaderIsrAndControllerEpochs(topicPartition).leaderAndIsr.zkVersion + assertEquals(Code.OK, response.resultCode) + assertEquals(TopicPartitionStateZNode.path(topicPartition), response.path) + assertEquals(Some(topicPartition), response.ctx) + assertEquals( + Some(leaderIsrAndControllerEpochs(topicPartition)), + TopicPartitionStateZNode.decode(response.data, statWithVersion(zkVersion))) + } + + private def eraseMetadata(response: CreateResponse): CreateResponse = + response.copy(metadata = ResponseMetadata(0, 0)) + + @Test + def testGetTopicsAndPartitions(): Unit = { + assertTrue(zkClient.getAllTopicsInCluster().isEmpty) + assertTrue(zkClient.getAllPartitions.isEmpty) + + zkClient.createRecursive(TopicZNode.path(topic1)) + zkClient.createRecursive(TopicZNode.path(topic2)) + assertEquals(Set(topic1, topic2), zkClient.getAllTopicsInCluster()) + + assertTrue(zkClient.getAllPartitions.isEmpty) + + zkClient.createTopicPartitionStatesRaw(initialLeaderIsrAndControllerEpochs, controllerEpochZkVersion) + assertEquals(Set(topicPartition10, topicPartition11), zkClient.getAllPartitions) + } + + @Test + def testCreateAndGetTopicPartitionStatesRaw(): Unit = { + zkClient.createRecursive(TopicZNode.path(topic1)) + + // Mismatch controller epoch zkVersion + assertThrows(classOf[ControllerMovedException], () => zkClient.createTopicPartitionStatesRaw(initialLeaderIsrAndControllerEpochs, controllerEpochZkVersion + 1)) + + assertEquals( + Seq( + CreateResponse(Code.OK, TopicPartitionStateZNode.path(topicPartition10), Some(topicPartition10), + TopicPartitionStateZNode.path(topicPartition10), ResponseMetadata(0, 0)), + CreateResponse(Code.OK, TopicPartitionStateZNode.path(topicPartition11), Some(topicPartition11), + TopicPartitionStateZNode.path(topicPartition11), ResponseMetadata(0, 0))), + zkClient.createTopicPartitionStatesRaw(initialLeaderIsrAndControllerEpochs, controllerEpochZkVersion) + .map(eraseMetadata).toList) + + val getResponses = zkClient.getTopicPartitionStatesRaw(topicPartitions10_11) + assertEquals(2, getResponses.size) + topicPartitions10_11.zip(getResponses) foreach {case (tp, r) => checkGetDataResponse(initialLeaderIsrAndControllerEpochs, tp, r)} + + // Trying to create existing topicPartition states fails + assertEquals( + Seq( + CreateResponse(Code.NODEEXISTS, TopicPartitionStateZNode.path(topicPartition10), Some(topicPartition10), null, ResponseMetadata(0, 0)), + CreateResponse(Code.NODEEXISTS, TopicPartitionStateZNode.path(topicPartition11), Some(topicPartition11), null, ResponseMetadata(0, 0))), + zkClient.createTopicPartitionStatesRaw(initialLeaderIsrAndControllerEpochs, controllerEpochZkVersion).map(eraseMetadata).toList) + } + + @Test + def testSetTopicPartitionStatesRaw(): Unit = { + + def expectedSetDataResponses(topicPartitions: TopicPartition*)(resultCode: Code, stat: Stat) = + topicPartitions.map { topicPartition => + SetDataResponse(resultCode, TopicPartitionStateZNode.path(topicPartition), + Some(topicPartition), stat, ResponseMetadata(0, 0)) + } + + zkClient.createRecursive(TopicZNode.path(topic1)) + + // Trying to set non-existing topicPartition's data results in NONODE responses + assertEquals( + expectedSetDataResponses(topicPartition10, topicPartition11)(Code.NONODE, null), + zkClient.setTopicPartitionStatesRaw(initialLeaderIsrAndControllerEpochs, controllerEpochZkVersion).map { + _.copy(metadata = ResponseMetadata(0, 0))}.toList) + + zkClient.createTopicPartitionStatesRaw(initialLeaderIsrAndControllerEpochs, controllerEpochZkVersion) + + assertEquals( + expectedSetDataResponses(topicPartition10, topicPartition11)(Code.OK, statWithVersion(1)), + zkClient.setTopicPartitionStatesRaw(leaderIsrAndControllerEpochs(state = 1, zkVersion = 0), controllerEpochZkVersion).map { + eraseMetadataAndStat}.toList) + + // Mismatch controller epoch zkVersion + assertThrows(classOf[ControllerMovedException], () => zkClient.setTopicPartitionStatesRaw(leaderIsrAndControllerEpochs(state = 1, zkVersion = 0), controllerEpochZkVersion + 1)) + + val getResponses = zkClient.getTopicPartitionStatesRaw(topicPartitions10_11) + assertEquals(2, getResponses.size) + topicPartitions10_11.zip(getResponses) foreach {case (tp, r) => checkGetDataResponse(leaderIsrAndControllerEpochs(state = 1, zkVersion = 0), tp, r)} + + // Other ZK client can also write the state of a partition + assertEquals( + expectedSetDataResponses(topicPartition10, topicPartition11)(Code.OK, statWithVersion(2)), + otherZkClient.setTopicPartitionStatesRaw(leaderIsrAndControllerEpochs(state = 2, zkVersion = 1), controllerEpochZkVersion).map { + eraseMetadataAndStat}.toList) + } + + @Test + def testReassignPartitionsInProgress(): Unit = { + assertFalse(zkClient.reassignPartitionsInProgress) + zkClient.createRecursive(ReassignPartitionsZNode.path) + assertTrue(zkClient.reassignPartitionsInProgress) + } + + @Test + def testGetTopicPartitionStates(): Unit = { + assertEquals(None, zkClient.getTopicPartitionState(topicPartition10)) + assertEquals(None, zkClient.getLeaderForPartition(topicPartition10)) + + zkClient.createRecursive(TopicZNode.path(topic1)) + + zkClient.createTopicPartitionStatesRaw(initialLeaderIsrAndControllerEpochs, controllerEpochZkVersion) + assertEquals( + initialLeaderIsrAndControllerEpochs, + zkClient.getTopicPartitionStates(Seq(topicPartition10, topicPartition11)) + ) + + assertEquals( + Some(initialLeaderIsrAndControllerEpochs(topicPartition10)), + zkClient.getTopicPartitionState(topicPartition10) + ) + + assertEquals(Some(1), zkClient.getLeaderForPartition(topicPartition10)) + + val notExistingPartition = new TopicPartition(topic1, 2) + assertTrue(zkClient.getTopicPartitionStates(Seq(notExistingPartition)).isEmpty) + assertEquals( + Map(topicPartition10 -> initialLeaderIsrAndControllerEpochs(topicPartition10)), + zkClient.getTopicPartitionStates(Seq(topicPartition10, notExistingPartition)) + ) + + assertEquals(None, zkClient.getTopicPartitionState(notExistingPartition)) + assertEquals(None, zkClient.getLeaderForPartition(notExistingPartition)) + + } + + private def eraseMetadataAndStat(response: SetDataResponse): SetDataResponse = { + val stat = if (response.stat != null) statWithVersion(response.stat.getVersion) else null + response.copy(metadata = ResponseMetadata(0, 0), stat = stat) + } + + @Test + def testControllerEpochMethods(): Unit = { + zkClient.deletePath(ControllerEpochZNode.path) + + assertEquals(None, zkClient.getControllerEpoch) + + assertEquals(SetDataResponse(Code.NONODE, ControllerEpochZNode.path, None, null, ResponseMetadata(0, 0)), + eraseMetadataAndStat(zkClient.setControllerEpochRaw(1, 0)), + "Setting non existing nodes should return NONODE results") + + assertEquals(CreateResponse(Code.OK, ControllerEpochZNode.path, None, ControllerEpochZNode.path, ResponseMetadata(0, 0)), + eraseMetadata(zkClient.createControllerEpochRaw(0)), + "Creating non existing nodes is OK") + assertEquals(0, zkClient.getControllerEpoch.get._1) + + assertEquals(CreateResponse(Code.NODEEXISTS, ControllerEpochZNode.path, None, null, ResponseMetadata(0, 0)), + eraseMetadata(zkClient.createControllerEpochRaw(0)), + "Attemt to create existing nodes should return NODEEXISTS") + + assertEquals(SetDataResponse(Code.OK, ControllerEpochZNode.path, None, statWithVersion(1), ResponseMetadata(0, 0)), + eraseMetadataAndStat(zkClient.setControllerEpochRaw(1, 0)), + "Updating existing nodes is OK") + assertEquals(1, zkClient.getControllerEpoch.get._1) + + assertEquals(SetDataResponse(Code.BADVERSION, ControllerEpochZNode.path, None, null, ResponseMetadata(0, 0)), + eraseMetadataAndStat(zkClient.setControllerEpochRaw(1, 0)), + "Updating with wrong ZK version returns BADVERSION") + } + + @Test + def testControllerManagementMethods(): Unit = { + // No controller + assertEquals(None, zkClient.getControllerId) + // Create controller + val (_, newEpochZkVersion) = zkClient.registerControllerAndIncrementControllerEpoch(controllerId = 1) + assertEquals(Some(1), zkClient.getControllerId) + zkClient.deleteController(newEpochZkVersion) + assertEquals(None, zkClient.getControllerId) + } + + @Test + def testZNodeChangeHandlerForDataChange(): Unit = { + val mockPath = "/foo" + + val znodeChangeHandlerCountDownLatch = new CountDownLatch(1) + val zNodeChangeHandler = new ZNodeChangeHandler { + override def handleCreation(): Unit = { + znodeChangeHandlerCountDownLatch.countDown() + } + + override val path: String = mockPath + } + + zkClient.registerZNodeChangeHandlerAndCheckExistence(zNodeChangeHandler) + zkClient.createRecursive(mockPath) + assertTrue(znodeChangeHandlerCountDownLatch.await(5, TimeUnit.SECONDS), "Failed to receive create notification") + } + + @Test + def testClusterIdMethods(): Unit = { + val clusterId = CoreUtils.generateUuidAsBase64() + + zkClient.createOrGetClusterId(clusterId) + assertEquals(clusterId, zkClient.getClusterId.getOrElse(fail("No cluster id found"))) + } + + @Test + def testBrokerSequenceIdMethods(): Unit = { + val sequenceId = zkClient.generateBrokerSequenceId() + assertEquals(sequenceId + 1, zkClient.generateBrokerSequenceId()) + } + + @Test + def testCreateTopLevelPaths(): Unit = { + zkClient.createTopLevelPaths() + + ZkData.PersistentZkPaths.foreach(path => assertTrue(zkClient.pathExists(path))) + } + + @Test + def testPreferredReplicaElectionMethods(): Unit = { + + assertTrue(zkClient.getPreferredReplicaElection.isEmpty) + + val electionPartitions = Set(new TopicPartition(topic1, 0), new TopicPartition(topic1, 1)) + + zkClient.createPreferredReplicaElection(electionPartitions) + assertEquals(electionPartitions, zkClient.getPreferredReplicaElection) + + assertThrows(classOf[NodeExistsException], () => zkClient.createPreferredReplicaElection(electionPartitions)) + + // Mismatch controller epoch zkVersion + assertThrows(classOf[ControllerMovedException], () => zkClient.deletePreferredReplicaElection(controllerEpochZkVersion + 1)) + assertEquals(electionPartitions, zkClient.getPreferredReplicaElection) + + zkClient.deletePreferredReplicaElection(controllerEpochZkVersion) + assertTrue(zkClient.getPreferredReplicaElection.isEmpty) + } + + private def dataAsString(path: String): Option[String] = { + val (data, _) = zkClient.getDataAndStat(path) + data.map(new String(_, UTF_8)) + } + + @Test + def testDelegationTokenMethods(): Unit = { + assertFalse(zkClient.pathExists(DelegationTokensZNode.path)) + assertFalse(zkClient.pathExists(DelegationTokenChangeNotificationZNode.path)) + + zkClient.createDelegationTokenPaths() + assertTrue(zkClient.pathExists(DelegationTokensZNode.path)) + assertTrue(zkClient.pathExists(DelegationTokenChangeNotificationZNode.path)) + + val tokenId = "token1" + val owner = SecurityUtils.parseKafkaPrincipal("User:owner1") + val renewers = List(SecurityUtils.parseKafkaPrincipal("User:renewer1"), SecurityUtils.parseKafkaPrincipal("User:renewer1")) + + val tokenInfo = new TokenInformation(tokenId, owner, renewers.asJava, + System.currentTimeMillis(), System.currentTimeMillis(), System.currentTimeMillis()) + val bytes = new Array[Byte](20) + Random.nextBytes(bytes) + val token = new org.apache.kafka.common.security.token.delegation.DelegationToken(tokenInfo, bytes) + + // test non-existent token + assertTrue(zkClient.getDelegationTokenInfo(tokenId).isEmpty) + assertFalse(zkClient.deleteDelegationToken(tokenId)) + + // create a token + zkClient.setOrCreateDelegationToken(token) + + //get created token + assertEquals(tokenInfo, zkClient.getDelegationTokenInfo(tokenId).get) + + //update expiryTime + tokenInfo.setExpiryTimestamp(System.currentTimeMillis()) + zkClient.setOrCreateDelegationToken(token) + + //test updated token + assertEquals(tokenInfo, zkClient.getDelegationTokenInfo(tokenId).get) + + //test deleting token + assertTrue(zkClient.deleteDelegationToken(tokenId)) + assertEquals(None, zkClient.getDelegationTokenInfo(tokenId)) + } + + @Test + def testConsumerOffsetPath(): Unit = { + def getConsumersOffsetsZkPath(consumerGroup: String, topic: String, partition: Int): String = { + s"/consumers/$consumerGroup/offsets/$topic/$partition" + } + + val consumerGroup = "test-group" + val topic = "test-topic" + val partition = 2 + + val expectedConsumerGroupOffsetsPath = getConsumersOffsetsZkPath(consumerGroup, topic, partition) + val actualConsumerGroupOffsetsPath = ConsumerOffset.path(consumerGroup, topic, partition) + + assertEquals(expectedConsumerGroupOffsetsPath, actualConsumerGroupOffsetsPath) + } + + @Test + def testAclMethods(): Unit = { + val mockPath = "/foo" + + assertThrows(classOf[NoNodeException], () => zkClient.getAcl(mockPath)) + + assertThrows(classOf[NoNodeException], () => zkClient.setAcl(mockPath, ZooDefs.Ids.OPEN_ACL_UNSAFE.asScala)) + + zkClient.createRecursive(mockPath) + + zkClient.setAcl(mockPath, ZooDefs.Ids.READ_ACL_UNSAFE.asScala) + + assertEquals(ZooDefs.Ids.READ_ACL_UNSAFE.asScala, zkClient.getAcl(mockPath)) + } + + @Test + def testJuteMaxBufffer(): Unit = { + + def assertJuteMaxBufferConfig(clientConfig: ZKClientConfig, expectedValue: String): Unit = { + val client = KafkaZkClient(zkConnect, zkAclsEnabled.getOrElse(JaasUtils.isZkSaslEnabled), zkSessionTimeout, + zkConnectionTimeout, zkMaxInFlightRequests, Time.SYSTEM, name = "KafkaZkClient", + zkClientConfig = clientConfig) + try assertEquals(expectedValue, client.currentZooKeeper.getClientConfig.getProperty(ZKConfig.JUTE_MAXBUFFER)) + finally client.close() + } + + // default case + assertEquals("4194304", zkClient.currentZooKeeper.getClientConfig.getProperty(ZKConfig.JUTE_MAXBUFFER)) + + // Value set directly on ZKClientConfig takes precedence over system property + System.setProperty(ZKConfig.JUTE_MAXBUFFER, (3000 * 1024).toString) + try { + val clientConfig1 = new ZKClientConfig + clientConfig1.setProperty(ZKConfig.JUTE_MAXBUFFER, (2000 * 1024).toString) + assertJuteMaxBufferConfig(clientConfig1, expectedValue = "2048000") + + // System property value is used if value is not set in ZKClientConfig + assertJuteMaxBufferConfig(new ZKClientConfig, expectedValue = "3072000") + + } finally System.clearProperty(ZKConfig.JUTE_MAXBUFFER) + } + + class ExpiredKafkaZkClient private (zooKeeperClient: ZooKeeperClient, isSecure: Boolean, time: Time) + extends KafkaZkClient(zooKeeperClient, isSecure, time) { + // Overwriting this method from the parent class to force the client to re-register the Broker. + override def shouldReCreateEphemeralZNode(ephemeralOwnerId: Long): Boolean = { + true + } + + def getPathCzxid(path: String): Long = { + val getDataRequest = GetDataRequest(path) + val getDataResponse = retryRequestUntilConnected(getDataRequest) + + getDataResponse.stat.getCzxid + } + } + + private object ExpiredKafkaZkClient { + def apply(connectString: String, + isSecure: Boolean, + sessionTimeoutMs: Int, + connectionTimeoutMs: Int, + maxInFlightRequests: Int, + time: Time, + metricGroup: String = "kafka.server", + metricType: String = "SessionExpireListener") = { + val zooKeeperClient = new ZooKeeperClient(connectString, sessionTimeoutMs, connectionTimeoutMs, maxInFlightRequests, + time, metricGroup, metricType, new ZKClientConfig, "ExpiredKafkaZkClient") + new ExpiredKafkaZkClient(zooKeeperClient, isSecure, time) + } + } +} diff --git a/core/src/test/scala/unit/kafka/zk/ReassignPartitionsZNodeTest.scala b/core/src/test/scala/unit/kafka/zk/ReassignPartitionsZNodeTest.scala new file mode 100644 index 0000000..c2b45d0 --- /dev/null +++ b/core/src/test/scala/unit/kafka/zk/ReassignPartitionsZNodeTest.scala @@ -0,0 +1,55 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.zk + +import java.nio.charset.StandardCharsets + +import com.fasterxml.jackson.core.JsonProcessingException +import org.apache.kafka.common.TopicPartition +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +class ReassignPartitionsZNodeTest { + + private val topic = "foo" + private val partition1 = 0 + private val replica1 = 1 + private val replica2 = 2 + + private val reassignPartitionData = Map(new TopicPartition(topic, partition1) -> Seq(replica1, replica2)) + private val reassignmentJson = """{"version":1,"partitions":[{"topic":"foo","partition":0,"replicas":[1,2]}]}""" + + @Test + def testEncode(): Unit = { + val encodedJsonString = new String(ReassignPartitionsZNode.encode(reassignPartitionData), StandardCharsets.UTF_8) + assertEquals(reassignmentJson, encodedJsonString) + } + + @Test + def testDecodeInvalidJson(): Unit = { + val result = ReassignPartitionsZNode.decode("invalid json".getBytes) + val exception = result.left.getOrElse(throw new AssertionError(s"decode should have failed, result $result")) + assertTrue(exception.isInstanceOf[JsonProcessingException]) + } + + @Test + def testDecodeValidJson(): Unit = { + val result = ReassignPartitionsZNode.decode(reassignmentJson.getBytes) + val replicas = result.map(assignmentMap => assignmentMap(new TopicPartition(topic, partition1))) + assertEquals(Right(Seq(replica1, replica2)), replicas) + } +} diff --git a/core/src/test/scala/unit/kafka/zk/ZkFourLetterWords.scala b/core/src/test/scala/unit/kafka/zk/ZkFourLetterWords.scala new file mode 100644 index 0000000..2930ac8 --- /dev/null +++ b/core/src/test/scala/unit/kafka/zk/ZkFourLetterWords.scala @@ -0,0 +1,47 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.zk + +import java.io.IOException +import java.net.{SocketTimeoutException, Socket, InetAddress, InetSocketAddress} + +/** + * ZooKeeper responds to a small set of commands. Each command is composed of four letters. You issue the commands to + * ZooKeeper via telnet or nc, at the client port. + * + * Three of the more interesting commands: "stat" gives some general information about the server and connected + * clients, while "srvr" and "cons" give extended details on server and connections respectively. + */ +object ZkFourLetterWords { + def sendStat(host: String, port: Int, timeout: Int): Unit = { + val hostAddress = + if (host != null) new InetSocketAddress(host, port) + else new InetSocketAddress(InetAddress.getByName(null), port) + val sock = new Socket() + try { + sock.connect(hostAddress, timeout) + val outStream = sock.getOutputStream + outStream.write("stat".getBytes) + outStream.flush() + } catch { + case e: SocketTimeoutException => throw new IOException("Exception while sending 4lw", e) + } finally { + sock.close + } + } +} diff --git a/core/src/test/scala/unit/kafka/zookeeper/ZooKeeperClientTest.scala b/core/src/test/scala/unit/kafka/zookeeper/ZooKeeperClientTest.scala new file mode 100644 index 0000000..5af2ba8 --- /dev/null +++ b/core/src/test/scala/unit/kafka/zookeeper/ZooKeeperClientTest.scala @@ -0,0 +1,724 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.zookeeper + +import java.nio.charset.StandardCharsets +import java.util.UUID +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} +import java.util.concurrent.{ArrayBlockingQueue, ConcurrentLinkedQueue, CountDownLatch, Executors, Semaphore, TimeUnit} + +import scala.collection.Seq +import com.yammer.metrics.core.{Gauge, Meter, MetricName} +import kafka.server.KafkaConfig +import kafka.metrics.KafkaYammerMetrics +import kafka.utils.TestUtils +import kafka.server.QuorumTestHarness +import org.apache.kafka.common.security.JaasUtils +import org.apache.kafka.common.utils.Time +import org.apache.zookeeper.KeeperException.{Code, NoNodeException} +import org.apache.zookeeper.Watcher.Event.{EventType, KeeperState} +import org.apache.zookeeper.ZooKeeper.States +import org.apache.zookeeper.client.ZKClientConfig +import org.apache.zookeeper.{CreateMode, WatchedEvent, ZooDefs} +import org.junit.jupiter.api.Assertions.{assertArrayEquals, assertEquals, assertFalse, assertThrows, assertTrue, fail} +import org.junit.jupiter.api.{AfterEach, BeforeEach, Test, TestInfo} + +import scala.jdk.CollectionConverters._ + +class ZooKeeperClientTest extends QuorumTestHarness { + private val mockPath = "/foo" + private val time = Time.SYSTEM + + private var zooKeeperClient: ZooKeeperClient = _ + + @BeforeEach + override def setUp(testInfo: TestInfo): Unit = { + TestUtils.verifyNoUnexpectedThreads("@BeforeEach") + cleanMetricsRegistry() + super.setUp(testInfo) + zooKeeperClient = newZooKeeperClient() + } + + @AfterEach + override def tearDown(): Unit = { + if (zooKeeperClient != null) + zooKeeperClient.close() + super.tearDown() + System.clearProperty(JaasUtils.JAVA_LOGIN_CONFIG_PARAM) + TestUtils.verifyNoUnexpectedThreads("@AfterEach") + } + + @Test + def testUnresolvableConnectString(): Unit = { + try { + newZooKeeperClient("some.invalid.hostname.foo.bar.local", connectionTimeoutMs = 10) + } catch { + case e: ZooKeeperClientTimeoutException => + assertEquals(Set.empty, runningZkSendThreads, "ZooKeeper client threads still running") + } + } + + private def runningZkSendThreads: collection.Set[String] = Thread.getAllStackTraces.keySet.asScala + .filter(_.isAlive) + .map(_.getName) + .filter(t => t.contains("SendThread()")) + + @Test + def testConnectionTimeout(): Unit = { + zookeeper.shutdown() + assertThrows(classOf[ZooKeeperClientTimeoutException], () => newZooKeeperClient( + connectionTimeoutMs = 10).close()) + } + + @Test + def testConnection(): Unit = { + val client = newZooKeeperClient() + try { + // Verify ZooKeeper event thread name. This is used in QuorumTestHarness to verify that tests have closed ZK clients + val threads = Thread.getAllStackTraces.keySet.asScala.map(_.getName) + assertTrue(threads.exists(_.contains(QuorumTestHarness.ZkClientEventThreadSuffix)), + s"ZooKeeperClient event thread not found, threads=$threads") + } finally { + client.close() + } + } + + @Test + def testConnectionViaNettyClient(): Unit = { + // Confirm that we can explicitly set client connection configuration, which is necessary for TLS. + // TLS connectivity itself is tested in system tests rather than here to avoid having to add TLS support + // to kafka.zk.EmbeddedZookeeper + val clientConfig = new ZKClientConfig() + val propKey = KafkaConfig.ZkClientCnxnSocketProp + val propVal = "org.apache.zookeeper.ClientCnxnSocketNetty" + KafkaConfig.setZooKeeperClientProperty(clientConfig, propKey, propVal) + val client = newZooKeeperClient(clientConfig = clientConfig) + try { + assertEquals(Some(propVal), KafkaConfig.zooKeeperClientProperty(client.clientConfig, propKey)) + // For a sanity check, make sure a bad client connection socket class name generates an exception + val badClientConfig = new ZKClientConfig() + KafkaConfig.setZooKeeperClientProperty(badClientConfig, propKey, propVal + "BadClassName") + assertThrows(classOf[Exception], () => newZooKeeperClient(clientConfig = badClientConfig)) + } finally { + client.close() + } + } + + @Test + def testDeleteNonExistentZNode(): Unit = { + val deleteResponse = zooKeeperClient.handleRequest(DeleteRequest(mockPath, -1)) + assertEquals(Code.NONODE, deleteResponse.resultCode, "Response code should be NONODE") + assertThrows(classOf[NoNodeException], () => deleteResponse.maybeThrow()) + } + + @Test + def testDeleteExistingZNode(): Unit = { + val createResponse = zooKeeperClient.handleRequest(CreateRequest(mockPath, Array.empty[Byte], + ZooDefs.Ids.OPEN_ACL_UNSAFE.asScala, CreateMode.PERSISTENT)) + assertEquals(Code.OK, createResponse.resultCode, "Response code for create should be OK") + val deleteResponse = zooKeeperClient.handleRequest(DeleteRequest(mockPath, -1)) + assertEquals(Code.OK, deleteResponse.resultCode, "Response code for delete should be OK") + } + + @Test + def testExistsNonExistentZNode(): Unit = { + val existsResponse = zooKeeperClient.handleRequest(ExistsRequest(mockPath)) + assertEquals(Code.NONODE, existsResponse.resultCode, "Response code should be NONODE") + } + + @Test + def testExistsExistingZNode(): Unit = { + val createResponse = zooKeeperClient.handleRequest(CreateRequest(mockPath, Array.empty[Byte], + ZooDefs.Ids.OPEN_ACL_UNSAFE.asScala, CreateMode.PERSISTENT)) + assertEquals(Code.OK, createResponse.resultCode, "Response code for create should be OK") + val existsResponse = zooKeeperClient.handleRequest(ExistsRequest(mockPath)) + assertEquals(Code.OK, existsResponse.resultCode, "Response code for exists should be OK") + } + + @Test + def testGetDataNonExistentZNode(): Unit = { + val getDataResponse = zooKeeperClient.handleRequest(GetDataRequest(mockPath)) + assertEquals(Code.NONODE, getDataResponse.resultCode, "Response code should be NONODE") + } + + @Test + def testGetDataExistingZNode(): Unit = { + val data = bytes + val createResponse = zooKeeperClient.handleRequest(CreateRequest(mockPath, data, ZooDefs.Ids.OPEN_ACL_UNSAFE.asScala, + CreateMode.PERSISTENT)) + assertEquals(Code.OK, createResponse.resultCode, "Response code for create should be OK") + val getDataResponse = zooKeeperClient.handleRequest(GetDataRequest(mockPath)) + assertEquals(Code.OK, getDataResponse.resultCode, "Response code for getData should be OK") + assertArrayEquals(data, getDataResponse.data, "Data for getData should match created znode data") + } + + @Test + def testSetDataNonExistentZNode(): Unit = { + val setDataResponse = zooKeeperClient.handleRequest(SetDataRequest(mockPath, Array.empty[Byte], -1)) + assertEquals(Code.NONODE, setDataResponse.resultCode, "Response code should be NONODE") + } + + @Test + def testSetDataExistingZNode(): Unit = { + val data = bytes + val createResponse = zooKeeperClient.handleRequest(CreateRequest(mockPath, Array.empty[Byte], + ZooDefs.Ids.OPEN_ACL_UNSAFE.asScala, CreateMode.PERSISTENT)) + assertEquals(Code.OK, createResponse.resultCode, "Response code for create should be OK") + val setDataResponse = zooKeeperClient.handleRequest(SetDataRequest(mockPath, data, -1)) + assertEquals(Code.OK, setDataResponse.resultCode, "Response code for setData should be OK") + val getDataResponse = zooKeeperClient.handleRequest(GetDataRequest(mockPath)) + assertEquals(Code.OK, getDataResponse.resultCode, "Response code for getData should be OK") + assertArrayEquals(data, getDataResponse.data, "Data for getData should match setData's data") + } + + @Test + def testGetAclNonExistentZNode(): Unit = { + val getAclResponse = zooKeeperClient.handleRequest(GetAclRequest(mockPath)) + assertEquals(Code.NONODE, getAclResponse.resultCode, "Response code should be NONODE") + } + + @Test + def testGetAclExistingZNode(): Unit = { + val createResponse = zooKeeperClient.handleRequest(CreateRequest(mockPath, Array.empty[Byte], ZooDefs.Ids.OPEN_ACL_UNSAFE.asScala, CreateMode.PERSISTENT)) + assertEquals(Code.OK, createResponse.resultCode, "Response code for create should be OK") + val getAclResponse = zooKeeperClient.handleRequest(GetAclRequest(mockPath)) + assertEquals(Code.OK, getAclResponse.resultCode, "Response code for getAcl should be OK") + assertEquals(ZooDefs.Ids.OPEN_ACL_UNSAFE.asScala, getAclResponse.acl, "ACL should be " + ZooDefs.Ids.OPEN_ACL_UNSAFE.asScala) + } + + @Test + def testSetAclNonExistentZNode(): Unit = { + val setAclResponse = zooKeeperClient.handleRequest(SetAclRequest(mockPath, ZooDefs.Ids.OPEN_ACL_UNSAFE.asScala, -1)) + assertEquals(Code.NONODE, setAclResponse.resultCode, "Response code should be NONODE") + } + + @Test + def testGetChildrenNonExistentZNode(): Unit = { + val getChildrenResponse = zooKeeperClient.handleRequest(GetChildrenRequest(mockPath, registerWatch = true)) + assertEquals(Code.NONODE, getChildrenResponse.resultCode, "Response code should be NONODE") + } + + @Test + def testGetChildrenExistingZNode(): Unit = { + val createResponse = zooKeeperClient.handleRequest(CreateRequest(mockPath, Array.empty[Byte], + ZooDefs.Ids.OPEN_ACL_UNSAFE.asScala, CreateMode.PERSISTENT)) + assertEquals(Code.OK, createResponse.resultCode, "Response code for create should be OK") + val getChildrenResponse = zooKeeperClient.handleRequest(GetChildrenRequest(mockPath, registerWatch = true)) + assertEquals(Code.OK, getChildrenResponse.resultCode, "Response code for getChildren should be OK") + assertEquals(Seq.empty[String], getChildrenResponse.children, "getChildren should return no children") + } + + @Test + def testGetChildrenExistingZNodeWithChildren(): Unit = { + val child1 = "child1" + val child2 = "child2" + val child1Path = mockPath + "/" + child1 + val child2Path = mockPath + "/" + child2 + val createResponse = zooKeeperClient.handleRequest(CreateRequest(mockPath, Array.empty[Byte], + ZooDefs.Ids.OPEN_ACL_UNSAFE.asScala, CreateMode.PERSISTENT)) + assertEquals(Code.OK, createResponse.resultCode, "Response code for create should be OK") + val createResponseChild1 = zooKeeperClient.handleRequest(CreateRequest(child1Path, Array.empty[Byte], + ZooDefs.Ids.OPEN_ACL_UNSAFE.asScala, CreateMode.PERSISTENT)) + assertEquals(Code.OK, createResponseChild1.resultCode, "Response code for create child1 should be OK") + val createResponseChild2 = zooKeeperClient.handleRequest(CreateRequest(child2Path, Array.empty[Byte], + ZooDefs.Ids.OPEN_ACL_UNSAFE.asScala, CreateMode.PERSISTENT)) + assertEquals(Code.OK, createResponseChild2.resultCode, "Response code for create child2 should be OK") + + val getChildrenResponse = zooKeeperClient.handleRequest(GetChildrenRequest(mockPath, registerWatch = true)) + assertEquals(Code.OK, getChildrenResponse.resultCode, "Response code for getChildren should be OK") + assertEquals(Seq(child1, child2), getChildrenResponse.children.sorted, "getChildren should return two children") + } + + @Test + def testPipelinedGetData(): Unit = { + val createRequests = (1 to 3).map(x => CreateRequest("/" + x, (x * 2).toString.getBytes, ZooDefs.Ids.OPEN_ACL_UNSAFE.asScala, CreateMode.PERSISTENT)) + val createResponses = createRequests.map(zooKeeperClient.handleRequest) + createResponses.foreach(createResponse => assertEquals(Code.OK, createResponse.resultCode, "Response code for create should be OK")) + val getDataRequests = (1 to 3).map(x => GetDataRequest("/" + x)) + val getDataResponses = zooKeeperClient.handleRequests(getDataRequests) + getDataResponses.foreach(getDataResponse => assertEquals(Code.OK, getDataResponse.resultCode, + "Response code for getData should be OK")) + getDataResponses.zipWithIndex.foreach { case (getDataResponse, i) => + assertEquals(Code.OK, getDataResponse.resultCode, "Response code for getData should be OK") + assertEquals(((i + 1) * 2), Integer.valueOf(new String(getDataResponse.data)), "Data for getData should match") + } + } + + @Test + def testMixedPipeline(): Unit = { + val createResponse = zooKeeperClient.handleRequest(CreateRequest(mockPath, Array.empty[Byte], + ZooDefs.Ids.OPEN_ACL_UNSAFE.asScala, CreateMode.PERSISTENT)) + assertEquals(Code.OK, createResponse.resultCode, "Response code for create should be OK") + val getDataRequest = GetDataRequest(mockPath) + val setDataRequest = SetDataRequest("/nonexistent", Array.empty[Byte], -1) + val responses = zooKeeperClient.handleRequests(Seq(getDataRequest, setDataRequest)) + assertEquals(Code.OK, responses.head.resultCode, "Response code for getData should be OK") + assertArrayEquals(Array.empty[Byte], responses.head.asInstanceOf[GetDataResponse].data, "Data for getData should be empty") + assertEquals(Code.NONODE, responses.last.resultCode, "Response code for setData should be NONODE") + } + + @Test + def testZNodeChangeHandlerForCreation(): Unit = { + val znodeChangeHandlerCountDownLatch = new CountDownLatch(1) + val zNodeChangeHandler = new ZNodeChangeHandler { + override def handleCreation(): Unit = { + znodeChangeHandlerCountDownLatch.countDown() + } + override val path: String = mockPath + } + + zooKeeperClient.registerZNodeChangeHandler(zNodeChangeHandler) + val existsRequest = ExistsRequest(mockPath) + val createRequest = CreateRequest(mockPath, Array.empty[Byte], ZooDefs.Ids.OPEN_ACL_UNSAFE.asScala, CreateMode.PERSISTENT) + val responses = zooKeeperClient.handleRequests(Seq(existsRequest, createRequest)) + assertEquals(Code.NONODE, responses.head.resultCode, "Response code for exists should be NONODE") + assertEquals(Code.OK, responses.last.resultCode, "Response code for create should be OK") + assertTrue(znodeChangeHandlerCountDownLatch.await(5, TimeUnit.SECONDS), "Failed to receive create notification") + } + + @Test + def testZNodeChangeHandlerForDeletion(): Unit = { + val znodeChangeHandlerCountDownLatch = new CountDownLatch(1) + val zNodeChangeHandler = new ZNodeChangeHandler { + override def handleDeletion(): Unit = { + znodeChangeHandlerCountDownLatch.countDown() + } + override val path: String = mockPath + } + + zooKeeperClient.registerZNodeChangeHandler(zNodeChangeHandler) + val existsRequest = ExistsRequest(mockPath) + val createRequest = CreateRequest(mockPath, Array.empty[Byte], ZooDefs.Ids.OPEN_ACL_UNSAFE.asScala, CreateMode.PERSISTENT) + val responses = zooKeeperClient.handleRequests(Seq(createRequest, existsRequest)) + assertEquals(Code.OK, responses.last.resultCode, "Response code for create should be OK") + assertEquals(Code.OK, responses.head.resultCode, "Response code for exists should be OK") + val deleteResponse = zooKeeperClient.handleRequest(DeleteRequest(mockPath, -1)) + assertEquals(Code.OK, deleteResponse.resultCode, "Response code for delete should be OK") + assertTrue(znodeChangeHandlerCountDownLatch.await(5, TimeUnit.SECONDS), "Failed to receive delete notification") + } + + @Test + def testZNodeChangeHandlerForDataChange(): Unit = { + val znodeChangeHandlerCountDownLatch = new CountDownLatch(1) + val zNodeChangeHandler = new ZNodeChangeHandler { + override def handleDataChange(): Unit = { + znodeChangeHandlerCountDownLatch.countDown() + } + override val path: String = mockPath + } + + zooKeeperClient.registerZNodeChangeHandler(zNodeChangeHandler) + val existsRequest = ExistsRequest(mockPath) + val createRequest = CreateRequest(mockPath, Array.empty[Byte], ZooDefs.Ids.OPEN_ACL_UNSAFE.asScala, CreateMode.PERSISTENT) + val responses = zooKeeperClient.handleRequests(Seq(createRequest, existsRequest)) + assertEquals(Code.OK, responses.last.resultCode, "Response code for create should be OK") + assertEquals(Code.OK, responses.head.resultCode, "Response code for exists should be OK") + val setDataResponse = zooKeeperClient.handleRequest(SetDataRequest(mockPath, Array.empty[Byte], -1)) + assertEquals(Code.OK, setDataResponse.resultCode, "Response code for setData should be OK") + assertTrue(znodeChangeHandlerCountDownLatch.await(5, TimeUnit.SECONDS), "Failed to receive data change notification") + } + + @Test + def testBlockOnRequestCompletionFromStateChangeHandler(): Unit = { + // This tests the scenario exposed by KAFKA-6879 in which the expiration callback awaits + // completion of a request which is handled by another thread + + val latch = new CountDownLatch(1) + val stateChangeHandler = new StateChangeHandler { + override val name = this.getClass.getName + override def beforeInitializingSession(): Unit = { + latch.await() + } + } + + zooKeeperClient.close() + zooKeeperClient = newZooKeeperClient() + zooKeeperClient.registerStateChangeHandler(stateChangeHandler) + + val requestThread = new Thread() { + override def run(): Unit = { + try + zooKeeperClient.handleRequest(CreateRequest(mockPath, Array.empty[Byte], + ZooDefs.Ids.OPEN_ACL_UNSAFE.asScala, CreateMode.PERSISTENT)) + finally + latch.countDown() + } + } + + val reinitializeThread = new Thread() { + override def run(): Unit = { + zooKeeperClient.forceReinitialize() + } + } + + reinitializeThread.start() + + // sleep briefly before starting the request thread so that the initialization + // thread is blocking on the latch + Thread.sleep(100) + requestThread.start() + + reinitializeThread.join() + requestThread.join() + } + + @Test + def testExceptionInBeforeInitializingSession(): Unit = { + val faultyHandler = new StateChangeHandler { + override val name = this.getClass.getName + override def beforeInitializingSession(): Unit = { + throw new RuntimeException() + } + } + + val goodCalls = new AtomicInteger(0) + val goodHandler = new StateChangeHandler { + override val name = this.getClass.getName + override def beforeInitializingSession(): Unit = { + goodCalls.incrementAndGet() + } + } + + zooKeeperClient.close() + zooKeeperClient = newZooKeeperClient() + zooKeeperClient.registerStateChangeHandler(faultyHandler) + zooKeeperClient.registerStateChangeHandler(goodHandler) + + zooKeeperClient.forceReinitialize() + + assertEquals(1, goodCalls.get) + + // Client should be usable even if the callback throws an error + val createResponse = zooKeeperClient.handleRequest(CreateRequest(mockPath, Array.empty[Byte], + ZooDefs.Ids.OPEN_ACL_UNSAFE.asScala, CreateMode.PERSISTENT)) + assertEquals(Code.OK, createResponse.resultCode, "Response code for create should be OK") + } + + @Test + def testZNodeChildChangeHandlerForChildChange(): Unit = { + val zNodeChildChangeHandlerCountDownLatch = new CountDownLatch(1) + val zNodeChildChangeHandler = new ZNodeChildChangeHandler { + override def handleChildChange(): Unit = { + zNodeChildChangeHandlerCountDownLatch.countDown() + } + override val path: String = mockPath + } + + val child1 = "child1" + val child1Path = mockPath + "/" + child1 + val createResponse = zooKeeperClient.handleRequest( + CreateRequest(mockPath, Array.empty[Byte], ZooDefs.Ids.OPEN_ACL_UNSAFE.asScala, CreateMode.PERSISTENT)) + assertEquals(Code.OK, createResponse.resultCode, "Response code for create should be OK") + zooKeeperClient.registerZNodeChildChangeHandler(zNodeChildChangeHandler) + val getChildrenResponse = zooKeeperClient.handleRequest(GetChildrenRequest(mockPath, registerWatch = true)) + assertEquals(Code.OK, getChildrenResponse.resultCode, "Response code for getChildren should be OK") + val createResponseChild1 = zooKeeperClient.handleRequest( + CreateRequest(child1Path, Array.empty[Byte], ZooDefs.Ids.OPEN_ACL_UNSAFE.asScala, CreateMode.PERSISTENT)) + assertEquals(Code.OK, createResponseChild1.resultCode, "Response code for create child1 should be OK") + assertTrue(zNodeChildChangeHandlerCountDownLatch.await(5, TimeUnit.SECONDS), + "Failed to receive child change notification") + } + + @Test + def testZNodeChildChangeHandlerForChildChangeNotTriggered(): Unit = { + val zNodeChildChangeHandlerCountDownLatch = new CountDownLatch(1) + val zNodeChildChangeHandler = new ZNodeChildChangeHandler { + override def handleChildChange(): Unit = { + zNodeChildChangeHandlerCountDownLatch.countDown() + } + override val path: String = mockPath + } + + val child1 = "child1" + val child1Path = mockPath + "/" + child1 + val createResponse = zooKeeperClient.handleRequest( + CreateRequest(mockPath, Array.empty[Byte], ZooDefs.Ids.OPEN_ACL_UNSAFE.asScala, CreateMode.PERSISTENT)) + assertEquals(Code.OK, createResponse.resultCode, "Response code for create should be OK") + zooKeeperClient.registerZNodeChildChangeHandler(zNodeChildChangeHandler) + val getChildrenResponse = zooKeeperClient.handleRequest(GetChildrenRequest(mockPath, registerWatch = false)) + assertEquals(Code.OK, getChildrenResponse.resultCode, "Response code for getChildren should be OK") + val createResponseChild1 = zooKeeperClient.handleRequest( + CreateRequest(child1Path, Array.empty[Byte], ZooDefs.Ids.OPEN_ACL_UNSAFE.asScala, CreateMode.PERSISTENT)) + assertEquals(Code.OK, createResponseChild1.resultCode, "Response code for create child1 should be OK") + assertFalse(zNodeChildChangeHandlerCountDownLatch.await(100, TimeUnit.MILLISECONDS), + "Child change notification received") + } + + @Test + def testStateChangeHandlerForAuthFailure(): Unit = { + System.setProperty(JaasUtils.JAVA_LOGIN_CONFIG_PARAM, "no-such-file-exists.conf") + val stateChangeHandlerCountDownLatch = new CountDownLatch(1) + val stateChangeHandler = new StateChangeHandler { + override val name: String = this.getClass.getName + + override def onAuthFailure(): Unit = { + stateChangeHandlerCountDownLatch.countDown() + } + } + + val zooKeeperClient = newZooKeeperClient() + try { + zooKeeperClient.registerStateChangeHandler(stateChangeHandler) + zooKeeperClient.forceReinitialize() + + assertTrue(stateChangeHandlerCountDownLatch.await(5, TimeUnit.SECONDS), "Failed to receive auth failed notification") + } finally zooKeeperClient.close() + } + + @Test + def testConnectionLossRequestTermination(): Unit = { + val batchSize = 10 + val zooKeeperClient = newZooKeeperClient(maxInFlight = 2) + zookeeper.shutdown() + try { + val requests = (1 to batchSize).map(i => GetDataRequest(s"/$i")) + val countDownLatch = new CountDownLatch(1) + val running = new AtomicBoolean(true) + val unexpectedResponses = new ArrayBlockingQueue[GetDataResponse](batchSize) + val requestThread = new Thread { + override def run(): Unit = { + while (running.get()) { + val responses = zooKeeperClient.handleRequests(requests) + val suffix = responses.dropWhile(response => response.resultCode != Code.CONNECTIONLOSS) + if (!suffix.forall(response => response.resultCode == Code.CONNECTIONLOSS)) + responses.foreach(unexpectedResponses.add) + if (!unexpectedResponses.isEmpty || suffix.nonEmpty) + running.set(false) + } + countDownLatch.countDown() + } + } + requestThread.start() + val requestThreadTerminated = countDownLatch.await(30, TimeUnit.SECONDS) + if (!requestThreadTerminated) { + running.set(false) + requestThread.join(5000) + fail("Failed to receive a CONNECTIONLOSS response code after zookeeper has shutdown.") + } else if (!unexpectedResponses.isEmpty) { + fail(s"Received an unexpected non-CONNECTIONLOSS response code after a CONNECTIONLOSS response code from a single batch: $unexpectedResponses") + } + } finally zooKeeperClient.close() + } + + /** + * Tests that if session expiry notification is received while a thread is processing requests, + * session expiry is handled and the request thread completes with responses to all requests, + * even though some requests may fail due to session expiry or disconnection. + * + * Sequence of events on different threads: + * Request thread: + * - Sends `maxInflightRequests` requests (these may complete before session is expired) + * Main thread: + * - Waits for at least one request to be processed (this should succeed) + * - Expires session by creating new client with same session id + * - Unblocks another `maxInflightRequests` requests before and after new client is closed (these may fail) + * ZooKeeperClient Event thread: + * - Delivers responses and session expiry (no ordering guarantee between these, both are processed asynchronously) + * Response executor thread: + * - Blocks subsequent sends by delaying response until session expiry is processed + * ZooKeeperClient Session Expiry Handler: + * - Unblocks subsequent sends + * Main thread: + * - Waits for all sends to complete. The requests sent after session expiry processing should succeed. + */ + @Test + def testSessionExpiry(): Unit = { + val maxInflightRequests = 2 + val responseExecutor = Executors.newSingleThreadExecutor + val sendSemaphore = new Semaphore(0) + val sendCompleteSemaphore = new Semaphore(0) + val sendSize = maxInflightRequests * 5 + @volatile var resultCodes: Seq[Code] = null + val stateChanges = new ConcurrentLinkedQueue[String]() + val zooKeeperClient = new ZooKeeperClient(zkConnect, zkSessionTimeout, zkConnectionTimeout, maxInflightRequests, + time, "testGroupType", "testGroupName", new ZKClientConfig, "ZooKeeperClientTest") { + override def send[Req <: AsyncRequest](request: Req)(processResponse: Req#Response => Unit): Unit = { + super.send(request)( response => { + responseExecutor.submit(new Runnable { + override def run(): Unit = { + sendCompleteSemaphore.release() + sendSemaphore.acquire() + processResponse(response) + } + }) + }) + } + } + try { + zooKeeperClient.registerStateChangeHandler(new StateChangeHandler { + override val name: String ="test-state-change-handler" + override def afterInitializingSession(): Unit = { + verifyHandlerThread() + stateChanges.add("afterInitializingSession") + } + override def beforeInitializingSession(): Unit = { + verifyHandlerThread() + stateChanges.add("beforeInitializingSession") + sendSemaphore.release(sendSize) // Resume remaining sends + } + private def verifyHandlerThread(): Unit = { + val threadName = Thread.currentThread.getName + assertTrue(threadName.startsWith(zooKeeperClient.reinitializeScheduler.threadNamePrefix), s"Unexpected thread + $threadName") + } + }) + + val requestThread = new Thread { + override def run(): Unit = { + val requests = (1 to sendSize).map(i => GetDataRequest(s"/$i")) + resultCodes = zooKeeperClient.handleRequests(requests).map(_.resultCode) + } + } + requestThread.start() + sendCompleteSemaphore.acquire() // Wait for request thread to start processing requests + + val anotherZkClient = createZooKeeperClientToTriggerSessionExpiry(zooKeeperClient.currentZooKeeper) + sendSemaphore.release(maxInflightRequests) // Resume a few more sends which may fail + anotherZkClient.close() + sendSemaphore.release(maxInflightRequests) // Resume a few more sends which may fail + + requestThread.join(10000) + if (requestThread.isAlive) { + requestThread.interrupt() + fail("Request thread did not complete") + } + assertEquals(Seq("beforeInitializingSession", "afterInitializingSession"), stateChanges.asScala.toSeq) + + assertEquals(resultCodes.size, sendSize) + val connectionLostCount = resultCodes.count(_ == Code.CONNECTIONLOSS) + assertTrue(connectionLostCount <= maxInflightRequests, s"Unexpected connection lost requests $resultCodes") + val expiredCount = resultCodes.count(_ == Code.SESSIONEXPIRED) + assertTrue(expiredCount <= maxInflightRequests, s"Unexpected session expired requests $resultCodes") + assertTrue(connectionLostCount + expiredCount > 0, s"No connection lost or expired requests $resultCodes") + assertEquals(Code.NONODE, resultCodes.head) + assertEquals(Code.NONODE, resultCodes.last) + assertTrue(resultCodes.forall(Set(Code.NONODE, Code.SESSIONEXPIRED, Code.CONNECTIONLOSS).contains), + s"Unexpected result code $resultCodes") + + } finally { + zooKeeperClient.close() + responseExecutor.shutdownNow() + } + assertFalse(zooKeeperClient.reinitializeScheduler.isStarted, "Expiry executor not shutdown") + } + + @Test + def testSessionExpiryDuringClose(): Unit = { + val semaphore = new Semaphore(0) + val closeExecutor = Executors.newSingleThreadExecutor + try { + zooKeeperClient.reinitializeScheduler.schedule("test", () => semaphore.acquireUninterruptibly(), + delay = 0, period = -1, TimeUnit.SECONDS) + zooKeeperClient.scheduleReinitialize("session-expired", "Session expired.", delayMs = 0L) + val closeFuture = closeExecutor.submit(new Runnable { + override def run(): Unit = { + zooKeeperClient.close() + } + }) + assertFalse(closeFuture.isDone, "Close completed without shutting down expiry scheduler gracefully") + assertTrue(zooKeeperClient.currentZooKeeper.getState.isAlive) // Client should be closed after expiry handler + semaphore.release() + closeFuture.get(10, TimeUnit.SECONDS) + assertFalse(zooKeeperClient.reinitializeScheduler.isStarted, "Expiry executor not shutdown") + } finally { + closeExecutor.shutdownNow() + } + } + + @Test + def testReinitializeAfterAuthFailure(): Unit = { + val sessionInitializedCountDownLatch = new CountDownLatch(1) + val changeHandler = new StateChangeHandler { + override val name = this.getClass.getName + override def beforeInitializingSession(): Unit = { + sessionInitializedCountDownLatch.countDown() + } + } + + zooKeeperClient.close() + @volatile var connectionStateOverride: Option[States] = None + zooKeeperClient = new ZooKeeperClient(zkConnect, zkSessionTimeout, zkConnectionTimeout, + zkMaxInFlightRequests, time, "testMetricGroup", "testMetricType", new ZKClientConfig, "ZooKeeperClientTest") { + override def connectionState: States = connectionStateOverride.getOrElse(super.connectionState) + } + zooKeeperClient.registerStateChangeHandler(changeHandler) + + connectionStateOverride = Some(States.CONNECTED) + zooKeeperClient.ZooKeeperClientWatcher.process(new WatchedEvent(EventType.None, KeeperState.AuthFailed, null)) + assertFalse(sessionInitializedCountDownLatch.await(10, TimeUnit.MILLISECONDS), "Unexpected session initialization when connection is alive") + + connectionStateOverride = Some(States.AUTH_FAILED) + zooKeeperClient.ZooKeeperClientWatcher.process(new WatchedEvent(EventType.None, KeeperState.AuthFailed, null)) + assertTrue(sessionInitializedCountDownLatch.await(5, TimeUnit.SECONDS), "Failed to receive session initializing notification") + } + + def isExpectedMetricName(metricName: MetricName, name: String): Boolean = + metricName.getName == name && metricName.getGroup == "testMetricGroup" && metricName.getType == "testMetricType" + + @Test + def testZooKeeperStateChangeRateMetrics(): Unit = { + def checkMeterCount(name: String, expected: Long): Unit = { + val meter = KafkaYammerMetrics.defaultRegistry.allMetrics.asScala.collectFirst { + case (metricName, meter: Meter) if isExpectedMetricName(metricName, name) => meter + }.getOrElse(sys.error(s"Unable to find meter with name $name")) + assertEquals(expected, meter.count, s"Unexpected meter count for $name") + } + + val expiresPerSecName = "ZooKeeperExpiresPerSec" + val disconnectsPerSecName = "ZooKeeperDisconnectsPerSec" + checkMeterCount(expiresPerSecName, 0) + checkMeterCount(disconnectsPerSecName, 0) + + zooKeeperClient.ZooKeeperClientWatcher.process(new WatchedEvent(EventType.None, KeeperState.Expired, null)) + checkMeterCount(expiresPerSecName, 1) + checkMeterCount(disconnectsPerSecName, 0) + + zooKeeperClient.ZooKeeperClientWatcher.process(new WatchedEvent(EventType.None, KeeperState.Disconnected, null)) + checkMeterCount(expiresPerSecName, 1) + checkMeterCount(disconnectsPerSecName, 1) + } + + @Test + def testZooKeeperSessionStateMetric(): Unit = { + def gaugeValue(name: String): Option[String] = { + KafkaYammerMetrics.defaultRegistry.allMetrics.asScala.collectFirst { + case (metricName, gauge: Gauge[_]) if isExpectedMetricName(metricName, name) => gauge.value.asInstanceOf[String] + } + } + + assertEquals(Some(States.CONNECTED.toString), gaugeValue("SessionState")) + assertEquals(States.CONNECTED, zooKeeperClient.connectionState) + + zooKeeperClient.close() + + assertEquals(None, gaugeValue("SessionState")) + assertEquals(States.CLOSED, zooKeeperClient.connectionState) + } + + private def newZooKeeperClient(connectionString: String = zkConnect, + connectionTimeoutMs: Int = zkConnectionTimeout, + maxInFlight: Int = zkMaxInFlightRequests, + clientConfig: ZKClientConfig = new ZKClientConfig) = + new ZooKeeperClient(connectionString, zkSessionTimeout, connectionTimeoutMs, maxInFlight, time, + "testMetricGroup", "testMetricType", clientConfig, "ZooKeeperClientTest") + + private def cleanMetricsRegistry(): Unit = { + val metrics = KafkaYammerMetrics.defaultRegistry + metrics.allMetrics.keySet.forEach(metrics.removeMetric) + } + + private def bytes = UUID.randomUUID().toString.getBytes(StandardCharsets.UTF_8) +} diff --git a/doap_Kafka.rdf b/doap_Kafka.rdf new file mode 100644 index 0000000..5eddf47 --- /dev/null +++ b/doap_Kafka.rdf @@ -0,0 +1,50 @@ + + + + + + 2014-04-12 + + Apache Kafka + + + Apache Kafka is a distributed, fault tolerant, publish-subscribe messaging. + A single Kafka broker can handle hundreds of megabytes of reads and writes per second from thousands of clients. Kafka is designed to allow a single cluster to serve as the central data backbone for a large organization. It can be elastically and transparently expanded without downtime. Data streams are partitioned and spread over a cluster of machines to allow data streams larger than the capability of any single machine and to allow clusters of co-ordinated consumers. Kafka has a modern cluster-centric design that offers strong durability and fault-tolerance guarantees. Messages are persisted on disk and replicated within the cluster to prevent data loss. Each broker can handle terabytes of messages without performance impact. + + + + Scala + + + + + + + + + + Jun Rao + + + + + diff --git a/docs/api.html b/docs/api.html new file mode 100644 index 0000000..7b74d04 --- /dev/null +++ b/docs/api.html @@ -0,0 +1,110 @@ + + + +
                diff --git a/docs/configuration.html b/docs/configuration.html new file mode 100644 index 0000000..0782c83 --- /dev/null +++ b/docs/configuration.html @@ -0,0 +1,271 @@ + + + + +
                diff --git a/docs/connect.html b/docs/connect.html new file mode 100644 index 0000000..07f8778 --- /dev/null +++ b/docs/connect.html @@ -0,0 +1,753 @@ + + + + +
                diff --git a/docs/design.html b/docs/design.html new file mode 100644 index 0000000..8b7bf92 --- /dev/null +++ b/docs/design.html @@ -0,0 +1,661 @@ + + + + +
                diff --git a/docs/documentation.html b/docs/documentation.html new file mode 100644 index 0000000..d13f691 --- /dev/null +++ b/docs/documentation.html @@ -0,0 +1,108 @@ + + + + + + + + +
                + +
                +
                <
                +
                +
                + +
                +
                + + +

                Documentation

                +

                Kafka 3.1 Documentation

                + Prior releases: 0.7.x, + 0.8.0, + 0.8.1.X, + 0.8.2.X, + 0.9.0.X, + 0.10.0.X, + 0.10.1.X, + 0.10.2.X, + 0.11.0.X, + 1.0.X, + 1.1.X, + 2.0.X, + 2.1.X, + 2.2.X, + 2.3.X, + 2.4.X, + 2.5.X, + 2.6.X, + 2.7.X, + 2.8.X, + 3.0.X. + +

                1. Getting Started

                +

                1.1 Introduction

                + +

                1.2 Use Cases

                + +

                1.3 Quick Start

                + +

                1.4 Ecosystem

                + +

                1.5 Upgrading From Previous Versions

                + + +

                2. APIs

                + + + +

                3. Configuration

                + + + +

                4. Design

                + + + +

                5. Implementation

                + + + +

                6. Operations

                + + + +

                7. Security

                + + +

                8. Kafka Connect

                + + +

                9. Kafka Streams

                +

                + Kafka Streams is a client library for processing and analyzing data stored in Kafka. It builds upon important stream processing concepts such as properly distinguishing between event time and processing time, windowing support, exactly-once processing semantics and simple yet efficient management of application state. +

                +

                + Kafka Streams has a low barrier to entry: You can quickly write and run a small-scale proof-of-concept on a single machine; and you only need to run additional instances of your application on multiple machines to scale up to high-volume production workloads. Kafka Streams transparently handles the load balancing of multiple instances of the same application by leveraging Kafka's parallelism model. +

                + +

                Learn More about Kafka Streams read this Section.

                + + + diff --git a/docs/documentation/index.html b/docs/documentation/index.html new file mode 100644 index 0000000..1d7507f --- /dev/null +++ b/docs/documentation/index.html @@ -0,0 +1,18 @@ + + + \ No newline at end of file diff --git a/docs/documentation/streams/architecture.html b/docs/documentation/streams/architecture.html new file mode 100644 index 0000000..ad7b323 --- /dev/null +++ b/docs/documentation/streams/architecture.html @@ -0,0 +1,19 @@ + + + + diff --git a/docs/documentation/streams/core-concepts.html b/docs/documentation/streams/core-concepts.html new file mode 100644 index 0000000..d699b79 --- /dev/null +++ b/docs/documentation/streams/core-concepts.html @@ -0,0 +1,19 @@ + + + + diff --git a/docs/documentation/streams/developer-guide/app-reset-tool.html b/docs/documentation/streams/developer-guide/app-reset-tool.html new file mode 100644 index 0000000..64a43aa --- /dev/null +++ b/docs/documentation/streams/developer-guide/app-reset-tool.html @@ -0,0 +1,19 @@ + + + + diff --git a/docs/documentation/streams/developer-guide/config-streams.html b/docs/documentation/streams/developer-guide/config-streams.html new file mode 100644 index 0000000..979f66d --- /dev/null +++ b/docs/documentation/streams/developer-guide/config-streams.html @@ -0,0 +1,19 @@ + + + + diff --git a/docs/documentation/streams/developer-guide/datatypes.html b/docs/documentation/streams/developer-guide/datatypes.html new file mode 100644 index 0000000..98dd3a1 --- /dev/null +++ b/docs/documentation/streams/developer-guide/datatypes.html @@ -0,0 +1,19 @@ + + + + diff --git a/docs/documentation/streams/developer-guide/dsl-api.html b/docs/documentation/streams/developer-guide/dsl-api.html new file mode 100644 index 0000000..1bbc06d --- /dev/null +++ b/docs/documentation/streams/developer-guide/dsl-api.html @@ -0,0 +1,19 @@ + + + + diff --git a/docs/documentation/streams/developer-guide/dsl-topology-naming.html b/docs/documentation/streams/developer-guide/dsl-topology-naming.html new file mode 100644 index 0000000..9f42a04 --- /dev/null +++ b/docs/documentation/streams/developer-guide/dsl-topology-naming.html @@ -0,0 +1,19 @@ + + + + diff --git a/docs/documentation/streams/developer-guide/index.html b/docs/documentation/streams/developer-guide/index.html new file mode 100644 index 0000000..3a61247 --- /dev/null +++ b/docs/documentation/streams/developer-guide/index.html @@ -0,0 +1,19 @@ + + + + diff --git a/docs/documentation/streams/developer-guide/interactive-queries.html b/docs/documentation/streams/developer-guide/interactive-queries.html new file mode 100644 index 0000000..0506012 --- /dev/null +++ b/docs/documentation/streams/developer-guide/interactive-queries.html @@ -0,0 +1,19 @@ + + + + diff --git a/docs/documentation/streams/developer-guide/manage-topics.html b/docs/documentation/streams/developer-guide/manage-topics.html new file mode 100644 index 0000000..f422554 --- /dev/null +++ b/docs/documentation/streams/developer-guide/manage-topics.html @@ -0,0 +1,19 @@ + + + + diff --git a/docs/documentation/streams/developer-guide/memory-mgmt.html b/docs/documentation/streams/developer-guide/memory-mgmt.html new file mode 100644 index 0000000..024e137 --- /dev/null +++ b/docs/documentation/streams/developer-guide/memory-mgmt.html @@ -0,0 +1,19 @@ + + + + diff --git a/docs/documentation/streams/developer-guide/processor-api.html b/docs/documentation/streams/developer-guide/processor-api.html new file mode 100644 index 0000000..9e9ab91 --- /dev/null +++ b/docs/documentation/streams/developer-guide/processor-api.html @@ -0,0 +1,19 @@ + + + + diff --git a/docs/documentation/streams/developer-guide/running-app.html b/docs/documentation/streams/developer-guide/running-app.html new file mode 100644 index 0000000..05d5f0b --- /dev/null +++ b/docs/documentation/streams/developer-guide/running-app.html @@ -0,0 +1,19 @@ + + + + diff --git a/docs/documentation/streams/developer-guide/security.html b/docs/documentation/streams/developer-guide/security.html new file mode 100644 index 0000000..5d6e5f0 --- /dev/null +++ b/docs/documentation/streams/developer-guide/security.html @@ -0,0 +1,19 @@ + + + + diff --git a/docs/documentation/streams/developer-guide/testing.html b/docs/documentation/streams/developer-guide/testing.html new file mode 100644 index 0000000..4753e66 --- /dev/null +++ b/docs/documentation/streams/developer-guide/testing.html @@ -0,0 +1,19 @@ + + + + diff --git a/docs/documentation/streams/developer-guide/write-streams.html b/docs/documentation/streams/developer-guide/write-streams.html new file mode 100644 index 0000000..976c6fe --- /dev/null +++ b/docs/documentation/streams/developer-guide/write-streams.html @@ -0,0 +1,19 @@ + + + + diff --git a/docs/documentation/streams/index.html b/docs/documentation/streams/index.html new file mode 100644 index 0000000..5ff3b3b --- /dev/null +++ b/docs/documentation/streams/index.html @@ -0,0 +1,19 @@ + + + + diff --git a/docs/documentation/streams/quickstart.html b/docs/documentation/streams/quickstart.html new file mode 100644 index 0000000..efb0234 --- /dev/null +++ b/docs/documentation/streams/quickstart.html @@ -0,0 +1,19 @@ + + + + diff --git a/docs/documentation/streams/tutorial.html b/docs/documentation/streams/tutorial.html new file mode 100644 index 0000000..e2cf401 --- /dev/null +++ b/docs/documentation/streams/tutorial.html @@ -0,0 +1,19 @@ + + + + diff --git a/docs/documentation/streams/upgrade-guide.html b/docs/documentation/streams/upgrade-guide.html new file mode 100644 index 0000000..b1b3200 --- /dev/null +++ b/docs/documentation/streams/upgrade-guide.html @@ -0,0 +1,19 @@ + + + + diff --git a/docs/ecosystem.html b/docs/ecosystem.html new file mode 100644 index 0000000..5fbcec5 --- /dev/null +++ b/docs/ecosystem.html @@ -0,0 +1,18 @@ + + +There are a plethora of tools that integrate with Kafka outside the main distribution. The ecosystem page lists many of these, including stream processing systems, Hadoop integration, monitoring, and deployment tools. diff --git a/docs/generated/admin_client_config.html b/docs/generated/admin_client_config.html new file mode 100644 index 0000000..e69de29 diff --git a/docs/generated/connect_config.html b/docs/generated/connect_config.html new file mode 100644 index 0000000..e69de29 diff --git a/docs/generated/connect_metrics.html b/docs/generated/connect_metrics.html new file mode 100644 index 0000000..e69de29 diff --git a/docs/generated/connect_predicates.html b/docs/generated/connect_predicates.html new file mode 100644 index 0000000..e69de29 diff --git a/docs/generated/connect_transforms.html b/docs/generated/connect_transforms.html new file mode 100644 index 0000000..e69de29 diff --git a/docs/generated/consumer_config.html b/docs/generated/consumer_config.html new file mode 100644 index 0000000..e69de29 diff --git a/docs/generated/consumer_metrics.html b/docs/generated/consumer_metrics.html new file mode 100644 index 0000000..e69de29 diff --git a/docs/generated/kafka_config.html b/docs/generated/kafka_config.html new file mode 100644 index 0000000..e69de29 diff --git a/docs/generated/producer_config.html b/docs/generated/producer_config.html new file mode 100644 index 0000000..e69de29 diff --git a/docs/generated/producer_metrics.html b/docs/generated/producer_metrics.html new file mode 100644 index 0000000..e69de29 diff --git a/docs/generated/protocol_api_keys.html b/docs/generated/protocol_api_keys.html new file mode 100644 index 0000000..e69de29 diff --git a/docs/generated/protocol_errors.html b/docs/generated/protocol_errors.html new file mode 100644 index 0000000..e69de29 diff --git a/docs/generated/protocol_messages.html b/docs/generated/protocol_messages.html new file mode 100644 index 0000000..e69de29 diff --git a/docs/generated/protocol_types.html b/docs/generated/protocol_types.html new file mode 100644 index 0000000..e69de29 diff --git a/docs/generated/sink_connector_config.html b/docs/generated/sink_connector_config.html new file mode 100644 index 0000000..e69de29 diff --git a/docs/generated/source_connector_config.html b/docs/generated/source_connector_config.html new file mode 100644 index 0000000..e69de29 diff --git a/docs/generated/streams_config.html b/docs/generated/streams_config.html new file mode 100644 index 0000000..e69de29 diff --git a/docs/generated/topic_config.html b/docs/generated/topic_config.html new file mode 100644 index 0000000..e69de29 diff --git a/docs/images/consumer-groups.png b/docs/images/consumer-groups.png new file mode 100644 index 0000000..16fe293 Binary files /dev/null and b/docs/images/consumer-groups.png differ diff --git a/docs/images/icons/NYT.jpg b/docs/images/icons/NYT.jpg new file mode 100644 index 0000000..f4a7e8f Binary files /dev/null and b/docs/images/icons/NYT.jpg differ diff --git a/docs/images/icons/architecture--white.png b/docs/images/icons/architecture--white.png new file mode 100644 index 0000000..98b1b03 Binary files /dev/null and b/docs/images/icons/architecture--white.png differ diff --git a/docs/images/icons/architecture.png b/docs/images/icons/architecture.png new file mode 100644 index 0000000..6f9fd40 Binary files /dev/null and b/docs/images/icons/architecture.png differ diff --git a/docs/images/icons/documentation--white.png b/docs/images/icons/documentation--white.png new file mode 100644 index 0000000..1e8fd97 Binary files /dev/null and b/docs/images/icons/documentation--white.png differ diff --git a/docs/images/icons/documentation.png b/docs/images/icons/documentation.png new file mode 100644 index 0000000..8d9da19 Binary files /dev/null and b/docs/images/icons/documentation.png differ diff --git a/docs/images/icons/line.png b/docs/images/icons/line.png new file mode 100755 index 0000000..4587d21 Binary files /dev/null and b/docs/images/icons/line.png differ diff --git a/docs/images/icons/new-york.png b/docs/images/icons/new-york.png new file mode 100755 index 0000000..42a4b0b Binary files /dev/null and b/docs/images/icons/new-york.png differ diff --git a/docs/images/icons/rabobank.png b/docs/images/icons/rabobank.png new file mode 100755 index 0000000..ddad710 Binary files /dev/null and b/docs/images/icons/rabobank.png differ diff --git a/docs/images/icons/tutorials--white.png b/docs/images/icons/tutorials--white.png new file mode 100644 index 0000000..97a0c04 Binary files /dev/null and b/docs/images/icons/tutorials--white.png differ diff --git a/docs/images/icons/tutorials.png b/docs/images/icons/tutorials.png new file mode 100644 index 0000000..983da6c Binary files /dev/null and b/docs/images/icons/tutorials.png differ diff --git a/docs/images/icons/zalando.png b/docs/images/icons/zalando.png new file mode 100755 index 0000000..719a7dc Binary files /dev/null and b/docs/images/icons/zalando.png differ diff --git a/docs/images/kafka-apis.png b/docs/images/kafka-apis.png new file mode 100644 index 0000000..db6053c Binary files /dev/null and b/docs/images/kafka-apis.png differ diff --git a/docs/images/kafka_log.png b/docs/images/kafka_log.png new file mode 100644 index 0000000..75abd96 Binary files /dev/null and b/docs/images/kafka_log.png differ diff --git a/docs/images/kafka_multidc.png b/docs/images/kafka_multidc.png new file mode 100644 index 0000000..7bc56f4 Binary files /dev/null and b/docs/images/kafka_multidc.png differ diff --git a/docs/images/kafka_multidc_complex.png b/docs/images/kafka_multidc_complex.png new file mode 100644 index 0000000..ab88deb Binary files /dev/null and b/docs/images/kafka_multidc_complex.png differ diff --git a/docs/images/log_anatomy.png b/docs/images/log_anatomy.png new file mode 100644 index 0000000..a649499 Binary files /dev/null and b/docs/images/log_anatomy.png differ diff --git a/docs/images/log_cleaner_anatomy.png b/docs/images/log_cleaner_anatomy.png new file mode 100644 index 0000000..fb425b0 Binary files /dev/null and b/docs/images/log_cleaner_anatomy.png differ diff --git a/docs/images/log_compaction.png b/docs/images/log_compaction.png new file mode 100644 index 0000000..4e4a833 Binary files /dev/null and b/docs/images/log_compaction.png differ diff --git a/docs/images/log_consumer.png b/docs/images/log_consumer.png new file mode 100644 index 0000000..fbc45f2 Binary files /dev/null and b/docs/images/log_consumer.png differ diff --git a/docs/images/mirror-maker.png b/docs/images/mirror-maker.png new file mode 100644 index 0000000..8f76b1f Binary files /dev/null and b/docs/images/mirror-maker.png differ diff --git a/docs/images/producer_consumer.png b/docs/images/producer_consumer.png new file mode 100644 index 0000000..4b10cc9 Binary files /dev/null and b/docs/images/producer_consumer.png differ diff --git a/docs/images/streams-architecture-overview.jpg b/docs/images/streams-architecture-overview.jpg new file mode 100644 index 0000000..9222079 Binary files /dev/null and b/docs/images/streams-architecture-overview.jpg differ diff --git a/docs/images/streams-architecture-states.jpg b/docs/images/streams-architecture-states.jpg new file mode 100644 index 0000000..fde12db Binary files /dev/null and b/docs/images/streams-architecture-states.jpg differ diff --git a/docs/images/streams-architecture-tasks.jpg b/docs/images/streams-architecture-tasks.jpg new file mode 100644 index 0000000..2e957f9 Binary files /dev/null and b/docs/images/streams-architecture-tasks.jpg differ diff --git a/docs/images/streams-architecture-threads.jpg b/docs/images/streams-architecture-threads.jpg new file mode 100644 index 0000000..d5f10db Binary files /dev/null and b/docs/images/streams-architecture-threads.jpg differ diff --git a/docs/images/streams-architecture-topology.jpg b/docs/images/streams-architecture-topology.jpg new file mode 100644 index 0000000..f42e8cd Binary files /dev/null and b/docs/images/streams-architecture-topology.jpg differ diff --git a/docs/images/streams-cache-and-commit-interval.png b/docs/images/streams-cache-and-commit-interval.png new file mode 100644 index 0000000..a663bc6 Binary files /dev/null and b/docs/images/streams-cache-and-commit-interval.png differ diff --git a/docs/images/streams-concepts-topology.jpg b/docs/images/streams-concepts-topology.jpg new file mode 100644 index 0000000..832f6d4 Binary files /dev/null and b/docs/images/streams-concepts-topology.jpg differ diff --git a/docs/images/streams-elastic-scaling-1.png b/docs/images/streams-elastic-scaling-1.png new file mode 100644 index 0000000..7823ac1 Binary files /dev/null and b/docs/images/streams-elastic-scaling-1.png differ diff --git a/docs/images/streams-elastic-scaling-2.png b/docs/images/streams-elastic-scaling-2.png new file mode 100644 index 0000000..374b5ff Binary files /dev/null and b/docs/images/streams-elastic-scaling-2.png differ diff --git a/docs/images/streams-elastic-scaling-3.png b/docs/images/streams-elastic-scaling-3.png new file mode 100644 index 0000000..0b4adaf Binary files /dev/null and b/docs/images/streams-elastic-scaling-3.png differ diff --git a/docs/images/streams-interactive-queries-01.png b/docs/images/streams-interactive-queries-01.png new file mode 100644 index 0000000..d5d5031 Binary files /dev/null and b/docs/images/streams-interactive-queries-01.png differ diff --git a/docs/images/streams-interactive-queries-02.png b/docs/images/streams-interactive-queries-02.png new file mode 100644 index 0000000..ea894b6 Binary files /dev/null and b/docs/images/streams-interactive-queries-02.png differ diff --git a/docs/images/streams-interactive-queries-03.png b/docs/images/streams-interactive-queries-03.png new file mode 100644 index 0000000..403e3ae Binary files /dev/null and b/docs/images/streams-interactive-queries-03.png differ diff --git a/docs/images/streams-interactive-queries-api-01.png b/docs/images/streams-interactive-queries-api-01.png new file mode 100644 index 0000000..2b4aaed Binary files /dev/null and b/docs/images/streams-interactive-queries-api-01.png differ diff --git a/docs/images/streams-interactive-queries-api-02.png b/docs/images/streams-interactive-queries-api-02.png new file mode 100644 index 0000000..e5e7527 Binary files /dev/null and b/docs/images/streams-interactive-queries-api-02.png differ diff --git a/docs/images/streams-session-windows-01.png b/docs/images/streams-session-windows-01.png new file mode 100644 index 0000000..2d711d8 Binary files /dev/null and b/docs/images/streams-session-windows-01.png differ diff --git a/docs/images/streams-session-windows-02.png b/docs/images/streams-session-windows-02.png new file mode 100644 index 0000000..6c0382f Binary files /dev/null and b/docs/images/streams-session-windows-02.png differ diff --git a/docs/images/streams-sliding-windows.png b/docs/images/streams-sliding-windows.png new file mode 100644 index 0000000..fa6d5c3 Binary files /dev/null and b/docs/images/streams-sliding-windows.png differ diff --git a/docs/images/streams-stateful_operations.png b/docs/images/streams-stateful_operations.png new file mode 100644 index 0000000..b0fe3de Binary files /dev/null and b/docs/images/streams-stateful_operations.png differ diff --git a/docs/images/streams-table-duality-01.png b/docs/images/streams-table-duality-01.png new file mode 100644 index 0000000..4fa4d1b Binary files /dev/null and b/docs/images/streams-table-duality-01.png differ diff --git a/docs/images/streams-table-duality-02.png b/docs/images/streams-table-duality-02.png new file mode 100644 index 0000000..4e805c1 Binary files /dev/null and b/docs/images/streams-table-duality-02.png differ diff --git a/docs/images/streams-table-duality-03.png b/docs/images/streams-table-duality-03.png new file mode 100644 index 0000000..b0b04f5 Binary files /dev/null and b/docs/images/streams-table-duality-03.png differ diff --git a/docs/images/streams-table-updates-01.png b/docs/images/streams-table-updates-01.png new file mode 100644 index 0000000..3a2c35e Binary files /dev/null and b/docs/images/streams-table-updates-01.png differ diff --git a/docs/images/streams-table-updates-02.png b/docs/images/streams-table-updates-02.png new file mode 100644 index 0000000..a0a5b1f Binary files /dev/null and b/docs/images/streams-table-updates-02.png differ diff --git a/docs/images/streams-time-windows-hopping.png b/docs/images/streams-time-windows-hopping.png new file mode 100644 index 0000000..5fcb9d2 Binary files /dev/null and b/docs/images/streams-time-windows-hopping.png differ diff --git a/docs/images/streams-time-windows-tumbling.png b/docs/images/streams-time-windows-tumbling.png new file mode 100644 index 0000000..571ab79 Binary files /dev/null and b/docs/images/streams-time-windows-tumbling.png differ diff --git a/docs/images/streams-welcome.png b/docs/images/streams-welcome.png new file mode 100644 index 0000000..63918c4 Binary files /dev/null and b/docs/images/streams-welcome.png differ diff --git a/docs/images/tracking_high_level.png b/docs/images/tracking_high_level.png new file mode 100644 index 0000000..b643230 Binary files /dev/null and b/docs/images/tracking_high_level.png differ diff --git a/docs/implementation.html b/docs/implementation.html new file mode 100644 index 0000000..5d75ccb --- /dev/null +++ b/docs/implementation.html @@ -0,0 +1,293 @@ + + + + +
                diff --git a/docs/introduction.html b/docs/introduction.html new file mode 100644 index 0000000..49de2fa --- /dev/null +++ b/docs/introduction.html @@ -0,0 +1,220 @@ + + + + + + +
                diff --git a/docs/js/templateData.js b/docs/js/templateData.js new file mode 100644 index 0000000..45e1eb2 --- /dev/null +++ b/docs/js/templateData.js @@ -0,0 +1,24 @@ +/* +Licensed to the Apache Software Foundation (ASF) under one or more +contributor license agreements. See the NOTICE file distributed with +this work for additional information regarding copyright ownership. +The ASF licenses this file to You under the Apache License, Version 2.0 +(the "License"); you may not use this file except in compliance with +the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +// Define variables for doc templates +var context={ + "version": "31", + "dotVersion": "3.1", + "fullDotVersion": "3.1.0", + "scalaVersion": "2.13" +}; diff --git a/docs/migration.html b/docs/migration.html new file mode 100644 index 0000000..95fc87f --- /dev/null +++ b/docs/migration.html @@ -0,0 +1,34 @@ + + + +

                Migrating from 0.7.x to 0.8

                + +0.8 is our first (and hopefully last) release with a non-backwards-compatible wire protocol, ZooKeeper layout, and on-disk data format. This was a chance for us to clean up a lot of cruft and start fresh. This means performing a no-downtime upgrade is more painful than normal—you cannot just swap in the new code in-place. + +

                Migration Steps

                + +
                  +
                1. Setup a new cluster running 0.8. +
                2. Use the 0.7 to 0.8 migration tool to mirror data from the 0.7 cluster into the 0.8 cluster. +
                3. When the 0.8 cluster is fully caught up, redeploy all data consumers running the 0.8 client and reading from the 0.8 cluster. +
                4. Finally migrate all 0.7 producers to 0.8 client publishing data to the 0.8 cluster. +
                5. Decommission the 0.7 cluster. +
                6. Drink. +
                + + diff --git a/docs/ops.html b/docs/ops.html new file mode 100644 index 0000000..b5d3bd5 --- /dev/null +++ b/docs/ops.html @@ -0,0 +1,3120 @@ + + + Here is some information on actually running Kafka as a production system based on usage and experience at LinkedIn. Please send us any additional tips you know of. + +

                6.1 Basic Kafka Operations

                + + This section will review the most common operations you will perform on your Kafka cluster. All of the tools reviewed in this section are available under the bin/ directory of the Kafka distribution and each tool will print details on all possible commandline options if it is run with no arguments. + +

                Adding and removing topics

                + + You have the option of either adding topics manually or having them be created automatically when data is first published to a non-existent topic. If topics are auto-created then you may want to tune the default topic configurations used for auto-created topics. +

                + Topics are added and modified using the topic tool: +

                  > bin/kafka-topics.sh --bootstrap-server broker_host:port --create --topic my_topic_name \
                +        --partitions 20 --replication-factor 3 --config x=y
                + The replication factor controls how many servers will replicate each message that is written. If you have a replication factor of 3 then up to 2 servers can fail before you will lose access to your data. We recommend you use a replication factor of 2 or 3 so that you can transparently bounce machines without interrupting data consumption. +

                + The partition count controls how many logs the topic will be sharded into. There are several impacts of the partition count. First each partition must fit entirely on a single server. So if you have 20 partitions the full data set (and read and write load) will be handled by no more than 20 servers (not counting replicas). Finally the partition count impacts the maximum parallelism of your consumers. This is discussed in greater detail in the concepts section. +

                + Each sharded partition log is placed into its own folder under the Kafka log directory. The name of such folders consists of the topic name, appended by a dash (-) and the partition id. Since a typical folder name can not be over 255 characters long, there will be a limitation on the length of topic names. We assume the number of partitions will not ever be above 100,000. Therefore, topic names cannot be longer than 249 characters. This leaves just enough room in the folder name for a dash and a potentially 5 digit long partition id. +

                + The configurations added on the command line override the default settings the server has for things like the length of time data should be retained. The complete set of per-topic configurations is documented here. + +

                Modifying topics

                + + You can change the configuration or partitioning of a topic using the same topic tool. +

                + To add partitions you can do +

                  > bin/kafka-topics.sh --bootstrap-server broker_host:port --alter --topic my_topic_name \
                +        --partitions 40
                + Be aware that one use case for partitions is to semantically partition data, and adding partitions doesn't change the partitioning of existing data so this may disturb consumers if they rely on that partition. That is if data is partitioned by hash(key) % number_of_partitions then this partitioning will potentially be shuffled by adding partitions but Kafka will not attempt to automatically redistribute data in any way. +

                + To add configs: +

                  > bin/kafka-configs.sh --bootstrap-server broker_host:port --entity-type topics --entity-name my_topic_name --alter --add-config x=y
                + To remove a config: +
                  > bin/kafka-configs.sh --bootstrap-server broker_host:port --entity-type topics --entity-name my_topic_name --alter --delete-config x
                + And finally deleting a topic: +
                  > bin/kafka-topics.sh --bootstrap-server broker_host:port --delete --topic my_topic_name
                +

                + Kafka does not currently support reducing the number of partitions for a topic. +

                + Instructions for changing the replication factor of a topic can be found here. + +

                Graceful shutdown

                + + The Kafka cluster will automatically detect any broker shutdown or failure and elect new leaders for the partitions on that machine. This will occur whether a server fails or it is brought down intentionally for maintenance or configuration changes. For the latter cases Kafka supports a more graceful mechanism for stopping a server than just killing it. + + When a server is stopped gracefully it has two optimizations it will take advantage of: +
                  +
                1. It will sync all its logs to disk to avoid needing to do any log recovery when it restarts (i.e. validating the checksum for all messages in the tail of the log). Log recovery takes time so this speeds up intentional restarts. +
                2. It will migrate any partitions the server is the leader for to other replicas prior to shutting down. This will make the leadership transfer faster and minimize the time each partition is unavailable to a few milliseconds. +
                + + Syncing the logs will happen automatically whenever the server is stopped other than by a hard kill, but the controlled leadership migration requires using a special setting: +
                      controlled.shutdown.enable=true
                + Note that controlled shutdown will only succeed if all the partitions hosted on the broker have replicas (i.e. the replication factor is greater than 1 and at least one of these replicas is alive). This is generally what you want since shutting down the last replica would make that topic partition unavailable. + +

                Balancing leadership

                + + Whenever a broker stops or crashes, leadership for that broker's partitions transfers to other replicas. When the broker is restarted it will only be a follower for all its partitions, meaning it will not be used for client reads and writes. +

                + To avoid this imbalance, Kafka has a notion of preferred replicas. If the list of replicas for a partition is 1,5,9 then node 1 is preferred as the leader to either node 5 or 9 because it is earlier in the replica list. By default the Kafka cluster will try to restore leadership to the preferred replicas. This behaviour is configured with: + +

                      auto.leader.rebalance.enable=true
                + You can also set this to false, but you will then need to manually restore leadership to the restored replicas by running the command: +
                  > bin/kafka-preferred-replica-election.sh --bootstrap-server broker_host:port
                + +

                Balancing Replicas Across Racks

                + The rack awareness feature spreads replicas of the same partition across different racks. This extends the guarantees Kafka provides for broker-failure to cover rack-failure, limiting the risk of data loss should all the brokers on a rack fail at once. The feature can also be applied to other broker groupings such as availability zones in EC2. +

                + You can specify that a broker belongs to a particular rack by adding a property to the broker config: +
                  broker.rack=my-rack-id
                + When a topic is created, modified or replicas are redistributed, the rack constraint will be honoured, ensuring replicas span as many racks as they can (a partition will span min(#racks, replication-factor) different racks). +

                + The algorithm used to assign replicas to brokers ensures that the number of leaders per broker will be constant, regardless of how brokers are distributed across racks. This ensures balanced throughput. +

                + However if racks are assigned different numbers of brokers, the assignment of replicas will not be even. Racks with fewer brokers will get more replicas, meaning they will use more storage and put more resources into replication. Hence it is sensible to configure an equal number of brokers per rack. + +

                Mirroring data between clusters & Geo-replication

                + +

                + Kafka administrators can define data flows that cross the boundaries of individual Kafka clusters, data centers, or geographical regions. Please refer to the section on Geo-Replication for further information. +

                + +

                Checking consumer position

                + Sometimes it's useful to see the position of your consumers. We have a tool that will show the position of all consumers in a consumer group as well as how far behind the end of the log they are. To run this tool on a consumer group named my-group consuming a topic named my-topic would look like this: +
                  > bin/kafka-consumer-groups.sh --bootstrap-server localhost:9092 --describe --group my-group
                +
                +  TOPIC                          PARTITION  CURRENT-OFFSET  LOG-END-OFFSET  LAG        CONSUMER-ID                                       HOST                           CLIENT-ID
                +  my-topic                       0          2               4               2          consumer-1-029af89c-873c-4751-a720-cefd41a669d6   /127.0.0.1                     consumer-1
                +  my-topic                       1          2               3               1          consumer-1-029af89c-873c-4751-a720-cefd41a669d6   /127.0.0.1                     consumer-1
                +  my-topic                       2          2               3               1          consumer-2-42c1abd4-e3b2-425d-a8bb-e1ea49b29bb2   /127.0.0.1                     consumer-2
                + +

                Managing Consumer Groups

                + + With the ConsumerGroupCommand tool, we can list, describe, or delete the consumer groups. The consumer group can be deleted manually, or automatically when the last committed offset for that group expires. Manual deletion works only if the group does not have any active members. + + For example, to list all consumer groups across all topics: + +
                  > bin/kafka-consumer-groups.sh --bootstrap-server localhost:9092 --list
                +
                +  test-consumer-group
                + + To view offsets, as mentioned earlier, we "describe" the consumer group like this: + +
                  > bin/kafka-consumer-groups.sh --bootstrap-server localhost:9092 --describe --group my-group
                +
                +  TOPIC           PARTITION  CURRENT-OFFSET  LOG-END-OFFSET  LAG             CONSUMER-ID                                    HOST            CLIENT-ID
                +  topic3          0          241019          395308          154289          consumer2-e76ea8c3-5d30-4299-9005-47eb41f3d3c4 /127.0.0.1      consumer2
                +  topic2          1          520678          803288          282610          consumer2-e76ea8c3-5d30-4299-9005-47eb41f3d3c4 /127.0.0.1      consumer2
                +  topic3          1          241018          398817          157799          consumer2-e76ea8c3-5d30-4299-9005-47eb41f3d3c4 /127.0.0.1      consumer2
                +  topic1          0          854144          855809          1665            consumer1-3fc8d6f1-581a-4472-bdf3-3515b4aee8c1 /127.0.0.1      consumer1
                +  topic2          0          460537          803290          342753          consumer1-3fc8d6f1-581a-4472-bdf3-3515b4aee8c1 /127.0.0.1      consumer1
                +  topic3          2          243655          398812          155157          consumer4-117fe4d3-c6c1-4178-8ee9-eb4a3954bee0 /127.0.0.1      consumer4
                + + There are a number of additional "describe" options that can be used to provide more detailed information about a consumer group: +
                  +
                • --members: This option provides the list of all active members in the consumer group. +
                        > bin/kafka-consumer-groups.sh --bootstrap-server localhost:9092 --describe --group my-group --members
                  +
                  +      CONSUMER-ID                                    HOST            CLIENT-ID       #PARTITIONS
                  +      consumer1-3fc8d6f1-581a-4472-bdf3-3515b4aee8c1 /127.0.0.1      consumer1       2
                  +      consumer4-117fe4d3-c6c1-4178-8ee9-eb4a3954bee0 /127.0.0.1      consumer4       1
                  +      consumer2-e76ea8c3-5d30-4299-9005-47eb41f3d3c4 /127.0.0.1      consumer2       3
                  +      consumer3-ecea43e4-1f01-479f-8349-f9130b75d8ee /127.0.0.1      consumer3       0
                  +
                • +
                • --members --verbose: On top of the information reported by the "--members" options above, this option also provides the partitions assigned to each member. +
                        > bin/kafka-consumer-groups.sh --bootstrap-server localhost:9092 --describe --group my-group --members --verbose
                  +
                  +      CONSUMER-ID                                    HOST            CLIENT-ID       #PARTITIONS     ASSIGNMENT
                  +      consumer1-3fc8d6f1-581a-4472-bdf3-3515b4aee8c1 /127.0.0.1      consumer1       2               topic1(0), topic2(0)
                  +      consumer4-117fe4d3-c6c1-4178-8ee9-eb4a3954bee0 /127.0.0.1      consumer4       1               topic3(2)
                  +      consumer2-e76ea8c3-5d30-4299-9005-47eb41f3d3c4 /127.0.0.1      consumer2       3               topic2(1), topic3(0,1)
                  +      consumer3-ecea43e4-1f01-479f-8349-f9130b75d8ee /127.0.0.1      consumer3       0               -
                  +
                • +
                • --offsets: This is the default describe option and provides the same output as the "--describe" option.
                • +
                • --state: This option provides useful group-level information. +
                        > bin/kafka-consumer-groups.sh --bootstrap-server localhost:9092 --describe --group my-group --state
                  +
                  +      COORDINATOR (ID)          ASSIGNMENT-STRATEGY       STATE                #MEMBERS
                  +      localhost:9092 (0)        range                     Stable               4
                  +
                • +
                + + To manually delete one or multiple consumer groups, the "--delete" option can be used: +
                  > bin/kafka-consumer-groups.sh --bootstrap-server localhost:9092 --delete --group my-group --group my-other-group
                +
                +  Deletion of requested consumer groups ('my-group', 'my-other-group') was successful.
                + +

                + To reset offsets of a consumer group, "--reset-offsets" option can be used. + This option supports one consumer group at the time. It requires defining following scopes: --all-topics or --topic. One scope must be selected, unless you use '--from-file' scenario. Also, first make sure that the consumer instances are inactive. + See KIP-122 for more details. + +

                + It has 3 execution options: +

                  +
                • + (default) to display which offsets to reset. +
                • +
                • + --execute : to execute --reset-offsets process. +
                • +
                • + --export : to export the results to a CSV format. +
                • +
                + +

                + --reset-offsets also has following scenarios to choose from (at least one scenario must be selected): +

                  +
                • + --to-datetime <String: datetime> : Reset offsets to offsets from datetime. Format: 'YYYY-MM-DDTHH:mm:SS.sss' +
                • +
                • + --to-earliest : Reset offsets to earliest offset. +
                • +
                • + --to-latest : Reset offsets to latest offset. +
                • +
                • + --shift-by <Long: number-of-offsets> : Reset offsets shifting current offset by 'n', where 'n' can be positive or negative. +
                • +
                • + --from-file : Reset offsets to values defined in CSV file. +
                • +
                • + --to-current : Resets offsets to current offset. +
                • +
                • + --by-duration <String: duration> : Reset offsets to offset by duration from current timestamp. Format: 'PnDTnHnMnS' +
                • +
                • + --to-offset : Reset offsets to a specific offset. +
                • +
                + + Please note, that out of range offsets will be adjusted to available offset end. For example, if offset end is at 10 and offset shift request is + of 15, then, offset at 10 will actually be selected. + +

                + For example, to reset offsets of a consumer group to the latest offset: + +

                  > bin/kafka-consumer-groups.sh --bootstrap-server localhost:9092 --reset-offsets --group consumergroup1 --topic topic1 --to-latest
                +
                +  TOPIC                          PARTITION  NEW-OFFSET
                +  topic1                         0          0
                + +

                + + If you are using the old high-level consumer and storing the group metadata in ZooKeeper (i.e. offsets.storage=zookeeper), pass + --zookeeper instead of --bootstrap-server: + +

                  > bin/kafka-consumer-groups.sh --zookeeper localhost:2181 --list
                + +

                Expanding your cluster

                + + Adding servers to a Kafka cluster is easy, just assign them a unique broker id and start up Kafka on your new servers. However these new servers will not automatically be assigned any data partitions, so unless partitions are moved to them they won't be doing any work until new topics are created. So usually when you add machines to your cluster you will want to migrate some existing data to these machines. +

                + The process of migrating data is manually initiated but fully automated. Under the covers what happens is that Kafka will add the new server as a follower of the partition it is migrating and allow it to fully replicate the existing data in that partition. When the new server has fully replicated the contents of this partition and joined the in-sync replica one of the existing replicas will delete their partition's data. +

                + The partition reassignment tool can be used to move partitions across brokers. An ideal partition distribution would ensure even data load and partition sizes across all brokers. The partition reassignment tool does not have the capability to automatically study the data distribution in a Kafka cluster and move partitions around to attain an even load distribution. As such, the admin has to figure out which topics or partitions should be moved around. +

                + The partition reassignment tool can run in 3 mutually exclusive modes: +

                  +
                • --generate: In this mode, given a list of topics and a list of brokers, the tool generates a candidate reassignment to move all partitions of the specified topics to the new brokers. This option merely provides a convenient way to generate a partition reassignment plan given a list of topics and target brokers.
                • +
                • --execute: In this mode, the tool kicks off the reassignment of partitions based on the user provided reassignment plan. (using the --reassignment-json-file option). This can either be a custom reassignment plan hand crafted by the admin or provided by using the --generate option
                • +
                • --verify: In this mode, the tool verifies the status of the reassignment for all partitions listed during the last --execute. The status can be either of successfully completed, failed or in progress
                • +
                +
                Automatically migrating data to new machines
                + The partition reassignment tool can be used to move some topics off of the current set of brokers to the newly added brokers. This is typically useful while expanding an existing cluster since it is easier to move entire topics to the new set of brokers, than moving one partition at a time. When used to do this, the user should provide a list of topics that should be moved to the new set of brokers and a target list of new brokers. The tool then evenly distributes all partitions for the given list of topics across the new set of brokers. During this move, the replication factor of the topic is kept constant. Effectively the replicas for all partitions for the input list of topics are moved from the old set of brokers to the newly added brokers. +

                + For instance, the following example will move all partitions for topics foo1,foo2 to the new set of brokers 5,6. At the end of this move, all partitions for topics foo1 and foo2 will only exist on brokers 5,6. +

                + Since the tool accepts the input list of topics as a json file, you first need to identify the topics you want to move and create the json file as follows: +

                  > cat topics-to-move.json
                +  {"topics": [{"topic": "foo1"},
                +              {"topic": "foo2"}],
                +  "version":1
                +  }
                + Once the json file is ready, use the partition reassignment tool to generate a candidate assignment: +
                  > bin/kafka-reassign-partitions.sh --bootstrap-server localhost:9092 --topics-to-move-json-file topics-to-move.json --broker-list "5,6" --generate
                +  Current partition replica assignment
                +
                +  {"version":1,
                +  "partitions":[{"topic":"foo1","partition":2,"replicas":[1,2]},
                +                {"topic":"foo1","partition":0,"replicas":[3,4]},
                +                {"topic":"foo2","partition":2,"replicas":[1,2]},
                +                {"topic":"foo2","partition":0,"replicas":[3,4]},
                +                {"topic":"foo1","partition":1,"replicas":[2,3]},
                +                {"topic":"foo2","partition":1,"replicas":[2,3]}]
                +  }
                +
                +  Proposed partition reassignment configuration
                +
                +  {"version":1,
                +  "partitions":[{"topic":"foo1","partition":2,"replicas":[5,6]},
                +                {"topic":"foo1","partition":0,"replicas":[5,6]},
                +                {"topic":"foo2","partition":2,"replicas":[5,6]},
                +                {"topic":"foo2","partition":0,"replicas":[5,6]},
                +                {"topic":"foo1","partition":1,"replicas":[5,6]},
                +                {"topic":"foo2","partition":1,"replicas":[5,6]}]
                +  }
                +

                + The tool generates a candidate assignment that will move all partitions from topics foo1,foo2 to brokers 5,6. Note, however, that at this point, the partition movement has not started, it merely tells you the current assignment and the proposed new assignment. The current assignment should be saved in case you want to rollback to it. The new assignment should be saved in a json file (e.g. expand-cluster-reassignment.json) to be input to the tool with the --execute option as follows: +

                  > bin/kafka-reassign-partitions.sh --bootstrap-server localhost:9092 --reassignment-json-file expand-cluster-reassignment.json --execute
                +  Current partition replica assignment
                +
                +  {"version":1,
                +  "partitions":[{"topic":"foo1","partition":2,"replicas":[1,2]},
                +                {"topic":"foo1","partition":0,"replicas":[3,4]},
                +                {"topic":"foo2","partition":2,"replicas":[1,2]},
                +                {"topic":"foo2","partition":0,"replicas":[3,4]},
                +                {"topic":"foo1","partition":1,"replicas":[2,3]},
                +                {"topic":"foo2","partition":1,"replicas":[2,3]}]
                +  }
                +
                +  Save this to use as the --reassignment-json-file option during rollback
                +  Successfully started reassignment of partitions
                +  {"version":1,
                +  "partitions":[{"topic":"foo1","partition":2,"replicas":[5,6]},
                +                {"topic":"foo1","partition":0,"replicas":[5,6]},
                +                {"topic":"foo2","partition":2,"replicas":[5,6]},
                +                {"topic":"foo2","partition":0,"replicas":[5,6]},
                +                {"topic":"foo1","partition":1,"replicas":[5,6]},
                +                {"topic":"foo2","partition":1,"replicas":[5,6]}]
                +  }
                +

                + Finally, the --verify option can be used with the tool to check the status of the partition reassignment. Note that the same expand-cluster-reassignment.json (used with the --execute option) should be used with the --verify option: +

                  > bin/kafka-reassign-partitions.sh --bootstrap-server localhost:9092 --reassignment-json-file expand-cluster-reassignment.json --verify
                +  Status of partition reassignment:
                +  Reassignment of partition [foo1,0] completed successfully
                +  Reassignment of partition [foo1,1] is in progress
                +  Reassignment of partition [foo1,2] is in progress
                +  Reassignment of partition [foo2,0] completed successfully
                +  Reassignment of partition [foo2,1] completed successfully
                +  Reassignment of partition [foo2,2] completed successfully
                + +
                Custom partition assignment and migration
                + The partition reassignment tool can also be used to selectively move replicas of a partition to a specific set of brokers. When used in this manner, it is assumed that the user knows the reassignment plan and does not require the tool to generate a candidate reassignment, effectively skipping the --generate step and moving straight to the --execute step +

                + For instance, the following example moves partition 0 of topic foo1 to brokers 5,6 and partition 1 of topic foo2 to brokers 2,3: +

                + The first step is to hand craft the custom reassignment plan in a json file: +

                  > cat custom-reassignment.json
                +  {"version":1,"partitions":[{"topic":"foo1","partition":0,"replicas":[5,6]},{"topic":"foo2","partition":1,"replicas":[2,3]}]}
                + Then, use the json file with the --execute option to start the reassignment process: +
                  > bin/kafka-reassign-partitions.sh --bootstrap-server localhost:9092 --reassignment-json-file custom-reassignment.json --execute
                +  Current partition replica assignment
                +
                +  {"version":1,
                +  "partitions":[{"topic":"foo1","partition":0,"replicas":[1,2]},
                +                {"topic":"foo2","partition":1,"replicas":[3,4]}]
                +  }
                +
                +  Save this to use as the --reassignment-json-file option during rollback
                +  Successfully started reassignment of partitions
                +  {"version":1,
                +  "partitions":[{"topic":"foo1","partition":0,"replicas":[5,6]},
                +                {"topic":"foo2","partition":1,"replicas":[2,3]}]
                +  }
                +

                + The --verify option can be used with the tool to check the status of the partition reassignment. Note that the same custom-reassignment.json (used with the --execute option) should be used with the --verify option: +

                  > bin/kafka-reassign-partitions.sh --bootstrap-server localhost:9092 --reassignment-json-file custom-reassignment.json --verify
                +  Status of partition reassignment:
                +  Reassignment of partition [foo1,0] completed successfully
                +  Reassignment of partition [foo2,1] completed successfully
                + +

                Decommissioning brokers

                + The partition reassignment tool does not have the ability to automatically generate a reassignment plan for decommissioning brokers yet. As such, the admin has to come up with a reassignment plan to move the replica for all partitions hosted on the broker to be decommissioned, to the rest of the brokers. This can be relatively tedious as the reassignment needs to ensure that all the replicas are not moved from the decommissioned broker to only one other broker. To make this process effortless, we plan to add tooling support for decommissioning brokers in the future. + +

                Increasing replication factor

                + Increasing the replication factor of an existing partition is easy. Just specify the extra replicas in the custom reassignment json file and use it with the --execute option to increase the replication factor of the specified partitions. +

                + For instance, the following example increases the replication factor of partition 0 of topic foo from 1 to 3. Before increasing the replication factor, the partition's only replica existed on broker 5. As part of increasing the replication factor, we will add more replicas on brokers 6 and 7. +

                + The first step is to hand craft the custom reassignment plan in a json file: +

                  > cat increase-replication-factor.json
                +  {"version":1,
                +  "partitions":[{"topic":"foo","partition":0,"replicas":[5,6,7]}]}
                + Then, use the json file with the --execute option to start the reassignment process: +
                  > bin/kafka-reassign-partitions.sh --bootstrap-server localhost:9092 --reassignment-json-file increase-replication-factor.json --execute
                +  Current partition replica assignment
                +
                +  {"version":1,
                +  "partitions":[{"topic":"foo","partition":0,"replicas":[5]}]}
                +
                +  Save this to use as the --reassignment-json-file option during rollback
                +  Successfully started reassignment of partitions
                +  {"version":1,
                +  "partitions":[{"topic":"foo","partition":0,"replicas":[5,6,7]}]}
                +

                + The --verify option can be used with the tool to check the status of the partition reassignment. Note that the same increase-replication-factor.json (used with the --execute option) should be used with the --verify option: +

                  > bin/kafka-reassign-partitions.sh --bootstrap-server localhost:9092 --reassignment-json-file increase-replication-factor.json --verify
                +  Status of partition reassignment:
                +  Reassignment of partition [foo,0] completed successfully
                + You can also verify the increase in replication factor with the kafka-topics tool: +
                  > bin/kafka-topics.sh --bootstrap-server localhost:9092 --topic foo --describe
                +  Topic:foo	PartitionCount:1	ReplicationFactor:3	Configs:
                +    Topic: foo	Partition: 0	Leader: 5	Replicas: 5,6,7	Isr: 5,6,7
                + +

                Limiting Bandwidth Usage during Data Migration

                + Kafka lets you apply a throttle to replication traffic, setting an upper bound on the bandwidth used to move replicas from machine to machine. This is useful when rebalancing a cluster, bootstrapping a new broker or adding or removing brokers, as it limits the impact these data-intensive operations will have on users. +

                + There are two interfaces that can be used to engage a throttle. The simplest, and safest, is to apply a throttle when invoking the kafka-reassign-partitions.sh, but kafka-configs.sh can also be used to view and alter the throttle values directly. +

                + So for example, if you were to execute a rebalance, with the below command, it would move partitions at no more than 50MB/s. +
                $ bin/kafka-reassign-partitions.sh --bootstrap-server localhost:9092 --execute --reassignment-json-file bigger-cluster.json --throttle 50000000
                + When you execute this script you will see the throttle engage: +
                  The throttle limit was set to 50000000 B/s
                +  Successfully started reassignment of partitions.
                +

                Should you wish to alter the throttle, during a rebalance, say to increase the throughput so it completes quicker, you can do this by re-running the execute command passing the same reassignment-json-file:

                +
                $ bin/kafka-reassign-partitions.sh --bootstrap-server localhost:9092  --execute --reassignment-json-file bigger-cluster.json --throttle 700000000
                +  There is an existing assignment running.
                +  The throttle limit was set to 700000000 B/s
                + +

                Once the rebalance completes the administrator can check the status of the rebalance using the --verify option. + If the rebalance has completed, the throttle will be removed via the --verify command. It is important that + administrators remove the throttle in a timely manner once rebalancing completes by running the command with + the --verify option. Failure to do so could cause regular replication traffic to be throttled.

                +

                When the --verify option is executed, and the reassignment has completed, the script will confirm that the throttle was removed:

                + +
                  > bin/kafka-reassign-partitions.sh --bootstrap-server localhost:9092  --verify --reassignment-json-file bigger-cluster.json
                +  Status of partition reassignment:
                +  Reassignment of partition [my-topic,1] completed successfully
                +  Reassignment of partition [mytopic,0] completed successfully
                +  Throttle was removed.
                + +

                The administrator can also validate the assigned configs using the kafka-configs.sh. There are two pairs of throttle + configuration used to manage the throttling process. First pair refers to the throttle value itself. This is configured, at a broker + level, using the dynamic properties:

                + +
                    leader.replication.throttled.rate
                +    follower.replication.throttled.rate
                + +

                Then there is the configuration pair of enumerated sets of throttled replicas:

                + +
                    leader.replication.throttled.replicas
                +    follower.replication.throttled.replicas
                + +

                Which are configured per topic.

                + +

                All four config values are automatically assigned by kafka-reassign-partitions.sh (discussed below).

                + +

                To view the throttle limit configuration:

                + +
                  > bin/kafka-configs.sh --describe --bootstrap-server localhost:9092 --entity-type brokers
                +  Configs for brokers '2' are leader.replication.throttled.rate=700000000,follower.replication.throttled.rate=700000000
                +  Configs for brokers '1' are leader.replication.throttled.rate=700000000,follower.replication.throttled.rate=700000000
                + +

                This shows the throttle applied to both leader and follower side of the replication protocol. By default both sides + are assigned the same throttled throughput value.

                + +

                To view the list of throttled replicas:

                + +
                  > bin/kafka-configs.sh --describe --bootstrap-server localhost:9092 --entity-type topics
                +  Configs for topic 'my-topic' are leader.replication.throttled.replicas=1:102,0:101,
                +      follower.replication.throttled.replicas=1:101,0:102
                + +

                Here we see the leader throttle is applied to partition 1 on broker 102 and partition 0 on broker 101. Likewise the + follower throttle is applied to partition 1 on + broker 101 and partition 0 on broker 102.

                + +

                By default kafka-reassign-partitions.sh will apply the leader throttle to all replicas that exist before the + rebalance, any one of which might be leader. + It will apply the follower throttle to all move destinations. So if there is a partition with replicas on brokers + 101,102, being reassigned to 102,103, a leader throttle, + for that partition, would be applied to 101,102 and a follower throttle would be applied to 103 only.

                + + +

                If required, you can also use the --alter switch on kafka-configs.sh to alter the throttle configurations manually. +

                + +
                Safe usage of throttled replication
                + +

                Some care should be taken when using throttled replication. In particular:

                + +

                (1) Throttle Removal:

                + The throttle should be removed in a timely manner once reassignment completes (by running kafka-reassign-partitions.sh + --verify). + +

                (2) Ensuring Progress:

                +

                If the throttle is set too low, in comparison to the incoming write rate, it is possible for replication to not + make progress. This occurs when:

                +
                max(BytesInPerSec) > throttle
                +

                + Where BytesInPerSec is the metric that monitors the write throughput of producers into each broker.

                +

                The administrator can monitor whether replication is making progress, during the rebalance, using the metric:

                + +
                kafka.server:type=FetcherLagMetrics,name=ConsumerLag,clientId=([-.\w]+),topic=([-.\w]+),partition=([0-9]+)
                + +

                The lag should constantly decrease during replication. If the metric does not decrease the administrator should + increase the + throttle throughput as described above.

                + + +

                Setting quotas

                + Quotas overrides and defaults may be configured at (user, client-id), user or client-id levels as described here. + By default, clients receive an unlimited quota. + + It is possible to set custom quotas for each (user, client-id), user or client-id group. +

                + Configure custom quota for (user=user1, client-id=clientA): +

                  > bin/kafka-configs.sh  --bootstrap-server localhost:9092 --alter --add-config 'producer_byte_rate=1024,consumer_byte_rate=2048,request_percentage=200' --entity-type users --entity-name user1 --entity-type clients --entity-name clientA
                +  Updated config for entity: user-principal 'user1', client-id 'clientA'.
                + + Configure custom quota for user=user1: +
                  > bin/kafka-configs.sh  --bootstrap-server localhost:9092 --alter --add-config 'producer_byte_rate=1024,consumer_byte_rate=2048,request_percentage=200' --entity-type users --entity-name user1
                +  Updated config for entity: user-principal 'user1'.
                + + Configure custom quota for client-id=clientA: +
                  > bin/kafka-configs.sh  --bootstrap-server localhost:9092 --alter --add-config 'producer_byte_rate=1024,consumer_byte_rate=2048,request_percentage=200' --entity-type clients --entity-name clientA
                +  Updated config for entity: client-id 'clientA'.
                + + It is possible to set default quotas for each (user, client-id), user or client-id group by specifying --entity-default option instead of --entity-name. +

                + Configure default client-id quota for user=userA: +

                  > bin/kafka-configs.sh  --bootstrap-server localhost:9092 --alter --add-config 'producer_byte_rate=1024,consumer_byte_rate=2048,request_percentage=200' --entity-type users --entity-name user1 --entity-type clients --entity-default
                +  Updated config for entity: user-principal 'user1', default client-id.
                + + Configure default quota for user: +
                  > bin/kafka-configs.sh  --bootstrap-server localhost:9092 --alter --add-config 'producer_byte_rate=1024,consumer_byte_rate=2048,request_percentage=200' --entity-type users --entity-default
                +  Updated config for entity: default user-principal.
                + + Configure default quota for client-id: +
                  > bin/kafka-configs.sh  --bootstrap-server localhost:9092 --alter --add-config 'producer_byte_rate=1024,consumer_byte_rate=2048,request_percentage=200' --entity-type clients --entity-default
                +  Updated config for entity: default client-id.
                + + Here's how to describe the quota for a given (user, client-id): +
                  > bin/kafka-configs.sh  --bootstrap-server localhost:9092 --describe --entity-type users --entity-name user1 --entity-type clients --entity-name clientA
                +  Configs for user-principal 'user1', client-id 'clientA' are producer_byte_rate=1024,consumer_byte_rate=2048,request_percentage=200
                + Describe quota for a given user: +
                  > bin/kafka-configs.sh  --bootstrap-server localhost:9092 --describe --entity-type users --entity-name user1
                +  Configs for user-principal 'user1' are producer_byte_rate=1024,consumer_byte_rate=2048,request_percentage=200
                + Describe quota for a given client-id: +
                  > bin/kafka-configs.sh  --bootstrap-server localhost:9092 --describe --entity-type clients --entity-name clientA
                +  Configs for client-id 'clientA' are producer_byte_rate=1024,consumer_byte_rate=2048,request_percentage=200
                + If entity name is not specified, all entities of the specified type are described. For example, describe all users: +
                  > bin/kafka-configs.sh  --bootstrap-server localhost:9092 --describe --entity-type users
                +  Configs for user-principal 'user1' are producer_byte_rate=1024,consumer_byte_rate=2048,request_percentage=200
                +  Configs for default user-principal are producer_byte_rate=1024,consumer_byte_rate=2048,request_percentage=200
                + Similarly for (user, client): +
                  > bin/kafka-configs.sh  --bootstrap-server localhost:9092 --describe --entity-type users --entity-type clients
                +  Configs for user-principal 'user1', default client-id are producer_byte_rate=1024,consumer_byte_rate=2048,request_percentage=200
                +  Configs for user-principal 'user1', client-id 'clientA' are producer_byte_rate=1024,consumer_byte_rate=2048,request_percentage=200
                + +

                6.2 Datacenters

                + + Some deployments will need to manage a data pipeline that spans multiple datacenters. Our recommended approach to this is to deploy a local Kafka cluster in each datacenter, with application instances in each datacenter interacting only with their local cluster and mirroring data between clusters (see the documentation on Geo-Replication for how to do this). +

                + This deployment pattern allows datacenters to act as independent entities and allows us to manage and tune inter-datacenter replication centrally. This allows each facility to stand alone and operate even if the inter-datacenter links are unavailable: when this occurs the mirroring falls behind until the link is restored at which time it catches up. +

                + For applications that need a global view of all data you can use mirroring to provide clusters which have aggregate data mirrored from the local clusters in all datacenters. These aggregate clusters are used for reads by applications that require the full data set. +

                + This is not the only possible deployment pattern. It is possible to read from or write to a remote Kafka cluster over the WAN, though obviously this will add whatever latency is required to get the cluster. +

                + Kafka naturally batches data in both the producer and consumer so it can achieve high-throughput even over a high-latency connection. To allow this though it may be necessary to increase the TCP socket buffer sizes for the producer, consumer, and broker using the socket.send.buffer.bytes and socket.receive.buffer.bytes configurations. The appropriate way to set this is documented here. +

                + It is generally not advisable to run a single Kafka cluster that spans multiple datacenters over a high-latency link. This will incur very high replication latency both for Kafka writes and ZooKeeper writes, and neither Kafka nor ZooKeeper will remain available in all locations if the network between locations is unavailable. + +

                6.3 Geo-Replication (Cross-Cluster Data Mirroring)

                + +

                Geo-Replication Overview

                + +

                + Kafka administrators can define data flows that cross the boundaries of individual Kafka clusters, data centers, or geo-regions. Such event streaming setups are often needed for organizational, technical, or legal requirements. Common scenarios include: +

                + +
                  +
                • Geo-replication
                • +
                • Disaster recovery
                • +
                • Feeding edge clusters into a central, aggregate cluster
                • +
                • Physical isolation of clusters (such as production vs. testing)
                • +
                • Cloud migration or hybrid cloud deployments
                • +
                • Legal and compliance requirements
                • +
                + +

                + Administrators can set up such inter-cluster data flows with Kafka's MirrorMaker (version 2), a tool to replicate data between different Kafka environments in a streaming manner. MirrorMaker is built on top of the Kafka Connect framework and supports features such as: +

                + +
                  +
                • Replicates topics (data plus configurations)
                • +
                • Replicates consumer groups including offsets to migrate applications between clusters
                • +
                • Replicates ACLs
                • +
                • Preserves partitioning
                • +
                • Automatically detects new topics and partitions
                • +
                • Provides a wide range of metrics, such as end-to-end replication latency across multiple data centers/clusters
                • +
                • Fault-tolerant and horizontally scalable operations
                • +
                + +

                + Note: Geo-replication with MirrorMaker replicates data across Kafka clusters. This inter-cluster replication is different from Kafka's intra-cluster replication, which replicates data within the same Kafka cluster. +

                + +

                What Are Replication Flows

                + +

                + With MirrorMaker, Kafka administrators can replicate topics, topic configurations, consumer groups and their offsets, and ACLs from one or more source Kafka clusters to one or more target Kafka clusters, i.e., across cluster environments. In a nutshell, MirrorMaker uses Connectors to consume from source clusters and produce to target clusters. +

                + +

                + These directional flows from source to target clusters are called replication flows. They are defined with the format {source_cluster}->{target_cluster} in the MirrorMaker configuration file as described later. Administrators can create complex replication topologies based on these flows. +

                + +

                + Here are some example patterns: +

                + +
                  +
                • Active/Active high availability deployments: A->B, B->A
                • +
                • Active/Passive or Active/Standby high availability deployments: A->B
                • +
                • Aggregation (e.g., from many clusters to one): A->K, B->K, C->K
                • +
                • Fan-out (e.g., from one to many clusters): K->A, K->B, K->C
                • +
                • Forwarding: A->B, B->C, C->D
                • +
                + +

                + By default, a flow replicates all topics and consumer groups. However, each replication flow can be configured independently. For instance, you can define that only specific topics or consumer groups are replicated from the source cluster to the target cluster. +

                + +

                + Here is a first example on how to configure data replication from a primary cluster to a secondary cluster (an active/passive setup): +

                + +
                # Basic settings
                +clusters = primary, secondary
                +primary.bootstrap.servers = broker3-primary:9092
                +secondary.bootstrap.servers = broker5-secondary:9092
                +
                +# Define replication flows
                +primary->secondary.enabled = true
                +primary->secondary.topics = foobar-topic, quux-.*
                +
                + + +

                Configuring Geo-Replication

                + +

                + The following sections describe how to configure and run a dedicated MirrorMaker cluster. If you want to run MirrorMaker within an existing Kafka Connect cluster or other supported deployment setups, please refer to KIP-382: MirrorMaker 2.0 and be aware that the names of configuration settings may vary between deployment modes. +

                + +

                + Beyond what's covered in the following sections, further examples and information on configuration settings are available at: +

                + + + +
                Configuration File Syntax
                + +

                + The MirrorMaker configuration file is typically named connect-mirror-maker.properties. You can configure a variety of components in this file: +

                + +
                  +
                • MirrorMaker settings: global settings including cluster definitions (aliases), plus custom settings per replication flow
                • +
                • Kafka Connect and connector settings
                • +
                • Kafka producer, consumer, and admin client settings
                • +
                + +

                + Example: Define MirrorMaker settings (explained in more detail later). +

                + +
                # Global settings
                +clusters = us-west, us-east   # defines cluster aliases
                +us-west.bootstrap.servers = broker3-west:9092
                +us-east.bootstrap.servers = broker5-east:9092
                +
                +topics = .*   # all topics to be replicated by default
                +
                +# Specific replication flow settings (here: flow from us-west to us-east)
                +us-west->us-east.enabled = true
                +us-west->us.east.topics = foo.*, bar.*  # override the default above
                +
                + +

                + MirrorMaker is based on the Kafka Connect framework. Any Kafka Connect, source connector, and sink connector settings as described in the documentation chapter on Kafka Connect can be used directly in the MirrorMaker configuration, without having to change or prefix the name of the configuration setting. +

                + +

                + Example: Define custom Kafka Connect settings to be used by MirrorMaker. +

                + +
                # Setting Kafka Connect defaults for MirrorMaker
                +tasks.max = 5
                +
                + +

                + Most of the default Kafka Connect settings work well for MirrorMaker out-of-the-box, with the exception of tasks.max. In order to evenly distribute the workload across more than one MirrorMaker process, it is recommended to set tasks.max to at least 2 (preferably higher) depending on the available hardware resources and the total number of topic-partitions to be replicated. +

                + +

                + You can further customize MirrorMaker's Kafka Connect settings per source or target cluster (more precisely, you can specify Kafka Connect worker-level configuration settings "per connector"). Use the format of {cluster}.{config_name} in the MirrorMaker configuration file. +

                + +

                + Example: Define custom connector settings for the us-west cluster. +

                + +
                # us-west custom settings
                +us-west.offset.storage.topic = my-mirrormaker-offsets
                +
                + +

                + MirrorMaker internally uses the Kafka producer, consumer, and admin clients. Custom settings for these clients are often needed. To override the defaults, use the following format in the MirrorMaker configuration file: +

                + +
                  +
                • {source}.consumer.{consumer_config_name}
                • +
                • {target}.producer.{producer_config_name}
                • +
                • {source_or_target}.admin.{admin_config_name}
                • +
                + +

                + Example: Define custom producer, consumer, admin client settings. +

                + +
                # us-west cluster (from which to consume)
                +us-west.consumer.isolation.level = read_committed
                +us-west.admin.bootstrap.servers = broker57-primary:9092
                +
                +# us-east cluster (to which to produce)
                +us-east.producer.compression.type = gzip
                +us-east.producer.buffer.memory = 32768
                +us-east.admin.bootstrap.servers = broker8-secondary:9092
                +
                + +
                Creating and Enabling Replication Flows
                + +

                + To define a replication flow, you must first define the respective source and target Kafka clusters in the MirrorMaker configuration file. +

                + +
                  +
                • clusters (required): comma-separated list of Kafka cluster "aliases"
                • +
                • {clusterAlias}.bootstrap.servers (required): connection information for the specific cluster; comma-separated list of "bootstrap" Kafka brokers +
                + +

                + Example: Define two cluster aliases primary and secondary, including their connection information. +

                + +
                clusters = primary, secondary
                +primary.bootstrap.servers = broker10-primary:9092,broker-11-primary:9092
                +secondary.bootstrap.servers = broker5-secondary:9092,broker6-secondary:9092
                +
                + +

                + Secondly, you must explicitly enable individual replication flows with {source}->{target}.enabled = true as needed. Remember that flows are directional: if you need two-way (bidirectional) replication, you must enable flows in both directions. +

                + +
                # Enable replication from primary to secondary
                +primary->secondary.enabled = true
                +
                + +

                + By default, a replication flow will replicate all but a few special topics and consumer groups from the source cluster to the target cluster, and automatically detect any newly created topics and groups. The names of replicated topics in the target cluster will be prefixed with the name of the source cluster (see section further below). For example, the topic foo in the source cluster us-west would be replicated to a topic named us-west.foo in the target cluster us-east. +

                + +

                + The subsequent sections explain how to customize this basic setup according to your needs. +

                + +
                Configuring Replication Flows
                + +

                +The configuration of a replication flow is a combination of top-level default settings (e.g., topics), on top of which flow-specific settings, if any, are applied (e.g., us-west->us-east.topics). To change the top-level defaults, add the respective top-level setting to the MirrorMaker configuration file. To override the defaults for a specific replication flow only, use the syntax format {source}->{target}.{config.name}. +

                + +

                + The most important settings are: +

                + +
                  +
                • topics: list of topics or a regular expression that defines which topics in the source cluster to replicate (default: topics = .*) +
                • topics.exclude: list of topics or a regular expression to subsequently exclude topics that were matched by the topics setting (default: topics.exclude = .*[\-\.]internal, .*\.replica, __.*) +
                • groups: list of topics or regular expression that defines which consumer groups in the source cluster to replicate (default: groups = .*) +
                • groups.exclude: list of topics or a regular expression to subsequently exclude consumer groups that were matched by the groups setting (default: groups.exclude = console-consumer-.*, connect-.*, __.*) +
                • {source}->{target}.enable: set to true to enable the replication flow (default: false) +
                + +

                + Example: +

                + +
                # Custom top-level defaults that apply to all replication flows
                +topics = .*
                +groups = consumer-group1, consumer-group2
                +
                +# Don't forget to enable a flow!
                +us-west->us-east.enabled = true
                +
                +# Custom settings for specific replication flows
                +us-west->us-east.topics = foo.*
                +us-west->us-east.groups = bar.*
                +us-west->us-east.emit.heartbeats = false
                +
                + +

                + Additional configuration settings are supported, some of which are listed below. In most cases, you can leave these settings at their default values. See MirrorMakerConfig and MirrorConnectorConfig for further details. +

                + +
                  +
                • refresh.topics.enabled: whether to check for new topics in the source cluster periodically (default: true) +
                • refresh.topics.interval.seconds: frequency of checking for new topics in the source cluster; lower values than the default may lead to performance degradation (default: 600, every ten minutes) +
                • refresh.groups.enabled: whether to check for new consumer groups in the source cluster periodically (default: true) +
                • refresh.groups.interval.seconds: frequency of checking for new consumer groups in the source cluster; lower values than the default may lead to performance degradation (default: 600, every ten minutes) +
                • sync.topic.configs.enabled: whether to replicate topic configurations from the source cluster (default: true) +
                • sync.topic.acls.enabled: whether to sync ACLs from the source cluster (default: true) +
                • emit.heartbeats.enabled: whether to emit heartbeats periodically (default: true) +
                • emit.heartbeats.interval.seconds: frequency at which heartbeats are emitted (default: 1, every one seconds) +
                • heartbeats.topic.replication.factor: replication factor of MirrorMaker's internal heartbeat topics (default: 3) +
                • emit.checkpoints.enabled: whether to emit MirrorMaker's consumer offsets periodically (default: true) +
                • emit.checkpoints.interval.seconds: frequency at which checkpoints are emitted (default: 60, every minute) +
                • checkpoints.topic.replication.factor: replication factor of MirrorMaker's internal checkpoints topics (default: 3) +
                • sync.group.offsets.enabled: whether to periodically write the translated offsets of replicated consumer groups (in the source cluster) to __consumer_offsets topic in target cluster, as long as no active consumers in that group are connected to the target cluster (default: false) +
                • sync.group.offsets.interval.seconds: frequency at which consumer group offsets are synced (default: 60, every minute) +
                • offset-syncs.topic.replication.factor: replication factor of MirrorMaker's internal offset-sync topics (default: 3) +
                + +
                Securing Replication Flows
                + +

                + MirrorMaker supports the same security settings as Kafka Connect, so please refer to the linked section for further information. +

                + +

                + Example: Encrypt communication between MirrorMaker and the us-east cluster. +

                + +
                us-east.security.protocol=SSL
                +us-east.ssl.truststore.location=/path/to/truststore.jks
                +us-east.ssl.truststore.password=my-secret-password
                +us-east.ssl.keystore.location=/path/to/keystore.jks
                +us-east.ssl.keystore.password=my-secret-password
                +us-east.ssl.key.password=my-secret-password
                +
                + +
                Custom Naming of Replicated Topics in Target Clusters
                + +

                + Replicated topics in a target cluster—sometimes called remote topics—are renamed according to a replication policy. MirrorMaker uses this policy to ensure that events (aka records, messages) from different clusters are not written to the same topic-partition. By default as per DefaultReplicationPolicy, the names of replicated topics in the target clusters have the format {source}.{source_topic_name}: +

                + +
                us-west         us-east
                +=========       =================
                +                bar-topic
                +foo-topic  -->  us-west.foo-topic
                +
                + +

                + You can customize the separator (default: .) with the replication.policy.separator setting: +

                + +
                # Defining a custom separator
                +us-west->us-east.replication.policy.separator = _
                +
                + +

                + If you need further control over how replicated topics are named, you can implement a custom ReplicationPolicy and override replication.policy.class (default is DefaultReplicationPolicy) in the MirrorMaker configuration. +

                + +
                Preventing Configuration Conflicts
                + +

                + MirrorMaker processes share configuration via their target Kafka clusters. This behavior may cause conflicts when configurations differ among MirrorMaker processes that operate against the same target cluster. +

                + +

                + For example, the following two MirrorMaker processes would be racy: +

                + +
                # Configuration of process 1
                +A->B.enabled = true
                +A->B.topics = foo
                +
                +# Configuration of process 2
                +A->B.enabled = true
                +A->B.topics = bar
                +
                + +

                + In this case, the two processes will share configuration via cluster B, which causes a conflict. Depending on which of the two processes is the elected "leader", the result will be that either the topic foo or the topic bar is replicated, but not both. +

                + +

                + It is therefore important to keep the MirrorMaker configration consistent across replication flows to the same target cluster. This can be achieved, for example, through automation tooling or by using a single, shared MirrorMaker configuration file for your entire organization. +

                + +
                Best Practice: Consume from Remote, Produce to Local
                + +

                +To minimize latency ("producer lag"), it is recommended to locate MirrorMaker processes as close as possible to their target clusters, i.e., the clusters that it produces data to. That's because Kafka producers typically struggle more with unreliable or high-latency network connections than Kafka consumers. +

                + +
                First DC          Second DC
                +==========        =========================
                +primary --------- MirrorMaker --> secondary
                +(remote)                           (local)
                +
                + +

                +To run such a "consume from remote, produce to local" setup, run the MirrorMaker processes close to and preferably in the same location as the target clusters, and explicitly set these "local" clusters in the --clusters command line parameter (blank-separated list of cluster aliases): +

                + +
                # Run in secondary's data center, reading from the remote `primary` cluster
                +$ ./bin/connect-mirror-maker.sh connect-mirror-maker.properties --clusters secondary
                +
                + +The --clusters secondary tells the MirrorMaker process that the given cluster(s) are nearby, and prevents it from replicating data or sending configuration to clusters at other, remote locations. + +
                Example: Active/Passive High Availability Deployment
                + +

                +The following example shows the basic settings to replicate topics from a primary to a secondary Kafka environment, but not from the secondary back to the primary. Please be aware that most production setups will need further configuration, such as security settings. +

                + +
                # Unidirectional flow (one-way) from primary to secondary cluster
                +primary.bootstrap.servers = broker1-primary:9092
                +secondary.bootstrap.servers = broker2-secondary:9092
                +
                +primary->secondary.enabled = true
                +secondary->primary.enabled = false
                +
                +primary->secondary.topics = foo.*  # only replicate some topics
                +
                + +
                Example: Active/Active High Availability Deployment
                + +

                + The following example shows the basic settings to replicate topics between two clusters in both ways. Please be aware that most production setups will need further configuration, such as security settings. +

                + +
                # Bidirectional flow (two-way) between us-west and us-east clusters
                +clusters = us-west, us-east
                +us-west.bootstrap.servers = broker1-west:9092,broker2-west:9092
                +Us-east.bootstrap.servers = broker3-east:9092,broker4-east:9092
                +
                +us-west->us-east.enabled = true
                +us-east->us-west.enabled = true
                +
                + +

                + Note on preventing replication "loops" (where topics will be originally replicated from A to B, then the replicated topics will be replicated yet again from B to A, and so forth): As long as you define the above flows in the same MirrorMaker configuration file, you do not need to explicitly add topics.exclude settings to prevent replication loops between the two clusters. +

                + +
                Example: Multi-Cluster Geo-Replication
                + +

                + Let's put all the information from the previous sections together in a larger example. Imagine there are three data centers (west, east, north), with two Kafka clusters in each data center (e.g., west-1, west-2). The example in this section shows how to configure MirrorMaker (1) for Active/Active replication within each data center, as well as (2) for Cross Data Center Replication (XDCR). +

                + +

                + First, define the source and target clusters along with their replication flows in the configuration: +

                + +
                # Basic settings
                +clusters: west-1, west-2, east-1, east-2, north-1, north-2
                +west-1.bootstrap.servers = ...
                +west-2.bootstrap.servers = ...
                +east-1.bootstrap.servers = ...
                +east-2.bootstrap.servers = ...
                +north-1.bootstrap.servers = ...
                +north-2.bootstrap.servers = ...
                +
                +# Replication flows for Active/Active in West DC
                +west-1->west-2.enabled = true
                +west-2->west-1.enabled = true
                +
                +# Replication flows for Active/Active in East DC
                +east-1->east-2.enabled = true
                +east-2->east-1.enabled = true
                +
                +# Replication flows for Active/Active in North DC
                +north-1->north-2.enabled = true
                +north-2->north-1.enabled = true
                +
                +# Replication flows for XDCR via west-1, east-1, north-1
                +west-1->east-1.enabled  = true
                +west-1->north-1.enabled = true
                +east-1->west-1.enabled  = true
                +east-1->north-1.enabled = true
                +north-1->west-1.enabled = true
                +north-1->east-1.enabled = true
                +
                + +

                + Then, in each data center, launch one or more MirrorMaker as follows: +

                + +
                # In West DC:
                +$ ./bin/connect-mirror-maker.sh connect-mirror-maker.properties --clusters west-1 west-2
                +
                +# In East DC:
                +$ ./bin/connect-mirror-maker.sh connect-mirror-maker.properties --clusters east-1 east-2
                +
                +# In North DC:
                +$ ./bin/connect-mirror-maker.sh connect-mirror-maker.properties --clusters north-1 north-2
                +
                + +

                + With this configuration, records produced to any cluster will be replicated within the data center, as well as across to other data centers. By providing the --clusters parameter, we ensure that each MirrorMaker process produces data to nearby clusters only. +

                + +

                + Note: The --clusters parameter is, technically, not required here. MirrorMaker will work fine without it. However, throughput may suffer from "producer lag" between data centers, and you may incur unnecessary data transfer costs. +

                + +

                Starting Geo-Replication

                + +

                + You can run as few or as many MirrorMaker processes (think: nodes, servers) as needed. Because MirrorMaker is based on Kafka Connect, MirrorMaker processes that are configured to replicate the same Kafka clusters run in a distributed setup: They will find each other, share configuration (see section below), load balance their work, and so on. If, for example, you want to increase the throughput of replication flows, one option is to run additional MirrorMaker processes in parallel. +

                + +

                + To start a MirrorMaker process, run the command: +

                + +
                $ ./bin/connect-mirror-maker.sh connect-mirror-maker.properties
                +
                + +

                + After startup, it may take a few minutes until a MirrorMaker process first begins to replicate data. +

                + +

                + Optionally, as described previously, you can set the parameter --clusters to ensure that the MirrorMaker process produces data to nearby clusters only. +

                + +
                # Note: The cluster alias us-west must be defined in the configuration file
                +$ ./bin/connect-mirror-maker.sh connect-mirror-maker.properties \
                +            --clusters us-west
                +
                + +

                + Note when testing replication of consumer groups: By default, MirrorMaker does not replicate consumer groups created by the kafka-console-consumer.sh tool, which you might use to test your MirrorMaker setup on the command line. If you do want to replicate these consumer groups as well, set the groups.exclude configuration accordingly (default: groups.exclude = console-consumer-.*, connect-.*, __.*). Remember to update the configuration again once you completed your testing. +

                + +

                Stopping Geo-Replication

                + +

                + You can stop a running MirrorMaker process by sending a SIGTERM signal with the command: +

                + +
                $ kill <MirrorMaker pid>
                +
                + +

                Applying Configuration Changes

                + +

                + To make configuration changes take effect, the MirrorMaker process(es) must be restarted. +

                + +

                Monitoring Geo-Replication

                + +

                + It is recommended to monitor MirrorMaker processes to ensure all defined replication flows are up and running correctly. MirrorMaker is built on the Connect framework and inherits all of Connect's metrics, such source-record-poll-rate. In addition, MirrorMaker produces its own metrics under the kafka.connect.mirror metric group. Metrics are tagged with the following properties: +

                + +
                  +
                • source: alias of source cluster (e.g., primary)
                • +
                • target: alias of target cluster (e.g., secondary)
                • +
                • topic: replicated topic on target cluster
                • +
                • partition: partition being replicated
                • +
                + +

                + Metrics are tracked for each replicated topic. The source cluster can be inferred from the topic name. For example, replicating topic1 from primary->secondary will yield metrics like: +

                + +
                  +
                • target=secondary +
                • topic=primary.topic1 +
                • partition=1 +
                + +

                + The following metrics are emitted: +

                + +
                # MBean: kafka.connect.mirror:type=MirrorSourceConnector,target=([-.w]+),topic=([-.w]+),partition=([0-9]+)
                +
                +record-count            # number of records replicated source -> target
                +record-age-ms           # age of records when they are replicated
                +record-age-ms-min
                +record-age-ms-max
                +record-age-ms-avg
                +replication-latency-ms  # time it takes records to propagate source->target
                +replication-latency-ms-min
                +replication-latency-ms-max
                +replication-latency-ms-avg
                +byte-rate               # average number of bytes/sec in replicated records
                +
                +# MBean: kafka.connect.mirror:type=MirrorCheckpointConnector,source=([-.w]+),target=([-.w]+)
                +
                +checkpoint-latency-ms   # time it takes to replicate consumer offsets
                +checkpoint-latency-ms-min
                +checkpoint-latency-ms-max
                +checkpoint-latency-ms-avg
                +
                + +

                + These metrics do not differentiate between created-at and log-append timestamps. +

                + + +

                6.4 Multi-Tenancy

                + +

                Multi-Tenancy Overview

                + +

                + As a highly scalable event streaming platform, Kafka is used by many users as their central nervous system, connecting in real-time a wide range of different systems and applications from various teams and lines of businesses. Such multi-tenant cluster environments command proper control and management to ensure the peaceful coexistence of these different needs. This section highlights features and best practices to set up such shared environments, which should help you operate clusters that meet SLAs/OLAs and that minimize potential collateral damage caused by "noisy neighbors". +

                + +

                + Multi-tenancy is a many-sided subject, including but not limited to: +

                + +
                  +
                • Creating user spaces for tenants (sometimes called namespaces)
                • +
                • Configuring topics with data retention policies and more
                • +
                • Securing topics and clusters with encryption, authentication, and authorization
                • +
                • Isolating tenants with quotas and rate limits
                • +
                • Monitoring and metering
                • +
                • Inter-cluster data sharing (cf. geo-replication)
                • +
                + +

                Creating User Spaces (Namespaces) For Tenants With Topic Naming

                + +

                + Kafka administrators operating a multi-tenant cluster typically need to define user spaces for each tenant. For the purpose of this section, "user spaces" are a collection of topics, which are grouped together under the management of a single entity or user. +

                + +

                + In Kafka, the main unit of data is the topic. Users can create and name each topic. They can also delete them, but it is not possible to rename a topic directly. Instead, to rename a topic, the user must create a new topic, move the messages from the original topic to the new, and then delete the original. With this in mind, it is recommended to define logical spaces, based on an hierarchical topic naming structure. This setup can then be combined with security features, such as prefixed ACLs, to isolate different spaces and tenants, while also minimizing the administrative overhead for securing the data in the cluster. +

                + +

                + These logical user spaces can be grouped in different ways, and the concrete choice depends on how your organization prefers to use your Kafka clusters. The most common groupings are as follows. +

                + +

                + By team or organizational unit: Here, the team is the main aggregator. In an organization where teams are the main user of the Kafka infrastructure, this might be the best grouping. +

                + +

                + Example topic naming structure: +

                + +
                  +
                • <organization>.<team>.<dataset>.<event-name>
                  (e.g., "acme.infosec.telemetry.logins")
                • +
                + +

                + By project or product: Here, a team manages more than one project. Their credentials will be different for each project, so all the controls and settings will always be project related. +

                + +

                + Example topic naming structure: +

                + +
                  +
                • <project>.<product>.<event-name>
                  (e.g., "mobility.payments.suspicious")
                • +
                + +

                + Certain information should normally not be put in a topic name, such as information that is likely to change over time (e.g., the name of the intended consumer) or that is a technical detail or metadata that is available elsewhere (e.g., the topic's partition count and other configuration settings). +

                + +

                + To enforce a topic naming structure, several options are available: +

                + +
                  +
                • Use prefix ACLs (cf. KIP-290) to enforce a common prefix for topic names. For example, team A may only be permitted to create topics whose names start with payments.teamA..
                • +
                • Define a custom CreateTopicPolicy (cf. KIP-108 and the setting create.topic.policy.class.name) to enforce strict naming patterns. These policies provide the most flexibility and can cover complex patterns and rules to match an organization's needs.
                • +
                • Disable topic creation for normal users by denying it with an ACL, and then rely on an external process to create topics on behalf of users (e.g., scripting or your favorite automation toolkit).
                • +
                • It may also be useful to disable the Kafka feature to auto-create topics on demand by setting auto.create.topics.enable=false in the broker configuration. Note that you should not rely solely on this option.
                • +
                + + +

                Configuring Topics: Data Retention And More

                + +

                + Kafka's configuration is very flexible due to its fine granularity, and it supports a plethora of per-topic configuration settings to help administrators set up multi-tenant clusters. For example, administrators often need to define data retention policies to control how much and/or for how long data will be stored in a topic, with settings such as retention.bytes (size) and retention.ms (time). This limits storage consumption within the cluster, and helps complying with legal requirements such as GDPR. +

                + +

                Securing Clusters and Topics: Authentication, Authorization, Encryption

                + +

                + Because the documentation has a dedicated chapter on security that applies to any Kafka deployment, this section focuses on additional considerations for multi-tenant environments. +

                + +

                +Security settings for Kafka fall into three main categories, which are similar to how administrators would secure other client-server data systems, like relational databases and traditional messaging systems. +

                + +
                  +
                1. Encryption of data transferred between Kafka brokers and Kafka clients, between brokers, between brokers and ZooKeeper nodes, and between brokers and other, optional tools.
                2. +
                3. Authentication of connections from Kafka clients and applications to Kafka brokers, as well as connections from Kafka brokers to ZooKeeper nodes.
                4. +
                5. Authorization of client operations such as creating, deleting, and altering the configuration of topics; writing events to or reading events from a topic; creating and deleting ACLs. Administrators can also define custom policies to put in place additional restrictions, such as a CreateTopicPolicy and AlterConfigPolicy (see KIP-108 and the settings create.topic.policy.class.name, alter.config.policy.class.name).
                6. +
                + +

                + When securing a multi-tenant Kafka environment, the most common administrative task is the third category (authorization), i.e., managing the user/client permissions that grant or deny access to certain topics and thus to the data stored by users within a cluster. This task is performed predominantly through the setting of access control lists (ACLs). Here, administrators of multi-tenant environments in particular benefit from putting a hierarchical topic naming structure in place as described in a previous section, because they can conveniently control access to topics through prefixed ACLs (--resource-pattern-type Prefixed). This significantly minimizes the administrative overhead of securing topics in multi-tenant environments: administrators can make their own trade-offs between higher developer convenience (more lenient permissions, using fewer and broader ACLs) vs. tighter security (more stringent permissions, using more and narrower ACLs). +

                + +

                + In the following example, user Alice—a new member of ACME corporation's InfoSec team—is granted write permissions to all topics whose names start with "acme.infosec.", such as "acme.infosec.telemetry.logins" and "acme.infosec.syslogs.events". +

                + +
                # Grant permissions to user Alice
                +$ bin/kafka-acls.sh \
                +    --bootstrap-server broker1:9092 \
                +    --add --allow-principal User:Alice \
                +    --producer \
                +    --resource-pattern-type prefixed --topic acme.infosec.
                +
                + +

                + You can similarly use this approach to isolate different customers on the same shared cluster. +

                + +

                Isolating Tenants: Quotas, Rate Limiting, Throttling

                + +

                + Multi-tenant clusters should generally be configured with quotas, which protect against users (tenants) eating up too many cluster resources, such as when they attempt to write or read very high volumes of data, or create requests to brokers at an excessively high rate. This may cause network saturation, monopolize broker resources, and impact other clients—all of which you want to avoid in a shared environment. +

                + +

                + Client quotas: Kafka supports different types of (per-user principal) client quotas. Because a client's quotas apply irrespective of which topics the client is writing to or reading from, they are a convenient and effective tool to allocate resources in a multi-tenant cluster. Request rate quotas, for example, help to limit a user's impact on broker CPU usage by limiting the time a broker spends on the request handling path for that user, after which throttling kicks in. In many situations, isolating users with request rate quotas has a bigger impact in multi-tenant clusters than setting incoming/outgoing network bandwidth quotas, because excessive broker CPU usage for processing requests reduces the effective bandwidth the broker can serve. Furthermore, administrators can also define quotas on topic operations—such as create, delete, and alter—to prevent Kafka clusters from being overwhelmed by highly concurrent topic operations (see KIP-599 and the quota type controller_mutations_rate). +

                + +

                + Server quotas: Kafka also supports different types of broker-side quotas. For example, administrators can set a limit on the rate with which the broker accepts new connections, set the maximum number of connections per broker, or set the maximum number of connections allowed from a specific IP address. +

                + +

                + For more information, please refer to the quota overview and how to set quotas. +

                + +

                Monitoring and Metering

                + +

                + Monitoring is a broader subject that is covered elsewhere in the documentation. Administrators of any Kafka environment, but especially multi-tenant ones, should set up monitoring according to these instructions. Kafka supports a wide range of metrics, such as the rate of failed authentication attempts, request latency, consumer lag, total number of consumer groups, metrics on the quotas described in the previous section, and many more. +

                + +

                + For example, monitoring can be configured to track the size of topic-partitions (with the JMX metric kafka.log.Log.Size.<TOPIC-NAME>), and thus the total size of data stored in a topic. You can then define alerts when tenants on shared clusters are getting close to using too much storage space. +

                + +

                Multi-Tenancy and Geo-Replication

                + +

                + Kafka lets you share data across different clusters, which may be located in different geographical regions, data centers, and so on. Apart from use cases such as disaster recovery, this functionality is useful when a multi-tenant setup requires inter-cluster data sharing. See the section Geo-Replication (Cross-Cluster Data Mirroring) for more information. +

                + +

                Further considerations

                + +

                + Data contracts: You may need to define data contracts between the producers and the consumers of data in a cluster, using event schemas. This ensures that events written to Kafka can always be read properly again, and prevents malformed or corrupt events being written. The best way to achieve this is to deploy a so-called schema registry alongside the cluster. (Kafka does not include a schema registry, but there are third-party implementations available.) A schema registry manages the event schemas and maps the schemas to topics, so that producers know which topics are accepting which types (schemas) of events, and consumers know how to read and parse events in a topic. Some registry implementations provide further functionality, such as schema evolution, storing a history of all schemas, and schema compatibility settings. +

                + + +

                6.5 Kafka Configuration

                + +

                Important Client Configurations

                + + The most important producer configurations are: +
                  +
                • acks
                • +
                • compression
                • +
                • batch size
                • +
                + The most important consumer configuration is the fetch size. +

                + All configurations are documented in the configuration section. +

                +

                A Production Server Config

                + Here is an example production server configuration: +
                  # ZooKeeper
                +  zookeeper.connect=[list of ZooKeeper servers]
                +
                +  # Log configuration
                +  num.partitions=8
                +  default.replication.factor=3
                +  log.dir=[List of directories. Kafka should have its own dedicated disk(s) or SSD(s).]
                +
                +  # Other configurations
                +  broker.id=[An integer. Start with 0 and increment by 1 for each new broker.]
                +  listeners=[list of listeners]
                +  auto.create.topics.enable=false
                +  min.insync.replicas=2
                +  queued.max.requests=[number of concurrent requests]
                + + Our client configuration varies a fair amount between different use cases. + +

                6.6 Java Version

                + + Java 8 and Java 11 are supported. Java 11 performs significantly better if TLS is enabled, so it is highly recommended (it also includes a number of other + performance improvements: G1GC, CRC32C, Compact Strings, Thread-Local Handshakes and more). + + From a security perspective, we recommend the latest released patch version as older freely available versions have disclosed security vulnerabilities. + + Typical arguments for running Kafka with OpenJDK-based Java implementations (including Oracle JDK) are: + +
                  -Xmx6g -Xms6g -XX:MetaspaceSize=96m -XX:+UseG1GC
                +  -XX:MaxGCPauseMillis=20 -XX:InitiatingHeapOccupancyPercent=35 -XX:G1HeapRegionSize=16M
                +  -XX:MinMetaspaceFreeRatio=50 -XX:MaxMetaspaceFreeRatio=80 -XX:+ExplicitGCInvokesConcurrent
                + + For reference, here are the stats for one of LinkedIn's busiest clusters (at peak) that uses said Java arguments: +
                  +
                • 60 brokers
                • +
                • 50k partitions (replication factor 2)
                • +
                • 800k messages/sec in
                • +
                • 300 MB/sec inbound, 1 GB/sec+ outbound
                • +
                + + All of the brokers in that cluster have a 90% GC pause time of about 21ms with less than 1 young GC per second. + +

                6.7 Hardware and OS

                + We are using dual quad-core Intel Xeon machines with 24GB of memory. +

                + You need sufficient memory to buffer active readers and writers. You can do a back-of-the-envelope estimate of memory needs by assuming you want to be able to buffer for 30 seconds and compute your memory need as write_throughput*30. +

                + The disk throughput is important. We have 8x7200 rpm SATA drives. In general disk throughput is the performance bottleneck, and more disks is better. Depending on how you configure flush behavior you may or may not benefit from more expensive disks (if you force flush often then higher RPM SAS drives may be better). + +

                OS

                + Kafka should run well on any unix system and has been tested on Linux and Solaris. +

                + We have seen a few issues running on Windows and Windows is not currently a well supported platform though we would be happy to change that. +

                + It is unlikely to require much OS-level tuning, but there are three potentially important OS-level configurations: +

                  +
                • File descriptor limits: Kafka uses file descriptors for log segments and open connections. If a broker hosts many partitions, consider that the broker needs at least (number_of_partitions)*(partition_size/segment_size) to track all log segments in addition to the number of connections the broker makes. We recommend at least 100000 allowed file descriptors for the broker processes as a starting point. Note: The mmap() function adds an extra reference to the file associated with the file descriptor fildes which is not removed by a subsequent close() on that file descriptor. This reference is removed when there are no more mappings to the file. +
                • Max socket buffer size: can be increased to enable high-performance data transfer between data centers as described here. +
                • Maximum number of memory map areas a process may have (aka vm.max_map_count). See the Linux kernel documentation. You should keep an eye at this OS-level property when considering the maximum number of partitions a broker may have. By default, on a number of Linux systems, the value of vm.max_map_count is somewhere around 65535. Each log segment, allocated per partition, requires a pair of index/timeindex files, and each of these files consumes 1 map area. In other words, each log segment uses 2 map areas. Thus, each partition requires minimum 2 map areas, as long as it hosts a single log segment. That is to say, creating 50000 partitions on a broker will result allocation of 100000 map areas and likely cause broker crash with OutOfMemoryError (Map failed) on a system with default vm.max_map_count. Keep in mind that the number of log segments per partition varies depending on the segment size, load intensity, retention policy and, generally, tends to be more than one. +
                +

                + +

                Disks and Filesystem

                + We recommend using multiple drives to get good throughput and not sharing the same drives used for Kafka data with application logs or other OS filesystem activity to ensure good latency. You can either RAID these drives together into a single volume or format and mount each drive as its own directory. Since Kafka has replication the redundancy provided by RAID can also be provided at the application level. This choice has several tradeoffs. +

                + If you configure multiple data directories partitions will be assigned round-robin to data directories. Each partition will be entirely in one of the data directories. If data is not well balanced among partitions this can lead to load imbalance between disks. +

                + RAID can potentially do better at balancing load between disks (although it doesn't always seem to) because it balances load at a lower level. The primary downside of RAID is that it is usually a big performance hit for write throughput and reduces the available disk space. +

                + Another potential benefit of RAID is the ability to tolerate disk failures. However our experience has been that rebuilding the RAID array is so I/O intensive that it effectively disables the server, so this does not provide much real availability improvement. + +

                Application vs. OS Flush Management

                + Kafka always immediately writes all data to the filesystem and supports the ability to configure the flush policy that controls when data is forced out of the OS cache and onto disk using the flush. This flush policy can be controlled to force data to disk after a period of time or after a certain number of messages has been written. There are several choices in this configuration. +

                + Kafka must eventually call fsync to know that data was flushed. When recovering from a crash for any log segment not known to be fsync'd Kafka will check the integrity of each message by checking its CRC and also rebuild the accompanying offset index file as part of the recovery process executed on startup. +

                + Note that durability in Kafka does not require syncing data to disk, as a failed node will always recover from its replicas. +

                + We recommend using the default flush settings which disable application fsync entirely. This means relying on the background flush done by the OS and Kafka's own background flush. This provides the best of all worlds for most uses: no knobs to tune, great throughput and latency, and full recovery guarantees. We generally feel that the guarantees provided by replication are stronger than sync to local disk, however the paranoid still may prefer having both and application level fsync policies are still supported. +

                + The drawback of using application level flush settings is that it is less efficient in its disk usage pattern (it gives the OS less leeway to re-order writes) and it can introduce latency as fsync in most Linux filesystems blocks writes to the file whereas the background flushing does much more granular page-level locking. +

                + In general you don't need to do any low-level tuning of the filesystem, but in the next few sections we will go over some of this in case it is useful. + +

                Understanding Linux OS Flush Behavior

                + + In Linux, data written to the filesystem is maintained in pagecache until it must be written out to disk (due to an application-level fsync or the OS's own flush policy). The flushing of data is done by a set of background threads called pdflush (or in post 2.6.32 kernels "flusher threads"). +

                + Pdflush has a configurable policy that controls how much dirty data can be maintained in cache and for how long before it must be written back to disk. + This policy is described here. + When Pdflush cannot keep up with the rate of data being written it will eventually cause the writing process to block incurring latency in the writes to slow down the accumulation of data. +

                + You can see the current state of OS memory usage by doing +

                 > cat /proc/meminfo 
                + The meaning of these values are described in the link above. +

                + Using pagecache has several advantages over an in-process cache for storing data that will be written out to disk: +

                  +
                • The I/O scheduler will batch together consecutive small writes into bigger physical writes which improves throughput. +
                • The I/O scheduler will attempt to re-sequence writes to minimize movement of the disk head which improves throughput. +
                • It automatically uses all the free memory on the machine +
                + +

                Filesystem Selection

                +

                Kafka uses regular files on disk, and as such it has no hard dependency on a specific filesystem. The two filesystems which have the most usage, however, are EXT4 and XFS. Historically, EXT4 has had more usage, but recent improvements to the XFS filesystem have shown it to have better performance characteristics for Kafka's workload with no compromise in stability.

                +

                Comparison testing was performed on a cluster with significant message loads, using a variety of filesystem creation and mount options. The primary metric in Kafka that was monitored was the "Request Local Time", indicating the amount of time append operations were taking. XFS resulted in much better local times (160ms vs. 250ms+ for the best EXT4 configuration), as well as lower average wait times. The XFS performance also showed less variability in disk performance.

                +
                General Filesystem Notes
                + For any filesystem used for data directories, on Linux systems, the following options are recommended to be used at mount time: +
                  +
                • noatime: This option disables updating of a file's atime (last access time) attribute when the file is read. This can eliminate a significant number of filesystem writes, especially in the case of bootstrapping consumers. Kafka does not rely on the atime attributes at all, so it is safe to disable this.
                • +
                +
                XFS Notes
                + The XFS filesystem has a significant amount of auto-tuning in place, so it does not require any change in the default settings, either at filesystem creation time or at mount. The only tuning parameters worth considering are: +
                  +
                • largeio: This affects the preferred I/O size reported by the stat call. While this can allow for higher performance on larger disk writes, in practice it had minimal or no effect on performance.
                • +
                • nobarrier: For underlying devices that have battery-backed cache, this option can provide a little more performance by disabling periodic write flushes. However, if the underlying device is well-behaved, it will report to the filesystem that it does not require flushes, and this option will have no effect.
                • +
                +
                EXT4 Notes
                + EXT4 is a serviceable choice of filesystem for the Kafka data directories, however getting the most performance out of it will require adjusting several mount options. In addition, these options are generally unsafe in a failure scenario, and will result in much more data loss and corruption. For a single broker failure, this is not much of a concern as the disk can be wiped and the replicas rebuilt from the cluster. In a multiple-failure scenario, such as a power outage, this can mean underlying filesystem (and therefore data) corruption that is not easily recoverable. The following options can be adjusted: +
                  +
                • data=writeback: Ext4 defaults to data=ordered which puts a strong order on some writes. Kafka does not require this ordering as it does very paranoid data recovery on all unflushed log. This setting removes the ordering constraint and seems to significantly reduce latency. +
                • Disabling journaling: Journaling is a tradeoff: it makes reboots faster after server crashes but it introduces a great deal of additional locking which adds variance to write performance. Those who don't care about reboot time and want to reduce a major source of write latency spikes can turn off journaling entirely. +
                • commit=num_secs: This tunes the frequency with which ext4 commits to its metadata journal. Setting this to a lower value reduces the loss of unflushed data during a crash. Setting this to a higher value will improve throughput. +
                • nobh: This setting controls additional ordering guarantees when using data=writeback mode. This should be safe with Kafka as we do not depend on write ordering and improves throughput and latency. +
                • delalloc: Delayed allocation means that the filesystem avoid allocating any blocks until the physical write occurs. This allows ext4 to allocate a large extent instead of smaller pages and helps ensure the data is written sequentially. This feature is great for throughput. It does seem to involve some locking in the filesystem which adds a bit of latency variance. +
                + +

                6.8 Monitoring

                + + Kafka uses Yammer Metrics for metrics reporting in the server. The Java clients use Kafka Metrics, a built-in metrics registry that minimizes transitive dependencies pulled into client applications. Both expose metrics via JMX and can be configured to report stats using pluggable stats reporters to hook up to your monitoring system. +

                + All Kafka rate metrics have a corresponding cumulative count metric with suffix -total. For example, + records-consumed-rate has a corresponding metric named records-consumed-total. +

                + The easiest way to see the available metrics is to fire up jconsole and point it at a running kafka client or server; this will allow browsing all metrics with JMX. + +

                Security Considerations for Remote Monitoring using JMX

                + Apache Kafka disables remote JMX by default. You can enable remote monitoring using JMX by setting the environment variable + JMX_PORT for processes started using the CLI or standard Java system properties to enable remote JMX programmatically. + You must enable security when enabling remote JMX in production scenarios to ensure that unauthorized users cannot monitor or + control your broker or application as well as the platform on which these are running. Note that authentication is disabled for + JMX by default in Kafka and security configs must be overridden for production deployments by setting the environment variable + KAFKA_JMX_OPTS for processes started using the CLI or by setting appropriate Java system properties. See + Monitoring and Management Using JMX Technology + for details on securing JMX. +

                + We do graphing and alerting on the following metrics: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
                DescriptionMbean nameNormal value
                Message in ratekafka.server:type=BrokerTopicMetrics,name=MessagesInPerSec
                Byte in rate from clientskafka.server:type=BrokerTopicMetrics,name=BytesInPerSec
                Byte in rate from other brokerskafka.server:type=BrokerTopicMetrics,name=ReplicationBytesInPerSec
                Controller Request rate from Brokerkafka.controller:type=ControllerChannelManager,name=RequestRateAndQueueTimeMs,brokerId=([0-9]+)The rate (requests per second) at which the ControllerChannelManager takes requests from the + queue of the given broker. And the time it takes for a request to stay in this queue before + it is taken from the queue.
                Controller Event queue sizekafka.controller:type=ControllerEventManager,name=EventQueueSizeSize of the ControllerEventManager's queue.
                Controller Event queue timekafka.controller:type=ControllerEventManager,name=EventQueueTimeMsTime that takes for any event (except the Idle event) to wait in the ControllerEventManager's + queue before being processed
                Request ratekafka.network:type=RequestMetrics,name=RequestsPerSec,request={Produce|FetchConsumer|FetchFollower},version=([0-9]+)
                Error ratekafka.network:type=RequestMetrics,name=ErrorsPerSec,request=([-.\w]+),error=([-.\w]+)Number of errors in responses counted per-request-type, per-error-code. If a response contains + multiple errors, all are counted. error=NONE indicates successful responses.
                Request size in byteskafka.network:type=RequestMetrics,name=RequestBytes,request=([-.\w]+)Size of requests for each request type.
                Temporary memory size in byteskafka.network:type=RequestMetrics,name=TemporaryMemoryBytes,request={Produce|Fetch}Temporary memory used for message format conversions and decompression.
                Message conversion timekafka.network:type=RequestMetrics,name=MessageConversionsTimeMs,request={Produce|Fetch}Time in milliseconds spent on message format conversions.
                Message conversion ratekafka.server:type=BrokerTopicMetrics,name={Produce|Fetch}MessageConversionsPerSec,topic=([-.\w]+)Number of records which required message format conversion.
                Request Queue Sizekafka.network:type=RequestChannel,name=RequestQueueSizeSize of the request queue.
                Byte out rate to clientskafka.server:type=BrokerTopicMetrics,name=BytesOutPerSec
                Byte out rate to other brokerskafka.server:type=BrokerTopicMetrics,name=ReplicationBytesOutPerSec
                Message validation failure rate due to no key specified for compacted topickafka.server:type=BrokerTopicMetrics,name=NoKeyCompactedTopicRecordsPerSec
                Message validation failure rate due to invalid magic numberkafka.server:type=BrokerTopicMetrics,name=InvalidMagicNumberRecordsPerSec
                Message validation failure rate due to incorrect crc checksumkafka.server:type=BrokerTopicMetrics,name=InvalidMessageCrcRecordsPerSec
                Message validation failure rate due to non-continuous offset or sequence number in batchkafka.server:type=BrokerTopicMetrics,name=InvalidOffsetOrSequenceRecordsPerSec
                Log flush rate and timekafka.log:type=LogFlushStats,name=LogFlushRateAndTimeMs
                # of offline log directorieskafka.log:type=LogManager,name=OfflineLogDirectoryCount0
                Leader election ratekafka.controller:type=ControllerStats,name=LeaderElectionRateAndTimeMsnon-zero when there are broker failures
                Unclean leader election ratekafka.controller:type=ControllerStats,name=UncleanLeaderElectionsPerSec0
                Is controller active on brokerkafka.controller:type=KafkaController,name=ActiveControllerCountonly one broker in the cluster should have 1
                Pending topic deleteskafka.controller:type=KafkaController,name=TopicsToDeleteCount
                Pending replica deleteskafka.controller:type=KafkaController,name=ReplicasToDeleteCount
                Ineligible pending topic deleteskafka.controller:type=KafkaController,name=TopicsIneligibleToDeleteCount
                Ineligible pending replica deleteskafka.controller:type=KafkaController,name=ReplicasIneligibleToDeleteCount
                # of under replicated partitions (|ISR| < |all replicas|)kafka.server:type=ReplicaManager,name=UnderReplicatedPartitions0
                # of under minIsr partitions (|ISR| < min.insync.replicas)kafka.server:type=ReplicaManager,name=UnderMinIsrPartitionCount0
                # of at minIsr partitions (|ISR| = min.insync.replicas)kafka.server:type=ReplicaManager,name=AtMinIsrPartitionCount0
                Partition countskafka.server:type=ReplicaManager,name=PartitionCountmostly even across brokers
                Offline Replica countskafka.server:type=ReplicaManager,name=OfflineReplicaCount0
                Leader replica countskafka.server:type=ReplicaManager,name=LeaderCountmostly even across brokers
                ISR shrink ratekafka.server:type=ReplicaManager,name=IsrShrinksPerSecIf a broker goes down, ISR for some of the partitions will + shrink. When that broker is up again, ISR will be expanded + once the replicas are fully caught up. Other than that, the + expected value for both ISR shrink rate and expansion rate is 0.
                ISR expansion ratekafka.server:type=ReplicaManager,name=IsrExpandsPerSecSee above
                Failed ISR update ratekafka.server:type=ReplicaManager,name=FailedIsrUpdatesPerSec0
                Max lag in messages btw follower and leader replicaskafka.server:type=ReplicaFetcherManager,name=MaxLag,clientId=Replicalag should be proportional to the maximum batch size of a produce request.
                Lag in messages per follower replicakafka.server:type=FetcherLagMetrics,name=ConsumerLag,clientId=([-.\w]+),topic=([-.\w]+),partition=([0-9]+)lag should be proportional to the maximum batch size of a produce request.
                Requests waiting in the producer purgatorykafka.server:type=DelayedOperationPurgatory,name=PurgatorySize,delayedOperation=Producenon-zero if ack=-1 is used
                Requests waiting in the fetch purgatorykafka.server:type=DelayedOperationPurgatory,name=PurgatorySize,delayedOperation=Fetchsize depends on fetch.wait.max.ms in the consumer
                Request total timekafka.network:type=RequestMetrics,name=TotalTimeMs,request={Produce|FetchConsumer|FetchFollower}broken into queue, local, remote and response send time
                Time the request waits in the request queuekafka.network:type=RequestMetrics,name=RequestQueueTimeMs,request={Produce|FetchConsumer|FetchFollower}
                Time the request is processed at the leaderkafka.network:type=RequestMetrics,name=LocalTimeMs,request={Produce|FetchConsumer|FetchFollower}
                Time the request waits for the followerkafka.network:type=RequestMetrics,name=RemoteTimeMs,request={Produce|FetchConsumer|FetchFollower}non-zero for produce requests when ack=-1
                Time the request waits in the response queuekafka.network:type=RequestMetrics,name=ResponseQueueTimeMs,request={Produce|FetchConsumer|FetchFollower}
                Time to send the responsekafka.network:type=RequestMetrics,name=ResponseSendTimeMs,request={Produce|FetchConsumer|FetchFollower}
                Number of messages the consumer lags behind the producer by. Published by the consumer, not broker.kafka.consumer:type=consumer-fetch-manager-metrics,client-id={client-id} Attribute: records-lag-max
                The average fraction of time the network processors are idlekafka.network:type=SocketServer,name=NetworkProcessorAvgIdlePercentbetween 0 and 1, ideally > 0.3
                The number of connections disconnected on a processor due to a client not re-authenticating and then using the connection beyond its expiration time for anything other than re-authenticationkafka.server:type=socket-server-metrics,listener=[SASL_PLAINTEXT|SASL_SSL],networkProcessor=<#>,name=expired-connections-killed-countideally 0 when re-authentication is enabled, implying there are no longer any older, pre-2.2.0 clients connecting to this (listener, processor) combination
                The total number of connections disconnected, across all processors, due to a client not re-authenticating and then using the connection beyond its expiration time for anything other than re-authenticationkafka.network:type=SocketServer,name=ExpiredConnectionsKilledCountideally 0 when re-authentication is enabled, implying there are no longer any older, pre-2.2.0 clients connecting to this broker
                The average fraction of time the request handler threads are idlekafka.server:type=KafkaRequestHandlerPool,name=RequestHandlerAvgIdlePercentbetween 0 and 1, ideally > 0.3
                Bandwidth quota metrics per (user, client-id), user or client-idkafka.server:type={Produce|Fetch},user=([-.\w]+),client-id=([-.\w]+)Two attributes. throttle-time indicates the amount of time in ms the client was throttled. Ideally = 0. + byte-rate indicates the data produce/consume rate of the client in bytes/sec. + For (user, client-id) quotas, both user and client-id are specified. If per-client-id quota is applied to the client, user is not specified. If per-user quota is applied, client-id is not specified.
                Request quota metrics per (user, client-id), user or client-idkafka.server:type=Request,user=([-.\w]+),client-id=([-.\w]+)Two attributes. throttle-time indicates the amount of time in ms the client was throttled. Ideally = 0. + request-time indicates the percentage of time spent in broker network and I/O threads to process requests from client group. + For (user, client-id) quotas, both user and client-id are specified. If per-client-id quota is applied to the client, user is not specified. If per-user quota is applied, client-id is not specified.
                Requests exempt from throttlingkafka.server:type=Requestexempt-throttle-time indicates the percentage of time spent in broker network and I/O threads to process requests + that are exempt from throttling.
                ZooKeeper client request latencykafka.server:type=ZooKeeperClientMetrics,name=ZooKeeperRequestLatencyMsLatency in millseconds for ZooKeeper requests from broker.
                ZooKeeper connection statuskafka.server:type=SessionExpireListener,name=SessionStateConnection status of broker's ZooKeeper session which may be one of + Disconnected|SyncConnected|AuthFailed|ConnectedReadOnly|SaslAuthenticated|Expired.
                Max time to load group metadatakafka.server:type=group-coordinator-metrics,name=partition-load-time-maxmaximum time, in milliseconds, it took to load offsets and group metadata from the consumer offset partitions loaded in the last 30 seconds (including time spent waiting for the loading task to be scheduled)
                Avg time to load group metadatakafka.server:type=group-coordinator-metrics,name=partition-load-time-avgaverage time, in milliseconds, it took to load offsets and group metadata from the consumer offset partitions loaded in the last 30 seconds (including time spent waiting for the loading task to be scheduled)
                Max time to load transaction metadatakafka.server:type=transaction-coordinator-metrics,name=partition-load-time-maxmaximum time, in milliseconds, it took to load transaction metadata from the consumer offset partitions loaded in the last 30 seconds (including time spent waiting for the loading task to be scheduled)
                Avg time to load transaction metadatakafka.server:type=transaction-coordinator-metrics,name=partition-load-time-avgaverage time, in milliseconds, it took to load transaction metadata from the consumer offset partitions loaded in the last 30 seconds (including time spent waiting for the loading task to be scheduled)
                Consumer Group Offset Countkafka.server:type=GroupMetadataManager,name=NumOffsetsTotal number of committed offsets for Consumer Groups
                Consumer Group Countkafka.server:type=GroupMetadataManager,name=NumGroupsTotal number of Consumer Groups
                Consumer Group Count, per Statekafka.server:type=GroupMetadataManager,name=NumGroups[PreparingRebalance,CompletingRebalance,Empty,Stable,Dead]The number of Consumer Groups in each state: PreparingRebalance, CompletingRebalance, Empty, Stable, Dead
                Number of reassigning partitionskafka.server:type=ReplicaManager,name=ReassigningPartitionsThe number of reassigning leader partitions on a broker.
                Outgoing byte rate of reassignment traffickafka.server:type=BrokerTopicMetrics,name=ReassignmentBytesOutPerSec
                Incoming byte rate of reassignment traffickafka.server:type=BrokerTopicMetrics,name=ReassignmentBytesInPerSec
                Size of a partition on disk (in bytes)kafka.log:type=Log,name=Size,topic=([-.\w]+),partition=([0-9]+)The size of a partition on disk, measured in bytes.
                Number of log segments in a partitionkafka.log:type=Log,name=NumLogSegments,topic=([-.\w]+),partition=([0-9]+)The number of log segments in a partition.
                First offset in a partitionkafka.log:type=Log,name=LogStartOffset,topic=([-.\w]+),partition=([0-9]+)The first offset in a partition.
                Last offset in a partitionkafka.log:type=Log,name=LogEndOffset,topic=([-.\w]+),partition=([0-9]+)The last offset in a partition.
                + +

                Common monitoring metrics for producer/consumer/connect/streams

                + + The following metrics are available on producer/consumer/connector/streams instances. For specific metrics, please see following sections. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
                Metric/Attribute nameDescriptionMbean name
                connection-close-rateConnections closed per second in the window.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                connection-close-totalTotal connections closed in the window.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                connection-creation-rateNew connections established per second in the window.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                connection-creation-totalTotal new connections established in the window.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                network-io-rateThe average number of network operations (reads or writes) on all connections per second.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                network-io-totalThe total number of network operations (reads or writes) on all connections.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                outgoing-byte-rateThe average number of outgoing bytes sent per second to all servers.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                outgoing-byte-totalThe total number of outgoing bytes sent to all servers.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                request-rateThe average number of requests sent per second.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                request-totalThe total number of requests sent.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                request-size-avgThe average size of all requests in the window.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                request-size-maxThe maximum size of any request sent in the window.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                incoming-byte-rateBytes/second read off all sockets.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                incoming-byte-totalTotal bytes read off all sockets.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                response-rateResponses received per second.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                response-totalTotal responses received.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                select-rateNumber of times the I/O layer checked for new I/O to perform per second.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                select-totalTotal number of times the I/O layer checked for new I/O to perform.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                io-wait-time-ns-avgThe average length of time the I/O thread spent waiting for a socket ready for reads or writes in nanoseconds.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                io-wait-time-ns-totalThe total time the I/O thread spent waiting in nanoseconds.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                io-waittime-total*Deprecated* The total time the I/O thread spent waiting in nanoseconds. Replacement is io-wait-time-ns-total.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                io-wait-ratioThe fraction of time the I/O thread spent waiting.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                io-time-ns-avgThe average length of time for I/O per select call in nanoseconds.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                io-time-ns-totalThe total time the I/O thread spent doing I/O in nanoseconds.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                iotime-total*Deprecated* The total time the I/O thread spent doing I/O in nanoseconds. Replacement is io-time-ns-total.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                io-ratioThe fraction of time the I/O thread spent doing I/O.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                connection-countThe current number of active connections.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                successful-authentication-rateConnections per second that were successfully authenticated using SASL or SSL.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                successful-authentication-totalTotal connections that were successfully authenticated using SASL or SSL.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                failed-authentication-rateConnections per second that failed authentication.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                failed-authentication-totalTotal connections that failed authentication.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                successful-reauthentication-rateConnections per second that were successfully re-authenticated using SASL.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                successful-reauthentication-totalTotal connections that were successfully re-authenticated using SASL.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                reauthentication-latency-maxThe maximum latency in ms observed due to re-authentication.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                reauthentication-latency-avgThe average latency in ms observed due to re-authentication.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                failed-reauthentication-rateConnections per second that failed re-authentication.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                failed-reauthentication-totalTotal connections that failed re-authentication.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                successful-authentication-no-reauth-totalTotal connections that were successfully authenticated by older, pre-2.2.0 SASL clients that do not support re-authentication. May only be non-zero.kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)
                + +

                Common Per-broker metrics for producer/consumer/connect/streams

                + + The following metrics are available on producer/consumer/connector/streams instances. For specific metrics, please see following sections. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
                Metric/Attribute nameDescriptionMbean name
                outgoing-byte-rateThe average number of outgoing bytes sent per second for a node.kafka.[producer|consumer|connect]:type=[consumer|producer|connect]-node-metrics,client-id=([-.\w]+),node-id=([0-9]+)
                outgoing-byte-totalThe total number of outgoing bytes sent for a node.kafka.[producer|consumer|connect]:type=[consumer|producer|connect]-node-metrics,client-id=([-.\w]+),node-id=([0-9]+)
                request-rateThe average number of requests sent per second for a node.kafka.[producer|consumer|connect]:type=[consumer|producer|connect]-node-metrics,client-id=([-.\w]+),node-id=([0-9]+)
                request-totalThe total number of requests sent for a node.kafka.[producer|consumer|connect]:type=[consumer|producer|connect]-node-metrics,client-id=([-.\w]+),node-id=([0-9]+)
                request-size-avgThe average size of all requests in the window for a node.kafka.[producer|consumer|connect]:type=[consumer|producer|connect]-node-metrics,client-id=([-.\w]+),node-id=([0-9]+)
                request-size-maxThe maximum size of any request sent in the window for a node.kafka.[producer|consumer|connect]:type=[consumer|producer|connect]-node-metrics,client-id=([-.\w]+),node-id=([0-9]+)
                incoming-byte-rateThe average number of bytes received per second for a node.kafka.[producer|consumer|connect]:type=[consumer|producer|connect]-node-metrics,client-id=([-.\w]+),node-id=([0-9]+)
                incoming-byte-totalThe total number of bytes received for a node.kafka.[producer|consumer|connect]:type=[consumer|producer|connect]-node-metrics,client-id=([-.\w]+),node-id=([0-9]+)
                request-latency-avgThe average request latency in ms for a node.kafka.[producer|consumer|connect]:type=[consumer|producer|connect]-node-metrics,client-id=([-.\w]+),node-id=([0-9]+)
                request-latency-maxThe maximum request latency in ms for a node.kafka.[producer|consumer|connect]:type=[consumer|producer|connect]-node-metrics,client-id=([-.\w]+),node-id=([0-9]+)
                response-rateResponses received per second for a node.kafka.[producer|consumer|connect]:type=[consumer|producer|connect]-node-metrics,client-id=([-.\w]+),node-id=([0-9]+)
                response-totalTotal responses received for a node.kafka.[producer|consumer|connect]:type=[consumer|producer|connect]-node-metrics,client-id=([-.\w]+),node-id=([0-9]+)
                + +

                Producer monitoring

                + + The following metrics are available on producer instances. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
                Metric/Attribute nameDescriptionMbean name
                waiting-threadsThe number of user threads blocked waiting for buffer memory to enqueue their records.kafka.producer:type=producer-metrics,client-id=([-.\w]+)
                buffer-total-bytesThe maximum amount of buffer memory the client can use (whether or not it is currently used).kafka.producer:type=producer-metrics,client-id=([-.\w]+)
                buffer-available-bytesThe total amount of buffer memory that is not being used (either unallocated or in the free list).kafka.producer:type=producer-metrics,client-id=([-.\w]+)
                bufferpool-wait-timeThe fraction of time an appender waits for space allocation.kafka.producer:type=producer-metrics,client-id=([-.\w]+)
                bufferpool-wait-time-total*Deprecated* The total time an appender waits for space allocation in nanoseconds. Replacement is bufferpool-wait-time-ns-totalkafka.producer:type=producer-metrics,client-id=([-.\w]+)
                bufferpool-wait-time-ns-totalThe total time an appender waits for space allocation in nanoseconds.kafka.producer:type=producer-metrics,client-id=([-.\w]+)
                + +
                Producer Sender Metrics
                + + + + +

                consumer monitoring

                + + The following metrics are available on consumer instances. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
                Metric/Attribute nameDescriptionMbean name
                time-between-poll-avgThe average delay between invocations of poll().kafka.consumer:type=consumer-metrics,client-id=([-.\w]+)
                time-between-poll-maxThe max delay between invocations of poll().kafka.consumer:type=consumer-metrics,client-id=([-.\w]+)
                last-poll-seconds-agoThe number of seconds since the last poll() invocation.kafka.consumer:type=consumer-metrics,client-id=([-.\w]+)
                poll-idle-ratio-avgThe average fraction of time the consumer's poll() is idle as opposed to waiting for the user code to process records.kafka.consumer:type=consumer-metrics,client-id=([-.\w]+)
                + +
                Consumer Group Metrics
                + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
                Metric/Attribute nameDescriptionMbean name
                commit-latency-avgThe average time taken for a commit requestkafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                commit-latency-maxThe max time taken for a commit requestkafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                commit-rateThe number of commit calls per secondkafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                commit-totalThe total number of commit callskafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                assigned-partitionsThe number of partitions currently assigned to this consumerkafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                heartbeat-response-time-maxThe max time taken to receive a response to a heartbeat requestkafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                heartbeat-rateThe average number of heartbeats per secondkafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                heartbeat-totalThe total number of heartbeatskafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                join-time-avgThe average time taken for a group rejoinkafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                join-time-maxThe max time taken for a group rejoinkafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                join-rateThe number of group joins per secondkafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                join-totalThe total number of group joinskafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                sync-time-avgThe average time taken for a group synckafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                sync-time-maxThe max time taken for a group synckafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                sync-rateThe number of group syncs per secondkafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                sync-totalThe total number of group syncskafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                rebalance-latency-avgThe average time taken for a group rebalancekafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                rebalance-latency-maxThe max time taken for a group rebalancekafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                rebalance-latency-totalThe total time taken for group rebalances so farkafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                rebalance-totalThe total number of group rebalances participatedkafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                rebalance-rate-per-hourThe number of group rebalance participated per hourkafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                failed-rebalance-totalThe total number of failed group rebalanceskafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                failed-rebalance-rate-per-hourThe number of failed group rebalance event per hourkafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                last-rebalance-seconds-agoThe number of seconds since the last rebalance eventkafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                last-heartbeat-seconds-agoThe number of seconds since the last controller heartbeatkafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                partitions-revoked-latency-avgThe average time taken by the on-partitions-revoked rebalance listener callbackkafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                partitions-revoked-latency-maxThe max time taken by the on-partitions-revoked rebalance listener callbackkafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                partitions-assigned-latency-avgThe average time taken by the on-partitions-assigned rebalance listener callbackkafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                partitions-assigned-latency-maxThe max time taken by the on-partitions-assigned rebalance listener callbackkafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                partitions-lost-latency-avgThe average time taken by the on-partitions-lost rebalance listener callbackkafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                partitions-lost-latency-maxThe max time taken by the on-partitions-lost rebalance listener callbackkafka.consumer:type=consumer-coordinator-metrics,client-id=([-.\w]+)
                + +
                Consumer Fetch Metrics
                + + + +

                Connect Monitoring

                + + A Connect worker process contains all the producer and consumer metrics as well as metrics specific to Connect. + The worker process itself has a number of metrics, while each connector and task have additional metrics. + + + +

                Streams Monitoring

                + + A Kafka Streams instance contains all the producer and consumer metrics as well as additional metrics specific to Streams. + By default Kafka Streams has metrics with three recording levels: info, debug, and trace. + +

                + Note that the metrics have a 4-layer hierarchy. At the top level there are client-level metrics for each started + Kafka Streams client. Each client has stream threads, with their own metrics. Each stream thread has tasks, with their + own metrics. Each task has a number of processor nodes, with their own metrics. Each task also has a number of state stores + and record caches, all with their own metrics. +

                + + Use the following configuration option to specify which metrics + you want collected: + +
                metrics.recording.level="info"
                + +
                Client Metrics
                +All of the following metrics have a recording level of info: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
                Metric/Attribute nameDescriptionMbean name
                versionThe version of the Kafka Streams client.kafka.streams:type=stream-metrics,client-id=([-.\w]+)
                commit-idThe version control commit ID of the Kafka Streams client.kafka.streams:type=stream-metrics,client-id=([-.\w]+)
                application-idThe application ID of the Kafka Streams client.kafka.streams:type=stream-metrics,client-id=([-.\w]+)
                topology-descriptionThe description of the topology executed in the Kafka Streams client.kafka.streams:type=stream-metrics,client-id=([-.\w]+)
                stateThe state of the Kafka Streams client.kafka.streams:type=stream-metrics,client-id=([-.\w]+)
                failed-stream-threadsThe number of failed stream threads since the start of the Kafka Streams client.kafka.streams:type=stream-metrics,client-id=([-.\w]+)
                + +
                Thread Metrics
                +All of the following metrics have a recording level of info: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
                Metric/Attribute nameDescriptionMbean name
                commit-latency-avgThe average execution time in ms, for committing, across all running tasks of this thread.kafka.streams:type=stream-thread-metrics,thread-id=([-.\w]+)
                commit-latency-maxThe maximum execution time in ms, for committing, across all running tasks of this thread.kafka.streams:type=stream-thread-metrics,thread-id=([-.\w]+)
                poll-latency-avgThe average execution time in ms, for consumer polling.kafka.streams:type=stream-thread-metrics,thread-id=([-.\w]+)
                poll-latency-maxThe maximum execution time in ms, for consumer polling.kafka.streams:type=stream-thread-metrics,thread-id=([-.\w]+)
                process-latency-avgThe average execution time in ms, for processing.kafka.streams:type=stream-thread-metrics,thread-id=([-.\w]+)
                process-latency-maxThe maximum execution time in ms, for processing.kafka.streams:type=stream-thread-metrics,thread-id=([-.\w]+)
                punctuate-latency-avgThe average execution time in ms, for punctuating.kafka.streams:type=stream-thread-metrics,thread-id=([-.\w]+)
                punctuate-latency-maxThe maximum execution time in ms, for punctuating.kafka.streams:type=stream-thread-metrics,thread-id=([-.\w]+)
                commit-rateThe average number of commits per second.kafka.streams:type=stream-thread-metrics,thread-id=([-.\w]+)
                commit-totalThe total number of commit calls.kafka.streams:type=stream-thread-metrics,thread-id=([-.\w]+)
                poll-rateThe average number of consumer poll calls per second.kafka.streams:type=stream-thread-metrics,thread-id=([-.\w]+)
                poll-totalThe total number of consumer poll calls.kafka.streams:type=stream-thread-metrics,thread-id=([-.\w]+)
                process-rateThe average number of processed records per second.kafka.streams:type=stream-thread-metrics,thread-id=([-.\w]+)
                process-totalThe total number of processed records.kafka.streams:type=stream-thread-metrics,thread-id=([-.\w]+)
                punctuate-rateThe average number of punctuate calls per second.kafka.streams:type=stream-thread-metrics,thread-id=([-.\w]+)
                punctuate-totalThe total number of punctuate calls.kafka.streams:type=stream-thread-metrics,thread-id=([-.\w]+)
                task-created-rateThe average number of tasks created per second.kafka.streams:type=stream-thread-metrics,thread-id=([-.\w]+)
                task-created-totalThe total number of tasks created.kafka.streams:type=stream-thread-metrics,thread-id=([-.\w]+)
                task-closed-rateThe average number of tasks closed per second.kafka.streams:type=stream-thread-metrics,thread-id=([-.\w]+)
                task-closed-totalThe total number of tasks closed.kafka.streams:type=stream-thread-metrics,thread-id=([-.\w]+)
                + +
                Task Metrics
                +All of the following metrics have a recording level of debug, except for the dropped-records-* and +active-process-ratio metrics which have a recording level of info: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
                Metric/Attribute nameDescriptionMbean name
                process-latency-avgThe average execution time in ns, for processing.kafka.streams:type=stream-task-metrics,thread-id=([-.\w]+),task-id=([-.\w]+)
                process-latency-maxThe maximum execution time in ns, for processing.kafka.streams:type=stream-task-metrics,thread-id=([-.\w]+),task-id=([-.\w]+)
                process-rateThe average number of processed records per second across all source processor nodes of this task.kafka.streams:type=stream-task-metrics,thread-id=([-.\w]+),task-id=([-.\w]+)
                process-totalThe total number of processed records across all source processor nodes of this task.kafka.streams:type=stream-task-metrics,thread-id=([-.\w]+),task-id=([-.\w]+)
                commit-latency-avgThe average execution time in ns, for committing.kafka.streams:type=stream-task-metrics,thread-id=([-.\w]+),task-id=([-.\w]+)
                commit-latency-maxThe maximum execution time in ns, for committing.kafka.streams:type=stream-task-metrics,thread-id=([-.\w]+),task-id=([-.\w]+)
                commit-rateThe average number of commit calls per second.kafka.streams:type=stream-task-metrics,thread-id=([-.\w]+),task-id=([-.\w]+)
                commit-totalThe total number of commit calls.kafka.streams:type=stream-task-metrics,thread-id=([-.\w]+),task-id=([-.\w]+)
                record-lateness-avgThe average observed lateness of records (stream time - record timestamp).kafka.streams:type=stream-task-metrics,thread-id=([-.\w]+),task-id=([-.\w]+)
                record-lateness-maxThe max observed lateness of records (stream time - record timestamp).kafka.streams:type=stream-task-metrics,thread-id=([-.\w]+),task-id=([-.\w]+)
                enforced-processing-rateThe average number of enforced processings per second.kafka.streams:type=stream-task-metrics,thread-id=([-.\w]+),task-id=([-.\w]+)
                enforced-processing-totalThe total number enforced processings.kafka.streams:type=stream-task-metrics,thread-id=([-.\w]+),task-id=([-.\w]+)
                dropped-records-rateThe average number of records dropped within this task.kafka.streams:type=stream-task-metrics,thread-id=([-.\w]+),task-id=([-.\w]+)
                dropped-records-totalThe total number of records dropped within this task.kafka.streams:type=stream-task-metrics,thread-id=([-.\w]+),task-id=([-.\w]+)
                active-process-ratioThe fraction of time the stream thread spent on processing this task among all assigned active tasks.kafka.streams:type=stream-task-metrics,thread-id=([-.\w]+),task-id=([-.\w]+)
                + +
                Processor Node Metrics
                + The following metrics are only available on certain types of nodes, i.e., the process-* metrics are only available for + source processor nodes, the suppression-emit-* metrics are only available for suppression operation nodes, and the + record-e2e-latency-* metrics are only available for source processor nodes and terminal nodes (nodes without successor + nodes). + All of the metrics have a recording level of debug, except for the record-e2e-latency-* metrics which have + a recording level of info: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
                Metric/Attribute nameDescriptionMbean name
                process-rateThe average number of records processed by a source processor node per second.kafka.streams:type=stream-processor-node-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),processor-node-id=([-.\w]+)
                process-totalThe total number of records processed by a source processor node per second.kafka.streams:type=stream-processor-node-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),processor-node-id=([-.\w]+)
                suppression-emit-rateThe rate at which records that have been emitted downstream from suppression operation nodes.kafka.streams:type=stream-processor-node-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),processor-node-id=([-.\w]+)
                suppression-emit-totalThe total number of records that have been emitted downstream from suppression operation nodes.kafka.streams:type=stream-processor-node-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),processor-node-id=([-.\w]+)
                record-e2e-latency-avgThe average end-to-end latency of a record, measured by comparing the record timestamp with the system time when it has been fully processed by the node.kafka.streams:type=stream-processor-node-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),processor-node-id=([-.\w]+)
                record-e2e-latency-maxThe maximum end-to-end latency of a record, measured by comparing the record timestamp with the system time when it has been fully processed by the node.kafka.streams:type=stream-processor-node-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),processor-node-id=([-.\w]+)
                record-e2e-latency-minThe minimum end-to-end latency of a record, measured by comparing the record timestamp with the system time when it has been fully processed by the node.kafka.streams:type=stream-processor-node-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),processor-node-id=([-.\w]+)
                + +
                State Store Metrics
                +All of the following metrics have a recording level of debug, except for the record-e2e-latency-* metrics which have a recording level trace>. +Note that the store-scope value is specified in StoreSupplier#metricsScope() for user's customized state stores; +for built-in state stores, currently we have: +
                  +
                • in-memory-state
                • +
                • in-memory-lru-state
                • +
                • in-memory-window-state
                • +
                • in-memory-suppression (for suppression buffers)
                • +
                • rocksdb-state (for RocksDB backed key-value store)
                • +
                • rocksdb-window-state (for RocksDB backed window store)
                • +
                • rocksdb-session-state (for RocksDB backed session store)
                • +
                + Metrics suppression-buffer-size-avg, suppression-buffer-size-max, suppression-buffer-count-avg, and suppression-buffer-count-max + are only available for suppression buffers. All other metrics are not available for suppression buffers. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
                Metric/Attribute nameDescriptionMbean name
                put-latency-avgThe average put execution time in ns.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                put-latency-maxThe maximum put execution time in ns.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                put-if-absent-latency-avgThe average put-if-absent execution time in ns.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                put-if-absent-latency-maxThe maximum put-if-absent execution time in ns.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                get-latency-avgThe average get execution time in ns.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                get-latency-maxThe maximum get execution time in ns.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                delete-latency-avgThe average delete execution time in ns.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                delete-latency-maxThe maximum delete execution time in ns.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                put-all-latency-avgThe average put-all execution time in ns.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                put-all-latency-maxThe maximum put-all execution time in ns.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                all-latency-avgThe average all operation execution time in ns.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                all-latency-maxThe maximum all operation execution time in ns.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                range-latency-avgThe average range execution time in ns.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                range-latency-maxThe maximum range execution time in ns.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                flush-latency-avgThe average flush execution time in ns.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                flush-latency-maxThe maximum flush execution time in ns.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                restore-latency-avgThe average restore execution time in ns.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                restore-latency-maxThe maximum restore execution time in ns.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                put-rateThe average put rate for this store.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                put-if-absent-rateThe average put-if-absent rate for this store.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                get-rateThe average get rate for this store.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                delete-rateThe average delete rate for this store.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                put-all-rateThe average put-all rate for this store.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                all-rateThe average all operation rate for this store.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                range-rateThe average range rate for this store.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                flush-rateThe average flush rate for this store.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                restore-rateThe average restore rate for this store.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                suppression-buffer-size-avgThe average total size, in bytes, of the buffered data over the sampling window.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),in-memory-suppression-id=([-.\w]+)
                suppression-buffer-size-maxThe maximum total size, in bytes, of the buffered data over the sampling window.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),in-memory-suppression-id=([-.\w]+)
                suppression-buffer-count-avgThe average number of records buffered over the sampling window.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),in-memory-suppression-id=([-.\w]+)
                suppression-buffer-count-maxThe maximum number of records buffered over the sampling window.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),in-memory-suppression-id=([-.\w]+)
                record-e2e-latency-avgThe average end-to-end latency of a record, measured by comparing the record timestamp with the system time when it has been fully processed by the node.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                record-e2e-latency-maxThe maximum end-to-end latency of a record, measured by comparing the record timestamp with the system time when it has been fully processed by the node.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                record-e2e-latency-minThe minimum end-to-end latency of a record, measured by comparing the record timestamp with the system time when it has been fully processed by the node.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                + +
                RocksDB Metrics
                + RocksDB metrics are grouped into statistics-based metrics and properties-based metrics. + The former are recorded from statistics that a RocksDB state store collects whereas the latter are recorded from + properties that RocksDB exposes. + Statistics collected by RocksDB provide cumulative measurements over time, e.g. bytes written to the state store. + Properties exposed by RocksDB provide current measurements, e.g., the amount of memory currently used. + Note that the store-scope for built-in RocksDB state stores are currently the following: +
                  +
                • rocksdb-state (for RocksDB backed key-value store)
                • +
                • rocksdb-window-state (for RocksDB backed window store)
                • +
                • rocksdb-session-state (for RocksDB backed session store)
                • +
                + + RocksDB Statistics-based Metrics: + All of the following statistics-based metrics have a recording level of debug because collecting + statistics in RocksDB + may have an impact on performance. + Statistics-based metrics are collected every minute from the RocksDB state stores. + If a state store consists of multiple RocksDB instances, as is the case for WindowStores and SessionStores, + each metric reports an aggregation over the RocksDB instances of the state store. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
                Metric/Attribute nameDescriptionMbean name
                bytes-written-rateThe average number of bytes written per second to the RocksDB state store.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                bytes-written-totalThe total number of bytes written to the RocksDB state store.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                bytes-read-rateThe average number of bytes read per second from the RocksDB state store.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                bytes-read-totalThe total number of bytes read from the RocksDB state store.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                memtable-bytes-flushed-rateThe average number of bytes flushed per second from the memtable to disk.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                memtable-bytes-flushed-totalThe total number of bytes flushed from the memtable to disk.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                memtable-hit-ratioThe ratio of memtable hits relative to all lookups to the memtable.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                block-cache-data-hit-ratioThe ratio of block cache hits for data blocks relative to all lookups for data blocks to the block cache.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                block-cache-index-hit-ratioThe ratio of block cache hits for index blocks relative to all lookups for index blocks to the block cache.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                block-cache-filter-hit-ratioThe ratio of block cache hits for filter blocks relative to all lookups for filter blocks to the block cache.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                write-stall-duration-avgThe average duration of write stalls in ms.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                write-stall-duration-totalThe total duration of write stalls in ms.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                bytes-read-compaction-rateThe average number of bytes read per second during compaction.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                bytes-written-compaction-rateThe average number of bytes written per second during compaction.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                number-open-filesThe number of current open files.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                number-file-errors-totalThe total number of file errors occurred.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                + + RocksDB Properties-based Metrics: + All of the following properties-based metrics have a recording level of info and are recorded when the + metrics are accessed. + If a state store consists of multiple RocksDB instances, as is the case for WindowStores and SessionStores, + each metric reports the sum over all the RocksDB instances of the state store, except for the block cache metrics + block-cache-*. The block cache metrics report the sum over all RocksDB instances if each instance uses its + own block cache, and they report the recorded value from only one instance if a single block cache is shared + among all instances. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
                Metric/Attribute nameDescriptionMbean name
                num-immutable-mem-tableThe number of immutable memtables that have not yet been flushed.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                cur-size-active-mem-tableThe approximate size of the active memtable in bytes.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                cur-size-all-mem-tablesThe approximate size of active and unflushed immutable memtables in bytes.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                size-all-mem-tablesThe approximate size of active, unflushed immutable, and pinned immutable memtables in bytes.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                num-entries-active-mem-tableThe number of entries in the active memtable.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                num-entries-imm-mem-tablesThe number of entries in the unflushed immutable memtables.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                num-deletes-active-mem-tableThe number of delete entries in the active memtable.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                num-deletes-imm-mem-tablesThe number of delete entries in the unflushed immutable memtables.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                mem-table-flush-pendingThis metric reports 1 if a memtable flush is pending, otherwise it reports 0.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                num-running-flushesThe number of currently running flushes.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                compaction-pendingThis metric reports 1 if at least one compaction is pending, otherwise it reports 0.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                num-running-compactionsThe number of currently running compactions.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                estimate-pending-compaction-bytesThe estimated total number of bytes a compaction needs to rewrite on disk to get all levels down to under + target size (only valid for level compaction).kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                total-sst-files-sizeThe total size in bytes of all SST files.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                live-sst-files-sizeThe total size in bytes of all SST files that belong to the latest LSM tree.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                num-live-versionsNumber of live versions of the LSM tree.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                block-cache-capacityThe capacity of the block cache in bytes.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                block-cache-usageThe memory size of the entries residing in block cache in bytes.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                block-cache-pinned-usageThe memory size for the entries being pinned in the block cache in bytes.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                estimate-num-keysThe estimated number of keys in the active and unflushed immutable memtables and storage.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                estimate-table-readers-memThe estimated memory in bytes used for reading SST tables, excluding memory used in block cache.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                background-errorsThe total number of background errors.kafka.streams:type=stream-state-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),[store-scope]-id=([-.\w]+)
                + +
                Record Cache Metrics
                + All of the following metrics have a recording level of debug: + + + + + + + + + + + + + + + + + + + + + + + + +
                Metric/Attribute nameDescriptionMbean name
                hit-ratio-avgThe average cache hit ratio defined as the ratio of cache read hits over the total cache read requests.kafka.streams:type=stream-record-cache-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),record-cache-id=([-.\w]+)
                hit-ratio-minThe mininum cache hit ratio.kafka.streams:type=stream-record-cache-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),record-cache-id=([-.\w]+)
                hit-ratio-maxThe maximum cache hit ratio.kafka.streams:type=stream-record-cache-metrics,thread-id=([-.\w]+),task-id=([-.\w]+),record-cache-id=([-.\w]+)
                + +

                Others

                + + We recommend monitoring GC time and other stats and various server stats such as CPU utilization, I/O service time, etc. + + On the client side, we recommend monitoring the message/byte rate (global and per topic), request rate/size/time, and on the consumer side, max lag in messages among all partitions and min fetch request rate. For a consumer to keep up, max lag needs to be less than a threshold and min fetch rate needs to be larger than 0. + +

                6.9 ZooKeeper

                + +

                Stable version

                + The current stable branch is 3.5. Kafka is regularly updated to include the latest release in the 3.5 series. + +

                Operationalizing ZooKeeper

                + Operationally, we do the following for a healthy ZooKeeper installation: +
                  +
                • Redundancy in the physical/hardware/network layout: try not to put them all in the same rack, decent (but don't go nuts) hardware, try to keep redundant power and network paths, etc. A typical ZooKeeper ensemble has 5 or 7 servers, which tolerates 2 and 3 servers down, respectively. If you have a small deployment, then using 3 servers is acceptable, but keep in mind that you'll only be able to tolerate 1 server down in this case.
                • +
                • I/O segregation: if you do a lot of write type traffic you'll almost definitely want the transaction logs on a dedicated disk group. Writes to the transaction log are synchronous (but batched for performance), and consequently, concurrent writes can significantly affect performance. ZooKeeper snapshots can be one such a source of concurrent writes, and ideally should be written on a disk group separate from the transaction log. Snapshots are written to disk asynchronously, so it is typically ok to share with the operating system and message log files. You can configure a server to use a separate disk group with the dataLogDir parameter.
                • +
                • Application segregation: Unless you really understand the application patterns of other apps that you want to install on the same box, it can be a good idea to run ZooKeeper in isolation (though this can be a balancing act with the capabilities of the hardware).
                • +
                • Use care with virtualization: It can work, depending on your cluster layout and read/write patterns and SLAs, but the tiny overheads introduced by the virtualization layer can add up and throw off ZooKeeper, as it can be very time sensitive
                • +
                • ZooKeeper configuration: It's java, make sure you give it 'enough' heap space (We usually run them with 3-5G, but that's mostly due to the data set size we have here). Unfortunately we don't have a good formula for it, but keep in mind that allowing for more ZooKeeper state means that snapshots can become large, and large snapshots affect recovery time. In fact, if the snapshot becomes too large (a few gigabytes), then you may need to increase the initLimit parameter to give enough time for servers to recover and join the ensemble.
                • +
                • Monitoring: Both JMX and the 4 letter words (4lw) commands are very useful, they do overlap in some cases (and in those cases we prefer the 4 letter commands, they seem more predictable, or at the very least, they work better with the LI monitoring infrastructure)
                • +
                • Don't overbuild the cluster: large clusters, especially in a write heavy usage pattern, means a lot of intracluster communication (quorums on the writes and subsequent cluster member updates), but don't underbuild it (and risk swamping the cluster). Having more servers adds to your read capacity.
                • +
                + Overall, we try to keep the ZooKeeper system as small as will handle the load (plus standard growth capacity planning) and as simple as possible. We try not to do anything fancy with the configuration or application layout as compared to the official release as well as keep it as self contained as possible. For these reasons, we tend to skip the OS packaged versions, since it has a tendency to try to put things in the OS standard hierarchy, which can be 'messy', for want of a better way to word it. + + +
                diff --git a/docs/protocol.html b/docs/protocol.html new file mode 100644 index 0000000..29811a2 --- /dev/null +++ b/docs/protocol.html @@ -0,0 +1,226 @@ + + + + +
                + +
                +

                Kafka protocol guide

                + +

                This document covers the wire protocol implemented in Kafka. It is meant to give a readable guide to the protocol that covers the available requests, their binary format, and the proper way to make use of them to implement a client. This document assumes you understand the basic design and terminology described here

                + + + +

                Preliminaries

                + +
                Network
                + +

                Kafka uses a binary protocol over TCP. The protocol defines all APIs as request response message pairs. All messages are size delimited and are made up of the following primitive types.

                + +

                The client initiates a socket connection and then writes a sequence of request messages and reads back the corresponding response message. No handshake is required on connection or disconnection. TCP is happier if you maintain persistent connections used for many requests to amortize the cost of the TCP handshake, but beyond this penalty connecting is pretty cheap.

                + +

                The client will likely need to maintain a connection to multiple brokers, as data is partitioned and the clients will need to talk to the server that has their data. However it should not generally be necessary to maintain multiple connections to a single broker from a single client instance (i.e. connection pooling).

                + +

                The server guarantees that on a single TCP connection, requests will be processed in the order they are sent and responses will return in that order as well. The broker's request processing allows only a single in-flight request per connection in order to guarantee this ordering. Note that clients can (and ideally should) use non-blocking IO to implement request pipelining and achieve higher throughput. i.e., clients can send requests even while awaiting responses for preceding requests since the outstanding requests will be buffered in the underlying OS socket buffer. All requests are initiated by the client, and result in a corresponding response message from the server except where noted.

                + +

                The server has a configurable maximum limit on request size and any request that exceeds this limit will result in the socket being disconnected.

                + +
                Partitioning and bootstrapping
                + +

                Kafka is a partitioned system so not all servers have the complete data set. Instead recall that topics are split into a pre-defined number of partitions, P, and each partition is replicated with some replication factor, N. Topic partitions themselves are just ordered "commit logs" numbered 0, 1, ..., P-1.

                + +

                All systems of this nature have the question of how a particular piece of data is assigned to a particular partition. Kafka clients directly control this assignment, the brokers themselves enforce no particular semantics of which messages should be published to a particular partition. Rather, to publish messages the client directly addresses messages to a particular partition, and when fetching messages, fetches from a particular partition. If two clients want to use the same partitioning scheme they must use the same method to compute the mapping of key to partition.

                + +

                These requests to publish or fetch data must be sent to the broker that is currently acting as the leader for a given partition. This condition is enforced by the broker, so a request for a particular partition to the wrong broker will result in an the NotLeaderForPartition error code (described below).

                + +

                How can the client find out which topics exist, what partitions they have, and which brokers currently host those partitions so that it can direct its requests to the right hosts? This information is dynamic, so you can't just configure each client with some static mapping file. Instead all Kafka brokers can answer a metadata request that describes the current state of the cluster: what topics there are, which partitions those topics have, which broker is the leader for those partitions, and the host and port information for these brokers.

                + +

                In other words, the client needs to somehow find one broker and that broker will tell the client about all the other brokers that exist and what partitions they host. This first broker may itself go down so the best practice for a client implementation is to take a list of two or three URLs to bootstrap from. The user can then choose to use a load balancer or just statically configure two or three of their Kafka hosts in the clients.

                + +

                The client does not need to keep polling to see if the cluster has changed; it can fetch metadata once when it is instantiated cache that metadata until it receives an error indicating that the metadata is out of date. This error can come in two forms: (1) a socket error indicating the client cannot communicate with a particular broker, (2) an error code in the response to a request indicating that this broker no longer hosts the partition for which data was requested.

                +
                  +
                1. Cycle through a list of "bootstrap" Kafka URLs until we find one we can connect to. Fetch cluster metadata.
                2. +
                3. Process fetch or produce requests, directing them to the appropriate broker based on the topic/partitions they send to or fetch from.
                4. +
                5. If we get an appropriate error, refresh the metadata and try again.
                6. +
                + +
                Partitioning Strategies
                + +

                As mentioned above the assignment of messages to partitions is something the producing client controls. That said, how should this functionality be exposed to the end-user?

                + +

                Partitioning really serves two purposes in Kafka:

                +
                  +
                1. It balances data and request load over brokers
                2. +
                3. It serves as a way to divvy up processing among consumer processes while allowing local state and preserving order within the partition. We call this semantic partitioning.
                4. +
                + +

                For a given use case you may care about only one of these or both.

                + +

                To accomplish simple load balancing a simple approach would be for the client to just round robin requests over all brokers. Another alternative, in an environment where there are many more producers than brokers, would be to have each client chose a single partition at random and publish to that. This later strategy will result in far fewer TCP connections.

                + +

                Semantic partitioning means using some key in the message to assign messages to partitions. For example if you were processing a click message stream you might want to partition the stream by the user id so that all data for a particular user would go to a single consumer. To accomplish this the client can take a key associated with the message and use some hash of this key to choose the partition to which to deliver the message.

                + +
                Batching
                + +

                Our APIs encourage batching small things together for efficiency. We have found this is a very significant performance win. Both our API to send messages and our API to fetch messages always work with a sequence of messages not a single message to encourage this. A clever client can make use of this and support an "asynchronous" mode in which it batches together messages sent individually and sends them in larger clumps. We go even further with this and allow the batching across multiple topics and partitions, so a produce request may contain data to append to many partitions and a fetch request may pull data from many partitions all at once.

                + +

                The client implementer can choose to ignore this and send everything one at a time if they like.

                + +
                Compatibility
                + +

                Kafka has a "bidirectional" client compatibility policy. In other words, new clients can talk to old servers, and old clients can talk to new servers. This allows users to upgrade either clients or servers without experiencing any downtime. + +

                Since the Kafka protocol has changed over time, clients and servers need to agree on the schema of the message that they are sending over the wire. This is done through API versioning. + +

                Before each request is sent, the client sends the API key and the API version. These two 16-bit numbers, when taken together, uniquely identify the schema of the message to follow. + +

                The intention is that clients will support a range of API versions. When communicating with a particular broker, a given client should use the highest API version supported by both and indicate this version in their requests.

                + +

                The server will reject requests with a version it does not support, and will always respond to the client with exactly the protocol format it expects based on the version it included in its request. The intended upgrade path is that new features would first be rolled out on the server (with the older clients not making use of them) and then as newer clients are deployed these new features would gradually be taken advantage of.

                + +

                Note that KIP-482 tagged fields can be added to a request without incrementing the version number. This offers an additional way of evolving the message schema without breaking compatibility. Tagged fields do not take up any space when the field is not set. Therefore, if a field is rarely used, it is more efficient to make it a tagged field than to put it in the mandatory schema. However, tagged fields are ignored by recipients that don't know about them, which could pose a challenge if this is not the behavior that the sender wants. In such cases, a version bump may be more appropriate. + +

                Retrieving Supported API versions
                +

                In order to work against multiple broker versions, clients need to know what versions of various APIs a + broker supports. The broker exposes this information since 0.10.0.0 as described in KIP-35. + Clients should use the supported API versions information to choose the highest API version supported by both client and broker. If no such version + exists, an error should be reported to the user.

                +

                The following sequence may be used by a client to obtain supported API versions from a broker.

                +
                  +
                1. Client sends ApiVersionsRequest to a broker after connection has been established with the broker. If SSL is enabled, + this happens after SSL connection has been established.
                2. +
                3. On receiving ApiVersionsRequest, a broker returns its full list of supported ApiKeys and + versions regardless of current authentication state (e.g., before SASL authentication on an SASL listener, do note that no + Kafka protocol requests may take place on an SSL listener before the SSL handshake is finished). If this is considered to + leak information about the broker version a workaround is to use SSL with client authentication which is performed at an + earlier stage of the connection where the ApiVersionRequest is not available. Also, note that broker versions older + than 0.10.0.0 do not support this API and will either ignore the request or close connection in response to the request.
                4. +
                5. If multiple versions of an API are supported by broker and client, clients are recommended to use the latest version supported + by the broker and itself.
                6. +
                7. Deprecation of a protocol version is done by marking an API version as deprecated in the protocol documentation.
                8. +
                9. Supported API versions obtained from a broker are only valid for the connection on which that information is obtained. + In the event of disconnection, the client should obtain the information from the broker again, as the broker might have been + upgraded/downgraded in the mean time.
                10. +
                + +
                SASL Authentication Sequence
                +

                The following sequence is used for SASL authentication: +

                  +
                1. Kafka ApiVersionsRequest may be sent by the client to obtain the version ranges of requests supported by the broker. This is optional.
                2. +
                3. Kafka SaslHandshakeRequest containing the SASL mechanism for authentication is sent by the client. If the requested mechanism is not enabled + in the server, the server responds with the list of supported mechanisms and closes the client connection. If the mechanism is enabled + in the server, the server sends a successful response and continues with SASL authentication.
                4. +
                5. The actual SASL authentication is now performed. If SaslHandshakeRequest version is v0, a series of SASL client and server tokens corresponding to the mechanism are sent + as opaque packets without wrapping the messages with Kafka protocol headers. If SaslHandshakeRequest version is v1, the SaslAuthenticate + request/response are used, where the actual SASL tokens are wrapped in the Kafka protocol. The error code in the final message from the broker will indicate if authentication succeeded or failed.
                6. +
                7. If authentication succeeds, subsequent packets are handled as Kafka API requests. Otherwise, the client connection is closed.
                8. +
                +

                For interoperability with 0.9.0.x clients, the first packet received by the server is handled as a SASL/GSSAPI client token if it is not a valid +Kafka request. SASL/GSSAPI authentication is performed starting with this packet, skipping the first two steps above.

                + + +

                The Protocol

                + +
                Protocol Primitive Types
                + +

                The protocol is built out of the following primitive types.

                + + +
                Notes on reading the request format grammars
                + +

                The BNFs below give an exact context free grammar for the request and response binary format. The BNF is intentionally not compact in order to give human-readable name. As always in a BNF a sequence of productions indicates concatenation. When there are multiple possible productions these are separated with '|' and may be enclosed in parenthesis for grouping. The top-level definition is always given first and subsequent sub-parts are indented.

                + +
                Common Request and Response Structure
                + +

                All requests and responses originate from the following grammar which will be incrementally describe through the rest of this document:

                + +
                RequestOrResponse => Size (RequestMessage | ResponseMessage)
                +  Size => int32
                + + + + +
                FieldDescription
                message_sizeThe message_size field gives the size of the subsequent request or response message in bytes. The client can read requests by first reading this 4 byte size as an integer N, and then reading and parsing the subsequent N bytes of the request.
                + +
                Record Batch
                +

                A description of the record batch format can be found here.

                + +

                Constants

                + +
                Error Codes
                +

                We use numeric codes to indicate what problem occurred on the server. These can be translated by the client into exceptions or whatever the appropriate error handling mechanism in the client language. Here is a table of the error codes currently in use:

                + + +
                Api Keys
                +

                The following are the numeric codes that the ApiKey in the request can take for each of the below request types.

                + + +

                The Messages

                + +

                This section gives details on each of the individual API Messages, their usage, their binary format, and the meaning of their fields.

                + + +

                Some Common Philosophical Questions

                + +

                Some people have asked why we don't use HTTP. There are a number of reasons, the best is that client implementors can make use of some of the more advanced TCP features--the ability to multiplex requests, the ability to simultaneously poll many connections, etc. We have also found HTTP libraries in many languages to be surprisingly shabby.

                + +

                Others have asked if maybe we shouldn't support many different protocols. Prior experience with this was that it makes it very hard to add and test new features if they have to be ported across many protocol implementations. Our feeling is that most users don't really see multiple protocols as a feature, they just want a good reliable client in the language of their choice.

                + +

                Another question is why we don't adopt XMPP, STOMP, AMQP or an existing protocol. The answer to this varies by protocol, but in general the problem is that the protocol does determine large parts of the implementation and we couldn't do what we are doing if we didn't have control over the protocol. Our belief is that it is possible to do better than existing messaging systems have in providing a truly distributed messaging system, and to do this we need to build something that works differently.

                + +

                A final question is why we don't use a system like Protocol Buffers or Thrift to define our request messages. These packages excel at helping you to managing lots and lots of serialized messages. However we have only a few messages. Support across languages is somewhat spotty (depending on the package). Finally the mapping between binary log format and wire protocol is something we manage somewhat carefully and this would not be possible with these systems. Finally we prefer the style of versioning APIs explicitly and checking this to inferring new values as nulls as it allows more nuanced control of compatibility.

                + + + + diff --git a/docs/quickstart.html b/docs/quickstart.html new file mode 100644 index 0000000..7f003c0 --- /dev/null +++ b/docs/quickstart.html @@ -0,0 +1,277 @@ + + + + + + +
                diff --git a/docs/security.html b/docs/security.html new file mode 100644 index 0000000..8ff9e6d --- /dev/null +++ b/docs/security.html @@ -0,0 +1,2122 @@ + + + + +
                diff --git a/docs/streams/architecture.html b/docs/streams/architecture.html new file mode 100644 index 0000000..a1773c5 --- /dev/null +++ b/docs/streams/architecture.html @@ -0,0 +1,193 @@ + + + + + + + + +
                + + +
                + + diff --git a/docs/streams/core-concepts.html b/docs/streams/core-concepts.html new file mode 100644 index 0000000..884b398 --- /dev/null +++ b/docs/streams/core-concepts.html @@ -0,0 +1,384 @@ + + + + + + + + +
                + + +
                + + diff --git a/docs/streams/developer-guide/app-reset-tool.html b/docs/streams/developer-guide/app-reset-tool.html new file mode 100644 index 0000000..597b662 --- /dev/null +++ b/docs/streams/developer-guide/app-reset-tool.html @@ -0,0 +1,202 @@ + + + + + + + + + + + diff --git a/docs/streams/developer-guide/config-streams.html b/docs/streams/developer-guide/config-streams.html new file mode 100644 index 0000000..dd9298d --- /dev/null +++ b/docs/streams/developer-guide/config-streams.html @@ -0,0 +1,1106 @@ + + + + + + + + + + + diff --git a/docs/streams/developer-guide/datatypes.html b/docs/streams/developer-guide/datatypes.html new file mode 100644 index 0000000..458b47b --- /dev/null +++ b/docs/streams/developer-guide/datatypes.html @@ -0,0 +1,235 @@ + + + + + + + + + + + diff --git a/docs/streams/developer-guide/dsl-api.html b/docs/streams/developer-guide/dsl-api.html new file mode 100644 index 0000000..ba97ffc --- /dev/null +++ b/docs/streams/developer-guide/dsl-api.html @@ -0,0 +1,3888 @@ + + + + + + + + + + + diff --git a/docs/streams/developer-guide/dsl-topology-naming.html b/docs/streams/developer-guide/dsl-topology-naming.html new file mode 100644 index 0000000..9e687f9 --- /dev/null +++ b/docs/streams/developer-guide/dsl-topology-naming.html @@ -0,0 +1,350 @@ + + + + + + + + + + + + + + + diff --git a/docs/streams/developer-guide/index.html b/docs/streams/developer-guide/index.html new file mode 100644 index 0000000..19f638e --- /dev/null +++ b/docs/streams/developer-guide/index.html @@ -0,0 +1,106 @@ + + + + + + + + +
                + + +
                + + diff --git a/docs/streams/developer-guide/interactive-queries.html b/docs/streams/developer-guide/interactive-queries.html new file mode 100644 index 0000000..45bfdb9 --- /dev/null +++ b/docs/streams/developer-guide/interactive-queries.html @@ -0,0 +1,502 @@ + + + + + + + + + + + diff --git a/docs/streams/developer-guide/manage-topics.html b/docs/streams/developer-guide/manage-topics.html new file mode 100644 index 0000000..d65e375 --- /dev/null +++ b/docs/streams/developer-guide/manage-topics.html @@ -0,0 +1,128 @@ + + + + + + + + + + + diff --git a/docs/streams/developer-guide/memory-mgmt.html b/docs/streams/developer-guide/memory-mgmt.html new file mode 100644 index 0000000..9a39ce1 --- /dev/null +++ b/docs/streams/developer-guide/memory-mgmt.html @@ -0,0 +1,278 @@ + + + + + + + + + + + diff --git a/docs/streams/developer-guide/processor-api.html b/docs/streams/developer-guide/processor-api.html new file mode 100644 index 0000000..90706e5 --- /dev/null +++ b/docs/streams/developer-guide/processor-api.html @@ -0,0 +1,554 @@ + + + + + + + + + + + diff --git a/docs/streams/developer-guide/running-app.html b/docs/streams/developer-guide/running-app.html new file mode 100644 index 0000000..ff3ed75 --- /dev/null +++ b/docs/streams/developer-guide/running-app.html @@ -0,0 +1,188 @@ + + + + + + + + + + + \ No newline at end of file diff --git a/docs/streams/developer-guide/security.html b/docs/streams/developer-guide/security.html new file mode 100644 index 0000000..63bc942 --- /dev/null +++ b/docs/streams/developer-guide/security.html @@ -0,0 +1,190 @@ + + + + + + + + + + + diff --git a/docs/streams/developer-guide/testing.html b/docs/streams/developer-guide/testing.html new file mode 100644 index 0000000..b5fadb1 --- /dev/null +++ b/docs/streams/developer-guide/testing.html @@ -0,0 +1,400 @@ + + + + + + + + + + + diff --git a/docs/streams/developer-guide/write-streams.html b/docs/streams/developer-guide/write-streams.html new file mode 100644 index 0000000..03bd163 --- /dev/null +++ b/docs/streams/developer-guide/write-streams.html @@ -0,0 +1,245 @@ + + + + + + + + + + + diff --git a/docs/streams/index.html b/docs/streams/index.html new file mode 100644 index 0000000..e38b389 --- /dev/null +++ b/docs/streams/index.html @@ -0,0 +1,364 @@ + + + + + + +
                + +
                + + +
                +
                +
                + + diff --git a/docs/streams/quickstart.html b/docs/streams/quickstart.html new file mode 100644 index 0000000..2cc48ef --- /dev/null +++ b/docs/streams/quickstart.html @@ -0,0 +1,355 @@ + + + + + +
                + + + +
                + + +
                + + diff --git a/docs/streams/tutorial.html b/docs/streams/tutorial.html new file mode 100644 index 0000000..a526de5 --- /dev/null +++ b/docs/streams/tutorial.html @@ -0,0 +1,608 @@ + + + + + +
                + + + +
                + + +
                + + diff --git a/docs/streams/upgrade-guide.html b/docs/streams/upgrade-guide.html new file mode 100644 index 0000000..febfc65 --- /dev/null +++ b/docs/streams/upgrade-guide.html @@ -0,0 +1,1214 @@ + + + + + + + + +
                + + +
                + + diff --git a/docs/toc.html b/docs/toc.html new file mode 100644 index 0000000..f5b1576 --- /dev/null +++ b/docs/toc.html @@ -0,0 +1,189 @@ + + + + + + +
                diff --git a/docs/upgrade.html b/docs/upgrade.html new file mode 100644 index 0000000..73ebe29 --- /dev/null +++ b/docs/upgrade.html @@ -0,0 +1,1883 @@ + + + + + + +
                diff --git a/docs/uses.html b/docs/uses.html new file mode 100644 index 0000000..51f16a8 --- /dev/null +++ b/docs/uses.html @@ -0,0 +1,81 @@ + + +

                Here is a description of a few of the popular use cases for Apache Kafka®. +For an overview of a number of these areas in action, see this blog post.

                + +

                Messaging

                + +Kafka works well as a replacement for a more traditional message broker. +Message brokers are used for a variety of reasons (to decouple processing from data producers, to buffer unprocessed messages, etc). +In comparison to most messaging systems Kafka has better throughput, built-in partitioning, replication, and fault-tolerance which makes it a good +solution for large scale message processing applications. +

                +In our experience messaging uses are often comparatively low-throughput, but may require low end-to-end latency and often depend on the strong +durability guarantees Kafka provides. +

                +In this domain Kafka is comparable to traditional messaging systems such as ActiveMQ or +RabbitMQ. + +

                Website Activity Tracking

                + +The original use case for Kafka was to be able to rebuild a user activity tracking pipeline as a set of real-time publish-subscribe feeds. +This means site activity (page views, searches, or other actions users may take) is published to central topics with one topic per activity type. +These feeds are available for subscription for a range of use cases including real-time processing, real-time monitoring, and loading into Hadoop or +offline data warehousing systems for offline processing and reporting. +

                +Activity tracking is often very high volume as many activity messages are generated for each user page view. + +

                Metrics

                + +Kafka is often used for operational monitoring data. +This involves aggregating statistics from distributed applications to produce centralized feeds of operational data. + +

                Log Aggregation

                + +Many people use Kafka as a replacement for a log aggregation solution. +Log aggregation typically collects physical log files off servers and puts them in a central place (a file server or HDFS perhaps) for processing. +Kafka abstracts away the details of files and gives a cleaner abstraction of log or event data as a stream of messages. +This allows for lower-latency processing and easier support for multiple data sources and distributed data consumption. + +In comparison to log-centric systems like Scribe or Flume, Kafka offers equally good performance, stronger durability guarantees due to replication, +and much lower end-to-end latency. + +

                Stream Processing

                + +Many users of Kafka process data in processing pipelines consisting of multiple stages, where raw input data is consumed from Kafka topics and then +aggregated, enriched, or otherwise transformed into new topics for further consumption or follow-up processing. +For example, a processing pipeline for recommending news articles might crawl article content from RSS feeds and publish it to an "articles" topic; +further processing might normalize or deduplicate this content and publish the cleansed article content to a new topic; +a final processing stage might attempt to recommend this content to users. +Such processing pipelines create graphs of real-time data flows based on the individual topics. +Starting in 0.10.0.0, a light-weight but powerful stream processing library called Kafka Streams +is available in Apache Kafka to perform such data processing as described above. +Apart from Kafka Streams, alternative open source stream processing tools include Apache Storm and +Apache Samza. + +

                Event Sourcing

                + +Event sourcing is a style of application design where state changes are logged as a +time-ordered sequence of records. Kafka's support for very large stored log data makes it an excellent backend for an application built in this style. + +

                Commit Log

                + +Kafka can serve as a kind of external commit-log for a distributed system. The log helps replicate data between nodes and acts as a re-syncing +mechanism for failed nodes to restore their data. +The log compaction feature in Kafka helps support this usage. +In this usage Kafka is similar to Apache BookKeeper project. diff --git a/examples/README b/examples/README new file mode 100644 index 0000000..3880011 --- /dev/null +++ b/examples/README @@ -0,0 +1,12 @@ +This directory contains examples of client code that uses kafka. + +To run the demo: + + 1. Start Zookeeper and the Kafka server + 2. For unlimited sync-producer-consumer run, `run bin/java-producer-consumer-demo.sh sync` + 3. For unlimited async-producer-consumer run, `run bin/java-producer-consumer-demo.sh` + 4. For exactly once demo run, `run bin/exactly-once-demo.sh 6 3 50000`, + this means we are starting 3 EOS instances with 6 topic partitions and 50000 pre-populated records. + 5. Some notes for exactly once demo: + 5.1. The Kafka server has to be on broker version 2.5 or higher. + 5.2. You could also use Intellij to run the example directly by configuring parameters as "Program arguments" diff --git a/examples/bin/exactly-once-demo.sh b/examples/bin/exactly-once-demo.sh new file mode 100755 index 0000000..e9faa42 --- /dev/null +++ b/examples/bin/exactly-once-demo.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +base_dir=$(dirname $0)/../.. + +if [ "x$KAFKA_HEAP_OPTS" = "x" ]; then + export KAFKA_HEAP_OPTS="-Xmx512M" +fi + +exec $base_dir/bin/kafka-run-class.sh kafka.examples.KafkaExactlyOnceDemo $@ diff --git a/examples/bin/java-producer-consumer-demo.sh b/examples/bin/java-producer-consumer-demo.sh new file mode 100755 index 0000000..fd25e59 --- /dev/null +++ b/examples/bin/java-producer-consumer-demo.sh @@ -0,0 +1,22 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +base_dir=$(dirname $0)/../.. + +if [ "x$KAFKA_HEAP_OPTS" = "x" ]; then + export KAFKA_HEAP_OPTS="-Xmx512M" +fi +exec $base_dir/bin/kafka-run-class.sh kafka.examples.KafkaConsumerProducerDemo $@ diff --git a/examples/src/main/java/kafka/examples/Consumer.java b/examples/src/main/java/kafka/examples/Consumer.java new file mode 100644 index 0000000..d748832 --- /dev/null +++ b/examples/src/main/java/kafka/examples/Consumer.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.examples; + +import kafka.utils.ShutdownableThread; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; + +import java.time.Duration; +import java.util.Collections; +import java.util.Optional; +import java.util.Properties; +import java.util.concurrent.CountDownLatch; + +public class Consumer extends ShutdownableThread { + private final KafkaConsumer consumer; + private final String topic; + private final String groupId; + private final int numMessageToConsume; + private int messageRemaining; + private final CountDownLatch latch; + + public Consumer(final String topic, + final String groupId, + final Optional instanceId, + final boolean readCommitted, + final int numMessageToConsume, + final CountDownLatch latch) { + super("KafkaConsumerExample", false); + this.groupId = groupId; + Properties props = new Properties(); + props.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, KafkaProperties.KAFKA_SERVER_URL + ":" + KafkaProperties.KAFKA_SERVER_PORT); + props.put(ConsumerConfig.GROUP_ID_CONFIG, groupId); + instanceId.ifPresent(id -> props.put(ConsumerConfig.GROUP_INSTANCE_ID_CONFIG, id)); + props.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "true"); + props.put(ConsumerConfig.AUTO_COMMIT_INTERVAL_MS_CONFIG, "1000"); + props.put(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, "30000"); + props.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.IntegerDeserializer"); + props.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, "org.apache.kafka.common.serialization.StringDeserializer"); + if (readCommitted) { + props.put(ConsumerConfig.ISOLATION_LEVEL_CONFIG, "read_committed"); + } + props.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + + consumer = new KafkaConsumer<>(props); + this.topic = topic; + this.numMessageToConsume = numMessageToConsume; + this.messageRemaining = numMessageToConsume; + this.latch = latch; + } + + KafkaConsumer get() { + return consumer; + } + + @Override + public void doWork() { + consumer.subscribe(Collections.singletonList(this.topic)); + ConsumerRecords records = consumer.poll(Duration.ofSeconds(1)); + for (ConsumerRecord record : records) { + System.out.println(groupId + " received message : from partition " + record.partition() + ", (" + record.key() + ", " + record.value() + ") at offset " + record.offset()); + } + messageRemaining -= records.count(); + if (messageRemaining <= 0) { + System.out.println(groupId + " finished reading " + numMessageToConsume + " messages"); + latch.countDown(); + } + } + + @Override + public String name() { + return null; + } + + @Override + public boolean isInterruptible() { + return false; + } +} diff --git a/examples/src/main/java/kafka/examples/ExactlyOnceMessageProcessor.java b/examples/src/main/java/kafka/examples/ExactlyOnceMessageProcessor.java new file mode 100644 index 0000000..8f31b19 --- /dev/null +++ b/examples/src/main/java/kafka/examples/ExactlyOnceMessageProcessor.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.examples; + +import org.apache.kafka.clients.consumer.ConsumerRebalanceListener; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.FencedInstanceIdException; +import org.apache.kafka.common.errors.ProducerFencedException; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicLong; + +/** + * A demo class for how to write a customized EOS app. It takes a consume-process-produce loop. + * Important configurations and APIs are commented. + */ +public class ExactlyOnceMessageProcessor extends Thread { + + private static final boolean READ_COMMITTED = true; + + private final String inputTopic; + private final String outputTopic; + private final String transactionalId; + private final String groupInstanceId; + + private final KafkaProducer producer; + private final KafkaConsumer consumer; + + private final CountDownLatch latch; + + public ExactlyOnceMessageProcessor(final String inputTopic, + final String outputTopic, + final int instanceIdx, + final CountDownLatch latch) { + this.inputTopic = inputTopic; + this.outputTopic = outputTopic; + this.transactionalId = "Processor-" + instanceIdx; + // It is recommended to have a relatively short txn timeout in order to clear pending offsets faster. + final int transactionTimeoutMs = 10000; + // A unique transactional.id must be provided in order to properly use EOS. + producer = new Producer(outputTopic, true, transactionalId, true, -1, transactionTimeoutMs, null).get(); + // Consumer must be in read_committed mode, which means it won't be able to read uncommitted data. + // Consumer could optionally configure groupInstanceId to avoid unnecessary rebalances. + this.groupInstanceId = "Txn-consumer-" + instanceIdx; + consumer = new Consumer(inputTopic, "Eos-consumer", + Optional.of(groupInstanceId), READ_COMMITTED, -1, null).get(); + this.latch = latch; + } + + @Override + public void run() { + // Init transactions call should always happen first in order to clear zombie transactions from previous generation. + producer.initTransactions(); + + final AtomicLong messageRemaining = new AtomicLong(Long.MAX_VALUE); + + consumer.subscribe(Collections.singleton(inputTopic), new ConsumerRebalanceListener() { + @Override + public void onPartitionsRevoked(Collection partitions) { + printWithTxnId("Revoked partition assignment to kick-off rebalancing: " + partitions); + } + + @Override + public void onPartitionsAssigned(Collection partitions) { + printWithTxnId("Received partition assignment after rebalancing: " + partitions); + messageRemaining.set(messagesRemaining(consumer)); + } + }); + + int messageProcessed = 0; + while (messageRemaining.get() > 0) { + try { + ConsumerRecords records = consumer.poll(Duration.ofMillis(200)); + if (records.count() > 0) { + // Begin a new transaction session. + producer.beginTransaction(); + for (ConsumerRecord record : records) { + // Process the record and send to downstream. + ProducerRecord customizedRecord = transform(record); + producer.send(customizedRecord); + } + + Map offsets = consumerOffsets(); + + // Checkpoint the progress by sending offsets to group coordinator broker. + // Note that this API is only available for broker >= 2.5. + producer.sendOffsetsToTransaction(offsets, consumer.groupMetadata()); + + // Finish the transaction. All sent records should be visible for consumption now. + producer.commitTransaction(); + messageProcessed += records.count(); + } + } catch (ProducerFencedException e) { + throw new KafkaException(String.format("The transactional.id %s has been claimed by another process", transactionalId)); + } catch (FencedInstanceIdException e) { + throw new KafkaException(String.format("The group.instance.id %s has been claimed by another process", groupInstanceId)); + } catch (KafkaException e) { + // If we have not been fenced, try to abort the transaction and continue. This will raise immediately + // if the producer has hit a fatal error. + producer.abortTransaction(); + + // The consumer fetch position needs to be restored to the committed offset + // before the transaction started. + resetToLastCommittedPositions(consumer); + } + + messageRemaining.set(messagesRemaining(consumer)); + printWithTxnId("Message remaining: " + messageRemaining); + } + + printWithTxnId("Finished processing " + messageProcessed + " records"); + latch.countDown(); + } + + private Map consumerOffsets() { + Map offsets = new HashMap<>(); + for (TopicPartition topicPartition : consumer.assignment()) { + offsets.put(topicPartition, new OffsetAndMetadata(consumer.position(topicPartition), null)); + } + return offsets; + } + + private void printWithTxnId(final String message) { + System.out.println(transactionalId + ": " + message); + } + + private ProducerRecord transform(final ConsumerRecord record) { + printWithTxnId("Transformed record (" + record.key() + "," + record.value() + ")"); + return new ProducerRecord<>(outputTopic, record.key() / 2, "Transformed_" + record.value()); + } + + private long messagesRemaining(final KafkaConsumer consumer) { + final Map fullEndOffsets = consumer.endOffsets(new ArrayList<>(consumer.assignment())); + // If we couldn't detect any end offset, that means we are still not able to fetch offsets. + if (fullEndOffsets.isEmpty()) { + return Long.MAX_VALUE; + } + + return consumer.assignment().stream().mapToLong(partition -> { + long currentPosition = consumer.position(partition); + printWithTxnId("Processing partition " + partition + " with full offsets " + fullEndOffsets); + if (fullEndOffsets.containsKey(partition)) { + return fullEndOffsets.get(partition) - currentPosition; + } + return 0; + }).sum(); + } + + private static void resetToLastCommittedPositions(KafkaConsumer consumer) { + final Map committed = consumer.committed(consumer.assignment()); + consumer.assignment().forEach(tp -> { + OffsetAndMetadata offsetAndMetadata = committed.get(tp); + if (offsetAndMetadata != null) + consumer.seek(tp, offsetAndMetadata.offset()); + else + consumer.seekToBeginning(Collections.singleton(tp)); + }); + } +} diff --git a/examples/src/main/java/kafka/examples/KafkaConsumerProducerDemo.java b/examples/src/main/java/kafka/examples/KafkaConsumerProducerDemo.java new file mode 100644 index 0000000..9fc911a --- /dev/null +++ b/examples/src/main/java/kafka/examples/KafkaConsumerProducerDemo.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.examples; + +import org.apache.kafka.common.errors.TimeoutException; + +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +public class KafkaConsumerProducerDemo { + public static void main(String[] args) throws InterruptedException { + boolean isAsync = args.length == 0 || !args[0].trim().equalsIgnoreCase("sync"); + CountDownLatch latch = new CountDownLatch(2); + Producer producerThread = new Producer(KafkaProperties.TOPIC, isAsync, null, false, 10000, -1, latch); + producerThread.start(); + + Consumer consumerThread = new Consumer(KafkaProperties.TOPIC, "DemoConsumer", Optional.empty(), false, 10000, latch); + consumerThread.start(); + + if (!latch.await(5, TimeUnit.MINUTES)) { + throw new TimeoutException("Timeout after 5 minutes waiting for demo producer and consumer to finish"); + } + + consumerThread.shutdown(); + System.out.println("All finished!"); + } +} diff --git a/examples/src/main/java/kafka/examples/KafkaExactlyOnceDemo.java b/examples/src/main/java/kafka/examples/KafkaExactlyOnceDemo.java new file mode 100644 index 0000000..50a1ad1 --- /dev/null +++ b/examples/src/main/java/kafka/examples/KafkaExactlyOnceDemo.java @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.examples; + +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.TopicExistsException; +import org.apache.kafka.common.errors.UnknownTopicOrPartitionException; + +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; + +/** + * This exactly once demo driver takes 3 arguments: + * - partition: number of partitions for input/output topic + * - instances: number of instances + * - records: number of records + * An example argument list would be `6 3 50000`. + * + * If you are using Intellij, the above arguments should be put in the configuration's `Program Arguments`. + * Also recommended to set an output log file by `Edit Configuration -> Logs -> Save console + * output to file` to record all the log output together. + * + * The driver could be decomposed as following stages: + * + * 1. Cleanup any topic whose name conflicts with input and output topic, so that we have a clean-start. + * + * 2. Set up a producer in a separate thread to pre-populate a set of records with even number keys into + * the input topic. The driver will block for the record generation to finish, so the producer + * must be in synchronous sending mode. + * + * 3. Set up transactional instances in separate threads which does a consume-process-produce loop, + * tailing data from input topic (See {@link ExactlyOnceMessageProcessor}). Each EOS instance will + * drain all the records from either given partitions or auto assigned partitions by actively + * comparing log end offset with committed offset. Each record will be processed exactly once + * as dividing the key by 2, and extend the value message. The driver will block for all the record + * processing to finish. The transformed record shall be written to the output topic, with + * transactional guarantee. + * + * 4. Set up a read committed consumer in a separate thread to verify we have all records within + * the output topic, while the message ordering on partition level is maintained. + * The driver will block for the consumption of all committed records. + * + * From this demo, you could see that all the records from pre-population are processed exactly once, + * with strong partition level ordering guarantee. + * + * Note: please start the kafka broker and zookeeper in local first. The broker version must be >= 2.5 + * in order to run, otherwise the app could throw + * {@link org.apache.kafka.common.errors.UnsupportedVersionException}. + */ +public class KafkaExactlyOnceDemo { + + private static final String INPUT_TOPIC = "input-topic"; + private static final String OUTPUT_TOPIC = "output-topic"; + + public static void main(String[] args) throws InterruptedException, ExecutionException { + if (args.length != 3) { + throw new IllegalArgumentException("Should accept 3 parameters: " + + "[number of partitions], [number of instances], [number of records]"); + } + + int numPartitions = Integer.parseInt(args[0]); + int numInstances = Integer.parseInt(args[1]); + int numRecords = Integer.parseInt(args[2]); + + /* Stage 1: topic cleanup and recreation */ + recreateTopics(numPartitions); + + CountDownLatch prePopulateLatch = new CountDownLatch(1); + + /* Stage 2: pre-populate records */ + Producer producerThread = new Producer(INPUT_TOPIC, false, null, true, numRecords, -1, prePopulateLatch); + producerThread.start(); + + if (!prePopulateLatch.await(5, TimeUnit.MINUTES)) { + throw new TimeoutException("Timeout after 5 minutes waiting for data pre-population"); + } + + CountDownLatch transactionalCopyLatch = new CountDownLatch(numInstances); + + /* Stage 3: transactionally process all messages */ + for (int instanceIdx = 0; instanceIdx < numInstances; instanceIdx++) { + ExactlyOnceMessageProcessor messageProcessor = new ExactlyOnceMessageProcessor( + INPUT_TOPIC, OUTPUT_TOPIC, instanceIdx, transactionalCopyLatch); + messageProcessor.start(); + } + + if (!transactionalCopyLatch.await(5, TimeUnit.MINUTES)) { + throw new TimeoutException("Timeout after 5 minutes waiting for transactionally message copy"); + } + + CountDownLatch consumeLatch = new CountDownLatch(1); + + /* Stage 4: consume all processed messages to verify exactly once */ + Consumer consumerThread = new Consumer(OUTPUT_TOPIC, "Verify-consumer", Optional.empty(), true, numRecords, consumeLatch); + consumerThread.start(); + + if (!consumeLatch.await(5, TimeUnit.MINUTES)) { + throw new TimeoutException("Timeout after 5 minutes waiting for output data consumption"); + } + + consumerThread.shutdown(); + System.out.println("All finished!"); + } + + private static void recreateTopics(final int numPartitions) + throws ExecutionException, InterruptedException { + Properties props = new Properties(); + props.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, + KafkaProperties.KAFKA_SERVER_URL + ":" + KafkaProperties.KAFKA_SERVER_PORT); + + Admin adminClient = Admin.create(props); + + List topicsToDelete = Arrays.asList(INPUT_TOPIC, OUTPUT_TOPIC); + + deleteTopic(adminClient, topicsToDelete); + + // Check topic existence in a retry loop + while (true) { + System.out.println("Making sure the topics are deleted successfully: " + topicsToDelete); + + Set listedTopics = adminClient.listTopics().names().get(); + System.out.println("Current list of topics: " + listedTopics); + + boolean hasTopicInfo = false; + for (String listedTopic : listedTopics) { + if (topicsToDelete.contains(listedTopic)) { + hasTopicInfo = true; + break; + } + } + if (!hasTopicInfo) { + break; + } + Thread.sleep(1000); + } + + // Create topics in a retry loop + while (true) { + final short replicationFactor = 1; + final List newTopics = Arrays.asList( + new NewTopic(INPUT_TOPIC, numPartitions, replicationFactor), + new NewTopic(OUTPUT_TOPIC, numPartitions, replicationFactor)); + try { + adminClient.createTopics(newTopics).all().get(); + System.out.println("Created new topics: " + newTopics); + break; + } catch (ExecutionException e) { + if (!(e.getCause() instanceof TopicExistsException)) { + throw e; + } + System.out.println("Metadata of the old topics are not cleared yet..."); + + deleteTopic(adminClient, topicsToDelete); + + Thread.sleep(1000); + } + } + } + + private static void deleteTopic(final Admin adminClient, final List topicsToDelete) + throws InterruptedException, ExecutionException { + try { + adminClient.deleteTopics(topicsToDelete).all().get(); + } catch (ExecutionException e) { + if (!(e.getCause() instanceof UnknownTopicOrPartitionException)) { + throw e; + } + System.out.println("Encountered exception during topic deletion: " + e.getCause()); + } + System.out.println("Deleted old topics: " + topicsToDelete); + } +} diff --git a/examples/src/main/java/kafka/examples/KafkaProperties.java b/examples/src/main/java/kafka/examples/KafkaProperties.java new file mode 100644 index 0000000..e73c8d7 --- /dev/null +++ b/examples/src/main/java/kafka/examples/KafkaProperties.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.examples; + +public class KafkaProperties { + public static final String TOPIC = "topic1"; + public static final String KAFKA_SERVER_URL = "localhost"; + public static final int KAFKA_SERVER_PORT = 9092; + + private KafkaProperties() {} +} diff --git a/examples/src/main/java/kafka/examples/Producer.java b/examples/src/main/java/kafka/examples/Producer.java new file mode 100644 index 0000000..d6b6dea --- /dev/null +++ b/examples/src/main/java/kafka/examples/Producer.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package kafka.examples; + +import org.apache.kafka.clients.producer.Callback; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.StringSerializer; + +import java.util.Properties; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; + +public class Producer extends Thread { + private final KafkaProducer producer; + private final String topic; + private final Boolean isAsync; + private int numRecords; + private final CountDownLatch latch; + + public Producer(final String topic, + final Boolean isAsync, + final String transactionalId, + final boolean enableIdempotency, + final int numRecords, + final int transactionTimeoutMs, + final CountDownLatch latch) { + Properties props = new Properties(); + props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, KafkaProperties.KAFKA_SERVER_URL + ":" + KafkaProperties.KAFKA_SERVER_PORT); + props.put(ProducerConfig.CLIENT_ID_CONFIG, "DemoProducer"); + props.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, IntegerSerializer.class.getName()); + props.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, StringSerializer.class.getName()); + if (transactionTimeoutMs > 0) { + props.put(ProducerConfig.TRANSACTION_TIMEOUT_CONFIG, transactionTimeoutMs); + } + if (transactionalId != null) { + props.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, transactionalId); + } + props.put(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG, enableIdempotency); + + producer = new KafkaProducer<>(props); + this.topic = topic; + this.isAsync = isAsync; + this.numRecords = numRecords; + this.latch = latch; + } + + KafkaProducer get() { + return producer; + } + + @Override + public void run() { + int messageKey = 0; + int recordsSent = 0; + while (recordsSent < numRecords) { + String messageStr = "Message_" + messageKey; + long startTime = System.currentTimeMillis(); + if (isAsync) { // Send asynchronously + producer.send(new ProducerRecord<>(topic, + messageKey, + messageStr), new DemoCallBack(startTime, messageKey, messageStr)); + } else { // Send synchronously + try { + producer.send(new ProducerRecord<>(topic, + messageKey, + messageStr)).get(); + System.out.println("Sent message: (" + messageKey + ", " + messageStr + ")"); + } catch (InterruptedException | ExecutionException e) { + e.printStackTrace(); + } + } + messageKey += 2; + recordsSent += 1; + } + System.out.println("Producer sent " + numRecords + " records successfully"); + latch.countDown(); + } +} + +class DemoCallBack implements Callback { + + private final long startTime; + private final int key; + private final String message; + + public DemoCallBack(long startTime, int key, String message) { + this.startTime = startTime; + this.key = key; + this.message = message; + } + + /** + * A callback method the user can implement to provide asynchronous handling of request completion. This method will + * be called when the record sent to the server has been acknowledged. When exception is not null in the callback, + * metadata will contain the special -1 value for all fields except for topicPartition, which will be valid. + * + * @param metadata The metadata for the record that was sent (i.e. the partition and offset). An empty metadata + * with -1 value for all fields except for topicPartition will be returned if an error occurred. + * @param exception The exception thrown during processing of this record. Null if no error occurred. + */ + public void onCompletion(RecordMetadata metadata, Exception exception) { + long elapsedTime = System.currentTimeMillis() - startTime; + if (metadata != null) { + System.out.println( + "message(" + key + ", " + message + ") sent to partition(" + metadata.partition() + + "), " + + "offset(" + metadata.offset() + ") in " + elapsedTime + " ms"); + } else { + exception.printStackTrace(); + } + } +} diff --git a/generator/src/main/java/org/apache/kafka/message/ApiMessageTypeGenerator.java b/generator/src/main/java/org/apache/kafka/message/ApiMessageTypeGenerator.java new file mode 100644 index 0000000..408e1a7 --- /dev/null +++ b/generator/src/main/java/org/apache/kafka/message/ApiMessageTypeGenerator.java @@ -0,0 +1,410 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +import java.io.BufferedWriter; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.EnumMap; +import java.util.Iterator; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.TreeMap; +import java.util.stream.Collectors; + +public final class ApiMessageTypeGenerator implements TypeClassGenerator { + private final HeaderGenerator headerGenerator; + private final CodeBuffer buffer; + private final TreeMap apis; + private final EnumMap> apisByListener = new EnumMap<>(RequestListenerType.class); + + private static final class ApiData { + short apiKey; + MessageSpec requestSpec; + MessageSpec responseSpec; + + ApiData(short apiKey) { + this.apiKey = apiKey; + } + + String name() { + if (requestSpec != null) { + return MessageGenerator.stripSuffix(requestSpec.name(), + MessageGenerator.REQUEST_SUFFIX); + } else if (responseSpec != null) { + return MessageGenerator.stripSuffix(responseSpec.name(), + MessageGenerator.RESPONSE_SUFFIX); + } else { + throw new RuntimeException("Neither requestSpec nor responseSpec is defined " + + "for API key " + apiKey); + } + } + + String requestSchema() { + if (requestSpec == null) { + return "null"; + } else { + return String.format("%sData.SCHEMAS", requestSpec.name()); + } + } + + String responseSchema() { + if (responseSpec == null) { + return "null"; + } else { + return String.format("%sData.SCHEMAS", responseSpec.name()); + } + } + } + + public ApiMessageTypeGenerator(String packageName) { + this.headerGenerator = new HeaderGenerator(packageName); + this.apis = new TreeMap<>(); + this.buffer = new CodeBuffer(); + } + + @Override + public String outputName() { + return MessageGenerator.API_MESSAGE_TYPE_JAVA; + } + + @Override + public void registerMessageType(MessageSpec spec) { + switch (spec.type()) { + case REQUEST: { + short apiKey = spec.apiKey().get(); + ApiData data = apis.get(apiKey); + if (!apis.containsKey(apiKey)) { + data = new ApiData(apiKey); + apis.put(apiKey, data); + } + if (data.requestSpec != null) { + throw new RuntimeException("Found more than one request with " + + "API key " + spec.apiKey().get()); + } + data.requestSpec = spec; + + if (spec.listeners() != null) { + for (RequestListenerType listener : spec.listeners()) { + apisByListener.putIfAbsent(listener, new ArrayList<>()); + apisByListener.get(listener).add(data); + } + } + break; + } + case RESPONSE: { + short apiKey = spec.apiKey().get(); + ApiData data = apis.get(apiKey); + if (!apis.containsKey(apiKey)) { + data = new ApiData(apiKey); + apis.put(apiKey, data); + } + if (data.responseSpec != null) { + throw new RuntimeException("Found more than one response with " + + "API key " + spec.apiKey().get()); + } + data.responseSpec = spec; + break; + } + default: + // do nothing + break; + } + } + + @Override + public void generateAndWrite(BufferedWriter writer) throws IOException { + generate(); + write(writer); + } + + private void generate() { + buffer.printf("public enum ApiMessageType {%n"); + buffer.incrementIndent(); + generateEnumValues(); + buffer.printf("%n"); + generateInstanceVariables(); + buffer.printf("%n"); + generateEnumConstructor(); + buffer.printf("%n"); + generateFromApiKey(); + buffer.printf("%n"); + generateNewApiMessageMethod("request"); + buffer.printf("%n"); + generateNewApiMessageMethod("response"); + buffer.printf("%n"); + generateAccessor("lowestSupportedVersion", "short"); + buffer.printf("%n"); + generateAccessor("highestSupportedVersion", "short"); + buffer.printf("%n"); + generateAccessor("listeners", "EnumSet"); + buffer.printf("%n"); + generateAccessor("apiKey", "short"); + buffer.printf("%n"); + generateAccessor("requestSchemas", "Schema[]"); + buffer.printf("%n"); + generateAccessor("responseSchemas", "Schema[]"); + buffer.printf("%n"); + generateToString(); + buffer.printf("%n"); + generateHeaderVersion("request"); + buffer.printf("%n"); + generateHeaderVersion("response"); + buffer.printf("%n"); + generateListenerTypesEnum(); + buffer.printf("%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + headerGenerator.generate(); + } + + private String generateListenerTypeEnumSet(Collection values) { + if (values.isEmpty()) { + return "EnumSet.noneOf(ListenerType.class)"; + } + StringBuilder bldr = new StringBuilder("EnumSet.of("); + Iterator iter = values.iterator(); + while (iter.hasNext()) { + bldr.append("ListenerType."); + bldr.append(iter.next()); + if (iter.hasNext()) { + bldr.append(", "); + } + } + bldr.append(")"); + return bldr.toString(); + } + + private void generateEnumValues() { + int numProcessed = 0; + for (Map.Entry entry : apis.entrySet()) { + ApiData apiData = entry.getValue(); + String name = apiData.name(); + numProcessed++; + + final Collection listeners; + if (apiData.requestSpec.listeners() == null) { + listeners = Collections.emptyList(); + } else { + listeners = apiData.requestSpec.listeners().stream() + .map(RequestListenerType::name) + .collect(Collectors.toList()); + } + + buffer.printf("%s(\"%s\", (short) %d, %s, %s, (short) %d, (short) %d, %s)%s%n", + MessageGenerator.toSnakeCase(name).toUpperCase(Locale.ROOT), + MessageGenerator.capitalizeFirst(name), + entry.getKey(), + apiData.requestSchema(), + apiData.responseSchema(), + apiData.requestSpec.struct().versions().lowest(), + apiData.requestSpec.struct().versions().highest(), + generateListenerTypeEnumSet(listeners), + (numProcessed == apis.size()) ? ";" : ","); + } + } + + private void generateInstanceVariables() { + buffer.printf("public final String name;%n"); + buffer.printf("private final short apiKey;%n"); + buffer.printf("private final Schema[] requestSchemas;%n"); + buffer.printf("private final Schema[] responseSchemas;%n"); + buffer.printf("private final short lowestSupportedVersion;%n"); + buffer.printf("private final short highestSupportedVersion;%n"); + buffer.printf("private final EnumSet listeners;%n"); + headerGenerator.addImport(MessageGenerator.SCHEMA_CLASS); + headerGenerator.addImport(MessageGenerator.ENUM_SET_CLASS); + } + + private void generateEnumConstructor() { + buffer.printf("ApiMessageType(String name, short apiKey, " + + "Schema[] requestSchemas, Schema[] responseSchemas, " + + "short lowestSupportedVersion, short highestSupportedVersion, " + + "EnumSet listeners) {%n"); + buffer.incrementIndent(); + buffer.printf("this.name = name;%n"); + buffer.printf("this.apiKey = apiKey;%n"); + buffer.printf("this.requestSchemas = requestSchemas;%n"); + buffer.printf("this.responseSchemas = responseSchemas;%n"); + buffer.printf("this.lowestSupportedVersion = lowestSupportedVersion;%n"); + buffer.printf("this.highestSupportedVersion = highestSupportedVersion;%n"); + buffer.printf("this.listeners = listeners;%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateFromApiKey() { + buffer.printf("public static ApiMessageType fromApiKey(short apiKey) {%n"); + buffer.incrementIndent(); + buffer.printf("switch (apiKey) {%n"); + buffer.incrementIndent(); + for (Map.Entry entry : apis.entrySet()) { + ApiData apiData = entry.getValue(); + String name = apiData.name(); + buffer.printf("case %d:%n", entry.getKey()); + buffer.incrementIndent(); + buffer.printf("return %s;%n", MessageGenerator.toSnakeCase(name).toUpperCase(Locale.ROOT)); + buffer.decrementIndent(); + } + buffer.printf("default:%n"); + buffer.incrementIndent(); + headerGenerator.addImport(MessageGenerator.UNSUPPORTED_VERSION_EXCEPTION_CLASS); + buffer.printf("throw new UnsupportedVersionException(\"Unsupported API key \"" + + " + apiKey);%n"); + buffer.decrementIndent(); + buffer.decrementIndent(); + buffer.printf("}%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateNewApiMessageMethod(String type) { + headerGenerator.addImport(MessageGenerator.API_MESSAGE_CLASS); + buffer.printf("public ApiMessage new%s() {%n", + MessageGenerator.capitalizeFirst(type)); + buffer.incrementIndent(); + buffer.printf("switch (apiKey) {%n"); + buffer.incrementIndent(); + for (Map.Entry entry : apis.entrySet()) { + buffer.printf("case %d:%n", entry.getKey()); + buffer.incrementIndent(); + buffer.printf("return new %s%sData();%n", + entry.getValue().name(), + MessageGenerator.capitalizeFirst(type)); + buffer.decrementIndent(); + } + buffer.printf("default:%n"); + buffer.incrementIndent(); + headerGenerator.addImport(MessageGenerator.UNSUPPORTED_VERSION_EXCEPTION_CLASS); + buffer.printf("throw new UnsupportedVersionException(\"Unsupported %s API key \"" + + " + apiKey);%n", type); + buffer.decrementIndent(); + buffer.decrementIndent(); + buffer.printf("}%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateAccessor(String name, String type) { + buffer.printf("public %s %s() {%n", type, name); + buffer.incrementIndent(); + buffer.printf("return this.%s;%n", name); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateToString() { + buffer.printf("@Override%n"); + buffer.printf("public String toString() {%n"); + buffer.incrementIndent(); + buffer.printf("return this.name();%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateHeaderVersion(String type) { + buffer.printf("public short %sHeaderVersion(short _version) {%n", type); + buffer.incrementIndent(); + buffer.printf("switch (apiKey) {%n"); + buffer.incrementIndent(); + for (Map.Entry entry : apis.entrySet()) { + short apiKey = entry.getKey(); + ApiData apiData = entry.getValue(); + String name = apiData.name(); + buffer.printf("case %d: // %s%n", apiKey, MessageGenerator.capitalizeFirst(name)); + buffer.incrementIndent(); + if (type.equals("response") && apiKey == 18) { + buffer.printf("// ApiVersionsResponse always includes a v0 header.%n"); + buffer.printf("// See KIP-511 for details.%n"); + buffer.printf("return (short) 0;%n"); + buffer.decrementIndent(); + continue; + } + if (type.equals("request") && apiKey == 7) { + buffer.printf("// Version 0 of ControlledShutdownRequest has a non-standard request header%n"); + buffer.printf("// which does not include clientId. Version 1 of ControlledShutdownRequest%n"); + buffer.printf("// and later use the standard request header.%n"); + buffer.printf("if (_version == 0) {%n"); + buffer.incrementIndent(); + buffer.printf("return (short) 0;%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + ApiData data = entry.getValue(); + MessageSpec spec = null; + if (type.equals("request")) { + spec = data.requestSpec; + } else if (type.equals("response")) { + spec = data.responseSpec; + } else { + throw new RuntimeException("Invalid type " + type + " for generateHeaderVersion"); + } + if (spec == null) { + throw new RuntimeException("failed to find " + type + " for API key " + apiKey); + } + VersionConditional.forVersions(spec.flexibleVersions(), + spec.validVersions()). + ifMember(__ -> { + if (type.equals("request")) { + buffer.printf("return (short) 2;%n"); + } else { + buffer.printf("return (short) 1;%n"); + } + }). + ifNotMember(__ -> { + if (type.equals("request")) { + buffer.printf("return (short) 1;%n"); + } else { + buffer.printf("return (short) 0;%n"); + } + }).generate(buffer); + buffer.decrementIndent(); + } + buffer.printf("default:%n"); + buffer.incrementIndent(); + headerGenerator.addImport(MessageGenerator.UNSUPPORTED_VERSION_EXCEPTION_CLASS); + buffer.printf("throw new UnsupportedVersionException(\"Unsupported API key \"" + + " + apiKey);%n"); + buffer.decrementIndent(); + buffer.decrementIndent(); + buffer.printf("}%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateListenerTypesEnum() { + buffer.printf("public enum ListenerType {%n"); + buffer.incrementIndent(); + Iterator listenerIter = Arrays.stream(RequestListenerType.values()).iterator(); + while (listenerIter.hasNext()) { + RequestListenerType scope = listenerIter.next(); + buffer.printf("%s%s%n", scope.name(), listenerIter.hasNext() ? "," : ";"); + } + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void write(BufferedWriter writer) throws IOException { + headerGenerator.buffer().write(writer); + buffer.write(writer); + } +} diff --git a/generator/src/main/java/org/apache/kafka/message/ClauseGenerator.java b/generator/src/main/java/org/apache/kafka/message/ClauseGenerator.java new file mode 100644 index 0000000..5ef55f2 --- /dev/null +++ b/generator/src/main/java/org/apache/kafka/message/ClauseGenerator.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +/** + * Generates a clause. + */ +public interface ClauseGenerator { + void generate(Versions versions); +} diff --git a/generator/src/main/java/org/apache/kafka/message/CodeBuffer.java b/generator/src/main/java/org/apache/kafka/message/CodeBuffer.java new file mode 100644 index 0000000..77febc9 --- /dev/null +++ b/generator/src/main/java/org/apache/kafka/message/CodeBuffer.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +import java.io.IOException; +import java.io.Writer; +import java.util.ArrayList; + +public class CodeBuffer { + private final ArrayList lines; + private int indent; + + public CodeBuffer() { + this.lines = new ArrayList<>(); + this.indent = 0; + } + + public void incrementIndent() { + indent++; + } + + public void decrementIndent() { + indent--; + if (indent < 0) { + throw new RuntimeException("Indent < 0"); + } + } + + public void printf(String format, Object... args) { + lines.add(String.format(indentSpaces() + format, args)); + } + + public void write(Writer writer) throws IOException { + for (String line : lines) { + writer.write(line); + } + } + + public void write(CodeBuffer other) { + for (String line : lines) { + other.lines.add(other.indentSpaces() + line); + } + } + + private String indentSpaces() { + StringBuilder bld = new StringBuilder(); + for (int i = 0; i < indent; i++) { + bld.append(" "); + } + return bld.toString(); + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof CodeBuffer)) { + return false; + } + CodeBuffer o = (CodeBuffer) other; + return lines.equals(o.lines); + } + + @Override + public int hashCode() { + return lines.hashCode(); + } +} diff --git a/generator/src/main/java/org/apache/kafka/message/EntityType.java b/generator/src/main/java/org/apache/kafka/message/EntityType.java new file mode 100644 index 0000000..225c987 --- /dev/null +++ b/generator/src/main/java/org/apache/kafka/message/EntityType.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +import com.fasterxml.jackson.annotation.JsonProperty; + +public enum EntityType { + @JsonProperty("unknown") + UNKNOWN(null), + + @JsonProperty("transactionalId") + TRANSACTIONAL_ID(FieldType.StringFieldType.INSTANCE), + + @JsonProperty("producerId") + PRODUCER_ID(FieldType.Int64FieldType.INSTANCE), + + @JsonProperty("groupId") + GROUP_ID(FieldType.StringFieldType.INSTANCE), + + @JsonProperty("topicName") + TOPIC_NAME(FieldType.StringFieldType.INSTANCE), + + @JsonProperty("brokerId") + BROKER_ID(FieldType.Int32FieldType.INSTANCE); + + private final FieldType baseType; + + EntityType(FieldType baseType) { + this.baseType = baseType; + } + + public void verifyTypeMatches(String fieldName, FieldType type) { + if (this == UNKNOWN) { + return; + } + if (type instanceof FieldType.ArrayType) { + FieldType.ArrayType arrayType = (FieldType.ArrayType) type; + verifyTypeMatches(fieldName, arrayType.elementType()); + } else { + if (!type.toString().equals(baseType.toString())) { + throw new RuntimeException("Field " + fieldName + " has entity type " + + name() + ", but field type " + type.toString() + ", which does " + + "not match."); + } + } + } +} diff --git a/generator/src/main/java/org/apache/kafka/message/FieldSpec.java b/generator/src/main/java/org/apache/kafka/message/FieldSpec.java new file mode 100644 index 0000000..d15b03c --- /dev/null +++ b/generator/src/main/java/org/apache/kafka/message/FieldSpec.java @@ -0,0 +1,636 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Base64; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.regex.Pattern; + +public final class FieldSpec { + private static final Pattern VALID_FIELD_NAMES = Pattern.compile("[A-Za-z]([A-Za-z0-9]*)"); + + private final String name; + + private final Versions versions; + + private final List fields; + + private final FieldType type; + + private final boolean mapKey; + + private final Versions nullableVersions; + + private final String fieldDefault; + + private final boolean ignorable; + + private final EntityType entityType; + + private final String about; + + private final Versions taggedVersions; + + private final Optional flexibleVersions; + + private final Optional tag; + + private final boolean zeroCopy; + + @JsonCreator + public FieldSpec(@JsonProperty("name") String name, + @JsonProperty("versions") String versions, + @JsonProperty("fields") List fields, + @JsonProperty("type") String type, + @JsonProperty("mapKey") boolean mapKey, + @JsonProperty("nullableVersions") String nullableVersions, + @JsonProperty("default") String fieldDefault, + @JsonProperty("ignorable") boolean ignorable, + @JsonProperty("entityType") EntityType entityType, + @JsonProperty("about") String about, + @JsonProperty("taggedVersions") String taggedVersions, + @JsonProperty("flexibleVersions") String flexibleVersions, + @JsonProperty("tag") Integer tag, + @JsonProperty("zeroCopy") boolean zeroCopy) { + this.name = Objects.requireNonNull(name); + if (!VALID_FIELD_NAMES.matcher(this.name).matches()) { + throw new RuntimeException("Invalid field name " + this.name); + } + this.taggedVersions = Versions.parse(taggedVersions, Versions.NONE); + // If versions is not set, but taggedVersions is, default to taggedVersions. + this.versions = Versions.parse(versions, this.taggedVersions.empty() ? + null : this.taggedVersions); + if (this.versions == null) { + throw new RuntimeException("You must specify the version of the " + + name + " structure."); + } + this.fields = Collections.unmodifiableList(fields == null ? + Collections.emptyList() : new ArrayList<>(fields)); + this.type = FieldType.parse(Objects.requireNonNull(type)); + this.mapKey = mapKey; + this.nullableVersions = Versions.parse(nullableVersions, Versions.NONE); + if (!this.nullableVersions.empty()) { + if (!this.type.canBeNullable()) { + throw new RuntimeException("Type " + this.type + " cannot be nullable."); + } + } + this.fieldDefault = fieldDefault == null ? "" : fieldDefault; + this.ignorable = ignorable; + this.entityType = (entityType == null) ? EntityType.UNKNOWN : entityType; + this.entityType.verifyTypeMatches(name, this.type); + + this.about = about == null ? "" : about; + if (!this.fields().isEmpty()) { + if (!this.type.isArray() && !this.type.isStruct()) { + throw new RuntimeException("Non-array or Struct field " + name + " cannot have fields"); + } + } + + if (flexibleVersions == null || flexibleVersions.isEmpty()) { + this.flexibleVersions = Optional.empty(); + } else { + this.flexibleVersions = Optional.of(Versions.parse(flexibleVersions, null)); + if (!(this.type.isString() || this.type.isBytes())) { + // For now, only allow flexibleVersions overrides for the string and bytes + // types. Overrides are only needed to keep compatibility with some old formats, + // so there isn't any need to support them for all types. + throw new RuntimeException("Invalid flexibleVersions override for " + name + + ". Only fields of type string or bytes can specify a flexibleVersions " + + "override."); + } + } + this.tag = Optional.ofNullable(tag); + if (this.tag.isPresent() && mapKey) { + throw new RuntimeException("Tagged fields cannot be used as keys."); + } + checkTagInvariants(); + + this.zeroCopy = zeroCopy; + if (this.zeroCopy && !this.type.isBytes()) { + throw new RuntimeException("Invalid zeroCopy value for " + name + + ". Only fields of type bytes can use zeroCopy flag."); + } + } + + private void checkTagInvariants() { + if (this.tag.isPresent()) { + if (this.tag.get() < 0) { + throw new RuntimeException("Field " + name + " specifies a tag of " + this.tag.get() + + ". Tags cannot be negative."); + } + if (this.taggedVersions.empty()) { + throw new RuntimeException("Field " + name + " specifies a tag of " + this.tag.get() + + ", but has no tagged versions. If a tag is specified, taggedVersions must " + + "be specified as well."); + } + Versions nullableTaggedVersions = this.nullableVersions.intersect(this.taggedVersions); + if (!(nullableTaggedVersions.empty() || nullableTaggedVersions.equals(this.taggedVersions))) { + throw new RuntimeException("Field " + name + " specifies nullableVersions " + + this.nullableVersions + " and taggedVersions " + this.taggedVersions + ". " + + "Either all tagged versions must be nullable, or none must be."); + } + if (this.taggedVersions.highest() < Short.MAX_VALUE) { + throw new RuntimeException("Field " + name + " specifies taggedVersions " + + this.taggedVersions + ", which is not open-ended. taggedVersions must " + + "be either none, or an open-ended range (that ends with a plus sign)."); + } + if (!this.taggedVersions.intersect(this.versions).equals(this.taggedVersions)) { + throw new RuntimeException("Field " + name + " specifies taggedVersions " + + this.taggedVersions + ", and versions " + this.versions + ". " + + "taggedVersions must be a subset of versions."); + } + } else if (!this.taggedVersions.empty()) { + throw new RuntimeException("Field " + name + " does not specify a tag, " + + "but specifies tagged versions of " + this.taggedVersions + ". " + + "Please specify a tag, or remove the taggedVersions."); + } + } + + @JsonProperty("name") + public String name() { + return name; + } + + String capitalizedCamelCaseName() { + return MessageGenerator.capitalizeFirst(name); + } + + String camelCaseName() { + return MessageGenerator.lowerCaseFirst(name); + } + + String snakeCaseName() { + return MessageGenerator.toSnakeCase(name); + } + + public Versions versions() { + return versions; + } + + @JsonProperty("versions") + public String versionsString() { + return versions.toString(); + } + + @JsonProperty("fields") + public List fields() { + return fields; + } + + @JsonProperty("type") + public String typeString() { + return type.toString(); + } + + public FieldType type() { + return type; + } + + @JsonProperty("mapKey") + public boolean mapKey() { + return mapKey; + } + + public Versions nullableVersions() { + return nullableVersions; + } + + @JsonProperty("nullableVersions") + public String nullableVersionsString() { + return nullableVersions.toString(); + } + + @JsonProperty("default") + public String defaultString() { + return fieldDefault; + } + + @JsonProperty("ignorable") + public boolean ignorable() { + return ignorable; + } + + @JsonProperty("entityType") + public EntityType entityType() { + return entityType; + } + + @JsonProperty("about") + public String about() { + return about; + } + + @JsonProperty("taggedVersions") + public String taggedVersionsString() { + return taggedVersions.toString(); + } + + public Versions taggedVersions() { + return taggedVersions; + } + + @JsonProperty("flexibleVersions") + public String flexibleVersionsString() { + return flexibleVersions.isPresent() ? flexibleVersions.get().toString() : null; + } + + public Optional flexibleVersions() { + return flexibleVersions; + } + + @JsonProperty("tag") + public Integer tagInteger() { + return tag.orElse(null); + } + + public Optional tag() { + return tag; + } + + @JsonProperty("zeroCopy") + public boolean zeroCopy() { + return zeroCopy; + } + + /** + * Get a string representation of the field default. + * + * @param headerGenerator The header generator in case we need to add imports. + * @param structRegistry The struct registry in case we need to look up structs. + * + * @return A string that can be used for the field default in the + * generated code. + */ + String fieldDefault(HeaderGenerator headerGenerator, + StructRegistry structRegistry) { + if (type instanceof FieldType.BoolFieldType) { + if (fieldDefault.isEmpty()) { + return "false"; + } else if (fieldDefault.equalsIgnoreCase("true")) { + return "true"; + } else if (fieldDefault.equalsIgnoreCase("false")) { + return "false"; + } else { + throw new RuntimeException("Invalid default for boolean field " + + name + ": " + fieldDefault); + } + } else if ((type instanceof FieldType.Int8FieldType) || + (type instanceof FieldType.Int16FieldType) || + (type instanceof FieldType.Uint16FieldType) || + (type instanceof FieldType.Int32FieldType) || + (type instanceof FieldType.Int64FieldType)) { + int base = 10; + String defaultString = fieldDefault; + if (defaultString.startsWith("0x")) { + base = 16; + defaultString = defaultString.substring(2); + } + if (type instanceof FieldType.Int8FieldType) { + if (defaultString.isEmpty()) { + return "(byte) 0"; + } else { + try { + Byte.valueOf(defaultString, base); + } catch (NumberFormatException e) { + throw new RuntimeException("Invalid default for int8 field " + + name + ": " + defaultString, e); + } + return "(byte) " + fieldDefault; + } + } else if (type instanceof FieldType.Int16FieldType) { + if (defaultString.isEmpty()) { + return "(short) 0"; + } else { + try { + Short.valueOf(defaultString, base); + } catch (NumberFormatException e) { + throw new RuntimeException("Invalid default for int16 field " + + name + ": " + defaultString, e); + } + return "(short) " + fieldDefault; + } + } else if (type instanceof FieldType.Uint16FieldType) { + if (defaultString.isEmpty()) { + return "0"; + } else { + try { + int value = Integer.valueOf(defaultString, base); + if (value < 0 || value > 65535) { + throw new RuntimeException("Invalid default for uint16 field " + + name + ": out of range."); + } + } catch (NumberFormatException e) { + throw new RuntimeException("Invalid default for uint16 field " + + name + ": " + defaultString, e); + } + return fieldDefault; + } + } else if (type instanceof FieldType.Int32FieldType) { + if (defaultString.isEmpty()) { + return "0"; + } else { + try { + Integer.valueOf(defaultString, base); + } catch (NumberFormatException e) { + throw new RuntimeException("Invalid default for int32 field " + + name + ": " + defaultString, e); + } + return fieldDefault; + } + } else if (type instanceof FieldType.Int64FieldType) { + if (defaultString.isEmpty()) { + return "0L"; + } else { + try { + Long.valueOf(defaultString, base); + } catch (NumberFormatException e) { + throw new RuntimeException("Invalid default for int64 field " + + name + ": " + defaultString, e); + } + return fieldDefault + "L"; + } + } else { + throw new RuntimeException("Unsupported field type " + type); + } + } else if (type instanceof FieldType.UUIDFieldType) { + headerGenerator.addImport(MessageGenerator.UUID_CLASS); + if (fieldDefault.isEmpty()) { + return "Uuid.ZERO_UUID"; + } else { + try { + ByteBuffer uuidBytes = ByteBuffer.wrap(Base64.getUrlDecoder().decode(fieldDefault)); + uuidBytes.getLong(); + uuidBytes.getLong(); + } catch (IllegalArgumentException e) { + throw new RuntimeException("Invalid default for uuid field " + + name + ": " + fieldDefault, e); + } + headerGenerator.addImport(MessageGenerator.UUID_CLASS); + return "Uuid.fromString(\"" + fieldDefault + "\")"; + } + } else if (type instanceof FieldType.Float64FieldType) { + if (fieldDefault.isEmpty()) { + return "0.0"; + } else { + try { + Double.parseDouble(fieldDefault); + } catch (NumberFormatException e) { + throw new RuntimeException("Invalid default for float64 field " + + name + ": " + fieldDefault, e); + } + return "Double.parseDouble(\"" + fieldDefault + "\")"; + } + } else if (type instanceof FieldType.StringFieldType) { + if (fieldDefault.equals("null")) { + validateNullDefault(); + return "null"; + } else { + return "\"" + fieldDefault + "\""; + } + } else if (type.isBytes()) { + if (fieldDefault.equals("null")) { + validateNullDefault(); + return "null"; + } else if (!fieldDefault.isEmpty()) { + throw new RuntimeException("Invalid default for bytes field " + + name + ". The only valid default for a bytes field " + + "is empty or null."); + } + if (zeroCopy) { + headerGenerator.addImport(MessageGenerator.BYTE_UTILS_CLASS); + return "ByteUtils.EMPTY_BUF"; + } else { + headerGenerator.addImport(MessageGenerator.BYTES_CLASS); + return "Bytes.EMPTY"; + } + } else if (type.isRecords()) { + return "null"; + } else if (type.isStruct()) { + if (!fieldDefault.isEmpty()) { + throw new RuntimeException("Invalid default for struct field " + + name + ": custom defaults are not supported for struct fields."); + } + return "new " + type.toString() + "()"; + } else if (type.isArray()) { + if (fieldDefault.equals("null")) { + validateNullDefault(); + return "null"; + } else if (!fieldDefault.isEmpty()) { + throw new RuntimeException("Invalid default for array field " + + name + ". The only valid default for an array field " + + "is the empty array or null."); + } + return String.format("new %s(0)", + concreteJavaType(headerGenerator, structRegistry)); + } else { + throw new RuntimeException("Unsupported field type " + type); + } + } + + private void validateNullDefault() { + if (!(nullableVersions().contains(versions))) { + throw new RuntimeException("null cannot be the default for field " + + name + ", because not all versions of this field are " + + "nullable."); + } + } + + /** + * Get the abstract Java type of the field-- for example, List. + * + * @param headerGenerator The header generator in case we need to add imports. + * @param structRegistry The struct registry in case we need to look up structs. + * + * @return The abstract java type name. + */ + String fieldAbstractJavaType(HeaderGenerator headerGenerator, + StructRegistry structRegistry) { + if (type instanceof FieldType.BoolFieldType) { + return "boolean"; + } else if (type instanceof FieldType.Int8FieldType) { + return "byte"; + } else if (type instanceof FieldType.Int16FieldType) { + return "short"; + } else if (type instanceof FieldType.Uint16FieldType) { + return "int"; + } else if (type instanceof FieldType.Int32FieldType) { + return "int"; + } else if (type instanceof FieldType.Int64FieldType) { + return "long"; + } else if (type instanceof FieldType.UUIDFieldType) { + headerGenerator.addImport(MessageGenerator.UUID_CLASS); + return "Uuid"; + } else if (type instanceof FieldType.Float64FieldType) { + return "double"; + } else if (type.isString()) { + return "String"; + } else if (type.isBytes()) { + if (zeroCopy) { + headerGenerator.addImport(MessageGenerator.BYTE_BUFFER_CLASS); + return "ByteBuffer"; + } else { + return "byte[]"; + } + } else if (type instanceof FieldType.RecordsFieldType) { + headerGenerator.addImport(MessageGenerator.BASE_RECORDS_CLASS); + return "BaseRecords"; + } else if (type.isStruct()) { + return MessageGenerator.capitalizeFirst(typeString()); + } else if (type.isArray()) { + FieldType.ArrayType arrayType = (FieldType.ArrayType) type; + if (structRegistry.isStructArrayWithKeys(this)) { + headerGenerator.addImport(MessageGenerator.IMPLICIT_LINKED_HASH_MULTI_COLLECTION_CLASS); + return collectionType(arrayType.elementType().toString()); + } else { + headerGenerator.addImport(MessageGenerator.LIST_CLASS); + return String.format("List<%s>", + arrayType.elementType().getBoxedJavaType(headerGenerator)); + } + } else { + throw new RuntimeException("Unknown field type " + type); + } + } + + /** + * Get the concrete Java type of the field-- for example, ArrayList. + * + * @param headerGenerator The header generator in case we need to add imports. + * @param structRegistry The struct registry in case we need to look up structs. + * + * @return The abstract java type name. + */ + String concreteJavaType(HeaderGenerator headerGenerator, + StructRegistry structRegistry) { + if (type.isArray()) { + FieldType.ArrayType arrayType = (FieldType.ArrayType) type; + if (structRegistry.isStructArrayWithKeys(this)) { + return collectionType(arrayType.elementType().toString()); + } else { + headerGenerator.addImport(MessageGenerator.ARRAYLIST_CLASS); + return String.format("ArrayList<%s>", + arrayType.elementType().getBoxedJavaType(headerGenerator)); + } + } else { + return fieldAbstractJavaType(headerGenerator, structRegistry); + } + } + + static String collectionType(String baseType) { + return baseType + "Collection"; + } + + /** + * Generate an if statement that checks if this field has a non-default value. + * + * @param headerGenerator The header generator in case we need to add imports. + * @param structRegistry The struct registry in case we need to look up structs. + * @param buffer The code buffer to write to. + * @param fieldPrefix The prefix to prepend before references to this field. + * @param nullableVersions The nullable versions to use for this field. This is + * mainly to let us choose to ignore the possibility of + * nulls sometimes (like when dealing with array entries + * that cannot be null). + */ + void generateNonDefaultValueCheck(HeaderGenerator headerGenerator, + StructRegistry structRegistry, + CodeBuffer buffer, + String fieldPrefix, + Versions nullableVersions) { + String fieldDefault = fieldDefault(headerGenerator, structRegistry); + if (type().isArray()) { + if (fieldDefault.equals("null")) { + buffer.printf("if (%s%s != null) {%n", fieldPrefix, camelCaseName()); + } else if (nullableVersions.empty()) { + buffer.printf("if (!%s%s.isEmpty()) {%n", fieldPrefix, camelCaseName()); + } else { + buffer.printf("if (%s%s == null || !%s%s.isEmpty()) {%n", + fieldPrefix, camelCaseName(), fieldPrefix, camelCaseName()); + } + } else if (type().isBytes()) { + if (fieldDefault.equals("null")) { + buffer.printf("if (%s%s != null) {%n", fieldPrefix, camelCaseName()); + } else if (nullableVersions.empty()) { + if (zeroCopy()) { + buffer.printf("if (%s%s.hasRemaining()) {%n", + fieldPrefix, camelCaseName()); + } else { + buffer.printf("if (%s%s.length != 0) {%n", + fieldPrefix, camelCaseName()); + } + } else { + if (zeroCopy()) { + buffer.printf("if (%s%s == null || %s%s.remaining() > 0) {%n", + fieldPrefix, camelCaseName(), fieldPrefix, camelCaseName()); + } else { + buffer.printf("if (%s%s == null || %s%s.length != 0) {%n", + fieldPrefix, camelCaseName(), fieldPrefix, camelCaseName()); + } + } + } else if (type().isString() || type().isStruct() || type() instanceof FieldType.UUIDFieldType) { + if (fieldDefault.equals("null")) { + buffer.printf("if (%s%s != null) {%n", fieldPrefix, camelCaseName()); + } else if (nullableVersions.empty()) { + buffer.printf("if (!%s%s.equals(%s)) {%n", + fieldPrefix, camelCaseName(), fieldDefault); + } else { + buffer.printf("if (%s%s == null || !%s%s.equals(%s)) {%n", + fieldPrefix, camelCaseName(), fieldPrefix, camelCaseName(), + fieldDefault); + } + } else if (type() instanceof FieldType.BoolFieldType) { + buffer.printf("if (%s%s%s) {%n", + fieldDefault.equals("true") ? "!" : "", + fieldPrefix, camelCaseName()); + } else { + buffer.printf("if (%s%s != %s) {%n", + fieldPrefix, camelCaseName(), fieldDefault); + } + } + + /** + * Generate an if statement that checks if this field is non-default and also + * non-ignorable. + * + * @param headerGenerator The header generator in case we need to add imports. + * @param structRegistry The struct registry in case we need to look up structs. + * @param fieldPrefix The prefix to prepend before references to this field. + * @param buffer The code buffer to write to. + */ + void generateNonIgnorableFieldCheck(HeaderGenerator headerGenerator, + StructRegistry structRegistry, + String fieldPrefix, + CodeBuffer buffer) { + generateNonDefaultValueCheck(headerGenerator, structRegistry, + buffer, fieldPrefix, nullableVersions()); + buffer.incrementIndent(); + headerGenerator.addImport(MessageGenerator.UNSUPPORTED_VERSION_EXCEPTION_CLASS); + buffer.printf("throw new UnsupportedVersionException(" + + "\"Attempted to write a non-default %s at version \" + _version);%n", + camelCaseName()); + buffer.decrementIndent(); + buffer.printf("}%n"); + } +} diff --git a/generator/src/main/java/org/apache/kafka/message/FieldType.java b/generator/src/main/java/org/apache/kafka/message/FieldType.java new file mode 100644 index 0000000..e0009c2 --- /dev/null +++ b/generator/src/main/java/org/apache/kafka/message/FieldType.java @@ -0,0 +1,487 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +import java.util.Optional; + +public interface FieldType { + String ARRAY_PREFIX = "[]"; + + final class BoolFieldType implements FieldType { + static final BoolFieldType INSTANCE = new BoolFieldType(); + private static final String NAME = "bool"; + + @Override + public String getBoxedJavaType(HeaderGenerator headerGenerator) { + return "Boolean"; + } + + @Override + public Optional fixedLength() { + return Optional.of(1); + } + + @Override + public String toString() { + return NAME; + } + } + + final class Int8FieldType implements FieldType { + static final Int8FieldType INSTANCE = new Int8FieldType(); + private static final String NAME = "int8"; + + @Override + public String getBoxedJavaType(HeaderGenerator headerGenerator) { + return "Byte"; + } + + @Override + public Optional fixedLength() { + return Optional.of(1); + } + + @Override + public String toString() { + return NAME; + } + } + + final class Int16FieldType implements FieldType { + static final Int16FieldType INSTANCE = new Int16FieldType(); + private static final String NAME = "int16"; + + @Override + public String getBoxedJavaType(HeaderGenerator headerGenerator) { + return "Short"; + } + + @Override + public Optional fixedLength() { + return Optional.of(2); + } + + @Override + public String toString() { + return NAME; + } + } + + final class Uint16FieldType implements FieldType { + static final Uint16FieldType INSTANCE = new Uint16FieldType(); + private static final String NAME = "uint16"; + + @Override + public String getBoxedJavaType(HeaderGenerator headerGenerator) { + return "Integer"; + } + + @Override + public Optional fixedLength() { + return Optional.of(2); + } + + @Override + public String toString() { + return NAME; + } + } + + final class Int32FieldType implements FieldType { + static final Int32FieldType INSTANCE = new Int32FieldType(); + private static final String NAME = "int32"; + + @Override + public String getBoxedJavaType(HeaderGenerator headerGenerator) { + return "Integer"; + } + + @Override + public Optional fixedLength() { + return Optional.of(4); + } + + @Override + public String toString() { + return NAME; + } + } + + final class Int64FieldType implements FieldType { + static final Int64FieldType INSTANCE = new Int64FieldType(); + private static final String NAME = "int64"; + + @Override + public String getBoxedJavaType(HeaderGenerator headerGenerator) { + return "Long"; + } + + @Override + public Optional fixedLength() { + return Optional.of(8); + } + + @Override + public String toString() { + return NAME; + } + } + + final class UUIDFieldType implements FieldType { + static final UUIDFieldType INSTANCE = new UUIDFieldType(); + private static final String NAME = "uuid"; + + @Override + public String getBoxedJavaType(HeaderGenerator headerGenerator) { + headerGenerator.addImport(MessageGenerator.UUID_CLASS); + return "Uuid"; + } + + @Override + public Optional fixedLength() { + return Optional.of(16); + } + + @Override + public String toString() { + return NAME; + } + } + + final class Float64FieldType implements FieldType { + static final Float64FieldType INSTANCE = new Float64FieldType(); + private static final String NAME = "float64"; + + @Override + public Optional fixedLength() { + return Optional.of(8); + } + + @Override + public String getBoxedJavaType(HeaderGenerator headerGenerator) { + return "Double"; + } + + @Override + public boolean isFloat() { + return true; + } + + @Override + public String toString() { + return NAME; + } + } + + final class StringFieldType implements FieldType { + static final StringFieldType INSTANCE = new StringFieldType(); + private static final String NAME = "string"; + + @Override + public String getBoxedJavaType(HeaderGenerator headerGenerator) { + return "String"; + } + + @Override + public boolean serializationIsDifferentInFlexibleVersions() { + return true; + } + + @Override + public boolean isString() { + return true; + } + + @Override + public boolean canBeNullable() { + return true; + } + + @Override + public String toString() { + return NAME; + } + } + + final class BytesFieldType implements FieldType { + static final BytesFieldType INSTANCE = new BytesFieldType(); + private static final String NAME = "bytes"; + + @Override + public String getBoxedJavaType(HeaderGenerator headerGenerator) { + headerGenerator.addImport(MessageGenerator.BYTE_BUFFER_CLASS); + return "ByteBuffer"; + } + + @Override + public boolean serializationIsDifferentInFlexibleVersions() { + return true; + } + + @Override + public boolean isBytes() { + return true; + } + + @Override + public boolean canBeNullable() { + return true; + } + + @Override + public String toString() { + return NAME; + } + } + + final class RecordsFieldType implements FieldType { + static final RecordsFieldType INSTANCE = new RecordsFieldType(); + private static final String NAME = "records"; + + @Override + public String getBoxedJavaType(HeaderGenerator headerGenerator) { + headerGenerator.addImport(MessageGenerator.BASE_RECORDS_CLASS); + return "BaseRecords"; + } + + @Override + public boolean serializationIsDifferentInFlexibleVersions() { + return true; + } + + @Override + public boolean isRecords() { + return true; + } + + @Override + public boolean canBeNullable() { + return true; + } + + @Override + public String toString() { + return NAME; + } + } + + final class StructType implements FieldType { + private final String type; + + StructType(String type) { + this.type = type; + } + + @Override + public String getBoxedJavaType(HeaderGenerator headerGenerator) { + return type; + } + + @Override + public boolean serializationIsDifferentInFlexibleVersions() { + return true; + } + + @Override + public boolean isStruct() { + return true; + } + + public String typeName() { + return type; + } + + @Override + public String toString() { + return type; + } + } + + final class ArrayType implements FieldType { + private final FieldType elementType; + + ArrayType(FieldType elementType) { + this.elementType = elementType; + } + + @Override + public boolean serializationIsDifferentInFlexibleVersions() { + return true; + } + + @Override + public String getBoxedJavaType(HeaderGenerator headerGenerator) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isArray() { + return true; + } + + @Override + public boolean isStructArray() { + return elementType.isStruct(); + } + + @Override + public boolean canBeNullable() { + return true; + } + + public FieldType elementType() { + return elementType; + } + + public String elementName() { + return elementType.toString(); + } + + @Override + public String toString() { + return "[]" + elementType.toString(); + } + } + + static FieldType parse(String string) { + string = string.trim(); + switch (string) { + case BoolFieldType.NAME: + return BoolFieldType.INSTANCE; + case Int8FieldType.NAME: + return Int8FieldType.INSTANCE; + case Int16FieldType.NAME: + return Int16FieldType.INSTANCE; + case Uint16FieldType.NAME: + return Uint16FieldType.INSTANCE; + case Int32FieldType.NAME: + return Int32FieldType.INSTANCE; + case Int64FieldType.NAME: + return Int64FieldType.INSTANCE; + case UUIDFieldType.NAME: + return UUIDFieldType.INSTANCE; + case Float64FieldType.NAME: + return Float64FieldType.INSTANCE; + case StringFieldType.NAME: + return StringFieldType.INSTANCE; + case BytesFieldType.NAME: + return BytesFieldType.INSTANCE; + case RecordsFieldType.NAME: + return RecordsFieldType.INSTANCE; + default: + if (string.startsWith(ARRAY_PREFIX)) { + String elementTypeString = string.substring(ARRAY_PREFIX.length()); + if (elementTypeString.length() == 0) { + throw new RuntimeException("Can't parse array type " + string + + ". No element type found."); + } + FieldType elementType = parse(elementTypeString); + if (elementType.isArray()) { + throw new RuntimeException("Can't have an array of arrays. " + + "Use an array of structs containing an array instead."); + } + return new ArrayType(elementType); + } else if (MessageGenerator.firstIsCapitalized(string)) { + return new StructType(string); + } else { + throw new RuntimeException("Can't parse type " + string); + } + } + } + + String getBoxedJavaType(HeaderGenerator headerGenerator); + + /** + * Returns true if this is an array type. + */ + default boolean isArray() { + return false; + } + + /** + * Returns true if this is an array of structures. + */ + default boolean isStructArray() { + return false; + } + + /** + * Returns true if the serialization of this type is different in flexible versions. + */ + default boolean serializationIsDifferentInFlexibleVersions() { + return false; + } + + /** + * Returns true if this is a string type. + */ + default boolean isString() { + return false; + } + + /** + * Returns true if this is a bytes type. + */ + default boolean isBytes() { + return false; + } + + /** + * Returns true if this is a records type + */ + default boolean isRecords() { + return false; + } + + /** + * Returns true if this is a floating point type. + */ + default boolean isFloat() { + return false; + } + + /** + * Returns true if this is a struct type. + */ + default boolean isStruct() { + return false; + } + + /** + * Returns true if this field type is compatible with nullability. + */ + default boolean canBeNullable() { + return false; + } + + /** + * Gets the fixed length of the field, or None if the field is variable-length. + */ + default Optional fixedLength() { + return Optional.empty(); + } + + default boolean isVariableLength() { + return !fixedLength().isPresent(); + } + + /** + * Convert the field type to a JSON string. + */ + String toString(); +} diff --git a/generator/src/main/java/org/apache/kafka/message/HeaderGenerator.java b/generator/src/main/java/org/apache/kafka/message/HeaderGenerator.java new file mode 100644 index 0000000..e1b9c12 --- /dev/null +++ b/generator/src/main/java/org/apache/kafka/message/HeaderGenerator.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +import java.util.Objects; +import java.util.TreeSet; + +/** + * The Kafka header generator. + */ +public final class HeaderGenerator { + private static final String[] HEADER = new String[] { + "/*", + " * Licensed to the Apache Software Foundation (ASF) under one or more", + " * contributor license agreements. See the NOTICE file distributed with", + " * this work for additional information regarding copyright ownership.", + " * The ASF licenses this file to You under the Apache License, Version 2.0", + " * (the \"License\"); you may not use this file except in compliance with", + " * the License. You may obtain a copy of the License at", + " *", + " * http://www.apache.org/licenses/LICENSE-2.0", + " *", + " * Unless required by applicable law or agreed to in writing, software", + " * distributed under the License is distributed on an \"AS IS\" BASIS,", + " * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.", + " * See the License for the specific language governing permissions and", + " * limitations under the License.", + " */", + "", + "// THIS CODE IS AUTOMATICALLY GENERATED. DO NOT EDIT.", + "" + }; + + + private final CodeBuffer buffer; + + private final TreeSet imports; + private final String packageName; + + private final TreeSet staticImports; + + public HeaderGenerator(String packageName) { + this.buffer = new CodeBuffer(); + this.imports = new TreeSet<>(); + this.packageName = packageName; + this.staticImports = new TreeSet<>(); + } + + public void addImport(String newImport) { + this.imports.add(newImport); + } + + public void addStaticImport(String newImport) { + this.staticImports.add(newImport); + } + + public void generate() { + Objects.requireNonNull(packageName); + for (int i = 0; i < HEADER.length; i++) { + buffer.printf("%s%n", HEADER[i]); + } + buffer.printf("package %s;%n", packageName); + buffer.printf("%n"); + for (String newImport : imports) { + buffer.printf("import %s;%n", newImport); + } + buffer.printf("%n"); + if (!staticImports.isEmpty()) { + for (String newImport : staticImports) { + buffer.printf("import static %s;%n", newImport); + } + buffer.printf("%n"); + } + } + + public CodeBuffer buffer() { + return buffer; + } +} diff --git a/generator/src/main/java/org/apache/kafka/message/IsNullConditional.java b/generator/src/main/java/org/apache/kafka/message/IsNullConditional.java new file mode 100644 index 0000000..88a3598 --- /dev/null +++ b/generator/src/main/java/org/apache/kafka/message/IsNullConditional.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +/** + * For versions of a field that are nullable, IsNullCondition creates a null check. + */ +public final class IsNullConditional { + interface ConditionalGenerator { + String generate(String name, boolean negated); + } + + private static class PrimitiveConditionalGenerator implements ConditionalGenerator { + final static PrimitiveConditionalGenerator INSTANCE = new PrimitiveConditionalGenerator(); + + @Override + public String generate(String name, boolean negated) { + if (negated) { + return String.format("%s != null", name); + } else { + return String.format("%s == null", name); + } + } + } + + static IsNullConditional forName(String name) { + return new IsNullConditional(name); + } + + static IsNullConditional forField(FieldSpec field) { + IsNullConditional cond = new IsNullConditional(field.camelCaseName()); + cond.nullableVersions(field.nullableVersions()); + return cond; + } + + private final String name; + private Versions nullableVersions = Versions.ALL; + private Versions possibleVersions = Versions.ALL; + private Runnable ifNull = null; + private Runnable ifShouldNotBeNull = null; + private boolean alwaysEmitBlockScope = false; + private ConditionalGenerator conditionalGenerator = PrimitiveConditionalGenerator.INSTANCE; + + private IsNullConditional(String name) { + this.name = name; + } + + IsNullConditional nullableVersions(Versions nullableVersions) { + this.nullableVersions = nullableVersions; + return this; + } + + IsNullConditional possibleVersions(Versions possibleVersions) { + this.possibleVersions = possibleVersions; + return this; + } + + IsNullConditional ifNull(Runnable ifNull) { + this.ifNull = ifNull; + return this; + } + + IsNullConditional ifShouldNotBeNull(Runnable ifShouldNotBeNull) { + this.ifShouldNotBeNull = ifShouldNotBeNull; + return this; + } + + IsNullConditional alwaysEmitBlockScope(boolean alwaysEmitBlockScope) { + this.alwaysEmitBlockScope = alwaysEmitBlockScope; + return this; + } + + IsNullConditional conditionalGenerator(ConditionalGenerator conditionalGenerator) { + this.conditionalGenerator = conditionalGenerator; + return this; + } + + void generate(CodeBuffer buffer) { + if (nullableVersions.intersect(possibleVersions).empty()) { + if (ifShouldNotBeNull != null) { + if (alwaysEmitBlockScope) { + buffer.printf("{%n"); + buffer.incrementIndent(); + } + ifShouldNotBeNull.run(); + if (alwaysEmitBlockScope) { + buffer.decrementIndent(); + buffer.printf("}%n"); + } + } + } else { + if (ifNull != null) { + buffer.printf("if (%s) {%n", conditionalGenerator.generate(name, false)); + buffer.incrementIndent(); + ifNull.run(); + buffer.decrementIndent(); + if (ifShouldNotBeNull != null) { + buffer.printf("} else {%n"); + buffer.incrementIndent(); + ifShouldNotBeNull.run(); + buffer.decrementIndent(); + } + buffer.printf("}%n"); + } else if (ifShouldNotBeNull != null) { + buffer.printf("if (%s) {%n", conditionalGenerator.generate(name, true)); + buffer.incrementIndent(); + ifShouldNotBeNull.run(); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + } + } +} diff --git a/generator/src/main/java/org/apache/kafka/message/JsonConverterGenerator.java b/generator/src/main/java/org/apache/kafka/message/JsonConverterGenerator.java new file mode 100644 index 0000000..2df8170 --- /dev/null +++ b/generator/src/main/java/org/apache/kafka/message/JsonConverterGenerator.java @@ -0,0 +1,442 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +import java.io.BufferedWriter; +import java.util.Iterator; + +/** + * Generates Kafka MessageData classes. + */ +public final class JsonConverterGenerator implements MessageClassGenerator { + private final static String SUFFIX = "JsonConverter"; + private final String packageName; + private final StructRegistry structRegistry; + private final HeaderGenerator headerGenerator; + private final CodeBuffer buffer; + + JsonConverterGenerator(String packageName) { + this.packageName = packageName; + this.structRegistry = new StructRegistry(); + this.headerGenerator = new HeaderGenerator(packageName); + this.buffer = new CodeBuffer(); + } + + @Override + public String outputName(MessageSpec spec) { + return spec.dataClassName() + SUFFIX; + } + + @Override + public void generateAndWrite(MessageSpec message, BufferedWriter writer) + throws Exception { + structRegistry.register(message); + headerGenerator.addStaticImport(String.format("%s.%s.*", + packageName, message.dataClassName())); + buffer.printf("public class %s {%n", + MessageGenerator.capitalizeFirst(outputName(message))); + buffer.incrementIndent(); + generateConverters(message.dataClassName(), message.struct(), + message.validVersions()); + for (Iterator iter = structRegistry.structs(); + iter.hasNext(); ) { + StructRegistry.StructInfo info = iter.next(); + buffer.printf("%n"); + buffer.printf("public static class %s {%n", + MessageGenerator.capitalizeFirst(info.spec().name() + SUFFIX)); + buffer.incrementIndent(); + generateConverters(MessageGenerator.capitalizeFirst(info.spec().name()), + info.spec(), info.parentVersions()); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + buffer.decrementIndent(); + buffer.printf("}%n"); + headerGenerator.generate(); + headerGenerator.buffer().write(writer); + buffer.write(writer); + } + + private void generateConverters(String name, + StructSpec spec, + Versions parentVersions) { + generateRead(name, spec, parentVersions); + generateWrite(name, spec, parentVersions); + generateOverloadWrite(name); + } + + private void generateRead(String className, + StructSpec struct, + Versions parentVersions) { + headerGenerator.addImport(MessageGenerator.JSON_NODE_CLASS); + buffer.printf("public static %s read(JsonNode _node, short _version) {%n", + className); + buffer.incrementIndent(); + buffer.printf("%s _object = new %s();%n", className, className); + VersionConditional.forVersions(struct.versions(), parentVersions). + allowMembershipCheckAlwaysFalse(false). + ifNotMember(__ -> { + headerGenerator.addImport(MessageGenerator.UNSUPPORTED_VERSION_EXCEPTION_CLASS); + buffer.printf("throw new UnsupportedVersionException(\"Can't read " + + "version \" + _version + \" of %s\");%n", className); + }). + generate(buffer); + Versions curVersions = parentVersions.intersect(struct.versions()); + for (FieldSpec field : struct.fields()) { + String sourceVariable = String.format("_%sNode", field.camelCaseName()); + buffer.printf("JsonNode %s = _node.get(\"%s\");%n", + sourceVariable, + field.camelCaseName()); + buffer.printf("if (%s == null) {%n", sourceVariable); + buffer.incrementIndent(); + Versions mandatoryVersions = field.versions().subtract(field.taggedVersions()); + VersionConditional.forVersions(mandatoryVersions, curVersions). + ifMember(__ -> { + buffer.printf("throw new RuntimeException(\"%s: unable to locate " + + "field \'%s\', which is mandatory in version \" + _version);%n", + className, field.camelCaseName()); + }). + ifNotMember(__ -> { + buffer.printf("_object.%s = %s;%n", field.camelCaseName(), + field.fieldDefault(headerGenerator, structRegistry)); + }). + generate(buffer); + buffer.decrementIndent(); + buffer.printf("} else {%n"); + buffer.incrementIndent(); + VersionConditional.forVersions(struct.versions(), curVersions). + ifMember(presentVersions -> { + generateTargetFromJson(new Target(field, + sourceVariable, + className, + input -> String.format("_object.%s = %s", field.camelCaseName(), input)), + curVersions); + }).ifNotMember(__ -> { + buffer.printf("throw new RuntimeException(\"%s: field \'%s\' is not " + + "supported in version \" + _version);%n", + className, field.camelCaseName()); + }).generate(buffer); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + buffer.printf("return _object;%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateTargetFromJson(Target target, Versions curVersions) { + if (target.field().type() instanceof FieldType.BoolFieldType) { + buffer.printf("if (!%s.isBoolean()) {%n", target.sourceVariable()); + buffer.incrementIndent(); + buffer.printf("throw new RuntimeException(\"%s expected Boolean type, " + + "but got \" + _node.getNodeType());%n", target.humanReadableName()); + buffer.decrementIndent(); + buffer.printf("}%n"); + buffer.printf("%s;%n", target.assignmentStatement( + target.sourceVariable() + ".asBoolean()")); + } else if (target.field().type() instanceof FieldType.Int8FieldType) { + headerGenerator.addImport(MessageGenerator.MESSAGE_UTIL_CLASS); + buffer.printf("%s;%n", target.assignmentStatement( + String.format("MessageUtil.jsonNodeToByte(%s, \"%s\")", + target.sourceVariable(), target.humanReadableName()))); + } else if (target.field().type() instanceof FieldType.Int16FieldType) { + headerGenerator.addImport(MessageGenerator.MESSAGE_UTIL_CLASS); + buffer.printf("%s;%n", target.assignmentStatement( + String.format("MessageUtil.jsonNodeToShort(%s, \"%s\")", + target.sourceVariable(), target.humanReadableName()))); + } else if (target.field().type() instanceof FieldType.Uint16FieldType) { + headerGenerator.addImport(MessageGenerator.MESSAGE_UTIL_CLASS); + buffer.printf("%s;%n", target.assignmentStatement( + String.format("MessageUtil.jsonNodeToUnsignedShort(%s, \"%s\")", + target.sourceVariable(), target.humanReadableName()))); + } else if (target.field().type() instanceof FieldType.Int32FieldType) { + headerGenerator.addImport(MessageGenerator.MESSAGE_UTIL_CLASS); + buffer.printf("%s;%n", target.assignmentStatement( + String.format("MessageUtil.jsonNodeToInt(%s, \"%s\")", + target.sourceVariable(), target.humanReadableName()))); + } else if (target.field().type() instanceof FieldType.Int64FieldType) { + headerGenerator.addImport(MessageGenerator.MESSAGE_UTIL_CLASS); + buffer.printf("%s;%n", target.assignmentStatement( + String.format("MessageUtil.jsonNodeToLong(%s, \"%s\")", + target.sourceVariable(), target.humanReadableName()))); + } else if (target.field().type() instanceof FieldType.UUIDFieldType) { + buffer.printf("if (!%s.isTextual()) {%n", target.sourceVariable()); + buffer.incrementIndent(); + buffer.printf("throw new RuntimeException(\"%s expected a JSON string " + + "type, but got \" + _node.getNodeType());%n", target.humanReadableName()); + buffer.decrementIndent(); + buffer.printf("}%n"); + headerGenerator.addImport(MessageGenerator.UUID_CLASS); + buffer.printf("%s;%n", target.assignmentStatement(String.format( + "Uuid.fromString(%s.asText())", target.sourceVariable()))); + } else if (target.field().type() instanceof FieldType.Float64FieldType) { + headerGenerator.addImport(MessageGenerator.MESSAGE_UTIL_CLASS); + buffer.printf("%s;%n", target.assignmentStatement( + String.format("MessageUtil.jsonNodeToDouble(%s, \"%s\")", + target.sourceVariable(), target.humanReadableName()))); + } else { + // Handle the variable length types. All of them are potentially + // nullable, so handle that here. + IsNullConditional.forName(target.sourceVariable()). + nullableVersions(target.field().nullableVersions()). + possibleVersions(curVersions). + conditionalGenerator((name, negated) -> + String.format("%s%s.isNull()", negated ? "!" : "", name)). + ifNull(() -> { + buffer.printf("%s;%n", target.assignmentStatement("null")); + }). + ifShouldNotBeNull(() -> { + generateVariableLengthTargetFromJson(target, curVersions); + }). + generate(buffer); + } + } + + private void generateVariableLengthTargetFromJson(Target target, Versions curVersions) { + if (target.field().type().isString()) { + buffer.printf("if (!%s.isTextual()) {%n", target.sourceVariable()); + buffer.incrementIndent(); + buffer.printf("throw new RuntimeException(\"%s expected a string " + + "type, but got \" + _node.getNodeType());%n", target.humanReadableName()); + buffer.decrementIndent(); + buffer.printf("}%n"); + buffer.printf("%s;%n", target.assignmentStatement( + String.format("%s.asText()", target.sourceVariable()))); + } else if (target.field().type().isBytes()) { + headerGenerator.addImport(MessageGenerator.MESSAGE_UTIL_CLASS); + if (target.field().zeroCopy()) { + headerGenerator.addImport(MessageGenerator.BYTE_BUFFER_CLASS); + buffer.printf("%s;%n", target.assignmentStatement( + String.format("ByteBuffer.wrap(MessageUtil.jsonNodeToBinary(%s, \"%s\"))", + target.sourceVariable(), target.humanReadableName()))); + } else { + buffer.printf("%s;%n", target.assignmentStatement( + String.format("MessageUtil.jsonNodeToBinary(%s, \"%s\")", + target.sourceVariable(), target.humanReadableName()))); + } + } else if (target.field().type().isRecords()) { + headerGenerator.addImport(MessageGenerator.MESSAGE_UTIL_CLASS); + headerGenerator.addImport(MessageGenerator.BYTE_BUFFER_CLASS); + headerGenerator.addImport(MessageGenerator.MEMORY_RECORDS_CLASS); + buffer.printf("%s;%n", target.assignmentStatement( + String.format("MemoryRecords.readableRecords(ByteBuffer.wrap(MessageUtil.jsonNodeToBinary(%s, \"%s\")))", + target.sourceVariable(), target.humanReadableName()))); + } else if (target.field().type().isArray()) { + buffer.printf("if (!%s.isArray()) {%n", target.sourceVariable()); + buffer.incrementIndent(); + buffer.printf("throw new RuntimeException(\"%s expected a JSON " + + "array, but got \" + _node.getNodeType());%n", target.humanReadableName()); + buffer.decrementIndent(); + buffer.printf("}%n"); + String type = target.field().concreteJavaType(headerGenerator, structRegistry); + buffer.printf("%s _collection = new %s(%s.size());%n", type, type, target.sourceVariable()); + buffer.printf("%s;%n", target.assignmentStatement("_collection")); + headerGenerator.addImport(MessageGenerator.JSON_NODE_CLASS); + buffer.printf("for (JsonNode _element : %s) {%n", target.sourceVariable()); + buffer.incrementIndent(); + generateTargetFromJson(target.arrayElementTarget( + input -> String.format("_collection.add(%s)", input)), + curVersions); + buffer.decrementIndent(); + buffer.printf("}%n"); + } else if (target.field().type().isStruct()) { + buffer.printf("%s;%n", target.assignmentStatement( + String.format("%s%s.read(%s, _version)", + target.field().type().toString(), SUFFIX, target.sourceVariable()))); + } else { + throw new RuntimeException("Unexpected type " + target.field().type()); + } + } + + private void generateOverloadWrite(String className) { + buffer.printf("public static JsonNode write(%s _object, short _version) {%n", + className); + buffer.incrementIndent(); + buffer.printf("return write(_object, _version, true);%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateWrite(String className, + StructSpec struct, + Versions parentVersions) { + headerGenerator.addImport(MessageGenerator.JSON_NODE_CLASS); + buffer.printf("public static JsonNode write(%s _object, short _version, boolean _serializeRecords) {%n", + className); + buffer.incrementIndent(); + VersionConditional.forVersions(struct.versions(), parentVersions). + allowMembershipCheckAlwaysFalse(false). + ifNotMember(__ -> { + headerGenerator.addImport(MessageGenerator.UNSUPPORTED_VERSION_EXCEPTION_CLASS); + buffer.printf("throw new UnsupportedVersionException(\"Can't write " + + "version \" + _version + \" of %s\");%n", className); + }). + generate(buffer); + Versions curVersions = parentVersions.intersect(struct.versions()); + headerGenerator.addImport(MessageGenerator.OBJECT_NODE_CLASS); + headerGenerator.addImport(MessageGenerator.JSON_NODE_FACTORY_CLASS); + buffer.printf("ObjectNode _node = new ObjectNode(JsonNodeFactory.instance);%n"); + for (FieldSpec field : struct.fields()) { + Target target = new Target(field, + String.format("_object.%s", field.camelCaseName()), + field.camelCaseName(), + input -> String.format("_node.set(\"%s\", %s)", field.camelCaseName(), input)); + VersionConditional cond = VersionConditional.forVersions(field.versions(), curVersions). + ifMember(presentVersions -> { + VersionConditional.forVersions(field.taggedVersions(), presentVersions). + ifMember(presentAndTaggedVersions -> { + field.generateNonDefaultValueCheck(headerGenerator, + structRegistry, buffer, "_object.", field.nullableVersions()); + buffer.incrementIndent(); + if (field.defaultString().equals("null")) { + // If the default was null, and we already checked that this field was not + // the default, we can omit further null checks. + generateTargetToJson(target.nonNullableCopy(), presentAndTaggedVersions); + } else { + generateTargetToJson(target, presentAndTaggedVersions); + } + buffer.decrementIndent(); + buffer.printf("}%n"); + }). + ifNotMember(presentAndNotTaggedVersions -> { + generateTargetToJson(target, presentAndNotTaggedVersions); + }). + generate(buffer); + }); + if (!field.ignorable()) { + cond.ifNotMember(__ -> { + field.generateNonIgnorableFieldCheck(headerGenerator, + structRegistry, "_object.", buffer); + }); + } + cond.generate(buffer); + } + buffer.printf("return _node;%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateTargetToJson(Target target, Versions versions) { + if (target.field().type() instanceof FieldType.BoolFieldType) { + headerGenerator.addImport(MessageGenerator.BOOLEAN_NODE_CLASS); + buffer.printf("%s;%n", target.assignmentStatement( + String.format("BooleanNode.valueOf(%s)", target.sourceVariable()))); + } else if ((target.field().type() instanceof FieldType.Int8FieldType) || + (target.field().type() instanceof FieldType.Int16FieldType)) { + headerGenerator.addImport(MessageGenerator.SHORT_NODE_CLASS); + buffer.printf("%s;%n", target.assignmentStatement( + String.format("new ShortNode(%s)", target.sourceVariable()))); + } else if ((target.field().type() instanceof FieldType.Int32FieldType) || + (target.field().type() instanceof FieldType.Uint16FieldType)) { + headerGenerator.addImport(MessageGenerator.INT_NODE_CLASS); + buffer.printf("%s;%n", target.assignmentStatement( + String.format("new IntNode(%s)", target.sourceVariable()))); + } else if (target.field().type() instanceof FieldType.Int64FieldType) { + headerGenerator.addImport(MessageGenerator.LONG_NODE_CLASS); + buffer.printf("%s;%n", target.assignmentStatement( + String.format("new LongNode(%s)", target.sourceVariable()))); + } else if (target.field().type() instanceof FieldType.UUIDFieldType) { + headerGenerator.addImport(MessageGenerator.TEXT_NODE_CLASS); + buffer.printf("%s;%n", target.assignmentStatement( + String.format("new TextNode(%s.toString())", target.sourceVariable()))); + } else if (target.field().type() instanceof FieldType.Float64FieldType) { + headerGenerator.addImport(MessageGenerator.DOUBLE_NODE_CLASS); + buffer.printf("%s;%n", target.assignmentStatement( + String.format("new DoubleNode(%s)", target.sourceVariable()))); + } else { + // Handle the variable length types. All of them are potentially + // nullable, so handle that here. + IsNullConditional.forName(target.sourceVariable()). + nullableVersions(target.field().nullableVersions()). + possibleVersions(versions). + conditionalGenerator((name, negated) -> + String.format("%s %s= null", name, negated ? "!" : "=")). + ifNull(() -> { + headerGenerator.addImport(MessageGenerator.NULL_NODE_CLASS); + buffer.printf("%s;%n", target.assignmentStatement("NullNode.instance")); + }). + ifShouldNotBeNull(() -> { + generateVariableLengthTargetToJson(target, versions); + }). + generate(buffer); + } + } + + private void generateVariableLengthTargetToJson(Target target, Versions versions) { + if (target.field().type().isString()) { + headerGenerator.addImport(MessageGenerator.TEXT_NODE_CLASS); + buffer.printf("%s;%n", target.assignmentStatement( + String.format("new TextNode(%s)", target.sourceVariable()))); + } else if (target.field().type().isBytes()) { + headerGenerator.addImport(MessageGenerator.BINARY_NODE_CLASS); + if (target.field().zeroCopy()) { + headerGenerator.addImport(MessageGenerator.MESSAGE_UTIL_CLASS); + buffer.printf("%s;%n", target.assignmentStatement( + String.format("new BinaryNode(MessageUtil.byteBufferToArray(%s))", + target.sourceVariable()))); + } else { + headerGenerator.addImport(MessageGenerator.ARRAYS_CLASS); + buffer.printf("%s;%n", target.assignmentStatement( + String.format("new BinaryNode(Arrays.copyOf(%s, %s.length))", + target.sourceVariable(), target.sourceVariable()))); + } + } else if (target.field().type().isRecords()) { + headerGenerator.addImport(MessageGenerator.BINARY_NODE_CLASS); + headerGenerator.addImport(MessageGenerator.INT_NODE_CLASS); + // KIP-673: When logging requests/responses, we do not serialize the record, instead we + // output its sizeInBytes, because outputting the bytes is not very useful and can be + // quite expensive. Otherwise, we will serialize the record. + buffer.printf("if (_serializeRecords) {%n"); + buffer.incrementIndent(); + buffer.printf("%s;%n", target.assignmentStatement("new BinaryNode(new byte[]{})")); + buffer.decrementIndent(); + buffer.printf("} else {%n"); + buffer.incrementIndent(); + buffer.printf("_node.set(\"%sSizeInBytes\", new IntNode(%s.sizeInBytes()));%n", + target.field().camelCaseName(), + target.sourceVariable()); + buffer.decrementIndent(); + buffer.printf("}%n"); + } else if (target.field().type().isArray()) { + headerGenerator.addImport(MessageGenerator.ARRAY_NODE_CLASS); + headerGenerator.addImport(MessageGenerator.JSON_NODE_FACTORY_CLASS); + FieldType.ArrayType arrayType = (FieldType.ArrayType) target.field().type(); + FieldType elementType = arrayType.elementType(); + String arrayInstanceName = String.format("_%sArray", + target.field().camelCaseName()); + buffer.printf("ArrayNode %s = new ArrayNode(JsonNodeFactory.instance);%n", + arrayInstanceName); + buffer.printf("for (%s _element : %s) {%n", + elementType.getBoxedJavaType(headerGenerator), target.sourceVariable()); + buffer.incrementIndent(); + generateTargetToJson(target.arrayElementTarget( + input -> String.format("%s.add(%s)", arrayInstanceName, input)), + versions); + buffer.decrementIndent(); + buffer.printf("}%n"); + buffer.printf("%s;%n", target.assignmentStatement(arrayInstanceName)); + } else if (target.field().type().isStruct()) { + buffer.printf("%s;%n", target.assignmentStatement( + String.format("%sJsonConverter.write(%s, _version, _serializeRecords)", + target.field().type().toString(), target.sourceVariable()))); + } else { + throw new RuntimeException("unknown type " + target.field().type()); + } + } + +} diff --git a/generator/src/main/java/org/apache/kafka/message/MessageClassGenerator.java b/generator/src/main/java/org/apache/kafka/message/MessageClassGenerator.java new file mode 100644 index 0000000..8311756 --- /dev/null +++ b/generator/src/main/java/org/apache/kafka/message/MessageClassGenerator.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +import java.io.BufferedWriter; + +public interface MessageClassGenerator { + /** + * The short name of the converter class we are generating. For example, + * FetchRequestDataJsonConverter.java. + */ + String outputName(MessageSpec spec); + + /** + * Generate the convertere, and then write it out. + * + * @param spec The message to generate a converter for. + * @param writer The writer to write out the state to. + */ + void generateAndWrite(MessageSpec spec, BufferedWriter writer) throws Exception; +} diff --git a/generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java b/generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java new file mode 100644 index 0000000..b9923ee --- /dev/null +++ b/generator/src/main/java/org/apache/kafka/message/MessageDataGenerator.java @@ -0,0 +1,1615 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +import java.io.BufferedWriter; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashSet; +import java.util.Optional; +import java.util.Set; +import java.util.TreeMap; +import java.util.stream.Collectors; + +/** + * Generates Kafka MessageData classes. + */ +public final class MessageDataGenerator implements MessageClassGenerator { + private final StructRegistry structRegistry; + private final HeaderGenerator headerGenerator; + private final SchemaGenerator schemaGenerator; + private final CodeBuffer buffer; + private Versions messageFlexibleVersions; + + MessageDataGenerator(String packageName) { + this.structRegistry = new StructRegistry(); + this.headerGenerator = new HeaderGenerator(packageName); + this.schemaGenerator = new SchemaGenerator(headerGenerator, structRegistry); + this.buffer = new CodeBuffer(); + } + + @Override + public String outputName(MessageSpec spec) { + return spec.dataClassName(); + } + + @Override + public void generateAndWrite(MessageSpec message, BufferedWriter writer) throws Exception { + generate(message); + write(writer); + } + + void generate(MessageSpec message) throws Exception { + if (message.struct().versions().contains(Short.MAX_VALUE)) { + throw new RuntimeException("Message " + message.name() + " does " + + "not specify a maximum version."); + } + structRegistry.register(message); + schemaGenerator.generateSchemas(message); + messageFlexibleVersions = message.flexibleVersions(); + generateClass(Optional.of(message), + message.dataClassName(), + message.struct(), + message.struct().versions()); + headerGenerator.generate(); + } + + void write(BufferedWriter writer) throws Exception { + headerGenerator.buffer().write(writer); + buffer.write(writer); + } + + private void generateClass(Optional topLevelMessageSpec, + String className, + StructSpec struct, + Versions parentVersions) throws Exception { + buffer.printf("%n"); + boolean isTopLevel = topLevelMessageSpec.isPresent(); + boolean isSetElement = struct.hasKeys(); // Check if the class is inside a set. + if (isTopLevel && isSetElement) { + throw new RuntimeException("Cannot set mapKey on top level fields."); + } + generateClassHeader(className, isTopLevel, isSetElement); + buffer.incrementIndent(); + generateFieldDeclarations(struct, isSetElement); + buffer.printf("%n"); + schemaGenerator.writeSchema(className, buffer); + generateClassConstructors(className, struct, isSetElement); + buffer.printf("%n"); + if (isTopLevel) { + generateShortAccessor("apiKey", topLevelMessageSpec.get().apiKey().orElse((short) -1)); + } + buffer.printf("%n"); + generateShortAccessor("lowestSupportedVersion", parentVersions.lowest()); + buffer.printf("%n"); + generateShortAccessor("highestSupportedVersion", parentVersions.highest()); + buffer.printf("%n"); + generateClassReader(className, struct, parentVersions); + buffer.printf("%n"); + generateClassWriter(className, struct, parentVersions); + buffer.printf("%n"); + generateClassMessageSize(className, struct, parentVersions); + if (isSetElement) { + buffer.printf("%n"); + generateClassEquals(className, struct, true); + } + buffer.printf("%n"); + generateClassEquals(className, struct, false); + buffer.printf("%n"); + generateClassHashCode(struct, isSetElement); + buffer.printf("%n"); + generateClassDuplicate(className, struct); + buffer.printf("%n"); + generateClassToString(className, struct); + generateFieldAccessors(struct, isSetElement); + buffer.printf("%n"); + generateUnknownTaggedFieldsAccessor(); + generateFieldMutators(struct, className, isSetElement); + + if (!isTopLevel) { + buffer.decrementIndent(); + buffer.printf("}%n"); + } + generateSubclasses(className, struct, parentVersions, isSetElement); + if (isTopLevel) { + for (Iterator iter = structRegistry.commonStructs(); iter.hasNext(); ) { + StructSpec commonStruct = iter.next(); + generateClass(Optional.empty(), + commonStruct.name(), + commonStruct, + commonStruct.versions()); + } + buffer.decrementIndent(); + buffer.printf("}%n"); + } + } + + private void generateClassHeader(String className, boolean isTopLevel, + boolean isSetElement) { + Set implementedInterfaces = new HashSet<>(); + if (isTopLevel) { + implementedInterfaces.add("ApiMessage"); + headerGenerator.addImport(MessageGenerator.API_MESSAGE_CLASS); + } else { + implementedInterfaces.add("Message"); + headerGenerator.addImport(MessageGenerator.MESSAGE_CLASS); + } + if (isSetElement) { + headerGenerator.addImport(MessageGenerator.IMPLICIT_LINKED_HASH_MULTI_COLLECTION_CLASS); + implementedInterfaces.add("ImplicitLinkedHashMultiCollection.Element"); + } + Set classModifiers = new LinkedHashSet<>(); + classModifiers.add("public"); + if (!isTopLevel) { + classModifiers.add("static"); + } + buffer.printf("%s class %s implements %s {%n", + String.join(" ", classModifiers), + className, + String.join(", ", implementedInterfaces)); + } + + private void generateSubclasses(String className, StructSpec struct, + Versions parentVersions, boolean isSetElement) throws Exception { + for (FieldSpec field : struct.fields()) { + if (field.type().isStructArray()) { + FieldType.ArrayType arrayType = (FieldType.ArrayType) field.type(); + if (!structRegistry.commonStructNames().contains(arrayType.elementName())) { + generateClass(Optional.empty(), + arrayType.elementType().toString(), + structRegistry.findStruct(field), + parentVersions.intersect(struct.versions())); + } + } else if (field.type().isStruct()) { + if (!structRegistry.commonStructNames().contains(field.typeString())) { + generateClass(Optional.empty(), + field.typeString(), + structRegistry.findStruct(field), + parentVersions.intersect(struct.versions())); + } + } + } + if (isSetElement) { + generateHashSet(className, struct); + } + } + + private void generateHashSet(String className, StructSpec struct) { + buffer.printf("%n"); + headerGenerator.addImport(MessageGenerator.IMPLICIT_LINKED_HASH_MULTI_COLLECTION_CLASS); + buffer.printf("public static class %s extends ImplicitLinkedHashMultiCollection<%s> {%n", + FieldSpec.collectionType(className), className); + buffer.incrementIndent(); + generateHashSetZeroArgConstructor(className); + buffer.printf("%n"); + generateHashSetSizeArgConstructor(className); + buffer.printf("%n"); + generateHashSetIteratorConstructor(className); + buffer.printf("%n"); + generateHashSetFindMethod(className, struct); + buffer.printf("%n"); + generateHashSetFindAllMethod(className, struct); + buffer.printf("%n"); + generateCollectionDuplicateMethod(className); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateHashSetZeroArgConstructor(String className) { + buffer.printf("public %s() {%n", FieldSpec.collectionType(className)); + buffer.incrementIndent(); + buffer.printf("super();%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateHashSetSizeArgConstructor(String className) { + buffer.printf("public %s(int expectedNumElements) {%n", + FieldSpec.collectionType(className)); + buffer.incrementIndent(); + buffer.printf("super(expectedNumElements);%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateHashSetIteratorConstructor(String className) { + headerGenerator.addImport(MessageGenerator.ITERATOR_CLASS); + buffer.printf("public %s(Iterator<%s> iterator) {%n", + FieldSpec.collectionType(className), className); + buffer.incrementIndent(); + buffer.printf("super(iterator);%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateHashSetFindMethod(String className, StructSpec struct) { + headerGenerator.addImport(MessageGenerator.LIST_CLASS); + buffer.printf("public %s find(%s) {%n", className, + commaSeparatedHashSetFieldAndTypes(struct)); + buffer.incrementIndent(); + generateKeyElement(className, struct); + headerGenerator.addImport(MessageGenerator.IMPLICIT_LINKED_HASH_MULTI_COLLECTION_CLASS); + buffer.printf("return find(_key);%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateHashSetFindAllMethod(String className, StructSpec struct) { + headerGenerator.addImport(MessageGenerator.LIST_CLASS); + buffer.printf("public List<%s> findAll(%s) {%n", className, + commaSeparatedHashSetFieldAndTypes(struct)); + buffer.incrementIndent(); + generateKeyElement(className, struct); + headerGenerator.addImport(MessageGenerator.IMPLICIT_LINKED_HASH_MULTI_COLLECTION_CLASS); + buffer.printf("return findAll(_key);%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateKeyElement(String className, StructSpec struct) { + buffer.printf("%s _key = new %s();%n", className, className); + for (FieldSpec field : struct.fields()) { + if (field.mapKey()) { + buffer.printf("_key.set%s(%s);%n", + field.capitalizedCamelCaseName(), + field.camelCaseName()); + } + } + } + + private String commaSeparatedHashSetFieldAndTypes(StructSpec struct) { + return struct.fields().stream(). + filter(f -> f.mapKey()). + map(f -> String.format("%s %s", + f.concreteJavaType(headerGenerator, structRegistry), f.camelCaseName())). + collect(Collectors.joining(", ")); + } + + private void generateCollectionDuplicateMethod(String className) { + headerGenerator.addImport(MessageGenerator.LIST_CLASS); + buffer.printf("public %s duplicate() {%n", FieldSpec.collectionType(className)); + buffer.incrementIndent(); + buffer.printf("%s _duplicate = new %s(size());%n", + FieldSpec.collectionType(className), FieldSpec.collectionType(className)); + buffer.printf("for (%s _element : this) {%n", className); + buffer.incrementIndent(); + buffer.printf("_duplicate.add(_element.duplicate());%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + buffer.printf("return _duplicate;%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateFieldDeclarations(StructSpec struct, boolean isSetElement) { + for (FieldSpec field : struct.fields()) { + generateFieldDeclaration(field); + } + headerGenerator.addImport(MessageGenerator.LIST_CLASS); + headerGenerator.addImport(MessageGenerator.RAW_TAGGED_FIELD_CLASS); + buffer.printf("private List _unknownTaggedFields;%n"); + if (isSetElement) { + buffer.printf("private int next;%n"); + buffer.printf("private int prev;%n"); + } + } + + private void generateFieldDeclaration(FieldSpec field) { + buffer.printf("%s %s;%n", + field.fieldAbstractJavaType(headerGenerator, structRegistry), + field.camelCaseName()); + } + + private void generateFieldAccessors(StructSpec struct, boolean isSetElement) { + for (FieldSpec field : struct.fields()) { + generateFieldAccessor(field); + } + if (isSetElement) { + buffer.printf("%n"); + buffer.printf("@Override%n"); + generateAccessor("int", "next", "next"); + + buffer.printf("%n"); + buffer.printf("@Override%n"); + generateAccessor("int", "prev", "prev"); + } + } + + private void generateUnknownTaggedFieldsAccessor() { + buffer.printf("@Override%n"); + headerGenerator.addImport(MessageGenerator.LIST_CLASS); + headerGenerator.addImport(MessageGenerator.RAW_TAGGED_FIELD_CLASS); + buffer.printf("public List unknownTaggedFields() {%n"); + buffer.incrementIndent(); + // Optimize _unknownTaggedFields by not creating a new list object + // unless we need it. + buffer.printf("if (_unknownTaggedFields == null) {%n"); + buffer.incrementIndent(); + headerGenerator.addImport(MessageGenerator.ARRAYLIST_CLASS); + buffer.printf("_unknownTaggedFields = new ArrayList<>(0);%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + buffer.printf("return _unknownTaggedFields;%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + + } + + private void generateFieldMutators(StructSpec struct, String className, + boolean isSetElement) { + for (FieldSpec field : struct.fields()) { + generateFieldMutator(className, field); + } + if (isSetElement) { + buffer.printf("%n"); + buffer.printf("@Override%n"); + generateSetter("int", "setNext", "next"); + + buffer.printf("%n"); + buffer.printf("@Override%n"); + generateSetter("int", "setPrev", "prev"); + } + } + + private void generateClassConstructors(String className, StructSpec struct, boolean isSetElement) { + headerGenerator.addImport(MessageGenerator.READABLE_CLASS); + buffer.printf("public %s(Readable _readable, short _version) {%n", className); + buffer.incrementIndent(); + buffer.printf("read(_readable, _version);%n"); + generateConstructorEpilogue(isSetElement); + buffer.decrementIndent(); + buffer.printf("}%n"); + buffer.printf("%n"); + + buffer.printf("public %s() {%n", className); + buffer.incrementIndent(); + for (FieldSpec field : struct.fields()) { + buffer.printf("this.%s = %s;%n", + field.camelCaseName(), + field.fieldDefault(headerGenerator, structRegistry)); + } + generateConstructorEpilogue(isSetElement); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateConstructorEpilogue(boolean isSetElement) { + if (isSetElement) { + headerGenerator.addImport(MessageGenerator.IMPLICIT_LINKED_HASH_COLLECTION_CLASS); + buffer.printf("this.prev = ImplicitLinkedHashCollection.INVALID_INDEX;%n"); + buffer.printf("this.next = ImplicitLinkedHashCollection.INVALID_INDEX;%n"); + } + } + + private void generateShortAccessor(String name, short val) { + buffer.printf("@Override%n"); + buffer.printf("public short %s() {%n", name); + buffer.incrementIndent(); + buffer.printf("return %d;%n", val); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateClassReader(String className, StructSpec struct, + Versions parentVersions) { + headerGenerator.addImport(MessageGenerator.READABLE_CLASS); + buffer.printf("@Override%n"); + buffer.printf("public void read(Readable _readable, short _version) {%n"); + buffer.incrementIndent(); + VersionConditional.forVersions(parentVersions, struct.versions()). + allowMembershipCheckAlwaysFalse(false). + ifNotMember(__ -> { + headerGenerator.addImport(MessageGenerator.UNSUPPORTED_VERSION_EXCEPTION_CLASS); + buffer.printf("throw new UnsupportedVersionException(\"Can't read " + + "version \" + _version + \" of %s\");%n", className); + }). + generate(buffer); + Versions curVersions = parentVersions.intersect(struct.versions()); + for (FieldSpec field : struct.fields()) { + Versions fieldFlexibleVersions = fieldFlexibleVersions(field); + if (!field.taggedVersions().intersect(fieldFlexibleVersions).equals(field.taggedVersions())) { + throw new RuntimeException("Field " + field.name() + " specifies tagged " + + "versions " + field.taggedVersions() + " that are not a subset of the " + + "flexible versions " + fieldFlexibleVersions); + } + Versions mandatoryVersions = field.versions().subtract(field.taggedVersions()); + VersionConditional.forVersions(mandatoryVersions, curVersions). + alwaysEmitBlockScope(field.type().isVariableLength()). + ifNotMember(__ -> { + // If the field is not present, or is tagged, set it to its default here. + buffer.printf("this.%s = %s;%n", field.camelCaseName(), + field.fieldDefault(headerGenerator, structRegistry)); + }). + ifMember(presentAndUntaggedVersions -> { + if (field.type().isVariableLength() && !field.type().isStruct()) { + ClauseGenerator callGenerateVariableLengthReader = versions -> { + generateVariableLengthReader(fieldFlexibleVersions(field), + field.camelCaseName(), + field.type(), + versions, + field.nullableVersions(), + String.format("this.%s = ", field.camelCaseName()), + String.format(";%n"), + structRegistry.isStructArrayWithKeys(field), + field.zeroCopy()); + }; + // For arrays where the field type needs to be serialized differently in flexible + // versions, lift the flexible version check outside of the array. + // This may mean generating two separate 'for' loops-- one for flexible + // versions, and one for regular versions. + if (field.type().isArray() && + ((FieldType.ArrayType) field.type()).elementType(). + serializationIsDifferentInFlexibleVersions()) { + VersionConditional.forVersions(fieldFlexibleVersions(field), + presentAndUntaggedVersions). + ifMember(callGenerateVariableLengthReader). + ifNotMember(callGenerateVariableLengthReader). + generate(buffer); + } else { + callGenerateVariableLengthReader.generate(presentAndUntaggedVersions); + } + } else { + buffer.printf("this.%s = %s;%n", field.camelCaseName(), + primitiveReadExpression(field.type())); + } + }). + generate(buffer); + } + buffer.printf("this._unknownTaggedFields = null;%n"); + VersionConditional.forVersions(messageFlexibleVersions, curVersions). + ifMember(curFlexibleVersions -> { + buffer.printf("int _numTaggedFields = _readable.readUnsignedVarint();%n"); + buffer.printf("for (int _i = 0; _i < _numTaggedFields; _i++) {%n"); + buffer.incrementIndent(); + buffer.printf("int _tag = _readable.readUnsignedVarint();%n"); + buffer.printf("int _size = _readable.readUnsignedVarint();%n"); + buffer.printf("switch (_tag) {%n"); + buffer.incrementIndent(); + for (FieldSpec field : struct.fields()) { + Versions validTaggedVersions = field.versions().intersect(field.taggedVersions()); + if (!validTaggedVersions.empty()) { + if (!field.tag().isPresent()) { + throw new RuntimeException("Field " + field.name() + " has tagged versions, but no tag."); + } + buffer.printf("case %d: {%n", field.tag().get()); + buffer.incrementIndent(); + VersionConditional.forVersions(validTaggedVersions, curFlexibleVersions). + ifMember(presentAndTaggedVersions -> { + if (field.type().isVariableLength() && !field.type().isStruct()) { + // All tagged fields are serialized using the new-style + // flexible versions serialization. + generateVariableLengthReader(fieldFlexibleVersions(field), + field.camelCaseName(), + field.type(), + presentAndTaggedVersions, + field.nullableVersions(), + String.format("this.%s = ", field.camelCaseName()), + String.format(";%n"), + structRegistry.isStructArrayWithKeys(field), + field.zeroCopy()); + } else { + buffer.printf("this.%s = %s;%n", field.camelCaseName(), + primitiveReadExpression(field.type())); + } + buffer.printf("break;%n"); + }). + ifNotMember(__ -> { + buffer.printf("throw new RuntimeException(\"Tag %d is not " + + "valid for version \" + _version);%n", field.tag().get()); + }). + generate(buffer); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + } + buffer.printf("default:%n"); + buffer.incrementIndent(); + buffer.printf("this._unknownTaggedFields = _readable.readUnknownTaggedField(this._unknownTaggedFields, _tag, _size);%n"); + buffer.printf("break;%n"); + buffer.decrementIndent(); + buffer.decrementIndent(); + buffer.printf("}%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + }). + generate(buffer); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private String primitiveReadExpression(FieldType type) { + if (type instanceof FieldType.BoolFieldType) { + return "_readable.readByte() != 0"; + } else if (type instanceof FieldType.Int8FieldType) { + return "_readable.readByte()"; + } else if (type instanceof FieldType.Int16FieldType) { + return "_readable.readShort()"; + } else if (type instanceof FieldType.Uint16FieldType) { + return "_readable.readUnsignedShort()"; + } else if (type instanceof FieldType.Int32FieldType) { + return "_readable.readInt()"; + } else if (type instanceof FieldType.Int64FieldType) { + return "_readable.readLong()"; + } else if (type instanceof FieldType.UUIDFieldType) { + return "_readable.readUuid()"; + } else if (type instanceof FieldType.Float64FieldType) { + return "_readable.readDouble()"; + } else if (type.isStruct()) { + return String.format("new %s(_readable, _version)", type.toString()); + } else { + throw new RuntimeException("Unsupported field type " + type); + } + } + + private void generateVariableLengthReader(Versions fieldFlexibleVersions, + String name, + FieldType type, + Versions possibleVersions, + Versions nullableVersions, + String assignmentPrefix, + String assignmentSuffix, + boolean isStructArrayWithKeys, + boolean zeroCopy) { + String lengthVar = type.isArray() ? "arrayLength" : "length"; + buffer.printf("int %s;%n", lengthVar); + VersionConditional.forVersions(fieldFlexibleVersions, possibleVersions). + ifMember(__ -> { + buffer.printf("%s = _readable.readUnsignedVarint() - 1;%n", lengthVar); + }). + ifNotMember(__ -> { + if (type.isString()) { + buffer.printf("%s = _readable.readShort();%n", lengthVar); + } else if (type.isBytes() || type.isArray() || type.isRecords()) { + buffer.printf("%s = _readable.readInt();%n", lengthVar); + } else { + throw new RuntimeException("Can't handle variable length type " + type); + } + }). + generate(buffer); + buffer.printf("if (%s < 0) {%n", lengthVar); + buffer.incrementIndent(); + VersionConditional.forVersions(nullableVersions, possibleVersions). + ifNotMember(__ -> { + buffer.printf("throw new RuntimeException(\"non-nullable field %s " + + "was serialized as null\");%n", name); + }). + ifMember(__ -> { + buffer.printf("%snull%s", assignmentPrefix, assignmentSuffix); + }). + generate(buffer); + buffer.decrementIndent(); + if (type.isString()) { + buffer.printf("} else if (%s > 0x7fff) {%n", lengthVar); + buffer.incrementIndent(); + buffer.printf("throw new RuntimeException(\"string field %s " + + "had invalid length \" + %s);%n", name, lengthVar); + buffer.decrementIndent(); + } + buffer.printf("} else {%n"); + buffer.incrementIndent(); + if (type.isString()) { + buffer.printf("%s_readable.readString(%s)%s", + assignmentPrefix, lengthVar, assignmentSuffix); + } else if (type.isBytes()) { + if (zeroCopy) { + buffer.printf("%s_readable.readByteBuffer(%s)%s", + assignmentPrefix, lengthVar, assignmentSuffix); + } else { + buffer.printf("byte[] newBytes = new byte[%s];%n", lengthVar); + buffer.printf("_readable.readArray(newBytes);%n"); + buffer.printf("%snewBytes%s", assignmentPrefix, assignmentSuffix); + } + } else if (type.isRecords()) { + buffer.printf("%s_readable.readRecords(%s)%s", + assignmentPrefix, lengthVar, assignmentSuffix); + } else if (type.isArray()) { + FieldType.ArrayType arrayType = (FieldType.ArrayType) type; + if (isStructArrayWithKeys) { + headerGenerator.addImport(MessageGenerator.IMPLICIT_LINKED_HASH_MULTI_COLLECTION_CLASS); + buffer.printf("%s newCollection = new %s(%s);%n", + FieldSpec.collectionType(arrayType.elementType().toString()), + FieldSpec.collectionType(arrayType.elementType().toString()), lengthVar); + } else { + headerGenerator.addImport(MessageGenerator.ARRAYLIST_CLASS); + String boxedArrayType = + arrayType.elementType().getBoxedJavaType(headerGenerator); + buffer.printf("ArrayList<%s> newCollection = new ArrayList<>(%s);%n", boxedArrayType, lengthVar); + } + buffer.printf("for (int i = 0; i < %s; i++) {%n", lengthVar); + buffer.incrementIndent(); + if (arrayType.elementType().isArray()) { + throw new RuntimeException("Nested arrays are not supported. " + + "Use an array of structures containing another array."); + } else if (arrayType.elementType().isBytes() || arrayType.elementType().isString()) { + generateVariableLengthReader(fieldFlexibleVersions, + name + " element", + arrayType.elementType(), + possibleVersions, + Versions.NONE, + "newCollection.add(", + String.format(");%n"), + false, + false); + } else { + buffer.printf("newCollection.add(%s);%n", + primitiveReadExpression(arrayType.elementType())); + } + buffer.decrementIndent(); + buffer.printf("}%n"); + buffer.printf("%snewCollection%s", assignmentPrefix, assignmentSuffix); + } else { + throw new RuntimeException("Can't handle variable length type " + type); + } + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateClassWriter(String className, StructSpec struct, + Versions parentVersions) { + headerGenerator.addImport(MessageGenerator.WRITABLE_CLASS); + headerGenerator.addImport(MessageGenerator.OBJECT_SERIALIZATION_CACHE_CLASS); + buffer.printf("@Override%n"); + buffer.printf("public void write(Writable _writable, ObjectSerializationCache _cache, short _version) {%n"); + buffer.incrementIndent(); + VersionConditional.forVersions(struct.versions(), parentVersions). + allowMembershipCheckAlwaysFalse(false). + ifNotMember(__ -> { + headerGenerator.addImport(MessageGenerator.UNSUPPORTED_VERSION_EXCEPTION_CLASS); + buffer.printf("throw new UnsupportedVersionException(\"Can't write " + + "version \" + _version + \" of %s\");%n", className); + }). + generate(buffer); + buffer.printf("int _numTaggedFields = 0;%n"); + Versions curVersions = parentVersions.intersect(struct.versions()); + TreeMap taggedFields = new TreeMap<>(); + for (FieldSpec field : struct.fields()) { + VersionConditional cond = VersionConditional.forVersions(field.versions(), curVersions). + ifMember(presentVersions -> { + VersionConditional.forVersions(field.taggedVersions(), presentVersions). + ifNotMember(presentAndUntaggedVersions -> { + if (field.type().isVariableLength() && !field.type().isStruct()) { + ClauseGenerator callGenerateVariableLengthWriter = versions -> { + generateVariableLengthWriter(fieldFlexibleVersions(field), + field.camelCaseName(), + field.type(), + versions, + field.nullableVersions(), + field.zeroCopy()); + }; + // For arrays where the field type needs to be serialized differently in flexible + // versions, lift the flexible version check outside of the array. + // This may mean generating two separate 'for' loops-- one for flexible + // versions, and one for regular versions. + if (field.type().isArray() && + ((FieldType.ArrayType) field.type()).elementType(). + serializationIsDifferentInFlexibleVersions()) { + VersionConditional.forVersions(fieldFlexibleVersions(field), + presentAndUntaggedVersions). + ifMember(callGenerateVariableLengthWriter). + ifNotMember(callGenerateVariableLengthWriter). + generate(buffer); + } else { + callGenerateVariableLengthWriter.generate(presentAndUntaggedVersions); + } + } else { + buffer.printf("%s;%n", + primitiveWriteExpression(field.type(), field.camelCaseName())); + } + }). + ifMember(__ -> { + field.generateNonDefaultValueCheck(headerGenerator, + structRegistry, buffer, "this.", field.nullableVersions()); + buffer.incrementIndent(); + buffer.printf("_numTaggedFields++;%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + if (taggedFields.put(field.tag().get(), field) != null) { + throw new RuntimeException("Field " + field.name() + " has tag " + + field.tag() + ", but another field already used that tag."); + } + }). + generate(buffer); + }); + if (!field.ignorable()) { + cond.ifNotMember(__ -> { + field.generateNonIgnorableFieldCheck(headerGenerator, + structRegistry, "this.", buffer); + }); + } + cond.generate(buffer); + } + headerGenerator.addImport(MessageGenerator.RAW_TAGGED_FIELD_WRITER_CLASS); + buffer.printf("RawTaggedFieldWriter _rawWriter = RawTaggedFieldWriter.forFields(_unknownTaggedFields);%n"); + buffer.printf("_numTaggedFields += _rawWriter.numFields();%n"); + VersionConditional.forVersions(messageFlexibleVersions, curVersions). + ifNotMember(__ -> { + generateCheckForUnsupportedNumTaggedFields("_numTaggedFields > 0"); + }). + ifMember(__ -> { + buffer.printf("_writable.writeUnsignedVarint(_numTaggedFields);%n"); + int prevTag = -1; + for (FieldSpec field : taggedFields.values()) { + if (prevTag + 1 != field.tag().get()) { + buffer.printf("_rawWriter.writeRawTags(_writable, %d);%n", field.tag().get()); + } + VersionConditional. + forVersions(field.versions(), field.taggedVersions().intersect(field.versions())). + allowMembershipCheckAlwaysFalse(false). + ifMember(presentAndTaggedVersions -> { + IsNullConditional cond = IsNullConditional.forName(field.camelCaseName()). + nullableVersions(field.nullableVersions()). + possibleVersions(presentAndTaggedVersions). + alwaysEmitBlockScope(true). + ifShouldNotBeNull(() -> { + if (!field.defaultString().equals("null")) { + field.generateNonDefaultValueCheck(headerGenerator, + structRegistry, buffer, "this.", Versions.NONE); + buffer.incrementIndent(); + } + buffer.printf("_writable.writeUnsignedVarint(%d);%n", field.tag().get()); + if (field.type().isString()) { + buffer.printf("byte[] _stringBytes = _cache.getSerializedValue(this.%s);%n", + field.camelCaseName()); + headerGenerator.addImport(MessageGenerator.BYTE_UTILS_CLASS); + buffer.printf("_writable.writeUnsignedVarint(_stringBytes.length + " + + "ByteUtils.sizeOfUnsignedVarint(_stringBytes.length + 1));%n"); + buffer.printf("_writable.writeUnsignedVarint(_stringBytes.length + 1);%n"); + buffer.printf("_writable.writeByteArray(_stringBytes);%n"); + } else if (field.type().isBytes()) { + headerGenerator.addImport(MessageGenerator.BYTE_UTILS_CLASS); + buffer.printf("_writable.writeUnsignedVarint(this.%s.length + " + + "ByteUtils.sizeOfUnsignedVarint(this.%s.length + 1));%n", + field.camelCaseName(), field.camelCaseName()); + buffer.printf("_writable.writeUnsignedVarint(this.%s.length + 1);%n", + field.camelCaseName()); + buffer.printf("_writable.writeByteArray(this.%s);%n", + field.camelCaseName()); + } else if (field.type().isArray()) { + headerGenerator.addImport(MessageGenerator.BYTE_UTILS_CLASS); + buffer.printf("_writable.writeUnsignedVarint(_cache.getArraySizeInBytes(this.%s));%n", + field.camelCaseName()); + generateVariableLengthWriter(fieldFlexibleVersions(field), + field.camelCaseName(), + field.type(), + presentAndTaggedVersions, + Versions.NONE, + field.zeroCopy()); + } else if (field.type().isStruct()) { + buffer.printf("_writable.writeUnsignedVarint(this.%s.size(_cache, _version));%n", + field.camelCaseName()); + buffer.printf("%s;%n", + primitiveWriteExpression(field.type(), field.camelCaseName())); + } else if (field.type().isRecords()) { + throw new RuntimeException("Unsupported attempt to declare field `" + + field.name() + "` with `records` type as a tagged field."); + } else { + buffer.printf("_writable.writeUnsignedVarint(%d);%n", + field.type().fixedLength().get()); + buffer.printf("%s;%n", + primitiveWriteExpression(field.type(), field.camelCaseName())); + } + if (!field.defaultString().equals("null")) { + buffer.decrementIndent(); + buffer.printf("}%n"); + } + }); + if (!field.defaultString().equals("null")) { + cond.ifNull(() -> { + buffer.printf("_writable.writeUnsignedVarint(%d);%n", field.tag().get()); + buffer.printf("_writable.writeUnsignedVarint(1);%n"); + buffer.printf("_writable.writeUnsignedVarint(0);%n"); + }); + } + cond.generate(buffer); + }). + generate(buffer); + prevTag = field.tag().get(); + } + if (prevTag < Integer.MAX_VALUE) { + buffer.printf("_rawWriter.writeRawTags(_writable, Integer.MAX_VALUE);%n"); + } + }). + generate(buffer); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateCheckForUnsupportedNumTaggedFields(String conditional) { + buffer.printf("if (%s) {%n", conditional); + buffer.incrementIndent(); + headerGenerator.addImport(MessageGenerator.UNSUPPORTED_VERSION_EXCEPTION_CLASS); + buffer.printf("throw new UnsupportedVersionException(\"Tagged fields were set, " + + "but version \" + _version + \" of this message does not support them.\");%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private String primitiveWriteExpression(FieldType type, String name) { + if (type instanceof FieldType.BoolFieldType) { + return String.format("_writable.writeByte(%s ? (byte) 1 : (byte) 0)", name); + } else if (type instanceof FieldType.Int8FieldType) { + return String.format("_writable.writeByte(%s)", name); + } else if (type instanceof FieldType.Int16FieldType) { + return String.format("_writable.writeShort(%s)", name); + } else if (type instanceof FieldType.Uint16FieldType) { + return String.format("_writable.writeUnsignedShort(%s)", name); + } else if (type instanceof FieldType.Int32FieldType) { + return String.format("_writable.writeInt(%s)", name); + } else if (type instanceof FieldType.Int64FieldType) { + return String.format("_writable.writeLong(%s)", name); + } else if (type instanceof FieldType.UUIDFieldType) { + return String.format("_writable.writeUuid(%s)", name); + } else if (type instanceof FieldType.Float64FieldType) { + return String.format("_writable.writeDouble(%s)", name); + } else if (type instanceof FieldType.StructType) { + return String.format("%s.write(_writable, _cache, _version)", name); + } else { + throw new RuntimeException("Unsupported field type " + type); + } + } + + private void generateVariableLengthWriter(Versions fieldFlexibleVersions, + String name, + FieldType type, + Versions possibleVersions, + Versions nullableVersions, + boolean zeroCopy) { + IsNullConditional.forName(name). + possibleVersions(possibleVersions). + nullableVersions(nullableVersions). + alwaysEmitBlockScope(type.isString()). + ifNull(() -> { + VersionConditional.forVersions(nullableVersions, possibleVersions). + ifMember(presentVersions -> { + VersionConditional.forVersions(fieldFlexibleVersions, presentVersions). + ifMember(___ -> { + buffer.printf("_writable.writeUnsignedVarint(0);%n"); + }). + ifNotMember(___ -> { + if (type.isString()) { + buffer.printf("_writable.writeShort((short) -1);%n"); + } else { + buffer.printf("_writable.writeInt(-1);%n"); + } + }). + generate(buffer); + }). + ifNotMember(__ -> { + buffer.printf("throw new NullPointerException();%n"); + }). + generate(buffer); + }). + ifShouldNotBeNull(() -> { + final String lengthExpression; + if (type.isString()) { + buffer.printf("byte[] _stringBytes = _cache.getSerializedValue(%s);%n", + name); + lengthExpression = "_stringBytes.length"; + } else if (type.isBytes()) { + if (zeroCopy) { + lengthExpression = String.format("%s.remaining()", name); + } else { + lengthExpression = String.format("%s.length", name); + } + } else if (type.isRecords()) { + lengthExpression = String.format("%s.sizeInBytes()", name); + } else if (type.isArray()) { + lengthExpression = String.format("%s.size()", name); + } else { + throw new RuntimeException("Unhandled type " + type); + } + // Check whether we're dealing with a flexible version or not. In a flexible + // version, the length is serialized differently. + // + // Note: for arrays, each branch of the if contains the loop for writing out + // the elements. This allows us to lift the version check out of the loop. + // This is helpful for things like arrays of strings, where each element + // will be serialized differently based on whether the version is flexible. + VersionConditional.forVersions(fieldFlexibleVersions, possibleVersions). + ifMember(ifMemberVersions -> { + buffer.printf("_writable.writeUnsignedVarint(%s + 1);%n", lengthExpression); + }). + ifNotMember(ifNotMemberVersions -> { + if (type.isString()) { + buffer.printf("_writable.writeShort((short) %s);%n", lengthExpression); + } else { + buffer.printf("_writable.writeInt(%s);%n", lengthExpression); + } + }). + generate(buffer); + if (type.isString()) { + buffer.printf("_writable.writeByteArray(_stringBytes);%n"); + } else if (type.isBytes()) { + if (zeroCopy) { + buffer.printf("_writable.writeByteBuffer(%s);%n", name); + } else { + buffer.printf("_writable.writeByteArray(%s);%n", name); + } + } else if (type.isRecords()) { + buffer.printf("_writable.writeRecords(%s);%n", name); + } else if (type.isArray()) { + FieldType.ArrayType arrayType = (FieldType.ArrayType) type; + FieldType elementType = arrayType.elementType(); + String elementName = String.format("%sElement", name); + buffer.printf("for (%s %s : %s) {%n", + elementType.getBoxedJavaType(headerGenerator), + elementName, + name); + buffer.incrementIndent(); + if (elementType.isArray()) { + throw new RuntimeException("Nested arrays are not supported. " + + "Use an array of structures containing another array."); + } else if (elementType.isBytes() || elementType.isString()) { + generateVariableLengthWriter(fieldFlexibleVersions, + elementName, + elementType, + possibleVersions, + Versions.NONE, + false); + } else { + buffer.printf("%s;%n", primitiveWriteExpression(elementType, elementName)); + } + buffer.decrementIndent(); + buffer.printf("}%n"); + } + }). + generate(buffer); + } + + private void generateClassMessageSize( + String className, + StructSpec struct, + Versions parentVersions + ) { + headerGenerator.addImport(MessageGenerator.OBJECT_SERIALIZATION_CACHE_CLASS); + headerGenerator.addImport(MessageGenerator.MESSAGE_SIZE_ACCUMULATOR_CLASS); + buffer.printf("@Override%n"); + buffer.printf("public void addSize(MessageSizeAccumulator _size, ObjectSerializationCache _cache, short _version) {%n"); + buffer.incrementIndent(); + buffer.printf("int _numTaggedFields = 0;%n"); + VersionConditional.forVersions(parentVersions, struct.versions()). + allowMembershipCheckAlwaysFalse(false). + ifNotMember(__ -> { + headerGenerator.addImport(MessageGenerator.UNSUPPORTED_VERSION_EXCEPTION_CLASS); + buffer.printf("throw new UnsupportedVersionException(\"Can't size " + + "version \" + _version + \" of %s\");%n", className); + }). + generate(buffer); + Versions curVersions = parentVersions.intersect(struct.versions()); + for (FieldSpec field : struct.fields()) { + VersionConditional.forVersions(field.versions(), curVersions). + ifMember(presentVersions -> { + VersionConditional.forVersions(field.taggedVersions(), presentVersions). + ifMember(presentAndTaggedVersions -> { + generateFieldSize(field, presentAndTaggedVersions, true); + }). + ifNotMember(presentAndUntaggedVersions -> { + generateFieldSize(field, presentAndUntaggedVersions, false); + }). + generate(buffer); + }).generate(buffer); + } + buffer.printf("if (_unknownTaggedFields != null) {%n"); + buffer.incrementIndent(); + buffer.printf("_numTaggedFields += _unknownTaggedFields.size();%n"); + buffer.printf("for (RawTaggedField _field : _unknownTaggedFields) {%n"); + buffer.incrementIndent(); + headerGenerator.addImport(MessageGenerator.BYTE_UTILS_CLASS); + buffer.printf("_size.addBytes(ByteUtils.sizeOfUnsignedVarint(_field.tag()));%n"); + buffer.printf("_size.addBytes(ByteUtils.sizeOfUnsignedVarint(_field.size()));%n"); + buffer.printf("_size.addBytes(_field.size());%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + VersionConditional.forVersions(messageFlexibleVersions, curVersions). + ifNotMember(__ -> { + generateCheckForUnsupportedNumTaggedFields("_numTaggedFields > 0"); + }). + ifMember(__ -> { + headerGenerator.addImport(MessageGenerator.BYTE_UTILS_CLASS); + buffer.printf("_size.addBytes(ByteUtils.sizeOfUnsignedVarint(_numTaggedFields));%n"); + }). + generate(buffer); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + /** + * Generate the size calculator for a variable-length array element. + * Array elements cannot be null. + */ + private void generateVariableLengthArrayElementSize(Versions flexibleVersions, + String fieldName, + FieldType type, + Versions versions) { + if (type instanceof FieldType.StringFieldType) { + generateStringToBytes(fieldName); + VersionConditional.forVersions(flexibleVersions, versions). + ifNotMember(__ -> { + buffer.printf("_size.addBytes(_stringBytes.length + 2);%n"); + }). + ifMember(__ -> { + headerGenerator.addImport(MessageGenerator.BYTE_UTILS_CLASS); + buffer.printf("_size.addBytes(_stringBytes.length + " + + "ByteUtils.sizeOfUnsignedVarint(_stringBytes.length + 1));%n"); + }). + generate(buffer); + } else if (type instanceof FieldType.BytesFieldType) { + buffer.printf("_size.addBytes(%s.length);%n", fieldName); + VersionConditional.forVersions(flexibleVersions, versions). + ifNotMember(__ -> { + buffer.printf("_size.addBytes(4);%n"); + }). + ifMember(__ -> { + headerGenerator.addImport(MessageGenerator.BYTE_UTILS_CLASS); + buffer.printf("_size.addBytes(" + + "ByteUtils.sizeOfUnsignedVarint(%s.length + 1));%n", + fieldName); + }). + generate(buffer); + } else if (type instanceof FieldType.StructType) { + buffer.printf("%s.addSize(_size, _cache, _version);%n", fieldName); + } else { + throw new RuntimeException("Unsupported type " + type); + } + } + + private void generateFieldSize(FieldSpec field, + Versions possibleVersions, + boolean tagged) { + if (field.type().fixedLength().isPresent()) { + generateFixedLengthFieldSize(field, tagged); + } else { + generateVariableLengthFieldSize(field, possibleVersions, tagged); + } + } + + private void generateFixedLengthFieldSize(FieldSpec field, + boolean tagged) { + if (tagged) { + // Check to see that the field is not set to the default value. + // If it is, then we don't need to serialize it. + field.generateNonDefaultValueCheck(headerGenerator, structRegistry, buffer, + "this.", field.nullableVersions()); + buffer.incrementIndent(); + buffer.printf("_numTaggedFields++;%n"); + buffer.printf("_size.addBytes(%d);%n", + MessageGenerator.sizeOfUnsignedVarint(field.tag().get())); + // Account for the tagged field prefix length. + buffer.printf("_size.addBytes(%d);%n", + MessageGenerator.sizeOfUnsignedVarint(field.type().fixedLength().get())); + buffer.printf("_size.addBytes(%d);%n", field.type().fixedLength().get()); + buffer.decrementIndent(); + buffer.printf("}%n"); + } else { + buffer.printf("_size.addBytes(%d);%n", field.type().fixedLength().get()); + } + } + + private void generateVariableLengthFieldSize(FieldSpec field, + Versions possibleVersions, + boolean tagged) { + IsNullConditional.forField(field). + alwaysEmitBlockScope(true). + possibleVersions(possibleVersions). + nullableVersions(field.nullableVersions()). + ifNull(() -> { + if (!tagged || !field.defaultString().equals("null")) { + VersionConditional.forVersions(fieldFlexibleVersions(field), possibleVersions). + ifMember(__ -> { + if (tagged) { + buffer.printf("_numTaggedFields++;%n"); + buffer.printf("_size.addBytes(%d);%n", + MessageGenerator.sizeOfUnsignedVarint(field.tag().get())); + buffer.printf("_size.addBytes(%d);%n", MessageGenerator.sizeOfUnsignedVarint( + MessageGenerator.sizeOfUnsignedVarint(0))); + } + buffer.printf("_size.addBytes(%d);%n", MessageGenerator.sizeOfUnsignedVarint(0)); + }). + ifNotMember(__ -> { + if (tagged) { + throw new RuntimeException("Tagged field " + field.name() + + " should not be present in non-flexible versions."); + } + if (field.type().isString()) { + buffer.printf("_size.addBytes(2);%n"); + } else { + buffer.printf("_size.addBytes(4);%n"); + } + }). + generate(buffer); + } + }). + ifShouldNotBeNull(() -> { + if (tagged) { + if (!field.defaultString().equals("null")) { + field.generateNonDefaultValueCheck(headerGenerator, + structRegistry, buffer, "this.", Versions.NONE); + buffer.incrementIndent(); + } + buffer.printf("_numTaggedFields++;%n"); + buffer.printf("_size.addBytes(%d);%n", + MessageGenerator.sizeOfUnsignedVarint(field.tag().get())); + } + if (field.type().isString()) { + generateStringToBytes(field.camelCaseName()); + VersionConditional.forVersions(fieldFlexibleVersions(field), possibleVersions). + ifMember(__ -> { + headerGenerator.addImport(MessageGenerator.BYTE_UTILS_CLASS); + if (tagged) { + buffer.printf("int _stringPrefixSize = " + + "ByteUtils.sizeOfUnsignedVarint(_stringBytes.length + 1);%n"); + buffer.printf("_size.addBytes(_stringBytes.length + _stringPrefixSize + " + + "ByteUtils.sizeOfUnsignedVarint(_stringPrefixSize + _stringBytes.length));%n"); + + } else { + buffer.printf("_size.addBytes(_stringBytes.length + " + + "ByteUtils.sizeOfUnsignedVarint(_stringBytes.length + 1));%n"); + } + }). + ifNotMember(__ -> { + if (tagged) { + throw new RuntimeException("Tagged field " + field.name() + + " should not be present in non-flexible versions."); + } + buffer.printf("_size.addBytes(_stringBytes.length + 2);%n"); + }). + generate(buffer); + } else if (field.type().isArray()) { + if (tagged) { + buffer.printf("int _sizeBeforeArray = _size.totalSize();%n"); + } + VersionConditional.forVersions(fieldFlexibleVersions(field), possibleVersions). + ifMember(__ -> { + headerGenerator.addImport(MessageGenerator.BYTE_UTILS_CLASS); + buffer.printf("_size.addBytes(ByteUtils.sizeOfUnsignedVarint(%s.size() + 1));%n", + field.camelCaseName()); + }). + ifNotMember(__ -> { + buffer.printf("_size.addBytes(4);%n"); + }). + generate(buffer); + FieldType.ArrayType arrayType = (FieldType.ArrayType) field.type(); + FieldType elementType = arrayType.elementType(); + if (elementType.fixedLength().isPresent()) { + buffer.printf("_size.addBytes(%s.size() * %d);%n", + field.camelCaseName(), + elementType.fixedLength().get()); + } else if (elementType instanceof FieldType.ArrayType) { + throw new RuntimeException("Arrays of arrays are not supported " + + "(use a struct)."); + } else { + buffer.printf("for (%s %sElement : %s) {%n", + elementType.getBoxedJavaType(headerGenerator), + field.camelCaseName(), field.camelCaseName()); + buffer.incrementIndent(); + generateVariableLengthArrayElementSize(fieldFlexibleVersions(field), + String.format("%sElement", field.camelCaseName()), + elementType, + possibleVersions); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + if (tagged) { + headerGenerator.addImport(MessageGenerator.BYTE_UTILS_CLASS); + buffer.printf("int _arraySize = _size.totalSize() - _sizeBeforeArray;%n"); + buffer.printf("_cache.setArraySizeInBytes(%s, _arraySize);%n", + field.camelCaseName()); + buffer.printf("_size.addBytes(ByteUtils.sizeOfUnsignedVarint(_arraySize));%n"); + } + } else if (field.type().isBytes()) { + if (tagged) { + buffer.printf("int _sizeBeforeBytes = _size.totalSize();%n"); + } + if (field.zeroCopy()) { + buffer.printf("_size.addZeroCopyBytes(%s.remaining());%n", field.camelCaseName()); + } else { + buffer.printf("_size.addBytes(%s.length);%n", field.camelCaseName()); + } + VersionConditional.forVersions(fieldFlexibleVersions(field), possibleVersions). + ifMember(__ -> { + headerGenerator.addImport(MessageGenerator.BYTE_UTILS_CLASS); + if (field.zeroCopy()) { + buffer.printf("_size.addBytes(" + + "ByteUtils.sizeOfUnsignedVarint(%s.remaining() + 1));%n", field.camelCaseName()); + } else { + buffer.printf("_size.addBytes(ByteUtils.sizeOfUnsignedVarint(%s.length + 1));%n", + field.camelCaseName()); + } + }). + ifNotMember(__ -> { + buffer.printf("_size.addBytes(4);%n"); + }). + generate(buffer); + if (tagged) { + headerGenerator.addImport(MessageGenerator.BYTE_UTILS_CLASS); + buffer.printf("int _bytesSize = _size.totalSize() - _sizeBeforeBytes;%n"); + buffer.printf("_size.addBytes(ByteUtils.sizeOfUnsignedVarint(_bytesSize));%n"); + } + } else if (field.type().isRecords()) { + buffer.printf("_size.addZeroCopyBytes(%s.sizeInBytes());%n", field.camelCaseName()); + VersionConditional.forVersions(fieldFlexibleVersions(field), possibleVersions). + ifMember(__ -> { + headerGenerator.addImport(MessageGenerator.BYTE_UTILS_CLASS); + buffer.printf("_size.addBytes(" + + "ByteUtils.sizeOfUnsignedVarint(%s.sizeInBytes() + 1));%n", field.camelCaseName()); + }). + ifNotMember(__ -> { + buffer.printf("_size.addBytes(4);%n"); + }). + generate(buffer); + } else if (field.type().isStruct()) { + buffer.printf("int _sizeBeforeStruct = _size.totalSize();%n", field.camelCaseName()); + buffer.printf("this.%s.addSize(_size, _cache, _version);%n", field.camelCaseName()); + buffer.printf("int _structSize = _size.totalSize() - _sizeBeforeStruct;%n", field.camelCaseName()); + + if (tagged) { + buffer.printf("_size.addBytes(ByteUtils.sizeOfUnsignedVarint(_structSize));%n"); + } + } else { + throw new RuntimeException("unhandled type " + field.type()); + } + if (tagged && !field.defaultString().equals("null")) { + buffer.decrementIndent(); + buffer.printf("}%n"); + } + }). + generate(buffer); + } + + private void generateStringToBytes(String name) { + headerGenerator.addImport(MessageGenerator.STANDARD_CHARSETS); + buffer.printf("byte[] _stringBytes = %s.getBytes(StandardCharsets.UTF_8);%n", name); + buffer.printf("if (_stringBytes.length > 0x7fff) {%n"); + buffer.incrementIndent(); + buffer.printf("throw new RuntimeException(\"'%s' field is too long to " + + "be serialized\");%n", name); + buffer.decrementIndent(); + buffer.printf("}%n"); + buffer.printf("_cache.cacheSerializedValue(%s, _stringBytes);%n", name); + } + + private void generateClassEquals(String className, StructSpec struct, + boolean elementKeysAreEqual) { + buffer.printf("@Override%n"); + buffer.printf("public boolean %s(Object obj) {%n", + elementKeysAreEqual ? "elementKeysAreEqual" : "equals"); + buffer.incrementIndent(); + buffer.printf("if (!(obj instanceof %s)) return false;%n", className); + buffer.printf("%s other = (%s) obj;%n", className, className); + if (!struct.fields().isEmpty()) { + for (FieldSpec field : struct.fields()) { + if (!elementKeysAreEqual || field.mapKey()) { + generateFieldEquals(field); + } + } + } + if (elementKeysAreEqual) { + buffer.printf("return true;%n"); + } else { + headerGenerator.addImport(MessageGenerator.MESSAGE_UTIL_CLASS); + buffer.printf("return MessageUtil.compareRawTaggedFields(_unknownTaggedFields, " + + "other._unknownTaggedFields);%n"); + } + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateFieldEquals(FieldSpec field) { + if (field.type() instanceof FieldType.UUIDFieldType) { + buffer.printf("if (!this.%s.equals(other.%s)) return false;%n", + field.camelCaseName(), field.camelCaseName()); + } else if (field.type().isString() || field.type().isArray() || field.type().isStruct()) { + buffer.printf("if (this.%s == null) {%n", field.camelCaseName()); + buffer.incrementIndent(); + buffer.printf("if (other.%s != null) return false;%n", field.camelCaseName()); + buffer.decrementIndent(); + buffer.printf("} else {%n"); + buffer.incrementIndent(); + buffer.printf("if (!this.%s.equals(other.%s)) return false;%n", + field.camelCaseName(), field.camelCaseName()); + buffer.decrementIndent(); + buffer.printf("}%n"); + } else if (field.type().isBytes()) { + if (field.zeroCopy()) { + headerGenerator.addImport(MessageGenerator.OBJECTS_CLASS); + buffer.printf("if (!Objects.equals(this.%s, other.%s)) return false;%n", + field.camelCaseName(), field.camelCaseName()); + } else { + // Arrays#equals handles nulls. + headerGenerator.addImport(MessageGenerator.ARRAYS_CLASS); + buffer.printf("if (!Arrays.equals(this.%s, other.%s)) return false;%n", + field.camelCaseName(), field.camelCaseName()); + } + } else if (field.type().isRecords()) { + headerGenerator.addImport(MessageGenerator.OBJECTS_CLASS); + buffer.printf("if (!Objects.equals(this.%s, other.%s)) return false;%n", + field.camelCaseName(), field.camelCaseName()); + } else { + buffer.printf("if (%s != other.%s) return false;%n", + field.camelCaseName(), field.camelCaseName()); + } + } + + private void generateClassHashCode(StructSpec struct, boolean onlyMapKeys) { + buffer.printf("@Override%n"); + buffer.printf("public int hashCode() {%n"); + buffer.incrementIndent(); + buffer.printf("int hashCode = 0;%n"); + for (FieldSpec field : struct.fields()) { + if ((!onlyMapKeys) || field.mapKey()) { + generateFieldHashCode(field); + } + } + buffer.printf("return hashCode;%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateFieldHashCode(FieldSpec field) { + if (field.type() instanceof FieldType.BoolFieldType) { + buffer.printf("hashCode = 31 * hashCode + (%s ? 1231 : 1237);%n", + field.camelCaseName()); + } else if ((field.type() instanceof FieldType.Int8FieldType) || + (field.type() instanceof FieldType.Int16FieldType) || + (field.type() instanceof FieldType.Uint16FieldType) || + (field.type() instanceof FieldType.Int32FieldType)) { + buffer.printf("hashCode = 31 * hashCode + %s;%n", + field.camelCaseName()); + } else if (field.type() instanceof FieldType.Int64FieldType) { + buffer.printf("hashCode = 31 * hashCode + ((int) (%s >> 32) ^ (int) %s);%n", + field.camelCaseName(), field.camelCaseName()); + } else if (field.type() instanceof FieldType.UUIDFieldType) { + buffer.printf("hashCode = 31 * hashCode + %s.hashCode();%n", + field.camelCaseName()); + } else if (field.type() instanceof FieldType.Float64FieldType) { + buffer.printf("hashCode = 31 * hashCode + Double.hashCode(%s);%n", + field.camelCaseName(), field.camelCaseName()); + } else if (field.type().isBytes()) { + if (field.zeroCopy()) { + headerGenerator.addImport(MessageGenerator.OBJECTS_CLASS); + buffer.printf("hashCode = 31 * hashCode + Objects.hashCode(%s);%n", + field.camelCaseName()); + } else { + headerGenerator.addImport(MessageGenerator.ARRAYS_CLASS); + buffer.printf("hashCode = 31 * hashCode + Arrays.hashCode(%s);%n", + field.camelCaseName()); + } + } else if (field.type().isRecords()) { + headerGenerator.addImport(MessageGenerator.OBJECTS_CLASS); + buffer.printf("hashCode = 31 * hashCode + Objects.hashCode(%s);%n", + field.camelCaseName()); + } else if (field.type().isStruct() + || field.type().isArray() + || field.type().isString()) { + buffer.printf("hashCode = 31 * hashCode + (%s == null ? 0 : %s.hashCode());%n", + field.camelCaseName(), field.camelCaseName()); + } else { + throw new RuntimeException("Unsupported field type " + field.type()); + } + } + + private void generateClassDuplicate(String className, StructSpec struct) { + buffer.printf("@Override%n"); + buffer.printf("public %s duplicate() {%n", className); + buffer.incrementIndent(); + buffer.printf("%s _duplicate = new %s();%n", className, className); + for (FieldSpec field : struct.fields()) { + generateFieldDuplicate(new Target(field, + field.camelCaseName(), + field.camelCaseName(), + input -> String.format("_duplicate.%s = %s", field.camelCaseName(), input))); + } + buffer.printf("return _duplicate;%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateFieldDuplicate(Target target) { + FieldSpec field = target.field(); + if ((field.type() instanceof FieldType.BoolFieldType) || + (field.type() instanceof FieldType.Int8FieldType) || + (field.type() instanceof FieldType.Int16FieldType) || + (field.type() instanceof FieldType.Uint16FieldType) || + (field.type() instanceof FieldType.Int32FieldType) || + (field.type() instanceof FieldType.Int64FieldType) || + (field.type() instanceof FieldType.Float64FieldType) || + (field.type() instanceof FieldType.UUIDFieldType)) { + buffer.printf("%s;%n", target.assignmentStatement(target.sourceVariable())); + } else { + IsNullConditional cond = IsNullConditional.forName(target.sourceVariable()). + nullableVersions(target.field().nullableVersions()). + ifNull(() -> buffer.printf("%s;%n", target.assignmentStatement("null"))); + if (field.type().isBytes()) { + if (field.zeroCopy()) { + cond.ifShouldNotBeNull(() -> + buffer.printf("%s;%n", target.assignmentStatement( + String.format("%s.duplicate()", target.sourceVariable())))); + } else { + cond.ifShouldNotBeNull(() -> { + headerGenerator.addImport(MessageGenerator.MESSAGE_UTIL_CLASS); + buffer.printf("%s;%n", target.assignmentStatement( + String.format("MessageUtil.duplicate(%s)", + target.sourceVariable()))); + }); + } + } else if (field.type().isRecords()) { + cond.ifShouldNotBeNull(() -> { + headerGenerator.addImport(MessageGenerator.MEMORY_RECORDS_CLASS); + buffer.printf("%s;%n", target.assignmentStatement( + String.format("MemoryRecords.readableRecords(((MemoryRecords) %s).buffer().duplicate())", + target.sourceVariable()))); + }); + } else if (field.type().isStruct()) { + cond.ifShouldNotBeNull(() -> + buffer.printf("%s;%n", target.assignmentStatement( + String.format("%s.duplicate()", target.sourceVariable())))); + } else if (field.type().isString()) { + // Strings are immutable, so we don't need to duplicate them. + cond.ifShouldNotBeNull(() -> + buffer.printf("%s;%n", target.assignmentStatement( + target.sourceVariable()))); + } else if (field.type().isArray()) { + cond.ifShouldNotBeNull(() -> { + String newArrayName = + String.format("new%s", field.capitalizedCamelCaseName()); + String type = field.concreteJavaType(headerGenerator, structRegistry); + buffer.printf("%s %s = new %s(%s.size());%n", + type, newArrayName, type, target.sourceVariable()); + FieldType.ArrayType arrayType = (FieldType.ArrayType) field.type(); + buffer.printf("for (%s _element : %s) {%n", + arrayType.elementType().getBoxedJavaType(headerGenerator), + target.sourceVariable()); + buffer.incrementIndent(); + generateFieldDuplicate(target.arrayElementTarget(input -> + String.format("%s.add(%s)", newArrayName, input))); + buffer.decrementIndent(); + buffer.printf("}%n"); + buffer.printf("%s;%n", target.assignmentStatement( + String.format("new%s", field.capitalizedCamelCaseName()))); + }); + } else { + throw new RuntimeException("Unhandled field type " + field.type()); + } + cond.generate(buffer); + } + } + + private void generateClassToString(String className, StructSpec struct) { + buffer.printf("@Override%n"); + buffer.printf("public String toString() {%n"); + buffer.incrementIndent(); + buffer.printf("return \"%s(\"%n", className); + buffer.incrementIndent(); + String prefix = ""; + for (FieldSpec field : struct.fields()) { + generateFieldToString(prefix, field); + prefix = ", "; + } + buffer.printf("+ \")\";%n"); + buffer.decrementIndent(); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateFieldToString(String prefix, FieldSpec field) { + if (field.type() instanceof FieldType.BoolFieldType) { + buffer.printf("+ \"%s%s=\" + (%s ? \"true\" : \"false\")%n", prefix, field.camelCaseName(), field.camelCaseName()); + } else if ((field.type() instanceof FieldType.Int8FieldType) || + (field.type() instanceof FieldType.Int16FieldType) || + (field.type() instanceof FieldType.Uint16FieldType) || + (field.type() instanceof FieldType.Int32FieldType) || + (field.type() instanceof FieldType.Int64FieldType) || + (field.type() instanceof FieldType.Float64FieldType)) { + buffer.printf("+ \"%s%s=\" + %s%n", + prefix, field.camelCaseName(), field.camelCaseName()); + } else if (field.type().isString()) { + buffer.printf("+ \"%s%s=\" + ((%s == null) ? \"null\" : \"'\" + %s.toString() + \"'\")%n", + prefix, field.camelCaseName(), field.camelCaseName(), field.camelCaseName()); + } else if (field.type().isBytes()) { + if (field.zeroCopy()) { + buffer.printf("+ \"%s%s=\" + %s%n", + prefix, field.camelCaseName(), field.camelCaseName()); + } else { + headerGenerator.addImport(MessageGenerator.ARRAYS_CLASS); + buffer.printf("+ \"%s%s=\" + Arrays.toString(%s)%n", + prefix, field.camelCaseName(), field.camelCaseName()); + } + } else if (field.type().isRecords()) { + buffer.printf("+ \"%s%s=\" + %s%n", + prefix, field.camelCaseName(), field.camelCaseName()); + } else if (field.type() instanceof FieldType.UUIDFieldType || + field.type().isStruct()) { + buffer.printf("+ \"%s%s=\" + %s.toString()%n", + prefix, field.camelCaseName(), field.camelCaseName()); + } else if (field.type().isArray()) { + headerGenerator.addImport(MessageGenerator.MESSAGE_UTIL_CLASS); + if (field.nullableVersions().empty()) { + buffer.printf("+ \"%s%s=\" + MessageUtil.deepToString(%s.iterator())%n", + prefix, field.camelCaseName(), field.camelCaseName()); + } else { + buffer.printf("+ \"%s%s=\" + ((%s == null) ? \"null\" : " + + "MessageUtil.deepToString(%s.iterator()))%n", + prefix, field.camelCaseName(), field.camelCaseName(), field.camelCaseName()); + } + } else { + throw new RuntimeException("Unsupported field type " + field.type()); + } + } + + private void generateFieldAccessor(FieldSpec field) { + buffer.printf("%n"); + generateAccessor(field.fieldAbstractJavaType(headerGenerator, structRegistry), + field.camelCaseName(), + field.camelCaseName()); + } + + private void generateAccessor(String javaType, String functionName, String memberName) { + buffer.printf("public %s %s() {%n", javaType, functionName); + buffer.incrementIndent(); + buffer.printf("return this.%s;%n", memberName); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateFieldMutator(String className, FieldSpec field) { + buffer.printf("%n"); + buffer.printf("public %s set%s(%s v) {%n", + className, + field.capitalizedCamelCaseName(), + field.fieldAbstractJavaType(headerGenerator, structRegistry)); + buffer.incrementIndent(); + if (field.type() instanceof FieldType.Uint16FieldType) { + buffer.printf("if (v < 0 || v > 65535) {%n"); + buffer.incrementIndent(); + buffer.printf("throw new RuntimeException(\"Invalid value \" + v + " + + "\" for unsigned short field.\");%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + buffer.printf("this.%s = v;%n", field.camelCaseName()); + buffer.printf("return this;%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateSetter(String javaType, String functionName, String memberName) { + buffer.printf("public void %s(%s v) {%n", functionName, javaType); + buffer.incrementIndent(); + buffer.printf("this.%s = v;%n", memberName); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private Versions fieldFlexibleVersions(FieldSpec field) { + if (field.flexibleVersions().isPresent()) { + if (!messageFlexibleVersions.intersect(field.flexibleVersions().get()). + equals(field.flexibleVersions().get())) { + throw new RuntimeException("The flexible versions for field " + + field.name() + " are " + field.flexibleVersions().get() + + ", which are not a subset of the flexible versions for the " + + "message as a whole, which are " + messageFlexibleVersions); + } + return field.flexibleVersions().get(); + } else { + return messageFlexibleVersions; + } + } + +} diff --git a/generator/src/main/java/org/apache/kafka/message/MessageGenerator.java b/generator/src/main/java/org/apache/kafka/message/MessageGenerator.java new file mode 100644 index 0000000..cfbeae8 --- /dev/null +++ b/generator/src/main/java/org/apache/kafka/message/MessageGenerator.java @@ -0,0 +1,367 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; +import net.sourceforge.argparse4j.ArgumentParsers; +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.Namespace; + +import java.io.BufferedWriter; +import java.nio.file.DirectoryStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; + +import static net.sourceforge.argparse4j.impl.Arguments.store; + +/** + * The Kafka message generator. + */ +public final class MessageGenerator { + static final String JSON_SUFFIX = ".json"; + + static final String JSON_GLOB = "*" + JSON_SUFFIX; + + static final String JAVA_SUFFIX = ".java"; + + static final String API_MESSAGE_TYPE_JAVA = "ApiMessageType.java"; + + static final String API_SCOPE_JAVA = "ApiScope.java"; + + static final String METADATA_RECORD_TYPE_JAVA = "MetadataRecordType.java"; + + static final String METADATA_JSON_CONVERTERS_JAVA = "MetadataJsonConverters.java"; + + static final String API_MESSAGE_CLASS = "org.apache.kafka.common.protocol.ApiMessage"; + + static final String MESSAGE_CLASS = "org.apache.kafka.common.protocol.Message"; + + static final String MESSAGE_UTIL_CLASS = "org.apache.kafka.common.protocol.MessageUtil"; + + static final String READABLE_CLASS = "org.apache.kafka.common.protocol.Readable"; + + static final String WRITABLE_CLASS = "org.apache.kafka.common.protocol.Writable"; + + static final String ARRAYS_CLASS = "java.util.Arrays"; + + static final String OBJECTS_CLASS = "java.util.Objects"; + + static final String LIST_CLASS = "java.util.List"; + + static final String ARRAYLIST_CLASS = "java.util.ArrayList"; + + static final String IMPLICIT_LINKED_HASH_COLLECTION_CLASS = + "org.apache.kafka.common.utils.ImplicitLinkedHashCollection"; + + static final String IMPLICIT_LINKED_HASH_MULTI_COLLECTION_CLASS = + "org.apache.kafka.common.utils.ImplicitLinkedHashMultiCollection"; + + static final String UNSUPPORTED_VERSION_EXCEPTION_CLASS = + "org.apache.kafka.common.errors.UnsupportedVersionException"; + + static final String ITERATOR_CLASS = "java.util.Iterator"; + + static final String ENUM_SET_CLASS = "java.util.EnumSet"; + + static final String TYPE_CLASS = "org.apache.kafka.common.protocol.types.Type"; + + static final String FIELD_CLASS = "org.apache.kafka.common.protocol.types.Field"; + + static final String SCHEMA_CLASS = "org.apache.kafka.common.protocol.types.Schema"; + + static final String ARRAYOF_CLASS = "org.apache.kafka.common.protocol.types.ArrayOf"; + + static final String COMPACT_ARRAYOF_CLASS = "org.apache.kafka.common.protocol.types.CompactArrayOf"; + + static final String BYTES_CLASS = "org.apache.kafka.common.utils.Bytes"; + + static final String UUID_CLASS = "org.apache.kafka.common.Uuid"; + + static final String BASE_RECORDS_CLASS = "org.apache.kafka.common.record.BaseRecords"; + + static final String MEMORY_RECORDS_CLASS = "org.apache.kafka.common.record.MemoryRecords"; + + static final String REQUEST_SUFFIX = "Request"; + + static final String RESPONSE_SUFFIX = "Response"; + + static final String BYTE_UTILS_CLASS = "org.apache.kafka.common.utils.ByteUtils"; + + static final String STANDARD_CHARSETS = "java.nio.charset.StandardCharsets"; + + static final String TAGGED_FIELDS_SECTION_CLASS = "org.apache.kafka.common.protocol.types.Field.TaggedFieldsSection"; + + static final String OBJECT_SERIALIZATION_CACHE_CLASS = "org.apache.kafka.common.protocol.ObjectSerializationCache"; + + static final String MESSAGE_SIZE_ACCUMULATOR_CLASS = "org.apache.kafka.common.protocol.MessageSizeAccumulator"; + + static final String RAW_TAGGED_FIELD_CLASS = "org.apache.kafka.common.protocol.types.RawTaggedField"; + + static final String RAW_TAGGED_FIELD_WRITER_CLASS = "org.apache.kafka.common.protocol.types.RawTaggedFieldWriter"; + + static final String TREE_MAP_CLASS = "java.util.TreeMap"; + + static final String BYTE_BUFFER_CLASS = "java.nio.ByteBuffer"; + + static final String NAVIGABLE_MAP_CLASS = "java.util.NavigableMap"; + + static final String MAP_ENTRY_CLASS = "java.util.Map.Entry"; + + static final String JSON_NODE_CLASS = "com.fasterxml.jackson.databind.JsonNode"; + + static final String OBJECT_NODE_CLASS = "com.fasterxml.jackson.databind.node.ObjectNode"; + + static final String JSON_NODE_FACTORY_CLASS = "com.fasterxml.jackson.databind.node.JsonNodeFactory"; + + static final String BOOLEAN_NODE_CLASS = "com.fasterxml.jackson.databind.node.BooleanNode"; + + static final String SHORT_NODE_CLASS = "com.fasterxml.jackson.databind.node.ShortNode"; + + static final String INT_NODE_CLASS = "com.fasterxml.jackson.databind.node.IntNode"; + + static final String LONG_NODE_CLASS = "com.fasterxml.jackson.databind.node.LongNode"; + + static final String TEXT_NODE_CLASS = "com.fasterxml.jackson.databind.node.TextNode"; + + static final String BINARY_NODE_CLASS = "com.fasterxml.jackson.databind.node.BinaryNode"; + + static final String NULL_NODE_CLASS = "com.fasterxml.jackson.databind.node.NullNode"; + + static final String ARRAY_NODE_CLASS = "com.fasterxml.jackson.databind.node.ArrayNode"; + + static final String DOUBLE_NODE_CLASS = "com.fasterxml.jackson.databind.node.DoubleNode"; + + /** + * The Jackson serializer we use for JSON objects. + */ + static final ObjectMapper JSON_SERDE; + + static { + JSON_SERDE = new ObjectMapper(); + JSON_SERDE.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); + JSON_SERDE.configure(DeserializationFeature.ACCEPT_SINGLE_VALUE_AS_ARRAY, true); + JSON_SERDE.configure(DeserializationFeature.FAIL_ON_TRAILING_TOKENS, true); + JSON_SERDE.configure(JsonParser.Feature.ALLOW_COMMENTS, true); + JSON_SERDE.setSerializationInclusion(JsonInclude.Include.NON_EMPTY); + } + + private static List createTypeClassGenerators(String packageName, + List types) { + if (types == null) return Collections.emptyList(); + List generators = new ArrayList<>(); + for (String type : types) { + switch (type) { + case "ApiMessageTypeGenerator": + generators.add(new ApiMessageTypeGenerator(packageName)); + break; + case "MetadataRecordTypeGenerator": + generators.add(new MetadataRecordTypeGenerator(packageName)); + break; + case "MetadataJsonConvertersGenerator": + generators.add(new MetadataJsonConvertersGenerator(packageName)); + break; + default: + throw new RuntimeException("Unknown type class generator type '" + type + "'"); + } + } + return generators; + } + + private static List createMessageClassGenerators(String packageName, + List types) { + if (types == null) return Collections.emptyList(); + List generators = new ArrayList<>(); + for (String type : types) { + switch (type) { + case "MessageDataGenerator": + generators.add(new MessageDataGenerator(packageName)); + break; + case "JsonConverterGenerator": + generators.add(new JsonConverterGenerator(packageName)); + break; + default: + throw new RuntimeException("Unknown message class generator type '" + type + "'"); + } + } + return generators; + } + + public static void processDirectories(String packageName, + String outputDir, + String inputDir, + List typeClassGeneratorTypes, + List messageClassGeneratorTypes) throws Exception { + Files.createDirectories(Paths.get(outputDir)); + int numProcessed = 0; + + List typeClassGenerators = + createTypeClassGenerators(packageName, typeClassGeneratorTypes); + HashSet outputFileNames = new HashSet<>(); + try (DirectoryStream directoryStream = Files + .newDirectoryStream(Paths.get(inputDir), JSON_GLOB)) { + for (Path inputPath : directoryStream) { + try { + MessageSpec spec = JSON_SERDE. + readValue(inputPath.toFile(), MessageSpec.class); + List generators = + createMessageClassGenerators(packageName, messageClassGeneratorTypes); + for (MessageClassGenerator generator : generators) { + String name = generator.outputName(spec) + JAVA_SUFFIX; + outputFileNames.add(name); + Path outputPath = Paths.get(outputDir, name); + try (BufferedWriter writer = Files.newBufferedWriter(outputPath)) { + generator.generateAndWrite(spec, writer); + } + } + numProcessed++; + typeClassGenerators.forEach(generator -> generator.registerMessageType(spec)); + } catch (Exception e) { + throw new RuntimeException("Exception while processing " + inputPath.toString(), e); + } + } + } + for (TypeClassGenerator typeClassGenerator : typeClassGenerators) { + outputFileNames.add(typeClassGenerator.outputName()); + Path factoryOutputPath = Paths.get(outputDir, typeClassGenerator.outputName()); + try (BufferedWriter writer = Files.newBufferedWriter(factoryOutputPath)) { + typeClassGenerator.generateAndWrite(writer); + } + } + try (DirectoryStream directoryStream = Files. + newDirectoryStream(Paths.get(outputDir))) { + for (Path outputPath : directoryStream) { + Path fileName = outputPath.getFileName(); + if (fileName != null) { + if (!outputFileNames.contains(fileName.toString())) { + Files.delete(outputPath); + } + } + } + } + System.out.printf("MessageGenerator: processed %d Kafka message JSON files(s).%n", numProcessed); + } + + static String capitalizeFirst(String string) { + if (string.isEmpty()) { + return string; + } + return string.substring(0, 1).toUpperCase(Locale.ENGLISH) + + string.substring(1); + } + + static String lowerCaseFirst(String string) { + if (string.isEmpty()) { + return string; + } + return string.substring(0, 1).toLowerCase(Locale.ENGLISH) + + string.substring(1); + } + + static boolean firstIsCapitalized(String string) { + if (string.isEmpty()) { + return false; + } + return Character.isUpperCase(string.charAt(0)); + } + + static String toSnakeCase(String string) { + StringBuilder bld = new StringBuilder(); + boolean prevWasCapitalized = true; + for (int i = 0; i < string.length(); i++) { + char c = string.charAt(i); + if (Character.isUpperCase(c)) { + if (!prevWasCapitalized) { + bld.append('_'); + } + bld.append(Character.toLowerCase(c)); + prevWasCapitalized = true; + } else { + bld.append(c); + prevWasCapitalized = false; + } + } + return bld.toString(); + } + + static String stripSuffix(String str, String suffix) { + if (str.endsWith(suffix)) { + return str.substring(0, str.length() - suffix.length()); + } else { + throw new RuntimeException("String " + str + " does not end with the " + + "expected suffix " + suffix); + } + } + + /** + * Return the number of bytes needed to encode an integer in unsigned variable-length format. + */ + static int sizeOfUnsignedVarint(int value) { + int bytes = 1; + while ((value & 0xffffff80) != 0L) { + bytes += 1; + value >>>= 7; + } + return bytes; + } + + public static void main(String[] args) throws Exception { + ArgumentParser parser = ArgumentParsers + .newArgumentParser("message-generator") + .defaultHelp(true) + .description("The Kafka message generator"); + parser.addArgument("--package", "-p") + .action(store()) + .required(true) + .metavar("PACKAGE") + .help("The java package to use in generated files."); + parser.addArgument("--output", "-o") + .action(store()) + .required(true) + .metavar("OUTPUT") + .help("The output directory to create."); + parser.addArgument("--input", "-i") + .action(store()) + .required(true) + .metavar("INPUT") + .help("The input directory to use."); + parser.addArgument("--typeclass-generators", "-t") + .nargs("+") + .action(store()) + .metavar("TYPECLASS_GENERATORS") + .help("The type class generators to use, if any."); + parser.addArgument("--message-class-generators", "-m") + .nargs("+") + .action(store()) + .metavar("MESSAGE_CLASS_GENERATORS") + .help("The message class generators to use."); + Namespace res = parser.parseArgsOrFail(args); + processDirectories(res.getString("package"), res.getString("output"), + res.getString("input"), res.getList("typeclass_generators"), + res.getList("message_class_generators")); + } +} diff --git a/generator/src/main/java/org/apache/kafka/message/MessageSpec.java b/generator/src/main/java/org/apache/kafka/message/MessageSpec.java new file mode 100644 index 0000000..82866be --- /dev/null +++ b/generator/src/main/java/org/apache/kafka/message/MessageSpec.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +public final class MessageSpec { + private final StructSpec struct; + + private final Optional apiKey; + + private final MessageSpecType type; + + private final List commonStructs; + + private final Versions flexibleVersions; + + private final List listeners; + + @JsonCreator + public MessageSpec(@JsonProperty("name") String name, + @JsonProperty("validVersions") String validVersions, + @JsonProperty("fields") List fields, + @JsonProperty("apiKey") Short apiKey, + @JsonProperty("type") MessageSpecType type, + @JsonProperty("commonStructs") List commonStructs, + @JsonProperty("flexibleVersions") String flexibleVersions, + @JsonProperty("listeners") List listeners) { + this.struct = new StructSpec(name, validVersions, fields); + this.apiKey = apiKey == null ? Optional.empty() : Optional.of(apiKey); + this.type = Objects.requireNonNull(type); + this.commonStructs = commonStructs == null ? Collections.emptyList() : + Collections.unmodifiableList(new ArrayList<>(commonStructs)); + if (flexibleVersions == null) { + throw new RuntimeException("You must specify a value for flexibleVersions. " + + "Please use 0+ for all new messages."); + } + this.flexibleVersions = Versions.parse(flexibleVersions, Versions.NONE); + if ((!this.flexibleVersions().empty()) && + (this.flexibleVersions.highest() < Short.MAX_VALUE)) { + throw new RuntimeException("Field " + name + " specifies flexibleVersions " + + this.flexibleVersions + ", which is not open-ended. flexibleVersions must " + + "be either none, or an open-ended range (that ends with a plus sign)."); + } + + if (listeners != null && !listeners.isEmpty() && type != MessageSpecType.REQUEST) { + throw new RuntimeException("The `requestScope` property is only valid for " + + "messages with type `request`"); + } + this.listeners = listeners; + } + + public StructSpec struct() { + return struct; + } + + @JsonProperty("name") + public String name() { + return struct.name(); + } + + public Versions validVersions() { + return struct.versions(); + } + + @JsonProperty("validVersions") + public String validVersionsString() { + return struct.versionsString(); + } + + @JsonProperty("fields") + public List fields() { + return struct.fields(); + } + + @JsonProperty("apiKey") + public Optional apiKey() { + return apiKey; + } + + @JsonProperty("type") + public MessageSpecType type() { + return type; + } + + @JsonProperty("commonStructs") + public List commonStructs() { + return commonStructs; + } + + public Versions flexibleVersions() { + return flexibleVersions; + } + + @JsonProperty("flexibleVersions") + public String flexibleVersionsString() { + return flexibleVersions.toString(); + } + + @JsonProperty("listeners") + public List listeners() { + return listeners; + } + + public String dataClassName() { + switch (type) { + case HEADER: + case REQUEST: + case RESPONSE: + // We append the Data suffix to request/response/header classes to avoid + // collisions with existing objects. This can go away once the protocols + // have all been converted and we begin using the generated types directly. + return struct.name() + "Data"; + default: + return struct.name(); + } + } +} diff --git a/generator/src/main/java/org/apache/kafka/message/MessageSpecType.java b/generator/src/main/java/org/apache/kafka/message/MessageSpecType.java new file mode 100644 index 0000000..d1110d9 --- /dev/null +++ b/generator/src/main/java/org/apache/kafka/message/MessageSpecType.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +import com.fasterxml.jackson.annotation.JsonProperty; + +public enum MessageSpecType { + /** + * Kafka request RPCs. + */ + @JsonProperty("request") + REQUEST, + + /** + * Kafka response RPCs. + */ + @JsonProperty("response") + RESPONSE, + + /** + * Kafka RPC headers. + */ + @JsonProperty("header") + HEADER, + + /** + * KIP-631 controller records. + */ + @JsonProperty("metadata") + METADATA, + + /** + * Other message spec types. + */ + @JsonProperty("data") + DATA; +} diff --git a/generator/src/main/java/org/apache/kafka/message/MetadataJsonConvertersGenerator.java b/generator/src/main/java/org/apache/kafka/message/MetadataJsonConvertersGenerator.java new file mode 100644 index 0000000..13321f0 --- /dev/null +++ b/generator/src/main/java/org/apache/kafka/message/MetadataJsonConvertersGenerator.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +import java.io.BufferedWriter; +import java.io.IOException; +import java.util.Map; +import java.util.TreeMap; + +public class MetadataJsonConvertersGenerator implements TypeClassGenerator { + private final HeaderGenerator headerGenerator; + private final CodeBuffer buffer; + private final TreeMap apis; + + public MetadataJsonConvertersGenerator(String packageName) { + this.headerGenerator = new HeaderGenerator(packageName); + this.apis = new TreeMap<>(); + this.buffer = new CodeBuffer(); + } + + @Override + public String outputName() { + return MessageGenerator.METADATA_JSON_CONVERTERS_JAVA; + } + + @Override + public void registerMessageType(MessageSpec spec) { + if (spec.type() == MessageSpecType.METADATA) { + short id = spec.apiKey().get(); + MessageSpec prevSpec = apis.put(id, spec); + if (prevSpec != null) { + throw new RuntimeException("Duplicate metadata record entry for type " + + id + ". Original claimant: " + prevSpec.name() + ". New " + + "claimant: " + spec.name()); + } + } + } + + @Override + public void generateAndWrite(BufferedWriter writer) throws IOException { + buffer.printf("public class MetadataJsonConverters {%n"); + buffer.incrementIndent(); + generateWriteJson(); + buffer.printf("%n"); + generateReadJson(); + buffer.printf("%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + headerGenerator.generate(); + + headerGenerator.buffer().write(writer); + buffer.write(writer); + } + + private void generateWriteJson() { + headerGenerator.addImport(MessageGenerator.JSON_NODE_CLASS); + headerGenerator.addImport(MessageGenerator.API_MESSAGE_CLASS); + + buffer.printf("public static JsonNode writeJson(ApiMessage apiMessage, short apiVersion) {%n"); + buffer.incrementIndent(); + buffer.printf("switch (apiMessage.apiKey()) {%n"); + buffer.incrementIndent(); + for (Map.Entry entry : apis.entrySet()) { + String apiMessageClassName = MessageGenerator.capitalizeFirst(entry.getValue().name()); + buffer.printf("case %d:%n", entry.getKey()); + buffer.incrementIndent(); + buffer.printf("return %sJsonConverter.write((%s) apiMessage, apiVersion);%n", apiMessageClassName, apiMessageClassName); + buffer.decrementIndent(); + } + buffer.printf("default:%n"); + buffer.incrementIndent(); + headerGenerator.addImport(MessageGenerator.UNSUPPORTED_VERSION_EXCEPTION_CLASS); + buffer.printf("throw new UnsupportedVersionException(\"Unknown metadata id \"" + + " + apiMessage.apiKey());%n"); + buffer.decrementIndent(); + buffer.decrementIndent(); + buffer.printf("}%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateReadJson() { + headerGenerator.addImport(MessageGenerator.JSON_NODE_CLASS); + headerGenerator.addImport(MessageGenerator.API_MESSAGE_CLASS); + + buffer.printf("public static ApiMessage readJson(JsonNode json, short apiKey, short apiVersion) {%n"); + buffer.incrementIndent(); + buffer.printf("switch (apiKey) {%n"); + buffer.incrementIndent(); + for (Map.Entry entry : apis.entrySet()) { + String apiMessageClassName = MessageGenerator.capitalizeFirst(entry.getValue().name()); + buffer.printf("case %d:%n", entry.getKey()); + buffer.incrementIndent(); + buffer.printf("return %sJsonConverter.read(json, apiVersion);%n", apiMessageClassName); + buffer.decrementIndent(); + } + buffer.printf("default:%n"); + buffer.incrementIndent(); + headerGenerator.addImport(MessageGenerator.UNSUPPORTED_VERSION_EXCEPTION_CLASS); + buffer.printf("throw new UnsupportedVersionException(\"Unknown metadata id \"" + + " + apiKey);%n"); + buffer.decrementIndent(); + buffer.decrementIndent(); + buffer.printf("}%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } +} diff --git a/generator/src/main/java/org/apache/kafka/message/MetadataRecordTypeGenerator.java b/generator/src/main/java/org/apache/kafka/message/MetadataRecordTypeGenerator.java new file mode 100644 index 0000000..cb3db0c --- /dev/null +++ b/generator/src/main/java/org/apache/kafka/message/MetadataRecordTypeGenerator.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +import java.io.BufferedWriter; +import java.io.IOException; +import java.util.Locale; +import java.util.Map; +import java.util.TreeMap; + +public final class MetadataRecordTypeGenerator implements TypeClassGenerator { + private final HeaderGenerator headerGenerator; + private final CodeBuffer buffer; + private final TreeMap apis; + + public MetadataRecordTypeGenerator(String packageName) { + this.headerGenerator = new HeaderGenerator(packageName); + this.apis = new TreeMap<>(); + this.buffer = new CodeBuffer(); + } + + @Override + public String outputName() { + return MessageGenerator.METADATA_RECORD_TYPE_JAVA; + } + + @Override + public void registerMessageType(MessageSpec spec) { + if (spec.type() == MessageSpecType.METADATA) { + short id = spec.apiKey().get(); + MessageSpec prevSpec = apis.put(id, spec); + if (prevSpec != null) { + throw new RuntimeException("Duplicate metadata record entry for type " + + id + ". Original claimant: " + prevSpec.name() + ". New " + + "claimant: " + spec.name()); + } + } + } + + @Override + public void generateAndWrite(BufferedWriter writer) throws IOException { + generate(); + write(writer); + } + + private void generate() { + buffer.printf("public enum MetadataRecordType {%n"); + buffer.incrementIndent(); + generateEnumValues(); + buffer.printf("%n"); + generateInstanceVariables(); + buffer.printf("%n"); + generateEnumConstructor(); + buffer.printf("%n"); + generateFromApiKey(); + buffer.printf("%n"); + generateNewMetadataRecord(); + buffer.printf("%n"); + generateAccessor("id", "short"); + buffer.printf("%n"); + generateAccessor("lowestSupportedVersion", "short"); + buffer.printf("%n"); + generateAccessor("highestSupportedVersion", "short"); + buffer.printf("%n"); + generateToString(); + buffer.decrementIndent(); + buffer.printf("}%n"); + headerGenerator.generate(); + } + + private void generateEnumValues() { + int numProcessed = 0; + for (Map.Entry entry : apis.entrySet()) { + MessageSpec spec = entry.getValue(); + String name = spec.name(); + numProcessed++; + buffer.printf("%s(\"%s\", (short) %d, (short) %d, (short) %d)%s%n", + MessageGenerator.toSnakeCase(name).toUpperCase(Locale.ROOT), + MessageGenerator.capitalizeFirst(name), + entry.getKey(), + entry.getValue().validVersions().lowest(), + entry.getValue().validVersions().highest(), + (numProcessed == apis.size()) ? ";" : ","); + } + } + + private void generateInstanceVariables() { + buffer.printf("private final String name;%n"); + buffer.printf("private final short id;%n"); + buffer.printf("private final short lowestSupportedVersion;%n"); + buffer.printf("private final short highestSupportedVersion;%n"); + } + + private void generateEnumConstructor() { + buffer.printf("MetadataRecordType(String name, short id, short lowestSupportedVersion, short highestSupportedVersion) {%n"); + buffer.incrementIndent(); + buffer.printf("this.name = name;%n"); + buffer.printf("this.id = id;%n"); + buffer.printf("this.lowestSupportedVersion = lowestSupportedVersion;%n"); + buffer.printf("this.highestSupportedVersion = highestSupportedVersion;%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateFromApiKey() { + buffer.printf("public static MetadataRecordType fromId(short id) {%n"); + buffer.incrementIndent(); + buffer.printf("switch (id) {%n"); + buffer.incrementIndent(); + for (Map.Entry entry : apis.entrySet()) { + buffer.printf("case %d:%n", entry.getKey()); + buffer.incrementIndent(); + buffer.printf("return %s;%n", MessageGenerator. + toSnakeCase(entry.getValue().name()).toUpperCase(Locale.ROOT)); + buffer.decrementIndent(); + } + buffer.printf("default:%n"); + buffer.incrementIndent(); + headerGenerator.addImport(MessageGenerator.UNSUPPORTED_VERSION_EXCEPTION_CLASS); + buffer.printf("throw new UnsupportedVersionException(\"Unknown metadata id \"" + + " + id);%n"); + buffer.decrementIndent(); + buffer.decrementIndent(); + buffer.printf("}%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateNewMetadataRecord() { + headerGenerator.addImport(MessageGenerator.API_MESSAGE_CLASS); + buffer.printf("public ApiMessage newMetadataRecord() {%n"); + buffer.incrementIndent(); + buffer.printf("switch (id) {%n"); + buffer.incrementIndent(); + for (Map.Entry entry : apis.entrySet()) { + buffer.printf("case %d:%n", entry.getKey()); + buffer.incrementIndent(); + buffer.printf("return new %s();%n", + MessageGenerator.capitalizeFirst(entry.getValue().name())); + buffer.decrementIndent(); + } + buffer.printf("default:%n"); + buffer.incrementIndent(); + headerGenerator.addImport(MessageGenerator.UNSUPPORTED_VERSION_EXCEPTION_CLASS); + buffer.printf("throw new UnsupportedVersionException(\"Unknown metadata id \"" + + " + id);%n"); + buffer.decrementIndent(); + buffer.decrementIndent(); + buffer.printf("}%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateAccessor(String name, String type) { + buffer.printf("public %s %s() {%n", type, name); + buffer.incrementIndent(); + buffer.printf("return this.%s;%n", name); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void generateToString() { + buffer.printf("@Override%n"); + buffer.printf("public String toString() {%n"); + buffer.incrementIndent(); + buffer.printf("return this.name();%n"); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + + private void write(BufferedWriter writer) throws IOException { + headerGenerator.buffer().write(writer); + buffer.write(writer); + } +} diff --git a/generator/src/main/java/org/apache/kafka/message/RequestListenerType.java b/generator/src/main/java/org/apache/kafka/message/RequestListenerType.java new file mode 100644 index 0000000..cefd40d --- /dev/null +++ b/generator/src/main/java/org/apache/kafka/message/RequestListenerType.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.message; + +import com.fasterxml.jackson.annotation.JsonProperty; + +public enum RequestListenerType { + @JsonProperty("zkBroker") + ZK_BROKER, + + @JsonProperty("broker") + BROKER, + + @JsonProperty("controller") + CONTROLLER; +} diff --git a/generator/src/main/java/org/apache/kafka/message/SchemaGenerator.java b/generator/src/main/java/org/apache/kafka/message/SchemaGenerator.java new file mode 100644 index 0000000..5ebd158 --- /dev/null +++ b/generator/src/main/java/org/apache/kafka/message/SchemaGenerator.java @@ -0,0 +1,371 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.TreeMap; + +/** + * Generates Schemas for Kafka MessageData classes. + */ +final class SchemaGenerator { + /** + * Schema information for a particular message. + */ + static class MessageInfo { + /** + * The versions of this message that we want to generate a schema for. + * This will be constrained by the valid versions for the parent objects. + * For example, if the parent message is valid for versions 0 and 1, + * we will only generate a version 0 and version 1 schema for child classes, + * even if their valid versions are "0+". + */ + private final Versions versions; + + /** + * Maps versions to schema declaration code. If the schema for a + * particular version is the same as that of a previous version, + * there will be no entry in the map for it. + */ + private final TreeMap schemaForVersion; + + MessageInfo(Versions versions) { + this.versions = versions; + this.schemaForVersion = new TreeMap<>(); + } + } + + /** + * The header file generator. This is shared with the MessageDataGenerator + * instance that owns this SchemaGenerator. + */ + private final HeaderGenerator headerGenerator; + + /** + * A registry with the structures we're generating. + */ + private final StructRegistry structRegistry; + + /** + * Maps message names to message information. + */ + private final Map messages; + + /** + * The versions that implement a KIP-482 flexible schema. + */ + private Versions messageFlexibleVersions; + + SchemaGenerator(HeaderGenerator headerGenerator, StructRegistry structRegistry) { + this.headerGenerator = headerGenerator; + this.structRegistry = structRegistry; + this.messages = new HashMap<>(); + } + + void generateSchemas(MessageSpec message) throws Exception { + this.messageFlexibleVersions = message.flexibleVersions(); + + // First generate schemas for common structures so that they are + // available when we generate the inline structures + for (Iterator iter = structRegistry.commonStructs(); iter.hasNext(); ) { + StructSpec struct = iter.next(); + generateSchemas(struct.name(), struct, message.struct().versions()); + } + + // Generate schemas for inline structures + generateSchemas(message.dataClassName(), message.struct(), + message.struct().versions()); + } + + void generateSchemas(String className, StructSpec struct, + Versions parentVersions) throws Exception { + Versions versions = parentVersions.intersect(struct.versions()); + MessageInfo messageInfo = messages.get(className); + if (messageInfo != null) { + return; + } + messageInfo = new MessageInfo(versions); + messages.put(className, messageInfo); + // Process the leaf classes first. + for (FieldSpec field : struct.fields()) { + if (field.type().isStructArray()) { + FieldType.ArrayType arrayType = (FieldType.ArrayType) field.type(); + generateSchemas(arrayType.elementType().toString(), structRegistry.findStruct(field), versions); + } else if (field.type().isStruct()) { + generateSchemas(field.type().toString(), structRegistry.findStruct(field), versions); + } + } + CodeBuffer prev = null; + for (short v = versions.lowest(); v <= versions.highest(); v++) { + CodeBuffer cur = new CodeBuffer(); + generateSchemaForVersion(struct, v, cur); + // If this schema version is different from the previous one, + // create a new map entry. + if (!cur.equals(prev)) { + messageInfo.schemaForVersion.put(v, cur); + } + prev = cur; + } + } + + private void generateSchemaForVersion(StructSpec struct, + short version, + CodeBuffer buffer) throws Exception { + // Find the last valid field index. + int lastValidIndex = struct.fields().size() - 1; + while (true) { + if (lastValidIndex < 0) { + break; + } + FieldSpec field = struct.fields().get(lastValidIndex); + if ((!field.taggedVersions().contains(version)) && + field.versions().contains(version)) { + break; + } + lastValidIndex--; + } + int finalLine = lastValidIndex; + if (messageFlexibleVersions.contains(version)) { + finalLine++; + } + + headerGenerator.addImport(MessageGenerator.SCHEMA_CLASS); + buffer.printf("new Schema(%n"); + buffer.incrementIndent(); + for (int i = 0; i <= lastValidIndex; i++) { + FieldSpec field = struct.fields().get(i); + if ((!field.versions().contains(version)) || + field.taggedVersions().contains(version)) { + continue; + } + Versions fieldFlexibleVersions = + field.flexibleVersions().orElse(messageFlexibleVersions); + headerGenerator.addImport(MessageGenerator.FIELD_CLASS); + buffer.printf("new Field(\"%s\", %s, \"%s\")%s%n", + field.snakeCaseName(), + fieldTypeToSchemaType(field, version, fieldFlexibleVersions), + field.about(), + i == finalLine ? "" : ","); + } + if (messageFlexibleVersions.contains(version)) { + generateTaggedFieldsSchemaForVersion(struct, version, buffer); + } + buffer.decrementIndent(); + buffer.printf(");%n"); + } + + private void generateTaggedFieldsSchemaForVersion(StructSpec struct, + short version, CodeBuffer buffer) throws Exception { + headerGenerator.addStaticImport(MessageGenerator.TAGGED_FIELDS_SECTION_CLASS); + + // Find the last valid tagged field index. + int lastValidIndex = struct.fields().size() - 1; + while (true) { + if (lastValidIndex < 0) { + break; + } + FieldSpec field = struct.fields().get(lastValidIndex); + if ((field.taggedVersions().contains(version)) && + field.versions().contains(version)) { + break; + } + lastValidIndex--; + } + + buffer.printf("TaggedFieldsSection.of(%n"); + buffer.incrementIndent(); + for (int i = 0; i <= lastValidIndex; i++) { + FieldSpec field = struct.fields().get(i); + if ((!field.versions().contains(version)) || + (!field.taggedVersions().contains(version))) { + continue; + } + headerGenerator.addImport(MessageGenerator.FIELD_CLASS); + Versions fieldFlexibleVersions = + field.flexibleVersions().orElse(messageFlexibleVersions); + buffer.printf("%d, new Field(\"%s\", %s, \"%s\")%s%n", + field.tag().get(), + field.snakeCaseName(), + fieldTypeToSchemaType(field, version, fieldFlexibleVersions), + field.about(), + i == lastValidIndex ? "" : ","); + } + buffer.decrementIndent(); + buffer.printf(")%n"); + } + + private String fieldTypeToSchemaType(FieldSpec field, + short version, + Versions fieldFlexibleVersions) { + return fieldTypeToSchemaType(field.type(), + field.nullableVersions().contains(version), + version, + fieldFlexibleVersions, + field.zeroCopy()); + } + + private String fieldTypeToSchemaType(FieldType type, + boolean nullable, + short version, + Versions fieldFlexibleVersions, + boolean zeroCopy) { + if (type instanceof FieldType.BoolFieldType) { + headerGenerator.addImport(MessageGenerator.TYPE_CLASS); + if (nullable) { + throw new RuntimeException("Type " + type + " cannot be nullable."); + } + return "Type.BOOLEAN"; + } else if (type instanceof FieldType.Int8FieldType) { + headerGenerator.addImport(MessageGenerator.TYPE_CLASS); + if (nullable) { + throw new RuntimeException("Type " + type + " cannot be nullable."); + } + return "Type.INT8"; + } else if (type instanceof FieldType.Int16FieldType) { + headerGenerator.addImport(MessageGenerator.TYPE_CLASS); + if (nullable) { + throw new RuntimeException("Type " + type + " cannot be nullable."); + } + return "Type.INT16"; + } else if (type instanceof FieldType.Uint16FieldType) { + headerGenerator.addImport(MessageGenerator.TYPE_CLASS); + if (nullable) { + throw new RuntimeException("Type " + type + " cannot be nullable."); + } + return "Type.UINT16"; + } else if (type instanceof FieldType.Int32FieldType) { + headerGenerator.addImport(MessageGenerator.TYPE_CLASS); + if (nullable) { + throw new RuntimeException("Type " + type + " cannot be nullable."); + } + return "Type.INT32"; + } else if (type instanceof FieldType.Int64FieldType) { + headerGenerator.addImport(MessageGenerator.TYPE_CLASS); + if (nullable) { + throw new RuntimeException("Type " + type + " cannot be nullable."); + } + return "Type.INT64"; + } else if (type instanceof FieldType.UUIDFieldType) { + headerGenerator.addImport(MessageGenerator.TYPE_CLASS); + if (nullable) { + throw new RuntimeException("Type " + type + " cannot be nullable."); + } + return "Type.UUID"; + } else if (type instanceof FieldType.Float64FieldType) { + headerGenerator.addImport(MessageGenerator.TYPE_CLASS); + if (nullable) { + throw new RuntimeException("Type " + type + " cannot be nullable."); + } + return "Type.FLOAT64"; + } else if (type instanceof FieldType.StringFieldType) { + headerGenerator.addImport(MessageGenerator.TYPE_CLASS); + if (fieldFlexibleVersions.contains(version)) { + return nullable ? "Type.COMPACT_NULLABLE_STRING" : "Type.COMPACT_STRING"; + } else { + return nullable ? "Type.NULLABLE_STRING" : "Type.STRING"; + } + } else if (type instanceof FieldType.BytesFieldType) { + headerGenerator.addImport(MessageGenerator.TYPE_CLASS); + if (fieldFlexibleVersions.contains(version)) { + return nullable ? "Type.COMPACT_NULLABLE_BYTES" : "Type.COMPACT_BYTES"; + } else { + return nullable ? "Type.NULLABLE_BYTES" : "Type.BYTES"; + } + } else if (type.isRecords()) { + headerGenerator.addImport(MessageGenerator.TYPE_CLASS); + if (fieldFlexibleVersions.contains(version)) { + return "Type.COMPACT_RECORDS"; + } else { + return "Type.RECORDS"; + } + } else if (type.isArray()) { + if (fieldFlexibleVersions.contains(version)) { + headerGenerator.addImport(MessageGenerator.COMPACT_ARRAYOF_CLASS); + FieldType.ArrayType arrayType = (FieldType.ArrayType) type; + String prefix = nullable ? "CompactArrayOf.nullable" : "new CompactArrayOf"; + return String.format("%s(%s)", prefix, + fieldTypeToSchemaType(arrayType.elementType(), false, version, fieldFlexibleVersions, false)); + + } else { + headerGenerator.addImport(MessageGenerator.ARRAYOF_CLASS); + FieldType.ArrayType arrayType = (FieldType.ArrayType) type; + String prefix = nullable ? "ArrayOf.nullable" : "new ArrayOf"; + return String.format("%s(%s)", prefix, + fieldTypeToSchemaType(arrayType.elementType(), false, version, fieldFlexibleVersions, false)); + } + } else if (type.isStruct()) { + if (nullable) { + throw new RuntimeException("Type " + type + " cannot be nullable."); + } + return String.format("%s.SCHEMA_%d", type.toString(), + floorVersion(type.toString(), version)); + } else { + throw new RuntimeException("Unsupported type " + type); + } + } + + /** + * Find the lowest schema version for a given class that is the same as the + * given version. + */ + private short floorVersion(String className, short v) { + MessageInfo message = messages.get(className); + return message.schemaForVersion.floorKey(v); + } + + /** + * Write the message schema to the provided buffer. + * + * @param className The class name. + * @param buffer The destination buffer. + */ + void writeSchema(String className, CodeBuffer buffer) throws Exception { + MessageInfo messageInfo = messages.get(className); + Versions versions = messageInfo.versions; + + for (short v = versions.lowest(); v <= versions.highest(); v++) { + CodeBuffer declaration = messageInfo.schemaForVersion.get(v); + if (declaration == null) { + buffer.printf("public static final Schema SCHEMA_%d = SCHEMA_%d;%n", v, v - 1); + } else { + buffer.printf("public static final Schema SCHEMA_%d =%n", v); + buffer.incrementIndent(); + declaration.write(buffer); + buffer.decrementIndent(); + } + buffer.printf("%n"); + } + buffer.printf("public static final Schema[] SCHEMAS = new Schema[] {%n"); + buffer.incrementIndent(); + for (short v = 0; v < versions.lowest(); v++) { + buffer.printf("null%s%n", (v == versions.highest()) ? "" : ","); + } + for (short v = versions.lowest(); v <= versions.highest(); v++) { + buffer.printf("SCHEMA_%d%s%n", v, (v == versions.highest()) ? "" : ","); + } + buffer.decrementIndent(); + buffer.printf("};%n"); + buffer.printf("%n"); + + buffer.printf("public static final short LOWEST_SUPPORTED_VERSION = %d;%n", versions.lowest()); + buffer.printf("public static final short HIGHEST_SUPPORTED_VERSION = %d;%n", versions.highest()); + buffer.printf("%n"); + } +} diff --git a/generator/src/main/java/org/apache/kafka/message/StructRegistry.java b/generator/src/main/java/org/apache/kafka/message/StructRegistry.java new file mode 100644 index 0000000..fef66fb --- /dev/null +++ b/generator/src/main/java/org/apache/kafka/message/StructRegistry.java @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; +import java.util.TreeSet; + +/** + * Contains structure data for Kafka MessageData classes. + */ +final class StructRegistry { + private final Map structs; + private final Set commonStructNames; + + static class StructInfo { + /** + * The specification for this structure. + */ + private final StructSpec spec; + + /** + * The versions which the parent(s) of this structure can have. If this is a + * top-level structure, this will be equal to the versions which the + * overall message can have. + */ + private final Versions parentVersions; + + StructInfo(StructSpec spec, Versions parentVersions) { + this.spec = spec; + this.parentVersions = parentVersions; + } + + public StructSpec spec() { + return spec; + } + + public Versions parentVersions() { + return parentVersions; + } + } + + StructRegistry() { + this.structs = new TreeMap<>(); + this.commonStructNames = new TreeSet<>(); + } + + /** + * Register all the structures contained a message spec. + */ + void register(MessageSpec message) throws Exception { + // Register common structures. + for (StructSpec struct : message.commonStructs()) { + if (!MessageGenerator.firstIsCapitalized(struct.name())) { + throw new RuntimeException("Can't process structure " + struct.name() + + ": the first letter of structure names must be capitalized."); + } + if (structs.containsKey(struct.name())) { + throw new RuntimeException("Common struct " + struct.name() + " was specified twice."); + } + structs.put(struct.name(), new StructInfo(struct, struct.versions())); + commonStructNames.add(struct.name()); + } + // Register inline structures. + addStructSpecs(message.validVersions(), message.fields()); + } + + @SuppressWarnings("unchecked") + private void addStructSpecs(Versions parentVersions, List fields) { + for (FieldSpec field : fields) { + String typeName = null; + if (field.type().isStructArray()) { + FieldType.ArrayType arrayType = (FieldType.ArrayType) field.type(); + typeName = arrayType.elementName(); + } else if (field.type().isStruct()) { + FieldType.StructType structType = (FieldType.StructType) field.type(); + typeName = structType.typeName(); + } + if (typeName != null) { + if (commonStructNames.contains(typeName)) { + // If we're using a common structure, we can't specify its fields. + // The fields should be specified in the commonStructs area. + if (!field.fields().isEmpty()) { + throw new RuntimeException("Can't re-specify the common struct " + + typeName + " as an inline struct."); + } + } else if (structs.containsKey(typeName)) { + // Inline structures should only appear once. + throw new RuntimeException("Struct " + typeName + + " was specified twice."); + } else { + // Synthesize a StructSpec object out of the fields. + StructSpec spec = new StructSpec(typeName, + field.versions().toString(), + field.fields()); + structs.put(typeName, new StructInfo(spec, parentVersions)); + } + + addStructSpecs(parentVersions.intersect(field.versions()), field.fields()); + } + } + } + + /** + * Locate the struct corresponding to a field. + */ + @SuppressWarnings("unchecked") + StructSpec findStruct(FieldSpec field) { + String structFieldName; + if (field.type().isArray()) { + FieldType.ArrayType arrayType = (FieldType.ArrayType) field.type(); + structFieldName = arrayType.elementName(); + } else if (field.type().isStruct()) { + FieldType.StructType structType = (FieldType.StructType) field.type(); + structFieldName = structType.typeName(); + } else { + throw new RuntimeException("Field " + field.name() + + " cannot be treated as a structure."); + } + StructInfo structInfo = structs.get(structFieldName); + if (structInfo == null) { + throw new RuntimeException("Unable to locate a specification for the structure " + + structFieldName); + } + return structInfo.spec; + } + + /** + * Return true if the field is a struct array with keys. + */ + @SuppressWarnings("unchecked") + boolean isStructArrayWithKeys(FieldSpec field) { + if (!field.type().isArray()) { + return false; + } + FieldType.ArrayType arrayType = (FieldType.ArrayType) field.type(); + if (!arrayType.isStructArray()) { + return false; + } + StructInfo structInfo = structs.get(arrayType.elementName()); + if (structInfo == null) { + throw new RuntimeException("Unable to locate a specification for the structure " + + arrayType.elementName()); + } + return structInfo.spec.hasKeys(); + } + + Set commonStructNames() { + return commonStructNames; + } + + /** + * Returns an iterator that will step through all the common structures. + */ + Iterator commonStructs() { + return new Iterator() { + private final Iterator iter = commonStructNames.iterator(); + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public StructSpec next() { + return structs.get(iter.next()).spec; + } + }; + } + + Iterator structs() { + return structs.values().iterator(); + } +} diff --git a/generator/src/main/java/org/apache/kafka/message/StructSpec.java b/generator/src/main/java/org/apache/kafka/message/StructSpec.java new file mode 100644 index 0000000..b4094da --- /dev/null +++ b/generator/src/main/java/org/apache/kafka/message/StructSpec.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; + +public final class StructSpec { + private final String name; + + private final Versions versions; + + private final List fields; + + private final boolean hasKeys; + + @JsonCreator + public StructSpec(@JsonProperty("name") String name, + @JsonProperty("versions") String versions, + @JsonProperty("fields") List fields) { + this.name = Objects.requireNonNull(name); + this.versions = Versions.parse(versions, null); + if (this.versions == null) { + throw new RuntimeException("You must specify the version of the " + + name + " structure."); + } + ArrayList newFields = new ArrayList<>(); + if (fields != null) { + // Each field should have a unique tag ID (if the field has a tag ID). + HashSet tags = new HashSet<>(); + for (FieldSpec field : fields) { + if (field.tag().isPresent()) { + if (tags.contains(field.tag().get())) { + throw new RuntimeException("In " + name + ", field " + field.name() + + " has a duplicate tag ID " + field.tag().get() + ". All tags IDs " + + "must be unique."); + } + tags.add(field.tag().get()); + } + newFields.add(field); + } + // Tag IDs should be contiguous and start at 0. This optimizes space on the wire, + // since larger numbers take more space. + for (int i = 0; i < tags.size(); i++) { + if (!tags.contains(i)) { + throw new RuntimeException("In " + name + ", the tag IDs are not " + + "contiguous. Make use of tag " + i + " before using any " + + "higher tag IDs."); + } + } + } + this.fields = Collections.unmodifiableList(newFields); + this.hasKeys = this.fields.stream().anyMatch(f -> f.mapKey()); + } + + @JsonProperty + public String name() { + return name; + } + + public Versions versions() { + return versions; + } + + @JsonProperty + public String versionsString() { + return versions.toString(); + } + + @JsonProperty + public List fields() { + return fields; + } + + boolean hasKeys() { + return hasKeys; + } +} diff --git a/generator/src/main/java/org/apache/kafka/message/Target.java b/generator/src/main/java/org/apache/kafka/message/Target.java new file mode 100644 index 0000000..a43cf2c --- /dev/null +++ b/generator/src/main/java/org/apache/kafka/message/Target.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +import java.util.Collections; +import java.util.function.Function; + +public final class Target { + private final FieldSpec field; + private final String sourceVariable; + private final String humanReadableName; + private final Function assignmentStatementGenerator; + + Target(FieldSpec field, String sourceVariable, String humanReadableName, + Function assignmentStatementGenerator) { + this.field = field; + this.sourceVariable = sourceVariable; + this.humanReadableName = humanReadableName; + this.assignmentStatementGenerator = assignmentStatementGenerator; + } + + public String assignmentStatement(String rightHandSide) { + return assignmentStatementGenerator.apply(rightHandSide); + } + + public Target nonNullableCopy() { + FieldSpec nonNullableField = new FieldSpec(field.name(), + field.versionsString(), + field.fields(), + field.typeString(), + field.mapKey(), + Versions.NONE.toString(), + field.defaultString(), + field.ignorable(), + field.entityType(), + field.about(), + field.taggedVersionsString(), + field.flexibleVersionsString(), + field.tagInteger(), + field.zeroCopy()); + return new Target(nonNullableField, sourceVariable, humanReadableName, assignmentStatementGenerator); + } + + public Target arrayElementTarget(Function assignmentStatementGenerator) { + if (!field.type().isArray()) { + throw new RuntimeException("Field " + field + " is not an array."); + } + FieldType.ArrayType arrayType = (FieldType.ArrayType) field.type(); + FieldSpec elementField = new FieldSpec(field.name() + "Element", + field.versions().toString(), + Collections.emptyList(), + arrayType.elementType().toString(), + false, + Versions.NONE.toString(), + "", + false, + EntityType.UNKNOWN, + "", + Versions.NONE.toString(), + field.flexibleVersionsString(), + null, + field.zeroCopy()); + return new Target(elementField, "_element", humanReadableName + " element", + assignmentStatementGenerator); + } + + public FieldSpec field() { + return field; + } + + public String sourceVariable() { + return sourceVariable; + } + + public String humanReadableName() { + return humanReadableName; + } +} diff --git a/generator/src/main/java/org/apache/kafka/message/TypeClassGenerator.java b/generator/src/main/java/org/apache/kafka/message/TypeClassGenerator.java new file mode 100644 index 0000000..22a0597 --- /dev/null +++ b/generator/src/main/java/org/apache/kafka/message/TypeClassGenerator.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +import java.io.BufferedWriter; +import java.io.IOException; + +public interface TypeClassGenerator { + /** + * The short name of the type class file we are generating. For example, + * ApiMessageType.java. + */ + String outputName(); + + /** + * Registers a message spec with the generator. + * + * @param spec The spec to register. + */ + void registerMessageType(MessageSpec spec); + + /** + * Generate the type, and then write it out. + * + * @param writer The writer to write out the state to. + */ + void generateAndWrite(BufferedWriter writer) throws IOException; +} diff --git a/generator/src/main/java/org/apache/kafka/message/VersionConditional.java b/generator/src/main/java/org/apache/kafka/message/VersionConditional.java new file mode 100644 index 0000000..9a1abfa --- /dev/null +++ b/generator/src/main/java/org/apache/kafka/message/VersionConditional.java @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +/** + * Creates an if statement based on whether or not the current version + * falls within a given range. + */ +public final class VersionConditional { + /** + * Create a version conditional. + * + * @param containingVersions The versions for which the conditional is true. + * @param possibleVersions The range of possible versions. + * @return The version conditional. + */ + static VersionConditional forVersions(Versions containingVersions, + Versions possibleVersions) { + return new VersionConditional(containingVersions, possibleVersions); + } + + private final Versions containingVersions; + private final Versions possibleVersions; + private ClauseGenerator ifMember = null; + private ClauseGenerator ifNotMember = null; + private boolean alwaysEmitBlockScope = false; + private boolean allowMembershipCheckAlwaysFalse = true; + + private VersionConditional(Versions containingVersions, Versions possibleVersions) { + this.containingVersions = containingVersions; + this.possibleVersions = possibleVersions; + } + + VersionConditional ifMember(ClauseGenerator ifMember) { + this.ifMember = ifMember; + return this; + } + + VersionConditional ifNotMember(ClauseGenerator ifNotMember) { + this.ifNotMember = ifNotMember; + return this; + } + + /** + * If this is set, we will always create a new block scope, even if there + * are no 'if' statements. This is useful for cases where we want to + * declare variables in the clauses without worrying if they conflict with + * other variables of the same name. + */ + VersionConditional alwaysEmitBlockScope(boolean alwaysEmitBlockScope) { + this.alwaysEmitBlockScope = alwaysEmitBlockScope; + return this; + } + + /** + * If this is set, VersionConditional#generate will throw an exception if + * the 'ifMember' clause is never used. This is useful as a sanity check + * in some cases where it doesn't make sense for the condition to always be + * false. For example, when generating a Message#write function, + * we might check that the version we're writing is supported. It wouldn't + * make sense for this check to always be false, since that would mean that + * no versions at all were supported. + */ + VersionConditional allowMembershipCheckAlwaysFalse(boolean allowMembershipCheckAlwaysFalse) { + this.allowMembershipCheckAlwaysFalse = allowMembershipCheckAlwaysFalse; + return this; + } + + private void generateFullRangeCheck(Versions ifVersions, + Versions ifNotVersions, + CodeBuffer buffer) { + if (ifMember != null) { + buffer.printf("if ((_version >= %d) && (_version <= %d)) {%n", + containingVersions.lowest(), containingVersions.highest()); + buffer.incrementIndent(); + ifMember.generate(ifVersions); + buffer.decrementIndent(); + if (ifNotMember != null) { + buffer.printf("} else {%n"); + buffer.incrementIndent(); + ifNotMember.generate(ifNotVersions); + buffer.decrementIndent(); + } + buffer.printf("}%n"); + } else if (ifNotMember != null) { + buffer.printf("if ((_version < %d) || (_version > %d)) {%n", + containingVersions.lowest(), containingVersions.highest()); + buffer.incrementIndent(); + ifNotMember.generate(ifNotVersions); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + } + + private void generateLowerRangeCheck(Versions ifVersions, + Versions ifNotVersions, + CodeBuffer buffer) { + if (ifMember != null) { + buffer.printf("if (_version >= %d) {%n", containingVersions.lowest()); + buffer.incrementIndent(); + ifMember.generate(ifVersions); + buffer.decrementIndent(); + if (ifNotMember != null) { + buffer.printf("} else {%n"); + buffer.incrementIndent(); + ifNotMember.generate(ifNotVersions); + buffer.decrementIndent(); + } + buffer.printf("}%n"); + } else if (ifNotMember != null) { + buffer.printf("if (_version < %d) {%n", containingVersions.lowest()); + buffer.incrementIndent(); + ifNotMember.generate(ifNotVersions); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + } + + private void generateUpperRangeCheck(Versions ifVersions, + Versions ifNotVersions, + CodeBuffer buffer) { + if (ifMember != null) { + buffer.printf("if (_version <= %d) {%n", containingVersions.highest()); + buffer.incrementIndent(); + ifMember.generate(ifVersions); + buffer.decrementIndent(); + if (ifNotMember != null) { + buffer.printf("} else {%n"); + buffer.incrementIndent(); + ifNotMember.generate(ifNotVersions); + buffer.decrementIndent(); + } + buffer.printf("}%n"); + } else if (ifNotMember != null) { + buffer.printf("if (_version > %d) {%n", containingVersions.highest()); + buffer.incrementIndent(); + ifNotMember.generate(ifNotVersions); + buffer.decrementIndent(); + buffer.printf("}%n"); + } + } + + private void generateAlwaysTrueCheck(Versions ifVersions, CodeBuffer buffer) { + if (ifMember != null) { + if (alwaysEmitBlockScope) { + buffer.printf("{%n"); + buffer.incrementIndent(); + } + ifMember.generate(ifVersions); + if (alwaysEmitBlockScope) { + buffer.decrementIndent(); + buffer.printf("}%n"); + } + } + } + + private void generateAlwaysFalseCheck(Versions ifNotVersions, CodeBuffer buffer) { + if (!allowMembershipCheckAlwaysFalse) { + throw new RuntimeException("Version ranges " + containingVersions + + " and " + possibleVersions + " have no versions in common."); + } + if (ifNotMember != null) { + if (alwaysEmitBlockScope) { + buffer.printf("{%n"); + buffer.incrementIndent(); + } + ifNotMember.generate(ifNotVersions); + if (alwaysEmitBlockScope) { + buffer.decrementIndent(); + buffer.printf("}%n"); + } + } + } + + void generate(CodeBuffer buffer) { + Versions ifVersions = possibleVersions.intersect(containingVersions); + Versions ifNotVersions = possibleVersions.subtract(containingVersions); + // In the case where ifNotVersions would be two ranges rather than one, + // we just pass in the original possibleVersions instead. + // This is slightly less optimal, but allows us to avoid dealing with + // multiple ranges. + if (ifNotVersions == null) { + ifNotVersions = possibleVersions; + } + + if (possibleVersions.lowest() < containingVersions.lowest()) { + if (possibleVersions.highest() > containingVersions.highest()) { + generateFullRangeCheck(ifVersions, ifNotVersions, buffer); + } else if (possibleVersions.highest() >= containingVersions.lowest()) { + generateLowerRangeCheck(ifVersions, ifNotVersions, buffer); + } else { + generateAlwaysFalseCheck(ifNotVersions, buffer); + } + } else if (possibleVersions.highest() >= containingVersions.lowest() && + (possibleVersions.lowest() <= containingVersions.highest())) { + if (possibleVersions.highest() > containingVersions.highest()) { + generateUpperRangeCheck(ifVersions, ifNotVersions, buffer); + } else { + generateAlwaysTrueCheck(ifVersions, buffer); + } + } else { + generateAlwaysFalseCheck(ifNotVersions, buffer); + } + } +} diff --git a/generator/src/main/java/org/apache/kafka/message/Versions.java b/generator/src/main/java/org/apache/kafka/message/Versions.java new file mode 100644 index 0000000..649c065 --- /dev/null +++ b/generator/src/main/java/org/apache/kafka/message/Versions.java @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +import java.util.Objects; + +/** + * A version range. + * + * A range consists of two 16-bit numbers: the lowest version which is accepted, and the highest. + * Ranges are inclusive, meaning that both the lowest and the highest version are valid versions. + * The only exception to this is the NONE range, which contains no versions at all. + * + * Version ranges can be represented as strings. + * + * A single supported version V is represented as "V". + * A bounded range from A to B is represented as "A-B". + * All versions greater than A is represented as "A+". + * The NONE range is represented as an the string "none". + */ +public final class Versions { + private final short lowest; + private final short highest; + + public static Versions parse(String input, Versions defaultVersions) { + if (input == null) { + return defaultVersions; + } + String trimmedInput = input.trim(); + if (trimmedInput.length() == 0) { + return defaultVersions; + } + if (trimmedInput.equals(NONE_STRING)) { + return NONE; + } + if (trimmedInput.endsWith("+")) { + return new Versions(Short.parseShort( + trimmedInput.substring(0, trimmedInput.length() - 1)), + Short.MAX_VALUE); + } else { + int dashIndex = trimmedInput.indexOf("-"); + if (dashIndex < 0) { + short version = Short.parseShort(trimmedInput); + return new Versions(version, version); + } + return new Versions( + Short.parseShort(trimmedInput.substring(0, dashIndex)), + Short.parseShort(trimmedInput.substring(dashIndex + 1))); + } + } + + public static final Versions ALL = new Versions((short) 0, Short.MAX_VALUE); + + public static final Versions NONE = new Versions(); + + public static final String NONE_STRING = "none"; + + private Versions() { + this.lowest = 0; + this.highest = -1; + } + + public Versions(short lowest, short highest) { + if ((lowest < 0) || (highest < 0)) { + throw new RuntimeException("Invalid version range " + + lowest + " to " + highest); + } + this.lowest = lowest; + this.highest = highest; + } + + public short lowest() { + return lowest; + } + + public short highest() { + return highest; + } + + public boolean empty() { + return lowest > highest; + } + + @Override + public String toString() { + if (empty()) { + return NONE_STRING; + } else if (lowest == highest) { + return String.valueOf(lowest); + } else if (highest == Short.MAX_VALUE) { + return String.format("%d+", lowest); + } else { + return String.format("%d-%d", lowest, highest); + } + } + + /** + * Return the intersection of two version ranges. + * + * @param other The other version range. + * @return A new version range. + */ + public Versions intersect(Versions other) { + short newLowest = lowest > other.lowest ? lowest : other.lowest; + short newHighest = highest < other.highest ? highest : other.highest; + if (newLowest > newHighest) { + return Versions.NONE; + } + return new Versions(newLowest, newHighest); + } + + /** + * Return a new version range that trims some versions from this range, if possible. + * We can't trim any versions if the resulting range would be disjoint. + * + * Some examples: + * 1-4.trim(1-2) = 3-4 + * 3+.trim(4+) = 3 + * 4+.trim(3+) = none + * 1-5.trim(2-4) = null + * + * @param other The other version range. + * @return A new version range. + */ + public Versions subtract(Versions other) { + if (other.lowest() <= lowest) { + if (other.highest >= highest) { + // Case 1: other is a superset of this. Trim everything. + return Versions.NONE; + } else if (other.highest < lowest) { + // Case 2: other is a disjoint version range that is lower than this. Trim nothing. + return this; + } else { + // Case 3: trim some values from the beginning of this range. + // + // Note: it is safe to assume that other.highest() + 1 will not overflow. + // The reason is because if other.highest() were Short.MAX_VALUE, + // other.highest() < highest could not be true. + return new Versions((short) (other.highest() + 1), highest); + } + } else if (other.highest >= highest) { + int newHighest = other.lowest - 1; + if (newHighest < 0) { + // Case 4: other was NONE. Trim nothing. + return this; + } else if (newHighest < highest) { + // Case 5: trim some values from the end of this range. + return new Versions(lowest, (short) newHighest); + } else { + // Case 6: other is a disjoint range that is higher than this. Trim nothing. + return this; + } + } else { + // Case 7: the difference between this and other would be two ranges, not one. + return null; + } + } + + public boolean contains(short version) { + return version >= lowest && version <= highest; + } + + public boolean contains(Versions other) { + if (other.empty()) { + return true; + } + return !((lowest > other.lowest) || (highest < other.highest)); + } + + @Override + public int hashCode() { + return Objects.hash(lowest, highest); + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof Versions)) { + return false; + } + Versions otherVersions = (Versions) other; + return lowest == otherVersions.lowest && + highest == otherVersions.highest; + } +} diff --git a/generator/src/test/java/org/apache/kafka/message/CodeBufferTest.java b/generator/src/test/java/org/apache/kafka/message/CodeBufferTest.java new file mode 100644 index 0000000..6dd2e64 --- /dev/null +++ b/generator/src/test/java/org/apache/kafka/message/CodeBufferTest.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.io.StringWriter; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertThrows; + +@Timeout(120) +public class CodeBufferTest { + + @Test + public void testWrite() throws Exception { + CodeBuffer buffer = new CodeBuffer(); + buffer.printf("public static void main(String[] args) throws Exception {%n"); + buffer.incrementIndent(); + buffer.printf("System.out.println(\"%s\");%n", "hello world"); + buffer.decrementIndent(); + buffer.printf("}%n"); + StringWriter stringWriter = new StringWriter(); + buffer.write(stringWriter); + assertEquals( + stringWriter.toString(), + String.format("public static void main(String[] args) throws Exception {%n") + + String.format(" System.out.println(\"hello world\");%n") + + String.format("}%n")); + } + + @Test + public void testEquals() { + CodeBuffer buffer1 = new CodeBuffer(); + CodeBuffer buffer2 = new CodeBuffer(); + assertEquals(buffer1, buffer2); + buffer1.printf("hello world"); + assertNotEquals(buffer1, buffer2); + buffer2.printf("hello world"); + assertEquals(buffer1, buffer2); + buffer1.printf("foo, bar, and baz"); + buffer2.printf("foo, bar, and baz"); + assertEquals(buffer1, buffer2); + } + + @Test + public void testIndentMustBeNonNegative() { + CodeBuffer buffer = new CodeBuffer(); + buffer.incrementIndent(); + buffer.decrementIndent(); + RuntimeException e = assertThrows(RuntimeException.class, buffer::decrementIndent); + assertTrue(e.getMessage().contains("Indent < 0")); + } +} diff --git a/generator/src/test/java/org/apache/kafka/message/EntityTypeTest.java b/generator/src/test/java/org/apache/kafka/message/EntityTypeTest.java new file mode 100644 index 0000000..0b4bc1f --- /dev/null +++ b/generator/src/test/java/org/apache/kafka/message/EntityTypeTest.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +@Timeout(120) +public class EntityTypeTest { + + @Test + public void testUnknownEntityType() { + for (FieldType type : new FieldType[] { + FieldType.StringFieldType.INSTANCE, + FieldType.Int8FieldType.INSTANCE, + FieldType.Int16FieldType.INSTANCE, + FieldType.Int32FieldType.INSTANCE, + FieldType.Int64FieldType.INSTANCE, + new FieldType.ArrayType(FieldType.StringFieldType.INSTANCE)}) { + EntityType.UNKNOWN.verifyTypeMatches("unknown", type); + } + } + + @Test + public void testVerifyTypeMatches() { + EntityType.TRANSACTIONAL_ID.verifyTypeMatches("transactionalIdField", + FieldType.StringFieldType.INSTANCE); + EntityType.TRANSACTIONAL_ID.verifyTypeMatches("transactionalIdField", + new FieldType.ArrayType(FieldType.StringFieldType.INSTANCE)); + EntityType.PRODUCER_ID.verifyTypeMatches("producerIdField", + FieldType.Int64FieldType.INSTANCE); + EntityType.PRODUCER_ID.verifyTypeMatches("producerIdField", + new FieldType.ArrayType(FieldType.Int64FieldType.INSTANCE)); + EntityType.GROUP_ID.verifyTypeMatches("groupIdField", + FieldType.StringFieldType.INSTANCE); + EntityType.GROUP_ID.verifyTypeMatches("groupIdField", + new FieldType.ArrayType(FieldType.StringFieldType.INSTANCE)); + EntityType.TOPIC_NAME.verifyTypeMatches("topicNameField", + FieldType.StringFieldType.INSTANCE); + EntityType.TOPIC_NAME.verifyTypeMatches("topicNameField", + new FieldType.ArrayType(FieldType.StringFieldType.INSTANCE)); + EntityType.BROKER_ID.verifyTypeMatches("brokerIdField", + FieldType.Int32FieldType.INSTANCE); + EntityType.BROKER_ID.verifyTypeMatches("brokerIdField", + new FieldType.ArrayType(FieldType.Int32FieldType.INSTANCE)); + } + + private static void expectException(Runnable r) { + assertThrows(RuntimeException.class, r::run); + } + + @Test + public void testVerifyTypeMismatches() { + expectException(() -> EntityType.TRANSACTIONAL_ID. + verifyTypeMatches("transactionalIdField", FieldType.Int32FieldType.INSTANCE)); + expectException(() -> EntityType.PRODUCER_ID. + verifyTypeMatches("producerIdField", FieldType.StringFieldType.INSTANCE)); + expectException(() -> EntityType.GROUP_ID. + verifyTypeMatches("groupIdField", FieldType.Int8FieldType.INSTANCE)); + expectException(() -> EntityType.TOPIC_NAME. + verifyTypeMatches("topicNameField", + new FieldType.ArrayType(FieldType.Int64FieldType.INSTANCE))); + expectException(() -> EntityType.BROKER_ID. + verifyTypeMatches("brokerIdField", FieldType.Int64FieldType.INSTANCE)); + } +} diff --git a/generator/src/test/java/org/apache/kafka/message/IsNullConditionalTest.java b/generator/src/test/java/org/apache/kafka/message/IsNullConditionalTest.java new file mode 100644 index 0000000..24159c6 --- /dev/null +++ b/generator/src/test/java/org/apache/kafka/message/IsNullConditionalTest.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +@Timeout(120) +public class IsNullConditionalTest { + + @Test + public void testNullCheck() throws Exception { + CodeBuffer buffer = new CodeBuffer(); + IsNullConditional. + forName("foobar"). + nullableVersions(Versions.parse("2+", null)). + possibleVersions(Versions.parse("0+", null)). + ifNull(() -> { + buffer.printf("System.out.println(\"null\");%n"); + }). + generate(buffer); + VersionConditionalTest.claimEquals(buffer, + "if (foobar == null) {%n", + " System.out.println(\"null\");%n", + "}%n"); + } + + @Test + public void testAnotherNullCheck() throws Exception { + CodeBuffer buffer = new CodeBuffer(); + IsNullConditional. + forName("foobar"). + nullableVersions(Versions.parse("0+", null)). + possibleVersions(Versions.parse("2+", null)). + ifNull(() -> { + buffer.printf("System.out.println(\"null\");%n"); + }). + ifShouldNotBeNull(() -> { + buffer.printf("System.out.println(\"not null\");%n"); + }). + generate(buffer); + VersionConditionalTest.claimEquals(buffer, + "if (foobar == null) {%n", + " System.out.println(\"null\");%n", + "} else {%n", + " System.out.println(\"not null\");%n", + "}%n"); + } + + @Test + public void testNotNullCheck() throws Exception { + CodeBuffer buffer = new CodeBuffer(); + IsNullConditional. + forName("foobar"). + nullableVersions(Versions.parse("0+", null)). + possibleVersions(Versions.parse("2+", null)). + ifShouldNotBeNull(() -> { + buffer.printf("System.out.println(\"not null\");%n"); + }). + generate(buffer); + VersionConditionalTest.claimEquals(buffer, + "if (foobar != null) {%n", + " System.out.println(\"not null\");%n", + "}%n"); + } + + @Test + public void testNeverNull() throws Exception { + CodeBuffer buffer = new CodeBuffer(); + IsNullConditional. + forName("baz"). + nullableVersions(Versions.parse("0-2", null)). + possibleVersions(Versions.parse("3+", null)). + ifNull(() -> { + buffer.printf("System.out.println(\"null\");%n"); + }). + ifShouldNotBeNull(() -> { + buffer.printf("System.out.println(\"not null\");%n"); + }). + generate(buffer); + VersionConditionalTest.claimEquals(buffer, + "System.out.println(\"not null\");%n"); + } + + @Test + public void testNeverNullWithBlockScope() throws Exception { + CodeBuffer buffer = new CodeBuffer(); + IsNullConditional. + forName("baz"). + nullableVersions(Versions.parse("0-2", null)). + possibleVersions(Versions.parse("3+", null)). + ifNull(() -> { + buffer.printf("System.out.println(\"null\");%n"); + }). + ifShouldNotBeNull(() -> { + buffer.printf("System.out.println(\"not null\");%n"); + }). + alwaysEmitBlockScope(true). + generate(buffer); + VersionConditionalTest.claimEquals(buffer, + "{%n", + " System.out.println(\"not null\");%n", + "}%n"); + } +} diff --git a/generator/src/test/java/org/apache/kafka/message/MessageDataGeneratorTest.java b/generator/src/test/java/org/apache/kafka/message/MessageDataGeneratorTest.java new file mode 100644 index 0000000..a51aacc --- /dev/null +++ b/generator/src/test/java/org/apache/kafka/message/MessageDataGeneratorTest.java @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.Arrays; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +@Timeout(120) +public class MessageDataGeneratorTest { + + @Test + public void testNullDefaults() throws Exception { + MessageSpec testMessageSpec = MessageGenerator.JSON_SERDE.readValue(String.join("", Arrays.asList( + "{", + " \"type\": \"request\",", + " \"name\": \"FooBar\",", + " \"validVersions\": \"0-2\",", + " \"flexibleVersions\": \"none\",", + " \"fields\": [", + " { \"name\": \"field1\", \"type\": \"int32\", \"versions\": \"0+\" },", + " { \"name\": \"field2\", \"type\": \"[]TestStruct\", \"versions\": \"1+\", ", + " \"nullableVersions\": \"1+\", \"default\": \"null\", \"fields\": [", + " { \"name\": \"field1\", \"type\": \"int32\", \"versions\": \"0+\" }", + " ]},", + " { \"name\": \"field3\", \"type\": \"bytes\", \"versions\": \"2+\", ", + " \"nullableVersions\": \"2+\", \"default\": \"null\" }", + " ]", + "}")), MessageSpec.class); + new MessageDataGenerator("org.apache.kafka.common.message").generate(testMessageSpec); + } + + private void assertStringContains(String substring, String value) { + assertTrue(value.contains(substring), + "Expected string to contain '" + substring + "', but it was " + value); + } + + @Test + public void testInvalidNullDefaultForInt() throws Exception { + MessageSpec testMessageSpec = MessageGenerator.JSON_SERDE.readValue(String.join("", Arrays.asList( + "{", + " \"type\": \"request\",", + " \"name\": \"FooBar\",", + " \"validVersions\": \"0-2\",", + " \"flexibleVersions\": \"none\",", + " \"fields\": [", + " { \"name\": \"field1\", \"type\": \"int32\", \"versions\": \"0+\", \"default\": \"null\" }", + " ]", + "}")), MessageSpec.class); + assertStringContains("Invalid default for int32", + assertThrows(RuntimeException.class, () -> { + new MessageDataGenerator("org.apache.kafka.common.message").generate(testMessageSpec); + }).getMessage()); + } + + @Test + public void testInvalidNullDefaultForPotentiallyNonNullableArray() throws Exception { + MessageSpec testMessageSpec = MessageGenerator.JSON_SERDE.readValue(String.join("", Arrays.asList( + "{", + " \"type\": \"request\",", + " \"name\": \"FooBar\",", + " \"validVersions\": \"0-2\",", + " \"flexibleVersions\": \"none\",", + " \"fields\": [", + " { \"name\": \"field1\", \"type\": \"[]int32\", \"versions\": \"0+\", \"nullableVersions\": \"1+\", ", + " \"default\": \"null\" }", + " ]", + "}")), MessageSpec.class); + + assertStringContains("not all versions of this field are nullable", + assertThrows(RuntimeException.class, () -> { + new MessageDataGenerator("org.apache.kafka.common.message").generate(testMessageSpec); + }).getMessage()); + } + + /** + * Test attempting to create a field with an invalid name. The name is + * invalid because it starts with an underscore. + */ + @Test + public void testInvalidFieldName() { + assertStringContains("Invalid field name", + assertThrows(Throwable.class, () -> { + MessageGenerator.JSON_SERDE.readValue(String.join("", Arrays.asList( + "{", + " \"type\": \"request\",", + " \"name\": \"FooBar\",", + " \"validVersions\": \"0-2\",", + " \"flexibleVersions\": \"0+\",", + " \"fields\": [", + " { \"name\": \"_badName\", \"type\": \"[]int32\", \"versions\": \"0+\" }", + " ]", + "}")), MessageSpec.class); + }).getMessage()); + } + + @Test + public void testInvalidTagWithoutTaggedVersions() { + assertStringContains("If a tag is specified, taggedVersions must be specified as well.", + assertThrows(Throwable.class, () -> { + MessageGenerator.JSON_SERDE.readValue(String.join("", Arrays.asList( + "{", + " \"type\": \"request\",", + " \"name\": \"FooBar\",", + " \"validVersions\": \"0-2\",", + " \"flexibleVersions\": \"0+\",", + " \"fields\": [", + " { \"name\": \"field1\", \"type\": \"int32\", \"versions\": \"0+\", \"tag\": 0 }", + " ]", + "}")), MessageSpec.class); + fail("Expected the MessageSpec constructor to fail"); + }).getMessage()); + } + + @Test + public void testInvalidNegativeTag() { + assertStringContains("Tags cannot be negative", + assertThrows(Throwable.class, () -> { + MessageGenerator.JSON_SERDE.readValue(String.join("", Arrays.asList( + "{", + " \"type\": \"request\",", + " \"name\": \"FooBar\",", + " \"validVersions\": \"0-2\",", + " \"flexibleVersions\": \"0+\",", + " \"fields\": [", + " { \"name\": \"field1\", \"type\": \"int32\", \"versions\": \"0+\", ", + " \"tag\": -1, \"taggedVersions\": \"0+\" }", + " ]", + "}")), MessageSpec.class); + }).getMessage()); + } + + @Test + public void testInvalidFlexibleVersionsRange() { + assertStringContains("flexibleVersions must be either none, or an open-ended range", + assertThrows(Throwable.class, () -> { + MessageGenerator.JSON_SERDE.readValue(String.join("", Arrays.asList( + "{", + " \"type\": \"request\",", + " \"name\": \"FooBar\",", + " \"validVersions\": \"0-2\",", + " \"flexibleVersions\": \"0-2\",", + " \"fields\": [", + " { \"name\": \"field1\", \"type\": \"int32\", \"versions\": \"0+\" }", + " ]", + "}")), MessageSpec.class); + }).getMessage()); + } + + @Test + public void testInvalidSometimesNullableTaggedField() { + assertStringContains("Either all tagged versions must be nullable, or none must be", + assertThrows(Throwable.class, () -> { + MessageGenerator.JSON_SERDE.readValue(String.join("", Arrays.asList( + "{", + " \"type\": \"request\",", + " \"name\": \"FooBar\",", + " \"validVersions\": \"0-2\",", + " \"flexibleVersions\": \"0+\",", + " \"fields\": [", + " { \"name\": \"field1\", \"type\": \"string\", \"versions\": \"0+\", ", + " \"tag\": 0, \"taggedVersions\": \"0+\", \"nullableVersions\": \"1+\" }", + " ]", + "}")), MessageSpec.class); + }).getMessage()); + } + + @Test + public void testInvalidTaggedVersionsNotASubetOfVersions() { + assertStringContains("taggedVersions must be a subset of versions", + assertThrows(Throwable.class, () -> { + MessageGenerator.JSON_SERDE.readValue(String.join("", Arrays.asList( + "{", + " \"type\": \"request\",", + " \"name\": \"FooBar\",", + " \"validVersions\": \"0-2\",", + " \"flexibleVersions\": \"0+\",", + " \"fields\": [", + " { \"name\": \"field1\", \"type\": \"string\", \"versions\": \"0-2\", ", + " \"tag\": 0, \"taggedVersions\": \"1+\" }", + " ]", + "}")), MessageSpec.class); + }).getMessage()); + } + + @Test + public void testInvalidTaggedVersionsWithoutTag() { + assertStringContains("Please specify a tag, or remove the taggedVersions", + assertThrows(Throwable.class, () -> { + MessageGenerator.JSON_SERDE.readValue(String.join("", Arrays.asList( + "{", + " \"type\": \"request\",", + " \"name\": \"FooBar\",", + " \"validVersions\": \"0-2\",", + " \"flexibleVersions\": \"0+\",", + " \"fields\": [", + " { \"name\": \"field1\", \"type\": \"string\", \"versions\": \"0+\", ", + " \"taggedVersions\": \"1+\" }", + " ]", + "}")), MessageSpec.class); + }).getMessage()); + } + + @Test + public void testInvalidTaggedVersionsRange() { + assertStringContains("taggedVersions must be either none, or an open-ended range", + assertThrows(Throwable.class, () -> { + MessageGenerator.JSON_SERDE.readValue(String.join("", Arrays.asList( + "{", + " \"type\": \"request\",", + " \"name\": \"FooBar\",", + " \"validVersions\": \"0-2\",", + " \"flexibleVersions\": \"0+\",", + " \"fields\": [", + " { \"name\": \"field1\", \"type\": \"string\", \"versions\": \"0+\", ", + " \"tag\": 0, \"taggedVersions\": \"1-2\" }", + " ]", + "}")), MessageSpec.class); + }).getMessage()); + } + + @Test + public void testDuplicateTags() { + assertStringContains("duplicate tag", + assertThrows(Throwable.class, () -> { + MessageGenerator.JSON_SERDE.readValue(String.join("", Arrays.asList( + "{", + " \"type\": \"request\",", + " \"name\": \"FooBar\",", + " \"validVersions\": \"0-2\",", + " \"flexibleVersions\": \"0+\",", + " \"fields\": [", + " { \"name\": \"field1\", \"type\": \"string\", \"versions\": \"0+\", ", + " \"tag\": 0, \"taggedVersions\": \"0+\" },", + " { \"name\": \"field2\", \"type\": \"int64\", \"versions\": \"0+\", ", + " \"tag\": 0, \"taggedVersions\": \"0+\" }", + " ]", + "}")), MessageSpec.class); + }).getMessage()); + } +} diff --git a/generator/src/test/java/org/apache/kafka/message/MessageGeneratorTest.java b/generator/src/test/java/org/apache/kafka/message/MessageGeneratorTest.java new file mode 100644 index 0000000..07766f2 --- /dev/null +++ b/generator/src/test/java/org/apache/kafka/message/MessageGeneratorTest.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +@Timeout(120) +public class MessageGeneratorTest { + + @Test + public void testCapitalizeFirst() throws Exception { + assertEquals("", MessageGenerator.capitalizeFirst("")); + assertEquals("AbC", MessageGenerator.capitalizeFirst("abC")); + } + + @Test + public void testLowerCaseFirst() throws Exception { + assertEquals("", MessageGenerator.lowerCaseFirst("")); + assertEquals("fORTRAN", MessageGenerator.lowerCaseFirst("FORTRAN")); + assertEquals("java", MessageGenerator.lowerCaseFirst("java")); + } + + @Test + public void testFirstIsCapitalized() throws Exception { + assertFalse(MessageGenerator.firstIsCapitalized("")); + assertTrue(MessageGenerator.firstIsCapitalized("FORTRAN")); + assertFalse(MessageGenerator.firstIsCapitalized("java")); + } + + @Test + public void testToSnakeCase() throws Exception { + assertEquals("", MessageGenerator.toSnakeCase("")); + assertEquals("foo_bar_baz", MessageGenerator.toSnakeCase("FooBarBaz")); + assertEquals("foo_bar_baz", MessageGenerator.toSnakeCase("fooBarBaz")); + assertEquals("fortran", MessageGenerator.toSnakeCase("FORTRAN")); + } + + @Test + public void stripSuffixTest() throws Exception { + assertEquals("FooBa", MessageGenerator.stripSuffix("FooBar", "r")); + assertEquals("", MessageGenerator.stripSuffix("FooBar", "FooBar")); + assertEquals("Foo", MessageGenerator.stripSuffix("FooBar", "Bar")); + try { + MessageGenerator.stripSuffix("FooBar", "Baz"); + fail("expected exception"); + } catch (RuntimeException e) { + } + } +} diff --git a/generator/src/test/java/org/apache/kafka/message/StructRegistryTest.java b/generator/src/test/java/org/apache/kafka/message/StructRegistryTest.java new file mode 100644 index 0000000..478a72a --- /dev/null +++ b/generator/src/test/java/org/apache/kafka/message/StructRegistryTest.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.Arrays; +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +@Timeout(120) +public class StructRegistryTest { + + @Test + public void testCommonStructs() throws Exception { + MessageSpec testMessageSpec = MessageGenerator.JSON_SERDE.readValue(String.join("", Arrays.asList( + "{", + " \"type\": \"request\",", + " \"name\": \"LeaderAndIsrRequest\",", + " \"validVersions\": \"0-2\",", + " \"flexibleVersions\": \"0+\",", + " \"fields\": [", + " { \"name\": \"field1\", \"type\": \"int32\", \"versions\": \"0+\" },", + " { \"name\": \"field2\", \"type\": \"[]TestCommonStruct\", \"versions\": \"1+\" },", + " { \"name\": \"field3\", \"type\": \"[]TestInlineStruct\", \"versions\": \"0+\", ", + " \"fields\": [", + " { \"name\": \"inlineField1\", \"type\": \"int64\", \"versions\": \"0+\" }", + " ]}", + " ],", + " \"commonStructs\": [", + " { \"name\": \"TestCommonStruct\", \"versions\": \"0+\", \"fields\": [", + " { \"name\": \"commonField1\", \"type\": \"int64\", \"versions\": \"0+\" }", + " ]}", + " ]", + "}")), MessageSpec.class); + StructRegistry structRegistry = new StructRegistry(); + structRegistry.register(testMessageSpec); + assertEquals(structRegistry.commonStructNames(), Collections.singleton("TestCommonStruct")); + assertFalse(structRegistry.isStructArrayWithKeys(testMessageSpec.fields().get(1))); + assertFalse(structRegistry.isStructArrayWithKeys(testMessageSpec.fields().get(2))); + assertTrue(structRegistry.commonStructs().hasNext()); + assertEquals(structRegistry.commonStructs().next().name(), "TestCommonStruct"); + } + + @Test + public void testReSpecifiedCommonStructError() throws Exception { + MessageSpec testMessageSpec = MessageGenerator.JSON_SERDE.readValue(String.join("", Arrays.asList( + "{", + " \"type\": \"request\",", + " \"name\": \"LeaderAndIsrRequest\",", + " \"validVersions\": \"0-2\",", + " \"flexibleVersions\": \"0+\",", + " \"fields\": [", + " { \"name\": \"field1\", \"type\": \"int32\", \"versions\": \"0+\" },", + " { \"name\": \"field2\", \"type\": \"[]TestCommonStruct\", \"versions\": \"0+\", ", + " \"fields\": [", + " { \"name\": \"inlineField1\", \"type\": \"int64\", \"versions\": \"0+\" }", + " ]}", + " ],", + " \"commonStructs\": [", + " { \"name\": \"TestCommonStruct\", \"versions\": \"0+\", \"fields\": [", + " { \"name\": \"commonField1\", \"type\": \"int64\", \"versions\": \"0+\" }", + " ]}", + " ]", + "}")), MessageSpec.class); + StructRegistry structRegistry = new StructRegistry(); + try { + structRegistry.register(testMessageSpec); + fail("Expected StructRegistry#registry to fail"); + } catch (RuntimeException e) { + assertTrue(e.getMessage().contains("Can't re-specify the common struct TestCommonStruct " + + "as an inline struct.")); + } + } + + @Test + public void testDuplicateCommonStructError() throws Exception { + MessageSpec testMessageSpec = MessageGenerator.JSON_SERDE.readValue(String.join("", Arrays.asList( + "{", + " \"type\": \"request\",", + " \"name\": \"LeaderAndIsrRequest\",", + " \"validVersions\": \"0-2\",", + " \"flexibleVersions\": \"0+\",", + " \"fields\": [", + " { \"name\": \"field1\", \"type\": \"int32\", \"versions\": \"0+\" }", + " ],", + " \"commonStructs\": [", + " { \"name\": \"TestCommonStruct\", \"versions\": \"0+\", \"fields\": [", + " { \"name\": \"commonField1\", \"type\": \"int64\", \"versions\": \"0+\" }", + " ]},", + " { \"name\": \"TestCommonStruct\", \"versions\": \"0+\", \"fields\": [", + " { \"name\": \"commonField1\", \"type\": \"int64\", \"versions\": \"0+\" }", + " ]}", + " ]", + "}")), MessageSpec.class); + StructRegistry structRegistry = new StructRegistry(); + try { + structRegistry.register(testMessageSpec); + fail("Expected StructRegistry#registry to fail"); + } catch (RuntimeException e) { + assertTrue(e.getMessage().contains("Common struct TestCommonStruct was specified twice.")); + } + } + + @Test + public void testSingleStruct() throws Exception { + MessageSpec testMessageSpec = MessageGenerator.JSON_SERDE.readValue(String.join("", Arrays.asList( + "{", + " \"type\": \"request\",", + " \"name\": \"LeaderAndIsrRequest\",", + " \"validVersions\": \"0-2\",", + " \"flexibleVersions\": \"0+\",", + " \"fields\": [", + " { \"name\": \"field1\", \"type\": \"int32\", \"versions\": \"0+\" },", + " { \"name\": \"field2\", \"type\": \"TestInlineStruct\", \"versions\": \"0+\", ", + " \"fields\": [", + " { \"name\": \"inlineField1\", \"type\": \"int64\", \"versions\": \"0+\" }", + " ]}", + " ]", + "}")), MessageSpec.class); + StructRegistry structRegistry = new StructRegistry(); + structRegistry.register(testMessageSpec); + + FieldSpec field2 = testMessageSpec.fields().get(1); + assertTrue(field2.type().isStruct()); + assertEquals(field2.type().toString(), "TestInlineStruct"); + assertEquals(field2.name(), "field2"); + + assertEquals(structRegistry.findStruct(field2).name(), "TestInlineStruct"); + assertFalse(structRegistry.isStructArrayWithKeys(field2)); + } +} diff --git a/generator/src/test/java/org/apache/kafka/message/VersionConditionalTest.java b/generator/src/test/java/org/apache/kafka/message/VersionConditionalTest.java new file mode 100644 index 0000000..9d7cdce --- /dev/null +++ b/generator/src/test/java/org/apache/kafka/message/VersionConditionalTest.java @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.io.StringWriter; + +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; + +@Timeout(120) +public class VersionConditionalTest { + + static void claimEquals(CodeBuffer buffer, String... lines) throws Exception { + StringWriter stringWriter = new StringWriter(); + buffer.write(stringWriter); + StringBuilder expectedStringBuilder = new StringBuilder(); + for (String line : lines) { + expectedStringBuilder.append(String.format(line)); + } + assertEquals(stringWriter.toString(), expectedStringBuilder.toString()); + } + + @Test + public void testAlwaysFalseConditional() throws Exception { + CodeBuffer buffer = new CodeBuffer(); + VersionConditional. + forVersions(Versions.parse("1-2", null), Versions.parse("3+", null)). + ifMember(__ -> { + buffer.printf("System.out.println(\"hello world\");%n"); + }). + ifNotMember(__ -> { + buffer.printf("System.out.println(\"foobar\");%n"); + }). + generate(buffer); + claimEquals(buffer, + "System.out.println(\"foobar\");%n"); + } + + @Test + public void testAnotherAlwaysFalseConditional() throws Exception { + CodeBuffer buffer = new CodeBuffer(); + VersionConditional. + forVersions(Versions.parse("3+", null), Versions.parse("1-2", null)). + ifMember(__ -> { + buffer.printf("System.out.println(\"hello world\");%n"); + }). + ifNotMember(__ -> { + buffer.printf("System.out.println(\"foobar\");%n"); + }). + generate(buffer); + claimEquals(buffer, + "System.out.println(\"foobar\");%n"); + } + + @Test + public void testAllowMembershipCheckAlwaysFalseFails() throws Exception { + try { + CodeBuffer buffer = new CodeBuffer(); + VersionConditional. + forVersions(Versions.parse("1-2", null), Versions.parse("3+", null)). + ifMember(__ -> { + buffer.printf("System.out.println(\"hello world\");%n"); + }). + ifNotMember(__ -> { + buffer.printf("System.out.println(\"foobar\");%n"); + }). + allowMembershipCheckAlwaysFalse(false). + generate(buffer); + } catch (RuntimeException e) { + assertTrue(e.getMessage().contains("no versions in common")); + } + } + + @Test + public void testAlwaysTrueConditional() throws Exception { + CodeBuffer buffer = new CodeBuffer(); + VersionConditional. + forVersions(Versions.parse("1-5", null), Versions.parse("2-4", null)). + ifMember(__ -> { + buffer.printf("System.out.println(\"hello world\");%n"); + }). + ifNotMember(__ -> { + buffer.printf("System.out.println(\"foobar\");%n"); + }). + allowMembershipCheckAlwaysFalse(false). + generate(buffer); + claimEquals(buffer, + "System.out.println(\"hello world\");%n"); + } + + @Test + public void testAlwaysTrueConditionalWithAlwaysEmitBlockScope() throws Exception { + CodeBuffer buffer = new CodeBuffer(); + VersionConditional. + forVersions(Versions.parse("1-5", null), Versions.parse("2-4", null)). + ifMember(__ -> { + buffer.printf("System.out.println(\"hello world\");%n"); + }). + ifNotMember(__ -> { + buffer.printf("System.out.println(\"foobar\");%n"); + }). + alwaysEmitBlockScope(true). + generate(buffer); + claimEquals(buffer, + "{%n", + " System.out.println(\"hello world\");%n", + "}%n"); + } + + @Test + public void testLowerRangeCheckWithElse() throws Exception { + CodeBuffer buffer = new CodeBuffer(); + VersionConditional. + forVersions(Versions.parse("1+", null), Versions.parse("0-100", null)). + ifMember(__ -> { + buffer.printf("System.out.println(\"hello world\");%n"); + }). + ifNotMember(__ -> { + buffer.printf("System.out.println(\"foobar\");%n"); + }). + generate(buffer); + claimEquals(buffer, + "if (_version >= 1) {%n", + " System.out.println(\"hello world\");%n", + "} else {%n", + " System.out.println(\"foobar\");%n", + "}%n"); + } + + @Test + public void testLowerRangeCheckWithIfMember() throws Exception { + CodeBuffer buffer = new CodeBuffer(); + VersionConditional. + forVersions(Versions.parse("1+", null), Versions.parse("0-100", null)). + ifMember(__ -> { + buffer.printf("System.out.println(\"hello world\");%n"); + }). + generate(buffer); + claimEquals(buffer, + "if (_version >= 1) {%n", + " System.out.println(\"hello world\");%n", + "}%n"); + } + + @Test + public void testLowerRangeCheckWithIfNotMember() throws Exception { + CodeBuffer buffer = new CodeBuffer(); + VersionConditional. + forVersions(Versions.parse("1+", null), Versions.parse("0-100", null)). + ifNotMember(__ -> { + buffer.printf("System.out.println(\"hello world\");%n"); + }). + generate(buffer); + claimEquals(buffer, + "if (_version < 1) {%n", + " System.out.println(\"hello world\");%n", + "}%n"); + } + + @Test + public void testUpperRangeCheckWithElse() throws Exception { + CodeBuffer buffer = new CodeBuffer(); + VersionConditional. + forVersions(Versions.parse("0-10", null), Versions.parse("4+", null)). + ifMember(__ -> { + buffer.printf("System.out.println(\"hello world\");%n"); + }). + ifNotMember(__ -> { + buffer.printf("System.out.println(\"foobar\");%n"); + }). + generate(buffer); + claimEquals(buffer, + "if (_version <= 10) {%n", + " System.out.println(\"hello world\");%n", + "} else {%n", + " System.out.println(\"foobar\");%n", + "}%n"); + } + + @Test + public void testUpperRangeCheckWithIfMember() throws Exception { + CodeBuffer buffer = new CodeBuffer(); + VersionConditional. + forVersions(Versions.parse("0-10", null), Versions.parse("4+", null)). + ifMember(__ -> { + buffer.printf("System.out.println(\"hello world\");%n"); + }). + generate(buffer); + claimEquals(buffer, + "if (_version <= 10) {%n", + " System.out.println(\"hello world\");%n", + "}%n"); + } + + @Test + public void testUpperRangeCheckWithIfNotMember() throws Exception { + CodeBuffer buffer = new CodeBuffer(); + VersionConditional. + forVersions(Versions.parse("1+", null), Versions.parse("0-100", null)). + ifNotMember(__ -> { + buffer.printf("System.out.println(\"hello world\");%n"); + }). + generate(buffer); + claimEquals(buffer, + "if (_version < 1) {%n", + " System.out.println(\"hello world\");%n", + "}%n"); + } + + @Test + public void testFullRangeCheck() throws Exception { + CodeBuffer buffer = new CodeBuffer(); + VersionConditional. + forVersions(Versions.parse("5-10", null), Versions.parse("1+", null)). + ifMember(__ -> { + buffer.printf("System.out.println(\"hello world\");%n"); + }). + allowMembershipCheckAlwaysFalse(false). + generate(buffer); + claimEquals(buffer, + "if ((_version >= 5) && (_version <= 10)) {%n", + " System.out.println(\"hello world\");%n", + "}%n"); + } +} diff --git a/generator/src/test/java/org/apache/kafka/message/VersionsTest.java b/generator/src/test/java/org/apache/kafka/message/VersionsTest.java new file mode 100644 index 0000000..23e59af --- /dev/null +++ b/generator/src/test/java/org/apache/kafka/message/VersionsTest.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.message; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; + +@Timeout(120) +public class VersionsTest { + + private static Versions newVersions(int lower, int higher) { + if ((lower < Short.MIN_VALUE) || (lower > Short.MAX_VALUE)) { + throw new RuntimeException("lower bound out of range."); + } + if ((higher < Short.MIN_VALUE) || (higher > Short.MAX_VALUE)) { + throw new RuntimeException("higher bound out of range."); + } + return new Versions((short) lower, (short) higher); + } + + @Test + public void testVersionsParse() { + assertEquals(Versions.NONE, Versions.parse(null, Versions.NONE)); + assertEquals(Versions.ALL, Versions.parse(" ", Versions.ALL)); + assertEquals(Versions.ALL, Versions.parse("", Versions.ALL)); + assertEquals(newVersions(4, 5), Versions.parse(" 4-5 ", null)); + } + + @Test + public void testRoundTrips() { + testRoundTrip(Versions.ALL, "0+"); + testRoundTrip(newVersions(1, 3), "1-3"); + testRoundTrip(newVersions(2, 2), "2"); + testRoundTrip(newVersions(3, Short.MAX_VALUE), "3+"); + testRoundTrip(Versions.NONE, "none"); + } + + private void testRoundTrip(Versions versions, String string) { + assertEquals(string, versions.toString()); + assertEquals(versions, Versions.parse(versions.toString(), null)); + } + + @Test + public void testIntersections() { + assertEquals(newVersions(2, 3), newVersions(1, 3).intersect( + newVersions(2, 4))); + assertEquals(newVersions(3, 3), newVersions(0, Short.MAX_VALUE).intersect( + newVersions(3, 3))); + assertEquals(Versions.NONE, newVersions(9, Short.MAX_VALUE).intersect( + newVersions(2, 8))); + assertEquals(Versions.NONE, Versions.NONE.intersect(Versions.NONE)); + } + + @Test + public void testContains() { + assertTrue(newVersions(2, 3).contains((short) 3)); + assertTrue(newVersions(2, 3).contains((short) 2)); + assertFalse(newVersions(0, 1).contains((short) 2)); + assertTrue(newVersions(0, Short.MAX_VALUE).contains((short) 100)); + assertFalse(newVersions(2, Short.MAX_VALUE).contains((short) 0)); + assertTrue(newVersions(2, 3).contains(newVersions(2, 3))); + assertTrue(newVersions(2, 3).contains(newVersions(2, 2))); + assertFalse(newVersions(2, 3).contains(newVersions(2, 4))); + assertTrue(newVersions(2, 3).contains(Versions.NONE)); + assertTrue(Versions.ALL.contains(newVersions(1, 2))); + } + + @Test + public void testSubtract() { + assertEquals(Versions.NONE, Versions.NONE.subtract(Versions.NONE)); + assertEquals(newVersions(0, 0), + newVersions(0, 0).subtract(Versions.NONE)); + assertEquals(newVersions(1, 1), + newVersions(1, 2).subtract(newVersions(2, 2))); + assertEquals(newVersions(2, 2), + newVersions(1, 2).subtract(newVersions(1, 1))); + assertNull(newVersions(0, Short.MAX_VALUE).subtract(newVersions(1, 100))); + assertEquals(newVersions(10, 10), + newVersions(1, 10).subtract(newVersions(1, 9))); + assertEquals(newVersions(1, 1), + newVersions(1, 10).subtract(newVersions(2, 10))); + assertEquals(newVersions(2, 4), + newVersions(2, Short.MAX_VALUE).subtract(newVersions(5, Short.MAX_VALUE))); + assertEquals(newVersions(5, Short.MAX_VALUE), + newVersions(0, Short.MAX_VALUE).subtract(newVersions(0, 4))); + } +} diff --git a/gradle.properties b/gradle.properties new file mode 100644 index 0000000..85abd73 --- /dev/null +++ b/gradle.properties @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +group=org.apache.kafka +# NOTE: When you change this version number, you should also make sure to update +# the version numbers in +# - docs/js/templateData.js +# - tests/kafkatest/__init__.py +# - tests/kafkatest/version.py (variable DEV_VERSION) +# - kafka-merge-pr.py +version=3.1.0 +scalaVersion=2.13.6 +task=build +org.gradle.jvmargs=-Xmx2g -Xss4m -XX:+UseParallelGC +org.gradle.parallel=true diff --git a/gradle/dependencies.gradle b/gradle/dependencies.gradle new file mode 100644 index 0000000..5bb9d61 --- /dev/null +++ b/gradle/dependencies.gradle @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +ext { + versions = [:] + libs = [:] + + // Available if -PscalaVersion is used. This is useful when we want to support a Scala version that has + // a higher minimum Java requirement than Kafka. This was previously the case for Scala 2.12 and Java 7. + availableScalaVersions = [ '2.12', '2.13' ] +} + +// Add Scala version +def defaultScala212Version = '2.12.14' +def defaultScala213Version = '2.13.6' +if (hasProperty('scalaVersion')) { + if (scalaVersion == '2.12') { + versions["scala"] = defaultScala212Version + } else if (scalaVersion == '2.13') { + versions["scala"] = defaultScala213Version + } else { + versions["scala"] = scalaVersion + } +} else { + versions["scala"] = defaultScala212Version +} + +/* Resolve base Scala version according to these patterns: + 1. generally available Scala versions (such as: 2.12.y and 2.13.z) corresponding base versions will be: 2.12 and 2.13 (respectively) + 2. pre-release Scala versions (i.e. milestone/rc, such as: 2.13.0-M5, 2.13.0-RC1, 2.14.0-M1, etc.) will have identical base versions; + rationale: pre-release Scala versions are not binary compatible with each other and that's the reason why libraries include the full + Scala release string in their name for pre-releases (see dependencies below with an artifact name suffix '_$versions.baseScala') +*/ +if ( !versions.scala.contains('-') ) { + versions["baseScala"] = versions.scala.substring(0, versions.scala.lastIndexOf(".")) +} else { + versions["baseScala"] = versions.scala +} + +versions += [ + activation: "1.1.1", + apacheda: "1.0.2", + apacheds: "2.0.0-M24", + argparse4j: "0.7.0", + bcpkix: "1.66", + checkstyle: "8.36.2", + commonsCli: "1.4", + dropwizardMetrics: "4.1.12.1", + gradle: "7.2", + grgit: "4.1.1", + httpclient: "4.5.13", + easymock: "4.3", + jackson: "2.12.3", + jacoco: "0.8.7", + javassist: "3.27.0-GA", + jetty: "9.4.43.v20210629", + jersey: "2.34", + jline: "3.12.1", + jmh: "1.32", + hamcrest: "2.2", + log4j: "1.2.17", + scalaLogging: "3.9.3", + jaxb: "2.3.0", + jaxrs: "2.1.1", + jfreechart: "1.0.0", + jopt: "5.0.4", + jose4j: "0.7.8", + junit: "5.7.1", + jqwik: "1.5.0", + kafka_0100: "0.10.0.1", + kafka_0101: "0.10.1.1", + kafka_0102: "0.10.2.2", + kafka_0110: "0.11.0.3", + kafka_10: "1.0.2", + kafka_11: "1.1.1", + kafka_20: "2.0.1", + kafka_21: "2.1.1", + kafka_22: "2.2.2", + kafka_23: "2.3.1", + kafka_24: "2.4.1", + kafka_25: "2.5.1", + kafka_26: "2.6.2", + kafka_27: "2.7.1", + kafka_28: "2.8.1", + lz4: "1.8.0", + mavenArtifact: "3.8.1", + metrics: "2.2.0", + mockito: "3.12.4", + netty: "4.1.68.Final", + powermock: "2.0.9", + reflections: "0.9.12", + rocksDB: "6.22.1.1", + scalaCollectionCompat: "2.4.4", + scalafmt: "2.7.5", + scalaJava8Compat : "1.0.0", + scoverage: "1.4.1", + slf4j: "1.7.30", + snappy: "1.1.8.4", + spotbugs: "4.2.2", + zinc: "1.3.5", + zookeeper: "3.6.3", + zstd: "1.5.0-4" +] +libs += [ + activation: "javax.activation:activation:$versions.activation", + apacheda: "org.apache.directory.api:api-all:$versions.apacheda", + apachedsCoreApi: "org.apache.directory.server:apacheds-core-api:$versions.apacheds", + apachedsInterceptorKerberos: "org.apache.directory.server:apacheds-interceptor-kerberos:$versions.apacheds", + apachedsProtocolShared: "org.apache.directory.server:apacheds-protocol-shared:$versions.apacheds", + apachedsProtocolKerberos: "org.apache.directory.server:apacheds-protocol-kerberos:$versions.apacheds", + apachedsProtocolLdap: "org.apache.directory.server:apacheds-protocol-ldap:$versions.apacheds", + apachedsLdifPartition: "org.apache.directory.server:apacheds-ldif-partition:$versions.apacheds", + apachedsMavibotPartition: "org.apache.directory.server:apacheds-mavibot-partition:$versions.apacheds", + apachedsJdbmPartition: "org.apache.directory.server:apacheds-jdbm-partition:$versions.apacheds", + argparse4j: "net.sourceforge.argparse4j:argparse4j:$versions.argparse4j", + bcpkix: "org.bouncycastle:bcpkix-jdk15on:$versions.bcpkix", + commonsCli: "commons-cli:commons-cli:$versions.commonsCli", + easymock: "org.easymock:easymock:$versions.easymock", + jacksonAnnotations: "com.fasterxml.jackson.core:jackson-annotations:$versions.jackson", + jacksonDatabind: "com.fasterxml.jackson.core:jackson-databind:$versions.jackson", + jacksonDataformatCsv: "com.fasterxml.jackson.dataformat:jackson-dataformat-csv:$versions.jackson", + jacksonModuleScala: "com.fasterxml.jackson.module:jackson-module-scala_$versions.baseScala:$versions.jackson", + jacksonJDK8Datatypes: "com.fasterxml.jackson.datatype:jackson-datatype-jdk8:$versions.jackson", + jacksonJaxrsJsonProvider: "com.fasterxml.jackson.jaxrs:jackson-jaxrs-json-provider:$versions.jackson", + jaxbApi: "javax.xml.bind:jaxb-api:$versions.jaxb", + jaxrsApi: "javax.ws.rs:javax.ws.rs-api:$versions.jaxrs", + javassist: "org.javassist:javassist:$versions.javassist", + jettyServer: "org.eclipse.jetty:jetty-server:$versions.jetty", + jettyClient: "org.eclipse.jetty:jetty-client:$versions.jetty", + jettyServlet: "org.eclipse.jetty:jetty-servlet:$versions.jetty", + jettyServlets: "org.eclipse.jetty:jetty-servlets:$versions.jetty", + jerseyContainerServlet: "org.glassfish.jersey.containers:jersey-container-servlet:$versions.jersey", + jerseyHk2: "org.glassfish.jersey.inject:jersey-hk2:$versions.jersey", + jline: "org.jline:jline:$versions.jline", + jmhCore: "org.openjdk.jmh:jmh-core:$versions.jmh", + jmhCoreBenchmarks: "org.openjdk.jmh:jmh-core-benchmarks:$versions.jmh", + jmhGeneratorAnnProcess: "org.openjdk.jmh:jmh-generator-annprocess:$versions.jmh", + joptSimple: "net.sf.jopt-simple:jopt-simple:$versions.jopt", + jose4j: "org.bitbucket.b_c:jose4j:$versions.jose4j", + junitJupiter: "org.junit.jupiter:junit-jupiter:$versions.junit", + junitJupiterApi: "org.junit.jupiter:junit-jupiter-api:$versions.junit", + junitVintageEngine: "org.junit.vintage:junit-vintage-engine:$versions.junit", + jqwik: "net.jqwik:jqwik:$versions.jqwik", + hamcrest: "org.hamcrest:hamcrest:$versions.hamcrest", + kafkaStreams_0100: "org.apache.kafka:kafka-streams:$versions.kafka_0100", + kafkaStreams_0101: "org.apache.kafka:kafka-streams:$versions.kafka_0101", + kafkaStreams_0102: "org.apache.kafka:kafka-streams:$versions.kafka_0102", + kafkaStreams_0110: "org.apache.kafka:kafka-streams:$versions.kafka_0110", + kafkaStreams_10: "org.apache.kafka:kafka-streams:$versions.kafka_10", + kafkaStreams_11: "org.apache.kafka:kafka-streams:$versions.kafka_11", + kafkaStreams_20: "org.apache.kafka:kafka-streams:$versions.kafka_20", + kafkaStreams_21: "org.apache.kafka:kafka-streams:$versions.kafka_21", + kafkaStreams_22: "org.apache.kafka:kafka-streams:$versions.kafka_22", + kafkaStreams_23: "org.apache.kafka:kafka-streams:$versions.kafka_23", + kafkaStreams_24: "org.apache.kafka:kafka-streams:$versions.kafka_24", + kafkaStreams_25: "org.apache.kafka:kafka-streams:$versions.kafka_25", + kafkaStreams_26: "org.apache.kafka:kafka-streams:$versions.kafka_26", + kafkaStreams_27: "org.apache.kafka:kafka-streams:$versions.kafka_27", + kafkaStreams_28: "org.apache.kafka:kafka-streams:$versions.kafka_28", + log4j: "log4j:log4j:$versions.log4j", + lz4: "org.lz4:lz4-java:$versions.lz4", + metrics: "com.yammer.metrics:metrics-core:$versions.metrics", + dropwizardMetrics: "io.dropwizard.metrics:metrics-core:$versions.dropwizardMetrics", + mockitoCore: "org.mockito:mockito-core:$versions.mockito", + mockitoInline: "org.mockito:mockito-inline:$versions.mockito", + mockitoJunitJupiter: "org.mockito:mockito-junit-jupiter:$versions.mockito", + nettyHandler: "io.netty:netty-handler:$versions.netty", + nettyTransportNativeEpoll: "io.netty:netty-transport-native-epoll:$versions.netty", + powermockJunit4: "org.powermock:powermock-module-junit4:$versions.powermock", + powermockEasymock: "org.powermock:powermock-api-easymock:$versions.powermock", + reflections: "org.reflections:reflections:$versions.reflections", + rocksDBJni: "org.rocksdb:rocksdbjni:$versions.rocksDB", + scalaCollectionCompat: "org.scala-lang.modules:scala-collection-compat_$versions.baseScala:$versions.scalaCollectionCompat", + scalaJava8Compat: "org.scala-lang.modules:scala-java8-compat_$versions.baseScala:$versions.scalaJava8Compat", + scalaLibrary: "org.scala-lang:scala-library:$versions.scala", + scalaLogging: "com.typesafe.scala-logging:scala-logging_$versions.baseScala:$versions.scalaLogging", + scalaReflect: "org.scala-lang:scala-reflect:$versions.scala", + slf4jApi: "org.slf4j:slf4j-api:$versions.slf4j", + slf4jlog4j: "org.slf4j:slf4j-log4j12:$versions.slf4j", + snappy: "org.xerial.snappy:snappy-java:$versions.snappy", + zookeeper: "org.apache.zookeeper:zookeeper:$versions.zookeeper", + jfreechart: "jfreechart:jfreechart:$versions.jfreechart", + mavenArtifact: "org.apache.maven:maven-artifact:$versions.mavenArtifact", + zstd: "com.github.luben:zstd-jni:$versions.zstd", + httpclient: "org.apache.httpcomponents:httpclient:$versions.httpclient" +] diff --git a/gradle/resources/rat-output-to-html.xsl b/gradle/resources/rat-output-to-html.xsl new file mode 100644 index 0000000..97ea7a1 --- /dev/null +++ b/gradle/resources/rat-output-to-html.xsl @@ -0,0 +1,206 @@ + + + + + + + + + + + + + + + + + + + + + + +

                Rat Report

                +

                This HTML version (yes, it is!) is generated from the RAT xml reports using Saxon9B. All the outputs required are displayed below, similar to the .txt version. + This is obviously a work in progress; and a prettier, easier to read and manage version will be available soon

                +
                + + + + + + + + + + + + + + + + + + + + + + + + + + +
                +Table 1: A snapshot summary of this rat report. +
                Notes: Binaries: Archives: Standards:
                Apache Licensed: Generated Documents:
                Note: JavaDocs are generated and so license header is optionalNote: Generated files do not require license headers
                Unknown Licenses - or files without a license. Unknown Licenses - or files without a license.
                +
                +
                +

                Unapproved Licenses:

                + + + +
                + + +
                +
                + +

                Archives:

                + + + + +
                +
                +
                + +

                + Files with Apache License headers will be marked AL
                + Binary files (which do not require AL headers) will be marked B
                + Compressed archives will be marked A
                + Notices, licenses etc will be marked N
                +

                + + + + ! + + + + N + A + B + + !!!!! + + +
                + + +
                +
                + +

                Printing headers for files without AL header...

                + + + +

                + +
                +
                +
                + + + +
                + + +
                +

                Resource:

                + +
                +
                + + + +

                First few lines of non-compliant file

                +

                + +

                +
                +

                Other Info:

                +
                + + + Header Type: +
                +
                + + + License Family: +
                +
                + + + License Approval: +
                +
                + + + Type: +
                +
                + + + +
                diff --git a/gradle/spotbugs-exclude.xml b/gradle/spotbugs-exclude.xml new file mode 100644 index 0000000..878cd01 --- /dev/null +++ b/gradle/spotbugs-exclude.xml @@ -0,0 +1,485 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +- + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 0000000..a37233c --- /dev/null +++ b/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,6 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionSha256Sum=a8da5b02437a60819cad23e10fc7e9cf32bcb57029d9cb277e26eeff76ce014b +distributionUrl=https\://services.gradle.org/distributions/gradle-7.2-all.zip +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/gradlew b/gradlew new file mode 100755 index 0000000..412c296 --- /dev/null +++ b/gradlew @@ -0,0 +1,198 @@ +#!/usr/bin/env sh + +# +# Copyright 2015 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +############################################################################## +## +## Gradle start up script for UN*X +## +############################################################################## + +# Attempt to set APP_HOME +# Resolve links: $0 may be a link +PRG="$0" +# Need this for relative symlinks. +while [ -h "$PRG" ] ; do + ls=`ls -ld "$PRG"` + link=`expr "$ls" : '.*-> \(.*\)$'` + if expr "$link" : '/.*' > /dev/null; then + PRG="$link" + else + PRG=`dirname "$PRG"`"/$link" + fi +done +SAVED="`pwd`" +cd "`dirname \"$PRG\"`/" >/dev/null +APP_HOME="`pwd -P`" +cd "$SAVED" >/dev/null + +APP_NAME="Gradle" +APP_BASE_NAME=`basename "$0"` + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD="maximum" + +warn () { + echo "$*" +} + +die () { + echo + echo "$*" + echo + exit 1 +} + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "`uname`" in + CYGWIN* ) + cygwin=true + ;; + Darwin* ) + darwin=true + ;; + MSYS* | MINGW* ) + msys=true + ;; + NONSTOP* ) + nonstop=true + ;; +esac + + +# Loop in case we encounter an error. +for attempt in 1 2 3; do + if [ ! -e "$APP_HOME/gradle/wrapper/gradle-wrapper.jar" ]; then + if ! curl -s -S --retry 3 -L -o "$APP_HOME/gradle/wrapper/gradle-wrapper.jar" "https://raw.githubusercontent.com/gradle/gradle/v7.2.0/gradle/wrapper/gradle-wrapper.jar"; then + rm -f "$APP_HOME/gradle/wrapper/gradle-wrapper.jar" + # Pause for a bit before looping in case the server throttled us. + sleep 5 + continue + fi + fi +done + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD="$JAVA_HOME/jre/sh/java" + else + JAVACMD="$JAVA_HOME/bin/java" + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD="java" + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." +fi + +# Increase the maximum file descriptors if we can. +if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then + MAX_FD_LIMIT=`ulimit -H -n` + if [ $? -eq 0 ] ; then + if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then + MAX_FD="$MAX_FD_LIMIT" + fi + ulimit -n $MAX_FD + if [ $? -ne 0 ] ; then + warn "Could not set maximum file descriptor limit: $MAX_FD" + fi + else + warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" + fi +fi + +# For Darwin, add options to specify how the application appears in the dock +if $darwin; then + GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" +fi + +# For Cygwin or MSYS, switch paths to Windows format before running java +if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then + APP_HOME=`cygpath --path --mixed "$APP_HOME"` + CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` + + JAVACMD=`cygpath --unix "$JAVACMD"` + + # We build the pattern for arguments to be converted via cygpath + ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` + SEP="" + for dir in $ROOTDIRSRAW ; do + ROOTDIRS="$ROOTDIRS$SEP$dir" + SEP="|" + done + OURCYGPATTERN="(^($ROOTDIRS))" + # Add a user-defined pattern to the cygpath arguments + if [ "$GRADLE_CYGPATTERN" != "" ] ; then + OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" + fi + # Now convert the arguments - kludge to limit ourselves to /bin/sh + i=0 + for arg in "$@" ; do + CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` + CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option + + if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition + eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` + else + eval `echo args$i`="\"$arg\"" + fi + i=`expr $i + 1` + done + case $i in + 0) set -- ;; + 1) set -- "$args0" ;; + 2) set -- "$args0" "$args1" ;; + 3) set -- "$args0" "$args1" "$args2" ;; + 4) set -- "$args0" "$args1" "$args2" "$args3" ;; + 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; + 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; + 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; + 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; + 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; + esac +fi + +# Escape application args +save () { + for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done + echo " " +} +APP_ARGS=`save "$@"` + +# Collect all arguments for the java command, following the shell quoting and substitution rules +eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" + +exec "$JAVACMD" "$@" diff --git a/gradlewAll b/gradlewAll new file mode 100755 index 0000000..30c2da3 --- /dev/null +++ b/gradlewAll @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Convenient way to invoke a gradle command with all Scala versions supported +# by default +./gradlew "$@" -PscalaVersion=2.12 && ./gradlew "$@" -PscalaVersion=2.13 + diff --git a/jmh-benchmarks/README.md b/jmh-benchmarks/README.md new file mode 100644 index 0000000..50db058 --- /dev/null +++ b/jmh-benchmarks/README.md @@ -0,0 +1,130 @@ +### JMH-Benchmarks module + +This module contains benchmarks written using [JMH](https://openjdk.java.net/projects/code-tools/jmh/) from OpenJDK. +Writing correct micro-benchmarks in Java (or another JVM language) is difficult and there are many non-obvious pitfalls (many +due to compiler optimizations). JMH is a framework for running and analyzing benchmarks (micro or macro) written in Java (or +another JVM language). + +### Running benchmarks + +If you want to set specific JMH flags or only run certain benchmarks, passing arguments via +gradle tasks is cumbersome. These are simplified by the provided `jmh.sh` script. + +The default behavior is to run all benchmarks: + + ./jmh-benchmarks/jmh.sh + +Pass a pattern or name after the command to select the benchmarks: + + ./jmh-benchmarks/jmh.sh LRUCacheBenchmark + +Check which benchmarks that match the provided pattern: + + ./jmh-benchmarks/jmh.sh -l LRUCacheBenchmark + +Run a specific test and override the number of forks, iterations and warm-up iteration to `2`: + + ./jmh-benchmarks/jmh.sh -f 2 -i 2 -wi 2 LRUCacheBenchmark + +Run a specific test with async and GC profilers on Linux and flame graph output: + + ./jmh-benchmarks/jmh.sh -prof gc -prof async:libPath=/path/to/libasyncProfiler.so\;output=flamegraph LRUCacheBenchmark + +The following sections cover async profiler and GC profilers in more detail. + +### Using JMH with async profiler + +It's good practice to check profiler output for microbenchmarks in order to verify that they represent the expected +application behavior and measure what you expect to measure. Some example pitfalls include the use of expensive mocks +or accidental inclusion of test setup code in the benchmarked code. JMH includes +[async-profiler](https://github.com/jvm-profiling-tools/async-profiler) integration that makes this easy: + + ./jmh-benchmarks/jmh.sh -prof async:libPath=/path/to/libasyncProfiler.so + +With flame graph output (the semicolon is escaped to ensure it is not treated as a command separator): + + ./jmh-benchmarks/jmh.sh -prof async:libPath=/path/to/libasyncProfiler.so\;output=flamegraph + +Simultaneous cpu, allocation and lock profiling with async profiler 2.0 and jfr output (the semicolon is +escaped to ensure it is not treated as a command separator): + + ./jmh-benchmarks/jmh.sh -prof async:libPath=/path/to/libasyncProfiler.so\;output=jfr\;alloc\;lock LRUCacheBenchmark + +A number of arguments can be passed to configure async profiler, run the following for a description: + + ./jmh-benchmarks/jmh.sh -prof async:help + +### Using JMH GC profiler + +It's good practice to run your benchmark with `-prof gc` to measure its allocation rate: + + ./jmh-benchmarks/jmh.sh -prof gc + +Of particular importance is the `norm` alloc rates, which measure the allocations per operation rather than allocations +per second which can increase when you have make your code faster. + +### Running JMH outside of gradle + +The JMH benchmarks can be run outside of gradle as you would with any executable jar file: + + java -jar /jmh-benchmarks/build/libs/kafka-jmh-benchmarks-*.jar -f2 LRUCacheBenchmark + +### Writing benchmarks + +For help in writing correct JMH tests, the best place to start is the [sample code](https://hg.openjdk.java.net/code-tools/jmh/file/tip/jmh-samples/src/main/java/org/openjdk/jmh/samples/) provided +by the JMH project. + +Typically, JMH is expected to run as a separate project in Maven. The jmh-benchmarks module uses +the [gradle shadow jar](https://github.com/johnrengelman/shadow) plugin to emulate this behavior, by creating the required +uber-jar file containing the benchmarking code and required JMH classes. + +JMH is highly configurable and users are encouraged to look through the samples for suggestions +on what options are available. A good tutorial for using JMH can be found [here](http://tutorials.jenkov.com/java-performance/jmh.html#return-value-from-benchmark-method) + +### Gradle Tasks + +If no benchmark mode is specified, the default is used which is throughput. It is assumed that users run +the gradle tasks with `./gradlew` from the root of the Kafka project. + +* `jmh-benchmarks:shadowJar` - creates the uber jar required to run the benchmarks. + +* `jmh-benchmarks:jmh` - runs the `clean` and `shadowJar` tasks followed by all the benchmarks. + +### JMH Options +Some common JMH options are: + +```text + + -e Benchmarks to exclude from the run. + + -f How many times to fork a single benchmark. Use 0 to + disable forking altogether. Warning: disabling + forking may have detrimental impact on benchmark + and infrastructure reliability, you might want + to use different warmup mode instead. + + -i Number of measurement iterations to do. Measurement + iterations are counted towards the benchmark score. + (default: 1 for SingleShotTime, and 5 for all other + modes) + + -l List the benchmarks that match a filter, and exit. + + -lprof List profilers, and exit. + + -o Redirect human-readable output to a given file. + + -prof Use profilers to collect additional benchmark data. + Some profilers are not available on all JVMs and/or + all OSes. Please see the list of available profilers + with -lprof. + + -v Verbosity mode. Available modes are: [SILENT, NORMAL, + EXTRA] + + -wi Number of warmup iterations to do. Warmup iterations + are not counted towards the benchmark score. (default: + 0 for SingleShotTime, and 5 for all other modes) +``` + +To view all options run jmh with the -h flag. diff --git a/jmh-benchmarks/jmh.sh b/jmh-benchmarks/jmh.sh new file mode 100755 index 0000000..2f500bf --- /dev/null +++ b/jmh-benchmarks/jmh.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +base_dir=$(dirname $0) +jmh_project_name="jmh-benchmarks" + +if [ ${base_dir} == "." ]; then + gradlew_dir=".." +elif [ ${base_dir##./} == "${jmh_project_name}" ]; then + gradlew_dir="." +else + echo "JMH Benchmarks need to be run from the root of the kafka repository or the 'jmh-benchmarks' directory" + exit +fi + +gradleCmd="${gradlew_dir}/gradlew" +libDir="${base_dir}/build/libs" + +echo "running gradlew :jmh-benchmarks:clean :jmh-benchmarks:shadowJar in quiet mode" + +$gradleCmd -q :jmh-benchmarks:clean :jmh-benchmarks:shadowJar + +echo "gradle build done" + +echo "running JMH with args: $@" + +java -jar ${libDir}/kafka-jmh-benchmarks-*.jar "$@" + +echo "JMH benchmarks done" diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/acl/AclAuthorizerBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/acl/AclAuthorizerBenchmark.java new file mode 100644 index 0000000..65aa2a1 --- /dev/null +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/acl/AclAuthorizerBenchmark.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.jmh.acl; + +import kafka.security.authorizer.AclAuthorizer; +import kafka.security.authorizer.AclAuthorizer.VersionedAcls; +import kafka.security.authorizer.AclEntry; +import org.apache.kafka.common.acl.AccessControlEntry; +import org.apache.kafka.common.acl.AclBindingFilter; +import org.apache.kafka.common.acl.AclOperation; +import org.apache.kafka.common.acl.AclPermissionType; +import org.apache.kafka.common.network.ClientInformation; +import org.apache.kafka.common.network.ListenerName; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.requests.RequestContext; +import org.apache.kafka.common.requests.RequestHeader; +import org.apache.kafka.common.resource.PatternType; +import org.apache.kafka.common.resource.ResourcePattern; +import org.apache.kafka.common.resource.ResourceType; +import org.apache.kafka.common.security.auth.KafkaPrincipal; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.server.authorizer.Action; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; +import scala.collection.JavaConverters; + +import java.net.InetAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.TimeUnit; + +@State(Scope.Benchmark) +@Fork(value = 1) +@Warmup(iterations = 5) +@Measurement(iterations = 15) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +public class AclAuthorizerBenchmark { + @Param({"10000", "50000", "200000"}) + private int resourceCount; + //no. of. rules per resource + @Param({"10", "50"}) + private int aclCount; + + @Param({"0", "20", "50", "90", "99", "99.9", "99.99", "100"}) + private double denyPercentage; + + private final int hostPreCount = 1000; + private final String resourceNamePrefix = "foo-bar35_resource-"; + private final AclAuthorizer aclAuthorizer = new AclAuthorizer(); + private final KafkaPrincipal principal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "test-user"); + private List actions = new ArrayList<>(); + private RequestContext authorizeContext; + private RequestContext authorizeByResourceTypeContext; + private String authorizeByResourceTypeHostName = "127.0.0.2"; + + private HashMap aclToUpdate = new HashMap<>(); + + Random rand = new Random(System.currentTimeMillis()); + double eps = 1e-9; + + @Setup(Level.Trial) + public void setup() throws Exception { + prepareAclCache(); + prepareAclToUpdate(); + // By adding `-95` to the resource name prefix, we cause the `TreeMap.from/to` call to return + // most map entries. In such cases, we rely on the filtering based on `String.startsWith` + // to return the matching ACLs. Using a more efficient data structure (e.g. a prefix + // tree) should improve performance significantly). + actions = Collections.singletonList(new Action(AclOperation.WRITE, + new ResourcePattern(ResourceType.TOPIC, resourceNamePrefix + 95, PatternType.LITERAL), + 1, true, true)); + authorizeContext = new RequestContext(new RequestHeader(ApiKeys.PRODUCE, Integer.valueOf(1).shortValue(), + "someclient", 1), "1", InetAddress.getByName("127.0.0.1"), principal, + ListenerName.normalised("listener"), SecurityProtocol.PLAINTEXT, ClientInformation.EMPTY, false); + authorizeByResourceTypeContext = new RequestContext(new RequestHeader(ApiKeys.PRODUCE, Integer.valueOf(1).shortValue(), + "someclient", 1), "1", InetAddress.getByName(authorizeByResourceTypeHostName), principal, + ListenerName.normalised("listener"), SecurityProtocol.PLAINTEXT, ClientInformation.EMPTY, false); + } + + private void prepareAclCache() { + Map> aclEntries = new HashMap<>(); + for (int resourceId = 0; resourceId < resourceCount; resourceId++) { + ResourcePattern resource = new ResourcePattern( + (resourceId % 10 == 0) ? ResourceType.GROUP : ResourceType.TOPIC, + resourceNamePrefix + resourceId, + (resourceId % 5 == 0) ? PatternType.PREFIXED : PatternType.LITERAL); + + Set entries = aclEntries.computeIfAbsent(resource, k -> new HashSet<>()); + + for (int aclId = 0; aclId < aclCount; aclId++) { + // The principal in the request context we are using + // is principal.toString without any suffix + String principalName = principal.toString() + (aclId == 0 ? "" : aclId); + AccessControlEntry allowAce = new AccessControlEntry( + principalName, "*", AclOperation.READ, AclPermissionType.ALLOW); + + entries.add(new AclEntry(allowAce)); + + if (shouldDeny()) { + // dominantly deny the resource + AccessControlEntry denyAce = new AccessControlEntry( + principalName, "*", AclOperation.READ, AclPermissionType.DENY); + entries.add(new AclEntry(denyAce)); + } + } + } + + ResourcePattern resourcePrefix = new ResourcePattern(ResourceType.TOPIC, resourceNamePrefix, + PatternType.PREFIXED); + Set entriesPrefix = aclEntries.computeIfAbsent(resourcePrefix, k -> new HashSet<>()); + for (int hostId = 0; hostId < hostPreCount; hostId++) { + AccessControlEntry allowAce = new AccessControlEntry(principal.toString(), "127.0.0." + hostId, + AclOperation.READ, AclPermissionType.ALLOW); + entriesPrefix.add(new AclEntry(allowAce)); + + if (shouldDeny()) { + // dominantly deny the resource + AccessControlEntry denyAce = new AccessControlEntry(principal.toString(), "127.0.0." + hostId, + AclOperation.READ, AclPermissionType.DENY); + entriesPrefix.add(new AclEntry(denyAce)); + } + } + + ResourcePattern resourceWildcard = new ResourcePattern(ResourceType.TOPIC, ResourcePattern.WILDCARD_RESOURCE, + PatternType.LITERAL); + Set entriesWildcard = aclEntries.computeIfAbsent(resourceWildcard, k -> new HashSet<>()); + // get dynamic entries number for wildcard acl + for (int hostId = 0; hostId < resourceCount / 10; hostId++) { + String hostName = "127.0.0" + hostId; + // AuthorizeByResourceType is optimizing the wildcard deny case. + // If we didn't skip the host, we would end up having a biased short runtime. + if (hostName.equals(authorizeByResourceTypeHostName)) { + continue; + } + + AccessControlEntry allowAce = new AccessControlEntry(principal.toString(), hostName, + AclOperation.READ, AclPermissionType.ALLOW); + entriesWildcard.add(new AclEntry(allowAce)); + if (shouldDeny()) { + AccessControlEntry denyAce = new AccessControlEntry(principal.toString(), hostName, + AclOperation.READ, AclPermissionType.DENY); + entriesWildcard.add(new AclEntry(denyAce)); + } + } + + for (Map.Entry> entryMap : aclEntries.entrySet()) { + aclAuthorizer.updateCache(entryMap.getKey(), + new VersionedAcls(JavaConverters.asScalaSetConverter(entryMap.getValue()).asScala().toSet(), 1)); + } + } + + private void prepareAclToUpdate() { + scala.collection.mutable.Set entries = new scala.collection.mutable.HashSet<>(); + for (int i = 0; i < resourceCount; i++) { + scala.collection.immutable.Set immutable = new scala.collection.immutable.HashSet<>(); + for (int j = 0; j < aclCount; j++) { + entries.add(new AclEntry(new AccessControlEntry( + principal.toString(), "127.0.0" + j, AclOperation.WRITE, AclPermissionType.ALLOW))); + immutable = entries.toSet(); + } + aclToUpdate.put( + new ResourcePattern(ResourceType.TOPIC, randomResourceName(resourceNamePrefix), PatternType.LITERAL), + new AclAuthorizer.VersionedAcls(immutable, i)); + } + } + + private String randomResourceName(String prefix) { + return prefix + UUID.randomUUID().toString().substring(0, 5); + } + + private Boolean shouldDeny() { + return rand.nextDouble() * 100.0 - eps < denyPercentage; + } + + @TearDown(Level.Trial) + public void tearDown() { + aclAuthorizer.close(); + } + + @Benchmark + public void testAclsIterator() { + aclAuthorizer.acls(AclBindingFilter.ANY); + } + + @Benchmark + public void testAuthorizer() { + aclAuthorizer.authorize(authorizeContext, actions); + } + + @Benchmark + public void testAuthorizeByResourceType() { + aclAuthorizer.authorizeByResourceType(authorizeByResourceTypeContext, AclOperation.READ, ResourceType.TOPIC); + } + + @Benchmark + public void testUpdateCache() { + AclAuthorizer aclAuthorizer = new AclAuthorizer(); + for (Map.Entry e : aclToUpdate.entrySet()) { + aclAuthorizer.updateCache(e.getKey(), e.getValue()); + } + } +} diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/admin/GetListOffsetsCallsBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/admin/GetListOffsetsCallsBenchmark.java new file mode 100644 index 0000000..8da09ed --- /dev/null +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/admin/GetListOffsetsCallsBenchmark.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.jmh.admin; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import org.apache.kafka.clients.admin.AdminClientTestUtils; +import org.apache.kafka.clients.admin.AdminClientUnitTestEnv; +import org.apache.kafka.clients.admin.KafkaAdminClient; +import org.apache.kafka.clients.admin.ListOffsetsOptions; +import org.apache.kafka.clients.admin.ListOffsetsResult; +import org.apache.kafka.clients.admin.OffsetSpec; +import org.apache.kafka.clients.admin.internals.MetadataOperationContext; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.message.MetadataResponseData; +import org.apache.kafka.common.message.MetadataResponseData.MetadataResponsePartition; +import org.apache.kafka.common.message.MetadataResponseData.MetadataResponseTopic; +import org.apache.kafka.common.requests.MetadataResponse; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +@State(Scope.Benchmark) +@Fork(value = 1) +@Warmup(iterations = 5) +@Measurement(iterations = 15) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MICROSECONDS) +public class GetListOffsetsCallsBenchmark { + @Param({"1", "10"}) + private int topicCount; + + @Param({"100", "1000", "10000"}) + private int partitionCount; + + private KafkaAdminClient admin; + private MetadataOperationContext context; + private final Map topicPartitionOffsets = new HashMap<>(); + private final Map> futures = new HashMap<>(); + private final int numNodes = 3; + + @Setup(Level.Trial) + public void setup() { + MetadataResponseData data = new MetadataResponseData(); + List mrTopicList = new ArrayList<>(); + Set topics = new HashSet<>(); + + for (int topicIndex = 0; topicIndex < topicCount; topicIndex++) { + Uuid topicId = Uuid.randomUuid(); + String topicName = "topic-" + topicIndex; + MetadataResponseTopic mrTopic = new MetadataResponseTopic() + .setTopicId(topicId) + .setName(topicName) + .setErrorCode((short) 0) + .setIsInternal(false); + + List mrPartitionList = new ArrayList<>(); + + for (int partition = 0; partition < partitionCount; partition++) { + TopicPartition tp = new TopicPartition(topicName, partition); + topics.add(tp.topic()); + futures.put(tp, new KafkaFutureImpl<>()); + topicPartitionOffsets.put(tp, OffsetSpec.latest()); + + MetadataResponsePartition mrPartition = new MetadataResponsePartition() + .setLeaderId(partition % numNodes) + .setPartitionIndex(partition) + .setIsrNodes(Arrays.asList(0, 1, 2)) + .setReplicaNodes(Arrays.asList(0, 1, 2)) + .setOfflineReplicas(Collections.emptyList()) + .setErrorCode((short) 0); + + mrPartitionList.add(mrPartition); + } + + mrTopic.setPartitions(mrPartitionList); + mrTopicList.add(mrTopic); + } + data.setTopics(new MetadataResponseData.MetadataResponseTopicCollection(mrTopicList.listIterator())); + + long deadline = 0L; + short version = 0; + context = new MetadataOperationContext<>(topics, new ListOffsetsOptions(), deadline, futures); + context.setResponse(Optional.of(new MetadataResponse(data, version))); + + AdminClientUnitTestEnv adminEnv = new AdminClientUnitTestEnv(mockCluster()); + admin = (KafkaAdminClient) adminEnv.adminClient(); + } + + @Benchmark + public Object testGetListOffsetsCalls() { + return AdminClientTestUtils.getListOffsetsCalls(admin, context, topicPartitionOffsets, futures); + } + + private Cluster mockCluster() { + final int controllerIndex = 0; + + HashMap nodes = new HashMap<>(); + for (int i = 0; i < numNodes; i++) + nodes.put(i, new Node(i, "localhost", 8121 + i)); + return new Cluster("mockClusterId", nodes.values(), + Collections.emptySet(), Collections.emptySet(), + Collections.emptySet(), nodes.get(controllerIndex)); + } +} diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/cache/LRUCacheBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/cache/LRUCacheBenchmark.java new file mode 100644 index 0000000..5b2d004 --- /dev/null +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/cache/LRUCacheBenchmark.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.jmh.cache; + +import java.util.concurrent.TimeUnit; +import org.apache.kafka.common.cache.LRUCache; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.runner.Runner; +import org.openjdk.jmh.runner.RunnerException; +import org.openjdk.jmh.runner.options.Options; +import org.openjdk.jmh.runner.options.OptionsBuilder; + +/** + * This is a simple example of a JMH benchmark. + * + * The sample code provided by the JMH project is a great place to start learning how to write correct benchmarks: + * http://hg.openjdk.java.net/code-tools/jmh/file/tip/jmh-samples/src/main/java/org/openjdk/jmh/samples/ + */ +@State(Scope.Thread) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +public class LRUCacheBenchmark { + + private static final int DISTINCT_KEYS = 10_000; + + private static final String KEY = "the_key_to_use"; + + private static final String VALUE = "the quick brown fox jumped over the lazy dog the olympics are about to start"; + + private final String[] keys = new String[DISTINCT_KEYS]; + + private final String[] values = new String[DISTINCT_KEYS]; + + private LRUCache lruCache; + + private long counter = 0; + + @Setup(Level.Trial) + public void setUp() { + for (int i = 0; i < DISTINCT_KEYS; ++i) { + keys[i] = KEY + i; + values[i] = VALUE + i; + } + lruCache = new LRUCache<>(100); + } + + @Benchmark + public String testCachePerformance() { + counter++; + int index = (int) (counter % DISTINCT_KEYS); + String hashkey = keys[index]; + lruCache.put(hashkey, values[index]); + return lruCache.get(hashkey); + } + + public static void main(String[] args) throws RunnerException { + Options opt = new OptionsBuilder() + .include(LRUCacheBenchmark.class.getSimpleName()) + .forks(2) + .build(); + + new Runner(opt).run(); + } + +} diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/common/FetchRequestBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/common/FetchRequestBenchmark.java new file mode 100644 index 0000000..a428f91 --- /dev/null +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/common/FetchRequestBenchmark.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.jmh.common; + +import kafka.network.RequestConvertToJson; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.network.Send; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.requests.AbstractRequest; +import org.apache.kafka.common.requests.ByteBufferChannel; +import org.apache.kafka.common.requests.FetchRequest; +import org.apache.kafka.common.requests.RequestHeader; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.TimeUnit; + +@State(Scope.Benchmark) +@Fork(value = 1) +@Warmup(iterations = 5) +@Measurement(iterations = 15) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +public class FetchRequestBenchmark { + + @Param({"10", "500", "1000"}) + private int topicCount; + + @Param({"3", "10", "20"}) + private int partitionCount; + + Map fetchData; + + Map topicNames; + + RequestHeader header; + + FetchRequest consumerRequest; + + FetchRequest replicaRequest; + + ByteBuffer requestBuffer; + + @Setup(Level.Trial) + public void setup() { + this.fetchData = new HashMap<>(); + this.topicNames = new HashMap<>(); + for (int topicIdx = 0; topicIdx < topicCount; topicIdx++) { + String topic = Uuid.randomUuid().toString(); + Uuid id = Uuid.randomUuid(); + topicNames.put(id, topic); + for (int partitionId = 0; partitionId < partitionCount; partitionId++) { + FetchRequest.PartitionData partitionData = new FetchRequest.PartitionData( + id, 0, 0, 4096, Optional.empty()); + fetchData.put(new TopicPartition(topic, partitionId), partitionData); + } + } + + this.header = new RequestHeader(ApiKeys.FETCH, ApiKeys.FETCH.latestVersion(), "jmh-benchmark", 100); + this.consumerRequest = FetchRequest.Builder.forConsumer(ApiKeys.FETCH.latestVersion(), 0, 0, fetchData) + .build(ApiKeys.FETCH.latestVersion()); + this.replicaRequest = FetchRequest.Builder.forReplica(ApiKeys.FETCH.latestVersion(), 1, 0, 0, fetchData) + .build(ApiKeys.FETCH.latestVersion()); + this.requestBuffer = this.consumerRequest.serialize(); + + } + + @Benchmark + public short testFetchRequestFromBuffer() { + return AbstractRequest.parseRequest(ApiKeys.FETCH, ApiKeys.FETCH.latestVersion(), requestBuffer).request.version(); + } + + @Benchmark + public int testFetchRequestForConsumer() { + FetchRequest fetchRequest = FetchRequest.Builder.forConsumer(ApiKeys.FETCH.latestVersion(), 0, 0, fetchData) + .build(ApiKeys.FETCH.latestVersion()); + return fetchRequest.fetchData(topicNames).size(); + } + + @Benchmark + public int testFetchRequestForReplica() { + FetchRequest fetchRequest = FetchRequest.Builder.forReplica( + ApiKeys.FETCH.latestVersion(), 1, 0, 0, fetchData) + .build(ApiKeys.FETCH.latestVersion()); + return fetchRequest.fetchData(topicNames).size(); + } + + @Benchmark + public int testSerializeFetchRequestForConsumer() throws IOException { + Send send = consumerRequest.toSend(header); + ByteBufferChannel channel = new ByteBufferChannel(send.size()); + send.writeTo(channel); + return channel.buffer().limit(); + } + + @Benchmark + public int testSerializeFetchRequestForReplica() throws IOException { + Send send = replicaRequest.toSend(header); + ByteBufferChannel channel = new ByteBufferChannel(send.size()); + send.writeTo(channel); + return channel.buffer().limit(); + } + + @Benchmark + public String testRequestToJson() { + return RequestConvertToJson.request(consumerRequest).toString(); + } +} diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/common/FetchResponseBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/common/FetchResponseBenchmark.java new file mode 100644 index 0000000..d8512bd --- /dev/null +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/common/FetchResponseBenchmark.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.jmh.common; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.message.FetchResponseData; +import org.apache.kafka.common.network.Send; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.SimpleRecord; +import org.apache.kafka.common.requests.ByteBufferChannel; +import org.apache.kafka.common.requests.FetchResponse; +import org.apache.kafka.common.requests.ResponseHeader; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.TimeUnit; + +@State(Scope.Benchmark) +@Fork(value = 1) +@Warmup(iterations = 5) +@Measurement(iterations = 15) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +public class FetchResponseBenchmark { + @Param({"10", "500", "1000"}) + private int topicCount; + + @Param({"3", "10", "20"}) + private int partitionCount; + + LinkedHashMap responseData; + + Map topicIds; + + Map topicNames; + + ResponseHeader header; + + FetchResponse fetchResponse; + + FetchResponseData fetchResponseData; + + @Setup(Level.Trial) + public void setup() { + MemoryRecords records = MemoryRecords.withRecords(CompressionType.NONE, + new SimpleRecord(1000, "key1".getBytes(StandardCharsets.UTF_8), "value1".getBytes(StandardCharsets.UTF_8)), + new SimpleRecord(1001, "key2".getBytes(StandardCharsets.UTF_8), "value2".getBytes(StandardCharsets.UTF_8)), + new SimpleRecord(1002, "key3".getBytes(StandardCharsets.UTF_8), "value3".getBytes(StandardCharsets.UTF_8))); + + this.responseData = new LinkedHashMap<>(); + this.topicIds = new HashMap<>(); + this.topicNames = new HashMap<>(); + for (int topicIdx = 0; topicIdx < topicCount; topicIdx++) { + String topic = UUID.randomUUID().toString(); + Uuid id = Uuid.randomUuid(); + topicIds.put(topic, id); + topicNames.put(id, topic); + for (int partitionId = 0; partitionId < partitionCount; partitionId++) { + FetchResponseData.PartitionData partitionData = new FetchResponseData.PartitionData() + .setPartitionIndex(partitionId) + .setLastStableOffset(0) + .setLogStartOffset(0) + .setRecords(records); + responseData.put(new TopicIdPartition(id, new TopicPartition(topic, partitionId)), partitionData); + } + } + + this.header = new ResponseHeader(100, ApiKeys.FETCH.responseHeaderVersion(ApiKeys.FETCH.latestVersion())); + this.fetchResponse = FetchResponse.of(Errors.NONE, 0, 0, responseData); + this.fetchResponseData = this.fetchResponse.data(); + } + + @Benchmark + public int testConstructFetchResponse() { + FetchResponse fetchResponse = FetchResponse.of(Errors.NONE, 0, 0, responseData); + return fetchResponse.data().responses().size(); + } + + @Benchmark + public int testPartitionMapFromData() { + return new FetchResponse(fetchResponseData).responseData(topicNames, ApiKeys.FETCH.latestVersion()).size(); + } + + @Benchmark + public int testSerializeFetchResponse() throws IOException { + Send send = fetchResponse.toSend(header, ApiKeys.FETCH.latestVersion()); + ByteBufferChannel channel = new ByteBufferChannel(send.size()); + send.writeTo(channel); + return channel.buffer().limit(); + } +} diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/common/ImplicitLinkedHashCollectionBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/common/ImplicitLinkedHashCollectionBenchmark.java new file mode 100644 index 0000000..c79c7f8 --- /dev/null +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/common/ImplicitLinkedHashCollectionBenchmark.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.jmh.common; + +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.utils.ImplicitLinkedHashCollection; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +import java.util.Comparator; +import java.util.concurrent.TimeUnit; + +@State(Scope.Benchmark) +@Fork(value = 1) +@Warmup(iterations = 3) +@Measurement(iterations = 6) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MICROSECONDS) +public class ImplicitLinkedHashCollectionBenchmark { + public static class TestElement implements ImplicitLinkedHashCollection.Element { + private final String value; + private int next = ImplicitLinkedHashCollection.INVALID_INDEX; + private int prev = ImplicitLinkedHashCollection.INVALID_INDEX; + + public TestElement(String value) { + this.value = value; + } + + public String value() { + return value; + } + + @Override + public int prev() { + return this.prev; + } + + @Override + public void setPrev(int prev) { + this.prev = prev; + } + + @Override + public int next() { + return this.next; + } + + @Override + public void setNext(int next) { + this.next = next; + } + + @Override + public int hashCode() { + return value.hashCode(); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof TestElement)) return false; + TestElement other = (TestElement) o; + return value.equals(other.value); + } + } + + public static class TestElementComparator implements Comparator { + public static final TestElementComparator INSTANCE = new TestElementComparator(); + + @Override + public int compare(TestElement a, TestElement b) { + return a.value().compareTo(b.value()); + } + } + + @Param({"10000", "100000"}) + private int size; + + private ImplicitLinkedHashCollection coll; + + @Setup(Level.Trial) + public void setup() { + coll = new ImplicitLinkedHashCollection<>(); + for (int i = 0; i < size; i++) { + coll.add(new TestElement(Uuid.randomUuid().toString())); + } + } + + /** + * Test sorting the collection entries. + */ + @Benchmark + public ImplicitLinkedHashCollection testCollectionSort() { + coll.sort(TestElementComparator.INSTANCE); + return coll; + } +} diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/common/ListOffsetRequestBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/common/ListOffsetRequestBenchmark.java new file mode 100644 index 0000000..e6fc2dc --- /dev/null +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/common/ListOffsetRequestBenchmark.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.jmh.common; + +import kafka.network.RequestConvertToJson; +import org.apache.kafka.common.IsolationLevel; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.message.ListOffsetsRequestData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.requests.ListOffsetsRequest; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +import java.util.HashMap; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.TimeUnit; + +@State(Scope.Benchmark) +@Fork(value = 1) +@Warmup(iterations = 5) +@Measurement(iterations = 15) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +public class ListOffsetRequestBenchmark { + @Param({"10", "500", "1000"}) + private int topicCount; + + @Param({"3", "10", "20"}) + private int partitionCount; + + Map offsetData; + + ListOffsetsRequest offsetRequest; + + @Setup(Level.Trial) + public void setup() { + this.offsetData = new HashMap<>(); + for (int topicIdx = 0; topicIdx < topicCount; topicIdx++) { + String topic = UUID.randomUUID().toString(); + for (int partitionId = 0; partitionId < partitionCount; partitionId++) { + ListOffsetsRequestData.ListOffsetsPartition data = new ListOffsetsRequestData.ListOffsetsPartition(); + this.offsetData.put(new TopicPartition(topic, partitionId), data); + } + } + + this.offsetRequest = ListOffsetsRequest.Builder.forConsumer(false, IsolationLevel.READ_UNCOMMITTED, false) + .build(ApiKeys.LIST_OFFSETS.latestVersion()); + } + + @Benchmark + public String testRequestToJson() { + return RequestConvertToJson.request(offsetRequest).toString(); + } +} diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/common/ProduceRequestBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/common/ProduceRequestBenchmark.java new file mode 100644 index 0000000..405458f --- /dev/null +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/common/ProduceRequestBenchmark.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.jmh.common; + +import kafka.network.RequestConvertToJson; +import org.apache.kafka.common.message.ProduceRequestData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.requests.ProduceRequest; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +import java.util.concurrent.TimeUnit; + +@State(Scope.Benchmark) +@Fork(value = 1) +@Warmup(iterations = 5) +@Measurement(iterations = 15) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +public class ProduceRequestBenchmark { + @Param({"10", "500", "1000"}) + private int topicCount; + + @Param({"3", "10", "20"}) + private int partitionCount; + + ProduceRequest produceRequest; + + @Setup(Level.Trial) + public void setup() { + this.produceRequest = ProduceRequest.forCurrentMagic(new ProduceRequestData()) + .build(ApiKeys.PRODUCE.latestVersion()); + } + + @Benchmark + public String testRequestToJson() { + return RequestConvertToJson.request(produceRequest).toString(); + } +} diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/common/TopicBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/common/TopicBenchmark.java new file mode 100644 index 0000000..fde239e --- /dev/null +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/common/TopicBenchmark.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.jmh.common; + +import org.apache.kafka.common.internals.Topic; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +import java.util.concurrent.TimeUnit; + +@State(Scope.Benchmark) +@Fork(value = 1) +@Warmup(iterations = 5) +@Measurement(iterations = 15) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +public class TopicBenchmark { + + @State(Scope.Thread) + public static class BenchState { + @Param({"topic", "longer-topic-name", "very-long-topic-name.with_more_text"}) + public String topicName; + } + + @Benchmark + public BenchState testValidate(BenchState state) { + // validate doesn't return anything, so return `state` to prevent the JVM from optimising the whole call away + Topic.validate(state.topicName); + return state; + } +} diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/consumer/SubscriptionStateBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/consumer/SubscriptionStateBenchmark.java new file mode 100644 index 0000000..d88a32b --- /dev/null +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/consumer/SubscriptionStateBenchmark.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.jmh.consumer; + +import org.apache.kafka.clients.Metadata; +import org.apache.kafka.clients.consumer.OffsetResetStrategy; +import org.apache.kafka.clients.consumer.internals.SubscriptionState; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.LogContext; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +import java.util.HashSet; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; + +@State(Scope.Benchmark) +@Fork(value = 1) +@Warmup(iterations = 5) +@Measurement(iterations = 15) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +public class SubscriptionStateBenchmark { + @Param({"5000"}) + int topicCount; + + @Param({"50"}) + int partitionCount; + + SubscriptionState subscriptionState; + + @Setup(Level.Trial) + public void setup() { + Set assignment = new HashSet<>(topicCount * partitionCount); + IntStream.range(0, topicCount).forEach(topicId -> + IntStream.range(0, partitionCount).forEach(partitionId -> + assignment.add(new TopicPartition(String.format("topic-%04d", topicId), partitionId)) + ) + ); + subscriptionState = new SubscriptionState(new LogContext(), OffsetResetStrategy.EARLIEST); + subscriptionState.assignFromUser(assignment); + SubscriptionState.FetchPosition position = new SubscriptionState.FetchPosition( + 0L, + Optional.of(0), + new Metadata.LeaderAndEpoch(Optional.of(new Node(0, "host", 9092)), Optional.of(10)) + ); + assignment.forEach(topicPartition -> { + subscriptionState.seekUnvalidated(topicPartition, position); + subscriptionState.completeValidation(topicPartition); + }); + } + + @Benchmark + public boolean testHasAllFetchPositions() { + return subscriptionState.hasAllFetchPositions(); + } + + @Benchmark + public int testFetchablePartitions() { + return subscriptionState.fetchablePartitions(tp -> true).size(); + } + + @Benchmark + public int testPartitionsNeedingValidation() { + return subscriptionState.partitionsNeedingValidation(0L).size(); + } +} diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetcher/ReplicaFetcherThreadBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetcher/ReplicaFetcherThreadBenchmark.java new file mode 100644 index 0000000..7f03788 --- /dev/null +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetcher/ReplicaFetcherThreadBenchmark.java @@ -0,0 +1,375 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.jmh.fetcher; + +import kafka.api.ApiVersion; +import kafka.api.ApiVersion$; +import kafka.cluster.BrokerEndPoint; +import kafka.cluster.DelayedOperations; +import kafka.cluster.IsrChangeListener; +import kafka.cluster.Partition; +import kafka.log.CleanerConfig; +import kafka.log.Defaults; +import kafka.log.LogAppendInfo; +import kafka.log.LogConfig; +import kafka.log.LogManager; +import kafka.server.AlterIsrManager; +import kafka.server.BrokerTopicStats; +import kafka.server.FailedPartitions; +import kafka.server.InitialFetchState; +import kafka.server.KafkaConfig; +import kafka.server.LogDirFailureChannel; +import kafka.server.MetadataCache; +import kafka.server.OffsetAndEpoch; +import kafka.server.OffsetTruncationState; +import kafka.server.QuotaFactory; +import kafka.server.ReplicaFetcherThread; +import kafka.server.ReplicaManager; +import kafka.server.ReplicaQuota; +import kafka.server.builders.LogManagerBuilder; +import kafka.server.builders.ReplicaManagerBuilder; +import kafka.server.checkpoints.OffsetCheckpoints; +import kafka.server.metadata.MockConfigRepository; +import kafka.server.metadata.ZkMetadataCache; +import kafka.utils.KafkaScheduler; +import kafka.utils.MockTime; +import kafka.utils.Pool; +import kafka.utils.TestUtils; +import kafka.zk.KafkaZkClient; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.message.FetchResponseData; +import org.apache.kafka.common.message.LeaderAndIsrRequestData; +import org.apache.kafka.common.message.OffsetForLeaderEpochRequestData.OffsetForLeaderPartition; +import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset; +import org.apache.kafka.common.message.UpdateMetadataRequestData; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.BaseRecords; +import org.apache.kafka.common.record.RecordsSend; +import org.apache.kafka.common.requests.FetchRequest; +import org.apache.kafka.common.requests.FetchResponse; +import org.apache.kafka.common.requests.UpdateMetadataRequest; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.mockito.Mockito; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Properties; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import scala.Option; +import scala.collection.Iterator; +import scala.collection.Map; + +@State(Scope.Benchmark) +@Fork(value = 1) +@Warmup(iterations = 5) +@Measurement(iterations = 15) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) + +public class ReplicaFetcherThreadBenchmark { + @Param({"100", "500", "1000", "5000"}) + private int partitionCount; + + private ReplicaFetcherBenchThread fetcher; + private LogManager logManager; + private File logDir = new File(System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString()); + private KafkaScheduler scheduler = new KafkaScheduler(1, "scheduler", true); + private Pool pool = new Pool(Option.empty()); + private Metrics metrics = new Metrics(); + private ReplicaManager replicaManager; + private Option topicId = Option.apply(Uuid.randomUuid()); + + @Setup(Level.Trial) + public void setup() throws IOException { + if (!logDir.mkdir()) + throw new IOException("error creating test directory"); + + scheduler.startup(); + Properties props = new Properties(); + props.put("zookeeper.connect", "127.0.0.1:9999"); + KafkaConfig config = new KafkaConfig(props); + LogConfig logConfig = createLogConfig(); + + BrokerTopicStats brokerTopicStats = new BrokerTopicStats(); + LogDirFailureChannel logDirFailureChannel = Mockito.mock(LogDirFailureChannel.class); + List logDirs = Collections.singletonList(logDir); + logManager = new LogManagerBuilder(). + setLogDirs(logDirs). + setInitialOfflineDirs(Collections.emptyList()). + setConfigRepository(new MockConfigRepository()). + setInitialDefaultConfig(logConfig). + setCleanerConfig(new CleanerConfig(0, 0, 0, 0, 0, 0.0, 0, false, "MD5")). + setRecoveryThreadsPerDataDir(1). + setFlushCheckMs(1000L). + setFlushRecoveryOffsetCheckpointMs(10000L). + setFlushStartOffsetCheckpointMs(10000L). + setRetentionCheckMs(1000L). + setMaxPidExpirationMs(60000). + setInterBrokerProtocolVersion(ApiVersion.latestVersion()). + setScheduler(scheduler). + setBrokerTopicStats(brokerTopicStats). + setLogDirFailureChannel(logDirFailureChannel). + setTime(Time.SYSTEM). + setKeepPartitionMetadataFile(true). + build(); + + LinkedHashMap initialFetched = new LinkedHashMap<>(); + HashMap topicIds = new HashMap<>(); + scala.collection.mutable.Map initialFetchStates = new scala.collection.mutable.HashMap<>(); + List updatePartitionState = new ArrayList<>(); + for (int i = 0; i < partitionCount; i++) { + TopicPartition tp = new TopicPartition("topic", i); + + List replicas = Arrays.asList(0, 1, 2); + LeaderAndIsrRequestData.LeaderAndIsrPartitionState partitionState = new LeaderAndIsrRequestData.LeaderAndIsrPartitionState() + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(0) + .setIsr(replicas) + .setZkVersion(1) + .setReplicas(replicas) + .setIsNew(true); + + IsrChangeListener isrChangeListener = Mockito.mock(IsrChangeListener.class); + OffsetCheckpoints offsetCheckpoints = Mockito.mock(OffsetCheckpoints.class); + Mockito.when(offsetCheckpoints.fetch(logDir.getAbsolutePath(), tp)).thenReturn(Option.apply(0L)); + AlterIsrManager isrChannelManager = Mockito.mock(AlterIsrManager.class); + Partition partition = new Partition(tp, 100, ApiVersion$.MODULE$.latestVersion(), + 0, Time.SYSTEM, isrChangeListener, new DelayedOperationsMock(tp), + Mockito.mock(MetadataCache.class), logManager, isrChannelManager); + + partition.makeFollower(partitionState, offsetCheckpoints, topicId); + pool.put(tp, partition); + initialFetchStates.put(tp, new InitialFetchState(topicId, new BrokerEndPoint(3, "host", 3000), 0, 0)); + BaseRecords fetched = new BaseRecords() { + @Override + public int sizeInBytes() { + return 0; + } + + @Override + public RecordsSend toSend() { + return null; + } + }; + initialFetched.put(new TopicIdPartition(topicId.get(), tp), new FetchResponseData.PartitionData() + .setPartitionIndex(tp.partition()) + .setLastStableOffset(0) + .setLogStartOffset(0) + .setRecords(fetched)); + + updatePartitionState.add( + new UpdateMetadataRequestData.UpdateMetadataPartitionState() + .setTopicName("topic") + .setPartitionIndex(i) + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(0) + .setIsr(replicas) + .setZkVersion(1) + .setReplicas(replicas)); + } + UpdateMetadataRequest updateMetadataRequest = new UpdateMetadataRequest.Builder(ApiKeys.UPDATE_METADATA.latestVersion(), + 0, 0, 0, updatePartitionState, Collections.emptyList(), topicIds).build(); + + // TODO: fix to support raft + ZkMetadataCache metadataCache = new ZkMetadataCache(0); + metadataCache.updateMetadata(0, updateMetadataRequest); + + replicaManager = new ReplicaManagerBuilder(). + setConfig(config). + setMetrics(metrics). + setTime(new MockTime()). + setZkClient(Mockito.mock(KafkaZkClient.class)). + setScheduler(scheduler). + setLogManager(logManager). + setQuotaManagers(Mockito.mock(QuotaFactory.QuotaManagers.class)). + setBrokerTopicStats(brokerTopicStats). + setMetadataCache(metadataCache). + setLogDirFailureChannel(new LogDirFailureChannel(logDirs.size())). + setAlterIsrManager(TestUtils.createAlterIsrManager()). + build(); + fetcher = new ReplicaFetcherBenchThread(config, replicaManager, pool); + fetcher.addPartitions(initialFetchStates); + // force a pass to move partitions to fetching state. We do this in the setup phase + // so that we do not measure this time as part of the steady state work + fetcher.doWork(); + // handle response to engage the incremental fetch session handler + fetcher.fetchSessionHandler().handleResponse(FetchResponse.of(Errors.NONE, 0, 999, initialFetched), ApiKeys.FETCH.latestVersion()); + } + + @TearDown(Level.Trial) + public void tearDown() throws IOException { + metrics.close(); + replicaManager.shutdown(false); + logManager.shutdown(); + scheduler.shutdown(); + Utils.delete(logDir); + } + + @Benchmark + public long testFetcher() { + fetcher.doWork(); + return fetcher.fetcherStats().requestRate().count(); + } + + // avoid mocked DelayedOperations to avoid mocked class affecting benchmark results + private static class DelayedOperationsMock extends DelayedOperations { + DelayedOperationsMock(TopicPartition topicPartition) { + super(topicPartition, null, null, null); + } + + @Override + public int numDelayedDelete() { + return 0; + } + } + + private static LogConfig createLogConfig() { + Properties logProps = new Properties(); + logProps.put(LogConfig.SegmentMsProp(), Defaults.SegmentMs()); + logProps.put(LogConfig.SegmentBytesProp(), Defaults.SegmentSize()); + logProps.put(LogConfig.RetentionMsProp(), Defaults.RetentionMs()); + logProps.put(LogConfig.RetentionBytesProp(), Defaults.RetentionSize()); + logProps.put(LogConfig.SegmentJitterMsProp(), Defaults.SegmentJitterMs()); + logProps.put(LogConfig.CleanupPolicyProp(), Defaults.CleanupPolicy()); + logProps.put(LogConfig.MaxMessageBytesProp(), Defaults.MaxMessageSize()); + logProps.put(LogConfig.IndexIntervalBytesProp(), Defaults.IndexInterval()); + logProps.put(LogConfig.SegmentIndexBytesProp(), Defaults.MaxIndexSize()); + logProps.put(LogConfig.FileDeleteDelayMsProp(), Defaults.FileDeleteDelayMs()); + return LogConfig.apply(logProps, new scala.collection.immutable.HashSet<>()); + } + + + static class ReplicaFetcherBenchThread extends ReplicaFetcherThread { + private final Pool pool; + + ReplicaFetcherBenchThread(KafkaConfig config, + ReplicaManager replicaManager, + Pool partitions) { + super("name", + 3, + new BrokerEndPoint(3, "host", 3000), + config, + new FailedPartitions(), + replicaManager, + new Metrics(), + Time.SYSTEM, + new ReplicaQuota() { + @Override + public boolean isQuotaExceeded() { + return false; + } + + @Override + public void record(long value) { + } + + @Override + public boolean isThrottled(TopicPartition topicPartition) { + return false; + } + }, + Option.empty()); + + pool = partitions; + } + + @Override + public Option latestEpoch(TopicPartition topicPartition) { + return Option.apply(0); + } + + @Override + public long logStartOffset(TopicPartition topicPartition) { + return pool.get(topicPartition).localLogOrException().logStartOffset(); + } + + @Override + public long logEndOffset(TopicPartition topicPartition) { + return 0; + } + + @Override + public void truncate(TopicPartition tp, OffsetTruncationState offsetTruncationState) { + // pretend to truncate to move to Fetching state + } + + @Override + public Option endOffsetForEpoch(TopicPartition topicPartition, int epoch) { + return Option.apply(new OffsetAndEpoch(0, 0)); + } + + @Override + public Option processPartitionData(TopicPartition topicPartition, long fetchOffset, + FetchResponseData.PartitionData partitionData) { + return Option.empty(); + } + + @Override + public long fetchEarliestOffsetFromLeader(TopicPartition topicPartition, int currentLeaderEpoch) { + return 0; + } + + @Override + public Map fetchEpochEndOffsets(Map partitions) { + scala.collection.mutable.Map endOffsets = new scala.collection.mutable.HashMap<>(); + Iterator iterator = partitions.keys().iterator(); + while (iterator.hasNext()) { + TopicPartition tp = iterator.next(); + endOffsets.put(tp, new EpochEndOffset() + .setPartition(tp.partition()) + .setErrorCode(Errors.NONE.code()) + .setLeaderEpoch(0) + .setEndOffset(100)); + } + return endOffsets; + } + + @Override + public Map fetchFromLeader(FetchRequest.Builder fetchRequest) { + return new scala.collection.mutable.HashMap<>(); + } + } +} diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetchsession/FetchSessionBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetchsession/FetchSessionBenchmark.java new file mode 100644 index 0000000..26216b9 --- /dev/null +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/fetchsession/FetchSessionBenchmark.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.jmh.fetchsession; + +import org.apache.kafka.clients.FetchSessionHandler; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.message.FetchResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.FetchRequest; +import org.apache.kafka.common.requests.FetchResponse; +import org.apache.kafka.common.utils.LogContext; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.TimeUnit; + +@State(Scope.Benchmark) +@Fork(value = 1) +@Warmup(iterations = 5) +@Measurement(iterations = 10) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +public class FetchSessionBenchmark { + private static final LogContext LOG_CONTEXT = new LogContext("[BenchFetchSessionHandler]="); + + @Param(value = {"10", "100", "1000"}) + private int partitionCount; + + @Param(value = {"0", "10", "100"}) + private int updatedPercentage; + + @Param(value = {"false", "true"}) + private boolean presize; + + private LinkedHashMap fetches; + private FetchSessionHandler handler; + private Map topicIds; + + @Setup(Level.Trial) + public void setUp() { + fetches = new LinkedHashMap<>(); + handler = new FetchSessionHandler(LOG_CONTEXT, 1); + topicIds = new HashMap<>(); + FetchSessionHandler.Builder builder = handler.newBuilder(); + + Uuid id = Uuid.randomUuid(); + topicIds.put("foo", id); + + LinkedHashMap respMap = new LinkedHashMap<>(); + for (int i = 0; i < partitionCount; i++) { + TopicPartition tp = new TopicPartition("foo", i); + FetchRequest.PartitionData partitionData = new FetchRequest.PartitionData(id, 0, 0, 200, Optional.empty()); + fetches.put(tp, partitionData); + builder.add(tp, partitionData); + respMap.put(new TopicIdPartition(id, tp), new FetchResponseData.PartitionData() + .setPartitionIndex(tp.partition()) + .setLastStableOffset(0) + .setLogStartOffset(0)); + } + builder.build(); + // build and handle an initial response so that the next fetch will be incremental + handler.handleResponse(FetchResponse.of(Errors.NONE, 0, 1, respMap), ApiKeys.FETCH.latestVersion()); + + int counter = 0; + for (TopicPartition topicPartition: new ArrayList<>(fetches.keySet())) { + if (updatedPercentage != 0 && counter % (100 / updatedPercentage) == 0) { + // reorder in fetch session, and update log start offset + fetches.remove(topicPartition); + fetches.put(topicPartition, new FetchRequest.PartitionData(Uuid.ZERO_UUID, 50, 40, 200, + Optional.empty())); + } + counter++; + } + } + + @Benchmark + @OutputTimeUnit(TimeUnit.NANOSECONDS) + public void incrementalFetchSessionBuild() { + FetchSessionHandler.Builder builder; + if (presize) + builder = handler.newBuilder(fetches.size(), true); + else + builder = handler.newBuilder(); + + // Should we keep lookup to mimic how adding really works? + for (Map.Entry entry: fetches.entrySet()) { + TopicPartition topicPartition = entry.getKey(); + builder.add(topicPartition, entry.getValue()); + } + + builder.build(); + } +} diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/metadata/MetadataRequestBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/metadata/MetadataRequestBenchmark.java new file mode 100644 index 0000000..83dd7eb --- /dev/null +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/metadata/MetadataRequestBenchmark.java @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.jmh.metadata; + +import kafka.controller.KafkaController; +import kafka.coordinator.group.GroupCoordinator; +import kafka.coordinator.transaction.TransactionCoordinator; +import kafka.network.RequestChannel; +import kafka.network.RequestConvertToJson; +import kafka.server.AutoTopicCreationManager; +import kafka.server.BrokerTopicStats; +import kafka.server.ClientQuotaManager; +import kafka.server.ClientRequestQuotaManager; +import kafka.server.ControllerMutationQuotaManager; +import kafka.server.FetchManager; +import kafka.server.KafkaApis; +import kafka.server.KafkaConfig; +import kafka.server.KafkaConfig$; +import kafka.server.MetadataCache; +import kafka.server.QuotaFactory; +import kafka.server.ReplicaManager; +import kafka.server.ReplicationQuotaManager; +import kafka.server.SimpleApiVersionManager; +import kafka.server.ZkAdminManager; +import kafka.server.ZkSupport; +import kafka.server.builders.KafkaApisBuilder; +import kafka.server.metadata.MockConfigRepository; +import kafka.server.metadata.ZkMetadataCache; +import kafka.zk.KafkaZkClient; +import org.apache.kafka.common.memory.MemoryPool; +import org.apache.kafka.common.message.ApiMessageType; +import org.apache.kafka.common.message.UpdateMetadataRequestData.UpdateMetadataBroker; +import org.apache.kafka.common.message.UpdateMetadataRequestData.UpdateMetadataEndpoint; +import org.apache.kafka.common.message.UpdateMetadataRequestData.UpdateMetadataPartitionState; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.network.ClientInformation; +import org.apache.kafka.common.network.ListenerName; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.requests.MetadataRequest; +import org.apache.kafka.common.requests.RequestContext; +import org.apache.kafka.common.requests.RequestHeader; +import org.apache.kafka.common.requests.UpdateMetadataRequest; +import org.apache.kafka.common.security.auth.KafkaPrincipal; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.common.utils.Time; +import org.mockito.Mockito; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; +import scala.Option; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.Optional; +import java.util.Properties; +import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; + +@State(Scope.Benchmark) +@Fork(value = 1) +@Warmup(iterations = 5) +@Measurement(iterations = 15) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) + +public class MetadataRequestBenchmark { + @Param({"500", "1000", "5000"}) + private int topicCount; + @Param({"10", "20", "50"}) + private int partitionCount; + + private RequestChannel requestChannel = Mockito.mock(RequestChannel.class, Mockito.withSettings().stubOnly()); + private RequestChannel.Metrics requestChannelMetrics = Mockito.mock(RequestChannel.Metrics.class); + private ReplicaManager replicaManager = Mockito.mock(ReplicaManager.class); + private GroupCoordinator groupCoordinator = Mockito.mock(GroupCoordinator.class); + private ZkAdminManager adminManager = Mockito.mock(ZkAdminManager.class); + private TransactionCoordinator transactionCoordinator = Mockito.mock(TransactionCoordinator.class); + private KafkaController kafkaController = Mockito.mock(KafkaController.class); + private AutoTopicCreationManager autoTopicCreationManager = Mockito.mock(AutoTopicCreationManager.class); + private KafkaZkClient kafkaZkClient = Mockito.mock(KafkaZkClient.class); + private Metrics metrics = new Metrics(); + private int brokerId = 1; + private ZkMetadataCache metadataCache = MetadataCache.zkMetadataCache(brokerId); + private ClientQuotaManager clientQuotaManager = Mockito.mock(ClientQuotaManager.class); + private ClientRequestQuotaManager clientRequestQuotaManager = Mockito.mock(ClientRequestQuotaManager.class); + private ControllerMutationQuotaManager controllerMutationQuotaManager = Mockito.mock(ControllerMutationQuotaManager.class); + private ReplicationQuotaManager replicaQuotaManager = Mockito.mock(ReplicationQuotaManager.class); + private QuotaFactory.QuotaManagers quotaManagers = new QuotaFactory.QuotaManagers(clientQuotaManager, + clientQuotaManager, clientRequestQuotaManager, controllerMutationQuotaManager, replicaQuotaManager, + replicaQuotaManager, replicaQuotaManager, Option.empty()); + private FetchManager fetchManager = Mockito.mock(FetchManager.class); + private BrokerTopicStats brokerTopicStats = new BrokerTopicStats(); + private KafkaPrincipal principal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "test-user"); + private KafkaApis kafkaApis; + private RequestChannel.Request allTopicMetadataRequest; + + @Setup(Level.Trial) + public void setup() { + initializeMetadataCache(); + kafkaApis = createKafkaApis(); + allTopicMetadataRequest = buildAllTopicMetadataRequest(); + } + + private void initializeMetadataCache() { + List liveBrokers = new LinkedList<>(); + List partitionStates = new LinkedList<>(); + + IntStream.range(0, 5).forEach(brokerId -> liveBrokers.add( + new UpdateMetadataBroker().setId(brokerId) + .setEndpoints(endpoints(brokerId)) + .setRack("rack1"))); + + IntStream.range(0, topicCount).forEach(topicId -> { + String topicName = "topic-" + topicId; + + IntStream.range(0, partitionCount).forEach(partitionId -> { + partitionStates.add( + new UpdateMetadataPartitionState().setTopicName(topicName) + .setPartitionIndex(partitionId) + .setControllerEpoch(1) + .setLeader(partitionCount % 5) + .setLeaderEpoch(0) + .setIsr(Arrays.asList(0, 1, 3)) + .setZkVersion(1) + .setReplicas(Arrays.asList(0, 1, 3))); + }); + }); + + UpdateMetadataRequest updateMetadataRequest = new UpdateMetadataRequest.Builder( + ApiKeys.UPDATE_METADATA.latestVersion(), + 1, 1, 1, + partitionStates, liveBrokers, Collections.emptyMap()).build(); + metadataCache.updateMetadata(100, updateMetadataRequest); + } + + private List endpoints(final int brokerId) { + return Collections.singletonList( + new UpdateMetadataEndpoint() + .setHost("host_" + brokerId) + .setPort(9092) + .setSecurityProtocol(SecurityProtocol.PLAINTEXT.id) + .setListener(ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT).value())); + } + + private KafkaApis createKafkaApis() { + Properties kafkaProps = new Properties(); + kafkaProps.put(KafkaConfig$.MODULE$.ZkConnectProp(), "zk"); + kafkaProps.put(KafkaConfig$.MODULE$.BrokerIdProp(), brokerId + ""); + KafkaConfig config = new KafkaConfig(kafkaProps); + return new KafkaApisBuilder(). + setRequestChannel(requestChannel). + setMetadataSupport(new ZkSupport(adminManager, kafkaController, kafkaZkClient, Option.empty(), metadataCache)). + setReplicaManager(replicaManager). + setGroupCoordinator(groupCoordinator). + setTxnCoordinator(transactionCoordinator). + setAutoTopicCreationManager(autoTopicCreationManager). + setBrokerId(brokerId). + setConfig(config). + setConfigRepository(new MockConfigRepository()). + setMetadataCache(metadataCache). + setMetrics(metrics). + setAuthorizer(Optional.empty()). + setQuotas(quotaManagers). + setFetchManager(fetchManager). + setBrokerTopicStats(brokerTopicStats). + setClusterId("clusterId"). + setTime(Time.SYSTEM). + setTokenManager(null). + setApiVersionManager(new SimpleApiVersionManager(ApiMessageType.ListenerType.ZK_BROKER)). + build(); + } + + @TearDown(Level.Trial) + public void tearDown() { + kafkaApis.close(); + metrics.close(); + } + + private RequestChannel.Request buildAllTopicMetadataRequest() { + MetadataRequest metadataRequest = MetadataRequest.Builder.allTopics().build(); + RequestHeader header = new RequestHeader(metadataRequest.apiKey(), metadataRequest.version(), "", 0); + ByteBuffer bodyBuffer = metadataRequest.serialize(); + + RequestContext context = new RequestContext(header, "1", null, principal, + ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT), + SecurityProtocol.PLAINTEXT, ClientInformation.EMPTY, false); + return new RequestChannel.Request(1, context, 0, MemoryPool.NONE, bodyBuffer, requestChannelMetrics, Option.empty()); + } + + @Benchmark + public void testMetadataRequestForAllTopics() { + kafkaApis.handleTopicMetadataRequest(allTopicMetadataRequest); + } + + @Benchmark + public String testRequestToJson() { + return RequestConvertToJson.requestDesc(allTopicMetadataRequest.header(), allTopicMetadataRequest.requestLog(), allTopicMetadataRequest.isForwarded()).toString(); + } + + @Benchmark + public void testTopicIdInfo() { + metadataCache.topicIdInfo(); + } +} diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/partition/PartitionMakeFollowerBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/partition/PartitionMakeFollowerBenchmark.java new file mode 100644 index 0000000..61a94c3 --- /dev/null +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/partition/PartitionMakeFollowerBenchmark.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.jmh.partition; + +import kafka.api.ApiVersion; +import kafka.api.ApiVersion$; +import kafka.cluster.DelayedOperations; +import kafka.cluster.IsrChangeListener; +import kafka.cluster.Partition; +import kafka.log.CleanerConfig; +import kafka.log.Defaults; +import kafka.log.LogConfig; +import kafka.log.LogManager; +import kafka.server.AlterIsrManager; +import kafka.server.BrokerTopicStats; +import kafka.server.LogDirFailureChannel; +import kafka.server.MetadataCache; +import kafka.server.builders.LogManagerBuilder; +import kafka.server.checkpoints.OffsetCheckpoints; +import kafka.server.metadata.MockConfigRepository; +import kafka.utils.KafkaScheduler; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.message.LeaderAndIsrRequestData; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.SimpleRecord; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.mockito.Mockito; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; + +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.Properties; +import java.util.UUID; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import scala.Option; +import scala.compat.java8.OptionConverters; + +@State(Scope.Benchmark) +@Fork(value = 1) +@Warmup(iterations = 5) +@Measurement(iterations = 15) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) + +public class PartitionMakeFollowerBenchmark { + private LogManager logManager; + private File logDir = new File(System.getProperty("java.io.tmpdir"), UUID.randomUUID().toString()); + private KafkaScheduler scheduler = new KafkaScheduler(1, "scheduler", true); + private Partition partition; + private List replicas = Arrays.asList(0, 1, 2); + private OffsetCheckpoints offsetCheckpoints = Mockito.mock(OffsetCheckpoints.class); + private DelayedOperations delayedOperations = Mockito.mock(DelayedOperations.class); + private ExecutorService executorService = Executors.newSingleThreadExecutor(); + private Option topicId; + + @Setup(Level.Trial) + public void setup() throws IOException { + if (!logDir.mkdir()) + throw new IOException("error creating test directory"); + + scheduler.startup(); + LogConfig logConfig = createLogConfig(); + + BrokerTopicStats brokerTopicStats = new BrokerTopicStats(); + LogDirFailureChannel logDirFailureChannel = Mockito.mock(LogDirFailureChannel.class); + logManager = new LogManagerBuilder(). + setLogDirs(Collections.singletonList(logDir)). + setInitialOfflineDirs(Collections.emptyList()). + setConfigRepository(new MockConfigRepository()). + setInitialDefaultConfig(logConfig). + setCleanerConfig(new CleanerConfig(0, 0, 0, 0, 0, 0.0, 0, false, "MD5")). + setRecoveryThreadsPerDataDir(1). + setFlushCheckMs(1000L). + setFlushRecoveryOffsetCheckpointMs(10000L). + setFlushStartOffsetCheckpointMs(10000L). + setRetentionCheckMs(1000L). + setMaxPidExpirationMs(60000). + setInterBrokerProtocolVersion(ApiVersion.latestVersion()). + setScheduler(scheduler). + setBrokerTopicStats(brokerTopicStats). + setLogDirFailureChannel(logDirFailureChannel). + setTime(Time.SYSTEM).setKeepPartitionMetadataFile(true). + build(); + + TopicPartition tp = new TopicPartition("topic", 0); + topicId = OptionConverters.toScala(Optional.of(Uuid.randomUuid())); + + Mockito.when(offsetCheckpoints.fetch(logDir.getAbsolutePath(), tp)).thenReturn(Option.apply(0L)); + IsrChangeListener isrChangeListener = Mockito.mock(IsrChangeListener.class); + AlterIsrManager alterIsrManager = Mockito.mock(AlterIsrManager.class); + partition = new Partition(tp, 100, + ApiVersion$.MODULE$.latestVersion(), 0, Time.SYSTEM, + isrChangeListener, delayedOperations, + Mockito.mock(MetadataCache.class), logManager, alterIsrManager); + partition.createLogIfNotExists(true, false, offsetCheckpoints, topicId); + executorService.submit((Runnable) () -> { + SimpleRecord[] simpleRecords = new SimpleRecord[] { + new SimpleRecord(1L, "foo".getBytes(StandardCharsets.UTF_8), "1".getBytes(StandardCharsets.UTF_8)), + new SimpleRecord(2L, "bar".getBytes(StandardCharsets.UTF_8), "2".getBytes(StandardCharsets.UTF_8)) + }; + int initialOffSet = 0; + while (true) { + MemoryRecords memoryRecords = MemoryRecords.withRecords(initialOffSet, CompressionType.NONE, 0, simpleRecords); + partition.appendRecordsToFollowerOrFutureReplica(memoryRecords, false); + initialOffSet = initialOffSet + 2; + } + }); + } + + @TearDown(Level.Trial) + public void tearDown() throws IOException { + executorService.shutdownNow(); + logManager.shutdown(); + scheduler.shutdown(); + Utils.delete(logDir); + } + + @Benchmark + public boolean testMakeFollower() { + LeaderAndIsrRequestData.LeaderAndIsrPartitionState partitionState = new LeaderAndIsrRequestData.LeaderAndIsrPartitionState() + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(0) + .setIsr(replicas) + .setZkVersion(1) + .setReplicas(replicas) + .setIsNew(true); + return partition.makeFollower(partitionState, offsetCheckpoints, topicId); + } + + private static LogConfig createLogConfig() { + Properties logProps = new Properties(); + logProps.put(LogConfig.SegmentMsProp(), Defaults.SegmentMs()); + logProps.put(LogConfig.SegmentBytesProp(), Defaults.SegmentSize()); + logProps.put(LogConfig.RetentionMsProp(), Defaults.RetentionMs()); + logProps.put(LogConfig.RetentionBytesProp(), Defaults.RetentionSize()); + logProps.put(LogConfig.SegmentJitterMsProp(), Defaults.SegmentJitterMs()); + logProps.put(LogConfig.CleanupPolicyProp(), Defaults.CleanupPolicy()); + logProps.put(LogConfig.MaxMessageBytesProp(), Defaults.MaxMessageSize()); + logProps.put(LogConfig.IndexIntervalBytesProp(), Defaults.IndexInterval()); + logProps.put(LogConfig.SegmentIndexBytesProp(), Defaults.MaxIndexSize()); + logProps.put(LogConfig.FileDeleteDelayMsProp(), Defaults.FileDeleteDelayMs()); + return LogConfig.apply(logProps, new scala.collection.immutable.HashSet<>()); + } +} diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/partition/UpdateFollowerFetchStateBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/partition/UpdateFollowerFetchStateBenchmark.java new file mode 100644 index 0000000..f416755 --- /dev/null +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/partition/UpdateFollowerFetchStateBenchmark.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.jmh.partition; + +import kafka.api.ApiVersion; +import kafka.api.ApiVersion$; +import kafka.cluster.DelayedOperations; +import kafka.cluster.IsrChangeListener; +import kafka.cluster.Partition; +import kafka.log.CleanerConfig; +import kafka.log.Defaults; +import kafka.log.LogConfig; +import kafka.log.LogManager; +import kafka.server.AlterIsrManager; +import kafka.server.BrokerTopicStats; +import kafka.server.LogDirFailureChannel; +import kafka.server.LogOffsetMetadata; +import kafka.server.MetadataCache; +import kafka.server.builders.LogManagerBuilder; +import kafka.server.checkpoints.OffsetCheckpoints; +import kafka.server.metadata.MockConfigRepository; +import kafka.utils.KafkaScheduler; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrPartitionState; +import org.apache.kafka.common.utils.Time; +import org.mockito.Mockito; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; + +import java.io.File; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.Properties; +import java.util.UUID; +import java.util.concurrent.TimeUnit; +import scala.Option; +import scala.compat.java8.OptionConverters; + +@State(Scope.Benchmark) +@Fork(value = 1) +@Warmup(iterations = 5) +@Measurement(iterations = 15) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +public class UpdateFollowerFetchStateBenchmark { + private TopicPartition topicPartition = new TopicPartition(UUID.randomUUID().toString(), 0); + private Option topicId = OptionConverters.toScala(Optional.of(Uuid.randomUuid())); + private File logDir = new File(System.getProperty("java.io.tmpdir"), topicPartition.toString()); + private KafkaScheduler scheduler = new KafkaScheduler(1, "scheduler", true); + private BrokerTopicStats brokerTopicStats = new BrokerTopicStats(); + private LogDirFailureChannel logDirFailureChannel = Mockito.mock(LogDirFailureChannel.class); + private long nextOffset = 0; + private LogManager logManager; + private Partition partition; + + @Setup(Level.Trial) + public void setUp() { + scheduler.startup(); + LogConfig logConfig = createLogConfig(); + logManager = new LogManagerBuilder(). + setLogDirs(Collections.singletonList(logDir)). + setInitialOfflineDirs(Collections.emptyList()). + setConfigRepository(new MockConfigRepository()). + setInitialDefaultConfig(logConfig). + setCleanerConfig(new CleanerConfig(0, 0, 0, 0, 0, 0.0, 0, false, "MD5")). + setRecoveryThreadsPerDataDir(1). + setFlushCheckMs(1000L). + setFlushRecoveryOffsetCheckpointMs(10000L). + setFlushStartOffsetCheckpointMs(10000L). + setRetentionCheckMs(1000L). + setMaxPidExpirationMs(60000). + setInterBrokerProtocolVersion(ApiVersion.latestVersion()). + setScheduler(scheduler). + setBrokerTopicStats(brokerTopicStats). + setLogDirFailureChannel(logDirFailureChannel). + setTime(Time.SYSTEM). + setKeepPartitionMetadataFile(true). + build(); + OffsetCheckpoints offsetCheckpoints = Mockito.mock(OffsetCheckpoints.class); + Mockito.when(offsetCheckpoints.fetch(logDir.getAbsolutePath(), topicPartition)).thenReturn(Option.apply(0L)); + DelayedOperations delayedOperations = new DelayedOperationsMock(); + + // one leader, plus two followers + List replicas = new ArrayList<>(); + replicas.add(0); + replicas.add(1); + replicas.add(2); + LeaderAndIsrPartitionState partitionState = new LeaderAndIsrPartitionState() + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(0) + .setIsr(replicas) + .setZkVersion(1) + .setReplicas(replicas) + .setIsNew(true); + IsrChangeListener isrChangeListener = Mockito.mock(IsrChangeListener.class); + AlterIsrManager alterIsrManager = Mockito.mock(AlterIsrManager.class); + partition = new Partition(topicPartition, 100, + ApiVersion$.MODULE$.latestVersion(), 0, Time.SYSTEM, + isrChangeListener, delayedOperations, + Mockito.mock(MetadataCache.class), logManager, alterIsrManager); + partition.makeLeader(partitionState, offsetCheckpoints, topicId); + } + + // avoid mocked DelayedOperations to avoid mocked class affecting benchmark results + private class DelayedOperationsMock extends DelayedOperations { + DelayedOperationsMock() { + super(topicPartition, null, null, null); + } + + @Override + public int numDelayedDelete() { + return 0; + } + } + + @TearDown(Level.Trial) + public void tearDown() { + logManager.shutdown(); + scheduler.shutdown(); + } + + private LogConfig createLogConfig() { + Properties logProps = new Properties(); + logProps.put(LogConfig.SegmentMsProp(), Defaults.SegmentMs()); + logProps.put(LogConfig.SegmentBytesProp(), Defaults.SegmentSize()); + logProps.put(LogConfig.RetentionMsProp(), Defaults.RetentionMs()); + logProps.put(LogConfig.RetentionBytesProp(), Defaults.RetentionSize()); + logProps.put(LogConfig.SegmentJitterMsProp(), Defaults.SegmentJitterMs()); + logProps.put(LogConfig.CleanupPolicyProp(), Defaults.CleanupPolicy()); + logProps.put(LogConfig.MaxMessageBytesProp(), Defaults.MaxMessageSize()); + logProps.put(LogConfig.IndexIntervalBytesProp(), Defaults.IndexInterval()); + logProps.put(LogConfig.SegmentIndexBytesProp(), Defaults.MaxIndexSize()); + logProps.put(LogConfig.FileDeleteDelayMsProp(), Defaults.FileDeleteDelayMs()); + return LogConfig.apply(logProps, new scala.collection.immutable.HashSet<>()); + } + + @Benchmark + @OutputTimeUnit(TimeUnit.NANOSECONDS) + public void updateFollowerFetchStateBench() { + // measure the impact of two follower fetches on the leader + partition.updateFollowerFetchState(1, new LogOffsetMetadata(nextOffset, nextOffset, 0), + 0, 1, nextOffset); + partition.updateFollowerFetchState(2, new LogOffsetMetadata(nextOffset, nextOffset, 0), + 0, 1, nextOffset); + nextOffset++; + } + + @Benchmark + @OutputTimeUnit(TimeUnit.NANOSECONDS) + public void updateFollowerFetchStateBenchNoChange() { + // measure the impact of two follower fetches on the leader when the follower didn't + // end up fetching anything + partition.updateFollowerFetchState(1, new LogOffsetMetadata(nextOffset, nextOffset, 0), + 0, 1, 100); + partition.updateFollowerFetchState(2, new LogOffsetMetadata(nextOffset, nextOffset, 0), + 0, 1, 100); + } +} diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/producer/ProducerRecordBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/producer/ProducerRecordBenchmark.java new file mode 100644 index 0000000..83d5c2b --- /dev/null +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/producer/ProducerRecordBenchmark.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.jmh.producer; + +import org.apache.kafka.clients.producer.ProducerRecord; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +import java.util.concurrent.TimeUnit; + +@State(Scope.Benchmark) +@Fork(value = 1) +@Warmup(iterations = 5) +@Measurement(iterations = 15) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +public class ProducerRecordBenchmark { + + @Benchmark + @OutputTimeUnit(TimeUnit.NANOSECONDS) + public ProducerRecord constructorBenchmark() { + return new ProducerRecord<>("topic", "value"); + } + +} diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/producer/ProducerRequestBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/producer/ProducerRequestBenchmark.java new file mode 100644 index 0000000..22d4955 --- /dev/null +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/producer/ProducerRequestBenchmark.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.jmh.producer; + +import org.apache.kafka.common.message.ProduceRequestData; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.record.SimpleRecord; +import org.apache.kafka.common.requests.ProduceRequest; +import org.apache.kafka.common.requests.ProduceResponse; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +@State(Scope.Benchmark) +@Fork(value = 1) +@Warmup(iterations = 5) +@Measurement(iterations = 15) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +public class ProducerRequestBenchmark { + private static final int NUMBER_OF_PARTITIONS = 3; + private static final int NUMBER_OF_RECORDS = 3; + private static final List TOPIC_PRODUCE_DATA = Collections.singletonList(new ProduceRequestData.TopicProduceData() + .setName("tp") + .setPartitionData(IntStream.range(0, NUMBER_OF_PARTITIONS).mapToObj(partitionIndex -> new ProduceRequestData.PartitionProduceData() + .setIndex(partitionIndex) + .setRecords(MemoryRecords.withRecords(CompressionType.NONE, IntStream.range(0, NUMBER_OF_RECORDS) + .mapToObj(recordIndex -> new SimpleRecord(100, "hello0".getBytes(StandardCharsets.UTF_8))) + .collect(Collectors.toList()) + .toArray(new SimpleRecord[0])))) + .collect(Collectors.toList())) + ); + private static final ProduceRequestData PRODUCE_REQUEST_DATA = new ProduceRequestData() + .setTimeoutMs(100) + .setAcks((short) 1) + .setTopicData(new ProduceRequestData.TopicProduceDataCollection(TOPIC_PRODUCE_DATA.iterator())); + + private static ProduceRequest request() { + return ProduceRequest.forMagic(RecordBatch.CURRENT_MAGIC_VALUE, PRODUCE_REQUEST_DATA).build(); + } + + private static final ProduceRequest REQUEST = request(); + + @Benchmark + @OutputTimeUnit(TimeUnit.NANOSECONDS) + public ProduceRequest constructorProduceRequest() { + return request(); + } + + @Benchmark + @OutputTimeUnit(TimeUnit.NANOSECONDS) + public ProduceResponse constructorErrorResponse() { + return REQUEST.getErrorResponse(0, Errors.INVALID_REQUEST.exception()); + } + +} diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/producer/ProducerResponseBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/producer/ProducerResponseBenchmark.java new file mode 100644 index 0000000..431880a --- /dev/null +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/producer/ProducerResponseBenchmark.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.jmh.producer; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.AbstractResponse; +import org.apache.kafka.common.requests.ProduceResponse; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +import java.util.AbstractMap; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +@State(Scope.Benchmark) +@Fork(value = 1) +@Warmup(iterations = 5) +@Measurement(iterations = 15) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +public class ProducerResponseBenchmark { + private static final int NUMBER_OF_PARTITIONS = 3; + private static final int NUMBER_OF_RECORDS = 3; + private static final Map PARTITION_RESPONSE_MAP = IntStream.range(0, NUMBER_OF_PARTITIONS) + .mapToObj(partitionIndex -> new AbstractMap.SimpleEntry<>( + new TopicPartition("tp", partitionIndex), + new ProduceResponse.PartitionResponse( + Errors.NONE, + 0, + 0, + 0, + IntStream.range(0, NUMBER_OF_RECORDS) + .mapToObj(ProduceResponse.RecordError::new) + .collect(Collectors.toList())) + )) + .collect(Collectors.toMap(AbstractMap.SimpleEntry::getKey, AbstractMap.SimpleEntry::getValue)); + + /** + * this method is still used by production so we benchmark it. + * see https://issues.apache.org/jira/browse/KAFKA-10730 + */ + @SuppressWarnings("deprecation") + private static ProduceResponse response() { + return new ProduceResponse(PARTITION_RESPONSE_MAP); + } + + private static final ProduceResponse RESPONSE = response(); + + @Benchmark + @OutputTimeUnit(TimeUnit.NANOSECONDS) + public AbstractResponse constructorProduceResponse() { + return response(); + } +} diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/record/BaseRecordBatchBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/record/BaseRecordBatchBenchmark.java new file mode 100644 index 0000000..e9910da --- /dev/null +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/record/BaseRecordBatchBenchmark.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.jmh.record; + +import kafka.server.BrokerTopicStats; +import kafka.server.RequestLocal; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.record.AbstractRecords; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.MemoryRecordsBuilder; +import org.apache.kafka.common.record.Record; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.record.TimestampType; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Random; +import java.util.stream.IntStream; + +import static org.apache.kafka.common.record.RecordBatch.CURRENT_MAGIC_VALUE; + +@State(Scope.Benchmark) +public abstract class BaseRecordBatchBenchmark { + private static final int MAX_HEADER_SIZE = 5; + private static final int HEADER_KEY_SIZE = 30; + + private final Random random = new Random(0); + + final int batchCount = 100; + + public enum Bytes { + RANDOM, ONES + } + + @Param(value = {"1", "2", "10", "50", "200", "500"}) + private int maxBatchSize = 200; + + @Param(value = {"1", "2"}) + byte messageVersion = CURRENT_MAGIC_VALUE; + + @Param(value = {"100", "1000", "10000", "100000"}) + private int messageSize = 1000; + + @Param(value = {"RANDOM", "ONES"}) + private Bytes bytes = Bytes.RANDOM; + + @Param(value = {"NO_CACHING", "CREATE"}) + private String bufferSupplierStr = "NO_CACHING"; + + // zero starting offset is much faster for v1 batches, but that will almost never happen + int startingOffset; + + // Used by measureSingleMessage + ByteBuffer singleBatchBuffer; + + // Used by measureVariableBatchSize + ByteBuffer[] batchBuffers; + RequestLocal requestLocal; + final BrokerTopicStats brokerTopicStats = new BrokerTopicStats(); + + @Setup + public void init() { + // For v0 batches a zero starting offset is much faster but that will almost never happen. + // For v2 batches we use starting offset = 0 as these batches are relative to the base + // offset and measureValidation will mutate these batches between iterations + startingOffset = messageVersion == 2 ? 0 : 42; + + if (bufferSupplierStr.equals("NO_CACHING")) { + requestLocal = RequestLocal.NoCaching(); + } else if (bufferSupplierStr.equals("CREATE")) { + requestLocal = RequestLocal.withThreadConfinedCaching(); + } else { + throw new IllegalArgumentException("Unsupported buffer supplier " + bufferSupplierStr); + } + singleBatchBuffer = createBatch(1); + + batchBuffers = new ByteBuffer[batchCount]; + for (int i = 0; i < batchCount; ++i) { + int size = random.nextInt(maxBatchSize) + 1; + batchBuffers[i] = createBatch(size); + } + } + + private static Header[] createHeaders() { + char[] headerChars = new char[HEADER_KEY_SIZE]; + Arrays.fill(headerChars, 'a'); + String headerKey = new String(headerChars); + byte[] headerValue = new byte[0]; + return IntStream.range(0, MAX_HEADER_SIZE).mapToObj(index -> new Header() { + @Override + public String key() { + return headerKey; + } + + @Override + public byte[] value() { + return headerValue; + } + }).toArray(Header[]::new); + } + + abstract CompressionType compressionType(); + + private ByteBuffer createBatch(int batchSize) { + // Magic v1 does not support record headers + Header[] headers = messageVersion < RecordBatch.MAGIC_VALUE_V2 ? Record.EMPTY_HEADERS : createHeaders(); + byte[] value = new byte[messageSize]; + final ByteBuffer buf = ByteBuffer.allocate( + AbstractRecords.estimateSizeInBytesUpperBound(messageVersion, compressionType(), new byte[0], value, + headers) * batchSize + ); + + final MemoryRecordsBuilder builder = + MemoryRecords.builder(buf, messageVersion, compressionType(), TimestampType.CREATE_TIME, startingOffset); + + for (int i = 0; i < batchSize; ++i) { + switch (bytes) { + case ONES: + Arrays.fill(value, (byte) 1); + break; + case RANDOM: + random.nextBytes(value); + break; + } + + builder.append(0, null, value, headers); + } + return builder.build().buffer(); + } +} diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/record/CompressedRecordBatchValidationBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/record/CompressedRecordBatchValidationBenchmark.java new file mode 100644 index 0000000..24ac53e --- /dev/null +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/record/CompressedRecordBatchValidationBenchmark.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.jmh.record; + +import kafka.api.ApiVersion; +import kafka.common.LongRef; +import kafka.log.AppendOrigin; +import kafka.log.LogValidator; +import kafka.message.CompressionCodec; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.common.utils.Time; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +@State(Scope.Benchmark) +@Fork(value = 1) +@Warmup(iterations = 5) +@Measurement(iterations = 15) +public class CompressedRecordBatchValidationBenchmark extends BaseRecordBatchBenchmark { + + @Param(value = {"LZ4", "SNAPPY", "GZIP", "ZSTD"}) + private CompressionType compressionType = CompressionType.LZ4; + + @Override + CompressionType compressionType() { + return compressionType; + } + + @Benchmark + public void measureValidateMessagesAndAssignOffsetsCompressed(Blackhole bh) { + MemoryRecords records = MemoryRecords.readableRecords(singleBatchBuffer.duplicate()); + LogValidator.validateMessagesAndAssignOffsetsCompressed(records, new TopicPartition("a", 0), + new LongRef(startingOffset), Time.SYSTEM, System.currentTimeMillis(), + CompressionCodec.getCompressionCodec(compressionType.id), + CompressionCodec.getCompressionCodec(compressionType.id), + false, messageVersion, TimestampType.CREATE_TIME, Long.MAX_VALUE, 0, + new AppendOrigin.Client$(), + ApiVersion.latestVersion(), + brokerTopicStats, + requestLocal); + } +} diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/record/RecordBatchIterationBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/record/RecordBatchIterationBenchmark.java new file mode 100644 index 0000000..c331cd5 --- /dev/null +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/record/RecordBatchIterationBenchmark.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.jmh.record; + +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.MutableRecordBatch; +import org.apache.kafka.common.record.Record; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.utils.CloseableIterator; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.OperationsPerInvocation; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +@State(Scope.Benchmark) +@Fork(value = 1) +@Warmup(iterations = 5) +@Measurement(iterations = 15) +public class RecordBatchIterationBenchmark extends BaseRecordBatchBenchmark { + + @Param(value = {"LZ4", "SNAPPY", "GZIP", "ZSTD", "NONE"}) + private CompressionType compressionType = CompressionType.NONE; + + @Override + CompressionType compressionType() { + return compressionType; + } + + @Benchmark + public void measureIteratorForBatchWithSingleMessage(Blackhole bh) { + for (RecordBatch batch : MemoryRecords.readableRecords(singleBatchBuffer.duplicate()).batches()) { + try (CloseableIterator iterator = batch.streamingIterator(requestLocal.bufferSupplier())) { + while (iterator.hasNext()) + bh.consume(iterator.next()); + } + } + } + + @OperationsPerInvocation(value = batchCount) + @Fork(jvmArgsAppend = "-Xmx8g") + @Benchmark + public void measureStreamingIteratorForVariableBatchSize(Blackhole bh) { + for (int i = 0; i < batchCount; ++i) { + for (RecordBatch batch : MemoryRecords.readableRecords(batchBuffers[i].duplicate()).batches()) { + try (CloseableIterator iterator = batch.streamingIterator(requestLocal.bufferSupplier())) { + while (iterator.hasNext()) + bh.consume(iterator.next()); + } + } + } + } + + @OperationsPerInvocation(value = batchCount) + @Fork(jvmArgsAppend = "-Xmx8g") + @Benchmark + public void measureSkipIteratorForVariableBatchSize(Blackhole bh) { + for (int i = 0; i < batchCount; ++i) { + for (MutableRecordBatch batch : MemoryRecords.readableRecords(batchBuffers[i].duplicate()).batches()) { + try (CloseableIterator iterator = batch.skipKeyValueIterator(requestLocal.bufferSupplier())) { + while (iterator.hasNext()) + bh.consume(iterator.next()); + } + } + } + } +} diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/record/UncompressedRecordBatchValidationBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/record/UncompressedRecordBatchValidationBenchmark.java new file mode 100644 index 0000000..001837e --- /dev/null +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/record/UncompressedRecordBatchValidationBenchmark.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.jmh.record; + +import kafka.common.LongRef; +import kafka.log.AppendOrigin; +import kafka.log.LogValidator; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.TimestampType; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +@State(Scope.Benchmark) +@Fork(value = 1) +@Warmup(iterations = 5) +@Measurement(iterations = 15) +public class UncompressedRecordBatchValidationBenchmark extends BaseRecordBatchBenchmark { + + @Override + CompressionType compressionType() { + return CompressionType.NONE; + } + + @Benchmark + public void measureAssignOffsetsNonCompressed(Blackhole bh) { + MemoryRecords records = MemoryRecords.readableRecords(singleBatchBuffer.duplicate()); + LogValidator.assignOffsetsNonCompressed(records, new TopicPartition("a", 0), + new LongRef(startingOffset), System.currentTimeMillis(), false, + TimestampType.CREATE_TIME, Long.MAX_VALUE, 0, + new AppendOrigin.Client$(), messageVersion, brokerTopicStats); + } +} diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/server/CheckpointBench.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/server/CheckpointBench.java new file mode 100644 index 0000000..21a8086 --- /dev/null +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/server/CheckpointBench.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.jmh.server; + +import kafka.api.ApiVersion; +import kafka.cluster.Partition; +import kafka.log.CleanerConfig; +import kafka.log.LogConfig; +import kafka.log.LogManager; +import kafka.server.AlterIsrManager; +import kafka.server.BrokerTopicStats; +import kafka.server.KafkaConfig; +import kafka.server.LogDirFailureChannel; +import kafka.server.MetadataCache; +import kafka.server.QuotaFactory; +import kafka.server.ReplicaManager; +import kafka.server.builders.ReplicaManagerBuilder; +import kafka.server.checkpoints.OffsetCheckpoints; +import kafka.server.metadata.MockConfigRepository; +import kafka.utils.KafkaScheduler; +import kafka.utils.MockTime; +import kafka.utils.Scheduler; +import kafka.utils.TestUtils; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.utils.Utils; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Warmup; + +import java.io.File; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import scala.collection.JavaConverters; +import scala.Option; + +@Warmup(iterations = 5) +@Measurement(iterations = 5) +@Fork(3) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(value = Scope.Benchmark) +public class CheckpointBench { + + @Param({"100", "1000", "2000"}) + public int numTopics; + + @Param({"3"}) + public int numPartitions; + + private final String topicName = "foo"; + + private Scheduler scheduler; + + private Metrics metrics; + + private MockTime time; + + private KafkaConfig brokerProperties; + + private ReplicaManager replicaManager; + private QuotaFactory.QuotaManagers quotaManagers; + private LogDirFailureChannel failureChannel; + private LogManager logManager; + private AlterIsrManager alterIsrManager; + + + @SuppressWarnings("deprecation") + @Setup(Level.Trial) + public void setup() { + this.scheduler = new KafkaScheduler(1, "scheduler-thread", true); + this.brokerProperties = KafkaConfig.fromProps(TestUtils.createBrokerConfig( + 0, TestUtils.MockZkConnect(), true, true, 9092, Option.empty(), Option.empty(), + Option.empty(), true, false, 0, false, 0, false, 0, Option.empty(), 1, true, 1, + (short) 1)); + this.metrics = new Metrics(); + this.time = new MockTime(); + this.failureChannel = new LogDirFailureChannel(brokerProperties.logDirs().size()); + final List files = + JavaConverters.seqAsJavaList(brokerProperties.logDirs()).stream().map(File::new).collect(Collectors.toList()); + this.logManager = TestUtils.createLogManager(JavaConverters.asScalaBuffer(files), + LogConfig.apply(), new MockConfigRepository(), CleanerConfig.apply(1, 4 * 1024 * 1024L, 0.9d, + 1024 * 1024, 32 * 1024 * 1024, + Double.MAX_VALUE, 15 * 1000, true, "MD5"), time, ApiVersion.latestVersion()); + scheduler.startup(); + final BrokerTopicStats brokerTopicStats = new BrokerTopicStats(); + final MetadataCache metadataCache = + MetadataCache.zkMetadataCache(this.brokerProperties.brokerId()); + this.quotaManagers = + QuotaFactory.instantiate(this.brokerProperties, + this.metrics, + this.time, ""); + + this.alterIsrManager = TestUtils.createAlterIsrManager(); + this.replicaManager = new ReplicaManagerBuilder(). + setConfig(brokerProperties). + setMetrics(metrics). + setTime(time). + setScheduler(scheduler). + setLogManager(logManager). + setQuotaManagers(quotaManagers). + setBrokerTopicStats(brokerTopicStats). + setMetadataCache(metadataCache). + setLogDirFailureChannel(failureChannel). + setAlterIsrManager(alterIsrManager). + build(); + replicaManager.startup(); + + List topicPartitions = new ArrayList<>(); + for (int topicNum = 0; topicNum < numTopics; topicNum++) { + final String topicName = this.topicName + "-" + topicNum; + for (int partitionNum = 0; partitionNum < numPartitions; partitionNum++) { + topicPartitions.add(new TopicPartition(topicName, partitionNum)); + } + } + + OffsetCheckpoints checkpoints = (logDir, topicPartition) -> Option.apply(0L); + for (TopicPartition topicPartition : topicPartitions) { + final Partition partition = this.replicaManager.createPartition(topicPartition); + partition.createLogIfNotExists(true, false, checkpoints, Option.apply(Uuid.randomUuid())); + } + + replicaManager.checkpointHighWatermarks(); + } + + @TearDown(Level.Trial) + public void tearDown() throws Exception { + this.replicaManager.shutdown(false); + this.metrics.close(); + this.scheduler.shutdown(); + this.quotaManagers.shutdown(); + for (File dir : JavaConverters.asJavaCollection(logManager.liveLogDirs())) { + Utils.delete(dir); + } + } + + + @Benchmark + @Threads(1) + public void measureCheckpointHighWatermarks() { + this.replicaManager.checkpointHighWatermarks(); + } + + @Benchmark + @Threads(1) + public void measureCheckpointLogStartOffsets() { + this.logManager.checkpointLogStartOffsets(); + } +} diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/server/PartitionCreationBench.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/server/PartitionCreationBench.java new file mode 100644 index 0000000..937ac86 --- /dev/null +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/server/PartitionCreationBench.java @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.jmh.server; + +import kafka.api.ApiVersion; +import kafka.cluster.Partition; +import kafka.log.CleanerConfig; +import kafka.log.Defaults; +import kafka.log.LogConfig; +import kafka.log.LogManager; +import kafka.server.AlterIsrManager; +import kafka.server.BrokerTopicStats; +import kafka.server.KafkaConfig; +import kafka.server.LogDirFailureChannel; +import kafka.server.QuotaFactory; +import kafka.server.ReplicaManager; +import kafka.server.builders.LogManagerBuilder; +import kafka.server.builders.ReplicaManagerBuilder; +import kafka.server.checkpoints.OffsetCheckpoints; +import kafka.server.metadata.ConfigRepository; +import kafka.server.metadata.MockConfigRepository; +import kafka.server.metadata.ZkMetadataCache; +import kafka.utils.KafkaScheduler; +import kafka.utils.Scheduler; +import kafka.utils.TestUtils; +import kafka.zk.KafkaZkClient; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.message.LeaderAndIsrRequestData; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Threads; +import org.openjdk.jmh.annotations.Warmup; + +import java.io.File; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Properties; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import scala.Option; +import scala.collection.JavaConverters; + +@Warmup(iterations = 5) +@Measurement(iterations = 5) +@Fork(3) +@BenchmarkMode(Mode.AverageTime) +@State(value = Scope.Benchmark) +public class PartitionCreationBench { + @Param({"false", "true"}) + public boolean useTopicIds; + + @Param({"2000"}) + public int numPartitions; + + private final String topicName = "foo"; + + private Option topicId; + private Scheduler scheduler; + private Metrics metrics; + private Time time; + private KafkaConfig brokerProperties; + + private ReplicaManager replicaManager; + private QuotaFactory.QuotaManagers quotaManagers; + private KafkaZkClient zkClient; + private LogDirFailureChannel failureChannel; + private LogManager logManager; + private AlterIsrManager alterIsrManager; + private List topicPartitions; + + @SuppressWarnings("deprecation") + @Setup(Level.Invocation) + public void setup() { + if (useTopicIds) + topicId = Option.apply(Uuid.randomUuid()); + else + topicId = Option.empty(); + + this.scheduler = new KafkaScheduler(1, "scheduler-thread", true); + this.brokerProperties = KafkaConfig.fromProps(TestUtils.createBrokerConfig( + 0, TestUtils.MockZkConnect(), true, true, 9092, Option.empty(), Option.empty(), + Option.empty(), true, false, 0, false, 0, false, 0, Option.empty(), 1, true, 1, + (short) 1)); + this.metrics = new Metrics(); + this.time = Time.SYSTEM; + this.failureChannel = new LogDirFailureChannel(brokerProperties.logDirs().size()); + final BrokerTopicStats brokerTopicStats = new BrokerTopicStats(); + final List files = + JavaConverters.seqAsJavaList(brokerProperties.logDirs()).stream().map(File::new).collect(Collectors.toList()); + CleanerConfig cleanerConfig = CleanerConfig.apply(1, + 4 * 1024 * 1024L, 0.9d, + 1024 * 1024, 32 * 1024 * 1024, + Double.MAX_VALUE, 15 * 1000, true, "MD5"); + + ConfigRepository configRepository = new MockConfigRepository(); + this.logManager = new LogManagerBuilder(). + setLogDirs(files). + setInitialOfflineDirs(Collections.emptyList()). + setConfigRepository(configRepository). + setInitialDefaultConfig(createLogConfig()). + setCleanerConfig(cleanerConfig). + setRecoveryThreadsPerDataDir(1). + setFlushCheckMs(1000L). + setFlushRecoveryOffsetCheckpointMs(10000L). + setFlushStartOffsetCheckpointMs(10000L). + setRetentionCheckMs(1000L). + setMaxPidExpirationMs(60000). + setInterBrokerProtocolVersion(ApiVersion.latestVersion()). + setScheduler(scheduler). + setBrokerTopicStats(brokerTopicStats). + setLogDirFailureChannel(failureChannel). + setTime(Time.SYSTEM). + setKeepPartitionMetadataFile(true). + build(); + scheduler.startup(); + this.quotaManagers = QuotaFactory.instantiate(this.brokerProperties, this.metrics, this.time, ""); + this.zkClient = new KafkaZkClient(null, false, Time.SYSTEM) { + @Override + public Properties getEntityConfigs(String rootEntityType, String sanitizedEntityName) { + return new Properties(); + } + }; + this.alterIsrManager = TestUtils.createAlterIsrManager(); + this.replicaManager = new ReplicaManagerBuilder(). + setConfig(brokerProperties). + setMetrics(metrics). + setTime(time). + setZkClient(zkClient). + setScheduler(scheduler). + setLogManager(logManager). + setQuotaManagers(quotaManagers). + setBrokerTopicStats(brokerTopicStats). + setMetadataCache(new ZkMetadataCache(this.brokerProperties.brokerId())). + setLogDirFailureChannel(failureChannel). + setAlterIsrManager(alterIsrManager). + build(); + replicaManager.startup(); + replicaManager.checkpointHighWatermarks(); + } + + @TearDown(Level.Invocation) + public void tearDown() throws Exception { + this.replicaManager.shutdown(false); + logManager.shutdown(); + this.metrics.close(); + this.scheduler.shutdown(); + this.quotaManagers.shutdown(); + for (File dir : JavaConverters.asJavaCollection(logManager.liveLogDirs())) { + Utils.delete(dir); + } + this.zkClient.close(); + } + + private static LogConfig createLogConfig() { + Properties logProps = new Properties(); + logProps.put(LogConfig.SegmentMsProp(), Defaults.SegmentMs()); + logProps.put(LogConfig.SegmentBytesProp(), Defaults.SegmentSize()); + logProps.put(LogConfig.RetentionMsProp(), Defaults.RetentionMs()); + logProps.put(LogConfig.RetentionBytesProp(), Defaults.RetentionSize()); + logProps.put(LogConfig.SegmentJitterMsProp(), Defaults.SegmentJitterMs()); + logProps.put(LogConfig.CleanupPolicyProp(), Defaults.CleanupPolicy()); + logProps.put(LogConfig.MaxMessageBytesProp(), Defaults.MaxMessageSize()); + logProps.put(LogConfig.IndexIntervalBytesProp(), Defaults.IndexInterval()); + logProps.put(LogConfig.SegmentIndexBytesProp(), Defaults.MaxIndexSize()); + logProps.put(LogConfig.FileDeleteDelayMsProp(), Defaults.FileDeleteDelayMs()); + return LogConfig.apply(logProps, new scala.collection.immutable.HashSet<>()); + } + + @Benchmark + @Threads(1) + @OutputTimeUnit(TimeUnit.MILLISECONDS) + public void makeFollower() { + topicPartitions = new ArrayList<>(); + for (int partitionNum = 0; partitionNum < numPartitions; partitionNum++) { + topicPartitions.add(new TopicPartition(topicName, partitionNum)); + } + + List replicas = new ArrayList<>(); + replicas.add(0); + replicas.add(1); + replicas.add(2); + + OffsetCheckpoints checkpoints = (logDir, topicPartition) -> Option.apply(0L); + for (TopicPartition topicPartition : topicPartitions) { + final Partition partition = this.replicaManager.createPartition(topicPartition); + List inSync = new ArrayList<>(); + inSync.add(0); + inSync.add(1); + inSync.add(2); + + LeaderAndIsrRequestData.LeaderAndIsrPartitionState partitionState = new LeaderAndIsrRequestData.LeaderAndIsrPartitionState() + .setControllerEpoch(0) + .setLeader(0) + .setLeaderEpoch(0) + .setIsr(inSync) + .setZkVersion(1) + .setReplicas(replicas) + .setIsNew(true); + + partition.makeFollower(partitionState, checkpoints, topicId); + } + } +} diff --git a/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/timeline/TimelineHashMapBenchmark.java b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/timeline/TimelineHashMapBenchmark.java new file mode 100644 index 0000000..ae6c56e --- /dev/null +++ b/jmh-benchmarks/src/main/java/org/apache/kafka/jmh/timeline/TimelineHashMapBenchmark.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.jmh.timeline; + +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.timeline.SnapshotRegistry; +import org.apache.kafka.timeline.TimelineHashMap; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +@State(Scope.Benchmark) +@Fork(value = 1) +@Warmup(iterations = 3) +@Measurement(iterations = 10) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) + +public class TimelineHashMapBenchmark { + private final static int NUM_ENTRIES = 1_000_000; + + @Benchmark + public Map testAddEntriesInHashMap() { + HashMap map = new HashMap<>(NUM_ENTRIES); + for (int i = 0; i < NUM_ENTRIES; i++) { + int key = (int) (0xffffffff & ((i * 2862933555777941757L) + 3037000493L)); + map.put(key, String.valueOf(key)); + } + return map; + } + + @Benchmark + public Map testAddEntriesInTimelineMap() { + SnapshotRegistry snapshotRegistry = new SnapshotRegistry(new LogContext()); + TimelineHashMap map = + new TimelineHashMap<>(snapshotRegistry, NUM_ENTRIES); + for (int i = 0; i < NUM_ENTRIES; i++) { + int key = (int) (0xffffffff & ((i * 2862933555777941757L) + 3037000493L)); + map.put(key, String.valueOf(key)); + } + return map; + } + + @Benchmark + public Map testAddEntriesWithSnapshots() { + SnapshotRegistry snapshotRegistry = new SnapshotRegistry(new LogContext()); + TimelineHashMap map = + new TimelineHashMap<>(snapshotRegistry, NUM_ENTRIES); + long epoch = 0; + int j = 0; + for (int i = 0; i < NUM_ENTRIES; i++) { + int key = (int) (0xffffffff & ((i * 2862933555777941757L) + 3037000493L)); + if (j > 10 && key % 3 == 0) { + snapshotRegistry.deleteSnapshotsUpTo(epoch - 1000); + snapshotRegistry.getOrCreateSnapshot(epoch); + j = 0; + } else { + j++; + } + map.put(key, String.valueOf(key)); + epoch++; + } + return map; + } +} diff --git a/kafka-merge-pr.py b/kafka-merge-pr.py new file mode 100755 index 0000000..abe7cbd --- /dev/null +++ b/kafka-merge-pr.py @@ -0,0 +1,476 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Utility for creating well-formed pull request merges and pushing them to Apache. This script is a modified version +# of the one created by the Spark project (https://github.com/apache/spark/blob/master/dev/merge_spark_pr.py). +# +# Usage: ./kafka-merge-pr.py (see config env vars below) +# +# This utility assumes you already have local a kafka git folder and that you +# have added remotes corresponding to both: +# (i) the github apache kafka mirror and +# (ii) the apache kafka git repo. + +import json +import os +import re +import subprocess +import sys +import urllib2 + +try: + import jira.client + JIRA_IMPORTED = True +except ImportError: + JIRA_IMPORTED = False + +PROJECT_NAME = "kafka" + +CAPITALIZED_PROJECT_NAME = "kafka".upper() + +# Location of the local git repository +REPO_HOME = os.environ.get("%s_HOME" % CAPITALIZED_PROJECT_NAME, os.getcwd()) +# Remote name which points to the GitHub site +PR_REMOTE_NAME = os.environ.get("PR_REMOTE_NAME", "apache-github") +# Remote name where we want to push the changes to (GitHub by default, but Apache Git would work if GitHub is down) +PUSH_REMOTE_NAME = os.environ.get("PUSH_REMOTE_NAME", "apache-github") +# ASF JIRA username +JIRA_USERNAME = os.environ.get("JIRA_USERNAME", "") +# ASF JIRA password +JIRA_PASSWORD = os.environ.get("JIRA_PASSWORD", "") +# OAuth key used for issuing requests against the GitHub API. If this is not defined, then requests +# will be unauthenticated. You should only need to configure this if you find yourself regularly +# exceeding your IP's unauthenticated request rate limit. You can create an OAuth key at +# https://github.com/settings/tokens. This script only requires the "public_repo" scope. +GITHUB_OAUTH_KEY = os.environ.get("GITHUB_OAUTH_KEY") + +GITHUB_USER = os.environ.get("GITHUB_USER", "apache") +GITHUB_BASE = "https://github.com/%s/%s/pull" % (GITHUB_USER, PROJECT_NAME) +GITHUB_API_BASE = "https://api.github.com/repos/%s/%s" % (GITHUB_USER, PROJECT_NAME) +JIRA_BASE = "https://issues.apache.org/jira/browse" +JIRA_API_BASE = "https://issues.apache.org/jira" +# Prefix added to temporary branches +TEMP_BRANCH_PREFIX = "PR_TOOL" + +DEV_BRANCH_NAME = "trunk" + +DEFAULT_FIX_VERSION = os.environ.get("DEFAULT_FIX_VERSION", "3.1.0") + +def get_json(url): + try: + request = urllib2.Request(url) + if GITHUB_OAUTH_KEY: + request.add_header('Authorization', 'token %s' % GITHUB_OAUTH_KEY) + return json.load(urllib2.urlopen(request)) + except urllib2.HTTPError as e: + if "X-RateLimit-Remaining" in e.headers and e.headers["X-RateLimit-Remaining"] == '0': + print "Exceeded the GitHub API rate limit; see the instructions in " + \ + "kafka-merge-pr.py to configure an OAuth token for making authenticated " + \ + "GitHub requests." + else: + print "Unable to fetch URL, exiting: %s" % url + sys.exit(-1) + + +def fail(msg): + print msg + clean_up() + sys.exit(-1) + + +def run_cmd(cmd): + print cmd + if isinstance(cmd, list): + return subprocess.check_output(cmd) + else: + return subprocess.check_output(cmd.split(" ")) + + +def continue_maybe(prompt): + result = raw_input("\n%s (y/n): " % prompt) + if result.lower() != "y": + fail("Okay, exiting") + +def clean_up(): + if original_head != get_current_branch(): + print "Restoring head pointer to %s" % original_head + run_cmd("git checkout %s" % original_head) + + branches = run_cmd("git branch").replace(" ", "").split("\n") + + for branch in filter(lambda x: x.startswith(TEMP_BRANCH_PREFIX), branches): + print "Deleting local branch %s" % branch + run_cmd("git branch -D %s" % branch) + +def get_current_branch(): + return run_cmd("git rev-parse --abbrev-ref HEAD").replace("\n", "") + +# merge the requested PR and return the merge hash +def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): + pr_branch_name = "%s_MERGE_PR_%s" % (TEMP_BRANCH_PREFIX, pr_num) + target_branch_name = "%s_MERGE_PR_%s_%s" % (TEMP_BRANCH_PREFIX, pr_num, target_ref.upper()) + run_cmd("git fetch %s pull/%s/head:%s" % (PR_REMOTE_NAME, pr_num, pr_branch_name)) + run_cmd("git fetch %s %s:%s" % (PUSH_REMOTE_NAME, target_ref, target_branch_name)) + run_cmd("git checkout %s" % target_branch_name) + + had_conflicts = False + try: + run_cmd(['git', 'merge', pr_branch_name, '--squash']) + except Exception as e: + msg = "Error merging: %s\nWould you like to manually fix-up this merge?" % e + continue_maybe(msg) + msg = "Okay, please fix any conflicts and 'git add' conflicting files... Finished?" + continue_maybe(msg) + had_conflicts = True + + commit_authors = run_cmd(['git', 'log', 'HEAD..%s' % pr_branch_name, + '--pretty=format:%an <%ae>']).split("\n") + distinct_authors = sorted(set(commit_authors), + key=lambda x: commit_authors.count(x), reverse=True) + primary_author = raw_input( + "Enter primary author in the format of \"name \" [%s]: " % + distinct_authors[0]) + if primary_author == "": + primary_author = distinct_authors[0] + + reviewers = raw_input( + "Enter reviewers in the format of \"name1 , name2 \": ").strip() + + run_cmd(['git', 'log', 'HEAD..%s' % pr_branch_name, '--pretty=format:%h [%an] %s']).split("\n") + + merge_message_flags = [] + + merge_message_flags += ["-m", title] + + if body is not None: + # Remove "Committer Checklist" section + checklist_index = body.find("### Committer Checklist") + if checklist_index != -1: + body = body[:checklist_index].rstrip() + # Remove @ symbols from the body to avoid triggering e-mails to people every time someone creates a + # public fork of the project. + body = body.replace("@", "") + merge_message_flags += ["-m", body] + + authors = "\n".join(["Author: %s" % a for a in distinct_authors]) + + merge_message_flags += ["-m", authors] + + if reviewers != "": + merge_message_flags += ["-m", "Reviewers: %s" % reviewers] + + if had_conflicts: + committer_name = run_cmd("git config --get user.name").strip() + committer_email = run_cmd("git config --get user.email").strip() + message = "This patch had conflicts when merged, resolved by\nCommitter: %s <%s>" % ( + committer_name, committer_email) + merge_message_flags += ["-m", message] + + # The string "Closes #%s" string is required for GitHub to correctly close the PR + close_line = "Closes #%s from %s" % (pr_num, pr_repo_desc) + merge_message_flags += ["-m", close_line] + + run_cmd(['git', 'commit', '--author="%s"' % primary_author] + merge_message_flags) + + continue_maybe("Merge complete (local ref %s). Push to %s?" % ( + target_branch_name, PUSH_REMOTE_NAME)) + + try: + run_cmd('git push %s %s:%s' % (PUSH_REMOTE_NAME, target_branch_name, target_ref)) + except Exception as e: + clean_up() + fail("Exception while pushing: %s" % e) + + merge_hash = run_cmd("git rev-parse %s" % target_branch_name)[:8] + clean_up() + print("Pull request #%s merged!" % pr_num) + print("Merge hash: %s" % merge_hash) + return merge_hash + + +def cherry_pick(pr_num, merge_hash, default_branch): + pick_ref = raw_input("Enter a branch name [%s]: " % default_branch) + if pick_ref == "": + pick_ref = default_branch + + pick_branch_name = "%s_PICK_PR_%s_%s" % (TEMP_BRANCH_PREFIX, pr_num, pick_ref.upper()) + + run_cmd("git fetch %s %s:%s" % (PUSH_REMOTE_NAME, pick_ref, pick_branch_name)) + run_cmd("git checkout %s" % pick_branch_name) + + try: + run_cmd("git cherry-pick -sx %s" % merge_hash) + except Exception as e: + msg = "Error cherry-picking: %s\nWould you like to manually fix-up this merge?" % e + continue_maybe(msg) + msg = "Okay, please fix any conflicts and finish the cherry-pick. Finished?" + continue_maybe(msg) + + continue_maybe("Pick complete (local ref %s). Push to %s?" % ( + pick_branch_name, PUSH_REMOTE_NAME)) + + try: + run_cmd('git push %s %s:%s' % (PUSH_REMOTE_NAME, pick_branch_name, pick_ref)) + except Exception as e: + clean_up() + fail("Exception while pushing: %s" % e) + + pick_hash = run_cmd("git rev-parse %s" % pick_branch_name)[:8] + clean_up() + + print("Pull request #%s picked into %s!" % (pr_num, pick_ref)) + print("Pick hash: %s" % pick_hash) + return pick_ref + + +def fix_version_from_branch(branch, versions): + # Note: Assumes this is a sorted (newest->oldest) list of un-released versions + if branch == DEV_BRANCH_NAME: + versions = filter(lambda x: x == DEFAULT_FIX_VERSION, versions) + if len(versions) > 0: + return versions[0] + else: + return None + else: + versions = filter(lambda x: x.startswith(branch), versions) + if len(versions) > 0: + return versions[-1] + else: + return None + + +def resolve_jira_issue(merge_branches, comment, default_jira_id=""): + asf_jira = jira.client.JIRA({'server': JIRA_API_BASE}, + basic_auth=(JIRA_USERNAME, JIRA_PASSWORD)) + + jira_id = raw_input("Enter a JIRA id [%s]: " % default_jira_id) + if jira_id == "": + jira_id = default_jira_id + + try: + issue = asf_jira.issue(jira_id) + except Exception as e: + fail("ASF JIRA could not find %s\n%s" % (jira_id, e)) + + cur_status = issue.fields.status.name + cur_summary = issue.fields.summary + cur_assignee = issue.fields.assignee + if cur_assignee is None: + cur_assignee = "NOT ASSIGNED!!!" + else: + cur_assignee = cur_assignee.displayName + + if cur_status == "Resolved" or cur_status == "Closed": + fail("JIRA issue %s already has status '%s'" % (jira_id, cur_status)) + print ("=== JIRA %s ===" % jira_id) + print ("summary\t\t%s\nassignee\t%s\nstatus\t\t%s\nurl\t\t%s/%s\n" % ( + cur_summary, cur_assignee, cur_status, JIRA_BASE, jira_id)) + + versions = asf_jira.project_versions(CAPITALIZED_PROJECT_NAME) + versions = sorted(versions, key=lambda x: x.name, reverse=True) + versions = filter(lambda x: x.raw['released'] is False, versions) + + version_names = map(lambda x: x.name, versions) + default_fix_versions = map(lambda x: fix_version_from_branch(x, version_names), merge_branches) + default_fix_versions = filter(lambda x: x != None, default_fix_versions) + default_fix_versions = ",".join(default_fix_versions) + + fix_versions = raw_input("Enter comma-separated fix version(s) [%s]: " % default_fix_versions) + if fix_versions == "": + fix_versions = default_fix_versions + fix_versions = fix_versions.replace(" ", "").split(",") + + def get_version_json(version_str): + return filter(lambda v: v.name == version_str, versions)[0].raw + + jira_fix_versions = map(lambda v: get_version_json(v), fix_versions) + + resolve = filter(lambda a: a['name'] == "Resolve Issue", asf_jira.transitions(jira_id))[0] + resolution = filter(lambda r: r.raw['name'] == "Fixed", asf_jira.resolutions())[0] + asf_jira.transition_issue( + jira_id, resolve["id"], fixVersions = jira_fix_versions, + comment = comment, resolution = {'id': resolution.raw['id']}) + + print "Successfully resolved %s with fixVersions=%s!" % (jira_id, fix_versions) + + +def resolve_jira_issues(title, merge_branches, comment): + jira_ids = re.findall("%s-[0-9]{4,5}" % CAPITALIZED_PROJECT_NAME, title) + + if len(jira_ids) == 0: + resolve_jira_issue(merge_branches, comment) + for jira_id in jira_ids: + resolve_jira_issue(merge_branches, comment, jira_id) + + +def standardize_jira_ref(text): + """ + Standardize the jira reference commit message prefix to "PROJECT_NAME-XXX; Issue" + + >>> standardize_jira_ref("%s-5954; Top by key" % CAPITALIZED_PROJECT_NAME) + 'KAFKA-5954; Top by key' + >>> standardize_jira_ref("%s-5821; ParquetRelation2 CTAS should check if delete is successful" % PROJECT_NAME) + 'KAFKA-5821; ParquetRelation2 CTAS should check if delete is successful' + >>> standardize_jira_ref("%s-4123 [WIP] Show new dependencies added in pull requests" % PROJECT_NAME) + 'KAFKA-4123; [WIP] Show new dependencies added in pull requests' + >>> standardize_jira_ref("%s 5954: Top by key" % PROJECT_NAME) + 'KAFKA-5954; Top by key' + >>> standardize_jira_ref("%s-979 a LRU scheduler for load balancing in TaskSchedulerImpl" % PROJECT_NAME) + 'KAFKA-979; a LRU scheduler for load balancing in TaskSchedulerImpl' + >>> standardize_jira_ref("%s-1094 Support MiMa for reporting binary compatibility across versions." % CAPITALIZED_PROJECT_NAME) + 'KAFKA-1094; Support MiMa for reporting binary compatibility across versions.' + >>> standardize_jira_ref("[WIP] %s-1146; Vagrant support" % CAPITALIZED_PROJECT_NAME) + 'KAFKA-1146; [WIP] Vagrant support' + >>> standardize_jira_ref("%s-1032. If Yarn app fails before registering, app master stays aroun..." % PROJECT_NAME) + 'KAFKA-1032; If Yarn app fails before registering, app master stays aroun...' + >>> standardize_jira_ref("%s-6250 %s-6146 %s-5911: Types are now reserved words in DDL parser." % (PROJECT_NAME, PROJECT_NAME, CAPITALIZED_PROJECT_NAME)) + 'KAFKA-6250 KAFKA-6146 KAFKA-5911; Types are now reserved words in DDL parser.' + >>> standardize_jira_ref("Additional information for users building from source code") + 'Additional information for users building from source code' + """ + jira_refs = [] + components = [] + + # Extract JIRA ref(s): + pattern = re.compile(r'(%s[-\s]*[0-9]{3,6})+' % CAPITALIZED_PROJECT_NAME, re.IGNORECASE) + for ref in pattern.findall(text): + # Add brackets, replace spaces with a dash, & convert to uppercase + jira_refs.append(re.sub(r'\s+', '-', ref.upper())) + text = text.replace(ref, '') + + # Extract project name component(s): + # Look for alphanumeric chars, spaces, dashes, periods, and/or commas + pattern = re.compile(r'(\[[\w\s,-\.]+\])', re.IGNORECASE) + for component in pattern.findall(text): + components.append(component.upper()) + text = text.replace(component, '') + + # Cleanup any remaining symbols: + pattern = re.compile(r'^\W+(.*)', re.IGNORECASE) + if (pattern.search(text) is not None): + text = pattern.search(text).groups()[0] + + # Assemble full text (JIRA ref(s), module(s), remaining text) + jira_prefix = ' '.join(jira_refs).strip() + if jira_prefix: + jira_prefix = jira_prefix + "; " + clean_text = jira_prefix + ' '.join(components).strip() + " " + text.strip() + + # Replace multiple spaces with a single space, e.g. if no jira refs and/or components were included + clean_text = re.sub(r'\s+', ' ', clean_text.strip()) + + return clean_text + +def main(): + global original_head + + original_head = get_current_branch() + + branches = get_json("%s/branches" % GITHUB_API_BASE) + branch_names = filter(lambda x: x[0].isdigit(), [x['name'] for x in branches]) + # Assumes branch names can be sorted lexicographically + latest_branch = sorted(branch_names, reverse=True)[0] + + pr_num = raw_input("Which pull request would you like to merge? (e.g. 34): ") + pr = get_json("%s/pulls/%s" % (GITHUB_API_BASE, pr_num)) + pr_events = get_json("%s/issues/%s/events" % (GITHUB_API_BASE, pr_num)) + + url = pr["url"] + + pr_title = pr["title"] + commit_title = raw_input("Commit title [%s]: " % pr_title.encode("utf-8")).decode("utf-8") + if commit_title == "": + commit_title = pr_title + + # Decide whether to use the modified title or not + modified_title = standardize_jira_ref(commit_title) + if modified_title != commit_title: + print "I've re-written the title as follows to match the standard format:" + print "Original: %s" % commit_title + print "Modified: %s" % modified_title + result = raw_input("Would you like to use the modified title? (y/n): ") + if result.lower() == "y": + commit_title = modified_title + print "Using modified title:" + else: + print "Using original title:" + print commit_title + + body = pr["body"] + target_ref = pr["base"]["ref"] + user_login = pr["user"]["login"] + base_ref = pr["head"]["ref"] + pr_repo_desc = "%s/%s" % (user_login, base_ref) + + # Merged pull requests don't appear as merged in the GitHub API; + # Instead, they're closed by asfgit. + merge_commits = \ + [e for e in pr_events if e["actor"]["login"] == "asfgit" and e["event"] == "closed"] + + if merge_commits: + merge_hash = merge_commits[0]["commit_id"] + message = get_json("%s/commits/%s" % (GITHUB_API_BASE, merge_hash))["commit"]["message"] + + print "Pull request %s has already been merged, assuming you want to backport" % pr_num + commit_is_downloaded = run_cmd(['git', 'rev-parse', '--quiet', '--verify', + "%s^{commit}" % merge_hash]).strip() != "" + if not commit_is_downloaded: + fail("Couldn't find any merge commit for #%s, you may need to update HEAD." % pr_num) + + print "Found commit %s:\n%s" % (merge_hash, message) + cherry_pick(pr_num, merge_hash, latest_branch) + sys.exit(0) + + if not bool(pr["mergeable"]): + msg = "Pull request %s is not mergeable in its current form.\n" % pr_num + \ + "Continue? (experts only!)" + continue_maybe(msg) + + print ("\n=== Pull Request #%s ===" % pr_num) + print ("PR title\t%s\nCommit title\t%s\nSource\t\t%s\nTarget\t\t%s\nURL\t\t%s" % ( + pr_title, commit_title, pr_repo_desc, target_ref, url)) + continue_maybe("Proceed with merging pull request #%s?" % pr_num) + + merged_refs = [target_ref] + + merge_hash = merge_pr(pr_num, target_ref, commit_title, body, pr_repo_desc) + + pick_prompt = "Would you like to pick %s into another branch?" % merge_hash + while raw_input("\n%s (y/n): " % pick_prompt).lower() == "y": + merged_refs = merged_refs + [cherry_pick(pr_num, merge_hash, latest_branch)] + + if JIRA_IMPORTED: + if JIRA_USERNAME and JIRA_PASSWORD: + continue_maybe("Would you like to update an associated JIRA?") + jira_comment = "Issue resolved by pull request %s\n[%s/%s]" % (pr_num, GITHUB_BASE, pr_num) + resolve_jira_issues(commit_title, merged_refs, jira_comment) + else: + print "JIRA_USERNAME and JIRA_PASSWORD not set" + print "Exiting without trying to close the associated JIRA." + else: + print "Could not find jira-python library. Run 'sudo pip install jira' to install." + print "Exiting without trying to close the associated JIRA." + +if __name__ == "__main__": + import doctest + (failure_count, test_count) = doctest.testmod() + if (failure_count): + exit(-1) + + main() diff --git a/licenses/CDDL+GPL-1.1 b/licenses/CDDL+GPL-1.1 new file mode 100644 index 0000000..4b156e6 --- /dev/null +++ b/licenses/CDDL+GPL-1.1 @@ -0,0 +1,760 @@ +COMMON DEVELOPMENT AND DISTRIBUTION LICENSE (CDDL) Version 1.1 + +1. Definitions. + + 1.1. "Contributor" means each individual or entity that creates or + contributes to the creation of Modifications. + + 1.2. "Contributor Version" means the combination of the Original + Software, prior Modifications used by a Contributor (if any), and + the Modifications made by that particular Contributor. + + 1.3. "Covered Software" means (a) the Original Software, or (b) + Modifications, or (c) the combination of files containing Original + Software with files containing Modifications, in each case including + portions thereof. + + 1.4. "Executable" means the Covered Software in any form other than + Source Code. + + 1.5. "Initial Developer" means the individual or entity that first + makes Original Software available under this License. + + 1.6. "Larger Work" means a work which combines Covered Software or + portions thereof with code not governed by the terms of this License. + + 1.7. "License" means this document. + + 1.8. "Licensable" means having the right to grant, to the maximum + extent possible, whether at the time of the initial grant or + subsequently acquired, any and all of the rights conveyed herein. + + 1.9. "Modifications" means the Source Code and Executable form of + any of the following: + + A. Any file that results from an addition to, deletion from or + modification of the contents of a file containing Original Software + or previous Modifications; + + B. Any new file that contains any part of the Original Software or + previous Modification; or + + C. Any new file that is contributed or otherwise made available + under the terms of this License. + + 1.10. "Original Software" means the Source Code and Executable form + of computer software code that is originally released under this + License. + + 1.11. "Patent Claims" means any patent claim(s), now owned or + hereafter acquired, including without limitation, method, process, + and apparatus claims, in any patent Licensable by grantor. + + 1.12. "Source Code" means (a) the common form of computer software + code in which modifications are made and (b) associated + documentation included in or with such code. + + 1.13. "You" (or "Your") means an individual or a legal entity + exercising rights under, and complying with all of the terms of, + this License. For legal entities, "You" includes any entity which + controls, is controlled by, or is under common control with You. For + purposes of this definition, "control" means (a) the power, direct + or indirect, to cause the direction or management of such entity, + whether by contract or otherwise, or (b) ownership of more than + fifty percent (50%) of the outstanding shares or beneficial + ownership of such entity. + +2. License Grants. + + 2.1. The Initial Developer Grant. + + Conditioned upon Your compliance with Section 3.1 below and subject + to third party intellectual property claims, the Initial Developer + hereby grants You a world-wide, royalty-free, non-exclusive license: + + (a) under intellectual property rights (other than patent or + trademark) Licensable by Initial Developer, to use, reproduce, + modify, display, perform, sublicense and distribute the Original + Software (or portions thereof), with or without Modifications, + and/or as part of a Larger Work; and + + (b) under Patent Claims infringed by the making, using or selling of + Original Software, to make, have made, use, practice, sell, and + offer for sale, and/or otherwise dispose of the Original Software + (or portions thereof). + + (c) The licenses granted in Sections 2.1(a) and (b) are effective on + the date Initial Developer first distributes or otherwise makes the + Original Software available to a third party under the terms of this + License. + + (d) Notwithstanding Section 2.1(b) above, no patent license is + granted: (1) for code that You delete from the Original Software, or + (2) for infringements caused by: (i) the modification of the + Original Software, or (ii) the combination of the Original Software + with other software or devices. + + 2.2. Contributor Grant. + + Conditioned upon Your compliance with Section 3.1 below and subject + to third party intellectual property claims, each Contributor hereby + grants You a world-wide, royalty-free, non-exclusive license: + + (a) under intellectual property rights (other than patent or + trademark) Licensable by Contributor to use, reproduce, modify, + display, perform, sublicense and distribute the Modifications + created by such Contributor (or portions thereof), either on an + unmodified basis, with other Modifications, as Covered Software + and/or as part of a Larger Work; and + + (b) under Patent Claims infringed by the making, using, or selling + of Modifications made by that Contributor either alone and/or in + combination with its Contributor Version (or portions of such + combination), to make, use, sell, offer for sale, have made, and/or + otherwise dispose of: (1) Modifications made by that Contributor (or + portions thereof); and (2) the combination of Modifications made by + that Contributor with its Contributor Version (or portions of such + combination). + + (c) The licenses granted in Sections 2.2(a) and 2.2(b) are effective + on the date Contributor first distributes or otherwise makes the + Modifications available to a third party. + + (d) Notwithstanding Section 2.2(b) above, no patent license is + granted: (1) for any code that Contributor has deleted from the + Contributor Version; (2) for infringements caused by: (i) third + party modifications of Contributor Version, or (ii) the combination + of Modifications made by that Contributor with other software + (except as part of the Contributor Version) or other devices; or (3) + under Patent Claims infringed by Covered Software in the absence of + Modifications made by that Contributor. + +3. Distribution Obligations. + + 3.1. Availability of Source Code. + + Any Covered Software that You distribute or otherwise make available + in Executable form must also be made available in Source Code form + and that Source Code form must be distributed only under the terms + of this License. You must include a copy of this License with every + copy of the Source Code form of the Covered Software You distribute + or otherwise make available. You must inform recipients of any such + Covered Software in Executable form as to how they can obtain such + Covered Software in Source Code form in a reasonable manner on or + through a medium customarily used for software exchange. + + 3.2. Modifications. + + The Modifications that You create or to which You contribute are + governed by the terms of this License. You represent that You + believe Your Modifications are Your original creation(s) and/or You + have sufficient rights to grant the rights conveyed by this License. + + 3.3. Required Notices. + + You must include a notice in each of Your Modifications that + identifies You as the Contributor of the Modification. You may not + remove or alter any copyright, patent or trademark notices contained + within the Covered Software, or any notices of licensing or any + descriptive text giving attribution to any Contributor or the + Initial Developer. + + 3.4. Application of Additional Terms. + + You may not offer or impose any terms on any Covered Software in + Source Code form that alters or restricts the applicable version of + this License or the recipients' rights hereunder. You may choose to + offer, and to charge a fee for, warranty, support, indemnity or + liability obligations to one or more recipients of Covered Software. + However, you may do so only on Your own behalf, and not on behalf of + the Initial Developer or any Contributor. You must make it + absolutely clear that any such warranty, support, indemnity or + liability obligation is offered by You alone, and You hereby agree + to indemnify the Initial Developer and every Contributor for any + liability incurred by the Initial Developer or such Contributor as a + result of warranty, support, indemnity or liability terms You offer. + + 3.5. Distribution of Executable Versions. + + You may distribute the Executable form of the Covered Software under + the terms of this License or under the terms of a license of Your + choice, which may contain terms different from this License, + provided that You are in compliance with the terms of this License + and that the license for the Executable form does not attempt to + limit or alter the recipient's rights in the Source Code form from + the rights set forth in this License. If You distribute the Covered + Software in Executable form under a different license, You must make + it absolutely clear that any terms which differ from this License + are offered by You alone, not by the Initial Developer or + Contributor. You hereby agree to indemnify the Initial Developer and + every Contributor for any liability incurred by the Initial + Developer or such Contributor as a result of any such terms You offer. + + 3.6. Larger Works. + + You may create a Larger Work by combining Covered Software with + other code not governed by the terms of this License and distribute + the Larger Work as a single product. In such a case, You must make + sure the requirements of this License are fulfilled for the Covered + Software. + +4. Versions of the License. + + 4.1. New Versions. + + Oracle is the initial license steward and may publish revised and/or + new versions of this License from time to time. Each version will be + given a distinguishing version number. Except as provided in Section + 4.3, no one other than the license steward has the right to modify + this License. + + 4.2. Effect of New Versions. + + You may always continue to use, distribute or otherwise make the + Covered Software available under the terms of the version of the + License under which You originally received the Covered Software. If + the Initial Developer includes a notice in the Original Software + prohibiting it from being distributed or otherwise made available + under any subsequent version of the License, You must distribute and + make the Covered Software available under the terms of the version + of the License under which You originally received the Covered + Software. Otherwise, You may also choose to use, distribute or + otherwise make the Covered Software available under the terms of any + subsequent version of the License published by the license steward. + + 4.3. Modified Versions. + + When You are an Initial Developer and You want to create a new + license for Your Original Software, You may create and use a + modified version of this License if You: (a) rename the license and + remove any references to the name of the license steward (except to + note that the license differs from this License); and (b) otherwise + make it clear that the license contains terms which differ from this + License. + +5. DISCLAIMER OF WARRANTY. + + COVERED SOFTWARE IS PROVIDED UNDER THIS LICENSE ON AN "AS IS" BASIS, + WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, + INCLUDING, WITHOUT LIMITATION, WARRANTIES THAT THE COVERED SOFTWARE + IS FREE OF DEFECTS, MERCHANTABLE, FIT FOR A PARTICULAR PURPOSE OR + NON-INFRINGING. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF + THE COVERED SOFTWARE IS WITH YOU. SHOULD ANY COVERED SOFTWARE PROVE + DEFECTIVE IN ANY RESPECT, YOU (NOT THE INITIAL DEVELOPER OR ANY + OTHER CONTRIBUTOR) ASSUME THE COST OF ANY NECESSARY SERVICING, + REPAIR OR CORRECTION. THIS DISCLAIMER OF WARRANTY CONSTITUTES AN + ESSENTIAL PART OF THIS LICENSE. NO USE OF ANY COVERED SOFTWARE IS + AUTHORIZED HEREUNDER EXCEPT UNDER THIS DISCLAIMER. + +6. TERMINATION. + + 6.1. This License and the rights granted hereunder will terminate + automatically if You fail to comply with terms herein and fail to + cure such breach within 30 days of becoming aware of the breach. + Provisions which, by their nature, must remain in effect beyond the + termination of this License shall survive. + + 6.2. If You assert a patent infringement claim (excluding + declaratory judgment actions) against Initial Developer or a + Contributor (the Initial Developer or Contributor against whom You + assert such claim is referred to as "Participant") alleging that the + Participant Software (meaning the Contributor Version where the + Participant is a Contributor or the Original Software where the + Participant is the Initial Developer) directly or indirectly + infringes any patent, then any and all rights granted directly or + indirectly to You by such Participant, the Initial Developer (if the + Initial Developer is not the Participant) and all Contributors under + Sections 2.1 and/or 2.2 of this License shall, upon 60 days notice + from Participant terminate prospectively and automatically at the + expiration of such 60 day notice period, unless if within such 60 + day period You withdraw Your claim with respect to the Participant + Software against such Participant either unilaterally or pursuant to + a written agreement with Participant. + + 6.3. If You assert a patent infringement claim against Participant + alleging that the Participant Software directly or indirectly + infringes any patent where such claim is resolved (such as by + license or settlement) prior to the initiation of patent + infringement litigation, then the reasonable value of the licenses + granted by such Participant under Sections 2.1 or 2.2 shall be taken + into account in determining the amount or value of any payment or + license. + + 6.4. In the event of termination under Sections 6.1 or 6.2 above, + all end user licenses that have been validly granted by You or any + distributor hereunder prior to termination (excluding licenses + granted to You by any distributor) shall survive termination. + +7. LIMITATION OF LIABILITY. + + UNDER NO CIRCUMSTANCES AND UNDER NO LEGAL THEORY, WHETHER TORT + (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE, SHALL YOU, THE + INITIAL DEVELOPER, ANY OTHER CONTRIBUTOR, OR ANY DISTRIBUTOR OF + COVERED SOFTWARE, OR ANY SUPPLIER OF ANY OF SUCH PARTIES, BE LIABLE + TO ANY PERSON FOR ANY INDIRECT, SPECIAL, INCIDENTAL, OR + CONSEQUENTIAL DAMAGES OF ANY CHARACTER INCLUDING, WITHOUT + LIMITATION, DAMAGES FOR LOSS OF GOODWILL, WORK STOPPAGE, COMPUTER + FAILURE OR MALFUNCTION, OR ANY AND ALL OTHER COMMERCIAL DAMAGES OR + LOSSES, EVEN IF SUCH PARTY SHALL HAVE BEEN INFORMED OF THE + POSSIBILITY OF SUCH DAMAGES. THIS LIMITATION OF LIABILITY SHALL NOT + APPLY TO LIABILITY FOR DEATH OR PERSONAL INJURY RESULTING FROM SUCH + PARTY'S NEGLIGENCE TO THE EXTENT APPLICABLE LAW PROHIBITS SUCH + LIMITATION. SOME JURISDICTIONS DO NOT ALLOW THE EXCLUSION OR + LIMITATION OF INCIDENTAL OR CONSEQUENTIAL DAMAGES, SO THIS EXCLUSION + AND LIMITATION MAY NOT APPLY TO YOU. + +8. U.S. GOVERNMENT END USERS. + + The Covered Software is a "commercial item," as that term is defined + in 48 C.F.R. 2.101 (Oct. 1995), consisting of "commercial computer + software" (as that term is defined at 48 C.F.R. § + 252.227-7014(a)(1)) and "commercial computer software documentation" + as such terms are used in 48 C.F.R. 12.212 (Sept. 1995). Consistent + with 48 C.F.R. 12.212 and 48 C.F.R. 227.7202-1 through 227.7202-4 + (June 1995), all U.S. Government End Users acquire Covered Software + with only those rights set forth herein. This U.S. Government Rights + clause is in lieu of, and supersedes, any other FAR, DFAR, or other + clause or provision that addresses Government rights in computer + software under this License. + +9. MISCELLANEOUS. + + This License represents the complete agreement concerning subject + matter hereof. If any provision of this License is held to be + unenforceable, such provision shall be reformed only to the extent + necessary to make it enforceable. This License shall be governed by + the law of the jurisdiction specified in a notice contained within + the Original Software (except to the extent applicable law, if any, + provides otherwise), excluding such jurisdiction's conflict-of-law + provisions. Any litigation relating to this License shall be subject + to the jurisdiction of the courts located in the jurisdiction and + venue specified in a notice contained within the Original Software, + with the losing party responsible for costs, including, without + limitation, court costs and reasonable attorneys' fees and expenses. + The application of the United Nations Convention on Contracts for + the International Sale of Goods is expressly excluded. Any law or + regulation which provides that the language of a contract shall be + construed against the drafter shall not apply to this License. You + agree that You alone are responsible for compliance with the United + States export administration regulations (and the export control + laws and regulation of any other countries) when You use, distribute + or otherwise make available any Covered Software. + +10. RESPONSIBILITY FOR CLAIMS. + + As between Initial Developer and the Contributors, each party is + responsible for claims and damages arising, directly or indirectly, + out of its utilization of rights under this License and You agree to + work with Initial Developer and Contributors to distribute such + responsibility on an equitable basis. Nothing herein is intended or + shall be deemed to constitute any admission of liability. + +------------------------------------------------------------------------ + +NOTICE PURSUANT TO SECTION 9 OF THE COMMON DEVELOPMENT AND DISTRIBUTION +LICENSE (CDDL) + +The code released under the CDDL shall be governed by the laws of the +State of California (excluding conflict-of-law provisions). Any +litigation relating to this License shall be subject to the jurisdiction +of the Federal Courts of the Northern District of California and the +state courts of the State of California, with venue lying in Santa Clara +County, California. + + + + The GNU General Public License (GPL) Version 2, June 1991 + +Copyright (C) 1989, 1991 Free Software Foundation, Inc. +51 Franklin Street, Fifth Floor +Boston, MA 02110-1335 +USA + +Everyone is permitted to copy and distribute verbatim copies +of this license document, but changing it is not allowed. + +Preamble + +The licenses for most software are designed to take away your freedom to +share and change it. By contrast, the GNU General Public License is +intended to guarantee your freedom to share and change free software--to +make sure the software is free for all its users. This General Public +License applies to most of the Free Software Foundation's software and +to any other program whose authors commit to using it. (Some other Free +Software Foundation software is covered by the GNU Library General +Public License instead.) You can apply it to your programs, too. + +When we speak of free software, we are referring to freedom, not price. +Our General Public Licenses are designed to make sure that you have the +freedom to distribute copies of free software (and charge for this +service if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs; and that you know you can do these things. + +To protect your rights, we need to make restrictions that forbid anyone +to deny you these rights or to ask you to surrender the rights. These +restrictions translate to certain responsibilities for you if you +distribute copies of the software, or if you modify it. + +For example, if you distribute copies of such a program, whether gratis +or for a fee, you must give the recipients all the rights that you have. +You must make sure that they, too, receive or can get the source code. +And you must show them these terms so they know their rights. + +We protect your rights with two steps: (1) copyright the software, and +(2) offer you this license which gives you legal permission to copy, +distribute and/or modify the software. + +Also, for each author's protection and ours, we want to make certain +that everyone understands that there is no warranty for this free +software. If the software is modified by someone else and passed on, we +want its recipients to know that what they have is not the original, so +that any problems introduced by others will not reflect on the original +authors' reputations. + +Finally, any free program is threatened constantly by software patents. +We wish to avoid the danger that redistributors of a free program will +individually obtain patent licenses, in effect making the program +proprietary. To prevent this, we have made it clear that any patent must +be licensed for everyone's free use or not licensed at all. + +The precise terms and conditions for copying, distribution and +modification follow. + +TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION + +0. This License applies to any program or other work which contains a +notice placed by the copyright holder saying it may be distributed under +the terms of this General Public License. The "Program", below, refers +to any such program or work, and a "work based on the Program" means +either the Program or any derivative work under copyright law: that is +to say, a work containing the Program or a portion of it, either +verbatim or with modifications and/or translated into another language. +(Hereinafter, translation is included without limitation in the term +"modification".) Each licensee is addressed as "you". + +Activities other than copying, distribution and modification are not +covered by this License; they are outside its scope. The act of running +the Program is not restricted, and the output from the Program is +covered only if its contents constitute a work based on the Program +(independent of having been made by running the Program). Whether that +is true depends on what the Program does. + +1. You may copy and distribute verbatim copies of the Program's source +code as you receive it, in any medium, provided that you conspicuously +and appropriately publish on each copy an appropriate copyright notice +and disclaimer of warranty; keep intact all the notices that refer to +this License and to the absence of any warranty; and give any other +recipients of the Program a copy of this License along with the Program. + +You may charge a fee for the physical act of transferring a copy, and +you may at your option offer warranty protection in exchange for a fee. + +2. You may modify your copy or copies of the Program or any portion of +it, thus forming a work based on the Program, and copy and distribute +such modifications or work under the terms of Section 1 above, provided +that you also meet all of these conditions: + + a) You must cause the modified files to carry prominent notices + stating that you changed the files and the date of any change. + + b) You must cause any work that you distribute or publish, that in + whole or in part contains or is derived from the Program or any part + thereof, to be licensed as a whole at no charge to all third parties + under the terms of this License. + + c) If the modified program normally reads commands interactively + when run, you must cause it, when started running for such + interactive use in the most ordinary way, to print or display an + announcement including an appropriate copyright notice and a notice + that there is no warranty (or else, saying that you provide a + warranty) and that users may redistribute the program under these + conditions, and telling the user how to view a copy of this License. + (Exception: if the Program itself is interactive but does not + normally print such an announcement, your work based on the Program + is not required to print an announcement.) + +These requirements apply to the modified work as a whole. If +identifiable sections of that work are not derived from the Program, and +can be reasonably considered independent and separate works in +themselves, then this License, and its terms, do not apply to those +sections when you distribute them as separate works. But when you +distribute the same sections as part of a whole which is a work based on +the Program, the distribution of the whole must be on the terms of this +License, whose permissions for other licensees extend to the entire +whole, and thus to each and every part regardless of who wrote it. + +Thus, it is not the intent of this section to claim rights or contest +your rights to work written entirely by you; rather, the intent is to +exercise the right to control the distribution of derivative or +collective works based on the Program. + +In addition, mere aggregation of another work not based on the Program +with the Program (or with a work based on the Program) on a volume of a +storage or distribution medium does not bring the other work under the +scope of this License. + +3. You may copy and distribute the Program (or a work based on it, +under Section 2) in object code or executable form under the terms of +Sections 1 and 2 above provided that you also do one of the following: + + a) Accompany it with the complete corresponding machine-readable + source code, which must be distributed under the terms of Sections 1 + and 2 above on a medium customarily used for software interchange; or, + + b) Accompany it with a written offer, valid for at least three + years, to give any third party, for a charge no more than your cost + of physically performing source distribution, a complete + machine-readable copy of the corresponding source code, to be + distributed under the terms of Sections 1 and 2 above on a medium + customarily used for software interchange; or, + + c) Accompany it with the information you received as to the offer to + distribute corresponding source code. (This alternative is allowed + only for noncommercial distribution and only if you received the + program in object code or executable form with such an offer, in + accord with Subsection b above.) + +The source code for a work means the preferred form of the work for +making modifications to it. For an executable work, complete source code +means all the source code for all modules it contains, plus any +associated interface definition files, plus the scripts used to control +compilation and installation of the executable. However, as a special +exception, the source code distributed need not include anything that is +normally distributed (in either source or binary form) with the major +components (compiler, kernel, and so on) of the operating system on +which the executable runs, unless that component itself accompanies the +executable. + +If distribution of executable or object code is made by offering access +to copy from a designated place, then offering equivalent access to copy +the source code from the same place counts as distribution of the source +code, even though third parties are not compelled to copy the source +along with the object code. + +4. You may not copy, modify, sublicense, or distribute the Program +except as expressly provided under this License. Any attempt otherwise +to copy, modify, sublicense or distribute the Program is void, and will +automatically terminate your rights under this License. However, parties +who have received copies, or rights, from you under this License will +not have their licenses terminated so long as such parties remain in +full compliance. + +5. You are not required to accept this License, since you have not +signed it. However, nothing else grants you permission to modify or +distribute the Program or its derivative works. These actions are +prohibited by law if you do not accept this License. Therefore, by +modifying or distributing the Program (or any work based on the +Program), you indicate your acceptance of this License to do so, and all +its terms and conditions for copying, distributing or modifying the +Program or works based on it. + +6. Each time you redistribute the Program (or any work based on the +Program), the recipient automatically receives a license from the +original licensor to copy, distribute or modify the Program subject to +these terms and conditions. You may not impose any further restrictions +on the recipients' exercise of the rights granted herein. You are not +responsible for enforcing compliance by third parties to this License. + +7. If, as a consequence of a court judgment or allegation of patent +infringement or for any other reason (not limited to patent issues), +conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot distribute +so as to satisfy simultaneously your obligations under this License and +any other pertinent obligations, then as a consequence you may not +distribute the Program at all. For example, if a patent license would +not permit royalty-free redistribution of the Program by all those who +receive copies directly or indirectly through you, then the only way you +could satisfy both it and this License would be to refrain entirely from +distribution of the Program. + +If any portion of this section is held invalid or unenforceable under +any particular circumstance, the balance of the section is intended to +apply and the section as a whole is intended to apply in other +circumstances. + +It is not the purpose of this section to induce you to infringe any +patents or other property right claims or to contest validity of any +such claims; this section has the sole purpose of protecting the +integrity of the free software distribution system, which is implemented +by public license practices. Many people have made generous +contributions to the wide range of software distributed through that +system in reliance on consistent application of that system; it is up to +the author/donor to decide if he or she is willing to distribute +software through any other system and a licensee cannot impose that choice. + +This section is intended to make thoroughly clear what is believed to be +a consequence of the rest of this License. + +8. If the distribution and/or use of the Program is restricted in +certain countries either by patents or by copyrighted interfaces, the +original copyright holder who places the Program under this License may +add an explicit geographical distribution limitation excluding those +countries, so that distribution is permitted only in or among countries +not thus excluded. In such case, this License incorporates the +limitation as if written in the body of this License. + +9. The Free Software Foundation may publish revised and/or new +versions of the General Public License from time to time. Such new +versions will be similar in spirit to the present version, but may +differ in detail to address new problems or concerns. + +Each version is given a distinguishing version number. If the Program +specifies a version number of this License which applies to it and "any +later version", you have the option of following the terms and +conditions either of that version or of any later version published by +the Free Software Foundation. If the Program does not specify a version +number of this License, you may choose any version ever published by the +Free Software Foundation. + +10. If you wish to incorporate parts of the Program into other free +programs whose distribution conditions are different, write to the +author to ask for permission. For software which is copyrighted by the +Free Software Foundation, write to the Free Software Foundation; we +sometimes make exceptions for this. Our decision will be guided by the +two goals of preserving the free status of all derivatives of our free +software and of promoting the sharing and reuse of software generally. + +NO WARRANTY + +11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO +WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. +EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR +OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, +EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE +ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH +YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL +NECESSARY SERVICING, REPAIR OR CORRECTION. + +12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN +WRITING WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY +AND/OR REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR +DAMAGES, INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL +DAMAGES ARISING OUT OF THE USE OR INABILITY TO USE THE PROGRAM +(INCLUDING BUT NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED +INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF +THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER OR +OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + +END OF TERMS AND CONDITIONS + +How to Apply These Terms to Your New Programs + +If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + +To do so, attach the following notices to the program. It is safest to +attach them to the start of each source file to most effectively convey +the exclusion of warranty; and each file should have at least the +"copyright" line and a pointer to where the full notice is found. + + One line to give the program's name and a brief idea of what it does. + Copyright (C) + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, but + WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program; if not, write to the Free Software + Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1335 USA + +Also add information on how to contact you by electronic and paper mail. + +If the program is interactive, make it output a short notice like this +when it starts in an interactive mode: + + Gnomovision version 69, Copyright (C) year name of author + Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type + `show w'. This is free software, and you are welcome to redistribute + it under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the +appropriate parts of the General Public License. Of course, the commands +you use may be called something other than `show w' and `show c'; they +could even be mouse-clicks or menu items--whatever suits your program. + +You should also get your employer (if you work as a programmer) or your +school, if any, to sign a "copyright disclaimer" for the program, if +necessary. Here is a sample; alter the names: + + Yoyodyne, Inc., hereby disclaims all copyright interest in the + program `Gnomovision' (which makes passes at compilers) written by + James Hacker. + + signature of Ty Coon, 1 April 1989 + Ty Coon, President of Vice + +This General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications +with the library. If this is what you want to do, use the GNU Library +General Public License instead of this License. + +# + +Certain source files distributed by Oracle America, Inc. and/or its +affiliates are subject to the following clarification and special +exception to the GPLv2, based on the GNU Project exception for its +Classpath libraries, known as the GNU Classpath Exception, but only +where Oracle has expressly included in the particular source file's +header the words "Oracle designates this particular file as subject to +the "Classpath" exception as provided by Oracle in the LICENSE file +that accompanied this code." + +You should also note that Oracle includes multiple, independent +programs in this software package. Some of those programs are provided +under licenses deemed incompatible with the GPLv2 by the Free Software +Foundation and others. For example, the package includes programs +licensed under the Apache License, Version 2.0. Such programs are +licensed to you under their original licenses. + +Oracle facilitates your further distribution of this package by adding +the Classpath Exception to the necessary parts of its GPLv2 code, which +permits you to use that code in combination with other independent +modules not licensed under the GPLv2. However, note that this would +not permit you to commingle code under an incompatible license with +Oracle's GPLv2 licensed code by, for example, cutting and pasting such +code into a file also containing Oracle's GPLv2 licensed code and then +distributing the result. Additionally, if you were to remove the +Classpath Exception from any of the files to which it applies and +distribute the result, you would likely be required to license some or +all of the other code in that distribution under the GPLv2 as well, and +since the GPLv2 is incompatible with the license terms of some items +included in the distribution by Oracle, removing the Classpath +Exception could therefore effectively compromise your ability to +further distribute the package. + +Proceed with caution and we recommend that you obtain the advice of a +lawyer skilled in open source matters before removing the Classpath +Exception or making modifications to this package which may +subsequently be redistributed and/or involve the use of third party +software. + +CLASSPATH EXCEPTION +Linking this library statically or dynamically with other modules is +making a combined work based on this library. Thus, the terms and +conditions of the GNU General Public License version 2 cover the whole +combination. + +As a special exception, the copyright holders of this library give you +permission to link this library with independent modules to produce an +executable, regardless of the license terms of these independent +modules, and to copy and distribute the resulting executable under +terms of your choice, provided that you also meet, for each linked +independent module, the terms and conditions of the license of that +module. An independent module is a module which is not derived from or +based on this library. If you modify this library, you may extend this +exception to your version of the library, but you are not obligated to +do so. If you do not wish to do so, delete this exception statement +from your version. + diff --git a/licenses/DWTFYWTPL b/licenses/DWTFYWTPL new file mode 100644 index 0000000..5a8e332 --- /dev/null +++ b/licenses/DWTFYWTPL @@ -0,0 +1,14 @@ + DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE + Version 2, December 2004 + + Copyright (C) 2004 Sam Hocevar + + Everyone is permitted to copy and distribute verbatim or modified + copies of this license document, and changing it is allowed as long + as the name is changed. + + DO WHAT THE FUCK YOU WANT TO PUBLIC LICENSE + TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION + + 0. You just DO WHAT THE FUCK YOU WANT TO. + diff --git a/licenses/argparse-MIT b/licenses/argparse-MIT new file mode 100644 index 0000000..773b0df --- /dev/null +++ b/licenses/argparse-MIT @@ -0,0 +1,23 @@ +/* + * Copyright (C) 2011-2017 Tatsuhiro Tsujikawa + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, copy, + * modify, merge, publish, distribute, sublicense, and/or sell copies + * of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ diff --git a/licenses/eclipse-distribution-license-1.0 b/licenses/eclipse-distribution-license-1.0 new file mode 100644 index 0000000..5f06513 --- /dev/null +++ b/licenses/eclipse-distribution-license-1.0 @@ -0,0 +1,13 @@ +Eclipse Distribution License - v 1.0 + +Copyright (c) 2007, Eclipse Foundation, Inc. and its licensors. + +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. +* Neither the name of the Eclipse Foundation, Inc. nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/licenses/eclipse-public-license-2.0 b/licenses/eclipse-public-license-2.0 new file mode 100644 index 0000000..c9f1425 --- /dev/null +++ b/licenses/eclipse-public-license-2.0 @@ -0,0 +1,87 @@ +Eclipse Public License - v 2.0 + +THE ACCOMPANYING PROGRAM IS PROVIDED UNDER THE TERMS OF THIS ECLIPSE PUBLIC LICENSE (“AGREEMENT”). ANY USE, REPRODUCTION OR DISTRIBUTION OF THE PROGRAM CONSTITUTES RECIPIENT'S ACCEPTANCE OF THIS AGREEMENT. +1. DEFINITIONS + +“Contribution” means: + + a) in the case of the initial Contributor, the initial content Distributed under this Agreement, and + b) in the case of each subsequent Contributor: + i) changes to the Program, and + ii) additions to the Program; + where such changes and/or additions to the Program originate from and are Distributed by that particular Contributor. A Contribution “originates” from a Contributor if it was added to the Program by such Contributor itself or anyone acting on such Contributor's behalf. Contributions do not include changes or additions to the Program that are not Modified Works. + +“Contributor” means any person or entity that Distributes the Program. + +“Licensed Patents” mean patent claims licensable by a Contributor which are necessarily infringed by the use or sale of its Contribution alone or when combined with the Program. + +“Program” means the Contributions Distributed in accordance with this Agreement. + +“Recipient” means anyone who receives the Program under this Agreement or any Secondary License (as applicable), including Contributors. + +“Derivative Works” shall mean any work, whether in Source Code or other form, that is based on (or derived from) the Program and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. + +“Modified Works” shall mean any work in Source Code or other form that results from an addition to, deletion from, or modification of the contents of the Program, including, for purposes of clarity any new file in Source Code form that contains any contents of the Program. Modified Works shall not include works that contain only declarations, interfaces, types, classes, structures, or files of the Program solely in each case in order to link to, bind by name, or subclass the Program or Modified Works thereof. + +“Distribute” means the acts of a) distributing or b) making available in any manner that enables the transfer of a copy. + +“Source Code” means the form of a Program preferred for making modifications, including but not limited to software source code, documentation source, and configuration files. + +“Secondary License” means either the GNU General Public License, Version 2.0, or any later versions of that license, including any exceptions or additional permissions as identified by the initial Contributor. +2. GRANT OF RIGHTS + + a) Subject to the terms of this Agreement, each Contributor hereby grants Recipient a non-exclusive, worldwide, royalty-free copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, Distribute and sublicense the Contribution of such Contributor, if any, and such Derivative Works. + b) Subject to the terms of this Agreement, each Contributor hereby grants Recipient a non-exclusive, worldwide, royalty-free patent license under Licensed Patents to make, use, sell, offer to sell, import and otherwise transfer the Contribution of such Contributor, if any, in Source Code or other form. This patent license shall apply to the combination of the Contribution and the Program if, at the time the Contribution is added by the Contributor, such addition of the Contribution causes such combination to be covered by the Licensed Patents. The patent license shall not apply to any other combinations which include the Contribution. No hardware per se is licensed hereunder. + c) Recipient understands that although each Contributor grants the licenses to its Contributions set forth herein, no assurances are provided by any Contributor that the Program does not infringe the patent or other intellectual property rights of any other entity. Each Contributor disclaims any liability to Recipient for claims brought by any other entity based on infringement of intellectual property rights or otherwise. As a condition to exercising the rights and licenses granted hereunder, each Recipient hereby assumes sole responsibility to secure any other intellectual property rights needed, if any. For example, if a third party patent license is required to allow Recipient to Distribute the Program, it is Recipient's responsibility to acquire that license before distributing the Program. + d) Each Contributor represents that to its knowledge it has sufficient copyright rights in its Contribution, if any, to grant the copyright license set forth in this Agreement. + e) Notwithstanding the terms of any Secondary License, no Contributor makes additional grants to any Recipient (other than those set forth in this Agreement) as a result of such Recipient's receipt of the Program under the terms of a Secondary License (if permitted under the terms of Section 3). + +3. REQUIREMENTS + +3.1 If a Contributor Distributes the Program in any form, then: + + a) the Program must also be made available as Source Code, in accordance with section 3.2, and the Contributor must accompany the Program with a statement that the Source Code for the Program is available under this Agreement, and informs Recipients how to obtain it in a reasonable manner on or through a medium customarily used for software exchange; and + b) the Contributor may Distribute the Program under a license different than this Agreement, provided that such license: + i) effectively disclaims on behalf of all other Contributors all warranties and conditions, express and implied, including warranties or conditions of title and non-infringement, and implied warranties or conditions of merchantability and fitness for a particular purpose; + ii) effectively excludes on behalf of all other Contributors all liability for damages, including direct, indirect, special, incidental and consequential damages, such as lost profits; + iii) does not attempt to limit or alter the recipients' rights in the Source Code under section 3.2; and + iv) requires any subsequent distribution of the Program by any party to be under a license that satisfies the requirements of this section 3. + +3.2 When the Program is Distributed as Source Code: + + a) it must be made available under this Agreement, or if the Program (i) is combined with other material in a separate file or files made available under a Secondary License, and (ii) the initial Contributor attached to the Source Code the notice described in Exhibit A of this Agreement, then the Program may be made available under the terms of such Secondary Licenses, and + b) a copy of this Agreement must be included with each copy of the Program. + +3.3 Contributors may not remove or alter any copyright, patent, trademark, attribution notices, disclaimers of warranty, or limitations of liability (‘notices’) contained within the Program from any copy of the Program which they Distribute, provided that Contributors may add their own appropriate notices. +4. COMMERCIAL DISTRIBUTION + +Commercial distributors of software may accept certain responsibilities with respect to end users, business partners and the like. While this license is intended to facilitate the commercial use of the Program, the Contributor who includes the Program in a commercial product offering should do so in a manner which does not create potential liability for other Contributors. Therefore, if a Contributor includes the Program in a commercial product offering, such Contributor (“Commercial Contributor”) hereby agrees to defend and indemnify every other Contributor (“Indemnified Contributor”) against any losses, damages and costs (collectively “Losses”) arising from claims, lawsuits and other legal actions brought by a third party against the Indemnified Contributor to the extent caused by the acts or omissions of such Commercial Contributor in connection with its distribution of the Program in a commercial product offering. The obligations in this section do not apply to any claims or Losses relating to any actual or alleged intellectual property infringement. In order to qualify, an Indemnified Contributor must: a) promptly notify the Commercial Contributor in writing of such claim, and b) allow the Commercial Contributor to control, and cooperate with the Commercial Contributor in, the defense and any related settlement negotiations. The Indemnified Contributor may participate in any such claim at its own expense. + +For example, a Contributor might include the Program in a commercial product offering, Product X. That Contributor is then a Commercial Contributor. If that Commercial Contributor then makes performance claims, or offers warranties related to Product X, those performance claims and warranties are such Commercial Contributor's responsibility alone. Under this section, the Commercial Contributor would have to defend claims against the other Contributors related to those performance claims and warranties, and if a court requires any other Contributor to pay any damages as a result, the Commercial Contributor must pay those damages. +5. NO WARRANTY + +EXCEPT AS EXPRESSLY SET FORTH IN THIS AGREEMENT, AND TO THE EXTENT PERMITTED BY APPLICABLE LAW, THE PROGRAM IS PROVIDED ON AN “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OR CONDITIONS OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE. Each Recipient is solely responsible for determining the appropriateness of using and distributing the Program and assumes all risks associated with its exercise of rights under this Agreement, including but not limited to the risks and costs of program errors, compliance with applicable laws, damage to or loss of data, programs or equipment, and unavailability or interruption of operations. +6. DISCLAIMER OF LIABILITY + +EXCEPT AS EXPRESSLY SET FORTH IN THIS AGREEMENT, AND TO THE EXTENT PERMITTED BY APPLICABLE LAW, NEITHER RECIPIENT NOR ANY CONTRIBUTORS SHALL HAVE ANY LIABILITY FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING WITHOUT LIMITATION LOST PROFITS), HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OR DISTRIBUTION OF THE PROGRAM OR THE EXERCISE OF ANY RIGHTS GRANTED HEREUNDER, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. +7. GENERAL + +If any provision of this Agreement is invalid or unenforceable under applicable law, it shall not affect the validity or enforceability of the remainder of the terms of this Agreement, and without further action by the parties hereto, such provision shall be reformed to the minimum extent necessary to make such provision valid and enforceable. + +If Recipient institutes patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Program itself (excluding combinations of the Program with other software or hardware) infringes such Recipient's patent(s), then such Recipient's rights granted under Section 2(b) shall terminate as of the date such litigation is filed. + +All Recipient's rights under this Agreement shall terminate if it fails to comply with any of the material terms or conditions of this Agreement and does not cure such failure in a reasonable period of time after becoming aware of such noncompliance. If all Recipient's rights under this Agreement terminate, Recipient agrees to cease use and distribution of the Program as soon as reasonably practicable. However, Recipient's obligations under this Agreement and any licenses granted by Recipient relating to the Program shall continue and survive. + +Everyone is permitted to copy and distribute copies of this Agreement, but in order to avoid inconsistency the Agreement is copyrighted and may only be modified in the following manner. The Agreement Steward reserves the right to publish new versions (including revisions) of this Agreement from time to time. No one other than the Agreement Steward has the right to modify this Agreement. The Eclipse Foundation is the initial Agreement Steward. The Eclipse Foundation may assign the responsibility to serve as the Agreement Steward to a suitable separate entity. Each new version of the Agreement will be given a distinguishing version number. The Program (including Contributions) may always be Distributed subject to the version of the Agreement under which it was received. In addition, after a new version of the Agreement is published, Contributor may elect to Distribute the Program (including its Contributions) under the new version. + +Except as expressly stated in Sections 2(a) and 2(b) above, Recipient receives no rights or licenses to the intellectual property of any Contributor under this Agreement, whether expressly, by implication, estoppel or otherwise. All rights in the Program not expressly granted under this Agreement are reserved. Nothing in this Agreement is intended to be enforceable by any entity that is not a Contributor or Recipient. No third-party beneficiary rights are created under this Agreement. +Exhibit A – Form of Secondary Licenses Notice + +“This Source Code may also be made available under the following Secondary Licenses when the conditions for such availability set forth in the Eclipse Public License, v. 2.0 are satisfied: {name license(s), version(s), and exceptions or additional permissions here}.” + + Simply including a copy of this Agreement, including this Exhibit A is not sufficient to license the Source Code under Secondary Licenses. + + If it is not possible or desirable to put the notice in a particular file, then You may include the notice in a location (such as a LICENSE file in a relevant directory) where a recipient would be likely to look for such a notice. + + You may add additional accurate notices of copyright ownership. + diff --git a/licenses/jline-BSD-3-clause b/licenses/jline-BSD-3-clause new file mode 100644 index 0000000..7e11b67 --- /dev/null +++ b/licenses/jline-BSD-3-clause @@ -0,0 +1,35 @@ +Copyright (c) 2002-2018, the original author or authors. +All rights reserved. + +https://opensource.org/licenses/BSD-3-Clause + +Redistribution and use in source and binary forms, with or +without modification, are permitted provided that the following +conditions are met: + +Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with +the distribution. + +Neither the name of JLine nor the names of its contributors +may be used to endorse or promote products derived from this +software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, +BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY +AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO +EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, +OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED +AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING +IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED +OF THE POSSIBILITY OF SUCH DAMAGE. + diff --git a/licenses/jopt-simple-MIT b/licenses/jopt-simple-MIT new file mode 100644 index 0000000..54b2732 --- /dev/null +++ b/licenses/jopt-simple-MIT @@ -0,0 +1,24 @@ +/* + The MIT License + + Copyright (c) 2004-2016 Paul R. Holser, Jr. + + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files (the + "Software"), to deal in the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to + the following conditions: + + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE + LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION + OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION + WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +*/ diff --git a/licenses/paranamer-BSD-3-clause b/licenses/paranamer-BSD-3-clause new file mode 100644 index 0000000..9eab879 --- /dev/null +++ b/licenses/paranamer-BSD-3-clause @@ -0,0 +1,29 @@ +[ ParaNamer used to be 'Pubic Domain', but since it includes a small piece of ASM it is now the same license as that: BSD ] + + Portions copyright (c) 2006-2018 Paul Hammant & ThoughtWorks Inc + Portions copyright (c) 2000-2007 INRIA, France Telecom + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + 3. Neither the name of the copyright holders nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF + THE POSSIBILITY OF SUCH DAMAGE. diff --git a/licenses/slf4j-MIT b/licenses/slf4j-MIT new file mode 100644 index 0000000..315bd49 --- /dev/null +++ b/licenses/slf4j-MIT @@ -0,0 +1,24 @@ +Copyright (c) 2004-2017 QOS.ch +All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + + diff --git a/licenses/zstd-jni-BSD-2-clause b/licenses/zstd-jni-BSD-2-clause new file mode 100644 index 0000000..66abb8a --- /dev/null +++ b/licenses/zstd-jni-BSD-2-clause @@ -0,0 +1,26 @@ +Zstd-jni: JNI bindings to Zstd Library + +Copyright (c) 2015-present, Luben Karavelov/ All rights reserved. + +BSD License + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, this + list of conditions and the following disclaimer in the documentation and/or + other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/log4j-appender/src/main/java/org/apache/kafka/log4jappender/KafkaLog4jAppender.java b/log4j-appender/src/main/java/org/apache/kafka/log4jappender/KafkaLog4jAppender.java new file mode 100644 index 0000000..23272a2 --- /dev/null +++ b/log4j-appender/src/main/java/org/apache/kafka/log4jappender/KafkaLog4jAppender.java @@ -0,0 +1,377 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.log4jappender; + +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.Producer; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.log4j.AppenderSkeleton; +import org.apache.log4j.helpers.LogLog; +import org.apache.log4j.spi.LoggingEvent; + +import java.nio.charset.StandardCharsets; +import java.util.Date; +import java.util.Properties; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; + +import static org.apache.kafka.clients.CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG; +import static org.apache.kafka.clients.CommonClientConfigs.SECURITY_PROTOCOL_CONFIG; +import static org.apache.kafka.clients.producer.ProducerConfig.ACKS_CONFIG; +import static org.apache.kafka.clients.producer.ProducerConfig.BATCH_SIZE_CONFIG; +import static org.apache.kafka.clients.producer.ProducerConfig.COMPRESSION_TYPE_CONFIG; +import static org.apache.kafka.clients.producer.ProducerConfig.DELIVERY_TIMEOUT_MS_CONFIG; +import static org.apache.kafka.clients.producer.ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG; +import static org.apache.kafka.clients.producer.ProducerConfig.LINGER_MS_CONFIG; +import static org.apache.kafka.clients.producer.ProducerConfig.MAX_BLOCK_MS_CONFIG; +import static org.apache.kafka.clients.producer.ProducerConfig.RETRIES_CONFIG; +import static org.apache.kafka.clients.producer.ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG; +import static org.apache.kafka.common.config.SaslConfigs.SASL_JAAS_CONFIG; +import static org.apache.kafka.common.config.SaslConfigs.SASL_KERBEROS_SERVICE_NAME; +import static org.apache.kafka.common.config.SaslConfigs.SASL_MECHANISM; +import static org.apache.kafka.common.config.SslConfigs.SSL_ENGINE_FACTORY_CLASS_CONFIG; +import static org.apache.kafka.common.config.SslConfigs.SSL_KEYSTORE_LOCATION_CONFIG; +import static org.apache.kafka.common.config.SslConfigs.SSL_KEYSTORE_PASSWORD_CONFIG; +import static org.apache.kafka.common.config.SslConfigs.SSL_KEYSTORE_TYPE_CONFIG; +import static org.apache.kafka.common.config.SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG; +import static org.apache.kafka.common.config.SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG; + +/** + * A log4j appender that produces log messages to Kafka + */ +public class KafkaLog4jAppender extends AppenderSkeleton { + + private String brokerList; + private String topic; + private String compressionType; + private String securityProtocol; + private String sslTruststoreLocation; + private String sslTruststorePassword; + private String sslKeystoreType; + private String sslKeystoreLocation; + private String sslKeystorePassword; + private String saslKerberosServiceName; + private String saslMechanism; + private String clientJaasConfPath; + private String clientJaasConf; + private String kerb5ConfPath; + private Integer maxBlockMs; + private String sslEngineFactoryClass; + + private int retries = Integer.MAX_VALUE; + private int requiredNumAcks = 1; + private int deliveryTimeoutMs = 120000; + private int lingerMs = 0; + private int batchSize = 16384; + private boolean ignoreExceptions = true; + private boolean syncSend; + private Producer producer; + + public Producer getProducer() { + return producer; + } + + public String getBrokerList() { + return brokerList; + } + + public void setBrokerList(String brokerList) { + this.brokerList = brokerList; + } + + public int getRequiredNumAcks() { + return requiredNumAcks; + } + + public void setRequiredNumAcks(int requiredNumAcks) { + this.requiredNumAcks = requiredNumAcks; + } + + public int getLingerMs() { + return lingerMs; + } + + public void setLingerMs(int lingerMs) { + this.lingerMs = lingerMs; + } + + public int getBatchSize() { + return batchSize; + } + + public void setBatchSize(int batchSize) { + this.batchSize = batchSize; + } + + public int getRetries() { + return retries; + } + + public void setRetries(int retries) { + this.retries = retries; + } + + public int getDeliveryTimeoutMs() { + return deliveryTimeoutMs; + } + + public void setDeliveryTimeoutMs(int deliveryTimeoutMs) { + this.deliveryTimeoutMs = deliveryTimeoutMs; + } + + public String getCompressionType() { + return compressionType; + } + + public void setCompressionType(String compressionType) { + this.compressionType = compressionType; + } + + public String getTopic() { + return topic; + } + + public void setTopic(String topic) { + this.topic = topic; + } + + public boolean getIgnoreExceptions() { + return ignoreExceptions; + } + + public void setIgnoreExceptions(boolean ignoreExceptions) { + this.ignoreExceptions = ignoreExceptions; + } + + public boolean getSyncSend() { + return syncSend; + } + + public void setSyncSend(boolean syncSend) { + this.syncSend = syncSend; + } + + public String getSslTruststorePassword() { + return sslTruststorePassword; + } + + public String getSslTruststoreLocation() { + return sslTruststoreLocation; + } + + public String getSecurityProtocol() { + return securityProtocol; + } + + public void setSecurityProtocol(String securityProtocol) { + this.securityProtocol = securityProtocol; + } + + public void setSslTruststoreLocation(String sslTruststoreLocation) { + this.sslTruststoreLocation = sslTruststoreLocation; + } + + public void setSslTruststorePassword(String sslTruststorePassword) { + this.sslTruststorePassword = sslTruststorePassword; + } + + public void setSslKeystorePassword(String sslKeystorePassword) { + this.sslKeystorePassword = sslKeystorePassword; + } + + public void setSslKeystoreType(String sslKeystoreType) { + this.sslKeystoreType = sslKeystoreType; + } + + public void setSslKeystoreLocation(String sslKeystoreLocation) { + this.sslKeystoreLocation = sslKeystoreLocation; + } + + public void setSaslKerberosServiceName(String saslKerberosServiceName) { + this.saslKerberosServiceName = saslKerberosServiceName; + } + + public void setClientJaasConfPath(String clientJaasConfPath) { + this.clientJaasConfPath = clientJaasConfPath; + } + + public void setKerb5ConfPath(String kerb5ConfPath) { + this.kerb5ConfPath = kerb5ConfPath; + } + + public String getSslKeystoreLocation() { + return sslKeystoreLocation; + } + + public String getSslKeystoreType() { + return sslKeystoreType; + } + + public String getSslKeystorePassword() { + return sslKeystorePassword; + } + + public String getSaslKerberosServiceName() { + return saslKerberosServiceName; + } + + public String getClientJaasConfPath() { + return clientJaasConfPath; + } + + public void setSaslMechanism(String saslMechanism) { + this.saslMechanism = saslMechanism; + } + + public String getSaslMechanism() { + return this.saslMechanism; + } + + public void setClientJaasConf(final String clientJaasConf) { + this.clientJaasConf = clientJaasConf; + } + + public String getClientJaasConf() { + return this.clientJaasConf; + } + + public String getKerb5ConfPath() { + return kerb5ConfPath; + } + + public int getMaxBlockMs() { + return maxBlockMs; + } + + public void setMaxBlockMs(int maxBlockMs) { + this.maxBlockMs = maxBlockMs; + } + + public String getSslEngineFactoryClass() { + return sslEngineFactoryClass; + } + + public void setSslEngineFactoryClass(String sslEngineFactoryClass) { + this.sslEngineFactoryClass = sslEngineFactoryClass; + } + + @Override + public void activateOptions() { + // check for config parameter validity + Properties props = new Properties(); + if (brokerList != null) + props.put(BOOTSTRAP_SERVERS_CONFIG, brokerList); + if (props.isEmpty()) + throw new ConfigException("The bootstrap servers property should be specified"); + if (topic == null) + throw new ConfigException("Topic must be specified by the Kafka log4j appender"); + if (compressionType != null) + props.put(COMPRESSION_TYPE_CONFIG, compressionType); + + props.put(ACKS_CONFIG, Integer.toString(requiredNumAcks)); + props.put(RETRIES_CONFIG, retries); + props.put(DELIVERY_TIMEOUT_MS_CONFIG, deliveryTimeoutMs); + props.put(LINGER_MS_CONFIG, lingerMs); + props.put(BATCH_SIZE_CONFIG, batchSize); + + if (securityProtocol != null) { + props.put(SECURITY_PROTOCOL_CONFIG, securityProtocol); + } + + if (securityProtocol != null && (securityProtocol.contains("SSL") || securityProtocol.contains("SASL"))) { + if (sslEngineFactoryClass != null) { + props.put(SSL_ENGINE_FACTORY_CLASS_CONFIG, sslEngineFactoryClass); + } + } + + if (securityProtocol != null && securityProtocol.contains("SSL") && sslTruststoreLocation != null && sslTruststorePassword != null) { + props.put(SSL_TRUSTSTORE_LOCATION_CONFIG, sslTruststoreLocation); + props.put(SSL_TRUSTSTORE_PASSWORD_CONFIG, sslTruststorePassword); + + if (sslKeystoreType != null && sslKeystoreLocation != null && + sslKeystorePassword != null) { + props.put(SSL_KEYSTORE_TYPE_CONFIG, sslKeystoreType); + props.put(SSL_KEYSTORE_LOCATION_CONFIG, sslKeystoreLocation); + props.put(SSL_KEYSTORE_PASSWORD_CONFIG, sslKeystorePassword); + } + } + + if (securityProtocol != null && securityProtocol.contains("SASL") && saslKerberosServiceName != null && clientJaasConfPath != null) { + props.put(SASL_KERBEROS_SERVICE_NAME, saslKerberosServiceName); + System.setProperty("java.security.auth.login.config", clientJaasConfPath); + } + if (kerb5ConfPath != null) { + System.setProperty("java.security.krb5.conf", kerb5ConfPath); + } + if (saslMechanism != null) { + props.put(SASL_MECHANISM, saslMechanism); + } + if (clientJaasConf != null) { + props.put(SASL_JAAS_CONFIG, clientJaasConf); + } + if (maxBlockMs != null) { + props.put(MAX_BLOCK_MS_CONFIG, maxBlockMs); + } + + props.put(KEY_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class.getName()); + props.put(VALUE_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class.getName()); + this.producer = getKafkaProducer(props); + LogLog.debug("Kafka producer connected to " + brokerList); + LogLog.debug("Logging for topic: " + topic); + } + + protected Producer getKafkaProducer(Properties props) { + return new KafkaProducer<>(props); + } + + @Override + protected void append(LoggingEvent event) { + String message = subAppend(event); + LogLog.debug("[" + new Date(event.getTimeStamp()) + "]" + message); + Future response = producer.send( + new ProducerRecord<>(topic, message.getBytes(StandardCharsets.UTF_8))); + if (syncSend) { + try { + response.get(); + } catch (InterruptedException | ExecutionException ex) { + if (!ignoreExceptions) + throw new RuntimeException(ex); + LogLog.debug("Exception while getting response", ex); + } + } + } + + private String subAppend(LoggingEvent event) { + return (this.layout == null) ? event.getRenderedMessage() : this.layout.format(event); + } + + @Override + public void close() { + if (!this.closed) { + this.closed = true; + producer.close(); + } + } + + @Override + public boolean requiresLayout() { + return true; + } +} diff --git a/log4j-appender/src/test/java/org/apache/kafka/log4jappender/KafkaLog4jAppenderTest.java b/log4j-appender/src/test/java/org/apache/kafka/log4jappender/KafkaLog4jAppenderTest.java new file mode 100644 index 0000000..7ec5633 --- /dev/null +++ b/log4j-appender/src/test/java/org/apache/kafka/log4jappender/KafkaLog4jAppenderTest.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.log4jappender; + +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.MatcherAssert.assertThat; +import org.apache.kafka.clients.producer.MockProducer; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.config.SaslConfigs; +import org.apache.log4j.Logger; +import org.apache.log4j.PropertyConfigurator; +import org.apache.log4j.helpers.LogLog; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.nio.charset.StandardCharsets; +import java.util.Properties; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeoutException; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class KafkaLog4jAppenderTest { + + private Logger logger = Logger.getLogger(KafkaLog4jAppenderTest.class); + + @BeforeEach + public void setup() { + LogLog.setInternalDebugging(true); + } + + @Test + public void testKafkaLog4jConfigs() { + Properties hostMissingProps = new Properties(); + hostMissingProps.put("log4j.rootLogger", "INFO"); + hostMissingProps.put("log4j.appender.KAFKA", "org.apache.kafka.log4jappender.KafkaLog4jAppender"); + hostMissingProps.put("log4j.appender.KAFKA.layout", "org.apache.log4j.PatternLayout"); + hostMissingProps.put("log4j.appender.KAFKA.layout.ConversionPattern", "%-5p: %c - %m%n"); + hostMissingProps.put("log4j.appender.KAFKA.Topic", "test-topic"); + hostMissingProps.put("log4j.logger.kafka.log4j", "INFO, KAFKA"); + + assertThrows(ConfigException.class, () -> PropertyConfigurator.configure(hostMissingProps), "Missing properties exception was expected !"); + + Properties topicMissingProps = new Properties(); + topicMissingProps.put("log4j.rootLogger", "INFO"); + topicMissingProps.put("log4j.appender.KAFKA", "org.apache.kafka.log4jappender.KafkaLog4jAppender"); + topicMissingProps.put("log4j.appender.KAFKA.layout", "org.apache.log4j.PatternLayout"); + topicMissingProps.put("log4j.appender.KAFKA.layout.ConversionPattern", "%-5p: %c - %m%n"); + topicMissingProps.put("log4j.appender.KAFKA.brokerList", "127.0.0.1:9093"); + topicMissingProps.put("log4j.logger.kafka.log4j", "INFO, KAFKA"); + + assertThrows(ConfigException.class, () -> PropertyConfigurator.configure(topicMissingProps), "Missing properties exception was expected !"); + } + + @Test + public void testSetSaslMechanism() { + Properties props = getLog4jConfig(false); + props.put("log4j.appender.KAFKA.SaslMechanism", "PLAIN"); + PropertyConfigurator.configure(props); + + MockKafkaLog4jAppender mockKafkaLog4jAppender = getMockKafkaLog4jAppender(); + + assertEquals(mockKafkaLog4jAppender.getProducerProperties().getProperty(SaslConfigs.SASL_MECHANISM), "PLAIN"); + } + + @Test + public void testSaslMechanismNotSet() { + testProducerPropertyNotSet(SaslConfigs.SASL_MECHANISM); + } + + @Test + public void testSetJaasConfig() { + Properties props = getLog4jConfig(false); + props.put("log4j.appender.KAFKA.ClientJaasConf", "jaas-config"); + PropertyConfigurator.configure(props); + + MockKafkaLog4jAppender mockKafkaLog4jAppender = getMockKafkaLog4jAppender(); + assertEquals(mockKafkaLog4jAppender.getProducerProperties().getProperty(SaslConfigs.SASL_JAAS_CONFIG), "jaas-config"); + } + + @Test + public void testJaasConfigNotSet() { + testProducerPropertyNotSet(SaslConfigs.SASL_JAAS_CONFIG); + } + + private void testProducerPropertyNotSet(String name) { + PropertyConfigurator.configure(getLog4jConfig(false)); + MockKafkaLog4jAppender mockKafkaLog4jAppender = getMockKafkaLog4jAppender(); + assertThat(mockKafkaLog4jAppender.getProducerProperties().stringPropertyNames(), not(hasItem(name))); + } + + @Test + public void testLog4jAppends() { + PropertyConfigurator.configure(getLog4jConfig(false)); + + for (int i = 1; i <= 5; ++i) { + logger.error(getMessage(i)); + } + assertEquals(getMockKafkaLog4jAppender().getHistory().size(), 5); + } + + @Test + public void testSyncSendAndSimulateProducerFailShouldThrowException() { + Properties props = getLog4jConfig(true); + props.put("log4j.appender.KAFKA.IgnoreExceptions", "false"); + PropertyConfigurator.configure(props); + + MockKafkaLog4jAppender mockKafkaLog4jAppender = getMockKafkaLog4jAppender(); + replaceProducerWithMocked(mockKafkaLog4jAppender, false); + + assertThrows(RuntimeException.class, () -> logger.error(getMessage(0))); + } + + @Test + public void testSyncSendWithoutIgnoringExceptionsShouldNotThrowException() { + Properties props = getLog4jConfig(true); + props.put("log4j.appender.KAFKA.IgnoreExceptions", "false"); + PropertyConfigurator.configure(props); + + MockKafkaLog4jAppender mockKafkaLog4jAppender = getMockKafkaLog4jAppender(); + replaceProducerWithMocked(mockKafkaLog4jAppender, true); + + logger.error(getMessage(0)); + } + + @Test + public void testRealProducerConfigWithSyncSendShouldNotThrowException() { + Properties props = getLog4jConfigWithRealProducer(true); + PropertyConfigurator.configure(props); + + logger.error(getMessage(0)); + } + + @Test + public void testRealProducerConfigWithSyncSendAndNotIgnoringExceptionsShouldThrowException() { + Properties props = getLog4jConfigWithRealProducer(false); + PropertyConfigurator.configure(props); + + assertThrows(RuntimeException.class, () -> logger.error(getMessage(0))); + } + + private void replaceProducerWithMocked(MockKafkaLog4jAppender mockKafkaLog4jAppender, boolean success) { + @SuppressWarnings("unchecked") + MockProducer producer = mock(MockProducer.class); + CompletableFuture future = new CompletableFuture<>(); + if (success) + future.complete(new RecordMetadata(new TopicPartition("tp", 0), 0, 0, 0, 0, 0)); + else + future.completeExceptionally(new TimeoutException("simulated timeout")); + when(producer.send(any())).thenReturn(future); + // reconfiguring mock appender + mockKafkaLog4jAppender.setKafkaProducer(producer); + mockKafkaLog4jAppender.activateOptions(); + } + + private MockKafkaLog4jAppender getMockKafkaLog4jAppender() { + return (MockKafkaLog4jAppender) Logger.getRootLogger().getAppender("KAFKA"); + } + + private byte[] getMessage(int i) { + return ("test_" + i).getBytes(StandardCharsets.UTF_8); + } + + private Properties getLog4jConfigWithRealProducer(boolean ignoreExceptions) { + Properties props = new Properties(); + props.put("log4j.rootLogger", "INFO, KAFKA"); + props.put("log4j.appender.KAFKA", "org.apache.kafka.log4jappender.KafkaLog4jAppender"); + props.put("log4j.appender.KAFKA.layout", "org.apache.log4j.PatternLayout"); + props.put("log4j.appender.KAFKA.layout.ConversionPattern", "%-5p: %c - %m%n"); + props.put("log4j.appender.KAFKA.BrokerList", "127.0.0.2:9093"); + props.put("log4j.appender.KAFKA.Topic", "test-topic"); + props.put("log4j.appender.KAFKA.RequiredNumAcks", "1"); + props.put("log4j.appender.KAFKA.SyncSend", "true"); + // setting producer timeout (max.block.ms) to be low + props.put("log4j.appender.KAFKA.maxBlockMs", "10"); + // ignoring exceptions + props.put("log4j.appender.KAFKA.IgnoreExceptions", Boolean.toString(ignoreExceptions)); + props.put("log4j.logger.kafka.log4j", "INFO, KAFKA"); + return props; + } + + private Properties getLog4jConfig(boolean syncSend) { + Properties props = new Properties(); + props.put("log4j.rootLogger", "INFO, KAFKA"); + props.put("log4j.appender.KAFKA", "org.apache.kafka.log4jappender.MockKafkaLog4jAppender"); + props.put("log4j.appender.KAFKA.layout", "org.apache.log4j.PatternLayout"); + props.put("log4j.appender.KAFKA.layout.ConversionPattern", "%-5p: %c - %m%n"); + props.put("log4j.appender.KAFKA.BrokerList", "127.0.0.1:9093"); + props.put("log4j.appender.KAFKA.Topic", "test-topic"); + props.put("log4j.appender.KAFKA.RequiredNumAcks", "1"); + props.put("log4j.appender.KAFKA.SyncSend", Boolean.toString(syncSend)); + props.put("log4j.logger.kafka.log4j", "INFO, KAFKA"); + return props; + } +} + diff --git a/log4j-appender/src/test/java/org/apache/kafka/log4jappender/MockKafkaLog4jAppender.java b/log4j-appender/src/test/java/org/apache/kafka/log4jappender/MockKafkaLog4jAppender.java new file mode 100644 index 0000000..b699fa9 --- /dev/null +++ b/log4j-appender/src/test/java/org/apache/kafka/log4jappender/MockKafkaLog4jAppender.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.log4jappender; + +import org.apache.kafka.clients.producer.MockProducer; +import org.apache.kafka.clients.producer.Producer; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.test.MockSerializer; +import org.apache.log4j.spi.LoggingEvent; + +import java.util.List; +import java.util.Properties; + +public class MockKafkaLog4jAppender extends KafkaLog4jAppender { + private MockProducer mockProducer = + new MockProducer<>(false, new MockSerializer(), new MockSerializer()); + + private Properties producerProperties; + + @Override + protected Producer getKafkaProducer(Properties props) { + producerProperties = props; + return mockProducer; + } + + void setKafkaProducer(MockProducer producer) { + this.mockProducer = producer; + } + + @Override + protected void append(LoggingEvent event) { + if (super.getProducer() == null) { + activateOptions(); + } + super.append(event); + } + + List> getHistory() { + return mockProducer.history(); + } + + public Properties getProducerProperties() { + return producerProperties; + } +} diff --git a/metadata/src/main/java/org/apache/kafka/controller/BrokerControlState.java b/metadata/src/main/java/org/apache/kafka/controller/BrokerControlState.java new file mode 100644 index 0000000..dfcf8ce --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/controller/BrokerControlState.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + + +public enum BrokerControlState { + FENCED(true, false), + UNFENCED(false, false), + CONTROLLED_SHUTDOWN(false, false), + SHUTDOWN_NOW(true, true); + + private final boolean fenced; + private final boolean shouldShutDown; + + BrokerControlState(boolean fenced, boolean shouldShutDown) { + this.fenced = fenced; + this.shouldShutDown = shouldShutDown; + } + + public boolean fenced() { + return fenced; + } + + public boolean shouldShutDown() { + return shouldShutDown; + } + + public boolean inControlledShutdown() { + return this == CONTROLLED_SHUTDOWN; + } +} diff --git a/metadata/src/main/java/org/apache/kafka/controller/BrokerControlStates.java b/metadata/src/main/java/org/apache/kafka/controller/BrokerControlStates.java new file mode 100644 index 0000000..6605852 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/controller/BrokerControlStates.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import java.util.Objects; + + +class BrokerControlStates { + private final BrokerControlState current; + private final BrokerControlState next; + + BrokerControlStates(BrokerControlState current, BrokerControlState next) { + this.current = current; + this.next = next; + } + + BrokerControlState current() { + return current; + } + + BrokerControlState next() { + return next; + } + + @Override + public int hashCode() { + return Objects.hash(current, next); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof BrokerControlStates)) return false; + BrokerControlStates other = (BrokerControlStates) o; + return other.current == current && other.next == next; + } + + @Override + public String toString() { + return "BrokerControlStates(current=" + current + ", next=" + next + ")"; + } +} diff --git a/metadata/src/main/java/org/apache/kafka/controller/BrokerHeartbeatManager.java b/metadata/src/main/java/org/apache/kafka/controller/BrokerHeartbeatManager.java new file mode 100644 index 0000000..b95f0d3 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/controller/BrokerHeartbeatManager.java @@ -0,0 +1,603 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import org.apache.kafka.common.errors.InvalidReplicationFactorException; +import org.apache.kafka.common.message.BrokerHeartbeatRequestData; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.metadata.UsableBroker; +import org.slf4j.Logger; + +import java.util.Collection; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.TreeSet; +import java.util.function.Function; +import java.util.function.Supplier; + +import static org.apache.kafka.controller.BrokerControlState.FENCED; +import static org.apache.kafka.controller.BrokerControlState.CONTROLLED_SHUTDOWN; +import static org.apache.kafka.controller.BrokerControlState.SHUTDOWN_NOW; +import static org.apache.kafka.controller.BrokerControlState.UNFENCED; + + +/** + * The BrokerHeartbeatManager manages all the soft state associated with broker heartbeats. + * Soft state is state which does not appear in the metadata log. This state includes + * things like the last time each broker sent us a heartbeat, and whether the broker is + * trying to perform a controlled shutdown. + * + * Only the active controller has a BrokerHeartbeatManager, since only the active + * controller handles broker heartbeats. Standby controllers will create a heartbeat + * manager as part of the process of activating. This design minimizes the size of the + * metadata partition by excluding heartbeats from it. However, it does mean that after + * a controller failover, we may take some extra time to fence brokers, since the new + * active controller does not know when the last heartbeats were received from each. + */ +public class BrokerHeartbeatManager { + static class BrokerHeartbeatState { + /** + * The broker ID. + */ + private final int id; + + /** + * The last time we received a heartbeat from this broker, in monotonic nanoseconds. + * When this field is updated, we also may have to update the broker's position in + * the unfenced list. + */ + long lastContactNs; + + /** + * The last metadata offset which this broker reported. When this field is updated, + * we may also have to update the broker's position in the active set. + */ + long metadataOffset; + + /** + * The offset at which the broker should complete its controlled shutdown, or -1 + * if the broker is not performing a controlled shutdown. When this field is + * updated, we also have to update the broker's position in the shuttingDown set. + */ + private long controlledShutDownOffset; + + /** + * The previous entry in the unfenced list, or null if the broker is not in that list. + */ + private BrokerHeartbeatState prev; + + /** + * The next entry in the unfenced list, or null if the broker is not in that list. + */ + private BrokerHeartbeatState next; + + BrokerHeartbeatState(int id) { + this.id = id; + this.lastContactNs = 0; + this.prev = null; + this.next = null; + this.metadataOffset = -1; + this.controlledShutDownOffset = -1; + } + + /** + * Returns the broker ID. + */ + int id() { + return id; + } + + /** + * Returns true only if the broker is fenced. + */ + boolean fenced() { + return prev == null; + } + + /** + * Returns true only if the broker is in controlled shutdown state. + */ + boolean shuttingDown() { + return controlledShutDownOffset >= 0; + } + } + + static class MetadataOffsetComparator implements Comparator { + static final MetadataOffsetComparator INSTANCE = new MetadataOffsetComparator(); + + @Override + public int compare(BrokerHeartbeatState a, BrokerHeartbeatState b) { + if (a.metadataOffset < b.metadataOffset) { + return -1; + } else if (a.metadataOffset > b.metadataOffset) { + return 1; + } else if (a.id < b.id) { + return -1; + } else if (a.id > b.id) { + return 1; + } else { + return 0; + } + } + } + + static class BrokerHeartbeatStateList { + /** + * The head of the list of unfenced brokers. The list is sorted in ascending order + * of last contact time. + */ + private final BrokerHeartbeatState head; + + BrokerHeartbeatStateList() { + this.head = new BrokerHeartbeatState(-1); + head.prev = head; + head.next = head; + } + + /** + * Return the head of the list, or null if the list is empty. + */ + BrokerHeartbeatState first() { + BrokerHeartbeatState result = head.next; + return result == head ? null : result; + } + + /** + * Add the broker to the list. We start looking for a place to put it at the end + * of the list. + */ + void add(BrokerHeartbeatState broker) { + BrokerHeartbeatState cur = head.prev; + while (true) { + if (cur == head || cur.lastContactNs <= broker.lastContactNs) { + broker.next = cur.next; + cur.next.prev = broker; + broker.prev = cur; + cur.next = broker; + break; + } + cur = cur.prev; + } + } + + /** + * Remove a broker from the list. + */ + void remove(BrokerHeartbeatState broker) { + if (broker.next == null) { + throw new RuntimeException(broker + " is not in the list."); + } + broker.prev.next = broker.next; + broker.next.prev = broker.prev; + broker.prev = null; + broker.next = null; + } + + BrokerHeartbeatStateIterator iterator() { + return new BrokerHeartbeatStateIterator(head); + } + } + + static class BrokerHeartbeatStateIterator implements Iterator { + private final BrokerHeartbeatState head; + private BrokerHeartbeatState cur; + + BrokerHeartbeatStateIterator(BrokerHeartbeatState head) { + this.head = head; + this.cur = head; + } + + @Override + public boolean hasNext() { + return cur.next != head; + } + + @Override + public BrokerHeartbeatState next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + BrokerHeartbeatState result = cur.next; + cur = cur.next; + return result; + } + } + + private final Logger log; + + /** + * The Kafka clock object to use. + */ + private final Time time; + + /** + * The broker session timeout in nanoseconds. + */ + private final long sessionTimeoutNs; + + /** + * Maps broker IDs to heartbeat states. + */ + private final HashMap brokers; + + /** + * The list of unfenced brokers, sorted by last contact time. + */ + private final BrokerHeartbeatStateList unfenced; + + /** + * The set of active brokers. A broker is active if it is unfenced, and not shutting + * down. + */ + private final TreeSet active; + + BrokerHeartbeatManager(LogContext logContext, + Time time, + long sessionTimeoutNs) { + this.log = logContext.logger(BrokerHeartbeatManager.class); + this.time = time; + this.sessionTimeoutNs = sessionTimeoutNs; + this.brokers = new HashMap<>(); + this.unfenced = new BrokerHeartbeatStateList(); + this.active = new TreeSet<>(MetadataOffsetComparator.INSTANCE); + } + + // VisibleForTesting + Time time() { + return time; + } + + // VisibleForTesting + BrokerHeartbeatStateList unfenced() { + return unfenced; + } + + // VisibleForTesting + Collection brokers() { + return brokers.values(); + } + + /** + * Mark a broker as fenced. + * + * @param brokerId The ID of the broker to mark as fenced. + */ + void fence(int brokerId) { + BrokerHeartbeatState broker = brokers.get(brokerId); + if (broker != null) { + untrack(broker); + } + } + + /** + * Remove a broker. + * + * @param brokerId The ID of the broker to remove. + */ + void remove(int brokerId) { + BrokerHeartbeatState broker = brokers.remove(brokerId); + if (broker != null) { + untrack(broker); + } + } + + /** + * Stop tracking the broker in the unfenced list and active set, if it was tracked + * in either of these. + * + * @param broker The broker state to stop tracking. + */ + private void untrack(BrokerHeartbeatState broker) { + if (!broker.fenced()) { + unfenced.remove(broker); + if (!broker.shuttingDown()) { + active.remove(broker); + } + } + } + + /** + * Check if the given broker has a valid session. + * + * @param brokerId The broker ID to check. + * + * @return True if the given broker has a valid session. + */ + boolean hasValidSession(int brokerId) { + BrokerHeartbeatState broker = brokers.get(brokerId); + if (broker == null) return false; + return hasValidSession(broker); + } + + /** + * Check if the given broker has a valid session. + * + * @param broker The broker to check. + * + * @return True if the given broker has a valid session. + */ + private boolean hasValidSession(BrokerHeartbeatState broker) { + if (broker.fenced()) { + return false; + } else { + return broker.lastContactNs + sessionTimeoutNs >= time.nanoseconds(); + } + } + + /** + * Update broker state, including lastContactNs. + * + * @param brokerId The broker ID. + * @param fenced True only if the broker is currently fenced. + * @param metadataOffset The latest metadata offset of the broker. + */ + void touch(int brokerId, boolean fenced, long metadataOffset) { + BrokerHeartbeatState broker = brokers.get(brokerId); + if (broker == null) { + broker = new BrokerHeartbeatState(brokerId); + brokers.put(brokerId, broker); + } else { + // Remove the broker from the unfenced list and/or the active set. Its + // position in either of those data structures depends on values we are + // changing here. We will re-add it if necessary at the end of this function. + untrack(broker); + } + broker.lastContactNs = time.nanoseconds(); + broker.metadataOffset = metadataOffset; + if (fenced) { + // If a broker is fenced, it leaves controlled shutdown. On its next heartbeat, + // it will shut down immediately. + broker.controlledShutDownOffset = -1; + } else { + unfenced.add(broker); + if (!broker.shuttingDown()) { + active.add(broker); + } + } + } + + long lowestActiveOffset() { + Iterator iterator = active.iterator(); + if (!iterator.hasNext()) { + return Long.MAX_VALUE; + } + BrokerHeartbeatState first = iterator.next(); + return first.metadataOffset; + } + + /** + * Mark a broker as being in the controlled shutdown state. + * + * @param brokerId The broker id. + * @param controlledShutDownOffset The offset at which controlled shutdown will be complete. + */ + void updateControlledShutdownOffset(int brokerId, long controlledShutDownOffset) { + BrokerHeartbeatState broker = brokers.get(brokerId); + if (broker == null) { + throw new RuntimeException("Unable to locate broker " + brokerId); + } + if (broker.fenced()) { + throw new RuntimeException("Fenced brokers cannot enter controlled shutdown."); + } + active.remove(broker); + broker.controlledShutDownOffset = controlledShutDownOffset; + log.debug("Updated the controlled shutdown offset for broker {} to {}.", + brokerId, controlledShutDownOffset); + } + + /** + * Return the time in monotonic nanoseconds at which we should check if a broker + * session needs to be expired. + */ + long nextCheckTimeNs() { + BrokerHeartbeatState broker = unfenced.first(); + if (broker == null) { + return Long.MAX_VALUE; + } else { + return broker.lastContactNs + sessionTimeoutNs; + } + } + + /** + * Check if the oldest broker to have hearbeated has already violated the + * sessionTimeoutNs timeout and needs to be fenced. + * + * @return An Optional broker node id. + */ + Optional findOneStaleBroker() { + BrokerHeartbeatStateIterator iterator = unfenced.iterator(); + if (iterator.hasNext()) { + BrokerHeartbeatState broker = iterator.next(); + // The unfenced list is sorted on last contact time from each + // broker. If the first broker is not stale, then none is. + if (!hasValidSession(broker)) { + return Optional.of(broker.id); + } + } + return Optional.empty(); + } + + /** + * Place replicas on unfenced brokers. + * + * @param startPartition The partition ID to start with. + * @param numPartitions The number of partitions to place. + * @param numReplicas The number of replicas for each partition. + * @param idToRack A function mapping broker id to broker rack. + * @param placer The replica placer to use. + * + * @return A list of replica lists. + * + * @throws InvalidReplicationFactorException If too many replicas were requested. + */ + List> placeReplicas(int startPartition, + int numPartitions, + short numReplicas, + Function> idToRack, + ReplicaPlacer placer) { + Iterator iterator = new UsableBrokerIterator( + brokers.values().iterator(), idToRack); + return placer.place(startPartition, numPartitions, numReplicas, iterator); + } + + static class UsableBrokerIterator implements Iterator { + private final Iterator iterator; + private final Function> idToRack; + private UsableBroker next; + + UsableBrokerIterator(Iterator iterator, + Function> idToRack) { + this.iterator = iterator; + this.idToRack = idToRack; + this.next = null; + } + + @Override + public boolean hasNext() { + if (next != null) { + return true; + } + BrokerHeartbeatState result; + do { + if (!iterator.hasNext()) { + return false; + } + result = iterator.next(); + } while (result.shuttingDown()); + Optional rack = idToRack.apply(result.id()); + next = new UsableBroker(result.id(), rack, result.fenced()); + return true; + } + + @Override + public UsableBroker next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + UsableBroker result = next; + next = null; + return result; + } + } + + BrokerControlState currentBrokerState(BrokerHeartbeatState broker) { + if (broker.shuttingDown()) { + return CONTROLLED_SHUTDOWN; + } else if (broker.fenced()) { + return FENCED; + } else { + return UNFENCED; + } + } + + /** + * Calculate the next broker state for a broker that just sent a heartbeat request. + * + * @param brokerId The broker id. + * @param request The incoming heartbeat request. + * @param lastCommittedOffset The last committed offset of the quorum controller. + * @param hasLeaderships A callback which evaluates to true if the broker leads + * at least one partition. + * + * @return The current and next broker states. + */ + BrokerControlStates calculateNextBrokerState(int brokerId, + BrokerHeartbeatRequestData request, + long lastCommittedOffset, + Supplier hasLeaderships) { + BrokerHeartbeatState broker = brokers.getOrDefault(brokerId, + new BrokerHeartbeatState(brokerId)); + BrokerControlState currentState = currentBrokerState(broker); + switch (currentState) { + case FENCED: + if (request.wantShutDown()) { + log.info("Fenced broker {} has requested and been granted an immediate " + + "shutdown.", brokerId); + return new BrokerControlStates(currentState, SHUTDOWN_NOW); + } else if (!request.wantFence()) { + if (request.currentMetadataOffset() >= lastCommittedOffset) { + log.info("The request from broker {} to unfence has been granted " + + "because it has caught up with the last committed metadata " + + "offset {}.", brokerId, lastCommittedOffset); + return new BrokerControlStates(currentState, UNFENCED); + } else { + if (log.isDebugEnabled()) { + log.debug("The request from broker {} to unfence cannot yet " + + "be granted because it has not caught up with the last " + + "committed metadata offset {}. It is still at offset {}.", + brokerId, lastCommittedOffset, request.currentMetadataOffset()); + } + return new BrokerControlStates(currentState, FENCED); + } + } + return new BrokerControlStates(currentState, FENCED); + + case UNFENCED: + if (request.wantFence()) { + if (request.wantShutDown()) { + log.info("Unfenced broker {} has requested and been granted an " + + "immediate shutdown.", brokerId); + return new BrokerControlStates(currentState, SHUTDOWN_NOW); + } else { + log.info("Unfenced broker {} has requested and been granted " + + "fencing", brokerId); + return new BrokerControlStates(currentState, FENCED); + } + } else if (request.wantShutDown()) { + if (hasLeaderships.get()) { + log.info("Unfenced broker {} has requested and been granted a " + + "controlled shutdown.", brokerId); + return new BrokerControlStates(currentState, CONTROLLED_SHUTDOWN); + } else { + log.info("Unfenced broker {} has requested and been granted an " + + "immediate shutdown.", brokerId); + return new BrokerControlStates(currentState, SHUTDOWN_NOW); + } + } + return new BrokerControlStates(currentState, UNFENCED); + + case CONTROLLED_SHUTDOWN: + if (hasLeaderships.get()) { + log.debug("Broker {} is in controlled shutdown state, but can not " + + "shut down because more leaders still need to be moved.", brokerId); + return new BrokerControlStates(currentState, CONTROLLED_SHUTDOWN); + } + long lowestActiveOffset = lowestActiveOffset(); + if (broker.controlledShutDownOffset <= lowestActiveOffset) { + log.info("The request from broker {} to shut down has been granted " + + "since the lowest active offset {} is now greater than the " + + "broker's controlled shutdown offset {}.", brokerId, + lowestActiveOffset, broker.controlledShutDownOffset); + return new BrokerControlStates(currentState, SHUTDOWN_NOW); + } + log.debug("The request from broker {} to shut down can not yet be granted " + + "because the lowest active offset {} is not greater than the broker's " + + "shutdown offset {}.", brokerId, lowestActiveOffset, + broker.controlledShutDownOffset); + return new BrokerControlStates(currentState, CONTROLLED_SHUTDOWN); + + default: + return new BrokerControlStates(currentState, SHUTDOWN_NOW); + } + } +} diff --git a/metadata/src/main/java/org/apache/kafka/controller/BrokersToIsrs.java b/metadata/src/main/java/org/apache/kafka/controller/BrokersToIsrs.java new file mode 100644 index 0000000..aceb6dd --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/controller/BrokersToIsrs.java @@ -0,0 +1,346 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import org.apache.kafka.common.Uuid; +import org.apache.kafka.metadata.Replicas; +import org.apache.kafka.timeline.SnapshotRegistry; +import org.apache.kafka.timeline.TimelineHashMap; +import org.apache.kafka.timeline.TimelineInteger; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; +import java.util.Map; +import java.util.Map.Entry; +import java.util.NoSuchElementException; +import java.util.Objects; + +import static org.apache.kafka.metadata.LeaderConstants.NO_LEADER; +import static org.apache.kafka.metadata.Replicas.NONE; + + +/** + * Associates brokers with their in-sync partitions. + * + * This is useful when we need to remove a broker from all the ISRs, or move all leaders + * away from a broker. + * + * We also track all the partitions that currently have no leader. + * + * The core data structure is a map from broker IDs to topic maps. Each topic map relates + * topic UUIDs to arrays of partition IDs. + * + * Each entry in the array has a high bit which indicates that the broker is the leader + * for the given partition, as well as 31 low bits which contain the partition id. This + * works because partition IDs cannot be negative. + */ +public class BrokersToIsrs { + private final static int LEADER_FLAG = 0x8000_0000; + + private final static int REPLICA_MASK = 0x7fff_ffff; + + static class TopicIdPartition { + private final Uuid topicId; + private final int partitionId; + + TopicIdPartition(Uuid topicId, int partitionId) { + this.topicId = topicId; + this.partitionId = partitionId; + } + + public Uuid topicId() { + return topicId; + } + + public int partitionId() { + return partitionId; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof TopicIdPartition)) return false; + TopicIdPartition other = (TopicIdPartition) o; + return other.topicId.equals(topicId) && other.partitionId == partitionId; + } + + @Override + public int hashCode() { + return Objects.hash(topicId, partitionId); + } + + @Override + public String toString() { + return topicId + ":" + partitionId; + } + } + + static class PartitionsOnReplicaIterator implements Iterator { + private final Iterator> iterator; + private final boolean leaderOnly; + private int offset = 0; + Uuid uuid = Uuid.ZERO_UUID; + int[] replicas = NONE; + private TopicIdPartition next = null; + + PartitionsOnReplicaIterator(Map topicMap, boolean leaderOnly) { + this.iterator = topicMap.entrySet().iterator(); + this.leaderOnly = leaderOnly; + } + + @Override + public boolean hasNext() { + if (next != null) return true; + while (true) { + if (offset >= replicas.length) { + if (!iterator.hasNext()) return false; + offset = 0; + Entry entry = iterator.next(); + uuid = entry.getKey(); + replicas = entry.getValue(); + } + int replica = replicas[offset++]; + if ((!leaderOnly) || (replica & LEADER_FLAG) != 0) { + next = new TopicIdPartition(uuid, replica & REPLICA_MASK); + return true; + } + } + } + + @Override + public TopicIdPartition next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + TopicIdPartition result = next; + next = null; + return result; + } + } + + private final SnapshotRegistry snapshotRegistry; + + /** + * A map of broker IDs to the partitions that the broker is in the ISR for. + * Partitions with no isr members appear in this map under id NO_LEADER. + */ + private final TimelineHashMap> isrMembers; + + private final TimelineInteger offlinePartitionCount; + + BrokersToIsrs(SnapshotRegistry snapshotRegistry) { + this.snapshotRegistry = snapshotRegistry; + this.isrMembers = new TimelineHashMap<>(snapshotRegistry, 0); + this.offlinePartitionCount = new TimelineInteger(snapshotRegistry); + } + + /** + * Update our records of a partition's ISR. + * + * @param topicId The topic ID of the partition. + * @param partitionId The partition ID of the partition. + * @param prevIsr The previous ISR, or null if the partition is new. + * @param nextIsr The new ISR, or null if the partition is being removed. + * @param prevLeader The previous leader, or NO_LEADER if the partition had no leader. + * @param nextLeader The new leader, or NO_LEADER if the partition now has no leader. + */ + void update(Uuid topicId, int partitionId, int[] prevIsr, int[] nextIsr, + int prevLeader, int nextLeader) { + int[] prev; + if (prevIsr == null) { + prev = NONE; + } else { + if (prevLeader == NO_LEADER) { + prev = Replicas.copyWith(prevIsr, NO_LEADER); + if (nextLeader != NO_LEADER) { + offlinePartitionCount.decrement(); + } + } else { + prev = Replicas.clone(prevIsr); + } + Arrays.sort(prev); + } + int[] next; + if (nextIsr == null) { + next = NONE; + } else { + if (nextLeader == NO_LEADER) { + next = Replicas.copyWith(nextIsr, NO_LEADER); + if (prevLeader != NO_LEADER) { + offlinePartitionCount.increment(); + } + } else { + next = Replicas.clone(nextIsr); + } + Arrays.sort(next); + } + int i = 0, j = 0; + while (true) { + if (i == prev.length) { + if (j == next.length) { + break; + } + int newReplica = next[j]; + add(newReplica, topicId, partitionId, newReplica == nextLeader); + j++; + } else if (j == next.length) { + int prevReplica = prev[i]; + remove(prevReplica, topicId, partitionId, prevReplica == prevLeader); + i++; + } else { + int prevReplica = prev[i]; + int newReplica = next[j]; + if (prevReplica < newReplica) { + remove(prevReplica, topicId, partitionId, prevReplica == prevLeader); + i++; + } else if (prevReplica > newReplica) { + add(newReplica, topicId, partitionId, newReplica == nextLeader); + j++; + } else { + boolean wasLeader = prevReplica == prevLeader; + boolean isLeader = prevReplica == nextLeader; + if (wasLeader != isLeader) { + change(prevReplica, topicId, partitionId, wasLeader, isLeader); + } + i++; + j++; + } + } + } + } + + void removeTopicEntryForBroker(Uuid topicId, int brokerId) { + Map topicMap = isrMembers.get(brokerId); + if (topicMap != null) { + if (brokerId == NO_LEADER) { + offlinePartitionCount.set(offlinePartitionCount.get() - topicMap.get(topicId).length); + } + topicMap.remove(topicId); + } + } + + private void add(int brokerId, Uuid topicId, int newPartition, boolean leader) { + if (leader) { + newPartition = newPartition | LEADER_FLAG; + } + TimelineHashMap topicMap = isrMembers.get(brokerId); + if (topicMap == null) { + topicMap = new TimelineHashMap<>(snapshotRegistry, 0); + isrMembers.put(brokerId, topicMap); + } + int[] partitions = topicMap.get(topicId); + int[] newPartitions; + if (partitions == null) { + newPartitions = new int[1]; + } else { + newPartitions = new int[partitions.length + 1]; + System.arraycopy(partitions, 0, newPartitions, 0, partitions.length); + } + newPartitions[newPartitions.length - 1] = newPartition; + topicMap.put(topicId, newPartitions); + } + + private void change(int brokerId, Uuid topicId, int partition, + boolean wasLeader, boolean isLeader) { + TimelineHashMap topicMap = isrMembers.get(brokerId); + if (topicMap == null) { + throw new RuntimeException("Broker " + brokerId + " has no isrMembers " + + "entry, so we can't change " + topicId + ":" + partition); + } + int[] partitions = topicMap.get(topicId); + if (partitions == null) { + throw new RuntimeException("Broker " + brokerId + " has no " + + "entry in isrMembers for topic " + topicId); + } + int[] newPartitions = new int[partitions.length]; + int target = wasLeader ? partition | LEADER_FLAG : partition; + for (int i = 0; i < partitions.length; i++) { + int cur = partitions[i]; + if (cur == target) { + newPartitions[i] = isLeader ? partition | LEADER_FLAG : partition; + } else { + newPartitions[i] = cur; + } + } + topicMap.put(topicId, newPartitions); + } + + private void remove(int brokerId, Uuid topicId, int removedPartition, boolean leader) { + if (leader) { + removedPartition = removedPartition | LEADER_FLAG; + } + TimelineHashMap topicMap = isrMembers.get(brokerId); + if (topicMap == null) { + throw new RuntimeException("Broker " + brokerId + " has no isrMembers " + + "entry, so we can't remove " + topicId + ":" + removedPartition); + } + int[] partitions = topicMap.get(topicId); + if (partitions == null) { + throw new RuntimeException("Broker " + brokerId + " has no " + + "entry in isrMembers for topic " + topicId); + } + if (partitions.length == 1) { + if (partitions[0] != removedPartition) { + throw new RuntimeException("Broker " + brokerId + " has no " + + "entry in isrMembers for " + topicId + ":" + removedPartition); + } + topicMap.remove(topicId); + if (topicMap.isEmpty()) { + isrMembers.remove(brokerId); + } + } else { + int[] newPartitions = new int[partitions.length - 1]; + int j = 0; + for (int i = 0; i < partitions.length; i++) { + int partition = partitions[i]; + if (partition != removedPartition) { + newPartitions[j++] = partition; + } + } + topicMap.put(topicId, newPartitions); + } + } + + PartitionsOnReplicaIterator iterator(int brokerId, boolean leadersOnly) { + Map topicMap = isrMembers.get(brokerId); + if (topicMap == null) { + topicMap = Collections.emptyMap(); + } + return new PartitionsOnReplicaIterator(topicMap, leadersOnly); + } + + PartitionsOnReplicaIterator partitionsWithNoLeader() { + return iterator(NO_LEADER, true); + } + + PartitionsOnReplicaIterator partitionsLedByBroker(int brokerId) { + return iterator(brokerId, true); + } + + PartitionsOnReplicaIterator partitionsWithBrokerInIsr(int brokerId) { + return iterator(brokerId, false); + } + + boolean hasLeaderships(int brokerId) { + return iterator(brokerId, true).hasNext(); + } + + int offlinePartitionCount() { + return offlinePartitionCount.get(); + } +} diff --git a/metadata/src/main/java/org/apache/kafka/controller/ClientQuotaControlManager.java b/metadata/src/main/java/org/apache/kafka/controller/ClientQuotaControlManager.java new file mode 100644 index 0000000..6e8198b --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/controller/ClientQuotaControlManager.java @@ -0,0 +1,334 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.internals.QuotaConfigs; +import org.apache.kafka.common.errors.InvalidRequestException; +import org.apache.kafka.common.metadata.ClientQuotaRecord; +import org.apache.kafka.common.metadata.ClientQuotaRecord.EntityData; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.quota.ClientQuotaAlteration; +import org.apache.kafka.common.quota.ClientQuotaEntity; +import org.apache.kafka.common.requests.ApiError; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.apache.kafka.timeline.SnapshotRegistry; +import org.apache.kafka.timeline.TimelineHashMap; + +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.NoSuchElementException; +import java.util.Objects; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import static org.apache.kafka.common.metadata.MetadataRecordType.CLIENT_QUOTA_RECORD; + + +public class ClientQuotaControlManager { + private final SnapshotRegistry snapshotRegistry; + + final TimelineHashMap> clientQuotaData; + + ClientQuotaControlManager(SnapshotRegistry snapshotRegistry) { + this.snapshotRegistry = snapshotRegistry; + this.clientQuotaData = new TimelineHashMap<>(snapshotRegistry, 0); + } + + /** + * Determine the result of applying a batch of client quota alteration. Note + * that this method does not change the contents of memory. It just generates a + * result, that you can replay later if you wish using replay(). + * + * @param quotaAlterations List of client quota alterations to evaluate + * @return The result. + */ + ControllerResult> alterClientQuotas( + Collection quotaAlterations) { + List outputRecords = new ArrayList<>(); + Map outputResults = new HashMap<>(); + + quotaAlterations.forEach(quotaAlteration -> { + // Note that the values in this map may be null + Map alterations = new HashMap<>(quotaAlteration.ops().size()); + quotaAlteration.ops().forEach(op -> { + if (alterations.containsKey(op.key())) { + outputResults.put(quotaAlteration.entity(), ApiError.fromThrowable( + new InvalidRequestException("Duplicate quota key " + op.key() + + " not updating quota for this entity " + quotaAlteration.entity()))); + } else { + alterations.put(op.key(), op.value()); + } + }); + if (outputResults.containsKey(quotaAlteration.entity())) { + outputResults.put(quotaAlteration.entity(), ApiError.fromThrowable( + new InvalidRequestException("Ignoring duplicate entity " + quotaAlteration.entity()))); + } else { + alterClientQuotaEntity(quotaAlteration.entity(), alterations, outputRecords, outputResults); + } + }); + + return ControllerResult.atomicOf(outputRecords, outputResults); + } + + /** + * Apply a quota record to the in-memory state. + * + * @param record A ClientQuotaRecord instance. + */ + public void replay(ClientQuotaRecord record) { + Map entityMap = new HashMap<>(2); + record.entity().forEach(entityData -> entityMap.put(entityData.entityType(), entityData.entityName())); + ClientQuotaEntity entity = new ClientQuotaEntity(entityMap); + TimelineHashMap quotas = clientQuotaData.get(entity); + if (quotas == null) { + quotas = new TimelineHashMap<>(snapshotRegistry, 0); + clientQuotaData.put(entity, quotas); + } + if (record.remove()) { + quotas.remove(record.key()); + if (quotas.size() == 0) { + clientQuotaData.remove(entity); + } + } else { + quotas.put(record.key(), record.value()); + } + } + + private void alterClientQuotaEntity( + ClientQuotaEntity entity, + Map newQuotaConfigs, + List outputRecords, + Map outputResults) { + + // Check entity types and sanitize the names + Map validatedEntityMap = new HashMap<>(3); + ApiError error = validateEntity(entity, validatedEntityMap); + if (error.isFailure()) { + outputResults.put(entity, error); + return; + } + + // Check the combination of entity types and get the config keys + Map configKeys = new HashMap<>(4); + error = configKeysForEntityType(validatedEntityMap, configKeys); + if (error.isFailure()) { + outputResults.put(entity, error); + return; + } + + // Don't share objects between different records + Supplier> recordEntitySupplier = () -> + validatedEntityMap.entrySet().stream().map(mapEntry -> new EntityData() + .setEntityType(mapEntry.getKey()) + .setEntityName(mapEntry.getValue())) + .collect(Collectors.toList()); + + List newRecords = new ArrayList<>(newQuotaConfigs.size()); + Map currentQuotas = clientQuotaData.containsKey(entity) ? + clientQuotaData.get(entity) : Collections.emptyMap(); + for (Map.Entry entry : newQuotaConfigs.entrySet()) { + String key = entry.getKey(); + Double newValue = entry.getValue(); + if (newValue == null) { + if (currentQuotas.containsKey(key)) { + // Null value indicates removal + newRecords.add(new ApiMessageAndVersion(new ClientQuotaRecord() + .setEntity(recordEntitySupplier.get()) + .setKey(key) + .setRemove(true), + CLIENT_QUOTA_RECORD.highestSupportedVersion())); + } + } else { + ApiError validationError = validateQuotaKeyValue(configKeys, key, newValue); + if (validationError.isFailure()) { + outputResults.put(entity, validationError); + return; + } else { + final Double currentValue = currentQuotas.get(key); + if (!Objects.equals(currentValue, newValue)) { + // Only record the new value if it has changed + newRecords.add(new ApiMessageAndVersion(new ClientQuotaRecord() + .setEntity(recordEntitySupplier.get()) + .setKey(key) + .setValue(newValue), + CLIENT_QUOTA_RECORD.highestSupportedVersion())); + } + } + } + } + + outputRecords.addAll(newRecords); + outputResults.put(entity, ApiError.NONE); + } + + private ApiError configKeysForEntityType(Map entity, Map output) { + // We only allow certain combinations of quota entity types. Which type is in use determines which config + // keys are valid + boolean hasUser = entity.containsKey(ClientQuotaEntity.USER); + boolean hasClientId = entity.containsKey(ClientQuotaEntity.CLIENT_ID); + boolean hasIp = entity.containsKey(ClientQuotaEntity.IP); + + final Map configKeys; + if (hasIp) { + if (hasUser || hasClientId) { + return new ApiError(Errors.INVALID_REQUEST, "Invalid quota entity combination, IP entity should" + + "not be combined with User or ClientId"); + } else { + if (isValidIpEntity(entity.get(ClientQuotaEntity.IP))) { + configKeys = QuotaConfigs.ipConfigs().configKeys(); + } else { + return new ApiError(Errors.INVALID_REQUEST, entity.get(ClientQuotaEntity.IP) + " is not a valid IP or resolvable host."); + } + } + } else if (hasUser && hasClientId) { + configKeys = QuotaConfigs.userConfigs().configKeys(); + } else if (hasUser) { + configKeys = QuotaConfigs.userConfigs().configKeys(); + } else if (hasClientId) { + configKeys = QuotaConfigs.clientConfigs().configKeys(); + } else { + return new ApiError(Errors.INVALID_REQUEST, "Invalid empty client quota entity"); + } + + output.putAll(configKeys); + return ApiError.NONE; + } + + private ApiError validateQuotaKeyValue(Map validKeys, String key, Double value) { + // TODO can this validation be shared with alter configs? + // Ensure we have an allowed quota key + ConfigDef.ConfigKey configKey = validKeys.get(key); + if (configKey == null) { + return new ApiError(Errors.INVALID_REQUEST, "Invalid configuration key " + key); + } + + // Ensure the quota value is valid + switch (configKey.type()) { + case DOUBLE: + break; + case SHORT: + case INT: + case LONG: + Double epsilon = 1e-6; + Long longValue = Double.valueOf(value + epsilon).longValue(); + if (Math.abs(longValue.doubleValue() - value) > epsilon) { + return new ApiError(Errors.INVALID_REQUEST, + "Configuration " + key + " must be a Long value"); + } + break; + default: + return new ApiError(Errors.UNKNOWN_SERVER_ERROR, + "Unexpected config type " + configKey.type() + " should be Long or Double"); + } + return ApiError.NONE; + } + + // TODO move this somewhere common? + private boolean isValidIpEntity(String ip) { + if (Objects.nonNull(ip)) { + try { + InetAddress.getByName(ip); + return true; + } catch (UnknownHostException e) { + return false; + } + } else { + return true; + } + } + + private ApiError validateEntity(ClientQuotaEntity entity, Map validatedEntityMap) { + // Given a quota entity (which is a mapping of entity type to entity name), validate it's types + if (entity.entries().isEmpty()) { + return new ApiError(Errors.INVALID_REQUEST, "Invalid empty client quota entity"); + } + + for (Entry entityEntry : entity.entries().entrySet()) { + String entityType = entityEntry.getKey(); + String entityName = entityEntry.getValue(); + if (validatedEntityMap.containsKey(entityType)) { + return new ApiError(Errors.INVALID_REQUEST, "Invalid client quota entity, duplicate entity entry " + entityType); + } + if (Objects.equals(entityType, ClientQuotaEntity.USER)) { + validatedEntityMap.put(ClientQuotaEntity.USER, entityName); + } else if (Objects.equals(entityType, ClientQuotaEntity.CLIENT_ID)) { + validatedEntityMap.put(ClientQuotaEntity.CLIENT_ID, entityName); + } else if (Objects.equals(entityType, ClientQuotaEntity.IP)) { + validatedEntityMap.put(ClientQuotaEntity.IP, entityName); + } else { + return new ApiError(Errors.INVALID_REQUEST, "Unhandled client quota entity type: " + entityType); + } + + if (entityName != null && entityName.isEmpty()) { + return new ApiError(Errors.INVALID_REQUEST, "Empty " + entityType + " not supported"); + } + } + + return ApiError.NONE; + } + + class ClientQuotaControlIterator implements Iterator> { + private final long epoch; + private final Iterator>> iterator; + + ClientQuotaControlIterator(long epoch) { + this.epoch = epoch; + this.iterator = clientQuotaData.entrySet(epoch).iterator(); + } + + @Override + public boolean hasNext() { + return iterator.hasNext(); + } + + @Override + public List next() { + if (!hasNext()) throw new NoSuchElementException(); + Entry> entry = iterator.next(); + ClientQuotaEntity entity = entry.getKey(); + List records = new ArrayList<>(); + for (Entry quotaEntry : entry.getValue().entrySet(epoch)) { + ClientQuotaRecord record = new ClientQuotaRecord(); + for (Entry entityEntry : entity.entries().entrySet()) { + record.entity().add(new EntityData(). + setEntityType(entityEntry.getKey()). + setEntityName(entityEntry.getValue())); + } + record.setKey(quotaEntry.getKey()); + record.setValue(quotaEntry.getValue()); + record.setRemove(false); + records.add(new ApiMessageAndVersion(record, + CLIENT_QUOTA_RECORD.highestSupportedVersion())); + } + return records; + } + } + + ClientQuotaControlIterator iterator(long epoch) { + return new ClientQuotaControlIterator(epoch); + } +} diff --git a/metadata/src/main/java/org/apache/kafka/controller/ClusterControlManager.java b/metadata/src/main/java/org/apache/kafka/controller/ClusterControlManager.java new file mode 100644 index 0000000..5916cdc --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/controller/ClusterControlManager.java @@ -0,0 +1,445 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import org.apache.kafka.common.Endpoint; +import org.apache.kafka.common.errors.DuplicateBrokerRegistrationException; +import org.apache.kafka.common.errors.StaleBrokerEpochException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.BrokerRegistrationRequestData; +import org.apache.kafka.common.metadata.FenceBrokerRecord; +import org.apache.kafka.common.metadata.RegisterBrokerRecord; +import org.apache.kafka.common.metadata.RegisterBrokerRecord.BrokerEndpoint; +import org.apache.kafka.common.metadata.RegisterBrokerRecord.BrokerEndpointCollection; +import org.apache.kafka.common.metadata.RegisterBrokerRecord.BrokerFeature; +import org.apache.kafka.common.metadata.RegisterBrokerRecord.BrokerFeatureCollection; +import org.apache.kafka.common.metadata.UnfenceBrokerRecord; +import org.apache.kafka.common.metadata.UnregisterBrokerRecord; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.metadata.BrokerRegistration; +import org.apache.kafka.metadata.BrokerRegistrationReply; +import org.apache.kafka.metadata.FeatureMapAndEpoch; +import org.apache.kafka.metadata.VersionRange; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.apache.kafka.timeline.SnapshotRegistry; +import org.apache.kafka.timeline.TimelineHashMap; +import org.slf4j.Logger; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; + +import static org.apache.kafka.common.metadata.MetadataRecordType.REGISTER_BROKER_RECORD; + + +/** + * The ClusterControlManager manages all the hard state associated with the Kafka cluster. + * Hard state is state which appears in the metadata log, such as broker registrations, + * brokers being fenced or unfenced, and broker feature versions. + */ +public class ClusterControlManager { + class ReadyBrokersFuture { + private final CompletableFuture future; + private final int minBrokers; + + ReadyBrokersFuture(CompletableFuture future, int minBrokers) { + this.future = future; + this.minBrokers = minBrokers; + } + + boolean check() { + int numUnfenced = 0; + for (BrokerRegistration registration : brokerRegistrations.values()) { + if (!registration.fenced()) { + numUnfenced++; + } + if (numUnfenced >= minBrokers) { + return true; + } + } + return false; + } + } + + /** + * The SLF4J log context. + */ + private final LogContext logContext; + + /** + * The SLF4J log object. + */ + private final Logger log; + + /** + * The Kafka clock object to use. + */ + private final Time time; + + /** + * How long sessions should last, in nanoseconds. + */ + private final long sessionTimeoutNs; + + /** + * The replica placer to use. + */ + private final ReplicaPlacer replicaPlacer; + + /** + * Maps broker IDs to broker registrations. + */ + private final TimelineHashMap brokerRegistrations; + + /** + * A reference to the controller's metrics registry. + */ + private final ControllerMetrics controllerMetrics; + + /** + * The broker heartbeat manager, or null if this controller is on standby. + */ + private BrokerHeartbeatManager heartbeatManager; + + /** + * A future which is completed as soon as we have the given number of brokers + * ready. + */ + private Optional readyBrokersFuture; + + ClusterControlManager(LogContext logContext, + Time time, + SnapshotRegistry snapshotRegistry, + long sessionTimeoutNs, + ReplicaPlacer replicaPlacer, + ControllerMetrics metrics) { + this.logContext = logContext; + this.log = logContext.logger(ClusterControlManager.class); + this.time = time; + this.sessionTimeoutNs = sessionTimeoutNs; + this.replicaPlacer = replicaPlacer; + this.brokerRegistrations = new TimelineHashMap<>(snapshotRegistry, 0); + this.heartbeatManager = null; + this.readyBrokersFuture = Optional.empty(); + this.controllerMetrics = metrics; + } + + /** + * Transition this ClusterControlManager to active. + */ + public void activate() { + heartbeatManager = new BrokerHeartbeatManager(logContext, time, sessionTimeoutNs); + for (BrokerRegistration registration : brokerRegistrations.values()) { + heartbeatManager.touch(registration.id(), registration.fenced(), -1); + } + } + + /** + * Transition this ClusterControlManager to standby. + */ + public void deactivate() { + heartbeatManager = null; + } + + Map brokerRegistrations() { + return brokerRegistrations; + } + + Set fencedBrokerIds() { + return brokerRegistrations.values() + .stream() + .filter(BrokerRegistration::fenced) + .map(BrokerRegistration::id) + .collect(Collectors.toSet()); + } + + /** + * Process an incoming broker registration request. + */ + public ControllerResult registerBroker( + BrokerRegistrationRequestData request, + long brokerEpoch, + FeatureMapAndEpoch finalizedFeatures) { + if (heartbeatManager == null) { + throw new RuntimeException("ClusterControlManager is not active."); + } + int brokerId = request.brokerId(); + BrokerRegistration existing = brokerRegistrations.get(brokerId); + if (existing != null) { + if (heartbeatManager.hasValidSession(brokerId)) { + if (!existing.incarnationId().equals(request.incarnationId())) { + throw new DuplicateBrokerRegistrationException("Another broker is " + + "registered with that broker id."); + } + } else { + if (!existing.incarnationId().equals(request.incarnationId())) { + // Remove any existing session for the old broker incarnation. + heartbeatManager.remove(brokerId); + existing = null; + } + } + } + + RegisterBrokerRecord record = new RegisterBrokerRecord().setBrokerId(brokerId). + setIncarnationId(request.incarnationId()). + setBrokerEpoch(brokerEpoch). + setRack(request.rack()); + for (BrokerRegistrationRequestData.Listener listener : request.listeners()) { + record.endPoints().add(new BrokerEndpoint(). + setHost(listener.host()). + setName(listener.name()). + setPort(listener.port()). + setSecurityProtocol(listener.securityProtocol())); + } + for (BrokerRegistrationRequestData.Feature feature : request.features()) { + Optional finalized = finalizedFeatures.map().get(feature.name()); + if (finalized.isPresent()) { + if (!finalized.get().contains(new VersionRange(feature.minSupportedVersion(), + feature.maxSupportedVersion()))) { + throw new UnsupportedVersionException("Unable to register because " + + "the broker has an unsupported version of " + feature.name()); + } + } + record.features().add(new BrokerFeature(). + setName(feature.name()). + setMinSupportedVersion(feature.minSupportedVersion()). + setMaxSupportedVersion(feature.maxSupportedVersion())); + } + + if (existing == null) { + heartbeatManager.touch(brokerId, true, -1); + } else { + heartbeatManager.touch(brokerId, existing.fenced(), -1); + } + + List records = new ArrayList<>(); + records.add(new ApiMessageAndVersion(record, + REGISTER_BROKER_RECORD.highestSupportedVersion())); + return ControllerResult.of(records, new BrokerRegistrationReply(brokerEpoch)); + } + + public void replay(RegisterBrokerRecord record) { + int brokerId = record.brokerId(); + List listeners = new ArrayList<>(); + for (BrokerEndpoint endpoint : record.endPoints()) { + listeners.add(new Endpoint(endpoint.name(), + SecurityProtocol.forId(endpoint.securityProtocol()), + endpoint.host(), endpoint.port())); + } + Map features = new HashMap<>(); + for (BrokerFeature feature : record.features()) { + features.put(feature.name(), new VersionRange( + feature.minSupportedVersion(), feature.maxSupportedVersion())); + } + + // Update broker registrations. + BrokerRegistration prevRegistration = brokerRegistrations.put(brokerId, + new BrokerRegistration(brokerId, record.brokerEpoch(), + record.incarnationId(), listeners, features, + Optional.ofNullable(record.rack()), record.fenced())); + updateMetrics(prevRegistration, brokerRegistrations.get(brokerId)); + if (prevRegistration == null) { + log.info("Registered new broker: {}", record); + } else if (prevRegistration.incarnationId().equals(record.incarnationId())) { + log.info("Re-registered broker incarnation: {}", record); + } else { + log.info("Re-registered broker id {}: {}", brokerId, record); + } + } + + public void replay(UnregisterBrokerRecord record) { + int brokerId = record.brokerId(); + BrokerRegistration registration = brokerRegistrations.get(brokerId); + if (registration == null) { + throw new RuntimeException(String.format("Unable to replay %s: no broker " + + "registration found for that id", record.toString())); + } else if (registration.epoch() != record.brokerEpoch()) { + throw new RuntimeException(String.format("Unable to replay %s: no broker " + + "registration with that epoch found", record.toString())); + } else { + brokerRegistrations.remove(brokerId); + updateMetrics(registration, brokerRegistrations.get(brokerId)); + log.info("Unregistered broker: {}", record); + } + } + + public void replay(FenceBrokerRecord record) { + int brokerId = record.id(); + BrokerRegistration registration = brokerRegistrations.get(brokerId); + if (registration == null) { + throw new RuntimeException(String.format("Unable to replay %s: no broker " + + "registration found for that id", record.toString())); + } else if (registration.epoch() != record.epoch()) { + throw new RuntimeException(String.format("Unable to replay %s: no broker " + + "registration with that epoch found", record.toString())); + } else { + brokerRegistrations.put(brokerId, registration.cloneWithFencing(true)); + updateMetrics(registration, brokerRegistrations.get(brokerId)); + log.info("Fenced broker: {}", record); + } + } + + public void replay(UnfenceBrokerRecord record) { + int brokerId = record.id(); + BrokerRegistration registration = brokerRegistrations.get(brokerId); + if (registration == null) { + throw new RuntimeException(String.format("Unable to replay %s: no broker " + + "registration found for that id", record.toString())); + } else if (registration.epoch() != record.epoch()) { + throw new RuntimeException(String.format("Unable to replay %s: no broker " + + "registration with that epoch found", record.toString())); + } else { + brokerRegistrations.put(brokerId, registration.cloneWithFencing(false)); + updateMetrics(registration, brokerRegistrations.get(brokerId)); + log.info("Unfenced broker: {}", record); + } + if (readyBrokersFuture.isPresent()) { + if (readyBrokersFuture.get().check()) { + readyBrokersFuture.get().future.complete(null); + readyBrokersFuture = Optional.empty(); + } + } + } + + private void updateMetrics(BrokerRegistration prevRegistration, BrokerRegistration registration) { + if (registration == null) { + if (prevRegistration.fenced()) { + controllerMetrics.setFencedBrokerCount(controllerMetrics.fencedBrokerCount() - 1); + } else { + controllerMetrics.setActiveBrokerCount(controllerMetrics.activeBrokerCount() - 1); + } + } else if (prevRegistration == null) { + if (registration.fenced()) { + controllerMetrics.setFencedBrokerCount(controllerMetrics.fencedBrokerCount() + 1); + } else { + controllerMetrics.setActiveBrokerCount(controllerMetrics.activeBrokerCount() + 1); + } + } else { + if (prevRegistration.fenced() && !registration.fenced()) { + controllerMetrics.setFencedBrokerCount(controllerMetrics.fencedBrokerCount() - 1); + controllerMetrics.setActiveBrokerCount(controllerMetrics.activeBrokerCount() + 1); + } else if (!prevRegistration.fenced() && registration.fenced()) { + controllerMetrics.setFencedBrokerCount(controllerMetrics.fencedBrokerCount() + 1); + controllerMetrics.setActiveBrokerCount(controllerMetrics.activeBrokerCount() - 1); + } + } + } + + + public List> placeReplicas(int startPartition, + int numPartitions, + short numReplicas) { + if (heartbeatManager == null) { + throw new RuntimeException("ClusterControlManager is not active."); + } + return heartbeatManager.placeReplicas(startPartition, numPartitions, numReplicas, + id -> brokerRegistrations.get(id).rack(), replicaPlacer); + } + + public boolean unfenced(int brokerId) { + BrokerRegistration registration = brokerRegistrations.get(brokerId); + if (registration == null) return false; + return !registration.fenced(); + } + + BrokerHeartbeatManager heartbeatManager() { + if (heartbeatManager == null) { + throw new RuntimeException("ClusterControlManager is not active."); + } + return heartbeatManager; + } + + public void checkBrokerEpoch(int brokerId, long brokerEpoch) { + BrokerRegistration registration = brokerRegistrations.get(brokerId); + if (registration == null) { + throw new StaleBrokerEpochException("No broker registration found for " + + "broker id " + brokerId); + } + if (registration.epoch() != brokerEpoch) { + throw new StaleBrokerEpochException("Expected broker epoch " + + registration.epoch() + ", but got broker epoch " + brokerEpoch); + } + } + + public void addReadyBrokersFuture(CompletableFuture future, int minBrokers) { + readyBrokersFuture = Optional.of(new ReadyBrokersFuture(future, minBrokers)); + if (readyBrokersFuture.get().check()) { + readyBrokersFuture.get().future.complete(null); + readyBrokersFuture = Optional.empty(); + } + } + + class ClusterControlIterator implements Iterator> { + private final Iterator> iterator; + + ClusterControlIterator(long epoch) { + this.iterator = brokerRegistrations.entrySet(epoch).iterator(); + } + + @Override + public boolean hasNext() { + return iterator.hasNext(); + } + + @Override + public List next() { + if (!hasNext()) throw new NoSuchElementException(); + Entry entry = iterator.next(); + int brokerId = entry.getKey(); + BrokerRegistration registration = entry.getValue(); + BrokerEndpointCollection endpoints = new BrokerEndpointCollection(); + for (Entry endpointEntry : registration.listeners().entrySet()) { + endpoints.add(new BrokerEndpoint().setName(endpointEntry.getKey()). + setHost(endpointEntry.getValue().host()). + setPort(endpointEntry.getValue().port()). + setSecurityProtocol(endpointEntry.getValue().securityProtocol().id)); + } + BrokerFeatureCollection features = new BrokerFeatureCollection(); + for (Entry featureEntry : registration.supportedFeatures().entrySet()) { + features.add(new BrokerFeature().setName(featureEntry.getKey()). + setMaxSupportedVersion(featureEntry.getValue().max()). + setMinSupportedVersion(featureEntry.getValue().min())); + } + List batch = new ArrayList<>(); + batch.add(new ApiMessageAndVersion(new RegisterBrokerRecord(). + setBrokerId(brokerId). + setIncarnationId(registration.incarnationId()). + setBrokerEpoch(registration.epoch()). + setEndPoints(endpoints). + setFeatures(features). + setRack(registration.rack().orElse(null)). + setFenced(registration.fenced()), + REGISTER_BROKER_RECORD.highestSupportedVersion())); + return batch; + } + } + + ClusterControlIterator iterator(long epoch) { + return new ClusterControlIterator(epoch); + } +} diff --git a/metadata/src/main/java/org/apache/kafka/controller/ConfigurationControlManager.java b/metadata/src/main/java/org/apache/kafka/controller/ConfigurationControlManager.java new file mode 100644 index 0000000..83f1cbf --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/controller/ConfigurationControlManager.java @@ -0,0 +1,457 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import org.apache.kafka.clients.admin.AlterConfigOp.OpType; +import org.apache.kafka.common.config.ConfigDef.ConfigKey; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.config.ConfigResource.Type; +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.internals.Topic; +import org.apache.kafka.common.metadata.ConfigRecord; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.ApiError; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.apache.kafka.server.policy.AlterConfigPolicy; +import org.apache.kafka.server.policy.AlterConfigPolicy.RequestMetadata; +import org.apache.kafka.timeline.SnapshotRegistry; +import org.apache.kafka.timeline.TimelineHashMap; +import org.slf4j.Logger; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.NoSuchElementException; +import java.util.Objects; +import java.util.Optional; + +import static org.apache.kafka.clients.admin.AlterConfigOp.OpType.APPEND; +import static org.apache.kafka.common.metadata.MetadataRecordType.CONFIG_RECORD; +import static org.apache.kafka.common.protocol.Errors.INVALID_CONFIG; + + +public class ConfigurationControlManager { + private final Logger log; + private final SnapshotRegistry snapshotRegistry; + private final Map configDefs; + private final Optional alterConfigPolicy; + private final ConfigurationValidator validator; + private final TimelineHashMap> configData; + + ConfigurationControlManager(LogContext logContext, + SnapshotRegistry snapshotRegistry, + Map configDefs, + Optional alterConfigPolicy, + ConfigurationValidator validator) { + this.log = logContext.logger(ConfigurationControlManager.class); + this.snapshotRegistry = snapshotRegistry; + this.configDefs = configDefs; + this.configData = new TimelineHashMap<>(snapshotRegistry, 0); + this.alterConfigPolicy = alterConfigPolicy; + this.validator = validator; + } + + /** + * Determine the result of applying a batch of incremental configuration changes. Note + * that this method does not change the contents of memory. It just generates a + * result, that you can replay later if you wish using replay(). + * + * Note that there can only be one result per ConfigResource. So if you try to modify + * several keys and one modification fails, the whole ConfigKey fails and nothing gets + * changed. + * + * @param configChanges Maps each resource to a map from config keys to + * operation data. + * @return The result. + */ + ControllerResult> incrementalAlterConfigs( + Map>> configChanges) { + List outputRecords = new ArrayList<>(); + Map outputResults = new HashMap<>(); + for (Entry>> resourceEntry : + configChanges.entrySet()) { + incrementalAlterConfigResource(resourceEntry.getKey(), + resourceEntry.getValue(), + outputRecords, + outputResults); + } + return ControllerResult.atomicOf(outputRecords, outputResults); + } + + private void incrementalAlterConfigResource(ConfigResource configResource, + Map> keysToOps, + List outputRecords, + Map outputResults) { + ApiError error = checkConfigResource(configResource); + if (error.isFailure()) { + outputResults.put(configResource, error); + return; + } + List newRecords = new ArrayList<>(); + for (Entry> keysToOpsEntry : keysToOps.entrySet()) { + String key = keysToOpsEntry.getKey(); + String currentValue = null; + TimelineHashMap currentConfigs = configData.get(configResource); + if (currentConfigs != null) { + currentValue = currentConfigs.get(key); + } + String newValue = currentValue; + Entry opTypeAndNewValue = keysToOpsEntry.getValue(); + OpType opType = opTypeAndNewValue.getKey(); + String opValue = opTypeAndNewValue.getValue(); + switch (opType) { + case SET: + newValue = opValue; + break; + case DELETE: + newValue = null; + break; + case APPEND: + case SUBTRACT: + if (!isSplittable(configResource.type(), key)) { + outputResults.put(configResource, new ApiError( + INVALID_CONFIG, "Can't " + opType + " to " + + "key " + key + " because its type is not LIST.")); + return; + } + List newValueParts = getParts(newValue, key, configResource); + if (opType == APPEND) { + if (!newValueParts.contains(opValue)) { + newValueParts.add(opValue); + } + newValue = String.join(",", newValueParts); + } else if (newValueParts.remove(opValue)) { + newValue = String.join(",", newValueParts); + } + break; + } + if (!Objects.equals(currentValue, newValue)) { + newRecords.add(new ApiMessageAndVersion(new ConfigRecord(). + setResourceType(configResource.type().id()). + setResourceName(configResource.name()). + setName(key). + setValue(newValue), CONFIG_RECORD.highestSupportedVersion())); + } + } + error = validateAlterConfig(configResource, newRecords); + if (error.isFailure()) { + outputResults.put(configResource, error); + return; + } + outputRecords.addAll(newRecords); + outputResults.put(configResource, ApiError.NONE); + } + + private ApiError validateAlterConfig(ConfigResource configResource, + List newRecords) { + Map newConfigs = new HashMap<>(); + TimelineHashMap existingConfigs = configData.get(configResource); + if (existingConfigs != null) newConfigs.putAll(existingConfigs); + for (ApiMessageAndVersion newRecord : newRecords) { + ConfigRecord configRecord = (ConfigRecord) newRecord.message(); + if (configRecord.value() == null) { + newConfigs.remove(configRecord.name()); + } else { + newConfigs.put(configRecord.name(), configRecord.value()); + } + } + try { + validator.validate(configResource, newConfigs); + if (alterConfigPolicy.isPresent()) { + alterConfigPolicy.get().validate(new RequestMetadata(configResource, newConfigs)); + } + } catch (ConfigException e) { + return new ApiError(INVALID_CONFIG, e.getMessage()); + } catch (Throwable e) { + return ApiError.fromThrowable(e); + } + return ApiError.NONE; + } + + /** + * Determine the result of applying a batch of legacy configuration changes. Note + * that this method does not change the contents of memory. It just generates a + * result, that you can replay later if you wish using replay(). + * + * @param newConfigs The new configurations to install for each resource. + * All existing configurations will be overwritten. + * @return The result. + */ + ControllerResult> legacyAlterConfigs( + Map> newConfigs) { + List outputRecords = new ArrayList<>(); + Map outputResults = new HashMap<>(); + for (Entry> resourceEntry : + newConfigs.entrySet()) { + legacyAlterConfigResource(resourceEntry.getKey(), + resourceEntry.getValue(), + outputRecords, + outputResults); + } + return ControllerResult.atomicOf(outputRecords, outputResults); + } + + private void legacyAlterConfigResource(ConfigResource configResource, + Map newConfigs, + List outputRecords, + Map outputResults) { + ApiError error = checkConfigResource(configResource); + if (error.isFailure()) { + outputResults.put(configResource, error); + return; + } + List newRecords = new ArrayList<>(); + Map currentConfigs = configData.get(configResource); + if (currentConfigs == null) { + currentConfigs = Collections.emptyMap(); + } + for (Entry entry : newConfigs.entrySet()) { + String key = entry.getKey(); + String newValue = entry.getValue(); + String currentValue = currentConfigs.get(key); + if (!Objects.equals(newValue, currentValue)) { + newRecords.add(new ApiMessageAndVersion(new ConfigRecord(). + setResourceType(configResource.type().id()). + setResourceName(configResource.name()). + setName(key). + setValue(newValue), CONFIG_RECORD.highestSupportedVersion())); + } + } + for (String key : currentConfigs.keySet()) { + if (!newConfigs.containsKey(key)) { + newRecords.add(new ApiMessageAndVersion(new ConfigRecord(). + setResourceType(configResource.type().id()). + setResourceName(configResource.name()). + setName(key). + setValue(null), CONFIG_RECORD.highestSupportedVersion())); + } + } + error = validateAlterConfig(configResource, newRecords); + if (error.isFailure()) { + outputResults.put(configResource, error); + return; + } + outputRecords.addAll(newRecords); + outputResults.put(configResource, ApiError.NONE); + } + + private List getParts(String value, String key, ConfigResource configResource) { + if (value == null) { + value = getConfigValueDefault(configResource.type(), key); + } + List parts = new ArrayList<>(); + if (value == null) { + return parts; + } + String[] splitValues = value.split(","); + for (String splitValue : splitValues) { + if (!splitValue.isEmpty()) { + parts.add(splitValue); + } + } + return parts; + } + + static ApiError checkConfigResource(ConfigResource configResource) { + switch (configResource.type()) { + case BROKER_LOGGER: + // We do not handle resources of type BROKER_LOGGER in + // ConfigurationControlManager, since they are not persisted to the + // metadata log. + // + // When using incrementalAlterConfigs, we handle changes to BROKER_LOGGER + // in ControllerApis.scala. When using the legacy alterConfigs, + // BROKER_LOGGER is not supported at all. + return new ApiError(Errors.INVALID_REQUEST, "Unsupported " + + "configuration resource type BROKER_LOGGER "); + case BROKER: + // Note: A Resource with type BROKER and an empty name represents a + // cluster configuration that applies to all brokers. + if (!configResource.name().isEmpty()) { + try { + int brokerId = Integer.parseInt(configResource.name()); + if (brokerId < 0) { + return new ApiError(Errors.INVALID_REQUEST, "Illegal " + + "negative broker ID in BROKER resource."); + } + } catch (NumberFormatException e) { + return new ApiError(Errors.INVALID_REQUEST, "Illegal " + + "non-integral BROKER resource type name."); + } + } + return ApiError.NONE; + case TOPIC: + try { + Topic.validate(configResource.name()); + } catch (Exception e) { + return new ApiError(Errors.INVALID_REQUEST, "Illegal topic name."); + } + return ApiError.NONE; + case UNKNOWN: + return new ApiError(Errors.INVALID_REQUEST, "Unsupported configuration " + + "resource type UNKNOWN."); + default: + return new ApiError(Errors.INVALID_REQUEST, "Unsupported unexpected " + + "resource type"); + } + } + + boolean isSplittable(ConfigResource.Type type, String key) { + ConfigDef configDef = configDefs.get(type); + if (configDef == null) { + return false; + } + ConfigKey configKey = configDef.configKeys().get(key); + if (configKey == null) { + return false; + } + return configKey.type == ConfigDef.Type.LIST; + } + + String getConfigValueDefault(ConfigResource.Type type, String key) { + ConfigDef configDef = configDefs.get(type); + if (configDef == null) { + return null; + } + ConfigKey configKey = configDef.configKeys().get(key); + if (configKey == null || !configKey.hasDefault()) { + return null; + } + return ConfigDef.convertToString(configKey.defaultValue, configKey.type); + } + + /** + * Apply a configuration record to the in-memory state. + * + * @param record The ConfigRecord. + */ + public void replay(ConfigRecord record) { + Type type = Type.forId(record.resourceType()); + ConfigResource configResource = new ConfigResource(type, record.resourceName()); + TimelineHashMap configs = configData.get(configResource); + if (configs == null) { + configs = new TimelineHashMap<>(snapshotRegistry, 0); + configData.put(configResource, configs); + } + if (record.value() == null) { + configs.remove(record.name()); + } else { + configs.put(record.name(), record.value()); + } + if (configs.isEmpty()) { + configData.remove(configResource); + } + log.info("{}: set configuration {} to {}", configResource, record.name(), record.value()); + } + + // VisibleForTesting + Map getConfigs(ConfigResource configResource) { + Map map = configData.get(configResource); + if (map == null) { + return Collections.emptyMap(); + } else { + return Collections.unmodifiableMap(new HashMap<>(map)); + } + } + + public Map>> describeConfigs( + long lastCommittedOffset, Map> resources) { + Map>> results = new HashMap<>(); + for (Entry> resourceEntry : resources.entrySet()) { + ConfigResource resource = resourceEntry.getKey(); + ApiError error = checkConfigResource(resource); + if (error.isFailure()) { + results.put(resource, new ResultOrError<>(error)); + continue; + } + Map foundConfigs = new HashMap<>(); + TimelineHashMap configs = + configData.get(resource, lastCommittedOffset); + if (configs != null) { + Collection targetConfigs = resourceEntry.getValue(); + if (targetConfigs.isEmpty()) { + Iterator> iter = + configs.entrySet(lastCommittedOffset).iterator(); + while (iter.hasNext()) { + Entry entry = iter.next(); + foundConfigs.put(entry.getKey(), entry.getValue()); + } + } else { + for (String key : targetConfigs) { + String value = configs.get(key, lastCommittedOffset); + if (value != null) { + foundConfigs.put(key, value); + } + } + } + } + results.put(resource, new ResultOrError<>(foundConfigs)); + } + return results; + } + + void deleteTopicConfigs(String name) { + configData.remove(new ConfigResource(Type.TOPIC, name)); + } + + boolean uncleanLeaderElectionEnabledForTopic(String name) { + return false; // TODO: support configuring unclean leader election. + } + + class ConfigurationControlIterator implements Iterator> { + private final long epoch; + private final Iterator>> iterator; + + ConfigurationControlIterator(long epoch) { + this.epoch = epoch; + this.iterator = configData.entrySet(epoch).iterator(); + } + + @Override + public boolean hasNext() { + return iterator.hasNext(); + } + + @Override + public List next() { + if (!hasNext()) throw new NoSuchElementException(); + List records = new ArrayList<>(); + Entry> entry = iterator.next(); + ConfigResource resource = entry.getKey(); + for (Entry configEntry : entry.getValue().entrySet(epoch)) { + records.add(new ApiMessageAndVersion(new ConfigRecord(). + setResourceName(resource.name()). + setResourceType(resource.type().id()). + setName(configEntry.getKey()). + setValue(configEntry.getValue()), CONFIG_RECORD.highestSupportedVersion())); + } + return records; + } + } + + ConfigurationControlIterator iterator(long epoch) { + return new ConfigurationControlIterator(epoch); + } +} diff --git a/metadata/src/main/java/org/apache/kafka/controller/ConfigurationValidator.java b/metadata/src/main/java/org/apache/kafka/controller/ConfigurationValidator.java new file mode 100644 index 0000000..b14580a --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/controller/ConfigurationValidator.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import org.apache.kafka.common.config.ConfigResource; + +import java.util.Map; + + +public interface ConfigurationValidator { + ConfigurationValidator NO_OP = (__, ___) -> { }; + + /** + * Throws an ApiException if a configuration is invalid for the given resource. + * + * @param resource The configuration resource. + * @param config The new configuration. + */ + void validate(ConfigResource resource, Map config); +} diff --git a/metadata/src/main/java/org/apache/kafka/controller/Controller.java b/metadata/src/main/java/org/apache/kafka/controller/Controller.java new file mode 100644 index 0000000..f06b108 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/controller/Controller.java @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import org.apache.kafka.clients.admin.AlterConfigOp; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.message.AllocateProducerIdsRequestData; +import org.apache.kafka.common.message.AllocateProducerIdsResponseData; +import org.apache.kafka.common.message.AlterIsrRequestData; +import org.apache.kafka.common.message.AlterIsrResponseData; +import org.apache.kafka.common.message.AlterPartitionReassignmentsRequestData; +import org.apache.kafka.common.message.AlterPartitionReassignmentsResponseData; +import org.apache.kafka.common.message.BrokerHeartbeatRequestData; +import org.apache.kafka.common.message.BrokerRegistrationRequestData; +import org.apache.kafka.common.message.CreatePartitionsRequestData.CreatePartitionsTopic; +import org.apache.kafka.common.message.CreatePartitionsResponseData.CreatePartitionsTopicResult; +import org.apache.kafka.common.message.CreateTopicsRequestData; +import org.apache.kafka.common.message.CreateTopicsResponseData; +import org.apache.kafka.common.message.ElectLeadersRequestData; +import org.apache.kafka.common.message.ElectLeadersResponseData; +import org.apache.kafka.common.message.ListPartitionReassignmentsRequestData; +import org.apache.kafka.common.message.ListPartitionReassignmentsResponseData; +import org.apache.kafka.common.quota.ClientQuotaAlteration; +import org.apache.kafka.common.quota.ClientQuotaEntity; +import org.apache.kafka.common.requests.ApiError; +import org.apache.kafka.metadata.BrokerHeartbeatReply; +import org.apache.kafka.metadata.BrokerRegistrationReply; +import org.apache.kafka.metadata.FeatureMapAndEpoch; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + + +public interface Controller extends AutoCloseable { + /** + * Change partition ISRs. + * + * @param request The AlterIsrRequest data. + * + * @return A future yielding the response. + */ + CompletableFuture alterIsr(AlterIsrRequestData request); + + /** + * Create a batch of topics. + * + * @param request The CreateTopicsRequest data. + * + * @return A future yielding the response. + */ + CompletableFuture + createTopics(CreateTopicsRequestData request); + + /** + * Unregister a broker. + * + * @param brokerId The broker id to unregister. + * + * @return A future that is completed successfully when the broker is + * unregistered. + */ + CompletableFuture unregisterBroker(int brokerId); + + /** + * Find the ids for topic names. + * + * @param deadlineNs The time by which this operation needs to be complete, before + * we will complete this operation with a timeout. + * @param topicNames The topic names to resolve. + * @return A future yielding a map from topic name to id. + */ + CompletableFuture>> findTopicIds(long deadlineNs, + Collection topicNames); + + /** + * Find the names for topic ids. + * + * @param deadlineNs The time by which this operation needs to be complete, before + * we will complete this operation with a timeout. + * @param topicIds The topic ids to resolve. + * @return A future yielding a map from topic id to name. + */ + CompletableFuture>> findTopicNames(long deadlineNs, + Collection topicIds); + + /** + * Delete a batch of topics. + * + * @param deadlineNs The time by which this operation needs to be complete, before + * we will complete this operation with a timeout. + * @param topicIds The IDs of the topics to delete. + * + * @return A future yielding the response. + */ + CompletableFuture> deleteTopics(long deadlineNs, + Collection topicIds); + + /** + * Describe the current configuration of various resources. + * + * @param resources A map from resources to the collection of config keys that we + * want to describe for each. If the collection is empty, then + * all configuration keys will be described. + * + * @return + */ + CompletableFuture>>> + describeConfigs(Map> resources); + + /** + * Elect new partition leaders. + * + * @param request The request. + * + * @return A future yielding the elect leaders response. + */ + CompletableFuture electLeaders(ElectLeadersRequestData request); + + /** + * Get the current finalized feature ranges for each feature. + * + * @return A future yielding the feature ranges. + */ + CompletableFuture finalizedFeatures(); + + /** + * Perform some incremental configuration changes. + * + * @param configChanges The changes. + * @param validateOnly True if we should validate the changes but not apply them. + * + * @return A future yielding a map from config resources to error results. + */ + CompletableFuture> incrementalAlterConfigs( + Map>> configChanges, + boolean validateOnly); + + /** + * Start or stop some partition reassignments. + * + * @param request The alter partition reassignments request. + * + * @return A future yielding the results. + */ + CompletableFuture + alterPartitionReassignments(AlterPartitionReassignmentsRequestData request); + + /** + * List ongoing partition reassignments. + * + * @param request The list partition reassignments request. + * + * @return A future yielding the results. + */ + CompletableFuture + listPartitionReassignments(ListPartitionReassignmentsRequestData request); + + /** + * Perform some configuration changes using the legacy API. + * + * @param newConfigs The new configuration maps to apply. + * @param validateOnly True if we should validate the changes but not apply them. + * + * @return A future yielding a map from config resources to error results. + */ + CompletableFuture> legacyAlterConfigs( + Map> newConfigs, boolean validateOnly); + + /** + * Process a heartbeat from a broker. + * + * @param request The broker heartbeat request. + * + * @return A future yielding the broker heartbeat reply. + */ + CompletableFuture processBrokerHeartbeat( + BrokerHeartbeatRequestData request); + + /** + * Attempt to register the given broker. + * + * @param request The registration request. + * + * @return A future yielding the broker registration reply. + */ + CompletableFuture registerBroker( + BrokerRegistrationRequestData request); + + /** + * Wait for the given number of brokers to be registered and unfenced. + * This is for testing. + * + * @param minBrokers The minimum number of brokers to wait for. + * @return A future which is completed when the given number of brokers + * is reached. + */ + CompletableFuture waitForReadyBrokers(int minBrokers); + + /** + * Perform some client quota changes + * + * @param quotaAlterations The list of quotas to alter + * @param validateOnly True if we should validate the changes but not apply them. + * @return A future yielding a map of quota entities to error results. + */ + CompletableFuture> alterClientQuotas( + Collection quotaAlterations, boolean validateOnly + ); + + /** + * Allocate a block of producer IDs for transactional and idempotent producers + * @param request The allocate producer IDs request + * @return A future which yields a new producer ID block as a response + */ + CompletableFuture allocateProducerIds( + AllocateProducerIdsRequestData request + ); + + /** + * Begin writing a controller snapshot. If there was already an ongoing snapshot, it + * simply returns information about that snapshot rather than starting a new one. + * + * @return A future yielding the epoch of the snapshot. + */ + CompletableFuture beginWritingSnapshot(); + + /** + * Create partitions on certain topics. + * + * @param deadlineNs The time by which this operation needs to be complete, before + * we will complete this operation with a timeout. + * @param topics The list of topics to create partitions for. + * @return A future yielding per-topic results. + */ + CompletableFuture> + createPartitions(long deadlineNs, List topics); + + /** + * Begin shutting down, but don't block. You must still call close to clean up all + * resources. + */ + void beginShutdown(); + + /** + * If this controller is active, this is the non-negative controller epoch. + * Otherwise, this is -1. + */ + int curClaimEpoch(); + + /** + * Returns true if this controller is currently active. + */ + default boolean isActive() { + return curClaimEpoch() != -1; + } + + /** + * Blocks until we have shut down and freed all resources. + */ + void close() throws InterruptedException; +} diff --git a/metadata/src/main/java/org/apache/kafka/controller/ControllerMetrics.java b/metadata/src/main/java/org/apache/kafka/controller/ControllerMetrics.java new file mode 100644 index 0000000..fa03e05 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/controller/ControllerMetrics.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + + +public interface ControllerMetrics extends AutoCloseable { + void setActive(boolean active); + + boolean active(); + + void updateEventQueueTime(long durationMs); + + void updateEventQueueProcessingTime(long durationMs); + + void setFencedBrokerCount(int brokerCount); + + int fencedBrokerCount(); + + void setActiveBrokerCount(int brokerCount); + + int activeBrokerCount(); + + void setGlobalTopicsCount(int topicCount); + + int globalTopicsCount(); + + void setGlobalPartitionCount(int partitionCount); + + int globalPartitionCount(); + + void setOfflinePartitionCount(int offlinePartitions); + + int offlinePartitionCount(); + + void setPreferredReplicaImbalanceCount(int replicaImbalances); + + int preferredReplicaImbalanceCount(); + + void close(); +} diff --git a/metadata/src/main/java/org/apache/kafka/controller/ControllerPurgatory.java b/metadata/src/main/java/org/apache/kafka/controller/ControllerPurgatory.java new file mode 100644 index 0000000..ee6c1d1 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/controller/ControllerPurgatory.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.TreeMap; + +/** + * The purgatory which holds events that have been started, but not yet completed. + * We wait for the high water mark of the metadata log to advance before completing + * them. + */ +class ControllerPurgatory { + /** + * A map from log offsets to events. Each event will be completed once the log + * advances past its offset. + */ + private final TreeMap> pending = new TreeMap<>(); + + /** + * Complete some purgatory entries. + * + * @param offset The offset which the high water mark has advanced to. + */ + void completeUpTo(long offset) { + Iterator>> iter = pending.entrySet().iterator(); + while (iter.hasNext()) { + Entry> entry = iter.next(); + if (entry.getKey() > offset) { + break; + } + for (DeferredEvent event : entry.getValue()) { + event.complete(null); + } + iter.remove(); + } + } + + /** + * Fail all the pending purgatory entries. + * + * @param exception The exception to fail the entries with. + */ + void failAll(Exception exception) { + Iterator>> iter = pending.entrySet().iterator(); + while (iter.hasNext()) { + Entry> entry = iter.next(); + for (DeferredEvent event : entry.getValue()) { + event.complete(exception); + } + iter.remove(); + } + } + + /** + * Add a new purgatory event. + * + * @param offset The offset to add the new event at. + * @param event The new event. + */ + void add(long offset, DeferredEvent event) { + if (!pending.isEmpty()) { + long lastKey = pending.lastKey(); + if (offset < lastKey) { + throw new RuntimeException("There is already a purgatory event with " + + "offset " + lastKey + ". We should not add one with an offset of " + + offset + " which " + "is lower than that."); + } + } + List events = pending.get(offset); + if (events == null) { + events = new ArrayList<>(); + pending.put(offset, events); + } + events.add(event); + } + + /** + * Get the offset of the highest pending event, or empty if there are no pending + * events. + */ + Optional highestPendingOffset() { + if (pending.isEmpty()) { + return Optional.empty(); + } else { + return Optional.of(pending.lastKey()); + } + } +} \ No newline at end of file diff --git a/metadata/src/main/java/org/apache/kafka/controller/ControllerResult.java b/metadata/src/main/java/org/apache/kafka/controller/ControllerResult.java new file mode 100644 index 0000000..d130de5 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/controller/ControllerResult.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import org.apache.kafka.server.common.ApiMessageAndVersion; + +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + + +class ControllerResult { + private final List records; + private final T response; + private final boolean isAtomic; + + protected ControllerResult(List records, T response, boolean isAtomic) { + Objects.requireNonNull(records); + this.records = records; + this.response = response; + this.isAtomic = isAtomic; + } + + public List records() { + return records; + } + + public T response() { + return response; + } + + public boolean isAtomic() { + return isAtomic; + } + + @Override + public boolean equals(Object o) { + if (o == null || (!o.getClass().equals(getClass()))) { + return false; + } + ControllerResult other = (ControllerResult) o; + return records.equals(other.records) && + Objects.equals(response, other.response) && + Objects.equals(isAtomic, other.isAtomic); + } + + @Override + public int hashCode() { + return Objects.hash(records, response, isAtomic); + } + + @Override + public String toString() { + return String.format( + "ControllerResult(records=%s, response=%s, isAtomic=%s)", + String.join(",", records.stream().map(ApiMessageAndVersion::toString).collect(Collectors.toList())), + response, + isAtomic + ); + } + + public ControllerResult withoutRecords() { + return new ControllerResult<>(Collections.emptyList(), response, false); + } + + public static ControllerResult atomicOf(List records, T response) { + return new ControllerResult<>(records, response, true); + } + + public static ControllerResult of(List records, T response) { + return new ControllerResult<>(records, response, false); + } +} diff --git a/metadata/src/main/java/org/apache/kafka/controller/ControllerResultAndOffset.java b/metadata/src/main/java/org/apache/kafka/controller/ControllerResultAndOffset.java new file mode 100644 index 0000000..1b72565 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/controller/ControllerResultAndOffset.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import org.apache.kafka.server.common.ApiMessageAndVersion; + +import java.util.Objects; +import java.util.stream.Collectors; + + +final class ControllerResultAndOffset extends ControllerResult { + private final long offset; + + private ControllerResultAndOffset(long offset, ControllerResult result) { + super(result.records(), result.response(), result.isAtomic()); + this.offset = offset; + } + + public long offset() { + return offset; + } + + @Override + public boolean equals(Object o) { + if (o == null || (!o.getClass().equals(getClass()))) { + return false; + } + ControllerResultAndOffset other = (ControllerResultAndOffset) o; + return records().equals(other.records()) && + response().equals(other.response()) && + isAtomic() == other.isAtomic() && + offset == other.offset; + } + + @Override + public int hashCode() { + return Objects.hash(records(), response(), isAtomic(), offset); + } + + @Override + public String toString() { + return String.format( + "ControllerResultAndOffset(records=%s, response=%s, isAtomic=%s, offset=%s)", + String.join(",", records().stream().map(ApiMessageAndVersion::toString).collect(Collectors.toList())), + response(), + isAtomic(), + offset + ); + } + + public static ControllerResultAndOffset of(long offset, ControllerResult result) { + return new ControllerResultAndOffset<>(offset, result); + } +} diff --git a/metadata/src/main/java/org/apache/kafka/controller/DeferredEvent.java b/metadata/src/main/java/org/apache/kafka/controller/DeferredEvent.java new file mode 100644 index 0000000..e1606f3 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/controller/DeferredEvent.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +/** + * Represents a deferred event in the controller purgatory. + */ +interface DeferredEvent { + /** + * Complete the event. + * + * @param exception null if the event should be completed successfully; the + * error otherwise. + */ + void complete(Throwable exception); +} diff --git a/metadata/src/main/java/org/apache/kafka/controller/FeatureControlManager.java b/metadata/src/main/java/org/apache/kafka/controller/FeatureControlManager.java new file mode 100644 index 0000000..ed7c98c --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/controller/FeatureControlManager.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map.Entry; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Set; +import java.util.TreeMap; + +import org.apache.kafka.common.metadata.FeatureLevelRecord; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.ApiError; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.apache.kafka.metadata.FeatureMap; +import org.apache.kafka.metadata.FeatureMapAndEpoch; +import org.apache.kafka.metadata.VersionRange; +import org.apache.kafka.timeline.SnapshotRegistry; +import org.apache.kafka.timeline.TimelineHashMap; + +import static org.apache.kafka.common.metadata.MetadataRecordType.FEATURE_LEVEL_RECORD; + + +public class FeatureControlManager { + /** + * An immutable map containing the features supported by this controller's software. + */ + private final Map supportedFeatures; + + /** + * Maps feature names to finalized version ranges. + */ + private final TimelineHashMap finalizedVersions; + + FeatureControlManager(Map supportedFeatures, + SnapshotRegistry snapshotRegistry) { + this.supportedFeatures = supportedFeatures; + this.finalizedVersions = new TimelineHashMap<>(snapshotRegistry, 0); + } + + ControllerResult> updateFeatures( + Map updates, Set downgradeables, + Map> brokerFeatures) { + TreeMap results = new TreeMap<>(); + List records = new ArrayList<>(); + for (Entry entry : updates.entrySet()) { + results.put(entry.getKey(), updateFeature(entry.getKey(), entry.getValue(), + downgradeables.contains(entry.getKey()), brokerFeatures, records)); + } + + return ControllerResult.atomicOf(records, results); + } + + private ApiError updateFeature(String featureName, + VersionRange newRange, + boolean downgradeable, + Map> brokerFeatures, + List records) { + if (newRange.min() <= 0) { + return new ApiError(Errors.INVALID_UPDATE_VERSION, + "The lower value for the new range cannot be less than 1."); + } + if (newRange.max() <= 0) { + return new ApiError(Errors.INVALID_UPDATE_VERSION, + "The upper value for the new range cannot be less than 1."); + } + VersionRange localRange = supportedFeatures.get(featureName); + if (localRange == null || !localRange.contains(newRange)) { + return new ApiError(Errors.INVALID_UPDATE_VERSION, + "The controller does not support the given feature range."); + } + for (Entry> brokerEntry : + brokerFeatures.entrySet()) { + VersionRange brokerRange = brokerEntry.getValue().get(featureName); + if (brokerRange == null || !brokerRange.contains(newRange)) { + return new ApiError(Errors.INVALID_UPDATE_VERSION, + "Broker " + brokerEntry.getKey() + " does not support the given " + + "feature range."); + } + } + VersionRange currentRange = finalizedVersions.get(featureName); + if (currentRange != null && currentRange.max() > newRange.max()) { + if (!downgradeable) { + return new ApiError(Errors.INVALID_UPDATE_VERSION, + "Can't downgrade the maximum version of this feature without " + + "setting downgradable to true."); + } + } + records.add(new ApiMessageAndVersion( + new FeatureLevelRecord().setName(featureName). + setMinFeatureLevel(newRange.min()).setMaxFeatureLevel(newRange.max()), + FEATURE_LEVEL_RECORD.highestSupportedVersion())); + return ApiError.NONE; + } + + FeatureMapAndEpoch finalizedFeatures(long lastCommittedOffset) { + Map features = new HashMap<>(); + for (Entry entry : finalizedVersions.entrySet(lastCommittedOffset)) { + features.put(entry.getKey(), entry.getValue()); + } + return new FeatureMapAndEpoch(new FeatureMap(features), lastCommittedOffset); + } + + public void replay(FeatureLevelRecord record) { + finalizedVersions.put(record.name(), + new VersionRange(record.minFeatureLevel(), record.maxFeatureLevel())); + } + + class FeatureControlIterator implements Iterator> { + private final Iterator> iterator; + + FeatureControlIterator(long epoch) { + this.iterator = finalizedVersions.entrySet(epoch).iterator(); + } + + @Override + public boolean hasNext() { + return iterator.hasNext(); + } + + @Override + public List next() { + if (!hasNext()) throw new NoSuchElementException(); + Entry entry = iterator.next(); + VersionRange versions = entry.getValue(); + return Collections.singletonList(new ApiMessageAndVersion(new FeatureLevelRecord(). + setName(entry.getKey()). + setMinFeatureLevel(versions.min()). + setMaxFeatureLevel(versions.max()), FEATURE_LEVEL_RECORD.highestSupportedVersion())); + } + } + + FeatureControlIterator iterator(long epoch) { + return new FeatureControlIterator(epoch); + } +} diff --git a/metadata/src/main/java/org/apache/kafka/controller/PartitionChangeBuilder.java b/metadata/src/main/java/org/apache/kafka/controller/PartitionChangeBuilder.java new file mode 100644 index 0000000..cf0f6bf --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/controller/PartitionChangeBuilder.java @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.metadata.PartitionChangeRecord; +import org.apache.kafka.metadata.PartitionRegistration; +import org.apache.kafka.metadata.Replicas; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.function.Function; +import java.util.function.Supplier; + +import static org.apache.kafka.common.metadata.MetadataRecordType.PARTITION_CHANGE_RECORD; +import static org.apache.kafka.metadata.LeaderConstants.NO_LEADER; +import static org.apache.kafka.metadata.LeaderConstants.NO_LEADER_CHANGE; + +/** + * PartitionChangeBuilder handles changing partition registrations. + */ +public class PartitionChangeBuilder { + private static final Logger log = LoggerFactory.getLogger(PartitionChangeBuilder.class); + + public static boolean changeRecordIsNoOp(PartitionChangeRecord record) { + if (record.isr() != null) return false; + if (record.leader() != NO_LEADER_CHANGE) return false; + if (record.replicas() != null) return false; + if (record.removingReplicas() != null) return false; + if (record.addingReplicas() != null) return false; + return true; + } + + private final PartitionRegistration partition; + private final Uuid topicId; + private final int partitionId; + private final Function isAcceptableLeader; + private final Supplier uncleanElectionOk; + private List targetIsr; + private List targetReplicas; + private List targetRemoving; + private List targetAdding; + private boolean alwaysElectPreferredIfPossible; + + public PartitionChangeBuilder(PartitionRegistration partition, + Uuid topicId, + int partitionId, + Function isAcceptableLeader, + Supplier uncleanElectionOk) { + this.partition = partition; + this.topicId = topicId; + this.partitionId = partitionId; + this.isAcceptableLeader = isAcceptableLeader; + this.uncleanElectionOk = uncleanElectionOk; + this.targetIsr = Replicas.toList(partition.isr); + this.targetReplicas = Replicas.toList(partition.replicas); + this.targetRemoving = Replicas.toList(partition.removingReplicas); + this.targetAdding = Replicas.toList(partition.addingReplicas); + this.alwaysElectPreferredIfPossible = false; + } + + public PartitionChangeBuilder setTargetIsr(List targetIsr) { + this.targetIsr = targetIsr; + return this; + } + + public PartitionChangeBuilder setTargetReplicas(List targetReplicas) { + this.targetReplicas = targetReplicas; + return this; + } + + public PartitionChangeBuilder setAlwaysElectPreferredIfPossible(boolean alwaysElectPreferredIfPossible) { + this.alwaysElectPreferredIfPossible = alwaysElectPreferredIfPossible; + return this; + } + + public PartitionChangeBuilder setTargetRemoving(List targetRemoving) { + this.targetRemoving = targetRemoving; + return this; + } + + public PartitionChangeBuilder setTargetAdding(List targetAdding) { + this.targetAdding = targetAdding; + return this; + } + + boolean shouldTryElection() { + // If the new isr doesn't have the current leader, we need to try to elect a new + // one. Note: this also handles the case where the current leader is NO_LEADER, + // since that value cannot appear in targetIsr. + if (!targetIsr.contains(partition.leader)) return true; + + // Check if we want to try to get away from a non-preferred leader. + if (alwaysElectPreferredIfPossible && !partition.hasPreferredLeader()) return true; + + return false; + } + + class BestLeader { + final int node; + final boolean unclean; + + BestLeader() { + for (int replica : targetReplicas) { + if (targetIsr.contains(replica) && isAcceptableLeader.apply(replica)) { + this.node = replica; + this.unclean = false; + return; + } + } + if (uncleanElectionOk.get()) { + for (int replica : targetReplicas) { + if (isAcceptableLeader.apply(replica)) { + this.node = replica; + this.unclean = true; + return; + } + } + } + this.node = NO_LEADER; + this.unclean = false; + } + } + + private void tryElection(PartitionChangeRecord record) { + BestLeader bestLeader = new BestLeader(); + if (bestLeader.node != partition.leader) { + log.debug("Setting new leader for topicId {}, partition {} to {}", topicId, partitionId, bestLeader.node); + record.setLeader(bestLeader.node); + if (bestLeader.unclean) { + // If the election was unclean, we have to forcibly set the ISR to just the + // new leader. This can result in data loss! + record.setIsr(Collections.singletonList(bestLeader.node)); + } + } else { + log.debug("Failed to find a new leader with current state: {}", this); + } + } + + /** + * Trigger a leader epoch bump if one is needed. + * + * We need to bump the leader epoch if: + * 1. The leader changed, or + * 2. The new ISR does not contain all the nodes that the old ISR did, or + * 3. The new replia list does not contain all the nodes that the old replia list did. + * + * Changes that do NOT fall in any of these categories will increase the partition epoch, but + * not the leader epoch. Note that if the leader epoch increases, the partition epoch will + * always increase as well; there is no case where the partition epoch increases more slowly + * than the leader epoch. + * + * If the PartitionChangeRecord sets the leader field to something other than + * NO_LEADER_CHANGE, a leader epoch bump will automatically occur. That takes care of + * case 1. In this function, we check for cases 2 and 3, and handle them by manually + * setting record.leader to the current leader. + */ + void triggerLeaderEpochBumpIfNeeded(PartitionChangeRecord record) { + if (record.leader() == NO_LEADER_CHANGE) { + if (!Replicas.contains(targetIsr, partition.isr) || + !Replicas.contains(targetReplicas, partition.replicas)) { + record.setLeader(partition.leader); + } + } + } + + private void completeReassignmentIfNeeded() { + // Check if there is a reassignment to complete. + if (targetRemoving.isEmpty() && targetAdding.isEmpty()) return; + + List newTargetIsr = targetIsr; + List newTargetReplicas = targetReplicas; + if (!targetRemoving.isEmpty()) { + newTargetIsr = new ArrayList<>(targetIsr.size()); + for (int replica : targetIsr) { + if (!targetRemoving.contains(replica)) { + newTargetIsr.add(replica); + } + } + if (newTargetIsr.isEmpty()) return; + newTargetReplicas = new ArrayList<>(targetReplicas.size()); + for (int replica : targetReplicas) { + if (!targetRemoving.contains(replica)) { + newTargetReplicas.add(replica); + } + } + if (newTargetReplicas.isEmpty()) return; + } + for (int replica : targetAdding) { + if (!newTargetIsr.contains(replica)) return; + } + targetIsr = newTargetIsr; + targetReplicas = newTargetReplicas; + targetRemoving = Collections.emptyList(); + targetAdding = Collections.emptyList(); + } + + public Optional build() { + PartitionChangeRecord record = new PartitionChangeRecord(). + setTopicId(topicId). + setPartitionId(partitionId); + + completeReassignmentIfNeeded(); + + if (shouldTryElection()) { + tryElection(record); + } + + triggerLeaderEpochBumpIfNeeded(record); + + if (!targetIsr.isEmpty() && !targetIsr.equals(Replicas.toList(partition.isr))) { + record.setIsr(targetIsr); + } + if (!targetReplicas.isEmpty() && !targetReplicas.equals(Replicas.toList(partition.replicas))) { + record.setReplicas(targetReplicas); + } + if (!targetRemoving.equals(Replicas.toList(partition.removingReplicas))) { + record.setRemovingReplicas(targetRemoving); + } + if (!targetAdding.equals(Replicas.toList(partition.addingReplicas))) { + record.setAddingReplicas(targetAdding); + } + if (changeRecordIsNoOp(record)) { + return Optional.empty(); + } else { + return Optional.of(new ApiMessageAndVersion(record, + PARTITION_CHANGE_RECORD.highestSupportedVersion())); + } + } + + @Override + public String toString() { + return "PartitionChangeBuilder(" + + "partition=" + partition + + ", topicId=" + topicId + + ", partitionId=" + partitionId + + ", isAcceptableLeader=" + isAcceptableLeader + + ", uncleanElectionOk=" + uncleanElectionOk + + ", targetIsr=" + targetIsr + + ", targetReplicas=" + targetReplicas + + ", targetRemoving=" + targetRemoving + + ", targetAdding=" + targetAdding + + ", alwaysElectPreferredIfPossible=" + alwaysElectPreferredIfPossible + + ')'; + } +} diff --git a/metadata/src/main/java/org/apache/kafka/controller/PartitionReassignmentReplicas.java b/metadata/src/main/java/org/apache/kafka/controller/PartitionReassignmentReplicas.java new file mode 100644 index 0000000..96ae408 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/controller/PartitionReassignmentReplicas.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import java.util.ArrayList; +import java.util.Set; +import java.util.List; +import java.util.Objects; +import java.util.TreeSet; + + +class PartitionReassignmentReplicas { + private final List removing; + private final List adding; + private final List merged; + + private static Set calculateDifference(List a, List b) { + Set result = new TreeSet<>(a); + result.removeAll(b); + return result; + } + + PartitionReassignmentReplicas(List currentReplicas, + List targetReplicas) { + Set removing = calculateDifference(currentReplicas, targetReplicas); + this.removing = new ArrayList<>(removing); + Set adding = calculateDifference(targetReplicas, currentReplicas); + this.adding = new ArrayList<>(adding); + this.merged = new ArrayList<>(targetReplicas); + this.merged.addAll(removing); + } + + List removing() { + return removing; + } + + List adding() { + return adding; + } + + List merged() { + return merged; + } + + @Override + public int hashCode() { + return Objects.hash(removing, adding, merged); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof PartitionReassignmentReplicas)) return false; + PartitionReassignmentReplicas other = (PartitionReassignmentReplicas) o; + return removing.equals(other.removing) && + adding.equals(other.adding) && + merged.equals(other.merged); + } + + @Override + public String toString() { + return "PartitionReassignmentReplicas(" + + "removing=" + removing + ", " + + "adding=" + adding + ", " + + "merged=" + merged + ")"; + } +} diff --git a/metadata/src/main/java/org/apache/kafka/controller/PartitionReassignmentRevert.java b/metadata/src/main/java/org/apache/kafka/controller/PartitionReassignmentRevert.java new file mode 100644 index 0000000..3cf6dbb --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/controller/PartitionReassignmentRevert.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import org.apache.kafka.common.errors.InvalidReplicaAssignmentException; +import org.apache.kafka.metadata.PartitionRegistration; +import org.apache.kafka.metadata.Replicas; + +import java.util.ArrayList; +import java.util.Set; +import java.util.List; +import java.util.Objects; + + +class PartitionReassignmentRevert { + private final List replicas; + private final List isr; + private final boolean unclean; + + PartitionReassignmentRevert(PartitionRegistration registration) { + // Figure out the replica list and ISR that we will have after reverting the + // reassignment. In general, we want to take out any replica that the reassignment + // was adding, but keep the ones the reassignment was removing. (But see the + // special case below.) + Set adding = Replicas.toSet(registration.addingReplicas); + this.replicas = new ArrayList<>(registration.replicas.length); + this.isr = new ArrayList<>(registration.isr.length); + for (int i = 0; i < registration.isr.length; i++) { + int replica = registration.isr[i]; + if (!adding.contains(replica)) { + this.isr.add(replica); + } + } + for (int replica : registration.replicas) { + if (!adding.contains(replica)) { + this.replicas.add(replica); + } + } + if (isr.isEmpty()) { + // In the special case that all the replicas that are in the ISR are also + // contained in addingReplicas, we choose the first remaining replica and add + // it to the ISR. This is considered an unclean leader election. Therefore, + // calling code must check that unclean leader election is enabled before + // accepting the new ISR. + if (this.replicas.isEmpty()) { + // This should not be reachable, since it would require a partition + // starting with an empty replica set prior to the reassignment we are + // trying to revert. + throw new InvalidReplicaAssignmentException("Invalid replica " + + "assignment: addingReplicas contains all replicas."); + } + isr.add(replicas.get(0)); + this.unclean = true; + } else { + this.unclean = false; + } + } + + List replicas() { + return replicas; + } + + List isr() { + return isr; + } + + boolean unclean() { + return unclean; + } + + @Override + public int hashCode() { + return Objects.hash(replicas, isr); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof PartitionReassignmentRevert)) return false; + PartitionReassignmentRevert other = (PartitionReassignmentRevert) o; + return replicas.equals(other.replicas) && + isr.equals(other.isr); + } + + @Override + public String toString() { + return "PartitionReassignmentRevert(" + + "replicas=" + replicas + ", " + + "isr=" + isr + ")"; + } +} diff --git a/metadata/src/main/java/org/apache/kafka/controller/ProducerIdControlManager.java b/metadata/src/main/java/org/apache/kafka/controller/ProducerIdControlManager.java new file mode 100644 index 0000000..7291f93 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/controller/ProducerIdControlManager.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import org.apache.kafka.common.errors.UnknownServerException; +import org.apache.kafka.common.metadata.ProducerIdsRecord; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.apache.kafka.server.common.ProducerIdsBlock; +import org.apache.kafka.timeline.SnapshotRegistry; +import org.apache.kafka.timeline.TimelineLong; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; + + +public class ProducerIdControlManager { + + private final ClusterControlManager clusterControlManager; + private final TimelineLong lastProducerId; + + ProducerIdControlManager(ClusterControlManager clusterControlManager, SnapshotRegistry snapshotRegistry) { + this.clusterControlManager = clusterControlManager; + this.lastProducerId = new TimelineLong(snapshotRegistry); + } + + ControllerResult generateNextProducerId(int brokerId, long brokerEpoch) { + clusterControlManager.checkBrokerEpoch(brokerId, brokerEpoch); + + long producerId = lastProducerId.get(); + + if (producerId > Long.MAX_VALUE - ProducerIdsBlock.PRODUCER_ID_BLOCK_SIZE) { + throw new UnknownServerException("Exhausted all producerIds as the next block's end producerId " + + "is will has exceeded long type limit"); + } + + long nextProducerId = producerId + ProducerIdsBlock.PRODUCER_ID_BLOCK_SIZE; + ProducerIdsRecord record = new ProducerIdsRecord() + .setProducerIdsEnd(nextProducerId) + .setBrokerId(brokerId) + .setBrokerEpoch(brokerEpoch); + ProducerIdsBlock block = new ProducerIdsBlock(brokerId, producerId, ProducerIdsBlock.PRODUCER_ID_BLOCK_SIZE); + return ControllerResult.of(Collections.singletonList(new ApiMessageAndVersion(record, (short) 0)), block); + } + + void replay(ProducerIdsRecord record) { + long currentProducerId = lastProducerId.get(); + if (record.producerIdsEnd() <= currentProducerId) { + throw new RuntimeException("Producer ID from record is not monotonically increasing"); + } else { + lastProducerId.set(record.producerIdsEnd()); + } + } + + Iterator> iterator(long epoch) { + List records = new ArrayList<>(1); + + long producerId = lastProducerId.get(epoch); + if (producerId > 0) { + records.add(new ApiMessageAndVersion( + new ProducerIdsRecord() + .setProducerIdsEnd(producerId) + .setBrokerId(0) + .setBrokerEpoch(0L), + (short) 0)); + } + return Collections.singleton(records).iterator(); + } +} diff --git a/metadata/src/main/java/org/apache/kafka/controller/QuorumController.java b/metadata/src/main/java/org/apache/kafka/controller/QuorumController.java new file mode 100644 index 0000000..16b3ab3 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/controller/QuorumController.java @@ -0,0 +1,1426 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import org.apache.kafka.clients.admin.AlterConfigOp.OpType; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.errors.ApiException; +import org.apache.kafka.common.errors.NotControllerException; +import org.apache.kafka.common.errors.UnknownServerException; +import org.apache.kafka.common.message.AllocateProducerIdsRequestData; +import org.apache.kafka.common.message.AllocateProducerIdsResponseData; +import org.apache.kafka.common.message.AlterIsrRequestData; +import org.apache.kafka.common.message.AlterIsrResponseData; +import org.apache.kafka.common.message.AlterPartitionReassignmentsRequestData; +import org.apache.kafka.common.message.AlterPartitionReassignmentsResponseData; +import org.apache.kafka.common.message.BrokerHeartbeatRequestData; +import org.apache.kafka.common.message.BrokerRegistrationRequestData; +import org.apache.kafka.common.message.CreatePartitionsRequestData.CreatePartitionsTopic; +import org.apache.kafka.common.message.CreatePartitionsResponseData.CreatePartitionsTopicResult; +import org.apache.kafka.common.message.CreateTopicsRequestData; +import org.apache.kafka.common.message.CreateTopicsResponseData; +import org.apache.kafka.common.message.ElectLeadersRequestData; +import org.apache.kafka.common.message.ElectLeadersResponseData; +import org.apache.kafka.common.message.ListPartitionReassignmentsRequestData; +import org.apache.kafka.common.message.ListPartitionReassignmentsResponseData; +import org.apache.kafka.common.metadata.ConfigRecord; +import org.apache.kafka.common.metadata.ClientQuotaRecord; +import org.apache.kafka.common.metadata.FeatureLevelRecord; +import org.apache.kafka.common.metadata.FenceBrokerRecord; +import org.apache.kafka.common.metadata.MetadataRecordType; +import org.apache.kafka.common.metadata.PartitionChangeRecord; +import org.apache.kafka.common.metadata.PartitionRecord; +import org.apache.kafka.common.metadata.ProducerIdsRecord; +import org.apache.kafka.common.metadata.RegisterBrokerRecord; +import org.apache.kafka.common.metadata.RemoveTopicRecord; +import org.apache.kafka.common.metadata.TopicRecord; +import org.apache.kafka.common.metadata.UnfenceBrokerRecord; +import org.apache.kafka.common.metadata.UnregisterBrokerRecord; +import org.apache.kafka.common.protocol.ApiMessage; +import org.apache.kafka.common.quota.ClientQuotaAlteration; +import org.apache.kafka.common.quota.ClientQuotaEntity; +import org.apache.kafka.common.requests.ApiError; +import org.apache.kafka.common.utils.ExponentialBackoff; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.controller.SnapshotGenerator.Section; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.apache.kafka.metadata.BrokerHeartbeatReply; +import org.apache.kafka.metadata.BrokerRegistrationReply; +import org.apache.kafka.metadata.FeatureMapAndEpoch; +import org.apache.kafka.metadata.VersionRange; +import org.apache.kafka.queue.EventQueue; +import org.apache.kafka.queue.EventQueue.EarliestDeadlineFunction; +import org.apache.kafka.queue.KafkaEventQueue; +import org.apache.kafka.raft.Batch; +import org.apache.kafka.raft.BatchReader; +import org.apache.kafka.raft.LeaderAndEpoch; +import org.apache.kafka.raft.OffsetAndEpoch; +import org.apache.kafka.raft.RaftClient; +import org.apache.kafka.server.policy.AlterConfigPolicy; +import org.apache.kafka.server.policy.CreateTopicPolicy; +import org.apache.kafka.snapshot.SnapshotReader; +import org.apache.kafka.snapshot.SnapshotWriter; +import org.apache.kafka.timeline.SnapshotRegistry; +import org.slf4j.Logger; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map.Entry; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.OptionalLong; +import java.util.Random; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import static java.util.concurrent.TimeUnit.MICROSECONDS; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.NANOSECONDS; + + +/** + * QuorumController implements the main logic of the KRaft (Kafka Raft Metadata) mode controller. + * + * The node which is the leader of the metadata log becomes the active controller. All + * other nodes remain in standby mode. Standby controllers cannot create new metadata log + * entries. They just replay the metadata log entries that the current active controller + * has created. + * + * The QuorumController is single-threaded. A single event handler thread performs most + * operations. This avoids the need for complex locking. + * + * The controller exposes an asynchronous, futures-based API to the world. This reflects + * the fact that the controller may have several operations in progress at any given + * point. The future associated with each operation will not be completed until the + * results of the operation have been made durable to the metadata log. + */ +public final class QuorumController implements Controller { + /** + * A builder class which creates the QuorumController. + */ + static public class Builder { + private final int nodeId; + private Time time = Time.SYSTEM; + private String threadNamePrefix = null; + private LogContext logContext = null; + private Map configDefs = Collections.emptyMap(); + private RaftClient raftClient = null; + private Map supportedFeatures = Collections.emptyMap(); + private short defaultReplicationFactor = 3; + private int defaultNumPartitions = 1; + private ReplicaPlacer replicaPlacer = new StripedReplicaPlacer(new Random()); + private long snapshotMaxNewRecordBytes = Long.MAX_VALUE; + private long sessionTimeoutNs = NANOSECONDS.convert(18, TimeUnit.SECONDS); + private ControllerMetrics controllerMetrics = null; + private Optional createTopicPolicy = Optional.empty(); + private Optional alterConfigPolicy = Optional.empty(); + private ConfigurationValidator configurationValidator = ConfigurationValidator.NO_OP; + + public Builder(int nodeId) { + this.nodeId = nodeId; + } + + public Builder setTime(Time time) { + this.time = time; + return this; + } + + public Builder setThreadNamePrefix(String threadNamePrefix) { + this.threadNamePrefix = threadNamePrefix; + return this; + } + + public Builder setLogContext(LogContext logContext) { + this.logContext = logContext; + return this; + } + + public Builder setConfigDefs(Map configDefs) { + this.configDefs = configDefs; + return this; + } + + public Builder setRaftClient(RaftClient logManager) { + this.raftClient = logManager; + return this; + } + + public Builder setSupportedFeatures(Map supportedFeatures) { + this.supportedFeatures = supportedFeatures; + return this; + } + + public Builder setDefaultReplicationFactor(short defaultReplicationFactor) { + this.defaultReplicationFactor = defaultReplicationFactor; + return this; + } + + public Builder setDefaultNumPartitions(int defaultNumPartitions) { + this.defaultNumPartitions = defaultNumPartitions; + return this; + } + + public Builder setReplicaPlacer(ReplicaPlacer replicaPlacer) { + this.replicaPlacer = replicaPlacer; + return this; + } + + public Builder setSnapshotMaxNewRecordBytes(long value) { + this.snapshotMaxNewRecordBytes = value; + return this; + } + + public Builder setSessionTimeoutNs(long sessionTimeoutNs) { + this.sessionTimeoutNs = sessionTimeoutNs; + return this; + } + + public Builder setMetrics(ControllerMetrics controllerMetrics) { + this.controllerMetrics = controllerMetrics; + return this; + } + + public Builder setCreateTopicPolicy(Optional createTopicPolicy) { + this.createTopicPolicy = createTopicPolicy; + return this; + } + + public Builder setAlterConfigPolicy(Optional alterConfigPolicy) { + this.alterConfigPolicy = alterConfigPolicy; + return this; + } + + public Builder setConfigurationValidator(ConfigurationValidator configurationValidator) { + this.configurationValidator = configurationValidator; + return this; + } + + @SuppressWarnings("unchecked") + public QuorumController build() throws Exception { + if (raftClient == null) { + throw new RuntimeException("You must set a raft client."); + } + if (threadNamePrefix == null) { + threadNamePrefix = String.format("Node%d_", nodeId); + } + if (logContext == null) { + logContext = new LogContext(String.format("[Controller %d] ", nodeId)); + } + if (controllerMetrics == null) { + controllerMetrics = (ControllerMetrics) Class.forName( + "org.apache.kafka.controller.MockControllerMetrics").getConstructor().newInstance(); + } + KafkaEventQueue queue = null; + try { + queue = new KafkaEventQueue(time, logContext, threadNamePrefix + "QuorumController"); + return new QuorumController(logContext, nodeId, queue, time, configDefs, + raftClient, supportedFeatures, defaultReplicationFactor, + defaultNumPartitions, replicaPlacer, snapshotMaxNewRecordBytes, + sessionTimeoutNs, controllerMetrics, createTopicPolicy, + alterConfigPolicy, configurationValidator); + } catch (Exception e) { + Utils.closeQuietly(queue, "event queue"); + throw e; + } + } + } + + public static final String CONTROLLER_THREAD_SUFFIX = "QuorumControllerEventHandler"; + + private static final String ACTIVE_CONTROLLER_EXCEPTION_TEXT_PREFIX = + "The active controller appears to be node "; + + private NotControllerException newNotControllerException() { + OptionalInt latestController = raftClient.leaderAndEpoch().leaderId(); + if (latestController.isPresent()) { + return new NotControllerException(ACTIVE_CONTROLLER_EXCEPTION_TEXT_PREFIX + + latestController.getAsInt()); + } else { + return new NotControllerException("No controller appears to be active."); + } + } + + public static int exceptionToApparentController(NotControllerException e) { + if (e.getMessage().startsWith(ACTIVE_CONTROLLER_EXCEPTION_TEXT_PREFIX)) { + return Integer.parseInt(e.getMessage().substring( + ACTIVE_CONTROLLER_EXCEPTION_TEXT_PREFIX.length())); + } else { + return -1; + } + } + + private void handleEventEnd(String name, long startProcessingTimeNs) { + long endProcessingTime = time.nanoseconds(); + long deltaNs = endProcessingTime - startProcessingTimeNs; + log.debug("Processed {} in {} us", name, + MICROSECONDS.convert(deltaNs, NANOSECONDS)); + controllerMetrics.updateEventQueueProcessingTime(NANOSECONDS.toMillis(deltaNs)); + } + + private Throwable handleEventException(String name, + Optional startProcessingTimeNs, + Throwable exception) { + if (!startProcessingTimeNs.isPresent()) { + log.info("unable to start processing {} because of {}.", name, + exception.getClass().getSimpleName()); + if (exception instanceof ApiException) { + return exception; + } else { + return new UnknownServerException(exception); + } + } + long endProcessingTime = time.nanoseconds(); + long deltaNs = endProcessingTime - startProcessingTimeNs.get(); + long deltaUs = MICROSECONDS.convert(deltaNs, NANOSECONDS); + if (exception instanceof ApiException) { + log.info("{}: failed with {} in {} us", name, + exception.getClass().getSimpleName(), deltaUs); + return exception; + } + log.warn("{}: failed with unknown server exception {} at epoch {} in {} us. " + + "Reverting to last committed offset {}.", + this, exception.getClass().getSimpleName(), curClaimEpoch, deltaUs, + lastCommittedOffset, exception); + raftClient.resign(curClaimEpoch); + renounce(); + return new UnknownServerException(exception); + } + + /** + * A controller event for handling internal state changes, such as Raft inputs. + */ + class ControlEvent implements EventQueue.Event { + private final String name; + private final Runnable handler; + private final long eventCreatedTimeNs = time.nanoseconds(); + private Optional startProcessingTimeNs = Optional.empty(); + + ControlEvent(String name, Runnable handler) { + this.name = name; + this.handler = handler; + } + + @Override + public void run() throws Exception { + long now = time.nanoseconds(); + controllerMetrics.updateEventQueueTime(NANOSECONDS.toMillis(now - eventCreatedTimeNs)); + startProcessingTimeNs = Optional.of(now); + log.debug("Executing {}.", this); + handler.run(); + handleEventEnd(this.toString(), startProcessingTimeNs.get()); + } + + @Override + public void handleException(Throwable exception) { + handleEventException(name, startProcessingTimeNs, exception); + } + + @Override + public String toString() { + return name; + } + } + + private void appendControlEvent(String name, Runnable handler) { + ControlEvent event = new ControlEvent(name, handler); + queue.append(event); + } + + private static final String GENERATE_SNAPSHOT = "generateSnapshot"; + + private static final int MAX_BATCHES_PER_GENERATE_CALL = 10; + + class SnapshotGeneratorManager implements Runnable { + private final ExponentialBackoff exponentialBackoff = new ExponentialBackoff(10, 2, 5000, 0); + private SnapshotGenerator generator = null; + + void createSnapshotGenerator(long committedOffset, int committedEpoch, long committedTimestamp) { + if (generator != null) { + throw new RuntimeException("Snapshot generator already exists."); + } + if (!snapshotRegistry.hasSnapshot(committedOffset)) { + throw new RuntimeException( + String.format( + "Cannot generate a snapshot at committed offset %s because it does not exists in the snapshot registry.", + committedOffset + ) + ); + } + Optional> writer = raftClient.createSnapshot( + committedOffset, + committedEpoch, + committedTimestamp + ); + if (writer.isPresent()) { + generator = new SnapshotGenerator( + logContext, + writer.get(), + MAX_BATCHES_PER_GENERATE_CALL, + exponentialBackoff, + Arrays.asList( + new Section("features", featureControl.iterator(committedOffset)), + new Section("cluster", clusterControl.iterator(committedOffset)), + new Section("replication", replicationControl.iterator(committedOffset)), + new Section("configuration", configurationControl.iterator(committedOffset)), + new Section("clientQuotas", clientQuotaControlManager.iterator(committedOffset)), + new Section("producerIds", producerIdControlManager.iterator(committedOffset)) + ) + ); + reschedule(0); + } else { + log.info( + "Skipping generation of snapshot for committed offset {} and epoch {} since it already exists", + committedOffset, + committedEpoch + ); + } + } + + void cancel() { + if (generator == null) return; + log.error("Cancelling snapshot {}", generator.lastContainedLogOffset()); + generator.writer().close(); + generator = null; + + // Delete every in-memory snapshot up to the committed offset. They are not needed since this + // snapshot generation was canceled. + snapshotRegistry.deleteSnapshotsUpTo(lastCommittedOffset); + + queue.cancelDeferred(GENERATE_SNAPSHOT); + } + + void reschedule(long delayNs) { + ControlEvent event = new ControlEvent(GENERATE_SNAPSHOT, this); + queue.scheduleDeferred(event.name, + new EarliestDeadlineFunction(time.nanoseconds() + delayNs), event); + } + + @Override + public void run() { + if (generator == null) { + log.debug("No snapshot is in progress."); + return; + } + OptionalLong nextDelay; + try { + nextDelay = generator.generateBatches(); + } catch (Exception e) { + log.error("Error while generating snapshot {}", generator.lastContainedLogOffset(), e); + generator.writer().close(); + generator = null; + return; + } + if (!nextDelay.isPresent()) { + log.info("Finished generating snapshot {}.", generator.lastContainedLogOffset()); + generator.writer().close(); + generator = null; + + // Delete every in-memory snapshot up to the committed offset. They are not needed since this + // snapshot generation finished. + snapshotRegistry.deleteSnapshotsUpTo(lastCommittedOffset); + return; + } + reschedule(nextDelay.getAsLong()); + } + + OptionalLong snapshotLastOffsetFromLog() { + if (generator == null) { + return OptionalLong.empty(); + } + return OptionalLong.of(generator.lastContainedLogOffset()); + } + } + + /** + * A controller event that reads the committed internal state in order to expose it + * to an API. + */ + class ControllerReadEvent implements EventQueue.Event { + private final String name; + private final CompletableFuture future; + private final Supplier handler; + private final long eventCreatedTimeNs = time.nanoseconds(); + private Optional startProcessingTimeNs = Optional.empty(); + + ControllerReadEvent(String name, Supplier handler) { + this.name = name; + this.future = new CompletableFuture(); + this.handler = handler; + } + + CompletableFuture future() { + return future; + } + + @Override + public void run() throws Exception { + long now = time.nanoseconds(); + controllerMetrics.updateEventQueueTime(NANOSECONDS.toMillis(now - eventCreatedTimeNs)); + startProcessingTimeNs = Optional.of(now); + T value = handler.get(); + handleEventEnd(this.toString(), startProcessingTimeNs.get()); + future.complete(value); + } + + @Override + public void handleException(Throwable exception) { + future.completeExceptionally( + handleEventException(name, startProcessingTimeNs, exception)); + } + + @Override + public String toString() { + return name + "(" + System.identityHashCode(this) + ")"; + } + } + + // VisibleForTesting + ReplicationControlManager replicationControl() { + return replicationControl; + } + + // VisibleForTesting + CompletableFuture appendReadEvent(String name, Supplier handler) { + ControllerReadEvent event = new ControllerReadEvent(name, handler); + queue.append(event); + return event.future(); + } + + CompletableFuture appendReadEvent(String name, long deadlineNs, Supplier handler) { + ControllerReadEvent event = new ControllerReadEvent(name, handler); + queue.appendWithDeadline(deadlineNs, event); + return event.future(); + } + + interface ControllerWriteOperation { + /** + * Generate the metadata records needed to implement this controller write + * operation. In general, this operation should not modify the "hard state" of + * the controller. That modification will happen later on, when we replay the + * records generated by this function. + * + * There are cases where this function modifies the "soft state" of the + * controller. Mainly, this happens when we process cluster heartbeats. + * + * This function also generates an RPC result. In general, if the RPC resulted in + * an error, the RPC result will be an error, and the generated record list will + * be empty. This would happen if we tried to create a topic with incorrect + * parameters, for example. Of course, partial errors are possible for batch + * operations. + * + * @return A result containing a list of records, and the RPC result. + */ + ControllerResult generateRecordsAndResult() throws Exception; + + /** + * Once we've passed the records to the Raft layer, we will invoke this function + * with the end offset at which those records were placed. If there were no + * records to write, we'll just pass the last write offset. + */ + default void processBatchEndOffset(long offset) {} + } + + /** + * A controller event that modifies the controller state. + */ + class ControllerWriteEvent implements EventQueue.Event, DeferredEvent { + private final String name; + private final CompletableFuture future; + private final ControllerWriteOperation op; + private final long eventCreatedTimeNs = time.nanoseconds(); + private Optional startProcessingTimeNs = Optional.empty(); + private ControllerResultAndOffset resultAndOffset; + + ControllerWriteEvent(String name, ControllerWriteOperation op) { + this.name = name; + this.future = new CompletableFuture(); + this.op = op; + this.resultAndOffset = null; + } + + CompletableFuture future() { + return future; + } + + @Override + public void run() throws Exception { + long now = time.nanoseconds(); + controllerMetrics.updateEventQueueTime(NANOSECONDS.toMillis(now - eventCreatedTimeNs)); + int controllerEpoch = curClaimEpoch; + if (controllerEpoch == -1) { + throw newNotControllerException(); + } + startProcessingTimeNs = Optional.of(now); + ControllerResult result = op.generateRecordsAndResult(); + if (result.records().isEmpty()) { + op.processBatchEndOffset(writeOffset); + // If the operation did not return any records, then it was actually just + // a read after all, and not a read + write. However, this read was done + // from the latest in-memory state, which might contain uncommitted data. + Optional maybeOffset = purgatory.highestPendingOffset(); + if (!maybeOffset.isPresent()) { + // If the purgatory is empty, there are no pending operations and no + // uncommitted state. We can return immediately. + resultAndOffset = ControllerResultAndOffset.of(-1, result); + log.debug("Completing read-only operation {} immediately because " + + "the purgatory is empty.", this); + complete(null); + return; + } + // If there are operations in the purgatory, we want to wait for the latest + // one to complete before returning our result to the user. + resultAndOffset = ControllerResultAndOffset.of(maybeOffset.get(), result); + log.debug("Read-only operation {} will be completed when the log " + + "reaches offset {}", this, resultAndOffset.offset()); + } else { + // If the operation returned a batch of records, those records need to be + // written before we can return our result to the user. Here, we hand off + // the batch of records to the raft client. They will be written out + // asynchronously. + final long offset; + if (result.isAtomic()) { + offset = raftClient.scheduleAtomicAppend(controllerEpoch, result.records()); + } else { + offset = raftClient.scheduleAppend(controllerEpoch, result.records()); + } + op.processBatchEndOffset(offset); + writeOffset = offset; + resultAndOffset = ControllerResultAndOffset.of(offset, result); + for (ApiMessageAndVersion message : result.records()) { + replay(message.message(), Optional.empty(), offset); + } + snapshotRegistry.getOrCreateSnapshot(offset); + log.debug("Read-write operation {} will be completed when the log " + + "reaches offset {}.", this, resultAndOffset.offset()); + } + purgatory.add(resultAndOffset.offset(), this); + } + + @Override + public void handleException(Throwable exception) { + complete(exception); + } + + @Override + public void complete(Throwable exception) { + if (exception == null) { + handleEventEnd(this.toString(), startProcessingTimeNs.get()); + future.complete(resultAndOffset.response()); + } else { + future.completeExceptionally( + handleEventException(name, startProcessingTimeNs, exception)); + } + } + + @Override + public String toString() { + return name + "(" + System.identityHashCode(this) + ")"; + } + } + + private CompletableFuture appendWriteEvent(String name, + long deadlineNs, + ControllerWriteOperation op) { + ControllerWriteEvent event = new ControllerWriteEvent<>(name, op); + queue.appendWithDeadline(deadlineNs, event); + return event.future(); + } + + private CompletableFuture appendWriteEvent(String name, + ControllerWriteOperation op) { + ControllerWriteEvent event = new ControllerWriteEvent<>(name, op); + queue.append(event); + return event.future(); + } + + class QuorumMetaLogListener implements RaftClient.Listener { + + @Override + public void handleCommit(BatchReader reader) { + appendRaftEvent("handleCommit[baseOffset=" + reader.baseOffset() + "]", () -> { + try { + boolean isActiveController = curClaimEpoch != -1; + long processedRecordsSize = 0; + while (reader.hasNext()) { + Batch batch = reader.next(); + long offset = batch.lastOffset(); + int epoch = batch.epoch(); + List messages = batch.records(); + + if (isActiveController) { + // If the controller is active, the records were already replayed, + // so we don't need to do it here. + log.debug("Completing purgatory items up to offset {} and epoch {}.", offset, epoch); + + // Complete any events in the purgatory that were waiting for this offset. + purgatory.completeUpTo(offset); + + // Delete all the in-memory snapshots that we no longer need. + // If we are writing a new snapshot, then we need to keep that around; + // otherwise, we should delete up to the current committed offset. + snapshotRegistry.deleteSnapshotsUpTo( + snapshotGeneratorManager.snapshotLastOffsetFromLog().orElse(offset)); + + } else { + // If the controller is a standby, replay the records that were + // created by the active controller. + if (log.isDebugEnabled()) { + if (log.isTraceEnabled()) { + log.trace("Replaying commits from the active node up to " + + "offset {} and epoch {}: {}.", offset, epoch, messages.stream() + .map(ApiMessageAndVersion::toString) + .collect(Collectors.joining(", "))); + } else { + log.debug("Replaying commits from the active node up to " + + "offset {} and epoch {}.", offset, epoch); + } + } + for (ApiMessageAndVersion messageAndVersion : messages) { + replay(messageAndVersion.message(), Optional.empty(), offset); + } + } + + lastCommittedOffset = offset; + lastCommittedEpoch = epoch; + lastCommittedTimestamp = batch.appendTimestamp(); + processedRecordsSize += batch.sizeInBytes(); + } + + maybeGenerateSnapshot(processedRecordsSize); + } finally { + reader.close(); + } + }); + } + + @Override + public void handleSnapshot(SnapshotReader reader) { + appendRaftEvent(String.format("handleSnapshot[snapshotId=%s]", reader.snapshotId()), () -> { + try { + boolean isActiveController = curClaimEpoch != -1; + if (isActiveController) { + throw new IllegalStateException( + String.format( + "Asked to load snapshot (%s) when it is the active controller (%s)", + reader.snapshotId(), + curClaimEpoch + ) + ); + } + log.info("Starting to replay snapshot ({}), from last commit offset ({}) and epoch ({})", + reader.snapshotId(), lastCommittedOffset, lastCommittedEpoch); + + resetState(); + + while (reader.hasNext()) { + Batch batch = reader.next(); + long offset = batch.lastOffset(); + List messages = batch.records(); + + if (log.isDebugEnabled()) { + if (log.isTraceEnabled()) { + log.trace( + "Replaying snapshot ({}) batch with last offset of {}: {}", + reader.snapshotId(), + offset, + messages + .stream() + .map(ApiMessageAndVersion::toString) + .collect(Collectors.joining(", ")) + ); + } else { + log.debug( + "Replaying snapshot ({}) batch with last offset of {}", + reader.snapshotId(), + offset + ); + } + } + + for (ApiMessageAndVersion messageAndVersion : messages) { + replay(messageAndVersion.message(), Optional.of(reader.snapshotId()), offset); + } + } + + lastCommittedOffset = reader.lastContainedLogOffset(); + lastCommittedEpoch = reader.lastContainedLogEpoch(); + lastCommittedTimestamp = reader.lastContainedLogTimestamp(); + snapshotRegistry.getOrCreateSnapshot(lastCommittedOffset); + } finally { + reader.close(); + } + }); + } + + @Override + public void handleLeaderChange(LeaderAndEpoch newLeader) { + if (newLeader.isLeader(nodeId)) { + final int newEpoch = newLeader.epoch(); + appendRaftEvent("handleLeaderChange[" + newEpoch + "]", () -> { + int curEpoch = curClaimEpoch; + if (curEpoch != -1) { + throw new RuntimeException("Tried to claim controller epoch " + + newEpoch + ", but we never renounced controller epoch " + + curEpoch); + } + log.info( + "Becoming the active controller at epoch {}, committed offset {} and committed epoch {}.", + newEpoch, lastCommittedOffset, lastCommittedEpoch + ); + + curClaimEpoch = newEpoch; + controllerMetrics.setActive(true); + writeOffset = lastCommittedOffset; + clusterControl.activate(); + + // Before switching to active, create an in-memory snapshot at the last committed offset. This is + // required because the active controller assumes that there is always an in-memory snapshot at the + // last committed offset. + snapshotRegistry.getOrCreateSnapshot(lastCommittedOffset); + }); + } else if (curClaimEpoch != -1) { + appendRaftEvent("handleRenounce[" + curClaimEpoch + "]", () -> { + log.warn("Renouncing the leadership at oldEpoch {} due to a metadata " + + "log event. Reverting to last committed offset {}.", curClaimEpoch, + lastCommittedOffset); + renounce(); + }); + } + } + + @Override + public void beginShutdown() { + queue.beginShutdown("MetaLogManager.Listener"); + } + + private void appendRaftEvent(String name, Runnable runnable) { + appendControlEvent(name, () -> { + if (this != metaLogListener) { + log.debug("Ignoring {} raft event from an old registration", name); + } else { + runnable.run(); + } + }); + } + } + + private void renounce() { + curClaimEpoch = -1; + controllerMetrics.setActive(false); + purgatory.failAll(newNotControllerException()); + + if (snapshotRegistry.hasSnapshot(lastCommittedOffset)) { + snapshotRegistry.revertToSnapshot(lastCommittedOffset); + } else { + resetState(); + raftClient.unregister(metaLogListener); + metaLogListener = new QuorumMetaLogListener(); + raftClient.register(metaLogListener); + } + + writeOffset = -1; + clusterControl.deactivate(); + cancelMaybeFenceReplicas(); + } + + private void scheduleDeferredWriteEvent(String name, long deadlineNs, + ControllerWriteOperation op) { + ControllerWriteEvent event = new ControllerWriteEvent<>(name, op); + queue.scheduleDeferred(name, new EarliestDeadlineFunction(deadlineNs), event); + event.future.exceptionally(e -> { + if (e instanceof UnknownServerException && e.getCause() != null && + e.getCause() instanceof RejectedExecutionException) { + log.error("Cancelling deferred write event {} because the event queue " + + "is now closed.", name); + return null; + } else if (e instanceof NotControllerException) { + log.debug("Cancelling deferred write event {} because this controller " + + "is no longer active.", name); + return null; + } + log.error("Unexpected exception while executing deferred write event {}. " + + "Rescheduling for a minute from now.", name, e); + scheduleDeferredWriteEvent(name, + deadlineNs + NANOSECONDS.convert(1, TimeUnit.MINUTES), op); + return null; + }); + } + + static final String MAYBE_FENCE_REPLICAS = "maybeFenceReplicas"; + + private void rescheduleMaybeFenceStaleBrokers() { + long nextCheckTimeNs = clusterControl.heartbeatManager().nextCheckTimeNs(); + if (nextCheckTimeNs == Long.MAX_VALUE) { + cancelMaybeFenceReplicas(); + return; + } + scheduleDeferredWriteEvent(MAYBE_FENCE_REPLICAS, nextCheckTimeNs, () -> { + ControllerResult result = replicationControl.maybeFenceOneStaleBroker(); + // This following call ensures that if there are multiple brokers that + // are currently stale, then fencing for them is scheduled immediately + rescheduleMaybeFenceStaleBrokers(); + return result; + }); + } + + private void cancelMaybeFenceReplicas() { + queue.cancelDeferred(MAYBE_FENCE_REPLICAS); + } + + @SuppressWarnings("unchecked") + private void replay(ApiMessage message, Optional snapshotId, long offset) { + try { + MetadataRecordType type = MetadataRecordType.fromId(message.apiKey()); + switch (type) { + case REGISTER_BROKER_RECORD: + clusterControl.replay((RegisterBrokerRecord) message); + break; + case UNREGISTER_BROKER_RECORD: + clusterControl.replay((UnregisterBrokerRecord) message); + break; + case TOPIC_RECORD: + replicationControl.replay((TopicRecord) message); + break; + case PARTITION_RECORD: + replicationControl.replay((PartitionRecord) message); + break; + case CONFIG_RECORD: + configurationControl.replay((ConfigRecord) message); + break; + case PARTITION_CHANGE_RECORD: + replicationControl.replay((PartitionChangeRecord) message); + break; + case FENCE_BROKER_RECORD: + clusterControl.replay((FenceBrokerRecord) message); + break; + case UNFENCE_BROKER_RECORD: + clusterControl.replay((UnfenceBrokerRecord) message); + break; + case REMOVE_TOPIC_RECORD: + replicationControl.replay((RemoveTopicRecord) message); + break; + case FEATURE_LEVEL_RECORD: + featureControl.replay((FeatureLevelRecord) message); + break; + case CLIENT_QUOTA_RECORD: + clientQuotaControlManager.replay((ClientQuotaRecord) message); + break; + case PRODUCER_IDS_RECORD: + producerIdControlManager.replay((ProducerIdsRecord) message); + break; + default: + throw new RuntimeException("Unhandled record type " + type); + } + } catch (Exception e) { + if (snapshotId.isPresent()) { + log.error("Error replaying record {} from snapshot {} at last offset {}.", + message.toString(), snapshotId.get(), offset, e); + } else { + log.error("Error replaying record {} at last offset {}.", + message.toString(), offset, e); + } + } + } + + private void maybeGenerateSnapshot(long batchSizeInBytes) { + newBytesSinceLastSnapshot += batchSizeInBytes; + if (newBytesSinceLastSnapshot >= snapshotMaxNewRecordBytes && + snapshotGeneratorManager.generator == null + ) { + boolean isActiveController = curClaimEpoch != -1; + if (!isActiveController) { + // The active controller creates in-memory snapshot every time an uncommitted + // batch gets appended. The in-active controller can be more efficient and only + // create an in-memory snapshot when needed. + snapshotRegistry.getOrCreateSnapshot(lastCommittedOffset); + } + + log.info("Generating a snapshot that includes (epoch={}, offset={}) after {} committed bytes since the last snapshot.", + lastCommittedEpoch, lastCommittedOffset, newBytesSinceLastSnapshot); + + snapshotGeneratorManager.createSnapshotGenerator(lastCommittedOffset, lastCommittedEpoch, lastCommittedTimestamp); + newBytesSinceLastSnapshot = 0; + } + } + + private void resetState() { + snapshotGeneratorManager.cancel(); + snapshotRegistry.reset(); + + newBytesSinceLastSnapshot = 0; + lastCommittedOffset = -1; + lastCommittedEpoch = -1; + lastCommittedTimestamp = -1; + } + + private final LogContext logContext; + + private final Logger log; + + /** + * The ID of this controller node. + */ + private final int nodeId; + + /** + * The single-threaded queue that processes all of our events. + * It also processes timeouts. + */ + private final KafkaEventQueue queue; + + /** + * The Kafka clock object to use. + */ + private final Time time; + + /** + * The controller metrics. + */ + private final ControllerMetrics controllerMetrics; + + /** + * A registry for snapshot data. This must be accessed only by the event queue thread. + */ + private final SnapshotRegistry snapshotRegistry; + + /** + * The purgatory which holds deferred operations which are waiting for the metadata + * log's high water mark to advance. This must be accessed only by the event queue thread. + */ + private final ControllerPurgatory purgatory; + + /** + * An object which stores the controller's dynamic configuration. + * This must be accessed only by the event queue thread. + */ + private final ConfigurationControlManager configurationControl; + + /** + * An object which stores the controller's dynamic client quotas. + * This must be accessed only by the event queue thread. + */ + private final ClientQuotaControlManager clientQuotaControlManager; + + /** + * An object which stores the controller's view of the cluster. + * This must be accessed only by the event queue thread. + */ + private final ClusterControlManager clusterControl; + + /** + * An object which stores the controller's view of the cluster features. + * This must be accessed only by the event queue thread. + */ + private final FeatureControlManager featureControl; + + /** + * An object which stores the controller's view of the latest producer ID + * that has been generated. This must be accessed only by the event queue thread. + */ + private final ProducerIdControlManager producerIdControlManager; + + /** + * An object which stores the controller's view of topics and partitions. + * This must be accessed only by the event queue thread. + */ + private final ReplicationControlManager replicationControl; + + /** + * Manages generating controller snapshots. + */ + private final SnapshotGeneratorManager snapshotGeneratorManager = new SnapshotGeneratorManager(); + + /** + * The interface that we use to mutate the Raft log. + */ + private final RaftClient raftClient; + + /** + * The interface that receives callbacks from the Raft log. These callbacks are + * invoked from the Raft thread(s), not from the controller thread. Control events + * from this callbacks need to compare against this value to verify that the event + * was not from a previous registration. + */ + private QuorumMetaLogListener metaLogListener; + + /** + * If this controller is active, this is the non-negative controller epoch. + * Otherwise, this is -1. This variable must be modified only from the controller + * thread, but it can be read from other threads. + */ + private volatile int curClaimEpoch; + + /** + * The last offset we have committed, or -1 if we have not committed any offsets. + */ + private long lastCommittedOffset = -1; + + /** + * The epoch of the last offset we have committed, or -1 if we have not committed any offsets. + */ + private int lastCommittedEpoch = -1; + + /** + * The timestamp in milliseconds of the last batch we have committed, or -1 if we have not commmitted any offset. + */ + private long lastCommittedTimestamp = -1; + + /** + * If we have called scheduleWrite, this is the last offset we got back from it. + */ + private long writeOffset; + + /** + * Maximum number of bytes processed through handling commits before generating a snapshot. + */ + private final long snapshotMaxNewRecordBytes; + + /** + * Number of bytes processed through handling commits since the last snapshot was generated. + */ + private long newBytesSinceLastSnapshot = 0; + + private QuorumController(LogContext logContext, + int nodeId, + KafkaEventQueue queue, + Time time, + Map configDefs, + RaftClient raftClient, + Map supportedFeatures, + short defaultReplicationFactor, + int defaultNumPartitions, + ReplicaPlacer replicaPlacer, + long snapshotMaxNewRecordBytes, + long sessionTimeoutNs, + ControllerMetrics controllerMetrics, + Optional createTopicPolicy, + Optional alterConfigPolicy, + ConfigurationValidator configurationValidator) { + this.logContext = logContext; + this.log = logContext.logger(QuorumController.class); + this.nodeId = nodeId; + this.queue = queue; + this.time = time; + this.controllerMetrics = controllerMetrics; + this.snapshotRegistry = new SnapshotRegistry(logContext); + this.purgatory = new ControllerPurgatory(); + this.configurationControl = new ConfigurationControlManager(logContext, + snapshotRegistry, configDefs, alterConfigPolicy, configurationValidator); + this.clientQuotaControlManager = new ClientQuotaControlManager(snapshotRegistry); + this.clusterControl = new ClusterControlManager(logContext, time, + snapshotRegistry, sessionTimeoutNs, replicaPlacer, controllerMetrics); + this.featureControl = new FeatureControlManager(supportedFeatures, snapshotRegistry); + this.producerIdControlManager = new ProducerIdControlManager(clusterControl, snapshotRegistry); + this.snapshotMaxNewRecordBytes = snapshotMaxNewRecordBytes; + this.replicationControl = new ReplicationControlManager(snapshotRegistry, + logContext, defaultReplicationFactor, defaultNumPartitions, + configurationControl, clusterControl, controllerMetrics, createTopicPolicy); + this.raftClient = raftClient; + this.metaLogListener = new QuorumMetaLogListener(); + this.curClaimEpoch = -1; + this.writeOffset = -1L; + + resetState(); + + this.raftClient.register(metaLogListener); + } + + @Override + public CompletableFuture alterIsr(AlterIsrRequestData request) { + if (request.topics().isEmpty()) { + return CompletableFuture.completedFuture(new AlterIsrResponseData()); + } + return appendWriteEvent("alterIsr", () -> + replicationControl.alterIsr(request)); + } + + @Override + public CompletableFuture + createTopics(CreateTopicsRequestData request) { + if (request.topics().isEmpty()) { + return CompletableFuture.completedFuture(new CreateTopicsResponseData()); + } + return appendWriteEvent("createTopics", + time.nanoseconds() + NANOSECONDS.convert(request.timeoutMs(), MILLISECONDS), + () -> replicationControl.createTopics(request)); + } + + @Override + public CompletableFuture unregisterBroker(int brokerId) { + return appendWriteEvent("unregisterBroker", + () -> replicationControl.unregisterBroker(brokerId)); + } + + @Override + public CompletableFuture>> findTopicIds(long deadlineNs, + Collection names) { + if (names.isEmpty()) return CompletableFuture.completedFuture(Collections.emptyMap()); + return appendReadEvent("findTopicIds", deadlineNs, + () -> replicationControl.findTopicIds(lastCommittedOffset, names)); + } + + @Override + public CompletableFuture>> findTopicNames(long deadlineNs, + Collection ids) { + if (ids.isEmpty()) return CompletableFuture.completedFuture(Collections.emptyMap()); + return appendReadEvent("findTopicNames", deadlineNs, + () -> replicationControl.findTopicNames(lastCommittedOffset, ids)); + } + + @Override + public CompletableFuture> deleteTopics(long deadlineNs, + Collection ids) { + if (ids.isEmpty()) return CompletableFuture.completedFuture(Collections.emptyMap()); + return appendWriteEvent("deleteTopics", deadlineNs, + () -> replicationControl.deleteTopics(ids)); + } + + @Override + public CompletableFuture>>> + describeConfigs(Map> resources) { + return appendReadEvent("describeConfigs", () -> + configurationControl.describeConfigs(lastCommittedOffset, resources)); + } + + @Override + public CompletableFuture + electLeaders(ElectLeadersRequestData request) { + // If topicPartitions is null, we will try to trigger a new leader election on + // all partitions (!). But if it's empty, there is nothing to do. + if (request.topicPartitions() != null && request.topicPartitions().isEmpty()) { + return CompletableFuture.completedFuture(new ElectLeadersResponseData()); + } + return appendWriteEvent("electLeaders", + time.nanoseconds() + NANOSECONDS.convert(request.timeoutMs(), MILLISECONDS), + () -> replicationControl.electLeaders(request)); + } + + @Override + public CompletableFuture finalizedFeatures() { + return appendReadEvent("getFinalizedFeatures", + () -> featureControl.finalizedFeatures(lastCommittedOffset)); + } + + @Override + public CompletableFuture> incrementalAlterConfigs( + Map>> configChanges, + boolean validateOnly) { + if (configChanges.isEmpty()) { + return CompletableFuture.completedFuture(Collections.emptyMap()); + } + return appendWriteEvent("incrementalAlterConfigs", () -> { + ControllerResult> result = + configurationControl.incrementalAlterConfigs(configChanges); + if (validateOnly) { + return result.withoutRecords(); + } else { + return result; + } + }); + } + + @Override + public CompletableFuture + alterPartitionReassignments(AlterPartitionReassignmentsRequestData request) { + if (request.topics().isEmpty()) { + return CompletableFuture.completedFuture(new AlterPartitionReassignmentsResponseData()); + } + return appendWriteEvent("alterPartitionReassignments", + time.nanoseconds() + NANOSECONDS.convert(request.timeoutMs(), MILLISECONDS), + () -> replicationControl.alterPartitionReassignments(request)); + } + + @Override + public CompletableFuture + listPartitionReassignments(ListPartitionReassignmentsRequestData request) { + if (request.topics() != null && request.topics().isEmpty()) { + return CompletableFuture.completedFuture( + new ListPartitionReassignmentsResponseData().setErrorMessage(null)); + } + return appendReadEvent("listPartitionReassignments", + time.nanoseconds() + NANOSECONDS.convert(request.timeoutMs(), MILLISECONDS), + () -> replicationControl.listPartitionReassignments(request.topics())); + } + + @Override + public CompletableFuture> legacyAlterConfigs( + Map> newConfigs, boolean validateOnly) { + if (newConfigs.isEmpty()) { + return CompletableFuture.completedFuture(Collections.emptyMap()); + } + return appendWriteEvent("legacyAlterConfigs", () -> { + ControllerResult> result = + configurationControl.legacyAlterConfigs(newConfigs); + if (validateOnly) { + return result.withoutRecords(); + } else { + return result; + } + }); + } + + @Override + public CompletableFuture + processBrokerHeartbeat(BrokerHeartbeatRequestData request) { + return appendWriteEvent("processBrokerHeartbeat", + new ControllerWriteOperation() { + private final int brokerId = request.brokerId(); + private boolean inControlledShutdown = false; + + @Override + public ControllerResult generateRecordsAndResult() { + ControllerResult result = replicationControl. + processBrokerHeartbeat(request, lastCommittedOffset); + inControlledShutdown = result.response().inControlledShutdown(); + rescheduleMaybeFenceStaleBrokers(); + return result; + } + + @Override + public void processBatchEndOffset(long offset) { + if (inControlledShutdown) { + clusterControl.heartbeatManager(). + updateControlledShutdownOffset(brokerId, offset); + } + } + }); + } + + @Override + public CompletableFuture + registerBroker(BrokerRegistrationRequestData request) { + return appendWriteEvent("registerBroker", () -> { + ControllerResult result = clusterControl. + registerBroker(request, writeOffset + 1, featureControl. + finalizedFeatures(Long.MAX_VALUE)); + rescheduleMaybeFenceStaleBrokers(); + return result; + }); + } + + @Override + public CompletableFuture> alterClientQuotas( + Collection quotaAlterations, boolean validateOnly) { + if (quotaAlterations.isEmpty()) { + return CompletableFuture.completedFuture(Collections.emptyMap()); + } + return appendWriteEvent("alterClientQuotas", () -> { + ControllerResult> result = + clientQuotaControlManager.alterClientQuotas(quotaAlterations); + if (validateOnly) { + return result.withoutRecords(); + } else { + return result; + } + }); + } + + @Override + public CompletableFuture allocateProducerIds( + AllocateProducerIdsRequestData request) { + return appendWriteEvent("allocateProducerIds", + () -> producerIdControlManager.generateNextProducerId(request.brokerId(), request.brokerEpoch())) + .thenApply(result -> new AllocateProducerIdsResponseData() + .setProducerIdStart(result.producerIdStart()) + .setProducerIdLen(result.producerIdLen())); + } + + @Override + public CompletableFuture> + createPartitions(long deadlineNs, List topics) { + if (topics.isEmpty()) { + return CompletableFuture.completedFuture(Collections.emptyList()); + } + return appendWriteEvent("createPartitions", deadlineNs, + () -> replicationControl.createPartitions(topics)); + } + + @Override + public CompletableFuture beginWritingSnapshot() { + CompletableFuture future = new CompletableFuture<>(); + appendControlEvent("beginWritingSnapshot", () -> { + if (snapshotGeneratorManager.generator == null) { + snapshotGeneratorManager.createSnapshotGenerator( + lastCommittedOffset, + lastCommittedEpoch, + lastCommittedTimestamp + ); + } + future.complete(snapshotGeneratorManager.generator.lastContainedLogOffset()); + }); + return future; + } + + @Override + public CompletableFuture waitForReadyBrokers(int minBrokers) { + final CompletableFuture future = new CompletableFuture<>(); + appendControlEvent("waitForReadyBrokers", () -> { + clusterControl.addReadyBrokersFuture(future, minBrokers); + }); + return future; + } + + @Override + public void beginShutdown() { + queue.beginShutdown("QuorumController#beginShutdown"); + } + + public int nodeId() { + return nodeId; + } + + @Override + public int curClaimEpoch() { + return curClaimEpoch; + } + + @Override + public void close() throws InterruptedException { + queue.close(); + controllerMetrics.close(); + } + + // VisibleForTesting + CountDownLatch pause() { + final CountDownLatch latch = new CountDownLatch(1); + appendControlEvent("pause", () -> { + try { + latch.await(); + } catch (InterruptedException e) { + log.info("Interrupted while waiting for unpause.", e); + } + }); + return latch; + } + + // VisibleForTesting + Time time() { + return time; + } +} diff --git a/metadata/src/main/java/org/apache/kafka/controller/QuorumControllerMetrics.java b/metadata/src/main/java/org/apache/kafka/controller/QuorumControllerMetrics.java new file mode 100644 index 0000000..9b3a4dd --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/controller/QuorumControllerMetrics.java @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import com.yammer.metrics.core.Gauge; +import com.yammer.metrics.core.Histogram; +import com.yammer.metrics.core.MetricName; +import com.yammer.metrics.core.MetricsRegistry; + +import java.util.Arrays; +import java.util.Objects; + +public final class QuorumControllerMetrics implements ControllerMetrics { + private final static MetricName ACTIVE_CONTROLLER_COUNT = getMetricName( + "KafkaController", "ActiveControllerCount"); + private final static MetricName EVENT_QUEUE_TIME_MS = getMetricName( + "ControllerEventManager", "EventQueueTimeMs"); + private final static MetricName EVENT_QUEUE_PROCESSING_TIME_MS = getMetricName( + "ControllerEventManager", "EventQueueProcessingTimeMs"); + private final static MetricName FENCED_BROKER_COUNT = getMetricName( + "KafkaController", "FencedBrokerCount"); + private final static MetricName ACTIVE_BROKER_COUNT = getMetricName( + "KafkaController", "ActiveBrokerCount"); + private final static MetricName GLOBAL_TOPIC_COUNT = getMetricName( + "KafkaController", "GlobalTopicCount"); + private final static MetricName GLOBAL_PARTITION_COUNT = getMetricName( + "KafkaController", "GlobalPartitionCount"); + private final static MetricName OFFLINE_PARTITION_COUNT = getMetricName( + "KafkaController", "OfflinePartitionsCount"); + private final static MetricName PREFERRED_REPLICA_IMBALANCE_COUNT = getMetricName( + "KafkaController", "PreferredReplicaImbalanceCount"); + + private final MetricsRegistry registry; + private volatile boolean active; + private volatile int fencedBrokerCount; + private volatile int activeBrokerCount; + private volatile int globalTopicCount; + private volatile int globalPartitionCount; + private volatile int offlinePartitionCount; + private volatile int preferredReplicaImbalanceCount; + private final Gauge activeControllerCount; + private final Gauge fencedBrokerCountGauge; + private final Gauge activeBrokerCountGauge; + private final Gauge globalPartitionCountGauge; + private final Gauge globalTopicCountGauge; + private final Gauge offlinePartitionCountGauge; + private final Gauge preferredReplicaImbalanceCountGauge; + private final Histogram eventQueueTime; + private final Histogram eventQueueProcessingTime; + + public QuorumControllerMetrics(MetricsRegistry registry) { + this.registry = Objects.requireNonNull(registry); + this.active = false; + this.fencedBrokerCount = 0; + this.activeBrokerCount = 0; + this.globalTopicCount = 0; + this.globalPartitionCount = 0; + this.offlinePartitionCount = 0; + this.preferredReplicaImbalanceCount = 0; + this.activeControllerCount = registry.newGauge(ACTIVE_CONTROLLER_COUNT, new Gauge() { + @Override + public Integer value() { + return active ? 1 : 0; + } + }); + this.eventQueueTime = registry.newHistogram(EVENT_QUEUE_TIME_MS, true); + this.eventQueueProcessingTime = registry.newHistogram(EVENT_QUEUE_PROCESSING_TIME_MS, true); + this.fencedBrokerCountGauge = registry.newGauge(FENCED_BROKER_COUNT, new Gauge() { + @Override + public Integer value() { + return fencedBrokerCount; + } + }); + this.activeBrokerCountGauge = registry.newGauge(ACTIVE_BROKER_COUNT, new Gauge() { + @Override + public Integer value() { + return activeBrokerCount; + } + }); + this.globalTopicCountGauge = registry.newGauge(GLOBAL_TOPIC_COUNT, new Gauge() { + @Override + public Integer value() { + return globalTopicCount; + } + }); + this.globalPartitionCountGauge = registry.newGauge(GLOBAL_PARTITION_COUNT, new Gauge() { + @Override + public Integer value() { + return globalPartitionCount; + } + }); + this.offlinePartitionCountGauge = registry.newGauge(OFFLINE_PARTITION_COUNT, new Gauge() { + @Override + public Integer value() { + return offlinePartitionCount; + } + }); + this.preferredReplicaImbalanceCountGauge = registry.newGauge(PREFERRED_REPLICA_IMBALANCE_COUNT, new Gauge() { + @Override + public Integer value() { + return preferredReplicaImbalanceCount; + } + }); + } + + @Override + public void setActive(boolean active) { + this.active = active; + } + + @Override + public boolean active() { + return this.active; + } + + @Override + public void updateEventQueueTime(long durationMs) { + eventQueueTime.update(durationMs); + } + + @Override + public void updateEventQueueProcessingTime(long durationMs) { + eventQueueTime.update(durationMs); + } + + @Override + public void setFencedBrokerCount(int brokerCount) { + this.fencedBrokerCount = brokerCount; + } + + @Override + public int fencedBrokerCount() { + return this.fencedBrokerCount; + } + + public void setActiveBrokerCount(int brokerCount) { + this.activeBrokerCount = brokerCount; + } + + @Override + public int activeBrokerCount() { + return this.activeBrokerCount; + } + + @Override + public void setGlobalTopicsCount(int topicCount) { + this.globalTopicCount = topicCount; + } + + @Override + public int globalTopicsCount() { + return this.globalTopicCount; + } + + @Override + public void setGlobalPartitionCount(int partitionCount) { + this.globalPartitionCount = partitionCount; + } + + @Override + public int globalPartitionCount() { + return this.globalPartitionCount; + } + + @Override + public void setOfflinePartitionCount(int offlinePartitions) { + this.offlinePartitionCount = offlinePartitions; + } + + @Override + public int offlinePartitionCount() { + return this.offlinePartitionCount; + } + + @Override + public void setPreferredReplicaImbalanceCount(int replicaImbalances) { + this.preferredReplicaImbalanceCount = replicaImbalances; + } + + @Override + public int preferredReplicaImbalanceCount() { + return this.preferredReplicaImbalanceCount; + } + + @Override + public void close() { + Arrays.asList( + ACTIVE_CONTROLLER_COUNT, + EVENT_QUEUE_TIME_MS, + EVENT_QUEUE_PROCESSING_TIME_MS, + GLOBAL_TOPIC_COUNT, + GLOBAL_PARTITION_COUNT, + OFFLINE_PARTITION_COUNT, + PREFERRED_REPLICA_IMBALANCE_COUNT).forEach(this.registry::removeMetric); + } + + private static MetricName getMetricName(String type, String name) { + final String group = "kafka.controller"; + final StringBuilder mbeanNameBuilder = new StringBuilder(); + mbeanNameBuilder.append(group).append(":type=").append(type).append(",name=").append(name); + return new MetricName(group, type, name, null, mbeanNameBuilder.toString()); + } +} diff --git a/metadata/src/main/java/org/apache/kafka/controller/ReplicaPlacer.java b/metadata/src/main/java/org/apache/kafka/controller/ReplicaPlacer.java new file mode 100644 index 0000000..9a705f4 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/controller/ReplicaPlacer.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import java.util.Iterator; +import java.util.List; +import org.apache.kafka.common.annotation.InterfaceStability; +import org.apache.kafka.common.errors.InvalidReplicationFactorException; +import org.apache.kafka.metadata.UsableBroker; + + +/** + * The interface which a Kafka replica placement policy must implement. + */ +@InterfaceStability.Unstable +interface ReplicaPlacer { + /** + * Create a new replica placement. + * + * @param startPartition The partition ID to start with. + * @param numPartitions The number of partitions to create placements for. + * @param numReplicas The number of replicas to create for each partitions. + * Must be positive. + * @param iterator An iterator that yields all the usable brokers. + * + * @return A list of replica lists. + * + * @throws InvalidReplicationFactorException If too many replicas were requested. + */ + List> place(int startPartition, + int numPartitions, + short numReplicas, + Iterator iterator) + throws InvalidReplicationFactorException; +} diff --git a/metadata/src/main/java/org/apache/kafka/controller/ReplicationControlManager.java b/metadata/src/main/java/org/apache/kafka/controller/ReplicationControlManager.java new file mode 100644 index 0000000..5462dea --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/controller/ReplicationControlManager.java @@ -0,0 +1,1450 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import org.apache.kafka.clients.admin.AlterConfigOp.OpType; +import org.apache.kafka.common.ElectionType; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.errors.ApiException; +import org.apache.kafka.common.errors.BrokerIdNotRegisteredException; +import org.apache.kafka.common.errors.InvalidPartitionsException; +import org.apache.kafka.common.errors.InvalidReplicaAssignmentException; +import org.apache.kafka.common.errors.InvalidReplicationFactorException; +import org.apache.kafka.common.errors.InvalidRequestException; +import org.apache.kafka.common.errors.InvalidTopicException; +import org.apache.kafka.common.errors.NoReassignmentInProgressException; +import org.apache.kafka.common.errors.PolicyViolationException; +import org.apache.kafka.common.errors.UnknownServerException; +import org.apache.kafka.common.errors.UnknownTopicIdException; +import org.apache.kafka.common.errors.UnknownTopicOrPartitionException; +import org.apache.kafka.common.internals.Topic; +import org.apache.kafka.common.message.AlterIsrRequestData; +import org.apache.kafka.common.message.AlterIsrResponseData; +import org.apache.kafka.common.message.AlterPartitionReassignmentsRequestData; +import org.apache.kafka.common.message.AlterPartitionReassignmentsRequestData.ReassignablePartition; +import org.apache.kafka.common.message.AlterPartitionReassignmentsRequestData.ReassignableTopic; +import org.apache.kafka.common.message.AlterPartitionReassignmentsResponseData; +import org.apache.kafka.common.message.AlterPartitionReassignmentsResponseData.ReassignablePartitionResponse; +import org.apache.kafka.common.message.AlterPartitionReassignmentsResponseData.ReassignableTopicResponse; +import org.apache.kafka.common.message.BrokerHeartbeatRequestData; +import org.apache.kafka.common.message.CreatePartitionsRequestData.CreatePartitionsAssignment; +import org.apache.kafka.common.message.CreatePartitionsRequestData.CreatePartitionsTopic; +import org.apache.kafka.common.message.CreatePartitionsResponseData.CreatePartitionsTopicResult; +import org.apache.kafka.common.message.CreateTopicsRequestData; +import org.apache.kafka.common.message.CreateTopicsRequestData.CreatableReplicaAssignment; +import org.apache.kafka.common.message.CreateTopicsRequestData.CreatableTopic; +import org.apache.kafka.common.message.CreateTopicsRequestData.CreatableTopicCollection; +import org.apache.kafka.common.message.CreateTopicsResponseData; +import org.apache.kafka.common.message.CreateTopicsResponseData.CreatableTopicResult; +import org.apache.kafka.common.message.ElectLeadersRequestData; +import org.apache.kafka.common.message.ElectLeadersRequestData.TopicPartitions; +import org.apache.kafka.common.message.ElectLeadersResponseData; +import org.apache.kafka.common.message.ElectLeadersResponseData.PartitionResult; +import org.apache.kafka.common.message.ElectLeadersResponseData.ReplicaElectionResult; +import org.apache.kafka.common.message.ListPartitionReassignmentsRequestData.ListPartitionReassignmentsTopics; +import org.apache.kafka.common.message.ListPartitionReassignmentsResponseData; +import org.apache.kafka.common.message.ListPartitionReassignmentsResponseData.OngoingPartitionReassignment; +import org.apache.kafka.common.message.ListPartitionReassignmentsResponseData.OngoingTopicReassignment; +import org.apache.kafka.common.metadata.FenceBrokerRecord; +import org.apache.kafka.common.metadata.PartitionChangeRecord; +import org.apache.kafka.common.metadata.PartitionRecord; +import org.apache.kafka.common.metadata.RemoveTopicRecord; +import org.apache.kafka.common.metadata.TopicRecord; +import org.apache.kafka.common.metadata.UnfenceBrokerRecord; +import org.apache.kafka.common.metadata.UnregisterBrokerRecord; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.ApiError; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.controller.BrokersToIsrs.TopicIdPartition; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.apache.kafka.metadata.BrokerHeartbeatReply; +import org.apache.kafka.metadata.BrokerRegistration; +import org.apache.kafka.metadata.PartitionRegistration; +import org.apache.kafka.metadata.Replicas; +import org.apache.kafka.server.policy.CreateTopicPolicy; +import org.apache.kafka.timeline.SnapshotRegistry; +import org.apache.kafka.timeline.TimelineHashMap; +import org.apache.kafka.timeline.TimelineInteger; +import org.slf4j.Logger; + +import java.util.AbstractMap.SimpleImmutableEntry; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Optional; +import java.util.function.Function; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.ListIterator; +import java.util.Map; +import java.util.Map.Entry; +import java.util.NoSuchElementException; +import java.util.OptionalInt; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import static org.apache.kafka.clients.admin.AlterConfigOp.OpType.SET; +import static org.apache.kafka.common.config.ConfigResource.Type.TOPIC; +import static org.apache.kafka.common.metadata.MetadataRecordType.FENCE_BROKER_RECORD; +import static org.apache.kafka.common.metadata.MetadataRecordType.PARTITION_RECORD; +import static org.apache.kafka.common.metadata.MetadataRecordType.REMOVE_TOPIC_RECORD; +import static org.apache.kafka.common.metadata.MetadataRecordType.TOPIC_RECORD; +import static org.apache.kafka.common.metadata.MetadataRecordType.UNFENCE_BROKER_RECORD; +import static org.apache.kafka.common.metadata.MetadataRecordType.UNREGISTER_BROKER_RECORD; +import static org.apache.kafka.common.protocol.Errors.FENCED_LEADER_EPOCH; +import static org.apache.kafka.common.protocol.Errors.INVALID_REQUEST; +import static org.apache.kafka.common.protocol.Errors.INVALID_UPDATE_VERSION; +import static org.apache.kafka.common.protocol.Errors.NO_REASSIGNMENT_IN_PROGRESS; +import static org.apache.kafka.common.protocol.Errors.UNKNOWN_TOPIC_ID; +import static org.apache.kafka.common.protocol.Errors.UNKNOWN_TOPIC_OR_PARTITION; +import static org.apache.kafka.metadata.LeaderConstants.NO_LEADER; +import static org.apache.kafka.metadata.LeaderConstants.NO_LEADER_CHANGE; + + +/** + * The ReplicationControlManager is the part of the controller which deals with topics + * and partitions. It is responsible for managing the in-sync replica set and leader + * of each partition, as well as administrative tasks like creating or deleting topics. + */ +public class ReplicationControlManager { + + static class TopicControlInfo { + private final String name; + private final Uuid id; + private final TimelineHashMap parts; + + TopicControlInfo(String name, SnapshotRegistry snapshotRegistry, Uuid id) { + this.name = name; + this.id = id; + this.parts = new TimelineHashMap<>(snapshotRegistry, 0); + } + + public String name() { + return name; + } + + public Uuid topicId() { + return id; + } + } + + private final SnapshotRegistry snapshotRegistry; + private final Logger log; + + /** + * The KIP-464 default replication factor that is used if a CreateTopics request does + * not specify one. + */ + private final short defaultReplicationFactor; + + /** + * The KIP-464 default number of partitions that is used if a CreateTopics request does + * not specify a number of partitions. + */ + private final int defaultNumPartitions; + + /** + * A count of the total number of partitions in the cluster. + */ + private final TimelineInteger globalPartitionCount; + + /** + * A count of the number of partitions that do not have their first replica as a leader. + */ + private final TimelineInteger preferredReplicaImbalanceCount; + + /** + * A reference to the controller's configuration control manager. + */ + private final ConfigurationControlManager configurationControl; + + /** + * A reference to the controller's cluster control manager. + */ + private final ClusterControlManager clusterControl; + + /** + * A reference to the controller's metrics registry. + */ + private final ControllerMetrics controllerMetrics; + + /** + * The policy to use to validate that topic assignments are valid, if one is present. + */ + private final Optional createTopicPolicy; + + /** + * Maps topic names to topic UUIDs. + */ + private final TimelineHashMap topicsByName; + + /** + * Maps topic UUIDs to structures containing topic information, including partitions. + */ + private final TimelineHashMap topics; + + /** + * A map of broker IDs to the partitions that the broker is in the ISR for. + */ + private final BrokersToIsrs brokersToIsrs; + + /** + * A map from topic IDs to the partitions in the topic which are reassigning. + */ + private final TimelineHashMap reassigningTopics; + + ReplicationControlManager(SnapshotRegistry snapshotRegistry, + LogContext logContext, + short defaultReplicationFactor, + int defaultNumPartitions, + ConfigurationControlManager configurationControl, + ClusterControlManager clusterControl, + ControllerMetrics controllerMetrics, + Optional createTopicPolicy) { + this.snapshotRegistry = snapshotRegistry; + this.log = logContext.logger(ReplicationControlManager.class); + this.defaultReplicationFactor = defaultReplicationFactor; + this.defaultNumPartitions = defaultNumPartitions; + this.configurationControl = configurationControl; + this.controllerMetrics = controllerMetrics; + this.createTopicPolicy = createTopicPolicy; + this.clusterControl = clusterControl; + this.globalPartitionCount = new TimelineInteger(snapshotRegistry); + this.preferredReplicaImbalanceCount = new TimelineInteger(snapshotRegistry); + this.topicsByName = new TimelineHashMap<>(snapshotRegistry, 0); + this.topics = new TimelineHashMap<>(snapshotRegistry, 0); + this.brokersToIsrs = new BrokersToIsrs(snapshotRegistry); + this.reassigningTopics = new TimelineHashMap<>(snapshotRegistry, 0); + } + + public void replay(TopicRecord record) { + topicsByName.put(record.name(), record.topicId()); + topics.put(record.topicId(), + new TopicControlInfo(record.name(), snapshotRegistry, record.topicId())); + controllerMetrics.setGlobalTopicsCount(topics.size()); + log.info("Created topic {} with topic ID {}.", record.name(), record.topicId()); + } + + public void replay(PartitionRecord record) { + TopicControlInfo topicInfo = topics.get(record.topicId()); + if (topicInfo == null) { + throw new RuntimeException("Tried to create partition " + record.topicId() + + ":" + record.partitionId() + ", but no topic with that ID was found."); + } + PartitionRegistration newPartInfo = new PartitionRegistration(record); + PartitionRegistration prevPartInfo = topicInfo.parts.get(record.partitionId()); + String description = topicInfo.name + "-" + record.partitionId() + + " with topic ID " + record.topicId(); + if (prevPartInfo == null) { + log.info("Created partition {} and {}.", description, newPartInfo); + topicInfo.parts.put(record.partitionId(), newPartInfo); + brokersToIsrs.update(record.topicId(), record.partitionId(), null, + newPartInfo.isr, NO_LEADER, newPartInfo.leader); + globalPartitionCount.increment(); + controllerMetrics.setGlobalPartitionCount(globalPartitionCount.get()); + updateReassigningTopicsIfNeeded(record.topicId(), record.partitionId(), + false, newPartInfo.isReassigning()); + } else if (!newPartInfo.equals(prevPartInfo)) { + newPartInfo.maybeLogPartitionChange(log, description, prevPartInfo); + topicInfo.parts.put(record.partitionId(), newPartInfo); + brokersToIsrs.update(record.topicId(), record.partitionId(), prevPartInfo.isr, + newPartInfo.isr, prevPartInfo.leader, newPartInfo.leader); + updateReassigningTopicsIfNeeded(record.topicId(), record.partitionId(), + prevPartInfo.isReassigning(), newPartInfo.isReassigning()); + } + if (newPartInfo.leader != newPartInfo.preferredReplica()) { + preferredReplicaImbalanceCount.increment(); + } + controllerMetrics.setOfflinePartitionCount(brokersToIsrs.offlinePartitionCount()); + controllerMetrics.setPreferredReplicaImbalanceCount(preferredReplicaImbalanceCount.get()); + } + + private void updateReassigningTopicsIfNeeded(Uuid topicId, int partitionId, + boolean wasReassigning, boolean isReassigning) { + if (!wasReassigning) { + if (isReassigning) { + int[] prevReassigningParts = reassigningTopics.getOrDefault(topicId, Replicas.NONE); + reassigningTopics.put(topicId, Replicas.copyWith(prevReassigningParts, partitionId)); + } + } else if (!isReassigning) { + int[] prevReassigningParts = reassigningTopics.getOrDefault(topicId, Replicas.NONE); + int[] newReassigningParts = Replicas.copyWithout(prevReassigningParts, partitionId); + if (newReassigningParts.length == 0) { + reassigningTopics.remove(topicId); + } else { + reassigningTopics.put(topicId, newReassigningParts); + } + } + } + + public void replay(PartitionChangeRecord record) { + TopicControlInfo topicInfo = topics.get(record.topicId()); + if (topicInfo == null) { + throw new RuntimeException("Tried to create partition " + record.topicId() + + ":" + record.partitionId() + ", but no topic with that ID was found."); + } + PartitionRegistration prevPartitionInfo = topicInfo.parts.get(record.partitionId()); + if (prevPartitionInfo == null) { + throw new RuntimeException("Tried to create partition " + record.topicId() + + ":" + record.partitionId() + ", but no partition with that id was found."); + } + PartitionRegistration newPartitionInfo = prevPartitionInfo.merge(record); + updateReassigningTopicsIfNeeded(record.topicId(), record.partitionId(), + prevPartitionInfo.isReassigning(), newPartitionInfo.isReassigning()); + topicInfo.parts.put(record.partitionId(), newPartitionInfo); + brokersToIsrs.update(record.topicId(), record.partitionId(), + prevPartitionInfo.isr, newPartitionInfo.isr, prevPartitionInfo.leader, + newPartitionInfo.leader); + String topicPart = topicInfo.name + "-" + record.partitionId() + " with topic ID " + + record.topicId(); + newPartitionInfo.maybeLogPartitionChange(log, topicPart, prevPartitionInfo); + if (!newPartitionInfo.hasPreferredLeader() && prevPartitionInfo.hasPreferredLeader()) { + preferredReplicaImbalanceCount.increment(); + } + controllerMetrics.setOfflinePartitionCount(brokersToIsrs.offlinePartitionCount()); + controllerMetrics.setPreferredReplicaImbalanceCount(preferredReplicaImbalanceCount.get()); + if (record.removingReplicas() != null || record.addingReplicas() != null) { + log.info("Replayed partition assignment change {} for topic {}", record, topicInfo.name); + } else if (log.isTraceEnabled()) { + log.trace("Replayed partition change {} for topic {}", record, topicInfo.name); + } + } + + public void replay(RemoveTopicRecord record) { + // Remove this topic from the topics map and the topicsByName map. + TopicControlInfo topic = topics.remove(record.topicId()); + if (topic == null) { + throw new UnknownTopicIdException("Can't find topic with ID " + record.topicId() + + " to remove."); + } + topicsByName.remove(topic.name); + reassigningTopics.remove(record.topicId()); + + // Delete the configurations associated with this topic. + configurationControl.deleteTopicConfigs(topic.name); + + // Remove the entries for this topic in brokersToIsrs. + for (PartitionRegistration partition : topic.parts.values()) { + for (int i = 0; i < partition.isr.length; i++) { + brokersToIsrs.removeTopicEntryForBroker(topic.id, partition.isr[i]); + } + if (partition.leader != partition.preferredReplica()) { + preferredReplicaImbalanceCount.decrement(); + } + globalPartitionCount.decrement(); + } + brokersToIsrs.removeTopicEntryForBroker(topic.id, NO_LEADER); + + controllerMetrics.setGlobalTopicsCount(topics.size()); + controllerMetrics.setGlobalPartitionCount(globalPartitionCount.get()); + controllerMetrics.setOfflinePartitionCount(brokersToIsrs.offlinePartitionCount()); + controllerMetrics.setPreferredReplicaImbalanceCount(preferredReplicaImbalanceCount.get()); + log.info("Removed topic {} with ID {}.", topic.name, record.topicId()); + } + + ControllerResult + createTopics(CreateTopicsRequestData request) { + Map topicErrors = new HashMap<>(); + List records = new ArrayList<>(); + + // Check the topic names. + validateNewTopicNames(topicErrors, request.topics()); + + // Identify topics that already exist and mark them with the appropriate error + request.topics().stream().filter(creatableTopic -> topicsByName.containsKey(creatableTopic.name())) + .forEach(t -> topicErrors.put(t.name(), new ApiError(Errors.TOPIC_ALREADY_EXISTS, + "Topic '" + t.name() + "' already exists."))); + + // Verify that the configurations for the new topics are OK, and figure out what + // ConfigRecords should be created. + Map>> configChanges = + computeConfigChanges(topicErrors, request.topics()); + ControllerResult> configResult = + configurationControl.incrementalAlterConfigs(configChanges); + for (Entry entry : configResult.response().entrySet()) { + if (entry.getValue().isFailure()) { + topicErrors.put(entry.getKey().name(), entry.getValue()); + } + } + records.addAll(configResult.records()); + + // Try to create whatever topics are needed. + Map successes = new HashMap<>(); + for (CreatableTopic topic : request.topics()) { + if (topicErrors.containsKey(topic.name())) continue; + ApiError error; + try { + error = createTopic(topic, records, successes); + } catch (ApiException e) { + error = ApiError.fromThrowable(e); + } + if (error.isFailure()) { + topicErrors.put(topic.name(), error); + } + } + + // Create responses for all topics. + CreateTopicsResponseData data = new CreateTopicsResponseData(); + StringBuilder resultsBuilder = new StringBuilder(); + String resultsPrefix = ""; + for (CreatableTopic topic : request.topics()) { + ApiError error = topicErrors.get(topic.name()); + if (error != null) { + data.topics().add(new CreatableTopicResult(). + setName(topic.name()). + setErrorCode(error.error().code()). + setErrorMessage(error.message())); + resultsBuilder.append(resultsPrefix).append(topic).append(": "). + append(error.error()).append(" (").append(error.message()).append(")"); + resultsPrefix = ", "; + continue; + } + CreatableTopicResult result = successes.get(topic.name()); + data.topics().add(result); + resultsBuilder.append(resultsPrefix).append(topic).append(": "). + append("SUCCESS"); + resultsPrefix = ", "; + } + if (request.validateOnly()) { + log.info("Validate-only CreateTopics result(s): {}", resultsBuilder.toString()); + return ControllerResult.atomicOf(Collections.emptyList(), data); + } else { + log.info("CreateTopics result(s): {}", resultsBuilder.toString()); + return ControllerResult.atomicOf(records, data); + } + } + + private ApiError createTopic(CreatableTopic topic, + List records, + Map successes) { + Map newParts = new HashMap<>(); + if (!topic.assignments().isEmpty()) { + if (topic.replicationFactor() != -1) { + return new ApiError(INVALID_REQUEST, + "A manual partition assignment was specified, but replication " + + "factor was not set to -1."); + } + if (topic.numPartitions() != -1) { + return new ApiError(INVALID_REQUEST, + "A manual partition assignment was specified, but numPartitions " + + "was not set to -1."); + } + OptionalInt replicationFactor = OptionalInt.empty(); + for (CreatableReplicaAssignment assignment : topic.assignments()) { + if (newParts.containsKey(assignment.partitionIndex())) { + return new ApiError(Errors.INVALID_REPLICA_ASSIGNMENT, + "Found multiple manual partition assignments for partition " + + assignment.partitionIndex()); + } + validateManualPartitionAssignment(assignment.brokerIds(), replicationFactor); + replicationFactor = OptionalInt.of(assignment.brokerIds().size()); + List isr = assignment.brokerIds().stream(). + filter(clusterControl::unfenced).collect(Collectors.toList()); + if (isr.isEmpty()) { + return new ApiError(Errors.INVALID_REPLICA_ASSIGNMENT, + "All brokers specified in the manual partition assignment for " + + "partition " + assignment.partitionIndex() + " are fenced."); + } + newParts.put(assignment.partitionIndex(), new PartitionRegistration( + Replicas.toArray(assignment.brokerIds()), Replicas.toArray(isr), + Replicas.NONE, Replicas.NONE, isr.get(0), 0, 0)); + } + ApiError error = maybeCheckCreateTopicPolicy(() -> { + Map> assignments = new HashMap<>(); + newParts.entrySet().forEach(e -> assignments.put(e.getKey(), + Replicas.toList(e.getValue().replicas))); + Map configs = new HashMap<>(); + topic.configs().forEach(config -> configs.put(config.name(), config.value())); + return new CreateTopicPolicy.RequestMetadata( + topic.name(), null, null, assignments, configs); + }); + if (error.isFailure()) return error; + } else if (topic.replicationFactor() < -1 || topic.replicationFactor() == 0) { + return new ApiError(Errors.INVALID_REPLICATION_FACTOR, + "Replication factor must be larger than 0, or -1 to use the default value."); + } else if (topic.numPartitions() < -1 || topic.numPartitions() == 0) { + return new ApiError(Errors.INVALID_PARTITIONS, + "Number of partitions was set to an invalid non-positive value."); + } else { + int numPartitions = topic.numPartitions() == -1 ? + defaultNumPartitions : topic.numPartitions(); + short replicationFactor = topic.replicationFactor() == -1 ? + defaultReplicationFactor : topic.replicationFactor(); + try { + List> replicas = clusterControl. + placeReplicas(0, numPartitions, replicationFactor); + for (int partitionId = 0; partitionId < replicas.size(); partitionId++) { + int[] r = Replicas.toArray(replicas.get(partitionId)); + newParts.put(partitionId, + new PartitionRegistration(r, r, Replicas.NONE, Replicas.NONE, r[0], 0, 0)); + } + } catch (InvalidReplicationFactorException e) { + return new ApiError(Errors.INVALID_REPLICATION_FACTOR, + "Unable to replicate the partition " + replicationFactor + + " time(s): " + e.getMessage()); + } + ApiError error = maybeCheckCreateTopicPolicy(() -> { + Map configs = new HashMap<>(); + topic.configs().forEach(config -> configs.put(config.name(), config.value())); + return new CreateTopicPolicy.RequestMetadata( + topic.name(), numPartitions, replicationFactor, null, configs); + }); + if (error.isFailure()) return error; + } + Uuid topicId = Uuid.randomUuid(); + successes.put(topic.name(), new CreatableTopicResult(). + setName(topic.name()). + setTopicId(topicId). + setErrorCode((short) 0). + setErrorMessage(null). + setNumPartitions(newParts.size()). + setReplicationFactor((short) newParts.get(0).replicas.length)); + records.add(new ApiMessageAndVersion(new TopicRecord(). + setName(topic.name()). + setTopicId(topicId), TOPIC_RECORD.highestSupportedVersion())); + for (Entry partEntry : newParts.entrySet()) { + int partitionIndex = partEntry.getKey(); + PartitionRegistration info = partEntry.getValue(); + records.add(info.toRecord(topicId, partitionIndex)); + } + return ApiError.NONE; + } + + private ApiError maybeCheckCreateTopicPolicy(Supplier supplier) { + if (createTopicPolicy.isPresent()) { + try { + createTopicPolicy.get().validate(supplier.get()); + } catch (PolicyViolationException e) { + return new ApiError(Errors.POLICY_VIOLATION, e.getMessage()); + } + } + return ApiError.NONE; + } + + static void validateNewTopicNames(Map topicErrors, + CreatableTopicCollection topics) { + for (CreatableTopic topic : topics) { + if (topicErrors.containsKey(topic.name())) continue; + try { + Topic.validate(topic.name()); + } catch (InvalidTopicException e) { + topicErrors.put(topic.name(), + new ApiError(Errors.INVALID_TOPIC_EXCEPTION, e.getMessage())); + } + } + } + + static Map>> + computeConfigChanges(Map topicErrors, + CreatableTopicCollection topics) { + Map>> configChanges = new HashMap<>(); + for (CreatableTopic topic : topics) { + if (topicErrors.containsKey(topic.name())) continue; + Map> topicConfigs = new HashMap<>(); + for (CreateTopicsRequestData.CreateableTopicConfig config : topic.configs()) { + topicConfigs.put(config.name(), new SimpleImmutableEntry<>(SET, config.value())); + } + if (!topicConfigs.isEmpty()) { + configChanges.put(new ConfigResource(TOPIC, topic.name()), topicConfigs); + } + } + return configChanges; + } + + Map> findTopicIds(long offset, Collection names) { + Map> results = new HashMap<>(names.size()); + for (String name : names) { + if (name == null) { + results.put(null, new ResultOrError<>(INVALID_REQUEST, "Invalid null topic name.")); + } else { + Uuid id = topicsByName.get(name, offset); + if (id == null) { + results.put(name, new ResultOrError<>( + new ApiError(UNKNOWN_TOPIC_OR_PARTITION))); + } else { + results.put(name, new ResultOrError<>(id)); + } + } + } + return results; + } + + Map> findTopicNames(long offset, Collection ids) { + Map> results = new HashMap<>(ids.size()); + for (Uuid id : ids) { + if (id == null || id.equals(Uuid.ZERO_UUID)) { + results.put(id, new ResultOrError<>(new ApiError(INVALID_REQUEST, + "Attempt to find topic with invalid topicId " + id))); + } else { + TopicControlInfo topic = topics.get(id, offset); + if (topic == null) { + results.put(id, new ResultOrError<>(new ApiError(UNKNOWN_TOPIC_ID))); + } else { + results.put(id, new ResultOrError<>(topic.name)); + } + } + } + return results; + } + + ControllerResult> deleteTopics(Collection ids) { + Map results = new HashMap<>(ids.size()); + List records = new ArrayList<>(ids.size()); + for (Uuid id : ids) { + try { + deleteTopic(id, records); + results.put(id, ApiError.NONE); + } catch (ApiException e) { + results.put(id, ApiError.fromThrowable(e)); + } catch (Exception e) { + log.error("Unexpected deleteTopics error for {}", id, e); + results.put(id, ApiError.fromThrowable(e)); + } + } + return ControllerResult.atomicOf(records, results); + } + + void deleteTopic(Uuid id, List records) { + TopicControlInfo topic = topics.get(id); + if (topic == null) { + throw new UnknownTopicIdException(UNKNOWN_TOPIC_ID.message()); + } + records.add(new ApiMessageAndVersion(new RemoveTopicRecord(). + setTopicId(id), REMOVE_TOPIC_RECORD.highestSupportedVersion())); + } + + // VisibleForTesting + PartitionRegistration getPartition(Uuid topicId, int partitionId) { + TopicControlInfo topic = topics.get(topicId); + if (topic == null) { + return null; + } + return topic.parts.get(partitionId); + } + + // VisibleForTesting + TopicControlInfo getTopic(Uuid topicId) { + return topics.get(topicId); + } + + // VisibleForTesting + BrokersToIsrs brokersToIsrs() { + return brokersToIsrs; + } + + ControllerResult alterIsr(AlterIsrRequestData request) { + clusterControl.checkBrokerEpoch(request.brokerId(), request.brokerEpoch()); + AlterIsrResponseData response = new AlterIsrResponseData(); + List records = new ArrayList<>(); + for (AlterIsrRequestData.TopicData topicData : request.topics()) { + AlterIsrResponseData.TopicData responseTopicData = + new AlterIsrResponseData.TopicData().setName(topicData.name()); + response.topics().add(responseTopicData); + Uuid topicId = topicsByName.get(topicData.name()); + if (topicId == null || !topics.containsKey(topicId)) { + for (AlterIsrRequestData.PartitionData partitionData : topicData.partitions()) { + responseTopicData.partitions().add(new AlterIsrResponseData.PartitionData(). + setPartitionIndex(partitionData.partitionIndex()). + setErrorCode(UNKNOWN_TOPIC_OR_PARTITION.code())); + } + log.info("Rejecting alterIsr request for unknown topic ID {}.", topicId); + continue; + } + TopicControlInfo topic = topics.get(topicId); + for (AlterIsrRequestData.PartitionData partitionData : topicData.partitions()) { + int partitionId = partitionData.partitionIndex(); + PartitionRegistration partition = topic.parts.get(partitionId); + if (partition == null) { + responseTopicData.partitions().add(new AlterIsrResponseData.PartitionData(). + setPartitionIndex(partitionId). + setErrorCode(UNKNOWN_TOPIC_OR_PARTITION.code())); + log.info("Rejecting alterIsr request for unknown partition {}-{}.", + topic.name, partitionId); + continue; + } + if (partitionData.leaderEpoch() != partition.leaderEpoch) { + responseTopicData.partitions().add(new AlterIsrResponseData.PartitionData(). + setPartitionIndex(partitionId). + setErrorCode(FENCED_LEADER_EPOCH.code())); + log.debug("Rejecting alterIsr request from node {} for {}-{} because " + + "the current leader epoch is {}, not {}.", request.brokerId(), topic.name, + partitionId, partition.leaderEpoch, partitionData.leaderEpoch()); + continue; + } + if (request.brokerId() != partition.leader) { + responseTopicData.partitions().add(new AlterIsrResponseData.PartitionData(). + setPartitionIndex(partitionId). + setErrorCode(INVALID_REQUEST.code())); + log.info("Rejecting alterIsr request from node {} for {}-{} because " + + "the current leader is {}.", request.brokerId(), topic.name, + partitionId, partition.leader); + continue; + } + if (partitionData.currentIsrVersion() != partition.partitionEpoch) { + responseTopicData.partitions().add(new AlterIsrResponseData.PartitionData(). + setPartitionIndex(partitionId). + setErrorCode(INVALID_UPDATE_VERSION.code())); + log.info("Rejecting alterIsr request from node {} for {}-{} because " + + "the current partition epoch is {}, not {}.", request.brokerId(), + topic.name, partitionId, partition.partitionEpoch, + partitionData.currentIsrVersion()); + continue; + } + int[] newIsr = Replicas.toArray(partitionData.newIsr()); + if (!Replicas.validateIsr(partition.replicas, newIsr)) { + responseTopicData.partitions().add(new AlterIsrResponseData.PartitionData(). + setPartitionIndex(partitionId). + setErrorCode(INVALID_REQUEST.code())); + log.error("Rejecting alterIsr request from node {} for {}-{} because " + + "it specified an invalid ISR {}.", request.brokerId(), + topic.name, partitionId, partitionData.newIsr()); + continue; + } + if (!Replicas.contains(newIsr, partition.leader)) { + // An alterIsr request can't ask for the current leader to be removed. + responseTopicData.partitions().add(new AlterIsrResponseData.PartitionData(). + setPartitionIndex(partitionId). + setErrorCode(INVALID_REQUEST.code())); + log.error("Rejecting alterIsr request from node {} for {}-{} because " + + "it specified an invalid ISR {} that doesn't include itself.", + request.brokerId(), topic.name, partitionId, partitionData.newIsr()); + continue; + } + // At this point, we have decided to perform the ISR change. We use + // PartitionChangeBuilder to find out what its effect will be. + PartitionChangeBuilder builder = new PartitionChangeBuilder(partition, + topic.id, + partitionId, + r -> clusterControl.unfenced(r), + () -> configurationControl.uncleanLeaderElectionEnabledForTopic(topicData.name())); + builder.setTargetIsr(partitionData.newIsr()); + Optional record = builder.build(); + Errors result = Errors.NONE; + if (record.isPresent()) { + records.add(record.get()); + PartitionChangeRecord change = (PartitionChangeRecord) record.get().message(); + partition = partition.merge(change); + if (log.isDebugEnabled()) { + log.debug("Node {} has altered ISR for {}-{} to {}.", + request.brokerId(), topic.name, partitionId, change.isr()); + } + if (change.leader() != request.brokerId() && + change.leader() != NO_LEADER_CHANGE) { + // Normally, an alterIsr request, which is made by the partition + // leader itself, is not allowed to modify the partition leader. + // However, if there is an ongoing partition reassignment and the + // ISR change completes it, then the leader may change as part of + // the changes made during reassignment cleanup. + // + // In this case, we report back FENCED_LEADER_EPOCH to the leader + // which made the alterIsr request. This lets it know that it must + // fetch new metadata before trying again. This return code is + // unusual because we both return an error and generate a new + // metadata record. We usually only do one or the other. + log.info("AlterIsr request from node {} for {}-{} completed " + + "the ongoing partition reassignment and triggered a " + + "leadership change. Reutrning FENCED_LEADER_EPOCH.", + request.brokerId(), topic.name, partitionId); + responseTopicData.partitions().add(new AlterIsrResponseData.PartitionData(). + setPartitionIndex(partitionId). + setErrorCode(FENCED_LEADER_EPOCH.code())); + continue; + } else if (change.removingReplicas() != null || + change.addingReplicas() != null) { + log.info("AlterIsr request from node {} for {}-{} completed " + + "the ongoing partition reassignment.", request.brokerId(), + topic.name, partitionId); + } + } + responseTopicData.partitions().add(new AlterIsrResponseData.PartitionData(). + setPartitionIndex(partitionId). + setErrorCode(result.code()). + setLeaderId(partition.leader). + setLeaderEpoch(partition.leaderEpoch). + setCurrentIsrVersion(partition.partitionEpoch). + setIsr(Replicas.toList(partition.isr))); + } + } + return ControllerResult.of(records, response); + } + + /** + * Generate the appropriate records to handle a broker being fenced. + * + * First, we remove this broker from any non-singleton ISR. Then we generate a + * FenceBrokerRecord. + * + * @param brokerId The broker id. + * @param records The record list to append to. + */ + + void handleBrokerFenced(int brokerId, List records) { + BrokerRegistration brokerRegistration = clusterControl.brokerRegistrations().get(brokerId); + if (brokerRegistration == null) { + throw new RuntimeException("Can't find broker registration for broker " + brokerId); + } + generateLeaderAndIsrUpdates("handleBrokerFenced", brokerId, NO_LEADER, records, + brokersToIsrs.partitionsWithBrokerInIsr(brokerId)); + records.add(new ApiMessageAndVersion(new FenceBrokerRecord(). + setId(brokerId).setEpoch(brokerRegistration.epoch()), + FENCE_BROKER_RECORD.highestSupportedVersion())); + } + + /** + * Generate the appropriate records to handle a broker being unregistered. + * + * First, we remove this broker from any non-singleton ISR. Then we generate an + * UnregisterBrokerRecord. + * + * @param brokerId The broker id. + * @param brokerEpoch The broker epoch. + * @param records The record list to append to. + */ + void handleBrokerUnregistered(int brokerId, long brokerEpoch, + List records) { + generateLeaderAndIsrUpdates("handleBrokerUnregistered", brokerId, NO_LEADER, records, + brokersToIsrs.partitionsWithBrokerInIsr(brokerId)); + records.add(new ApiMessageAndVersion(new UnregisterBrokerRecord(). + setBrokerId(brokerId).setBrokerEpoch(brokerEpoch), + UNREGISTER_BROKER_RECORD.highestSupportedVersion())); + } + + /** + * Generate the appropriate records to handle a broker becoming unfenced. + * + * First, we create an UnfenceBrokerRecord. Then, we check if if there are any + * partitions that don't currently have a leader that should be led by the newly + * unfenced broker. + * + * @param brokerId The broker id. + * @param brokerEpoch The broker epoch. + * @param records The record list to append to. + */ + void handleBrokerUnfenced(int brokerId, long brokerEpoch, List records) { + records.add(new ApiMessageAndVersion(new UnfenceBrokerRecord().setId(brokerId). + setEpoch(brokerEpoch), UNFENCE_BROKER_RECORD.highestSupportedVersion())); + generateLeaderAndIsrUpdates("handleBrokerUnfenced", NO_LEADER, brokerId, records, + brokersToIsrs.partitionsWithNoLeader()); + } + + ControllerResult electLeaders(ElectLeadersRequestData request) { + ElectionType electionType = electionType(request.electionType()); + List records = new ArrayList<>(); + ElectLeadersResponseData response = new ElectLeadersResponseData(); + if (request.topicPartitions() == null) { + // If topicPartitions is null, we try to elect a new leader for every partition. There + // are some obvious issues with this wire protocol. For example, what if we have too + // many partitions to fit the results in a single RPC? This behavior should probably be + // removed from the protocol. For now, however, we have to implement this for + // compatibility with the old controller. + for (Entry topicEntry : topicsByName.entrySet()) { + String topicName = topicEntry.getKey(); + ReplicaElectionResult topicResults = + new ReplicaElectionResult().setTopic(topicName); + response.replicaElectionResults().add(topicResults); + TopicControlInfo topic = topics.get(topicEntry.getValue()); + if (topic != null) { + for (int partitionId : topic.parts.keySet()) { + ApiError error = electLeader(topicName, partitionId, electionType, records); + + // When electing leaders for all partitions, we do not return + // partitions which already have the desired leader. + if (error.error() != Errors.ELECTION_NOT_NEEDED) { + topicResults.partitionResult().add(new PartitionResult(). + setPartitionId(partitionId). + setErrorCode(error.error().code()). + setErrorMessage(error.message())); + } + } + } + } + } else { + for (TopicPartitions topic : request.topicPartitions()) { + ReplicaElectionResult topicResults = + new ReplicaElectionResult().setTopic(topic.topic()); + response.replicaElectionResults().add(topicResults); + for (int partitionId : topic.partitions()) { + ApiError error = electLeader(topic.topic(), partitionId, electionType, records); + topicResults.partitionResult().add(new PartitionResult(). + setPartitionId(partitionId). + setErrorCode(error.error().code()). + setErrorMessage(error.message())); + } + } + } + return ControllerResult.of(records, response); + } + + private static ElectionType electionType(byte electionType) { + try { + return ElectionType.valueOf(electionType); + } catch (IllegalArgumentException e) { + throw new InvalidRequestException("Unknown election type " + (int) electionType); + } + } + + ApiError electLeader(String topic, int partitionId, ElectionType electionType, + List records) { + Uuid topicId = topicsByName.get(topic); + if (topicId == null) { + return new ApiError(UNKNOWN_TOPIC_OR_PARTITION, + "No such topic as " + topic); + } + TopicControlInfo topicInfo = topics.get(topicId); + if (topicInfo == null) { + return new ApiError(UNKNOWN_TOPIC_OR_PARTITION, + "No such topic id as " + topicId); + } + PartitionRegistration partition = topicInfo.parts.get(partitionId); + if (partition == null) { + return new ApiError(UNKNOWN_TOPIC_OR_PARTITION, + "No such partition as " + topic + "-" + partitionId); + } + if ((electionType == ElectionType.PREFERRED && partition.hasPreferredLeader()) + || (electionType == ElectionType.UNCLEAN && partition.hasLeader())) { + return new ApiError(Errors.ELECTION_NOT_NEEDED); + } + + PartitionChangeBuilder builder = new PartitionChangeBuilder(partition, + topicId, + partitionId, + r -> clusterControl.unfenced(r), + () -> electionType == ElectionType.UNCLEAN); + + builder.setAlwaysElectPreferredIfPossible(electionType == ElectionType.PREFERRED); + Optional record = builder.build(); + if (!record.isPresent()) { + if (electionType == ElectionType.PREFERRED) { + return new ApiError(Errors.PREFERRED_LEADER_NOT_AVAILABLE); + } else { + return new ApiError(Errors.ELIGIBLE_LEADERS_NOT_AVAILABLE); + } + } + records.add(record.get()); + return ApiError.NONE; + } + + ControllerResult processBrokerHeartbeat( + BrokerHeartbeatRequestData request, long lastCommittedOffset) { + int brokerId = request.brokerId(); + long brokerEpoch = request.brokerEpoch(); + clusterControl.checkBrokerEpoch(brokerId, brokerEpoch); + BrokerHeartbeatManager heartbeatManager = clusterControl.heartbeatManager(); + BrokerControlStates states = heartbeatManager.calculateNextBrokerState(brokerId, + request, lastCommittedOffset, () -> brokersToIsrs.hasLeaderships(brokerId)); + List records = new ArrayList<>(); + if (states.current() != states.next()) { + switch (states.next()) { + case FENCED: + handleBrokerFenced(brokerId, records); + break; + case UNFENCED: + handleBrokerUnfenced(brokerId, brokerEpoch, records); + break; + case CONTROLLED_SHUTDOWN: + generateLeaderAndIsrUpdates("enterControlledShutdown[" + brokerId + "]", + brokerId, NO_LEADER, records, brokersToIsrs.partitionsWithBrokerInIsr(brokerId)); + break; + case SHUTDOWN_NOW: + handleBrokerFenced(brokerId, records); + break; + } + } + heartbeatManager.touch(brokerId, + states.next().fenced(), + request.currentMetadataOffset()); + boolean isCaughtUp = request.currentMetadataOffset() >= lastCommittedOffset; + BrokerHeartbeatReply reply = new BrokerHeartbeatReply(isCaughtUp, + states.next().fenced(), + states.next().inControlledShutdown(), + states.next().shouldShutDown()); + return ControllerResult.of(records, reply); + } + + public ControllerResult unregisterBroker(int brokerId) { + BrokerRegistration registration = clusterControl.brokerRegistrations().get(brokerId); + if (registration == null) { + throw new BrokerIdNotRegisteredException("Broker ID " + brokerId + + " is not currently registered"); + } + List records = new ArrayList<>(); + handleBrokerUnregistered(brokerId, registration.epoch(), records); + return ControllerResult.of(records, null); + } + + ControllerResult maybeFenceOneStaleBroker() { + List records = new ArrayList<>(); + BrokerHeartbeatManager heartbeatManager = clusterControl.heartbeatManager(); + heartbeatManager.findOneStaleBroker().ifPresent(brokerId -> { + // Even though multiple brokers can go stale at a time, we will process + // fencing one at a time so that the effect of fencing each broker is visible + // to the system prior to processing the next one + log.info("Fencing broker {} because its session has timed out.", brokerId); + handleBrokerFenced(brokerId, records); + heartbeatManager.fence(brokerId); + }); + return ControllerResult.of(records, null); + } + + // Visible for testing + Boolean isBrokerUnfenced(int brokerId) { + return clusterControl.unfenced(brokerId); + } + + ControllerResult> + createPartitions(List topics) { + List records = new ArrayList<>(); + List results = new ArrayList<>(); + for (CreatePartitionsTopic topic : topics) { + ApiError apiError = ApiError.NONE; + try { + createPartitions(topic, records); + } catch (ApiException e) { + apiError = ApiError.fromThrowable(e); + } catch (Exception e) { + log.error("Unexpected createPartitions error for {}", topic, e); + apiError = ApiError.fromThrowable(e); + } + results.add(new CreatePartitionsTopicResult(). + setName(topic.name()). + setErrorCode(apiError.error().code()). + setErrorMessage(apiError.message())); + } + return new ControllerResult<>(records, results, true); + } + + void createPartitions(CreatePartitionsTopic topic, + List records) { + Uuid topicId = topicsByName.get(topic.name()); + if (topicId == null) { + throw new UnknownTopicOrPartitionException(); + } + TopicControlInfo topicInfo = topics.get(topicId); + if (topicInfo == null) { + throw new UnknownTopicOrPartitionException(); + } + if (topic.count() == topicInfo.parts.size()) { + throw new InvalidPartitionsException("Topic already has " + + topicInfo.parts.size() + " partition(s)."); + } else if (topic.count() < topicInfo.parts.size()) { + throw new InvalidPartitionsException("The topic " + topic.name() + " currently " + + "has " + topicInfo.parts.size() + " partition(s); " + topic.count() + + " would not be an increase."); + } + int additional = topic.count() - topicInfo.parts.size(); + if (topic.assignments() != null) { + if (topic.assignments().size() != additional) { + throw new InvalidReplicaAssignmentException("Attempted to add " + additional + + " additional partition(s), but only " + topic.assignments().size() + + " assignment(s) were specified."); + } + } + Iterator iterator = topicInfo.parts.values().iterator(); + if (!iterator.hasNext()) { + throw new UnknownServerException("Invalid state: topic " + topic.name() + + " appears to have no partitions."); + } + PartitionRegistration partitionInfo = iterator.next(); + if (partitionInfo.replicas.length > Short.MAX_VALUE) { + throw new UnknownServerException("Invalid replication factor " + + partitionInfo.replicas.length + ": expected a number equal to less than " + + Short.MAX_VALUE); + } + short replicationFactor = (short) partitionInfo.replicas.length; + int startPartitionId = topicInfo.parts.size(); + + List> placements; + List> isrs; + if (topic.assignments() != null) { + placements = new ArrayList<>(); + isrs = new ArrayList<>(); + for (int i = 0; i < topic.assignments().size(); i++) { + CreatePartitionsAssignment assignment = topic.assignments().get(i); + validateManualPartitionAssignment(assignment.brokerIds(), + OptionalInt.of(replicationFactor)); + placements.add(assignment.brokerIds()); + List isr = assignment.brokerIds().stream(). + filter(clusterControl::unfenced).collect(Collectors.toList()); + if (isr.isEmpty()) { + throw new InvalidReplicaAssignmentException( + "All brokers specified in the manual partition assignment for " + + "partition " + (startPartitionId + i) + " are fenced."); + } + isrs.add(isr); + } + } else { + placements = clusterControl.placeReplicas(startPartitionId, additional, + replicationFactor); + isrs = placements; + } + int partitionId = startPartitionId; + for (int i = 0; i < placements.size(); i++) { + List placement = placements.get(i); + List isr = isrs.get(i); + records.add(new ApiMessageAndVersion(new PartitionRecord(). + setPartitionId(partitionId). + setTopicId(topicId). + setReplicas(placement). + setIsr(isr). + setRemovingReplicas(Collections.emptyList()). + setAddingReplicas(Collections.emptyList()). + setLeader(isr.get(0)). + setLeaderEpoch(0). + setPartitionEpoch(0), PARTITION_RECORD.highestSupportedVersion())); + partitionId++; + } + } + + void validateManualPartitionAssignment(List assignment, + OptionalInt replicationFactor) { + if (assignment.isEmpty()) { + throw new InvalidReplicaAssignmentException("The manual partition " + + "assignment includes an empty replica list."); + } + List sortedBrokerIds = new ArrayList<>(assignment); + sortedBrokerIds.sort(Integer::compare); + Integer prevBrokerId = null; + for (Integer brokerId : sortedBrokerIds) { + if (!clusterControl.brokerRegistrations().containsKey(brokerId)) { + throw new InvalidReplicaAssignmentException("The manual partition " + + "assignment includes broker " + brokerId + ", but no such broker is " + + "registered."); + } + if (brokerId.equals(prevBrokerId)) { + throw new InvalidReplicaAssignmentException("The manual partition " + + "assignment includes the broker " + prevBrokerId + " more than " + + "once."); + } + prevBrokerId = brokerId; + } + if (replicationFactor.isPresent() && + sortedBrokerIds.size() != replicationFactor.getAsInt()) { + throw new InvalidReplicaAssignmentException("The manual partition " + + "assignment includes a partition with " + sortedBrokerIds.size() + + " replica(s), but this is not consistent with previous " + + "partitions, which have " + replicationFactor.getAsInt() + " replica(s)."); + } + } + + /** + * Iterate over a sequence of partitions and generate ISR changes and/or leader + * changes if necessary. + * + * @param context A human-readable context string used in log4j logging. + * @param brokerToRemove NO_LEADER if no broker is being removed; the ID of the + * broker to remove from the ISR and leadership, otherwise. + * @param brokerToAdd NO_LEADER if no broker is being added; the ID of the + * broker which is now eligible to be a leader, otherwise. + * @param records A list of records which we will append to. + * @param iterator The iterator containing the partitions to examine. + */ + void generateLeaderAndIsrUpdates(String context, + int brokerToRemove, + int brokerToAdd, + List records, + Iterator iterator) { + int oldSize = records.size(); + + // If the caller passed a valid broker ID for brokerToAdd, rather than passing + // NO_LEADER, that node will be considered an acceptable leader even if it is + // currently fenced. This is useful when handling unfencing. The reason is that + // while we're generating the records to handle unfencing, the ClusterControlManager + // still shows the node as fenced. + // + // Similarly, if the caller passed a valid broker ID for brokerToRemove, rather + // than passing NO_LEADER, that node will never be considered an acceptable leader. + // This is useful when handling a newly fenced node. We also exclude brokerToRemove + // from the target ISR, but we need to exclude it here too, to handle the case + // where there is an unclean leader election which chooses a leader from outside + // the ISR. + Function isAcceptableLeader = + r -> (r != brokerToRemove) && (r == brokerToAdd || clusterControl.unfenced(r)); + + while (iterator.hasNext()) { + TopicIdPartition topicIdPart = iterator.next(); + TopicControlInfo topic = topics.get(topicIdPart.topicId()); + if (topic == null) { + throw new RuntimeException("Topic ID " + topicIdPart.topicId() + + " existed in isrMembers, but not in the topics map."); + } + PartitionRegistration partition = topic.parts.get(topicIdPart.partitionId()); + if (partition == null) { + throw new RuntimeException("Partition " + topicIdPart + + " existed in isrMembers, but not in the partitions map."); + } + PartitionChangeBuilder builder = new PartitionChangeBuilder(partition, + topicIdPart.topicId(), + topicIdPart.partitionId(), + isAcceptableLeader, + () -> configurationControl.uncleanLeaderElectionEnabledForTopic(topic.name)); + + // Note: if brokerToRemove was passed as NO_LEADER, this is a no-op (the new + // target ISR will be the same as the old one). + builder.setTargetIsr(Replicas.toList( + Replicas.copyWithout(partition.isr, brokerToRemove))); + + builder.build().ifPresent(records::add); + } + if (records.size() != oldSize) { + if (log.isDebugEnabled()) { + StringBuilder bld = new StringBuilder(); + String prefix = ""; + for (ListIterator iter = records.listIterator(oldSize); + iter.hasNext(); ) { + ApiMessageAndVersion apiMessageAndVersion = iter.next(); + PartitionChangeRecord record = (PartitionChangeRecord) apiMessageAndVersion.message(); + bld.append(prefix).append(topics.get(record.topicId()).name).append("-"). + append(record.partitionId()); + prefix = ", "; + } + log.debug("{}: changing partition(s): {}", context, bld.toString()); + } else if (log.isInfoEnabled()) { + log.info("{}: changing {} partition(s)", context, records.size() - oldSize); + } + } + } + + ControllerResult + alterPartitionReassignments(AlterPartitionReassignmentsRequestData request) { + List records = new ArrayList<>(); + AlterPartitionReassignmentsResponseData result = + new AlterPartitionReassignmentsResponseData().setErrorMessage(null); + int successfulAlterations = 0, totalAlterations = 0; + for (ReassignableTopic topic : request.topics()) { + ReassignableTopicResponse topicResponse = new ReassignableTopicResponse(). + setName(topic.name()); + for (ReassignablePartition partition : topic.partitions()) { + ApiError error = ApiError.NONE; + try { + alterPartitionReassignment(topic.name(), partition, records); + successfulAlterations++; + } catch (Throwable e) { + log.info("Unable to alter partition reassignment for " + + topic.name() + ":" + partition.partitionIndex() + " because " + + "of an " + e.getClass().getSimpleName() + " error: " + e.getMessage()); + error = ApiError.fromThrowable(e); + } + totalAlterations++; + topicResponse.partitions().add(new ReassignablePartitionResponse(). + setPartitionIndex(partition.partitionIndex()). + setErrorCode(error.error().code()). + setErrorMessage(error.message())); + } + result.responses().add(topicResponse); + } + log.info("Successfully altered {} out of {} partition reassignment(s).", + successfulAlterations, totalAlterations); + return ControllerResult.atomicOf(records, result); + } + + void alterPartitionReassignment(String topicName, + ReassignablePartition target, + List records) { + Uuid topicId = topicsByName.get(topicName); + if (topicId == null) { + throw new UnknownTopicOrPartitionException("Unable to find a topic " + + "named " + topicName + "."); + } + TopicControlInfo topicInfo = topics.get(topicId); + if (topicInfo == null) { + throw new UnknownTopicOrPartitionException("Unable to find a topic " + + "with ID " + topicId + "."); + } + TopicIdPartition tp = new TopicIdPartition(topicId, target.partitionIndex()); + PartitionRegistration part = topicInfo.parts.get(target.partitionIndex()); + if (part == null) { + throw new UnknownTopicOrPartitionException("Unable to find partition " + + topicName + ":" + target.partitionIndex() + "."); + } + Optional record; + if (target.replicas() == null) { + record = cancelPartitionReassignment(topicName, tp, part); + } else { + record = changePartitionReassignment(tp, part, target); + } + record.ifPresent(records::add); + } + + Optional cancelPartitionReassignment(String topicName, + TopicIdPartition tp, + PartitionRegistration part) { + if (!part.isReassigning()) { + throw new NoReassignmentInProgressException(NO_REASSIGNMENT_IN_PROGRESS.message()); + } + PartitionReassignmentRevert revert = new PartitionReassignmentRevert(part); + if (revert.unclean()) { + if (!configurationControl.uncleanLeaderElectionEnabledForTopic(topicName)) { + throw new InvalidReplicaAssignmentException("Unable to revert partition " + + "assignment for " + topicName + ":" + tp.partitionId() + " because " + + "it would require an unclean leader election."); + } + } + PartitionChangeBuilder builder = new PartitionChangeBuilder(part, + tp.topicId(), + tp.partitionId(), + r -> clusterControl.unfenced(r), + () -> configurationControl.uncleanLeaderElectionEnabledForTopic(topicName)); + builder.setTargetIsr(revert.isr()). + setTargetReplicas(revert.replicas()). + setTargetRemoving(Collections.emptyList()). + setTargetAdding(Collections.emptyList()); + return builder.build(); + } + + /** + * Apply a given partition reassignment. In general a partition reassignment goes + * through several stages: + * + * 1. Issue a PartitionChangeRecord adding all the new replicas to the partition's + * main replica list, and setting removingReplicas and addingReplicas. + * + * 2. Wait for the partition to have an ISR that contains all the new replicas. Or + * if there are no new replicas, wait until we have an ISR that contains at least one + * replica that we are not removing. + * + * 3. Issue a second PartitionChangeRecord removing all removingReplicas from the + * partitions' main replica list, and clearing removingReplicas and addingReplicas. + * + * After stage 3, the reassignment is done. + * + * Under some conditions, steps #1 and #2 can be skipped entirely since the ISR is + * already suitable to progress to stage #3. For example, a partition reassignment + * that merely rearranges existing replicas in the list can bypass step #1 and #2 and + * complete immediately. + * + * @param tp The topic id and partition id. + * @param part The existing partition info. + * @param target The target partition info. + * + * @return The ChangePartitionRecord for the new partition assignment, + * or empty if no change is needed. + */ + Optional changePartitionReassignment(TopicIdPartition tp, + PartitionRegistration part, + ReassignablePartition target) { + // Check that the requested partition assignment is valid. + validateManualPartitionAssignment(target.replicas(), OptionalInt.empty()); + + List currentReplicas = Replicas.toList(part.replicas); + PartitionReassignmentReplicas reassignment = + new PartitionReassignmentReplicas(currentReplicas, target.replicas()); + PartitionChangeBuilder builder = new PartitionChangeBuilder(part, + tp.topicId(), + tp.partitionId(), + r -> clusterControl.unfenced(r), + () -> false); + if (!reassignment.merged().equals(currentReplicas)) { + builder.setTargetReplicas(reassignment.merged()); + } + if (!reassignment.removing().isEmpty()) { + builder.setTargetRemoving(reassignment.removing()); + } + if (!reassignment.adding().isEmpty()) { + builder.setTargetAdding(reassignment.adding()); + } + return builder.build(); + } + + ListPartitionReassignmentsResponseData listPartitionReassignments( + List topicList) { + ListPartitionReassignmentsResponseData response = + new ListPartitionReassignmentsResponseData().setErrorMessage(null); + if (topicList == null) { + // List all reassigning topics. + for (Entry entry : reassigningTopics.entrySet()) { + listReassigningTopic(response, entry.getKey(), Replicas.toList(entry.getValue())); + } + } else { + // List the given topics. + for (ListPartitionReassignmentsTopics topic : topicList) { + Uuid topicId = topicsByName.get(topic.name()); + if (topicId != null) { + listReassigningTopic(response, topicId, topic.partitionIndexes()); + } + } + } + return response; + } + + private void listReassigningTopic(ListPartitionReassignmentsResponseData response, + Uuid topicId, + List partitionIds) { + TopicControlInfo topicInfo = topics.get(topicId); + if (topicInfo == null) return; + OngoingTopicReassignment ongoingTopic = new OngoingTopicReassignment(). + setName(topicInfo.name); + for (int partitionId : partitionIds) { + Optional ongoing = + getOngoingPartitionReassignment(topicInfo, partitionId); + if (ongoing.isPresent()) { + ongoingTopic.partitions().add(ongoing.get()); + } + } + if (!ongoingTopic.partitions().isEmpty()) { + response.topics().add(ongoingTopic); + } + } + + private Optional + getOngoingPartitionReassignment(TopicControlInfo topicInfo, int partitionId) { + PartitionRegistration partition = topicInfo.parts.get(partitionId); + if (partition == null || !partition.isReassigning()) { + return Optional.empty(); + } + return Optional.of(new OngoingPartitionReassignment(). + setAddingReplicas(Replicas.toList(partition.addingReplicas)). + setRemovingReplicas(Replicas.toList(partition.removingReplicas)). + setPartitionIndex(partitionId). + setReplicas(Replicas.toList(partition.replicas))); + } + + class ReplicationControlIterator implements Iterator> { + private final long epoch; + private final Iterator iterator; + + ReplicationControlIterator(long epoch) { + this.epoch = epoch; + this.iterator = topics.values(epoch).iterator(); + } + + @Override + public boolean hasNext() { + return iterator.hasNext(); + } + + @Override + public List next() { + if (!hasNext()) throw new NoSuchElementException(); + TopicControlInfo topic = iterator.next(); + List records = new ArrayList<>(); + records.add(new ApiMessageAndVersion(new TopicRecord(). + setName(topic.name). + setTopicId(topic.id), TOPIC_RECORD.highestSupportedVersion())); + for (Entry entry : topic.parts.entrySet(epoch)) { + records.add(entry.getValue().toRecord(topic.id, entry.getKey())); + } + return records; + } + } + + ReplicationControlIterator iterator(long epoch) { + return new ReplicationControlIterator(epoch); + } +} diff --git a/metadata/src/main/java/org/apache/kafka/controller/ResultOrError.java b/metadata/src/main/java/org/apache/kafka/controller/ResultOrError.java new file mode 100644 index 0000000..6a548c4 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/controller/ResultOrError.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.ApiError; + +import java.util.Objects; + + +public class ResultOrError { + private final ApiError error; + private final T result; + + public ResultOrError(Errors error, String message) { + this(new ApiError(error, message)); + } + + public ResultOrError(ApiError error) { + Objects.requireNonNull(error); + this.error = error; + this.result = null; + } + + public ResultOrError(T result) { + this.error = null; + this.result = result; + } + + public static ResultOrError of(T result) { + return new ResultOrError<>(result); + } + + public static ResultOrError of(ApiError error) { + return new ResultOrError<>(error); + } + + public boolean isError() { + return error != null; + } + + public boolean isResult() { + return error == null; + } + + public ApiError error() { + return error; + } + + public T result() { + return result; + } + + @Override + public boolean equals(Object o) { + if (o == null || (!o.getClass().equals(getClass()))) { + return false; + } + ResultOrError other = (ResultOrError) o; + return Objects.equals(error, other.error) && + Objects.equals(result, other.result); + } + + @Override + public int hashCode() { + return Objects.hash(error, result); + } + + @Override + public String toString() { + if (error == null) { + return "ResultOrError(" + result + ")"; + } else { + return "ResultOrError(" + error + ")"; + } + } +} diff --git a/metadata/src/main/java/org/apache/kafka/controller/SnapshotGenerator.java b/metadata/src/main/java/org/apache/kafka/controller/SnapshotGenerator.java new file mode 100644 index 0000000..df4bc61 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/controller/SnapshotGenerator.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.OptionalLong; + +import org.apache.kafka.common.utils.ExponentialBackoff; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.apache.kafka.snapshot.SnapshotWriter; +import org.slf4j.Logger; + + +final class SnapshotGenerator { + static class Section { + private final String name; + private final Iterator> iterator; + + Section(String name, Iterator> iterator) { + this.name = name; + this.iterator = iterator; + } + + String name() { + return name; + } + + Iterator> iterator() { + return iterator; + } + } + + private final Logger log; + private final SnapshotWriter writer; + private final int maxBatchesPerGenerateCall; + private final ExponentialBackoff exponentialBackoff; + private final List
                sections; + private final Iterator
                sectionIterator; + private Iterator> batchIterator; + private List batch; + private Section section; + private long numRecords; + + SnapshotGenerator(LogContext logContext, + SnapshotWriter writer, + int maxBatchesPerGenerateCall, + ExponentialBackoff exponentialBackoff, + List
                sections) { + this.log = logContext.logger(SnapshotGenerator.class); + this.writer = writer; + this.maxBatchesPerGenerateCall = maxBatchesPerGenerateCall; + this.exponentialBackoff = exponentialBackoff; + this.sections = sections; + this.sectionIterator = this.sections.iterator(); + this.batchIterator = Collections.emptyIterator(); + this.batch = null; + this.section = null; + this.numRecords = 0; + } + + /** + * Returns the last offset from the log that will be included in the snapshot. + */ + long lastContainedLogOffset() { + return writer.lastContainedLogOffset(); + } + + SnapshotWriter writer() { + return writer; + } + + /** + * Generate and write the next batch of records. + * + * @return true if the last batch was generated, otherwise false + */ + private boolean generateBatch() throws Exception { + if (batch == null) { + while (!batchIterator.hasNext()) { + if (section != null) { + log.info("Generated {} record(s) for the {} section of snapshot {}.", + numRecords, section.name(), writer.snapshotId()); + section = null; + numRecords = 0; + } + if (!sectionIterator.hasNext()) { + writer.freeze(); + return true; + } + section = sectionIterator.next(); + log.info("Generating records for the {} section of snapshot {}.", + section.name(), writer.snapshotId()); + batchIterator = section.iterator(); + } + batch = batchIterator.next(); + } + + writer.append(batch); + numRecords += batch.size(); + batch = null; + return false; + } + + /** + * Generate the next few batches of records. + * + * @return The number of nanoseconds to delay before rescheduling the + * generateBatches event, or empty if the snapshot is done. + */ + OptionalLong generateBatches() throws Exception { + for (int numBatches = 0; numBatches < maxBatchesPerGenerateCall; numBatches++) { + if (generateBatch()) { + return OptionalLong.empty(); + } + } + return OptionalLong.of(0); + } +} diff --git a/metadata/src/main/java/org/apache/kafka/controller/StripedReplicaPlacer.java b/metadata/src/main/java/org/apache/kafka/controller/StripedReplicaPlacer.java new file mode 100644 index 0000000..031354c --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/controller/StripedReplicaPlacer.java @@ -0,0 +1,445 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Random; + +import org.apache.kafka.common.errors.InvalidReplicationFactorException; +import org.apache.kafka.metadata.OptionalStringComparator; +import org.apache.kafka.metadata.UsableBroker; + + +/** + * The striped replica placer. + * + * + * GOALS + * The design of this placer attempts to satisfy a few competing goals. Firstly, we want + * to spread the replicas as evenly as we can across racks. In the simple case where + * broker racks have not been configured, this goal is a no-op, of course. But it is the + * highest priority goal in multi-rack clusters. + * + * Our second goal is to spread the replicas evenly across brokers. Since we are placing + * multiple partitions, we try to avoid putting each partition on the same set of + * replicas, even if it does satisfy the rack placement goal. If any specific broker is + * fenced, we would like the new leaders to distributed evenly across the remaining + * brokers. + * + * However, we treat the rack placement goal as higher priority than this goal-- if you + * configure 10 brokers in rack A and B, and 1 broker in rack C, you will end up with a + * lot of partitions on that one broker in rack C. If you were to place a lot of + * partitions with replication factor 3, each partition would try to get a replica there. + * In general racks are supposed to be about the same size -- if they aren't, this is a + * user error. + * + * Finally, we would prefer to place replicas on unfenced brokers, rather than on fenced + * brokers. + * + * + * CONSTRAINTS + * In addition to these goals, we have two constraints. Unlike the goals, these are not + * optional -- they are mandatory. Placement will fail if a constraint cannot be + * satisfied. The first constraint is that we can't place more than one replica on the + * same broker. This imposes an upper limit on replication factor-- for example, a 3-node + * cluster can't have any topics with replication factor 4. This constraint comes from + * Kafka's internal design. + * + * The second constraint is that the leader of each partition must be an unfenced broker. + * This constraint is a bit arbitrary. In theory, we could allow people to create + * new topics even if every broker were fenced. However, this would be confusing for + * users. + * + * + * ALGORITHM + * The StripedReplicaPlacer constructor loads the broker data into rack objects. Each + * rack object contains a sorted list of fenced brokers, and a separate sorted list of + * unfenced brokers. The racks themselves are organized into a sorted list, stored inside + * the top-level RackList object. + * + * The general idea is that we place replicas on to racks in a round-robin fashion. So if + * we had racks A, B, C, and D, and we were creating a new partition with replication + * factor 3, our first replica might come from A, our second from B, and our third from C. + * Of course our placement would not be very fair if we always started with rack A. + * Therefore, we generate a random starting offset when the RackList is created. So one + * time we might go B, C, D. Another time we might go C, D, A. And so forth. + * + * Note that each partition we generate advances the starting offset by one. + * So in our 4-rack cluster, with 3 partitions, we might choose these racks: + * + * partition 1: A, B, C + * partition 2: B, C, A + * partition 3: C, A, B + * + * This is what generates the characteristic "striped" pattern of this placer. + * + * So far I haven't said anything about how we choose a replica from within a rack. In + * fact, this is also done in a round-robin fashion. So if rack A had replica A0, A1, A2, + * and A3, we might return A0 the first time, A1, the second, A2 the third, and so on. + * Just like with the racks, we add a random starting offset to mix things up a bit. + * + * So let's say you had a cluster with racks A, B, and C, and each rack had 3 replicas, + * for 9 nodes in total. + * If all the offsets were 0, you'd get placements like this: + * + * partition 1: A0, B0, C0 + * partition 2: B1, C1, A1 + * partition 3: C2, A2, B2 + * + * One additional complication with choosing a replica within a rack is that we want to + * choose the unfenced replicas first. In a big cluster with lots of nodes available, + * we'd prefer not to place a new partition on a node that is fenced. Therefore, we + * actually maintain two lists, rather than the single list I described above. + * We only start using the fenced node list when the unfenced node list is totally + * exhausted. + * + * Furthermore, we cannot place the first replica (the leader) of a new partition on a + * fenced replica. Therefore, we have some special logic to ensure that this doesn't + * happen. + */ +public class StripedReplicaPlacer implements ReplicaPlacer { + /** + * A list of brokers that we can iterate through. + */ + static class BrokerList { + final static BrokerList EMPTY = new BrokerList(); + private final List brokers = new ArrayList<>(0); + + /** + * How many brokers we have retrieved from the list during the current iteration epoch. + */ + private int index = 0; + + /** + * The offset to add to the index in order to calculate the list entry to fetch. The + * addition is done modulo the list size. + */ + private int offset = 0; + + /** + * The last known iteration epoch. If we call next with a different epoch than this, the + * index and offset will be reset. + */ + private int epoch = 0; + + BrokerList add(int broker) { + this.brokers.add(broker); + return this; + } + + /** + * Initialize this broker list by sorting it and randomizing the start offset. + * + * @param random The random number generator. + */ + void initialize(Random random) { + if (!brokers.isEmpty()) { + brokers.sort(Integer::compareTo); + this.offset = random.nextInt(brokers.size()); + } + } + + /** + * Randomly shuffle the brokers in this list. + */ + void shuffle(Random random) { + Collections.shuffle(brokers, random); + } + + /** + * @return The number of brokers in this list. + */ + int size() { + return brokers.size(); + } + + /** + * Get the next broker in this list, or -1 if there are no more elements to be + * returned. + * + * @param epoch The current iteration epoch. + * + * @return The broker ID, or -1 if there are no more brokers to be + * returned in this epoch. + */ + int next(int epoch) { + if (brokers.size() == 0) return -1; + if (this.epoch != epoch) { + this.epoch = epoch; + this.index = 0; + this.offset = (offset + 1) % brokers.size(); + } + if (index >= brokers.size()) return -1; + int broker = brokers.get((index + offset) % brokers.size()); + index++; + return broker; + } + } + + /** + * A rack in the cluster, which contains brokers. + */ + static class Rack { + private final BrokerList fenced = new BrokerList(); + private final BrokerList unfenced = new BrokerList(); + + /** + * Initialize this rack. + * + * @param random The random number generator. + */ + void initialize(Random random) { + fenced.initialize(random); + unfenced.initialize(random); + } + + void shuffle(Random random) { + fenced.shuffle(random); + unfenced.shuffle(random); + } + + BrokerList fenced() { + return fenced; + } + + BrokerList unfenced() { + return unfenced; + } + + /** + * Get the next unfenced broker in this rack, or -1 if there are no more brokers + * to be returned. + * + * @param epoch The current iteration epoch. + * + * @return The broker ID, or -1 if there are no more brokers to be + * returned in this epoch. + */ + int nextUnfenced(int epoch) { + return unfenced.next(epoch); + } + + /** + * Get the next broker in this rack, or -1 if there are no more brokers to be + * returned. + * + * @param epoch The current iteration epoch. + * + * @return The broker ID, or -1 if there are no more brokers to be + * returned in this epoch. + */ + int next(int epoch) { + int result = unfenced.next(epoch); + if (result >= 0) return result; + return fenced.next(epoch); + } + } + + /** + * A list of racks that we can iterate through. + */ + static class RackList { + /** + * The random number generator. + */ + private final Random random; + + /** + * A map from rack names to the brokers contained within them. + */ + private final Map, Rack> racks = new HashMap<>(); + + /** + * The names of all the racks in the cluster. + * + * Racks which have at least one unfenced broker come first (in sorted order), + * followed by racks which have only fenced brokers (also in sorted order). + */ + private final List> rackNames = new ArrayList<>(); + + /** + * The total number of brokers in the cluster, both fenced and unfenced. + */ + private final int numTotalBrokers; + + /** + * The total number of unfenced brokers in the cluster. + */ + private final int numUnfencedBrokers; + + /** + * The iteration epoch. + */ + private int epoch = 0; + + /** + * The offset we use to determine which rack is returned first. + */ + private int offset; + + RackList(Random random, Iterator iterator) { + this.random = random; + int numTotalBrokersCount = 0, numUnfencedBrokersCount = 0; + while (iterator.hasNext()) { + UsableBroker broker = iterator.next(); + Rack rack = racks.get(broker.rack()); + if (rack == null) { + rackNames.add(broker.rack()); + rack = new Rack(); + racks.put(broker.rack(), rack); + } + if (broker.fenced()) { + rack.fenced().add(broker.id()); + } else { + numUnfencedBrokersCount++; + rack.unfenced().add(broker.id()); + } + numTotalBrokersCount++; + } + for (Rack rack : racks.values()) { + rack.initialize(random); + } + this.rackNames.sort(OptionalStringComparator.INSTANCE); + this.numTotalBrokers = numTotalBrokersCount; + this.numUnfencedBrokers = numUnfencedBrokersCount; + this.offset = rackNames.isEmpty() ? 0 : random.nextInt(rackNames.size()); + } + + int numTotalBrokers() { + return numTotalBrokers; + } + + int numUnfencedBrokers() { + return numUnfencedBrokers; + } + + // VisibleForTesting + List> rackNames() { + return rackNames; + } + + List place(int replicationFactor) { + throwInvalidReplicationFactorIfNonPositive(replicationFactor); + throwInvalidReplicationFactorIfTooFewBrokers(replicationFactor, numTotalBrokers()); + throwInvalidReplicationFactorIfZero(numUnfencedBrokers()); + // If we have returned as many assignments as there are unfenced brokers in + // the cluster, shuffle the rack list and broker lists to try to avoid + // repeating the same assignments again. + // But don't reset the iteration epoch for a single unfenced broker -- otherwise we would loop forever + if (epoch == numUnfencedBrokers && numUnfencedBrokers > 1) { + shuffle(); + epoch = 0; + } + if (offset == rackNames.size()) { + offset = 0; + } + List brokers = new ArrayList<>(replicationFactor); + int firstRackIndex = offset; + while (true) { + Optional name = rackNames.get(firstRackIndex); + Rack rack = racks.get(name); + int result = rack.nextUnfenced(epoch); + if (result >= 0) { + brokers.add(result); + break; + } + firstRackIndex++; + if (firstRackIndex == rackNames.size()) { + firstRackIndex = 0; + } + } + int rackIndex = offset; + for (int replica = 1; replica < replicationFactor; replica++) { + int result = -1; + do { + if (rackIndex == firstRackIndex) { + firstRackIndex = -1; + } else { + Optional rackName = rackNames.get(rackIndex); + Rack rack = racks.get(rackName); + result = rack.next(epoch); + } + rackIndex++; + if (rackIndex == rackNames.size()) { + rackIndex = 0; + } + } while (result < 0); + brokers.add(result); + } + epoch++; + offset++; + return brokers; + } + + void shuffle() { + Collections.shuffle(rackNames, random); + for (Rack rack : racks.values()) { + rack.shuffle(random); + } + } + } + + private static void throwInvalidReplicationFactorIfNonPositive(int replicationFactor) { + if (replicationFactor <= 0) { + throw new InvalidReplicationFactorException("Invalid replication factor " + + replicationFactor + ": the replication factor must be positive."); + } + } + + private static void throwInvalidReplicationFactorIfZero(int numUnfenced) { + if (numUnfenced == 0) { + throw new InvalidReplicationFactorException("All brokers are currently fenced."); + } + } + + private static void throwInvalidReplicationFactorIfTooFewBrokers(int replicationFactor, int numTotalBrokers) { + if (replicationFactor > numTotalBrokers) { + throw new InvalidReplicationFactorException("The target replication factor " + + "of " + replicationFactor + " cannot be reached because only " + + numTotalBrokers + " broker(s) are registered."); + } + } + + private final Random random; + + public StripedReplicaPlacer(Random random) { + this.random = random; + } + + @Override + public List> place(int startPartition, + int numPartitions, + short replicationFactor, + Iterator iterator) { + RackList rackList = new RackList(random, iterator); + throwInvalidReplicationFactorIfNonPositive(replicationFactor); + throwInvalidReplicationFactorIfZero(rackList.numUnfencedBrokers()); + throwInvalidReplicationFactorIfTooFewBrokers(replicationFactor, rackList.numTotalBrokers()); + List> placements = new ArrayList<>(numPartitions); + for (int partition = 0; partition < numPartitions; partition++) { + placements.add(rackList.place(replicationFactor)); + } + return placements; + } +} diff --git a/metadata/src/main/java/org/apache/kafka/image/ClientQuotaDelta.java b/metadata/src/main/java/org/apache/kafka/image/ClientQuotaDelta.java new file mode 100644 index 0000000..2def59a --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/image/ClientQuotaDelta.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.image; + +import org.apache.kafka.common.metadata.ClientQuotaRecord; + +import java.util.HashMap; +import java.util.Map; +import java.util.OptionalDouble; + +import static java.util.Map.Entry; + + +public final class ClientQuotaDelta { + private final ClientQuotaImage image; + private final Map changes = new HashMap<>(); + + public ClientQuotaDelta(ClientQuotaImage image) { + this.image = image; + } + + public Map changes() { + return changes; + } + + public void finishSnapshot() { + for (String key : image.quotas().keySet()) { + if (!changes.containsKey(key)) { + // If a quota from the image did not appear in the snapshot, mark it as removed. + changes.put(key, OptionalDouble.empty()); + } + } + } + + public void replay(ClientQuotaRecord record) { + if (record.remove()) { + changes.put(record.key(), OptionalDouble.empty()); + } else { + changes.put(record.key(), OptionalDouble.of(record.value())); + } + } + + public ClientQuotaImage apply() { + Map newQuotas = new HashMap<>(image.quotas().size()); + for (Entry entry : image.quotas().entrySet()) { + OptionalDouble change = changes.get(entry.getKey()); + if (change == null) { + newQuotas.put(entry.getKey(), entry.getValue()); + } else if (change.isPresent()) { + newQuotas.put(entry.getKey(), change.getAsDouble()); + } + } + for (Entry entry : changes.entrySet()) { + if (!newQuotas.containsKey(entry.getKey())) { + if (entry.getValue().isPresent()) { + newQuotas.put(entry.getKey(), entry.getValue().getAsDouble()); + } + } + } + return new ClientQuotaImage(newQuotas); + } +} diff --git a/metadata/src/main/java/org/apache/kafka/image/ClientQuotaImage.java b/metadata/src/main/java/org/apache/kafka/image/ClientQuotaImage.java new file mode 100644 index 0000000..d5a47a7 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/image/ClientQuotaImage.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.image; + +import org.apache.kafka.common.message.DescribeClientQuotasResponseData.ValueData; +import org.apache.kafka.common.metadata.ClientQuotaRecord; +import org.apache.kafka.common.metadata.ClientQuotaRecord.EntityData; +import org.apache.kafka.common.quota.ClientQuotaEntity; +import org.apache.kafka.server.common.ApiMessageAndVersion; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map.Entry; +import java.util.Map; +import java.util.Objects; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +import static org.apache.kafka.common.metadata.MetadataRecordType.CLIENT_QUOTA_RECORD; + + +/** + * Represents a quota for a client entity in the metadata image. + * + * This class is thread-safe. + */ +public final class ClientQuotaImage { + public final static ClientQuotaImage EMPTY = new ClientQuotaImage(Collections.emptyMap()); + + private final Map quotas; + + public ClientQuotaImage(Map quotas) { + this.quotas = quotas; + } + + Map quotas() { + return quotas; + } + + public void write(ClientQuotaEntity entity, Consumer> out) { + List records = new ArrayList<>(quotas.size()); + for (Entry entry : quotas.entrySet()) { + records.add(new ApiMessageAndVersion(new ClientQuotaRecord(). + setEntity(entityToData(entity)). + setKey(entry.getKey()). + setValue(entry.getValue()). + setRemove(false), + CLIENT_QUOTA_RECORD.highestSupportedVersion())); + } + out.accept(records); + } + + public static List entityToData(ClientQuotaEntity entity) { + List entityData = new ArrayList<>(entity.entries().size()); + for (Entry entry : entity.entries().entrySet()) { + entityData.add(new EntityData(). + setEntityType(entry.getKey()). + setEntityName(entry.getValue())); + } + return entityData; + } + + public static ClientQuotaEntity dataToEntity(List entityData) { + Map entries = new HashMap<>(); + for (EntityData data : entityData) { + entries.put(data.entityType(), data.entityName()); + } + return new ClientQuotaEntity(Collections.unmodifiableMap(entries)); + } + + public List toDescribeValues() { + List values = new ArrayList<>(quotas.size()); + for (Entry entry : quotas.entrySet()) { + values.add(new ValueData().setKey(entry.getKey()).setValue(entry.getValue())); + } + return values; + } + + public boolean isEmpty() { + return quotas.isEmpty(); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof ClientQuotaImage)) return false; + ClientQuotaImage other = (ClientQuotaImage) o; + return quotas.equals(other.quotas); + } + + @Override + public int hashCode() { + return Objects.hash(quotas); + } + + @Override + public String toString() { + return "ClientQuotaImage(quotas=" + quotas.entrySet().stream(). + map(e -> e.getKey() + ":" + e.getValue()).collect(Collectors.joining(", ")) + + ")"; + } +} diff --git a/metadata/src/main/java/org/apache/kafka/image/ClientQuotasDelta.java b/metadata/src/main/java/org/apache/kafka/image/ClientQuotasDelta.java new file mode 100644 index 0000000..4b574b3 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/image/ClientQuotasDelta.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.image; + +import org.apache.kafka.common.metadata.ClientQuotaRecord; +import org.apache.kafka.common.quota.ClientQuotaEntity; + +import java.util.HashMap; +import java.util.Map; +import java.util.Map.Entry; + + +public final class ClientQuotasDelta { + private final ClientQuotasImage image; + private final Map changes = new HashMap<>(); + + public ClientQuotasDelta(ClientQuotasImage image) { + this.image = image; + } + + public Map changes() { + return changes; + } + + public void finishSnapshot() { + for (Entry entry : image.entities().entrySet()) { + ClientQuotaEntity entity = entry.getKey(); + ClientQuotaImage quotaImage = entry.getValue(); + ClientQuotaDelta quotaDelta = changes.computeIfAbsent(entity, + __ -> new ClientQuotaDelta(quotaImage)); + quotaDelta.finishSnapshot(); + } + } + + public void replay(ClientQuotaRecord record) { + ClientQuotaEntity entity = ClientQuotaImage.dataToEntity(record.entity()); + ClientQuotaDelta change = changes.computeIfAbsent(entity, __ -> + new ClientQuotaDelta(image.entities(). + getOrDefault(entity, ClientQuotaImage.EMPTY))); + change.replay(record); + } + + public ClientQuotasImage apply() { + Map newEntities = + new HashMap<>(image.entities().size()); + for (Entry entry : image.entities().entrySet()) { + ClientQuotaEntity entity = entry.getKey(); + ClientQuotaDelta change = changes.get(entity); + if (change == null) { + newEntities.put(entity, entry.getValue()); + } else { + ClientQuotaImage quotaImage = change.apply(); + if (!quotaImage.isEmpty()) { + newEntities.put(entity, quotaImage); + } + } + } + for (Entry entry : changes.entrySet()) { + ClientQuotaEntity entity = entry.getKey(); + if (!newEntities.containsKey(entity)) { + ClientQuotaImage quotaImage = entry.getValue().apply(); + if (!quotaImage.isEmpty()) { + newEntities.put(entity, quotaImage); + } + } + } + return new ClientQuotasImage(newEntities); + } + + @Override + public String toString() { + return "ClientQuotasDelta(" + + "changes=" + changes + + ')'; + } +} diff --git a/metadata/src/main/java/org/apache/kafka/image/ClientQuotasImage.java b/metadata/src/main/java/org/apache/kafka/image/ClientQuotasImage.java new file mode 100644 index 0000000..98b9f0e --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/image/ClientQuotasImage.java @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.image; + +import org.apache.kafka.common.errors.InvalidRequestException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.message.DescribeClientQuotasRequestData; +import org.apache.kafka.common.message.DescribeClientQuotasResponseData; +import org.apache.kafka.common.message.DescribeClientQuotasResponseData.EntityData; +import org.apache.kafka.common.message.DescribeClientQuotasResponseData.EntryData; +import org.apache.kafka.common.quota.ClientQuotaEntity; +import org.apache.kafka.server.common.ApiMessageAndVersion; + +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map.Entry; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +import static org.apache.kafka.common.quota.ClientQuotaEntity.CLIENT_ID; +import static org.apache.kafka.common.quota.ClientQuotaEntity.IP; +import static org.apache.kafka.common.quota.ClientQuotaEntity.USER; +import static org.apache.kafka.common.requests.DescribeClientQuotasRequest.MATCH_TYPE_EXACT; +import static org.apache.kafka.common.requests.DescribeClientQuotasRequest.MATCH_TYPE_DEFAULT; +import static org.apache.kafka.common.requests.DescribeClientQuotasRequest.MATCH_TYPE_SPECIFIED; + + +/** + * Represents the client quotas in the metadata image. + * + * This class is thread-safe. + */ +public final class ClientQuotasImage { + public final static ClientQuotasImage EMPTY = new ClientQuotasImage(Collections.emptyMap()); + + private final Map entities; + + public ClientQuotasImage(Map entities) { + this.entities = Collections.unmodifiableMap(entities); + } + + public boolean isEmpty() { + return entities.isEmpty(); + } + + Map entities() { + return entities; + } + + public void write(Consumer> out) { + for (Entry entry : entities.entrySet()) { + ClientQuotaEntity entity = entry.getKey(); + ClientQuotaImage clientQuotaImage = entry.getValue(); + clientQuotaImage.write(entity, out); + } + } + + public DescribeClientQuotasResponseData describe(DescribeClientQuotasRequestData request) { + DescribeClientQuotasResponseData response = new DescribeClientQuotasResponseData(); + Map exactMatch = new HashMap<>(); + Set typeMatch = new HashSet<>(); + for (DescribeClientQuotasRequestData.ComponentData component : request.components()) { + if (component.entityType().isEmpty()) { + throw new InvalidRequestException("Invalid empty entity type."); + } else if (exactMatch.containsKey(component.entityType()) || + typeMatch.contains(component.entityType())) { + throw new InvalidRequestException("Entity type " + component.entityType() + + " cannot appear more than once in the filter."); + } + if (!(component.entityType().equals(IP) || component.entityType().equals(USER) || + component.entityType().equals(CLIENT_ID))) { + throw new UnsupportedVersionException("Unsupported entity type " + + component.entityType()); + } + switch (component.matchType()) { + case MATCH_TYPE_EXACT: + if (component.match() == null) { + throw new InvalidRequestException("Request specified " + + "MATCH_TYPE_EXACT, but set match string to null."); + } + exactMatch.put(component.entityType(), component.match()); + break; + case MATCH_TYPE_DEFAULT: + if (component.match() != null) { + throw new InvalidRequestException("Request specified " + + "MATCH_TYPE_DEFAULT, but also specified a match string."); + } + exactMatch.put(component.entityType(), null); + break; + case MATCH_TYPE_SPECIFIED: + if (component.match() != null) { + throw new InvalidRequestException("Request specified " + + "MATCH_TYPE_SPECIFIED, but also specified a match string."); + } + typeMatch.add(component.entityType()); + break; + default: + throw new InvalidRequestException("Unknown match type " + component.matchType()); + } + } + if (exactMatch.containsKey(IP) || typeMatch.contains(IP)) { + if ((exactMatch.containsKey(USER) || typeMatch.contains(USER)) || + (exactMatch.containsKey(CLIENT_ID) || typeMatch.contains(CLIENT_ID))) { + throw new InvalidRequestException("Invalid entity filter component " + + "combination. IP filter component should not be used with " + + "user or clientId filter component."); + } + } + // TODO: this is O(N). We should add indexing here to speed it up. See KAFKA-13022. + for (Entry entry : entities.entrySet()) { + ClientQuotaEntity entity = entry.getKey(); + ClientQuotaImage quotaImage = entry.getValue(); + if (matches(entity, exactMatch, typeMatch, request.strict())) { + response.entries().add(toDescribeEntry(entity, quotaImage)); + } + } + return response; + } + + private static boolean matches(ClientQuotaEntity entity, + Map exactMatch, + Set typeMatch, + boolean strict) { + if (strict) { + if (entity.entries().size() != exactMatch.size() + typeMatch.size()) { + return false; + } + } + for (Entry entry : exactMatch.entrySet()) { + if (!entity.entries().containsKey(entry.getKey())) { + return false; + } + if (!Objects.equals(entity.entries().get(entry.getKey()), entry.getValue())) { + return false; + } + } + for (String type : typeMatch) { + if (!entity.entries().containsKey(type)) { + return false; + } + } + return true; + } + + private static EntryData toDescribeEntry(ClientQuotaEntity entity, + ClientQuotaImage quotaImage) { + EntryData data = new EntryData(); + for (Entry entry : entity.entries().entrySet()) { + data.entity().add(new EntityData(). + setEntityType(entry.getKey()). + setEntityName(entry.getValue())); + } + data.setValues(quotaImage.toDescribeValues()); + return data; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof ClientQuotasImage)) return false; + ClientQuotasImage other = (ClientQuotasImage) o; + return entities.equals(other.entities); + } + + @Override + public int hashCode() { + return Objects.hash(entities); + } + + @Override + public String toString() { + return "ClientQuotasImage(entities=" + entities.entrySet().stream(). + map(e -> e.getKey() + ":" + e.getValue()).collect(Collectors.joining(", ")) + + ")"; + } +} diff --git a/metadata/src/main/java/org/apache/kafka/image/ClusterDelta.java b/metadata/src/main/java/org/apache/kafka/image/ClusterDelta.java new file mode 100644 index 0000000..6c48b8e --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/image/ClusterDelta.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.image; + +import org.apache.kafka.common.metadata.BrokerRegistrationChangeRecord; +import org.apache.kafka.common.metadata.FenceBrokerRecord; +import org.apache.kafka.common.metadata.RegisterBrokerRecord; +import org.apache.kafka.common.metadata.UnfenceBrokerRecord; +import org.apache.kafka.common.metadata.UnregisterBrokerRecord; +import org.apache.kafka.metadata.BrokerRegistration; + +import java.util.HashMap; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; + + +/** + * Represents changes to the cluster in the metadata image. + */ +public final class ClusterDelta { + private final ClusterImage image; + private final HashMap> changedBrokers = new HashMap<>(); + + public ClusterDelta(ClusterImage image) { + this.image = image; + } + + public HashMap> changedBrokers() { + return changedBrokers; + } + + public BrokerRegistration broker(int nodeId) { + Optional result = changedBrokers.get(nodeId); + if (result != null) { + return result.orElse(null); + } + return image.broker(nodeId); + } + + public void finishSnapshot() { + for (Integer brokerId : image.brokers().keySet()) { + if (!changedBrokers.containsKey(brokerId)) { + changedBrokers.put(brokerId, Optional.empty()); + } + } + } + + public void replay(RegisterBrokerRecord record) { + BrokerRegistration broker = BrokerRegistration.fromRecord(record); + changedBrokers.put(broker.id(), Optional.of(broker)); + } + + public void replay(UnregisterBrokerRecord record) { + changedBrokers.put(record.brokerId(), Optional.empty()); + } + + private BrokerRegistration getBrokerOrThrow(int brokerId, long epoch, String action) { + BrokerRegistration broker = broker(brokerId); + if (broker == null) { + throw new IllegalStateException("Tried to " + action + " broker " + brokerId + + ", but that broker was not registered."); + } + if (broker.epoch() != epoch) { + throw new IllegalStateException("Tried to " + action + " broker " + brokerId + + ", but the given epoch, " + epoch + ", did not match the current broker " + + "epoch, " + broker.epoch()); + } + return broker; + } + + public void replay(FenceBrokerRecord record) { + BrokerRegistration broker = getBrokerOrThrow(record.id(), record.epoch(), "fence"); + changedBrokers.put(record.id(), Optional.of(broker.cloneWithFencing(true))); + } + + public void replay(UnfenceBrokerRecord record) { + BrokerRegistration broker = getBrokerOrThrow(record.id(), record.epoch(), "unfence"); + changedBrokers.put(record.id(), Optional.of(broker.cloneWithFencing(false))); + } + + public void replay(BrokerRegistrationChangeRecord record) { + BrokerRegistration broker = + getBrokerOrThrow(record.brokerId(), record.brokerEpoch(), "change"); + if (record.fenced() < 0) { + changedBrokers.put(record.brokerId(), Optional.of(broker.cloneWithFencing(false))); + } else if (record.fenced() > 0) { + changedBrokers.put(record.brokerId(), Optional.of(broker.cloneWithFencing(true))); + } + } + + public ClusterImage apply() { + Map newBrokers = new HashMap<>(image.brokers().size()); + for (Entry entry : image.brokers().entrySet()) { + int nodeId = entry.getKey(); + Optional change = changedBrokers.get(nodeId); + if (change == null) { + newBrokers.put(nodeId, entry.getValue()); + } else if (change.isPresent()) { + newBrokers.put(nodeId, change.get()); + } + } + for (Entry> entry : changedBrokers.entrySet()) { + int nodeId = entry.getKey(); + Optional brokerRegistration = entry.getValue(); + if (!newBrokers.containsKey(nodeId)) { + if (brokerRegistration.isPresent()) { + newBrokers.put(nodeId, brokerRegistration.get()); + } + } + } + return new ClusterImage(newBrokers); + } + + @Override + public String toString() { + return "ClusterDelta(" + + "changedBrokers=" + changedBrokers + + ')'; + } +} diff --git a/metadata/src/main/java/org/apache/kafka/image/ClusterImage.java b/metadata/src/main/java/org/apache/kafka/image/ClusterImage.java new file mode 100644 index 0000000..3cf36fa --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/image/ClusterImage.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.image; + +import org.apache.kafka.metadata.BrokerRegistration; +import org.apache.kafka.server.common.ApiMessageAndVersion; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.stream.Collectors; + + +/** + * Represents the cluster in the metadata image. + * + * This class is thread-safe. + */ +public final class ClusterImage { + public static final ClusterImage EMPTY = new ClusterImage(Collections.emptyMap()); + + private final Map brokers; + + public ClusterImage(Map brokers) { + this.brokers = Collections.unmodifiableMap(brokers); + } + + public boolean isEmpty() { + return brokers.isEmpty(); + } + + public Map brokers() { + return brokers; + } + + public BrokerRegistration broker(int nodeId) { + return brokers.get(nodeId); + } + + public void write(Consumer> out) { + List batch = new ArrayList<>(); + for (BrokerRegistration broker : brokers.values()) { + batch.add(broker.toRecord()); + } + out.accept(batch); + } + + @Override + public int hashCode() { + return brokers.hashCode(); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof ClusterImage)) return false; + ClusterImage other = (ClusterImage) o; + return brokers.equals(other.brokers); + } + + @Override + public String toString() { + return brokers.entrySet().stream(). + map(e -> e.getKey() + ":" + e.getValue()).collect(Collectors.joining(", ")); + } +} diff --git a/metadata/src/main/java/org/apache/kafka/image/ConfigurationDelta.java b/metadata/src/main/java/org/apache/kafka/image/ConfigurationDelta.java new file mode 100644 index 0000000..677f764 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/image/ConfigurationDelta.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.image; + +import org.apache.kafka.common.metadata.ConfigRecord; + +import java.util.HashMap; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; + + +/** + * Represents changes to the configurations in the metadata image. + */ +public final class ConfigurationDelta { + private final ConfigurationImage image; + private final Map> changes = new HashMap<>(); + + public ConfigurationDelta(ConfigurationImage image) { + this.image = image; + } + + public void finishSnapshot() { + for (String key : image.data().keySet()) { + if (!changes.containsKey(key)) { + changes.put(key, Optional.empty()); + } + } + } + + public void replay(ConfigRecord record) { + changes.put(record.name(), Optional.ofNullable(record.value())); + } + + public void deleteAll() { + changes.clear(); + for (String key : image.data().keySet()) { + changes.put(key, Optional.empty()); + } + } + + public ConfigurationImage apply() { + Map newData = new HashMap<>(image.data().size()); + for (Entry entry : image.data().entrySet()) { + Optional change = changes.get(entry.getKey()); + if (change == null) { + newData.put(entry.getKey(), entry.getValue()); + } else if (change.isPresent()) { + newData.put(entry.getKey(), change.get()); + } + } + for (Entry> entry : changes.entrySet()) { + if (!newData.containsKey(entry.getKey())) { + if (entry.getValue().isPresent()) { + newData.put(entry.getKey(), entry.getValue().get()); + } + } + } + return new ConfigurationImage(newData); + } + + @Override + public String toString() { + // Values are intentionally left out of this so that sensitive configs + // do not end up in logging by mistake. + return "ConfigurationDelta(" + + "changedKeys=" + changes.keySet() + + ')'; + } +} diff --git a/metadata/src/main/java/org/apache/kafka/image/ConfigurationImage.java b/metadata/src/main/java/org/apache/kafka/image/ConfigurationImage.java new file mode 100644 index 0000000..fa004ac --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/image/ConfigurationImage.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.image; + +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.metadata.ConfigRecord; +import org.apache.kafka.server.common.ApiMessageAndVersion; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Properties; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +import static org.apache.kafka.common.metadata.MetadataRecordType.CONFIG_RECORD; + + +/** + * Represents the configuration of a resource. + * + * This class is thread-safe. + */ +public final class ConfigurationImage { + public static final ConfigurationImage EMPTY = new ConfigurationImage(Collections.emptyMap()); + + private final Map data; + + public ConfigurationImage(Map data) { + this.data = data; + } + + Map data() { + return data; + } + + public boolean isEmpty() { + return data.isEmpty(); + } + + public Properties toProperties() { + Properties properties = new Properties(); + properties.putAll(data); + return properties; + } + + public void write(ConfigResource configResource, Consumer> out) { + List records = new ArrayList<>(); + for (Map.Entry entry : data.entrySet()) { + records.add(new ApiMessageAndVersion(new ConfigRecord(). + setResourceType(configResource.type().id()). + setResourceName(configResource.name()). + setName(entry.getKey()). + setValue(entry.getValue()), CONFIG_RECORD.highestSupportedVersion())); + } + out.accept(records); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof ConfigurationImage)) return false; + ConfigurationImage other = (ConfigurationImage) o; + return data.equals(other.data); + } + + @Override + public int hashCode() { + return Objects.hash(data); + } + + @Override + public String toString() { + return "ConfigurationImage(data=" + data.entrySet().stream(). + map(e -> e.getKey() + ":" + e.getValue()).collect(Collectors.joining(", ")) + + ")"; + } +} diff --git a/metadata/src/main/java/org/apache/kafka/image/ConfigurationsDelta.java b/metadata/src/main/java/org/apache/kafka/image/ConfigurationsDelta.java new file mode 100644 index 0000000..d0f5848 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/image/ConfigurationsDelta.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.image; + +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.config.ConfigResource.Type; +import org.apache.kafka.common.metadata.ConfigRecord; +import org.apache.kafka.common.metadata.RemoveTopicRecord; + +import java.util.HashMap; +import java.util.Map; +import java.util.Map.Entry; + + +/** + * Represents changes to the configurations in the metadata image. + */ +public final class ConfigurationsDelta { + private final ConfigurationsImage image; + private final Map changes = new HashMap<>(); + + public ConfigurationsDelta(ConfigurationsImage image) { + this.image = image; + } + + public Map changes() { + return changes; + } + + public void finishSnapshot() { + for (Entry entry : image.resourceData().entrySet()) { + ConfigResource resource = entry.getKey(); + ConfigurationImage configImage = entry.getValue(); + ConfigurationDelta configDelta = changes.computeIfAbsent(resource, + __ -> new ConfigurationDelta(configImage)); + configDelta.finishSnapshot(); + } + } + + public void replay(ConfigRecord record) { + ConfigResource resource = + new ConfigResource(Type.forId(record.resourceType()), record.resourceName()); + ConfigurationImage configImage = + image.resourceData().getOrDefault(resource, ConfigurationImage.EMPTY); + ConfigurationDelta delta = changes.computeIfAbsent(resource, + __ -> new ConfigurationDelta(configImage)); + delta.replay(record); + } + + public void replay(RemoveTopicRecord record, String topicName) { + ConfigResource resource = + new ConfigResource(Type.TOPIC, topicName); + ConfigurationImage configImage = + image.resourceData().getOrDefault(resource, ConfigurationImage.EMPTY); + ConfigurationDelta delta = changes.computeIfAbsent(resource, + __ -> new ConfigurationDelta(configImage)); + delta.deleteAll(); + } + + public ConfigurationsImage apply() { + Map newData = new HashMap<>(); + for (Entry entry : image.resourceData().entrySet()) { + ConfigResource resource = entry.getKey(); + ConfigurationDelta delta = changes.get(resource); + if (delta == null) { + newData.put(resource, entry.getValue()); + } else { + ConfigurationImage newImage = delta.apply(); + if (!newImage.isEmpty()) { + newData.put(resource, newImage); + } + } + } + for (Entry entry : changes.entrySet()) { + if (!newData.containsKey(entry.getKey())) { + ConfigurationImage newImage = entry.getValue().apply(); + if (!newImage.isEmpty()) { + newData.put(entry.getKey(), newImage); + } + } + } + return new ConfigurationsImage(newData); + } + + @Override + public String toString() { + return "ConfigurationsDelta(" + + "changes=" + changes + + ')'; + } +} diff --git a/metadata/src/main/java/org/apache/kafka/image/ConfigurationsImage.java b/metadata/src/main/java/org/apache/kafka/image/ConfigurationsImage.java new file mode 100644 index 0000000..87bc902 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/image/ConfigurationsImage.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.image; + +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.server.common.ApiMessageAndVersion; + +import java.util.Collections; +import java.util.List; +import java.util.Map.Entry; +import java.util.Map; +import java.util.Objects; +import java.util.Properties; +import java.util.function.Consumer; +import java.util.stream.Collectors; + + +/** + * Represents the configurations in the metadata image. + * + * This class is thread-safe. + */ +public final class ConfigurationsImage { + public static final ConfigurationsImage EMPTY = + new ConfigurationsImage(Collections.emptyMap()); + + private final Map data; + + public ConfigurationsImage(Map data) { + this.data = Collections.unmodifiableMap(data); + } + + public boolean isEmpty() { + return data.isEmpty(); + } + + Map resourceData() { + return data; + } + + public Properties configProperties(ConfigResource configResource) { + ConfigurationImage configurationImage = data.get(configResource); + if (configurationImage != null) { + return configurationImage.toProperties(); + } else { + return new Properties(); + } + } + + public void write(Consumer> out) { + for (Entry entry : data.entrySet()) { + ConfigResource configResource = entry.getKey(); + ConfigurationImage configImage = entry.getValue(); + configImage.write(configResource, out); + } + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof ConfigurationsImage)) return false; + ConfigurationsImage other = (ConfigurationsImage) o; + return data.equals(other.data); + } + + @Override + public int hashCode() { + return Objects.hash(data); + } + + @Override + public String toString() { + return "ConfigurationsImage(data=" + data.entrySet().stream(). + map(e -> e.getKey() + ":" + e.getValue()).collect(Collectors.joining(", ")) + + ")"; + } +} diff --git a/metadata/src/main/java/org/apache/kafka/image/FeaturesDelta.java b/metadata/src/main/java/org/apache/kafka/image/FeaturesDelta.java new file mode 100644 index 0000000..781c496 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/image/FeaturesDelta.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.image; + +import org.apache.kafka.common.metadata.FeatureLevelRecord; +import org.apache.kafka.common.metadata.RemoveFeatureLevelRecord; +import org.apache.kafka.metadata.VersionRange; + +import java.util.HashMap; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; + + +/** + * Represents changes to the cluster in the metadata image. + */ +public final class FeaturesDelta { + private final FeaturesImage image; + + private final Map> changes = new HashMap<>(); + + public FeaturesDelta(FeaturesImage image) { + this.image = image; + } + + public Map> changes() { + return changes; + } + + public void finishSnapshot() { + for (String featureName : image.finalizedVersions().keySet()) { + if (!changes.containsKey(featureName)) { + changes.put(featureName, Optional.empty()); + } + } + } + + public void replay(FeatureLevelRecord record) { + changes.put(record.name(), Optional.of( + new VersionRange(record.minFeatureLevel(), record.maxFeatureLevel()))); + } + + public void replay(RemoveFeatureLevelRecord record) { + changes.put(record.name(), Optional.empty()); + } + + public FeaturesImage apply() { + Map newFinalizedVersions = + new HashMap<>(image.finalizedVersions().size()); + for (Entry entry : image.finalizedVersions().entrySet()) { + String name = entry.getKey(); + Optional change = changes.get(name); + if (change == null) { + newFinalizedVersions.put(name, entry.getValue()); + } else if (change.isPresent()) { + newFinalizedVersions.put(name, change.get()); + } + } + for (Entry> entry : changes.entrySet()) { + String name = entry.getKey(); + Optional change = entry.getValue(); + if (!newFinalizedVersions.containsKey(name)) { + if (change.isPresent()) { + newFinalizedVersions.put(name, change.get()); + } + } + } + return new FeaturesImage(newFinalizedVersions); + } + + @Override + public String toString() { + return "FeaturesDelta(" + + "changes=" + changes + + ')'; + } +} diff --git a/metadata/src/main/java/org/apache/kafka/image/FeaturesImage.java b/metadata/src/main/java/org/apache/kafka/image/FeaturesImage.java new file mode 100644 index 0000000..f5f3729 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/image/FeaturesImage.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.image; + +import org.apache.kafka.common.metadata.FeatureLevelRecord; +import org.apache.kafka.metadata.VersionRange; +import org.apache.kafka.server.common.ApiMessageAndVersion; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +import static org.apache.kafka.common.metadata.MetadataRecordType.FEATURE_LEVEL_RECORD; + + +/** + * Represents the feature levels in the metadata image. + * + * This class is thread-safe. + */ +public final class FeaturesImage { + public static final FeaturesImage EMPTY = new FeaturesImage(Collections.emptyMap()); + + private final Map finalizedVersions; + + public FeaturesImage(Map finalizedVersions) { + this.finalizedVersions = Collections.unmodifiableMap(finalizedVersions); + } + + public boolean isEmpty() { + return finalizedVersions.isEmpty(); + } + + Map finalizedVersions() { + return finalizedVersions; + } + + private Optional finalizedVersion(String feature) { + return Optional.ofNullable(finalizedVersions.get(feature)); + } + + public void write(Consumer> out) { + List batch = new ArrayList<>(); + for (Entry entry : finalizedVersions.entrySet()) { + batch.add(new ApiMessageAndVersion(new FeatureLevelRecord(). + setName(entry.getKey()). + setMinFeatureLevel(entry.getValue().min()). + setMaxFeatureLevel(entry.getValue().max()), + FEATURE_LEVEL_RECORD.highestSupportedVersion())); + } + out.accept(batch); + } + + @Override + public int hashCode() { + return finalizedVersions.hashCode(); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof FeaturesImage)) return false; + FeaturesImage other = (FeaturesImage) o; + return finalizedVersions.equals(other.finalizedVersions); + } + + @Override + public String toString() { + return finalizedVersions.entrySet().stream(). + map(e -> e.getKey() + ":" + e.getValue()).collect(Collectors.joining(", ")); + } +} diff --git a/metadata/src/main/java/org/apache/kafka/image/LocalReplicaChanges.java b/metadata/src/main/java/org/apache/kafka/image/LocalReplicaChanges.java new file mode 100644 index 0000000..3afdd1a --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/image/LocalReplicaChanges.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.image; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.metadata.PartitionRegistration; + +import java.util.Set; +import java.util.Map; + +public final class LocalReplicaChanges { + private final Set deletes; + private final Map leaders; + private final Map followers; + + LocalReplicaChanges( + Set deletes, + Map leaders, + Map followers + ) { + this.deletes = deletes; + this.leaders = leaders; + this.followers = followers; + } + + public Set deletes() { + return deletes; + } + + public Map leaders() { + return leaders; + } + + public Map followers() { + return followers; + } + + @Override + public String toString() { + return String.format( + "LocalReplicaChanges(deletes = %s, leaders = %s, followers = %s)", + deletes, + leaders, + followers + ); + } + + public static final class PartitionInfo { + private final Uuid topicId; + private final PartitionRegistration partition; + + public PartitionInfo(Uuid topicId, PartitionRegistration partition) { + this.topicId = topicId; + this.partition = partition; + } + + @Override + public String toString() { + return String.format("PartitionInfo(topicId = %s, partition = %s)", topicId, partition); + } + + public Uuid topicId() { + return topicId; + } + + public PartitionRegistration partition() { + return partition; + } + } +} diff --git a/metadata/src/main/java/org/apache/kafka/image/MetadataDelta.java b/metadata/src/main/java/org/apache/kafka/image/MetadataDelta.java new file mode 100644 index 0000000..aa7725b --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/image/MetadataDelta.java @@ -0,0 +1,288 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.image; + +import org.apache.kafka.common.metadata.BrokerRegistrationChangeRecord; +import org.apache.kafka.common.metadata.ClientQuotaRecord; +import org.apache.kafka.common.metadata.ConfigRecord; +import org.apache.kafka.common.metadata.FeatureLevelRecord; +import org.apache.kafka.common.metadata.FenceBrokerRecord; +import org.apache.kafka.common.metadata.MetadataRecordType; +import org.apache.kafka.common.metadata.PartitionChangeRecord; +import org.apache.kafka.common.metadata.PartitionRecord; +import org.apache.kafka.common.metadata.RegisterBrokerRecord; +import org.apache.kafka.common.metadata.RemoveFeatureLevelRecord; +import org.apache.kafka.common.metadata.RemoveTopicRecord; +import org.apache.kafka.common.metadata.TopicRecord; +import org.apache.kafka.common.metadata.UnfenceBrokerRecord; +import org.apache.kafka.common.metadata.UnregisterBrokerRecord; +import org.apache.kafka.common.protocol.ApiMessage; +import org.apache.kafka.raft.OffsetAndEpoch; +import org.apache.kafka.server.common.ApiMessageAndVersion; + +import java.util.Iterator; +import java.util.List; + + +/** + * A change to the broker metadata image. + * + * This class is thread-safe. + */ +public final class MetadataDelta { + private final MetadataImage image; + + private long highestOffset; + + private int highestEpoch; + + private FeaturesDelta featuresDelta = null; + + private ClusterDelta clusterDelta = null; + + private TopicsDelta topicsDelta = null; + + private ConfigurationsDelta configsDelta = null; + + private ClientQuotasDelta clientQuotasDelta = null; + + public MetadataDelta(MetadataImage image) { + this.image = image; + this.highestOffset = image.highestOffsetAndEpoch().offset; + this.highestEpoch = image.highestOffsetAndEpoch().epoch; + } + + public MetadataImage image() { + return image; + } + + public FeaturesDelta featuresDelta() { + return featuresDelta; + } + + public ClusterDelta clusterDelta() { + return clusterDelta; + } + + public TopicsDelta topicsDelta() { + return topicsDelta; + } + + public ConfigurationsDelta configsDelta() { + return configsDelta; + } + + public ClientQuotasDelta clientQuotasDelta() { + return clientQuotasDelta; + } + + public void read(long highestOffset, int highestEpoch, Iterator> reader) { + while (reader.hasNext()) { + List batch = reader.next(); + for (ApiMessageAndVersion messageAndVersion : batch) { + replay(highestOffset, highestEpoch, messageAndVersion.message()); + } + } + } + + public void replay(long offset, int epoch, ApiMessage record) { + highestOffset = offset; + highestEpoch = epoch; + + MetadataRecordType type = MetadataRecordType.fromId(record.apiKey()); + switch (type) { + case REGISTER_BROKER_RECORD: + replay((RegisterBrokerRecord) record); + break; + case UNREGISTER_BROKER_RECORD: + replay((UnregisterBrokerRecord) record); + break; + case TOPIC_RECORD: + replay((TopicRecord) record); + break; + case PARTITION_RECORD: + replay((PartitionRecord) record); + break; + case CONFIG_RECORD: + replay((ConfigRecord) record); + break; + case PARTITION_CHANGE_RECORD: + replay((PartitionChangeRecord) record); + break; + case FENCE_BROKER_RECORD: + replay((FenceBrokerRecord) record); + break; + case UNFENCE_BROKER_RECORD: + replay((UnfenceBrokerRecord) record); + break; + case REMOVE_TOPIC_RECORD: + replay((RemoveTopicRecord) record); + break; + case FEATURE_LEVEL_RECORD: + replay((FeatureLevelRecord) record); + break; + case CLIENT_QUOTA_RECORD: + replay((ClientQuotaRecord) record); + break; + case PRODUCER_IDS_RECORD: + // Nothing to do. + break; + case REMOVE_FEATURE_LEVEL_RECORD: + replay((RemoveFeatureLevelRecord) record); + break; + case BROKER_REGISTRATION_CHANGE_RECORD: + replay((BrokerRegistrationChangeRecord) record); + break; + default: + throw new RuntimeException("Unknown metadata record type " + type); + } + } + + public void replay(RegisterBrokerRecord record) { + if (clusterDelta == null) clusterDelta = new ClusterDelta(image.cluster()); + clusterDelta.replay(record); + } + + public void replay(UnregisterBrokerRecord record) { + if (clusterDelta == null) clusterDelta = new ClusterDelta(image.cluster()); + clusterDelta.replay(record); + } + + public void replay(TopicRecord record) { + if (topicsDelta == null) topicsDelta = new TopicsDelta(image.topics()); + topicsDelta.replay(record); + } + + public void replay(PartitionRecord record) { + if (topicsDelta == null) topicsDelta = new TopicsDelta(image.topics()); + topicsDelta.replay(record); + } + + public void replay(ConfigRecord record) { + if (configsDelta == null) configsDelta = new ConfigurationsDelta(image.configs()); + configsDelta.replay(record); + } + + public void replay(PartitionChangeRecord record) { + if (topicsDelta == null) topicsDelta = new TopicsDelta(image.topics()); + topicsDelta.replay(record); + } + + public void replay(FenceBrokerRecord record) { + if (clusterDelta == null) clusterDelta = new ClusterDelta(image.cluster()); + clusterDelta.replay(record); + } + + public void replay(UnfenceBrokerRecord record) { + if (clusterDelta == null) clusterDelta = new ClusterDelta(image.cluster()); + clusterDelta.replay(record); + } + + public void replay(RemoveTopicRecord record) { + if (topicsDelta == null) topicsDelta = new TopicsDelta(image.topics()); + String topicName = topicsDelta.replay(record); + if (configsDelta == null) configsDelta = new ConfigurationsDelta(image.configs()); + configsDelta.replay(record, topicName); + } + + public void replay(FeatureLevelRecord record) { + if (featuresDelta == null) featuresDelta = new FeaturesDelta(image.features()); + featuresDelta.replay(record); + } + + public void replay(BrokerRegistrationChangeRecord record) { + if (clusterDelta == null) clusterDelta = new ClusterDelta(image.cluster()); + clusterDelta.replay(record); + } + + public void replay(ClientQuotaRecord record) { + if (clientQuotasDelta == null) clientQuotasDelta = new ClientQuotasDelta(image.clientQuotas()); + clientQuotasDelta.replay(record); + } + + public void replay(RemoveFeatureLevelRecord record) { + if (featuresDelta == null) featuresDelta = new FeaturesDelta(image.features()); + featuresDelta.replay(record); + } + + /** + * Create removal deltas for anything which was in the base image, but which was not + * referenced in the snapshot records we just applied. + */ + public void finishSnapshot() { + if (featuresDelta != null) featuresDelta.finishSnapshot(); + if (clusterDelta != null) clusterDelta.finishSnapshot(); + if (topicsDelta != null) topicsDelta.finishSnapshot(); + if (configsDelta != null) configsDelta.finishSnapshot(); + if (clientQuotasDelta != null) clientQuotasDelta.finishSnapshot(); + } + + public MetadataImage apply() { + FeaturesImage newFeatures; + if (featuresDelta == null) { + newFeatures = image.features(); + } else { + newFeatures = featuresDelta.apply(); + } + ClusterImage newCluster; + if (clusterDelta == null) { + newCluster = image.cluster(); + } else { + newCluster = clusterDelta.apply(); + } + TopicsImage newTopics; + if (topicsDelta == null) { + newTopics = image.topics(); + } else { + newTopics = topicsDelta.apply(); + } + ConfigurationsImage newConfigs; + if (configsDelta == null) { + newConfigs = image.configs(); + } else { + newConfigs = configsDelta.apply(); + } + ClientQuotasImage newClientQuotas; + if (clientQuotasDelta == null) { + newClientQuotas = image.clientQuotas(); + } else { + newClientQuotas = clientQuotasDelta.apply(); + } + return new MetadataImage( + new OffsetAndEpoch(highestOffset, highestEpoch), + newFeatures, + newCluster, + newTopics, + newConfigs, + newClientQuotas + ); + } + + @Override + public String toString() { + return "MetadataDelta(" + + "highestOffset=" + highestOffset + + ", highestEpoch=" + highestEpoch + + ", featuresDelta=" + featuresDelta + + ", clusterDelta=" + clusterDelta + + ", topicsDelta=" + topicsDelta + + ", configsDelta=" + configsDelta + + ", clientQuotasDelta=" + clientQuotasDelta + + ')'; + } +} diff --git a/metadata/src/main/java/org/apache/kafka/image/MetadataImage.java b/metadata/src/main/java/org/apache/kafka/image/MetadataImage.java new file mode 100644 index 0000000..b9fdbc1 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/image/MetadataImage.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.image; + +import org.apache.kafka.raft.OffsetAndEpoch; +import org.apache.kafka.server.common.ApiMessageAndVersion; + +import java.util.List; +import java.util.Objects; +import java.util.function.Consumer; + + +/** + * The broker metadata image. + * + * This class is thread-safe. + */ +public final class MetadataImage { + public final static MetadataImage EMPTY = new MetadataImage( + new OffsetAndEpoch(0, 0), + FeaturesImage.EMPTY, + ClusterImage.EMPTY, + TopicsImage.EMPTY, + ConfigurationsImage.EMPTY, + ClientQuotasImage.EMPTY); + + private final OffsetAndEpoch highestOffsetAndEpoch; + + private final FeaturesImage features; + + private final ClusterImage cluster; + + private final TopicsImage topics; + + private final ConfigurationsImage configs; + + private final ClientQuotasImage clientQuotas; + + public MetadataImage( + OffsetAndEpoch highestOffsetAndEpoch, + FeaturesImage features, + ClusterImage cluster, + TopicsImage topics, + ConfigurationsImage configs, + ClientQuotasImage clientQuotas + ) { + this.highestOffsetAndEpoch = highestOffsetAndEpoch; + this.features = features; + this.cluster = cluster; + this.topics = topics; + this.configs = configs; + this.clientQuotas = clientQuotas; + } + + public boolean isEmpty() { + return features.isEmpty() && + cluster.isEmpty() && + topics.isEmpty() && + configs.isEmpty() && + clientQuotas.isEmpty(); + } + + public OffsetAndEpoch highestOffsetAndEpoch() { + return highestOffsetAndEpoch; + } + + public FeaturesImage features() { + return features; + } + + public ClusterImage cluster() { + return cluster; + } + + public TopicsImage topics() { + return topics; + } + + public ConfigurationsImage configs() { + return configs; + } + + public ClientQuotasImage clientQuotas() { + return clientQuotas; + } + + public void write(Consumer> out) { + features.write(out); + cluster.write(out); + topics.write(out); + configs.write(out); + clientQuotas.write(out); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof MetadataImage)) return false; + MetadataImage other = (MetadataImage) o; + return highestOffsetAndEpoch.equals(other.highestOffsetAndEpoch) && + features.equals(other.features) && + cluster.equals(other.cluster) && + topics.equals(other.topics) && + configs.equals(other.configs) && + clientQuotas.equals(other.clientQuotas); + } + + @Override + public int hashCode() { + return Objects.hash(highestOffsetAndEpoch, features, cluster, topics, configs, clientQuotas); + } + + @Override + public String toString() { + return "MetadataImage(highestOffsetAndEpoch=" + highestOffsetAndEpoch + + ", features=" + features + + ", cluster=" + cluster + + ", topics=" + topics + + ", configs=" + configs + + ", clientQuotas=" + clientQuotas + + ")"; + } +} diff --git a/metadata/src/main/java/org/apache/kafka/image/TopicDelta.java b/metadata/src/main/java/org/apache/kafka/image/TopicDelta.java new file mode 100644 index 0000000..3214a0c --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/image/TopicDelta.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.image; + +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.metadata.PartitionChangeRecord; +import org.apache.kafka.common.metadata.PartitionRecord; +import org.apache.kafka.metadata.PartitionRegistration; +import org.apache.kafka.metadata.Replicas; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map.Entry; +import java.util.Map; +import java.util.Set; + +/** + * Represents changes to a topic in the metadata image. + */ +public final class TopicDelta { + private final TopicImage image; + private final Map partitionChanges = new HashMap<>(); + + public TopicDelta(TopicImage image) { + this.image = image; + } + + public TopicImage image() { + return image; + } + + public Map partitionChanges() { + return partitionChanges; + } + + public String name() { + return image.name(); + } + + public Uuid id() { + return image.id(); + } + + public void replay(PartitionRecord record) { + partitionChanges.put(record.partitionId(), new PartitionRegistration(record)); + } + + public void replay(PartitionChangeRecord record) { + PartitionRegistration partition = partitionChanges.get(record.partitionId()); + if (partition == null) { + partition = image.partitions().get(record.partitionId()); + if (partition == null) { + throw new RuntimeException("Unable to find partition " + + record.topicId() + ":" + record.partitionId()); + } + } + partitionChanges.put(record.partitionId(), partition.merge(record)); + } + + public TopicImage apply() { + Map newPartitions = new HashMap<>(); + for (Entry entry : image.partitions().entrySet()) { + int partitionId = entry.getKey(); + PartitionRegistration changedPartition = partitionChanges.get(partitionId); + if (changedPartition == null) { + newPartitions.put(partitionId, entry.getValue()); + } else { + newPartitions.put(partitionId, changedPartition); + } + } + for (Entry entry : partitionChanges.entrySet()) { + if (!newPartitions.containsKey(entry.getKey())) { + newPartitions.put(entry.getKey(), entry.getValue()); + } + } + return new TopicImage(image.name(), image.id(), newPartitions); + } + + /** + * Find the partitions that have change based on the replica given. + * + * The changes identified are: + * 1. partitions for which the broker is not a replica anymore + * 2. partitions for which the broker is now the leader + * 3. partitions for which the broker is now a follower + * + * @param brokerId the broker id + * @return the list of partitions which the broker should remove, become leader or become follower. + */ + public LocalReplicaChanges localChanges(int brokerId) { + Set deletes = new HashSet<>(); + Map leaders = new HashMap<>(); + Map followers = new HashMap<>(); + + for (Entry entry : partitionChanges.entrySet()) { + if (!Replicas.contains(entry.getValue().replicas, brokerId)) { + PartitionRegistration prevPartition = image.partitions().get(entry.getKey()); + if (prevPartition != null && Replicas.contains(prevPartition.replicas, brokerId)) { + deletes.add(new TopicPartition(name(), entry.getKey())); + } + } else if (entry.getValue().leader == brokerId) { + PartitionRegistration prevPartition = image.partitions().get(entry.getKey()); + if (prevPartition == null || prevPartition.partitionEpoch != entry.getValue().partitionEpoch) { + leaders.put( + new TopicPartition(name(), entry.getKey()), + new LocalReplicaChanges.PartitionInfo(id(), entry.getValue()) + ); + } + } else if ( + entry.getValue().leader != brokerId && + Replicas.contains(entry.getValue().replicas, brokerId) + ) { + PartitionRegistration prevPartition = image.partitions().get(entry.getKey()); + if (prevPartition == null || prevPartition.partitionEpoch != entry.getValue().partitionEpoch) { + followers.put( + new TopicPartition(name(), entry.getKey()), + new LocalReplicaChanges.PartitionInfo(id(), entry.getValue()) + ); + } + } + } + + return new LocalReplicaChanges(deletes, leaders, followers); + } + + @Override + public String toString() { + return "TopicDelta(" + + "partitionChanges=" + partitionChanges + + ')'; + } +} diff --git a/metadata/src/main/java/org/apache/kafka/image/TopicImage.java b/metadata/src/main/java/org/apache/kafka/image/TopicImage.java new file mode 100644 index 0000000..b31ef94 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/image/TopicImage.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.image; + +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.metadata.TopicRecord; +import org.apache.kafka.metadata.PartitionRegistration; +import org.apache.kafka.server.common.ApiMessageAndVersion; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map.Entry; +import java.util.Map; +import java.util.Objects; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +import static org.apache.kafka.common.metadata.MetadataRecordType.TOPIC_RECORD; + + +/** + * Represents a topic in the metadata image. + * + * This class is thread-safe. + */ +public final class TopicImage { + private final String name; + + private final Uuid id; + + private final Map partitions; + + public TopicImage(String name, + Uuid id, + Map partitions) { + this.name = name; + this.id = id; + this.partitions = partitions; + } + + public String name() { + return name; + } + + public Uuid id() { + return id; + } + + public Map partitions() { + return partitions; + } + + public void write(Consumer> out) { + List batch = new ArrayList<>(); + batch.add(new ApiMessageAndVersion(new TopicRecord(). + setName(name). + setTopicId(id), TOPIC_RECORD.highestSupportedVersion())); + for (Entry entry : partitions.entrySet()) { + int partitionId = entry.getKey(); + PartitionRegistration partition = entry.getValue(); + batch.add(partition.toRecord(id, partitionId)); + } + out.accept(batch); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof TopicImage)) return false; + TopicImage other = (TopicImage) o; + return name.equals(other.name) && + id.equals(other.id) && + partitions.equals(other.partitions); + } + + @Override + public int hashCode() { + return Objects.hash(name, id, partitions); + } + + @Override + public String toString() { + return "TopicImage(name=" + name + ", id=" + id + ", partitions=" + + partitions.entrySet().stream(). + map(e -> e.getKey() + ":" + e.getValue()). + collect(Collectors.joining(", ")) + ")"; + } +} diff --git a/metadata/src/main/java/org/apache/kafka/image/TopicsDelta.java b/metadata/src/main/java/org/apache/kafka/image/TopicsDelta.java new file mode 100644 index 0000000..f9d8087 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/image/TopicsDelta.java @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.image; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.metadata.PartitionChangeRecord; +import org.apache.kafka.common.metadata.PartitionRecord; +import org.apache.kafka.common.metadata.RemoveTopicRecord; +import org.apache.kafka.common.metadata.TopicRecord; +import org.apache.kafka.metadata.Replicas; + +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; + + +/** + * Represents changes to the topics in the metadata image. + */ +public final class TopicsDelta { + private final TopicsImage image; + + /** + * A map from topic IDs to the topic deltas for each topic. Topics which have been + * deleted will not appear in this map. + */ + private final Map changedTopics = new HashMap<>(); + + /** + * The IDs of topics that exist in the image but that have been deleted. Note that if + * a topic does not exist in the image, it will also not exist in this set. Topics + * that are created and then deleted within the same delta will leave no trace. + */ + private final Set deletedTopicIds = new HashSet<>(); + + public TopicsDelta(TopicsImage image) { + this.image = image; + } + + public TopicsImage image() { + return image; + } + + public Map changedTopics() { + return changedTopics; + } + + public void replay(TopicRecord record) { + TopicDelta delta = new TopicDelta( + new TopicImage(record.name(), record.topicId(), Collections.emptyMap())); + changedTopics.put(record.topicId(), delta); + } + + TopicDelta getOrCreateTopicDelta(Uuid id) { + TopicDelta topicDelta = changedTopics.get(id); + if (topicDelta == null) { + topicDelta = new TopicDelta(image.getTopic(id)); + changedTopics.put(id, topicDelta); + } + return topicDelta; + } + + public void replay(PartitionRecord record) { + TopicDelta topicDelta = getOrCreateTopicDelta(record.topicId()); + topicDelta.replay(record); + } + + public void replay(PartitionChangeRecord record) { + TopicDelta topicDelta = getOrCreateTopicDelta(record.topicId()); + topicDelta.replay(record); + } + + public String replay(RemoveTopicRecord record) { + TopicDelta topicDelta = changedTopics.remove(record.topicId()); + String topicName; + if (topicDelta != null) { + topicName = topicDelta.image().name(); + if (image.topicsById().containsKey(record.topicId())) { + deletedTopicIds.add(record.topicId()); + } + } else { + TopicImage topicImage = image.getTopic(record.topicId()); + if (topicImage == null) { + throw new RuntimeException("Unable to delete topic with id " + + record.topicId() + ": no such topic found."); + } + topicName = topicImage.name(); + deletedTopicIds.add(record.topicId()); + } + return topicName; + } + + public void finishSnapshot() { + for (Uuid topicId : image.topicsById().keySet()) { + if (!changedTopics.containsKey(topicId)) { + deletedTopicIds.add(topicId); + } + } + } + + public TopicsImage apply() { + Map newTopicsById = new HashMap<>(image.topicsById().size()); + Map newTopicsByName = new HashMap<>(image.topicsByName().size()); + for (Entry entry : image.topicsById().entrySet()) { + Uuid id = entry.getKey(); + TopicImage prevTopicImage = entry.getValue(); + TopicDelta delta = changedTopics.get(id); + if (delta == null) { + if (!deletedTopicIds.contains(id)) { + newTopicsById.put(id, prevTopicImage); + newTopicsByName.put(prevTopicImage.name(), prevTopicImage); + } + } else { + TopicImage newTopicImage = delta.apply(); + newTopicsById.put(id, newTopicImage); + newTopicsByName.put(delta.name(), newTopicImage); + } + } + for (Entry entry : changedTopics.entrySet()) { + if (!newTopicsById.containsKey(entry.getKey())) { + TopicImage newTopicImage = entry.getValue().apply(); + newTopicsById.put(newTopicImage.id(), newTopicImage); + newTopicsByName.put(newTopicImage.name(), newTopicImage); + } + } + return new TopicsImage(newTopicsById, newTopicsByName); + } + + public TopicDelta changedTopic(Uuid topicId) { + return changedTopics.get(topicId); + } + + /** + * Returns true if the topic with the given name was deleted. Note: this will return + * true even if a new topic with the same name was subsequently created. + */ + public boolean topicWasDeleted(String topicName) { + TopicImage topicImage = image.getTopic(topicName); + if (topicImage == null) { + return false; + } + return deletedTopicIds.contains(topicImage.id()); + } + + public Set deletedTopicIds() { + return deletedTopicIds; + } + + /** + * Find the topic partitions that have change based on the replica given. + * + * The changes identified are: + * 1. topic partitions for which the broker is not a replica anymore + * 2. topic partitions for which the broker is now the leader + * 3. topic partitions for which the broker is now a follower + * + * @param brokerId the broker id + * @return the list of topic partitions which the broker should remove, become leader or become follower. + */ + public LocalReplicaChanges localChanges(int brokerId) { + Set deletes = new HashSet<>(); + Map leaders = new HashMap<>(); + Map followers = new HashMap<>(); + + for (TopicDelta delta : changedTopics.values()) { + LocalReplicaChanges changes = delta.localChanges(brokerId); + + deletes.addAll(changes.deletes()); + leaders.putAll(changes.leaders()); + followers.putAll(changes.followers()); + } + + // Add all of the removed topic partitions to the set of locally removed partitions + deletedTopicIds().forEach(topicId -> { + TopicImage topicImage = image().getTopic(topicId); + topicImage.partitions().forEach((partitionId, prevPartition) -> { + if (Replicas.contains(prevPartition.replicas, brokerId)) { + deletes.add(new TopicPartition(topicImage.name(), partitionId)); + } + }); + }); + + return new LocalReplicaChanges(deletes, leaders, followers); + } + + @Override + public String toString() { + return "TopicsDelta(" + + "changedTopics=" + changedTopics + + ", deletedTopicIds=" + deletedTopicIds + + ')'; + } +} diff --git a/metadata/src/main/java/org/apache/kafka/image/TopicsImage.java b/metadata/src/main/java/org/apache/kafka/image/TopicsImage.java new file mode 100644 index 0000000..c0b218b --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/image/TopicsImage.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.image; + +import org.apache.kafka.common.Uuid; +import org.apache.kafka.metadata.PartitionRegistration; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.apache.kafka.server.util.TranslatedValueMapView; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.function.Consumer; +import java.util.stream.Collectors; + + +/** + * Represents the topics in the metadata image. + * + * This class is thread-safe. + */ +public final class TopicsImage { + public static final TopicsImage EMPTY = + new TopicsImage(Collections.emptyMap(), Collections.emptyMap()); + + private final Map topicsById; + private final Map topicsByName; + + public TopicsImage(Map topicsById, + Map topicsByName) { + this.topicsById = Collections.unmodifiableMap(topicsById); + this.topicsByName = Collections.unmodifiableMap(topicsByName); + } + + public boolean isEmpty() { + return topicsById.isEmpty() && topicsByName.isEmpty(); + } + + public Map topicsById() { + return topicsById; + } + + public Map topicsByName() { + return topicsByName; + } + + public PartitionRegistration getPartition(Uuid id, int partitionId) { + TopicImage topicImage = topicsById.get(id); + if (topicImage == null) return null; + return topicImage.partitions().get(partitionId); + } + + public TopicImage getTopic(Uuid id) { + return topicsById.get(id); + } + + public TopicImage getTopic(String name) { + return topicsByName.get(name); + } + + public void write(Consumer> out) { + for (TopicImage topicImage : topicsById.values()) { + topicImage.write(out); + } + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof TopicsImage)) return false; + TopicsImage other = (TopicsImage) o; + return topicsById.equals(other.topicsById) && + topicsByName.equals(other.topicsByName); + } + + @Override + public int hashCode() { + return Objects.hash(topicsById, topicsByName); + } + + /** + * Expose a view of this TopicsImage as a map from topic names to IDs. + * + * Like TopicsImage itself, this map is immutable. + */ + public Map topicNameToIdView() { + return new TranslatedValueMapView<>(topicsByName, image -> image.id()); + } + + /** + * Expose a view of this TopicsImage as a map from IDs to names. + * + * Like TopicsImage itself, this map is immutable. + */ + public Map topicIdToNameView() { + return new TranslatedValueMapView<>(topicsById, image -> image.name()); + } + + @Override + public String toString() { + return "TopicsImage(topicsById=" + topicsById.entrySet().stream(). + map(e -> e.getKey() + ":" + e.getValue()).collect(Collectors.joining(", ")) + + ", topicsByName=" + topicsByName.entrySet().stream(). + map(e -> e.getKey() + ":" + e.getValue()).collect(Collectors.joining(", ")) + + ")"; + } +} diff --git a/metadata/src/main/java/org/apache/kafka/metadata/BrokerHeartbeatReply.java b/metadata/src/main/java/org/apache/kafka/metadata/BrokerHeartbeatReply.java new file mode 100644 index 0000000..c936601 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/metadata/BrokerHeartbeatReply.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.metadata; + +import java.util.Objects; + + +public class BrokerHeartbeatReply { + /** + * True if the heartbeat reply should tell the broker that it has caught up. + */ + private final boolean isCaughtUp; + + /** + * True if the heartbeat reply should tell the broker that it is fenced. + */ + private final boolean isFenced; + + /** + * True if the broker is currently in a controlled shutdown state. + */ + private final boolean inControlledShutdown; + + /** + * True if the heartbeat reply should tell the broker that it should shut down. + */ + private final boolean shouldShutDown; + + public BrokerHeartbeatReply(boolean isCaughtUp, + boolean isFenced, + boolean inControlledShutdown, + boolean shouldShutDown) { + this.isCaughtUp = isCaughtUp; + this.isFenced = isFenced; + this.inControlledShutdown = inControlledShutdown; + this.shouldShutDown = shouldShutDown; + } + + public boolean isCaughtUp() { + return isCaughtUp; + } + + public boolean isFenced() { + return isFenced; + } + + public boolean inControlledShutdown() { + return inControlledShutdown; + } + + public boolean shouldShutDown() { + return shouldShutDown; + } + + @Override + public int hashCode() { + return Objects.hash(isCaughtUp, isFenced, inControlledShutdown, shouldShutDown); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof BrokerHeartbeatReply)) return false; + BrokerHeartbeatReply other = (BrokerHeartbeatReply) o; + return other.isCaughtUp == isCaughtUp && + other.isFenced == isFenced && + other.inControlledShutdown == inControlledShutdown && + other.shouldShutDown == shouldShutDown; + } + + @Override + public String toString() { + return "BrokerHeartbeatReply(isCaughtUp=" + isCaughtUp + + ", isFenced=" + isFenced + + ", inControlledShutdown=" + inControlledShutdown + + ", shouldShutDown = " + shouldShutDown + + ")"; + } +} diff --git a/metadata/src/main/java/org/apache/kafka/metadata/BrokerRegistration.java b/metadata/src/main/java/org/apache/kafka/metadata/BrokerRegistration.java new file mode 100644 index 0000000..fd5eb65 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/metadata/BrokerRegistration.java @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.metadata; + +import org.apache.kafka.common.Endpoint; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.metadata.RegisterBrokerRecord; +import org.apache.kafka.common.metadata.RegisterBrokerRecord.BrokerEndpoint; +import org.apache.kafka.common.metadata.RegisterBrokerRecord.BrokerFeature; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.server.common.ApiMessageAndVersion; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +import static org.apache.kafka.common.metadata.MetadataRecordType.REGISTER_BROKER_RECORD; + + +/** + * An immutable class which represents broker registrations. + */ +public class BrokerRegistration { + private static Map listenersToMap(Collection listeners) { + Map listenersMap = new HashMap<>(); + for (Endpoint endpoint : listeners) { + listenersMap.put(endpoint.listenerName().get(), endpoint); + } + return listenersMap; + } + + private final int id; + private final long epoch; + private final Uuid incarnationId; + private final Map listeners; + private final Map supportedFeatures; + private final Optional rack; + private final boolean fenced; + + public BrokerRegistration(int id, + long epoch, + Uuid incarnationId, + List listeners, + Map supportedFeatures, + Optional rack, + boolean fenced) { + this(id, epoch, incarnationId, listenersToMap(listeners), supportedFeatures, rack, fenced); + } + + public BrokerRegistration(int id, + long epoch, + Uuid incarnationId, + Map listeners, + Map supportedFeatures, + Optional rack, + boolean fenced) { + this.id = id; + this.epoch = epoch; + this.incarnationId = incarnationId; + Map newListeners = new HashMap<>(listeners.size()); + for (Entry entry : listeners.entrySet()) { + if (!entry.getValue().listenerName().isPresent()) { + throw new IllegalArgumentException("Broker listeners must be named."); + } + newListeners.put(entry.getKey(), entry.getValue()); + } + this.listeners = Collections.unmodifiableMap(newListeners); + Objects.requireNonNull(supportedFeatures); + this.supportedFeatures = new HashMap<>(supportedFeatures); + Objects.requireNonNull(rack); + this.rack = rack; + this.fenced = fenced; + } + + public static BrokerRegistration fromRecord(RegisterBrokerRecord record) { + Map listeners = new HashMap<>(); + for (BrokerEndpoint endpoint : record.endPoints()) { + listeners.put(endpoint.name(), new Endpoint(endpoint.name(), + SecurityProtocol.forId(endpoint.securityProtocol()), + endpoint.host(), + endpoint.port())); + } + Map supportedFeatures = new HashMap<>(); + for (BrokerFeature feature : record.features()) { + supportedFeatures.put(feature.name(), new VersionRange( + feature.minSupportedVersion(), feature.maxSupportedVersion())); + } + return new BrokerRegistration(record.brokerId(), + record.brokerEpoch(), + record.incarnationId(), + listeners, + supportedFeatures, + Optional.ofNullable(record.rack()), + record.fenced()); + } + + public int id() { + return id; + } + + public long epoch() { + return epoch; + } + + public Uuid incarnationId() { + return incarnationId; + } + + public Map listeners() { + return listeners; + } + + public Optional node(String listenerName) { + Endpoint endpoint = listeners().get(listenerName); + if (endpoint == null) { + return Optional.empty(); + } + return Optional.of(new Node(id, endpoint.host(), endpoint.port(), rack.orElse(null))); + } + + public Map supportedFeatures() { + return supportedFeatures; + } + + public Optional rack() { + return rack; + } + + public boolean fenced() { + return fenced; + } + + public ApiMessageAndVersion toRecord() { + RegisterBrokerRecord registrationRecord = new RegisterBrokerRecord(). + setBrokerId(id). + setRack(rack.orElse(null)). + setBrokerEpoch(epoch). + setIncarnationId(incarnationId). + setFenced(fenced); + for (Entry entry : listeners.entrySet()) { + Endpoint endpoint = entry.getValue(); + registrationRecord.endPoints().add(new BrokerEndpoint(). + setName(entry.getKey()). + setHost(endpoint.host()). + setPort(endpoint.port()). + setSecurityProtocol(endpoint.securityProtocol().id)); + } + for (Entry entry : supportedFeatures.entrySet()) { + registrationRecord.features().add(new BrokerFeature(). + setName(entry.getKey()). + setMinSupportedVersion(entry.getValue().min()). + setMaxSupportedVersion(entry.getValue().max())); + } + return new ApiMessageAndVersion(registrationRecord, + REGISTER_BROKER_RECORD.highestSupportedVersion()); + } + + @Override + public int hashCode() { + return Objects.hash(id, epoch, incarnationId, listeners, supportedFeatures, + rack, fenced); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof BrokerRegistration)) return false; + BrokerRegistration other = (BrokerRegistration) o; + return other.id == id && + other.epoch == epoch && + other.incarnationId.equals(incarnationId) && + other.listeners.equals(listeners) && + other.supportedFeatures.equals(supportedFeatures) && + other.rack.equals(rack) && + other.fenced == fenced; + } + + @Override + public String toString() { + StringBuilder bld = new StringBuilder(); + bld.append("BrokerRegistration(id=").append(id); + bld.append(", epoch=").append(epoch); + bld.append(", incarnationId=").append(incarnationId); + bld.append(", listeners=[").append( + listeners.keySet().stream().sorted(). + map(n -> listeners.get(n).toString()). + collect(Collectors.joining(", "))); + bld.append("], supportedFeatures={").append( + supportedFeatures.entrySet().stream().sorted(). + map(e -> e.getKey() + ": " + e.getValue()). + collect(Collectors.joining(", "))); + bld.append("}"); + bld.append(", rack=").append(rack); + bld.append(", fenced=").append(fenced); + bld.append(")"); + return bld.toString(); + } + + public BrokerRegistration cloneWithFencing(boolean fencing) { + return new BrokerRegistration(id, epoch, incarnationId, listeners, + supportedFeatures, rack, fencing); + } +} diff --git a/metadata/src/main/java/org/apache/kafka/metadata/BrokerRegistrationReply.java b/metadata/src/main/java/org/apache/kafka/metadata/BrokerRegistrationReply.java new file mode 100644 index 0000000..40678ed --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/metadata/BrokerRegistrationReply.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.metadata; + +import java.util.Objects; + + +public class BrokerRegistrationReply { + private final long epoch; + + public BrokerRegistrationReply(long epoch) { + this.epoch = epoch; + } + + public long epoch() { + return epoch; + } + + @Override + public int hashCode() { + return Objects.hash(epoch); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof BrokerRegistrationReply)) return false; + BrokerRegistrationReply other = (BrokerRegistrationReply) o; + return other.epoch == epoch; + } + + @Override + public String toString() { + return "BrokerRegistrationReply(epoch=" + epoch + ")"; + } +} diff --git a/metadata/src/main/java/org/apache/kafka/metadata/BrokerState.java b/metadata/src/main/java/org/apache/kafka/metadata/BrokerState.java new file mode 100644 index 0000000..82f1215 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/metadata/BrokerState.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.metadata; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.HashMap; +import java.util.Map; + +/** + * The broker state. + * + * The numeric values used here are part of Kafka's public API. They appear in metrics, + * and are also sent over the wire in some cases. + * + * The expected state transitions are: + * + * NOT_RUNNING + * ↓ + * STARTING + * ↓ + * RECOVERY + * ↓ + * RUNNING + * ↓ + * PENDING_CONTROLLED_SHUTDOWN + * ↓ + * SHUTTING_DOWN + */ +@InterfaceStability.Evolving +public enum BrokerState { + /** + * The state the broker is in when it first starts up. + */ + NOT_RUNNING((byte) 0), + + /** + * The state the broker is in when it is catching up with cluster metadata. + */ + STARTING((byte) 1), + + /** + * The broker has caught up with cluster metadata, but has not yet + * been unfenced by the controller. + */ + RECOVERY((byte) 2), + + /** + * The state the broker is in when it has registered at least once, and is + * accepting client requests. + */ + RUNNING((byte) 3), + + /** + * The state the broker is in when it is attempting to perform a controlled + * shutdown. + */ + PENDING_CONTROLLED_SHUTDOWN((byte) 6), + + /** + * The state the broker is in when it is shutting down. + */ + SHUTTING_DOWN((byte) 7), + + /** + * The broker is in an unknown state. + */ + UNKNOWN((byte) 127); + + private final static Map VALUES_TO_ENUMS = new HashMap<>(); + + static { + for (BrokerState state : BrokerState.values()) { + VALUES_TO_ENUMS.put(state.value(), state); + } + } + + private final byte value; + + BrokerState(byte value) { + this.value = value; + } + + public static BrokerState fromValue(byte value) { + BrokerState state = VALUES_TO_ENUMS.get(value); + if (state == null) { + return UNKNOWN; + } + return state; + } + + public byte value() { + return value; + } +} diff --git a/metadata/src/main/java/org/apache/kafka/metadata/FeatureMap.java b/metadata/src/main/java/org/apache/kafka/metadata/FeatureMap.java new file mode 100644 index 0000000..272c87d --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/metadata/FeatureMap.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.metadata; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + + +/** + * A map of feature names to their supported versions. + */ +public class FeatureMap { + private final Map features; + + public FeatureMap(Map features) { + this.features = Collections.unmodifiableMap(new HashMap<>(features)); + } + + public Optional get(String name) { + return Optional.ofNullable(features.get(name)); + } + + public Map features() { + return features; + } + + @Override + public int hashCode() { + return features.hashCode(); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof FeatureMap)) return false; + FeatureMap other = (FeatureMap) o; + return features.equals(other.features); + } + + @Override + public String toString() { + StringBuilder bld = new StringBuilder(); + bld.append("{"); + bld.append(features.keySet().stream().sorted(). + map(k -> k + ": " + features.get(k)). + collect(Collectors.joining(", "))); + bld.append("}"); + return bld.toString(); + } +} diff --git a/metadata/src/main/java/org/apache/kafka/metadata/FeatureMapAndEpoch.java b/metadata/src/main/java/org/apache/kafka/metadata/FeatureMapAndEpoch.java new file mode 100644 index 0000000..26096ea --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/metadata/FeatureMapAndEpoch.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.metadata; + +import java.util.Objects; + + +/** + * A map of feature names to their supported versions. + */ +public class FeatureMapAndEpoch { + private final FeatureMap map; + private final long epoch; + + public FeatureMapAndEpoch(FeatureMap map, long epoch) { + this.map = map; + this.epoch = epoch; + } + + public FeatureMap map() { + return map; + } + + public long epoch() { + return epoch; + } + + @Override + public int hashCode() { + return Objects.hash(map, epoch); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof FeatureMapAndEpoch)) return false; + FeatureMapAndEpoch other = (FeatureMapAndEpoch) o; + return map.equals(other.map) && epoch == other.epoch; + } + + @Override + public String toString() { + StringBuilder bld = new StringBuilder(); + bld.append("{"); + bld.append("map=").append(map.toString()); + bld.append(", epoch=").append(epoch); + bld.append("}"); + return bld.toString(); + } +} diff --git a/metadata/src/main/java/org/apache/kafka/metadata/LeaderConstants.java b/metadata/src/main/java/org/apache/kafka/metadata/LeaderConstants.java new file mode 100644 index 0000000..ad23931 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/metadata/LeaderConstants.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.metadata; + + +public class LeaderConstants { + /** + * A special value used to represent the leader for a partition with no leader. + */ + public static final int NO_LEADER = -1; + + /** + * A special value used to represent a PartitionChangeRecord that does not change the + * partition leader. + */ + public static final int NO_LEADER_CHANGE = -2; +} diff --git a/metadata/src/main/java/org/apache/kafka/metadata/MetadataRecordSerde.java b/metadata/src/main/java/org/apache/kafka/metadata/MetadataRecordSerde.java new file mode 100644 index 0000000..7964fed --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/metadata/MetadataRecordSerde.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.metadata; + +import org.apache.kafka.common.metadata.MetadataRecordType; +import org.apache.kafka.common.protocol.ApiMessage; +import org.apache.kafka.server.common.serialization.AbstractApiMessageSerde; + +public class MetadataRecordSerde extends AbstractApiMessageSerde { + public static final MetadataRecordSerde INSTANCE = new MetadataRecordSerde(); + + @Override + public ApiMessage apiMessageFor(short apiKey) { + return MetadataRecordType.fromId(apiKey).newMetadataRecord(); + } +} diff --git a/metadata/src/main/java/org/apache/kafka/metadata/OptionalStringComparator.java b/metadata/src/main/java/org/apache/kafka/metadata/OptionalStringComparator.java new file mode 100644 index 0000000..3f20507 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/metadata/OptionalStringComparator.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.metadata; + +import java.util.Comparator; +import java.util.Optional; + + +public class OptionalStringComparator implements Comparator> { + public static final OptionalStringComparator INSTANCE = new OptionalStringComparator(); + + @Override + public int compare(Optional a, Optional b) { + if (!a.isPresent()) { + if (!b.isPresent()) { + return 0; + } else { + return -1; + } + } else if (!b.isPresent()) { + return 1; + } + return a.get().compareTo(b.get()); + } +} diff --git a/metadata/src/main/java/org/apache/kafka/metadata/PartitionRegistration.java b/metadata/src/main/java/org/apache/kafka/metadata/PartitionRegistration.java new file mode 100644 index 0000000..933bda9 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/metadata/PartitionRegistration.java @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.metadata; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrPartitionState; +import org.apache.kafka.common.metadata.PartitionChangeRecord; +import org.apache.kafka.common.metadata.PartitionRecord; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.slf4j.Logger; + +import java.util.Arrays; +import java.util.Objects; + +import static org.apache.kafka.common.metadata.MetadataRecordType.PARTITION_RECORD; +import static org.apache.kafka.metadata.LeaderConstants.NO_LEADER; +import static org.apache.kafka.metadata.LeaderConstants.NO_LEADER_CHANGE; + + +public class PartitionRegistration { + public final int[] replicas; + public final int[] isr; + public final int[] removingReplicas; + public final int[] addingReplicas; + public final int leader; + public final int leaderEpoch; + public final int partitionEpoch; + + public static boolean electionWasClean(int newLeader, int[] isr) { + return newLeader == NO_LEADER || Replicas.contains(isr, newLeader); + } + + public PartitionRegistration(PartitionRecord record) { + this(Replicas.toArray(record.replicas()), + Replicas.toArray(record.isr()), + Replicas.toArray(record.removingReplicas()), + Replicas.toArray(record.addingReplicas()), + record.leader(), + record.leaderEpoch(), + record.partitionEpoch()); + } + + public PartitionRegistration(int[] replicas, int[] isr, int[] removingReplicas, + int[] addingReplicas, int leader, int leaderEpoch, + int partitionEpoch) { + this.replicas = replicas; + this.isr = isr; + this.removingReplicas = removingReplicas; + this.addingReplicas = addingReplicas; + this.leader = leader; + this.leaderEpoch = leaderEpoch; + this.partitionEpoch = partitionEpoch; + } + + public PartitionRegistration merge(PartitionChangeRecord record) { + int[] newReplicas = (record.replicas() == null) ? + replicas : Replicas.toArray(record.replicas()); + int[] newIsr = (record.isr() == null) ? isr : Replicas.toArray(record.isr()); + int[] newRemovingReplicas = (record.removingReplicas() == null) ? + removingReplicas : Replicas.toArray(record.removingReplicas()); + int[] newAddingReplicas = (record.addingReplicas() == null) ? + addingReplicas : Replicas.toArray(record.addingReplicas()); + int newLeader; + int newLeaderEpoch; + if (record.leader() == NO_LEADER_CHANGE) { + newLeader = leader; + newLeaderEpoch = leaderEpoch; + } else { + newLeader = record.leader(); + newLeaderEpoch = leaderEpoch + 1; + } + return new PartitionRegistration(newReplicas, + newIsr, + newRemovingReplicas, + newAddingReplicas, + newLeader, + newLeaderEpoch, + partitionEpoch + 1); + } + + public String diff(PartitionRegistration prev) { + StringBuilder builder = new StringBuilder(); + String prefix = ""; + if (!Arrays.equals(replicas, prev.replicas)) { + builder.append(prefix).append("replicas: "). + append(Arrays.toString(prev.replicas)). + append(" -> ").append(Arrays.toString(replicas)); + prefix = ", "; + } + if (!Arrays.equals(isr, prev.isr)) { + builder.append(prefix).append("isr: "). + append(Arrays.toString(prev.isr)). + append(" -> ").append(Arrays.toString(isr)); + prefix = ", "; + } + if (!Arrays.equals(removingReplicas, prev.removingReplicas)) { + builder.append(prefix).append("removingReplicas: "). + append(Arrays.toString(prev.removingReplicas)). + append(" -> ").append(Arrays.toString(removingReplicas)); + prefix = ", "; + } + if (!Arrays.equals(addingReplicas, prev.addingReplicas)) { + builder.append(prefix).append("addingReplicas: "). + append(Arrays.toString(prev.addingReplicas)). + append(" -> ").append(Arrays.toString(addingReplicas)); + prefix = ", "; + } + if (leader != prev.leader) { + builder.append(prefix).append("leader: "). + append(prev.leader).append(" -> ").append(leader); + prefix = ", "; + } + if (leaderEpoch != prev.leaderEpoch) { + builder.append(prefix).append("leaderEpoch: "). + append(prev.leaderEpoch).append(" -> ").append(leaderEpoch); + prefix = ", "; + } + if (partitionEpoch != prev.partitionEpoch) { + builder.append(prefix).append("partitionEpoch: "). + append(prev.partitionEpoch).append(" -> ").append(partitionEpoch); + } + return builder.toString(); + } + + public void maybeLogPartitionChange(Logger log, String description, PartitionRegistration prev) { + if (!electionWasClean(leader, prev.isr)) { + log.info("UNCLEAN partition change for {}: {}", description, diff(prev)); + } else if (log.isDebugEnabled()) { + log.debug("partition change for {}: {}", description, diff(prev)); + } + } + + public boolean hasLeader() { + return leader != LeaderConstants.NO_LEADER; + } + + public boolean hasPreferredLeader() { + return leader == preferredReplica(); + } + + public int preferredReplica() { + return replicas.length == 0 ? LeaderConstants.NO_LEADER : replicas[0]; + } + + public ApiMessageAndVersion toRecord(Uuid topicId, int partitionId) { + return new ApiMessageAndVersion(new PartitionRecord(). + setPartitionId(partitionId). + setTopicId(topicId). + setReplicas(Replicas.toList(replicas)). + setIsr(Replicas.toList(isr)). + setRemovingReplicas(Replicas.toList(removingReplicas)). + setAddingReplicas(Replicas.toList(addingReplicas)). + setLeader(leader). + setLeaderEpoch(leaderEpoch). + setPartitionEpoch(partitionEpoch), PARTITION_RECORD.highestSupportedVersion()); + } + + public LeaderAndIsrPartitionState toLeaderAndIsrPartitionState(TopicPartition tp, + boolean isNew) { + return new LeaderAndIsrPartitionState(). + setTopicName(tp.topic()). + setPartitionIndex(tp.partition()). + setControllerEpoch(-1). + setLeader(leader). + setLeaderEpoch(leaderEpoch). + setIsr(Replicas.toList(isr)). + setZkVersion(partitionEpoch). + setReplicas(Replicas.toList(replicas)). + setAddingReplicas(Replicas.toList(addingReplicas)). + setRemovingReplicas(Replicas.toList(removingReplicas)). + setIsNew(isNew); + } + + /** + * Returns true if this partition is reassigning. + */ + public boolean isReassigning() { + return removingReplicas.length > 0 || addingReplicas.length > 0; + } + + @Override + public int hashCode() { + return Objects.hash(replicas, isr, removingReplicas, addingReplicas, leader, + leaderEpoch, partitionEpoch); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof PartitionRegistration)) return false; + PartitionRegistration other = (PartitionRegistration) o; + return Arrays.equals(replicas, other.replicas) && + Arrays.equals(isr, other.isr) && + Arrays.equals(removingReplicas, other.removingReplicas) && + Arrays.equals(addingReplicas, other.addingReplicas) && + leader == other.leader && + leaderEpoch == other.leaderEpoch && + partitionEpoch == other.partitionEpoch; + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder("PartitionRegistration("); + builder.append("replicas=").append(Arrays.toString(replicas)); + builder.append(", isr=").append(Arrays.toString(isr)); + builder.append(", removingReplicas=").append(Arrays.toString(removingReplicas)); + builder.append(", addingReplicas=").append(Arrays.toString(addingReplicas)); + builder.append(", leader=").append(leader); + builder.append(", leaderEpoch=").append(leaderEpoch); + builder.append(", partitionEpoch=").append(partitionEpoch); + builder.append(")"); + return builder.toString(); + } +} diff --git a/metadata/src/main/java/org/apache/kafka/metadata/Replicas.java b/metadata/src/main/java/org/apache/kafka/metadata/Replicas.java new file mode 100644 index 0000000..fa5ef4b --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/metadata/Replicas.java @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.metadata; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + + +public class Replicas { + /** + * An empty replica array. + */ + public final static int[] NONE = new int[0]; + + /** + * Convert an array of integers to a list of ints. + * + * @param array The input array. + * @return The output list. + */ + public static List toList(int[] array) { + if (array == null) return null; + ArrayList list = new ArrayList<>(array.length); + for (int i = 0; i < array.length; i++) { + list.add(array[i]); + } + return list; + } + + /** + * Convert a list of integers to an array of ints. + * + * @param list The input list. + * @return The output array. + */ + public static int[] toArray(List list) { + if (list == null) return null; + int[] array = new int[list.size()]; + for (int i = 0; i < list.size(); i++) { + array[i] = list.get(i); + } + return array; + } + + /** + * Copy an array of ints. + * + * @param array The input array. + * @return A copy of the array. + */ + public static int[] clone(int[] array) { + int[] clone = new int[array.length]; + System.arraycopy(array, 0, clone, 0, array.length); + return clone; + } + + /** + * Check that a replica set is valid. + * + * @param replicas The replica set. + * @return True if none of the replicas are negative, and there are no + * duplicates. + */ + public static boolean validate(int[] replicas) { + if (replicas.length == 0) return true; + int[] sortedReplicas = clone(replicas); + Arrays.sort(sortedReplicas); + int prev = sortedReplicas[0]; + if (prev < 0) return false; + for (int i = 1; i < sortedReplicas.length; i++) { + int replica = sortedReplicas[i]; + if (prev == replica) return false; + prev = replica; + } + return true; + } + + /** + * Check that an isr set is valid. + * + * @param replicas The replica set. + * @param isr The in-sync replica set. + * @return True if none of the in-sync replicas are negative, there are + * no duplicates, and all in-sync replicas are also replicas. + */ + public static boolean validateIsr(int[] replicas, int[] isr) { + if (isr.length == 0) return true; + if (replicas.length == 0) return false; + int[] sortedReplicas = clone(replicas); + Arrays.sort(sortedReplicas); + int[] sortedIsr = clone(isr); + Arrays.sort(sortedIsr); + int j = 0; + if (sortedIsr[0] < 0) return false; + int prevIsr = -1; + for (int i = 0; i < sortedIsr.length; i++) { + int curIsr = sortedIsr[i]; + if (prevIsr == curIsr) return false; + prevIsr = curIsr; + while (true) { + if (j == sortedReplicas.length) return false; + int curReplica = sortedReplicas[j++]; + if (curReplica == curIsr) break; + } + } + return true; + } + + /** + * Returns true if an array of replicas contains a specific value. + * + * @param replicas The replica array. + * @param value The value to look for. + * + * @return True only if the value is found in the array. + */ + public static boolean contains(int[] replicas, int value) { + for (int i = 0; i < replicas.length; i++) { + if (replicas[i] == value) return true; + } + return false; + } + + /** + * Check if the first list of integers contains the second. + * + * @param a The first list + * @param b The second list + * + * @return True only if the first contains the second. + */ + public static boolean contains(List a, int[] b) { + List aSorted = new ArrayList<>(a); + aSorted.sort(Integer::compareTo); + List bSorted = Replicas.toList(b); + bSorted.sort(Integer::compareTo); + int i = 0; + for (int replica : bSorted) { + while (true) { + if (i >= aSorted.size()) return false; + int replica2 = aSorted.get(i++); + if (replica2 == replica) break; + if (replica2 > replica) return false; + } + } + return true; + } + + /** + * Copy a replica array without any occurrences of the given value. + * + * @param replicas The replica array. + * @param value The value to filter out. + * + * @return A new array without the given value. + */ + public static int[] copyWithout(int[] replicas, int value) { + int size = 0; + for (int i = 0; i < replicas.length; i++) { + if (replicas[i] != value) { + size++; + } + } + int[] result = new int[size]; + int j = 0; + for (int i = 0; i < replicas.length; i++) { + int replica = replicas[i]; + if (replica != value) { + result[j++] = replica; + } + } + return result; + } + + /** + * Copy a replica array without any occurrences of the given values. + * + * @param replicas The replica array. + * @param values The values to filter out. + * + * @return A new array without the given value. + */ + public static int[] copyWithout(int[] replicas, int[] values) { + int size = 0; + for (int i = 0; i < replicas.length; i++) { + if (!Replicas.contains(values, replicas[i])) { + size++; + } + } + int[] result = new int[size]; + int j = 0; + for (int i = 0; i < replicas.length; i++) { + int replica = replicas[i]; + if (!Replicas.contains(values, replica)) { + result[j++] = replica; + } + } + return result; + } + + /** + * Copy a replica array with the given value. + * + * @param replicas The replica array. + * @param value The value to add. + * + * @return A new array with the given value. + */ + public static int[] copyWith(int[] replicas, int value) { + int[] newReplicas = new int[replicas.length + 1]; + System.arraycopy(replicas, 0, newReplicas, 0, replicas.length); + newReplicas[newReplicas.length - 1] = value; + return newReplicas; + } + + /** + * Convert a replica array to a set. + * + * @param replicas The replica array. + * + * @return A new array with the given value. + */ + public static Set toSet(int[] replicas) { + Set result = new HashSet<>(); + for (int replica : replicas) { + result.add(replica); + } + return result; + } +} diff --git a/metadata/src/main/java/org/apache/kafka/metadata/UsableBroker.java b/metadata/src/main/java/org/apache/kafka/metadata/UsableBroker.java new file mode 100644 index 0000000..9c04ebd --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/metadata/UsableBroker.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.metadata; + +import java.util.Objects; +import java.util.Optional; + + +/** + * A broker where a replica can be placed. + */ +public class UsableBroker { + private final int id; + + private final Optional rack; + + private final boolean fenced; + + public UsableBroker(int id, Optional rack, boolean fenced) { + this.id = id; + this.rack = rack; + this.fenced = fenced; + } + + public int id() { + return id; + } + + public Optional rack() { + return rack; + } + + public boolean fenced() { + return fenced; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof UsableBroker)) return false; + UsableBroker other = (UsableBroker) o; + return other.id == id && other.rack.equals(rack) && other.fenced == fenced; + } + + @Override + public int hashCode() { + return Objects.hash(id, rack, fenced); + } + + @Override + public String toString() { + return "UsableBroker(id=" + id + ", rack=" + rack + ", fenced=" + fenced + ")"; + } +} diff --git a/metadata/src/main/java/org/apache/kafka/metadata/VersionRange.java b/metadata/src/main/java/org/apache/kafka/metadata/VersionRange.java new file mode 100644 index 0000000..f171ea1 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/metadata/VersionRange.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.metadata; + +import java.util.Objects; + +/** + * An immutable class which represents version ranges. + */ +public class VersionRange { + public final static VersionRange ALL = new VersionRange((short) 0, Short.MAX_VALUE); + + private final short min; + private final short max; + + public VersionRange(short min, short max) { + this.min = min; + this.max = max; + } + + public short min() { + return min; + } + + public short max() { + return max; + } + + public boolean contains(VersionRange other) { + return other.min >= min && other.max <= max; + } + + @Override + public int hashCode() { + return Objects.hash(min, max); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof VersionRange)) return false; + VersionRange other = (VersionRange) o; + return other.min == min && other.max == max; + } + + @Override + public String toString() { + if (min == max) { + return String.valueOf(min); + } else if (max == Short.MAX_VALUE) { + return String.valueOf(min) + "+"; + } else { + return String.valueOf(min) + "-" + String.valueOf(max); + } + } +} diff --git a/metadata/src/main/java/org/apache/kafka/timeline/BaseHashTable.java b/metadata/src/main/java/org/apache/kafka/timeline/BaseHashTable.java new file mode 100644 index 0000000..0531546 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/timeline/BaseHashTable.java @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.timeline; + +import java.util.ArrayList; +import java.util.List; + +/** + * A hash table which uses separate chaining. + * + * In order to optimize memory consumption a bit, the common case where there is + * one element per slot is handled by simply placing the element in the slot, + * and the case where there are multiple elements is handled by creating an + * array and putting that in the slot. Java is storing type info in memory + * about every object whether we want it or not, so let's get some benefit + * out of it. + * + * Arrays and null values cannot be inserted. + */ +@SuppressWarnings("unchecked") +class BaseHashTable { + /** + * The maximum load factor we will allow the hash table to climb to before expanding. + */ + private final static double MAX_LOAD_FACTOR = 0.75f; + + /** + * The minimum number of slots we can have in the hash table. + */ + final static int MIN_CAPACITY = 2; + + /** + * The maximum number of slots we can have in the hash table. + */ + final static int MAX_CAPACITY = 1 << 30; + + private Object[] elements; + private int size = 0; + + BaseHashTable(int expectedSize) { + this.elements = new Object[expectedSizeToCapacity(expectedSize)]; + } + + /** + * Calculate the capacity we should provision, given the expected size. + * + * Our capacity must always be a power of 2, and never less than 2 or more + * than MAX_CAPACITY. We use 64-bit numbers here to avoid overflow + * concerns. + */ + static int expectedSizeToCapacity(int expectedSize) { + long minCapacity = (long) Math.ceil((float) expectedSize / MAX_LOAD_FACTOR); + return Math.max(MIN_CAPACITY, + (int) Math.min(MAX_CAPACITY, roundUpToPowerOfTwo(minCapacity))); + } + + private static long roundUpToPowerOfTwo(long i) { + if (i <= 0) { + return 0; + } else if (i > (1L << 62)) { + throw new ArithmeticException("There are no 63-bit powers of 2 higher than " + + "or equal to " + i); + } else { + return 1L << -Long.numberOfLeadingZeros(i - 1); + } + } + + final int baseSize() { + return size; + } + + final Object[] baseElements() { + return elements; + } + + final T baseGet(Object key) { + int slot = findSlot(key, elements.length); + Object value = elements[slot]; + if (value == null) { + return null; + } else if (value instanceof Object[]) { + T[] array = (T[]) value; + for (T object : array) { + if (object.equals(key)) { + return object; + } + } + return null; + } else if (value.equals(key)) { + return (T) value; + } else { + return null; + } + } + + final T baseAddOrReplace(T newObject) { + if (((size + 1) * MAX_LOAD_FACTOR > elements.length) && + (elements.length < MAX_CAPACITY)) { + int newSize = elements.length * 2; + rehash(newSize); + } + int slot = findSlot(newObject, elements.length); + Object cur = elements[slot]; + if (cur == null) { + size++; + elements[slot] = newObject; + return null; + } else if (cur instanceof Object[]) { + T[] curArray = (T[]) cur; + for (int i = 0; i < curArray.length; i++) { + T value = curArray[i]; + if (value.equals(newObject)) { + curArray[i] = newObject; + return value; + } + } + size++; + T[] newArray = (T[]) new Object[curArray.length + 1]; + System.arraycopy(curArray, 0, newArray, 0, curArray.length); + newArray[curArray.length] = newObject; + elements[slot] = newArray; + return null; + } else if (cur.equals(newObject)) { + elements[slot] = newObject; + return (T) cur; + } else { + size++; + elements[slot] = new Object[] {cur, newObject}; + return null; + } + } + + final T baseRemove(Object key) { + int slot = findSlot(key, elements.length); + Object object = elements[slot]; + if (object == null) { + return null; + } else if (object instanceof Object[]) { + Object[] curArray = (Object[]) object; + for (int i = 0; i < curArray.length; i++) { + if (curArray[i].equals(key)) { + size--; + if (curArray.length <= 2) { + int j = i == 0 ? 1 : 0; + elements[slot] = curArray[j]; + } else { + Object[] newArray = new Object[curArray.length - 1]; + System.arraycopy(curArray, 0, newArray, 0, i); + System.arraycopy(curArray, i + 1, newArray, i, curArray.length - 1 - i); + elements[slot] = newArray; + } + return (T) curArray[i]; + } + } + return null; + } else if (object.equals(key)) { + size--; + elements[slot] = null; + return (T) object; + } else { + return null; + } + } + + /** + * Expand the hash table to a new size. Existing elements will be copied to new slots. + */ + final private void rehash(int newSize) { + Object[] prevElements = elements; + elements = new Object[newSize]; + List ready = new ArrayList<>(); + for (int slot = 0; slot < prevElements.length; slot++) { + unpackSlot(ready, prevElements, slot); + for (Object object : ready) { + int newSlot = findSlot(object, elements.length); + Object cur = elements[newSlot]; + if (cur == null) { + elements[newSlot] = object; + } else if (cur instanceof Object[]) { + Object[] curArray = (Object[]) cur; + Object[] newArray = new Object[curArray.length + 1]; + System.arraycopy(curArray, 0, newArray, 0, curArray.length); + newArray[curArray.length] = object; + elements[newSlot] = newArray; + } else { + elements[newSlot] = new Object[]{cur, object}; + } + } + ready.clear(); + } + } + + /** + * Find the slot in the array that an element should go into. + */ + static int findSlot(Object object, int numElements) { + // This performs a secondary hash using Knuth's multiplicative Fibonacci + // hashing. Then, we choose some of the highest bits. The number of bits + // we choose is based on the table size. If the size is 2, we need 1 bit; + // if the size is 4, we need 2 bits, etc. + int objectHashCode = object.hashCode(); + int log2size = 32 - Integer.numberOfLeadingZeros(numElements); + int shift = 65 - log2size; + return (int) ((objectHashCode * -7046029254386353131L) >>> shift); + } + + /** + * Copy any elements in the given slot into the output list. + */ + static void unpackSlot(List out, Object[] elements, int slot) { + Object value = elements[slot]; + if (value == null) { + return; + } else if (value instanceof Object[]) { + Object[] array = (Object[]) value; + for (Object object : array) { + out.add((T) object); + } + } else { + out.add((T) value); + } + } + + String baseToDebugString() { + StringBuilder bld = new StringBuilder(); + bld.append("BaseHashTable{"); + for (int i = 0; i < elements.length; i++) { + Object slotObject = elements[i]; + bld.append(String.format("%n%d: ", i)); + if (slotObject == null) { + bld.append("null"); + } else if (slotObject instanceof Object[]) { + Object[] array = (Object[]) slotObject; + String prefix = ""; + for (Object object : array) { + bld.append(prefix); + prefix = ", "; + bld.append(object); + } + } else { + bld.append(slotObject); + } + } + bld.append(String.format("%n}")); + return bld.toString(); + } +} diff --git a/metadata/src/main/java/org/apache/kafka/timeline/Delta.java b/metadata/src/main/java/org/apache/kafka/timeline/Delta.java new file mode 100644 index 0000000..f18302a --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/timeline/Delta.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.timeline; + +/** + * An API which snapshot delta structures implement. + */ +interface Delta { + /** + * Merge the source delta into this one. + * + * @param destinationEpoch The epoch of this delta. + * @param source The source delta. + */ + void mergeFrom(long destinationEpoch, Delta source); +} diff --git a/metadata/src/main/java/org/apache/kafka/timeline/Revertable.java b/metadata/src/main/java/org/apache/kafka/timeline/Revertable.java new file mode 100644 index 0000000..43eb117 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/timeline/Revertable.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.timeline; + +/** + * An API which all snapshot data structures implement, indicating that their contents + * can be reverted to a point in time. + */ +interface Revertable { + /** + * Revert to the target epoch. + * + * @param targetEpoch The epoch to revert to. + * @param delta The delta associated with this epoch for this object. + */ + void executeRevert(long targetEpoch, Delta delta); + + /** + * Reverts to the initial value. + */ + void reset(); +} diff --git a/metadata/src/main/java/org/apache/kafka/timeline/Snapshot.java b/metadata/src/main/java/org/apache/kafka/timeline/Snapshot.java new file mode 100644 index 0000000..8efa61f --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/timeline/Snapshot.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.timeline; + +import java.util.IdentityHashMap; +import java.util.Map; + +/** + * A snapshot of some timeline data structures. + * + * The snapshot contains historical data for several timeline data structures. + * We use an IdentityHashMap to store this data. This way, we can easily drop all of + * the snapshot data. + */ +class Snapshot { + private final long epoch; + private IdentityHashMap map = new IdentityHashMap<>(4); + private Snapshot prev = this; + private Snapshot next = this; + + Snapshot(long epoch) { + this.epoch = epoch; + } + + long epoch() { + return epoch; + } + + @SuppressWarnings("unchecked") + T getDelta(Revertable owner) { + return (T) map.get(owner); + } + + void setDelta(Revertable owner, Delta delta) { + map.put(owner, delta); + } + + void handleRevert() { + for (Map.Entry entry : map.entrySet()) { + entry.getKey().executeRevert(epoch, entry.getValue()); + } + } + + void mergeFrom(Snapshot source) { + // Merge the deltas from the source snapshot into this snapshot. + for (Map.Entry entry : source.map.entrySet()) { + // We first try to just copy over the object reference. That will work if + //we have no entry at all for the given Revertable. + Delta destinationDelta = map.putIfAbsent(entry.getKey(), entry.getValue()); + if (destinationDelta != null) { + // If we already have an entry for the Revertable, we need to merge the + // source delta into our delta. + destinationDelta.mergeFrom(epoch, entry.getValue()); + } + } + // Delete the source snapshot to make sure nobody tries to reuse it. We might now + // share some delta entries with it. + source.erase(); + } + + Snapshot prev() { + return prev; + } + + Snapshot next() { + return next; + } + + void appendNext(Snapshot newNext) { + newNext.prev = this; + newNext.next = next; + next.prev = newNext; + next = newNext; + } + + void erase() { + map = null; + next.prev = prev; + prev.next = next; + prev = this; + next = this; + } +} diff --git a/metadata/src/main/java/org/apache/kafka/timeline/SnapshotRegistry.java b/metadata/src/main/java/org/apache/kafka/timeline/SnapshotRegistry.java new file mode 100644 index 0000000..997d49a --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/timeline/SnapshotRegistry.java @@ -0,0 +1,284 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.timeline; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.stream.Collectors; + +import org.apache.kafka.common.utils.LogContext; +import org.slf4j.Logger; + + +/** + * A registry containing snapshots of timeline data structures. + * We generally expect a small number of snapshots-- perhaps 1 or 2 at a time. + * Therefore, we use ArrayLists here rather than a data structure with higher overhead. + */ +public class SnapshotRegistry { + public final static long LATEST_EPOCH = Long.MAX_VALUE; + + /** + * Iterate through the list of snapshots in order of creation, such that older + * snapshots come first. + */ + class SnapshotIterator implements Iterator { + Snapshot cur; + Snapshot result = null; + + SnapshotIterator(Snapshot start) { + cur = start; + } + + @Override + public boolean hasNext() { + return cur != head; + } + + @Override + public Snapshot next() { + result = cur; + cur = cur.next(); + return result; + } + + @Override + public void remove() { + if (result == null) { + throw new IllegalStateException(); + } + deleteSnapshot(result); + result = null; + } + } + + /** + * Iterate through the list of snapshots in reverse order of creation, such that + * the newest snapshot is first. + */ + class ReverseSnapshotIterator implements Iterator { + Snapshot cur; + + ReverseSnapshotIterator() { + cur = head.prev(); + } + + @Override + public boolean hasNext() { + return cur != head; + } + + @Override + public Snapshot next() { + Snapshot result = cur; + cur = cur.prev(); + return result; + } + } + + private final Logger log; + + /** + * A map from snapshot epochs to snapshot data structures. + */ + private final HashMap snapshots = new HashMap<>(); + + /** + * The head of a list of snapshots, sorted by epoch. + */ + private final Snapshot head = new Snapshot(Long.MIN_VALUE); + + /** + * Collection of all Revertable registered with this registry + */ + private final List revertables = new ArrayList<>(); + + public SnapshotRegistry(LogContext logContext) { + this.log = logContext.logger(SnapshotRegistry.class); + } + + /** + * Returns a snapshot iterator that iterates from the snapshots with the + * lowest epoch to those with the highest. + */ + public Iterator iterator() { + return new SnapshotIterator(head.next()); + } + + /** + * Returns a snapshot iterator that iterates from the snapshots with the + * lowest epoch to those with the highest, starting at the snapshot with the + * given epoch. + */ + public Iterator iterator(long epoch) { + return iterator(getSnapshot(epoch)); + } + + /** + * Returns a snapshot iterator that iterates from the snapshots with the + * lowest epoch to those with the highest, starting at the given snapshot. + */ + public Iterator iterator(Snapshot snapshot) { + return new SnapshotIterator(snapshot); + } + + /** + * Returns a reverse snapshot iterator that iterates from the snapshots with the + * highest epoch to those with the lowest. + */ + public Iterator reverseIterator() { + return new ReverseSnapshotIterator(); + } + + /** + * Returns a sorted list of snapshot epochs. + */ + public List epochsList() { + List result = new ArrayList<>(); + for (Iterator iterator = iterator(); iterator.hasNext(); ) { + result.add(iterator.next().epoch()); + } + return result; + } + + public boolean hasSnapshot(long epoch) { + return snapshots.containsKey(epoch); + } + + /** + * Gets the snapshot for a specific epoch. + */ + public Snapshot getSnapshot(long epoch) { + Snapshot snapshot = snapshots.get(epoch); + if (snapshot == null) { + throw new RuntimeException("No snapshot for epoch " + epoch + ". Snapshot " + + "epochs are: " + epochsList().stream().map(e -> e.toString()). + collect(Collectors.joining(", "))); + } + return snapshot; + } + + /** + * Creates a new snapshot at the given epoch. + * + * If {@code epoch} already exists and it is the last snapshot then just return that snapshot. + * + * @param epoch The epoch to create the snapshot at. The current epoch + * will be advanced to one past this epoch. + */ + public Snapshot getOrCreateSnapshot(long epoch) { + Snapshot last = head.prev(); + if (last.epoch() > epoch) { + throw new RuntimeException("Can't create a new snapshot at epoch " + epoch + + " because there is already a snapshot with epoch " + last.epoch()); + } else if (last.epoch() == epoch) { + return last; + } + Snapshot snapshot = new Snapshot(epoch); + last.appendNext(snapshot); + snapshots.put(epoch, snapshot); + log.debug("Creating snapshot {}", epoch); + return snapshot; + } + + /** + * Reverts the state of all data structures to the state at the given epoch. + * + * @param targetEpoch The epoch of the snapshot to revert to. + */ + public void revertToSnapshot(long targetEpoch) { + Snapshot target = getSnapshot(targetEpoch); + Iterator iterator = iterator(target); + iterator.next(); + while (iterator.hasNext()) { + Snapshot snapshot = iterator.next(); + log.debug("Deleting snapshot {} because we are reverting to {}", + snapshot.epoch(), targetEpoch); + iterator.remove(); + } + target.handleRevert(); + } + + /** + * Deletes the snapshot with the given epoch. + * + * @param targetEpoch The epoch of the snapshot to delete. + */ + public void deleteSnapshot(long targetEpoch) { + deleteSnapshot(getSnapshot(targetEpoch)); + } + + /** + * Deletes the given snapshot. + * + * @param snapshot The snapshot to delete. + */ + public void deleteSnapshot(Snapshot snapshot) { + Snapshot prev = snapshot.prev(); + if (prev != head) { + prev.mergeFrom(snapshot); + } else { + snapshot.erase(); + } + log.debug("Deleting snapshot {}", snapshot.epoch()); + snapshots.remove(snapshot.epoch(), snapshot); + } + + /** + * Deletes all the snapshots up to the given epoch + * + * @param targetEpoch The epoch to delete up to. + */ + public void deleteSnapshotsUpTo(long targetEpoch) { + for (Iterator iterator = iterator(); iterator.hasNext(); ) { + Snapshot snapshot = iterator.next(); + if (snapshot.epoch() >= targetEpoch) { + return; + } + log.debug("Deleting snapshot {}", snapshot.epoch()); + iterator.remove(); + } + } + + /** + * Return the latest epoch. + */ + public long latestEpoch() { + return head.prev().epoch(); + } + + /** + * Associate with this registry. + */ + public void register(Revertable revertable) { + revertables.add(revertable); + } + + /** + * Delete all snapshots and resets all of the Revertable object registered. + */ + public void reset() { + deleteSnapshotsUpTo(LATEST_EPOCH); + + for (Revertable revertable : revertables) { + revertable.reset(); + } + } +} diff --git a/metadata/src/main/java/org/apache/kafka/timeline/SnapshottableHashTable.java b/metadata/src/main/java/org/apache/kafka/timeline/SnapshottableHashTable.java new file mode 100644 index 0000000..cbd0a28 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/timeline/SnapshottableHashTable.java @@ -0,0 +1,465 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.timeline; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; + + +/** + * SnapshottableHashTable implements a hash table that supports creating point-in-time + * snapshots. Each snapshot is immutable once it is created; the past cannot be changed. + * We handle divergences between the current state and historical state by copying a + * reference to elements that have been deleted or overwritten into the most recent + * snapshot tier. + * + * Note that there are no keys in SnapshottableHashTable, only values. So it more similar + * to a hash set than a hash map. The subclasses implement full-featured maps and sets + * using this class as a building block. + * + * Each snapshot tier contains a size and a hash table. The size reflects the size at + * the time the snapshot was taken. Note that, as an optimization, snapshot tiers will + * be null if they don't contain anything. So for example, if snapshot 20 of Object O + * contains the same entries as snapshot 10 of that object, the snapshot 20 tier for + * object O will be null. + * + * The current tier's data is stored in the fields inherited from BaseHashTable. It + * would be conceptually simpler to have a separate BaseHashTable object, but since Java + * doesn't have value types, subclassing is the only way to avoid another pointer + * indirection and the associated extra memory cost. + * + * Note that each element in the hash table contains a start epoch, and a value. The + * start epoch is there to identify when the object was first inserted. This in turn + * determines which snapshots it is a member of. + * + * In order to retrieve an object from snapshot E, we start by checking to see if the + * object exists in the "current" hash tier. If it does, and its startEpoch extends back + * to E, we return that object. Otherwise, we check all the snapshot tiers, starting + * with E, and ending with the most recent snapshot, to see if the object is there. + * As an optimization, if we encounter the object in a snapshot tier but its epoch is too + * new, we know that its value at epoch E must be null, so we can return that immediately. + * + * The class hierarchy looks like this: + * + * Revertable BaseHashTable + * ↑ ↑ + * SnapshottableHashTable → SnapshotRegistry → Snapshot + * ↑ ↑ + * TimelineHashSet TimelineHashMap + * + * BaseHashTable is a simple hash table that uses separate chaining. The interface is + * pretty bare-bones since this class is not intended to be used directly by end-users. + * + * This class, SnapshottableHashTable, has the logic for snapshotting and iterating over + * snapshots. This is the core of the snapshotted hash table code and handles the + * tiering. + * + * TimelineHashSet and TimelineHashMap are mostly wrappers around this + * SnapshottableHashTable class. They implement standard Java APIs for Set and Map, + * respectively. There's a fair amount of boilerplate for this, but it's necessary so + * that timeline data structures can be used while writing idiomatic Java code. + * The accessor APIs have two versions -- one that looks at the current state, and one + * that looks at a historical snapshotted state. Mutation APIs only ever mutate the + * current state. + * + * One very important feature of SnapshottableHashTable is that we support iterating + * over a snapshot even while changes are being made to the current state. See the + * Javadoc for the iterator for more information about how this is accomplished. + * + * All of these classes require external synchronization, and don't support null keys or + * values. + */ +class SnapshottableHashTable + extends BaseHashTable implements Revertable { + + /** + * A special epoch value that represents the latest data. + */ + final static long LATEST_EPOCH = Long.MAX_VALUE; + + interface ElementWithStartEpoch { + void setStartEpoch(long startEpoch); + long startEpoch(); + } + + static class HashTier implements Delta { + private final int size; + private BaseHashTable deltaTable; + + HashTier(int size) { + this.size = size; + } + + @SuppressWarnings("unchecked") + @Override + public void mergeFrom(long epoch, Delta source) { + HashTier other = (HashTier) source; + List list = new ArrayList<>(); + Object[] otherElements = other.deltaTable.baseElements(); + for (int slot = 0; slot < otherElements.length; slot++) { + BaseHashTable.unpackSlot(list, otherElements, slot); + for (T element : list) { + // When merging in a later hash tier, we want to keep only the elements + // that were present at our epoch. + if (element.startEpoch() <= epoch) { + deltaTable.baseAddOrReplace(element); + } + } + } + } + } + + /** + * Iterate over the values that currently exist in the hash table. + * + * You can use this iterator even if you are making changes to the map. + * The changes may or may not be visible while you are iterating. + */ + class CurrentIterator implements Iterator { + private final Object[] topTier; + private final List ready; + private int slot; + private T lastReturned; + + CurrentIterator(Object[] topTier) { + this.topTier = topTier; + this.ready = new ArrayList<>(); + this.slot = 0; + this.lastReturned = null; + } + + @Override + public boolean hasNext() { + while (ready.isEmpty()) { + if (slot == topTier.length) { + return false; + } + BaseHashTable.unpackSlot(ready, topTier, slot); + slot++; + } + return true; + } + + @Override + public T next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + lastReturned = ready.remove(ready.size() - 1); + return lastReturned; + } + + @Override + public void remove() { + if (lastReturned == null) { + throw new UnsupportedOperationException("remove"); + } + snapshottableRemove(lastReturned); + lastReturned = null; + } + } + + /** + * Iterate over the values that existed in the hash table during a specific snapshot. + * + * You can use this iterator even if you are making changes to the map. + * The snapshot is immutable and will always show up the same. + */ + class HistoricalIterator implements Iterator { + private final Object[] topTier; + private final Snapshot snapshot; + private final List temp; + private final List ready; + private int slot; + + HistoricalIterator(Object[] topTier, Snapshot snapshot) { + this.topTier = topTier; + this.snapshot = snapshot; + this.temp = new ArrayList<>(); + this.ready = new ArrayList<>(); + this.slot = 0; + } + + @Override + public boolean hasNext() { + while (ready.isEmpty()) { + if (slot == topTier.length) { + return false; + } + BaseHashTable.unpackSlot(temp, topTier, slot); + for (T object : temp) { + if (object.startEpoch() <= snapshot.epoch()) { + ready.add(object); + } + } + temp.clear(); + + /* + * As we iterate over the SnapshottableHashTable, elements may move from + * the top tier into the snapshot tiers. This would happen if something + * were deleted in the top tier, for example, but still retained in the + * snapshot. + * + * We don't want to return any elements twice, though. Therefore, we + * iterate over the top tier and the snapshot tier at the + * same time. The key to understanding how this works is realizing that + * both hash tables use the same hash function, but simply choose a + * different number of significant bits based on their size. + * So if the top tier has size 4 and the snapshot tier has size 2, we have + * the following mapping: + * + * Elements that would be in slot 0 or 1 in the top tier can only be in + * slot 0 in the snapshot tier. + * Elements that would be in slot 2 or 3 in the top tier can only be in + * slot 1 in the snapshot tier. + * + * Therefore, we can do something like this: + * 1. check slot 0 in the top tier and slot 0 in the snapshot tier. + * 2. check slot 1 in the top tier and slot 0 in the snapshot tier. + * 3. check slot 2 in the top tier and slot 1 in the snapshot tier. + * 4. check slot 3 in the top tier and slot 1 in the snapshot tier. + * + * If elements move from the top tier to the snapshot tier, then + * we'll still find them and report them exactly once. + * + * Note that while I used 4 and 2 as example sizes here, the same pattern + * holds for different powers of two. The "snapshot slot" of an element + * will be the top few bits of the top tier slot of that element. + */ + Iterator iterator = snapshotRegistry.iterator(snapshot); + while (iterator.hasNext()) { + Snapshot curSnapshot = iterator.next(); + HashTier tier = curSnapshot.getDelta(SnapshottableHashTable.this); + if (tier != null && tier.deltaTable != null) { + BaseHashTable deltaTable = tier.deltaTable; + int shift = Integer.numberOfLeadingZeros(deltaTable.baseElements().length) - + Integer.numberOfLeadingZeros(topTier.length); + int tierSlot = slot >>> shift; + BaseHashTable.unpackSlot(temp, deltaTable.baseElements(), tierSlot); + for (T object : temp) { + if (BaseHashTable.findSlot(object, topTier.length) == slot) { + ready.add(object); + } + } + temp.clear(); + } + } + slot++; + } + return true; + } + + @Override + public T next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return ready.remove(ready.size() - 1); + } + } + + private final SnapshotRegistry snapshotRegistry; + + SnapshottableHashTable(SnapshotRegistry snapshotRegistry, int expectedSize) { + super(expectedSize); + this.snapshotRegistry = snapshotRegistry; + snapshotRegistry.register(this); + } + + int snapshottableSize(long epoch) { + if (epoch == LATEST_EPOCH) { + return baseSize(); + } else { + Iterator iterator = snapshotRegistry.iterator(epoch); + while (iterator.hasNext()) { + Snapshot snapshot = iterator.next(); + HashTier tier = snapshot.getDelta(SnapshottableHashTable.this); + if (tier != null) { + return tier.size; + } + } + return baseSize(); + } + } + + T snapshottableGet(Object key, long epoch) { + T result = baseGet(key); + if (result != null && result.startEpoch() <= epoch) { + return result; + } + if (epoch == LATEST_EPOCH) { + return null; + } + Iterator iterator = snapshotRegistry.iterator(epoch); + while (iterator.hasNext()) { + Snapshot snapshot = iterator.next(); + HashTier tier = snapshot.getDelta(SnapshottableHashTable.this); + if (tier != null && tier.deltaTable != null) { + result = tier.deltaTable.baseGet(key); + if (result != null) { + if (result.startEpoch() <= epoch) { + return result; + } else { + return null; + } + } + } + } + return null; + } + + boolean snapshottableAddUnlessPresent(T object) { + T prev = baseGet(object); + if (prev != null) { + return false; + } + object.setStartEpoch(snapshotRegistry.latestEpoch() + 1); + int prevSize = baseSize(); + baseAddOrReplace(object); + updateTierData(prevSize); + return true; + } + + T snapshottableAddOrReplace(T object) { + object.setStartEpoch(snapshotRegistry.latestEpoch() + 1); + int prevSize = baseSize(); + T prev = baseAddOrReplace(object); + if (prev == null) { + updateTierData(prevSize); + } else { + updateTierData(prev, prevSize); + } + return prev; + } + + T snapshottableRemove(Object object) { + T prev = baseRemove(object); + if (prev == null) { + return null; + } else { + updateTierData(prev, baseSize() + 1); + return prev; + } + } + + private void updateTierData(int prevSize) { + Iterator iterator = snapshotRegistry.reverseIterator(); + if (iterator.hasNext()) { + Snapshot snapshot = iterator.next(); + HashTier tier = snapshot.getDelta(SnapshottableHashTable.this); + if (tier == null) { + tier = new HashTier<>(prevSize); + snapshot.setDelta(SnapshottableHashTable.this, tier); + } + } + } + + private void updateTierData(T prev, int prevSize) { + Iterator iterator = snapshotRegistry.reverseIterator(); + if (iterator.hasNext()) { + Snapshot snapshot = iterator.next(); + // If the previous element was present in the most recent snapshot, add it to + // that snapshot's hash tier. + if (prev.startEpoch() <= snapshot.epoch()) { + HashTier tier = snapshot.getDelta(SnapshottableHashTable.this); + if (tier == null) { + tier = new HashTier<>(prevSize); + snapshot.setDelta(SnapshottableHashTable.this, tier); + } + if (tier.deltaTable == null) { + tier.deltaTable = new BaseHashTable<>(1); + } + tier.deltaTable.baseAddOrReplace(prev); + } + } + } + + Iterator snapshottableIterator(long epoch) { + if (epoch == LATEST_EPOCH) { + return new CurrentIterator(baseElements()); + } else { + return new HistoricalIterator(baseElements(), snapshotRegistry.getSnapshot(epoch)); + } + } + + String snapshottableToDebugString() { + StringBuilder bld = new StringBuilder(); + bld.append(String.format("SnapshottableHashTable{%n")); + bld.append("top tier: "); + bld.append(baseToDebugString()); + bld.append(String.format(",%nsnapshot tiers: [%n")); + String prefix = ""; + for (Iterator iter = snapshotRegistry.iterator(); iter.hasNext(); ) { + Snapshot snapshot = iter.next(); + bld.append(prefix); + bld.append("epoch ").append(snapshot.epoch()).append(": "); + HashTier tier = snapshot.getDelta(this); + if (tier == null) { + bld.append("null"); + } else { + bld.append("HashTier{"); + bld.append("size=").append(tier.size); + bld.append(", deltaTable="); + if (tier.deltaTable == null) { + bld.append("null"); + } else { + bld.append(tier.deltaTable.baseToDebugString()); + } + bld.append("}"); + } + bld.append(String.format("%n")); + } + bld.append(String.format("]}%n")); + return bld.toString(); + } + + @SuppressWarnings("unchecked") + @Override + public void executeRevert(long targetEpoch, Delta delta) { + HashTier tier = (HashTier) delta; + Iterator iter = snapshottableIterator(LATEST_EPOCH); + while (iter.hasNext()) { + T element = iter.next(); + if (element.startEpoch() > targetEpoch) { + iter.remove(); + } + } + BaseHashTable deltaTable = tier.deltaTable; + if (deltaTable != null) { + List out = new ArrayList<>(); + for (int i = 0; i < deltaTable.baseElements().length; i++) { + BaseHashTable.unpackSlot(out, deltaTable.baseElements(), i); + for (T value : out) { + baseAddOrReplace(value); + } + out.clear(); + } + } + } + + @Override + public void reset() { + Iterator iter = snapshottableIterator(SnapshottableHashTable.LATEST_EPOCH); + while (iter.hasNext()) { + iter.next(); + iter.remove(); + } + } +} diff --git a/metadata/src/main/java/org/apache/kafka/timeline/TimelineHashMap.java b/metadata/src/main/java/org/apache/kafka/timeline/TimelineHashMap.java new file mode 100644 index 0000000..855e7ed --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/timeline/TimelineHashMap.java @@ -0,0 +1,411 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.timeline; + +import java.util.AbstractCollection; +import java.util.AbstractSet; +import java.util.Collection; +import java.util.Iterator; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +/** + * This is a hash map which can be snapshotted. + * + * See {@SnapshottableHashTable} for more details about the implementation. + * + * This class requires external synchronization. Null keys and values are not supported. + * + * @param The key type of the set. + * @param The value type of the set. + */ +public class TimelineHashMap + extends SnapshottableHashTable> + implements Map { + static class TimelineHashMapEntry + implements SnapshottableHashTable.ElementWithStartEpoch, Map.Entry { + private final K key; + private final V value; + private long startEpoch; + + TimelineHashMapEntry(K key, V value) { + this.key = key; + this.value = value; + this.startEpoch = SnapshottableHashTable.LATEST_EPOCH; + } + + @Override + public K getKey() { + return key; + } + + @Override + public V getValue() { + return value; + } + + @Override + public V setValue(V value) { + // This would be inefficient to support since we'd need a back-reference + // to the enclosing map in each Entry object. There would also be + // complications if this entry object was sourced from a historical iterator; + // we don't support modifying the past. Since we don't really need this API, + // let's just not support it. + throw new UnsupportedOperationException(); + } + + @Override + public void setStartEpoch(long startEpoch) { + this.startEpoch = startEpoch; + } + + @Override + public long startEpoch() { + return startEpoch; + } + + @SuppressWarnings("unchecked") + @Override + public boolean equals(Object o) { + if (!(o instanceof TimelineHashMapEntry)) return false; + TimelineHashMapEntry other = (TimelineHashMapEntry) o; + return key.equals(other.key); + } + + @Override + public int hashCode() { + return key.hashCode(); + } + } + + public TimelineHashMap(SnapshotRegistry snapshotRegistry, int expectedSize) { + super(snapshotRegistry, expectedSize); + } + + @Override + public int size() { + return size(SnapshottableHashTable.LATEST_EPOCH); + } + + public int size(long epoch) { + return snapshottableSize(epoch); + } + + @Override + public boolean isEmpty() { + return isEmpty(SnapshottableHashTable.LATEST_EPOCH); + } + + public boolean isEmpty(long epoch) { + return snapshottableSize(epoch) == 0; + } + + @Override + public boolean containsKey(Object key) { + return containsKey(key, SnapshottableHashTable.LATEST_EPOCH); + } + + public boolean containsKey(Object key, long epoch) { + return snapshottableGet(new TimelineHashMapEntry<>(key, null), epoch) != null; + } + + @Override + public boolean containsValue(Object value) { + Iterator> iter = entrySet().iterator(); + while (iter.hasNext()) { + Entry e = iter.next(); + if (value.equals(e.getValue())) { + return true; + } + } + return false; + } + + @Override + public V get(Object key) { + return get(key, SnapshottableHashTable.LATEST_EPOCH); + } + + public V get(Object key, long epoch) { + Entry entry = + snapshottableGet(new TimelineHashMapEntry<>(key, null), epoch); + if (entry == null) { + return null; + } + return entry.getValue(); + } + + @Override + public V put(K key, V value) { + Objects.requireNonNull(key); + Objects.requireNonNull(value); + TimelineHashMapEntry entry = new TimelineHashMapEntry<>(key, value); + TimelineHashMapEntry prev = snapshottableAddOrReplace(entry); + if (prev == null) { + return null; + } + return prev.getValue(); + } + + @Override + public V remove(Object key) { + TimelineHashMapEntry result = snapshottableRemove( + new TimelineHashMapEntry<>(key, null)); + return result == null ? null : result.value; + } + + @Override + public void putAll(Map map) { + for (Map.Entry e : map.entrySet()) { + put(e.getKey(), e.getValue()); + } + } + + @Override + public void clear() { + reset(); + } + + final class KeySet extends AbstractSet { + private final long epoch; + + KeySet(long epoch) { + this.epoch = epoch; + } + + public final int size() { + return TimelineHashMap.this.size(epoch); + } + + public final void clear() { + if (epoch != SnapshottableHashTable.LATEST_EPOCH) { + throw new RuntimeException("can't modify snapshot"); + } + TimelineHashMap.this.clear(); + } + + public final Iterator iterator() { + return new KeyIterator(epoch); + } + + public final boolean contains(Object o) { + return TimelineHashMap.this.containsKey(o, epoch); + } + + public final boolean remove(Object o) { + if (epoch != SnapshottableHashTable.LATEST_EPOCH) { + throw new RuntimeException("can't modify snapshot"); + } + return TimelineHashMap.this.remove(o) != null; + } + } + + final class KeyIterator implements Iterator { + private final Iterator> iter; + + KeyIterator(long epoch) { + this.iter = snapshottableIterator(epoch); + } + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public K next() { + TimelineHashMapEntry next = iter.next(); + return next.getKey(); + } + + @Override + public void remove() { + iter.remove(); + } + } + + @Override + public Set keySet() { + return keySet(SnapshottableHashTable.LATEST_EPOCH); + } + + public Set keySet(long epoch) { + return new KeySet(epoch); + } + + final class Values extends AbstractCollection { + private final long epoch; + + Values(long epoch) { + this.epoch = epoch; + } + + public final int size() { + return TimelineHashMap.this.size(epoch); + } + + public final void clear() { + if (epoch != SnapshottableHashTable.LATEST_EPOCH) { + throw new RuntimeException("can't modify snapshot"); + } + TimelineHashMap.this.clear(); + } + + public final Iterator iterator() { + return new ValueIterator(epoch); + } + + public final boolean contains(Object o) { + return TimelineHashMap.this.containsKey(o, epoch); + } + } + + final class ValueIterator implements Iterator { + private final Iterator> iter; + + ValueIterator(long epoch) { + this.iter = snapshottableIterator(epoch); + } + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public V next() { + TimelineHashMapEntry next = iter.next(); + return next.getValue(); + } + + @Override + public void remove() { + iter.remove(); + } + } + + @Override + public Collection values() { + return values(SnapshottableHashTable.LATEST_EPOCH); + } + + public Collection values(long epoch) { + return new Values(epoch); + } + + final class EntrySet extends AbstractSet> { + private final long epoch; + + EntrySet(long epoch) { + this.epoch = epoch; + } + + public final int size() { + return TimelineHashMap.this.size(epoch); + } + + public final void clear() { + if (epoch != SnapshottableHashTable.LATEST_EPOCH) { + throw new RuntimeException("can't modify snapshot"); + } + TimelineHashMap.this.clear(); + } + + public final Iterator> iterator() { + return new EntryIterator(epoch); + } + + public final boolean contains(Object o) { + return snapshottableGet(o, epoch) != null; + } + + public final boolean remove(Object o) { + if (epoch != SnapshottableHashTable.LATEST_EPOCH) { + throw new RuntimeException("can't modify snapshot"); + } + return snapshottableRemove(new TimelineHashMapEntry<>(o, null)) != null; + } + } + + final class EntryIterator implements Iterator> { + private final Iterator> iter; + + EntryIterator(long epoch) { + this.iter = snapshottableIterator(epoch); + } + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public Map.Entry next() { + return iter.next(); + } + + @Override + public void remove() { + iter.remove(); + } + } + + @Override + public Set> entrySet() { + return entrySet(SnapshottableHashTable.LATEST_EPOCH); + } + + public Set> entrySet(long epoch) { + return new EntrySet(epoch); + } + + @Override + public int hashCode() { + int hash = 0; + Iterator> iter = entrySet().iterator(); + while (iter.hasNext()) { + hash += iter.next().hashCode(); + } + return hash; + } + + @Override + public boolean equals(Object o) { + if (o == this) + return true; + if (!(o instanceof Map)) + return false; + Map m = (Map) o; + if (m.size() != size()) + return false; + try { + Iterator> iter = entrySet().iterator(); + while (iter.hasNext()) { + Entry entry = iter.next(); + if (!m.get(entry.getKey()).equals(entry.getValue())) { + return false; + } + } + } catch (ClassCastException unused) { + return false; + } + return true; + + } +} diff --git a/metadata/src/main/java/org/apache/kafka/timeline/TimelineHashSet.java b/metadata/src/main/java/org/apache/kafka/timeline/TimelineHashSet.java new file mode 100644 index 0000000..34efb10 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/timeline/TimelineHashSet.java @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.timeline; + +import java.util.Collection; +import java.util.Iterator; +import java.util.Objects; +import java.util.Set; + +/** + * This is a hash set which can be snapshotted. + * + * See {@SnapshottableHashTable} for more details about the implementation. + * + * This class requires external synchronization. Null values are not supported. + * + * @param The value type of the set. + */ +public class TimelineHashSet + extends SnapshottableHashTable> + implements Set { + static class TimelineHashSetEntry + implements SnapshottableHashTable.ElementWithStartEpoch { + private final T value; + private long startEpoch; + + TimelineHashSetEntry(T value) { + this.value = value; + this.startEpoch = SnapshottableHashTable.LATEST_EPOCH; + } + + public T getValue() { + return value; + } + + @Override + public void setStartEpoch(long startEpoch) { + this.startEpoch = startEpoch; + } + + @Override + public long startEpoch() { + return startEpoch; + } + + @SuppressWarnings("unchecked") + @Override + public boolean equals(Object o) { + if (!(o instanceof TimelineHashSetEntry)) return false; + TimelineHashSetEntry other = (TimelineHashSetEntry) o; + return value.equals(other.value); + } + + @Override + public int hashCode() { + return value.hashCode(); + } + } + + public TimelineHashSet(SnapshotRegistry snapshotRegistry, int expectedSize) { + super(snapshotRegistry, expectedSize); + } + + @Override + public int size() { + return size(SnapshottableHashTable.LATEST_EPOCH); + } + + public int size(long epoch) { + return snapshottableSize(epoch); + } + + @Override + public boolean isEmpty() { + return isEmpty(SnapshottableHashTable.LATEST_EPOCH); + } + + public boolean isEmpty(long epoch) { + return snapshottableSize(epoch) == 0; + } + + @Override + public boolean contains(Object key) { + return contains(key, SnapshottableHashTable.LATEST_EPOCH); + } + + public boolean contains(Object object, long epoch) { + return snapshottableGet(new TimelineHashSetEntry<>(object), epoch) != null; + } + + final class ValueIterator implements Iterator { + private final Iterator> iter; + + ValueIterator(long epoch) { + this.iter = snapshottableIterator(epoch); + } + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public T next() { + return iter.next().value; + } + + @Override + public void remove() { + iter.remove(); + } + } + + @Override + public Iterator iterator() { + return iterator(SnapshottableHashTable.LATEST_EPOCH); + } + + public Iterator iterator(long epoch) { + return new ValueIterator(epoch); + } + + @Override + public Object[] toArray() { + Object[] result = new Object[size()]; + Iterator iter = iterator(); + int i = 0; + while (iter.hasNext()) { + result[i++] = iter.next(); + } + return result; + } + + @SuppressWarnings("unchecked") + @Override + public R[] toArray(R[] a) { + int size = size(); + if (size <= a.length) { + Iterator iter = iterator(); + int i = 0; + while (iter.hasNext()) { + a[i++] = (R) iter.next(); + } + while (i < a.length) { + a[i++] = null; + } + return a; + } else { + return (R[]) toArray(); + } + } + + @Override + public boolean add(T newValue) { + Objects.requireNonNull(newValue); + return snapshottableAddUnlessPresent(new TimelineHashSetEntry<>(newValue)); + } + + @Override + public boolean remove(Object value) { + return snapshottableRemove(new TimelineHashSetEntry<>(value)) != null; + } + + @Override + public boolean containsAll(Collection collection) { + for (Object value : collection) { + if (!contains(value)) return false; + } + return true; + } + + @Override + public boolean addAll(Collection collection) { + boolean modified = false; + for (T value : collection) { + if (add(value)) { + modified = true; + } + } + return modified; + } + + @Override + public boolean retainAll(Collection collection) { + Objects.requireNonNull(collection); + boolean modified = false; + Iterator it = iterator(); + while (it.hasNext()) { + if (!collection.contains(it.next())) { + it.remove(); + modified = true; + } + } + return modified; + } + + @Override + public boolean removeAll(Collection collection) { + Objects.requireNonNull(collection); + boolean modified = false; + Iterator it = iterator(); + while (it.hasNext()) { + if (collection.contains(it.next())) { + it.remove(); + modified = true; + } + } + return modified; + } + + @Override + public void clear() { + reset(); + } + + @Override + public int hashCode() { + int hash = 0; + Iterator iter = iterator(); + while (iter.hasNext()) { + hash += iter.next().hashCode(); + } + return hash; + } + + @Override + public boolean equals(Object o) { + if (o == this) + return true; + if (!(o instanceof Set)) + return false; + Collection c = (Collection) o; + if (c.size() != size()) + return false; + try { + return containsAll(c); + } catch (ClassCastException unused) { + return false; + } + } +} diff --git a/metadata/src/main/java/org/apache/kafka/timeline/TimelineInteger.java b/metadata/src/main/java/org/apache/kafka/timeline/TimelineInteger.java new file mode 100644 index 0000000..d158890 --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/timeline/TimelineInteger.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.timeline; + +import java.util.Iterator; + + +/** + * This is a mutable integer which can be snapshotted. + * + * This class requires external synchronization. + */ +public class TimelineInteger implements Revertable { + public static final int INIT = 0; + + static class IntegerContainer implements Delta { + private int value = INIT; + + int value() { + return value; + } + + void setValue(int value) { + this.value = value; + } + + @Override + public void mergeFrom(long destinationEpoch, Delta delta) { + // Nothing to do + } + } + + private final SnapshotRegistry snapshotRegistry; + private int value; + + public TimelineInteger(SnapshotRegistry snapshotRegistry) { + this.snapshotRegistry = snapshotRegistry; + this.value = INIT; + + snapshotRegistry.register(this); + } + + public int get() { + return value; + } + + public int get(long epoch) { + if (epoch == SnapshotRegistry.LATEST_EPOCH) return value; + Iterator iterator = snapshotRegistry.iterator(epoch); + while (iterator.hasNext()) { + Snapshot snapshot = iterator.next(); + IntegerContainer container = snapshot.getDelta(TimelineInteger.this); + if (container != null) return container.value(); + } + return value; + } + + public void set(int newValue) { + Iterator iterator = snapshotRegistry.reverseIterator(); + if (iterator.hasNext()) { + Snapshot snapshot = iterator.next(); + IntegerContainer container = snapshot.getDelta(TimelineInteger.this); + if (container == null) { + container = new IntegerContainer(); + snapshot.setDelta(TimelineInteger.this, container); + container.setValue(value); + } + } + this.value = newValue; + } + + public void increment() { + set(get() + 1); + } + + public void decrement() { + set(get() - 1); + } + + @SuppressWarnings("unchecked") + @Override + public void executeRevert(long targetEpoch, Delta delta) { + IntegerContainer container = (IntegerContainer) delta; + this.value = container.value; + } + + @Override + public void reset() { + set(INIT); + } + + @Override + public int hashCode() { + return value; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof TimelineInteger)) return false; + TimelineInteger other = (TimelineInteger) o; + return value == other.value; + } + + @Override + public String toString() { + return Integer.toString(value); + } +} diff --git a/metadata/src/main/java/org/apache/kafka/timeline/TimelineLong.java b/metadata/src/main/java/org/apache/kafka/timeline/TimelineLong.java new file mode 100644 index 0000000..9b401db --- /dev/null +++ b/metadata/src/main/java/org/apache/kafka/timeline/TimelineLong.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.timeline; + +import java.util.Iterator; + + +/** + * This is a mutable long which can be snapshotted. + * + * This class requires external synchronization. + */ +public class TimelineLong implements Revertable { + public static final long INIT = 0; + + static class LongContainer implements Delta { + private long value = INIT; + + long value() { + return value; + } + + void setValue(long value) { + this.value = value; + } + + @Override + public void mergeFrom(long destinationEpoch, Delta delta) { + // Nothing to do + } + } + + private final SnapshotRegistry snapshotRegistry; + private long value; + + public TimelineLong(SnapshotRegistry snapshotRegistry) { + this.snapshotRegistry = snapshotRegistry; + this.value = INIT; + + snapshotRegistry.register(this); + } + + public long get() { + return value; + } + + public long get(long epoch) { + if (epoch == SnapshotRegistry.LATEST_EPOCH) return value; + Iterator iterator = snapshotRegistry.iterator(epoch); + while (iterator.hasNext()) { + Snapshot snapshot = iterator.next(); + LongContainer container = snapshot.getDelta(TimelineLong.this); + if (container != null) return container.value(); + } + return value; + } + + public void set(long newValue) { + Iterator iterator = snapshotRegistry.reverseIterator(); + if (iterator.hasNext()) { + Snapshot snapshot = iterator.next(); + LongContainer prevContainer = snapshot.getDelta(TimelineLong.this); + if (prevContainer == null) { + prevContainer = new LongContainer(); + snapshot.setDelta(TimelineLong.this, prevContainer); + prevContainer.setValue(value); + } + } + this.value = newValue; + } + + public void increment() { + set(get() + 1L); + } + + public void decrement() { + set(get() - 1L); + } + + @SuppressWarnings("unchecked") + @Override + public void executeRevert(long targetEpoch, Delta delta) { + LongContainer container = (LongContainer) delta; + this.value = container.value(); + } + + @Override + public void reset() { + set(INIT); + } + + @Override + public int hashCode() { + return ((int) value) ^ (int) (value >>> 32); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof TimelineLong)) return false; + TimelineLong other = (TimelineLong) o; + return value == other.value; + } + + @Override + public String toString() { + return Long.toString(value); + } +} diff --git a/metadata/src/main/resources/common/metadata/AccessControlRecord.json b/metadata/src/main/resources/common/metadata/AccessControlRecord.json new file mode 100644 index 0000000..deef33c --- /dev/null +++ b/metadata/src/main/resources/common/metadata/AccessControlRecord.json @@ -0,0 +1,38 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 6, + "type": "metadata", + "name": "AccessControlRecord", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "ResourceType", "type": "int8", "versions": "0+", + "about": "The resource type" }, + { "name": "ResourceName", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The resource name, or null if this is for the default resource." }, + { "name": "PatternType", "type": "int8", "versions": "0+", + "about": "The pattern type (literal, prefixed, etc.)" }, + { "name": "Principal", "type": "string", "versions": "0+", + "about": "The principal name." }, + { "name": "Host", "type": "string", "versions": "0+", + "about": "The host." }, + { "name": "Operation", "type": "int8", "versions": "0+", + "about": "The operation type." }, + { "name": "PermissionType", "type": "int8", "versions": "0+", + "about": "The permission type (allow, deny)." } + ] +} diff --git a/metadata/src/main/resources/common/metadata/BrokerRegistrationChangeRecord.json b/metadata/src/main/resources/common/metadata/BrokerRegistrationChangeRecord.json new file mode 100644 index 0000000..152508c --- /dev/null +++ b/metadata/src/main/resources/common/metadata/BrokerRegistrationChangeRecord.json @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 17, + "type": "metadata", + "name": "BrokerRegistrationChangeRecord", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "BrokerId", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The broker id." }, + { "name": "BrokerEpoch", "type": "int64", "versions": "0+", + "about": "The broker epoch assigned by the controller." }, + { "name": "Fenced", "type": "int8", "versions": "0+", "taggedVersions": "0+", "tag": 0, + "about": "-1 if the broker has been unfenced, 0 if no change, 1 if the broker has been fenced." } + ] +} diff --git a/metadata/src/main/resources/common/metadata/ClientQuotaRecord.json b/metadata/src/main/resources/common/metadata/ClientQuotaRecord.json new file mode 100644 index 0000000..9bb7aca --- /dev/null +++ b/metadata/src/main/resources/common/metadata/ClientQuotaRecord.json @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 14, + "type": "metadata", + "name": "ClientQuotaRecord", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "Entity", "type": "[]EntityData", "versions": "0+", + "about": "The quota entity to alter.", "fields": [ + { "name": "EntityType", "type": "string", "versions": "0+", + "about": "The entity type." }, + { "name": "EntityName", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The name of the entity, or null if the default." } + ]}, + { "name": "Key", "type": "string", "versions": "0+", + "about": "The quota configuration key." }, + { "name": "Value", "type": "float64", "versions": "0+", + "about": "The value to set, otherwise ignored if the value is to be removed." }, + { "name": "Remove", "type": "bool", "versions": "0+", + "about": "Whether the quota configuration value should be removed, otherwise set." } + ] +} diff --git a/metadata/src/main/resources/common/metadata/ConfigRecord.json b/metadata/src/main/resources/common/metadata/ConfigRecord.json new file mode 100644 index 0000000..a0f0c3a --- /dev/null +++ b/metadata/src/main/resources/common/metadata/ConfigRecord.json @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 4, + "type": "metadata", + "name": "ConfigRecord", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "ResourceType", "type": "int8", "versions": "0+", + "about": "The type of resource this configuration applies to." }, + { "name": "ResourceName", "type": "string", "versions": "0+", + "about": "The name of the resource this configuration applies to." }, + { "name": "Name", "type": "string", "versions": "0+", + "about": "The name of the configuration key." }, + { "name": "Value", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The value of the configuration, or null if the it should be deleted." } + ] +} diff --git a/metadata/src/main/resources/common/metadata/DelegationTokenRecord.json b/metadata/src/main/resources/common/metadata/DelegationTokenRecord.json new file mode 100644 index 0000000..d07c293 --- /dev/null +++ b/metadata/src/main/resources/common/metadata/DelegationTokenRecord.json @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 10, + "type": "metadata", + "name": "DelegationTokenRecord", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "Owner", "type": "string", "versions": "0+", + "about": "The delegation token owner." }, + { "name": "Renewers", "type": "[]string", "versions": "0+", + "about": "The principals which have renewed this token." }, + { "name": "IssueTimestamp", "type": "int64", "versions": "0+", + "about": "The time at which this timestamp was issued." }, + { "name": "MaxTimestamp", "type": "int64", "versions": "0+", + "about": "The time at which this token cannot be renewed any more." }, + { "name": "ExpirationTimestamp", "type": "int64", "versions": "0+", + "about": "The next time at which this token must be renewed." }, + { "name": "TokenId", "type": "string", "versions": "0+", + "about": "The token id." } + ] +} diff --git a/metadata/src/main/resources/common/metadata/FeatureLevelRecord.json b/metadata/src/main/resources/common/metadata/FeatureLevelRecord.json new file mode 100644 index 0000000..ac112f1 --- /dev/null +++ b/metadata/src/main/resources/common/metadata/FeatureLevelRecord.json @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 12, + "type": "metadata", + "name": "FeatureLevelRecord", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "Name", "type": "string", "versions": "0+", + "about": "The feature name." }, + { "name": "MinFeatureLevel", "type": "int16", "versions": "0+", + "about": "The current finalized minimum feature level of this feature for the cluster." }, + { "name": "MaxFeatureLevel", "type": "int16", "versions": "0+", + "about": "The current finalized maximum feature level of this feature for the cluster." } + ] +} diff --git a/metadata/src/main/resources/common/metadata/FenceBrokerRecord.json b/metadata/src/main/resources/common/metadata/FenceBrokerRecord.json new file mode 100644 index 0000000..0cd29be --- /dev/null +++ b/metadata/src/main/resources/common/metadata/FenceBrokerRecord.json @@ -0,0 +1,28 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 7, + "type": "metadata", + "name": "FenceBrokerRecord", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "Id", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The broker ID to fence. It will be removed from all ISRs." }, + { "name": "Epoch", "type": "int64", "versions": "0+", + "about": "The epoch of the broker to fence." } + ] +} diff --git a/metadata/src/main/resources/common/metadata/PartitionChangeRecord.json b/metadata/src/main/resources/common/metadata/PartitionChangeRecord.json new file mode 100644 index 0000000..7afaa42 --- /dev/null +++ b/metadata/src/main/resources/common/metadata/PartitionChangeRecord.json @@ -0,0 +1,43 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 5, + "type": "metadata", + "name": "PartitionChangeRecord", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "PartitionId", "type": "int32", "versions": "0+", "default": "-1", + "about": "The partition id." }, + { "name": "TopicId", "type": "uuid", "versions": "0+", + "about": "The unique ID of this topic." }, + { "name": "Isr", "type": "[]int32", "default": "null", "entityType": "brokerId", + "versions": "0+", "nullableVersions": "0+", "taggedVersions": "0+", "tag": 0, + "about": "null if the ISR didn't change; the new in-sync replicas otherwise." }, + { "name": "Leader", "type": "int32", "default": "-2", "entityType": "brokerId", + "versions": "0+", "taggedVersions": "0+", "tag": 1, + "about": "-1 if there is now no leader; -2 if the leader didn't change; the new leader otherwise." }, + { "name": "Replicas", "type": "[]int32", "default": "null", "entityType": "brokerId", + "versions": "0+", "nullableVersions": "0+", "taggedVersions": "0+", "tag": 2, + "about": "null if the replicas didn't change; the new replicas otherwise." }, + { "name": "RemovingReplicas", "type": "[]int32", "default": "null", "entityType": "brokerId", + "versions": "0+", "nullableVersions": "0+", "taggedVersions": "0+", "tag": 3, + "about": "null if the removing replicas didn't change; the new removing replicas otherwise." }, + { "name": "AddingReplicas", "type": "[]int32", "default": "null", "entityType": "brokerId", + "versions": "0+", "nullableVersions": "0+", "taggedVersions": "0+", "tag": 4, + "about": "null if the adding replicas didn't change; the new adding replicas otherwise." } + ] +} diff --git a/metadata/src/main/resources/common/metadata/PartitionRecord.json b/metadata/src/main/resources/common/metadata/PartitionRecord.json new file mode 100644 index 0000000..66a13e2 --- /dev/null +++ b/metadata/src/main/resources/common/metadata/PartitionRecord.json @@ -0,0 +1,42 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 3, + "type": "metadata", + "name": "PartitionRecord", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "PartitionId", "type": "int32", "versions": "0+", "default": "-1", + "about": "The partition id." }, + { "name": "TopicId", "type": "uuid", "versions": "0+", + "about": "The unique ID of this topic." }, + { "name": "Replicas", "type": "[]int32", "versions": "0+", "entityType": "brokerId", + "about": "The replicas of this partition, sorted by preferred order." }, + { "name": "Isr", "type": "[]int32", "versions": "0+", + "about": "The in-sync replicas of this partition" }, + { "name": "RemovingReplicas", "type": "[]int32", "versions": "0+", "entityType": "brokerId", + "about": "The replicas that we are in the process of removing." }, + { "name": "AddingReplicas", "type": "[]int32", "versions": "0+", "entityType": "brokerId", + "about": "The replicas that we are in the process of adding." }, + { "name": "Leader", "type": "int32", "versions": "0+", "default": "-1", "entityType": "brokerId", + "about": "The lead replica, or -1 if there is no leader." }, + { "name": "LeaderEpoch", "type": "int32", "versions": "0+", "default": "-1", + "about": "The epoch of the partition leader." }, + { "name": "PartitionEpoch", "type": "int32", "versions": "0+", "default": "-1", + "about": "An epoch that gets incremented each time we change anything in the partition." } + ] +} diff --git a/metadata/src/main/resources/common/metadata/ProducerIdsRecord.json b/metadata/src/main/resources/common/metadata/ProducerIdsRecord.json new file mode 100644 index 0000000..0467871 --- /dev/null +++ b/metadata/src/main/resources/common/metadata/ProducerIdsRecord.json @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 15, + "type": "metadata", + "name": "ProducerIdsRecord", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "BrokerId", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The ID of the requesting broker" }, + { "name": "BrokerEpoch", "type": "int64", "versions": "0+", "default": "-1", + "about": "The epoch of the requesting broker" }, + { "name": "ProducerIdsEnd", "type": "int64", "versions": "0+", + "about": "The highest producer ID that has been generated"} + ] +} diff --git a/metadata/src/main/resources/common/metadata/RegisterBrokerRecord.json b/metadata/src/main/resources/common/metadata/RegisterBrokerRecord.json new file mode 100644 index 0000000..a0e7af2 --- /dev/null +++ b/metadata/src/main/resources/common/metadata/RegisterBrokerRecord.json @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 0, + "type": "metadata", + "name": "RegisterBrokerRecord", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "BrokerId", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The broker id." }, + { "name": "IncarnationId", "type": "uuid", "versions": "0+", + "about": "The incarnation ID of the broker process" }, + { "name": "BrokerEpoch", "type": "int64", "versions": "0+", + "about": "The broker epoch assigned by the controller." }, + { "name": "EndPoints", "type": "[]BrokerEndpoint", "versions": "0+", + "about": "The endpoints that can be used to communicate with this broker.", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "mapKey": true, + "about": "The name of the endpoint." }, + { "name": "Host", "type": "string", "versions": "0+", + "about": "The hostname." }, + { "name": "Port", "type": "uint16", "versions": "0+", + "about": "The port." }, + { "name": "SecurityProtocol", "type": "int16", "versions": "0+", + "about": "The security protocol." } + ]}, + { "name": "Features", "type": "[]BrokerFeature", + "about": "The features on this broker", "versions": "0+", "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "mapKey": true, + "about": "The feature name." }, + { "name": "MinSupportedVersion", "type": "int16", "versions": "0+", + "about": "The minimum supported feature level." }, + { "name": "MaxSupportedVersion", "type": "int16", "versions": "0+", + "about": "The maximum supported feature level." } + ]}, + { "name": "Rack", "type": "string", "versions": "0+", "nullableVersions": "0+", + "about": "The broker rack." }, + { "name": "Fenced", "type": "bool", "versions": "0+", "default": "true", + "about": "True if the broker is fenced." } + ] +} diff --git a/metadata/src/main/resources/common/metadata/RemoveFeatureLevelRecord.json b/metadata/src/main/resources/common/metadata/RemoveFeatureLevelRecord.json new file mode 100644 index 0000000..6ed7161 --- /dev/null +++ b/metadata/src/main/resources/common/metadata/RemoveFeatureLevelRecord.json @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 16, + "type": "metadata", + "name": "RemoveFeatureLevelRecord", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "Name", "type": "string", "versions": "0+", + "about": "The feature name." } + ] +} diff --git a/metadata/src/main/resources/common/metadata/RemoveTopicRecord.json b/metadata/src/main/resources/common/metadata/RemoveTopicRecord.json new file mode 100644 index 0000000..be290e3 --- /dev/null +++ b/metadata/src/main/resources/common/metadata/RemoveTopicRecord.json @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 9, + "type": "metadata", + "name": "RemoveTopicRecord", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "TopicId", "type": "uuid", "versions": "0+", + "about": "The topic to remove. All associated partitions will be removed as well." } + ] +} diff --git a/metadata/src/main/resources/common/metadata/TopicRecord.json b/metadata/src/main/resources/common/metadata/TopicRecord.json new file mode 100644 index 0000000..6fa5a05 --- /dev/null +++ b/metadata/src/main/resources/common/metadata/TopicRecord.json @@ -0,0 +1,28 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 2, + "type": "metadata", + "name": "TopicRecord", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "Name", "type": "string", "versions": "0+", "entityType": "topicName", + "about": "The topic name." }, + { "name": "TopicId", "type": "uuid", "versions": "0+", + "about": "The unique ID of this topic." } + ] +} diff --git a/metadata/src/main/resources/common/metadata/UnfenceBrokerRecord.json b/metadata/src/main/resources/common/metadata/UnfenceBrokerRecord.json new file mode 100644 index 0000000..92770a6 --- /dev/null +++ b/metadata/src/main/resources/common/metadata/UnfenceBrokerRecord.json @@ -0,0 +1,28 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 8, + "type": "metadata", + "name": "UnfenceBrokerRecord", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "Id", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The broker ID to unfence." }, + { "name": "Epoch", "type": "int64", "versions": "0+", + "about": "The epoch of the broker to unfence." } + ] +} diff --git a/metadata/src/main/resources/common/metadata/UnregisterBrokerRecord.json b/metadata/src/main/resources/common/metadata/UnregisterBrokerRecord.json new file mode 100644 index 0000000..358d88a --- /dev/null +++ b/metadata/src/main/resources/common/metadata/UnregisterBrokerRecord.json @@ -0,0 +1,28 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 1, + "type": "metadata", + "name": "UnregisterBrokerRecord", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "BrokerId", "type": "int32", "versions": "0+", "entityType": "brokerId", + "about": "The broker id." }, + { "name": "BrokerEpoch", "type": "int64", "versions": "0+", + "about": "The broker epoch." } + ] +} diff --git a/metadata/src/main/resources/common/metadata/UserScramCredentialRecord.json b/metadata/src/main/resources/common/metadata/UserScramCredentialRecord.json new file mode 100644 index 0000000..2f106ff --- /dev/null +++ b/metadata/src/main/resources/common/metadata/UserScramCredentialRecord.json @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 11, + "type": "metadata", + "name": "UserScramCredentialRecord", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { "name": "Name", "type": "string", "versions": "0+", + "about": "The user name." }, + { "name": "CredentialInfos", "type": "[]CredentialInfo", "versions": "0+", + "about": "The mechanism and related information associated with the user's SCRAM credential.", "fields": [ + { "name": "Mechanism", "type": "int8", "versions": "0+", + "about": "The SCRAM mechanism." }, + { "name": "Salt", "type": "bytes", "versions": "0+", + "about": "A random salt generated by the client." }, + { "name": "SaltedPassword", "type": "bytes", "versions": "0+", + "about": "The salted password." }, + { "name": "Iterations", "type": "int32", "versions": "0+", + "about": "The number of iterations used in the SCRAM credential." }]} + ] +} diff --git a/metadata/src/test/java/org/apache/kafka/controller/BrokerHeartbeatManagerTest.java b/metadata/src/test/java/org/apache/kafka/controller/BrokerHeartbeatManagerTest.java new file mode 100644 index 0000000..c5c46ab --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/controller/BrokerHeartbeatManagerTest.java @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import java.util.Collections; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Optional; +import java.util.Set; +import java.util.TreeSet; +import org.apache.kafka.common.message.BrokerHeartbeatRequestData; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.controller.BrokerHeartbeatManager.BrokerHeartbeatState; +import org.apache.kafka.controller.BrokerHeartbeatManager.BrokerHeartbeatStateIterator; +import org.apache.kafka.controller.BrokerHeartbeatManager.BrokerHeartbeatStateList; +import org.apache.kafka.controller.BrokerHeartbeatManager.UsableBrokerIterator; +import org.apache.kafka.metadata.UsableBroker; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import static org.apache.kafka.controller.BrokerControlState.CONTROLLED_SHUTDOWN; +import static org.apache.kafka.controller.BrokerControlState.FENCED; +import static org.apache.kafka.controller.BrokerControlState.SHUTDOWN_NOW; +import static org.apache.kafka.controller.BrokerControlState.UNFENCED; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + + +@Timeout(40) +public class BrokerHeartbeatManagerTest { + private static BrokerHeartbeatManager newBrokerHeartbeatManager() { + LogContext logContext = new LogContext(); + MockTime time = new MockTime(0, 1_000_000, 0); + return new BrokerHeartbeatManager(logContext, time, 10_000_000); + } + + @Test + public void testHasValidSession() { + BrokerHeartbeatManager manager = newBrokerHeartbeatManager(); + MockTime time = (MockTime) manager.time(); + assertFalse(manager.hasValidSession(0)); + manager.touch(0, false, 0); + time.sleep(5); + manager.touch(1, false, 0); + manager.touch(2, false, 0); + assertTrue(manager.hasValidSession(0)); + assertTrue(manager.hasValidSession(1)); + assertTrue(manager.hasValidSession(2)); + assertFalse(manager.hasValidSession(3)); + time.sleep(6); + assertFalse(manager.hasValidSession(0)); + assertTrue(manager.hasValidSession(1)); + assertTrue(manager.hasValidSession(2)); + assertFalse(manager.hasValidSession(3)); + manager.remove(2); + assertFalse(manager.hasValidSession(2)); + manager.remove(1); + assertFalse(manager.hasValidSession(1)); + } + + @Test + public void testFindOneStaleBroker() { + BrokerHeartbeatManager manager = newBrokerHeartbeatManager(); + MockTime time = (MockTime) manager.time(); + assertFalse(manager.hasValidSession(0)); + manager.touch(0, false, 0); + time.sleep(5); + manager.touch(1, false, 0); + time.sleep(1); + manager.touch(2, false, 0); + + Iterator iter = manager.unfenced().iterator(); + assertEquals(0, iter.next().id()); + assertEquals(1, iter.next().id()); + assertEquals(2, iter.next().id()); + assertFalse(iter.hasNext()); + assertEquals(Optional.empty(), manager.findOneStaleBroker()); + + time.sleep(5); + assertEquals(Optional.of(0), manager.findOneStaleBroker()); + manager.fence(0); + assertEquals(Optional.empty(), manager.findOneStaleBroker()); + iter = manager.unfenced().iterator(); + assertEquals(1, iter.next().id()); + assertEquals(2, iter.next().id()); + assertFalse(iter.hasNext()); + + time.sleep(20); + assertEquals(Optional.of(1), manager.findOneStaleBroker()); + manager.fence(1); + assertEquals(Optional.of(2), manager.findOneStaleBroker()); + manager.fence(2); + + assertEquals(Optional.empty(), manager.findOneStaleBroker()); + iter = manager.unfenced().iterator(); + assertFalse(iter.hasNext()); + } + + @Test + public void testNextCheckTimeNs() { + BrokerHeartbeatManager manager = newBrokerHeartbeatManager(); + MockTime time = (MockTime) manager.time(); + assertEquals(Long.MAX_VALUE, manager.nextCheckTimeNs()); + manager.touch(0, false, 0); + time.sleep(2); + manager.touch(1, false, 0); + time.sleep(1); + manager.touch(2, false, 0); + time.sleep(1); + manager.touch(3, false, 0); + assertEquals(Optional.empty(), manager.findOneStaleBroker()); + assertEquals(10_000_000, manager.nextCheckTimeNs()); + time.sleep(7); + assertEquals(10_000_000, manager.nextCheckTimeNs()); + assertEquals(Optional.of(0), manager.findOneStaleBroker()); + manager.fence(0); + assertEquals(12_000_000, manager.nextCheckTimeNs()); + + time.sleep(3); + assertEquals(Optional.of(1), manager.findOneStaleBroker()); + manager.fence(1); + assertEquals(Optional.of(2), manager.findOneStaleBroker()); + manager.fence(2); + + assertEquals(14_000_000, manager.nextCheckTimeNs()); + } + + @Test + public void testMetadataOffsetComparator() { + TreeSet set = + new TreeSet<>(BrokerHeartbeatManager.MetadataOffsetComparator.INSTANCE); + BrokerHeartbeatState broker1 = new BrokerHeartbeatState(1); + BrokerHeartbeatState broker2 = new BrokerHeartbeatState(2); + BrokerHeartbeatState broker3 = new BrokerHeartbeatState(3); + set.add(broker1); + set.add(broker2); + set.add(broker3); + Iterator iterator = set.iterator(); + assertEquals(broker1, iterator.next()); + assertEquals(broker2, iterator.next()); + assertEquals(broker3, iterator.next()); + assertFalse(iterator.hasNext()); + assertTrue(set.remove(broker1)); + assertTrue(set.remove(broker2)); + assertTrue(set.remove(broker3)); + assertTrue(set.isEmpty()); + broker1.metadataOffset = 800; + broker2.metadataOffset = 400; + broker3.metadataOffset = 100; + set.add(broker1); + set.add(broker2); + set.add(broker3); + iterator = set.iterator(); + assertEquals(broker3, iterator.next()); + assertEquals(broker2, iterator.next()); + assertEquals(broker1, iterator.next()); + assertFalse(iterator.hasNext()); + } + + private static Set usableBrokersToSet(BrokerHeartbeatManager manager) { + Set brokers = new HashSet<>(); + for (Iterator iterator = new UsableBrokerIterator( + manager.brokers().iterator(), + id -> id % 2 == 0 ? Optional.of("rack1") : Optional.of("rack2")); + iterator.hasNext(); ) { + brokers.add(iterator.next()); + } + return brokers; + } + + @Test + public void testUsableBrokerIterator() { + BrokerHeartbeatManager manager = newBrokerHeartbeatManager(); + assertEquals(Collections.emptySet(), usableBrokersToSet(manager)); + manager.touch(0, false, 100); + manager.touch(1, false, 100); + manager.touch(2, false, 98); + manager.touch(3, false, 100); + manager.touch(4, true, 100); + assertEquals(98L, manager.lowestActiveOffset()); + Set expected = new HashSet<>(); + expected.add(new UsableBroker(0, Optional.of("rack1"), false)); + expected.add(new UsableBroker(1, Optional.of("rack2"), false)); + expected.add(new UsableBroker(2, Optional.of("rack1"), false)); + expected.add(new UsableBroker(3, Optional.of("rack2"), false)); + expected.add(new UsableBroker(4, Optional.of("rack1"), true)); + assertEquals(expected, usableBrokersToSet(manager)); + manager.updateControlledShutdownOffset(2, 0); + assertEquals(100L, manager.lowestActiveOffset()); + assertThrows(RuntimeException.class, + () -> manager.updateControlledShutdownOffset(4, 0)); + manager.touch(4, false, 100); + manager.updateControlledShutdownOffset(4, 0); + expected.remove(new UsableBroker(2, Optional.of("rack1"), false)); + expected.remove(new UsableBroker(4, Optional.of("rack1"), true)); + assertEquals(expected, usableBrokersToSet(manager)); + } + + @Test + public void testBrokerHeartbeatStateList() { + BrokerHeartbeatStateList list = new BrokerHeartbeatStateList(); + assertEquals(null, list.first()); + BrokerHeartbeatStateIterator iterator = list.iterator(); + assertFalse(iterator.hasNext()); + BrokerHeartbeatState broker0 = new BrokerHeartbeatState(0); + broker0.lastContactNs = 200; + BrokerHeartbeatState broker1 = new BrokerHeartbeatState(1); + broker1.lastContactNs = 100; + BrokerHeartbeatState broker2 = new BrokerHeartbeatState(2); + broker2.lastContactNs = 50; + BrokerHeartbeatState broker3 = new BrokerHeartbeatState(3); + broker3.lastContactNs = 150; + list.add(broker0); + list.add(broker1); + list.add(broker2); + list.add(broker3); + assertEquals(broker2, list.first()); + iterator = list.iterator(); + assertEquals(broker2, iterator.next()); + assertEquals(broker1, iterator.next()); + assertEquals(broker3, iterator.next()); + assertEquals(broker0, iterator.next()); + assertFalse(iterator.hasNext()); + list.remove(broker1); + iterator = list.iterator(); + assertEquals(broker2, iterator.next()); + assertEquals(broker3, iterator.next()); + assertEquals(broker0, iterator.next()); + assertFalse(iterator.hasNext()); + } + + @Test + public void testCalculateNextBrokerState() { + BrokerHeartbeatManager manager = newBrokerHeartbeatManager(); + manager.touch(0, true, 100); + manager.touch(1, false, 98); + manager.touch(2, false, 100); + manager.touch(3, false, 100); + manager.touch(4, true, 100); + manager.touch(5, false, 99); + manager.updateControlledShutdownOffset(5, 99); + + assertEquals(98L, manager.lowestActiveOffset()); + + assertEquals(new BrokerControlStates(FENCED, SHUTDOWN_NOW), + manager.calculateNextBrokerState(0, + new BrokerHeartbeatRequestData().setWantShutDown(true), 100, () -> false)); + assertEquals(new BrokerControlStates(FENCED, UNFENCED), + manager.calculateNextBrokerState(0, + new BrokerHeartbeatRequestData().setWantFence(false). + setCurrentMetadataOffset(100), 100, () -> false)); + assertEquals(new BrokerControlStates(FENCED, FENCED), + manager.calculateNextBrokerState(0, + new BrokerHeartbeatRequestData().setWantFence(false). + setCurrentMetadataOffset(50), 100, () -> false)); + assertEquals(new BrokerControlStates(FENCED, FENCED), + manager.calculateNextBrokerState(0, + new BrokerHeartbeatRequestData().setWantFence(true), 100, () -> false)); + + assertEquals(new BrokerControlStates(UNFENCED, CONTROLLED_SHUTDOWN), + manager.calculateNextBrokerState(1, + new BrokerHeartbeatRequestData().setWantShutDown(true), 100, () -> true)); + assertEquals(new BrokerControlStates(UNFENCED, SHUTDOWN_NOW), + manager.calculateNextBrokerState(1, + new BrokerHeartbeatRequestData().setWantShutDown(true), 100, () -> false)); + assertEquals(new BrokerControlStates(UNFENCED, UNFENCED), + manager.calculateNextBrokerState(1, + new BrokerHeartbeatRequestData().setWantFence(false), 100, () -> false)); + + assertEquals(new BrokerControlStates(CONTROLLED_SHUTDOWN, CONTROLLED_SHUTDOWN), + manager.calculateNextBrokerState(5, + new BrokerHeartbeatRequestData().setWantShutDown(true), 100, () -> true)); + assertEquals(new BrokerControlStates(CONTROLLED_SHUTDOWN, CONTROLLED_SHUTDOWN), + manager.calculateNextBrokerState(5, + new BrokerHeartbeatRequestData().setWantShutDown(true), 100, () -> false)); + manager.fence(1); + assertEquals(new BrokerControlStates(CONTROLLED_SHUTDOWN, SHUTDOWN_NOW), + manager.calculateNextBrokerState(5, + new BrokerHeartbeatRequestData().setWantShutDown(true), 100, () -> false)); + assertEquals(new BrokerControlStates(CONTROLLED_SHUTDOWN, CONTROLLED_SHUTDOWN), + manager.calculateNextBrokerState(5, + new BrokerHeartbeatRequestData().setWantShutDown(true), 100, () -> true)); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/controller/BrokersToIsrsTest.java b/metadata/src/test/java/org/apache/kafka/controller/BrokersToIsrsTest.java new file mode 100644 index 0000000..6510ee5 --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/controller/BrokersToIsrsTest.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.controller.BrokersToIsrs.PartitionsOnReplicaIterator; +import org.apache.kafka.controller.BrokersToIsrs.TopicIdPartition; +import org.apache.kafka.timeline.SnapshotRegistry; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.HashSet; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; + + +@Timeout(40) +public class BrokersToIsrsTest { + private static final Uuid[] UUIDS = new Uuid[] { + Uuid.fromString("z5XgH_fQSAK3-RYoF2ymgw"), + Uuid.fromString("U52uRe20RsGI0RvpcTx33Q") + }; + + private static Set toSet(TopicIdPartition... partitions) { + HashSet set = new HashSet<>(); + for (TopicIdPartition partition : partitions) { + set.add(partition); + } + return set; + } + + private static Set toSet(PartitionsOnReplicaIterator iterator) { + HashSet set = new HashSet<>(); + while (iterator.hasNext()) { + set.add(iterator.next()); + } + return set; + } + + @Test + public void testIterator() { + SnapshotRegistry snapshotRegistry = new SnapshotRegistry(new LogContext()); + BrokersToIsrs brokersToIsrs = new BrokersToIsrs(snapshotRegistry); + assertEquals(toSet(), toSet(brokersToIsrs.iterator(1, false))); + brokersToIsrs.update(UUIDS[0], 0, null, new int[] {1, 2, 3}, -1, 1); + brokersToIsrs.update(UUIDS[1], 1, null, new int[] {2, 3, 4}, -1, 4); + assertEquals(toSet(new TopicIdPartition(UUIDS[0], 0)), + toSet(brokersToIsrs.iterator(1, false))); + assertEquals(toSet(new TopicIdPartition(UUIDS[0], 0), + new TopicIdPartition(UUIDS[1], 1)), + toSet(brokersToIsrs.iterator(2, false))); + assertEquals(toSet(new TopicIdPartition(UUIDS[1], 1)), + toSet(brokersToIsrs.iterator(4, false))); + assertEquals(toSet(), toSet(brokersToIsrs.iterator(5, false))); + brokersToIsrs.update(UUIDS[1], 2, null, new int[] {3, 2, 1}, -1, 3); + assertEquals(toSet(new TopicIdPartition(UUIDS[0], 0), + new TopicIdPartition(UUIDS[1], 1), + new TopicIdPartition(UUIDS[1], 2)), + toSet(brokersToIsrs.iterator(2, false))); + } + + @Test + public void testLeadersOnlyIterator() { + SnapshotRegistry snapshotRegistry = new SnapshotRegistry(new LogContext()); + BrokersToIsrs brokersToIsrs = new BrokersToIsrs(snapshotRegistry); + brokersToIsrs.update(UUIDS[0], 0, null, new int[]{1, 2, 3}, -1, 1); + brokersToIsrs.update(UUIDS[1], 1, null, new int[]{2, 3, 4}, -1, 4); + assertEquals(toSet(new TopicIdPartition(UUIDS[0], 0)), + toSet(brokersToIsrs.iterator(1, true))); + assertEquals(toSet(), toSet(brokersToIsrs.iterator(2, true))); + assertEquals(toSet(new TopicIdPartition(UUIDS[1], 1)), + toSet(brokersToIsrs.iterator(4, true))); + brokersToIsrs.update(UUIDS[0], 0, new int[]{1, 2, 3}, new int[]{1, 2, 3}, 1, 2); + assertEquals(toSet(), toSet(brokersToIsrs.iterator(1, true))); + assertEquals(toSet(new TopicIdPartition(UUIDS[0], 0)), + toSet(brokersToIsrs.iterator(2, true))); + } + + @Test + public void testNoLeader() { + SnapshotRegistry snapshotRegistry = new SnapshotRegistry(new LogContext()); + BrokersToIsrs brokersToIsrs = new BrokersToIsrs(snapshotRegistry); + brokersToIsrs.update(UUIDS[0], 2, null, new int[]{1, 2, 3}, -1, 3); + assertEquals(toSet(new TopicIdPartition(UUIDS[0], 2)), + toSet(brokersToIsrs.iterator(3, true))); + assertEquals(toSet(), toSet(brokersToIsrs.iterator(2, true))); + assertEquals(toSet(), toSet(brokersToIsrs.partitionsWithNoLeader())); + brokersToIsrs.update(UUIDS[0], 2, new int[]{1, 2, 3}, new int[]{1, 2, 3}, 3, -1); + assertEquals(toSet(new TopicIdPartition(UUIDS[0], 2)), + toSet(brokersToIsrs.partitionsWithNoLeader())); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/controller/ClientQuotaControlManagerTest.java b/metadata/src/test/java/org/apache/kafka/controller/ClientQuotaControlManagerTest.java new file mode 100644 index 0000000..b915db3 --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/controller/ClientQuotaControlManagerTest.java @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import org.apache.kafka.common.config.internals.QuotaConfigs; +import org.apache.kafka.common.metadata.ClientQuotaRecord; +import org.apache.kafka.common.metadata.ClientQuotaRecord.EntityData; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.quota.ClientQuotaAlteration; +import org.apache.kafka.common.quota.ClientQuotaEntity; +import org.apache.kafka.common.requests.ApiError; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.metadata.RecordTestUtils; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.apache.kafka.timeline.SnapshotRegistry; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@Timeout(value = 40) +public class ClientQuotaControlManagerTest { + + @Test + public void testInvalidEntityTypes() { + SnapshotRegistry snapshotRegistry = new SnapshotRegistry(new LogContext()); + ClientQuotaControlManager manager = new ClientQuotaControlManager(snapshotRegistry); + + // Unknown type "foo" + assertInvalidEntity(manager, entity("foo", "bar")); + + // Null type + assertInvalidEntity(manager, entity(null, "null")); + + // Valid + unknown combo + assertInvalidEntity(manager, entity(ClientQuotaEntity.USER, "user-1", "foo", "bar")); + assertInvalidEntity(manager, entity("foo", "bar", ClientQuotaEntity.IP, "1.2.3.4")); + + // Invalid combinations + assertInvalidEntity(manager, entity(ClientQuotaEntity.USER, "user-1", ClientQuotaEntity.IP, "1.2.3.4")); + assertInvalidEntity(manager, entity(ClientQuotaEntity.CLIENT_ID, "user-1", ClientQuotaEntity.IP, "1.2.3.4")); + + // Empty + assertInvalidEntity(manager, new ClientQuotaEntity(Collections.emptyMap())); + } + + private void assertInvalidEntity(ClientQuotaControlManager manager, ClientQuotaEntity entity) { + assertInvalidQuota(manager, entity, quotas(QuotaConfigs.PRODUCER_BYTE_RATE_OVERRIDE_CONFIG, 10000.0)); + } + + @Test + public void testInvalidQuotaKeys() { + SnapshotRegistry snapshotRegistry = new SnapshotRegistry(new LogContext()); + ClientQuotaControlManager manager = new ClientQuotaControlManager(snapshotRegistry); + ClientQuotaEntity entity = entity(ClientQuotaEntity.USER, "user-1"); + + // Invalid + valid keys + assertInvalidQuota(manager, entity, quotas("not.a.quota.key", 0.0, QuotaConfigs.REQUEST_PERCENTAGE_OVERRIDE_CONFIG, 99.9)); + + // Valid + invalid keys + assertInvalidQuota(manager, entity, quotas(QuotaConfigs.REQUEST_PERCENTAGE_OVERRIDE_CONFIG, 99.9, "not.a.quota.key", 0.0)); + + // Null key + assertInvalidQuota(manager, entity, quotas(null, 99.9)); + } + + private void assertInvalidQuota(ClientQuotaControlManager manager, ClientQuotaEntity entity, Map quota) { + List alters = new ArrayList<>(); + entityQuotaToAlterations(entity, quota, alters::add); + ControllerResult> result = manager.alterClientQuotas(alters); + assertEquals(Errors.INVALID_REQUEST, result.response().get(entity).error()); + assertEquals(0, result.records().size()); + } + + @Test + public void testAlterAndRemove() { + SnapshotRegistry snapshotRegistry = new SnapshotRegistry(new LogContext()); + ClientQuotaControlManager manager = new ClientQuotaControlManager(snapshotRegistry); + + ClientQuotaEntity userEntity = userEntity("user-1"); + List alters = new ArrayList<>(); + + // Add one quota + entityQuotaToAlterations(userEntity, quotas(QuotaConfigs.PRODUCER_BYTE_RATE_OVERRIDE_CONFIG, 10000.0), alters::add); + alterQuotas(alters, manager); + assertEquals(1, manager.clientQuotaData.get(userEntity).size()); + assertEquals(10000.0, manager.clientQuotaData.get(userEntity).get(QuotaConfigs.PRODUCER_BYTE_RATE_OVERRIDE_CONFIG), 1e-6); + + // Replace it and add another + alters.clear(); + entityQuotaToAlterations(userEntity, quotas( + QuotaConfigs.PRODUCER_BYTE_RATE_OVERRIDE_CONFIG, 10001.0, + QuotaConfigs.CONSUMER_BYTE_RATE_OVERRIDE_CONFIG, 20000.0 + ), alters::add); + alterQuotas(alters, manager); + assertEquals(2, manager.clientQuotaData.get(userEntity).size()); + assertEquals(10001.0, manager.clientQuotaData.get(userEntity).get(QuotaConfigs.PRODUCER_BYTE_RATE_OVERRIDE_CONFIG), 1e-6); + assertEquals(20000.0, manager.clientQuotaData.get(userEntity).get(QuotaConfigs.CONSUMER_BYTE_RATE_OVERRIDE_CONFIG), 1e-6); + + // Remove one of the quotas, the other remains + alters.clear(); + entityQuotaToAlterations(userEntity, quotas( + QuotaConfigs.PRODUCER_BYTE_RATE_OVERRIDE_CONFIG, null + ), alters::add); + alterQuotas(alters, manager); + assertEquals(1, manager.clientQuotaData.get(userEntity).size()); + assertEquals(20000.0, manager.clientQuotaData.get(userEntity).get(QuotaConfigs.CONSUMER_BYTE_RATE_OVERRIDE_CONFIG), 1e-6); + + // Remove non-existent quota, no change + alters.clear(); + entityQuotaToAlterations(userEntity, quotas( + QuotaConfigs.REQUEST_PERCENTAGE_OVERRIDE_CONFIG, null + ), alters::add); + alterQuotas(alters, manager); + assertEquals(1, manager.clientQuotaData.get(userEntity).size()); + assertEquals(20000.0, manager.clientQuotaData.get(userEntity).get(QuotaConfigs.CONSUMER_BYTE_RATE_OVERRIDE_CONFIG), 1e-6); + + // All quotas removed, we should cleanup the map + alters.clear(); + entityQuotaToAlterations(userEntity, quotas( + QuotaConfigs.CONSUMER_BYTE_RATE_OVERRIDE_CONFIG, null + ), alters::add); + alterQuotas(alters, manager); + assertFalse(manager.clientQuotaData.containsKey(userEntity)); + + // Remove non-existent quota, again no change + alters.clear(); + entityQuotaToAlterations(userEntity, quotas( + QuotaConfigs.CONSUMER_BYTE_RATE_OVERRIDE_CONFIG, null + ), alters::add); + alterQuotas(alters, manager); + assertFalse(manager.clientQuotaData.containsKey(userEntity)); + + // Mixed update + alters.clear(); + Map quotas = new HashMap<>(4); + quotas.put(QuotaConfigs.REQUEST_PERCENTAGE_OVERRIDE_CONFIG, 99.0); + quotas.put(QuotaConfigs.CONTROLLER_MUTATION_RATE_OVERRIDE_CONFIG, null); + quotas.put(QuotaConfigs.PRODUCER_BYTE_RATE_OVERRIDE_CONFIG, 10002.0); + quotas.put(QuotaConfigs.CONSUMER_BYTE_RATE_OVERRIDE_CONFIG, 20001.0); + + entityQuotaToAlterations(userEntity, quotas, alters::add); + alterQuotas(alters, manager); + assertEquals(3, manager.clientQuotaData.get(userEntity).size()); + assertEquals(20001.0, manager.clientQuotaData.get(userEntity).get(QuotaConfigs.CONSUMER_BYTE_RATE_OVERRIDE_CONFIG), 1e-6); + assertEquals(10002.0, manager.clientQuotaData.get(userEntity).get(QuotaConfigs.PRODUCER_BYTE_RATE_OVERRIDE_CONFIG), 1e-6); + assertEquals(99.0, manager.clientQuotaData.get(userEntity).get(QuotaConfigs.REQUEST_PERCENTAGE_OVERRIDE_CONFIG), 1e-6); + } + + @Test + public void testEntityTypes() throws Exception { + SnapshotRegistry snapshotRegistry = new SnapshotRegistry(new LogContext()); + ClientQuotaControlManager manager = new ClientQuotaControlManager(snapshotRegistry); + + Map> quotasToTest = new HashMap<>(); + quotasToTest.put(userClientEntity("user-1", "client-id-1"), + quotas(QuotaConfigs.REQUEST_PERCENTAGE_OVERRIDE_CONFIG, 50.50)); + quotasToTest.put(userClientEntity("user-2", "client-id-1"), + quotas(QuotaConfigs.REQUEST_PERCENTAGE_OVERRIDE_CONFIG, 51.51)); + quotasToTest.put(userClientEntity("user-3", "client-id-2"), + quotas(QuotaConfigs.REQUEST_PERCENTAGE_OVERRIDE_CONFIG, 52.52)); + quotasToTest.put(userClientEntity(null, "client-id-1"), + quotas(QuotaConfigs.REQUEST_PERCENTAGE_OVERRIDE_CONFIG, 53.53)); + quotasToTest.put(userClientEntity("user-1", null), + quotas(QuotaConfigs.REQUEST_PERCENTAGE_OVERRIDE_CONFIG, 54.54)); + quotasToTest.put(userClientEntity("user-3", null), + quotas(QuotaConfigs.REQUEST_PERCENTAGE_OVERRIDE_CONFIG, 55.55)); + quotasToTest.put(userEntity("user-1"), + quotas(QuotaConfigs.REQUEST_PERCENTAGE_OVERRIDE_CONFIG, 56.56)); + quotasToTest.put(userEntity("user-2"), + quotas(QuotaConfigs.REQUEST_PERCENTAGE_OVERRIDE_CONFIG, 57.57)); + quotasToTest.put(userEntity("user-3"), + quotas(QuotaConfigs.REQUEST_PERCENTAGE_OVERRIDE_CONFIG, 58.58)); + quotasToTest.put(userEntity(null), + quotas(QuotaConfigs.REQUEST_PERCENTAGE_OVERRIDE_CONFIG, 59.59)); + quotasToTest.put(clientEntity("client-id-2"), + quotas(QuotaConfigs.REQUEST_PERCENTAGE_OVERRIDE_CONFIG, 60.60)); + + List alters = new ArrayList<>(); + quotasToTest.forEach((entity, quota) -> entityQuotaToAlterations(entity, quota, alters::add)); + alterQuotas(alters, manager); + + RecordTestUtils.assertBatchIteratorContains(Arrays.asList( + Arrays.asList(new ApiMessageAndVersion(new ClientQuotaRecord().setEntity(Arrays.asList( + new EntityData().setEntityType("user").setEntityName("user-1"), + new EntityData().setEntityType("client-id").setEntityName("client-id-1"))). + setKey("request_percentage").setValue(50.5).setRemove(false), (short) 0)), + Arrays.asList(new ApiMessageAndVersion(new ClientQuotaRecord().setEntity(Arrays.asList( + new EntityData().setEntityType("user").setEntityName("user-2"), + new EntityData().setEntityType("client-id").setEntityName("client-id-1"))). + setKey("request_percentage").setValue(51.51).setRemove(false), (short) 0)), + Arrays.asList(new ApiMessageAndVersion(new ClientQuotaRecord().setEntity(Arrays.asList( + new EntityData().setEntityType("user").setEntityName("user-3"), + new EntityData().setEntityType("client-id").setEntityName("client-id-2"))). + setKey("request_percentage").setValue(52.52).setRemove(false), (short) 0)), + Arrays.asList(new ApiMessageAndVersion(new ClientQuotaRecord().setEntity(Arrays.asList( + new EntityData().setEntityType("user").setEntityName(null), + new EntityData().setEntityType("client-id").setEntityName("client-id-1"))). + setKey("request_percentage").setValue(53.53).setRemove(false), (short) 0)), + Arrays.asList(new ApiMessageAndVersion(new ClientQuotaRecord().setEntity(Arrays.asList( + new EntityData().setEntityType("user").setEntityName("user-1"), + new EntityData().setEntityType("client-id").setEntityName(null))). + setKey("request_percentage").setValue(54.54).setRemove(false), (short) 0)), + Arrays.asList(new ApiMessageAndVersion(new ClientQuotaRecord().setEntity(Arrays.asList( + new EntityData().setEntityType("user").setEntityName("user-3"), + new EntityData().setEntityType("client-id").setEntityName(null))). + setKey("request_percentage").setValue(55.55).setRemove(false), (short) 0)), + Arrays.asList(new ApiMessageAndVersion(new ClientQuotaRecord().setEntity(Arrays.asList( + new EntityData().setEntityType("user").setEntityName("user-1"))). + setKey("request_percentage").setValue(56.56).setRemove(false), (short) 0)), + Arrays.asList(new ApiMessageAndVersion(new ClientQuotaRecord().setEntity(Arrays.asList( + new EntityData().setEntityType("user").setEntityName("user-2"))). + setKey("request_percentage").setValue(57.57).setRemove(false), (short) 0)), + Arrays.asList(new ApiMessageAndVersion(new ClientQuotaRecord().setEntity(Arrays.asList( + new EntityData().setEntityType("user").setEntityName("user-3"))). + setKey("request_percentage").setValue(58.58).setRemove(false), (short) 0)), + Arrays.asList(new ApiMessageAndVersion(new ClientQuotaRecord().setEntity(Arrays.asList( + new EntityData().setEntityType("user").setEntityName(null))). + setKey("request_percentage").setValue(59.59).setRemove(false), (short) 0)), + Arrays.asList(new ApiMessageAndVersion(new ClientQuotaRecord().setEntity(Arrays.asList( + new EntityData().setEntityType("client-id").setEntityName("client-id-2"))). + setKey("request_percentage").setValue(60.60).setRemove(false), (short) 0))), + manager.iterator(Long.MAX_VALUE)); + } + + static void entityQuotaToAlterations(ClientQuotaEntity entity, Map quota, + Consumer acceptor) { + Collection ops = quota.entrySet().stream() + .map(quotaEntry -> new ClientQuotaAlteration.Op(quotaEntry.getKey(), quotaEntry.getValue())) + .collect(Collectors.toList()); + acceptor.accept(new ClientQuotaAlteration(entity, ops)); + } + + static void alterQuotas(List alterations, ClientQuotaControlManager manager) { + ControllerResult> result = manager.alterClientQuotas(alterations); + assertTrue(result.response().values().stream().allMatch(ApiError::isSuccess)); + result.records().forEach(apiMessageAndVersion -> + manager.replay((ClientQuotaRecord) apiMessageAndVersion.message())); + } + + static Map quotas(String key, Double value) { + return Collections.singletonMap(key, value); + } + + static Map quotas(String key1, Double value1, String key2, Double value2) { + Map quotas = new HashMap<>(2); + quotas.put(key1, value1); + quotas.put(key2, value2); + return quotas; + } + + static ClientQuotaEntity entity(String type, String name) { + return new ClientQuotaEntity(Collections.singletonMap(type, name)); + } + + static ClientQuotaEntity entity(String type1, String name1, String type2, String name2) { + Map entries = new HashMap<>(2); + entries.put(type1, name1); + entries.put(type2, name2); + return new ClientQuotaEntity(entries); + } + + static ClientQuotaEntity userEntity(String user) { + return new ClientQuotaEntity(Collections.singletonMap(ClientQuotaEntity.USER, user)); + } + + static ClientQuotaEntity clientEntity(String clientId) { + return new ClientQuotaEntity(Collections.singletonMap(ClientQuotaEntity.CLIENT_ID, clientId)); + } + + static ClientQuotaEntity userClientEntity(String user, String clientId) { + Map entries = new HashMap<>(2); + entries.put(ClientQuotaEntity.USER, user); + entries.put(ClientQuotaEntity.CLIENT_ID, clientId); + return new ClientQuotaEntity(entries); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/controller/ClusterControlManagerTest.java b/metadata/src/test/java/org/apache/kafka/controller/ClusterControlManagerTest.java new file mode 100644 index 0000000..16625b5 --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/controller/ClusterControlManagerTest.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Random; +import org.apache.kafka.common.Endpoint; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.errors.StaleBrokerEpochException; +import org.apache.kafka.common.metadata.RegisterBrokerRecord.BrokerEndpoint; +import org.apache.kafka.common.metadata.RegisterBrokerRecord.BrokerEndpointCollection; +import org.apache.kafka.common.metadata.RegisterBrokerRecord; +import org.apache.kafka.common.metadata.UnfenceBrokerRecord; +import org.apache.kafka.common.metadata.UnregisterBrokerRecord; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.metadata.BrokerRegistration; +import org.apache.kafka.metadata.RecordTestUtils; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.apache.kafka.timeline.SnapshotRegistry; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + + +@Timeout(value = 40) +public class ClusterControlManagerTest { + @Test + public void testReplay() { + MockTime time = new MockTime(0, 0, 0); + + SnapshotRegistry snapshotRegistry = new SnapshotRegistry(new LogContext()); + ClusterControlManager clusterControl = new ClusterControlManager( + new LogContext(), time, snapshotRegistry, 1000, + new StripedReplicaPlacer(new Random()), new MockControllerMetrics()); + clusterControl.activate(); + assertFalse(clusterControl.unfenced(0)); + + RegisterBrokerRecord brokerRecord = new RegisterBrokerRecord().setBrokerEpoch(100).setBrokerId(1); + brokerRecord.endPoints().add(new BrokerEndpoint(). + setSecurityProtocol(SecurityProtocol.PLAINTEXT.id). + setPort((short) 9092). + setName("PLAINTEXT"). + setHost("example.com")); + clusterControl.replay(brokerRecord); + clusterControl.checkBrokerEpoch(1, 100); + assertThrows(StaleBrokerEpochException.class, + () -> clusterControl.checkBrokerEpoch(1, 101)); + assertThrows(StaleBrokerEpochException.class, + () -> clusterControl.checkBrokerEpoch(2, 100)); + assertFalse(clusterControl.unfenced(0)); + assertFalse(clusterControl.unfenced(1)); + + UnfenceBrokerRecord unfenceBrokerRecord = + new UnfenceBrokerRecord().setId(1).setEpoch(100); + clusterControl.replay(unfenceBrokerRecord); + assertFalse(clusterControl.unfenced(0)); + assertTrue(clusterControl.unfenced(1)); + } + + @Test + public void testUnregister() throws Exception { + RegisterBrokerRecord brokerRecord = new RegisterBrokerRecord(). + setBrokerId(1). + setBrokerEpoch(100). + setIncarnationId(Uuid.fromString("fPZv1VBsRFmnlRvmGcOW9w")). + setRack("arack"); + brokerRecord.endPoints().add(new BrokerEndpoint(). + setSecurityProtocol(SecurityProtocol.PLAINTEXT.id). + setPort((short) 9092). + setName("PLAINTEXT"). + setHost("example.com")); + SnapshotRegistry snapshotRegistry = new SnapshotRegistry(new LogContext()); + ClusterControlManager clusterControl = new ClusterControlManager( + new LogContext(), new MockTime(0, 0, 0), snapshotRegistry, 1000, + new StripedReplicaPlacer(new Random()), new MockControllerMetrics()); + clusterControl.activate(); + clusterControl.replay(brokerRecord); + assertEquals(new BrokerRegistration(1, 100, + Uuid.fromString("fPZv1VBsRFmnlRvmGcOW9w"), Collections.singletonMap("PLAINTEXT", + new Endpoint("PLAINTEXT", SecurityProtocol.PLAINTEXT, "example.com", 9092)), + Collections.emptyMap(), Optional.of("arack"), true), + clusterControl.brokerRegistrations().get(1)); + UnregisterBrokerRecord unregisterRecord = new UnregisterBrokerRecord(). + setBrokerId(1). + setBrokerEpoch(100); + clusterControl.replay(unregisterRecord); + assertFalse(clusterControl.brokerRegistrations().containsKey(1)); + } + + @ParameterizedTest + @ValueSource(ints = {3, 10}) + public void testPlaceReplicas(int numUsableBrokers) throws Exception { + MockTime time = new MockTime(0, 0, 0); + SnapshotRegistry snapshotRegistry = new SnapshotRegistry(new LogContext()); + MockRandom random = new MockRandom(); + ClusterControlManager clusterControl = new ClusterControlManager( + new LogContext(), time, snapshotRegistry, 1000, + new StripedReplicaPlacer(random), new MockControllerMetrics()); + clusterControl.activate(); + for (int i = 0; i < numUsableBrokers; i++) { + RegisterBrokerRecord brokerRecord = + new RegisterBrokerRecord().setBrokerEpoch(100).setBrokerId(i); + brokerRecord.endPoints().add(new BrokerEndpoint(). + setSecurityProtocol(SecurityProtocol.PLAINTEXT.id). + setPort((short) 9092). + setName("PLAINTEXT"). + setHost("example.com")); + clusterControl.replay(brokerRecord); + UnfenceBrokerRecord unfenceRecord = + new UnfenceBrokerRecord().setId(i).setEpoch(100); + clusterControl.replay(unfenceRecord); + clusterControl.heartbeatManager().touch(i, false, 0); + } + for (int i = 0; i < numUsableBrokers; i++) { + assertTrue(clusterControl.unfenced(i), + String.format("broker %d was not unfenced.", i)); + } + for (int i = 0; i < 100; i++) { + List> results = clusterControl.placeReplicas(0, 1, (short) 3); + HashSet seen = new HashSet<>(); + for (Integer result : results.get(0)) { + assertTrue(result >= 0); + assertTrue(result < numUsableBrokers); + assertTrue(seen.add(result)); + } + } + } + + @Test + public void testIterator() throws Exception { + MockTime time = new MockTime(0, 0, 0); + SnapshotRegistry snapshotRegistry = new SnapshotRegistry(new LogContext()); + ClusterControlManager clusterControl = new ClusterControlManager( + new LogContext(), time, snapshotRegistry, 1000, + new StripedReplicaPlacer(new Random()), new MockControllerMetrics()); + clusterControl.activate(); + assertFalse(clusterControl.unfenced(0)); + for (int i = 0; i < 3; i++) { + RegisterBrokerRecord brokerRecord = new RegisterBrokerRecord(). + setBrokerEpoch(100).setBrokerId(i).setRack(null); + brokerRecord.endPoints().add(new BrokerEndpoint(). + setSecurityProtocol(SecurityProtocol.PLAINTEXT.id). + setPort((short) 9092 + i). + setName("PLAINTEXT"). + setHost("example.com")); + clusterControl.replay(brokerRecord); + } + for (int i = 0; i < 2; i++) { + UnfenceBrokerRecord unfenceBrokerRecord = + new UnfenceBrokerRecord().setId(i).setEpoch(100); + clusterControl.replay(unfenceBrokerRecord); + } + RecordTestUtils.assertBatchIteratorContains(Arrays.asList( + Arrays.asList(new ApiMessageAndVersion(new RegisterBrokerRecord(). + setBrokerEpoch(100).setBrokerId(0).setRack(null). + setEndPoints(new BrokerEndpointCollection(Collections.singleton( + new BrokerEndpoint().setSecurityProtocol(SecurityProtocol.PLAINTEXT.id). + setPort((short) 9092). + setName("PLAINTEXT"). + setHost("example.com")).iterator())). + setFenced(false), (short) 0)), + Arrays.asList(new ApiMessageAndVersion(new RegisterBrokerRecord(). + setBrokerEpoch(100).setBrokerId(1).setRack(null). + setEndPoints(new BrokerEndpointCollection(Collections.singleton( + new BrokerEndpoint().setSecurityProtocol(SecurityProtocol.PLAINTEXT.id). + setPort((short) 9093). + setName("PLAINTEXT"). + setHost("example.com")).iterator())). + setFenced(false), (short) 0)), + Arrays.asList(new ApiMessageAndVersion(new RegisterBrokerRecord(). + setBrokerEpoch(100).setBrokerId(2).setRack(null). + setEndPoints(new BrokerEndpointCollection(Collections.singleton( + new BrokerEndpoint().setSecurityProtocol(SecurityProtocol.PLAINTEXT.id). + setPort((short) 9094). + setName("PLAINTEXT"). + setHost("example.com")).iterator())). + setFenced(true), (short) 0))), + clusterControl.iterator(Long.MAX_VALUE)); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/controller/ConfigurationControlManagerTest.java b/metadata/src/test/java/org/apache/kafka/controller/ConfigurationControlManagerTest.java new file mode 100644 index 0000000..f84b12e --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/controller/ConfigurationControlManagerTest.java @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.errors.PolicyViolationException; +import org.apache.kafka.common.metadata.ConfigRecord; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.ApiError; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.metadata.RecordTestUtils; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.apache.kafka.server.policy.AlterConfigPolicy; +import org.apache.kafka.server.policy.AlterConfigPolicy.RequestMetadata; +import org.apache.kafka.timeline.SnapshotRegistry; + +import java.util.AbstractMap.SimpleImmutableEntry; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicLong; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import static java.util.Arrays.asList; +import static org.apache.kafka.clients.admin.AlterConfigOp.OpType.APPEND; +import static org.apache.kafka.clients.admin.AlterConfigOp.OpType.DELETE; +import static org.apache.kafka.clients.admin.AlterConfigOp.OpType.SET; +import static org.apache.kafka.clients.admin.AlterConfigOp.OpType.SUBTRACT; +import static org.apache.kafka.common.config.ConfigResource.Type.BROKER; +import static org.apache.kafka.common.config.ConfigResource.Type.BROKER_LOGGER; +import static org.apache.kafka.common.config.ConfigResource.Type.TOPIC; +import static org.apache.kafka.common.config.ConfigResource.Type.UNKNOWN; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + + +@Timeout(value = 40) +public class ConfigurationControlManagerTest { + + static final Map CONFIGS = new HashMap<>(); + + static { + CONFIGS.put(BROKER, new ConfigDef(). + define("foo.bar", ConfigDef.Type.LIST, "1", ConfigDef.Importance.HIGH, "foo bar"). + define("baz", ConfigDef.Type.STRING, ConfigDef.Importance.HIGH, "baz"). + define("quux", ConfigDef.Type.INT, ConfigDef.Importance.HIGH, "quux")); + CONFIGS.put(TOPIC, new ConfigDef(). + define("abc", ConfigDef.Type.LIST, ConfigDef.Importance.HIGH, "abc"). + define("def", ConfigDef.Type.STRING, ConfigDef.Importance.HIGH, "def"). + define("ghi", ConfigDef.Type.BOOLEAN, true, ConfigDef.Importance.HIGH, "ghi")); + } + + static final ConfigResource BROKER0 = new ConfigResource(BROKER, "0"); + static final ConfigResource MYTOPIC = new ConfigResource(TOPIC, "mytopic"); + + @SuppressWarnings("unchecked") + private static Map toMap(Entry... entries) { + Map map = new LinkedHashMap<>(); + for (Entry entry : entries) { + map.put(entry.getKey(), entry.getValue()); + } + return map; + } + + static Entry entry(A a, B b) { + return new SimpleImmutableEntry<>(a, b); + } + + @Test + public void testReplay() throws Exception { + SnapshotRegistry snapshotRegistry = new SnapshotRegistry(new LogContext()); + ConfigurationControlManager manager = + new ConfigurationControlManager(new LogContext(), snapshotRegistry, CONFIGS, + Optional.empty(), ConfigurationValidator.NO_OP); + assertEquals(Collections.emptyMap(), manager.getConfigs(BROKER0)); + manager.replay(new ConfigRecord(). + setResourceType(BROKER.id()).setResourceName("0"). + setName("foo.bar").setValue("1,2")); + assertEquals(Collections.singletonMap("foo.bar", "1,2"), + manager.getConfigs(BROKER0)); + manager.replay(new ConfigRecord(). + setResourceType(BROKER.id()).setResourceName("0"). + setName("foo.bar").setValue(null)); + assertEquals(Collections.emptyMap(), manager.getConfigs(BROKER0)); + manager.replay(new ConfigRecord(). + setResourceType(TOPIC.id()).setResourceName("mytopic"). + setName("abc").setValue("x,y,z")); + manager.replay(new ConfigRecord(). + setResourceType(TOPIC.id()).setResourceName("mytopic"). + setName("def").setValue("blah")); + assertEquals(toMap(entry("abc", "x,y,z"), entry("def", "blah")), + manager.getConfigs(MYTOPIC)); + RecordTestUtils.assertBatchIteratorContains(asList( + asList(new ApiMessageAndVersion(new ConfigRecord(). + setResourceType(TOPIC.id()).setResourceName("mytopic"). + setName("abc").setValue("x,y,z"), (short) 0), + new ApiMessageAndVersion(new ConfigRecord(). + setResourceType(TOPIC.id()).setResourceName("mytopic"). + setName("def").setValue("blah"), (short) 0))), + manager.iterator(Long.MAX_VALUE)); + } + + @Test + public void testCheckConfigResource() { + assertEquals(new ApiError(Errors.INVALID_REQUEST, "Unsupported " + + "configuration resource type BROKER_LOGGER ").toString(), + ConfigurationControlManager.checkConfigResource( + new ConfigResource(BROKER_LOGGER, "kafka.server.FetchContext")).toString()); + assertEquals(new ApiError(Errors.INVALID_REQUEST, "Illegal topic name.").toString(), + ConfigurationControlManager.checkConfigResource( + new ConfigResource(TOPIC, "* @ invalid$")).toString()); + assertEquals(new ApiError(Errors.INVALID_REQUEST, "Illegal topic name.").toString(), + ConfigurationControlManager.checkConfigResource( + new ConfigResource(TOPIC, "")).toString()); + assertEquals(new ApiError(Errors.INVALID_REQUEST, "Illegal non-integral " + + "BROKER resource type name.").toString(), + ConfigurationControlManager.checkConfigResource( + new ConfigResource(BROKER, "bob")).toString()); + assertEquals(new ApiError(Errors.NONE, null).toString(), + ConfigurationControlManager.checkConfigResource( + new ConfigResource(BROKER, "")).toString()); + assertEquals(new ApiError(Errors.INVALID_REQUEST, "Unsupported configuration " + + "resource type UNKNOWN.").toString(), + ConfigurationControlManager.checkConfigResource( + new ConfigResource(UNKNOWN, "bob")).toString()); + } + + @Test + public void testIncrementalAlterConfigs() { + SnapshotRegistry snapshotRegistry = new SnapshotRegistry(new LogContext()); + ConfigurationControlManager manager = + new ConfigurationControlManager(new LogContext(), snapshotRegistry, CONFIGS, + Optional.empty(), ConfigurationValidator.NO_OP); + + ControllerResult> result = manager. + incrementalAlterConfigs(toMap(entry(BROKER0, toMap( + entry("baz", entry(SUBTRACT, "abc")), + entry("quux", entry(SET, "abc")))), + entry(MYTOPIC, toMap(entry("abc", entry(APPEND, "123")))))); + + assertEquals(ControllerResult.atomicOf(Collections.singletonList(new ApiMessageAndVersion( + new ConfigRecord().setResourceType(TOPIC.id()).setResourceName("mytopic"). + setName("abc").setValue("123"), (short) 0)), + toMap(entry(BROKER0, new ApiError(Errors.INVALID_CONFIG, + "Can't SUBTRACT to key baz because its type is not LIST.")), + entry(MYTOPIC, ApiError.NONE))), result); + + RecordTestUtils.replayAll(manager, result.records()); + + assertEquals(ControllerResult.atomicOf(Collections.singletonList(new ApiMessageAndVersion( + new ConfigRecord().setResourceType(TOPIC.id()).setResourceName("mytopic"). + setName("abc").setValue(null), (short) 0)), + toMap(entry(MYTOPIC, ApiError.NONE))), + manager.incrementalAlterConfigs(toMap(entry(MYTOPIC, toMap( + entry("abc", entry(DELETE, "xyz"))))))); + } + + private static class MockAlterConfigsPolicy implements AlterConfigPolicy { + private final List expecteds; + private final AtomicLong index = new AtomicLong(0); + + MockAlterConfigsPolicy(List expecteds) { + this.expecteds = expecteds; + } + + @Override + public void validate(RequestMetadata actual) throws PolicyViolationException { + long curIndex = index.getAndIncrement(); + if (curIndex >= expecteds.size()) { + throw new PolicyViolationException("Unexpected config alteration: index " + + "out of range at " + curIndex); + } + RequestMetadata expected = expecteds.get((int) curIndex); + if (!expected.equals(actual)) { + throw new PolicyViolationException("Expected: " + expected + + ". Got: " + actual); + } + } + + @Override + public void close() throws Exception { + // nothing to do + } + + @Override + public void configure(Map configs) { + // nothing to do + } + } + + @Test + public void testIncrementalAlterConfigsWithPolicy() { + SnapshotRegistry snapshotRegistry = new SnapshotRegistry(new LogContext()); + MockAlterConfigsPolicy policy = new MockAlterConfigsPolicy(asList( + new RequestMetadata(MYTOPIC, Collections.emptyMap()), + new RequestMetadata(BROKER0, toMap(entry("foo.bar", "123"), + entry("quux", "456"))))); + ConfigurationControlManager manager = new ConfigurationControlManager( + new LogContext(), snapshotRegistry, CONFIGS, Optional.of(policy), + ConfigurationValidator.NO_OP); + + assertEquals(ControllerResult.atomicOf(asList(new ApiMessageAndVersion( + new ConfigRecord().setResourceType(BROKER.id()).setResourceName("0"). + setName("foo.bar").setValue("123"), (short) 0), new ApiMessageAndVersion( + new ConfigRecord().setResourceType(BROKER.id()).setResourceName("0"). + setName("quux").setValue("456"), (short) 0)), + toMap(entry(MYTOPIC, new ApiError(Errors.POLICY_VIOLATION, + "Expected: AlterConfigPolicy.RequestMetadata(resource=ConfigResource(" + + "type=TOPIC, name='mytopic'), configs={}). Got: " + + "AlterConfigPolicy.RequestMetadata(resource=ConfigResource(" + + "type=TOPIC, name='mytopic'), configs={foo.bar=123})")), + entry(BROKER0, ApiError.NONE))), + manager.incrementalAlterConfigs(toMap(entry(MYTOPIC, toMap( + entry("foo.bar", entry(SET, "123")))), + entry(BROKER0, toMap( + entry("foo.bar", entry(SET, "123")), + entry("quux", entry(SET, "456"))))))); + } + + @Test + public void testIsSplittable() { + SnapshotRegistry snapshotRegistry = new SnapshotRegistry(new LogContext()); + ConfigurationControlManager manager = + new ConfigurationControlManager(new LogContext(), snapshotRegistry, CONFIGS, + Optional.empty(), ConfigurationValidator.NO_OP); + assertTrue(manager.isSplittable(BROKER, "foo.bar")); + assertFalse(manager.isSplittable(BROKER, "baz")); + assertFalse(manager.isSplittable(BROKER, "foo.baz.quux")); + assertFalse(manager.isSplittable(TOPIC, "baz")); + assertTrue(manager.isSplittable(TOPIC, "abc")); + } + + @Test + public void testGetConfigValueDefault() { + SnapshotRegistry snapshotRegistry = new SnapshotRegistry(new LogContext()); + ConfigurationControlManager manager = + new ConfigurationControlManager(new LogContext(), snapshotRegistry, CONFIGS, + Optional.empty(), ConfigurationValidator.NO_OP); + assertEquals("1", manager.getConfigValueDefault(BROKER, "foo.bar")); + assertEquals(null, manager.getConfigValueDefault(BROKER, "foo.baz.quux")); + assertEquals(null, manager.getConfigValueDefault(TOPIC, "abc")); + assertEquals("true", manager.getConfigValueDefault(TOPIC, "ghi")); + } + + @Test + public void testLegacyAlterConfigs() { + SnapshotRegistry snapshotRegistry = new SnapshotRegistry(new LogContext()); + ConfigurationControlManager manager = + new ConfigurationControlManager(new LogContext(), snapshotRegistry, CONFIGS, + Optional.empty(), ConfigurationValidator.NO_OP); + List expectedRecords1 = asList( + new ApiMessageAndVersion(new ConfigRecord(). + setResourceType(TOPIC.id()).setResourceName("mytopic"). + setName("abc").setValue("456"), (short) 0), + new ApiMessageAndVersion(new ConfigRecord(). + setResourceType(TOPIC.id()).setResourceName("mytopic"). + setName("def").setValue("901"), (short) 0)); + assertEquals( + ControllerResult.atomicOf( + expectedRecords1, + toMap(entry(MYTOPIC, ApiError.NONE)) + ), + manager.legacyAlterConfigs( + toMap(entry(MYTOPIC, toMap(entry("abc", "456"), entry("def", "901")))) + ) + ); + for (ApiMessageAndVersion message : expectedRecords1) { + manager.replay((ConfigRecord) message.message()); + } + assertEquals( + ControllerResult.atomicOf( + asList( + new ApiMessageAndVersion( + new ConfigRecord() + .setResourceType(TOPIC.id()) + .setResourceName("mytopic") + .setName("abc") + .setValue(null), + (short) 0 + ) + ), + toMap(entry(MYTOPIC, ApiError.NONE)) + ), + manager.legacyAlterConfigs(toMap(entry(MYTOPIC, toMap(entry("def", "901"))))) + ); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/controller/ControllerPurgatoryTest.java b/metadata/src/test/java/org/apache/kafka/controller/ControllerPurgatoryTest.java new file mode 100644 index 0000000..57953e1 --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/controller/ControllerPurgatoryTest.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@Timeout(value = 40) +public class ControllerPurgatoryTest { + + static class SampleDeferredEvent implements DeferredEvent { + private final CompletableFuture future = new CompletableFuture<>(); + + @Override + public void complete(Throwable exception) { + if (exception != null) { + future.completeExceptionally(exception); + } else { + future.complete(null); + } + } + + CompletableFuture future() { + return future; + } + } + + @Test + public void testCompleteEvents() { + ControllerPurgatory purgatory = new ControllerPurgatory(); + SampleDeferredEvent event1 = new SampleDeferredEvent(); + SampleDeferredEvent event2 = new SampleDeferredEvent(); + SampleDeferredEvent event3 = new SampleDeferredEvent(); + purgatory.add(1, event1); + assertEquals(Optional.of(1L), purgatory.highestPendingOffset()); + purgatory.add(1, event2); + assertEquals(Optional.of(1L), purgatory.highestPendingOffset()); + purgatory.add(3, event3); + assertEquals(Optional.of(3L), purgatory.highestPendingOffset()); + purgatory.completeUpTo(2); + assertTrue(event1.future.isDone()); + assertTrue(event2.future.isDone()); + assertFalse(event3.future.isDone()); + purgatory.completeUpTo(4); + assertTrue(event3.future.isDone()); + assertEquals(Optional.empty(), purgatory.highestPendingOffset()); + } + + @Test + public void testFailOnIncorrectOrdering() { + ControllerPurgatory purgatory = new ControllerPurgatory(); + SampleDeferredEvent event1 = new SampleDeferredEvent(); + SampleDeferredEvent event2 = new SampleDeferredEvent(); + purgatory.add(2, event1); + assertThrows(RuntimeException.class, () -> purgatory.add(1, event2)); + } + + @Test + public void testFailEvents() { + ControllerPurgatory purgatory = new ControllerPurgatory(); + SampleDeferredEvent event1 = new SampleDeferredEvent(); + SampleDeferredEvent event2 = new SampleDeferredEvent(); + SampleDeferredEvent event3 = new SampleDeferredEvent(); + purgatory.add(1, event1); + purgatory.add(3, event2); + purgatory.add(3, event3); + purgatory.completeUpTo(2); + assertTrue(event1.future.isDone()); + assertFalse(event2.future.isDone()); + assertFalse(event3.future.isDone()); + purgatory.failAll(new RuntimeException("failed")); + assertTrue(event2.future.isDone()); + assertTrue(event3.future.isDone()); + assertEquals(RuntimeException.class, assertThrows(ExecutionException.class, + () -> event2.future.get()).getCause().getClass()); + assertEquals(RuntimeException.class, assertThrows(ExecutionException.class, + () -> event3.future.get()).getCause().getClass()); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/controller/FeatureControlManagerTest.java b/metadata/src/test/java/org/apache/kafka/controller/FeatureControlManagerTest.java new file mode 100644 index 0000000..680253c --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/controller/FeatureControlManagerTest.java @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.kafka.common.metadata.FeatureLevelRecord; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.ApiError; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.metadata.FeatureMap; +import org.apache.kafka.metadata.FeatureMapAndEpoch; +import org.apache.kafka.metadata.RecordTestUtils; +import org.apache.kafka.metadata.VersionRange; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.apache.kafka.timeline.SnapshotRegistry; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import static org.junit.jupiter.api.Assertions.assertEquals; + + +@Timeout(value = 40) +public class FeatureControlManagerTest { + @SuppressWarnings("unchecked") + private static Map rangeMap(Object... args) { + Map result = new HashMap<>(); + for (int i = 0; i < args.length; i += 3) { + String feature = (String) args[i]; + Integer low = (Integer) args[i + 1]; + Integer high = (Integer) args[i + 2]; + result.put(feature, new VersionRange(low.shortValue(), high.shortValue())); + } + return result; + } + + @Test + public void testUpdateFeatures() { + SnapshotRegistry snapshotRegistry = new SnapshotRegistry(new LogContext()); + snapshotRegistry.getOrCreateSnapshot(-1); + FeatureControlManager manager = new FeatureControlManager( + rangeMap("foo", 1, 2), snapshotRegistry); + assertEquals(new FeatureMapAndEpoch(new FeatureMap(Collections.emptyMap()), -1), + manager.finalizedFeatures(-1)); + assertEquals(ControllerResult.atomicOf(Collections.emptyList(), Collections. + singletonMap("foo", new ApiError(Errors.INVALID_UPDATE_VERSION, + "The controller does not support the given feature range."))), + manager.updateFeatures(rangeMap("foo", 1, 3), + Collections.singleton("foo"), + Collections.emptyMap())); + ControllerResult> result = manager.updateFeatures( + rangeMap("foo", 1, 2, "bar", 1, 1), Collections.emptySet(), + Collections.emptyMap()); + Map expectedMap = new HashMap<>(); + expectedMap.put("foo", ApiError.NONE); + expectedMap.put("bar", new ApiError(Errors.INVALID_UPDATE_VERSION, + "The controller does not support the given feature range.")); + assertEquals(expectedMap, result.response()); + List expectedMessages = new ArrayList<>(); + expectedMessages.add(new ApiMessageAndVersion(new FeatureLevelRecord(). + setName("foo").setMinFeatureLevel((short) 1).setMaxFeatureLevel((short) 2), + (short) 0)); + assertEquals(expectedMessages, result.records()); + } + + @Test + public void testReplay() { + FeatureLevelRecord record = new FeatureLevelRecord(). + setName("foo").setMinFeatureLevel((short) 1).setMaxFeatureLevel((short) 2); + SnapshotRegistry snapshotRegistry = new SnapshotRegistry(new LogContext()); + snapshotRegistry.getOrCreateSnapshot(-1); + FeatureControlManager manager = new FeatureControlManager( + rangeMap("foo", 1, 2), snapshotRegistry); + manager.replay(record); + snapshotRegistry.getOrCreateSnapshot(123); + assertEquals(new FeatureMapAndEpoch(new FeatureMap(rangeMap("foo", 1, 2)), 123), + manager.finalizedFeatures(123)); + } + + @Test + public void testUpdateFeaturesErrorCases() { + SnapshotRegistry snapshotRegistry = new SnapshotRegistry(new LogContext()); + FeatureControlManager manager = new FeatureControlManager( + rangeMap("foo", 1, 5, "bar", 1, 2), snapshotRegistry); + + assertEquals( + ControllerResult.atomicOf( + Collections.emptyList(), + Collections.singletonMap( + "foo", + new ApiError( + Errors.INVALID_UPDATE_VERSION, + "Broker 5 does not support the given feature range." + ) + ) + ), + manager.updateFeatures( + rangeMap("foo", 1, 3), + Collections.singleton("foo"), + Collections.singletonMap(5, rangeMap()) + ) + ); + + ControllerResult> result = manager.updateFeatures( + rangeMap("foo", 1, 3), Collections.emptySet(), Collections.emptyMap()); + assertEquals(Collections.singletonMap("foo", ApiError.NONE), result.response()); + manager.replay((FeatureLevelRecord) result.records().get(0).message()); + snapshotRegistry.getOrCreateSnapshot(3); + + assertEquals(ControllerResult.atomicOf(Collections.emptyList(), Collections. + singletonMap("foo", new ApiError(Errors.INVALID_UPDATE_VERSION, + "Can't downgrade the maximum version of this feature without " + + "setting downgradable to true."))), + manager.updateFeatures(rangeMap("foo", 1, 2), + Collections.emptySet(), Collections.emptyMap())); + + assertEquals( + ControllerResult.atomicOf( + Collections.singletonList( + new ApiMessageAndVersion( + new FeatureLevelRecord() + .setName("foo") + .setMinFeatureLevel((short) 1) + .setMaxFeatureLevel((short) 2), + (short) 0 + ) + ), + Collections.singletonMap("foo", ApiError.NONE) + ), + manager.updateFeatures( + rangeMap("foo", 1, 2), + Collections.singleton("foo"), + Collections.emptyMap() + ) + ); + } + + @Test + public void testFeatureControlIterator() throws Exception { + SnapshotRegistry snapshotRegistry = new SnapshotRegistry(new LogContext()); + FeatureControlManager manager = new FeatureControlManager( + rangeMap("foo", 1, 5, "bar", 1, 2), snapshotRegistry); + ControllerResult> result = manager. + updateFeatures(rangeMap("foo", 1, 5, "bar", 1, 1), + Collections.emptySet(), Collections.emptyMap()); + RecordTestUtils.replayAll(manager, result.records()); + RecordTestUtils.assertBatchIteratorContains(Arrays.asList( + Arrays.asList(new ApiMessageAndVersion(new FeatureLevelRecord(). + setName("foo"). + setMinFeatureLevel((short) 1). + setMaxFeatureLevel((short) 5), (short) 0)), + Arrays.asList(new ApiMessageAndVersion(new FeatureLevelRecord(). + setName("bar"). + setMinFeatureLevel((short) 1). + setMaxFeatureLevel((short) 1), (short) 0))), + manager.iterator(Long.MAX_VALUE)); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/controller/MockControllerMetrics.java b/metadata/src/test/java/org/apache/kafka/controller/MockControllerMetrics.java new file mode 100644 index 0000000..0120f15 --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/controller/MockControllerMetrics.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +public final class MockControllerMetrics implements ControllerMetrics { + private volatile boolean active; + private volatile int fencedBrokers; + private volatile int activeBrokers; + private volatile int topics; + private volatile int partitions; + private volatile int offlinePartitions; + private volatile int preferredReplicaImbalances; + private volatile boolean closed = false; + + public MockControllerMetrics() { + this.active = false; + this.fencedBrokers = 0; + this.activeBrokers = 0; + this.topics = 0; + this.partitions = 0; + this.offlinePartitions = 0; + this.preferredReplicaImbalances = 0; + } + + @Override + public void setActive(boolean active) { + this.active = active; + } + + @Override + public boolean active() { + return this.active; + } + + @Override + public void updateEventQueueTime(long durationMs) { + // nothing to do + } + + @Override + public void updateEventQueueProcessingTime(long durationMs) { + // nothing to do + } + + @Override + public void setFencedBrokerCount(int brokerCount) { + this.fencedBrokers = brokerCount; + } + + @Override + public int fencedBrokerCount() { + return this.fencedBrokers; + } + + @Override + public void setActiveBrokerCount(int brokerCount) { + this.activeBrokers = brokerCount; + } + + @Override + public int activeBrokerCount() { + return activeBrokers; + } + + @Override + public void setGlobalTopicsCount(int topicCount) { + this.topics = topicCount; + } + + @Override + public int globalTopicsCount() { + return this.topics; + } + + @Override + public void setGlobalPartitionCount(int partitionCount) { + this.partitions = partitionCount; + } + + @Override + public int globalPartitionCount() { + return this.partitions; + } + + @Override + public void setOfflinePartitionCount(int offlinePartitions) { + this.offlinePartitions = offlinePartitions; + } + + @Override + public int offlinePartitionCount() { + return this.offlinePartitions; + } + + @Override + public void setPreferredReplicaImbalanceCount(int replicaImbalances) { + this.preferredReplicaImbalances = replicaImbalances; + } + + @Override + public int preferredReplicaImbalanceCount() { + return this.preferredReplicaImbalances; + } + + @Override + public void close() { + closed = true; + } + + public boolean isClosed() { + return this.closed; + } +} diff --git a/metadata/src/test/java/org/apache/kafka/controller/MockRandom.java b/metadata/src/test/java/org/apache/kafka/controller/MockRandom.java new file mode 100644 index 0000000..c42a158 --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/controller/MockRandom.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import java.util.Random; + + +/** + * A subclass of Random with a fixed seed and generation algorithm. + */ +public class MockRandom extends Random { + private long state = 17; + + @Override + protected int next(int bits) { + state = (state * 2862933555777941757L) + 3037000493L; + return (int) (state >>> (64 - bits)); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/controller/PartitionChangeBuilderTest.java b/metadata/src/test/java/org/apache/kafka/controller/PartitionChangeBuilderTest.java new file mode 100644 index 0000000..f935a80 --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/controller/PartitionChangeBuilderTest.java @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.metadata.PartitionChangeRecord; +import org.apache.kafka.controller.PartitionChangeBuilder.BestLeader; +import org.apache.kafka.metadata.PartitionRegistration; +import org.apache.kafka.metadata.Replicas; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Optional; + +import static org.apache.kafka.common.metadata.MetadataRecordType.PARTITION_CHANGE_RECORD; +import static org.apache.kafka.controller.PartitionChangeBuilder.changeRecordIsNoOp; +import static org.apache.kafka.metadata.LeaderConstants.NO_LEADER; +import static org.apache.kafka.metadata.LeaderConstants.NO_LEADER_CHANGE; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + + +@Timeout(value = 40) +public class PartitionChangeBuilderTest { + @Test + public void testChangeRecordIsNoOp() { + assertTrue(changeRecordIsNoOp(new PartitionChangeRecord())); + assertFalse(changeRecordIsNoOp(new PartitionChangeRecord().setLeader(1))); + assertFalse(changeRecordIsNoOp(new PartitionChangeRecord(). + setIsr(Arrays.asList(1, 2, 3)))); + assertFalse(changeRecordIsNoOp(new PartitionChangeRecord(). + setRemovingReplicas(Arrays.asList(1)))); + assertFalse(changeRecordIsNoOp(new PartitionChangeRecord(). + setAddingReplicas(Arrays.asList(4)))); + } + + private final static PartitionRegistration FOO = new PartitionRegistration( + new int[] {2, 1, 3}, new int[] {2, 1, 3}, Replicas.NONE, Replicas.NONE, + 1, 100, 200); + + private final static Uuid FOO_ID = Uuid.fromString("FbrrdcfiR-KC2CPSTHaJrg"); + + private static PartitionChangeBuilder createFooBuilder(boolean allowUnclean) { + return new PartitionChangeBuilder(FOO, FOO_ID, 0, r -> r != 3, () -> allowUnclean); + } + + private final static PartitionRegistration BAR = new PartitionRegistration( + new int[] {1, 2, 3, 4}, new int[] {1, 2, 3}, new int[] {1}, new int[] {4}, + 1, 100, 200); + + private final static Uuid BAR_ID = Uuid.fromString("LKfUsCBnQKekvL9O5dY9nw"); + + private static PartitionChangeBuilder createBarBuilder(boolean allowUnclean) { + return new PartitionChangeBuilder(BAR, BAR_ID, 0, r -> r != 3, () -> allowUnclean); + } + + private static void assertBestLeaderEquals(PartitionChangeBuilder builder, + int expectedNode, + boolean expectedUnclean) { + BestLeader bestLeader = builder.new BestLeader(); + assertEquals(expectedNode, bestLeader.node); + assertEquals(expectedUnclean, bestLeader.unclean); + } + + @Test + public void testBestLeader() { + assertBestLeaderEquals(createFooBuilder(false), 2, false); + assertBestLeaderEquals(createFooBuilder(true), 2, false); + assertBestLeaderEquals(createFooBuilder(false). + setTargetIsr(Arrays.asList(1, 3)), 1, false); + assertBestLeaderEquals(createFooBuilder(true). + setTargetIsr(Arrays.asList(1, 3)), 1, false); + assertBestLeaderEquals(createFooBuilder(false). + setTargetIsr(Arrays.asList(3)), NO_LEADER, false); + assertBestLeaderEquals(createFooBuilder(true). + setTargetIsr(Arrays.asList(3)), 2, true); + assertBestLeaderEquals(createFooBuilder(true). + setTargetIsr(Arrays.asList(4)).setTargetReplicas(Arrays.asList(2, 1, 3, 4)), + 4, false); + } + + @Test + public void testShouldTryElection() { + assertFalse(createFooBuilder(false).shouldTryElection()); + assertTrue(createFooBuilder(false).setAlwaysElectPreferredIfPossible(true). + shouldTryElection()); + assertTrue(createFooBuilder(false).setTargetIsr(Arrays.asList(2, 3)). + shouldTryElection()); + assertFalse(createFooBuilder(false).setTargetIsr(Arrays.asList(2, 1)). + shouldTryElection()); + + assertTrue(createFooBuilder(true) + .setTargetIsr(Arrays.asList(3)) + .shouldTryElection()); + assertTrue(createFooBuilder(true) + .setTargetIsr(Arrays.asList(4)) + .setTargetReplicas(Arrays.asList(2, 1, 3, 4)) + .shouldTryElection()); + } + + private static void testTriggerLeaderEpochBumpIfNeededLeader(PartitionChangeBuilder builder, + PartitionChangeRecord record, + int expectedLeader) { + builder.triggerLeaderEpochBumpIfNeeded(record); + assertEquals(expectedLeader, record.leader()); + } + + @Test + public void testTriggerLeaderEpochBumpIfNeeded() { + testTriggerLeaderEpochBumpIfNeededLeader(createFooBuilder(false), + new PartitionChangeRecord(), NO_LEADER_CHANGE); + testTriggerLeaderEpochBumpIfNeededLeader(createFooBuilder(false). + setTargetIsr(Arrays.asList(2, 1)), new PartitionChangeRecord(), 1); + testTriggerLeaderEpochBumpIfNeededLeader(createFooBuilder(false). + setTargetIsr(Arrays.asList(2, 1, 3, 4)), new PartitionChangeRecord(), + NO_LEADER_CHANGE); + testTriggerLeaderEpochBumpIfNeededLeader(createFooBuilder(false). + setTargetReplicas(Arrays.asList(2, 1, 3, 4)), new PartitionChangeRecord(), + NO_LEADER_CHANGE); + testTriggerLeaderEpochBumpIfNeededLeader(createFooBuilder(false). + setTargetReplicas(Arrays.asList(2, 1, 3, 4)), + new PartitionChangeRecord().setLeader(2), 2); + } + + @Test + public void testNoChange() { + assertEquals(Optional.empty(), createFooBuilder(false).build()); + assertEquals(Optional.empty(), createFooBuilder(true).build()); + assertEquals(Optional.empty(), createBarBuilder(false).build()); + assertEquals(Optional.empty(), createBarBuilder(true).build()); + } + + @Test + public void testIsrChangeAndLeaderBump() { + assertEquals(Optional.of(new ApiMessageAndVersion(new PartitionChangeRecord(). + setTopicId(FOO_ID). + setPartitionId(0). + setIsr(Arrays.asList(2, 1)). + setLeader(1), PARTITION_CHANGE_RECORD.highestSupportedVersion())), + createFooBuilder(false).setTargetIsr(Arrays.asList(2, 1)).build()); + } + + @Test + public void testIsrChangeAndLeaderChange() { + assertEquals(Optional.of(new ApiMessageAndVersion(new PartitionChangeRecord(). + setTopicId(FOO_ID). + setPartitionId(0). + setIsr(Arrays.asList(2, 3)). + setLeader(2), PARTITION_CHANGE_RECORD.highestSupportedVersion())), + createFooBuilder(false).setTargetIsr(Arrays.asList(2, 3)).build()); + } + + @Test + public void testReassignmentRearrangesReplicas() { + assertEquals(Optional.of(new ApiMessageAndVersion(new PartitionChangeRecord(). + setTopicId(FOO_ID). + setPartitionId(0). + setReplicas(Arrays.asList(3, 2, 1)), + PARTITION_CHANGE_RECORD.highestSupportedVersion())), + createFooBuilder(false).setTargetReplicas(Arrays.asList(3, 2, 1)).build()); + } + + @Test + public void testIsrEnlargementCompletesReassignment() { + assertEquals(Optional.of(new ApiMessageAndVersion(new PartitionChangeRecord(). + setTopicId(BAR_ID). + setPartitionId(0). + setReplicas(Arrays.asList(2, 3, 4)). + setIsr(Arrays.asList(2, 3, 4)). + setLeader(2). + setRemovingReplicas(Collections.emptyList()). + setAddingReplicas(Collections.emptyList()), + PARTITION_CHANGE_RECORD.highestSupportedVersion())), + createBarBuilder(false).setTargetIsr(Arrays.asList(1, 2, 3, 4)).build()); + } + + @Test + public void testRevertReassignment() { + PartitionReassignmentRevert revert = new PartitionReassignmentRevert(BAR); + assertEquals(Arrays.asList(1, 2, 3), revert.replicas()); + assertEquals(Arrays.asList(1, 2, 3), revert.isr()); + assertEquals(Optional.of(new ApiMessageAndVersion(new PartitionChangeRecord(). + setTopicId(BAR_ID). + setPartitionId(0). + setReplicas(Arrays.asList(1, 2, 3)). + setLeader(1). + setRemovingReplicas(Collections.emptyList()). + setAddingReplicas(Collections.emptyList()), + PARTITION_CHANGE_RECORD.highestSupportedVersion())), + createBarBuilder(false). + setTargetReplicas(revert.replicas()). + setTargetIsr(revert.isr()). + setTargetRemoving(Collections.emptyList()). + setTargetAdding(Collections.emptyList()). + build()); + } + + @Test + public void testRemovingReplicaReassignment() { + PartitionReassignmentReplicas replicas = new PartitionReassignmentReplicas( + Replicas.toList(FOO.replicas), Arrays.asList(1, 2)); + assertEquals(Collections.singletonList(3), replicas.removing()); + assertEquals(Collections.emptyList(), replicas.adding()); + assertEquals(Arrays.asList(1, 2, 3), replicas.merged()); + assertEquals(Optional.of(new ApiMessageAndVersion(new PartitionChangeRecord(). + setTopicId(FOO_ID). + setPartitionId(0). + setReplicas(Arrays.asList(1, 2)). + setIsr(Arrays.asList(2, 1)). + setLeader(1), + PARTITION_CHANGE_RECORD.highestSupportedVersion())), + createFooBuilder(false). + setTargetReplicas(replicas.merged()). + setTargetRemoving(replicas.removing()). + build()); + } + + @Test + public void testAddingReplicaReassignment() { + PartitionReassignmentReplicas replicas = new PartitionReassignmentReplicas( + Replicas.toList(FOO.replicas), Arrays.asList(1, 2, 3, 4)); + assertEquals(Collections.emptyList(), replicas.removing()); + assertEquals(Collections.singletonList(4), replicas.adding()); + assertEquals(Arrays.asList(1, 2, 3, 4), replicas.merged()); + assertEquals(Optional.of(new ApiMessageAndVersion(new PartitionChangeRecord(). + setTopicId(FOO_ID). + setPartitionId(0). + setReplicas(Arrays.asList(1, 2, 3, 4)). + setAddingReplicas(Collections.singletonList(4)), + PARTITION_CHANGE_RECORD.highestSupportedVersion())), + createFooBuilder(false). + setTargetReplicas(replicas.merged()). + setTargetAdding(replicas.adding()). + build()); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/controller/PartitionReassignmentReplicasTest.java b/metadata/src/test/java/org/apache/kafka/controller/PartitionReassignmentReplicasTest.java new file mode 100644 index 0000000..c74090e --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/controller/PartitionReassignmentReplicasTest.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import java.util.Arrays; +import java.util.Collections; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import static org.junit.jupiter.api.Assertions.assertEquals; + + +@Timeout(40) +public class PartitionReassignmentReplicasTest { + @Test + public void testNoneAddedOrRemoved() { + PartitionReassignmentReplicas replicas = new PartitionReassignmentReplicas( + Arrays.asList(3, 2, 1), Arrays.asList(3, 2, 1)); + assertEquals(Collections.emptyList(), replicas.removing()); + assertEquals(Collections.emptyList(), replicas.adding()); + assertEquals(Arrays.asList(3, 2, 1), replicas.merged()); + } + + @Test + public void testAdditions() { + PartitionReassignmentReplicas replicas = new PartitionReassignmentReplicas( + Arrays.asList(3, 2, 1), Arrays.asList(3, 6, 2, 1, 5)); + assertEquals(Collections.emptyList(), replicas.removing()); + assertEquals(Arrays.asList(5, 6), replicas.adding()); + assertEquals(Arrays.asList(3, 6, 2, 1, 5), replicas.merged()); + } + + @Test + public void testRemovals() { + PartitionReassignmentReplicas replicas = new PartitionReassignmentReplicas( + Arrays.asList(3, 2, 1, 0), Arrays.asList(3, 1)); + assertEquals(Arrays.asList(0, 2), replicas.removing()); + assertEquals(Collections.emptyList(), replicas.adding()); + assertEquals(Arrays.asList(3, 1, 0, 2), replicas.merged()); + } + + @Test + public void testAdditionsAndRemovals() { + PartitionReassignmentReplicas replicas = new PartitionReassignmentReplicas( + Arrays.asList(3, 2, 1, 0), Arrays.asList(7, 3, 1, 9)); + assertEquals(Arrays.asList(0, 2), replicas.removing()); + assertEquals(Arrays.asList(7, 9), replicas.adding()); + assertEquals(Arrays.asList(7, 3, 1, 9, 0, 2), replicas.merged()); + } + + @Test + public void testRearrangement() { + PartitionReassignmentReplicas replicas = new PartitionReassignmentReplicas( + Arrays.asList(3, 2, 1, 0), Arrays.asList(0, 1, 3, 2)); + assertEquals(Collections.emptyList(), replicas.removing()); + assertEquals(Collections.emptyList(), replicas.adding()); + assertEquals(Arrays.asList(0, 1, 3, 2), replicas.merged()); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/controller/PartitionReassignmentRevertTest.java b/metadata/src/test/java/org/apache/kafka/controller/PartitionReassignmentRevertTest.java new file mode 100644 index 0000000..26120be --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/controller/PartitionReassignmentRevertTest.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import java.util.Arrays; + +import org.apache.kafka.metadata.PartitionRegistration; +import org.apache.kafka.metadata.Replicas; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + + +@Timeout(40) +public class PartitionReassignmentRevertTest { + @Test + public void testNoneAddedOrRemoved() { + PartitionRegistration registration = new PartitionRegistration( + new int[] {3, 2, 1}, new int[] {3, 2}, + Replicas.NONE, Replicas.NONE, 3, 100, 200); + PartitionReassignmentRevert revert = new PartitionReassignmentRevert(registration); + assertEquals(Arrays.asList(3, 2, 1), revert.replicas()); + assertEquals(Arrays.asList(3, 2), revert.isr()); + assertFalse(revert.unclean()); + } + + @Test + public void testSomeRemoving() { + PartitionRegistration registration = new PartitionRegistration( + new int[] {3, 2, 1}, new int[] {3, 2}, + new int[] {2, 1}, Replicas.NONE, 3, 100, 200); + PartitionReassignmentRevert revert = new PartitionReassignmentRevert(registration); + assertEquals(Arrays.asList(3, 2, 1), revert.replicas()); + assertEquals(Arrays.asList(3, 2), revert.isr()); + assertFalse(revert.unclean()); + } + + @Test + public void testSomeAdding() { + PartitionRegistration registration = new PartitionRegistration( + new int[] {4, 5, 3, 2, 1}, new int[] {4, 5, 2}, + Replicas.NONE, new int[] {4, 5}, 3, 100, 200); + PartitionReassignmentRevert revert = new PartitionReassignmentRevert(registration); + assertEquals(Arrays.asList(3, 2, 1), revert.replicas()); + assertEquals(Arrays.asList(2), revert.isr()); + assertFalse(revert.unclean()); + } + + @Test + public void testSomeRemovingAndAdding() { + PartitionRegistration registration = new PartitionRegistration( + new int[] {4, 5, 3, 2, 1}, new int[] {4, 5, 2}, + new int[] {2}, new int[] {4, 5}, 3, 100, 200); + PartitionReassignmentRevert revert = new PartitionReassignmentRevert(registration); + assertEquals(Arrays.asList(3, 2, 1), revert.replicas()); + assertEquals(Arrays.asList(2), revert.isr()); + assertFalse(revert.unclean()); + } + + @Test + public void testIsrSpecialCase() { + PartitionRegistration registration = new PartitionRegistration( + new int[] {4, 5, 3, 2, 1}, new int[] {4, 5}, + new int[] {2}, new int[] {4, 5}, 3, 100, 200); + PartitionReassignmentRevert revert = new PartitionReassignmentRevert(registration); + assertEquals(Arrays.asList(3, 2, 1), revert.replicas()); + assertEquals(Arrays.asList(3), revert.isr()); + assertTrue(revert.unclean()); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/controller/ProducerIdControlManagerTest.java b/metadata/src/test/java/org/apache/kafka/controller/ProducerIdControlManagerTest.java new file mode 100644 index 0000000..990395b --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/controller/ProducerIdControlManagerTest.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import org.apache.kafka.common.errors.StaleBrokerEpochException; +import org.apache.kafka.common.errors.UnknownServerException; +import org.apache.kafka.common.metadata.ProducerIdsRecord; +import org.apache.kafka.common.metadata.RegisterBrokerRecord; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.apache.kafka.server.common.ProducerIdsBlock; +import org.apache.kafka.timeline.SnapshotRegistry; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Iterator; +import java.util.List; +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + + +public class ProducerIdControlManagerTest { + + private SnapshotRegistry snapshotRegistry; + private ClusterControlManager clusterControl; + private ProducerIdControlManager producerIdControlManager; + + @BeforeEach + public void setUp() { + final LogContext logContext = new LogContext(); + final MockTime time = new MockTime(); + final Random random = new Random(); + snapshotRegistry = new SnapshotRegistry(logContext); + clusterControl = new ClusterControlManager( + logContext, time, snapshotRegistry, 1000, + new StripedReplicaPlacer(random), new MockControllerMetrics()); + + clusterControl.activate(); + for (int i = 0; i < 4; i++) { + RegisterBrokerRecord brokerRecord = new RegisterBrokerRecord().setBrokerEpoch(100).setBrokerId(i); + brokerRecord.endPoints().add(new RegisterBrokerRecord.BrokerEndpoint(). + setSecurityProtocol(SecurityProtocol.PLAINTEXT.id). + setPort((short) 9092). + setName("PLAINTEXT"). + setHost(String.format("broker-%02d.example.org", i))); + clusterControl.replay(brokerRecord); + } + + this.producerIdControlManager = new ProducerIdControlManager(clusterControl, snapshotRegistry); + } + + @Test + public void testInitialResult() { + ControllerResult result = + producerIdControlManager.generateNextProducerId(1, 100); + assertEquals(0, result.response().producerIdStart()); + assertEquals(1000, result.response().producerIdLen()); + ProducerIdsRecord record = (ProducerIdsRecord) result.records().get(0).message(); + assertEquals(1000, record.producerIdsEnd()); + } + + @Test + public void testMonotonic() { + producerIdControlManager.replay( + new ProducerIdsRecord() + .setBrokerId(1) + .setBrokerEpoch(100) + .setProducerIdsEnd(42)); + + ProducerIdsBlock range = + producerIdControlManager.generateNextProducerId(1, 100).response(); + assertEquals(42, range.producerIdStart()); + + // Can't go backwards in Producer IDs + assertThrows(RuntimeException.class, () -> { + producerIdControlManager.replay( + new ProducerIdsRecord() + .setBrokerId(1) + .setBrokerEpoch(100) + .setProducerIdsEnd(40)); + }, "Producer ID range must only increase"); + range = producerIdControlManager.generateNextProducerId(1, 100).response(); + assertEquals(42, range.producerIdStart()); + + // Gaps in the ID range are okay. + producerIdControlManager.replay( + new ProducerIdsRecord() + .setBrokerId(1) + .setBrokerEpoch(100) + .setProducerIdsEnd(50)); + range = producerIdControlManager.generateNextProducerId(1, 100).response(); + assertEquals(50, range.producerIdStart()); + } + + @Test + public void testUnknownBrokerOrEpoch() { + ControllerResult result; + + assertThrows(StaleBrokerEpochException.class, () -> + producerIdControlManager.generateNextProducerId(99, 0)); + + assertThrows(StaleBrokerEpochException.class, () -> + producerIdControlManager.generateNextProducerId(1, 99)); + } + + @Test + public void testMaxValue() { + producerIdControlManager.replay( + new ProducerIdsRecord() + .setBrokerId(1) + .setBrokerEpoch(100) + .setProducerIdsEnd(Long.MAX_VALUE - 1)); + + assertThrows(UnknownServerException.class, () -> + producerIdControlManager.generateNextProducerId(1, 100)); + } + + @Test + public void testSnapshotIterator() { + ProducerIdsBlock range = null; + for (int i = 0; i < 100; i++) { + range = generateProducerIds(producerIdControlManager, i % 4, 100); + } + + Iterator> snapshotIterator = producerIdControlManager.iterator(Long.MAX_VALUE); + assertTrue(snapshotIterator.hasNext()); + List batch = snapshotIterator.next(); + assertEquals(1, batch.size(), "Producer IDs record batch should only contain a single record"); + assertEquals(range.producerIdStart() + range.producerIdLen(), ((ProducerIdsRecord) batch.get(0).message()).producerIdsEnd()); + assertFalse(snapshotIterator.hasNext(), "Producer IDs iterator should only contain a single batch"); + + ProducerIdControlManager newProducerIdManager = new ProducerIdControlManager(clusterControl, snapshotRegistry); + snapshotIterator = producerIdControlManager.iterator(Long.MAX_VALUE); + while (snapshotIterator.hasNext()) { + snapshotIterator.next().forEach(message -> newProducerIdManager.replay((ProducerIdsRecord) message.message())); + } + + // Verify that after reloading state from this "snapshot", we don't produce any overlapping IDs + long lastProducerID = range.producerIdStart() + range.producerIdLen() - 1; + range = generateProducerIds(producerIdControlManager, 1, 100); + assertTrue(range.producerIdStart() > lastProducerID); + } + + static ProducerIdsBlock generateProducerIds( + ProducerIdControlManager producerIdControlManager, int brokerId, long brokerEpoch) { + ControllerResult result = + producerIdControlManager.generateNextProducerId(brokerId, brokerEpoch); + result.records().forEach(apiMessageAndVersion -> + producerIdControlManager.replay((ProducerIdsRecord) apiMessageAndVersion.message())); + return result.response(); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/controller/QuorumControllerMetricsTest.java b/metadata/src/test/java/org/apache/kafka/controller/QuorumControllerMetricsTest.java new file mode 100644 index 0000000..74b24c7 --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/controller/QuorumControllerMetricsTest.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import com.yammer.metrics.core.MetricsRegistry; +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.api.Test; + +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class QuorumControllerMetricsTest { + private static final String EXPECTED_GROUP = "kafka.controller"; + @Test + public void testKafkaControllerMetricNames() { + String expectedType = "KafkaController"; + Set expectedMetricNames = Utils.mkSet( + "ActiveControllerCount", + "GlobalTopicCount", + "GlobalPartitionCount", + "OfflinePartitionsCount", + "PreferredReplicaImbalanceCount"); + assertMetricsCreatedAndRemovedUponClose(expectedType, expectedMetricNames); + } + + @Test + public void testControllerEventManagerMetricNames() { + String expectedType = "ControllerEventManager"; + Set expectedMetricNames = Utils.mkSet( + "EventQueueTimeMs", + "EventQueueProcessingTimeMs"); + assertMetricsCreatedAndRemovedUponClose(expectedType, expectedMetricNames); + } + + private static void assertMetricsCreatedAndRemovedUponClose(String expectedType, Set expectedMetricNames) { + MetricsRegistry registry = new MetricsRegistry(); + try (QuorumControllerMetrics quorumControllerMetrics = new QuorumControllerMetrics(registry)) { + assertMetricsCreated(registry, expectedMetricNames, expectedType); + } + assertMetricsRemoved(registry, expectedMetricNames, expectedType); + } + + private static void assertMetricsCreated(MetricsRegistry registry, Set expectedMetricNames, String expectedType) { + expectedMetricNames.forEach(expectedMetricName -> assertTrue( + registry.allMetrics().keySet().stream().anyMatch(metricName -> { + if (metricName.getGroup().equals(EXPECTED_GROUP) && metricName.getType().equals(expectedType) + && metricName.getScope() == null && metricName.getName().equals(expectedMetricName)) { + // It has to exist AND the MBean name has to be correct; + // fail right here if the MBean name doesn't match + String expectedMBeanPrefix = EXPECTED_GROUP + ":type=" + expectedType + ",name="; + assertEquals(expectedMBeanPrefix + expectedMetricName, metricName.getMBeanName(), + "Incorrect MBean name"); + return true; // the metric name exists and the associated MBean name matches + } else { + return false; // this one didn't match + } + }), "Missing metric: " + expectedMetricName)); + } + + private static void assertMetricsRemoved(MetricsRegistry registry, Set expectedMetricNames, String expectedType) { + expectedMetricNames.forEach(expectedMetricName -> assertTrue( + registry.allMetrics().keySet().stream().noneMatch(metricName -> + metricName.getGroup().equals(EXPECTED_GROUP) && metricName.getType().equals(expectedType) + && metricName.getScope() == null && metricName.getName().equals(expectedMetricName)), + "Metric not removed when closed: " + expectedMetricName)); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/controller/QuorumControllerTest.java b/metadata/src/test/java/org/apache/kafka/controller/QuorumControllerTest.java new file mode 100644 index 0000000..c274de6 --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/controller/QuorumControllerTest.java @@ -0,0 +1,868 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Spliterator; +import java.util.Spliterators; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; +import java.util.stream.IntStream; + +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.message.AllocateProducerIdsRequestData; +import org.apache.kafka.common.message.AlterIsrRequestData; +import org.apache.kafka.common.message.AlterPartitionReassignmentsRequestData.ReassignableTopic; +import org.apache.kafka.common.message.AlterPartitionReassignmentsRequestData; +import org.apache.kafka.common.message.AlterPartitionReassignmentsResponseData; +import org.apache.kafka.common.message.BrokerHeartbeatRequestData; +import org.apache.kafka.common.message.BrokerRegistrationRequestData.Listener; +import org.apache.kafka.common.message.BrokerRegistrationRequestData.ListenerCollection; +import org.apache.kafka.common.message.BrokerRegistrationRequestData; +import org.apache.kafka.common.message.CreatePartitionsRequestData.CreatePartitionsTopic; +import org.apache.kafka.common.message.CreatePartitionsResponseData.CreatePartitionsTopicResult; +import org.apache.kafka.common.message.CreateTopicsRequestData.CreatableReplicaAssignment; +import org.apache.kafka.common.message.CreateTopicsRequestData.CreatableReplicaAssignmentCollection; +import org.apache.kafka.common.message.CreateTopicsRequestData; +import org.apache.kafka.common.message.CreateTopicsRequestData.CreatableTopic; +import org.apache.kafka.common.message.CreateTopicsRequestData.CreatableTopicCollection; +import org.apache.kafka.common.message.CreateTopicsResponseData; +import org.apache.kafka.common.message.ElectLeadersRequestData; +import org.apache.kafka.common.message.ElectLeadersResponseData; +import org.apache.kafka.common.message.ListPartitionReassignmentsRequestData; +import org.apache.kafka.common.message.ListPartitionReassignmentsResponseData; +import org.apache.kafka.common.metadata.PartitionRecord; +import org.apache.kafka.common.metadata.ProducerIdsRecord; +import org.apache.kafka.common.metadata.RegisterBrokerRecord; +import org.apache.kafka.common.metadata.RegisterBrokerRecord.BrokerEndpoint; +import org.apache.kafka.common.metadata.RegisterBrokerRecord.BrokerEndpointCollection; +import org.apache.kafka.common.metadata.TopicRecord; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.ApiError; +import org.apache.kafka.common.utils.BufferSupplier; +import org.apache.kafka.controller.BrokersToIsrs.TopicIdPartition; +import org.apache.kafka.metadata.BrokerHeartbeatReply; +import org.apache.kafka.metadata.BrokerRegistrationReply; +import org.apache.kafka.metadata.MetadataRecordSerde; +import org.apache.kafka.metadata.PartitionRegistration; +import org.apache.kafka.metadata.RecordTestUtils; +import org.apache.kafka.metalog.LocalLogManagerTestEnv; +import org.apache.kafka.raft.Batch; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.apache.kafka.snapshot.RawSnapshotReader; +import org.apache.kafka.snapshot.SnapshotReader; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import static java.util.concurrent.TimeUnit.HOURS; +import static org.apache.kafka.clients.admin.AlterConfigOp.OpType.SET; +import static org.apache.kafka.controller.ConfigurationControlManagerTest.BROKER0; +import static org.apache.kafka.controller.ConfigurationControlManagerTest.CONFIGS; +import static org.apache.kafka.controller.ConfigurationControlManagerTest.entry; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + + +@Timeout(value = 40) +public class QuorumControllerTest { + + /** + * Test creating a new QuorumController and closing it. + */ + @Test + public void testCreateAndClose() throws Throwable { + MockControllerMetrics metrics = new MockControllerMetrics(); + try ( + LocalLogManagerTestEnv logEnv = new LocalLogManagerTestEnv(1, Optional.empty()); + QuorumControllerTestEnv controlEnv = + new QuorumControllerTestEnv(logEnv, builder -> builder.setMetrics(metrics)) + ) { + } + assertTrue(metrics.isClosed(), "metrics were not closed"); + } + + /** + * Test setting some configuration values and reading them back. + */ + @Test + public void testConfigurationOperations() throws Throwable { + try ( + LocalLogManagerTestEnv logEnv = new LocalLogManagerTestEnv(1, Optional.empty()); + QuorumControllerTestEnv controlEnv = new QuorumControllerTestEnv(logEnv, b -> b.setConfigDefs(CONFIGS)) + ) { + testConfigurationOperations(controlEnv.activeController()); + } + } + + private void testConfigurationOperations(QuorumController controller) throws Throwable { + assertEquals(Collections.singletonMap(BROKER0, ApiError.NONE), + controller.incrementalAlterConfigs(Collections.singletonMap( + BROKER0, Collections.singletonMap("baz", entry(SET, "123"))), true).get()); + assertEquals(Collections.singletonMap(BROKER0, + new ResultOrError<>(Collections.emptyMap())), + controller.describeConfigs(Collections.singletonMap( + BROKER0, Collections.emptyList())).get()); + assertEquals(Collections.singletonMap(BROKER0, ApiError.NONE), + controller.incrementalAlterConfigs(Collections.singletonMap( + BROKER0, Collections.singletonMap("baz", entry(SET, "123"))), false).get()); + assertEquals(Collections.singletonMap(BROKER0, new ResultOrError<>(Collections. + singletonMap("baz", "123"))), + controller.describeConfigs(Collections.singletonMap( + BROKER0, Collections.emptyList())).get()); + } + + /** + * Test that an incrementalAlterConfigs operation doesn't complete until the records + * can be written to the metadata log. + */ + @Test + public void testDelayedConfigurationOperations() throws Throwable { + try ( + LocalLogManagerTestEnv logEnv = new LocalLogManagerTestEnv(1, Optional.empty()); + QuorumControllerTestEnv controlEnv = new QuorumControllerTestEnv(logEnv, b -> b.setConfigDefs(CONFIGS)) + ) { + testDelayedConfigurationOperations(logEnv, controlEnv.activeController()); + } + } + + private void testDelayedConfigurationOperations(LocalLogManagerTestEnv logEnv, + QuorumController controller) + throws Throwable { + logEnv.logManagers().forEach(m -> m.setMaxReadOffset(0L)); + CompletableFuture> future1 = + controller.incrementalAlterConfigs(Collections.singletonMap( + BROKER0, Collections.singletonMap("baz", entry(SET, "123"))), false); + assertFalse(future1.isDone()); + assertEquals(Collections.singletonMap(BROKER0, + new ResultOrError<>(Collections.emptyMap())), + controller.describeConfigs(Collections.singletonMap( + BROKER0, Collections.emptyList())).get()); + logEnv.logManagers().forEach(m -> m.setMaxReadOffset(1L)); + assertEquals(Collections.singletonMap(BROKER0, ApiError.NONE), future1.get()); + } + + @Test + public void testFenceMultipleBrokers() throws Throwable { + List allBrokers = Arrays.asList(1, 2, 3, 4, 5); + List brokersToKeepUnfenced = Arrays.asList(1); + List brokersToFence = Arrays.asList(2, 3, 4, 5); + short replicationFactor = 5; + long sessionTimeoutMillis = 1000; + + try ( + LocalLogManagerTestEnv logEnv = new LocalLogManagerTestEnv(1, Optional.empty()); + QuorumControllerTestEnv controlEnv = new QuorumControllerTestEnv( + logEnv, b -> b.setConfigDefs(CONFIGS), Optional.of(sessionTimeoutMillis)); + ) { + ListenerCollection listeners = new ListenerCollection(); + listeners.add(new Listener().setName("PLAINTEXT").setHost("localhost").setPort(9092)); + QuorumController active = controlEnv.activeController(); + Map brokerEpochs = new HashMap<>(); + + for (Integer brokerId : allBrokers) { + CompletableFuture reply = active.registerBroker( + new BrokerRegistrationRequestData(). + setBrokerId(brokerId). + setClusterId("06B-K3N1TBCNYFgruEVP0Q"). + setIncarnationId(Uuid.randomUuid()). + setListeners(listeners)); + brokerEpochs.put(brokerId, reply.get().epoch()); + } + + // Brokers are only registered and should still be fenced + allBrokers.forEach(brokerId -> { + assertFalse(active.replicationControl().isBrokerUnfenced(brokerId), + "Broker " + brokerId + " should have been fenced"); + }); + + // Unfence all brokers and create a topic foo + sendBrokerheartbeat(active, allBrokers, brokerEpochs); + CreateTopicsRequestData createTopicsRequestData = new CreateTopicsRequestData().setTopics( + new CreatableTopicCollection(Collections.singleton( + new CreatableTopic().setName("foo").setNumPartitions(1). + setReplicationFactor(replicationFactor)).iterator())); + CreateTopicsResponseData createTopicsResponseData = active.createTopics(createTopicsRequestData).get(); + assertEquals(Errors.NONE, Errors.forCode(createTopicsResponseData.topics().find("foo").errorCode())); + Uuid topicIdFoo = createTopicsResponseData.topics().find("foo").topicId(); + + // Fence some of the brokers + TestUtils.waitForCondition(() -> { + sendBrokerheartbeat(active, brokersToKeepUnfenced, brokerEpochs); + for (Integer brokerId : brokersToFence) { + if (active.replicationControl().isBrokerUnfenced(brokerId)) { + return false; + } + } + return true; + }, sessionTimeoutMillis * 3, + "Fencing of brokers did not process within expected time" + ); + + // Send another heartbeat to the brokers we want to keep alive + sendBrokerheartbeat(active, brokersToKeepUnfenced, brokerEpochs); + + // At this point only the brokers we want fenced should be fenced. + brokersToKeepUnfenced.forEach(brokerId -> { + assertTrue(active.replicationControl().isBrokerUnfenced(brokerId), + "Broker " + brokerId + " should have been unfenced"); + }); + brokersToFence.forEach(brokerId -> { + assertFalse(active.replicationControl().isBrokerUnfenced(brokerId), + "Broker " + brokerId + " should have been fenced"); + }); + + // Verify the isr and leaders for the topic partition + int[] expectedIsr = {1}; + int[] isrFoo = active.replicationControl().getPartition(topicIdFoo, 0).isr; + + assertTrue(Arrays.equals(isrFoo, expectedIsr), + "The ISR for topic foo was " + Arrays.toString(isrFoo) + + ". It is expected to be " + Arrays.toString(expectedIsr)); + + int fooLeader = active.replicationControl().getPartition(topicIdFoo, 0).leader; + assertEquals(expectedIsr[0], fooLeader); + } + } + + @Test + public void testUnregisterBroker() throws Throwable { + try (LocalLogManagerTestEnv logEnv = new LocalLogManagerTestEnv(1, Optional.empty())) { + try (QuorumControllerTestEnv controlEnv = + new QuorumControllerTestEnv(logEnv, b -> b.setConfigDefs(CONFIGS))) { + ListenerCollection listeners = new ListenerCollection(); + listeners.add(new Listener().setName("PLAINTEXT"). + setHost("localhost").setPort(9092)); + QuorumController active = controlEnv.activeController(); + CompletableFuture reply = active.registerBroker( + new BrokerRegistrationRequestData(). + setBrokerId(0). + setClusterId("06B-K3N1TBCNYFgruEVP0Q"). + setIncarnationId(Uuid.fromString("kxAT73dKQsitIedpiPtwBA")). + setListeners(listeners)); + assertEquals(0L, reply.get().epoch()); + CreateTopicsRequestData createTopicsRequestData = + new CreateTopicsRequestData().setTopics( + new CreatableTopicCollection(Collections.singleton( + new CreatableTopic().setName("foo").setNumPartitions(1). + setReplicationFactor((short) 1)).iterator())); + assertEquals(Errors.INVALID_REPLICATION_FACTOR.code(), active.createTopics( + createTopicsRequestData).get().topics().find("foo").errorCode()); + assertEquals("Unable to replicate the partition 1 time(s): All brokers " + + "are currently fenced.", active.createTopics( + createTopicsRequestData).get().topics().find("foo").errorMessage()); + assertEquals(new BrokerHeartbeatReply(true, false, false, false), + active.processBrokerHeartbeat(new BrokerHeartbeatRequestData(). + setWantFence(false).setBrokerEpoch(0L).setBrokerId(0). + setCurrentMetadataOffset(100000L)).get()); + assertEquals(Errors.NONE.code(), active.createTopics( + createTopicsRequestData).get().topics().find("foo").errorCode()); + CompletableFuture topicPartitionFuture = active.appendReadEvent( + "debugGetPartition", () -> { + Iterator iterator = active. + replicationControl().brokersToIsrs().iterator(0, true); + assertTrue(iterator.hasNext()); + return iterator.next(); + }); + assertEquals(0, topicPartitionFuture.get().partitionId()); + active.unregisterBroker(0).get(); + topicPartitionFuture = active.appendReadEvent( + "debugGetPartition", () -> { + Iterator iterator = active. + replicationControl().brokersToIsrs().partitionsWithNoLeader(); + assertTrue(iterator.hasNext()); + return iterator.next(); + }); + assertEquals(0, topicPartitionFuture.get().partitionId()); + } + } + } + + @Test + public void testSnapshotSaveAndLoad() throws Throwable { + final int numBrokers = 4; + Map brokerEpochs = new HashMap<>(); + RawSnapshotReader reader = null; + Uuid fooId; + try (LocalLogManagerTestEnv logEnv = new LocalLogManagerTestEnv(3, Optional.empty())) { + try (QuorumControllerTestEnv controlEnv = + new QuorumControllerTestEnv(logEnv, b -> b.setConfigDefs(CONFIGS))) { + QuorumController active = controlEnv.activeController(); + for (int i = 0; i < numBrokers; i++) { + BrokerRegistrationReply reply = active.registerBroker( + new BrokerRegistrationRequestData(). + setBrokerId(i). + setRack(null). + setClusterId("06B-K3N1TBCNYFgruEVP0Q"). + setIncarnationId(Uuid.fromString("kxAT73dKQsitIedpiPtwB" + i)). + setListeners(new ListenerCollection(Arrays.asList(new Listener(). + setName("PLAINTEXT").setHost("localhost"). + setPort(9092 + i)).iterator()))).get(); + brokerEpochs.put(i, reply.epoch()); + } + for (int i = 0; i < numBrokers - 1; i++) { + assertEquals(new BrokerHeartbeatReply(true, false, false, false), + active.processBrokerHeartbeat(new BrokerHeartbeatRequestData(). + setWantFence(false).setBrokerEpoch(brokerEpochs.get(i)). + setBrokerId(i).setCurrentMetadataOffset(100000L)).get()); + } + CreateTopicsResponseData fooData = active.createTopics( + new CreateTopicsRequestData().setTopics( + new CreatableTopicCollection(Collections.singleton( + new CreatableTopic().setName("foo").setNumPartitions(-1). + setReplicationFactor((short) -1). + setAssignments(new CreatableReplicaAssignmentCollection( + Arrays.asList(new CreatableReplicaAssignment(). + setPartitionIndex(0). + setBrokerIds(Arrays.asList(0, 1, 2)), + new CreatableReplicaAssignment(). + setPartitionIndex(1). + setBrokerIds(Arrays.asList(1, 2, 0))). + iterator()))).iterator()))).get(); + fooId = fooData.topics().find("foo").topicId(); + active.allocateProducerIds( + new AllocateProducerIdsRequestData().setBrokerId(0).setBrokerEpoch(brokerEpochs.get(0))).get(); + long snapshotLogOffset = active.beginWritingSnapshot().get(); + reader = logEnv.waitForSnapshot(snapshotLogOffset); + SnapshotReader snapshot = createSnapshotReader(reader); + assertEquals(snapshotLogOffset, snapshot.lastContainedLogOffset()); + checkSnapshotContent(expectedSnapshotContent(fooId, brokerEpochs), snapshot); + } + } + + try (LocalLogManagerTestEnv logEnv = new LocalLogManagerTestEnv(3, Optional.of(reader))) { + try (QuorumControllerTestEnv controlEnv = + new QuorumControllerTestEnv(logEnv, b -> b.setConfigDefs(CONFIGS))) { + QuorumController active = controlEnv.activeController(); + long snapshotLogOffset = active.beginWritingSnapshot().get(); + SnapshotReader snapshot = createSnapshotReader( + logEnv.waitForSnapshot(snapshotLogOffset) + ); + assertEquals(snapshotLogOffset, snapshot.lastContainedLogOffset()); + checkSnapshotContent(expectedSnapshotContent(fooId, brokerEpochs), snapshot); + } + } + } + + @Test + public void testSnapshotConfiguration() throws Throwable { + final int numBrokers = 4; + final int maxNewRecordBytes = 4; + Map brokerEpochs = new HashMap<>(); + Uuid fooId; + try (LocalLogManagerTestEnv logEnv = new LocalLogManagerTestEnv(3, Optional.empty())) { + try (QuorumControllerTestEnv controlEnv = new QuorumControllerTestEnv(logEnv, + builder -> { + builder + .setConfigDefs(CONFIGS) + .setSnapshotMaxNewRecordBytes(maxNewRecordBytes); + }) + ) { + + QuorumController active = controlEnv.activeController(); + for (int i = 0; i < numBrokers; i++) { + BrokerRegistrationReply reply = active.registerBroker( + new BrokerRegistrationRequestData(). + setBrokerId(i). + setRack(null). + setClusterId("06B-K3N1TBCNYFgruEVP0Q"). + setIncarnationId(Uuid.fromString("kxAT73dKQsitIedpiPtwB" + i)). + setListeners(new ListenerCollection(Arrays.asList(new Listener(). + setName("PLAINTEXT").setHost("localhost"). + setPort(9092 + i)).iterator()))).get(); + brokerEpochs.put(i, reply.epoch()); + } + for (int i = 0; i < numBrokers - 1; i++) { + assertEquals(new BrokerHeartbeatReply(true, false, false, false), + active.processBrokerHeartbeat(new BrokerHeartbeatRequestData(). + setWantFence(false).setBrokerEpoch(brokerEpochs.get(i)). + setBrokerId(i).setCurrentMetadataOffset(100000L)).get()); + } + CreateTopicsResponseData fooData = active.createTopics( + new CreateTopicsRequestData().setTopics( + new CreatableTopicCollection(Collections.singleton( + new CreatableTopic().setName("foo").setNumPartitions(-1). + setReplicationFactor((short) -1). + setAssignments(new CreatableReplicaAssignmentCollection( + Arrays.asList(new CreatableReplicaAssignment(). + setPartitionIndex(0). + setBrokerIds(Arrays.asList(0, 1, 2)), + new CreatableReplicaAssignment(). + setPartitionIndex(1). + setBrokerIds(Arrays.asList(1, 2, 0))). + iterator()))).iterator()))).get(); + fooId = fooData.topics().find("foo").topicId(); + active.allocateProducerIds( + new AllocateProducerIdsRequestData().setBrokerId(0).setBrokerEpoch(brokerEpochs.get(0))).get(); + + SnapshotReader snapshot = createSnapshotReader(logEnv.waitForLatestSnapshot()); + checkSnapshotSubcontent( + expectedSnapshotContent(fooId, brokerEpochs), + snapshot + ); + } + } + } + + @Test + public void testSnapshotOnlyAfterConfiguredMinBytes() throws Throwable { + final int numBrokers = 4; + final int maxNewRecordBytes = 1000; + Map brokerEpochs = new HashMap<>(); + try (LocalLogManagerTestEnv logEnv = new LocalLogManagerTestEnv(3, Optional.empty())) { + try (QuorumControllerTestEnv controlEnv = new QuorumControllerTestEnv(logEnv, + builder -> builder.setConfigDefs(CONFIGS). + setSnapshotMaxNewRecordBytes(maxNewRecordBytes)) + ) { + QuorumController active = controlEnv.activeController(); + for (int i = 0; i < numBrokers; i++) { + BrokerRegistrationReply reply = active.registerBroker( + new BrokerRegistrationRequestData(). + setBrokerId(i). + setRack(null). + setClusterId("06B-K3N1TBCNYFgruEVP0Q"). + setIncarnationId(Uuid.fromString("kxAT73dKQsitIedpiPtwB" + i)). + setListeners(new ListenerCollection(Arrays.asList(new Listener(). + setName("PLAINTEXT").setHost("localhost"). + setPort(9092 + i)).iterator()))).get(); + brokerEpochs.put(i, reply.epoch()); + assertEquals(new BrokerHeartbeatReply(true, false, false, false), + active.processBrokerHeartbeat(new BrokerHeartbeatRequestData(). + setWantFence(false).setBrokerEpoch(brokerEpochs.get(i)). + setBrokerId(i).setCurrentMetadataOffset(100000L)).get()); + } + + assertTrue(logEnv.appendedBytes() < maxNewRecordBytes, + String.format("%s appended bytes is not less than %s max new record bytes", + logEnv.appendedBytes(), + maxNewRecordBytes)); + + // Keep creating topic until we reached the max bytes limit + int counter = 0; + while (logEnv.appendedBytes() < maxNewRecordBytes) { + counter += 1; + String topicName = String.format("foo-%s", counter); + active.createTopics(new CreateTopicsRequestData().setTopics( + new CreatableTopicCollection(Collections.singleton( + new CreatableTopic().setName(topicName).setNumPartitions(-1). + setReplicationFactor((short) -1). + setAssignments(new CreatableReplicaAssignmentCollection( + Arrays.asList(new CreatableReplicaAssignment(). + setPartitionIndex(0). + setBrokerIds(Arrays.asList(0, 1, 2)), + new CreatableReplicaAssignment(). + setPartitionIndex(1). + setBrokerIds(Arrays.asList(1, 2, 0))). + iterator()))).iterator()))).get(); + } + logEnv.waitForLatestSnapshot(); + } + } + } + + private SnapshotReader createSnapshotReader(RawSnapshotReader reader) { + return SnapshotReader.of( + reader, + new MetadataRecordSerde(), + BufferSupplier.create(), + Integer.MAX_VALUE + ); + } + + private List expectedSnapshotContent(Uuid fooId, Map brokerEpochs) { + return Arrays.asList( + new ApiMessageAndVersion(new TopicRecord(). + setName("foo").setTopicId(fooId), (short) 0), + new ApiMessageAndVersion(new PartitionRecord().setPartitionId(0). + setTopicId(fooId).setReplicas(Arrays.asList(0, 1, 2)). + setIsr(Arrays.asList(0, 1, 2)).setRemovingReplicas(Collections.emptyList()). + setAddingReplicas(Collections.emptyList()).setLeader(0).setLeaderEpoch(0). + setPartitionEpoch(0), (short) 0), + new ApiMessageAndVersion(new PartitionRecord().setPartitionId(1). + setTopicId(fooId).setReplicas(Arrays.asList(1, 2, 0)). + setIsr(Arrays.asList(1, 2, 0)).setRemovingReplicas(Collections.emptyList()). + setAddingReplicas(Collections.emptyList()).setLeader(1).setLeaderEpoch(0). + setPartitionEpoch(0), (short) 0), + new ApiMessageAndVersion(new RegisterBrokerRecord(). + setBrokerId(0).setBrokerEpoch(brokerEpochs.get(0)). + setIncarnationId(Uuid.fromString("kxAT73dKQsitIedpiPtwB0")). + setEndPoints( + new BrokerEndpointCollection( + Arrays.asList( + new BrokerEndpoint().setName("PLAINTEXT").setHost("localhost"). + setPort(9092).setSecurityProtocol((short) 0)).iterator())). + setRack(null). + setFenced(false), (short) 0), + new ApiMessageAndVersion(new RegisterBrokerRecord(). + setBrokerId(1).setBrokerEpoch(brokerEpochs.get(1)). + setIncarnationId(Uuid.fromString("kxAT73dKQsitIedpiPtwB1")). + setEndPoints( + new BrokerEndpointCollection( + Arrays.asList( + new BrokerEndpoint().setName("PLAINTEXT").setHost("localhost"). + setPort(9093).setSecurityProtocol((short) 0)).iterator())). + setRack(null). + setFenced(false), (short) 0), + new ApiMessageAndVersion(new RegisterBrokerRecord(). + setBrokerId(2).setBrokerEpoch(brokerEpochs.get(2)). + setIncarnationId(Uuid.fromString("kxAT73dKQsitIedpiPtwB2")). + setEndPoints( + new BrokerEndpointCollection( + Arrays.asList( + new BrokerEndpoint().setName("PLAINTEXT").setHost("localhost"). + setPort(9094).setSecurityProtocol((short) 0)).iterator())). + setRack(null). + setFenced(false), (short) 0), + new ApiMessageAndVersion(new RegisterBrokerRecord(). + setBrokerId(3).setBrokerEpoch(brokerEpochs.get(3)). + setIncarnationId(Uuid.fromString("kxAT73dKQsitIedpiPtwB3")). + setEndPoints(new BrokerEndpointCollection(Arrays.asList( + new BrokerEndpoint().setName("PLAINTEXT").setHost("localhost"). + setPort(9095).setSecurityProtocol((short) 0)).iterator())). + setRack(null), (short) 0), + new ApiMessageAndVersion(new ProducerIdsRecord(). + setBrokerId(0). + setBrokerEpoch(brokerEpochs.get(0)). + setProducerIdsEnd(1000), (short) 0) + ); + } + + private void checkSnapshotContent( + List expected, + Iterator> iterator + ) throws Exception { + RecordTestUtils.assertBatchIteratorContains( + Arrays.asList(expected), + Arrays.asList( + StreamSupport.stream(Spliterators.spliteratorUnknownSize(iterator, Spliterator.ORDERED), false) + .flatMap(batch -> batch.records().stream()) + .collect(Collectors.toList()) + ).iterator() + ); + } + + /** + * This function checks that the iterator is a subset of the expected list. + * + * This is needed because when generating snapshots through configuration is difficult to control exactly when a + * snapshot will be generated and which committed offset will be included in the snapshot. + */ + private void checkSnapshotSubcontent( + List expected, + Iterator> iterator + ) throws Exception { + RecordTestUtils.deepSortRecords(expected); + + List actual = StreamSupport + .stream(Spliterators.spliteratorUnknownSize(iterator, Spliterator.ORDERED), false) + .flatMap(batch -> batch.records().stream()) + .collect(Collectors.toList()); + + RecordTestUtils.deepSortRecords(actual); + + int expectedIndex = 0; + for (ApiMessageAndVersion current : actual) { + while (expectedIndex < expected.size() && !expected.get(expectedIndex).equals(current)) { + expectedIndex += 1; + } + expectedIndex += 1; + } + + assertTrue( + expectedIndex <= expected.size(), + String.format("actual is not a subset of expected: expected = %s; actual = %s", expected, actual) + ); + } + + /** + * Test that certain controller operations time out if they stay on the controller + * queue for too long. + */ + @Test + public void testTimeouts() throws Throwable { + try (LocalLogManagerTestEnv logEnv = new LocalLogManagerTestEnv(1, Optional.empty())) { + try (QuorumControllerTestEnv controlEnv = + new QuorumControllerTestEnv(logEnv, b -> b.setConfigDefs(CONFIGS))) { + QuorumController controller = controlEnv.activeController(); + CountDownLatch countDownLatch = controller.pause(); + CompletableFuture createFuture = + controller.createTopics(new CreateTopicsRequestData().setTimeoutMs(0). + setTopics(new CreatableTopicCollection(Collections.singleton( + new CreatableTopic().setName("foo")).iterator()))); + long now = controller.time().nanoseconds(); + CompletableFuture> deleteFuture = + controller.deleteTopics(now, Collections.singletonList(Uuid.ZERO_UUID)); + CompletableFuture>> findTopicIdsFuture = + controller.findTopicIds(now, Collections.singletonList("foo")); + CompletableFuture>> findTopicNamesFuture = + controller.findTopicNames(now, Collections.singletonList(Uuid.ZERO_UUID)); + CompletableFuture> createPartitionsFuture = + controller.createPartitions(now, Collections.singletonList( + new CreatePartitionsTopic())); + CompletableFuture electLeadersFuture = + controller.electLeaders(new ElectLeadersRequestData().setTimeoutMs(0). + setTopicPartitions(null)); + CompletableFuture alterReassignmentsFuture = + controller.alterPartitionReassignments( + new AlterPartitionReassignmentsRequestData().setTimeoutMs(0). + setTopics(Collections.singletonList(new ReassignableTopic()))); + CompletableFuture listReassignmentsFuture = + controller.listPartitionReassignments( + new ListPartitionReassignmentsRequestData().setTopics(null).setTimeoutMs(0)); + while (controller.time().nanoseconds() == now) { + Thread.sleep(0, 10); + } + countDownLatch.countDown(); + assertYieldsTimeout(createFuture); + assertYieldsTimeout(deleteFuture); + assertYieldsTimeout(findTopicIdsFuture); + assertYieldsTimeout(findTopicNamesFuture); + assertYieldsTimeout(createPartitionsFuture); + assertYieldsTimeout(electLeadersFuture); + assertYieldsTimeout(alterReassignmentsFuture); + assertYieldsTimeout(listReassignmentsFuture); + } + } + } + + private static void assertYieldsTimeout(Future future) { + assertEquals(TimeoutException.class, assertThrows(ExecutionException.class, + () -> future.get()).getCause().getClass()); + } + + /** + * Test that certain controller operations finish immediately without putting an event + * on the controller queue, if there is nothing to do. + */ + @Test + public void testEarlyControllerResults() throws Throwable { + try (LocalLogManagerTestEnv logEnv = new LocalLogManagerTestEnv(1, Optional.empty())) { + try (QuorumControllerTestEnv controlEnv = + new QuorumControllerTestEnv(logEnv, b -> b.setConfigDefs(CONFIGS))) { + QuorumController controller = controlEnv.activeController(); + CountDownLatch countDownLatch = controller.pause(); + CompletableFuture createFuture = + controller.createTopics(new CreateTopicsRequestData().setTimeoutMs(120000)); + long deadlineMs = controller.time().nanoseconds() + HOURS.toNanos(1); + CompletableFuture> deleteFuture = + controller.deleteTopics(deadlineMs, Collections.emptyList()); + CompletableFuture>> findTopicIdsFuture = + controller.findTopicIds(deadlineMs, Collections.emptyList()); + CompletableFuture>> findTopicNamesFuture = + controller.findTopicNames(deadlineMs, Collections.emptyList()); + CompletableFuture> createPartitionsFuture = + controller.createPartitions(deadlineMs, Collections.emptyList()); + CompletableFuture electLeadersFuture = + controller.electLeaders(new ElectLeadersRequestData().setTimeoutMs(120000)); + CompletableFuture alterReassignmentsFuture = + controller.alterPartitionReassignments( + new AlterPartitionReassignmentsRequestData().setTimeoutMs(12000)); + createFuture.get(); + deleteFuture.get(); + findTopicIdsFuture.get(); + findTopicNamesFuture.get(); + createPartitionsFuture.get(); + electLeadersFuture.get(); + alterReassignmentsFuture.get(); + countDownLatch.countDown(); + } + } + } + + @Test + public void testMissingInMemorySnapshot() throws Exception { + int numBrokers = 3; + int numPartitions = 3; + String topicName = "topic-name"; + + try ( + LocalLogManagerTestEnv logEnv = new LocalLogManagerTestEnv(1, Optional.empty()); + QuorumControllerTestEnv controlEnv = + new QuorumControllerTestEnv(logEnv, b -> b.setConfigDefs(CONFIGS)) + ) { + QuorumController controller = controlEnv.activeController(); + + Map brokerEpochs = registerBrokers(controller, numBrokers); + + // Create a lot of partitions + List partitions = IntStream + .range(0, numPartitions) + .mapToObj(partitionIndex -> new CreatableReplicaAssignment() + .setPartitionIndex(partitionIndex) + .setBrokerIds(Arrays.asList(0, 1, 2)) + ) + .collect(Collectors.toList()); + + Uuid topicId = controller.createTopics( + new CreateTopicsRequestData() + .setTopics( + new CreatableTopicCollection( + Collections.singleton( + new CreatableTopic() + .setName(topicName) + .setNumPartitions(-1) + .setReplicationFactor((short) -1) + .setAssignments(new CreatableReplicaAssignmentCollection(partitions.iterator())) + ).iterator() + ) + ) + ).get().topics().find(topicName).topicId(); + + // Create a lot of alter isr + List alterIsrs = IntStream + .range(0, numPartitions) + .mapToObj(partitionIndex -> { + PartitionRegistration partitionRegistration = controller.replicationControl().getPartition( + topicId, + partitionIndex + ); + + return new AlterIsrRequestData.PartitionData() + .setPartitionIndex(partitionIndex) + .setLeaderEpoch(partitionRegistration.leaderEpoch) + .setCurrentIsrVersion(partitionRegistration.partitionEpoch) + .setNewIsr(Arrays.asList(0, 1)); + }) + .collect(Collectors.toList()); + + AlterIsrRequestData.TopicData topicData = new AlterIsrRequestData.TopicData() + .setName(topicName); + topicData.partitions().addAll(alterIsrs); + + int leaderId = 0; + AlterIsrRequestData alterIsrRequest = new AlterIsrRequestData() + .setBrokerId(leaderId) + .setBrokerEpoch(brokerEpochs.get(leaderId)); + alterIsrRequest.topics().add(topicData); + + logEnv.logManagers().get(0).resignAfterNonAtomicCommit(); + + int oldClaimEpoch = controller.curClaimEpoch(); + assertThrows( + ExecutionException.class, + () -> controller.alterIsr(alterIsrRequest).get() + ); + + // Wait for the controller to become active again + assertSame(controller, controlEnv.activeController()); + assertTrue( + oldClaimEpoch < controller.curClaimEpoch(), + String.format("oldClaimEpoch = %s, newClaimEpoch = %s", oldClaimEpoch, controller.curClaimEpoch()) + ); + + // Since the alterIsr partially failed we expect to see + // some partitions to still have 2 in the ISR. + int partitionsWithReplica2 = Utils.toList( + controller + .replicationControl() + .brokersToIsrs() + .partitionsWithBrokerInIsr(2) + ).size(); + int partitionsWithReplica0 = Utils.toList( + controller + .replicationControl() + .brokersToIsrs() + .partitionsWithBrokerInIsr(0) + ).size(); + + assertEquals(numPartitions, partitionsWithReplica0); + assertNotEquals(0, partitionsWithReplica2); + assertTrue( + partitionsWithReplica0 > partitionsWithReplica2, + String.format( + "partitionsWithReplica0 = %s, partitionsWithReplica2 = %s", + partitionsWithReplica0, + partitionsWithReplica2 + ) + ); + } + } + + private Map registerBrokers(QuorumController controller, int numBrokers) throws Exception { + Map brokerEpochs = new HashMap<>(); + for (int brokerId = 0; brokerId < numBrokers; brokerId++) { + BrokerRegistrationReply reply = controller.registerBroker( + new BrokerRegistrationRequestData() + .setBrokerId(brokerId) + .setRack(null) + .setClusterId("06B-K3N1TBCNYFgruEVP0Q") + .setIncarnationId(Uuid.fromString("kxAT73dKQsitIedpiPtwB" + brokerId)) + .setListeners( + new ListenerCollection( + Arrays.asList( + new Listener() + .setName("PLAINTEXT") + .setHost("localhost") + .setPort(9092 + brokerId) + ).iterator() + ) + ) + ).get(); + brokerEpochs.put(brokerId, reply.epoch()); + + // Send heartbeat to unfence + controller.processBrokerHeartbeat( + new BrokerHeartbeatRequestData() + .setWantFence(false) + .setBrokerEpoch(brokerEpochs.get(brokerId)) + .setBrokerId(brokerId) + .setCurrentMetadataOffset(100000L) + ).get(); + } + + return brokerEpochs; + } + + private void sendBrokerheartbeat( + QuorumController controller, + List brokers, + Map brokerEpochs + ) throws Exception { + if (brokers.isEmpty()) { + return; + } + for (Integer brokerId : brokers) { + BrokerHeartbeatReply reply = controller.processBrokerHeartbeat( + new BrokerHeartbeatRequestData() + .setWantFence(false) + .setBrokerEpoch(brokerEpochs.get(brokerId)) + .setBrokerId(brokerId) + .setCurrentMetadataOffset(100000) + ).get(); + assertEquals(new BrokerHeartbeatReply(true, false, false, false), reply); + } + } + +} diff --git a/metadata/src/test/java/org/apache/kafka/controller/QuorumControllerTestEnv.java b/metadata/src/test/java/org/apache/kafka/controller/QuorumControllerTestEnv.java new file mode 100644 index 0000000..7487882 --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/controller/QuorumControllerTestEnv.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import static java.util.concurrent.TimeUnit.NANOSECONDS; + +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import org.apache.kafka.controller.QuorumController.Builder; +import org.apache.kafka.metalog.LocalLogManagerTestEnv; +import org.apache.kafka.raft.LeaderAndEpoch; +import org.apache.kafka.test.TestUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; +import java.util.OptionalInt; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +public class QuorumControllerTestEnv implements AutoCloseable { + private static final Logger log = + LoggerFactory.getLogger(QuorumControllerTestEnv.class); + + private final List controllers; + private final LocalLogManagerTestEnv logEnv; + + public QuorumControllerTestEnv( + LocalLogManagerTestEnv logEnv, + Consumer builderConsumer + ) throws Exception { + this(logEnv, builderConsumer, Optional.empty()); + } + + public QuorumControllerTestEnv( + LocalLogManagerTestEnv logEnv, + Consumer builderConsumer, + Optional sessionTimeoutMillis + ) throws Exception { + this.logEnv = logEnv; + int numControllers = logEnv.logManagers().size(); + this.controllers = new ArrayList<>(numControllers); + try { + for (int i = 0; i < numControllers; i++) { + QuorumController.Builder builder = new QuorumController.Builder(i); + builder.setRaftClient(logEnv.logManagers().get(i)); + if (sessionTimeoutMillis.isPresent()) { + builder.setSessionTimeoutNs(NANOSECONDS.convert( + sessionTimeoutMillis.get(), TimeUnit.MILLISECONDS)); + } + builderConsumer.accept(builder); + this.controllers.add(builder.build()); + } + } catch (Exception e) { + close(); + throw e; + } + } + + QuorumController activeController() throws InterruptedException { + AtomicReference value = new AtomicReference<>(null); + TestUtils.retryOnExceptionWithTimeout(20000, 3, () -> { + LeaderAndEpoch leader = logEnv.leaderAndEpoch(); + for (QuorumController controller : controllers) { + if (OptionalInt.of(controller.nodeId()).equals(leader.leaderId()) && + controller.curClaimEpoch() == leader.epoch()) { + value.set(controller); + break; + } + } + + if (value.get() == null) { + throw new RuntimeException(String.format("Expected to see %s as leader", leader)); + } + }); + + return value.get(); + } + + public List controllers() { + return controllers; + } + + @Override + public void close() throws InterruptedException { + for (QuorumController controller : controllers) { + controller.beginShutdown(); + } + for (QuorumController controller : controllers) { + controller.close(); + } + } +} diff --git a/metadata/src/test/java/org/apache/kafka/controller/ReplicationControlManagerTest.java b/metadata/src/test/java/org/apache/kafka/controller/ReplicationControlManagerTest.java new file mode 100644 index 0000000..94fbe7c --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/controller/ReplicationControlManagerTest.java @@ -0,0 +1,1602 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import org.apache.kafka.common.ElectionType; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.errors.InvalidReplicaAssignmentException; +import org.apache.kafka.common.errors.PolicyViolationException; +import org.apache.kafka.common.errors.StaleBrokerEpochException; +import org.apache.kafka.common.message.AlterIsrRequestData; +import org.apache.kafka.common.message.AlterIsrRequestData.PartitionData; +import org.apache.kafka.common.message.AlterIsrRequestData.TopicData; +import org.apache.kafka.common.message.AlterIsrResponseData; +import org.apache.kafka.common.message.AlterPartitionReassignmentsRequestData; +import org.apache.kafka.common.message.AlterPartitionReassignmentsRequestData.ReassignablePartition; +import org.apache.kafka.common.message.AlterPartitionReassignmentsRequestData.ReassignableTopic; +import org.apache.kafka.common.message.AlterPartitionReassignmentsResponseData; +import org.apache.kafka.common.message.AlterPartitionReassignmentsResponseData.ReassignablePartitionResponse; +import org.apache.kafka.common.message.AlterPartitionReassignmentsResponseData.ReassignableTopicResponse; +import org.apache.kafka.common.message.BrokerHeartbeatRequestData; +import org.apache.kafka.common.message.CreatePartitionsRequestData.CreatePartitionsAssignment; +import org.apache.kafka.common.message.CreatePartitionsRequestData.CreatePartitionsTopic; +import org.apache.kafka.common.message.CreatePartitionsResponseData.CreatePartitionsTopicResult; +import org.apache.kafka.common.message.CreateTopicsRequestData; +import org.apache.kafka.common.message.CreateTopicsRequestData.CreatableReplicaAssignment; +import org.apache.kafka.common.message.CreateTopicsRequestData.CreatableTopic; +import org.apache.kafka.common.message.CreateTopicsRequestData.CreatableTopicCollection; +import org.apache.kafka.common.message.CreateTopicsResponseData; +import org.apache.kafka.common.message.CreateTopicsResponseData.CreatableTopicResult; +import org.apache.kafka.common.message.ElectLeadersRequestData; +import org.apache.kafka.common.message.ElectLeadersRequestData.TopicPartitions; +import org.apache.kafka.common.message.ElectLeadersRequestData.TopicPartitionsCollection; +import org.apache.kafka.common.message.ElectLeadersResponseData; +import org.apache.kafka.common.message.ElectLeadersResponseData.PartitionResult; +import org.apache.kafka.common.message.ElectLeadersResponseData.ReplicaElectionResult; +import org.apache.kafka.common.message.ListPartitionReassignmentsRequestData.ListPartitionReassignmentsTopics; +import org.apache.kafka.common.message.ListPartitionReassignmentsResponseData; +import org.apache.kafka.common.message.ListPartitionReassignmentsResponseData.OngoingPartitionReassignment; +import org.apache.kafka.common.message.ListPartitionReassignmentsResponseData.OngoingTopicReassignment; +import org.apache.kafka.common.metadata.ConfigRecord; +import org.apache.kafka.common.metadata.PartitionChangeRecord; +import org.apache.kafka.common.metadata.PartitionRecord; +import org.apache.kafka.common.metadata.RegisterBrokerRecord; +import org.apache.kafka.common.metadata.TopicRecord; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.ApiError; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.metadata.BrokerHeartbeatReply; +import org.apache.kafka.metadata.BrokerRegistration; +import org.apache.kafka.metadata.PartitionRegistration; +import org.apache.kafka.metadata.RecordTestUtils; +import org.apache.kafka.metadata.Replicas; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.apache.kafka.server.policy.CreateTopicPolicy; +import org.apache.kafka.timeline.SnapshotRegistry; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.concurrent.atomic.AtomicLong; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static java.util.Collections.singletonMap; +import static org.apache.kafka.common.config.TopicConfig.SEGMENT_BYTES_CONFIG; +import static org.apache.kafka.common.protocol.Errors.ELECTION_NOT_NEEDED; +import static org.apache.kafka.common.protocol.Errors.ELIGIBLE_LEADERS_NOT_AVAILABLE; +import static org.apache.kafka.common.protocol.Errors.FENCED_LEADER_EPOCH; +import static org.apache.kafka.common.protocol.Errors.INVALID_PARTITIONS; +import static org.apache.kafka.common.protocol.Errors.INVALID_REPLICA_ASSIGNMENT; +import static org.apache.kafka.common.protocol.Errors.INVALID_TOPIC_EXCEPTION; +import static org.apache.kafka.common.protocol.Errors.NONE; +import static org.apache.kafka.common.protocol.Errors.NO_REASSIGNMENT_IN_PROGRESS; +import static org.apache.kafka.common.protocol.Errors.POLICY_VIOLATION; +import static org.apache.kafka.common.protocol.Errors.PREFERRED_LEADER_NOT_AVAILABLE; +import static org.apache.kafka.common.protocol.Errors.UNKNOWN_TOPIC_ID; +import static org.apache.kafka.common.protocol.Errors.UNKNOWN_TOPIC_OR_PARTITION; +import static org.apache.kafka.controller.BrokersToIsrs.TopicIdPartition; +import static org.apache.kafka.metadata.LeaderConstants.NO_LEADER; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + + +@Timeout(40) +public class ReplicationControlManagerTest { + private final static Logger log = LoggerFactory.getLogger(ReplicationControlManagerTest.class); + private final static int BROKER_SESSION_TIMEOUT_MS = 1000; + + private static class ReplicationControlTestContext { + final SnapshotRegistry snapshotRegistry = new SnapshotRegistry(new LogContext()); + final LogContext logContext = new LogContext(); + final MockTime time = new MockTime(); + final MockRandom random = new MockRandom(); + final ControllerMetrics metrics = new MockControllerMetrics(); + final ClusterControlManager clusterControl = new ClusterControlManager( + logContext, time, snapshotRegistry, TimeUnit.MILLISECONDS.convert(BROKER_SESSION_TIMEOUT_MS, TimeUnit.NANOSECONDS), + new StripedReplicaPlacer(random), metrics); + final ConfigurationControlManager configurationControl = new ConfigurationControlManager( + new LogContext(), snapshotRegistry, Collections.emptyMap(), Optional.empty(), + (__, ___) -> { }); + final ReplicationControlManager replicationControl; + + void replay(List records) throws Exception { + RecordTestUtils.replayAll(clusterControl, records); + RecordTestUtils.replayAll(configurationControl, records); + RecordTestUtils.replayAll(replicationControl, records); + } + + ReplicationControlTestContext() { + this(Optional.empty()); + } + + ReplicationControlTestContext(Optional createTopicPolicy) { + this.replicationControl = new ReplicationControlManager(snapshotRegistry, + new LogContext(), + (short) 3, + 1, + configurationControl, + clusterControl, + metrics, + createTopicPolicy); + clusterControl.activate(); + } + + CreatableTopicResult createTestTopic(String name, + int numPartitions, + short replicationFactor, + short expectedErrorCode) throws Exception { + CreateTopicsRequestData request = new CreateTopicsRequestData(); + CreatableTopic topic = new CreatableTopic().setName(name); + topic.setNumPartitions(numPartitions).setReplicationFactor(replicationFactor); + request.topics().add(topic); + ControllerResult result = + replicationControl.createTopics(request); + CreatableTopicResult topicResult = result.response().topics().find(name); + assertNotNull(topicResult); + assertEquals(expectedErrorCode, topicResult.errorCode()); + if (expectedErrorCode == NONE.code()) { + replay(result.records()); + } + return topicResult; + } + + CreatableTopicResult createTestTopic(String name, int[][] replicas) throws Exception { + return createTestTopic(name, replicas, Collections.emptyMap(), (short) 0); + } + + CreatableTopicResult createTestTopic(String name, int[][] replicas, + short expectedErrorCode) throws Exception { + return createTestTopic(name, replicas, Collections.emptyMap(), expectedErrorCode); + } + + CreatableTopicResult createTestTopic(String name, int[][] replicas, + Map configs, + short expectedErrorCode) throws Exception { + assertFalse(replicas.length == 0); + CreateTopicsRequestData request = new CreateTopicsRequestData(); + CreatableTopic topic = new CreatableTopic().setName(name); + topic.setNumPartitions(-1).setReplicationFactor((short) -1); + for (int i = 0; i < replicas.length; i++) { + topic.assignments().add(new CreatableReplicaAssignment(). + setPartitionIndex(i).setBrokerIds(Replicas.toList(replicas[i]))); + } + configs.entrySet().forEach(e -> topic.configs().add( + new CreateTopicsRequestData.CreateableTopicConfig().setName(e.getKey()). + setValue(e.getValue()))); + request.topics().add(topic); + ControllerResult result = + replicationControl.createTopics(request); + CreatableTopicResult topicResult = result.response().topics().find(name); + assertNotNull(topicResult); + assertEquals(expectedErrorCode, topicResult.errorCode()); + if (expectedErrorCode == NONE.code()) { + assertEquals(replicas.length, topicResult.numPartitions()); + assertEquals(replicas[0].length, topicResult.replicationFactor()); + replay(result.records()); + } + return topicResult; + } + + void createPartitions(int count, String name, + int[][] replicas, short expectedErrorCode) throws Exception { + assertFalse(replicas.length == 0); + CreatePartitionsTopic topic = new CreatePartitionsTopic(). + setName(name). + setCount(count); + for (int i = 0; i < replicas.length; i++) { + topic.assignments().add(new CreatePartitionsAssignment(). + setBrokerIds(Replicas.toList(replicas[i]))); + } + ControllerResult> result = + replicationControl.createPartitions(Collections.singletonList(topic)); + assertEquals(1, result.response().size()); + CreatePartitionsTopicResult topicResult = result.response().get(0); + assertEquals(name, topicResult.name()); + assertEquals(expectedErrorCode, topicResult.errorCode()); + replay(result.records()); + } + + void registerBrokers(Integer... brokerIds) throws Exception { + for (int brokerId : brokerIds) { + RegisterBrokerRecord brokerRecord = new RegisterBrokerRecord(). + setBrokerEpoch(brokerId + 100).setBrokerId(brokerId); + brokerRecord.endPoints().add(new RegisterBrokerRecord.BrokerEndpoint(). + setSecurityProtocol(SecurityProtocol.PLAINTEXT.id). + setPort((short) 9092 + brokerId). + setName("PLAINTEXT"). + setHost("localhost")); + replay(Collections.singletonList(new ApiMessageAndVersion(brokerRecord, (short) 0))); + } + } + + void alterIsr( + TopicIdPartition topicIdPartition, + int leaderId, + List isr + ) throws Exception { + BrokerRegistration registration = clusterControl.brokerRegistrations().get(leaderId); + assertFalse(registration.fenced()); + + PartitionRegistration partition = replicationControl.getPartition( + topicIdPartition.topicId(), + topicIdPartition.partitionId() + ); + assertNotNull(partition); + assertEquals(leaderId, partition.leader); + + PartitionData partitionData = new PartitionData() + .setPartitionIndex(topicIdPartition.partitionId()) + .setCurrentIsrVersion(partition.partitionEpoch) + .setLeaderEpoch(partition.leaderEpoch) + .setNewIsr(isr); + + String topicName = replicationControl.getTopic(topicIdPartition.topicId()).name(); + TopicData topicData = new TopicData() + .setName(topicName) + .setPartitions(singletonList(partitionData)); + + ControllerResult alterIsr = replicationControl.alterIsr( + new AlterIsrRequestData() + .setBrokerId(leaderId) + .setBrokerEpoch(registration.epoch()) + .setTopics(singletonList(topicData))); + replay(alterIsr.records()); + } + + void unfenceBrokers(Integer... brokerIds) throws Exception { + unfenceBrokers(Utils.mkSet(brokerIds)); + } + + void unfenceBrokers(Set brokerIds) throws Exception { + for (int brokerId : brokerIds) { + ControllerResult result = replicationControl. + processBrokerHeartbeat(new BrokerHeartbeatRequestData(). + setBrokerId(brokerId).setBrokerEpoch(brokerId + 100). + setCurrentMetadataOffset(1). + setWantFence(false).setWantShutDown(false), 0); + assertEquals(new BrokerHeartbeatReply(true, false, false, false), + result.response()); + replay(result.records()); + } + } + + void alterTopicConfig( + String topic, + String configKey, + String configValue + ) throws Exception { + ConfigRecord configRecord = new ConfigRecord() + .setResourceType(ConfigResource.Type.TOPIC.id()) + .setResourceName(topic) + .setName(configKey) + .setValue(configValue); + replay(singletonList(new ApiMessageAndVersion(configRecord, (short) 0))); + } + + void fenceBrokers(Set brokerIds) throws Exception { + time.sleep(BROKER_SESSION_TIMEOUT_MS); + + Set unfencedBrokerIds = clusterControl.brokerRegistrations().keySet().stream() + .filter(brokerId -> !brokerIds.contains(brokerId)) + .collect(Collectors.toSet()); + unfenceBrokers(unfencedBrokerIds.toArray(new Integer[0])); + + Optional staleBroker = clusterControl.heartbeatManager().findOneStaleBroker(); + while (staleBroker.isPresent()) { + ControllerResult fenceResult = replicationControl.maybeFenceOneStaleBroker(); + replay(fenceResult.records()); + staleBroker = clusterControl.heartbeatManager().findOneStaleBroker(); + } + + assertEquals(brokerIds, clusterControl.fencedBrokerIds()); + } + + long currentBrokerEpoch(int brokerId) { + Map registrations = clusterControl.brokerRegistrations(); + BrokerRegistration registration = registrations.get(brokerId); + assertNotNull(registration, "No current registration for broker " + brokerId); + return registration.epoch(); + } + + OptionalInt currentLeader(TopicIdPartition topicIdPartition) { + PartitionRegistration partition = replicationControl. + getPartition(topicIdPartition.topicId(), topicIdPartition.partitionId()); + return (partition.leader < 0) ? OptionalInt.empty() : OptionalInt.of(partition.leader); + } + } + + private static class MockCreateTopicPolicy implements CreateTopicPolicy { + private final List expecteds; + private final AtomicLong index = new AtomicLong(0); + + MockCreateTopicPolicy(List expecteds) { + this.expecteds = expecteds; + } + + @Override + public void validate(RequestMetadata actual) throws PolicyViolationException { + long curIndex = index.getAndIncrement(); + if (curIndex >= expecteds.size()) { + throw new PolicyViolationException("Unexpected topic creation: index " + + "out of range at " + curIndex); + } + RequestMetadata expected = expecteds.get((int) curIndex); + if (!expected.equals(actual)) { + throw new PolicyViolationException("Expected: " + expected + + ". Got: " + actual); + } + } + + @Override + public void close() throws Exception { + // nothing to do + } + + @Override + public void configure(Map configs) { + // nothing to do + } + } + + @Test + public void testCreateTopics() throws Exception { + ReplicationControlTestContext ctx = new ReplicationControlTestContext(); + ReplicationControlManager replicationControl = ctx.replicationControl; + CreateTopicsRequestData request = new CreateTopicsRequestData(); + request.topics().add(new CreatableTopic().setName("foo"). + setNumPartitions(-1).setReplicationFactor((short) -1)); + ControllerResult result = + replicationControl.createTopics(request); + CreateTopicsResponseData expectedResponse = new CreateTopicsResponseData(); + expectedResponse.topics().add(new CreatableTopicResult().setName("foo"). + setErrorCode(Errors.INVALID_REPLICATION_FACTOR.code()). + setErrorMessage("Unable to replicate the partition 3 time(s): All " + + "brokers are currently fenced.")); + assertEquals(expectedResponse, result.response()); + + ctx.registerBrokers(0, 1, 2); + ctx.unfenceBrokers(0, 1, 2); + ControllerResult result2 = + replicationControl.createTopics(request); + CreateTopicsResponseData expectedResponse2 = new CreateTopicsResponseData(); + expectedResponse2.topics().add(new CreatableTopicResult().setName("foo"). + setNumPartitions(1).setReplicationFactor((short) 3). + setErrorMessage(null).setErrorCode((short) 0). + setTopicId(result2.response().topics().find("foo").topicId())); + assertEquals(expectedResponse2, result2.response()); + ctx.replay(result2.records()); + assertEquals(new PartitionRegistration(new int[] {1, 2, 0}, + new int[] {1, 2, 0}, Replicas.NONE, Replicas.NONE, 1, 0, 0), + replicationControl.getPartition( + ((TopicRecord) result2.records().get(0).message()).topicId(), 0)); + ControllerResult result3 = + replicationControl.createTopics(request); + CreateTopicsResponseData expectedResponse3 = new CreateTopicsResponseData(); + expectedResponse3.topics().add(new CreatableTopicResult().setName("foo"). + setErrorCode(Errors.TOPIC_ALREADY_EXISTS.code()). + setErrorMessage("Topic 'foo' already exists.")); + assertEquals(expectedResponse3, result3.response()); + Uuid fooId = result2.response().topics().find("foo").topicId(); + RecordTestUtils.assertBatchIteratorContains(asList( + asList(new ApiMessageAndVersion(new PartitionRecord(). + setPartitionId(0).setTopicId(fooId). + setReplicas(asList(1, 2, 0)).setIsr(asList(1, 2, 0)). + setRemovingReplicas(Collections.emptyList()).setAddingReplicas(Collections.emptyList()).setLeader(1). + setLeaderEpoch(0).setPartitionEpoch(0), (short) 0), + new ApiMessageAndVersion(new TopicRecord(). + setTopicId(fooId).setName("foo"), (short) 0))), + ctx.replicationControl.iterator(Long.MAX_VALUE)); + } + + @Test + public void testBrokerCountMetrics() throws Exception { + ReplicationControlTestContext ctx = new ReplicationControlTestContext(); + ReplicationControlManager replicationControl = ctx.replicationControl; + + ctx.registerBrokers(0); + + assertEquals(1, ctx.metrics.fencedBrokerCount()); + assertEquals(0, ctx.metrics.activeBrokerCount()); + + ctx.unfenceBrokers(0); + + assertEquals(0, ctx.metrics.fencedBrokerCount()); + assertEquals(1, ctx.metrics.activeBrokerCount()); + + ctx.registerBrokers(1); + ctx.unfenceBrokers(1); + + assertEquals(2, ctx.metrics.activeBrokerCount()); + + ctx.registerBrokers(2); + ctx.unfenceBrokers(2); + + assertEquals(0, ctx.metrics.fencedBrokerCount()); + assertEquals(3, ctx.metrics.activeBrokerCount()); + + ControllerResult result = replicationControl.unregisterBroker(0); + ctx.replay(result.records()); + result = replicationControl.unregisterBroker(2); + ctx.replay(result.records()); + + assertEquals(0, ctx.metrics.fencedBrokerCount()); + assertEquals(1, ctx.metrics.activeBrokerCount()); + } + + @Test + public void testCreateTopicsWithValidateOnlyFlag() throws Exception { + ReplicationControlTestContext ctx = new ReplicationControlTestContext(); + ctx.registerBrokers(0, 1, 2); + ctx.unfenceBrokers(0, 1, 2); + CreateTopicsRequestData request = new CreateTopicsRequestData().setValidateOnly(true); + request.topics().add(new CreatableTopic().setName("foo"). + setNumPartitions(1).setReplicationFactor((short) 3)); + ControllerResult result = + ctx.replicationControl.createTopics(request); + assertEquals(0, result.records().size()); + CreatableTopicResult topicResult = result.response().topics().find("foo"); + assertEquals((short) 0, topicResult.errorCode()); + } + + @Test + public void testInvalidCreateTopicsWithValidateOnlyFlag() throws Exception { + ReplicationControlTestContext ctx = new ReplicationControlTestContext(); + ctx.registerBrokers(0, 1, 2); + ctx.unfenceBrokers(0, 1, 2); + CreateTopicsRequestData request = new CreateTopicsRequestData().setValidateOnly(true); + request.topics().add(new CreatableTopic().setName("foo"). + setNumPartitions(1).setReplicationFactor((short) 4)); + ControllerResult result = + ctx.replicationControl.createTopics(request); + assertEquals(0, result.records().size()); + CreateTopicsResponseData expectedResponse = new CreateTopicsResponseData(); + expectedResponse.topics().add(new CreatableTopicResult().setName("foo"). + setErrorCode(Errors.INVALID_REPLICATION_FACTOR.code()). + setErrorMessage("Unable to replicate the partition 4 time(s): The target " + + "replication factor of 4 cannot be reached because only 3 broker(s) " + + "are registered.")); + assertEquals(expectedResponse, result.response()); + } + + @Test + public void testCreateTopicsWithPolicy() throws Exception { + MockCreateTopicPolicy createTopicPolicy = new MockCreateTopicPolicy(asList( + new CreateTopicPolicy.RequestMetadata("foo", 2, (short) 2, + null, Collections.emptyMap()), + new CreateTopicPolicy.RequestMetadata("bar", 3, (short) 2, + null, Collections.emptyMap()), + new CreateTopicPolicy.RequestMetadata("baz", null, null, + Collections.singletonMap(0, asList(2, 1, 0)), + Collections.singletonMap(SEGMENT_BYTES_CONFIG, "12300000")), + new CreateTopicPolicy.RequestMetadata("quux", null, null, + Collections.singletonMap(0, asList(2, 1, 0)), Collections.emptyMap()))); + ReplicationControlTestContext ctx = + new ReplicationControlTestContext(Optional.of(createTopicPolicy)); + ctx.registerBrokers(0, 1, 2); + ctx.unfenceBrokers(0, 1, 2); + ctx.createTestTopic("foo", 2, (short) 2, NONE.code()); + ctx.createTestTopic("bar", 3, (short) 3, POLICY_VIOLATION.code()); + ctx.createTestTopic("baz", new int[][] {new int[] {2, 1, 0}}, + Collections.singletonMap(SEGMENT_BYTES_CONFIG, "12300000"), NONE.code()); + ctx.createTestTopic("quux", new int[][] {new int[] {1, 2, 0}}, POLICY_VIOLATION.code()); + } + + @Test + public void testGlobalTopicAndPartitionMetrics() throws Exception { + ReplicationControlTestContext ctx = new ReplicationControlTestContext(); + ReplicationControlManager replicationControl = ctx.replicationControl; + CreateTopicsRequestData request = new CreateTopicsRequestData(); + request.topics().add(new CreatableTopic().setName("foo"). + setNumPartitions(1).setReplicationFactor((short) -1)); + + ctx.registerBrokers(0, 1, 2); + ctx.unfenceBrokers(0, 1, 2); + + List topicsToDelete = new ArrayList<>(); + + ControllerResult result = + replicationControl.createTopics(request); + topicsToDelete.add(result.response().topics().find("foo").topicId()); + + RecordTestUtils.replayAll(replicationControl, result.records()); + assertEquals(1, ctx.metrics.globalTopicsCount()); + + request = new CreateTopicsRequestData(); + request.topics().add(new CreatableTopic().setName("bar"). + setNumPartitions(1).setReplicationFactor((short) -1)); + request.topics().add(new CreatableTopic().setName("baz"). + setNumPartitions(2).setReplicationFactor((short) -1)); + result = replicationControl.createTopics(request); + RecordTestUtils.replayAll(replicationControl, result.records()); + assertEquals(3, ctx.metrics.globalTopicsCount()); + assertEquals(4, ctx.metrics.globalPartitionCount()); + + topicsToDelete.add(result.response().topics().find("baz").topicId()); + ControllerResult> deleteResult = replicationControl.deleteTopics(topicsToDelete); + RecordTestUtils.replayAll(replicationControl, deleteResult.records()); + assertEquals(1, ctx.metrics.globalTopicsCount()); + assertEquals(1, ctx.metrics.globalPartitionCount()); + + Uuid topicToDelete = result.response().topics().find("bar").topicId(); + deleteResult = replicationControl.deleteTopics(Collections.singletonList(topicToDelete)); + RecordTestUtils.replayAll(replicationControl, deleteResult.records()); + assertEquals(0, ctx.metrics.globalTopicsCount()); + assertEquals(0, ctx.metrics.globalPartitionCount()); + } + + @Test + public void testOfflinePartitionAndReplicaImbalanceMetrics() throws Exception { + ReplicationControlTestContext ctx = new ReplicationControlTestContext(); + ReplicationControlManager replicationControl = ctx.replicationControl; + ctx.registerBrokers(0, 1, 2, 3); + ctx.unfenceBrokers(0, 1, 2, 3); + + CreatableTopicResult foo = ctx.createTestTopic("foo", new int[][] { + new int[] {0, 2}, new int[] {0, 1}}); + + CreatableTopicResult zar = ctx.createTestTopic("zar", new int[][] { + new int[] {0, 1, 2}, new int[] {1, 2, 3}, new int[] {1, 2, 0}}); + + ControllerResult result = replicationControl.unregisterBroker(0); + ctx.replay(result.records()); + + // All partitions should still be online after unregistering broker 0 + assertEquals(0, ctx.metrics.offlinePartitionCount()); + // Three partitions should not have their preferred (first) replica 0 + assertEquals(3, ctx.metrics.preferredReplicaImbalanceCount()); + + result = replicationControl.unregisterBroker(1); + ctx.replay(result.records()); + + // After unregistering broker 1, 1 partition for topic foo should go offline + assertEquals(1, ctx.metrics.offlinePartitionCount()); + // All five partitions should not have their preferred (first) replica at this point + assertEquals(5, ctx.metrics.preferredReplicaImbalanceCount()); + + result = replicationControl.unregisterBroker(2); + ctx.replay(result.records()); + + // After unregistering broker 2, the last partition for topic foo should go offline + // and 2 partitions for topic zar should go offline + assertEquals(4, ctx.metrics.offlinePartitionCount()); + + result = replicationControl.unregisterBroker(3); + ctx.replay(result.records()); + + // After unregistering broker 3 the last partition for topic zar should go offline + assertEquals(5, ctx.metrics.offlinePartitionCount()); + + // Deleting topic foo should bring the offline partition count down to 3 + ArrayList records = new ArrayList<>(); + replicationControl.deleteTopic(foo.topicId(), records); + ctx.replay(records); + + assertEquals(3, ctx.metrics.offlinePartitionCount()); + + // Deleting topic zar should bring the offline partition count down to 0 + records = new ArrayList<>(); + replicationControl.deleteTopic(zar.topicId(), records); + ctx.replay(records); + + assertEquals(0, ctx.metrics.offlinePartitionCount()); + } + + @Test + public void testValidateNewTopicNames() { + Map topicErrors = new HashMap<>(); + CreatableTopicCollection topics = new CreatableTopicCollection(); + topics.add(new CreatableTopic().setName("")); + topics.add(new CreatableTopic().setName("woo")); + topics.add(new CreatableTopic().setName(".")); + ReplicationControlManager.validateNewTopicNames(topicErrors, topics); + Map expectedTopicErrors = new HashMap<>(); + expectedTopicErrors.put("", new ApiError(INVALID_TOPIC_EXCEPTION, + "Topic name is illegal, it can't be empty")); + expectedTopicErrors.put(".", new ApiError(INVALID_TOPIC_EXCEPTION, + "Topic name cannot be \".\" or \"..\"")); + assertEquals(expectedTopicErrors, topicErrors); + } + + @Test + public void testRemoveLeaderships() throws Exception { + ReplicationControlTestContext ctx = new ReplicationControlTestContext(); + ReplicationControlManager replicationControl = ctx.replicationControl; + ctx.registerBrokers(0, 1, 2, 3); + ctx.unfenceBrokers(0, 1, 2, 3); + CreatableTopicResult result = ctx.createTestTopic("foo", + new int[][] { + new int[] {0, 1, 2}, + new int[] {1, 2, 3}, + new int[] {2, 3, 0}, + new int[] {0, 2, 1} + }); + Set expectedPartitions = new HashSet<>(); + expectedPartitions.add(new TopicIdPartition(result.topicId(), 0)); + expectedPartitions.add(new TopicIdPartition(result.topicId(), 3)); + assertEquals(expectedPartitions, RecordTestUtils. + iteratorToSet(replicationControl.brokersToIsrs().iterator(0, true))); + List records = new ArrayList<>(); + replicationControl.handleBrokerFenced(0, records); + ctx.replay(records); + assertEquals(Collections.emptySet(), RecordTestUtils. + iteratorToSet(replicationControl.brokersToIsrs().iterator(0, true))); + } + + @Test + public void testShrinkAndExpandIsr() throws Exception { + ReplicationControlTestContext ctx = new ReplicationControlTestContext(); + ReplicationControlManager replicationControl = ctx.replicationControl; + ctx.registerBrokers(0, 1, 2); + ctx.unfenceBrokers(0, 1, 2); + CreatableTopicResult createTopicResult = ctx.createTestTopic("foo", + new int[][] {new int[] {0, 1, 2}}); + + TopicIdPartition topicIdPartition = new TopicIdPartition(createTopicResult.topicId(), 0); + TopicPartition topicPartition = new TopicPartition("foo", 0); + assertEquals(OptionalInt.of(0), ctx.currentLeader(topicIdPartition)); + long brokerEpoch = ctx.currentBrokerEpoch(0); + PartitionData shrinkIsrRequest = newAlterIsrPartition( + replicationControl, topicIdPartition, asList(0, 1)); + ControllerResult shrinkIsrResult = sendAlterIsr( + replicationControl, 0, brokerEpoch, "foo", shrinkIsrRequest); + AlterIsrResponseData.PartitionData shrinkIsrResponse = assertAlterIsrResponse( + shrinkIsrResult, topicPartition, NONE); + assertConsistentAlterIsrResponse(replicationControl, topicIdPartition, shrinkIsrResponse); + + PartitionData expandIsrRequest = newAlterIsrPartition( + replicationControl, topicIdPartition, asList(0, 1, 2)); + ControllerResult expandIsrResult = sendAlterIsr( + replicationControl, 0, brokerEpoch, "foo", expandIsrRequest); + AlterIsrResponseData.PartitionData expandIsrResponse = assertAlterIsrResponse( + expandIsrResult, topicPartition, NONE); + assertConsistentAlterIsrResponse(replicationControl, topicIdPartition, expandIsrResponse); + } + + @Test + public void testInvalidAlterIsrRequests() throws Exception { + ReplicationControlTestContext ctx = new ReplicationControlTestContext(); + ReplicationControlManager replicationControl = ctx.replicationControl; + ctx.registerBrokers(0, 1, 2); + ctx.unfenceBrokers(0, 1, 2); + CreatableTopicResult createTopicResult = ctx.createTestTopic("foo", + new int[][] {new int[] {0, 1, 2}}); + + TopicIdPartition topicIdPartition = new TopicIdPartition(createTopicResult.topicId(), 0); + TopicPartition topicPartition = new TopicPartition("foo", 0); + assertEquals(OptionalInt.of(0), ctx.currentLeader(topicIdPartition)); + long brokerEpoch = ctx.currentBrokerEpoch(0); + + // Invalid leader + PartitionData invalidLeaderRequest = newAlterIsrPartition( + replicationControl, topicIdPartition, asList(0, 1)); + ControllerResult invalidLeaderResult = sendAlterIsr( + replicationControl, 1, ctx.currentBrokerEpoch(1), + "foo", invalidLeaderRequest); + assertAlterIsrResponse(invalidLeaderResult, topicPartition, Errors.INVALID_REQUEST); + + // Stale broker epoch + PartitionData invalidBrokerEpochRequest = newAlterIsrPartition( + replicationControl, topicIdPartition, asList(0, 1)); + assertThrows(StaleBrokerEpochException.class, () -> sendAlterIsr( + replicationControl, 0, brokerEpoch - 1, "foo", invalidBrokerEpochRequest)); + + // Invalid leader epoch + PartitionData invalidLeaderEpochRequest = newAlterIsrPartition( + replicationControl, topicIdPartition, asList(0, 1)); + invalidLeaderEpochRequest.setLeaderEpoch(500); + ControllerResult invalidLeaderEpochResult = sendAlterIsr( + replicationControl, 1, ctx.currentBrokerEpoch(1), + "foo", invalidLeaderEpochRequest); + assertAlterIsrResponse(invalidLeaderEpochResult, topicPartition, FENCED_LEADER_EPOCH); + + // Invalid ISR (3 is not a valid replica) + PartitionData invalidIsrRequest1 = newAlterIsrPartition( + replicationControl, topicIdPartition, asList(0, 1)); + invalidIsrRequest1.setNewIsr(asList(0, 1, 3)); + ControllerResult invalidIsrResult1 = sendAlterIsr( + replicationControl, 1, ctx.currentBrokerEpoch(1), + "foo", invalidIsrRequest1); + assertAlterIsrResponse(invalidIsrResult1, topicPartition, Errors.INVALID_REQUEST); + + // Invalid ISR (does not include leader 0) + PartitionData invalidIsrRequest2 = newAlterIsrPartition( + replicationControl, topicIdPartition, asList(0, 1)); + invalidIsrRequest2.setNewIsr(asList(1, 2)); + ControllerResult invalidIsrResult2 = sendAlterIsr( + replicationControl, 1, ctx.currentBrokerEpoch(1), + "foo", invalidIsrRequest2); + assertAlterIsrResponse(invalidIsrResult2, topicPartition, Errors.INVALID_REQUEST); + } + + private PartitionData newAlterIsrPartition( + ReplicationControlManager replicationControl, + TopicIdPartition topicIdPartition, + List newIsr + ) { + PartitionRegistration partitionControl = + replicationControl.getPartition(topicIdPartition.topicId(), topicIdPartition.partitionId()); + return new AlterIsrRequestData.PartitionData() + .setPartitionIndex(0) + .setLeaderEpoch(partitionControl.leaderEpoch) + .setCurrentIsrVersion(partitionControl.partitionEpoch) + .setNewIsr(newIsr); + } + + private ControllerResult sendAlterIsr( + ReplicationControlManager replicationControl, + int brokerId, + long brokerEpoch, + String topic, + AlterIsrRequestData.PartitionData partitionData + ) throws Exception { + AlterIsrRequestData request = new AlterIsrRequestData() + .setBrokerId(brokerId) + .setBrokerEpoch(brokerEpoch); + + AlterIsrRequestData.TopicData topicData = new AlterIsrRequestData.TopicData() + .setName(topic); + request.topics().add(topicData); + topicData.partitions().add(partitionData); + + ControllerResult result = replicationControl.alterIsr(request); + RecordTestUtils.replayAll(replicationControl, result.records()); + return result; + } + + private AlterIsrResponseData.PartitionData assertAlterIsrResponse( + ControllerResult alterIsrResult, + TopicPartition topicPartition, + Errors expectedError + ) { + AlterIsrResponseData response = alterIsrResult.response(); + assertEquals(1, response.topics().size()); + + AlterIsrResponseData.TopicData topicData = response.topics().get(0); + assertEquals(topicPartition.topic(), topicData.name()); + assertEquals(1, topicData.partitions().size()); + + AlterIsrResponseData.PartitionData partitionData = topicData.partitions().get(0); + assertEquals(topicPartition.partition(), partitionData.partitionIndex()); + assertEquals(expectedError, Errors.forCode(partitionData.errorCode())); + return partitionData; + } + + private void assertConsistentAlterIsrResponse( + ReplicationControlManager replicationControl, + TopicIdPartition topicIdPartition, + AlterIsrResponseData.PartitionData partitionData + ) { + PartitionRegistration partitionControl = + replicationControl.getPartition(topicIdPartition.topicId(), topicIdPartition.partitionId()); + assertEquals(partitionControl.leader, partitionData.leaderId()); + assertEquals(partitionControl.leaderEpoch, partitionData.leaderEpoch()); + assertEquals(partitionControl.partitionEpoch, partitionData.currentIsrVersion()); + List expectedIsr = IntStream.of(partitionControl.isr).boxed().collect(Collectors.toList()); + assertEquals(expectedIsr, partitionData.isr()); + } + + private void assertCreatedTopicConfigs( + ReplicationControlTestContext ctx, + String topic, + CreateTopicsRequestData.CreateableTopicConfigCollection requestConfigs + ) { + Map configs = ctx.configurationControl.getConfigs( + new ConfigResource(ConfigResource.Type.TOPIC, topic)); + assertEquals(requestConfigs.size(), configs.size()); + for (CreateTopicsRequestData.CreateableTopicConfig requestConfig : requestConfigs) { + String value = configs.get(requestConfig.name()); + assertEquals(requestConfig.value(), value); + } + } + + private void assertEmptyTopicConfigs( + ReplicationControlTestContext ctx, + String topic + ) { + Map configs = ctx.configurationControl.getConfigs( + new ConfigResource(ConfigResource.Type.TOPIC, topic)); + assertEquals(Collections.emptyMap(), configs); + } + + @Test + public void testDeleteTopics() throws Exception { + ReplicationControlTestContext ctx = new ReplicationControlTestContext(); + ReplicationControlManager replicationControl = ctx.replicationControl; + CreateTopicsRequestData request = new CreateTopicsRequestData(); + CreateTopicsRequestData.CreateableTopicConfigCollection requestConfigs = + new CreateTopicsRequestData.CreateableTopicConfigCollection(); + requestConfigs.add(new CreateTopicsRequestData.CreateableTopicConfig(). + setName("cleanup.policy").setValue("compact")); + requestConfigs.add(new CreateTopicsRequestData.CreateableTopicConfig(). + setName("min.cleanable.dirty.ratio").setValue("0.1")); + request.topics().add(new CreatableTopic().setName("foo"). + setNumPartitions(3).setReplicationFactor((short) 2). + setConfigs(requestConfigs)); + ctx.registerBrokers(0, 1); + ctx.unfenceBrokers(0, 1); + ControllerResult createResult = + replicationControl.createTopics(request); + CreateTopicsResponseData expectedResponse = new CreateTopicsResponseData(); + Uuid topicId = createResult.response().topics().find("foo").topicId(); + expectedResponse.topics().add(new CreatableTopicResult().setName("foo"). + setNumPartitions(3).setReplicationFactor((short) 2). + setErrorMessage(null).setErrorCode((short) 0). + setTopicId(topicId)); + assertEquals(expectedResponse, createResult.response()); + // Until the records are replayed, no changes are made + assertNull(replicationControl.getPartition(topicId, 0)); + assertEmptyTopicConfigs(ctx, "foo"); + ctx.replay(createResult.records()); + assertNotNull(replicationControl.getPartition(topicId, 0)); + assertNotNull(replicationControl.getPartition(topicId, 1)); + assertNotNull(replicationControl.getPartition(topicId, 2)); + assertNull(replicationControl.getPartition(topicId, 3)); + assertCreatedTopicConfigs(ctx, "foo", requestConfigs); + + assertEquals(singletonMap(topicId, new ResultOrError<>("foo")), + replicationControl.findTopicNames(Long.MAX_VALUE, Collections.singleton(topicId))); + assertEquals(singletonMap("foo", new ResultOrError<>(topicId)), + replicationControl.findTopicIds(Long.MAX_VALUE, Collections.singleton("foo"))); + Uuid invalidId = new Uuid(topicId.getMostSignificantBits() + 1, + topicId.getLeastSignificantBits()); + assertEquals(singletonMap(invalidId, + new ResultOrError<>(new ApiError(UNKNOWN_TOPIC_ID))), + replicationControl.findTopicNames(Long.MAX_VALUE, Collections.singleton(invalidId))); + assertEquals(singletonMap("bar", + new ResultOrError<>(new ApiError(UNKNOWN_TOPIC_OR_PARTITION))), + replicationControl.findTopicIds(Long.MAX_VALUE, Collections.singleton("bar"))); + + ControllerResult> invalidDeleteResult = replicationControl. + deleteTopics(Collections.singletonList(invalidId)); + assertEquals(0, invalidDeleteResult.records().size()); + assertEquals(singletonMap(invalidId, new ApiError(UNKNOWN_TOPIC_ID, null)), + invalidDeleteResult.response()); + ControllerResult> deleteResult = replicationControl. + deleteTopics(Collections.singletonList(topicId)); + assertTrue(deleteResult.isAtomic()); + assertEquals(singletonMap(topicId, new ApiError(NONE, null)), + deleteResult.response()); + assertEquals(1, deleteResult.records().size()); + ctx.replay(deleteResult.records()); + assertNull(replicationControl.getPartition(topicId, 0)); + assertNull(replicationControl.getPartition(topicId, 1)); + assertNull(replicationControl.getPartition(topicId, 2)); + assertNull(replicationControl.getPartition(topicId, 3)); + assertEquals(singletonMap(topicId, new ResultOrError<>( + new ApiError(UNKNOWN_TOPIC_ID))), replicationControl.findTopicNames( + Long.MAX_VALUE, Collections.singleton(topicId))); + assertEquals(singletonMap("foo", new ResultOrError<>( + new ApiError(UNKNOWN_TOPIC_OR_PARTITION))), replicationControl.findTopicIds( + Long.MAX_VALUE, Collections.singleton("foo"))); + assertEmptyTopicConfigs(ctx, "foo"); + } + + + @Test + public void testCreatePartitions() throws Exception { + ReplicationControlTestContext ctx = new ReplicationControlTestContext(); + ReplicationControlManager replicationControl = ctx.replicationControl; + CreateTopicsRequestData request = new CreateTopicsRequestData(); + request.topics().add(new CreatableTopic().setName("foo"). + setNumPartitions(3).setReplicationFactor((short) 2)); + request.topics().add(new CreatableTopic().setName("bar"). + setNumPartitions(4).setReplicationFactor((short) 2)); + request.topics().add(new CreatableTopic().setName("quux"). + setNumPartitions(2).setReplicationFactor((short) 2)); + request.topics().add(new CreatableTopic().setName("foo2"). + setNumPartitions(2).setReplicationFactor((short) 2)); + ctx.registerBrokers(0, 1); + ctx.unfenceBrokers(0, 1); + ControllerResult createTopicResult = + replicationControl.createTopics(request); + ctx.replay(createTopicResult.records()); + List topics = new ArrayList<>(); + topics.add(new CreatePartitionsTopic(). + setName("foo").setCount(5).setAssignments(null)); + topics.add(new CreatePartitionsTopic(). + setName("bar").setCount(3).setAssignments(null)); + topics.add(new CreatePartitionsTopic(). + setName("baz").setCount(3).setAssignments(null)); + topics.add(new CreatePartitionsTopic(). + setName("quux").setCount(2).setAssignments(null)); + ControllerResult> createPartitionsResult = + replicationControl.createPartitions(topics); + assertEquals(asList(new CreatePartitionsTopicResult(). + setName("foo"). + setErrorCode(NONE.code()). + setErrorMessage(null), + new CreatePartitionsTopicResult(). + setName("bar"). + setErrorCode(INVALID_PARTITIONS.code()). + setErrorMessage("The topic bar currently has 4 partition(s); 3 would not be an increase."), + new CreatePartitionsTopicResult(). + setName("baz"). + setErrorCode(UNKNOWN_TOPIC_OR_PARTITION.code()). + setErrorMessage(null), + new CreatePartitionsTopicResult(). + setName("quux"). + setErrorCode(INVALID_PARTITIONS.code()). + setErrorMessage("Topic already has 2 partition(s).")), + createPartitionsResult.response()); + ctx.replay(createPartitionsResult.records()); + List topics2 = new ArrayList<>(); + topics2.add(new CreatePartitionsTopic(). + setName("foo").setCount(6).setAssignments(asList( + new CreatePartitionsAssignment().setBrokerIds(asList(1, 0))))); + topics2.add(new CreatePartitionsTopic(). + setName("bar").setCount(5).setAssignments(asList( + new CreatePartitionsAssignment().setBrokerIds(asList(1))))); + topics2.add(new CreatePartitionsTopic(). + setName("quux").setCount(4).setAssignments(asList( + new CreatePartitionsAssignment().setBrokerIds(asList(1, 0))))); + topics2.add(new CreatePartitionsTopic(). + setName("foo2").setCount(3).setAssignments(asList( + new CreatePartitionsAssignment().setBrokerIds(asList(2, 0))))); + ControllerResult> createPartitionsResult2 = + replicationControl.createPartitions(topics2); + assertEquals(asList(new CreatePartitionsTopicResult(). + setName("foo"). + setErrorCode(NONE.code()). + setErrorMessage(null), + new CreatePartitionsTopicResult(). + setName("bar"). + setErrorCode(INVALID_REPLICA_ASSIGNMENT.code()). + setErrorMessage("The manual partition assignment includes a partition " + + "with 1 replica(s), but this is not consistent with previous " + + "partitions, which have 2 replica(s)."), + new CreatePartitionsTopicResult(). + setName("quux"). + setErrorCode(INVALID_REPLICA_ASSIGNMENT.code()). + setErrorMessage("Attempted to add 2 additional partition(s), but only 1 assignment(s) were specified."), + new CreatePartitionsTopicResult(). + setName("foo2"). + setErrorCode(INVALID_REPLICA_ASSIGNMENT.code()). + setErrorMessage("The manual partition assignment includes broker 2, but " + + "no such broker is registered.")), + createPartitionsResult2.response()); + ctx.replay(createPartitionsResult2.records()); + } + + @Test + public void testValidateGoodManualPartitionAssignments() throws Exception { + ReplicationControlTestContext ctx = new ReplicationControlTestContext(); + ctx.registerBrokers(1, 2, 3); + ctx.replicationControl.validateManualPartitionAssignment(asList(1), + OptionalInt.of(1)); + ctx.replicationControl.validateManualPartitionAssignment(asList(1), + OptionalInt.empty()); + ctx.replicationControl.validateManualPartitionAssignment(asList(1, 2, 3), + OptionalInt.of(3)); + ctx.replicationControl.validateManualPartitionAssignment(asList(1, 2, 3), + OptionalInt.empty()); + } + + @Test + public void testValidateBadManualPartitionAssignments() throws Exception { + ReplicationControlTestContext ctx = new ReplicationControlTestContext(); + ctx.registerBrokers(1, 2); + assertEquals("The manual partition assignment includes an empty replica list.", + assertThrows(InvalidReplicaAssignmentException.class, () -> + ctx.replicationControl.validateManualPartitionAssignment(asList(), + OptionalInt.empty())).getMessage()); + assertEquals("The manual partition assignment includes broker 3, but no such " + + "broker is registered.", assertThrows(InvalidReplicaAssignmentException.class, () -> + ctx.replicationControl.validateManualPartitionAssignment(asList(1, 2, 3), + OptionalInt.empty())).getMessage()); + assertEquals("The manual partition assignment includes the broker 2 more than " + + "once.", assertThrows(InvalidReplicaAssignmentException.class, () -> + ctx.replicationControl.validateManualPartitionAssignment(asList(1, 2, 2), + OptionalInt.empty())).getMessage()); + assertEquals("The manual partition assignment includes a partition with 2 " + + "replica(s), but this is not consistent with previous partitions, which have " + + "3 replica(s).", assertThrows(InvalidReplicaAssignmentException.class, () -> + ctx.replicationControl.validateManualPartitionAssignment(asList(1, 2), + OptionalInt.of(3))).getMessage()); + } + + private final static ListPartitionReassignmentsResponseData NONE_REASSIGNING = + new ListPartitionReassignmentsResponseData().setErrorMessage(null); + + @Test + public void testReassignPartitions() throws Exception { + ReplicationControlTestContext ctx = new ReplicationControlTestContext(); + ReplicationControlManager replication = ctx.replicationControl; + ctx.registerBrokers(0, 1, 2, 3); + ctx.unfenceBrokers(0, 1, 2, 3); + Uuid fooId = ctx.createTestTopic("foo", new int[][] { + new int[] {1, 2, 3}, new int[] {3, 2, 1}}).topicId(); + ctx.createTestTopic("bar", new int[][] { + new int[] {1, 2, 3}}).topicId(); + assertEquals(NONE_REASSIGNING, replication.listPartitionReassignments(null)); + ControllerResult alterResult = + replication.alterPartitionReassignments( + new AlterPartitionReassignmentsRequestData().setTopics(asList( + new ReassignableTopic().setName("foo").setPartitions(asList( + new ReassignablePartition().setPartitionIndex(0). + setReplicas(asList(3, 2, 1)), + new ReassignablePartition().setPartitionIndex(1). + setReplicas(asList(0, 2, 1)), + new ReassignablePartition().setPartitionIndex(2). + setReplicas(asList(0, 2, 1)))), + new ReassignableTopic().setName("bar")))); + assertEquals(new AlterPartitionReassignmentsResponseData(). + setErrorMessage(null).setResponses(asList( + new ReassignableTopicResponse().setName("foo").setPartitions(asList( + new ReassignablePartitionResponse().setPartitionIndex(0). + setErrorMessage(null), + new ReassignablePartitionResponse().setPartitionIndex(1). + setErrorMessage(null), + new ReassignablePartitionResponse().setPartitionIndex(2). + setErrorCode(UNKNOWN_TOPIC_OR_PARTITION.code()). + setErrorMessage("Unable to find partition foo:2."))), + new ReassignableTopicResponse(). + setName("bar"))), + alterResult.response()); + ctx.replay(alterResult.records()); + ListPartitionReassignmentsResponseData currentReassigning = + new ListPartitionReassignmentsResponseData().setErrorMessage(null). + setTopics(asList(new OngoingTopicReassignment(). + setName("foo").setPartitions(asList( + new OngoingPartitionReassignment().setPartitionIndex(1). + setRemovingReplicas(asList(3)). + setAddingReplicas(asList(0)). + setReplicas(asList(0, 2, 1, 3)))))); + assertEquals(currentReassigning, replication.listPartitionReassignments(null)); + assertEquals(NONE_REASSIGNING, replication.listPartitionReassignments(asList( + new ListPartitionReassignmentsTopics().setName("bar"). + setPartitionIndexes(asList(0, 1, 2))))); + assertEquals(currentReassigning, replication.listPartitionReassignments(asList( + new ListPartitionReassignmentsTopics().setName("foo"). + setPartitionIndexes(asList(0, 1, 2))))); + ControllerResult cancelResult = + replication.alterPartitionReassignments( + new AlterPartitionReassignmentsRequestData().setTopics(asList( + new ReassignableTopic().setName("foo").setPartitions(asList( + new ReassignablePartition().setPartitionIndex(0). + setReplicas(null), + new ReassignablePartition().setPartitionIndex(1). + setReplicas(null), + new ReassignablePartition().setPartitionIndex(2). + setReplicas(null))), + new ReassignableTopic().setName("bar").setPartitions(asList( + new ReassignablePartition().setPartitionIndex(0). + setReplicas(null)))))); + assertEquals(ControllerResult.atomicOf(Collections.singletonList(new ApiMessageAndVersion( + new PartitionChangeRecord().setTopicId(fooId). + setPartitionId(1). + setReplicas(asList(2, 1, 3)). + setLeader(3). + setRemovingReplicas(Collections.emptyList()). + setAddingReplicas(Collections.emptyList()), (short) 0)), + new AlterPartitionReassignmentsResponseData().setErrorMessage(null).setResponses(asList( + new ReassignableTopicResponse().setName("foo").setPartitions(asList( + new ReassignablePartitionResponse().setPartitionIndex(0). + setErrorCode(NO_REASSIGNMENT_IN_PROGRESS.code()).setErrorMessage(null), + new ReassignablePartitionResponse().setPartitionIndex(1). + setErrorCode(NONE.code()).setErrorMessage(null), + new ReassignablePartitionResponse().setPartitionIndex(2). + setErrorCode(UNKNOWN_TOPIC_OR_PARTITION.code()). + setErrorMessage("Unable to find partition foo:2."))), + new ReassignableTopicResponse().setName("bar").setPartitions(asList( + new ReassignablePartitionResponse().setPartitionIndex(0). + setErrorCode(NO_REASSIGNMENT_IN_PROGRESS.code()). + setErrorMessage(null)))))), + cancelResult); + log.info("running final alterIsr..."); + ControllerResult alterIsrResult = replication.alterIsr( + new AlterIsrRequestData().setBrokerId(3).setBrokerEpoch(103). + setTopics(asList(new TopicData().setName("foo").setPartitions(asList( + new PartitionData().setPartitionIndex(1).setCurrentIsrVersion(1). + setLeaderEpoch(0).setNewIsr(asList(3, 0, 2, 1))))))); + assertEquals(new AlterIsrResponseData().setTopics(asList( + new AlterIsrResponseData.TopicData().setName("foo").setPartitions(asList( + new AlterIsrResponseData.PartitionData(). + setPartitionIndex(1). + setErrorCode(FENCED_LEADER_EPOCH.code()))))), + alterIsrResult.response()); + ctx.replay(alterIsrResult.records()); + assertEquals(NONE_REASSIGNING, replication.listPartitionReassignments(null)); + } + + @Test + public void testCancelReassignPartitions() throws Exception { + ReplicationControlTestContext ctx = new ReplicationControlTestContext(); + ReplicationControlManager replication = ctx.replicationControl; + ctx.registerBrokers(0, 1, 2, 3, 4); + ctx.unfenceBrokers(0, 1, 2, 3, 4); + Uuid fooId = ctx.createTestTopic("foo", new int[][] { + new int[] {1, 2, 3, 4}, new int[] {0, 1, 2, 3}, new int[] {4, 3, 1, 0}, + new int[] {2, 3, 4, 1}}).topicId(); + Uuid barId = ctx.createTestTopic("bar", new int[][] { + new int[] {4, 3, 2}}).topicId(); + assertEquals(NONE_REASSIGNING, replication.listPartitionReassignments(null)); + List fenceRecords = new ArrayList<>(); + replication.handleBrokerFenced(3, fenceRecords); + ctx.replay(fenceRecords); + assertEquals(new PartitionRegistration(new int[] {1, 2, 3, 4}, new int[] {1, 2, 4}, + new int[] {}, new int[] {}, 1, 1, 1), replication.getPartition(fooId, 0)); + ControllerResult alterResult = + replication.alterPartitionReassignments( + new AlterPartitionReassignmentsRequestData().setTopics(asList( + new ReassignableTopic().setName("foo").setPartitions(asList( + new ReassignablePartition().setPartitionIndex(0). + setReplicas(asList(1, 2, 3)), + new ReassignablePartition().setPartitionIndex(1). + setReplicas(asList(1, 2, 3, 0)), + new ReassignablePartition().setPartitionIndex(2). + setReplicas(asList(5, 6, 7)), + new ReassignablePartition().setPartitionIndex(3). + setReplicas(asList()))), + new ReassignableTopic().setName("bar").setPartitions(asList( + new ReassignablePartition().setPartitionIndex(0). + setReplicas(asList(1, 2, 3, 4, 0))))))); + assertEquals(new AlterPartitionReassignmentsResponseData(). + setErrorMessage(null).setResponses(asList( + new ReassignableTopicResponse().setName("foo").setPartitions(asList( + new ReassignablePartitionResponse().setPartitionIndex(0). + setErrorMessage(null), + new ReassignablePartitionResponse().setPartitionIndex(1). + setErrorMessage(null), + new ReassignablePartitionResponse().setPartitionIndex(2). + setErrorCode(INVALID_REPLICA_ASSIGNMENT.code()). + setErrorMessage("The manual partition assignment includes broker 5, " + + "but no such broker is registered."), + new ReassignablePartitionResponse().setPartitionIndex(3). + setErrorCode(INVALID_REPLICA_ASSIGNMENT.code()). + setErrorMessage("The manual partition assignment includes an empty " + + "replica list."))), + new ReassignableTopicResponse().setName("bar").setPartitions(asList( + new ReassignablePartitionResponse().setPartitionIndex(0). + setErrorMessage(null))))), + alterResult.response()); + ctx.replay(alterResult.records()); + assertEquals(new PartitionRegistration(new int[] {1, 2, 3}, new int[] {1, 2}, + new int[] {}, new int[] {}, 1, 2, 2), replication.getPartition(fooId, 0)); + assertEquals(new PartitionRegistration(new int[] {1, 2, 3, 0}, new int[] {0, 1, 2}, + new int[] {}, new int[] {}, 0, 1, 2), replication.getPartition(fooId, 1)); + assertEquals(new PartitionRegistration(new int[] {1, 2, 3, 4, 0}, new int[] {4, 2}, + new int[] {}, new int[] {0, 1}, 4, 1, 2), replication.getPartition(barId, 0)); + ListPartitionReassignmentsResponseData currentReassigning = + new ListPartitionReassignmentsResponseData().setErrorMessage(null). + setTopics(asList(new OngoingTopicReassignment(). + setName("bar").setPartitions(asList( + new OngoingPartitionReassignment().setPartitionIndex(0). + setRemovingReplicas(Collections.emptyList()). + setAddingReplicas(asList(0, 1)). + setReplicas(asList(1, 2, 3, 4, 0)))))); + assertEquals(currentReassigning, replication.listPartitionReassignments(null)); + assertEquals(NONE_REASSIGNING, replication.listPartitionReassignments(asList( + new ListPartitionReassignmentsTopics().setName("foo"). + setPartitionIndexes(asList(0, 1, 2))))); + assertEquals(currentReassigning, replication.listPartitionReassignments(asList( + new ListPartitionReassignmentsTopics().setName("bar"). + setPartitionIndexes(asList(0, 1, 2))))); + ControllerResult alterIsrResult = replication.alterIsr( + new AlterIsrRequestData().setBrokerId(4).setBrokerEpoch(104). + setTopics(asList(new TopicData().setName("bar").setPartitions(asList( + new PartitionData().setPartitionIndex(0).setCurrentIsrVersion(2). + setLeaderEpoch(1).setNewIsr(asList(4, 1, 2, 3, 0))))))); + assertEquals(new AlterIsrResponseData().setTopics(asList( + new AlterIsrResponseData.TopicData().setName("bar").setPartitions(asList( + new AlterIsrResponseData.PartitionData(). + setPartitionIndex(0). + setLeaderId(4). + setLeaderEpoch(1). + setIsr(asList(4, 1, 2, 3, 0)). + setCurrentIsrVersion(3). + setErrorCode(NONE.code()))))), + alterIsrResult.response()); + ControllerResult cancelResult = + replication.alterPartitionReassignments( + new AlterPartitionReassignmentsRequestData().setTopics(asList( + new ReassignableTopic().setName("foo").setPartitions(asList( + new ReassignablePartition().setPartitionIndex(0). + setReplicas(null))), + new ReassignableTopic().setName("bar").setPartitions(asList( + new ReassignablePartition().setPartitionIndex(0). + setReplicas(null)))))); + assertEquals(ControllerResult.atomicOf(Collections.singletonList(new ApiMessageAndVersion( + new PartitionChangeRecord().setTopicId(barId). + setPartitionId(0). + setLeader(4). + setReplicas(asList(2, 3, 4)). + setRemovingReplicas(null). + setAddingReplicas(Collections.emptyList()), (short) 0)), + new AlterPartitionReassignmentsResponseData().setErrorMessage(null).setResponses(asList( + new ReassignableTopicResponse().setName("foo").setPartitions(asList( + new ReassignablePartitionResponse().setPartitionIndex(0). + setErrorCode(NO_REASSIGNMENT_IN_PROGRESS.code()).setErrorMessage(null))), + new ReassignableTopicResponse().setName("bar").setPartitions(asList( + new ReassignablePartitionResponse().setPartitionIndex(0). + setErrorMessage(null)))))), + cancelResult); + ctx.replay(cancelResult.records()); + assertEquals(NONE_REASSIGNING, replication.listPartitionReassignments(null)); + assertEquals(new PartitionRegistration(new int[] {2, 3, 4}, new int[] {4, 2}, + new int[] {}, new int[] {}, 4, 2, 3), replication.getPartition(barId, 0)); + } + + @Test + public void testManualPartitionAssignmentOnAllFencedBrokers() throws Exception { + ReplicationControlTestContext ctx = new ReplicationControlTestContext(); + ctx.registerBrokers(0, 1, 2, 3); + ctx.createTestTopic("foo", new int[][] {new int[] {0, 1, 2}}, + INVALID_REPLICA_ASSIGNMENT.code()); + } + + @Test + public void testCreatePartitionsFailsWithManualAssignmentWithAllFenced() throws Exception { + ReplicationControlTestContext ctx = new ReplicationControlTestContext(); + ctx.registerBrokers(0, 1, 2, 3, 4, 5); + ctx.unfenceBrokers(0, 1, 2); + Uuid fooId = ctx.createTestTopic("foo", new int[][] {new int[] {0, 1, 2}}).topicId(); + ctx.createPartitions(2, "foo", new int[][] {new int[] {3, 4, 5}}, + INVALID_REPLICA_ASSIGNMENT.code()); + ctx.createPartitions(2, "foo", new int[][] {new int[] {2, 4, 5}}, NONE.code()); + assertEquals(new PartitionRegistration(new int[] {2, 4, 5}, + new int[] {2}, Replicas.NONE, Replicas.NONE, 2, 0, 0), + ctx.replicationControl.getPartition(fooId, 1)); + } + + private void assertLeaderAndIsr( + ReplicationControlManager replication, + TopicIdPartition topicIdPartition, + int leaderId, + int[] isr + ) { + PartitionRegistration registration = replication.getPartition( + topicIdPartition.topicId(), + topicIdPartition.partitionId() + ); + assertArrayEquals(isr, registration.isr); + assertEquals(leaderId, registration.leader); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testElectUncleanLeaders(boolean electAllPartitions) throws Exception { + ReplicationControlTestContext ctx = new ReplicationControlTestContext(); + ReplicationControlManager replication = ctx.replicationControl; + ctx.registerBrokers(0, 1, 2, 3, 4); + ctx.unfenceBrokers(0, 1, 2, 3, 4); + + Uuid fooId = ctx.createTestTopic("foo", new int[][]{ + new int[]{1, 2, 3}, new int[]{2, 3, 4}, new int[]{0, 2, 1}}).topicId(); + + TopicIdPartition partition0 = new TopicIdPartition(fooId, 0); + TopicIdPartition partition1 = new TopicIdPartition(fooId, 1); + TopicIdPartition partition2 = new TopicIdPartition(fooId, 2); + + ctx.fenceBrokers(Utils.mkSet(2, 3)); + ctx.fenceBrokers(Utils.mkSet(1, 2, 3)); + + assertLeaderAndIsr(replication, partition0, NO_LEADER, new int[]{1}); + assertLeaderAndIsr(replication, partition1, 4, new int[]{4}); + assertLeaderAndIsr(replication, partition2, 0, new int[]{0}); + + ElectLeadersRequestData request = buildElectLeadersRequest( + ElectionType.UNCLEAN, + electAllPartitions ? null : singletonMap("foo", asList(0, 1, 2)) + ); + + // No election can be done yet because no replicas are available for partition 0 + ControllerResult result1 = replication.electLeaders(request); + assertEquals(Collections.emptyList(), result1.records()); + + ElectLeadersResponseData expectedResponse1 = buildElectLeadersResponse(NONE, electAllPartitions, Utils.mkMap( + Utils.mkEntry( + new TopicPartition("foo", 0), + new ApiError(ELIGIBLE_LEADERS_NOT_AVAILABLE) + ), + Utils.mkEntry( + new TopicPartition("foo", 1), + new ApiError(ELECTION_NOT_NEEDED) + ), + Utils.mkEntry( + new TopicPartition("foo", 2), + new ApiError(ELECTION_NOT_NEEDED) + ) + )); + assertElectLeadersResponse(expectedResponse1, result1.response()); + + // Now we bring 2 back online which should allow the unclean election of partition 0 + ctx.unfenceBrokers(Utils.mkSet(2)); + + // Bring 2 back into the ISR for partition 1. This allows us to verify that + // preferred election does not occur as a result of the unclean election request. + ctx.alterIsr(partition1, 4, asList(2, 4)); + + ControllerResult result = replication.electLeaders(request); + assertEquals(1, result.records().size()); + + ApiMessageAndVersion record = result.records().get(0); + assertTrue(record.message() instanceof PartitionChangeRecord); + + PartitionChangeRecord partitionChangeRecord = (PartitionChangeRecord) record.message(); + assertEquals(0, partitionChangeRecord.partitionId()); + assertEquals(2, partitionChangeRecord.leader()); + assertEquals(singletonList(2), partitionChangeRecord.isr()); + ctx.replay(result.records()); + + assertLeaderAndIsr(replication, partition0, 2, new int[]{2}); + assertLeaderAndIsr(replication, partition1, 4, new int[]{2, 4}); + assertLeaderAndIsr(replication, partition2, 0, new int[]{0}); + + ElectLeadersResponseData expectedResponse = buildElectLeadersResponse(NONE, electAllPartitions, Utils.mkMap( + Utils.mkEntry( + new TopicPartition("foo", 0), + ApiError.NONE + ), + Utils.mkEntry( + new TopicPartition("foo", 1), + new ApiError(ELECTION_NOT_NEEDED) + ), + Utils.mkEntry( + new TopicPartition("foo", 2), + new ApiError(ELECTION_NOT_NEEDED) + ) + )); + assertElectLeadersResponse(expectedResponse, result.response()); + } + + @Test + public void testPreferredElectionDoesNotTriggerUncleanElection() throws Exception { + ReplicationControlTestContext ctx = new ReplicationControlTestContext(); + ReplicationControlManager replication = ctx.replicationControl; + ctx.registerBrokers(1, 2, 3, 4); + ctx.unfenceBrokers(1, 2, 3, 4); + + Uuid fooId = ctx.createTestTopic("foo", new int[][]{new int[]{1, 2, 3}}).topicId(); + TopicIdPartition partition = new TopicIdPartition(fooId, 0); + + ctx.fenceBrokers(Utils.mkSet(2, 3)); + ctx.fenceBrokers(Utils.mkSet(1, 2, 3)); + ctx.unfenceBrokers(Utils.mkSet(2)); + + assertLeaderAndIsr(replication, partition, NO_LEADER, new int[]{1}); + + ctx.alterTopicConfig("foo", "unclean.leader.election.enable", "true"); + + ElectLeadersRequestData request = buildElectLeadersRequest( + ElectionType.PREFERRED, + singletonMap("foo", singletonList(0)) + ); + + // No election should be done even though unclean election is available + ControllerResult result = replication.electLeaders(request); + assertEquals(Collections.emptyList(), result.records()); + + ElectLeadersResponseData expectedResponse = buildElectLeadersResponse(NONE, false, singletonMap( + new TopicPartition("foo", 0), new ApiError(PREFERRED_LEADER_NOT_AVAILABLE) + )); + assertEquals(expectedResponse, result.response()); + } + + private ElectLeadersRequestData buildElectLeadersRequest( + ElectionType electionType, + Map> partitions + ) { + ElectLeadersRequestData request = new ElectLeadersRequestData(). + setElectionType(electionType.value); + + if (partitions == null) { + request.setTopicPartitions(null); + } else { + partitions.forEach((topic, partitionIds) -> { + request.topicPartitions().add(new TopicPartitions() + .setTopic(topic) + .setPartitions(partitionIds) + ); + }); + } + return request; + } + + @Test + public void testFenceMultipleBrokers() throws Exception { + ReplicationControlTestContext ctx = new ReplicationControlTestContext(); + ReplicationControlManager replication = ctx.replicationControl; + ctx.registerBrokers(0, 1, 2, 3, 4); + ctx.unfenceBrokers(0, 1, 2, 3, 4); + + Uuid fooId = ctx.createTestTopic("foo", new int[][]{ + new int[]{1, 2, 3}, new int[]{2, 3, 4}, new int[]{0, 2, 1}}).topicId(); + + assertTrue(ctx.clusterControl.fencedBrokerIds().isEmpty()); + ctx.fenceBrokers(Utils.mkSet(2, 3)); + + PartitionRegistration partition0 = replication.getPartition(fooId, 0); + PartitionRegistration partition1 = replication.getPartition(fooId, 1); + PartitionRegistration partition2 = replication.getPartition(fooId, 2); + + assertArrayEquals(new int[]{1, 2, 3}, partition0.replicas); + assertArrayEquals(new int[]{1}, partition0.isr); + assertEquals(1, partition0.leader); + + assertArrayEquals(new int[]{2, 3, 4}, partition1.replicas); + assertArrayEquals(new int[]{4}, partition1.isr); + assertEquals(4, partition1.leader); + + assertArrayEquals(new int[]{0, 2, 1}, partition2.replicas); + assertArrayEquals(new int[]{0, 1}, partition2.isr); + assertNotEquals(2, partition2.leader); + } + + @Test + public void testElectPreferredLeaders() throws Exception { + ReplicationControlTestContext ctx = new ReplicationControlTestContext(); + ReplicationControlManager replication = ctx.replicationControl; + ctx.registerBrokers(0, 1, 2, 3, 4); + ctx.unfenceBrokers(2, 3, 4); + Uuid fooId = ctx.createTestTopic("foo", new int[][]{ + new int[]{1, 2, 3}, new int[]{2, 3, 4}, new int[]{0, 2, 1}}).topicId(); + ElectLeadersRequestData request1 = new ElectLeadersRequestData(). + setElectionType(ElectionType.PREFERRED.value). + setTopicPartitions(new TopicPartitionsCollection(asList( + new TopicPartitions().setTopic("foo"). + setPartitions(asList(0, 1)), + new TopicPartitions().setTopic("bar"). + setPartitions(asList(0, 1))).iterator())); + ControllerResult election1Result = + replication.electLeaders(request1); + ElectLeadersResponseData expectedResponse1 = buildElectLeadersResponse(NONE, false, Utils.mkMap( + Utils.mkEntry( + new TopicPartition("foo", 0), + new ApiError(PREFERRED_LEADER_NOT_AVAILABLE) + ), + Utils.mkEntry( + new TopicPartition("foo", 1), + new ApiError(ELECTION_NOT_NEEDED) + ), + Utils.mkEntry( + new TopicPartition("bar", 0), + new ApiError(UNKNOWN_TOPIC_OR_PARTITION, "No such topic as bar") + ), + Utils.mkEntry( + new TopicPartition("bar", 1), + new ApiError(UNKNOWN_TOPIC_OR_PARTITION, "No such topic as bar") + ) + )); + assertElectLeadersResponse(expectedResponse1, election1Result.response()); + assertEquals(Collections.emptyList(), election1Result.records()); + ctx.unfenceBrokers(0, 1); + + ControllerResult alterIsrResult = replication.alterIsr( + new AlterIsrRequestData().setBrokerId(2).setBrokerEpoch(102). + setTopics(asList(new AlterIsrRequestData.TopicData().setName("foo"). + setPartitions(asList(new AlterIsrRequestData.PartitionData(). + setPartitionIndex(0).setCurrentIsrVersion(0). + setLeaderEpoch(0).setNewIsr(asList(1, 2, 3))))))); + assertEquals(new AlterIsrResponseData().setTopics(asList( + new AlterIsrResponseData.TopicData().setName("foo").setPartitions(asList( + new AlterIsrResponseData.PartitionData(). + setPartitionIndex(0). + setLeaderId(2). + setLeaderEpoch(0). + setIsr(asList(1, 2, 3)). + setCurrentIsrVersion(1). + setErrorCode(NONE.code()))))), + alterIsrResult.response()); + + ElectLeadersResponseData expectedResponse2 = buildElectLeadersResponse(NONE, false, Utils.mkMap( + Utils.mkEntry( + new TopicPartition("foo", 0), + ApiError.NONE + ), + Utils.mkEntry( + new TopicPartition("foo", 1), + new ApiError(ELECTION_NOT_NEEDED) + ), + Utils.mkEntry( + new TopicPartition("bar", 0), + new ApiError(UNKNOWN_TOPIC_OR_PARTITION, "No such topic as bar") + ), + Utils.mkEntry( + new TopicPartition("bar", 1), + new ApiError(UNKNOWN_TOPIC_OR_PARTITION, "No such topic as bar") + ) + )); + + ctx.replay(alterIsrResult.records()); + ControllerResult election2Result = + replication.electLeaders(request1); + assertElectLeadersResponse(expectedResponse2, election2Result.response()); + assertEquals(asList(new ApiMessageAndVersion(new PartitionChangeRecord(). + setPartitionId(0). + setTopicId(fooId). + setLeader(1), (short) 0)), election2Result.records()); + } + + private void assertElectLeadersResponse( + ElectLeadersResponseData expected, + ElectLeadersResponseData actual + ) { + assertEquals(Errors.forCode(expected.errorCode()), Errors.forCode(actual.errorCode())); + assertEquals(collectElectLeadersErrors(expected), collectElectLeadersErrors(actual)); + } + + private Map collectElectLeadersErrors(ElectLeadersResponseData response) { + Map res = new HashMap<>(); + response.replicaElectionResults().forEach(topicResult -> { + String topic = topicResult.topic(); + topicResult.partitionResult().forEach(partitionResult -> { + TopicPartition topicPartition = new TopicPartition(topic, partitionResult.partitionId()); + res.put(topicPartition, partitionResult); + }); + }); + return res; + } + + private ElectLeadersResponseData buildElectLeadersResponse( + Errors topLevelError, + boolean electAllPartitions, + Map errors + ) { + Map>> errorsByTopic = errors.entrySet().stream() + .collect(Collectors.groupingBy(entry -> entry.getKey().topic())); + + ElectLeadersResponseData response = new ElectLeadersResponseData() + .setErrorCode(topLevelError.code()); + + errorsByTopic.forEach((topic, partitionErrors) -> { + ReplicaElectionResult electionResult = new ReplicaElectionResult().setTopic(topic); + electionResult.setPartitionResult(partitionErrors.stream() + .filter(entry -> !electAllPartitions || entry.getValue().error() != ELECTION_NOT_NEEDED) + .map(entry -> { + TopicPartition topicPartition = entry.getKey(); + ApiError error = entry.getValue(); + return new PartitionResult() + .setPartitionId(topicPartition.partition()) + .setErrorCode(error.error().code()) + .setErrorMessage(error.message()); + }) + .collect(Collectors.toList())); + response.replicaElectionResults().add(electionResult); + }); + + return response; + } + +} diff --git a/metadata/src/test/java/org/apache/kafka/controller/ResultOrErrorTest.java b/metadata/src/test/java/org/apache/kafka/controller/ResultOrErrorTest.java new file mode 100644 index 0000000..7d42b2e --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/controller/ResultOrErrorTest.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.requests.ApiError; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + + +@Timeout(value = 40) +public class ResultOrErrorTest { + @Test + public void testError() { + ResultOrError resultOrError = + new ResultOrError<>(Errors.INVALID_REQUEST, "missing foobar"); + assertTrue(resultOrError.isError()); + assertFalse(resultOrError.isResult()); + assertEquals(null, resultOrError.result()); + assertEquals(new ApiError(Errors.INVALID_REQUEST, "missing foobar"), + resultOrError.error()); + } + + @Test + public void testResult() { + ResultOrError resultOrError = new ResultOrError<>(123); + assertFalse(resultOrError.isError()); + assertTrue(resultOrError.isResult()); + assertEquals(123, resultOrError.result()); + assertEquals(null, resultOrError.error()); + } + + @Test + public void testEquals() { + ResultOrError a = new ResultOrError<>(Errors.INVALID_REQUEST, "missing foobar"); + ResultOrError b = new ResultOrError<>("abcd"); + assertFalse(a.equals(b)); + assertFalse(b.equals(a)); + assertTrue(a.equals(a)); + assertTrue(b.equals(b)); + ResultOrError c = new ResultOrError<>(Errors.INVALID_REQUEST, "missing baz"); + assertFalse(a.equals(c)); + assertFalse(c.equals(a)); + assertTrue(c.equals(c)); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/controller/SnapshotGeneratorTest.java b/metadata/src/test/java/org/apache/kafka/controller/SnapshotGeneratorTest.java new file mode 100644 index 0000000..a7ac119 --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/controller/SnapshotGeneratorTest.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.memory.MemoryPool; +import org.apache.kafka.common.metadata.ConfigRecord; +import org.apache.kafka.common.metadata.TopicRecord; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.utils.ExponentialBackoff; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.controller.SnapshotGenerator.Section; +import org.apache.kafka.metadata.MetadataRecordSerde; +import org.apache.kafka.raft.OffsetAndEpoch; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.apache.kafka.snapshot.MockRawSnapshotWriter; +import org.apache.kafka.snapshot.RawSnapshotWriter; +import org.apache.kafka.snapshot.SnapshotWriter; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.Arrays; +import java.util.List; +import java.util.OptionalLong; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + + +@Timeout(40) +public class SnapshotGeneratorTest { + private static final List> BATCHES; + + static { + BATCHES = Arrays.asList( + Arrays.asList(new ApiMessageAndVersion(new TopicRecord(). + setName("foo").setTopicId(Uuid.randomUuid()), (short) 0)), + Arrays.asList(new ApiMessageAndVersion(new TopicRecord(). + setName("bar").setTopicId(Uuid.randomUuid()), (short) 0)), + Arrays.asList(new ApiMessageAndVersion(new TopicRecord(). + setName("baz").setTopicId(Uuid.randomUuid()), (short) 0)), + Arrays.asList(new ApiMessageAndVersion(new ConfigRecord(). + setResourceName("foo").setResourceType(ConfigResource.Type.TOPIC.id()). + setName("retention.ms").setValue("10000000"), (short) 0), + new ApiMessageAndVersion(new ConfigRecord(). + setResourceName("foo").setResourceType(ConfigResource.Type.TOPIC.id()). + setName("max.message.bytes").setValue("100000000"), (short) 0)), + Arrays.asList(new ApiMessageAndVersion(new ConfigRecord(). + setResourceName("bar").setResourceType(ConfigResource.Type.TOPIC.id()). + setName("retention.ms").setValue("5000000"), (short) 0))); + } + + @Test + public void testGenerateBatches() throws Exception { + SnapshotWriter writer = createSnapshotWriter(123, 0); + ExponentialBackoff exponentialBackoff = + new ExponentialBackoff(100, 2, 400, 0.0); + List
                sections = Arrays.asList(new Section("replication", + Arrays.asList(BATCHES.get(0), BATCHES.get(1), BATCHES.get(2)).iterator()), + new Section("configuration", + Arrays.asList(BATCHES.get(3), BATCHES.get(4)).iterator())); + SnapshotGenerator generator = new SnapshotGenerator(new LogContext(), + writer, 2, exponentialBackoff, sections); + assertFalse(writer.isFrozen()); + assertEquals(123L, generator.lastContainedLogOffset()); + assertEquals(writer, generator.writer()); + assertEquals(OptionalLong.of(0L), generator.generateBatches()); + assertEquals(OptionalLong.of(0L), generator.generateBatches()); + assertFalse(writer.isFrozen()); + assertEquals(OptionalLong.empty(), generator.generateBatches()); + assertTrue(writer.isFrozen()); + } + + private SnapshotWriter createSnapshotWriter( + long committedOffset, + long lastContainedLogTime + ) { + return SnapshotWriter.createWithHeader( + () -> createNewSnapshot(new OffsetAndEpoch(committedOffset + 1, 1)), + 1024, + MemoryPool.NONE, + new MockTime(), + lastContainedLogTime, + CompressionType.NONE, + new MetadataRecordSerde() + ).get(); + } + + private Optional createNewSnapshot( + OffsetAndEpoch snapshotId + ) { + return Optional.of(new MockRawSnapshotWriter(snapshotId, buffer -> { })); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/controller/StripedReplicaPlacerTest.java b/metadata/src/test/java/org/apache/kafka/controller/StripedReplicaPlacerTest.java new file mode 100644 index 0000000..c3fbb09 --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/controller/StripedReplicaPlacerTest.java @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.controller; + +import org.apache.kafka.common.errors.InvalidReplicationFactorException; +import org.apache.kafka.controller.StripedReplicaPlacer.BrokerList; +import org.apache.kafka.controller.StripedReplicaPlacer.RackList; +import org.apache.kafka.metadata.UsableBroker; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + + +@Timeout(value = 40) +public class StripedReplicaPlacerTest { + /** + * Test that the BrokerList class works as expected. + */ + @Test + public void testBrokerList() { + assertEquals(0, BrokerList.EMPTY.size()); + assertEquals(-1, BrokerList.EMPTY.next(1)); + BrokerList brokers = new BrokerList().add(0).add(1).add(2).add(3); + assertEquals(4, brokers.size()); + assertEquals(0, brokers.next(0)); + assertEquals(1, brokers.next(0)); + assertEquals(2, brokers.next(0)); + assertEquals(3, brokers.next(0)); + assertEquals(-1, brokers.next(0)); + assertEquals(-1, brokers.next(0)); + assertEquals(1, brokers.next(1)); + assertEquals(2, brokers.next(1)); + assertEquals(3, brokers.next(1)); + assertEquals(0, brokers.next(1)); + assertEquals(-1, brokers.next(1)); + } + + /** + * Test that we perform striped replica placement as expected, and don't use the + * fenced replica if we don't have to. + */ + @Test + public void testAvoidFencedReplicaIfPossibleOnSingleRack() { + MockRandom random = new MockRandom(); + RackList rackList = new RackList(random, Arrays.asList( + new UsableBroker(3, Optional.empty(), false), + new UsableBroker(1, Optional.empty(), true), + new UsableBroker(0, Optional.empty(), false), + new UsableBroker(4, Optional.empty(), false), + new UsableBroker(2, Optional.empty(), false)).iterator()); + assertEquals(5, rackList.numTotalBrokers()); + assertEquals(4, rackList.numUnfencedBrokers()); + assertEquals(Collections.singletonList(Optional.empty()), rackList.rackNames()); + assertThrows(InvalidReplicationFactorException.class, () -> rackList.place(0)); + assertThrows(InvalidReplicationFactorException.class, () -> rackList.place(-1)); + assertEquals(Arrays.asList(3, 4, 0, 2), rackList.place(4)); + assertEquals(Arrays.asList(4, 0, 2, 3), rackList.place(4)); + assertEquals(Arrays.asList(0, 2, 3, 4), rackList.place(4)); + assertEquals(Arrays.asList(2, 3, 4, 0), rackList.place(4)); + assertEquals(Arrays.asList(0, 4, 3, 2), rackList.place(4)); + } + + /** + * Test that we perform striped replica placement as expected for a multi-partition topic + * on a single unfenced broker + */ + @Test + public void testMultiPartitionTopicPlacementOnSingleUnfencedBroker() { + MockRandom random = new MockRandom(); + StripedReplicaPlacer placer = new StripedReplicaPlacer(random); + assertEquals(Arrays.asList(Arrays.asList(0), + Arrays.asList(0), + Arrays.asList(0)), + placer.place(0, 3, (short) 1, Arrays.asList( + new UsableBroker(0, Optional.empty(), false), + new UsableBroker(1, Optional.empty(), true)).iterator())); + } + + /** + * Test that we will place on the fenced replica if we need to. + */ + @Test + public void testPlacementOnFencedReplicaOnSingleRack() { + MockRandom random = new MockRandom(); + RackList rackList = new RackList(random, Arrays.asList( + new UsableBroker(3, Optional.empty(), false), + new UsableBroker(1, Optional.empty(), true), + new UsableBroker(2, Optional.empty(), false)).iterator()); + assertEquals(3, rackList.numTotalBrokers()); + assertEquals(2, rackList.numUnfencedBrokers()); + assertEquals(Collections.singletonList(Optional.empty()), rackList.rackNames()); + assertEquals(Arrays.asList(3, 2, 1), rackList.place(3)); + assertEquals(Arrays.asList(2, 3, 1), rackList.place(3)); + assertEquals(Arrays.asList(3, 2, 1), rackList.place(3)); + assertEquals(Arrays.asList(2, 3, 1), rackList.place(3)); + } + + @Test + public void testRackListWithMultipleRacks() { + MockRandom random = new MockRandom(); + RackList rackList = new RackList(random, Arrays.asList( + new UsableBroker(11, Optional.of("1"), false), + new UsableBroker(10, Optional.of("1"), false), + new UsableBroker(30, Optional.of("3"), false), + new UsableBroker(31, Optional.of("3"), false), + new UsableBroker(21, Optional.of("2"), false), + new UsableBroker(20, Optional.of("2"), true)).iterator()); + assertEquals(6, rackList.numTotalBrokers()); + assertEquals(5, rackList.numUnfencedBrokers()); + assertEquals(Arrays.asList(Optional.of("1"), Optional.of("2"), Optional.of("3")), rackList.rackNames()); + assertEquals(Arrays.asList(11, 21, 31, 10), rackList.place(4)); + assertEquals(Arrays.asList(21, 30, 10, 20), rackList.place(4)); + assertEquals(Arrays.asList(31, 11, 21, 30), rackList.place(4)); + } + + @Test + public void testRackListWithInvalidRacks() { + MockRandom random = new MockRandom(); + RackList rackList = new RackList(random, Arrays.asList( + new UsableBroker(11, Optional.of("1"), false), + new UsableBroker(10, Optional.of("1"), false), + new UsableBroker(30, Optional.of("3"), true), + new UsableBroker(31, Optional.of("3"), true), + new UsableBroker(20, Optional.of("2"), true), + new UsableBroker(21, Optional.of("2"), true), + new UsableBroker(41, Optional.of("4"), false), + new UsableBroker(40, Optional.of("4"), true)).iterator()); + assertEquals(8, rackList.numTotalBrokers()); + assertEquals(3, rackList.numUnfencedBrokers()); + assertEquals(Arrays.asList(Optional.of("1"), + Optional.of("2"), + Optional.of("3"), + Optional.of("4")), rackList.rackNames()); + assertEquals(Arrays.asList(41, 11, 21, 30), rackList.place(4)); + assertEquals(Arrays.asList(10, 20, 31, 41), rackList.place(4)); + assertEquals(Arrays.asList(41, 21, 30, 11), rackList.place(4)); + } + + @Test + public void testAllBrokersFenced() { + MockRandom random = new MockRandom(); + StripedReplicaPlacer placer = new StripedReplicaPlacer(random); + assertEquals("All brokers are currently fenced.", + assertThrows(InvalidReplicationFactorException.class, + () -> placer.place(0, 1, (short) 1, Arrays.asList( + new UsableBroker(11, Optional.of("1"), true), + new UsableBroker(10, Optional.of("1"), true)).iterator())).getMessage()); + } + + @Test + public void testNotEnoughBrokers() { + MockRandom random = new MockRandom(); + StripedReplicaPlacer placer = new StripedReplicaPlacer(random); + assertEquals("The target replication factor of 3 cannot be reached because only " + + "2 broker(s) are registered.", + assertThrows(InvalidReplicationFactorException.class, + () -> placer.place(0, 1, (short) 3, Arrays.asList( + new UsableBroker(11, Optional.of("1"), false), + new UsableBroker(10, Optional.of("1"), false)).iterator())).getMessage()); + } + + @Test + public void testNonPositiveReplicationFactor() { + MockRandom random = new MockRandom(); + StripedReplicaPlacer placer = new StripedReplicaPlacer(random); + assertEquals("Invalid replication factor 0: the replication factor must be positive.", + assertThrows(InvalidReplicationFactorException.class, + () -> placer.place(0, 1, (short) 0, Arrays.asList( + new UsableBroker(11, Optional.of("1"), false), + new UsableBroker(10, Optional.of("1"), false)).iterator())).getMessage()); + } + + @Test + public void testSuccessfulPlacement() { + MockRandom random = new MockRandom(); + StripedReplicaPlacer placer = new StripedReplicaPlacer(random); + assertEquals(Arrays.asList(Arrays.asList(2, 3, 0), + Arrays.asList(3, 0, 1), + Arrays.asList(0, 1, 2), + Arrays.asList(1, 2, 3), + Arrays.asList(1, 0, 2)), + placer.place(0, 5, (short) 3, Arrays.asList( + new UsableBroker(0, Optional.empty(), false), + new UsableBroker(3, Optional.empty(), false), + new UsableBroker(2, Optional.empty(), false), + new UsableBroker(1, Optional.empty(), false)).iterator())); + } + + @Test + public void testEvenDistribution() { + MockRandom random = new MockRandom(); + StripedReplicaPlacer placer = new StripedReplicaPlacer(random); + List> replicas = placer.place(0, 200, (short) 2, Arrays.asList( + new UsableBroker(0, Optional.empty(), false), + new UsableBroker(1, Optional.empty(), false), + new UsableBroker(2, Optional.empty(), false), + new UsableBroker(3, Optional.empty(), false)).iterator()); + Map, Integer> counts = new HashMap<>(); + for (List partitionReplicas : replicas) { + counts.put(partitionReplicas, counts.getOrDefault(partitionReplicas, 0) + 1); + } + assertEquals(14, counts.get(Arrays.asList(0, 1))); + assertEquals(22, counts.get(Arrays.asList(0, 2))); + assertEquals(14, counts.get(Arrays.asList(0, 3))); + assertEquals(17, counts.get(Arrays.asList(1, 0))); + assertEquals(17, counts.get(Arrays.asList(1, 2))); + assertEquals(16, counts.get(Arrays.asList(1, 3))); + assertEquals(13, counts.get(Arrays.asList(2, 0))); + assertEquals(17, counts.get(Arrays.asList(2, 1))); + assertEquals(20, counts.get(Arrays.asList(2, 3))); + assertEquals(20, counts.get(Arrays.asList(3, 0))); + assertEquals(19, counts.get(Arrays.asList(3, 1))); + assertEquals(11, counts.get(Arrays.asList(3, 2))); + } + + @Test + public void testRackListAllBrokersFenced() { + // ensure we can place N replicas on a rack when the rack has less than N brokers + MockRandom random = new MockRandom(); + RackList rackList = new RackList(random, Arrays.asList( + new UsableBroker(0, Optional.empty(), true), + new UsableBroker(1, Optional.empty(), true), + new UsableBroker(2, Optional.empty(), true)).iterator()); + assertEquals(3, rackList.numTotalBrokers()); + assertEquals(0, rackList.numUnfencedBrokers()); + assertEquals(Collections.singletonList(Optional.empty()), rackList.rackNames()); + assertEquals("All brokers are currently fenced.", + assertThrows(InvalidReplicationFactorException.class, + () -> rackList.place(3)).getMessage()); + } + + @Test + public void testRackListNotEnoughBrokers() { + MockRandom random = new MockRandom(); + RackList rackList = new RackList(random, Arrays.asList( + new UsableBroker(11, Optional.of("1"), false), + new UsableBroker(10, Optional.of("1"), false)).iterator()); + assertEquals("The target replication factor of 3 cannot be reached because only " + + "2 broker(s) are registered.", + assertThrows(InvalidReplicationFactorException.class, + () -> rackList.place(3)).getMessage()); + } + + @Test + public void testRackListNonPositiveReplicationFactor() { + MockRandom random = new MockRandom(); + RackList rackList = new RackList(random, Arrays.asList( + new UsableBroker(11, Optional.of("1"), false), + new UsableBroker(10, Optional.of("1"), false)).iterator()); + assertEquals("Invalid replication factor -1: the replication factor must be positive.", + assertThrows(InvalidReplicationFactorException.class, + () -> rackList.place(-1)).getMessage()); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/image/ClientQuotasImageTest.java b/metadata/src/test/java/org/apache/kafka/image/ClientQuotasImageTest.java new file mode 100644 index 0000000..aad3488 --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/image/ClientQuotasImageTest.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.image; + +import org.apache.kafka.common.config.internals.QuotaConfigs; +import org.apache.kafka.common.metadata.ClientQuotaRecord; +import org.apache.kafka.common.metadata.ClientQuotaRecord.EntityData; +import org.apache.kafka.common.quota.ClientQuotaEntity; +import org.apache.kafka.metadata.RecordTestUtils; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.kafka.common.metadata.MetadataRecordType.CLIENT_QUOTA_RECORD; +import static org.junit.jupiter.api.Assertions.assertEquals; + + +@Timeout(value = 40) +public class ClientQuotasImageTest { + final static ClientQuotasImage IMAGE1; + + final static List DELTA1_RECORDS; + + final static ClientQuotasDelta DELTA1; + + final static ClientQuotasImage IMAGE2; + + static { + Map entities1 = new HashMap<>(); + Map fooUser = new HashMap<>(); + fooUser.put(ClientQuotaEntity.USER, "foo"); + Map fooUserQuotas = new HashMap<>(); + fooUserQuotas.put(QuotaConfigs.PRODUCER_BYTE_RATE_OVERRIDE_CONFIG, 123.0); + entities1.put(new ClientQuotaEntity(fooUser), new ClientQuotaImage(fooUserQuotas)); + Map barUserAndIp = new HashMap<>(); + barUserAndIp.put(ClientQuotaEntity.USER, "bar"); + barUserAndIp.put(ClientQuotaEntity.IP, "127.0.0.1"); + Map barUserAndIpQuotas = new HashMap<>(); + barUserAndIpQuotas.put(QuotaConfigs.CONSUMER_BYTE_RATE_OVERRIDE_CONFIG, 456.0); + entities1.put(new ClientQuotaEntity(barUserAndIp), + new ClientQuotaImage(barUserAndIpQuotas)); + IMAGE1 = new ClientQuotasImage(entities1); + + DELTA1_RECORDS = new ArrayList<>(); + DELTA1_RECORDS.add(new ApiMessageAndVersion(new ClientQuotaRecord(). + setEntity(Arrays.asList( + new EntityData().setEntityType(ClientQuotaEntity.USER).setEntityName("bar"), + new EntityData().setEntityType(ClientQuotaEntity.IP).setEntityName("127.0.0.1"))). + setKey(QuotaConfigs.CONSUMER_BYTE_RATE_OVERRIDE_CONFIG). + setRemove(true), CLIENT_QUOTA_RECORD.highestSupportedVersion())); + DELTA1_RECORDS.add(new ApiMessageAndVersion(new ClientQuotaRecord(). + setEntity(Arrays.asList( + new EntityData().setEntityType(ClientQuotaEntity.USER).setEntityName("foo"))). + setKey(QuotaConfigs.CONSUMER_BYTE_RATE_OVERRIDE_CONFIG). + setValue(999.0), CLIENT_QUOTA_RECORD.highestSupportedVersion())); + + DELTA1 = new ClientQuotasDelta(IMAGE1); + RecordTestUtils.replayAll(DELTA1, DELTA1_RECORDS); + + Map entities2 = new HashMap<>(); + Map fooUserQuotas2 = new HashMap<>(); + fooUserQuotas2.put(QuotaConfigs.PRODUCER_BYTE_RATE_OVERRIDE_CONFIG, 123.0); + fooUserQuotas2.put(QuotaConfigs.CONSUMER_BYTE_RATE_OVERRIDE_CONFIG, 999.0); + entities2.put(new ClientQuotaEntity(fooUser), new ClientQuotaImage(fooUserQuotas2)); + IMAGE2 = new ClientQuotasImage(entities2); + } + + @Test + public void testEmptyImageRoundTrip() throws Throwable { + testToImageAndBack(ClientQuotasImage.EMPTY); + } + + @Test + public void testImage1RoundTrip() throws Throwable { + testToImageAndBack(IMAGE1); + } + + @Test + public void testApplyDelta1() throws Throwable { + assertEquals(IMAGE2, DELTA1.apply()); + } + + @Test + public void testImage2RoundTrip() throws Throwable { + testToImageAndBack(IMAGE2); + } + + private void testToImageAndBack(ClientQuotasImage image) throws Throwable { + MockSnapshotConsumer writer = new MockSnapshotConsumer(); + image.write(writer); + ClientQuotasDelta delta = new ClientQuotasDelta(ClientQuotasImage.EMPTY); + RecordTestUtils.replayAllBatches(delta, writer.batches()); + ClientQuotasImage nextImage = delta.apply(); + assertEquals(image, nextImage); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/image/ClusterImageTest.java b/metadata/src/test/java/org/apache/kafka/image/ClusterImageTest.java new file mode 100644 index 0000000..6908cf2 --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/image/ClusterImageTest.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.image; + +import org.apache.kafka.common.Endpoint; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.metadata.FenceBrokerRecord; +import org.apache.kafka.common.metadata.UnfenceBrokerRecord; +import org.apache.kafka.common.metadata.UnregisterBrokerRecord; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.metadata.BrokerRegistration; +import org.apache.kafka.metadata.RecordTestUtils; +import org.apache.kafka.metadata.VersionRange; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.apache.kafka.common.metadata.MetadataRecordType.FENCE_BROKER_RECORD; +import static org.apache.kafka.common.metadata.MetadataRecordType.UNFENCE_BROKER_RECORD; +import static org.apache.kafka.common.metadata.MetadataRecordType.UNREGISTER_BROKER_RECORD; +import static org.junit.jupiter.api.Assertions.assertEquals; + + +@Timeout(value = 40) +public class ClusterImageTest { + public final static ClusterImage IMAGE1; + + static final List DELTA1_RECORDS; + + final static ClusterDelta DELTA1; + + final static ClusterImage IMAGE2; + + static { + Map map1 = new HashMap<>(); + map1.put(0, new BrokerRegistration(0, + 1000, + Uuid.fromString("vZKYST0pSA2HO5x_6hoO2Q"), + Arrays.asList(new Endpoint("PLAINTEXT", SecurityProtocol.PLAINTEXT, "localhost", 9092)), + Collections.singletonMap("foo", new VersionRange((short) 1, (short) 3)), + Optional.empty(), + true)); + map1.put(1, new BrokerRegistration(1, + 1001, + Uuid.fromString("U52uRe20RsGI0RvpcTx33Q"), + Arrays.asList(new Endpoint("PLAINTEXT", SecurityProtocol.PLAINTEXT, "localhost", 9093)), + Collections.singletonMap("foo", new VersionRange((short) 1, (short) 3)), + Optional.empty(), + false)); + map1.put(2, new BrokerRegistration(2, + 123, + Uuid.fromString("hr4TVh3YQiu3p16Awkka6w"), + Arrays.asList(new Endpoint("PLAINTEXT", SecurityProtocol.PLAINTEXT, "localhost", 9093)), + Collections.emptyMap(), + Optional.of("arack"), + false)); + IMAGE1 = new ClusterImage(map1); + + DELTA1_RECORDS = new ArrayList<>(); + DELTA1_RECORDS.add(new ApiMessageAndVersion(new UnfenceBrokerRecord(). + setId(0).setEpoch(1000), UNFENCE_BROKER_RECORD.highestSupportedVersion())); + DELTA1_RECORDS.add(new ApiMessageAndVersion(new FenceBrokerRecord(). + setId(1).setEpoch(1001), FENCE_BROKER_RECORD.highestSupportedVersion())); + DELTA1_RECORDS.add(new ApiMessageAndVersion(new UnregisterBrokerRecord(). + setBrokerId(2).setBrokerEpoch(123), + UNREGISTER_BROKER_RECORD.highestSupportedVersion())); + + DELTA1 = new ClusterDelta(IMAGE1); + RecordTestUtils.replayAll(DELTA1, DELTA1_RECORDS); + + Map map2 = new HashMap<>(); + map2.put(0, new BrokerRegistration(0, + 1000, + Uuid.fromString("vZKYST0pSA2HO5x_6hoO2Q"), + Arrays.asList(new Endpoint("PLAINTEXT", SecurityProtocol.PLAINTEXT, "localhost", 9092)), + Collections.singletonMap("foo", new VersionRange((short) 1, (short) 3)), + Optional.empty(), + false)); + map2.put(1, new BrokerRegistration(1, + 1001, + Uuid.fromString("U52uRe20RsGI0RvpcTx33Q"), + Arrays.asList(new Endpoint("PLAINTEXT", SecurityProtocol.PLAINTEXT, "localhost", 9093)), + Collections.singletonMap("foo", new VersionRange((short) 1, (short) 3)), + Optional.empty(), + true)); + IMAGE2 = new ClusterImage(map2); + } + + @Test + public void testEmptyImageRoundTrip() throws Throwable { + testToImageAndBack(ClusterImage.EMPTY); + } + + @Test + public void testImage1RoundTrip() throws Throwable { + testToImageAndBack(IMAGE1); + } + + @Test + public void testApplyDelta1() throws Throwable { + assertEquals(IMAGE2, DELTA1.apply()); + } + + @Test + public void testImage2RoundTrip() throws Throwable { + testToImageAndBack(IMAGE2); + } + + private void testToImageAndBack(ClusterImage image) throws Throwable { + MockSnapshotConsumer writer = new MockSnapshotConsumer(); + image.write(writer); + ClusterDelta delta = new ClusterDelta(ClusterImage.EMPTY); + RecordTestUtils.replayAllBatches(delta, writer.batches()); + ClusterImage nextImage = delta.apply(); + assertEquals(image, nextImage); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/image/ConfigurationsImageTest.java b/metadata/src/test/java/org/apache/kafka/image/ConfigurationsImageTest.java new file mode 100644 index 0000000..77d5c60 --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/image/ConfigurationsImageTest.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.image; + +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.metadata.ConfigRecord; +import org.apache.kafka.metadata.RecordTestUtils; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.kafka.common.config.ConfigResource.Type.BROKER; +import static org.apache.kafka.common.metadata.MetadataRecordType.CONFIG_RECORD; +import static org.junit.jupiter.api.Assertions.assertEquals; + + +@Timeout(value = 40) +public class ConfigurationsImageTest { + final static ConfigurationsImage IMAGE1; + + final static List DELTA1_RECORDS; + + final static ConfigurationsDelta DELTA1; + + final static ConfigurationsImage IMAGE2; + + static { + Map map1 = new HashMap<>(); + Map broker0Map = new HashMap<>(); + broker0Map.put("foo", "bar"); + broker0Map.put("baz", "quux"); + map1.put(new ConfigResource(BROKER, "0"), + new ConfigurationImage(broker0Map)); + Map broker1Map = new HashMap<>(); + broker1Map.put("foobar", "foobaz"); + map1.put(new ConfigResource(BROKER, "1"), + new ConfigurationImage(broker1Map)); + IMAGE1 = new ConfigurationsImage(map1); + + DELTA1_RECORDS = new ArrayList<>(); + DELTA1_RECORDS.add(new ApiMessageAndVersion(new ConfigRecord().setResourceType(BROKER.id()). + setResourceName("0").setName("foo").setValue(null), + CONFIG_RECORD.highestSupportedVersion())); + DELTA1_RECORDS.add(new ApiMessageAndVersion(new ConfigRecord().setResourceType(BROKER.id()). + setResourceName("1").setName("barfoo").setValue("bazfoo"), + CONFIG_RECORD.highestSupportedVersion())); + + DELTA1 = new ConfigurationsDelta(IMAGE1); + RecordTestUtils.replayAll(DELTA1, DELTA1_RECORDS); + + Map map2 = new HashMap<>(); + Map broker0Map2 = new HashMap<>(); + broker0Map2.put("baz", "quux"); + map2.put(new ConfigResource(BROKER, "0"), + new ConfigurationImage(broker0Map2)); + Map broker1Map2 = new HashMap<>(); + broker1Map2.put("foobar", "foobaz"); + broker1Map2.put("barfoo", "bazfoo"); + map2.put(new ConfigResource(BROKER, "1"), + new ConfigurationImage(broker1Map2)); + IMAGE2 = new ConfigurationsImage(map2); + } + + @Test + public void testEmptyImageRoundTrip() throws Throwable { + testToImageAndBack(ConfigurationsImage.EMPTY); + } + + @Test + public void testImage1RoundTrip() throws Throwable { + testToImageAndBack(IMAGE1); + } + + @Test + public void testApplyDelta1() throws Throwable { + assertEquals(IMAGE2, DELTA1.apply()); + } + + @Test + public void testImage2RoundTrip() throws Throwable { + testToImageAndBack(IMAGE2); + } + + private void testToImageAndBack(ConfigurationsImage image) throws Throwable { + MockSnapshotConsumer writer = new MockSnapshotConsumer(); + image.write(writer); + ConfigurationsDelta delta = new ConfigurationsDelta(ConfigurationsImage.EMPTY); + RecordTestUtils.replayAllBatches(delta, writer.batches()); + ConfigurationsImage nextImage = delta.apply(); + assertEquals(image, nextImage); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/image/FeaturesImageTest.java b/metadata/src/test/java/org/apache/kafka/image/FeaturesImageTest.java new file mode 100644 index 0000000..720086f --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/image/FeaturesImageTest.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.image; + +import org.apache.kafka.common.metadata.FeatureLevelRecord; +import org.apache.kafka.common.metadata.RemoveFeatureLevelRecord; +import org.apache.kafka.metadata.RecordTestUtils; +import org.apache.kafka.metadata.VersionRange; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.kafka.common.metadata.MetadataRecordType.FEATURE_LEVEL_RECORD; +import static org.apache.kafka.common.metadata.MetadataRecordType.REMOVE_FEATURE_LEVEL_RECORD; +import static org.junit.jupiter.api.Assertions.assertEquals; + + +@Timeout(value = 40) +public class FeaturesImageTest { + final static FeaturesImage IMAGE1; + final static List DELTA1_RECORDS; + final static FeaturesDelta DELTA1; + final static FeaturesImage IMAGE2; + + static { + Map map1 = new HashMap<>(); + map1.put("foo", new VersionRange((short) 1, (short) 2)); + map1.put("bar", new VersionRange((short) 1, (short) 1)); + map1.put("baz", new VersionRange((short) 1, (short) 8)); + IMAGE1 = new FeaturesImage(map1); + + DELTA1_RECORDS = new ArrayList<>(); + DELTA1_RECORDS.add(new ApiMessageAndVersion(new FeatureLevelRecord(). + setName("foo").setMinFeatureLevel((short) 1).setMaxFeatureLevel((short) 3), + FEATURE_LEVEL_RECORD.highestSupportedVersion())); + DELTA1_RECORDS.add(new ApiMessageAndVersion(new RemoveFeatureLevelRecord(). + setName("bar"), REMOVE_FEATURE_LEVEL_RECORD.highestSupportedVersion())); + DELTA1_RECORDS.add(new ApiMessageAndVersion(new RemoveFeatureLevelRecord(). + setName("baz"), REMOVE_FEATURE_LEVEL_RECORD.highestSupportedVersion())); + + DELTA1 = new FeaturesDelta(IMAGE1); + RecordTestUtils.replayAll(DELTA1, DELTA1_RECORDS); + + Map map2 = new HashMap<>(); + map2.put("foo", new VersionRange((short) 1, (short) 3)); + IMAGE2 = new FeaturesImage(map2); + } + + @Test + public void testEmptyImageRoundTrip() throws Throwable { + testToImageAndBack(FeaturesImage.EMPTY); + } + + @Test + public void testImage1RoundTrip() throws Throwable { + testToImageAndBack(IMAGE1); + } + + @Test + public void testApplyDelta1() throws Throwable { + assertEquals(IMAGE2, DELTA1.apply()); + } + + @Test + public void testImage2RoundTrip() throws Throwable { + testToImageAndBack(IMAGE2); + } + + private void testToImageAndBack(FeaturesImage image) throws Throwable { + MockSnapshotConsumer writer = new MockSnapshotConsumer(); + image.write(writer); + FeaturesDelta delta = new FeaturesDelta(FeaturesImage.EMPTY); + RecordTestUtils.replayAllBatches(delta, writer.batches()); + FeaturesImage nextImage = delta.apply(); + assertEquals(image, nextImage); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/image/MetadataImageTest.java b/metadata/src/test/java/org/apache/kafka/image/MetadataImageTest.java new file mode 100644 index 0000000..43709ba --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/image/MetadataImageTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.image; + +import org.apache.kafka.metadata.RecordTestUtils; +import org.apache.kafka.raft.OffsetAndEpoch; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import static org.junit.jupiter.api.Assertions.assertEquals; + + +@Timeout(value = 40) +public class MetadataImageTest { + public final static MetadataImage IMAGE1; + + public final static MetadataDelta DELTA1; + + public final static MetadataImage IMAGE2; + + static { + IMAGE1 = new MetadataImage( + new OffsetAndEpoch(100, 4), + FeaturesImageTest.IMAGE1, + ClusterImageTest.IMAGE1, + TopicsImageTest.IMAGE1, + ConfigurationsImageTest.IMAGE1, + ClientQuotasImageTest.IMAGE1); + + DELTA1 = new MetadataDelta(IMAGE1); + RecordTestUtils.replayAll(DELTA1, 200, 5, FeaturesImageTest.DELTA1_RECORDS); + RecordTestUtils.replayAll(DELTA1, 200, 5, ClusterImageTest.DELTA1_RECORDS); + RecordTestUtils.replayAll(DELTA1, 200, 5, TopicsImageTest.DELTA1_RECORDS); + RecordTestUtils.replayAll(DELTA1, 200, 5, ConfigurationsImageTest.DELTA1_RECORDS); + RecordTestUtils.replayAll(DELTA1, 200, 5, ClientQuotasImageTest.DELTA1_RECORDS); + + IMAGE2 = new MetadataImage( + new OffsetAndEpoch(200, 5), + FeaturesImageTest.IMAGE2, + ClusterImageTest.IMAGE2, + TopicsImageTest.IMAGE2, + ConfigurationsImageTest.IMAGE2, + ClientQuotasImageTest.IMAGE2); + } + + @Test + public void testEmptyImageRoundTrip() throws Throwable { + testToImageAndBack(MetadataImage.EMPTY); + } + + @Test + public void testImage1RoundTrip() throws Throwable { + testToImageAndBack(IMAGE1); + } + + @Test + public void testApplyDelta1() throws Throwable { + assertEquals(IMAGE2, DELTA1.apply()); + } + + @Test + public void testImage2RoundTrip() throws Throwable { + testToImageAndBack(IMAGE2); + } + + private void testToImageAndBack(MetadataImage image) throws Throwable { + MockSnapshotConsumer writer = new MockSnapshotConsumer(); + image.write(writer); + MetadataDelta delta = new MetadataDelta(MetadataImage.EMPTY); + RecordTestUtils.replayAllBatches( + delta, image.highestOffsetAndEpoch().offset, image.highestOffsetAndEpoch().epoch, writer.batches()); + MetadataImage nextImage = delta.apply(); + assertEquals(image, nextImage); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/image/MockSnapshotConsumer.java b/metadata/src/test/java/org/apache/kafka/image/MockSnapshotConsumer.java new file mode 100644 index 0000000..90b10f2 --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/image/MockSnapshotConsumer.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.image; + +import org.apache.kafka.server.common.ApiMessageAndVersion; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; + + +public class MockSnapshotConsumer implements Consumer> { + private final List> batches = new ArrayList<>(); + + @Override + public void accept(List batch) { + batches.add(batch); + } + + public List> batches() { + return batches; + } +} diff --git a/metadata/src/test/java/org/apache/kafka/image/TopicsImageTest.java b/metadata/src/test/java/org/apache/kafka/image/TopicsImageTest.java new file mode 100644 index 0000000..e417fb2 --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/image/TopicsImageTest.java @@ -0,0 +1,423 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.image; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.metadata.PartitionChangeRecord; +import org.apache.kafka.common.metadata.PartitionRecord; +import org.apache.kafka.common.metadata.RemoveTopicRecord; +import org.apache.kafka.common.metadata.TopicRecord; +import org.apache.kafka.metadata.PartitionRegistration; +import org.apache.kafka.metadata.RecordTestUtils; +import org.apache.kafka.metadata.Replicas; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; + +import static org.apache.kafka.common.metadata.MetadataRecordType.PARTITION_CHANGE_RECORD; +import static org.apache.kafka.common.metadata.MetadataRecordType.PARTITION_RECORD; +import static org.apache.kafka.common.metadata.MetadataRecordType.REMOVE_TOPIC_RECORD; +import static org.apache.kafka.common.metadata.MetadataRecordType.TOPIC_RECORD; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + + +@Timeout(value = 40) +public class TopicsImageTest { + static final TopicsImage IMAGE1; + + static final List DELTA1_RECORDS; + + static final TopicsDelta DELTA1; + + static final TopicsImage IMAGE2; + + static final List TOPIC_IMAGES1; + + private static TopicImage newTopicImage(String name, Uuid id, PartitionRegistration... partitions) { + Map partitionMap = new HashMap<>(); + int i = 0; + for (PartitionRegistration partition : partitions) { + partitionMap.put(i++, partition); + } + return new TopicImage(name, id, partitionMap); + } + + private static Map newTopicsByIdMap(Collection topics) { + Map map = new HashMap<>(); + for (TopicImage topic : topics) { + map.put(topic.id(), topic); + } + return map; + } + + private static Map newTopicsByNameMap(Collection topics) { + Map map = new HashMap<>(); + for (TopicImage topic : topics) { + map.put(topic.name(), topic); + } + return map; + } + + private static final Uuid FOO_UUID = Uuid.fromString("ThIaNwRnSM2Nt9Mx1v0RvA"); + + private static final Uuid BAR_UUID = Uuid.fromString("f62ptyETTjet8SL5ZeREiw"); + + private static final Uuid BAZ_UUID = Uuid.fromString("tgHBnRglT5W_RlENnuG5vg"); + + static { + TOPIC_IMAGES1 = Arrays.asList( + newTopicImage("foo", FOO_UUID, + new PartitionRegistration(new int[] {2, 3, 4}, + new int[] {2, 3}, Replicas.NONE, Replicas.NONE, 2, 1, 345), + new PartitionRegistration(new int[] {3, 4, 5}, + new int[] {3, 4, 5}, Replicas.NONE, Replicas.NONE, 3, 4, 684), + new PartitionRegistration(new int[] {2, 4, 5}, + new int[] {2, 4, 5}, Replicas.NONE, Replicas.NONE, 2, 10, 84)), + newTopicImage("bar", BAR_UUID, + new PartitionRegistration(new int[] {0, 1, 2, 3, 4}, + new int[] {0, 1, 2, 3}, new int[] {1}, new int[] {3, 4}, 0, 1, 345))); + + IMAGE1 = new TopicsImage(newTopicsByIdMap(TOPIC_IMAGES1), newTopicsByNameMap(TOPIC_IMAGES1)); + + DELTA1_RECORDS = new ArrayList<>(); + DELTA1_RECORDS.add(new ApiMessageAndVersion(new RemoveTopicRecord(). + setTopicId(FOO_UUID), + REMOVE_TOPIC_RECORD.highestSupportedVersion())); + DELTA1_RECORDS.add(new ApiMessageAndVersion(new PartitionChangeRecord(). + setTopicId(BAR_UUID). + setPartitionId(0).setLeader(1), + PARTITION_CHANGE_RECORD.highestSupportedVersion())); + DELTA1_RECORDS.add(new ApiMessageAndVersion(new TopicRecord(). + setName("baz").setTopicId(BAZ_UUID), + TOPIC_RECORD.highestSupportedVersion())); + DELTA1_RECORDS.add(new ApiMessageAndVersion(new PartitionRecord(). + setPartitionId(0). + setTopicId(BAZ_UUID). + setReplicas(Arrays.asList(1, 2, 3, 4)). + setIsr(Arrays.asList(3, 4)). + setRemovingReplicas(Collections.singletonList(2)). + setAddingReplicas(Collections.singletonList(1)). + setLeader(3). + setLeaderEpoch(2). + setPartitionEpoch(1), PARTITION_RECORD.highestSupportedVersion())); + + DELTA1 = new TopicsDelta(IMAGE1); + RecordTestUtils.replayAll(DELTA1, DELTA1_RECORDS); + + List topics2 = Arrays.asList( + newTopicImage("bar", BAR_UUID, + new PartitionRegistration(new int[] {0, 1, 2, 3, 4}, + new int[] {0, 1, 2, 3}, new int[] {1}, new int[] {3, 4}, 1, 2, 346)), + newTopicImage("baz", BAZ_UUID, + new PartitionRegistration(new int[] {1, 2, 3, 4}, + new int[] {3, 4}, new int[] {2}, new int[] {1}, 3, 2, 1))); + IMAGE2 = new TopicsImage(newTopicsByIdMap(topics2), newTopicsByNameMap(topics2)); + } + + private ApiMessageAndVersion newPartitionRecord(Uuid topicId, int partitionId, List replicas) { + return new ApiMessageAndVersion( + new PartitionRecord() + .setPartitionId(partitionId) + .setTopicId(topicId) + .setReplicas(replicas) + .setIsr(replicas) + .setLeader(replicas.get(0)) + .setLeaderEpoch(1) + .setPartitionEpoch(1), + PARTITION_RECORD.highestSupportedVersion() + ); + } + + private PartitionRegistration newPartition(int[] replicas) { + return new PartitionRegistration(replicas, replicas, Replicas.NONE, Replicas.NONE, replicas[0], 1, 1); + } + + @Test + public void testBasicLocalChanges() { + int localId = 3; + /* Changes already include in DELTA1_RECORDS and IMAGE1: + * foo - topic id deleted + * bar-0 - stay as follower with different partition epoch + * baz-0 - new topic to leader + */ + List topicRecords = new ArrayList<>(DELTA1_RECORDS); + + // Create a new foo topic with a different id + Uuid newFooId = Uuid.fromString("b66ybsWIQoygs01vdjH07A"); + topicRecords.add( + new ApiMessageAndVersion( + new TopicRecord().setName("foo") .setTopicId(newFooId), + TOPIC_RECORD.highestSupportedVersion() + ) + ); + topicRecords.add(newPartitionRecord(newFooId, 0, Arrays.asList(0, 1, 2))); + topicRecords.add(newPartitionRecord(newFooId, 1, Arrays.asList(0, 1, localId))); + + // baz-1 - new partition to follower + topicRecords.add( + new ApiMessageAndVersion( + new PartitionRecord() + .setPartitionId(1) + .setTopicId(BAZ_UUID) + .setReplicas(Arrays.asList(4, 2, localId)) + .setIsr(Arrays.asList(4, 2, localId)) + .setLeader(4) + .setLeaderEpoch(2) + .setPartitionEpoch(1), + PARTITION_RECORD.highestSupportedVersion() + ) + ); + + TopicsDelta delta = new TopicsDelta(IMAGE1); + RecordTestUtils.replayAll(delta, topicRecords); + + LocalReplicaChanges changes = delta.localChanges(localId); + assertEquals( + new HashSet<>(Arrays.asList(new TopicPartition("foo", 0), new TopicPartition("foo", 1))), + changes.deletes() + ); + assertEquals( + new HashSet<>(Arrays.asList(new TopicPartition("baz", 0))), + changes.leaders().keySet() + ); + assertEquals( + new HashSet<>( + Arrays.asList(new TopicPartition("baz", 1), new TopicPartition("bar", 0), new TopicPartition("foo", 1)) + ), + changes.followers().keySet() + ); + } + + @Test + public void testDeleteAfterChanges() { + int localId = 3; + Uuid zooId = Uuid.fromString("0hHJ3X5ZQ-CFfQ5xgpj90w"); + + List topics = new ArrayList<>(); + topics.add( + newTopicImage( + "zoo", + zooId, + newPartition(new int[] {localId, 1, 2}) + ) + ); + TopicsImage image = new TopicsImage(newTopicsByIdMap(topics), newTopicsByNameMap(topics)); + + List topicRecords = new ArrayList<>(); + // leader to follower + topicRecords.add( + new ApiMessageAndVersion( + new PartitionChangeRecord().setTopicId(zooId).setPartitionId(0).setLeader(1), + PARTITION_CHANGE_RECORD.highestSupportedVersion() + ) + ); + // remove zoo topic + topicRecords.add( + new ApiMessageAndVersion( + new RemoveTopicRecord().setTopicId(zooId), + REMOVE_TOPIC_RECORD.highestSupportedVersion() + ) + ); + + TopicsDelta delta = new TopicsDelta(image); + RecordTestUtils.replayAll(delta, topicRecords); + + LocalReplicaChanges changes = delta.localChanges(localId); + assertEquals(new HashSet<>(Arrays.asList(new TopicPartition("zoo", 0))), changes.deletes()); + assertEquals(Collections.emptyMap(), changes.leaders()); + assertEquals(Collections.emptyMap(), changes.followers()); + } + + @Test + public void testLocalReassignmentChanges() { + int localId = 3; + Uuid zooId = Uuid.fromString("0hHJ3X5ZQ-CFfQ5xgpj90w"); + + List topics = new ArrayList<>(); + topics.add( + newTopicImage( + "zoo", + zooId, + newPartition(new int[] {0, 1, localId}), + newPartition(new int[] {localId, 1, 2}), + newPartition(new int[] {0, 1, localId}), + newPartition(new int[] {localId, 1, 2}), + newPartition(new int[] {0, 1, 2}), + newPartition(new int[] {0, 1, 2}) + ) + ); + TopicsImage image = new TopicsImage(newTopicsByIdMap(topics), newTopicsByNameMap(topics)); + + List topicRecords = new ArrayList<>(); + // zoo-0 - follower to leader + topicRecords.add( + new ApiMessageAndVersion( + new PartitionChangeRecord().setTopicId(zooId).setPartitionId(0).setLeader(localId), + PARTITION_CHANGE_RECORD.highestSupportedVersion() + ) + ); + // zoo-1 - leader to follower + topicRecords.add( + new ApiMessageAndVersion( + new PartitionChangeRecord().setTopicId(zooId).setPartitionId(1).setLeader(1), + PARTITION_CHANGE_RECORD.highestSupportedVersion() + ) + ); + // zoo-2 - follower to removed + topicRecords.add( + new ApiMessageAndVersion( + new PartitionChangeRecord() + .setTopicId(zooId) + .setPartitionId(2) + .setIsr(Arrays.asList(0, 1, 2)) + .setReplicas(Arrays.asList(0, 1, 2)), + PARTITION_CHANGE_RECORD.highestSupportedVersion() + ) + ); + // zoo-3 - leader to removed + topicRecords.add( + new ApiMessageAndVersion( + new PartitionChangeRecord() + .setTopicId(zooId) + .setPartitionId(3) + .setLeader(0) + .setIsr(Arrays.asList(0, 1, 2)) + .setReplicas(Arrays.asList(0, 1, 2)), + PARTITION_CHANGE_RECORD.highestSupportedVersion() + ) + ); + // zoo-4 - not replica to leader + topicRecords.add( + new ApiMessageAndVersion( + new PartitionChangeRecord() + .setTopicId(zooId) + .setPartitionId(4) + .setLeader(localId) + .setIsr(Arrays.asList(localId, 1, 2)) + .setReplicas(Arrays.asList(localId, 1, 2)), + PARTITION_CHANGE_RECORD.highestSupportedVersion() + ) + ); + // zoo-5 - not replica to follower + topicRecords.add( + new ApiMessageAndVersion( + new PartitionChangeRecord() + .setTopicId(zooId) + .setPartitionId(5) + .setIsr(Arrays.asList(0, 1, localId)) + .setReplicas(Arrays.asList(0, 1, localId)), + PARTITION_CHANGE_RECORD.highestSupportedVersion() + ) + ); + + TopicsDelta delta = new TopicsDelta(image); + RecordTestUtils.replayAll(delta, topicRecords); + + LocalReplicaChanges changes = delta.localChanges(localId); + assertEquals( + new HashSet<>(Arrays.asList(new TopicPartition("zoo", 2), new TopicPartition("zoo", 3))), + changes.deletes() + ); + assertEquals( + new HashSet<>(Arrays.asList(new TopicPartition("zoo", 0), new TopicPartition("zoo", 4))), + changes.leaders().keySet() + ); + assertEquals( + new HashSet<>(Arrays.asList(new TopicPartition("zoo", 1), new TopicPartition("zoo", 5))), + changes.followers().keySet() + ); + } + + @Test + public void testEmptyImageRoundTrip() throws Throwable { + testToImageAndBack(TopicsImage.EMPTY); + } + + @Test + public void testImage1RoundTrip() throws Throwable { + testToImageAndBack(IMAGE1); + } + + @Test + public void testApplyDelta1() throws Throwable { + assertEquals(IMAGE2, DELTA1.apply()); + } + + @Test + public void testImage2RoundTrip() throws Throwable { + testToImageAndBack(IMAGE2); + } + + private void testToImageAndBack(TopicsImage image) throws Throwable { + MockSnapshotConsumer writer = new MockSnapshotConsumer(); + image.write(writer); + TopicsDelta delta = new TopicsDelta(TopicsImage.EMPTY); + RecordTestUtils.replayAllBatches(delta, writer.batches()); + TopicsImage nextImage = delta.apply(); + assertEquals(image, nextImage); + } + + @Test + public void testTopicNameToIdView() { + Map map = IMAGE1.topicNameToIdView(); + assertTrue(map.containsKey("foo")); + assertEquals(FOO_UUID, map.get("foo")); + assertTrue(map.containsKey("bar")); + assertEquals(BAR_UUID, map.get("bar")); + assertFalse(map.containsKey("baz")); + assertEquals(null, map.get("baz")); + HashSet uuids = new HashSet<>(); + map.values().iterator().forEachRemaining(u -> uuids.add(u)); + HashSet expectedUuids = new HashSet<>(Arrays.asList( + Uuid.fromString("ThIaNwRnSM2Nt9Mx1v0RvA"), + Uuid.fromString("f62ptyETTjet8SL5ZeREiw"))); + assertEquals(expectedUuids, uuids); + assertThrows(UnsupportedOperationException.class, () -> map.remove("foo")); + assertThrows(UnsupportedOperationException.class, () -> map.put("bar", FOO_UUID)); + } + + @Test + public void testTopicIdToNameView() { + Map map = IMAGE1.topicIdToNameView(); + assertTrue(map.containsKey(FOO_UUID)); + assertEquals("foo", map.get(FOO_UUID)); + assertTrue(map.containsKey(BAR_UUID)); + assertEquals("bar", map.get(BAR_UUID)); + assertFalse(map.containsKey(BAZ_UUID)); + assertEquals(null, map.get(BAZ_UUID)); + HashSet names = new HashSet<>(); + map.values().iterator().forEachRemaining(n -> names.add(n)); + HashSet expectedNames = new HashSet<>(Arrays.asList("foo", "bar")); + assertEquals(expectedNames, names); + assertThrows(UnsupportedOperationException.class, () -> map.remove(FOO_UUID)); + assertThrows(UnsupportedOperationException.class, () -> map.put(FOO_UUID, "bar")); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/metadata/BrokerRegistrationTest.java b/metadata/src/test/java/org/apache/kafka/metadata/BrokerRegistrationTest.java new file mode 100644 index 0000000..0f350c4 --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/metadata/BrokerRegistrationTest.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.metadata; + +import org.apache.kafka.common.Endpoint; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.metadata.RegisterBrokerRecord; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@Timeout(value = 40) +public class BrokerRegistrationTest { + private static final List REGISTRATIONS = Arrays.asList( + new BrokerRegistration(0, 0, Uuid.fromString("pc1GhUlBS92cGGaKXl6ipw"), + Arrays.asList(new Endpoint("INTERNAL", SecurityProtocol.PLAINTEXT, "localhost", 9090)), + Collections.singletonMap("foo", new VersionRange((short) 1, (short) 2)), + Optional.empty(), false), + new BrokerRegistration(1, 0, Uuid.fromString("3MfdxWlNSn2UDYsmDP1pYg"), + Arrays.asList(new Endpoint("INTERNAL", SecurityProtocol.PLAINTEXT, "localhost", 9091)), + Collections.singletonMap("foo", new VersionRange((short) 1, (short) 2)), + Optional.empty(), false), + new BrokerRegistration(2, 0, Uuid.fromString("eY7oaG1RREie5Kk9uy1l6g"), + Arrays.asList(new Endpoint("INTERNAL", SecurityProtocol.PLAINTEXT, "localhost", 9092)), + Collections.singletonMap("foo", new VersionRange((short) 2, (short) 3)), + Optional.of("myrack"), false)); + + @Test + public void testValues() { + assertEquals(0, REGISTRATIONS.get(0).id()); + assertEquals(1, REGISTRATIONS.get(1).id()); + assertEquals(2, REGISTRATIONS.get(2).id()); + } + + @Test + public void testEquals() { + assertFalse(REGISTRATIONS.get(0).equals(REGISTRATIONS.get(1))); + assertFalse(REGISTRATIONS.get(1).equals(REGISTRATIONS.get(0))); + assertFalse(REGISTRATIONS.get(0).equals(REGISTRATIONS.get(2))); + assertFalse(REGISTRATIONS.get(2).equals(REGISTRATIONS.get(0))); + assertTrue(REGISTRATIONS.get(0).equals(REGISTRATIONS.get(0))); + assertTrue(REGISTRATIONS.get(1).equals(REGISTRATIONS.get(1))); + assertTrue(REGISTRATIONS.get(2).equals(REGISTRATIONS.get(2))); + } + + @Test + public void testToString() { + assertEquals("BrokerRegistration(id=1, epoch=0, " + + "incarnationId=3MfdxWlNSn2UDYsmDP1pYg, listeners=[Endpoint(" + + "listenerName='INTERNAL', securityProtocol=PLAINTEXT, " + + "host='localhost', port=9091)], supportedFeatures={foo: 1-2}, " + + "rack=Optional.empty, fenced=false)", + REGISTRATIONS.get(1).toString()); + } + + @Test + public void testFromRecordAndToRecord() { + testRoundTrip(REGISTRATIONS.get(0)); + testRoundTrip(REGISTRATIONS.get(1)); + testRoundTrip(REGISTRATIONS.get(2)); + } + + private void testRoundTrip(BrokerRegistration registration) { + ApiMessageAndVersion messageAndVersion = registration.toRecord(); + BrokerRegistration registration2 = BrokerRegistration.fromRecord( + (RegisterBrokerRecord) messageAndVersion.message()); + assertEquals(registration, registration2); + ApiMessageAndVersion messageAndVersion2 = registration2.toRecord(); + assertEquals(messageAndVersion, messageAndVersion2); + } + + @Test + public void testToNode() { + assertEquals(Optional.empty(), REGISTRATIONS.get(0).node("NONEXISTENT")); + assertEquals(Optional.of(new Node(0, "localhost", 9090, null)), + REGISTRATIONS.get(0).node("INTERNAL")); + assertEquals(Optional.of(new Node(1, "localhost", 9091, null)), + REGISTRATIONS.get(1).node("INTERNAL")); + assertEquals(Optional.of(new Node(2, "localhost", 9092, "myrack")), + REGISTRATIONS.get(2).node("INTERNAL")); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/metadata/BrokerStateTest.java b/metadata/src/test/java/org/apache/kafka/metadata/BrokerStateTest.java new file mode 100644 index 0000000..d590f01 --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/metadata/BrokerStateTest.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.metadata; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@Timeout(value = 40) +public class BrokerStateTest { + private static final Logger log = LoggerFactory.getLogger(BrokerStateTest.class); + + @Test + public void testFromValue() { + for (BrokerState state : BrokerState.values()) { + BrokerState state2 = BrokerState.fromValue(state.value()); + assertEquals(state, state2); + } + } + + @Test + public void testUnknownValues() { + assertEquals(BrokerState.UNKNOWN, BrokerState.fromValue((byte) 126)); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/metadata/MetadataRecordSerdeTest.java b/metadata/src/test/java/org/apache/kafka/metadata/MetadataRecordSerdeTest.java new file mode 100644 index 0000000..cbcbe85 --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/metadata/MetadataRecordSerdeTest.java @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.metadata; + +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.metadata.RegisterBrokerRecord; +import org.apache.kafka.common.metadata.TopicRecord; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.ObjectSerializationCache; +import org.apache.kafka.common.utils.ByteUtils; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.apache.kafka.server.common.serialization.MetadataParseException; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class MetadataRecordSerdeTest { + + @Test + public void testSerde() { + TopicRecord topicRecord = new TopicRecord() + .setName("foo") + .setTopicId(Uuid.randomUuid()); + + MetadataRecordSerde serde = new MetadataRecordSerde(); + + for (short version = TopicRecord.LOWEST_SUPPORTED_VERSION; version <= TopicRecord.HIGHEST_SUPPORTED_VERSION; version++) { + ApiMessageAndVersion messageAndVersion = new ApiMessageAndVersion(topicRecord, version); + + ObjectSerializationCache cache = new ObjectSerializationCache(); + int size = serde.recordSize(messageAndVersion, cache); + + ByteBuffer buffer = ByteBuffer.allocate(size); + ByteBufferAccessor bufferAccessor = new ByteBufferAccessor(buffer); + + serde.write(messageAndVersion, cache, bufferAccessor); + buffer.flip(); + + assertEquals(size, buffer.remaining()); + ApiMessageAndVersion readMessageAndVersion = serde.read(bufferAccessor, size); + assertEquals(messageAndVersion, readMessageAndVersion); + } + } + + @Test + public void testDeserializeWithUnhandledFrameVersion() { + ByteBuffer buffer = ByteBuffer.allocate(16); + ByteUtils.writeUnsignedVarint(15, buffer); + buffer.flip(); + + MetadataRecordSerde serde = new MetadataRecordSerde(); + assertStartsWith("Could not deserialize metadata record due to unknown frame version", + assertThrows(MetadataParseException.class, + () -> serde.read(new ByteBufferAccessor(buffer), 16)).getMessage()); + } + + /** + * Test attempting to parse an event which has a malformed frame version type varint. + */ + @Test + public void testParsingMalformedFrameVersionVarint() { + MetadataRecordSerde serde = new MetadataRecordSerde(); + ByteBuffer buffer = ByteBuffer.allocate(64); + buffer.clear(); + buffer.put((byte) 0x80); + buffer.put((byte) 0x80); + buffer.put((byte) 0x80); + buffer.put((byte) 0x80); + buffer.put((byte) 0x80); + buffer.put((byte) 0x80); + buffer.position(0); + buffer.limit(64); + assertStartsWith("Error while reading frame version", + assertThrows(MetadataParseException.class, + () -> serde.read(new ByteBufferAccessor(buffer), buffer.remaining())).getMessage()); + } + + /** + * Test attempting to parse an event which has a malformed message type varint. + */ + @Test + public void testParsingMalformedMessageTypeVarint() { + MetadataRecordSerde serde = new MetadataRecordSerde(); + ByteBuffer buffer = ByteBuffer.allocate(64); + buffer.clear(); + buffer.put((byte) 0x01); + buffer.put((byte) 0x80); + buffer.put((byte) 0x80); + buffer.put((byte) 0x80); + buffer.put((byte) 0x80); + buffer.put((byte) 0x80); + buffer.put((byte) 0x80); + buffer.position(0); + buffer.limit(64); + assertStartsWith("Error while reading type", + assertThrows(MetadataParseException.class, + () -> serde.read(new ByteBufferAccessor(buffer), buffer.remaining())).getMessage()); + } + + /** + * Test attempting to parse an event which has a malformed message version varint. + */ + @Test + public void testParsingMalformedMessageVersionVarint() { + MetadataRecordSerde serde = new MetadataRecordSerde(); + ByteBuffer buffer = ByteBuffer.allocate(64); + buffer.clear(); + buffer.put((byte) 0x01); + buffer.put((byte) 0x08); + buffer.put((byte) 0x80); + buffer.put((byte) 0x80); + buffer.put((byte) 0x80); + buffer.put((byte) 0x80); + buffer.put((byte) 0x80); + buffer.position(0); + buffer.limit(64); + assertStartsWith("Error while reading version", + assertThrows(MetadataParseException.class, + () -> serde.read(new ByteBufferAccessor(buffer), buffer.remaining())).getMessage()); + } + + /** + * Test attempting to parse an event which has a version > Short.MAX_VALUE + */ + @Test + public void testParsingVersionTooLarge() { + MetadataRecordSerde serde = new MetadataRecordSerde(); + ByteBuffer buffer = ByteBuffer.allocate(64); + buffer.clear(); + buffer.put((byte) 0x01); // frame version + buffer.put((byte) 0x08); // apiKey + buffer.put((byte) 0xff); // api version + buffer.put((byte) 0xff); // api version + buffer.put((byte) 0xff); // api version + buffer.put((byte) 0x7f); // api version end + buffer.put((byte) 0x80); + buffer.position(0); + buffer.limit(64); + assertStartsWith("Value for version was too large", + assertThrows(MetadataParseException.class, + () -> serde.read(new ByteBufferAccessor(buffer), buffer.remaining())).getMessage()); + } + + /** + * Test attempting to parse an event which has a unsupported version + */ + @Test + public void testParsingUnsupportedApiKey() { + MetadataRecordSerde serde = new MetadataRecordSerde(); + ByteBuffer buffer = ByteBuffer.allocate(64); + buffer.put((byte) 0x01); // frame version + buffer.put((byte) 0xff); // apiKey + buffer.put((byte) 0x7f); // apiKey + buffer.put((byte) 0x00); // api version + buffer.put((byte) 0x80); + buffer.position(0); + buffer.limit(64); + assertStartsWith("Unknown metadata id ", + assertThrows(MetadataParseException.class, + () -> serde.read(new ByteBufferAccessor(buffer), buffer.remaining())).getCause().getMessage()); + } + + /** + * Test attempting to parse an event which has a malformed message body. + */ + @Test + public void testParsingMalformedMessage() { + MetadataRecordSerde serde = new MetadataRecordSerde(); + ByteBuffer buffer = ByteBuffer.allocate(4); + buffer.put((byte) 0x01); // frame version + buffer.put((byte) 0x00); // apiKey + buffer.put((byte) 0x00); // apiVersion + buffer.put((byte) 0x80); // malformed data + buffer.position(0); + buffer.limit(4); + assertStartsWith("Failed to deserialize record with type", + assertThrows(MetadataParseException.class, + () -> serde.read(new ByteBufferAccessor(buffer), buffer.remaining())).getMessage()); + } + + /** + * Test attempting to parse an event which has a malformed message version varint. + */ + @Test + public void testParsingRecordWithGarbageAtEnd() { + MetadataRecordSerde serde = new MetadataRecordSerde(); + RegisterBrokerRecord message = new RegisterBrokerRecord().setBrokerId(1).setBrokerEpoch(2); + + ObjectSerializationCache cache = new ObjectSerializationCache(); + ApiMessageAndVersion messageAndVersion = new ApiMessageAndVersion(message, (short) 0); + int size = serde.recordSize(messageAndVersion, cache); + ByteBuffer buffer = ByteBuffer.allocate(size + 1); + + serde.write(messageAndVersion, cache, new ByteBufferAccessor(buffer)); + buffer.clear(); + assertStartsWith("Found 1 byte(s) of garbage after", + assertThrows(MetadataParseException.class, + () -> serde.read(new ByteBufferAccessor(buffer), size + 1)).getMessage()); + } + + private static void assertStartsWith(String prefix, String str) { + assertTrue(str.startsWith(prefix), + "Expected string '" + str + "' to start with '" + prefix + "'"); + } + +} diff --git a/metadata/src/test/java/org/apache/kafka/metadata/OptionalStringComparatorTest.java b/metadata/src/test/java/org/apache/kafka/metadata/OptionalStringComparatorTest.java new file mode 100644 index 0000000..68f2544 --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/metadata/OptionalStringComparatorTest.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.metadata; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.Optional; + +import static org.apache.kafka.metadata.OptionalStringComparator.INSTANCE; +import static org.junit.jupiter.api.Assertions.assertEquals; + + +@Timeout(value = 40) +public class OptionalStringComparatorTest { + @Test + public void testComparisons() { + assertEquals(0, INSTANCE.compare(Optional.of("foo"), Optional.of("foo"))); + assertEquals(-1, INSTANCE.compare(Optional.of("a"), Optional.of("b"))); + assertEquals(1, INSTANCE.compare(Optional.of("b"), Optional.of("a"))); + assertEquals(-1, INSTANCE.compare(Optional.empty(), Optional.of("a"))); + assertEquals(1, INSTANCE.compare(Optional.of("a"), Optional.empty())); + assertEquals(0, INSTANCE.compare(Optional.empty(), Optional.empty())); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/metadata/PartitionRegistrationTest.java b/metadata/src/test/java/org/apache/kafka/metadata/PartitionRegistrationTest.java new file mode 100644 index 0000000..9b1be5d --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/metadata/PartitionRegistrationTest.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.metadata; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrPartitionState; +import org.apache.kafka.common.metadata.PartitionChangeRecord; +import org.apache.kafka.common.metadata.PartitionRecord; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.Arrays; +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + + +@Timeout(40) +public class PartitionRegistrationTest { + @Test + public void testElectionWasClean() { + assertTrue(PartitionRegistration.electionWasClean(1, new int[]{1, 2})); + assertFalse(PartitionRegistration.electionWasClean(1, new int[]{0, 2})); + assertFalse(PartitionRegistration.electionWasClean(1, new int[]{})); + assertTrue(PartitionRegistration.electionWasClean(3, new int[]{1, 2, 3, 4, 5, 6})); + } + + @Test + public void testPartitionControlInfoMergeAndDiff() { + PartitionRegistration a = new PartitionRegistration( + new int[]{1, 2, 3}, new int[]{1, 2}, Replicas.NONE, Replicas.NONE, 1, 0, 0); + PartitionRegistration b = new PartitionRegistration( + new int[]{1, 2, 3}, new int[]{3}, Replicas.NONE, Replicas.NONE, 3, 1, 1); + PartitionRegistration c = new PartitionRegistration( + new int[]{1, 2, 3}, new int[]{1}, Replicas.NONE, Replicas.NONE, 1, 0, 1); + assertEquals(b, a.merge(new PartitionChangeRecord(). + setLeader(3).setIsr(Arrays.asList(3)))); + assertEquals("isr: [1, 2] -> [3], leader: 1 -> 3, leaderEpoch: 0 -> 1, partitionEpoch: 0 -> 1", + b.diff(a)); + assertEquals("isr: [1, 2] -> [1], partitionEpoch: 0 -> 1", + c.diff(a)); + } + + @Test + public void testRecordRoundTrip() { + PartitionRegistration registrationA = new PartitionRegistration( + new int[]{1, 2, 3}, new int[]{1, 2}, new int[]{1}, Replicas.NONE, 1, 0, 0); + Uuid topicId = Uuid.fromString("OGdAI5nxT_m-ds3rJMqPLA"); + int partitionId = 4; + ApiMessageAndVersion record = registrationA.toRecord(topicId, partitionId); + PartitionRegistration registrationB = + new PartitionRegistration((PartitionRecord) record.message()); + assertEquals(registrationA, registrationB); + } + + @Test + public void testToLeaderAndIsrPartitionState() { + PartitionRegistration a = new PartitionRegistration( + new int[]{1, 2, 3}, new int[]{1, 2}, Replicas.NONE, Replicas.NONE, 1, 123, 456); + PartitionRegistration b = new PartitionRegistration( + new int[]{2, 3, 4}, new int[]{2, 3, 4}, Replicas.NONE, Replicas.NONE, 2, 234, 567); + assertEquals(new LeaderAndIsrPartitionState(). + setTopicName("foo"). + setPartitionIndex(1). + setControllerEpoch(-1). + setLeader(1). + setLeaderEpoch(123). + setIsr(Arrays.asList(1, 2)). + setZkVersion(456). + setReplicas(Arrays.asList(1, 2, 3)). + setAddingReplicas(Collections.emptyList()). + setRemovingReplicas(Collections.emptyList()). + setIsNew(true).toString(), + a.toLeaderAndIsrPartitionState(new TopicPartition("foo", 1), true).toString()); + assertEquals(new LeaderAndIsrPartitionState(). + setTopicName("bar"). + setPartitionIndex(0). + setControllerEpoch(-1). + setLeader(2). + setLeaderEpoch(234). + setIsr(Arrays.asList(2, 3, 4)). + setZkVersion(567). + setReplicas(Arrays.asList(2, 3, 4)). + setAddingReplicas(Collections.emptyList()). + setRemovingReplicas(Collections.emptyList()). + setIsNew(false).toString(), + b.toLeaderAndIsrPartitionState(new TopicPartition("bar", 0), false).toString()); + } + + @Test + public void testMergePartitionChangeRecordWithReassignmentData() { + PartitionRegistration partition0 = new PartitionRegistration(new int[] {1, 2, 3}, + new int[] {1, 2, 3}, Replicas.NONE, Replicas.NONE, 1, 100, 200); + PartitionRegistration partition1 = partition0.merge(new PartitionChangeRecord(). + setRemovingReplicas(Collections.singletonList(3)). + setAddingReplicas(Collections.singletonList(4)). + setReplicas(Arrays.asList(1, 2, 3, 4))); + assertEquals(new PartitionRegistration(new int[] {1, 2, 3, 4}, + new int[] {1, 2, 3}, new int[] {3}, new int[] {4}, 1, 100, 201), partition1); + PartitionRegistration partition2 = partition1.merge(new PartitionChangeRecord(). + setIsr(Arrays.asList(1, 2, 4)). + setRemovingReplicas(Collections.emptyList()). + setAddingReplicas(Collections.emptyList()). + setReplicas(Arrays.asList(1, 2, 4))); + assertEquals(new PartitionRegistration(new int[] {1, 2, 4}, + new int[] {1, 2, 4}, Replicas.NONE, Replicas.NONE, 1, 100, 202), partition2); + assertFalse(partition2.isReassigning()); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/metadata/RecordTestUtils.java b/metadata/src/test/java/org/apache/kafka/metadata/RecordTestUtils.java new file mode 100644 index 0000000..431c9bb --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/metadata/RecordTestUtils.java @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.metadata; + +import org.apache.kafka.common.protocol.ApiMessage; +import org.apache.kafka.common.protocol.Message; +import org.apache.kafka.common.protocol.ObjectSerializationCache; +import org.apache.kafka.common.utils.ImplicitLinkedHashCollection; +import org.apache.kafka.image.MetadataDelta; +import org.apache.kafka.raft.Batch; +import org.apache.kafka.raft.BatchReader; +import org.apache.kafka.raft.internals.MemoryBatchReader; +import org.apache.kafka.server.common.ApiMessageAndVersion; + +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + + +/** + * Utilities for testing classes that deal with metadata records. + */ +public class RecordTestUtils { + /** + * Replay a list of records. + * + * @param target The object to invoke the replay function on. + * @param recordsAndVersions A list of records. + */ + public static void replayAll(Object target, + List recordsAndVersions) { + for (ApiMessageAndVersion recordAndVersion : recordsAndVersions) { + ApiMessage record = recordAndVersion.message(); + try { + Method method = target.getClass().getMethod("replay", record.getClass()); + method.invoke(target, record); + } catch (NoSuchMethodException e) { + // ignore + } catch (InvocationTargetException e) { + throw new RuntimeException(e); + } catch (IllegalAccessException e) { + throw new RuntimeException(e); + } + } + } + + /** + * Replay a list of records to the metadata delta. + * + * @param delta the metadata delta on which to replay the records + * @param highestOffset highest offset from the list of records + * @param highestEpoch highest epoch from the list of records + * @param recordsAndVersions list of records + */ + public static void replayAll( + MetadataDelta delta, + long highestOffset, + int highestEpoch, + List recordsAndVersions + ) { + for (ApiMessageAndVersion recordAndVersion : recordsAndVersions) { + ApiMessage record = recordAndVersion.message(); + delta.replay(highestOffset, highestEpoch, record); + } + } + + /** + * Replay a list of record batches. + * + * @param target The object to invoke the replay function on. + * @param batches A list of batches of records. + */ + public static void replayAllBatches(Object target, + List> batches) { + for (List batch : batches) { + replayAll(target, batch); + } + } + + /** + * Replay a list of record batches to the metadata delta. + * + * @param delta the metadata delta on which to replay the records + * @param highestOffset highest offset from the list of record batches + * @param highestEpoch highest epoch from the list of record batches + * @param recordsAndVersions list of batches of records + */ + public static void replayAllBatches( + MetadataDelta delta, + long highestOffset, + int highestEpoch, + List> batches + ) { + for (List batch : batches) { + replayAll(delta, highestOffset, highestEpoch, batch); + } + } + + /** + * Materialize the output of an iterator into a set. + * + * @param iterator The input iterator. + * + * @return The output set. + */ + public static Set iteratorToSet(Iterator iterator) { + HashSet set = new HashSet<>(); + while (iterator.hasNext()) { + set.add(iterator.next()); + } + return set; + } + + /** + * Assert that a batch iterator yields a given set of record batches. + * + * @param batches A list of record batches. + * @param iterator The input iterator. + */ + public static void assertBatchIteratorContains(List> batches, + Iterator> iterator) throws Exception { + List> actual = new ArrayList<>(); + while (iterator.hasNext()) { + actual.add(new ArrayList<>(iterator.next())); + } + deepSortRecords(actual); + List> expected = new ArrayList<>(); + for (List batch : batches) { + expected.add(new ArrayList<>(batch)); + } + deepSortRecords(expected); + assertEquals(expected, actual); + } + + /** + * Sort the contents of an object which contains records. + * + * @param o The input object. It will be modified in-place. + */ + @SuppressWarnings("unchecked") + public static void deepSortRecords(Object o) throws Exception { + if (o == null) { + return; + } else if (o instanceof List) { + List list = (List) o; + for (Object entry : list) { + if (entry != null) { + if (Number.class.isAssignableFrom(entry.getClass())) { + return; + } + deepSortRecords(entry); + } + } + list.sort(Comparator.comparing(Object::toString)); + } else if (o instanceof ImplicitLinkedHashCollection) { + ImplicitLinkedHashCollection coll = (ImplicitLinkedHashCollection) o; + for (Object entry : coll) { + deepSortRecords(entry); + } + coll.sort(Comparator.comparing(Object::toString)); + } else if (o instanceof Message || o instanceof ApiMessageAndVersion) { + for (Field field : o.getClass().getDeclaredFields()) { + field.setAccessible(true); + deepSortRecords(field.get(o)); + } + } + } + + /** + * Create a batch reader for testing. + * + * @param lastOffset The last offset of the given list of records. + * @param records The records. + * @return A batch reader which will return the given records. + */ + public static BatchReader + mockBatchReader(long lastOffset, List records) { + List> batches = new ArrayList<>(); + long offset = lastOffset - records.size() + 1; + Iterator iterator = records.iterator(); + List curRecords = new ArrayList<>(); + assertTrue(iterator.hasNext()); // At least one record is required + while (true) { + if (!iterator.hasNext() || curRecords.size() >= 2) { + batches.add(Batch.data(offset, 0, 0, sizeInBytes(curRecords), curRecords)); + if (!iterator.hasNext()) { + break; + } + offset += curRecords.size(); + curRecords = new ArrayList<>(); + } + curRecords.add(iterator.next()); + } + return MemoryBatchReader.of(batches, __ -> { }); + } + + + private static int sizeInBytes(List records) { + int size = 0; + for (ApiMessageAndVersion record : records) { + ObjectSerializationCache cache = new ObjectSerializationCache(); + size += MetadataRecordSerde.INSTANCE.recordSize(record, cache); + } + return size; + } +} diff --git a/metadata/src/test/java/org/apache/kafka/metadata/ReplicasTest.java b/metadata/src/test/java/org/apache/kafka/metadata/ReplicasTest.java new file mode 100644 index 0000000..7a26d48 --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/metadata/ReplicasTest.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.metadata; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + + +@Timeout(40) +public class ReplicasTest { + @Test + public void testToList() { + assertEquals(Arrays.asList(1, 2, 3, 4), Replicas.toList(new int[] {1, 2, 3, 4})); + assertEquals(Arrays.asList(), Replicas.toList(Replicas.NONE)); + assertEquals(Arrays.asList(2), Replicas.toList(new int[] {2})); + } + + @Test + public void testToArray() { + assertArrayEquals(new int[] {3, 2, 1}, Replicas.toArray(Arrays.asList(3, 2, 1))); + assertArrayEquals(new int[] {}, Replicas.toArray(Arrays.asList())); + assertArrayEquals(new int[] {2}, Replicas.toArray(Arrays.asList(2))); + } + + @Test + public void testClone() { + assertArrayEquals(new int[]{3, 2, 1}, Replicas.clone(new int[]{3, 2, 1})); + assertArrayEquals(new int[]{}, Replicas.clone(new int[]{})); + assertArrayEquals(new int[]{2}, Replicas.clone(new int[]{2})); + } + + @Test + public void testValidate() { + assertTrue(Replicas.validate(new int[] {})); + assertTrue(Replicas.validate(new int[] {3})); + assertTrue(Replicas.validate(new int[] {3, 1, 2, 6})); + assertFalse(Replicas.validate(new int[] {3, 3})); + assertFalse(Replicas.validate(new int[] {4, -1, 3})); + assertFalse(Replicas.validate(new int[] {-1})); + assertFalse(Replicas.validate(new int[] {3, 1, 2, 6, 1})); + assertTrue(Replicas.validate(new int[] {1, 100})); + } + + @Test + public void testValidateIsr() { + assertTrue(Replicas.validateIsr(new int[] {}, new int[] {})); + assertTrue(Replicas.validateIsr(new int[] {1, 2, 3}, new int[] {})); + assertTrue(Replicas.validateIsr(new int[] {1, 2, 3}, new int[] {1, 2, 3})); + assertTrue(Replicas.validateIsr(new int[] {3, 1, 2}, new int[] {2, 1})); + assertFalse(Replicas.validateIsr(new int[] {3, 1, 2}, new int[] {4, 1})); + assertFalse(Replicas.validateIsr(new int[] {1, 2, 4}, new int[] {4, 4})); + } + + @Test + public void testContains() { + assertTrue(Replicas.contains(new int[] {3, 0, 1}, 0)); + assertFalse(Replicas.contains(new int[] {}, 0)); + assertTrue(Replicas.contains(new int[] {1}, 1)); + } + + @Test + public void testCopyWithout() { + assertArrayEquals(new int[] {}, Replicas.copyWithout(new int[] {}, 0)); + assertArrayEquals(new int[] {}, Replicas.copyWithout(new int[] {1}, 1)); + assertArrayEquals(new int[] {1, 3}, Replicas.copyWithout(new int[] {1, 2, 3}, 2)); + assertArrayEquals(new int[] {4, 1}, Replicas.copyWithout(new int[] {4, 2, 2, 1}, 2)); + } + + @Test + public void testCopyWithout2() { + assertArrayEquals(new int[] {}, Replicas.copyWithout(new int[] {}, new int[] {})); + assertArrayEquals(new int[] {}, Replicas.copyWithout(new int[] {1}, new int[] {1})); + assertArrayEquals(new int[] {1, 3}, + Replicas.copyWithout(new int[] {1, 2, 3}, new int[]{2, 4})); + assertArrayEquals(new int[] {4}, + Replicas.copyWithout(new int[] {4, 2, 2, 1}, new int[]{2, 1})); + } + + @Test + public void testCopyWith() { + assertArrayEquals(new int[] {-1}, Replicas.copyWith(new int[] {}, -1)); + assertArrayEquals(new int[] {1, 2, 3, 4}, Replicas.copyWith(new int[] {1, 2, 3}, 4)); + } + + @Test + public void testToSet() { + assertEquals(Collections.emptySet(), Replicas.toSet(new int[] {})); + assertEquals(new HashSet<>(Arrays.asList(3, 1, 5)), + Replicas.toSet(new int[] {1, 3, 5})); + assertEquals(new HashSet<>(Arrays.asList(1, 2, 10)), + Replicas.toSet(new int[] {1, 1, 2, 10, 10})); + } + + @Test + public void testContains2() { + assertTrue(Replicas.contains(Collections.emptyList(), Replicas.NONE)); + assertFalse(Replicas.contains(Collections.emptyList(), new int[] {1})); + assertTrue(Replicas.contains(Arrays.asList(1, 2, 3), new int[] {3, 2, 1})); + assertTrue(Replicas.contains(Arrays.asList(1, 2, 3, 4), new int[] {3})); + assertTrue(Replicas.contains(Arrays.asList(1, 2, 3, 4), new int[] {3, 1})); + assertFalse(Replicas.contains(Arrays.asList(1, 2, 3, 4), new int[] {3, 1, 7})); + assertTrue(Replicas.contains(Arrays.asList(1, 2, 3, 4), new int[] {})); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/metadata/VersionRangeTest.java b/metadata/src/test/java/org/apache/kafka/metadata/VersionRangeTest.java new file mode 100644 index 0000000..88082a6 --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/metadata/VersionRangeTest.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.metadata; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@Timeout(value = 40) +public class VersionRangeTest { + @SuppressWarnings("unchecked") + private static VersionRange v(int a, int b) { + assertTrue(a <= Short.MAX_VALUE); + assertTrue(a >= Short.MIN_VALUE); + assertTrue(b <= Short.MAX_VALUE); + assertTrue(b >= Short.MIN_VALUE); + return new VersionRange((short) a, (short) b); + } + + @Test + public void testEquality() { + assertEquals(v(1, 1), v(1, 1)); + assertFalse(v(1, 1).equals(v(1, 2))); + assertFalse(v(2, 1).equals(v(1, 2))); + assertFalse(v(2, 1).equals(v(2, 2))); + } + + @Test + public void testContains() { + assertTrue(v(1, 1).contains(v(1, 1))); + assertFalse(v(1, 1).contains(v(1, 2))); + assertTrue(v(1, 2).contains(v(1, 1))); + assertFalse(v(4, 10).contains(v(3, 8))); + assertTrue(v(2, 12).contains(v(3, 11))); + } + + @Test + public void testToString() { + assertEquals("1-2", v(1, 2).toString()); + assertEquals("1", v(1, 1).toString()); + assertEquals("1+", v(1, Short.MAX_VALUE).toString()); + assertEquals("100-200", v(100, 200).toString()); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/metalog/LocalLogManager.java b/metadata/src/test/java/org/apache/kafka/metalog/LocalLogManager.java new file mode 100644 index 0000000..531ecc8 --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/metalog/LocalLogManager.java @@ -0,0 +1,769 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.metalog; + +import org.apache.kafka.common.memory.MemoryPool; +import org.apache.kafka.common.protocol.ObjectSerializationCache; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.utils.BufferSupplier; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.metadata.MetadataRecordSerde; +import org.apache.kafka.queue.EventQueue; +import org.apache.kafka.queue.KafkaEventQueue; +import org.apache.kafka.raft.Batch; +import org.apache.kafka.raft.LeaderAndEpoch; +import org.apache.kafka.raft.OffsetAndEpoch; +import org.apache.kafka.raft.RaftClient; +import org.apache.kafka.raft.internals.MemoryBatchReader; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.apache.kafka.snapshot.MockRawSnapshotReader; +import org.apache.kafka.snapshot.MockRawSnapshotWriter; +import org.apache.kafka.snapshot.RawSnapshotReader; +import org.apache.kafka.snapshot.RawSnapshotWriter; +import org.apache.kafka.snapshot.SnapshotReader; +import org.apache.kafka.snapshot.SnapshotWriter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.AbstractMap.SimpleImmutableEntry; +import java.util.Collections; +import java.util.HashMap; +import java.util.IdentityHashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map.Entry; +import java.util.Map; +import java.util.NavigableMap; +import java.util.Objects; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.OptionalLong; +import java.util.TreeMap; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; + +/** + * The LocalLogManager is a test implementation that relies on the contents of memory. + */ +public final class LocalLogManager implements RaftClient, AutoCloseable { + interface LocalBatch { + int epoch(); + int size(); + } + + static class LeaderChangeBatch implements LocalBatch { + private final LeaderAndEpoch newLeader; + + LeaderChangeBatch(LeaderAndEpoch newLeader) { + this.newLeader = newLeader; + } + + @Override + public int epoch() { + return newLeader.epoch(); + } + + @Override + public int size() { + return 1; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof LeaderChangeBatch)) return false; + LeaderChangeBatch other = (LeaderChangeBatch) o; + if (!other.newLeader.equals(newLeader)) return false; + return true; + } + + @Override + public int hashCode() { + return Objects.hash(newLeader); + } + + @Override + public String toString() { + return "LeaderChangeBatch(newLeader=" + newLeader + ")"; + } + } + + static class LocalRecordBatch implements LocalBatch { + private final int leaderEpoch; + private final long appendTimestamp; + private final List records; + + LocalRecordBatch(int leaderEpoch, long appendTimestamp, List records) { + this.leaderEpoch = leaderEpoch; + this.appendTimestamp = appendTimestamp; + this.records = records; + } + + @Override + public int epoch() { + return leaderEpoch; + } + + @Override + public int size() { + return records.size(); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof LocalRecordBatch)) return false; + LocalRecordBatch other = (LocalRecordBatch) o; + + return leaderEpoch == other.leaderEpoch && + appendTimestamp == other.appendTimestamp && + Objects.equals(records, other.records); + } + + @Override + public int hashCode() { + return Objects.hash(leaderEpoch, appendTimestamp, records); + } + + @Override + public String toString() { + return String.format( + "LocalRecordBatch(leaderEpoch=%s, appendTimestamp=%s, records=%s)", + leaderEpoch, + appendTimestamp, + records + ); + } + } + + public static class SharedLogData { + private final Logger log = LoggerFactory.getLogger(SharedLogData.class); + + /** + * Maps node IDs to the matching log managers. + */ + private final HashMap logManagers = new HashMap<>(); + + /** + * Maps offsets to record batches. + */ + private final TreeMap batches = new TreeMap<>(); + + /** + * The current leader. + */ + private LeaderAndEpoch leader = new LeaderAndEpoch(OptionalInt.empty(), 0); + + /** + * The start offset of the last batch that was created, or -1 if no batches have + * been created. + */ + private long prevOffset; + + /** + * Maps committed offset to snapshot reader. + */ + private NavigableMap snapshots = new TreeMap<>(); + + public SharedLogData(Optional snapshot) { + if (snapshot.isPresent()) { + RawSnapshotReader initialSnapshot = snapshot.get(); + prevOffset = initialSnapshot.snapshotId().offset - 1; + snapshots.put(prevOffset, initialSnapshot); + } else { + prevOffset = -1; + } + } + + synchronized void registerLogManager(LocalLogManager logManager) { + if (logManagers.put(logManager.nodeId, logManager) != null) { + throw new RuntimeException("Can't have multiple LocalLogManagers " + + "with id " + logManager.nodeId()); + } + electLeaderIfNeeded(); + } + + synchronized void unregisterLogManager(LocalLogManager logManager) { + if (!logManagers.remove(logManager.nodeId, logManager)) { + throw new RuntimeException("Log manager " + logManager.nodeId() + + " was not found."); + } + } + + synchronized long tryAppend(int nodeId, int epoch, List batch) { + // No easy access to the concept of time. Use the base offset as the append timestamp + long appendTimestamp = (prevOffset + 1) * 10; + return tryAppend(nodeId, epoch, new LocalRecordBatch(epoch, appendTimestamp, batch)); + } + + synchronized long tryAppend(int nodeId, int epoch, LocalBatch batch) { + if (epoch != leader.epoch()) { + log.trace("tryAppend(nodeId={}, epoch={}): the provided epoch does not " + + "match the current leader epoch of {}.", nodeId, epoch, leader.epoch()); + return Long.MAX_VALUE; + } + if (!leader.isLeader(nodeId)) { + log.trace("tryAppend(nodeId={}, epoch={}): the given node id does not " + + "match the current leader id of {}.", nodeId, epoch, leader.leaderId()); + return Long.MAX_VALUE; + } + log.trace("tryAppend(nodeId={}): appending {}.", nodeId, batch); + long offset = append(batch); + electLeaderIfNeeded(); + return offset; + } + + synchronized long append(LocalBatch batch) { + prevOffset += batch.size(); + log.debug("append(batch={}, prevOffset={})", batch, prevOffset); + batches.put(prevOffset, batch); + if (batch instanceof LeaderChangeBatch) { + LeaderChangeBatch leaderChangeBatch = (LeaderChangeBatch) batch; + leader = leaderChangeBatch.newLeader; + } + for (LocalLogManager logManager : logManagers.values()) { + logManager.scheduleLogCheck(); + } + return prevOffset; + } + + synchronized void electLeaderIfNeeded() { + if (leader.leaderId().isPresent() || logManagers.isEmpty()) { + return; + } + int nextLeaderIndex = ThreadLocalRandom.current().nextInt(logManagers.size()); + Iterator iter = logManagers.keySet().iterator(); + Integer nextLeaderNode = null; + for (int i = 0; i <= nextLeaderIndex; i++) { + nextLeaderNode = iter.next(); + } + LeaderAndEpoch newLeader = new LeaderAndEpoch(OptionalInt.of(nextLeaderNode), leader.epoch() + 1); + log.info("Elected new leader: {}.", newLeader); + append(new LeaderChangeBatch(newLeader)); + } + + synchronized LeaderAndEpoch leaderAndEpoch() { + return leader; + } + + synchronized Entry nextBatch(long offset) { + Entry entry = batches.higherEntry(offset); + if (entry == null) { + return null; + } + return new SimpleImmutableEntry<>(entry.getKey(), entry.getValue()); + } + + /** + * Optionally return a snapshot reader if the offset if less than the first batch. + */ + synchronized Optional nextSnapshot(long offset) { + return Optional.ofNullable(snapshots.lastEntry()).flatMap(entry -> { + if (offset <= entry.getKey()) { + return Optional.of(entry.getValue()); + } + + return Optional.empty(); + }); + } + + /** + * Stores a new snapshot and notifies all threads waiting for a snapshot. + */ + synchronized void addSnapshot(RawSnapshotReader newSnapshot) { + if (newSnapshot.snapshotId().offset - 1 > prevOffset) { + log.error( + "Ignored attempt to add a snapshot {} that is greater than the latest offset {}", + newSnapshot, + prevOffset + ); + } else { + snapshots.put(newSnapshot.snapshotId().offset - 1, newSnapshot); + this.notifyAll(); + } + } + + /** + * Returns the snapshot whos last offset is the committed offset. + * + * If such snapshot doesn't exists, it waits until it does. + */ + synchronized RawSnapshotReader waitForSnapshot(long committedOffset) throws InterruptedException { + while (true) { + RawSnapshotReader reader = snapshots.get(committedOffset); + if (reader != null) { + return reader; + } else { + this.wait(); + } + } + } + + /** + * Returns the latest snapshot. + * + * If a snapshot doesn't exists, it waits until it does. + */ + synchronized RawSnapshotReader waitForLatestSnapshot() throws InterruptedException { + while (snapshots.isEmpty()) { + this.wait(); + } + + return Objects.requireNonNull(snapshots.lastEntry()).getValue(); + } + + synchronized long appendedBytes() { + ObjectSerializationCache objectCache = new ObjectSerializationCache(); + + return batches + .values() + .stream() + .flatMapToInt(batch -> { + if (batch instanceof LocalRecordBatch) { + LocalRecordBatch localBatch = (LocalRecordBatch) batch; + return localBatch.records.stream().mapToInt(record -> messageSize(record, objectCache)); + } else { + return IntStream.empty(); + } + }) + .sum(); + } + } + + private static class MetaLogListenerData { + private long offset = -1; + private LeaderAndEpoch notifiedLeader = new LeaderAndEpoch(OptionalInt.empty(), 0); + + private final RaftClient.Listener listener; + + MetaLogListenerData(RaftClient.Listener listener) { + this.listener = listener; + } + + long offset() { + return offset; + } + + void setOffset(long offset) { + this.offset = offset; + } + + LeaderAndEpoch notifiedLeader() { + return notifiedLeader; + } + + void handleCommit(MemoryBatchReader reader) { + listener.handleCommit(reader); + offset = reader.lastOffset().getAsLong(); + } + + void handleSnapshot(SnapshotReader reader) { + listener.handleSnapshot(reader); + offset = reader.lastContainedLogOffset(); + } + + void handleLeaderChange(long offset, LeaderAndEpoch leader) { + listener.handleLeaderChange(leader); + notifiedLeader = leader; + this.offset = offset; + } + + void beginShutdown() { + listener.beginShutdown(); + } + } + + private final Logger log; + + /** + * The node ID of this local log manager. Each log manager must have a unique ID. + */ + private final int nodeId; + + /** + * A reference to the in-memory state that unites all the log managers in use. + */ + private final SharedLogData shared; + + /** + * The event queue used by this local log manager. + */ + private final EventQueue eventQueue; + + /** + * Whether this LocalLogManager has been initialized. + */ + private boolean initialized = false; + + /** + * Whether this LocalLogManager has been shut down. + */ + private boolean shutdown = false; + + /** + * An offset that the log manager will not read beyond. This exists only for testing + * purposes. + */ + private long maxReadOffset = Long.MAX_VALUE; + + /** + * The listener objects attached to this local log manager. + */ + private final Map, MetaLogListenerData> listeners = new IdentityHashMap<>(); + + /** + * The current leader, as seen by this log manager. + */ + private volatile LeaderAndEpoch leader = new LeaderAndEpoch(OptionalInt.empty(), 0); + + /* + * If this variable is true the next non-atomic append with more than 1 record will + * result is half the records getting appended with leader election following that. + * This is done to emulate having some of the records not getting committed. + */ + private AtomicBoolean resignAfterNonAtomicCommit = new AtomicBoolean(false); + + public LocalLogManager(LogContext logContext, + int nodeId, + SharedLogData shared, + String threadNamePrefix) { + this.log = logContext.logger(LocalLogManager.class); + this.nodeId = nodeId; + this.shared = shared; + this.eventQueue = new KafkaEventQueue(Time.SYSTEM, logContext, threadNamePrefix); + shared.registerLogManager(this); + } + + private void scheduleLogCheck() { + eventQueue.append(() -> { + try { + log.debug("Node {}: running log check.", nodeId); + int numEntriesFound = 0; + for (MetaLogListenerData listenerData : listeners.values()) { + while (true) { + // Load the snapshot if needed and we are not the leader + LeaderAndEpoch notifiedLeader = listenerData.notifiedLeader(); + if (!OptionalInt.of(nodeId).equals(notifiedLeader.leaderId())) { + Optional snapshot = shared.nextSnapshot(listenerData.offset()); + if (snapshot.isPresent()) { + log.trace("Node {}: handling snapshot with id {}.", nodeId, snapshot.get().snapshotId()); + listenerData.handleSnapshot( + SnapshotReader.of( + snapshot.get(), + new MetadataRecordSerde(), + BufferSupplier.create(), + Integer.MAX_VALUE + ) + ); + } + } + + Entry entry = shared.nextBatch(listenerData.offset()); + if (entry == null) { + log.trace("Node {}: reached the end of the log after finding " + + "{} entries.", nodeId, numEntriesFound); + break; + } + long entryOffset = entry.getKey(); + if (entryOffset > maxReadOffset) { + log.trace("Node {}: after {} entries, not reading the next " + + "entry because its offset is {}, and maxReadOffset is {}.", + nodeId, numEntriesFound, entryOffset, maxReadOffset); + break; + } + if (entry.getValue() instanceof LeaderChangeBatch) { + LeaderChangeBatch batch = (LeaderChangeBatch) entry.getValue(); + log.trace("Node {}: handling LeaderChange to {}.", + nodeId, batch.newLeader); + // Only notify the listener if it equals the shared leader state + LeaderAndEpoch sharedLeader = shared.leaderAndEpoch(); + if (batch.newLeader.equals(sharedLeader)) { + listenerData.handleLeaderChange(entryOffset, batch.newLeader); + if (batch.newLeader.epoch() > leader.epoch()) { + leader = batch.newLeader; + } + } else { + log.debug("Node {}: Ignoring {} since it doesn't match the latest known leader {}", + nodeId, batch.newLeader, sharedLeader); + listenerData.setOffset(entryOffset); + } + } else if (entry.getValue() instanceof LocalRecordBatch) { + LocalRecordBatch batch = (LocalRecordBatch) entry.getValue(); + log.trace("Node {}: handling LocalRecordBatch with offset {}.", + nodeId, entryOffset); + ObjectSerializationCache objectCache = new ObjectSerializationCache(); + + listenerData.handleCommit( + MemoryBatchReader.of( + Collections.singletonList( + Batch.data( + entryOffset - batch.records.size() + 1, + batch.leaderEpoch, + batch.appendTimestamp, + batch + .records + .stream() + .mapToInt(record -> messageSize(record, objectCache)) + .sum(), + batch.records + ) + ), + reader -> { } + ) + ); + } + numEntriesFound++; + } + } + log.trace("Completed log check for node " + nodeId); + } catch (Exception e) { + log.error("Exception while handling log check", e); + } + }); + } + + private static int messageSize(ApiMessageAndVersion messageAndVersion, ObjectSerializationCache objectCache) { + return new MetadataRecordSerde().recordSize(messageAndVersion, objectCache); + } + + public void beginShutdown() { + eventQueue.beginShutdown("beginShutdown", () -> { + try { + if (initialized && !shutdown) { + log.debug("Node {}: beginning shutdown.", nodeId); + resign(leader.epoch()); + for (MetaLogListenerData listenerData : listeners.values()) { + listenerData.beginShutdown(); + } + shared.unregisterLogManager(this); + } + } catch (Exception e) { + log.error("Unexpected exception while sending beginShutdown callbacks", e); + } + shutdown = true; + }); + } + + @Override + public void close() { + log.debug("Node {}: closing.", nodeId); + beginShutdown(); + + try { + eventQueue.close(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } + + /** + * Shutdown the log manager. + * + * Even though the API suggests a non-blocking shutdown, this method always returns a completed + * future. This means that shutdown is a blocking operation. + */ + @Override + public CompletableFuture shutdown(int timeoutMs) { + CompletableFuture shutdownFuture = new CompletableFuture<>(); + try { + close(); + shutdownFuture.complete(null); + } catch (Throwable t) { + shutdownFuture.completeExceptionally(t); + } + return shutdownFuture; + } + + @Override + public void initialize() { + eventQueue.append(() -> { + log.debug("initialized local log manager for node " + nodeId); + initialized = true; + }); + } + + @Override + public void register(RaftClient.Listener listener) { + CompletableFuture future = new CompletableFuture<>(); + eventQueue.append(() -> { + if (shutdown) { + log.info("Node {}: can't register because local log manager has " + + "already been shut down.", nodeId); + future.complete(null); + } else if (initialized) { + int id = System.identityHashCode(listener); + if (listeners.putIfAbsent(listener, new MetaLogListenerData(listener)) != null) { + log.error("Node {}: can't register because listener {} already exists", nodeId, id); + } else { + log.info("Node {}: registered MetaLogListener {}", nodeId, id); + } + shared.electLeaderIfNeeded(); + scheduleLogCheck(); + future.complete(null); + } else { + log.info("Node {}: can't register because local log manager has not " + + "been initialized.", nodeId); + future.completeExceptionally(new RuntimeException( + "LocalLogManager was not initialized.")); + } + }); + try { + future.get(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } catch (ExecutionException e) { + throw new RuntimeException(e); + } + } + + @Override + public void unregister(RaftClient.Listener listener) { + eventQueue.append(() -> { + if (shutdown) { + log.info("Node {}: can't unregister because local log manager is shutdown", nodeId); + } else { + int id = System.identityHashCode(listener); + if (listeners.remove(listener) == null) { + log.error("Node {}: can't unregister because the listener {} doesn't exists", nodeId, id); + } else { + log.info("Node {}: unregistered MetaLogListener {}", nodeId, id); + } + } + }); + } + + @Override + public long scheduleAppend(int epoch, List batch) { + if (batch.isEmpty()) { + throw new IllegalArgumentException("Batch cannot be empty"); + } + + List first = batch.subList(0, batch.size() / 2); + List second = batch.subList(batch.size() / 2, batch.size()); + + assertEquals(batch.size(), first.size() + second.size()); + assertFalse(second.isEmpty()); + + OptionalLong firstOffset = first + .stream() + .mapToLong(record -> scheduleAtomicAppend(epoch, Collections.singletonList(record))) + .max(); + + if (firstOffset.isPresent() && resignAfterNonAtomicCommit.getAndSet(false)) { + // Emulate losing leadership in the middle of a non-atomic append by not writing + // the rest of the batch and instead writing a leader change message + resign(leader.epoch()); + + return firstOffset.getAsLong() + second.size(); + } else { + return second + .stream() + .mapToLong(record -> scheduleAtomicAppend(epoch, Collections.singletonList(record))) + .max() + .getAsLong(); + } + } + + @Override + public long scheduleAtomicAppend(int epoch, List batch) { + return shared.tryAppend(nodeId, leader.epoch(), batch); + } + + @Override + public void resign(int epoch) { + LeaderAndEpoch curLeader = leader; + LeaderAndEpoch nextLeader = new LeaderAndEpoch(OptionalInt.empty(), curLeader.epoch() + 1); + shared.tryAppend(nodeId, curLeader.epoch(), new LeaderChangeBatch(nextLeader)); + } + + @Override + public Optional> createSnapshot( + long committedOffset, + int committedEpoch, + long lastContainedLogTimestamp + ) { + OffsetAndEpoch snapshotId = new OffsetAndEpoch(committedOffset + 1, committedEpoch); + return SnapshotWriter.createWithHeader( + () -> createNewSnapshot(snapshotId), + 1024, + MemoryPool.NONE, + new MockTime(), + lastContainedLogTimestamp, + CompressionType.NONE, + new MetadataRecordSerde() + ); + } + + private Optional createNewSnapshot(OffsetAndEpoch snapshotId) { + return Optional.of( + new MockRawSnapshotWriter(snapshotId, buffer -> { + shared.addSnapshot(new MockRawSnapshotReader(snapshotId, buffer)); + }) + ); + } + + @Override + public LeaderAndEpoch leaderAndEpoch() { + return leader; + } + + @Override + public OptionalInt nodeId() { + return OptionalInt.of(nodeId); + } + + public List> listeners() { + final CompletableFuture>> future = new CompletableFuture<>(); + eventQueue.append(() -> { + future.complete(listeners.values().stream().map(l -> l.listener).collect(Collectors.toList())); + }); + try { + return future.get(); + } catch (ExecutionException | InterruptedException e) { + throw new RuntimeException(e); + } + } + + public void setMaxReadOffset(long maxReadOffset) { + CompletableFuture future = new CompletableFuture<>(); + eventQueue.append(() -> { + log.trace("Node {}: set maxReadOffset to {}.", nodeId, maxReadOffset); + this.maxReadOffset = maxReadOffset; + scheduleLogCheck(); + future.complete(null); + }); + try { + future.get(); + } catch (ExecutionException | InterruptedException e) { + throw new RuntimeException(e); + } + } + + public void resignAfterNonAtomicCommit() { + resignAfterNonAtomicCommit.set(true); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/metalog/LocalLogManagerTest.java b/metadata/src/test/java/org/apache/kafka/metalog/LocalLogManagerTest.java new file mode 100644 index 0000000..7b5e26d --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/metalog/LocalLogManagerTest.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.metalog; + +import org.apache.kafka.common.metadata.RegisterBrokerRecord; +import org.apache.kafka.raft.LeaderAndEpoch; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.stream.Collectors; + +import static org.apache.kafka.metalog.MockMetaLogManagerListener.COMMIT; +import static org.apache.kafka.metalog.MockMetaLogManagerListener.LAST_COMMITTED_OFFSET; +import static org.apache.kafka.metalog.MockMetaLogManagerListener.SHUTDOWN; +import static org.junit.jupiter.api.Assertions.assertEquals; + + +@Timeout(value = 40) +public class LocalLogManagerTest { + + /** + * Test creating a LocalLogManager and closing it. + */ + @Test + public void testCreateAndClose() throws Exception { + try (LocalLogManagerTestEnv env = + LocalLogManagerTestEnv.createWithMockListeners(1, Optional.empty())) { + env.close(); + assertEquals(null, env.firstError.get()); + } + } + + /** + * Test that the local log manager will claim leadership. + */ + @Test + public void testClaimsLeadership() throws Exception { + try (LocalLogManagerTestEnv env = + LocalLogManagerTestEnv.createWithMockListeners(1, Optional.empty())) { + assertEquals(new LeaderAndEpoch(OptionalInt.of(0), 1), env.waitForLeader()); + env.close(); + assertEquals(null, env.firstError.get()); + } + } + + /** + * Test that we can pass leadership back and forth between log managers. + */ + @Test + public void testPassLeadership() throws Exception { + try (LocalLogManagerTestEnv env = + LocalLogManagerTestEnv.createWithMockListeners(3, Optional.empty())) { + LeaderAndEpoch first = env.waitForLeader(); + LeaderAndEpoch cur = first; + do { + int currentLeaderId = cur.leaderId().orElseThrow(() -> + new AssertionError("Current leader is undefined") + ); + env.logManagers().get(currentLeaderId).resign(cur.epoch()); + + LeaderAndEpoch next = env.waitForLeader(); + while (next.epoch() == cur.epoch()) { + Thread.sleep(1); + next = env.waitForLeader(); + } + long expectedNextEpoch = cur.epoch() + 2; + assertEquals(expectedNextEpoch, next.epoch(), "Expected next epoch to be " + expectedNextEpoch + + ", but found " + next); + cur = next; + } while (cur.leaderId().equals(first.leaderId())); + env.close(); + assertEquals(null, env.firstError.get()); + } + } + + private static void waitForLastCommittedOffset(long targetOffset, + LocalLogManager logManager) throws InterruptedException { + TestUtils.retryOnExceptionWithTimeout(20000, 3, () -> { + MockMetaLogManagerListener listener = + (MockMetaLogManagerListener) logManager.listeners().get(0); + long highestOffset = -1; + for (String event : listener.serializedEvents()) { + if (event.startsWith(LAST_COMMITTED_OFFSET)) { + long offset = Long.valueOf( + event.substring(LAST_COMMITTED_OFFSET.length() + 1)); + if (offset < highestOffset) { + throw new RuntimeException("Invalid offset: " + offset + + " is less than the previous offset of " + highestOffset); + } + highestOffset = offset; + } + } + if (highestOffset < targetOffset) { + throw new RuntimeException("Offset for log manager " + + logManager.nodeId() + " only reached " + highestOffset); + } + }); + } + + /** + * Test that all the log managers see all the commits. + */ + @Test + public void testCommits() throws Exception { + try (LocalLogManagerTestEnv env = + LocalLogManagerTestEnv.createWithMockListeners(3, Optional.empty())) { + LeaderAndEpoch leaderInfo = env.waitForLeader(); + int leaderId = leaderInfo.leaderId().orElseThrow(() -> + new AssertionError("Current leader is undefined") + ); + + LocalLogManager activeLogManager = env.logManagers().get(leaderId); + int epoch = activeLogManager.leaderAndEpoch().epoch(); + List messages = Arrays.asList( + new ApiMessageAndVersion(new RegisterBrokerRecord().setBrokerId(0), (short) 0), + new ApiMessageAndVersion(new RegisterBrokerRecord().setBrokerId(1), (short) 0), + new ApiMessageAndVersion(new RegisterBrokerRecord().setBrokerId(2), (short) 0)); + assertEquals(3, activeLogManager.scheduleAppend(epoch, messages)); + for (LocalLogManager logManager : env.logManagers()) { + waitForLastCommittedOffset(3, logManager); + } + List listeners = env.logManagers().stream(). + map(m -> (MockMetaLogManagerListener) m.listeners().get(0)). + collect(Collectors.toList()); + env.close(); + for (MockMetaLogManagerListener listener : listeners) { + List events = listener.serializedEvents(); + assertEquals(SHUTDOWN, events.get(events.size() - 1)); + int foundIndex = 0; + for (String event : events) { + if (event.startsWith(COMMIT)) { + assertEquals(messages.get(foundIndex).message().toString(), + event.substring(COMMIT.length() + 1)); + foundIndex++; + } + } + assertEquals(messages.size(), foundIndex); + } + } + } +} diff --git a/metadata/src/test/java/org/apache/kafka/metalog/LocalLogManagerTestEnv.java b/metadata/src/test/java/org/apache/kafka/metalog/LocalLogManagerTestEnv.java new file mode 100644 index 0000000..8a22434 --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/metalog/LocalLogManagerTestEnv.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.metalog; + +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.metalog.LocalLogManager.SharedLogData; +import org.apache.kafka.raft.LeaderAndEpoch; +import org.apache.kafka.snapshot.RawSnapshotReader; +import org.apache.kafka.test.TestUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; + +public class LocalLogManagerTestEnv implements AutoCloseable { + private static final Logger log = + LoggerFactory.getLogger(LocalLogManagerTestEnv.class); + + /** + * The first error we encountered during this test, or the empty string if we have + * not encountered any. + */ + final AtomicReference firstError = new AtomicReference<>(null); + + /** + * The test directory, which we will delete once the test is over. + */ + private final File dir; + + /** + * The shared data for our LocalLogManager instances. + */ + private final SharedLogData shared; + + /** + * A list of log managers. + */ + private final List logManagers; + + public static LocalLogManagerTestEnv createWithMockListeners( + int numManagers, + Optional snapshot + ) throws Exception { + LocalLogManagerTestEnv testEnv = new LocalLogManagerTestEnv(numManagers, snapshot); + try { + for (LocalLogManager logManager : testEnv.logManagers) { + logManager.register(new MockMetaLogManagerListener(logManager.nodeId().getAsInt())); + } + } catch (Exception e) { + testEnv.close(); + throw e; + } + return testEnv; + } + + public LocalLogManagerTestEnv(int numManagers, Optional snapshot) throws Exception { + dir = TestUtils.tempDirectory(); + shared = new SharedLogData(snapshot); + List newLogManagers = new ArrayList<>(numManagers); + try { + for (int nodeId = 0; nodeId < numManagers; nodeId++) { + newLogManagers.add(new LocalLogManager( + new LogContext(String.format("[LocalLogManager %d] ", nodeId)), + nodeId, + shared, + String.format("LocalLogManager-%d_", nodeId))); + } + for (LocalLogManager logManager : newLogManagers) { + logManager.initialize(); + } + } catch (Throwable t) { + for (LocalLogManager logManager : newLogManagers) { + logManager.close(); + } + throw t; + } + this.logManagers = newLogManagers; + } + + AtomicReference firstError() { + return firstError; + } + + File dir() { + return dir; + } + + LeaderAndEpoch waitForLeader() throws InterruptedException { + AtomicReference value = new AtomicReference<>(null); + TestUtils.retryOnExceptionWithTimeout(20000, 3, () -> { + LeaderAndEpoch result = null; + for (LocalLogManager logManager : logManagers) { + LeaderAndEpoch leader = logManager.leaderAndEpoch(); + int nodeId = logManager.nodeId().getAsInt(); + if (leader.isLeader(nodeId)) { + if (result != null) { + throw new RuntimeException("node " + nodeId + + " thinks it's the leader, but so does " + result.leaderId()); + } + result = leader; + } + } + if (result == null) { + throw new RuntimeException("No leader found."); + } + value.set(result); + }); + return value.get(); + } + + public List logManagers() { + return logManagers; + } + + public RawSnapshotReader waitForSnapshot(long committedOffset) throws InterruptedException { + return shared.waitForSnapshot(committedOffset); + } + + public RawSnapshotReader waitForLatestSnapshot() throws InterruptedException { + return shared.waitForLatestSnapshot(); + } + + public long appendedBytes() { + return shared.appendedBytes(); + } + + public LeaderAndEpoch leaderAndEpoch() { + return shared.leaderAndEpoch(); + } + + @Override + public void close() throws InterruptedException { + try { + for (LocalLogManager logManager : logManagers) { + logManager.beginShutdown(); + } + for (LocalLogManager logManager : logManagers) { + logManager.close(); + } + Utils.delete(dir); + } catch (IOException e) { + log.error("Error deleting {}", dir.getAbsolutePath(), e); + } + } +} diff --git a/metadata/src/test/java/org/apache/kafka/metalog/MockMetaLogManagerListener.java b/metadata/src/test/java/org/apache/kafka/metalog/MockMetaLogManagerListener.java new file mode 100644 index 0000000..53a10ca --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/metalog/MockMetaLogManagerListener.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.metalog; + +import org.apache.kafka.common.protocol.ApiMessage; +import org.apache.kafka.raft.Batch; +import org.apache.kafka.raft.BatchReader; +import org.apache.kafka.raft.LeaderAndEpoch; +import org.apache.kafka.raft.RaftClient; +import org.apache.kafka.snapshot.SnapshotReader; +import org.apache.kafka.server.common.ApiMessageAndVersion; + +import java.util.ArrayList; +import java.util.List; +import java.util.OptionalInt; + +public class MockMetaLogManagerListener implements RaftClient.Listener { + public static final String COMMIT = "COMMIT"; + public static final String LAST_COMMITTED_OFFSET = "LAST_COMMITTED_OFFSET"; + public static final String NEW_LEADER = "NEW_LEADER"; + public static final String RENOUNCE = "RENOUNCE"; + public static final String SHUTDOWN = "SHUTDOWN"; + public static final String SNAPSHOT = "SNAPSHOT"; + + private final int nodeId; + private final List serializedEvents = new ArrayList<>(); + private LeaderAndEpoch leaderAndEpoch = new LeaderAndEpoch(OptionalInt.empty(), 0); + + public MockMetaLogManagerListener(int nodeId) { + this.nodeId = nodeId; + } + + @Override + public synchronized void handleCommit(BatchReader reader) { + try { + while (reader.hasNext()) { + Batch batch = reader.next(); + long lastCommittedOffset = batch.lastOffset(); + + for (ApiMessageAndVersion messageAndVersion : batch.records()) { + ApiMessage message = messageAndVersion.message(); + StringBuilder bld = new StringBuilder(); + bld.append(COMMIT).append(" ").append(message.toString()); + serializedEvents.add(bld.toString()); + } + StringBuilder bld = new StringBuilder(); + bld.append(LAST_COMMITTED_OFFSET).append(" ").append(lastCommittedOffset); + serializedEvents.add(bld.toString()); + } + } finally { + reader.close(); + } + } + + @Override + public synchronized void handleSnapshot(SnapshotReader reader) { + long lastCommittedOffset = reader.lastContainedLogOffset(); + try { + while (reader.hasNext()) { + Batch batch = reader.next(); + + for (ApiMessageAndVersion messageAndVersion : batch.records()) { + ApiMessage message = messageAndVersion.message(); + StringBuilder bld = new StringBuilder(); + bld.append(SNAPSHOT).append(" ").append(message.toString()); + serializedEvents.add(bld.toString()); + } + StringBuilder bld = new StringBuilder(); + bld.append(LAST_COMMITTED_OFFSET).append(" ").append(lastCommittedOffset); + serializedEvents.add(bld.toString()); + } + } finally { + reader.close(); + } + } + + @Override + public synchronized void handleLeaderChange(LeaderAndEpoch newLeaderAndEpoch) { + LeaderAndEpoch oldLeaderAndEpoch = this.leaderAndEpoch; + this.leaderAndEpoch = newLeaderAndEpoch; + + if (newLeaderAndEpoch.isLeader(nodeId)) { + StringBuilder bld = new StringBuilder(); + bld.append(NEW_LEADER).append(" "). + append(nodeId).append(" ").append(newLeaderAndEpoch.epoch()); + serializedEvents.add(bld.toString()); + } else if (oldLeaderAndEpoch.isLeader(nodeId)) { + StringBuilder bld = new StringBuilder(); + bld.append(RENOUNCE).append(" ").append(newLeaderAndEpoch.epoch()); + serializedEvents.add(bld.toString()); + } + } + + @Override + public void beginShutdown() { + StringBuilder bld = new StringBuilder(); + bld.append(SHUTDOWN); + synchronized (this) { + serializedEvents.add(bld.toString()); + } + } + + public synchronized List serializedEvents() { + return new ArrayList<>(serializedEvents); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/timeline/BaseHashTableTest.java b/metadata/src/test/java/org/apache/kafka/timeline/BaseHashTableTest.java new file mode 100644 index 0000000..a73357c --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/timeline/BaseHashTableTest.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.timeline; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.HashSet; +import java.util.Random; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +@Timeout(value = 40) +public class BaseHashTableTest { + + @Test + public void testEmptyTable() { + BaseHashTable table = new BaseHashTable<>(0); + assertEquals(0, table.baseSize()); + assertEquals(null, table.baseGet(Integer.valueOf(1))); + } + + @Test + public void testFindSlot() { + Random random = new Random(123); + for (int i = 1; i <= 5; i++) { + int numSlots = 2 << i; + HashSet slotsReturned = new HashSet<>(); + while (slotsReturned.size() < numSlots) { + int slot = BaseHashTable.findSlot(random.nextInt(), numSlots); + assertTrue(slot >= 0); + assertTrue(slot < numSlots); + slotsReturned.add(slot); + } + } + } + + @Test + public void testInsertAndRemove() { + BaseHashTable table = new BaseHashTable<>(20); + Integer one = Integer.valueOf(1); + Integer two = Integer.valueOf(2); + Integer three = Integer.valueOf(3); + Integer four = Integer.valueOf(4); + assertEquals(null, table.baseAddOrReplace(one)); + assertEquals(null, table.baseAddOrReplace(two)); + assertEquals(null, table.baseAddOrReplace(three)); + assertEquals(3, table.baseSize()); + assertEquals(one, table.baseGet(one)); + assertEquals(two, table.baseGet(two)); + assertEquals(three, table.baseGet(three)); + assertEquals(null, table.baseGet(four)); + assertEquals(one, table.baseRemove(one)); + assertEquals(2, table.baseSize()); + assertEquals(null, table.baseGet(one)); + assertEquals(2, table.baseSize()); + } + + static class Foo { + @Override + public boolean equals(Object o) { + return this == o; + } + + @Override + public int hashCode() { + return 42; + } + } + + @Test + public void testHashCollisons() { + Foo one = new Foo(); + Foo two = new Foo(); + Foo three = new Foo(); + Foo four = new Foo(); + BaseHashTable table = new BaseHashTable<>(20); + assertEquals(null, table.baseAddOrReplace(one)); + assertEquals(null, table.baseAddOrReplace(two)); + assertEquals(null, table.baseAddOrReplace(three)); + assertEquals(3, table.baseSize()); + assertEquals(one, table.baseGet(one)); + assertEquals(two, table.baseGet(two)); + assertEquals(three, table.baseGet(three)); + assertEquals(null, table.baseGet(four)); + assertEquals(one, table.baseRemove(one)); + assertEquals(three, table.baseRemove(three)); + assertEquals(1, table.baseSize()); + assertEquals(null, table.baseGet(four)); + assertEquals(two, table.baseGet(two)); + assertEquals(two, table.baseRemove(two)); + assertEquals(0, table.baseSize()); + } + + @Test + public void testExpansion() { + BaseHashTable table = new BaseHashTable<>(0); + + for (int i = 0; i < 4096; i++) { + assertEquals(i, table.baseSize()); + assertEquals(null, table.baseAddOrReplace(Integer.valueOf(i))); + } + + for (int i = 0; i < 4096; i++) { + assertEquals(4096 - i, table.baseSize()); + assertEquals(Integer.valueOf(i), table.baseRemove(Integer.valueOf(i))); + } + } + + @Test + public void testExpectedSizeToCapacity() { + assertEquals(2, BaseHashTable.expectedSizeToCapacity(Integer.MIN_VALUE)); + assertEquals(2, BaseHashTable.expectedSizeToCapacity(-123)); + assertEquals(2, BaseHashTable.expectedSizeToCapacity(0)); + assertEquals(2, BaseHashTable.expectedSizeToCapacity(1)); + assertEquals(4, BaseHashTable.expectedSizeToCapacity(2)); + assertEquals(4, BaseHashTable.expectedSizeToCapacity(3)); + assertEquals(8, BaseHashTable.expectedSizeToCapacity(4)); + assertEquals(16, BaseHashTable.expectedSizeToCapacity(12)); + assertEquals(32, BaseHashTable.expectedSizeToCapacity(13)); + assertEquals(0x2000000, BaseHashTable.expectedSizeToCapacity(0x1010400)); + assertEquals(0x4000000, BaseHashTable.expectedSizeToCapacity(0x2000000)); + assertEquals(0x4000000, BaseHashTable.expectedSizeToCapacity(0x2000001)); + assertEquals(BaseHashTable.MAX_CAPACITY, BaseHashTable.expectedSizeToCapacity(BaseHashTable.MAX_CAPACITY)); + assertEquals(BaseHashTable.MAX_CAPACITY, BaseHashTable.expectedSizeToCapacity(BaseHashTable.MAX_CAPACITY + 1)); + assertEquals(BaseHashTable.MAX_CAPACITY, BaseHashTable.expectedSizeToCapacity(Integer.MAX_VALUE - 1)); + assertEquals(BaseHashTable.MAX_CAPACITY, BaseHashTable.expectedSizeToCapacity(Integer.MAX_VALUE)); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/timeline/SnapshotRegistryTest.java b/metadata/src/test/java/org/apache/kafka/timeline/SnapshotRegistryTest.java new file mode 100644 index 0000000..8922423 --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/timeline/SnapshotRegistryTest.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.timeline; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +import org.apache.kafka.common.utils.LogContext; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + + +@Timeout(value = 40) +public class SnapshotRegistryTest { + @Test + public void testEmptyRegistry() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + assertThrows(RuntimeException.class, () -> registry.getSnapshot(0)); + assertIteratorContains(registry.iterator()); + } + + private static void assertIteratorContains(Iterator iter, + Snapshot... snapshots) { + List expected = new ArrayList<>(); + for (Snapshot snapshot : snapshots) { + expected.add(snapshot); + } + List actual = new ArrayList<>(); + while (iter.hasNext()) { + Snapshot snapshot = iter.next(); + actual.add(snapshot); + } + assertEquals(expected, actual); + } + + @Test + public void testCreateSnapshots() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + Snapshot snapshot123 = registry.getOrCreateSnapshot(123); + assertEquals(snapshot123, registry.getSnapshot(123)); + assertThrows(RuntimeException.class, () -> registry.getSnapshot(456)); + assertIteratorContains(registry.iterator(), snapshot123); + assertEquals("Can't create a new snapshot at epoch 1 because there is already " + + "a snapshot with epoch 123", assertThrows(RuntimeException.class, + () -> registry.getOrCreateSnapshot(1)).getMessage()); + Snapshot snapshot456 = registry.getOrCreateSnapshot(456); + assertIteratorContains(registry.iterator(), snapshot123, snapshot456); + } + + @Test + public void testCreateAndDeleteSnapshots() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + Snapshot snapshot123 = registry.getOrCreateSnapshot(123); + Snapshot snapshot456 = registry.getOrCreateSnapshot(456); + Snapshot snapshot789 = registry.getOrCreateSnapshot(789); + registry.deleteSnapshot(snapshot456.epoch()); + assertIteratorContains(registry.iterator(), snapshot123, snapshot789); + } + + @Test + public void testDeleteSnapshotUpTo() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + registry.getOrCreateSnapshot(10); + registry.getOrCreateSnapshot(12); + Snapshot snapshot14 = registry.getOrCreateSnapshot(14); + registry.deleteSnapshotsUpTo(14); + assertIteratorContains(registry.iterator(), snapshot14); + } + + @Test + public void testCreateSnapshotOfLatest() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + registry.getOrCreateSnapshot(10); + Snapshot latest = registry.getOrCreateSnapshot(12); + Snapshot duplicate = registry.getOrCreateSnapshot(12); + + assertEquals(latest, duplicate); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/timeline/SnapshottableHashTableTest.java b/metadata/src/test/java/org/apache/kafka/timeline/SnapshottableHashTableTest.java new file mode 100644 index 0000000..7f1ddcc --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/timeline/SnapshottableHashTableTest.java @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.timeline; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.IdentityHashMap; +import java.util.Iterator; +import java.util.List; +import java.util.stream.Collectors; + +import org.apache.kafka.common.utils.LogContext; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; + +@Timeout(value = 40) +public class SnapshottableHashTableTest { + + /** + * The class of test elements. + * + * This class is intended to help test how the table handles distinct objects which + * are equal to each other. Therefore, for the purpose of hashing and equality, we + * only check i here, and ignore j. + */ + static class TestElement implements SnapshottableHashTable.ElementWithStartEpoch { + private final int i; + private final char j; + private long startEpoch = Long.MAX_VALUE; + + TestElement(int i, char j) { + this.i = i; + this.j = j; + } + + @Override + public void setStartEpoch(long startEpoch) { + this.startEpoch = startEpoch; + } + + @Override + public long startEpoch() { + return startEpoch; + } + + @Override + public int hashCode() { + return i; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof TestElement)) { + return false; + } + TestElement other = (TestElement) o; + return other.i == i; + } + + @Override + public String toString() { + return String.format("E_%d%c(%s)", i, j, System.identityHashCode(this)); + } + } + + private static final TestElement E_1A = new TestElement(1, 'A'); + private static final TestElement E_1B = new TestElement(1, 'B'); + private static final TestElement E_2A = new TestElement(2, 'A'); + private static final TestElement E_3A = new TestElement(3, 'A'); + private static final TestElement E_3B = new TestElement(3, 'B'); + + @Test + public void testEmptyTable() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + SnapshottableHashTable table = + new SnapshottableHashTable<>(registry, 1); + assertEquals(0, table.snapshottableSize(Long.MAX_VALUE)); + } + + @Test + public void testAddAndRemove() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + SnapshottableHashTable table = + new SnapshottableHashTable<>(registry, 1); + assertTrue(null == table.snapshottableAddOrReplace(E_1B)); + assertEquals(1, table.snapshottableSize(Long.MAX_VALUE)); + registry.getOrCreateSnapshot(0); + assertTrue(E_1B == table.snapshottableAddOrReplace(E_1A)); + assertTrue(E_1B == table.snapshottableGet(E_1A, 0)); + assertTrue(E_1A == table.snapshottableGet(E_1A, Long.MAX_VALUE)); + assertEquals(null, table.snapshottableAddOrReplace(E_2A)); + assertEquals(null, table.snapshottableAddOrReplace(E_3A)); + assertEquals(3, table.snapshottableSize(Long.MAX_VALUE)); + assertEquals(1, table.snapshottableSize(0)); + registry.getOrCreateSnapshot(1); + assertEquals(E_1A, table.snapshottableRemove(E_1B)); + assertEquals(E_2A, table.snapshottableRemove(E_2A)); + assertEquals(E_3A, table.snapshottableRemove(E_3A)); + assertEquals(0, table.snapshottableSize(Long.MAX_VALUE)); + assertEquals(1, table.snapshottableSize(0)); + assertEquals(3, table.snapshottableSize(1)); + registry.deleteSnapshot(0); + assertEquals("No snapshot for epoch 0. Snapshot epochs are: 1", + assertThrows(RuntimeException.class, () -> + table.snapshottableSize(0)).getMessage()); + registry.deleteSnapshot(1); + assertEquals(0, table.snapshottableSize(Long.MAX_VALUE)); + } + + @Test + public void testIterateOverSnapshot() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + SnapshottableHashTable table = + new SnapshottableHashTable<>(registry, 1); + assertTrue(table.snapshottableAddUnlessPresent(E_1B)); + assertFalse(table.snapshottableAddUnlessPresent(E_1A)); + assertTrue(table.snapshottableAddUnlessPresent(E_2A)); + assertTrue(table.snapshottableAddUnlessPresent(E_3A)); + registry.getOrCreateSnapshot(0); + assertIteratorYields(table.snapshottableIterator(0), E_1B, E_2A, E_3A); + assertEquals(E_1B, table.snapshottableRemove(E_1B)); + assertIteratorYields(table.snapshottableIterator(0), E_1B, E_2A, E_3A); + assertEquals(null, table.snapshottableRemove(E_1A)); + assertIteratorYields(table.snapshottableIterator(Long.MAX_VALUE), E_2A, E_3A); + assertEquals(E_2A, table.snapshottableRemove(E_2A)); + assertEquals(E_3A, table.snapshottableRemove(E_3A)); + assertIteratorYields(table.snapshottableIterator(0), E_1B, E_2A, E_3A); + } + + @Test + public void testIterateOverSnapshotWhileExpandingTable() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + SnapshottableHashTable table = + new SnapshottableHashTable<>(registry, 1); + assertEquals(null, table.snapshottableAddOrReplace(E_1A)); + registry.getOrCreateSnapshot(0); + Iterator iter = table.snapshottableIterator(0); + assertTrue(table.snapshottableAddUnlessPresent(E_2A)); + assertTrue(table.snapshottableAddUnlessPresent(E_3A)); + assertIteratorYields(iter, E_1A); + } + + @Test + public void testIterateOverSnapshotWhileDeletingAndReplacing() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + SnapshottableHashTable table = + new SnapshottableHashTable<>(registry, 1); + assertEquals(null, table.snapshottableAddOrReplace(E_1A)); + assertEquals(null, table.snapshottableAddOrReplace(E_2A)); + assertEquals(null, table.snapshottableAddOrReplace(E_3A)); + assertEquals(E_1A, table.snapshottableRemove(E_1A)); + assertEquals(null, table.snapshottableAddOrReplace(E_1B)); + registry.getOrCreateSnapshot(0); + Iterator iter = table.snapshottableIterator(0); + List iterElements = new ArrayList<>(); + iterElements.add(iter.next()); + assertEquals(E_2A, table.snapshottableRemove(E_2A)); + assertEquals(E_3A, table.snapshottableAddOrReplace(E_3B)); + iterElements.add(iter.next()); + assertEquals(E_1B, table.snapshottableRemove(E_1B)); + iterElements.add(iter.next()); + assertFalse(iter.hasNext()); + assertIteratorYields(iterElements.iterator(), E_1B, E_2A, E_3A); + } + + @Test + public void testRevert() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + SnapshottableHashTable table = + new SnapshottableHashTable<>(registry, 1); + assertEquals(null, table.snapshottableAddOrReplace(E_1A)); + assertEquals(null, table.snapshottableAddOrReplace(E_2A)); + assertEquals(null, table.snapshottableAddOrReplace(E_3A)); + registry.getOrCreateSnapshot(0); + assertEquals(E_1A, table.snapshottableAddOrReplace(E_1B)); + assertEquals(E_3A, table.snapshottableAddOrReplace(E_3B)); + registry.getOrCreateSnapshot(1); + assertEquals(3, table.snapshottableSize(Long.MAX_VALUE)); + assertIteratorYields(table.snapshottableIterator(Long.MAX_VALUE), E_1B, E_2A, E_3B); + table.snapshottableRemove(E_1B); + table.snapshottableRemove(E_2A); + table.snapshottableRemove(E_3B); + assertEquals(0, table.snapshottableSize(Long.MAX_VALUE)); + assertEquals(3, table.snapshottableSize(0)); + assertEquals(3, table.snapshottableSize(1)); + registry.revertToSnapshot(0); + assertIteratorYields(table.snapshottableIterator(Long.MAX_VALUE), E_1A, E_2A, E_3A); + } + + @Test + public void testReset() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + SnapshottableHashTable table = + new SnapshottableHashTable<>(registry, 1); + assertEquals(null, table.snapshottableAddOrReplace(E_1A)); + assertEquals(null, table.snapshottableAddOrReplace(E_2A)); + assertEquals(null, table.snapshottableAddOrReplace(E_3A)); + registry.getOrCreateSnapshot(0); + assertEquals(E_1A, table.snapshottableAddOrReplace(E_1B)); + assertEquals(E_3A, table.snapshottableAddOrReplace(E_3B)); + registry.getOrCreateSnapshot(1); + + registry.reset(); + + assertEquals(Collections.emptyList(), registry.epochsList()); + // Check that the table is empty + assertIteratorYields(table.snapshottableIterator(Long.MAX_VALUE)); + } + + /** + * Assert that the given iterator contains the given elements, in any order. + * We compare using reference equality here, rather than object equality. + */ + private static void assertIteratorYields(Iterator iter, + Object... expected) { + IdentityHashMap remaining = new IdentityHashMap<>(); + for (Object object : expected) { + remaining.put(object, true); + } + List extraObjects = new ArrayList<>(); + int i = 0; + while (iter.hasNext()) { + Object object = iter.next(); + assertNotNull(object); + if (remaining.remove(object) == null) { + extraObjects.add(object); + } + } + if (!extraObjects.isEmpty() || !remaining.isEmpty()) { + throw new RuntimeException("Found extra object(s): [" + String.join(", ", + extraObjects.stream().map(e -> e.toString()).collect(Collectors.toList())) + + "] and didn't find object(s): [" + String.join(", ", + remaining.keySet().stream().map(e -> e.toString()).collect(Collectors.toList())) + "]"); + } + } +} diff --git a/metadata/src/test/java/org/apache/kafka/timeline/TimelineHashMapTest.java b/metadata/src/test/java/org/apache/kafka/timeline/TimelineHashMapTest.java new file mode 100644 index 0000000..afffd3d --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/timeline/TimelineHashMapTest.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.timeline; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import org.apache.kafka.common.utils.LogContext; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@Timeout(value = 40) +public class TimelineHashMapTest { + + @Test + public void testEmptyMap() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + TimelineHashMap map = new TimelineHashMap<>(registry, 1); + assertTrue(map.isEmpty()); + assertEquals(0, map.size()); + map.clear(); + assertTrue(map.isEmpty()); + } + + @Test + public void testNullsForbidden() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + TimelineHashMap map = new TimelineHashMap<>(registry, 1); + assertThrows(NullPointerException.class, () -> map.put(null, true)); + assertThrows(NullPointerException.class, () -> map.put("abc", null)); + assertThrows(NullPointerException.class, () -> map.put(null, null)); + } + + @Test + public void testIteration() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + TimelineHashMap map = new TimelineHashMap<>(registry, 1); + map.put(123, "abc"); + map.put(456, "def"); + assertThat(iteratorToList(map.keySet().iterator()), containsInAnyOrder(123, 456)); + assertThat(iteratorToList(map.values().iterator()), containsInAnyOrder("abc", "def")); + assertTrue(map.containsValue("abc")); + assertTrue(map.containsKey(456)); + assertFalse(map.isEmpty()); + registry.getOrCreateSnapshot(2); + Iterator> iter = map.entrySet(2).iterator(); + map.clear(); + List snapshotValues = new ArrayList<>(); + snapshotValues.add(iter.next().getValue()); + snapshotValues.add(iter.next().getValue()); + assertFalse(iter.hasNext()); + assertThat(snapshotValues, containsInAnyOrder("abc", "def")); + assertFalse(map.isEmpty(2)); + assertTrue(map.isEmpty()); + } + + static List iteratorToList(Iterator iter) { + List list = new ArrayList<>(); + while (iter.hasNext()) { + list.add(iter.next()); + } + return list; + } + + @Test + public void testMapMethods() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + TimelineHashMap map = new TimelineHashMap<>(registry, 1); + assertEquals(null, map.putIfAbsent(1, "xyz")); + assertEquals("xyz", map.putIfAbsent(1, "123")); + assertEquals("xyz", map.putIfAbsent(1, "ghi")); + map.putAll(Collections.singletonMap(2, "b")); + assertTrue(map.containsKey(2)); + assertEquals("xyz", map.remove(1)); + assertEquals("b", map.remove(2)); + } + + @Test + public void testMapEquals() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + TimelineHashMap map1 = new TimelineHashMap<>(registry, 1); + assertEquals(null, map1.putIfAbsent(1, "xyz")); + assertEquals(null, map1.putIfAbsent(2, "abc")); + TimelineHashMap map2 = new TimelineHashMap<>(registry, 1); + assertEquals(null, map2.putIfAbsent(1, "xyz")); + assertFalse(map1.equals(map2)); + assertEquals(null, map2.putIfAbsent(2, "abc")); + assertEquals(map1, map2); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/timeline/TimelineHashSetTest.java b/metadata/src/test/java/org/apache/kafka/timeline/TimelineHashSetTest.java new file mode 100644 index 0000000..070893c --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/timeline/TimelineHashSetTest.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.timeline; + +import java.util.Arrays; + +import org.apache.kafka.common.utils.LogContext; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@Timeout(value = 40) +public class TimelineHashSetTest { + + @Test + public void testEmptySet() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + TimelineHashSet set = new TimelineHashSet<>(registry, 1); + assertTrue(set.isEmpty()); + assertEquals(0, set.size()); + set.clear(); + assertTrue(set.isEmpty()); + } + + @Test + public void testNullsForbidden() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + TimelineHashSet set = new TimelineHashSet<>(registry, 1); + assertThrows(NullPointerException.class, () -> set.add(null)); + } + + @Test + public void testIteration() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + TimelineHashSet set = new TimelineHashSet<>(registry, 1); + set.add("a"); + set.add("b"); + set.add("c"); + set.add("d"); + assertTrue(set.retainAll(Arrays.asList("a", "b", "c"))); + assertFalse(set.retainAll(Arrays.asList("a", "b", "c"))); + assertFalse(set.removeAll(Arrays.asList("d"))); + registry.getOrCreateSnapshot(2); + assertTrue(set.removeAll(Arrays.asList("c"))); + assertThat(TimelineHashMapTest.iteratorToList(set.iterator(2)), + containsInAnyOrder("a", "b", "c")); + assertThat(TimelineHashMapTest.iteratorToList(set.iterator()), + containsInAnyOrder("a", "b")); + assertEquals(2, set.size()); + assertEquals(3, set.size(2)); + set.clear(); + assertTrue(set.isEmpty()); + assertFalse(set.isEmpty(2)); + } + + @Test + public void testToArray() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + TimelineHashSet set = new TimelineHashSet<>(registry, 1); + set.add("z"); + assertArrayEquals(new String[] {"z"}, set.toArray()); + assertArrayEquals(new String[] {"z", null}, set.toArray(new String[2])); + assertArrayEquals(new String[] {"z"}, set.toArray(new String[0])); + } + + @Test + public void testSetMethods() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + TimelineHashSet set = new TimelineHashSet<>(registry, 1); + assertTrue(set.add("xyz")); + assertFalse(set.add("xyz")); + assertTrue(set.remove("xyz")); + assertFalse(set.remove("xyz")); + assertTrue(set.addAll(Arrays.asList("abc", "def", "ghi"))); + assertFalse(set.addAll(Arrays.asList("abc", "def", "ghi"))); + assertTrue(set.addAll(Arrays.asList("abc", "def", "ghi", "jkl"))); + assertTrue(set.containsAll(Arrays.asList("def", "jkl"))); + assertFalse(set.containsAll(Arrays.asList("abc", "def", "xyz"))); + assertTrue(set.removeAll(Arrays.asList("def", "ghi", "xyz"))); + registry.getOrCreateSnapshot(5); + assertThat(TimelineHashMapTest.iteratorToList(set.iterator(5)), + containsInAnyOrder("abc", "jkl")); + assertThat(TimelineHashMapTest.iteratorToList(set.iterator()), + containsInAnyOrder("abc", "jkl")); + set.removeIf(e -> e.startsWith("a")); + assertThat(TimelineHashMapTest.iteratorToList(set.iterator()), + containsInAnyOrder("jkl")); + assertThat(TimelineHashMapTest.iteratorToList(set.iterator(5)), + containsInAnyOrder("abc", "jkl")); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/timeline/TimelineIntegerTest.java b/metadata/src/test/java/org/apache/kafka/timeline/TimelineIntegerTest.java new file mode 100644 index 0000000..736c4cb --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/timeline/TimelineIntegerTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.timeline; + +import java.util.Collections; + +import org.apache.kafka.common.utils.LogContext; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import static org.junit.jupiter.api.Assertions.assertEquals; + + +@Timeout(value = 40) +public class TimelineIntegerTest { + @Test + public void testModifyValue() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + TimelineInteger integer = new TimelineInteger(registry); + assertEquals(0, integer.get()); + assertEquals(0, integer.get(Long.MAX_VALUE)); + integer.set(1); + integer.set(2); + assertEquals(2, integer.get()); + assertEquals(2, integer.get(Long.MAX_VALUE)); + } + + @Test + public void testToStringAndEquals() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + TimelineInteger integer = new TimelineInteger(registry); + assertEquals("0", integer.toString()); + integer.set(1); + TimelineInteger integer2 = new TimelineInteger(registry); + integer2.set(1); + assertEquals("1", integer2.toString()); + assertEquals(integer, integer2); + } + + @Test + public void testSnapshot() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + TimelineInteger integer = new TimelineInteger(registry); + registry.getOrCreateSnapshot(2); + integer.set(1); + registry.getOrCreateSnapshot(3); + integer.set(2); + integer.increment(); + integer.increment(); + integer.decrement(); + registry.getOrCreateSnapshot(4); + assertEquals(0, integer.get(2)); + assertEquals(1, integer.get(3)); + assertEquals(3, integer.get(4)); + registry.revertToSnapshot(3); + assertEquals(1, integer.get()); + registry.revertToSnapshot(2); + assertEquals(0, integer.get()); + } + + @Test + public void testReset() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + TimelineInteger value = new TimelineInteger(registry); + registry.getOrCreateSnapshot(2); + value.set(1); + registry.getOrCreateSnapshot(3); + value.set(2); + + registry.reset(); + + assertEquals(Collections.emptyList(), registry.epochsList()); + assertEquals(TimelineInteger.INIT, value.get()); + } +} diff --git a/metadata/src/test/java/org/apache/kafka/timeline/TimelineLongTest.java b/metadata/src/test/java/org/apache/kafka/timeline/TimelineLongTest.java new file mode 100644 index 0000000..26412a5 --- /dev/null +++ b/metadata/src/test/java/org/apache/kafka/timeline/TimelineLongTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.timeline; + +import java.util.Collections; + +import org.apache.kafka.common.utils.LogContext; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import static org.junit.jupiter.api.Assertions.assertEquals; + + +@Timeout(value = 40) +public class TimelineLongTest { + @Test + public void testModifyValue() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + TimelineLong value = new TimelineLong(registry); + assertEquals(0L, value.get()); + assertEquals(0L, value.get(Long.MAX_VALUE)); + value.set(1L); + value.set(Long.MAX_VALUE); + assertEquals(Long.MAX_VALUE, value.get()); + assertEquals(Long.MAX_VALUE, value.get(Long.MAX_VALUE)); + } + + @Test + public void testToStringAndEquals() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + TimelineLong value = new TimelineLong(registry); + assertEquals("0", value.toString()); + value.set(1L); + TimelineLong integer2 = new TimelineLong(registry); + integer2.set(1); + assertEquals("1", integer2.toString()); + assertEquals(value, integer2); + } + + @Test + public void testSnapshot() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + TimelineLong value = new TimelineLong(registry); + registry.getOrCreateSnapshot(2); + value.set(1L); + registry.getOrCreateSnapshot(3); + value.set(2L); + value.increment(); + value.increment(); + value.decrement(); + registry.getOrCreateSnapshot(4); + assertEquals(0L, value.get(2)); + assertEquals(1L, value.get(3)); + assertEquals(3L, value.get(4)); + registry.revertToSnapshot(3); + assertEquals(1L, value.get()); + registry.revertToSnapshot(2); + assertEquals(0L, value.get()); + } + + @Test + public void testReset() { + SnapshotRegistry registry = new SnapshotRegistry(new LogContext()); + TimelineLong value = new TimelineLong(registry); + registry.getOrCreateSnapshot(2); + value.set(1L); + registry.getOrCreateSnapshot(3); + value.set(2L); + + registry.reset(); + + assertEquals(Collections.emptyList(), registry.epochsList()); + assertEquals(TimelineLong.INIT, value.get()); + } +} diff --git a/metadata/src/test/resources/log4j.properties b/metadata/src/test/resources/log4j.properties new file mode 100644 index 0000000..db38793 --- /dev/null +++ b/metadata/src/test/resources/log4j.properties @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +log4j.rootLogger=DEBUG, stdout + +log4j.appender.stdout=org.apache.log4j.ConsoleAppender +log4j.appender.stdout.layout=org.apache.log4j.PatternLayout +log4j.appender.stdout.layout.ConversionPattern=[%d] %p %m (%c:%L)%n + +log4j.logger.org.apache.kafka=DEBUG +log4j.logger.org.apache.zookeeper=WARN diff --git a/raft/README.md b/raft/README.md new file mode 100644 index 0000000..bc95c93 --- /dev/null +++ b/raft/README.md @@ -0,0 +1,51 @@ +KRaft (Kafka Raft) +================== +KRaft (Kafka Raft) is a protocol based on the [Raft Consensus Protocol](https://www.usenix.org/system/files/conference/atc14/atc14-paper-ongaro.pdf) +tailored for Apache Kafka. + +This is used by Apache Kafka in the [KRaft (Kafka Raft Metadata) mode](https://github.com/apache/kafka/blob/trunk/config/kraft/README.md). We +also have a standalone test server which can be used for performance testing. We describe the details to set this up below. + +### Run Single Quorum ### + bin/test-kraft-server-start.sh --config config/kraft.properties + +### Run Multi Node Quorum ### +Create 3 separate KRaft quorum properties as the following: + +`cat << EOF >> config/kraft-quorum-1.properties` + + node.id=1 + listeners=PLAINTEXT://localhost:9092 + controller.listener.names=PLAINTEXT + controller.quorum.voters=1@localhost:9092,2@localhost:9093,3@localhost:9094 + log.dirs=/tmp/kraft-logs-1 + EOF + +`cat << EOF >> config/kraft-quorum-2.properties` + + node.id=2 + listeners=PLAINTEXT://localhost:9093 + controller.listener.names=PLAINTEXT + controller.quorum.voters=1@localhost:9092,2@localhost:9093,3@localhost:9094 + log.dirs=/tmp/kraft-logs-2 + EOF + +`cat << EOF >> config/kraft-quorum-3.properties` + + node.id=3 + listeners=PLAINTEXT://localhost:9094 + controller.listener.names=PLAINTEXT + controller.quorum.voters=1@localhost:9092,2@localhost:9093,3@localhost:9094 + log.dirs=/tmp/kraft-logs-3 + EOF + +Open up 3 separate terminals, and run individual commands: + + bin/test-kraft-server-start.sh --config config/kraft-quorum-1.properties + bin/test-kraft-server-start.sh --config config/kraft-quorum-2.properties + bin/test-kraft-server-start.sh --config config/kraft-quorum-3.properties + +Once a leader is elected, it will begin writing to an internal +`__raft_performance_test` topic with a steady workload of random data. +You can control the workload using the `--throughput` and `--record-size` +arguments passed to `test-kraft-server-start.sh`. diff --git a/raft/bin/test-kraft-server-start.sh b/raft/bin/test-kraft-server-start.sh new file mode 100755 index 0000000..701bc18 --- /dev/null +++ b/raft/bin/test-kraft-server-start.sh @@ -0,0 +1,39 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +base_dir=$(dirname $0) + +if [ "x$KAFKA_LOG4J_OPTS" = "x" ]; then + export KAFKA_LOG4J_OPTS="-Dlog4j.configuration=file:$base_dir/../config/kraft-log4j.properties" +fi + +if [ "x$KAFKA_HEAP_OPTS" = "x" ]; then + export KAFKA_HEAP_OPTS="-Xmx1G -Xms1G" +fi + +EXTRA_ARGS=${EXTRA_ARGS-'-name kafkaServer -loggc'} + +COMMAND=$1 +case $COMMAND in + -daemon) + EXTRA_ARGS="-daemon "$EXTRA_ARGS + shift + ;; + *) + ;; +esac + +exec $base_dir/../../bin/kafka-run-class.sh $EXTRA_ARGS kafka.tools.TestRaftServer "$@" diff --git a/raft/config/kraft-log4j.properties b/raft/config/kraft-log4j.properties new file mode 100644 index 0000000..14f739a --- /dev/null +++ b/raft/config/kraft-log4j.properties @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +log4j.rootLogger=INFO, stderr + +log4j.appender.stderr=org.apache.log4j.ConsoleAppender +log4j.appender.stderr.layout=org.apache.log4j.PatternLayout +log4j.appender.stderr.layout.ConversionPattern=[%d] %p %m (%c)%n +log4j.appender.stderr.Target=System.err + +log4j.logger.org.apache.kafka.raft=INFO +log4j.logger.org.apache.kafka.snapshot=INFO diff --git a/raft/config/kraft.properties b/raft/config/kraft.properties new file mode 100644 index 0000000..a8556db --- /dev/null +++ b/raft/config/kraft.properties @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +node.id=0 +listeners=PLAINTEXT://localhost:9092 +controller.listener.names=PLAINTEXT +controller.quorum.voters=0@localhost:9092 +log.dirs=/tmp/kraft-logs diff --git a/raft/src/main/java/org/apache/kafka/raft/Batch.java b/raft/src/main/java/org/apache/kafka/raft/Batch.java new file mode 100644 index 0000000..685a758 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/Batch.java @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Objects; + +/** + * A batch of records. + * + * This type contains a list of records `T` along with the information associated with those records. + */ +public final class Batch implements Iterable { + private final long baseOffset; + private final int epoch; + private final long appendTimestamp; + private final int sizeInBytes; + private final long lastOffset; + private final List records; + + private Batch( + long baseOffset, + int epoch, + long appendTimestamp, + int sizeInBytes, + long lastOffset, + List records + ) { + this.baseOffset = baseOffset; + this.epoch = epoch; + this.appendTimestamp = appendTimestamp; + this.sizeInBytes = sizeInBytes; + this.lastOffset = lastOffset; + this.records = records; + } + + /** + * The offset of the last record in the batch. + */ + public long lastOffset() { + return lastOffset; + } + + /** + * The offset of the first record in the batch. + */ + public long baseOffset() { + return baseOffset; + } + + /** + * The append timestamp in milliseconds of the batch. + */ + public long appendTimestamp() { + return appendTimestamp; + } + + /** + * The list of records in the batch. + */ + public List records() { + return records; + } + + /** + * The epoch of the leader that appended the record batch. + */ + public int epoch() { + return epoch; + } + + /** + * The number of bytes used by this batch. + */ + public int sizeInBytes() { + return sizeInBytes; + } + + @Override + public Iterator iterator() { + return records.iterator(); + } + + @Override + public String toString() { + return "Batch(" + + "baseOffset=" + baseOffset + + ", epoch=" + epoch + + ", appendTimestamp=" + appendTimestamp + + ", sizeInBytes=" + sizeInBytes + + ", lastOffset=" + lastOffset + + ", records=" + records + + ')'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Batch batch = (Batch) o; + return baseOffset == batch.baseOffset && + epoch == batch.epoch && + appendTimestamp == batch.appendTimestamp && + sizeInBytes == batch.sizeInBytes && + lastOffset == batch.lastOffset && + Objects.equals(records, batch.records); + } + + @Override + public int hashCode() { + return Objects.hash( + baseOffset, + epoch, + appendTimestamp, + sizeInBytes, + lastOffset, + records + ); + } + + /** + * Create a control batch without any data records. + * + * Internally this is used to propagate offset information for control batches which do not decode to the type T. + * + * @param baseOffset offset of the batch + * @param epoch epoch of the leader that created this batch + * @param appendTimestamp timestamp in milliseconds of when the batch was appended + * @param sizeInBytes number of bytes used by this batch + * @param lastOffset offset of the last record of this batch + */ + public static Batch control( + long baseOffset, + int epoch, + long appendTimestamp, + int sizeInBytes, + long lastOffset + ) { + return new Batch<>( + baseOffset, + epoch, + appendTimestamp, + sizeInBytes, + lastOffset, + Collections.emptyList() + ); + } + + /** + * Create a data batch with the given base offset, epoch and records. + * + * @param baseOffset offset of the first record in the batch + * @param epoch epoch of the leader that created this batch + * @param appendTimestamp timestamp in milliseconds of when the batch was appended + * @param sizeInBytes number of bytes used by this batch + * @param records the list of records in this batch + */ + public static Batch data( + long baseOffset, + int epoch, + long appendTimestamp, + int sizeInBytes, + List records + ) { + if (records.isEmpty()) { + throw new IllegalArgumentException( + String.format( + "Batch must contain at least one record; baseOffset = %s; epoch = %s", + baseOffset, + epoch + ) + ); + } + + return new Batch<>( + baseOffset, + epoch, + appendTimestamp, + sizeInBytes, + baseOffset + records.size() - 1, + records + ); + } +} diff --git a/raft/src/main/java/org/apache/kafka/raft/BatchReader.java b/raft/src/main/java/org/apache/kafka/raft/BatchReader.java new file mode 100644 index 0000000..79e6614 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/BatchReader.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import java.util.Iterator; +import java.util.OptionalLong; + +/** + * This interface is used to send committed data from the {@link RaftClient} + * down to registered {@link RaftClient.Listener} instances. + * + * The advantage of hiding the consumption of committed batches behind an interface + * is that it allows us to push blocking operations such as reads from disk outside + * of the Raft IO thread. This helps to ensure that a slow state machine will not + * affect replication. + * + * @param record type (see {@link org.apache.kafka.server.common.serialization.RecordSerde}) + */ +public interface BatchReader extends Iterator>, AutoCloseable { + + /** + * Get the base offset of the readable batches. Note that this value is a constant + * which is defined when the {@link BatchReader} instance is constructed. It does + * not change based on reader progress. + * + * @return the base offset + */ + long baseOffset(); + + /** + * Get the last offset of the batch if it is known. When reading from disk, we may + * not know the last offset of a set of records until it has been read from disk. + * In this case, the state machine cannot advance to the next committed data until + * all batches from the {@link BatchReader} instance have been consumed. + * + * @return optional last offset + */ + OptionalLong lastOffset(); + + /** + * Close this reader. It is the responsibility of the {@link RaftClient.Listener} + * to close each reader passed to {@link RaftClient.Listener#handleCommit(BatchReader)}. + */ + @Override + void close(); +} diff --git a/raft/src/main/java/org/apache/kafka/raft/CandidateState.java b/raft/src/main/java/org/apache/kafka/raft/CandidateState.java new file mode 100644 index 0000000..e9e1e0e --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/CandidateState.java @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Timer; +import org.slf4j.Logger; + +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +public class CandidateState implements EpochState { + private final int localId; + private final int epoch; + private final int retries; + private final Map voteStates = new HashMap<>(); + private final Optional highWatermark; + private final int electionTimeoutMs; + private final Timer electionTimer; + private final Timer backoffTimer; + private final Logger log; + + /** + * The life time of a candidate state is the following: + * + * 1. Once started, it would keep record of the received votes. + * 2. If majority votes granted, it can then end its life and will be replaced by a leader state; + * 3. If majority votes rejected or election timed out, it would transit into a backing off phase; + * after the backoff phase completes, it would end its left and be replaced by a new candidate state with bumped retry. + */ + private boolean isBackingOff; + + protected CandidateState( + Time time, + int localId, + int epoch, + Set voters, + Optional highWatermark, + int retries, + int electionTimeoutMs, + LogContext logContext + ) { + this.localId = localId; + this.epoch = epoch; + this.highWatermark = highWatermark; + this.retries = retries; + this.isBackingOff = false; + this.electionTimeoutMs = electionTimeoutMs; + this.electionTimer = time.timer(electionTimeoutMs); + this.backoffTimer = time.timer(0); + this.log = logContext.logger(CandidateState.class); + + for (Integer voterId : voters) { + voteStates.put(voterId, State.UNRECORDED); + } + voteStates.put(localId, State.GRANTED); + } + + public int localId() { + return localId; + } + + public int majoritySize() { + return voteStates.size() / 2 + 1; + } + + private long numGranted() { + return voteStates.values().stream().filter(state -> state == State.GRANTED).count(); + } + + private long numUnrecorded() { + return voteStates.values().stream().filter(state -> state == State.UNRECORDED).count(); + } + + /** + * Check if the candidate is backing off for the next election + */ + public boolean isBackingOff() { + return isBackingOff; + } + + public int retries() { + return retries; + } + + /** + * Check whether we have received enough votes to conclude the election and become leader. + * + * @return true if at least a majority of nodes have granted the vote + */ + public boolean isVoteGranted() { + return numGranted() >= majoritySize(); + } + + /** + * Check if we have received enough rejections that it is no longer possible to reach a + * majority of grants. + * + * @return true if the vote is rejected, false if the vote is already or can still be granted + */ + public boolean isVoteRejected() { + return numGranted() + numUnrecorded() < majoritySize(); + } + + /** + * Record a granted vote from one of the voters. + * + * @param remoteNodeId The id of the voter + * @return true if the voter had not been previously recorded + * @throws IllegalArgumentException if the remote node is not a voter or if the vote had already been + * rejected by this node + */ + public boolean recordGrantedVote(int remoteNodeId) { + State state = voteStates.get(remoteNodeId); + if (state == null) { + throw new IllegalArgumentException("Attempt to grant vote to non-voter " + remoteNodeId); + } else if (state == State.REJECTED) { + throw new IllegalArgumentException("Attempt to grant vote from node " + remoteNodeId + + " which previously rejected our request"); + } + return voteStates.put(remoteNodeId, State.GRANTED) == State.UNRECORDED; + } + + /** + * Record a rejected vote from one of the voters. + * + * @param remoteNodeId The id of the voter + * @return true if the rejected vote had not been previously recorded + * @throws IllegalArgumentException if the remote node is not a voter or if the vote had already been + * granted by this node + */ + public boolean recordRejectedVote(int remoteNodeId) { + State state = voteStates.get(remoteNodeId); + if (state == null) { + throw new IllegalArgumentException("Attempt to reject vote to non-voter " + remoteNodeId); + } else if (state == State.GRANTED) { + throw new IllegalArgumentException("Attempt to reject vote from node " + remoteNodeId + + " which previously granted our request"); + } + + return voteStates.put(remoteNodeId, State.REJECTED) == State.UNRECORDED; + } + + /** + * Record the current election has failed since we've either received sufficient rejecting voters or election timed out + */ + public void startBackingOff(long currentTimeMs, long backoffDurationMs) { + this.backoffTimer.update(currentTimeMs); + this.backoffTimer.reset(backoffDurationMs); + this.isBackingOff = true; + } + + /** + * Get the set of voters which have not been counted as granted or rejected yet. + * + * @return The set of unrecorded voters + */ + public Set unrecordedVoters() { + return votersInState(State.UNRECORDED); + } + + /** + * Get the set of voters that have granted our vote requests. + * + * @return The set of granting voters, which should always contain the ID of the candidate + */ + public Set grantingVoters() { + return votersInState(State.GRANTED); + } + + /** + * Get the set of voters that have rejected our candidacy. + * + * @return The set of rejecting voters + */ + public Set rejectingVoters() { + return votersInState(State.REJECTED); + } + + private Set votersInState(State state) { + return voteStates.entrySet().stream() + .filter(entry -> entry.getValue() == state) + .map(Map.Entry::getKey) + .collect(Collectors.toSet()); + } + + public boolean hasElectionTimeoutExpired(long currentTimeMs) { + electionTimer.update(currentTimeMs); + return electionTimer.isExpired(); + } + + public boolean isBackoffComplete(long currentTimeMs) { + backoffTimer.update(currentTimeMs); + return backoffTimer.isExpired(); + } + + public long remainingBackoffMs(long currentTimeMs) { + if (!isBackingOff) { + throw new IllegalStateException("Candidate is not currently backing off"); + } + backoffTimer.update(currentTimeMs); + return backoffTimer.remainingMs(); + } + + public long remainingElectionTimeMs(long currentTimeMs) { + electionTimer.update(currentTimeMs); + return electionTimer.remainingMs(); + } + + @Override + public ElectionState election() { + return ElectionState.withVotedCandidate(epoch, localId, voteStates.keySet()); + } + + @Override + public int epoch() { + return epoch; + } + + @Override + public Optional highWatermark() { + return highWatermark; + } + + @Override + public boolean canGrantVote(int candidateId, boolean isLogUpToDate) { + // Still reject vote request even candidateId = localId, Although the candidate votes for + // itself, this vote is implicit and not "granted". + log.debug("Rejecting vote request from candidate {} since we are already candidate in epoch {}", + candidateId, epoch); + return false; + } + + @Override + public String toString() { + return "CandidateState(" + + "localId=" + localId + + ", epoch=" + epoch + + ", retries=" + retries + + ", electionTimeoutMs=" + electionTimeoutMs + + ')'; + } + + @Override + public String name() { + return "Candidate"; + } + + @Override + public void close() {} + + private enum State { + UNRECORDED, + GRANTED, + REJECTED + } +} diff --git a/raft/src/main/java/org/apache/kafka/raft/ElectionState.java b/raft/src/main/java/org/apache/kafka/raft/ElectionState.java new file mode 100644 index 0000000..5c6372e --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/ElectionState.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import java.util.OptionalInt; +import java.util.Set; + +/** + * Encapsulate election state stored on disk after every state change. + */ +public class ElectionState { + public final int epoch; + public final OptionalInt leaderIdOpt; + public final OptionalInt votedIdOpt; + private final Set voters; + + ElectionState(int epoch, + OptionalInt leaderIdOpt, + OptionalInt votedIdOpt, + Set voters) { + this.epoch = epoch; + this.leaderIdOpt = leaderIdOpt; + this.votedIdOpt = votedIdOpt; + this.voters = voters; + } + + public static ElectionState withVotedCandidate(int epoch, int votedId, Set voters) { + if (votedId < 0) + throw new IllegalArgumentException("Illegal voted Id " + votedId + ": must be non-negative"); + if (!voters.contains(votedId)) + throw new IllegalArgumentException("Voted candidate with id " + votedId + " is not among the valid voters"); + return new ElectionState(epoch, OptionalInt.empty(), OptionalInt.of(votedId), voters); + } + + public static ElectionState withElectedLeader(int epoch, int leaderId, Set voters) { + if (leaderId < 0) + throw new IllegalArgumentException("Illegal leader Id " + leaderId + ": must be non-negative"); + if (!voters.contains(leaderId)) + throw new IllegalArgumentException("Leader with id " + leaderId + " is not among the valid voters"); + return new ElectionState(epoch, OptionalInt.of(leaderId), OptionalInt.empty(), voters); + } + + public static ElectionState withUnknownLeader(int epoch, Set voters) { + return new ElectionState(epoch, OptionalInt.empty(), OptionalInt.empty(), voters); + } + + public boolean isLeader(int nodeId) { + if (nodeId < 0) + throw new IllegalArgumentException("Invalid negative nodeId: " + nodeId); + return leaderIdOpt.orElse(-1) == nodeId; + } + + public boolean isVotedCandidate(int nodeId) { + if (nodeId < 0) + throw new IllegalArgumentException("Invalid negative nodeId: " + nodeId); + return votedIdOpt.orElse(-1) == nodeId; + } + + public int leaderId() { + if (!leaderIdOpt.isPresent()) + throw new IllegalStateException("Attempt to access nil leaderId"); + return leaderIdOpt.getAsInt(); + } + + public int votedId() { + if (!votedIdOpt.isPresent()) + throw new IllegalStateException("Attempt to access nil votedId"); + return votedIdOpt.getAsInt(); + } + + public Set voters() { + return voters; + } + + public boolean hasLeader() { + return leaderIdOpt.isPresent(); + } + + public boolean hasVoted() { + return votedIdOpt.isPresent(); + } + + + @Override + public String toString() { + return "Election(epoch=" + epoch + + ", leaderIdOpt=" + leaderIdOpt + + ", votedIdOpt=" + votedIdOpt + + ')'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + ElectionState that = (ElectionState) o; + + if (epoch != that.epoch) return false; + if (!leaderIdOpt.equals(that.leaderIdOpt)) return false; + return votedIdOpt.equals(that.votedIdOpt); + } + + @Override + public int hashCode() { + int result = epoch; + result = 31 * result + leaderIdOpt.hashCode(); + result = 31 * result + votedIdOpt.hashCode(); + return result; + } +} diff --git a/raft/src/main/java/org/apache/kafka/raft/EpochState.java b/raft/src/main/java/org/apache/kafka/raft/EpochState.java new file mode 100644 index 0000000..89e8f0a --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/EpochState.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import java.io.Closeable; +import java.util.Optional; + +public interface EpochState extends Closeable { + + default Optional highWatermark() { + return Optional.empty(); + } + + /** + * Decide whether to grant a vote to a candidate, it is the responsibility of the caller to invoke + * {@link QuorumState##transitionToVoted(int, int)} if vote is granted. + * + * @param candidateId The ID of the voter who attempt to become leader + * @param isLogUpToDate Whether the candidate’s log is at least as up-to-date as receiver’s log, it + * is the responsibility of the caller to compare the log in advance + * @return true If grant vote. + */ + boolean canGrantVote(int candidateId, boolean isLogUpToDate); + + /** + * Get the current election state, which is guaranteed to be immutable. + */ + ElectionState election(); + + /** + * Get the current (immutable) epoch. + */ + int epoch(); + + /** + * User-friendly description of the state + */ + String name(); + +} diff --git a/raft/src/main/java/org/apache/kafka/raft/ExpirationService.java b/raft/src/main/java/org/apache/kafka/raft/ExpirationService.java new file mode 100644 index 0000000..e8e8b02 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/ExpirationService.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import java.util.concurrent.CompletableFuture; + +public interface ExpirationService { + /** + * Get a new completable future which will automatically fail exceptionally with a + * {@link org.apache.kafka.common.errors.TimeoutException} if not completed before + * the provided time limit expires. + * + * @param timeoutMs the duration in milliseconds before the future is completed exceptionally + * @param arbitrary future type (the service must set no expectation on the this type) + * @return the completable future + */ + CompletableFuture failAfter(long timeoutMs); +} diff --git a/raft/src/main/java/org/apache/kafka/raft/FileBasedStateStore.java b/raft/src/main/java/org/apache/kafka/raft/FileBasedStateStore.java new file mode 100644 index 0000000..e403613 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/FileBasedStateStore.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.fasterxml.jackson.databind.node.ShortNode; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.raft.generated.QuorumStateData; +import org.apache.kafka.raft.generated.QuorumStateData.Voter; +import org.apache.kafka.raft.generated.QuorumStateDataJsonConverter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.UncheckedIOException; +import java.io.EOFException; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.util.List; +import java.util.OptionalInt; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Local file based quorum state store. It takes the JSON format of {@link QuorumStateData} + * with an extra data version number as part of the data for easy deserialization. + * + * Example format: + *
                + * {"clusterId":"",
                + *   "leaderId":1,
                + *   "leaderEpoch":2,
                + *   "votedId":-1,
                + *   "appliedOffset":0,
                + *   "currentVoters":[],
                + *   "data_version":0}
                + * 
                + * */ +public class FileBasedStateStore implements QuorumStateStore { + private static final Logger log = LoggerFactory.getLogger(FileBasedStateStore.class); + + private final File stateFile; + + static final String DATA_VERSION = "data_version"; + + public FileBasedStateStore(final File stateFile) { + this.stateFile = stateFile; + } + + private QuorumStateData readStateFromFile(File file) { + try (final BufferedReader reader = Files.newBufferedReader(file.toPath())) { + final String line = reader.readLine(); + if (line == null) { + throw new EOFException("File ended prematurely."); + } + + final ObjectMapper objectMapper = new ObjectMapper(); + JsonNode readNode = objectMapper.readTree(line); + + if (!(readNode instanceof ObjectNode)) { + throw new IOException("Deserialized node " + readNode + + " is not an object node"); + } + final ObjectNode dataObject = (ObjectNode) readNode; + + JsonNode dataVersionNode = dataObject.get(DATA_VERSION); + if (dataVersionNode == null) { + throw new IOException("Deserialized node " + readNode + + " does not have " + DATA_VERSION + " field"); + } + + final short dataVersion = dataVersionNode.shortValue(); + return QuorumStateDataJsonConverter.read(dataObject, dataVersion); + } catch (IOException e) { + throw new UncheckedIOException( + String.format("Error while reading the Quorum status from the file %s", file), e); + } + } + + /** + * Reads the election state from local file. + */ + @Override + public ElectionState readElectionState() { + if (!stateFile.exists()) { + return null; + } + + QuorumStateData data = readStateFromFile(stateFile); + + return new ElectionState(data.leaderEpoch(), + data.leaderId() == UNKNOWN_LEADER_ID ? OptionalInt.empty() : + OptionalInt.of(data.leaderId()), + data.votedId() == NOT_VOTED ? OptionalInt.empty() : + OptionalInt.of(data.votedId()), + data.currentVoters() + .stream().map(Voter::voterId).collect(Collectors.toSet())); + } + + @Override + public void writeElectionState(ElectionState latest) { + QuorumStateData data = new QuorumStateData() + .setLeaderEpoch(latest.epoch) + .setVotedId(latest.hasVoted() ? latest.votedId() : NOT_VOTED) + .setLeaderId(latest.hasLeader() ? latest.leaderId() : UNKNOWN_LEADER_ID) + .setCurrentVoters(voters(latest.voters())); + writeElectionStateToFile(stateFile, data); + } + + private List voters(Set votersId) { + return votersId.stream().map( + voterId -> new Voter().setVoterId(voterId)).collect(Collectors.toList()); + } + + private void writeElectionStateToFile(final File stateFile, QuorumStateData state) { + final File temp = new File(stateFile.getAbsolutePath() + ".tmp"); + deleteFileIfExists(temp); + + log.trace("Writing tmp quorum state {}", temp.getAbsolutePath()); + + try (final FileOutputStream fileOutputStream = new FileOutputStream(temp); + final BufferedWriter writer = new BufferedWriter( + new OutputStreamWriter(fileOutputStream, StandardCharsets.UTF_8))) { + short version = state.highestSupportedVersion(); + + ObjectNode jsonState = (ObjectNode) QuorumStateDataJsonConverter.write(state, version); + jsonState.set(DATA_VERSION, new ShortNode(version)); + writer.write(jsonState.toString()); + writer.flush(); + fileOutputStream.getFD().sync(); + Utils.atomicMoveWithFallback(temp.toPath(), stateFile.toPath()); + } catch (IOException e) { + throw new UncheckedIOException( + String.format("Error while writing the Quorum status from the file %s", + stateFile.getAbsolutePath()), e); + } finally { + // cleanup the temp file when the write finishes (either success or fail). + deleteFileIfExists(temp); + } + } + + /** + * Clear state store by deleting the local quorum state file + */ + @Override + public void clear() { + deleteFileIfExists(stateFile); + deleteFileIfExists(new File(stateFile.getAbsolutePath() + ".tmp")); + } + + @Override + public String toString() { + return "Quorum state filepath: " + stateFile.getAbsolutePath(); + } + + private void deleteFileIfExists(File file) { + try { + Files.deleteIfExists(file.toPath()); + } catch (IOException e) { + throw new UncheckedIOException( + String.format("Error while deleting file %s", file.getAbsoluteFile()), e); + } + } +} diff --git a/raft/src/main/java/org/apache/kafka/raft/FollowerState.java b/raft/src/main/java/org/apache/kafka/raft/FollowerState.java new file mode 100644 index 0000000..e3a3047 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/FollowerState.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Timer; +import org.apache.kafka.snapshot.RawSnapshotWriter; +import org.slf4j.Logger; + +import java.util.Optional; +import java.util.OptionalInt; +import java.util.OptionalLong; +import java.util.Set; + +public class FollowerState implements EpochState { + private final int fetchTimeoutMs; + private final int epoch; + private final int leaderId; + private final Set voters; + // Used for tracking the expiration of both the Fetch and FetchSnapshot requests + private final Timer fetchTimer; + private Optional highWatermark; + /* Used to track the currently fetching snapshot. When fetching snapshot regular + * Fetch request are paused + */ + private Optional fetchingSnapshot; + + private final Logger log; + + public FollowerState( + Time time, + int epoch, + int leaderId, + Set voters, + Optional highWatermark, + int fetchTimeoutMs, + LogContext logContext + ) { + this.fetchTimeoutMs = fetchTimeoutMs; + this.epoch = epoch; + this.leaderId = leaderId; + this.voters = voters; + this.fetchTimer = time.timer(fetchTimeoutMs); + this.highWatermark = highWatermark; + this.fetchingSnapshot = Optional.empty(); + this.log = logContext.logger(FollowerState.class); + } + + @Override + public ElectionState election() { + return new ElectionState( + epoch, + OptionalInt.of(leaderId), + OptionalInt.empty(), + voters + ); + } + + @Override + public int epoch() { + return epoch; + } + + @Override + public String name() { + return "Follower"; + } + + public long remainingFetchTimeMs(long currentTimeMs) { + fetchTimer.update(currentTimeMs); + return fetchTimer.remainingMs(); + } + + public int leaderId() { + return leaderId; + } + + public boolean hasFetchTimeoutExpired(long currentTimeMs) { + fetchTimer.update(currentTimeMs); + return fetchTimer.isExpired(); + } + + public void resetFetchTimeout(long currentTimeMs) { + fetchTimer.update(currentTimeMs); + fetchTimer.reset(fetchTimeoutMs); + } + + public void overrideFetchTimeout(long currentTimeMs, long timeoutMs) { + fetchTimer.update(currentTimeMs); + fetchTimer.reset(timeoutMs); + } + + public boolean updateHighWatermark(OptionalLong highWatermark) { + if (!highWatermark.isPresent() && this.highWatermark.isPresent()) + throw new IllegalArgumentException("Attempt to overwrite current high watermark " + this.highWatermark + + " with unknown value"); + + if (this.highWatermark.isPresent()) { + long previousHighWatermark = this.highWatermark.get().offset; + long updatedHighWatermark = highWatermark.getAsLong(); + + if (updatedHighWatermark < 0) + throw new IllegalArgumentException("Illegal negative high watermark update"); + if (previousHighWatermark > updatedHighWatermark) + throw new IllegalArgumentException("Non-monotonic update of high watermark attempted"); + if (previousHighWatermark == updatedHighWatermark) + return false; + } + + this.highWatermark = highWatermark.isPresent() ? + Optional.of(new LogOffsetMetadata(highWatermark.getAsLong())) : + Optional.empty(); + return true; + } + + @Override + public Optional highWatermark() { + return highWatermark; + } + + public Optional fetchingSnapshot() { + return fetchingSnapshot; + } + + public void setFetchingSnapshot(Optional newSnapshot) { + if (fetchingSnapshot.isPresent()) { + fetchingSnapshot.get().close(); + } + fetchingSnapshot = newSnapshot; + } + + @Override + public boolean canGrantVote(int candidateId, boolean isLogUpToDate) { + log.debug("Rejecting vote request from candidate {} since we already have a leader {} in epoch {}", + candidateId, leaderId(), epoch); + return false; + } + + @Override + public String toString() { + return "FollowerState(" + + "fetchTimeoutMs=" + fetchTimeoutMs + + ", epoch=" + epoch + + ", leaderId=" + leaderId + + ", voters=" + voters + + ", highWatermark=" + highWatermark + + ", fetchingSnapshot=" + fetchingSnapshot + + ')'; + } + + @Override + public void close() { + if (fetchingSnapshot.isPresent()) { + fetchingSnapshot.get().close(); + } + } +} diff --git a/raft/src/main/java/org/apache/kafka/raft/Isolation.java b/raft/src/main/java/org/apache/kafka/raft/Isolation.java new file mode 100644 index 0000000..a87d010 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/Isolation.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +public enum Isolation { + COMMITTED, + UNCOMMITTED +} diff --git a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java new file mode 100644 index 0000000..24acb5e --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java @@ -0,0 +1,2603 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.ClusterAuthorizationException; +import org.apache.kafka.common.errors.NotLeaderOrFollowerException; +import org.apache.kafka.common.memory.MemoryPool; +import org.apache.kafka.common.message.BeginQuorumEpochRequestData; +import org.apache.kafka.common.message.BeginQuorumEpochResponseData; +import org.apache.kafka.common.message.DescribeQuorumRequestData; +import org.apache.kafka.common.message.DescribeQuorumResponseData; +import org.apache.kafka.common.message.DescribeQuorumResponseData.ReplicaState; +import org.apache.kafka.common.message.EndQuorumEpochRequestData; +import org.apache.kafka.common.message.EndQuorumEpochResponseData; +import org.apache.kafka.common.message.FetchRequestData; +import org.apache.kafka.common.message.FetchResponseData; +import org.apache.kafka.common.message.FetchSnapshotRequestData; +import org.apache.kafka.common.message.FetchSnapshotResponseData; +import org.apache.kafka.common.message.VoteRequestData; +import org.apache.kafka.common.message.VoteResponseData; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ApiMessage; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.Records; +import org.apache.kafka.common.record.UnalignedMemoryRecords; +import org.apache.kafka.common.record.UnalignedRecords; +import org.apache.kafka.common.requests.BeginQuorumEpochRequest; +import org.apache.kafka.common.requests.BeginQuorumEpochResponse; +import org.apache.kafka.common.requests.DescribeQuorumRequest; +import org.apache.kafka.common.requests.DescribeQuorumResponse; +import org.apache.kafka.common.requests.EndQuorumEpochRequest; +import org.apache.kafka.common.requests.EndQuorumEpochResponse; +import org.apache.kafka.common.requests.FetchResponse; +import org.apache.kafka.common.requests.FetchSnapshotRequest; +import org.apache.kafka.common.requests.FetchSnapshotResponse; +import org.apache.kafka.common.requests.VoteRequest; +import org.apache.kafka.common.requests.VoteResponse; +import org.apache.kafka.common.utils.BufferSupplier; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Timer; +import org.apache.kafka.raft.RequestManager.ConnectionState; +import org.apache.kafka.raft.errors.NotLeaderException; +import org.apache.kafka.raft.internals.BatchAccumulator; +import org.apache.kafka.raft.internals.BatchMemoryPool; +import org.apache.kafka.raft.internals.BlockingMessageQueue; +import org.apache.kafka.raft.internals.CloseListener; +import org.apache.kafka.raft.internals.FuturePurgatory; +import org.apache.kafka.raft.internals.KafkaRaftMetrics; +import org.apache.kafka.raft.internals.MemoryBatchReader; +import org.apache.kafka.raft.internals.RecordsBatchReader; +import org.apache.kafka.raft.internals.ThresholdPurgatory; +import org.apache.kafka.server.common.serialization.RecordSerde; +import org.apache.kafka.snapshot.RawSnapshotReader; +import org.apache.kafka.snapshot.RawSnapshotWriter; +import org.apache.kafka.snapshot.SnapshotReader; +import org.apache.kafka.snapshot.SnapshotWriter; +import org.slf4j.Logger; + +import java.util.Collections; +import java.util.IdentityHashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.OptionalLong; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import static java.util.concurrent.CompletableFuture.completedFuture; +import static org.apache.kafka.raft.RaftUtil.hasValidTopicPartition; + +/** + * This class implements a Kafkaesque version of the Raft protocol. Leader election + * is more or less pure Raft, but replication is driven by replica fetching and we use Kafka's + * log reconciliation protocol to truncate the log to a common point following each leader + * election. + * + * Like Zookeeper, this protocol distinguishes between voters and observers. Voters are + * the only ones who are eligible to handle protocol requests and they are the only ones + * who take part in elections. The protocol does not yet support dynamic quorum changes. + * + * These are the APIs in this protocol: + * + * 1) {@link VoteRequestData}: Sent by valid voters when their election timeout expires and they + * become a candidate. This request includes the last offset in the log which electors use + * to tell whether or not to grant the vote. + * + * 2) {@link BeginQuorumEpochRequestData}: Sent by the leader of an epoch only to valid voters to + * assert its leadership of the new epoch. This request will be retried indefinitely for + * each voter until it acknowledges the request or a new election occurs. + * + * This is not needed in usual Raft because the leader can use an empty data push + * to achieve the same purpose. The Kafka Raft implementation, however, is driven by + * fetch requests from followers, so there must be a way to find the new leader after + * an election has completed. + * + * 3) {@link EndQuorumEpochRequestData}: Sent by the leader of an epoch to valid voters in order to + * gracefully resign from the current epoch. This causes remaining voters to immediately + * begin a new election. + * + * 4) {@link FetchRequestData}: This is the same as the usual Fetch API in Kafka, but we add snapshot + * check before responding, and we also piggyback some additional metadata on responses (i.e. current + * leader and epoch). Unlike partition replication, we also piggyback truncation detection on this API + * rather than through a separate truncation state. + * + * 5) {@link FetchSnapshotRequestData}: Sent by the follower to the epoch leader in order to fetch a snapshot. + * This happens when a FetchResponse includes a snapshot ID due to the follower's log end offset being less + * than the leader's log start offset. This API is similar to the Fetch API since the snapshot is stored + * as FileRecords, but we use {@link UnalignedRecords} in FetchSnapshotResponse because the records + * are not necessarily offset-aligned. + */ +public class KafkaRaftClient implements RaftClient { + private static final int RETRY_BACKOFF_BASE_MS = 100; + public static final int MAX_FETCH_WAIT_MS = 500; + public static final int MAX_BATCH_SIZE_BYTES = 8 * 1024 * 1024; + public static final int MAX_FETCH_SIZE_BYTES = MAX_BATCH_SIZE_BYTES; + + private final AtomicReference shutdown = new AtomicReference<>(); + private final Logger logger; + private final Time time; + private final int fetchMaxWaitMs; + private final String clusterId; + private final NetworkChannel channel; + private final ReplicatedLog log; + private final Random random; + private final FuturePurgatory appendPurgatory; + private final FuturePurgatory fetchPurgatory; + private final RecordSerde serde; + private final MemoryPool memoryPool; + private final RaftMessageQueue messageQueue; + private final RaftConfig raftConfig; + private final KafkaRaftMetrics kafkaRaftMetrics; + private final QuorumState quorum; + private final RequestManager requestManager; + private final RaftMetadataLogCleanerManager snapshotCleaner; + + private final Map, ListenerContext> listenerContexts = new IdentityHashMap<>(); + private final ConcurrentLinkedQueue> pendingRegistrations = new ConcurrentLinkedQueue<>(); + + /** + * Create a new instance. + * + * Note that if the node ID is empty, then the client will behave as a + * non-participating observer. + */ + public KafkaRaftClient( + RecordSerde serde, + NetworkChannel channel, + ReplicatedLog log, + QuorumStateStore quorumStateStore, + Time time, + Metrics metrics, + ExpirationService expirationService, + LogContext logContext, + String clusterId, + OptionalInt nodeId, + RaftConfig raftConfig + ) { + this(serde, + channel, + new BlockingMessageQueue(), + log, + quorumStateStore, + new BatchMemoryPool(5, MAX_BATCH_SIZE_BYTES), + time, + metrics, + expirationService, + MAX_FETCH_WAIT_MS, + clusterId, + nodeId, + logContext, + new Random(), + raftConfig); + } + + KafkaRaftClient( + RecordSerde serde, + NetworkChannel channel, + RaftMessageQueue messageQueue, + ReplicatedLog log, + QuorumStateStore quorumStateStore, + MemoryPool memoryPool, + Time time, + Metrics metrics, + ExpirationService expirationService, + int fetchMaxWaitMs, + String clusterId, + OptionalInt nodeId, + LogContext logContext, + Random random, + RaftConfig raftConfig + ) { + this.serde = serde; + this.channel = channel; + this.messageQueue = messageQueue; + this.log = log; + this.memoryPool = memoryPool; + this.fetchPurgatory = new ThresholdPurgatory<>(expirationService); + this.appendPurgatory = new ThresholdPurgatory<>(expirationService); + this.time = time; + this.clusterId = clusterId; + this.fetchMaxWaitMs = fetchMaxWaitMs; + this.logger = logContext.logger(KafkaRaftClient.class); + this.random = random; + this.raftConfig = raftConfig; + this.snapshotCleaner = new RaftMetadataLogCleanerManager(logger, time, 60000, log::maybeClean); + Set quorumVoterIds = raftConfig.quorumVoterIds(); + this.requestManager = new RequestManager(quorumVoterIds, raftConfig.retryBackoffMs(), + raftConfig.requestTimeoutMs(), random); + this.quorum = new QuorumState( + nodeId, + quorumVoterIds, + raftConfig.electionTimeoutMs(), + raftConfig.fetchTimeoutMs(), + quorumStateStore, + time, + logContext, + random); + this.kafkaRaftMetrics = new KafkaRaftMetrics(metrics, "raft", quorum); + kafkaRaftMetrics.updateNumUnknownVoterConnections(quorum.remoteVoters().size()); + + // Update the voter endpoints with what's in RaftConfig + Map voterAddresses = raftConfig.quorumVoterConnections(); + voterAddresses.entrySet().stream() + .filter(e -> e.getValue() instanceof RaftConfig.InetAddressSpec) + .forEach(e -> this.channel.updateEndpoint(e.getKey(), (RaftConfig.InetAddressSpec) e.getValue())); + } + + private void updateFollowerHighWatermark( + FollowerState state, + OptionalLong highWatermarkOpt + ) { + highWatermarkOpt.ifPresent(highWatermark -> { + long newHighWatermark = Math.min(endOffset().offset, highWatermark); + if (state.updateHighWatermark(OptionalLong.of(newHighWatermark))) { + logger.debug("Follower high watermark updated to {}", newHighWatermark); + log.updateHighWatermark(new LogOffsetMetadata(newHighWatermark)); + updateListenersProgress(newHighWatermark); + } + }); + } + + private void updateLeaderEndOffsetAndTimestamp( + LeaderState state, + long currentTimeMs + ) { + final LogOffsetMetadata endOffsetMetadata = log.endOffset(); + + if (state.updateLocalState(currentTimeMs, endOffsetMetadata)) { + onUpdateLeaderHighWatermark(state, currentTimeMs); + } + + fetchPurgatory.maybeComplete(endOffsetMetadata.offset, currentTimeMs); + } + + private void onUpdateLeaderHighWatermark( + LeaderState state, + long currentTimeMs + ) { + state.highWatermark().ifPresent(highWatermark -> { + logger.debug("Leader high watermark updated to {}", highWatermark); + log.updateHighWatermark(highWatermark); + + // After updating the high watermark, we first clear the append + // purgatory so that we have an opportunity to route the pending + // records still held in memory directly to the listener + appendPurgatory.maybeComplete(highWatermark.offset, currentTimeMs); + + // It is also possible that the high watermark is being updated + // for the first time following the leader election, so we need + // to give lagging listeners an opportunity to catch up as well + updateListenersProgress(highWatermark.offset); + }); + } + + private void updateListenersProgress(long highWatermark) { + for (ListenerContext listenerContext : listenerContexts.values()) { + listenerContext.nextExpectedOffset().ifPresent(nextExpectedOffset -> { + if (nextExpectedOffset < log.startOffset() && nextExpectedOffset < highWatermark) { + SnapshotReader snapshot = latestSnapshot().orElseThrow(() -> new IllegalStateException( + String.format( + "Snapshot expected since next offset of %s is %s, log start offset is %s and high-watermark is %s", + listenerContext.listenerName(), + nextExpectedOffset, + log.startOffset(), + highWatermark + ) + )); + listenerContext.fireHandleSnapshot(snapshot); + } + }); + + // Re-read the expected offset in case the snapshot had to be reloaded + listenerContext.nextExpectedOffset().ifPresent(nextExpectedOffset -> { + if (nextExpectedOffset < highWatermark) { + LogFetchInfo readInfo = log.read(nextExpectedOffset, Isolation.COMMITTED); + listenerContext.fireHandleCommit(nextExpectedOffset, readInfo.records); + } + }); + } + } + + private Optional> latestSnapshot() { + return log.latestSnapshot().map(reader -> + SnapshotReader.of(reader, serde, BufferSupplier.create(), MAX_BATCH_SIZE_BYTES) + ); + } + + private void maybeFireHandleCommit(long baseOffset, int epoch, long appendTimestamp, int sizeInBytes, List records) { + for (ListenerContext listenerContext : listenerContexts.values()) { + listenerContext.nextExpectedOffset().ifPresent(nextOffset -> { + if (nextOffset == baseOffset) { + listenerContext.fireHandleCommit(baseOffset, epoch, appendTimestamp, sizeInBytes, records); + } + }); + } + } + + private void maybeFireLeaderChange(LeaderState state) { + for (ListenerContext listenerContext : listenerContexts.values()) { + listenerContext.maybeFireLeaderChange(quorum.leaderAndEpoch(), state.epochStartOffset()); + } + } + + private void maybeFireLeaderChange() { + for (ListenerContext listenerContext : listenerContexts.values()) { + listenerContext.maybeFireLeaderChange(quorum.leaderAndEpoch()); + } + } + + @Override + public void initialize() { + quorum.initialize(new OffsetAndEpoch(log.endOffset().offset, log.lastFetchedEpoch())); + + long currentTimeMs = time.milliseconds(); + if (quorum.isLeader()) { + throw new IllegalStateException("Voter cannot initialize as a Leader"); + } else if (quorum.isCandidate()) { + onBecomeCandidate(currentTimeMs); + } else if (quorum.isFollower()) { + onBecomeFollower(currentTimeMs); + } + + // When there is only a single voter, become candidate immediately + if (quorum.isVoter() + && quorum.remoteVoters().isEmpty() + && !quorum.isCandidate()) { + + transitionToCandidate(currentTimeMs); + } + } + + @Override + public void register(Listener listener) { + pendingRegistrations.add(Registration.register(listener)); + wakeup(); + } + + @Override + public void unregister(Listener listener) { + pendingRegistrations.add(Registration.unregister(listener)); + // No need to wakeup the polling thread. It is a removal so the updates can be + // delayed until the polling thread wakes up for other reasons. + } + + @Override + public LeaderAndEpoch leaderAndEpoch() { + return quorum.leaderAndEpoch(); + } + + @Override + public OptionalInt nodeId() { + return quorum.localId(); + } + + private OffsetAndEpoch endOffset() { + return new OffsetAndEpoch(log.endOffset().offset, log.lastFetchedEpoch()); + } + + private void resetConnections() { + requestManager.resetAll(); + } + + private void onBecomeLeader(long currentTimeMs) { + long endOffset = log.endOffset().offset; + + BatchAccumulator accumulator = new BatchAccumulator<>( + quorum.epoch(), + endOffset, + raftConfig.appendLingerMs(), + MAX_BATCH_SIZE_BYTES, + memoryPool, + time, + CompressionType.NONE, + serde + ); + + LeaderState state = quorum.transitionToLeader(endOffset, accumulator); + maybeFireLeaderChange(state); + + log.initializeLeaderEpoch(quorum.epoch()); + + // The high watermark can only be advanced once we have written a record + // from the new leader's epoch. Hence we write a control message immediately + // to ensure there is no delay committing pending data. + state.appendLeaderChangeMessage(currentTimeMs); + + resetConnections(); + kafkaRaftMetrics.maybeUpdateElectionLatency(currentTimeMs); + } + + private void flushLeaderLog(LeaderState state, long currentTimeMs) { + // We update the end offset before flushing so that parked fetches can return sooner. + updateLeaderEndOffsetAndTimestamp(state, currentTimeMs); + log.flush(); + } + + private boolean maybeTransitionToLeader(CandidateState state, long currentTimeMs) { + if (state.isVoteGranted()) { + onBecomeLeader(currentTimeMs); + return true; + } else { + return false; + } + } + + private void onBecomeCandidate(long currentTimeMs) { + CandidateState state = quorum.candidateStateOrThrow(); + if (!maybeTransitionToLeader(state, currentTimeMs)) { + resetConnections(); + kafkaRaftMetrics.updateElectionStartMs(currentTimeMs); + } + } + + private void transitionToCandidate(long currentTimeMs) { + quorum.transitionToCandidate(); + maybeFireLeaderChange(); + onBecomeCandidate(currentTimeMs); + } + + private void transitionToUnattached(int epoch) { + quorum.transitionToUnattached(epoch); + maybeFireLeaderChange(); + resetConnections(); + } + + private void transitionToResigned(List preferredSuccessors) { + fetchPurgatory.completeAllExceptionally( + Errors.NOT_LEADER_OR_FOLLOWER.exception("Not handling request since this node is resigning")); + quorum.transitionToResigned(preferredSuccessors); + maybeFireLeaderChange(); + resetConnections(); + } + + private void transitionToVoted(int candidateId, int epoch) { + quorum.transitionToVoted(epoch, candidateId); + maybeFireLeaderChange(); + resetConnections(); + } + + private void onBecomeFollower(long currentTimeMs) { + kafkaRaftMetrics.maybeUpdateElectionLatency(currentTimeMs); + + resetConnections(); + + // After becoming a follower, we need to complete all pending fetches so that + // they can be re-sent to the leader without waiting for their expirations + fetchPurgatory.completeAllExceptionally(new NotLeaderOrFollowerException( + "Cannot process the fetch request because the node is no longer the leader.")); + + // Clearing the append purgatory should complete all futures exceptionally since this node is no longer the leader + appendPurgatory.completeAllExceptionally(new NotLeaderOrFollowerException( + "Failed to receive sufficient acknowledgments for this append before leader change.")); + } + + private void transitionToFollower( + int epoch, + int leaderId, + long currentTimeMs + ) { + quorum.transitionToFollower(epoch, leaderId); + maybeFireLeaderChange(); + onBecomeFollower(currentTimeMs); + } + + private VoteResponseData buildVoteResponse(Errors partitionLevelError, boolean voteGranted) { + return VoteResponse.singletonResponse( + Errors.NONE, + log.topicPartition(), + partitionLevelError, + quorum.epoch(), + quorum.leaderIdOrSentinel(), + voteGranted); + } + + /** + * Handle a Vote request. This API may return the following errors: + * + * - {@link Errors#INCONSISTENT_CLUSTER_ID} if the cluster id is presented in request + * but different from this node + * - {@link Errors#BROKER_NOT_AVAILABLE} if this node is currently shutting down + * - {@link Errors#FENCED_LEADER_EPOCH} if the epoch is smaller than this node's epoch + * - {@link Errors#INCONSISTENT_VOTER_SET} if the request suggests inconsistent voter membership (e.g. + * if this node or the sender is not one of the current known voters) + * - {@link Errors#INVALID_REQUEST} if the last epoch or offset are invalid + */ + private VoteResponseData handleVoteRequest( + RaftRequest.Inbound requestMetadata + ) { + VoteRequestData request = (VoteRequestData) requestMetadata.data; + + if (!hasValidClusterId(request.clusterId())) { + return new VoteResponseData().setErrorCode(Errors.INCONSISTENT_CLUSTER_ID.code()); + } + + if (!hasValidTopicPartition(request, log.topicPartition())) { + // Until we support multi-raft, we treat individual topic partition mismatches as invalid requests + return new VoteResponseData().setErrorCode(Errors.INVALID_REQUEST.code()); + } + + VoteRequestData.PartitionData partitionRequest = + request.topics().get(0).partitions().get(0); + + int candidateId = partitionRequest.candidateId(); + int candidateEpoch = partitionRequest.candidateEpoch(); + + int lastEpoch = partitionRequest.lastOffsetEpoch(); + long lastEpochEndOffset = partitionRequest.lastOffset(); + if (lastEpochEndOffset < 0 || lastEpoch < 0 || lastEpoch >= candidateEpoch) { + return buildVoteResponse(Errors.INVALID_REQUEST, false); + } + + Optional errorOpt = validateVoterOnlyRequest(candidateId, candidateEpoch); + if (errorOpt.isPresent()) { + return buildVoteResponse(errorOpt.get(), false); + } + + if (candidateEpoch > quorum.epoch()) { + transitionToUnattached(candidateEpoch); + } + + OffsetAndEpoch lastEpochEndOffsetAndEpoch = new OffsetAndEpoch(lastEpochEndOffset, lastEpoch); + boolean voteGranted = quorum.canGrantVote(candidateId, lastEpochEndOffsetAndEpoch.compareTo(endOffset()) >= 0); + + if (voteGranted && quorum.isUnattached()) { + transitionToVoted(candidateId, candidateEpoch); + } + + logger.info("Vote request {} with epoch {} is {}", request, candidateEpoch, voteGranted ? "granted" : "rejected"); + return buildVoteResponse(Errors.NONE, voteGranted); + } + + private boolean handleVoteResponse( + RaftResponse.Inbound responseMetadata, + long currentTimeMs + ) { + int remoteNodeId = responseMetadata.sourceId(); + VoteResponseData response = (VoteResponseData) responseMetadata.data; + Errors topLevelError = Errors.forCode(response.errorCode()); + if (topLevelError != Errors.NONE) { + return handleTopLevelError(topLevelError, responseMetadata); + } + + if (!hasValidTopicPartition(response, log.topicPartition())) { + return false; + } + + VoteResponseData.PartitionData partitionResponse = + response.topics().get(0).partitions().get(0); + + Errors error = Errors.forCode(partitionResponse.errorCode()); + OptionalInt responseLeaderId = optionalLeaderId(partitionResponse.leaderId()); + int responseEpoch = partitionResponse.leaderEpoch(); + + Optional handled = maybeHandleCommonResponse( + error, responseLeaderId, responseEpoch, currentTimeMs); + if (handled.isPresent()) { + return handled.get(); + } else if (error == Errors.NONE) { + if (quorum.isLeader()) { + logger.debug("Ignoring vote response {} since we already became leader for epoch {}", + partitionResponse, quorum.epoch()); + } else if (quorum.isCandidate()) { + CandidateState state = quorum.candidateStateOrThrow(); + if (partitionResponse.voteGranted()) { + state.recordGrantedVote(remoteNodeId); + maybeTransitionToLeader(state, currentTimeMs); + } else { + state.recordRejectedVote(remoteNodeId); + + // If our vote is rejected, we go immediately to the random backoff. This + // ensures that we are not stuck waiting for the election timeout when the + // vote has become gridlocked. + if (state.isVoteRejected() && !state.isBackingOff()) { + logger.info("Insufficient remaining votes to become leader (rejected by {}). " + + "We will backoff before retrying election again", state.rejectingVoters()); + + state.startBackingOff( + currentTimeMs, + binaryExponentialElectionBackoffMs(state.retries()) + ); + } + } + } else { + logger.debug("Ignoring vote response {} since we are no longer a candidate in epoch {}", + partitionResponse, quorum.epoch()); + } + return true; + } else { + return handleUnexpectedError(error, responseMetadata); + } + } + + private int binaryExponentialElectionBackoffMs(int retries) { + if (retries <= 0) { + throw new IllegalArgumentException("Retries " + retries + " should be larger than zero"); + } + // upper limit exponential co-efficients at 20 to avoid overflow + return Math.min(RETRY_BACKOFF_BASE_MS * random.nextInt(2 << Math.min(20, retries - 1)), + raftConfig.electionBackoffMaxMs()); + } + + private int strictExponentialElectionBackoffMs(int positionInSuccessors, int totalNumSuccessors) { + if (positionInSuccessors <= 0 || positionInSuccessors >= totalNumSuccessors) { + throw new IllegalArgumentException("Position " + positionInSuccessors + " should be larger than zero" + + " and smaller than total number of successors " + totalNumSuccessors); + } + + int retryBackOffBaseMs = raftConfig.electionBackoffMaxMs() >> (totalNumSuccessors - 1); + return Math.min(raftConfig.electionBackoffMaxMs(), retryBackOffBaseMs << (positionInSuccessors - 1)); + } + + private BeginQuorumEpochResponseData buildBeginQuorumEpochResponse(Errors partitionLevelError) { + return BeginQuorumEpochResponse.singletonResponse( + Errors.NONE, + log.topicPartition(), + partitionLevelError, + quorum.epoch(), + quorum.leaderIdOrSentinel()); + } + + /** + * Handle a BeginEpoch request. This API may return the following errors: + * + * - {@link Errors#INCONSISTENT_CLUSTER_ID} if the cluster id is presented in request + * but different from this node + * - {@link Errors#BROKER_NOT_AVAILABLE} if this node is currently shutting down + * - {@link Errors#INCONSISTENT_VOTER_SET} if the request suggests inconsistent voter membership (e.g. + * if this node or the sender is not one of the current known voters) + * - {@link Errors#FENCED_LEADER_EPOCH} if the epoch is smaller than this node's epoch + */ + private BeginQuorumEpochResponseData handleBeginQuorumEpochRequest( + RaftRequest.Inbound requestMetadata, + long currentTimeMs + ) { + BeginQuorumEpochRequestData request = (BeginQuorumEpochRequestData) requestMetadata.data; + + if (!hasValidClusterId(request.clusterId())) { + return new BeginQuorumEpochResponseData().setErrorCode(Errors.INCONSISTENT_CLUSTER_ID.code()); + } + + if (!hasValidTopicPartition(request, log.topicPartition())) { + // Until we support multi-raft, we treat topic partition mismatches as invalid requests + return new BeginQuorumEpochResponseData().setErrorCode(Errors.INVALID_REQUEST.code()); + } + + BeginQuorumEpochRequestData.PartitionData partitionRequest = + request.topics().get(0).partitions().get(0); + + int requestLeaderId = partitionRequest.leaderId(); + int requestEpoch = partitionRequest.leaderEpoch(); + + Optional errorOpt = validateVoterOnlyRequest(requestLeaderId, requestEpoch); + if (errorOpt.isPresent()) { + return buildBeginQuorumEpochResponse(errorOpt.get()); + } + + maybeTransition(OptionalInt.of(requestLeaderId), requestEpoch, currentTimeMs); + return buildBeginQuorumEpochResponse(Errors.NONE); + } + + private boolean handleBeginQuorumEpochResponse( + RaftResponse.Inbound responseMetadata, + long currentTimeMs + ) { + int remoteNodeId = responseMetadata.sourceId(); + BeginQuorumEpochResponseData response = (BeginQuorumEpochResponseData) responseMetadata.data; + Errors topLevelError = Errors.forCode(response.errorCode()); + if (topLevelError != Errors.NONE) { + return handleTopLevelError(topLevelError, responseMetadata); + } + + if (!hasValidTopicPartition(response, log.topicPartition())) { + return false; + } + + BeginQuorumEpochResponseData.PartitionData partitionResponse = + response.topics().get(0).partitions().get(0); + + Errors partitionError = Errors.forCode(partitionResponse.errorCode()); + OptionalInt responseLeaderId = optionalLeaderId(partitionResponse.leaderId()); + int responseEpoch = partitionResponse.leaderEpoch(); + + Optional handled = maybeHandleCommonResponse( + partitionError, responseLeaderId, responseEpoch, currentTimeMs); + if (handled.isPresent()) { + return handled.get(); + } else if (partitionError == Errors.NONE) { + if (quorum.isLeader()) { + LeaderState state = quorum.leaderStateOrThrow(); + state.addAcknowledgementFrom(remoteNodeId); + } else { + logger.debug("Ignoring BeginQuorumEpoch response {} since " + + "this node is not the leader anymore", response); + } + return true; + } else { + return handleUnexpectedError(partitionError, responseMetadata); + } + } + + private EndQuorumEpochResponseData buildEndQuorumEpochResponse(Errors partitionLevelError) { + return EndQuorumEpochResponse.singletonResponse( + Errors.NONE, + log.topicPartition(), + partitionLevelError, + quorum.epoch(), + quorum.leaderIdOrSentinel()); + } + + /** + * Handle an EndEpoch request. This API may return the following errors: + * + * - {@link Errors#INCONSISTENT_CLUSTER_ID} if the cluster id is presented in request + * but different from this node + * - {@link Errors#BROKER_NOT_AVAILABLE} if this node is currently shutting down + * - {@link Errors#INCONSISTENT_VOTER_SET} if the request suggests inconsistent voter membership (e.g. + * if this node or the sender is not one of the current known voters) + * - {@link Errors#FENCED_LEADER_EPOCH} if the epoch is smaller than this node's epoch + */ + private EndQuorumEpochResponseData handleEndQuorumEpochRequest( + RaftRequest.Inbound requestMetadata, + long currentTimeMs + ) { + EndQuorumEpochRequestData request = (EndQuorumEpochRequestData) requestMetadata.data; + + if (!hasValidClusterId(request.clusterId())) { + return new EndQuorumEpochResponseData().setErrorCode(Errors.INCONSISTENT_CLUSTER_ID.code()); + } + + if (!hasValidTopicPartition(request, log.topicPartition())) { + // Until we support multi-raft, we treat topic partition mismatches as invalid requests + return new EndQuorumEpochResponseData().setErrorCode(Errors.INVALID_REQUEST.code()); + } + + EndQuorumEpochRequestData.PartitionData partitionRequest = + request.topics().get(0).partitions().get(0); + + int requestEpoch = partitionRequest.leaderEpoch(); + int requestLeaderId = partitionRequest.leaderId(); + + Optional errorOpt = validateVoterOnlyRequest(requestLeaderId, requestEpoch); + if (errorOpt.isPresent()) { + return buildEndQuorumEpochResponse(errorOpt.get()); + } + maybeTransition(OptionalInt.of(requestLeaderId), requestEpoch, currentTimeMs); + + if (quorum.isFollower()) { + FollowerState state = quorum.followerStateOrThrow(); + if (state.leaderId() == requestLeaderId) { + List preferredSuccessors = partitionRequest.preferredSuccessors(); + long electionBackoffMs = endEpochElectionBackoff(preferredSuccessors); + logger.debug("Overriding follower fetch timeout to {} after receiving " + + "EndQuorumEpoch request from leader {} in epoch {}", electionBackoffMs, + requestLeaderId, requestEpoch); + state.overrideFetchTimeout(currentTimeMs, electionBackoffMs); + } + } + return buildEndQuorumEpochResponse(Errors.NONE); + } + + private long endEpochElectionBackoff(List preferredSuccessors) { + // Based on the priority inside the preferred successors, choose the corresponding delayed + // election backoff time based on strict exponential mechanism so that the most up-to-date + // voter has a higher chance to be elected. If the node's priority is highest, become + // candidate immediately instead of waiting for next poll. + int position = preferredSuccessors.indexOf(quorum.localIdOrThrow()); + if (position <= 0) { + return 0; + } else { + return strictExponentialElectionBackoffMs(position, preferredSuccessors.size()); + } + } + + private boolean handleEndQuorumEpochResponse( + RaftResponse.Inbound responseMetadata, + long currentTimeMs + ) { + EndQuorumEpochResponseData response = (EndQuorumEpochResponseData) responseMetadata.data; + Errors topLevelError = Errors.forCode(response.errorCode()); + if (topLevelError != Errors.NONE) { + return handleTopLevelError(topLevelError, responseMetadata); + } + + if (!hasValidTopicPartition(response, log.topicPartition())) { + return false; + } + + EndQuorumEpochResponseData.PartitionData partitionResponse = + response.topics().get(0).partitions().get(0); + + Errors partitionError = Errors.forCode(partitionResponse.errorCode()); + OptionalInt responseLeaderId = optionalLeaderId(partitionResponse.leaderId()); + int responseEpoch = partitionResponse.leaderEpoch(); + + Optional handled = maybeHandleCommonResponse( + partitionError, responseLeaderId, responseEpoch, currentTimeMs); + if (handled.isPresent()) { + return handled.get(); + } else if (partitionError == Errors.NONE) { + ResignedState resignedState = quorum.resignedStateOrThrow(); + resignedState.acknowledgeResignation(responseMetadata.sourceId()); + return true; + } else { + return handleUnexpectedError(partitionError, responseMetadata); + } + } + + private FetchResponseData buildFetchResponse( + Errors error, + Records records, + ValidOffsetAndEpoch validOffsetAndEpoch, + Optional highWatermark + ) { + return RaftUtil.singletonFetchResponse(log.topicPartition(), log.topicId(), Errors.NONE, partitionData -> { + partitionData + .setRecords(records) + .setErrorCode(error.code()) + .setLogStartOffset(log.startOffset()) + .setHighWatermark(highWatermark + .map(offsetMetadata -> offsetMetadata.offset) + .orElse(-1L)); + + partitionData.currentLeader() + .setLeaderEpoch(quorum.epoch()) + .setLeaderId(quorum.leaderIdOrSentinel()); + + switch (validOffsetAndEpoch.kind()) { + case DIVERGING: + partitionData.divergingEpoch() + .setEpoch(validOffsetAndEpoch.offsetAndEpoch().epoch) + .setEndOffset(validOffsetAndEpoch.offsetAndEpoch().offset); + break; + case SNAPSHOT: + partitionData.snapshotId() + .setEpoch(validOffsetAndEpoch.offsetAndEpoch().epoch) + .setEndOffset(validOffsetAndEpoch.offsetAndEpoch().offset); + break; + default: + } + }); + } + + private FetchResponseData buildEmptyFetchResponse( + Errors error, + Optional highWatermark + ) { + return buildFetchResponse( + error, + MemoryRecords.EMPTY, + ValidOffsetAndEpoch.valid(), + highWatermark + ); + } + + private boolean hasValidClusterId(String requestClusterId) { + // We don't enforce the cluster id if it is not provided. + if (requestClusterId == null) { + return true; + } + return clusterId.equals(requestClusterId); + } + + /** + * Handle a Fetch request. The fetch offset and last fetched epoch are always + * validated against the current log. In the case that they do not match, the response will + * indicate the diverging offset/epoch. A follower is expected to truncate its log in this + * case and resend the fetch. + * + * This API may return the following errors: + * + * - {@link Errors#INCONSISTENT_CLUSTER_ID} if the cluster id is presented in request + * but different from this node + * - {@link Errors#BROKER_NOT_AVAILABLE} if this node is currently shutting down + * - {@link Errors#FENCED_LEADER_EPOCH} if the epoch is smaller than this node's epoch + * - {@link Errors#INVALID_REQUEST} if the request epoch is larger than the leader's current epoch + * or if either the fetch offset or the last fetched epoch is invalid + */ + private CompletableFuture handleFetchRequest( + RaftRequest.Inbound requestMetadata, + long currentTimeMs + ) { + FetchRequestData request = (FetchRequestData) requestMetadata.data; + + if (!hasValidClusterId(request.clusterId())) { + return completedFuture(new FetchResponseData().setErrorCode(Errors.INCONSISTENT_CLUSTER_ID.code())); + } + + if (!hasValidTopicPartition(request, log.topicPartition(), log.topicId())) { + // Until we support multi-raft, we treat topic partition mismatches as invalid requests + return completedFuture(new FetchResponseData().setErrorCode(Errors.INVALID_REQUEST.code())); + } + // If the ID is valid, we can set the topic name. + request.topics().get(0).setTopic(log.topicPartition().topic()); + + FetchRequestData.FetchPartition fetchPartition = request.topics().get(0).partitions().get(0); + if (request.maxWaitMs() < 0 + || fetchPartition.fetchOffset() < 0 + || fetchPartition.lastFetchedEpoch() < 0 + || fetchPartition.lastFetchedEpoch() > fetchPartition.currentLeaderEpoch()) { + return completedFuture(buildEmptyFetchResponse( + Errors.INVALID_REQUEST, Optional.empty())); + } + + FetchResponseData response = tryCompleteFetchRequest(request.replicaId(), fetchPartition, currentTimeMs); + FetchResponseData.PartitionData partitionResponse = + response.responses().get(0).partitions().get(0); + + if (partitionResponse.errorCode() != Errors.NONE.code() + || FetchResponse.recordsSize(partitionResponse) > 0 + || request.maxWaitMs() == 0) { + return completedFuture(response); + } + + CompletableFuture future = fetchPurgatory.await( + fetchPartition.fetchOffset(), + request.maxWaitMs()); + + return future.handle((completionTimeMs, exception) -> { + if (exception != null) { + Throwable cause = exception instanceof ExecutionException ? + exception.getCause() : exception; + + // If the fetch timed out in purgatory, it means no new data is available, + // and we will complete the fetch successfully. Otherwise, if there was + // any other error, we need to return it. + Errors error = Errors.forException(cause); + if (error != Errors.REQUEST_TIMED_OUT) { + logger.debug("Failed to handle fetch from {} at {} due to {}", + request.replicaId(), fetchPartition.fetchOffset(), error); + return buildEmptyFetchResponse(error, Optional.empty()); + } + } + + // FIXME: `completionTimeMs`, which can be null + logger.trace("Completing delayed fetch from {} starting at offset {} at {}", + request.replicaId(), fetchPartition.fetchOffset(), completionTimeMs); + + return tryCompleteFetchRequest(request.replicaId(), fetchPartition, time.milliseconds()); + }); + } + + private FetchResponseData tryCompleteFetchRequest( + int replicaId, + FetchRequestData.FetchPartition request, + long currentTimeMs + ) { + try { + Optional errorOpt = validateLeaderOnlyRequest(request.currentLeaderEpoch()); + if (errorOpt.isPresent()) { + return buildEmptyFetchResponse(errorOpt.get(), Optional.empty()); + } + + long fetchOffset = request.fetchOffset(); + int lastFetchedEpoch = request.lastFetchedEpoch(); + LeaderState state = quorum.leaderStateOrThrow(); + ValidOffsetAndEpoch validOffsetAndEpoch = log.validateOffsetAndEpoch(fetchOffset, lastFetchedEpoch); + + final Records records; + if (validOffsetAndEpoch.kind() == ValidOffsetAndEpoch.Kind.VALID) { + LogFetchInfo info = log.read(fetchOffset, Isolation.UNCOMMITTED); + + if (state.updateReplicaState(replicaId, currentTimeMs, info.startOffsetMetadata)) { + onUpdateLeaderHighWatermark(state, currentTimeMs); + } + + records = info.records; + } else { + records = MemoryRecords.EMPTY; + } + + return buildFetchResponse(Errors.NONE, records, validOffsetAndEpoch, state.highWatermark()); + } catch (Exception e) { + logger.error("Caught unexpected error in fetch completion of request {}", request, e); + return buildEmptyFetchResponse(Errors.UNKNOWN_SERVER_ERROR, Optional.empty()); + } + } + + private static OptionalInt optionalLeaderId(int leaderIdOrNil) { + if (leaderIdOrNil < 0) + return OptionalInt.empty(); + return OptionalInt.of(leaderIdOrNil); + } + + private static String listenerName(Listener listener) { + return String.format("%s@%s", listener.getClass().getTypeName(), System.identityHashCode(listener)); + } + + private boolean handleFetchResponse( + RaftResponse.Inbound responseMetadata, + long currentTimeMs + ) { + FetchResponseData response = (FetchResponseData) responseMetadata.data; + Errors topLevelError = Errors.forCode(response.errorCode()); + if (topLevelError != Errors.NONE) { + return handleTopLevelError(topLevelError, responseMetadata); + } + + if (!RaftUtil.hasValidTopicPartition(response, log.topicPartition(), log.topicId())) { + return false; + } + // If the ID is valid, we can set the topic name. + response.responses().get(0).setTopic(log.topicPartition().topic()); + + FetchResponseData.PartitionData partitionResponse = + response.responses().get(0).partitions().get(0); + + FetchResponseData.LeaderIdAndEpoch currentLeaderIdAndEpoch = partitionResponse.currentLeader(); + OptionalInt responseLeaderId = optionalLeaderId(currentLeaderIdAndEpoch.leaderId()); + int responseEpoch = currentLeaderIdAndEpoch.leaderEpoch(); + Errors error = Errors.forCode(partitionResponse.errorCode()); + + Optional handled = maybeHandleCommonResponse( + error, responseLeaderId, responseEpoch, currentTimeMs); + if (handled.isPresent()) { + return handled.get(); + } + + FollowerState state = quorum.followerStateOrThrow(); + if (error == Errors.NONE) { + FetchResponseData.EpochEndOffset divergingEpoch = partitionResponse.divergingEpoch(); + if (divergingEpoch.epoch() >= 0) { + // The leader is asking us to truncate before continuing + final OffsetAndEpoch divergingOffsetAndEpoch = new OffsetAndEpoch( + divergingEpoch.endOffset(), divergingEpoch.epoch()); + + state.highWatermark().ifPresent(highWatermark -> { + if (divergingOffsetAndEpoch.offset < highWatermark.offset) { + throw new KafkaException("The leader requested truncation to offset " + + divergingOffsetAndEpoch.offset + ", which is below the current high watermark" + + " " + highWatermark); + } + }); + + long truncationOffset = log.truncateToEndOffset(divergingOffsetAndEpoch); + logger.info("Truncated to offset {} from Fetch response from leader {}", truncationOffset, quorum.leaderIdOrSentinel()); + } else if (partitionResponse.snapshotId().epoch() >= 0 || + partitionResponse.snapshotId().endOffset() >= 0) { + // The leader is asking us to fetch a snapshot + + if (partitionResponse.snapshotId().epoch() < 0) { + logger.error( + "The leader sent a snapshot id with a valid end offset {} but with an invalid epoch {}", + partitionResponse.snapshotId().endOffset(), + partitionResponse.snapshotId().epoch() + ); + return false; + } else if (partitionResponse.snapshotId().endOffset() < 0) { + logger.error( + "The leader sent a snapshot id with a valid epoch {} but with an invalid end offset {}", + partitionResponse.snapshotId().epoch(), + partitionResponse.snapshotId().endOffset() + ); + return false; + } else { + final OffsetAndEpoch snapshotId = new OffsetAndEpoch( + partitionResponse.snapshotId().endOffset(), + partitionResponse.snapshotId().epoch() + ); + + // Do not validate the snapshot id against the local replicated log + // since this snapshot is expected to reference offsets and epochs + // greater than the log end offset and high-watermark + state.setFetchingSnapshot(log.storeSnapshot(snapshotId)); + } + } else { + Records records = FetchResponse.recordsOrFail(partitionResponse); + if (records.sizeInBytes() > 0) { + appendAsFollower(records); + } + + OptionalLong highWatermark = partitionResponse.highWatermark() < 0 ? + OptionalLong.empty() : OptionalLong.of(partitionResponse.highWatermark()); + updateFollowerHighWatermark(state, highWatermark); + } + + state.resetFetchTimeout(currentTimeMs); + return true; + } else { + return handleUnexpectedError(error, responseMetadata); + } + } + + private void appendAsFollower( + Records records + ) { + LogAppendInfo info = log.appendAsFollower(records); + log.flush(); + + OffsetAndEpoch endOffset = endOffset(); + kafkaRaftMetrics.updateFetchedRecords(info.lastOffset - info.firstOffset + 1); + kafkaRaftMetrics.updateLogEnd(endOffset); + logger.trace("Follower end offset updated to {} after append", endOffset); + } + + private LogAppendInfo appendAsLeader( + Records records + ) { + LogAppendInfo info = log.appendAsLeader(records, quorum.epoch()); + OffsetAndEpoch endOffset = endOffset(); + kafkaRaftMetrics.updateAppendRecords(info.lastOffset - info.firstOffset + 1); + kafkaRaftMetrics.updateLogEnd(endOffset); + logger.trace("Leader appended records at base offset {}, new end offset is {}", info.firstOffset, endOffset); + return info; + } + + private DescribeQuorumResponseData handleDescribeQuorumRequest( + RaftRequest.Inbound requestMetadata, + long currentTimeMs + ) { + DescribeQuorumRequestData describeQuorumRequestData = (DescribeQuorumRequestData) requestMetadata.data; + if (!hasValidTopicPartition(describeQuorumRequestData, log.topicPartition())) { + return DescribeQuorumRequest.getPartitionLevelErrorResponse( + describeQuorumRequestData, Errors.UNKNOWN_TOPIC_OR_PARTITION); + } + + if (!quorum.isLeader()) { + return DescribeQuorumRequest.getTopLevelErrorResponse(Errors.INVALID_REQUEST); + } + + LeaderState leaderState = quorum.leaderStateOrThrow(); + return DescribeQuorumResponse.singletonResponse(log.topicPartition(), + leaderState.localId(), + leaderState.epoch(), + leaderState.highWatermark().isPresent() ? leaderState.highWatermark().get().offset : -1, + convertToReplicaStates(leaderState.getVoterEndOffsets()), + convertToReplicaStates(leaderState.getObserverStates(currentTimeMs)) + ); + } + + /** + * Handle a FetchSnapshot request, similar to the Fetch request but we use {@link UnalignedRecords} + * in response because the records are not necessarily offset-aligned. + * + * This API may return the following errors: + * + * - {@link Errors#INCONSISTENT_CLUSTER_ID} if the cluster id is presented in request + * but different from this node + * - {@link Errors#BROKER_NOT_AVAILABLE} if this node is currently shutting down + * - {@link Errors#FENCED_LEADER_EPOCH} if the epoch is smaller than this node's epoch + * - {@link Errors#INVALID_REQUEST} if the request epoch is larger than the leader's current epoch + * or if either the fetch offset or the last fetched epoch is invalid + * - {@link Errors#SNAPSHOT_NOT_FOUND} if the request snapshot id does not exists + * - {@link Errors#POSITION_OUT_OF_RANGE} if the request snapshot offset out of range + */ + private FetchSnapshotResponseData handleFetchSnapshotRequest( + RaftRequest.Inbound requestMetadata + ) { + FetchSnapshotRequestData data = (FetchSnapshotRequestData) requestMetadata.data; + + if (!hasValidClusterId(data.clusterId())) { + return new FetchSnapshotResponseData().setErrorCode(Errors.INCONSISTENT_CLUSTER_ID.code()); + } + + if (data.topics().size() != 1 && data.topics().get(0).partitions().size() != 1) { + return FetchSnapshotResponse.withTopLevelError(Errors.INVALID_REQUEST); + } + + Optional partitionSnapshotOpt = FetchSnapshotRequest + .forTopicPartition(data, log.topicPartition()); + if (!partitionSnapshotOpt.isPresent()) { + // The Raft client assumes that there is only one topic partition. + TopicPartition unknownTopicPartition = new TopicPartition( + data.topics().get(0).name(), + data.topics().get(0).partitions().get(0).partition() + ); + + return FetchSnapshotResponse.singleton( + unknownTopicPartition, + responsePartitionSnapshot -> responsePartitionSnapshot + .setErrorCode(Errors.UNKNOWN_TOPIC_OR_PARTITION.code()) + ); + } + + FetchSnapshotRequestData.PartitionSnapshot partitionSnapshot = partitionSnapshotOpt.get(); + Optional leaderValidation = validateLeaderOnlyRequest( + partitionSnapshot.currentLeaderEpoch() + ); + if (leaderValidation.isPresent()) { + return FetchSnapshotResponse.singleton( + log.topicPartition(), + responsePartitionSnapshot -> addQuorumLeader(responsePartitionSnapshot) + .setErrorCode(leaderValidation.get().code()) + ); + } + + OffsetAndEpoch snapshotId = new OffsetAndEpoch( + partitionSnapshot.snapshotId().endOffset(), + partitionSnapshot.snapshotId().epoch() + ); + Optional snapshotOpt = log.readSnapshot(snapshotId); + if (!snapshotOpt.isPresent()) { + return FetchSnapshotResponse.singleton( + log.topicPartition(), + responsePartitionSnapshot -> addQuorumLeader(responsePartitionSnapshot) + .setErrorCode(Errors.SNAPSHOT_NOT_FOUND.code()) + ); + } + + RawSnapshotReader snapshot = snapshotOpt.get(); + long snapshotSize = snapshot.sizeInBytes(); + if (partitionSnapshot.position() < 0 || partitionSnapshot.position() >= snapshotSize) { + return FetchSnapshotResponse.singleton( + log.topicPartition(), + responsePartitionSnapshot -> addQuorumLeader(responsePartitionSnapshot) + .setErrorCode(Errors.POSITION_OUT_OF_RANGE.code()) + ); + } + + if (partitionSnapshot.position() > Integer.MAX_VALUE) { + throw new IllegalStateException( + String.format( + "Trying to fetch a snapshot with size (%s) and a position (%s) larger than %s", + snapshotSize, + partitionSnapshot.position(), + Integer.MAX_VALUE + ) + ); + } + + int maxSnapshotSize; + try { + maxSnapshotSize = Math.toIntExact(snapshotSize); + } catch (ArithmeticException e) { + maxSnapshotSize = Integer.MAX_VALUE; + } + + UnalignedRecords records = snapshot.slice(partitionSnapshot.position(), Math.min(data.maxBytes(), maxSnapshotSize)); + + return FetchSnapshotResponse.singleton( + log.topicPartition(), + responsePartitionSnapshot -> { + addQuorumLeader(responsePartitionSnapshot) + .snapshotId() + .setEndOffset(snapshotId.offset) + .setEpoch(snapshotId.epoch); + + return responsePartitionSnapshot + .setSize(snapshotSize) + .setPosition(partitionSnapshot.position()) + .setUnalignedRecords(records); + } + ); + } + + private boolean handleFetchSnapshotResponse( + RaftResponse.Inbound responseMetadata, + long currentTimeMs + ) { + FetchSnapshotResponseData data = (FetchSnapshotResponseData) responseMetadata.data; + Errors topLevelError = Errors.forCode(data.errorCode()); + if (topLevelError != Errors.NONE) { + return handleTopLevelError(topLevelError, responseMetadata); + } + + if (data.topics().size() != 1 && data.topics().get(0).partitions().size() != 1) { + return false; + } + + Optional partitionSnapshotOpt = FetchSnapshotResponse + .forTopicPartition(data, log.topicPartition()); + if (!partitionSnapshotOpt.isPresent()) { + return false; + } + + FetchSnapshotResponseData.PartitionSnapshot partitionSnapshot = partitionSnapshotOpt.get(); + + FetchSnapshotResponseData.LeaderIdAndEpoch currentLeaderIdAndEpoch = partitionSnapshot.currentLeader(); + OptionalInt responseLeaderId = optionalLeaderId(currentLeaderIdAndEpoch.leaderId()); + int responseEpoch = currentLeaderIdAndEpoch.leaderEpoch(); + Errors error = Errors.forCode(partitionSnapshot.errorCode()); + + Optional handled = maybeHandleCommonResponse( + error, responseLeaderId, responseEpoch, currentTimeMs); + if (handled.isPresent()) { + return handled.get(); + } + + FollowerState state = quorum.followerStateOrThrow(); + + if (Errors.forCode(partitionSnapshot.errorCode()) == Errors.SNAPSHOT_NOT_FOUND || + partitionSnapshot.snapshotId().endOffset() < 0 || + partitionSnapshot.snapshotId().epoch() < 0) { + + /* The leader deleted the snapshot before the follower could download it. Start over by + * reseting the fetching snapshot state and sending another fetch request. + */ + logger.trace( + "Leader doesn't know about snapshot id {}, returned error {} and snapshot id {}", + state.fetchingSnapshot(), + partitionSnapshot.errorCode(), + partitionSnapshot.snapshotId() + ); + state.setFetchingSnapshot(Optional.empty()); + state.resetFetchTimeout(currentTimeMs); + return true; + } + + OffsetAndEpoch snapshotId = new OffsetAndEpoch( + partitionSnapshot.snapshotId().endOffset(), + partitionSnapshot.snapshotId().epoch() + ); + + RawSnapshotWriter snapshot; + if (state.fetchingSnapshot().isPresent()) { + snapshot = state.fetchingSnapshot().get(); + } else { + throw new IllegalStateException( + String.format("Received unexpected fetch snapshot response: %s", partitionSnapshot) + ); + } + + if (!snapshot.snapshotId().equals(snapshotId)) { + throw new IllegalStateException( + String.format( + "Received fetch snapshot response with an invalid id. Expected %s; Received %s", + snapshot.snapshotId(), + snapshotId + ) + ); + } + if (snapshot.sizeInBytes() != partitionSnapshot.position()) { + throw new IllegalStateException( + String.format( + "Received fetch snapshot response with an invalid position. Expected %s; Received %s", + snapshot.sizeInBytes(), + partitionSnapshot.position() + ) + ); + } + + final UnalignedMemoryRecords records; + if (partitionSnapshot.unalignedRecords() instanceof MemoryRecords) { + records = new UnalignedMemoryRecords(((MemoryRecords) partitionSnapshot.unalignedRecords()).buffer()); + } else if (partitionSnapshot.unalignedRecords() instanceof UnalignedMemoryRecords) { + records = (UnalignedMemoryRecords) partitionSnapshot.unalignedRecords(); + } else { + throw new IllegalStateException(String.format("Received unexpected fetch snapshot response: %s", partitionSnapshot)); + } + snapshot.append(records); + + if (snapshot.sizeInBytes() == partitionSnapshot.size()) { + // Finished fetching the snapshot. + snapshot.freeze(); + state.setFetchingSnapshot(Optional.empty()); + + if (log.truncateToLatestSnapshot()) { + updateFollowerHighWatermark(state, OptionalLong.of(log.highWatermark().offset)); + } else { + throw new IllegalStateException( + String.format( + "Full log truncation expected but didn't happen. Snapshot of %s, log end offset %s, last fetched %s", + snapshot.snapshotId(), + log.endOffset(), + log.lastFetchedEpoch() + ) + ); + } + } + + state.resetFetchTimeout(currentTimeMs); + return true; + } + + List convertToReplicaStates(Map replicaEndOffsets) { + return replicaEndOffsets.entrySet().stream() + .map(entry -> new ReplicaState() + .setReplicaId(entry.getKey()) + .setLogEndOffset(entry.getValue())) + .collect(Collectors.toList()); + } + + private boolean hasConsistentLeader(int epoch, OptionalInt leaderId) { + // Only elected leaders are sent in the request/response header, so if we have an elected + // leaderId, it should be consistent with what is in the message. + if (leaderId.isPresent() && leaderId.getAsInt() == quorum.localIdOrSentinel()) { + // The response indicates that we should be the leader, so we verify that is the case + return quorum.isLeader(); + } else { + return epoch != quorum.epoch() + || !leaderId.isPresent() + || !quorum.leaderId().isPresent() + || leaderId.equals(quorum.leaderId()); + } + } + + /** + * Handle response errors that are common across request types. + * + * @param error Error from the received response + * @param leaderId Optional leaderId from the response + * @param epoch Epoch received from the response + * @param currentTimeMs Current epoch time in milliseconds + * @return Optional value indicating whether the error was handled here and the outcome of + * that handling. Specifically: + * + * - Optional.empty means that the response was not handled here and the custom + * API handler should be applied + * - Optional.of(true) indicates that the response was successfully handled here and + * the request does not need to be retried + * - Optional.of(false) indicates that the response was handled here, but that the request + * will need to be retried + */ + private Optional maybeHandleCommonResponse( + Errors error, + OptionalInt leaderId, + int epoch, + long currentTimeMs + ) { + if (epoch < quorum.epoch() || error == Errors.UNKNOWN_LEADER_EPOCH) { + // We have a larger epoch, so the response is no longer relevant + return Optional.of(true); + } else if (epoch > quorum.epoch() + || error == Errors.FENCED_LEADER_EPOCH + || error == Errors.NOT_LEADER_OR_FOLLOWER) { + + // The response indicates that the request had a stale epoch, but we need + // to validate the epoch from the response against our current state. + maybeTransition(leaderId, epoch, currentTimeMs); + return Optional.of(true); + } else if (epoch == quorum.epoch() + && leaderId.isPresent() + && !quorum.hasLeader()) { + + // Since we are transitioning to Follower, we will only forward the + // request to the handler if there is no error. Otherwise, we will let + // the request be retried immediately (if needed) after the transition. + // This handling allows an observer to discover the leader and append + // to the log in the same Fetch request. + transitionToFollower(epoch, leaderId.getAsInt(), currentTimeMs); + if (error == Errors.NONE) { + return Optional.empty(); + } else { + return Optional.of(true); + } + } else if (error == Errors.BROKER_NOT_AVAILABLE) { + return Optional.of(false); + } else if (error == Errors.INCONSISTENT_GROUP_PROTOCOL) { + // For now we treat this as a fatal error. Once we have support for quorum + // reassignment, this error could suggest that either we or the recipient of + // the request just has stale voter information, which means we can retry + // after backing off. + throw new IllegalStateException("Received error indicating inconsistent voter sets"); + } else if (error == Errors.INVALID_REQUEST) { + throw new IllegalStateException("Received unexpected invalid request error"); + } + + return Optional.empty(); + } + + private void maybeTransition( + OptionalInt leaderId, + int epoch, + long currentTimeMs + ) { + if (!hasConsistentLeader(epoch, leaderId)) { + throw new IllegalStateException("Received request or response with leader " + leaderId + + " and epoch " + epoch + " which is inconsistent with current leader " + + quorum.leaderId() + " and epoch " + quorum.epoch()); + } else if (epoch > quorum.epoch()) { + if (leaderId.isPresent()) { + transitionToFollower(epoch, leaderId.getAsInt(), currentTimeMs); + } else { + transitionToUnattached(epoch); + } + } else if (leaderId.isPresent() && !quorum.hasLeader()) { + // The request or response indicates the leader of the current epoch, + // which is currently unknown + transitionToFollower(epoch, leaderId.getAsInt(), currentTimeMs); + } + } + + private boolean handleTopLevelError(Errors error, RaftResponse.Inbound response) { + if (error == Errors.BROKER_NOT_AVAILABLE) { + return false; + } else if (error == Errors.CLUSTER_AUTHORIZATION_FAILED) { + throw new ClusterAuthorizationException("Received cluster authorization error in response " + response); + } else { + return handleUnexpectedError(error, response); + } + } + + private boolean handleUnexpectedError(Errors error, RaftResponse.Inbound response) { + logger.error("Unexpected error {} in {} response: {}", + error, ApiKeys.forId(response.data.apiKey()), response); + return false; + } + + private void handleResponse(RaftResponse.Inbound response, long currentTimeMs) { + // The response epoch matches the local epoch, so we can handle the response + ApiKeys apiKey = ApiKeys.forId(response.data.apiKey()); + final boolean handledSuccessfully; + + switch (apiKey) { + case FETCH: + handledSuccessfully = handleFetchResponse(response, currentTimeMs); + break; + + case VOTE: + handledSuccessfully = handleVoteResponse(response, currentTimeMs); + break; + + case BEGIN_QUORUM_EPOCH: + handledSuccessfully = handleBeginQuorumEpochResponse(response, currentTimeMs); + break; + + case END_QUORUM_EPOCH: + handledSuccessfully = handleEndQuorumEpochResponse(response, currentTimeMs); + break; + + case FETCH_SNAPSHOT: + handledSuccessfully = handleFetchSnapshotResponse(response, currentTimeMs); + break; + + default: + throw new IllegalArgumentException("Received unexpected response type: " + apiKey); + } + + ConnectionState connection = requestManager.getOrCreate(response.sourceId()); + if (handledSuccessfully) { + connection.onResponseReceived(response.correlationId); + } else { + connection.onResponseError(response.correlationId, currentTimeMs); + } + } + + /** + * Validate a request which is only valid between voters. If an error is + * present in the returned value, it should be returned in the response. + */ + private Optional validateVoterOnlyRequest(int remoteNodeId, int requestEpoch) { + if (requestEpoch < quorum.epoch()) { + return Optional.of(Errors.FENCED_LEADER_EPOCH); + } else if (remoteNodeId < 0) { + return Optional.of(Errors.INVALID_REQUEST); + } else if (quorum.isObserver() || !quorum.isVoter(remoteNodeId)) { + return Optional.of(Errors.INCONSISTENT_VOTER_SET); + } else { + return Optional.empty(); + } + } + + /** + * Validate a request which is intended for the current quorum leader. + * If an error is present in the returned value, it should be returned + * in the response. + */ + private Optional validateLeaderOnlyRequest(int requestEpoch) { + if (requestEpoch < quorum.epoch()) { + return Optional.of(Errors.FENCED_LEADER_EPOCH); + } else if (requestEpoch > quorum.epoch()) { + return Optional.of(Errors.UNKNOWN_LEADER_EPOCH); + } else if (!quorum.isLeader()) { + // In general, non-leaders do not expect to receive requests + // matching their own epoch, but it is possible when observers + // are using the Fetch API to find the result of an election. + return Optional.of(Errors.NOT_LEADER_OR_FOLLOWER); + } else if (shutdown.get() != null) { + return Optional.of(Errors.BROKER_NOT_AVAILABLE); + } else { + return Optional.empty(); + } + } + + private void handleRequest(RaftRequest.Inbound request, long currentTimeMs) { + ApiKeys apiKey = ApiKeys.forId(request.data.apiKey()); + final CompletableFuture responseFuture; + + switch (apiKey) { + case FETCH: + responseFuture = handleFetchRequest(request, currentTimeMs); + break; + + case VOTE: + responseFuture = completedFuture(handleVoteRequest(request)); + break; + + case BEGIN_QUORUM_EPOCH: + responseFuture = completedFuture(handleBeginQuorumEpochRequest(request, currentTimeMs)); + break; + + case END_QUORUM_EPOCH: + responseFuture = completedFuture(handleEndQuorumEpochRequest(request, currentTimeMs)); + break; + + case DESCRIBE_QUORUM: + responseFuture = completedFuture(handleDescribeQuorumRequest(request, currentTimeMs)); + break; + + case FETCH_SNAPSHOT: + responseFuture = completedFuture(handleFetchSnapshotRequest(request)); + break; + + default: + throw new IllegalArgumentException("Unexpected request type " + apiKey); + } + + responseFuture.whenComplete((response, exception) -> { + final ApiMessage message; + if (response != null) { + message = response; + } else { + message = RaftUtil.errorResponse(apiKey, Errors.forException(exception)); + } + + RaftResponse.Outbound responseMessage = new RaftResponse.Outbound(request.correlationId(), message); + request.completion.complete(responseMessage); + logger.trace("Sent response {} to inbound request {}", responseMessage, request); + }); + } + + private void handleInboundMessage(RaftMessage message, long currentTimeMs) { + logger.trace("Received inbound message {}", message); + + if (message instanceof RaftRequest.Inbound) { + RaftRequest.Inbound request = (RaftRequest.Inbound) message; + handleRequest(request, currentTimeMs); + } else if (message instanceof RaftResponse.Inbound) { + RaftResponse.Inbound response = (RaftResponse.Inbound) message; + ConnectionState connection = requestManager.getOrCreate(response.sourceId()); + if (connection.isResponseExpected(response.correlationId)) { + handleResponse(response, currentTimeMs); + } else { + logger.debug("Ignoring response {} since it is no longer needed", response); + } + } else { + throw new IllegalArgumentException("Unexpected message " + message); + } + } + + /** + * Attempt to send a request. Return the time to wait before the request can be retried. + */ + private long maybeSendRequest( + long currentTimeMs, + int destinationId, + Supplier requestSupplier + ) { + ConnectionState connection = requestManager.getOrCreate(destinationId); + + if (connection.isBackingOff(currentTimeMs)) { + long remainingBackoffMs = connection.remainingBackoffMs(currentTimeMs); + logger.debug("Connection for {} is backing off for {} ms", destinationId, remainingBackoffMs); + return remainingBackoffMs; + } + + if (connection.isReady(currentTimeMs)) { + int correlationId = channel.newCorrelationId(); + ApiMessage request = requestSupplier.get(); + + RaftRequest.Outbound requestMessage = new RaftRequest.Outbound( + correlationId, + request, + destinationId, + currentTimeMs + ); + + requestMessage.completion.whenComplete((response, exception) -> { + if (exception != null) { + ApiKeys api = ApiKeys.forId(request.apiKey()); + Errors error = Errors.forException(exception); + ApiMessage errorResponse = RaftUtil.errorResponse(api, error); + + response = new RaftResponse.Inbound( + correlationId, + errorResponse, + destinationId + ); + } + + messageQueue.add(response); + }); + + channel.send(requestMessage); + logger.trace("Sent outbound request: {}", requestMessage); + connection.onRequestSent(correlationId, currentTimeMs); + return Long.MAX_VALUE; + } + + return connection.remainingRequestTimeMs(currentTimeMs); + } + + private EndQuorumEpochRequestData buildEndQuorumEpochRequest( + ResignedState state + ) { + return EndQuorumEpochRequest.singletonRequest( + log.topicPartition(), + clusterId, + quorum.epoch(), + quorum.localIdOrThrow(), + state.preferredSuccessors() + ); + } + + private long maybeSendRequests( + long currentTimeMs, + Set destinationIds, + Supplier requestSupplier + ) { + long minBackoffMs = Long.MAX_VALUE; + for (Integer destinationId : destinationIds) { + long backoffMs = maybeSendRequest(currentTimeMs, destinationId, requestSupplier); + if (backoffMs < minBackoffMs) { + minBackoffMs = backoffMs; + } + } + return minBackoffMs; + } + + private BeginQuorumEpochRequestData buildBeginQuorumEpochRequest() { + return BeginQuorumEpochRequest.singletonRequest( + log.topicPartition(), + clusterId, + quorum.epoch(), + quorum.localIdOrThrow() + ); + } + + private VoteRequestData buildVoteRequest() { + OffsetAndEpoch endOffset = endOffset(); + return VoteRequest.singletonRequest( + log.topicPartition(), + clusterId, + quorum.epoch(), + quorum.localIdOrThrow(), + endOffset.epoch, + endOffset.offset + ); + } + + private FetchRequestData buildFetchRequest() { + FetchRequestData request = RaftUtil.singletonFetchRequest(log.topicPartition(), log.topicId(), fetchPartition -> { + fetchPartition + .setCurrentLeaderEpoch(quorum.epoch()) + .setLastFetchedEpoch(log.lastFetchedEpoch()) + .setFetchOffset(log.endOffset().offset); + }); + return request + .setMaxBytes(MAX_FETCH_SIZE_BYTES) + .setMaxWaitMs(fetchMaxWaitMs) + .setClusterId(clusterId) + .setReplicaId(quorum.localIdOrSentinel()); + } + + private long maybeSendAnyVoterFetch(long currentTimeMs) { + OptionalInt readyVoterIdOpt = requestManager.findReadyVoter(currentTimeMs); + if (readyVoterIdOpt.isPresent()) { + return maybeSendRequest( + currentTimeMs, + readyVoterIdOpt.getAsInt(), + this::buildFetchRequest + ); + } else { + return requestManager.backoffBeforeAvailableVoter(currentTimeMs); + } + } + + private FetchSnapshotRequestData buildFetchSnapshotRequest(OffsetAndEpoch snapshotId, long snapshotSize) { + FetchSnapshotRequestData.SnapshotId requestSnapshotId = new FetchSnapshotRequestData.SnapshotId() + .setEpoch(snapshotId.epoch) + .setEndOffset(snapshotId.offset); + + FetchSnapshotRequestData request = FetchSnapshotRequest.singleton( + clusterId, + log.topicPartition(), + snapshotPartition -> { + return snapshotPartition + .setCurrentLeaderEpoch(quorum.epoch()) + .setSnapshotId(requestSnapshotId) + .setPosition(snapshotSize); + } + ); + + return request.setReplicaId(quorum.localIdOrSentinel()); + } + + private FetchSnapshotResponseData.PartitionSnapshot addQuorumLeader( + FetchSnapshotResponseData.PartitionSnapshot partitionSnapshot + ) { + partitionSnapshot.currentLeader() + .setLeaderEpoch(quorum.epoch()) + .setLeaderId(quorum.leaderIdOrSentinel()); + + return partitionSnapshot; + } + + public boolean isRunning() { + GracefulShutdown gracefulShutdown = shutdown.get(); + return gracefulShutdown == null || !gracefulShutdown.isFinished(); + } + + public boolean isShuttingDown() { + GracefulShutdown gracefulShutdown = shutdown.get(); + return gracefulShutdown != null && !gracefulShutdown.isFinished(); + } + + private void appendBatch( + LeaderState state, + BatchAccumulator.CompletedBatch batch, + long appendTimeMs + ) { + try { + int epoch = state.epoch(); + LogAppendInfo info = appendAsLeader(batch.data); + OffsetAndEpoch offsetAndEpoch = new OffsetAndEpoch(info.lastOffset, epoch); + CompletableFuture future = appendPurgatory.await( + offsetAndEpoch.offset + 1, Integer.MAX_VALUE); + + future.whenComplete((commitTimeMs, exception) -> { + if (exception != null) { + logger.debug("Failed to commit {} records at {}", batch.numRecords, offsetAndEpoch, exception); + } else { + long elapsedTime = Math.max(0, commitTimeMs - appendTimeMs); + double elapsedTimePerRecord = (double) elapsedTime / batch.numRecords; + kafkaRaftMetrics.updateCommitLatency(elapsedTimePerRecord, appendTimeMs); + logger.debug("Completed commit of {} records at {}", batch.numRecords, offsetAndEpoch); + batch.records.ifPresent(records -> { + maybeFireHandleCommit(batch.baseOffset, epoch, batch.appendTimestamp(), batch.sizeInBytes(), records); + }); + } + }); + } finally { + batch.release(); + } + } + + private long maybeAppendBatches( + LeaderState state, + long currentTimeMs + ) { + long timeUntilDrain = state.accumulator().timeUntilDrain(currentTimeMs); + if (timeUntilDrain <= 0) { + List> batches = state.accumulator().drain(); + Iterator> iterator = batches.iterator(); + + try { + while (iterator.hasNext()) { + BatchAccumulator.CompletedBatch batch = iterator.next(); + appendBatch(state, batch, currentTimeMs); + } + flushLeaderLog(state, currentTimeMs); + } finally { + // Release and discard any batches which failed to be appended + while (iterator.hasNext()) { + iterator.next().release(); + } + } + } + return timeUntilDrain; + } + + private long pollResigned(long currentTimeMs) { + ResignedState state = quorum.resignedStateOrThrow(); + long endQuorumBackoffMs = maybeSendRequests( + currentTimeMs, + state.unackedVoters(), + () -> buildEndQuorumEpochRequest(state) + ); + + GracefulShutdown shutdown = this.shutdown.get(); + final long stateTimeoutMs; + if (shutdown != null) { + // If we are shutting down, then we will remain in the resigned state + // until either the shutdown expires or an election bumps the epoch + stateTimeoutMs = shutdown.remainingTimeMs(); + } else if (state.hasElectionTimeoutExpired(currentTimeMs)) { + transitionToCandidate(currentTimeMs); + stateTimeoutMs = 0L; + } else { + stateTimeoutMs = state.remainingElectionTimeMs(currentTimeMs); + } + + return Math.min(stateTimeoutMs, endQuorumBackoffMs); + } + + private long pollLeader(long currentTimeMs) { + LeaderState state = quorum.leaderStateOrThrow(); + maybeFireLeaderChange(state); + + if (shutdown.get() != null || state.isResignRequested()) { + transitionToResigned(state.nonLeaderVotersByDescendingFetchOffset()); + return 0L; + } + + long timeUntilFlush = maybeAppendBatches( + state, + currentTimeMs + ); + + long timeUntilSend = maybeSendRequests( + currentTimeMs, + state.nonAcknowledgingVoters(), + this::buildBeginQuorumEpochRequest + ); + + return Math.min(timeUntilFlush, timeUntilSend); + } + + private long maybeSendVoteRequests( + CandidateState state, + long currentTimeMs + ) { + // Continue sending Vote requests as long as we still have a chance to win the election + if (!state.isVoteRejected()) { + return maybeSendRequests( + currentTimeMs, + state.unrecordedVoters(), + this::buildVoteRequest + ); + } + return Long.MAX_VALUE; + } + + private long pollCandidate(long currentTimeMs) { + CandidateState state = quorum.candidateStateOrThrow(); + GracefulShutdown shutdown = this.shutdown.get(); + + if (shutdown != null) { + // If we happen to shutdown while we are a candidate, we will continue + // with the current election until one of the following conditions is met: + // 1) we are elected as leader (which allows us to resign) + // 2) another leader is elected + // 3) the shutdown timer expires + long minRequestBackoffMs = maybeSendVoteRequests(state, currentTimeMs); + return Math.min(shutdown.remainingTimeMs(), minRequestBackoffMs); + } else if (state.isBackingOff()) { + if (state.isBackoffComplete(currentTimeMs)) { + logger.info("Re-elect as candidate after election backoff has completed"); + transitionToCandidate(currentTimeMs); + return 0L; + } + return state.remainingBackoffMs(currentTimeMs); + } else if (state.hasElectionTimeoutExpired(currentTimeMs)) { + long backoffDurationMs = binaryExponentialElectionBackoffMs(state.retries()); + logger.debug("Election has timed out, backing off for {}ms before becoming a candidate again", + backoffDurationMs); + state.startBackingOff(currentTimeMs, backoffDurationMs); + return backoffDurationMs; + } else { + long minRequestBackoffMs = maybeSendVoteRequests(state, currentTimeMs); + return Math.min(minRequestBackoffMs, state.remainingElectionTimeMs(currentTimeMs)); + } + } + + private long pollFollower(long currentTimeMs) { + FollowerState state = quorum.followerStateOrThrow(); + if (quorum.isVoter()) { + return pollFollowerAsVoter(state, currentTimeMs); + } else { + return pollFollowerAsObserver(state, currentTimeMs); + } + } + + private long pollFollowerAsVoter(FollowerState state, long currentTimeMs) { + GracefulShutdown shutdown = this.shutdown.get(); + if (shutdown != null) { + // If we are a follower, then we can shutdown immediately. We want to + // skip the transition to candidate in any case. + return 0; + } else if (state.hasFetchTimeoutExpired(currentTimeMs)) { + logger.info("Become candidate due to fetch timeout"); + transitionToCandidate(currentTimeMs); + return 0L; + } else { + long backoffMs = maybeSendFetchOrFetchSnapshot(state, currentTimeMs); + + return Math.min(backoffMs, state.remainingFetchTimeMs(currentTimeMs)); + } + } + + private long pollFollowerAsObserver(FollowerState state, long currentTimeMs) { + if (state.hasFetchTimeoutExpired(currentTimeMs)) { + return maybeSendAnyVoterFetch(currentTimeMs); + } else { + final long backoffMs; + + // If the current leader is backing off due to some failure or if the + // request has timed out, then we attempt to send the Fetch to another + // voter in order to discover if there has been a leader change. + ConnectionState connection = requestManager.getOrCreate(state.leaderId()); + if (connection.hasRequestTimedOut(currentTimeMs)) { + backoffMs = maybeSendAnyVoterFetch(currentTimeMs); + connection.reset(); + } else if (connection.isBackingOff(currentTimeMs)) { + backoffMs = maybeSendAnyVoterFetch(currentTimeMs); + } else { + backoffMs = maybeSendFetchOrFetchSnapshot(state, currentTimeMs); + } + + return Math.min(backoffMs, state.remainingFetchTimeMs(currentTimeMs)); + } + } + + private long maybeSendFetchOrFetchSnapshot(FollowerState state, long currentTimeMs) { + final Supplier requestSupplier; + + if (state.fetchingSnapshot().isPresent()) { + RawSnapshotWriter snapshot = state.fetchingSnapshot().get(); + long snapshotSize = snapshot.sizeInBytes(); + + requestSupplier = () -> buildFetchSnapshotRequest(snapshot.snapshotId(), snapshotSize); + } else { + requestSupplier = this::buildFetchRequest; + } + + return maybeSendRequest(currentTimeMs, state.leaderId(), requestSupplier); + } + + private long pollVoted(long currentTimeMs) { + VotedState state = quorum.votedStateOrThrow(); + GracefulShutdown shutdown = this.shutdown.get(); + + if (shutdown != null) { + // If shutting down, then remain in this state until either the + // shutdown completes or an epoch bump forces another state transition + return shutdown.remainingTimeMs(); + } else if (state.hasElectionTimeoutExpired(currentTimeMs)) { + transitionToCandidate(currentTimeMs); + return 0L; + } else { + return state.remainingElectionTimeMs(currentTimeMs); + } + } + + private long pollUnattached(long currentTimeMs) { + UnattachedState state = quorum.unattachedStateOrThrow(); + if (quorum.isVoter()) { + return pollUnattachedAsVoter(state, currentTimeMs); + } else { + return pollUnattachedAsObserver(state, currentTimeMs); + } + } + + private long pollUnattachedAsVoter(UnattachedState state, long currentTimeMs) { + GracefulShutdown shutdown = this.shutdown.get(); + if (shutdown != null) { + // If shutting down, then remain in this state until either the + // shutdown completes or an epoch bump forces another state transition + return shutdown.remainingTimeMs(); + } else if (state.hasElectionTimeoutExpired(currentTimeMs)) { + transitionToCandidate(currentTimeMs); + return 0L; + } else { + return state.remainingElectionTimeMs(currentTimeMs); + } + } + + private long pollUnattachedAsObserver(UnattachedState state, long currentTimeMs) { + long fetchBackoffMs = maybeSendAnyVoterFetch(currentTimeMs); + return Math.min(fetchBackoffMs, state.remainingElectionTimeMs(currentTimeMs)); + } + + private long pollCurrentState(long currentTimeMs) { + if (quorum.isLeader()) { + return pollLeader(currentTimeMs); + } else if (quorum.isCandidate()) { + return pollCandidate(currentTimeMs); + } else if (quorum.isFollower()) { + return pollFollower(currentTimeMs); + } else if (quorum.isVoted()) { + return pollVoted(currentTimeMs); + } else if (quorum.isUnattached()) { + return pollUnattached(currentTimeMs); + } else if (quorum.isResigned()) { + return pollResigned(currentTimeMs); + } else { + throw new IllegalStateException("Unexpected quorum state " + quorum); + } + } + + private void pollListeners() { + // Apply all of the pending registration + while (true) { + Registration registration = pendingRegistrations.poll(); + if (registration == null) { + break; + } + + processRegistration(registration); + } + + // Check listener progress to see if reads are expected + quorum.highWatermark().ifPresent(highWatermarkMetadata -> { + updateListenersProgress(highWatermarkMetadata.offset); + }); + } + + private void processRegistration(Registration registration) { + Listener listener = registration.listener(); + Registration.Ops ops = registration.ops(); + + if (ops == Registration.Ops.REGISTER) { + if (listenerContexts.putIfAbsent(listener, new ListenerContext(listener)) != null) { + logger.error("Attempting to add a listener that already exists: {}", listenerName(listener)); + } else { + logger.info("Registered the listener {}", listenerName(listener)); + } + } else { + if (listenerContexts.remove(listener) == null) { + logger.error("Attempting to remove a listener that doesn't exists: {}", listenerName(listener)); + } else { + logger.info("Unregistered the listener {}", listenerName(listener)); + } + } + } + + private boolean maybeCompleteShutdown(long currentTimeMs) { + GracefulShutdown shutdown = this.shutdown.get(); + if (shutdown == null) { + return false; + } + + shutdown.update(currentTimeMs); + if (shutdown.hasTimedOut()) { + shutdown.failWithTimeout(); + return true; + } + + if (quorum.isObserver() + || quorum.remoteVoters().isEmpty() + || quorum.hasRemoteLeader()) { + + shutdown.complete(); + return true; + } + + return false; + } + + /** + * A simple timer based log cleaner + */ + private static class RaftMetadataLogCleanerManager { + private final Logger logger; + private final Timer timer; + private final long delayMs; + private final Runnable cleaner; + + RaftMetadataLogCleanerManager(Logger logger, Time time, long delayMs, Runnable cleaner) { + this.logger = logger; + this.timer = time.timer(delayMs); + this.delayMs = delayMs; + this.cleaner = cleaner; + } + + public long maybeClean(long currentTimeMs) { + timer.update(currentTimeMs); + if (timer.isExpired()) { + try { + cleaner.run(); + } catch (Throwable t) { + logger.error("Had an error during log cleaning", t); + } + timer.reset(delayMs); + } + return timer.remainingMs(); + } + } + + private void wakeup() { + messageQueue.wakeup(); + } + + /** + * Handle an inbound request. The response will be returned through + * {@link RaftRequest.Inbound#completion}. + * + * @param request The inbound request + */ + public void handle(RaftRequest.Inbound request) { + messageQueue.add(Objects.requireNonNull(request)); + } + + /** + * Poll for new events. This allows the client to handle inbound + * requests and send any needed outbound requests. + */ + public void poll() { + pollListeners(); + + long currentTimeMs = time.milliseconds(); + if (maybeCompleteShutdown(currentTimeMs)) { + return; + } + + long pollStateTimeoutMs = pollCurrentState(currentTimeMs); + long cleaningTimeoutMs = snapshotCleaner.maybeClean(currentTimeMs); + long pollTimeoutMs = Math.min(pollStateTimeoutMs, cleaningTimeoutMs); + + kafkaRaftMetrics.updatePollStart(currentTimeMs); + + RaftMessage message = messageQueue.poll(pollTimeoutMs); + + currentTimeMs = time.milliseconds(); + kafkaRaftMetrics.updatePollEnd(currentTimeMs); + + if (message != null) { + handleInboundMessage(message, currentTimeMs); + } + } + + @Override + public long scheduleAppend(int epoch, List records) { + return append(epoch, records, false); + } + + @Override + public long scheduleAtomicAppend(int epoch, List records) { + return append(epoch, records, true); + } + + private long append(int epoch, List records, boolean isAtomic) { + LeaderState leaderState = quorum.maybeLeaderState().orElseThrow( + () -> new NotLeaderException("Append failed because the replication is not the current leader") + ); + + BatchAccumulator accumulator = leaderState.accumulator(); + boolean isFirstAppend = accumulator.isEmpty(); + final long offset; + if (isAtomic) { + offset = accumulator.appendAtomic(epoch, records); + } else { + offset = accumulator.append(epoch, records); + } + + // Wakeup the network channel if either this is the first append + // or the accumulator is ready to drain now. Checking for the first + // append ensures that we give the IO thread a chance to observe + // the linger timeout so that it can schedule its own wakeup in case + // there are no additional appends. + if (isFirstAppend || accumulator.needsDrain(time.milliseconds())) { + wakeup(); + } + return offset; + } + + @Override + public CompletableFuture shutdown(int timeoutMs) { + logger.info("Beginning graceful shutdown"); + CompletableFuture shutdownComplete = new CompletableFuture<>(); + shutdown.set(new GracefulShutdown(timeoutMs, shutdownComplete)); + wakeup(); + return shutdownComplete; + } + + @Override + public void resign(int epoch) { + if (epoch < 0) { + throw new IllegalArgumentException("Attempt to resign from an invalid negative epoch " + epoch); + } + + if (!quorum.isVoter()) { + throw new IllegalStateException("Attempt to resign by a non-voter"); + } + + LeaderAndEpoch leaderAndEpoch = leaderAndEpoch(); + int currentEpoch = leaderAndEpoch.epoch(); + + if (epoch > currentEpoch) { + throw new IllegalArgumentException("Attempt to resign from epoch " + epoch + + " which is larger than the current epoch " + currentEpoch); + } else if (epoch < currentEpoch) { + // If the passed epoch is smaller than the current epoch, then it might mean + // that the listener has not been notified about a leader change that already + // took place. In this case, we consider the call as already fulfilled and + // take no further action. + logger.debug("Ignoring call to resign from epoch {} since it is smaller than the " + + "current epoch {}", epoch, currentEpoch); + return; + } else if (!leaderAndEpoch.isLeader(quorum.localIdOrThrow())) { + throw new IllegalArgumentException("Cannot resign from epoch " + epoch + + " since we are not the leader"); + } else { + // Note that if we transition to another state before we have a chance to + // request resignation, then we consider the call fulfilled. + Optional> leaderStateOpt = quorum.maybeLeaderState(); + if (!leaderStateOpt.isPresent()) { + logger.debug("Ignoring call to resign from epoch {} since this node is " + + "no longer the leader", epoch); + return; + } + + LeaderState leaderState = leaderStateOpt.get(); + if (leaderState.epoch() != epoch) { + logger.debug("Ignoring call to resign from epoch {} since it is smaller than the " + + "current epoch {}", epoch, leaderState.epoch()); + } else { + logger.info("Received user request to resign from the current epoch {}", currentEpoch); + leaderState.requestResign(); + wakeup(); + } + } + } + + @Override + public Optional> createSnapshot( + long committedOffset, + int committedEpoch, + long lastContainedLogTime + ) { + return SnapshotWriter.createWithHeader( + () -> log.createNewSnapshot(new OffsetAndEpoch(committedOffset + 1, committedEpoch)), + MAX_BATCH_SIZE_BYTES, + memoryPool, + time, + lastContainedLogTime, + CompressionType.NONE, + serde + ); + } + + @Override + public void close() { + if (kafkaRaftMetrics != null) { + kafkaRaftMetrics.close(); + } + } + + QuorumState quorum() { + return quorum; + } + + public OptionalLong highWatermark() { + if (quorum.highWatermark().isPresent()) { + return OptionalLong.of(quorum.highWatermark().get().offset); + } else { + return OptionalLong.empty(); + } + } + + private class GracefulShutdown { + final Timer finishTimer; + final CompletableFuture completeFuture; + + public GracefulShutdown(long shutdownTimeoutMs, + CompletableFuture completeFuture) { + this.finishTimer = time.timer(shutdownTimeoutMs); + this.completeFuture = completeFuture; + } + + public void update(long currentTimeMs) { + finishTimer.update(currentTimeMs); + } + + public boolean hasTimedOut() { + return finishTimer.isExpired(); + } + + public boolean isFinished() { + return completeFuture.isDone(); + } + + public long remainingTimeMs() { + return finishTimer.remainingMs(); + } + + public void failWithTimeout() { + logger.warn("Graceful shutdown timed out after {}ms", finishTimer.timeoutMs()); + completeFuture.completeExceptionally( + new TimeoutException("Timeout expired before graceful shutdown completed")); + } + + public void complete() { + logger.info("Graceful shutdown completed"); + completeFuture.complete(null); + } + } + + private static final class Registration { + private final Ops ops; + private final Listener listener; + + private Registration(Ops ops, Listener listener) { + this.ops = ops; + this.listener = listener; + } + + private Ops ops() { + return ops; + } + + private Listener listener() { + return listener; + } + + private enum Ops { + REGISTER, UNREGISTER + } + + private static Registration register(Listener listener) { + return new Registration<>(Ops.REGISTER, listener); + } + + private static Registration unregister(Listener listener) { + return new Registration<>(Ops.UNREGISTER, listener); + } + } + + private final class ListenerContext implements CloseListener> { + private final RaftClient.Listener listener; + // This field is used only by the Raft IO thread + private LeaderAndEpoch lastFiredLeaderChange = new LeaderAndEpoch(OptionalInt.empty(), 0); + + // These fields are visible to both the Raft IO thread and the listener + // and are protected through synchronization on this ListenerContext instance + private BatchReader lastSent = null; + private long nextOffset = 0; + + private ListenerContext(Listener listener) { + this.listener = listener; + } + + /** + * Get the last acked offset, which is one greater than the offset of the + * last record which was acked by the state machine. + */ + private synchronized long nextOffset() { + return nextOffset; + } + + /** + * Get the next expected offset, which might be larger than the last acked + * offset if there are inflight batches which have not been acked yet. + * Note that when fetching from disk, we may not know the last offset of + * inflight data until it has been processed by the state machine. In this case, + * we delay sending additional data until the state machine has read to the + * end and the last offset is determined. + */ + private synchronized OptionalLong nextExpectedOffset() { + if (lastSent != null) { + OptionalLong lastSentOffset = lastSent.lastOffset(); + if (lastSentOffset.isPresent()) { + return OptionalLong.of(lastSentOffset.getAsLong() + 1); + } else { + return OptionalLong.empty(); + } + } else { + return OptionalLong.of(nextOffset); + } + } + + /** + * This API is used when the Listener needs to be notified of a new snapshot. This happens + * when the context's next offset is less than the log start offset. + */ + private void fireHandleSnapshot(SnapshotReader reader) { + synchronized (this) { + nextOffset = reader.snapshotId().offset; + lastSent = null; + } + + logger.debug("Notifying listener {} of snapshot {}", listenerName(), reader.snapshotId()); + listener.handleSnapshot(reader); + } + + /** + * This API is used for committed records that have been received through + * replication. In general, followers will write new data to disk before they + * know whether it has been committed. Rather than retaining the uncommitted + * data in memory, we let the state machine read the records from disk. + */ + private void fireHandleCommit(long baseOffset, Records records) { + fireHandleCommit( + RecordsBatchReader.of( + baseOffset, + records, + serde, + BufferSupplier.create(), + MAX_BATCH_SIZE_BYTES, + this + ) + ); + } + + /** + * This API is used for committed records originating from {@link #scheduleAppend(int, List)} + * or {@link #scheduleAtomicAppend(int, List)} on this instance. In this case, we are able to + * save the original record objects, which saves the need to read them back from disk. This is + * a nice optimization for the leader which is typically doing more work than all of the + * followers. + */ + private void fireHandleCommit( + long baseOffset, + int epoch, + long appendTimestamp, + int sizeInBytes, + List records + ) { + Batch batch = Batch.data(baseOffset, epoch, appendTimestamp, sizeInBytes, records); + MemoryBatchReader reader = MemoryBatchReader.of(Collections.singletonList(batch), this); + fireHandleCommit(reader); + } + + private String listenerName() { + return KafkaRaftClient.listenerName(listener); + } + + private void fireHandleCommit(BatchReader reader) { + synchronized (this) { + this.lastSent = reader; + } + logger.debug( + "Notifying listener {} of batch for baseOffset {} and lastOffset {}", + listenerName(), + reader.baseOffset(), + reader.lastOffset() + ); + listener.handleCommit(reader); + } + + private void maybeFireLeaderChange(LeaderAndEpoch leaderAndEpoch) { + if (shouldFireLeaderChange(leaderAndEpoch)) { + lastFiredLeaderChange = leaderAndEpoch; + logger.debug("Notifying listener {} of leader change {}", listenerName(), leaderAndEpoch); + listener.handleLeaderChange(leaderAndEpoch); + } + } + + private boolean shouldFireLeaderChange(LeaderAndEpoch leaderAndEpoch) { + if (leaderAndEpoch.equals(lastFiredLeaderChange)) { + return false; + } else if (leaderAndEpoch.epoch() > lastFiredLeaderChange.epoch()) { + return true; + } else { + return leaderAndEpoch.leaderId().isPresent() && + !lastFiredLeaderChange.leaderId().isPresent(); + } + } + + private void maybeFireLeaderChange(LeaderAndEpoch leaderAndEpoch, long epochStartOffset) { + // If this node is becoming the leader, then we can fire `handleClaim` as soon + // as the listener has caught up to the start of the leader epoch. This guarantees + // that the state machine has seen the full committed state before it becomes + // leader and begins writing to the log. + if (shouldFireLeaderChange(leaderAndEpoch) && nextOffset() >= epochStartOffset) { + lastFiredLeaderChange = leaderAndEpoch; + listener.handleLeaderChange(leaderAndEpoch); + } + } + + public synchronized void onClose(BatchReader reader) { + OptionalLong lastOffset = reader.lastOffset(); + + if (lastOffset.isPresent()) { + nextOffset = lastOffset.getAsLong() + 1; + } + + if (lastSent == reader) { + lastSent = null; + wakeup(); + } + } + } +} diff --git a/raft/src/main/java/org/apache/kafka/raft/LeaderAndEpoch.java b/raft/src/main/java/org/apache/kafka/raft/LeaderAndEpoch.java new file mode 100644 index 0000000..fee0c2f --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/LeaderAndEpoch.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import java.util.Objects; +import java.util.OptionalInt; + +public class LeaderAndEpoch { + private final OptionalInt leaderId; + private final int epoch; + + public LeaderAndEpoch(OptionalInt leaderId, int epoch) { + this.leaderId = Objects.requireNonNull(leaderId); + this.epoch = epoch; + } + + public OptionalInt leaderId() { + return leaderId; + } + + public int epoch() { + return epoch; + } + + public boolean isLeader(int nodeId) { + return leaderId.isPresent() && leaderId.getAsInt() == nodeId; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + LeaderAndEpoch that = (LeaderAndEpoch) o; + return epoch == that.epoch && + leaderId.equals(that.leaderId); + } + + @Override + public int hashCode() { + return Objects.hash(leaderId, epoch); + } + + @Override + public String toString() { + return "LeaderAndEpoch(" + + "leaderId=" + leaderId + + ", epoch=" + epoch + + ')'; + } +} diff --git a/raft/src/main/java/org/apache/kafka/raft/LeaderState.java b/raft/src/main/java/org/apache/kafka/raft/LeaderState.java new file mode 100644 index 0000000..de08b7b --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/LeaderState.java @@ -0,0 +1,394 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.raft.internals.BatchAccumulator; +import org.slf4j.Logger; + +import org.apache.kafka.common.message.LeaderChangeMessage; +import org.apache.kafka.common.message.LeaderChangeMessage.Voter; +import org.apache.kafka.common.record.ControlRecordUtils; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.OptionalLong; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * In the context of LeaderState, an acknowledged voter means one who has acknowledged the current leader by either + * responding to a `BeginQuorumEpoch` request from the leader or by beginning to send `Fetch` requests. + * More specifically, the set of unacknowledged voters are targets for BeginQuorumEpoch requests from the leader until + * they acknowledge the leader. + */ +public class LeaderState implements EpochState { + static final long OBSERVER_SESSION_TIMEOUT_MS = 300_000L; + + private final int localId; + private final int epoch; + private final long epochStartOffset; + + private Optional highWatermark; + private final Map voterStates = new HashMap<>(); + private final Map observerStates = new HashMap<>(); + private final Set grantingVoters = new HashSet<>(); + private final Logger log; + private final BatchAccumulator accumulator; + + // This is volatile because resignation can be requested from an external thread. + private volatile boolean resignRequested = false; + + protected LeaderState( + int localId, + int epoch, + long epochStartOffset, + Set voters, + Set grantingVoters, + BatchAccumulator accumulator, + LogContext logContext + ) { + this.localId = localId; + this.epoch = epoch; + this.epochStartOffset = epochStartOffset; + this.highWatermark = Optional.empty(); + + for (int voterId : voters) { + boolean hasAcknowledgedLeader = voterId == localId; + this.voterStates.put(voterId, new ReplicaState(voterId, hasAcknowledgedLeader)); + } + this.grantingVoters.addAll(grantingVoters); + this.log = logContext.logger(LeaderState.class); + this.accumulator = Objects.requireNonNull(accumulator, "accumulator must be non-null"); + } + + public BatchAccumulator accumulator() { + return this.accumulator; + } + + private static List convertToVoters(Set voterIds) { + return voterIds.stream() + .map(follower -> new Voter().setVoterId(follower)) + .collect(Collectors.toList()); + } + + public void appendLeaderChangeMessage(long currentTimeMs) { + List voters = convertToVoters(voterStates.keySet()); + List grantingVoters = convertToVoters(this.grantingVoters()); + + LeaderChangeMessage leaderChangeMessage = new LeaderChangeMessage() + .setVersion(ControlRecordUtils.LEADER_CHANGE_SCHEMA_HIGHEST_VERSION) + .setLeaderId(this.election().leaderId()) + .setVoters(voters) + .setGrantingVoters(grantingVoters); + + accumulator.appendLeaderChangeMessage(leaderChangeMessage, currentTimeMs); + accumulator.forceDrain(); + } + + public boolean isResignRequested() { + return resignRequested; + } + + public void requestResign() { + this.resignRequested = true; + } + + @Override + public Optional highWatermark() { + return highWatermark; + } + + @Override + public ElectionState election() { + return ElectionState.withElectedLeader(epoch, localId, voterStates.keySet()); + } + + @Override + public int epoch() { + return epoch; + } + + public Set grantingVoters() { + return this.grantingVoters; + } + + public int localId() { + return localId; + } + + public Set nonAcknowledgingVoters() { + Set nonAcknowledging = new HashSet<>(); + for (ReplicaState state : voterStates.values()) { + if (!state.hasAcknowledgedLeader) + nonAcknowledging.add(state.nodeId); + } + return nonAcknowledging; + } + + private boolean updateHighWatermark() { + // Find the largest offset which is replicated to a majority of replicas (the leader counts) + List followersByDescendingFetchOffset = followersByDescendingFetchOffset(); + + int indexOfHw = voterStates.size() / 2; + Optional highWatermarkUpdateOpt = followersByDescendingFetchOffset.get(indexOfHw).endOffset; + + if (highWatermarkUpdateOpt.isPresent()) { + + // The KRaft protocol requires an extra condition on commitment after a leader + // election. The leader must commit one record from its own epoch before it is + // allowed to expose records from any previous epoch. This guarantees that its + // log will contain the largest record (in terms of epoch/offset) in any log + // which ensures that any future leader will have replicated this record as well + // as all records from previous epochs that the current leader has committed. + + LogOffsetMetadata highWatermarkUpdateMetadata = highWatermarkUpdateOpt.get(); + long highWatermarkUpdateOffset = highWatermarkUpdateMetadata.offset; + + if (highWatermarkUpdateOffset > epochStartOffset) { + if (highWatermark.isPresent()) { + LogOffsetMetadata currentHighWatermarkMetadata = highWatermark.get(); + if (highWatermarkUpdateOffset > currentHighWatermarkMetadata.offset + || (highWatermarkUpdateOffset == currentHighWatermarkMetadata.offset && + !highWatermarkUpdateMetadata.metadata.equals(currentHighWatermarkMetadata.metadata))) { + highWatermark = highWatermarkUpdateOpt; + log.trace( + "High watermark updated to {} based on indexOfHw {} and voters {}", + highWatermark, + indexOfHw, + followersByDescendingFetchOffset + ); + return true; + } else if (highWatermarkUpdateOffset < currentHighWatermarkMetadata.offset) { + log.error("The latest computed high watermark {} is smaller than the current " + + "value {}, which suggests that one of the voters has lost committed data. " + + "Full voter replication state: {}", highWatermarkUpdateOffset, + currentHighWatermarkMetadata.offset, voterStates.values()); + return false; + } else { + return false; + } + } else { + highWatermark = highWatermarkUpdateOpt; + log.trace( + "High watermark set to {} based on indexOfHw {} and voters {}", + highWatermark, + indexOfHw, + followersByDescendingFetchOffset + ); + return true; + } + } + } + return false; + } + + /** + * Update the local replica state. + * + * See {@link #updateReplicaState(int, long, LogOffsetMetadata)} + */ + public boolean updateLocalState(long fetchTimestamp, LogOffsetMetadata logOffsetMetadata) { + return updateReplicaState(localId, fetchTimestamp, logOffsetMetadata); + } + + /** + * Update the replica state in terms of fetch time and log end offsets. + * + * @param replicaId replica id + * @param fetchTimestamp fetch timestamp + * @param logOffsetMetadata new log offset and metadata + * @return true if the high watermark is updated too + */ + public boolean updateReplicaState(int replicaId, + long fetchTimestamp, + LogOffsetMetadata logOffsetMetadata) { + // Ignore fetches from negative replica id, as it indicates + // the fetch is from non-replica. For example, a consumer. + if (replicaId < 0) { + return false; + } + + ReplicaState state = getReplicaState(replicaId); + state.updateFetchTimestamp(fetchTimestamp); + return updateEndOffset(state, logOffsetMetadata); + } + + public List nonLeaderVotersByDescendingFetchOffset() { + return followersByDescendingFetchOffset().stream() + .filter(state -> state.nodeId != localId) + .map(state -> state.nodeId) + .collect(Collectors.toList()); + } + + private List followersByDescendingFetchOffset() { + return new ArrayList<>(this.voterStates.values()).stream() + .sorted() + .collect(Collectors.toList()); + } + + private boolean updateEndOffset(ReplicaState state, + LogOffsetMetadata endOffsetMetadata) { + state.endOffset.ifPresent(currentEndOffset -> { + if (currentEndOffset.offset > endOffsetMetadata.offset) { + if (state.nodeId == localId) { + throw new IllegalStateException("Detected non-monotonic update of local " + + "end offset: " + currentEndOffset.offset + " -> " + endOffsetMetadata.offset); + } else { + log.warn("Detected non-monotonic update of fetch offset from nodeId {}: {} -> {}", + state.nodeId, currentEndOffset.offset, endOffsetMetadata.offset); + } + } + }); + + state.endOffset = Optional.of(endOffsetMetadata); + state.hasAcknowledgedLeader = true; + return isVoter(state.nodeId) && updateHighWatermark(); + } + + public void addAcknowledgementFrom(int remoteNodeId) { + ReplicaState voterState = ensureValidVoter(remoteNodeId); + voterState.hasAcknowledgedLeader = true; + } + + private ReplicaState ensureValidVoter(int remoteNodeId) { + ReplicaState state = voterStates.get(remoteNodeId); + if (state == null) + throw new IllegalArgumentException("Unexpected acknowledgement from non-voter " + remoteNodeId); + return state; + } + + public long epochStartOffset() { + return epochStartOffset; + } + + private ReplicaState getReplicaState(int remoteNodeId) { + ReplicaState state = voterStates.get(remoteNodeId); + if (state == null) { + observerStates.putIfAbsent(remoteNodeId, new ReplicaState(remoteNodeId, false)); + return observerStates.get(remoteNodeId); + } + return state; + } + + Map getVoterEndOffsets() { + return getReplicaEndOffsets(voterStates); + } + + Map getObserverStates(final long currentTimeMs) { + clearInactiveObservers(currentTimeMs); + return getReplicaEndOffsets(observerStates); + } + + private static Map getReplicaEndOffsets( + Map replicaStates) { + return replicaStates.entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, + e -> e.getValue().endOffset.map( + logOffsetMetadata -> logOffsetMetadata.offset).orElse(-1L)) + ); + } + + private void clearInactiveObservers(final long currentTimeMs) { + observerStates.entrySet().removeIf( + integerReplicaStateEntry -> + currentTimeMs - integerReplicaStateEntry.getValue().lastFetchTimestamp.orElse(-1) + >= OBSERVER_SESSION_TIMEOUT_MS); + } + + private boolean isVoter(int remoteNodeId) { + return voterStates.containsKey(remoteNodeId); + } + + private static class ReplicaState implements Comparable { + final int nodeId; + Optional endOffset; + OptionalLong lastFetchTimestamp; + boolean hasAcknowledgedLeader; + + public ReplicaState(int nodeId, boolean hasAcknowledgedLeader) { + this.nodeId = nodeId; + this.endOffset = Optional.empty(); + this.lastFetchTimestamp = OptionalLong.empty(); + this.hasAcknowledgedLeader = hasAcknowledgedLeader; + } + + void updateFetchTimestamp(long currentFetchTimeMs) { + // To be resilient to system time shifts we do not strictly + // require the timestamp be monotonically increasing. + lastFetchTimestamp = OptionalLong.of(Math.max(lastFetchTimestamp.orElse(-1L), currentFetchTimeMs)); + } + + @Override + public int compareTo(ReplicaState that) { + if (this.endOffset.equals(that.endOffset)) + return Integer.compare(this.nodeId, that.nodeId); + else if (!this.endOffset.isPresent()) + return 1; + else if (!that.endOffset.isPresent()) + return -1; + else + return Long.compare(that.endOffset.get().offset, this.endOffset.get().offset); + } + + @Override + public String toString() { + return String.format( + "ReplicaState(nodeId=%s, endOffset=%s, lastFetchTimestamp=%s, hasAcknowledgedLeader=%s)", + nodeId, + endOffset, + lastFetchTimestamp, + hasAcknowledgedLeader + ); + } + } + + @Override + public boolean canGrantVote(int candidateId, boolean isLogUpToDate) { + log.debug("Rejecting vote request from candidate {} since we are already leader in epoch {}", + candidateId, epoch); + return false; + } + + @Override + public String toString() { + return String.format( + "Leader(localId=%s, epoch=%s, epochStartOffset=%s, highWatermark=%s, voterStates=%s)", + localId, + epoch, + epochStartOffset, + highWatermark, + voterStates + ); + } + + @Override + public String name() { + return "Leader"; + } + + @Override + public void close() { + accumulator.close(); + } + +} diff --git a/raft/src/main/java/org/apache/kafka/raft/LogAppendInfo.java b/raft/src/main/java/org/apache/kafka/raft/LogAppendInfo.java new file mode 100644 index 0000000..6dc036f --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/LogAppendInfo.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +/** + * Metadata for the record batch appended to log + */ +public class LogAppendInfo { + + public final long firstOffset; + public final long lastOffset; + + public LogAppendInfo(long firstOffset, long lastOffset) { + this.firstOffset = firstOffset; + this.lastOffset = lastOffset; + } +} diff --git a/raft/src/main/java/org/apache/kafka/raft/LogFetchInfo.java b/raft/src/main/java/org/apache/kafka/raft/LogFetchInfo.java new file mode 100644 index 0000000..7aca7ea --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/LogFetchInfo.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.record.Records; + +/** + * Metadata for the records fetched from log, including the records itself + */ +public class LogFetchInfo { + + public final Records records; + public final LogOffsetMetadata startOffsetMetadata; + + public LogFetchInfo(Records records, LogOffsetMetadata startOffsetMetadata) { + this.records = records; + this.startOffsetMetadata = startOffsetMetadata; + } +} diff --git a/raft/src/main/java/org/apache/kafka/raft/LogOffsetMetadata.java b/raft/src/main/java/org/apache/kafka/raft/LogOffsetMetadata.java new file mode 100644 index 0000000..6a96619 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/LogOffsetMetadata.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import java.util.Objects; +import java.util.Optional; + +/** + * Metadata for specific local log offset + */ +public class LogOffsetMetadata { + + public final long offset; + public final Optional metadata; + + public LogOffsetMetadata(long offset) { + this(offset, Optional.empty()); + } + + public LogOffsetMetadata(long offset, Optional metadata) { + this.offset = offset; + this.metadata = metadata; + } + + @Override + public String toString() { + return "LogOffsetMetadata(offset=" + offset + + ", metadata=" + metadata + ")"; + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof LogOffsetMetadata) { + LogOffsetMetadata other = (LogOffsetMetadata) obj; + return this.offset == other.offset && + this.metadata.equals(other.metadata); + } else { + return false; + } + } + + @Override + public int hashCode() { + return Objects.hash(offset, metadata); + } +} diff --git a/raft/src/main/java/org/apache/kafka/raft/NetworkChannel.java b/raft/src/main/java/org/apache/kafka/raft/NetworkChannel.java new file mode 100644 index 0000000..e3482e5 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/NetworkChannel.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import java.io.Closeable; + +/** + * A simple network interface with few assumptions. We do not assume ordering + * of requests or even that every outbound request will receive a response. + */ +public interface NetworkChannel extends Closeable { + + /** + * Generate a new and unique correlationId for a new request to be sent. + */ + int newCorrelationId(); + + /** + * Send an outbound request message. + * + * @param request outbound request to send + */ + void send(RaftRequest.Outbound request); + + /** + * Update connection information for the given id. + */ + void updateEndpoint(int id, RaftConfig.InetAddressSpec address); + + default void close() {} + +} diff --git a/raft/src/main/java/org/apache/kafka/raft/OffsetAndEpoch.java b/raft/src/main/java/org/apache/kafka/raft/OffsetAndEpoch.java new file mode 100644 index 0000000..a4b98d7 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/OffsetAndEpoch.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +public class OffsetAndEpoch implements Comparable { + public final long offset; + public final int epoch; + + public OffsetAndEpoch(long offset, int epoch) { + this.offset = offset; + this.epoch = epoch; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + OffsetAndEpoch that = (OffsetAndEpoch) o; + + if (offset != that.offset) return false; + return epoch == that.epoch; + } + + @Override + public int hashCode() { + int result = (int) (offset ^ (offset >>> 32)); + result = 31 * result + epoch; + return result; + } + + @Override + public String toString() { + return "OffsetAndEpoch(" + + "offset=" + offset + + ", epoch=" + epoch + + ')'; + } + + @Override + public int compareTo(OffsetAndEpoch o) { + if (epoch == o.epoch) + return Long.compare(offset, o.offset); + return Integer.compare(epoch, o.epoch); + } +} diff --git a/raft/src/main/java/org/apache/kafka/raft/OffsetMetadata.java b/raft/src/main/java/org/apache/kafka/raft/OffsetMetadata.java new file mode 100644 index 0000000..3cd89e9 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/OffsetMetadata.java @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +// Opaque metadata type which should be instantiated by the log implementation +public interface OffsetMetadata { +} diff --git a/raft/src/main/java/org/apache/kafka/raft/QuorumState.java b/raft/src/main/java/org/apache/kafka/raft/QuorumState.java new file mode 100644 index 0000000..23447a9 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/QuorumState.java @@ -0,0 +1,568 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.raft.internals.BatchAccumulator; +import org.slf4j.Logger; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.Random; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * This class is responsible for managing the current state of this node and ensuring + * only valid state transitions. Below we define the possible state transitions and + * how they are triggered: + * + * Unattached|Resigned transitions to: + * Unattached: After learning of a new election with a higher epoch + * Voted: After granting a vote to a candidate + * Candidate: After expiration of the election timeout + * Follower: After discovering a leader with an equal or larger epoch + * + * Voted transitions to: + * Unattached: After learning of a new election with a higher epoch + * Candidate: After expiration of the election timeout + * + * Candidate transitions to: + * Unattached: After learning of a new election with a higher epoch + * Candidate: After expiration of the election timeout + * Leader: After receiving a majority of votes + * + * Leader transitions to: + * Unattached: After learning of a new election with a higher epoch + * Resigned: When shutting down gracefully + * + * Follower transitions to: + * Unattached: After learning of a new election with a higher epoch + * Candidate: After expiration of the fetch timeout + * Follower: After discovering a leader with a larger epoch + * + * Observers follow a simpler state machine. The Voted/Candidate/Leader/Resigned + * states are not possible for observers, so the only transitions that are possible + * are between Unattached and Follower. + * + * Unattached transitions to: + * Unattached: After learning of a new election with a higher epoch + * Follower: After discovering a leader with an equal or larger epoch + * + * Follower transitions to: + * Unattached: After learning of a new election with a higher epoch + * Follower: After discovering a leader with a larger epoch + * + */ +public class QuorumState { + private final OptionalInt localId; + private final Time time; + private final Logger log; + private final QuorumStateStore store; + private final Set voters; + private final Random random; + private final int electionTimeoutMs; + private final int fetchTimeoutMs; + private final LogContext logContext; + + private volatile EpochState state; + + public QuorumState(OptionalInt localId, + Set voters, + int electionTimeoutMs, + int fetchTimeoutMs, + QuorumStateStore store, + Time time, + LogContext logContext, + Random random) { + this.localId = localId; + this.voters = new HashSet<>(voters); + this.electionTimeoutMs = electionTimeoutMs; + this.fetchTimeoutMs = fetchTimeoutMs; + this.store = store; + this.time = time; + this.log = logContext.logger(QuorumState.class); + this.random = random; + this.logContext = logContext; + } + + public void initialize(OffsetAndEpoch logEndOffsetAndEpoch) throws IllegalStateException { + // We initialize in whatever state we were in on shutdown. If we were a leader + // or candidate, probably an election was held, but we will find out about it + // when we send Vote or BeginEpoch requests. + + ElectionState election; + try { + election = store.readElectionState(); + if (election == null) { + election = ElectionState.withUnknownLeader(0, voters); + } + } catch (final UncheckedIOException e) { + // For exceptions during state file loading (missing or not readable), + // we could assume the file is corrupted already and should be cleaned up. + log.warn("Clearing local quorum state store after error loading state {}", + store.toString(), e); + store.clear(); + election = ElectionState.withUnknownLeader(0, voters); + } + + final EpochState initialState; + if (!election.voters().isEmpty() && !voters.equals(election.voters())) { + throw new IllegalStateException("Configured voter set: " + voters + + " is different from the voter set read from the state file: " + election.voters() + + ". Check if the quorum configuration is up to date, " + + "or wipe out the local state file if necessary"); + } else if (election.hasVoted() && !isVoter()) { + String localIdDescription = localId.isPresent() ? + localId.getAsInt() + " is not a voter" : + "is undefined"; + throw new IllegalStateException("Initialized quorum state " + election + + " with a voted candidate, which indicates this node was previously " + + " a voter, but the local id " + localIdDescription); + } else if (election.epoch < logEndOffsetAndEpoch.epoch) { + log.warn("Epoch from quorum-state file is {}, which is " + + "smaller than last written epoch {} in the log", + election.epoch, logEndOffsetAndEpoch.epoch); + initialState = new UnattachedState( + time, + logEndOffsetAndEpoch.epoch, + voters, + Optional.empty(), + randomElectionTimeoutMs(), + logContext + ); + } else if (localId.isPresent() && election.isLeader(localId.getAsInt())) { + // If we were previously a leader, then we will start out as resigned + // in the same epoch. This serves two purposes: + // 1. It ensures that we cannot vote for another leader in the same epoch. + // 2. It protects the invariant that each record is uniquely identified by + // offset and epoch, which might otherwise be violated if unflushed data + // is lost after restarting. + initialState = new ResignedState( + time, + localId.getAsInt(), + election.epoch, + voters, + randomElectionTimeoutMs(), + Collections.emptyList(), + logContext + ); + } else if (localId.isPresent() && election.isVotedCandidate(localId.getAsInt())) { + initialState = new CandidateState( + time, + localId.getAsInt(), + election.epoch, + voters, + Optional.empty(), + 1, + randomElectionTimeoutMs(), + logContext + ); + } else if (election.hasVoted()) { + initialState = new VotedState( + time, + election.epoch, + election.votedId(), + voters, + Optional.empty(), + randomElectionTimeoutMs(), + logContext + ); + } else if (election.hasLeader()) { + initialState = new FollowerState( + time, + election.epoch, + election.leaderId(), + voters, + Optional.empty(), + fetchTimeoutMs, + logContext + ); + } else { + initialState = new UnattachedState( + time, + election.epoch, + voters, + Optional.empty(), + randomElectionTimeoutMs(), + logContext + ); + } + + transitionTo(initialState); + } + + public Set remoteVoters() { + return voters.stream().filter(voterId -> voterId != localIdOrSentinel()).collect(Collectors.toSet()); + } + + public int localIdOrSentinel() { + return localId.orElse(-1); + } + + public int localIdOrThrow() { + return localId.orElseThrow(() -> new IllegalStateException("Required local id is not present")); + } + + public OptionalInt localId() { + return localId; + } + + public int epoch() { + return state.epoch(); + } + + public int leaderIdOrSentinel() { + return leaderId().orElse(-1); + } + + public Optional highWatermark() { + return state.highWatermark(); + } + + public OptionalInt leaderId() { + + ElectionState election = state.election(); + if (election.hasLeader()) + return OptionalInt.of(state.election().leaderId()); + else + return OptionalInt.empty(); + } + + public boolean hasLeader() { + return leaderId().isPresent(); + } + + public boolean hasRemoteLeader() { + return hasLeader() && leaderIdOrSentinel() != localIdOrSentinel(); + } + + public boolean isVoter() { + return localId.isPresent() && voters.contains(localId.getAsInt()); + } + + public boolean isVoter(int nodeId) { + return voters.contains(nodeId); + } + + public boolean isObserver() { + return !isVoter(); + } + + public void transitionToResigned(List preferredSuccessors) { + if (!isLeader()) { + throw new IllegalStateException("Invalid transition to Resigned state from " + state); + } + + // The Resigned state is a soft state which does not need to be persisted. + // A leader will always be re-initialized in this state. + int epoch = state.epoch(); + this.state = new ResignedState( + time, + localIdOrThrow(), + epoch, + voters, + randomElectionTimeoutMs(), + preferredSuccessors, + logContext + ); + log.info("Completed transition to {}", state); + } + + /** + * Transition to the "unattached" state. This means we have found an epoch greater than + * or equal to the current epoch, but wo do not yet know of the elected leader. + */ + public void transitionToUnattached(int epoch) { + int currentEpoch = state.epoch(); + if (epoch <= currentEpoch) { + throw new IllegalStateException("Cannot transition to Unattached with epoch= " + epoch + + " from current state " + state); + } + + final long electionTimeoutMs; + if (isObserver()) { + electionTimeoutMs = Long.MAX_VALUE; + } else if (isCandidate()) { + electionTimeoutMs = candidateStateOrThrow().remainingElectionTimeMs(time.milliseconds()); + } else if (isVoted()) { + electionTimeoutMs = votedStateOrThrow().remainingElectionTimeMs(time.milliseconds()); + } else if (isUnattached()) { + electionTimeoutMs = unattachedStateOrThrow().remainingElectionTimeMs(time.milliseconds()); + } else { + electionTimeoutMs = randomElectionTimeoutMs(); + } + + transitionTo(new UnattachedState( + time, + epoch, + voters, + state.highWatermark(), + electionTimeoutMs, + logContext + )); + } + + /** + * Grant a vote to a candidate and become a follower for this epoch. We will remain in this + * state until either the election timeout expires or a leader is elected. In particular, + * we do not begin fetching until the election has concluded and {@link #transitionToFollower(int, int)} + * is invoked. + */ + public void transitionToVoted( + int epoch, + int candidateId + ) { + if (localId.isPresent() && candidateId == localId.getAsInt()) { + throw new IllegalStateException("Cannot transition to Voted with votedId=" + candidateId + + " and epoch=" + epoch + " since it matches the local broker.id"); + } else if (isObserver()) { + throw new IllegalStateException("Cannot transition to Voted with votedId=" + candidateId + + " and epoch=" + epoch + " since the local broker.id=" + localId + " is not a voter"); + } else if (!isVoter(candidateId)) { + throw new IllegalStateException("Cannot transition to Voted with voterId=" + candidateId + + " and epoch=" + epoch + " since it is not one of the voters " + voters); + } + + int currentEpoch = state.epoch(); + if (epoch < currentEpoch) { + throw new IllegalStateException("Cannot transition to Voted with votedId=" + candidateId + + " and epoch=" + epoch + " since the current epoch " + currentEpoch + " is larger"); + } else if (epoch == currentEpoch && !isUnattached()) { + throw new IllegalStateException("Cannot transition to Voted with votedId=" + candidateId + + " and epoch=" + epoch + " from the current state " + state); + } + + // Note that we reset the election timeout after voting for a candidate because we + // know that the candidate has at least as good of a chance of getting elected as us + + transitionTo(new VotedState( + time, + epoch, + candidateId, + voters, + state.highWatermark(), + randomElectionTimeoutMs(), + logContext + )); + } + + /** + * Become a follower of an elected leader so that we can begin fetching. + */ + public void transitionToFollower( + int epoch, + int leaderId + ) { + if (localId.isPresent() && leaderId == localId.getAsInt()) { + throw new IllegalStateException("Cannot transition to Follower with leaderId=" + leaderId + + " and epoch=" + epoch + " since it matches the local broker.id=" + localId); + } else if (!isVoter(leaderId)) { + throw new IllegalStateException("Cannot transition to Follower with leaderId=" + leaderId + + " and epoch=" + epoch + " since it is not one of the voters " + voters); + } + + int currentEpoch = state.epoch(); + if (epoch < currentEpoch) { + throw new IllegalStateException("Cannot transition to Follower with leaderId=" + leaderId + + " and epoch=" + epoch + " since the current epoch " + currentEpoch + " is larger"); + } else if (epoch == currentEpoch + && (isFollower() || isLeader())) { + throw new IllegalStateException("Cannot transition to Follower with leaderId=" + leaderId + + " and epoch=" + epoch + " from state " + state); + } + + transitionTo(new FollowerState( + time, + epoch, + leaderId, + voters, + state.highWatermark(), + fetchTimeoutMs, + logContext + )); + } + + public void transitionToCandidate() { + if (isObserver()) { + throw new IllegalStateException("Cannot transition to Candidate since the local broker.id=" + localId + + " is not one of the voters " + voters); + } else if (isLeader()) { + throw new IllegalStateException("Cannot transition to Candidate since the local broker.id=" + localId + + " since this node is already a Leader with state " + state); + } + + int retries = isCandidate() ? candidateStateOrThrow().retries() + 1 : 1; + int newEpoch = epoch() + 1; + int electionTimeoutMs = randomElectionTimeoutMs(); + + transitionTo(new CandidateState( + time, + localIdOrThrow(), + newEpoch, + voters, + state.highWatermark(), + retries, + electionTimeoutMs, + logContext + )); + } + + public LeaderState transitionToLeader(long epochStartOffset, BatchAccumulator accumulator) { + if (isObserver()) { + throw new IllegalStateException("Cannot transition to Leader since the local broker.id=" + localId + + " is not one of the voters " + voters); + } else if (!isCandidate()) { + throw new IllegalStateException("Cannot transition to Leader from current state " + state); + } + + CandidateState candidateState = candidateStateOrThrow(); + if (!candidateState.isVoteGranted()) + throw new IllegalStateException("Cannot become leader without majority votes granted"); + + // Note that the leader does not retain the high watermark that was known + // in the previous state. The reason for this is to protect the monotonicity + // of the global high watermark, which is exposed through the leader. The + // only way a new leader can be sure that the high watermark is increasing + // monotonically is to wait until a majority of the voters have reached the + // starting offset of the new epoch. The downside of this is that the local + // state machine is temporarily stalled by the advancement of the global + // high watermark even though it only depends on local monotonicity. We + // could address this problem by decoupling the local high watermark, but + // we typically expect the state machine to be caught up anyway. + + LeaderState state = new LeaderState<>( + localIdOrThrow(), + epoch(), + epochStartOffset, + voters, + candidateState.grantingVoters(), + accumulator, + logContext + ); + transitionTo(state); + return state; + } + + private void transitionTo(EpochState state) { + if (this.state != null) { + try { + this.state.close(); + } catch (IOException e) { + throw new UncheckedIOException( + "Failed to transition from " + this.state.name() + " to " + state.name(), e); + } + } + + this.store.writeElectionState(state.election()); + this.state = state; + log.info("Completed transition to {}", state); + } + + private int randomElectionTimeoutMs() { + if (electionTimeoutMs == 0) + return 0; + return electionTimeoutMs + random.nextInt(electionTimeoutMs); + } + + public boolean canGrantVote(int candidateId, boolean isLogUpToDate) { + return state.canGrantVote(candidateId, isLogUpToDate); + } + + public FollowerState followerStateOrThrow() { + if (isFollower()) + return (FollowerState) state; + throw new IllegalStateException("Expected to be Follower, but the current state is " + state); + } + + public VotedState votedStateOrThrow() { + if (isVoted()) + return (VotedState) state; + throw new IllegalStateException("Expected to be Voted, but current state is " + state); + } + + public UnattachedState unattachedStateOrThrow() { + if (isUnattached()) + return (UnattachedState) state; + throw new IllegalStateException("Expected to be Unattached, but current state is " + state); + } + + @SuppressWarnings("unchecked") + public LeaderState leaderStateOrThrow() { + if (isLeader()) + return (LeaderState) state; + throw new IllegalStateException("Expected to be Leader, but current state is " + state); + } + + @SuppressWarnings("unchecked") + public Optional> maybeLeaderState() { + EpochState state = this.state; + if (state instanceof LeaderState) { + return Optional.of((LeaderState) state); + } else { + return Optional.empty(); + } + } + + public ResignedState resignedStateOrThrow() { + if (isResigned()) + return (ResignedState) state; + throw new IllegalStateException("Expected to be Resigned, but current state is " + state); + } + + public CandidateState candidateStateOrThrow() { + if (isCandidate()) + return (CandidateState) state; + throw new IllegalStateException("Expected to be Candidate, but current state is " + state); + } + + public LeaderAndEpoch leaderAndEpoch() { + ElectionState election = state.election(); + return new LeaderAndEpoch(election.leaderIdOpt, election.epoch); + } + + public boolean isFollower() { + return state instanceof FollowerState; + } + + public boolean isVoted() { + return state instanceof VotedState; + } + + public boolean isUnattached() { + return state instanceof UnattachedState; + } + + public boolean isLeader() { + return state instanceof LeaderState; + } + + public boolean isResigned() { + return state instanceof ResignedState; + } + + public boolean isCandidate() { + return state instanceof CandidateState; + } + +} diff --git a/raft/src/main/java/org/apache/kafka/raft/QuorumStateStore.java b/raft/src/main/java/org/apache/kafka/raft/QuorumStateStore.java new file mode 100644 index 0000000..7e64e2e --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/QuorumStateStore.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +/** + * Maintain the save and retrieval of quorum state information, so far only supports + * read and write of election states. + */ +public interface QuorumStateStore { + + int UNKNOWN_LEADER_ID = -1; + int NOT_VOTED = -1; + + /** + * Read the latest election state. + * + * @return The latest written election state or `null` if there is none + */ + ElectionState readElectionState(); + + /** + * Persist the updated election state. This must be atomic, both writing the full updated state + * and replacing the old state. + * @param latest The latest election state + */ + void writeElectionState(ElectionState latest); + + /** + * Clear any state associated to the store for a fresh start + */ + void clear(); +} diff --git a/raft/src/main/java/org/apache/kafka/raft/RaftClient.java b/raft/src/main/java/org/apache/kafka/raft/RaftClient.java new file mode 100644 index 0000000..8e4f50e --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/RaftClient.java @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.raft.errors.BufferAllocationException; +import org.apache.kafka.raft.errors.NotLeaderException; +import org.apache.kafka.snapshot.SnapshotReader; +import org.apache.kafka.snapshot.SnapshotWriter; + +import java.util.List; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.concurrent.CompletableFuture; + +public interface RaftClient extends AutoCloseable { + + interface Listener { + /** + * Callback which is invoked for all records committed to the log. + * It is the responsibility of this implementation to invoke {@link BatchReader#close()} + * after consuming the reader. + * + * Note that there is not a one-to-one correspondence between writes through + * {@link #scheduleAppend(int, List)} or {@link #scheduleAtomicAppend(int, List)} + * and this callback. The Raft implementation is free to batch together the records + * from multiple append calls provided that batch boundaries are respected. Records + * specified through {@link #scheduleAtomicAppend(int, List)} are guaranteed to be a + * subset of a batch provided by the {@link BatchReader}. Records specified through + * {@link #scheduleAppend(int, List)} are guaranteed to be in the same order but + * they can map to any number of batches provided by the {@link BatchReader}. + * + * @param reader reader instance which must be iterated and closed + */ + void handleCommit(BatchReader reader); + + /** + * Callback which is invoked when the Listener needs to load a snapshot. + * It is the responsibility of this implementation to invoke {@link SnapshotReader#close()} + * after consuming the reader. + * + * When handling this call, the implementation must assume that all previous calls + * to {@link #handleCommit} contain invalid data. + * + * @param reader snapshot reader instance which must be iterated and closed + */ + void handleSnapshot(SnapshotReader reader); + + /** + * Called on any change to leadership. This includes both when a leader is elected and + * when a leader steps down or fails. + * + * If this node is the leader, then the notification of leadership will be delayed until + * the implementation of this interface has caught up to the high-watermark through calls to + * {@link #handleSnapshot(SnapshotReader)} and {@link #handleCommit(BatchReader)}. + * + * If this node is not the leader, then this method will be called as soon as possible. In + * this case the leader may or may not be known for the current epoch. + * + * Subsequent calls to this method will expose a monotonically increasing epoch. For a + * given epoch the leader may be unknown, {@code leader.leaderId} is {@code OptionalInt#empty}, + * or known {@code leader.leaderId} is {@code OptionalInt#of}. Once a leader is known for + * a given epoch it will remain the leader for that epoch. In other words, the implementation of + * method should expect this method will be called at most twice for each epoch. Once if the + * epoch changed but the leader is not known and once when the leader is known for the current + * epoch. + * + * @param leader the current leader and epoch + */ + default void handleLeaderChange(LeaderAndEpoch leader) {} + + default void beginShutdown() {} + } + + /** + * Initialize the client. This should only be called once on startup. + */ + void initialize(); + + /** + * Register a listener to get commit, snapshot and leader notifications. + * + * The implementation of this interface assumes that each call to {@code register} uses + * a different {@code Listener} instance. If the same instance is used for multiple calls + * to this method, then only one {@code Listener} will be registered. + * + * @param listener the listener to register + */ + void register(Listener listener); + + /** + * Unregisters a listener. + * + * To distinguish from events that happend before the call to {@code unregister} and a future + * call to {@code register}, different {@code Listener} instances must be used. + * + * If the {@code Listener} provided was never registered then the unregistration is ignored. + * + * @param listener the listener to unregister + */ + void unregister(Listener listener); + + /** + * Return the current {@link LeaderAndEpoch}. + * + * @return the current leader and epoch + */ + LeaderAndEpoch leaderAndEpoch(); + + /** + * Get local nodeId if one is defined. This may be absent when the client is used + * as an anonymous observer, as in the case of the metadata shell. + * + * @return optional node id + */ + OptionalInt nodeId(); + + /** + * Append a list of records to the log. The write will be scheduled for some time + * in the future. There is no guarantee that appended records will be written to + * the log and eventually committed. While the order of the records is preserve, they can + * be appended to the log using one or more batches. Each record may be committed independently. + * If a record is committed, then all records scheduled for append during this epoch + * and prior to this record are also committed. + * + * If the provided current leader epoch does not match the current epoch, which + * is possible when the state machine has yet to observe the epoch change, then + * this method will throw an {@link NotLeaderException} to indicate the leader + * to resign its leadership. The state machine is expected to discard all + * uncommitted entries after observing an epoch change. + * + * @param epoch the current leader epoch + * @param records the list of records to append + * @return the expected offset of the last record if append succeed + * @throws org.apache.kafka.common.errors.RecordBatchTooLargeException if the size of the records is greater than the maximum + * batch size; if this exception is throw none of the elements in records were + * committed + * @throws NotLeaderException if we are not the current leader or the epoch doesn't match the leader epoch + * @throws BufferAllocationException if we failed to allocate memory for the records + */ + long scheduleAppend(int epoch, List records); + + /** + * Append a list of records to the log. The write will be scheduled for some time + * in the future. There is no guarantee that appended records will be written to + * the log and eventually committed. However, it is guaranteed that if any of the + * records become committed, then all of them will be. + * + * If the provided current leader epoch does not match the current epoch, which + * is possible when the state machine has yet to observe the epoch change, then + * this method will throw an {@link NotLeaderException} to indicate the leader + * to resign its leadership. The state machine is expected to discard all + * uncommitted entries after observing an epoch change. + * + * @param epoch the current leader epoch + * @param records the list of records to append + * @return the expected offset of the last record if append succeed + * @throws org.apache.kafka.common.errors.RecordBatchTooLargeException if the size of the records is greater than the maximum + * batch size; if this exception is throw none of the elements in records were + * committed + * @throws NotLeaderException if we are not the current leader or the epoch doesn't match the leader epoch + * @throws BufferAllocationException we failed to allocate memory for the records + */ + long scheduleAtomicAppend(int epoch, List records); + + /** + * Attempt a graceful shutdown of the client. This allows the leader to proactively + * resign and help a new leader to get elected rather than forcing the remaining + * voters to wait for the fetch timeout. + * + * Note that if the client has hit an unexpected exception which has left it in an + * indeterminate state, then the call to shutdown should be skipped. However, it + * is still expected that {@link #close()} will be used to clean up any resources + * in use. + * + * @param timeoutMs How long to wait for graceful completion of pending operations. + * @return A future which is completed when shutdown completes successfully or the timeout expires. + */ + CompletableFuture shutdown(int timeoutMs); + + /** + * Resign the leadership. The leader will give up its leadership in the passed epoch + * (if it matches the current epoch), and a new election will be held. Note that nothing + * prevents this node from being reelected as the leader. + * + * Notification of successful resignation can be observed through + * {@link Listener#handleLeaderChange(LeaderAndEpoch)}. + * + * @param epoch the epoch to resign from. If this does not match the current epoch, this + * call will be ignored. + */ + void resign(int epoch); + + /** + * Create a writable snapshot file for a committed offset and epoch. + * + * The RaftClient assumes that the snapshot returned will contain the records up to and + * including the committed offset and epoch. See {@link SnapshotWriter} for details on + * how to use this object. If a snapshot already exists then returns an + * {@link Optional#empty()}. + * + * @param committedEpoch the epoch of the committed offset + * @param committedOffset the last committed offset that will be included in the snapshot + * @param lastContainedLogTime The append time of the highest record contained in this snapshot + * @return a writable snapshot if it doesn't already exists + * @throws IllegalArgumentException if the committed offset is greater than the high-watermark + * or less than the log start offset. + */ + Optional> createSnapshot(long committedOffset, int committedEpoch, long lastContainedLogTime); +} diff --git a/raft/src/main/java/org/apache/kafka/raft/RaftConfig.java b/raft/src/main/java/org/apache/kafka/raft/RaftConfig.java new file mode 100644 index 0000000..0833df0 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/RaftConfig.java @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.utils.Utils; + +import java.net.InetSocketAddress; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * RaftConfig encapsulates configuration specific to the Raft quorum voter nodes. + * Specifically, this class parses the voter node endpoints into an AddressSpec + * for use with the KafkaRaftClient/KafkaNetworkChannel. + * + * If the voter endpoints are not known at startup, a non-routable address can be provided instead. + * For example: `1@0.0.0.0:0,2@0.0.0.0:0,3@0.0.0.0:0` + * This will assign an {@link UnknownAddressSpec} to the voter entries + * + * The default raft timeouts are relatively low compared to some other timeouts such as + * request.timeout.ms. This is part of a general design philosophy where we see changing + * the leader of a Raft cluster as a relatively quick operation. For example, the KIP-631 + * controller should be able to transition from standby to active without reloading all of + * the metadata. The standby is a "hot" standby, not a "cold" one. + */ +public class RaftConfig { + + private static final String QUORUM_PREFIX = "controller.quorum."; + + // Non-routable address represents an endpoint that does not resolve to any particular node + public static final InetSocketAddress NON_ROUTABLE_ADDRESS = new InetSocketAddress("0.0.0.0", 0); + public static final UnknownAddressSpec UNKNOWN_ADDRESS_SPEC_INSTANCE = new UnknownAddressSpec(); + + public static final String QUORUM_VOTERS_CONFIG = QUORUM_PREFIX + "voters"; + public static final String QUORUM_VOTERS_DOC = "Map of id/endpoint information for " + + "the set of voters in a comma-separated list of `{id}@{host}:{port}` entries. " + + "For example: `1@localhost:9092,2@localhost:9093,3@localhost:9094`"; + public static final List DEFAULT_QUORUM_VOTERS = Collections.emptyList(); + + public static final String QUORUM_ELECTION_TIMEOUT_MS_CONFIG = QUORUM_PREFIX + "election.timeout.ms"; + public static final String QUORUM_ELECTION_TIMEOUT_MS_DOC = "Maximum time in milliseconds to wait " + + "without being able to fetch from the leader before triggering a new election"; + public static final int DEFAULT_QUORUM_ELECTION_TIMEOUT_MS = 1_000; + + public static final String QUORUM_FETCH_TIMEOUT_MS_CONFIG = QUORUM_PREFIX + "fetch.timeout.ms"; + public static final String QUORUM_FETCH_TIMEOUT_MS_DOC = "Maximum time without a successful fetch from " + + "the current leader before becoming a candidate and triggering a election for voters; Maximum time without " + + "receiving fetch from a majority of the quorum before asking around to see if there's a new epoch for leader"; + public static final int DEFAULT_QUORUM_FETCH_TIMEOUT_MS = 2_000; + + public static final String QUORUM_ELECTION_BACKOFF_MAX_MS_CONFIG = QUORUM_PREFIX + "election.backoff.max.ms"; + public static final String QUORUM_ELECTION_BACKOFF_MAX_MS_DOC = "Maximum time in milliseconds before starting new elections. " + + "This is used in the binary exponential backoff mechanism that helps prevent gridlocked elections"; + public static final int DEFAULT_QUORUM_ELECTION_BACKOFF_MAX_MS = 1_000; + + public static final String QUORUM_LINGER_MS_CONFIG = QUORUM_PREFIX + "append.linger.ms"; + public static final String QUORUM_LINGER_MS_DOC = "The duration in milliseconds that the leader will " + + "wait for writes to accumulate before flushing them to disk."; + public static final int DEFAULT_QUORUM_LINGER_MS = 25; + + public static final String QUORUM_REQUEST_TIMEOUT_MS_CONFIG = QUORUM_PREFIX + + CommonClientConfigs.REQUEST_TIMEOUT_MS_CONFIG; + public static final String QUORUM_REQUEST_TIMEOUT_MS_DOC = CommonClientConfigs.REQUEST_TIMEOUT_MS_DOC; + public static final int DEFAULT_QUORUM_REQUEST_TIMEOUT_MS = 2_000; + + public static final String QUORUM_RETRY_BACKOFF_MS_CONFIG = QUORUM_PREFIX + + CommonClientConfigs.RETRY_BACKOFF_MS_CONFIG; + public static final String QUORUM_RETRY_BACKOFF_MS_DOC = CommonClientConfigs.RETRY_BACKOFF_MS_DOC; + public static final int DEFAULT_QUORUM_RETRY_BACKOFF_MS = 20; + + private final int requestTimeoutMs; + private final int retryBackoffMs; + private final int electionTimeoutMs; + private final int electionBackoffMaxMs; + private final int fetchTimeoutMs; + private final int appendLingerMs; + private final Map voterConnections; + + public interface AddressSpec { + } + + public static class InetAddressSpec implements AddressSpec { + public final InetSocketAddress address; + + public InetAddressSpec(InetSocketAddress address) { + if (address == null || address.equals(NON_ROUTABLE_ADDRESS)) { + throw new IllegalArgumentException("Invalid address: " + address); + } + this.address = address; + } + + @Override + public int hashCode() { + return address.hashCode(); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + + if (obj == null || getClass() != obj.getClass()) { + return false; + } + + final InetAddressSpec that = (InetAddressSpec) obj; + return that.address.equals(address); + } + } + + public static class UnknownAddressSpec implements AddressSpec { + private UnknownAddressSpec() { + } + } + + public RaftConfig(AbstractConfig abstractConfig) { + this(parseVoterConnections(abstractConfig.getList(QUORUM_VOTERS_CONFIG)), + abstractConfig.getInt(QUORUM_REQUEST_TIMEOUT_MS_CONFIG), + abstractConfig.getInt(QUORUM_RETRY_BACKOFF_MS_CONFIG), + abstractConfig.getInt(QUORUM_ELECTION_TIMEOUT_MS_CONFIG), + abstractConfig.getInt(QUORUM_ELECTION_BACKOFF_MAX_MS_CONFIG), + abstractConfig.getInt(QUORUM_FETCH_TIMEOUT_MS_CONFIG), + abstractConfig.getInt(QUORUM_LINGER_MS_CONFIG)); + } + + public RaftConfig( + Map voterConnections, + int requestTimeoutMs, + int retryBackoffMs, + int electionTimeoutMs, + int electionBackoffMaxMs, + int fetchTimeoutMs, + int appendLingerMs + ) { + this.voterConnections = voterConnections; + this.requestTimeoutMs = requestTimeoutMs; + this.retryBackoffMs = retryBackoffMs; + this.electionTimeoutMs = electionTimeoutMs; + this.electionBackoffMaxMs = electionBackoffMaxMs; + this.fetchTimeoutMs = fetchTimeoutMs; + this.appendLingerMs = appendLingerMs; + } + + public int requestTimeoutMs() { + return requestTimeoutMs; + } + + public int retryBackoffMs() { + return retryBackoffMs; + } + + public int electionTimeoutMs() { + return electionTimeoutMs; + } + + public int electionBackoffMaxMs() { + return electionBackoffMaxMs; + } + + public int fetchTimeoutMs() { + return fetchTimeoutMs; + } + + public int appendLingerMs() { + return appendLingerMs; + } + + public Set quorumVoterIds() { + return quorumVoterConnections().keySet(); + } + + public Map quorumVoterConnections() { + return voterConnections; + } + + private static Integer parseVoterId(String idString) { + try { + return Integer.parseInt(idString); + } catch (NumberFormatException e) { + throw new ConfigException("Failed to parse voter ID as an integer from " + idString); + } + } + + public static Map parseVoterConnections(List voterEntries) { + Map voterMap = new HashMap<>(); + for (String voterMapEntry : voterEntries) { + String[] idAndAddress = voterMapEntry.split("@"); + if (idAndAddress.length != 2) { + throw new ConfigException("Invalid configuration value for " + QUORUM_VOTERS_CONFIG + + ". Each entry should be in the form `{id}@{host}:{port}`."); + } + + Integer voterId = parseVoterId(idAndAddress[0]); + String host = Utils.getHost(idAndAddress[1]); + if (host == null) { + throw new ConfigException("Failed to parse host name from entry " + voterMapEntry + + " for the configuration " + QUORUM_VOTERS_CONFIG + + ". Each entry should be in the form `{id}@{host}:{port}`."); + } + + Integer port = Utils.getPort(idAndAddress[1]); + if (port == null) { + throw new ConfigException("Failed to parse host port from entry " + voterMapEntry + + " for the configuration " + QUORUM_VOTERS_CONFIG + + ". Each entry should be in the form `{id}@{host}:{port}`."); + } + + InetSocketAddress address = new InetSocketAddress(host, port); + if (address.equals(NON_ROUTABLE_ADDRESS)) { + voterMap.put(voterId, UNKNOWN_ADDRESS_SPEC_INSTANCE); + } else { + voterMap.put(voterId, new InetAddressSpec(address)); + } + } + + return voterMap; + } + + public static List quorumVoterStringsToNodes(List voters) { + return voterConnectionsToNodes(parseVoterConnections(voters)); + } + + public static List voterConnectionsToNodes(Map voterConnections) { + return voterConnections.entrySet().stream() + .filter(Objects::nonNull) + .filter(connection -> connection.getValue() instanceof InetAddressSpec) + .map(connection -> { + InetAddressSpec spec = (InetAddressSpec) connection.getValue(); + return new Node(connection.getKey(), spec.address.getHostName(), spec.address.getPort()); + }) + .collect(Collectors.toList()); + } + + public static class ControllerQuorumVotersValidator implements ConfigDef.Validator { + @Override + public void ensureValid(String name, Object value) { + if (value == null) { + throw new ConfigException(name, null); + } + + @SuppressWarnings("unchecked") + List voterStrings = (List) value; + + // Attempt to parse the connect strings + parseVoterConnections(voterStrings); + } + + @Override + public String toString() { + return "non-empty list"; + } + } +} diff --git a/raft/src/main/java/org/apache/kafka/raft/RaftMessage.java b/raft/src/main/java/org/apache/kafka/raft/RaftMessage.java new file mode 100644 index 0000000..f50ec90 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/RaftMessage.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.protocol.ApiMessage; + +public interface RaftMessage { + int correlationId(); + + ApiMessage data(); + +} diff --git a/raft/src/main/java/org/apache/kafka/raft/RaftMessageQueue.java b/raft/src/main/java/org/apache/kafka/raft/RaftMessageQueue.java new file mode 100644 index 0000000..7d1e4b7 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/RaftMessageQueue.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +/** + * This class is used to serialize inbound requests or responses to outbound requests. + * It basically just allows us to wrap a blocking queue so that we can have a mocked + * implementation which does not depend on system time. + * + * See {@link org.apache.kafka.raft.internals.BlockingMessageQueue}. + */ +public interface RaftMessageQueue { + + /** + * Block for the arrival of a new message. + * + * @param timeoutMs timeout in milliseconds to wait for a new event + * @return the event or null if either the timeout was reached or there was + * a call to {@link #wakeup()} before any events became available + */ + RaftMessage poll(long timeoutMs); + + /** + * Add a new message to the queue. + * + * @param message the message to deliver + * @throws IllegalStateException if the queue cannot accept the message + */ + void add(RaftMessage message); + + /** + * Check whether there are pending messages awaiting delivery. + * + * @return if there are no pending messages to deliver + */ + boolean isEmpty(); + + /** + * Wakeup the thread blocking in {@link #poll(long)}. This will cause + * {@link #poll(long)} to return null if no messages are available. + */ + void wakeup(); + +} diff --git a/raft/src/main/java/org/apache/kafka/raft/RaftRequest.java b/raft/src/main/java/org/apache/kafka/raft/RaftRequest.java new file mode 100644 index 0000000..28e63c1 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/RaftRequest.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.protocol.ApiMessage; + +import java.util.concurrent.CompletableFuture; + +public abstract class RaftRequest implements RaftMessage { + protected final int correlationId; + protected final ApiMessage data; + protected final long createdTimeMs; + + public RaftRequest(int correlationId, ApiMessage data, long createdTimeMs) { + this.correlationId = correlationId; + this.data = data; + this.createdTimeMs = createdTimeMs; + } + + @Override + public int correlationId() { + return correlationId; + } + + @Override + public ApiMessage data() { + return data; + } + + public long createdTimeMs() { + return createdTimeMs; + } + + public static class Inbound extends RaftRequest { + public final CompletableFuture completion = new CompletableFuture<>(); + + public Inbound(int correlationId, ApiMessage data, long createdTimeMs) { + super(correlationId, data, createdTimeMs); + } + + @Override + public String toString() { + return "InboundRequest(" + + "correlationId=" + correlationId + + ", data=" + data + + ", createdTimeMs=" + createdTimeMs + + ')'; + } + } + + public static class Outbound extends RaftRequest { + private final int destinationId; + public final CompletableFuture completion = new CompletableFuture<>(); + + public Outbound(int correlationId, ApiMessage data, int destinationId, long createdTimeMs) { + super(correlationId, data, createdTimeMs); + this.destinationId = destinationId; + } + + public int destinationId() { + return destinationId; + } + + @Override + public String toString() { + return "OutboundRequest(" + + "correlationId=" + correlationId + + ", data=" + data + + ", createdTimeMs=" + createdTimeMs + + ", destinationId=" + destinationId + + ')'; + } + } +} diff --git a/raft/src/main/java/org/apache/kafka/raft/RaftResponse.java b/raft/src/main/java/org/apache/kafka/raft/RaftResponse.java new file mode 100644 index 0000000..71101a6 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/RaftResponse.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.protocol.ApiMessage; + +public abstract class RaftResponse implements RaftMessage { + protected final int correlationId; + protected final ApiMessage data; + + protected RaftResponse(int correlationId, ApiMessage data) { + this.correlationId = correlationId; + this.data = data; + } + + @Override + public int correlationId() { + return correlationId; + } + + @Override + public ApiMessage data() { + return data; + } + + public static class Inbound extends RaftResponse { + private final int sourceId; + + public Inbound(int correlationId, ApiMessage data, int sourceId) { + super(correlationId, data); + this.sourceId = sourceId; + } + + public int sourceId() { + return sourceId; + } + + @Override + public String toString() { + return "InboundResponse(" + + "correlationId=" + correlationId + + ", data=" + data + + ", sourceId=" + sourceId + + ')'; + } + } + + public static class Outbound extends RaftResponse { + public Outbound(int requestId, ApiMessage data) { + super(requestId, data); + } + + @Override + public String toString() { + return "OutboundResponse(" + + "correlationId=" + correlationId + + ", data=" + data + + ')'; + } + } +} diff --git a/raft/src/main/java/org/apache/kafka/raft/RaftUtil.java b/raft/src/main/java/org/apache/kafka/raft/RaftUtil.java new file mode 100644 index 0000000..9ff0361 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/RaftUtil.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.message.BeginQuorumEpochRequestData; +import org.apache.kafka.common.message.BeginQuorumEpochResponseData; +import org.apache.kafka.common.message.DescribeQuorumRequestData; +import org.apache.kafka.common.message.EndQuorumEpochRequestData; +import org.apache.kafka.common.message.EndQuorumEpochResponseData; +import org.apache.kafka.common.message.FetchRequestData; +import org.apache.kafka.common.message.FetchResponseData; +import org.apache.kafka.common.message.VoteRequestData; +import org.apache.kafka.common.message.VoteResponseData; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ApiMessage; +import org.apache.kafka.common.protocol.Errors; + +import java.util.Collections; +import java.util.function.Consumer; + +import static java.util.Collections.singletonList; + +public class RaftUtil { + + public static ApiMessage errorResponse(ApiKeys apiKey, Errors error) { + switch (apiKey) { + case VOTE: + return new VoteResponseData().setErrorCode(error.code()); + case BEGIN_QUORUM_EPOCH: + return new BeginQuorumEpochResponseData().setErrorCode(error.code()); + case END_QUORUM_EPOCH: + return new EndQuorumEpochResponseData().setErrorCode(error.code()); + case FETCH: + return new FetchResponseData().setErrorCode(error.code()); + default: + throw new IllegalArgumentException("Received response for unexpected request type: " + apiKey); + } + } + + public static FetchRequestData singletonFetchRequest( + TopicPartition topicPartition, + Uuid topicId, + Consumer partitionConsumer + ) { + FetchRequestData.FetchPartition fetchPartition = + new FetchRequestData.FetchPartition() + .setPartition(topicPartition.partition()); + partitionConsumer.accept(fetchPartition); + + FetchRequestData.FetchTopic fetchTopic = + new FetchRequestData.FetchTopic() + .setTopic(topicPartition.topic()) + .setTopicId(topicId) + .setPartitions(singletonList(fetchPartition)); + + return new FetchRequestData() + .setTopics(singletonList(fetchTopic)); + } + + public static FetchResponseData singletonFetchResponse( + TopicPartition topicPartition, + Uuid topicId, + Errors topLevelError, + Consumer partitionConsumer + ) { + FetchResponseData.PartitionData fetchablePartition = + new FetchResponseData.PartitionData(); + + fetchablePartition.setPartitionIndex(topicPartition.partition()); + + partitionConsumer.accept(fetchablePartition); + + FetchResponseData.FetchableTopicResponse fetchableTopic = + new FetchResponseData.FetchableTopicResponse() + .setTopic(topicPartition.topic()) + .setTopicId(topicId) + .setPartitions(Collections.singletonList(fetchablePartition)); + + return new FetchResponseData() + .setErrorCode(topLevelError.code()) + .setResponses(Collections.singletonList(fetchableTopic)); + } + + static boolean hasValidTopicPartition(FetchRequestData data, TopicPartition topicPartition, Uuid topicId) { + return data.topics().size() == 1 && + data.topics().get(0).topicId().equals(topicId) && + data.topics().get(0).partitions().size() == 1 && + data.topics().get(0).partitions().get(0).partition() == topicPartition.partition(); + } + + static boolean hasValidTopicPartition(FetchResponseData data, TopicPartition topicPartition, Uuid topicId) { + return data.responses().size() == 1 && + data.responses().get(0).topicId().equals(topicId) && + data.responses().get(0).partitions().size() == 1 && + data.responses().get(0).partitions().get(0).partitionIndex() == topicPartition.partition(); + } + + static boolean hasValidTopicPartition(VoteResponseData data, TopicPartition topicPartition) { + return data.topics().size() == 1 && + data.topics().get(0).topicName().equals(topicPartition.topic()) && + data.topics().get(0).partitions().size() == 1 && + data.topics().get(0).partitions().get(0).partitionIndex() == topicPartition.partition(); + } + + static boolean hasValidTopicPartition(VoteRequestData data, TopicPartition topicPartition) { + return data.topics().size() == 1 && + data.topics().get(0).topicName().equals(topicPartition.topic()) && + data.topics().get(0).partitions().size() == 1 && + data.topics().get(0).partitions().get(0).partitionIndex() == topicPartition.partition(); + } + + static boolean hasValidTopicPartition(BeginQuorumEpochRequestData data, TopicPartition topicPartition) { + return data.topics().size() == 1 && + data.topics().get(0).topicName().equals(topicPartition.topic()) && + data.topics().get(0).partitions().size() == 1 && + data.topics().get(0).partitions().get(0).partitionIndex() == topicPartition.partition(); + } + + static boolean hasValidTopicPartition(BeginQuorumEpochResponseData data, TopicPartition topicPartition) { + return data.topics().size() == 1 && + data.topics().get(0).topicName().equals(topicPartition.topic()) && + data.topics().get(0).partitions().size() == 1 && + data.topics().get(0).partitions().get(0).partitionIndex() == topicPartition.partition(); + } + + static boolean hasValidTopicPartition(EndQuorumEpochRequestData data, TopicPartition topicPartition) { + return data.topics().size() == 1 && + data.topics().get(0).topicName().equals(topicPartition.topic()) && + data.topics().get(0).partitions().size() == 1 && + data.topics().get(0).partitions().get(0).partitionIndex() == topicPartition.partition(); + } + + static boolean hasValidTopicPartition(EndQuorumEpochResponseData data, TopicPartition topicPartition) { + return data.topics().size() == 1 && + data.topics().get(0).topicName().equals(topicPartition.topic()) && + data.topics().get(0).partitions().size() == 1 && + data.topics().get(0).partitions().get(0).partitionIndex() == topicPartition.partition(); + } + + static boolean hasValidTopicPartition(DescribeQuorumRequestData data, TopicPartition topicPartition) { + return data.topics().size() == 1 && + data.topics().get(0).topicName().equals(topicPartition.topic()) && + data.topics().get(0).partitions().size() == 1 && + data.topics().get(0).partitions().get(0).partitionIndex() == topicPartition.partition(); + } +} diff --git a/raft/src/main/java/org/apache/kafka/raft/ReplicatedCounter.java b/raft/src/main/java/org/apache/kafka/raft/ReplicatedCounter.java new file mode 100644 index 0000000..66303c6 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/ReplicatedCounter.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.raft.errors.NotLeaderException; +import org.apache.kafka.snapshot.SnapshotReader; +import org.apache.kafka.snapshot.SnapshotWriter; +import org.slf4j.Logger; + +import java.util.Optional; +import java.util.OptionalInt; + +import static java.util.Collections.singletonList; + +public class ReplicatedCounter implements RaftClient.Listener { + private final int nodeId; + private final Logger log; + private final RaftClient client; + private final int snapshotDelayInRecords = 10; + + private int committed = 0; + private int uncommitted = 0; + private OptionalInt claimedEpoch = OptionalInt.empty(); + private long lastOffsetSnapshotted = -1; + + private int handleSnapshotCalls = 0; + + public ReplicatedCounter( + int nodeId, + RaftClient client, + LogContext logContext + ) { + this.nodeId = nodeId; + this.client = client; + log = logContext.logger(ReplicatedCounter.class); + } + + public synchronized boolean isWritable() { + return claimedEpoch.isPresent(); + } + + public synchronized void increment() { + if (!claimedEpoch.isPresent()) { + throw new KafkaException("Counter is not currently writable"); + } + + int epoch = claimedEpoch.getAsInt(); + uncommitted += 1; + try { + long offset = client.scheduleAppend(epoch, singletonList(uncommitted)); + log.debug("Scheduled append of record {} with epoch {} at offset {}", + uncommitted, epoch, offset); + } catch (NotLeaderException e) { + log.info("Appending failed, transition to resigned", e); + client.resign(epoch); + } + } + + @Override + public synchronized void handleCommit(BatchReader reader) { + try { + int initialCommitted = committed; + long lastCommittedOffset = -1; + int lastCommittedEpoch = 0; + long lastCommittedTimestamp = -1; + + while (reader.hasNext()) { + Batch batch = reader.next(); + log.debug( + "Handle commit of batch with records {} at base offset {}", + batch.records(), + batch.baseOffset() + ); + for (Integer nextCommitted: batch.records()) { + if (nextCommitted != committed + 1) { + throw new AssertionError( + String.format( + "Expected next committed value to be %s, but instead found %s on node %s", + committed + 1, + nextCommitted, + nodeId + ) + ); + } + committed = nextCommitted; + } + + lastCommittedOffset = batch.lastOffset(); + lastCommittedEpoch = batch.epoch(); + lastCommittedTimestamp = batch.appendTimestamp(); + } + log.debug("Counter incremented from {} to {}", initialCommitted, committed); + + if (lastOffsetSnapshotted + snapshotDelayInRecords < lastCommittedOffset) { + log.debug( + "Generating new snapshot with committed offset {} and epoch {} since the previoud snapshot includes {}", + lastCommittedOffset, + lastCommittedEpoch, + lastOffsetSnapshotted + ); + Optional> snapshot = client.createSnapshot( + lastCommittedOffset, + lastCommittedEpoch, + lastCommittedTimestamp); + if (snapshot.isPresent()) { + try { + snapshot.get().append(singletonList(committed)); + snapshot.get().freeze(); + lastOffsetSnapshotted = lastCommittedOffset; + } finally { + snapshot.get().close(); + } + } else { + lastOffsetSnapshotted = lastCommittedOffset; + } + } + } finally { + reader.close(); + } + } + + @Override + public synchronized void handleSnapshot(SnapshotReader reader) { + try { + log.debug("Loading snapshot {}", reader.snapshotId()); + while (reader.hasNext()) { + Batch batch = reader.next(); + if (batch.records().size() != 1) { + throw new AssertionError( + String.format( + "Expected the snapshot at %s to only contain one record %s", + reader.snapshotId(), + batch.records() + ) + ); + } + + for (Integer value : batch) { + log.debug("Setting value: {}", value); + committed = value; + uncommitted = value; + } + } + lastOffsetSnapshotted = reader.lastContainedLogOffset(); + handleSnapshotCalls += 1; + log.debug("Finished loading snapshot. Set value: {}", committed); + } finally { + reader.close(); + } + } + + @Override + public synchronized void handleLeaderChange(LeaderAndEpoch newLeader) { + if (newLeader.isLeader(nodeId)) { + log.debug("Counter uncommitted value initialized to {} after claiming leadership in epoch {}", + committed, newLeader); + uncommitted = committed; + claimedEpoch = OptionalInt.of(newLeader.epoch()); + } else { + log.debug("Counter uncommitted value reset after resigning leadership"); + uncommitted = -1; + claimedEpoch = OptionalInt.empty(); + } + handleSnapshotCalls = 0; + } + + /** Use handleSnapshotCalls to verify leader is never asked to load snapshot */ + public int handleSnapshotCalls() { + return handleSnapshotCalls; + } +} diff --git a/raft/src/main/java/org/apache/kafka/raft/ReplicatedLog.java b/raft/src/main/java/org/apache/kafka/raft/ReplicatedLog.java new file mode 100644 index 0000000..f60a796 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/ReplicatedLog.java @@ -0,0 +1,315 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.record.Records; +import org.apache.kafka.snapshot.RawSnapshotReader; +import org.apache.kafka.snapshot.RawSnapshotWriter; + +import java.util.Optional; + +public interface ReplicatedLog extends AutoCloseable { + + /** + * Write a set of records to the local leader log. These messages will either + * be written atomically in a single batch or the call will fail and raise an + * exception. + * + * @return the metadata information of the appended batch + * @throws IllegalArgumentException if the record set is empty + * @throws RuntimeException if the batch base offset doesn't match the log end offset + */ + LogAppendInfo appendAsLeader(Records records, int epoch); + + /** + * Append a set of records that were replicated from the leader. The main + * difference from appendAsLeader is that we do not need to assign the epoch + * or do additional validation. + * + * @return the metadata information of the appended batch + * @throws IllegalArgumentException if the record set is empty + * @throws RuntimeException if the batch base offset doesn't match the log end offset + */ + LogAppendInfo appendAsFollower(Records records); + + /** + * Read a set of records within a range of offsets. + */ + LogFetchInfo read(long startOffsetInclusive, Isolation isolation); + + /** + * Return the latest epoch. For an empty log, the latest epoch is defined + * as 0. We refer to this as the "primordial epoch" and it is never allowed + * to have a leader or any records associated with it (leader epochs always start + * from 1). Basically this just saves us the trouble of having to use `Option` + * all over the place. + */ + int lastFetchedEpoch(); + + /** + * Validate the given offset and epoch against the log and oldest snapshot. + * + * Returns the largest valid offset and epoch given `offset` and `epoch` as the upper bound. + * This can result in three possible values returned: + * + * 1. ValidOffsetAndEpoch.valid if the given offset and epoch is valid in the log. + * + * 2. ValidOffsetAndEpoch.diverging if the given offset and epoch is not valid; and the + * largest valid offset and epoch is in the log. + * + * 3. ValidOffsetAndEpoch.snapshot if the given offset and epoch is not valid; and the largest + * valid offset and epoch is less than the oldest snapshot. + * + * @param offset the offset to validate + * @param epoch the epoch of the record at offset - 1 + * @return the largest valid offset and epoch + */ + default ValidOffsetAndEpoch validateOffsetAndEpoch(long offset, int epoch) { + if (startOffset() == 0 && offset == 0) { + return ValidOffsetAndEpoch.valid(new OffsetAndEpoch(0, 0)); + } + + Optional earliestSnapshotId = earliestSnapshotId(); + if (earliestSnapshotId.isPresent() && + ((offset < startOffset()) || + (offset == startOffset() && epoch != earliestSnapshotId.get().epoch) || + (epoch < earliestSnapshotId.get().epoch)) + ) { + /* Send a snapshot if the leader has a snapshot at the log start offset and + * 1. the fetch offset is less than the log start offset or + * 2. the fetch offset is equal to the log start offset and last fetch epoch doesn't match + * the oldest snapshot or + * 3. last fetch epoch is less than the oldest snapshot's epoch + */ + OffsetAndEpoch latestSnapshotId = latestSnapshotId().orElseThrow(() -> new IllegalStateException( + String.format( + "Log start offset (%s) is greater than zero but latest snapshot was not found", + startOffset() + ) + )); + + return ValidOffsetAndEpoch.snapshot(latestSnapshotId); + } else { + OffsetAndEpoch endOffsetAndEpoch = endOffsetForEpoch(epoch); + + if (endOffsetAndEpoch.epoch != epoch || endOffsetAndEpoch.offset < offset) { + return ValidOffsetAndEpoch.diverging(endOffsetAndEpoch); + } else { + return ValidOffsetAndEpoch.valid(new OffsetAndEpoch(offset, epoch)); + } + } + } + + /** + * Find the first epoch less than or equal to the given epoch and its end offset. + */ + OffsetAndEpoch endOffsetForEpoch(int epoch); + + /** + * Get the current log end offset metadata. This is always one plus the offset of the last + * written record. When the log is empty, the end offset is equal to the start offset. + */ + LogOffsetMetadata endOffset(); + + /** + * Get the high watermark. + */ + LogOffsetMetadata highWatermark(); + + /** + * Get the current log start offset. This is the offset of the first written + * entry, if one exists, or the end offset otherwise. + */ + long startOffset(); + + /** + * Initialize a new leader epoch beginning at the current log end offset. This API is invoked + * after becoming a leader and ensures that we can always determine the end offset and epoch + * with {@link #endOffsetForEpoch(int)} for any previous epoch. + * + * @param epoch Epoch of the newly elected leader + */ + void initializeLeaderEpoch(int epoch); + + /** + * Truncate the log to the given offset. All records with offsets greater than or equal to + * the given offset will be removed. + * + * @param offset The offset to truncate to + */ + void truncateTo(long offset); + + /** + * Fully truncate the log if the latest snapshot is later than the log end offset. + * + * In general this operation empties the log and sets the log start offset, high watermark and + * log end offset to the latest snapshot's end offset. + * + * @return true when the log is fully truncated, otherwise returns false + */ + boolean truncateToLatestSnapshot(); + + /** + * Update the high watermark and associated metadata (which is used to avoid + * index lookups when handling reads with {@link #read(long, Isolation)} with + * the {@link Isolation#COMMITTED} isolation level. + * + * @param offsetMetadata The offset and optional metadata + */ + void updateHighWatermark(LogOffsetMetadata offsetMetadata); + + /** + * Delete all snapshots prior to the given snapshot + * + * The replicated log's start offset can be increased and older segments can be deleted when + * there is a snapshot greater than the current log start offset. + */ + boolean deleteBeforeSnapshot(OffsetAndEpoch snapshotId); + + /** + * Flush the current log to disk. + */ + void flush(); + + /** + * Possibly perform cleaning of snapshots and logs + */ + boolean maybeClean(); + + /** + * Get the last offset which has been flushed to disk. + */ + long lastFlushedOffset(); + + /** + * Return the topic partition associated with the log. + */ + TopicPartition topicPartition(); + + /** + * Return the topic ID associated with the log. + */ + Uuid topicId(); + + /** + * Truncate to an offset and epoch. + * + * @param endOffset offset and epoch to truncate to + * @return the truncation offset + */ + default long truncateToEndOffset(OffsetAndEpoch endOffset) { + final long truncationOffset; + int leaderEpoch = endOffset.epoch; + if (leaderEpoch == 0) { + truncationOffset = Math.min(endOffset.offset, endOffset().offset); + } else { + OffsetAndEpoch localEndOffset = endOffsetForEpoch(leaderEpoch); + if (localEndOffset.epoch == leaderEpoch) { + truncationOffset = Math.min(localEndOffset.offset, endOffset.offset); + } else { + truncationOffset = localEndOffset.offset; + } + } + + truncateTo(truncationOffset); + return truncationOffset; + } + + /** + * Create a writable snapshot for the given snapshot id. + * + * See {@link RawSnapshotWriter} for details on how to use this object. The caller of + * this method is responsible for invoking {@link RawSnapshotWriter#close()}. If a + * snapshot already exists or it is less than log start offset then return an + * {@link Optional#empty()}. + * + * Snapshots created using this method will be validated against the existing snapshots + * and the replicated log. + * + * @param snapshotId the end offset and epoch that identifies the snapshot + * @return a writable snapshot if it doesn't already exists and greater than the log start + * offset + * @throws IllegalArgumentException if validate is true and end offset is greater than the + * high-watermark + */ + Optional createNewSnapshot(OffsetAndEpoch snapshotId); + + /** + * Create a writable snapshot for the given snapshot id. + * + * See {@link RawSnapshotWriter} for details on how to use this object. The caller of + * this method is responsible for invoking {@link RawSnapshotWriter#close()}. If a + * snapshot already exists then return an {@link Optional#empty()}. + * + * Snapshots created using this method will not be validated against the existing snapshots + * and the replicated log. This is useful when creating snapshot from a trusted source like + * the quorum leader. + * + * @param snapshotId the end offset and epoch that identifies the snapshot + * @return a writable snapshot if it doesn't already exists + */ + Optional storeSnapshot(OffsetAndEpoch snapshotId); + + /** + * Opens a readable snapshot for the given snapshot id. + * + * Returns an Optional with a readable snapshot, if the snapshot exists, otherwise + * returns an empty Optional. See {@link RawSnapshotReader} for details on how to + * use this object. + * + * @param snapshotId the end offset and epoch that identifies the snapshot + * @return an Optional with a readable snapshot, if the snapshot exists, otherwise + * returns an empty Optional + */ + Optional readSnapshot(OffsetAndEpoch snapshotId); + + /** + * Returns the latest readable snapshot if one exists. + * + * @return an Optional with the latest readable snapshot, if one exists, otherwise + * returns an empty Optional + */ + Optional latestSnapshot(); + + /** + * Returns the latest snapshot id if one exists. + * + * @return an Optional snapshot id of the latest snashot if one exists, otherwise returns an + * empty Optional + */ + Optional latestSnapshotId(); + + /** + * Returns the snapshot id at the log start offset. + * + * If the log start offset is nonzero then it is expected that there is a snapshot with an end + * offset equal to the start offset. + * + * @return an Optional snapshot id at the log start offset if nonzero, otherwise returns an empty + * Optional + */ + Optional earliestSnapshotId(); + + /** + * Notifies the replicated log when a new snapshot is available. + */ + void onSnapshotFrozen(OffsetAndEpoch snapshotId); + + default void close() {} +} diff --git a/raft/src/main/java/org/apache/kafka/raft/RequestManager.java b/raft/src/main/java/org/apache/kafka/raft/RequestManager.java new file mode 100644 index 0000000..5a5cb00 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/RequestManager.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.OptionalInt; +import java.util.OptionalLong; +import java.util.Random; +import java.util.Set; + +public class RequestManager { + private final Map connections = new HashMap<>(); + private final List voters = new ArrayList<>(); + + private final int retryBackoffMs; + private final int requestTimeoutMs; + private final Random random; + + public RequestManager(Set voterIds, + int retryBackoffMs, + int requestTimeoutMs, + Random random) { + + this.retryBackoffMs = retryBackoffMs; + this.requestTimeoutMs = requestTimeoutMs; + this.voters.addAll(voterIds); + this.random = random; + + for (Integer voterId: voterIds) { + ConnectionState connection = new ConnectionState(voterId); + connections.put(voterId, connection); + } + } + + public ConnectionState getOrCreate(int id) { + return connections.computeIfAbsent(id, key -> new ConnectionState(id)); + } + + public OptionalInt findReadyVoter(long currentTimeMs) { + int startIndex = random.nextInt(voters.size()); + OptionalInt res = OptionalInt.empty(); + for (int i = 0; i < voters.size(); i++) { + int index = (startIndex + i) % voters.size(); + Integer voterId = voters.get(index); + ConnectionState connection = connections.get(voterId); + boolean isReady = connection.isReady(currentTimeMs); + + if (isReady) { + res = OptionalInt.of(voterId); + } else if (connection.inFlightCorrelationId.isPresent()) { + res = OptionalInt.empty(); + break; + } + } + return res; + } + + public long backoffBeforeAvailableVoter(long currentTimeMs) { + long minBackoffMs = Long.MAX_VALUE; + for (Integer voterId : voters) { + ConnectionState connection = connections.get(voterId); + if (connection.isReady(currentTimeMs)) { + return 0L; + } else if (connection.isBackingOff(currentTimeMs)) { + minBackoffMs = Math.min(minBackoffMs, connection.remainingBackoffMs(currentTimeMs)); + } else { + minBackoffMs = Math.min(minBackoffMs, connection.remainingRequestTimeMs(currentTimeMs)); + } + } + return minBackoffMs; + } + + public void resetAll() { + for (ConnectionState connectionState : connections.values()) + connectionState.reset(); + } + + private enum State { + AWAITING_REQUEST, + BACKING_OFF, + READY + } + + public class ConnectionState { + private final long id; + private State state = State.READY; + private long lastSendTimeMs = 0L; + private long lastFailTimeMs = 0L; + private OptionalLong inFlightCorrelationId = OptionalLong.empty(); + + public ConnectionState(long id) { + this.id = id; + } + + private boolean isBackoffComplete(long timeMs) { + return state == State.BACKING_OFF && timeMs >= lastFailTimeMs + retryBackoffMs; + } + + boolean hasRequestTimedOut(long timeMs) { + return state == State.AWAITING_REQUEST && timeMs >= lastSendTimeMs + requestTimeoutMs; + } + + public long id() { + return id; + } + + boolean isReady(long timeMs) { + if (isBackoffComplete(timeMs) || hasRequestTimedOut(timeMs)) { + state = State.READY; + } + return state == State.READY; + } + + boolean isBackingOff(long timeMs) { + if (state != State.BACKING_OFF) { + return false; + } else { + return !isBackoffComplete(timeMs); + } + } + + boolean hasInflightRequest(long timeMs) { + if (state != State.AWAITING_REQUEST) { + return false; + } else { + return !hasRequestTimedOut(timeMs); + } + } + + long remainingRequestTimeMs(long timeMs) { + if (hasInflightRequest(timeMs)) { + return lastSendTimeMs + requestTimeoutMs - timeMs; + } else { + return 0; + } + } + + long remainingBackoffMs(long timeMs) { + if (isBackingOff(timeMs)) { + return lastFailTimeMs + retryBackoffMs - timeMs; + } else { + return 0; + } + } + + boolean isResponseExpected(long correlationId) { + return inFlightCorrelationId.isPresent() && inFlightCorrelationId.getAsLong() == correlationId; + } + + void onResponseError(long correlationId, long timeMs) { + inFlightCorrelationId.ifPresent(inflightRequestId -> { + if (inflightRequestId == correlationId) { + lastFailTimeMs = timeMs; + state = State.BACKING_OFF; + inFlightCorrelationId = OptionalLong.empty(); + } + }); + } + + void onResponseReceived(long correlationId) { + inFlightCorrelationId.ifPresent(inflightRequestId -> { + if (inflightRequestId == correlationId) { + state = State.READY; + inFlightCorrelationId = OptionalLong.empty(); + } + }); + } + + void onRequestSent(long correlationId, long timeMs) { + lastSendTimeMs = timeMs; + inFlightCorrelationId = OptionalLong.of(correlationId); + state = State.AWAITING_REQUEST; + } + + /** + * Ignore in-flight requests or backoff and become available immediately. This is used + * when there is a state change which usually means in-flight requests are obsolete + * and we need to send new requests. + */ + void reset() { + state = State.READY; + inFlightCorrelationId = OptionalLong.empty(); + } + + @Override + public String toString() { + return "ConnectionState(" + + "id=" + id + + ", state=" + state + + ", lastSendTimeMs=" + lastSendTimeMs + + ", lastFailTimeMs=" + lastFailTimeMs + + ", inFlightCorrelationId=" + inFlightCorrelationId + + ')'; + } + } + +} diff --git a/raft/src/main/java/org/apache/kafka/raft/ResignedState.java b/raft/src/main/java/org/apache/kafka/raft/ResignedState.java new file mode 100644 index 0000000..899823a --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/ResignedState.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Timer; +import org.slf4j.Logger; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * This state represents a leader which has fenced itself either because it + * is shutting down or because it has encountered a soft failure of some sort. + * No writes are accepted in this state and we are not permitted to vote for + * any other candidate in this epoch. + * + * A resigned leader may initiate a new election by sending `EndQuorumEpoch` + * requests to all of the voters. This state tracks delivery of this request + * in order to prevent unnecessary retries. + * + * A voter will remain in the `Resigned` state until we either learn about + * another election, or our own election timeout expires and we become a + * Candidate. + */ +public class ResignedState implements EpochState { + private final int localId; + private final int epoch; + private final Set voters; + private final long electionTimeoutMs; + private final Set unackedVoters; + private final Timer electionTimer; + private final List preferredSuccessors; + private final Logger log; + + public ResignedState( + Time time, + int localId, + int epoch, + Set voters, + long electionTimeoutMs, + List preferredSuccessors, + LogContext logContext + ) { + this.localId = localId; + this.epoch = epoch; + this.voters = voters; + this.unackedVoters = new HashSet<>(voters); + this.unackedVoters.remove(localId); + this.electionTimeoutMs = electionTimeoutMs; + this.electionTimer = time.timer(electionTimeoutMs); + this.preferredSuccessors = preferredSuccessors; + this.log = logContext.logger(ResignedState.class); + } + + @Override + public ElectionState election() { + return ElectionState.withElectedLeader(epoch, localId, voters); + } + + @Override + public int epoch() { + return epoch; + } + + /** + * Get the set of voters which have yet to acknowledge the resignation. + * This node will send `EndQuorumEpoch` requests to this set until these + * voters acknowledge the request or we transition to another state. + * + * @return the set of unacknowledged voters + */ + public Set unackedVoters() { + return unackedVoters; + } + + /** + * Invoked after receiving a successful `EndQuorumEpoch` response. This + * is in order to prevent unnecessary retries. + * + * @param voterId the ID of the voter that send the successful response + */ + public void acknowledgeResignation(int voterId) { + if (!voters.contains(voterId)) { + throw new IllegalArgumentException("Attempt to acknowledge delivery of `EndQuorumEpoch` " + + "by a non-voter " + voterId); + } + unackedVoters.remove(voterId); + } + + /** + * Check whether the timeout has expired. + * + * @param currentTimeMs current time in milliseconds + * @return true if the timeout has expired, false otherwise + */ + public boolean hasElectionTimeoutExpired(long currentTimeMs) { + electionTimer.update(currentTimeMs); + return electionTimer.isExpired(); + } + + /** + * Check the time remaining until the timeout expires. + * + * @param currentTimeMs current time in milliseconds + * @return the duration in milliseconds from the current time before the timeout expires + */ + public long remainingElectionTimeMs(long currentTimeMs) { + electionTimer.update(currentTimeMs); + return electionTimer.remainingMs(); + } + + public List preferredSuccessors() { + return preferredSuccessors; + } + + @Override + public boolean canGrantVote(int candidateId, boolean isLogUpToDate) { + log.debug("Rejecting vote request from candidate {} since we have resigned as candidate/leader in epoch {}", + candidateId, epoch); + return false; + } + + @Override + public String name() { + return "Resigned"; + } + + @Override + public String toString() { + return "ResignedState(" + + "localId=" + localId + + ", epoch=" + epoch + + ", voters=" + voters + + ", electionTimeoutMs=" + electionTimeoutMs + + ", unackedVoters=" + unackedVoters + + ", preferredSuccessors=" + preferredSuccessors + + ')'; + } + + @Override + public void close() {} +} diff --git a/raft/src/main/java/org/apache/kafka/raft/UnattachedState.java b/raft/src/main/java/org/apache/kafka/raft/UnattachedState.java new file mode 100644 index 0000000..4dc5fc7 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/UnattachedState.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Timer; +import org.slf4j.Logger; + +import java.util.Optional; +import java.util.OptionalInt; +import java.util.Set; + +/** + * A voter is "unattached" when it learns of an ongoing election (typically + * by observing a bumped epoch), but has yet to cast its vote or become a + * candidate itself. + */ +public class UnattachedState implements EpochState { + private final int epoch; + private final Set voters; + private final long electionTimeoutMs; + private final Timer electionTimer; + private final Optional highWatermark; + private final Logger log; + + public UnattachedState( + Time time, + int epoch, + Set voters, + Optional highWatermark, + long electionTimeoutMs, + LogContext logContext + ) { + this.epoch = epoch; + this.voters = voters; + this.highWatermark = highWatermark; + this.electionTimeoutMs = electionTimeoutMs; + this.electionTimer = time.timer(electionTimeoutMs); + this.log = logContext.logger(UnattachedState.class); + } + + @Override + public ElectionState election() { + return new ElectionState( + epoch, + OptionalInt.empty(), + OptionalInt.empty(), + voters + ); + } + + @Override + public int epoch() { + return epoch; + } + + @Override + public String name() { + return "Unattached"; + } + + public long electionTimeoutMs() { + return electionTimeoutMs; + } + + public long remainingElectionTimeMs(long currentTimeMs) { + electionTimer.update(currentTimeMs); + return electionTimer.remainingMs(); + } + + public boolean hasElectionTimeoutExpired(long currentTimeMs) { + electionTimer.update(currentTimeMs); + return electionTimer.isExpired(); + } + + @Override + public Optional highWatermark() { + return highWatermark; + } + + @Override + public boolean canGrantVote(int candidateId, boolean isLogUpToDate) { + if (!isLogUpToDate) { + log.debug("Rejecting vote request from candidate {} since candidate epoch/offset is not up to date with us", + candidateId); + } + return isLogUpToDate; + } + + @Override + public String toString() { + return "Unattached(" + + "epoch=" + epoch + + ", voters=" + voters + + ", electionTimeoutMs=" + electionTimeoutMs + + ')'; + } + + @Override + public void close() {} +} diff --git a/raft/src/main/java/org/apache/kafka/raft/ValidOffsetAndEpoch.java b/raft/src/main/java/org/apache/kafka/raft/ValidOffsetAndEpoch.java new file mode 100644 index 0000000..320e3d9 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/ValidOffsetAndEpoch.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import java.util.Objects; + +public final class ValidOffsetAndEpoch { + final private Kind kind; + final private OffsetAndEpoch offsetAndEpoch; + + private ValidOffsetAndEpoch(Kind kind, OffsetAndEpoch offsetAndEpoch) { + this.kind = kind; + this.offsetAndEpoch = offsetAndEpoch; + } + + public Kind kind() { + return kind; + } + + public OffsetAndEpoch offsetAndEpoch() { + return offsetAndEpoch; + } + + public enum Kind { + DIVERGING, SNAPSHOT, VALID + } + + public static ValidOffsetAndEpoch diverging(OffsetAndEpoch offsetAndEpoch) { + return new ValidOffsetAndEpoch(Kind.DIVERGING, offsetAndEpoch); + } + + public static ValidOffsetAndEpoch snapshot(OffsetAndEpoch offsetAndEpoch) { + return new ValidOffsetAndEpoch(Kind.SNAPSHOT, offsetAndEpoch); + } + + public static ValidOffsetAndEpoch valid(OffsetAndEpoch offsetAndEpoch) { + return new ValidOffsetAndEpoch(Kind.VALID, offsetAndEpoch); + } + + public static ValidOffsetAndEpoch valid() { + return valid(new OffsetAndEpoch(-1, -1)); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + ValidOffsetAndEpoch that = (ValidOffsetAndEpoch) obj; + return kind == that.kind && + offsetAndEpoch.equals(that.offsetAndEpoch); + } + + @Override + public int hashCode() { + return Objects.hash(kind, offsetAndEpoch); + } + + @Override + public String toString() { + return String.format( + "ValidOffsetAndEpoch(kind=%s, offsetAndEpoch=%s)", + kind, + offsetAndEpoch + ); + } +} diff --git a/raft/src/main/java/org/apache/kafka/raft/VotedState.java b/raft/src/main/java/org/apache/kafka/raft/VotedState.java new file mode 100644 index 0000000..2ae5026 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/VotedState.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Timer; +import org.slf4j.Logger; + +import java.util.Optional; +import java.util.OptionalInt; +import java.util.Set; + +/** + * The "voted" state is for voters who have cast their vote for a specific candidate. + * Once a vote has been cast, it is not possible for a voter to change its vote until a + * new election is started. If the election timeout expires before a new leader is elected, + * then the voter will become a candidate. + */ +public class VotedState implements EpochState { + private final int epoch; + private final int votedId; + private final Set voters; + private final int electionTimeoutMs; + private final Timer electionTimer; + private final Optional highWatermark; + private final Logger log; + + public VotedState( + Time time, + int epoch, + int votedId, + Set voters, + Optional highWatermark, + int electionTimeoutMs, + LogContext logContext + ) { + this.epoch = epoch; + this.votedId = votedId; + this.voters = voters; + this.highWatermark = highWatermark; + this.electionTimeoutMs = electionTimeoutMs; + this.electionTimer = time.timer(electionTimeoutMs); + this.log = logContext.logger(VotedState.class); + } + + @Override + public ElectionState election() { + return new ElectionState( + epoch, + OptionalInt.empty(), + OptionalInt.of(votedId), + voters + ); + } + + public int votedId() { + return votedId; + } + + @Override + public int epoch() { + return epoch; + } + + @Override + public String name() { + return "Voted"; + } + + public long remainingElectionTimeMs(long currentTimeMs) { + electionTimer.update(currentTimeMs); + return electionTimer.remainingMs(); + } + + public boolean hasElectionTimeoutExpired(long currentTimeMs) { + electionTimer.update(currentTimeMs); + return electionTimer.isExpired(); + } + + public void overrideElectionTimeout(long currentTimeMs, long timeoutMs) { + electionTimer.update(currentTimeMs); + electionTimer.reset(timeoutMs); + } + + @Override + public boolean canGrantVote(int candidateId, boolean isLogUpToDate) { + if (votedId() == candidateId) { + return true; + } + + log.debug("Rejecting vote request from candidate {} since we already have voted for " + + "another candidate {} in epoch {}", candidateId, votedId(), epoch); + return false; + } + + @Override + public Optional highWatermark() { + return highWatermark; + } + + @Override + public String toString() { + return "Voted(" + + "epoch=" + epoch + + ", votedId=" + votedId + + ", voters=" + voters + + ", electionTimeoutMs=" + electionTimeoutMs + + ')'; + } + + @Override + public void close() {} +} diff --git a/raft/src/main/java/org/apache/kafka/raft/errors/BufferAllocationException.java b/raft/src/main/java/org/apache/kafka/raft/errors/BufferAllocationException.java new file mode 100644 index 0000000..ecd36b3 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/errors/BufferAllocationException.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.errors; + +/** + * Indicates that an operation is failed because we failed to allocate memory for it. + */ +public class BufferAllocationException extends RaftException { + + private final static long serialVersionUID = 1L; + + public BufferAllocationException(String s) { + super(s); + } +} diff --git a/raft/src/main/java/org/apache/kafka/raft/errors/NotLeaderException.java b/raft/src/main/java/org/apache/kafka/raft/errors/NotLeaderException.java new file mode 100644 index 0000000..7f737fa --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/errors/NotLeaderException.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.errors; + +/** + * Indicates that an operation is not allowed because this node is not the + * current leader or the epoch is not the same with the current leader epoch. + */ +public class NotLeaderException extends RaftException { + + private final static long serialVersionUID = 1L; + + public NotLeaderException(String s) { + super(s); + } + + public NotLeaderException(String s, Throwable throwable) { + super(s, throwable); + } + + public NotLeaderException(Throwable throwable) { + super(throwable); + } +} diff --git a/raft/src/main/java/org/apache/kafka/raft/errors/RaftException.java b/raft/src/main/java/org/apache/kafka/raft/errors/RaftException.java new file mode 100644 index 0000000..6df196b --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/errors/RaftException.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.errors; + +import org.apache.kafka.common.KafkaException; + +/** + * RaftException is the top-level exception type generated by Kafka raft implementations. + */ +public class RaftException extends KafkaException { + + private final static long serialVersionUID = 1L; + + public RaftException(String s) { + super(s); + } + + public RaftException(String s, Throwable throwable) { + super(s, throwable); + } + + public RaftException(Throwable throwable) { + super(throwable); + } +} diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/BatchAccumulator.java b/raft/src/main/java/org/apache/kafka/raft/internals/BatchAccumulator.java new file mode 100644 index 0000000..697394d --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/internals/BatchAccumulator.java @@ -0,0 +1,546 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + +import org.apache.kafka.common.errors.RecordBatchTooLargeException; +import org.apache.kafka.common.memory.MemoryPool; +import org.apache.kafka.common.protocol.ObjectSerializationCache; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.raft.errors.BufferAllocationException; +import org.apache.kafka.raft.errors.NotLeaderException; +import org.apache.kafka.server.common.serialization.RecordSerde; + +import org.apache.kafka.common.message.LeaderChangeMessage; +import org.apache.kafka.common.message.SnapshotHeaderRecord; +import org.apache.kafka.common.message.SnapshotFooterRecord; +import java.io.Closeable; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.function.Function; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.ReentrantLock; + +public class BatchAccumulator implements Closeable { + private final int epoch; + private final Time time; + private final SimpleTimer lingerTimer; + private final int lingerMs; + private final int maxBatchSize; + private final CompressionType compressionType; + private final MemoryPool memoryPool; + private final ReentrantLock appendLock; + private final RecordSerde serde; + + private final ConcurrentLinkedQueue> completed; + private volatile DrainStatus drainStatus; + + // These fields are protected by the append lock + private long nextOffset; + private BatchBuilder currentBatch; + + private enum DrainStatus { + STARTED, FINISHED, NONE + } + + public BatchAccumulator( + int epoch, + long baseOffset, + int lingerMs, + int maxBatchSize, + MemoryPool memoryPool, + Time time, + CompressionType compressionType, + RecordSerde serde + ) { + this.epoch = epoch; + this.lingerMs = lingerMs; + this.maxBatchSize = maxBatchSize; + this.memoryPool = memoryPool; + this.time = time; + this.lingerTimer = new SimpleTimer(); + this.compressionType = compressionType; + this.serde = serde; + this.nextOffset = baseOffset; + this.drainStatus = DrainStatus.NONE; + this.completed = new ConcurrentLinkedQueue<>(); + this.appendLock = new ReentrantLock(); + } + + /** + * Append a list of records into as many batches as necessary. + * + * The order of the elements in the records argument will match the order in the batches. + * This method will use as many batches as necessary to serialize all of the records. Since + * this method can split the records into multiple batches it is possible that some of the + * records will get committed while other will not when the leader fails. + * + * @param epoch the expected leader epoch. If this does not match, then {@link NotLeaderException} + * will be thrown + * @param records the list of records to include in the batches + * @return the expected offset of the last record + * @throws RecordBatchTooLargeException if the size of one record T is greater than the maximum + * batch size; if this exception is throw some of the elements in records may have + * been committed + * @throws NotLeaderException if the epoch doesn't match the leader epoch + * @throws BufferAllocationException if we failed to allocate memory for the records + */ + public long append(int epoch, List records) { + return append(epoch, records, false); + } + + /** + * Append a list of records into an atomic batch. We guarantee all records are included in the + * same underlying record batch so that either all of the records become committed or none of + * them do. + * + * @param epoch the expected leader epoch. If this does not match, then {@link NotLeaderException} + * will be thrown + * @param records the list of records to include in a batch + * @return the expected offset of the last record + * @throws RecordBatchTooLargeException if the size of the records is greater than the maximum + * batch size; if this exception is throw none of the elements in records were + * committed + * @throws NotLeaderException if the epoch doesn't match the leader epoch + * @throws BufferAllocationException if we failed to allocate memory for the records + */ + public long appendAtomic(int epoch, List records) { + return append(epoch, records, true); + } + + private long append(int epoch, List records, boolean isAtomic) { + if (epoch < this.epoch) { + throw new NotLeaderException("Append failed because the epoch doesn't match"); + } else if (epoch > this.epoch) { + throw new IllegalArgumentException("Attempt to append from epoch " + epoch + + " which is larger than the current epoch " + this.epoch); + } + + ObjectSerializationCache serializationCache = new ObjectSerializationCache(); + + appendLock.lock(); + try { + maybeCompleteDrain(); + + BatchBuilder batch = null; + if (isAtomic) { + batch = maybeAllocateBatch(records, serializationCache); + } + + for (T record : records) { + if (!isAtomic) { + batch = maybeAllocateBatch(Collections.singleton(record), serializationCache); + } + + if (batch == null) { + throw new BufferAllocationException("Append failed because we failed to allocate memory to write the batch"); + } + + batch.appendRecord(record, serializationCache); + nextOffset += 1; + } + + maybeResetLinger(); + + return nextOffset - 1; + } finally { + appendLock.unlock(); + } + } + + private void maybeResetLinger() { + if (!lingerTimer.isRunning()) { + lingerTimer.reset(time.milliseconds() + lingerMs); + } + } + + private BatchBuilder maybeAllocateBatch( + Collection records, + ObjectSerializationCache serializationCache + ) { + if (currentBatch == null) { + startNewBatch(); + } + + if (currentBatch != null) { + OptionalInt bytesNeeded = currentBatch.bytesNeeded(records, serializationCache); + if (bytesNeeded.isPresent() && bytesNeeded.getAsInt() > maxBatchSize) { + throw new RecordBatchTooLargeException( + String.format( + "The total record(s) size of %s exceeds the maximum allowed batch size of %s", + bytesNeeded.getAsInt(), + maxBatchSize + ) + ); + } else if (bytesNeeded.isPresent()) { + completeCurrentBatch(); + startNewBatch(); + } + } + + return currentBatch; + } + + private void completeCurrentBatch() { + MemoryRecords data = currentBatch.build(); + completed.add(new CompletedBatch<>( + currentBatch.baseOffset(), + currentBatch.records(), + data, + memoryPool, + currentBatch.initialBuffer() + )); + currentBatch = null; + } + + /** + * Append a control batch from a supplied memory record. + * + * See the {@code valueCreator} parameter description for requirements on this function. + * + * @param valueCreator a function that uses the passed buffer to create the control + * batch that will be appended. The memory records returned must contain one + * control batch and that control batch have one record. + */ + private void appendControlMessage(Function valueCreator) { + appendLock.lock(); + try { + ByteBuffer buffer = memoryPool.tryAllocate(256); + if (buffer != null) { + try { + forceDrain(); + completed.add( + new CompletedBatch<>( + nextOffset, + 1, + valueCreator.apply(buffer), + memoryPool, + buffer + ) + ); + nextOffset += 1; + } catch (Exception e) { + // Release the buffer now since the buffer was not stored in completed for a delayed release + memoryPool.release(buffer); + throw e; + } + } else { + throw new IllegalStateException("Could not allocate buffer for the control record"); + } + } finally { + appendLock.unlock(); + } + } + + /** + * Append a {@link LeaderChangeMessage} record to the batch + * + * @param @LeaderChangeMessage The message to append + * @param @currentTimeMs The timestamp of message generation + * @throws IllegalStateException on failure to allocate a buffer for the record + */ + public void appendLeaderChangeMessage( + LeaderChangeMessage leaderChangeMessage, + long currentTimeMs + ) { + appendControlMessage(buffer -> { + return MemoryRecords.withLeaderChangeMessage( + this.nextOffset, + currentTimeMs, + this.epoch, + buffer, + leaderChangeMessage + ); + }); + } + + + /** + * Append a {@link SnapshotHeaderRecord} record to the batch + * + * @param snapshotHeaderRecord The message to append + * @throws IllegalStateException on failure to allocate a buffer for the record + */ + public void appendSnapshotHeaderMessage( + SnapshotHeaderRecord snapshotHeaderRecord, + long currentTimeMs + ) { + appendControlMessage(buffer -> { + return MemoryRecords.withSnapshotHeaderRecord( + this.nextOffset, + currentTimeMs, + this.epoch, + buffer, + snapshotHeaderRecord + ); + }); + } + + /** + * Append a {@link SnapshotFooterRecord} record to the batch + * + * @param snapshotFooterRecord The message to append + * @param currentTimeMs + * @throws IllegalStateException on failure to allocate a buffer for the record + */ + public void appendSnapshotFooterMessage( + SnapshotFooterRecord snapshotFooterRecord, + long currentTimeMs + ) { + appendControlMessage(buffer -> { + return MemoryRecords.withSnapshotFooterRecord( + this.nextOffset, + currentTimeMs, + this.epoch, + buffer, + snapshotFooterRecord + ); + }); + } + + public void forceDrain() { + appendLock.lock(); + try { + drainStatus = DrainStatus.STARTED; + maybeCompleteDrain(); + } finally { + appendLock.unlock(); + } + } + + private void maybeCompleteDrain() { + if (drainStatus == DrainStatus.STARTED) { + if (currentBatch != null && currentBatch.nonEmpty()) { + completeCurrentBatch(); + } + // Reset the timer to a large value. The linger clock will begin + // ticking after the next append. + lingerTimer.reset(Long.MAX_VALUE); + drainStatus = DrainStatus.FINISHED; + } + } + + private void startNewBatch() { + ByteBuffer buffer = memoryPool.tryAllocate(maxBatchSize); + if (buffer != null) { + currentBatch = new BatchBuilder<>( + buffer, + serde, + compressionType, + nextOffset, + time.milliseconds(), + false, + epoch, + maxBatchSize + ); + } + } + + /** + * Check whether there are any batches which need to be drained now. + * + * @param currentTimeMs current time in milliseconds + * @return true if there are batches ready to drain, false otherwise + */ + public boolean needsDrain(long currentTimeMs) { + return timeUntilDrain(currentTimeMs) <= 0; + } + + /** + * Check the time remaining until the next needed drain. If the accumulator + * is empty, then {@link Long#MAX_VALUE} will be returned. + * + * @param currentTimeMs current time in milliseconds + * @return the delay in milliseconds before the next expected drain + */ + public long timeUntilDrain(long currentTimeMs) { + if (drainStatus == DrainStatus.FINISHED) { + return 0; + } else { + return lingerTimer.remainingMs(currentTimeMs); + } + } + + /** + * Get the leader epoch, which is constant for each instance. + * + * @return the leader epoch + */ + public int epoch() { + return epoch; + } + + /** + * Drain completed batches. The caller is expected to first check whether + * {@link #needsDrain(long)} returns true in order to avoid unnecessary draining. + * + * Note on thread-safety: this method is safe in the presence of concurrent + * appends, but it assumes a single thread is responsible for draining. + * + * This call will not block, but the drain may require multiple attempts before + * it can be completed if the thread responsible for appending is holding the + * append lock. In the worst case, the append will be completed on the next + * call to {@link #append(int, List)} following the initial call to this method. + * The caller should respect the time to the next flush as indicated by + * {@link #timeUntilDrain(long)}. + * + * @return the list of completed batches + */ + public List> drain() { + // Start the drain if it has not been started already + if (drainStatus == DrainStatus.NONE) { + drainStatus = DrainStatus.STARTED; + } + + // Complete the drain ourselves if we can acquire the lock + if (appendLock.tryLock()) { + try { + maybeCompleteDrain(); + } finally { + appendLock.unlock(); + } + } + + // If the drain has finished, then all of the batches will be completed + if (drainStatus == DrainStatus.FINISHED) { + drainStatus = DrainStatus.NONE; + return drainCompleted(); + } else { + return Collections.emptyList(); + } + } + + private List> drainCompleted() { + List> res = new ArrayList<>(completed.size()); + while (true) { + CompletedBatch batch = completed.poll(); + if (batch == null) { + return res; + } else { + res.add(batch); + } + } + } + + public boolean isEmpty() { + // The linger timer begins running when we have pending batches. + // We use this to infer when the accumulator is empty to avoid the + // need to acquire the append lock. + return !lingerTimer.isRunning(); + } + + /** + * Get the number of completed batches which are ready to be drained. + * This does not include the batch that is currently being filled. + */ + public int numCompletedBatches() { + return completed.size(); + } + + @Override + public void close() { + List> unwritten = drain(); + unwritten.forEach(CompletedBatch::release); + } + + public static class CompletedBatch { + public final long baseOffset; + public final int numRecords; + public final Optional> records; + public final MemoryRecords data; + private final MemoryPool pool; + // Buffer that was allocated by the MemoryPool (pool). This may not be the buffer used in + // the MemoryRecords (data) object. + private final ByteBuffer initialBuffer; + + private CompletedBatch( + long baseOffset, + List records, + MemoryRecords data, + MemoryPool pool, + ByteBuffer initialBuffer + ) { + Objects.requireNonNull(data.firstBatch(), "Exptected memory records to contain one batch"); + + this.baseOffset = baseOffset; + this.records = Optional.of(records); + this.numRecords = records.size(); + this.data = data; + this.pool = pool; + this.initialBuffer = initialBuffer; + } + + private CompletedBatch( + long baseOffset, + int numRecords, + MemoryRecords data, + MemoryPool pool, + ByteBuffer initialBuffer + ) { + Objects.requireNonNull(data.firstBatch(), "Exptected memory records to contain one batch"); + + this.baseOffset = baseOffset; + this.records = Optional.empty(); + this.numRecords = numRecords; + this.data = data; + this.pool = pool; + this.initialBuffer = initialBuffer; + } + + public int sizeInBytes() { + return data.sizeInBytes(); + } + + public void release() { + pool.release(initialBuffer); + } + + public long appendTimestamp() { + // 1. firstBatch is not null because data has one and only one batch + // 2. maxTimestamp is the append time of the batch. This needs to be changed + // to return the LastContainedLogTimestamp of the SnapshotHeaderRecord + return data.firstBatch().maxTimestamp(); + } + } + + private static class SimpleTimer { + // We use an atomic long so that the Raft IO thread can query the linger + // time without any locking + private final AtomicLong deadlineMs = new AtomicLong(Long.MAX_VALUE); + + boolean isRunning() { + return deadlineMs.get() != Long.MAX_VALUE; + } + + void reset(long deadlineMs) { + this.deadlineMs.set(deadlineMs); + } + + long remainingMs(long currentTimeMs) { + return Math.max(0, deadlineMs.get() - currentTimeMs); + } + } + +} diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/BatchBuilder.java b/raft/src/main/java/org/apache/kafka/raft/internals/BatchBuilder.java new file mode 100644 index 0000000..982040b --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/internals/BatchBuilder.java @@ -0,0 +1,355 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + +import org.apache.kafka.common.protocol.DataOutputStreamWritable; +import org.apache.kafka.common.protocol.ObjectSerializationCache; +import org.apache.kafka.common.protocol.Writable; +import org.apache.kafka.common.record.AbstractRecords; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.DefaultRecord; +import org.apache.kafka.common.record.DefaultRecordBatch; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.common.utils.ByteBufferOutputStream; +import org.apache.kafka.common.utils.ByteUtils; +import org.apache.kafka.server.common.serialization.RecordSerde; + +import java.io.DataOutputStream; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.OptionalInt; + +/** + * Collect a set of records into a single batch. New records are added + * through {@link #appendRecord(Object, ObjectSerializationCache)}, but the caller must first + * check whether there is room using {@link #bytesNeeded(Collection, ObjectSerializationCache)}. Once the + * batch is ready, then {@link #build()} should be used to get the resulting + * {@link MemoryRecords} instance. + * + * @param record type indicated by {@link RecordSerde} passed in constructor + */ +public class BatchBuilder { + private final ByteBuffer initialBuffer; + private final CompressionType compressionType; + private final ByteBufferOutputStream batchOutput; + private final DataOutputStreamWritable recordOutput; + private final long baseOffset; + private final long appendTime; + private final boolean isControlBatch; + private final int leaderEpoch; + private final int initialPosition; + private final int maxBytes; + private final RecordSerde serde; + private final List records; + + private long nextOffset; + private int unflushedBytes; + private boolean isOpenForAppends = true; + + public BatchBuilder( + ByteBuffer buffer, + RecordSerde serde, + CompressionType compressionType, + long baseOffset, + long appendTime, + boolean isControlBatch, + int leaderEpoch, + int maxBytes + ) { + this.initialBuffer = buffer; + this.batchOutput = new ByteBufferOutputStream(buffer); + this.serde = serde; + this.compressionType = compressionType; + this.baseOffset = baseOffset; + this.nextOffset = baseOffset; + this.appendTime = appendTime; + this.isControlBatch = isControlBatch; + this.initialPosition = batchOutput.position(); + this.leaderEpoch = leaderEpoch; + this.maxBytes = maxBytes; + this.records = new ArrayList<>(); + + // field compressionType must be set before calculating the batch header size + int batchHeaderSizeInBytes = batchHeaderSizeInBytes(); + batchOutput.position(initialPosition + batchHeaderSizeInBytes); + + this.recordOutput = new DataOutputStreamWritable(new DataOutputStream( + compressionType.wrapForOutput(this.batchOutput, RecordBatch.MAGIC_VALUE_V2))); + } + + /** + * Append a record to this batch. The caller must first verify there is room for the batch + * using {@link #bytesNeeded(Collection, ObjectSerializationCache)}. + * + * @param record the record to append + * @param serializationCache serialization cache for use in {@link RecordSerde#write(Object, ObjectSerializationCache, Writable)} + * @return the offset of the appended batch + */ + public long appendRecord(T record, ObjectSerializationCache serializationCache) { + if (!isOpenForAppends) { + throw new IllegalStateException("Cannot append new records after the batch has been built"); + } + + if (nextOffset - baseOffset > Integer.MAX_VALUE) { + throw new IllegalArgumentException("Cannot include more than " + Integer.MAX_VALUE + + " records in a single batch"); + } + + long offset = nextOffset++; + int recordSizeInBytes = writeRecord( + offset, + record, + serializationCache + ); + unflushedBytes += recordSizeInBytes; + records.add(record); + return offset; + } + + /** + * Check whether the batch has enough room for all the record values. + * + * Returns an empty {@link OptionalInt} if the batch builder has room for this list of records. + * Otherwise it returns the expected number of bytes needed for a batch to contain these records. + * + * @param records the records to use when checking for room + * @param serializationCache serialization cache for computing sizes + * @return empty {@link OptionalInt} if there is room for the records to be appended, otherwise + * returns the number of bytes needed + */ + public OptionalInt bytesNeeded(Collection records, ObjectSerializationCache serializationCache) { + int bytesNeeded = bytesNeededForRecords( + records, + serializationCache + ); + + if (!isOpenForAppends) { + return OptionalInt.of(batchHeaderSizeInBytes() + bytesNeeded); + } + + int approxUnusedSizeInBytes = maxBytes - approximateSizeInBytes(); + if (approxUnusedSizeInBytes >= bytesNeeded) { + return OptionalInt.empty(); + } else if (unflushedBytes > 0) { + recordOutput.flush(); + unflushedBytes = 0; + int unusedSizeInBytes = maxBytes - flushedSizeInBytes(); + if (unusedSizeInBytes >= bytesNeeded) { + return OptionalInt.empty(); + } + } + + return OptionalInt.of(batchHeaderSizeInBytes() + bytesNeeded); + } + + private int flushedSizeInBytes() { + return batchOutput.position() - initialPosition; + } + + /** + * Get an estimate of the current size of the appended data. This estimate + * is precise if no compression is in use. + * + * @return estimated size in bytes of the appended records + */ + public int approximateSizeInBytes() { + return flushedSizeInBytes() + unflushedBytes; + } + + /** + * Get the base offset of this batch. This is constant upon constructing + * the builder instance. + * + * @return the base offset + */ + public long baseOffset() { + return baseOffset; + } + + /** + * Return the offset of the last appended record. This is updated after + * every append and can be used after the batch has been built to obtain + * the last offset. + * + * @return the offset of the last appended record + */ + public long lastOffset() { + return nextOffset - 1; + } + + /** + * Get the number of records appended to the batch. This is updated after + * each append. + * + * @return the number of appended records + */ + public int numRecords() { + return (int) (nextOffset - baseOffset); + } + + /** + * Check whether there has been at least one record appended to the batch. + * + * @return true if one or more records have been appended + */ + public boolean nonEmpty() { + return numRecords() > 0; + } + + /** + * Return the reference to the initial buffer passed through the constructor. + * This is used in case the buffer needs to be returned to a pool (e.g. + * in {@link org.apache.kafka.common.memory.MemoryPool#release(ByteBuffer)}. + * + * @return the initial buffer passed to the constructor + */ + public ByteBuffer initialBuffer() { + return initialBuffer; + } + + /** + * Get a list of the records appended to the batch. + * @return a list of records + */ + public List records() { + return records; + } + + private void writeDefaultBatchHeader() { + ByteBuffer buffer = batchOutput.buffer(); + int lastPosition = buffer.position(); + + buffer.position(initialPosition); + int size = lastPosition - initialPosition; + int lastOffsetDelta = (int) (lastOffset() - baseOffset); + + DefaultRecordBatch.writeHeader( + buffer, + baseOffset, + lastOffsetDelta, + size, + RecordBatch.MAGIC_VALUE_V2, + compressionType, + TimestampType.CREATE_TIME, + appendTime, + appendTime, + RecordBatch.NO_PRODUCER_ID, + RecordBatch.NO_PRODUCER_EPOCH, + RecordBatch.NO_SEQUENCE, + false, + isControlBatch, + false, + leaderEpoch, + numRecords() + ); + + buffer.position(lastPosition); + } + + public MemoryRecords build() { + recordOutput.close(); + writeDefaultBatchHeader(); + ByteBuffer buffer = batchOutput.buffer().duplicate(); + buffer.flip(); + buffer.position(initialPosition); + isOpenForAppends = false; + return MemoryRecords.readableRecords(buffer.slice()); + } + + public int writeRecord( + long offset, + T payload, + ObjectSerializationCache serializationCache + ) { + int offsetDelta = (int) (offset - baseOffset); + long timestampDelta = 0; + + int payloadSize = serde.recordSize(payload, serializationCache); + int sizeInBytes = DefaultRecord.sizeOfBodyInBytes( + offsetDelta, + timestampDelta, + -1, + payloadSize, + DefaultRecord.EMPTY_HEADERS + ); + recordOutput.writeVarint(sizeInBytes); + + // Write attributes (currently unused) + recordOutput.writeByte((byte) 0); + + // Write timestamp and offset + recordOutput.writeVarlong(timestampDelta); + recordOutput.writeVarint(offsetDelta); + + // Write key, which is always null for controller messages + recordOutput.writeVarint(-1); + + // Write value + recordOutput.writeVarint(payloadSize); + serde.write(payload, serializationCache, recordOutput); + + // Write headers (currently unused) + recordOutput.writeVarint(0); + return ByteUtils.sizeOfVarint(sizeInBytes) + sizeInBytes; + } + + private int batchHeaderSizeInBytes() { + return AbstractRecords.recordBatchHeaderSizeInBytes( + RecordBatch.MAGIC_VALUE_V2, + compressionType + ); + } + + private int bytesNeededForRecords( + Collection records, + ObjectSerializationCache serializationCache + ) { + long expectedNextOffset = nextOffset; + int bytesNeeded = 0; + for (T record : records) { + if (expectedNextOffset - baseOffset >= Integer.MAX_VALUE) { + throw new IllegalArgumentException( + String.format( + "Adding %s records to a batch with base offset of %s and next offset of %s", + records.size(), + baseOffset, + expectedNextOffset + ) + ); + } + + int recordSizeInBytes = DefaultRecord.sizeOfBodyInBytes( + (int) (expectedNextOffset - baseOffset), + 0, + -1, + serde.recordSize(record, serializationCache), + DefaultRecord.EMPTY_HEADERS + ); + + bytesNeeded = Math.addExact(bytesNeeded, ByteUtils.sizeOfVarint(recordSizeInBytes)); + bytesNeeded = Math.addExact(bytesNeeded, recordSizeInBytes); + + expectedNextOffset += 1; + } + + return bytesNeeded; + } +} diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/BatchMemoryPool.java b/raft/src/main/java/org/apache/kafka/raft/internals/BatchMemoryPool.java new file mode 100644 index 0000000..5cd3e33 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/internals/BatchMemoryPool.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + +import org.apache.kafka.common.memory.MemoryPool; + +import java.nio.ByteBuffer; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.concurrent.locks.ReentrantLock; + +/** + * Simple memory pool which maintains a limited number of fixed-size buffers. + */ +public class BatchMemoryPool implements MemoryPool { + private final ReentrantLock lock; + private final Deque free; + private final int maxBatches; + private final int batchSize; + + private int numAllocatedBatches = 0; + + public BatchMemoryPool(int maxBatches, int batchSize) { + this.maxBatches = maxBatches; + this.batchSize = batchSize; + this.free = new ArrayDeque<>(maxBatches); + this.lock = new ReentrantLock(); + } + + @Override + public ByteBuffer tryAllocate(int sizeBytes) { + if (sizeBytes > batchSize) { + throw new IllegalArgumentException("Cannot allocate buffers larger than max " + + "batch size of " + batchSize); + } + + lock.lock(); + try { + ByteBuffer buffer = free.poll(); + if (buffer == null && numAllocatedBatches < maxBatches) { + buffer = ByteBuffer.allocate(batchSize); + numAllocatedBatches += 1; + } + return buffer; + } finally { + lock.unlock(); + } + } + + @Override + public void release(ByteBuffer previouslyAllocated) { + lock.lock(); + try { + previouslyAllocated.clear(); + + if (previouslyAllocated.limit() != batchSize) { + throw new IllegalArgumentException("Released buffer with unexpected size " + + previouslyAllocated.limit()); + } + + free.offer(previouslyAllocated); + } finally { + lock.unlock(); + } + } + + @Override + public long size() { + lock.lock(); + try { + return numAllocatedBatches * (long) batchSize; + } finally { + lock.unlock(); + } + } + + @Override + public long availableMemory() { + lock.lock(); + try { + int freeBatches = free.size() + (maxBatches - numAllocatedBatches); + return freeBatches * (long) batchSize; + } finally { + lock.unlock(); + } + } + + @Override + public boolean isOutOfMemory() { + return availableMemory() == 0; + } + +} diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/BlockingMessageQueue.java b/raft/src/main/java/org/apache/kafka/raft/internals/BlockingMessageQueue.java new file mode 100644 index 0000000..9343cca --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/internals/BlockingMessageQueue.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + +import org.apache.kafka.common.errors.InterruptException; +import org.apache.kafka.common.protocol.ApiMessage; +import org.apache.kafka.raft.RaftMessage; +import org.apache.kafka.raft.RaftMessageQueue; + +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +public class BlockingMessageQueue implements RaftMessageQueue { + private static final RaftMessage WAKEUP_MESSAGE = new RaftMessage() { + @Override + public int correlationId() { + return 0; + } + + @Override + public ApiMessage data() { + return null; + } + }; + + private final BlockingQueue queue = new LinkedBlockingQueue<>(); + private final AtomicInteger size = new AtomicInteger(0); + + @Override + public RaftMessage poll(long timeoutMs) { + try { + RaftMessage message = queue.poll(timeoutMs, TimeUnit.MILLISECONDS); + if (message == null || message == WAKEUP_MESSAGE) { + return null; + } else { + size.decrementAndGet(); + return message; + } + } catch (InterruptedException e) { + throw new InterruptException(e); + } + } + + @Override + public void add(RaftMessage message) { + queue.add(message); + size.incrementAndGet(); + } + + @Override + public boolean isEmpty() { + return size.get() == 0; + } + + @Override + public void wakeup() { + queue.add(WAKEUP_MESSAGE); + } + +} diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/CloseListener.java b/raft/src/main/java/org/apache/kafka/raft/internals/CloseListener.java new file mode 100644 index 0000000..e54ff96 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/internals/CloseListener.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + +public interface CloseListener { + + void onClose(T closeable); + +} diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/FuturePurgatory.java b/raft/src/main/java/org/apache/kafka/raft/internals/FuturePurgatory.java new file mode 100644 index 0000000..b37fb3a --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/internals/FuturePurgatory.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + +import java.util.concurrent.CompletableFuture; + +/** + * Simple purgatory interface which supports waiting with expiration for a given threshold + * to be reached. The threshold is specified through {@link #await(Comparable, long)}. + * The returned future can be completed in the following ways: + * + * 1) The future is completed successfully if the threshold value is reached + * in a call to {@link #maybeComplete(Comparable, long)}. + * 2) The future is completed successfully if {@link #completeAll(long)} is called. + * 3) The future is completed exceptionally if {@link #completeAllExceptionally(Throwable)} + * is called. + * 4) If none of the above happens before the expiration of the timeout passed to + * {@link #await(Comparable, long)}, then the future will be completed exceptionally + * with a {@link org.apache.kafka.common.errors.TimeoutException}. + * + * It is also possible for the future to be completed externally, but this should + * generally be avoided. + * + * Note that the future objects should be organized in order so that completing awaiting + * futures would stop early and not traverse all awaiting futures. + * + * @param threshold value type + */ +public interface FuturePurgatory> { + + /** + * Create a new future which is tracked by the purgatory. + * + * @param threshold the minimum value that must be reached for the future + * to be successfully completed by {@link #maybeComplete(Comparable, long)} + * @param maxWaitTimeMs the maximum time to wait for completion. If this + * timeout is reached, then the future will be completed exceptionally + * with a {@link org.apache.kafka.common.errors.TimeoutException} + * + * @return the future tracking the expected completion + */ + CompletableFuture await(T threshold, long maxWaitTimeMs); + + /** + * Complete awaiting futures whose associated values are larger than the given threshold value. + * The completion callbacks will be triggered from the calling thread. + * + * @param value the threshold value used to determine which futures can be completed + * @param currentTimeMs the current time in milliseconds that will be passed to + * {@link CompletableFuture#complete(Object)} when the futures are completed + */ + void maybeComplete(T value, long currentTimeMs); + + /** + * Complete all awaiting futures successfully. + * + * @param currentTimeMs the current time in milliseconds that will be passed to + * {@link CompletableFuture#complete(Object)} when the futures are completed + */ + void completeAll(long currentTimeMs); + + /** + * Complete all awaiting futures exceptionally. The completion callbacks will be + * triggered with the passed in exception. + * + * @param exception the current time in milliseconds that will be passed to + * {@link CompletableFuture#completeExceptionally(Throwable)} + */ + void completeAllExceptionally(Throwable exception); + + /** + * The number of currently waiting futures. + * + * @return the number of waiting futures + */ + int numWaiting(); +} diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/KafkaRaftMetrics.java b/raft/src/main/java/org/apache/kafka/raft/internals/KafkaRaftMetrics.java new file mode 100644 index 0000000..c7ffcfb --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/internals/KafkaRaftMetrics.java @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.Gauge; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.stats.Avg; +import org.apache.kafka.common.metrics.stats.Max; +import org.apache.kafka.common.metrics.stats.Rate; +import org.apache.kafka.common.metrics.stats.WindowedSum; +import org.apache.kafka.raft.OffsetAndEpoch; +import org.apache.kafka.raft.QuorumState; + +import java.util.OptionalLong; +import java.util.concurrent.TimeUnit; + +public class KafkaRaftMetrics implements AutoCloseable { + + private final Metrics metrics; + + private OffsetAndEpoch logEndOffset; + private int numUnknownVoterConnections; + private OptionalLong electionStartMs; + private OptionalLong pollStartMs; + private OptionalLong pollEndMs; + + private final MetricName currentLeaderIdMetricName; + private final MetricName currentVotedIdMetricName; + private final MetricName currentEpochMetricName; + private final MetricName currentStateMetricName; + private final MetricName highWatermarkMetricName; + private final MetricName logEndOffsetMetricName; + private final MetricName logEndEpochMetricName; + private final MetricName numUnknownVoterConnectionsMetricName; + private final Sensor commitTimeSensor; + private final Sensor electionTimeSensor; + private final Sensor fetchRecordsSensor; + private final Sensor appendRecordsSensor; + private final Sensor pollIdleSensor; + + public KafkaRaftMetrics(Metrics metrics, String metricGrpPrefix, QuorumState state) { + this.metrics = metrics; + String metricGroupName = metricGrpPrefix + "-metrics"; + + this.pollStartMs = OptionalLong.empty(); + this.pollEndMs = OptionalLong.empty(); + this.electionStartMs = OptionalLong.empty(); + this.numUnknownVoterConnections = 0; + this.logEndOffset = new OffsetAndEpoch(0L, 0); + + this.currentStateMetricName = metrics.metricName("current-state", metricGroupName, "The current state of this member; possible values are leader, candidate, voted, follower, unattached"); + Gauge stateProvider = (mConfig, currentTimeMs) -> { + if (state.isLeader()) { + return "leader"; + } else if (state.isCandidate()) { + return "candidate"; + } else if (state.isVoted()) { + return "voted"; + } else if (state.isFollower()) { + return "follower"; + } else { + return "unattached"; + } + }; + metrics.addMetric(this.currentStateMetricName, null, stateProvider); + + this.currentLeaderIdMetricName = metrics.metricName("current-leader", metricGroupName, "The current quorum leader's id; -1 indicates unknown"); + metrics.addMetric(this.currentLeaderIdMetricName, (mConfig, currentTimeMs) -> state.leaderId().orElse(-1)); + + this.currentVotedIdMetricName = metrics.metricName("current-vote", metricGroupName, "The current voted leader's id; -1 indicates not voted for anyone"); + metrics.addMetric(this.currentVotedIdMetricName, (mConfig, currentTimeMs) -> { + if (state.isLeader() || state.isCandidate()) { + return state.localIdOrThrow(); + } else if (state.isVoted()) { + return state.votedStateOrThrow().votedId(); + } else { + return -1; + } + }); + + this.currentEpochMetricName = metrics.metricName("current-epoch", metricGroupName, "The current quorum epoch."); + metrics.addMetric(this.currentEpochMetricName, (mConfig, currentTimeMs) -> state.epoch()); + + this.highWatermarkMetricName = metrics.metricName("high-watermark", metricGroupName, "The high watermark maintained on this member; -1 if it is unknown"); + metrics.addMetric(this.highWatermarkMetricName, (mConfig, currentTimeMs) -> state.highWatermark().map(hw -> hw.offset).orElse(-1L)); + + this.logEndOffsetMetricName = metrics.metricName("log-end-offset", metricGroupName, "The current raft log end offset."); + metrics.addMetric(this.logEndOffsetMetricName, (mConfig, currentTimeMs) -> logEndOffset.offset); + + this.logEndEpochMetricName = metrics.metricName("log-end-epoch", metricGroupName, "The current raft log end epoch."); + metrics.addMetric(this.logEndEpochMetricName, (mConfig, currentTimeMs) -> logEndOffset.epoch); + + this.numUnknownVoterConnectionsMetricName = metrics.metricName("number-unknown-voter-connections", metricGroupName, "The number of voter connections recognized at this member."); + metrics.addMetric(this.numUnknownVoterConnectionsMetricName, (mConfig, currentTimeMs) -> numUnknownVoterConnections); + + this.commitTimeSensor = metrics.sensor("commit-latency"); + this.commitTimeSensor.add(metrics.metricName("commit-latency-avg", metricGroupName, + "The average time in milliseconds to commit an entry in the raft log."), new Avg()); + this.commitTimeSensor.add(metrics.metricName("commit-latency-max", metricGroupName, + "The maximum time in milliseconds to commit an entry in the raft log."), new Max()); + + this.electionTimeSensor = metrics.sensor("election-latency"); + this.electionTimeSensor.add(metrics.metricName("election-latency-avg", metricGroupName, + "The average time in milliseconds to elect a new leader."), new Avg()); + this.electionTimeSensor.add(metrics.metricName("election-latency-max", metricGroupName, + "The maximum time in milliseconds to elect a new leader."), new Max()); + + this.fetchRecordsSensor = metrics.sensor("fetch-records"); + this.fetchRecordsSensor.add(metrics.metricName("fetch-records-rate", metricGroupName, + "The average number of records fetched from the leader of the raft quorum."), + new Rate(TimeUnit.SECONDS, new WindowedSum())); + + this.appendRecordsSensor = metrics.sensor("append-records"); + this.appendRecordsSensor.add(metrics.metricName("append-records-rate", metricGroupName, + "The average number of records appended per sec as the leader of the raft quorum."), + new Rate(TimeUnit.SECONDS, new WindowedSum())); + + this.pollIdleSensor = metrics.sensor("poll-idle-ratio"); + this.pollIdleSensor.add(metrics.metricName("poll-idle-ratio-avg", + metricGroupName, + "The average fraction of time the client's poll() is idle as opposed to waiting for the user code to process records."), + new Avg()); + } + + public void updatePollStart(long currentTimeMs) { + if (pollEndMs.isPresent() && pollStartMs.isPresent()) { + long pollTimeMs = Math.max(pollEndMs.getAsLong() - pollStartMs.getAsLong(), 0L); + long totalTimeMs = Math.max(currentTimeMs - pollStartMs.getAsLong(), 1L); + this.pollIdleSensor.record(pollTimeMs / (double) totalTimeMs, currentTimeMs); + } + + this.pollStartMs = OptionalLong.of(currentTimeMs); + this.pollEndMs = OptionalLong.empty(); + } + + public void updatePollEnd(long currentTimeMs) { + this.pollEndMs = OptionalLong.of(currentTimeMs); + } + + public void updateLogEnd(OffsetAndEpoch logEndOffset) { + this.logEndOffset = logEndOffset; + } + + public void updateNumUnknownVoterConnections(int numUnknownVoterConnections) { + this.numUnknownVoterConnections = numUnknownVoterConnections; + } + + public void updateAppendRecords(long numRecords) { + appendRecordsSensor.record(numRecords); + } + + public void updateFetchedRecords(long numRecords) { + fetchRecordsSensor.record(numRecords); + } + + public void updateCommitLatency(double latencyMs, long currentTimeMs) { + commitTimeSensor.record(latencyMs, currentTimeMs); + } + + public void updateElectionStartMs(long currentTimeMs) { + electionStartMs = OptionalLong.of(currentTimeMs); + } + + public void maybeUpdateElectionLatency(long currentTimeMs) { + if (electionStartMs.isPresent()) { + electionTimeSensor.record(currentTimeMs - electionStartMs.getAsLong(), currentTimeMs); + electionStartMs = OptionalLong.empty(); + } + } + + @Override + public void close() { + metrics.removeMetric(currentLeaderIdMetricName); + metrics.removeMetric(currentVotedIdMetricName); + metrics.removeMetric(currentEpochMetricName); + metrics.removeMetric(currentStateMetricName); + metrics.removeMetric(highWatermarkMetricName); + metrics.removeMetric(logEndOffsetMetricName); + metrics.removeMetric(logEndEpochMetricName); + metrics.removeMetric(numUnknownVoterConnectionsMetricName); + + metrics.removeSensor(commitTimeSensor.name()); + metrics.removeSensor(electionTimeSensor.name()); + metrics.removeSensor(fetchRecordsSensor.name()); + metrics.removeSensor(appendRecordsSensor.name()); + metrics.removeSensor(pollIdleSensor.name()); + } +} diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/MemoryBatchReader.java b/raft/src/main/java/org/apache/kafka/raft/internals/MemoryBatchReader.java new file mode 100644 index 0000000..df0ec50 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/internals/MemoryBatchReader.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + +import org.apache.kafka.raft.Batch; +import org.apache.kafka.raft.BatchReader; + +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.OptionalLong; + +public class MemoryBatchReader implements BatchReader { + private final CloseListener> closeListener; + private final Iterator> iterator; + private final long baseOffset; + private final long lastOffset; + + private MemoryBatchReader( + long baseOffset, + long lastOffset, + Iterator> iterator, + CloseListener> closeListener + ) { + this.baseOffset = baseOffset; + this.lastOffset = lastOffset; + this.iterator = iterator; + this.closeListener = closeListener; + } + + @Override + public boolean hasNext() { + return iterator.hasNext(); + } + + @Override + public Batch next() { + return iterator.next(); + } + + @Override + public long baseOffset() { + return baseOffset; + } + + @Override + public OptionalLong lastOffset() { + return OptionalLong.of(lastOffset); + } + + @Override + public void close() { + closeListener.onClose(this); + } + + public static MemoryBatchReader empty( + long baseOffset, + long lastOffset, + CloseListener> closeListener + ) { + return new MemoryBatchReader<>( + baseOffset, + lastOffset, + Collections.emptyIterator(), + closeListener + ); + } + + public static MemoryBatchReader of( + List> batches, + CloseListener> closeListener + ) { + if (batches.isEmpty()) { + throw new IllegalArgumentException("MemoryBatchReader requires at least " + + "one batch to iterate, but an empty list was provided"); + } + + return new MemoryBatchReader<>( + batches.get(0).baseOffset(), + batches.get(batches.size() - 1).lastOffset(), + batches.iterator(), + closeListener + ); + } +} diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/RecordsBatchReader.java b/raft/src/main/java/org/apache/kafka/raft/internals/RecordsBatchReader.java new file mode 100644 index 0000000..e952061 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/internals/RecordsBatchReader.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + +import org.apache.kafka.common.record.Records; +import org.apache.kafka.common.utils.BufferSupplier; +import org.apache.kafka.raft.Batch; +import org.apache.kafka.raft.BatchReader; +import org.apache.kafka.server.common.serialization.RecordSerde; + +import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.OptionalLong; + +public final class RecordsBatchReader implements BatchReader { + private final long baseOffset; + private final RecordsIterator iterator; + private final CloseListener> closeListener; + + private long lastReturnedOffset; + + private Optional> nextBatch = Optional.empty(); + private boolean isClosed = false; + + private RecordsBatchReader( + long baseOffset, + RecordsIterator iterator, + CloseListener> closeListener + ) { + this.baseOffset = baseOffset; + this.iterator = iterator; + this.closeListener = closeListener; + this.lastReturnedOffset = baseOffset; + } + + @Override + public boolean hasNext() { + ensureOpen(); + + if (!nextBatch.isPresent()) { + nextBatch = nextBatch(); + } + + return nextBatch.isPresent(); + } + + @Override + public Batch next() { + if (!hasNext()) { + throw new NoSuchElementException("Records batch reader doesn't have any more elements"); + } + + Batch batch = nextBatch.get(); + nextBatch = Optional.empty(); + + lastReturnedOffset = batch.lastOffset(); + return batch; + } + + @Override + public long baseOffset() { + return baseOffset; + } + + public OptionalLong lastOffset() { + if (isClosed) { + return OptionalLong.of(lastReturnedOffset); + } else { + return OptionalLong.empty(); + } + } + + @Override + public void close() { + if (!isClosed) { + isClosed = true; + + iterator.close(); + closeListener.onClose(this); + } + } + + public static RecordsBatchReader of( + long baseOffset, + Records records, + RecordSerde serde, + BufferSupplier bufferSupplier, + int maxBatchSize, + CloseListener> closeListener + ) { + return new RecordsBatchReader<>( + baseOffset, + new RecordsIterator<>(records, serde, bufferSupplier, maxBatchSize), + closeListener + ); + } + + private void ensureOpen() { + if (isClosed) { + throw new IllegalStateException("Records batch reader was closed"); + } + } + + private Optional> nextBatch() { + while (iterator.hasNext()) { + Batch batch = iterator.next(); + + if (batch.records().isEmpty()) { + lastReturnedOffset = batch.lastOffset(); + } else { + return Optional.of(batch); + } + } + + return Optional.empty(); + } +} diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/RecordsIterator.java b/raft/src/main/java/org/apache/kafka/raft/internals/RecordsIterator.java new file mode 100644 index 0000000..b36d4f1 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/internals/RecordsIterator.java @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Optional; +import org.apache.kafka.common.protocol.DataInputStreamReadable; +import org.apache.kafka.common.protocol.Readable; +import org.apache.kafka.common.record.DefaultRecordBatch; +import org.apache.kafka.common.record.FileRecords; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.MutableRecordBatch; +import org.apache.kafka.common.record.Records; +import org.apache.kafka.common.utils.BufferSupplier; +import org.apache.kafka.raft.Batch; +import org.apache.kafka.server.common.serialization.RecordSerde; + +public final class RecordsIterator implements Iterator>, AutoCloseable { + private final Records records; + private final RecordSerde serde; + private final BufferSupplier bufferSupplier; + private final int batchSize; + + private Iterator nextBatches = Collections.emptyIterator(); + private Optional> nextBatch = Optional.empty(); + // Buffer used as the backing store for nextBatches if needed + private Optional allocatedBuffer = Optional.empty(); + // Number of bytes from records read up to now + private int bytesRead = 0; + private boolean isClosed = false; + + public RecordsIterator( + Records records, + RecordSerde serde, + BufferSupplier bufferSupplier, + int batchSize + ) { + this.records = records; + this.serde = serde; + this.bufferSupplier = bufferSupplier; + this.batchSize = Math.max(batchSize, Records.HEADER_SIZE_UP_TO_MAGIC); + } + + @Override + public boolean hasNext() { + ensureOpen(); + + if (!nextBatch.isPresent()) { + nextBatch = nextBatch(); + } + + return nextBatch.isPresent(); + } + + @Override + public Batch next() { + if (!hasNext()) { + throw new NoSuchElementException("Batch iterator doesn't have any more elements"); + } + + Batch batch = nextBatch.get(); + nextBatch = Optional.empty(); + + return batch; + } + + @Override + public void close() { + isClosed = true; + allocatedBuffer.ifPresent(bufferSupplier::release); + allocatedBuffer = Optional.empty(); + } + + private void ensureOpen() { + if (isClosed) { + throw new IllegalStateException("Serde record batch itererator was closed"); + } + } + + private MemoryRecords readFileRecords(FileRecords fileRecords, ByteBuffer buffer) { + int start = buffer.position(); + try { + fileRecords.readInto(buffer, bytesRead); + } catch (IOException e) { + throw new UncheckedIOException("Failed to read records into memory", e); + } + + bytesRead += buffer.limit() - start; + return MemoryRecords.readableRecords(buffer.slice()); + } + + private MemoryRecords createMemoryRecords(FileRecords fileRecords) { + final ByteBuffer buffer; + if (allocatedBuffer.isPresent()) { + buffer = allocatedBuffer.get(); + buffer.compact(); + } else { + buffer = bufferSupplier.get(Math.min(batchSize, records.sizeInBytes())); + allocatedBuffer = Optional.of(buffer); + } + + MemoryRecords memoryRecords = readFileRecords(fileRecords, buffer); + + // firstBatchSize() is always non-null because the minimum buffer is HEADER_SIZE_UP_TO_MAGIC. + if (memoryRecords.firstBatchSize() <= buffer.remaining()) { + return memoryRecords; + } else { + // Not enough bytes read; create a bigger buffer + ByteBuffer newBuffer = bufferSupplier.get(memoryRecords.firstBatchSize()); + allocatedBuffer = Optional.of(newBuffer); + + newBuffer.put(buffer); + bufferSupplier.release(buffer); + + return readFileRecords(fileRecords, newBuffer); + } + } + + private Iterator nextBatches() { + int recordSize = records.sizeInBytes(); + if (bytesRead < recordSize) { + final MemoryRecords memoryRecords; + if (records instanceof MemoryRecords) { + bytesRead = recordSize; + memoryRecords = (MemoryRecords) records; + } else if (records instanceof FileRecords) { + memoryRecords = createMemoryRecords((FileRecords) records); + } else { + throw new IllegalStateException(String.format("Unexpected Records type %s", records.getClass())); + } + + return memoryRecords.batchIterator(); + } + + return Collections.emptyIterator(); + } + + private Optional> nextBatch() { + if (!nextBatches.hasNext()) { + nextBatches = nextBatches(); + } + + if (nextBatches.hasNext()) { + MutableRecordBatch nextBatch = nextBatches.next(); + + // Update the buffer position to reflect the read batch + allocatedBuffer.ifPresent(buffer -> buffer.position(buffer.position() + nextBatch.sizeInBytes())); + + if (!(nextBatch instanceof DefaultRecordBatch)) { + throw new IllegalStateException( + String.format("DefaultRecordBatch expected by record type was %s", nextBatch.getClass()) + ); + } + + return Optional.of(readBatch((DefaultRecordBatch) nextBatch)); + } + + return Optional.empty(); + } + + private Batch readBatch(DefaultRecordBatch batch) { + final Batch result; + if (batch.isControlBatch()) { + result = Batch.control( + batch.baseOffset(), + batch.partitionLeaderEpoch(), + batch.maxTimestamp(), + batch.sizeInBytes(), + batch.lastOffset() + ); + } else { + Integer numRecords = batch.countOrNull(); + if (numRecords == null) { + throw new IllegalStateException("Expected a record count for the records batch"); + } + + List records = new ArrayList<>(numRecords); + try (DataInputStreamReadable input = new DataInputStreamReadable(batch.recordInputStream(bufferSupplier))) { + for (int i = 0; i < numRecords; i++) { + T record = readRecord(input); + records.add(record); + } + } + + result = Batch.data( + batch.baseOffset(), + batch.partitionLeaderEpoch(), + batch.maxTimestamp(), + batch.sizeInBytes(), + records + ); + } + + return result; + } + + private T readRecord(Readable input) { + // Read size of body in bytes + input.readVarint(); + + // Read unused attributes + input.readByte(); + + long timestampDelta = input.readVarlong(); + if (timestampDelta != 0) { + throw new IllegalArgumentException(); + } + + // Read offset delta + input.readVarint(); + + int keySize = input.readVarint(); + if (keySize != -1) { + throw new IllegalArgumentException("Unexpected key size " + keySize); + } + + int valueSize = input.readVarint(); + if (valueSize < 0) { + throw new IllegalArgumentException(); + } + + T record = serde.read(input, valueSize); + + int numHeaders = input.readVarint(); + if (numHeaders != 0) { + throw new IllegalArgumentException(); + } + + return record; + } +} diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/StringSerde.java b/raft/src/main/java/org/apache/kafka/raft/internals/StringSerde.java new file mode 100644 index 0000000..c2a011a --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/internals/StringSerde.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + +import org.apache.kafka.common.protocol.ObjectSerializationCache; +import org.apache.kafka.common.protocol.Readable; +import org.apache.kafka.common.protocol.Writable; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.server.common.serialization.RecordSerde; + +public class StringSerde implements RecordSerde { + + @Override + public int recordSize(String data, ObjectSerializationCache serializationCache) { + return recordSize(data); + } + + public int recordSize(String data) { + return Utils.utf8Length(data); + } + + @Override + public void write(String data, ObjectSerializationCache serializationCache, Writable out) { + out.writeByteArray(Utils.utf8(data)); + } + + @Override + public String read(Readable input, int size) { + byte[] data = new byte[size]; + input.readArray(data); + return Utils.utf8(data); + } + +} diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/ThresholdPurgatory.java b/raft/src/main/java/org/apache/kafka/raft/internals/ThresholdPurgatory.java new file mode 100644 index 0000000..eec3911 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/internals/ThresholdPurgatory.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + +import org.apache.kafka.raft.ExpirationService; + +import java.util.NavigableMap; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentNavigableMap; +import java.util.concurrent.ConcurrentSkipListMap; +import java.util.concurrent.atomic.AtomicLong; + +public class ThresholdPurgatory> implements FuturePurgatory { + private final AtomicLong idGenerator = new AtomicLong(0); + private final ExpirationService expirationService; + private final ConcurrentNavigableMap, CompletableFuture> thresholdMap = + new ConcurrentSkipListMap<>(); + + public ThresholdPurgatory(ExpirationService expirationService) { + this.expirationService = expirationService; + } + + @Override + public CompletableFuture await(T threshold, long maxWaitTimeMs) { + ThresholdKey key = new ThresholdKey<>(idGenerator.incrementAndGet(), threshold); + CompletableFuture future = expirationService.failAfter(maxWaitTimeMs); + thresholdMap.put(key, future); + future.whenComplete((timeMs, exception) -> thresholdMap.remove(key)); + return future; + } + + @Override + public void maybeComplete(T value, long currentTimeMs) { + ThresholdKey maxKey = new ThresholdKey<>(Long.MAX_VALUE, value); + NavigableMap, CompletableFuture> submap = thresholdMap.headMap(maxKey); + for (CompletableFuture completion : submap.values()) { + completion.complete(currentTimeMs); + } + } + + @Override + public void completeAll(long currentTimeMs) { + for (CompletableFuture completion : thresholdMap.values()) { + completion.complete(currentTimeMs); + } + } + + @Override + public void completeAllExceptionally(Throwable exception) { + for (CompletableFuture completion : thresholdMap.values()) { + completion.completeExceptionally(exception); + } + } + + @Override + public int numWaiting() { + return thresholdMap.size(); + } + + private static class ThresholdKey> implements Comparable> { + private final long id; + private final T threshold; + + private ThresholdKey(long id, T threshold) { + this.id = id; + this.threshold = threshold; + } + + @Override + public int compareTo(ThresholdKey o) { + int res = this.threshold.compareTo(o.threshold); + if (res != 0) { + return res; + } else { + return Long.compare(this.id, o.id); + } + } + } + +} diff --git a/raft/src/main/java/org/apache/kafka/snapshot/FileRawSnapshotReader.java b/raft/src/main/java/org/apache/kafka/snapshot/FileRawSnapshotReader.java new file mode 100644 index 0000000..7d2955d --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/snapshot/FileRawSnapshotReader.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.snapshot; + +import org.apache.kafka.common.record.FileRecords; +import org.apache.kafka.common.record.Records; +import org.apache.kafka.common.record.UnalignedRecords; +import org.apache.kafka.raft.OffsetAndEpoch; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.file.Path; + +public final class FileRawSnapshotReader implements RawSnapshotReader, AutoCloseable { + private final FileRecords fileRecords; + private final OffsetAndEpoch snapshotId; + + private FileRawSnapshotReader(FileRecords fileRecords, OffsetAndEpoch snapshotId) { + this.fileRecords = fileRecords; + this.snapshotId = snapshotId; + } + + @Override + public OffsetAndEpoch snapshotId() { + return snapshotId; + } + + @Override + public long sizeInBytes() { + return fileRecords.sizeInBytes(); + } + + @Override + public UnalignedRecords slice(long position, int size) { + return fileRecords.sliceUnaligned(Math.toIntExact(position), size); + } + + @Override + public Records records() { + return fileRecords; + } + + @Override + public void close() { + try { + fileRecords.close(); + } catch (IOException e) { + throw new UncheckedIOException( + String.format("Unable to close snapshot reader %s at %s", snapshotId, fileRecords), + e + ); + } + } + + /** + * Opens a snapshot for reading. + * + * @param logDir the directory for the topic partition + * @param snapshotId the end offset and epoch for the snapshotId + */ + public static FileRawSnapshotReader open(Path logDir, OffsetAndEpoch snapshotId) { + FileRecords fileRecords; + Path filePath = Snapshots.snapshotPath(logDir, snapshotId); + try { + fileRecords = FileRecords.open( + filePath.toFile(), + false, // mutable + true, // fileAlreadyExists + 0, // initFileSize + false // preallocate + ); + } catch (IOException e) { + throw new UncheckedIOException( + String.format("Unable to Opens a snapshot file %s", filePath.toAbsolutePath()), e + ); + } + + return new FileRawSnapshotReader(fileRecords, snapshotId); + } +} diff --git a/raft/src/main/java/org/apache/kafka/snapshot/FileRawSnapshotWriter.java b/raft/src/main/java/org/apache/kafka/snapshot/FileRawSnapshotWriter.java new file mode 100644 index 0000000..badefd3 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/snapshot/FileRawSnapshotWriter.java @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.snapshot; + +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.UnalignedMemoryRecords; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.raft.OffsetAndEpoch; +import org.apache.kafka.raft.ReplicatedLog; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.channels.FileChannel; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.Optional; + +public final class FileRawSnapshotWriter implements RawSnapshotWriter { + private final Path tempSnapshotPath; + private final FileChannel channel; + private final OffsetAndEpoch snapshotId; + private final Optional replicatedLog; + private boolean frozen = false; + + private FileRawSnapshotWriter( + Path tempSnapshotPath, + FileChannel channel, + OffsetAndEpoch snapshotId, + Optional replicatedLog + ) { + this.tempSnapshotPath = tempSnapshotPath; + this.channel = channel; + this.snapshotId = snapshotId; + this.replicatedLog = replicatedLog; + } + + @Override + public OffsetAndEpoch snapshotId() { + return snapshotId; + } + + @Override + public long sizeInBytes() { + try { + return channel.size(); + } catch (IOException e) { + throw new UncheckedIOException( + String.format( + "Error calculating snapshot size. temp path = %s, snapshotId = %s.", + tempSnapshotPath, + snapshotId), + e + ); + } + } + + @Override + public void append(UnalignedMemoryRecords records) { + try { + checkIfFrozen("Append"); + Utils.writeFully(channel, records.buffer()); + } catch (IOException e) { + throw new UncheckedIOException( + String.format("Error writing file snapshot, " + + "temp path = %s, snapshotId = %s.", this.tempSnapshotPath, this.snapshotId), + e + ); + } + } + + @Override + public void append(MemoryRecords records) { + try { + checkIfFrozen("Append"); + Utils.writeFully(channel, records.buffer()); + } catch (IOException e) { + throw new UncheckedIOException( + String.format("Error writing file snapshot, " + + "temp path = %s, snapshotId = %s.", this.tempSnapshotPath, this.snapshotId), + e + ); + } + } + + @Override + public boolean isFrozen() { + return frozen; + } + + @Override + public void freeze() { + try { + checkIfFrozen("Freeze"); + + channel.close(); + frozen = true; + + if (!tempSnapshotPath.toFile().setReadOnly()) { + throw new IllegalStateException(String.format("Unable to set file (%s) as read-only", tempSnapshotPath)); + } + + Path destination = Snapshots.moveRename(tempSnapshotPath, snapshotId); + Utils.atomicMoveWithFallback(tempSnapshotPath, destination); + + replicatedLog.ifPresent(log -> log.onSnapshotFrozen(snapshotId)); + } catch (IOException e) { + throw new UncheckedIOException( + String.format("Error freezing file snapshot, " + + "temp path = %s, snapshotId = %s.", this.tempSnapshotPath, this.snapshotId), + e + ); + } + } + + @Override + public void close() { + try { + channel.close(); + // This is a noop if freeze was called before calling close + Files.deleteIfExists(tempSnapshotPath); + } catch (IOException e) { + throw new UncheckedIOException( + String.format("Error closing snapshot writer, " + + "temp path = %s, snapshotId %s.", this.tempSnapshotPath, this.snapshotId), + e + ); + } + } + + @Override + public String toString() { + return String.format( + "FileRawSnapshotWriter(path=%s, snapshotId=%s, frozen=%s)", + tempSnapshotPath, + snapshotId, + frozen + ); + } + + void checkIfFrozen(String operation) { + if (frozen) { + throw new IllegalStateException( + String.format( + "%s is not supported. Snapshot is already frozen: id = %s; temp path = %s", + operation, + snapshotId, + tempSnapshotPath + ) + ); + } + } + + /** + * Create a snapshot writer for topic partition log dir and snapshot id. + * + * @param logDir the directory for the topic partition + * @param snapshotId the end offset and epoch for the snapshotId + */ + public static FileRawSnapshotWriter create( + Path logDir, + OffsetAndEpoch snapshotId, + Optional replicatedLog + ) { + Path path = Snapshots.createTempFile(logDir, snapshotId); + + try { + return new FileRawSnapshotWriter( + path, + FileChannel.open(path, StandardOpenOption.WRITE, StandardOpenOption.APPEND), + snapshotId, + replicatedLog + ); + } catch (IOException e) { + throw new UncheckedIOException( + String.format( + "Error creating snapshot writer. path = %s, snapshotId %s.", + path, + snapshotId + ), + e + ); + } + } +} diff --git a/raft/src/main/java/org/apache/kafka/snapshot/RawSnapshotReader.java b/raft/src/main/java/org/apache/kafka/snapshot/RawSnapshotReader.java new file mode 100644 index 0000000..1a51999 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/snapshot/RawSnapshotReader.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.snapshot; + +import org.apache.kafka.common.record.Records; +import org.apache.kafka.common.record.UnalignedRecords; +import org.apache.kafka.raft.OffsetAndEpoch; + +/** + * Interface for reading snapshots as a sequence of records. + */ +public interface RawSnapshotReader { + /** + * Returns the end offset and epoch for the snapshot. + */ + OffsetAndEpoch snapshotId(); + + /** + * Returns the number of bytes for the snapshot. + */ + long sizeInBytes(); + + /** + * Creates a slize of unaligned records from the position up to a size. + * + * @param position the starting position of the slice in the snapshot + * @param size the maximum size of the slice + * @return an unaligned slice of records in the snapshot + */ + UnalignedRecords slice(long position, int size); + + /** + * Returns all of the records backing this snapshot reader. + * + * @return all of the records for this snapshot + */ + Records records(); +} diff --git a/raft/src/main/java/org/apache/kafka/snapshot/RawSnapshotWriter.java b/raft/src/main/java/org/apache/kafka/snapshot/RawSnapshotWriter.java new file mode 100644 index 0000000..07d8271 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/snapshot/RawSnapshotWriter.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.snapshot; + +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.UnalignedMemoryRecords; +import org.apache.kafka.raft.OffsetAndEpoch; + +/** + * Interface for writing snapshot as a sequence of records. + */ +public interface RawSnapshotWriter extends AutoCloseable { + /** + * Returns the end offset and epoch for the snapshot. + */ + OffsetAndEpoch snapshotId(); + + /** + * Returns the number of bytes for the snapshot. + */ + long sizeInBytes(); + + /** + * Fully appends the memory record set to the snapshot. + * + * If the method returns without an exception the given record set was fully writing the + * snapshot. + * + * @param records the region to append + */ + void append(MemoryRecords records); + + /** + * Fully appends the memory record set to the snapshot, the difference with {@link RawSnapshotWriter#append(MemoryRecords)} + * is that the record set are fetched from leader by FetchSnapshotRequest, so the records are unaligned. + * + * If the method returns without an exception the given records was fully writing the + * snapshot. + * + * @param records the region to append + */ + void append(UnalignedMemoryRecords records); + + /** + * Returns true if the snapshot has been frozen, otherwise false is returned. + * + * Modification to the snapshot are not allowed once it is frozen. + */ + boolean isFrozen(); + + /** + * Freezes the snapshot and marking it as immutable. + */ + void freeze(); + + /** + * Closes the snapshot writer. + * + * If close is called without first calling freeze the snapshot is aborted. + */ + void close(); +} diff --git a/raft/src/main/java/org/apache/kafka/snapshot/SnapshotPath.java b/raft/src/main/java/org/apache/kafka/snapshot/SnapshotPath.java new file mode 100644 index 0000000..16237a9 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/snapshot/SnapshotPath.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.snapshot; + +import java.nio.file.Path; +import org.apache.kafka.raft.OffsetAndEpoch; + +public final class SnapshotPath { + public final Path path; + public final OffsetAndEpoch snapshotId; + public final boolean partial; + public final boolean deleted; + + public SnapshotPath(Path path, OffsetAndEpoch snapshotId, boolean partial, boolean deleted) { + this.path = path; + this.snapshotId = snapshotId; + this.partial = partial; + this.deleted = deleted; + } + + @Override + public String toString() { + return String.format("SnapshotPath(path=%s, snapshotId=%s, partial=%s)", path, snapshotId, partial); + } +} diff --git a/raft/src/main/java/org/apache/kafka/snapshot/SnapshotReader.java b/raft/src/main/java/org/apache/kafka/snapshot/SnapshotReader.java new file mode 100644 index 0000000..8c6e8e6 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/snapshot/SnapshotReader.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.snapshot; + +import java.util.Iterator; +import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.OptionalLong; + +import org.apache.kafka.common.utils.BufferSupplier; +import org.apache.kafka.raft.Batch; +import org.apache.kafka.raft.OffsetAndEpoch; +import org.apache.kafka.server.common.serialization.RecordSerde; +import org.apache.kafka.raft.internals.RecordsIterator; + +/** + * A type for reading an immutable snapshot. + * + * A snapshot reader can be used to scan through all of the objects T in a snapshot. It + * is assumed that the content of the snapshot represents all of the objects T for the topic + * partition from offset 0 up to but not including the end offset in the snapshot id. + * + * The offsets ({@code baseOffset()} and {@code lastOffset()} stored in {@code Batch} + * objects returned by this iterator are independent of the offset of the records in the + * log used to generate this batch. + * + * Use {@code lastContainedLogOffset()} and {@code lastContainedLogEpoch()} to query which + * offsets and epoch from the log are included in this snapshot. Both of these values are + * inclusive. + */ +public final class SnapshotReader implements AutoCloseable, Iterator> { + private final OffsetAndEpoch snapshotId; + private final RecordsIterator iterator; + + private Optional> nextBatch = Optional.empty(); + private OptionalLong lastContainedLogTimestamp = OptionalLong.empty(); + + private SnapshotReader( + OffsetAndEpoch snapshotId, + RecordsIterator iterator + ) { + this.snapshotId = snapshotId; + this.iterator = iterator; + } + + /** + * Returns the end offset and epoch for the snapshot. + */ + public OffsetAndEpoch snapshotId() { + return snapshotId; + } + + /** + * Returns the last log offset which is represented in the snapshot. + */ + public long lastContainedLogOffset() { + return snapshotId.offset - 1; + } + + /** + * Returns the epoch of the last log offset which is represented in the snapshot. + */ + public int lastContainedLogEpoch() { + return snapshotId.epoch; + } + + /** + * Returns the timestamp of the last log offset which is represented in the snapshot. + */ + public long lastContainedLogTimestamp() { + if (!lastContainedLogTimestamp.isPresent()) { + nextBatch.ifPresent(batch -> { + throw new IllegalStateException( + String.format( + "nextBatch was present when last contained log timestamp was not present", + batch + ) + ); + }); + nextBatch = nextBatch(); + } + + return lastContainedLogTimestamp.getAsLong(); + } + + @Override + public boolean hasNext() { + if (!nextBatch.isPresent()) { + nextBatch = nextBatch(); + } + + return nextBatch.isPresent(); + } + + @Override + public Batch next() { + if (!hasNext()) { + throw new NoSuchElementException("Snapshot reader doesn't have any more elements"); + } + + Batch batch = nextBatch.get(); + nextBatch = Optional.empty(); + + return batch; + } + + /** + * Closes the snapshot reader. + */ + public void close() { + iterator.close(); + } + + public static SnapshotReader of( + RawSnapshotReader snapshot, + RecordSerde serde, + BufferSupplier bufferSupplier, + int maxBatchSize + ) { + return new SnapshotReader<>( + snapshot.snapshotId(), + new RecordsIterator<>(snapshot.records(), serde, bufferSupplier, maxBatchSize) + ); + } + + /** + * Returns the next non-control Batch + */ + private Optional> nextBatch() { + while (iterator.hasNext()) { + Batch batch = iterator.next(); + + if (!lastContainedLogTimestamp.isPresent()) { + // The Batch type doesn't support returning control batches. For now lets just use + // the append time of the first batch + lastContainedLogTimestamp = OptionalLong.of(batch.appendTimestamp()); + } + + if (!batch.records().isEmpty()) { + return Optional.of(batch); + } + } + + return Optional.empty(); + } +} diff --git a/raft/src/main/java/org/apache/kafka/snapshot/SnapshotWriter.java b/raft/src/main/java/org/apache/kafka/snapshot/SnapshotWriter.java new file mode 100644 index 0000000..62fc2d7 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/snapshot/SnapshotWriter.java @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.snapshot; + +import org.apache.kafka.common.memory.MemoryPool; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.raft.OffsetAndEpoch; +import org.apache.kafka.server.common.serialization.RecordSerde; +import org.apache.kafka.raft.internals.BatchAccumulator; +import org.apache.kafka.raft.internals.BatchAccumulator.CompletedBatch; +import org.apache.kafka.common.message.SnapshotHeaderRecord; +import org.apache.kafka.common.message.SnapshotFooterRecord; +import org.apache.kafka.common.record.ControlRecordUtils; + +import java.util.Optional; +import java.util.List; +import java.util.function.Supplier; + +/** + * A type for writing a snapshot for a given end offset and epoch. + * + * A snapshot writer can be used to append objects until freeze is called. When freeze is + * called the snapshot is validated and marked as immutable. After freeze is called any + * append will fail with an exception. + * + * It is assumed that the content of the snapshot represents all of the records for the + * topic partition from offset 0 up to but not including the end offset in the snapshot + * id. + * + * @see org.apache.kafka.raft.KafkaRaftClient#createSnapshot(long, int, long) + */ +final public class SnapshotWriter implements AutoCloseable { + final private RawSnapshotWriter snapshot; + final private BatchAccumulator accumulator; + final private Time time; + final private long lastContainedLogTimestamp; + + private SnapshotWriter( + RawSnapshotWriter snapshot, + int maxBatchSize, + MemoryPool memoryPool, + Time time, + long lastContainedLogTimestamp, + CompressionType compressionType, + RecordSerde serde + ) { + this.snapshot = snapshot; + this.time = time; + this.lastContainedLogTimestamp = lastContainedLogTimestamp; + + this.accumulator = new BatchAccumulator<>( + snapshot.snapshotId().epoch, + 0, + Integer.MAX_VALUE, + maxBatchSize, + memoryPool, + time, + compressionType, + serde + ); + } + + /** + * Adds a {@link SnapshotHeaderRecord} to snapshot + * + * @throws IllegalStateException if the snapshot is not empty + */ + private void initializeSnapshotWithHeader() { + if (snapshot.sizeInBytes() != 0) { + String message = String.format( + "Initializing writer with a non-empty snapshot: id = '%s'.", + snapshot.snapshotId() + ); + throw new IllegalStateException(message); + } + + SnapshotHeaderRecord headerRecord = new SnapshotHeaderRecord() + .setVersion(ControlRecordUtils.SNAPSHOT_HEADER_HIGHEST_VERSION) + .setLastContainedLogTimestamp(lastContainedLogTimestamp); + accumulator.appendSnapshotHeaderMessage(headerRecord, time.milliseconds()); + accumulator.forceDrain(); + } + + /** + * Adds a {@link SnapshotFooterRecord} to the snapshot + * + * No more records should be appended to the snapshot after calling this method + */ + private void finalizeSnapshotWithFooter() { + SnapshotFooterRecord footerRecord = new SnapshotFooterRecord() + .setVersion(ControlRecordUtils.SNAPSHOT_FOOTER_HIGHEST_VERSION); + accumulator.appendSnapshotFooterMessage(footerRecord, time.milliseconds()); + accumulator.forceDrain(); + } + + /** + * Create an instance of this class and initialize + * the underlying snapshot with {@link SnapshotHeaderRecord} + * + * @param snapshot a lambda to create the low level snapshot writer + * @param maxBatchSize the maximum size in byte for a batch + * @param memoryPool the memory pool for buffer allocation + * @param time the clock implementation + * @param lastContainedLogTimestamp The append time of the highest record contained in this snapshot + * @param compressionType the compression algorithm to use + * @param serde the record serialization and deserialization implementation + * @return {@link Optional}{@link SnapshotWriter} + */ + public static Optional> createWithHeader( + Supplier> supplier, + int maxBatchSize, + MemoryPool memoryPool, + Time snapshotTime, + long lastContainedLogTimestamp, + CompressionType compressionType, + RecordSerde serde + ) { + Optional> writer = supplier.get().map(snapshot -> { + return new SnapshotWriter( + snapshot, + maxBatchSize, + memoryPool, + snapshotTime, + lastContainedLogTimestamp, + CompressionType.NONE, + serde); + }); + writer.ifPresent(SnapshotWriter::initializeSnapshotWithHeader); + return writer; + } + + /** + * Returns the end offset and epoch for the snapshot. + */ + public OffsetAndEpoch snapshotId() { + return snapshot.snapshotId(); + } + + /** + * Returns the last log offset which is represented in the snapshot. + */ + public long lastContainedLogOffset() { + return snapshot.snapshotId().offset - 1; + } + + /** + * Returns the epoch of the last log offset which is represented in the snapshot. + */ + public int lastContainedLogEpoch() { + return snapshot.snapshotId().epoch; + } + + /** + * Returns true if the snapshot has been frozen, otherwise false is returned. + * + * Modification to the snapshot are not allowed once it is frozen. + */ + public boolean isFrozen() { + return snapshot.isFrozen(); + } + + /** + * Appends a list of values to the snapshot. + * + * The list of record passed are guaranteed to get written together. + * + * @param records the list of records to append to the snapshot + * @throws IllegalStateException if append is called when isFrozen is true + */ + public void append(List records) { + if (snapshot.isFrozen()) { + String message = String.format( + "Append not supported. Snapshot is already frozen: id = '%s'.", + snapshot.snapshotId() + ); + + throw new IllegalStateException(message); + } + + accumulator.append(snapshot.snapshotId().epoch, records); + + if (accumulator.needsDrain(time.milliseconds())) { + appendBatches(accumulator.drain()); + } + } + + /** + * Freezes the snapshot by flushing all pending writes and marking it as immutable. + * + * Also adds a {@link SnapshotFooterRecord} to the end of the snapshot + */ + public void freeze() { + finalizeSnapshotWithFooter(); + appendBatches(accumulator.drain()); + snapshot.freeze(); + accumulator.close(); + } + + /** + * Closes the snapshot writer. + * + * If close is called without first calling freeze the snapshot is aborted. + */ + public void close() { + snapshot.close(); + accumulator.close(); + } + + private void appendBatches(List> batches) { + try { + for (CompletedBatch batch : batches) { + snapshot.append(batch.data); + } + } finally { + batches.forEach(CompletedBatch::release); + } + } +} diff --git a/raft/src/main/java/org/apache/kafka/snapshot/Snapshots.java b/raft/src/main/java/org/apache/kafka/snapshot/Snapshots.java new file mode 100644 index 0000000..a4d3b5a --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/snapshot/Snapshots.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.snapshot; + +import org.apache.kafka.raft.OffsetAndEpoch; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.text.NumberFormat; +import java.util.Optional; + +public final class Snapshots { + private static final Logger log = LoggerFactory.getLogger(Snapshots.class); + private static final String SUFFIX = ".checkpoint"; + private static final String PARTIAL_SUFFIX = String.format("%s.part", SUFFIX); + private static final String DELETE_SUFFIX = String.format("%s.deleted", SUFFIX); + + private static final NumberFormat OFFSET_FORMATTER = NumberFormat.getInstance(); + private static final NumberFormat EPOCH_FORMATTER = NumberFormat.getInstance(); + + private static final int OFFSET_WIDTH = 20; + private static final int EPOCH_WIDTH = 10; + + static { + OFFSET_FORMATTER.setMinimumIntegerDigits(OFFSET_WIDTH); + OFFSET_FORMATTER.setGroupingUsed(false); + + EPOCH_FORMATTER.setMinimumIntegerDigits(EPOCH_WIDTH); + EPOCH_FORMATTER.setGroupingUsed(false); + } + + static Path snapshotDir(Path logDir) { + return logDir; + } + + static String filenameFromSnapshotId(OffsetAndEpoch snapshotId) { + return String.format("%s-%s", OFFSET_FORMATTER.format(snapshotId.offset), EPOCH_FORMATTER.format(snapshotId.epoch)); + } + + static Path moveRename(Path source, OffsetAndEpoch snapshotId) { + return source.resolveSibling(filenameFromSnapshotId(snapshotId) + SUFFIX); + } + + static Path deleteRename(Path source, OffsetAndEpoch snapshotId) { + return source.resolveSibling(filenameFromSnapshotId(snapshotId) + DELETE_SUFFIX); + } + + public static Path snapshotPath(Path logDir, OffsetAndEpoch snapshotId) { + return snapshotDir(logDir).resolve(filenameFromSnapshotId(snapshotId) + SUFFIX); + } + + public static Path createTempFile(Path logDir, OffsetAndEpoch snapshotId) { + Path dir = snapshotDir(logDir); + + try { + // Create the snapshot directory if it doesn't exists + Files.createDirectories(dir); + String prefix = String.format("%s-", filenameFromSnapshotId(snapshotId)); + return Files.createTempFile(dir, prefix, PARTIAL_SUFFIX); + } catch (IOException e) { + throw new UncheckedIOException( + String.format("Error creating temporary file, logDir = %s, snapshotId = %s.", + dir.toAbsolutePath(), snapshotId), e); + } + } + + public static Optional parse(Path path) { + Path filename = path.getFileName(); + if (filename == null) { + return Optional.empty(); + } + + String name = filename.toString(); + + boolean partial = false; + boolean deleted = false; + if (name.endsWith(PARTIAL_SUFFIX)) { + partial = true; + } else if (name.endsWith(DELETE_SUFFIX)) { + deleted = true; + } else if (!name.endsWith(SUFFIX)) { + return Optional.empty(); + } + + long endOffset = Long.parseLong(name.substring(0, OFFSET_WIDTH)); + int epoch = Integer.parseInt( + name.substring(OFFSET_WIDTH + 1, OFFSET_WIDTH + EPOCH_WIDTH + 1) + ); + + return Optional.of(new SnapshotPath(path, new OffsetAndEpoch(endOffset, epoch), partial, deleted)); + } + + /** + * Delete the snapshot from the filesystem. + */ + public static boolean deleteIfExists(Path logDir, OffsetAndEpoch snapshotId) { + Path immutablePath = snapshotPath(logDir, snapshotId); + Path deletedPath = deleteRename(immutablePath, snapshotId); + try { + boolean deleted = Files.deleteIfExists(immutablePath) | Files.deleteIfExists(deletedPath); + if (deleted) { + log.info("Deleted snapshot files for snapshot {}.", snapshotId); + } else { + log.info("Did not delete snapshot files for snapshot {} since they did not exist.", snapshotId); + } + return deleted; + } catch (IOException e) { + log.error("Error deleting snapshot files {} and {}", immutablePath, deletedPath, e); + return false; + } + } + + /** + * Mark a snapshot for deletion by renaming with the deleted suffix + */ + public static void markForDelete(Path logDir, OffsetAndEpoch snapshotId) { + Path immutablePath = snapshotPath(logDir, snapshotId); + Path deletedPath = deleteRename(immutablePath, snapshotId); + try { + Utils.atomicMoveWithFallback(immutablePath, deletedPath, false); + } catch (IOException e) { + throw new UncheckedIOException( + String.format( + "Error renaming snapshot file from %s to %s.", + immutablePath, + deletedPath + ), + e + ); + } + } +} diff --git a/raft/src/main/resources/common/message/QuorumStateData.json b/raft/src/main/resources/common/message/QuorumStateData.json new file mode 100644 index 0000000..d71a32c --- /dev/null +++ b/raft/src/main/resources/common/message/QuorumStateData.json @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "type": "data", + "name": "QuorumStateData", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + {"name": "ClusterId", "type": "string", "versions": "0+"}, + {"name": "LeaderId", "type": "int32", "versions": "0+", "default": "-1"}, + {"name": "LeaderEpoch", "type": "int32", "versions": "0+", "default": "-1"}, + {"name": "VotedId", "type": "int32", "versions": "0+", "default": "-1"}, + {"name": "AppliedOffset", "type": "int64", "versions": "0+"}, + {"name": "CurrentVoters", "type": "[]Voter", "versions": "0+", "nullableVersions": "0+"} + ], + "commonStructs": [ + { "name": "Voter", "versions": "0+", "fields": [ + {"name": "VoterId", "type": "int32", "versions": "0+"} + ]} + ] +} diff --git a/raft/src/test/java/org/apache/kafka/raft/CandidateStateTest.java b/raft/src/test/java/org/apache/kafka/raft/CandidateStateTest.java new file mode 100644 index 0000000..71a2375 --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/CandidateStateTest.java @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import java.util.Collections; +import java.util.Optional; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class CandidateStateTest { + private final int localId = 0; + private final int epoch = 5; + private final MockTime time = new MockTime(); + private final int electionTimeoutMs = 5000; + private final LogContext logContext = new LogContext(); + + private CandidateState newCandidateState( + Set voters, + Optional highWatermark + ) { + return new CandidateState( + time, + localId, + epoch, + voters, + highWatermark, + 0, + electionTimeoutMs, + logContext + ); + } + + @Test + public void testSingleNodeQuorum() { + CandidateState state = newCandidateState(Collections.singleton(localId), Optional.empty()); + assertTrue(state.isVoteGranted()); + assertFalse(state.isVoteRejected()); + assertEquals(Collections.emptySet(), state.unrecordedVoters()); + } + + @Test + public void testTwoNodeQuorumVoteRejected() { + int otherNodeId = 1; + CandidateState state = newCandidateState(Utils.mkSet(localId, otherNodeId), Optional.empty()); + assertFalse(state.isVoteGranted()); + assertFalse(state.isVoteRejected()); + assertEquals(Collections.singleton(otherNodeId), state.unrecordedVoters()); + assertTrue(state.recordRejectedVote(otherNodeId)); + assertFalse(state.isVoteGranted()); + assertTrue(state.isVoteRejected()); + } + + @Test + public void testTwoNodeQuorumVoteGranted() { + int otherNodeId = 1; + CandidateState state = newCandidateState( + Utils.mkSet(localId, otherNodeId), Optional.empty()); + assertFalse(state.isVoteGranted()); + assertFalse(state.isVoteRejected()); + assertEquals(Collections.singleton(otherNodeId), state.unrecordedVoters()); + assertTrue(state.recordGrantedVote(otherNodeId)); + assertEquals(Collections.emptySet(), state.unrecordedVoters()); + assertFalse(state.isVoteRejected()); + assertTrue(state.isVoteGranted()); + } + + @Test + public void testThreeNodeQuorumVoteGranted() { + int node1 = 1; + int node2 = 2; + CandidateState state = newCandidateState( + Utils.mkSet(localId, node1, node2), Optional.empty()); + assertFalse(state.isVoteGranted()); + assertFalse(state.isVoteRejected()); + assertEquals(Utils.mkSet(node1, node2), state.unrecordedVoters()); + assertTrue(state.recordGrantedVote(node1)); + assertEquals(Collections.singleton(node2), state.unrecordedVoters()); + assertTrue(state.isVoteGranted()); + assertFalse(state.isVoteRejected()); + assertTrue(state.recordRejectedVote(node2)); + assertEquals(Collections.emptySet(), state.unrecordedVoters()); + assertTrue(state.isVoteGranted()); + assertFalse(state.isVoteRejected()); + } + + @Test + public void testThreeNodeQuorumVoteRejected() { + int node1 = 1; + int node2 = 2; + CandidateState state = newCandidateState( + Utils.mkSet(localId, node1, node2), Optional.empty()); + assertFalse(state.isVoteGranted()); + assertFalse(state.isVoteRejected()); + assertEquals(Utils.mkSet(node1, node2), state.unrecordedVoters()); + assertTrue(state.recordRejectedVote(node1)); + assertEquals(Collections.singleton(node2), state.unrecordedVoters()); + assertFalse(state.isVoteGranted()); + assertFalse(state.isVoteRejected()); + assertTrue(state.recordRejectedVote(node2)); + assertEquals(Collections.emptySet(), state.unrecordedVoters()); + assertFalse(state.isVoteGranted()); + assertTrue(state.isVoteRejected()); + } + + @Test + public void testCannotRejectVoteFromLocalId() { + int otherNodeId = 1; + CandidateState state = newCandidateState( + Utils.mkSet(localId, otherNodeId), Optional.empty()); + assertThrows(IllegalArgumentException.class, () -> state.recordRejectedVote(localId)); + } + + @Test + public void testCannotChangeVoteGrantedToRejected() { + int otherNodeId = 1; + CandidateState state = newCandidateState( + Utils.mkSet(localId, otherNodeId), Optional.empty()); + assertTrue(state.recordGrantedVote(otherNodeId)); + assertThrows(IllegalArgumentException.class, () -> state.recordRejectedVote(otherNodeId)); + assertTrue(state.isVoteGranted()); + } + + @Test + public void testCannotChangeVoteRejectedToGranted() { + int otherNodeId = 1; + CandidateState state = newCandidateState( + Utils.mkSet(localId, otherNodeId), Optional.empty()); + assertTrue(state.recordRejectedVote(otherNodeId)); + assertThrows(IllegalArgumentException.class, () -> state.recordGrantedVote(otherNodeId)); + assertTrue(state.isVoteRejected()); + } + + @Test + public void testCannotGrantOrRejectNonVoters() { + int nonVoterId = 1; + CandidateState state = newCandidateState( + Collections.singleton(localId), Optional.empty()); + assertThrows(IllegalArgumentException.class, () -> state.recordGrantedVote(nonVoterId)); + assertThrows(IllegalArgumentException.class, () -> state.recordRejectedVote(nonVoterId)); + } + + @Test + public void testIdempotentGrant() { + int otherNodeId = 1; + CandidateState state = newCandidateState( + Utils.mkSet(localId, otherNodeId), Optional.empty()); + assertTrue(state.recordGrantedVote(otherNodeId)); + assertFalse(state.recordGrantedVote(otherNodeId)); + } + + @Test + public void testIdempotentReject() { + int otherNodeId = 1; + CandidateState state = newCandidateState( + Utils.mkSet(localId, otherNodeId), Optional.empty()); + assertTrue(state.recordRejectedVote(otherNodeId)); + assertFalse(state.recordRejectedVote(otherNodeId)); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testGrantVote(boolean isLogUpToDate) { + CandidateState state = newCandidateState( + Utils.mkSet(1, 2, 3), + Optional.empty() + ); + + assertFalse(state.canGrantVote(1, isLogUpToDate)); + assertFalse(state.canGrantVote(2, isLogUpToDate)); + assertFalse(state.canGrantVote(3, isLogUpToDate)); + } + +} diff --git a/raft/src/test/java/org/apache/kafka/raft/FileBasedStateStoreTest.java b/raft/src/test/java/org/apache/kafka/raft/FileBasedStateStoreTest.java new file mode 100644 index 0000000..5fa4f5c --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/FileBasedStateStoreTest.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.AfterEach; + +import java.io.File; +import java.io.IOException; +import java.util.OptionalInt; +import java.util.Set; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class FileBasedStateStoreTest { + + private FileBasedStateStore stateStore; + + @Test + public void testReadElectionState() throws IOException { + final File stateFile = TestUtils.tempFile(); + + stateStore = new FileBasedStateStore(stateFile); + + final int leaderId = 1; + final int epoch = 2; + Set voters = Utils.mkSet(leaderId); + + stateStore.writeElectionState(ElectionState.withElectedLeader(epoch, leaderId, voters)); + assertTrue(stateFile.exists()); + assertEquals(ElectionState.withElectedLeader(epoch, leaderId, voters), stateStore.readElectionState()); + + // Start another state store and try to read from the same file. + final FileBasedStateStore secondStateStore = new FileBasedStateStore(stateFile); + assertEquals(ElectionState.withElectedLeader(epoch, leaderId, voters), secondStateStore.readElectionState()); + } + + @Test + public void testWriteElectionState() throws IOException { + final File stateFile = TestUtils.tempFile(); + + stateStore = new FileBasedStateStore(stateFile); + + // We initialized a state from the metadata log + assertTrue(stateFile.exists()); + + // The temp file should be removed + final File createdTempFile = new File(stateFile.getAbsolutePath() + ".tmp"); + assertFalse(createdTempFile.exists()); + + final int epoch = 2; + final int leaderId = 1; + final int votedId = 5; + Set voters = Utils.mkSet(leaderId, votedId); + + stateStore.writeElectionState(ElectionState.withElectedLeader(epoch, leaderId, voters)); + + assertEquals(stateStore.readElectionState(), new ElectionState(epoch, + OptionalInt.of(leaderId), OptionalInt.empty(), voters)); + + stateStore.writeElectionState(ElectionState.withVotedCandidate(epoch, votedId, voters)); + + assertEquals(stateStore.readElectionState(), new ElectionState(epoch, + OptionalInt.empty(), OptionalInt.of(votedId), voters)); + + final FileBasedStateStore rebootStateStore = new FileBasedStateStore(stateFile); + + assertEquals(rebootStateStore.readElectionState(), new ElectionState(epoch, + OptionalInt.empty(), OptionalInt.of(votedId), voters)); + + stateStore.clear(); + assertFalse(stateFile.exists()); + } + + @AfterEach + public void cleanup() throws IOException { + if (stateStore != null) { + stateStore.clear(); + } + } +} diff --git a/raft/src/test/java/org/apache/kafka/raft/FollowerStateTest.java b/raft/src/test/java/org/apache/kafka/raft/FollowerStateTest.java new file mode 100644 index 0000000..42c6bc9 --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/FollowerStateTest.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import java.util.Optional; +import java.util.OptionalLong; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class FollowerStateTest { + private final MockTime time = new MockTime(); + private final LogContext logContext = new LogContext(); + private final int epoch = 5; + private final int fetchTimeoutMs = 15000; + int leaderId = 3; + + private FollowerState newFollowerState( + Set voters, + Optional highWatermark + ) { + return new FollowerState( + time, + epoch, + leaderId, + voters, + highWatermark, + fetchTimeoutMs, + logContext + ); + } + + @Test + public void testFetchTimeoutExpiration() { + FollowerState state = newFollowerState(Utils.mkSet(1, 2, 3), Optional.empty()); + + assertFalse(state.hasFetchTimeoutExpired(time.milliseconds())); + assertEquals(fetchTimeoutMs, state.remainingFetchTimeMs(time.milliseconds())); + + time.sleep(5000); + assertFalse(state.hasFetchTimeoutExpired(time.milliseconds())); + assertEquals(fetchTimeoutMs - 5000, state.remainingFetchTimeMs(time.milliseconds())); + + time.sleep(10000); + assertTrue(state.hasFetchTimeoutExpired(time.milliseconds())); + assertEquals(0, state.remainingFetchTimeMs(time.milliseconds())); + } + + @Test + public void testMonotonicHighWatermark() { + FollowerState state = newFollowerState(Utils.mkSet(1, 2, 3), Optional.empty()); + + OptionalLong highWatermark = OptionalLong.of(15L); + state.updateHighWatermark(highWatermark); + assertThrows(IllegalArgumentException.class, () -> state.updateHighWatermark(OptionalLong.empty())); + assertThrows(IllegalArgumentException.class, () -> state.updateHighWatermark(OptionalLong.of(14L))); + state.updateHighWatermark(highWatermark); + assertEquals(Optional.of(new LogOffsetMetadata(15L)), state.highWatermark()); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testGrantVote(boolean isLogUpToDate) { + FollowerState state = newFollowerState( + Utils.mkSet(1, 2, 3), + Optional.empty() + ); + + assertFalse(state.canGrantVote(1, isLogUpToDate)); + assertFalse(state.canGrantVote(2, isLogUpToDate)); + assertFalse(state.canGrantVote(3, isLogUpToDate)); + } + +} diff --git a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientSnapshotTest.java b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientSnapshotTest.java new file mode 100644 index 0000000..42db40c --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientSnapshotTest.java @@ -0,0 +1,1809 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.memory.MemoryPool; +import org.apache.kafka.common.message.FetchResponseData; +import org.apache.kafka.common.message.FetchSnapshotRequestData; +import org.apache.kafka.common.message.FetchSnapshotResponseData; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.UnalignedMemoryRecords; +import org.apache.kafka.common.requests.FetchSnapshotRequest; +import org.apache.kafka.common.requests.FetchSnapshotResponse; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.raft.internals.StringSerde; +import org.apache.kafka.snapshot.RawSnapshotReader; +import org.apache.kafka.snapshot.RawSnapshotWriter; +import org.apache.kafka.snapshot.SnapshotReader; +import org.apache.kafka.snapshot.SnapshotWriter; +import org.apache.kafka.snapshot.SnapshotWriterReaderTest; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.OptionalLong; +import java.util.OptionalInt; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertThrows; + +final public class KafkaRaftClientSnapshotTest { + @Test + public void testLeaderListenerNotified() throws Exception { + int localId = 0; + int otherNodeId = localId + 1; + Set voters = Utils.mkSet(localId, otherNodeId); + OffsetAndEpoch snapshotId = new OffsetAndEpoch(3, 1); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .appendToLog(snapshotId.epoch, Arrays.asList("a", "b", "c")) + .appendToLog(snapshotId.epoch, Arrays.asList("d", "e", "f")) + .withEmptySnapshot(snapshotId) + .deleteBeforeSnapshot(snapshotId) + .build(); + + context.becomeLeader(); + int epoch = context.currentEpoch(); + + // Advance the highWatermark + long localLogEndOffset = context.log.endOffset().offset; + context.deliverRequest(context.fetchRequest(epoch, otherNodeId, localLogEndOffset, epoch, 0)); + context.pollUntilResponse(); + context.assertSentFetchPartitionResponse(Errors.NONE, epoch, OptionalInt.of(localId)); + assertEquals(localLogEndOffset, context.client.highWatermark().getAsLong()); + + // Check that listener was notified of the new snapshot + try (SnapshotReader snapshot = context.listener.drainHandledSnapshot().get()) { + assertEquals(snapshotId, snapshot.snapshotId()); + SnapshotWriterReaderTest.assertSnapshot(Arrays.asList(), snapshot); + } + } + + @Test + public void testFollowerListenerNotified() throws Exception { + int localId = 0; + int leaderId = localId + 1; + Set voters = Utils.mkSet(localId, leaderId); + int epoch = 2; + OffsetAndEpoch snapshotId = new OffsetAndEpoch(3, 1); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .appendToLog(snapshotId.epoch, Arrays.asList("a", "b", "c")) + .appendToLog(snapshotId.epoch, Arrays.asList("d", "e", "f")) + .withEmptySnapshot(snapshotId) + .deleteBeforeSnapshot(snapshotId) + .withElectedLeader(epoch, leaderId) + .build(); + + // Advance the highWatermark + long localLogEndOffset = context.log.endOffset().offset; + context.pollUntilRequest(); + RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); + context.assertFetchRequestData(fetchRequest, epoch, localLogEndOffset, snapshotId.epoch); + context.deliverResponse( + fetchRequest.correlationId, + fetchRequest.destinationId(), + context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, localLogEndOffset, Errors.NONE) + ); + + context.pollUntilRequest(); + context.assertSentFetchRequest(epoch, localLogEndOffset, snapshotId.epoch); + + // Check that listener was notified of the new snapshot + try (SnapshotReader snapshot = context.listener.drainHandledSnapshot().get()) { + assertEquals(snapshotId, snapshot.snapshotId()); + SnapshotWriterReaderTest.assertSnapshot(Arrays.asList(), snapshot); + } + } + + @Test + public void testSecondListenerNotified() throws Exception { + int localId = 0; + int leaderId = localId + 1; + Set voters = Utils.mkSet(localId, leaderId); + int epoch = 2; + OffsetAndEpoch snapshotId = new OffsetAndEpoch(3, 1); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .appendToLog(snapshotId.epoch, Arrays.asList("a", "b", "c")) + .appendToLog(snapshotId.epoch, Arrays.asList("d", "e", "f")) + .withEmptySnapshot(snapshotId) + .deleteBeforeSnapshot(snapshotId) + .withElectedLeader(epoch, leaderId) + .build(); + + // Advance the highWatermark + long localLogEndOffset = context.log.endOffset().offset; + context.pollUntilRequest(); + RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); + context.assertFetchRequestData(fetchRequest, epoch, localLogEndOffset, snapshotId.epoch); + context.deliverResponse( + fetchRequest.correlationId, + fetchRequest.destinationId(), + context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, localLogEndOffset, Errors.NONE) + ); + + context.pollUntilRequest(); + context.assertSentFetchRequest(epoch, localLogEndOffset, snapshotId.epoch); + + RaftClientTestContext.MockListener secondListener = new RaftClientTestContext.MockListener(OptionalInt.of(localId)); + context.client.register(secondListener); + context.client.poll(); + + // Check that the second listener was notified of the new snapshot + try (SnapshotReader snapshot = secondListener.drainHandledSnapshot().get()) { + assertEquals(snapshotId, snapshot.snapshotId()); + SnapshotWriterReaderTest.assertSnapshot(Arrays.asList(), snapshot); + } + } + + @Test + public void testListenerRenotified() throws Exception { + int localId = 0; + int otherNodeId = localId + 1; + Set voters = Utils.mkSet(localId, otherNodeId); + OffsetAndEpoch snapshotId = new OffsetAndEpoch(3, 1); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .appendToLog(snapshotId.epoch, Arrays.asList("a", "b", "c")) + .appendToLog(snapshotId.epoch, Arrays.asList("d", "e", "f")) + .appendToLog(snapshotId.epoch, Arrays.asList("g", "h", "i")) + .withEmptySnapshot(snapshotId) + .deleteBeforeSnapshot(snapshotId) + .build(); + + context.becomeLeader(); + int epoch = context.currentEpoch(); + + // Stop the listener from reading commit batches + context.listener.updateReadCommit(false); + + // Advance the highWatermark + long localLogEndOffset = context.log.endOffset().offset; + context.deliverRequest(context.fetchRequest(epoch, otherNodeId, localLogEndOffset, epoch, 0)); + context.pollUntilResponse(); + context.assertSentFetchPartitionResponse(Errors.NONE, epoch, OptionalInt.of(localId)); + assertEquals(localLogEndOffset, context.client.highWatermark().getAsLong()); + + // Check that listener was notified of the new snapshot + try (SnapshotReader snapshot = context.listener.drainHandledSnapshot().get()) { + assertEquals(snapshotId, snapshot.snapshotId()); + SnapshotWriterReaderTest.assertSnapshot(Arrays.asList(), snapshot); + } + + // Generate a new snapshot + OffsetAndEpoch secondSnapshotId = new OffsetAndEpoch(localLogEndOffset, epoch); + try (SnapshotWriter snapshot = context.client.createSnapshot(secondSnapshotId.offset - 1, secondSnapshotId.epoch, 0).get()) { + assertEquals(secondSnapshotId, snapshot.snapshotId()); + snapshot.freeze(); + } + context.log.deleteBeforeSnapshot(secondSnapshotId); + context.client.poll(); + + // Resume the listener from reading commit batches + context.listener.updateReadCommit(true); + + context.client.poll(); + // Check that listener was notified of the second snapshot + try (SnapshotReader snapshot = context.listener.drainHandledSnapshot().get()) { + assertEquals(secondSnapshotId, snapshot.snapshotId()); + SnapshotWriterReaderTest.assertSnapshot(Arrays.asList(), snapshot); + } + } + + @Test + public void testFetchRequestOffsetLessThanLogStart() throws Exception { + int localId = 0; + int otherNodeId = localId + 1; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withAppendLingerMs(1) + .build(); + + context.becomeLeader(); + int epoch = context.currentEpoch(); + + List appendRecords = Arrays.asList("a", "b", "c"); + context.client.scheduleAppend(epoch, appendRecords); + context.time.sleep(context.appendLingerMs()); + context.client.poll(); + + long localLogEndOffset = context.log.endOffset().offset; + assertTrue( + appendRecords.size() <= localLogEndOffset, + String.format("Record length = %s, log end offset = %s", appendRecords.size(), localLogEndOffset) + ); + + // Advance the highWatermark + context.advanceLocalLeaderHighWatermarkToLogEndOffset(); + + OffsetAndEpoch snapshotId = new OffsetAndEpoch(localLogEndOffset, epoch); + try (SnapshotWriter snapshot = context.client.createSnapshot(snapshotId.offset - 1, snapshotId.epoch, 0).get()) { + assertEquals(snapshotId, snapshot.snapshotId()); + snapshot.freeze(); + } + context.log.deleteBeforeSnapshot(snapshotId); + context.client.poll(); + + // Send Fetch request less than start offset + context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 0, epoch, 0)); + context.pollUntilResponse(); + FetchResponseData.PartitionData partitionResponse = context.assertSentFetchPartitionResponse(); + assertEquals(Errors.NONE, Errors.forCode(partitionResponse.errorCode())); + assertEquals(epoch, partitionResponse.currentLeader().leaderEpoch()); + assertEquals(localId, partitionResponse.currentLeader().leaderId()); + assertEquals(snapshotId.epoch, partitionResponse.snapshotId().epoch()); + assertEquals(snapshotId.offset, partitionResponse.snapshotId().endOffset()); + } + + @Test + public void testFetchRequestWithLargerLastFetchedEpoch() throws Exception { + int localId = 0; + int otherNodeId = localId + 1; + Set voters = Utils.mkSet(localId, otherNodeId); + + OffsetAndEpoch oldestSnapshotId = new OffsetAndEpoch(3, 2); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .appendToLog(oldestSnapshotId.epoch, Arrays.asList("a", "b", "c")) + .appendToLog(oldestSnapshotId.epoch, Arrays.asList("d", "e", "f")) + .withAppendLingerMs(1) + .build(); + + context.becomeLeader(); + int epoch = context.currentEpoch(); + assertEquals(oldestSnapshotId.epoch + 1, epoch); + + // Advance the highWatermark + context.advanceLocalLeaderHighWatermarkToLogEndOffset(); + + // Create a snapshot at the high watermark + try (SnapshotWriter snapshot = context.client.createSnapshot(oldestSnapshotId.offset - 1, oldestSnapshotId.epoch, 0).get()) { + assertEquals(oldestSnapshotId, snapshot.snapshotId()); + snapshot.freeze(); + } + context.client.poll(); + + context.client.scheduleAppend(epoch, Arrays.asList("g", "h", "i")); + context.time.sleep(context.appendLingerMs()); + context.client.poll(); + + // It is an invalid request to send an last fetched epoch greater than the current epoch + context.deliverRequest(context.fetchRequest(epoch, otherNodeId, oldestSnapshotId.offset + 1, epoch + 1, 0)); + context.pollUntilResponse(); + context.assertSentFetchPartitionResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(localId)); + } + + @Test + public void testFetchRequestTruncateToLogStart() throws Exception { + int localId = 0; + int otherNodeId = localId + 1; + int syncNodeId = otherNodeId + 1; + Set voters = Utils.mkSet(localId, otherNodeId, syncNodeId); + + OffsetAndEpoch oldestSnapshotId = new OffsetAndEpoch(3, 2); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .appendToLog(oldestSnapshotId.epoch, Arrays.asList("a", "b", "c")) + .appendToLog(oldestSnapshotId.epoch + 2, Arrays.asList("d", "e", "f")) + .withAppendLingerMs(1) + .build(); + + context.becomeLeader(); + int epoch = context.currentEpoch(); + assertEquals(oldestSnapshotId.epoch + 2 + 1, epoch); + + // Advance the highWatermark + context.advanceLocalLeaderHighWatermarkToLogEndOffset(); + + // Create a snapshot at the high watermark + try (SnapshotWriter snapshot = context.client.createSnapshot(oldestSnapshotId.offset - 1, oldestSnapshotId.epoch, 0).get()) { + assertEquals(oldestSnapshotId, snapshot.snapshotId()); + snapshot.freeze(); + } + context.client.poll(); + + // This should truncate to the old snapshot + context.deliverRequest( + context.fetchRequest(epoch, otherNodeId, oldestSnapshotId.offset + 1, oldestSnapshotId.epoch + 1, 0) + ); + context.pollUntilResponse(); + FetchResponseData.PartitionData partitionResponse = context.assertSentFetchPartitionResponse(); + assertEquals(Errors.NONE, Errors.forCode(partitionResponse.errorCode())); + assertEquals(epoch, partitionResponse.currentLeader().leaderEpoch()); + assertEquals(localId, partitionResponse.currentLeader().leaderId()); + assertEquals(oldestSnapshotId.epoch, partitionResponse.divergingEpoch().epoch()); + assertEquals(oldestSnapshotId.offset, partitionResponse.divergingEpoch().endOffset()); + } + + @Test + public void testFetchRequestAtLogStartOffsetWithValidEpoch() throws Exception { + int localId = 0; + int otherNodeId = localId + 1; + int syncNodeId = otherNodeId + 1; + Set voters = Utils.mkSet(localId, otherNodeId, syncNodeId); + + OffsetAndEpoch oldestSnapshotId = new OffsetAndEpoch(3, 2); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .appendToLog(oldestSnapshotId.epoch, Arrays.asList("a", "b", "c")) + .appendToLog(oldestSnapshotId.epoch, Arrays.asList("d", "e", "f")) + .appendToLog(oldestSnapshotId.epoch + 2, Arrays.asList("g", "h", "i")) + .withAppendLingerMs(1) + .build(); + + context.becomeLeader(); + int epoch = context.currentEpoch(); + assertEquals(oldestSnapshotId.epoch + 2 + 1, epoch); + + // Advance the highWatermark + context.advanceLocalLeaderHighWatermarkToLogEndOffset(); + + // Create a snapshot at the high watermark + try (SnapshotWriter snapshot = context.client.createSnapshot(oldestSnapshotId.offset - 1, oldestSnapshotId.epoch, 0).get()) { + assertEquals(oldestSnapshotId, snapshot.snapshotId()); + snapshot.freeze(); + } + context.client.poll(); + + // Send fetch request at log start offset with valid last fetched epoch + context.deliverRequest( + context.fetchRequest(epoch, otherNodeId, oldestSnapshotId.offset, oldestSnapshotId.epoch, 0) + ); + context.pollUntilResponse(); + context.assertSentFetchPartitionResponse(Errors.NONE, epoch, OptionalInt.of(localId)); + } + + @Test + public void testFetchRequestAtLogStartOffsetWithInvalidEpoch() throws Exception { + int localId = 0; + int otherNodeId = localId + 1; + int syncNodeId = otherNodeId + 1; + Set voters = Utils.mkSet(localId, otherNodeId, syncNodeId); + + OffsetAndEpoch oldestSnapshotId = new OffsetAndEpoch(3, 2); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .appendToLog(oldestSnapshotId.epoch, Arrays.asList("a", "b", "c")) + .appendToLog(oldestSnapshotId.epoch, Arrays.asList("d", "e", "f")) + .appendToLog(oldestSnapshotId.epoch + 2, Arrays.asList("g", "h", "i")) + .withAppendLingerMs(1) + .build(); + + context.becomeLeader(); + int epoch = context.currentEpoch(); + assertEquals(oldestSnapshotId.epoch + 2 + 1, epoch); + + // Advance the highWatermark + context.advanceLocalLeaderHighWatermarkToLogEndOffset(); + + // Create a snapshot at the high watermark + try (SnapshotWriter snapshot = context.client.createSnapshot(oldestSnapshotId.offset - 1, oldestSnapshotId.epoch, 0).get()) { + assertEquals(oldestSnapshotId, snapshot.snapshotId()); + snapshot.freeze(); + } + context.log.deleteBeforeSnapshot(oldestSnapshotId); + context.client.poll(); + + // Send fetch with log start offset and invalid last fetched epoch + context.deliverRequest( + context.fetchRequest(epoch, otherNodeId, oldestSnapshotId.offset, oldestSnapshotId.epoch + 1, 0) + ); + context.pollUntilResponse(); + FetchResponseData.PartitionData partitionResponse = context.assertSentFetchPartitionResponse(); + assertEquals(Errors.NONE, Errors.forCode(partitionResponse.errorCode())); + assertEquals(epoch, partitionResponse.currentLeader().leaderEpoch()); + assertEquals(localId, partitionResponse.currentLeader().leaderId()); + assertEquals(oldestSnapshotId.epoch, partitionResponse.snapshotId().epoch()); + assertEquals(oldestSnapshotId.offset, partitionResponse.snapshotId().endOffset()); + } + + @Test + public void testFetchRequestWithLastFetchedEpochLessThanOldestSnapshot() throws Exception { + int localId = 0; + int otherNodeId = localId + 1; + int syncNodeId = otherNodeId + 1; + Set voters = Utils.mkSet(localId, otherNodeId, syncNodeId); + + OffsetAndEpoch oldestSnapshotId = new OffsetAndEpoch(3, 2); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .appendToLog(oldestSnapshotId.epoch, Arrays.asList("a", "b", "c")) + .appendToLog(oldestSnapshotId.epoch, Arrays.asList("d", "e", "f")) + .appendToLog(oldestSnapshotId.epoch + 2, Arrays.asList("g", "h", "i")) + .withAppendLingerMs(1) + .build(); + + context.becomeLeader(); + int epoch = context.currentEpoch(); + assertEquals(oldestSnapshotId.epoch + 2 + 1, epoch); + + // Advance the highWatermark + context.advanceLocalLeaderHighWatermarkToLogEndOffset(); + + // Create a snapshot at the high watermark + try (SnapshotWriter snapshot = context.client.createSnapshot(oldestSnapshotId.offset - 1, oldestSnapshotId.epoch, 0).get()) { + assertEquals(oldestSnapshotId, snapshot.snapshotId()); + snapshot.freeze(); + } + context.client.poll(); + + // Send a epoch less than the oldest snapshot + context.deliverRequest( + context.fetchRequest( + epoch, + otherNodeId, + context.log.endOffset().offset, + oldestSnapshotId.epoch - 1, + 0 + ) + ); + context.pollUntilResponse(); + FetchResponseData.PartitionData partitionResponse = context.assertSentFetchPartitionResponse(); + assertEquals(Errors.NONE, Errors.forCode(partitionResponse.errorCode())); + assertEquals(epoch, partitionResponse.currentLeader().leaderEpoch()); + assertEquals(localId, partitionResponse.currentLeader().leaderId()); + assertEquals(oldestSnapshotId.epoch, partitionResponse.snapshotId().epoch()); + assertEquals(oldestSnapshotId.offset, partitionResponse.snapshotId().endOffset()); + } + + @Test + public void testFetchSnapshotRequestMissingSnapshot() throws Exception { + int localId = 0; + int epoch = 2; + Set voters = Utils.mkSet(localId, localId + 1); + + RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch); + + context.deliverRequest( + fetchSnapshotRequest( + context.metadataPartition, + epoch, + new OffsetAndEpoch(0, 0), + Integer.MAX_VALUE, + 0 + ) + ); + + context.client.poll(); + + FetchSnapshotResponseData.PartitionSnapshot response = context.assertSentFetchSnapshotResponse(context.metadataPartition).get(); + assertEquals(Errors.SNAPSHOT_NOT_FOUND, Errors.forCode(response.errorCode())); + } + + @Test + public void testFetchSnapshotRequestUnknownPartition() throws Exception { + int localId = 0; + Set voters = Utils.mkSet(localId, localId + 1); + int epoch = 2; + TopicPartition topicPartition = new TopicPartition("unknown", 0); + + RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch); + + context.deliverRequest( + fetchSnapshotRequest( + topicPartition, + epoch, + new OffsetAndEpoch(0, 0), + Integer.MAX_VALUE, + 0 + ) + ); + + context.client.poll(); + + FetchSnapshotResponseData.PartitionSnapshot response = context.assertSentFetchSnapshotResponse(topicPartition).get(); + assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, Errors.forCode(response.errorCode())); + } + + @Test + public void testFetchSnapshotRequestAsLeader() throws Exception { + int localId = 0; + Set voters = Utils.mkSet(localId, localId + 1); + OffsetAndEpoch snapshotId = new OffsetAndEpoch(1, 1); + List records = Arrays.asList("foo", "bar"); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .appendToLog(snapshotId.epoch, Arrays.asList("a")) + .build(); + + context.becomeLeader(); + int epoch = context.currentEpoch(); + + context.advanceLocalLeaderHighWatermarkToLogEndOffset(); + + try (SnapshotWriter snapshot = context.client.createSnapshot(snapshotId.offset - 1, snapshotId.epoch, 0).get()) { + assertEquals(snapshotId, snapshot.snapshotId()); + snapshot.append(records); + snapshot.freeze(); + } + + RawSnapshotReader snapshot = context.log.readSnapshot(snapshotId).get(); + context.deliverRequest( + fetchSnapshotRequest( + context.metadataPartition, + epoch, + snapshotId, + Integer.MAX_VALUE, + 0 + ) + ); + + context.client.poll(); + + FetchSnapshotResponseData.PartitionSnapshot response = context + .assertSentFetchSnapshotResponse(context.metadataPartition) + .get(); + + assertEquals(Errors.NONE, Errors.forCode(response.errorCode())); + assertEquals(snapshot.sizeInBytes(), response.size()); + assertEquals(0, response.position()); + assertEquals(snapshot.sizeInBytes(), response.unalignedRecords().sizeInBytes()); + + UnalignedMemoryRecords memoryRecords = (UnalignedMemoryRecords) snapshot.slice(0, Math.toIntExact(snapshot.sizeInBytes())); + + assertEquals(memoryRecords.buffer(), ((UnalignedMemoryRecords) response.unalignedRecords()).buffer()); + } + + @Test + public void testPartialFetchSnapshotRequestAsLeader() throws Exception { + int localId = 0; + Set voters = Utils.mkSet(localId, localId + 1); + OffsetAndEpoch snapshotId = new OffsetAndEpoch(2, 1); + List records = Arrays.asList("foo", "bar"); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .appendToLog(snapshotId.epoch, records) + .build(); + + context.becomeLeader(); + int epoch = context.currentEpoch(); + + context.advanceLocalLeaderHighWatermarkToLogEndOffset(); + + try (SnapshotWriter snapshot = context.client.createSnapshot(snapshotId.offset - 1, snapshotId.epoch, 0).get()) { + assertEquals(snapshotId, snapshot.snapshotId()); + snapshot.append(records); + snapshot.freeze(); + } + + RawSnapshotReader snapshot = context.log.readSnapshot(snapshotId).get(); + // Fetch half of the snapshot + context.deliverRequest( + fetchSnapshotRequest( + context.metadataPartition, + epoch, + snapshotId, + Math.toIntExact(snapshot.sizeInBytes() / 2), + 0 + ) + ); + + context.client.poll(); + + FetchSnapshotResponseData.PartitionSnapshot response = context + .assertSentFetchSnapshotResponse(context.metadataPartition) + .get(); + + assertEquals(Errors.NONE, Errors.forCode(response.errorCode())); + assertEquals(snapshot.sizeInBytes(), response.size()); + assertEquals(0, response.position()); + assertEquals(snapshot.sizeInBytes() / 2, response.unalignedRecords().sizeInBytes()); + + UnalignedMemoryRecords memoryRecords = (UnalignedMemoryRecords) snapshot.slice(0, Math.toIntExact(snapshot.sizeInBytes())); + ByteBuffer snapshotBuffer = memoryRecords.buffer(); + + ByteBuffer responseBuffer = ByteBuffer.allocate(Math.toIntExact(snapshot.sizeInBytes())); + responseBuffer.put(((UnalignedMemoryRecords) response.unalignedRecords()).buffer()); + + ByteBuffer expectedBytes = snapshotBuffer.duplicate(); + expectedBytes.limit(Math.toIntExact(snapshot.sizeInBytes() / 2)); + + assertEquals(expectedBytes, responseBuffer.duplicate().flip()); + + // Fetch the remainder of the snapshot + context.deliverRequest( + fetchSnapshotRequest( + context.metadataPartition, + epoch, + snapshotId, + Integer.MAX_VALUE, + responseBuffer.position() + ) + ); + + context.client.poll(); + + response = context.assertSentFetchSnapshotResponse(context.metadataPartition).get(); + assertEquals(Errors.NONE, Errors.forCode(response.errorCode())); + assertEquals(snapshot.sizeInBytes(), response.size()); + assertEquals(responseBuffer.position(), response.position()); + assertEquals(snapshot.sizeInBytes() - (snapshot.sizeInBytes() / 2), response.unalignedRecords().sizeInBytes()); + + responseBuffer.put(((UnalignedMemoryRecords) response.unalignedRecords()).buffer()); + assertEquals(snapshotBuffer, responseBuffer.flip()); + } + + @Test + public void testFetchSnapshotRequestAsFollower() throws IOException { + int localId = 0; + int leaderId = localId + 1; + Set voters = Utils.mkSet(localId, leaderId); + int epoch = 2; + OffsetAndEpoch snapshotId = new OffsetAndEpoch(0, 0); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectedLeader(epoch, leaderId) + .build(); + + context.deliverRequest( + fetchSnapshotRequest( + context.metadataPartition, + epoch, + snapshotId, + Integer.MAX_VALUE, + 0 + ) + ); + + context.client.poll(); + + FetchSnapshotResponseData.PartitionSnapshot response = context.assertSentFetchSnapshotResponse(context.metadataPartition).get(); + assertEquals(Errors.NOT_LEADER_OR_FOLLOWER, Errors.forCode(response.errorCode())); + assertEquals(epoch, response.currentLeader().leaderEpoch()); + assertEquals(leaderId, response.currentLeader().leaderId()); + } + + @Test + public void testFetchSnapshotRequestWithInvalidPosition() throws Exception { + int localId = 0; + Set voters = Utils.mkSet(localId, localId + 1); + OffsetAndEpoch snapshotId = new OffsetAndEpoch(1, 1); + List records = Arrays.asList("foo", "bar"); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .appendToLog(snapshotId.epoch, Arrays.asList("a")) + .build(); + + context.becomeLeader(); + int epoch = context.currentEpoch(); + + context.advanceLocalLeaderHighWatermarkToLogEndOffset(); + + try (SnapshotWriter snapshot = context.client.createSnapshot(snapshotId.offset - 1, snapshotId.epoch, 0).get()) { + assertEquals(snapshotId, snapshot.snapshotId()); + snapshot.append(records); + snapshot.freeze(); + } + + context.deliverRequest( + fetchSnapshotRequest( + context.metadataPartition, + epoch, + snapshotId, + Integer.MAX_VALUE, + -1 + ) + ); + + context.client.poll(); + + FetchSnapshotResponseData.PartitionSnapshot response = context.assertSentFetchSnapshotResponse(context.metadataPartition).get(); + assertEquals(Errors.POSITION_OUT_OF_RANGE, Errors.forCode(response.errorCode())); + assertEquals(epoch, response.currentLeader().leaderEpoch()); + assertEquals(localId, response.currentLeader().leaderId()); + + RawSnapshotReader snapshot = context.log.readSnapshot(snapshotId).get(); + context.deliverRequest( + fetchSnapshotRequest( + context.metadataPartition, + epoch, + snapshotId, + Integer.MAX_VALUE, + snapshot.sizeInBytes() + ) + ); + + context.client.poll(); + + response = context.assertSentFetchSnapshotResponse(context.metadataPartition).get(); + assertEquals(Errors.POSITION_OUT_OF_RANGE, Errors.forCode(response.errorCode())); + assertEquals(epoch, response.currentLeader().leaderEpoch()); + assertEquals(localId, response.currentLeader().leaderId()); + } + + @Test + public void testFetchSnapshotRequestWithOlderEpoch() throws Exception { + int localId = 0; + Set voters = Utils.mkSet(localId, localId + 1); + int epoch = 2; + OffsetAndEpoch snapshotId = new OffsetAndEpoch(0, 0); + + RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch); + + context.deliverRequest( + fetchSnapshotRequest( + context.metadataPartition, + epoch - 1, + snapshotId, + Integer.MAX_VALUE, + 0 + ) + ); + + context.client.poll(); + + FetchSnapshotResponseData.PartitionSnapshot response = context.assertSentFetchSnapshotResponse(context.metadataPartition).get(); + assertEquals(Errors.FENCED_LEADER_EPOCH, Errors.forCode(response.errorCode())); + assertEquals(epoch, response.currentLeader().leaderEpoch()); + assertEquals(localId, response.currentLeader().leaderId()); + } + + @Test + public void testFetchSnapshotRequestWithNewerEpoch() throws Exception { + int localId = 0; + Set voters = Utils.mkSet(localId, localId + 1); + int epoch = 2; + OffsetAndEpoch snapshotId = new OffsetAndEpoch(0, 0); + + RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch); + + context.deliverRequest( + fetchSnapshotRequest( + context.metadataPartition, + epoch + 1, + snapshotId, + Integer.MAX_VALUE, + 0 + ) + ); + + context.client.poll(); + + FetchSnapshotResponseData.PartitionSnapshot response = context.assertSentFetchSnapshotResponse(context.metadataPartition).get(); + assertEquals(Errors.UNKNOWN_LEADER_EPOCH, Errors.forCode(response.errorCode())); + assertEquals(epoch, response.currentLeader().leaderEpoch()); + assertEquals(localId, response.currentLeader().leaderId()); + } + + @Test + public void testFetchResponseWithInvalidSnapshotId() throws Exception { + int localId = 0; + int leaderId = localId + 1; + Set voters = Utils.mkSet(localId, leaderId); + int epoch = 2; + OffsetAndEpoch invalidEpoch = new OffsetAndEpoch(100L, -1); + OffsetAndEpoch invalidEndOffset = new OffsetAndEpoch(-1L, 1); + int slept = 0; + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectedLeader(epoch, leaderId) + .build(); + + context.pollUntilRequest(); + RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); + context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); + + context.deliverResponse( + fetchRequest.correlationId, + fetchRequest.destinationId(), + snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, invalidEpoch, 200L) + ); + + // Handle the invalid response + context.client.poll(); + + // Expect another fetch request after backoff has expired + context.time.sleep(context.retryBackoffMs); + slept += context.retryBackoffMs; + + context.pollUntilRequest(); + fetchRequest = context.assertSentFetchRequest(); + context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); + + context.deliverResponse( + fetchRequest.correlationId, + fetchRequest.destinationId(), + snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, invalidEndOffset, 200L) + ); + + // Handle the invalid response + context.client.poll(); + + // Expect another fetch request after backoff has expired + context.time.sleep(context.retryBackoffMs); + slept += context.retryBackoffMs; + + context.pollUntilRequest(); + fetchRequest = context.assertSentFetchRequest(); + context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); + + // Fetch timer is not reset; sleeping for remainder should transition to candidate + context.time.sleep(context.fetchTimeoutMs - slept); + + context.pollUntilRequest(); + + context.assertSentVoteRequest(epoch + 1, 0, 0L, 1); + context.assertVotedCandidate(epoch + 1, localId); + } + + @Test + public void testFetchResponseWithSnapshotId() throws Exception { + int localId = 0; + int leaderId = localId + 1; + Set voters = Utils.mkSet(localId, leaderId); + int epoch = 2; + OffsetAndEpoch snapshotId = new OffsetAndEpoch(100L, 1); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectedLeader(epoch, leaderId) + .build(); + + context.pollUntilRequest(); + RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); + context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); + + context.deliverResponse( + fetchRequest.correlationId, + fetchRequest.destinationId(), + snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L) + ); + + context.pollUntilRequest(); + RaftRequest.Outbound snapshotRequest = context.assertSentFetchSnapshotRequest(); + FetchSnapshotRequestData.PartitionSnapshot request = assertFetchSnapshotRequest( + snapshotRequest, + context.metadataPartition, + localId, + Integer.MAX_VALUE + ).get(); + assertEquals(snapshotId.offset, request.snapshotId().endOffset()); + assertEquals(snapshotId.epoch, request.snapshotId().epoch()); + assertEquals(0, request.position()); + + List records = Arrays.asList("foo", "bar"); + MemorySnapshotWriter memorySnapshot = new MemorySnapshotWriter(snapshotId); + try (SnapshotWriter snapshotWriter = snapshotWriter(context, memorySnapshot)) { + snapshotWriter.append(records); + snapshotWriter.freeze(); + } + + context.deliverResponse( + snapshotRequest.correlationId, + snapshotRequest.destinationId(), + fetchSnapshotResponse( + context.metadataPartition, + epoch, + leaderId, + snapshotId, + memorySnapshot.buffer().remaining(), + 0L, + memorySnapshot.buffer().slice() + ) + ); + + context.pollUntilRequest(); + fetchRequest = context.assertSentFetchRequest(); + context.assertFetchRequestData(fetchRequest, epoch, snapshotId.offset, snapshotId.epoch); + + // Check that the snapshot was written to the log + RawSnapshotReader snapshot = context.log.readSnapshot(snapshotId).get(); + assertEquals(memorySnapshot.buffer().remaining(), snapshot.sizeInBytes()); + SnapshotWriterReaderTest.assertSnapshot(Arrays.asList(records), snapshot); + + // Check that listener was notified of the new snapshot + try (SnapshotReader reader = context.listener.drainHandledSnapshot().get()) { + assertEquals(snapshotId, reader.snapshotId()); + SnapshotWriterReaderTest.assertSnapshot(Arrays.asList(records), reader); + } + } + + @Test + public void testFetchSnapshotResponsePartialData() throws Exception { + int localId = 0; + int leaderId = localId + 1; + Set voters = Utils.mkSet(localId, leaderId); + int epoch = 2; + OffsetAndEpoch snapshotId = new OffsetAndEpoch(100L, 1); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectedLeader(epoch, leaderId) + .build(); + + context.pollUntilRequest(); + RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); + context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); + + context.deliverResponse( + fetchRequest.correlationId, + fetchRequest.destinationId(), + snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L) + ); + + context.pollUntilRequest(); + RaftRequest.Outbound snapshotRequest = context.assertSentFetchSnapshotRequest(); + FetchSnapshotRequestData.PartitionSnapshot request = assertFetchSnapshotRequest( + snapshotRequest, + context.metadataPartition, + localId, + Integer.MAX_VALUE + ).get(); + assertEquals(snapshotId.offset, request.snapshotId().endOffset()); + assertEquals(snapshotId.epoch, request.snapshotId().epoch()); + assertEquals(0, request.position()); + + List records = Arrays.asList("foo", "bar"); + MemorySnapshotWriter memorySnapshot = new MemorySnapshotWriter(snapshotId); + try (SnapshotWriter snapshotWriter = snapshotWriter(context, memorySnapshot)) { + snapshotWriter.append(records); + snapshotWriter.freeze(); + } + + ByteBuffer sendingBuffer = memorySnapshot.buffer().slice(); + sendingBuffer.limit(sendingBuffer.limit() / 2); + + context.deliverResponse( + snapshotRequest.correlationId, + snapshotRequest.destinationId(), + fetchSnapshotResponse( + context.metadataPartition, + epoch, + leaderId, + snapshotId, + memorySnapshot.buffer().remaining(), + 0L, + sendingBuffer + ) + ); + + context.pollUntilRequest(); + snapshotRequest = context.assertSentFetchSnapshotRequest(); + request = assertFetchSnapshotRequest( + snapshotRequest, + context.metadataPartition, + localId, + Integer.MAX_VALUE + ).get(); + assertEquals(snapshotId.offset, request.snapshotId().endOffset()); + assertEquals(snapshotId.epoch, request.snapshotId().epoch()); + assertEquals(sendingBuffer.limit(), request.position()); + + sendingBuffer = memorySnapshot.buffer().slice(); + sendingBuffer.position(Math.toIntExact(request.position())); + + context.deliverResponse( + snapshotRequest.correlationId, + snapshotRequest.destinationId(), + fetchSnapshotResponse( + context.metadataPartition, + epoch, + leaderId, + snapshotId, + memorySnapshot.buffer().remaining(), + request.position(), + sendingBuffer + ) + ); + + context.pollUntilRequest(); + fetchRequest = context.assertSentFetchRequest(); + context.assertFetchRequestData(fetchRequest, epoch, snapshotId.offset, snapshotId.epoch); + + // Check that the snapshot was written to the log + RawSnapshotReader snapshot = context.log.readSnapshot(snapshotId).get(); + assertEquals(memorySnapshot.buffer().remaining(), snapshot.sizeInBytes()); + SnapshotWriterReaderTest.assertSnapshot(Arrays.asList(records), snapshot); + + // Check that listener was notified of the new snapshot + try (SnapshotReader reader = context.listener.drainHandledSnapshot().get()) { + assertEquals(snapshotId, reader.snapshotId()); + SnapshotWriterReaderTest.assertSnapshot(Arrays.asList(records), reader); + } + } + + @Test + public void testFetchSnapshotResponseMissingSnapshot() throws Exception { + int localId = 0; + int leaderId = localId + 1; + Set voters = Utils.mkSet(localId, leaderId); + int epoch = 2; + OffsetAndEpoch snapshotId = new OffsetAndEpoch(100L, 1); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectedLeader(epoch, leaderId) + .build(); + + context.pollUntilRequest(); + RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); + context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); + + context.deliverResponse( + fetchRequest.correlationId, + fetchRequest.destinationId(), + snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L) + ); + + context.pollUntilRequest(); + RaftRequest.Outbound snapshotRequest = context.assertSentFetchSnapshotRequest(); + FetchSnapshotRequestData.PartitionSnapshot request = assertFetchSnapshotRequest( + snapshotRequest, + context.metadataPartition, + localId, + Integer.MAX_VALUE + ).get(); + assertEquals(snapshotId.offset, request.snapshotId().endOffset()); + assertEquals(snapshotId.epoch, request.snapshotId().epoch()); + assertEquals(0, request.position()); + + // Reply with a snapshot not found error + context.deliverResponse( + snapshotRequest.correlationId, + snapshotRequest.destinationId(), + FetchSnapshotResponse.singleton( + context.metadataPartition, + responsePartitionSnapshot -> { + responsePartitionSnapshot + .currentLeader() + .setLeaderEpoch(epoch) + .setLeaderId(leaderId); + + return responsePartitionSnapshot + .setErrorCode(Errors.SNAPSHOT_NOT_FOUND.code()); + } + ) + ); + + context.pollUntilRequest(); + fetchRequest = context.assertSentFetchRequest(); + context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); + } + + @Test + public void testFetchSnapshotResponseFromNewerEpochNotLeader() throws Exception { + int localId = 0; + int firstLeaderId = localId + 1; + int secondLeaderId = firstLeaderId + 1; + Set voters = Utils.mkSet(localId, firstLeaderId, secondLeaderId); + int epoch = 2; + OffsetAndEpoch snapshotId = new OffsetAndEpoch(100L, 1); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectedLeader(epoch, firstLeaderId) + .build(); + + context.pollUntilRequest(); + RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); + context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); + + context.deliverResponse( + fetchRequest.correlationId, + fetchRequest.destinationId(), + snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, firstLeaderId, snapshotId, 200L) + ); + + context.pollUntilRequest(); + RaftRequest.Outbound snapshotRequest = context.assertSentFetchSnapshotRequest(); + FetchSnapshotRequestData.PartitionSnapshot request = assertFetchSnapshotRequest( + snapshotRequest, + context.metadataPartition, + localId, + Integer.MAX_VALUE + ).get(); + assertEquals(snapshotId.offset, request.snapshotId().endOffset()); + assertEquals(snapshotId.epoch, request.snapshotId().epoch()); + assertEquals(0, request.position()); + + // Reply with new leader response + context.deliverResponse( + snapshotRequest.correlationId, + snapshotRequest.destinationId(), + FetchSnapshotResponse.singleton( + context.metadataPartition, + responsePartitionSnapshot -> { + responsePartitionSnapshot + .currentLeader() + .setLeaderEpoch(epoch + 1) + .setLeaderId(secondLeaderId); + + return responsePartitionSnapshot + .setErrorCode(Errors.FENCED_LEADER_EPOCH.code()); + } + ) + ); + + context.pollUntilRequest(); + fetchRequest = context.assertSentFetchRequest(); + context.assertFetchRequestData(fetchRequest, epoch + 1, 0L, 0); + } + + @Test + public void testFetchSnapshotResponseFromNewerEpochLeader() throws Exception { + int localId = 0; + int leaderId = localId + 1; + Set voters = Utils.mkSet(localId, leaderId); + int epoch = 2; + OffsetAndEpoch snapshotId = new OffsetAndEpoch(100L, 1); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectedLeader(epoch, leaderId) + .build(); + + context.pollUntilRequest(); + RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); + context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); + + context.deliverResponse( + fetchRequest.correlationId, + fetchRequest.destinationId(), + snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L) + ); + + context.pollUntilRequest(); + RaftRequest.Outbound snapshotRequest = context.assertSentFetchSnapshotRequest(); + FetchSnapshotRequestData.PartitionSnapshot request = assertFetchSnapshotRequest( + snapshotRequest, + context.metadataPartition, + localId, + Integer.MAX_VALUE + ).get(); + assertEquals(snapshotId.offset, request.snapshotId().endOffset()); + assertEquals(snapshotId.epoch, request.snapshotId().epoch()); + assertEquals(0, request.position()); + + // Reply with new leader epoch + context.deliverResponse( + snapshotRequest.correlationId, + snapshotRequest.destinationId(), + FetchSnapshotResponse.singleton( + context.metadataPartition, + responsePartitionSnapshot -> { + responsePartitionSnapshot + .currentLeader() + .setLeaderEpoch(epoch + 1) + .setLeaderId(leaderId); + + return responsePartitionSnapshot + .setErrorCode(Errors.FENCED_LEADER_EPOCH.code()); + } + ) + ); + + context.pollUntilRequest(); + fetchRequest = context.assertSentFetchRequest(); + context.assertFetchRequestData(fetchRequest, epoch + 1, 0L, 0); + } + + @Test + public void testFetchSnapshotResponseFromOlderEpoch() throws Exception { + int localId = 0; + int leaderId = localId + 1; + Set voters = Utils.mkSet(localId, leaderId); + int epoch = 2; + OffsetAndEpoch snapshotId = new OffsetAndEpoch(100L, 1); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectedLeader(epoch, leaderId) + .build(); + + context.pollUntilRequest(); + RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); + context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); + + context.deliverResponse( + fetchRequest.correlationId, + fetchRequest.destinationId(), + snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L) + ); + + context.pollUntilRequest(); + RaftRequest.Outbound snapshotRequest = context.assertSentFetchSnapshotRequest(); + FetchSnapshotRequestData.PartitionSnapshot request = assertFetchSnapshotRequest( + snapshotRequest, + context.metadataPartition, + localId, + Integer.MAX_VALUE + ).get(); + assertEquals(snapshotId.offset, request.snapshotId().endOffset()); + assertEquals(snapshotId.epoch, request.snapshotId().epoch()); + assertEquals(0, request.position()); + + // Reply with unknown leader epoch + context.deliverResponse( + snapshotRequest.correlationId, + snapshotRequest.destinationId(), + FetchSnapshotResponse.singleton( + context.metadataPartition, + responsePartitionSnapshot -> { + responsePartitionSnapshot + .currentLeader() + .setLeaderEpoch(epoch - 1) + .setLeaderId(leaderId + 1); + + return responsePartitionSnapshot + .setErrorCode(Errors.UNKNOWN_LEADER_EPOCH.code()); + } + ) + ); + + context.pollUntilRequest(); + + // Follower should resend the fetch snapshot request + snapshotRequest = context.assertSentFetchSnapshotRequest(); + request = assertFetchSnapshotRequest( + snapshotRequest, + context.metadataPartition, + localId, + Integer.MAX_VALUE + ).get(); + assertEquals(snapshotId.offset, request.snapshotId().endOffset()); + assertEquals(snapshotId.epoch, request.snapshotId().epoch()); + assertEquals(0, request.position()); + } + + @Test + public void testFetchSnapshotResponseWithInvalidId() throws Exception { + int localId = 0; + int leaderId = localId + 1; + Set voters = Utils.mkSet(localId, leaderId); + int epoch = 2; + OffsetAndEpoch snapshotId = new OffsetAndEpoch(100L, 1); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectedLeader(epoch, leaderId) + .build(); + + context.pollUntilRequest(); + RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); + context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); + + context.deliverResponse( + fetchRequest.correlationId, + fetchRequest.destinationId(), + snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L) + ); + + context.pollUntilRequest(); + RaftRequest.Outbound snapshotRequest = context.assertSentFetchSnapshotRequest(); + FetchSnapshotRequestData.PartitionSnapshot request = assertFetchSnapshotRequest( + snapshotRequest, + context.metadataPartition, + localId, + Integer.MAX_VALUE + ).get(); + assertEquals(snapshotId.offset, request.snapshotId().endOffset()); + assertEquals(snapshotId.epoch, request.snapshotId().epoch()); + assertEquals(0, request.position()); + + // Reply with an invalid snapshot id endOffset + context.deliverResponse( + snapshotRequest.correlationId, + snapshotRequest.destinationId(), + FetchSnapshotResponse.singleton( + context.metadataPartition, + responsePartitionSnapshot -> { + responsePartitionSnapshot + .currentLeader() + .setLeaderEpoch(epoch) + .setLeaderId(leaderId); + + responsePartitionSnapshot + .snapshotId() + .setEndOffset(-1) + .setEpoch(snapshotId.epoch); + + return responsePartitionSnapshot; + } + ) + ); + + context.pollUntilRequest(); + + // Follower should send a fetch request + fetchRequest = context.assertSentFetchRequest(); + context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); + + context.deliverResponse( + fetchRequest.correlationId, + fetchRequest.destinationId(), + snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L) + ); + + context.pollUntilRequest(); + + snapshotRequest = context.assertSentFetchSnapshotRequest(); + request = assertFetchSnapshotRequest( + snapshotRequest, + context.metadataPartition, + localId, + Integer.MAX_VALUE + ).get(); + assertEquals(snapshotId.offset, request.snapshotId().endOffset()); + assertEquals(snapshotId.epoch, request.snapshotId().epoch()); + assertEquals(0, request.position()); + + // Reply with an invalid snapshot id epoch + context.deliverResponse( + snapshotRequest.correlationId, + snapshotRequest.destinationId(), + FetchSnapshotResponse.singleton( + context.metadataPartition, + responsePartitionSnapshot -> { + responsePartitionSnapshot + .currentLeader() + .setLeaderEpoch(epoch) + .setLeaderId(leaderId); + + responsePartitionSnapshot + .snapshotId() + .setEndOffset(snapshotId.offset) + .setEpoch(-1); + + return responsePartitionSnapshot; + } + ) + ); + + context.pollUntilRequest(); + + // Follower should send a fetch request + fetchRequest = context.assertSentFetchRequest(); + context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); + } + + @Test + public void testFetchSnapshotResponseToNotFollower() throws Exception { + int localId = 0; + int leaderId = localId + 1; + Set voters = Utils.mkSet(localId, leaderId); + int epoch = 2; + OffsetAndEpoch snapshotId = new OffsetAndEpoch(100L, 1); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectedLeader(epoch, leaderId) + .build(); + + context.pollUntilRequest(); + RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); + context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); + + context.deliverResponse( + fetchRequest.correlationId, + fetchRequest.destinationId(), + snapshotFetchResponse(context.metadataPartition, context.metadataTopicId, epoch, leaderId, snapshotId, 200L) + ); + + context.pollUntilRequest(); + + RaftRequest.Outbound snapshotRequest = context.assertSentFetchSnapshotRequest(); + FetchSnapshotRequestData.PartitionSnapshot request = assertFetchSnapshotRequest( + snapshotRequest, + context.metadataPartition, + localId, + Integer.MAX_VALUE + ).get(); + assertEquals(snapshotId.offset, request.snapshotId().endOffset()); + assertEquals(snapshotId.epoch, request.snapshotId().epoch()); + assertEquals(0, request.position()); + + // Sleeping for fetch timeout should transition to candidate + context.time.sleep(context.fetchTimeoutMs); + + context.pollUntilRequest(); + + context.assertSentVoteRequest(epoch + 1, 0, 0L, 1); + context.assertVotedCandidate(epoch + 1, localId); + + // Send the response late + context.deliverResponse( + snapshotRequest.correlationId, + snapshotRequest.destinationId(), + FetchSnapshotResponse.singleton( + context.metadataPartition, + responsePartitionSnapshot -> { + responsePartitionSnapshot + .currentLeader() + .setLeaderEpoch(epoch) + .setLeaderId(leaderId); + + responsePartitionSnapshot + .snapshotId() + .setEndOffset(snapshotId.offset) + .setEpoch(snapshotId.epoch); + + return responsePartitionSnapshot; + } + ) + ); + + // Assert that the response is ignored and the replicas stays as a candidate + context.client.poll(); + context.assertVotedCandidate(epoch + 1, localId); + } + + @Test + public void testFetchSnapshotRequestClusterIdValidation() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 5; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch); + + // null cluster id is accepted + context.deliverRequest( + fetchSnapshotRequest( + context.clusterId.toString(), + context.metadataPartition, + epoch, + new OffsetAndEpoch(0, 0), + Integer.MAX_VALUE, + 0 + ) + ); + context.pollUntilResponse(); + context.assertSentFetchSnapshotResponse(context.metadataPartition); + + // null cluster id is accepted + context.deliverRequest( + fetchSnapshotRequest( + null, + context.metadataPartition, + epoch, + new OffsetAndEpoch(0, 0), + Integer.MAX_VALUE, + 0 + ) + ); + context.pollUntilResponse(); + context.assertSentFetchSnapshotResponse(context.metadataPartition); + + // empty cluster id is rejected + context.deliverRequest( + fetchSnapshotRequest( + "", + context.metadataPartition, + epoch, + new OffsetAndEpoch(0, 0), + Integer.MAX_VALUE, + 0 + ) + ); + context.pollUntilResponse(); + context.assertSentFetchSnapshotResponse(Errors.INCONSISTENT_CLUSTER_ID); + + // invalid cluster id is rejected + context.deliverRequest( + fetchSnapshotRequest( + "invalid-uuid", + context.metadataPartition, + epoch, + new OffsetAndEpoch(0, 0), + Integer.MAX_VALUE, + 0 + ) + ); + context.pollUntilResponse(); + context.assertSentFetchSnapshotResponse(Errors.INCONSISTENT_CLUSTER_ID); + } + + @Test + public void testCreateSnapshotAsLeaderWithInvalidSnapshotId() throws Exception { + int localId = 0; + int otherNodeId = localId + 1; + Set voters = Utils.mkSet(localId, otherNodeId); + int epoch = 2; + + List appendRecords = Arrays.asList("a", "b", "c"); + OffsetAndEpoch invalidSnapshotId1 = new OffsetAndEpoch(3, epoch); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .appendToLog(epoch, appendRecords) + .withAppendLingerMs(1) + .build(); + + context.becomeLeader(); + int currentEpoch = context.currentEpoch(); + + // When leader creating snapshot: + // 1.1 high watermark cannot be empty + assertEquals(OptionalLong.empty(), context.client.highWatermark()); + assertThrows(IllegalArgumentException.class, () -> context.client.createSnapshot(invalidSnapshotId1.offset, invalidSnapshotId1.epoch, 0)); + + // 1.2 high watermark must larger than or equal to the snapshotId's endOffset + context.advanceLocalLeaderHighWatermarkToLogEndOffset(); + // append some more records to make the LEO > high watermark + List newRecords = Arrays.asList("d", "e", "f"); + context.client.scheduleAppend(currentEpoch, newRecords); + context.time.sleep(context.appendLingerMs()); + context.client.poll(); + assertEquals(context.log.endOffset().offset, context.client.highWatermark().getAsLong() + newRecords.size()); + + OffsetAndEpoch invalidSnapshotId2 = new OffsetAndEpoch(context.client.highWatermark().getAsLong() + 1, currentEpoch); + assertThrows(IllegalArgumentException.class, () -> context.client.createSnapshot(invalidSnapshotId2.offset, invalidSnapshotId2.epoch, 0)); + + // 2 the quorum epoch must larger than or equal to the snapshotId's epoch + OffsetAndEpoch invalidSnapshotId3 = new OffsetAndEpoch(context.client.highWatermark().getAsLong() - 2, currentEpoch + 1); + assertThrows(IllegalArgumentException.class, () -> context.client.createSnapshot(invalidSnapshotId3.offset, invalidSnapshotId3.epoch, 0)); + + // 3 the snapshotId should be validated against endOffsetForEpoch + OffsetAndEpoch endOffsetForEpoch = context.log.endOffsetForEpoch(epoch); + assertEquals(epoch, endOffsetForEpoch.epoch); + OffsetAndEpoch invalidSnapshotId4 = new OffsetAndEpoch(endOffsetForEpoch.offset + 1, epoch); + assertThrows(IllegalArgumentException.class, () -> context.client.createSnapshot(invalidSnapshotId4.offset, invalidSnapshotId4.epoch, 0)); + } + + @Test + public void testCreateSnapshotAsFollowerWithInvalidSnapshotId() throws Exception { + int localId = 0; + int leaderId = 1; + int otherFollowerId = 2; + int epoch = 5; + Set voters = Utils.mkSet(localId, leaderId, otherFollowerId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectedLeader(epoch, leaderId) + .build(); + context.assertElectedLeader(epoch, leaderId); + + // When follower creating snapshot: + // 1) The high watermark cannot be empty + assertEquals(OptionalLong.empty(), context.client.highWatermark()); + OffsetAndEpoch invalidSnapshotId1 = new OffsetAndEpoch(0, 0); + assertThrows(IllegalArgumentException.class, () -> context.client.createSnapshot(invalidSnapshotId1.offset, invalidSnapshotId1.epoch, 0)); + + // Poll for our first fetch request + context.pollUntilRequest(); + RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); + assertTrue(voters.contains(fetchRequest.destinationId())); + context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); + + // The response does not advance the high watermark + List records1 = Arrays.asList("a", "b", "c"); + MemoryRecords batch1 = context.buildBatch(0L, 3, records1); + context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), + context.fetchResponse(epoch, leaderId, batch1, 0L, Errors.NONE)); + context.client.poll(); + + // 2) The high watermark must be larger than or equal to the snapshotId's endOffset + int currentEpoch = context.currentEpoch(); + OffsetAndEpoch invalidSnapshotId2 = new OffsetAndEpoch(context.client.highWatermark().getAsLong() + 1, currentEpoch); + assertThrows(IllegalArgumentException.class, () -> context.client.createSnapshot(invalidSnapshotId2.offset, invalidSnapshotId2.epoch, 0)); + + // 3) The quorum epoch must be larger than or equal to the snapshotId's epoch + OffsetAndEpoch invalidSnapshotId3 = new OffsetAndEpoch(context.client.highWatermark().getAsLong(), currentEpoch + 1); + assertThrows(IllegalArgumentException.class, () -> context.client.createSnapshot(invalidSnapshotId3.offset, invalidSnapshotId3.epoch, 0)); + + // The high watermark advances to be larger than log.endOffsetForEpoch(3), to test the case 3 + context.pollUntilRequest(); + fetchRequest = context.assertSentFetchRequest(); + assertTrue(voters.contains(fetchRequest.destinationId())); + context.assertFetchRequestData(fetchRequest, epoch, 3L, 3); + + List records2 = Arrays.asList("d", "e", "f"); + MemoryRecords batch2 = context.buildBatch(3L, 4, records2); + context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), + context.fetchResponse(epoch, leaderId, batch2, 6L, Errors.NONE)); + context.client.poll(); + assertEquals(6L, context.client.highWatermark().getAsLong()); + + // 4) The snapshotId should be validated against endOffsetForEpoch + OffsetAndEpoch endOffsetForEpoch = context.log.endOffsetForEpoch(3); + assertEquals(3, endOffsetForEpoch.epoch); + OffsetAndEpoch invalidSnapshotId4 = new OffsetAndEpoch(endOffsetForEpoch.offset + 1, epoch); + assertThrows(IllegalArgumentException.class, () -> context.client.createSnapshot(invalidSnapshotId4.offset, invalidSnapshotId4.epoch, 0)); + } + + private static FetchSnapshotRequestData fetchSnapshotRequest( + TopicPartition topicPartition, + int epoch, + OffsetAndEpoch offsetAndEpoch, + int maxBytes, + long position + ) { + return fetchSnapshotRequest(null, topicPartition, epoch, offsetAndEpoch, maxBytes, position); + } + + private static FetchSnapshotRequestData fetchSnapshotRequest( + String clusterId, + TopicPartition topicPartition, + int epoch, + OffsetAndEpoch offsetAndEpoch, + int maxBytes, + long position + ) { + FetchSnapshotRequestData.SnapshotId snapshotId = new FetchSnapshotRequestData.SnapshotId() + .setEndOffset(offsetAndEpoch.offset) + .setEpoch(offsetAndEpoch.epoch); + + FetchSnapshotRequestData request = FetchSnapshotRequest.singleton( + clusterId, + topicPartition, + snapshotPartition -> { + return snapshotPartition + .setCurrentLeaderEpoch(epoch) + .setSnapshotId(snapshotId) + .setPosition(position); + } + ); + + return request.setMaxBytes(maxBytes); + } + + private static FetchSnapshotResponseData fetchSnapshotResponse( + TopicPartition topicPartition, + int leaderEpoch, + int leaderId, + OffsetAndEpoch snapshotId, + long size, + long position, + ByteBuffer buffer + ) { + return FetchSnapshotResponse.singleton( + topicPartition, + partitionSnapshot -> { + partitionSnapshot.currentLeader() + .setLeaderEpoch(leaderEpoch) + .setLeaderId(leaderId); + + partitionSnapshot.snapshotId() + .setEndOffset(snapshotId.offset) + .setEpoch(snapshotId.epoch); + + return partitionSnapshot + .setSize(size) + .setPosition(position) + .setUnalignedRecords(MemoryRecords.readableRecords(buffer.slice())); + } + ); + } + + private static FetchResponseData snapshotFetchResponse( + TopicPartition topicPartition, + Uuid topicId, + int epoch, + int leaderId, + OffsetAndEpoch snapshotId, + long highWatermark + ) { + return RaftUtil.singletonFetchResponse(topicPartition, topicId, Errors.NONE, partitionData -> { + partitionData.setHighWatermark(highWatermark); + + partitionData.currentLeader() + .setLeaderEpoch(epoch) + .setLeaderId(leaderId); + + partitionData.snapshotId() + .setEpoch(snapshotId.epoch) + .setEndOffset(snapshotId.offset); + }); + } + + private static Optional assertFetchSnapshotRequest( + RaftRequest.Outbound request, + TopicPartition topicPartition, + int replicaId, + int maxBytes + ) { + assertTrue(request.data() instanceof FetchSnapshotRequestData); + + FetchSnapshotRequestData data = (FetchSnapshotRequestData) request.data(); + + assertEquals(replicaId, data.replicaId()); + assertEquals(maxBytes, data.maxBytes()); + + return FetchSnapshotRequest.forTopicPartition(data, topicPartition); + } + + private static SnapshotWriter snapshotWriter(RaftClientTestContext context, RawSnapshotWriter snapshot) { + return SnapshotWriter.createWithHeader( + () -> Optional.of(snapshot), + 4 * 1024, + MemoryPool.NONE, + context.time, + 0, + CompressionType.NONE, + new StringSerde() + ).get(); + } + + private final static class MemorySnapshotWriter implements RawSnapshotWriter { + private final OffsetAndEpoch snapshotId; + private ByteBuffer data; + private boolean frozen; + + public MemorySnapshotWriter(OffsetAndEpoch snapshotId) { + this.snapshotId = snapshotId; + this.data = ByteBuffer.allocate(0); + this.frozen = false; + } + + @Override + public OffsetAndEpoch snapshotId() { + return snapshotId; + } + + @Override + public long sizeInBytes() { + if (frozen) { + throw new RuntimeException("Snapshot is already frozen " + snapshotId); + } + + return data.position(); + } + + @Override + public void append(UnalignedMemoryRecords records) { + if (frozen) { + throw new RuntimeException("Snapshot is already frozen " + snapshotId); + } + append(records.buffer()); + } + + @Override + public void append(MemoryRecords records) { + if (frozen) { + throw new RuntimeException("Snapshot is already frozen " + snapshotId); + } + append(records.buffer()); + } + + private void append(ByteBuffer buffer) { + if (!(data.remaining() >= buffer.remaining())) { + ByteBuffer old = data; + old.flip(); + + int newSize = Math.max(data.capacity() * 2, data.capacity() + buffer.remaining()); + data = ByteBuffer.allocate(newSize); + + data.put(old); + } + data.put(buffer); + } + + @Override + public boolean isFrozen() { + return frozen; + } + + @Override + public void freeze() { + if (frozen) { + throw new RuntimeException("Snapshot is already frozen " + snapshotId); + } + + frozen = true; + data.flip(); + } + + @Override + public void close() {} + + public ByteBuffer buffer() { + return data; + } + } +} diff --git a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java new file mode 100644 index 0000000..9b2771d --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java @@ -0,0 +1,2783 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.errors.ClusterAuthorizationException; +import org.apache.kafka.common.errors.RecordBatchTooLargeException; +import org.apache.kafka.common.memory.MemoryPool; +import org.apache.kafka.common.message.BeginQuorumEpochResponseData; +import org.apache.kafka.common.message.DescribeQuorumResponseData.ReplicaState; +import org.apache.kafka.common.message.EndQuorumEpochResponseData; +import org.apache.kafka.common.message.FetchResponseData; +import org.apache.kafka.common.message.VoteResponseData; +import org.apache.kafka.common.metrics.KafkaMetric; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.MutableRecordBatch; +import org.apache.kafka.common.record.Record; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.record.Records; +import org.apache.kafka.common.requests.DescribeQuorumRequest; +import org.apache.kafka.common.requests.EndQuorumEpochResponse; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.raft.errors.BufferAllocationException; +import org.apache.kafka.raft.errors.NotLeaderException; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.OptionalLong; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeoutException; + +import static java.util.Collections.singletonList; +import static org.apache.kafka.raft.RaftClientTestContext.Builder.DEFAULT_ELECTION_TIMEOUT_MS; +import static org.apache.kafka.test.TestUtils.assertFutureThrows; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class KafkaRaftClientTest { + + @Test + public void testInitializeSingleMemberQuorum() throws IOException { + int localId = 0; + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, Collections.singleton(localId)).build(); + context.assertElectedLeader(1, localId); + } + + @Test + public void testInitializeAsLeaderFromStateStoreSingleMemberQuorum() throws Exception { + // Start off as leader. We should still bump the epoch after initialization + + int localId = 0; + int initialEpoch = 2; + Set voters = Collections.singleton(localId); + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectedLeader(initialEpoch, localId) + .build(); + + context.pollUntil(() -> context.log.endOffset().offset == 1L); + assertEquals(1L, context.log.endOffset().offset); + assertEquals(initialEpoch + 1, context.log.lastFetchedEpoch()); + assertEquals(new LeaderAndEpoch(OptionalInt.of(localId), initialEpoch + 1), + context.currentLeaderAndEpoch()); + context.assertElectedLeader(initialEpoch + 1, localId); + } + + @Test + public void testRejectVotesFromSameEpochAfterResigningLeadership() throws Exception { + int localId = 0; + int remoteId = 1; + Set voters = Utils.mkSet(localId, remoteId); + int epoch = 2; + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .updateRandom(r -> r.mockNextInt(DEFAULT_ELECTION_TIMEOUT_MS, 0)) + .withElectedLeader(epoch, localId) + .build(); + + assertEquals(0L, context.log.endOffset().offset); + context.assertElectedLeader(epoch, localId); + + // Since we were the leader in epoch 2, we should ensure that we will not vote for any + // other voter in the same epoch, even if it has caught up to the same position. + context.deliverRequest(context.voteRequest(epoch, remoteId, + context.log.lastFetchedEpoch(), context.log.endOffset().offset)); + context.pollUntilResponse(); + context.assertSentVoteResponse(Errors.NONE, epoch, OptionalInt.of(localId), false); + } + + @Test + public void testRejectVotesFromSameEpochAfterResigningCandidacy() throws Exception { + int localId = 0; + int remoteId = 1; + Set voters = Utils.mkSet(localId, remoteId); + int epoch = 2; + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .updateRandom(r -> r.mockNextInt(DEFAULT_ELECTION_TIMEOUT_MS, 0)) + .withVotedCandidate(epoch, localId) + .build(); + + assertEquals(0L, context.log.endOffset().offset); + context.assertVotedCandidate(epoch, localId); + + // Since we were the leader in epoch 2, we should ensure that we will not vote for any + // other voter in the same epoch, even if it has caught up to the same position. + context.deliverRequest(context.voteRequest(epoch, remoteId, + context.log.lastFetchedEpoch(), context.log.endOffset().offset)); + context.pollUntilResponse(); + context.assertSentVoteResponse(Errors.NONE, epoch, OptionalInt.empty(), false); + } + + @Test + public void testGrantVotesFromHigherEpochAfterResigningLeadership() throws Exception { + int localId = 0; + int remoteId = 1; + Set voters = Utils.mkSet(localId, remoteId); + int epoch = 2; + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .updateRandom(r -> r.mockNextInt(DEFAULT_ELECTION_TIMEOUT_MS, 0)) + .withElectedLeader(epoch, localId) + .build(); + + // Resign from leader, will restart in resigned state + assertTrue(context.client.quorum().isResigned()); + assertEquals(0L, context.log.endOffset().offset); + context.assertElectedLeader(epoch, localId); + + // Send vote request with higher epoch + context.deliverRequest(context.voteRequest(epoch + 1, remoteId, + context.log.lastFetchedEpoch(), context.log.endOffset().offset)); + context.client.poll(); + + // We will first transition to unattached and then grant vote and then transition to voted + assertTrue(context.client.quorum().isVoted()); + context.assertVotedCandidate(epoch + 1, remoteId); + context.assertSentVoteResponse(Errors.NONE, epoch + 1, OptionalInt.empty(), true); + } + + @Test + public void testGrantVotesFromHigherEpochAfterResigningCandidacy() throws Exception { + int localId = 0; + int remoteId = 1; + Set voters = Utils.mkSet(localId, remoteId); + int epoch = 2; + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .updateRandom(r -> r.mockNextInt(DEFAULT_ELECTION_TIMEOUT_MS, 0)) + .withVotedCandidate(epoch, localId) + .build(); + + // Resign from candidate, will restart in candidate state + assertTrue(context.client.quorum().isCandidate()); + assertEquals(0L, context.log.endOffset().offset); + context.assertVotedCandidate(epoch, localId); + + // Send vote request with higher epoch + context.deliverRequest(context.voteRequest(epoch + 1, remoteId, + context.log.lastFetchedEpoch(), context.log.endOffset().offset)); + context.client.poll(); + + // We will first transition to unattached and then grant vote and then transition to voted + assertTrue(context.client.quorum().isVoted()); + context.assertVotedCandidate(epoch + 1, remoteId); + context.assertSentVoteResponse(Errors.NONE, epoch + 1, OptionalInt.empty(), true); + } + + @Test + public void testGrantVotesWhenShuttingDown() throws Exception { + int localId = 0; + int remoteId = 1; + Set voters = Utils.mkSet(localId, remoteId); + int epoch = 2; + + RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch); + + // Beginning shutdown + context.client.shutdown(1000); + assertTrue(context.client.isShuttingDown()); + + // Send vote request with higher epoch + context.deliverRequest(context.voteRequest(epoch + 1, remoteId, + context.log.lastFetchedEpoch(), context.log.endOffset().offset)); + context.client.poll(); + + // We will first transition to unattached and then grant vote and then transition to voted + assertTrue(context.client.quorum().isVoted()); + context.assertVotedCandidate(epoch + 1, remoteId); + context.assertSentVoteResponse(Errors.NONE, epoch + 1, OptionalInt.empty(), true); + } + + @Test + public void testInitializeAsResignedAndBecomeCandidate() throws Exception { + int localId = 0; + int remoteId = 1; + Set voters = Utils.mkSet(localId, remoteId); + int epoch = 2; + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .updateRandom(r -> r.mockNextInt(DEFAULT_ELECTION_TIMEOUT_MS, 0)) + .withElectedLeader(epoch, localId) + .build(); + + // Resign from leader, will restart in resigned state + assertTrue(context.client.quorum().isResigned()); + assertEquals(0L, context.log.endOffset().offset); + context.assertElectedLeader(epoch, localId); + + // Election timeout + context.time.sleep(context.electionTimeoutMs()); + context.client.poll(); + + // Become candidate in a new epoch + assertTrue(context.client.quorum().isCandidate()); + context.assertVotedCandidate(epoch + 1, localId); + } + + @Test + public void testInitializeAsResignedLeaderFromStateStore() throws Exception { + int localId = 0; + Set voters = Utils.mkSet(localId, 1); + int epoch = 2; + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .updateRandom(r -> r.mockNextInt(DEFAULT_ELECTION_TIMEOUT_MS, 0)) + .withElectedLeader(epoch, localId) + .build(); + + // The node will remain elected, but start up in a resigned state + // in which no additional writes are accepted. + assertEquals(0L, context.log.endOffset().offset); + context.assertElectedLeader(epoch, localId); + context.client.poll(); + assertThrows(NotLeaderException.class, () -> context.client.scheduleAppend(epoch, Arrays.asList("a", "b"))); + + context.pollUntilRequest(); + int correlationId = context.assertSentEndQuorumEpochRequest(epoch, 1); + context.deliverResponse(correlationId, 1, context.endEpochResponse(epoch, OptionalInt.of(localId))); + context.client.poll(); + + context.time.sleep(context.electionTimeoutMs()); + context.pollUntilRequest(); + context.assertVotedCandidate(epoch + 1, localId); + context.assertSentVoteRequest(epoch + 1, 0, 0L, 1); + } + + @Test + public void testAppendFailedWithNotLeaderException() throws Exception { + int localId = 0; + Set voters = Utils.mkSet(localId, 1); + int epoch = 2; + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withUnknownLeader(epoch) + .build(); + + assertThrows(NotLeaderException.class, () -> context.client.scheduleAppend(epoch, Arrays.asList("a", "b"))); + } + + @Test + public void testAppendFailedWithBufferAllocationException() throws Exception { + int localId = 0; + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + + MemoryPool memoryPool = Mockito.mock(MemoryPool.class); + ByteBuffer leaderBuffer = ByteBuffer.allocate(256); + // Return null when allocation error + Mockito.when(memoryPool.tryAllocate(KafkaRaftClient.MAX_BATCH_SIZE_BYTES)) + .thenReturn(null); + Mockito.when(memoryPool.tryAllocate(256)) + .thenReturn(leaderBuffer); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withMemoryPool(memoryPool) + .build(); + + context.becomeLeader(); + assertEquals(OptionalInt.of(localId), context.currentLeader()); + int epoch = context.currentEpoch(); + + assertThrows(BufferAllocationException.class, () -> context.client.scheduleAppend(epoch, singletonList("a"))); + } + + @Test + public void testAppendFailedWithFencedEpoch() throws Exception { + int localId = 0; + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .build(); + + context.becomeLeader(); + assertEquals(OptionalInt.of(localId), context.currentLeader()); + int epoch = context.currentEpoch(); + + // Throws IllegalArgumentException on higher epoch + assertThrows(IllegalArgumentException.class, () -> context.client.scheduleAppend(epoch + 1, singletonList("a"))); + // Throws NotLeaderException on smaller epoch + assertThrows(NotLeaderException.class, () -> context.client.scheduleAppend(epoch - 1, singletonList("a"))); + } + + @Test + public void testAppendFailedWithRecordBatchTooLargeException() throws Exception { + int localId = 0; + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .build(); + + context.becomeLeader(); + assertEquals(OptionalInt.of(localId), context.currentLeader()); + int epoch = context.currentEpoch(); + + int size = KafkaRaftClient.MAX_BATCH_SIZE_BYTES / 8 + 1; // 8 is the estimate min size of each record + List batchToLarge = new ArrayList<>(size + 1); + for (int i = 0; i < size; i++) + batchToLarge.add("a"); + + assertThrows(RecordBatchTooLargeException.class, () -> context.client.scheduleAtomicAppend(epoch, batchToLarge)); + } + + @Test + public void testEndQuorumEpochRetriesWhileResigned() throws Exception { + int localId = 0; + int voter1 = 1; + int voter2 = 2; + Set voters = Utils.mkSet(localId, voter1, voter2); + int epoch = 19; + + // Start off as leader so that we will initialize in the Resigned state. + // Note that we intentionally set a request timeout which is smaller than + // the election timeout so that we can still in the Resigned state and + // verify retry behavior. + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectionTimeoutMs(10000) + .withRequestTimeoutMs(5000) + .withElectedLeader(epoch, localId) + .build(); + + context.pollUntilRequest(); + List requests = context.collectEndQuorumRequests( + epoch, Utils.mkSet(voter1, voter2), Optional.empty()); + assertEquals(2, requests.size()); + + // Respond to one of the requests so that we can verify that no additional + // request to this node is sent. + RaftRequest.Outbound endEpochOutbound = requests.get(0); + context.deliverResponse(endEpochOutbound.correlationId, endEpochOutbound.destinationId(), + context.endEpochResponse(epoch, OptionalInt.of(localId))); + context.client.poll(); + assertEquals(Collections.emptyList(), context.channel.drainSendQueue()); + + // Now sleep for the request timeout and verify that we get only one + // retried request from the voter that hasn't responded yet. + int nonRespondedId = requests.get(1).destinationId(); + context.time.sleep(6000); + context.pollUntilRequest(); + List retries = context.collectEndQuorumRequests( + epoch, Utils.mkSet(nonRespondedId), Optional.empty()); + assertEquals(1, retries.size()); + } + + @Test + public void testResignWillCompleteFetchPurgatory() throws Exception { + int localId = 0; + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .build(); + + context.becomeLeader(); + assertEquals(OptionalInt.of(localId), context.currentLeader()); + + // send fetch request when become leader + int epoch = context.currentEpoch(); + context.deliverRequest(context.fetchRequest(epoch, otherNodeId, context.log.endOffset().offset, epoch, 1000)); + context.client.poll(); + + // append some record, but the fetch in purgatory will still fail + context.log.appendAsLeader( + context.buildBatch(context.log.endOffset().offset, epoch, singletonList("raft")), + epoch + ); + + // when transition to resign, all request in fetchPurgatory will fail + context.client.shutdown(1000); + context.client.poll(); + context.assertSentFetchPartitionResponse(Errors.NOT_LEADER_OR_FOLLOWER, epoch, OptionalInt.of(localId)); + context.assertResignedLeader(epoch, localId); + + // shutting down finished + context.time.sleep(1000); + context.client.poll(); + assertFalse(context.client.isRunning()); + assertFalse(context.client.isShuttingDown()); + } + + @Test + public void testResignInOlderEpochIgnored() throws Exception { + int localId = 0; + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build(); + + context.becomeLeader(); + assertEquals(OptionalInt.of(localId), context.currentLeader()); + + int currentEpoch = context.currentEpoch(); + context.client.resign(currentEpoch - 1); + context.client.poll(); + + // Ensure we are still leader even after expiration of the election timeout. + context.time.sleep(context.electionTimeoutMs() * 2); + context.client.poll(); + context.assertElectedLeader(currentEpoch, localId); + } + + @Test + public void testHandleBeginQuorumEpochAfterUserInitiatedResign() throws Exception { + int localId = 0; + int remoteId1 = 1; + int remoteId2 = 2; + Set voters = Utils.mkSet(localId, remoteId1, remoteId2); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build(); + + context.becomeLeader(); + assertEquals(OptionalInt.of(localId), context.currentLeader()); + + int resignedEpoch = context.currentEpoch(); + + context.client.resign(resignedEpoch); + context.pollUntil(context.client.quorum()::isResigned); + + context.deliverRequest(context.beginEpochRequest(resignedEpoch + 1, remoteId1)); + context.pollUntilResponse(); + context.assertSentBeginQuorumEpochResponse(Errors.NONE); + context.assertElectedLeader(resignedEpoch + 1, remoteId1); + assertEquals(new LeaderAndEpoch(OptionalInt.of(remoteId1), resignedEpoch + 1), + context.listener.currentLeaderAndEpoch()); + } + + @Test + public void testElectionTimeoutAfterUserInitiatedResign() throws Exception { + int localId = 0; + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build(); + + context.becomeLeader(); + assertEquals(OptionalInt.of(localId), context.currentLeader()); + + int resignedEpoch = context.currentEpoch(); + + context.client.resign(resignedEpoch); + context.pollUntil(context.client.quorum()::isResigned); + + context.pollUntilRequest(); + int correlationId = context.assertSentEndQuorumEpochRequest(resignedEpoch, otherNodeId); + + EndQuorumEpochResponseData response = EndQuorumEpochResponse.singletonResponse( + Errors.NONE, + context.metadataPartition, + Errors.NONE, + resignedEpoch, + localId + ); + + context.deliverResponse(correlationId, otherNodeId, response); + context.client.poll(); + + // We do not resend `EndQuorumRequest` once the other voter has acknowledged it. + context.time.sleep(context.retryBackoffMs); + context.client.poll(); + assertFalse(context.channel.hasSentRequests()); + + // Any `Fetch` received in the resigned state should result in a NOT_LEADER error. + context.deliverRequest(context.fetchRequest(1, -1, 0, 0, 0)); + context.pollUntilResponse(); + context.assertSentFetchPartitionResponse(Errors.NOT_LEADER_OR_FOLLOWER, + resignedEpoch, OptionalInt.of(localId)); + + // After the election timer, we should become a candidate. + context.time.sleep(2 * context.electionTimeoutMs()); + context.pollUntil(context.client.quorum()::isCandidate); + assertEquals(resignedEpoch + 1, context.currentEpoch()); + assertEquals(new LeaderAndEpoch(OptionalInt.empty(), resignedEpoch + 1), + context.listener.currentLeaderAndEpoch()); + } + + @Test + public void testCannotResignWithLargerEpochThanCurrentEpoch() throws Exception { + int localId = 0; + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build(); + context.becomeLeader(); + + assertThrows(IllegalArgumentException.class, + () -> context.client.resign(context.currentEpoch() + 1)); + } + + @Test + public void testCannotResignIfNotLeader() throws Exception { + int localId = 0; + int otherNodeId = 1; + int leaderEpoch = 2; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectedLeader(leaderEpoch, otherNodeId) + .build(); + + assertEquals(OptionalInt.of(otherNodeId), context.currentLeader()); + assertThrows(IllegalArgumentException.class, () -> context.client.resign(leaderEpoch)); + } + + @Test + public void testCannotResignIfObserver() throws Exception { + int leaderId = 1; + int otherNodeId = 2; + int epoch = 5; + Set voters = Utils.mkSet(leaderId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(OptionalInt.empty(), voters).build(); + context.pollUntilRequest(); + + RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); + assertTrue(voters.contains(fetchRequest.destinationId())); + context.assertFetchRequestData(fetchRequest, 0, 0L, 0); + + context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), + context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)); + + context.client.poll(); + context.assertElectedLeader(epoch, leaderId); + assertThrows(IllegalStateException.class, () -> context.client.resign(epoch)); + } + + @Test + public void testInitializeAsCandidateFromStateStore() throws Exception { + int localId = 0; + // Need 3 node to require a 2-node majority + Set voters = Utils.mkSet(localId, 1, 2); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withVotedCandidate(2, localId) + .build(); + context.assertVotedCandidate(2, localId); + assertEquals(0L, context.log.endOffset().offset); + + // The candidate will resume the election after reinitialization + context.pollUntilRequest(); + List voteRequests = context.collectVoteRequests(2, 0, 0); + assertEquals(2, voteRequests.size()); + } + + @Test + public void testInitializeAsCandidateAndBecomeLeader() throws Exception { + int localId = 0; + final int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build(); + + context.assertUnknownLeader(0); + context.time.sleep(2 * context.electionTimeoutMs()); + + context.pollUntilRequest(); + context.assertVotedCandidate(1, localId); + + int correlationId = context.assertSentVoteRequest(1, 0, 0L, 1); + context.deliverResponse(correlationId, otherNodeId, context.voteResponse(true, Optional.empty(), 1)); + + // Become leader after receiving the vote + context.pollUntil(() -> context.log.endOffset().offset == 1L); + context.assertElectedLeader(1, localId); + long electionTimestamp = context.time.milliseconds(); + + // Leader change record appended + assertEquals(1L, context.log.endOffset().offset); + assertEquals(1L, context.log.lastFlushedOffset()); + + // Send BeginQuorumEpoch to voters + context.client.poll(); + context.assertSentBeginQuorumEpochRequest(1, 1); + + Records records = context.log.read(0, Isolation.UNCOMMITTED).records; + RecordBatch batch = records.batches().iterator().next(); + assertTrue(batch.isControlBatch()); + + Record record = batch.iterator().next(); + assertEquals(electionTimestamp, record.timestamp()); + RaftClientTestContext.verifyLeaderChangeMessage(localId, Arrays.asList(localId, otherNodeId), + Arrays.asList(otherNodeId, localId), record.key(), record.value()); + } + + @Test + public void testInitializeAsCandidateAndBecomeLeaderQuorumOfThree() throws Exception { + int localId = 0; + final int firstNodeId = 1; + final int secondNodeId = 2; + Set voters = Utils.mkSet(localId, firstNodeId, secondNodeId); + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build(); + + context.assertUnknownLeader(0); + context.time.sleep(2 * context.electionTimeoutMs()); + + context.pollUntilRequest(); + context.assertVotedCandidate(1, localId); + + int correlationId = context.assertSentVoteRequest(1, 0, 0L, 2); + context.deliverResponse(correlationId, firstNodeId, context.voteResponse(true, Optional.empty(), 1)); + + // Become leader after receiving the vote + context.pollUntil(() -> context.log.endOffset().offset == 1L); + context.assertElectedLeader(1, localId); + long electionTimestamp = context.time.milliseconds(); + + // Leader change record appended + assertEquals(1L, context.log.endOffset().offset); + assertEquals(1L, context.log.lastFlushedOffset()); + + // Send BeginQuorumEpoch to voters + context.client.poll(); + context.assertSentBeginQuorumEpochRequest(1, 2); + + Records records = context.log.read(0, Isolation.UNCOMMITTED).records; + RecordBatch batch = records.batches().iterator().next(); + assertTrue(batch.isControlBatch()); + + Record record = batch.iterator().next(); + assertEquals(electionTimestamp, record.timestamp()); + RaftClientTestContext.verifyLeaderChangeMessage(localId, Arrays.asList(localId, firstNodeId, secondNodeId), + Arrays.asList(firstNodeId, localId), record.key(), record.value()); + } + + @Test + public void testHandleBeginQuorumRequest() throws Exception { + int localId = 0; + int otherNodeId = 1; + int votedCandidateEpoch = 2; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withVotedCandidate(votedCandidateEpoch, otherNodeId) + .build(); + + context.deliverRequest(context.beginEpochRequest(votedCandidateEpoch, otherNodeId)); + context.pollUntilResponse(); + + context.assertElectedLeader(votedCandidateEpoch, otherNodeId); + + context.assertSentBeginQuorumEpochResponse(Errors.NONE, votedCandidateEpoch, OptionalInt.of(otherNodeId)); + } + + @Test + public void testHandleBeginQuorumResponse() throws Exception { + int localId = 0; + int otherNodeId = 1; + int leaderEpoch = 2; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectedLeader(leaderEpoch, localId) + .build(); + + context.deliverRequest(context.beginEpochRequest(leaderEpoch + 1, otherNodeId)); + context.pollUntilResponse(); + + context.assertElectedLeader(leaderEpoch + 1, otherNodeId); + } + + @Test + public void testEndQuorumIgnoredAsCandidateIfOlderEpoch() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 5; + int jitterMs = 85; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .updateRandom(r -> r.mockNextInt(jitterMs)) + .withUnknownLeader(epoch - 1) + .build(); + + // Sleep a little to ensure that we become a candidate + context.time.sleep(context.electionTimeoutMs() + jitterMs); + context.client.poll(); + context.assertVotedCandidate(epoch, localId); + + context.deliverRequest(context.endEpochRequest(epoch - 2, otherNodeId, + Collections.singletonList(localId))); + + context.client.poll(); + context.assertSentEndQuorumEpochResponse(Errors.FENCED_LEADER_EPOCH, epoch, OptionalInt.empty()); + + // We should still be candidate until expiration of election timeout + context.time.sleep(context.electionTimeoutMs() + jitterMs - 1); + context.client.poll(); + context.assertVotedCandidate(epoch, localId); + + // Enter the backoff period + context.time.sleep(1); + context.client.poll(); + context.assertVotedCandidate(epoch, localId); + + // After backoff, we will become a candidate again + context.time.sleep(context.electionBackoffMaxMs); + context.client.poll(); + context.assertVotedCandidate(epoch + 1, localId); + } + + @Test + public void testEndQuorumIgnoredAsLeaderIfOlderEpoch() throws Exception { + int localId = 0; + int voter2 = localId + 1; + int voter3 = localId + 2; + int epoch = 7; + Set voters = Utils.mkSet(localId, voter2, voter3); + + RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch); + + // One of the voters may have sent EndQuorumEpoch from an earlier epoch + context.deliverRequest(context.endEpochRequest(epoch - 2, voter2, Arrays.asList(localId, voter3))); + + context.pollUntilResponse(); + context.assertSentEndQuorumEpochResponse(Errors.FENCED_LEADER_EPOCH, epoch, OptionalInt.of(localId)); + + // We should still be leader as long as fetch timeout has not expired + context.time.sleep(context.fetchTimeoutMs - 1); + context.client.poll(); + context.assertElectedLeader(epoch, localId); + } + + @Test + public void testEndQuorumStartsNewElectionImmediatelyIfFollowerUnattached() throws Exception { + int localId = 0; + int voter2 = localId + 1; + int voter3 = localId + 2; + int epoch = 2; + Set voters = Utils.mkSet(localId, voter2, voter3); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withUnknownLeader(epoch) + .build(); + + context.deliverRequest(context.endEpochRequest(epoch, voter2, + Arrays.asList(localId, voter3))); + + context.pollUntilResponse(); + context.assertSentEndQuorumEpochResponse(Errors.NONE, epoch, OptionalInt.of(voter2)); + + // Should become a candidate immediately + context.client.poll(); + context.assertVotedCandidate(epoch + 1, localId); + } + + @Test + public void testAccumulatorClearedAfterBecomingFollower() throws Exception { + int localId = 0; + int otherNodeId = 1; + int lingerMs = 50; + Set voters = Utils.mkSet(localId, otherNodeId); + + MemoryPool memoryPool = Mockito.mock(MemoryPool.class); + ByteBuffer buffer = ByteBuffer.allocate(KafkaRaftClient.MAX_BATCH_SIZE_BYTES); + ByteBuffer leaderBuffer = ByteBuffer.allocate(256); + Mockito.when(memoryPool.tryAllocate(KafkaRaftClient.MAX_BATCH_SIZE_BYTES)) + .thenReturn(buffer); + Mockito.when(memoryPool.tryAllocate(256)) + .thenReturn(leaderBuffer); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withAppendLingerMs(lingerMs) + .withMemoryPool(memoryPool) + .build(); + + context.becomeLeader(); + assertEquals(OptionalInt.of(localId), context.currentLeader()); + int epoch = context.currentEpoch(); + + assertEquals(1L, context.client.scheduleAppend(epoch, singletonList("a"))); + context.deliverRequest(context.beginEpochRequest(epoch + 1, otherNodeId)); + context.pollUntilResponse(); + + context.assertElectedLeader(epoch + 1, otherNodeId); + Mockito.verify(memoryPool).release(buffer); + } + + @Test + public void testAccumulatorClearedAfterBecomingVoted() throws Exception { + int localId = 0; + int otherNodeId = 1; + int lingerMs = 50; + Set voters = Utils.mkSet(localId, otherNodeId); + + MemoryPool memoryPool = Mockito.mock(MemoryPool.class); + ByteBuffer buffer = ByteBuffer.allocate(KafkaRaftClient.MAX_BATCH_SIZE_BYTES); + ByteBuffer leaderBuffer = ByteBuffer.allocate(256); + Mockito.when(memoryPool.tryAllocate(KafkaRaftClient.MAX_BATCH_SIZE_BYTES)) + .thenReturn(buffer); + Mockito.when(memoryPool.tryAllocate(256)) + .thenReturn(leaderBuffer); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withAppendLingerMs(lingerMs) + .withMemoryPool(memoryPool) + .build(); + + context.becomeLeader(); + assertEquals(OptionalInt.of(localId), context.currentLeader()); + int epoch = context.currentEpoch(); + + assertEquals(1L, context.client.scheduleAppend(epoch, singletonList("a"))); + context.deliverRequest(context.voteRequest(epoch + 1, otherNodeId, epoch, + context.log.endOffset().offset)); + context.pollUntilResponse(); + + context.assertVotedCandidate(epoch + 1, otherNodeId); + Mockito.verify(memoryPool).release(buffer); + } + + @Test + public void testAccumulatorClearedAfterBecomingUnattached() throws Exception { + int localId = 0; + int otherNodeId = 1; + int lingerMs = 50; + Set voters = Utils.mkSet(localId, otherNodeId); + + MemoryPool memoryPool = Mockito.mock(MemoryPool.class); + ByteBuffer buffer = ByteBuffer.allocate(KafkaRaftClient.MAX_BATCH_SIZE_BYTES); + ByteBuffer leaderBuffer = ByteBuffer.allocate(256); + Mockito.when(memoryPool.tryAllocate(KafkaRaftClient.MAX_BATCH_SIZE_BYTES)) + .thenReturn(buffer); + Mockito.when(memoryPool.tryAllocate(256)) + .thenReturn(leaderBuffer); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withAppendLingerMs(lingerMs) + .withMemoryPool(memoryPool) + .build(); + + context.becomeLeader(); + assertEquals(OptionalInt.of(localId), context.currentLeader()); + int epoch = context.currentEpoch(); + + assertEquals(1L, context.client.scheduleAppend(epoch, singletonList("a"))); + context.deliverRequest(context.voteRequest(epoch + 1, otherNodeId, epoch, 0L)); + context.pollUntilResponse(); + + context.assertUnknownLeader(epoch + 1); + Mockito.verify(memoryPool).release(buffer); + } + + @Test + public void testChannelWokenUpIfLingerTimeoutReachedWithoutAppend() throws Exception { + // This test verifies that the client will set its poll timeout accounting + // for the lingerMs of a pending append + + int localId = 0; + int otherNodeId = 1; + int lingerMs = 50; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withAppendLingerMs(lingerMs) + .build(); + + context.becomeLeader(); + assertEquals(OptionalInt.of(localId), context.currentLeader()); + assertEquals(1L, context.log.endOffset().offset); + + int epoch = context.currentEpoch(); + assertEquals(1L, context.client.scheduleAppend(epoch, singletonList("a"))); + assertTrue(context.messageQueue.wakeupRequested()); + + context.client.poll(); + assertEquals(OptionalLong.of(lingerMs), context.messageQueue.lastPollTimeoutMs()); + + context.time.sleep(20); + context.client.poll(); + assertEquals(OptionalLong.of(30), context.messageQueue.lastPollTimeoutMs()); + + context.time.sleep(30); + context.client.poll(); + assertEquals(2L, context.log.endOffset().offset); + } + + @Test + public void testChannelWokenUpIfLingerTimeoutReachedDuringAppend() throws Exception { + // This test verifies that the client will get woken up immediately + // if the linger timeout has expired during an append + + int localId = 0; + int otherNodeId = 1; + int lingerMs = 50; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withAppendLingerMs(lingerMs) + .build(); + + context.becomeLeader(); + assertEquals(OptionalInt.of(localId), context.currentLeader()); + assertEquals(1L, context.log.endOffset().offset); + + int epoch = context.currentEpoch(); + assertEquals(1L, context.client.scheduleAppend(epoch, singletonList("a"))); + assertTrue(context.messageQueue.wakeupRequested()); + + context.client.poll(); + assertFalse(context.messageQueue.wakeupRequested()); + assertEquals(OptionalLong.of(lingerMs), context.messageQueue.lastPollTimeoutMs()); + + context.time.sleep(lingerMs); + assertEquals(2L, context.client.scheduleAppend(epoch, singletonList("b"))); + assertTrue(context.messageQueue.wakeupRequested()); + + context.client.poll(); + assertEquals(3L, context.log.endOffset().offset); + } + + @Test + public void testHandleEndQuorumRequest() throws Exception { + int localId = 0; + int oldLeaderId = 1; + int leaderEpoch = 2; + Set voters = Utils.mkSet(localId, oldLeaderId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectedLeader(leaderEpoch, oldLeaderId) + .build(); + + context.deliverRequest(context.endEpochRequest(leaderEpoch, oldLeaderId, + Collections.singletonList(localId))); + + context.pollUntilResponse(); + context.assertSentEndQuorumEpochResponse(Errors.NONE, leaderEpoch, OptionalInt.of(oldLeaderId)); + + context.client.poll(); + context.assertVotedCandidate(leaderEpoch + 1, localId); + } + + @Test + public void testHandleEndQuorumRequestWithLowerPriorityToBecomeLeader() throws Exception { + int localId = 0; + int oldLeaderId = 1; + int leaderEpoch = 2; + int preferredNextLeader = 3; + Set voters = Utils.mkSet(localId, oldLeaderId, preferredNextLeader); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectedLeader(leaderEpoch, oldLeaderId) + .build(); + + context.deliverRequest(context.endEpochRequest(leaderEpoch, oldLeaderId, + Arrays.asList(preferredNextLeader, localId))); + + context.pollUntilResponse(); + context.assertSentEndQuorumEpochResponse(Errors.NONE, leaderEpoch, OptionalInt.of(oldLeaderId)); + + // The election won't trigger by one round retry backoff + context.time.sleep(1); + + context.pollUntilRequest(); + + context.assertSentFetchRequest(leaderEpoch, 0, 0); + + context.time.sleep(context.retryBackoffMs); + + context.pollUntilRequest(); + + List voteRequests = context.collectVoteRequests(leaderEpoch + 1, 0, 0); + assertEquals(2, voteRequests.size()); + + // Should have already done self-voting + context.assertVotedCandidate(leaderEpoch + 1, localId); + } + + @Test + public void testVoteRequestTimeout() throws Exception { + int localId = 0; + int epoch = 1; + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build(); + context.assertUnknownLeader(0); + + context.time.sleep(2 * context.electionTimeoutMs()); + context.pollUntilRequest(); + context.assertVotedCandidate(epoch, localId); + + int correlationId = context.assertSentVoteRequest(epoch, 0, 0L, 1); + + context.time.sleep(context.requestTimeoutMs()); + context.client.poll(); + int retryCorrelationId = context.assertSentVoteRequest(epoch, 0, 0L, 1); + + // We will ignore the timed out response if it arrives late + context.deliverResponse(correlationId, otherNodeId, context.voteResponse(true, Optional.empty(), 1)); + context.client.poll(); + context.assertVotedCandidate(epoch, localId); + + // Become leader after receiving the retry response + context.deliverResponse(retryCorrelationId, otherNodeId, context.voteResponse(true, Optional.empty(), 1)); + context.client.poll(); + context.assertElectedLeader(epoch, localId); + } + + @Test + public void testHandleValidVoteRequestAsFollower() throws Exception { + int localId = 0; + int epoch = 2; + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withUnknownLeader(epoch) + .build(); + + context.deliverRequest(context.voteRequest(epoch, otherNodeId, epoch - 1, 1)); + context.pollUntilResponse(); + + context.assertSentVoteResponse(Errors.NONE, epoch, OptionalInt.empty(), true); + + context.assertVotedCandidate(epoch, otherNodeId); + } + + @Test + public void testHandleVoteRequestAsFollowerWithElectedLeader() throws Exception { + int localId = 0; + int epoch = 2; + int otherNodeId = 1; + int electedLeaderId = 3; + Set voters = Utils.mkSet(localId, otherNodeId, electedLeaderId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectedLeader(epoch, electedLeaderId) + .build(); + + context.deliverRequest(context.voteRequest(epoch, otherNodeId, epoch - 1, 1)); + context.pollUntilResponse(); + + context.assertSentVoteResponse(Errors.NONE, epoch, OptionalInt.of(electedLeaderId), false); + + context.assertElectedLeader(epoch, electedLeaderId); + } + + @Test + public void testHandleVoteRequestAsFollowerWithVotedCandidate() throws Exception { + int localId = 0; + int epoch = 2; + int otherNodeId = 1; + int votedCandidateId = 3; + Set voters = Utils.mkSet(localId, otherNodeId, votedCandidateId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withVotedCandidate(epoch, votedCandidateId) + .build(); + + context.deliverRequest(context.voteRequest(epoch, otherNodeId, epoch - 1, 1)); + context.pollUntilResponse(); + + context.assertSentVoteResponse(Errors.NONE, epoch, OptionalInt.empty(), false); + context.assertVotedCandidate(epoch, votedCandidateId); + } + + @Test + public void testHandleInvalidVoteRequestWithOlderEpoch() throws Exception { + int localId = 0; + int epoch = 2; + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withUnknownLeader(epoch) + .build(); + + context.deliverRequest(context.voteRequest(epoch - 1, otherNodeId, epoch - 2, 1)); + context.pollUntilResponse(); + + context.assertSentVoteResponse(Errors.FENCED_LEADER_EPOCH, epoch, OptionalInt.empty(), false); + context.assertUnknownLeader(epoch); + } + + @Test + public void testHandleInvalidVoteRequestAsObserver() throws Exception { + int localId = 0; + int epoch = 2; + int otherNodeId = 1; + int otherNodeId2 = 2; + Set voters = Utils.mkSet(otherNodeId, otherNodeId2); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withUnknownLeader(epoch) + .build(); + + context.deliverRequest(context.voteRequest(epoch + 1, otherNodeId, epoch, 1)); + context.pollUntilResponse(); + + context.assertSentVoteResponse(Errors.INCONSISTENT_VOTER_SET, epoch, OptionalInt.empty(), false); + context.assertUnknownLeader(epoch); + } + + @Test + public void testLeaderIgnoreVoteRequestOnSameEpoch() throws Exception { + int localId = 0; + int otherNodeId = 1; + int leaderEpoch = 2; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, leaderEpoch); + + context.deliverRequest(context.voteRequest(leaderEpoch, otherNodeId, leaderEpoch - 1, 1)); + + context.client.poll(); + + context.assertSentVoteResponse(Errors.NONE, leaderEpoch, OptionalInt.of(localId), false); + context.assertElectedLeader(leaderEpoch, localId); + } + + @Test + public void testListenerCommitCallbackAfterLeaderWrite() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 5; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch); + + // First poll has no high watermark advance + context.client.poll(); + assertEquals(OptionalLong.empty(), context.client.highWatermark()); + assertEquals(1L, context.log.endOffset().offset); + + // Let follower send a fetch to initialize the high watermark, + // note the offset 0 would be a control message for becoming the leader + context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 1L, epoch, 0)); + context.pollUntilResponse(); + context.assertSentFetchPartitionResponse(Errors.NONE, epoch, OptionalInt.of(localId)); + assertEquals(OptionalLong.of(1L), context.client.highWatermark()); + + List records = Arrays.asList("a", "b", "c"); + long offset = context.client.scheduleAppend(epoch, records); + context.client.poll(); + assertEquals(OptionalLong.empty(), context.listener.lastCommitOffset()); + + // Let the follower send a fetch, it should advance the high watermark + context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 1L, epoch, 500)); + context.pollUntilResponse(); + context.assertSentFetchPartitionResponse(Errors.NONE, epoch, OptionalInt.of(localId)); + assertEquals(OptionalLong.of(1L), context.client.highWatermark()); + assertEquals(OptionalLong.empty(), context.listener.lastCommitOffset()); + + // Let the follower send another fetch from offset 4 + context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 4L, epoch, 500)); + context.pollUntil(() -> context.client.highWatermark().equals(OptionalLong.of(4L))); + assertEquals(records, context.listener.commitWithLastOffset(offset)); + } + + @Test + public void testCandidateIgnoreVoteRequestOnSameEpoch() throws Exception { + int localId = 0; + int otherNodeId = 1; + int leaderEpoch = 2; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withVotedCandidate(leaderEpoch, localId) + .build(); + + context.pollUntilRequest(); + + context.deliverRequest(context.voteRequest(leaderEpoch, otherNodeId, leaderEpoch - 1, 1)); + context.client.poll(); + context.assertSentVoteResponse(Errors.NONE, leaderEpoch, OptionalInt.empty(), false); + context.assertVotedCandidate(leaderEpoch, localId); + } + + @Test + public void testRetryElection() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 1; + int exponentialFactor = 85; // set it large enough so that we will bound on jitter + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .updateRandom(r -> r.mockNextInt(exponentialFactor)) + .build(); + + context.assertUnknownLeader(0); + + context.time.sleep(2 * context.electionTimeoutMs()); + context.pollUntilRequest(); + context.assertVotedCandidate(epoch, localId); + + // Quorum size is two. If the other member rejects, then we need to schedule a revote. + int correlationId = context.assertSentVoteRequest(epoch, 0, 0L, 1); + context.deliverResponse(correlationId, otherNodeId, context.voteResponse(false, Optional.empty(), 1)); + + context.client.poll(); + + // All nodes have rejected our candidacy, but we should still remember that we had voted + context.assertVotedCandidate(epoch, localId); + + // Even though our candidacy was rejected, we will backoff for jitter period + // before we bump the epoch and start a new election. + context.time.sleep(context.electionBackoffMaxMs - 1); + context.client.poll(); + context.assertVotedCandidate(epoch, localId); + + // After jitter expires, we become a candidate again + context.time.sleep(1); + context.client.poll(); + context.pollUntilRequest(); + context.assertVotedCandidate(epoch + 1, localId); + context.assertSentVoteRequest(epoch + 1, 0, 0L, 1); + } + + @Test + public void testInitializeAsFollowerEmptyLog() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 5; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectedLeader(epoch, otherNodeId) + .build(); + + context.assertElectedLeader(epoch, otherNodeId); + + context.pollUntilRequest(); + + context.assertSentFetchRequest(epoch, 0L, 0); + } + + @Test + public void testInitializeAsFollowerNonEmptyLog() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 5; + int lastEpoch = 3; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectedLeader(epoch, otherNodeId) + .appendToLog(lastEpoch, singletonList("foo")) + .build(); + + context.assertElectedLeader(epoch, otherNodeId); + + context.pollUntilRequest(); + context.assertSentFetchRequest(epoch, 1L, lastEpoch); + } + + @Test + public void testVoterBecomeCandidateAfterFetchTimeout() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 5; + int lastEpoch = 3; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectedLeader(epoch, otherNodeId) + .appendToLog(lastEpoch, singletonList("foo")) + .build(); + context.assertElectedLeader(epoch, otherNodeId); + + context.pollUntilRequest(); + context.assertSentFetchRequest(epoch, 1L, lastEpoch); + + context.time.sleep(context.fetchTimeoutMs); + + context.pollUntilRequest(); + + context.assertSentVoteRequest(epoch + 1, lastEpoch, 1L, 1); + context.assertVotedCandidate(epoch + 1, localId); + } + + @Test + public void testInitializeObserverNoPreviousState() throws Exception { + int localId = 0; + int leaderId = 1; + int otherNodeId = 2; + int epoch = 5; + Set voters = Utils.mkSet(leaderId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build(); + + context.pollUntilRequest(); + RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); + assertTrue(voters.contains(fetchRequest.destinationId())); + context.assertFetchRequestData(fetchRequest, 0, 0L, 0); + + context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), + context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)); + + context.client.poll(); + context.assertElectedLeader(epoch, leaderId); + } + + @Test + public void testObserverQuorumDiscoveryFailure() throws Exception { + int localId = 0; + int leaderId = 1; + int epoch = 5; + Set voters = Utils.mkSet(leaderId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build(); + + context.pollUntilRequest(); + RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); + assertTrue(voters.contains(fetchRequest.destinationId())); + context.assertFetchRequestData(fetchRequest, 0, 0L, 0); + + context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), + context.fetchResponse(-1, -1, MemoryRecords.EMPTY, -1, Errors.UNKNOWN_SERVER_ERROR)); + context.client.poll(); + + context.time.sleep(context.retryBackoffMs); + context.pollUntilRequest(); + + fetchRequest = context.assertSentFetchRequest(); + assertTrue(voters.contains(fetchRequest.destinationId())); + context.assertFetchRequestData(fetchRequest, 0, 0L, 0); + + context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), + context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)); + context.client.poll(); + + context.assertElectedLeader(epoch, leaderId); + } + + @Test + public void testObserverSendDiscoveryFetchAfterFetchTimeout() throws Exception { + int localId = 0; + int leaderId = 1; + int otherNodeId = 2; + int epoch = 5; + Set voters = Utils.mkSet(leaderId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build(); + + context.pollUntilRequest(); + RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); + assertTrue(voters.contains(fetchRequest.destinationId())); + context.assertFetchRequestData(fetchRequest, 0, 0L, 0); + + context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), + context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)); + context.client.poll(); + + context.assertElectedLeader(epoch, leaderId); + context.time.sleep(context.fetchTimeoutMs); + + context.pollUntilRequest(); + fetchRequest = context.assertSentFetchRequest(); + assertTrue(voters.contains(fetchRequest.destinationId())); + context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); + } + + @Test + public void testInvalidFetchRequest() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 5; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch); + + context.deliverRequest(context.fetchRequest( + epoch, otherNodeId, -5L, 0, 0)); + context.pollUntilResponse(); + context.assertSentFetchPartitionResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(localId)); + + context.deliverRequest(context.fetchRequest( + epoch, otherNodeId, 0L, -1, 0)); + context.pollUntilResponse(); + context.assertSentFetchPartitionResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(localId)); + + context.deliverRequest(context.fetchRequest( + epoch, otherNodeId, 0L, epoch + 1, 0)); + context.pollUntilResponse(); + context.assertSentFetchPartitionResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(localId)); + + context.deliverRequest(context.fetchRequest( + epoch + 1, otherNodeId, 0L, 0, 0)); + context.pollUntilResponse(); + context.assertSentFetchPartitionResponse(Errors.UNKNOWN_LEADER_EPOCH, epoch, OptionalInt.of(localId)); + + context.deliverRequest(context.fetchRequest( + epoch, otherNodeId, 0L, 0, -1)); + context.pollUntilResponse(); + context.assertSentFetchPartitionResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(localId)); + } + + @Test + public void testFetchRequestClusterIdValidation() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 5; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch); + + // valid cluster id is accepted + context.deliverRequest(context.fetchRequest( + epoch, context.clusterId.toString(), otherNodeId, -5L, 0, 0)); + context.pollUntilResponse(); + context.assertSentFetchPartitionResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(localId)); + + // null cluster id is accepted + context.deliverRequest(context.fetchRequest( + epoch, null, otherNodeId, -5L, 0, 0)); + context.pollUntilResponse(); + context.assertSentFetchPartitionResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(localId)); + + // empty cluster id is rejected + context.deliverRequest(context.fetchRequest( + epoch, "", otherNodeId, -5L, 0, 0)); + context.pollUntilResponse(); + context.assertSentFetchPartitionResponse(Errors.INCONSISTENT_CLUSTER_ID); + + // invalid cluster id is rejected + context.deliverRequest(context.fetchRequest( + epoch, "invalid-uuid", otherNodeId, -5L, 0, 0)); + context.pollUntilResponse(); + context.assertSentFetchPartitionResponse(Errors.INCONSISTENT_CLUSTER_ID); + } + + @Test + public void testVoteRequestClusterIdValidation() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 5; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch); + + // valid cluster id is accepted + context.deliverRequest(context.voteRequest(epoch, localId, 0, 0)); + context.pollUntilResponse(); + context.assertSentVoteResponse(Errors.NONE, epoch, OptionalInt.of(localId), false); + + // null cluster id is accepted + context.deliverRequest(context.voteRequest(epoch, localId, 0, 0)); + context.pollUntilResponse(); + context.assertSentVoteResponse(Errors.NONE, epoch, OptionalInt.of(localId), false); + + // empty cluster id is rejected + context.deliverRequest(context.voteRequest("", epoch, localId, 0, 0)); + context.pollUntilResponse(); + context.assertSentVoteResponse(Errors.INCONSISTENT_CLUSTER_ID); + + // invalid cluster id is rejected + context.deliverRequest(context.voteRequest("invalid-uuid", epoch, localId, 0, 0)); + context.pollUntilResponse(); + context.assertSentVoteResponse(Errors.INCONSISTENT_CLUSTER_ID); + } + + @Test + public void testBeginQuorumEpochRequestClusterIdValidation() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 5; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch); + + // valid cluster id is accepted + context.deliverRequest(context.beginEpochRequest(context.clusterId.toString(), epoch, localId)); + context.pollUntilResponse(); + context.assertSentBeginQuorumEpochResponse(Errors.NONE, epoch, OptionalInt.of(localId)); + + // null cluster id is accepted + context.deliverRequest(context.beginEpochRequest(epoch, localId)); + context.pollUntilResponse(); + context.assertSentBeginQuorumEpochResponse(Errors.NONE, epoch, OptionalInt.of(localId)); + + // empty cluster id is rejected + context.deliverRequest(context.beginEpochRequest("", epoch, localId)); + context.pollUntilResponse(); + context.assertSentBeginQuorumEpochResponse(Errors.INCONSISTENT_CLUSTER_ID); + + // invalid cluster id is rejected + context.deliverRequest(context.beginEpochRequest("invalid-uuid", epoch, localId)); + context.pollUntilResponse(); + context.assertSentBeginQuorumEpochResponse(Errors.INCONSISTENT_CLUSTER_ID); + } + + @Test + public void testEndQuorumEpochRequestClusterIdValidation() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 5; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch); + + // valid cluster id is accepted + context.deliverRequest(context.endEpochRequest(context.clusterId.toString(), epoch, localId, Collections.singletonList(otherNodeId))); + context.pollUntilResponse(); + context.assertSentEndQuorumEpochResponse(Errors.NONE, epoch, OptionalInt.of(localId)); + + // null cluster id is accepted + context.deliverRequest(context.endEpochRequest(epoch, localId, Collections.singletonList(otherNodeId))); + context.pollUntilResponse(); + context.assertSentEndQuorumEpochResponse(Errors.NONE, epoch, OptionalInt.of(localId)); + + // empty cluster id is rejected + context.deliverRequest(context.endEpochRequest("", epoch, localId, Collections.singletonList(otherNodeId))); + context.pollUntilResponse(); + context.assertSentEndQuorumEpochResponse(Errors.INCONSISTENT_CLUSTER_ID); + + // invalid cluster id is rejected + context.deliverRequest(context.endEpochRequest("invalid-uuid", epoch, localId, Collections.singletonList(otherNodeId))); + context.pollUntilResponse(); + context.assertSentEndQuorumEpochResponse(Errors.INCONSISTENT_CLUSTER_ID); + } + + @Test + public void testVoterOnlyRequestValidation() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 5; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch); + + int nonVoterId = 2; + context.deliverRequest(context.voteRequest(epoch, nonVoterId, 0, 0)); + context.client.poll(); + context.assertSentVoteResponse(Errors.INCONSISTENT_VOTER_SET, epoch, OptionalInt.of(localId), false); + + context.deliverRequest(context.beginEpochRequest(epoch, nonVoterId)); + context.client.poll(); + context.assertSentBeginQuorumEpochResponse(Errors.INCONSISTENT_VOTER_SET, epoch, OptionalInt.of(localId)); + + context.deliverRequest(context.endEpochRequest(epoch, nonVoterId, Collections.singletonList(otherNodeId))); + context.client.poll(); + + // The sent request has no localId as a preferable voter. + context.assertSentEndQuorumEpochResponse(Errors.INCONSISTENT_VOTER_SET, epoch, OptionalInt.of(localId)); + } + + @Test + public void testInvalidVoteRequest() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 5; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectedLeader(epoch, otherNodeId) + .build(); + context.assertElectedLeader(epoch, otherNodeId); + + context.deliverRequest(context.voteRequest(epoch + 1, otherNodeId, 0, -5L)); + context.pollUntilResponse(); + context.assertSentVoteResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(otherNodeId), false); + context.assertElectedLeader(epoch, otherNodeId); + + context.deliverRequest(context.voteRequest(epoch + 1, otherNodeId, -1, 0L)); + context.pollUntilResponse(); + context.assertSentVoteResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(otherNodeId), false); + context.assertElectedLeader(epoch, otherNodeId); + + context.deliverRequest(context.voteRequest(epoch + 1, otherNodeId, epoch + 1, 0L)); + context.pollUntilResponse(); + context.assertSentVoteResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(otherNodeId), false); + context.assertElectedLeader(epoch, otherNodeId); + } + + @Test + public void testPurgatoryFetchTimeout() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 5; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch); + + // Follower sends a fetch which cannot be satisfied immediately + int maxWaitTimeMs = 500; + context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 1L, epoch, maxWaitTimeMs)); + context.client.poll(); + assertEquals(0, context.channel.drainSendQueue().size()); + + // After expiration of the max wait time, the fetch returns an empty record set + context.time.sleep(maxWaitTimeMs); + context.client.poll(); + MemoryRecords fetchedRecords = context.assertSentFetchPartitionResponse(Errors.NONE, epoch, OptionalInt.of(localId)); + assertEquals(0, fetchedRecords.sizeInBytes()); + } + + @Test + public void testPurgatoryFetchSatisfiedByWrite() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 5; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch); + + // Follower sends a fetch which cannot be satisfied immediately + context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 1L, epoch, 500)); + context.client.poll(); + assertEquals(0, context.channel.drainSendQueue().size()); + + // Append some records that can fulfill the Fetch request + String[] appendRecords = new String[] {"a", "b", "c"}; + context.client.scheduleAppend(epoch, Arrays.asList(appendRecords)); + context.client.poll(); + + MemoryRecords fetchedRecords = context.assertSentFetchPartitionResponse(Errors.NONE, epoch, OptionalInt.of(localId)); + RaftClientTestContext.assertMatchingRecords(appendRecords, fetchedRecords); + } + + @Test + public void testPurgatoryFetchCompletedByFollowerTransition() throws Exception { + int localId = 0; + int voter1 = localId; + int voter2 = localId + 1; + int voter3 = localId + 2; + int epoch = 5; + Set voters = Utils.mkSet(voter1, voter2, voter3); + + RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch); + + // Follower sends a fetch which cannot be satisfied immediately + context.deliverRequest(context.fetchRequest(epoch, voter2, 1L, epoch, 500)); + context.client.poll(); + assertTrue(context.channel.drainSendQueue().stream() + .noneMatch(msg -> msg.data() instanceof FetchResponseData)); + + // Now we get a BeginEpoch from the other voter and become a follower + context.deliverRequest(context.beginEpochRequest(epoch + 1, voter3)); + context.pollUntilResponse(); + context.assertElectedLeader(epoch + 1, voter3); + + // We expect the BeginQuorumEpoch response and a failed Fetch response + context.assertSentBeginQuorumEpochResponse(Errors.NONE, epoch + 1, OptionalInt.of(voter3)); + + // The fetch should be satisfied immediately and return an error + MemoryRecords fetchedRecords = context.assertSentFetchPartitionResponse( + Errors.NOT_LEADER_OR_FOLLOWER, epoch + 1, OptionalInt.of(voter3)); + assertEquals(0, fetchedRecords.sizeInBytes()); + } + + @Test + public void testFetchResponseIgnoredAfterBecomingCandidate() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 5; + // The other node starts out as the leader + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectedLeader(epoch, otherNodeId) + .build(); + context.assertElectedLeader(epoch, otherNodeId); + + // Wait until we have a Fetch inflight to the leader + context.pollUntilRequest(); + int fetchCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0); + + // Now await the fetch timeout and become a candidate + context.time.sleep(context.fetchTimeoutMs); + context.client.poll(); + context.assertVotedCandidate(epoch + 1, localId); + + // The fetch response from the old leader returns, but it should be ignored + Records records = context.buildBatch(0L, 3, Arrays.asList("a", "b")); + context.deliverResponse(fetchCorrelationId, otherNodeId, + context.fetchResponse(epoch, otherNodeId, records, 0L, Errors.NONE)); + + context.client.poll(); + assertEquals(0, context.log.endOffset().offset); + context.assertVotedCandidate(epoch + 1, localId); + } + + @Test + public void testFetchResponseIgnoredAfterBecomingFollowerOfDifferentLeader() throws Exception { + int localId = 0; + int voter1 = localId; + int voter2 = localId + 1; + int voter3 = localId + 2; + int epoch = 5; + // Start out with `voter2` as the leader + Set voters = Utils.mkSet(voter1, voter2, voter3); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectedLeader(epoch, voter2) + .build(); + context.assertElectedLeader(epoch, voter2); + + // Wait until we have a Fetch inflight to the leader + context.pollUntilRequest(); + int fetchCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0); + + // Now receive a BeginEpoch from `voter3` + context.deliverRequest(context.beginEpochRequest(epoch + 1, voter3)); + context.client.poll(); + context.assertElectedLeader(epoch + 1, voter3); + + // The fetch response from the old leader returns, but it should be ignored + Records records = context.buildBatch(0L, 3, Arrays.asList("a", "b")); + FetchResponseData response = context.fetchResponse(epoch, voter2, records, 0L, Errors.NONE); + context.deliverResponse(fetchCorrelationId, voter2, response); + + context.client.poll(); + assertEquals(0, context.log.endOffset().offset); + context.assertElectedLeader(epoch + 1, voter3); + } + + @Test + public void testVoteResponseIgnoredAfterBecomingFollower() throws Exception { + int localId = 0; + int voter1 = localId; + int voter2 = localId + 1; + int voter3 = localId + 2; + int epoch = 5; + Set voters = Utils.mkSet(voter1, voter2, voter3); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withUnknownLeader(epoch - 1) + .build(); + context.assertUnknownLeader(epoch - 1); + + // Sleep a little to ensure that we become a candidate + context.time.sleep(context.electionTimeoutMs() * 2); + + // Wait until the vote requests are inflight + context.pollUntilRequest(); + context.assertVotedCandidate(epoch, localId); + List voteRequests = context.collectVoteRequests(epoch, 0, 0); + assertEquals(2, voteRequests.size()); + + // While the vote requests are still inflight, we receive a BeginEpoch for the same epoch + context.deliverRequest(context.beginEpochRequest(epoch, voter3)); + context.client.poll(); + context.assertElectedLeader(epoch, voter3); + + // The vote requests now return and should be ignored + VoteResponseData voteResponse1 = context.voteResponse(false, Optional.empty(), epoch); + context.deliverResponse(voteRequests.get(0).correlationId, voter2, voteResponse1); + + VoteResponseData voteResponse2 = context.voteResponse(false, Optional.of(voter3), epoch); + context.deliverResponse(voteRequests.get(1).correlationId, voter3, voteResponse2); + + context.client.poll(); + context.assertElectedLeader(epoch, voter3); + } + + @Test + public void testObserverLeaderRediscoveryAfterBrokerNotAvailableError() throws Exception { + int localId = 0; + int leaderId = 1; + int otherNodeId = 2; + int epoch = 5; + Set voters = Utils.mkSet(leaderId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build(); + + context.discoverLeaderAsObserver(leaderId, epoch); + + context.pollUntilRequest(); + RaftRequest.Outbound fetchRequest1 = context.assertSentFetchRequest(); + assertEquals(leaderId, fetchRequest1.destinationId()); + context.assertFetchRequestData(fetchRequest1, epoch, 0L, 0); + + context.deliverResponse(fetchRequest1.correlationId, fetchRequest1.destinationId(), + context.fetchResponse(epoch, -1, MemoryRecords.EMPTY, -1, Errors.BROKER_NOT_AVAILABLE)); + context.pollUntilRequest(); + + // We should retry the Fetch against the other voter since the original + // voter connection will be backing off. + RaftRequest.Outbound fetchRequest2 = context.assertSentFetchRequest(); + assertNotEquals(leaderId, fetchRequest2.destinationId()); + assertTrue(voters.contains(fetchRequest2.destinationId())); + context.assertFetchRequestData(fetchRequest2, epoch, 0L, 0); + + Errors error = fetchRequest2.destinationId() == leaderId ? + Errors.NONE : Errors.NOT_LEADER_OR_FOLLOWER; + context.deliverResponse(fetchRequest2.correlationId, fetchRequest2.destinationId(), + context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, error)); + context.client.poll(); + + context.assertElectedLeader(epoch, leaderId); + } + + @Test + public void testObserverLeaderRediscoveryAfterRequestTimeout() throws Exception { + int localId = 0; + int leaderId = 1; + int otherNodeId = 2; + int epoch = 5; + Set voters = Utils.mkSet(leaderId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build(); + + context.discoverLeaderAsObserver(leaderId, epoch); + + context.pollUntilRequest(); + RaftRequest.Outbound fetchRequest1 = context.assertSentFetchRequest(); + assertEquals(leaderId, fetchRequest1.destinationId()); + context.assertFetchRequestData(fetchRequest1, epoch, 0L, 0); + + context.time.sleep(context.requestTimeoutMs()); + context.pollUntilRequest(); + + // We should retry the Fetch against the other voter since the original + // voter connection will be backing off. + RaftRequest.Outbound fetchRequest2 = context.assertSentFetchRequest(); + assertNotEquals(leaderId, fetchRequest2.destinationId()); + assertTrue(voters.contains(fetchRequest2.destinationId())); + context.assertFetchRequestData(fetchRequest2, epoch, 0L, 0); + + context.deliverResponse(fetchRequest2.correlationId, fetchRequest2.destinationId(), + context.fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)); + context.client.poll(); + + context.assertElectedLeader(epoch, leaderId); + } + + @Test + public void testLeaderGracefulShutdown() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch); + + // Now shutdown + int shutdownTimeoutMs = 5000; + CompletableFuture shutdownFuture = context.client.shutdown(shutdownTimeoutMs); + + // We should still be running until we have had a chance to send EndQuorumEpoch + assertTrue(context.client.isShuttingDown()); + assertTrue(context.client.isRunning()); + assertFalse(shutdownFuture.isDone()); + + // Send EndQuorumEpoch request to the other voter + context.pollUntilRequest(); + assertTrue(context.client.isShuttingDown()); + assertTrue(context.client.isRunning()); + context.assertSentEndQuorumEpochRequest(1, otherNodeId); + + // We should still be able to handle vote requests during graceful shutdown + // in order to help the new leader get elected + context.deliverRequest(context.voteRequest(epoch + 1, otherNodeId, epoch, 1L)); + context.client.poll(); + context.assertSentVoteResponse(Errors.NONE, epoch + 1, OptionalInt.empty(), true); + + // Graceful shutdown completes when a new leader is elected + context.deliverRequest(context.beginEpochRequest(2, otherNodeId)); + + TestUtils.waitForCondition(() -> { + context.client.poll(); + return !context.client.isRunning(); + }, 5000, "Client failed to shutdown before expiration of timeout"); + assertFalse(context.client.isShuttingDown()); + assertTrue(shutdownFuture.isDone()); + assertNull(shutdownFuture.get()); + } + + @Test + public void testEndQuorumEpochSentBasedOnFetchOffset() throws Exception { + int localId = 0; + int closeFollower = 2; + int laggingFollower = 1; + int epoch = 1; + Set voters = Utils.mkSet(localId, closeFollower, laggingFollower); + + RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch); + + // The lagging follower fetches first + context.deliverRequest(context.fetchRequest(1, laggingFollower, 1L, epoch, 0)); + context.pollUntilResponse(); + context.assertSentFetchPartitionResponse(1L, epoch); + + // Append some records, so that the close follower will be able to advance further. + context.client.scheduleAppend(epoch, Arrays.asList("foo", "bar")); + context.client.poll(); + + context.deliverRequest(context.fetchRequest(epoch, closeFollower, 3L, epoch, 0)); + context.pollUntilResponse(); + context.assertSentFetchPartitionResponse(3L, epoch); + + // Now shutdown + context.client.shutdown(context.electionTimeoutMs() * 2); + + // We should still be running until we have had a chance to send EndQuorumEpoch + assertTrue(context.client.isRunning()); + + // Send EndQuorumEpoch request to the close follower + context.pollUntilRequest(); + assertTrue(context.client.isRunning()); + + context.collectEndQuorumRequests( + epoch, + Utils.mkSet(closeFollower, laggingFollower), + Optional.of(Arrays.asList(closeFollower, laggingFollower)) + ); + } + + @Test + public void testDescribeQuorum() throws Exception { + int localId = 0; + int closeFollower = 2; + int laggingFollower = 1; + int epoch = 1; + Set voters = Utils.mkSet(localId, closeFollower, laggingFollower); + + RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch); + + context.deliverRequest(context.fetchRequest(1, laggingFollower, 1L, epoch, 0)); + context.pollUntilResponse(); + context.assertSentFetchPartitionResponse(1L, epoch); + + context.client.scheduleAppend(epoch, Arrays.asList("foo", "bar")); + context.client.poll(); + + context.deliverRequest(context.fetchRequest(epoch, closeFollower, 3L, epoch, 0)); + context.pollUntilResponse(); + context.assertSentFetchPartitionResponse(3L, epoch); + + // Create observer + int observerId = 3; + context.deliverRequest(context.fetchRequest(epoch, observerId, 0L, 0, 0)); + context.pollUntilResponse(); + context.assertSentFetchPartitionResponse(3L, epoch); + + context.deliverRequest(DescribeQuorumRequest.singletonRequest(context.metadataPartition)); + context.pollUntilResponse(); + + context.assertSentDescribeQuorumResponse(localId, epoch, 3L, + Arrays.asList( + new ReplicaState() + .setReplicaId(localId) + // As we are appending the records directly to the log, + // the leader end offset hasn't been updated yet. + .setLogEndOffset(3L), + new ReplicaState() + .setReplicaId(laggingFollower) + .setLogEndOffset(1L), + new ReplicaState() + .setReplicaId(closeFollower) + .setLogEndOffset(3)), + singletonList( + new ReplicaState() + .setReplicaId(observerId) + .setLogEndOffset(0L))); + } + + @Test + public void testLeaderGracefulShutdownTimeout() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch); + + // Now shutdown + int shutdownTimeoutMs = 5000; + CompletableFuture shutdownFuture = context.client.shutdown(shutdownTimeoutMs); + + // We should still be running until we have had a chance to send EndQuorumEpoch + assertTrue(context.client.isRunning()); + assertFalse(shutdownFuture.isDone()); + + // Send EndQuorumEpoch request to the other vote + context.pollUntilRequest(); + assertTrue(context.client.isRunning()); + + context.assertSentEndQuorumEpochRequest(epoch, otherNodeId); + + // The shutdown timeout is hit before we receive any requests or responses indicating an epoch bump + context.time.sleep(shutdownTimeoutMs); + + context.client.poll(); + assertFalse(context.client.isRunning()); + assertTrue(shutdownFuture.isCompletedExceptionally()); + assertFutureThrows(shutdownFuture, TimeoutException.class); + } + + @Test + public void testFollowerGracefulShutdown() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 5; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectedLeader(epoch, otherNodeId) + .build(); + context.assertElectedLeader(epoch, otherNodeId); + + context.client.poll(); + + int shutdownTimeoutMs = 5000; + CompletableFuture shutdownFuture = context.client.shutdown(shutdownTimeoutMs); + assertTrue(context.client.isRunning()); + assertFalse(shutdownFuture.isDone()); + + context.client.poll(); + assertFalse(context.client.isRunning()); + assertTrue(shutdownFuture.isDone()); + assertNull(shutdownFuture.get()); + } + + @Test + public void testObserverGracefulShutdown() throws Exception { + int localId = 0; + int voter1 = 1; + int voter2 = 2; + Set voters = Utils.mkSet(voter1, voter2); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withUnknownLeader(5) + .build(); + context.client.poll(); + context.assertUnknownLeader(5); + + // Observer shutdown should complete immediately even if the + // current leader is unknown + CompletableFuture shutdownFuture = context.client.shutdown(5000); + assertTrue(context.client.isRunning()); + assertFalse(shutdownFuture.isDone()); + + context.client.poll(); + assertFalse(context.client.isRunning()); + assertTrue(shutdownFuture.isDone()); + assertNull(shutdownFuture.get()); + } + + @Test + public void testGracefulShutdownSingleMemberQuorum() throws IOException { + int localId = 0; + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, Collections.singleton(localId)).build(); + + context.assertElectedLeader(1, localId); + context.client.poll(); + assertEquals(0, context.channel.drainSendQueue().size()); + int shutdownTimeoutMs = 5000; + context.client.shutdown(shutdownTimeoutMs); + assertTrue(context.client.isRunning()); + context.client.poll(); + assertFalse(context.client.isRunning()); + } + + @Test + public void testFollowerReplication() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 5; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectedLeader(epoch, otherNodeId) + .build(); + context.assertElectedLeader(epoch, otherNodeId); + + context.pollUntilRequest(); + + int fetchQuorumCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0); + Records records = context.buildBatch(0L, 3, Arrays.asList("a", "b")); + FetchResponseData response = context.fetchResponse(epoch, otherNodeId, records, 0L, Errors.NONE); + context.deliverResponse(fetchQuorumCorrelationId, otherNodeId, response); + + context.client.poll(); + assertEquals(2L, context.log.endOffset().offset); + assertEquals(2L, context.log.lastFlushedOffset()); + } + + @Test + public void testEmptyRecordSetInFetchResponse() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 5; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectedLeader(epoch, otherNodeId) + .build(); + context.assertElectedLeader(epoch, otherNodeId); + + // Receive an empty fetch response + context.pollUntilRequest(); + int fetchQuorumCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0); + FetchResponseData fetchResponse = context.fetchResponse(epoch, otherNodeId, + MemoryRecords.EMPTY, 0L, Errors.NONE); + context.deliverResponse(fetchQuorumCorrelationId, otherNodeId, fetchResponse); + context.client.poll(); + assertEquals(0L, context.log.endOffset().offset); + assertEquals(OptionalLong.of(0L), context.client.highWatermark()); + + // Receive some records in the next poll, but do not advance high watermark + context.pollUntilRequest(); + Records records = context.buildBatch(0L, epoch, Arrays.asList("a", "b")); + fetchQuorumCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0); + fetchResponse = context.fetchResponse(epoch, otherNodeId, + records, 0L, Errors.NONE); + context.deliverResponse(fetchQuorumCorrelationId, otherNodeId, fetchResponse); + context.client.poll(); + assertEquals(2L, context.log.endOffset().offset); + assertEquals(OptionalLong.of(0L), context.client.highWatermark()); + + // The next fetch response is empty, but should still advance the high watermark + context.pollUntilRequest(); + fetchQuorumCorrelationId = context.assertSentFetchRequest(epoch, 2L, epoch); + fetchResponse = context.fetchResponse(epoch, otherNodeId, + MemoryRecords.EMPTY, 2L, Errors.NONE); + context.deliverResponse(fetchQuorumCorrelationId, otherNodeId, fetchResponse); + context.client.poll(); + assertEquals(2L, context.log.endOffset().offset); + assertEquals(OptionalLong.of(2L), context.client.highWatermark()); + } + + @Test + public void testFetchShouldBeTreatedAsLeaderAcknowledgement() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 5; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .updateRandom(r -> r.mockNextInt(DEFAULT_ELECTION_TIMEOUT_MS, 0)) + .withUnknownLeader(epoch - 1) + .build(); + + context.time.sleep(context.electionTimeoutMs()); + context.expectAndGrantVotes(epoch); + + context.pollUntilRequest(); + + // We send BeginEpoch, but it gets lost and the destination finds the leader through the Fetch API + context.assertSentBeginQuorumEpochRequest(epoch, 1); + + context.deliverRequest(context.fetchRequest( + epoch, otherNodeId, 0L, 0, 500)); + + context.client.poll(); + + // The BeginEpoch request eventually times out. We should not send another one. + context.assertSentFetchPartitionResponse(Errors.NONE, epoch, OptionalInt.of(localId)); + context.time.sleep(context.requestTimeoutMs()); + + context.client.poll(); + + List sentMessages = context.channel.drainSendQueue(); + assertEquals(0, sentMessages.size()); + } + + @Test + public void testLeaderAppendSingleMemberQuorum() throws Exception { + int localId = 0; + Set voters = Collections.singleton(localId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build(); + long now = context.time.milliseconds(); + + context.pollUntil(() -> context.log.endOffset().offset == 1L); + context.assertElectedLeader(1, localId); + + // We still write the leader change message + assertEquals(OptionalLong.of(1L), context.client.highWatermark()); + + String[] appendRecords = new String[] {"a", "b", "c"}; + + // First poll has no high watermark advance + context.client.poll(); + assertEquals(OptionalLong.of(1L), context.client.highWatermark()); + + context.client.scheduleAppend(context.currentEpoch(), Arrays.asList(appendRecords)); + + // Then poll the appended data with leader change record + context.client.poll(); + assertEquals(OptionalLong.of(4L), context.client.highWatermark()); + + // Now try reading it + int otherNodeId = 1; + List batches = new ArrayList<>(2); + boolean appended = true; + + // Continue to fetch until the leader returns an empty response + while (appended) { + long fetchOffset = 0; + int lastFetchedEpoch = 0; + if (!batches.isEmpty()) { + MutableRecordBatch lastBatch = batches.get(batches.size() - 1); + fetchOffset = lastBatch.lastOffset() + 1; + lastFetchedEpoch = lastBatch.partitionLeaderEpoch(); + } + + context.deliverRequest(context.fetchRequest(1, otherNodeId, fetchOffset, lastFetchedEpoch, 0)); + context.pollUntilResponse(); + + MemoryRecords fetchedRecords = context.assertSentFetchPartitionResponse(Errors.NONE, 1, OptionalInt.of(localId)); + List fetchedBatch = Utils.toList(fetchedRecords.batchIterator()); + batches.addAll(fetchedBatch); + + appended = !fetchedBatch.isEmpty(); + } + + assertEquals(2, batches.size()); + + MutableRecordBatch leaderChangeBatch = batches.get(0); + assertTrue(leaderChangeBatch.isControlBatch()); + List readRecords = Utils.toList(leaderChangeBatch.iterator()); + assertEquals(1, readRecords.size()); + + Record record = readRecords.get(0); + assertEquals(now, record.timestamp()); + RaftClientTestContext.verifyLeaderChangeMessage(localId, Collections.singletonList(localId), + Collections.singletonList(localId), record.key(), record.value()); + + MutableRecordBatch batch = batches.get(1); + assertEquals(1, batch.partitionLeaderEpoch()); + readRecords = Utils.toList(batch.iterator()); + assertEquals(3, readRecords.size()); + + for (int i = 0; i < appendRecords.length; i++) { + assertEquals(appendRecords[i], Utils.utf8(readRecords.get(i).value())); + } + } + + @Test + public void testFollowerLogReconciliation() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 5; + int lastEpoch = 3; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectedLeader(epoch, otherNodeId) + .appendToLog(lastEpoch, Arrays.asList("foo", "bar")) + .appendToLog(lastEpoch, Arrays.asList("baz")) + .build(); + + context.assertElectedLeader(epoch, otherNodeId); + assertEquals(3L, context.log.endOffset().offset); + + context.pollUntilRequest(); + + int correlationId = context.assertSentFetchRequest(epoch, 3L, lastEpoch); + + FetchResponseData response = context.divergingFetchResponse(epoch, otherNodeId, 2L, + lastEpoch, 1L); + context.deliverResponse(correlationId, otherNodeId, response); + + // Poll again to complete truncation + context.client.poll(); + assertEquals(2L, context.log.endOffset().offset); + + // Now we should be fetching + context.client.poll(); + context.assertSentFetchRequest(epoch, 2L, lastEpoch); + } + + @Test + public void testMetrics() throws Exception { + int localId = 0; + int epoch = 1; + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, Collections.singleton(localId)) + .build(); + context.pollUntil(() -> context.log.endOffset().offset == 1L); + + assertNotNull(getMetric(context.metrics, "current-state")); + assertNotNull(getMetric(context.metrics, "current-leader")); + assertNotNull(getMetric(context.metrics, "current-vote")); + assertNotNull(getMetric(context.metrics, "current-epoch")); + assertNotNull(getMetric(context.metrics, "high-watermark")); + assertNotNull(getMetric(context.metrics, "log-end-offset")); + assertNotNull(getMetric(context.metrics, "log-end-epoch")); + assertNotNull(getMetric(context.metrics, "number-unknown-voter-connections")); + assertNotNull(getMetric(context.metrics, "poll-idle-ratio-avg")); + assertNotNull(getMetric(context.metrics, "commit-latency-avg")); + assertNotNull(getMetric(context.metrics, "commit-latency-max")); + assertNotNull(getMetric(context.metrics, "election-latency-avg")); + assertNotNull(getMetric(context.metrics, "election-latency-max")); + assertNotNull(getMetric(context.metrics, "fetch-records-rate")); + assertNotNull(getMetric(context.metrics, "append-records-rate")); + + assertEquals("leader", getMetric(context.metrics, "current-state").metricValue()); + assertEquals((double) localId, getMetric(context.metrics, "current-leader").metricValue()); + assertEquals((double) localId, getMetric(context.metrics, "current-vote").metricValue()); + assertEquals((double) epoch, getMetric(context.metrics, "current-epoch").metricValue()); + assertEquals((double) 1L, getMetric(context.metrics, "high-watermark").metricValue()); + assertEquals((double) 1L, getMetric(context.metrics, "log-end-offset").metricValue()); + assertEquals((double) epoch, getMetric(context.metrics, "log-end-epoch").metricValue()); + + context.client.scheduleAppend(epoch, Arrays.asList("a", "b", "c")); + context.client.poll(); + + assertEquals((double) 4L, getMetric(context.metrics, "high-watermark").metricValue()); + assertEquals((double) 4L, getMetric(context.metrics, "log-end-offset").metricValue()); + assertEquals((double) epoch, getMetric(context.metrics, "log-end-epoch").metricValue()); + + context.client.close(); + + // should only have total-metrics-count left + assertEquals(1, context.metrics.metrics().size()); + } + + @Test + public void testClusterAuthorizationFailedInFetch() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 5; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectedLeader(epoch, otherNodeId) + .build(); + + context.assertElectedLeader(epoch, otherNodeId); + + context.pollUntilRequest(); + + int correlationId = context.assertSentFetchRequest(epoch, 0, 0); + FetchResponseData response = new FetchResponseData() + .setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code()); + context.deliverResponse(correlationId, otherNodeId, response); + assertThrows(ClusterAuthorizationException.class, context.client::poll); + } + + @Test + public void testClusterAuthorizationFailedInBeginQuorumEpoch() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 5; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .updateRandom(r -> r.mockNextInt(DEFAULT_ELECTION_TIMEOUT_MS, 0)) + .withUnknownLeader(epoch - 1) + .build(); + + context.time.sleep(context.electionTimeoutMs()); + context.expectAndGrantVotes(epoch); + + context.pollUntilRequest(); + int correlationId = context.assertSentBeginQuorumEpochRequest(epoch, 1); + BeginQuorumEpochResponseData response = new BeginQuorumEpochResponseData() + .setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code()); + + context.deliverResponse(correlationId, otherNodeId, response); + assertThrows(ClusterAuthorizationException.class, context.client::poll); + } + + @Test + public void testClusterAuthorizationFailedInVote() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 5; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withUnknownLeader(epoch - 1) + .build(); + + // Sleep a little to ensure that we become a candidate + context.time.sleep(context.electionTimeoutMs() * 2); + context.pollUntilRequest(); + context.assertVotedCandidate(epoch, localId); + + int correlationId = context.assertSentVoteRequest(epoch, 0, 0L, 1); + VoteResponseData response = new VoteResponseData() + .setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code()); + + context.deliverResponse(correlationId, otherNodeId, response); + assertThrows(ClusterAuthorizationException.class, context.client::poll); + } + + @Test + public void testClusterAuthorizationFailedInEndQuorumEpoch() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 2; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch); + + context.client.shutdown(5000); + context.pollUntilRequest(); + + int correlationId = context.assertSentEndQuorumEpochRequest(epoch, otherNodeId); + EndQuorumEpochResponseData response = new EndQuorumEpochResponseData() + .setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code()); + + context.deliverResponse(correlationId, otherNodeId, response); + assertThrows(ClusterAuthorizationException.class, context.client::poll); + } + + @Test + public void testHandleClaimFiresImmediatelyOnEmptyLog() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 5; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch); + assertEquals(OptionalInt.of(epoch), context.listener.currentClaimedEpoch()); + } + + @Test + public void testHandleClaimCallbackFiresAfterHighWatermarkReachesEpochStartOffset() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 5; + Set voters = Utils.mkSet(localId, otherNodeId); + + List batch1 = Arrays.asList("1", "2", "3"); + List batch2 = Arrays.asList("4", "5", "6"); + List batch3 = Arrays.asList("7", "8", "9"); + + List> expectedBatches = Arrays.asList(batch1, batch2, batch3); + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .appendToLog(1, batch1) + .appendToLog(1, batch2) + .appendToLog(2, batch3) + .withUnknownLeader(epoch - 1) + .build(); + + context.becomeLeader(); + context.client.poll(); + + // After becoming leader, we expect the `LeaderChange` record to be appended + // in addition to the initial 9 records in the log. + assertEquals(10L, context.log.endOffset().offset); + + // The high watermark is not known to the leader until the followers + // begin fetching, so we should not have fired the `handleClaim` callback. + assertEquals(OptionalInt.empty(), context.listener.currentClaimedEpoch()); + assertEquals(OptionalLong.empty(), context.listener.lastCommitOffset()); + + // Deliver a fetch from the other voter. The high watermark will not + // be exposed until it is able to reach the start of the leader epoch, + // so we are unable to deliver committed data or fire `handleClaim`. + context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 3L, 1, 500)); + context.client.poll(); + assertEquals(OptionalInt.empty(), context.listener.currentClaimedEpoch()); + assertEquals(OptionalLong.empty(), context.listener.lastCommitOffset()); + + // Now catch up to the start of the leader epoch so that the high + // watermark advances and we can start sending committed data to the + // listener. Note that the `LeaderChange` control record is filtered. + context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 10L, epoch, 500)); + context.pollUntil(() -> { + int committedBatches = context.listener.numCommittedBatches(); + long baseOffset = 0; + for (int index = 0; index < committedBatches; index++) { + List expectedBatch = expectedBatches.get(index); + assertEquals(expectedBatch, context.listener.commitWithBaseOffset(baseOffset)); + baseOffset += expectedBatch.size(); + } + + return context.listener.currentClaimedEpoch().isPresent(); + }); + + assertEquals(OptionalInt.of(epoch), context.listener.currentClaimedEpoch()); + // Note that last committed offset is inclusive, hence we subtract 1. + assertEquals( + OptionalLong.of(expectedBatches.stream().mapToInt(List::size).sum() - 1), + context.listener.lastCommitOffset() + ); + } + + @Test + public void testLateRegisteredListenerCatchesUp() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 5; + Set voters = Utils.mkSet(localId, otherNodeId); + + List batch1 = Arrays.asList("1", "2", "3"); + List batch2 = Arrays.asList("4", "5", "6"); + List batch3 = Arrays.asList("7", "8", "9"); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .appendToLog(1, batch1) + .appendToLog(1, batch2) + .appendToLog(2, batch3) + .withUnknownLeader(epoch - 1) + .build(); + + context.becomeLeader(); + context.client.poll(); + assertEquals(10L, context.log.endOffset().offset); + + // Let the initial listener catch up + context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 10L, epoch, 0)); + context.pollUntil(() -> OptionalInt.of(epoch).equals(context.listener.currentClaimedEpoch())); + assertEquals(OptionalLong.of(10L), context.client.highWatermark()); + assertEquals(OptionalLong.of(8L), context.listener.lastCommitOffset()); + assertEquals(OptionalInt.of(epoch), context.listener.currentClaimedEpoch()); + // Ensure that the `handleClaim` callback was not fired early + assertEquals(9L, context.listener.claimedEpochStartOffset(epoch)); + + // Register a second listener and allow it to catch up to the high watermark + RaftClientTestContext.MockListener secondListener = new RaftClientTestContext.MockListener(OptionalInt.of(localId)); + context.client.register(secondListener); + context.pollUntil(() -> OptionalInt.of(epoch).equals(secondListener.currentClaimedEpoch())); + assertEquals(OptionalLong.of(8L), secondListener.lastCommitOffset()); + assertEquals(OptionalInt.of(epoch), context.listener.currentClaimedEpoch()); + // Ensure that the `handleClaim` callback was not fired early + assertEquals(9L, secondListener.claimedEpochStartOffset(epoch)); + } + + @Test + public void testReregistrationChangesListenerContext() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 5; + Set voters = Utils.mkSet(localId, otherNodeId); + + List batch1 = Arrays.asList("1", "2", "3"); + List batch2 = Arrays.asList("4", "5", "6"); + List batch3 = Arrays.asList("7", "8", "9"); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .appendToLog(1, batch1) + .appendToLog(1, batch2) + .appendToLog(2, batch3) + .withUnknownLeader(epoch - 1) + .build(); + + context.becomeLeader(); + context.client.poll(); + assertEquals(10L, context.log.endOffset().offset); + + // Let the initial listener catch up + context.advanceLocalLeaderHighWatermarkToLogEndOffset(); + context.pollUntil(() -> OptionalLong.of(8).equals(context.listener.lastCommitOffset())); + + // Register a second listener + RaftClientTestContext.MockListener secondListener = new RaftClientTestContext.MockListener(OptionalInt.of(localId)); + context.client.register(secondListener); + context.pollUntil(() -> OptionalLong.of(8).equals(secondListener.lastCommitOffset())); + context.client.unregister(secondListener); + + // Write to the log and show that the default listener gets updated... + assertEquals(10L, context.client.scheduleAppend(epoch, singletonList("a"))); + context.client.poll(); + context.advanceLocalLeaderHighWatermarkToLogEndOffset(); + context.pollUntil(() -> OptionalLong.of(10).equals(context.listener.lastCommitOffset())); + // ... but unregister listener doesn't + assertEquals(OptionalLong.of(8), secondListener.lastCommitOffset()); + } + + @Test + public void testHandleCommitCallbackFiresAfterFollowerHighWatermarkAdvances() throws Exception { + int localId = 0; + int otherNodeId = 1; + int epoch = 5; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withElectedLeader(epoch, otherNodeId) + .build(); + assertEquals(OptionalLong.empty(), context.client.highWatermark()); + + // Poll for our first fetch request + context.pollUntilRequest(); + RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); + assertTrue(voters.contains(fetchRequest.destinationId())); + context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); + + // The response does not advance the high watermark + List records1 = Arrays.asList("a", "b", "c"); + MemoryRecords batch1 = context.buildBatch(0L, 3, records1); + context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), + context.fetchResponse(epoch, otherNodeId, batch1, 0L, Errors.NONE)); + context.client.poll(); + + // The listener should not have seen any data + assertEquals(OptionalLong.of(0L), context.client.highWatermark()); + assertEquals(0, context.listener.numCommittedBatches()); + assertEquals(OptionalInt.empty(), context.listener.currentClaimedEpoch()); + + // Now look for the next fetch request + context.pollUntilRequest(); + fetchRequest = context.assertSentFetchRequest(); + assertTrue(voters.contains(fetchRequest.destinationId())); + context.assertFetchRequestData(fetchRequest, epoch, 3L, 3); + + // The high watermark advances to include the first batch we fetched + List records2 = Arrays.asList("d", "e", "f"); + MemoryRecords batch2 = context.buildBatch(3L, 3, records2); + context.deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), + context.fetchResponse(epoch, otherNodeId, batch2, 3L, Errors.NONE)); + context.client.poll(); + + // The listener should have seen only the data from the first batch + assertEquals(OptionalLong.of(3L), context.client.highWatermark()); + assertEquals(1, context.listener.numCommittedBatches()); + assertEquals(OptionalLong.of(2L), context.listener.lastCommitOffset()); + assertEquals(records1, context.listener.lastCommit().records()); + assertEquals(OptionalInt.empty(), context.listener.currentClaimedEpoch()); + } + + @Test + public void testHandleCommitCallbackFiresInVotedState() throws Exception { + // This test verifies that the state machine can still catch up even while + // an election is in progress as long as the high watermark is known. + + int localId = 0; + int otherNodeId = 1; + int epoch = 7; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .appendToLog(2, Arrays.asList("a", "b", "c")) + .appendToLog(4, Arrays.asList("d", "e", "f")) + .appendToLog(4, Arrays.asList("g", "h", "i")) + .withUnknownLeader(epoch - 1) + .build(); + + // Start off as the leader and receive a fetch to initialize the high watermark + context.becomeLeader(); + context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 10L, epoch, 500)); + context.client.poll(); + assertEquals(OptionalLong.of(10L), context.client.highWatermark()); + + // Now we receive a vote request which transitions us to the 'voted' state + int candidateEpoch = epoch + 1; + context.deliverRequest(context.voteRequest(candidateEpoch, otherNodeId, epoch, 10L)); + context.pollUntilResponse(); + context.assertVotedCandidate(candidateEpoch, otherNodeId); + assertEquals(OptionalLong.of(10L), context.client.highWatermark()); + + // Register another listener and verify that it catches up while we remain 'voted' + RaftClientTestContext.MockListener secondListener = new RaftClientTestContext.MockListener(OptionalInt.of(localId)); + context.client.register(secondListener); + context.client.poll(); + context.assertVotedCandidate(candidateEpoch, otherNodeId); + + // Note the offset is 8 because the record at offset 9 is a control record + context.pollUntil(() -> secondListener.lastCommitOffset().equals(OptionalLong.of(8L))); + assertEquals(OptionalLong.of(8L), secondListener.lastCommitOffset()); + assertEquals(OptionalInt.empty(), secondListener.currentClaimedEpoch()); + } + + @Test + public void testHandleCommitCallbackFiresInCandidateState() throws Exception { + // This test verifies that the state machine can still catch up even while + // an election is in progress as long as the high watermark is known. + + int localId = 0; + int otherNodeId = 1; + int epoch = 7; + Set voters = Utils.mkSet(localId, otherNodeId); + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .appendToLog(2, Arrays.asList("a", "b", "c")) + .appendToLog(4, Arrays.asList("d", "e", "f")) + .appendToLog(4, Arrays.asList("g", "h", "i")) + .withUnknownLeader(epoch - 1) + .build(); + + // Start off as the leader and receive a fetch to initialize the high watermark + context.becomeLeader(); + assertEquals(10L, context.log.endOffset().offset); + + context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 10L, epoch, 0)); + context.pollUntilResponse(); + assertEquals(OptionalLong.of(10L), context.client.highWatermark()); + context.assertSentFetchPartitionResponse(Errors.NONE, epoch, OptionalInt.of(localId)); + + // Now we receive a vote request which transitions us to the 'unattached' state + context.deliverRequest(context.voteRequest(epoch + 1, otherNodeId, epoch, 9L)); + context.pollUntilResponse(); + context.assertUnknownLeader(epoch + 1); + assertEquals(OptionalLong.of(10L), context.client.highWatermark()); + + // Timeout the election and become candidate + int candidateEpoch = epoch + 2; + context.time.sleep(context.electionTimeoutMs() * 2); + context.client.poll(); + context.assertVotedCandidate(candidateEpoch, localId); + + // Register another listener and verify that it catches up + RaftClientTestContext.MockListener secondListener = new RaftClientTestContext.MockListener(OptionalInt.of(localId)); + context.client.register(secondListener); + context.client.poll(); + context.assertVotedCandidate(candidateEpoch, localId); + + // Note the offset is 8 because the record at offset 9 is a control record + context.pollUntil(() -> secondListener.lastCommitOffset().equals(OptionalLong.of(8L))); + assertEquals(OptionalLong.of(8L), secondListener.lastCommitOffset()); + assertEquals(OptionalInt.empty(), secondListener.currentClaimedEpoch()); + } + + @Test + public void testObserverFetchWithNoLocalId() throws Exception { + // When no `localId` is defined, the client will behave as an observer. + // This is designed for tooling/debugging use cases. + + Set voters = Utils.mkSet(1, 2); + RaftClientTestContext context = new RaftClientTestContext.Builder(OptionalInt.empty(), voters) + .build(); + + // First fetch discovers the current leader and epoch + + context.pollUntilRequest(); + RaftRequest.Outbound fetchRequest1 = context.assertSentFetchRequest(); + assertTrue(voters.contains(fetchRequest1.destinationId())); + context.assertFetchRequestData(fetchRequest1, 0, 0L, 0); + + int leaderEpoch = 5; + int leaderId = 1; + + context.deliverResponse(fetchRequest1.correlationId, fetchRequest1.destinationId(), + context.fetchResponse(5, leaderId, MemoryRecords.EMPTY, 0L, Errors.FENCED_LEADER_EPOCH)); + context.client.poll(); + context.assertElectedLeader(leaderEpoch, leaderId); + + // Second fetch goes to the discovered leader + + context.pollUntilRequest(); + RaftRequest.Outbound fetchRequest2 = context.assertSentFetchRequest(); + assertEquals(leaderId, fetchRequest2.destinationId()); + context.assertFetchRequestData(fetchRequest2, leaderEpoch, 0L, 0); + + List records = Arrays.asList("a", "b", "c"); + MemoryRecords batch1 = context.buildBatch(0L, 3, records); + context.deliverResponse(fetchRequest2.correlationId, fetchRequest2.destinationId(), + context.fetchResponse(leaderEpoch, leaderId, batch1, 0L, Errors.NONE)); + context.client.poll(); + assertEquals(3L, context.log.endOffset().offset); + assertEquals(3, context.log.lastFetchedEpoch()); + } + + private static KafkaMetric getMetric(final Metrics metrics, final String name) { + return metrics.metrics().get(metrics.metricName(name, "raft-metrics")); + } + +} diff --git a/raft/src/test/java/org/apache/kafka/raft/LeaderStateTest.java b/raft/src/test/java/org/apache/kafka/raft/LeaderStateTest.java new file mode 100644 index 0000000..5f9989d --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/LeaderStateTest.java @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.raft.internals.BatchAccumulator; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.Mockito; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; + +import static java.util.Collections.emptySet; +import static java.util.Collections.singleton; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class LeaderStateTest { + private final int localId = 0; + private final int epoch = 5; + private final LogContext logContext = new LogContext(); + + private final BatchAccumulator accumulator = Mockito.mock(BatchAccumulator.class); + + private LeaderState newLeaderState( + Set voters, + long epochStartOffset + ) { + return new LeaderState<>( + localId, + epoch, + epochStartOffset, + voters, + voters, + accumulator, + logContext + ); + } + + @Test + public void testRequireNonNullAccumulator() { + assertThrows(NullPointerException.class, () -> new LeaderState<>( + localId, + epoch, + 0, + Collections.emptySet(), + Collections.emptySet(), + null, + logContext + )); + } + + @Test + public void testFollowerAcknowledgement() { + int node1 = 1; + int node2 = 2; + LeaderState state = newLeaderState(mkSet(localId, node1, node2), 0L); + assertEquals(mkSet(node1, node2), state.nonAcknowledgingVoters()); + state.addAcknowledgementFrom(node1); + assertEquals(singleton(node2), state.nonAcknowledgingVoters()); + state.addAcknowledgementFrom(node2); + assertEquals(emptySet(), state.nonAcknowledgingVoters()); + } + + @Test + public void testNonFollowerAcknowledgement() { + int nonVoterId = 1; + LeaderState state = newLeaderState(singleton(localId), 0L); + assertThrows(IllegalArgumentException.class, () -> state.addAcknowledgementFrom(nonVoterId)); + } + + @Test + public void testUpdateHighWatermarkQuorumSizeOne() { + LeaderState state = newLeaderState(singleton(localId), 15L); + assertEquals(Optional.empty(), state.highWatermark()); + assertFalse(state.updateLocalState(0, new LogOffsetMetadata(15L))); + assertEquals(emptySet(), state.nonAcknowledgingVoters()); + assertEquals(Optional.empty(), state.highWatermark()); + assertTrue(state.updateLocalState(0, new LogOffsetMetadata(16L))); + assertEquals(Optional.of(new LogOffsetMetadata(16L)), state.highWatermark()); + assertTrue(state.updateLocalState(0, new LogOffsetMetadata(20))); + assertEquals(Optional.of(new LogOffsetMetadata(20L)), state.highWatermark()); + } + + @Test + public void testNonMonotonicLocalEndOffsetUpdate() { + LeaderState state = newLeaderState(singleton(localId), 15L); + assertEquals(Optional.empty(), state.highWatermark()); + assertTrue(state.updateLocalState(0, new LogOffsetMetadata(16L))); + assertEquals(Optional.of(new LogOffsetMetadata(16L)), state.highWatermark()); + assertThrows(IllegalStateException.class, + () -> state.updateLocalState(0, new LogOffsetMetadata(15L))); + } + + @Test + public void testIdempotentEndOffsetUpdate() { + LeaderState state = newLeaderState(singleton(localId), 15L); + assertEquals(Optional.empty(), state.highWatermark()); + assertTrue(state.updateLocalState(0, new LogOffsetMetadata(16L))); + assertFalse(state.updateLocalState(0, new LogOffsetMetadata(16L))); + assertEquals(Optional.of(new LogOffsetMetadata(16L)), state.highWatermark()); + } + + @Test + public void testUpdateHighWatermarkMetadata() { + LeaderState state = newLeaderState(singleton(localId), 15L); + assertEquals(Optional.empty(), state.highWatermark()); + + LogOffsetMetadata initialHw = new LogOffsetMetadata(16L, Optional.of(new MockOffsetMetadata("bar"))); + assertTrue(state.updateLocalState(0, initialHw)); + assertEquals(Optional.of(initialHw), state.highWatermark()); + + LogOffsetMetadata updateHw = new LogOffsetMetadata(16L, Optional.of(new MockOffsetMetadata("baz"))); + assertTrue(state.updateLocalState(0, updateHw)); + assertEquals(Optional.of(updateHw), state.highWatermark()); + } + + @Test + public void testUpdateHighWatermarkQuorumSizeTwo() { + int otherNodeId = 1; + LeaderState state = newLeaderState(mkSet(localId, otherNodeId), 10L); + assertFalse(state.updateLocalState(0, new LogOffsetMetadata(13L))); + assertEquals(singleton(otherNodeId), state.nonAcknowledgingVoters()); + assertEquals(Optional.empty(), state.highWatermark()); + assertFalse(state.updateReplicaState(otherNodeId, 0, new LogOffsetMetadata(10L))); + assertEquals(emptySet(), state.nonAcknowledgingVoters()); + assertEquals(Optional.empty(), state.highWatermark()); + assertTrue(state.updateReplicaState(otherNodeId, 0, new LogOffsetMetadata(11L))); + assertEquals(Optional.of(new LogOffsetMetadata(11L)), state.highWatermark()); + assertTrue(state.updateReplicaState(otherNodeId, 0, new LogOffsetMetadata(13L))); + assertEquals(Optional.of(new LogOffsetMetadata(13L)), state.highWatermark()); + } + + @Test + public void testUpdateHighWatermarkQuorumSizeThree() { + int node1 = 1; + int node2 = 2; + LeaderState state = newLeaderState(mkSet(localId, node1, node2), 10L); + assertFalse(state.updateLocalState(0, new LogOffsetMetadata(15L))); + assertEquals(mkSet(node1, node2), state.nonAcknowledgingVoters()); + assertEquals(Optional.empty(), state.highWatermark()); + assertFalse(state.updateReplicaState(node1, 0, new LogOffsetMetadata(10L))); + assertEquals(singleton(node2), state.nonAcknowledgingVoters()); + assertEquals(Optional.empty(), state.highWatermark()); + assertFalse(state.updateReplicaState(node2, 0, new LogOffsetMetadata(10L))); + assertEquals(emptySet(), state.nonAcknowledgingVoters()); + assertEquals(Optional.empty(), state.highWatermark()); + assertTrue(state.updateReplicaState(node2, 0, new LogOffsetMetadata(15L))); + assertEquals(Optional.of(new LogOffsetMetadata(15L)), state.highWatermark()); + assertFalse(state.updateLocalState(0, new LogOffsetMetadata(20L))); + assertEquals(Optional.of(new LogOffsetMetadata(15L)), state.highWatermark()); + assertTrue(state.updateReplicaState(node1, 0, new LogOffsetMetadata(20L))); + assertEquals(Optional.of(new LogOffsetMetadata(20L)), state.highWatermark()); + assertFalse(state.updateReplicaState(node2, 0, new LogOffsetMetadata(20L))); + assertEquals(Optional.of(new LogOffsetMetadata(20L)), state.highWatermark()); + } + + @Test + public void testNonMonotonicHighWatermarkUpdate() { + MockTime time = new MockTime(); + int node1 = 1; + LeaderState state = newLeaderState(mkSet(localId, node1), 0L); + state.updateLocalState(time.milliseconds(), new LogOffsetMetadata(10L)); + state.updateReplicaState(node1, time.milliseconds(), new LogOffsetMetadata(10L)); + assertEquals(Optional.of(new LogOffsetMetadata(10L)), state.highWatermark()); + + // Follower crashes and disk is lost. It fetches an earlier offset to rebuild state. + // The leader will report an error in the logs, but will not let the high watermark rewind + assertFalse(state.updateReplicaState(node1, time.milliseconds(), new LogOffsetMetadata(5L))); + assertEquals(5L, state.getVoterEndOffsets().get(node1)); + assertEquals(Optional.of(new LogOffsetMetadata(10L)), state.highWatermark()); + } + + @Test + public void testGetNonLeaderFollowersByFetchOffsetDescending() { + int node1 = 1; + int node2 = 2; + long leaderStartOffset = 10L; + long leaderEndOffset = 15L; + + LeaderState state = setUpLeaderAndFollowers(node1, node2, leaderStartOffset, leaderEndOffset); + + // Leader should not be included; the follower with larger offset should be prioritized. + assertEquals(Arrays.asList(node2, node1), state.nonLeaderVotersByDescendingFetchOffset()); + } + + @Test + public void testGetVoterStates() { + int node1 = 1; + int node2 = 2; + long leaderStartOffset = 10L; + long leaderEndOffset = 15L; + + LeaderState state = setUpLeaderAndFollowers(node1, node2, leaderStartOffset, leaderEndOffset); + + assertEquals(mkMap( + mkEntry(localId, leaderEndOffset), + mkEntry(node1, leaderStartOffset), + mkEntry(node2, leaderEndOffset) + ), state.getVoterEndOffsets()); + } + + private LeaderState setUpLeaderAndFollowers(int follower1, + int follower2, + long leaderStartOffset, + long leaderEndOffset) { + LeaderState state = newLeaderState(mkSet(localId, follower1, follower2), leaderStartOffset); + state.updateLocalState(0, new LogOffsetMetadata(leaderEndOffset)); + assertEquals(Optional.empty(), state.highWatermark()); + state.updateReplicaState(follower1, 0, new LogOffsetMetadata(leaderStartOffset)); + state.updateReplicaState(follower2, 0, new LogOffsetMetadata(leaderEndOffset)); + return state; + } + + @Test + public void testGetObserverStatesWithObserver() { + int observerId = 10; + long epochStartOffset = 10L; + + LeaderState state = newLeaderState(mkSet(localId), epochStartOffset); + long timestamp = 20L; + assertFalse(state.updateReplicaState(observerId, timestamp, new LogOffsetMetadata(epochStartOffset))); + + assertEquals(Collections.singletonMap(observerId, epochStartOffset), state.getObserverStates(timestamp)); + } + + @Test + public void testNoOpForNegativeRemoteNodeId() { + int observerId = -1; + long epochStartOffset = 10L; + + LeaderState state = newLeaderState(mkSet(localId), epochStartOffset); + assertFalse(state.updateReplicaState(observerId, 0, new LogOffsetMetadata(epochStartOffset))); + + assertEquals(Collections.emptyMap(), state.getObserverStates(10)); + } + + @Test + public void testObserverStateExpiration() { + MockTime time = new MockTime(); + int observerId = 10; + long epochStartOffset = 10L; + LeaderState state = newLeaderState(mkSet(localId), epochStartOffset); + + state.updateReplicaState(observerId, time.milliseconds(), new LogOffsetMetadata(epochStartOffset)); + assertEquals(singleton(observerId), state.getObserverStates(time.milliseconds()).keySet()); + + time.sleep(LeaderState.OBSERVER_SESSION_TIMEOUT_MS); + assertEquals(emptySet(), state.getObserverStates(time.milliseconds()).keySet()); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testGrantVote(boolean isLogUpToDate) { + LeaderState state = newLeaderState(Utils.mkSet(1, 2, 3), 1); + + assertFalse(state.canGrantVote(1, isLogUpToDate)); + assertFalse(state.canGrantVote(2, isLogUpToDate)); + assertFalse(state.canGrantVote(3, isLogUpToDate)); + } + + private static class MockOffsetMetadata implements OffsetMetadata { + private final String value; + + private MockOffsetMetadata(String value) { + this.value = value; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MockOffsetMetadata that = (MockOffsetMetadata) o; + return Objects.equals(value, that.value); + } + + @Override + public int hashCode() { + return Objects.hash(value); + } + } + +} diff --git a/raft/src/test/java/org/apache/kafka/raft/MockExpirationService.java b/raft/src/test/java/org/apache/kafka/raft/MockExpirationService.java new file mode 100644 index 0000000..f71e1db --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/MockExpirationService.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.utils.MockTime; + +import java.util.PriorityQueue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicLong; + +public class MockExpirationService implements ExpirationService, MockTime.Listener { + private final AtomicLong idGenerator = new AtomicLong(0); + private final MockTime time; + private final PriorityQueue> queue = new PriorityQueue<>(); + + public MockExpirationService(MockTime time) { + this.time = time; + time.addListener(this); + } + + @Override + public CompletableFuture failAfter(long timeoutMs) { + long deadlineMs = time.milliseconds() + timeoutMs; + long id = idGenerator.incrementAndGet(); + ExpirationFuture future = new ExpirationFuture<>(id, deadlineMs); + queue.add(future); + return future; + } + + @Override + public void onTimeUpdated() { + long currentTimeMs = time.milliseconds(); + while (true) { + ExpirationFuture future = queue.peek(); + if (future == null || future.deadlineMs > currentTimeMs) { + break; + } + ExpirationFuture polled = queue.poll(); + polled.completeExceptionally(new TimeoutException()); + } + } + + private static class ExpirationFuture extends CompletableFuture implements Comparable> { + private final long id; + private final long deadlineMs; + + private ExpirationFuture(long id, long deadlineMs) { + this.id = id; + this.deadlineMs = deadlineMs; + } + + @Override + public int compareTo(ExpirationFuture o) { + int res = Long.compare(this.deadlineMs, o.deadlineMs); + if (res != 0) { + return res; + } else { + return Long.compare(this.id, o.id); + } + } + } + +} diff --git a/raft/src/test/java/org/apache/kafka/raft/MockExpirationServiceTest.java b/raft/src/test/java/org/apache/kafka/raft/MockExpirationServiceTest.java new file mode 100644 index 0000000..10ee16a --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/MockExpirationServiceTest.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; + +import java.util.concurrent.CompletableFuture; + +import static org.junit.jupiter.api.Assertions.assertFalse; + +class MockExpirationServiceTest { + + private final MockTime time = new MockTime(); + private final MockExpirationService expirationService = new MockExpirationService(time); + + @Test + public void testFailAfter() { + CompletableFuture future1 = expirationService.failAfter(50); + CompletableFuture future2 = expirationService.failAfter(25); + CompletableFuture future3 = expirationService.failAfter(75); + CompletableFuture future4 = expirationService.failAfter(50); + + time.sleep(25); + TestUtils.assertFutureThrows(future2, TimeoutException.class); + assertFalse(future1.isDone()); + assertFalse(future3.isDone()); + assertFalse(future4.isDone()); + + time.sleep(25); + TestUtils.assertFutureThrows(future1, TimeoutException.class); + TestUtils.assertFutureThrows(future4, TimeoutException.class); + assertFalse(future3.isDone()); + + time.sleep(25); + TestUtils.assertFutureThrows(future3, TimeoutException.class); + } + +} \ No newline at end of file diff --git a/raft/src/test/java/org/apache/kafka/raft/MockLog.java b/raft/src/test/java/org/apache/kafka/raft/MockLog.java new file mode 100644 index 0000000..50cdeec --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/MockLog.java @@ -0,0 +1,702 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.errors.OffsetOutOfRangeException; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.MemoryRecordsBuilder; +import org.apache.kafka.common.record.Record; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.record.Records; +import org.apache.kafka.common.record.SimpleRecord; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.snapshot.MockRawSnapshotReader; +import org.apache.kafka.snapshot.MockRawSnapshotWriter; +import org.apache.kafka.snapshot.RawSnapshotReader; +import org.apache.kafka.snapshot.RawSnapshotWriter; +import org.slf4j.Logger; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.NavigableMap; +import java.util.Objects; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.OptionalLong; +import java.util.TreeMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.stream.Collectors; + +public class MockLog implements ReplicatedLog { + private static final AtomicLong ID_GENERATOR = new AtomicLong(); + + private final List epochStartOffsets = new ArrayList<>(); + private final List batches = new ArrayList<>(); + private final NavigableMap snapshots = new TreeMap<>(); + private final TopicPartition topicPartition; + private final Uuid topicId; + private final Logger logger; + + private long nextId = ID_GENERATOR.getAndIncrement(); + private LogOffsetMetadata highWatermark = new LogOffsetMetadata(0, Optional.empty()); + private long lastFlushedOffset = 0; + + public MockLog( + TopicPartition topicPartition, + Uuid topicId, + LogContext logContext + ) { + this.topicPartition = topicPartition; + this.topicId = topicId; + this.logger = logContext.logger(MockLog.class); + } + + @Override + public void truncateTo(long offset) { + if (offset < highWatermark.offset) { + throw new IllegalArgumentException("Illegal attempt to truncate to offset " + offset + + " which is below the current high watermark " + highWatermark); + } + + batches.removeIf(entry -> entry.lastOffset() >= offset); + epochStartOffsets.removeIf(epochStartOffset -> epochStartOffset.startOffset >= offset); + } + + @Override + public boolean truncateToLatestSnapshot() { + AtomicBoolean truncated = new AtomicBoolean(false); + latestSnapshotId().ifPresent(snapshotId -> { + if (snapshotId.epoch > logLastFetchedEpoch().orElse(0) || + (snapshotId.epoch == logLastFetchedEpoch().orElse(0) && + snapshotId.offset > endOffset().offset)) { + + batches.clear(); + epochStartOffsets.clear(); + snapshots.headMap(snapshotId, false).clear(); + updateHighWatermark(new LogOffsetMetadata(snapshotId.offset)); + flush(); + + truncated.set(true); + } + }); + + return truncated.get(); + } + + @Override + public void updateHighWatermark(LogOffsetMetadata offsetMetadata) { + if (this.highWatermark.offset > offsetMetadata.offset) { + throw new IllegalArgumentException("Non-monotonic update of current high watermark " + + highWatermark + " to new value " + offsetMetadata); + } else if (offsetMetadata.offset > endOffset().offset) { + throw new IllegalArgumentException("Attempt to update high watermark to " + offsetMetadata + + " which is larger than the current end offset " + endOffset()); + } else if (offsetMetadata.offset < startOffset()) { + throw new IllegalArgumentException("Attempt to update high watermark to " + offsetMetadata + + " which is smaller than the current start offset " + startOffset()); + } + + assertValidHighWatermarkMetadata(offsetMetadata); + this.highWatermark = offsetMetadata; + } + + @Override + public LogOffsetMetadata highWatermark() { + return highWatermark; + } + + @Override + public TopicPartition topicPartition() { + return topicPartition; + } + + @Override + public Uuid topicId() { + return topicId; + } + + private Optional metadataForOffset(long offset) { + if (offset == endOffset().offset) { + return endOffset().metadata; + } + + for (LogBatch batch : batches) { + if (batch.lastOffset() < offset) + continue; + + for (LogEntry entry : batch.entries) { + if (entry.offset == offset) { + return Optional.of(entry.metadata); + } + } + } + + return Optional.empty(); + } + + private void assertValidHighWatermarkMetadata(LogOffsetMetadata offsetMetadata) { + if (!offsetMetadata.metadata.isPresent()) { + return; + } + + long id = ((MockOffsetMetadata) offsetMetadata.metadata.get()).id; + long offset = offsetMetadata.offset; + + metadataForOffset(offset).ifPresent(metadata -> { + long entryId = ((MockOffsetMetadata) metadata).id; + if (entryId != id) { + throw new IllegalArgumentException("High watermark " + offset + + " metadata uuid " + id + " does not match the " + + " log's record entry maintained uuid " + entryId); + } + }); + } + + private OptionalInt logLastFetchedEpoch() { + if (epochStartOffsets.isEmpty()) { + return OptionalInt.empty(); + } else { + return OptionalInt.of(epochStartOffsets.get(epochStartOffsets.size() - 1).epoch); + } + } + + @Override + public int lastFetchedEpoch() { + return logLastFetchedEpoch().orElseGet(() -> latestSnapshotId().map(id -> id.epoch).orElse(0)); + } + + @Override + public OffsetAndEpoch endOffsetForEpoch(int epoch) { + return lastOffsetAndEpochFiltered(epochStartOffset -> epochStartOffset.epoch <= epoch); + } + + private OffsetAndEpoch epochForEndOffset(long endOffset) { + return lastOffsetAndEpochFiltered(epochStartOffset -> epochStartOffset.startOffset < endOffset); + } + + private OffsetAndEpoch lastOffsetAndEpochFiltered(Predicate predicate) { + int epochLowerBound = earliestSnapshotId().map(id -> id.epoch).orElse(0); + for (EpochStartOffset epochStartOffset : epochStartOffsets) { + if (!predicate.test(epochStartOffset)) { + return new OffsetAndEpoch(epochStartOffset.startOffset, epochLowerBound); + } + epochLowerBound = epochStartOffset.epoch; + } + + return new OffsetAndEpoch(endOffset().offset, lastFetchedEpoch()); + } + + private Optional lastEntry() { + if (batches.isEmpty()) + return Optional.empty(); + return Optional.of(batches.get(batches.size() - 1).last()); + } + + private Optional firstEntry() { + if (batches.isEmpty()) + return Optional.empty(); + return Optional.of(batches.get(0).first()); + } + + @Override + public LogOffsetMetadata endOffset() { + long nextOffset = lastEntry() + .map(entry -> entry.offset + 1) + .orElse( + latestSnapshotId() + .map(id -> id.offset) + .orElse(0L) + ); + return new LogOffsetMetadata(nextOffset, Optional.of(new MockOffsetMetadata(nextId))); + } + + @Override + public long startOffset() { + return firstEntry() + .map(entry -> entry.offset) + .orElse( + earliestSnapshotId() + .map(id -> id.offset) + .orElse(0L) + ); + } + + private List buildEntries(RecordBatch batch, Function offsetSupplier) { + List entries = new ArrayList<>(); + for (Record record : batch) { + long offset = offsetSupplier.apply(record); + long timestamp = record.timestamp(); + ByteBuffer key = copy(record.key()); + ByteBuffer value = copy(record.value()); + entries.add(buildEntry(offset, new SimpleRecord(timestamp, key, value))); + } + return entries; + } + + private ByteBuffer copy(ByteBuffer nullableByteBuffer) { + if (nullableByteBuffer == null) { + return null; + } else { + byte[] array = Utils.toArray(nullableByteBuffer, nullableByteBuffer.position(), nullableByteBuffer.limit()); + return ByteBuffer.wrap(array); + } + } + + private LogEntry buildEntry(Long offset, SimpleRecord record) { + long id = nextId; + nextId = ID_GENERATOR.getAndIncrement(); + return new LogEntry(new MockOffsetMetadata(id), offset, record); + } + + + @Override + public LogAppendInfo appendAsLeader(Records records, int epoch) { + return append(records, OptionalInt.of(epoch)); + } + + private Long appendBatch(LogBatch batch) { + if (batch.epoch > lastFetchedEpoch()) { + epochStartOffsets.add(new EpochStartOffset(batch.epoch, batch.firstOffset())); + } + batches.add(batch); + return batch.firstOffset(); + } + + @Override + public LogAppendInfo appendAsFollower(Records records) { + return append(records, OptionalInt.empty()); + } + + private LogAppendInfo append(Records records, OptionalInt epoch) { + if (records.sizeInBytes() == 0) + throw new IllegalArgumentException("Attempt to append an empty record set"); + + long baseOffset = endOffset().offset; + long lastOffset = baseOffset; + for (RecordBatch batch : records.batches()) { + if (batch.baseOffset() != endOffset().offset) { + /* KafkaMetadataLog throws an kafka.common.UnexpectedAppendOffsetException this is the + * best we can do from this module. + */ + throw new RuntimeException( + String.format( + "Illegal append at offset %s with current end offset of %s", + batch.baseOffset(), + endOffset().offset + ) + ); + } + + List entries = buildEntries(batch, Record::offset); + appendBatch( + new LogBatch( + epoch.orElseGet(batch::partitionLeaderEpoch), + batch.isControlBatch(), + entries + ) + ); + lastOffset = entries.get(entries.size() - 1).offset; + } + + return new LogAppendInfo(baseOffset, lastOffset); + } + + @Override + public void flush() { + lastFlushedOffset = endOffset().offset; + } + + @Override + public boolean maybeClean() { + return false; + } + + @Override + public long lastFlushedOffset() { + return lastFlushedOffset; + } + + /** + * Reopening the log causes all unflushed data to be lost. + */ + public void reopen() { + batches.removeIf(batch -> batch.firstOffset() >= lastFlushedOffset); + epochStartOffsets.removeIf(epochStartOffset -> epochStartOffset.startOffset >= lastFlushedOffset); + highWatermark = new LogOffsetMetadata(0L, Optional.empty()); + } + + public List readBatches(long startOffset, OptionalLong maxOffsetOpt) { + verifyOffsetInRange(startOffset); + + long maxOffset = maxOffsetOpt.orElse(endOffset().offset); + if (startOffset == maxOffset) { + return Collections.emptyList(); + } + + return batches.stream() + .filter(batch -> batch.lastOffset() >= startOffset && batch.lastOffset() < maxOffset) + .collect(Collectors.toList()); + } + + private void verifyOffsetInRange(long offset) { + if (offset > endOffset().offset) { + throw new OffsetOutOfRangeException("Requested offset " + offset + " is larger than " + + "then log end offset " + endOffset().offset); + } + + if (offset < this.startOffset()) { + throw new OffsetOutOfRangeException("Requested offset " + offset + " is smaller than " + + "then log start offset " + this.startOffset()); + } + } + + @Override + public LogFetchInfo read(long startOffset, Isolation isolation) { + OptionalLong maxOffsetOpt = isolation == Isolation.COMMITTED ? + OptionalLong.of(highWatermark.offset) : + OptionalLong.empty(); + + verifyOffsetInRange(startOffset); + + long maxOffset = maxOffsetOpt.orElse(endOffset().offset); + if (startOffset >= maxOffset) { + return new LogFetchInfo(MemoryRecords.EMPTY, new LogOffsetMetadata( + startOffset, metadataForOffset(startOffset))); + } + + ByteBuffer buffer = ByteBuffer.allocate(512); + int batchCount = 0; + LogOffsetMetadata batchStartOffset = null; + + for (LogBatch batch : batches) { + // Note that start offset is inclusive while max offset is exclusive. We only return + // complete batches, so batches which end at an offset larger than the max offset are + // filtered, which is effectively the same as having the consumer drop an incomplete + // batch returned in a fetch response. + if (batch.lastOffset() >= startOffset && batch.lastOffset() < maxOffset && !batch.entries.isEmpty()) { + buffer = batch.writeTo(buffer); + + if (batchStartOffset == null) { + batchStartOffset = batch.entries.get(0).logOffsetMetadata(); + } + + // Read on the mock log should return at most 2 batches. This is a simple solution + // for testing interesting partial read scenarios. + batchCount += 1; + if (batchCount >= 2) { + break; + } + } + } + + buffer.flip(); + Records records = MemoryRecords.readableRecords(buffer); + + if (batchStartOffset == null) { + throw new RuntimeException("Expected to find at least one entry starting from offset " + + startOffset + " but found none"); + } + + return new LogFetchInfo(records, batchStartOffset); + } + + @Override + public void initializeLeaderEpoch(int epoch) { + long startOffset = endOffset().offset; + epochStartOffsets.removeIf(epochStartOffset -> + epochStartOffset.startOffset >= startOffset || epochStartOffset.epoch >= epoch); + epochStartOffsets.add(new EpochStartOffset(epoch, startOffset)); + } + + @Override + public Optional createNewSnapshot(OffsetAndEpoch snapshotId) { + if (snapshotId.offset < startOffset()) { + logger.info( + "Cannot create a snapshot with an id ({}) less than the log start offset ({})", + snapshotId, + startOffset() + ); + + return Optional.empty(); + } + + long highWatermarkOffset = highWatermark().offset; + if (snapshotId.offset > highWatermarkOffset) { + throw new IllegalArgumentException( + String.format( + "Cannot create a snapshot with an id (%s) greater than the high-watermark (%s)", + snapshotId, + highWatermarkOffset + ) + ); + } + + ValidOffsetAndEpoch validOffsetAndEpoch = validateOffsetAndEpoch(snapshotId.offset, snapshotId.epoch); + if (validOffsetAndEpoch.kind() != ValidOffsetAndEpoch.Kind.VALID) { + throw new IllegalArgumentException( + String.format( + "Snapshot id (%s) is not valid according to the log: %s", + snapshotId, + validOffsetAndEpoch + ) + ); + } + + return storeSnapshot(snapshotId); + } + + @Override + public Optional storeSnapshot(OffsetAndEpoch snapshotId) { + if (snapshots.containsKey(snapshotId)) { + return Optional.empty(); + } else { + return Optional.of( + new MockRawSnapshotWriter(snapshotId, buffer -> { + snapshots.putIfAbsent(snapshotId, new MockRawSnapshotReader(snapshotId, buffer)); + }) + ); + } + } + + @Override + public Optional readSnapshot(OffsetAndEpoch snapshotId) { + return Optional.ofNullable(snapshots.get(snapshotId)); + } + + @Override + public Optional latestSnapshot() { + return latestSnapshotId().flatMap(this::readSnapshot); + } + + @Override + public Optional latestSnapshotId() { + return Optional.ofNullable(snapshots.lastEntry()) + .map(Map.Entry::getKey); + } + + @Override + public Optional earliestSnapshotId() { + return Optional.ofNullable(snapshots.firstEntry()) + .map(Map.Entry::getKey); + } + + @Override + public void onSnapshotFrozen(OffsetAndEpoch snapshotId) {} + + @Override + public boolean deleteBeforeSnapshot(OffsetAndEpoch snapshotId) { + if (startOffset() > snapshotId.offset) { + throw new OffsetOutOfRangeException( + String.format( + "New log start (%s) is less than the curent log start offset (%s)", + snapshotId, + startOffset() + ) + ); + } + if (highWatermark.offset < snapshotId.offset) { + throw new OffsetOutOfRangeException( + String.format( + "New log start (%s) is greater than the high watermark (%s)", + snapshotId, + highWatermark.offset + ) + ); + } + + boolean updated = false; + if (snapshots.containsKey(snapshotId)) { + snapshots.headMap(snapshotId, false).clear(); + + batches.removeIf(entry -> entry.lastOffset() < snapshotId.offset); + + AtomicReference> last = new AtomicReference<>(Optional.empty()); + epochStartOffsets.removeIf(epochStartOffset -> { + if (epochStartOffset.startOffset <= snapshotId.offset) { + last.set(Optional.of(epochStartOffset)); + return true; + } + + return false; + }); + + last.get().ifPresent(epochStartOffset -> { + epochStartOffsets.add( + 0, + new EpochStartOffset(epochStartOffset.epoch, snapshotId.offset) + ); + }); + + updated = true; + } + + return updated; + } + + static class MockOffsetMetadata implements OffsetMetadata { + final long id; + + MockOffsetMetadata(long id) { + this.id = id; + } + + @Override + public String toString() { + return "MockOffsetMetadata(" + + "id=" + id + + ')'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MockOffsetMetadata that = (MockOffsetMetadata) o; + return id == that.id; + } + + @Override + public int hashCode() { + return Objects.hash(id); + } + } + + static class LogEntry { + final MockOffsetMetadata metadata; + final long offset; + final SimpleRecord record; + + LogEntry(MockOffsetMetadata metadata, long offset, SimpleRecord record) { + this.metadata = metadata; + this.offset = offset; + this.record = record; + } + + LogOffsetMetadata logOffsetMetadata() { + return new LogOffsetMetadata(offset, Optional.of(metadata)); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + LogEntry logEntry = (LogEntry) o; + return offset == logEntry.offset && + Objects.equals(metadata, logEntry.metadata) && + Objects.equals(record, logEntry.record); + } + + @Override + public int hashCode() { + return Objects.hash(metadata, offset, record); + } + + @Override + public String toString() { + return String.format( + "LogEntry(metadata=%s, offset=%s, record=%s)", + metadata, + offset, + record + ); + } + } + + static class LogBatch { + final List entries; + final int epoch; + final boolean isControlBatch; + + LogBatch(int epoch, boolean isControlBatch, List entries) { + if (entries.isEmpty()) + throw new IllegalArgumentException("Empty batches are not supported"); + this.entries = entries; + this.epoch = epoch; + this.isControlBatch = isControlBatch; + } + + long firstOffset() { + return first().offset; + } + + LogEntry first() { + return entries.get(0); + } + + long lastOffset() { + return last().offset; + } + + LogEntry last() { + return entries.get(entries.size() - 1); + } + + ByteBuffer writeTo(ByteBuffer buffer) { + LogEntry first = first(); + + MemoryRecordsBuilder builder = MemoryRecords.builder( + buffer, RecordBatch.CURRENT_MAGIC_VALUE, CompressionType.NONE, + TimestampType.CREATE_TIME, first.offset, first.record.timestamp(), + RecordBatch.NO_PRODUCER_ID, RecordBatch.NO_PRODUCER_EPOCH, + RecordBatch.NO_SEQUENCE, false, + isControlBatch, epoch); + + for (LogEntry entry : entries) { + if (isControlBatch) { + builder.appendControlRecordWithOffset(entry.offset, entry.record); + } else { + builder.appendWithOffset(entry.offset, entry.record); + } + } + + builder.close(); + return builder.buffer(); + } + + @Override + public String toString() { + return String.format("LogBatch(entries=%s, epoch=%s, isControlBatch=%s)", entries, epoch, isControlBatch); + } + } + + private static class EpochStartOffset { + final int epoch; + final long startOffset; + + private EpochStartOffset(int epoch, long startOffset) { + this.epoch = epoch; + this.startOffset = startOffset; + } + + @Override + public String toString() { + return String.format("EpochStartOffset(epoch=%s, startOffset=%s)", epoch, startOffset); + } + } +} diff --git a/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java b/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java new file mode 100644 index 0000000..9365640 --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/MockLogTest.java @@ -0,0 +1,999 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.errors.OffsetOutOfRangeException; +import org.apache.kafka.common.message.LeaderChangeMessage; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.ControlRecordUtils; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.Record; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.record.Records; +import org.apache.kafka.common.record.SimpleRecord; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.snapshot.RawSnapshotReader; +import org.apache.kafka.snapshot.RawSnapshotWriter; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class MockLogTest { + + private MockLog log; + private final TopicPartition topicPartition = new TopicPartition("mock-topic", 0); + private final Uuid topicId = Uuid.randomUuid(); + + @BeforeEach + public void setup() { + log = new MockLog(topicPartition, topicId, new LogContext()); + } + + @AfterEach + public void cleanup() { + log.close(); + } + + @Test + public void testTopicPartition() { + assertEquals(topicPartition, log.topicPartition()); + } + + @Test + public void testTopicId() { + assertEquals(topicId, log.topicId()); + } + + @Test + public void testTruncateTo() { + int epoch = 2; + SimpleRecord recordOne = new SimpleRecord("one".getBytes()); + SimpleRecord recordTwo = new SimpleRecord("two".getBytes()); + appendAsLeader(Arrays.asList(recordOne, recordTwo), epoch); + + SimpleRecord recordThree = new SimpleRecord("three".getBytes()); + appendAsLeader(Collections.singleton(recordThree), epoch); + + assertEquals(0L, log.startOffset()); + assertEquals(3L, log.endOffset().offset); + + log.truncateTo(2); + assertEquals(0L, log.startOffset()); + assertEquals(2L, log.endOffset().offset); + + log.truncateTo(1); + assertEquals(0L, log.startOffset()); + assertEquals(0L, log.endOffset().offset); + } + + @Test + public void testTruncateBelowHighWatermark() { + appendBatch(5, 1); + LogOffsetMetadata highWatermark = new LogOffsetMetadata(5L); + log.updateHighWatermark(highWatermark); + assertEquals(highWatermark, log.highWatermark()); + assertThrows(IllegalArgumentException.class, () -> log.truncateTo(4L)); + assertEquals(highWatermark, log.highWatermark()); + } + + @Test + public void testUpdateHighWatermark() { + appendBatch(5, 1); + LogOffsetMetadata newOffset = new LogOffsetMetadata(5L); + log.updateHighWatermark(newOffset); + assertEquals(newOffset.offset, log.highWatermark().offset); + } + + @Test + public void testDecrementHighWatermark() { + appendBatch(5, 1); + LogOffsetMetadata newOffset = new LogOffsetMetadata(4L); + log.updateHighWatermark(newOffset); + assertThrows(IllegalArgumentException.class, () -> log.updateHighWatermark(new LogOffsetMetadata(3L))); + } + + @Test + public void testAssignEpochStartOffset() { + log.initializeLeaderEpoch(2); + assertEquals(2, log.lastFetchedEpoch()); + } + + @Test + public void testAppendAsLeader() { + int epoch = 2; + SimpleRecord recordOne = new SimpleRecord("one".getBytes()); + List expectedRecords = new ArrayList<>(); + + expectedRecords.add(recordOne); + appendAsLeader(Collections.singleton(recordOne), epoch); + + assertEquals(new OffsetAndEpoch(expectedRecords.size(), epoch), log.endOffsetForEpoch(epoch)); + assertEquals(epoch, log.lastFetchedEpoch()); + validateReadRecords(expectedRecords, log); + + SimpleRecord recordTwo = new SimpleRecord("two".getBytes()); + SimpleRecord recordThree = new SimpleRecord("three".getBytes()); + expectedRecords.add(recordTwo); + expectedRecords.add(recordThree); + appendAsLeader(Arrays.asList(recordTwo, recordThree), epoch); + + assertEquals(new OffsetAndEpoch(expectedRecords.size(), epoch), log.endOffsetForEpoch(epoch)); + assertEquals(epoch, log.lastFetchedEpoch()); + validateReadRecords(expectedRecords, log); + } + + @Test + public void testUnexpectedAppendOffset() { + SimpleRecord recordFoo = new SimpleRecord("foo".getBytes()); + final int currentEpoch = 3; + final long initialOffset = log.endOffset().offset; + + log.appendAsLeader( + MemoryRecords.withRecords(initialOffset, CompressionType.NONE, currentEpoch, recordFoo), + currentEpoch + ); + + // Throw exception for out of order records + assertThrows( + RuntimeException.class, + () -> { + log.appendAsLeader( + MemoryRecords.withRecords(initialOffset, CompressionType.NONE, currentEpoch, recordFoo), + currentEpoch + ); + } + ); + + assertThrows( + RuntimeException.class, + () -> { + log.appendAsFollower( + MemoryRecords.withRecords(initialOffset, CompressionType.NONE, currentEpoch, recordFoo) + ); + } + ); + } + + @Test + public void testAppendControlRecord() { + final long initialOffset = 0; + final int currentEpoch = 3; + LeaderChangeMessage messageData = new LeaderChangeMessage().setLeaderId(0); + ByteBuffer buffer = ByteBuffer.allocate(256); + log.appendAsLeader( + MemoryRecords.withLeaderChangeMessage(initialOffset, 0L, 2, buffer, messageData), + currentEpoch + ); + + assertEquals(0, log.startOffset()); + assertEquals(1, log.endOffset().offset); + assertEquals(currentEpoch, log.lastFetchedEpoch()); + + Records records = log.read(0, Isolation.UNCOMMITTED).records; + for (RecordBatch batch : records.batches()) { + assertTrue(batch.isControlBatch()); + } + List extractRecords = new ArrayList<>(); + for (Record record : records.records()) { + LeaderChangeMessage deserializedData = ControlRecordUtils.deserializeLeaderChangeMessage(record); + assertEquals(deserializedData, messageData); + extractRecords.add(record.value()); + } + + assertEquals(1, extractRecords.size()); + assertEquals(new OffsetAndEpoch(1, currentEpoch), log.endOffsetForEpoch(currentEpoch)); + } + + @Test + public void testAppendAsFollower() throws IOException { + final long initialOffset = 5; + final int epoch = 3; + SimpleRecord recordFoo = new SimpleRecord("foo".getBytes()); + + try (RawSnapshotWriter snapshot = log.storeSnapshot(new OffsetAndEpoch(initialOffset, 0)).get()) { + snapshot.freeze(); + } + log.truncateToLatestSnapshot(); + + log.appendAsFollower(MemoryRecords.withRecords(initialOffset, CompressionType.NONE, epoch, recordFoo)); + + assertEquals(initialOffset, log.startOffset()); + assertEquals(initialOffset + 1, log.endOffset().offset); + assertEquals(3, log.lastFetchedEpoch()); + + Records records = log.read(5L, Isolation.UNCOMMITTED).records; + List extractRecords = new ArrayList<>(); + for (Record record : records.records()) { + extractRecords.add(record.value()); + } + + assertEquals(1, extractRecords.size()); + assertEquals(recordFoo.value(), extractRecords.get(0)); + assertEquals(new OffsetAndEpoch(5, 0), log.endOffsetForEpoch(0)); + assertEquals(new OffsetAndEpoch(log.endOffset().offset, epoch), log.endOffsetForEpoch(epoch)); + } + + @Test + public void testReadRecords() { + int epoch = 2; + + ByteBuffer recordOneBuffer = ByteBuffer.allocate(4); + recordOneBuffer.putInt(1); + SimpleRecord recordOne = new SimpleRecord(recordOneBuffer); + + ByteBuffer recordTwoBuffer = ByteBuffer.allocate(4); + recordTwoBuffer.putInt(2); + SimpleRecord recordTwo = new SimpleRecord(recordTwoBuffer); + + appendAsLeader(Arrays.asList(recordOne, recordTwo), epoch); + + Records records = log.read(0, Isolation.UNCOMMITTED).records; + + List extractRecords = new ArrayList<>(); + for (Record record : records.records()) { + extractRecords.add(record.value()); + } + assertEquals(Arrays.asList(recordOne.value(), recordTwo.value()), extractRecords); + } + + @Test + public void testReadUpToLogEnd() { + appendBatch(20, 1); + appendBatch(10, 1); + appendBatch(30, 1); + + assertEquals(Optional.of(new OffsetRange(0L, 59L)), readOffsets(0L, Isolation.UNCOMMITTED)); + assertEquals(Optional.of(new OffsetRange(0L, 59L)), readOffsets(10L, Isolation.UNCOMMITTED)); + assertEquals(Optional.of(new OffsetRange(20L, 59L)), readOffsets(20L, Isolation.UNCOMMITTED)); + assertEquals(Optional.of(new OffsetRange(20L, 59L)), readOffsets(25L, Isolation.UNCOMMITTED)); + assertEquals(Optional.of(new OffsetRange(30L, 59L)), readOffsets(30L, Isolation.UNCOMMITTED)); + assertEquals(Optional.of(new OffsetRange(30L, 59L)), readOffsets(33L, Isolation.UNCOMMITTED)); + assertEquals(Optional.empty(), readOffsets(60L, Isolation.UNCOMMITTED)); + assertThrows(OffsetOutOfRangeException.class, () -> log.read(61L, Isolation.UNCOMMITTED)); + + // Verify range after truncation + log.truncateTo(20L); + assertThrows(OffsetOutOfRangeException.class, () -> log.read(21L, Isolation.UNCOMMITTED)); + } + + @Test + public void testReadUpToHighWatermark() { + appendBatch(20, 1); + appendBatch(10, 1); + appendBatch(30, 1); + + log.updateHighWatermark(new LogOffsetMetadata(0L)); + assertEquals(Optional.empty(), readOffsets(0L, Isolation.COMMITTED)); + assertEquals(Optional.empty(), readOffsets(10L, Isolation.COMMITTED)); + + log.updateHighWatermark(new LogOffsetMetadata(20L)); + assertEquals(Optional.of(new OffsetRange(0L, 19L)), readOffsets(0L, Isolation.COMMITTED)); + assertEquals(Optional.of(new OffsetRange(0L, 19L)), readOffsets(10L, Isolation.COMMITTED)); + assertEquals(Optional.empty(), readOffsets(20L, Isolation.COMMITTED)); + assertEquals(Optional.empty(), readOffsets(30L, Isolation.COMMITTED)); + + log.updateHighWatermark(new LogOffsetMetadata(30L)); + assertEquals(Optional.of(new OffsetRange(0L, 29L)), readOffsets(0L, Isolation.COMMITTED)); + assertEquals(Optional.of(new OffsetRange(0L, 29L)), readOffsets(10L, Isolation.COMMITTED)); + assertEquals(Optional.of(new OffsetRange(20L, 29L)), readOffsets(20L, Isolation.COMMITTED)); + assertEquals(Optional.of(new OffsetRange(20L, 29L)), readOffsets(25L, Isolation.COMMITTED)); + assertEquals(Optional.empty(), readOffsets(30L, Isolation.COMMITTED)); + assertEquals(Optional.empty(), readOffsets(50L, Isolation.COMMITTED)); + + log.updateHighWatermark(new LogOffsetMetadata(60L)); + assertEquals(Optional.of(new OffsetRange(0L, 59L)), readOffsets(0L, Isolation.COMMITTED)); + assertEquals(Optional.of(new OffsetRange(0L, 59L)), readOffsets(10L, Isolation.COMMITTED)); + assertEquals(Optional.of(new OffsetRange(20L, 59L)), readOffsets(20L, Isolation.COMMITTED)); + assertEquals(Optional.of(new OffsetRange(20L, 59L)), readOffsets(25L, Isolation.COMMITTED)); + assertEquals(Optional.of(new OffsetRange(30, 59L)), readOffsets(30L, Isolation.COMMITTED)); + assertEquals(Optional.of(new OffsetRange(30L, 59L)), readOffsets(50L, Isolation.COMMITTED)); + assertEquals(Optional.empty(), readOffsets(60L, Isolation.COMMITTED)); + assertThrows(OffsetOutOfRangeException.class, () -> log.read(61L, Isolation.COMMITTED)); + } + + @Test + public void testMetadataValidation() { + appendBatch(5, 1); + appendBatch(5, 1); + appendBatch(5, 1); + + LogFetchInfo readInfo = log.read(5, Isolation.UNCOMMITTED); + assertEquals(5L, readInfo.startOffsetMetadata.offset); + assertTrue(readInfo.startOffsetMetadata.metadata.isPresent()); + MockLog.MockOffsetMetadata offsetMetadata = (MockLog.MockOffsetMetadata) + readInfo.startOffsetMetadata.metadata.get(); + + // Update to a high watermark with valid offset metadata + log.updateHighWatermark(readInfo.startOffsetMetadata); + assertEquals(readInfo.startOffsetMetadata.offset, log.highWatermark().offset); + + // Now update to a high watermark with invalid metadata + assertThrows(IllegalArgumentException.class, () -> + log.updateHighWatermark(new LogOffsetMetadata(10L, + Optional.of(new MockLog.MockOffsetMetadata(98230980L))))); + + // Ensure we can update the high watermark to the end offset + LogFetchInfo readFromEndInfo = log.read(15L, Isolation.UNCOMMITTED); + assertEquals(15, readFromEndInfo.startOffsetMetadata.offset); + assertTrue(readFromEndInfo.startOffsetMetadata.metadata.isPresent()); + log.updateHighWatermark(readFromEndInfo.startOffsetMetadata); + + // Ensure that the end offset metadata is valid after new entries are appended + appendBatch(5, 1); + log.updateHighWatermark(readFromEndInfo.startOffsetMetadata); + + // Check handling of a fetch from the middle of a batch + LogFetchInfo readFromMiddleInfo = log.read(16L, Isolation.UNCOMMITTED); + assertEquals(readFromEndInfo.startOffsetMetadata, readFromMiddleInfo.startOffsetMetadata); + } + + @Test + public void testEndOffsetForEpoch() { + appendBatch(5, 1); + appendBatch(10, 1); + appendBatch(5, 3); + appendBatch(10, 4); + + assertEquals(new OffsetAndEpoch(0, 0), log.endOffsetForEpoch(0)); + assertEquals(new OffsetAndEpoch(15L, 1), log.endOffsetForEpoch(1)); + assertEquals(new OffsetAndEpoch(15L, 1), log.endOffsetForEpoch(2)); + assertEquals(new OffsetAndEpoch(20L, 3), log.endOffsetForEpoch(3)); + assertEquals(new OffsetAndEpoch(30L, 4), log.endOffsetForEpoch(4)); + assertEquals(new OffsetAndEpoch(30L, 4), log.endOffsetForEpoch(5)); + } + + @Test + public void testEmptyAppendNotAllowed() { + assertThrows(IllegalArgumentException.class, () -> log.appendAsFollower(MemoryRecords.EMPTY)); + assertThrows(IllegalArgumentException.class, () -> log.appendAsLeader(MemoryRecords.EMPTY, 1)); + } + + @Test + public void testReadOutOfRangeOffset() throws IOException { + final long initialOffset = 5L; + final int epoch = 3; + SimpleRecord recordFoo = new SimpleRecord("foo".getBytes()); + + try (RawSnapshotWriter snapshot = log.storeSnapshot(new OffsetAndEpoch(initialOffset, 0)).get()) { + snapshot.freeze(); + } + log.truncateToLatestSnapshot(); + + log.appendAsFollower(MemoryRecords.withRecords(initialOffset, CompressionType.NONE, epoch, recordFoo)); + + assertThrows(OffsetOutOfRangeException.class, () -> log.read(log.startOffset() - 1, + Isolation.UNCOMMITTED)); + assertThrows(OffsetOutOfRangeException.class, () -> log.read(log.endOffset().offset + 1, + Isolation.UNCOMMITTED)); + } + + @Test + public void testMonotonicEpochStartOffset() { + appendBatch(5, 1); + assertEquals(5L, log.endOffset().offset); + + log.initializeLeaderEpoch(2); + assertEquals(new OffsetAndEpoch(5L, 1), log.endOffsetForEpoch(1)); + assertEquals(new OffsetAndEpoch(5L, 2), log.endOffsetForEpoch(2)); + + // Initialize a new epoch at the same end offset. The epoch cache ensures + // that the start offset of each retained epoch increases monotonically. + log.initializeLeaderEpoch(3); + assertEquals(new OffsetAndEpoch(5L, 1), log.endOffsetForEpoch(1)); + assertEquals(new OffsetAndEpoch(5L, 1), log.endOffsetForEpoch(2)); + assertEquals(new OffsetAndEpoch(5L, 3), log.endOffsetForEpoch(3)); + } + + @Test + public void testUnflushedRecordsLostAfterReopen() { + appendBatch(5, 1); + appendBatch(10, 2); + log.flush(); + + appendBatch(5, 3); + appendBatch(10, 4); + log.reopen(); + + assertEquals(15L, log.endOffset().offset); + assertEquals(2, log.lastFetchedEpoch()); + } + + @Test + public void testCreateSnapshot() throws IOException { + int numberOfRecords = 10; + int epoch = 0; + OffsetAndEpoch snapshotId = new OffsetAndEpoch(numberOfRecords, epoch); + appendBatch(numberOfRecords, epoch); + log.updateHighWatermark(new LogOffsetMetadata(numberOfRecords)); + + try (RawSnapshotWriter snapshot = log.createNewSnapshot(snapshotId).get()) { + snapshot.freeze(); + } + + RawSnapshotReader snapshot = log.readSnapshot(snapshotId).get(); + assertEquals(0, snapshot.sizeInBytes()); + } + + @Test + public void testCreateSnapshotValidation() { + int numberOfRecords = 10; + int firstEpoch = 1; + int secondEpoch = 3; + + appendBatch(numberOfRecords, firstEpoch); + appendBatch(numberOfRecords, secondEpoch); + log.updateHighWatermark(new LogOffsetMetadata(2 * numberOfRecords)); + + // Test snapshot id for the first epoch + log.createNewSnapshot(new OffsetAndEpoch(numberOfRecords, firstEpoch)).get().close(); + log.createNewSnapshot(new OffsetAndEpoch(numberOfRecords - 1, firstEpoch)).get().close(); + log.createNewSnapshot(new OffsetAndEpoch(1, firstEpoch)).get().close(); + + // Test snapshot id for the second epoch + log.createNewSnapshot(new OffsetAndEpoch(2 * numberOfRecords, secondEpoch)).get().close(); + log.createNewSnapshot(new OffsetAndEpoch(2 * numberOfRecords - 1, secondEpoch)).get().close(); + log.createNewSnapshot(new OffsetAndEpoch(numberOfRecords + 1, secondEpoch)).get().close(); + } + + @Test + public void testCreateSnapshotLaterThanHighWatermark() { + int numberOfRecords = 10; + int epoch = 1; + + appendBatch(numberOfRecords, epoch); + log.updateHighWatermark(new LogOffsetMetadata(numberOfRecords)); + + assertThrows( + IllegalArgumentException.class, + () -> log.createNewSnapshot(new OffsetAndEpoch(numberOfRecords + 1, epoch)) + ); + } + + @Test + public void testCreateSnapshotMuchLaterEpoch() { + int numberOfRecords = 10; + int epoch = 1; + + appendBatch(numberOfRecords, epoch); + log.updateHighWatermark(new LogOffsetMetadata(numberOfRecords)); + + assertThrows( + IllegalArgumentException.class, + () -> log.createNewSnapshot(new OffsetAndEpoch(numberOfRecords, epoch + 1)) + ); + } + + @Test + public void testCreateSnapshotBeforeLogStartOffset() { + int numberOfRecords = 10; + int epoch = 1; + OffsetAndEpoch snapshotId = new OffsetAndEpoch(numberOfRecords, epoch); + + appendBatch(numberOfRecords, epoch); + log.updateHighWatermark(new LogOffsetMetadata(numberOfRecords)); + + try (RawSnapshotWriter snapshot = log.createNewSnapshot(snapshotId).get()) { + snapshot.freeze(); + } + + assertTrue(log.deleteBeforeSnapshot(snapshotId)); + assertEquals(snapshotId.offset, log.startOffset()); + + assertEquals(Optional.empty(), log.createNewSnapshot(new OffsetAndEpoch(numberOfRecords - 1, epoch))); + } + + @Test + public void testCreateSnapshotMuchEalierEpoch() { + int numberOfRecords = 10; + int epoch = 2; + OffsetAndEpoch snapshotId = new OffsetAndEpoch(numberOfRecords, epoch); + + appendBatch(numberOfRecords, epoch); + log.updateHighWatermark(new LogOffsetMetadata(numberOfRecords)); + + try (RawSnapshotWriter snapshot = log.createNewSnapshot(snapshotId).get()) { + snapshot.freeze(); + } + + assertTrue(log.deleteBeforeSnapshot(snapshotId)); + assertEquals(snapshotId.offset, log.startOffset()); + + assertThrows( + IllegalArgumentException.class, + () -> log.createNewSnapshot(new OffsetAndEpoch(numberOfRecords, epoch - 1)) + ); + } + + @Test + public void testCreateSnapshotWithMissingEpoch() { + int firstBatchRecords = 5; + int firstEpoch = 1; + int missingEpoch = firstEpoch + 1; + int secondBatchRecords = 5; + int secondEpoch = missingEpoch + 1; + + int numberOfRecords = firstBatchRecords + secondBatchRecords; + + appendBatch(firstBatchRecords, firstEpoch); + appendBatch(secondBatchRecords, secondEpoch); + log.updateHighWatermark(new LogOffsetMetadata(numberOfRecords)); + + assertThrows( + IllegalArgumentException.class, + () -> log.createNewSnapshot(new OffsetAndEpoch(1, missingEpoch)) + ); + assertThrows( + IllegalArgumentException.class, + () -> log.createNewSnapshot(new OffsetAndEpoch(firstBatchRecords, missingEpoch)) + ); + assertThrows( + IllegalArgumentException.class, + () -> log.createNewSnapshot(new OffsetAndEpoch(secondBatchRecords, missingEpoch)) + ); + } + + @Test + public void testCreateExistingSnapshot() { + int numberOfRecords = 10; + int epoch = 1; + OffsetAndEpoch snapshotId = new OffsetAndEpoch(numberOfRecords, epoch); + + appendBatch(numberOfRecords, epoch); + log.updateHighWatermark(new LogOffsetMetadata(numberOfRecords)); + + try (RawSnapshotWriter snapshot = log.createNewSnapshot(snapshotId).get()) { + snapshot.freeze(); + } + + assertTrue(log.deleteBeforeSnapshot(snapshotId)); + assertEquals(snapshotId.offset, log.startOffset()); + assertEquals(Optional.empty(), log.createNewSnapshot(snapshotId)); + } + + @Test + public void testReadMissingSnapshot() { + assertFalse(log.readSnapshot(new OffsetAndEpoch(10, 0)).isPresent()); + } + + @Test + public void testUpdateLogStartOffset() throws IOException { + int offset = 10; + int epoch = 0; + OffsetAndEpoch snapshotId = new OffsetAndEpoch(offset, epoch); + + appendBatch(offset, epoch); + log.updateHighWatermark(new LogOffsetMetadata(offset)); + + try (RawSnapshotWriter snapshot = log.createNewSnapshot(snapshotId).get()) { + snapshot.freeze(); + } + + assertTrue(log.deleteBeforeSnapshot(snapshotId)); + assertEquals(offset, log.startOffset()); + assertEquals(epoch, log.lastFetchedEpoch()); + assertEquals(offset, log.endOffset().offset); + + int newRecords = 10; + appendBatch(newRecords, epoch + 1); + log.updateHighWatermark(new LogOffsetMetadata(offset + newRecords)); + + // Start offset should not change since a new snapshot was not generated + assertFalse(log.deleteBeforeSnapshot(new OffsetAndEpoch(offset + newRecords, epoch))); + assertEquals(offset, log.startOffset()); + + assertEquals(epoch + 1, log.lastFetchedEpoch()); + assertEquals(offset + newRecords, log.endOffset().offset); + assertEquals(offset + newRecords, log.highWatermark().offset); + } + + @Test + public void testUpdateLogStartOffsetWithMissingSnapshot() { + int offset = 10; + int epoch = 0; + + appendBatch(offset, epoch); + log.updateHighWatermark(new LogOffsetMetadata(offset)); + + assertFalse(log.deleteBeforeSnapshot(new OffsetAndEpoch(1, epoch))); + assertEquals(0, log.startOffset()); + assertEquals(epoch, log.lastFetchedEpoch()); + assertEquals(offset, log.endOffset().offset); + assertEquals(offset, log.highWatermark().offset); + } + + @Test + public void testFailToIncreaseLogStartPastHighWatermark() throws IOException { + int offset = 10; + int epoch = 0; + OffsetAndEpoch snapshotId = new OffsetAndEpoch(2 * offset, epoch); + + appendBatch(3 * offset, epoch); + log.updateHighWatermark(new LogOffsetMetadata(offset)); + + try (RawSnapshotWriter snapshot = log.storeSnapshot(snapshotId).get()) { + snapshot.freeze(); + } + + assertThrows( + OffsetOutOfRangeException.class, + () -> log.deleteBeforeSnapshot(snapshotId) + ); + } + + @Test + public void testTruncateFullyToLatestSnapshot() throws IOException { + int numberOfRecords = 10; + int epoch = 0; + OffsetAndEpoch sameEpochSnapshotId = new OffsetAndEpoch(2 * numberOfRecords, epoch); + + appendBatch(numberOfRecords, epoch); + + try (RawSnapshotWriter snapshot = log.storeSnapshot(sameEpochSnapshotId).get()) { + snapshot.freeze(); + } + + assertTrue(log.truncateToLatestSnapshot()); + assertEquals(sameEpochSnapshotId.offset, log.startOffset()); + assertEquals(sameEpochSnapshotId.epoch, log.lastFetchedEpoch()); + assertEquals(sameEpochSnapshotId.offset, log.endOffset().offset); + assertEquals(sameEpochSnapshotId.offset, log.highWatermark().offset); + + OffsetAndEpoch greaterEpochSnapshotId = new OffsetAndEpoch(3 * numberOfRecords, epoch + 1); + + appendBatch(numberOfRecords, epoch); + + try (RawSnapshotWriter snapshot = log.storeSnapshot(greaterEpochSnapshotId).get()) { + snapshot.freeze(); + } + + assertTrue(log.truncateToLatestSnapshot()); + assertEquals(greaterEpochSnapshotId.offset, log.startOffset()); + assertEquals(greaterEpochSnapshotId.epoch, log.lastFetchedEpoch()); + assertEquals(greaterEpochSnapshotId.offset, log.endOffset().offset); + assertEquals(greaterEpochSnapshotId.offset, log.highWatermark().offset); + } + + @Test + public void testDoesntTruncateFully() throws IOException { + int numberOfRecords = 10; + int epoch = 1; + + appendBatch(numberOfRecords, epoch); + + OffsetAndEpoch olderEpochSnapshotId = new OffsetAndEpoch(numberOfRecords, epoch - 1); + try (RawSnapshotWriter snapshot = log.storeSnapshot(olderEpochSnapshotId).get()) { + snapshot.freeze(); + } + + assertFalse(log.truncateToLatestSnapshot()); + + appendBatch(numberOfRecords, epoch); + + OffsetAndEpoch olderOffsetSnapshotId = new OffsetAndEpoch(numberOfRecords, epoch); + try (RawSnapshotWriter snapshot = log.storeSnapshot(olderOffsetSnapshotId).get()) { + snapshot.freeze(); + } + + assertFalse(log.truncateToLatestSnapshot()); + } + + @Test + public void testTruncateWillRemoveOlderSnapshot() throws IOException { + int numberOfRecords = 10; + int epoch = 1; + + OffsetAndEpoch sameEpochSnapshotId = new OffsetAndEpoch(numberOfRecords, epoch); + appendBatch(numberOfRecords, epoch); + log.updateHighWatermark(new LogOffsetMetadata(sameEpochSnapshotId.offset)); + + try (RawSnapshotWriter snapshot = log.createNewSnapshot(sameEpochSnapshotId).get()) { + snapshot.freeze(); + } + + OffsetAndEpoch greaterEpochSnapshotId = new OffsetAndEpoch(2 * numberOfRecords, epoch + 1); + appendBatch(numberOfRecords, epoch); + + try (RawSnapshotWriter snapshot = log.storeSnapshot(greaterEpochSnapshotId).get()) { + snapshot.freeze(); + } + + assertTrue(log.truncateToLatestSnapshot()); + assertEquals(Optional.empty(), log.readSnapshot(sameEpochSnapshotId)); + } + + @Test + public void testUpdateLogStartOffsetWillRemoveOlderSnapshot() throws IOException { + int numberOfRecords = 10; + int epoch = 1; + + OffsetAndEpoch sameEpochSnapshotId = new OffsetAndEpoch(numberOfRecords, epoch); + appendBatch(numberOfRecords, epoch); + log.updateHighWatermark(new LogOffsetMetadata(sameEpochSnapshotId.offset)); + + try (RawSnapshotWriter snapshot = log.createNewSnapshot(sameEpochSnapshotId).get()) { + snapshot.freeze(); + } + + OffsetAndEpoch greaterEpochSnapshotId = new OffsetAndEpoch(2 * numberOfRecords, epoch + 1); + appendBatch(numberOfRecords, greaterEpochSnapshotId.epoch); + log.updateHighWatermark(new LogOffsetMetadata(greaterEpochSnapshotId.offset)); + + try (RawSnapshotWriter snapshot = log.createNewSnapshot(greaterEpochSnapshotId).get()) { + snapshot.freeze(); + } + + assertTrue(log.deleteBeforeSnapshot(greaterEpochSnapshotId)); + assertEquals(Optional.empty(), log.readSnapshot(sameEpochSnapshotId)); + } + + @Test + public void testValidateEpochGreaterThanLastKnownEpoch() { + int numberOfRecords = 1; + int epoch = 1; + + appendBatch(numberOfRecords, epoch); + + ValidOffsetAndEpoch resultOffsetAndEpoch = log.validateOffsetAndEpoch(numberOfRecords, epoch + 1); + assertEquals(ValidOffsetAndEpoch.diverging(new OffsetAndEpoch(log.endOffset().offset, epoch)), + resultOffsetAndEpoch); + } + + @Test + public void testValidateEpochLessThanOldestSnapshotEpoch() throws IOException { + int offset = 1; + int epoch = 1; + + OffsetAndEpoch olderEpochSnapshotId = new OffsetAndEpoch(offset, epoch); + try (RawSnapshotWriter snapshot = log.storeSnapshot(olderEpochSnapshotId).get()) { + snapshot.freeze(); + } + log.truncateToLatestSnapshot(); + + ValidOffsetAndEpoch resultOffsetAndEpoch = log.validateOffsetAndEpoch(offset, epoch - 1); + assertEquals(ValidOffsetAndEpoch.snapshot(olderEpochSnapshotId), resultOffsetAndEpoch); + } + + @Test + public void testValidateOffsetLessThanOldestSnapshotOffset() throws IOException { + int offset = 2; + int epoch = 1; + + OffsetAndEpoch olderEpochSnapshotId = new OffsetAndEpoch(offset, epoch); + try (RawSnapshotWriter snapshot = log.storeSnapshot(olderEpochSnapshotId).get()) { + snapshot.freeze(); + } + log.truncateToLatestSnapshot(); + + ValidOffsetAndEpoch resultOffsetAndEpoch = log.validateOffsetAndEpoch(offset - 1, epoch); + assertEquals(ValidOffsetAndEpoch.snapshot(olderEpochSnapshotId), resultOffsetAndEpoch); + } + + @Test + public void testValidateOffsetEqualToOldestSnapshotOffset() throws IOException { + int offset = 2; + int epoch = 1; + + OffsetAndEpoch olderEpochSnapshotId = new OffsetAndEpoch(offset, epoch); + try (RawSnapshotWriter snapshot = log.storeSnapshot(olderEpochSnapshotId).get()) { + snapshot.freeze(); + } + log.truncateToLatestSnapshot(); + + ValidOffsetAndEpoch resultOffsetAndEpoch = log.validateOffsetAndEpoch(offset, epoch); + assertEquals(ValidOffsetAndEpoch.Kind.VALID, resultOffsetAndEpoch.kind()); + } + + @Test + public void testValidateUnknownEpochLessThanLastKnownGreaterThanOldestSnapshot() throws IOException { + int numberOfRecords = 5; + int offset = 10; + + OffsetAndEpoch olderEpochSnapshotId = new OffsetAndEpoch(offset, 1); + try (RawSnapshotWriter snapshot = log.storeSnapshot(olderEpochSnapshotId).get()) { + snapshot.freeze(); + } + log.truncateToLatestSnapshot(); + + appendBatch(numberOfRecords, 1); + appendBatch(numberOfRecords, 2); + appendBatch(numberOfRecords, 4); + + // offset is not equal to oldest snapshot's offset + ValidOffsetAndEpoch resultOffsetAndEpoch = log.validateOffsetAndEpoch(100, 3); + assertEquals(ValidOffsetAndEpoch.diverging(new OffsetAndEpoch(20, 2)), resultOffsetAndEpoch); + } + + @Test + public void testValidateEpochLessThanFirstEpochInLog() throws IOException { + int numberOfRecords = 5; + int offset = 10; + + OffsetAndEpoch olderEpochSnapshotId = new OffsetAndEpoch(offset, 1); + try (RawSnapshotWriter snapshot = log.storeSnapshot(olderEpochSnapshotId).get()) { + snapshot.freeze(); + } + log.truncateToLatestSnapshot(); + + appendBatch(numberOfRecords, 3); + + // offset is not equal to oldest snapshot's offset + ValidOffsetAndEpoch resultOffsetAndEpoch = log.validateOffsetAndEpoch(100, 2); + assertEquals(ValidOffsetAndEpoch.diverging(olderEpochSnapshotId), resultOffsetAndEpoch); + } + + @Test + public void testValidateOffsetGreatThanEndOffset() { + int numberOfRecords = 1; + int epoch = 1; + + appendBatch(numberOfRecords, epoch); + + ValidOffsetAndEpoch resultOffsetAndEpoch = log.validateOffsetAndEpoch(numberOfRecords + 1, epoch); + assertEquals(ValidOffsetAndEpoch.diverging(new OffsetAndEpoch(log.endOffset().offset, epoch)), + resultOffsetAndEpoch); + } + + @Test + public void testValidateOffsetLessThanLEO() { + int numberOfRecords = 10; + int epoch = 1; + + appendBatch(numberOfRecords, epoch); + appendBatch(numberOfRecords, epoch + 1); + + ValidOffsetAndEpoch resultOffsetAndEpoch = log.validateOffsetAndEpoch(11, epoch); + assertEquals(ValidOffsetAndEpoch.diverging(new OffsetAndEpoch(10, epoch)), resultOffsetAndEpoch); + } + + @Test + public void testValidateValidEpochAndOffset() { + int numberOfRecords = 5; + int epoch = 1; + + appendBatch(numberOfRecords, epoch); + + ValidOffsetAndEpoch resultOffsetAndEpoch = log.validateOffsetAndEpoch(numberOfRecords - 1, epoch); + assertEquals(ValidOffsetAndEpoch.Kind.VALID, resultOffsetAndEpoch.kind()); + } + + private Optional readOffsets(long startOffset, Isolation isolation) { + // The current MockLog implementation reads at most one batch + + long firstReadOffset = -1L; + long lastReadOffset = -1L; + + long currentStart = startOffset; + boolean foundRecord = true; + while (foundRecord) { + foundRecord = false; + + Records records = log.read(currentStart, isolation).records; + for (Record record : records.records()) { + foundRecord = true; + + if (firstReadOffset < 0L) { + firstReadOffset = record.offset(); + } + + if (record.offset() > lastReadOffset) { + lastReadOffset = record.offset(); + } + } + + currentStart = lastReadOffset + 1; + } + + if (firstReadOffset < 0) { + return Optional.empty(); + } else { + return Optional.of(new OffsetRange(firstReadOffset, lastReadOffset)); + } + } + + private static class OffsetRange { + public final long startOffset; + public final long endOffset; + + private OffsetRange(long startOffset, long endOffset) { + this.startOffset = startOffset; + this.endOffset = endOffset; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + OffsetRange that = (OffsetRange) o; + return startOffset == that.startOffset && + endOffset == that.endOffset; + } + + @Override + public int hashCode() { + return Objects.hash(startOffset, endOffset); + } + + @Override + public String toString() { + return String.format("OffsetRange(startOffset=%s, endOffset=%s)", startOffset, endOffset); + } + } + + private void appendAsLeader(Collection records, int epoch) { + log.appendAsLeader( + MemoryRecords.withRecords( + log.endOffset().offset, + CompressionType.NONE, + records.toArray(new SimpleRecord[records.size()]) + ), + epoch + ); + } + + private void appendBatch(int numRecords, int epoch) { + List records = new ArrayList<>(numRecords); + for (int i = 0; i < numRecords; i++) { + records.add(new SimpleRecord(String.valueOf(i).getBytes())); + } + + appendAsLeader(records, epoch); + } + + private static void validateReadRecords(List expectedRecords, MockLog log) { + assertEquals(0L, log.startOffset()); + assertEquals(expectedRecords.size(), log.endOffset().offset); + + int currentOffset = 0; + while (currentOffset < log.endOffset().offset) { + Records records = log.read(currentOffset, Isolation.UNCOMMITTED).records; + List batches = Utils.toList(records.batches().iterator()); + + assertTrue(batches.size() > 0); + for (RecordBatch batch : batches) { + assertTrue(batch.countOrNull() > 0); + assertEquals(currentOffset, batch.baseOffset()); + assertEquals(currentOffset + batch.countOrNull() - 1, batch.lastOffset()); + + for (Record record : batch) { + assertEquals(currentOffset, record.offset()); + assertEquals(expectedRecords.get(currentOffset), new SimpleRecord(record)); + currentOffset += 1; + } + + assertEquals(currentOffset - 1, batch.lastOffset()); + } + } + } +} diff --git a/raft/src/test/java/org/apache/kafka/raft/MockMessageQueue.java b/raft/src/test/java/org/apache/kafka/raft/MockMessageQueue.java new file mode 100644 index 0000000..5fcd599 --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/MockMessageQueue.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import java.util.ArrayDeque; +import java.util.OptionalLong; +import java.util.Queue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; + +/** + * Mocked implementation which does not block in {@link #poll(long)}.. + */ +public class MockMessageQueue implements RaftMessageQueue { + private final Queue messages = new ArrayDeque<>(); + private final AtomicBoolean wakeupRequested = new AtomicBoolean(false); + private final AtomicLong lastPollTimeout = new AtomicLong(-1); + + @Override + public RaftMessage poll(long timeoutMs) { + wakeupRequested.set(false); + lastPollTimeout.set(timeoutMs); + return messages.poll(); + } + + @Override + public void add(RaftMessage message) { + messages.offer(message); + } + + public OptionalLong lastPollTimeoutMs() { + long lastTimeoutMs = lastPollTimeout.get(); + if (lastTimeoutMs < 0) { + return OptionalLong.empty(); + } else { + return OptionalLong.of(lastTimeoutMs); + } + } + + public boolean wakeupRequested() { + return wakeupRequested.get(); + } + + @Override + public boolean isEmpty() { + return messages.isEmpty(); + } + + @Override + public void wakeup() { + wakeupRequested.set(true); + } +} diff --git a/raft/src/test/java/org/apache/kafka/raft/MockNetworkChannel.java b/raft/src/test/java/org/apache/kafka/raft/MockNetworkChannel.java new file mode 100644 index 0000000..2a97931 --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/MockNetworkChannel.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.protocol.ApiKeys; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; + +public class MockNetworkChannel implements NetworkChannel { + private final AtomicInteger correlationIdCounter; + private final Set nodeCache; + private final List sendQueue = new ArrayList<>(); + private final Map awaitingResponse = new HashMap<>(); + + public MockNetworkChannel(AtomicInteger correlationIdCounter, Set destinationIds) { + this.correlationIdCounter = correlationIdCounter; + this.nodeCache = destinationIds; + } + + public MockNetworkChannel(Set destinationIds) { + this(new AtomicInteger(0), destinationIds); + } + + @Override + public int newCorrelationId() { + return correlationIdCounter.getAndIncrement(); + } + + @Override + public void send(RaftRequest.Outbound request) { + if (!nodeCache.contains(request.destinationId())) { + throw new IllegalArgumentException("Attempted to send to destination " + + request.destinationId() + ", but its address is not yet known"); + } + sendQueue.add(request); + } + + @Override + public void updateEndpoint(int id, RaftConfig.InetAddressSpec address) { + // empty + } + + public List drainSendQueue() { + return drainSentRequests(Optional.empty()); + } + + public List drainSentRequests(Optional apiKeyFilter) { + List requests = new ArrayList<>(); + Iterator iterator = sendQueue.iterator(); + while (iterator.hasNext()) { + RaftRequest.Outbound request = iterator.next(); + if (!apiKeyFilter.isPresent() || request.data().apiKey() == apiKeyFilter.get().id) { + awaitingResponse.put(request.correlationId, request); + requests.add(request); + iterator.remove(); + } + } + return requests; + } + + + public boolean hasSentRequests() { + return !sendQueue.isEmpty(); + } + + public void mockReceive(RaftResponse.Inbound response) { + RaftRequest.Outbound request = awaitingResponse.get(response.correlationId); + if (request == null) { + throw new IllegalStateException("Received response for a request which is not being awaited"); + } + request.completion.complete(response); + } + +} diff --git a/raft/src/test/java/org/apache/kafka/raft/MockQuorumStateStore.java b/raft/src/test/java/org/apache/kafka/raft/MockQuorumStateStore.java new file mode 100644 index 0000000..02e3bc8 --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/MockQuorumStateStore.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +public class MockQuorumStateStore implements QuorumStateStore { + private ElectionState current; + + @Override + public ElectionState readElectionState() { + return current; + } + + @Override + public void writeElectionState(ElectionState update) { + this.current = update; + } + + @Override + public void clear() { + current = null; + } +} diff --git a/raft/src/test/java/org/apache/kafka/raft/MockableRandom.java b/raft/src/test/java/org/apache/kafka/raft/MockableRandom.java new file mode 100644 index 0000000..b487b16 --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/MockableRandom.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import java.util.OptionalInt; +import java.util.Random; +import java.util.function.IntFunction; + +/** + * A Random instance that makes it easy to modify the behavior of certain methods for test purposes. + */ +class MockableRandom extends Random { + + private IntFunction nextIntFunction = __ -> OptionalInt.empty(); + + public MockableRandom(long seed) { + super(seed); + } + + public void mockNextInt(int expectedBound, int returnValue) { + this.nextIntFunction = b -> { + if (b == expectedBound) + return OptionalInt.of(returnValue); + else + return OptionalInt.empty(); + }; + } + + public void mockNextInt(int returnValue) { + this.nextIntFunction = __ -> OptionalInt.of(returnValue); + } + + @Override + public int nextInt(int bound) { + return nextIntFunction.apply(bound).orElse(super.nextInt(bound)); + } +} diff --git a/raft/src/test/java/org/apache/kafka/raft/QuorumStateTest.java b/raft/src/test/java/org/apache/kafka/raft/QuorumStateTest.java new file mode 100644 index 0000000..4463a5d --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/QuorumStateTest.java @@ -0,0 +1,1073 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.raft.internals.BatchAccumulator; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.Collections; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.OptionalLong; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class QuorumStateTest { + private final int localId = 0; + private final int logEndEpoch = 0; + private final MockQuorumStateStore store = new MockQuorumStateStore(); + private final MockTime time = new MockTime(); + private final int electionTimeoutMs = 5000; + private final int fetchTimeoutMs = 10000; + private final MockableRandom random = new MockableRandom(1L); + + private BatchAccumulator accumulator = Mockito.mock(BatchAccumulator.class); + + private QuorumState buildQuorumState(Set voters) { + return buildQuorumState(OptionalInt.of(localId), voters); + } + + private QuorumState buildQuorumState( + OptionalInt localId, + Set voters + ) { + return new QuorumState( + localId, + voters, + electionTimeoutMs, + fetchTimeoutMs, + store, + time, + new LogContext(), + random + ); + } + + @Test + public void testInitializePrimordialEpoch() throws IOException { + Set voters = Utils.mkSet(localId); + assertNull(store.readElectionState()); + + QuorumState state = initializeEmptyState(voters); + assertTrue(state.isUnattached()); + assertEquals(0, state.epoch()); + state.transitionToCandidate(); + CandidateState candidateState = state.candidateStateOrThrow(); + assertTrue(candidateState.isVoteGranted()); + assertEquals(1, candidateState.epoch()); + } + + @Test + public void testInitializeAsUnattached() throws IOException { + int node1 = 1; + int node2 = 2; + int epoch = 5; + Set voters = Utils.mkSet(localId, node1, node2); + store.writeElectionState(ElectionState.withUnknownLeader(epoch, voters)); + + int jitterMs = 2500; + random.mockNextInt(jitterMs); + + QuorumState state = buildQuorumState(voters); + state.initialize(new OffsetAndEpoch(0L, 0)); + + assertTrue(state.isUnattached()); + UnattachedState unattachedState = state.unattachedStateOrThrow(); + assertEquals(epoch, unattachedState.epoch()); + assertEquals(electionTimeoutMs + jitterMs, + unattachedState.remainingElectionTimeMs(time.milliseconds())); + } + + @Test + public void testInitializeAsFollower() throws IOException { + int node1 = 1; + int node2 = 2; + int epoch = 5; + Set voters = Utils.mkSet(localId, node1, node2); + store.writeElectionState(ElectionState.withElectedLeader(epoch, node1, voters)); + + QuorumState state = buildQuorumState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + assertTrue(state.isFollower()); + assertEquals(epoch, state.epoch()); + + FollowerState followerState = state.followerStateOrThrow(); + assertEquals(epoch, followerState.epoch()); + assertEquals(node1, followerState.leaderId()); + assertEquals(fetchTimeoutMs, followerState.remainingFetchTimeMs(time.milliseconds())); + } + + @Test + public void testInitializeAsVoted() throws IOException { + int node1 = 1; + int node2 = 2; + int epoch = 5; + Set voters = Utils.mkSet(localId, node1, node2); + store.writeElectionState(ElectionState.withVotedCandidate(epoch, node1, voters)); + + int jitterMs = 2500; + random.mockNextInt(jitterMs); + + QuorumState state = buildQuorumState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + assertTrue(state.isVoted()); + assertEquals(epoch, state.epoch()); + + VotedState votedState = state.votedStateOrThrow(); + assertEquals(epoch, votedState.epoch()); + assertEquals(node1, votedState.votedId()); + assertEquals(electionTimeoutMs + jitterMs, + votedState.remainingElectionTimeMs(time.milliseconds())); + } + + @Test + public void testInitializeAsResignedCandidate() throws IOException { + int node1 = 1; + int node2 = 2; + int epoch = 5; + Set voters = Utils.mkSet(localId, node1, node2); + ElectionState election = ElectionState.withVotedCandidate(epoch, localId, voters); + store.writeElectionState(election); + + int jitterMs = 2500; + random.mockNextInt(jitterMs); + + QuorumState state = buildQuorumState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + assertTrue(state.isCandidate()); + assertEquals(epoch, state.epoch()); + + CandidateState candidateState = state.candidateStateOrThrow(); + assertEquals(epoch, candidateState.epoch()); + assertEquals(election, candidateState.election()); + assertEquals(Utils.mkSet(node1, node2), candidateState.unrecordedVoters()); + assertEquals(Utils.mkSet(localId), candidateState.grantingVoters()); + assertEquals(Collections.emptySet(), candidateState.rejectingVoters()); + assertEquals(electionTimeoutMs + jitterMs, + candidateState.remainingElectionTimeMs(time.milliseconds())); + } + + @Test + public void testInitializeAsResignedLeader() throws IOException { + int node1 = 1; + int node2 = 2; + int epoch = 5; + Set voters = Utils.mkSet(localId, node1, node2); + ElectionState election = ElectionState.withElectedLeader(epoch, localId, voters); + store.writeElectionState(election); + + // If we were previously a leader, we will start as resigned in order to ensure + // a new leader gets elected. This ensures that records are always uniquely + // defined by epoch and offset even accounting for the loss of unflushed data. + + // The election timeout should be reset after we become a candidate again + int jitterMs = 2500; + random.mockNextInt(jitterMs); + + QuorumState state = buildQuorumState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + assertFalse(state.isLeader()); + assertEquals(epoch, state.epoch()); + + ResignedState resignedState = state.resignedStateOrThrow(); + assertEquals(epoch, resignedState.epoch()); + assertEquals(election, resignedState.election()); + assertEquals(Utils.mkSet(node1, node2), resignedState.unackedVoters()); + assertEquals(electionTimeoutMs + jitterMs, + resignedState.remainingElectionTimeMs(time.milliseconds())); + } + + @Test + public void testCandidateToCandidate() throws IOException { + int node1 = 1; + int node2 = 2; + Set voters = Utils.mkSet(localId, node1, node2); + assertNull(store.readElectionState()); + + QuorumState state = initializeEmptyState(voters); + state.transitionToCandidate(); + assertTrue(state.isCandidate()); + assertEquals(1, state.epoch()); + + CandidateState candidate1 = state.candidateStateOrThrow(); + candidate1.recordRejectedVote(node2); + + // Check backoff behavior before transitioning + int backoffMs = 500; + candidate1.startBackingOff(time.milliseconds(), backoffMs); + assertTrue(candidate1.isBackingOff()); + assertFalse(candidate1.isBackoffComplete(time.milliseconds())); + + time.sleep(backoffMs - 1); + assertTrue(candidate1.isBackingOff()); + assertFalse(candidate1.isBackoffComplete(time.milliseconds())); + + time.sleep(1); + assertTrue(candidate1.isBackingOff()); + assertTrue(candidate1.isBackoffComplete(time.milliseconds())); + + // The election timeout should be reset after we become a candidate again + int jitterMs = 2500; + random.mockNextInt(jitterMs); + + state.transitionToCandidate(); + assertTrue(state.isCandidate()); + CandidateState candidate2 = state.candidateStateOrThrow(); + assertEquals(2, state.epoch()); + assertEquals(Collections.singleton(localId), candidate2.grantingVoters()); + assertEquals(Collections.emptySet(), candidate2.rejectingVoters()); + assertEquals(electionTimeoutMs + jitterMs, + candidate2.remainingElectionTimeMs(time.milliseconds())); + } + + @Test + public void testCandidateToResigned() throws IOException { + int node1 = 1; + int node2 = 2; + Set voters = Utils.mkSet(localId, node1, node2); + assertNull(store.readElectionState()); + + QuorumState state = initializeEmptyState(voters); + state.transitionToCandidate(); + assertTrue(state.isCandidate()); + assertEquals(1, state.epoch()); + + assertThrows(IllegalStateException.class, () -> + state.transitionToResigned(Collections.singletonList(localId))); + assertTrue(state.isCandidate()); + } + + @Test + public void testCandidateToLeader() throws IOException { + Set voters = Utils.mkSet(localId); + assertNull(store.readElectionState()); + + QuorumState state = initializeEmptyState(voters); + state.transitionToCandidate(); + assertTrue(state.isCandidate()); + assertEquals(1, state.epoch()); + + state.transitionToLeader(0L, accumulator); + LeaderState leaderState = state.leaderStateOrThrow(); + assertTrue(state.isLeader()); + assertEquals(1, leaderState.epoch()); + assertEquals(Optional.empty(), leaderState.highWatermark()); + } + + @Test + public void testCandidateToLeaderWithoutGrantedVote() throws IOException { + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToCandidate(); + assertFalse(state.candidateStateOrThrow().isVoteGranted()); + assertThrows(IllegalStateException.class, () -> state.transitionToLeader(0L, accumulator)); + state.candidateStateOrThrow().recordGrantedVote(otherNodeId); + assertTrue(state.candidateStateOrThrow().isVoteGranted()); + state.transitionToLeader(0L, accumulator); + assertTrue(state.isLeader()); + } + + @Test + public void testCandidateToFollower() throws IOException { + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToCandidate(); + + state.transitionToFollower(5, otherNodeId); + assertEquals(5, state.epoch()); + assertEquals(OptionalInt.of(otherNodeId), state.leaderId()); + assertEquals(ElectionState.withElectedLeader(5, otherNodeId, voters), store.readElectionState()); + } + + @Test + public void testCandidateToUnattached() throws IOException { + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToCandidate(); + + state.transitionToUnattached(5); + assertEquals(5, state.epoch()); + assertEquals(OptionalInt.empty(), state.leaderId()); + assertEquals(ElectionState.withUnknownLeader(5, voters), store.readElectionState()); + } + + @Test + public void testCandidateToVoted() throws IOException { + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToCandidate(); + + state.transitionToVoted(5, otherNodeId); + assertEquals(5, state.epoch()); + assertEquals(OptionalInt.empty(), state.leaderId()); + + VotedState followerState = state.votedStateOrThrow(); + assertEquals(otherNodeId, followerState.votedId()); + assertEquals(ElectionState.withVotedCandidate(5, otherNodeId, voters), store.readElectionState()); + } + + @Test + public void testCandidateToAnyStateLowerEpoch() throws IOException { + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToUnattached(5); + state.transitionToCandidate(); + assertThrows(IllegalStateException.class, () -> state.transitionToUnattached(4)); + assertThrows(IllegalStateException.class, () -> state.transitionToVoted(4, otherNodeId)); + assertThrows(IllegalStateException.class, () -> state.transitionToFollower(4, otherNodeId)); + assertEquals(6, state.epoch()); + assertEquals(ElectionState.withVotedCandidate(6, localId, voters), store.readElectionState()); + } + + @Test + public void testLeaderToLeader() throws IOException { + Set voters = Utils.mkSet(localId); + assertNull(store.readElectionState()); + + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToCandidate(); + state.transitionToLeader(0L, accumulator); + assertTrue(state.isLeader()); + assertEquals(1, state.epoch()); + + assertThrows(IllegalStateException.class, () -> state.transitionToLeader(0L, accumulator)); + assertTrue(state.isLeader()); + assertEquals(1, state.epoch()); + } + + @Test + public void testLeaderToResigned() throws IOException { + Set voters = Utils.mkSet(localId); + assertNull(store.readElectionState()); + + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToCandidate(); + state.transitionToLeader(0L, accumulator); + assertTrue(state.isLeader()); + assertEquals(1, state.epoch()); + + state.transitionToResigned(Collections.singletonList(localId)); + assertTrue(state.isResigned()); + ResignedState resignedState = state.resignedStateOrThrow(); + assertEquals(ElectionState.withElectedLeader(1, localId, voters), + resignedState.election()); + assertEquals(1, resignedState.epoch()); + assertEquals(Collections.emptySet(), resignedState.unackedVoters()); + } + + @Test + public void testLeaderToCandidate() throws IOException { + Set voters = Utils.mkSet(localId); + assertNull(store.readElectionState()); + + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToCandidate(); + state.transitionToLeader(0L, accumulator); + assertTrue(state.isLeader()); + assertEquals(1, state.epoch()); + + assertThrows(IllegalStateException.class, state::transitionToCandidate); + assertTrue(state.isLeader()); + assertEquals(1, state.epoch()); + } + + @Test + public void testLeaderToFollower() throws IOException { + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + + QuorumState state = initializeEmptyState(voters); + + state.transitionToCandidate(); + state.candidateStateOrThrow().recordGrantedVote(otherNodeId); + state.transitionToLeader(0L, accumulator); + state.transitionToFollower(5, otherNodeId); + + assertEquals(5, state.epoch()); + assertEquals(OptionalInt.of(otherNodeId), state.leaderId()); + assertEquals(ElectionState.withElectedLeader(5, otherNodeId, voters), store.readElectionState()); + } + + @Test + public void testLeaderToUnattached() throws IOException { + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToCandidate(); + state.candidateStateOrThrow().recordGrantedVote(otherNodeId); + state.transitionToLeader(0L, accumulator); + state.transitionToUnattached(5); + assertEquals(5, state.epoch()); + assertEquals(OptionalInt.empty(), state.leaderId()); + assertEquals(ElectionState.withUnknownLeader(5, voters), store.readElectionState()); + } + + @Test + public void testLeaderToVoted() throws IOException { + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToCandidate(); + state.candidateStateOrThrow().recordGrantedVote(otherNodeId); + state.transitionToLeader(0L, accumulator); + state.transitionToVoted(5, otherNodeId); + + assertEquals(5, state.epoch()); + assertEquals(OptionalInt.empty(), state.leaderId()); + VotedState votedState = state.votedStateOrThrow(); + assertEquals(otherNodeId, votedState.votedId()); + assertEquals(ElectionState.withVotedCandidate(5, otherNodeId, voters), store.readElectionState()); + } + + @Test + public void testLeaderToAnyStateLowerEpoch() throws IOException { + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToUnattached(5); + state.transitionToCandidate(); + state.candidateStateOrThrow().recordGrantedVote(otherNodeId); + state.transitionToLeader(0L, accumulator); + assertThrows(IllegalStateException.class, () -> state.transitionToUnattached(4)); + assertThrows(IllegalStateException.class, () -> state.transitionToVoted(4, otherNodeId)); + assertThrows(IllegalStateException.class, () -> state.transitionToFollower(4, otherNodeId)); + assertEquals(6, state.epoch()); + assertEquals(ElectionState.withElectedLeader(6, localId, voters), store.readElectionState()); + } + + @Test + public void testCannotFollowOrVoteForSelf() throws IOException { + Set voters = Utils.mkSet(localId); + assertNull(store.readElectionState()); + QuorumState state = initializeEmptyState(voters); + + assertThrows(IllegalStateException.class, () -> state.transitionToFollower(0, localId)); + assertThrows(IllegalStateException.class, () -> state.transitionToVoted(0, localId)); + } + + @Test + public void testUnattachedToLeaderOrResigned() throws IOException { + int leaderId = 1; + int epoch = 5; + Set voters = Utils.mkSet(localId, leaderId); + store.writeElectionState(ElectionState.withVotedCandidate(epoch, leaderId, voters)); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + assertTrue(state.isUnattached()); + assertThrows(IllegalStateException.class, () -> state.transitionToLeader(0L, accumulator)); + assertThrows(IllegalStateException.class, () -> state.transitionToResigned(Collections.emptyList())); + } + + @Test + public void testUnattachedToVotedSameEpoch() throws IOException { + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToUnattached(5); + + int jitterMs = 2500; + random.mockNextInt(electionTimeoutMs, jitterMs); + state.transitionToVoted(5, otherNodeId); + + VotedState votedState = state.votedStateOrThrow(); + assertEquals(5, votedState.epoch()); + assertEquals(otherNodeId, votedState.votedId()); + assertEquals(ElectionState.withVotedCandidate(5, otherNodeId, voters), store.readElectionState()); + + // Verify election timeout is reset when we vote for a candidate + assertEquals(electionTimeoutMs + jitterMs, + votedState.remainingElectionTimeMs(time.milliseconds())); + } + + @Test + public void testUnattachedToVotedHigherEpoch() throws IOException { + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToUnattached(5); + state.transitionToVoted(8, otherNodeId); + + VotedState votedState = state.votedStateOrThrow(); + assertEquals(8, votedState.epoch()); + assertEquals(otherNodeId, votedState.votedId()); + assertEquals(ElectionState.withVotedCandidate(8, otherNodeId, voters), store.readElectionState()); + } + + @Test + public void testUnattachedToCandidate() throws IOException { + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToUnattached(5); + + int jitterMs = 2500; + random.mockNextInt(electionTimeoutMs, jitterMs); + state.transitionToCandidate(); + + assertTrue(state.isCandidate()); + CandidateState candidateState = state.candidateStateOrThrow(); + assertEquals(6, candidateState.epoch()); + assertEquals(electionTimeoutMs + jitterMs, + candidateState.remainingElectionTimeMs(time.milliseconds())); + } + + @Test + public void testUnattachedToUnattached() throws IOException { + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToUnattached(5); + + long remainingElectionTimeMs = state.unattachedStateOrThrow().remainingElectionTimeMs(time.milliseconds()); + time.sleep(1000); + + state.transitionToUnattached(6); + UnattachedState unattachedState = state.unattachedStateOrThrow(); + assertEquals(6, unattachedState.epoch()); + + // Verify that the election timer does not get reset + assertEquals(remainingElectionTimeMs - 1000, + unattachedState.remainingElectionTimeMs(time.milliseconds())); + } + + @Test + public void testUnattachedToFollowerSameEpoch() throws IOException { + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToUnattached(5); + + state.transitionToFollower(5, otherNodeId); + assertTrue(state.isFollower()); + FollowerState followerState = state.followerStateOrThrow(); + assertEquals(5, followerState.epoch()); + assertEquals(otherNodeId, followerState.leaderId()); + assertEquals(fetchTimeoutMs, followerState.remainingFetchTimeMs(time.milliseconds())); + } + + @Test + public void testUnattachedToFollowerHigherEpoch() throws IOException { + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToUnattached(5); + + state.transitionToFollower(8, otherNodeId); + assertTrue(state.isFollower()); + FollowerState followerState = state.followerStateOrThrow(); + assertEquals(8, followerState.epoch()); + assertEquals(otherNodeId, followerState.leaderId()); + assertEquals(fetchTimeoutMs, followerState.remainingFetchTimeMs(time.milliseconds())); + } + + @Test + public void testUnattachedToAnyStateLowerEpoch() throws IOException { + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToUnattached(5); + assertThrows(IllegalStateException.class, () -> state.transitionToUnattached(4)); + assertThrows(IllegalStateException.class, () -> state.transitionToVoted(4, otherNodeId)); + assertThrows(IllegalStateException.class, () -> state.transitionToFollower(4, otherNodeId)); + assertEquals(5, state.epoch()); + assertEquals(ElectionState.withUnknownLeader(5, voters), store.readElectionState()); + } + + @Test + public void testVotedToInvalidLeaderOrResigned() throws IOException { + int node1 = 1; + int node2 = 2; + Set voters = Utils.mkSet(localId, node1, node2); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToVoted(5, node1); + assertThrows(IllegalStateException.class, () -> state.transitionToLeader(0, accumulator)); + assertThrows(IllegalStateException.class, () -> state.transitionToResigned(Collections.emptyList())); + } + + @Test + public void testVotedToCandidate() throws IOException { + int node1 = 1; + int node2 = 2; + Set voters = Utils.mkSet(localId, node1, node2); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToVoted(5, node1); + + int jitterMs = 2500; + random.mockNextInt(electionTimeoutMs, jitterMs); + state.transitionToCandidate(); + assertTrue(state.isCandidate()); + CandidateState candidateState = state.candidateStateOrThrow(); + assertEquals(6, candidateState.epoch()); + assertEquals(electionTimeoutMs + jitterMs, + candidateState.remainingElectionTimeMs(time.milliseconds())); + } + + @Test + public void testVotedToVotedSameEpoch() throws IOException { + int node1 = 1; + int node2 = 2; + Set voters = Utils.mkSet(localId, node1, node2); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToUnattached(5); + state.transitionToVoted(8, node1); + assertThrows(IllegalStateException.class, () -> state.transitionToVoted(8, node1)); + assertThrows(IllegalStateException.class, () -> state.transitionToVoted(8, node2)); + } + + @Test + public void testVotedToFollowerSameEpoch() throws IOException { + int node1 = 1; + int node2 = 2; + Set voters = Utils.mkSet(localId, node1, node2); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToVoted(5, node1); + state.transitionToFollower(5, node2); + + FollowerState followerState = state.followerStateOrThrow(); + assertEquals(5, followerState.epoch()); + assertEquals(node2, followerState.leaderId()); + assertEquals(ElectionState.withElectedLeader(5, node2, voters), store.readElectionState()); + } + + @Test + public void testVotedToFollowerHigherEpoch() throws IOException { + int node1 = 1; + int node2 = 2; + Set voters = Utils.mkSet(localId, node1, node2); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToVoted(5, node1); + state.transitionToFollower(8, node2); + + FollowerState followerState = state.followerStateOrThrow(); + assertEquals(8, followerState.epoch()); + assertEquals(node2, followerState.leaderId()); + assertEquals(ElectionState.withElectedLeader(8, node2, voters), store.readElectionState()); + } + + @Test + public void testVotedToUnattachedSameEpoch() throws IOException { + int node1 = 1; + int node2 = 2; + Set voters = Utils.mkSet(localId, node1, node2); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToVoted(5, node1); + assertThrows(IllegalStateException.class, () -> state.transitionToUnattached(5)); + } + + @Test + public void testVotedToUnattachedHigherEpoch() throws IOException { + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToVoted(5, otherNodeId); + + long remainingElectionTimeMs = state.votedStateOrThrow().remainingElectionTimeMs(time.milliseconds()); + time.sleep(1000); + + state.transitionToUnattached(6); + UnattachedState unattachedState = state.unattachedStateOrThrow(); + assertEquals(6, unattachedState.epoch()); + + // Verify that the election timer does not get reset + assertEquals(remainingElectionTimeMs - 1000, + unattachedState.remainingElectionTimeMs(time.milliseconds())); + } + + @Test + public void testVotedToAnyStateLowerEpoch() throws IOException { + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToVoted(5, otherNodeId); + assertThrows(IllegalStateException.class, () -> state.transitionToUnattached(4)); + assertThrows(IllegalStateException.class, () -> state.transitionToVoted(4, otherNodeId)); + assertThrows(IllegalStateException.class, () -> state.transitionToFollower(4, otherNodeId)); + assertEquals(5, state.epoch()); + assertEquals(ElectionState.withVotedCandidate(5, otherNodeId, voters), store.readElectionState()); + } + + @Test + public void testFollowerToFollowerSameEpoch() throws IOException { + int node1 = 1; + int node2 = 2; + Set voters = Utils.mkSet(localId, node1, node2); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToFollower(8, node2); + assertThrows(IllegalStateException.class, () -> state.transitionToFollower(8, node1)); + assertThrows(IllegalStateException.class, () -> state.transitionToFollower(8, node2)); + + FollowerState followerState = state.followerStateOrThrow(); + assertEquals(8, followerState.epoch()); + assertEquals(node2, followerState.leaderId()); + assertEquals(ElectionState.withElectedLeader(8, node2, voters), store.readElectionState()); + } + + @Test + public void testFollowerToFollowerHigherEpoch() throws IOException { + int node1 = 1; + int node2 = 2; + Set voters = Utils.mkSet(localId, node1, node2); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToFollower(8, node2); + state.transitionToFollower(9, node1); + + FollowerState followerState = state.followerStateOrThrow(); + assertEquals(9, followerState.epoch()); + assertEquals(node1, followerState.leaderId()); + assertEquals(ElectionState.withElectedLeader(9, node1, voters), store.readElectionState()); + } + + @Test + public void testFollowerToLeaderOrResigned() throws IOException { + int node1 = 1; + int node2 = 2; + Set voters = Utils.mkSet(localId, node1, node2); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToFollower(8, node2); + assertThrows(IllegalStateException.class, () -> state.transitionToLeader(0, accumulator)); + assertThrows(IllegalStateException.class, () -> state.transitionToResigned(Collections.emptyList())); + } + + @Test + public void testFollowerToCandidate() throws IOException { + int node1 = 1; + int node2 = 2; + Set voters = Utils.mkSet(localId, node1, node2); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToFollower(8, node2); + + int jitterMs = 2500; + random.mockNextInt(electionTimeoutMs, jitterMs); + state.transitionToCandidate(); + assertTrue(state.isCandidate()); + CandidateState candidateState = state.candidateStateOrThrow(); + assertEquals(9, candidateState.epoch()); + assertEquals(electionTimeoutMs + jitterMs, + candidateState.remainingElectionTimeMs(time.milliseconds())); + } + + @Test + public void testFollowerToUnattachedSameEpoch() throws IOException { + int node1 = 1; + int node2 = 2; + Set voters = Utils.mkSet(localId, node1, node2); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToFollower(8, node2); + assertThrows(IllegalStateException.class, () -> state.transitionToUnattached(8)); + } + + @Test + public void testFollowerToUnattachedHigherEpoch() throws IOException { + int node1 = 1; + int node2 = 2; + Set voters = Utils.mkSet(localId, node1, node2); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToFollower(8, node2); + + int jitterMs = 2500; + random.mockNextInt(electionTimeoutMs, jitterMs); + state.transitionToUnattached(9); + assertTrue(state.isUnattached()); + UnattachedState unattachedState = state.unattachedStateOrThrow(); + assertEquals(9, unattachedState.epoch()); + assertEquals(electionTimeoutMs + jitterMs, + unattachedState.remainingElectionTimeMs(time.milliseconds())); + } + + @Test + public void testFollowerToVotedSameEpoch() throws IOException { + int node1 = 1; + int node2 = 2; + Set voters = Utils.mkSet(localId, node1, node2); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToFollower(8, node2); + + assertThrows(IllegalStateException.class, () -> state.transitionToVoted(8, node1)); + assertThrows(IllegalStateException.class, () -> state.transitionToVoted(8, localId)); + assertThrows(IllegalStateException.class, () -> state.transitionToVoted(8, node2)); + } + + @Test + public void testFollowerToVotedHigherEpoch() throws IOException { + int node1 = 1; + int node2 = 2; + Set voters = Utils.mkSet(localId, node1, node2); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToFollower(8, node2); + + int jitterMs = 2500; + random.mockNextInt(electionTimeoutMs, jitterMs); + state.transitionToVoted(9, node1); + assertTrue(state.isVoted()); + VotedState votedState = state.votedStateOrThrow(); + assertEquals(9, votedState.epoch()); + assertEquals(node1, votedState.votedId()); + assertEquals(electionTimeoutMs + jitterMs, + votedState.remainingElectionTimeMs(time.milliseconds())); + } + + @Test + public void testFollowerToAnyStateLowerEpoch() throws IOException { + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + state.transitionToFollower(5, otherNodeId); + assertThrows(IllegalStateException.class, () -> state.transitionToUnattached(4)); + assertThrows(IllegalStateException.class, () -> state.transitionToVoted(4, otherNodeId)); + assertThrows(IllegalStateException.class, () -> state.transitionToFollower(4, otherNodeId)); + assertEquals(5, state.epoch()); + assertEquals(ElectionState.withElectedLeader(5, otherNodeId, voters), store.readElectionState()); + } + + @Test + public void testCannotBecomeFollowerOfNonVoter() throws IOException { + int otherNodeId = 1; + int nonVoterId = 2; + Set voters = Utils.mkSet(localId, otherNodeId); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + assertThrows(IllegalStateException.class, () -> state.transitionToVoted(4, nonVoterId)); + assertThrows(IllegalStateException.class, () -> state.transitionToFollower(4, nonVoterId)); + } + + @Test + public void testObserverCannotBecomeCandidateOrLeaderOrVoted() throws IOException { + int otherNodeId = 1; + Set voters = Utils.mkSet(otherNodeId); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + assertTrue(state.isObserver()); + assertThrows(IllegalStateException.class, state::transitionToCandidate); + assertThrows(IllegalStateException.class, () -> state.transitionToLeader(0L, accumulator)); + assertThrows(IllegalStateException.class, () -> state.transitionToVoted(5, otherNodeId)); + } + + @Test + public void testObserverFollowerToUnattached() throws IOException { + int node1 = 1; + int node2 = 2; + Set voters = Utils.mkSet(node1, node2); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + assertTrue(state.isObserver()); + + state.transitionToFollower(2, node1); + state.transitionToUnattached(3); + assertTrue(state.isUnattached()); + UnattachedState unattachedState = state.unattachedStateOrThrow(); + assertEquals(3, unattachedState.epoch()); + + // Observers can remain in the unattached state indefinitely until a leader is found + assertEquals(Long.MAX_VALUE, unattachedState.electionTimeoutMs()); + } + + @Test + public void testObserverUnattachedToFollower() throws IOException { + int node1 = 1; + int node2 = 2; + Set voters = Utils.mkSet(node1, node2); + QuorumState state = initializeEmptyState(voters); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + assertTrue(state.isObserver()); + + state.transitionToUnattached(2); + state.transitionToFollower(3, node1); + assertTrue(state.isFollower()); + FollowerState followerState = state.followerStateOrThrow(); + assertEquals(3, followerState.epoch()); + assertEquals(node1, followerState.leaderId()); + assertEquals(fetchTimeoutMs, followerState.remainingFetchTimeMs(time.milliseconds())); + } + + @Test + public void testInitializeWithCorruptedStore() { + QuorumStateStore stateStore = Mockito.mock(QuorumStateStore.class); + Mockito.doThrow(UncheckedIOException.class).when(stateStore).readElectionState(); + + QuorumState state = buildQuorumState(Utils.mkSet(localId)); + + int epoch = 2; + state.initialize(new OffsetAndEpoch(0L, epoch)); + assertEquals(epoch, state.epoch()); + assertTrue(state.isUnattached()); + assertFalse(state.hasLeader()); + } + + @Test + public void testInconsistentVotersBetweenConfigAndState() throws IOException { + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + + QuorumState state = initializeEmptyState(voters); + + int unknownVoterId = 2; + Set stateVoters = Utils.mkSet(localId, otherNodeId, unknownVoterId); + + int epoch = 5; + store.writeElectionState(ElectionState.withElectedLeader(epoch, localId, stateVoters)); + assertThrows(IllegalStateException.class, + () -> state.initialize(new OffsetAndEpoch(0L, logEndEpoch))); + } + + @Test + public void testHasRemoteLeader() throws IOException { + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + + QuorumState state = initializeEmptyState(voters); + assertFalse(state.hasRemoteLeader()); + + state.transitionToCandidate(); + assertFalse(state.hasRemoteLeader()); + + state.candidateStateOrThrow().recordGrantedVote(otherNodeId); + state.transitionToLeader(0L, accumulator); + assertFalse(state.hasRemoteLeader()); + + state.transitionToUnattached(state.epoch() + 1); + assertFalse(state.hasRemoteLeader()); + + state.transitionToVoted(state.epoch() + 1, otherNodeId); + assertFalse(state.hasRemoteLeader()); + + state.transitionToFollower(state.epoch() + 1, otherNodeId); + assertTrue(state.hasRemoteLeader()); + } + + @Test + public void testHighWatermarkRetained() throws IOException { + int otherNodeId = 1; + Set voters = Utils.mkSet(localId, otherNodeId); + + QuorumState state = initializeEmptyState(voters); + state.transitionToFollower(5, otherNodeId); + + FollowerState followerState = state.followerStateOrThrow(); + followerState.updateHighWatermark(OptionalLong.of(10L)); + + Optional highWatermark = Optional.of(new LogOffsetMetadata(10L)); + assertEquals(highWatermark, state.highWatermark()); + + state.transitionToUnattached(6); + assertEquals(highWatermark, state.highWatermark()); + + state.transitionToVoted(7, otherNodeId); + assertEquals(highWatermark, state.highWatermark()); + + state.transitionToCandidate(); + assertEquals(highWatermark, state.highWatermark()); + + CandidateState candidateState = state.candidateStateOrThrow(); + candidateState.recordGrantedVote(otherNodeId); + assertTrue(candidateState.isVoteGranted()); + + state.transitionToLeader(10L, accumulator); + assertEquals(Optional.empty(), state.highWatermark()); + } + + @Test + public void testInitializeWithEmptyLocalId() throws IOException { + QuorumState state = buildQuorumState(OptionalInt.empty(), Utils.mkSet(0, 1)); + state.initialize(new OffsetAndEpoch(0L, 0)); + + assertTrue(state.isObserver()); + assertFalse(state.isVoter()); + + assertThrows(IllegalStateException.class, state::transitionToCandidate); + assertThrows(IllegalStateException.class, () -> state.transitionToVoted(1, 1)); + assertThrows(IllegalStateException.class, () -> state.transitionToLeader(0L, accumulator)); + + state.transitionToFollower(1, 1); + assertTrue(state.isFollower()); + + state.transitionToUnattached(2); + assertTrue(state.isUnattached()); + } + + @Test + public void testObserverInitializationFailsIfElectionStateHasVotedCandidate() { + Set voters = Utils.mkSet(0, 1); + int epoch = 5; + int votedId = 1; + + store.writeElectionState(ElectionState.withVotedCandidate(epoch, votedId, voters)); + + QuorumState state1 = buildQuorumState(OptionalInt.of(2), voters); + assertThrows(IllegalStateException.class, () -> state1.initialize(new OffsetAndEpoch(0, 0))); + + QuorumState state2 = buildQuorumState(OptionalInt.empty(), voters); + assertThrows(IllegalStateException.class, () -> state2.initialize(new OffsetAndEpoch(0, 0))); + } + + private QuorumState initializeEmptyState(Set voters) throws IOException { + QuorumState state = buildQuorumState(voters); + store.writeElectionState(ElectionState.withUnknownLeader(0, voters)); + state.initialize(new OffsetAndEpoch(0L, logEndEpoch)); + return state; + } +} diff --git a/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java b/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java new file mode 100644 index 0000000..00c351c --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java @@ -0,0 +1,1199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import java.util.function.Consumer; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.memory.MemoryPool; +import org.apache.kafka.common.message.BeginQuorumEpochRequestData; +import org.apache.kafka.common.message.BeginQuorumEpochResponseData; +import org.apache.kafka.common.message.DescribeQuorumResponseData; +import org.apache.kafka.common.message.DescribeQuorumResponseData.ReplicaState; +import org.apache.kafka.common.message.EndQuorumEpochRequestData; +import org.apache.kafka.common.message.EndQuorumEpochResponseData; +import org.apache.kafka.common.message.FetchRequestData; +import org.apache.kafka.common.message.FetchResponseData; +import org.apache.kafka.common.message.FetchSnapshotResponseData; +import org.apache.kafka.common.message.LeaderChangeMessage; +import org.apache.kafka.common.message.LeaderChangeMessage.Voter; +import org.apache.kafka.common.message.VoteRequestData; +import org.apache.kafka.common.message.VoteResponseData; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ApiMessage; +import org.apache.kafka.common.protocol.Errors; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.ControlRecordType; +import org.apache.kafka.common.record.ControlRecordUtils; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.Record; +import org.apache.kafka.common.record.Records; +import org.apache.kafka.common.requests.BeginQuorumEpochRequest; +import org.apache.kafka.common.requests.BeginQuorumEpochResponse; +import org.apache.kafka.common.requests.DescribeQuorumResponse; +import org.apache.kafka.common.requests.EndQuorumEpochRequest; +import org.apache.kafka.common.requests.EndQuorumEpochResponse; +import org.apache.kafka.common.requests.FetchSnapshotResponse; +import org.apache.kafka.common.requests.VoteRequest; +import org.apache.kafka.common.requests.VoteResponse; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.raft.internals.BatchBuilder; +import org.apache.kafka.raft.internals.StringSerde; +import org.apache.kafka.server.common.serialization.RecordSerde; +import org.apache.kafka.snapshot.RawSnapshotWriter; +import org.apache.kafka.snapshot.SnapshotReader; +import org.apache.kafka.test.TestCondition; +import org.apache.kafka.test.TestUtils; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.OptionalLong; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.apache.kafka.raft.RaftUtil.hasValidTopicPartition; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public final class RaftClientTestContext { + public final RecordSerde serde = Builder.SERDE; + final TopicPartition metadataPartition = Builder.METADATA_PARTITION; + final Uuid metadataTopicId = Uuid.METADATA_TOPIC_ID; + final int electionBackoffMaxMs = Builder.ELECTION_BACKOFF_MAX_MS; + final int fetchMaxWaitMs = Builder.FETCH_MAX_WAIT_MS; + final int fetchTimeoutMs = Builder.FETCH_TIMEOUT_MS; + final int retryBackoffMs = Builder.RETRY_BACKOFF_MS; + + private int electionTimeoutMs; + private int requestTimeoutMs; + private int appendLingerMs; + + private final QuorumStateStore quorumStateStore; + final Uuid clusterId; + private final OptionalInt localId; + public final KafkaRaftClient client; + final Metrics metrics; + public final MockLog log; + final MockNetworkChannel channel; + final MockMessageQueue messageQueue; + final MockTime time; + final MockListener listener; + final Set voters; + + private final List sentResponses = new ArrayList<>(); + + public static final class Builder { + static final int DEFAULT_ELECTION_TIMEOUT_MS = 10000; + + private static final RecordSerde SERDE = new StringSerde(); + private static final TopicPartition METADATA_PARTITION = new TopicPartition("metadata", 0); + private static final int ELECTION_BACKOFF_MAX_MS = 100; + private static final int FETCH_MAX_WAIT_MS = 0; + // fetch timeout is usually larger than election timeout + private static final int FETCH_TIMEOUT_MS = 50000; + private static final int DEFAULT_REQUEST_TIMEOUT_MS = 5000; + private static final int RETRY_BACKOFF_MS = 50; + private static final int DEFAULT_APPEND_LINGER_MS = 0; + + private final MockMessageQueue messageQueue = new MockMessageQueue(); + private final MockTime time = new MockTime(); + private final QuorumStateStore quorumStateStore = new MockQuorumStateStore(); + private final MockableRandom random = new MockableRandom(1L); + private final LogContext logContext = new LogContext(); + private final MockLog log = new MockLog(METADATA_PARTITION, Uuid.METADATA_TOPIC_ID, logContext); + private final Set voters; + private final OptionalInt localId; + + private Uuid clusterId = Uuid.randomUuid(); + private int requestTimeoutMs = DEFAULT_REQUEST_TIMEOUT_MS; + private int electionTimeoutMs = DEFAULT_ELECTION_TIMEOUT_MS; + private int appendLingerMs = DEFAULT_APPEND_LINGER_MS; + private MemoryPool memoryPool = MemoryPool.NONE; + + public Builder(int localId, Set voters) { + this(OptionalInt.of(localId), voters); + } + + public Builder(OptionalInt localId, Set voters) { + this.voters = voters; + this.localId = localId; + } + + Builder withElectedLeader(int epoch, int leaderId) throws IOException { + quorumStateStore.writeElectionState(ElectionState.withElectedLeader(epoch, leaderId, voters)); + return this; + } + + Builder withUnknownLeader(int epoch) throws IOException { + quorumStateStore.writeElectionState(ElectionState.withUnknownLeader(epoch, voters)); + return this; + } + + Builder withVotedCandidate(int epoch, int votedId) throws IOException { + quorumStateStore.writeElectionState(ElectionState.withVotedCandidate(epoch, votedId, voters)); + return this; + } + + Builder updateRandom(Consumer consumer) { + consumer.accept(random); + return this; + } + + Builder withMemoryPool(MemoryPool pool) { + this.memoryPool = pool; + return this; + } + + Builder withAppendLingerMs(int appendLingerMs) { + this.appendLingerMs = appendLingerMs; + return this; + } + + public Builder appendToLog(int epoch, List records) { + MemoryRecords batch = buildBatch( + time.milliseconds(), + log.endOffset().offset, + epoch, + records + ); + log.appendAsLeader(batch, epoch); + return this; + } + + Builder withEmptySnapshot(OffsetAndEpoch snapshotId) throws IOException { + try (RawSnapshotWriter snapshot = log.storeSnapshot(snapshotId).get()) { + snapshot.freeze(); + } + return this; + } + + Builder deleteBeforeSnapshot(OffsetAndEpoch snapshotId) throws IOException { + if (snapshotId.offset > log.highWatermark().offset) { + log.updateHighWatermark(new LogOffsetMetadata(snapshotId.offset)); + } + log.deleteBeforeSnapshot(snapshotId); + + return this; + } + + Builder withElectionTimeoutMs(int electionTimeoutMs) { + this.electionTimeoutMs = electionTimeoutMs; + return this; + } + + Builder withRequestTimeoutMs(int requestTimeoutMs) { + this.requestTimeoutMs = requestTimeoutMs; + return this; + } + + Builder withClusterId(Uuid clusterId) { + this.clusterId = clusterId; + return this; + } + + public RaftClientTestContext build() throws IOException { + Metrics metrics = new Metrics(time); + MockNetworkChannel channel = new MockNetworkChannel(voters); + MockListener listener = new MockListener(localId); + Map voterAddressMap = voters.stream() + .collect(Collectors.toMap(id -> id, RaftClientTestContext::mockAddress)); + RaftConfig raftConfig = new RaftConfig(voterAddressMap, requestTimeoutMs, RETRY_BACKOFF_MS, electionTimeoutMs, + ELECTION_BACKOFF_MAX_MS, FETCH_TIMEOUT_MS, appendLingerMs); + + KafkaRaftClient client = new KafkaRaftClient<>( + SERDE, + channel, + messageQueue, + log, + quorumStateStore, + memoryPool, + time, + metrics, + new MockExpirationService(time), + FETCH_MAX_WAIT_MS, + clusterId.toString(), + localId, + logContext, + random, + raftConfig + ); + + client.register(listener); + client.initialize(); + + RaftClientTestContext context = new RaftClientTestContext( + clusterId, + localId, + client, + log, + channel, + messageQueue, + time, + quorumStateStore, + voters, + metrics, + listener + ); + + context.electionTimeoutMs = electionTimeoutMs; + context.requestTimeoutMs = requestTimeoutMs; + context.appendLingerMs = appendLingerMs; + + return context; + } + } + + private RaftClientTestContext( + Uuid clusterId, + OptionalInt localId, + KafkaRaftClient client, + MockLog log, + MockNetworkChannel channel, + MockMessageQueue messageQueue, + MockTime time, + QuorumStateStore quorumStateStore, + Set voters, + Metrics metrics, + MockListener listener + ) { + this.clusterId = clusterId; + this.localId = localId; + this.client = client; + this.log = log; + this.channel = channel; + this.messageQueue = messageQueue; + this.time = time; + this.quorumStateStore = quorumStateStore; + this.voters = voters; + this.metrics = metrics; + this.listener = listener; + } + + int electionTimeoutMs() { + return electionTimeoutMs; + } + + int requestTimeoutMs() { + return requestTimeoutMs; + } + + int appendLingerMs() { + return appendLingerMs; + } + + MemoryRecords buildBatch( + long baseOffset, + int epoch, + List records + ) { + return buildBatch(time.milliseconds(), baseOffset, epoch, records); + } + + static MemoryRecords buildBatch( + long timestamp, + long baseOffset, + int epoch, + List records + ) { + ByteBuffer buffer = ByteBuffer.allocate(512); + BatchBuilder builder = new BatchBuilder<>( + buffer, + Builder.SERDE, + CompressionType.NONE, + baseOffset, + timestamp, + false, + epoch, + 512 + ); + + for (String record : records) { + builder.appendRecord(record, null); + } + + return builder.build(); + } + + static RaftClientTestContext initializeAsLeader(int localId, Set voters, int epoch) throws Exception { + if (epoch <= 0) { + throw new IllegalArgumentException("Cannot become leader in epoch " + epoch); + } + + RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters) + .withUnknownLeader(epoch - 1) + .build(); + + context.assertUnknownLeader(epoch - 1); + context.becomeLeader(); + return context; + } + + public void becomeLeader() throws Exception { + int currentEpoch = currentEpoch(); + time.sleep(electionTimeoutMs * 2); + expectAndGrantVotes(currentEpoch + 1); + expectBeginEpoch(currentEpoch + 1); + } + + public OptionalInt currentLeader() { + return currentLeaderAndEpoch().leaderId(); + } + + public int currentEpoch() { + return currentLeaderAndEpoch().epoch(); + } + + LeaderAndEpoch currentLeaderAndEpoch() { + ElectionState election = quorumStateStore.readElectionState(); + return new LeaderAndEpoch(election.leaderIdOpt, election.epoch); + } + + void expectAndGrantVotes( + int epoch + ) throws Exception { + pollUntilRequest(); + + List voteRequests = collectVoteRequests(epoch, + log.lastFetchedEpoch(), log.endOffset().offset); + + for (RaftRequest.Outbound request : voteRequests) { + VoteResponseData voteResponse = voteResponse(true, Optional.empty(), epoch); + deliverResponse(request.correlationId, request.destinationId(), voteResponse); + } + + client.poll(); + assertElectedLeader(epoch, localIdOrThrow()); + } + + private int localIdOrThrow() { + return localId.orElseThrow(() -> new AssertionError("Required local id is not defined")); + } + + private void expectBeginEpoch( + int epoch + ) throws Exception { + pollUntilRequest(); + for (RaftRequest.Outbound request : collectBeginEpochRequests(epoch)) { + BeginQuorumEpochResponseData beginEpochResponse = beginEpochResponse(epoch, localIdOrThrow()); + deliverResponse(request.correlationId, request.destinationId(), beginEpochResponse); + } + client.poll(); + } + + public void pollUntil(TestCondition condition) throws InterruptedException { + TestUtils.waitForCondition(() -> { + client.poll(); + return condition.conditionMet(); + }, 5000, "Condition failed to be satisfied before timeout"); + } + + void pollUntilResponse() throws InterruptedException { + pollUntil(() -> !sentResponses.isEmpty()); + } + + void pollUntilRequest() throws InterruptedException { + pollUntil(channel::hasSentRequests); + } + + void assertVotedCandidate(int epoch, int leaderId) throws IOException { + assertEquals(ElectionState.withVotedCandidate(epoch, leaderId, voters), quorumStateStore.readElectionState()); + } + + public void assertElectedLeader(int epoch, int leaderId) throws IOException { + assertEquals(ElectionState.withElectedLeader(epoch, leaderId, voters), quorumStateStore.readElectionState()); + } + + void assertUnknownLeader(int epoch) throws IOException { + assertEquals(ElectionState.withUnknownLeader(epoch, voters), quorumStateStore.readElectionState()); + } + + void assertResignedLeader(int epoch, int leaderId) throws IOException { + assertTrue(client.quorum().isResigned()); + assertEquals(ElectionState.withElectedLeader(epoch, leaderId, voters), quorumStateStore.readElectionState()); + } + + int assertSentDescribeQuorumResponse( + int leaderId, + int leaderEpoch, + long highWatermark, + List voterStates, + List observerStates + ) { + List sentMessages = drainSentResponses(ApiKeys.DESCRIBE_QUORUM); + assertEquals(1, sentMessages.size()); + RaftResponse.Outbound raftMessage = sentMessages.get(0); + assertTrue( + raftMessage.data() instanceof DescribeQuorumResponseData, + "Unexpected request type " + raftMessage.data()); + DescribeQuorumResponseData response = (DescribeQuorumResponseData) raftMessage.data(); + + DescribeQuorumResponseData expectedResponse = DescribeQuorumResponse.singletonResponse( + metadataPartition, + leaderId, + leaderEpoch, + highWatermark, + voterStates, + observerStates); + + assertEquals(expectedResponse, response); + return raftMessage.correlationId(); + } + + int assertSentVoteRequest(int epoch, int lastEpoch, long lastEpochOffset, int numVoteReceivers) { + List voteRequests = collectVoteRequests(epoch, lastEpoch, lastEpochOffset); + assertEquals(numVoteReceivers, voteRequests.size()); + return voteRequests.iterator().next().correlationId(); + } + + void assertSentVoteResponse( + Errors error + ) { + List sentMessages = drainSentResponses(ApiKeys.VOTE); + assertEquals(1, sentMessages.size()); + RaftMessage raftMessage = sentMessages.get(0); + assertTrue(raftMessage.data() instanceof VoteResponseData); + VoteResponseData response = (VoteResponseData) raftMessage.data(); + + assertEquals(error, Errors.forCode(response.errorCode())); + } + + void assertSentVoteResponse( + Errors error, + int epoch, + OptionalInt leaderId, + boolean voteGranted + ) { + List sentMessages = drainSentResponses(ApiKeys.VOTE); + assertEquals(1, sentMessages.size()); + RaftMessage raftMessage = sentMessages.get(0); + assertTrue(raftMessage.data() instanceof VoteResponseData); + VoteResponseData response = (VoteResponseData) raftMessage.data(); + assertTrue(hasValidTopicPartition(response, metadataPartition)); + + VoteResponseData.PartitionData partitionResponse = response.topics().get(0).partitions().get(0); + + assertEquals(voteGranted, partitionResponse.voteGranted()); + assertEquals(error, Errors.forCode(partitionResponse.errorCode())); + assertEquals(epoch, partitionResponse.leaderEpoch()); + assertEquals(leaderId.orElse(-1), partitionResponse.leaderId()); + } + + List collectVoteRequests( + int epoch, + int lastEpoch, + long lastEpochOffset + ) { + List voteRequests = new ArrayList<>(); + for (RaftMessage raftMessage : channel.drainSendQueue()) { + if (raftMessage.data() instanceof VoteRequestData) { + VoteRequestData request = (VoteRequestData) raftMessage.data(); + VoteRequestData.PartitionData partitionRequest = unwrap(request); + + assertEquals(epoch, partitionRequest.candidateEpoch()); + assertEquals(localIdOrThrow(), partitionRequest.candidateId()); + assertEquals(lastEpoch, partitionRequest.lastOffsetEpoch()); + assertEquals(lastEpochOffset, partitionRequest.lastOffset()); + voteRequests.add((RaftRequest.Outbound) raftMessage); + } + } + return voteRequests; + } + + void deliverRequest(ApiMessage request) { + RaftRequest.Inbound inboundRequest = new RaftRequest.Inbound( + channel.newCorrelationId(), request, time.milliseconds()); + inboundRequest.completion.whenComplete((response, exception) -> { + if (exception != null) { + throw new RuntimeException(exception); + } else { + sentResponses.add(response); + } + }); + client.handle(inboundRequest); + } + + void deliverResponse(int correlationId, int sourceId, ApiMessage response) { + channel.mockReceive(new RaftResponse.Inbound(correlationId, response, sourceId)); + } + + int assertSentBeginQuorumEpochRequest(int epoch, int numBeginEpochRequests) { + List requests = collectBeginEpochRequests(epoch); + assertEquals(numBeginEpochRequests, requests.size()); + return requests.get(0).correlationId; + } + + private List drainSentResponses( + ApiKeys apiKey + ) { + List res = new ArrayList<>(); + Iterator iterator = sentResponses.iterator(); + while (iterator.hasNext()) { + RaftResponse.Outbound response = iterator.next(); + if (response.data.apiKey() == apiKey.id) { + res.add(response); + iterator.remove(); + } + } + return res; + } + + void assertSentBeginQuorumEpochResponse( + Errors responseError + ) { + List sentMessages = drainSentResponses(ApiKeys.BEGIN_QUORUM_EPOCH); + assertEquals(1, sentMessages.size()); + RaftMessage raftMessage = sentMessages.get(0); + assertTrue(raftMessage.data() instanceof BeginQuorumEpochResponseData); + BeginQuorumEpochResponseData response = (BeginQuorumEpochResponseData) raftMessage.data(); + assertEquals(responseError, Errors.forCode(response.errorCode())); + } + + void assertSentBeginQuorumEpochResponse( + Errors partitionError, + int epoch, + OptionalInt leaderId + ) { + List sentMessages = drainSentResponses(ApiKeys.BEGIN_QUORUM_EPOCH); + assertEquals(1, sentMessages.size()); + RaftMessage raftMessage = sentMessages.get(0); + assertTrue(raftMessage.data() instanceof BeginQuorumEpochResponseData); + BeginQuorumEpochResponseData response = (BeginQuorumEpochResponseData) raftMessage.data(); + assertEquals(Errors.NONE, Errors.forCode(response.errorCode())); + + BeginQuorumEpochResponseData.PartitionData partitionResponse = + response.topics().get(0).partitions().get(0); + + assertEquals(epoch, partitionResponse.leaderEpoch()); + assertEquals(leaderId.orElse(-1), partitionResponse.leaderId()); + assertEquals(partitionError, Errors.forCode(partitionResponse.errorCode())); + } + + int assertSentEndQuorumEpochRequest(int epoch, int destinationId) { + List endQuorumRequests = collectEndQuorumRequests( + epoch, Collections.singleton(destinationId), Optional.empty()); + assertEquals(1, endQuorumRequests.size()); + return endQuorumRequests.get(0).correlationId(); + } + + void assertSentEndQuorumEpochResponse( + Errors responseError + ) { + List sentMessages = drainSentResponses(ApiKeys.END_QUORUM_EPOCH); + assertEquals(1, sentMessages.size()); + RaftMessage raftMessage = sentMessages.get(0); + assertTrue(raftMessage.data() instanceof EndQuorumEpochResponseData); + EndQuorumEpochResponseData response = (EndQuorumEpochResponseData) raftMessage.data(); + assertEquals(responseError, Errors.forCode(response.errorCode())); + } + + void assertSentEndQuorumEpochResponse( + Errors partitionError, + int epoch, + OptionalInt leaderId + ) { + List sentMessages = drainSentResponses(ApiKeys.END_QUORUM_EPOCH); + assertEquals(1, sentMessages.size()); + RaftMessage raftMessage = sentMessages.get(0); + assertTrue(raftMessage.data() instanceof EndQuorumEpochResponseData); + EndQuorumEpochResponseData response = (EndQuorumEpochResponseData) raftMessage.data(); + assertEquals(Errors.NONE, Errors.forCode(response.errorCode())); + + EndQuorumEpochResponseData.PartitionData partitionResponse = + response.topics().get(0).partitions().get(0); + + assertEquals(epoch, partitionResponse.leaderEpoch()); + assertEquals(leaderId.orElse(-1), partitionResponse.leaderId()); + assertEquals(partitionError, Errors.forCode(partitionResponse.errorCode())); + } + + RaftRequest.Outbound assertSentFetchRequest() { + List sentRequests = channel.drainSentRequests(Optional.of(ApiKeys.FETCH)); + assertEquals(1, sentRequests.size()); + return sentRequests.get(0); + } + + int assertSentFetchRequest( + int epoch, + long fetchOffset, + int lastFetchedEpoch + ) { + List sentMessages = channel.drainSendQueue(); + assertEquals(1, sentMessages.size()); + + // TODO: Use more specific type + RaftMessage raftMessage = sentMessages.get(0); + assertFetchRequestData(raftMessage, epoch, fetchOffset, lastFetchedEpoch); + return raftMessage.correlationId(); + } + + FetchResponseData.PartitionData assertSentFetchPartitionResponse() { + List sentMessages = drainSentResponses(ApiKeys.FETCH); + assertEquals( + 1, sentMessages.size(), "Found unexpected sent messages " + sentMessages); + RaftResponse.Outbound raftMessage = sentMessages.get(0); + assertEquals(ApiKeys.FETCH.id, raftMessage.data.apiKey()); + FetchResponseData response = (FetchResponseData) raftMessage.data(); + assertEquals(Errors.NONE, Errors.forCode(response.errorCode())); + + assertEquals(1, response.responses().size()); + assertEquals(metadataPartition.topic(), response.responses().get(0).topic()); + assertEquals(1, response.responses().get(0).partitions().size()); + return response.responses().get(0).partitions().get(0); + } + + void assertSentFetchPartitionResponse(Errors topLevelError) { + List sentMessages = drainSentResponses(ApiKeys.FETCH); + assertEquals( + 1, sentMessages.size(), "Found unexpected sent messages " + sentMessages); + RaftResponse.Outbound raftMessage = sentMessages.get(0); + assertEquals(ApiKeys.FETCH.id, raftMessage.data.apiKey()); + FetchResponseData response = (FetchResponseData) raftMessage.data(); + assertEquals(topLevelError, Errors.forCode(response.errorCode())); + } + + + MemoryRecords assertSentFetchPartitionResponse( + Errors error, + int epoch, + OptionalInt leaderId + ) { + FetchResponseData.PartitionData partitionResponse = assertSentFetchPartitionResponse(); + assertEquals(error, Errors.forCode(partitionResponse.errorCode())); + assertEquals(epoch, partitionResponse.currentLeader().leaderEpoch()); + assertEquals(leaderId.orElse(-1), partitionResponse.currentLeader().leaderId()); + assertEquals(-1, partitionResponse.divergingEpoch().endOffset()); + assertEquals(-1, partitionResponse.divergingEpoch().epoch()); + assertEquals(-1, partitionResponse.snapshotId().endOffset()); + assertEquals(-1, partitionResponse.snapshotId().epoch()); + return (MemoryRecords) partitionResponse.records(); + } + + MemoryRecords assertSentFetchPartitionResponse( + long highWatermark, + int leaderEpoch + ) { + FetchResponseData.PartitionData partitionResponse = assertSentFetchPartitionResponse(); + assertEquals(Errors.NONE, Errors.forCode(partitionResponse.errorCode())); + assertEquals(leaderEpoch, partitionResponse.currentLeader().leaderEpoch()); + assertEquals(highWatermark, partitionResponse.highWatermark()); + assertEquals(-1, partitionResponse.divergingEpoch().endOffset()); + assertEquals(-1, partitionResponse.divergingEpoch().epoch()); + assertEquals(-1, partitionResponse.snapshotId().endOffset()); + assertEquals(-1, partitionResponse.snapshotId().epoch()); + return (MemoryRecords) partitionResponse.records(); + } + + RaftRequest.Outbound assertSentFetchSnapshotRequest() { + List sentRequests = channel.drainSentRequests(Optional.of(ApiKeys.FETCH_SNAPSHOT)); + assertEquals(1, sentRequests.size()); + + return sentRequests.get(0); + } + + void assertSentFetchSnapshotResponse(Errors responseError) { + List sentMessages = drainSentResponses(ApiKeys.FETCH_SNAPSHOT); + assertEquals(1, sentMessages.size()); + + RaftMessage message = sentMessages.get(0); + assertTrue(message.data() instanceof FetchSnapshotResponseData); + + FetchSnapshotResponseData response = (FetchSnapshotResponseData) message.data(); + assertEquals(responseError, Errors.forCode(response.errorCode())); + } + + Optional assertSentFetchSnapshotResponse(TopicPartition topicPartition) { + List sentMessages = drainSentResponses(ApiKeys.FETCH_SNAPSHOT); + assertEquals(1, sentMessages.size()); + + RaftMessage message = sentMessages.get(0); + assertTrue(message.data() instanceof FetchSnapshotResponseData); + + FetchSnapshotResponseData response = (FetchSnapshotResponseData) message.data(); + assertEquals(Errors.NONE, Errors.forCode(response.errorCode())); + + return FetchSnapshotResponse.forTopicPartition(response, topicPartition); + } + + List collectEndQuorumRequests( + int epoch, + Set destinationIdSet, + Optional> preferredSuccessorsOpt + ) { + List endQuorumRequests = new ArrayList<>(); + Set collectedDestinationIdSet = new HashSet<>(); + for (RaftMessage raftMessage : channel.drainSendQueue()) { + if (raftMessage.data() instanceof EndQuorumEpochRequestData) { + EndQuorumEpochRequestData request = (EndQuorumEpochRequestData) raftMessage.data(); + + EndQuorumEpochRequestData.PartitionData partitionRequest = + request.topics().get(0).partitions().get(0); + + assertEquals(epoch, partitionRequest.leaderEpoch()); + assertEquals(localIdOrThrow(), partitionRequest.leaderId()); + preferredSuccessorsOpt.ifPresent(preferredSuccessors -> { + assertEquals(preferredSuccessors, partitionRequest.preferredSuccessors()); + }); + + RaftRequest.Outbound outboundRequest = (RaftRequest.Outbound) raftMessage; + collectedDestinationIdSet.add(outboundRequest.destinationId()); + endQuorumRequests.add(outboundRequest); + } + } + assertEquals(destinationIdSet, collectedDestinationIdSet); + return endQuorumRequests; + } + + void discoverLeaderAsObserver( + int leaderId, + int epoch + ) throws Exception { + pollUntilRequest(); + RaftRequest.Outbound fetchRequest = assertSentFetchRequest(); + assertTrue(voters.contains(fetchRequest.destinationId())); + assertFetchRequestData(fetchRequest, 0, 0L, 0); + + deliverResponse(fetchRequest.correlationId, fetchRequest.destinationId(), + fetchResponse(epoch, leaderId, MemoryRecords.EMPTY, 0L, Errors.NONE)); + client.poll(); + assertElectedLeader(epoch, leaderId); + } + + private List collectBeginEpochRequests(int epoch) { + List requests = new ArrayList<>(); + for (RaftRequest.Outbound raftRequest : channel.drainSentRequests(Optional.of(ApiKeys.BEGIN_QUORUM_EPOCH))) { + assertTrue(raftRequest.data() instanceof BeginQuorumEpochRequestData); + BeginQuorumEpochRequestData request = (BeginQuorumEpochRequestData) raftRequest.data(); + + BeginQuorumEpochRequestData.PartitionData partitionRequest = + request.topics().get(0).partitions().get(0); + + assertEquals(epoch, partitionRequest.leaderEpoch()); + assertEquals(localIdOrThrow(), partitionRequest.leaderId()); + requests.add(raftRequest); + } + return requests; + } + + private static RaftConfig.AddressSpec mockAddress(int id) { + return new RaftConfig.InetAddressSpec(new InetSocketAddress("localhost", 9990 + id)); + } + + EndQuorumEpochResponseData endEpochResponse( + int epoch, + OptionalInt leaderId + ) { + return EndQuorumEpochResponse.singletonResponse( + Errors.NONE, + metadataPartition, + Errors.NONE, + epoch, + leaderId.orElse(-1) + ); + } + + EndQuorumEpochRequestData endEpochRequest( + int epoch, + int leaderId, + List preferredSuccessors + ) { + return EndQuorumEpochRequest.singletonRequest( + metadataPartition, + epoch, + leaderId, + preferredSuccessors + ); + } + + EndQuorumEpochRequestData endEpochRequest( + String clusterId, + int epoch, + int leaderId, + List preferredSuccessors + ) { + return EndQuorumEpochRequest.singletonRequest( + metadataPartition, + clusterId, + epoch, + leaderId, + preferredSuccessors + ); + } + + BeginQuorumEpochRequestData beginEpochRequest(String clusterId, int epoch, int leaderId) { + return BeginQuorumEpochRequest.singletonRequest( + metadataPartition, + clusterId, + epoch, + leaderId + ); + } + + BeginQuorumEpochRequestData beginEpochRequest(int epoch, int leaderId) { + return BeginQuorumEpochRequest.singletonRequest( + metadataPartition, + epoch, + leaderId + ); + } + + private BeginQuorumEpochResponseData beginEpochResponse(int epoch, int leaderId) { + return BeginQuorumEpochResponse.singletonResponse( + Errors.NONE, + metadataPartition, + Errors.NONE, + epoch, + leaderId + ); + } + + VoteRequestData voteRequest(int epoch, int candidateId, int lastEpoch, long lastEpochOffset) { + return VoteRequest.singletonRequest( + metadataPartition, + clusterId.toString(), + epoch, + candidateId, + lastEpoch, + lastEpochOffset + ); + } + + VoteRequestData voteRequest( + String clusterId, + int epoch, + int candidateId, + int lastEpoch, + long lastEpochOffset + ) { + return VoteRequest.singletonRequest( + metadataPartition, + clusterId, + epoch, + candidateId, + lastEpoch, + lastEpochOffset + ); + } + + VoteResponseData voteResponse(boolean voteGranted, Optional leaderId, int epoch) { + return VoteResponse.singletonResponse( + Errors.NONE, + metadataPartition, + Errors.NONE, + epoch, + leaderId.orElse(-1), + voteGranted + ); + } + + private VoteRequestData.PartitionData unwrap(VoteRequestData voteRequest) { + assertTrue(RaftUtil.hasValidTopicPartition(voteRequest, metadataPartition)); + return voteRequest.topics().get(0).partitions().get(0); + } + + static void assertMatchingRecords( + String[] expected, + Records actual + ) { + List recordList = Utils.toList(actual.records()); + assertEquals(expected.length, recordList.size()); + for (int i = 0; i < expected.length; i++) { + Record record = recordList.get(i); + assertEquals(expected[i], Utils.utf8(record.value()), + "Record at offset " + record.offset() + " does not match expected"); + } + } + + static void verifyLeaderChangeMessage( + int leaderId, + List voters, + List grantingVoters, + ByteBuffer recordKey, + ByteBuffer recordValue + ) { + assertEquals(ControlRecordType.LEADER_CHANGE, ControlRecordType.parse(recordKey)); + + LeaderChangeMessage leaderChangeMessage = ControlRecordUtils.deserializeLeaderChangeMessage(recordValue); + assertEquals(leaderId, leaderChangeMessage.leaderId()); + assertEquals(voters.stream().map(voterId -> new Voter().setVoterId(voterId)).collect(Collectors.toList()), + leaderChangeMessage.voters()); + assertEquals(grantingVoters.stream().map(voterId -> new Voter().setVoterId(voterId)).collect(Collectors.toSet()), + new HashSet<>(leaderChangeMessage.grantingVoters())); + } + + void assertFetchRequestData( + RaftMessage message, + int epoch, + long fetchOffset, + int lastFetchedEpoch + ) { + assertTrue( + message.data() instanceof FetchRequestData, "Unexpected request type " + message.data()); + FetchRequestData request = (FetchRequestData) message.data(); + assertEquals(KafkaRaftClient.MAX_FETCH_SIZE_BYTES, request.maxBytes()); + assertEquals(fetchMaxWaitMs, request.maxWaitMs()); + + assertEquals(1, request.topics().size()); + assertEquals(metadataPartition.topic(), request.topics().get(0).topic()); + assertEquals(1, request.topics().get(0).partitions().size()); + + FetchRequestData.FetchPartition fetchPartition = request.topics().get(0).partitions().get(0); + assertEquals(epoch, fetchPartition.currentLeaderEpoch()); + assertEquals(fetchOffset, fetchPartition.fetchOffset()); + assertEquals(lastFetchedEpoch, fetchPartition.lastFetchedEpoch()); + assertEquals(localId.orElse(-1), request.replicaId()); + } + + FetchRequestData fetchRequest( + int epoch, + int replicaId, + long fetchOffset, + int lastFetchedEpoch, + int maxWaitTimeMs + ) { + return fetchRequest( + epoch, + clusterId.toString(), + replicaId, + fetchOffset, + lastFetchedEpoch, + maxWaitTimeMs + ); + } + + FetchRequestData fetchRequest( + int epoch, + String clusterId, + int replicaId, + long fetchOffset, + int lastFetchedEpoch, + int maxWaitTimeMs + ) { + FetchRequestData request = RaftUtil.singletonFetchRequest(metadataPartition, metadataTopicId, fetchPartition -> { + fetchPartition + .setCurrentLeaderEpoch(epoch) + .setLastFetchedEpoch(lastFetchedEpoch) + .setFetchOffset(fetchOffset); + }); + return request + .setMaxWaitMs(maxWaitTimeMs) + .setClusterId(clusterId) + .setReplicaId(replicaId); + } + + FetchResponseData fetchResponse( + int epoch, + int leaderId, + Records records, + long highWatermark, + Errors error + ) { + return RaftUtil.singletonFetchResponse(metadataPartition, metadataTopicId, Errors.NONE, partitionData -> { + partitionData + .setRecords(records) + .setErrorCode(error.code()) + .setHighWatermark(highWatermark); + + partitionData.currentLeader() + .setLeaderEpoch(epoch) + .setLeaderId(leaderId); + }); + } + + FetchResponseData divergingFetchResponse( + int epoch, + int leaderId, + long divergingEpochEndOffset, + int divergingEpoch, + long highWatermark + ) { + return RaftUtil.singletonFetchResponse(metadataPartition, metadataTopicId, Errors.NONE, partitionData -> { + partitionData.setHighWatermark(highWatermark); + + partitionData.currentLeader() + .setLeaderEpoch(epoch) + .setLeaderId(leaderId); + + partitionData.divergingEpoch() + .setEpoch(divergingEpoch) + .setEndOffset(divergingEpochEndOffset); + }); + } + + public void advanceLocalLeaderHighWatermarkToLogEndOffset() throws InterruptedException { + assertEquals(localId, currentLeader()); + long localLogEndOffset = log.endOffset().offset; + Set followers = voters.stream().filter(voter -> voter != localId.getAsInt()).collect(Collectors.toSet()); + + // Send a request from every follower + for (int follower : followers) { + deliverRequest( + fetchRequest(currentEpoch(), follower, localLogEndOffset, currentEpoch(), 0) + ); + pollUntilResponse(); + assertSentFetchPartitionResponse(Errors.NONE, currentEpoch(), localId); + } + + pollUntil(() -> OptionalLong.of(localLogEndOffset).equals(client.highWatermark())); + } + + static class MockListener implements RaftClient.Listener { + private final List> commits = new ArrayList<>(); + private final List> savedBatches = new ArrayList<>(); + private final Map claimedEpochStartOffsets = new HashMap<>(); + private LeaderAndEpoch currentLeaderAndEpoch = new LeaderAndEpoch(OptionalInt.empty(), 0); + private final OptionalInt localId; + private Optional> snapshot = Optional.empty(); + private boolean readCommit = true; + + MockListener(OptionalInt localId) { + this.localId = localId; + } + + int numCommittedBatches() { + return commits.size(); + } + + Long claimedEpochStartOffset(int epoch) { + return claimedEpochStartOffsets.get(epoch); + } + + LeaderAndEpoch currentLeaderAndEpoch() { + return currentLeaderAndEpoch; + } + + Batch lastCommit() { + if (commits.isEmpty()) { + return null; + } else { + return commits.get(commits.size() - 1); + } + } + + OptionalLong lastCommitOffset() { + if (commits.isEmpty()) { + return OptionalLong.empty(); + } else { + return OptionalLong.of(commits.get(commits.size() - 1).lastOffset()); + } + } + + OptionalInt currentClaimedEpoch() { + if (localId.isPresent() && currentLeaderAndEpoch.isLeader(localId.getAsInt())) { + return OptionalInt.of(currentLeaderAndEpoch.epoch()); + } else { + return OptionalInt.empty(); + } + } + + List commitWithBaseOffset(long baseOffset) { + return commits.stream() + .filter(batch -> batch.baseOffset() == baseOffset) + .findFirst() + .map(batch -> batch.records()) + .orElse(null); + } + + List commitWithLastOffset(long lastOffset) { + return commits.stream() + .filter(batch -> batch.lastOffset() == lastOffset) + .findFirst() + .map(batch -> batch.records()) + .orElse(null); + } + + Optional> drainHandledSnapshot() { + Optional> temp = snapshot; + snapshot = Optional.empty(); + return temp; + } + + void updateReadCommit(boolean readCommit) { + this.readCommit = readCommit; + + if (readCommit) { + for (BatchReader batch : savedBatches) { + readBatch(batch); + } + + savedBatches.clear(); + } + } + + void readBatch(BatchReader reader) { + try { + while (reader.hasNext()) { + long nextOffset = lastCommitOffset().isPresent() ? + lastCommitOffset().getAsLong() + 1 : 0L; + Batch batch = reader.next(); + // We expect monotonic offsets, but not necessarily sequential + // offsets since control records will be filtered. + assertTrue(batch.baseOffset() >= nextOffset, + "Received non-monotonic commit " + batch + + ". We expected an offset at least as large as " + nextOffset); + commits.add(batch); + } + } finally { + reader.close(); + } + } + + @Override + public void handleLeaderChange(LeaderAndEpoch leaderAndEpoch) { + // We record the next expected offset as the claimed epoch's start + // offset. This is useful to verify that the `handleLeaderChange` callback + // was not received early. + this.currentLeaderAndEpoch = leaderAndEpoch; + + currentClaimedEpoch().ifPresent(claimedEpoch -> { + long claimedEpochStartOffset = lastCommitOffset().isPresent() ? + lastCommitOffset().getAsLong() + 1 : 0L; + this.claimedEpochStartOffsets.put(leaderAndEpoch.epoch(), claimedEpochStartOffset); + }); + } + + @Override + public void handleCommit(BatchReader reader) { + if (readCommit) { + readBatch(reader); + } else { + savedBatches.add(reader); + } + } + + @Override + public void handleSnapshot(SnapshotReader reader) { + snapshot.ifPresent(snapshot -> assertDoesNotThrow(snapshot::close)); + commits.clear(); + savedBatches.clear(); + snapshot = Optional.of(reader); + } + } +} diff --git a/raft/src/test/java/org/apache/kafka/raft/RaftEventSimulationTest.java b/raft/src/test/java/org/apache/kafka/raft/RaftEventSimulationTest.java new file mode 100644 index 0000000..120eca3 --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/RaftEventSimulationTest.java @@ -0,0 +1,1241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import net.jqwik.api.AfterFailureMode; +import net.jqwik.api.ForAll; +import net.jqwik.api.Property; +import net.jqwik.api.Tag; +import net.jqwik.api.constraints.IntRange; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.memory.MemoryPool; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.protocol.ObjectSerializationCache; +import org.apache.kafka.common.protocol.Readable; +import org.apache.kafka.common.protocol.Writable; +import org.apache.kafka.common.protocol.types.Type; +import org.apache.kafka.common.utils.BufferSupplier; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.raft.MockLog.LogBatch; +import org.apache.kafka.raft.MockLog.LogEntry; +import org.apache.kafka.raft.internals.BatchMemoryPool; +import org.apache.kafka.server.common.serialization.RecordSerde; +import org.apache.kafka.snapshot.SnapshotReader; + +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.OptionalLong; +import java.util.PriorityQueue; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Consumer; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +/** + * The simulation testing framework provides a way to verify quorum behavior under + * different conditions. It is similar to system testing in that the test involves + * independently executing nodes, but there are several important differences: + * + * 1. Simulation behavior is deterministic provided an initial random seed. This + * makes it easy to reproduce and debug test failures. + * 2. The simulation uses an in-memory message router instead of a real network. + * Not only is this much cheaper and faster, it provides an easy way to create + * flaky network conditions or even network partitions without losing the + * simulation determinism. + * 3. Similarly, persistent state is stored in memory. We can nevertheless simulate + * different kinds of failures, such as the loss of unflushed data after a hard + * node restart using {@link MockLog}. + * + * The framework uses a single event scheduler in order to provide deterministic + * executions. Each test is setup as a specific scenario with a variable number of + * voters and observers. Much like system tests, there is typically a warmup + * period, followed by some cluster event (such as a node failure), and then some + * logic to validate behavior after recovery. + * + * If any of the tests fail, the output will indicate the arguments that failed. + * The easiest way to reproduce the failure for debugging is to create a separate + * `@Test` case which invokes the `@Property` method with those arguments directly. + * This ensures that logging output will only include output from a single + * simulation execution. + */ +@Tag("integration") +public class RaftEventSimulationTest { + private static final TopicPartition METADATA_PARTITION = new TopicPartition("__cluster_metadata", 0); + private static final int ELECTION_TIMEOUT_MS = 1000; + private static final int ELECTION_JITTER_MS = 100; + private static final int FETCH_TIMEOUT_MS = 3000; + private static final int RETRY_BACKOFF_MS = 50; + private static final int REQUEST_TIMEOUT_MS = 3000; + private static final int FETCH_MAX_WAIT_MS = 100; + private static final int LINGER_MS = 0; + + @Property(tries = 100, afterFailure = AfterFailureMode.SAMPLE_ONLY) + void canElectInitialLeader( + @ForAll int seed, + @ForAll @IntRange(min = 1, max = 5) int numVoters, + @ForAll @IntRange(min = 0, max = 5) int numObservers + ) { + Random random = new Random(seed); + Cluster cluster = new Cluster(numVoters, numObservers, random); + MessageRouter router = new MessageRouter(cluster); + EventScheduler scheduler = schedulerWithDefaultInvariants(cluster); + + cluster.startAll(); + schedulePolling(scheduler, cluster, 3, 5); + scheduler.schedule(router::deliverAll, 0, 2, 1); + scheduler.schedule(new SequentialAppendAction(cluster), 0, 2, 3); + scheduler.runUntil(cluster::hasConsistentLeader); + scheduler.runUntil(() -> cluster.allReachedHighWatermark(10)); + } + + @Property(tries = 100, afterFailure = AfterFailureMode.SAMPLE_ONLY) + void canElectNewLeaderAfterOldLeaderFailure( + @ForAll int seed, + @ForAll @IntRange(min = 3, max = 5) int numVoters, + @ForAll @IntRange(min = 0, max = 5) int numObservers, + @ForAll boolean isGracefulShutdown + ) { + Random random = new Random(seed); + Cluster cluster = new Cluster(numVoters, numObservers, random); + MessageRouter router = new MessageRouter(cluster); + EventScheduler scheduler = schedulerWithDefaultInvariants(cluster); + + // Seed the cluster with some data + cluster.startAll(); + schedulePolling(scheduler, cluster, 3, 5); + scheduler.schedule(router::deliverAll, 0, 2, 1); + scheduler.schedule(new SequentialAppendAction(cluster), 0, 2, 3); + scheduler.runUntil(cluster::hasConsistentLeader); + scheduler.runUntil(() -> cluster.anyReachedHighWatermark(10)); + + // Shutdown the leader and write some more data. We can verify the new leader has been elected + // by verifying that the high watermark can still advance. + int leaderId = cluster.latestLeader().orElseThrow(() -> + new AssertionError("Failed to find current leader") + ); + + if (isGracefulShutdown) { + cluster.shutdown(leaderId); + } else { + cluster.kill(leaderId); + } + + scheduler.runUntil(() -> cluster.allReachedHighWatermark(20)); + long highWatermark = cluster.maxHighWatermarkReached(); + + // Restart the node and verify it catches up + cluster.start(leaderId); + scheduler.runUntil(() -> cluster.allReachedHighWatermark(highWatermark + 10)); + } + + @Property(tries = 100, afterFailure = AfterFailureMode.SAMPLE_ONLY) + void canRecoverAfterAllNodesKilled( + @ForAll int seed, + @ForAll @IntRange(min = 1, max = 5) int numVoters, + @ForAll @IntRange(min = 0, max = 5) int numObservers + ) { + Random random = new Random(seed); + Cluster cluster = new Cluster(numVoters, numObservers, random); + MessageRouter router = new MessageRouter(cluster); + EventScheduler scheduler = schedulerWithDefaultInvariants(cluster); + + // Seed the cluster with some data + cluster.startAll(); + schedulePolling(scheduler, cluster, 3, 5); + scheduler.schedule(router::deliverAll, 0, 2, 1); + scheduler.schedule(new SequentialAppendAction(cluster), 0, 2, 3); + scheduler.runUntil(cluster::hasConsistentLeader); + scheduler.runUntil(() -> cluster.anyReachedHighWatermark(10)); + long highWatermark = cluster.maxHighWatermarkReached(); + + // We kill all of the nodes. Then we bring back a majority and verify that + // they are able to elect a leader and continue making progress + cluster.killAll(); + + Iterator nodeIdsIterator = cluster.nodes().iterator(); + for (int i = 0; i < cluster.majoritySize(); i++) { + Integer nodeId = nodeIdsIterator.next(); + cluster.start(nodeId); + } + + scheduler.runUntil(() -> cluster.allReachedHighWatermark(highWatermark + 10)); + } + + @Property(tries = 100, afterFailure = AfterFailureMode.SAMPLE_ONLY) + void canElectNewLeaderAfterOldLeaderPartitionedAway( + @ForAll int seed, + @ForAll @IntRange(min = 3, max = 5) int numVoters, + @ForAll @IntRange(min = 0, max = 5) int numObservers + ) { + Random random = new Random(seed); + Cluster cluster = new Cluster(numVoters, numObservers, random); + MessageRouter router = new MessageRouter(cluster); + EventScheduler scheduler = schedulerWithDefaultInvariants(cluster); + + // Seed the cluster with some data + cluster.startAll(); + schedulePolling(scheduler, cluster, 3, 5); + scheduler.schedule(router::deliverAll, 0, 2, 2); + scheduler.schedule(new SequentialAppendAction(cluster), 0, 2, 3); + scheduler.runUntil(cluster::hasConsistentLeader); + scheduler.runUntil(() -> cluster.anyReachedHighWatermark(10)); + + // The leader gets partitioned off. We can verify the new leader has been elected + // by writing some data and ensuring that it gets replicated + int leaderId = cluster.latestLeader().orElseThrow(() -> + new AssertionError("Failed to find current leader") + ); + router.filter(leaderId, new DropAllTraffic()); + + Set nonPartitionedNodes = new HashSet<>(cluster.nodes()); + nonPartitionedNodes.remove(leaderId); + + scheduler.runUntil(() -> cluster.allReachedHighWatermark(20, nonPartitionedNodes)); + } + + @Property(tries = 100, afterFailure = AfterFailureMode.SAMPLE_ONLY) + void canMakeProgressIfMajorityIsReachable( + @ForAll int seed, + @ForAll @IntRange(min = 0, max = 3) int numObservers + ) { + int numVoters = 5; + Random random = new Random(seed); + Cluster cluster = new Cluster(numVoters, numObservers, random); + MessageRouter router = new MessageRouter(cluster); + EventScheduler scheduler = schedulerWithDefaultInvariants(cluster); + + // Seed the cluster with some data + cluster.startAll(); + schedulePolling(scheduler, cluster, 3, 5); + scheduler.schedule(router::deliverAll, 0, 2, 2); + scheduler.schedule(new SequentialAppendAction(cluster), 0, 2, 3); + scheduler.runUntil(cluster::hasConsistentLeader); + scheduler.runUntil(() -> cluster.anyReachedHighWatermark(10)); + + // Partition the nodes into two sets. Nodes are reachable within each set, + // but the two sets cannot communicate with each other. We should be able + // to make progress even if an election is needed in the larger set. + router.filter(0, new DropOutboundRequestsFrom(Utils.mkSet(2, 3, 4))); + router.filter(1, new DropOutboundRequestsFrom(Utils.mkSet(2, 3, 4))); + router.filter(2, new DropOutboundRequestsFrom(Utils.mkSet(0, 1))); + router.filter(3, new DropOutboundRequestsFrom(Utils.mkSet(0, 1))); + router.filter(4, new DropOutboundRequestsFrom(Utils.mkSet(0, 1))); + + long partitionLogEndOffset = cluster.maxLogEndOffset(); + scheduler.runUntil(() -> cluster.anyReachedHighWatermark(2 * partitionLogEndOffset)); + + long minorityHighWatermark = cluster.maxHighWatermarkReached(Utils.mkSet(0, 1)); + long majorityHighWatermark = cluster.maxHighWatermarkReached(Utils.mkSet(2, 3, 4)); + + assertTrue( + majorityHighWatermark > minorityHighWatermark, + String.format( + "majorityHighWatermark = %s, minorityHighWatermark = %s", + majorityHighWatermark, + minorityHighWatermark + ) + ); + + // Now restore the partition and verify everyone catches up + router.filter(0, new PermitAllTraffic()); + router.filter(1, new PermitAllTraffic()); + router.filter(2, new PermitAllTraffic()); + router.filter(3, new PermitAllTraffic()); + router.filter(4, new PermitAllTraffic()); + + long restoredLogEndOffset = cluster.maxLogEndOffset(); + scheduler.runUntil(() -> cluster.allReachedHighWatermark(2 * restoredLogEndOffset)); + } + + @Property(tries = 100, afterFailure = AfterFailureMode.SAMPLE_ONLY) + void canMakeProgressAfterBackToBackLeaderFailures( + @ForAll int seed, + @ForAll @IntRange(min = 3, max = 5) int numVoters, + @ForAll @IntRange(min = 0, max = 5) int numObservers + ) { + Random random = new Random(seed); + Cluster cluster = new Cluster(numVoters, numObservers, random); + MessageRouter router = new MessageRouter(cluster); + EventScheduler scheduler = schedulerWithDefaultInvariants(cluster); + + // Seed the cluster with some data + cluster.startAll(); + schedulePolling(scheduler, cluster, 3, 5); + scheduler.schedule(router::deliverAll, 0, 2, 5); + scheduler.schedule(new SequentialAppendAction(cluster), 0, 2, 3); + scheduler.runUntil(cluster::hasConsistentLeader); + scheduler.runUntil(() -> cluster.anyReachedHighWatermark(10)); + + int leaderId = cluster.latestLeader().getAsInt(); + router.filter(leaderId, new DropAllTraffic()); + scheduler.runUntil(() -> cluster.latestLeader().isPresent() && cluster.latestLeader().getAsInt() != leaderId); + + // As soon as we have a new leader, restore traffic to the old leader and partition the new leader + int newLeaderId = cluster.latestLeader().getAsInt(); + router.filter(leaderId, new PermitAllTraffic()); + router.filter(newLeaderId, new DropAllTraffic()); + + // Verify now that we can make progress + long targetHighWatermark = cluster.maxHighWatermarkReached() + 10; + scheduler.runUntil(() -> cluster.anyReachedHighWatermark(targetHighWatermark)); + } + + @Property(tries = 100, afterFailure = AfterFailureMode.SAMPLE_ONLY) + void canRecoverFromSingleNodeCommittedDataLoss( + @ForAll int seed, + @ForAll @IntRange(min = 3, max = 5) int numVoters, + @ForAll @IntRange(min = 0, max = 2) int numObservers + ) { + // We run this test without the `MonotonicEpoch` and `MajorityReachedHighWatermark` + // invariants since the loss of committed data on one node can violate them. + + Random random = new Random(seed); + Cluster cluster = new Cluster(numVoters, numObservers, random); + EventScheduler scheduler = new EventScheduler(cluster.random, cluster.time); + scheduler.addInvariant(new MonotonicHighWatermark(cluster)); + scheduler.addInvariant(new SingleLeader(cluster)); + scheduler.addValidation(new ConsistentCommittedData(cluster)); + + MessageRouter router = new MessageRouter(cluster); + + cluster.startAll(); + schedulePolling(scheduler, cluster, 3, 5); + scheduler.schedule(router::deliverAll, 0, 2, 5); + scheduler.schedule(new SequentialAppendAction(cluster), 0, 2, 3); + scheduler.runUntil(() -> cluster.anyReachedHighWatermark(10)); + + RaftNode node = cluster.randomRunning().orElseThrow(() -> + new AssertionError("Failed to find running node") + ); + + // Kill a random node and drop all of its persistent state. The Raft + // protocol guarantees should still ensure we lose no committed data + // as long as a new leader is elected before the failed node is restarted. + cluster.killAndDeletePersistentState(node.nodeId); + scheduler.runUntil(() -> !cluster.hasLeader(node.nodeId) && cluster.hasConsistentLeader()); + + // Now restart the failed node and ensure that it recovers. + long highWatermarkBeforeRestart = cluster.maxHighWatermarkReached(); + cluster.start(node.nodeId); + scheduler.runUntil(() -> cluster.allReachedHighWatermark(highWatermarkBeforeRestart + 10)); + } + + private EventScheduler schedulerWithDefaultInvariants(Cluster cluster) { + EventScheduler scheduler = new EventScheduler(cluster.random, cluster.time); + scheduler.addInvariant(new MonotonicHighWatermark(cluster)); + scheduler.addInvariant(new MonotonicEpoch(cluster)); + scheduler.addInvariant(new MajorityReachedHighWatermark(cluster)); + scheduler.addInvariant(new SingleLeader(cluster)); + scheduler.addInvariant(new SnapshotAtLogStart(cluster)); + scheduler.addInvariant(new LeaderNeverLoadSnapshot(cluster)); + scheduler.addValidation(new ConsistentCommittedData(cluster)); + return scheduler; + } + + private void schedulePolling(EventScheduler scheduler, + Cluster cluster, + int pollIntervalMs, + int pollJitterMs) { + int delayMs = 0; + for (int nodeId : cluster.nodes()) { + scheduler.schedule(() -> cluster.pollIfRunning(nodeId), delayMs, pollIntervalMs, pollJitterMs); + delayMs++; + } + } + + private static abstract class Event implements Comparable { + final int eventId; + final long deadlineMs; + final Runnable action; + + protected Event(Runnable action, int eventId, long deadlineMs) { + this.action = action; + this.eventId = eventId; + this.deadlineMs = deadlineMs; + } + + void execute(EventScheduler scheduler) { + action.run(); + } + + public int compareTo(Event other) { + int compare = Long.compare(deadlineMs, other.deadlineMs); + if (compare != 0) + return compare; + return Integer.compare(eventId, other.eventId); + } + } + + private static class PeriodicEvent extends Event { + final Random random; + final int periodMs; + final int jitterMs; + + protected PeriodicEvent(Runnable action, + int eventId, + Random random, + long deadlineMs, + int periodMs, + int jitterMs) { + super(action, eventId, deadlineMs); + this.random = random; + this.periodMs = periodMs; + this.jitterMs = jitterMs; + } + + @Override + void execute(EventScheduler scheduler) { + super.execute(scheduler); + int nextExecDelayMs = periodMs + (jitterMs == 0 ? 0 : random.nextInt(jitterMs)); + scheduler.schedule(action, nextExecDelayMs, periodMs, jitterMs); + } + } + + private static class SequentialAppendAction implements Runnable { + final Cluster cluster; + + private SequentialAppendAction(Cluster cluster) { + this.cluster = cluster; + } + + @Override + public void run() { + cluster.withCurrentLeader(node -> { + if (!node.client.isShuttingDown() && node.counter.isWritable()) + node.counter.increment(); + }); + } + } + + private interface Invariant { + void verify(); + } + + private interface Validation { + void validate(); + } + + private static class EventScheduler { + private static final int MAX_ITERATIONS = 500000; + + final AtomicInteger eventIdGenerator = new AtomicInteger(0); + final PriorityQueue queue = new PriorityQueue<>(); + final Random random; + final Time time; + final List invariants = new ArrayList<>(); + final List validations = new ArrayList<>(); + + private EventScheduler(Random random, Time time) { + this.random = random; + this.time = time; + } + + // Add an invariant, which is checked after every event + private void addInvariant(Invariant invariant) { + invariants.add(invariant); + } + + // Add a validation, which is checked at the end of the simulation + private void addValidation(Validation validation) { + validations.add(validation); + } + + void schedule(Runnable action, int delayMs, int periodMs, int jitterMs) { + long initialDeadlineMs = time.milliseconds() + delayMs; + int eventId = eventIdGenerator.incrementAndGet(); + PeriodicEvent event = new PeriodicEvent(action, eventId, random, initialDeadlineMs, periodMs, jitterMs); + queue.offer(event); + } + + void runUntil(Supplier exitCondition) { + for (int iteration = 0; iteration < MAX_ITERATIONS; iteration++) { + if (exitCondition.get()) { + break; + } + + if (queue.isEmpty()) { + throw new IllegalStateException("Event queue exhausted before condition was satisfied"); + } + + Event event = queue.poll(); + long delayMs = Math.max(event.deadlineMs - time.milliseconds(), 0); + time.sleep(delayMs); + event.execute(this); + invariants.forEach(Invariant::verify); + } + + assertTrue(exitCondition.get(), "Simulation condition was not satisfied after " + + MAX_ITERATIONS + " iterations"); + + validations.forEach(Validation::validate); + } + } + + private static class PersistentState { + final MockQuorumStateStore store = new MockQuorumStateStore(); + final MockLog log; + + PersistentState(int nodeId) { + log = new MockLog( + METADATA_PARTITION, + Uuid.METADATA_TOPIC_ID, + new LogContext(String.format("[Node %s] ", nodeId)) + ); + } + } + + private static class Cluster { + final Random random; + final AtomicInteger correlationIdCounter = new AtomicInteger(); + final MockTime time = new MockTime(); + final Uuid clusterId = Uuid.randomUuid(); + final Set voters = new HashSet<>(); + final Map nodes = new HashMap<>(); + final Map running = new HashMap<>(); + + private Cluster(int numVoters, int numObservers, Random random) { + this.random = random; + + int nodeId = 0; + for (; nodeId < numVoters; nodeId++) { + voters.add(nodeId); + nodes.put(nodeId, new PersistentState(nodeId)); + } + + for (; nodeId < numVoters + numObservers; nodeId++) { + nodes.put(nodeId, new PersistentState(nodeId)); + } + } + + Set nodes() { + return nodes.keySet(); + } + + int majoritySize() { + return voters.size() / 2 + 1; + } + + long maxLogEndOffset() { + return running + .values() + .stream() + .mapToLong(RaftNode::logEndOffset) + .max() + .orElse(0L); + } + + OptionalLong leaderHighWatermark() { + Optional leaderWithMaxEpoch = running + .values() + .stream() + .filter(node -> node.client.quorum().isLeader()) + .max((node1, node2) -> Integer.compare(node2.client.quorum().epoch(), node1.client.quorum().epoch())); + if (leaderWithMaxEpoch.isPresent()) { + return leaderWithMaxEpoch.get().client.highWatermark(); + } else { + return OptionalLong.empty(); + } + } + + boolean anyReachedHighWatermark(long offset) { + return running.values().stream() + .anyMatch(node -> node.highWatermark() > offset); + } + + long maxHighWatermarkReached() { + return running.values().stream() + .mapToLong(RaftNode::highWatermark) + .max() + .orElse(0L); + } + + long maxHighWatermarkReached(Set nodeIds) { + return running.values().stream() + .filter(node -> nodeIds.contains(node.nodeId)) + .mapToLong(RaftNode::highWatermark) + .max() + .orElse(0L); + } + + boolean allReachedHighWatermark(long offset, Set nodeIds) { + return nodeIds.stream() + .allMatch(nodeId -> running.get(nodeId).highWatermark() >= offset); + } + + boolean allReachedHighWatermark(long offset) { + return running.values().stream() + .allMatch(node -> node.highWatermark() >= offset); + } + + boolean hasLeader(int nodeId) { + OptionalInt latestLeader = latestLeader(); + return latestLeader.isPresent() && latestLeader.getAsInt() == nodeId; + } + + OptionalInt latestLeader() { + OptionalInt latestLeader = OptionalInt.empty(); + int latestEpoch = 0; + + for (RaftNode node : running.values()) { + if (node.client.quorum().epoch() > latestEpoch) { + latestLeader = node.client.quorum().leaderId(); + latestEpoch = node.client.quorum().epoch(); + } else if (node.client.quorum().epoch() == latestEpoch && node.client.quorum().leaderId().isPresent()) { + latestLeader = node.client.quorum().leaderId(); + } + } + return latestLeader; + } + + boolean hasConsistentLeader() { + Iterator iter = running.values().iterator(); + if (!iter.hasNext()) + return false; + + RaftNode first = iter.next(); + ElectionState election = first.store.readElectionState(); + if (!election.hasLeader()) + return false; + + while (iter.hasNext()) { + RaftNode next = iter.next(); + if (!election.equals(next.store.readElectionState())) + return false; + } + + return true; + } + + void killAll() { + running.clear(); + } + + void kill(int nodeId) { + running.remove(nodeId); + } + + void shutdown(int nodeId) { + RaftNode node = running.get(nodeId); + if (node == null) { + throw new IllegalStateException("Attempt to shutdown a node which is not currently running"); + } + node.client.shutdown(500).whenComplete((res, exception) -> kill(nodeId)); + } + + void pollIfRunning(int nodeId) { + ifRunning(nodeId, RaftNode::poll); + } + + Optional nodeIfRunning(int nodeId) { + return Optional.ofNullable(running.get(nodeId)); + } + + Collection running() { + return running.values(); + } + + void ifRunning(int nodeId, Consumer action) { + nodeIfRunning(nodeId).ifPresent(action); + } + + Optional randomRunning() { + List nodes = new ArrayList<>(running.values()); + if (nodes.isEmpty()) { + return Optional.empty(); + } else { + return Optional.of(nodes.get(random.nextInt(nodes.size()))); + } + } + + void withCurrentLeader(Consumer action) { + for (RaftNode node : running.values()) { + if (node.client.quorum().isLeader()) { + action.accept(node); + } + } + } + + void forAllRunning(Consumer action) { + running.values().forEach(action); + } + + void startAll() { + if (!running.isEmpty()) + throw new IllegalStateException("Some nodes are already started"); + for (int voterId : nodes.keySet()) { + start(voterId); + } + } + + void killAndDeletePersistentState(int nodeId) { + kill(nodeId); + nodes.put(nodeId, new PersistentState(nodeId)); + } + + private static RaftConfig.AddressSpec nodeAddress(int id) { + return new RaftConfig.InetAddressSpec(new InetSocketAddress("localhost", 9990 + id)); + } + + void start(int nodeId) { + LogContext logContext = new LogContext("[Node " + nodeId + "] "); + PersistentState persistentState = nodes.get(nodeId); + MockNetworkChannel channel = new MockNetworkChannel(correlationIdCounter, voters); + MockMessageQueue messageQueue = new MockMessageQueue(); + Map voterAddressMap = voters.stream() + .collect(Collectors.toMap(id -> id, Cluster::nodeAddress)); + RaftConfig raftConfig = new RaftConfig(voterAddressMap, REQUEST_TIMEOUT_MS, RETRY_BACKOFF_MS, ELECTION_TIMEOUT_MS, + ELECTION_JITTER_MS, FETCH_TIMEOUT_MS, LINGER_MS); + Metrics metrics = new Metrics(time); + + persistentState.log.reopen(); + + IntSerde serde = new IntSerde(); + MemoryPool memoryPool = new BatchMemoryPool(2, KafkaRaftClient.MAX_BATCH_SIZE_BYTES); + + KafkaRaftClient client = new KafkaRaftClient<>( + serde, + channel, + messageQueue, + persistentState.log, + persistentState.store, + memoryPool, + time, + metrics, + new MockExpirationService(time), + FETCH_MAX_WAIT_MS, + clusterId.toString(), + OptionalInt.of(nodeId), + logContext, + random, + raftConfig + ); + RaftNode node = new RaftNode( + nodeId, + client, + persistentState.log, + channel, + messageQueue, + persistentState.store, + logContext, + time, + random, + serde + ); + node.initialize(); + running.put(nodeId, node); + } + } + + private static class RaftNode { + final int nodeId; + final KafkaRaftClient client; + final MockLog log; + final MockNetworkChannel channel; + final MockMessageQueue messageQueue; + final MockQuorumStateStore store; + final LogContext logContext; + final ReplicatedCounter counter; + final Time time; + final Random random; + final RecordSerde intSerde; + + private RaftNode( + int nodeId, + KafkaRaftClient client, + MockLog log, + MockNetworkChannel channel, + MockMessageQueue messageQueue, + MockQuorumStateStore store, + LogContext logContext, + Time time, + Random random, + RecordSerde intSerde + ) { + this.nodeId = nodeId; + this.client = client; + this.log = log; + this.channel = channel; + this.messageQueue = messageQueue; + this.store = store; + this.logContext = logContext; + this.time = time; + this.random = random; + this.counter = new ReplicatedCounter(nodeId, client, logContext); + this.intSerde = intSerde; + } + + void initialize() { + client.register(this.counter); + client.initialize(); + } + + void poll() { + try { + do { + client.poll(); + } while (client.isRunning() && !messageQueue.isEmpty()); + } catch (Exception e) { + throw new RuntimeException("Uncaught exception during poll of node " + nodeId, e); + } + } + + long highWatermark() { + return client.quorum().highWatermark() + .map(hw -> hw.offset) + .orElse(0L); + } + + long logEndOffset() { + return log.endOffset().offset; + } + + @Override + public String toString() { + return String.format( + "Node(id=%s, hw=%s, logEndOffset=%s)", + nodeId, + highWatermark(), + logEndOffset() + ); + } + } + + private static class InflightRequest { + final int correlationId; + final int sourceId; + final int destinationId; + + private InflightRequest(int correlationId, int sourceId, int destinationId) { + this.correlationId = correlationId; + this.sourceId = sourceId; + this.destinationId = destinationId; + } + } + + private interface NetworkFilter { + boolean acceptInbound(RaftMessage message); + boolean acceptOutbound(RaftMessage message); + } + + private static class PermitAllTraffic implements NetworkFilter { + + @Override + public boolean acceptInbound(RaftMessage message) { + return true; + } + + @Override + public boolean acceptOutbound(RaftMessage message) { + return true; + } + } + + private static class DropAllTraffic implements NetworkFilter { + + @Override + public boolean acceptInbound(RaftMessage message) { + return false; + } + + @Override + public boolean acceptOutbound(RaftMessage message) { + return false; + } + } + + private static class DropOutboundRequestsFrom implements NetworkFilter { + + private final Set unreachable; + + private DropOutboundRequestsFrom(Set unreachable) { + this.unreachable = unreachable; + } + + @Override + public boolean acceptInbound(RaftMessage message) { + return true; + } + + @Override + public boolean acceptOutbound(RaftMessage message) { + if (message instanceof RaftRequest.Outbound) { + RaftRequest.Outbound request = (RaftRequest.Outbound) message; + return !unreachable.contains(request.destinationId()); + } + return true; + } + } + + private static class MonotonicEpoch implements Invariant { + final Cluster cluster; + final Map nodeEpochs = new HashMap<>(); + + private MonotonicEpoch(Cluster cluster) { + this.cluster = cluster; + for (Map.Entry nodeStateEntry : cluster.nodes.entrySet()) { + Integer nodeId = nodeStateEntry.getKey(); + nodeEpochs.put(nodeId, 0); + } + } + + @Override + public void verify() { + for (Map.Entry nodeStateEntry : cluster.nodes.entrySet()) { + Integer nodeId = nodeStateEntry.getKey(); + PersistentState state = nodeStateEntry.getValue(); + Integer oldEpoch = nodeEpochs.get(nodeId); + + ElectionState electionState = state.store.readElectionState(); + if (electionState == null) { + continue; + } + + Integer newEpoch = electionState.epoch; + if (oldEpoch > newEpoch) { + fail("Non-monotonic update of epoch detected on node " + nodeId + ": " + + oldEpoch + " -> " + newEpoch); + } + cluster.ifRunning(nodeId, nodeState -> { + assertEquals(newEpoch.intValue(), nodeState.client.quorum().epoch()); + }); + nodeEpochs.put(nodeId, newEpoch); + } + } + } + + private static class MajorityReachedHighWatermark implements Invariant { + final Cluster cluster; + + private MajorityReachedHighWatermark(Cluster cluster) { + this.cluster = cluster; + } + + @Override + public void verify() { + cluster.leaderHighWatermark().ifPresent(highWatermark -> { + long numReachedHighWatermark = cluster.nodes.entrySet().stream() + .filter(entry -> cluster.voters.contains(entry.getKey())) + .filter(entry -> entry.getValue().log.endOffset().offset >= highWatermark) + .count(); + assertTrue( + numReachedHighWatermark >= cluster.majoritySize(), + "Insufficient nodes have reached current high watermark"); + }); + } + } + + private static class SingleLeader implements Invariant { + final Cluster cluster; + int epoch = 0; + OptionalInt leaderId = OptionalInt.empty(); + + private SingleLeader(Cluster cluster) { + this.cluster = cluster; + } + + @Override + public void verify() { + for (Map.Entry nodeEntry : cluster.nodes.entrySet()) { + PersistentState state = nodeEntry.getValue(); + ElectionState electionState = state.store.readElectionState(); + + if (electionState != null && electionState.epoch >= epoch && electionState.hasLeader()) { + if (epoch == electionState.epoch && leaderId.isPresent()) { + assertEquals(leaderId.getAsInt(), electionState.leaderId()); + } else { + epoch = electionState.epoch; + leaderId = OptionalInt.of(electionState.leaderId()); + } + } + } + } + } + + private static class MonotonicHighWatermark implements Invariant { + final Cluster cluster; + long highWatermark = 0; + + private MonotonicHighWatermark(Cluster cluster) { + this.cluster = cluster; + } + + @Override + public void verify() { + OptionalLong leaderHighWatermark = cluster.leaderHighWatermark(); + leaderHighWatermark.ifPresent(newHighWatermark -> { + long oldHighWatermark = highWatermark; + this.highWatermark = newHighWatermark; + if (newHighWatermark < oldHighWatermark) { + fail("Non-monotonic update of high watermark detected: " + + oldHighWatermark + " -> " + newHighWatermark); + } + }); + } + } + + private static class SnapshotAtLogStart implements Invariant { + final Cluster cluster; + + private SnapshotAtLogStart(Cluster cluster) { + this.cluster = cluster; + } + + @Override + public void verify() { + for (Map.Entry nodeEntry : cluster.nodes.entrySet()) { + int nodeId = nodeEntry.getKey(); + ReplicatedLog log = nodeEntry.getValue().log; + log.earliestSnapshotId().ifPresent(earliestSnapshotId -> { + long logStartOffset = log.startOffset(); + ValidOffsetAndEpoch validateOffsetAndEpoch = log.validateOffsetAndEpoch( + earliestSnapshotId.offset, + earliestSnapshotId.epoch + ); + + assertTrue( + logStartOffset <= earliestSnapshotId.offset, + () -> String.format( + "invalid log start offset (%s) and snapshotId offset (%s): nodeId = %s", + logStartOffset, + earliestSnapshotId.offset, + nodeId + ) + ); + assertEquals( + ValidOffsetAndEpoch.valid(earliestSnapshotId), + validateOffsetAndEpoch, + () -> String.format("invalid leader epoch cache: nodeId = %s", nodeId) + ); + + if (logStartOffset > 0) { + assertEquals( + logStartOffset, + earliestSnapshotId.offset, + () -> String.format("mising snapshot at log start offset: nodeId = %s", nodeId) + ); + } + }); + } + } + } + + private static class LeaderNeverLoadSnapshot implements Invariant { + final Cluster cluster; + + private LeaderNeverLoadSnapshot(Cluster cluster) { + this.cluster = cluster; + } + + @Override + public void verify() { + for (RaftNode raftNode : cluster.running()) { + if (raftNode.counter.isWritable()) { + assertEquals(0, raftNode.counter.handleSnapshotCalls()); + } + } + } + } + + /** + * Validating the committed data is expensive, so we do this as a {@link Validation}. We depend + * on the following external invariants: + * + * - High watermark increases monotonically + * - Truncation below the high watermark is not permitted + * - A majority of nodes reach the high watermark + * + * Under these assumptions, once the simulation finishes, we validate that all nodes have + * consistent data below the respective high watermark that has been recorded. + */ + private static class ConsistentCommittedData implements Validation { + final Cluster cluster; + final Map committedSequenceNumbers = new HashMap<>(); + + private ConsistentCommittedData(Cluster cluster) { + this.cluster = cluster; + } + + private int parseSequenceNumber(ByteBuffer value) { + return (int) Type.INT32.read(value); + } + + private void assertCommittedData(RaftNode node) { + final int nodeId = node.nodeId; + final KafkaRaftClient manager = node.client; + final MockLog log = node.log; + + OptionalLong highWatermark = manager.highWatermark(); + if (!highWatermark.isPresent()) { + // We cannot do validation if the current high watermark is unknown + return; + } + + AtomicLong startOffset = new AtomicLong(0); + log.earliestSnapshotId().ifPresent(snapshotId -> { + assertTrue(snapshotId.offset <= highWatermark.getAsLong()); + startOffset.set(snapshotId.offset); + + try (SnapshotReader snapshot = + SnapshotReader.of(log.readSnapshot(snapshotId).get(), node.intSerde, BufferSupplier.create(), Integer.MAX_VALUE)) { + // Expect only one batch with only one record + assertTrue(snapshot.hasNext()); + Batch batch = snapshot.next(); + assertFalse(snapshot.hasNext()); + assertEquals(1, batch.records().size()); + + // The snapshotId offset is an "end offset" + long offset = snapshotId.offset - 1; + int sequence = batch.records().get(0); + committedSequenceNumbers.putIfAbsent(offset, sequence); + + assertEquals( + committedSequenceNumbers.get(offset), + sequence, + String.format("Committed sequence at offset %s changed on node %s", offset, nodeId) + ); + } + }); + + for (LogBatch batch : log.readBatches(startOffset.get(), highWatermark)) { + if (batch.isControlBatch) { + continue; + } + + for (LogEntry entry : batch.entries) { + long offset = entry.offset; + assertTrue(offset < highWatermark.getAsLong()); + + int sequence = parseSequenceNumber(entry.record.value().duplicate()); + committedSequenceNumbers.putIfAbsent(offset, sequence); + + int committedSequence = committedSequenceNumbers.get(offset); + assertEquals( + committedSequence, sequence, + "Committed sequence at offset " + offset + " changed on node " + nodeId); + } + } + } + + @Override + public void validate() { + cluster.forAllRunning(this::assertCommittedData); + } + } + + private static class MessageRouter { + final Map inflight = new HashMap<>(); + final Map filters = new HashMap<>(); + final Cluster cluster; + + private MessageRouter(Cluster cluster) { + this.cluster = cluster; + for (int nodeId : cluster.nodes.keySet()) + filters.put(nodeId, new PermitAllTraffic()); + } + + void deliver(int senderId, RaftRequest.Outbound outbound) { + if (!filters.get(senderId).acceptOutbound(outbound)) + return; + + int correlationId = outbound.correlationId(); + int destinationId = outbound.destinationId(); + RaftRequest.Inbound inbound = new RaftRequest.Inbound(correlationId, outbound.data(), + cluster.time.milliseconds()); + + if (!filters.get(destinationId).acceptInbound(inbound)) + return; + + cluster.nodeIfRunning(destinationId).ifPresent(node -> { + inflight.put(correlationId, new InflightRequest(correlationId, senderId, destinationId)); + + inbound.completion.whenComplete((response, exception) -> { + if (response != null && filters.get(destinationId).acceptOutbound(response)) { + deliver(destinationId, response); + } + }); + + node.client.handle(inbound); + }); + } + + void deliver(int senderId, RaftResponse.Outbound outbound) { + int correlationId = outbound.correlationId(); + RaftResponse.Inbound inbound = new RaftResponse.Inbound(correlationId, outbound.data(), senderId); + InflightRequest inflightRequest = inflight.remove(correlationId); + + if (!filters.get(inflightRequest.sourceId).acceptInbound(inbound)) + return; + + cluster.nodeIfRunning(inflightRequest.sourceId).ifPresent(node -> { + node.channel.mockReceive(inbound); + }); + } + + void filter(int nodeId, NetworkFilter filter) { + filters.put(nodeId, filter); + } + + void deliverTo(RaftNode node) { + node.channel.drainSendQueue().forEach(msg -> deliver(node.nodeId, msg)); + } + + void deliverAll() { + for (RaftNode node : cluster.running()) { + deliverTo(node); + } + } + } + + private static class IntSerde implements RecordSerde { + @Override + public int recordSize(Integer data, ObjectSerializationCache serializationCache) { + return Type.INT32.sizeOf(data); + } + + @Override + public void write(Integer data, ObjectSerializationCache serializationCache, Writable out) { + out.writeInt(data); + } + + @Override + public Integer read(Readable input, int size) { + return input.readInt(); + } + } + +} diff --git a/raft/src/test/java/org/apache/kafka/raft/RequestManagerTest.java b/raft/src/test/java/org/apache/kafka/raft/RequestManagerTest.java new file mode 100644 index 0000000..e6e2f7c --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/RequestManagerTest.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.api.Test; + +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class RequestManagerTest { + private final MockTime time = new MockTime(); + private final int requestTimeoutMs = 30000; + private final int retryBackoffMs = 100; + private final Random random = new Random(1); + + @Test + public void testResetAllConnections() { + RequestManager cache = new RequestManager( + Utils.mkSet(1, 2, 3), + retryBackoffMs, + requestTimeoutMs, + random); + + // One host has an inflight request + RequestManager.ConnectionState connectionState1 = cache.getOrCreate(1); + connectionState1.onRequestSent(1, time.milliseconds()); + assertFalse(connectionState1.isReady(time.milliseconds())); + + // Another is backing off + RequestManager.ConnectionState connectionState2 = cache.getOrCreate(2); + connectionState2.onRequestSent(2, time.milliseconds()); + connectionState2.onResponseError(2, time.milliseconds()); + assertFalse(connectionState2.isReady(time.milliseconds())); + + cache.resetAll(); + + // Now both should be ready + assertTrue(connectionState1.isReady(time.milliseconds())); + assertTrue(connectionState2.isReady(time.milliseconds())); + } + + @Test + public void testBackoffAfterFailure() { + RequestManager cache = new RequestManager( + Utils.mkSet(1, 2, 3), + retryBackoffMs, + requestTimeoutMs, + random); + + RequestManager.ConnectionState connectionState = cache.getOrCreate(1); + assertTrue(connectionState.isReady(time.milliseconds())); + + long correlationId = 1; + connectionState.onRequestSent(correlationId, time.milliseconds()); + assertFalse(connectionState.isReady(time.milliseconds())); + + connectionState.onResponseError(correlationId, time.milliseconds()); + assertFalse(connectionState.isReady(time.milliseconds())); + + time.sleep(retryBackoffMs); + assertTrue(connectionState.isReady(time.milliseconds())); + } + + @Test + public void testSuccessfulResponse() { + RequestManager cache = new RequestManager( + Utils.mkSet(1, 2, 3), + retryBackoffMs, + requestTimeoutMs, + random); + + RequestManager.ConnectionState connectionState = cache.getOrCreate(1); + + long correlationId = 1; + connectionState.onRequestSent(correlationId, time.milliseconds()); + assertFalse(connectionState.isReady(time.milliseconds())); + connectionState.onResponseReceived(correlationId); + assertTrue(connectionState.isReady(time.milliseconds())); + } + + @Test + public void testIgnoreUnexpectedResponse() { + RequestManager cache = new RequestManager( + Utils.mkSet(1, 2, 3), + retryBackoffMs, + requestTimeoutMs, + random); + + RequestManager.ConnectionState connectionState = cache.getOrCreate(1); + + long correlationId = 1; + connectionState.onRequestSent(correlationId, time.milliseconds()); + assertFalse(connectionState.isReady(time.milliseconds())); + connectionState.onResponseReceived(correlationId + 1); + assertFalse(connectionState.isReady(time.milliseconds())); + } + + @Test + public void testRequestTimeout() { + RequestManager cache = new RequestManager( + Utils.mkSet(1, 2, 3), + retryBackoffMs, + requestTimeoutMs, + random); + + RequestManager.ConnectionState connectionState = cache.getOrCreate(1); + + long correlationId = 1; + connectionState.onRequestSent(correlationId, time.milliseconds()); + assertFalse(connectionState.isReady(time.milliseconds())); + + time.sleep(requestTimeoutMs - 1); + assertFalse(connectionState.isReady(time.milliseconds())); + + time.sleep(1); + assertTrue(connectionState.isReady(time.milliseconds())); + } + +} diff --git a/raft/src/test/java/org/apache/kafka/raft/ResignedStateTest.java b/raft/src/test/java/org/apache/kafka/raft/ResignedStateTest.java new file mode 100644 index 0000000..770297b --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/ResignedStateTest.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import java.util.Collections; +import java.util.List; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class ResignedStateTest { + + private final MockTime time = new MockTime(); + private final LogContext logContext = new LogContext(); + int electionTimeoutMs = 5000; + int localId = 0; + int epoch = 5; + + private ResignedState newResignedState( + Set voters, + List preferredSuccessors + ) { + return new ResignedState( + time, + localId, + epoch, + voters, + electionTimeoutMs, + preferredSuccessors, + logContext + ); + } + + @Test + public void testResignedState() { + int remoteId = 1; + Set voters = Utils.mkSet(localId, remoteId); + + ResignedState state = newResignedState(voters, Collections.emptyList()); + + assertEquals(ElectionState.withElectedLeader(epoch, localId, voters), state.election()); + assertEquals(epoch, state.epoch()); + + assertEquals(Collections.singleton(remoteId), state.unackedVoters()); + state.acknowledgeResignation(remoteId); + assertEquals(Collections.emptySet(), state.unackedVoters()); + + assertEquals(electionTimeoutMs, state.remainingElectionTimeMs(time.milliseconds())); + assertFalse(state.hasElectionTimeoutExpired(time.milliseconds())); + time.sleep(electionTimeoutMs / 2); + assertEquals(electionTimeoutMs / 2, state.remainingElectionTimeMs(time.milliseconds())); + assertFalse(state.hasElectionTimeoutExpired(time.milliseconds())); + time.sleep(electionTimeoutMs / 2); + assertEquals(0, state.remainingElectionTimeMs(time.milliseconds())); + assertTrue(state.hasElectionTimeoutExpired(time.milliseconds())); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testGrantVote(boolean isLogUpToDate) { + ResignedState state = newResignedState( + Utils.mkSet(1, 2, 3), + Collections.emptyList() + ); + + assertFalse(state.canGrantVote(1, isLogUpToDate)); + assertFalse(state.canGrantVote(2, isLogUpToDate)); + assertFalse(state.canGrantVote(3, isLogUpToDate)); + } +} diff --git a/raft/src/test/java/org/apache/kafka/raft/UnattachedStateTest.java b/raft/src/test/java/org/apache/kafka/raft/UnattachedStateTest.java new file mode 100644 index 0000000..96f2a52 --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/UnattachedStateTest.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import java.util.Optional; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class UnattachedStateTest { + + private final MockTime time = new MockTime(); + private final LogContext logContext = new LogContext(); + private final int epoch = 5; + private final int electionTimeoutMs = 10000; + + private UnattachedState newUnattachedState(Set voters, Optional highWatermark) { + return new UnattachedState( + time, + epoch, + voters, + highWatermark, + electionTimeoutMs, + logContext + ); + } + + @Test + public void testElectionTimeout() { + Set voters = Utils.mkSet(1, 2, 3); + + UnattachedState state = newUnattachedState( + voters, + Optional.empty() + ); + + assertEquals(epoch, state.epoch()); + + assertEquals(ElectionState.withUnknownLeader(epoch, voters), state.election()); + assertEquals(electionTimeoutMs, state.remainingElectionTimeMs(time.milliseconds())); + assertFalse(state.hasElectionTimeoutExpired(time.milliseconds())); + + time.sleep(electionTimeoutMs / 2); + assertEquals(electionTimeoutMs / 2, state.remainingElectionTimeMs(time.milliseconds())); + assertFalse(state.hasElectionTimeoutExpired(time.milliseconds())); + + time.sleep(electionTimeoutMs / 2); + assertEquals(0, state.remainingElectionTimeMs(time.milliseconds())); + assertTrue(state.hasElectionTimeoutExpired(time.milliseconds())); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testGrantVote(boolean isLogUpToDate) { + UnattachedState state = newUnattachedState( + Utils.mkSet(1, 2, 3), + Optional.empty() + ); + + assertEquals(isLogUpToDate, state.canGrantVote(1, isLogUpToDate)); + assertEquals(isLogUpToDate, state.canGrantVote(2, isLogUpToDate)); + assertEquals(isLogUpToDate, state.canGrantVote(3, isLogUpToDate)); + } +} diff --git a/raft/src/test/java/org/apache/kafka/raft/VotedStateTest.java b/raft/src/test/java/org/apache/kafka/raft/VotedStateTest.java new file mode 100644 index 0000000..317b80f --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/VotedStateTest.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import java.util.Optional; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class VotedStateTest { + + private final MockTime time = new MockTime(); + private final LogContext logContext = new LogContext(); + private final int epoch = 5; + private final int votedId = 1; + private final int electionTimeoutMs = 10000; + + private VotedState newVotedState( + Set voters, + Optional highWatermark + ) { + return new VotedState( + time, + epoch, + votedId, + voters, + highWatermark, + electionTimeoutMs, + logContext + ); + } + + @Test + public void testElectionTimeout() { + Set voters = Utils.mkSet(1, 2, 3); + + VotedState state = newVotedState(voters, Optional.empty()); + + assertEquals(epoch, state.epoch()); + assertEquals(votedId, state.votedId()); + assertEquals(ElectionState.withVotedCandidate(epoch, votedId, voters), state.election()); + assertEquals(electionTimeoutMs, state.remainingElectionTimeMs(time.milliseconds())); + assertFalse(state.hasElectionTimeoutExpired(time.milliseconds())); + + time.sleep(5000); + assertEquals(electionTimeoutMs - 5000, state.remainingElectionTimeMs(time.milliseconds())); + assertFalse(state.hasElectionTimeoutExpired(time.milliseconds())); + + time.sleep(5000); + assertEquals(0, state.remainingElectionTimeMs(time.milliseconds())); + assertTrue(state.hasElectionTimeoutExpired(time.milliseconds())); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testGrantVote(boolean isLogUpToDate) { + VotedState state = newVotedState( + Utils.mkSet(1, 2, 3), + Optional.empty() + ); + + assertTrue(state.canGrantVote(1, isLogUpToDate)); + assertFalse(state.canGrantVote(2, isLogUpToDate)); + assertFalse(state.canGrantVote(3, isLogUpToDate)); + } +} diff --git a/raft/src/test/java/org/apache/kafka/raft/internals/BatchAccumulatorTest.java b/raft/src/test/java/org/apache/kafka/raft/internals/BatchAccumulatorTest.java new file mode 100644 index 0000000..1149977 --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/internals/BatchAccumulatorTest.java @@ -0,0 +1,522 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + +import org.apache.kafka.common.memory.MemoryPool; +import org.apache.kafka.common.message.LeaderChangeMessage; +import org.apache.kafka.common.protocol.ObjectSerializationCache; +import org.apache.kafka.common.protocol.Writable; +import org.apache.kafka.common.record.AbstractRecords; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.DefaultRecord; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.utils.ByteUtils; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class BatchAccumulatorTest { + private final MemoryPool memoryPool = Mockito.mock(MemoryPool.class); + private final MockTime time = new MockTime(); + private final StringSerde serde = new StringSerde(); + + private BatchAccumulator buildAccumulator( + int leaderEpoch, + long baseOffset, + int lingerMs, + int maxBatchSize + ) { + return new BatchAccumulator<>( + leaderEpoch, + baseOffset, + lingerMs, + maxBatchSize, + memoryPool, + time, + CompressionType.NONE, + serde + ); + } + + @Test + public void testLeaderChangeMessageWritten() { + int leaderEpoch = 17; + long baseOffset = 0; + int lingerMs = 50; + int maxBatchSize = 512; + + ByteBuffer buffer = ByteBuffer.allocate(256); + Mockito.when(memoryPool.tryAllocate(256)) + .thenReturn(buffer); + + BatchAccumulator acc = buildAccumulator( + leaderEpoch, + baseOffset, + lingerMs, + maxBatchSize + ); + + acc.appendLeaderChangeMessage(new LeaderChangeMessage(), time.milliseconds()); + assertTrue(acc.needsDrain(time.milliseconds())); + + List> batches = acc.drain(); + assertEquals(1, batches.size()); + + BatchAccumulator.CompletedBatch batch = batches.get(0); + batch.release(); + Mockito.verify(memoryPool).release(buffer); + } + + @Test + public void testForceDrain() { + asList(APPEND, APPEND_ATOMIC).forEach(appender -> { + int leaderEpoch = 17; + long baseOffset = 157; + int lingerMs = 50; + int maxBatchSize = 512; + + Mockito.when(memoryPool.tryAllocate(maxBatchSize)) + .thenReturn(ByteBuffer.allocate(maxBatchSize)); + + BatchAccumulator acc = buildAccumulator( + leaderEpoch, + baseOffset, + lingerMs, + maxBatchSize + ); + + List records = asList("a", "b", "c", "d", "e", "f", "g", "h", "i"); + + // Append records + assertEquals(baseOffset, appender.call(acc, leaderEpoch, records.subList(0, 1))); + assertEquals(baseOffset + 2, appender.call(acc, leaderEpoch, records.subList(1, 3))); + assertEquals(baseOffset + 5, appender.call(acc, leaderEpoch, records.subList(3, 6))); + assertEquals(baseOffset + 7, appender.call(acc, leaderEpoch, records.subList(6, 8))); + assertEquals(baseOffset + 8, appender.call(acc, leaderEpoch, records.subList(8, 9))); + + assertFalse(acc.needsDrain(time.milliseconds())); + acc.forceDrain(); + assertTrue(acc.needsDrain(time.milliseconds())); + assertEquals(0, acc.timeUntilDrain(time.milliseconds())); + + // Drain completed batches + List> batches = acc.drain(); + + assertEquals(1, batches.size()); + assertFalse(acc.needsDrain(time.milliseconds())); + assertEquals(Long.MAX_VALUE - time.milliseconds(), acc.timeUntilDrain(time.milliseconds())); + + BatchAccumulator.CompletedBatch batch = batches.get(0); + assertEquals(records, batch.records.get()); + assertEquals(baseOffset, batch.baseOffset); + assertEquals(time.milliseconds(), batch.appendTimestamp()); + }); + } + + @Test + public void testForceDrainBeforeAppendLeaderChangeMessage() { + asList(APPEND, APPEND_ATOMIC).forEach(appender -> { + int leaderEpoch = 17; + long baseOffset = 157; + int lingerMs = 50; + int maxBatchSize = 512; + + Mockito.when(memoryPool.tryAllocate(maxBatchSize)) + .thenReturn(ByteBuffer.allocate(maxBatchSize)); + Mockito.when(memoryPool.tryAllocate(256)) + .thenReturn(ByteBuffer.allocate(256)); + + BatchAccumulator acc = buildAccumulator( + leaderEpoch, + baseOffset, + lingerMs, + maxBatchSize + ); + + List records = asList("a", "b", "c", "d", "e", "f", "g", "h", "i"); + + // Append records + assertEquals(baseOffset, appender.call(acc, leaderEpoch, records.subList(0, 1))); + assertEquals(baseOffset + 2, appender.call(acc, leaderEpoch, records.subList(1, 3))); + assertEquals(baseOffset + 5, appender.call(acc, leaderEpoch, records.subList(3, 6))); + assertEquals(baseOffset + 7, appender.call(acc, leaderEpoch, records.subList(6, 8))); + assertEquals(baseOffset + 8, appender.call(acc, leaderEpoch, records.subList(8, 9))); + + assertFalse(acc.needsDrain(time.milliseconds())); + + // Append a leader change message + acc.appendLeaderChangeMessage(new LeaderChangeMessage(), time.milliseconds()); + + assertTrue(acc.needsDrain(time.milliseconds())); + + // Test that drain status is FINISHED + assertEquals(0, acc.timeUntilDrain(time.milliseconds())); + + // Drain completed batches + List> batches = acc.drain(); + + // Should have 2 batches, one consisting of `records` and one `leaderChangeMessage` + assertEquals(2, batches.size()); + assertFalse(acc.needsDrain(time.milliseconds())); + assertEquals(Long.MAX_VALUE - time.milliseconds(), acc.timeUntilDrain(time.milliseconds())); + + BatchAccumulator.CompletedBatch batch = batches.get(0); + assertEquals(records, batch.records.get()); + assertEquals(baseOffset, batch.baseOffset); + assertEquals(time.milliseconds(), batch.appendTimestamp()); + }); + } + + @Test + public void testLingerIgnoredIfAccumulatorEmpty() { + int leaderEpoch = 17; + long baseOffset = 157; + int lingerMs = 50; + int maxBatchSize = 512; + + BatchAccumulator acc = buildAccumulator( + leaderEpoch, + baseOffset, + lingerMs, + maxBatchSize + ); + + assertTrue(acc.isEmpty()); + assertFalse(acc.needsDrain(time.milliseconds())); + assertEquals(Long.MAX_VALUE - time.milliseconds(), acc.timeUntilDrain(time.milliseconds())); + } + + @Test + public void testLingerBeginsOnFirstWrite() { + int leaderEpoch = 17; + long baseOffset = 157; + int lingerMs = 50; + int maxBatchSize = 512; + + Mockito.when(memoryPool.tryAllocate(maxBatchSize)) + .thenReturn(ByteBuffer.allocate(maxBatchSize)); + + BatchAccumulator acc = buildAccumulator( + leaderEpoch, + baseOffset, + lingerMs, + maxBatchSize + ); + + time.sleep(15); + assertEquals(baseOffset, acc.append(leaderEpoch, singletonList("a"))); + assertEquals(lingerMs, acc.timeUntilDrain(time.milliseconds())); + assertFalse(acc.isEmpty()); + + time.sleep(lingerMs / 2); + assertEquals(lingerMs / 2, acc.timeUntilDrain(time.milliseconds())); + assertFalse(acc.isEmpty()); + + time.sleep(lingerMs / 2); + assertEquals(0, acc.timeUntilDrain(time.milliseconds())); + assertTrue(acc.needsDrain(time.milliseconds())); + assertFalse(acc.isEmpty()); + } + + @Test + public void testCompletedBatchReleaseBuffer() { + int leaderEpoch = 17; + long baseOffset = 157; + int lingerMs = 50; + int maxBatchSize = 512; + + ByteBuffer buffer = ByteBuffer.allocate(maxBatchSize); + Mockito.when(memoryPool.tryAllocate(maxBatchSize)) + .thenReturn(buffer); + + BatchAccumulator acc = buildAccumulator( + leaderEpoch, + baseOffset, + lingerMs, + maxBatchSize + ); + + assertEquals(baseOffset, acc.append(leaderEpoch, singletonList("a"))); + time.sleep(lingerMs); + + List> batches = acc.drain(); + assertEquals(1, batches.size()); + + BatchAccumulator.CompletedBatch batch = batches.get(0); + batch.release(); + Mockito.verify(memoryPool).release(buffer); + } + + @Test + public void testUnflushedBuffersReleasedByClose() { + int leaderEpoch = 17; + long baseOffset = 157; + int lingerMs = 50; + int maxBatchSize = 512; + + ByteBuffer buffer = ByteBuffer.allocate(maxBatchSize); + Mockito.when(memoryPool.tryAllocate(maxBatchSize)) + .thenReturn(buffer); + + BatchAccumulator acc = buildAccumulator( + leaderEpoch, + baseOffset, + lingerMs, + maxBatchSize + ); + + assertEquals(baseOffset, acc.append(leaderEpoch, singletonList("a"))); + acc.close(); + Mockito.verify(memoryPool).release(buffer); + } + + @Test + public void testSingleBatchAccumulation() { + asList(APPEND, APPEND_ATOMIC).forEach(appender -> { + int leaderEpoch = 17; + long baseOffset = 157; + int lingerMs = 50; + int maxBatchSize = 512; + + Mockito.when(memoryPool.tryAllocate(maxBatchSize)) + .thenReturn(ByteBuffer.allocate(maxBatchSize)); + + BatchAccumulator acc = buildAccumulator( + leaderEpoch, + baseOffset, + lingerMs, + maxBatchSize + ); + + List records = asList("a", "b", "c", "d", "e", "f", "g", "h", "i"); + assertEquals(baseOffset, appender.call(acc, leaderEpoch, records.subList(0, 1))); + assertEquals(baseOffset + 2, appender.call(acc, leaderEpoch, records.subList(1, 3))); + assertEquals(baseOffset + 5, appender.call(acc, leaderEpoch, records.subList(3, 6))); + assertEquals(baseOffset + 7, appender.call(acc, leaderEpoch, records.subList(6, 8))); + assertEquals(baseOffset + 8, appender.call(acc, leaderEpoch, records.subList(8, 9))); + + long expectedAppendTimestamp = time.milliseconds(); + time.sleep(lingerMs); + assertTrue(acc.needsDrain(time.milliseconds())); + + List> batches = acc.drain(); + assertEquals(1, batches.size()); + assertFalse(acc.needsDrain(time.milliseconds())); + assertEquals(Long.MAX_VALUE - time.milliseconds(), acc.timeUntilDrain(time.milliseconds())); + + BatchAccumulator.CompletedBatch batch = batches.get(0); + assertEquals(records, batch.records.get()); + assertEquals(baseOffset, batch.baseOffset); + assertEquals(expectedAppendTimestamp, batch.appendTimestamp()); + }); + } + + @Test + public void testMultipleBatchAccumulation() { + asList(APPEND, APPEND_ATOMIC).forEach(appender -> { + int leaderEpoch = 17; + long baseOffset = 157; + int lingerMs = 50; + int maxBatchSize = 256; + + Mockito.when(memoryPool.tryAllocate(maxBatchSize)) + .thenReturn(ByteBuffer.allocate(maxBatchSize)); + + BatchAccumulator acc = buildAccumulator( + leaderEpoch, + baseOffset, + lingerMs, + maxBatchSize + ); + + // Append entries until we have 4 batches to drain (3 completed, 1 building) + while (acc.numCompletedBatches() < 3) { + appender.call(acc, leaderEpoch, singletonList("foo")); + } + + List> batches = acc.drain(); + assertEquals(4, batches.size()); + assertTrue(batches.stream().allMatch(batch -> batch.data.sizeInBytes() <= maxBatchSize)); + }); + } + + @Test + public void testRecordsAreSplit() { + int leaderEpoch = 17; + long baseOffset = 157; + int lingerMs = 50; + String record = "a"; + int numberOfRecords = 9; + int recordsPerBatch = 2; + int batchHeaderSize = AbstractRecords.recordBatchHeaderSizeInBytes( + RecordBatch.MAGIC_VALUE_V2, + CompressionType.NONE + ); + int maxBatchSize = batchHeaderSize + recordsPerBatch * recordSizeInBytes(record, recordsPerBatch); + + Mockito.when(memoryPool.tryAllocate(maxBatchSize)) + .thenReturn(ByteBuffer.allocate(maxBatchSize)); + + BatchAccumulator acc = buildAccumulator( + leaderEpoch, + baseOffset, + lingerMs, + maxBatchSize + ); + + List records = Stream + .generate(() -> record) + .limit(numberOfRecords) + .collect(Collectors.toList()); + assertEquals(baseOffset + numberOfRecords - 1, acc.append(leaderEpoch, records)); + + time.sleep(lingerMs); + assertTrue(acc.needsDrain(time.milliseconds())); + + List> batches = acc.drain(); + // ceilingDiv(records.size(), recordsPerBatch) + int expectedBatches = (records.size() + recordsPerBatch - 1) / recordsPerBatch; + assertEquals(expectedBatches, batches.size()); + assertTrue(batches.stream().allMatch(batch -> batch.data.sizeInBytes() <= maxBatchSize)); + } + + @Test + public void testCloseWhenEmpty() { + int leaderEpoch = 17; + long baseOffset = 157; + int lingerMs = 50; + int maxBatchSize = 256; + + BatchAccumulator acc = buildAccumulator( + leaderEpoch, + baseOffset, + lingerMs, + maxBatchSize + ); + + acc.close(); + Mockito.verifyNoInteractions(memoryPool); + } + + @Test + public void testDrainDoesNotBlockWithConcurrentAppend() throws Exception { + int leaderEpoch = 17; + long baseOffset = 157; + int lingerMs = 50; + int maxBatchSize = 256; + + StringSerde serde = Mockito.spy(new StringSerde()); + BatchAccumulator acc = new BatchAccumulator<>( + leaderEpoch, + baseOffset, + lingerMs, + maxBatchSize, + memoryPool, + time, + CompressionType.NONE, + serde + ); + + CountDownLatch acquireLockLatch = new CountDownLatch(1); + CountDownLatch releaseLockLatch = new CountDownLatch(1); + + // Do the first append outside the thread to start the linger timer + Mockito.when(memoryPool.tryAllocate(maxBatchSize)) + .thenReturn(ByteBuffer.allocate(maxBatchSize)); + acc.append(leaderEpoch, singletonList("a")); + + // Let the serde block to simulate a slow append + Mockito.doAnswer(invocation -> { + Writable writable = invocation.getArgument(2); + acquireLockLatch.countDown(); + releaseLockLatch.await(); + writable.writeByteArray(Utils.utf8("b")); + return null; + }).when(serde).write( + Mockito.eq("b"), + Mockito.any(ObjectSerializationCache.class), + Mockito.any(Writable.class) + ); + + Thread appendThread = new Thread(() -> acc.append(leaderEpoch, singletonList("b"))); + appendThread.start(); + + // Attempt to drain while the append thread is holding the lock + acquireLockLatch.await(); + time.sleep(lingerMs); + assertTrue(acc.needsDrain(time.milliseconds())); + assertEquals(Collections.emptyList(), acc.drain()); + assertTrue(acc.needsDrain(time.milliseconds())); + + // Now let the append thread complete and verify that we can finish the drain + releaseLockLatch.countDown(); + appendThread.join(); + List> drained = acc.drain(); + assertEquals(1, drained.size()); + assertEquals(Long.MAX_VALUE - time.milliseconds(), acc.timeUntilDrain(time.milliseconds())); + drained.stream().forEach(completedBatch -> { + completedBatch.data.batches().forEach(recordBatch -> { + assertEquals(leaderEpoch, recordBatch.partitionLeaderEpoch()); }); + }); + } + + int recordSizeInBytes(String record, int numberOfRecords) { + int serdeSize = serde.recordSize("a", new ObjectSerializationCache()); + + int recordSizeInBytes = DefaultRecord.sizeOfBodyInBytes( + numberOfRecords, + 0, + -1, + serdeSize, + DefaultRecord.EMPTY_HEADERS + ); + + return ByteUtils.sizeOfVarint(recordSizeInBytes) + recordSizeInBytes; + } + + static interface Appender { + Long call(BatchAccumulator acc, int epoch, List records); + } + + static final Appender APPEND_ATOMIC = new Appender() { + @Override + public Long call(BatchAccumulator acc, int epoch, List records) { + return acc.appendAtomic(epoch, records); + } + }; + + static final Appender APPEND = new Appender() { + @Override + public Long call(BatchAccumulator acc, int epoch, List records) { + return acc.append(epoch, records); + } + }; +} diff --git a/raft/src/test/java/org/apache/kafka/raft/internals/BatchBuilderTest.java b/raft/src/test/java/org/apache/kafka/raft/internals/BatchBuilderTest.java new file mode 100644 index 0000000..e4611f1 --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/internals/BatchBuilderTest.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.MutableRecordBatch; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; +import org.junit.jupiter.params.provider.ValueSource; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class BatchBuilderTest { + private StringSerde serde = new StringSerde(); + private MockTime time = new MockTime(); + + @ParameterizedTest + @EnumSource(CompressionType.class) + void testBuildBatch(CompressionType compressionType) { + ByteBuffer buffer = ByteBuffer.allocate(1024); + long baseOffset = 57; + long logAppendTime = time.milliseconds(); + boolean isControlBatch = false; + int leaderEpoch = 15; + + BatchBuilder builder = new BatchBuilder<>( + buffer, + serde, + compressionType, + baseOffset, + logAppendTime, + isControlBatch, + leaderEpoch, + buffer.limit() + ); + + List records = Arrays.asList( + "a", + "ap", + "app", + "appl", + "apple" + ); + + records.forEach(record -> builder.appendRecord(record, null)); + MemoryRecords builtRecordSet = builder.build(); + assertTrue(builder.bytesNeeded(Arrays.asList("a"), null).isPresent()); + assertThrows(IllegalStateException.class, () -> builder.appendRecord("a", null)); + + List builtBatches = Utils.toList(builtRecordSet.batchIterator()); + assertEquals(1, builtBatches.size()); + assertEquals(records, builder.records()); + + MutableRecordBatch batch = builtBatches.get(0); + assertEquals(5, batch.countOrNull()); + assertEquals(compressionType, batch.compressionType()); + assertEquals(baseOffset, batch.baseOffset()); + assertEquals(logAppendTime, batch.maxTimestamp()); + assertEquals(isControlBatch, batch.isControlBatch()); + assertEquals(leaderEpoch, batch.partitionLeaderEpoch()); + + List builtRecords = Utils.toList(batch).stream() + .map(record -> Utils.utf8(record.value())) + .collect(Collectors.toList()); + assertEquals(records, builtRecords); + } + + + @ParameterizedTest + @ValueSource(ints = {128, 157, 256, 433, 512, 777, 1024}) + public void testHasRoomForUncompressed(int batchSize) { + ByteBuffer buffer = ByteBuffer.allocate(batchSize); + long baseOffset = 57; + long logAppendTime = time.milliseconds(); + boolean isControlBatch = false; + int leaderEpoch = 15; + + BatchBuilder builder = new BatchBuilder<>( + buffer, + serde, + CompressionType.NONE, + baseOffset, + logAppendTime, + isControlBatch, + leaderEpoch, + buffer.limit() + ); + + String record = "i am a record"; + + while (!builder.bytesNeeded(Arrays.asList(record), null).isPresent()) { + builder.appendRecord(record, null); + } + + // Approximate size should be exact when compression is not used + int sizeInBytes = builder.approximateSizeInBytes(); + MemoryRecords records = builder.build(); + assertEquals(sizeInBytes, records.sizeInBytes()); + assertTrue(sizeInBytes <= batchSize, "Built batch size " + + sizeInBytes + " is larger than max batch size " + batchSize); + } +} diff --git a/raft/src/test/java/org/apache/kafka/raft/internals/BatchMemoryPoolTest.java b/raft/src/test/java/org/apache/kafka/raft/internals/BatchMemoryPoolTest.java new file mode 100644 index 0000000..4177de1 --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/internals/BatchMemoryPoolTest.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class BatchMemoryPoolTest { + + @Test + public void testAllocateAndRelease() { + int batchSize = 1024; + int maxBatches = 1; + + BatchMemoryPool pool = new BatchMemoryPool(maxBatches, batchSize); + assertEquals(batchSize, pool.availableMemory()); + assertFalse(pool.isOutOfMemory()); + + ByteBuffer allocated = pool.tryAllocate(batchSize); + assertNotNull(allocated); + assertEquals(0, allocated.position()); + assertEquals(batchSize, allocated.limit()); + assertEquals(0, pool.availableMemory()); + assertTrue(pool.isOutOfMemory()); + assertNull(pool.tryAllocate(batchSize)); + + allocated.position(512); + allocated.limit(724); + + pool.release(allocated); + ByteBuffer reallocated = pool.tryAllocate(batchSize); + assertSame(allocated, reallocated); + assertEquals(0, allocated.position()); + assertEquals(batchSize, allocated.limit()); + } + + @Test + public void testMultipleAllocations() { + int batchSize = 1024; + int maxBatches = 3; + + BatchMemoryPool pool = new BatchMemoryPool(maxBatches, batchSize); + assertEquals(batchSize * maxBatches, pool.availableMemory()); + + ByteBuffer batch1 = pool.tryAllocate(batchSize); + assertNotNull(batch1); + + ByteBuffer batch2 = pool.tryAllocate(batchSize); + assertNotNull(batch2); + + ByteBuffer batch3 = pool.tryAllocate(batchSize); + assertNotNull(batch3); + + assertNull(pool.tryAllocate(batchSize)); + + pool.release(batch2); + assertSame(batch2, pool.tryAllocate(batchSize)); + + pool.release(batch1); + pool.release(batch3); + ByteBuffer buffer = pool.tryAllocate(batchSize); + assertTrue(buffer == batch1 || buffer == batch3); + } + + @Test + public void testOversizeAllocation() { + int batchSize = 1024; + int maxBatches = 3; + + BatchMemoryPool pool = new BatchMemoryPool(maxBatches, batchSize); + assertThrows(IllegalArgumentException.class, () -> pool.tryAllocate(batchSize + 1)); + } + + @Test + public void testReleaseBufferNotMatchingBatchSize() { + int batchSize = 1024; + int maxBatches = 3; + + BatchMemoryPool pool = new BatchMemoryPool(maxBatches, batchSize); + ByteBuffer buffer = ByteBuffer.allocate(1023); + assertThrows(IllegalArgumentException.class, () -> pool.release(buffer)); + } + +} diff --git a/raft/src/test/java/org/apache/kafka/raft/internals/BlockingMessageQueueTest.java b/raft/src/test/java/org/apache/kafka/raft/internals/BlockingMessageQueueTest.java new file mode 100644 index 0000000..e752fbd --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/internals/BlockingMessageQueueTest.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + +import org.apache.kafka.raft.RaftMessage; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class BlockingMessageQueueTest { + + @Test + public void testOfferAndPoll() { + BlockingMessageQueue queue = new BlockingMessageQueue(); + assertTrue(queue.isEmpty()); + assertNull(queue.poll(0)); + + RaftMessage message1 = Mockito.mock(RaftMessage.class); + queue.add(message1); + assertFalse(queue.isEmpty()); + assertEquals(message1, queue.poll(0)); + assertTrue(queue.isEmpty()); + + RaftMessage message2 = Mockito.mock(RaftMessage.class); + RaftMessage message3 = Mockito.mock(RaftMessage.class); + queue.add(message2); + queue.add(message3); + assertFalse(queue.isEmpty()); + assertEquals(message2, queue.poll(0)); + assertEquals(message3, queue.poll(0)); + + } + + @Test + public void testWakeupFromPoll() { + BlockingMessageQueue queue = new BlockingMessageQueue(); + queue.wakeup(); + assertNull(queue.poll(Long.MAX_VALUE)); + } + +} \ No newline at end of file diff --git a/raft/src/test/java/org/apache/kafka/raft/internals/KafkaRaftMetricsTest.java b/raft/src/test/java/org/apache/kafka/raft/internals/KafkaRaftMetricsTest.java new file mode 100644 index 0000000..0d64eac --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/internals/KafkaRaftMetricsTest.java @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + + +import org.apache.kafka.common.metrics.KafkaMetric; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.raft.LogOffsetMetadata; +import org.apache.kafka.raft.MockQuorumStateStore; +import org.apache.kafka.raft.OffsetAndEpoch; +import org.apache.kafka.raft.QuorumState; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import java.io.IOException; +import java.util.Collections; +import java.util.OptionalInt; +import java.util.OptionalLong; +import java.util.Random; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class KafkaRaftMetricsTest { + + private final int localId = 0; + private final int electionTimeoutMs = 5000; + private final int fetchTimeoutMs = 10000; + + private final Time time = new MockTime(); + private final Metrics metrics = new Metrics(time); + private final Random random = new Random(1); + private KafkaRaftMetrics raftMetrics; + + private BatchAccumulator accumulator = Mockito.mock(BatchAccumulator.class); + + @AfterEach + public void tearDown() { + if (raftMetrics != null) { + raftMetrics.close(); + } + metrics.close(); + } + + private QuorumState buildQuorumState(Set voters) { + return new QuorumState( + OptionalInt.of(localId), + voters, + electionTimeoutMs, + fetchTimeoutMs, + new MockQuorumStateStore(), + time, + new LogContext("kafka-raft-metrics-test"), + random + ); + } + + @Test + public void shouldRecordVoterQuorumState() throws IOException { + QuorumState state = buildQuorumState(Utils.mkSet(localId, 1, 2)); + + state.initialize(new OffsetAndEpoch(0L, 0)); + raftMetrics = new KafkaRaftMetrics(metrics, "raft", state); + + assertEquals("unattached", getMetric(metrics, "current-state").metricValue()); + assertEquals((double) -1L, getMetric(metrics, "current-leader").metricValue()); + assertEquals((double) -1L, getMetric(metrics, "current-vote").metricValue()); + assertEquals((double) 0, getMetric(metrics, "current-epoch").metricValue()); + assertEquals((double) -1L, getMetric(metrics, "high-watermark").metricValue()); + + state.transitionToCandidate(); + assertEquals("candidate", getMetric(metrics, "current-state").metricValue()); + assertEquals((double) -1L, getMetric(metrics, "current-leader").metricValue()); + assertEquals((double) localId, getMetric(metrics, "current-vote").metricValue()); + assertEquals((double) 1, getMetric(metrics, "current-epoch").metricValue()); + assertEquals((double) -1L, getMetric(metrics, "high-watermark").metricValue()); + + state.candidateStateOrThrow().recordGrantedVote(1); + state.transitionToLeader(2L, accumulator); + assertEquals("leader", getMetric(metrics, "current-state").metricValue()); + assertEquals((double) localId, getMetric(metrics, "current-leader").metricValue()); + assertEquals((double) localId, getMetric(metrics, "current-vote").metricValue()); + assertEquals((double) 1, getMetric(metrics, "current-epoch").metricValue()); + assertEquals((double) -1L, getMetric(metrics, "high-watermark").metricValue()); + + state.leaderStateOrThrow().updateLocalState(0, new LogOffsetMetadata(5L)); + state.leaderStateOrThrow().updateReplicaState(1, 0, new LogOffsetMetadata(5L)); + assertEquals((double) 5L, getMetric(metrics, "high-watermark").metricValue()); + + state.transitionToFollower(2, 1); + assertEquals("follower", getMetric(metrics, "current-state").metricValue()); + assertEquals((double) 1, getMetric(metrics, "current-leader").metricValue()); + assertEquals((double) -1, getMetric(metrics, "current-vote").metricValue()); + assertEquals((double) 2, getMetric(metrics, "current-epoch").metricValue()); + assertEquals((double) 5L, getMetric(metrics, "high-watermark").metricValue()); + + state.followerStateOrThrow().updateHighWatermark(OptionalLong.of(10L)); + assertEquals((double) 10L, getMetric(metrics, "high-watermark").metricValue()); + + state.transitionToVoted(3, 2); + assertEquals("voted", getMetric(metrics, "current-state").metricValue()); + assertEquals((double) -1, getMetric(metrics, "current-leader").metricValue()); + assertEquals((double) 2, getMetric(metrics, "current-vote").metricValue()); + assertEquals((double) 3, getMetric(metrics, "current-epoch").metricValue()); + assertEquals((double) 10L, getMetric(metrics, "high-watermark").metricValue()); + + state.transitionToUnattached(4); + assertEquals("unattached", getMetric(metrics, "current-state").metricValue()); + assertEquals((double) -1, getMetric(metrics, "current-leader").metricValue()); + assertEquals((double) -1, getMetric(metrics, "current-vote").metricValue()); + assertEquals((double) 4, getMetric(metrics, "current-epoch").metricValue()); + assertEquals((double) 10L, getMetric(metrics, "high-watermark").metricValue()); + } + + @Test + public void shouldRecordNonVoterQuorumState() throws IOException { + QuorumState state = buildQuorumState(Utils.mkSet(1, 2, 3)); + state.initialize(new OffsetAndEpoch(0L, 0)); + raftMetrics = new KafkaRaftMetrics(metrics, "raft", state); + + assertEquals("unattached", getMetric(metrics, "current-state").metricValue()); + assertEquals((double) -1L, getMetric(metrics, "current-leader").metricValue()); + assertEquals((double) -1L, getMetric(metrics, "current-vote").metricValue()); + assertEquals((double) 0, getMetric(metrics, "current-epoch").metricValue()); + assertEquals((double) -1L, getMetric(metrics, "high-watermark").metricValue()); + + state.transitionToFollower(2, 1); + assertEquals("follower", getMetric(metrics, "current-state").metricValue()); + assertEquals((double) 1, getMetric(metrics, "current-leader").metricValue()); + assertEquals((double) -1, getMetric(metrics, "current-vote").metricValue()); + assertEquals((double) 2, getMetric(metrics, "current-epoch").metricValue()); + assertEquals((double) -1L, getMetric(metrics, "high-watermark").metricValue()); + + state.followerStateOrThrow().updateHighWatermark(OptionalLong.of(10L)); + assertEquals((double) 10L, getMetric(metrics, "high-watermark").metricValue()); + + state.transitionToUnattached(4); + assertEquals("unattached", getMetric(metrics, "current-state").metricValue()); + assertEquals((double) -1, getMetric(metrics, "current-leader").metricValue()); + assertEquals((double) -1, getMetric(metrics, "current-vote").metricValue()); + assertEquals((double) 4, getMetric(metrics, "current-epoch").metricValue()); + assertEquals((double) 10L, getMetric(metrics, "high-watermark").metricValue()); + } + + @Test + public void shouldRecordLogEnd() throws IOException { + QuorumState state = buildQuorumState(Collections.singleton(localId)); + state.initialize(new OffsetAndEpoch(0L, 0)); + raftMetrics = new KafkaRaftMetrics(metrics, "raft", state); + + assertEquals((double) 0L, getMetric(metrics, "log-end-offset").metricValue()); + assertEquals((double) 0, getMetric(metrics, "log-end-epoch").metricValue()); + + raftMetrics.updateLogEnd(new OffsetAndEpoch(5L, 1)); + + assertEquals((double) 5L, getMetric(metrics, "log-end-offset").metricValue()); + assertEquals((double) 1, getMetric(metrics, "log-end-epoch").metricValue()); + } + + @Test + public void shouldRecordNumUnknownVoterConnections() throws IOException { + QuorumState state = buildQuorumState(Collections.singleton(localId)); + state.initialize(new OffsetAndEpoch(0L, 0)); + raftMetrics = new KafkaRaftMetrics(metrics, "raft", state); + + assertEquals((double) 0, getMetric(metrics, "number-unknown-voter-connections").metricValue()); + + raftMetrics.updateNumUnknownVoterConnections(2); + + assertEquals((double) 2, getMetric(metrics, "number-unknown-voter-connections").metricValue()); + } + + @Test + public void shouldRecordPollIdleRatio() throws IOException { + QuorumState state = buildQuorumState(Collections.singleton(localId)); + state.initialize(new OffsetAndEpoch(0L, 0)); + raftMetrics = new KafkaRaftMetrics(metrics, "raft", state); + + raftMetrics.updatePollStart(time.milliseconds()); + time.sleep(100L); + raftMetrics.updatePollEnd(time.milliseconds()); + time.sleep(900L); + raftMetrics.updatePollStart(time.milliseconds()); + + assertEquals(0.1, getMetric(metrics, "poll-idle-ratio-avg").metricValue()); + + time.sleep(100L); + raftMetrics.updatePollEnd(time.milliseconds()); + time.sleep(100L); + raftMetrics.updatePollStart(time.milliseconds()); + + assertEquals(0.3, getMetric(metrics, "poll-idle-ratio-avg").metricValue()); + } + + @Test + public void shouldRecordLatency() throws IOException { + QuorumState state = buildQuorumState(Collections.singleton(localId)); + state.initialize(new OffsetAndEpoch(0L, 0)); + raftMetrics = new KafkaRaftMetrics(metrics, "raft", state); + + raftMetrics.updateElectionStartMs(time.milliseconds()); + time.sleep(1000L); + raftMetrics.maybeUpdateElectionLatency(time.milliseconds()); + + assertEquals((double) 1000, getMetric(metrics, "election-latency-avg").metricValue()); + assertEquals((double) 1000, getMetric(metrics, "election-latency-max").metricValue()); + + raftMetrics.updateElectionStartMs(time.milliseconds()); + time.sleep(800L); + raftMetrics.maybeUpdateElectionLatency(time.milliseconds()); + + assertEquals((double) 900, getMetric(metrics, "election-latency-avg").metricValue()); + assertEquals((double) 1000, getMetric(metrics, "election-latency-max").metricValue()); + + raftMetrics.updateCommitLatency(50, time.milliseconds()); + + assertEquals(50.0, getMetric(metrics, "commit-latency-avg").metricValue()); + assertEquals(50.0, getMetric(metrics, "commit-latency-max").metricValue()); + + raftMetrics.updateCommitLatency(60, time.milliseconds()); + + assertEquals(55.0, getMetric(metrics, "commit-latency-avg").metricValue()); + assertEquals(60.0, getMetric(metrics, "commit-latency-max").metricValue()); + } + + @Test + public void shouldRecordRate() throws IOException { + QuorumState state = buildQuorumState(Collections.singleton(localId)); + state.initialize(new OffsetAndEpoch(0L, 0)); + raftMetrics = new KafkaRaftMetrics(metrics, "raft", state); + + raftMetrics.updateAppendRecords(12); + assertEquals(0.4, getMetric(metrics, "append-records-rate").metricValue()); + + raftMetrics.updateAppendRecords(9); + assertEquals(0.7, getMetric(metrics, "append-records-rate").metricValue()); + + raftMetrics.updateFetchedRecords(24); + assertEquals(0.8, getMetric(metrics, "fetch-records-rate").metricValue()); + + raftMetrics.updateFetchedRecords(48); + assertEquals(2.4, getMetric(metrics, "fetch-records-rate").metricValue()); + } + + private KafkaMetric getMetric(final Metrics metrics, final String name) { + return metrics.metrics().get(metrics.metricName(name, "raft-metrics")); + } +} diff --git a/raft/src/test/java/org/apache/kafka/raft/internals/MemoryBatchReaderTest.java b/raft/src/test/java/org/apache/kafka/raft/internals/MemoryBatchReaderTest.java new file mode 100644 index 0000000..631b6b4 --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/internals/MemoryBatchReaderTest.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + +import org.apache.kafka.raft.Batch; +import org.apache.kafka.raft.BatchReader; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import java.util.Arrays; +import java.util.OptionalLong; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class MemoryBatchReaderTest { + + @Test + public void testIteration() { + Batch batch1 = Batch.data( + 0L, 1, 0L, 3, Arrays.asList("a", "b", "c") + ); + Batch batch2 = Batch.data( + 3L, 2, 1L, 2, Arrays.asList("d", "e") + ); + Batch batch3 = Batch.data( + 5L, 2, 3L, 4, Arrays.asList("f", "g", "h", "i") + ); + + @SuppressWarnings("unchecked") + CloseListener> listener = Mockito.mock(CloseListener.class); + MemoryBatchReader reader = MemoryBatchReader.of( + Arrays.asList(batch1, batch2, batch3), + listener + ); + + assertEquals(0L, reader.baseOffset()); + assertEquals(OptionalLong.of(8L), reader.lastOffset()); + + assertTrue(reader.hasNext()); + assertEquals(batch1, reader.next()); + + assertTrue(reader.hasNext()); + assertEquals(batch2, reader.next()); + + assertTrue(reader.hasNext()); + assertEquals(batch3, reader.next()); + + assertFalse(reader.hasNext()); + + reader.close(); + Mockito.verify(listener).onClose(reader); + } + +} diff --git a/raft/src/test/java/org/apache/kafka/raft/internals/RecordsBatchReaderTest.java b/raft/src/test/java/org/apache/kafka/raft/internals/RecordsBatchReaderTest.java new file mode 100644 index 0000000..6fe5407 --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/internals/RecordsBatchReaderTest.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.FileRecords; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.Records; +import org.apache.kafka.common.utils.BufferSupplier; +import org.apache.kafka.raft.BatchReader; +import org.apache.kafka.raft.internals.RecordsIteratorTest.TestBatch; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; +import org.mockito.Mockito; + +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.IdentityHashMap; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Set; + +import static org.apache.kafka.test.TestUtils.tempFile; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class RecordsBatchReaderTest { + private static final int MAX_BATCH_BYTES = 128; + + private final StringSerde serde = new StringSerde(); + + @ParameterizedTest + @EnumSource(CompressionType.class) + public void testReadFromMemoryRecords(CompressionType compressionType) { + long baseOffset = 57; + + List> batches = RecordsIteratorTest.createBatches(baseOffset); + MemoryRecords memRecords = RecordsIteratorTest.buildRecords(compressionType, batches); + + testBatchReader(baseOffset, memRecords, batches); + } + + @ParameterizedTest + @EnumSource(CompressionType.class) + public void testReadFromFileRecords(CompressionType compressionType) throws Exception { + long baseOffset = 57; + + List> batches = RecordsIteratorTest.createBatches(baseOffset); + MemoryRecords memRecords = RecordsIteratorTest.buildRecords(compressionType, batches); + + FileRecords fileRecords = FileRecords.open(tempFile()); + fileRecords.append(memRecords); + + testBatchReader(baseOffset, fileRecords, batches); + } + + private void testBatchReader( + long baseOffset, + Records records, + List> expectedBatches + ) { + BufferSupplier bufferSupplier = Mockito.mock(BufferSupplier.class); + Set allocatedBuffers = Collections.newSetFromMap(new IdentityHashMap<>()); + + Mockito.when(bufferSupplier.get(Mockito.anyInt())).thenAnswer(invocation -> { + int size = invocation.getArgument(0); + ByteBuffer buffer = ByteBuffer.allocate(size); + allocatedBuffers.add(buffer); + return buffer; + }); + + Mockito.doAnswer(invocation -> { + ByteBuffer released = invocation.getArgument(0); + allocatedBuffers.remove(released); + return null; + }).when(bufferSupplier).release(Mockito.any(ByteBuffer.class)); + + @SuppressWarnings("unchecked") + CloseListener> closeListener = Mockito.mock(CloseListener.class); + + RecordsBatchReader reader = RecordsBatchReader.of( + baseOffset, + records, + serde, + bufferSupplier, + MAX_BATCH_BYTES, + closeListener + ); + + for (TestBatch batch : expectedBatches) { + assertTrue(reader.hasNext()); + assertEquals(batch, TestBatch.from(reader.next())); + } + + assertFalse(reader.hasNext()); + assertThrows(NoSuchElementException.class, reader::next); + + reader.close(); + Mockito.verify(closeListener).onClose(reader); + assertEquals(Collections.emptySet(), allocatedBuffers); + } + +} diff --git a/raft/src/test/java/org/apache/kafka/raft/internals/RecordsIteratorTest.java b/raft/src/test/java/org/apache/kafka/raft/internals/RecordsIteratorTest.java new file mode 100644 index 0000000..4a10f57 --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/internals/RecordsIteratorTest.java @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.IdentityHashMap; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Objects; +import java.util.Random; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import net.jqwik.api.ForAll; +import net.jqwik.api.Property; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.FileRecords; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.Records; +import org.apache.kafka.common.utils.BufferSupplier; +import org.apache.kafka.raft.Batch; +import org.apache.kafka.server.common.serialization.RecordSerde; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public final class RecordsIteratorTest { + private static final RecordSerde STRING_SERDE = new StringSerde(); + + private static Stream emptyRecords() throws IOException { + return Stream.of( + FileRecords.open(TestUtils.tempFile()), + MemoryRecords.EMPTY + ).map(Arguments::of); + } + + @ParameterizedTest + @MethodSource("emptyRecords") + void testEmptyRecords(Records records) throws IOException { + testIterator(Collections.emptyList(), records); + } + + @Property + public void testMemoryRecords( + @ForAll CompressionType compressionType, + @ForAll long seed + ) { + List> batches = createBatches(seed); + + MemoryRecords memRecords = buildRecords(compressionType, batches); + testIterator(batches, memRecords); + } + + @Property + public void testFileRecords( + @ForAll CompressionType compressionType, + @ForAll long seed + ) throws IOException { + List> batches = createBatches(seed); + + MemoryRecords memRecords = buildRecords(compressionType, batches); + FileRecords fileRecords = FileRecords.open(TestUtils.tempFile()); + fileRecords.append(memRecords); + + testIterator(batches, fileRecords); + } + + private void testIterator( + List> expectedBatches, + Records records + ) { + Set allocatedBuffers = Collections.newSetFromMap(new IdentityHashMap<>()); + + RecordsIterator iterator = createIterator( + records, + mockBufferSupplier(allocatedBuffers) + ); + + for (TestBatch batch : expectedBatches) { + assertTrue(iterator.hasNext()); + assertEquals(batch, TestBatch.from(iterator.next())); + } + + assertFalse(iterator.hasNext()); + assertThrows(NoSuchElementException.class, iterator::next); + + iterator.close(); + assertEquals(Collections.emptySet(), allocatedBuffers); + } + + static RecordsIterator createIterator(Records records, BufferSupplier bufferSupplier) { + return new RecordsIterator<>(records, STRING_SERDE, bufferSupplier, Records.HEADER_SIZE_UP_TO_MAGIC); + } + + static BufferSupplier mockBufferSupplier(Set buffers) { + BufferSupplier bufferSupplier = Mockito.mock(BufferSupplier.class); + + Mockito.when(bufferSupplier.get(Mockito.anyInt())).thenAnswer(invocation -> { + int size = invocation.getArgument(0); + ByteBuffer buffer = ByteBuffer.allocate(size); + buffers.add(buffer); + return buffer; + }); + + Mockito.doAnswer(invocation -> { + ByteBuffer released = invocation.getArgument(0); + buffers.remove(released); + return null; + }).when(bufferSupplier).release(Mockito.any(ByteBuffer.class)); + + return bufferSupplier; + } + + public static List> createBatches(long seed) { + Random random = new Random(seed); + long baseOffset = random.nextInt(100); + int epoch = random.nextInt(3) + 1; + long appendTimestamp = random.nextInt(1000); + + int numberOfBatches = random.nextInt(100) + 1; + List> batches = new ArrayList<>(numberOfBatches); + for (int i = 0; i < numberOfBatches; i++) { + int numberOfRecords = random.nextInt(100) + 1; + List records = random + .ints(numberOfRecords, 0, 10) + .mapToObj(String::valueOf) + .collect(Collectors.toList()); + + batches.add(new TestBatch<>(baseOffset, epoch, appendTimestamp, records)); + baseOffset += records.size(); + if (i % 5 == 0) { + epoch += random.nextInt(3); + } + appendTimestamp += random.nextInt(1000); + } + + return batches; + } + + public static MemoryRecords buildRecords( + CompressionType compressionType, + List> batches + ) { + ByteBuffer buffer = ByteBuffer.allocate(102400); + + for (TestBatch batch : batches) { + BatchBuilder builder = new BatchBuilder<>( + buffer, + STRING_SERDE, + compressionType, + batch.baseOffset, + batch.appendTimestamp, + false, + batch.epoch, + 1024 + ); + + for (String record : batch.records) { + builder.appendRecord(record, null); + } + + builder.build(); + } + + buffer.flip(); + return MemoryRecords.readableRecords(buffer); + } + + public static final class TestBatch { + final long baseOffset; + final int epoch; + final long appendTimestamp; + final List records; + + TestBatch(long baseOffset, int epoch, long appendTimestamp, List records) { + this.baseOffset = baseOffset; + this.epoch = epoch; + this.appendTimestamp = appendTimestamp; + this.records = records; + } + + @Override + public String toString() { + return String.format( + "TestBatch(baseOffset=%s, epoch=%s, records=%s)", + baseOffset, + epoch, + records + ); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TestBatch testBatch = (TestBatch) o; + return baseOffset == testBatch.baseOffset && + epoch == testBatch.epoch && + Objects.equals(records, testBatch.records); + } + + @Override + public int hashCode() { + return Objects.hash(baseOffset, epoch, records); + } + + static TestBatch from(Batch batch) { + return new TestBatch<>(batch.baseOffset(), batch.epoch(), batch.appendTimestamp(), batch.records()); + } + } +} diff --git a/raft/src/test/java/org/apache/kafka/raft/internals/ThresholdPurgatoryTest.java b/raft/src/test/java/org/apache/kafka/raft/internals/ThresholdPurgatoryTest.java new file mode 100644 index 0000000..7816427 --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/internals/ThresholdPurgatoryTest.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + +import org.apache.kafka.common.errors.NotLeaderOrFollowerException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.raft.MockExpirationService; +import org.junit.jupiter.api.Test; + +import java.util.concurrent.CompletableFuture; + +import static org.apache.kafka.test.TestUtils.assertFutureThrows; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class ThresholdPurgatoryTest { + private final MockTime time = new MockTime(); + private final MockExpirationService expirationService = new MockExpirationService(time); + private final ThresholdPurgatory purgatory = new ThresholdPurgatory<>(expirationService); + + @Test + public void testThresholdCompletion() throws Exception { + CompletableFuture future1 = purgatory.await(3L, 500); + CompletableFuture future2 = purgatory.await(1L, 500); + CompletableFuture future3 = purgatory.await(5L, 500); + assertEquals(3, purgatory.numWaiting()); + + long completionTime1 = time.milliseconds(); + purgatory.maybeComplete(1L, completionTime1); + assertTrue(future2.isDone()); + assertFalse(future1.isDone()); + assertFalse(future3.isDone()); + assertEquals(completionTime1, future2.get()); + assertEquals(2, purgatory.numWaiting()); + + time.sleep(100); + purgatory.maybeComplete(2L, time.milliseconds()); + assertFalse(future1.isDone()); + assertFalse(future3.isDone()); + + time.sleep(100); + long completionTime2 = time.milliseconds(); + purgatory.maybeComplete(3L, completionTime2); + assertTrue(future1.isDone()); + assertFalse(future3.isDone()); + assertEquals(completionTime2, future1.get()); + assertEquals(1, purgatory.numWaiting()); + + time.sleep(100); + purgatory.maybeComplete(4L, time.milliseconds()); + assertFalse(future3.isDone()); + + time.sleep(100); + long completionTime3 = time.milliseconds(); + purgatory.maybeComplete(5L, completionTime3); + assertTrue(future3.isDone()); + assertEquals(completionTime3, future3.get()); + assertEquals(0, purgatory.numWaiting()); + } + + @Test + public void testExpiration() { + CompletableFuture future1 = purgatory.await(1L, 200); + CompletableFuture future2 = purgatory.await(1L, 200); + assertEquals(2, purgatory.numWaiting()); + + time.sleep(100); + CompletableFuture future3 = purgatory.await(5L, 50); + CompletableFuture future4 = purgatory.await(5L, 200); + CompletableFuture future5 = purgatory.await(5L, 100); + assertEquals(5, purgatory.numWaiting()); + + time.sleep(50); + assertFutureThrows(future3, TimeoutException.class); + assertFalse(future1.isDone()); + assertFalse(future2.isDone()); + assertFalse(future4.isDone()); + assertFalse(future5.isDone()); + assertEquals(4, purgatory.numWaiting()); + + time.sleep(50); + assertFutureThrows(future1, TimeoutException.class); + assertFutureThrows(future2, TimeoutException.class); + assertFutureThrows(future5, TimeoutException.class); + assertFalse(future4.isDone()); + assertEquals(1, purgatory.numWaiting()); + + time.sleep(50); + assertFalse(future4.isDone()); + assertEquals(1, purgatory.numWaiting()); + + time.sleep(50); + assertFutureThrows(future4, TimeoutException.class); + assertEquals(0, purgatory.numWaiting()); + } + + @Test + public void testCompleteAll() throws Exception { + CompletableFuture future1 = purgatory.await(3L, 500); + CompletableFuture future2 = purgatory.await(1L, 500); + CompletableFuture future3 = purgatory.await(5L, 500); + assertEquals(3, purgatory.numWaiting()); + + long completionTime = time.milliseconds(); + purgatory.completeAll(completionTime); + assertEquals(completionTime, future1.get()); + assertEquals(completionTime, future2.get()); + assertEquals(completionTime, future3.get()); + assertEquals(0, purgatory.numWaiting()); + } + + @Test + public void testCompleteAllExceptionally() { + CompletableFuture future1 = purgatory.await(3L, 500); + CompletableFuture future2 = purgatory.await(1L, 500); + CompletableFuture future3 = purgatory.await(5L, 500); + assertEquals(3, purgatory.numWaiting()); + + purgatory.completeAllExceptionally(new NotLeaderOrFollowerException()); + assertFutureThrows(future1, NotLeaderOrFollowerException.class); + assertFutureThrows(future2, NotLeaderOrFollowerException.class); + assertFutureThrows(future3, NotLeaderOrFollowerException.class); + assertEquals(0, purgatory.numWaiting()); + } + + @Test + public void testExternalCompletion() { + CompletableFuture future1 = purgatory.await(3L, 500); + CompletableFuture future2 = purgatory.await(1L, 500); + CompletableFuture future3 = purgatory.await(5L, 500); + assertEquals(3, purgatory.numWaiting()); + + future2.complete(time.milliseconds()); + assertFalse(future1.isDone()); + assertFalse(future3.isDone()); + assertEquals(2, purgatory.numWaiting()); + + future1.complete(time.milliseconds()); + assertFalse(future3.isDone()); + assertEquals(1, purgatory.numWaiting()); + + future3.complete(time.milliseconds()); + assertEquals(0, purgatory.numWaiting()); + } + +} diff --git a/raft/src/test/java/org/apache/kafka/snapshot/FileRawSnapshotTest.java b/raft/src/test/java/org/apache/kafka/snapshot/FileRawSnapshotTest.java new file mode 100644 index 0000000..ef38d46 --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/snapshot/FileRawSnapshotTest.java @@ -0,0 +1,358 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.snapshot; + +import org.apache.kafka.common.utils.BufferSupplier.GrowableBufferSupplier; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.Record; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.record.SimpleRecord; +import org.apache.kafka.common.record.UnalignedFileRecords; +import org.apache.kafka.common.record.UnalignedMemoryRecords; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.raft.OffsetAndEpoch; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.Iterator; +import java.util.Optional; +import java.util.stream.IntStream; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public final class FileRawSnapshotTest { + private Path tempDir = null; + + @BeforeEach + public void setUp() { + tempDir = TestUtils.tempDirectory().toPath(); + } + + @AfterEach + public void tearDown() throws IOException { + Utils.delete(tempDir.toFile()); + } + + @Test + public void testWritingSnapshot() throws IOException { + OffsetAndEpoch offsetAndEpoch = new OffsetAndEpoch(10L, 3); + int bufferSize = 256; + int numberOfBatches = 10; + int expectedSize = 0; + + try (FileRawSnapshotWriter snapshot = createSnapshotWriter(tempDir, offsetAndEpoch)) { + assertEquals(0, snapshot.sizeInBytes()); + + UnalignedMemoryRecords records = buildRecords(ByteBuffer.wrap(randomBytes(bufferSize))); + for (int i = 0; i < numberOfBatches; i++) { + snapshot.append(records); + expectedSize += records.sizeInBytes(); + } + + assertEquals(expectedSize, snapshot.sizeInBytes()); + + snapshot.freeze(); + } + + // File should exist and the size should be the sum of all the buffers + assertTrue(Files.exists(Snapshots.snapshotPath(tempDir, offsetAndEpoch))); + assertEquals(expectedSize, Files.size(Snapshots.snapshotPath(tempDir, offsetAndEpoch))); + } + + @Test + public void testWriteReadSnapshot() throws IOException { + OffsetAndEpoch offsetAndEpoch = new OffsetAndEpoch(10L, 3); + int bufferSize = 256; + int numberOfBatches = 10; + + ByteBuffer expectedBuffer = ByteBuffer.wrap(randomBytes(bufferSize)); + + try (FileRawSnapshotWriter snapshot = createSnapshotWriter(tempDir, offsetAndEpoch)) { + UnalignedMemoryRecords records = buildRecords(expectedBuffer); + for (int i = 0; i < numberOfBatches; i++) { + snapshot.append(records); + } + + snapshot.freeze(); + } + + try (FileRawSnapshotReader snapshot = FileRawSnapshotReader.open(tempDir, offsetAndEpoch)) { + int countBatches = 0; + int countRecords = 0; + + Iterator batches = Utils.covariantCast(snapshot.records().batchIterator()); + while (batches.hasNext()) { + RecordBatch batch = batches.next(); + countBatches += 1; + + Iterator records = batch.streamingIterator(new GrowableBufferSupplier()); + while (records.hasNext()) { + Record record = records.next(); + + countRecords += 1; + + assertFalse(record.hasKey()); + assertTrue(record.hasValue()); + assertEquals(bufferSize, record.value().remaining()); + assertEquals(expectedBuffer, record.value()); + } + } + + assertEquals(numberOfBatches, countBatches); + assertEquals(numberOfBatches, countRecords); + } + } + + @Test + public void testPartialWriteReadSnapshot() throws IOException { + Path tempDir = TestUtils.tempDirectory().toPath(); + OffsetAndEpoch offsetAndEpoch = new OffsetAndEpoch(10L, 3); + + ByteBuffer records = buildRecords(ByteBuffer.wrap(Utils.utf8("foo"))).buffer(); + + ByteBuffer expectedBuffer = ByteBuffer.wrap(records.array()); + + ByteBuffer buffer1 = expectedBuffer.duplicate(); + buffer1.position(0); + buffer1.limit(expectedBuffer.limit() / 2); + ByteBuffer buffer2 = expectedBuffer.duplicate(); + buffer2.position(expectedBuffer.limit() / 2); + buffer2.limit(expectedBuffer.limit()); + + try (FileRawSnapshotWriter snapshot = createSnapshotWriter(tempDir, offsetAndEpoch)) { + snapshot.append(new UnalignedMemoryRecords(buffer1)); + snapshot.append(new UnalignedMemoryRecords(buffer2)); + snapshot.freeze(); + } + + try (FileRawSnapshotReader snapshot = FileRawSnapshotReader.open(tempDir, offsetAndEpoch)) { + int totalSize = Math.toIntExact(snapshot.sizeInBytes()); + assertEquals(expectedBuffer.remaining(), totalSize); + + UnalignedFileRecords record1 = (UnalignedFileRecords) snapshot.slice(0, totalSize / 2); + UnalignedFileRecords record2 = (UnalignedFileRecords) snapshot.slice(totalSize / 2, totalSize - totalSize / 2); + + assertEquals(buffer1, TestUtils.toBuffer(record1)); + assertEquals(buffer2, TestUtils.toBuffer(record2)); + + ByteBuffer readBuffer = ByteBuffer.allocate(record1.sizeInBytes() + record2.sizeInBytes()); + readBuffer.put(TestUtils.toBuffer(record1)); + readBuffer.put(TestUtils.toBuffer(record2)); + readBuffer.flip(); + assertEquals(expectedBuffer, readBuffer); + } + } + + @Test + public void testBatchWriteReadSnapshot() throws IOException { + OffsetAndEpoch offsetAndEpoch = new OffsetAndEpoch(10L, 3); + int bufferSize = 256; + int batchSize = 3; + int numberOfBatches = 10; + + try (FileRawSnapshotWriter snapshot = createSnapshotWriter(tempDir, offsetAndEpoch)) { + for (int i = 0; i < numberOfBatches; i++) { + ByteBuffer[] buffers = IntStream + .range(0, batchSize) + .mapToObj(ignore -> ByteBuffer.wrap(randomBytes(bufferSize))).toArray(ByteBuffer[]::new); + + snapshot.append(buildRecords(buffers)); + } + + snapshot.freeze(); + } + + try (FileRawSnapshotReader snapshot = FileRawSnapshotReader.open(tempDir, offsetAndEpoch)) { + int countBatches = 0; + int countRecords = 0; + + Iterator batches = Utils.covariantCast(snapshot.records().batchIterator()); + while (batches.hasNext()) { + RecordBatch batch = batches.next(); + countBatches += 1; + + Iterator records = batch.streamingIterator(new GrowableBufferSupplier()); + while (records.hasNext()) { + Record record = records.next(); + + countRecords += 1; + + assertFalse(record.hasKey()); + assertTrue(record.hasValue()); + assertEquals(bufferSize, record.value().remaining()); + } + } + + assertEquals(numberOfBatches, countBatches); + assertEquals(numberOfBatches * batchSize, countRecords); + } + } + + @Test + public void testBufferWriteReadSnapshot() throws IOException { + OffsetAndEpoch offsetAndEpoch = new OffsetAndEpoch(10L, 3); + int bufferSize = 256; + int batchSize = 3; + int numberOfBatches = 10; + int expectedSize = 0; + + try (FileRawSnapshotWriter snapshot = createSnapshotWriter(tempDir, offsetAndEpoch)) { + for (int i = 0; i < numberOfBatches; i++) { + ByteBuffer[] buffers = IntStream + .range(0, batchSize) + .mapToObj(ignore -> ByteBuffer.wrap(randomBytes(bufferSize))).toArray(ByteBuffer[]::new); + + UnalignedMemoryRecords records = buildRecords(buffers); + snapshot.append(records); + expectedSize += records.sizeInBytes(); + } + + assertEquals(expectedSize, snapshot.sizeInBytes()); + + snapshot.freeze(); + } + + // File should exist and the size should be the sum of all the buffers + assertTrue(Files.exists(Snapshots.snapshotPath(tempDir, offsetAndEpoch))); + assertEquals(expectedSize, Files.size(Snapshots.snapshotPath(tempDir, offsetAndEpoch))); + + try (FileRawSnapshotReader snapshot = FileRawSnapshotReader.open(tempDir, offsetAndEpoch)) { + int countBatches = 0; + int countRecords = 0; + + Iterator batches = Utils.covariantCast(snapshot.records().batchIterator()); + while (batches.hasNext()) { + RecordBatch batch = batches.next(); + countBatches += 1; + + Iterator records = batch.streamingIterator(new GrowableBufferSupplier()); + while (records.hasNext()) { + Record record = records.next(); + + countRecords += 1; + + assertFalse(record.hasKey()); + assertTrue(record.hasValue()); + assertEquals(bufferSize, record.value().remaining()); + } + } + + assertEquals(numberOfBatches, countBatches); + assertEquals(numberOfBatches * batchSize, countRecords); + } + } + + @Test + public void testAbortedSnapshot() throws IOException { + OffsetAndEpoch offsetAndEpoch = new OffsetAndEpoch(20L, 2); + int bufferSize = 256; + int numberOfBatches = 10; + + try (FileRawSnapshotWriter snapshot = createSnapshotWriter(tempDir, offsetAndEpoch)) { + UnalignedMemoryRecords records = buildRecords(ByteBuffer.wrap(randomBytes(bufferSize))); + for (int i = 0; i < numberOfBatches; i++) { + snapshot.append(records); + } + } + + // File should not exist since freeze was not called before + assertFalse(Files.exists(Snapshots.snapshotPath(tempDir, offsetAndEpoch))); + assertEquals(0, Files.list(Snapshots.snapshotDir(tempDir)).count()); + } + + @Test + public void testAppendToFrozenSnapshot() throws IOException { + OffsetAndEpoch offsetAndEpoch = new OffsetAndEpoch(10L, 3); + int bufferSize = 256; + int numberOfBatches = 10; + + try (FileRawSnapshotWriter snapshot = createSnapshotWriter(tempDir, offsetAndEpoch)) { + UnalignedMemoryRecords records = buildRecords(ByteBuffer.wrap(randomBytes(bufferSize))); + for (int i = 0; i < numberOfBatches; i++) { + snapshot.append(records); + } + + snapshot.freeze(); + + assertThrows(RuntimeException.class, () -> snapshot.append(records)); + } + + // File should exist and the size should be greater than the sum of all the buffers + assertTrue(Files.exists(Snapshots.snapshotPath(tempDir, offsetAndEpoch))); + assertTrue(Files.size(Snapshots.snapshotPath(tempDir, offsetAndEpoch)) > bufferSize * numberOfBatches); + } + + @Test + public void testCreateSnapshotWithSameId() throws IOException { + OffsetAndEpoch offsetAndEpoch = new OffsetAndEpoch(20L, 2); + int bufferSize = 256; + int numberOfBatches = 1; + + try (FileRawSnapshotWriter snapshot = createSnapshotWriter(tempDir, offsetAndEpoch)) { + UnalignedMemoryRecords records = buildRecords(ByteBuffer.wrap(randomBytes(bufferSize))); + for (int i = 0; i < numberOfBatches; i++) { + snapshot.append(records); + } + + snapshot.freeze(); + } + + // Create another snapshot with the same id + try (FileRawSnapshotWriter snapshot = createSnapshotWriter(tempDir, offsetAndEpoch)) { + UnalignedMemoryRecords records = buildRecords(ByteBuffer.wrap(randomBytes(bufferSize))); + for (int i = 0; i < numberOfBatches; i++) { + snapshot.append(records); + } + + snapshot.freeze(); + } + } + + private static byte[] randomBytes(int size) { + byte[] array = new byte[size]; + + TestUtils.SEEDED_RANDOM.nextBytes(array); + + return array; + } + + private static UnalignedMemoryRecords buildRecords(ByteBuffer... buffers) { + MemoryRecords records = MemoryRecords.withRecords( + CompressionType.NONE, + Arrays.stream(buffers).map(SimpleRecord::new).toArray(SimpleRecord[]::new) + ); + return new UnalignedMemoryRecords(records.buffer()); + } + + private static FileRawSnapshotWriter createSnapshotWriter( + Path dir, + OffsetAndEpoch snapshotId + ) throws IOException { + return FileRawSnapshotWriter.create(dir, snapshotId, Optional.empty()); + } +} diff --git a/raft/src/test/java/org/apache/kafka/snapshot/MockRawSnapshotReader.java b/raft/src/test/java/org/apache/kafka/snapshot/MockRawSnapshotReader.java new file mode 100644 index 0000000..0e0a35d --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/snapshot/MockRawSnapshotReader.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.snapshot; + +import java.nio.ByteBuffer; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.Records; +import org.apache.kafka.common.record.UnalignedMemoryRecords; +import org.apache.kafka.common.record.UnalignedRecords; +import org.apache.kafka.raft.OffsetAndEpoch; + +public final class MockRawSnapshotReader implements RawSnapshotReader { + private final OffsetAndEpoch snapshotId; + private final MemoryRecords data; + + public MockRawSnapshotReader(OffsetAndEpoch snapshotId, ByteBuffer data) { + this.snapshotId = snapshotId; + this.data = MemoryRecords.readableRecords(data); + } + + @Override + public OffsetAndEpoch snapshotId() { + return snapshotId; + } + + @Override + public long sizeInBytes() { + return data.sizeInBytes(); + } + + @Override + public UnalignedRecords slice(long position, int size) { + ByteBuffer buffer = data.buffer(); + buffer.position(Math.toIntExact(position)); + buffer.limit(Math.min(buffer.limit(), Math.toIntExact(position + size))); + return new UnalignedMemoryRecords(buffer.slice()); + } + + @Override + public Records records() { + return data; + } + + @Override + public String toString() { + return String.format("MockRawSnapshotReader(snapshotId=%s, data=%s)", snapshotId, data); + } +} diff --git a/raft/src/test/java/org/apache/kafka/snapshot/MockRawSnapshotWriter.java b/raft/src/test/java/org/apache/kafka/snapshot/MockRawSnapshotWriter.java new file mode 100644 index 0000000..0b5cc66 --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/snapshot/MockRawSnapshotWriter.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.snapshot; + +import java.nio.ByteBuffer; +import java.util.function.Consumer; +import org.apache.kafka.common.record.MemoryRecords; +import org.apache.kafka.common.record.UnalignedMemoryRecords; +import org.apache.kafka.common.utils.ByteBufferOutputStream; +import org.apache.kafka.raft.OffsetAndEpoch; + +public final class MockRawSnapshotWriter implements RawSnapshotWriter { + private final ByteBufferOutputStream data = new ByteBufferOutputStream(0); + private final OffsetAndEpoch snapshotId; + private final Consumer frozenHandler; + + private boolean frozen = false; + private boolean closed = false; + + public MockRawSnapshotWriter( + OffsetAndEpoch snapshotId, + Consumer frozenHandler + ) { + this.snapshotId = snapshotId; + this.frozenHandler = frozenHandler; + } + + @Override + public OffsetAndEpoch snapshotId() { + return snapshotId; + } + + @Override + public long sizeInBytes() { + ensureNotFrozenOrClosed(); + return data.position(); + } + + @Override + public void append(UnalignedMemoryRecords records) { + ensureNotFrozenOrClosed(); + data.write(records.buffer()); + } + + @Override + public void append(MemoryRecords records) { + ensureNotFrozenOrClosed(); + data.write(records.buffer()); + } + + @Override + public boolean isFrozen() { + return frozen; + } + + @Override + public void freeze() { + ensureNotFrozenOrClosed(); + + frozen = true; + ByteBuffer buffer = data.buffer(); + buffer.flip(); + + frozenHandler.accept(buffer); + } + + @Override + public void close() { + ensureOpen(); + closed = true; + } + + @Override + public String toString() { + return String.format("MockRawSnapshotWriter(snapshotId=%s, data=%s)", snapshotId, data.buffer()); + } + + private void ensureNotFrozenOrClosed() { + if (frozen) { + throw new IllegalStateException("Snapshot is already frozen " + snapshotId); + } + ensureOpen(); + } + + private void ensureOpen() { + if (closed) { + throw new IllegalStateException("Snapshot is already closed " + snapshotId); + } + } +} diff --git a/raft/src/test/java/org/apache/kafka/snapshot/SnapshotWriterReaderTest.java b/raft/src/test/java/org/apache/kafka/snapshot/SnapshotWriterReaderTest.java new file mode 100644 index 0000000..c0f80c5 --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/snapshot/SnapshotWriterReaderTest.java @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.snapshot; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.OptionalInt; +import java.util.Random; +import java.util.Iterator; +import java.util.Set; +import org.apache.kafka.common.message.SnapshotFooterRecord; +import org.apache.kafka.common.message.SnapshotHeaderRecord; +import org.apache.kafka.common.utils.BufferSupplier; +import org.apache.kafka.raft.Batch; +import org.apache.kafka.raft.OffsetAndEpoch; +import org.apache.kafka.raft.RaftClientTestContext; +import org.apache.kafka.raft.internals.StringSerde; +import org.apache.kafka.common.utils.BufferSupplier.GrowableBufferSupplier; +import org.apache.kafka.common.record.ControlRecordUtils; +import org.apache.kafka.common.record.Record; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.utils.Utils; + +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +final public class SnapshotWriterReaderTest { + private final int localId = 0; + private final Set voters = Collections.singleton(localId); + + @Test + public void testSnapshotDelimiters() throws Exception { + int recordsPerBatch = 1; + int batches = 0; + int delimiterCount = 2; + long magicTimestamp = 0xDEADBEEF; + OffsetAndEpoch id = new OffsetAndEpoch(recordsPerBatch * batches, 3); + + RaftClientTestContext.Builder contextBuilder = new RaftClientTestContext.Builder(localId, voters); + RaftClientTestContext context = contextBuilder.build(); + + context.pollUntil(() -> context.currentLeader().equals(OptionalInt.of(localId))); + context.advanceLocalLeaderHighWatermarkToLogEndOffset(); + + // Create an empty snapshot and freeze it immediately + try (SnapshotWriter snapshot = context.client.createSnapshot(id.offset - 1, id.epoch, magicTimestamp).get()) { + assertEquals(id, snapshot.snapshotId()); + snapshot.freeze(); + } + + // Verify that an empty snapshot has only the Header and Footer + try (SnapshotReader reader = readSnapshot(context, id, Integer.MAX_VALUE)) { + RawSnapshotReader snapshot = context.log.readSnapshot(id).get(); + int recordCount = validateDelimiters(snapshot, magicTimestamp); + assertEquals((recordsPerBatch * batches) + delimiterCount, recordCount); + } + } + + @Test + public void testWritingSnapshot() throws Exception { + int recordsPerBatch = 3; + int batches = 3; + int delimiterCount = 2; + long magicTimestamp = 0xDEADBEEF; + OffsetAndEpoch id = new OffsetAndEpoch(recordsPerBatch * batches, 3); + List> expected = buildRecords(recordsPerBatch, batches); + + RaftClientTestContext.Builder contextBuilder = new RaftClientTestContext.Builder(localId, voters); + for (List batch : expected) { + contextBuilder.appendToLog(id.epoch, batch); + } + RaftClientTestContext context = contextBuilder.build(); + + context.pollUntil(() -> context.currentLeader().equals(OptionalInt.of(localId))); + int epoch = context.currentEpoch(); + + context.advanceLocalLeaderHighWatermarkToLogEndOffset(); + + try (SnapshotWriter snapshot = context.client.createSnapshot(id.offset - 1, id.epoch, magicTimestamp).get()) { + assertEquals(id, snapshot.snapshotId()); + expected.forEach(batch -> assertDoesNotThrow(() -> snapshot.append(batch))); + snapshot.freeze(); + } + + try (SnapshotReader reader = readSnapshot(context, id, Integer.MAX_VALUE)) { + RawSnapshotReader snapshot = context.log.readSnapshot(id).get(); + int recordCount = validateDelimiters(snapshot, magicTimestamp); + assertEquals((recordsPerBatch * batches) + delimiterCount, recordCount); + assertSnapshot(expected, reader); + } + } + + @Test + public void testAbortedSnapshot() throws Exception { + int recordsPerBatch = 3; + int batches = 3; + OffsetAndEpoch id = new OffsetAndEpoch(recordsPerBatch * batches, 3); + List> expected = buildRecords(recordsPerBatch, batches); + + RaftClientTestContext.Builder contextBuilder = new RaftClientTestContext.Builder(localId, voters); + for (List batch : expected) { + contextBuilder.appendToLog(id.epoch, batch); + } + RaftClientTestContext context = contextBuilder.build(); + + context.pollUntil(() -> context.currentLeader().equals(OptionalInt.of(localId))); + int epoch = context.currentEpoch(); + + context.advanceLocalLeaderHighWatermarkToLogEndOffset(); + + try (SnapshotWriter snapshot = context.client.createSnapshot(id.offset - 1, id.epoch, 0).get()) { + assertEquals(id, snapshot.snapshotId()); + expected.forEach(batch -> { + assertDoesNotThrow(() -> snapshot.append(batch)); + }); + } + + assertEquals(Optional.empty(), context.log.readSnapshot(id)); + } + + @Test + public void testAppendToFrozenSnapshot() throws Exception { + int recordsPerBatch = 3; + int batches = 3; + OffsetAndEpoch id = new OffsetAndEpoch(recordsPerBatch * batches, 3); + List> expected = buildRecords(recordsPerBatch, batches); + + RaftClientTestContext.Builder contextBuilder = new RaftClientTestContext.Builder(localId, voters); + for (List batch : expected) { + contextBuilder.appendToLog(id.epoch, batch); + } + RaftClientTestContext context = contextBuilder.build(); + + context.pollUntil(() -> context.currentLeader().equals(OptionalInt.of(localId))); + int epoch = context.currentEpoch(); + + context.advanceLocalLeaderHighWatermarkToLogEndOffset(); + + try (SnapshotWriter snapshot = context.client.createSnapshot(id.offset - 1, id.epoch, 0).get()) { + assertEquals(id, snapshot.snapshotId()); + expected.forEach(batch -> { + assertDoesNotThrow(() -> snapshot.append(batch)); + }); + + snapshot.freeze(); + + assertThrows(RuntimeException.class, () -> snapshot.append(expected.get(0))); + } + } + + private List> buildRecords(int recordsPerBatch, int batches) { + Random random = new Random(0); + List> result = new ArrayList<>(batches); + for (int i = 0; i < batches; i++) { + List batch = new ArrayList<>(recordsPerBatch); + for (int j = 0; j < recordsPerBatch; j++) { + batch.add(String.valueOf(random.nextInt())); + } + result.add(batch); + } + + return result; + } + + private SnapshotReader readSnapshot( + RaftClientTestContext context, + OffsetAndEpoch snapshotId, + int maxBatchSize + ) { + return SnapshotReader.of( + context.log.readSnapshot(snapshotId).get(), + context.serde, + BufferSupplier.create(), + maxBatchSize + ); + } + + private int validateDelimiters( + RawSnapshotReader snapshot, + long lastContainedLogTime + ) { + assertNotEquals(0, snapshot.sizeInBytes()); + + int countRecords = 0; + + Iterator recordBatches = Utils.covariantCast(snapshot.records().batchIterator()); + + assertTrue(recordBatches.hasNext()); + RecordBatch batch = recordBatches.next(); + + Iterator records = batch.streamingIterator(new GrowableBufferSupplier()); + + // Verify existence of the header record + assertTrue(batch.isControlBatch()); + assertTrue(records.hasNext()); + Record record = records.next(); + countRecords += 1; + + SnapshotHeaderRecord headerRecord = ControlRecordUtils.deserializedSnapshotHeaderRecord(record); + assertEquals(headerRecord.version(), ControlRecordUtils.SNAPSHOT_HEADER_HIGHEST_VERSION); + assertEquals(headerRecord.lastContainedLogTimestamp(), lastContainedLogTime); + + assertFalse(records.hasNext()); + + // Loop over remaining records + while (recordBatches.hasNext()) { + batch = recordBatches.next(); + records = batch.streamingIterator(new GrowableBufferSupplier()); + + while (records.hasNext()) { + countRecords += 1; + record = records.next(); + } + } + + // Verify existence of the footer record in the end + assertTrue(batch.isControlBatch()); + + SnapshotFooterRecord footerRecord = ControlRecordUtils.deserializedSnapshotFooterRecord(record); + assertEquals(footerRecord.version(), ControlRecordUtils.SNAPSHOT_HEADER_HIGHEST_VERSION); + + return countRecords; + } + + public static void assertSnapshot(List> batches, RawSnapshotReader reader) { + assertSnapshot( + batches, + SnapshotReader.of(reader, new StringSerde(), BufferSupplier.create(), Integer.MAX_VALUE) + ); + } + + public static void assertSnapshot(List> batches, SnapshotReader reader) { + List expected = new ArrayList<>(); + batches.forEach(expected::addAll); + + List actual = new ArrayList<>(expected.size()); + while (reader.hasNext()) { + Batch batch = reader.next(); + for (String value : batch) { + actual.add(value); + } + } + + assertEquals(expected, actual); + } +} diff --git a/raft/src/test/java/org/apache/kafka/snapshot/SnapshotsTest.java b/raft/src/test/java/org/apache/kafka/snapshot/SnapshotsTest.java new file mode 100644 index 0000000..ae89543 --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/snapshot/SnapshotsTest.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.snapshot; + +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.raft.OffsetAndEpoch; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import java.io.IOException; +import java.nio.file.FileSystems; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +final public class SnapshotsTest { + + @Test + public void testValidSnapshotFilename() { + OffsetAndEpoch snapshotId = new OffsetAndEpoch( + TestUtils.RANDOM.nextInt(Integer.MAX_VALUE), + TestUtils.RANDOM.nextInt(Integer.MAX_VALUE) + ); + Path path = Snapshots.snapshotPath(TestUtils.tempDirectory().toPath(), snapshotId); + SnapshotPath snapshotPath = Snapshots.parse(path).get(); + + assertEquals(path, snapshotPath.path); + assertEquals(snapshotId, snapshotPath.snapshotId); + assertFalse(snapshotPath.partial); + assertFalse(snapshotPath.deleted); + } + + @Test + public void testValidPartialSnapshotFilename() throws IOException { + OffsetAndEpoch snapshotId = new OffsetAndEpoch( + TestUtils.RANDOM.nextInt(Integer.MAX_VALUE), + TestUtils.RANDOM.nextInt(Integer.MAX_VALUE) + ); + + Path path = Snapshots.createTempFile(TestUtils.tempDirectory().toPath(), snapshotId); + // Delete it as we only need the path for testing + Files.delete(path); + + SnapshotPath snapshotPath = Snapshots.parse(path).get(); + + assertEquals(path, snapshotPath.path); + assertEquals(snapshotId, snapshotPath.snapshotId); + assertTrue(snapshotPath.partial); + } + + @Test + public void testValidDeletedSnapshotFilename() { + OffsetAndEpoch snapshotId = new OffsetAndEpoch( + TestUtils.RANDOM.nextInt(Integer.MAX_VALUE), + TestUtils.RANDOM.nextInt(Integer.MAX_VALUE) + ); + Path path = Snapshots.snapshotPath(TestUtils.tempDirectory().toPath(), snapshotId); + Path deletedPath = Snapshots.deleteRename(path, snapshotId); + SnapshotPath snapshotPath = Snapshots.parse(deletedPath).get(); + + assertEquals(snapshotId, snapshotPath.snapshotId); + assertTrue(snapshotPath.deleted); + } + + @Test + public void testInvalidSnapshotFilenames() { + Path root = FileSystems.getDefault().getPath("/"); + // Doesn't parse log files + assertEquals(Optional.empty(), Snapshots.parse(root.resolve("00000000000000000000.log"))); + // Doesn't parse producer snapshots + assertEquals(Optional.empty(), Snapshots.parse(root.resolve("00000000000000000000.snapshot"))); + // Doesn't parse offset indexes + assertEquals(Optional.empty(), Snapshots.parse(root.resolve("00000000000000000000.index"))); + assertEquals(Optional.empty(), Snapshots.parse(root.resolve("00000000000000000000.timeindex"))); + // Leader epoch checkpoint + assertEquals(Optional.empty(), Snapshots.parse(root.resolve("leader-epoch-checkpoint"))); + // partition metadata + assertEquals(Optional.empty(), Snapshots.parse(root.resolve("partition.metadata"))); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void testDeleteSnapshot(boolean renameBeforeDeleting) throws IOException { + + OffsetAndEpoch snapshotId = new OffsetAndEpoch( + TestUtils.RANDOM.nextInt(Integer.MAX_VALUE), + TestUtils.RANDOM.nextInt(Integer.MAX_VALUE) + ); + + Path logDirPath = TestUtils.tempDirectory().toPath(); + try (FileRawSnapshotWriter snapshot = FileRawSnapshotWriter.create(logDirPath, snapshotId, Optional.empty())) { + snapshot.freeze(); + + Path snapshotPath = Snapshots.snapshotPath(logDirPath, snapshotId); + assertTrue(Files.exists(snapshotPath)); + + if (renameBeforeDeleting) + // rename snapshot before deleting + Utils.atomicMoveWithFallback(snapshotPath, Snapshots.deleteRename(snapshotPath, snapshotId), false); + + assertTrue(Snapshots.deleteIfExists(logDirPath, snapshot.snapshotId())); + assertFalse(Files.exists(snapshotPath)); + assertFalse(Files.exists(Snapshots.deleteRename(snapshotPath, snapshotId))); + } + } +} diff --git a/raft/src/test/resources/log4j.properties b/raft/src/test/resources/log4j.properties new file mode 100644 index 0000000..6d90f6d --- /dev/null +++ b/raft/src/test/resources/log4j.properties @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +log4j.rootLogger=OFF, stdout + +log4j.appender.stdout=org.apache.log4j.ConsoleAppender +log4j.appender.stdout.layout=org.apache.log4j.PatternLayout +log4j.appender.stdout.layout.ConversionPattern=[%d] %p %m (%c:%L)%n + +log4j.logger.org.apache.kafka.raft=ERROR +log4j.logger.org.apache.kafka.snapshot=ERROR diff --git a/release.py b/release.py new file mode 100755 index 0000000..c690d18 --- /dev/null +++ b/release.py @@ -0,0 +1,760 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Utility for creating release candidates and promoting release candidates to a final relase. + +Usage: release.py [subcommand] + +release.py stage + + Builds and stages an RC for a release. + + The utility is interactive; you will be prompted for basic release information and guided through the process. + + This utility assumes you already have local a kafka git folder and that you + have added remotes corresponding to both: + (i) the github apache kafka mirror and + (ii) the apache kafka git repo. + +release.py stage-docs [kafka-site-path] + + Builds the documentation and stages it into an instance of the Kafka website repository. + + This is meant to automate the integration between the main Kafka website repository (https://github.com/apache/kafka-site) + and the versioned documentation maintained in the main Kafka repository. This is useful both for local testing and + development of docs (follow the instructions here: https://cwiki.apache.org/confluence/display/KAFKA/Setup+Kafka+Website+on+Local+Apache+Server) + as well as for committers to deploy docs (run this script, then validate, commit, and push to kafka-site). + + With no arguments this script assumes you have the Kafka repository and kafka-site repository checked out side-by-side, but + you can specify a full path to the kafka-site repository if this is not the case. + +release.py release-email + + Generates the email content/template for sending release announcement email. + +""" + +import datetime +from getpass import getpass +import json +import os +import subprocess +import sys +import tempfile +import time +import re + +PROJECT_NAME = "kafka" +CAPITALIZED_PROJECT_NAME = "kafka".upper() +SCRIPT_DIR = os.path.abspath(os.path.dirname(__file__)) +# Location of the local git repository +REPO_HOME = os.environ.get("%s_HOME" % CAPITALIZED_PROJECT_NAME, SCRIPT_DIR) +# Remote name, which points to Github by default +PUSH_REMOTE_NAME = os.environ.get("PUSH_REMOTE_NAME", "apache-github") +PREFS_FILE = os.path.join(SCRIPT_DIR, '.release-settings.json') +PUBLIC_HTML = "public_html" + +delete_gitrefs = False +work_dir = None + +def fail(msg): + if work_dir: + cmd("Cleaning up work directory", "rm -rf %s" % work_dir) + + if delete_gitrefs: + try: + cmd("Resetting repository working state to branch %s" % starting_branch, "git reset --hard HEAD && git checkout %s" % starting_branch, shell=True) + cmd("Deleting git branches %s" % release_version, "git branch -D %s" % release_version, shell=True) + cmd("Deleting git tag %s" %rc_tag , "git tag -d %s" % rc_tag, shell=True) + except subprocess.CalledProcessError: + print("Failed when trying to clean up git references added by this script. You may need to clean up branches/tags yourself before retrying.") + print("Expected git branch: " + release_version) + print("Expected git tag: " + rc_tag) + print(msg) + sys.exit(1) + +def print_output(output): + if output is None or len(output) == 0: + return + for line in output.split('\n'): + print(">", line) + +def cmd(action, cmd_arg, *args, **kwargs): + if isinstance(cmd_arg, str) and not kwargs.get("shell", False): + cmd_arg = cmd_arg.split() + allow_failure = kwargs.pop("allow_failure", False) + num_retries = kwargs.pop("num_retries", 0) + + stdin_log = "" + if "stdin" in kwargs and isinstance(kwargs["stdin"], str): + stdin_log = "--> " + kwargs["stdin"] + stdin = tempfile.TemporaryFile() + stdin.write(kwargs["stdin"].encode('utf-8')) + stdin.seek(0) + kwargs["stdin"] = stdin + + print(action, cmd_arg, stdin_log) + try: + output = subprocess.check_output(cmd_arg, *args, stderr=subprocess.STDOUT, **kwargs) + print_output(output.decode('utf-8')) + except subprocess.CalledProcessError as e: + print_output(e.output.decode('utf-8')) + + if num_retries > 0: + kwargs['num_retries'] = num_retries - 1 + kwargs['allow_failure'] = allow_failure + print("Retrying... %d remaining retries" % (num_retries - 1)) + time.sleep(4. / (num_retries + 1)) # e.g., if retries=3, sleep for 1s, 1.3s, 2s + return cmd(action, cmd_arg, *args, **kwargs) + + if allow_failure: + return + + print("*************************************************") + print("*** First command failure occurred here. ***") + print("*** Will now try to clean up working state. ***") + print("*************************************************") + fail("") + + +def cmd_output(cmd, *args, **kwargs): + if isinstance(cmd, str): + cmd = cmd.split() + return subprocess.check_output(cmd, *args, stderr=subprocess.STDOUT, **kwargs).decode('utf-8') + +def replace(path, pattern, replacement): + updated = [] + with open(path, 'r') as f: + for line in f: + updated.append((replacement + '\n') if line.startswith(pattern) else line) + + with open(path, 'w') as f: + for line in updated: + f.write(line) + +def regexReplace(path, pattern, replacement): + updated = [] + with open(path, 'r') as f: + for line in f: + updated.append(re.sub(pattern, replacement, line)) + + with open(path, 'w') as f: + for line in updated: + f.write(line) + +def user_ok(msg): + ok = input(msg) + return ok.strip().lower() == 'y' + +def sftp_mkdir(dir): + try: + cmd_str = """ +mkdir %s +""" % dir + cmd("Creating '%s' in your Apache home directory if it does not exist (errors are ok if the directory already exists)" % dir, "sftp -b - %s@home.apache.org" % apache_id, stdin=cmd_str, allow_failure=True, num_retries=3) + except subprocess.CalledProcessError: + # This is ok. The command fails if the directory already exists + pass + +def sftp_upload(dir): + try: + cmd_str = """ +cd %s +put -r %s +""" % (PUBLIC_HTML, dir) + cmd("Uploading '%s' under %s in your Apache home directory" % (dir, PUBLIC_HTML), "sftp -b - %s@home.apache.org" % apache_id, stdin=cmd_str, allow_failure=True, num_retries=3) + except subprocess.CalledProcessError: + fail("Failed uploading %s to your Apache home directory" % dir) + +def get_pref(prefs, name, request_fn): + "Get a preference from existing preference dictionary or invoke a function that can collect it from the user" + val = prefs.get(name) + if not val: + val = request_fn() + prefs[name] = val + return val + +def load_prefs(): + """Load saved preferences""" + prefs = {} + if os.path.exists(PREFS_FILE): + with open(PREFS_FILE, 'r') as prefs_fp: + prefs = json.load(prefs_fp) + return prefs + +def save_prefs(prefs): + """Save preferences""" + print("Saving preferences to %s" % PREFS_FILE) + with open(PREFS_FILE, 'w') as prefs_fp: + prefs = json.dump(prefs, prefs_fp) + +def get_jdk(prefs, version): + """ + Get settings for the specified JDK version. + """ + jdk_java_home = get_pref(prefs, 'jdk%d' % version, lambda: input("Enter the path for JAVA_HOME for a JDK%d compiler (blank to use default JAVA_HOME): " % version)) + jdk_env = dict(os.environ) if jdk_java_home.strip() else None + if jdk_env is not None: jdk_env['JAVA_HOME'] = jdk_java_home + java_version = cmd_output("%s/bin/java -version" % jdk_java_home, env=jdk_env) + if version == 8: + if "1.8.0" not in java_version: + fail("JDK 8 is required") + elif "%d.0" % version not in java_version and '"%d"' % version not in java_version: + fail("JDK %s is required" % version) + return jdk_env + +def get_version(repo=REPO_HOME): + """ + Extracts the full version information as a str from gradle.properties + """ + with open(os.path.join(repo, 'gradle.properties')) as fp: + for line in fp: + parts = line.split('=') + if parts[0].strip() != 'version': continue + return parts[1].strip() + fail("Couldn't extract version from gradle.properties") + +def docs_version(version): + """ + Detects the major/minor version and converts it to the format used for docs on the website, e.g. gets 0.10.2.0-SNAPSHOT + from gradle.properties and converts it to 0102 + """ + version_parts = version.strip().split('.') + # 1.0+ will only have 3 version components as opposed to pre-1.0 that had 4 + major_minor = version_parts[0:3] if version_parts[0] == '0' else version_parts[0:2] + return ''.join(major_minor) + +def docs_release_version(version): + """ + Detects the version from gradle.properties and converts it to a release version number that should be valid for the + current release branch. For example, 0.10.2.0-SNAPSHOT would remain 0.10.2.0-SNAPSHOT (because no release has been + made on that branch yet); 0.10.2.1-SNAPSHOT would be converted to 0.10.2.0 because 0.10.2.1 is still in development + but 0.10.2.0 should have already been released. Regular version numbers (e.g. as encountered on a release branch) + will remain the same. + """ + version_parts = version.strip().split('.') + if '-SNAPSHOT' in version_parts[-1]: + bugfix = int(version_parts[-1].split('-')[0]) + if bugfix > 0: + version_parts[-1] = str(bugfix - 1) + return '.'.join(version_parts) + +def command_stage_docs(): + kafka_site_repo_path = sys.argv[2] if len(sys.argv) > 2 else os.path.join(REPO_HOME, '..', 'kafka-site') + if not os.path.exists(kafka_site_repo_path) or not os.path.exists(os.path.join(kafka_site_repo_path, 'powered-by.html')): + sys.exit("%s doesn't exist or does not appear to be the kafka-site repository" % kafka_site_repo_path) + + prefs = load_prefs() + jdk17_env = get_jdk(prefs, 17) + save_prefs(prefs) + + version = get_version() + # We explicitly override the version of the project that we normally get from gradle.properties since we want to be + # able to run this from a release branch where we made some updates, but the build would show an incorrect SNAPSHOT + # version due to already having bumped the bugfix version number. + gradle_version_override = docs_release_version(version) + + cmd("Building docs", "./gradlew -Pversion=%s clean siteDocsTar aggregatedJavadoc" % gradle_version_override, cwd=REPO_HOME, env=jdk17_env) + + docs_tar = os.path.join(REPO_HOME, 'core', 'build', 'distributions', 'kafka_2.13-%s-site-docs.tgz' % gradle_version_override) + + versioned_docs_path = os.path.join(kafka_site_repo_path, docs_version(version)) + if not os.path.exists(versioned_docs_path): + os.mkdir(versioned_docs_path, 755) + + # The contents of the docs jar are site-docs/. We need to get rid of the site-docs prefix and dump everything + # inside it into the docs version subdirectory in the kafka-site repo + cmd('Extracting site-docs', 'tar xf %s --strip-components 1' % docs_tar, cwd=versioned_docs_path) + + javadocs_src_dir = os.path.join(REPO_HOME, 'build', 'docs', 'javadoc') + + cmd('Copying javadocs', 'cp -R %s %s' % (javadocs_src_dir, versioned_docs_path)) + + sys.exit(0) + +def validate_release_version_parts(version): + try: + version_parts = version.split('.') + if len(version_parts) != 3: + fail("Invalid release version, should have 3 version number components") + # Validate each part is a number + [int(x) for x in version_parts] + except ValueError: + fail("Invalid release version, should be a dotted version number") + +def get_release_version_parts(version): + validate_release_version_parts(version) + return version.split('.') + +def validate_release_num(version): + tags = cmd_output('git tag').split() + if version not in tags: + fail("The specified version is not a valid release version number") + validate_release_version_parts(version) + +def command_release_announcement_email(): + tags = cmd_output('git tag').split() + release_tag_pattern = re.compile('^[0-9]+\.[0-9]+\.[0-9]+$') + release_tags = sorted([t for t in tags if re.match(release_tag_pattern, t)]) + release_version_num = release_tags[-1] + if not user_ok("""Is the current release %s ? (y/n): """ % release_version_num): + release_version_num = input('What is the current release version:') + validate_release_num(release_version_num) + previous_release_version_num = release_tags[-2] + if not user_ok("""Is the previous release %s ? (y/n): """ % previous_release_version_num): + previous_release_version_num = input('What is the previous release version:') + validate_release_num(previous_release_version_num) + if release_version_num < previous_release_version_num : + fail("Current release version number can't be less than previous release version number") + number_of_contributors = int(subprocess.check_output('git shortlog -sn --no-merges %s..%s | wc -l' % (previous_release_version_num, release_version_num) , shell=True).decode('utf-8')) + contributors = subprocess.check_output("git shortlog -sn --no-merges %s..%s | cut -f2 | sort --ignore-case" % (previous_release_version_num, release_version_num), shell=True).decode('utf-8') + release_announcement_data = { + 'number_of_contributors': number_of_contributors, + 'contributors': ', '.join(str(x) for x in filter(None, contributors.split('\n'))), + 'release_version': release_version_num + } + + release_announcement_email = """ +To: announce@apache.org, dev@kafka.apache.org, users@kafka.apache.org, kafka-clients@googlegroups.com +Subject: [ANNOUNCE] Apache Kafka %(release_version)s + +The Apache Kafka community is pleased to announce the release for Apache Kafka %(release_version)s + +
                + +All of the changes in this release can be found in the release notes: +https://www.apache.org/dist/kafka/%(release_version)s/RELEASE_NOTES.html + + +You can download the source and binary release (Scala ) from: +https://kafka.apache.org/downloads#%(release_version)s + +--------------------------------------------------------------------------------------------------- + + +Apache Kafka is a distributed streaming platform with four core APIs: + + +** The Producer API allows an application to publish a stream of records to +one or more Kafka topics. + +** The Consumer API allows an application to subscribe to one or more +topics and process the stream of records produced to them. + +** The Streams API allows an application to act as a stream processor, +consuming an input stream from one or more topics and producing an +output stream to one or more output topics, effectively transforming the +input streams to output streams. + +** The Connector API allows building and running reusable producers or +consumers that connect Kafka topics to existing applications or data +systems. For example, a connector to a relational database might +capture every change to a table. + + +With these APIs, Kafka can be used for two broad classes of application: + +** Building real-time streaming data pipelines that reliably get data +between systems or applications. + +** Building real-time streaming applications that transform or react +to the streams of data. + + +Apache Kafka is in use at large and small companies worldwide, including +Capital One, Goldman Sachs, ING, LinkedIn, Netflix, Pinterest, Rabobank, +Target, The New York Times, Uber, Yelp, and Zalando, among others. + +A big thank you for the following %(number_of_contributors)d contributors to this release! + +%(contributors)s + +We welcome your help and feedback. For more information on how to +report problems, and to get involved, visit the project website at +https://kafka.apache.org/ + +Thank you! + + +Regards, + +""" % release_announcement_data + + print() + print("*****************************************************************") + print() + print(release_announcement_email) + print() + print("*****************************************************************") + print() + print("Use the above template to send the announcement for the release to the mailing list.") + print("IMPORTANT: Note that there are still some substitutions that need to be made in the template:") + print(" - Describe major changes in this release") + print(" - Scala versions") + print(" - Fill in your name in the signature") + print(" - You will need to use your apache email address to send out the email (otherwise, it won't be delivered to announce@apache.org)") + print(" - Finally, validate all the links before shipping!") + print("Note that all substitutions are annotated with <> around them.") + sys.exit(0) + + +# Dispatch to subcommand +subcommand = sys.argv[1] if len(sys.argv) > 1 else None +if subcommand == 'stage-docs': + command_stage_docs() +elif subcommand == 'release-email': + command_release_announcement_email() +elif not (subcommand is None or subcommand == 'stage'): + fail("Unknown subcommand: %s" % subcommand) +# else -> default subcommand stage + + +## Default 'stage' subcommand implementation isn't isolated to its own function yet for historical reasons + +prefs = load_prefs() + +if not user_ok("""Requirements: +1. Updated docs to reference the new release version where appropriate. +2. JDK8 and JDK17 compilers and libraries +3. Your Apache ID, already configured with SSH keys on id.apache.org and SSH keys available in this shell session +4. All issues in the target release resolved with valid resolutions (if not, this script will report the problematic JIRAs) +5. A GPG key used for signing the release. This key should have been added to public Apache servers and the KEYS file on the Kafka site +6. Standard toolset installed -- git, gpg, gradle, sftp, etc. +7. ~/.gradle/gradle.properties configured with the signing properties described in the release process wiki, i.e. + + mavenUrl=https://repository.apache.org/service/local/staging/deploy/maven2 + mavenUsername=your-apache-id + mavenPassword=your-apache-passwd + signing.keyId=your-gpgkeyId + signing.password=your-gpg-passphrase + signing.secretKeyRingFile=/Users/your-id/.gnupg/secring.gpg (if you are using GPG 2.1 and beyond, then this file will no longer exist anymore, and you have to manually create it from the new private key directory with "gpg --export-secret-keys -o ~/.gnupg/secring.gpg") +8. ~/.m2/settings.xml configured for pgp signing and uploading to apache release maven, i.e., + + apache.releases.https + your-apache-id + your-apache-passwd + + + your-gpgkeyId + your-gpg-passphrase + + + gpg-signing + + your-gpgkeyId + your-gpgkeyId + + +9. You may also need to update some gnupgp configs: + ~/.gnupg/gpg-agent.conf + allow-loopback-pinentry + + ~/.gnupg/gpg.conf + use-agent + pinentry-mode loopback + + echo RELOADAGENT | gpg-connect-agent + +If any of these are missing, see https://cwiki.apache.org/confluence/display/KAFKA/Release+Process for instructions on setting them up. + +Some of these may be used from these previous settings loaded from %s: + +%s + +Do you have all of of these setup? (y/n): """ % (PREFS_FILE, json.dumps(prefs, indent=2))): + fail("Please try again once you have all the prerequisites ready.") + + +starting_branch = cmd_output('git rev-parse --abbrev-ref HEAD') + +cmd("Verifying that you have no unstaged git changes", 'git diff --exit-code --quiet') +cmd("Verifying that you have no staged git changes", 'git diff --cached --exit-code --quiet') + +release_version = input("Release version (without any RC info, e.g. 1.0.0): ") +release_version_parts = get_release_version_parts(release_version) + +rc = input("Release candidate number: ") + +dev_branch = '.'.join(release_version_parts[:2]) +docs_release_version = docs_version(release_version) + +# Validate that the release doesn't already exist and that the +cmd("Fetching tags from upstream", 'git fetch --tags %s' % PUSH_REMOTE_NAME) +tags = cmd_output('git tag').split() + +if release_version in tags: + fail("The specified version has already been tagged and released.") + +# TODO promotion +if not rc: + fail("Automatic Promotion is not yet supported.") + + # Find the latest RC and make sure they want to promote that one + rc_tag = sorted([t for t in tags if t.startswith(release_version + '-rc')])[-1] + if not user_ok("Found %s as latest RC for this release. Is this correct? (y/n): "): + fail("This script couldn't determine which RC tag to promote, you'll need to fix up the RC tags and re-run the script.") + + sys.exit(0) + +# Prereq checks +apache_id = get_pref(prefs, 'apache_id', lambda: input("Enter your apache username: ")) + +jdk8_env = get_jdk(prefs, 8) +jdk17_env = get_jdk(prefs, 17) + +def select_gpg_key(): + print("Here are the available GPG keys:") + available_keys = cmd_output("gpg --list-secret-keys") + print(available_keys) + key_name = input("Which user name (enter the user name without email address): ") + if key_name not in available_keys: + fail("Couldn't find the requested key.") + return key_name + +key_name = get_pref(prefs, 'gpg-key', select_gpg_key) + +gpg_passphrase = get_pref(prefs, 'gpg-pass', lambda: getpass("Passphrase for this GPG key: ")) +# Do a quick validation so we can fail fast if the password is incorrect +with tempfile.NamedTemporaryFile() as gpg_test_tempfile: + gpg_test_tempfile.write("abcdefg".encode('utf-8')) + cmd("Testing GPG key & passphrase", ["gpg", "--batch", "--pinentry-mode", "loopback", "--passphrase-fd", "0", "-u", key_name, "--armor", "--output", gpg_test_tempfile.name + ".asc", "--detach-sig", gpg_test_tempfile.name], stdin=gpg_passphrase) + +save_prefs(prefs) + +# Generate RC +try: + int(rc) +except ValueError: + fail("Invalid release candidate number: %s" % rc) +rc_tag = release_version + '-rc' + rc + +delete_gitrefs = True # Since we are about to start creating new git refs, enable cleanup function on failure to try to delete them +cmd("Checking out current development branch", "git checkout -b %s %s" % (release_version, PUSH_REMOTE_NAME + "/" + dev_branch)) +print("Updating version numbers") +replace("gradle.properties", "version", "version=%s" % release_version) +replace("tests/kafkatest/__init__.py", "__version__", "__version__ = '%s'" % release_version) +print("updating streams quickstart pom") +regexReplace("streams/quickstart/pom.xml", "-SNAPSHOT", "") +print("updating streams quickstart java pom") +regexReplace("streams/quickstart/java/pom.xml", "-SNAPSHOT", "") +print("updating streams quickstart archetype pom") +regexReplace("streams/quickstart/java/src/main/resources/archetype-resources/pom.xml", "-SNAPSHOT", "") +print("updating ducktape version.py") +regexReplace("./tests/kafkatest/version.py", "^DEV_VERSION =.*", + "DEV_VERSION = KafkaVersion(\"%s-SNAPSHOT\")" % release_version) +# Command in explicit list due to messages with spaces +cmd("Committing version number updates", ["git", "commit", "-a", "-m", "Bump version to %s" % release_version]) +# Command in explicit list due to messages with spaces +cmd("Tagging release candidate %s" % rc_tag, ["git", "tag", "-a", rc_tag, "-m", rc_tag]) +rc_githash = cmd_output("git show-ref --hash " + rc_tag) +cmd("Switching back to your starting branch", "git checkout %s" % starting_branch) + +# Note that we don't use tempfile here because mkdtemp causes problems with sftp and being able to determine the absolute path to a file. +# Instead we rely on a fixed path and if it +work_dir = os.path.join(REPO_HOME, ".release_work_dir") +if os.path.exists(work_dir): + fail("A previous attempt at a release left dirty state in the work directory. Clean up %s before proceeding. (This attempt will try to cleanup, simply retrying may be sufficient now...)" % work_dir) +os.makedirs(work_dir) +print("Temporary build working director:", work_dir) +kafka_dir = os.path.join(work_dir, 'kafka') +streams_quickstart_dir = os.path.join(kafka_dir, 'streams/quickstart') +print("Streams quickstart dir", streams_quickstart_dir) +artifact_name = "kafka-" + rc_tag +cmd("Creating staging area for release artifacts", "mkdir " + artifact_name, cwd=work_dir) +artifacts_dir = os.path.join(work_dir, artifact_name) +cmd("Cloning clean copy of repo", "git clone %s kafka" % REPO_HOME, cwd=work_dir) +cmd("Checking out RC tag", "git checkout -b %s %s" % (release_version, rc_tag), cwd=kafka_dir) +current_year = datetime.datetime.now().year +cmd("Verifying the correct year in NOTICE", "grep %s NOTICE" % current_year, cwd=kafka_dir) + +with open(os.path.join(artifacts_dir, "RELEASE_NOTES.html"), 'w') as f: + print("Generating release notes") + try: + subprocess.check_call([sys.executable, "./release_notes.py", release_version], stdout=f) + except subprocess.CalledProcessError as e: + print_output(e.output) + + print("*************************************************") + print("*** First command failure occurred here. ***") + print("*** Will now try to clean up working state. ***") + print("*************************************************") + fail("") + + +params = { 'release_version': release_version, + 'rc_tag': rc_tag, + 'artifacts_dir': artifacts_dir + } +cmd("Creating source archive", "git archive --format tar.gz --prefix kafka-%(release_version)s-src/ -o %(artifacts_dir)s/kafka-%(release_version)s-src.tgz %(rc_tag)s" % params) + +cmd("Building artifacts", "./gradlew clean && ./gradlewAll releaseTarGz", cwd=kafka_dir, env=jdk8_env, shell=True) +cmd("Copying artifacts", "cp %s/core/build/distributions/* %s" % (kafka_dir, artifacts_dir), shell=True) +cmd("Building docs", "./gradlew clean aggregatedJavadoc", cwd=kafka_dir, env=jdk17_env) +cmd("Copying docs", "cp -R %s/build/docs/javadoc %s" % (kafka_dir, artifacts_dir)) + +for filename in os.listdir(artifacts_dir): + full_path = os.path.join(artifacts_dir, filename) + if not os.path.isfile(full_path): + continue + # Commands in explicit list due to key_name possibly containing spaces + cmd("Signing " + full_path, ["gpg", "--batch", "--passphrase-fd", "0", "-u", key_name, "--armor", "--output", full_path + ".asc", "--detach-sig", full_path], stdin=gpg_passphrase) + cmd("Verifying " + full_path, ["gpg", "--verify", full_path + ".asc", full_path]) + # Note that for verification, we need to make sure only the filename is used with --print-md because the command line + # argument for the file is included in the output and verification uses a simple diff that will break if an absolut path + # is used. + dir, fname = os.path.split(full_path) + cmd("Generating MD5 for " + full_path, "gpg --print-md md5 %s > %s.md5" % (fname, fname), shell=True, cwd=dir) + cmd("Generating SHA1 for " + full_path, "gpg --print-md sha1 %s > %s.sha1" % (fname, fname), shell=True, cwd=dir) + cmd("Generating SHA512 for " + full_path, "gpg --print-md sha512 %s > %s.sha512" % (fname, fname), shell=True, cwd=dir) + +cmd("Listing artifacts to be uploaded:", "ls -R %s" % artifacts_dir) + +cmd("Zipping artifacts", "tar -czf %s.tar.gz %s" % (artifact_name, artifact_name), cwd=work_dir) +sftp_mkdir(PUBLIC_HTML) +sftp_upload(artifacts_dir) +if not user_ok("Confirm the artifact is present under %s in your Apache home directory: https://home.apache.org/~%s/ (y/n)?: " % (PUBLIC_HTML, apache_id)): + fail("Ok, giving up") + +with open(os.path.expanduser("~/.gradle/gradle.properties")) as f: + contents = f.read() +if not user_ok("Going to build and upload mvn artifacts based on these settings:\n" + contents + '\nOK (y/n)?: '): + fail("Retry again later") +cmd("Building and uploading archives", "./gradlewAll publish", cwd=kafka_dir, env=jdk8_env, shell=True) +cmd("Building and uploading archives", "mvn deploy -Pgpg-signing", cwd=streams_quickstart_dir, env=jdk8_env, shell=True) + +release_notification_props = { 'release_version': release_version, + 'rc': rc, + 'rc_tag': rc_tag, + 'rc_githash': rc_githash, + 'dev_branch': dev_branch, + 'docs_version': docs_release_version, + 'apache_id': apache_id, + } + +# TODO: Many of these suggested validation steps could be automated and would help pre-validate a lot of the stuff voters test +print(""" +******************************************************************************************************************************************************* +Ok. We've built and staged everything for the %(rc_tag)s. + +Now you should sanity check it before proceeding. All subsequent steps start making RC data public. + +Some suggested steps: + + * Grab the source archive and make sure it compiles: https://home.apache.org/~%(apache_id)s/kafka-%(rc_tag)s/kafka-%(release_version)s-src.tgz + * Grab one of the binary distros and run the quickstarts against them: https://home.apache.org/~%(apache_id)s/kafka-%(rc_tag)s/kafka_2.13-%(release_version)s.tgz + * Extract and verify one of the site docs jars: https://home.apache.org/~%(apache_id)s/kafka-%(rc_tag)s/kafka_2.13-%(release_version)s-site-docs.tgz + * Build a sample against jars in the staging repo: (TODO: Can we get a temporary URL before "closing" the staged artifacts?) + * Validate GPG signatures on at least one file: + wget https://home.apache.org/~%(apache_id)s/kafka-%(rc_tag)s/kafka-%(release_version)s-src.tgz && + wget https://home.apache.org/~%(apache_id)s/kafka-%(rc_tag)s/kafka-%(release_version)s-src.tgz.asc && + wget https://home.apache.org/~%(apache_id)s/kafka-%(rc_tag)s/kafka-%(release_version)s-src.tgz.md5 && + wget https://home.apache.org/~%(apache_id)s/kafka-%(rc_tag)s/kafka-%(release_version)s-src.tgz.sha1 && + wget https://home.apache.org/~%(apache_id)s/kafka-%(rc_tag)s/kafka-%(release_version)s-src.tgz.sha512 && + gpg --verify kafka-%(release_version)s-src.tgz.asc kafka-%(release_version)s-src.tgz && + gpg --print-md md5 kafka-%(release_version)s-src.tgz | diff - kafka-%(release_version)s-src.tgz.md5 && + gpg --print-md sha1 kafka-%(release_version)s-src.tgz | diff - kafka-%(release_version)s-src.tgz.sha1 && + gpg --print-md sha512 kafka-%(release_version)s-src.tgz | diff - kafka-%(release_version)s-src.tgz.sha512 && + rm kafka-%(release_version)s-src.tgz* && + echo "OK" || echo "Failed" + * Validate the javadocs look ok. They are at https://home.apache.org/~%(apache_id)s/kafka-%(rc_tag)s/javadoc/ + +******************************************************************************************************************************************************* +""" % release_notification_props) +if not user_ok("Have you sufficiently verified the release artifacts (y/n)?: "): + fail("Ok, giving up") + +print("Next, we need to get the Maven artifacts we published into the staging repository.") +# TODO: Can we get this closed via a REST API since we already need to collect credentials for this repo? +print("Go to https://repository.apache.org/#stagingRepositories and hit 'Close' for the new repository that was created by uploading artifacts.") +print("If this is not the first RC, you need to 'Drop' the previous artifacts.") +print("Confirm the correct artifacts are visible at https://repository.apache.org/content/groups/staging/org/apache/kafka/") +if not user_ok("Have you successfully deployed the artifacts (y/n)?: "): + fail("Ok, giving up") +if not user_ok("Ok to push RC tag %s (y/n)?: " % rc_tag): + fail("Ok, giving up") +cmd("Pushing RC tag", "git push %s %s" % (PUSH_REMOTE_NAME, rc_tag)) + +# Move back to starting branch and clean out the temporary release branch (e.g. 1.0.0) we used to generate everything +cmd("Resetting repository working state", "git reset --hard HEAD && git checkout %s" % starting_branch, shell=True) +cmd("Deleting git branches %s" % release_version, "git branch -D %s" % release_version, shell=True) + + +email_contents = """ +To: dev@kafka.apache.org, users@kafka.apache.org, kafka-clients@googlegroups.com +Subject: [VOTE] %(release_version)s RC%(rc)s + +Hello Kafka users, developers and client-developers, + +This is the first candidate for release of Apache Kafka %(release_version)s. + + + +Release notes for the %(release_version)s release: +https://home.apache.org/~%(apache_id)s/kafka-%(rc_tag)s/RELEASE_NOTES.html + +*** Please download, test and vote by + +Kafka's KEYS file containing PGP keys we use to sign the release: +https://kafka.apache.org/KEYS + +* Release artifacts to be voted upon (source and binary): +https://home.apache.org/~%(apache_id)s/kafka-%(rc_tag)s/ + +* Maven artifacts to be voted upon: +https://repository.apache.org/content/groups/staging/org/apache/kafka/ + +* Javadoc: +https://home.apache.org/~%(apache_id)s/kafka-%(rc_tag)s/javadoc/ + +* Tag to be voted upon (off %(dev_branch)s branch) is the %(release_version)s tag: +https://github.com/apache/kafka/releases/tag/%(rc_tag)s + +* Documentation: +https://kafka.apache.org/%(docs_version)s/documentation.html + +* Protocol: +https://kafka.apache.org/%(docs_version)s/protocol.html + +* Successful Jenkins builds for the %(dev_branch)s branch: +Unit/integration tests: https://ci-builds.apache.org/job/Kafka/job/kafka/job/%(dev_branch)s// +System tests: https://jenkins.confluent.io/job/system-test-kafka/job/%(dev_branch)s// + +/************************************** + +Thanks, + +""" % release_notification_props + +print() +print() +print("*****************************************************************") +print() +print(email_contents) +print() +print("*****************************************************************") +print() +print("All artifacts should now be fully staged. Use the above template to send the announcement for the RC to the mailing list.") +print("IMPORTANT: Note that there are still some substitutions that need to be made in the template:") +print(" - Describe major changes in this release") +print(" - Deadline for voting, which should be at least 3 days after you send out the email") +print(" - Jenkins build numbers for successful unit & system test builds") +print(" - Fill in your name in the signature") +print(" - Finally, validate all the links before shipping!") +print("Note that all substitutions are annotated with <> around them.") diff --git a/release_notes.py b/release_notes.py new file mode 100755 index 0000000..e44c74d --- /dev/null +++ b/release_notes.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Usage: release_notes.py > RELEASE_NOTES.html + +Generates release notes for a Kafka release by generating an HTML doc containing some introductory information about the + release with links to the Kafka docs followed by a list of issues resolved in the release. The script will fail if it finds + any unresolved issues still marked with the target release. You should run this script after either resolving all issues or + moving outstanding issues to a later release. + +""" + +from jira import JIRA +import itertools, sys + +if len(sys.argv) < 2: + print("Usage: release_notes.py ", file=sys.stderr) + sys.exit(1) + +version = sys.argv[1] +minor_version_dotless = "".join(version.split(".")[:2]) # i.e., 10 if version == 1.0.1 + +JIRA_BASE_URL = 'https://issues.apache.org/jira' +MAX_RESULTS = 100 # This is constrained for cloud instances so we need to fix this value + +def get_issues(jira, query, **kwargs): + """ + Get all issues matching the JQL query from the JIRA instance. This handles expanding paginated results for you. Any additional keyword arguments are forwarded to the JIRA.search_issues call. + """ + results = [] + startAt = 0 + new_results = None + while new_results == None or len(new_results) == MAX_RESULTS: + new_results = jira.search_issues(query, startAt=startAt, maxResults=MAX_RESULTS, **kwargs) + results += new_results + startAt += len(new_results) + return results + +def issue_link(issue): + return "%s/browse/%s" % (JIRA_BASE_URL, issue.key) + + +if __name__ == "__main__": + apache = JIRA(JIRA_BASE_URL) + issues = get_issues(apache, 'project=KAFKA and fixVersion=%s' % version) + if not issues: + print("Didn't find any issues for the target fix version", file=sys.stderr) + sys.exit(1) + + # Some resolutions, including a lack of resolution, indicate that the bug hasn't actually been addressed and we shouldn't even be able to create a release until they are fixed + UNRESOLVED_RESOLUTIONS = [None, + "Unresolved", + "Duplicate", + "Invalid", + "Not A Problem", + "Not A Bug", + "Won't Fix", + "Incomplete", + "Cannot Reproduce", + "Later", + "Works for Me", + "Workaround", + "Information Provided" + ] + unresolved_issues = [issue for issue in issues if issue.fields.resolution in UNRESOLVED_RESOLUTIONS or issue.fields.resolution.name in UNRESOLVED_RESOLUTIONS] + if unresolved_issues: + print("The release is not completed since unresolved issues or improperly resolved issues were found still tagged with this release as the fix version:", file=sys.stderr) + for issue in unresolved_issues: + print("Unresolved issue: %15s %20s %s" % (issue.key, issue.fields.resolution, issue_link(issue)), file=sys.stderr) + print("", file=sys.stderr) + print("Note that for some resolutions, you should simply remove the fix version as they have not been truly fixed in this release.", file=sys.stderr) + sys.exit(1) + + # Get list of (issue type, [issues]) sorted by the issue ID type, with each subset of issues sorted by their key so they + # are in increasing order of bug #. To get a nice ordering of the issue types we customize the key used to sort by issue + # type a bit to ensure features and improvements end up first. + def issue_type_key(issue): + if issue.fields.issuetype.name == 'New Feature': + return -2 + if issue.fields.issuetype.name == 'Improvement': + return -1 + return int(issue.fields.issuetype.id) + + by_group = [(k,sorted(g, key=lambda issue: issue.id)) for k,g in itertools.groupby(sorted(issues, key=issue_type_key), lambda issue: issue.fields.issuetype.name)] + + print("

                Release Notes - Kafka - Version %s

                " % version) + print("""

                Below is a summary of the JIRA issues addressed in the %(version)s release of Kafka. For full documentation of the + release, a guide to get started, and information about the project, see the Kafka + project site.

                + +

                Note about upgrades: Please carefully review the + upgrade documentation for this release thoroughly + before upgrading your cluster. The upgrade notes discuss any critical information about incompatibilities and breaking + changes, performance changes, and any other changes that might impact your production deployment of Kafka.

                + +

                The documentation for the most recent release can be found at + https://kafka.apache.org/documentation.html.

                """ % { 'version': version, 'minor': minor_version_dotless }) + for itype, issues in by_group: + print("

                %s

                " % itype) + print("
                  ") + for issue in issues: + print('
                • [%(key)s] - %(summary)s
                • ' % {'key': issue.key, 'link': issue_link(issue), 'summary': issue.fields.summary}) + print("
                ") diff --git a/server-common/src/main/java/org/apache/kafka/queue/EventQueue.java b/server-common/src/main/java/org/apache/kafka/queue/EventQueue.java new file mode 100644 index 0000000..c282917 --- /dev/null +++ b/server-common/src/main/java/org/apache/kafka/queue/EventQueue.java @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.queue; + +import org.slf4j.Logger; + +import java.util.OptionalLong; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; + + +public interface EventQueue extends AutoCloseable { + interface Event { + /** + * Run the event. + */ + void run() throws Exception; + + /** + * Handle an exception that was either generated by running the event, or by the + * event queue's inability to run the event. + * + * @param e The exception. This will be a TimeoutException if the event hit + * its deadline before it could be scheduled. + * It will be a RejectedExecutionException if the event could not be + * scheduled because the event queue has already been closed. + * Otherweise, it will be whatever exception was thrown by run(). + */ + default void handleException(Throwable e) {} + } + + abstract class FailureLoggingEvent implements Event { + private final Logger log; + + public FailureLoggingEvent(Logger log) { + this.log = log; + } + + @Override + public void handleException(Throwable e) { + if (e instanceof RejectedExecutionException) { + log.info("Not processing {} because the event queue is closed.", this); + } else { + log.error("Unexpected error handling {}", this, e); + } + } + + @Override + public String toString() { + return this.getClass().getSimpleName(); + } + } + + class NoDeadlineFunction implements Function { + public static final NoDeadlineFunction INSTANCE = new NoDeadlineFunction(); + + @Override + public OptionalLong apply(OptionalLong ignored) { + return OptionalLong.empty(); + } + } + + class DeadlineFunction implements Function { + private final long deadlineNs; + + public DeadlineFunction(long deadlineNs) { + this.deadlineNs = deadlineNs; + } + + @Override + public OptionalLong apply(OptionalLong ignored) { + return OptionalLong.of(deadlineNs); + } + } + + class EarliestDeadlineFunction implements Function { + private final long newDeadlineNs; + + public EarliestDeadlineFunction(long newDeadlineNs) { + this.newDeadlineNs = newDeadlineNs; + } + + @Override + public OptionalLong apply(OptionalLong prevDeadlineNs) { + if (!prevDeadlineNs.isPresent()) { + return OptionalLong.of(newDeadlineNs); + } else if (prevDeadlineNs.getAsLong() < newDeadlineNs) { + return prevDeadlineNs; + } else { + return OptionalLong.of(newDeadlineNs); + } + } + } + + class VoidEvent implements Event { + public final static VoidEvent INSTANCE = new VoidEvent(); + + @Override + public void run() throws Exception { + } + } + + /** + * Add an element to the front of the queue. + * + * @param event The mandatory event to prepend. + */ + default void prepend(Event event) { + enqueue(EventInsertionType.PREPEND, null, NoDeadlineFunction.INSTANCE, event); + } + + /** + * Add an element to the end of the queue. + * + * @param event The event to append. + */ + default void append(Event event) { + enqueue(EventInsertionType.APPEND, null, NoDeadlineFunction.INSTANCE, event); + } + + /** + * Add an event to the end of the queue. + * + * @param deadlineNs The deadline for starting the event, in monotonic + * nanoseconds. If the event has not started by this + * deadline, handleException is called with a + * {@link org.apache.kafka.common.errors.TimeoutException}, + * and the event is cancelled. + * @param event The event to append. + */ + default void appendWithDeadline(long deadlineNs, Event event) { + enqueue(EventInsertionType.APPEND, null, new DeadlineFunction(deadlineNs), event); + } + + /** + * Schedule an event to be run at a specific time. + * + * @param tag If this is non-null, the unique tag to use for this + * event. If an event with this tag already exists, it + * will be cancelled. + * @param deadlineNsCalculator A function which takes as an argument the existing + * deadline for the event with this tag (or empty if the + * event has no tag, or if there is none such), and + * produces the deadline to use for this event. + * Once the deadline has arrived, the event will be + * run. Events whose deadlines are only a few nanoseconds + * apart may be executed in any order. + * @param event The event to schedule. + */ + default void scheduleDeferred(String tag, + Function deadlineNsCalculator, + Event event) { + enqueue(EventInsertionType.DEFERRED, tag, deadlineNsCalculator, event); + } + + /** + * Cancel a deferred event. + * + * @param tag The unique tag for the event to be cancelled. Must be + * non-null. If the event with the tag has not been + * scheduled, this call will be ignored. + */ + void cancelDeferred(String tag); + + enum EventInsertionType { + PREPEND, + APPEND, + DEFERRED + } + + /** + * Add an event to the queue. + * + * @param insertionType How to insert the event. + * PREPEND means insert the event as the first thing + * to run. APPEND means insert the event as the last + * thing to run. DEFERRED means insert the event to + * run after a delay. + * @param tag If this is non-null, the unique tag to use for + * this event. If an event with this tag already + * exists, it will be cancelled. + * @param deadlineNsCalculator If this is non-null, it is a function which takes + * as an argument the existing deadline for the + * event with this tag (or null if the event has no + * tag, or if there is none such), and produces the + * deadline to use for this event (or empty to use + * none.) Events whose deadlines are only a few + * nanoseconds apart may be executed in any order. + * @param event The event to enqueue. + */ + void enqueue(EventInsertionType insertionType, + String tag, + Function deadlineNsCalculator, + Event event); + + /** + * Asynchronously shut down the event queue with no unnecessary delay. + * @see #beginShutdown(String, Event, long, TimeUnit) + * + * @param source The source of the shutdown. + */ + default void beginShutdown(String source) { + beginShutdown(source, new VoidEvent()); + } + + /** + * Asynchronously shut down the event queue with no unnecessary delay. + * + * @param source The source of the shutdown. + * @param cleanupEvent The mandatory event to invoke after all other events have + * been processed. + * @see #beginShutdown(String, Event, long, TimeUnit) + */ + default void beginShutdown(String source, Event cleanupEvent) { + beginShutdown(source, cleanupEvent, 0, TimeUnit.SECONDS); + } + + /** + * Asynchronously shut down the event queue. + * + * No new events will be accepted, and the timeout will be initiated + * for all existing events. + * + * @param source The source of the shutdown. + * @param cleanupEvent The mandatory event to invoke after all other events have + * been processed. + * @param timeSpan The amount of time to use for the timeout. + * Once the timeout elapses, any remaining queued + * events will get a + * {@link org.apache.kafka.common.errors.TimeoutException}. + * @param timeUnit The time unit to use for the timeout. + */ + void beginShutdown(String source, Event cleanupEvent, long timeSpan, TimeUnit timeUnit); + + /** + * This method is used during unit tests where MockTime is in use. + * It is used to alert the queue that the mock time has changed. + */ + default void wakeup() { } + + /** + * Synchronously close the event queue and wait for any threads to be joined. + */ + void close() throws InterruptedException; +} diff --git a/server-common/src/main/java/org/apache/kafka/queue/KafkaEventQueue.java b/server-common/src/main/java/org/apache/kafka/queue/KafkaEventQueue.java new file mode 100644 index 0000000..05642f6 --- /dev/null +++ b/server-common/src/main/java/org/apache/kafka/queue/KafkaEventQueue.java @@ -0,0 +1,426 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.queue; + +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.OptionalLong; +import java.util.TreeMap; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Function; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.utils.KafkaThread; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.slf4j.Logger; + + +public final class KafkaEventQueue implements EventQueue { + /** + * A context object that wraps events. + */ + private static class EventContext { + /** + * The caller-supplied event. + */ + private final Event event; + + /** + * How this event was inserted. + */ + private final EventInsertionType insertionType; + + /** + * The previous pointer of our circular doubly-linked list. + */ + private EventContext prev = this; + + /** + * The next pointer in our circular doubly-linked list. + */ + private EventContext next = this; + + /** + * If this event is in the delay map, this is the key it is there under. + * If it is not in the map, this is null. + */ + private OptionalLong deadlineNs = OptionalLong.empty(); + + /** + * The tag associated with this event. + */ + private String tag; + + EventContext(Event event, EventInsertionType insertionType, String tag) { + this.event = event; + this.insertionType = insertionType; + this.tag = tag; + } + + /** + * Insert the event context in the circularly linked list after this node. + */ + void insertAfter(EventContext other) { + this.next.prev = other; + other.next = this.next; + other.prev = this; + this.next = other; + } + + /** + * Insert a new node in the circularly linked list before this node. + */ + void insertBefore(EventContext other) { + this.prev.next = other; + other.prev = this.prev; + other.next = this; + this.prev = other; + } + + /** + * Remove this node from the circularly linked list. + */ + void remove() { + this.prev.next = this.next; + this.next.prev = this.prev; + this.prev = this; + this.next = this; + } + + /** + * Returns true if this node is the only element in its list. + */ + boolean isSingleton() { + return prev == this && next == this; + } + + /** + * Run the event associated with this EventContext. + */ + void run(Logger log) throws InterruptedException { + try { + event.run(); + } catch (InterruptedException e) { + throw e; + } catch (Exception e) { + try { + event.handleException(e); + } catch (Throwable t) { + log.error("Unexpected exception in handleException", t); + } + } + } + + /** + * Complete the event associated with this EventContext with a timeout exception. + */ + void completeWithTimeout() { + completeWithException(new TimeoutException()); + } + + /** + * Complete the event associated with this EventContext with the specified + * exception. + */ + void completeWithException(Throwable t) { + event.handleException(t); + } + } + + private class EventHandler implements Runnable { + /** + * Event contexts indexed by tag. Events without a tag are not included here. + */ + private final Map tagToEventContext = new HashMap<>(); + + /** + * The head of the event queue. + */ + private final EventContext head = new EventContext(null, null, null); + + /** + * An ordered map of times in monotonic nanoseconds to events to time out. + */ + private final TreeMap deadlineMap = new TreeMap<>(); + + /** + * A condition variable for waking up the event handler thread. + */ + private final Condition cond = lock.newCondition(); + + @Override + public void run() { + try { + handleEvents(); + cleanupEvent.run(); + } catch (Throwable e) { + log.warn("event handler thread exiting with exception", e); + } + } + + private void remove(EventContext eventContext) { + eventContext.remove(); + if (eventContext.deadlineNs.isPresent()) { + deadlineMap.remove(eventContext.deadlineNs.getAsLong()); + eventContext.deadlineNs = OptionalLong.empty(); + } + if (eventContext.tag != null) { + tagToEventContext.remove(eventContext.tag, eventContext); + eventContext.tag = null; + } + } + + private void handleEvents() throws InterruptedException { + EventContext toTimeout = null; + EventContext toRun = null; + while (true) { + if (toTimeout != null) { + toTimeout.completeWithTimeout(); + toTimeout = null; + } else if (toRun != null) { + toRun.run(log); + toRun = null; + } + lock.lock(); + try { + long awaitNs = Long.MAX_VALUE; + Map.Entry entry = deadlineMap.firstEntry(); + if (entry != null) { + // Search for timed-out events or deferred events that are ready + // to run. + long now = time.nanoseconds(); + long timeoutNs = entry.getKey(); + EventContext eventContext = entry.getValue(); + if (timeoutNs <= now) { + if (eventContext.insertionType == EventInsertionType.DEFERRED) { + // The deferred event is ready to run. Prepend it to the + // queue. (The value for deferred events is a schedule time + // rather than a timeout.) + remove(eventContext); + toRun = eventContext; + } else { + // not a deferred event, so it is a deadline, and it is timed out. + remove(eventContext); + toTimeout = eventContext; + } + continue; + } else if (closingTimeNs <= now) { + remove(eventContext); + toTimeout = eventContext; + continue; + } + awaitNs = timeoutNs - now; + } + if (head.next == head) { + if ((closingTimeNs != Long.MAX_VALUE) && deadlineMap.isEmpty()) { + // If there are no more entries to process, and the queue is + // closing, exit the thread. + return; + } + } else { + toRun = head.next; + remove(toRun); + continue; + } + if (closingTimeNs != Long.MAX_VALUE) { + long now = time.nanoseconds(); + if (awaitNs > closingTimeNs - now) { + awaitNs = closingTimeNs - now; + } + } + if (awaitNs == Long.MAX_VALUE) { + cond.await(); + } else { + cond.awaitNanos(awaitNs); + } + } finally { + lock.unlock(); + } + } + } + + Exception enqueue(EventContext eventContext, + Function deadlineNsCalculator) { + lock.lock(); + try { + if (closingTimeNs != Long.MAX_VALUE) { + return new RejectedExecutionException(); + } + OptionalLong existingDeadlineNs = OptionalLong.empty(); + if (eventContext.tag != null) { + EventContext toRemove = + tagToEventContext.put(eventContext.tag, eventContext); + if (toRemove != null) { + existingDeadlineNs = toRemove.deadlineNs; + remove(toRemove); + } + } + OptionalLong deadlineNs = deadlineNsCalculator.apply(existingDeadlineNs); + boolean queueWasEmpty = head.isSingleton(); + boolean shouldSignal = false; + switch (eventContext.insertionType) { + case APPEND: + head.insertBefore(eventContext); + if (queueWasEmpty) { + shouldSignal = true; + } + break; + case PREPEND: + head.insertAfter(eventContext); + if (queueWasEmpty) { + shouldSignal = true; + } + break; + case DEFERRED: + if (!deadlineNs.isPresent()) { + return new RuntimeException( + "You must specify a deadline for deferred events."); + } + break; + } + if (deadlineNs.isPresent()) { + long insertNs = deadlineNs.getAsLong(); + long prevStartNs = deadlineMap.isEmpty() ? Long.MAX_VALUE : deadlineMap.firstKey(); + // If the time in nanoseconds is already taken, take the next one. + while (deadlineMap.putIfAbsent(insertNs, eventContext) != null) { + insertNs++; + } + eventContext.deadlineNs = OptionalLong.of(insertNs); + // If the new timeout is before all the existing ones, wake up the + // timeout thread. + if (insertNs <= prevStartNs) { + shouldSignal = true; + } + } + if (shouldSignal) { + cond.signal(); + } + } finally { + lock.unlock(); + } + return null; + } + + void cancelDeferred(String tag) { + lock.lock(); + try { + EventContext eventContext = tagToEventContext.get(tag); + if (eventContext != null) { + remove(eventContext); + } + } finally { + lock.unlock(); + } + } + + void wakeUp() { + lock.lock(); + try { + eventHandler.cond.signal(); + } finally { + lock.unlock(); + } + } + } + + private final Time time; + private final ReentrantLock lock; + private final Logger log; + private final EventHandler eventHandler; + private final Thread eventHandlerThread; + + /** + * The time in monotonic nanoseconds when the queue is closing, or Long.MAX_VALUE if + * the queue is not currently closing. + */ + private long closingTimeNs; + + private Event cleanupEvent; + + public KafkaEventQueue(Time time, + LogContext logContext, + String threadNamePrefix) { + this.time = time; + this.lock = new ReentrantLock(); + this.log = logContext.logger(KafkaEventQueue.class); + this.eventHandler = new EventHandler(); + this.eventHandlerThread = new KafkaThread(threadNamePrefix + "EventHandler", + this.eventHandler, false); + this.closingTimeNs = Long.MAX_VALUE; + this.cleanupEvent = null; + this.eventHandlerThread.start(); + } + + @Override + public void enqueue(EventInsertionType insertionType, + String tag, + Function deadlineNsCalculator, + Event event) { + EventContext eventContext = new EventContext(event, insertionType, tag); + Exception e = eventHandler.enqueue(eventContext, deadlineNsCalculator); + if (e != null) { + eventContext.completeWithException(e); + } + } + + @Override + public void cancelDeferred(String tag) { + eventHandler.cancelDeferred(tag); + } + + @Override + public void beginShutdown(String source, Event newCleanupEvent, + long timeSpan, TimeUnit timeUnit) { + if (timeSpan < 0) { + throw new IllegalArgumentException("beginShutdown must be called with a " + + "non-negative timeout."); + } + Objects.requireNonNull(newCleanupEvent); + lock.lock(); + try { + if (cleanupEvent != null) { + log.debug("{}: Event queue is already shutting down.", source); + return; + } + log.info("{}: shutting down event queue.", source); + cleanupEvent = newCleanupEvent; + long newClosingTimeNs = time.nanoseconds() + timeUnit.toNanos(timeSpan); + if (closingTimeNs >= newClosingTimeNs) + closingTimeNs = newClosingTimeNs; + eventHandler.cond.signal(); + } finally { + lock.unlock(); + } + } + + @Override + public void wakeup() { + eventHandler.wakeUp(); + } + + @Override + public void close() throws InterruptedException { + beginShutdown("KafkaEventQueue#close"); + eventHandlerThread.join(); + log.info("closed event queue."); + } +} diff --git a/server-common/src/main/java/org/apache/kafka/server/common/ApiMessageAndVersion.java b/server-common/src/main/java/org/apache/kafka/server/common/ApiMessageAndVersion.java new file mode 100644 index 0000000..66a625b --- /dev/null +++ b/server-common/src/main/java/org/apache/kafka/server/common/ApiMessageAndVersion.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.server.common; + +import org.apache.kafka.common.protocol.ApiMessage; + +import java.util.Objects; + +/** + * An ApiMessage and an associated version. + */ +public class ApiMessageAndVersion { + private final ApiMessage message; + private final short version; + + public ApiMessageAndVersion(ApiMessage message, short version) { + this.message = message; + this.version = version; + } + + public ApiMessage message() { + return message; + } + + public short version() { + return version; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ApiMessageAndVersion that = (ApiMessageAndVersion) o; + return version == that.version && + Objects.equals(message, that.message); + } + + @Override + public int hashCode() { + return Objects.hash(message, version); + } + + @Override + public String toString() { + return "ApiMessageAndVersion(" + message + " at version " + version + ")"; + } +} diff --git a/server-common/src/main/java/org/apache/kafka/server/common/CheckpointFile.java b/server-common/src/main/java/org/apache/kafka/server/common/CheckpointFile.java new file mode 100644 index 0000000..a1f708e --- /dev/null +++ b/server-common/src/main/java/org/apache/kafka/server/common/CheckpointFile.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.common; + +import org.apache.kafka.common.utils.Utils; + +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.nio.charset.StandardCharsets; +import java.nio.file.FileAlreadyExistsException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** + * This class represents a utility to capture a checkpoint in a file. It writes down to the file in the below format. + * + * ========= File beginning ========= + * version: int + * entries-count: int + * entry-as-string-on-each-line + * ========= File end =============== + * + * Each entry is represented as a string on each line in the checkpoint file. {@link EntryFormatter} is used + * to convert the entry into a string and vice versa. + * + * @param entry type. + */ +public class CheckpointFile { + + private final int version; + private final EntryFormatter formatter; + private final Object lock = new Object(); + private final Path absolutePath; + private final Path tempPath; + + public CheckpointFile(File file, + int version, + EntryFormatter formatter) throws IOException { + this.version = version; + this.formatter = formatter; + try { + // Create the file if it does not exist. + Files.createFile(file.toPath()); + } catch (FileAlreadyExistsException ex) { + // Ignore if file already exists. + } + absolutePath = file.toPath().toAbsolutePath(); + tempPath = Paths.get(absolutePath.toString() + ".tmp"); + } + + public void write(Collection entries) throws IOException { + synchronized (lock) { + // write to temp file and then swap with the existing file + try (FileOutputStream fileOutputStream = new FileOutputStream(tempPath.toFile()); + BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(fileOutputStream, StandardCharsets.UTF_8))) { + // Write the version + writer.write(Integer.toString(version)); + writer.newLine(); + + // Write the entries count + writer.write(Integer.toString(entries.size())); + writer.newLine(); + + // Write each entry on a new line. + for (T entry : entries) { + writer.write(formatter.toString(entry)); + writer.newLine(); + } + + writer.flush(); + fileOutputStream.getFD().sync(); + } + + Utils.atomicMoveWithFallback(tempPath, absolutePath); + } + } + + public List read() throws IOException { + synchronized (lock) { + try (BufferedReader reader = Files.newBufferedReader(absolutePath)) { + CheckpointReadBuffer checkpointBuffer = new CheckpointReadBuffer<>(absolutePath.toString(), reader, version, formatter); + return checkpointBuffer.read(); + } + } + } + + private static class CheckpointReadBuffer { + + private final String location; + private final BufferedReader reader; + private final int version; + private final EntryFormatter formatter; + + CheckpointReadBuffer(String location, + BufferedReader reader, + int version, + EntryFormatter formatter) { + this.location = location; + this.reader = reader; + this.version = version; + this.formatter = formatter; + } + + List read() throws IOException { + String line = reader.readLine(); + if (line == null) + return Collections.emptyList(); + + int readVersion = toInt(line); + if (readVersion != version) { + throw new IOException("Unrecognised version:" + readVersion + ", expected version: " + version + + " in checkpoint file at: " + location); + } + + line = reader.readLine(); + if (line == null) { + return Collections.emptyList(); + } + int expectedSize = toInt(line); + List entries = new ArrayList<>(expectedSize); + line = reader.readLine(); + while (line != null) { + Optional maybeEntry = formatter.fromString(line); + if (!maybeEntry.isPresent()) { + throw buildMalformedLineException(line); + } + entries.add(maybeEntry.get()); + line = reader.readLine(); + } + + if (entries.size() != expectedSize) { + throw new IOException("Expected [" + expectedSize + "] entries in checkpoint file [" + + location + "], but found only [" + entries.size() + "]"); + } + + return entries; + } + + private int toInt(String line) throws IOException { + try { + return Integer.parseInt(line); + } catch (NumberFormatException e) { + throw buildMalformedLineException(line); + } + } + + private IOException buildMalformedLineException(String line) { + return new IOException(String.format("Malformed line in checkpoint file [%s]: %s", location, line)); + } + } + + /** + * This is used to convert the given entry of type {@code T} into a string and vice versa. + * + * @param entry type + */ + public interface EntryFormatter { + + /** + * @param entry entry to be converted into string. + * @return String representation of the given entry. + */ + String toString(T entry); + + /** + * @param value string representation of an entry. + * @return entry converted from the given string representation if possible. {@link Optional#empty()} represents + * that the given string representation could not be converted into an entry. + */ + Optional fromString(String value); + } +} diff --git a/server-common/src/main/java/org/apache/kafka/server/common/ProducerIdsBlock.java b/server-common/src/main/java/org/apache/kafka/server/common/ProducerIdsBlock.java new file mode 100644 index 0000000..8a0fd84 --- /dev/null +++ b/server-common/src/main/java/org/apache/kafka/server/common/ProducerIdsBlock.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.server.common; + +import java.util.Objects; + +/** + * Holds a range of Producer IDs used for Transactional and EOS producers. + * + * The start and end of the ID block are inclusive. + */ +public class ProducerIdsBlock { + public static final int PRODUCER_ID_BLOCK_SIZE = 1000; + + public static final ProducerIdsBlock EMPTY = new ProducerIdsBlock(-1, 0, 0); + + private final int brokerId; + private final long producerIdStart; + private final int producerIdLen; + + public ProducerIdsBlock(int brokerId, long producerIdStart, int producerIdLen) { + this.brokerId = brokerId; + this.producerIdStart = producerIdStart; + this.producerIdLen = producerIdLen; + } + + public int brokerId() { + return brokerId; + } + + public long producerIdStart() { + return producerIdStart; + } + + public int producerIdLen() { + return producerIdLen; + } + + public long producerIdEnd() { + return producerIdStart + producerIdLen - 1; + } + + + @Override + public String toString() { + return "ProducerIdsBlock{" + + "brokerId=" + brokerId + + ", producerIdStart=" + producerIdStart + + ", producerIdLen=" + producerIdLen + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ProducerIdsBlock that = (ProducerIdsBlock) o; + return brokerId == that.brokerId && producerIdStart == that.producerIdStart && producerIdLen == that.producerIdLen; + } + + @Override + public int hashCode() { + return Objects.hash(brokerId, producerIdStart, producerIdLen); + } +} diff --git a/server-common/src/main/java/org/apache/kafka/server/common/serialization/AbstractApiMessageSerde.java b/server-common/src/main/java/org/apache/kafka/server/common/serialization/AbstractApiMessageSerde.java new file mode 100644 index 0000000..7533178 --- /dev/null +++ b/server-common/src/main/java/org/apache/kafka/server/common/serialization/AbstractApiMessageSerde.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.common.serialization; + +import org.apache.kafka.common.protocol.ApiMessage; +import org.apache.kafka.common.protocol.ObjectSerializationCache; +import org.apache.kafka.common.protocol.Readable; +import org.apache.kafka.common.protocol.Writable; +import org.apache.kafka.common.utils.ByteUtils; +import org.apache.kafka.server.common.ApiMessageAndVersion; + +/** + * This is an implementation of {@code RecordSerde} with {@link ApiMessageAndVersion} but implementors need to implement + * {@link #apiMessageFor(short)} to return a {@code ApiMessage} instance for the given {@code apiKey}. + * + * This can be used as the underlying serialization mechanism for records defined with {@link ApiMessage}s. + *

                + * Serialization format for the given {@code ApiMessageAndVersion} is below: + *

                + *
                + *     [data_frame_version header message]
                + *     header => [api_key version]
                + *
                + *     data_frame_version   : This is the header version, current value is 0. Header includes both api_key and version.
                + *     api_key              : apiKey of {@code ApiMessageAndVersion} object.
                + *     version              : version of {@code ApiMessageAndVersion} object.
                + *     message              : serialized message of {@code ApiMessageAndVersion} object.
                + * 
                + */ +public abstract class AbstractApiMessageSerde implements RecordSerde { + private static final short DEFAULT_FRAME_VERSION = 1; + private static final int DEFAULT_FRAME_VERSION_SIZE = ByteUtils.sizeOfUnsignedVarint(DEFAULT_FRAME_VERSION); + + private static short unsignedIntToShort(Readable input, String entity) { + int val; + try { + val = input.readUnsignedVarint(); + } catch (Exception e) { + throw new MetadataParseException("Error while reading " + entity, e); + } + if (val > Short.MAX_VALUE) { + throw new MetadataParseException("Value for " + entity + " was too large."); + } + return (short) val; + } + + @Override + public int recordSize(ApiMessageAndVersion data, + ObjectSerializationCache serializationCache) { + int size = DEFAULT_FRAME_VERSION_SIZE; + size += ByteUtils.sizeOfUnsignedVarint(data.message().apiKey()); + size += ByteUtils.sizeOfUnsignedVarint(data.version()); + size += data.message().size(serializationCache, data.version()); + return size; + } + + @Override + public void write(ApiMessageAndVersion data, + ObjectSerializationCache serializationCache, + Writable out) { + out.writeUnsignedVarint(DEFAULT_FRAME_VERSION); + out.writeUnsignedVarint(data.message().apiKey()); + out.writeUnsignedVarint(data.version()); + data.message().write(out, serializationCache, data.version()); + } + + @Override + public ApiMessageAndVersion read(Readable input, + int size) { + short frameVersion = unsignedIntToShort(input, "frame version"); + + if (frameVersion == 0) { + throw new MetadataParseException("Could not deserialize metadata record with frame version 0. " + + "Note that upgrades from the preview release of KRaft in 2.8 to newer versions are not supported."); + } else if (frameVersion != DEFAULT_FRAME_VERSION) { + throw new MetadataParseException("Could not deserialize metadata record due to unknown frame version " + + frameVersion + "(only frame version " + DEFAULT_FRAME_VERSION + " is supported)"); + } + short apiKey = unsignedIntToShort(input, "type"); + short version = unsignedIntToShort(input, "version"); + + ApiMessage record; + try { + record = apiMessageFor(apiKey); + } catch (Exception e) { + throw new MetadataParseException(e); + } + try { + record.read(input, version); + } catch (Exception e) { + throw new MetadataParseException("Failed to deserialize record with type " + apiKey, e); + } + if (input.remaining() > 0) { + throw new MetadataParseException("Found " + input.remaining() + + " byte(s) of garbage after " + apiKey); + } + return new ApiMessageAndVersion(record, version); + } + + /** + * Return {@code ApiMessage} instance for the given {@code apiKey}. This is used while deserializing the bytes + * payload into the respective {@code ApiMessage} in {@link #read(Readable, int)} method. + * + * @param apiKey apiKey for which a {@code ApiMessage} to be created. + */ + public abstract ApiMessage apiMessageFor(short apiKey); +} diff --git a/server-common/src/main/java/org/apache/kafka/server/common/serialization/BytesApiMessageSerde.java b/server-common/src/main/java/org/apache/kafka/server/common/serialization/BytesApiMessageSerde.java new file mode 100644 index 0000000..668bbfb --- /dev/null +++ b/server-common/src/main/java/org/apache/kafka/server/common/serialization/BytesApiMessageSerde.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.common.serialization; + +import org.apache.kafka.common.protocol.ApiMessage; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.ObjectSerializationCache; +import org.apache.kafka.common.protocol.Readable; +import org.apache.kafka.server.common.ApiMessageAndVersion; + +import java.nio.ByteBuffer; + +/** + * This class provides conversion of {@code ApiMessageAndVersion} to bytes and vice versa.. This can be used as serialization protocol for any + * metadata records derived of {@code ApiMessage}s. It internally uses {@link AbstractApiMessageSerde} for serialization/deserialization + * mechanism. + *

                + * Implementors need to extend this class and implement {@link #apiMessageFor(short)} method to return a respective + * {@code ApiMessage} for the given {@code apiKey}. This is required to deserialize the bytes to build the respective + * {@code ApiMessage} instance. + */ +public abstract class BytesApiMessageSerde { + + private final AbstractApiMessageSerde apiMessageSerde = new AbstractApiMessageSerde() { + @Override + public ApiMessage apiMessageFor(short apiKey) { + return BytesApiMessageSerde.this.apiMessageFor(apiKey); + } + }; + + public byte[] serialize(ApiMessageAndVersion messageAndVersion) { + ObjectSerializationCache cache = new ObjectSerializationCache(); + int size = apiMessageSerde.recordSize(messageAndVersion, cache); + ByteBufferAccessor writable = new ByteBufferAccessor(ByteBuffer.allocate(size)); + apiMessageSerde.write(messageAndVersion, cache, writable); + + return writable.buffer().array(); + } + + public ApiMessageAndVersion deserialize(byte[] data) { + Readable readable = new ByteBufferAccessor(ByteBuffer.wrap(data)); + + return apiMessageSerde.read(readable, data.length); + } + + /** + * Return {@code ApiMessage} instance for the given {@code apiKey}. This is used while deserializing the bytes + * payload into the respective {@code ApiMessage} in {@link #deserialize(byte[])} method. + * + * @param apiKey apiKey for which a {@code ApiMessage} to be created. + */ + public abstract ApiMessage apiMessageFor(short apiKey); + +} \ No newline at end of file diff --git a/server-common/src/main/java/org/apache/kafka/server/common/serialization/MetadataParseException.java b/server-common/src/main/java/org/apache/kafka/server/common/serialization/MetadataParseException.java new file mode 100644 index 0000000..49eea2d --- /dev/null +++ b/server-common/src/main/java/org/apache/kafka/server/common/serialization/MetadataParseException.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.common.serialization; + +/** + * An exception indicating that we failed to parse a metadata entry. + */ +public class MetadataParseException extends RuntimeException { + private static final long serialVersionUID = 1L; + + public MetadataParseException(String message) { + super(message); + } + + public MetadataParseException(Throwable e) { + super(e); + } + + public MetadataParseException(String message, Throwable throwable) { + super(message, throwable); + } +} diff --git a/server-common/src/main/java/org/apache/kafka/server/common/serialization/RecordSerde.java b/server-common/src/main/java/org/apache/kafka/server/common/serialization/RecordSerde.java new file mode 100644 index 0000000..70642c6 --- /dev/null +++ b/server-common/src/main/java/org/apache/kafka/server/common/serialization/RecordSerde.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.common.serialization; + +import org.apache.kafka.common.protocol.ObjectSerializationCache; +import org.apache.kafka.common.protocol.Readable; +import org.apache.kafka.common.protocol.Writable; + +/** + * Serde interface for records written to a metadata log. This class assumes + * a two-pass serialization, with the first pass used to compute the size of the + * serialized record, and the second pass to write the object. + */ +public interface RecordSerde { + /** + * Get the size of a record. This must be called first before writing + * the data through {@link #write(Object, ObjectSerializationCache, Writable)}. + * + * @param data the record that will be serialized + * @param serializationCache serialization cache + * @return the size in bytes of the serialized record + */ + int recordSize(T data, ObjectSerializationCache serializationCache); + + /** + * Write the record to the output stream. This must be called after + * computing the size with {@link #recordSize(Object, ObjectSerializationCache)}. + * The same {@link ObjectSerializationCache} instance must be used in both calls. + * + * @param data the record to serialize and write + * @param serializationCache serialization cache + * @param out the output stream to write the record to + */ + void write(T data, ObjectSerializationCache serializationCache, Writable out); + + /** + * Read a record from a {@link Readable} input. + * + * @param input the input stream to deserialize + * @param size the size of the record in bytes + * @return the deserialized record + */ + T read(Readable input, int size); +} diff --git a/server-common/src/main/java/org/apache/kafka/server/util/TranslatedValueMapView.java b/server-common/src/main/java/org/apache/kafka/server/util/TranslatedValueMapView.java new file mode 100644 index 0000000..9c85f6c --- /dev/null +++ b/server-common/src/main/java/org/apache/kafka/server/util/TranslatedValueMapView.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.server.util; + +import java.util.AbstractMap; +import java.util.AbstractSet; +import java.util.Iterator; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.function.Function; + + +/** + * A map which presents a lightweight view of another "underlying" map. Values in the + * underlying map will be translated by a callback before they are returned. + * + * This class is not internally synchronized. (Typically the underlyingMap is treated as + * immutable.) + */ +public final class TranslatedValueMapView extends AbstractMap { + class TranslatedValueSetView extends AbstractSet> { + @Override + public Iterator> iterator() { + return new TranslatedValueEntryIterator(underlyingMap.entrySet().iterator()); + } + + @SuppressWarnings("rawtypes") + @Override + public boolean contains(Object o) { + if (!(o instanceof Entry)) return false; + Entry other = (Entry) o; + if (!underlyingMap.containsKey(other.getKey())) return false; + B value = underlyingMap.get(other.getKey()); + V translatedValue = valueMapping.apply(value); + return Objects.equals(translatedValue, other.getValue()); + } + + @Override + public boolean isEmpty() { + return underlyingMap.isEmpty(); + } + + @Override + public int size() { + return underlyingMap.size(); + } + } + + class TranslatedValueEntryIterator implements Iterator> { + private final Iterator> underlyingIterator; + + TranslatedValueEntryIterator(Iterator> underlyingIterator) { + this.underlyingIterator = underlyingIterator; + } + + @Override + public boolean hasNext() { + return underlyingIterator.hasNext(); + } + + @Override + public Entry next() { + Entry underlyingEntry = underlyingIterator.next(); + return new AbstractMap.SimpleImmutableEntry<>(underlyingEntry.getKey(), + valueMapping.apply(underlyingEntry.getValue())); + } + } + + private final Map underlyingMap; + private final Function valueMapping; + private final TranslatedValueSetView set; + + public TranslatedValueMapView(Map underlyingMap, + Function valueMapping) { + this.underlyingMap = underlyingMap; + this.valueMapping = valueMapping; + this.set = new TranslatedValueSetView(); + } + + @Override + public boolean containsKey(Object key) { + return underlyingMap.containsKey(key); + } + + @Override + public V get(Object key) { + if (!underlyingMap.containsKey(key)) return null; + B value = underlyingMap.get(key); + return valueMapping.apply(value); + } + + @Override + public Set> entrySet() { + return set; + } + + @Override + public boolean isEmpty() { + return underlyingMap.isEmpty(); + } +} diff --git a/server-common/src/test/java/org/apache/kafka/queue/KafkaEventQueueTest.java b/server-common/src/test/java/org/apache/kafka/queue/KafkaEventQueueTest.java new file mode 100644 index 0000000..c3ee62a --- /dev/null +++ b/server-common/src/test/java/org/apache/kafka/queue/KafkaEventQueueTest.java @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.queue; + +import java.util.Arrays; +import java.util.List; +import java.util.OptionalLong; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; + + +@Timeout(value = 60) +public class KafkaEventQueueTest { + private static class FutureEvent implements EventQueue.Event { + private final CompletableFuture future; + private final Supplier supplier; + + FutureEvent(CompletableFuture future, Supplier supplier) { + this.future = future; + this.supplier = supplier; + } + + @Override + public void run() throws Exception { + T value = supplier.get(); + future.complete(value); + } + + @Override + public void handleException(Throwable e) { + future.completeExceptionally(e); + } + } + + @Test + public void testCreateAndClose() throws Exception { + KafkaEventQueue queue = + new KafkaEventQueue(Time.SYSTEM, new LogContext(), "testCreateAndClose"); + queue.close(); + } + + @Test + public void testHandleEvents() throws Exception { + KafkaEventQueue queue = + new KafkaEventQueue(Time.SYSTEM, new LogContext(), "testHandleEvents"); + AtomicInteger numEventsExecuted = new AtomicInteger(0); + CompletableFuture future1 = new CompletableFuture<>(); + queue.prepend(new FutureEvent<>(future1, () -> { + assertEquals(1, numEventsExecuted.incrementAndGet()); + return 1; + })); + CompletableFuture future2 = new CompletableFuture<>(); + queue.appendWithDeadline(Time.SYSTEM.nanoseconds() + TimeUnit.SECONDS.toNanos(60), + new FutureEvent<>(future2, () -> { + assertEquals(2, numEventsExecuted.incrementAndGet()); + return 2; + })); + CompletableFuture future3 = new CompletableFuture<>(); + queue.append(new FutureEvent<>(future3, () -> { + assertEquals(3, numEventsExecuted.incrementAndGet()); + return 3; + })); + assertEquals(Integer.valueOf(1), future1.get()); + assertEquals(Integer.valueOf(3), future3.get()); + assertEquals(Integer.valueOf(2), future2.get()); + CompletableFuture future4 = new CompletableFuture<>(); + queue.appendWithDeadline(Time.SYSTEM.nanoseconds() + TimeUnit.SECONDS.toNanos(60), + new FutureEvent<>(future4, () -> { + assertEquals(4, numEventsExecuted.incrementAndGet()); + return 4; + })); + future4.get(); + queue.beginShutdown("testHandleEvents"); + queue.close(); + } + + @Test + public void testTimeouts() throws Exception { + KafkaEventQueue queue = + new KafkaEventQueue(Time.SYSTEM, new LogContext(), "testTimeouts"); + AtomicInteger numEventsExecuted = new AtomicInteger(0); + CompletableFuture future1 = new CompletableFuture<>(); + queue.append(new FutureEvent<>(future1, () -> { + assertEquals(1, numEventsExecuted.incrementAndGet()); + return 1; + })); + CompletableFuture future2 = new CompletableFuture<>(); + queue.append(new FutureEvent<>(future2, () -> { + assertEquals(2, numEventsExecuted.incrementAndGet()); + Time.SYSTEM.sleep(1); + return 2; + })); + CompletableFuture future3 = new CompletableFuture<>(); + queue.appendWithDeadline(Time.SYSTEM.nanoseconds() + 1, + new FutureEvent<>(future3, () -> { + numEventsExecuted.incrementAndGet(); + return 3; + })); + CompletableFuture future4 = new CompletableFuture<>(); + queue.append(new FutureEvent<>(future4, () -> { + numEventsExecuted.incrementAndGet(); + return 4; + })); + assertEquals(Integer.valueOf(1), future1.get()); + assertEquals(Integer.valueOf(2), future2.get()); + assertEquals(Integer.valueOf(4), future4.get()); + assertEquals(TimeoutException.class, + assertThrows(ExecutionException.class, + () -> future3.get()).getCause().getClass()); + queue.close(); + assertEquals(3, numEventsExecuted.get()); + } + + @Test + public void testScheduleDeferred() throws Exception { + KafkaEventQueue queue = + new KafkaEventQueue(Time.SYSTEM, new LogContext(), "testAppendDeferred"); + + // Wait for the deferred event to happen after the non-deferred event. + // It may not happen every time, so we keep trying until it does. + AtomicLong counter = new AtomicLong(0); + CompletableFuture future1; + do { + counter.addAndGet(1); + future1 = new CompletableFuture<>(); + queue.scheduleDeferred(null, + __ -> OptionalLong.of(Time.SYSTEM.nanoseconds() + 1000000), + new FutureEvent<>(future1, () -> counter.get() % 2 == 0)); + CompletableFuture future2 = new CompletableFuture<>(); + queue.append(new FutureEvent<>(future2, () -> counter.addAndGet(1))); + future2.get(); + } while (!future1.get()); + queue.close(); + } + + private final static long ONE_HOUR_NS = TimeUnit.NANOSECONDS.convert(1, TimeUnit.HOURS); + + @Test + public void testScheduleDeferredWithTagReplacement() throws Exception { + KafkaEventQueue queue = new KafkaEventQueue(Time.SYSTEM, new LogContext(), + "testScheduleDeferredWithTagReplacement"); + + AtomicInteger ai = new AtomicInteger(0); + CompletableFuture future1 = new CompletableFuture<>(); + queue.scheduleDeferred("foo", + __ -> OptionalLong.of(Time.SYSTEM.nanoseconds() + ONE_HOUR_NS), + new FutureEvent<>(future1, () -> ai.addAndGet(1000))); + CompletableFuture future2 = new CompletableFuture<>(); + queue.scheduleDeferred("foo", prev -> OptionalLong.of(prev.orElse(0) - ONE_HOUR_NS), + new FutureEvent<>(future2, () -> ai.addAndGet(1))); + assertFalse(future1.isDone()); + assertEquals(Integer.valueOf(1), future2.get()); + assertEquals(1, ai.get()); + queue.close(); + } + + @Test + public void testDeferredIsQueuedAfterTriggering() throws Exception { + MockTime time = new MockTime(0, 100000, 1); + KafkaEventQueue queue = new KafkaEventQueue(time, new LogContext(), + "testDeferredIsQueuedAfterTriggering"); + AtomicInteger count = new AtomicInteger(0); + List> futures = Arrays.asList( + new CompletableFuture(), + new CompletableFuture(), + new CompletableFuture()); + queue.scheduleDeferred("foo", __ -> OptionalLong.of(2L), + new FutureEvent<>(futures.get(0), () -> count.getAndIncrement())); + queue.append(new FutureEvent<>(futures.get(1), () -> count.getAndAdd(1))); + assertEquals(Integer.valueOf(0), futures.get(1).get()); + time.sleep(1); + queue.append(new FutureEvent<>(futures.get(2), () -> count.getAndAdd(1))); + assertEquals(Integer.valueOf(1), futures.get(0).get()); + assertEquals(Integer.valueOf(2), futures.get(2).get()); + queue.close(); + } + + @Test + public void testShutdownBeforeDeferred() throws Exception { + KafkaEventQueue queue = new KafkaEventQueue(Time.SYSTEM, new LogContext(), + "testShutdownBeforeDeferred"); + final AtomicInteger count = new AtomicInteger(0); + CompletableFuture future = new CompletableFuture<>(); + queue.scheduleDeferred("myDeferred", + __ -> OptionalLong.of(Time.SYSTEM.nanoseconds() + TimeUnit.HOURS.toNanos(1)), + new FutureEvent<>(future, () -> count.getAndAdd(1))); + queue.beginShutdown("testShutdownBeforeDeferred"); + assertThrows(ExecutionException.class, () -> future.get()); + assertEquals(0, count.get()); + queue.close(); + } + + @Test + public void testRejectedExecutionExecption() throws Exception { + KafkaEventQueue queue = new KafkaEventQueue(Time.SYSTEM, new LogContext(), + "testRejectedExecutionExecption"); + queue.close(); + CompletableFuture future = new CompletableFuture<>(); + queue.append(new EventQueue.Event() { + @Override + public void run() throws Exception { + future.complete(null); + } + + @Override + public void handleException(Throwable e) { + future.completeExceptionally(e); + } + }); + assertEquals(RejectedExecutionException.class, assertThrows( + ExecutionException.class, () -> future.get()).getCause().getClass()); + } +} diff --git a/server-common/src/test/java/org/apache/kafka/server/util/TranslatedValueMapViewTest.java b/server-common/src/test/java/org/apache/kafka/server/util/TranslatedValueMapViewTest.java new file mode 100644 index 0000000..cc8feea --- /dev/null +++ b/server-common/src/test/java/org/apache/kafka/server/util/TranslatedValueMapViewTest.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.server.util; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.AbstractMap.SimpleImmutableEntry; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.Map.Entry; +import java.util.TreeMap; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + + +@Timeout(value = 60) +public class TranslatedValueMapViewTest { + private static Map createTestMap() { + Map testMap = new TreeMap<>(); + testMap.put("foo", 2); + testMap.put("bar", 3); + testMap.put("baz", 5); + return testMap; + } + + @Test + public void testContains() { + Map underlying = createTestMap(); + TranslatedValueMapView view = + new TranslatedValueMapView<>(underlying, v -> v.toString()); + assertTrue(view.containsKey("foo")); + assertTrue(view.containsKey("bar")); + assertTrue(view.containsKey("baz")); + assertFalse(view.containsKey("quux")); + underlying.put("quux", 101); + assertTrue(view.containsKey("quux")); + } + + @Test + public void testIsEmptyAndSize() { + Map underlying = new HashMap<>(); + TranslatedValueMapView view = + new TranslatedValueMapView<>(underlying, v -> v.toString()); + assertTrue(view.isEmpty()); + assertEquals(0, view.size()); + underlying.put("quux", 101); + assertFalse(view.isEmpty()); + assertEquals(1, view.size()); + } + + @Test + public void testGet() { + Map underlying = createTestMap(); + TranslatedValueMapView view = + new TranslatedValueMapView<>(underlying, v -> v.toString()); + assertEquals("2", view.get("foo")); + assertEquals("3", view.get("bar")); + assertEquals("5", view.get("baz")); + assertNull(view.get("quux")); + underlying.put("quux", 101); + assertEquals("101", view.get("quux")); + } + + @Test + public void testEntrySet() { + Map underlying = createTestMap(); + TranslatedValueMapView view = + new TranslatedValueMapView<>(underlying, v -> v.toString()); + assertEquals(3, view.entrySet().size()); + assertFalse(view.entrySet().isEmpty()); + assertTrue(view.entrySet().contains(new SimpleImmutableEntry<>("foo", "2"))); + assertFalse(view.entrySet().contains(new SimpleImmutableEntry<>("bar", "4"))); + } + + @Test + public void testEntrySetIterator() { + Map underlying = createTestMap(); + TranslatedValueMapView view = + new TranslatedValueMapView<>(underlying, v -> v.toString()); + Iterator> iterator = view.entrySet().iterator(); + assertTrue(iterator.hasNext()); + assertEquals(new SimpleImmutableEntry<>("bar", "3"), iterator.next()); + assertTrue(iterator.hasNext()); + assertEquals(new SimpleImmutableEntry<>("baz", "5"), iterator.next()); + assertTrue(iterator.hasNext()); + assertEquals(new SimpleImmutableEntry<>("foo", "2"), iterator.next()); + assertFalse(iterator.hasNext()); + } +} diff --git a/settings.gradle b/settings.gradle new file mode 100644 index 0000000..6ebabce --- /dev/null +++ b/settings.gradle @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +include 'clients', + 'connect:api', + 'connect:basic-auth-extension', + 'connect:file', + 'connect:json', + 'connect:mirror', + 'connect:mirror-client', + 'connect:runtime', + 'connect:transforms', + 'core', + 'examples', + 'generator', + 'jmh-benchmarks', + 'log4j-appender', + 'metadata', + 'raft', + 'server-common', + 'shell', + 'storage', + 'storage:api', + 'streams', + 'streams:examples', + 'streams:streams-scala', + 'streams:test-utils', + 'streams:upgrade-system-tests-0100', + 'streams:upgrade-system-tests-0101', + 'streams:upgrade-system-tests-0102', + 'streams:upgrade-system-tests-0110', + 'streams:upgrade-system-tests-10', + 'streams:upgrade-system-tests-11', + 'streams:upgrade-system-tests-20', + 'streams:upgrade-system-tests-21', + 'streams:upgrade-system-tests-22', + 'streams:upgrade-system-tests-23', + 'streams:upgrade-system-tests-24', + 'streams:upgrade-system-tests-25', + 'streams:upgrade-system-tests-26', + 'streams:upgrade-system-tests-27', + 'streams:upgrade-system-tests-28', + 'tools', + 'trogdor' diff --git a/shell/src/main/java/org/apache/kafka/shell/CatCommandHandler.java b/shell/src/main/java/org/apache/kafka/shell/CatCommandHandler.java new file mode 100644 index 0000000..3fc9427 --- /dev/null +++ b/shell/src/main/java/org/apache/kafka/shell/CatCommandHandler.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.shell; + +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.Namespace; +import org.apache.kafka.shell.MetadataNode.DirectoryNode; +import org.apache.kafka.shell.MetadataNode.FileNode; +import org.jline.reader.Candidate; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.PrintWriter; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +/** + * Implements the cat command. + */ +public final class CatCommandHandler implements Commands.Handler { + private static final Logger log = LoggerFactory.getLogger(CatCommandHandler.class); + + public final static Commands.Type TYPE = new CatCommandType(); + + public static class CatCommandType implements Commands.Type { + private CatCommandType() { + } + + @Override + public String name() { + return "cat"; + } + + @Override + public String description() { + return "Show the contents of metadata nodes."; + } + + @Override + public boolean shellOnly() { + return false; + } + + @Override + public void addArguments(ArgumentParser parser) { + parser.addArgument("targets"). + nargs("+"). + help("The metadata nodes to display."); + } + + @Override + public Commands.Handler createHandler(Namespace namespace) { + return new CatCommandHandler(namespace.getList("targets")); + } + + @Override + public void completeNext(MetadataNodeManager nodeManager, List nextWords, + List candidates) throws Exception { + CommandUtils.completePath(nodeManager, nextWords.get(nextWords.size() - 1), + candidates); + } + } + + private final List targets; + + public CatCommandHandler(List targets) { + this.targets = targets; + } + + @Override + public void run(Optional shell, + PrintWriter writer, + MetadataNodeManager manager) throws Exception { + log.trace("cat " + targets); + for (String target : targets) { + manager.visit(new GlobVisitor(target, entryOption -> { + if (entryOption.isPresent()) { + MetadataNode node = entryOption.get().node(); + if (node instanceof DirectoryNode) { + writer.println("cat: " + target + ": Is a directory"); + } else if (node instanceof FileNode) { + FileNode fileNode = (FileNode) node; + writer.println(fileNode.contents()); + } + } else { + writer.println("cat: " + target + ": No such file or directory."); + } + })); + } + } + + @Override + public int hashCode() { + return targets.hashCode(); + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof CatCommandHandler)) return false; + CatCommandHandler o = (CatCommandHandler) other; + if (!Objects.equals(o.targets, targets)) return false; + return true; + } +} diff --git a/shell/src/main/java/org/apache/kafka/shell/CdCommandHandler.java b/shell/src/main/java/org/apache/kafka/shell/CdCommandHandler.java new file mode 100644 index 0000000..8d270e5 --- /dev/null +++ b/shell/src/main/java/org/apache/kafka/shell/CdCommandHandler.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.shell; + +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.Namespace; +import org.apache.kafka.shell.MetadataNode.DirectoryNode; +import org.jline.reader.Candidate; + +import java.io.PrintWriter; +import java.util.List; +import java.util.Optional; +import java.util.function.Consumer; + +/** + * Implements the cd command. + */ +public final class CdCommandHandler implements Commands.Handler { + public final static Commands.Type TYPE = new CdCommandType(); + + public static class CdCommandType implements Commands.Type { + private CdCommandType() { + } + + @Override + public String name() { + return "cd"; + } + + @Override + public String description() { + return "Set the current working directory."; + } + + @Override + public boolean shellOnly() { + return true; + } + + @Override + public void addArguments(ArgumentParser parser) { + parser.addArgument("target"). + nargs("?"). + help("The directory to change to."); + } + + @Override + public Commands.Handler createHandler(Namespace namespace) { + return new CdCommandHandler(Optional.ofNullable(namespace.getString("target"))); + } + + @Override + public void completeNext(MetadataNodeManager nodeManager, List nextWords, + List candidates) throws Exception { + if (nextWords.size() == 1) { + CommandUtils.completePath(nodeManager, nextWords.get(0), candidates); + } + } + } + + private final Optional target; + + public CdCommandHandler(Optional target) { + this.target = target; + } + + @Override + public void run(Optional shell, + PrintWriter writer, + MetadataNodeManager manager) throws Exception { + String effectiveTarget = target.orElse("/"); + manager.visit(new Consumer() { + @Override + public void accept(MetadataNodeManager.Data data) { + new GlobVisitor(effectiveTarget, entryOption -> { + if (entryOption.isPresent()) { + if (!(entryOption.get().node() instanceof DirectoryNode)) { + writer.println("cd: " + effectiveTarget + ": not a directory."); + } else { + data.setWorkingDirectory(entryOption.get().absolutePath()); + } + } else { + writer.println("cd: " + effectiveTarget + ": no such directory."); + } + }).accept(data); + } + }); + } + + @Override + public int hashCode() { + return target.hashCode(); + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof CdCommandHandler)) return false; + CdCommandHandler o = (CdCommandHandler) other; + if (!o.target.equals(target)) return false; + return true; + } +} diff --git a/shell/src/main/java/org/apache/kafka/shell/CommandUtils.java b/shell/src/main/java/org/apache/kafka/shell/CommandUtils.java new file mode 100644 index 0000000..5febfb8 --- /dev/null +++ b/shell/src/main/java/org/apache/kafka/shell/CommandUtils.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.shell; + +import org.apache.kafka.shell.MetadataNode.DirectoryNode; +import org.jline.reader.Candidate; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map.Entry; + +/** + * Utility functions for command handlers. + */ +public final class CommandUtils { + /** + * Convert a list of paths into the effective list of paths which should be used. + * Empty strings will be removed. If no paths are given, the current working + * directory will be used. + * + * @param paths The input paths. Non-null. + * + * @return The output paths. + */ + public static List getEffectivePaths(List paths) { + List effectivePaths = new ArrayList<>(); + for (String path : paths) { + if (!path.isEmpty()) { + effectivePaths.add(path); + } + } + if (effectivePaths.isEmpty()) { + effectivePaths.add("."); + } + return effectivePaths; + } + + /** + * Generate a list of potential completions for a prefix of a command name. + * + * @param commandPrefix The command prefix. Non-null. + * @param candidates The list to add the output completions to. + */ + public static void completeCommand(String commandPrefix, List candidates) { + String command = Commands.TYPES.ceilingKey(commandPrefix); + while (command != null && command.startsWith(commandPrefix)) { + candidates.add(new Candidate(command)); + command = Commands.TYPES.higherKey(command); + } + } + + /** + * Convert a path to a list of path components. + * Multiple slashes in a row are treated the same as a single slash. + * Trailing slashes are ignored. + */ + public static List splitPath(String path) { + List results = new ArrayList<>(); + String[] components = path.split("/"); + for (int i = 0; i < components.length; i++) { + if (!components[i].isEmpty()) { + results.add(components[i]); + } + } + return results; + } + + public static List stripDotPathComponents(List input) { + List output = new ArrayList<>(); + for (String string : input) { + if (string.equals("..")) { + if (output.size() > 0) { + output.remove(output.size() - 1); + } + } else if (!string.equals(".")) { + output.add(string); + } + } + return output; + } + + /** + * Generate a list of potential completions for a path. + * + * @param nodeManager The NodeManager. + * @param pathPrefix The path prefix. Non-null. + * @param candidates The list to add the output completions to. + */ + public static void completePath(MetadataNodeManager nodeManager, + String pathPrefix, + List candidates) throws Exception { + nodeManager.visit(data -> { + String absolutePath = pathPrefix.startsWith("/") ? + pathPrefix : data.workingDirectory() + "/" + pathPrefix; + List pathComponents = stripDotPathComponents(splitPath(absolutePath)); + DirectoryNode directory = data.root(); + int numDirectories = pathPrefix.endsWith("/") ? + pathComponents.size() : pathComponents.size() - 1; + for (int i = 0; i < numDirectories; i++) { + MetadataNode node = directory.child(pathComponents.get(i)); + if (!(node instanceof DirectoryNode)) { + return; + } + directory = (DirectoryNode) node; + } + String lastComponent = ""; + if (numDirectories >= 0 && numDirectories < pathComponents.size()) { + lastComponent = pathComponents.get(numDirectories); + } + Entry candidate = + directory.children().ceilingEntry(lastComponent); + String effectivePrefix; + int lastSlash = pathPrefix.lastIndexOf('/'); + if (lastSlash < 0) { + effectivePrefix = ""; + } else { + effectivePrefix = pathPrefix.substring(0, lastSlash + 1); + } + while (candidate != null && candidate.getKey().startsWith(lastComponent)) { + StringBuilder candidateBuilder = new StringBuilder(); + candidateBuilder.append(effectivePrefix).append(candidate.getKey()); + boolean complete = true; + if (candidate.getValue() instanceof DirectoryNode) { + candidateBuilder.append("/"); + complete = false; + } + candidates.add(new Candidate(candidateBuilder.toString(), + candidateBuilder.toString(), null, null, null, null, complete)); + candidate = directory.children().higherEntry(candidate.getKey()); + } + }); + } +} diff --git a/shell/src/main/java/org/apache/kafka/shell/Commands.java b/shell/src/main/java/org/apache/kafka/shell/Commands.java new file mode 100644 index 0000000..db16411 --- /dev/null +++ b/shell/src/main/java/org/apache/kafka/shell/Commands.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.shell; + +import net.sourceforge.argparse4j.ArgumentParsers; +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.ArgumentParserException; +import net.sourceforge.argparse4j.inf.Namespace; +import net.sourceforge.argparse4j.inf.Subparser; +import net.sourceforge.argparse4j.inf.Subparsers; +import net.sourceforge.argparse4j.internal.HelpScreenException; +import org.jline.reader.Candidate; + +import java.io.PrintWriter; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.NavigableMap; +import java.util.Optional; +import java.util.TreeMap; + +/** + * The commands for the Kafka metadata tool. + */ +public final class Commands { + /** + * A map from command names to command types. + */ + static final NavigableMap TYPES; + + static { + TreeMap typesMap = new TreeMap<>(); + for (Type type : Arrays.asList( + CatCommandHandler.TYPE, + CdCommandHandler.TYPE, + ExitCommandHandler.TYPE, + FindCommandHandler.TYPE, + HelpCommandHandler.TYPE, + HistoryCommandHandler.TYPE, + LsCommandHandler.TYPE, + ManCommandHandler.TYPE, + PwdCommandHandler.TYPE)) { + typesMap.put(type.name(), type); + } + TYPES = Collections.unmodifiableNavigableMap(typesMap); + } + + /** + * Command handler objects are instantiated with specific arguments to + * execute commands. + */ + public interface Handler { + void run(Optional shell, + PrintWriter writer, + MetadataNodeManager manager) throws Exception; + } + + /** + * An object which describes a type of command handler. This includes + * information like its name, help text, and whether it should be accessible + * from non-interactive mode. + */ + public interface Type { + String name(); + String description(); + boolean shellOnly(); + void addArguments(ArgumentParser parser); + Handler createHandler(Namespace namespace); + void completeNext(MetadataNodeManager nodeManager, + List nextWords, + List candidates) throws Exception; + } + + private final ArgumentParser parser; + + /** + * Create the commands instance. + * + * @param addShellCommands True if we should include the shell-only commands. + */ + public Commands(boolean addShellCommands) { + this.parser = ArgumentParsers.newArgumentParser("", false); + Subparsers subparsers = this.parser.addSubparsers().dest("command"); + for (Type type : TYPES.values()) { + if (addShellCommands || !type.shellOnly()) { + Subparser subParser = subparsers.addParser(type.name()); + subParser.help(type.description()); + type.addArguments(subParser); + } + } + } + + ArgumentParser parser() { + return parser; + } + + /** + * Handle the given command. + * + * In general this function should not throw exceptions. Instead, it should + * return ErroneousCommandHandler if the input was invalid. + * + * @param arguments The command line arguments. + * @return The command handler. + */ + public Handler parseCommand(List arguments) { + List trimmedArguments = new ArrayList<>(arguments); + while (true) { + if (trimmedArguments.isEmpty()) { + return new NoOpCommandHandler(); + } + String last = trimmedArguments.get(trimmedArguments.size() - 1); + if (!last.isEmpty()) { + break; + } + trimmedArguments.remove(trimmedArguments.size() - 1); + } + Namespace namespace; + try { + namespace = parser.parseArgs(trimmedArguments.toArray(new String[0])); + } catch (HelpScreenException e) { + return new NoOpCommandHandler(); + } catch (ArgumentParserException e) { + return new ErroneousCommandHandler(e.getMessage()); + } + String command = namespace.get("command"); + if (!command.equals(trimmedArguments.get(0))) { + return new ErroneousCommandHandler("invalid choice: '" + + trimmedArguments.get(0) + "': did you mean '" + command + "'?"); + } + Type type = TYPES.get(command); + if (type == null) { + return new ErroneousCommandHandler("Unknown command specified: " + command); + } else { + return type.createHandler(namespace); + } + } +} diff --git a/shell/src/main/java/org/apache/kafka/shell/ErroneousCommandHandler.java b/shell/src/main/java/org/apache/kafka/shell/ErroneousCommandHandler.java new file mode 100644 index 0000000..d52c55f --- /dev/null +++ b/shell/src/main/java/org/apache/kafka/shell/ErroneousCommandHandler.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.shell; + +import java.io.PrintWriter; +import java.util.Objects; +import java.util.Optional; + +/** + * Handles erroneous commands. + */ +public final class ErroneousCommandHandler implements Commands.Handler { + private final String message; + + public ErroneousCommandHandler(String message) { + this.message = message; + } + + @Override + public void run(Optional shell, + PrintWriter writer, + MetadataNodeManager manager) { + writer.println(message); + } + + @Override + public int hashCode() { + return Objects.hashCode(message); + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof ErroneousCommandHandler)) return false; + ErroneousCommandHandler o = (ErroneousCommandHandler) other; + if (!Objects.equals(o.message, message)) return false; + return true; + } + + @Override + public String toString() { + return "ErroneousCommandHandler(" + message + ")"; + } +} diff --git a/shell/src/main/java/org/apache/kafka/shell/ExitCommandHandler.java b/shell/src/main/java/org/apache/kafka/shell/ExitCommandHandler.java new file mode 100644 index 0000000..2b11b35 --- /dev/null +++ b/shell/src/main/java/org/apache/kafka/shell/ExitCommandHandler.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.shell; + +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.Namespace; +import org.apache.kafka.common.utils.Exit; +import org.jline.reader.Candidate; + +import java.io.PrintWriter; +import java.util.List; +import java.util.Optional; + +/** + * Implements the exit command. + */ +public final class ExitCommandHandler implements Commands.Handler { + public final static Commands.Type TYPE = new ExitCommandType(); + + public static class ExitCommandType implements Commands.Type { + private ExitCommandType() { + } + + @Override + public String name() { + return "exit"; + } + + @Override + public String description() { + return "Exit the metadata shell."; + } + + @Override + public boolean shellOnly() { + return true; + } + + @Override + public void addArguments(ArgumentParser parser) { + // nothing to do + } + + @Override + public Commands.Handler createHandler(Namespace namespace) { + return new ExitCommandHandler(); + } + + @Override + public void completeNext(MetadataNodeManager nodeManager, List nextWords, + List candidates) throws Exception { + // nothing to do + } + } + + @Override + public void run(Optional shell, + PrintWriter writer, + MetadataNodeManager manager) { + Exit.exit(0); + } + + @Override + public int hashCode() { + return 0; + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof ExitCommandHandler)) return false; + return true; + } +} diff --git a/shell/src/main/java/org/apache/kafka/shell/FindCommandHandler.java b/shell/src/main/java/org/apache/kafka/shell/FindCommandHandler.java new file mode 100644 index 0000000..6d9ae44 --- /dev/null +++ b/shell/src/main/java/org/apache/kafka/shell/FindCommandHandler.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.shell; + +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.Namespace; +import org.apache.kafka.shell.MetadataNode.DirectoryNode; +import org.jline.reader.Candidate; + +import java.io.PrintWriter; +import java.util.List; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.Optional; + +/** + * Implements the find command. + */ +public final class FindCommandHandler implements Commands.Handler { + public final static Commands.Type TYPE = new FindCommandType(); + + public static class FindCommandType implements Commands.Type { + private FindCommandType() { + } + + @Override + public String name() { + return "find"; + } + + @Override + public String description() { + return "Search for nodes in the directory hierarchy."; + } + + @Override + public boolean shellOnly() { + return false; + } + + @Override + public void addArguments(ArgumentParser parser) { + parser.addArgument("paths"). + nargs("*"). + help("The paths to start at."); + } + + @Override + public Commands.Handler createHandler(Namespace namespace) { + return new FindCommandHandler(namespace.getList("paths")); + } + + @Override + public void completeNext(MetadataNodeManager nodeManager, List nextWords, + List candidates) throws Exception { + CommandUtils.completePath(nodeManager, nextWords.get(nextWords.size() - 1), + candidates); + } + } + + private final List paths; + + public FindCommandHandler(List paths) { + this.paths = paths; + } + + @Override + public void run(Optional shell, + PrintWriter writer, + MetadataNodeManager manager) throws Exception { + for (String path : CommandUtils.getEffectivePaths(paths)) { + manager.visit(new GlobVisitor(path, entryOption -> { + if (entryOption.isPresent()) { + find(writer, path, entryOption.get().node()); + } else { + writer.println("find: " + path + ": no such file or directory."); + } + })); + } + } + + private void find(PrintWriter writer, String path, MetadataNode node) { + writer.println(path); + if (node instanceof DirectoryNode) { + DirectoryNode directory = (DirectoryNode) node; + for (Entry entry : directory.children().entrySet()) { + String nextPath = path.equals("/") ? + path + entry.getKey() : path + "/" + entry.getKey(); + find(writer, nextPath, entry.getValue()); + } + } + } + + @Override + public int hashCode() { + return Objects.hashCode(paths); + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof FindCommandHandler)) return false; + FindCommandHandler o = (FindCommandHandler) other; + if (!Objects.equals(o.paths, paths)) return false; + return true; + } +} diff --git a/shell/src/main/java/org/apache/kafka/shell/GlobComponent.java b/shell/src/main/java/org/apache/kafka/shell/GlobComponent.java new file mode 100644 index 0000000..b93382b --- /dev/null +++ b/shell/src/main/java/org/apache/kafka/shell/GlobComponent.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.shell; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.regex.Pattern; + +/** + * Implements a per-path-component glob. + */ +public final class GlobComponent { + private static final Logger log = LoggerFactory.getLogger(GlobComponent.class); + + /** + * Returns true if the character is a special character for regular expressions. + */ + private static boolean isRegularExpressionSpecialCharacter(char ch) { + switch (ch) { + case '$': + case '(': + case ')': + case '+': + case '.': + case '[': + case ']': + case '^': + case '{': + case '|': + return true; + default: + break; + } + return false; + } + + /** + * Returns true if the character is a special character for globs. + */ + private static boolean isGlobSpecialCharacter(char ch) { + switch (ch) { + case '*': + case '?': + case '\\': + case '{': + case '}': + return true; + default: + break; + } + return false; + } + + /** + * Converts a glob string to a regular expression string. + * Returns null if the glob should be handled as a literal (can only match one string). + * Throws an exception if the glob is malformed. + */ + static String toRegularExpression(String glob) { + StringBuilder output = new StringBuilder("^"); + boolean literal = true; + boolean processingGroup = false; + + for (int i = 0; i < glob.length(); ) { + char c = glob.charAt(i++); + switch (c) { + case '?': + literal = false; + output.append("."); + break; + case '*': + literal = false; + output.append(".*"); + break; + case '\\': + if (i == glob.length()) { + output.append(c); + } else { + char next = glob.charAt(i); + i++; + if (isGlobSpecialCharacter(next) || + isRegularExpressionSpecialCharacter(next)) { + output.append('\\'); + } + output.append(next); + } + break; + case '{': + if (processingGroup) { + throw new RuntimeException("Can't nest glob groups."); + } + literal = false; + output.append("(?:(?:"); + processingGroup = true; + break; + case ',': + if (processingGroup) { + literal = false; + output.append(")|(?:"); + } else { + output.append(c); + } + break; + case '}': + if (processingGroup) { + literal = false; + output.append("))"); + processingGroup = false; + } else { + output.append(c); + } + break; + // TODO: handle character ranges + default: + if (isRegularExpressionSpecialCharacter(c)) { + output.append('\\'); + } + output.append(c); + } + } + if (processingGroup) { + throw new RuntimeException("Unterminated glob group."); + } + if (literal) { + return null; + } + output.append('$'); + return output.toString(); + } + + private final String component; + private final Pattern pattern; + + public GlobComponent(String component) { + this.component = component; + Pattern newPattern = null; + try { + String regularExpression = toRegularExpression(component); + if (regularExpression != null) { + newPattern = Pattern.compile(regularExpression); + } + } catch (RuntimeException e) { + log.debug("Invalid glob pattern: " + e.getMessage()); + } + this.pattern = newPattern; + } + + public String component() { + return component; + } + + public boolean literal() { + return pattern == null; + } + + public boolean matches(String nodeName) { + if (pattern == null) { + return component.equals(nodeName); + } else { + return pattern.matcher(nodeName).matches(); + } + } +} diff --git a/shell/src/main/java/org/apache/kafka/shell/GlobVisitor.java b/shell/src/main/java/org/apache/kafka/shell/GlobVisitor.java new file mode 100644 index 0000000..8081b7e --- /dev/null +++ b/shell/src/main/java/org/apache/kafka/shell/GlobVisitor.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.shell; + +import java.util.Arrays; +import java.util.List; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.Optional; +import java.util.function.Consumer; + +/** + * Visits metadata paths based on a glob string. + */ +public final class GlobVisitor implements Consumer { + private final String glob; + private final Consumer> handler; + + public GlobVisitor(String glob, + Consumer> handler) { + this.glob = glob; + this.handler = handler; + } + + public static class MetadataNodeInfo { + private final String[] path; + private final MetadataNode node; + + MetadataNodeInfo(String[] path, MetadataNode node) { + this.path = path; + this.node = node; + } + + public String[] path() { + return path; + } + + public MetadataNode node() { + return node; + } + + public String lastPathComponent() { + if (path.length == 0) { + return "/"; + } else { + return path[path.length - 1]; + } + } + + public String absolutePath() { + return "/" + String.join("/", path); + } + + @Override + public int hashCode() { + return Objects.hash(path, node); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof MetadataNodeInfo)) return false; + MetadataNodeInfo other = (MetadataNodeInfo) o; + if (!Arrays.equals(path, other.path)) return false; + if (!node.equals(other.node)) return false; + return true; + } + + @Override + public String toString() { + StringBuilder bld = new StringBuilder("MetadataNodeInfo(path="); + for (int i = 0; i < path.length; i++) { + bld.append("/"); + bld.append(path[i]); + } + bld.append(", node=").append(node).append(")"); + return bld.toString(); + } + } + + @Override + public void accept(MetadataNodeManager.Data data) { + String fullGlob = glob.startsWith("/") ? glob : + data.workingDirectory() + "/" + glob; + List globComponents = + CommandUtils.stripDotPathComponents(CommandUtils.splitPath(fullGlob)); + if (!accept(globComponents, 0, data.root(), new String[0])) { + handler.accept(Optional.empty()); + } + } + + private boolean accept(List globComponents, + int componentIndex, + MetadataNode node, + String[] path) { + if (componentIndex >= globComponents.size()) { + handler.accept(Optional.of(new MetadataNodeInfo(path, node))); + return true; + } + String globComponentString = globComponents.get(componentIndex); + GlobComponent globComponent = new GlobComponent(globComponentString); + if (globComponent.literal()) { + if (!(node instanceof MetadataNode.DirectoryNode)) { + return false; + } + MetadataNode.DirectoryNode directory = (MetadataNode.DirectoryNode) node; + MetadataNode child = directory.child(globComponent.component()); + if (child == null) { + return false; + } + String[] newPath = new String[path.length + 1]; + System.arraycopy(path, 0, newPath, 0, path.length); + newPath[path.length] = globComponent.component(); + return accept(globComponents, componentIndex + 1, child, newPath); + } + if (!(node instanceof MetadataNode.DirectoryNode)) { + return false; + } + MetadataNode.DirectoryNode directory = (MetadataNode.DirectoryNode) node; + boolean matchedAny = false; + for (Entry entry : directory.children().entrySet()) { + String nodeName = entry.getKey(); + if (globComponent.matches(nodeName)) { + String[] newPath = new String[path.length + 1]; + System.arraycopy(path, 0, newPath, 0, path.length); + newPath[path.length] = nodeName; + if (accept(globComponents, componentIndex + 1, entry.getValue(), newPath)) { + matchedAny = true; + } + } + } + return matchedAny; + } +} diff --git a/shell/src/main/java/org/apache/kafka/shell/HelpCommandHandler.java b/shell/src/main/java/org/apache/kafka/shell/HelpCommandHandler.java new file mode 100644 index 0000000..829274e --- /dev/null +++ b/shell/src/main/java/org/apache/kafka/shell/HelpCommandHandler.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.shell; + +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.Namespace; +import org.jline.reader.Candidate; + +import java.io.PrintWriter; +import java.util.List; +import java.util.Optional; + +/** + * Implements the help command. + */ +public final class HelpCommandHandler implements Commands.Handler { + public final static Commands.Type TYPE = new HelpCommandType(); + + public static class HelpCommandType implements Commands.Type { + private HelpCommandType() { + } + + @Override + public String name() { + return "help"; + } + + @Override + public String description() { + return "Display this help message."; + } + + @Override + public boolean shellOnly() { + return true; + } + + @Override + public void addArguments(ArgumentParser parser) { + // nothing to do + } + + @Override + public Commands.Handler createHandler(Namespace namespace) { + return new HelpCommandHandler(); + } + + @Override + public void completeNext(MetadataNodeManager nodeManager, List nextWords, + List candidates) throws Exception { + // nothing to do + } + } + + @Override + public void run(Optional shell, + PrintWriter writer, + MetadataNodeManager manager) { + writer.printf("Welcome to the Apache Kafka metadata shell.%n%n"); + new Commands(true).parser().printHelp(writer); + } + + @Override + public int hashCode() { + return 0; + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof HelpCommandHandler)) return false; + return true; + } +} diff --git a/shell/src/main/java/org/apache/kafka/shell/HistoryCommandHandler.java b/shell/src/main/java/org/apache/kafka/shell/HistoryCommandHandler.java new file mode 100644 index 0000000..edf9def --- /dev/null +++ b/shell/src/main/java/org/apache/kafka/shell/HistoryCommandHandler.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.shell; + +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.Namespace; +import org.jline.reader.Candidate; + +import java.io.PrintWriter; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** + * Implements the history command. + */ +public final class HistoryCommandHandler implements Commands.Handler { + public final static Commands.Type TYPE = new HistoryCommandType(); + + public static class HistoryCommandType implements Commands.Type { + private HistoryCommandType() { + } + + @Override + public String name() { + return "history"; + } + + @Override + public String description() { + return "Print command history."; + } + + @Override + public boolean shellOnly() { + return true; + } + + @Override + public void addArguments(ArgumentParser parser) { + parser.addArgument("numEntriesToShow"). + nargs("?"). + type(Integer.class). + help("The number of entries to show."); + } + + @Override + public Commands.Handler createHandler(Namespace namespace) { + Integer numEntriesToShow = namespace.getInt("numEntriesToShow"); + return new HistoryCommandHandler(numEntriesToShow == null ? + Integer.MAX_VALUE : numEntriesToShow); + } + + @Override + public void completeNext(MetadataNodeManager nodeManager, List nextWords, + List candidates) throws Exception { + // nothing to do + } + } + + private final int numEntriesToShow; + + public HistoryCommandHandler(int numEntriesToShow) { + this.numEntriesToShow = numEntriesToShow; + } + + @Override + public void run(Optional shell, + PrintWriter writer, + MetadataNodeManager manager) throws Exception { + if (!shell.isPresent()) { + throw new RuntimeException("The history command requires a shell."); + } + Iterator> iter = shell.get().history(numEntriesToShow); + while (iter.hasNext()) { + Map.Entry entry = iter.next(); + writer.printf("% 5d %s%n", entry.getKey(), entry.getValue()); + } + } + + @Override + public int hashCode() { + return numEntriesToShow; + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof HistoryCommandHandler)) return false; + HistoryCommandHandler o = (HistoryCommandHandler) other; + return o.numEntriesToShow == numEntriesToShow; + } +} diff --git a/shell/src/main/java/org/apache/kafka/shell/InteractiveShell.java b/shell/src/main/java/org/apache/kafka/shell/InteractiveShell.java new file mode 100644 index 0000000..aa4d4ea --- /dev/null +++ b/shell/src/main/java/org/apache/kafka/shell/InteractiveShell.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.shell; + +import org.jline.reader.Candidate; +import org.jline.reader.Completer; +import org.jline.reader.EndOfFileException; +import org.jline.reader.History; +import org.jline.reader.LineReader; +import org.jline.reader.LineReaderBuilder; +import org.jline.reader.ParsedLine; +import org.jline.reader.Parser; +import org.jline.reader.UserInterruptException; +import org.jline.reader.impl.DefaultParser; +import org.jline.reader.impl.history.DefaultHistory; +import org.jline.terminal.Terminal; +import org.jline.terminal.TerminalBuilder; + +import java.io.IOException; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Map.Entry; +import java.util.NoSuchElementException; +import java.util.Optional; + +/** + * The Kafka metadata shell. + */ +public final class InteractiveShell implements AutoCloseable { + static class MetadataShellCompleter implements Completer { + private final MetadataNodeManager nodeManager; + + MetadataShellCompleter(MetadataNodeManager nodeManager) { + this.nodeManager = nodeManager; + } + + @Override + public void complete(LineReader reader, ParsedLine line, List candidates) { + if (line.words().size() == 0) { + CommandUtils.completeCommand("", candidates); + } else if (line.words().size() == 1) { + CommandUtils.completeCommand(line.words().get(0), candidates); + } else { + Iterator iter = line.words().iterator(); + String command = iter.next(); + List nextWords = new ArrayList<>(); + while (iter.hasNext()) { + nextWords.add(iter.next()); + } + Commands.Type type = Commands.TYPES.get(command); + if (type == null) { + return; + } + try { + type.completeNext(nodeManager, nextWords, candidates); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + } + } + + private final MetadataNodeManager nodeManager; + private final Terminal terminal; + private final Parser parser; + private final History history; + private final MetadataShellCompleter completer; + private final LineReader reader; + + public InteractiveShell(MetadataNodeManager nodeManager) throws IOException { + this.nodeManager = nodeManager; + TerminalBuilder builder = TerminalBuilder.builder(). + system(true). + nativeSignals(true); + this.terminal = builder.build(); + this.parser = new DefaultParser(); + this.history = new DefaultHistory(); + this.completer = new MetadataShellCompleter(nodeManager); + this.reader = LineReaderBuilder.builder(). + terminal(terminal). + parser(parser). + history(history). + completer(completer). + option(LineReader.Option.AUTO_FRESH_LINE, false). + build(); + } + + public void runMainLoop() throws Exception { + terminal.writer().println("[ Kafka Metadata Shell ]"); + terminal.flush(); + Commands commands = new Commands(true); + while (true) { + try { + reader.readLine(">> "); + ParsedLine parsedLine = reader.getParsedLine(); + Commands.Handler handler = commands.parseCommand(parsedLine.words()); + handler.run(Optional.of(this), terminal.writer(), nodeManager); + terminal.writer().flush(); + } catch (UserInterruptException eof) { + // Handle the user pressing control-C. + terminal.writer().println("^C"); + } catch (EndOfFileException eof) { + return; + } + } + } + + public int screenWidth() { + return terminal.getWidth(); + } + + public Iterator> history(int numEntriesToShow) { + if (numEntriesToShow < 0) { + numEntriesToShow = 0; + } + int last = history.last(); + if (numEntriesToShow > last + 1) { + numEntriesToShow = last + 1; + } + int first = last - numEntriesToShow + 1; + if (first < history.first()) { + first = history.first(); + } + return new HistoryIterator(first, last); + } + + public class HistoryIterator implements Iterator> { + private int index; + private int last; + + HistoryIterator(int index, int last) { + this.index = index; + this.last = last; + } + + @Override + public boolean hasNext() { + return index <= last; + } + + @Override + public Entry next() { + if (index > last) { + throw new NoSuchElementException(); + } + int p = index++; + return new AbstractMap.SimpleImmutableEntry<>(p, history.get(p)); + } + } + + @Override + public void close() throws IOException { + terminal.close(); + } +} diff --git a/shell/src/main/java/org/apache/kafka/shell/LsCommandHandler.java b/shell/src/main/java/org/apache/kafka/shell/LsCommandHandler.java new file mode 100644 index 0000000..6260d12 --- /dev/null +++ b/shell/src/main/java/org/apache/kafka/shell/LsCommandHandler.java @@ -0,0 +1,299 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.shell; + +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.Namespace; +import org.apache.kafka.shell.GlobVisitor.MetadataNodeInfo; +import org.apache.kafka.shell.MetadataNode.DirectoryNode; +import org.apache.kafka.shell.MetadataNode.FileNode; +import org.jline.reader.Candidate; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.PrintWriter; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.OptionalInt; + +/** + * Implements the ls command. + */ +public final class LsCommandHandler implements Commands.Handler { + private static final Logger log = LoggerFactory.getLogger(LsCommandHandler.class); + + public final static Commands.Type TYPE = new LsCommandType(); + + public static class LsCommandType implements Commands.Type { + private LsCommandType() { + } + + @Override + public String name() { + return "ls"; + } + + @Override + public String description() { + return "List metadata nodes."; + } + + @Override + public boolean shellOnly() { + return false; + } + + @Override + public void addArguments(ArgumentParser parser) { + parser.addArgument("targets"). + nargs("*"). + help("The metadata node paths to list."); + } + + @Override + public Commands.Handler createHandler(Namespace namespace) { + return new LsCommandHandler(namespace.getList("targets")); + } + + @Override + public void completeNext(MetadataNodeManager nodeManager, List nextWords, + List candidates) throws Exception { + CommandUtils.completePath(nodeManager, nextWords.get(nextWords.size() - 1), + candidates); + } + } + + private final List targets; + + public LsCommandHandler(List targets) { + this.targets = targets; + } + + static class TargetDirectory { + private final String name; + private final List children; + + TargetDirectory(String name, List children) { + this.name = name; + this.children = children; + } + } + + @Override + public void run(Optional shell, + PrintWriter writer, + MetadataNodeManager manager) throws Exception { + List targetFiles = new ArrayList<>(); + List targetDirectories = new ArrayList<>(); + for (String target : CommandUtils.getEffectivePaths(targets)) { + manager.visit(new GlobVisitor(target, entryOption -> { + if (entryOption.isPresent()) { + MetadataNodeInfo info = entryOption.get(); + MetadataNode node = info.node(); + if (node instanceof DirectoryNode) { + DirectoryNode directory = (DirectoryNode) node; + List children = new ArrayList<>(); + children.addAll(directory.children().keySet()); + targetDirectories.add( + new TargetDirectory(info.lastPathComponent(), children)); + } else if (node instanceof FileNode) { + targetFiles.add(info.lastPathComponent()); + } + } else { + writer.println("ls: " + target + ": no such file or directory."); + } + })); + } + OptionalInt screenWidth = shell.isPresent() ? + OptionalInt.of(shell.get().screenWidth()) : OptionalInt.empty(); + log.trace("LS : targetFiles = {}, targetDirectories = {}, screenWidth = {}", + targetFiles, targetDirectories, screenWidth); + printTargets(writer, screenWidth, targetFiles, targetDirectories); + } + + static void printTargets(PrintWriter writer, + OptionalInt screenWidth, + List targetFiles, + List targetDirectories) { + printEntries(writer, "", screenWidth, targetFiles); + boolean needIntro = targetFiles.size() > 0 || targetDirectories.size() > 1; + boolean firstIntro = targetFiles.isEmpty(); + for (TargetDirectory targetDirectory : targetDirectories) { + String intro = ""; + if (needIntro) { + if (!firstIntro) { + intro = intro + String.format("%n"); + } + intro = intro + targetDirectory.name + ":"; + firstIntro = false; + } + log.trace("LS : targetDirectory name = {}, children = {}", + targetDirectory.name, targetDirectory.children); + printEntries(writer, intro, screenWidth, targetDirectory.children); + } + } + + static void printEntries(PrintWriter writer, + String intro, + OptionalInt screenWidth, + List entries) { + if (entries.isEmpty()) { + return; + } + if (!intro.isEmpty()) { + writer.println(intro); + } + ColumnSchema columnSchema = calculateColumnSchema(screenWidth, entries); + int numColumns = columnSchema.numColumns(); + int numLines = (entries.size() + numColumns - 1) / numColumns; + for (int line = 0; line < numLines; line++) { + StringBuilder output = new StringBuilder(); + for (int column = 0; column < numColumns; column++) { + int entryIndex = line + (column * columnSchema.entriesPerColumn()); + if (entryIndex < entries.size()) { + String entry = entries.get(entryIndex); + output.append(entry); + if (column < numColumns - 1) { + int width = columnSchema.columnWidth(column); + for (int i = 0; i < width - entry.length(); i++) { + output.append(" "); + } + } + } + } + writer.println(output.toString()); + } + } + + static ColumnSchema calculateColumnSchema(OptionalInt screenWidth, + List entries) { + if (!screenWidth.isPresent()) { + return new ColumnSchema(1, entries.size()); + } + int maxColumns = screenWidth.getAsInt() / 4; + if (maxColumns <= 1) { + return new ColumnSchema(1, entries.size()); + } + ColumnSchema[] schemas = new ColumnSchema[maxColumns]; + for (int numColumns = 1; numColumns <= maxColumns; numColumns++) { + schemas[numColumns - 1] = new ColumnSchema(numColumns, + (entries.size() + numColumns - 1) / numColumns); + } + for (int i = 0; i < entries.size(); i++) { + String entry = entries.get(i); + for (int s = 0; s < schemas.length; s++) { + ColumnSchema schema = schemas[s]; + schema.process(i, entry); + } + } + for (int s = schemas.length - 1; s > 0; s--) { + ColumnSchema schema = schemas[s]; + if (schema.columnWidths[schema.columnWidths.length - 1] != 0 && + schema.totalWidth() <= screenWidth.getAsInt()) { + return schema; + } + } + return schemas[0]; + } + + static class ColumnSchema { + private final int[] columnWidths; + private final int entriesPerColumn; + + ColumnSchema(int numColumns, int entriesPerColumn) { + this.columnWidths = new int[numColumns]; + this.entriesPerColumn = entriesPerColumn; + } + + ColumnSchema setColumnWidths(Integer... widths) { + for (int i = 0; i < widths.length; i++) { + columnWidths[i] = widths[i]; + } + return this; + } + + void process(int entryIndex, String output) { + int columnIndex = entryIndex / entriesPerColumn; + columnWidths[columnIndex] = Math.max( + columnWidths[columnIndex], output.length() + 2); + } + + int totalWidth() { + int total = 0; + for (int i = 0; i < columnWidths.length; i++) { + total += columnWidths[i]; + } + return total; + } + + int numColumns() { + return columnWidths.length; + } + + int columnWidth(int columnIndex) { + return columnWidths[columnIndex]; + } + + int entriesPerColumn() { + return entriesPerColumn; + } + + @Override + public int hashCode() { + return Objects.hash(columnWidths, entriesPerColumn); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof ColumnSchema)) return false; + ColumnSchema other = (ColumnSchema) o; + if (entriesPerColumn != other.entriesPerColumn) return false; + if (!Arrays.equals(columnWidths, other.columnWidths)) return false; + return true; + } + + @Override + public String toString() { + StringBuilder bld = new StringBuilder("ColumnSchema(columnWidths=["); + String prefix = ""; + for (int i = 0; i < columnWidths.length; i++) { + bld.append(prefix); + bld.append(columnWidths[i]); + prefix = ", "; + } + bld.append("], entriesPerColumn=").append(entriesPerColumn).append(")"); + return bld.toString(); + } + } + + @Override + public int hashCode() { + return Objects.hashCode(targets); + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof LsCommandHandler)) return false; + LsCommandHandler o = (LsCommandHandler) other; + if (!Objects.equals(o.targets, targets)) return false; + return true; + } +} diff --git a/shell/src/main/java/org/apache/kafka/shell/ManCommandHandler.java b/shell/src/main/java/org/apache/kafka/shell/ManCommandHandler.java new file mode 100644 index 0000000..dcd0b8c --- /dev/null +++ b/shell/src/main/java/org/apache/kafka/shell/ManCommandHandler.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.shell; + +import net.sourceforge.argparse4j.ArgumentParsers; +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.Namespace; +import org.jline.reader.Candidate; + +import java.io.PrintWriter; +import java.util.List; +import java.util.Optional; + +/** + * Implements the manual command. + */ +public final class ManCommandHandler implements Commands.Handler { + private final String cmd; + + public final static Commands.Type TYPE = new ManCommandType(); + + public static class ManCommandType implements Commands.Type { + private ManCommandType() { + } + + @Override + public String name() { + return "man"; + } + + @Override + public String description() { + return "Show the help text for a specific command."; + } + + @Override + public boolean shellOnly() { + return true; + } + + @Override + public void addArguments(ArgumentParser parser) { + parser.addArgument("cmd"). + nargs(1). + help("The command to get help text for."); + } + + @Override + public Commands.Handler createHandler(Namespace namespace) { + return new ManCommandHandler(namespace.getList("cmd").get(0)); + } + + @Override + public void completeNext(MetadataNodeManager nodeManager, List nextWords, + List candidates) throws Exception { + if (nextWords.size() == 1) { + CommandUtils.completeCommand(nextWords.get(0), candidates); + } + } + } + + public ManCommandHandler(String cmd) { + this.cmd = cmd; + } + + @Override + public void run(Optional shell, + PrintWriter writer, + MetadataNodeManager manager) { + Commands.Type type = Commands.TYPES.get(cmd); + if (type == null) { + writer.println("man: unknown command " + cmd + + ". Type help to get a list of commands."); + } else { + ArgumentParser parser = ArgumentParsers.newArgumentParser(type.name(), false); + type.addArguments(parser); + writer.printf("%s: %s%n%n", cmd, type.description()); + parser.printHelp(writer); + } + } + + @Override + public int hashCode() { + return cmd.hashCode(); + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof ManCommandHandler)) return false; + ManCommandHandler o = (ManCommandHandler) other; + if (!o.cmd.equals(cmd)) return false; + return true; + } +} diff --git a/shell/src/main/java/org/apache/kafka/shell/MetadataNode.java b/shell/src/main/java/org/apache/kafka/shell/MetadataNode.java new file mode 100644 index 0000000..ad0b3cb --- /dev/null +++ b/shell/src/main/java/org/apache/kafka/shell/MetadataNode.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.shell; + +import java.util.NavigableMap; +import java.util.TreeMap; + +/** + * A node in the metadata tool. + */ +public interface MetadataNode { + class DirectoryNode implements MetadataNode { + private final TreeMap children = new TreeMap<>(); + + public DirectoryNode mkdirs(String... names) { + if (names.length == 0) { + throw new RuntimeException("Invalid zero-length path"); + } + DirectoryNode node = this; + for (int i = 0; i < names.length; i++) { + MetadataNode nextNode = node.children.get(names[i]); + if (nextNode == null) { + nextNode = new DirectoryNode(); + node.children.put(names[i], nextNode); + } else { + if (!(nextNode instanceof DirectoryNode)) { + throw new NotDirectoryException(); + } + } + node = (DirectoryNode) nextNode; + } + return node; + } + + public void rmrf(String... names) { + if (names.length == 0) { + throw new RuntimeException("Invalid zero-length path"); + } + DirectoryNode node = this; + for (int i = 0; i < names.length - 1; i++) { + MetadataNode nextNode = node.children.get(names[i]); + if (!(nextNode instanceof DirectoryNode)) { + throw new RuntimeException("Unable to locate directory /" + + String.join("/", names)); + } + node = (DirectoryNode) nextNode; + } + node.children.remove(names[names.length - 1]); + } + + public FileNode create(String name) { + MetadataNode node = children.get(name); + if (node == null) { + node = new FileNode(); + children.put(name, node); + } else { + if (!(node instanceof FileNode)) { + throw new NotFileException(); + } + } + return (FileNode) node; + } + + public MetadataNode child(String component) { + return children.get(component); + } + + public NavigableMap children() { + return children; + } + + public void addChild(String name, DirectoryNode child) { + children.put(name, child); + } + + public DirectoryNode directory(String... names) { + if (names.length == 0) { + throw new RuntimeException("Invalid zero-length path"); + } + DirectoryNode node = this; + for (int i = 0; i < names.length; i++) { + MetadataNode nextNode = node.children.get(names[i]); + if (!(nextNode instanceof DirectoryNode)) { + throw new RuntimeException("Unable to locate directory /" + + String.join("/", names)); + } + node = (DirectoryNode) nextNode; + } + return node; + } + + public FileNode file(String... names) { + if (names.length == 0) { + throw new RuntimeException("Invalid zero-length path"); + } + DirectoryNode node = this; + for (int i = 0; i < names.length - 1; i++) { + MetadataNode nextNode = node.children.get(names[i]); + if (!(nextNode instanceof DirectoryNode)) { + throw new RuntimeException("Unable to locate file /" + + String.join("/", names)); + } + node = (DirectoryNode) nextNode; + } + MetadataNode nextNode = node.child(names[names.length - 1]); + if (!(nextNode instanceof FileNode)) { + throw new RuntimeException("Unable to locate file /" + + String.join("/", names)); + } + return (FileNode) nextNode; + } + } + + class FileNode implements MetadataNode { + private String contents; + + void setContents(String contents) { + this.contents = contents; + } + + String contents() { + return contents; + } + } +} diff --git a/shell/src/main/java/org/apache/kafka/shell/MetadataNodeManager.java b/shell/src/main/java/org/apache/kafka/shell/MetadataNodeManager.java new file mode 100644 index 0000000..fa1b411 --- /dev/null +++ b/shell/src/main/java/org/apache/kafka/shell/MetadataNodeManager.java @@ -0,0 +1,338 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.shell; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.metadata.ClientQuotaRecord; +import org.apache.kafka.common.metadata.ClientQuotaRecord.EntityData; +import org.apache.kafka.common.metadata.ConfigRecord; +import org.apache.kafka.common.metadata.FenceBrokerRecord; +import org.apache.kafka.common.metadata.MetadataRecordType; +import org.apache.kafka.common.metadata.PartitionChangeRecord; +import org.apache.kafka.common.metadata.PartitionRecord; +import org.apache.kafka.common.metadata.PartitionRecordJsonConverter; +import org.apache.kafka.common.metadata.RegisterBrokerRecord; +import org.apache.kafka.common.metadata.RemoveTopicRecord; +import org.apache.kafka.common.metadata.TopicRecord; +import org.apache.kafka.common.metadata.UnfenceBrokerRecord; +import org.apache.kafka.common.metadata.UnregisterBrokerRecord; +import org.apache.kafka.common.protocol.ApiMessage; +import org.apache.kafka.common.utils.AppInfoParser; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.queue.EventQueue; +import org.apache.kafka.queue.KafkaEventQueue; +import org.apache.kafka.raft.Batch; +import org.apache.kafka.raft.BatchReader; +import org.apache.kafka.raft.LeaderAndEpoch; +import org.apache.kafka.raft.RaftClient; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.apache.kafka.shell.MetadataNode.DirectoryNode; +import org.apache.kafka.shell.MetadataNode.FileNode; +import org.apache.kafka.snapshot.SnapshotReader; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; + +/** + * Maintains the in-memory metadata for the metadata tool. + */ +public final class MetadataNodeManager implements AutoCloseable { + private static final int NO_LEADER_CHANGE = -2; + + private static final Logger log = LoggerFactory.getLogger(MetadataNodeManager.class); + + public static class Data { + private final DirectoryNode root = new DirectoryNode(); + private String workingDirectory = "/"; + + public DirectoryNode root() { + return root; + } + + public String workingDirectory() { + return workingDirectory; + } + + public void setWorkingDirectory(String workingDirectory) { + this.workingDirectory = workingDirectory; + } + } + + class LogListener implements RaftClient.Listener { + @Override + public void handleCommit(BatchReader reader) { + try { + while (reader.hasNext()) { + Batch batch = reader.next(); + log.debug("handleCommits " + batch.records() + " at offset " + batch.lastOffset()); + DirectoryNode dir = data.root.mkdirs("metadataQuorum"); + dir.create("offset").setContents(String.valueOf(batch.lastOffset())); + for (ApiMessageAndVersion messageAndVersion : batch.records()) { + handleMessage(messageAndVersion.message()); + } + } + } finally { + reader.close(); + } + } + + @Override + public void handleSnapshot(SnapshotReader reader) { + try { + while (reader.hasNext()) { + Batch batch = reader.next(); + for (ApiMessageAndVersion messageAndVersion : batch) { + handleMessage(messageAndVersion.message()); + } + } + } finally { + reader.close(); + } + } + + @Override + public void handleLeaderChange(LeaderAndEpoch leader) { + appendEvent("handleNewLeader", () -> { + log.debug("handleNewLeader " + leader); + DirectoryNode dir = data.root.mkdirs("metadataQuorum"); + dir.create("leader").setContents(leader.toString()); + }, null); + } + + @Override + public void beginShutdown() { + log.debug("Metadata log listener sent beginShutdown"); + } + } + + private final Data data = new Data(); + private final LogListener logListener = new LogListener(); + private final ObjectMapper objectMapper; + private final KafkaEventQueue queue; + + public MetadataNodeManager() { + this.objectMapper = new ObjectMapper(); + this.objectMapper.registerModule(new Jdk8Module()); + this.queue = new KafkaEventQueue(Time.SYSTEM, + new LogContext("[node-manager-event-queue] "), ""); + } + + public void setup() throws Exception { + CompletableFuture future = new CompletableFuture<>(); + appendEvent("createShellNodes", () -> { + DirectoryNode directory = data.root().mkdirs("local"); + directory.create("version").setContents(AppInfoParser.getVersion()); + directory.create("commitId").setContents(AppInfoParser.getCommitId()); + future.complete(null); + }, future); + future.get(); + } + + public LogListener logListener() { + return logListener; + } + + // VisibleForTesting + Data getData() { + return data; + } + + @Override + public void close() throws Exception { + queue.close(); + } + + public void visit(Consumer consumer) throws Exception { + CompletableFuture future = new CompletableFuture<>(); + appendEvent("visit", () -> { + consumer.accept(data); + future.complete(null); + }, future); + future.get(); + } + + private void appendEvent(String name, Runnable runnable, CompletableFuture future) { + queue.append(new EventQueue.Event() { + @Override + public void run() throws Exception { + runnable.run(); + } + + @Override + public void handleException(Throwable e) { + log.error("Unexpected error while handling event " + name, e); + if (future != null) { + future.completeExceptionally(e); + } + } + }); + } + + // VisibleForTesting + void handleMessage(ApiMessage message) { + try { + MetadataRecordType type = MetadataRecordType.fromId(message.apiKey()); + handleCommitImpl(type, message); + } catch (Exception e) { + log.error("Error processing record of type " + message.apiKey(), e); + } + } + + private void handleCommitImpl(MetadataRecordType type, ApiMessage message) + throws Exception { + switch (type) { + case REGISTER_BROKER_RECORD: { + DirectoryNode brokersNode = data.root.mkdirs("brokers"); + RegisterBrokerRecord record = (RegisterBrokerRecord) message; + DirectoryNode brokerNode = brokersNode. + mkdirs(Integer.toString(record.brokerId())); + FileNode registrationNode = brokerNode.create("registration"); + registrationNode.setContents(record.toString()); + brokerNode.create("isFenced").setContents("true"); + break; + } + case UNREGISTER_BROKER_RECORD: { + UnregisterBrokerRecord record = (UnregisterBrokerRecord) message; + data.root.rmrf("brokers", Integer.toString(record.brokerId())); + break; + } + case TOPIC_RECORD: { + TopicRecord record = (TopicRecord) message; + DirectoryNode topicsDirectory = data.root.mkdirs("topics"); + DirectoryNode topicDirectory = topicsDirectory.mkdirs(record.name()); + topicDirectory.create("id").setContents(record.topicId().toString()); + topicDirectory.create("name").setContents(record.name().toString()); + DirectoryNode topicIdsDirectory = data.root.mkdirs("topicIds"); + topicIdsDirectory.addChild(record.topicId().toString(), topicDirectory); + break; + } + case PARTITION_RECORD: { + PartitionRecord record = (PartitionRecord) message; + DirectoryNode topicDirectory = + data.root.mkdirs("topicIds").mkdirs(record.topicId().toString()); + DirectoryNode partitionDirectory = + topicDirectory.mkdirs(Integer.toString(record.partitionId())); + JsonNode node = PartitionRecordJsonConverter. + write(record, PartitionRecord.HIGHEST_SUPPORTED_VERSION); + partitionDirectory.create("data").setContents(node.toPrettyString()); + break; + } + case CONFIG_RECORD: { + ConfigRecord record = (ConfigRecord) message; + String typeString = ""; + switch (ConfigResource.Type.forId(record.resourceType())) { + case BROKER: + typeString = "broker"; + break; + case TOPIC: + typeString = "topic"; + break; + default: + throw new RuntimeException("Error processing CONFIG_RECORD: " + + "Can't handle ConfigResource.Type " + record.resourceType()); + } + DirectoryNode configDirectory = data.root.mkdirs("configs"). + mkdirs(typeString).mkdirs(record.resourceName()); + if (record.value() == null) { + configDirectory.rmrf(record.name()); + } else { + configDirectory.create(record.name()).setContents(record.value()); + } + break; + } + case PARTITION_CHANGE_RECORD: { + PartitionChangeRecord record = (PartitionChangeRecord) message; + FileNode file = data.root.file("topicIds", record.topicId().toString(), + Integer.toString(record.partitionId()), "data"); + JsonNode node = objectMapper.readTree(file.contents()); + PartitionRecord partition = PartitionRecordJsonConverter. + read(node, PartitionRecord.HIGHEST_SUPPORTED_VERSION); + if (record.isr() != null) { + partition.setIsr(record.isr()); + } + if (record.leader() != NO_LEADER_CHANGE) { + partition.setLeader(record.leader()); + partition.setLeaderEpoch(partition.leaderEpoch() + 1); + } + partition.setPartitionEpoch(partition.partitionEpoch() + 1); + file.setContents(PartitionRecordJsonConverter.write(partition, + PartitionRecord.HIGHEST_SUPPORTED_VERSION).toPrettyString()); + break; + } + case FENCE_BROKER_RECORD: { + FenceBrokerRecord record = (FenceBrokerRecord) message; + data.root.mkdirs("brokers", Integer.toString(record.id())). + create("isFenced").setContents("true"); + break; + } + case UNFENCE_BROKER_RECORD: { + UnfenceBrokerRecord record = (UnfenceBrokerRecord) message; + data.root.mkdirs("brokers", Integer.toString(record.id())). + create("isFenced").setContents("false"); + break; + } + case REMOVE_TOPIC_RECORD: { + RemoveTopicRecord record = (RemoveTopicRecord) message; + DirectoryNode topicsDirectory = + data.root.directory("topicIds", record.topicId().toString()); + String name = topicsDirectory.file("name").contents(); + data.root.rmrf("topics", name); + data.root.rmrf("topicIds", record.topicId().toString()); + break; + } + case CLIENT_QUOTA_RECORD: { + ClientQuotaRecord record = (ClientQuotaRecord) message; + List directories = clientQuotaRecordDirectories(record.entity()); + DirectoryNode node = data.root; + for (String directory : directories) { + node = node.mkdirs(directory); + } + if (record.remove()) + node.rmrf(record.key()); + else + node.create(record.key()).setContents(record.value() + ""); + break; + } + default: + throw new RuntimeException("Unhandled metadata record type"); + } + } + + static List clientQuotaRecordDirectories(List entityData) { + List result = new ArrayList<>(); + result.add("client-quotas"); + TreeMap entries = new TreeMap<>(); + entityData.forEach(e -> entries.put(e.entityType(), e)); + for (Map.Entry entry : entries.entrySet()) { + result.add(entry.getKey()); + result.add(entry.getValue().entityName() == null ? + "" : entry.getValue().entityName()); + } + return result; + } +} diff --git a/shell/src/main/java/org/apache/kafka/shell/MetadataShell.java b/shell/src/main/java/org/apache/kafka/shell/MetadataShell.java new file mode 100644 index 0000000..1d99623 --- /dev/null +++ b/shell/src/main/java/org/apache/kafka/shell/MetadataShell.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.shell; + +import kafka.raft.KafkaRaftManager; +import kafka.tools.TerseFailure; +import net.sourceforge.argparse4j.ArgumentParsers; +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.Namespace; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.BufferedWriter; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.ExecutionException; + + +/** + * The Kafka metadata shell. + */ +public final class MetadataShell { + private static final Logger log = LoggerFactory.getLogger(MetadataShell.class); + + public static class Builder { + private String snapshotPath; + + public Builder setSnapshotPath(String snapshotPath) { + this.snapshotPath = snapshotPath; + return this; + } + + public MetadataShell build() throws Exception { + if (snapshotPath == null) { + throw new RuntimeException("You must supply the log path via --snapshot"); + } + MetadataNodeManager nodeManager = null; + SnapshotFileReader reader = null; + try { + nodeManager = new MetadataNodeManager(); + reader = new SnapshotFileReader(snapshotPath, nodeManager.logListener()); + return new MetadataShell(null, reader, nodeManager); + } catch (Throwable e) { + log.error("Initialization error", e); + if (reader != null) { + reader.close(); + } + if (nodeManager != null) { + nodeManager.close(); + } + throw e; + } + } + } + + private final KafkaRaftManager raftManager; + + private final SnapshotFileReader snapshotFileReader; + + private final MetadataNodeManager nodeManager; + + public MetadataShell(KafkaRaftManager raftManager, + SnapshotFileReader snapshotFileReader, + MetadataNodeManager nodeManager) { + this.raftManager = raftManager; + this.snapshotFileReader = snapshotFileReader; + this.nodeManager = nodeManager; + } + + public void run(List args) throws Exception { + nodeManager.setup(); + if (raftManager != null) { + raftManager.startup(); + raftManager.register(nodeManager.logListener()); + } else if (snapshotFileReader != null) { + snapshotFileReader.startup(); + } else { + throw new RuntimeException("Expected either a raft manager or snapshot reader"); + } + if (args == null || args.isEmpty()) { + // Interactive mode. + System.out.println("Loading..."); + waitUntilCaughtUp(); + System.out.println("Starting..."); + try (InteractiveShell shell = new InteractiveShell(nodeManager)) { + shell.runMainLoop(); + } + } else { + // Non-interactive mode. + waitUntilCaughtUp(); + Commands commands = new Commands(false); + try (PrintWriter writer = new PrintWriter(new BufferedWriter( + new OutputStreamWriter(System.out, StandardCharsets.UTF_8)))) { + Commands.Handler handler = commands.parseCommand(args); + handler.run(Optional.empty(), writer, nodeManager); + writer.flush(); + } + } + } + + public void close() throws Exception { + if (raftManager != null) { + raftManager.shutdown(); + } + if (snapshotFileReader != null) { + snapshotFileReader.close(); + } + nodeManager.close(); + } + + public static void main(String[] args) throws Exception { + ArgumentParser parser = ArgumentParsers + .newArgumentParser("metadata-tool") + .defaultHelp(true) + .description("The Apache Kafka metadata tool"); + parser.addArgument("--snapshot", "-s") + .type(String.class) + .help("The snapshot file to read."); + parser.addArgument("command") + .nargs("*") + .help("The command to run."); + Namespace res = parser.parseArgsOrFail(args); + try { + Builder builder = new Builder(); + builder.setSnapshotPath(res.getString("snapshot")); + Path tempDir = Files.createTempDirectory("MetadataShell"); + Exit.addShutdownHook("agent-shutdown-hook", () -> { + log.debug("Removing temporary directory " + tempDir.toAbsolutePath().toString()); + try { + Utils.delete(tempDir.toFile()); + } catch (Exception e) { + log.error("Got exception while removing temporary directory " + + tempDir.toAbsolutePath().toString()); + } + }); + MetadataShell shell = builder.build(); + try { + shell.run(res.getList("command")); + } finally { + shell.close(); + } + Exit.exit(0); + } catch (TerseFailure e) { + System.err.println("Error: " + e.getMessage()); + Exit.exit(1); + } catch (Throwable e) { + System.err.println("Unexpected error: " + + (e.getMessage() == null ? "" : e.getMessage())); + e.printStackTrace(System.err); + Exit.exit(1); + } + } + + void waitUntilCaughtUp() throws ExecutionException, InterruptedException { + snapshotFileReader.caughtUpFuture().get(); + } +} diff --git a/shell/src/main/java/org/apache/kafka/shell/NoOpCommandHandler.java b/shell/src/main/java/org/apache/kafka/shell/NoOpCommandHandler.java new file mode 100644 index 0000000..1756ba7 --- /dev/null +++ b/shell/src/main/java/org/apache/kafka/shell/NoOpCommandHandler.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.shell; + +import java.io.PrintWriter; +import java.util.Optional; + +/** + * Does nothing. + */ +public final class NoOpCommandHandler implements Commands.Handler { + @Override + public void run(Optional shell, + PrintWriter writer, + MetadataNodeManager manager) { + } + + @Override + public int hashCode() { + return 0; + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof NoOpCommandHandler)) return false; + return true; + } +} diff --git a/shell/src/main/java/org/apache/kafka/shell/NotDirectoryException.java b/shell/src/main/java/org/apache/kafka/shell/NotDirectoryException.java new file mode 100644 index 0000000..6925347 --- /dev/null +++ b/shell/src/main/java/org/apache/kafka/shell/NotDirectoryException.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.shell; + +/** + * An exception that is thrown when a non-directory node is treated like a + * directory. + */ +public class NotDirectoryException extends RuntimeException { + private static final long serialVersionUID = 1L; + + public NotDirectoryException() { + super(); + } +} diff --git a/shell/src/main/java/org/apache/kafka/shell/NotFileException.java b/shell/src/main/java/org/apache/kafka/shell/NotFileException.java new file mode 100644 index 0000000..cbc2a83 --- /dev/null +++ b/shell/src/main/java/org/apache/kafka/shell/NotFileException.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.shell; + +/** + * An exception that is thrown when a non-file node is treated like a + * file. + */ +public class NotFileException extends RuntimeException { + private static final long serialVersionUID = 1L; + + public NotFileException() { + super(); + } +} diff --git a/shell/src/main/java/org/apache/kafka/shell/PwdCommandHandler.java b/shell/src/main/java/org/apache/kafka/shell/PwdCommandHandler.java new file mode 100644 index 0000000..1e5b5da --- /dev/null +++ b/shell/src/main/java/org/apache/kafka/shell/PwdCommandHandler.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.shell; + +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.Namespace; +import org.jline.reader.Candidate; + +import java.io.PrintWriter; +import java.util.List; +import java.util.Optional; + +/** + * Implements the pwd command. + */ +public final class PwdCommandHandler implements Commands.Handler { + public final static Commands.Type TYPE = new PwdCommandType(); + + public static class PwdCommandType implements Commands.Type { + private PwdCommandType() { + } + + @Override + public String name() { + return "pwd"; + } + + @Override + public String description() { + return "Print the current working directory."; + } + + @Override + public boolean shellOnly() { + return true; + } + + @Override + public void addArguments(ArgumentParser parser) { + // nothing to do + } + + @Override + public Commands.Handler createHandler(Namespace namespace) { + return new PwdCommandHandler(); + } + + @Override + public void completeNext(MetadataNodeManager nodeManager, List nextWords, + List candidates) throws Exception { + // nothing to do + } + } + + @Override + public void run(Optional shell, + PrintWriter writer, + MetadataNodeManager manager) throws Exception { + manager.visit(data -> { + writer.println(data.workingDirectory()); + }); + } + + @Override + public int hashCode() { + return 0; + } + + @Override + public boolean equals(Object other) { + if (!(other instanceof PwdCommandHandler)) return false; + return true; + } +} diff --git a/shell/src/main/java/org/apache/kafka/shell/SnapshotFileReader.java b/shell/src/main/java/org/apache/kafka/shell/SnapshotFileReader.java new file mode 100644 index 0000000..9edf868 --- /dev/null +++ b/shell/src/main/java/org/apache/kafka/shell/SnapshotFileReader.java @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.shell; + +import org.apache.kafka.common.message.LeaderChangeMessage; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.record.ControlRecordType; +import org.apache.kafka.common.record.FileLogInputStream.FileChannelRecordBatch; +import org.apache.kafka.common.record.FileRecords; +import org.apache.kafka.common.record.Record; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.metadata.MetadataRecordSerde; +import org.apache.kafka.queue.EventQueue; +import org.apache.kafka.queue.KafkaEventQueue; +import org.apache.kafka.raft.Batch; +import org.apache.kafka.raft.LeaderAndEpoch; +import org.apache.kafka.raft.RaftClient; +import org.apache.kafka.raft.internals.MemoryBatchReader; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.OptionalInt; +import java.util.concurrent.CompletableFuture; + + +/** + * Reads Kafka metadata snapshots. + */ +public final class SnapshotFileReader implements AutoCloseable { + private static final Logger log = LoggerFactory.getLogger(SnapshotFileReader.class); + + private final String snapshotPath; + private final RaftClient.Listener listener; + private final KafkaEventQueue queue; + private final CompletableFuture caughtUpFuture; + private FileRecords fileRecords; + private Iterator batchIterator; + private final MetadataRecordSerde serde = new MetadataRecordSerde(); + + public SnapshotFileReader(String snapshotPath, RaftClient.Listener listener) { + this.snapshotPath = snapshotPath; + this.listener = listener; + this.queue = new KafkaEventQueue(Time.SYSTEM, + new LogContext("[snapshotReaderQueue] "), "snapshotReaderQueue_"); + this.caughtUpFuture = new CompletableFuture<>(); + } + + public void startup() throws Exception { + CompletableFuture future = new CompletableFuture<>(); + queue.append(new EventQueue.Event() { + @Override + public void run() throws Exception { + fileRecords = FileRecords.open(new File(snapshotPath), false); + batchIterator = fileRecords.batches().iterator(); + scheduleHandleNextBatch(); + future.complete(null); + } + + @Override + public void handleException(Throwable e) { + future.completeExceptionally(e); + beginShutdown("startup error"); + } + }); + future.get(); + } + + private void handleNextBatch() { + if (!batchIterator.hasNext()) { + beginShutdown("done"); + return; + } + FileChannelRecordBatch batch = batchIterator.next(); + if (batch.isControlBatch()) { + handleControlBatch(batch); + } else { + handleMetadataBatch(batch); + } + scheduleHandleNextBatch(); + } + + private void scheduleHandleNextBatch() { + queue.append(new EventQueue.Event() { + @Override + public void run() { + handleNextBatch(); + } + + @Override + public void handleException(Throwable e) { + log.error("Unexpected error while handling a batch of events", e); + beginShutdown("handleBatch error"); + } + }); + } + + private void handleControlBatch(FileChannelRecordBatch batch) { + for (Iterator iter = batch.iterator(); iter.hasNext(); ) { + Record record = iter.next(); + try { + short typeId = ControlRecordType.parseTypeId(record.key()); + ControlRecordType type = ControlRecordType.fromTypeId(typeId); + switch (type) { + case LEADER_CHANGE: + LeaderChangeMessage message = new LeaderChangeMessage(); + message.read(new ByteBufferAccessor(record.value()), (short) 0); + listener.handleLeaderChange(new LeaderAndEpoch( + OptionalInt.of(message.leaderId()), + batch.partitionLeaderEpoch() + )); + break; + default: + log.error("Ignoring control record with type {} at offset {}", + type, record.offset()); + } + } catch (Throwable e) { + log.error("unable to read control record at offset {}", record.offset(), e); + } + } + } + + private void handleMetadataBatch(FileChannelRecordBatch batch) { + List messages = new ArrayList<>(); + for (Record record : batch) { + ByteBufferAccessor accessor = new ByteBufferAccessor(record.value()); + try { + ApiMessageAndVersion messageAndVersion = serde.read(accessor, record.valueSize()); + messages.add(messageAndVersion); + } catch (Throwable e) { + log.error("unable to read metadata record at offset {}", record.offset(), e); + } + } + listener.handleCommit( + MemoryBatchReader.of( + Collections.singletonList( + Batch.data( + batch.baseOffset(), + batch.partitionLeaderEpoch(), + batch.maxTimestamp(), + batch.sizeInBytes(), + messages + ) + ), + reader -> { } + ) + ); + } + + public void beginShutdown(String reason) { + if (reason.equals("done")) { + caughtUpFuture.complete(null); + } else { + caughtUpFuture.completeExceptionally(new RuntimeException(reason)); + } + queue.beginShutdown(reason, new EventQueue.Event() { + @Override + public void run() throws Exception { + listener.beginShutdown(); + if (fileRecords != null) { + fileRecords.close(); + fileRecords = null; + } + batchIterator = null; + } + + @Override + public void handleException(Throwable e) { + log.error("shutdown error", e); + } + }); + } + + @Override + public void close() throws Exception { + beginShutdown("closing"); + queue.close(); + } + + public CompletableFuture caughtUpFuture() { + return caughtUpFuture; + } +} diff --git a/shell/src/test/java/org/apache/kafka/shell/CommandTest.java b/shell/src/test/java/org/apache/kafka/shell/CommandTest.java new file mode 100644 index 0000000..c896a06 --- /dev/null +++ b/shell/src/test/java/org/apache/kafka/shell/CommandTest.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.shell; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Optional; + +@Timeout(value = 120000, unit = MILLISECONDS) +public class CommandTest { + @Test + public void testParseCommands() { + assertEquals(new CatCommandHandler(Arrays.asList("foo")), + new Commands(true).parseCommand(Arrays.asList("cat", "foo"))); + assertEquals(new CdCommandHandler(Optional.empty()), + new Commands(true).parseCommand(Arrays.asList("cd"))); + assertEquals(new CdCommandHandler(Optional.of("foo")), + new Commands(true).parseCommand(Arrays.asList("cd", "foo"))); + assertEquals(new ExitCommandHandler(), + new Commands(true).parseCommand(Arrays.asList("exit"))); + assertEquals(new HelpCommandHandler(), + new Commands(true).parseCommand(Arrays.asList("help"))); + assertEquals(new HistoryCommandHandler(3), + new Commands(true).parseCommand(Arrays.asList("history", "3"))); + assertEquals(new HistoryCommandHandler(Integer.MAX_VALUE), + new Commands(true).parseCommand(Arrays.asList("history"))); + assertEquals(new LsCommandHandler(Collections.emptyList()), + new Commands(true).parseCommand(Arrays.asList("ls"))); + assertEquals(new LsCommandHandler(Arrays.asList("abc", "123")), + new Commands(true).parseCommand(Arrays.asList("ls", "abc", "123"))); + assertEquals(new PwdCommandHandler(), + new Commands(true).parseCommand(Arrays.asList("pwd"))); + } + + @Test + public void testParseInvalidCommand() { + assertEquals(new ErroneousCommandHandler("invalid choice: 'blah' (choose " + + "from 'cat', 'cd', 'exit', 'find', 'help', 'history', 'ls', 'man', 'pwd')"), + new Commands(true).parseCommand(Arrays.asList("blah"))); + } + + @Test + public void testEmptyCommandLine() { + assertEquals(new NoOpCommandHandler(), + new Commands(true).parseCommand(Arrays.asList(""))); + assertEquals(new NoOpCommandHandler(), + new Commands(true).parseCommand(Collections.emptyList())); + } +} diff --git a/shell/src/test/java/org/apache/kafka/shell/CommandUtilsTest.java b/shell/src/test/java/org/apache/kafka/shell/CommandUtilsTest.java new file mode 100644 index 0000000..90c3b5c --- /dev/null +++ b/shell/src/test/java/org/apache/kafka/shell/CommandUtilsTest.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.shell; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.Arrays; + +@Timeout(value = 120000, unit = MILLISECONDS) +public class CommandUtilsTest { + @Test + public void testSplitPath() { + assertEquals(Arrays.asList("alpha", "beta"), + CommandUtils.splitPath("/alpha/beta")); + assertEquals(Arrays.asList("alpha", "beta"), + CommandUtils.splitPath("//alpha/beta/")); + } +} diff --git a/shell/src/test/java/org/apache/kafka/shell/GlobComponentTest.java b/shell/src/test/java/org/apache/kafka/shell/GlobComponentTest.java new file mode 100644 index 0000000..da3a7ec --- /dev/null +++ b/shell/src/test/java/org/apache/kafka/shell/GlobComponentTest.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.shell; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +@Timeout(value = 120000, unit = MILLISECONDS) +public class GlobComponentTest { + private void verifyIsLiteral(GlobComponent globComponent, String component) { + assertTrue(globComponent.literal()); + assertEquals(component, globComponent.component()); + assertTrue(globComponent.matches(component)); + assertFalse(globComponent.matches(component + "foo")); + } + + @Test + public void testLiteralComponent() { + verifyIsLiteral(new GlobComponent("abc"), "abc"); + verifyIsLiteral(new GlobComponent(""), ""); + verifyIsLiteral(new GlobComponent("foobar_123"), "foobar_123"); + verifyIsLiteral(new GlobComponent("$blah+"), "$blah+"); + } + + @Test + public void testToRegularExpression() { + assertEquals(null, GlobComponent.toRegularExpression("blah")); + assertEquals(null, GlobComponent.toRegularExpression("")); + assertEquals(null, GlobComponent.toRegularExpression("does not need a regex, actually")); + assertEquals("^\\$blah.*$", GlobComponent.toRegularExpression("$blah*")); + assertEquals("^.*$", GlobComponent.toRegularExpression("*")); + assertEquals("^foo(?:(?:bar)|(?:baz))$", GlobComponent.toRegularExpression("foo{bar,baz}")); + } + + @Test + public void testGlobMatch() { + GlobComponent star = new GlobComponent("*"); + assertFalse(star.literal()); + assertTrue(star.matches("")); + assertTrue(star.matches("anything")); + GlobComponent question = new GlobComponent("b?b"); + assertFalse(question.literal()); + assertFalse(question.matches("")); + assertTrue(question.matches("bob")); + assertTrue(question.matches("bib")); + assertFalse(question.matches("bic")); + GlobComponent foobarOrFoobaz = new GlobComponent("foo{bar,baz}"); + assertFalse(foobarOrFoobaz.literal()); + assertTrue(foobarOrFoobaz.matches("foobar")); + assertTrue(foobarOrFoobaz.matches("foobaz")); + assertFalse(foobarOrFoobaz.matches("foobah")); + assertFalse(foobarOrFoobaz.matches("foo")); + assertFalse(foobarOrFoobaz.matches("baz")); + } +} diff --git a/shell/src/test/java/org/apache/kafka/shell/GlobVisitorTest.java b/shell/src/test/java/org/apache/kafka/shell/GlobVisitorTest.java new file mode 100644 index 0000000..59eeb5d --- /dev/null +++ b/shell/src/test/java/org/apache/kafka/shell/GlobVisitorTest.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.shell; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.apache.kafka.shell.GlobVisitor.MetadataNodeInfo; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.function.Consumer; + +@Timeout(value = 120000, unit = MILLISECONDS) +public class GlobVisitorTest { + static private final MetadataNodeManager.Data DATA; + + static { + DATA = new MetadataNodeManager.Data(); + DATA.root().mkdirs("alpha", "beta", "gamma"); + DATA.root().mkdirs("alpha", "theta"); + DATA.root().mkdirs("foo", "a"); + DATA.root().mkdirs("foo", "beta"); + DATA.root().mkdirs("zeta").create("c"); + DATA.root().mkdirs("zeta"); + DATA.root().create("zzz"); + DATA.setWorkingDirectory("foo"); + } + + static class InfoConsumer implements Consumer> { + private Optional> infos = null; + + @Override + public void accept(Optional info) { + if (infos == null) { + if (info.isPresent()) { + infos = Optional.of(new ArrayList<>()); + infos.get().add(info.get()); + } else { + infos = Optional.empty(); + } + } else { + if (info.isPresent()) { + infos.get().add(info.get()); + } else { + throw new RuntimeException("Saw non-empty info after seeing empty info"); + } + } + } + } + + @Test + public void testStarGlob() { + InfoConsumer consumer = new InfoConsumer(); + GlobVisitor visitor = new GlobVisitor("*", consumer); + visitor.accept(DATA); + assertEquals(Optional.of(Arrays.asList( + new MetadataNodeInfo(new String[] {"foo", "a"}, + DATA.root().directory("foo").child("a")), + new MetadataNodeInfo(new String[] {"foo", "beta"}, + DATA.root().directory("foo").child("beta")))), consumer.infos); + } + + @Test + public void testDotDot() { + InfoConsumer consumer = new InfoConsumer(); + GlobVisitor visitor = new GlobVisitor("..", consumer); + visitor.accept(DATA); + assertEquals(Optional.of(Arrays.asList( + new MetadataNodeInfo(new String[0], DATA.root()))), consumer.infos); + } + + @Test + public void testDoubleDotDot() { + InfoConsumer consumer = new InfoConsumer(); + GlobVisitor visitor = new GlobVisitor("../..", consumer); + visitor.accept(DATA); + assertEquals(Optional.of(Arrays.asList( + new MetadataNodeInfo(new String[0], DATA.root()))), consumer.infos); + } + + @Test + public void testZGlob() { + InfoConsumer consumer = new InfoConsumer(); + GlobVisitor visitor = new GlobVisitor("../z*", consumer); + visitor.accept(DATA); + assertEquals(Optional.of(Arrays.asList( + new MetadataNodeInfo(new String[] {"zeta"}, + DATA.root().child("zeta")), + new MetadataNodeInfo(new String[] {"zzz"}, + DATA.root().child("zzz")))), consumer.infos); + } + + @Test + public void testBetaOrThetaGlob() { + InfoConsumer consumer = new InfoConsumer(); + GlobVisitor visitor = new GlobVisitor("../*/{beta,theta}", consumer); + visitor.accept(DATA); + assertEquals(Optional.of(Arrays.asList( + new MetadataNodeInfo(new String[] {"alpha", "beta"}, + DATA.root().directory("alpha").child("beta")), + new MetadataNodeInfo(new String[] {"alpha", "theta"}, + DATA.root().directory("alpha").child("theta")), + new MetadataNodeInfo(new String[] {"foo", "beta"}, + DATA.root().directory("foo").child("beta")))), consumer.infos); + } + + @Test + public void testNotFoundGlob() { + InfoConsumer consumer = new InfoConsumer(); + GlobVisitor visitor = new GlobVisitor("epsilon", consumer); + visitor.accept(DATA); + assertEquals(Optional.empty(), consumer.infos); + } + + @Test + public void testAbsoluteGlob() { + InfoConsumer consumer = new InfoConsumer(); + GlobVisitor visitor = new GlobVisitor("/a?pha", consumer); + visitor.accept(DATA); + assertEquals(Optional.of(Arrays.asList( + new MetadataNodeInfo(new String[] {"alpha"}, + DATA.root().directory("alpha")))), consumer.infos); + } +} diff --git a/shell/src/test/java/org/apache/kafka/shell/LsCommandHandlerTest.java b/shell/src/test/java/org/apache/kafka/shell/LsCommandHandlerTest.java new file mode 100644 index 0000000..c845706 --- /dev/null +++ b/shell/src/test/java/org/apache/kafka/shell/LsCommandHandlerTest.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.shell; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.apache.kafka.shell.LsCommandHandler.ColumnSchema; +import org.apache.kafka.shell.LsCommandHandler.TargetDirectory; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.io.ByteArrayOutputStream; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Collections; +import java.util.OptionalInt; + +@Timeout(value = 120000, unit = MILLISECONDS) +public class LsCommandHandlerTest { + @Test + public void testCalculateColumnSchema() { + assertEquals(new ColumnSchema(1, 3), + LsCommandHandler.calculateColumnSchema(OptionalInt.empty(), + Arrays.asList("abc", "def", "ghi"))); + assertEquals(new ColumnSchema(1, 2), + LsCommandHandler.calculateColumnSchema(OptionalInt.of(0), + Arrays.asList("abc", "def"))); + assertEquals(new ColumnSchema(3, 1).setColumnWidths(3, 8, 6), + LsCommandHandler.calculateColumnSchema(OptionalInt.of(80), + Arrays.asList("a", "abcdef", "beta"))); + assertEquals(new ColumnSchema(2, 3).setColumnWidths(10, 7), + LsCommandHandler.calculateColumnSchema(OptionalInt.of(18), + Arrays.asList("alphabet", "beta", "gamma", "theta", "zeta"))); + } + + @Test + public void testPrintEntries() throws Exception { + try (ByteArrayOutputStream stream = new ByteArrayOutputStream()) { + try (PrintWriter writer = new PrintWriter(new OutputStreamWriter( + stream, StandardCharsets.UTF_8))) { + LsCommandHandler.printEntries(writer, "", OptionalInt.of(18), + Arrays.asList("alphabet", "beta", "gamma", "theta", "zeta")); + } + assertEquals(String.join(String.format("%n"), Arrays.asList( + "alphabet theta", + "beta zeta", + "gamma")), stream.toString().trim()); + } + } + + @Test + public void testPrintTargets() throws Exception { + try (ByteArrayOutputStream stream = new ByteArrayOutputStream()) { + try (PrintWriter writer = new PrintWriter(new OutputStreamWriter( + stream, StandardCharsets.UTF_8))) { + LsCommandHandler.printTargets(writer, OptionalInt.of(18), + Arrays.asList("foo", "foobarbaz", "quux"), Arrays.asList( + new TargetDirectory("/some/dir", + Collections.singletonList("supercalifragalistic")), + new TargetDirectory("/some/other/dir", + Arrays.asList("capability", "delegation", "elephant", + "fungible", "green")))); + } + assertEquals(String.join(String.format("%n"), Arrays.asList( + "foo quux", + "foobarbaz ", + "", + "/some/dir:", + "supercalifragalistic", + "", + "/some/other/dir:", + "capability", + "delegation", + "elephant", + "fungible", + "green")), stream.toString().trim()); + } + } +} + diff --git a/shell/src/test/java/org/apache/kafka/shell/MetadataNodeManagerTest.java b/shell/src/test/java/org/apache/kafka/shell/MetadataNodeManagerTest.java new file mode 100644 index 0000000..81483f5 --- /dev/null +++ b/shell/src/test/java/org/apache/kafka/shell/MetadataNodeManagerTest.java @@ -0,0 +1,293 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.shell; + +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.metadata.ClientQuotaRecord; +import org.apache.kafka.common.metadata.ConfigRecord; +import org.apache.kafka.common.metadata.FenceBrokerRecord; +import org.apache.kafka.common.metadata.PartitionChangeRecord; +import org.apache.kafka.common.metadata.PartitionRecord; +import org.apache.kafka.common.metadata.PartitionRecordJsonConverter; +import org.apache.kafka.common.metadata.RegisterBrokerRecord; +import org.apache.kafka.common.metadata.RemoveTopicRecord; +import org.apache.kafka.common.metadata.TopicRecord; +import org.apache.kafka.common.metadata.UnfenceBrokerRecord; +import org.apache.kafka.common.metadata.UnregisterBrokerRecord; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; + +import static org.apache.kafka.metadata.LeaderConstants.NO_LEADER_CHANGE; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; + + +public class MetadataNodeManagerTest { + + private MetadataNodeManager metadataNodeManager; + + @BeforeEach + public void setup() throws Exception { + metadataNodeManager = new MetadataNodeManager(); + metadataNodeManager.setup(); + } + + @AfterEach + public void cleanup() throws Exception { + metadataNodeManager.close(); + } + + @Test + public void testRegisterBrokerRecordAndUnregisterBrokerRecord() { + // Register broker + RegisterBrokerRecord record = new RegisterBrokerRecord() + .setBrokerId(1) + .setBrokerEpoch(2); + metadataNodeManager.handleMessage(record); + + assertEquals(record.toString(), + metadataNodeManager.getData().root().directory("brokers", "1").file("registration").contents()); + assertEquals("true", + metadataNodeManager.getData().root().directory("brokers", "1").file("isFenced").contents()); + + // Unregister broker + UnregisterBrokerRecord unregisterBrokerRecord = new UnregisterBrokerRecord() + .setBrokerId(1); + metadataNodeManager.handleMessage(unregisterBrokerRecord); + assertFalse(metadataNodeManager.getData().root().directory("brokers").children().containsKey("1")); + } + + @Test + public void testTopicRecordAndRemoveTopicRecord() { + // Add topic + TopicRecord topicRecord = new TopicRecord() + .setName("topicName") + .setTopicId(Uuid.fromString("GcaQDl2UTsCNs1p9s37XkQ")); + + metadataNodeManager.handleMessage(topicRecord); + + assertEquals("topicName", + metadataNodeManager.getData().root().directory("topics", "topicName").file("name").contents()); + assertEquals("GcaQDl2UTsCNs1p9s37XkQ", + metadataNodeManager.getData().root().directory("topics", "topicName").file("id").contents()); + assertEquals("topicName", + metadataNodeManager.getData().root().directory("topicIds", "GcaQDl2UTsCNs1p9s37XkQ").file("name").contents()); + assertEquals("GcaQDl2UTsCNs1p9s37XkQ", + metadataNodeManager.getData().root().directory("topicIds", "GcaQDl2UTsCNs1p9s37XkQ").file("id").contents()); + + // Remove topic + RemoveTopicRecord removeTopicRecord = new RemoveTopicRecord() + .setTopicId(Uuid.fromString("GcaQDl2UTsCNs1p9s37XkQ")); + + metadataNodeManager.handleMessage(removeTopicRecord); + + assertFalse( + metadataNodeManager.getData().root().directory("topicIds").children().containsKey("GcaQDl2UTsCNs1p9s37XkQ")); + assertFalse( + metadataNodeManager.getData().root().directory("topics").children().containsKey("topicName")); + } + + @Test + public void testPartitionRecord() { + PartitionRecord record = new PartitionRecord() + .setTopicId(Uuid.fromString("GcaQDl2UTsCNs1p9s37XkQ")) + .setPartitionId(0) + .setLeaderEpoch(1) + .setReplicas(Arrays.asList(1, 2, 3)) + .setIsr(Arrays.asList(1, 2, 3)); + + metadataNodeManager.handleMessage(record); + assertEquals( + PartitionRecordJsonConverter.write(record, PartitionRecord.HIGHEST_SUPPORTED_VERSION).toPrettyString(), + metadataNodeManager.getData().root().directory("topicIds", "GcaQDl2UTsCNs1p9s37XkQ", "0").file("data").contents()); + } + + @Test + public void testValidConfigRecord() { + checkValidConfigRecord(ConfigResource.Type.BROKER.id(), "broker"); + checkValidConfigRecord(ConfigResource.Type.TOPIC.id(), "topic"); + } + + private void checkValidConfigRecord(byte resourceType, String typeString) { + ConfigRecord configRecord = new ConfigRecord() + .setResourceType(resourceType) + .setResourceName("0") + .setName("name") + .setValue("kraft"); + + metadataNodeManager.handleMessage(configRecord); + assertEquals("kraft", + metadataNodeManager.getData().root().directory("configs", typeString, "0").file("name").contents()); + + // null value indicates delete + configRecord.setValue(null); + metadataNodeManager.handleMessage(configRecord); + assertFalse( + metadataNodeManager.getData().root().directory("configs", typeString, "0").children().containsKey("name")); + } + + @Test + public void testInvalidConfigRecord() { + checkInvalidConfigRecord(ConfigResource.Type.BROKER_LOGGER.id()); + checkInvalidConfigRecord(ConfigResource.Type.UNKNOWN.id()); + } + + private void checkInvalidConfigRecord(byte resourceType) { + ConfigRecord configRecord = new ConfigRecord() + .setResourceType(resourceType) + .setResourceName("0") + .setName("name") + .setValue("kraft"); + metadataNodeManager.handleMessage(configRecord); + assertFalse(metadataNodeManager.getData().root().children().containsKey("configs")); + } + + @Test + public void testPartitionChangeRecord() { + PartitionRecord oldPartitionRecord = new PartitionRecord() + .setTopicId(Uuid.fromString("GcaQDl2UTsCNs1p9s37XkQ")) + .setPartitionId(0) + .setPartitionEpoch(0) + .setLeader(0) + .setLeaderEpoch(0) + .setIsr(Arrays.asList(0, 1, 2)) + .setReplicas(Arrays.asList(0, 1, 2)); + + PartitionChangeRecord partitionChangeRecord = new PartitionChangeRecord() + .setTopicId(Uuid.fromString("GcaQDl2UTsCNs1p9s37XkQ")) + .setPartitionId(0) + .setLeader(NO_LEADER_CHANGE) + .setReplicas(Arrays.asList(0, 1, 2)); + + PartitionRecord newPartitionRecord = new PartitionRecord() + .setTopicId(Uuid.fromString("GcaQDl2UTsCNs1p9s37XkQ")) + .setPartitionId(0) + .setPartitionEpoch(1) + .setLeader(0) + .setLeaderEpoch(0) + .setIsr(Arrays.asList(0, 1, 2)) + .setReplicas(Arrays.asList(0, 1, 2)); + + // Change nothing + checkPartitionChangeRecord( + oldPartitionRecord, + partitionChangeRecord, + newPartitionRecord + ); + + // Change isr + checkPartitionChangeRecord( + oldPartitionRecord, + partitionChangeRecord.duplicate().setIsr(Arrays.asList(0, 2)), + newPartitionRecord.duplicate().setIsr(Arrays.asList(0, 2)) + ); + + // Change leader + checkPartitionChangeRecord( + oldPartitionRecord, + partitionChangeRecord.duplicate().setLeader(1), + newPartitionRecord.duplicate().setLeader(1).setLeaderEpoch(1) + ); + } + + private void checkPartitionChangeRecord(PartitionRecord oldPartitionRecord, + PartitionChangeRecord partitionChangeRecord, + PartitionRecord newPartitionRecord) { + metadataNodeManager.handleMessage(oldPartitionRecord); + metadataNodeManager.handleMessage(partitionChangeRecord); + assertEquals( + PartitionRecordJsonConverter.write(newPartitionRecord, PartitionRecord.HIGHEST_SUPPORTED_VERSION).toPrettyString(), + metadataNodeManager.getData().root() + .directory("topicIds", oldPartitionRecord.topicId().toString(), oldPartitionRecord.partitionId() + "") + .file("data").contents() + ); + } + + @Test + public void testUnfenceBrokerRecordAndFenceBrokerRecord() { + RegisterBrokerRecord record = new RegisterBrokerRecord() + .setBrokerId(1) + .setBrokerEpoch(2); + metadataNodeManager.handleMessage(record); + + assertEquals("true", + metadataNodeManager.getData().root().directory("brokers", "1").file("isFenced").contents()); + + UnfenceBrokerRecord unfenceBrokerRecord = new UnfenceBrokerRecord() + .setId(1) + .setEpoch(2); + metadataNodeManager.handleMessage(unfenceBrokerRecord); + assertEquals("false", + metadataNodeManager.getData().root().directory("brokers", "1").file("isFenced").contents()); + + FenceBrokerRecord fenceBrokerRecord = new FenceBrokerRecord() + .setId(1) + .setEpoch(2); + metadataNodeManager.handleMessage(fenceBrokerRecord); + assertEquals("true", + metadataNodeManager.getData().root().directory("brokers", "1").file("isFenced").contents()); + } + + @Test + public void testClientQuotaRecord() { + ClientQuotaRecord record = new ClientQuotaRecord() + .setEntity(Arrays.asList( + new ClientQuotaRecord.EntityData() + .setEntityType("user") + .setEntityName("kraft"), + new ClientQuotaRecord.EntityData() + .setEntityType("client") + .setEntityName("kstream") + )) + .setKey("producer_byte_rate") + .setValue(1000.0); + + metadataNodeManager.handleMessage(record); + + assertEquals("1000.0", + metadataNodeManager.getData().root().directory("client-quotas", + "client", "kstream", + "user", "kraft").file("producer_byte_rate").contents()); + + metadataNodeManager.handleMessage(record.setRemove(true)); + + assertFalse( + metadataNodeManager.getData().root().directory("client-quotas", + "client", "kstream", + "user", "kraft").children().containsKey("producer_byte_rate")); + + record = new ClientQuotaRecord() + .setEntity(Arrays.asList( + new ClientQuotaRecord.EntityData() + .setEntityType("user") + .setEntityName(null) + )) + .setKey("producer_byte_rate") + .setValue(2000.0); + + metadataNodeManager.handleMessage(record); + + assertEquals("2000.0", + metadataNodeManager.getData().root().directory("client-quotas", + "user", "").file("producer_byte_rate").contents()); + } +} diff --git a/shell/src/test/java/org/apache/kafka/shell/MetadataNodeTest.java b/shell/src/test/java/org/apache/kafka/shell/MetadataNodeTest.java new file mode 100644 index 0000000..42223c7 --- /dev/null +++ b/shell/src/test/java/org/apache/kafka/shell/MetadataNodeTest.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.shell; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.apache.kafka.shell.MetadataNode.DirectoryNode; +import org.apache.kafka.shell.MetadataNode.FileNode; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; + +@Timeout(value = 120000, unit = MILLISECONDS) +public class MetadataNodeTest { + @Test + public void testMkdirs() { + DirectoryNode root = new DirectoryNode(); + DirectoryNode defNode = root.mkdirs("abc", "def"); + DirectoryNode defNode2 = root.mkdirs("abc", "def"); + assertTrue(defNode == defNode2); + DirectoryNode defNode3 = root.directory("abc", "def"); + assertTrue(defNode == defNode3); + root.mkdirs("ghi"); + assertEquals(new HashSet<>(Arrays.asList("abc", "ghi")), root.children().keySet()); + assertEquals(Collections.singleton("def"), root.mkdirs("abc").children().keySet()); + assertEquals(Collections.emptySet(), defNode.children().keySet()); + } + + @Test + public void testRmrf() { + DirectoryNode root = new DirectoryNode(); + DirectoryNode foo = root.mkdirs("foo"); + foo.mkdirs("a"); + foo.mkdirs("b"); + root.mkdirs("baz"); + assertEquals(new HashSet<>(Arrays.asList("foo", "baz")), root.children().keySet()); + root.rmrf("foo", "a"); + assertEquals(new HashSet<>(Arrays.asList("b")), foo.children().keySet()); + root.rmrf("foo"); + assertEquals(new HashSet<>(Collections.singleton("baz")), root.children().keySet()); + } + + @Test + public void testCreateFiles() { + DirectoryNode root = new DirectoryNode(); + DirectoryNode abcdNode = root.mkdirs("abcd"); + FileNode quuxNodde = abcdNode.create("quux"); + quuxNodde.setContents("quux contents"); + assertEquals("quux contents", quuxNodde.contents()); + assertThrows(NotDirectoryException.class, () -> root.mkdirs("abcd", "quux")); + } +} diff --git a/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/LogSegmentData.java b/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/LogSegmentData.java new file mode 100644 index 0000000..905bfd0 --- /dev/null +++ b/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/LogSegmentData.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.storage; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.nio.ByteBuffer; +import java.nio.file.Path; +import java.util.Objects; +import java.util.Optional; + +/** + * This represents all the required data and indexes for a specific log segment that needs to be stored in the remote + * storage. This is passed with {@link RemoteStorageManager#copyLogSegmentData(RemoteLogSegmentMetadata, LogSegmentData)} + * while copying a specific log segment to the remote storage. + */ +@InterfaceStability.Evolving +public class LogSegmentData { + + private final Path logSegment; + private final Path offsetIndex; + private final Path timeIndex; + private final Optional transactionIndex; + private final Path producerSnapshotIndex; + private final ByteBuffer leaderEpochIndex; + + /** + * Creates a LogSegmentData instance with data and indexes. + * + * @param logSegment actual log segment file + * @param offsetIndex offset index file + * @param timeIndex time index file + * @param transactionIndex transaction index file, which can be null + * @param producerSnapshotIndex producer snapshot until this segment + * @param leaderEpochIndex leader-epoch-index until this segment + */ + public LogSegmentData(Path logSegment, + Path offsetIndex, + Path timeIndex, + Optional transactionIndex, + Path producerSnapshotIndex, + ByteBuffer leaderEpochIndex) { + this.logSegment = Objects.requireNonNull(logSegment, "logSegment can not be null"); + this.offsetIndex = Objects.requireNonNull(offsetIndex, "offsetIndex can not be null"); + this.timeIndex = Objects.requireNonNull(timeIndex, "timeIndex can not be null"); + this.transactionIndex = Objects.requireNonNull(transactionIndex, "transactionIndex can not be null"); + this.producerSnapshotIndex = Objects.requireNonNull(producerSnapshotIndex, "producerSnapshotIndex can not be null"); + this.leaderEpochIndex = Objects.requireNonNull(leaderEpochIndex, "leaderEpochIndex can not be null"); + } + + /** + * @return Log segment file of this segment. + */ + public Path logSegment() { + return logSegment; + } + + /** + * @return Offset index file. + */ + public Path offsetIndex() { + return offsetIndex; + } + + /** + * @return Time index file of this segment. + */ + public Path timeIndex() { + return timeIndex; + } + + /** + * @return Transaction index file of this segment if it exists. + */ + public Optional transactionIndex() { + return transactionIndex; + } + + /** + * @return Producer snapshot file until this segment. + */ + public Path producerSnapshotIndex() { + return producerSnapshotIndex; + } + + /** + * @return Leader epoch index until this segment. + */ + public ByteBuffer leaderEpochIndex() { + return leaderEpochIndex; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + LogSegmentData that = (LogSegmentData) o; + return Objects.equals(logSegment, that.logSegment) && + Objects.equals(offsetIndex, that.offsetIndex) && + Objects.equals(timeIndex, that.timeIndex) && + Objects.equals(transactionIndex, that.transactionIndex) && + Objects.equals(producerSnapshotIndex, that.producerSnapshotIndex) && + Objects.equals(leaderEpochIndex, that.leaderEpochIndex); + } + + @Override + public int hashCode() { + return Objects.hash(logSegment, offsetIndex, timeIndex, transactionIndex, producerSnapshotIndex, leaderEpochIndex); + } + + @Override + public String toString() { + return "LogSegmentData{" + + "logSegment=" + logSegment + + ", offsetIndex=" + offsetIndex + + ", timeIndex=" + timeIndex + + ", txnIndex=" + transactionIndex + + ", producerSnapshotIndex=" + producerSnapshotIndex + + ", leaderEpochIndex=" + leaderEpochIndex + + '}'; + } +} \ No newline at end of file diff --git a/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteLogMetadata.java b/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteLogMetadata.java new file mode 100644 index 0000000..74d5c3d --- /dev/null +++ b/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteLogMetadata.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.storage; + +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.annotation.InterfaceStability; + +/** + * Base class for remote log metadata objects like {@link RemoteLogSegmentMetadata}, {@link RemoteLogSegmentMetadataUpdate}, + * and {@link RemotePartitionDeleteMetadata}. + */ +@InterfaceStability.Evolving +public abstract class RemoteLogMetadata { + + /** + * Broker id from which this event is generated. + */ + private final int brokerId; + + /** + * Epoch time in milli seconds at which this event is generated. + */ + private final long eventTimestampMs; + + protected RemoteLogMetadata(int brokerId, long eventTimestampMs) { + this.brokerId = brokerId; + this.eventTimestampMs = eventTimestampMs; + } + + /** + * @return Epoch time in milli seconds at which this event is occurred. + */ + public long eventTimestampMs() { + return eventTimestampMs; + } + + /** + * @return Broker id from which this event is generated. + */ + public int brokerId() { + return brokerId; + } + + /** + * @return TopicIdPartition for which this event is generated. + */ + public abstract TopicIdPartition topicIdPartition(); +} diff --git a/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteLogMetadataManager.java b/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteLogMetadataManager.java new file mode 100644 index 0000000..9a29746 --- /dev/null +++ b/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteLogMetadataManager.java @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.storage; + +import org.apache.kafka.common.Configurable; +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.io.Closeable; +import java.util.Iterator; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; + +/** + * This interface provides storing and fetching remote log segment metadata with strongly consistent semantics. + *

                + * This class can be plugged in to Kafka cluster by adding the implementation class as + * remote.log.metadata.manager.class.name property value. There is an inbuilt implementation backed by + * topic storage in the local cluster. This is used as the default implementation if + * remote.log.metadata.manager.class.name is not configured. + *

                + *

                + * remote.log.metadata.manager.class.path property is about the class path of the RemoteLogStorageManager + * implementation. If specified, the RemoteLogStorageManager implementation and its dependent libraries will be loaded + * by a dedicated classloader which searches this class path before the Kafka broker class path. The syntax of this + * parameter is same with the standard Java class path string. + *

                + *

                + * remote.log.metadata.manager.listener.name property is about listener name of the local broker to which + * it should get connected if needed by RemoteLogMetadataManager implementation. When this is configured all other + * required properties can be passed as properties with prefix of 'remote.log.metadata.manager.listener. + *

                + * "cluster.id", "broker.id" and all other properties prefixed with "remote.log.metadata." are passed when + * {@link #configure(Map)} is invoked on this instance. + *

                + */ +@InterfaceStability.Evolving +public interface RemoteLogMetadataManager extends Configurable, Closeable { + + /** + * This method is used to add {@link RemoteLogSegmentMetadata} asynchronously with the containing {@link RemoteLogSegmentId} into {@link RemoteLogMetadataManager}. + *

                + * RemoteLogSegmentMetadata is identified by RemoteLogSegmentId and it should have the initial state which is {@link RemoteLogSegmentState#COPY_SEGMENT_STARTED}. + *

                + * {@link #updateRemoteLogSegmentMetadata(RemoteLogSegmentMetadataUpdate)} should be used to update an existing RemoteLogSegmentMetadata. + * + * @param remoteLogSegmentMetadata metadata about the remote log segment. + * @throws RemoteStorageException if there are any storage related errors occurred. + * @throws IllegalArgumentException if the given metadata instance does not have the state as {@link RemoteLogSegmentState#COPY_SEGMENT_STARTED} + * @return a CompletableFuture which will complete once this operation is finished. + */ + CompletableFuture addRemoteLogSegmentMetadata(RemoteLogSegmentMetadata remoteLogSegmentMetadata) throws RemoteStorageException; + + /** + * This method is used to update the {@link RemoteLogSegmentMetadata} asynchronously. Currently, it allows to update with the new + * state based on the life cycle of the segment. It can go through the below state transitions. + *

                + *

                +     * +---------------------+            +----------------------+
                +     * |COPY_SEGMENT_STARTED |----------->|COPY_SEGMENT_FINISHED |
                +     * +-------------------+-+            +--+-------------------+
                +     *                     |                 |
                +     *                     |                 |
                +     *                     v                 v
                +     *                  +--+-----------------+-+
                +     *                  |DELETE_SEGMENT_STARTED|
                +     *                  +-----------+----------+
                +     *                              |
                +     *                              |
                +     *                              v
                +     *                  +-----------+-----------+
                +     *                  |DELETE_SEGMENT_FINISHED|
                +     *                  +-----------------------+
                +     * 
                + *

                + * {@link RemoteLogSegmentState#COPY_SEGMENT_STARTED} - This state indicates that the segment copying to remote storage is started but not yet finished. + * {@link RemoteLogSegmentState#COPY_SEGMENT_FINISHED} - This state indicates that the segment copying to remote storage is finished. + *
                + * The leader broker copies the log segments to the remote storage and puts the remote log segment metadata with the + * state as “COPY_SEGMENT_STARTED” and updates the state as “COPY_SEGMENT_FINISHED” once the copy is successful. + *

                + * {@link RemoteLogSegmentState#DELETE_SEGMENT_STARTED} - This state indicates that the segment deletion is started but not yet finished. + * {@link RemoteLogSegmentState#DELETE_SEGMENT_FINISHED} - This state indicates that the segment is deleted successfully. + *
                + * Leader partitions publish both the above delete segment events when remote log retention is reached for the + * respective segments. Remote Partition Removers also publish these events when a segment is deleted as part of + * the remote partition deletion. + * + * @param remoteLogSegmentMetadataUpdate update of the remote log segment metadata. + * @throws RemoteStorageException if there are any storage related errors occurred. + * @throws RemoteResourceNotFoundException when there are no resources associated with the given remoteLogSegmentMetadataUpdate. + * @throws IllegalArgumentException if the given metadata instance has the state as {@link RemoteLogSegmentState#COPY_SEGMENT_STARTED} + * @return a CompletableFuture which will complete once this operation is finished. + */ + CompletableFuture updateRemoteLogSegmentMetadata(RemoteLogSegmentMetadataUpdate remoteLogSegmentMetadataUpdate) + throws RemoteStorageException; + + /** + * Returns {@link RemoteLogSegmentMetadata} if it exists for the given topic partition containing the offset with + * the given leader-epoch for the offset, else returns {@link Optional#empty()}. + * + * @param topicIdPartition topic partition + * @param epochForOffset leader epoch for the given offset + * @param offset offset + * @return the requested remote log segment metadata if it exists. + * @throws RemoteStorageException if there are any storage related errors occurred. + */ + Optional remoteLogSegmentMetadata(TopicIdPartition topicIdPartition, + int epochForOffset, + long offset) + throws RemoteStorageException; + + /** + * Returns the highest log offset of topic partition for the given leader epoch in remote storage. This is used by + * remote log management subsystem to know up to which offset the segments have been copied to remote storage for + * a given leader epoch. + * + * @param topicIdPartition topic partition + * @param leaderEpoch leader epoch + * @return the requested highest log offset if exists. + * @throws RemoteStorageException if there are any storage related errors occurred. + */ + Optional highestOffsetForEpoch(TopicIdPartition topicIdPartition, + int leaderEpoch) throws RemoteStorageException; + + /** + * This method is used to update the metadata about remote partition delete event asynchronously. Currently, it allows updating the + * state ({@link RemotePartitionDeleteState}) of a topic partition in remote metadata storage. Controller invokes + * this method with {@link RemotePartitionDeleteMetadata} having state as {@link RemotePartitionDeleteState#DELETE_PARTITION_MARKED}. + * So, remote partition removers can act on this event to clean the respective remote log segments of the partition. + *


                + * In the case of default RLMM implementation, remote partition remover processes {@link RemotePartitionDeleteState#DELETE_PARTITION_MARKED} + *

                  + *
                • sends an event with state as {@link RemotePartitionDeleteState#DELETE_PARTITION_STARTED} + *
                • gets all the remote log segments and deletes them. + *
                • sends an event with state as {@link RemotePartitionDeleteState#DELETE_PARTITION_FINISHED} once all the remote log segments are + * deleted. + *
                + * + * @param remotePartitionDeleteMetadata update on delete state of a partition. + * @throws RemoteStorageException if there are any storage related errors occurred. + * @throws RemoteResourceNotFoundException when there are no resources associated with the given remotePartitionDeleteMetadata. + * @return a CompletableFuture which will complete once this operation is finished. + */ + CompletableFuture putRemotePartitionDeleteMetadata(RemotePartitionDeleteMetadata remotePartitionDeleteMetadata) + throws RemoteStorageException; + + /** + * Returns all the remote log segment metadata of the given topicIdPartition. + *

                + * Remote Partition Removers uses this method to fetch all the segments for a given topic partition, so that they + * can delete them. + * + * @return Iterator of all the remote log segment metadata for the given topic partition. + */ + Iterator listRemoteLogSegments(TopicIdPartition topicIdPartition) + throws RemoteStorageException; + + /** + * Returns iterator of remote log segment metadata, sorted by {@link RemoteLogSegmentMetadata#startOffset()} in + * ascending order which contains the given leader epoch. This is used by remote log retention management subsystem + * to fetch the segment metadata for a given leader epoch. + * + * @param topicIdPartition topic partition + * @param leaderEpoch leader epoch + * @return Iterator of remote segments, sorted by start offset in ascending order. + */ + Iterator listRemoteLogSegments(TopicIdPartition topicIdPartition, + int leaderEpoch) throws RemoteStorageException; + + /** + * This method is invoked only when there are changes in leadership of the topic partitions that this broker is + * responsible for. + * + * @param leaderPartitions partitions that have become leaders on this broker. + * @param followerPartitions partitions that have become followers on this broker. + */ + void onPartitionLeadershipChanges(Set leaderPartitions, + Set followerPartitions); + + /** + * This method is invoked only when the topic partitions are stopped on this broker. This can happen when a + * partition is emigrated to other broker or a partition is deleted. + * + * @param partitions topic partitions that have been stopped. + */ + void onStopPartitions(Set partitions); +} \ No newline at end of file diff --git a/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteLogSegmentId.java b/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteLogSegmentId.java new file mode 100644 index 0000000..cbebd9f --- /dev/null +++ b/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteLogSegmentId.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.storage; + +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Objects; + +/** + * This class represents a universally unique identifier associated to a topic partition's log segment. This will be + * regenerated for every attempt of copying a specific log segment in {@link RemoteStorageManager#copyLogSegmentData(RemoteLogSegmentMetadata, LogSegmentData)}. + * Once it is stored in remote storage, it is used to access that segment later from remote log metadata storage. + */ +@InterfaceStability.Evolving +public class RemoteLogSegmentId { + + private final TopicIdPartition topicIdPartition; + private final Uuid id; + + public RemoteLogSegmentId(TopicIdPartition topicIdPartition, Uuid id) { + this.topicIdPartition = Objects.requireNonNull(topicIdPartition, "topicIdPartition can not be null"); + this.id = Objects.requireNonNull(id, "id can not be null"); + } + + /** + * @return TopicIdPartition of this remote log segment. + */ + public TopicIdPartition topicIdPartition() { + return topicIdPartition; + } + + /** + * @return Universally Unique Id of this remote log segment. + */ + public Uuid id() { + return id; + } + + @Override + public String toString() { + return "RemoteLogSegmentId{" + + "topicIdPartition=" + topicIdPartition + + ", id=" + id + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + RemoteLogSegmentId that = (RemoteLogSegmentId) o; + return Objects.equals(topicIdPartition, that.topicIdPartition) && Objects.equals(id, that.id); + } + + @Override + public int hashCode() { + return Objects.hash(topicIdPartition, id); + } + +} diff --git a/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteLogSegmentMetadata.java b/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteLogSegmentMetadata.java new file mode 100644 index 0000000..e0cbb79 --- /dev/null +++ b/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteLogSegmentMetadata.java @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.storage; + +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Collections; +import java.util.Map; +import java.util.NavigableMap; +import java.util.Objects; +import java.util.TreeMap; + +/** + * It describes the metadata about a topic partition's remote log segment in the remote storage. This is uniquely + * represented with {@link RemoteLogSegmentId}. + *

                + * New instance is always created with the state as {@link RemoteLogSegmentState#COPY_SEGMENT_STARTED}. This can be + * updated by applying {@link RemoteLogSegmentMetadataUpdate} for the respective {@link RemoteLogSegmentId} of the + * {@code RemoteLogSegmentMetadata}. + */ +@InterfaceStability.Evolving +public class RemoteLogSegmentMetadata extends RemoteLogMetadata { + + /** + * Universally unique remote log segment id. + */ + private final RemoteLogSegmentId remoteLogSegmentId; + + /** + * Start offset of this segment. + */ + private final long startOffset; + + /** + * End offset of this segment. + */ + private final long endOffset; + + /** + * Maximum timestamp in milli seconds in the segment + */ + private final long maxTimestampMs; + + /** + * LeaderEpoch vs offset for messages within this segment. + */ + private final NavigableMap segmentLeaderEpochs; + + /** + * Size of the segment in bytes. + */ + private final int segmentSizeInBytes; + + /** + * It indicates the state in which the action is executed on this segment. + */ + private final RemoteLogSegmentState state; + + /** + * Creates an instance with the given metadata of remote log segment. + *

                + * {@code segmentLeaderEpochs} can not be empty. If all the records in this segment belong to the same leader epoch + * then it should have an entry with epoch mapping to start-offset of this segment. + * + * @param remoteLogSegmentId Universally unique remote log segment id. + * @param startOffset Start offset of this segment (inclusive). + * @param endOffset End offset of this segment (inclusive). + * @param maxTimestampMs Maximum timestamp in milli seconds in this segment. + * @param brokerId Broker id from which this event is generated. + * @param eventTimestampMs Epoch time in milli seconds at which the remote log segment is copied to the remote tier storage. + * @param segmentSizeInBytes Size of this segment in bytes. + * @param state State of the respective segment of remoteLogSegmentId. + * @param segmentLeaderEpochs leader epochs occurred within this segment. + */ + public RemoteLogSegmentMetadata(RemoteLogSegmentId remoteLogSegmentId, + long startOffset, + long endOffset, + long maxTimestampMs, + int brokerId, + long eventTimestampMs, + int segmentSizeInBytes, + RemoteLogSegmentState state, + Map segmentLeaderEpochs) { + super(brokerId, eventTimestampMs); + this.remoteLogSegmentId = Objects.requireNonNull(remoteLogSegmentId, "remoteLogSegmentId can not be null"); + this.state = Objects.requireNonNull(state, "state can not be null"); + + this.startOffset = startOffset; + this.endOffset = endOffset; + this.maxTimestampMs = maxTimestampMs; + this.segmentSizeInBytes = segmentSizeInBytes; + + if (segmentLeaderEpochs == null || segmentLeaderEpochs.isEmpty()) { + throw new IllegalArgumentException("segmentLeaderEpochs can not be null or empty"); + } + + this.segmentLeaderEpochs = Collections.unmodifiableNavigableMap(new TreeMap<>(segmentLeaderEpochs)); + } + + /** + * Creates an instance with the given metadata of remote log segment and its state as {@link RemoteLogSegmentState#COPY_SEGMENT_STARTED}. + *

                + * {@code segmentLeaderEpochs} can not be empty. If all the records in this segment belong to the same leader epoch + * then it should have an entry with epoch mapping to start-offset of this segment. + * + * @param remoteLogSegmentId Universally unique remote log segment id. + * @param startOffset Start offset of this segment (inclusive). + * @param endOffset End offset of this segment (inclusive). + * @param maxTimestampMs Maximum timestamp in this segment + * @param brokerId Broker id from which this event is generated. + * @param eventTimestampMs Epoch time in milli seconds at which the remote log segment is copied to the remote tier storage. + * @param segmentSizeInBytes Size of this segment in bytes. + * @param segmentLeaderEpochs leader epochs occurred within this segment + */ + public RemoteLogSegmentMetadata(RemoteLogSegmentId remoteLogSegmentId, + long startOffset, + long endOffset, + long maxTimestampMs, + int brokerId, + long eventTimestampMs, + int segmentSizeInBytes, + Map segmentLeaderEpochs) { + this(remoteLogSegmentId, + startOffset, + endOffset, + maxTimestampMs, + brokerId, + eventTimestampMs, segmentSizeInBytes, + RemoteLogSegmentState.COPY_SEGMENT_STARTED, + segmentLeaderEpochs); + } + + + /** + * @return unique id of this segment. + */ + public RemoteLogSegmentId remoteLogSegmentId() { + return remoteLogSegmentId; + } + + /** + * @return Start offset of this segment (inclusive). + */ + public long startOffset() { + return startOffset; + } + + /** + * @return End offset of this segment (inclusive). + */ + public long endOffset() { + return endOffset; + } + + /** + * @return Total size of this segment in bytes. + */ + public int segmentSizeInBytes() { + return segmentSizeInBytes; + } + + /** + * @return Maximum timestamp in milli seconds of a record within this segment. + */ + public long maxTimestampMs() { + return maxTimestampMs; + } + + /** + * @return Map of leader epoch vs offset for the records available in this segment. + */ + public NavigableMap segmentLeaderEpochs() { + return segmentLeaderEpochs; + } + + /** + * Returns the current state of this remote log segment. It can be any of the below + *

                  + * {@link RemoteLogSegmentState#COPY_SEGMENT_STARTED} + * {@link RemoteLogSegmentState#COPY_SEGMENT_FINISHED} + * {@link RemoteLogSegmentState#DELETE_SEGMENT_STARTED} + * {@link RemoteLogSegmentState#DELETE_SEGMENT_FINISHED} + *
                + */ + public RemoteLogSegmentState state() { + return state; + } + + /** + * Creates a new RemoteLogSegmentMetadata applying the given {@code rlsmUpdate} on this instance. This method will + * not update this instance. + * + * @param rlsmUpdate update to be applied. + * @return a new instance created by applying the given update on this instance. + */ + public RemoteLogSegmentMetadata createWithUpdates(RemoteLogSegmentMetadataUpdate rlsmUpdate) { + if (!remoteLogSegmentId.equals(rlsmUpdate.remoteLogSegmentId())) { + throw new IllegalArgumentException("Given rlsmUpdate does not have this instance's remoteLogSegmentId."); + } + + return new RemoteLogSegmentMetadata(remoteLogSegmentId, startOffset, + endOffset, maxTimestampMs, rlsmUpdate.brokerId(), rlsmUpdate.eventTimestampMs(), + segmentSizeInBytes, rlsmUpdate.state(), segmentLeaderEpochs); + } + + @Override + public TopicIdPartition topicIdPartition() { + return remoteLogSegmentId.topicIdPartition(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + RemoteLogSegmentMetadata that = (RemoteLogSegmentMetadata) o; + return startOffset == that.startOffset && endOffset == that.endOffset + && maxTimestampMs == that.maxTimestampMs + && segmentSizeInBytes == that.segmentSizeInBytes + && Objects.equals(remoteLogSegmentId, that.remoteLogSegmentId) + && Objects.equals(segmentLeaderEpochs, that.segmentLeaderEpochs) && state == that.state + && eventTimestampMs() == that.eventTimestampMs() + && brokerId() == that.brokerId(); + } + + @Override + public int hashCode() { + return Objects.hash(remoteLogSegmentId, startOffset, endOffset, brokerId(), maxTimestampMs, + eventTimestampMs(), segmentLeaderEpochs, segmentSizeInBytes, state); + } + + @Override + public String toString() { + return "RemoteLogSegmentMetadata{" + + "remoteLogSegmentId=" + remoteLogSegmentId + + ", startOffset=" + startOffset + + ", endOffset=" + endOffset + + ", brokerId=" + brokerId() + + ", maxTimestampMs=" + maxTimestampMs + + ", eventTimestampMs=" + eventTimestampMs() + + ", segmentLeaderEpochs=" + segmentLeaderEpochs + + ", segmentSizeInBytes=" + segmentSizeInBytes + + ", state=" + state + + '}'; + } + +} diff --git a/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteLogSegmentMetadataUpdate.java b/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteLogSegmentMetadataUpdate.java new file mode 100644 index 0000000..a01df96 --- /dev/null +++ b/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteLogSegmentMetadataUpdate.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.storage; + +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Objects; + +/** + * It describes the metadata update about the log segment in the remote storage. This is currently used to update the + * state of the remote log segment by using {@link RemoteLogMetadataManager#updateRemoteLogSegmentMetadata(RemoteLogSegmentMetadataUpdate)}. + * This also includes the timestamp of this event. + */ +@InterfaceStability.Evolving +public class RemoteLogSegmentMetadataUpdate extends RemoteLogMetadata { + + /** + * Universally unique remote log segment id. + */ + private final RemoteLogSegmentId remoteLogSegmentId; + + /** + * It indicates the state in which the action is executed on this segment. + */ + private final RemoteLogSegmentState state; + + /** + * @param remoteLogSegmentId Universally unique remote log segment id. + * @param eventTimestampMs Epoch time in milli seconds at which the remote log segment is copied to the remote tier storage. + * @param state State of the remote log segment. + * @param brokerId Broker id from which this event is generated. + */ + public RemoteLogSegmentMetadataUpdate(RemoteLogSegmentId remoteLogSegmentId, long eventTimestampMs, + RemoteLogSegmentState state, int brokerId) { + super(brokerId, eventTimestampMs); + this.remoteLogSegmentId = Objects.requireNonNull(remoteLogSegmentId, "remoteLogSegmentId can not be null"); + this.state = Objects.requireNonNull(state, "state can not be null"); + } + + /** + * @return Universally unique id of this remote log segment. + */ + public RemoteLogSegmentId remoteLogSegmentId() { + return remoteLogSegmentId; + } + + /** + * It represents the state of the remote log segment. It can be one of the values of {@link RemoteLogSegmentState}. + */ + public RemoteLogSegmentState state() { + return state; + } + + @Override + public TopicIdPartition topicIdPartition() { + return remoteLogSegmentId.topicIdPartition(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + RemoteLogSegmentMetadataUpdate that = (RemoteLogSegmentMetadataUpdate) o; + return Objects.equals(remoteLogSegmentId, that.remoteLogSegmentId) && + state == that.state && + eventTimestampMs() == that.eventTimestampMs() && + brokerId() == that.brokerId(); + } + + @Override + public int hashCode() { + return Objects.hash(remoteLogSegmentId, state, eventTimestampMs(), brokerId()); + } + + @Override + public String toString() { + return "RemoteLogSegmentMetadataUpdate{" + + "remoteLogSegmentId=" + remoteLogSegmentId + + ", state=" + state + + ", eventTimestampMs=" + eventTimestampMs() + + ", brokerId=" + brokerId() + + '}'; + } +} diff --git a/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteLogSegmentState.java b/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteLogSegmentState.java new file mode 100644 index 0000000..c618321 --- /dev/null +++ b/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteLogSegmentState.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.storage; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; +import java.util.Objects; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * This enum indicates the state of the remote log segment. This will be based on the action executed on this + * segment by the remote log service implementation. + *

                + * It goes through the below state transitions. Self transition is treated as valid. This allows updating with the + * same state in case of retries and failover. + *

                + *

                + * +---------------------+            +----------------------+
                + * |COPY_SEGMENT_STARTED |----------> |COPY_SEGMENT_FINISHED |
                + * +-------------------+-+            +--+-------------------+
                + *                     |                 |
                + *                     |                 |
                + *                     v                 v
                + *                  +--+-----------------+-+
                + *                  |DELETE_SEGMENT_STARTED|
                + *                  +-----------+----------+
                + *                              |
                + *                              |
                + *                              v
                + *                  +-----------+-----------+
                + *                  |DELETE_SEGMENT_FINISHED|
                + *                  +-----------------------+
                + * 
                + */ +@InterfaceStability.Evolving +public enum RemoteLogSegmentState { + + /** + * This state indicates that the segment copying to remote storage is started but not yet finished. + */ + COPY_SEGMENT_STARTED((byte) 0), + + /** + * This state indicates that the segment copying to remote storage is finished. + */ + COPY_SEGMENT_FINISHED((byte) 1), + + /** + * This state indicates that the segment deletion is started but not yet finished. + */ + DELETE_SEGMENT_STARTED((byte) 2), + + /** + * This state indicates that the segment is deleted successfully. + */ + DELETE_SEGMENT_FINISHED((byte) 3); + + private static final Map STATE_TYPES = Collections.unmodifiableMap( + Arrays.stream(values()).collect(Collectors.toMap(RemoteLogSegmentState::id, Function.identity()))); + + private final byte id; + + RemoteLogSegmentState(byte id) { + this.id = id; + } + + public byte id() { + return id; + } + + public static RemoteLogSegmentState forId(byte id) { + return STATE_TYPES.get(id); + } + + public static boolean isValidTransition(RemoteLogSegmentState srcState, RemoteLogSegmentState targetState) { + Objects.requireNonNull(targetState, "targetState can not be null"); + + if (srcState == null) { + // If the source state is null, check the target state as the initial state viz COPY_SEGMENT_STARTED + // This ensures simplicity here as we don't have to define one more type to represent the state 'null' like + // COPY_SEGMENT_NOT_STARTED, have the null check by the caller and pass that state. + return targetState == COPY_SEGMENT_STARTED; + } else if (srcState == targetState) { + // Self transition is treated as valid. This is to maintain the idempotency for the state in case of retries + // or failover. + return true; + } else if (srcState == COPY_SEGMENT_STARTED) { + return targetState == COPY_SEGMENT_FINISHED || targetState == DELETE_SEGMENT_STARTED; + } else if (srcState == COPY_SEGMENT_FINISHED) { + return targetState == DELETE_SEGMENT_STARTED; + } else if (srcState == DELETE_SEGMENT_STARTED) { + return targetState == DELETE_SEGMENT_FINISHED; + } else { + return false; + } + } +} diff --git a/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemotePartitionDeleteMetadata.java b/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemotePartitionDeleteMetadata.java new file mode 100644 index 0000000..c84e1d7 --- /dev/null +++ b/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemotePartitionDeleteMetadata.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.storage; + +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Objects; + +/** + * This class represents the metadata about the remote partition. It can be created/updated with {@link RemoteLogMetadataManager#putRemotePartitionDeleteMetadata(RemotePartitionDeleteMetadata)}. + * Possible state transitions are mentioned at {@link RemotePartitionDeleteState}. + */ +@InterfaceStability.Evolving +public class RemotePartitionDeleteMetadata extends RemoteLogMetadata { + + private final TopicIdPartition topicIdPartition; + private final RemotePartitionDeleteState state; + + /** + * Creates an instance of this class with the given metadata. + * + * @param topicIdPartition topic partition for which this event is meant for. + * @param state State of the remote topic partition. + * @param eventTimestampMs Epoch time in milli seconds at which this event is occurred. + * @param brokerId Id of the broker in which this event is raised. + */ + public RemotePartitionDeleteMetadata(TopicIdPartition topicIdPartition, + RemotePartitionDeleteState state, + long eventTimestampMs, + int brokerId) { + super(brokerId, eventTimestampMs); + this.topicIdPartition = Objects.requireNonNull(topicIdPartition); + this.state = Objects.requireNonNull(state); + } + + /** + * @return TopicIdPartition for which this event is meant for. + */ + public TopicIdPartition topicIdPartition() { + return topicIdPartition; + } + + /** + * It represents the state of the remote partition. It can be one of the values of {@link RemotePartitionDeleteState}. + */ + public RemotePartitionDeleteState state() { + return state; + } + + @Override + public String toString() { + return "RemotePartitionDeleteMetadata{" + + "topicPartition=" + topicIdPartition + + ", state=" + state + + ", eventTimestampMs=" + eventTimestampMs() + + ", brokerId=" + brokerId() + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + RemotePartitionDeleteMetadata that = (RemotePartitionDeleteMetadata) o; + return Objects.equals(topicIdPartition, that.topicIdPartition) && + state == that.state && + eventTimestampMs() == that.eventTimestampMs() && + brokerId() == that.brokerId(); + } + + @Override + public int hashCode() { + return Objects.hash(topicIdPartition, state, eventTimestampMs(), brokerId()); + } +} \ No newline at end of file diff --git a/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemotePartitionDeleteState.java b/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemotePartitionDeleteState.java new file mode 100644 index 0000000..e2fad1a --- /dev/null +++ b/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemotePartitionDeleteState.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.storage; + +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; +import java.util.Objects; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** + * This enum indicates the deletion state of the remote topic partition. This will be based on the action executed on this + * partition by the remote log service implementation. + * State transitions are mentioned below. Self transition is treated as valid. This allows updating with the + * same state in case of retries and failover. + *

                + *

                + * +-------------------------+
                + * |DELETE_PARTITION_MARKED  |
                + * +-----------+-------------+
                + *             |
                + *             |
                + * +-----------v--------------+
                + * |DELETE_PARTITION_STARTED  |
                + * +-----------+--------------+
                + *             |
                + *             |
                + * +-----------v--------------+
                + * |DELETE_PARTITION_FINISHED |
                + * +--------------------------+
                + * 
                + *

                + */ +@InterfaceStability.Evolving +public enum RemotePartitionDeleteState { + + /** + * This is used when a topic/partition is marked for delete by the controller. + * That means, all its remote log segments are eligible for deletion so that remote partition removers can + * start deleting them. + */ + DELETE_PARTITION_MARKED((byte) 0), + + /** + * This state indicates that the partition deletion is started but not yet finished. + */ + DELETE_PARTITION_STARTED((byte) 1), + + /** + * This state indicates that the partition is deleted successfully. + */ + DELETE_PARTITION_FINISHED((byte) 2); + + private static final Map STATE_TYPES = Collections.unmodifiableMap( + Arrays.stream(values()).collect(Collectors.toMap(RemotePartitionDeleteState::id, Function.identity()))); + + private final byte id; + + RemotePartitionDeleteState(byte id) { + this.id = id; + } + + public byte id() { + return id; + } + + public static RemotePartitionDeleteState forId(byte id) { + return STATE_TYPES.get(id); + } + + public static boolean isValidTransition(RemotePartitionDeleteState srcState, + RemotePartitionDeleteState targetState) { + Objects.requireNonNull(targetState, "targetState can not be null"); + + if (srcState == null) { + // If the source state is null, check the target state as the initial state viz DELETE_PARTITION_MARKED. + // This ensures simplicity here as we don't have to define one more type to represent the state 'null' like + // DELETE_PARTITION_NOT_MARKED, have the null check by the caller and pass that state. + return targetState == DELETE_PARTITION_MARKED; + } else if (srcState == targetState) { + // Self transition is treated as valid. This is to maintain the idempotency for the state in case of retries + // or failover. + return true; + } else if (srcState == DELETE_PARTITION_MARKED) { + return targetState == DELETE_PARTITION_STARTED; + } else if (srcState == DELETE_PARTITION_STARTED) { + return targetState == DELETE_PARTITION_FINISHED; + } else { + return false; + } + } +} \ No newline at end of file diff --git a/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteResourceNotFoundException.java b/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteResourceNotFoundException.java new file mode 100644 index 0000000..f6ac4ec --- /dev/null +++ b/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteResourceNotFoundException.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.storage; + +/** + * Exception thrown when a resource is not found on the remote storage. + *

                + * A resource can be a log segment, any of the indexes or any which was stored in remote storage for a particular log + * segment. + */ +public class RemoteResourceNotFoundException extends RemoteStorageException { + private static final long serialVersionUID = 1L; + + public RemoteResourceNotFoundException(final String message) { + super(message); + } + + public RemoteResourceNotFoundException(final Throwable cause) { + super("Requested remote resource was not found", cause); + } + + public RemoteResourceNotFoundException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteStorageException.java b/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteStorageException.java new file mode 100644 index 0000000..bd392fc --- /dev/null +++ b/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteStorageException.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.storage; + +/** + * Exception thrown when there is a remote storage error. This can be used as the base exception by implementors of + * {@link RemoteStorageManager} or {@link RemoteLogMetadataManager} to create extended exceptions. + */ +public class RemoteStorageException extends Exception { + private static final long serialVersionUID = 1L; + + public RemoteStorageException(final String message) { + super(message); + } + + public RemoteStorageException(final String message, final Throwable cause) { + super(message, cause); + } + + public RemoteStorageException(Throwable cause) { + super(cause); + } +} diff --git a/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteStorageManager.java b/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteStorageManager.java new file mode 100644 index 0000000..6231d8e --- /dev/null +++ b/storage/api/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteStorageManager.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.storage; + +import org.apache.kafka.common.Configurable; +import org.apache.kafka.common.annotation.InterfaceStability; + +import java.io.Closeable; +import java.io.InputStream; + +/** + * This interface provides the lifecycle of remote log segments that includes copy, fetch, and delete from remote + * storage. + *

                + * Each upload or copy of a segment is initiated with {@link RemoteLogSegmentMetadata} containing {@link RemoteLogSegmentId} + * which is universally unique even for the same topic partition and offsets. + *

                + * {@link RemoteLogSegmentMetadata} is stored in {@link RemoteLogMetadataManager} before and after copy/delete operations on + * {@link RemoteStorageManager} with the respective {@link RemoteLogSegmentState}. {@link RemoteLogMetadataManager} is + * responsible for storing and fetching metadata about the remote log segments in a strongly consistent manner. + * This allows {@link RemoteStorageManager} to have eventual consistency on metadata (although the data is stored + * in strongly consistent semantics). + */ +@InterfaceStability.Evolving +public interface RemoteStorageManager extends Configurable, Closeable { + + /** + * Type of the index file. + */ + enum IndexType { + /** + * Represents offset index. + */ + OFFSET, + + /** + * Represents timestamp index. + */ + TIMESTAMP, + + /** + * Represents producer snapshot index. + */ + PRODUCER_SNAPSHOT, + + /** + * Represents transaction index. + */ + TRANSACTION, + + /** + * Represents leader epoch index. + */ + LEADER_EPOCH, + } + + /** + * Copies the given {@link LogSegmentData} provided for the given {@code remoteLogSegmentMetadata}. This includes + * log segment and its auxiliary indexes like offset index, time index, transaction index, leader epoch index, and + * producer snapshot index. + *

                + * Invoker of this API should always send a unique id as part of {@link RemoteLogSegmentMetadata#remoteLogSegmentId()} + * even when it retries to invoke this method for the same log segment data. + * + * @param remoteLogSegmentMetadata metadata about the remote log segment. + * @param logSegmentData data to be copied to tiered storage. + * @throws RemoteStorageException if there are any errors in storing the data of the segment. + */ + void copyLogSegmentData(RemoteLogSegmentMetadata remoteLogSegmentMetadata, + LogSegmentData logSegmentData) + throws RemoteStorageException; + + /** + * Returns the remote log segment data file/object as InputStream for the given {@link RemoteLogSegmentMetadata} + * starting from the given startPosition. The stream will end at the end of the remote log segment data file/object. + * + * @param remoteLogSegmentMetadata metadata about the remote log segment. + * @param startPosition start position of log segment to be read, inclusive. + * @return input stream of the requested log segment data. + * @throws RemoteStorageException if there are any errors while fetching the desired segment. + * @throws RemoteResourceNotFoundException when there are no resources associated with the given remoteLogSegmentMetadata. + */ + InputStream fetchLogSegment(RemoteLogSegmentMetadata remoteLogSegmentMetadata, + int startPosition) throws RemoteStorageException; + + /** + * Returns the remote log segment data file/object as InputStream for the given {@link RemoteLogSegmentMetadata} + * starting from the given startPosition. The stream will end at the smaller of endPosition and the end of the + * remote log segment data file/object. + * + * @param remoteLogSegmentMetadata metadata about the remote log segment. + * @param startPosition start position of log segment to be read, inclusive. + * @param endPosition end position of log segment to be read, inclusive. + * @return input stream of the requested log segment data. + * @throws RemoteStorageException if there are any errors while fetching the desired segment. + * @throws RemoteResourceNotFoundException when there are no resources associated with the given remoteLogSegmentMetadata. + */ + InputStream fetchLogSegment(RemoteLogSegmentMetadata remoteLogSegmentMetadata, + int startPosition, + int endPosition) throws RemoteStorageException; + + /** + * Returns the index for the respective log segment of {@link RemoteLogSegmentMetadata}. + * + * @param remoteLogSegmentMetadata metadata about the remote log segment. + * @param indexType type of the index to be fetched for the segment. + * @return input stream of the requested index. + * @throws RemoteStorageException if there are any errors while fetching the index. + * @throws RemoteResourceNotFoundException when there are no resources associated with the given remoteLogSegmentMetadata. + */ + InputStream fetchIndex(RemoteLogSegmentMetadata remoteLogSegmentMetadata, + IndexType indexType) throws RemoteStorageException; + + /** + * Deletes the resources associated with the given {@code remoteLogSegmentMetadata}. Deletion is considered as + * successful if this call returns successfully without any errors. It will throw {@link RemoteStorageException} if + * there are any errors in deleting the file. + *

                + * + * @param remoteLogSegmentMetadata metadata about the remote log segment to be deleted. + * @throws RemoteResourceNotFoundException if the requested resource is not found + * @throws RemoteStorageException if there are any storage related errors occurred. + * @throws RemoteResourceNotFoundException when there are no resources associated with the given remoteLogSegmentMetadata. + */ + void deleteLogSegmentData(RemoteLogSegmentMetadata remoteLogSegmentMetadata) throws RemoteStorageException; +} \ No newline at end of file diff --git a/storage/api/src/test/java/org/apache/kafka/server/log/remote/storage/LogSegmentDataTest.java b/storage/api/src/test/java/org/apache/kafka/server/log/remote/storage/LogSegmentDataTest.java new file mode 100644 index 0000000..e0a2022 --- /dev/null +++ b/storage/api/src/test/java/org/apache/kafka/server/log/remote/storage/LogSegmentDataTest.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.storage; + +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.io.File; +import java.nio.ByteBuffer; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertFalse; + +public class LogSegmentDataTest { + + @Test + public void testOptionalTransactionIndex() { + File dir = TestUtils.tempDirectory(); + LogSegmentData logSegmentDataWithTransactionIndex = new LogSegmentData( + new File(dir, "log-segment").toPath(), + new File(dir, "offset-index").toPath(), + new File(dir, "time-index").toPath(), + Optional.of(new File(dir, "transaction-index").toPath()), + new File(dir, "producer-snapshot").toPath(), + ByteBuffer.allocate(1) + ); + Assertions.assertTrue(logSegmentDataWithTransactionIndex.transactionIndex().isPresent()); + + LogSegmentData logSegmentDataWithNoTransactionIndex = new LogSegmentData( + new File(dir, "log-segment").toPath(), + new File(dir, "offset-index").toPath(), + new File(dir, "time-index").toPath(), + Optional.empty(), + new File(dir, "producer-snapshot").toPath(), + ByteBuffer.allocate(1) + ); + assertFalse(logSegmentDataWithNoTransactionIndex.transactionIndex().isPresent()); + } +} \ No newline at end of file diff --git a/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/CommittedOffsetsFile.java b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/CommittedOffsetsFile.java new file mode 100644 index 0000000..1eddc0b --- /dev/null +++ b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/CommittedOffsetsFile.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.server.log.remote.metadata.storage; + +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.server.common.CheckpointFile; + +import java.io.File; +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.regex.Pattern; + +/** + * This class represents a file containing the committed offsets of remote log metadata partitions. + */ +public class CommittedOffsetsFile { + private static final int CURRENT_VERSION = 0; + private static final String SEPARATOR = " "; + + private static final Pattern MINIMUM_ONE_WHITESPACE = Pattern.compile("\\s+"); + private final CheckpointFile> checkpointFile; + + CommittedOffsetsFile(File offsetsFile) throws IOException { + CheckpointFile.EntryFormatter> formatter = new EntryFormatter(); + checkpointFile = new CheckpointFile<>(offsetsFile, CURRENT_VERSION, formatter); + } + + private static class EntryFormatter implements CheckpointFile.EntryFormatter> { + + @Override + public String toString(Map.Entry entry) { + // Each entry is stored in a new line as + return entry.getKey() + SEPARATOR + entry.getValue(); + } + + @Override + public Optional> fromString(String line) { + String[] strings = MINIMUM_ONE_WHITESPACE.split(line); + if (strings.length != 2) { + return Optional.empty(); + } + + try { + return Optional.of(Utils.mkEntry(Integer.parseInt(strings[0]), Long.parseLong(strings[1]))); + } catch (NumberFormatException e) { + return Optional.empty(); + } + + } + } + + public synchronized void writeEntries(Map committedOffsets) throws IOException { + checkpointFile.write(committedOffsets.entrySet()); + } + + public synchronized Map readEntries() throws IOException { + List> entries = checkpointFile.read(); + Map partitionToOffsets = new HashMap<>(entries.size()); + for (Map.Entry entry : entries) { + Long existingValue = partitionToOffsets.put(entry.getKey(), entry.getValue()); + if (existingValue != null) { + throw new IOException("Multiple entries exist for key: " + entry.getKey()); + } + } + + return partitionToOffsets; + } +} \ No newline at end of file diff --git a/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/ConsumerManager.java b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/ConsumerManager.java new file mode 100644 index 0000000..77f83fb --- /dev/null +++ b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/ConsumerManager.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage; + +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.utils.KafkaThread; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Closeable; +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.TimeoutException; + +/** + * This class manages the consumer thread viz {@link ConsumerTask} that polls messages from the assigned metadata topic partitions. + * It also provides a way to wait until the given record is received by the consumer before it is timed out with an interval of + * {@link TopicBasedRemoteLogMetadataManagerConfig#consumeWaitMs()}. + */ +public class ConsumerManager implements Closeable { + + public static final String COMMITTED_OFFSETS_FILE_NAME = "_rlmm_committed_offsets"; + + private static final Logger log = LoggerFactory.getLogger(ConsumerManager.class); + private static final long CONSUME_RECHECK_INTERVAL_MS = 50L; + + private final TopicBasedRemoteLogMetadataManagerConfig rlmmConfig; + private final Time time; + private final ConsumerTask consumerTask; + private final Thread consumerTaskThread; + + public ConsumerManager(TopicBasedRemoteLogMetadataManagerConfig rlmmConfig, + RemotePartitionMetadataEventHandler remotePartitionMetadataEventHandler, + RemoteLogMetadataTopicPartitioner topicPartitioner, + Time time) { + this.rlmmConfig = rlmmConfig; + this.time = time; + + //Create a task to consume messages and submit the respective events to RemotePartitionMetadataEventHandler. + KafkaConsumer consumer = new KafkaConsumer<>(rlmmConfig.consumerProperties()); + Path committedOffsetsPath = new File(rlmmConfig.logDir(), COMMITTED_OFFSETS_FILE_NAME).toPath(); + consumerTask = new ConsumerTask(consumer, remotePartitionMetadataEventHandler, topicPartitioner, committedOffsetsPath, time, 60_000L); + consumerTaskThread = KafkaThread.nonDaemon("RLMMConsumerTask", consumerTask); + } + + public void startConsumerThread() { + try { + // Start a thread to continuously consume records from topic partitions. + consumerTaskThread.start(); + log.info("RLMM Consumer task thread is started"); + } catch (Exception e) { + throw new KafkaException("Error encountered while initializing and scheduling ConsumerTask thread", e); + } + } + + /** + * Waits if necessary for the consumption to reach the offset of the given {@code recordMetadata}. + * + * @param recordMetadata record metadata to be checked for consumption. + * @throws TimeoutException if this method execution did not complete with in the wait time configured with + * property {@code TopicBasedRemoteLogMetadataManagerConfig#REMOTE_LOG_METADATA_CONSUME_WAIT_MS_PROP}. + */ + public void waitTillConsumptionCatchesUp(RecordMetadata recordMetadata) throws TimeoutException { + waitTillConsumptionCatchesUp(recordMetadata, rlmmConfig.consumeWaitMs()); + } + + /** + * Waits if necessary for the consumption to reach the offset of the given {@code recordMetadata}. + * + * @param recordMetadata record metadata to be checked for consumption. + * @param timeoutMs wait timeout in milli seconds + * @throws TimeoutException if this method execution did not complete with in the given {@code timeoutMs}. + */ + public void waitTillConsumptionCatchesUp(RecordMetadata recordMetadata, + long timeoutMs) throws TimeoutException { + final int partition = recordMetadata.partition(); + final long consumeCheckIntervalMs = Math.min(CONSUME_RECHECK_INTERVAL_MS, timeoutMs); + + // If the current assignment does not have the subscription for this partition then return immediately. + if (!consumerTask.isPartitionAssigned(partition)) { + throw new KafkaException("This consumer is not subscribed to the target partition " + partition + " on which message is produced."); + } + + final long offset = recordMetadata.offset(); + long startTimeMs = time.milliseconds(); + while (true) { + long receivedOffset = consumerTask.receivedOffsetForPartition(partition).orElse(-1L); + if (receivedOffset >= offset) { + return; + } + + log.debug("Committed offset [{}] for partition [{}], but the target offset: [{}], Sleeping for [{}] to retry again", + offset, partition, receivedOffset, consumeCheckIntervalMs); + + if (time.milliseconds() - startTimeMs > timeoutMs) { + log.warn("Committed offset for partition:[{}] is : [{}], but the target offset: [{}] ", + partition, receivedOffset, offset); + throw new TimeoutException("Timed out in catching up with the expected offset by consumer."); + } + + time.sleep(consumeCheckIntervalMs); + } + } + + @Override + public void close() throws IOException { + // Consumer task will close the task and it internally closes all the resources including the consumer. + Utils.closeQuietly(consumerTask, "ConsumerTask"); + + // Wait until the consumer thread finishes. + try { + consumerTaskThread.join(); + } catch (Exception e) { + log.error("Encountered error while waiting for consumerTaskThread to finish.", e); + } + } + + public void addAssignmentsForPartitions(Set partitions) { + consumerTask.addAssignmentsForPartitions(partitions); + } + + public void removeAssignmentsForPartitions(Set partitions) { + consumerTask.removeAssignmentsForPartitions(partitions); + } + + public Optional receivedOffsetForPartition(int metadataPartition) { + return consumerTask.receivedOffsetForPartition(metadataPartition); + } +} diff --git a/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/ConsumerTask.java b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/ConsumerTask.java new file mode 100644 index 0000000..2509a44 --- /dev/null +++ b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/ConsumerTask.java @@ -0,0 +1,356 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.server.log.remote.metadata.storage.serialization.RemoteLogMetadataSerde; +import org.apache.kafka.server.log.remote.storage.RemoteLogMetadata; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadata; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.file.Path; +import java.time.Duration; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; + +import static org.apache.kafka.server.log.remote.metadata.storage.TopicBasedRemoteLogMetadataManagerConfig.REMOTE_LOG_METADATA_TOPIC_NAME; + +/** + * This class is responsible for consuming messages from remote log metadata topic ({@link TopicBasedRemoteLogMetadataManagerConfig#REMOTE_LOG_METADATA_TOPIC_NAME}) + * partitions and maintain the state of the remote log segment metadata. It gives an API to add or remove + * for what topic partition's metadata should be consumed by this instance using + * {{@link #addAssignmentsForPartitions(Set)}} and {@link #removeAssignmentsForPartitions(Set)} respectively. + *

                + * When a broker is started, controller sends topic partitions that this broker is leader or follower for and the + * partitions to be deleted. This class receives those notifications with + * {@link #addAssignmentsForPartitions(Set)} and {@link #removeAssignmentsForPartitions(Set)} assigns consumer for the + * respective remote log metadata partitions by using {@link RemoteLogMetadataTopicPartitioner#metadataPartition(TopicIdPartition)}. + * Any leadership changes later are called through the same API. We will remove the partitions that are deleted from + * this broker which are received through {@link #removeAssignmentsForPartitions(Set)}. + *

                + * After receiving these events it invokes {@link RemotePartitionMetadataEventHandler#handleRemoteLogSegmentMetadata(RemoteLogSegmentMetadata)}, + * which maintains in-memory representation of the state of {@link RemoteLogSegmentMetadata}. + */ +class ConsumerTask implements Runnable, Closeable { + private static final Logger log = LoggerFactory.getLogger(ConsumerTask.class); + + private static final long POLL_INTERVAL_MS = 100L; + + private final RemoteLogMetadataSerde serde = new RemoteLogMetadataSerde(); + private final KafkaConsumer consumer; + private final RemotePartitionMetadataEventHandler remotePartitionMetadataEventHandler; + private final RemoteLogMetadataTopicPartitioner topicPartitioner; + private final Time time; + + // It indicates whether the closing process has been started or not. If it is set as true, + // consumer will stop consuming messages and it will not allow partition assignments to be updated. + private volatile boolean closing = false; + + // It indicates whether the consumer needs to assign the partitions or not. This is set when it is + // determined that the consumer needs to be assigned with the updated partitions. + private volatile boolean assignPartitions = false; + + // It represents a lock for any operations related to the assignedTopicPartitions. + private final Object assignPartitionsLock = new Object(); + + // Remote log metadata topic partitions that consumer is assigned to. + private volatile Set assignedMetaPartitions = Collections.emptySet(); + + // User topic partitions that this broker is a leader/follower for. + private Set assignedTopicPartitions = Collections.emptySet(); + + // Map of remote log metadata topic partition to consumed offsets. Received consumer records + // may or may not have been processed based on the assigned topic partitions. + private final Map partitionToConsumedOffsets = new ConcurrentHashMap<>(); + + // Map of remote log metadata topic partition to processed offsets that were synced in committedOffsetsFile. + private Map lastSyncedPartitionToConsumedOffsets = Collections.emptyMap(); + + private final long committedOffsetSyncIntervalMs; + private CommittedOffsetsFile committedOffsetsFile; + private long lastSyncedTimeMs; + + public ConsumerTask(KafkaConsumer consumer, + RemotePartitionMetadataEventHandler remotePartitionMetadataEventHandler, + RemoteLogMetadataTopicPartitioner topicPartitioner, + Path committedOffsetsPath, + Time time, + long committedOffsetSyncIntervalMs) { + this.consumer = Objects.requireNonNull(consumer); + this.remotePartitionMetadataEventHandler = Objects.requireNonNull(remotePartitionMetadataEventHandler); + this.topicPartitioner = Objects.requireNonNull(topicPartitioner); + this.time = Objects.requireNonNull(time); + this.committedOffsetSyncIntervalMs = committedOffsetSyncIntervalMs; + + initializeConsumerAssignment(committedOffsetsPath); + } + + private void initializeConsumerAssignment(Path committedOffsetsPath) { + try { + committedOffsetsFile = new CommittedOffsetsFile(committedOffsetsPath.toFile()); + } catch (IOException e) { + throw new KafkaException(e); + } + + Map committedOffsets = Collections.emptyMap(); + try { + // Load committed offset and assign them in the consumer. + committedOffsets = committedOffsetsFile.readEntries(); + } catch (IOException e) { + // Ignore the error and consumer consumes from the earliest offset. + log.error("Encountered error while building committed offsets from the file. " + + "Consumer will consume from the earliest offset for the assigned partitions.", e); + } + + if (!committedOffsets.isEmpty()) { + // Assign topic partitions from the earlier committed offsets file. + Set earlierAssignedPartitions = committedOffsets.keySet(); + assignedMetaPartitions = Collections.unmodifiableSet(earlierAssignedPartitions); + Set metadataTopicPartitions = earlierAssignedPartitions.stream() + .map(x -> new TopicPartition(REMOTE_LOG_METADATA_TOPIC_NAME, x)) + .collect(Collectors.toSet()); + consumer.assign(metadataTopicPartitions); + + // Seek to the committed offsets + for (Map.Entry entry : committedOffsets.entrySet()) { + partitionToConsumedOffsets.put(entry.getKey(), entry.getValue()); + consumer.seek(new TopicPartition(REMOTE_LOG_METADATA_TOPIC_NAME, entry.getKey()), entry.getValue()); + } + + lastSyncedPartitionToConsumedOffsets = Collections.unmodifiableMap(committedOffsets); + } + } + + @Override + public void run() { + log.info("Started Consumer task thread."); + lastSyncedTimeMs = time.milliseconds(); + try { + while (!closing) { + maybeWaitForPartitionsAssignment(); + + log.info("Polling consumer to receive remote log metadata topic records"); + ConsumerRecords consumerRecords = consumer.poll(Duration.ofMillis(POLL_INTERVAL_MS)); + for (ConsumerRecord record : consumerRecords) { + processConsumerRecord(record); + } + + maybeSyncCommittedDataAndOffsets(false); + } + } catch (Exception e) { + log.error("Error occurred in consumer task, close:[{}]", closing, e); + } finally { + maybeSyncCommittedDataAndOffsets(true); + closeConsumer(); + log.info("Exiting from consumer task thread"); + } + } + + private void processConsumerRecord(ConsumerRecord record) { + // Taking assignPartitionsLock here as updateAssignmentsForPartitions changes assignedTopicPartitions + // and also calls remotePartitionMetadataEventHandler.clearTopicPartition(removedPartition) for the removed + // partitions. + RemoteLogMetadata remoteLogMetadata = serde.deserialize(record.value()); + synchronized (assignPartitionsLock) { + if (assignedTopicPartitions.contains(remoteLogMetadata.topicIdPartition())) { + remotePartitionMetadataEventHandler.handleRemoteLogMetadata(remoteLogMetadata); + } else { + log.debug("This event {} is skipped as the topic partition is not assigned for this instance.", remoteLogMetadata); + } + partitionToConsumedOffsets.put(record.partition(), record.offset()); + } + } + + private void maybeSyncCommittedDataAndOffsets(boolean forceSync) { + // Return immediately if there is no consumption from last time. + boolean noConsumedOffsetUpdates = partitionToConsumedOffsets.equals(lastSyncedPartitionToConsumedOffsets); + if (noConsumedOffsetUpdates || !forceSync && time.milliseconds() - lastSyncedTimeMs < committedOffsetSyncIntervalMs) { + log.debug("Skip syncing committed offsets, noConsumedOffsetUpdates: {}, forceSync: {}", noConsumedOffsetUpdates, forceSync); + return; + } + + try { + // Need to take lock on assignPartitionsLock as assignedTopicPartitions might + // get updated by other threads. + synchronized (assignPartitionsLock) { + for (TopicIdPartition topicIdPartition : assignedTopicPartitions) { + int metadataPartition = topicPartitioner.metadataPartition(topicIdPartition); + Long offset = partitionToConsumedOffsets.get(metadataPartition); + if (offset != null) { + remotePartitionMetadataEventHandler.syncLogMetadataSnapshot(topicIdPartition, metadataPartition, offset); + } else { + log.debug("Skipping syncup of the remote-log-metadata-file for partition:{} , with remote log metadata partition{}, and no offset", + topicIdPartition, metadataPartition); + } + } + + // Write partitionToConsumedOffsets into committed offsets file as we do not want to process them again + // in case of restarts. + committedOffsetsFile.writeEntries(partitionToConsumedOffsets); + lastSyncedPartitionToConsumedOffsets = new HashMap<>(partitionToConsumedOffsets); + } + + lastSyncedTimeMs = time.milliseconds(); + } catch (IOException e) { + throw new KafkaException("Error encountered while writing committed offsets to a local file", e); + } + } + + private void closeConsumer() { + log.info("Closing the consumer instance"); + try { + consumer.close(Duration.ofSeconds(30)); + } catch (Exception e) { + log.error("Error encountered while closing the consumer", e); + } + } + + private void maybeWaitForPartitionsAssignment() { + Set assignedMetaPartitionsSnapshot = Collections.emptySet(); + synchronized (assignPartitionsLock) { + // If it is closing, return immediately. This should be inside the assignPartitionsLock as the closing is updated + // in close() method with in the same lock to avoid any race conditions. + if (closing) { + return; + } + + while (assignedMetaPartitions.isEmpty()) { + // If no partitions are assigned, wait until they are assigned. + log.debug("Waiting for assigned remote log metadata partitions.."); + try { + // No timeout is set here, as it is always notified. Even when it is closed, the race can happen + // between the thread calling this method and the thread calling close(). We should have a check + // for closing as that might have been set and notified with assignPartitionsLock by `close` + // method. + assignPartitionsLock.wait(); + + if (closing) { + return; + } + } catch (InterruptedException e) { + throw new KafkaException(e); + } + } + + if (assignPartitions) { + assignedMetaPartitionsSnapshot = new HashSet<>(assignedMetaPartitions); + // Removing unassigned meta partitions from partitionToConsumedOffsets and partitionToCommittedOffsets + partitionToConsumedOffsets.entrySet().removeIf(entry -> !assignedMetaPartitions.contains(entry.getKey())); + + assignPartitions = false; + } + } + + if (!assignedMetaPartitionsSnapshot.isEmpty()) { + executeReassignment(assignedMetaPartitionsSnapshot); + } + } + + private void executeReassignment(Set assignedMetaPartitionsSnapshot) { + Set assignedMetaTopicPartitions = + assignedMetaPartitionsSnapshot.stream() + .map(partitionNum -> new TopicPartition(REMOTE_LOG_METADATA_TOPIC_NAME, partitionNum)) + .collect(Collectors.toSet()); + log.info("Reassigning partitions to consumer task [{}]", assignedMetaTopicPartitions); + consumer.assign(assignedMetaTopicPartitions); + } + + public void addAssignmentsForPartitions(Set partitions) { + updateAssignmentsForPartitions(partitions, Collections.emptySet()); + } + + public void removeAssignmentsForPartitions(Set partitions) { + updateAssignmentsForPartitions(Collections.emptySet(), partitions); + } + + private void updateAssignmentsForPartitions(Set addedPartitions, + Set removedPartitions) { + log.info("Updating assignments for addedPartitions: {} and removedPartition: {}", addedPartitions, removedPartitions); + + Objects.requireNonNull(addedPartitions, "addedPartitions must not be null"); + Objects.requireNonNull(removedPartitions, "removedPartitions must not be null"); + + if (addedPartitions.isEmpty() && removedPartitions.isEmpty()) { + return; + } + + synchronized (assignPartitionsLock) { + Set updatedReassignedPartitions = new HashSet<>(assignedTopicPartitions); + updatedReassignedPartitions.addAll(addedPartitions); + updatedReassignedPartitions.removeAll(removedPartitions); + Set updatedAssignedMetaPartitions = new HashSet<>(); + for (TopicIdPartition tp : updatedReassignedPartitions) { + updatedAssignedMetaPartitions.add(topicPartitioner.metadataPartition(tp)); + } + + // Clear removed topic partitions from inmemory cache. + for (TopicIdPartition removedPartition : removedPartitions) { + remotePartitionMetadataEventHandler.clearTopicPartition(removedPartition); + } + + assignedTopicPartitions = Collections.unmodifiableSet(updatedReassignedPartitions); + log.debug("Assigned topic partitions: {}", assignedTopicPartitions); + + if (!updatedAssignedMetaPartitions.equals(assignedMetaPartitions)) { + assignedMetaPartitions = Collections.unmodifiableSet(updatedAssignedMetaPartitions); + log.debug("Assigned metadata topic partitions: {}", assignedMetaPartitions); + + assignPartitions = true; + assignPartitionsLock.notifyAll(); + } else { + log.debug("No change in assigned metadata topic partitions: {}", assignedMetaPartitions); + } + } + } + + public Optional receivedOffsetForPartition(int partition) { + return Optional.ofNullable(partitionToConsumedOffsets.get(partition)); + } + + public boolean isPartitionAssigned(int partition) { + return assignedMetaPartitions.contains(partition); + } + + public void close() { + if (!closing) { + synchronized (assignPartitionsLock) { + // Closing should be updated only after acquiring the lock to avoid race in + // maybeWaitForPartitionsAssignment() where it waits on assignPartitionsLock. It should not wait + // if the closing is already set. + closing = true; + consumer.wakeup(); + assignPartitionsLock.notifyAll(); + } + } + } +} diff --git a/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/FileBasedRemoteLogMetadataCache.java b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/FileBasedRemoteLogMetadataCache.java new file mode 100644 index 0000000..15e4562 --- /dev/null +++ b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/FileBasedRemoteLogMetadataCache.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentId; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadata; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; + +/** + * This is a wrapper around {@link RemoteLogMetadataCache} providing a file based snapshot of + * {@link RemoteLogMetadataCache} for the given {@code topicIdPartition}. Snapshot is stored in the given + * {@code partitionDir}. + */ +public class FileBasedRemoteLogMetadataCache extends RemoteLogMetadataCache { + private static final Logger log = LoggerFactory.getLogger(FileBasedRemoteLogMetadataCache.class); + private final RemoteLogMetadataSnapshotFile snapshotFile; + private final TopicIdPartition topicIdPartition; + + public FileBasedRemoteLogMetadataCache(TopicIdPartition topicIdPartition, + Path partitionDir) { + if (!partitionDir.toFile().exists() || !partitionDir.toFile().isDirectory()) { + throw new KafkaException("Given partition directory:" + partitionDir + " must be an existing directory."); + } + + this.topicIdPartition = topicIdPartition; + snapshotFile = new RemoteLogMetadataSnapshotFile(partitionDir); + + try { + snapshotFile.read().ifPresent(snapshot -> loadRemoteLogSegmentMetadata(snapshot)); + } catch (IOException e) { + throw new KafkaException(e); + } + } + + protected void loadRemoteLogSegmentMetadata(RemoteLogMetadataSnapshotFile.Snapshot snapshot) { + log.info("Loading snapshot for partition {} is: {}", topicIdPartition, snapshot); + for (RemoteLogSegmentMetadataSnapshot metadataSnapshot : snapshot.remoteLogSegmentMetadataSnapshots()) { + switch (metadataSnapshot.state()) { + case COPY_SEGMENT_STARTED: + addCopyInProgressSegment(createRemoteLogSegmentMetadata(metadataSnapshot)); + break; + case COPY_SEGMENT_FINISHED: + handleSegmentWithCopySegmentFinishedState(createRemoteLogSegmentMetadata(metadataSnapshot)); + break; + case DELETE_SEGMENT_STARTED: + handleSegmentWithDeleteSegmentStartedState(createRemoteLogSegmentMetadata(metadataSnapshot)); + break; + case DELETE_SEGMENT_FINISHED: + default: + throw new IllegalArgumentException("Given remoteLogSegmentMetadata has invalid state: " + metadataSnapshot); + } + } + } + + private RemoteLogSegmentMetadata createRemoteLogSegmentMetadata(RemoteLogSegmentMetadataSnapshot snapshot) { + return new RemoteLogSegmentMetadata(new RemoteLogSegmentId(topicIdPartition, snapshot.segmentId()), snapshot.startOffset(), + snapshot.endOffset(), snapshot.maxTimestampMs(), snapshot.brokerId(), snapshot.eventTimestampMs(), + snapshot.segmentSizeInBytes(), snapshot.state(), snapshot.segmentLeaderEpochs()); + } + + /** + * Flushes the in-memory state to the snapshot file. + * + * @param metadataPartition remote log metadata partition from which the messages have been consumed for the given + * user topic partition. + * @param metadataPartitionOffset remote log metadata partition offset up to which the messages have been consumed. + * @throws IOException if any errors occurred while writing the snapshot to the file. + */ + public void flushToFile(int metadataPartition, + Long metadataPartitionOffset) throws IOException { + List snapshots = new ArrayList<>(idToSegmentMetadata.size()); + for (RemoteLogLeaderEpochState state : leaderEpochEntries.values()) { + // Add unreferenced segments first, as to maintain the order when these segments are again read from + // the snapshot to build RemoteLogMetadataCache. + for (RemoteLogSegmentId id : state.unreferencedSegmentIds()) { + snapshots.add(RemoteLogSegmentMetadataSnapshot.create(idToSegmentMetadata.get(id))); + } + + // Add referenced segments. + for (RemoteLogSegmentId id : state.referencedSegmentIds()) { + snapshots.add(RemoteLogSegmentMetadataSnapshot.create(idToSegmentMetadata.get(id))); + } + } + + snapshotFile.write(new RemoteLogMetadataSnapshotFile.Snapshot(metadataPartition, metadataPartitionOffset, snapshots)); + } +} diff --git a/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/ProducerManager.java b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/ProducerManager.java new file mode 100644 index 0000000..ca40754 --- /dev/null +++ b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/ProducerManager.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage; + +import org.apache.kafka.clients.producer.Callback; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.server.log.remote.metadata.storage.serialization.RemoteLogMetadataSerde; +import org.apache.kafka.server.log.remote.storage.RemoteLogMetadata; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Closeable; +import java.time.Duration; +import java.util.concurrent.CompletableFuture; + +/** + * This class is responsible for publishing messages into the remote log metadata topic partitions. + * + * Caller of this class should take care of not sending messages once the closing of this instance is initiated. + */ +public class ProducerManager implements Closeable { + private static final Logger log = LoggerFactory.getLogger(ProducerManager.class); + + private final RemoteLogMetadataSerde serde = new RemoteLogMetadataSerde(); + private final KafkaProducer producer; + private final RemoteLogMetadataTopicPartitioner topicPartitioner; + private final TopicBasedRemoteLogMetadataManagerConfig rlmmConfig; + + public ProducerManager(TopicBasedRemoteLogMetadataManagerConfig rlmmConfig, + RemoteLogMetadataTopicPartitioner rlmmTopicPartitioner) { + this.rlmmConfig = rlmmConfig; + this.producer = new KafkaProducer<>(rlmmConfig.producerProperties()); + topicPartitioner = rlmmTopicPartitioner; + } + + /** + * Returns {@link CompletableFuture} which will complete only after publishing of the given {@code remoteLogMetadata} + * is considered complete. + * + * @param remoteLogMetadata RemoteLogMetadata to be published + * @return + */ + public CompletableFuture publishMessage(RemoteLogMetadata remoteLogMetadata) { + CompletableFuture future = new CompletableFuture<>(); + + TopicIdPartition topicIdPartition = remoteLogMetadata.topicIdPartition(); + int metadataPartitionNum = topicPartitioner.metadataPartition(topicIdPartition); + log.debug("Publishing metadata message of partition:[{}] into metadata topic partition:[{}] with payload: [{}]", + topicIdPartition, metadataPartitionNum, remoteLogMetadata); + if (metadataPartitionNum >= rlmmConfig.metadataTopicPartitionsCount()) { + // This should never occur as long as metadata partitions always remain the same. + throw new KafkaException("Chosen partition no " + metadataPartitionNum + + " must be less than the partition count: " + rlmmConfig.metadataTopicPartitionsCount()); + } + + try { + Callback callback = new Callback() { + @Override + public void onCompletion(RecordMetadata metadata, + Exception exception) { + if (exception != null) { + future.completeExceptionally(exception); + } else { + future.complete(metadata); + } + } + }; + producer.send(new ProducerRecord<>(rlmmConfig.remoteLogMetadataTopicName(), metadataPartitionNum, null, + serde.serialize(remoteLogMetadata)), callback); + } catch (Exception ex) { + future.completeExceptionally(ex); + } + + return future; + } + + public void close() { + try { + producer.close(Duration.ofSeconds(30)); + } catch (Exception e) { + log.error("Error encountered while closing the producer", e); + } + } +} diff --git a/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogLeaderEpochState.java b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogLeaderEpochState.java new file mode 100644 index 0000000..d5787dd --- /dev/null +++ b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogLeaderEpochState.java @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage; + +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentId; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadata; +import org.apache.kafka.server.log.remote.storage.RemoteResourceNotFoundException; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.Iterator; +import java.util.Map; +import java.util.NavigableMap; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentSkipListMap; + +/** + * This class represents the in-memory state of segments associated with a leader epoch. This includes the mapping of offset to + * segment ids and unreferenced segments which are not mapped to any offset but they exist in remote storage. + *

                + * This is used by {@link RemoteLogMetadataCache} to track the segments for each leader epoch. + */ +class RemoteLogLeaderEpochState { + + // It contains offset to segment ids mapping with the segment state as COPY_SEGMENT_FINISHED. + private final NavigableMap offsetToId = new ConcurrentSkipListMap<>(); + + /** + * It represents unreferenced segments for this leader epoch. It contains the segments still in COPY_SEGMENT_STARTED + * and DELETE_SEGMENT_STARTED state or these have been replaced by callers with other segments having the same + * start offset for the leader epoch. These will be returned by {@link RemoteLogMetadataCache#listAllRemoteLogSegments()} + * and {@link RemoteLogMetadataCache#listRemoteLogSegments(int leaderEpoch)} so that callers can clean them up if + * they still exist. These will be cleaned from the cache once they reach DELETE_SEGMENT_FINISHED state. + */ + private final Set unreferencedSegmentIds = ConcurrentHashMap.newKeySet(); + + // It represents the highest log offset of the segments that reached the COPY_SEGMENT_FINISHED state. + private volatile Long highestLogOffset; + + /** + * Returns all the segments associated with this leader epoch sorted by start offset in ascending order. + * + * @param idToSegmentMetadata mapping of id to segment metadata. This will be used to get RemoteLogSegmentMetadata + * for an id to be used for sorting. + */ + Iterator listAllRemoteLogSegments(Map idToSegmentMetadata) + throws RemoteResourceNotFoundException { + // Return all the segments including unreferenced metadata. + int size = offsetToId.size() + unreferencedSegmentIds.size(); + if (size == 0) { + return Collections.emptyIterator(); + } + + ArrayList metadataList = new ArrayList<>(size); + collectConvertedIdToMetadata(offsetToId.values(), idToSegmentMetadata, metadataList); + + if (!unreferencedSegmentIds.isEmpty()) { + collectConvertedIdToMetadata(unreferencedSegmentIds, idToSegmentMetadata, metadataList); + + // Sort only when unreferenced entries exist as they are already sorted in offsetToId. + metadataList.sort(Comparator.comparingLong(RemoteLogSegmentMetadata::startOffset)); + } + + return metadataList.iterator(); + } + + private void collectConvertedIdToMetadata(Collection segmentIds, + Map idToSegmentMetadata, + Collection result) throws RemoteResourceNotFoundException { + for (RemoteLogSegmentId id : segmentIds) { + RemoteLogSegmentMetadata metadata = idToSegmentMetadata.get(id); + if (metadata == null) { + throw new RemoteResourceNotFoundException("No remote log segment metadata found for :" + id); + } + result.add(metadata); + } + } + + void handleSegmentWithCopySegmentStartedState(RemoteLogSegmentId remoteLogSegmentId) { + // Add this to unreferenced set of segments for the respective leader epoch. + unreferencedSegmentIds.add(remoteLogSegmentId); + } + + void handleSegmentWithCopySegmentFinishedState(Long startOffset, RemoteLogSegmentId remoteLogSegmentId, + Long leaderEpochEndOffset) { + // Add the segment epochs mapping as the segment is copied successfully. + RemoteLogSegmentId oldEntry = offsetToId.put(startOffset, remoteLogSegmentId); + + // Remove the metadata from unreferenced entries as it is successfully copied and added to the offset mapping. + unreferencedSegmentIds.remove(remoteLogSegmentId); + + // Add the old entry to unreferenced entries as the mapping is removed for the old entry. + if (oldEntry != null) { + unreferencedSegmentIds.add(oldEntry); + } + + // Update the highest offset entry for this leader epoch as we added a new mapping. + if (highestLogOffset == null || leaderEpochEndOffset > highestLogOffset) { + highestLogOffset = leaderEpochEndOffset; + } + } + + void handleSegmentWithDeleteSegmentStartedState(Long startOffset, RemoteLogSegmentId remoteLogSegmentId) { + // Remove the offset mappings as this segment is getting deleted. + offsetToId.remove(startOffset, remoteLogSegmentId); + + // Add this entry to unreferenced set for the leader epoch as it is being deleted. + // This allows any retries of deletion as these are returned from listAllSegments and listSegments(leaderEpoch). + unreferencedSegmentIds.add(remoteLogSegmentId); + } + + void handleSegmentWithDeleteSegmentFinishedState(RemoteLogSegmentId remoteLogSegmentId) { + // It completely removes the tracking of this segment as it is considered as deleted. + unreferencedSegmentIds.remove(remoteLogSegmentId); + } + + Long highestLogOffset() { + return highestLogOffset; + } + + /** + * Returns the RemoteLogSegmentId of a segment for the given offset, if there exists a mapping associated with + * the greatest offset less than or equal to the given offset, or null if there is no such mapping. + * + * @param offset offset + */ + RemoteLogSegmentId floorEntry(long offset) { + Map.Entry entry = offsetToId.floorEntry(offset); + + return entry == null ? null : entry.getValue(); + } + + Collection unreferencedSegmentIds() { + return Collections.unmodifiableCollection(unreferencedSegmentIds); + } + + Collection referencedSegmentIds() { + return Collections.unmodifiableCollection(offsetToId.values()); + } + + /** + * Action interface to act on remote log segment transition for the given {@link RemoteLogLeaderEpochState}. + */ + @FunctionalInterface + interface Action { + + /** + * Performs this operation with the given {@code remoteLogLeaderEpochState}. + * + * @param leaderEpoch leader epoch value + * @param remoteLogLeaderEpochState In-memory state of the segments for a leader epoch. + * @param startOffset start offset of the segment. + * @param segmentId segment id. + */ + void accept(int leaderEpoch, + RemoteLogLeaderEpochState remoteLogLeaderEpochState, + long startOffset, + RemoteLogSegmentId segmentId); + } + +} diff --git a/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogMetadataCache.java b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogMetadataCache.java new file mode 100644 index 0000000..ed88bc4 --- /dev/null +++ b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogMetadataCache.java @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage; + +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentId; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadata; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadataUpdate; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentState; +import org.apache.kafka.server.log.remote.storage.RemoteResourceNotFoundException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collections; +import java.util.Iterator; +import java.util.Map; +import java.util.NavigableMap; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +/** + * This class provides an in-memory cache of remote log segment metadata. This maintains the lineage of segments + * with respect to leader epochs. + *

                + * Remote log segment can go through the state transitions as mentioned in {@link RemoteLogSegmentState}. + *

                + * This class will have all the segments which did not reach terminal state viz DELETE_SEGMENT_FINISHED. That means,any + * segment reaching the terminal state will get cleared from this instance. + * This class provides different methods to fetch segment metadata like {@link #remoteLogSegmentMetadata(int, long)}, + * {@link #highestOffsetForEpoch(int)}, {@link #listRemoteLogSegments(int)}, {@link #listAllRemoteLogSegments()}. Those + * methods have different semantics to fetch the segment based on its state. + *

                + *

                  + *
                • + * {@link RemoteLogSegmentState#COPY_SEGMENT_STARTED}: + *
                  + * Segment in this state indicates it is not yet copied successfully. So, these segments will not be + * accessible for reads but these are considered for cleanups when a partition is deleted. + *
                • + *
                • + * {@link RemoteLogSegmentState#COPY_SEGMENT_FINISHED}: + *
                  + * Segment in this state indicates it is successfully copied and it is available for reads. So, these segments + * will be accessible for reads. But this should be available for any cleanup activity like deleting segments by the + * caller of this class. + *
                • + *
                • + * {@link RemoteLogSegmentState#DELETE_SEGMENT_STARTED}: + * Segment in this state indicates it is getting deleted. That means, it is not available for reads. But it should be + * available for any cleanup activity like deleting segments by the caller of this class. + *
                • + *
                • + * {@link RemoteLogSegmentState#DELETE_SEGMENT_FINISHED}: + * Segment in this state indicate it is already deleted. That means, it is not available for any activity including + * reads or cleanup activity. This cache will clear entries containing this state. + *
                • + *
                + * + *

                + * The below table summarizes whether the segment with the respective state are available for the given methods. + *

                + * +---------------------------------+----------------------+------------------------+-------------------------+-------------------------+
                + * |  Method / SegmentState          | COPY_SEGMENT_STARTED | COPY_SEGMENT_FINISHED  | DELETE_SEGMENT_STARTED  | DELETE_SEGMENT_FINISHED |
                + * |---------------------------------+----------------------+------------------------+-------------------------+-------------------------|
                + * | remoteLogSegmentMetadata        |        No            |           Yes          |          No             |           No            |
                + * | (int leaderEpoch, long offset)  |                      |                        |                         |                         |
                + * |---------------------------------+----------------------+------------------------+-------------------------+-------------------------|
                + * | listRemoteLogSegments           |        Yes           |           Yes          |          Yes            |           No            |
                + * | (int leaderEpoch)               |                      |                        |                         |                         |
                + * |---------------------------------+----------------------+------------------------+-------------------------+-------------------------|
                + * | listAllRemoteLogSegments()      |        Yes           |           Yes          |          Yes            |           No            |
                + * |                                 |                      |                        |                         |                         |
                + * +---------------------------------+----------------------+------------------------+-------------------------+-------------------------+
                + * 
                + *

                + *

                + */ +public class RemoteLogMetadataCache { + + private static final Logger log = LoggerFactory.getLogger(RemoteLogMetadataCache.class); + + // It contains all the segment-id to metadata mappings which did not reach the terminal state viz DELETE_SEGMENT_FINISHED. + protected final ConcurrentMap idToSegmentMetadata + = new ConcurrentHashMap<>(); + + // It contains leader epoch to the respective entry containing the state. + // TODO We are not clearing the entry for epoch when RemoteLogLeaderEpochState becomes empty. This will be addressed + // later. We will look into it when we integrate these APIs along with RemoteLogManager changes. + // https://issues.apache.org/jira/browse/KAFKA-12641 + protected final ConcurrentMap leaderEpochEntries = new ConcurrentHashMap<>(); + + /** + * Returns {@link RemoteLogSegmentMetadata} if it exists for the given leader-epoch containing the offset and with + * {@link RemoteLogSegmentState#COPY_SEGMENT_FINISHED} state, else returns {@link Optional#empty()}. + * + * @param leaderEpoch leader epoch for the given offset + * @param offset offset + * @return the requested remote log segment metadata if it exists. + */ + public Optional remoteLogSegmentMetadata(int leaderEpoch, long offset) { + RemoteLogLeaderEpochState remoteLogLeaderEpochState = leaderEpochEntries.get(leaderEpoch); + + if (remoteLogLeaderEpochState == null) { + return Optional.empty(); + } + + // Look for floor entry as the given offset may exist in this entry. + RemoteLogSegmentId remoteLogSegmentId = remoteLogLeaderEpochState.floorEntry(offset); + if (remoteLogSegmentId == null) { + // If the offset is lower than the minimum offset available in metadata then return empty. + return Optional.empty(); + } + + RemoteLogSegmentMetadata metadata = idToSegmentMetadata.get(remoteLogSegmentId); + // Check whether the given offset with leaderEpoch exists in this segment. + // Check for epoch's offset boundaries with in this segment. + // 1. Get the next epoch's start offset -1 if exists + // 2. If no next epoch exists, then segment end offset can be considered as epoch's relative end offset. + Map.Entry nextEntry = metadata.segmentLeaderEpochs().higherEntry(leaderEpoch); + long epochEndOffset = (nextEntry != null) ? nextEntry.getValue() - 1 : metadata.endOffset(); + + // Return empty when target offset > epoch's end offset. + return offset > epochEndOffset ? Optional.empty() : Optional.of(metadata); + } + + public void updateRemoteLogSegmentMetadata(RemoteLogSegmentMetadataUpdate metadataUpdate) + throws RemoteResourceNotFoundException { + log.debug("Updating remote log segment metadata: [{}]", metadataUpdate); + Objects.requireNonNull(metadataUpdate, "metadataUpdate can not be null"); + + RemoteLogSegmentState targetState = metadataUpdate.state(); + RemoteLogSegmentId remoteLogSegmentId = metadataUpdate.remoteLogSegmentId(); + RemoteLogSegmentMetadata existingMetadata = idToSegmentMetadata.get(remoteLogSegmentId); + if (existingMetadata == null) { + throw new RemoteResourceNotFoundException("No remote log segment metadata found for :" + + remoteLogSegmentId); + } + + // Check the state transition. + checkStateTransition(existingMetadata.state(), targetState); + + switch (targetState) { + case COPY_SEGMENT_STARTED: + // Callers should use addCopyInProgressSegment to add RemoteLogSegmentMetadata with state as + // RemoteLogSegmentState.COPY_SEGMENT_STARTED. + throw new IllegalArgumentException("metadataUpdate: " + metadataUpdate + " with state " + RemoteLogSegmentState.COPY_SEGMENT_STARTED + + " can not be updated"); + case COPY_SEGMENT_FINISHED: + handleSegmentWithCopySegmentFinishedState(existingMetadata.createWithUpdates(metadataUpdate)); + break; + case DELETE_SEGMENT_STARTED: + handleSegmentWithDeleteSegmentStartedState(existingMetadata.createWithUpdates(metadataUpdate)); + break; + case DELETE_SEGMENT_FINISHED: + handleSegmentWithDeleteSegmentFinishedState(existingMetadata.createWithUpdates(metadataUpdate)); + break; + default: + throw new IllegalArgumentException("Metadata with the state " + targetState + " is not supported"); + } + } + + protected final void handleSegmentWithCopySegmentFinishedState(RemoteLogSegmentMetadata remoteLogSegmentMetadata) { + doHandleSegmentStateTransitionForLeaderEpochs(remoteLogSegmentMetadata, + (leaderEpoch, remoteLogLeaderEpochState, startOffset, segmentId) -> { + long leaderEpochEndOffset = highestOffsetForEpoch(leaderEpoch, + remoteLogSegmentMetadata); + remoteLogLeaderEpochState.handleSegmentWithCopySegmentFinishedState(startOffset, + segmentId, + leaderEpochEndOffset); + }); + + // Put the entry with the updated metadata. + idToSegmentMetadata.put(remoteLogSegmentMetadata.remoteLogSegmentId(), remoteLogSegmentMetadata); + } + + protected final void handleSegmentWithDeleteSegmentStartedState(RemoteLogSegmentMetadata remoteLogSegmentMetadata) { + log.debug("Cleaning up the state for : [{}]", remoteLogSegmentMetadata); + + doHandleSegmentStateTransitionForLeaderEpochs(remoteLogSegmentMetadata, + (leaderEpoch, remoteLogLeaderEpochState, startOffset, segmentId) -> + remoteLogLeaderEpochState.handleSegmentWithDeleteSegmentStartedState(startOffset, segmentId)); + + // Put the entry with the updated metadata. + idToSegmentMetadata.put(remoteLogSegmentMetadata.remoteLogSegmentId(), remoteLogSegmentMetadata); + } + + private void handleSegmentWithDeleteSegmentFinishedState(RemoteLogSegmentMetadata remoteLogSegmentMetadata) { + log.debug("Removing the entry as it reached the terminal state: [{}]", remoteLogSegmentMetadata); + + doHandleSegmentStateTransitionForLeaderEpochs(remoteLogSegmentMetadata, + (leaderEpoch, remoteLogLeaderEpochState, startOffset, segmentId) -> + remoteLogLeaderEpochState.handleSegmentWithDeleteSegmentFinishedState(segmentId)); + + // Remove the segment's id to metadata mapping because this segment is considered as deleted and it cleared all + // the state of this segment in the cache. + idToSegmentMetadata.remove(remoteLogSegmentMetadata.remoteLogSegmentId()); + } + + private void doHandleSegmentStateTransitionForLeaderEpochs(RemoteLogSegmentMetadata remoteLogSegmentMetadata, + RemoteLogLeaderEpochState.Action action) { + RemoteLogSegmentId remoteLogSegmentId = remoteLogSegmentMetadata.remoteLogSegmentId(); + Map leaderEpochToOffset = remoteLogSegmentMetadata.segmentLeaderEpochs(); + + // Go through all the leader epochs and apply the given action. + for (Map.Entry entry : leaderEpochToOffset.entrySet()) { + Integer leaderEpoch = entry.getKey(); + Long startOffset = entry.getValue(); + // leaderEpochEntries will be empty when resorting the metadata from snapshot. + RemoteLogLeaderEpochState remoteLogLeaderEpochState = leaderEpochEntries.computeIfAbsent( + leaderEpoch, x -> new RemoteLogLeaderEpochState()); + action.accept(leaderEpoch, remoteLogLeaderEpochState, startOffset, remoteLogSegmentId); + } + } + + private static long highestOffsetForEpoch(Integer leaderEpoch, RemoteLogSegmentMetadata segmentMetadata) { + // Compute the highest offset for the leader epoch with in the segment + NavigableMap epochToOffset = segmentMetadata.segmentLeaderEpochs(); + Map.Entry nextEntry = epochToOffset.higherEntry(leaderEpoch); + + return nextEntry != null ? nextEntry.getValue() - 1 : segmentMetadata.endOffset(); + } + + /** + * Returns all the segments stored in this cache. + * + * @return + */ + public Iterator listAllRemoteLogSegments() { + // Return all the segments including unreferenced metadata. + return Collections.unmodifiableCollection(idToSegmentMetadata.values()).iterator(); + } + + /** + * Returns all the segments mapped to the leader epoch that exist in this cache sorted by {@link RemoteLogSegmentMetadata#startOffset()}. + * + * @param leaderEpoch leader epoch. + */ + public Iterator listRemoteLogSegments(int leaderEpoch) + throws RemoteResourceNotFoundException { + RemoteLogLeaderEpochState remoteLogLeaderEpochState = leaderEpochEntries.get(leaderEpoch); + if (remoteLogLeaderEpochState == null) { + return Collections.emptyIterator(); + } + + return remoteLogLeaderEpochState.listAllRemoteLogSegments(idToSegmentMetadata); + } + + /** + * Returns the highest offset of a segment for the given leader epoch if exists, else it returns empty. The segments + * that have reached the {@link RemoteLogSegmentState#COPY_SEGMENT_FINISHED} or later states are considered here. + * + * @param leaderEpoch leader epoch + */ + public Optional highestOffsetForEpoch(int leaderEpoch) { + RemoteLogLeaderEpochState entry = leaderEpochEntries.get(leaderEpoch); + return entry != null ? Optional.ofNullable(entry.highestLogOffset()) : Optional.empty(); + } + + /** + * This method tracks the given remote segment as not yet available for reads. It does not add the segment + * leader epoch offset mapping until this segment reaches COPY_SEGMENT_FINISHED state. + * + * @param remoteLogSegmentMetadata RemoteLogSegmentMetadata instance + */ + public void addCopyInProgressSegment(RemoteLogSegmentMetadata remoteLogSegmentMetadata) { + log.debug("Adding to in-progress state: [{}]", remoteLogSegmentMetadata); + Objects.requireNonNull(remoteLogSegmentMetadata, "remoteLogSegmentMetadata can not be null"); + + // This method is allowed only to add remote log segment with the initial state(which is RemoteLogSegmentState.COPY_SEGMENT_STARTED) + // but not to update the existing remote log segment metadata. + if (remoteLogSegmentMetadata.state() != RemoteLogSegmentState.COPY_SEGMENT_STARTED) { + throw new IllegalArgumentException( + "Given remoteLogSegmentMetadata:" + remoteLogSegmentMetadata + " should have state as " + RemoteLogSegmentState.COPY_SEGMENT_STARTED + + " but it contains state as: " + remoteLogSegmentMetadata.state()); + } + + RemoteLogSegmentId remoteLogSegmentId = remoteLogSegmentMetadata.remoteLogSegmentId(); + RemoteLogSegmentMetadata existingMetadata = idToSegmentMetadata.get(remoteLogSegmentId); + checkStateTransition(existingMetadata != null ? existingMetadata.state() : null, + remoteLogSegmentMetadata.state()); + + for (Integer epoch : remoteLogSegmentMetadata.segmentLeaderEpochs().keySet()) { + leaderEpochEntries.computeIfAbsent(epoch, leaderEpoch -> new RemoteLogLeaderEpochState()) + .handleSegmentWithCopySegmentStartedState(remoteLogSegmentId); + } + + idToSegmentMetadata.put(remoteLogSegmentId, remoteLogSegmentMetadata); + } + + private void checkStateTransition(RemoteLogSegmentState existingState, RemoteLogSegmentState targetState) { + if (!RemoteLogSegmentState.isValidTransition(existingState, targetState)) { + throw new IllegalStateException( + "Current state: " + existingState + " can not be transitioned to target state: " + targetState); + } + } + +} diff --git a/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogMetadataSnapshotFile.java b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogMetadataSnapshotFile.java new file mode 100644 index 0000000..cee77ee --- /dev/null +++ b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogMetadataSnapshotFile.java @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.server.log.remote.metadata.storage.serialization.RemoteLogMetadataSerde; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.FileChannel; +import java.nio.channels.ReadableByteChannel; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +/** + * This class represents the remote log data snapshot stored in a file for a specific topic partition. This is used by + * {@link TopicBasedRemoteLogMetadataManager} to store the remote log metadata received for a specific partition from + * remote log metadata topic. This will avoid reading the remote log metadata messages from the topic again when a + * broker restarts. + */ +public class RemoteLogMetadataSnapshotFile { + private static final Logger log = LoggerFactory.getLogger(RemoteLogMetadataSnapshotFile.class); + + public static final String COMMITTED_LOG_METADATA_SNAPSHOT_FILE_NAME = "remote_log_snapshot"; + + // File format: + //
                [...] + // header: + // entry: + + // header size: 2 (version) + 4 (partition num) + 8 (offset) + 4 (entries size) = 18 + private static final int HEADER_SIZE = 18; + + private final File metadataStoreFile; + private final RemoteLogMetadataSerde serde = new RemoteLogMetadataSerde(); + + /** + * Creates a CommittedLogMetadataSnapshotFile instance backed by a file with the name `remote_log_snapshot` in + * the given {@code metadataStoreDir}. It creates the file if it does not exist. + * + * @param metadataStoreDir directory in which the snapshot file to be created. + */ + RemoteLogMetadataSnapshotFile(Path metadataStoreDir) { + this.metadataStoreFile = new File(metadataStoreDir.toFile(), COMMITTED_LOG_METADATA_SNAPSHOT_FILE_NAME); + + // Create an empty file if it does not exist. + try { + boolean newFileCreated = metadataStoreFile.createNewFile(); + log.info("Remote log metadata snapshot file: [{}], newFileCreated: [{}]", metadataStoreFile, newFileCreated); + } catch (IOException e) { + throw new KafkaException(e); + } + } + + /** + * Writes the given snapshot replacing the earlier snapshot data. + * + * @param snapshot Snapshot to be stored. + * @throws IOException if there4 is any error in writing the given snapshot to the file. + */ + public synchronized void write(Snapshot snapshot) throws IOException { + Path newMetadataSnapshotFilePath = new File(metadataStoreFile.getAbsolutePath() + ".tmp").toPath(); + try (FileChannel fileChannel = FileChannel.open(newMetadataSnapshotFilePath, + StandardOpenOption.CREATE, StandardOpenOption.READ, StandardOpenOption.WRITE)) { + + // header: + ByteBuffer headerBuffer = ByteBuffer.allocate(HEADER_SIZE); + + // Write version + headerBuffer.putShort(snapshot.version()); + + // Write metadata partition and metadata partition offset + headerBuffer.putInt(snapshot.metadataPartition()); + + // Write metadata partition offset + headerBuffer.putLong(snapshot.metadataPartitionOffset()); + + // Write entries size + Collection metadataSnapshots = snapshot.remoteLogSegmentMetadataSnapshots(); + headerBuffer.putInt(metadataSnapshots.size()); + + // Write header + headerBuffer.flip(); + fileChannel.write(headerBuffer); + + // Write each entry + ByteBuffer lenBuffer = ByteBuffer.allocate(4); + for (RemoteLogSegmentMetadataSnapshot metadataSnapshot : metadataSnapshots) { + final byte[] serializedBytes = serde.serialize(metadataSnapshot); + // entry format: + + // Write entry length + lenBuffer.putInt(serializedBytes.length); + lenBuffer.flip(); + fileChannel.write(lenBuffer); + lenBuffer.rewind(); + + // Write entry bytes + fileChannel.write(ByteBuffer.wrap(serializedBytes)); + } + + fileChannel.force(true); + } + + Utils.atomicMoveWithFallback(newMetadataSnapshotFilePath, metadataStoreFile.toPath()); + } + + /** + * @return the Snapshot if it exists. + * @throws IOException if there is any error in reading the stored snapshot. + */ + public synchronized Optional read() throws IOException { + + // Checking for empty files. + if (metadataStoreFile.length() == 0) { + return Optional.empty(); + } + + try (ReadableByteChannel channel = Channels.newChannel(new FileInputStream(metadataStoreFile))) { + + // header: + // Read header + ByteBuffer headerBuffer = ByteBuffer.allocate(HEADER_SIZE); + channel.read(headerBuffer); + headerBuffer.rewind(); + short version = headerBuffer.getShort(); + int metadataPartition = headerBuffer.getInt(); + long metadataPartitionOffset = headerBuffer.getLong(); + int metadataSnapshotsSize = headerBuffer.getInt(); + + List result = new ArrayList<>(metadataSnapshotsSize); + ByteBuffer lenBuffer = ByteBuffer.allocate(4); + int lenBufferReadCt; + while ((lenBufferReadCt = channel.read(lenBuffer)) > 0) { + lenBuffer.rewind(); + + if (lenBufferReadCt != lenBuffer.capacity()) { + throw new IOException("Invalid amount of data read for the length of an entry, file may have been corrupted."); + } + + // entry format: + + // Read the length of each entry + final int len = lenBuffer.getInt(); + lenBuffer.rewind(); + + // Read the entry + ByteBuffer data = ByteBuffer.allocate(len); + final int read = channel.read(data); + if (read != len) { + throw new IOException("Invalid amount of data read, file may have been corrupted."); + } + + // We are always adding RemoteLogSegmentMetadata only as you can see in #write() method. + // Did not add a specific serde for RemoteLogSegmentMetadata and reusing RemoteLogMetadataSerde + final RemoteLogSegmentMetadataSnapshot remoteLogSegmentMetadata = + (RemoteLogSegmentMetadataSnapshot) serde.deserialize(data.array()); + result.add(remoteLogSegmentMetadata); + } + + if (metadataSnapshotsSize != result.size()) { + throw new IOException("Unexpected entries in the snapshot file. Expected size: " + metadataSnapshotsSize + + ", but found: " + result.size()); + } + + return Optional.of(new Snapshot(version, metadataPartition, metadataPartitionOffset, result)); + } + } + + /** + * This class represents the collection of remote log metadata for a specific topic partition. + */ + public static final class Snapshot { + private static final short CURRENT_VERSION = 0; + + private final short version; + private final int metadataPartition; + private final long metadataPartitionOffset; + private final Collection remoteLogSegmentMetadataSnapshots; + + public Snapshot(int metadataPartition, + long metadataPartitionOffset, + Collection remoteLogSegmentMetadataSnapshots) { + this(CURRENT_VERSION, metadataPartition, metadataPartitionOffset, remoteLogSegmentMetadataSnapshots); + } + + public Snapshot(short version, + int metadataPartition, + long metadataPartitionOffset, + Collection remoteLogSegmentMetadataSnapshots) { + // We will add multiple version support in future if needed. For now, the only supported version is CURRENT_VERSION viz 0. + if (version != CURRENT_VERSION) { + throw new IllegalArgumentException("Unexpected version received: " + version); + } + this.version = version; + this.metadataPartition = metadataPartition; + this.metadataPartitionOffset = metadataPartitionOffset; + this.remoteLogSegmentMetadataSnapshots = remoteLogSegmentMetadataSnapshots; + } + + public short version() { + return version; + } + + public int metadataPartition() { + return metadataPartition; + } + + public long metadataPartitionOffset() { + return metadataPartitionOffset; + } + + public Collection remoteLogSegmentMetadataSnapshots() { + return remoteLogSegmentMetadataSnapshots; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof Snapshot)) return false; + Snapshot snapshot = (Snapshot) o; + return version == snapshot.version && metadataPartition == snapshot.metadataPartition + && metadataPartitionOffset == snapshot.metadataPartitionOffset + && Objects.equals(remoteLogSegmentMetadataSnapshots, snapshot.remoteLogSegmentMetadataSnapshots); + } + + @Override + public int hashCode() { + return Objects.hash(version, metadataPartition, metadataPartitionOffset, remoteLogSegmentMetadataSnapshots); + } + + @Override + public String toString() { + return "Snapshot{" + + "version=" + version + + ", metadataPartition=" + metadataPartition + + ", metadataPartitionOffset=" + metadataPartitionOffset + + ", remoteLogSegmentMetadataSnapshotsSize" + remoteLogSegmentMetadataSnapshots.size() + + '}'; + } + } +} diff --git a/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogMetadataTopicPartitioner.java b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogMetadataTopicPartitioner.java new file mode 100644 index 0000000..af12647 --- /dev/null +++ b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogMetadataTopicPartitioner.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage; + +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Objects; + +public class RemoteLogMetadataTopicPartitioner { + public static final Logger log = LoggerFactory.getLogger(RemoteLogMetadataTopicPartitioner.class); + private final int numMetadataTopicPartitions; + + public RemoteLogMetadataTopicPartitioner(int numMetadataTopicPartitions) { + this.numMetadataTopicPartitions = numMetadataTopicPartitions; + } + + public int metadataPartition(TopicIdPartition topicIdPartition) { + Objects.requireNonNull(topicIdPartition, "TopicPartition can not be null"); + + int partitionNum = Utils.toPositive(Utils.murmur2(toBytes(topicIdPartition))) % numMetadataTopicPartitions; + log.debug("No of partitions [{}], partitionNum: [{}] for given topic: [{}]", numMetadataTopicPartitions, partitionNum, topicIdPartition); + return partitionNum; + } + + private byte[] toBytes(TopicIdPartition topicIdPartition) { + // We do not want to depend upon hash code generation of Uuid as that may change. + int hash = Objects.hash(topicIdPartition.topicId().getLeastSignificantBits(), + topicIdPartition.topicId().getMostSignificantBits(), + topicIdPartition.partition()); + + return toBytes(hash); + } + + private byte[] toBytes(int n) { + return new byte[]{ + (byte) (n >> 24), + (byte) (n >> 16), + (byte) (n >> 8), + (byte) n + }; + } +} diff --git a/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogSegmentMetadataSnapshot.java b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogSegmentMetadataSnapshot.java new file mode 100644 index 0000000..e7292c2 --- /dev/null +++ b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogSegmentMetadataSnapshot.java @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage; + +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.server.log.remote.storage.RemoteLogMetadata; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadata; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentState; + +import java.util.Collections; +import java.util.Map; +import java.util.NavigableMap; +import java.util.Objects; +import java.util.TreeMap; + +/** + * This class represents the entry containing the metadata about a remote log segment. This is similar to + * {@link RemoteLogSegmentMetadata} but it does not contain topic partition information. This class keeps + * only remote log segment ID but not the topic partition. + * + * This class is used in storing the snapshot of remote log metadata for a specific topic partition as mentioned + * in {@link RemoteLogMetadataSnapshotFile.Snapshot}. + */ +public class RemoteLogSegmentMetadataSnapshot extends RemoteLogMetadata { + + /** + * Universally unique remote log segment id. + */ + private final Uuid segmentId; + + /** + * Start offset of this segment. + */ + private final long startOffset; + + /** + * End offset of this segment. + */ + private final long endOffset; + + /** + * Maximum timestamp in milli seconds in the segment + */ + private final long maxTimestampMs; + + /** + * LeaderEpoch vs offset for messages within this segment. + */ + private final NavigableMap segmentLeaderEpochs; + + /** + * Size of the segment in bytes. + */ + private final int segmentSizeInBytes; + + /** + * It indicates the state in which the action is executed on this segment. + */ + private final RemoteLogSegmentState state; + + /** + * Creates an instance with the given metadata of remote log segment. + *

                + * {@code segmentLeaderEpochs} can not be empty. If all the records in this segment belong to the same leader epoch + * then it should have an entry with epoch mapping to start-offset of this segment. + * + * @param segmentId Universally unique remote log segment id. + * @param startOffset Start offset of this segment (inclusive). + * @param endOffset End offset of this segment (inclusive). + * @param maxTimestampMs Maximum timestamp in milli seconds in this segment. + * @param brokerId Broker id from which this event is generated. + * @param eventTimestampMs Epoch time in milli seconds at which the remote log segment is copied to the remote tier storage. + * @param segmentSizeInBytes Size of this segment in bytes. + * @param state State of the respective segment of remoteLogSegmentId. + * @param segmentLeaderEpochs leader epochs occurred within this segment. + */ + public RemoteLogSegmentMetadataSnapshot(Uuid segmentId, + long startOffset, + long endOffset, + long maxTimestampMs, + int brokerId, + long eventTimestampMs, + int segmentSizeInBytes, + RemoteLogSegmentState state, + Map segmentLeaderEpochs) { + super(brokerId, eventTimestampMs); + this.segmentId = Objects.requireNonNull(segmentId, "remoteLogSegmentId can not be null"); + this.state = Objects.requireNonNull(state, "state can not be null"); + + this.startOffset = startOffset; + this.endOffset = endOffset; + this.maxTimestampMs = maxTimestampMs; + this.segmentSizeInBytes = segmentSizeInBytes; + + if (segmentLeaderEpochs == null || segmentLeaderEpochs.isEmpty()) { + throw new IllegalArgumentException("segmentLeaderEpochs can not be null or empty"); + } + + this.segmentLeaderEpochs = Collections.unmodifiableNavigableMap(new TreeMap<>(segmentLeaderEpochs)); + } + + public static RemoteLogSegmentMetadataSnapshot create(RemoteLogSegmentMetadata metadata) { + return new RemoteLogSegmentMetadataSnapshot(metadata.remoteLogSegmentId().id(), metadata.startOffset(), metadata.endOffset(), + metadata.maxTimestampMs(), metadata.brokerId(), metadata.eventTimestampMs(), + metadata.segmentSizeInBytes(), metadata.state(), metadata.segmentLeaderEpochs()); + } + + /** + * @return unique id of this segment. + */ + public Uuid segmentId() { + return segmentId; + } + + /** + * @return Start offset of this segment (inclusive). + */ + public long startOffset() { + return startOffset; + } + + /** + * @return End offset of this segment (inclusive). + */ + public long endOffset() { + return endOffset; + } + + /** + * @return Total size of this segment in bytes. + */ + public int segmentSizeInBytes() { + return segmentSizeInBytes; + } + + /** + * @return Maximum timestamp in milli seconds of a record within this segment. + */ + public long maxTimestampMs() { + return maxTimestampMs; + } + + /** + * @return Map of leader epoch vs offset for the records available in this segment. + */ + public NavigableMap segmentLeaderEpochs() { + return segmentLeaderEpochs; + } + + /** + * Returns the current state of this remote log segment. It can be any of the below + *

                  + * {@link RemoteLogSegmentState#COPY_SEGMENT_STARTED} + * {@link RemoteLogSegmentState#COPY_SEGMENT_FINISHED} + * {@link RemoteLogSegmentState#DELETE_SEGMENT_STARTED} + * {@link RemoteLogSegmentState#DELETE_SEGMENT_FINISHED} + *
                + */ + public RemoteLogSegmentState state() { + return state; + } + + @Override + public TopicIdPartition topicIdPartition() { + throw new UnsupportedOperationException("This metadata does not have topic partition with it."); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof RemoteLogSegmentMetadataSnapshot)) return false; + RemoteLogSegmentMetadataSnapshot that = (RemoteLogSegmentMetadataSnapshot) o; + return startOffset == that.startOffset && endOffset == that.endOffset && maxTimestampMs == that.maxTimestampMs && segmentSizeInBytes == that.segmentSizeInBytes && Objects.equals( + segmentId, that.segmentId) && Objects.equals(segmentLeaderEpochs, that.segmentLeaderEpochs) && state == that.state; + } + + @Override + public int hashCode() { + return Objects.hash(segmentId, startOffset, endOffset, maxTimestampMs, segmentLeaderEpochs, segmentSizeInBytes, state); + } + + @Override + public String toString() { + return "RemoteLogSegmentMetadataSnapshot{" + + "segmentId=" + segmentId + + ", startOffset=" + startOffset + + ", endOffset=" + endOffset + + ", maxTimestampMs=" + maxTimestampMs + + ", segmentLeaderEpochs=" + segmentLeaderEpochs + + ", segmentSizeInBytes=" + segmentSizeInBytes + + ", state=" + state + + '}'; + } +} diff --git a/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/RemotePartitionMetadataEventHandler.java b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/RemotePartitionMetadataEventHandler.java new file mode 100644 index 0000000..c92a51e --- /dev/null +++ b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/RemotePartitionMetadataEventHandler.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage; + +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.server.log.remote.storage.RemoteLogMetadata; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadata; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadataUpdate; +import org.apache.kafka.server.log.remote.storage.RemotePartitionDeleteMetadata; + +import java.io.IOException; + +public abstract class RemotePartitionMetadataEventHandler { + + public void handleRemoteLogMetadata(RemoteLogMetadata remoteLogMetadata) { + if (remoteLogMetadata instanceof RemoteLogSegmentMetadata) { + handleRemoteLogSegmentMetadata((RemoteLogSegmentMetadata) remoteLogMetadata); + } else if (remoteLogMetadata instanceof RemoteLogSegmentMetadataUpdate) { + handleRemoteLogSegmentMetadataUpdate((RemoteLogSegmentMetadataUpdate) remoteLogMetadata); + } else if (remoteLogMetadata instanceof RemotePartitionDeleteMetadata) { + handleRemotePartitionDeleteMetadata((RemotePartitionDeleteMetadata) remoteLogMetadata); + } else { + throw new IllegalArgumentException("remoteLogMetadata: " + remoteLogMetadata + " is not supported."); + } + } + + protected abstract void handleRemoteLogSegmentMetadata(RemoteLogSegmentMetadata remoteLogSegmentMetadata); + + protected abstract void handleRemoteLogSegmentMetadataUpdate(RemoteLogSegmentMetadataUpdate remoteLogSegmentMetadataUpdate); + + protected abstract void handleRemotePartitionDeleteMetadata(RemotePartitionDeleteMetadata remotePartitionDeleteMetadata); + + public abstract void syncLogMetadataSnapshot(TopicIdPartition topicIdPartition, + int metadataPartition, + Long metadataPartitionOffset) throws IOException; + + public abstract void clearTopicPartition(TopicIdPartition topicIdPartition); + +} \ No newline at end of file diff --git a/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/RemotePartitionMetadataStore.java b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/RemotePartitionMetadataStore.java new file mode 100644 index 0000000..7051d18 --- /dev/null +++ b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/RemotePartitionMetadataStore.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage; + +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentId; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadata; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadataUpdate; +import org.apache.kafka.server.log.remote.storage.RemotePartitionDeleteMetadata; +import org.apache.kafka.server.log.remote.storage.RemotePartitionDeleteState; +import org.apache.kafka.server.log.remote.storage.RemoteResourceNotFoundException; +import org.apache.kafka.server.log.remote.storage.RemoteStorageException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Closeable; +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.Collections; +import java.util.Iterator; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; + +/** + * This class represents a store to maintain the {@link RemotePartitionDeleteMetadata} and {@link RemoteLogMetadataCache} for each topic partition. + */ +public class RemotePartitionMetadataStore extends RemotePartitionMetadataEventHandler implements Closeable { + private static final Logger log = LoggerFactory.getLogger(RemotePartitionMetadataStore.class); + + private final Path logDir; + + private Map idToPartitionDeleteMetadata = + new ConcurrentHashMap<>(); + + private Map idToRemoteLogMetadataCache = + new ConcurrentHashMap<>(); + + public RemotePartitionMetadataStore(Path logDir) { + this.logDir = logDir; + } + + @Override + public void handleRemoteLogSegmentMetadata(RemoteLogSegmentMetadata remoteLogSegmentMetadata) { + log.debug("Adding remote log segment : [{}]", remoteLogSegmentMetadata); + + final RemoteLogSegmentId remoteLogSegmentId = remoteLogSegmentMetadata.remoteLogSegmentId(); + TopicIdPartition topicIdPartition = remoteLogSegmentId.topicIdPartition(); + + // This should have been already existing as it is loaded when the partitions are assigned. + RemoteLogMetadataCache remoteLogMetadataCache = idToRemoteLogMetadataCache.get(topicIdPartition); + if (remoteLogMetadataCache != null) { + remoteLogMetadataCache.addCopyInProgressSegment(remoteLogSegmentMetadata); + } else { + throw new IllegalStateException("No partition metadata found for : " + topicIdPartition); + } + } + + private Path partitionLogDirectory(TopicPartition topicPartition) { + return new File(logDir.toFile(), topicPartition.topic() + "-" + topicPartition.partition()).toPath(); + } + + @Override + public void handleRemoteLogSegmentMetadataUpdate(RemoteLogSegmentMetadataUpdate rlsmUpdate) { + log.debug("Updating remote log segment: [{}]", rlsmUpdate); + RemoteLogSegmentId remoteLogSegmentId = rlsmUpdate.remoteLogSegmentId(); + TopicIdPartition topicIdPartition = remoteLogSegmentId.topicIdPartition(); + RemoteLogMetadataCache remoteLogMetadataCache = idToRemoteLogMetadataCache.get(topicIdPartition); + if (remoteLogMetadataCache != null) { + try { + remoteLogMetadataCache.updateRemoteLogSegmentMetadata(rlsmUpdate); + } catch (RemoteResourceNotFoundException e) { + log.warn("Error occurred while updating the remote log segment.", e); + } + } else { + throw new IllegalStateException("No partition metadata found for : " + topicIdPartition); + } + } + + @Override + public void handleRemotePartitionDeleteMetadata(RemotePartitionDeleteMetadata remotePartitionDeleteMetadata) { + log.debug("Received partition delete state with: [{}]", remotePartitionDeleteMetadata); + + TopicIdPartition topicIdPartition = remotePartitionDeleteMetadata.topicIdPartition(); + idToPartitionDeleteMetadata.put(topicIdPartition, remotePartitionDeleteMetadata); + // there will be a trigger to receive delete partition marker and act on that to delete all the segments. + + if (remotePartitionDeleteMetadata.state() == RemotePartitionDeleteState.DELETE_PARTITION_FINISHED) { + // remove the association for the partition. + idToRemoteLogMetadataCache.remove(topicIdPartition); + idToPartitionDeleteMetadata.remove(topicIdPartition); + } + } + + @Override + public void syncLogMetadataSnapshot(TopicIdPartition topicIdPartition, + int metadataPartition, + Long metadataPartitionOffset) throws IOException { + RemotePartitionDeleteMetadata partitionDeleteMetadata = idToPartitionDeleteMetadata.get(topicIdPartition); + if (partitionDeleteMetadata != null) { + log.info("Skipping syncing of metadata snapshot as remote partition [{}] is with state: [{}] ", topicIdPartition, + partitionDeleteMetadata); + } else { + FileBasedRemoteLogMetadataCache remoteLogMetadataCache = idToRemoteLogMetadataCache.get(topicIdPartition); + if (remoteLogMetadataCache != null) { + remoteLogMetadataCache.flushToFile(metadataPartition, metadataPartitionOffset); + } + } + } + + @Override + public void clearTopicPartition(TopicIdPartition topicIdPartition) { + idToRemoteLogMetadataCache.remove(topicIdPartition); + } + + public Iterator listRemoteLogSegments(TopicIdPartition topicIdPartition) + throws RemoteStorageException { + Objects.requireNonNull(topicIdPartition, "topicIdPartition can not be null"); + + return getRemoteLogMetadataCache(topicIdPartition).listAllRemoteLogSegments(); + } + + public Iterator listRemoteLogSegments(TopicIdPartition topicIdPartition, int leaderEpoch) + throws RemoteStorageException { + Objects.requireNonNull(topicIdPartition, "topicIdPartition can not be null"); + + return getRemoteLogMetadataCache(topicIdPartition).listRemoteLogSegments(leaderEpoch); + } + + private FileBasedRemoteLogMetadataCache getRemoteLogMetadataCache(TopicIdPartition topicIdPartition) + throws RemoteResourceNotFoundException { + FileBasedRemoteLogMetadataCache remoteLogMetadataCache = idToRemoteLogMetadataCache.get(topicIdPartition); + if (remoteLogMetadataCache == null) { + throw new RemoteResourceNotFoundException("No resource found for partition: " + topicIdPartition); + } + + return remoteLogMetadataCache; + } + + public Optional remoteLogSegmentMetadata(TopicIdPartition topicIdPartition, + long offset, + int epochForOffset) + throws RemoteStorageException { + Objects.requireNonNull(topicIdPartition, "topicIdPartition can not be null"); + + return getRemoteLogMetadataCache(topicIdPartition).remoteLogSegmentMetadata(epochForOffset, offset); + } + + public Optional highestLogOffset(TopicIdPartition topicIdPartition, + int leaderEpoch) throws RemoteStorageException { + Objects.requireNonNull(topicIdPartition, "topicIdPartition can not be null"); + + return getRemoteLogMetadataCache(topicIdPartition).highestOffsetForEpoch(leaderEpoch); + } + + @Override + public void close() throws IOException { + log.info("Clearing the entries from the store."); + + // Clear the entries by creating unmodifiable empty maps. + // Practically, we do not use the same instances that are closed. + idToPartitionDeleteMetadata = Collections.emptyMap(); + idToRemoteLogMetadataCache = Collections.emptyMap(); + } + + public void maybeLoadPartition(TopicIdPartition partition) { + idToRemoteLogMetadataCache.computeIfAbsent(partition, + topicIdPartition -> new FileBasedRemoteLogMetadataCache(topicIdPartition, partitionLogDirectory(topicIdPartition.topicPartition()))); + } + +} diff --git a/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/TopicBasedRemoteLogMetadataManager.java b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/TopicBasedRemoteLogMetadataManager.java new file mode 100644 index 0000000..0271780 --- /dev/null +++ b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/TopicBasedRemoteLogMetadataManager.java @@ -0,0 +1,539 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage; + +import org.apache.kafka.clients.admin.AdminClient; +import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.clients.admin.TopicDescription; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.config.TopicConfig; +import org.apache.kafka.common.errors.RetriableException; +import org.apache.kafka.common.errors.TopicExistsException; +import org.apache.kafka.common.internals.FatalExitError; +import org.apache.kafka.common.utils.KafkaThread; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.server.log.remote.storage.RemoteLogMetadata; +import org.apache.kafka.server.log.remote.storage.RemoteLogMetadataManager; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadata; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadataUpdate; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentState; +import org.apache.kafka.server.log.remote.storage.RemotePartitionDeleteMetadata; +import org.apache.kafka.server.log.remote.storage.RemoteStorageException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.time.Duration; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.locks.ReentrantReadWriteLock; + +/** + * This is the {@link RemoteLogMetadataManager} implementation with storage as an internal topic with name {@link TopicBasedRemoteLogMetadataManagerConfig#REMOTE_LOG_METADATA_TOPIC_NAME}. + * This is used to publish and fetch {@link RemoteLogMetadata} for the registered user topic partitions with + * {@link #onPartitionLeadershipChanges(Set, Set)}. Each broker will have an instance of this class and it subscribes + * to metadata updates for the registered user topic partitions. + */ +public class TopicBasedRemoteLogMetadataManager implements RemoteLogMetadataManager { + private static final Logger log = LoggerFactory.getLogger(TopicBasedRemoteLogMetadataManager.class); + + private volatile boolean configured = false; + + // It indicates whether the close process of this instance is started or not via #close() method. + // Using AtomicBoolean instead of volatile as it may encounter http://findbugs.sourceforge.net/bugDescriptions.html#SP_SPIN_ON_FIELD + // if the field is read but not updated in a spin loop like in #initializeResources() method. + private final AtomicBoolean closing = new AtomicBoolean(false); + private final AtomicBoolean initialized = new AtomicBoolean(false); + private final Time time = Time.SYSTEM; + private final boolean startConsumerThread; + + private Thread initializationThread; + private volatile ProducerManager producerManager; + private volatile ConsumerManager consumerManager; + + // This allows to gracefully close this instance using {@link #close()} method while there are some pending or new + // requests calling different methods which use the resources like producer/consumer managers. + private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock(); + + private RemotePartitionMetadataStore remotePartitionMetadataStore; + private volatile TopicBasedRemoteLogMetadataManagerConfig rlmmConfig; + private volatile RemoteLogMetadataTopicPartitioner rlmmTopicPartitioner; + private final Set pendingAssignPartitions = Collections.synchronizedSet(new HashSet<>()); + private volatile boolean initializationFailed; + + public TopicBasedRemoteLogMetadataManager() { + this(true); + } + + // Visible for testing. + public TopicBasedRemoteLogMetadataManager(boolean startConsumerThread) { + this.startConsumerThread = startConsumerThread; + } + + @Override + public CompletableFuture addRemoteLogSegmentMetadata(RemoteLogSegmentMetadata remoteLogSegmentMetadata) + throws RemoteStorageException { + Objects.requireNonNull(remoteLogSegmentMetadata, "remoteLogSegmentMetadata can not be null"); + + // This allows gracefully rejecting the requests while closing of this instance is in progress, which triggers + // closing the producer/consumer manager instances. + lock.readLock().lock(); + try { + ensureInitializedAndNotClosed(); + + // This method is allowed only to add remote log segment with the initial state(which is RemoteLogSegmentState.COPY_SEGMENT_STARTED) + // but not to update the existing remote log segment metadata. + if (remoteLogSegmentMetadata.state() != RemoteLogSegmentState.COPY_SEGMENT_STARTED) { + throw new IllegalArgumentException( + "Given remoteLogSegmentMetadata should have state as " + RemoteLogSegmentState.COPY_SEGMENT_STARTED + + " but it contains state as: " + remoteLogSegmentMetadata.state()); + } + + // Publish the message to the topic. + return storeRemoteLogMetadata(remoteLogSegmentMetadata.remoteLogSegmentId().topicIdPartition(), + remoteLogSegmentMetadata); + } finally { + lock.readLock().unlock(); + } + } + + @Override + public CompletableFuture updateRemoteLogSegmentMetadata(RemoteLogSegmentMetadataUpdate segmentMetadataUpdate) + throws RemoteStorageException { + Objects.requireNonNull(segmentMetadataUpdate, "segmentMetadataUpdate can not be null"); + + lock.readLock().lock(); + try { + ensureInitializedAndNotClosed(); + + // Callers should use addRemoteLogSegmentMetadata to add RemoteLogSegmentMetadata with state as + // RemoteLogSegmentState.COPY_SEGMENT_STARTED. + if (segmentMetadataUpdate.state() == RemoteLogSegmentState.COPY_SEGMENT_STARTED) { + throw new IllegalArgumentException("Given remoteLogSegmentMetadata should not have the state as: " + + RemoteLogSegmentState.COPY_SEGMENT_STARTED); + } + + // Publish the message to the topic. + return storeRemoteLogMetadata(segmentMetadataUpdate.remoteLogSegmentId().topicIdPartition(), segmentMetadataUpdate); + } finally { + lock.readLock().unlock(); + } + } + + @Override + public CompletableFuture putRemotePartitionDeleteMetadata(RemotePartitionDeleteMetadata remotePartitionDeleteMetadata) + throws RemoteStorageException { + Objects.requireNonNull(remotePartitionDeleteMetadata, "remotePartitionDeleteMetadata can not be null"); + + lock.readLock().lock(); + try { + ensureInitializedAndNotClosed(); + + return storeRemoteLogMetadata(remotePartitionDeleteMetadata.topicIdPartition(), remotePartitionDeleteMetadata); + } finally { + lock.readLock().unlock(); + } + } + + /** + * Returns {@link CompletableFuture} which will complete only after publishing of the given {@code remoteLogMetadata} into + * the remote log metadata topic and the internal consumer is caught up until the produced record's offset. + * + * @param topicIdPartition partition of the given remoteLogMetadata. + * @param remoteLogMetadata RemoteLogMetadata to be stored. + * @return + * @throws RemoteStorageException if there are any storage errors occur. + */ + private CompletableFuture storeRemoteLogMetadata(TopicIdPartition topicIdPartition, + RemoteLogMetadata remoteLogMetadata) + throws RemoteStorageException { + log.debug("Storing metadata for partition: [{}] with context: [{}]", topicIdPartition, remoteLogMetadata); + + try { + // Publish the message to the metadata topic. + CompletableFuture produceFuture = producerManager.publishMessage(remoteLogMetadata); + + // Create and return a `CompletableFuture` instance which completes when the consumer is caught up with the produced record's offset. + return produceFuture.thenApplyAsync(recordMetadata -> { + try { + consumerManager.waitTillConsumptionCatchesUp(recordMetadata); + } catch (TimeoutException e) { + throw new KafkaException(e); + } + return null; + }); + } catch (KafkaException e) { + if (e instanceof RetriableException) { + throw e; + } else { + throw new RemoteStorageException(e); + } + } + } + + @Override + public Optional remoteLogSegmentMetadata(TopicIdPartition topicIdPartition, + int epochForOffset, + long offset) + throws RemoteStorageException { + lock.readLock().lock(); + try { + ensureInitializedAndNotClosed(); + + return remotePartitionMetadataStore.remoteLogSegmentMetadata(topicIdPartition, offset, epochForOffset); + } finally { + lock.readLock().unlock(); + } + } + + @Override + public Optional highestOffsetForEpoch(TopicIdPartition topicIdPartition, + int leaderEpoch) + throws RemoteStorageException { + lock.readLock().lock(); + try { + + ensureInitializedAndNotClosed(); + + return remotePartitionMetadataStore.highestLogOffset(topicIdPartition, leaderEpoch); + } finally { + lock.readLock().unlock(); + } + + } + + @Override + public Iterator listRemoteLogSegments(TopicIdPartition topicIdPartition) + throws RemoteStorageException { + Objects.requireNonNull(topicIdPartition, "topicIdPartition can not be null"); + + lock.readLock().lock(); + try { + ensureInitializedAndNotClosed(); + + return remotePartitionMetadataStore.listRemoteLogSegments(topicIdPartition); + } finally { + lock.readLock().unlock(); + } + } + + @Override + public Iterator listRemoteLogSegments(TopicIdPartition topicIdPartition, int leaderEpoch) + throws RemoteStorageException { + Objects.requireNonNull(topicIdPartition, "topicIdPartition can not be null"); + + lock.readLock().lock(); + try { + ensureInitializedAndNotClosed(); + + return remotePartitionMetadataStore.listRemoteLogSegments(topicIdPartition, leaderEpoch); + } finally { + lock.readLock().unlock(); + } + } + + public int metadataPartition(TopicIdPartition topicIdPartition) { + return rlmmTopicPartitioner.metadataPartition(topicIdPartition); + } + + // Visible For Testing + public Optional receivedOffsetForPartition(int metadataPartition) { + return consumerManager.receivedOffsetForPartition(metadataPartition); + } + + @Override + public void onPartitionLeadershipChanges(Set leaderPartitions, + Set followerPartitions) { + Objects.requireNonNull(leaderPartitions, "leaderPartitions can not be null"); + Objects.requireNonNull(followerPartitions, "followerPartitions can not be null"); + + log.info("Received leadership notifications with leader partitions {} and follower partitions {}", + leaderPartitions, followerPartitions); + + lock.readLock().lock(); + try { + if (closing.get()) { + throw new IllegalStateException("This instance is in closing state"); + } + + HashSet allPartitions = new HashSet<>(leaderPartitions); + allPartitions.addAll(followerPartitions); + if (!initialized.get()) { + // If it is not yet initialized, then keep them as pending partitions and assign them + // when it is initialized successfully in initializeResources(). + this.pendingAssignPartitions.addAll(allPartitions); + } else { + assignPartitions(allPartitions); + } + } finally { + lock.readLock().unlock(); + } + } + + private void assignPartitions(Set allPartitions) { + for (TopicIdPartition partition : allPartitions) { + remotePartitionMetadataStore.maybeLoadPartition(partition); + } + + consumerManager.addAssignmentsForPartitions(allPartitions); + } + + @Override + public void onStopPartitions(Set partitions) { + lock.readLock().lock(); + try { + if (closing.get()) { + throw new IllegalStateException("This instance is in closing state"); + } + + if (!initialized.get()) { + // If it is not yet initialized, then remove them from the pending partitions if any. + if (!pendingAssignPartitions.isEmpty()) { + pendingAssignPartitions.removeAll(partitions); + } + } else { + consumerManager.removeAssignmentsForPartitions(partitions); + } + } finally { + lock.readLock().unlock(); + } + } + + @Override + public void configure(Map configs) { + Objects.requireNonNull(configs, "configs can not be null."); + + lock.writeLock().lock(); + try { + if (configured) { + log.info("Skipping configure as it is already configured."); + return; + } + + log.info("Started initializing with configs: {}", configs); + + rlmmConfig = new TopicBasedRemoteLogMetadataManagerConfig(configs); + rlmmTopicPartitioner = new RemoteLogMetadataTopicPartitioner(rlmmConfig.metadataTopicPartitionsCount()); + remotePartitionMetadataStore = new RemotePartitionMetadataStore(new File(rlmmConfig.logDir()).toPath()); + configured = true; + log.info("Successfully initialized with rlmmConfig: {}", rlmmConfig); + + // Scheduling the initialization producer/consumer managers in a separate thread. Required resources may + // not yet be available now. This thread makes sure that it is retried at regular intervals until it is + // successful. + initializationThread = KafkaThread.nonDaemon("RLMMInitializationThread", () -> initializeResources()); + initializationThread.start(); + } finally { + lock.writeLock().unlock(); + } + } + + private void initializeResources() { + log.info("Initializing the resources."); + final NewTopic remoteLogMetadataTopicRequest = createRemoteLogMetadataTopicRequest(); + boolean topicCreated = false; + long startTimeMs = time.milliseconds(); + AdminClient adminClient = null; + try { + adminClient = AdminClient.create(rlmmConfig.producerProperties()); + + // Stop if it is already initialized or closing. + while (!(initialized.get() || closing.get())) { + + // If it is timed out then raise an error to exit. + if (time.milliseconds() - startTimeMs > rlmmConfig.initializationRetryMaxTimeoutMs()) { + log.error("Timed out in initializing the resources, retried to initialize the resource for [{}] ms.", + rlmmConfig.initializationRetryMaxTimeoutMs()); + initializationFailed = true; + return; + } + + if (!topicCreated) { + topicCreated = createTopic(adminClient, remoteLogMetadataTopicRequest); + } + + if (!topicCreated) { + // Sleep for INITIALIZATION_RETRY_INTERVAL_MS before trying to create the topic again. + log.info("Sleep for : {} ms before it is retried again.", rlmmConfig.initializationRetryIntervalMs()); + Utils.sleep(rlmmConfig.initializationRetryIntervalMs()); + continue; + } else { + // If topic is already created, validate the existing topic partitions. + try { + String topicName = remoteLogMetadataTopicRequest.name(); + // If the existing topic partition size is not same as configured, mark initialization as failed and exit. + if (!isPartitionsCountSameAsConfigured(adminClient, topicName)) { + initializationFailed = true; + } + } catch (Exception e) { + log.info("Sleep for : {} ms before it is retried again.", rlmmConfig.initializationRetryIntervalMs()); + Utils.sleep(rlmmConfig.initializationRetryIntervalMs()); + continue; + } + } + + // Create producer and consumer managers. + lock.writeLock().lock(); + try { + producerManager = new ProducerManager(rlmmConfig, rlmmTopicPartitioner); + consumerManager = new ConsumerManager(rlmmConfig, remotePartitionMetadataStore, rlmmTopicPartitioner, time); + if (startConsumerThread) { + consumerManager.startConsumerThread(); + } else { + log.info("RLMM Consumer task thread is not configured to be started."); + } + + if (!pendingAssignPartitions.isEmpty()) { + assignPartitions(pendingAssignPartitions); + pendingAssignPartitions.clear(); + } + + initialized.set(true); + log.info("Initialized resources successfully."); + } catch (Exception e) { + log.error("Encountered error while initializing producer/consumer", e); + return; + } finally { + lock.writeLock().unlock(); + } + } + + } finally { + if (adminClient != null) { + try { + adminClient.close(Duration.ofSeconds(10)); + } catch (Exception e) { + // Ignore the error. + log.debug("Error occurred while closing the admin client", e); + } + } + } + } + + private boolean isPartitionsCountSameAsConfigured(AdminClient adminClient, + String topicName) throws InterruptedException, ExecutionException { + log.debug("Getting topic details to check for partition count and replication factor."); + TopicDescription topicDescription = adminClient.describeTopics(Collections.singleton(topicName)) + .topicNameValues().get(topicName).get(); + int expectedPartitions = rlmmConfig.metadataTopicPartitionsCount(); + int topicPartitionsSize = topicDescription.partitions().size(); + + if (topicPartitionsSize != expectedPartitions) { + log.error("Existing topic partition count [{}] is not same as the expected partition count [{}]", + topicPartitionsSize, expectedPartitions); + return false; + } + + return true; + } + + private NewTopic createRemoteLogMetadataTopicRequest() { + Map topicConfigs = new HashMap<>(); + topicConfigs.put(TopicConfig.RETENTION_MS_CONFIG, Long.toString(rlmmConfig.metadataTopicRetentionMs())); + topicConfigs.put(TopicConfig.CLEANUP_POLICY_CONFIG, TopicConfig.CLEANUP_POLICY_DELETE); + return new NewTopic(rlmmConfig.remoteLogMetadataTopicName(), + rlmmConfig.metadataTopicPartitionsCount(), + rlmmConfig.metadataTopicReplicationFactor()).configs(topicConfigs); + } + + /** + * @param topic topic to be created. + * @return Returns true if the topic already exists or it is created successfully. + */ + private boolean createTopic(AdminClient adminClient, NewTopic topic) { + boolean topicCreated = false; + try { + adminClient.createTopics(Collections.singleton(topic)).all().get(); + topicCreated = true; + } catch (Exception e) { + if (e.getCause() instanceof TopicExistsException) { + log.info("Topic [{}] already exists", topic.name()); + topicCreated = true; + } else { + log.error("Encountered error while creating remote log metadata topic.", e); + } + } + + return topicCreated; + } + + public boolean isInitialized() { + return initialized.get(); + } + + private void ensureInitializedAndNotClosed() { + if (initializationFailed) { + // If initialization is failed, shutdown the broker. + throw new FatalExitError(); + } + if (closing.get() || !initialized.get()) { + throw new IllegalStateException("This instance is in invalid state, initialized: " + initialized + + " close: " + closing); + } + } + + // Visible for testing. + public TopicBasedRemoteLogMetadataManagerConfig config() { + return rlmmConfig; + } + + // Visible for testing. + public void startConsumerThread() { + if (consumerManager != null) { + consumerManager.startConsumerThread(); + } + } + + @Override + public void close() throws IOException { + // Close all the resources. + log.info("Closing the resources."); + if (closing.compareAndSet(false, true)) { + lock.writeLock().lock(); + try { + if (initializationThread != null) { + try { + initializationThread.join(); + } catch (InterruptedException e) { + log.error("Initialization thread was interrupted while waiting to join on close.", e); + } + } + + Utils.closeQuietly(producerManager, "ProducerTask"); + Utils.closeQuietly(consumerManager, "RLMMConsumerManager"); + Utils.closeQuietly(remotePartitionMetadataStore, "RemotePartitionMetadataStore"); + } finally { + lock.writeLock().unlock(); + log.info("Closed the resources."); + } + } + } +} diff --git a/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/TopicBasedRemoteLogMetadataManagerConfig.java b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/TopicBasedRemoteLogMetadataManagerConfig.java new file mode 100644 index 0000000..7e52519 --- /dev/null +++ b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/TopicBasedRemoteLogMetadataManagerConfig.java @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.apache.kafka.common.serialization.ByteArraySerializer; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.apache.kafka.common.config.ConfigDef.Importance.LOW; +import static org.apache.kafka.common.config.ConfigDef.Range.atLeast; +import static org.apache.kafka.common.config.ConfigDef.Type.INT; +import static org.apache.kafka.common.config.ConfigDef.Type.LONG; +import static org.apache.kafka.common.config.ConfigDef.Type.SHORT; + +/** + * This class defines the configuration of topic based {@link org.apache.kafka.server.log.remote.storage.RemoteLogMetadataManager} implementation. + */ +public final class TopicBasedRemoteLogMetadataManagerConfig { + + public static final String REMOTE_LOG_METADATA_TOPIC_NAME = "__remote_log_metadata"; + + public static final String REMOTE_LOG_METADATA_TOPIC_REPLICATION_FACTOR_PROP = "remote.log.metadata.topic.replication.factor"; + public static final String REMOTE_LOG_METADATA_TOPIC_PARTITIONS_PROP = "remote.log.metadata.topic.num.partitions"; + public static final String REMOTE_LOG_METADATA_TOPIC_RETENTION_MS_PROP = "remote.log.metadata.topic.retention.ms"; + public static final String REMOTE_LOG_METADATA_CONSUME_WAIT_MS_PROP = "remote.log.metadata.consume.wait.ms"; + public static final String REMOTE_LOG_METADATA_INITIALIZATION_RETRY_MAX_TIMEOUT_MS_PROP = "remote.log.metadata.initialization.retry.max.timeout.ms"; + public static final String REMOTE_LOG_METADATA_INITIALIZATION_RETRY_INTERVAL_MS_PROP = "remote.log.metadata.initialization.retry.interval.ms"; + + public static final int DEFAULT_REMOTE_LOG_METADATA_TOPIC_PARTITIONS = 50; + public static final long DEFAULT_REMOTE_LOG_METADATA_TOPIC_RETENTION_MILLIS = -1L; + public static final short DEFAULT_REMOTE_LOG_METADATA_TOPIC_REPLICATION_FACTOR = 3; + public static final long DEFAULT_REMOTE_LOG_METADATA_CONSUME_WAIT_MS = 2 * 60 * 1000L; + public static final long DEFAULT_REMOTE_LOG_METADATA_INITIALIZATION_RETRY_MAX_TIMEOUT_MS = 2 * 60 * 1000L; + public static final long DEFAULT_REMOTE_LOG_METADATA_INITIALIZATION_RETRY_INTERVAL_MS = 5 * 1000L; + + public static final String REMOTE_LOG_METADATA_TOPIC_REPLICATION_FACTOR_DOC = "Replication factor of remote log metadata Topic."; + public static final String REMOTE_LOG_METADATA_TOPIC_PARTITIONS_DOC = "The number of partitions for remote log metadata Topic."; + public static final String REMOTE_LOG_METADATA_TOPIC_RETENTION_MS_DOC = "Remote log metadata topic log retention in milli seconds." + + "Default: -1, that means unlimited. Users can configure this value based on their use cases. " + + "To avoid any data loss, this value should be more than the maximum retention period of any topic enabled with " + + "tiered storage in the cluster."; + public static final String REMOTE_LOG_METADATA_CONSUME_WAIT_MS_DOC = "The amount of time in milli seconds to wait for the local consumer to " + + "receive the published event."; + public static final String REMOTE_LOG_METADATA_INITIALIZATION_RETRY_INTERVAL_MS_DOC = "The retry interval in milli seconds for " + + " retrying RemoteLogMetadataManager resources initialization again."; + + public static final String REMOTE_LOG_METADATA_INITIALIZATION_RETRY_MAX_TIMEOUT_MS_DOC = "The maximum amount of time in milli seconds " + + " for retrying RemoteLogMetadataManager resources initialization. When total retry intervals reach this timeout, initialization" + + " is considered as failed and broker starts shutting down."; + + public static final String REMOTE_LOG_METADATA_COMMON_CLIENT_PREFIX = "remote.log.metadata.common.client."; + public static final String REMOTE_LOG_METADATA_PRODUCER_PREFIX = "remote.log.metadata.producer."; + public static final String REMOTE_LOG_METADATA_CONSUMER_PREFIX = "remote.log.metadata.consumer."; + public static final String BROKER_ID = "broker.id"; + public static final String LOG_DIR = "log.dir"; + + private static final String REMOTE_LOG_METADATA_CLIENT_PREFIX = "__remote_log_metadata_client"; + + private static final ConfigDef CONFIG = new ConfigDef(); + static { + CONFIG.define(REMOTE_LOG_METADATA_TOPIC_REPLICATION_FACTOR_PROP, SHORT, DEFAULT_REMOTE_LOG_METADATA_TOPIC_REPLICATION_FACTOR, atLeast(1), LOW, + REMOTE_LOG_METADATA_TOPIC_REPLICATION_FACTOR_DOC) + .define(REMOTE_LOG_METADATA_TOPIC_PARTITIONS_PROP, INT, DEFAULT_REMOTE_LOG_METADATA_TOPIC_PARTITIONS, atLeast(1), LOW, + REMOTE_LOG_METADATA_TOPIC_PARTITIONS_DOC) + .define(REMOTE_LOG_METADATA_TOPIC_RETENTION_MS_PROP, LONG, DEFAULT_REMOTE_LOG_METADATA_TOPIC_RETENTION_MILLIS, LOW, + REMOTE_LOG_METADATA_TOPIC_RETENTION_MS_DOC) + .define(REMOTE_LOG_METADATA_CONSUME_WAIT_MS_PROP, LONG, DEFAULT_REMOTE_LOG_METADATA_CONSUME_WAIT_MS, atLeast(0), LOW, + REMOTE_LOG_METADATA_CONSUME_WAIT_MS_DOC) + .define(REMOTE_LOG_METADATA_INITIALIZATION_RETRY_MAX_TIMEOUT_MS_PROP, LONG, + DEFAULT_REMOTE_LOG_METADATA_INITIALIZATION_RETRY_MAX_TIMEOUT_MS, atLeast(0), LOW, + REMOTE_LOG_METADATA_INITIALIZATION_RETRY_MAX_TIMEOUT_MS_DOC) + .define(REMOTE_LOG_METADATA_INITIALIZATION_RETRY_INTERVAL_MS_PROP, LONG, + DEFAULT_REMOTE_LOG_METADATA_INITIALIZATION_RETRY_INTERVAL_MS, atLeast(0), LOW, + REMOTE_LOG_METADATA_INITIALIZATION_RETRY_INTERVAL_MS_DOC); + } + + private final String clientIdPrefix; + private final int metadataTopicPartitionsCount; + private final String logDir; + private final long consumeWaitMs; + private final long metadataTopicRetentionMs; + private final short metadataTopicReplicationFactor; + private final long initializationRetryMaxTimeoutMs; + private final long initializationRetryIntervalMs; + + private Map consumerProps; + private Map producerProps; + + public TopicBasedRemoteLogMetadataManagerConfig(Map props) { + Objects.requireNonNull(props, "props can not be null"); + + Map parsedConfigs = CONFIG.parse(props); + + logDir = (String) props.get(LOG_DIR); + if (logDir == null || logDir.isEmpty()) { + throw new IllegalArgumentException(LOG_DIR + " config must not be null or empty."); + } + + metadataTopicPartitionsCount = (int) parsedConfigs.get(REMOTE_LOG_METADATA_TOPIC_PARTITIONS_PROP); + metadataTopicReplicationFactor = (short) parsedConfigs.get(REMOTE_LOG_METADATA_TOPIC_REPLICATION_FACTOR_PROP); + metadataTopicRetentionMs = (long) parsedConfigs.get(REMOTE_LOG_METADATA_TOPIC_RETENTION_MS_PROP); + if (metadataTopicRetentionMs != -1 && metadataTopicRetentionMs <= 0) { + throw new IllegalArgumentException("Invalid metadata topic retention in millis: " + metadataTopicRetentionMs); + } + consumeWaitMs = (long) parsedConfigs.get(REMOTE_LOG_METADATA_CONSUME_WAIT_MS_PROP); + initializationRetryIntervalMs = (long) parsedConfigs.get(REMOTE_LOG_METADATA_INITIALIZATION_RETRY_INTERVAL_MS_PROP); + initializationRetryMaxTimeoutMs = (long) parsedConfigs.get(REMOTE_LOG_METADATA_INITIALIZATION_RETRY_MAX_TIMEOUT_MS_PROP); + + clientIdPrefix = REMOTE_LOG_METADATA_CLIENT_PREFIX + "_" + props.get(BROKER_ID); + + initializeProducerConsumerProperties(props); + } + + private void initializeProducerConsumerProperties(Map configs) { + Map commonClientConfigs = new HashMap<>(); + Map producerOnlyConfigs = new HashMap<>(); + Map consumerOnlyConfigs = new HashMap<>(); + + for (Map.Entry entry : configs.entrySet()) { + String key = entry.getKey(); + if (key.startsWith(REMOTE_LOG_METADATA_COMMON_CLIENT_PREFIX)) { + commonClientConfigs.put(key.substring(REMOTE_LOG_METADATA_COMMON_CLIENT_PREFIX.length()), entry.getValue()); + } else if (key.startsWith(REMOTE_LOG_METADATA_PRODUCER_PREFIX)) { + producerOnlyConfigs.put(key.substring(REMOTE_LOG_METADATA_PRODUCER_PREFIX.length()), entry.getValue()); + } else if (key.startsWith(REMOTE_LOG_METADATA_CONSUMER_PREFIX)) { + consumerOnlyConfigs.put(key.substring(REMOTE_LOG_METADATA_CONSUMER_PREFIX.length()), entry.getValue()); + } + } + + HashMap allProducerConfigs = new HashMap<>(commonClientConfigs); + allProducerConfigs.putAll(producerOnlyConfigs); + producerProps = createProducerProps(allProducerConfigs); + + HashMap allConsumerConfigs = new HashMap<>(commonClientConfigs); + allConsumerConfigs.putAll(consumerOnlyConfigs); + consumerProps = createConsumerProps(allConsumerConfigs); + } + + public String remoteLogMetadataTopicName() { + return REMOTE_LOG_METADATA_TOPIC_NAME; + } + + public int metadataTopicPartitionsCount() { + return metadataTopicPartitionsCount; + } + + public short metadataTopicReplicationFactor() { + return metadataTopicReplicationFactor; + } + + public long metadataTopicRetentionMs() { + return metadataTopicRetentionMs; + } + + public long consumeWaitMs() { + return consumeWaitMs; + } + + public long initializationRetryMaxTimeoutMs() { + return initializationRetryMaxTimeoutMs; + } + + public long initializationRetryIntervalMs() { + return initializationRetryIntervalMs; + } + + public String logDir() { + return logDir; + } + + public Map consumerProperties() { + return consumerProps; + } + + public Map producerProperties() { + return producerProps; + } + + private Map createConsumerProps(HashMap allConsumerConfigs) { + Map props = new HashMap<>(allConsumerConfigs); + + props.put(CommonClientConfigs.CLIENT_ID_CONFIG, clientIdPrefix + "_consumer"); + props.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, false); + props.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + props.put(ConsumerConfig.EXCLUDE_INTERNAL_TOPICS_CONFIG, false); + props.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class.getName()); + props.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class.getName()); + return props; + } + + private Map createProducerProps(HashMap allProducerConfigs) { + Map props = new HashMap<>(allProducerConfigs); + + props.put(ProducerConfig.CLIENT_ID_CONFIG, clientIdPrefix + "_producer"); + props.put(ProducerConfig.ACKS_CONFIG, "all"); + props.put(ProducerConfig.MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION, 1); + props.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class.getName()); + props.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class.getName()); + + return Collections.unmodifiableMap(props); + } + + @Override + public String toString() { + return "TopicBasedRemoteLogMetadataManagerConfig{" + + "clientIdPrefix='" + clientIdPrefix + '\'' + + ", metadataTopicPartitionsCount=" + metadataTopicPartitionsCount + + ", consumeWaitMs=" + consumeWaitMs + + ", metadataTopicRetentionMs=" + metadataTopicRetentionMs + + ", metadataTopicReplicationFactor=" + metadataTopicReplicationFactor + + ", initializationRetryMaxTimeoutMs=" + initializationRetryMaxTimeoutMs + + ", initializationRetryIntervalMs=" + initializationRetryIntervalMs + + ", consumerProps=" + consumerProps + + ", producerProps=" + producerProps + + '}'; + } +} diff --git a/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/serialization/RemoteLogMetadataSerde.java b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/serialization/RemoteLogMetadataSerde.java new file mode 100644 index 0000000..4a63b56 --- /dev/null +++ b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/serialization/RemoteLogMetadataSerde.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage.serialization; + +import org.apache.kafka.common.protocol.ApiMessage; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.apache.kafka.server.common.serialization.BytesApiMessageSerde; +import org.apache.kafka.server.log.remote.metadata.storage.RemoteLogSegmentMetadataSnapshot; +import org.apache.kafka.server.log.remote.metadata.storage.generated.MetadataRecordType; +import org.apache.kafka.server.log.remote.metadata.storage.generated.RemoteLogSegmentMetadataRecord; +import org.apache.kafka.server.log.remote.metadata.storage.generated.RemoteLogSegmentMetadataSnapshotRecord; +import org.apache.kafka.server.log.remote.metadata.storage.generated.RemoteLogSegmentMetadataUpdateRecord; +import org.apache.kafka.server.log.remote.metadata.storage.generated.RemotePartitionDeleteMetadataRecord; +import org.apache.kafka.server.log.remote.storage.RemoteLogMetadata; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadata; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadataUpdate; +import org.apache.kafka.server.log.remote.storage.RemotePartitionDeleteMetadata; + +import java.util.HashMap; +import java.util.Map; + +/** + * This class provides serialization and deserialization for {@link RemoteLogMetadata}. This is the root serde + * for the messages that are stored in internal remote log metadata topic. + */ +public class RemoteLogMetadataSerde { + private static final short REMOTE_LOG_SEGMENT_METADATA_API_KEY = new RemoteLogSegmentMetadataRecord().apiKey(); + private static final short REMOTE_LOG_SEGMENT_METADATA_UPDATE_API_KEY = new RemoteLogSegmentMetadataUpdateRecord().apiKey(); + private static final short REMOTE_PARTITION_DELETE_API_KEY = new RemotePartitionDeleteMetadataRecord().apiKey(); + private static final short REMOTE_LOG_SEGMENT_METADATA_SNAPSHOT_API_KEY = new RemoteLogSegmentMetadataSnapshotRecord().apiKey(); + + private final Map remoteLogStorageClassToApiKey; + private final Map keyToTransform; + private final BytesApiMessageSerde bytesApiMessageSerde; + + public RemoteLogMetadataSerde() { + remoteLogStorageClassToApiKey = createRemoteLogStorageClassToApiKeyMap(); + keyToTransform = createRemoteLogMetadataTransforms(); + bytesApiMessageSerde = new BytesApiMessageSerde() { + @Override + public ApiMessage apiMessageFor(short apiKey) { + return newApiMessage(apiKey); + } + }; + } + + protected ApiMessage newApiMessage(short apiKey) { + return MetadataRecordType.fromId(apiKey).newMetadataRecord(); + } + + protected Map createRemoteLogMetadataTransforms() { + Map map = new HashMap<>(); + map.put(REMOTE_LOG_SEGMENT_METADATA_API_KEY, new RemoteLogSegmentMetadataTransform()); + map.put(REMOTE_LOG_SEGMENT_METADATA_UPDATE_API_KEY, new RemoteLogSegmentMetadataUpdateTransform()); + map.put(REMOTE_PARTITION_DELETE_API_KEY, new RemotePartitionDeleteMetadataTransform()); + map.put(REMOTE_LOG_SEGMENT_METADATA_SNAPSHOT_API_KEY, new RemoteLogSegmentMetadataSnapshotTransform()); + return map; + } + + protected Map createRemoteLogStorageClassToApiKeyMap() { + Map map = new HashMap<>(); + map.put(RemoteLogSegmentMetadata.class.getName(), REMOTE_LOG_SEGMENT_METADATA_API_KEY); + map.put(RemoteLogSegmentMetadataUpdate.class.getName(), REMOTE_LOG_SEGMENT_METADATA_UPDATE_API_KEY); + map.put(RemotePartitionDeleteMetadata.class.getName(), REMOTE_PARTITION_DELETE_API_KEY); + map.put(RemoteLogSegmentMetadataSnapshot.class.getName(), REMOTE_LOG_SEGMENT_METADATA_SNAPSHOT_API_KEY); + return map; + } + + public byte[] serialize(RemoteLogMetadata remoteLogMetadata) { + Short apiKey = remoteLogStorageClassToApiKey.get(remoteLogMetadata.getClass().getName()); + if (apiKey == null) { + throw new IllegalArgumentException("ApiKey for given RemoteStorageMetadata class: " + remoteLogMetadata.getClass() + + " does not exist."); + } + + @SuppressWarnings("unchecked") + ApiMessageAndVersion apiMessageAndVersion = remoteLogMetadataTransform(apiKey).toApiMessageAndVersion(remoteLogMetadata); + + return bytesApiMessageSerde.serialize(apiMessageAndVersion); + } + + public RemoteLogMetadata deserialize(byte[] data) { + ApiMessageAndVersion apiMessageAndVersion = bytesApiMessageSerde.deserialize(data); + + return remoteLogMetadataTransform(apiMessageAndVersion.message().apiKey()).fromApiMessageAndVersion(apiMessageAndVersion); + } + + private RemoteLogMetadataTransform remoteLogMetadataTransform(short apiKey) { + RemoteLogMetadataTransform metadataTransform = keyToTransform.get(apiKey); + if (metadataTransform == null) { + throw new IllegalArgumentException("RemoteLogMetadataTransform for apikey: " + apiKey + " does not exist."); + } + + return metadataTransform; + } +} \ No newline at end of file diff --git a/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/serialization/RemoteLogMetadataTransform.java b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/serialization/RemoteLogMetadataTransform.java new file mode 100644 index 0000000..b6e3582 --- /dev/null +++ b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/serialization/RemoteLogMetadataTransform.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage.serialization; + +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.apache.kafka.server.log.remote.storage.RemoteLogMetadata; + +/** + * This interface is about transforming {@link RemoteLogMetadata} objects into the respective {@link ApiMessageAndVersion} or vice versa. + *

                + * Those metadata objects can be {@link org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadata}, + * {@link org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadataUpdate}, or {@link org.apache.kafka.server.log.remote.storage.RemotePartitionDeleteMetadata}. + *

                + * @param metadata type. + * + * @see RemoteLogSegmentMetadataTransform + * @see RemoteLogSegmentMetadataUpdateTransform + * @see RemotePartitionDeleteMetadataTransform + */ +public interface RemoteLogMetadataTransform { + + /** + * Transforms the given {@code metadata} object into the respective {@code ApiMessageAndVersion} object. + * + * @param metadata metadata object to be transformed. + * @return transformed {@code ApiMessageAndVersion} object. + */ + ApiMessageAndVersion toApiMessageAndVersion(T metadata); + + /** + * Return the metadata object transformed from the given {@code apiMessageAndVersion}. + * + * @param apiMessageAndVersion ApiMessageAndVersion object to be transformed. + * @return transformed {@code T} metadata object. + */ + T fromApiMessageAndVersion(ApiMessageAndVersion apiMessageAndVersion); + +} diff --git a/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/serialization/RemoteLogSegmentMetadataSnapshotTransform.java b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/serialization/RemoteLogSegmentMetadataSnapshotTransform.java new file mode 100644 index 0000000..bd613f8 --- /dev/null +++ b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/serialization/RemoteLogSegmentMetadataSnapshotTransform.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage.serialization; + +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.apache.kafka.server.log.remote.metadata.storage.RemoteLogSegmentMetadataSnapshot; +import org.apache.kafka.server.log.remote.metadata.storage.generated.RemoteLogSegmentMetadataSnapshotRecord; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentState; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +public class RemoteLogSegmentMetadataSnapshotTransform implements RemoteLogMetadataTransform { + + public ApiMessageAndVersion toApiMessageAndVersion(RemoteLogSegmentMetadataSnapshot segmentMetadata) { + RemoteLogSegmentMetadataSnapshotRecord record = new RemoteLogSegmentMetadataSnapshotRecord() + .setSegmentId(segmentMetadata.segmentId()) + .setStartOffset(segmentMetadata.startOffset()) + .setEndOffset(segmentMetadata.endOffset()) + .setBrokerId(segmentMetadata.brokerId()) + .setEventTimestampMs(segmentMetadata.eventTimestampMs()) + .setMaxTimestampMs(segmentMetadata.maxTimestampMs()) + .setSegmentSizeInBytes(segmentMetadata.segmentSizeInBytes()) + .setSegmentLeaderEpochs(createSegmentLeaderEpochsEntry(segmentMetadata.segmentLeaderEpochs())) + .setRemoteLogSegmentState(segmentMetadata.state().id()); + + return new ApiMessageAndVersion(record, record.highestSupportedVersion()); + } + + private List createSegmentLeaderEpochsEntry(Map leaderEpochs) { + return leaderEpochs.entrySet().stream() + .map(entry -> new RemoteLogSegmentMetadataSnapshotRecord.SegmentLeaderEpochEntry() + .setLeaderEpoch(entry.getKey()) + .setOffset(entry.getValue())) + .collect(Collectors.toList()); + } + + @Override + public RemoteLogSegmentMetadataSnapshot fromApiMessageAndVersion(ApiMessageAndVersion apiMessageAndVersion) { + RemoteLogSegmentMetadataSnapshotRecord record = (RemoteLogSegmentMetadataSnapshotRecord) apiMessageAndVersion.message(); + Map segmentLeaderEpochs = new HashMap<>(); + for (RemoteLogSegmentMetadataSnapshotRecord.SegmentLeaderEpochEntry segmentLeaderEpoch : record.segmentLeaderEpochs()) { + segmentLeaderEpochs.put(segmentLeaderEpoch.leaderEpoch(), segmentLeaderEpoch.offset()); + } + + return new RemoteLogSegmentMetadataSnapshot(record.segmentId(), + record.startOffset(), + record.endOffset(), + record.maxTimestampMs(), + record.brokerId(), + record.eventTimestampMs(), + record.segmentSizeInBytes(), + RemoteLogSegmentState.forId(record.remoteLogSegmentState()), + segmentLeaderEpochs); + } + +} diff --git a/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/serialization/RemoteLogSegmentMetadataTransform.java b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/serialization/RemoteLogSegmentMetadataTransform.java new file mode 100644 index 0000000..4282b9e --- /dev/null +++ b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/serialization/RemoteLogSegmentMetadataTransform.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage.serialization; + +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.apache.kafka.server.log.remote.metadata.storage.generated.RemoteLogSegmentMetadataRecord; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentId; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadata; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadataUpdate; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentState; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +public class RemoteLogSegmentMetadataTransform implements RemoteLogMetadataTransform { + + public ApiMessageAndVersion toApiMessageAndVersion(RemoteLogSegmentMetadata segmentMetadata) { + RemoteLogSegmentMetadataRecord record = new RemoteLogSegmentMetadataRecord() + .setRemoteLogSegmentId(createRemoteLogSegmentIdEntry(segmentMetadata)) + .setStartOffset(segmentMetadata.startOffset()) + .setEndOffset(segmentMetadata.endOffset()) + .setBrokerId(segmentMetadata.brokerId()) + .setEventTimestampMs(segmentMetadata.eventTimestampMs()) + .setMaxTimestampMs(segmentMetadata.maxTimestampMs()) + .setSegmentSizeInBytes(segmentMetadata.segmentSizeInBytes()) + .setSegmentLeaderEpochs(createSegmentLeaderEpochsEntry(segmentMetadata)) + .setRemoteLogSegmentState(segmentMetadata.state().id()); + + return new ApiMessageAndVersion(record, record.highestSupportedVersion()); + } + + private List createSegmentLeaderEpochsEntry(RemoteLogSegmentMetadata data) { + return data.segmentLeaderEpochs().entrySet().stream() + .map(entry -> new RemoteLogSegmentMetadataRecord.SegmentLeaderEpochEntry() + .setLeaderEpoch(entry.getKey()) + .setOffset(entry.getValue())) + .collect(Collectors.toList()); + } + + private RemoteLogSegmentMetadataRecord.RemoteLogSegmentIdEntry createRemoteLogSegmentIdEntry(RemoteLogSegmentMetadata data) { + return new RemoteLogSegmentMetadataRecord.RemoteLogSegmentIdEntry() + .setTopicIdPartition( + new RemoteLogSegmentMetadataRecord.TopicIdPartitionEntry() + .setId(data.remoteLogSegmentId().topicIdPartition().topicId()) + .setName(data.remoteLogSegmentId().topicIdPartition().topic()) + .setPartition(data.remoteLogSegmentId().topicIdPartition().partition())) + .setId(data.remoteLogSegmentId().id()); + } + + @Override + public RemoteLogSegmentMetadata fromApiMessageAndVersion(ApiMessageAndVersion apiMessageAndVersion) { + RemoteLogSegmentMetadataRecord record = (RemoteLogSegmentMetadataRecord) apiMessageAndVersion.message(); + RemoteLogSegmentId remoteLogSegmentId = buildRemoteLogSegmentId(record.remoteLogSegmentId()); + + Map segmentLeaderEpochs = new HashMap<>(); + for (RemoteLogSegmentMetadataRecord.SegmentLeaderEpochEntry segmentLeaderEpoch : record.segmentLeaderEpochs()) { + segmentLeaderEpochs.put(segmentLeaderEpoch.leaderEpoch(), segmentLeaderEpoch.offset()); + } + + RemoteLogSegmentMetadata remoteLogSegmentMetadata = + new RemoteLogSegmentMetadata(remoteLogSegmentId, record.startOffset(), record.endOffset(), + record.maxTimestampMs(), record.brokerId(), + record.eventTimestampMs(), record.segmentSizeInBytes(), + segmentLeaderEpochs); + RemoteLogSegmentMetadataUpdate rlsmUpdate + = new RemoteLogSegmentMetadataUpdate(remoteLogSegmentId, record.eventTimestampMs(), + RemoteLogSegmentState.forId(record.remoteLogSegmentState()), + record.brokerId()); + + return remoteLogSegmentMetadata.createWithUpdates(rlsmUpdate); + } + + private RemoteLogSegmentId buildRemoteLogSegmentId(RemoteLogSegmentMetadataRecord.RemoteLogSegmentIdEntry entry) { + TopicIdPartition topicIdPartition = + new TopicIdPartition(entry.topicIdPartition().id(), + new TopicPartition(entry.topicIdPartition().name(), entry.topicIdPartition().partition())); + + return new RemoteLogSegmentId(topicIdPartition, entry.id()); + } +} diff --git a/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/serialization/RemoteLogSegmentMetadataUpdateTransform.java b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/serialization/RemoteLogSegmentMetadataUpdateTransform.java new file mode 100644 index 0000000..3db7765 --- /dev/null +++ b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/serialization/RemoteLogSegmentMetadataUpdateTransform.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage.serialization; + +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.apache.kafka.server.log.remote.metadata.storage.generated.RemoteLogSegmentMetadataUpdateRecord; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentId; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadataUpdate; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentState; + +public class RemoteLogSegmentMetadataUpdateTransform implements RemoteLogMetadataTransform { + + public ApiMessageAndVersion toApiMessageAndVersion(RemoteLogSegmentMetadataUpdate segmentMetadataUpdate) { + RemoteLogSegmentMetadataUpdateRecord record = new RemoteLogSegmentMetadataUpdateRecord() + .setRemoteLogSegmentId(createRemoteLogSegmentIdEntry(segmentMetadataUpdate)) + .setBrokerId(segmentMetadataUpdate.brokerId()) + .setEventTimestampMs(segmentMetadataUpdate.eventTimestampMs()) + .setRemoteLogSegmentState(segmentMetadataUpdate.state().id()); + + return new ApiMessageAndVersion(record, record.highestSupportedVersion()); + } + + public RemoteLogSegmentMetadataUpdate fromApiMessageAndVersion(ApiMessageAndVersion apiMessageAndVersion) { + RemoteLogSegmentMetadataUpdateRecord record = (RemoteLogSegmentMetadataUpdateRecord) apiMessageAndVersion.message(); + RemoteLogSegmentMetadataUpdateRecord.RemoteLogSegmentIdEntry entry = record.remoteLogSegmentId(); + TopicIdPartition topicIdPartition = new TopicIdPartition(entry.topicIdPartition().id(), + new TopicPartition(entry.topicIdPartition().name(), entry.topicIdPartition().partition())); + + return new RemoteLogSegmentMetadataUpdate(new RemoteLogSegmentId(topicIdPartition, entry.id()), + record.eventTimestampMs(), RemoteLogSegmentState.forId(record.remoteLogSegmentState()), record.brokerId()); + } + + private RemoteLogSegmentMetadataUpdateRecord.RemoteLogSegmentIdEntry createRemoteLogSegmentIdEntry(RemoteLogSegmentMetadataUpdate data) { + return new RemoteLogSegmentMetadataUpdateRecord.RemoteLogSegmentIdEntry() + .setId(data.remoteLogSegmentId().id()) + .setTopicIdPartition( + new RemoteLogSegmentMetadataUpdateRecord.TopicIdPartitionEntry() + .setName(data.remoteLogSegmentId().topicIdPartition().topic()) + .setPartition(data.remoteLogSegmentId().topicIdPartition().partition()) + .setId(data.remoteLogSegmentId().topicIdPartition().topicId())); + } + +} diff --git a/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/serialization/RemotePartitionDeleteMetadataTransform.java b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/serialization/RemotePartitionDeleteMetadataTransform.java new file mode 100644 index 0000000..d94830f --- /dev/null +++ b/storage/src/main/java/org/apache/kafka/server/log/remote/metadata/storage/serialization/RemotePartitionDeleteMetadataTransform.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage.serialization; + +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.apache.kafka.server.log.remote.metadata.storage.generated.RemotePartitionDeleteMetadataRecord; +import org.apache.kafka.server.log.remote.storage.RemotePartitionDeleteMetadata; +import org.apache.kafka.server.log.remote.storage.RemotePartitionDeleteState; + +public final class RemotePartitionDeleteMetadataTransform implements RemoteLogMetadataTransform { + + @Override + public ApiMessageAndVersion toApiMessageAndVersion(RemotePartitionDeleteMetadata partitionDeleteMetadata) { + RemotePartitionDeleteMetadataRecord record = new RemotePartitionDeleteMetadataRecord() + .setTopicIdPartition(createTopicIdPartitionEntry(partitionDeleteMetadata.topicIdPartition())) + .setEventTimestampMs(partitionDeleteMetadata.eventTimestampMs()) + .setBrokerId(partitionDeleteMetadata.brokerId()) + .setRemotePartitionDeleteState(partitionDeleteMetadata.state().id()); + return new ApiMessageAndVersion(record, record.highestSupportedVersion()); + } + + private RemotePartitionDeleteMetadataRecord.TopicIdPartitionEntry createTopicIdPartitionEntry(TopicIdPartition topicIdPartition) { + return new RemotePartitionDeleteMetadataRecord.TopicIdPartitionEntry() + .setName(topicIdPartition.topic()) + .setPartition(topicIdPartition.partition()) + .setId(topicIdPartition.topicId()); + } + + public RemotePartitionDeleteMetadata fromApiMessageAndVersion(ApiMessageAndVersion apiMessageAndVersion) { + RemotePartitionDeleteMetadataRecord record = (RemotePartitionDeleteMetadataRecord) apiMessageAndVersion.message(); + TopicIdPartition topicIdPartition = new TopicIdPartition(record.topicIdPartition().id(), + new TopicPartition(record.topicIdPartition().name(), record.topicIdPartition().partition())); + + return new RemotePartitionDeleteMetadata(topicIdPartition, + RemotePartitionDeleteState.forId(record.remotePartitionDeleteState()), + record.eventTimestampMs(), record.brokerId()); + } +} diff --git a/storage/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteLogManagerConfig.java b/storage/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteLogManagerConfig.java new file mode 100644 index 0000000..f5700bb --- /dev/null +++ b/storage/src/main/java/org/apache/kafka/server/log/remote/storage/RemoteLogManagerConfig.java @@ -0,0 +1,411 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.storage; + +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.ConfigDef; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.apache.kafka.common.config.ConfigDef.Importance.LOW; +import static org.apache.kafka.common.config.ConfigDef.Importance.MEDIUM; +import static org.apache.kafka.common.config.ConfigDef.Range.atLeast; +import static org.apache.kafka.common.config.ConfigDef.Range.between; +import static org.apache.kafka.common.config.ConfigDef.Type.BOOLEAN; +import static org.apache.kafka.common.config.ConfigDef.Type.DOUBLE; +import static org.apache.kafka.common.config.ConfigDef.Type.INT; +import static org.apache.kafka.common.config.ConfigDef.Type.LONG; +import static org.apache.kafka.common.config.ConfigDef.Type.STRING; + +public final class RemoteLogManagerConfig { + + /** + * Prefix used for properties to be passed to {@link RemoteStorageManager} implementation. Remote log subsystem collects all the properties having + * this prefix and passes to {@code RemoteStorageManager} using {@link RemoteStorageManager#configure(Map)}. + */ + public static final String REMOTE_STORAGE_MANAGER_CONFIG_PREFIX_PROP = "remote.log.storage.manager.impl.prefix"; + public static final String REMOTE_STORAGE_MANAGER_CONFIG_PREFIX_DOC = "Prefix used for properties to be passed to RemoteStorageManager " + + "implementation. For example this value can be `rsm.s3.`."; + + /** + * Prefix used for properties to be passed to {@link RemoteLogMetadataManager} implementation. Remote log subsystem collects all the properties having + * this prefix and passed to {@code RemoteLogMetadataManager} using {@link RemoteLogMetadataManager#configure(Map)}. + */ + public static final String REMOTE_LOG_METADATA_MANAGER_CONFIG_PREFIX_PROP = "remote.log.metadata.manager.impl.prefix"; + public static final String REMOTE_LOG_METADATA_MANAGER_CONFIG_PREFIX_DOC = "Prefix used for properties to be passed to RemoteLogMetadataManager " + + "implementation. For example this value can be `rlmm.s3.`."; + + public static final String REMOTE_LOG_STORAGE_SYSTEM_ENABLE_PROP = "remote.log.storage.system.enable"; + public static final String REMOTE_LOG_STORAGE_SYSTEM_ENABLE_DOC = "Whether to enable tier storage functionality in a broker or not. Valid values " + + "are `true` or `false` and the default value is false. When it is true broker starts all the services required for tiered storage functionality."; + public static final boolean DEFAULT_REMOTE_LOG_STORAGE_SYSTEM_ENABLE = false; + + public static final String REMOTE_STORAGE_MANAGER_CLASS_NAME_PROP = "remote.log.storage.manager.class.name"; + public static final String REMOTE_STORAGE_MANAGER_CLASS_NAME_DOC = "Fully qualified class name of `RemoteLogStorageManager` implementation."; + + public static final String REMOTE_STORAGE_MANAGER_CLASS_PATH_PROP = "remote.log.storage.manager.class.path"; + public static final String REMOTE_STORAGE_MANAGER_CLASS_PATH_DOC = "Class path of the `RemoteLogStorageManager` implementation." + + "If specified, the RemoteLogStorageManager implementation and its dependent libraries will be loaded by a dedicated" + + "classloader which searches this class path before the Kafka broker class path. The syntax of this parameter is same" + + "with the standard Java class path string."; + + public static final String REMOTE_LOG_METADATA_MANAGER_CLASS_NAME_PROP = "remote.log.metadata.manager.class.name"; + public static final String REMOTE_LOG_METADATA_MANAGER_CLASS_NAME_DOC = "Fully qualified class name of `RemoteLogMetadataManager` implementation."; + //todo add the default topic based RLMM class name. + public static final String DEFAULT_REMOTE_LOG_METADATA_MANAGER_CLASS_NAME = ""; + + public static final String REMOTE_LOG_METADATA_MANAGER_CLASS_PATH_PROP = "remote.log.metadata.manager.class.path"; + public static final String REMOTE_LOG_METADATA_MANAGER_CLASS_PATH_DOC = "Class path of the `RemoteLogMetadataManager` implementation." + + "If specified, the RemoteLogMetadataManager implementation and its dependent libraries will be loaded by a dedicated" + + "classloader which searches this class path before the Kafka broker class path. The syntax of this parameter is same" + + "with the standard Java class path string."; + + public static final String REMOTE_LOG_METADATA_MANAGER_LISTENER_NAME_PROP = "remote.log.metadata.manager.listener.name"; + public static final String REMOTE_LOG_METADATA_MANAGER_LISTENER_NAME_DOC = "Listener name of the local broker to which it should get connected if " + + "needed by RemoteLogMetadataManager implementation."; + + public static final String REMOTE_LOG_INDEX_FILE_CACHE_TOTAL_SIZE_BYTES_PROP = "remote.log.index.file.cache.total.size.bytes"; + public static final String REMOTE_LOG_INDEX_FILE_CACHE_TOTAL_SIZE_BYTES_DOC = "The total size of the space allocated to store index files fetched " + + "from remote storage in the local storage."; + public static final long DEFAULT_REMOTE_LOG_INDEX_FILE_CACHE_TOTAL_SIZE_BYTES = 1024 * 1024 * 1024L; + + public static final String REMOTE_LOG_MANAGER_THREAD_POOL_SIZE_PROP = "remote.log.manager.thread.pool.size"; + public static final String REMOTE_LOG_MANAGER_THREAD_POOL_SIZE_DOC = "Size of the thread pool used in scheduling tasks to copy " + + "segments, fetch remote log indexes and clean up remote log segments."; + public static final int DEFAULT_REMOTE_LOG_MANAGER_THREAD_POOL_SIZE = 10; + + public static final String REMOTE_LOG_MANAGER_TASK_INTERVAL_MS_PROP = "remote.log.manager.task.interval.ms"; + public static final String REMOTE_LOG_MANAGER_TASK_INTERVAL_MS_DOC = "Interval at which remote log manager runs the scheduled tasks like copy " + + "segments, and clean up remote log segments."; + public static final long DEFAULT_REMOTE_LOG_MANAGER_TASK_INTERVAL_MS = 30 * 1000L; + + public static final String REMOTE_LOG_MANAGER_TASK_RETRY_BACK_OFF_MS_PROP = "remote.log.manager.task.retry.backoff.ms"; + public static final String REMOTE_LOG_MANAGER_TASK_RETRY_BACK_OFF_MS_DOC = "The initial amount of wait in milli seconds before the request is retried again."; + public static final long DEFAULT_REMOTE_LOG_MANAGER_TASK_RETRY_BACK_OFF_MS = 500L; + + public static final String REMOTE_LOG_MANAGER_TASK_RETRY_BACK_OFF_MAX_MS_PROP = "remote.log.manager.task.retry.backoff.max.ms"; + public static final String REMOTE_LOG_MANAGER_TASK_RETRY_BACK_OFF_MAX_MS_DOC = "The maximum amount of time in milliseconds to wait when the request " + + "is retried again. The retry duration will increase exponentially for each request failure up to this maximum wait interval."; + public static final long DEFAULT_REMOTE_LOG_MANAGER_TASK_RETRY_BACK_OFF_MAX_MS = 30 * 1000L; + + public static final String REMOTE_LOG_MANAGER_TASK_RETRY_JITTER_PROP = "remote.log.manager.task.retry.jitter"; + public static final String REMOTE_LOG_MANAGER_TASK_RETRY_JITTER_DOC = "The value used in defining the range for computing random jitter factor. " + + "It is applied to the effective exponential term for computing the resultant retry backoff interval. This will avoid thundering herds " + + "of requests. The default value is 0.2 and valid value should be between 0(inclusive) and 0.5(inclusive). " + + "For ex: remote.log.manager.task.retry.jitter = 0.25, then the range to compute random jitter will be [1-0.25, 1+0.25) viz [0.75, 1.25). " + + "So, jitter factor can be any random value with in that range."; + public static final double DEFAULT_REMOTE_LOG_MANAGER_TASK_RETRY_JITTER = 0.2; + + public static final String REMOTE_LOG_READER_THREADS_PROP = "remote.log.reader.threads"; + public static final String REMOTE_LOG_READER_THREADS_DOC = "Size of the thread pool that is allocated for handling remote log reads."; + public static final int DEFAULT_REMOTE_LOG_READER_THREADS = 10; + + public static final String REMOTE_LOG_READER_MAX_PENDING_TASKS_PROP = "remote.log.reader.max.pending.tasks"; + public static final String REMOTE_LOG_READER_MAX_PENDING_TASKS_DOC = "Maximum remote log reader thread pool task queue size. If the task queue " + + "is full, fetch requests are served with an error."; + public static final int DEFAULT_REMOTE_LOG_READER_MAX_PENDING_TASKS = 100; + + public static final ConfigDef CONFIG_DEF = new ConfigDef(); + + static { + CONFIG_DEF.defineInternal(REMOTE_LOG_STORAGE_SYSTEM_ENABLE_PROP, + BOOLEAN, + DEFAULT_REMOTE_LOG_STORAGE_SYSTEM_ENABLE, + null, + MEDIUM, + REMOTE_LOG_STORAGE_SYSTEM_ENABLE_DOC) + .defineInternal(REMOTE_STORAGE_MANAGER_CONFIG_PREFIX_PROP, + STRING, + null, + new ConfigDef.NonEmptyString(), + MEDIUM, + REMOTE_STORAGE_MANAGER_CONFIG_PREFIX_DOC) + .defineInternal(REMOTE_LOG_METADATA_MANAGER_CONFIG_PREFIX_PROP, + STRING, + null, + new ConfigDef.NonEmptyString(), + MEDIUM, + REMOTE_LOG_METADATA_MANAGER_CONFIG_PREFIX_DOC) + .defineInternal(REMOTE_STORAGE_MANAGER_CLASS_NAME_PROP, STRING, + null, + new ConfigDef.NonEmptyString(), + MEDIUM, + REMOTE_STORAGE_MANAGER_CLASS_NAME_DOC) + .defineInternal(REMOTE_STORAGE_MANAGER_CLASS_PATH_PROP, STRING, + null, + new ConfigDef.NonEmptyString(), + MEDIUM, + REMOTE_STORAGE_MANAGER_CLASS_PATH_DOC) + .defineInternal(REMOTE_LOG_METADATA_MANAGER_CLASS_NAME_PROP, + STRING, null, + new ConfigDef.NonEmptyString(), + MEDIUM, + REMOTE_LOG_METADATA_MANAGER_CLASS_NAME_DOC) + .defineInternal(REMOTE_LOG_METADATA_MANAGER_CLASS_PATH_PROP, + STRING, + null, + new ConfigDef.NonEmptyString(), + MEDIUM, + REMOTE_LOG_METADATA_MANAGER_CLASS_PATH_DOC) + .defineInternal(REMOTE_LOG_METADATA_MANAGER_LISTENER_NAME_PROP, STRING, + null, + new ConfigDef.NonEmptyString(), + MEDIUM, + REMOTE_LOG_METADATA_MANAGER_LISTENER_NAME_DOC) + .defineInternal(REMOTE_LOG_INDEX_FILE_CACHE_TOTAL_SIZE_BYTES_PROP, + LONG, + DEFAULT_REMOTE_LOG_INDEX_FILE_CACHE_TOTAL_SIZE_BYTES, + atLeast(1), + LOW, + REMOTE_LOG_INDEX_FILE_CACHE_TOTAL_SIZE_BYTES_DOC) + .defineInternal(REMOTE_LOG_MANAGER_THREAD_POOL_SIZE_PROP, + INT, + DEFAULT_REMOTE_LOG_MANAGER_THREAD_POOL_SIZE, + atLeast(1), + MEDIUM, + REMOTE_LOG_MANAGER_THREAD_POOL_SIZE_DOC) + .defineInternal(REMOTE_LOG_MANAGER_TASK_INTERVAL_MS_PROP, + LONG, + DEFAULT_REMOTE_LOG_MANAGER_TASK_INTERVAL_MS, + atLeast(1), + LOW, + REMOTE_LOG_MANAGER_TASK_INTERVAL_MS_DOC) + .defineInternal(REMOTE_LOG_MANAGER_TASK_RETRY_BACK_OFF_MS_PROP, + LONG, + DEFAULT_REMOTE_LOG_MANAGER_TASK_RETRY_BACK_OFF_MS, + atLeast(1), + LOW, + REMOTE_LOG_MANAGER_TASK_RETRY_BACK_OFF_MS_DOC) + .defineInternal(REMOTE_LOG_MANAGER_TASK_RETRY_BACK_OFF_MAX_MS_PROP, + LONG, + DEFAULT_REMOTE_LOG_MANAGER_TASK_RETRY_BACK_OFF_MAX_MS, + atLeast(1), LOW, + REMOTE_LOG_MANAGER_TASK_RETRY_BACK_OFF_MAX_MS_DOC) + .defineInternal(REMOTE_LOG_MANAGER_TASK_RETRY_JITTER_PROP, + DOUBLE, + DEFAULT_REMOTE_LOG_MANAGER_TASK_RETRY_JITTER, + between(0, 0.5), + LOW, + REMOTE_LOG_MANAGER_TASK_RETRY_JITTER_DOC) + .defineInternal(REMOTE_LOG_READER_THREADS_PROP, + INT, + DEFAULT_REMOTE_LOG_READER_THREADS, + atLeast(1), + MEDIUM, + REMOTE_LOG_READER_THREADS_DOC) + .defineInternal(REMOTE_LOG_READER_MAX_PENDING_TASKS_PROP, + INT, + DEFAULT_REMOTE_LOG_READER_MAX_PENDING_TASKS, + atLeast(1), + MEDIUM, + REMOTE_LOG_READER_MAX_PENDING_TASKS_DOC); + } + + private final boolean enableRemoteStorageSystem; + private final String remoteStorageManagerClassName; + private final String remoteStorageManagerClassPath; + private final String remoteLogMetadataManagerClassName; + private final String remoteLogMetadataManagerClassPath; + private final long remoteLogIndexFileCacheTotalSizeBytes; + private final int remoteLogManagerThreadPoolSize; + private final long remoteLogManagerTaskIntervalMs; + private final long remoteLogManagerTaskRetryBackoffMs; + private final long remoteLogManagerTaskRetryBackoffMaxMs; + private final double remoteLogManagerTaskRetryJitter; + private final int remoteLogReaderThreads; + private final int remoteLogReaderMaxPendingTasks; + private final String remoteStorageManagerPrefix; + private final HashMap remoteStorageManagerProps; + private final String remoteLogMetadataManagerPrefix; + private final HashMap remoteLogMetadataManagerProps; + private final String remoteLogMetadataManagerListenerName; + + public RemoteLogManagerConfig(AbstractConfig config) { + this(config.getBoolean(REMOTE_LOG_STORAGE_SYSTEM_ENABLE_PROP), + config.getString(REMOTE_STORAGE_MANAGER_CLASS_NAME_PROP), + config.getString(REMOTE_STORAGE_MANAGER_CLASS_PATH_PROP), + config.getString(REMOTE_LOG_METADATA_MANAGER_CLASS_NAME_PROP), + config.getString(REMOTE_LOG_METADATA_MANAGER_CLASS_PATH_PROP), + config.getString(REMOTE_LOG_METADATA_MANAGER_LISTENER_NAME_PROP), + config.getLong(REMOTE_LOG_INDEX_FILE_CACHE_TOTAL_SIZE_BYTES_PROP), + config.getInt(REMOTE_LOG_MANAGER_THREAD_POOL_SIZE_PROP), + config.getLong(REMOTE_LOG_MANAGER_TASK_INTERVAL_MS_PROP), + config.getLong(REMOTE_LOG_MANAGER_TASK_RETRY_BACK_OFF_MS_PROP), + config.getLong(REMOTE_LOG_MANAGER_TASK_RETRY_BACK_OFF_MAX_MS_PROP), + config.getDouble(REMOTE_LOG_MANAGER_TASK_RETRY_JITTER_PROP), + config.getInt(REMOTE_LOG_READER_THREADS_PROP), + config.getInt(REMOTE_LOG_READER_MAX_PENDING_TASKS_PROP), + config.getString(REMOTE_STORAGE_MANAGER_CONFIG_PREFIX_PROP), + config.getString(REMOTE_STORAGE_MANAGER_CONFIG_PREFIX_PROP) != null + ? config.originalsWithPrefix(config.getString(REMOTE_STORAGE_MANAGER_CONFIG_PREFIX_PROP)) + : Collections.emptyMap(), + config.getString(REMOTE_LOG_METADATA_MANAGER_CONFIG_PREFIX_PROP), + config.getString(REMOTE_LOG_METADATA_MANAGER_CONFIG_PREFIX_PROP) != null + ? config.originalsWithPrefix(config.getString(REMOTE_LOG_METADATA_MANAGER_CONFIG_PREFIX_PROP)) + : Collections.emptyMap()); + } + + // Visible for testing + public RemoteLogManagerConfig(boolean enableRemoteStorageSystem, + String remoteStorageManagerClassName, + String remoteStorageManagerClassPath, + String remoteLogMetadataManagerClassName, + String remoteLogMetadataManagerClassPath, + String remoteLogMetadataManagerListenerName, + long remoteLogIndexFileCacheTotalSizeBytes, + int remoteLogManagerThreadPoolSize, + long remoteLogManagerTaskIntervalMs, + long remoteLogManagerTaskRetryBackoffMs, + long remoteLogManagerTaskRetryBackoffMaxMs, + double remoteLogManagerTaskRetryJitter, + int remoteLogReaderThreads, + int remoteLogReaderMaxPendingTasks, + String remoteStorageManagerPrefix, + Map remoteStorageManagerProps, /* properties having keys stripped out with remoteStorageManagerPrefix */ + String remoteLogMetadataManagerPrefix, + Map remoteLogMetadataManagerProps /* properties having keys stripped out with remoteLogMetadataManagerPrefix */ + ) { + this.enableRemoteStorageSystem = enableRemoteStorageSystem; + this.remoteStorageManagerClassName = remoteStorageManagerClassName; + this.remoteStorageManagerClassPath = remoteStorageManagerClassPath; + this.remoteLogMetadataManagerClassName = remoteLogMetadataManagerClassName; + this.remoteLogMetadataManagerClassPath = remoteLogMetadataManagerClassPath; + this.remoteLogIndexFileCacheTotalSizeBytes = remoteLogIndexFileCacheTotalSizeBytes; + this.remoteLogManagerThreadPoolSize = remoteLogManagerThreadPoolSize; + this.remoteLogManagerTaskIntervalMs = remoteLogManagerTaskIntervalMs; + this.remoteLogManagerTaskRetryBackoffMs = remoteLogManagerTaskRetryBackoffMs; + this.remoteLogManagerTaskRetryBackoffMaxMs = remoteLogManagerTaskRetryBackoffMaxMs; + this.remoteLogManagerTaskRetryJitter = remoteLogManagerTaskRetryJitter; + this.remoteLogReaderThreads = remoteLogReaderThreads; + this.remoteLogReaderMaxPendingTasks = remoteLogReaderMaxPendingTasks; + this.remoteStorageManagerPrefix = remoteStorageManagerPrefix; + this.remoteStorageManagerProps = new HashMap<>(remoteStorageManagerProps); + this.remoteLogMetadataManagerPrefix = remoteLogMetadataManagerPrefix; + this.remoteLogMetadataManagerProps = new HashMap<>(remoteLogMetadataManagerProps); + this.remoteLogMetadataManagerListenerName = remoteLogMetadataManagerListenerName; + } + + public boolean enableRemoteStorageSystem() { + return enableRemoteStorageSystem; + } + + public String remoteStorageManagerClassName() { + return remoteStorageManagerClassName; + } + + public String remoteStorageManagerClassPath() { + return remoteStorageManagerClassPath; + } + + public String remoteLogMetadataManagerClassName() { + return remoteLogMetadataManagerClassName; + } + + public String remoteLogMetadataManagerClassPath() { + return remoteLogMetadataManagerClassPath; + } + + public long remoteLogIndexFileCacheTotalSizeBytes() { + return remoteLogIndexFileCacheTotalSizeBytes; + } + + public int remoteLogManagerThreadPoolSize() { + return remoteLogManagerThreadPoolSize; + } + + public long remoteLogManagerTaskIntervalMs() { + return remoteLogManagerTaskIntervalMs; + } + + public long remoteLogManagerTaskRetryBackoffMs() { + return remoteLogManagerTaskRetryBackoffMs; + } + + public long remoteLogManagerTaskRetryBackoffMaxMs() { + return remoteLogManagerTaskRetryBackoffMaxMs; + } + + public double remoteLogManagerTaskRetryJitter() { + return remoteLogManagerTaskRetryJitter; + } + + public int remoteLogReaderThreads() { + return remoteLogReaderThreads; + } + + public int remoteLogReaderMaxPendingTasks() { + return remoteLogReaderMaxPendingTasks; + } + + public String remoteLogMetadataManagerListenerName() { + return remoteLogMetadataManagerListenerName; + } + + public String remoteStorageManagerPrefix() { + return remoteStorageManagerPrefix; + } + + public String remoteLogMetadataManagerPrefix() { + return remoteLogMetadataManagerPrefix; + } + + public Map remoteStorageManagerProps() { + return Collections.unmodifiableMap(remoteStorageManagerProps); + } + + public Map remoteLogMetadataManagerProps() { + return Collections.unmodifiableMap(remoteLogMetadataManagerProps); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof RemoteLogManagerConfig)) return false; + RemoteLogManagerConfig that = (RemoteLogManagerConfig) o; + return enableRemoteStorageSystem == that.enableRemoteStorageSystem + && remoteLogIndexFileCacheTotalSizeBytes == that.remoteLogIndexFileCacheTotalSizeBytes + && remoteLogManagerThreadPoolSize == that.remoteLogManagerThreadPoolSize + && remoteLogManagerTaskIntervalMs == that.remoteLogManagerTaskIntervalMs + && remoteLogManagerTaskRetryBackoffMs == that.remoteLogManagerTaskRetryBackoffMs + && remoteLogManagerTaskRetryBackoffMaxMs == that.remoteLogManagerTaskRetryBackoffMaxMs + && remoteLogManagerTaskRetryJitter == that.remoteLogManagerTaskRetryJitter + && remoteLogReaderThreads == that.remoteLogReaderThreads + && remoteLogReaderMaxPendingTasks == that.remoteLogReaderMaxPendingTasks + && Objects.equals(remoteStorageManagerClassName, that.remoteStorageManagerClassName) + && Objects.equals(remoteStorageManagerClassPath, that.remoteStorageManagerClassPath) + && Objects.equals(remoteLogMetadataManagerClassName, that.remoteLogMetadataManagerClassName) + && Objects.equals(remoteLogMetadataManagerClassPath, that.remoteLogMetadataManagerClassPath) + && Objects.equals(remoteLogMetadataManagerListenerName, that.remoteLogMetadataManagerListenerName) + && Objects.equals(remoteStorageManagerProps, that.remoteStorageManagerProps) + && Objects.equals(remoteLogMetadataManagerProps, that.remoteLogMetadataManagerProps) + && Objects.equals(remoteStorageManagerPrefix, that.remoteStorageManagerPrefix) + && Objects.equals(remoteLogMetadataManagerPrefix, that.remoteLogMetadataManagerPrefix); + } + + @Override + public int hashCode() { + return Objects.hash(enableRemoteStorageSystem, remoteStorageManagerClassName, remoteStorageManagerClassPath, + remoteLogMetadataManagerClassName, remoteLogMetadataManagerClassPath, remoteLogMetadataManagerListenerName, + remoteLogIndexFileCacheTotalSizeBytes, remoteLogManagerThreadPoolSize, remoteLogManagerTaskIntervalMs, + remoteLogManagerTaskRetryBackoffMs, remoteLogManagerTaskRetryBackoffMaxMs, remoteLogManagerTaskRetryJitter, + remoteLogReaderThreads, remoteLogReaderMaxPendingTasks, remoteStorageManagerProps, remoteLogMetadataManagerProps, + remoteStorageManagerPrefix, remoteLogMetadataManagerPrefix); + } +} diff --git a/storage/src/main/resources/message/RemoteLogSegmentMetadataRecord.json b/storage/src/main/resources/message/RemoteLogSegmentMetadataRecord.json new file mode 100644 index 0000000..d18144e --- /dev/null +++ b/storage/src/main/resources/message/RemoteLogSegmentMetadataRecord.json @@ -0,0 +1,126 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 0, + "type": "metadata", + "name": "RemoteLogSegmentMetadataRecord", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { + "name": "RemoteLogSegmentId", + "type": "RemoteLogSegmentIdEntry", + "versions": "0+", + "about": "Unique representation of the remote log segment.", + "fields": [ + { + "name": "TopicIdPartition", + "type": "TopicIdPartitionEntry", + "versions": "0+", + "about": "Represents unique topic partition.", + "fields": [ + { + "name": "Name", + "type": "string", + "versions": "0+", + "about": "Topic name." + }, + { + "name": "Id", + "type": "uuid", + "versions": "0+", + "about": "Unique identifier of the topic." + }, + { + "name": "Partition", + "type": "int32", + "versions": "0+", + "about": "Partition number." + } + ] + }, + { + "name": "Id", + "type": "uuid", + "versions": "0+", + "about": "Unique identifier of the remote log segment." + } + ] + }, + { + "name": "StartOffset", + "type": "int64", + "versions": "0+", + "about": "Start offset of the segment." + }, + { + "name": "EndOffset", + "type": "int64", + "versions": "0+", + "about": "End offset of the segment." + }, + { + "name": "BrokerId", + "type": "int32", + "versions": "0+", + "about": "Broker id from which this event is generated." + }, + { + "name": "MaxTimestampMs", + "type": "int64", + "versions": "0+", + "about": "Maximum timestamp in milli seconds with in this segment." + }, + { + "name": "EventTimestampMs", + "type": "int64", + "versions": "0+", + "about": "Epoch time in milli seconds at which this event is generated." + }, + { + "name": "SegmentLeaderEpochs", + "type": "[]SegmentLeaderEpochEntry", + "versions": "0+", + "about": "Leader epoch to start-offset mappings for the records with in this segment.", + "fields": [ + { + "name": "LeaderEpoch", + "type": "int32", + "versions": "0+", + "about": "Leader epoch" + }, + { + "name": "Offset", + "type": "int64", + "versions": "0+", + "about": "Start offset for the leader epoch." + } + ] + }, + { + "name": "SegmentSizeInBytes", + "type": "int32", + "versions": "0+", + "about": "Segment size in bytes." + }, + { + "name": "RemoteLogSegmentState", + "type": "int8", + "versions": "0+", + "about": "State identifier of the remote log segment, which is RemoteLogSegmentState.id()." + } + ] +} \ No newline at end of file diff --git a/storage/src/main/resources/message/RemoteLogSegmentMetadataSnapshotRecord.json b/storage/src/main/resources/message/RemoteLogSegmentMetadataSnapshotRecord.json new file mode 100644 index 0000000..dbb2913 --- /dev/null +++ b/storage/src/main/resources/message/RemoteLogSegmentMetadataSnapshotRecord.json @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 3, + "type": "metadata", + "name": "RemoteLogSegmentMetadataSnapshotRecord", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { + "name": "SegmentId", + "type": "uuid", + "versions": "0+", + "about": "Unique identifier of the log segment" + }, + { + "name": "StartOffset", + "type": "int64", + "versions": "0+", + "about": "Start offset of the segment." + }, + { + "name": "EndOffset", + "type": "int64", + "versions": "0+", + "about": "End offset of the segment." + }, + { + "name": "BrokerId", + "type": "int32", + "versions": "0+", + "about": "Broker (controller or leader) id from which this event is created or updated." + }, + { + "name": "MaxTimestampMs", + "type": "int64", + "versions": "0+", + "about": "Maximum timestamp with in this segment." + }, + { + "name": "EventTimestampMs", + "type": "int64", + "versions": "0+", + "about": "Event timestamp of this segment." + }, + { + "name": "SegmentLeaderEpochs", + "type": "[]SegmentLeaderEpochEntry", + "versions": "0+", + "about": "Leader epochs of this segment.", + "fields": [ + { + "name": "LeaderEpoch", + "type": "int32", + "versions": "0+", + "about": "Leader epoch" + }, + { + "name": "Offset", + "type": "int64", + "versions": "0+", + "about": "Start offset for the leader epoch" + } + ] + }, + { + "name": "SegmentSizeInBytes", + "type": "int32", + "versions": "0+", + "about": "Segment size in bytes" + }, + { + "name": "RemoteLogSegmentState", + "type": "int8", + "versions": "0+", + "about": "State of the remote log segment" + } + ] +} \ No newline at end of file diff --git a/storage/src/main/resources/message/RemoteLogSegmentMetadataUpdateRecord.json b/storage/src/main/resources/message/RemoteLogSegmentMetadataUpdateRecord.json new file mode 100644 index 0000000..24003dc --- /dev/null +++ b/storage/src/main/resources/message/RemoteLogSegmentMetadataUpdateRecord.json @@ -0,0 +1,82 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 1, + "type": "metadata", + "name": "RemoteLogSegmentMetadataUpdateRecord", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { + "name": "RemoteLogSegmentId", + "type": "RemoteLogSegmentIdEntry", + "versions": "0+", + "about": "Unique representation of the remote log segment.", + "fields": [ + { + "name": "TopicIdPartition", + "type": "TopicIdPartitionEntry", + "versions": "0+", + "about": "Represents unique topic partition.", + "fields": [ + { + "name": "Name", + "type": "string", + "versions": "0+", + "about": "Topic name." + }, + { + "name": "Id", + "type": "uuid", + "versions": "0+", + "about": "Unique identifier of the topic." + }, + { + "name": "Partition", + "type": "int32", + "versions": "0+", + "about": "Partition number." + } + ] + }, + { + "name": "Id", + "type": "uuid", + "versions": "0+", + "about": "Unique identifier of the remote log segment." + } + ] + }, + { + "name": "BrokerId", + "type": "int32", + "versions": "0+", + "about": "Broker id from which this event is generated." + }, + { + "name": "EventTimestampMs", + "type": "int64", + "versions": "0+", + "about": "Epoch time in milli seconds at which this event is generated." + }, + { + "name": "RemoteLogSegmentState", + "type": "int8", + "versions": "0+", + "about": "State identifier of the remote log segment, which is RemoteLogSegmentState.id()." + } + ] +} \ No newline at end of file diff --git a/storage/src/main/resources/message/RemotePartitionDeleteMetadataRecord.json b/storage/src/main/resources/message/RemotePartitionDeleteMetadataRecord.json new file mode 100644 index 0000000..f5e955b --- /dev/null +++ b/storage/src/main/resources/message/RemotePartitionDeleteMetadataRecord.json @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "apiKey": 2, + "type": "metadata", + "name": "RemotePartitionDeleteMetadataRecord", + "validVersions": "0", + "flexibleVersions": "0+", + "fields": [ + { + "name": "TopicIdPartition", + "type": "TopicIdPartitionEntry", + "versions": "0+", + "about": "Represents unique topic partition.", + "fields": [ + { + "name": "Name", + "type": "string", + "versions": "0+", + "about": "Topic name." + }, + { + "name": "Id", + "type": "uuid", + "versions": "0+", + "about": "Unique identifier of the topic." + }, + { + "name": "Partition", + "type": "int32", + "versions": "0+", + "about": "Partition number." + } + ] + }, + { + "name": "BrokerId", + "type": "int32", + "versions": "0+", + "about": "Broker (controller or leader) id from which this event is created. DELETE_PARTITION_MARKED is sent by the controller. DELETE_PARTITION_STARTED and DELETE_PARTITION_FINISHED are sent by remote log metadata topic partition leader." + }, + { + "name": "EventTimestampMs", + "type": "int64", + "versions": "0+", + "about": "Epoch time in milli seconds at which this event is generated." + }, + { + "name": "RemotePartitionDeleteState", + "type": "int8", + "versions": "0+", + "about": "Deletion state identifier of the remote partition, which is RemotePartitionDeleteState.id()." + } + ] +} \ No newline at end of file diff --git a/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/FileBasedRemoteLogMetadataCacheTest.java b/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/FileBasedRemoteLogMetadataCacheTest.java new file mode 100644 index 0000000..5f77417 --- /dev/null +++ b/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/FileBasedRemoteLogMetadataCacheTest.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage; + +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentId; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadata; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadataUpdate; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentState; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; + +import java.nio.file.Path; +import java.util.Collections; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class FileBasedRemoteLogMetadataCacheTest { + + @Test + public void testFileBasedRemoteLogMetadataCacheWithUnreferencedSegments() throws Exception { + TopicIdPartition partition = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("test", 0)); + int brokerId = 0; + Path path = TestUtils.tempDirectory().toPath(); + + // Create file based metadata cache. + FileBasedRemoteLogMetadataCache cache = new FileBasedRemoteLogMetadataCache(partition, path); + + // Add a segment with start offset as 0 for leader epoch 0. + RemoteLogSegmentId segmentId1 = new RemoteLogSegmentId(partition, Uuid.randomUuid()); + RemoteLogSegmentMetadata metadata1 = new RemoteLogSegmentMetadata(segmentId1, + 0, 100, System.currentTimeMillis(), brokerId, System.currentTimeMillis(), + 1024 * 1024, Collections.singletonMap(0, 0L)); + cache.addCopyInProgressSegment(metadata1); + RemoteLogSegmentMetadataUpdate metadataUpdate1 = new RemoteLogSegmentMetadataUpdate(segmentId1, System.currentTimeMillis(), + RemoteLogSegmentState.COPY_SEGMENT_FINISHED, brokerId); + cache.updateRemoteLogSegmentMetadata(metadataUpdate1); + Optional receivedMetadata = cache.remoteLogSegmentMetadata(0, 0L); + assertTrue(receivedMetadata.isPresent()); + assertEquals(metadata1.createWithUpdates(metadataUpdate1), receivedMetadata.get()); + + // Add a new segment with start offset as 0 for leader epoch 0, which should replace the earlier segment. + RemoteLogSegmentId segmentId2 = new RemoteLogSegmentId(partition, Uuid.randomUuid()); + RemoteLogSegmentMetadata metadata2 = new RemoteLogSegmentMetadata(segmentId2, + 0, 900, System.currentTimeMillis(), brokerId, System.currentTimeMillis(), + 1024 * 1024, Collections.singletonMap(0, 0L)); + cache.addCopyInProgressSegment(metadata2); + RemoteLogSegmentMetadataUpdate metadataUpdate2 = new RemoteLogSegmentMetadataUpdate(segmentId2, System.currentTimeMillis(), + RemoteLogSegmentState.COPY_SEGMENT_FINISHED, brokerId); + cache.updateRemoteLogSegmentMetadata(metadataUpdate2); + + // Fetch segment for leader epoch:0 and start offset:0, it should be the newly added segment. + Optional receivedMetadata2 = cache.remoteLogSegmentMetadata(0, 0L); + assertTrue(receivedMetadata2.isPresent()); + assertEquals(metadata2.createWithUpdates(metadataUpdate2), receivedMetadata2.get()); + // Flush the cache to the file. + cache.flushToFile(0, 0L); + + // Create a new cache with loading from the stored path. + FileBasedRemoteLogMetadataCache loadedCache = new FileBasedRemoteLogMetadataCache(partition, path); + + // Fetch segment for leader epoch:0 and start offset:0, it should be metadata2. + // This ensures that the ordering of metadata is taken care after loading from the stored snapshots. + Optional receivedMetadataAfterLoad = loadedCache.remoteLogSegmentMetadata(0, 0L); + assertTrue(receivedMetadataAfterLoad.isPresent()); + assertEquals(metadata2.createWithUpdates(metadataUpdate2), receivedMetadataAfterLoad.get()); + } +} diff --git a/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogMetadataCacheTest.java b/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogMetadataCacheTest.java new file mode 100644 index 0000000..789997f --- /dev/null +++ b/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogMetadataCacheTest.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage; + +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentId; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadata; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadataUpdate; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentState; +import org.apache.kafka.server.log.remote.storage.RemoteResourceNotFoundException; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.Map; + +public class RemoteLogMetadataCacheTest { + + private static final TopicIdPartition TP0 = new TopicIdPartition(Uuid.randomUuid(), + new TopicPartition("foo", 0)); + private static final int SEG_SIZE = 1024 * 1024; + private static final int BROKER_ID_0 = 0; + private static final int BROKER_ID_1 = 1; + + private final Time time = new MockTime(1); + + @Test + public void testAPIsWithInvalidArgs() { + RemoteLogMetadataCache cache = new RemoteLogMetadataCache(); + + Assertions.assertThrows(NullPointerException.class, () -> cache.addCopyInProgressSegment(null)); + Assertions.assertThrows(NullPointerException.class, () -> cache.updateRemoteLogSegmentMetadata(null)); + + // Check for invalid state updates to addCopyInProgressSegment method. + for (RemoteLogSegmentState state : RemoteLogSegmentState.values()) { + if (state != RemoteLogSegmentState.COPY_SEGMENT_STARTED) { + RemoteLogSegmentMetadata segmentMetadata = new RemoteLogSegmentMetadata( + new RemoteLogSegmentId(TP0, Uuid.randomUuid()), 0, 100L, + -1L, BROKER_ID_0, time.milliseconds(), SEG_SIZE, Collections.singletonMap(0, 0L)); + RemoteLogSegmentMetadata updatedMetadata = segmentMetadata + .createWithUpdates(new RemoteLogSegmentMetadataUpdate(segmentMetadata.remoteLogSegmentId(), + time.milliseconds(), state, BROKER_ID_1)); + Assertions.assertThrows(IllegalArgumentException.class, () -> + cache.addCopyInProgressSegment(updatedMetadata)); + } + } + + // Check for updating non existing segment-id. + Assertions.assertThrows(RemoteResourceNotFoundException.class, () -> { + RemoteLogSegmentId nonExistingId = new RemoteLogSegmentId(TP0, Uuid.randomUuid()); + cache.updateRemoteLogSegmentMetadata(new RemoteLogSegmentMetadataUpdate(nonExistingId, + time.milliseconds(), RemoteLogSegmentState.DELETE_SEGMENT_STARTED, BROKER_ID_1)); + }); + + // Check for invalid state transition. + Assertions.assertThrows(IllegalStateException.class, () -> { + RemoteLogSegmentMetadata segmentMetadata = createSegmentUpdateWithState(cache, Collections.singletonMap(0, 0L), 0, + 100, RemoteLogSegmentState.COPY_SEGMENT_FINISHED); + cache.updateRemoteLogSegmentMetadata(new RemoteLogSegmentMetadataUpdate(segmentMetadata.remoteLogSegmentId(), + time.milliseconds(), RemoteLogSegmentState.DELETE_SEGMENT_FINISHED, BROKER_ID_1)); + }); + } + + private RemoteLogSegmentMetadata createSegmentUpdateWithState(RemoteLogMetadataCache cache, + Map segmentLeaderEpochs, + long startOffset, + long endOffset, + RemoteLogSegmentState state) + throws RemoteResourceNotFoundException { + RemoteLogSegmentId segmentId = new RemoteLogSegmentId(TP0, Uuid.randomUuid()); + RemoteLogSegmentMetadata segmentMetadata = new RemoteLogSegmentMetadata(segmentId, startOffset, endOffset, -1L, + BROKER_ID_0, time.milliseconds(), SEG_SIZE, segmentLeaderEpochs); + cache.addCopyInProgressSegment(segmentMetadata); + + RemoteLogSegmentMetadataUpdate segMetadataUpdate = new RemoteLogSegmentMetadataUpdate(segmentId, + time.milliseconds(), state, BROKER_ID_1); + cache.updateRemoteLogSegmentMetadata(segMetadataUpdate); + + return segmentMetadata.createWithUpdates(segMetadataUpdate); + } + +} \ No newline at end of file diff --git a/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogMetadataSerdeTest.java b/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogMetadataSerdeTest.java new file mode 100644 index 0000000..402d1a2 --- /dev/null +++ b/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogMetadataSerdeTest.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage; + +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.server.log.remote.metadata.storage.serialization.RemoteLogMetadataSerde; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentId; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadata; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadataUpdate; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentState; +import org.apache.kafka.server.log.remote.storage.RemotePartitionDeleteMetadata; +import org.apache.kafka.server.log.remote.storage.RemotePartitionDeleteState; +import org.apache.kafka.server.log.remote.storage.RemoteLogMetadata; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; + +public class RemoteLogMetadataSerdeTest { + + public static final String TOPIC = "foo"; + private static final TopicIdPartition TP0 = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition(TOPIC, 0)); + private final Time time = new MockTime(1); + + @Test + public void testRemoteLogSegmentMetadataSerde() { + RemoteLogSegmentMetadata remoteLogSegmentMetadata = createRemoteLogSegmentMetadata(); + + doTestRemoteLogMetadataSerde(remoteLogSegmentMetadata); + } + + @Test + public void testRemoteLogSegmentMetadataUpdateSerde() { + RemoteLogSegmentMetadataUpdate remoteLogSegmentMetadataUpdate = createRemoteLogSegmentMetadataUpdate(); + + doTestRemoteLogMetadataSerde(remoteLogSegmentMetadataUpdate); + } + + @Test + public void testRemotePartitionDeleteMetadataSerde() { + RemotePartitionDeleteMetadata remotePartitionDeleteMetadata = createRemotePartitionDeleteMetadata(); + + doTestRemoteLogMetadataSerde(remotePartitionDeleteMetadata); + } + + private RemoteLogSegmentMetadata createRemoteLogSegmentMetadata() { + Map segLeaderEpochs = new HashMap<>(); + segLeaderEpochs.put(0, 0L); + segLeaderEpochs.put(1, 20L); + segLeaderEpochs.put(2, 80L); + RemoteLogSegmentId remoteLogSegmentId = new RemoteLogSegmentId(TP0, Uuid.randomUuid()); + return new RemoteLogSegmentMetadata(remoteLogSegmentId, 0L, 100L, -1L, 1, + time.milliseconds(), 1024, segLeaderEpochs); + } + + private RemoteLogSegmentMetadataUpdate createRemoteLogSegmentMetadataUpdate() { + RemoteLogSegmentId remoteLogSegmentId = new RemoteLogSegmentId(TP0, Uuid.randomUuid()); + return new RemoteLogSegmentMetadataUpdate(remoteLogSegmentId, time.milliseconds(), + RemoteLogSegmentState.COPY_SEGMENT_FINISHED, 2); + } + + private RemotePartitionDeleteMetadata createRemotePartitionDeleteMetadata() { + return new RemotePartitionDeleteMetadata(TP0, RemotePartitionDeleteState.DELETE_PARTITION_MARKED, + time.milliseconds(), 0); + } + + private void doTestRemoteLogMetadataSerde(RemoteLogMetadata remoteLogMetadata) { + // Serialize metadata and get the bytes. + RemoteLogMetadataSerde serializer = new RemoteLogMetadataSerde(); + byte[] metadataBytes = serializer.serialize(remoteLogMetadata); + + // Deserialize the bytes and check the RemoteLogMetadata object is as expected. + // Created another RemoteLogMetadataSerde instance to depict the real usecase of serializer and deserializer having their own instances. + RemoteLogMetadataSerde deserializer = new RemoteLogMetadataSerde(); + RemoteLogMetadata deserializedRemoteLogMetadata = deserializer.deserialize(metadataBytes); + Assertions.assertEquals(remoteLogMetadata, deserializedRemoteLogMetadata); + } + + @Test + public void testInvalidRemoteStorageMetadata() { + // Serializing receives an exception as it does not have the expected RemoteLogMetadata registered in serdes. + Assertions.assertThrows(IllegalArgumentException.class, + () -> new RemoteLogMetadataSerde().serialize(new InvalidRemoteLogMetadata(1, time.milliseconds()))); + } + + private static class InvalidRemoteLogMetadata extends RemoteLogMetadata { + public InvalidRemoteLogMetadata(int brokerId, long eventTimestampMs) { + super(brokerId, eventTimestampMs); + } + + @Override + public TopicIdPartition topicIdPartition() { + throw new UnsupportedOperationException(); + } + } + +} \ No newline at end of file diff --git a/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogMetadataSnapshotFileTest.java b/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogMetadataSnapshotFileTest.java new file mode 100644 index 0000000..1b46028 --- /dev/null +++ b/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogMetadataSnapshotFileTest.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage; + +import org.apache.kafka.common.Uuid; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentState; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.io.File; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; + +public class RemoteLogMetadataSnapshotFileTest { + + @Test + public void testEmptyCommittedLogMetadataFile() throws Exception { + File metadataStoreDir = TestUtils.tempDirectory("_rlmm_committed"); + RemoteLogMetadataSnapshotFile snapshotFile = new RemoteLogMetadataSnapshotFile(metadataStoreDir.toPath()); + + // There should be an empty snapshot as nothing is written into it. + Assertions.assertFalse(snapshotFile.read().isPresent()); + } + + @Test + public void testEmptySnapshotWithCommittedLogMetadataFile() throws Exception { + File metadataStoreDir = TestUtils.tempDirectory("_rlmm_committed"); + RemoteLogMetadataSnapshotFile snapshotFile = new RemoteLogMetadataSnapshotFile(metadataStoreDir.toPath()); + + snapshotFile.write(new RemoteLogMetadataSnapshotFile.Snapshot(0, 0L, Collections.emptyList())); + + // There should be an empty snapshot as the written snapshot did not have any remote log segment metadata. + Assertions.assertTrue(snapshotFile.read().isPresent()); + Assertions.assertTrue(snapshotFile.read().get().remoteLogSegmentMetadataSnapshots().isEmpty()); + } + + @Test + public void testWriteReadCommittedLogMetadataFile() throws Exception { + File metadataStoreDir = TestUtils.tempDirectory("_rlmm_committed"); + RemoteLogMetadataSnapshotFile snapshotFile = new RemoteLogMetadataSnapshotFile(metadataStoreDir.toPath()); + + List remoteLogSegmentMetadatas = new ArrayList<>(); + long startOffset = 0; + for (int i = 0; i < 100; i++) { + long endOffset = startOffset + 100L; + remoteLogSegmentMetadatas.add( + new RemoteLogSegmentMetadataSnapshot(Uuid.randomUuid(), startOffset, endOffset, + System.currentTimeMillis(), 1, 100, 1024, + RemoteLogSegmentState.COPY_SEGMENT_FINISHED, Collections.singletonMap(i, startOffset))); + startOffset = endOffset + 1; + } + + RemoteLogMetadataSnapshotFile.Snapshot snapshot = new RemoteLogMetadataSnapshotFile.Snapshot(0, 120, + remoteLogSegmentMetadatas); + snapshotFile.write(snapshot); + + Optional maybeReadSnapshot = snapshotFile.read(); + Assertions.assertTrue(maybeReadSnapshot.isPresent()); + + Assertions.assertEquals(snapshot, maybeReadSnapshot.get()); + Assertions.assertEquals(new HashSet<>(snapshot.remoteLogSegmentMetadataSnapshots()), + new HashSet<>(maybeReadSnapshot.get().remoteLogSegmentMetadataSnapshots())); + } +} diff --git a/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogMetadataTransformTest.java b/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogMetadataTransformTest.java new file mode 100644 index 0000000..87e7683 --- /dev/null +++ b/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogMetadataTransformTest.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage; + +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.server.common.ApiMessageAndVersion; +import org.apache.kafka.server.log.remote.metadata.storage.serialization.RemoteLogSegmentMetadataTransform; +import org.apache.kafka.server.log.remote.metadata.storage.serialization.RemoteLogSegmentMetadataUpdateTransform; +import org.apache.kafka.server.log.remote.metadata.storage.serialization.RemotePartitionDeleteMetadataTransform; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentId; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadata; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadataUpdate; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentState; +import org.apache.kafka.server.log.remote.storage.RemotePartitionDeleteMetadata; +import org.apache.kafka.server.log.remote.storage.RemotePartitionDeleteState; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.Collections; + +public class RemoteLogMetadataTransformTest { + private static final TopicIdPartition TP0 = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 0)); + private final Time time = new MockTime(1); + + @Test + public void testRemoteLogSegmentMetadataTransform() { + RemoteLogSegmentMetadataTransform metadataTransform = new RemoteLogSegmentMetadataTransform(); + + RemoteLogSegmentMetadata metadata = createRemoteLogSegmentMetadata(); + ApiMessageAndVersion apiMessageAndVersion = metadataTransform.toApiMessageAndVersion(metadata); + RemoteLogSegmentMetadata remoteLogSegmentMetadataFromRecord = metadataTransform + .fromApiMessageAndVersion(apiMessageAndVersion); + + Assertions.assertEquals(metadata, remoteLogSegmentMetadataFromRecord); + } + + @Test + public void testRemoteLogSegmentMetadataUpdateTransform() { + RemoteLogSegmentMetadataUpdateTransform metadataUpdateTransform = new RemoteLogSegmentMetadataUpdateTransform(); + + RemoteLogSegmentMetadataUpdate metadataUpdate = + new RemoteLogSegmentMetadataUpdate(new RemoteLogSegmentId(TP0, Uuid.randomUuid()), time.milliseconds(), + RemoteLogSegmentState.COPY_SEGMENT_FINISHED, 1); + ApiMessageAndVersion apiMessageAndVersion = metadataUpdateTransform.toApiMessageAndVersion(metadataUpdate); + RemoteLogSegmentMetadataUpdate metadataUpdateFromRecord = metadataUpdateTransform.fromApiMessageAndVersion(apiMessageAndVersion); + + Assertions.assertEquals(metadataUpdate, metadataUpdateFromRecord); + } + + private RemoteLogSegmentMetadata createRemoteLogSegmentMetadata() { + RemoteLogSegmentId remoteLogSegmentId = new RemoteLogSegmentId(TP0, Uuid.randomUuid()); + return new RemoteLogSegmentMetadata(remoteLogSegmentId, 0L, 100L, -1L, 1, + time.milliseconds(), 1024, Collections.singletonMap(0, 0L)); + } + + @Test + public void testRemoteLogPartitionMetadataTransform() { + RemotePartitionDeleteMetadataTransform transform = new RemotePartitionDeleteMetadataTransform(); + + RemotePartitionDeleteMetadata partitionDeleteMetadata + = new RemotePartitionDeleteMetadata(TP0, RemotePartitionDeleteState.DELETE_PARTITION_STARTED, time.milliseconds(), 1); + ApiMessageAndVersion apiMessageAndVersion = transform.toApiMessageAndVersion(partitionDeleteMetadata); + RemotePartitionDeleteMetadata partitionDeleteMetadataFromRecord = transform.fromApiMessageAndVersion(apiMessageAndVersion); + + Assertions.assertEquals(partitionDeleteMetadata, partitionDeleteMetadataFromRecord); + } +} diff --git a/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogSegmentLifecycleManager.java b/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogSegmentLifecycleManager.java new file mode 100644 index 0000000..b887059 --- /dev/null +++ b/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogSegmentLifecycleManager.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage; + +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadata; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadataUpdate; +import org.apache.kafka.server.log.remote.storage.RemoteStorageException; + +import java.io.Closeable; +import java.io.IOException; +import java.util.Iterator; +import java.util.Optional; + +/** + * This interface defines the lifecycle methods for {@code RemoteLogSegmentMetadata}. {@link RemoteLogSegmentLifecycleTest} tests + * different implementations of this interface. This is responsible for managing all the segments for a given {@code topicIdPartition} + * registered with {@link #initialize(TopicIdPartition)}. + * + * @see org.apache.kafka.server.log.remote.metadata.storage.RemoteLogSegmentLifecycleTest.RemoteLogMetadataCacheWrapper + * @see org.apache.kafka.server.log.remote.metadata.storage.RemoteLogSegmentLifecycleTest.TopicBasedRemoteLogMetadataManagerWrapper + */ +public interface RemoteLogSegmentLifecycleManager extends Closeable { + + /** + * Initialize the resources for this instance and register the given {@code topicIdPartition}. + * + * @param topicIdPartition topic partition to be registered with this instance. + */ + default void initialize(TopicIdPartition topicIdPartition) { + } + + @Override + default void close() throws IOException { + } + + void addRemoteLogSegmentMetadata(RemoteLogSegmentMetadata segmentMetadata) throws RemoteStorageException; + + void updateRemoteLogSegmentMetadata(RemoteLogSegmentMetadataUpdate segmentMetadataUpdate) throws RemoteStorageException; + + Optional highestOffsetForEpoch(int epoch) throws RemoteStorageException; + + Optional remoteLogSegmentMetadata(int leaderEpoch, + long offset) throws RemoteStorageException; + + Iterator listRemoteLogSegments(int leaderEpoch) throws RemoteStorageException; + + Iterator listAllRemoteLogSegments() throws RemoteStorageException; +} diff --git a/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogSegmentLifecycleTest.java b/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogSegmentLifecycleTest.java new file mode 100644 index 0000000..8272b8d --- /dev/null +++ b/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/RemoteLogSegmentLifecycleTest.java @@ -0,0 +1,526 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage; + +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentId; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadata; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadataUpdate; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentState; +import org.apache.kafka.server.log.remote.storage.RemoteResourceNotFoundException; +import org.apache.kafka.server.log.remote.storage.RemoteStorageException; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Stream; + +public class RemoteLogSegmentLifecycleTest { + private static final Logger log = LoggerFactory.getLogger(RemoteLogSegmentLifecycleTest.class); + + private static final int SEG_SIZE = 1024 * 1024; + private static final int BROKER_ID_0 = 0; + private static final int BROKER_ID_1 = 1; + + private final TopicIdPartition topicIdPartition = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition("foo", 0)); + private final Time time = new MockTime(1); + + @ParameterizedTest(name = "remoteLogSegmentLifecycleManager = {0}") + @MethodSource("remoteLogSegmentLifecycleManagers") + public void testRemoteLogSegmentLifeCycle(RemoteLogSegmentLifecycleManager remoteLogSegmentLifecycleManager) throws Exception { + try { + remoteLogSegmentLifecycleManager.initialize(topicIdPartition); + + // segment 0 + // offsets: [0-100] + // leader epochs (0,0), (1,20), (2,80) + Map segment0LeaderEpochs = new HashMap<>(); + segment0LeaderEpochs.put(0, 0L); + segment0LeaderEpochs.put(1, 20L); + segment0LeaderEpochs.put(2, 80L); + RemoteLogSegmentId segment0Id = new RemoteLogSegmentId(topicIdPartition, Uuid.randomUuid()); + RemoteLogSegmentMetadata segment0Metadata = new RemoteLogSegmentMetadata(segment0Id, 0L, 100L, + -1L, BROKER_ID_0, time.milliseconds(), SEG_SIZE, + segment0LeaderEpochs); + remoteLogSegmentLifecycleManager.addRemoteLogSegmentMetadata(segment0Metadata); + + // We should not get this as the segment is still getting copied and it is not yet considered successful until + // it reaches RemoteLogSegmentState.COPY_SEGMENT_FINISHED. + Assertions.assertFalse(remoteLogSegmentLifecycleManager.remoteLogSegmentMetadata(40, 1).isPresent()); + + // Check that these leader epochs are not to be considered for highestOffsetForEpoch API as they are still getting copied. + Stream.of(0, 1, 2).forEach(epoch -> { + try { + Assertions.assertFalse(remoteLogSegmentLifecycleManager.highestOffsetForEpoch(epoch).isPresent()); + } catch (RemoteStorageException e) { + Assertions.fail(e); + } + }); + + RemoteLogSegmentMetadataUpdate segment0Update = new RemoteLogSegmentMetadataUpdate( + segment0Id, time.milliseconds(), RemoteLogSegmentState.COPY_SEGMENT_FINISHED, BROKER_ID_1); + remoteLogSegmentLifecycleManager.updateRemoteLogSegmentMetadata(segment0Update); + RemoteLogSegmentMetadata expectedSegment0Metadata = segment0Metadata.createWithUpdates(segment0Update); + + // segment 1 + // offsets: [101 - 200] + // no changes in leadership with in this segment + // leader epochs (2, 101) + Map segment1LeaderEpochs = Collections.singletonMap(2, 101L); + RemoteLogSegmentMetadata segment1Metadata = createSegmentUpdateWithState(remoteLogSegmentLifecycleManager, segment1LeaderEpochs, 101L, + 200L, + RemoteLogSegmentState.COPY_SEGMENT_FINISHED); + + // segment 2 + // offsets: [201 - 300] + // moved to epoch 3 in between + // leader epochs (2, 201), (3, 240) + Map segment2LeaderEpochs = new HashMap<>(); + segment2LeaderEpochs.put(2, 201L); + segment2LeaderEpochs.put(3, 240L); + RemoteLogSegmentMetadata segment2Metadata = createSegmentUpdateWithState(remoteLogSegmentLifecycleManager, segment2LeaderEpochs, 201L, + 300L, + RemoteLogSegmentState.COPY_SEGMENT_FINISHED); + + // segment 3 + // offsets: [250 - 400] + // leader epochs (3, 250), (4, 370) + Map segment3LeaderEpochs = new HashMap<>(); + segment3LeaderEpochs.put(3, 250L); + segment3LeaderEpochs.put(4, 370L); + RemoteLogSegmentMetadata segment3Metadata = createSegmentUpdateWithState(remoteLogSegmentLifecycleManager, segment3LeaderEpochs, 250L, + 400L, + RemoteLogSegmentState.COPY_SEGMENT_FINISHED); + + ////////////////////////////////////////////////////////////////////////////////////////// + // Four segments are added with different boundaries and leader epochs. + // Search for cache.remoteLogSegmentMetadata(leaderEpoch, offset) for different + // epochs and offsets + ////////////////////////////////////////////////////////////////////////////////////////// + + HashMap expectedEpochOffsetToSegmentMetadata = new HashMap<>(); + // Existing metadata entries. + expectedEpochOffsetToSegmentMetadata.put(new EpochOffset(1, 40), expectedSegment0Metadata); + expectedEpochOffsetToSegmentMetadata.put(new EpochOffset(2, 110), segment1Metadata); + expectedEpochOffsetToSegmentMetadata.put(new EpochOffset(3, 240), segment2Metadata); + expectedEpochOffsetToSegmentMetadata.put(new EpochOffset(3, 250), segment3Metadata); + expectedEpochOffsetToSegmentMetadata.put(new EpochOffset(4, 375), segment3Metadata); + + // Non existing metadata entries. + // Search for offset 110, epoch 1, and it should not exist. + expectedEpochOffsetToSegmentMetadata.put(new EpochOffset(1, 110), null); + // Search for non existing offset 401, epoch 4. + expectedEpochOffsetToSegmentMetadata.put(new EpochOffset(4, 401), null); + // Search for non existing epoch 5. + expectedEpochOffsetToSegmentMetadata.put(new EpochOffset(5, 301), null); + + for (Map.Entry entry : expectedEpochOffsetToSegmentMetadata.entrySet()) { + EpochOffset epochOffset = entry.getKey(); + Optional segmentMetadata = remoteLogSegmentLifecycleManager + .remoteLogSegmentMetadata(epochOffset.epoch, epochOffset.offset); + RemoteLogSegmentMetadata expectedSegmentMetadata = entry.getValue(); + log.debug("Searching for {} , result: {}, expected: {} ", epochOffset, segmentMetadata, + expectedSegmentMetadata); + if (expectedSegmentMetadata != null) { + Assertions.assertEquals(Optional.of(expectedSegmentMetadata), segmentMetadata); + } else { + Assertions.assertFalse(segmentMetadata.isPresent()); + } + } + + // Update segment with state as DELETE_SEGMENT_STARTED. + // It should not be available when we search for that segment. + remoteLogSegmentLifecycleManager + .updateRemoteLogSegmentMetadata(new RemoteLogSegmentMetadataUpdate(expectedSegment0Metadata.remoteLogSegmentId(), + time.milliseconds(), + RemoteLogSegmentState.DELETE_SEGMENT_STARTED, + BROKER_ID_1)); + Assertions.assertFalse(remoteLogSegmentLifecycleManager.remoteLogSegmentMetadata(0, 10).isPresent()); + + // Update segment with state as DELETE_SEGMENT_FINISHED. + // It should not be available when we search for that segment. + remoteLogSegmentLifecycleManager + .updateRemoteLogSegmentMetadata(new RemoteLogSegmentMetadataUpdate(expectedSegment0Metadata.remoteLogSegmentId(), + time.milliseconds(), + RemoteLogSegmentState.DELETE_SEGMENT_FINISHED, + BROKER_ID_1)); + Assertions.assertFalse(remoteLogSegmentLifecycleManager.remoteLogSegmentMetadata(0, 10).isPresent()); + + ////////////////////////////////////////////////////////////////////////////////////////// + // Search for cache.highestLogOffset(leaderEpoch) for all the leader epochs + ////////////////////////////////////////////////////////////////////////////////////////// + + Map expectedEpochToHighestOffset = new HashMap<>(); + expectedEpochToHighestOffset.put(0, 19L); + expectedEpochToHighestOffset.put(1, 79L); + expectedEpochToHighestOffset.put(2, 239L); + expectedEpochToHighestOffset.put(3, 369L); + expectedEpochToHighestOffset.put(4, 400L); + + for (Map.Entry entry : expectedEpochToHighestOffset.entrySet()) { + Integer epoch = entry.getKey(); + Long expectedOffset = entry.getValue(); + Optional offset = remoteLogSegmentLifecycleManager.highestOffsetForEpoch(epoch); + log.debug("Fetching highest offset for epoch: {} , returned: {} , expected: {}", epoch, offset, expectedOffset); + Assertions.assertEquals(Optional.of(expectedOffset), offset); + } + + // Search for non existing leader epoch + Optional highestOffsetForEpoch5 = remoteLogSegmentLifecycleManager.highestOffsetForEpoch(5); + Assertions.assertFalse(highestOffsetForEpoch5.isPresent()); + } finally { + Utils.closeQuietly(remoteLogSegmentLifecycleManager, "RemoteLogSegmentLifecycleManager"); + } + } + + private RemoteLogSegmentMetadata createSegmentUpdateWithState(RemoteLogSegmentLifecycleManager remoteLogSegmentLifecycleManager, + Map segmentLeaderEpochs, + long startOffset, + long endOffset, + RemoteLogSegmentState state) + throws RemoteStorageException { + RemoteLogSegmentId segmentId = new RemoteLogSegmentId(topicIdPartition, Uuid.randomUuid()); + RemoteLogSegmentMetadata segmentMetadata = new RemoteLogSegmentMetadata(segmentId, startOffset, endOffset, -1L, BROKER_ID_0, + time.milliseconds(), SEG_SIZE, segmentLeaderEpochs); + remoteLogSegmentLifecycleManager.addRemoteLogSegmentMetadata(segmentMetadata); + + RemoteLogSegmentMetadataUpdate segMetadataUpdate = new RemoteLogSegmentMetadataUpdate(segmentId, time.milliseconds(), state, BROKER_ID_1); + remoteLogSegmentLifecycleManager.updateRemoteLogSegmentMetadata(segMetadataUpdate); + + return segmentMetadata.createWithUpdates(segMetadataUpdate); + } + + private static class EpochOffset { + final int epoch; + final long offset; + + private EpochOffset(int epoch, + long offset) { + this.epoch = epoch; + this.offset = offset; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + EpochOffset that = (EpochOffset) o; + return epoch == that.epoch && offset == that.offset; + } + + @Override + public int hashCode() { + return Objects.hash(epoch, offset); + } + + @Override + public String toString() { + return "EpochOffset{" + + "epoch=" + epoch + + ", offset=" + offset + + '}'; + } + } + + private static Collection remoteLogSegmentLifecycleManagers() { + return Arrays.asList(Arguments.of(new RemoteLogMetadataCacheWrapper()), + Arguments.of(new TopicBasedRemoteLogMetadataManagerWrapper())); + } + + private void checkListSegments(RemoteLogSegmentLifecycleManager remoteLogSegmentLifecycleManager, + int leaderEpoch, + RemoteLogSegmentMetadata expectedSegment) + throws RemoteStorageException { + // cache.listRemoteLogSegments(leaderEpoch) should contain the above segment. + Iterator segmentsIter = remoteLogSegmentLifecycleManager.listRemoteLogSegments(leaderEpoch); + Assertions.assertTrue(segmentsIter.hasNext() && Objects.equals(segmentsIter.next(), expectedSegment)); + + // cache.listAllRemoteLogSegments() should contain the above segment. + Iterator allSegmentsIter = remoteLogSegmentLifecycleManager.listAllRemoteLogSegments(); + Assertions.assertTrue(allSegmentsIter.hasNext() && Objects.equals(allSegmentsIter.next(), expectedSegment)); + } + + @ParameterizedTest(name = "remoteLogSegmentLifecycleManager = {0}") + @MethodSource("remoteLogSegmentLifecycleManagers") + public void testCacheSegmentWithCopySegmentStartedState(RemoteLogSegmentLifecycleManager remoteLogSegmentLifecycleManager) throws Exception { + + try { + remoteLogSegmentLifecycleManager.initialize(topicIdPartition); + + // Create a segment with state COPY_SEGMENT_STARTED, and check for searching that segment and listing the + // segments. + RemoteLogSegmentId segmentId = new RemoteLogSegmentId(topicIdPartition, Uuid.randomUuid()); + RemoteLogSegmentMetadata segmentMetadata = new RemoteLogSegmentMetadata(segmentId, 0L, 50L, -1L, BROKER_ID_0, + time.milliseconds(), SEG_SIZE, Collections.singletonMap(0, 0L)); + remoteLogSegmentLifecycleManager.addRemoteLogSegmentMetadata(segmentMetadata); + + // This segment should not be available as the state is not reached to COPY_SEGMENT_FINISHED. + Optional segMetadataForOffset0Epoch0 = remoteLogSegmentLifecycleManager.remoteLogSegmentMetadata(0, 0); + Assertions.assertFalse(segMetadataForOffset0Epoch0.isPresent()); + + // cache.listRemoteLogSegments APIs should contain the above segment. + checkListSegments(remoteLogSegmentLifecycleManager, 0, segmentMetadata); + } finally { + Utils.closeQuietly(remoteLogSegmentLifecycleManager, "RemoteLogSegmentLifecycleManager"); + } + } + + @ParameterizedTest(name = "remoteLogSegmentLifecycleManager = {0}") + @MethodSource("remoteLogSegmentLifecycleManagers") + public void testCacheSegmentWithCopySegmentFinishedState(RemoteLogSegmentLifecycleManager remoteLogSegmentLifecycleManager) throws Exception { + try { + remoteLogSegmentLifecycleManager.initialize(topicIdPartition); + + // Create a segment and move it to state COPY_SEGMENT_FINISHED. and check for searching that segment and + // listing the segments. + RemoteLogSegmentMetadata segmentMetadata = createSegmentUpdateWithState(remoteLogSegmentLifecycleManager, + Collections.singletonMap(0, 101L), + 101L, 200L, RemoteLogSegmentState.COPY_SEGMENT_FINISHED); + + // Search should return the above segment. + Optional segMetadataForOffset150 = remoteLogSegmentLifecycleManager.remoteLogSegmentMetadata(0, 150); + Assertions.assertEquals(Optional.of(segmentMetadata), segMetadataForOffset150); + + // cache.listRemoteLogSegments should contain the above segments. + checkListSegments(remoteLogSegmentLifecycleManager, 0, segmentMetadata); + } finally { + Utils.closeQuietly(remoteLogSegmentLifecycleManager, "RemoteLogSegmentLifecycleManager"); + } + } + + @ParameterizedTest(name = "remoteLogSegmentLifecycleManager = {0}") + @MethodSource("remoteLogSegmentLifecycleManagers") + public void testCacheSegmentWithDeleteSegmentStartedState(RemoteLogSegmentLifecycleManager remoteLogSegmentLifecycleManager) throws Exception { + try { + remoteLogSegmentLifecycleManager.initialize(topicIdPartition); + + // Create a segment and move it to state DELETE_SEGMENT_STARTED, and check for searching that segment and + // listing the segments. + RemoteLogSegmentMetadata segmentMetadata = createSegmentUpdateWithState(remoteLogSegmentLifecycleManager, + Collections.singletonMap(0, 201L), + 201L, 300L, RemoteLogSegmentState.DELETE_SEGMENT_STARTED); + + // Search should not return the above segment as their leader epoch state is cleared. + Optional segmentMetadataForOffset250Epoch0 = remoteLogSegmentLifecycleManager.remoteLogSegmentMetadata(0, 250); + Assertions.assertFalse(segmentMetadataForOffset250Epoch0.isPresent()); + + checkListSegments(remoteLogSegmentLifecycleManager, 0, segmentMetadata); + } finally { + Utils.closeQuietly(remoteLogSegmentLifecycleManager, "RemoteLogSegmentLifecycleManager"); + } + } + + @ParameterizedTest(name = "remoteLogSegmentLifecycleManager = {0}") + @MethodSource("remoteLogSegmentLifecycleManagers") + public void testCacheSegmentsWithDeleteSegmentFinishedState(RemoteLogSegmentLifecycleManager remoteLogSegmentLifecycleManager) throws Exception { + try { + remoteLogSegmentLifecycleManager.initialize(topicIdPartition); + + // Create a segment and move it to state DELETE_SEGMENT_FINISHED, and check for searching that segment and + // listing the segments. + RemoteLogSegmentMetadata segmentMetadata = createSegmentUpdateWithState(remoteLogSegmentLifecycleManager, + Collections.singletonMap(0, 301L), + 301L, 400L, RemoteLogSegmentState.DELETE_SEGMENT_STARTED); + + // Search should not return the above segment as their leader epoch state is cleared. + Assertions.assertFalse(remoteLogSegmentLifecycleManager.remoteLogSegmentMetadata(0, 350).isPresent()); + + RemoteLogSegmentMetadataUpdate segmentMetadataUpdate = new RemoteLogSegmentMetadataUpdate(segmentMetadata.remoteLogSegmentId(), + time.milliseconds(), + RemoteLogSegmentState.DELETE_SEGMENT_FINISHED, + BROKER_ID_1); + remoteLogSegmentLifecycleManager.updateRemoteLogSegmentMetadata(segmentMetadataUpdate); + + // listRemoteLogSegments(0) and listRemoteLogSegments() should not contain the above segment. + Assertions.assertFalse(remoteLogSegmentLifecycleManager.listRemoteLogSegments(0).hasNext()); + Assertions.assertFalse(remoteLogSegmentLifecycleManager.listAllRemoteLogSegments().hasNext()); + } finally { + Utils.closeQuietly(remoteLogSegmentLifecycleManager, "RemoteLogSegmentLifecycleManager"); + } + } + + @ParameterizedTest(name = "remoteLogSegmentLifecycleManager = {0}") + @MethodSource("remoteLogSegmentLifecycleManagers") + public void testCacheListSegments(RemoteLogSegmentLifecycleManager remoteLogSegmentLifecycleManager) throws Exception { + try { + remoteLogSegmentLifecycleManager.initialize(topicIdPartition); + + // Create a few segments and add them to the cache. + RemoteLogSegmentMetadata segment0 = createSegmentUpdateWithState(remoteLogSegmentLifecycleManager, Collections.singletonMap(0, 0L), 0, + 100, + RemoteLogSegmentState.COPY_SEGMENT_FINISHED); + RemoteLogSegmentMetadata segment1 = createSegmentUpdateWithState(remoteLogSegmentLifecycleManager, Collections.singletonMap(0, 101L), 101, + 200, + RemoteLogSegmentState.COPY_SEGMENT_FINISHED); + Map segment2LeaderEpochs = new HashMap<>(); + segment2LeaderEpochs.put(0, 201L); + segment2LeaderEpochs.put(1, 301L); + RemoteLogSegmentMetadata segment2 = createSegmentUpdateWithState(remoteLogSegmentLifecycleManager, segment2LeaderEpochs, 201, 400, + RemoteLogSegmentState.COPY_SEGMENT_FINISHED); + + // listRemoteLogSegments(0) and listAllRemoteLogSegments() should contain all the above segments. + List expectedSegmentsForEpoch0 = Arrays.asList(segment0, segment1, segment2); + Assertions.assertTrue(TestUtils.sameElementsWithOrder(remoteLogSegmentLifecycleManager.listRemoteLogSegments(0), + expectedSegmentsForEpoch0.iterator())); + Assertions.assertTrue(TestUtils.sameElementsWithoutOrder(remoteLogSegmentLifecycleManager.listAllRemoteLogSegments(), + expectedSegmentsForEpoch0.iterator())); + + // listRemoteLogSegments(1) should contain only segment2. + List expectedSegmentsForEpoch1 = Collections.singletonList(segment2); + Assertions.assertTrue(TestUtils.sameElementsWithOrder(remoteLogSegmentLifecycleManager.listRemoteLogSegments(1), + expectedSegmentsForEpoch1.iterator())); + } finally { + Utils.closeQuietly(remoteLogSegmentLifecycleManager, "RemoteLogSegmentLifecycleManager"); + } + } + + /** + * This is a wrapper with {@link TopicBasedRemoteLogMetadataManager} implementing {@link RemoteLogSegmentLifecycleManager}. + * This is passed to {@link #testRemoteLogSegmentLifeCycle(RemoteLogSegmentLifecycleManager)} to test + * {@code RemoteLogMetadataCache} for several lifecycle operations. + *

                + * This starts a Kafka cluster with {@link #initialize(Set, boolean)} )} with {@link #brokerCount()} no of servers. It also + * creates the remote log metadata topic required for {@code TopicBasedRemoteLogMetadataManager}. This cluster will + * be stopped by invoking {@link #close()}. + */ + static class TopicBasedRemoteLogMetadataManagerWrapper extends TopicBasedRemoteLogMetadataManagerHarness implements RemoteLogSegmentLifecycleManager { + + private TopicIdPartition topicIdPartition; + + @Override + public synchronized void initialize(TopicIdPartition topicIdPartition) { + this.topicIdPartition = topicIdPartition; + super.initialize(Collections.singleton(topicIdPartition), true); + } + + @Override + public void addRemoteLogSegmentMetadata(RemoteLogSegmentMetadata segmentMetadata) throws RemoteStorageException { + try { + // Wait until the segment is added successfully. + remoteLogMetadataManager().addRemoteLogSegmentMetadata(segmentMetadata).get(); + } catch (Exception e) { + throw new RemoteStorageException(e); + } + } + + @Override + public void updateRemoteLogSegmentMetadata(RemoteLogSegmentMetadataUpdate segmentMetadataUpdate) throws RemoteStorageException { + try { + // Wait until the segment is updated successfully. + remoteLogMetadataManager().updateRemoteLogSegmentMetadata(segmentMetadataUpdate).get(); + } catch (Exception e) { + throw new RemoteStorageException(e); + } + } + + @Override + public Optional highestOffsetForEpoch(int leaderEpoch) throws RemoteStorageException { + return remoteLogMetadataManager().highestOffsetForEpoch(topicIdPartition, leaderEpoch); + } + + @Override + public Optional remoteLogSegmentMetadata(int leaderEpoch, + long offset) throws RemoteStorageException { + return remoteLogMetadataManager().remoteLogSegmentMetadata(topicIdPartition, leaderEpoch, offset); + } + + @Override + public Iterator listRemoteLogSegments(int leaderEpoch) throws RemoteStorageException { + return remoteLogMetadataManager().listRemoteLogSegments(topicIdPartition, leaderEpoch); + } + + @Override + public Iterator listAllRemoteLogSegments() throws RemoteStorageException { + return remoteLogMetadataManager().listRemoteLogSegments(topicIdPartition); + } + + @Override + public void close() throws IOException { + tearDown(); + } + + @Override + public int brokerCount() { + return 3; + } + } + + /** + * This is a wrapper with {@link RemoteLogMetadataCache} implementing {@link RemoteLogSegmentLifecycleManager}. + * This is passed to {@link #testRemoteLogSegmentLifeCycle(RemoteLogSegmentLifecycleManager)} to test + * {@code RemoteLogMetadataCache} for several lifecycle operations. + */ + static class RemoteLogMetadataCacheWrapper implements RemoteLogSegmentLifecycleManager { + + private final RemoteLogMetadataCache metadataCache = new RemoteLogMetadataCache(); + + @Override + public void updateRemoteLogSegmentMetadata(RemoteLogSegmentMetadataUpdate segmentMetadataUpdate) throws RemoteStorageException { + metadataCache.updateRemoteLogSegmentMetadata(segmentMetadataUpdate); + } + + @Override + public Optional highestOffsetForEpoch(int epoch) throws RemoteStorageException { + return metadataCache.highestOffsetForEpoch(epoch); + } + + @Override + public Optional remoteLogSegmentMetadata(int leaderEpoch, + long offset) throws RemoteStorageException { + return metadataCache.remoteLogSegmentMetadata(leaderEpoch, offset); + } + + @Override + public Iterator listRemoteLogSegments(int leaderEpoch) throws RemoteResourceNotFoundException { + return metadataCache.listRemoteLogSegments(leaderEpoch); + } + + @Override + public Iterator listAllRemoteLogSegments() { + return metadataCache.listAllRemoteLogSegments(); + } + + @Override + public void addRemoteLogSegmentMetadata(RemoteLogSegmentMetadata segmentMetadata) throws RemoteStorageException { + metadataCache.addCopyInProgressSegment(segmentMetadata); + } + } +} diff --git a/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/TopicBasedRemoteLogMetadataManagerConfigTest.java b/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/TopicBasedRemoteLogMetadataManagerConfigTest.java new file mode 100644 index 0000000..3785c8d --- /dev/null +++ b/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/TopicBasedRemoteLogMetadataManagerConfigTest.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.AbstractMap; +import java.util.HashMap; +import java.util.Map; + +import static org.apache.kafka.server.log.remote.metadata.storage.TopicBasedRemoteLogMetadataManagerConfig.BROKER_ID; +import static org.apache.kafka.server.log.remote.metadata.storage.TopicBasedRemoteLogMetadataManagerConfig.LOG_DIR; +import static org.apache.kafka.server.log.remote.metadata.storage.TopicBasedRemoteLogMetadataManagerConfig.REMOTE_LOG_METADATA_COMMON_CLIENT_PREFIX; +import static org.apache.kafka.server.log.remote.metadata.storage.TopicBasedRemoteLogMetadataManagerConfig.REMOTE_LOG_METADATA_CONSUMER_PREFIX; +import static org.apache.kafka.server.log.remote.metadata.storage.TopicBasedRemoteLogMetadataManagerConfig.REMOTE_LOG_METADATA_PRODUCER_PREFIX; +import static org.apache.kafka.server.log.remote.metadata.storage.TopicBasedRemoteLogMetadataManagerConfig.REMOTE_LOG_METADATA_TOPIC_PARTITIONS_PROP; +import static org.apache.kafka.server.log.remote.metadata.storage.TopicBasedRemoteLogMetadataManagerConfig.REMOTE_LOG_METADATA_TOPIC_REPLICATION_FACTOR_PROP; +import static org.apache.kafka.server.log.remote.metadata.storage.TopicBasedRemoteLogMetadataManagerConfig.REMOTE_LOG_METADATA_TOPIC_RETENTION_MS_PROP; + +public class TopicBasedRemoteLogMetadataManagerConfigTest { + private static final Logger log = LoggerFactory.getLogger(TopicBasedRemoteLogMetadataManagerConfigTest.class); + + private static final String BOOTSTRAP_SERVERS = "localhost:9091"; + + @Test + public void testValidConfig() { + + Map commonClientConfig = new HashMap<>(); + commonClientConfig.put(CommonClientConfigs.RETRIES_CONFIG, 10); + commonClientConfig.put(CommonClientConfigs.RETRY_BACKOFF_MS_CONFIG, 1000L); + commonClientConfig.put(CommonClientConfigs.METADATA_MAX_AGE_CONFIG, 60000L); + + Map producerConfig = new HashMap<>(); + producerConfig.put(ProducerConfig.ACKS_CONFIG, "all"); + + Map consumerConfig = new HashMap<>(); + consumerConfig.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, false); + + Map props = createValidConfigProps(commonClientConfig, producerConfig, consumerConfig); + + // Check for topic properties + TopicBasedRemoteLogMetadataManagerConfig rlmmConfig = new TopicBasedRemoteLogMetadataManagerConfig(props); + Assertions.assertEquals(props.get(REMOTE_LOG_METADATA_TOPIC_PARTITIONS_PROP), rlmmConfig.metadataTopicPartitionsCount()); + + // Check for common client configs. + for (Map.Entry entry : commonClientConfig.entrySet()) { + log.info("Checking config: " + entry.getKey()); + Assertions.assertEquals(entry.getValue(), + rlmmConfig.producerProperties().get(entry.getKey())); + Assertions.assertEquals(entry.getValue(), + rlmmConfig.consumerProperties().get(entry.getKey())); + } + + // Check for producer configs. + for (Map.Entry entry : producerConfig.entrySet()) { + log.info("Checking config: " + entry.getKey()); + Assertions.assertEquals(entry.getValue(), + rlmmConfig.producerProperties().get(entry.getKey())); + } + + // Check for consumer configs. + for (Map.Entry entry : consumerConfig.entrySet()) { + log.info("Checking config: " + entry.getKey()); + Assertions.assertEquals(entry.getValue(), + rlmmConfig.consumerProperties().get(entry.getKey())); + } + } + + @Test + public void testProducerConsumerOverridesConfig() { + Map.Entry overrideEntry = new AbstractMap.SimpleImmutableEntry<>(CommonClientConfigs.METADATA_MAX_AGE_CONFIG, 60000L); + Map commonClientConfig = new HashMap<>(); + commonClientConfig.put(CommonClientConfigs.RETRIES_CONFIG, 10); + commonClientConfig.put(CommonClientConfigs.RETRY_BACKOFF_MS_CONFIG, 1000L); + commonClientConfig.put(overrideEntry.getKey(), overrideEntry.getValue()); + + Map producerConfig = new HashMap<>(); + producerConfig.put(ProducerConfig.ACKS_CONFIG, -1); + Long overriddenProducerPropValue = overrideEntry.getValue() * 2; + producerConfig.put(overrideEntry.getKey(), overriddenProducerPropValue); + + Map consumerConfig = new HashMap<>(); + consumerConfig.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, false); + Long overriddenConsumerPropValue = overrideEntry.getValue() * 3; + consumerConfig.put(overrideEntry.getKey(), overriddenConsumerPropValue); + + Map props = createValidConfigProps(commonClientConfig, producerConfig, consumerConfig); + TopicBasedRemoteLogMetadataManagerConfig rlmmConfig = new TopicBasedRemoteLogMetadataManagerConfig(props); + + Assertions.assertEquals(overriddenProducerPropValue, + rlmmConfig.producerProperties().get(overrideEntry.getKey())); + Assertions.assertEquals(overriddenConsumerPropValue, + rlmmConfig.consumerProperties().get(overrideEntry.getKey())); + } + + private Map createValidConfigProps(Map commonClientConfig, + Map producerConfig, + Map consumerConfig) { + Map props = new HashMap<>(); + props.put(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, BOOTSTRAP_SERVERS); + props.put(BROKER_ID, 1); + props.put(LOG_DIR, TestUtils.tempDirectory().getAbsolutePath()); + + props.put(REMOTE_LOG_METADATA_TOPIC_REPLICATION_FACTOR_PROP, (short) 3); + props.put(REMOTE_LOG_METADATA_TOPIC_PARTITIONS_PROP, 10); + props.put(REMOTE_LOG_METADATA_TOPIC_RETENTION_MS_PROP, 60 * 60 * 1000L); + + // common client configs + for (Map.Entry entry : commonClientConfig.entrySet()) { + props.put(REMOTE_LOG_METADATA_COMMON_CLIENT_PREFIX + entry.getKey(), entry.getValue()); + } + + // producer configs + for (Map.Entry entry : producerConfig.entrySet()) { + props.put(REMOTE_LOG_METADATA_PRODUCER_PREFIX + entry.getKey(), entry.getValue()); + } + + //consumer configs + for (Map.Entry entry : consumerConfig.entrySet()) { + props.put(REMOTE_LOG_METADATA_CONSUMER_PREFIX + entry.getKey(), entry.getValue()); + } + + return props; + } +} \ No newline at end of file diff --git a/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/TopicBasedRemoteLogMetadataManagerHarness.java b/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/TopicBasedRemoteLogMetadataManagerHarness.java new file mode 100644 index 0000000..e1dd217 --- /dev/null +++ b/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/TopicBasedRemoteLogMetadataManagerHarness.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage; + +import kafka.api.IntegrationTestHarness; +import kafka.utils.EmptyTestInfo; +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.test.TestUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeoutException; + +import static org.apache.kafka.server.log.remote.metadata.storage.TopicBasedRemoteLogMetadataManagerConfig.BROKER_ID; +import static org.apache.kafka.server.log.remote.metadata.storage.TopicBasedRemoteLogMetadataManagerConfig.REMOTE_LOG_METADATA_COMMON_CLIENT_PREFIX; +import static org.apache.kafka.server.log.remote.metadata.storage.TopicBasedRemoteLogMetadataManagerConfig.LOG_DIR; +import static org.apache.kafka.server.log.remote.metadata.storage.TopicBasedRemoteLogMetadataManagerConfig.REMOTE_LOG_METADATA_TOPIC_PARTITIONS_PROP; +import static org.apache.kafka.server.log.remote.metadata.storage.TopicBasedRemoteLogMetadataManagerConfig.REMOTE_LOG_METADATA_TOPIC_REPLICATION_FACTOR_PROP; +import static org.apache.kafka.server.log.remote.metadata.storage.TopicBasedRemoteLogMetadataManagerConfig.REMOTE_LOG_METADATA_TOPIC_RETENTION_MS_PROP; + +/** + * A test harness class that brings up 3 brokers and registers {@link TopicBasedRemoteLogMetadataManager} on broker with id as 0. + */ +public class TopicBasedRemoteLogMetadataManagerHarness extends IntegrationTestHarness { + private static final Logger log = LoggerFactory.getLogger(TopicBasedRemoteLogMetadataManagerHarness.class); + + protected static final int METADATA_TOPIC_PARTITIONS_COUNT = 3; + protected static final short METADATA_TOPIC_REPLICATION_FACTOR = 2; + protected static final long METADATA_TOPIC_RETENTION_MS = 24 * 60 * 60 * 1000L; + + private TopicBasedRemoteLogMetadataManager topicBasedRemoteLogMetadataManager; + + protected Map overrideRemoteLogMetadataManagerProps() { + return Collections.emptyMap(); + } + + public void initialize(Set topicIdPartitions, + boolean startConsumerThread) { + // Call setup to start the cluster. + super.setUp(new EmptyTestInfo()); + + initializeRemoteLogMetadataManager(topicIdPartitions, startConsumerThread); + } + + public void initializeRemoteLogMetadataManager(Set topicIdPartitions, + boolean startConsumerThread) { + String logDir = TestUtils.tempDirectory("rlmm_segs_").getAbsolutePath(); + topicBasedRemoteLogMetadataManager = new TopicBasedRemoteLogMetadataManager(startConsumerThread) { + @Override + public void onPartitionLeadershipChanges(Set leaderPartitions, + Set followerPartitions) { + Set allReplicas = new HashSet<>(leaderPartitions); + allReplicas.addAll(followerPartitions); + // Make sure the topic partition dirs exist as the topics might not have been created on this broker. + for (TopicIdPartition topicIdPartition : allReplicas) { + // Create partition directory in the log directory created by topicBasedRemoteLogMetadataManager. + File partitionDir = new File(new File(config().logDir()), topicIdPartition.topicPartition().topic() + "-" + topicIdPartition.topicPartition().partition()); + partitionDir.mkdirs(); + if (!partitionDir.exists()) { + throw new KafkaException("Partition directory:[" + partitionDir + "] could not be created successfully."); + } + } + + super.onPartitionLeadershipChanges(leaderPartitions, followerPartitions); + } + }; + + // Initialize TopicBasedRemoteLogMetadataManager. + Map configs = new HashMap<>(); + configs.put(REMOTE_LOG_METADATA_COMMON_CLIENT_PREFIX + CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, brokerList()); + configs.put(BROKER_ID, 0); + configs.put(LOG_DIR, logDir); + configs.put(REMOTE_LOG_METADATA_TOPIC_PARTITIONS_PROP, METADATA_TOPIC_PARTITIONS_COUNT); + configs.put(REMOTE_LOG_METADATA_TOPIC_REPLICATION_FACTOR_PROP, METADATA_TOPIC_REPLICATION_FACTOR); + configs.put(REMOTE_LOG_METADATA_TOPIC_RETENTION_MS_PROP, METADATA_TOPIC_RETENTION_MS); + + log.debug("TopicBasedRemoteLogMetadataManager configs before adding overridden properties: {}", configs); + // Add override properties. + configs.putAll(overrideRemoteLogMetadataManagerProps()); + log.debug("TopicBasedRemoteLogMetadataManager configs after adding overridden properties: {}", configs); + + topicBasedRemoteLogMetadataManager.configure(configs); + try { + waitUntilInitialized(60_000); + } catch (TimeoutException e) { + throw new KafkaException(e); + } + + topicBasedRemoteLogMetadataManager.onPartitionLeadershipChanges(topicIdPartitions, Collections.emptySet()); + } + + // Visible for testing. + public void waitUntilInitialized(long waitTimeMs) throws TimeoutException { + long startMs = System.currentTimeMillis(); + while (!topicBasedRemoteLogMetadataManager.isInitialized()) { + long currentTimeMs = System.currentTimeMillis(); + if (currentTimeMs > startMs + waitTimeMs) { + throw new TimeoutException("Time out reached before it is initialized successfully"); + } + + Utils.sleep(100); + } + } + + @Override + public int brokerCount() { + return 3; + } + + protected TopicBasedRemoteLogMetadataManager remoteLogMetadataManager() { + return topicBasedRemoteLogMetadataManager; + } + + public void close() throws IOException { + closeRemoteLogMetadataManager(); + + // Stop the servers and zookeeper. + tearDown(); + } + + public void closeRemoteLogMetadataManager() { + Utils.closeQuietly(topicBasedRemoteLogMetadataManager, "TopicBasedRemoteLogMetadataManager"); + } +} \ No newline at end of file diff --git a/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/TopicBasedRemoteLogMetadataManagerRestartTest.java b/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/TopicBasedRemoteLogMetadataManagerRestartTest.java new file mode 100644 index 0000000..2c7baf8 --- /dev/null +++ b/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/TopicBasedRemoteLogMetadataManagerRestartTest.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage; + +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentId; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadata; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import scala.collection.JavaConverters; +import scala.collection.Seq; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.apache.kafka.server.log.remote.metadata.storage.ConsumerManager.COMMITTED_OFFSETS_FILE_NAME; +import static org.apache.kafka.server.log.remote.metadata.storage.TopicBasedRemoteLogMetadataManagerConfig.LOG_DIR; + +@SuppressWarnings("deprecation") // Added for Scala 2.12 compatibility for usages of JavaConverters +public class TopicBasedRemoteLogMetadataManagerRestartTest { + + private static final int SEG_SIZE = 1024 * 1024; + + private final Time time = new MockTime(1); + private final String logDir = TestUtils.tempDirectory("_rlmm_segs_").getAbsolutePath(); + + private TopicBasedRemoteLogMetadataManagerHarness remoteLogMetadataManagerHarness; + + @BeforeEach + public void setup() { + // Start the cluster and initialize TopicBasedRemoteLogMetadataManager. + remoteLogMetadataManagerHarness = new TopicBasedRemoteLogMetadataManagerHarness() { + protected Map overrideRemoteLogMetadataManagerProps() { + Map props = new HashMap<>(); + props.put(LOG_DIR, logDir); + return props; + } + }; + remoteLogMetadataManagerHarness.initialize(Collections.emptySet(), true); + } + + private void startTopicBasedRemoteLogMetadataManagerHarness(boolean startConsumerThread) { + remoteLogMetadataManagerHarness.initializeRemoteLogMetadataManager(Collections.emptySet(), startConsumerThread); + } + + @AfterEach + public void teardown() throws IOException { + if (remoteLogMetadataManagerHarness != null) { + remoteLogMetadataManagerHarness.close(); + } + } + + private void stopTopicBasedRemoteLogMetadataManagerHarness() throws IOException { + remoteLogMetadataManagerHarness.closeRemoteLogMetadataManager(); + } + + private TopicBasedRemoteLogMetadataManager topicBasedRlmm() { + return remoteLogMetadataManagerHarness.remoteLogMetadataManager(); + } + + @Test + public void testRLMMAPIsAfterRestart() throws Exception { + // Create topics. + String leaderTopic = "new-leader"; + HashMap> assignedLeaderTopicReplicas = new HashMap<>(); + List leaderTopicReplicas = new ArrayList<>(); + // Set broker id 0 as the first entry which is taken as the leader. + leaderTopicReplicas.add(0); + leaderTopicReplicas.add(1); + leaderTopicReplicas.add(2); + assignedLeaderTopicReplicas.put(0, JavaConverters.asScalaBuffer(leaderTopicReplicas)); + remoteLogMetadataManagerHarness.createTopic(leaderTopic, JavaConverters.mapAsScalaMap(assignedLeaderTopicReplicas)); + + String followerTopic = "new-follower"; + HashMap> assignedFollowerTopicReplicas = new HashMap<>(); + List followerTopicReplicas = new ArrayList<>(); + // Set broker id 1 as the first entry which is taken as the leader. + followerTopicReplicas.add(1); + followerTopicReplicas.add(2); + followerTopicReplicas.add(0); + assignedFollowerTopicReplicas.put(0, JavaConverters.asScalaBuffer(followerTopicReplicas)); + remoteLogMetadataManagerHarness.createTopic(followerTopic, JavaConverters.mapAsScalaMap(assignedFollowerTopicReplicas)); + + final TopicIdPartition leaderTopicIdPartition = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition(leaderTopic, 0)); + final TopicIdPartition followerTopicIdPartition = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition(followerTopic, 0)); + + // Register these partitions to RLMM. + topicBasedRlmm().onPartitionLeadershipChanges(Collections.singleton(leaderTopicIdPartition), Collections.singleton(followerTopicIdPartition)); + + // Add segments for these partitions but they are not available as they have not yet been subscribed. + RemoteLogSegmentMetadata leaderSegmentMetadata = new RemoteLogSegmentMetadata( + new RemoteLogSegmentId(leaderTopicIdPartition, Uuid.randomUuid()), + 0, 100, -1L, 0, + time.milliseconds(), SEG_SIZE, Collections.singletonMap(0, 0L)); + topicBasedRlmm().addRemoteLogSegmentMetadata(leaderSegmentMetadata).get(); + + RemoteLogSegmentMetadata followerSegmentMetadata = new RemoteLogSegmentMetadata( + new RemoteLogSegmentId(followerTopicIdPartition, Uuid.randomUuid()), + 0, 100, -1L, 0, + time.milliseconds(), SEG_SIZE, Collections.singletonMap(0, 0L)); + topicBasedRlmm().addRemoteLogSegmentMetadata(followerSegmentMetadata).get(); + + // Stop TopicBasedRemoteLogMetadataManager only. + stopTopicBasedRemoteLogMetadataManagerHarness(); + + // Start TopicBasedRemoteLogMetadataManager but do not start consumer thread to check whether the stored metadata is + // loaded successfully or not. + startTopicBasedRemoteLogMetadataManagerHarness(false); + + // Register these partitions to RLMM, which loads the respective metadata snapshots. + topicBasedRlmm().onPartitionLeadershipChanges(Collections.singleton(leaderTopicIdPartition), Collections.singleton(followerTopicIdPartition)); + + // Check for the stored entries from the earlier run. + Assertions.assertTrue(TestUtils.sameElementsWithoutOrder(Collections.singleton(leaderSegmentMetadata).iterator(), + topicBasedRlmm().listRemoteLogSegments(leaderTopicIdPartition))); + Assertions.assertTrue(TestUtils.sameElementsWithoutOrder(Collections.singleton(followerSegmentMetadata).iterator(), + topicBasedRlmm().listRemoteLogSegments(followerTopicIdPartition))); + // Check whether the check-pointed consumer offsets are stored or not. + Path committedOffsetsPath = new File(logDir, COMMITTED_OFFSETS_FILE_NAME).toPath(); + Assertions.assertTrue(committedOffsetsPath.toFile().exists()); + CommittedOffsetsFile committedOffsetsFile = new CommittedOffsetsFile(committedOffsetsPath.toFile()); + + int metadataPartition1 = topicBasedRlmm().metadataPartition(leaderTopicIdPartition); + int metadataPartition2 = topicBasedRlmm().metadataPartition(followerTopicIdPartition); + Optional receivedOffsetForPartition1 = topicBasedRlmm().receivedOffsetForPartition(metadataPartition1); + Optional receivedOffsetForPartition2 = topicBasedRlmm().receivedOffsetForPartition(metadataPartition2); + Assertions.assertTrue(receivedOffsetForPartition1.isPresent()); + Assertions.assertTrue(receivedOffsetForPartition2.isPresent()); + + // Make sure these offsets are at least 0. + Assertions.assertTrue(receivedOffsetForPartition1.get() >= 0); + Assertions.assertTrue(receivedOffsetForPartition2.get() >= 0); + + // Check the stored entries and the offsets that were set on consumer are the same. + Map partitionToOffset = committedOffsetsFile.readEntries(); + Assertions.assertEquals(partitionToOffset.get(metadataPartition1), receivedOffsetForPartition1.get()); + Assertions.assertEquals(partitionToOffset.get(metadataPartition2), receivedOffsetForPartition2.get()); + + // Start Consumer thread + topicBasedRlmm().startConsumerThread(); + + // Add one more segment + RemoteLogSegmentMetadata leaderSegmentMetadata2 = new RemoteLogSegmentMetadata( + new RemoteLogSegmentId(leaderTopicIdPartition, Uuid.randomUuid()), + 101, 200, -1L, 0, + time.milliseconds(), SEG_SIZE, Collections.singletonMap(0, 101L)); + topicBasedRlmm().addRemoteLogSegmentMetadata(leaderSegmentMetadata2).get(); + + // Check that both the stored segment and recently added segment are available. + Assertions.assertTrue(TestUtils.sameElementsWithoutOrder(Arrays.asList(leaderSegmentMetadata, leaderSegmentMetadata2).iterator(), + topicBasedRlmm().listRemoteLogSegments(leaderTopicIdPartition))); + } + +} \ No newline at end of file diff --git a/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/TopicBasedRemoteLogMetadataManagerTest.java b/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/TopicBasedRemoteLogMetadataManagerTest.java new file mode 100644 index 0000000..89b25c6 --- /dev/null +++ b/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/TopicBasedRemoteLogMetadataManagerTest.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage; + +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentId; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadata; +import org.apache.kafka.server.log.remote.storage.RemoteResourceNotFoundException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.collection.JavaConverters; +import scala.collection.Seq; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.concurrent.TimeoutException; + +@SuppressWarnings("deprecation") // Added for Scala 2.12 compatibility for usages of JavaConverters +public class TopicBasedRemoteLogMetadataManagerTest { + private static final Logger log = LoggerFactory.getLogger(TopicBasedRemoteLogMetadataManagerTest.class); + + private static final int SEG_SIZE = 1024 * 1024; + + private final Time time = new MockTime(1); + private final TopicBasedRemoteLogMetadataManagerHarness remoteLogMetadataManagerHarness = new TopicBasedRemoteLogMetadataManagerHarness(); + + @BeforeEach + public void setup() { + // Start the cluster and initialize TopicBasedRemoteLogMetadataManager. + remoteLogMetadataManagerHarness.initialize(Collections.emptySet(), true); + } + + @AfterEach + public void teardown() throws IOException { + remoteLogMetadataManagerHarness.close(); + } + + public TopicBasedRemoteLogMetadataManager topicBasedRlmm() { + return remoteLogMetadataManagerHarness.remoteLogMetadataManager(); + } + + @Test + public void testWithNoAssignedPartitions() throws Exception { + // This test checks simple lifecycle of TopicBasedRemoteLogMetadataManager with out assigning any leader/follower partitions. + // This should close successfully releasing the resources. + log.info("Not assigning any partitions on TopicBasedRemoteLogMetadataManager"); + } + + @Test + public void testNewPartitionUpdates() throws Exception { + // Create topics. + String leaderTopic = "new-leader"; + HashMap> assignedLeaderTopicReplicas = new HashMap<>(); + List leaderTopicReplicas = new ArrayList<>(); + // Set broker id 0 as the first entry which is taken as the leader. + leaderTopicReplicas.add(0); + leaderTopicReplicas.add(1); + leaderTopicReplicas.add(2); + assignedLeaderTopicReplicas.put(0, JavaConverters.asScalaBuffer(leaderTopicReplicas)); + remoteLogMetadataManagerHarness.createTopic(leaderTopic, JavaConverters.mapAsScalaMap(assignedLeaderTopicReplicas)); + + String followerTopic = "new-follower"; + HashMap> assignedFollowerTopicReplicas = new HashMap<>(); + List followerTopicReplicas = new ArrayList<>(); + // Set broker id 1 as the first entry which is taken as the leader. + followerTopicReplicas.add(1); + followerTopicReplicas.add(2); + followerTopicReplicas.add(0); + assignedFollowerTopicReplicas.put(0, JavaConverters.asScalaBuffer(followerTopicReplicas)); + remoteLogMetadataManagerHarness.createTopic(followerTopic, JavaConverters.mapAsScalaMap(assignedFollowerTopicReplicas)); + + final TopicIdPartition newLeaderTopicIdPartition = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition(leaderTopic, 0)); + final TopicIdPartition newFollowerTopicIdPartition = new TopicIdPartition(Uuid.randomUuid(), new TopicPartition(followerTopic, 0)); + + // Add segments for these partitions but an exception is received as they have not yet been subscribed. + // These messages would have been published to the respective metadata topic partitions but the ConsumerManager + // has not yet been subscribing as they are not yet registered. + RemoteLogSegmentMetadata leaderSegmentMetadata = new RemoteLogSegmentMetadata(new RemoteLogSegmentId(newLeaderTopicIdPartition, Uuid.randomUuid()), + 0, 100, -1L, 0, + time.milliseconds(), SEG_SIZE, Collections.singletonMap(0, 0L)); + Assertions.assertThrows(Exception.class, () -> topicBasedRlmm().addRemoteLogSegmentMetadata(leaderSegmentMetadata).get()); + + RemoteLogSegmentMetadata followerSegmentMetadata = new RemoteLogSegmentMetadata(new RemoteLogSegmentId(newFollowerTopicIdPartition, Uuid.randomUuid()), + 0, 100, -1L, 0, + time.milliseconds(), SEG_SIZE, Collections.singletonMap(0, 0L)); + Assertions.assertThrows(Exception.class, () -> topicBasedRlmm().addRemoteLogSegmentMetadata(followerSegmentMetadata).get()); + + // `listRemoteLogSegments` will receive an exception as these topic partitions are not yet registered. + Assertions.assertThrows(RemoteResourceNotFoundException.class, () -> topicBasedRlmm().listRemoteLogSegments(newLeaderTopicIdPartition)); + Assertions.assertThrows(RemoteResourceNotFoundException.class, () -> topicBasedRlmm().listRemoteLogSegments(newFollowerTopicIdPartition)); + + topicBasedRlmm().onPartitionLeadershipChanges(Collections.singleton(newLeaderTopicIdPartition), + Collections.singleton(newFollowerTopicIdPartition)); + + // RemoteLogSegmentMetadata events are already published, and topicBasedRlmm's consumer manager will start + // fetching those events and build the cache. + waitUntilConsumerCatchesup(newLeaderTopicIdPartition, newFollowerTopicIdPartition, 30_000L); + + Assertions.assertTrue(topicBasedRlmm().listRemoteLogSegments(newLeaderTopicIdPartition).hasNext()); + Assertions.assertTrue(topicBasedRlmm().listRemoteLogSegments(newFollowerTopicIdPartition).hasNext()); + } + + private void waitUntilConsumerCatchesup(TopicIdPartition newLeaderTopicIdPartition, + TopicIdPartition newFollowerTopicIdPartition, + long timeoutMs) throws TimeoutException { + int leaderMetadataPartition = topicBasedRlmm().metadataPartition(newLeaderTopicIdPartition); + int followerMetadataPartition = topicBasedRlmm().metadataPartition(newFollowerTopicIdPartition); + + log.debug("Metadata partition for newLeaderTopicIdPartition: [{}], is: [{}]", newLeaderTopicIdPartition, leaderMetadataPartition); + log.debug("Metadata partition for newFollowerTopicIdPartition: [{}], is: [{}]", newFollowerTopicIdPartition, followerMetadataPartition); + + long sleepMs = 100L; + long time = System.currentTimeMillis(); + + while (true) { + if (System.currentTimeMillis() - time > timeoutMs) { + throw new TimeoutException("Timed out after " + timeoutMs + "ms "); + } + + // If both the leader and follower partitions are mapped to the same metadata partition then it should have at least + // 2 messages. That means, received offset should be >= 1 (including duplicate messages if any). + if (leaderMetadataPartition == followerMetadataPartition) { + if (topicBasedRlmm().receivedOffsetForPartition(leaderMetadataPartition).orElse(-1L) >= 1) { + break; + } + } else { + // If the leader partition and the follower partition are mapped to different metadata partitions then + // each of those metadata partitions will have at least 1 message. That means, received offset should + // be >= 0 (including duplicate messages if any). + if (topicBasedRlmm().receivedOffsetForPartition(leaderMetadataPartition).orElse(-1L) >= 0 || + topicBasedRlmm().receivedOffsetForPartition(followerMetadataPartition).orElse(-1L) >= 0) { + break; + } + } + + log.debug("Sleeping for: " + sleepMs); + Utils.sleep(sleepMs); + } + } + +} diff --git a/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/TopicBasedRemoteLogMetadataManagerWrapperWithHarness.java b/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/TopicBasedRemoteLogMetadataManagerWrapperWithHarness.java new file mode 100644 index 0000000..ef9287d --- /dev/null +++ b/storage/src/test/java/org/apache/kafka/server/log/remote/metadata/storage/TopicBasedRemoteLogMetadataManagerWrapperWithHarness.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.metadata.storage; + +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.server.log.remote.storage.RemoteLogMetadataManager; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadata; +import org.apache.kafka.server.log.remote.storage.RemoteLogSegmentMetadataUpdate; +import org.apache.kafka.server.log.remote.storage.RemotePartitionDeleteMetadata; +import org.apache.kafka.server.log.remote.storage.RemoteStorageException; + +import java.io.IOException; +import java.util.Collections; +import java.util.Iterator; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; + +public class TopicBasedRemoteLogMetadataManagerWrapperWithHarness implements RemoteLogMetadataManager { + + private final TopicBasedRemoteLogMetadataManagerHarness remoteLogMetadataManagerHarness = new TopicBasedRemoteLogMetadataManagerHarness(); + + @Override + public CompletableFuture addRemoteLogSegmentMetadata(RemoteLogSegmentMetadata remoteLogSegmentMetadata) throws RemoteStorageException { + return remoteLogMetadataManagerHarness.remoteLogMetadataManager().addRemoteLogSegmentMetadata(remoteLogSegmentMetadata); + } + + @Override + public CompletableFuture updateRemoteLogSegmentMetadata(RemoteLogSegmentMetadataUpdate remoteLogSegmentMetadataUpdate) throws RemoteStorageException { + return remoteLogMetadataManagerHarness.remoteLogMetadataManager().updateRemoteLogSegmentMetadata(remoteLogSegmentMetadataUpdate); + } + + @Override + public Optional remoteLogSegmentMetadata(TopicIdPartition topicIdPartition, + int epochForOffset, + long offset) throws RemoteStorageException { + return remoteLogMetadataManagerHarness.remoteLogMetadataManager().remoteLogSegmentMetadata(topicIdPartition, epochForOffset, offset); + } + + @Override + public Optional highestOffsetForEpoch(TopicIdPartition topicIdPartition, + int leaderEpoch) throws RemoteStorageException { + return remoteLogMetadataManagerHarness.remoteLogMetadataManager().highestOffsetForEpoch(topicIdPartition, leaderEpoch); + } + + @Override + public CompletableFuture putRemotePartitionDeleteMetadata(RemotePartitionDeleteMetadata remotePartitionDeleteMetadata) throws RemoteStorageException { + return remoteLogMetadataManagerHarness.remoteLogMetadataManager().putRemotePartitionDeleteMetadata(remotePartitionDeleteMetadata); + } + + @Override + public Iterator listRemoteLogSegments(TopicIdPartition topicIdPartition) throws RemoteStorageException { + return remoteLogMetadataManagerHarness.remoteLogMetadataManager().listRemoteLogSegments(topicIdPartition); + } + + @Override + public Iterator listRemoteLogSegments(TopicIdPartition topicIdPartition, + int leaderEpoch) throws RemoteStorageException { + return remoteLogMetadataManagerHarness.remoteLogMetadataManager().listRemoteLogSegments(topicIdPartition, leaderEpoch); + } + + @Override + public void onPartitionLeadershipChanges(Set leaderPartitions, + Set followerPartitions) { + + remoteLogMetadataManagerHarness.remoteLogMetadataManager().onPartitionLeadershipChanges(leaderPartitions, followerPartitions); + } + + @Override + public void onStopPartitions(Set partitions) { + remoteLogMetadataManagerHarness.remoteLogMetadataManager().onStopPartitions(partitions); + } + + @Override + public void close() throws IOException { + remoteLogMetadataManagerHarness.remoteLogMetadataManager().close(); + } + + @Override + public void configure(Map configs) { + // This will make sure the cluster is up and TopicBasedRemoteLogMetadataManager is initialized. + remoteLogMetadataManagerHarness.initialize(Collections.emptySet(), true); + remoteLogMetadataManagerHarness.remoteLogMetadataManager().configure(configs); + } +} diff --git a/storage/src/test/java/org/apache/kafka/server/log/remote/storage/InmemoryRemoteLogMetadataManager.java b/storage/src/test/java/org/apache/kafka/server/log/remote/storage/InmemoryRemoteLogMetadataManager.java new file mode 100644 index 0000000..cac37ae --- /dev/null +++ b/storage/src/test/java/org/apache/kafka/server/log/remote/storage/InmemoryRemoteLogMetadataManager.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.storage; + +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.server.log.remote.metadata.storage.RemoteLogMetadataCache; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Collections; +import java.util.Iterator; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; + +/** + * This class is an implementation of {@link RemoteLogMetadataManager} backed by in-memory store. + * This class is not completely thread safe. + */ +public class InmemoryRemoteLogMetadataManager implements RemoteLogMetadataManager { + private static final Logger log = LoggerFactory.getLogger(InmemoryRemoteLogMetadataManager.class); + + private Map idToPartitionDeleteMetadata = + new ConcurrentHashMap<>(); + + private Map idToRemoteLogMetadataCache = new ConcurrentHashMap<>(); + + private static final CompletableFuture COMPLETED_FUTURE = new CompletableFuture<>(); + static { + COMPLETED_FUTURE.complete(null); + } + + @Override + public CompletableFuture addRemoteLogSegmentMetadata(RemoteLogSegmentMetadata remoteLogSegmentMetadata) + throws RemoteStorageException { + log.debug("Adding remote log segment : [{}]", remoteLogSegmentMetadata); + Objects.requireNonNull(remoteLogSegmentMetadata, "remoteLogSegmentMetadata can not be null"); + + RemoteLogSegmentId remoteLogSegmentId = remoteLogSegmentMetadata.remoteLogSegmentId(); + + idToRemoteLogMetadataCache + .computeIfAbsent(remoteLogSegmentId.topicIdPartition(), id -> new RemoteLogMetadataCache()) + .addCopyInProgressSegment(remoteLogSegmentMetadata); + + return COMPLETED_FUTURE; + } + + @Override + public CompletableFuture updateRemoteLogSegmentMetadata(RemoteLogSegmentMetadataUpdate metadataUpdate) + throws RemoteStorageException { + log.debug("Updating remote log segment: [{}]", metadataUpdate); + Objects.requireNonNull(metadataUpdate, "metadataUpdate can not be null"); + + getRemoteLogMetadataCache(metadataUpdate.remoteLogSegmentId().topicIdPartition()) + .updateRemoteLogSegmentMetadata(metadataUpdate); + + return COMPLETED_FUTURE; + } + + private RemoteLogMetadataCache getRemoteLogMetadataCache(TopicIdPartition topicIdPartition) + throws RemoteResourceNotFoundException { + RemoteLogMetadataCache remoteLogMetadataCache = idToRemoteLogMetadataCache.get(topicIdPartition); + if (remoteLogMetadataCache == null) { + throw new RemoteResourceNotFoundException("No existing metadata found for partition: " + topicIdPartition); + } + + return remoteLogMetadataCache; + } + + @Override + public Optional remoteLogSegmentMetadata(TopicIdPartition topicIdPartition, + int epochForOffset, + long offset) + throws RemoteStorageException { + Objects.requireNonNull(topicIdPartition, "topicIdPartition can not be null"); + + return getRemoteLogMetadataCache(topicIdPartition).remoteLogSegmentMetadata(epochForOffset, offset); + } + + @Override + public Optional highestOffsetForEpoch(TopicIdPartition topicIdPartition, + int leaderEpoch) throws RemoteStorageException { + Objects.requireNonNull(topicIdPartition, "topicIdPartition can not be null"); + + return getRemoteLogMetadataCache(topicIdPartition).highestOffsetForEpoch(leaderEpoch); + } + + @Override + public CompletableFuture putRemotePartitionDeleteMetadata(RemotePartitionDeleteMetadata remotePartitionDeleteMetadata) + throws RemoteStorageException { + log.debug("Adding delete state with: [{}]", remotePartitionDeleteMetadata); + Objects.requireNonNull(remotePartitionDeleteMetadata, "remotePartitionDeleteMetadata can not be null"); + + TopicIdPartition topicIdPartition = remotePartitionDeleteMetadata.topicIdPartition(); + + RemotePartitionDeleteState targetState = remotePartitionDeleteMetadata.state(); + RemotePartitionDeleteMetadata existingMetadata = idToPartitionDeleteMetadata.get(topicIdPartition); + RemotePartitionDeleteState existingState = existingMetadata != null ? existingMetadata.state() : null; + if (!RemotePartitionDeleteState.isValidTransition(existingState, targetState)) { + throw new IllegalStateException("Current state: " + existingState + ", target state: " + targetState); + } + + idToPartitionDeleteMetadata.put(topicIdPartition, remotePartitionDeleteMetadata); + + if (targetState == RemotePartitionDeleteState.DELETE_PARTITION_FINISHED) { + // Remove the association for the partition. + idToRemoteLogMetadataCache.remove(topicIdPartition); + idToPartitionDeleteMetadata.remove(topicIdPartition); + } + + return COMPLETED_FUTURE; + } + + @Override + public Iterator listRemoteLogSegments(TopicIdPartition topicIdPartition) + throws RemoteStorageException { + Objects.requireNonNull(topicIdPartition, "topicIdPartition can not be null"); + + return getRemoteLogMetadataCache(topicIdPartition).listAllRemoteLogSegments(); + } + + @Override + public Iterator listRemoteLogSegments(TopicIdPartition topicIdPartition, int leaderEpoch) + throws RemoteStorageException { + Objects.requireNonNull(topicIdPartition, "topicIdPartition can not be null"); + + return getRemoteLogMetadataCache(topicIdPartition).listRemoteLogSegments(leaderEpoch); + } + + @Override + public void onPartitionLeadershipChanges(Set leaderPartitions, + Set followerPartitions) { + // It is not applicable for this implementation. This will track the segments that are added/updated as part of + // this instance. It does not depend upon any leader or follower transitions. + } + + @Override + public void onStopPartitions(Set partitions) { + // It is not applicable for this implementation. This will track the segments that are added/updated as part of + // this instance. It does not depend upon stopped partitions. + } + + @Override + public void close() throws IOException { + // Clearing the references to the map and assigning empty immutable maps. + // Practically, this instance will not be used once it is closed. + idToPartitionDeleteMetadata = Collections.emptyMap(); + idToRemoteLogMetadataCache = Collections.emptyMap(); + } + + @Override + public void configure(Map configs) { + // Intentionally left blank here as nothing to be initialized here. + } +} diff --git a/storage/src/test/java/org/apache/kafka/server/log/remote/storage/InmemoryRemoteStorageManager.java b/storage/src/test/java/org/apache/kafka/server/log/remote/storage/InmemoryRemoteStorageManager.java new file mode 100644 index 0000000..9e5c3be --- /dev/null +++ b/storage/src/test/java/org/apache/kafka/server/log/remote/storage/InmemoryRemoteStorageManager.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.storage; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.util.Collections; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; + +/** + * This class is an implementation of {@link RemoteStorageManager} backed by in-memory store. + */ +public class InmemoryRemoteStorageManager implements RemoteStorageManager { + private static final Logger log = LoggerFactory.getLogger(InmemoryRemoteStorageManager.class); + + // Map of key to log data, which can be segment or any of its indexes. + private Map keyToLogData = new ConcurrentHashMap<>(); + + static String generateKeyForSegment(RemoteLogSegmentMetadata remoteLogSegmentMetadata) { + return remoteLogSegmentMetadata.remoteLogSegmentId().id().toString() + ".segment"; + } + + static String generateKeyForIndex(RemoteLogSegmentMetadata remoteLogSegmentMetadata, + IndexType indexType) { + return remoteLogSegmentMetadata.remoteLogSegmentId().id().toString() + "." + indexType.toString(); + } + + // visible for testing. + boolean containsKey(String key) { + return keyToLogData.containsKey(key); + } + + @Override + public void copyLogSegmentData(RemoteLogSegmentMetadata remoteLogSegmentMetadata, + LogSegmentData logSegmentData) + throws RemoteStorageException { + log.debug("copying log segment and indexes for : {}", remoteLogSegmentMetadata); + Objects.requireNonNull(remoteLogSegmentMetadata, "remoteLogSegmentMetadata can not be null"); + Objects.requireNonNull(logSegmentData, "logSegmentData can not be null"); + + if (keyToLogData.containsKey(generateKeyForSegment(remoteLogSegmentMetadata))) { + throw new RemoteStorageException("It already contains the segment for the given id: " + + remoteLogSegmentMetadata.remoteLogSegmentId()); + } + + try { + keyToLogData.put(generateKeyForSegment(remoteLogSegmentMetadata), + Files.readAllBytes(logSegmentData.logSegment())); + if (logSegmentData.transactionIndex().isPresent()) { + keyToLogData.put(generateKeyForIndex(remoteLogSegmentMetadata, IndexType.TRANSACTION), + Files.readAllBytes(logSegmentData.transactionIndex().get())); + } + keyToLogData.put(generateKeyForIndex(remoteLogSegmentMetadata, IndexType.LEADER_EPOCH), + logSegmentData.leaderEpochIndex().array()); + keyToLogData.put(generateKeyForIndex(remoteLogSegmentMetadata, IndexType.PRODUCER_SNAPSHOT), + Files.readAllBytes(logSegmentData.producerSnapshotIndex())); + keyToLogData.put(generateKeyForIndex(remoteLogSegmentMetadata, IndexType.OFFSET), + Files.readAllBytes(logSegmentData.offsetIndex())); + keyToLogData.put(generateKeyForIndex(remoteLogSegmentMetadata, IndexType.TIMESTAMP), + Files.readAllBytes(logSegmentData.timeIndex())); + } catch (Exception e) { + throw new RemoteStorageException(e); + } + log.debug("copied log segment and indexes for : {} successfully.", remoteLogSegmentMetadata); + } + + @Override + public InputStream fetchLogSegment(RemoteLogSegmentMetadata remoteLogSegmentMetadata, + int startPosition) + throws RemoteStorageException { + log.debug("Received fetch segment request at start position: [{}] for [{}]", startPosition, remoteLogSegmentMetadata); + Objects.requireNonNull(remoteLogSegmentMetadata, "remoteLogSegmentMetadata can not be null"); + + return fetchLogSegment(remoteLogSegmentMetadata, startPosition, Integer.MAX_VALUE); + } + + @Override + public InputStream fetchLogSegment(RemoteLogSegmentMetadata remoteLogSegmentMetadata, + int startPosition, + int endPosition) throws RemoteStorageException { + log.debug("Received fetch segment request at start position: [{}] and end position: [{}] for segment [{}]", + startPosition, endPosition, remoteLogSegmentMetadata); + + Objects.requireNonNull(remoteLogSegmentMetadata, "remoteLogSegmentMetadata can not be null"); + + if (startPosition < 0 || endPosition < 0) { + throw new IllegalArgumentException("Given start position or end position must not be negative."); + } + + if (endPosition < startPosition) { + throw new IllegalArgumentException("end position must be greater than or equal to start position"); + } + + String key = generateKeyForSegment(remoteLogSegmentMetadata); + byte[] segment = keyToLogData.get(key); + + if (segment == null) { + throw new RemoteResourceNotFoundException("No remote log segment found with start offset:" + + remoteLogSegmentMetadata.startOffset() + " and id: " + + remoteLogSegmentMetadata.remoteLogSegmentId()); + } + + if (startPosition >= segment.length) { + throw new IllegalArgumentException("start position: " + startPosition + + " must be less than the length of the segment: " + segment.length); + } + + // If the given (endPosition + 1) is more than the segment length then the segment length is taken into account. + // Computed length should never be more than the existing segment size. + int length = Math.min(segment.length - 1, endPosition) - startPosition + 1; + log.debug("Length of the segment to be sent: [{}], for segment: [{}]", length, remoteLogSegmentMetadata); + + return new ByteArrayInputStream(segment, startPosition, length); + } + + @Override + public InputStream fetchIndex(RemoteLogSegmentMetadata remoteLogSegmentMetadata, + IndexType indexType) throws RemoteStorageException { + log.debug("Received fetch request for index type: [{}], segment [{}]", indexType, remoteLogSegmentMetadata); + Objects.requireNonNull(remoteLogSegmentMetadata, "remoteLogSegmentMetadata can not be null"); + Objects.requireNonNull(indexType, "indexType can not be null"); + + String key = generateKeyForIndex(remoteLogSegmentMetadata, indexType); + byte[] index = keyToLogData.get(key); + if (index == null) { + throw new RemoteResourceNotFoundException("No remote log segment index found with start offset:" + + remoteLogSegmentMetadata.startOffset() + " and id: " + + remoteLogSegmentMetadata.remoteLogSegmentId()); + } + + return new ByteArrayInputStream(index); + } + + @Override + public void deleteLogSegmentData(RemoteLogSegmentMetadata remoteLogSegmentMetadata) throws RemoteStorageException { + log.info("Deleting log segment for: [{}]", remoteLogSegmentMetadata); + Objects.requireNonNull(remoteLogSegmentMetadata, "remoteLogSegmentMetadata can not be null"); + String segmentKey = generateKeyForSegment(remoteLogSegmentMetadata); + keyToLogData.remove(segmentKey); + for (IndexType indexType : IndexType.values()) { + String key = generateKeyForIndex(remoteLogSegmentMetadata, indexType); + keyToLogData.remove(key); + } + log.info("Deleted log segment successfully for: [{}]", remoteLogSegmentMetadata); + } + + @Override + public void close() throws IOException { + // Clearing the references to the map and assigning empty immutable map. + // Practically, this instance will not be used once it is closed. + keyToLogData = Collections.emptyMap(); + } + + @Override + public void configure(Map configs) { + // Intentionally left blank here as nothing to be initialized here. + } +} diff --git a/storage/src/test/java/org/apache/kafka/server/log/remote/storage/InmemoryRemoteStorageManagerTest.java b/storage/src/test/java/org/apache/kafka/server/log/remote/storage/InmemoryRemoteStorageManagerTest.java new file mode 100644 index 0000000..44984f7 --- /dev/null +++ b/storage/src/test/java/org/apache/kafka/server/log/remote/storage/InmemoryRemoteStorageManagerTest.java @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.storage; + +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.channels.SeekableByteChannel; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class InmemoryRemoteStorageManagerTest { + private static final Logger log = LoggerFactory.getLogger(InmemoryRemoteStorageManagerTest.class); + + private static final TopicPartition TP = new TopicPartition("foo", 1); + private static final File DIR = TestUtils.tempDirectory("inmem-rsm-"); + private static final Random RANDOM = new Random(); + + @Test + public void testCopyLogSegment() throws Exception { + InmemoryRemoteStorageManager rsm = new InmemoryRemoteStorageManager(); + RemoteLogSegmentMetadata segmentMetadata = createRemoteLogSegmentMetadata(); + LogSegmentData logSegmentData = createLogSegmentData(); + // Copy all the segment data. + rsm.copyLogSegmentData(segmentMetadata, logSegmentData); + + // Check that the segment data exists in in-memory RSM. + boolean containsSegment = rsm.containsKey(InmemoryRemoteStorageManager.generateKeyForSegment(segmentMetadata)); + assertTrue(containsSegment); + + // Check that the indexes exist in in-memory RSM. + for (RemoteStorageManager.IndexType indexType : RemoteStorageManager.IndexType.values()) { + boolean containsIndex = rsm.containsKey(InmemoryRemoteStorageManager.generateKeyForIndex(segmentMetadata, indexType)); + assertTrue(containsIndex); + } + } + + private RemoteLogSegmentMetadata createRemoteLogSegmentMetadata() { + TopicIdPartition topicPartition = new TopicIdPartition(Uuid.randomUuid(), TP); + RemoteLogSegmentId id = new RemoteLogSegmentId(topicPartition, Uuid.randomUuid()); + return new RemoteLogSegmentMetadata(id, 100L, 200L, System.currentTimeMillis(), 0, + System.currentTimeMillis(), 100, Collections.singletonMap(1, 100L)); + } + + @Test + public void testFetchLogSegmentIndexes() throws Exception { + InmemoryRemoteStorageManager rsm = new InmemoryRemoteStorageManager(); + RemoteLogSegmentMetadata segmentMetadata = createRemoteLogSegmentMetadata(); + int segSize = 100; + LogSegmentData logSegmentData = createLogSegmentData(segSize); + + // Copy the segment + rsm.copyLogSegmentData(segmentMetadata, logSegmentData); + + // Check segment data exists for the copied segment. + try (InputStream segmentStream = rsm.fetchLogSegment(segmentMetadata, 0)) { + checkContentSame(segmentStream, logSegmentData.logSegment()); + } + + HashMap expectedIndexToPaths = new HashMap<>(); + expectedIndexToPaths.put(RemoteStorageManager.IndexType.OFFSET, logSegmentData.offsetIndex()); + expectedIndexToPaths.put(RemoteStorageManager.IndexType.TIMESTAMP, logSegmentData.timeIndex()); + expectedIndexToPaths.put(RemoteStorageManager.IndexType.PRODUCER_SNAPSHOT, logSegmentData.producerSnapshotIndex()); + + logSegmentData.transactionIndex().ifPresent(txnIndex -> expectedIndexToPaths.put(RemoteStorageManager.IndexType.TRANSACTION, txnIndex)); + + // Check all segment indexes exist for the copied segment. + for (Map.Entry entry : expectedIndexToPaths.entrySet()) { + RemoteStorageManager.IndexType indexType = entry.getKey(); + Path indexPath = entry.getValue(); + log.debug("Fetching index type: {}, indexPath: {}", indexType, indexPath); + + try (InputStream offsetIndexStream = rsm.fetchIndex(segmentMetadata, indexType)) { + checkContentSame(offsetIndexStream, indexPath); + } + } + + try (InputStream leaderEpochIndexStream = rsm.fetchIndex(segmentMetadata, RemoteStorageManager.IndexType.LEADER_EPOCH)) { + ByteBuffer leaderEpochIndex = logSegmentData.leaderEpochIndex(); + assertEquals(leaderEpochIndex, + readAsByteBuffer(leaderEpochIndexStream, leaderEpochIndex.array().length)); + } + } + + @Test + public void testFetchSegmentsForRange() throws Exception { + InmemoryRemoteStorageManager rsm = new InmemoryRemoteStorageManager(); + RemoteLogSegmentMetadata segmentMetadata = createRemoteLogSegmentMetadata(); + int segSize = 100; + LogSegmentData logSegmentData = createLogSegmentData(segSize); + Path path = logSegmentData.logSegment(); + + // Copy the segment + rsm.copyLogSegmentData(segmentMetadata, logSegmentData); + + // 1. Fetch segment for startPos at 0 + doTestFetchForRange(rsm, segmentMetadata, path, 0, 40); + + // 2. Fetch segment for start and end positions as start and end of the segment. + doTestFetchForRange(rsm, segmentMetadata, path, 0, segSize); + + // 3. Fetch segment for endPos at the end of segment. + doTestFetchForRange(rsm, segmentMetadata, path, 90, segSize - 90); + + // 4. Fetch segment only for the start position. + doTestFetchForRange(rsm, segmentMetadata, path, 0, 1); + + // 5. Fetch segment only for the end position. + doTestFetchForRange(rsm, segmentMetadata, path, segSize - 1, 1); + + // 6. Fetch for any range other than boundaries. + doTestFetchForRange(rsm, segmentMetadata, path, 3, 90); + } + + private void doTestFetchForRange(InmemoryRemoteStorageManager rsm, RemoteLogSegmentMetadata rlsm, Path path, + int startPos, int len) throws Exception { + // Read from the segment for the expected range. + ByteBuffer expectedSegRangeBytes = ByteBuffer.allocate(len); + try (SeekableByteChannel seekableByteChannel = Files.newByteChannel(path)) { + seekableByteChannel.position(startPos).read(expectedSegRangeBytes); + } + expectedSegRangeBytes.rewind(); + + // Fetch from in-memory RSM for the same range + ByteBuffer fetchedSegRangeBytes = ByteBuffer.allocate(len); + try (InputStream segmentRangeStream = rsm.fetchLogSegment(rlsm, startPos, startPos + len - 1)) { + Utils.readFully(segmentRangeStream, fetchedSegRangeBytes); + } + fetchedSegRangeBytes.rewind(); + assertEquals(expectedSegRangeBytes, fetchedSegRangeBytes); + } + + @Test + public void testFetchInvalidRange() throws Exception { + InmemoryRemoteStorageManager rsm = new InmemoryRemoteStorageManager(); + RemoteLogSegmentMetadata remoteLogSegmentMetadata = createRemoteLogSegmentMetadata(); + int segSize = 100; + LogSegmentData logSegmentData = createLogSegmentData(segSize); + + // Copy the segment + rsm.copyLogSegmentData(remoteLogSegmentMetadata, logSegmentData); + + // Check fetch segments with invalid ranges like startPos < endPos + assertThrows(Exception.class, () -> rsm.fetchLogSegment(remoteLogSegmentMetadata, 2, 1)); + + // Check fetch segments with invalid ranges like startPos or endPos as negative. + assertThrows(Exception.class, () -> rsm.fetchLogSegment(remoteLogSegmentMetadata, -1, 0)); + assertThrows(Exception.class, () -> rsm.fetchLogSegment(remoteLogSegmentMetadata, -2, -1)); + } + + @Test + public void testDeleteSegment() throws Exception { + InmemoryRemoteStorageManager rsm = new InmemoryRemoteStorageManager(); + RemoteLogSegmentMetadata segmentMetadata = createRemoteLogSegmentMetadata(); + LogSegmentData logSegmentData = createLogSegmentData(); + + // Copy a log segment. + rsm.copyLogSegmentData(segmentMetadata, logSegmentData); + + // Check that the copied segment exists in rsm and it is same. + try (InputStream segmentStream = rsm.fetchLogSegment(segmentMetadata, 0)) { + checkContentSame(segmentStream, logSegmentData.logSegment()); + } + + // Delete segment and check that it does not exist in RSM. + rsm.deleteLogSegmentData(segmentMetadata); + + // Check that the segment data does not exist. + assertThrows(RemoteResourceNotFoundException.class, () -> rsm.fetchLogSegment(segmentMetadata, 0)); + + // Check that the segment data does not exist for range. + assertThrows(RemoteResourceNotFoundException.class, () -> rsm.fetchLogSegment(segmentMetadata, 0, 1)); + + // Check that all the indexes are not found. + for (RemoteStorageManager.IndexType indexType : RemoteStorageManager.IndexType.values()) { + assertThrows(RemoteResourceNotFoundException.class, () -> rsm.fetchIndex(segmentMetadata, indexType)); + } + } + + private void checkContentSame(InputStream segmentStream, Path path) throws IOException { + byte[] segmentBytes = Files.readAllBytes(path); + ByteBuffer byteBuffer = readAsByteBuffer(segmentStream, segmentBytes.length); + assertEquals(ByteBuffer.wrap(segmentBytes), byteBuffer); + } + + private ByteBuffer readAsByteBuffer(InputStream segmentStream, + int len) throws IOException { + ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[len]); + Utils.readFully(segmentStream, byteBuffer); + byteBuffer.rewind(); + return byteBuffer; + } + + private LogSegmentData createLogSegmentData() throws Exception { + return createLogSegmentData(100); + } + + private LogSegmentData createLogSegmentData(int segSize) throws Exception { + int prefix = Math.abs(RANDOM.nextInt()); + Path segment = new File(DIR, prefix + ".seg").toPath(); + Files.write(segment, TestUtils.randomBytes(segSize)); + + Path offsetIndex = new File(DIR, prefix + ".oi").toPath(); + Files.write(offsetIndex, TestUtils.randomBytes(10)); + + Path timeIndex = new File(DIR, prefix + ".ti").toPath(); + Files.write(timeIndex, TestUtils.randomBytes(10)); + + Path txnIndex = new File(DIR, prefix + ".txni").toPath(); + Files.write(txnIndex, TestUtils.randomBytes(10)); + + Path producerSnapshotIndex = new File(DIR, prefix + ".psi").toPath(); + Files.write(producerSnapshotIndex, TestUtils.randomBytes(10)); + + ByteBuffer leaderEpochIndex = ByteBuffer.wrap(TestUtils.randomBytes(10)); + return new LogSegmentData(segment, offsetIndex, timeIndex, Optional.of(txnIndex), producerSnapshotIndex, leaderEpochIndex); + } +} diff --git a/storage/src/test/java/org/apache/kafka/server/log/remote/storage/RemoteLogManagerConfigTest.java b/storage/src/test/java/org/apache/kafka/server/log/remote/storage/RemoteLogManagerConfigTest.java new file mode 100644 index 0000000..bb3c2ff --- /dev/null +++ b/storage/src/test/java/org/apache/kafka/server/log/remote/storage/RemoteLogManagerConfigTest.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.storage; + +import org.apache.kafka.common.config.AbstractConfig; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +public class RemoteLogManagerConfigTest { + + private static class TestConfig extends AbstractConfig { + public TestConfig(Map originals) { + super(RemoteLogManagerConfig.CONFIG_DEF, originals, true); + } + } + + @Test + public void testValidConfigs() { + String rsmPrefix = "__custom.rsm."; + String rlmmPrefix = "__custom.rlmm."; + Map rsmProps = Collections.singletonMap("rsm.prop", "val"); + Map rlmmProps = Collections.singletonMap("rlmm.prop", "val"); + RemoteLogManagerConfig expectedRemoteLogManagerConfig + = new RemoteLogManagerConfig(true, "dummy.remote.storage.class", "dummy.remote.storage.class.path", + "dummy.remote.log.metadata.class", "dummy.remote.log.metadata.class.path", + "listener.name", 1024 * 1024L, 1, 60000L, 100L, 60000L, 0.3, 10, 100, + rsmPrefix, rsmProps, rlmmPrefix, rlmmProps); + + Map props = extractProps(expectedRemoteLogManagerConfig); + rsmProps.forEach((k, v) -> props.put(rsmPrefix + k, v)); + rlmmProps.forEach((k, v) -> props.put(rlmmPrefix + k, v)); + TestConfig config = new TestConfig(props); + RemoteLogManagerConfig remoteLogManagerConfig = new RemoteLogManagerConfig(config); + Assertions.assertEquals(expectedRemoteLogManagerConfig, remoteLogManagerConfig); + } + + private Map extractProps(RemoteLogManagerConfig remoteLogManagerConfig) { + Map props = new HashMap<>(); + props.put(RemoteLogManagerConfig.REMOTE_LOG_STORAGE_SYSTEM_ENABLE_PROP, + remoteLogManagerConfig.enableRemoteStorageSystem()); + props.put(RemoteLogManagerConfig.REMOTE_STORAGE_MANAGER_CLASS_NAME_PROP, + remoteLogManagerConfig.remoteStorageManagerClassName()); + props.put(RemoteLogManagerConfig.REMOTE_STORAGE_MANAGER_CLASS_PATH_PROP, + remoteLogManagerConfig.remoteStorageManagerClassPath()); + props.put(RemoteLogManagerConfig.REMOTE_LOG_METADATA_MANAGER_CLASS_NAME_PROP, + remoteLogManagerConfig.remoteLogMetadataManagerClassName()); + props.put(RemoteLogManagerConfig.REMOTE_LOG_METADATA_MANAGER_CLASS_PATH_PROP, + remoteLogManagerConfig.remoteLogMetadataManagerClassPath()); + props.put(RemoteLogManagerConfig.REMOTE_LOG_METADATA_MANAGER_LISTENER_NAME_PROP, + remoteLogManagerConfig.remoteLogMetadataManagerListenerName()); + props.put(RemoteLogManagerConfig.REMOTE_LOG_INDEX_FILE_CACHE_TOTAL_SIZE_BYTES_PROP, + remoteLogManagerConfig.remoteLogIndexFileCacheTotalSizeBytes()); + props.put(RemoteLogManagerConfig.REMOTE_LOG_MANAGER_THREAD_POOL_SIZE_PROP, + remoteLogManagerConfig.remoteLogManagerThreadPoolSize()); + props.put(RemoteLogManagerConfig.REMOTE_LOG_MANAGER_TASK_INTERVAL_MS_PROP, + remoteLogManagerConfig.remoteLogManagerTaskIntervalMs()); + props.put(RemoteLogManagerConfig.REMOTE_LOG_MANAGER_TASK_RETRY_BACK_OFF_MS_PROP, + remoteLogManagerConfig.remoteLogManagerTaskRetryBackoffMs()); + props.put(RemoteLogManagerConfig.REMOTE_LOG_MANAGER_TASK_RETRY_BACK_OFF_MAX_MS_PROP, + remoteLogManagerConfig.remoteLogManagerTaskRetryBackoffMaxMs()); + props.put(RemoteLogManagerConfig.REMOTE_LOG_MANAGER_TASK_RETRY_JITTER_PROP, + remoteLogManagerConfig.remoteLogManagerTaskRetryJitter()); + props.put(RemoteLogManagerConfig.REMOTE_LOG_READER_THREADS_PROP, + remoteLogManagerConfig.remoteLogReaderThreads()); + props.put(RemoteLogManagerConfig.REMOTE_LOG_READER_MAX_PENDING_TASKS_PROP, + remoteLogManagerConfig.remoteLogReaderMaxPendingTasks()); + props.put(RemoteLogManagerConfig.REMOTE_STORAGE_MANAGER_CONFIG_PREFIX_PROP, + remoteLogManagerConfig.remoteStorageManagerPrefix()); + props.put(RemoteLogManagerConfig.REMOTE_LOG_METADATA_MANAGER_CONFIG_PREFIX_PROP, + remoteLogManagerConfig.remoteLogMetadataManagerPrefix()); + return props; + } +} diff --git a/storage/src/test/java/org/apache/kafka/server/log/remote/storage/RemoteLogMetadataManagerTest.java b/storage/src/test/java/org/apache/kafka/server/log/remote/storage/RemoteLogMetadataManagerTest.java new file mode 100644 index 0000000..95521c4 --- /dev/null +++ b/storage/src/test/java/org/apache/kafka/server/log/remote/storage/RemoteLogMetadataManagerTest.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.server.log.remote.storage; + +import org.apache.kafka.common.TopicIdPartition; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.server.log.remote.metadata.storage.TopicBasedRemoteLogMetadataManagerWrapperWithHarness; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +/** + * This class covers basic tests for {@link RemoteLogMetadataManager} implementations like {@link InmemoryRemoteLogMetadataManager}, + * and {@link org.apache.kafka.server.log.remote.metadata.storage.TopicBasedRemoteLogMetadataManager}. + */ +public class RemoteLogMetadataManagerTest { + + private static final TopicIdPartition TP0 = new TopicIdPartition(Uuid.randomUuid(), + new TopicPartition("foo", 0)); + private static final int SEG_SIZE = 1024 * 1024; + private static final int BROKER_ID_0 = 0; + private static final int BROKER_ID_1 = 1; + + private final Time time = new MockTime(1); + + @ParameterizedTest(name = "remoteLogMetadataManager = {0}") + @MethodSource("remoteLogMetadataManagers") + public void testFetchSegments(RemoteLogMetadataManager remoteLogMetadataManager) throws Exception { + try { + remoteLogMetadataManager.configure(Collections.emptyMap()); + remoteLogMetadataManager.onPartitionLeadershipChanges(Collections.singleton(TP0), Collections.emptySet()); + + // 1.Create a segment with state COPY_SEGMENT_STARTED, and this segment should not be available. + Map segmentLeaderEpochs = Collections.singletonMap(0, 101L); + RemoteLogSegmentId segmentId = new RemoteLogSegmentId(TP0, Uuid.randomUuid()); + RemoteLogSegmentMetadata segmentMetadata = new RemoteLogSegmentMetadata(segmentId, 101L, 200L, -1L, BROKER_ID_0, + time.milliseconds(), SEG_SIZE, segmentLeaderEpochs); + // Wait until the segment is added successfully. + remoteLogMetadataManager.addRemoteLogSegmentMetadata(segmentMetadata).get(); + + // Search should not return the above segment. + Assertions.assertFalse(remoteLogMetadataManager.remoteLogSegmentMetadata(TP0, 0, 150).isPresent()); + + // 2.Move that segment to COPY_SEGMENT_FINISHED state and this segment should be available. + RemoteLogSegmentMetadataUpdate segmentMetadataUpdate = new RemoteLogSegmentMetadataUpdate(segmentId, time.milliseconds(), + RemoteLogSegmentState.COPY_SEGMENT_FINISHED, + BROKER_ID_1); + // Wait until the segment is updated successfully. + remoteLogMetadataManager.updateRemoteLogSegmentMetadata(segmentMetadataUpdate).get(); + RemoteLogSegmentMetadata expectedSegmentMetadata = segmentMetadata.createWithUpdates(segmentMetadataUpdate); + + // Search should return the above segment. + Optional segmentMetadataForOffset150 = remoteLogMetadataManager.remoteLogSegmentMetadata(TP0, 0, 150); + Assertions.assertEquals(Optional.of(expectedSegmentMetadata), segmentMetadataForOffset150); + } finally { + Utils.closeQuietly(remoteLogMetadataManager, "RemoteLogMetadataManager"); + } + } + + @ParameterizedTest(name = "remoteLogMetadataManager = {0}") + @MethodSource("remoteLogMetadataManagers") + public void testRemotePartitionDeletion(RemoteLogMetadataManager remoteLogMetadataManager) throws Exception { + try { + remoteLogMetadataManager.configure(Collections.emptyMap()); + remoteLogMetadataManager.onPartitionLeadershipChanges(Collections.singleton(TP0), Collections.emptySet()); + + // Create remote log segment metadata and add them to RLMM. + + // segment 0 + // offsets: [0-100] + // leader epochs (0,0), (1,20), (2,80) + Map segmentLeaderEpochs = new HashMap<>(); + segmentLeaderEpochs.put(0, 0L); + segmentLeaderEpochs.put(1, 20L); + segmentLeaderEpochs.put(2, 50L); + segmentLeaderEpochs.put(3, 80L); + RemoteLogSegmentId segmentId = new RemoteLogSegmentId(TP0, Uuid.randomUuid()); + RemoteLogSegmentMetadata segmentMetadata = new RemoteLogSegmentMetadata(segmentId, 0L, 100L, + -1L, BROKER_ID_0, time.milliseconds(), SEG_SIZE, + segmentLeaderEpochs); + // Wait until the segment is added successfully. + remoteLogMetadataManager.addRemoteLogSegmentMetadata(segmentMetadata).get(); + + RemoteLogSegmentMetadataUpdate segmentMetadataUpdate = new RemoteLogSegmentMetadataUpdate( + segmentId, time.milliseconds(), RemoteLogSegmentState.COPY_SEGMENT_FINISHED, BROKER_ID_1); + // Wait until the segment is updated successfully. + remoteLogMetadataManager.updateRemoteLogSegmentMetadata(segmentMetadataUpdate).get(); + + RemoteLogSegmentMetadata expectedSegMetadata = segmentMetadata.createWithUpdates(segmentMetadataUpdate); + + // Check that the segment exists in RLMM. + Optional segMetadataForOffset30Epoch1 = remoteLogMetadataManager.remoteLogSegmentMetadata(TP0, 1, 30L); + Assertions.assertEquals(Optional.of(expectedSegMetadata), segMetadataForOffset30Epoch1); + + // Mark the partition for deletion and wait for it to be updated successfully. + remoteLogMetadataManager.putRemotePartitionDeleteMetadata( + createRemotePartitionDeleteMetadata(RemotePartitionDeleteState.DELETE_PARTITION_MARKED)).get(); + + Optional segmentMetadataAfterDelMark = remoteLogMetadataManager.remoteLogSegmentMetadata(TP0, + 1, 30L); + Assertions.assertEquals(Optional.of(expectedSegMetadata), segmentMetadataAfterDelMark); + + // Set the partition deletion state as started. Partition and segments should still be accessible as they are not + // yet deleted. Wait until the segment state is updated successfully. + remoteLogMetadataManager.putRemotePartitionDeleteMetadata( + createRemotePartitionDeleteMetadata(RemotePartitionDeleteState.DELETE_PARTITION_STARTED)).get(); + + Optional segmentMetadataAfterDelStart = remoteLogMetadataManager.remoteLogSegmentMetadata(TP0, + 1, 30L); + Assertions.assertEquals(Optional.of(expectedSegMetadata), segmentMetadataAfterDelStart); + + // Set the partition deletion state as finished. RLMM should clear all its internal state for that partition. + // Wait until the segment state is updated successfully. + remoteLogMetadataManager.putRemotePartitionDeleteMetadata( + createRemotePartitionDeleteMetadata(RemotePartitionDeleteState.DELETE_PARTITION_FINISHED)).get(); + + Assertions.assertThrows(RemoteResourceNotFoundException.class, + () -> remoteLogMetadataManager.remoteLogSegmentMetadata(TP0, 1, 30L)); + } finally { + Utils.closeQuietly(remoteLogMetadataManager, "RemoteLogMetadataManager"); + } + } + + private RemotePartitionDeleteMetadata createRemotePartitionDeleteMetadata(RemotePartitionDeleteState state) { + return new RemotePartitionDeleteMetadata(TP0, state, time.milliseconds(), BROKER_ID_0); + } + + private static Collection remoteLogMetadataManagers() { + return Arrays.asList(Arguments.of(new InmemoryRemoteLogMetadataManager()), Arguments.of(new TopicBasedRemoteLogMetadataManagerWrapperWithHarness())); + } +} \ No newline at end of file diff --git a/storage/src/test/resources/log4j.properties b/storage/src/test/resources/log4j.properties new file mode 100644 index 0000000..113e15e --- /dev/null +++ b/storage/src/test/resources/log4j.properties @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +log4j.rootLogger=OFF, stdout + +log4j.appender.stdout=org.apache.log4j.ConsoleAppender +log4j.appender.stdout.layout=org.apache.log4j.PatternLayout +log4j.appender.stdout.layout.ConversionPattern=[%d] %p %m (%c:%L)%n + +log4j.logger.org.apache.kafka.server.log.remote.storage=INFO +log4j.logger.org.apache.kafka.server.log.remote.metadata.storage=INFO diff --git a/streams/.gitignore b/streams/.gitignore new file mode 100644 index 0000000..ae3c172 --- /dev/null +++ b/streams/.gitignore @@ -0,0 +1 @@ +/bin/ diff --git a/streams/examples/src/main/java/org/apache/kafka/streams/examples/pageview/JsonTimestampExtractor.java b/streams/examples/src/main/java/org/apache/kafka/streams/examples/pageview/JsonTimestampExtractor.java new file mode 100644 index 0000000..d760183 --- /dev/null +++ b/streams/examples/src/main/java/org/apache/kafka/streams/examples/pageview/JsonTimestampExtractor.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.examples.pageview; + +import com.fasterxml.jackson.databind.JsonNode; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.streams.processor.TimestampExtractor; + +/** + * A timestamp extractor implementation that tries to extract event time from + * the "timestamp" field in the Json formatted message. + */ +public class JsonTimestampExtractor implements TimestampExtractor { + + @Override + public long extract(final ConsumerRecord record, final long partitionTime) { + if (record.value() instanceof PageViewTypedDemo.PageView) { + return ((PageViewTypedDemo.PageView) record.value()).timestamp; + } + + if (record.value() instanceof PageViewTypedDemo.UserProfile) { + return ((PageViewTypedDemo.UserProfile) record.value()).timestamp; + } + + if (record.value() instanceof JsonNode) { + return ((JsonNode) record.value()).get("timestamp").longValue(); + } + + throw new IllegalArgumentException("JsonTimestampExtractor cannot recognize the record value " + record.value()); + } +} diff --git a/streams/examples/src/main/java/org/apache/kafka/streams/examples/pageview/PageViewTypedDemo.java b/streams/examples/src/main/java/org/apache/kafka/streams/examples/pageview/PageViewTypedDemo.java new file mode 100644 index 0000000..a5086de --- /dev/null +++ b/streams/examples/src/main/java/org/apache/kafka/streams/examples/pageview/PageViewTypedDemo.java @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.examples.pageview; + +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.errors.SerializationException; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.TimeWindows; + +import java.io.IOException; +import java.time.Duration; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.CountDownLatch; + +/** + * Demonstrates how to perform a join between a KStream and a KTable, i.e. an example of a stateful computation, + * using specific data types (here: JSON POJO; but can also be Avro specific bindings, etc.) for serdes + * in Kafka Streams. + * + * In this example, we join a stream of pageviews (aka clickstreams) that reads from a topic named "streams-pageview-input" + * with a user profile table that reads from a topic named "streams-userprofile-input", where the data format + * is JSON string representing a record in the stream or table, to compute the number of pageviews per user region. + * + * Before running this example you must create the input topics and the output topic (e.g. via + * bin/kafka-topics --create ...), and write some data to the input topics (e.g. via + * bin/kafka-console-producer). Otherwise you won't see any data arriving in the output topic. + * + * The inputs for this example are: + * - Topic: streams-pageview-input + * Key Format: (String) USER_ID + * Value Format: (JSON) {"_t": "pv", "user": (String USER_ID), "page": (String PAGE_ID), "timestamp": (long ms TIMESTAMP)} + * + * - Topic: streams-userprofile-input + * Key Format: (String) USER_ID + * Value Format: (JSON) {"_t": "up", "region": (String REGION), "timestamp": (long ms TIMESTAMP)} + * + * To observe the results, read the output topic (e.g., via bin/kafka-console-consumer) + * - Topic: streams-pageviewstats-typed-output + * Key Format: (JSON) {"_t": "wpvbr", "windowStart": (long ms WINDOW_TIMESTAMP), "region": (String REGION)} + * Value Format: (JSON) {"_t": "rc", "count": (long REGION_COUNT), "region": (String REGION)} + * + * Note, the "_t" field is necessary to help Jackson identify the correct class for deserialization in the + * generic {@link JSONSerde}. If you instead specify a specific serde per class, you won't need the extra "_t" field. + */ +@SuppressWarnings({"WeakerAccess", "unused"}) +public class PageViewTypedDemo { + + /** + * A serde for any class that implements {@link JSONSerdeCompatible}. Note that the classes also need to + * be registered in the {@code @JsonSubTypes} annotation on {@link JSONSerdeCompatible}. + * + * @param The concrete type of the class that gets de/serialized + */ + public static class JSONSerde implements Serializer, Deserializer, Serde { + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + @Override + public void configure(final Map configs, final boolean isKey) {} + + @SuppressWarnings("unchecked") + @Override + public T deserialize(final String topic, final byte[] data) { + if (data == null) { + return null; + } + + try { + return (T) OBJECT_MAPPER.readValue(data, JSONSerdeCompatible.class); + } catch (final IOException e) { + throw new SerializationException(e); + } + } + + @Override + public byte[] serialize(final String topic, final T data) { + if (data == null) { + return null; + } + + try { + return OBJECT_MAPPER.writeValueAsBytes(data); + } catch (final Exception e) { + throw new SerializationException("Error serializing JSON message", e); + } + } + + @Override + public void close() {} + + @Override + public Serializer serializer() { + return this; + } + + @Override + public Deserializer deserializer() { + return this; + } + } + + /** + * An interface for registering types that can be de/serialized with {@link JSONSerde}. + */ + @SuppressWarnings("DefaultAnnotationParam") // being explicit for the example + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY, property = "_t") + @JsonSubTypes({ + @JsonSubTypes.Type(value = PageView.class, name = "pv"), + @JsonSubTypes.Type(value = UserProfile.class, name = "up"), + @JsonSubTypes.Type(value = PageViewByRegion.class, name = "pvbr"), + @JsonSubTypes.Type(value = WindowedPageViewByRegion.class, name = "wpvbr"), + @JsonSubTypes.Type(value = RegionCount.class, name = "rc") + }) + public interface JSONSerdeCompatible { + + } + + // POJO classes + static public class PageView implements JSONSerdeCompatible { + public String user; + public String page; + public Long timestamp; + } + + static public class UserProfile implements JSONSerdeCompatible { + public String region; + public Long timestamp; + } + + static public class PageViewByRegion implements JSONSerdeCompatible { + public String user; + public String page; + public String region; + } + + static public class WindowedPageViewByRegion implements JSONSerdeCompatible { + public long windowStart; + public String region; + } + + static public class RegionCount implements JSONSerdeCompatible { + public long count; + public String region; + } + + public static void main(final String[] args) { + final Properties props = new Properties(); + props.put(StreamsConfig.APPLICATION_ID_CONFIG, "streams-pageview-typed"); + props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + props.put(StreamsConfig.DEFAULT_TIMESTAMP_EXTRACTOR_CLASS_CONFIG, JsonTimestampExtractor.class); + props.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, JSONSerde.class); + props.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, JSONSerde.class); + props.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + props.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + + // setting offset reset to earliest so that we can re-run the demo code with the same pre-loaded data + props.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream views = builder.stream("streams-pageview-input", Consumed.with(Serdes.String(), new JSONSerde<>())); + + final KTable users = builder.table("streams-userprofile-input", Consumed.with(Serdes.String(), new JSONSerde<>())); + + final Duration duration24Hours = Duration.ofHours(24); + + final KStream regionCount = views + .leftJoin(users, (view, profile) -> { + final PageViewByRegion viewByRegion = new PageViewByRegion(); + viewByRegion.user = view.user; + viewByRegion.page = view.page; + + if (profile != null) { + viewByRegion.region = profile.region; + } else { + viewByRegion.region = "UNKNOWN"; + } + return viewByRegion; + }) + .map((user, viewRegion) -> new KeyValue<>(viewRegion.region, viewRegion)) + .groupByKey(Grouped.with(Serdes.String(), new JSONSerde<>())) + .windowedBy(TimeWindows.ofSizeAndGrace(Duration.ofDays(7), duration24Hours).advanceBy(Duration.ofSeconds(1))) + .count() + .toStream() + .map((key, value) -> { + final WindowedPageViewByRegion wViewByRegion = new WindowedPageViewByRegion(); + wViewByRegion.windowStart = key.window().start(); + wViewByRegion.region = key.key(); + + final RegionCount rCount = new RegionCount(); + rCount.region = key.key(); + rCount.count = value; + + return new KeyValue<>(wViewByRegion, rCount); + }); + + // write to the result topic + regionCount.to("streams-pageviewstats-typed-output", Produced.with(new JSONSerde<>(), new JSONSerde<>())); + + final KafkaStreams streams = new KafkaStreams(builder.build(), props); + final CountDownLatch latch = new CountDownLatch(1); + + // attach shutdown handler to catch control-c + Runtime.getRuntime().addShutdownHook(new Thread("streams-pipe-shutdown-hook") { + @Override + public void run() { + streams.close(); + latch.countDown(); + } + }); + + try { + streams.start(); + latch.await(); + } catch (final Throwable e) { + e.printStackTrace(); + System.exit(1); + } + System.exit(0); + } +} diff --git a/streams/examples/src/main/java/org/apache/kafka/streams/examples/pageview/PageViewUntypedDemo.java b/streams/examples/src/main/java/org/apache/kafka/streams/examples/pageview/PageViewUntypedDemo.java new file mode 100644 index 0000000..cdb3639 --- /dev/null +++ b/streams/examples/src/main/java/org/apache/kafka/streams/examples/pageview/PageViewUntypedDemo.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.examples.pageview; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.JsonNodeFactory; +import com.fasterxml.jackson.databind.node.ObjectNode; +import java.time.Duration; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.connect.json.JsonDeserializer; +import org.apache.kafka.connect.json.JsonSerializer; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.TimeWindows; + +import java.util.Properties; + +/** + * Demonstrates how to perform a join between a KStream and a KTable, i.e. an example of a stateful computation, + * using general data types (here: JSON; but can also be Avro generic bindings, etc.) for serdes + * in Kafka Streams. + * + * In this example, we join a stream of pageviews (aka clickstreams) that reads from a topic named "streams-pageview-input" + * with a user profile table that reads from a topic named "streams-userprofile-input", where the data format + * is JSON string representing a record in the stream or table, to compute the number of pageviews per user region. + * + * Before running this example you must create the input topics and the output topic (e.g. via + * bin/kafka-topics.sh --create ...), and write some data to the input topics (e.g. via + * bin/kafka-console-producer.sh). Otherwise you won't see any data arriving in the output topic. + */ +public class PageViewUntypedDemo { + + public static void main(final String[] args) throws Exception { + final Properties props = new Properties(); + props.put(StreamsConfig.APPLICATION_ID_CONFIG, "streams-pageview-untyped"); + props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + props.put(StreamsConfig.DEFAULT_TIMESTAMP_EXTRACTOR_CLASS_CONFIG, JsonTimestampExtractor.class); + props.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + + // setting offset reset to earliest so that we can re-run the demo code with the same pre-loaded data + props.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + + final StreamsBuilder builder = new StreamsBuilder(); + + final Serializer jsonSerializer = new JsonSerializer(); + final Deserializer jsonDeserializer = new JsonDeserializer(); + final Serde jsonSerde = Serdes.serdeFrom(jsonSerializer, jsonDeserializer); + + final Consumed consumed = Consumed.with(Serdes.String(), jsonSerde); + final KStream views = builder.stream("streams-pageview-input", consumed); + + final KTable users = builder.table("streams-userprofile-input", consumed); + + final KTable userRegions = users.mapValues(record -> record.get("region").textValue()); + + final Duration duration24Hours = Duration.ofHours(24); + + final KStream regionCount = views + .leftJoin(userRegions, (view, region) -> { + final ObjectNode jNode = JsonNodeFactory.instance.objectNode(); + return (JsonNode) jNode.put("user", view.get("user").textValue()) + .put("page", view.get("page").textValue()) + .put("region", region == null ? "UNKNOWN" : region); + + }) + .map((user, viewRegion) -> new KeyValue<>(viewRegion.get("region").textValue(), viewRegion)) + .groupByKey(Grouped.with(Serdes.String(), jsonSerde)) + .windowedBy(TimeWindows.ofSizeAndGrace(Duration.ofDays(7), duration24Hours).advanceBy(Duration.ofSeconds(1))) + .count() + .toStream() + .map((key, value) -> { + final ObjectNode keyNode = JsonNodeFactory.instance.objectNode(); + keyNode.put("window-start", key.window().start()) + .put("region", key.key()); + + final ObjectNode valueNode = JsonNodeFactory.instance.objectNode(); + valueNode.put("count", value); + + return new KeyValue<>((JsonNode) keyNode, (JsonNode) valueNode); + }); + + // write to the result topic + regionCount.to("streams-pageviewstats-untyped-output", Produced.with(jsonSerde, jsonSerde)); + + final KafkaStreams streams = new KafkaStreams(builder.build(), props); + streams.start(); + + // usually the stream application would be running forever, + // in this example we just let it run for some time and stop since the input data is finite. + Thread.sleep(5000L); + + streams.close(); + } +} diff --git a/streams/examples/src/main/java/org/apache/kafka/streams/examples/pipe/PipeDemo.java b/streams/examples/src/main/java/org/apache/kafka/streams/examples/pipe/PipeDemo.java new file mode 100644 index 0000000..860f2ff --- /dev/null +++ b/streams/examples/src/main/java/org/apache/kafka/streams/examples/pipe/PipeDemo.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.examples.pipe; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; + +import java.util.Properties; +import java.util.concurrent.CountDownLatch; + +/** + * Demonstrates, using the high-level KStream DSL, how to read data from a source (input) topic and how to + * write data to a sink (output) topic. + * + * In this example, we implement a simple "pipe" program that reads from a source topic "streams-plaintext-input" + * and writes the data as-is (i.e. unmodified) into a sink topic "streams-pipe-output". + * + * Before running this example you must create the input topic and the output topic (e.g. via + * bin/kafka-topics.sh --create ...), and write some data to the input topic (e.g. via + * bin/kafka-console-producer.sh). Otherwise you won't see any data arriving in the output topic. + */ +public class PipeDemo { + + public static void main(final String[] args) { + final Properties props = new Properties(); + props.put(StreamsConfig.APPLICATION_ID_CONFIG, "streams-pipe"); + props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + props.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + props.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + + // setting offset reset to earliest so that we can re-run the demo code with the same pre-loaded data + props.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + + final StreamsBuilder builder = new StreamsBuilder(); + + builder.stream("streams-plaintext-input").to("streams-pipe-output"); + + final KafkaStreams streams = new KafkaStreams(builder.build(), props); + final CountDownLatch latch = new CountDownLatch(1); + + // attach shutdown handler to catch control-c + Runtime.getRuntime().addShutdownHook(new Thread("streams-pipe-shutdown-hook") { + @Override + public void run() { + streams.close(); + latch.countDown(); + } + }); + + try { + streams.start(); + latch.await(); + } catch (final Throwable e) { + System.exit(1); + } + System.exit(0); + } +} diff --git a/streams/examples/src/main/java/org/apache/kafka/streams/examples/temperature/TemperatureDemo.java b/streams/examples/src/main/java/org/apache/kafka/streams/examples/temperature/TemperatureDemo.java new file mode 100644 index 0000000..6384466 --- /dev/null +++ b/streams/examples/src/main/java/org/apache/kafka/streams/examples/temperature/TemperatureDemo.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.examples.temperature; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.WindowedSerdes; + +import java.time.Duration; +import java.util.Properties; +import java.util.concurrent.CountDownLatch; + +/** + * Demonstrates, using the high-level KStream DSL, how to implement an IoT demo application + * which ingests temperature value processing the maximum value in the latest TEMPERATURE_WINDOW_SIZE seconds (which + * is 5 seconds) sending a new message if it exceeds the TEMPERATURE_THRESHOLD (which is 20) + * + * In this example, the input stream reads from a topic named "iot-temperature", where the values of messages + * represent temperature values; using a TEMPERATURE_WINDOW_SIZE seconds "tumbling" window, the maximum value is processed and + * sent to a topic named "iot-temperature-max" if it exceeds the TEMPERATURE_THRESHOLD. + * + * Before running this example you must create the input topic for temperature values in the following way : + * + * bin/kafka-topics.sh --create --bootstrap-server localhost:9092 --replication-factor 1 --partitions 1 --topic iot-temperature + * + * and at same time the output topic for filtered values : + * + * bin/kafka-topics.sh --create --bootstrap-server localhost:9092 --replication-factor 1 --partitions 1 --topic iot-temperature-max + * + * After that, a console consumer can be started in order to read filtered values from the "iot-temperature-max" topic : + * + * bin/kafka-console-consumer.sh --bootstrap-server localhost:9092 --topic iot-temperature-max --from-beginning + * + * On the other side, a console producer can be used for sending temperature values (which needs to be integers) + * to "iot-temperature" typing them on the console : + * + * bin/kafka-console-producer.sh --broker-list localhost:9092 --topic iot-temperature + * > 10 + * > 15 + * > 22 + */ +public class TemperatureDemo { + + // threshold used for filtering max temperature values + private static final int TEMPERATURE_THRESHOLD = 20; + // window size within which the filtering is applied + private static final int TEMPERATURE_WINDOW_SIZE = 5; + + public static void main(final String[] args) { + + final Properties props = new Properties(); + props.put(StreamsConfig.APPLICATION_ID_CONFIG, "streams-temperature"); + props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + props.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + props.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + + props.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + props.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + + final Duration duration24Hours = Duration.ofHours(24); + + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream source = builder.stream("iot-temperature"); + + final KStream, String> max = source + // temperature values are sent without a key (null), so in order + // to group and reduce them, a key is needed ("temp" has been chosen) + .selectKey((key, value) -> "temp") + .groupByKey() + .windowedBy(TimeWindows.ofSizeAndGrace(Duration.ofSeconds(TEMPERATURE_WINDOW_SIZE), duration24Hours)) + .reduce((value1, value2) -> { + if (Integer.parseInt(value1) > Integer.parseInt(value2)) { + return value1; + } else { + return value2; + } + }) + .toStream() + .filter((key, value) -> Integer.parseInt(value) > TEMPERATURE_THRESHOLD); + + final Serde> windowedSerde = WindowedSerdes.timeWindowedSerdeFrom(String.class, TEMPERATURE_WINDOW_SIZE); + + // need to override key serde to Windowed type + max.to("iot-temperature-max", Produced.with(windowedSerde, Serdes.String())); + + final KafkaStreams streams = new KafkaStreams(builder.build(), props); + final CountDownLatch latch = new CountDownLatch(1); + + // attach shutdown handler to catch control-c + Runtime.getRuntime().addShutdownHook(new Thread("streams-temperature-shutdown-hook") { + @Override + public void run() { + streams.close(); + latch.countDown(); + } + }); + + try { + streams.start(); + latch.await(); + } catch (final Throwable e) { + System.exit(1); + } + System.exit(0); + } +} diff --git a/streams/examples/src/main/java/org/apache/kafka/streams/examples/wordcount/WordCountDemo.java b/streams/examples/src/main/java/org/apache/kafka/streams/examples/wordcount/WordCountDemo.java new file mode 100644 index 0000000..4ca5d73 --- /dev/null +++ b/streams/examples/src/main/java/org/apache/kafka/streams/examples/wordcount/WordCountDemo.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.examples.wordcount; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Produced; + +import java.io.FileInputStream; +import java.io.IOException; +import java.util.Arrays; +import java.util.Locale; +import java.util.Properties; +import java.util.concurrent.CountDownLatch; + +/** + * Demonstrates, using the high-level KStream DSL, how to implement the WordCount program + * that computes a simple word occurrence histogram from an input text. + *

                + * In this example, the input stream reads from a topic named "streams-plaintext-input", where the values of messages + * represent lines of text; and the histogram output is written to topic "streams-wordcount-output" where each record + * is an updated count of a single word. + *

                + * Before running this example you must create the input topic and the output topic (e.g. via + * {@code bin/kafka-topics.sh --create ...}), and write some data to the input topic (e.g. via + * {@code bin/kafka-console-producer.sh}). Otherwise you won't see any data arriving in the output topic. + */ +public final class WordCountDemo { + + public static final String INPUT_TOPIC = "streams-plaintext-input"; + public static final String OUTPUT_TOPIC = "streams-wordcount-output"; + + static Properties getStreamsConfig(final String[] args) throws IOException { + final Properties props = new Properties(); + if (args != null && args.length > 0) { + try (final FileInputStream fis = new FileInputStream(args[0])) { + props.load(fis); + } + if (args.length > 1) { + System.out.println("Warning: Some command line arguments were ignored. This demo only accepts an optional configuration file."); + } + } + props.putIfAbsent(StreamsConfig.APPLICATION_ID_CONFIG, "streams-wordcount"); + props.putIfAbsent(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + props.putIfAbsent(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + props.putIfAbsent(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass().getName()); + props.putIfAbsent(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass().getName()); + + // setting offset reset to earliest so that we can re-run the demo code with the same pre-loaded data + // Note: To re-run the demo, you need to use the offset reset tool: + // https://cwiki.apache.org/confluence/display/KAFKA/Kafka+Streams+Application+Reset+Tool + props.putIfAbsent(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + return props; + } + + static void createWordCountStream(final StreamsBuilder builder) { + final KStream source = builder.stream(INPUT_TOPIC); + + final KTable counts = source + .flatMapValues(value -> Arrays.asList(value.toLowerCase(Locale.getDefault()).split("\\W+"))) + .groupBy((key, value) -> value) + .count(); + + // need to override value serde to Long type + counts.toStream().to(OUTPUT_TOPIC, Produced.with(Serdes.String(), Serdes.Long())); + } + + public static void main(final String[] args) throws IOException { + final Properties props = getStreamsConfig(args); + + final StreamsBuilder builder = new StreamsBuilder(); + createWordCountStream(builder); + final KafkaStreams streams = new KafkaStreams(builder.build(), props); + final CountDownLatch latch = new CountDownLatch(1); + + // attach shutdown handler to catch control-c + Runtime.getRuntime().addShutdownHook(new Thread("streams-wordcount-shutdown-hook") { + @Override + public void run() { + streams.close(); + latch.countDown(); + } + }); + + try { + streams.start(); + latch.await(); + } catch (final Throwable e) { + System.exit(1); + } + System.exit(0); + } +} diff --git a/streams/examples/src/main/java/org/apache/kafka/streams/examples/wordcount/WordCountProcessorDemo.java b/streams/examples/src/main/java/org/apache/kafka/streams/examples/wordcount/WordCountProcessorDemo.java new file mode 100644 index 0000000..014923f --- /dev/null +++ b/streams/examples/src/main/java/org/apache/kafka/streams/examples/wordcount/WordCountProcessorDemo.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.examples.wordcount; + +import java.io.FileInputStream; +import java.io.IOException; +import java.time.Duration; +import java.util.Locale; +import java.util.Properties; +import java.util.concurrent.CountDownLatch; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.Stores; + +/** + * Demonstrates, using the low-level Processor APIs, how to implement the WordCount program + * that computes a simple word occurrence histogram from an input text. + *

                + * Note: This is simplified code that only works correctly for single partition input topics. + * Check out {@link WordCountDemo} for a generic example. + *

                + * In this example, the input stream reads from a topic named "streams-plaintext-input", where the values of messages + * represent lines of text; and the histogram output is written to topic "streams-wordcount-processor-output" where each record + * is an updated count of a single word. + *

                + * Before running this example you must create the input topic and the output topic (e.g. via + * {@code bin/kafka-topics.sh --create ...}), and write some data to the input topic (e.g. via + * {@code bin/kafka-console-producer.sh}). Otherwise you won't see any data arriving in the output topic. + */ +public final class WordCountProcessorDemo { + static class WordCountProcessor implements Processor { + private KeyValueStore kvStore; + + @Override + public void init(final ProcessorContext context) { + context.schedule(Duration.ofSeconds(1), PunctuationType.STREAM_TIME, timestamp -> { + try (final KeyValueIterator iter = kvStore.all()) { + System.out.println("----------- " + timestamp + " ----------- "); + + while (iter.hasNext()) { + final KeyValue entry = iter.next(); + + System.out.println("[" + entry.key + ", " + entry.value + "]"); + + context.forward(new Record<>(entry.key, entry.value.toString(), timestamp)); + } + } + }); + kvStore = context.getStateStore("Counts"); + } + + @Override + public void process(final Record record) { + final String[] words = record.value().toLowerCase(Locale.getDefault()).split("\\W+"); + + for (final String word : words) { + final Integer oldValue = kvStore.get(word); + + if (oldValue == null) { + kvStore.put(word, 1); + } else { + kvStore.put(word, oldValue + 1); + } + } + } + + @Override + public void close() { + // close any resources managed by this processor + // Note: Do not close any StateStores as these are managed by the library + } + } + + public static void main(final String[] args) throws IOException { + final Properties props = new Properties(); + if (args != null && args.length > 0) { + try (final FileInputStream fis = new FileInputStream(args[0])) { + props.load(fis); + } + if (args.length > 1) { + System.out.println("Warning: Some command line arguments were ignored. This demo only accepts an optional configuration file."); + } + } + + props.putIfAbsent(StreamsConfig.APPLICATION_ID_CONFIG, "streams-wordcount-processor"); + props.putIfAbsent(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + props.putIfAbsent(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + props.putIfAbsent(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + props.putIfAbsent(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + + // setting offset reset to earliest so that we can re-run the demo code with the same pre-loaded data + props.putIfAbsent(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + + final Topology builder = new Topology(); + + builder.addSource("Source", "streams-plaintext-input"); + + builder.addProcessor("Process", WordCountProcessor::new, "Source"); + + builder.addStateStore(Stores.keyValueStoreBuilder( + Stores.inMemoryKeyValueStore("Counts"), + Serdes.String(), + Serdes.Integer()), + "Process"); + + builder.addSink("Sink", "streams-wordcount-processor-output", "Process"); + + final KafkaStreams streams = new KafkaStreams(builder, props); + final CountDownLatch latch = new CountDownLatch(1); + + // attach shutdown handler to catch control-c + Runtime.getRuntime().addShutdownHook(new Thread("streams-wordcount-shutdown-hook") { + @Override + public void run() { + streams.close(); + latch.countDown(); + } + }); + + try { + streams.start(); + latch.await(); + } catch (final Throwable e) { + System.exit(1); + } + System.exit(0); + } +} diff --git a/streams/examples/src/main/java/org/apache/kafka/streams/examples/wordcount/WordCountTransformerDemo.java b/streams/examples/src/main/java/org/apache/kafka/streams/examples/wordcount/WordCountTransformerDemo.java new file mode 100644 index 0000000..028d317 --- /dev/null +++ b/streams/examples/src/main/java/org/apache/kafka/streams/examples/wordcount/WordCountTransformerDemo.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.examples.wordcount; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.Transformer; +import org.apache.kafka.streams.kstream.TransformerSupplier; +import org.apache.kafka.streams.processor.ConnectedStoreProvider; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; + +import java.io.FileInputStream; +import java.io.IOException; +import java.time.Duration; +import java.util.Collections; +import java.util.Locale; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.CountDownLatch; + +/** + * Demonstrates, using a {@link Transformer} which combines the low-level Processor APIs with the high-level Kafka Streams DSL, + * how to implement the WordCount program that computes a simple word occurrence histogram from an input text. + *

                + * Note: This is simplified code that only works correctly for single partition input topics. + * Check out {@link WordCountDemo} for a generic example. + *

                + * In this example, the input stream reads from a topic named "streams-plaintext-input", where the values of messages + * represent lines of text; and the histogram output is written to topic "streams-wordcount-processor-output" where each record + * is an updated count of a single word. + *

                + * This example differs from {@link WordCountProcessorDemo} in that it uses a {@link Transformer} to define the word + * count logic, and the topology is wired up through a {@link StreamsBuilder}, which more closely resembles the high-level DSL. + * Additionally, the {@link TransformerSupplier} specifies the {@link StoreBuilder} that the {@link Transformer} needs + * by implementing {@link ConnectedStoreProvider#stores()}. + *

                + * Before running this example you must create the input topic and the output topic (e.g. via + * {@code bin/kafka-topics.sh --create ...}), and write some data to the input topic (e.g. via + * {@code bin/kafka-console-producer.sh}). Otherwise you won't see any data arriving in the output topic. + */ +public final class WordCountTransformerDemo { + + static class MyTransformerSupplier implements TransformerSupplier> { + + @Override + public Transformer> get() { + return new Transformer>() { + private KeyValueStore kvStore; + + @Override + public void init(final ProcessorContext context) { + context.schedule(Duration.ofSeconds(1), PunctuationType.STREAM_TIME, timestamp -> { + try (final KeyValueIterator iter = kvStore.all()) { + System.out.println("----------- " + timestamp + " ----------- "); + + while (iter.hasNext()) { + final KeyValue entry = iter.next(); + + System.out.println("[" + entry.key + ", " + entry.value + "]"); + + context.forward(entry.key, entry.value.toString()); + } + } + }); + this.kvStore = context.getStateStore("Counts"); + } + + @Override + public KeyValue transform(final String dummy, final String line) { + final String[] words = line.toLowerCase(Locale.getDefault()).split("\\W+"); + + for (final String word : words) { + final Integer oldValue = this.kvStore.get(word); + + if (oldValue == null) { + this.kvStore.put(word, 1); + } else { + this.kvStore.put(word, oldValue + 1); + } + } + + return null; + } + + @Override + public void close() {} + }; + } + + @Override + public Set> stores() { + return Collections.singleton(Stores.keyValueStoreBuilder( + Stores.inMemoryKeyValueStore("Counts"), + Serdes.String(), + Serdes.Integer())); + } + } + + public static void main(final String[] args) throws IOException { + final Properties props = new Properties(); + if (args != null && args.length > 0) { + try (final FileInputStream fis = new FileInputStream(args[0])) { + props.load(fis); + } + if (args.length > 1) { + System.out.println("Warning: Some command line arguments were ignored. This demo only accepts an optional configuration file."); + } + } + props.putIfAbsent(StreamsConfig.APPLICATION_ID_CONFIG, "streams-wordcount-transformer"); + props.putIfAbsent(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + props.putIfAbsent(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + props.putIfAbsent(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + props.putIfAbsent(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + + // setting offset reset to earliest so that we can re-run the demo code with the same pre-loaded data + props.putIfAbsent(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + + final StreamsBuilder builder = new StreamsBuilder(); + + builder.stream("streams-plaintext-input") + .transform(new MyTransformerSupplier()) + .to("streams-wordcount-processor-output"); + + final KafkaStreams streams = new KafkaStreams(builder.build(), props); + final CountDownLatch latch = new CountDownLatch(1); + + // attach shutdown handler to catch control-c + Runtime.getRuntime().addShutdownHook(new Thread("streams-wordcount-shutdown-hook") { + @Override + public void run() { + streams.close(); + latch.countDown(); + } + }); + + try { + streams.start(); + latch.await(); + } catch (final Throwable e) { + System.exit(1); + } + System.exit(0); + } +} \ No newline at end of file diff --git a/streams/examples/src/test/java/org/apache/kafka/streams/examples/docs/DeveloperGuideTesting.java b/streams/examples/src/test/java/org/apache/kafka/streams/examples/docs/DeveloperGuideTesting.java new file mode 100644 index 0000000..41b61e3 --- /dev/null +++ b/streams/examples/src/test/java/org/apache/kafka/streams/examples/docs/DeveloperGuideTesting.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.examples.docs; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TestOutputTopic; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.Stores; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.time.Duration; +import java.time.Instant; +import java.util.Properties; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.Is.is; + +/** + * This is code sample in docs/streams/developer-guide/testing.html + */ + +public class DeveloperGuideTesting { + private TopologyTestDriver testDriver; + private TestInputTopic inputTopic; + private TestOutputTopic outputTopic; + private KeyValueStore store; + + private Serde stringSerde = new Serdes.StringSerde(); + private Serde longSerde = new Serdes.LongSerde(); + + @BeforeEach + public void setup() { + final Topology topology = new Topology(); + topology.addSource("sourceProcessor", "input-topic"); + topology.addProcessor("aggregator", new CustomMaxAggregatorSupplier(), "sourceProcessor"); + topology.addStateStore( + Stores.keyValueStoreBuilder( + Stores.inMemoryKeyValueStore("aggStore"), + Serdes.String(), + Serdes.Long()).withLoggingDisabled(), // need to disable logging to allow store pre-populating + "aggregator"); + topology.addSink("sinkProcessor", "result-topic", "aggregator"); + + // setup test driver + final Properties props = new Properties(); + props.setProperty(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass().getName()); + props.setProperty(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.Long().getClass().getName()); + testDriver = new TopologyTestDriver(topology, props); + + // setup test topics + inputTopic = testDriver.createInputTopic("input-topic", stringSerde.serializer(), longSerde.serializer()); + outputTopic = testDriver.createOutputTopic("result-topic", stringSerde.deserializer(), longSerde.deserializer()); + + // pre-populate store + store = testDriver.getKeyValueStore("aggStore"); + store.put("a", 21L); + } + + @AfterEach + public void tearDown() { + testDriver.close(); + } + + + @Test + public void shouldFlushStoreForFirstInput() { + inputTopic.pipeInput("a", 1L); + assertThat(outputTopic.readKeyValue(), equalTo(new KeyValue<>("a", 21L))); + assertThat(outputTopic.isEmpty(), is(true)); + } + + @Test + public void shouldNotUpdateStoreForSmallerValue() { + inputTopic.pipeInput("a", 1L); + assertThat(store.get("a"), equalTo(21L)); + assertThat(outputTopic.readKeyValue(), equalTo(new KeyValue<>("a", 21L))); + assertThat(outputTopic.isEmpty(), is(true)); + } + + @Test + public void shouldNotUpdateStoreForLargerValue() { + inputTopic.pipeInput("a", 42L); + assertThat(store.get("a"), equalTo(42L)); + assertThat(outputTopic.readKeyValue(), equalTo(new KeyValue<>("a", 42L))); + assertThat(outputTopic.isEmpty(), is(true)); + } + + @Test + public void shouldUpdateStoreForNewKey() { + inputTopic.pipeInput("b", 21L); + assertThat(store.get("b"), equalTo(21L)); + assertThat(outputTopic.readKeyValue(), equalTo(new KeyValue<>("a", 21L))); + assertThat(outputTopic.readKeyValue(), equalTo(new KeyValue<>("b", 21L))); + assertThat(outputTopic.isEmpty(), is(true)); + } + + @Test + public void shouldPunctuateIfEvenTimeAdvances() { + final Instant recordTime = Instant.now(); + inputTopic.pipeInput("a", 1L, recordTime); + assertThat(outputTopic.readKeyValue(), equalTo(new KeyValue<>("a", 21L))); + + inputTopic.pipeInput("a", 1L, recordTime); + assertThat(outputTopic.isEmpty(), is(true)); + + inputTopic.pipeInput("a", 1L, recordTime.plusSeconds(10L)); + assertThat(outputTopic.readKeyValue(), equalTo(new KeyValue<>("a", 21L))); + assertThat(outputTopic.isEmpty(), is(true)); + } + + @Test + public void shouldPunctuateIfWallClockTimeAdvances() { + testDriver.advanceWallClockTime(Duration.ofSeconds(60)); + assertThat(outputTopic.readKeyValue(), equalTo(new KeyValue<>("a", 21L))); + assertThat(outputTopic.isEmpty(), is(true)); + } + + public static class CustomMaxAggregatorSupplier implements ProcessorSupplier { + @Override + public Processor get() { + return new CustomMaxAggregator(); + } + } + + public static class CustomMaxAggregator implements Processor { + ProcessorContext context; + private KeyValueStore store; + + @SuppressWarnings("unchecked") + @Override + public void init(final ProcessorContext context) { + this.context = context; + context.schedule(Duration.ofSeconds(60), PunctuationType.WALL_CLOCK_TIME, this::flushStore); + context.schedule(Duration.ofSeconds(10), PunctuationType.STREAM_TIME, this::flushStore); + store = context.getStateStore("aggStore"); + } + + @Override + public void process(final Record record) { + final Long oldValue = store.get(record.key()); + if (oldValue == null || record.value() > oldValue) { + store.put(record.key(), record.value()); + } + } + + private void flushStore(final long timestamp) { + try (final KeyValueIterator it = store.all()) { + while (it.hasNext()) { + final KeyValue next = it.next(); + context.forward(new Record<>(next.key, next.value, timestamp)); + } + } + } + } +} diff --git a/streams/examples/src/test/java/org/apache/kafka/streams/examples/wordcount/WordCountDemoTest.java b/streams/examples/src/test/java/org/apache/kafka/streams/examples/wordcount/WordCountDemoTest.java new file mode 100644 index 0000000..ccb8a74 --- /dev/null +++ b/streams/examples/src/test/java/org/apache/kafka/streams/examples/wordcount/WordCountDemoTest.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.examples.wordcount; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.common.serialization.LongDeserializer; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TestOutputTopic; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +/** + * Unit test of {@link WordCountDemo} stream using TopologyTestDriver. + */ +public class WordCountDemoTest { + + private TopologyTestDriver testDriver; + private TestInputTopic inputTopic; + private TestOutputTopic outputTopic; + + @BeforeEach + public void setup() throws IOException { + final StreamsBuilder builder = new StreamsBuilder(); + //Create Actual Stream Processing pipeline + WordCountDemo.createWordCountStream(builder); + testDriver = new TopologyTestDriver(builder.build(), WordCountDemo.getStreamsConfig(null)); + inputTopic = testDriver.createInputTopic(WordCountDemo.INPUT_TOPIC, new StringSerializer(), new StringSerializer()); + outputTopic = testDriver.createOutputTopic(WordCountDemo.OUTPUT_TOPIC, new StringDeserializer(), new LongDeserializer()); + } + + @AfterEach + public void tearDown() { + try { + testDriver.close(); + } catch (final RuntimeException e) { + // https://issues.apache.org/jira/browse/KAFKA-6647 causes exception when executed in Windows, ignoring it + // Logged stacktrace cannot be avoided + System.out.println("Ignoring exception, test failing in Windows due this exception:" + e.getLocalizedMessage()); + } + } + + + /** + * Simple test validating count of one word + */ + @Test + public void testOneWord() { + //Feed word "Hello" to inputTopic and no kafka key, timestamp is irrelevant in this case + inputTopic.pipeInput("Hello"); + //Read and validate output to match word as key and count as value + assertThat(outputTopic.readKeyValue(), equalTo(new KeyValue<>("hello", 1L))); + //No more output in topic + assertThat(outputTopic.isEmpty(), is(true)); + } + + /** + * Test Word count of sentence list. + */ + @Test + public void testCountListOfWords() { + final List inputValues = Arrays.asList( + "Apache Kafka Streams Example", + "Using \t\t Kafka Streams\tTest Utils", + "Reading and Writing Kafka Topic" + ); + final Map expectedWordCounts = new HashMap<>(); + expectedWordCounts.put("apache", 1L); + expectedWordCounts.put("kafka", 3L); + expectedWordCounts.put("streams", 2L); + expectedWordCounts.put("example", 1L); + expectedWordCounts.put("using", 1L); + expectedWordCounts.put("test", 1L); + expectedWordCounts.put("utils", 1L); + expectedWordCounts.put("reading", 1L); + expectedWordCounts.put("and", 1L); + expectedWordCounts.put("writing", 1L); + expectedWordCounts.put("topic", 1L); + + inputTopic.pipeValueList(inputValues); + final Map actualWordCounts = outputTopic.readKeyValuesToMap(); + assertThat(actualWordCounts, equalTo(expectedWordCounts)); + } + + @Test + public void testGetStreamsConfig() throws IOException { + final File tmp = TestUtils.tempFile("bootstrap.servers=localhost:1234"); + try { + Properties config = WordCountDemo.getStreamsConfig(new String[] {tmp.getPath()}); + assertThat("localhost:1234", equalTo(config.getProperty(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG))); + + config = WordCountDemo.getStreamsConfig(new String[] {tmp.getPath(), "extra", "args"}); + assertThat("localhost:1234", equalTo(config.getProperty(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG))); + } finally { + Files.deleteIfExists(tmp.toPath()); + } + } + +} diff --git a/streams/examples/src/test/java/org/apache/kafka/streams/examples/wordcount/WordCountProcessorTest.java b/streams/examples/src/test/java/org/apache/kafka/streams/examples/wordcount/WordCountProcessorTest.java new file mode 100644 index 0000000..1343a29 --- /dev/null +++ b/streams/examples/src/test/java/org/apache/kafka/streams/examples/wordcount/WordCountProcessorTest.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.examples.wordcount; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.processor.api.MockProcessorContext; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.Stores; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.List; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Demonstrate the use of {@link MockProcessorContext} for testing the {@link Processor} in the {@link WordCountProcessorDemo}. + */ +public class WordCountProcessorTest { + @Test + public void test() { + final MockProcessorContext context = new MockProcessorContext(); + + // Create, initialize, and register the state store. + final KeyValueStore store = + Stores.keyValueStoreBuilder(Stores.inMemoryKeyValueStore("Counts"), Serdes.String(), Serdes.Integer()) + .withLoggingDisabled() // Changelog is not supported by MockProcessorContext. + // Caching is disabled by default, but FYI: caching is also not supported by MockProcessorContext. + .build(); + store.init(context.getStateStoreContext(), store); + + // Create and initialize the processor under test + final Processor processor = new WordCountProcessorDemo.WordCountProcessor(); + processor.init(context); + + // send a record to the processor + processor.process(new Record<>("key", "alpha beta\tgamma\n\talpha", 0L)); + + // note that the processor does not forward during process() + assertTrue(context.forwarded().isEmpty()); + + // now, we trigger the punctuator, which iterates over the state store and forwards the contents. + context.scheduledPunctuators().get(0).getPunctuator().punctuate(0L); + + // finally, we can verify the output. + final List> expected = Arrays.asList( + new MockProcessorContext.CapturedForward<>(new Record<>("alpha", "2", 0L)), + new MockProcessorContext.CapturedForward<>(new Record<>("beta", "1", 0L)), + new MockProcessorContext.CapturedForward<>(new Record<>("gamma", "1", 0L)) + ); + assertThat(context.forwarded(), is(expected)); + } +} diff --git a/streams/examples/src/test/java/org/apache/kafka/streams/examples/wordcount/WordCountTransformerTest.java b/streams/examples/src/test/java/org/apache/kafka/streams/examples/wordcount/WordCountTransformerTest.java new file mode 100644 index 0000000..95a6391 --- /dev/null +++ b/streams/examples/src/test/java/org/apache/kafka/streams/examples/wordcount/WordCountTransformerTest.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.examples.wordcount; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Transformer; +import org.apache.kafka.streams.processor.Cancellable; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.Punctuator; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.api.MockProcessorContext; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.state.StoreBuilder; +import org.junit.jupiter.api.Test; + +import java.time.Duration; +import java.util.List; + +import static java.util.Arrays.asList; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Demonstrate the use of {@link MockProcessorContext} for testing the {@link Transformer} in the {@link WordCountTransformerDemo}. + */ +public class WordCountTransformerTest { + @Test + public void test() { + final MockProcessorContext context = new MockProcessorContext<>(); + + // Create and initialize the transformer under test; including its provided store + final WordCountTransformerDemo.MyTransformerSupplier supplier = new WordCountTransformerDemo.MyTransformerSupplier(); + for (final StoreBuilder storeBuilder : supplier.stores()) { + final StateStore store = storeBuilder + .withLoggingDisabled() // Changelog is not supported by MockProcessorContext. + // Caching is disabled by default, but FYI: caching is also not supported by MockProcessorContext. + .build(); + store.init(context.getStateStoreContext(), store); + context.getStateStoreContext().register(store, null); + } + final Transformer> transformer = supplier.get(); + transformer.init(new org.apache.kafka.streams.processor.MockProcessorContext() { + @Override + public S getStateStore(final String name) { + return context.getStateStore(name); + } + + @Override + public void forward(final K key, final V value) { + context.forward(new Record<>((String) key, (String) value, 0L)); + } + + @Override + public Cancellable schedule(final Duration interval, final PunctuationType type, final Punctuator callback) { + return context.schedule(interval, type, callback); + } + }); + + // send a record to the transformer + transformer.transform("key", "alpha beta\tgamma\n\talpha"); + + // note that the transformer does not forward during transform() + assertTrue(context.forwarded().isEmpty()); + + // now, we trigger the punctuator, which iterates over the state store and forwards the contents. + context.scheduledPunctuators().get(0).getPunctuator().punctuate(0L); + + // finally, we can verify the output. + final List> capturedForwards = context.forwarded(); + final List> expected = asList( + new MockProcessorContext.CapturedForward<>(new Record<>("alpha", "2", 0L)), + new MockProcessorContext.CapturedForward<>(new Record<>("beta", "1", 0L)), + new MockProcessorContext.CapturedForward<>(new Record<>("gamma", "1", 0L)) + ); + assertThat(capturedForwards, is(expected)); + } +} diff --git a/streams/quickstart/java/pom.xml b/streams/quickstart/java/pom.xml new file mode 100644 index 0000000..0f93c6e --- /dev/null +++ b/streams/quickstart/java/pom.xml @@ -0,0 +1,36 @@ + + + + 4.0.0 + + + UTF-8 + + + + org.apache.kafka + streams-quickstart + 3.1.0 + .. + + + streams-quickstart-java + maven-archetype + + diff --git a/streams/quickstart/java/src/main/resources/META-INF/maven/archetype-metadata.xml b/streams/quickstart/java/src/main/resources/META-INF/maven/archetype-metadata.xml new file mode 100644 index 0000000..9e0d8bd --- /dev/null +++ b/streams/quickstart/java/src/main/resources/META-INF/maven/archetype-metadata.xml @@ -0,0 +1,34 @@ + + + + + + src/main/java + + **/*.java + + + + src/main/resources + + + diff --git a/streams/quickstart/java/src/main/resources/archetype-resources/pom.xml b/streams/quickstart/java/src/main/resources/archetype-resources/pom.xml new file mode 100644 index 0000000..07a0b67 --- /dev/null +++ b/streams/quickstart/java/src/main/resources/archetype-resources/pom.xml @@ -0,0 +1,136 @@ + + + + 4.0.0 + + ${groupId} + ${artifactId} + ${version} + jar + + Kafka Streams Quickstart :: Java + + + UTF-8 + 3.1.0 + 1.7.7 + 1.2.17 + + + + + apache.snapshots + Apache Development Snapshot Repository + https://repository.apache.org/content/repositories/snapshots/ + + false + + + true + + + + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.1 + + 1.8 + 1.8 + + + + + + + + maven-compiler-plugin + + 1.8 + 1.8 + jdt + + + + org.eclipse.tycho + tycho-compiler-jdt + 0.21.0 + + + + + org.eclipse.m2e + lifecycle-mapping + 1.0.0 + + + + + + org.apache.maven.plugins + maven-assembly-plugin + [2.4,) + + single + + + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + [3.1,) + + testCompile + compile + + + + + + + + + + + + + + + + + + org.apache.kafka + kafka-streams + ${kafka.version} + + + diff --git a/streams/quickstart/java/src/main/resources/archetype-resources/src/main/java/LineSplit.java b/streams/quickstart/java/src/main/resources/archetype-resources/src/main/java/LineSplit.java new file mode 100644 index 0000000..d712a83 --- /dev/null +++ b/streams/quickstart/java/src/main/resources/archetype-resources/src/main/java/LineSplit.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ${package}; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.kstream.ValueMapper; + +import java.util.Arrays; +import java.util.Properties; +import java.util.concurrent.CountDownLatch; + +/** + * In this example, we implement a simple LineSplit program using the high-level Streams DSL + * that reads from a source topic "streams-plaintext-input", where the values of messages represent lines of text; + * the code split each text line in string into words and then write back into a sink topic "streams-linesplit-output" where + * each record represents a single word. + */ +public class LineSplit { + + public static void main(String[] args) throws Exception { + Properties props = new Properties(); + props.put(StreamsConfig.APPLICATION_ID_CONFIG, "streams-linesplit"); + props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + props.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + props.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + + final StreamsBuilder builder = new StreamsBuilder(); + + builder.stream("streams-plaintext-input") + .flatMapValues(value -> Arrays.asList(value.split("\\W+"))) + .to("streams-linesplit-output"); + + final Topology topology = builder.build(); + final KafkaStreams streams = new KafkaStreams(topology, props); + final CountDownLatch latch = new CountDownLatch(1); + + // attach shutdown handler to catch control-c + Runtime.getRuntime().addShutdownHook(new Thread("streams-shutdown-hook") { + @Override + public void run() { + streams.close(); + latch.countDown(); + } + }); + + try { + streams.start(); + latch.await(); + } catch (Throwable e) { + System.exit(1); + } + System.exit(0); + } +} diff --git a/streams/quickstart/java/src/main/resources/archetype-resources/src/main/java/Pipe.java b/streams/quickstart/java/src/main/resources/archetype-resources/src/main/java/Pipe.java new file mode 100644 index 0000000..b3152a7 --- /dev/null +++ b/streams/quickstart/java/src/main/resources/archetype-resources/src/main/java/Pipe.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ${package}; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; + +import java.util.Properties; +import java.util.concurrent.CountDownLatch; + +/** + * In this example, we implement a simple LineSplit program using the high-level Streams DSL + * that reads from a source topic "streams-plaintext-input", where the values of messages represent lines of text, + * and writes the messages as-is into a sink topic "streams-pipe-output". + */ +public class Pipe { + + public static void main(String[] args) throws Exception { + Properties props = new Properties(); + props.put(StreamsConfig.APPLICATION_ID_CONFIG, "streams-pipe"); + props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + props.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + props.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + + final StreamsBuilder builder = new StreamsBuilder(); + + builder.stream("streams-plaintext-input").to("streams-pipe-output"); + + final Topology topology = builder.build(); + final KafkaStreams streams = new KafkaStreams(topology, props); + final CountDownLatch latch = new CountDownLatch(1); + + // attach shutdown handler to catch control-c + Runtime.getRuntime().addShutdownHook(new Thread("streams-shutdown-hook") { + @Override + public void run() { + streams.close(); + latch.countDown(); + } + }); + + try { + streams.start(); + latch.await(); + } catch (Throwable e) { + System.exit(1); + } + System.exit(0); + } +} diff --git a/streams/quickstart/java/src/main/resources/archetype-resources/src/main/java/WordCount.java b/streams/quickstart/java/src/main/resources/archetype-resources/src/main/java/WordCount.java new file mode 100644 index 0000000..bdbefed --- /dev/null +++ b/streams/quickstart/java/src/main/resources/archetype-resources/src/main/java/WordCount.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package ${package}; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.ValueMapper; +import org.apache.kafka.streams.state.KeyValueStore; + +import java.util.Arrays; +import java.util.Locale; +import java.util.Properties; +import java.util.concurrent.CountDownLatch; + +/** + * In this example, we implement a simple WordCount program using the high-level Streams DSL + * that reads from a source topic "streams-plaintext-input", where the values of messages represent lines of text, + * split each text line into words and then compute the word occurence histogram, write the continuous updated histogram + * into a topic "streams-wordcount-output" where each record is an updated count of a single word. + */ +public class WordCount { + + public static void main(String[] args) throws Exception { + Properties props = new Properties(); + props.put(StreamsConfig.APPLICATION_ID_CONFIG, "streams-wordcount"); + props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + props.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + props.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + + final StreamsBuilder builder = new StreamsBuilder(); + + builder.stream("streams-plaintext-input") + .flatMapValues(value -> Arrays.asList(value.toLowerCase(Locale.getDefault()).split("\\W+"))) + .groupBy((key, value) -> value) + .count(Materialized.>as("counts-store")) + .toStream() + .to("streams-wordcount-output", Produced.with(Serdes.String(), Serdes.Long())); + + final Topology topology = builder.build(); + final KafkaStreams streams = new KafkaStreams(topology, props); + final CountDownLatch latch = new CountDownLatch(1); + + // attach shutdown handler to catch control-c + Runtime.getRuntime().addShutdownHook(new Thread("streams-shutdown-hook") { + @Override + public void run() { + streams.close(); + latch.countDown(); + } + }); + + try { + streams.start(); + latch.await(); + } catch (Throwable e) { + System.exit(1); + } + System.exit(0); + } +} diff --git a/streams/quickstart/java/src/main/resources/archetype-resources/src/main/resources/log4j.properties b/streams/quickstart/java/src/main/resources/archetype-resources/src/main/resources/log4j.properties new file mode 100644 index 0000000..b620f1b --- /dev/null +++ b/streams/quickstart/java/src/main/resources/archetype-resources/src/main/resources/log4j.properties @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +log4j.rootLogger=INFO, console + +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{HH:mm:ss,SSS} %-5p %-60c %x - %m%n \ No newline at end of file diff --git a/streams/quickstart/java/src/test/resources/projects/basic/archetype.properties b/streams/quickstart/java/src/test/resources/projects/basic/archetype.properties new file mode 100644 index 0000000..c4a7c16 --- /dev/null +++ b/streams/quickstart/java/src/test/resources/projects/basic/archetype.properties @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +groupId=org.apache.kafka.archtypetest +version=0.1 +artifactId=basic +package=org.apache.kafka.archetypetest diff --git a/streams/quickstart/java/src/test/resources/projects/basic/goal.txt b/streams/quickstart/java/src/test/resources/projects/basic/goal.txt new file mode 100644 index 0000000..f8808ba --- /dev/null +++ b/streams/quickstart/java/src/test/resources/projects/basic/goal.txt @@ -0,0 +1 @@ +compile \ No newline at end of file diff --git a/streams/quickstart/pom.xml b/streams/quickstart/pom.xml new file mode 100644 index 0000000..e063a39 --- /dev/null +++ b/streams/quickstart/pom.xml @@ -0,0 +1,121 @@ + + + + 4.0.0 + + org.apache.kafka + streams-quickstart + pom + 3.1.0 + + Kafka Streams :: Quickstart + + + org.apache + apache + 18 + + + + java + + + + + org.apache.maven.archetype + archetype-packaging + 2.2 + + + + + + org.apache.maven.plugins + maven-archetype-plugin + 2.2 + + + + + + maven-archetype-plugin + 2.2 + + true + + + + + org.apache.maven.plugins + maven-shade-plugin + 3.1.0 + + + + + + + + + com.github.siom79.japicmp + japicmp-maven-plugin + 0.11.0 + + true + + + + + + org.apache.maven.plugins + maven-resources-plugin + + false + + @ + + + + + org.apache.maven.plugins + maven-gpg-plugin + 1.6 + + + sign-artifacts + verify + + sign + + + ${gpg.keyname} + ${gpg.keyname} + + + + + + + + src/main/resources + true + + + + diff --git a/streams/src/main/java/org/apache/kafka/streams/KafkaClientSupplier.java b/streams/src/main/java/org/apache/kafka/streams/KafkaClientSupplier.java new file mode 100644 index 0000000..fc96ca7 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/KafkaClientSupplier.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.producer.Producer; +import org.apache.kafka.streams.kstream.GlobalKTable; +import org.apache.kafka.streams.processor.StateStore; + +import java.util.Map; + +/** + * {@code KafkaClientSupplier} can be used to provide custom Kafka clients to a {@link KafkaStreams} instance. + * + * @see KafkaStreams#KafkaStreams(Topology, java.util.Properties, KafkaClientSupplier) + */ +public interface KafkaClientSupplier { + /** + * Create an {@link Admin} which is used for internal topic management. + * + * @param config Supplied by the {@link java.util.Properties} given to the {@link KafkaStreams} + * @return an instance of {@link Admin} + */ + default Admin getAdmin(final Map config) { + throw new UnsupportedOperationException("Implementations of KafkaClientSupplier should implement the getAdmin() method."); + } + + /** + * Create a {@link Producer} which is used to write records to sink topics. + * + * @param config {@link StreamsConfig#getProducerConfigs(String) producer config} which is supplied by the + * {@link java.util.Properties} given to the {@link KafkaStreams} instance + * @return an instance of Kafka producer + */ + Producer getProducer(final Map config); + + /** + * Create a {@link Consumer} which is used to read records of source topics. + * + * @param config {@link StreamsConfig#getMainConsumerConfigs(String, String, int) consumer config} which is + * supplied by the {@link java.util.Properties} given to the {@link KafkaStreams} instance + * @return an instance of Kafka consumer + */ + Consumer getConsumer(final Map config); + + /** + * Create a {@link Consumer} which is used to read records to restore {@link StateStore}s. + * + * @param config {@link StreamsConfig#getRestoreConsumerConfigs(String) restore consumer config} which is supplied + * by the {@link java.util.Properties} given to the {@link KafkaStreams} + * @return an instance of Kafka consumer + */ + Consumer getRestoreConsumer(final Map config); + + /** + * Create a {@link Consumer} which is used to consume records for {@link GlobalKTable}. + * + * @param config {@link StreamsConfig#getGlobalConsumerConfigs(String) global consumer config} which is supplied + * by the {@link java.util.Properties} given to the {@link KafkaStreams} + * @return an instance of Kafka consumer + */ + Consumer getGlobalConsumer(final Map config); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java b/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java new file mode 100644 index 0000000..5067da6 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/KafkaStreams.java @@ -0,0 +1,1718 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.ListOffsetsResult.ListOffsetsResultInfo; +import org.apache.kafka.clients.admin.MemberToRemove; +import org.apache.kafka.clients.admin.RemoveMembersFromConsumerGroupOptions; +import org.apache.kafka.clients.admin.RemoveMembersFromConsumerGroupResult; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.metrics.JmxReporter; +import org.apache.kafka.common.metrics.KafkaMetricsContext; +import org.apache.kafka.common.metrics.MetricConfig; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.MetricsContext; +import org.apache.kafka.common.metrics.MetricsReporter; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.Sensor.RecordingLevel; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.errors.ProcessorStateException; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.errors.StreamsNotStartedException; +import org.apache.kafka.streams.errors.StreamsUncaughtExceptionHandler; +import org.apache.kafka.streams.errors.UnknownStateStoreException; +import org.apache.kafka.streams.errors.InvalidStateStorePartitionException; +import org.apache.kafka.streams.internals.metrics.ClientMetrics; +import org.apache.kafka.streams.processor.StateRestoreListener; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StreamPartitioner; +import org.apache.kafka.streams.processor.internals.ClientUtils; +import org.apache.kafka.streams.processor.internals.DefaultKafkaClientSupplier; +import org.apache.kafka.streams.processor.internals.GlobalStreamThread; +import org.apache.kafka.streams.processor.internals.TopologyMetadata; +import org.apache.kafka.streams.processor.internals.StateDirectory; +import org.apache.kafka.streams.processor.internals.StreamThread; +import org.apache.kafka.streams.processor.internals.StreamsMetadataState; +import org.apache.kafka.streams.processor.internals.Task; +import org.apache.kafka.streams.processor.internals.ThreadStateTransitionValidator; +import org.apache.kafka.streams.processor.internals.assignment.AssignorError; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.HostInfo; +import org.apache.kafka.streams.state.internals.GlobalStateStoreProvider; +import org.apache.kafka.streams.state.internals.QueryableStoreProvider; +import org.apache.kafka.streams.state.internals.StreamThreadStateStoreProvider; +import org.slf4j.Logger; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Properties; +import java.util.Set; +import java.util.TreeMap; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +import static org.apache.kafka.streams.StreamsConfig.METRICS_RECORDING_LEVEL_CONFIG; +import static org.apache.kafka.streams.errors.StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse.SHUTDOWN_CLIENT; +import static org.apache.kafka.streams.internals.ApiUtils.prepareMillisCheckFailMsgPrefix; +import static org.apache.kafka.streams.internals.ApiUtils.validateMillisecondDuration; +import static org.apache.kafka.streams.processor.internals.ClientUtils.fetchEndOffsets; + +/** + * A Kafka client that allows for performing continuous computation on input coming from one or more input topics and + * sends output to zero, one, or more output topics. + *

                + * The computational logic can be specified either by using the {@link Topology} to define a DAG topology of + * {@link org.apache.kafka.streams.processor.api.Processor}s or by using the {@link StreamsBuilder} which provides the high-level DSL to define + * transformations. + *

                + * One {@code KafkaStreams} instance can contain one or more threads specified in the configs for the processing work. + *

                + * A {@code KafkaStreams} instance can co-ordinate with any other instances with the same + * {@link StreamsConfig#APPLICATION_ID_CONFIG application ID} (whether in the same process, on other processes on this + * machine, or on remote machines) as a single (possibly distributed) stream processing application. + * These instances will divide up the work based on the assignment of the input topic partitions so that all partitions + * are being consumed. + * If instances are added or fail, all (remaining) instances will rebalance the partition assignment among themselves + * to balance processing load and ensure that all input topic partitions are processed. + *

                + * Internally a {@code KafkaStreams} instance contains a normal {@link KafkaProducer} and {@link KafkaConsumer} instance + * that is used for reading input and writing output. + *

                + * A simple example might look like this: + *

                {@code
                + * Properties props = new Properties();
                + * props.put(StreamsConfig.APPLICATION_ID_CONFIG, "my-stream-processing-application");
                + * props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092");
                + * props.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass());
                + * props.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass());
                + *
                + * StreamsBuilder builder = new StreamsBuilder();
                + * builder.stream("my-input-topic").mapValues(value -> String.valueOf(value.length())).to("my-output-topic");
                + *
                + * KafkaStreams streams = new KafkaStreams(builder.build(), props);
                + * streams.start();
                + * }
                + * + * @see org.apache.kafka.streams.StreamsBuilder + * @see org.apache.kafka.streams.Topology + */ +public class KafkaStreams implements AutoCloseable { + + private static final String JMX_PREFIX = "kafka.streams"; + + private static final Set> EXCEPTIONS_NOT_TO_BE_HANDLED_BY_USERS = + new HashSet<>(Arrays.asList(IllegalStateException.class, IllegalArgumentException.class)); + + // processId is expected to be unique across JVMs and to be used + // in userData of the subscription request to allow assignor be aware + // of the co-location of stream thread's consumers. It is for internal + // usage only and should not be exposed to users at all. + private final Time time; + private final Logger log; + private final String clientId; + private final Metrics metrics; + private final StreamsConfig config; + protected final List threads; + protected final StateDirectory stateDirectory; + private final StreamsMetadataState streamsMetadataState; + private final ScheduledExecutorService stateDirCleaner; + private final ScheduledExecutorService rocksDBMetricsRecordingService; + private final Admin adminClient; + private final StreamsMetricsImpl streamsMetrics; + private final long totalCacheSize; + private final StreamStateListener streamStateListener; + private final StateRestoreListener delegatingStateRestoreListener; + private final Map threadState; + private final UUID processId; + private final KafkaClientSupplier clientSupplier; + protected final TopologyMetadata topologyMetadata; + private final QueryableStoreProvider queryableStoreProvider; + + GlobalStreamThread globalStreamThread; + private KafkaStreams.StateListener stateListener; + private StateRestoreListener globalStateRestoreListener; + private boolean oldHandler; + private java.util.function.Consumer streamsUncaughtExceptionHandler; + private final Object changeThreadCount = new Object(); + + // container states + /** + * Kafka Streams states are the possible state that a Kafka Streams instance can be in. + * An instance must only be in one state at a time. + * The expected state transition with the following defined states is: + * + *
                +     *                 +--------------+
                +     *         +<----- | Created (0)  |
                +     *         |       +-----+--------+
                +     *         |             |
                +     *         |             v
                +     *         |       +----+--+------+
                +     *         |       | Re-          |
                +     *         +<----- | Balancing (1)| -------->+
                +     *         |       +-----+-+------+          |
                +     *         |             | ^                 |
                +     *         |             v |                 |
                +     *         |       +--------------+          v
                +     *         |       | Running (2)  | -------->+
                +     *         |       +------+-------+          |
                +     *         |              |                  |
                +     *         |              v                  |
                +     *         |       +------+-------+     +----+-------+
                +     *         +-----> | Pending      |     | Pending    |
                +     *                 | Shutdown (3) |     | Error (5)  |
                +     *                 +------+-------+     +-----+------+
                +     *                        |                   |
                +     *                        v                   v
                +     *                 +------+-------+     +-----+--------+
                +     *                 | Not          |     | Error (6)    |
                +     *                 | Running (4)  |     +--------------+
                +     *                 +--------------+
                +     *
                +     *
                +     * 
                + * Note the following: + * - RUNNING state will transit to REBALANCING if any of its threads is in PARTITION_REVOKED or PARTITIONS_ASSIGNED state + * - REBALANCING state will transit to RUNNING if all of its threads are in RUNNING state + * - Any state except NOT_RUNNING, PENDING_ERROR or ERROR can go to PENDING_SHUTDOWN (whenever close is called) + * - Of special importance: If the global stream thread dies, or all stream threads die (or both) then + * the instance will be in the ERROR state. The user will not need to close it. + */ + public enum State { + // Note: if you add a new state, check the below methods and how they are used within Streams to see if + // any of them should be updated to include the new state. For example a new shutdown path or terminal + // state would likely need to be included in methods like isShuttingDown(), hasCompletedShutdown(), etc. + CREATED(1, 3), // 0 + REBALANCING(2, 3, 5), // 1 + RUNNING(1, 2, 3, 5), // 2 + PENDING_SHUTDOWN(4), // 3 + NOT_RUNNING, // 4 + PENDING_ERROR(6), // 5 + ERROR; // 6 + + private final Set validTransitions = new HashSet<>(); + + State(final Integer... validTransitions) { + this.validTransitions.addAll(Arrays.asList(validTransitions)); + } + + public boolean hasNotStarted() { + return equals(CREATED); + } + + public boolean isRunningOrRebalancing() { + return equals(RUNNING) || equals(REBALANCING); + } + + public boolean isShuttingDown() { + return equals(PENDING_SHUTDOWN) || equals(PENDING_ERROR); + } + + public boolean hasCompletedShutdown() { + return equals(NOT_RUNNING) || equals(ERROR); + } + + public boolean hasStartedOrFinishedShuttingDown() { + return isShuttingDown() || hasCompletedShutdown(); + } + + public boolean isValidTransition(final State newState) { + return validTransitions.contains(newState.ordinal()); + } + } + + private final Object stateLock = new Object(); + protected volatile State state = State.CREATED; + + private boolean waitOnState(final State targetState, final long waitMs) { + final long begin = time.milliseconds(); + synchronized (stateLock) { + boolean interrupted = false; + long elapsedMs = 0L; + try { + while (state != targetState) { + if (waitMs > elapsedMs) { + final long remainingMs = waitMs - elapsedMs; + try { + stateLock.wait(remainingMs); + } catch (final InterruptedException e) { + interrupted = true; + } + } else { + log.debug("Cannot transit to {} within {}ms", targetState, waitMs); + return false; + } + elapsedMs = time.milliseconds() - begin; + } + } finally { + // Make sure to restore the interruption status before returning. + // We do not always own the current thread that executes this method, i.e., we do not know the + // interruption policy of the thread. The least we can do is restore the interruption status before + // the current thread exits this method. + if (interrupted) { + Thread.currentThread().interrupt(); + } + } + return true; + } + } + + /** + * Sets the state + * @param newState New state + */ + private boolean setState(final State newState) { + final State oldState; + + synchronized (stateLock) { + oldState = state; + + if (state == State.PENDING_SHUTDOWN && newState != State.NOT_RUNNING) { + // when the state is already in PENDING_SHUTDOWN, all other transitions than NOT_RUNNING (due to thread dying) will be + // refused but we do not throw exception here, to allow appropriate error handling + return false; + } else if (state == State.NOT_RUNNING && (newState == State.PENDING_SHUTDOWN || newState == State.NOT_RUNNING)) { + // when the state is already in NOT_RUNNING, its transition to PENDING_SHUTDOWN or NOT_RUNNING (due to consecutive close calls) + // will be refused but we do not throw exception here, to allow idempotent close calls + return false; + } else if (state == State.REBALANCING && newState == State.REBALANCING) { + // when the state is already in REBALANCING, it should not transit to REBALANCING again + return false; + } else if (state == State.ERROR && (newState == State.PENDING_ERROR || newState == State.ERROR)) { + // when the state is already in ERROR, its transition to PENDING_ERROR or ERROR (due to consecutive close calls) + return false; + } else if (state == State.PENDING_ERROR && newState != State.ERROR) { + // when the state is already in PENDING_ERROR, all other transitions than ERROR (due to thread dying) will be + // refused but we do not throw exception here, to allow appropriate error handling + return false; + } else if (!state.isValidTransition(newState)) { + throw new IllegalStateException("Stream-client " + clientId + ": Unexpected state transition from " + oldState + " to " + newState); + } else { + log.info("State transition from {} to {}", oldState, newState); + } + state = newState; + stateLock.notifyAll(); + } + + // we need to call the user customized state listener outside the state lock to avoid potential deadlocks + if (stateListener != null) { + stateListener.onChange(newState, oldState); + } + + return true; + } + + /** + * Return the current {@link State} of this {@code KafkaStreams} instance. + * + * @return the current state of this Kafka Streams instance + */ + public State state() { + return state; + } + + protected boolean isRunningOrRebalancing() { + synchronized (stateLock) { + return state.isRunningOrRebalancing(); + } + } + + protected boolean hasStartedOrFinishedShuttingDown() { + synchronized (stateLock) { + return state.hasStartedOrFinishedShuttingDown(); + } + } + + private void validateIsRunningOrRebalancing() { + synchronized (stateLock) { + if (state.hasNotStarted()) { + throw new StreamsNotStartedException("KafkaStreams has not been started, you can retry after calling start()"); + } + if (!state.isRunningOrRebalancing()) { + throw new IllegalStateException("KafkaStreams is not running. State is " + state + "."); + } + } + } + + /** + * Listen to {@link State} change events. + */ + public interface StateListener { + + /** + * Called when state changes. + * + * @param newState new state + * @param oldState previous state + */ + void onChange(final State newState, final State oldState); + } + + /** + * An app can set a single {@link KafkaStreams.StateListener} so that the app is notified when state changes. + * + * @param listener a new state listener + * @throws IllegalStateException if this {@code KafkaStreams} instance has already been started. + */ + public void setStateListener(final KafkaStreams.StateListener listener) { + synchronized (stateLock) { + if (state.hasNotStarted()) { + stateListener = listener; + } else { + throw new IllegalStateException("Can only set StateListener before calling start(). Current state is: " + state); + } + } + } + + /** + * Set the handler invoked when an internal {@link StreamsConfig#NUM_STREAM_THREADS_CONFIG stream thread} abruptly + * terminates due to an uncaught exception. + * + * @param uncaughtExceptionHandler the uncaught exception handler for all internal threads; {@code null} deletes the current handler + * @throws IllegalStateException if this {@code KafkaStreams} instance has already been started. + * + * @deprecated Since 2.8.0. Use {@link KafkaStreams#setUncaughtExceptionHandler(StreamsUncaughtExceptionHandler)} instead. + * + */ + @Deprecated + public void setUncaughtExceptionHandler(final Thread.UncaughtExceptionHandler uncaughtExceptionHandler) { + synchronized (stateLock) { + if (state.hasNotStarted()) { + oldHandler = true; + processStreamThread(thread -> thread.setUncaughtExceptionHandler(uncaughtExceptionHandler)); + + if (globalStreamThread != null) { + globalStreamThread.setUncaughtExceptionHandler(uncaughtExceptionHandler); + } + } else { + throw new IllegalStateException("Can only set UncaughtExceptionHandler before calling start(). " + + "Current state is: " + state); + } + } + } + + /** + * Set the handler invoked when an internal {@link StreamsConfig#NUM_STREAM_THREADS_CONFIG stream thread} + * throws an unexpected exception. + * These might be exceptions indicating rare bugs in Kafka Streams, or they + * might be exceptions thrown by your code, for example a NullPointerException thrown from your processor logic. + * The handler will execute on the thread that produced the exception. + * In order to get the thread that threw the exception, use {@code Thread.currentThread()}. + *

                + * Note, this handler must be threadsafe, since it will be shared among all threads, and invoked from any + * thread that encounters such an exception. + * + * @param streamsUncaughtExceptionHandler the uncaught exception handler of type {@link StreamsUncaughtExceptionHandler} for all internal threads + * @throws IllegalStateException if this {@code KafkaStreams} instance has already been started. + * @throws NullPointerException if streamsUncaughtExceptionHandler is null. + */ + public void setUncaughtExceptionHandler(final StreamsUncaughtExceptionHandler streamsUncaughtExceptionHandler) { + final Consumer handler = exception -> handleStreamsUncaughtException(exception, streamsUncaughtExceptionHandler); + synchronized (stateLock) { + if (state.hasNotStarted()) { + this.streamsUncaughtExceptionHandler = handler; + Objects.requireNonNull(streamsUncaughtExceptionHandler); + processStreamThread(thread -> thread.setStreamsUncaughtExceptionHandler(handler)); + if (globalStreamThread != null) { + globalStreamThread.setUncaughtExceptionHandler(handler); + } + } else { + throw new IllegalStateException("Can only set UncaughtExceptionHandler before calling start(). " + + "Current state is: " + state); + } + } + } + + private void defaultStreamsUncaughtExceptionHandler(final Throwable throwable) { + if (oldHandler) { + threads.remove(Thread.currentThread()); + if (throwable instanceof RuntimeException) { + throw (RuntimeException) throwable; + } else if (throwable instanceof Error) { + throw (Error) throwable; + } else { + throw new RuntimeException("Unexpected checked exception caught in the uncaught exception handler", throwable); + } + } else { + handleStreamsUncaughtException(throwable, t -> SHUTDOWN_CLIENT); + } + } + + private void replaceStreamThread(final Throwable throwable) { + if (globalStreamThread != null && Thread.currentThread().getName().equals(globalStreamThread.getName())) { + log.warn("The global thread cannot be replaced. Reverting to shutting down the client."); + log.error("Encountered the following exception during processing " + + " The streams client is going to shut down now. ", throwable); + closeToError(); + } + final StreamThread deadThread = (StreamThread) Thread.currentThread(); + deadThread.shutdown(); + addStreamThread(); + if (throwable instanceof RuntimeException) { + throw (RuntimeException) throwable; + } else if (throwable instanceof Error) { + throw (Error) throwable; + } else { + throw new RuntimeException("Unexpected checked exception caught in the uncaught exception handler", throwable); + } + } + + private boolean wrappedExceptionIsIn(final Throwable throwable, final Set> exceptionsOfInterest) { + return throwable.getCause() != null && exceptionsOfInterest.contains(throwable.getCause().getClass()); + } + + private StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse getActionForThrowable(final Throwable throwable, + final StreamsUncaughtExceptionHandler streamsUncaughtExceptionHandler) { + final StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse action; + if (wrappedExceptionIsIn(throwable, EXCEPTIONS_NOT_TO_BE_HANDLED_BY_USERS)) { + action = SHUTDOWN_CLIENT; + } else { + action = streamsUncaughtExceptionHandler.handle(throwable); + } + return action; + } + + private void handleStreamsUncaughtException(final Throwable throwable, + final StreamsUncaughtExceptionHandler streamsUncaughtExceptionHandler) { + final StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse action = getActionForThrowable(throwable, streamsUncaughtExceptionHandler); + if (oldHandler) { + log.warn("Stream's new uncaught exception handler is set as well as the deprecated old handler." + + "The old handler will be ignored as long as a new handler is set."); + } + switch (action) { + case REPLACE_THREAD: + log.error("Replacing thread in the streams uncaught exception handler", throwable); + replaceStreamThread(throwable); + break; + case SHUTDOWN_CLIENT: + log.error("Encountered the following exception during processing " + + "and Kafka Streams opted to " + action + "." + + " The streams client is going to shut down now. ", throwable); + closeToError(); + break; + case SHUTDOWN_APPLICATION: + if (getNumLiveStreamThreads() == 1) { + log.warn("Attempt to shut down the application requires adding a thread to communicate the shutdown. No processing will be done on this thread"); + addStreamThread(); + } + if (throwable instanceof Error) { + log.error("This option requires running threads to shut down the application." + + "but the uncaught exception was an Error, which means this runtime is no " + + "longer in a well-defined state. Attempting to send the shutdown command anyway.", throwable); + } + if (Thread.currentThread().equals(globalStreamThread) && getNumLiveStreamThreads() == 0) { + log.error("Exception in global thread caused the application to attempt to shutdown." + + " This action will succeed only if there is at least one StreamThread running on this client." + + " Currently there are no running threads so will now close the client."); + closeToError(); + break; + } + processStreamThread(thread -> thread.sendShutdownRequest(AssignorError.SHUTDOWN_REQUESTED)); + log.error("Encountered the following exception during processing " + + "and sent shutdown request for the entire application.", throwable); + break; + } + } + + /** + * Set the listener which is triggered whenever a {@link StateStore} is being restored in order to resume + * processing. + * + * @param globalStateRestoreListener The listener triggered when {@link StateStore} is being restored. + * @throws IllegalStateException if this {@code KafkaStreams} instance has already been started. + */ + public void setGlobalStateRestoreListener(final StateRestoreListener globalStateRestoreListener) { + synchronized (stateLock) { + if (state.hasNotStarted()) { + this.globalStateRestoreListener = globalStateRestoreListener; + } else { + throw new IllegalStateException("Can only set GlobalStateRestoreListener before calling start(). " + + "Current state is: " + state); + } + } + } + + /** + * Get read-only handle on global metrics registry, including streams client's own metrics plus + * its embedded producer, consumer and admin clients' metrics. + * + * @return Map of all metrics. + */ + public Map metrics() { + final Map result = new LinkedHashMap<>(); + // producer and consumer clients are per-thread + processStreamThread(thread -> { + result.putAll(thread.producerMetrics()); + result.putAll(thread.consumerMetrics()); + // admin client is shared, so we can actually move it + // to result.putAll(adminClient.metrics()). + // we did it intentionally just for flexibility. + result.putAll(thread.adminClientMetrics()); + }); + // global thread's consumer client + if (globalStreamThread != null) { + result.putAll(globalStreamThread.consumerMetrics()); + } + // self streams metrics + result.putAll(metrics.metrics()); + return Collections.unmodifiableMap(result); + } + + /** + * Class that handles stream thread transitions + */ + final class StreamStateListener implements StreamThread.StateListener { + private final Map threadState; + private GlobalStreamThread.State globalThreadState; + // this lock should always be held before the state lock + private final Object threadStatesLock; + + StreamStateListener(final Map threadState, + final GlobalStreamThread.State globalThreadState) { + this.threadState = threadState; + this.globalThreadState = globalThreadState; + this.threadStatesLock = new Object(); + } + + /** + * If all threads are up, including the global thread, set to RUNNING + */ + private void maybeSetRunning() { + // state can be transferred to RUNNING if all threads are either RUNNING or DEAD + for (final StreamThread.State state : threadState.values()) { + if (state != StreamThread.State.RUNNING && state != StreamThread.State.DEAD) { + return; + } + } + + // the global state thread is relevant only if it is started. There are cases + // when we don't have a global state thread at all, e.g., when we don't have global KTables + if (globalThreadState != null && globalThreadState != GlobalStreamThread.State.RUNNING) { + return; + } + + setState(State.RUNNING); + } + + + @Override + public synchronized void onChange(final Thread thread, + final ThreadStateTransitionValidator abstractNewState, + final ThreadStateTransitionValidator abstractOldState) { + synchronized (threadStatesLock) { + // StreamThreads first + if (thread instanceof StreamThread) { + final StreamThread.State newState = (StreamThread.State) abstractNewState; + threadState.put(thread.getId(), newState); + + if (newState == StreamThread.State.PARTITIONS_REVOKED || newState == StreamThread.State.PARTITIONS_ASSIGNED) { + setState(State.REBALANCING); + } else if (newState == StreamThread.State.RUNNING) { + maybeSetRunning(); + } + } else if (thread instanceof GlobalStreamThread) { + // global stream thread has different invariants + final GlobalStreamThread.State newState = (GlobalStreamThread.State) abstractNewState; + globalThreadState = newState; + + if (newState == GlobalStreamThread.State.RUNNING) { + maybeSetRunning(); + } else if (newState == GlobalStreamThread.State.DEAD) { + log.error("Global thread has died. The streams application or client will now close to ERROR."); + closeToError(); + } + } + } + } + } + + final class DelegatingStateRestoreListener implements StateRestoreListener { + private void throwOnFatalException(final Exception fatalUserException, + final TopicPartition topicPartition, + final String storeName) { + throw new StreamsException( + String.format("Fatal user code error in store restore listener for store %s, partition %s.", + storeName, + topicPartition), + fatalUserException); + } + + @Override + public void onRestoreStart(final TopicPartition topicPartition, + final String storeName, + final long startingOffset, + final long endingOffset) { + if (globalStateRestoreListener != null) { + try { + globalStateRestoreListener.onRestoreStart(topicPartition, storeName, startingOffset, endingOffset); + } catch (final Exception fatalUserException) { + throwOnFatalException(fatalUserException, topicPartition, storeName); + } + } + } + + @Override + public void onBatchRestored(final TopicPartition topicPartition, + final String storeName, + final long batchEndOffset, + final long numRestored) { + if (globalStateRestoreListener != null) { + try { + globalStateRestoreListener.onBatchRestored(topicPartition, storeName, batchEndOffset, numRestored); + } catch (final Exception fatalUserException) { + throwOnFatalException(fatalUserException, topicPartition, storeName); + } + } + } + + @Override + public void onRestoreEnd(final TopicPartition topicPartition, final String storeName, final long totalRestored) { + if (globalStateRestoreListener != null) { + try { + globalStateRestoreListener.onRestoreEnd(topicPartition, storeName, totalRestored); + } catch (final Exception fatalUserException) { + throwOnFatalException(fatalUserException, topicPartition, storeName); + } + } + } + } + + /** + * Create a {@code KafkaStreams} instance. + *

                + * Note: even if you never call {@link #start()} on a {@code KafkaStreams} instance, + * you still must {@link #close()} it to avoid resource leaks. + * + * @param topology the topology specifying the computational logic + * @param props properties for {@link StreamsConfig} + * @throws StreamsException if any fatal error occurs + */ + public KafkaStreams(final Topology topology, + final Properties props) { + this(topology, new StreamsConfig(props), new DefaultKafkaClientSupplier()); + } + + /** + * Create a {@code KafkaStreams} instance. + *

                + * Note: even if you never call {@link #start()} on a {@code KafkaStreams} instance, + * you still must {@link #close()} it to avoid resource leaks. + * + * @param topology the topology specifying the computational logic + * @param props properties for {@link StreamsConfig} + * @param clientSupplier the Kafka clients supplier which provides underlying producer and consumer clients + * for the new {@code KafkaStreams} instance + * @throws StreamsException if any fatal error occurs + */ + public KafkaStreams(final Topology topology, + final Properties props, + final KafkaClientSupplier clientSupplier) { + this(topology, new StreamsConfig(props), clientSupplier, Time.SYSTEM); + } + + /** + * Create a {@code KafkaStreams} instance. + *

                + * Note: even if you never call {@link #start()} on a {@code KafkaStreams} instance, + * you still must {@link #close()} it to avoid resource leaks. + * + * @param topology the topology specifying the computational logic + * @param props properties for {@link StreamsConfig} + * @param time {@code Time} implementation; cannot be null + * @throws StreamsException if any fatal error occurs + */ + public KafkaStreams(final Topology topology, + final Properties props, + final Time time) { + this(topology, new StreamsConfig(props), new DefaultKafkaClientSupplier(), time); + } + + /** + * Create a {@code KafkaStreams} instance. + *

                + * Note: even if you never call {@link #start()} on a {@code KafkaStreams} instance, + * you still must {@link #close()} it to avoid resource leaks. + * + * @param topology the topology specifying the computational logic + * @param props properties for {@link StreamsConfig} + * @param clientSupplier the Kafka clients supplier which provides underlying producer and consumer clients + * for the new {@code KafkaStreams} instance + * @param time {@code Time} implementation; cannot be null + * @throws StreamsException if any fatal error occurs + */ + public KafkaStreams(final Topology topology, + final Properties props, + final KafkaClientSupplier clientSupplier, + final Time time) { + this(topology, new StreamsConfig(props), clientSupplier, time); + } + + /** + * Create a {@code KafkaStreams} instance. + *

                + * Note: even if you never call {@link #start()} on a {@code KafkaStreams} instance, + * you still must {@link #close()} it to avoid resource leaks. + * + * @param topology the topology specifying the computational logic + * @param config configs for Kafka Streams + * @throws StreamsException if any fatal error occurs + */ + public KafkaStreams(final Topology topology, + final StreamsConfig config) { + this(topology, config, new DefaultKafkaClientSupplier()); + } + + /** + * Create a {@code KafkaStreams} instance. + *

                + * Note: even if you never call {@link #start()} on a {@code KafkaStreams} instance, + * you still must {@link #close()} it to avoid resource leaks. + * + * @param topology the topology specifying the computational logic + * @param config configs for Kafka Streams + * @param clientSupplier the Kafka clients supplier which provides underlying producer and consumer clients + * for the new {@code KafkaStreams} instance + * @throws StreamsException if any fatal error occurs + */ + public KafkaStreams(final Topology topology, + final StreamsConfig config, + final KafkaClientSupplier clientSupplier) { + this(new TopologyMetadata(topology.internalTopologyBuilder, config), config, clientSupplier); + } + + /** + * Create a {@code KafkaStreams} instance. + *

                + * Note: even if you never call {@link #start()} on a {@code KafkaStreams} instance, + * you still must {@link #close()} it to avoid resource leaks. + * + * @param topology the topology specifying the computational logic + * @param config configs for Kafka Streams + * @param time {@code Time} implementation; cannot be null + * @throws StreamsException if any fatal error occurs + */ + public KafkaStreams(final Topology topology, + final StreamsConfig config, + final Time time) { + this(new TopologyMetadata(topology.internalTopologyBuilder, config), config, new DefaultKafkaClientSupplier(), time); + } + + private KafkaStreams(final Topology topology, + final StreamsConfig config, + final KafkaClientSupplier clientSupplier, + final Time time) throws StreamsException { + this(new TopologyMetadata(topology.internalTopologyBuilder, config), config, clientSupplier, time); + } + + protected KafkaStreams(final TopologyMetadata topologyMetadata, + final StreamsConfig config, + final KafkaClientSupplier clientSupplier) throws StreamsException { + this(topologyMetadata, config, clientSupplier, Time.SYSTEM); + } + + private KafkaStreams(final TopologyMetadata topologyMetadata, + final StreamsConfig config, + final KafkaClientSupplier clientSupplier, + final Time time) throws StreamsException { + this.config = config; + this.time = time; + + this.topologyMetadata = topologyMetadata; + this.topologyMetadata.buildAndRewriteTopology(); + + final boolean hasGlobalTopology = topologyMetadata.hasGlobalTopology(); + + try { + stateDirectory = new StateDirectory(config, time, topologyMetadata.hasPersistentStores(), topologyMetadata.hasNamedTopologies()); + processId = stateDirectory.initializeProcessId(); + } catch (final ProcessorStateException fatal) { + throw new StreamsException(fatal); + } + + // The application ID is a required config and hence should always have value + final String userClientId = config.getString(StreamsConfig.CLIENT_ID_CONFIG); + final String applicationId = config.getString(StreamsConfig.APPLICATION_ID_CONFIG); + if (userClientId.length() <= 0) { + clientId = applicationId + "-" + processId; + } else { + clientId = userClientId; + } + final LogContext logContext = new LogContext(String.format("stream-client [%s] ", clientId)); + this.log = logContext.logger(getClass()); + + // use client id instead of thread client id since this admin client may be shared among threads + this.clientSupplier = clientSupplier; + adminClient = clientSupplier.getAdmin(config.getAdminConfigs(ClientUtils.getSharedAdminClientId(clientId))); + + log.info("Kafka Streams version: {}", ClientMetrics.version()); + log.info("Kafka Streams commit ID: {}", ClientMetrics.commitId()); + + metrics = getMetrics(config, time, clientId); + streamsMetrics = new StreamsMetricsImpl( + metrics, + clientId, + config.getString(StreamsConfig.BUILT_IN_METRICS_VERSION_CONFIG), + time + ); + + ClientMetrics.addVersionMetric(streamsMetrics); + ClientMetrics.addCommitIdMetric(streamsMetrics); + ClientMetrics.addApplicationIdMetric(streamsMetrics, config.getString(StreamsConfig.APPLICATION_ID_CONFIG)); + ClientMetrics.addTopologyDescriptionMetric(streamsMetrics, (metricsConfig, now) -> this.topologyMetadata.topologyDescriptionString()); + ClientMetrics.addStateMetric(streamsMetrics, (metricsConfig, now) -> state); + threads = Collections.synchronizedList(new LinkedList<>()); + ClientMetrics.addNumAliveStreamThreadMetric(streamsMetrics, (metricsConfig, now) -> getNumLiveStreamThreads()); + + streamsMetadataState = new StreamsMetadataState( + this.topologyMetadata, + parseHostInfo(config.getString(StreamsConfig.APPLICATION_SERVER_CONFIG))); + + oldHandler = false; + streamsUncaughtExceptionHandler = this::defaultStreamsUncaughtExceptionHandler; + delegatingStateRestoreListener = new DelegatingStateRestoreListener(); + + totalCacheSize = config.getLong(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG); + final int numStreamThreads = topologyMetadata.getNumStreamThreads(config); + final long cacheSizePerThread = getCacheSizePerThread(numStreamThreads); + + GlobalStreamThread.State globalThreadState = null; + if (hasGlobalTopology) { + final String globalThreadId = clientId + "-GlobalStreamThread"; + globalStreamThread = new GlobalStreamThread( + topologyMetadata.globalTaskTopology(), + config, + clientSupplier.getGlobalConsumer(config.getGlobalConsumerConfigs(clientId)), + stateDirectory, + cacheSizePerThread, + streamsMetrics, + time, + globalThreadId, + delegatingStateRestoreListener, + streamsUncaughtExceptionHandler + ); + globalThreadState = globalStreamThread.state(); + } + + threadState = new HashMap<>(numStreamThreads); + streamStateListener = new StreamStateListener(threadState, globalThreadState); + + final GlobalStateStoreProvider globalStateStoreProvider = new GlobalStateStoreProvider(this.topologyMetadata.globalStateStores()); + + if (hasGlobalTopology) { + globalStreamThread.setStateListener(streamStateListener); + } + + queryableStoreProvider = new QueryableStoreProvider(globalStateStoreProvider); + for (int i = 1; i <= numStreamThreads; i++) { + createAndAddStreamThread(cacheSizePerThread, i); + } + + stateDirCleaner = setupStateDirCleaner(); + rocksDBMetricsRecordingService = maybeCreateRocksDBMetricsRecordingService(clientId, config); + } + + private StreamThread createAndAddStreamThread(final long cacheSizePerThread, final int threadIdx) { + final StreamThread streamThread = StreamThread.create( + topologyMetadata, + config, + clientSupplier, + adminClient, + processId, + clientId, + streamsMetrics, + time, + streamsMetadataState, + cacheSizePerThread, + stateDirectory, + delegatingStateRestoreListener, + threadIdx, + KafkaStreams.this::closeToError, + streamsUncaughtExceptionHandler + ); + streamThread.setStateListener(streamStateListener); + threads.add(streamThread); + threadState.put(streamThread.getId(), streamThread.state()); + queryableStoreProvider.addStoreProviderForThread(streamThread.getName(), new StreamThreadStateStoreProvider(streamThread)); + return streamThread; + } + + private static Metrics getMetrics(final StreamsConfig config, final Time time, final String clientId) { + final MetricConfig metricConfig = new MetricConfig() + .samples(config.getInt(StreamsConfig.METRICS_NUM_SAMPLES_CONFIG)) + .recordLevel(Sensor.RecordingLevel.forName(config.getString(StreamsConfig.METRICS_RECORDING_LEVEL_CONFIG))) + .timeWindow(config.getLong(StreamsConfig.METRICS_SAMPLE_WINDOW_MS_CONFIG), TimeUnit.MILLISECONDS); + final List reporters = config.getConfiguredInstances(StreamsConfig.METRIC_REPORTER_CLASSES_CONFIG, + MetricsReporter.class, + Collections.singletonMap(StreamsConfig.CLIENT_ID_CONFIG, clientId)); + final JmxReporter jmxReporter = new JmxReporter(); + jmxReporter.configure(config.originals()); + reporters.add(jmxReporter); + final MetricsContext metricsContext = new KafkaMetricsContext(JMX_PREFIX, + config.originalsWithPrefix(CommonClientConfigs.METRICS_CONTEXT_PREFIX)); + return new Metrics(metricConfig, reporters, time, metricsContext); + } + + /** + * Adds and starts a stream thread in addition to the stream threads that are already running in this + * Kafka Streams client. + *

                + * Since the number of stream threads increases, the sizes of the caches in the new stream thread + * and the existing stream threads are adapted so that the sum of the cache sizes over all stream + * threads does not exceed the total cache size specified in configuration + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG}. + *

                + * Stream threads can only be added if this Kafka Streams client is in state RUNNING or REBALANCING. + * + * @return name of the added stream thread or empty if a new stream thread could not be added + */ + public Optional addStreamThread() { + if (isRunningOrRebalancing()) { + final StreamThread streamThread; + synchronized (changeThreadCount) { + final int threadIdx = getNextThreadIndex(); + final int numLiveThreads = getNumLiveStreamThreads(); + final long cacheSizePerThread = getCacheSizePerThread(numLiveThreads + 1); + log.info("Adding StreamThread-{}, there will now be {} live threads and the new cache size per thread is {}", + threadIdx, numLiveThreads + 1, cacheSizePerThread); + resizeThreadCache(cacheSizePerThread); + // Creating thread should hold the lock in order to avoid duplicate thread index. + // If the duplicate index happen, the metadata of thread may be duplicate too. + streamThread = createAndAddStreamThread(cacheSizePerThread, threadIdx); + } + + synchronized (stateLock) { + if (isRunningOrRebalancing()) { + streamThread.start(); + return Optional.of(streamThread.getName()); + } else { + log.warn("Terminating the new thread because the Kafka Streams client is in state {}", state); + streamThread.shutdown(); + threads.remove(streamThread); + final long cacheSizePerThread = getCacheSizePerThread(getNumLiveStreamThreads()); + log.info("Resizing thread cache due to terminating added thread, new cache size per thread is {}", cacheSizePerThread); + resizeThreadCache(cacheSizePerThread); + return Optional.empty(); + } + } + } else { + log.warn("Cannot add a stream thread when Kafka Streams client is in state {}", state); + return Optional.empty(); + } + } + + /** + * Removes one stream thread out of the running stream threads from this Kafka Streams client. + *

                + * The removed stream thread is gracefully shut down. This method does not specify which stream + * thread is shut down. + *

                + * Since the number of stream threads decreases, the sizes of the caches in the remaining stream + * threads are adapted so that the sum of the cache sizes over all stream threads equals the total + * cache size specified in configuration {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG}. + * + * @return name of the removed stream thread or empty if a stream thread could not be removed because + * no stream threads are alive + */ + public Optional removeStreamThread() { + return removeStreamThread(Long.MAX_VALUE); + } + + /** + * Removes one stream thread out of the running stream threads from this Kafka Streams client. + *

                + * The removed stream thread is gracefully shut down. This method does not specify which stream + * thread is shut down. + *

                + * Since the number of stream threads decreases, the sizes of the caches in the remaining stream + * threads are adapted so that the sum of the cache sizes over all stream threads equals the total + * cache size specified in configuration {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG}. + * + * @param timeout The length of time to wait for the thread to shutdown + * @throws org.apache.kafka.common.errors.TimeoutException if the thread does not stop in time + * @return name of the removed stream thread or empty if a stream thread could not be removed because + * no stream threads are alive + */ + public Optional removeStreamThread(final Duration timeout) { + final String msgPrefix = prepareMillisCheckFailMsgPrefix(timeout, "timeout"); + final long timeoutMs = validateMillisecondDuration(timeout, msgPrefix); + return removeStreamThread(timeoutMs); + } + + private Optional removeStreamThread(final long timeoutMs) throws TimeoutException { + final long startMs = time.milliseconds(); + + if (isRunningOrRebalancing()) { + synchronized (changeThreadCount) { + // make a copy of threads to avoid holding lock + for (final StreamThread streamThread : new ArrayList<>(threads)) { + final boolean callingThreadIsNotCurrentStreamThread = !streamThread.getName().equals(Thread.currentThread().getName()); + if (streamThread.isAlive() && (callingThreadIsNotCurrentStreamThread || getNumLiveStreamThreads() == 1)) { + log.info("Removing StreamThread " + streamThread.getName()); + final Optional groupInstanceID = streamThread.getGroupInstanceID(); + streamThread.requestLeaveGroupDuringShutdown(); + streamThread.shutdown(); + if (!streamThread.getName().equals(Thread.currentThread().getName())) { + final long remainingTimeMs = timeoutMs - (time.milliseconds() - startMs); + if (remainingTimeMs <= 0 || !streamThread.waitOnThreadState(StreamThread.State.DEAD, remainingTimeMs)) { + log.warn("{} did not shutdown in the allotted time.", streamThread.getName()); + // Don't remove from threads until shutdown is complete. We will trim it from the + // list once it reaches DEAD, and if for some reason it's hanging indefinitely in the + // shutdown then we should just consider this thread.id to be burned + } else { + log.info("Successfully removed {} in {}ms", streamThread.getName(), time.milliseconds() - startMs); + threads.remove(streamThread); + queryableStoreProvider.removeStoreProviderForThread(streamThread.getName()); + } + } else { + log.info("{} is the last remaining thread and must remove itself, therefore we cannot wait " + + "for it to complete shutdown as this will result in deadlock.", streamThread.getName()); + } + + final long cacheSizePerThread = getCacheSizePerThread(getNumLiveStreamThreads()); + log.info("Resizing thread cache due to thread removal, new cache size per thread is {}", cacheSizePerThread); + resizeThreadCache(cacheSizePerThread); + if (groupInstanceID.isPresent() && callingThreadIsNotCurrentStreamThread) { + final MemberToRemove memberToRemove = new MemberToRemove(groupInstanceID.get()); + final Collection membersToRemove = Collections.singletonList(memberToRemove); + final RemoveMembersFromConsumerGroupResult removeMembersFromConsumerGroupResult = + adminClient.removeMembersFromConsumerGroup( + config.getString(StreamsConfig.APPLICATION_ID_CONFIG), + new RemoveMembersFromConsumerGroupOptions(membersToRemove) + ); + try { + final long remainingTimeMs = timeoutMs - (time.milliseconds() - startMs); + removeMembersFromConsumerGroupResult.memberResult(memberToRemove).get(remainingTimeMs, TimeUnit.MILLISECONDS); + } catch (final java.util.concurrent.TimeoutException e) { + log.error("Could not remove static member {} from consumer group {} due to a timeout: {}", + groupInstanceID.get(), config.getString(StreamsConfig.APPLICATION_ID_CONFIG), e); + throw new TimeoutException(e.getMessage(), e); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + } catch (final ExecutionException e) { + log.error("Could not remove static member {} from consumer group {} due to: {}", + groupInstanceID.get(), config.getString(StreamsConfig.APPLICATION_ID_CONFIG), e); + throw new StreamsException( + "Could not remove static member " + groupInstanceID.get() + + " from consumer group " + config.getString(StreamsConfig.APPLICATION_ID_CONFIG) + + " for the following reason: ", + e.getCause() + ); + } + } + final long remainingTimeMs = timeoutMs - (time.milliseconds() - startMs); + if (remainingTimeMs <= 0) { + throw new TimeoutException("Thread " + streamThread.getName() + " did not stop in the allotted time"); + } + return Optional.of(streamThread.getName()); + } + } + } + log.warn("There are no threads eligible for removal"); + } else { + log.warn("Cannot remove a stream thread when Kafka Streams client is in state " + state()); + } + return Optional.empty(); + } + + /** + * Takes a snapshot and counts the number of stream threads which are not in PENDING_SHUTDOWN or DEAD + * + * note: iteration over SynchronizedList is not thread safe so it must be manually synchronized. However, we may + * require other locks when looping threads and it could cause deadlock. Hence, we create a copy to avoid holding + * threads lock when looping threads. + * @return number of alive stream threads + */ + private int getNumLiveStreamThreads() { + final AtomicInteger numLiveThreads = new AtomicInteger(0); + + synchronized (threads) { + processStreamThread(thread -> { + if (thread.state() == StreamThread.State.DEAD) { + log.debug("Trimming thread {} from the threads list since it's state is {}", thread.getName(), StreamThread.State.DEAD); + threads.remove(thread); + } else if (thread.state() == StreamThread.State.PENDING_SHUTDOWN) { + log.debug("Skipping thread {} from num live threads computation since it's state is {}", + thread.getName(), StreamThread.State.PENDING_SHUTDOWN); + } else { + numLiveThreads.incrementAndGet(); + } + }); + return numLiveThreads.get(); + } + } + + private int getNextThreadIndex() { + final HashSet allLiveThreadNames = new HashSet<>(); + final AtomicInteger maxThreadId = new AtomicInteger(1); + synchronized (threads) { + processStreamThread(thread -> { + // trim any DEAD threads from the list so we can reuse the thread.id + // this is only safe to do once the thread has fully completed shutdown + if (thread.state() == StreamThread.State.DEAD) { + threads.remove(thread); + } else { + allLiveThreadNames.add(thread.getName()); + // Assume threads are always named with the "-StreamThread-" suffix + final int threadId = Integer.parseInt(thread.getName().substring(thread.getName().lastIndexOf("-") + 1)); + if (threadId > maxThreadId.get()) { + maxThreadId.set(threadId); + } + } + }); + + final String baseName = clientId + "-StreamThread-"; + for (int i = 1; i <= maxThreadId.get(); i++) { + final String name = baseName + i; + if (!allLiveThreadNames.contains(name)) { + return i; + } + } + // It's safe to use threads.size() rather than getNumLiveStreamThreads() to infer the number of threads + // here since we trimmed any DEAD threads earlier in this method while holding the lock + return threads.size() + 1; + } + } + + private long getCacheSizePerThread(final int numStreamThreads) { + if (numStreamThreads == 0) { + return totalCacheSize; + } + return totalCacheSize / (numStreamThreads + (topologyMetadata.hasGlobalTopology() ? 1 : 0)); + } + + private void resizeThreadCache(final long cacheSizePerThread) { + processStreamThread(thread -> thread.resizeCache(cacheSizePerThread)); + if (globalStreamThread != null) { + globalStreamThread.resize(cacheSizePerThread); + } + } + + private ScheduledExecutorService setupStateDirCleaner() { + return Executors.newSingleThreadScheduledExecutor(r -> { + final Thread thread = new Thread(r, clientId + "-CleanupThread"); + thread.setDaemon(true); + return thread; + }); + } + + private static ScheduledExecutorService maybeCreateRocksDBMetricsRecordingService(final String clientId, + final StreamsConfig config) { + if (RecordingLevel.forName(config.getString(METRICS_RECORDING_LEVEL_CONFIG)) == RecordingLevel.DEBUG) { + return Executors.newSingleThreadScheduledExecutor(r -> { + final Thread thread = new Thread(r, clientId + "-RocksDBMetricsRecordingTrigger"); + thread.setDaemon(true); + return thread; + }); + } + return null; + } + + private static HostInfo parseHostInfo(final String endPoint) { + final HostInfo hostInfo = HostInfo.buildFromEndpoint(endPoint); + if (hostInfo == null) { + return StreamsMetadataState.UNKNOWN_HOST; + } else { + return hostInfo; + } + } + + /** + * Start the {@code KafkaStreams} instance by starting all its threads. + * This function is expected to be called only once during the life cycle of the client. + *

                + * Because threads are started in the background, this method does not block. + * However, if you have global stores in your topology, this method blocks until all global stores are restored. + * As a consequence, any fatal exception that happens during processing is by default only logged. + * If you want to be notified about dying threads, you can + * {@link #setUncaughtExceptionHandler(Thread.UncaughtExceptionHandler) register an uncaught exception handler} + * before starting the {@code KafkaStreams} instance. + *

                + * Note, for brokers with version {@code 0.9.x} or lower, the broker version cannot be checked. + * There will be no error and the client will hang and retry to verify the broker version until it + * {@link StreamsConfig#REQUEST_TIMEOUT_MS_CONFIG times out}. + + * @throws IllegalStateException if process was already started + * @throws StreamsException if the Kafka brokers have version 0.10.0.x or + * if {@link StreamsConfig#PROCESSING_GUARANTEE_CONFIG exactly-once} is enabled for pre 0.11.0.x brokers + */ + public synchronized void start() throws IllegalStateException, StreamsException { + if (setState(State.REBALANCING)) { + log.debug("Starting Streams client"); + + if (globalStreamThread != null) { + globalStreamThread.start(); + } + + processStreamThread(StreamThread::start); + + final Long cleanupDelay = config.getLong(StreamsConfig.STATE_CLEANUP_DELAY_MS_CONFIG); + stateDirCleaner.scheduleAtFixedRate(() -> { + // we do not use lock here since we only read on the value and act on it + if (state == State.RUNNING) { + stateDirectory.cleanRemovedTasks(cleanupDelay); + } + }, cleanupDelay, cleanupDelay, TimeUnit.MILLISECONDS); + + final long recordingDelay = 0; + final long recordingInterval = 1; + if (rocksDBMetricsRecordingService != null) { + rocksDBMetricsRecordingService.scheduleAtFixedRate( + streamsMetrics.rocksDBMetricsRecordingTrigger(), + recordingDelay, + recordingInterval, + TimeUnit.MINUTES + ); + } + } else { + throw new IllegalStateException("The client is either already started or already stopped, cannot re-start"); + } + } + + /** + * Shutdown this {@code KafkaStreams} instance by signaling all the threads to stop, and then wait for them to join. + * This will block until all threads have stopped. + */ + public void close() { + close(Long.MAX_VALUE); + } + + private Thread shutdownHelper(final boolean error) { + stateDirCleaner.shutdownNow(); + if (rocksDBMetricsRecordingService != null) { + rocksDBMetricsRecordingService.shutdownNow(); + } + + // wait for all threads to join in a separate thread; + // save the current thread so that if it is a stream thread + // we don't attempt to join it and cause a deadlock + return new Thread(() -> { + // notify all the threads to stop; avoid deadlocks by stopping any + // further state reports from the thread since we're shutting down + processStreamThread(StreamThread::shutdown); + topologyMetadata.wakeupThreads(); + + processStreamThread(thread -> { + try { + if (!thread.isRunning()) { + thread.join(); + } + } catch (final InterruptedException ex) { + Thread.currentThread().interrupt(); + } + }); + + if (globalStreamThread != null) { + globalStreamThread.shutdown(); + } + + if (globalStreamThread != null && !globalStreamThread.stillRunning()) { + try { + globalStreamThread.join(); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + } + globalStreamThread = null; + } + + stateDirectory.close(); + adminClient.close(); + + streamsMetrics.removeAllClientLevelSensorsAndMetrics(); + metrics.close(); + if (!error) { + setState(State.NOT_RUNNING); + } else { + setState(State.ERROR); + } + }, "kafka-streams-close-thread"); + } + + private boolean close(final long timeoutMs) { + if (state.hasCompletedShutdown()) { + log.info("Streams client is already in the terminal {} state, all resources are closed and the client has stopped.", state); + return true; + } + if (state.isShuttingDown()) { + log.info("Streams client is in {}, all resources are being closed and the client will be stopped.", state); + if (state == State.PENDING_ERROR && waitOnState(State.ERROR, timeoutMs)) { + log.info("Streams client stopped to ERROR completely"); + return true; + } else if (state == State.PENDING_SHUTDOWN && waitOnState(State.NOT_RUNNING, timeoutMs)) { + log.info("Streams client stopped to NOT_RUNNING completely"); + return true; + } else { + log.warn("Streams client cannot transition to {}} completely within the timeout", + state == State.PENDING_SHUTDOWN ? State.NOT_RUNNING : State.ERROR); + return false; + } + } + + if (!setState(State.PENDING_SHUTDOWN)) { + // if we can't transition to PENDING_SHUTDOWN but not because we're already shutting down, then it must be fatal + log.error("Failed to transition to PENDING_SHUTDOWN, current state is {}", state); + throw new StreamsException("Failed to shut down while in state " + state); + } else { + final Thread shutdownThread = shutdownHelper(false); + + shutdownThread.setDaemon(true); + shutdownThread.start(); + } + + if (waitOnState(State.NOT_RUNNING, timeoutMs)) { + log.info("Streams client stopped completely"); + return true; + } else { + log.info("Streams client cannot stop completely within the timeout"); + return false; + } + } + + private void closeToError() { + if (!setState(State.PENDING_ERROR)) { + log.info("Skipping shutdown since we are already in " + state()); + } else { + final Thread shutdownThread = shutdownHelper(true); + + shutdownThread.setDaemon(true); + shutdownThread.start(); + } + } + + /** + * Shutdown this {@code KafkaStreams} by signaling all the threads to stop, and then wait up to the timeout for the + * threads to join. + * A {@code timeout} of Duration.ZERO (or any other zero duration) makes the close operation asynchronous. + * Negative-duration timeouts are rejected. + * + * @param timeout how long to wait for the threads to shutdown + * @return {@code true} if all threads were successfully stopped—{@code false} if the timeout was reached + * before all threads stopped + * Note that this method must not be called in the {@link StateListener#onChange(KafkaStreams.State, KafkaStreams.State)} callback of {@link StateListener}. + * @throws IllegalArgumentException if {@code timeout} can't be represented as {@code long milliseconds} + */ + public synchronized boolean close(final Duration timeout) throws IllegalArgumentException { + final String msgPrefix = prepareMillisCheckFailMsgPrefix(timeout, "timeout"); + final long timeoutMs = validateMillisecondDuration(timeout, msgPrefix); + if (timeoutMs < 0) { + throw new IllegalArgumentException("Timeout can't be negative."); + } + + log.debug("Stopping Streams client with timeoutMillis = {} ms.", timeoutMs); + + return close(timeoutMs); + } + + /** + * Do a clean up of the local {@link StateStore} directory ({@link StreamsConfig#STATE_DIR_CONFIG}) by deleting all + * data with regard to the {@link StreamsConfig#APPLICATION_ID_CONFIG application ID}. + *

                + * May only be called either before this {@code KafkaStreams} instance is {@link #start() started} or after the + * instance is {@link #close() closed}. + *

                + * Calling this method triggers a restore of local {@link StateStore}s on the next {@link #start() application start}. + * + * @throws IllegalStateException if this {@code KafkaStreams} instance has been started and hasn't fully shut down + * @throws StreamsException if cleanup failed + */ + public void cleanUp() { + if (!(state.hasNotStarted() || state.hasCompletedShutdown())) { + throw new IllegalStateException("Cannot clean up while running."); + } + stateDirectory.clean(); + } + + /** + * Find all currently running {@code KafkaStreams} instances (potentially remotely) that use the same + * {@link StreamsConfig#APPLICATION_ID_CONFIG application ID} as this instance (i.e., all instances that belong to + * the same Kafka Streams application) and return {@link StreamsMetadata} for each discovered instance. + *

                + * Note: this is a point in time view and it may change due to partition reassignment. + * + * @return {@link StreamsMetadata} for each {@code KafkaStreams} instances of this application + * @deprecated since 3.0.0 use {@link KafkaStreams#metadataForAllStreamsClients} + */ + @Deprecated + public Collection allMetadata() { + validateIsRunningOrRebalancing(); + return streamsMetadataState.getAllMetadata().stream().map(streamsMetadata -> + new org.apache.kafka.streams.state.StreamsMetadata(streamsMetadata.hostInfo(), + streamsMetadata.stateStoreNames(), + streamsMetadata.topicPartitions(), + streamsMetadata.standbyStateStoreNames(), + streamsMetadata.standbyTopicPartitions())) + .collect(Collectors.toSet()); + } + + /** + * Find all currently running {@code KafkaStreams} instances (potentially remotely) that use the same + * {@link StreamsConfig#APPLICATION_ID_CONFIG application ID} as this instance (i.e., all instances that belong to + * the same Kafka Streams application) and return {@link StreamsMetadata} for each discovered instance. + *

                + * Note: this is a point in time view and it may change due to partition reassignment. + * + * @return {@link StreamsMetadata} for each {@code KafkaStreams} instances of this application + */ + public Collection metadataForAllStreamsClients() { + validateIsRunningOrRebalancing(); + return streamsMetadataState.getAllMetadata(); + } + + /** + * Find all currently running {@code KafkaStreams} instances (potentially remotely) that + *

                  + *
                • use the same {@link StreamsConfig#APPLICATION_ID_CONFIG application ID} as this instance (i.e., all + * instances that belong to the same Kafka Streams application)
                • + *
                • and that contain a {@link StateStore} with the given {@code storeName}
                • + *
                + * and return {@link StreamsMetadata} for each discovered instance. + *

                + * Note: this is a point in time view and it may change due to partition reassignment. + * + * @param storeName the {@code storeName} to find metadata for + * @return {@link StreamsMetadata} for each {@code KafkaStreams} instances with the provide {@code storeName} of + * this application + * @deprecated since 3.0.0 use {@link KafkaStreams#streamsMetadataForStore} instead + */ + @Deprecated + public Collection allMetadataForStore(final String storeName) { + validateIsRunningOrRebalancing(); + return streamsMetadataState.getAllMetadataForStore(storeName).stream().map(streamsMetadata -> + new org.apache.kafka.streams.state.StreamsMetadata(streamsMetadata.hostInfo(), + streamsMetadata.stateStoreNames(), + streamsMetadata.topicPartitions(), + streamsMetadata.standbyStateStoreNames(), + streamsMetadata.standbyTopicPartitions())) + .collect(Collectors.toSet()); + } + + /** + * Find all currently running {@code KafkaStreams} instances (potentially remotely) that + *

                  + *
                • use the same {@link StreamsConfig#APPLICATION_ID_CONFIG application ID} as this instance (i.e., all + * instances that belong to the same Kafka Streams application)
                • + *
                • and that contain a {@link StateStore} with the given {@code storeName}
                • + *
                + * and return {@link StreamsMetadata} for each discovered instance. + *

                + * Note: this is a point in time view and it may change due to partition reassignment. + * + * @param storeName the {@code storeName} to find metadata for + * @return {@link StreamsMetadata} for each {@code KafkaStreams} instances with the provide {@code storeName} of + * this application + */ + public Collection streamsMetadataForStore(final String storeName) { + validateIsRunningOrRebalancing(); + return streamsMetadataState.getAllMetadataForStore(storeName); + } + + /** + * Finds the metadata containing the active hosts and standby hosts where the key being queried would reside. + * + * @param storeName the {@code storeName} to find metadata for + * @param key the key to find metadata for + * @param keySerializer serializer for the key + * @param key type + * Returns {@link KeyQueryMetadata} containing all metadata about hosting the given key for the given store, + * or {@code null} if no matching metadata could be found. + */ + public KeyQueryMetadata queryMetadataForKey(final String storeName, + final K key, + final Serializer keySerializer) { + validateIsRunningOrRebalancing(); + return streamsMetadataState.getKeyQueryMetadataForKey(storeName, key, keySerializer); + } + + /** + * Finds the metadata containing the active hosts and standby hosts where the key being queried would reside. + * + * @param storeName the {@code storeName} to find metadata for + * @param key the key to find metadata for + * @param partitioner the partitioner to be use to locate the host for the key + * @param key type + * Returns {@link KeyQueryMetadata} containing all metadata about hosting the given key for the given store, using the + * the supplied partitioner, or {@code null} if no matching metadata could be found. + */ + public KeyQueryMetadata queryMetadataForKey(final String storeName, + final K key, + final StreamPartitioner partitioner) { + validateIsRunningOrRebalancing(); + return streamsMetadataState.getKeyQueryMetadataForKey(storeName, key, partitioner); + } + + /** + * Get a facade wrapping the local {@link StateStore} instances with the provided {@link StoreQueryParameters}. + * The returned object can be used to query the {@link StateStore} instances. + * + * @param storeQueryParameters the parameters used to fetch a queryable store + * @return A facade wrapping the local {@link StateStore} instances + * @throws StreamsNotStartedException If Streams has not yet been started. Just call {@link KafkaStreams#start()} + * and then retry this call. + * @throws UnknownStateStoreException If the specified store name does not exist in the topology. + * @throws InvalidStateStorePartitionException If the specified partition does not exist. + * @throws InvalidStateStoreException If the Streams instance isn't in a queryable state. + * If the store's type does not match the QueryableStoreType, + * the Streams instance is not in a queryable state with respect + * to the parameters, or if the store is not available locally, then + * an InvalidStateStoreException is thrown upon store access. + */ + public T store(final StoreQueryParameters storeQueryParameters) { + validateIsRunningOrRebalancing(); + final String storeName = storeQueryParameters.storeName(); + if (!topologyMetadata.hasStore(storeName)) { + throw new UnknownStateStoreException( + "Cannot get state store " + storeName + " because no such store is registered in the topology." + ); + } + return queryableStoreProvider.getStore(storeQueryParameters); + } + + /** + * handle each stream thread in a snapshot of threads. + * noted: iteration over SynchronizedList is not thread safe so it must be manually synchronized. However, we may + * require other locks when looping threads and it could cause deadlock. Hence, we create a copy to avoid holding + * threads lock when looping threads. + * @param consumer handler + */ + protected void processStreamThread(final Consumer consumer) { + final List copy = new ArrayList<>(threads); + for (final StreamThread thread : copy) consumer.accept(thread); + } + + /** + * Returns runtime information about the local threads of this {@link KafkaStreams} instance. + * + * @return the set of {@link org.apache.kafka.streams.processor.ThreadMetadata}. + * @deprecated since 3.0 use {@link #metadataForLocalThreads()} + */ + @Deprecated + @SuppressWarnings("deprecation") + public Set localThreadsMetadata() { + return metadataForLocalThreads().stream().map(threadMetadata -> new org.apache.kafka.streams.processor.ThreadMetadata( + threadMetadata.threadName(), + threadMetadata.threadState(), + threadMetadata.consumerClientId(), + threadMetadata.restoreConsumerClientId(), + threadMetadata.producerClientIds(), + threadMetadata.adminClientId(), + threadMetadata.activeTasks().stream().map(taskMetadata -> new org.apache.kafka.streams.processor.TaskMetadata( + taskMetadata.taskId().toString(), + taskMetadata.topicPartitions(), + taskMetadata.committedOffsets(), + taskMetadata.endOffsets(), + taskMetadata.timeCurrentIdlingStarted()) + ).collect(Collectors.toSet()), + threadMetadata.standbyTasks().stream().map(taskMetadata -> new org.apache.kafka.streams.processor.TaskMetadata( + taskMetadata.taskId().toString(), + taskMetadata.topicPartitions(), + taskMetadata.committedOffsets(), + taskMetadata.endOffsets(), + taskMetadata.timeCurrentIdlingStarted()) + ).collect(Collectors.toSet()))) + .collect(Collectors.toSet()); + } + + /** + * Returns runtime information about the local threads of this {@link KafkaStreams} instance. + * + * @return the set of {@link ThreadMetadata}. + */ + public Set metadataForLocalThreads() { + final Set threadMetadata = new HashSet<>(); + processStreamThread(thread -> { + synchronized (thread.getStateLock()) { + if (thread.state() != StreamThread.State.DEAD) { + threadMetadata.add(thread.threadMetadata()); + } + } + }); + return threadMetadata; + } + + /** + * Returns {@link LagInfo}, for all store partitions (active or standby) local to this Streams instance. Note that the + * values returned are just estimates and meant to be used for making soft decisions on whether the data in the store + * partition is fresh enough for querying. + * + * Note: Each invocation of this method issues a call to the Kafka brokers. Thus its advisable to limit the frequency + * of invocation to once every few seconds. + * + * @return map of store names to another map of partition to {@link LagInfo}s + * @throws StreamsException if the admin client request throws exception + */ + public Map> allLocalStorePartitionLags() { + final Map> localStorePartitionLags = new TreeMap<>(); + final Collection allPartitions = new LinkedList<>(); + final Map allChangelogPositions = new HashMap<>(); + + // Obtain the current positions, of all the active-restoring and standby tasks + processStreamThread(thread -> { + for (final Task task : thread.allTasks().values()) { + allPartitions.addAll(task.changelogPartitions()); + // Note that not all changelog partitions, will have positions; since some may not have started + allChangelogPositions.putAll(task.changelogOffsets()); + } + }); + + log.debug("Current changelog positions: {}", allChangelogPositions); + final Map allEndOffsets; + allEndOffsets = fetchEndOffsets(allPartitions, adminClient); + log.debug("Current end offsets :{}", allEndOffsets); + + for (final Map.Entry entry : allEndOffsets.entrySet()) { + // Avoiding an extra admin API lookup by computing lags for not-yet-started restorations + // from zero instead of the real "earliest offset" for the changelog. + // This will yield the correct relative order of lagginess for the tasks in the cluster, + // but it is an over-estimate of how much work remains to restore the task from scratch. + final long earliestOffset = 0L; + final long changelogPosition = allChangelogPositions.getOrDefault(entry.getKey(), earliestOffset); + final long latestOffset = entry.getValue().offset(); + final LagInfo lagInfo = new LagInfo(changelogPosition == Task.LATEST_OFFSET ? latestOffset : changelogPosition, latestOffset); + final String storeName = streamsMetadataState.getStoreForChangelogTopic(entry.getKey().topic()); + localStorePartitionLags.computeIfAbsent(storeName, ignored -> new TreeMap<>()) + .put(entry.getKey().partition(), lagInfo); + } + + return Collections.unmodifiableMap(localStorePartitionLags); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/KeyQueryMetadata.java b/streams/src/main/java/org/apache/kafka/streams/KeyQueryMetadata.java new file mode 100644 index 0000000..9ca4952 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/KeyQueryMetadata.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import org.apache.kafka.streams.state.HostInfo; + +import java.util.Collections; +import java.util.Objects; +import java.util.Set; + +/** + * Represents all the metadata related to a key, where a particular key resides in a {@link KafkaStreams} application. + * It contains the active {@link HostInfo} and a set of standby {@link HostInfo}s, denoting the instances where the key resides. + * It also contains the partition number where the key belongs, which could be useful when used in conjunction with other APIs. + * e.g: Relating with lags for that store partition. + * NOTE: This is a point in time view. It may change as rebalances happen. + */ +public class KeyQueryMetadata { + /** + * Sentinel to indicate that the KeyQueryMetadata is currently unavailable. This can occur during rebalance + * operations. + */ + public static final KeyQueryMetadata NOT_AVAILABLE = + new KeyQueryMetadata(HostInfo.unavailable(), Collections.emptySet(), -1); + + private final HostInfo activeHost; + + private final Set standbyHosts; + + private final int partition; + + public KeyQueryMetadata(final HostInfo activeHost, final Set standbyHosts, final int partition) { + this.activeHost = activeHost; + this.standbyHosts = standbyHosts; + this.partition = partition; + } + + /** + * Get the active Kafka Streams instance for given key. + * + * @return active instance's {@link HostInfo} + * @deprecated Use {@link #activeHost()} instead. + */ + @Deprecated + public HostInfo getActiveHost() { + return activeHost; + } + + /** + * Get the Kafka Streams instances that host the key as standbys. + * + * @return set of standby {@link HostInfo} or a empty set, if no standbys are configured + * @deprecated Use {@link #standbyHosts()} instead. + */ + @Deprecated + public Set getStandbyHosts() { + return standbyHosts; + } + + /** + * Get the store partition corresponding to the key. + * + * @return store partition number + * @deprecated Use {@link #partition()} instead. + */ + @Deprecated + public int getPartition() { + return partition; + } + + /** + * Get the active Kafka Streams instance for given key. + * + * @return active instance's {@link HostInfo} + */ + public HostInfo activeHost() { + return activeHost; + } + + /** + * Get the Kafka Streams instances that host the key as standbys. + * + * @return set of standby {@link HostInfo} or a empty set, if no standbys are configured + */ + public Set standbyHosts() { + return standbyHosts; + } + + /** + * Get the store partition corresponding to the key. + * + * @return store partition number + */ + public int partition() { + return partition; + } + + @Override + public boolean equals(final Object obj) { + if (!(obj instanceof KeyQueryMetadata)) { + return false; + } + final KeyQueryMetadata keyQueryMetadata = (KeyQueryMetadata) obj; + return Objects.equals(keyQueryMetadata.activeHost, activeHost) + && Objects.equals(keyQueryMetadata.standbyHosts, standbyHosts) + && Objects.equals(keyQueryMetadata.partition, partition); + } + + @Override + public String toString() { + return "KeyQueryMetadata {" + + "activeHost=" + activeHost + + ", standbyHosts=" + standbyHosts + + ", partition=" + partition + + '}'; + } + + @Override + public int hashCode() { + return Objects.hash(activeHost, standbyHosts, partition); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/KeyValue.java b/streams/src/main/java/org/apache/kafka/streams/KeyValue.java new file mode 100644 index 0000000..b534d11 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/KeyValue.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import java.util.Objects; + +/** + * A key-value pair defined for a single Kafka Streams record. + * If the record comes directly from a Kafka topic then its key/value are defined as the message key/value. + * + * @param Key type + * @param Value type + */ +public class KeyValue { + + /** The key of the key-value pair. */ + public final K key; + /** The value of the key-value pair. */ + public final V value; + + /** + * Create a new key-value pair. + * + * @param key the key + * @param value the value + */ + public KeyValue(final K key, final V value) { + this.key = key; + this.value = value; + } + + /** + * Create a new key-value pair. + * + * @param key the key + * @param value the value + * @param the type of the key + * @param the type of the value + * @return a new key-value pair + */ + public static KeyValue pair(final K key, final V value) { + return new KeyValue<>(key, value); + } + + @Override + public String toString() { + return "KeyValue(" + key + ", " + value + ")"; + } + + @Override + public boolean equals(final Object obj) { + if (this == obj) { + return true; + } + + if (!(obj instanceof KeyValue)) { + return false; + } + + final KeyValue other = (KeyValue) obj; + return Objects.equals(key, other.key) && Objects.equals(value, other.value); + } + + @Override + public int hashCode() { + return Objects.hash(key, value); + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/LagInfo.java b/streams/src/main/java/org/apache/kafka/streams/LagInfo.java new file mode 100644 index 0000000..4a6f642 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/LagInfo.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import java.util.Objects; + +/** + * Encapsulates information about lag, at a store partition replica (active or standby). This information is constantly changing as the + * tasks process records and thus, they should be treated as simply instantaenous measure of lag. + */ +public class LagInfo { + + private final long currentOffsetPosition; + + private final long endOffsetPosition; + + private final long offsetLag; + + LagInfo(final long currentOffsetPosition, final long endOffsetPosition) { + this.currentOffsetPosition = currentOffsetPosition; + this.endOffsetPosition = endOffsetPosition; + this.offsetLag = Math.max(0, endOffsetPosition - currentOffsetPosition); + } + + /** + * Get the current maximum offset on the store partition's changelog topic, that has been successfully written into + * the store partition's state store. + * + * @return current consume offset for standby/restoring store partitions & simply endoffset for active store partition replicas + */ + public long currentOffsetPosition() { + return this.currentOffsetPosition; + } + + /** + * Get the end offset position for this store partition's changelog topic on the Kafka brokers. + * + * @return last offset written to the changelog topic partition + */ + public long endOffsetPosition() { + return this.endOffsetPosition; + } + + /** + * Get the measured lag between current and end offset positions, for this store partition replica + * + * @return lag as measured by message offsets + */ + public long offsetLag() { + return this.offsetLag; + } + + @Override + public boolean equals(final Object obj) { + if (!(obj instanceof LagInfo)) { + return false; + } + final LagInfo other = (LagInfo) obj; + return currentOffsetPosition == other.currentOffsetPosition + && endOffsetPosition == other.endOffsetPosition + && this.offsetLag == other.offsetLag; + } + + @Override + public int hashCode() { + return Objects.hash(currentOffsetPosition, endOffsetPosition, offsetLag); + } + + @Override + public String toString() { + return "LagInfo {" + + " currentOffsetPosition=" + currentOffsetPosition + + ", endOffsetPosition=" + endOffsetPosition + + ", offsetLag=" + offsetLag + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/StoreQueryParameters.java b/streams/src/main/java/org/apache/kafka/streams/StoreQueryParameters.java new file mode 100644 index 0000000..aa5785c --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/StoreQueryParameters.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import org.apache.kafka.streams.state.QueryableStoreType; + +import java.util.Objects; + +/** + * {@code StoreQueryParameters} allows you to pass a variety of parameters when fetching a store for interactive query. + */ +public class StoreQueryParameters { + + private final Integer partition; + private final boolean staleStores; + private final String storeName; + private final QueryableStoreType queryableStoreType; + + private StoreQueryParameters(final String storeName, final QueryableStoreType queryableStoreType, final Integer partition, final boolean staleStores) { + this.storeName = storeName; + this.queryableStoreType = queryableStoreType; + this.partition = partition; + this.staleStores = staleStores; + } + + public static StoreQueryParameters fromNameAndType(final String storeName, + final QueryableStoreType queryableStoreType) { + return new StoreQueryParameters<>(storeName, queryableStoreType, null, false); + } + + /** + * Set a specific partition that should be queried exclusively. + * + * @param partition The specific integer partition to be fetched from the stores list by using {@link StoreQueryParameters}. + * + * @return StoreQueryParameters a new {@code StoreQueryParameters} instance configured with the specified partition + */ + public StoreQueryParameters withPartition(final Integer partition) { + return new StoreQueryParameters<>(storeName, queryableStoreType, partition, staleStores); + } + + /** + * Enable querying of stale state stores, i.e., allow to query active tasks during restore as well as standby tasks. + * + * @return StoreQueryParameters a new {@code StoreQueryParameters} instance configured with serving from stale stores enabled + */ + public StoreQueryParameters enableStaleStores() { + return new StoreQueryParameters<>(storeName, queryableStoreType, partition, true); + } + + /** + * Get the name of the state store that should be queried. + * + * @return String state store name + */ + public String storeName() { + return storeName; + } + + /** + * Get the queryable store type for which key is queried by the user. + * + * @return QueryableStoreType type of queryable store + */ + public QueryableStoreType queryableStoreType() { + return queryableStoreType; + } + + /** + * Get the store partition that will be queried. + * If the method returns {@code null}, it would mean that no specific partition has been requested, + * so all the local partitions for the store will be queried. + * + * @return Integer partition + */ + public Integer partition() { + return partition; + } + + /** + * Get the flag staleStores. If {@code true}, include standbys and recovering stores along with running stores. + * + * @return boolean staleStores + */ + public boolean staleStoresEnabled() { + return staleStores; + } + + @Override + public boolean equals(final Object obj) { + if (!(obj instanceof StoreQueryParameters)) { + return false; + } + final StoreQueryParameters storeQueryParameters = (StoreQueryParameters) obj; + return Objects.equals(storeQueryParameters.partition, partition) + && Objects.equals(storeQueryParameters.staleStores, staleStores) + && Objects.equals(storeQueryParameters.storeName, storeName) + && Objects.equals(storeQueryParameters.queryableStoreType, queryableStoreType); + } + + @Override + public String toString() { + return "StoreQueryParameters {" + + "partition=" + partition + + ", staleStores=" + staleStores + + ", storeName=" + storeName + + ", queryableStoreType=" + queryableStoreType + + '}'; + } + + @Override + public int hashCode() { + return Objects.hash(partition, staleStores, storeName, queryableStoreType); + } +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/StreamsBuilder.java b/streams/src/main/java/org/apache/kafka/streams/StreamsBuilder.java new file mode 100644 index 0000000..f10dc93 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/StreamsBuilder.java @@ -0,0 +1,616 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.errors.TopologyException; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.GlobalKTable; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KGroupedTable; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Transformer; +import org.apache.kafka.streams.kstream.ValueTransformer; +import org.apache.kafka.streams.kstream.internals.ConsumedInternal; +import org.apache.kafka.streams.kstream.internals.InternalStreamsBuilder; +import org.apache.kafka.streams.kstream.internals.MaterializedInternal; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.TimestampExtractor; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; +import org.apache.kafka.streams.processor.internals.ProcessorAdapter; +import org.apache.kafka.streams.processor.internals.ProcessorNode; +import org.apache.kafka.streams.processor.internals.SourceNode; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.ReadOnlyKeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; + +import java.util.Collection; +import java.util.Collections; +import java.util.Objects; +import java.util.Properties; +import java.util.regex.Pattern; + +/** + * {@code StreamsBuilder} provide the high-level Kafka Streams DSL to specify a Kafka Streams topology. + * + *

                + * It is a requirement that the processing logic ({@link Topology}) be defined in a deterministic way, + * as in, the order in which all operators are added must be predictable and the same across all application + * instances. + * Topologies are only identical if all operators are added in the same order. + * If different {@link KafkaStreams} instances of the same application build different topologies the result may be + * incompatible runtime code and unexpected results or errors + * + * @see Topology + * @see KStream + * @see KTable + * @see GlobalKTable + */ +public class StreamsBuilder { + + /** The actual topology that is constructed by this StreamsBuilder. */ + protected final Topology topology; + + /** The topology's internal builder. */ + protected final InternalTopologyBuilder internalTopologyBuilder; + + protected final InternalStreamsBuilder internalStreamsBuilder; + + public StreamsBuilder() { + topology = getNewTopology(); + internalTopologyBuilder = topology.internalTopologyBuilder; + internalStreamsBuilder = new InternalStreamsBuilder(internalTopologyBuilder); + } + + protected Topology getNewTopology() { + return new Topology(); + } + + /** + * Create a {@link KStream} from the specified topic. + * The default {@code "auto.offset.reset"} strategy, default {@link TimestampExtractor}, and default key and value + * deserializers as specified in the {@link StreamsConfig config} are used. + *

                + * If multiple topics are specified there is no ordering guarantee for records from different topics. + *

                + * Note that the specified input topic must be partitioned by key. + * If this is not the case it is the user's responsibility to repartition the data before any key based operation + * (like aggregation or join) is applied to the returned {@link KStream}. + * + * @param topic the topic name; cannot be {@code null} + * @return a {@link KStream} for the specified topic + */ + public synchronized KStream stream(final String topic) { + return stream(Collections.singleton(topic)); + } + + /** + * Create a {@link KStream} from the specified topic. + * The {@code "auto.offset.reset"} strategy, {@link TimestampExtractor}, key and value deserializers + * are defined by the options in {@link Consumed} are used. + *

                + * Note that the specified input topic must be partitioned by key. + * If this is not the case it is the user's responsibility to repartition the data before any key based operation + * (like aggregation or join) is applied to the returned {@link KStream}. + * + * @param topic the topic names; cannot be {@code null} + * @param consumed the instance of {@link Consumed} used to define optional parameters + * @return a {@link KStream} for the specified topic + */ + public synchronized KStream stream(final String topic, + final Consumed consumed) { + return stream(Collections.singleton(topic), consumed); + } + + /** + * Create a {@link KStream} from the specified topics. + * The default {@code "auto.offset.reset"} strategy, default {@link TimestampExtractor}, and default key and value + * deserializers as specified in the {@link StreamsConfig config} are used. + *

                + * If multiple topics are specified there is no ordering guarantee for records from different topics. + *

                + * Note that the specified input topics must be partitioned by key. + * If this is not the case it is the user's responsibility to repartition the data before any key based operation + * (like aggregation or join) is applied to the returned {@link KStream}. + * + * @param topics the topic names; must contain at least one topic name + * @return a {@link KStream} for the specified topics + */ + public synchronized KStream stream(final Collection topics) { + return stream(topics, Consumed.with(null, null, null, null)); + } + + /** + * Create a {@link KStream} from the specified topics. + * The {@code "auto.offset.reset"} strategy, {@link TimestampExtractor}, key and value deserializers + * are defined by the options in {@link Consumed} are used. + *

                + * If multiple topics are specified there is no ordering guarantee for records from different topics. + *

                + * Note that the specified input topics must be partitioned by key. + * If this is not the case it is the user's responsibility to repartition the data before any key based operation + * (like aggregation or join) is applied to the returned {@link KStream}. + * + * @param topics the topic names; must contain at least one topic name + * @param consumed the instance of {@link Consumed} used to define optional parameters + * @return a {@link KStream} for the specified topics + */ + public synchronized KStream stream(final Collection topics, + final Consumed consumed) { + Objects.requireNonNull(topics, "topics can't be null"); + Objects.requireNonNull(consumed, "consumed can't be null"); + return internalStreamsBuilder.stream(topics, new ConsumedInternal<>(consumed)); + } + + + /** + * Create a {@link KStream} from the specified topic pattern. + * The default {@code "auto.offset.reset"} strategy, default {@link TimestampExtractor}, and default key and value + * deserializers as specified in the {@link StreamsConfig config} are used. + *

                + * If multiple topics are matched by the specified pattern, the created {@link KStream} will read data from all of + * them and there is no ordering guarantee between records from different topics. This also means that the work + * will not be parallelized for multiple topics, and the number of tasks will scale with the maximum partition + * count of any matching topic rather than the total number of partitions across all topics. + *

                + * Note that the specified input topics must be partitioned by key. + * If this is not the case it is the user's responsibility to repartition the data before any key based operation + * (like aggregation or join) is applied to the returned {@link KStream}. + * + * @param topicPattern the pattern to match for topic names + * @return a {@link KStream} for topics matching the regex pattern. + */ + public synchronized KStream stream(final Pattern topicPattern) { + return stream(topicPattern, Consumed.with(null, null)); + } + + /** + * Create a {@link KStream} from the specified topic pattern. + * The {@code "auto.offset.reset"} strategy, {@link TimestampExtractor}, key and value deserializers + * are defined by the options in {@link Consumed} are used. + *

                + * If multiple topics are matched by the specified pattern, the created {@link KStream} will read data from all of + * them and there is no ordering guarantee between records from different topics. This also means that the work + * will not be parallelized for multiple topics, and the number of tasks will scale with the maximum partition + * count of any matching topic rather than the total number of partitions across all topics. + *

                + * Note that the specified input topics must be partitioned by key. + * If this is not the case it is the user's responsibility to repartition the data before any key based operation + * (like aggregation or join) is applied to the returned {@link KStream}. + * + * @param topicPattern the pattern to match for topic names + * @param consumed the instance of {@link Consumed} used to define optional parameters + * @return a {@link KStream} for topics matching the regex pattern. + */ + public synchronized KStream stream(final Pattern topicPattern, + final Consumed consumed) { + Objects.requireNonNull(topicPattern, "topicPattern can't be null"); + Objects.requireNonNull(consumed, "consumed can't be null"); + return internalStreamsBuilder.stream(topicPattern, new ConsumedInternal<>(consumed)); + } + + /** + * Create a {@link KTable} for the specified topic. + * The {@code "auto.offset.reset"} strategy, {@link TimestampExtractor}, key and value deserializers + * are defined by the options in {@link Consumed} are used. + * Input {@link KeyValue records} with {@code null} key will be dropped. + *

                + * Note that the specified input topic must be partitioned by key. + * If this is not the case the returned {@link KTable} will be corrupted. + *

                + * The resulting {@link KTable} will be materialized in a local {@link KeyValueStore} using the given + * {@code Materialized} instance. + * An internal changelog topic is created by default. Because the source topic can + * be used for recovery, you can avoid creating the changelog topic by setting + * the {@code "topology.optimization"} to {@code "all"} in the {@link StreamsConfig}. + *

                + * You should only specify serdes in the {@link Consumed} instance as these will also be used to overwrite the + * serdes in {@link Materialized}, i.e., + *

                 {@code
                +     * streamBuilder.table(topic, Consumed.with(Serde.String(), Serde.String()), Materialized.as(storeName))
                +     * }
                +     * 
                + * To query the local {@link ReadOnlyKeyValueStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *
                {@code
                +     * KafkaStreams streams = ...
                +     * ReadOnlyKeyValueStore> localStore = streams.store(queryableStoreName, QueryableStoreTypes.>timestampedKeyValueStore());
                +     * K key = "some-key";
                +     * ValueAndTimestamp valueForKey = localStore.get(key); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + * + * @param topic the topic name; cannot be {@code null} + * @param consumed the instance of {@link Consumed} used to define optional parameters; cannot be {@code null} + * @param materialized the instance of {@link Materialized} used to materialize a state store; cannot be {@code null} + * @return a {@link KTable} for the specified topic + */ + public synchronized KTable table(final String topic, + final Consumed consumed, + final Materialized> materialized) { + Objects.requireNonNull(topic, "topic can't be null"); + Objects.requireNonNull(consumed, "consumed can't be null"); + Objects.requireNonNull(materialized, "materialized can't be null"); + final ConsumedInternal consumedInternal = new ConsumedInternal<>(consumed); + materialized.withKeySerde(consumedInternal.keySerde()).withValueSerde(consumedInternal.valueSerde()); + + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(materialized, internalStreamsBuilder, topic + "-"); + + return internalStreamsBuilder.table(topic, consumedInternal, materializedInternal); + } + + /** + * Create a {@link KTable} for the specified topic. + * The default {@code "auto.offset.reset"} strategy and default key and value deserializers as specified in the + * {@link StreamsConfig config} are used. + * Input {@link KeyValue records} with {@code null} key will be dropped. + *

                + * Note that the specified input topics must be partitioned by key. + * If this is not the case the returned {@link KTable} will be corrupted. + *

                + * The resulting {@link KTable} will be materialized in a local {@link KeyValueStore} with an internal + * store name. Note that store name may not be queryable through Interactive Queries. + * An internal changelog topic is created by default. Because the source topic can + * be used for recovery, you can avoid creating the changelog topic by setting + * the {@code "topology.optimization"} to {@code "all"} in the {@link StreamsConfig}. + * + * @param topic the topic name; cannot be {@code null} + * @return a {@link KTable} for the specified topic + */ + public synchronized KTable table(final String topic) { + return table(topic, new ConsumedInternal<>()); + } + + /** + * Create a {@link KTable} for the specified topic. + * The {@code "auto.offset.reset"} strategy, {@link TimestampExtractor}, key and value deserializers + * are defined by the options in {@link Consumed} are used. + * Input {@link KeyValue records} with {@code null} key will be dropped. + *

                + * Note that the specified input topics must be partitioned by key. + * If this is not the case the returned {@link KTable} will be corrupted. + *

                + * The resulting {@link KTable} will be materialized in a local {@link KeyValueStore} with an internal + * store name. Note that store name may not be queryable through Interactive Queries. + * An internal changelog topic is created by default. Because the source topic can + * be used for recovery, you can avoid creating the changelog topic by setting + * the {@code "topology.optimization"} to {@code "all"} in the {@link StreamsConfig}. + * + * @param topic the topic name; cannot be {@code null} + * @param consumed the instance of {@link Consumed} used to define optional parameters; cannot be {@code null} + * @return a {@link KTable} for the specified topic + */ + public synchronized KTable table(final String topic, + final Consumed consumed) { + Objects.requireNonNull(topic, "topic can't be null"); + Objects.requireNonNull(consumed, "consumed can't be null"); + final ConsumedInternal consumedInternal = new ConsumedInternal<>(consumed); + + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>( + Materialized.with(consumedInternal.keySerde(), consumedInternal.valueSerde()), + internalStreamsBuilder, + topic + "-"); + + return internalStreamsBuilder.table(topic, consumedInternal, materializedInternal); + } + + /** + * Create a {@link KTable} for the specified topic. + * The default {@code "auto.offset.reset"} strategy as specified in the {@link StreamsConfig config} are used. + * Key and value deserializers as defined by the options in {@link Materialized} are used. + * Input {@link KeyValue records} with {@code null} key will be dropped. + *

                + * Note that the specified input topics must be partitioned by key. + * If this is not the case the returned {@link KTable} will be corrupted. + *

                + * The resulting {@link KTable} will be materialized in a local {@link KeyValueStore} using the {@link Materialized} instance. + * An internal changelog topic is created by default. Because the source topic can + * be used for recovery, you can avoid creating the changelog topic by setting + * the {@code "topology.optimization"} to {@code "all"} in the {@link StreamsConfig}. + * + * @param topic the topic name; cannot be {@code null} + * @param materialized the instance of {@link Materialized} used to materialize a state store; cannot be {@code null} + * @return a {@link KTable} for the specified topic + */ + public synchronized KTable table(final String topic, + final Materialized> materialized) { + Objects.requireNonNull(topic, "topic can't be null"); + Objects.requireNonNull(materialized, "materialized can't be null"); + + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(materialized, internalStreamsBuilder, topic + "-"); + + final ConsumedInternal consumedInternal = + new ConsumedInternal<>(Consumed.with(materializedInternal.keySerde(), materializedInternal.valueSerde())); + + return internalStreamsBuilder.table(topic, consumedInternal, materializedInternal); + } + + /** + * Create a {@link GlobalKTable} for the specified topic. + * Input {@link KeyValue records} with {@code null} key will be dropped. + *

                + * The resulting {@link GlobalKTable} will be materialized in a local {@link KeyValueStore} with an internal + * store name. Note that store name may not be queryable through Interactive Queries. + * No internal changelog topic is created since the original input topic can be used for recovery (cf. + * methods of {@link KGroupedStream} and {@link KGroupedTable} that return a {@link KTable}). + *

                + * Note that {@link GlobalKTable} always applies {@code "auto.offset.reset"} strategy {@code "earliest"} + * regardless of the specified value in {@link StreamsConfig} or {@link Consumed}. + * + * @param topic the topic name; cannot be {@code null} + * @param consumed the instance of {@link Consumed} used to define optional parameters + * @return a {@link GlobalKTable} for the specified topic + */ + public synchronized GlobalKTable globalTable(final String topic, + final Consumed consumed) { + Objects.requireNonNull(topic, "topic can't be null"); + Objects.requireNonNull(consumed, "consumed can't be null"); + final ConsumedInternal consumedInternal = new ConsumedInternal<>(consumed); + + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>( + Materialized.with(consumedInternal.keySerde(), consumedInternal.valueSerde()), + internalStreamsBuilder, topic + "-"); + + return internalStreamsBuilder.globalTable(topic, consumedInternal, materializedInternal); + } + + /** + * Create a {@link GlobalKTable} for the specified topic. + * The default key and value deserializers as specified in the {@link StreamsConfig config} are used. + * Input {@link KeyValue records} with {@code null} key will be dropped. + *

                + * The resulting {@link GlobalKTable} will be materialized in a local {@link KeyValueStore} with an internal + * store name. Note that store name may not be queryable through Interactive Queries. + * No internal changelog topic is created since the original input topic can be used for recovery (cf. + * methods of {@link KGroupedStream} and {@link KGroupedTable} that return a {@link KTable}). + *

                + * Note that {@link GlobalKTable} always applies {@code "auto.offset.reset"} strategy {@code "earliest"} + * regardless of the specified value in {@link StreamsConfig}. + * + * @param topic the topic name; cannot be {@code null} + * @return a {@link GlobalKTable} for the specified topic + */ + public synchronized GlobalKTable globalTable(final String topic) { + return globalTable(topic, Consumed.with(null, null)); + } + + /** + * Create a {@link GlobalKTable} for the specified topic. + * + * Input {@link KeyValue} pairs with {@code null} key will be dropped. + *

                + * The resulting {@link GlobalKTable} will be materialized in a local {@link KeyValueStore} configured with + * the provided instance of {@link Materialized}. + * However, no internal changelog topic is created since the original input topic can be used for recovery (cf. + * methods of {@link KGroupedStream} and {@link KGroupedTable} that return a {@link KTable}). + *

                + * You should only specify serdes in the {@link Consumed} instance as these will also be used to overwrite the + * serdes in {@link Materialized}, i.e., + *

                 {@code
                +     * streamBuilder.globalTable(topic, Consumed.with(Serde.String(), Serde.String()), Materialized.as(storeName))
                +     * }
                +     * 
                + * To query the local {@link ReadOnlyKeyValueStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *
                {@code
                +     * KafkaStreams streams = ...
                +     * ReadOnlyKeyValueStore> localStore = streams.store(queryableStoreName, QueryableStoreTypes.>timestampedKeyValueStore());
                +     * K key = "some-key";
                +     * ValueAndTimestamp valueForKey = localStore.get(key);
                +     * }
                + * Note that {@link GlobalKTable} always applies {@code "auto.offset.reset"} strategy {@code "earliest"} + * regardless of the specified value in {@link StreamsConfig} or {@link Consumed}. + * + * @param topic the topic name; cannot be {@code null} + * @param consumed the instance of {@link Consumed} used to define optional parameters; can't be {@code null} + * @param materialized the instance of {@link Materialized} used to materialize a state store; cannot be {@code null} + * @return a {@link GlobalKTable} for the specified topic + */ + public synchronized GlobalKTable globalTable(final String topic, + final Consumed consumed, + final Materialized> materialized) { + Objects.requireNonNull(topic, "topic can't be null"); + Objects.requireNonNull(consumed, "consumed can't be null"); + Objects.requireNonNull(materialized, "materialized can't be null"); + final ConsumedInternal consumedInternal = new ConsumedInternal<>(consumed); + // always use the serdes from consumed + materialized.withKeySerde(consumedInternal.keySerde()).withValueSerde(consumedInternal.valueSerde()); + + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(materialized, internalStreamsBuilder, topic + "-"); + + return internalStreamsBuilder.globalTable(topic, consumedInternal, materializedInternal); + } + + /** + * Create a {@link GlobalKTable} for the specified topic. + * + * Input {@link KeyValue} pairs with {@code null} key will be dropped. + *

                + * The resulting {@link GlobalKTable} will be materialized in a local {@link KeyValueStore} configured with + * the provided instance of {@link Materialized}. + * However, no internal changelog topic is created since the original input topic can be used for recovery (cf. + * methods of {@link KGroupedStream} and {@link KGroupedTable} that return a {@link KTable}). + *

                + * To query the local {@link ReadOnlyKeyValueStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ...
                +     * ReadOnlyKeyValueStore> localStore = streams.store(queryableStoreName, QueryableStoreTypes.>timestampedKeyValueStore());
                +     * K key = "some-key";
                +     * ValueAndTimestamp valueForKey = localStore.get(key);
                +     * }
                + * Note that {@link GlobalKTable} always applies {@code "auto.offset.reset"} strategy {@code "earliest"} + * regardless of the specified value in {@link StreamsConfig}. + * + * @param topic the topic name; cannot be {@code null} + * @param materialized the instance of {@link Materialized} used to materialize a state store; cannot be {@code null} + * @return a {@link GlobalKTable} for the specified topic + */ + public synchronized GlobalKTable globalTable(final String topic, + final Materialized> materialized) { + Objects.requireNonNull(topic, "topic can't be null"); + Objects.requireNonNull(materialized, "materialized can't be null"); + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(materialized, internalStreamsBuilder, topic + "-"); + + return internalStreamsBuilder.globalTable(topic, + new ConsumedInternal<>(Consumed.with(materializedInternal.keySerde(), + materializedInternal.valueSerde())), + materializedInternal); + } + + + /** + * Adds a state store to the underlying {@link Topology}. + *

                + * It is required to connect state stores to {@link org.apache.kafka.streams.processor.api.Processor Processors}, + * {@link Transformer Transformers}, + * or {@link ValueTransformer ValueTransformers} before they can be used. + * + * @param builder the builder used to obtain this state store {@link StateStore} instance + * @return itself + * @throws TopologyException if state store supplier is already added + */ + public synchronized StreamsBuilder addStateStore(final StoreBuilder builder) { + Objects.requireNonNull(builder, "builder can't be null"); + internalStreamsBuilder.addStateStore(builder); + return this; + } + + /** + * Adds a global {@link StateStore} to the topology. + * The {@link StateStore} sources its data from all partitions of the provided input topic. + * There will be exactly one instance of this {@link StateStore} per Kafka Streams instance. + *

                + * A {@link SourceNode} with the provided sourceName will be added to consume the data arriving from the partitions + * of the input topic. + *

                + * The provided {@link org.apache.kafka.streams.processor.ProcessorSupplier} will be used to create an {@link ProcessorNode} that will receive all + * records forwarded from the {@link SourceNode}. NOTE: you should not use the {@code Processor} to insert transformed records into + * the global state store. This store uses the source topic as changelog and during restore will insert records directly + * from the source. + * This {@link ProcessorNode} should be used to keep the {@link StateStore} up-to-date. + * The default {@link TimestampExtractor} as specified in the {@link StreamsConfig config} is used. + *

                + * It is not required to connect a global store to {@link org.apache.kafka.streams.processor.api.Processor Processors}, + * {@link Transformer Transformers}, + * or {@link ValueTransformer ValueTransformer}; those have read-only access to all global stores by default. + *

                + * The supplier should always generate a new instance each time {@link ProcessorSupplier#get()} gets called. Creating + * a single {@link Processor} object and returning the same object reference in {@link ProcessorSupplier#get()} would be + * a violation of the supplier pattern and leads to runtime exceptions. + * + * @param storeBuilder user defined {@link StoreBuilder}; can't be {@code null} + * @param topic the topic to source the data from + * @param consumed the instance of {@link Consumed} used to define optional parameters; can't be {@code null} + * @param stateUpdateSupplier the instance of {@link org.apache.kafka.streams.processor.ProcessorSupplier} + * @return itself + * @throws TopologyException if the processor of state is already registered + * @deprecated Since 2.7.0; use {@link #addGlobalStore(StoreBuilder, String, Consumed, ProcessorSupplier)} instead. + */ + @Deprecated + public synchronized StreamsBuilder addGlobalStore(final StoreBuilder storeBuilder, + final String topic, + final Consumed consumed, + final org.apache.kafka.streams.processor.ProcessorSupplier stateUpdateSupplier) { + Objects.requireNonNull(storeBuilder, "storeBuilder can't be null"); + Objects.requireNonNull(consumed, "consumed can't be null"); + internalStreamsBuilder.addGlobalStore( + storeBuilder, + topic, + new ConsumedInternal<>(consumed), + () -> ProcessorAdapter.adapt(stateUpdateSupplier.get()) + ); + return this; + } + + /** + * Adds a global {@link StateStore} to the topology. + * The {@link StateStore} sources its data from all partitions of the provided input topic. + * There will be exactly one instance of this {@link StateStore} per Kafka Streams instance. + *

                + * A {@link SourceNode} with the provided sourceName will be added to consume the data arriving from the partitions + * of the input topic. + *

                + * The provided {@link ProcessorSupplier}} will be used to create an + * {@link Processor} that will receive all records forwarded from the {@link SourceNode}. + * The supplier should always generate a new instance. Creating a single {@link Processor} object + * and returning the same object reference in {@link ProcessorSupplier#get()} is a + * violation of the supplier pattern and leads to runtime exceptions. + * NOTE: you should not use the {@link Processor} to insert transformed records into + * the global state store. This store uses the source topic as changelog and during restore will insert records directly + * from the source. + * This {@link Processor} should be used to keep the {@link StateStore} up-to-date. + * The default {@link TimestampExtractor} as specified in the {@link StreamsConfig config} is used. + *

                + * It is not required to connect a global store to the {@link Processor Processors}, + * {@link Transformer Transformers}, or {@link ValueTransformer ValueTransformer}; those have read-only access to all global stores by default. + * + * @param storeBuilder user defined {@link StoreBuilder}; can't be {@code null} + * @param topic the topic to source the data from + * @param consumed the instance of {@link Consumed} used to define optional parameters; can't be {@code null} + * @param stateUpdateSupplier the instance of {@link ProcessorSupplier} + * @return itself + * @throws TopologyException if the processor of state is already registered + */ + public synchronized StreamsBuilder addGlobalStore(final StoreBuilder storeBuilder, + final String topic, + final Consumed consumed, + final ProcessorSupplier stateUpdateSupplier) { + Objects.requireNonNull(storeBuilder, "storeBuilder can't be null"); + Objects.requireNonNull(consumed, "consumed can't be null"); + internalStreamsBuilder.addGlobalStore( + storeBuilder, + topic, + new ConsumedInternal<>(consumed), + stateUpdateSupplier + ); + return this; + } + + /** + * Returns the {@link Topology} that represents the specified processing logic. + * Note that using this method means no optimizations are performed. + * + * @return the {@link Topology} that represents the specified processing logic + */ + public synchronized Topology build() { + return build(null); + } + + /** + * Returns the {@link Topology} that represents the specified processing logic and accepts + * a {@link Properties} instance used to indicate whether to optimize topology or not. + * + * @param props the {@link Properties} used for building possibly optimized topology + * @return the {@link Topology} that represents the specified processing logic + */ + public synchronized Topology build(final Properties props) { + internalStreamsBuilder.buildAndOptimizeTopology(props); + return topology; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java b/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java new file mode 100644 index 0000000..61fb27d --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/StreamsConfig.java @@ -0,0 +1,1522 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.AdminClientConfig; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.config.AbstractConfig; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigDef.Importance; +import org.apache.kafka.common.config.ConfigDef.Type; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.config.TopicConfig; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.Sensor.RecordingLevel; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.errors.DefaultProductionExceptionHandler; +import org.apache.kafka.streams.errors.DeserializationExceptionHandler; +import org.apache.kafka.streams.errors.LogAndFailExceptionHandler; +import org.apache.kafka.streams.errors.ProductionExceptionHandler; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.FailOnInvalidTimestamp; +import org.apache.kafka.streams.processor.TimestampExtractor; +import org.apache.kafka.streams.processor.internals.StreamThread; +import org.apache.kafka.streams.processor.internals.StreamsPartitionAssignor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.time.Duration; +import java.util.Collections; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; +import java.util.Properties; +import java.util.Set; + +import static org.apache.kafka.common.IsolationLevel.READ_COMMITTED; +import static org.apache.kafka.common.config.ConfigDef.Range.atLeast; +import static org.apache.kafka.common.config.ConfigDef.Range.between; +import static org.apache.kafka.common.config.ConfigDef.ValidString.in; +import static org.apache.kafka.common.config.ConfigDef.parseType; + +/** + * Configuration for a {@link KafkaStreams} instance. + * Can also be used to configure the Kafka Streams internal {@link KafkaConsumer}, {@link KafkaProducer} and {@link Admin}. + * To avoid consumer/producer/admin property conflicts, you should prefix those properties using + * {@link #consumerPrefix(String)}, {@link #producerPrefix(String)} and {@link #adminClientPrefix(String)}, respectively. + *

                + * Example: + *

                {@code
                + * // potentially wrong: sets "metadata.max.age.ms" to 1 minute for producer AND consumer
                + * Properties streamsProperties = new Properties();
                + * streamsProperties.put(ConsumerConfig.METADATA_MAX_AGE_CONFIG, 60000);
                + * // or
                + * streamsProperties.put(ProducerConfig.METADATA_MAX_AGE_CONFIG, 60000);
                + *
                + * // suggested:
                + * Properties streamsProperties = new Properties();
                + * // sets "metadata.max.age.ms" to 1 minute for consumer only
                + * streamsProperties.put(StreamsConfig.consumerPrefix(ConsumerConfig.METADATA_MAX_AGE_CONFIG), 60000);
                + * // sets "metadata.max.age.ms" to 1 minute for producer only
                + * streamsProperties.put(StreamsConfig.producerPrefix(ProducerConfig.METADATA_MAX_AGE_CONFIG), 60000);
                + *
                + * StreamsConfig streamsConfig = new StreamsConfig(streamsProperties);
                + * }
                + * + * This instance can also be used to pass in custom configurations to different modules (e.g. passing a special config in your customized serde class). + * The consumer/producer/admin prefix can also be used to distinguish these custom config values passed to different clients with the same config name. + * * Example: + *
                {@code
                + * Properties streamsProperties = new Properties();
                + * // sets "my.custom.config" to "foo" for consumer only
                + * streamsProperties.put(StreamsConfig.consumerPrefix("my.custom.config"), "foo");
                + * // sets "my.custom.config" to "bar" for producer only
                + * streamsProperties.put(StreamsConfig.producerPrefix("my.custom.config"), "bar");
                + * // sets "my.custom.config2" to "boom" for all clients universally
                + * streamsProperties.put("my.custom.config2", "boom");
                + *
                + * // as a result, inside producer's serde class configure(..) function,
                + * // users can now read both key-value pairs "my.custom.config" -> "foo"
                + * // and "my.custom.config2" -> "boom" from the config map
                + * StreamsConfig streamsConfig = new StreamsConfig(streamsProperties);
                + * }
                + * + * + * When increasing {@link ProducerConfig#MAX_BLOCK_MS_CONFIG} to be more resilient to non-available brokers you should also + * increase {@link ConsumerConfig#MAX_POLL_INTERVAL_MS_CONFIG} using the following guidance: + *
                + *     max.poll.interval.ms > max.block.ms
                + * 
                + * + * + * Kafka Streams requires at least the following properties to be set: + *
                  + *
                • {@link #APPLICATION_ID_CONFIG "application.id"}
                • + *
                • {@link #BOOTSTRAP_SERVERS_CONFIG "bootstrap.servers"}
                • + *
                + * + * By default, Kafka Streams does not allow users to overwrite the following properties (Streams setting shown in parentheses): + *
                  + *
                • {@link ConsumerConfig#GROUP_ID_CONFIG "group.id"} (<application.id>) - Streams client will always use the application ID a consumer group ID
                • + *
                • {@link ConsumerConfig#ENABLE_AUTO_COMMIT_CONFIG "enable.auto.commit"} (false) - Streams client will always disable/turn off auto committing
                • + *
                • {@link ConsumerConfig#PARTITION_ASSIGNMENT_STRATEGY_CONFIG "partition.assignment.strategy"} (StreamsPartitionAssignor) - Streams client will always use its own partition assignor
                • + *
                + * + * If {@link #PROCESSING_GUARANTEE_CONFIG "processing.guarantee"} is set to {@link #EXACTLY_ONCE_V2 "exactly_once_v2"}, + * {@link #EXACTLY_ONCE "exactly_once"} (deprecated), or {@link #EXACTLY_ONCE_BETA "exactly_once_beta"} (deprecated), Kafka Streams does not + * allow users to overwrite the following properties (Streams setting shown in parentheses): + *
                  + *
                • {@link ConsumerConfig#ISOLATION_LEVEL_CONFIG "isolation.level"} (read_committed) - Consumers will always read committed data only
                • + *
                • {@link ProducerConfig#ENABLE_IDEMPOTENCE_CONFIG "enable.idempotence"} (true) - Producer will always have idempotency enabled
                • + *
                + * + * @see KafkaStreams#KafkaStreams(org.apache.kafka.streams.Topology, Properties) + * @see ConsumerConfig + * @see ProducerConfig + */ +public class StreamsConfig extends AbstractConfig { + + private static final Logger log = LoggerFactory.getLogger(StreamsConfig.class); + + private static final ConfigDef CONFIG; + + private final boolean eosEnabled; + private static final long DEFAULT_COMMIT_INTERVAL_MS = 30000L; + private static final long EOS_DEFAULT_COMMIT_INTERVAL_MS = 100L; + private static final int DEFAULT_TRANSACTION_TIMEOUT = 10000; + + public static final int DUMMY_THREAD_INDEX = 1; + public static final long MAX_TASK_IDLE_MS_DISABLED = -1; + + /** + * Prefix used to provide default topic configs to be applied when creating internal topics. + * These should be valid properties from {@link org.apache.kafka.common.config.TopicConfig TopicConfig}. + * It is recommended to use {@link #topicPrefix(String)}. + */ + // TODO: currently we cannot get the full topic configurations and hence cannot allow topic configs without the prefix, + // this can be lifted once kafka.log.LogConfig is completely deprecated by org.apache.kafka.common.config.TopicConfig + @SuppressWarnings("WeakerAccess") + public static final String TOPIC_PREFIX = "topic."; + + /** + * Prefix used to isolate {@link KafkaConsumer consumer} configs from other client configs. + * It is recommended to use {@link #consumerPrefix(String)} to add this prefix to {@link ConsumerConfig consumer + * properties}. + */ + @SuppressWarnings("WeakerAccess") + public static final String CONSUMER_PREFIX = "consumer."; + + /** + * Prefix used to override {@link KafkaConsumer consumer} configs for the main consumer client from + * the general consumer client configs. The override precedence is the following (from highest to lowest precedence): + * 1. main.consumer.[config-name] + * 2. consumer.[config-name] + * 3. [config-name] + */ + @SuppressWarnings("WeakerAccess") + public static final String MAIN_CONSUMER_PREFIX = "main.consumer."; + + /** + * Prefix used to override {@link KafkaConsumer consumer} configs for the restore consumer client from + * the general consumer client configs. The override precedence is the following (from highest to lowest precedence): + * 1. restore.consumer.[config-name] + * 2. consumer.[config-name] + * 3. [config-name] + */ + @SuppressWarnings("WeakerAccess") + public static final String RESTORE_CONSUMER_PREFIX = "restore.consumer."; + + /** + * Prefix used to override {@link KafkaConsumer consumer} configs for the global consumer client from + * the general consumer client configs. The override precedence is the following (from highest to lowest precedence): + * 1. global.consumer.[config-name] + * 2. consumer.[config-name] + * 3. [config-name] + */ + @SuppressWarnings("WeakerAccess") + public static final String GLOBAL_CONSUMER_PREFIX = "global.consumer."; + + /** + * Prefix used to isolate {@link KafkaProducer producer} configs from other client configs. + * It is recommended to use {@link #producerPrefix(String)} to add this prefix to {@link ProducerConfig producer + * properties}. + */ + @SuppressWarnings("WeakerAccess") + public static final String PRODUCER_PREFIX = "producer."; + + /** + * Prefix used to isolate {@link Admin admin} configs from other client configs. + * It is recommended to use {@link #adminClientPrefix(String)} to add this prefix to {@link AdminClientConfig admin + * client properties}. + */ + @SuppressWarnings("WeakerAccess") + public static final String ADMIN_CLIENT_PREFIX = "admin."; + + /** + * Config value for parameter {@link #TOPOLOGY_OPTIMIZATION_CONFIG "topology.optimization"} for disabling topology optimization + */ + public static final String NO_OPTIMIZATION = "none"; + + /** + * Config value for parameter {@link #TOPOLOGY_OPTIMIZATION_CONFIG "topology.optimization"} for enabling topology optimization + */ + public static final String OPTIMIZE = "all"; + + /** + * Config value for parameter {@link #UPGRADE_FROM_CONFIG "upgrade.from"} for upgrading an application from version {@code 0.10.0.x}. + */ + @SuppressWarnings("WeakerAccess") + public static final String UPGRADE_FROM_0100 = "0.10.0"; + + /** + * Config value for parameter {@link #UPGRADE_FROM_CONFIG "upgrade.from"} for upgrading an application from version {@code 0.10.1.x}. + */ + @SuppressWarnings("WeakerAccess") + public static final String UPGRADE_FROM_0101 = "0.10.1"; + + /** + * Config value for parameter {@link #UPGRADE_FROM_CONFIG "upgrade.from"} for upgrading an application from version {@code 0.10.2.x}. + */ + @SuppressWarnings("WeakerAccess") + public static final String UPGRADE_FROM_0102 = "0.10.2"; + + /** + * Config value for parameter {@link #UPGRADE_FROM_CONFIG "upgrade.from"} for upgrading an application from version {@code 0.11.0.x}. + */ + @SuppressWarnings("WeakerAccess") + public static final String UPGRADE_FROM_0110 = "0.11.0"; + + /** + * Config value for parameter {@link #UPGRADE_FROM_CONFIG "upgrade.from"} for upgrading an application from version {@code 1.0.x}. + */ + @SuppressWarnings("WeakerAccess") + public static final String UPGRADE_FROM_10 = "1.0"; + + /** + * Config value for parameter {@link #UPGRADE_FROM_CONFIG "upgrade.from"} for upgrading an application from version {@code 1.1.x}. + */ + @SuppressWarnings("WeakerAccess") + public static final String UPGRADE_FROM_11 = "1.1"; + + /** + * Config value for parameter {@link #UPGRADE_FROM_CONFIG "upgrade.from"} for upgrading an application from version {@code 2.0.x}. + */ + @SuppressWarnings("WeakerAccess") + public static final String UPGRADE_FROM_20 = "2.0"; + + /** + * Config value for parameter {@link #UPGRADE_FROM_CONFIG "upgrade.from"} for upgrading an application from version {@code 2.1.x}. + */ + @SuppressWarnings("WeakerAccess") + public static final String UPGRADE_FROM_21 = "2.1"; + + /** + * Config value for parameter {@link #UPGRADE_FROM_CONFIG "upgrade.from"} for upgrading an application from version {@code 2.2.x}. + */ + @SuppressWarnings("WeakerAccess") + public static final String UPGRADE_FROM_22 = "2.2"; + + /** + * Config value for parameter {@link #UPGRADE_FROM_CONFIG "upgrade.from"} for upgrading an application from version {@code 2.3.x}. + */ + @SuppressWarnings("WeakerAccess") + public static final String UPGRADE_FROM_23 = "2.3"; + + /** + * Config value for parameter {@link #PROCESSING_GUARANTEE_CONFIG "processing.guarantee"} for at-least-once processing guarantees. + */ + @SuppressWarnings("WeakerAccess") + public static final String AT_LEAST_ONCE = "at_least_once"; + + /** + * Config value for parameter {@link #PROCESSING_GUARANTEE_CONFIG "processing.guarantee"} for exactly-once processing guarantees. + *

                + * Enabling exactly-once processing semantics requires broker version 0.11.0 or higher. + * If you enable this feature Kafka Streams will use more resources (like broker connections) + * compared to {@link #AT_LEAST_ONCE "at_least_once"} and {@link #EXACTLY_ONCE_V2 "exactly_once_v2"}. + * + * @deprecated Since 3.0.0, will be removed in 4.0. Use {@link #EXACTLY_ONCE_V2 "exactly_once_v2"} instead. + */ + @SuppressWarnings("WeakerAccess") + @Deprecated + public static final String EXACTLY_ONCE = "exactly_once"; + + /** + * Config value for parameter {@link #PROCESSING_GUARANTEE_CONFIG "processing.guarantee"} for exactly-once processing guarantees. + *

                + * Enabling exactly-once (beta) requires broker version 2.5 or higher. + * If you enable this feature Kafka Streams will use fewer resources (like broker connections) + * compared to the {@link #EXACTLY_ONCE} (deprecated) case. + * + * @deprecated Since 3.0.0, will be removed in 4.0. Use {@link #EXACTLY_ONCE_V2 "exactly_once_v2"} instead. + */ + @SuppressWarnings("WeakerAccess") + @Deprecated + public static final String EXACTLY_ONCE_BETA = "exactly_once_beta"; + + /** + * Config value for parameter {@link #PROCESSING_GUARANTEE_CONFIG "processing.guarantee"} for exactly-once processing guarantees. + *

                + * Enabling exactly-once-v2 requires broker version 2.5 or higher. + */ + @SuppressWarnings("WeakerAccess") + public static final String EXACTLY_ONCE_V2 = "exactly_once_v2"; + + /** + * Config value for parameter {@link #BUILT_IN_METRICS_VERSION_CONFIG "built.in.metrics.version"} for the latest built-in metrics version. + */ + public static final String METRICS_LATEST = "latest"; + + /** {@code acceptable.recovery.lag} */ + public static final String ACCEPTABLE_RECOVERY_LAG_CONFIG = "acceptable.recovery.lag"; + private static final String ACCEPTABLE_RECOVERY_LAG_DOC = "The maximum acceptable lag (number of offsets to catch up) for a client to be considered caught-up for an active task." + + "Should correspond to a recovery time of well under a minute for a given workload. Must be at least 0."; + + /** {@code application.id} */ + @SuppressWarnings("WeakerAccess") + public static final String APPLICATION_ID_CONFIG = "application.id"; + private static final String APPLICATION_ID_DOC = "An identifier for the stream processing application. Must be unique within the Kafka cluster. It is used as 1) the default client-id prefix, 2) the group-id for membership management, 3) the changelog topic prefix."; + + /**{@code application.server} */ + @SuppressWarnings("WeakerAccess") + public static final String APPLICATION_SERVER_CONFIG = "application.server"; + private static final String APPLICATION_SERVER_DOC = "A host:port pair pointing to a user-defined endpoint that can be used for state store discovery and interactive queries on this KafkaStreams instance."; + + /** {@code bootstrap.servers} */ + @SuppressWarnings("WeakerAccess") + public static final String BOOTSTRAP_SERVERS_CONFIG = CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG; + + /** {@code buffered.records.per.partition} */ + @SuppressWarnings("WeakerAccess") + public static final String BUFFERED_RECORDS_PER_PARTITION_CONFIG = "buffered.records.per.partition"; + private static final String BUFFERED_RECORDS_PER_PARTITION_DOC = "Maximum number of records to buffer per partition."; + + /** {@code built.in.metrics.version} */ + public static final String BUILT_IN_METRICS_VERSION_CONFIG = "built.in.metrics.version"; + private static final String BUILT_IN_METRICS_VERSION_DOC = "Version of the built-in metrics to use."; + + /** {@code cache.max.bytes.buffering} */ + @SuppressWarnings("WeakerAccess") + public static final String CACHE_MAX_BYTES_BUFFERING_CONFIG = "cache.max.bytes.buffering"; + private static final String CACHE_MAX_BYTES_BUFFERING_DOC = "Maximum number of memory bytes to be used for buffering across all threads"; + + /** {@code client.id} */ + @SuppressWarnings("WeakerAccess") + public static final String CLIENT_ID_CONFIG = CommonClientConfigs.CLIENT_ID_CONFIG; + private static final String CLIENT_ID_DOC = "An ID prefix string used for the client IDs of internal consumer, producer and restore-consumer," + + " with pattern '-StreamThread--'."; + + /** {@code commit.interval.ms} */ + @SuppressWarnings("WeakerAccess") + public static final String COMMIT_INTERVAL_MS_CONFIG = "commit.interval.ms"; + private static final String COMMIT_INTERVAL_MS_DOC = "The frequency in milliseconds with which to save the position of the processor." + + " (Note, if processing.guarantee is set to " + EXACTLY_ONCE_V2 + ", " + EXACTLY_ONCE + ",the default value is " + EOS_DEFAULT_COMMIT_INTERVAL_MS + "," + + " otherwise the default value is " + DEFAULT_COMMIT_INTERVAL_MS + "."; + + /** {@code connections.max.idle.ms} */ + @SuppressWarnings("WeakerAccess") + public static final String CONNECTIONS_MAX_IDLE_MS_CONFIG = CommonClientConfigs.CONNECTIONS_MAX_IDLE_MS_CONFIG; + + /** {@code default.deserialization.exception.handler} */ + @SuppressWarnings("WeakerAccess") + public static final String DEFAULT_DESERIALIZATION_EXCEPTION_HANDLER_CLASS_CONFIG = "default.deserialization.exception.handler"; + private static final String DEFAULT_DESERIALIZATION_EXCEPTION_HANDLER_CLASS_DOC = "Exception handling class that implements the org.apache.kafka.streams.errors.DeserializationExceptionHandler interface."; + + /** {@code default.production.exception.handler} */ + @SuppressWarnings("WeakerAccess") + public static final String DEFAULT_PRODUCTION_EXCEPTION_HANDLER_CLASS_CONFIG = "default.production.exception.handler"; + private static final String DEFAULT_PRODUCTION_EXCEPTION_HANDLER_CLASS_DOC = "Exception handling class that implements the org.apache.kafka.streams.errors.ProductionExceptionHandler interface."; + + /** {@code default.windowed.key.serde.inner} */ + @SuppressWarnings("WeakerAccess") + @Deprecated + public static final String DEFAULT_WINDOWED_KEY_SERDE_INNER_CLASS = "default.windowed.key.serde.inner"; + private static final String DEFAULT_WINDOWED_KEY_SERDE_INNER_CLASS_DOC = "Default serializer / deserializer for the inner class of a windowed key. Must implement the " + + "org.apache.kafka.common.serialization.Serde interface."; + + /** {@code default.windowed.value.serde.inner} */ + @SuppressWarnings("WeakerAccess") + @Deprecated + public static final String DEFAULT_WINDOWED_VALUE_SERDE_INNER_CLASS = "default.windowed.value.serde.inner"; + private static final String DEFAULT_WINDOWED_VALUE_SERDE_INNER_CLASS_DOC = "Default serializer / deserializer for the inner class of a windowed value. Must implement the " + + "org.apache.kafka.common.serialization.Serde interface."; + + public static final String WINDOWED_INNER_CLASS_SERDE = "windowed.inner.class.serde"; + private static final String WINDOWED_INNER_CLASS_SERDE_DOC = " Default serializer / deserializer for the inner class of a windowed record. Must implement the \" +\n" + + " \"org.apache.kafka.common.serialization.Serde interface.. Note that setting this config in KafkaStreams application would result " + + "in an error as it is meant to be used only from Plain consumer client."; + + /** {@code default key.serde} */ + @SuppressWarnings("WeakerAccess") + public static final String DEFAULT_KEY_SERDE_CLASS_CONFIG = "default.key.serde"; + private static final String DEFAULT_KEY_SERDE_CLASS_DOC = " Default serializer / deserializer class for key that implements the org.apache.kafka.common.serialization.Serde interface. " + + "Note when windowed serde class is used, one needs to set the inner serde class that implements the org.apache.kafka.common.serialization.Serde interface via '" + + DEFAULT_WINDOWED_KEY_SERDE_INNER_CLASS + "' or '" + DEFAULT_WINDOWED_VALUE_SERDE_INNER_CLASS + "' as well"; + + /** {@code default value.serde} */ + @SuppressWarnings("WeakerAccess") + public static final String DEFAULT_VALUE_SERDE_CLASS_CONFIG = "default.value.serde"; + private static final String DEFAULT_VALUE_SERDE_CLASS_DOC = "Default serializer / deserializer class for value that implements the org.apache.kafka.common.serialization.Serde interface. " + + "Note when windowed serde class is used, one needs to set the inner serde class that implements the org.apache.kafka.common.serialization.Serde interface via '" + + DEFAULT_WINDOWED_KEY_SERDE_INNER_CLASS + "' or '" + DEFAULT_WINDOWED_VALUE_SERDE_INNER_CLASS + "' as well"; + + /** {@code default.timestamp.extractor} */ + @SuppressWarnings("WeakerAccess") + public static final String DEFAULT_TIMESTAMP_EXTRACTOR_CLASS_CONFIG = "default.timestamp.extractor"; + private static final String DEFAULT_TIMESTAMP_EXTRACTOR_CLASS_DOC = "Default timestamp extractor class that implements the org.apache.kafka.streams.processor.TimestampExtractor interface."; + + /** {@code max.task.idle.ms} */ + public static final String MAX_TASK_IDLE_MS_CONFIG = "max.task.idle.ms"; + private static final String MAX_TASK_IDLE_MS_DOC = "This config controls whether joins and merges" + + " may produce out-of-order results." + + " The config value is the maximum amount of time in milliseconds a stream task will stay idle" + + " when it is fully caught up on some (but not all) input partitions" + + " to wait for producers to send additional records and avoid potential" + + " out-of-order record processing across multiple input streams." + + " The default (zero) does not wait for producers to send more records," + + " but it does wait to fetch data that is already present on the brokers." + + " This default means that for records that are already present on the brokers," + + " Streams will process them in timestamp order." + + " Set to -1 to disable idling entirely and process any locally available data," + + " even though doing so may produce out-of-order processing."; + + /** {@code max.warmup.replicas} */ + public static final String MAX_WARMUP_REPLICAS_CONFIG = "max.warmup.replicas"; + private static final String MAX_WARMUP_REPLICAS_DOC = "The maximum number of warmup replicas (extra standbys beyond the configured num.standbys) that can be assigned at once for the purpose of keeping " + + " the task available on one instance while it is warming up on another instance it has been reassigned to. Used to throttle how much extra broker " + + " traffic and cluster state can be used for high availability. Must be at least 1."; + + /** {@code metadata.max.age.ms} */ + @SuppressWarnings("WeakerAccess") + public static final String METADATA_MAX_AGE_CONFIG = CommonClientConfigs.METADATA_MAX_AGE_CONFIG; + + /** {@code metrics.num.samples} */ + @SuppressWarnings("WeakerAccess") + public static final String METRICS_NUM_SAMPLES_CONFIG = CommonClientConfigs.METRICS_NUM_SAMPLES_CONFIG; + + /** {@code metrics.record.level} */ + @SuppressWarnings("WeakerAccess") + public static final String METRICS_RECORDING_LEVEL_CONFIG = CommonClientConfigs.METRICS_RECORDING_LEVEL_CONFIG; + + /** {@code metric.reporters} */ + @SuppressWarnings("WeakerAccess") + public static final String METRIC_REPORTER_CLASSES_CONFIG = CommonClientConfigs.METRIC_REPORTER_CLASSES_CONFIG; + + /** {@code metrics.sample.window.ms} */ + @SuppressWarnings("WeakerAccess") + public static final String METRICS_SAMPLE_WINDOW_MS_CONFIG = CommonClientConfigs.METRICS_SAMPLE_WINDOW_MS_CONFIG; + + /** {@code num.standby.replicas} */ + @SuppressWarnings("WeakerAccess") + public static final String NUM_STANDBY_REPLICAS_CONFIG = "num.standby.replicas"; + private static final String NUM_STANDBY_REPLICAS_DOC = "The number of standby replicas for each task."; + + /** {@code num.stream.threads} */ + @SuppressWarnings("WeakerAccess") + public static final String NUM_STREAM_THREADS_CONFIG = "num.stream.threads"; + private static final String NUM_STREAM_THREADS_DOC = "The number of threads to execute stream processing."; + + /** {@code poll.ms} */ + @SuppressWarnings("WeakerAccess") + public static final String POLL_MS_CONFIG = "poll.ms"; + private static final String POLL_MS_DOC = "The amount of time in milliseconds to block waiting for input."; + + /** {@code probing.rebalance.interval.ms} */ + public static final String PROBING_REBALANCE_INTERVAL_MS_CONFIG = "probing.rebalance.interval.ms"; + private static final String PROBING_REBALANCE_INTERVAL_MS_DOC = "The maximum time in milliseconds to wait before triggering a rebalance to probe for warmup replicas that have finished warming up and are ready to become active." + + " Probing rebalances will continue to be triggered until the assignment is balanced. Must be at least 1 minute."; + + /** {@code processing.guarantee} */ + @SuppressWarnings("WeakerAccess") + public static final String PROCESSING_GUARANTEE_CONFIG = "processing.guarantee"; + private static final String PROCESSING_GUARANTEE_DOC = "The processing guarantee that should be used. " + + "Possible values are " + AT_LEAST_ONCE + " (default) " + + "and " + EXACTLY_ONCE_V2 + " (requires brokers version 2.5 or higher). " + + "Deprecated options are " + EXACTLY_ONCE + " (requires brokers version 0.11.0 or higher) " + + "and " + EXACTLY_ONCE_BETA + " (requires brokers version 2.5 or higher). " + + "Note that exactly-once processing requires a cluster of at least three brokers by default what is the " + + "recommended setting for production; for development you can change this, by adjusting broker setting " + + "transaction.state.log.replication.factor and transaction.state.log.min.isr."; + + /** {@code receive.buffer.bytes} */ + @SuppressWarnings("WeakerAccess") + public static final String RECEIVE_BUFFER_CONFIG = CommonClientConfigs.RECEIVE_BUFFER_CONFIG; + + /** {@code reconnect.backoff.ms} */ + @SuppressWarnings("WeakerAccess") + public static final String RECONNECT_BACKOFF_MS_CONFIG = CommonClientConfigs.RECONNECT_BACKOFF_MS_CONFIG; + + /** {@code reconnect.backoff.max} */ + @SuppressWarnings("WeakerAccess") + public static final String RECONNECT_BACKOFF_MAX_MS_CONFIG = CommonClientConfigs.RECONNECT_BACKOFF_MAX_MS_CONFIG; + + /** {@code replication.factor} */ + @SuppressWarnings("WeakerAccess") + public static final String REPLICATION_FACTOR_CONFIG = "replication.factor"; + private static final String REPLICATION_FACTOR_DOC = "The replication factor for change log topics and repartition topics created by the stream processing application." + + " The default of -1 (meaning: use broker default replication factor) requires broker version 2.4 or newer"; + + /** {@code request.timeout.ms} */ + @SuppressWarnings("WeakerAccess") + public static final String REQUEST_TIMEOUT_MS_CONFIG = CommonClientConfigs.REQUEST_TIMEOUT_MS_CONFIG; + + /** + * {@code retries} + *

                + * This config is ignored by Kafka Streams. Note, that the internal clients (producer, admin) are still impacted by this config. + * + * @deprecated since 2.7 + */ + @SuppressWarnings("WeakerAccess") + @Deprecated + public static final String RETRIES_CONFIG = CommonClientConfigs.RETRIES_CONFIG; + + /** {@code retry.backoff.ms} */ + @SuppressWarnings("WeakerAccess") + public static final String RETRY_BACKOFF_MS_CONFIG = CommonClientConfigs.RETRY_BACKOFF_MS_CONFIG; + + /** {@code rocksdb.config.setter} */ + @SuppressWarnings("WeakerAccess") + public static final String ROCKSDB_CONFIG_SETTER_CLASS_CONFIG = "rocksdb.config.setter"; + private static final String ROCKSDB_CONFIG_SETTER_CLASS_DOC = "A Rocks DB config setter class or class name that implements the org.apache.kafka.streams.state.RocksDBConfigSetter interface"; + + /** {@code security.protocol} */ + @SuppressWarnings("WeakerAccess") + public static final String SECURITY_PROTOCOL_CONFIG = CommonClientConfigs.SECURITY_PROTOCOL_CONFIG; + + /** {@code send.buffer.bytes} */ + @SuppressWarnings("WeakerAccess") + public static final String SEND_BUFFER_CONFIG = CommonClientConfigs.SEND_BUFFER_CONFIG; + + /** {@code state.cleanup.delay} */ + @SuppressWarnings("WeakerAccess") + public static final String STATE_CLEANUP_DELAY_MS_CONFIG = "state.cleanup.delay.ms"; + private static final String STATE_CLEANUP_DELAY_MS_DOC = "The amount of time in milliseconds to wait before deleting state when a partition has migrated. Only state directories that have not been modified for at least state.cleanup.delay.ms will be removed"; + + /** {@code state.dir} */ + @SuppressWarnings("WeakerAccess") + public static final String STATE_DIR_CONFIG = "state.dir"; + private static final String STATE_DIR_DOC = "Directory location for state store. This path must be unique for each streams instance sharing the same underlying filesystem."; + + /** {@code task.timeout.ms} */ + public static final String TASK_TIMEOUT_MS_CONFIG = "task.timeout.ms"; + public static final String TASK_TIMEOUT_MS_DOC = "The maximum amount of time in milliseconds a task might stall due to internal errors and retries until an error is raised. " + + "For a timeout of 0ms, a task would raise an error for the first internal error. " + + "For any timeout larger than 0ms, a task will retry at least once before an error is raised."; + + /** {@code topology.optimization} */ + public static final String TOPOLOGY_OPTIMIZATION_CONFIG = "topology.optimization"; + private static final String TOPOLOGY_OPTIMIZATION_DOC = "A configuration telling Kafka Streams if it should optimize the topology, disabled by default"; + + /** {@code window.size.ms} */ + public static final String WINDOW_SIZE_MS_CONFIG = "window.size.ms"; + private static final String WINDOW_SIZE_MS_DOC = "Sets window size for the deserializer in order to calculate window end times."; + + /** {@code upgrade.from} */ + @SuppressWarnings("WeakerAccess") + public static final String UPGRADE_FROM_CONFIG = "upgrade.from"; + private static final String UPGRADE_FROM_DOC = "Allows upgrading in a backward compatible way. " + + "This is needed when upgrading from [0.10.0, 1.1] to 2.0+, or when upgrading from [2.0, 2.3] to 2.4+. " + + "When upgrading from 2.4 to a newer version it is not required to specify this config. Default is `null`. " + + "Accepted values are \"" + UPGRADE_FROM_0100 + "\", \"" + UPGRADE_FROM_0101 + "\", \"" + + UPGRADE_FROM_0102 + "\", \"" + UPGRADE_FROM_0110 + "\", \"" + UPGRADE_FROM_10 + "\", \"" + + UPGRADE_FROM_11 + "\", \"" + UPGRADE_FROM_20 + "\", \"" + UPGRADE_FROM_21 + "\", \"" + + UPGRADE_FROM_22 + "\", \"" + UPGRADE_FROM_23 + "\" (for upgrading from the corresponding old version)."; + + /** {@code windowstore.changelog.additional.retention.ms} */ + @SuppressWarnings("WeakerAccess") + public static final String WINDOW_STORE_CHANGE_LOG_ADDITIONAL_RETENTION_MS_CONFIG = "windowstore.changelog.additional.retention.ms"; + private static final String WINDOW_STORE_CHANGE_LOG_ADDITIONAL_RETENTION_MS_DOC = "Added to a windows maintainMs to ensure data is not deleted from the log prematurely. Allows for clock drift. Default is 1 day"; + + /** + * {@code topology.optimization} + * @deprecated since 2.7; use {@link #TOPOLOGY_OPTIMIZATION_CONFIG} instead + */ + @Deprecated + public static final String TOPOLOGY_OPTIMIZATION = TOPOLOGY_OPTIMIZATION_CONFIG; + + + private static final String[] NON_CONFIGURABLE_CONSUMER_DEFAULT_CONFIGS = + new String[] {ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG}; + private static final String[] NON_CONFIGURABLE_CONSUMER_EOS_CONFIGS = + new String[] {ConsumerConfig.ISOLATION_LEVEL_CONFIG}; + private static final String[] NON_CONFIGURABLE_PRODUCER_EOS_CONFIGS = + new String[] { + ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG, + ProducerConfig.MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION, + ProducerConfig.TRANSACTIONAL_ID_CONFIG + }; + + static { + CONFIG = new ConfigDef() + + // HIGH + + .define(APPLICATION_ID_CONFIG, // required with no default value + Type.STRING, + Importance.HIGH, + APPLICATION_ID_DOC) + .define(BOOTSTRAP_SERVERS_CONFIG, // required with no default value + Type.LIST, + Importance.HIGH, + CommonClientConfigs.BOOTSTRAP_SERVERS_DOC) + .define(NUM_STANDBY_REPLICAS_CONFIG, + Type.INT, + 0, + Importance.HIGH, + NUM_STANDBY_REPLICAS_DOC) + .define(STATE_DIR_CONFIG, + Type.STRING, + System.getProperty("java.io.tmpdir") + File.separator + "kafka-streams", + Importance.HIGH, + STATE_DIR_DOC) + + // MEDIUM + + .define(ACCEPTABLE_RECOVERY_LAG_CONFIG, + Type.LONG, + 10_000L, + atLeast(0), + Importance.MEDIUM, + ACCEPTABLE_RECOVERY_LAG_DOC) + .define(CACHE_MAX_BYTES_BUFFERING_CONFIG, + Type.LONG, + 10 * 1024 * 1024L, + atLeast(0), + Importance.MEDIUM, + CACHE_MAX_BYTES_BUFFERING_DOC) + .define(CLIENT_ID_CONFIG, + Type.STRING, + "", + Importance.MEDIUM, + CLIENT_ID_DOC) + .define(DEFAULT_DESERIALIZATION_EXCEPTION_HANDLER_CLASS_CONFIG, + Type.CLASS, + LogAndFailExceptionHandler.class.getName(), + Importance.MEDIUM, + DEFAULT_DESERIALIZATION_EXCEPTION_HANDLER_CLASS_DOC) + .define(DEFAULT_KEY_SERDE_CLASS_CONFIG, + Type.CLASS, + null, + Importance.MEDIUM, + DEFAULT_KEY_SERDE_CLASS_DOC) + .define(CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_INNER_CLASS, + Type.CLASS, + null, + Importance.MEDIUM, + CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_INNER_CLASS_DOC) + .define(CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_INNER_CLASS, + Type.CLASS, + null, + Importance.MEDIUM, + CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_INNER_CLASS_DOC) + .define(CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_TYPE_CLASS, + Type.CLASS, + null, + Importance.MEDIUM, + CommonClientConfigs.DEFAULT_LIST_KEY_SERDE_TYPE_CLASS_DOC) + .define(CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_TYPE_CLASS, + Type.CLASS, + null, + Importance.MEDIUM, + CommonClientConfigs.DEFAULT_LIST_VALUE_SERDE_TYPE_CLASS_DOC) + .define(DEFAULT_PRODUCTION_EXCEPTION_HANDLER_CLASS_CONFIG, + Type.CLASS, + DefaultProductionExceptionHandler.class.getName(), + Importance.MEDIUM, + DEFAULT_PRODUCTION_EXCEPTION_HANDLER_CLASS_DOC) + .define(DEFAULT_TIMESTAMP_EXTRACTOR_CLASS_CONFIG, + Type.CLASS, + FailOnInvalidTimestamp.class.getName(), + Importance.MEDIUM, + DEFAULT_TIMESTAMP_EXTRACTOR_CLASS_DOC) + .define(DEFAULT_VALUE_SERDE_CLASS_CONFIG, + Type.CLASS, + null, + Importance.MEDIUM, + DEFAULT_VALUE_SERDE_CLASS_DOC) + .define(MAX_TASK_IDLE_MS_CONFIG, + Type.LONG, + 0L, + Importance.MEDIUM, + MAX_TASK_IDLE_MS_DOC) + .define(MAX_WARMUP_REPLICAS_CONFIG, + Type.INT, + 2, + atLeast(1), + Importance.MEDIUM, + MAX_WARMUP_REPLICAS_DOC) + .define(NUM_STREAM_THREADS_CONFIG, + Type.INT, + 1, + Importance.MEDIUM, + NUM_STREAM_THREADS_DOC) + .define(PROCESSING_GUARANTEE_CONFIG, + Type.STRING, + AT_LEAST_ONCE, + in(AT_LEAST_ONCE, EXACTLY_ONCE, EXACTLY_ONCE_BETA, EXACTLY_ONCE_V2), + Importance.MEDIUM, + PROCESSING_GUARANTEE_DOC) + .define(REPLICATION_FACTOR_CONFIG, + Type.INT, + -1, + Importance.MEDIUM, + REPLICATION_FACTOR_DOC) + .define(SECURITY_PROTOCOL_CONFIG, + Type.STRING, + CommonClientConfigs.DEFAULT_SECURITY_PROTOCOL, + Importance.MEDIUM, + CommonClientConfigs.SECURITY_PROTOCOL_DOC) + .define(TASK_TIMEOUT_MS_CONFIG, + Type.LONG, + Duration.ofMinutes(5L).toMillis(), + atLeast(0L), + Importance.MEDIUM, + TASK_TIMEOUT_MS_DOC) + .define(TOPOLOGY_OPTIMIZATION_CONFIG, + Type.STRING, + NO_OPTIMIZATION, + in(NO_OPTIMIZATION, OPTIMIZE), + Importance.MEDIUM, + TOPOLOGY_OPTIMIZATION_DOC) + + // LOW + + .define(APPLICATION_SERVER_CONFIG, + Type.STRING, + "", + Importance.LOW, + APPLICATION_SERVER_DOC) + .define(BUFFERED_RECORDS_PER_PARTITION_CONFIG, + Type.INT, + 1000, + Importance.LOW, + BUFFERED_RECORDS_PER_PARTITION_DOC) + .define(BUILT_IN_METRICS_VERSION_CONFIG, + Type.STRING, + METRICS_LATEST, + in( + METRICS_LATEST + ), + Importance.LOW, + BUILT_IN_METRICS_VERSION_DOC) + .define(COMMIT_INTERVAL_MS_CONFIG, + Type.LONG, + DEFAULT_COMMIT_INTERVAL_MS, + atLeast(0), + Importance.LOW, + COMMIT_INTERVAL_MS_DOC) + .define(CONNECTIONS_MAX_IDLE_MS_CONFIG, + ConfigDef.Type.LONG, + 9 * 60 * 1000L, + ConfigDef.Importance.LOW, + CommonClientConfigs.CONNECTIONS_MAX_IDLE_MS_DOC) + .define(METADATA_MAX_AGE_CONFIG, + ConfigDef.Type.LONG, + 5 * 60 * 1000L, + atLeast(0), + ConfigDef.Importance.LOW, + CommonClientConfigs.METADATA_MAX_AGE_DOC) + .define(METRICS_NUM_SAMPLES_CONFIG, + Type.INT, + 2, + atLeast(1), + Importance.LOW, + CommonClientConfigs.METRICS_NUM_SAMPLES_DOC) + .define(METRIC_REPORTER_CLASSES_CONFIG, + Type.LIST, + "", + Importance.LOW, + CommonClientConfigs.METRIC_REPORTER_CLASSES_DOC) + .define(METRICS_RECORDING_LEVEL_CONFIG, + Type.STRING, + Sensor.RecordingLevel.INFO.toString(), + in(Sensor.RecordingLevel.INFO.toString(), Sensor.RecordingLevel.DEBUG.toString(), RecordingLevel.TRACE.toString()), + Importance.LOW, + CommonClientConfigs.METRICS_RECORDING_LEVEL_DOC) + .define(METRICS_SAMPLE_WINDOW_MS_CONFIG, + Type.LONG, + 30000L, + atLeast(0), + Importance.LOW, + CommonClientConfigs.METRICS_SAMPLE_WINDOW_MS_DOC) + .define(POLL_MS_CONFIG, + Type.LONG, + 100L, + Importance.LOW, + POLL_MS_DOC) + .define(PROBING_REBALANCE_INTERVAL_MS_CONFIG, + Type.LONG, + 10 * 60 * 1000L, + atLeast(60 * 1000L), + Importance.LOW, + PROBING_REBALANCE_INTERVAL_MS_DOC) + .define(RECEIVE_BUFFER_CONFIG, + Type.INT, + 32 * 1024, + atLeast(CommonClientConfigs.RECEIVE_BUFFER_LOWER_BOUND), + Importance.LOW, + CommonClientConfigs.RECEIVE_BUFFER_DOC) + .define(RECONNECT_BACKOFF_MS_CONFIG, + Type.LONG, + 50L, + atLeast(0L), + Importance.LOW, + CommonClientConfigs.RECONNECT_BACKOFF_MS_DOC) + .define(RECONNECT_BACKOFF_MAX_MS_CONFIG, + Type.LONG, + 1000L, + atLeast(0L), + ConfigDef.Importance.LOW, + CommonClientConfigs.RECONNECT_BACKOFF_MAX_MS_DOC) + .define(RETRIES_CONFIG, + Type.INT, + 0, + between(0, Integer.MAX_VALUE), + ConfigDef.Importance.LOW, + CommonClientConfigs.RETRIES_DOC) + .define(RETRY_BACKOFF_MS_CONFIG, + Type.LONG, + 100L, + atLeast(0L), + ConfigDef.Importance.LOW, + CommonClientConfigs.RETRY_BACKOFF_MS_DOC) + .define(REQUEST_TIMEOUT_MS_CONFIG, + Type.INT, + 40 * 1000, + atLeast(0), + ConfigDef.Importance.LOW, + CommonClientConfigs.REQUEST_TIMEOUT_MS_DOC) + .define(ROCKSDB_CONFIG_SETTER_CLASS_CONFIG, + Type.CLASS, + null, + Importance.LOW, + ROCKSDB_CONFIG_SETTER_CLASS_DOC) + .define(SEND_BUFFER_CONFIG, + Type.INT, + 128 * 1024, + atLeast(CommonClientConfigs.SEND_BUFFER_LOWER_BOUND), + Importance.LOW, + CommonClientConfigs.SEND_BUFFER_DOC) + .define(STATE_CLEANUP_DELAY_MS_CONFIG, + Type.LONG, + 10 * 60 * 1000L, + Importance.LOW, + STATE_CLEANUP_DELAY_MS_DOC) + .define(UPGRADE_FROM_CONFIG, + ConfigDef.Type.STRING, + null, + in(null, + UPGRADE_FROM_0100, + UPGRADE_FROM_0101, + UPGRADE_FROM_0102, + UPGRADE_FROM_0110, + UPGRADE_FROM_10, + UPGRADE_FROM_11, + UPGRADE_FROM_20, + UPGRADE_FROM_21, + UPGRADE_FROM_22, + UPGRADE_FROM_23), + Importance.LOW, + UPGRADE_FROM_DOC) + .define(WINDOWED_INNER_CLASS_SERDE, + Type.STRING, + null, + Importance.LOW, + WINDOWED_INNER_CLASS_SERDE_DOC) + .define(WINDOW_STORE_CHANGE_LOG_ADDITIONAL_RETENTION_MS_CONFIG, + Type.LONG, + 24 * 60 * 60 * 1000L, + Importance.LOW, + WINDOW_STORE_CHANGE_LOG_ADDITIONAL_RETENTION_MS_DOC) + .define(WINDOW_SIZE_MS_CONFIG, + Type.LONG, + null, + Importance.LOW, + WINDOW_SIZE_MS_DOC); + } + + // this is the list of configs for underlying clients + // that streams prefer different default values + private static final Map PRODUCER_DEFAULT_OVERRIDES; + static { + final Map tempProducerDefaultOverrides = new HashMap<>(); + tempProducerDefaultOverrides.put(ProducerConfig.LINGER_MS_CONFIG, "100"); + PRODUCER_DEFAULT_OVERRIDES = Collections.unmodifiableMap(tempProducerDefaultOverrides); + } + + private static final Map PRODUCER_EOS_OVERRIDES; + static { + final Map tempProducerDefaultOverrides = new HashMap<>(PRODUCER_DEFAULT_OVERRIDES); + tempProducerDefaultOverrides.put(ProducerConfig.DELIVERY_TIMEOUT_MS_CONFIG, Integer.MAX_VALUE); + tempProducerDefaultOverrides.put(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG, true); + // Reduce the transaction timeout for quicker pending offset expiration on broker side. + tempProducerDefaultOverrides.put(ProducerConfig.TRANSACTION_TIMEOUT_CONFIG, DEFAULT_TRANSACTION_TIMEOUT); + + PRODUCER_EOS_OVERRIDES = Collections.unmodifiableMap(tempProducerDefaultOverrides); + } + + private static final Map CONSUMER_DEFAULT_OVERRIDES; + static { + final Map tempConsumerDefaultOverrides = new HashMap<>(); + tempConsumerDefaultOverrides.put(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, "1000"); + tempConsumerDefaultOverrides.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + tempConsumerDefaultOverrides.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false"); + tempConsumerDefaultOverrides.put("internal.leave.group.on.close", false); + CONSUMER_DEFAULT_OVERRIDES = Collections.unmodifiableMap(tempConsumerDefaultOverrides); + } + + private static final Map CONSUMER_EOS_OVERRIDES; + static { + final Map tempConsumerDefaultOverrides = new HashMap<>(CONSUMER_DEFAULT_OVERRIDES); + tempConsumerDefaultOverrides.put(ConsumerConfig.ISOLATION_LEVEL_CONFIG, READ_COMMITTED.name().toLowerCase(Locale.ROOT)); + CONSUMER_EOS_OVERRIDES = Collections.unmodifiableMap(tempConsumerDefaultOverrides); + } + + public static class InternalConfig { + // This is settable in the main Streams config, but it's a private API for now + public static final String INTERNAL_TASK_ASSIGNOR_CLASS = "internal.task.assignor.class"; + + // These are not settable in the main Streams config; they are set by the StreamThread to pass internal + // state into the assignor. + public static final String REFERENCE_CONTAINER_PARTITION_ASSIGNOR = "__reference.container.instance__"; + + // This is settable in the main Streams config, but it's a private API for testing + public static final String ASSIGNMENT_LISTENER = "__assignment.listener__"; + + // Private API used to control the emit latency for left/outer join results (https://issues.apache.org/jira/browse/KAFKA-10847) + public static final String EMIT_INTERVAL_MS_KSTREAMS_OUTER_JOIN_SPURIOUS_RESULTS_FIX = "__emit.interval.ms.kstreams.outer.join.spurious.results.fix__"; + + public static boolean getBoolean(final Map configs, final String key, final boolean defaultValue) { + final Object value = configs.getOrDefault(key, defaultValue); + if (value instanceof Boolean) { + return (boolean) value; + } else if (value instanceof String) { + return Boolean.parseBoolean((String) value); + } else { + log.warn("Invalid value (" + value + ") on internal configuration '" + key + "'. Please specify a true/false value."); + return defaultValue; + } + } + + public static long getLong(final Map configs, final String key, final long defaultValue) { + final Object value = configs.getOrDefault(key, defaultValue); + if (value instanceof Number) { + return ((Number) value).longValue(); + } else if (value instanceof String) { + return Long.parseLong((String) value); + } else { + log.warn("Invalid value (" + value + ") on internal configuration '" + key + "'. Please specify a numeric value."); + return defaultValue; + } + } + } + + /** + * Prefix a property with {@link #CONSUMER_PREFIX}. This is used to isolate {@link ConsumerConfig consumer configs} + * from other client configs. + * + * @param consumerProp the consumer property to be masked + * @return {@link #CONSUMER_PREFIX} + {@code consumerProp} + */ + @SuppressWarnings("WeakerAccess") + public static String consumerPrefix(final String consumerProp) { + return CONSUMER_PREFIX + consumerProp; + } + + /** + * Prefix a property with {@link #MAIN_CONSUMER_PREFIX}. This is used to isolate {@link ConsumerConfig main consumer configs} + * from other client configs. + * + * @param consumerProp the consumer property to be masked + * @return {@link #MAIN_CONSUMER_PREFIX} + {@code consumerProp} + */ + @SuppressWarnings("WeakerAccess") + public static String mainConsumerPrefix(final String consumerProp) { + return MAIN_CONSUMER_PREFIX + consumerProp; + } + + /** + * Prefix a property with {@link #RESTORE_CONSUMER_PREFIX}. This is used to isolate {@link ConsumerConfig restore consumer configs} + * from other client configs. + * + * @param consumerProp the consumer property to be masked + * @return {@link #RESTORE_CONSUMER_PREFIX} + {@code consumerProp} + */ + @SuppressWarnings("WeakerAccess") + public static String restoreConsumerPrefix(final String consumerProp) { + return RESTORE_CONSUMER_PREFIX + consumerProp; + } + + /** + * Prefix a property with {@link #GLOBAL_CONSUMER_PREFIX}. This is used to isolate {@link ConsumerConfig global consumer configs} + * from other client configs. + * + * @param consumerProp the consumer property to be masked + * @return {@link #GLOBAL_CONSUMER_PREFIX} + {@code consumerProp} + */ + @SuppressWarnings("WeakerAccess") + public static String globalConsumerPrefix(final String consumerProp) { + return GLOBAL_CONSUMER_PREFIX + consumerProp; + } + + /** + * Prefix a property with {@link #PRODUCER_PREFIX}. This is used to isolate {@link ProducerConfig producer configs} + * from other client configs. + * + * @param producerProp the producer property to be masked + * @return PRODUCER_PREFIX + {@code producerProp} + */ + @SuppressWarnings("WeakerAccess") + public static String producerPrefix(final String producerProp) { + return PRODUCER_PREFIX + producerProp; + } + + /** + * Prefix a property with {@link #ADMIN_CLIENT_PREFIX}. This is used to isolate {@link AdminClientConfig admin configs} + * from other client configs. + * + * @param adminClientProp the admin client property to be masked + * @return ADMIN_CLIENT_PREFIX + {@code adminClientProp} + */ + @SuppressWarnings("WeakerAccess") + public static String adminClientPrefix(final String adminClientProp) { + return ADMIN_CLIENT_PREFIX + adminClientProp; + } + + /** + * Prefix a property with {@link #TOPIC_PREFIX} + * used to provide default topic configs to be applied when creating internal topics. + * + * @param topicProp the topic property to be masked + * @return TOPIC_PREFIX + {@code topicProp} + */ + @SuppressWarnings("WeakerAccess") + public static String topicPrefix(final String topicProp) { + return TOPIC_PREFIX + topicProp; + } + + /** + * Return a copy of the config definition. + * + * @return a copy of the config definition + */ + @SuppressWarnings("unused") + public static ConfigDef configDef() { + return new ConfigDef(CONFIG); + } + + /** + * Create a new {@code StreamsConfig} using the given properties. + * + * @param props properties that specify Kafka Streams and internal consumer/producer configuration + */ + public StreamsConfig(final Map props) { + this(props, true); + } + + protected StreamsConfig(final Map props, + final boolean doLog) { + super(CONFIG, props, doLog); + eosEnabled = StreamThread.eosEnabled(this); + + final String processingModeConfig = getString(StreamsConfig.PROCESSING_GUARANTEE_CONFIG); + if (processingModeConfig.equals(EXACTLY_ONCE)) { + log.warn("Configuration parameter `{}` is deprecated and will be removed in the 4.0.0 release. " + + "Please use `{}` instead. Note that this requires broker version 2.5+ so you should prepare " + + "to upgrade your brokers if necessary.", EXACTLY_ONCE, EXACTLY_ONCE_V2); + } + if (processingModeConfig.equals(EXACTLY_ONCE_BETA)) { + log.warn("Configuration parameter `{}` is deprecated and will be removed in the 4.0.0 release. " + + "Please use `{}` instead.", EXACTLY_ONCE_BETA, EXACTLY_ONCE_V2); + } + + if (props.containsKey(RETRIES_CONFIG)) { + log.warn("Configuration parameter `{}` is deprecated and will be removed in the 4.0.0 release.", RETRIES_CONFIG); + } + + if (eosEnabled) { + verifyEOSTransactionTimeoutCompatibility(); + } + } + + private void verifyEOSTransactionTimeoutCompatibility() { + final long commitInterval = getLong(COMMIT_INTERVAL_MS_CONFIG); + final String transactionTimeoutConfigKey = producerPrefix(ProducerConfig.TRANSACTION_TIMEOUT_CONFIG); + final int transactionTimeout = originals().containsKey(transactionTimeoutConfigKey) ? (int) parseType( + transactionTimeoutConfigKey, originals().get(transactionTimeoutConfigKey), Type.INT) : DEFAULT_TRANSACTION_TIMEOUT; + + if (transactionTimeout < commitInterval) { + throw new IllegalArgumentException(String.format("Transaction timeout %d was set lower than " + + "streams commit interval %d. This will cause ongoing transaction always timeout due to inactivity " + + "caused by long commit interval. Consider reconfiguring commit interval to match " + + "transaction timeout by tuning 'commit.interval.ms' config, or increase the transaction timeout to match " + + "commit interval by tuning `producer.transaction.timeout.ms` config.", + transactionTimeout, commitInterval)); + } + } + + @Override + protected Map postProcessParsedConfig(final Map parsedValues) { + final Map configUpdates = + CommonClientConfigs.postProcessReconnectBackoffConfigs(this, parsedValues); + + if (StreamThread.eosEnabled(this) && !originals().containsKey(COMMIT_INTERVAL_MS_CONFIG)) { + log.debug("Using {} default value of {} as exactly once is enabled.", + COMMIT_INTERVAL_MS_CONFIG, EOS_DEFAULT_COMMIT_INTERVAL_MS); + configUpdates.put(COMMIT_INTERVAL_MS_CONFIG, EOS_DEFAULT_COMMIT_INTERVAL_MS); + } + + return configUpdates; + } + + private Map getCommonConsumerConfigs() { + final Map clientProvidedProps = getClientPropsWithPrefix(CONSUMER_PREFIX, ConsumerConfig.configNames()); + + checkIfUnexpectedUserSpecifiedConsumerConfig(clientProvidedProps, NON_CONFIGURABLE_CONSUMER_DEFAULT_CONFIGS); + checkIfUnexpectedUserSpecifiedConsumerConfig(clientProvidedProps, NON_CONFIGURABLE_CONSUMER_EOS_CONFIGS); + + final Map consumerProps = new HashMap<>(eosEnabled ? CONSUMER_EOS_OVERRIDES : CONSUMER_DEFAULT_OVERRIDES); + if (StreamThread.processingMode(this) == StreamThread.ProcessingMode.EXACTLY_ONCE_V2) { + consumerProps.put("internal.throw.on.fetch.stable.offset.unsupported", true); + } + consumerProps.putAll(getClientCustomProps()); + consumerProps.putAll(clientProvidedProps); + + // bootstrap.servers should be from StreamsConfig + consumerProps.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, originals().get(BOOTSTRAP_SERVERS_CONFIG)); + + return consumerProps; + } + + private void checkIfUnexpectedUserSpecifiedConsumerConfig(final Map clientProvidedProps, + final String[] nonConfigurableConfigs) { + // Streams does not allow users to configure certain consumer/producer configurations, for example, + // enable.auto.commit. In cases where user tries to override such non-configurable + // consumer/producer configurations, log a warning and remove the user defined value from the Map. + // Thus the default values for these consumer/producer configurations that are suitable for + // Streams will be used instead. + + final String nonConfigurableConfigMessage = "Unexpected user-specified %s config: %s found. %sUser setting (%s) will be ignored and the Streams default setting (%s) will be used "; + final String eosMessage = PROCESSING_GUARANTEE_CONFIG + " is set to " + getString(PROCESSING_GUARANTEE_CONFIG) + ". Hence, "; + + for (final String config: nonConfigurableConfigs) { + if (clientProvidedProps.containsKey(config)) { + + if (CONSUMER_DEFAULT_OVERRIDES.containsKey(config)) { + if (!clientProvidedProps.get(config).equals(CONSUMER_DEFAULT_OVERRIDES.get(config))) { + log.warn(String.format(nonConfigurableConfigMessage, "consumer", config, "", clientProvidedProps.get(config), CONSUMER_DEFAULT_OVERRIDES.get(config))); + clientProvidedProps.remove(config); + } + } else if (eosEnabled) { + if (CONSUMER_EOS_OVERRIDES.containsKey(config)) { + if (!clientProvidedProps.get(config).equals(CONSUMER_EOS_OVERRIDES.get(config))) { + log.warn(String.format(nonConfigurableConfigMessage, + "consumer", config, eosMessage, clientProvidedProps.get(config), CONSUMER_EOS_OVERRIDES.get(config))); + clientProvidedProps.remove(config); + } + } else if (PRODUCER_EOS_OVERRIDES.containsKey(config)) { + if (!clientProvidedProps.get(config).equals(PRODUCER_EOS_OVERRIDES.get(config))) { + log.warn(String.format(nonConfigurableConfigMessage, + "producer", config, eosMessage, clientProvidedProps.get(config), PRODUCER_EOS_OVERRIDES.get(config))); + clientProvidedProps.remove(config); + } + } else if (ProducerConfig.TRANSACTIONAL_ID_CONFIG.equals(config)) { + log.warn(String.format(nonConfigurableConfigMessage, + "producer", config, eosMessage, clientProvidedProps.get(config), "-")); + clientProvidedProps.remove(config); + } + } + } + } + + if (eosEnabled) { + verifyMaxInFlightRequestPerConnection(clientProvidedProps.get(ProducerConfig.MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION)); + } + } + + private void verifyMaxInFlightRequestPerConnection(final Object maxInFlightRequests) { + if (maxInFlightRequests != null) { + final int maxInFlightRequestsAsInteger; + if (maxInFlightRequests instanceof Integer) { + maxInFlightRequestsAsInteger = (Integer) maxInFlightRequests; + } else if (maxInFlightRequests instanceof String) { + try { + maxInFlightRequestsAsInteger = Integer.parseInt(((String) maxInFlightRequests).trim()); + } catch (final NumberFormatException e) { + throw new ConfigException(ProducerConfig.MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION, maxInFlightRequests, "String value could not be parsed as 32-bit integer"); + } + } else { + throw new ConfigException(ProducerConfig.MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION, maxInFlightRequests, "Expected value to be a 32-bit integer, but it was a " + maxInFlightRequests.getClass().getName()); + } + + if (maxInFlightRequestsAsInteger > 5) { + throw new ConfigException(ProducerConfig.MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION, maxInFlightRequestsAsInteger, "Can't exceed 5 when exactly-once processing is enabled"); + } + } + } + + /** + * Get the configs to the {@link KafkaConsumer main consumer}. + * Properties using the prefix {@link #MAIN_CONSUMER_PREFIX} will be used in favor over + * the properties prefixed with {@link #CONSUMER_PREFIX} and the non-prefixed versions + * (read the override precedence ordering in {@link #MAIN_CONSUMER_PREFIX} + * except in the case of {@link ConsumerConfig#BOOTSTRAP_SERVERS_CONFIG} where we always use the non-prefixed + * version as we only support reading/writing from/to the same Kafka Cluster. + * If not specified by {@link #MAIN_CONSUMER_PREFIX}, main consumer will share the general consumer configs + * prefixed by {@link #CONSUMER_PREFIX}. + * + * @param groupId consumer groupId + * @param clientId clientId + * @param threadIdx stream thread index + * @return Map of the consumer configuration. + */ + @SuppressWarnings("WeakerAccess") + public Map getMainConsumerConfigs(final String groupId, final String clientId, final int threadIdx) { + final Map consumerProps = getCommonConsumerConfigs(); + + // Get main consumer override configs + final Map mainConsumerProps = originalsWithPrefix(MAIN_CONSUMER_PREFIX); + for (final Map.Entry entry: mainConsumerProps.entrySet()) { + consumerProps.put(entry.getKey(), entry.getValue()); + } + + // this is a hack to work around StreamsConfig constructor inside StreamsPartitionAssignor to avoid casting + consumerProps.put(APPLICATION_ID_CONFIG, groupId); + + // add group id, client id with stream client id prefix, and group instance id + consumerProps.put(ConsumerConfig.GROUP_ID_CONFIG, groupId); + consumerProps.put(CommonClientConfigs.CLIENT_ID_CONFIG, clientId); + final String groupInstanceId = (String) consumerProps.get(ConsumerConfig.GROUP_INSTANCE_ID_CONFIG); + // Suffix each thread consumer with thread.id to enforce uniqueness of group.instance.id. + if (groupInstanceId != null) { + consumerProps.put(ConsumerConfig.GROUP_INSTANCE_ID_CONFIG, groupInstanceId + "-" + threadIdx); + } + + // add configs required for stream partition assignor + consumerProps.put(UPGRADE_FROM_CONFIG, getString(UPGRADE_FROM_CONFIG)); + consumerProps.put(REPLICATION_FACTOR_CONFIG, getInt(REPLICATION_FACTOR_CONFIG)); + consumerProps.put(APPLICATION_SERVER_CONFIG, getString(APPLICATION_SERVER_CONFIG)); + consumerProps.put(NUM_STANDBY_REPLICAS_CONFIG, getInt(NUM_STANDBY_REPLICAS_CONFIG)); + consumerProps.put(ACCEPTABLE_RECOVERY_LAG_CONFIG, getLong(ACCEPTABLE_RECOVERY_LAG_CONFIG)); + consumerProps.put(MAX_WARMUP_REPLICAS_CONFIG, getInt(MAX_WARMUP_REPLICAS_CONFIG)); + consumerProps.put(PROBING_REBALANCE_INTERVAL_MS_CONFIG, getLong(PROBING_REBALANCE_INTERVAL_MS_CONFIG)); + consumerProps.put(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG, StreamsPartitionAssignor.class.getName()); + consumerProps.put(WINDOW_STORE_CHANGE_LOG_ADDITIONAL_RETENTION_MS_CONFIG, getLong(WINDOW_STORE_CHANGE_LOG_ADDITIONAL_RETENTION_MS_CONFIG)); + + // disable auto topic creation + consumerProps.put(ConsumerConfig.ALLOW_AUTO_CREATE_TOPICS_CONFIG, "false"); + + // verify that producer batch config is no larger than segment size, then add topic configs required for creating topics + final Map topicProps = originalsWithPrefix(TOPIC_PREFIX, false); + final Map producerProps = getClientPropsWithPrefix(PRODUCER_PREFIX, ProducerConfig.configNames()); + + if (topicProps.containsKey(topicPrefix(TopicConfig.SEGMENT_BYTES_CONFIG)) && + producerProps.containsKey(ProducerConfig.BATCH_SIZE_CONFIG)) { + final int segmentSize = Integer.parseInt(topicProps.get(topicPrefix(TopicConfig.SEGMENT_BYTES_CONFIG)).toString()); + final int batchSize = Integer.parseInt(producerProps.get(ProducerConfig.BATCH_SIZE_CONFIG).toString()); + + if (segmentSize < batchSize) { + throw new IllegalArgumentException(String.format("Specified topic segment size %d is is smaller than the configured producer batch size %d, this will cause produced batch not able to be appended to the topic", + segmentSize, + batchSize)); + } + } + + consumerProps.putAll(topicProps); + + return consumerProps; + } + + /** + * Get the configs for the {@link KafkaConsumer restore-consumer}. + * Properties using the prefix {@link #RESTORE_CONSUMER_PREFIX} will be used in favor over + * the properties prefixed with {@link #CONSUMER_PREFIX} and the non-prefixed versions + * (read the override precedence ordering in {@link #RESTORE_CONSUMER_PREFIX} + * except in the case of {@link ConsumerConfig#BOOTSTRAP_SERVERS_CONFIG} where we always use the non-prefixed + * version as we only support reading/writing from/to the same Kafka Cluster. + * If not specified by {@link #RESTORE_CONSUMER_PREFIX}, restore consumer will share the general consumer configs + * prefixed by {@link #CONSUMER_PREFIX}. + * + * @param clientId clientId + * @return Map of the restore consumer configuration. + */ + @SuppressWarnings("WeakerAccess") + public Map getRestoreConsumerConfigs(final String clientId) { + final Map baseConsumerProps = getCommonConsumerConfigs(); + + // Get restore consumer override configs + final Map restoreConsumerProps = originalsWithPrefix(RESTORE_CONSUMER_PREFIX); + for (final Map.Entry entry: restoreConsumerProps.entrySet()) { + baseConsumerProps.put(entry.getKey(), entry.getValue()); + } + + // no need to set group id for a restore consumer + baseConsumerProps.remove(ConsumerConfig.GROUP_ID_CONFIG); + // no need to set instance id for a restore consumer + baseConsumerProps.remove(ConsumerConfig.GROUP_INSTANCE_ID_CONFIG); + + // add client id with stream client id prefix + baseConsumerProps.put(CommonClientConfigs.CLIENT_ID_CONFIG, clientId); + baseConsumerProps.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "none"); + + return baseConsumerProps; + } + + /** + * Get the configs for the {@link KafkaConsumer global consumer}. + * Properties using the prefix {@link #GLOBAL_CONSUMER_PREFIX} will be used in favor over + * the properties prefixed with {@link #CONSUMER_PREFIX} and the non-prefixed versions + * (read the override precedence ordering in {@link #GLOBAL_CONSUMER_PREFIX} + * except in the case of {@link ConsumerConfig#BOOTSTRAP_SERVERS_CONFIG} where we always use the non-prefixed + * version as we only support reading/writing from/to the same Kafka Cluster. + * If not specified by {@link #GLOBAL_CONSUMER_PREFIX}, global consumer will share the general consumer configs + * prefixed by {@link #CONSUMER_PREFIX}. + * + * @param clientId clientId + * @return Map of the global consumer configuration. + */ + @SuppressWarnings("WeakerAccess") + public Map getGlobalConsumerConfigs(final String clientId) { + final Map baseConsumerProps = getCommonConsumerConfigs(); + + // Get global consumer override configs + final Map globalConsumerProps = originalsWithPrefix(GLOBAL_CONSUMER_PREFIX); + for (final Map.Entry entry: globalConsumerProps.entrySet()) { + baseConsumerProps.put(entry.getKey(), entry.getValue()); + } + + // no need to set group id for a global consumer + baseConsumerProps.remove(ConsumerConfig.GROUP_ID_CONFIG); + // no need to set instance id for a restore consumer + baseConsumerProps.remove(ConsumerConfig.GROUP_INSTANCE_ID_CONFIG); + + // add client id with stream client id prefix + baseConsumerProps.put(CommonClientConfigs.CLIENT_ID_CONFIG, clientId + "-global-consumer"); + baseConsumerProps.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "none"); + + return baseConsumerProps; + } + + /** + * Get the configs for the {@link KafkaProducer producer}. + * Properties using the prefix {@link #PRODUCER_PREFIX} will be used in favor over their non-prefixed versions + * except in the case of {@link ProducerConfig#BOOTSTRAP_SERVERS_CONFIG} where we always use the non-prefixed + * version as we only support reading/writing from/to the same Kafka Cluster. + * + * @param clientId clientId + * @return Map of the producer configuration. + */ + @SuppressWarnings("WeakerAccess") + public Map getProducerConfigs(final String clientId) { + final Map clientProvidedProps = getClientPropsWithPrefix(PRODUCER_PREFIX, ProducerConfig.configNames()); + + checkIfUnexpectedUserSpecifiedConsumerConfig(clientProvidedProps, NON_CONFIGURABLE_PRODUCER_EOS_CONFIGS); + + // generate producer configs from original properties and overridden maps + final Map props = new HashMap<>(eosEnabled ? PRODUCER_EOS_OVERRIDES : PRODUCER_DEFAULT_OVERRIDES); + props.putAll(getClientCustomProps()); + props.putAll(clientProvidedProps); + + // When using EOS alpha, stream should auto-downgrade the transactional commit protocol to be compatible with older brokers. + if (StreamThread.processingMode(this) == StreamThread.ProcessingMode.EXACTLY_ONCE_ALPHA) { + props.put("internal.auto.downgrade.txn.commit", true); + } + + props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, originals().get(BOOTSTRAP_SERVERS_CONFIG)); + // add client id with stream client id prefix + props.put(CommonClientConfigs.CLIENT_ID_CONFIG, clientId); + + return props; + } + + /** + * Get the configs for the {@link Admin admin client}. + * @param clientId clientId + * @return Map of the admin client configuration. + */ + @SuppressWarnings("WeakerAccess") + public Map getAdminConfigs(final String clientId) { + final Map clientProvidedProps = getClientPropsWithPrefix(ADMIN_CLIENT_PREFIX, AdminClientConfig.configNames()); + + final Map props = new HashMap<>(); + props.putAll(getClientCustomProps()); + props.putAll(clientProvidedProps); + + // add client id with stream client id prefix + props.put(CommonClientConfigs.CLIENT_ID_CONFIG, clientId); + + return props; + } + + private Map getClientPropsWithPrefix(final String prefix, + final Set configNames) { + final Map props = clientProps(configNames, originals()); + props.putAll(originalsWithPrefix(prefix)); + return props; + } + + /** + * Get a map of custom configs by removing from the originals all the Streams, Consumer, Producer, and AdminClient configs. + * Prefixed properties are also removed because they are already added by {@link #getClientPropsWithPrefix(String, Set)}. + * This allows to set a custom property for a specific client alone if specified using a prefix, or for all + * when no prefix is used. + * + * @return a map with the custom properties + */ + private Map getClientCustomProps() { + final Map props = originals(); + props.keySet().removeAll(CONFIG.names()); + props.keySet().removeAll(ConsumerConfig.configNames()); + props.keySet().removeAll(ProducerConfig.configNames()); + props.keySet().removeAll(AdminClientConfig.configNames()); + props.keySet().removeAll(originalsWithPrefix(CONSUMER_PREFIX, false).keySet()); + props.keySet().removeAll(originalsWithPrefix(PRODUCER_PREFIX, false).keySet()); + props.keySet().removeAll(originalsWithPrefix(ADMIN_CLIENT_PREFIX, false).keySet()); + return props; + } + + /** + * Return an {@link Serde#configure(Map, boolean) configured} instance of {@link #DEFAULT_KEY_SERDE_CLASS_CONFIG key Serde + * class}. + * + * @return an configured instance of key Serde class + */ + @SuppressWarnings("WeakerAccess") + public Serde defaultKeySerde() { + final Object keySerdeConfigSetting = get(DEFAULT_KEY_SERDE_CLASS_CONFIG); + if (keySerdeConfigSetting == null) { + throw new ConfigException("Please specify a key serde or set one through StreamsConfig#DEFAULT_KEY_SERDE_CLASS_CONFIG"); + } + try { + final Serde serde = getConfiguredInstance(DEFAULT_KEY_SERDE_CLASS_CONFIG, Serde.class); + serde.configure(originals(), true); + return serde; + } catch (final Exception e) { + throw new StreamsException( + String.format("Failed to configure key serde %s", keySerdeConfigSetting), e); + } + } + + /** + * Return an {@link Serde#configure(Map, boolean) configured} instance of {@link #DEFAULT_VALUE_SERDE_CLASS_CONFIG value + * Serde class}. + * + * @return an configured instance of value Serde class + */ + @SuppressWarnings("WeakerAccess") + public Serde defaultValueSerde() { + final Object valueSerdeConfigSetting = get(DEFAULT_VALUE_SERDE_CLASS_CONFIG); + if (valueSerdeConfigSetting == null) { + throw new ConfigException("Please specify a value serde or set one through StreamsConfig#DEFAULT_VALUE_SERDE_CLASS_CONFIG"); + } + try { + final Serde serde = getConfiguredInstance(DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serde.class); + serde.configure(originals(), false); + return serde; + } catch (final Exception e) { + throw new StreamsException( + String.format("Failed to configure value serde %s", valueSerdeConfigSetting), e); + } + } + + @SuppressWarnings("WeakerAccess") + public TimestampExtractor defaultTimestampExtractor() { + return getConfiguredInstance(DEFAULT_TIMESTAMP_EXTRACTOR_CLASS_CONFIG, TimestampExtractor.class); + } + + @SuppressWarnings("WeakerAccess") + public DeserializationExceptionHandler defaultDeserializationExceptionHandler() { + return getConfiguredInstance(DEFAULT_DESERIALIZATION_EXCEPTION_HANDLER_CLASS_CONFIG, DeserializationExceptionHandler.class); + } + + @SuppressWarnings("WeakerAccess") + public ProductionExceptionHandler defaultProductionExceptionHandler() { + return getConfiguredInstance(DEFAULT_PRODUCTION_EXCEPTION_HANDLER_CLASS_CONFIG, ProductionExceptionHandler.class); + } + + /** + * Override any client properties in the original configs with overrides + * + * @param configNames The given set of configuration names. + * @param originals The original configs to be filtered. + * @return client config with any overrides + */ + private Map clientProps(final Set configNames, + final Map originals) { + // iterate all client config names, filter out non-client configs from the original + // property map and use the overridden values when they are not specified by users + final Map parsed = new HashMap<>(); + for (final String configName: configNames) { + if (originals.containsKey(configName)) { + parsed.put(configName, originals.get(configName)); + } + } + + return parsed; + } + + public static void main(final String[] args) { + System.out.println(CONFIG.toHtml(4, config -> "streamsconfigs_" + config)); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/StreamsMetadata.java b/streams/src/main/java/org/apache/kafka/streams/StreamsMetadata.java new file mode 100644 index 0000000..11c4941 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/StreamsMetadata.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.streams.state.HostInfo; + +import java.util.Set; + +/** + * Metadata of a Kafka Streams client. + */ +public interface StreamsMetadata { + + /** + * The value of {@link StreamsConfig#APPLICATION_SERVER_CONFIG} configured for the Streams + * client. + * + * @return {@link HostInfo} corresponding to the Streams client + */ + HostInfo hostInfo(); + + /** + * Names of the state stores assigned to active tasks of the Streams client. + * + * @return names of the state stores assigned to active tasks + */ + Set stateStoreNames(); + + /** + * Source topic partitions of the active tasks of the Streams client. + * + * @return source topic partitions of the active tasks + */ + Set topicPartitions(); + + /** + * Changelog topic partitions for the state stores the standby tasks of the Streams client replicates. + * + * @return set of changelog topic partitions of the standby tasks + */ + Set standbyTopicPartitions(); + + /** + * Names of the state stores assigned to standby tasks of the Streams client. + * + * @return names of the state stores assigned to standby tasks + */ + Set standbyStateStoreNames(); + + /** + * Host where the Streams client runs. + * + * This method is equivalent to {@code StreamsMetadata.hostInfo().host();} + * + * @return the host where the Streams client runs + */ + String host(); + + /** + * Port on which the Streams client listens. + * + * This method is equivalent to {@code StreamsMetadata.hostInfo().port();} + * + * @return the port on which Streams client listens + */ + int port(); + + /** + * Compares the specified object with this StreamsMetadata. Returns {@code true} if and only if the specified object is + * also a StreamsMetadata and for both {@code hostInfo()} are equal, and {@code stateStoreNames()}, {@code topicPartitions()}, + * {@code standbyStateStoreNames()}, and {@code standbyTopicPartitions()} contain the same elements. + * + * @return {@code true} if this object is the same as the obj argument; {@code false} otherwise. + */ + boolean equals(Object o); + + /** + * Returns the hash code value for this TaskMetadata. The hash code of a list is defined to be the result of the following calculation: + *

                +     * {@code
                +     * Objects.hash(hostInfo(), stateStoreNames(), topicPartitions(), standbyStateStoreNames(), standbyTopicPartitions());
                +     * }
                +     * 
                + * + * @return a hash code value for this object. + */ + int hashCode(); + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/StreamsMetrics.java b/streams/src/main/java/org/apache/kafka/streams/StreamsMetrics.java new file mode 100644 index 0000000..cbf2169 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/StreamsMetrics.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.Sensor; + +import java.util.Map; + +/** + * The Kafka Streams metrics interface for adding metric sensors and collecting metric values. + */ +public interface StreamsMetrics { + + /** + * Get read-only handle on global metrics registry. + * + * @return Map of all metrics. + */ + Map metrics(); + + /** + * Add a latency, rate and total sensor for a specific operation, which will include the following metrics: + *
                  + *
                1. average latency
                2. + *
                3. max latency
                4. + *
                5. invocation rate (num.operations / seconds)
                6. + *
                7. total invocation count
                8. + *
                + * Whenever a user records this sensor via {@link Sensor#record(double)} etc, it will be counted as one invocation + * of the operation, and hence the rate / count metrics will be updated accordingly; and the recorded latency value + * will be used to update the average / max latency as well. + * + * Note that you can add more metrics to this sensor after you created it, which can then be updated upon + * {@link Sensor#record(double)} calls. + * + * The added sensor and its metrics can be removed with {@link #removeSensor(Sensor) removeSensor()}. + * + * @param scopeName name of the scope, which will be used as part of the metric type, e.g.: "stream-[scope]-metrics". + * @param entityName name of the entity, which will be used as part of the metric tags, e.g.: "[scope]-id" = "[entity]". + * @param operationName name of the operation, which will be used as the name of the metric, e.g.: "[operation]-latency-avg". + * @param recordingLevel the recording level (e.g., INFO or DEBUG) for this sensor. + * @param tags additional tags of the sensor + * @return The added sensor. + * @see #addRateTotalSensor(String, String, String, Sensor.RecordingLevel, String...) + * @see #removeSensor(Sensor) + * @see #addSensor(String, Sensor.RecordingLevel, Sensor...) + */ + Sensor addLatencyRateTotalSensor(final String scopeName, + final String entityName, + final String operationName, + final Sensor.RecordingLevel recordingLevel, + final String... tags); + + /** + * Add a rate and a total sensor for a specific operation, which will include the following metrics: + *
                  + *
                1. invocation rate (num.operations / time unit)
                2. + *
                3. total invocation count
                4. + *
                + * Whenever a user records this sensor via {@link Sensor#record(double)} etc, + * it will be counted as one invocation of the operation, and hence the rate / count metrics will be updated accordingly. + * + * Note that you can add more metrics to this sensor after you created it, which can then be updated upon + * {@link Sensor#record(double)} calls. + * + * The added sensor and its metrics can be removed with {@link #removeSensor(Sensor) removeSensor()}. + * + * @param scopeName name of the scope, which will be used as part of the metrics type, e.g.: "stream-[scope]-metrics". + * @param entityName name of the entity, which will be used as part of the metric tags, e.g.: "[scope]-id" = "[entity]". + * @param operationName name of the operation, which will be used as the name of the metric, e.g.: "[operation]-total". + * @param recordingLevel the recording level (e.g., INFO or DEBUG) for this sensor. + * @param tags additional tags of the sensor + * @return The added sensor. + * @see #addLatencyRateTotalSensor(String, String, String, Sensor.RecordingLevel, String...) + * @see #removeSensor(Sensor) + * @see #addSensor(String, Sensor.RecordingLevel, Sensor...) + */ + Sensor addRateTotalSensor(final String scopeName, + final String entityName, + final String operationName, + final Sensor.RecordingLevel recordingLevel, + final String... tags); + + /** + * Generic method to create a sensor. + * Note that for most cases it is advisable to use + * {@link #addRateTotalSensor(String, String, String, Sensor.RecordingLevel, String...) addRateTotalSensor()} + * or {@link #addLatencyRateTotalSensor(String, String, String, Sensor.RecordingLevel, String...) addLatencyRateTotalSensor()} + * to ensure metric name well-formedness and conformity with the rest of the Kafka Streams code base. + * However, if the above two methods are not sufficient, this method can also be used. + * + * @param name name of the sensor. + * @param recordingLevel the recording level (e.g., INFO or DEBUG) for this sensor + * @return The added sensor. + * @see #addRateTotalSensor(String, String, String, Sensor.RecordingLevel, String...) + * @see #addLatencyRateTotalSensor(String, String, String, Sensor.RecordingLevel, String...) + * @see #removeSensor(Sensor) + */ + Sensor addSensor(final String name, + final Sensor.RecordingLevel recordingLevel); + + /** + * Generic method to create a sensor with parent sensors. + * Note that for most cases it is advisable to use + * {@link #addRateTotalSensor(String, String, String, Sensor.RecordingLevel, String...) addRateTotalSensor()} + * or {@link #addLatencyRateTotalSensor(String, String, String, Sensor.RecordingLevel, String...) addLatencyRateTotalSensor()} + * to ensure metric name well-formedness and conformity with the rest of the Kafka Streams code base. + * However, if the above two methods are not sufficient, this method can also be used. + * + * @param name name of the sensor + * @param recordingLevel the recording level (e.g., INFO or DEBUG) for this sensor + * @return The added sensor. + * @see #addRateTotalSensor(String, String, String, Sensor.RecordingLevel, String...) + * @see #addLatencyRateTotalSensor(String, String, String, Sensor.RecordingLevel, String...) + * @see #removeSensor(Sensor) + */ + Sensor addSensor(final String name, + final Sensor.RecordingLevel recordingLevel, + final Sensor... parents); + + /** + * Remove a sensor. + * @param sensor sensor to be removed + */ + void removeSensor(final Sensor sensor); +} + + diff --git a/streams/src/main/java/org/apache/kafka/streams/TaskMetadata.java b/streams/src/main/java/org/apache/kafka/streams/TaskMetadata.java new file mode 100644 index 0000000..0ef7429 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/TaskMetadata.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.streams.processor.TaskId; + +import java.util.Map; +import java.util.Optional; +import java.util.Set; + + +/** + * Metadata of a task. + */ +public interface TaskMetadata { + + /** + * Task ID of the task. + * + * @return task ID consisting of subtopology and partition ID + */ + TaskId taskId(); + + /** + * Source topic partitions of the task. + * + * @return source topic partitions + */ + Set topicPartitions(); + + /** + * Offsets of the source topic partitions committed so far by the task. + * + * @return map from source topic partitions to committed offsets + */ + Map committedOffsets(); + + /** + * End offsets of the source topic partitions of the task. + * + * @return map source topic partition to end offsets + */ + Map endOffsets(); + + /** + * Time task idling started. If the task is not currently idling it will return empty. + * + * @return time when task idling started, empty {@code Optional} if the task is currently not idling + */ + Optional timeCurrentIdlingStarted(); + + /** + * Compares the specified object with this TaskMetadata. Returns {@code true} if and only if the specified object is + * also a TaskMetadata and both {@code taskId()} and {@code topicPartitions()} are equal. + * + * @return {@code true} if this object is the same as the obj argument; {@code false} otherwise. + */ + boolean equals(final Object o); + + /** + * Returns the hash code value for this TaskMetadata. The hash code of a list is defined to be the result of the following calculation: + *
                +     * {@code
                +     * Objects.hash(taskId(), topicPartitions());
                +     * }
                +     * 
                + * + * @return a hash code value for this object. + */ + int hashCode(); + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/ThreadMetadata.java b/streams/src/main/java/org/apache/kafka/streams/ThreadMetadata.java new file mode 100644 index 0000000..f611fe7 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/ThreadMetadata.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import java.util.Set; + +/** + * Metadata of a stream thread. + */ +public interface ThreadMetadata { + + + /** + * State of the stream thread + * + * @return the state + */ + String threadState(); + + /** + * Name of the stream thread + * + * @return the name + */ + String threadName(); + + /** + * Metadata of the active tasks assigned to the stream thread. + * + * @return metadata of the active tasks + */ + Set activeTasks(); + + /** + * Metadata of the standby tasks assigned to the stream thread. + * + * @return metadata of the standby tasks + + */ + Set standbyTasks(); + + /** + * Client ID of the Kafka consumer used by the stream thread. + * + * @return client ID of the Kafka consumer + */ + String consumerClientId(); + + /** + * Client ID of the restore Kafka consumer used by the stream thread + * + * @return client ID of the restore Kafka consumer + */ + String restoreConsumerClientId(); + + /** + * Client IDs of the Kafka producers used by the stream thread. + * + * @return client IDs of the Kafka producers + */ + Set producerClientIds(); + + /** + * Client ID of the admin client used by the stream thread. + * + * @return client ID of the admin client + */ + String adminClientId(); + + /** + * Compares the specified object with this ThreadMetadata. Returns {@code true} if and only if the specified object is + * also a ThreadMetadata and both {@code threadName()} are equal, {@code threadState()} are equal, {@code activeTasks()} contain the same + * elements, {@code standbyTasks()} contain the same elements, {@code mainConsumerClientId()} are equal, {@code restoreConsumerClientId()} + * are equal, {@code producerClientIds()} are equal, {@code producerClientIds} contain the same elements, and {@code adminClientId()} are equal. + * + * @return {@code true} if this object is the same as the obj argument; {@code false} otherwise. + */ + boolean equals(Object o); + + /** + * Returns the hash code value for this ThreadMetadata. The hash code of a list is defined to be the result of the following calculation: + *
                +     * {@code
                +     * Objects.hash(
                +     *             threadName,
                +     *             threadState,
                +     *             activeTasks,
                +     *             standbyTasks,
                +     *             mainConsumerClientId,
                +     *             restoreConsumerClientId,
                +     *             producerClientIds,
                +     *             adminClientId
                +     *             );
                +     * }
                +     * 
                + * + * @return a hash code value for this object. + */ + int hashCode(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/Topology.java b/streams/src/main/java/org/apache/kafka/streams/Topology.java new file mode 100644 index 0000000..7c45de7 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/Topology.java @@ -0,0 +1,941 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.errors.TopologyException; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.processor.ConnectedStoreProvider; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StreamPartitioner; +import org.apache.kafka.streams.processor.TimestampExtractor; +import org.apache.kafka.streams.processor.TopicNameExtractor; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; +import org.apache.kafka.streams.processor.internals.ProcessorAdapter; +import org.apache.kafka.streams.processor.internals.ProcessorNode; +import org.apache.kafka.streams.processor.internals.ProcessorTopology; +import org.apache.kafka.streams.processor.internals.SinkNode; +import org.apache.kafka.streams.processor.internals.SourceNode; +import org.apache.kafka.streams.state.StoreBuilder; + +import java.util.Set; +import java.util.regex.Pattern; + +/** + * A logical representation of a {@link ProcessorTopology}. + * A topology is an acyclic graph of sources, processors, and sinks. + * A {@link SourceNode source} is a node in the graph that consumes one or more Kafka topics and forwards them to its + * successor nodes. + * A {@link Processor processor} is a node in the graph that receives input records from upstream nodes, processes the + * records, and optionally forwarding new records to one or all of its downstream nodes. + * Finally, a {@link SinkNode sink} is a node in the graph that receives records from upstream nodes and writes them to + * a Kafka topic. + * A {@code Topology} allows you to construct an acyclic graph of these nodes, and then passed into a new + * {@link KafkaStreams} instance that will then {@link KafkaStreams#start() begin consuming, processing, and producing + * records}. + */ +public class Topology { + + protected final InternalTopologyBuilder internalTopologyBuilder = new InternalTopologyBuilder(); + + /** + * Sets the {@code auto.offset.reset} configuration when + * {@link #addSource(AutoOffsetReset, String, String...) adding a source processor} or when creating {@link KStream} + * or {@link KTable} via {@link StreamsBuilder}. + */ + public enum AutoOffsetReset { + EARLIEST, LATEST + } + + /** + * Add a new source that consumes the named topics and forward the records to child processor and/or sink nodes. + * The source will use the {@link StreamsConfig#DEFAULT_KEY_SERDE_CLASS_CONFIG default key deserializer} and + * {@link StreamsConfig#DEFAULT_VALUE_SERDE_CLASS_CONFIG default value deserializer} specified in the + * {@link StreamsConfig stream configuration}. + * The default {@link TimestampExtractor} as specified in the {@link StreamsConfig config} is used. + * + * @param name the unique name of the source used to reference this node when + * {@link #addProcessor(String, ProcessorSupplier, String...) adding processor children}. + * @param topics the name of one or more Kafka topics that this source is to consume + * @return itself + * @throws TopologyException if processor is already added or if topics have already been registered by another source + */ + public synchronized Topology addSource(final String name, + final String... topics) { + internalTopologyBuilder.addSource(null, name, null, null, null, topics); + return this; + } + + /** + * Add a new source that consumes from topics matching the given pattern + * and forward the records to child processor and/or sink nodes. + * The source will use the {@link StreamsConfig#DEFAULT_KEY_SERDE_CLASS_CONFIG default key deserializer} and + * {@link StreamsConfig#DEFAULT_VALUE_SERDE_CLASS_CONFIG default value deserializer} specified in the + * {@link StreamsConfig stream configuration}. + * The default {@link TimestampExtractor} as specified in the {@link StreamsConfig config} is used. + * + * @param name the unique name of the source used to reference this node when + * {@link #addProcessor(String, ProcessorSupplier, String...) adding processor children}. + * @param topicPattern regular expression pattern to match Kafka topics that this source is to consume + * @return itself + * @throws TopologyException if processor is already added or if topics have already been registered by another source + */ + public synchronized Topology addSource(final String name, + final Pattern topicPattern) { + internalTopologyBuilder.addSource(null, name, null, null, null, topicPattern); + return this; + } + + /** + * Add a new source that consumes the named topics and forward the records to child processor and/or sink nodes. + * The source will use the {@link StreamsConfig#DEFAULT_KEY_SERDE_CLASS_CONFIG default key deserializer} and + * {@link StreamsConfig#DEFAULT_VALUE_SERDE_CLASS_CONFIG default value deserializer} specified in the + * {@link StreamsConfig stream configuration}. + * The default {@link TimestampExtractor} as specified in the {@link StreamsConfig config} is used. + * + * @param offsetReset the auto offset reset policy to use for this source if no committed offsets found; acceptable values earliest or latest + * @param name the unique name of the source used to reference this node when + * {@link #addProcessor(String, ProcessorSupplier, String...) adding processor children}. + * @param topics the name of one or more Kafka topics that this source is to consume + * @return itself + * @throws TopologyException if processor is already added or if topics have already been registered by another source + */ + public synchronized Topology addSource(final AutoOffsetReset offsetReset, + final String name, + final String... topics) { + internalTopologyBuilder.addSource(offsetReset, name, null, null, null, topics); + return this; + } + + /** + * Add a new source that consumes from topics matching the given pattern + * and forward the records to child processor and/or sink nodes. + * The source will use the {@link StreamsConfig#DEFAULT_KEY_SERDE_CLASS_CONFIG default key deserializer} and + * {@link StreamsConfig#DEFAULT_VALUE_SERDE_CLASS_CONFIG default value deserializer} specified in the + * {@link StreamsConfig stream configuration}. + * The default {@link TimestampExtractor} as specified in the {@link StreamsConfig config} is used. + * + * @param offsetReset the auto offset reset policy value for this source if no committed offsets found; acceptable values earliest or latest. + * @param name the unique name of the source used to reference this node when + * {@link #addProcessor(String, ProcessorSupplier, String...) adding processor children}. + * @param topicPattern regular expression pattern to match Kafka topics that this source is to consume + * @return itself + * @throws TopologyException if processor is already added or if topics have already been registered by another source + */ + public synchronized Topology addSource(final AutoOffsetReset offsetReset, + final String name, + final Pattern topicPattern) { + internalTopologyBuilder.addSource(offsetReset, name, null, null, null, topicPattern); + return this; + } + + /** + * Add a new source that consumes the named topics and forward the records to child processor and/or sink nodes. + * The source will use the {@link StreamsConfig#DEFAULT_KEY_SERDE_CLASS_CONFIG default key deserializer} and + * {@link StreamsConfig#DEFAULT_VALUE_SERDE_CLASS_CONFIG default value deserializer} specified in the + * {@link StreamsConfig stream configuration}. + * + * @param timestampExtractor the stateless timestamp extractor used for this source, + * if not specified the default extractor defined in the configs will be used + * @param name the unique name of the source used to reference this node when + * {@link #addProcessor(String, ProcessorSupplier, String...) adding processor children}. + * @param topics the name of one or more Kafka topics that this source is to consume + * @return itself + * @throws TopologyException if processor is already added or if topics have already been registered by another source + */ + public synchronized Topology addSource(final TimestampExtractor timestampExtractor, + final String name, + final String... topics) { + internalTopologyBuilder.addSource(null, name, timestampExtractor, null, null, topics); + return this; + } + + /** + * Add a new source that consumes from topics matching the given pattern + * and forward the records to child processor and/or sink nodes. + * The source will use the {@link StreamsConfig#DEFAULT_KEY_SERDE_CLASS_CONFIG default key deserializer} and + * {@link StreamsConfig#DEFAULT_VALUE_SERDE_CLASS_CONFIG default value deserializer} specified in the + * {@link StreamsConfig stream configuration}. + * + * @param timestampExtractor the stateless timestamp extractor used for this source, + * if not specified the default extractor defined in the configs will be used + * @param name the unique name of the source used to reference this node when + * {@link #addProcessor(String, ProcessorSupplier, String...) adding processor children}. + * @param topicPattern regular expression pattern to match Kafka topics that this source is to consume + * @return itself + * @throws TopologyException if processor is already added or if topics have already been registered by another source + */ + public synchronized Topology addSource(final TimestampExtractor timestampExtractor, + final String name, + final Pattern topicPattern) { + internalTopologyBuilder.addSource(null, name, timestampExtractor, null, null, topicPattern); + return this; + } + + /** + * Add a new source that consumes the named topics and forward the records to child processor and/or sink nodes. + * The source will use the {@link StreamsConfig#DEFAULT_KEY_SERDE_CLASS_CONFIG default key deserializer} and + * {@link StreamsConfig#DEFAULT_VALUE_SERDE_CLASS_CONFIG default value deserializer} specified in the + * {@link StreamsConfig stream configuration}. + * + * @param offsetReset the auto offset reset policy to use for this source if no committed offsets found; + * acceptable values earliest or latest + * @param timestampExtractor the stateless timestamp extractor used for this source, + * if not specified the default extractor defined in the configs will be used + * @param name the unique name of the source used to reference this node when + * {@link #addProcessor(String, ProcessorSupplier, String...) adding processor children}. + * @param topics the name of one or more Kafka topics that this source is to consume + * @return itself + * @throws TopologyException if processor is already added or if topics have already been registered by another source + */ + public synchronized Topology addSource(final AutoOffsetReset offsetReset, + final TimestampExtractor timestampExtractor, + final String name, + final String... topics) { + internalTopologyBuilder.addSource(offsetReset, name, timestampExtractor, null, null, topics); + return this; + } + + /** + * Add a new source that consumes from topics matching the given pattern and forward the records to child processor + * and/or sink nodes. + * The source will use the {@link StreamsConfig#DEFAULT_KEY_SERDE_CLASS_CONFIG default key deserializer} and + * {@link StreamsConfig#DEFAULT_VALUE_SERDE_CLASS_CONFIG default value deserializer} specified in the + * {@link StreamsConfig stream configuration}. + * + * @param offsetReset the auto offset reset policy value for this source if no committed offsets found; + * acceptable values earliest or latest. + * @param timestampExtractor the stateless timestamp extractor used for this source, + * if not specified the default extractor defined in the configs will be used + * @param name the unique name of the source used to reference this node when + * {@link #addProcessor(String, ProcessorSupplier, String...) adding processor children}. + * @param topicPattern regular expression pattern to match Kafka topics that this source is to consume + * @return itself + * @throws TopologyException if processor is already added or if topics have already been registered by another source + */ + public synchronized Topology addSource(final AutoOffsetReset offsetReset, + final TimestampExtractor timestampExtractor, + final String name, + final Pattern topicPattern) { + internalTopologyBuilder.addSource(offsetReset, name, timestampExtractor, null, null, topicPattern); + return this; + } + + /** + * Add a new source that consumes the named topics and forwards the records to child processor and/or sink nodes. + * The source will use the specified key and value deserializers. + * The default {@link TimestampExtractor} as specified in the {@link StreamsConfig config} is used. + * + * @param name the unique name of the source used to reference this node when + * {@link #addProcessor(String, ProcessorSupplier, String...) adding processor children} + * @param keyDeserializer key deserializer used to read this source, if not specified the default + * key deserializer defined in the configs will be used + * @param valueDeserializer value deserializer used to read this source, + * if not specified the default value deserializer defined in the configs will be used + * @param topics the name of one or more Kafka topics that this source is to consume + * @return itself + * @throws TopologyException if processor is already added or if topics have already been registered by another source + */ + public synchronized Topology addSource(final String name, + final Deserializer keyDeserializer, + final Deserializer valueDeserializer, + final String... topics) { + internalTopologyBuilder.addSource(null, name, null, keyDeserializer, valueDeserializer, topics); + return this; + } + + /** + * Add a new source that consumes from topics matching the given pattern and forwards the records to child processor + * and/or sink nodes. + * The source will use the specified key and value deserializers. + * The provided de-/serializers will be used for all matched topics, so care should be taken to specify patterns for + * topics that share the same key-value data format. + * The default {@link TimestampExtractor} as specified in the {@link StreamsConfig config} is used. + * + * @param name the unique name of the source used to reference this node when + * {@link #addProcessor(String, ProcessorSupplier, String...) adding processor children} + * @param keyDeserializer key deserializer used to read this source, if not specified the default + * key deserializer defined in the configs will be used + * @param valueDeserializer value deserializer used to read this source, + * if not specified the default value deserializer defined in the configs will be used + * @param topicPattern regular expression pattern to match Kafka topics that this source is to consume + * @return itself + * @throws TopologyException if processor is already added or if topics have already been registered by name + */ + public synchronized Topology addSource(final String name, + final Deserializer keyDeserializer, + final Deserializer valueDeserializer, + final Pattern topicPattern) { + internalTopologyBuilder.addSource(null, name, null, keyDeserializer, valueDeserializer, topicPattern); + return this; + } + + /** + * Add a new source that consumes from topics matching the given pattern and forwards the records to child processor + * and/or sink nodes. + * The source will use the specified key and value deserializers. + * The provided de-/serializers will be used for all the specified topics, so care should be taken when specifying + * topics that share the same key-value data format. + * + * @param offsetReset the auto offset reset policy to use for this stream if no committed offsets found; + * acceptable values are earliest or latest + * @param name the unique name of the source used to reference this node when + * {@link #addProcessor(String, ProcessorSupplier, String...) adding processor children} + * @param keyDeserializer key deserializer used to read this source, if not specified the default + * key deserializer defined in the configs will be used + * @param valueDeserializer value deserializer used to read this source, + * if not specified the default value deserializer defined in the configs will be used + * @param topics the name of one or more Kafka topics that this source is to consume + * @return itself + * @throws TopologyException if processor is already added or if topics have already been registered by name + */ + @SuppressWarnings("overloads") + public synchronized Topology addSource(final AutoOffsetReset offsetReset, + final String name, + final Deserializer keyDeserializer, + final Deserializer valueDeserializer, + final String... topics) { + internalTopologyBuilder.addSource(offsetReset, name, null, keyDeserializer, valueDeserializer, topics); + return this; + } + + /** + * Add a new source that consumes from topics matching the given pattern and forwards the records to child processor + * and/or sink nodes. + * The source will use the specified key and value deserializers. + * The provided de-/serializers will be used for all matched topics, so care should be taken to specify patterns for + * topics that share the same key-value data format. + * + * @param offsetReset the auto offset reset policy to use for this stream if no committed offsets found; + * acceptable values are earliest or latest + * @param name the unique name of the source used to reference this node when + * {@link #addProcessor(String, ProcessorSupplier, String...) adding processor children} + * @param keyDeserializer key deserializer used to read this source, if not specified the default + * key deserializer defined in the configs will be used + * @param valueDeserializer value deserializer used to read this source, + * if not specified the default value deserializer defined in the configs will be used + * @param topicPattern regular expression pattern to match Kafka topics that this source is to consume + * @return itself + * @throws TopologyException if processor is already added or if topics have already been registered by name + */ + public synchronized Topology addSource(final AutoOffsetReset offsetReset, + final String name, + final Deserializer keyDeserializer, + final Deserializer valueDeserializer, + final Pattern topicPattern) { + internalTopologyBuilder.addSource(offsetReset, name, null, keyDeserializer, valueDeserializer, topicPattern); + return this; + } + + /** + * Add a new source that consumes the named topics and forwards the records to child processor and/or sink nodes. + * The source will use the specified key and value deserializers. + * + * @param offsetReset the auto offset reset policy to use for this stream if no committed offsets found; + * acceptable values are earliest or latest. + * @param name the unique name of the source used to reference this node when + * {@link #addProcessor(String, ProcessorSupplier, String...) adding processor children}. + * @param timestampExtractor the stateless timestamp extractor used for this source, + * if not specified the default extractor defined in the configs will be used + * @param keyDeserializer key deserializer used to read this source, if not specified the default + * key deserializer defined in the configs will be used + * @param valueDeserializer value deserializer used to read this source, + * if not specified the default value deserializer defined in the configs will be used + * @param topics the name of one or more Kafka topics that this source is to consume + * @return itself + * @throws TopologyException if processor is already added or if topics have already been registered by another source + */ + @SuppressWarnings("overloads") + public synchronized Topology addSource(final AutoOffsetReset offsetReset, + final String name, + final TimestampExtractor timestampExtractor, + final Deserializer keyDeserializer, + final Deserializer valueDeserializer, + final String... topics) { + internalTopologyBuilder.addSource(offsetReset, name, timestampExtractor, keyDeserializer, valueDeserializer, topics); + return this; + } + + /** + * Add a new source that consumes from topics matching the given pattern and forwards the records to child processor + * and/or sink nodes. + * The source will use the specified key and value deserializers. + * The provided de-/serializers will be used for all matched topics, so care should be taken to specify patterns for + * topics that share the same key-value data format. + * + * @param offsetReset the auto offset reset policy to use for this stream if no committed offsets found; + * acceptable values are earliest or latest + * @param name the unique name of the source used to reference this node when + * {@link #addProcessor(String, ProcessorSupplier, String...) adding processor children}. + * @param timestampExtractor the stateless timestamp extractor used for this source, + * if not specified the default extractor defined in the configs will be used + * @param keyDeserializer key deserializer used to read this source, if not specified the default + * key deserializer defined in the configs will be used + * @param valueDeserializer value deserializer used to read this source, + * if not specified the default value deserializer defined in the configs will be used + * @param topicPattern regular expression pattern to match Kafka topics that this source is to consume + * @return itself + * @throws TopologyException if processor is already added or if topics have already been registered by name + */ + @SuppressWarnings("overloads") + public synchronized Topology addSource(final AutoOffsetReset offsetReset, + final String name, + final TimestampExtractor timestampExtractor, + final Deserializer keyDeserializer, + final Deserializer valueDeserializer, + final Pattern topicPattern) { + internalTopologyBuilder.addSource(offsetReset, name, timestampExtractor, keyDeserializer, valueDeserializer, topicPattern); + return this; + } + + /** + * Add a new sink that forwards records from upstream parent processor and/or source nodes to the named Kafka topic. + * The sink will use the {@link StreamsConfig#DEFAULT_KEY_SERDE_CLASS_CONFIG default key serializer} and + * {@link StreamsConfig#DEFAULT_VALUE_SERDE_CLASS_CONFIG default value serializer} specified in the + * {@link StreamsConfig stream configuration}. + * + * @param name the unique name of the sink + * @param topic the name of the Kafka topic to which this sink should write its records + * @param parentNames the name of one or more source or processor nodes whose output records this sink should consume + * and write to its topic + * @return itself + * @throws TopologyException if parent processor is not added yet, or if this processor's name is equal to the parent's name + * @see #addSink(String, String, StreamPartitioner, String...) + * @see #addSink(String, String, Serializer, Serializer, String...) + * @see #addSink(String, String, Serializer, Serializer, StreamPartitioner, String...) + */ + public synchronized Topology addSink(final String name, + final String topic, + final String... parentNames) { + internalTopologyBuilder.addSink(name, topic, null, null, null, parentNames); + return this; + } + + /** + * Add a new sink that forwards records from upstream parent processor and/or source nodes to the named Kafka topic, + * using the supplied partitioner. + * The sink will use the {@link StreamsConfig#DEFAULT_KEY_SERDE_CLASS_CONFIG default key serializer} and + * {@link StreamsConfig#DEFAULT_VALUE_SERDE_CLASS_CONFIG default value serializer} specified in the + * {@link StreamsConfig stream configuration}. + *

                + * The sink will also use the specified {@link StreamPartitioner} to determine how records are distributed among + * the named Kafka topic's partitions. + * Such control is often useful with topologies that use {@link #addStateStore(StoreBuilder, String...) state + * stores} in its processors. + * In most other cases, however, a partitioner needs not be specified and Kafka will automatically distribute + * records among partitions using Kafka's default partitioning logic. + * + * @param name the unique name of the sink + * @param topic the name of the Kafka topic to which this sink should write its records + * @param partitioner the function that should be used to determine the partition for each record processed by the sink + * @param parentNames the name of one or more source or processor nodes whose output records this sink should consume + * and write to its topic + * @return itself + * @throws TopologyException if parent processor is not added yet, or if this processor's name is equal to the parent's name + * @see #addSink(String, String, String...) + * @see #addSink(String, String, Serializer, Serializer, String...) + * @see #addSink(String, String, Serializer, Serializer, StreamPartitioner, String...) + */ + public synchronized Topology addSink(final String name, + final String topic, + final StreamPartitioner partitioner, + final String... parentNames) { + internalTopologyBuilder.addSink(name, topic, null, null, partitioner, parentNames); + return this; + } + + /** + * Add a new sink that forwards records from upstream parent processor and/or source nodes to the named Kafka topic. + * The sink will use the specified key and value serializers. + * + * @param name the unique name of the sink + * @param topic the name of the Kafka topic to which this sink should write its records + * @param keySerializer the {@link Serializer key serializer} used when consuming records; may be null if the sink + * should use the {@link StreamsConfig#DEFAULT_KEY_SERDE_CLASS_CONFIG default key serializer} specified in the + * {@link StreamsConfig stream configuration} + * @param valueSerializer the {@link Serializer value serializer} used when consuming records; may be null if the sink + * should use the {@link StreamsConfig#DEFAULT_VALUE_SERDE_CLASS_CONFIG default value serializer} specified in the + * {@link StreamsConfig stream configuration} + * @param parentNames the name of one or more source or processor nodes whose output records this sink should consume + * and write to its topic + * @return itself + * @throws TopologyException if parent processor is not added yet, or if this processor's name is equal to the parent's name + * @see #addSink(String, String, String...) + * @see #addSink(String, String, StreamPartitioner, String...) + * @see #addSink(String, String, Serializer, Serializer, StreamPartitioner, String...) + */ + public synchronized Topology addSink(final String name, + final String topic, + final Serializer keySerializer, + final Serializer valueSerializer, + final String... parentNames) { + internalTopologyBuilder.addSink(name, topic, keySerializer, valueSerializer, null, parentNames); + return this; + } + + /** + * Add a new sink that forwards records from upstream parent processor and/or source nodes to the named Kafka topic. + * The sink will use the specified key and value serializers, and the supplied partitioner. + * + * @param name the unique name of the sink + * @param topic the name of the Kafka topic to which this sink should write its records + * @param keySerializer the {@link Serializer key serializer} used when consuming records; may be null if the sink + * should use the {@link StreamsConfig#DEFAULT_KEY_SERDE_CLASS_CONFIG default key serializer} specified in the + * {@link StreamsConfig stream configuration} + * @param valueSerializer the {@link Serializer value serializer} used when consuming records; may be null if the sink + * should use the {@link StreamsConfig#DEFAULT_VALUE_SERDE_CLASS_CONFIG default value serializer} specified in the + * {@link StreamsConfig stream configuration} + * @param partitioner the function that should be used to determine the partition for each record processed by the sink + * @param parentNames the name of one or more source or processor nodes whose output records this sink should consume + * and write to its topic + * @return itself + * @throws TopologyException if parent processor is not added yet, or if this processor's name is equal to the parent's name + * @see #addSink(String, String, String...) + * @see #addSink(String, String, StreamPartitioner, String...) + * @see #addSink(String, String, Serializer, Serializer, String...) + */ + public synchronized Topology addSink(final String name, + final String topic, + final Serializer keySerializer, + final Serializer valueSerializer, + final StreamPartitioner partitioner, + final String... parentNames) { + internalTopologyBuilder.addSink(name, topic, keySerializer, valueSerializer, partitioner, parentNames); + return this; + } + + /** + * Add a new sink that forwards records from upstream parent processor and/or source nodes to Kafka topics based on {@code topicExtractor}. + * The topics that it may ever send to should be pre-created. + * The sink will use the {@link StreamsConfig#DEFAULT_KEY_SERDE_CLASS_CONFIG default key serializer} and + * {@link StreamsConfig#DEFAULT_VALUE_SERDE_CLASS_CONFIG default value serializer} specified in the + * {@link StreamsConfig stream configuration}. + * + * @param name the unique name of the sink + * @param topicExtractor the extractor to determine the name of the Kafka topic to which this sink should write for each record + * @param parentNames the name of one or more source or processor nodes whose output records this sink should consume + * and dynamically write to topics + * @return itself + * @throws TopologyException if parent processor is not added yet, or if this processor's name is equal to the parent's name + * @see #addSink(String, String, StreamPartitioner, String...) + * @see #addSink(String, String, Serializer, Serializer, String...) + * @see #addSink(String, String, Serializer, Serializer, StreamPartitioner, String...) + */ + public synchronized Topology addSink(final String name, + final TopicNameExtractor topicExtractor, + final String... parentNames) { + internalTopologyBuilder.addSink(name, topicExtractor, null, null, null, parentNames); + return this; + } + + /** + * Add a new sink that forwards records from upstream parent processor and/or source nodes to Kafka topics based on {@code topicExtractor}, + * using the supplied partitioner. + * The topics that it may ever send to should be pre-created. + * The sink will use the {@link StreamsConfig#DEFAULT_KEY_SERDE_CLASS_CONFIG default key serializer} and + * {@link StreamsConfig#DEFAULT_VALUE_SERDE_CLASS_CONFIG default value serializer} specified in the + * {@link StreamsConfig stream configuration}. + *

                + * The sink will also use the specified {@link StreamPartitioner} to determine how records are distributed among + * the named Kafka topic's partitions. + * Such control is often useful with topologies that use {@link #addStateStore(StoreBuilder, String...) state + * stores} in its processors. + * In most other cases, however, a partitioner needs not be specified and Kafka will automatically distribute + * records among partitions using Kafka's default partitioning logic. + * + * @param name the unique name of the sink + * @param topicExtractor the extractor to determine the name of the Kafka topic to which this sink should write for each record + * @param partitioner the function that should be used to determine the partition for each record processed by the sink + * @param parentNames the name of one or more source or processor nodes whose output records this sink should consume + * and dynamically write to topics + * @return itself + * @throws TopologyException if parent processor is not added yet, or if this processor's name is equal to the parent's name + * @see #addSink(String, String, String...) + * @see #addSink(String, String, Serializer, Serializer, String...) + * @see #addSink(String, String, Serializer, Serializer, StreamPartitioner, String...) + */ + public synchronized Topology addSink(final String name, + final TopicNameExtractor topicExtractor, + final StreamPartitioner partitioner, + final String... parentNames) { + internalTopologyBuilder.addSink(name, topicExtractor, null, null, partitioner, parentNames); + return this; + } + + /** + * Add a new sink that forwards records from upstream parent processor and/or source nodes to Kafka topics based on {@code topicExtractor}. + * The topics that it may ever send to should be pre-created. + * The sink will use the specified key and value serializers. + * + * @param name the unique name of the sink + * @param topicExtractor the extractor to determine the name of the Kafka topic to which this sink should write for each record + * @param keySerializer the {@link Serializer key serializer} used when consuming records; may be null if the sink + * should use the {@link StreamsConfig#DEFAULT_KEY_SERDE_CLASS_CONFIG default key serializer} specified in the + * {@link StreamsConfig stream configuration} + * @param valueSerializer the {@link Serializer value serializer} used when consuming records; may be null if the sink + * should use the {@link StreamsConfig#DEFAULT_VALUE_SERDE_CLASS_CONFIG default value serializer} specified in the + * {@link StreamsConfig stream configuration} + * @param parentNames the name of one or more source or processor nodes whose output records this sink should consume + * and dynamically write to topics + * @return itself + * @throws TopologyException if parent processor is not added yet, or if this processor's name is equal to the parent's name + * @see #addSink(String, String, String...) + * @see #addSink(String, String, StreamPartitioner, String...) + * @see #addSink(String, String, Serializer, Serializer, StreamPartitioner, String...) + */ + public synchronized Topology addSink(final String name, + final TopicNameExtractor topicExtractor, + final Serializer keySerializer, + final Serializer valueSerializer, + final String... parentNames) { + internalTopologyBuilder.addSink(name, topicExtractor, keySerializer, valueSerializer, null, parentNames); + return this; + } + + /** + * Add a new sink that forwards records from upstream parent processor and/or source nodes to Kafka topics based on {@code topicExtractor}. + * The topics that it may ever send to should be pre-created. + * The sink will use the specified key and value serializers, and the supplied partitioner. + * + * @param name the unique name of the sink + * @param topicExtractor the extractor to determine the name of the Kafka topic to which this sink should write for each record + * @param keySerializer the {@link Serializer key serializer} used when consuming records; may be null if the sink + * should use the {@link StreamsConfig#DEFAULT_KEY_SERDE_CLASS_CONFIG default key serializer} specified in the + * {@link StreamsConfig stream configuration} + * @param valueSerializer the {@link Serializer value serializer} used when consuming records; may be null if the sink + * should use the {@link StreamsConfig#DEFAULT_VALUE_SERDE_CLASS_CONFIG default value serializer} specified in the + * {@link StreamsConfig stream configuration} + * @param partitioner the function that should be used to determine the partition for each record processed by the sink + * @param parentNames the name of one or more source or processor nodes whose output records this sink should consume + * and dynamically write to topics + * @return itself + * @throws TopologyException if parent processor is not added yet, or if this processor's name is equal to the parent's name + * @see #addSink(String, String, String...) + * @see #addSink(String, String, StreamPartitioner, String...) + * @see #addSink(String, String, Serializer, Serializer, String...) + */ + public synchronized Topology addSink(final String name, + final TopicNameExtractor topicExtractor, + final Serializer keySerializer, + final Serializer valueSerializer, + final StreamPartitioner partitioner, + final String... parentNames) { + internalTopologyBuilder.addSink(name, topicExtractor, keySerializer, valueSerializer, partitioner, parentNames); + return this; + } + + /** + * Add a new processor node that receives and processes records output by one or more parent source or processor + * node. + * Any new record output by this processor will be forwarded to its child processor or sink nodes. + * The supplier should always generate a new instance each time + * {@link org.apache.kafka.streams.processor.ProcessorSupplier#get()} gets called. Creating a single + * {@link org.apache.kafka.streams.processor.Processor} object and returning the same object reference in + * {@link org.apache.kafka.streams.processor.ProcessorSupplier#get()} would be a violation of the supplier pattern + * and leads to runtime exceptions. + * If {@code supplier} provides stores via {@link ConnectedStoreProvider#stores()}, the provided {@link StoreBuilder}s + * will be added to the topology and connected to this processor automatically. + * + * @param name the unique name of the processor node + * @param supplier the supplier used to obtain this node's {@link org.apache.kafka.streams.processor.Processor} instance + * @param parentNames the name of one or more source or processor nodes whose output records this processor should receive + * and process + * @return itself + * @throws TopologyException if parent processor is not added yet, or if this processor's name is equal to the parent's name + * @deprecated Since 2.7.0 Use {@link #addProcessor(String, ProcessorSupplier, String...)} instead. + */ + @SuppressWarnings("rawtypes") + @Deprecated + public synchronized Topology addProcessor(final String name, + final org.apache.kafka.streams.processor.ProcessorSupplier supplier, + final String... parentNames) { + return addProcessor( + name, + new ProcessorSupplier() { + @Override + public Set> stores() { + return supplier.stores(); + } + + @Override + public org.apache.kafka.streams.processor.api.Processor get() { + return ProcessorAdapter.adaptRaw(supplier.get()); + } + }, + parentNames + ); + } + + /** + * Add a new processor node that receives and processes records output by one or more parent source or processor + * node. + * Any new record output by this processor will be forwarded to its child processor or sink nodes. + * If {@code supplier} provides stores via {@link ConnectedStoreProvider#stores()}, the provided {@link StoreBuilder}s + * will be added to the topology and connected to this processor automatically. + * + * @param name the unique name of the processor node + * @param supplier the supplier used to obtain this node's {@link Processor} instance + * @param parentNames the name of one or more source or processor nodes whose output records this processor should receive + * and process + * @return itself + * @throws TopologyException if parent processor is not added yet, or if this processor's name is equal to the parent's name + */ + public synchronized Topology addProcessor(final String name, + final ProcessorSupplier supplier, + final String... parentNames) { + internalTopologyBuilder.addProcessor(name, supplier, parentNames); + final Set> stores = supplier.stores(); + if (stores != null) { + for (final StoreBuilder storeBuilder : stores) { + internalTopologyBuilder.addStateStore(storeBuilder, name); + } + } + return this; + } + + /** + * Adds a state store. + * + * @param storeBuilder the storeBuilder used to obtain this state store {@link StateStore} instance + * @param processorNames the names of the processors that should be able to access the provided store + * @return itself + * @throws TopologyException if state store supplier is already added + */ + public synchronized Topology addStateStore(final StoreBuilder storeBuilder, + final String... processorNames) { + internalTopologyBuilder.addStateStore(storeBuilder, processorNames); + return this; + } + + /** + * Adds a global {@link StateStore} to the topology. + * The {@link StateStore} sources its data from all partitions of the provided input topic. + * There will be exactly one instance of this {@link StateStore} per Kafka Streams instance. + *

                + * A {@link SourceNode} with the provided sourceName will be added to consume the data arriving from the partitions + * of the input topic. + *

                + * The provided {@link org.apache.kafka.streams.processor.ProcessorSupplier} will be used to create an {@link ProcessorNode} that will receive all + * records forwarded from the {@link SourceNode}. + * The supplier should always generate a new instance each time + * {@link org.apache.kafka.streams.processor.ProcessorSupplier#get()} gets called. Creating a single + * {@link org.apache.kafka.streams.processor.Processor} object and returning the same object reference in + * {@link org.apache.kafka.streams.processor.ProcessorSupplier#get()} would be a violation of the supplier pattern + * and leads to runtime exceptions. + * This {@link ProcessorNode} should be used to keep the {@link StateStore} up-to-date. + * The default {@link TimestampExtractor} as specified in the {@link StreamsConfig config} is used. + * + * @param storeBuilder user defined state store builder + * @param sourceName name of the {@link SourceNode} that will be automatically added + * @param keyDeserializer the {@link Deserializer} to deserialize keys with + * @param valueDeserializer the {@link Deserializer} to deserialize values with + * @param topic the topic to source the data from + * @param processorName the name of the {@link org.apache.kafka.streams.processor.ProcessorSupplier} + * @param stateUpdateSupplier the instance of {@link org.apache.kafka.streams.processor.ProcessorSupplier} + * @return itself + * @throws TopologyException if the processor of state is already registered + * @deprecated Since 2.7.0. Use {@link #addGlobalStore(StoreBuilder, String, Deserializer, Deserializer, String, String, ProcessorSupplier)} instead. + */ + @Deprecated + public synchronized Topology addGlobalStore(final StoreBuilder storeBuilder, + final String sourceName, + final Deserializer keyDeserializer, + final Deserializer valueDeserializer, + final String topic, + final String processorName, + final org.apache.kafka.streams.processor.ProcessorSupplier stateUpdateSupplier) { + internalTopologyBuilder.addGlobalStore( + storeBuilder, + sourceName, + null, + keyDeserializer, + valueDeserializer, + topic, + processorName, + () -> ProcessorAdapter.adapt(stateUpdateSupplier.get()) + ); + return this; + } + + /** + * Adds a global {@link StateStore} to the topology. + * The {@link StateStore} sources its data from all partitions of the provided input topic. + * There will be exactly one instance of this {@link StateStore} per Kafka Streams instance. + *

                + * A {@link SourceNode} with the provided sourceName will be added to consume the data arriving from the partitions + * of the input topic. + *

                + * The provided {@link org.apache.kafka.streams.processor.ProcessorSupplier} will be used to create an {@link ProcessorNode} that will receive all + * records forwarded from the {@link SourceNode}. + * The supplier should always generate a new instance each time + * {@link org.apache.kafka.streams.processor.ProcessorSupplier#get()} gets called. Creating a single + * {@link org.apache.kafka.streams.processor.Processor} object and returning the same object reference in + * {@link org.apache.kafka.streams.processor.ProcessorSupplier#get()} would be a violation of the supplier pattern + * and leads to runtime exceptions. + * This {@link ProcessorNode} should be used to keep the {@link StateStore} up-to-date. + * + * @param storeBuilder user defined key value store builder + * @param sourceName name of the {@link SourceNode} that will be automatically added + * @param timestampExtractor the stateless timestamp extractor used for this source, + * if not specified the default extractor defined in the configs will be used + * @param keyDeserializer the {@link Deserializer} to deserialize keys with + * @param valueDeserializer the {@link Deserializer} to deserialize values with + * @param topic the topic to source the data from + * @param processorName the name of the {@link org.apache.kafka.streams.processor.ProcessorSupplier} + * @param stateUpdateSupplier the instance of {@link org.apache.kafka.streams.processor.ProcessorSupplier} + * @return itself + * @throws TopologyException if the processor of state is already registered + * @deprecated Since 2.7.0. Use {@link #addGlobalStore(StoreBuilder, String, TimestampExtractor, Deserializer, Deserializer, String, String, ProcessorSupplier)} instead. + */ + @Deprecated + public synchronized Topology addGlobalStore(final StoreBuilder storeBuilder, + final String sourceName, + final TimestampExtractor timestampExtractor, + final Deserializer keyDeserializer, + final Deserializer valueDeserializer, + final String topic, + final String processorName, + final org.apache.kafka.streams.processor.ProcessorSupplier stateUpdateSupplier) { + internalTopologyBuilder.addGlobalStore( + storeBuilder, + sourceName, + timestampExtractor, + keyDeserializer, + valueDeserializer, + topic, + processorName, + () -> ProcessorAdapter.adapt(stateUpdateSupplier.get()) + ); + return this; + } + + /** + * Adds a global {@link StateStore} to the topology. + * The {@link StateStore} sources its data from all partitions of the provided input topic. + * There will be exactly one instance of this {@link StateStore} per Kafka Streams instance. + *

                + * A {@link SourceNode} with the provided sourceName will be added to consume the data arriving from the partitions + * of the input topic. + *

                + * The provided {@link ProcessorSupplier} will be used to create an {@link ProcessorNode} that will receive all + * records forwarded from the {@link SourceNode}. + * This {@link ProcessorNode} should be used to keep the {@link StateStore} up-to-date. + * The default {@link TimestampExtractor} as specified in the {@link StreamsConfig config} is used. + * + * @param storeBuilder user defined state store builder + * @param sourceName name of the {@link SourceNode} that will be automatically added + * @param keyDeserializer the {@link Deserializer} to deserialize keys with + * @param valueDeserializer the {@link Deserializer} to deserialize values with + * @param topic the topic to source the data from + * @param processorName the name of the {@link ProcessorSupplier} + * @param stateUpdateSupplier the instance of {@link ProcessorSupplier} + * @return itself + * @throws TopologyException if the processor of state is already registered + */ + public synchronized Topology addGlobalStore(final StoreBuilder storeBuilder, + final String sourceName, + final Deserializer keyDeserializer, + final Deserializer valueDeserializer, + final String topic, + final String processorName, + final ProcessorSupplier stateUpdateSupplier) { + internalTopologyBuilder.addGlobalStore( + storeBuilder, + sourceName, + null, + keyDeserializer, + valueDeserializer, + topic, + processorName, + stateUpdateSupplier + ); + return this; + } + + /** + * Adds a global {@link StateStore} to the topology. + * The {@link StateStore} sources its data from all partitions of the provided input topic. + * There will be exactly one instance of this {@link StateStore} per Kafka Streams instance. + *

                + * A {@link SourceNode} with the provided sourceName will be added to consume the data arriving from the partitions + * of the input topic. + *

                + * The provided {@link ProcessorSupplier} will be used to create an {@link ProcessorNode} that will receive all + * records forwarded from the {@link SourceNode}. + * This {@link ProcessorNode} should be used to keep the {@link StateStore} up-to-date. + * + * @param storeBuilder user defined key value store builder + * @param sourceName name of the {@link SourceNode} that will be automatically added + * @param timestampExtractor the stateless timestamp extractor used for this source, + * if not specified the default extractor defined in the configs will be used + * @param keyDeserializer the {@link Deserializer} to deserialize keys with + * @param valueDeserializer the {@link Deserializer} to deserialize values with + * @param topic the topic to source the data from + * @param processorName the name of the {@link ProcessorSupplier} + * @param stateUpdateSupplier the instance of {@link ProcessorSupplier} + * @return itself + * @throws TopologyException if the processor of state is already registered + */ + public synchronized Topology addGlobalStore(final StoreBuilder storeBuilder, + final String sourceName, + final TimestampExtractor timestampExtractor, + final Deserializer keyDeserializer, + final Deserializer valueDeserializer, + final String topic, + final String processorName, + final ProcessorSupplier stateUpdateSupplier) { + internalTopologyBuilder.addGlobalStore( + storeBuilder, + sourceName, + timestampExtractor, + keyDeserializer, + valueDeserializer, + topic, + processorName, + stateUpdateSupplier + ); + return this; + } + + /** + * Connects the processor and the state stores. + * + * @param processorName the name of the processor + * @param stateStoreNames the names of state stores that the processor uses + * @return itself + * @throws TopologyException if the processor or a state store is unknown + */ + public synchronized Topology connectProcessorAndStateStores(final String processorName, + final String... stateStoreNames) { + internalTopologyBuilder.connectProcessorAndStateStores(processorName, stateStoreNames); + return this; + } + + /** + * Returns a description of the specified {@code Topology}. + * + * @return a description of the topology. + */ + + public synchronized TopologyDescription describe() { + return internalTopologyBuilder.describe(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/TopologyDescription.java b/streams/src/main/java/org/apache/kafka/streams/TopologyDescription.java new file mode 100644 index 0000000..6f26779 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/TopologyDescription.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import org.apache.kafka.streams.processor.TopicNameExtractor; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.internals.StreamTask; + +import java.util.Set; +import java.util.regex.Pattern; + +/** + * A meta representation of a {@link Topology topology}. + *

                + * The nodes of a topology are grouped into {@link Subtopology sub-topologies} if they are connected. + * In contrast, two sub-topologies are not connected but can be linked to each other via topics, i.e., if one + * sub-topology {@link Topology#addSink(String, String, String...) writes} into a topic and another sub-topology + * {@link Topology#addSource(String, String...) reads} from the same topic. + *

                + * When {@link KafkaStreams#start()} is called, different sub-topologies will be constructed and executed as independent + * {@link StreamTask tasks}. + */ +public interface TopologyDescription { + /** + * A connected sub-graph of a {@link Topology}. + *

                + * Nodes of a {@code Subtopology} are connected + * {@link Topology#addProcessor(String, ProcessorSupplier, String...) directly} or indirectly via + * {@link Topology#connectProcessorAndStateStores(String, String...) state stores} + * (i.e., if multiple processors share the same state). + */ + interface Subtopology { + /** + * Internally assigned unique ID. + * @return the ID of the sub-topology + */ + int id(); + + /** + * All nodes of this sub-topology. + * @return set of all nodes within the sub-topology + */ + Set nodes(); + } + + /** + * Represents a {@link Topology#addGlobalStore(org.apache.kafka.streams.state.StoreBuilder, String, + * org.apache.kafka.common.serialization.Deserializer, org.apache.kafka.common.serialization.Deserializer, String, + * String, org.apache.kafka.streams.processor.api.ProcessorSupplier) global store}. + * Adding a global store results in adding a source node and one stateful processor node. + * Note, that all added global stores form a single unit (similar to a {@link Subtopology}) even if different + * global stores are not connected to each other. + * Furthermore, global stores are available to all processors without connecting them explicitly, and thus global + * stores will never be part of any {@link Subtopology}. + */ + interface GlobalStore { + /** + * The source node reading from a "global" topic. + * @return the "global" source node + */ + Source source(); + + /** + * The processor node maintaining the global store. + * @return the "global" processor node + */ + Processor processor(); + + int id(); + } + + /** + * A node of a topology. Can be a source, sink, or processor node. + */ + interface Node { + /** + * The name of the node. Will never be {@code null}. + * @return the name of the node + */ + String name(); + /** + * The predecessors of this node within a sub-topology. + * Note, sources do not have any predecessors. + * Will never be {@code null}. + * @return set of all predecessors + */ + Set predecessors(); + /** + * The successor of this node within a sub-topology. + * Note, sinks do not have any successors. + * Will never be {@code null}. + * @return set of all successor + */ + Set successors(); + } + + + /** + * A source node of a topology. + */ + interface Source extends Node { + + /** + * The topic names this source node is reading from. + * @return a set of topic names + */ + Set topicSet(); + + /** + * The pattern used to match topic names that is reading from. + * @return the pattern used to match topic names + */ + Pattern topicPattern(); + } + + /** + * A processor node of a topology. + */ + interface Processor extends Node { + /** + * The names of all connected stores. + * @return set of store names + */ + Set stores(); + } + + /** + * A sink node of a topology. + */ + interface Sink extends Node { + /** + * The topic name this sink node is writing to. + * Could be {@code null} if the topic name can only be dynamically determined based on {@link TopicNameExtractor} + * @return a topic name + */ + String topic(); + + /** + * The {@link TopicNameExtractor} class that this sink node uses to dynamically extract the topic name to write to. + * Could be {@code null} if the topic name is not dynamically determined. + * @return the {@link TopicNameExtractor} class used get the topic name + */ + TopicNameExtractor topicNameExtractor(); + } + + /** + * All sub-topologies of the represented topology. + * @return set of all sub-topologies + */ + Set subtopologies(); + + /** + * All global stores of the represented topology. + * @return set of all global stores + */ + Set globalStores(); + +} + diff --git a/streams/src/main/java/org/apache/kafka/streams/errors/BrokerNotFoundException.java b/streams/src/main/java/org/apache/kafka/streams/errors/BrokerNotFoundException.java new file mode 100644 index 0000000..24c5fcb --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/errors/BrokerNotFoundException.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.errors; + + +/** + * Indicates that none of the specified {@link org.apache.kafka.streams.StreamsConfig#BOOTSTRAP_SERVERS_CONFIG brokers} + * could be found. + * + * @see org.apache.kafka.streams.StreamsConfig + */ +public class BrokerNotFoundException extends StreamsException { + + private final static long serialVersionUID = 1L; + + public BrokerNotFoundException(final String message) { + super(message); + } + + public BrokerNotFoundException(final String message, final Throwable throwable) { + super(message, throwable); + } + + public BrokerNotFoundException(final Throwable throwable) { + super(throwable); + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/errors/DefaultProductionExceptionHandler.java b/streams/src/main/java/org/apache/kafka/streams/errors/DefaultProductionExceptionHandler.java new file mode 100644 index 0000000..4fdb1a3 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/errors/DefaultProductionExceptionHandler.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.errors; + +import java.util.Map; +import org.apache.kafka.clients.producer.ProducerRecord; + +/** + * {@code ProductionExceptionHandler} that always instructs streams to fail when an exception + * happens while attempting to produce result records. + */ +public class DefaultProductionExceptionHandler implements ProductionExceptionHandler { + @Override + public ProductionExceptionHandlerResponse handle(final ProducerRecord record, + final Exception exception) { + return ProductionExceptionHandlerResponse.FAIL; + } + + @Override + public void configure(final Map configs) { + // ignore + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/errors/DeserializationExceptionHandler.java b/streams/src/main/java/org/apache/kafka/streams/errors/DeserializationExceptionHandler.java new file mode 100644 index 0000000..4c382b6 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/errors/DeserializationExceptionHandler.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.errors; + + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.Configurable; +import org.apache.kafka.streams.processor.ProcessorContext; + +/** + * Interface that specifies how an exception from source node deserialization + * (e.g., reading from Kafka) should be handled. + */ +public interface DeserializationExceptionHandler extends Configurable { + + /** + * Inspect a record and the exception received. + *

                + * Note, that the passed in {@link ProcessorContext} only allows to access metadata like the task ID. + * However, it cannot be used to emit records via {@link ProcessorContext#forward(Object, Object)}; + * calling {@code forward()} (and some other methods) would result in a runtime exception. + * + * @param context processor context + * @param record record that failed deserialization + * @param exception the actual exception + */ + DeserializationHandlerResponse handle(final ProcessorContext context, + final ConsumerRecord record, + final Exception exception); + + /** + * Enumeration that describes the response from the exception handler. + */ + enum DeserializationHandlerResponse { + /* continue with processing */ + CONTINUE(0, "CONTINUE"), + /* fail the processing and stop */ + FAIL(1, "FAIL"); + + /** an english description of the api--this is for debugging and can change */ + public final String name; + + /** the permanent and immutable id of an API--this can't change ever */ + public final int id; + + DeserializationHandlerResponse(final int id, final String name) { + this.id = id; + this.name = name; + } + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/errors/InvalidStateStoreException.java b/streams/src/main/java/org/apache/kafka/streams/errors/InvalidStateStoreException.java new file mode 100644 index 0000000..50c961b --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/errors/InvalidStateStoreException.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.errors; + + +/** + * Indicates that there was a problem when trying to access a {@link org.apache.kafka.streams.processor.StateStore StateStore}. + * {@code InvalidStateStoreException} is not thrown directly but only its following sub-classes. + */ +public class InvalidStateStoreException extends StreamsException { + + private final static long serialVersionUID = 1L; + + public InvalidStateStoreException(final String message) { + super(message); + } + + public InvalidStateStoreException(final String message, final Throwable throwable) { + super(message, throwable); + } + + public InvalidStateStoreException(final Throwable throwable) { + super(throwable); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/errors/InvalidStateStorePartitionException.java b/streams/src/main/java/org/apache/kafka/streams/errors/InvalidStateStorePartitionException.java new file mode 100644 index 0000000..e85a037 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/errors/InvalidStateStorePartitionException.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.errors; + +import org.apache.kafka.streams.KafkaStreams; + +/** + * Indicates that the specific state store being queried via + * {@link org.apache.kafka.streams.StoreQueryParameters} used a partitioning that is not assigned to this instance. + * You can use {@link KafkaStreams#metadataForAllStreamsClients()} to discover the correct instance that hosts the requested partition. + */ +public class InvalidStateStorePartitionException extends InvalidStateStoreException { + + private static final long serialVersionUID = 1L; + + public InvalidStateStorePartitionException(final String message) { + super(message); + } + + public InvalidStateStorePartitionException(final String message, final Throwable throwable) { + super(message, throwable); + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/errors/LockException.java b/streams/src/main/java/org/apache/kafka/streams/errors/LockException.java new file mode 100644 index 0000000..ddaa69f --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/errors/LockException.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.errors; + + +/** + * Indicates that the state store directory lock could not be acquired because another thread holds the lock. + * + * @see org.apache.kafka.streams.processor.StateStore + */ +public class LockException extends StreamsException { + + private final static long serialVersionUID = 1L; + + public LockException(final String message) { + super(message); + } + + public LockException(final String message, final Throwable throwable) { + super(message, throwable); + } + + public LockException(final Throwable throwable) { + super(throwable); + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/errors/LogAndContinueExceptionHandler.java b/streams/src/main/java/org/apache/kafka/streams/errors/LogAndContinueExceptionHandler.java new file mode 100644 index 0000000..4f9a096 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/errors/LogAndContinueExceptionHandler.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.errors; + + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Map; + +/** + * Deserialization handler that logs a deserialization exception and then + * signals the processing pipeline to continue processing more records. + */ +public class LogAndContinueExceptionHandler implements DeserializationExceptionHandler { + private static final Logger log = LoggerFactory.getLogger(LogAndContinueExceptionHandler.class); + + @Override + public DeserializationHandlerResponse handle(final ProcessorContext context, + final ConsumerRecord record, + final Exception exception) { + + log.warn("Exception caught during Deserialization, " + + "taskId: {}, topic: {}, partition: {}, offset: {}", + context.taskId(), record.topic(), record.partition(), record.offset(), + exception); + + return DeserializationHandlerResponse.CONTINUE; + } + + @Override + public void configure(final Map configs) { + // ignore + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/errors/LogAndFailExceptionHandler.java b/streams/src/main/java/org/apache/kafka/streams/errors/LogAndFailExceptionHandler.java new file mode 100644 index 0000000..61d2106 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/errors/LogAndFailExceptionHandler.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.errors; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Map; + + +/** + * Deserialization handler that logs a deserialization exception and then + * signals the processing pipeline to stop processing more records and fail. + */ +public class LogAndFailExceptionHandler implements DeserializationExceptionHandler { + private static final Logger log = LoggerFactory.getLogger(LogAndFailExceptionHandler.class); + + @Override + public DeserializationHandlerResponse handle(final ProcessorContext context, + final ConsumerRecord record, + final Exception exception) { + + log.error("Exception caught during Deserialization, " + + "taskId: {}, topic: {}, partition: {}, offset: {}", + context.taskId(), record.topic(), record.partition(), record.offset(), + exception); + + return DeserializationHandlerResponse.FAIL; + } + + @Override + public void configure(final Map configs) { + // ignore + } +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/errors/MissingSourceTopicException.java b/streams/src/main/java/org/apache/kafka/streams/errors/MissingSourceTopicException.java new file mode 100644 index 0000000..c23ea4f --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/errors/MissingSourceTopicException.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.errors; + +public class MissingSourceTopicException extends StreamsException { + + private final static long serialVersionUID = 1L; + + public MissingSourceTopicException(final String message) { + super(message); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/errors/ProcessorStateException.java b/streams/src/main/java/org/apache/kafka/streams/errors/ProcessorStateException.java new file mode 100644 index 0000000..8f1f6ac --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/errors/ProcessorStateException.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.errors; + + +/** + * Indicates a processor state operation (e.g. put, get) has failed. + * + * @see org.apache.kafka.streams.processor.StateStore + */ +public class ProcessorStateException extends StreamsException { + + private final static long serialVersionUID = 1L; + + public ProcessorStateException(final String message) { + super(message); + } + + public ProcessorStateException(final String message, final Throwable throwable) { + super(message, throwable); + } + + public ProcessorStateException(final Throwable throwable) { + super(throwable); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/errors/ProductionExceptionHandler.java b/streams/src/main/java/org/apache/kafka/streams/errors/ProductionExceptionHandler.java new file mode 100644 index 0000000..a24f9d2 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/errors/ProductionExceptionHandler.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.errors; + +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.Configurable; + +/** + * Interface that specifies how an exception when attempting to produce a result to + * Kafka should be handled. + */ +public interface ProductionExceptionHandler extends Configurable { + /** + * Inspect a record that we attempted to produce, and the exception that resulted + * from attempting to produce it and determine whether or not to continue processing. + * + * @param record The record that failed to produce + * @param exception The exception that occurred during production + */ + ProductionExceptionHandlerResponse handle(final ProducerRecord record, + final Exception exception); + + enum ProductionExceptionHandlerResponse { + /* continue processing */ + CONTINUE(0, "CONTINUE"), + /* fail processing */ + FAIL(1, "FAIL"); + + /** + * an english description of the api--this is for debugging and can change + */ + public final String name; + + /** + * the permanent and immutable id of an API--this can't change ever + */ + public final int id; + + ProductionExceptionHandlerResponse(final int id, + final String name) { + this.id = id; + this.name = name; + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/errors/StateStoreMigratedException.java b/streams/src/main/java/org/apache/kafka/streams/errors/StateStoreMigratedException.java new file mode 100644 index 0000000..45329c8 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/errors/StateStoreMigratedException.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.errors; + +/** + * Indicates that the state store being queried is closed although the Kafka Streams state is + * {@link org.apache.kafka.streams.KafkaStreams.State#RUNNING RUNNING} or + * {@link org.apache.kafka.streams.KafkaStreams.State#REBALANCING REBALANCING}. + * This could happen because the store moved to some other instance during a rebalance so + * rediscovery of the state store is required before retrying. + */ +public class StateStoreMigratedException extends InvalidStateStoreException { + + private static final long serialVersionUID = 1L; + + public StateStoreMigratedException(final String message) { + super(message); + } + + public StateStoreMigratedException(final String message, final Throwable throwable) { + super(message, throwable); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/errors/StateStoreNotAvailableException.java b/streams/src/main/java/org/apache/kafka/streams/errors/StateStoreNotAvailableException.java new file mode 100644 index 0000000..7cec17c --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/errors/StateStoreNotAvailableException.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.errors; + +/** + * Indicates that the state store being queried is already closed. This could happen when Kafka Streams is in + * {@link org.apache.kafka.streams.KafkaStreams.State#PENDING_SHUTDOWN PENDING_SHUTDOWN} or + * {@link org.apache.kafka.streams.KafkaStreams.State#NOT_RUNNING NOT_RUNNING} or + * {@link org.apache.kafka.streams.KafkaStreams.State#ERROR ERROR} state. + */ +public class StateStoreNotAvailableException extends InvalidStateStoreException { + + private static final long serialVersionUID = 1L; + + public StateStoreNotAvailableException(final String message) { + super(message); + } + + public StateStoreNotAvailableException(final String message, final Throwable throwable) { + super(message, throwable); + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/errors/StreamsException.java b/streams/src/main/java/org/apache/kafka/streams/errors/StreamsException.java new file mode 100644 index 0000000..32beb7e --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/errors/StreamsException.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.errors; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.streams.processor.TaskId; + +import java.util.Optional; + +/** + * {@link StreamsException} is the top-level exception type generated by Kafka Streams, and indicates errors have + * occurred during a {@link org.apache.kafka.streams.processor.internals.StreamThread StreamThread's} processing. It + * is guaranteed that any exception thrown up to the {@link StreamsUncaughtExceptionHandler} will be of the type + * {@code StreamsException}. For example, any user exceptions will be wrapped as a {@code StreamsException}. + */ +public class StreamsException extends KafkaException { + + private final static long serialVersionUID = 1L; + + private TaskId taskId; + + public StreamsException(final String message) { + this(message, (TaskId) null); + } + + public StreamsException(final String message, final TaskId taskId) { + super(message); + this.taskId = taskId; + } + + public StreamsException(final String message, final Throwable throwable) { + this(message, throwable, null); + } + + public StreamsException(final String message, final Throwable throwable, final TaskId taskId) { + super(message, throwable); + this.taskId = taskId; + } + + public StreamsException(final Throwable throwable) { + this(throwable, null); + } + + public StreamsException(final Throwable throwable, final TaskId taskId) { + super(throwable); + this.taskId = taskId; + } + + /** + * @return the {@link TaskId} that this exception originated from, or {@link Optional#empty()} if the exception + * cannot be traced back to a particular task. Note that the {@code TaskId} being empty does not + * guarantee that the exception wasn't directly related to a specific task. + */ + public Optional taskId() { + return Optional.ofNullable(taskId); + } + + public void setTaskId(final TaskId taskId) { + this.taskId = taskId; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/errors/StreamsNotStartedException.java b/streams/src/main/java/org/apache/kafka/streams/errors/StreamsNotStartedException.java new file mode 100644 index 0000000..562be0e --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/errors/StreamsNotStartedException.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.errors; + +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KafkaStreams.State; + +/** + * Indicates that Kafka Streams is in state {@link State CREATED} and thus state stores cannot be queries yet. + * To query state stores, it's required to first start Kafka Streams via {@link KafkaStreams#start()}. + * You can retry to query the state after the state transitioned to {@link State RUNNING}. + */ +public class StreamsNotStartedException extends InvalidStateStoreException { + + private static final long serialVersionUID = 1L; + + public StreamsNotStartedException(final String message) { + super(message); + } + + public StreamsNotStartedException(final String message, final Throwable throwable) { + super(message, throwable); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/errors/StreamsRebalancingException.java b/streams/src/main/java/org/apache/kafka/streams/errors/StreamsRebalancingException.java new file mode 100644 index 0000000..4b8e14c --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/errors/StreamsRebalancingException.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.errors; + +/** + * Indicates that Kafka Streams is in state {@link org.apache.kafka.streams.KafkaStreams.State#REBALANCING REBALANCING} and thus + * cannot be queried by default. You can retry to query after the rebalance finished. As an alternative, you can also query + * (potentially stale) state stores during a rebalance via {@link org.apache.kafka.streams.StoreQueryParameters#enableStaleStores()}. + */ +public class StreamsRebalancingException extends InvalidStateStoreException { + + private static final long serialVersionUID = 1L; + + public StreamsRebalancingException(final String message) { + super(message); + } + + public StreamsRebalancingException(final String message, final Throwable throwable) { + super(message, throwable); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/errors/StreamsUncaughtExceptionHandler.java b/streams/src/main/java/org/apache/kafka/streams/errors/StreamsUncaughtExceptionHandler.java new file mode 100644 index 0000000..5502b35 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/errors/StreamsUncaughtExceptionHandler.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.errors; + +public interface StreamsUncaughtExceptionHandler { + /** + * Inspect the exception received in a stream thread and respond with an action. + * @param exception the actual exception + */ + StreamThreadExceptionResponse handle(final Throwable exception); + + /** + * Enumeration that describes the response from the exception handler. + */ + enum StreamThreadExceptionResponse { + REPLACE_THREAD(0, "REPLACE_THREAD"), + SHUTDOWN_CLIENT(1, "SHUTDOWN_KAFKA_STREAMS_CLIENT"), + SHUTDOWN_APPLICATION(2, "SHUTDOWN_KAFKA_STREAMS_APPLICATION"); + + /** an english description of the api--this is for debugging and can change */ + public final String name; + + /** the permanent and immutable id of an API--this can't change ever */ + public final int id; + + StreamThreadExceptionResponse(final int id, final String name) { + this.id = id; + this.name = name; + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/errors/TaskAssignmentException.java b/streams/src/main/java/org/apache/kafka/streams/errors/TaskAssignmentException.java new file mode 100644 index 0000000..2be43d6 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/errors/TaskAssignmentException.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.errors; + +/** + * Indicates a run time error incurred while trying to assign + * {@link org.apache.kafka.streams.processor.internals.StreamTask stream tasks} to + * {@link org.apache.kafka.streams.processor.internals.StreamThread threads}. + */ +public class TaskAssignmentException extends StreamsException { + + private final static long serialVersionUID = 1L; + + public TaskAssignmentException(final String message) { + super(message); + } + + public TaskAssignmentException(final String message, final Throwable throwable) { + super(message, throwable); + } + + public TaskAssignmentException(final Throwable throwable) { + super(throwable); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/errors/TaskCorruptedException.java b/streams/src/main/java/org/apache/kafka/streams/errors/TaskCorruptedException.java new file mode 100644 index 0000000..bf5bd17 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/errors/TaskCorruptedException.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.errors; + +import org.apache.kafka.clients.consumer.InvalidOffsetException; +import org.apache.kafka.streams.processor.TaskId; + +import java.util.Set; + +/** + * Indicates a specific task is corrupted and need to be re-initialized. It can be thrown when + * + * 1) Under EOS, if the checkpoint file does not contain offsets for corresponding store's changelogs, meaning + * previously it was not close cleanly; + * 2) Out-of-range exception thrown during restoration, meaning that the changelog has been modified and we re-bootstrap + * the store. + */ +public class TaskCorruptedException extends StreamsException { + + private final Set corruptedTasks; + + public TaskCorruptedException(final Set corruptedTasks) { + super("Tasks " + corruptedTasks + " are corrupted and hence needs to be re-initialized"); + this.corruptedTasks = corruptedTasks; + } + + public TaskCorruptedException(final Set corruptedTasks, + final InvalidOffsetException e) { + super("Tasks " + corruptedTasks + " are corrupted and hence needs to be re-initialized", e); + this.corruptedTasks = corruptedTasks; + } + + public Set corruptedTasks() { + return corruptedTasks; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/errors/TaskIdFormatException.java b/streams/src/main/java/org/apache/kafka/streams/errors/TaskIdFormatException.java new file mode 100644 index 0000000..6349343 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/errors/TaskIdFormatException.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.errors; + + +/** + * Indicates a run time error incurred while trying parse the {@link org.apache.kafka.streams.processor.TaskId task id} + * from the read string. + * + * @see org.apache.kafka.streams.processor.internals.StreamTask + */ +public class TaskIdFormatException extends StreamsException { + + private static final long serialVersionUID = 1L; + + public TaskIdFormatException(final String message) { + super("Task id cannot be parsed correctly" + (message == null ? "" : " from " + message)); + } + + public TaskIdFormatException(final String message, final Throwable throwable) { + super("Task id cannot be parsed correctly" + (message == null ? "" : " from " + message), throwable); + } + + public TaskIdFormatException(final Throwable throwable) { + super(throwable); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/errors/TaskMigratedException.java b/streams/src/main/java/org/apache/kafka/streams/errors/TaskMigratedException.java new file mode 100644 index 0000000..fdb3ab8 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/errors/TaskMigratedException.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.errors; + + +/** + * Indicates that all tasks belongs to the thread have migrated to another thread. This exception can be thrown when + * the thread gets fenced (either by the consumer coordinator or by the transaction coordinator), which means it is + * no longer part of the group but a "zombie" already + */ +public class TaskMigratedException extends StreamsException { + + private final static long serialVersionUID = 1L; + + public TaskMigratedException(final String message) { + super(message + "; it means all tasks belonging to this thread should be migrated."); + } + + public TaskMigratedException(final String message, final Throwable throwable) { + super(message + "; it means all tasks belonging to this thread should be migrated.", throwable); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/errors/TopologyException.java b/streams/src/main/java/org/apache/kafka/streams/errors/TopologyException.java new file mode 100644 index 0000000..1eaef06 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/errors/TopologyException.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.errors; + + +/** + * Indicates a pre run time error occurred while parsing the {@link org.apache.kafka.streams.Topology logical topology} + * to construct the {@link org.apache.kafka.streams.processor.internals.ProcessorTopology physical processor topology}. + */ +public class TopologyException extends StreamsException { + + private static final long serialVersionUID = 1L; + + public TopologyException(final String message) { + super("Invalid topology" + (message == null ? "" : ": " + message)); + } + + public TopologyException(final String message, + final Throwable throwable) { + super("Invalid topology" + (message == null ? "" : ": " + message), throwable); + } + + public TopologyException(final Throwable throwable) { + super(throwable); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/errors/UnknownStateStoreException.java b/streams/src/main/java/org/apache/kafka/streams/errors/UnknownStateStoreException.java new file mode 100644 index 0000000..0ee0658 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/errors/UnknownStateStoreException.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.errors; + +/** + * Indicates that the state store being queried is unknown, i.e., the state store does either not exist in your topology + * or it is not queryable. + */ +public class UnknownStateStoreException extends InvalidStateStoreException { + + private static final long serialVersionUID = 1L; + + public UnknownStateStoreException(final String message) { + super(message); + } + + public UnknownStateStoreException(final String message, final Throwable throwable) { + super(message, throwable); + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/internals/ApiUtils.java b/streams/src/main/java/org/apache/kafka/streams/internals/ApiUtils.java new file mode 100644 index 0000000..c62de23 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/internals/ApiUtils.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.internals; + +import org.apache.kafka.streams.kstream.ValueTransformerSupplier; + +import java.time.Duration; +import java.time.Instant; +import java.util.function.Supplier; + +import static java.lang.String.format; + +public final class ApiUtils { + + private static final String MILLISECOND_VALIDATION_FAIL_MSG_FRMT = "Invalid value for parameter \"%s\" (value was: %s). "; + private static final String VALIDATE_MILLISECOND_NULL_SUFFIX = "It shouldn't be null."; + private static final String VALIDATE_MILLISECOND_OVERFLOW_SUFFIX = "It can't be converted to milliseconds."; + + private ApiUtils() { + } + + /** + * Validates that milliseconds from {@code duration} can be retrieved. + * @param duration Duration to check. + * @param messagePrefix Prefix text for an error message. + * @return Milliseconds from {@code duration}. + */ + public static long validateMillisecondDuration(final Duration duration, final String messagePrefix) { + try { + if (duration == null) { + throw new IllegalArgumentException(messagePrefix + VALIDATE_MILLISECOND_NULL_SUFFIX); + } + + return duration.toMillis(); + } catch (final ArithmeticException e) { + throw new IllegalArgumentException(messagePrefix + VALIDATE_MILLISECOND_OVERFLOW_SUFFIX, e); + } + } + + /** + * Validates that milliseconds from {@code instant} can be retrieved. + * @param instant Instant to check. + * @param messagePrefix Prefix text for an error message. + * @return Milliseconds from {@code instant}. + */ + public static long validateMillisecondInstant(final Instant instant, final String messagePrefix) { + try { + if (instant == null) { + throw new IllegalArgumentException(messagePrefix + VALIDATE_MILLISECOND_NULL_SUFFIX); + } + + return instant.toEpochMilli(); + } catch (final ArithmeticException e) { + throw new IllegalArgumentException(messagePrefix + VALIDATE_MILLISECOND_OVERFLOW_SUFFIX, e); + } + } + + /** + * Generates the prefix message for validateMillisecondXXXXXX() utility + * @param value Object to be converted to milliseconds + * @param name Object name + * @return Error message prefix to use in exception + */ + public static String prepareMillisCheckFailMsgPrefix(final Object value, final String name) { + return format(MILLISECOND_VALIDATION_FAIL_MSG_FRMT, name, value); + } + + /** + * @throws IllegalArgumentException if the same instance is obtained each time + */ + public static void checkSupplier(final Supplier supplier) { + if (supplier.get() == supplier.get()) { + final String supplierClass = supplier.getClass().getName(); + throw new IllegalArgumentException(String.format("%s generates single reference." + + " %s#get() must return a new object each time it is called.", supplierClass, supplierClass)); + } + } + + /** + * @throws IllegalArgumentException if the same instance is obtained each time + */ + public static void checkSupplier(final ValueTransformerSupplier supplier) { + if (supplier.get() == supplier.get()) { + final String supplierClass = supplier.getClass().getName(); + throw new IllegalArgumentException(String.format("%s generates single reference." + + " %s#get() must return a new object each time it is called.", supplierClass, supplierClass)); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/internals/metrics/ClientMetrics.java b/streams/src/main/java/org/apache/kafka/streams/internals/metrics/ClientMetrics.java new file mode 100644 index 0000000..a57b4bb --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/internals/metrics/ClientMetrics.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.internals.metrics; + +import org.apache.kafka.common.metrics.Gauge; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.Sensor.RecordingLevel; +import org.apache.kafka.streams.KafkaStreams.State; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.InputStream; +import java.util.Properties; + +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.CLIENT_LEVEL_GROUP; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addSumMetricToSensor; + +public class ClientMetrics { + private ClientMetrics() {} + + private static final Logger log = LoggerFactory.getLogger(ClientMetrics.class); + private static final String VERSION = "version"; + private static final String COMMIT_ID = "commit-id"; + private static final String APPLICATION_ID = "application-id"; + private static final String TOPOLOGY_DESCRIPTION = "topology-description"; + private static final String STATE = "state"; + private static final String ALIVE_STREAM_THREADS = "alive-stream-threads"; + private static final String VERSION_FROM_FILE; + private static final String COMMIT_ID_FROM_FILE; + private static final String DEFAULT_VALUE = "unknown"; + private static final String FAILED_STREAM_THREADS = "failed-stream-threads"; + + static { + final Properties props = new Properties(); + try (InputStream resourceStream = ClientMetrics.class.getResourceAsStream( + "/kafka/kafka-streams-version.properties")) { + + props.load(resourceStream); + } catch (final Exception exception) { + log.warn("Error while loading kafka-streams-version.properties", exception); + } + VERSION_FROM_FILE = props.getProperty("version", DEFAULT_VALUE).trim(); + COMMIT_ID_FROM_FILE = props.getProperty("commitId", DEFAULT_VALUE).trim(); + } + + private static final String VERSION_DESCRIPTION = "The version of the Kafka Streams client"; + private static final String COMMIT_ID_DESCRIPTION = "The version control commit ID of the Kafka Streams client"; + private static final String APPLICATION_ID_DESCRIPTION = "The application ID of the Kafka Streams client"; + private static final String TOPOLOGY_DESCRIPTION_DESCRIPTION = + "The description of the topology executed in the Kafka Streams client"; + private static final String STATE_DESCRIPTION = "The state of the Kafka Streams client"; + private static final String ALIVE_STREAM_THREADS_DESCRIPTION = "The current number of alive stream threads that are running or participating in rebalance"; + private static final String FAILED_STREAM_THREADS_DESCRIPTION = "The number of failed stream threads since the start of the Kafka Streams client"; + + public static String version() { + return VERSION_FROM_FILE; + } + + public static String commitId() { + return COMMIT_ID_FROM_FILE; + } + + public static void addVersionMetric(final StreamsMetricsImpl streamsMetrics) { + streamsMetrics.addClientLevelImmutableMetric( + VERSION, + VERSION_DESCRIPTION, + RecordingLevel.INFO, + VERSION_FROM_FILE + ); + } + + public static void addCommitIdMetric(final StreamsMetricsImpl streamsMetrics) { + streamsMetrics.addClientLevelImmutableMetric( + COMMIT_ID, + COMMIT_ID_DESCRIPTION, + RecordingLevel.INFO, + COMMIT_ID_FROM_FILE + ); + } + + public static void addApplicationIdMetric(final StreamsMetricsImpl streamsMetrics, final String applicationId) { + streamsMetrics.addClientLevelImmutableMetric( + APPLICATION_ID, + APPLICATION_ID_DESCRIPTION, + RecordingLevel.INFO, + applicationId + ); + } + + public static void addTopologyDescriptionMetric(final StreamsMetricsImpl streamsMetrics, + final Gauge topologyDescription) { + streamsMetrics.addClientLevelMutableMetric( + TOPOLOGY_DESCRIPTION, + TOPOLOGY_DESCRIPTION_DESCRIPTION, + RecordingLevel.INFO, + topologyDescription + ); + } + + public static void addStateMetric(final StreamsMetricsImpl streamsMetrics, + final Gauge stateProvider) { + streamsMetrics.addClientLevelMutableMetric( + STATE, + STATE_DESCRIPTION, + RecordingLevel.INFO, + stateProvider + ); + } + + public static void addNumAliveStreamThreadMetric(final StreamsMetricsImpl streamsMetrics, + final Gauge stateProvider) { + streamsMetrics.addClientLevelMutableMetric( + ALIVE_STREAM_THREADS, + ALIVE_STREAM_THREADS_DESCRIPTION, + RecordingLevel.INFO, + stateProvider + ); + } + + public static Sensor failedStreamThreadSensor(final StreamsMetricsImpl streamsMetrics) { + final Sensor sensor = streamsMetrics.clientLevelSensor(FAILED_STREAM_THREADS, RecordingLevel.INFO); + addSumMetricToSensor( + sensor, + CLIENT_LEVEL_GROUP, + streamsMetrics.clientLevelTagMap(), + FAILED_STREAM_THREADS, + false, + FAILED_STREAM_THREADS_DESCRIPTION + ); + return sensor; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/Aggregator.java b/streams/src/main/java/org/apache/kafka/streams/kstream/Aggregator.java new file mode 100644 index 0000000..217a145 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/Aggregator.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + + +/** + * The {@code Aggregator} interface for aggregating values of the given key. + * This is a generalization of {@link Reducer} and allows to have different types for input value and aggregation + * result. + * {@code Aggregator} is used in combination with {@link Initializer} that provides an initial aggregation value. + *

                + * {@code Aggregator} can be used to implement aggregation functions like count. + + * @param key type + * @param input value type + * @param aggregate value type + * @see Initializer + * @see KGroupedStream#aggregate(Initializer, Aggregator) + * @see KGroupedStream#aggregate(Initializer, Aggregator, Materialized) + * @see TimeWindowedKStream#aggregate(Initializer, Aggregator) + * @see TimeWindowedKStream#aggregate(Initializer, Aggregator, Materialized) + * @see SessionWindowedKStream#aggregate(Initializer, Aggregator, Merger) + * @see SessionWindowedKStream#aggregate(Initializer, Aggregator, Merger, Materialized) + * @see Reducer + */ +public interface Aggregator { + + /** + * Compute a new aggregate from the key and value of a record and the current aggregate of the same key. + * + * @param key the key of the record + * @param value the value of the record + * @param aggregate the current aggregate value + * @return the new aggregate value + */ + VA apply(final K key, final V value, final VA aggregate); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/Branched.java b/streams/src/main/java/org/apache/kafka/streams/kstream/Branched.java new file mode 100644 index 0000000..3e85a73 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/Branched.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import java.util.Objects; +import java.util.function.Consumer; +import java.util.function.Function; + +/** + * The {@code Branched} class is used to define the optional parameters when building branches with + * {@link BranchedKStream}. + * + * @param type of record key + * @param type of record value + */ +public class Branched implements NamedOperation> { + + protected final String name; + protected final Function, ? extends KStream> chainFunction; + protected final Consumer> chainConsumer; + + protected Branched(final String name, + final Function, ? extends KStream> chainFunction, + final Consumer> chainConsumer) { + this.name = name; + this.chainFunction = chainFunction; + this.chainConsumer = chainConsumer; + } + + /** + * Create an instance of {@code Branched} with provided branch name suffix. + * + * @param name the branch name suffix to be used (see {@link BranchedKStream} description for details) + * @param key type + * @param value type + * @return a new instance of {@code Branched} + */ + public static Branched as(final String name) { + Objects.requireNonNull(name, "name cannot be null"); + return new Branched<>(name, null, null); + } + + /** + * Create an instance of {@code Branched} with provided chain function. + * + * @param chain A function that will be applied to the branch. If the provided function returns + * {@code null}, its result is ignored, otherwise it is added to the {@code Map} returned + * by {@link BranchedKStream#defaultBranch()} or {@link BranchedKStream#noDefaultBranch()} (see + * {@link BranchedKStream} description for details). + * @param key type + * @param value type + * @return a new instance of {@code Branched} + */ + public static Branched withFunction( + final Function, ? extends KStream> chain) { + Objects.requireNonNull(chain, "chain function cannot be null"); + return new Branched<>(null, chain, null); + } + + /** + * Create an instance of {@code Branched} with provided chain consumer. + * + * @param chain A consumer to which the branch will be sent. If a consumer is provided, + * the respective branch will not be added to the resulting {@code Map} returned + * by {@link BranchedKStream#defaultBranch()} or {@link BranchedKStream#noDefaultBranch()} (see + * {@link BranchedKStream} description for details). + * @param key type + * @param value type + * @return a new instance of {@code Branched} + */ + public static Branched withConsumer(final Consumer> chain) { + Objects.requireNonNull(chain, "chain consumer cannot be null"); + return new Branched<>(null, null, chain); + } + + /** + * Create an instance of {@code Branched} with provided chain function and branch name suffix. + * + * @param chain A function that will be applied to the branch. If the provided function returns + * {@code null}, its result is ignored, otherwise it is added to the {@code Map} returned + * by {@link BranchedKStream#defaultBranch()} or {@link BranchedKStream#noDefaultBranch()} (see + * {@link BranchedKStream} description for details). + * @param name the branch name suffix to be used. If {@code null}, a default branch name suffix will be generated + * (see {@link BranchedKStream} description for details) + * @param key type + * @param value type + * @return a new instance of {@code Branched} + */ + public static Branched withFunction( + final Function, ? extends KStream> chain, final String name) { + Objects.requireNonNull(chain, "chain function cannot be null"); + return new Branched<>(name, chain, null); + } + + /** + * Create an instance of {@code Branched} with provided chain consumer and branch name suffix. + * + * @param chain A consumer to which the branch will be sent. If a non-null consumer is provided, + * the respective branch will not be added to the resulting {@code Map} returned + * by {@link BranchedKStream#defaultBranch()} or {@link BranchedKStream#noDefaultBranch()} (see + * {@link BranchedKStream} description for details). + * @param name the branch name suffix to be used. If {@code null}, a default branch name suffix will be generated + * (see {@link BranchedKStream} description for details) + * @param key type + * @param value type + * @return a new instance of {@code Branched} + */ + public static Branched withConsumer(final Consumer> chain, + final String name) { + Objects.requireNonNull(chain, "chain consumer cannot be null"); + return new Branched<>(name, null, chain); + } + + /** + * Create an instance of {@code Branched} from an existing instance. + * + * @param branched the instance of {@code Branched} to copy + */ + protected Branched(final Branched branched) { + this(branched.name, branched.chainFunction, branched.chainConsumer); + } + + /** + * Configure the instance of {@code Branched} with a branch name suffix. + * + * @param name the branch name suffix to be used. If {@code null} a default branch name suffix will be generated (see + * {@link BranchedKStream} description for details) + * @return {@code this} + */ + @Override + public Branched withName(final String name) { + Objects.requireNonNull(name, "name cannot be null"); + return new Branched<>(name, chainFunction, chainConsumer); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/BranchedKStream.java b/streams/src/main/java/org/apache/kafka/streams/kstream/BranchedKStream.java new file mode 100644 index 0000000..2115170 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/BranchedKStream.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import java.util.Map; + +/** + * Branches the records in the original stream based on the predicates supplied for the branch definitions. + *

                + * Branches are defined with {@link BranchedKStream#branch(Predicate, Branched)} or + * {@link BranchedKStream#defaultBranch(Branched)} methods. Each record is evaluated against the {@code predicate} + * supplied via {@link Branched} parameters, and is routed to the first branch for which its respective predicate + * evaluates to {@code true}. If a record does not match any predicates, it will be routed to the default branch, + * or dropped if no default branch is created. + *

                + * Each branch (which is a {@link KStream} instance) then can be processed either by + * a {@link java.util.function.Function} or a {@link java.util.function.Consumer} provided via a {@link Branched} + * parameter. If certain conditions are met, it also can be accessed from the {@link Map} returned by an optional + * {@link BranchedKStream#defaultBranch(Branched)} or {@link BranchedKStream#noDefaultBranch()} method call + * (see usage examples). + *

                + * The branching happens on a first-match basis: A record in the original stream is assigned to the corresponding result + * stream for the first predicate that evaluates to {@code true}, and is assigned to this stream only. If you need + * to route a record to multiple streams, you can apply multiple {@link KStream#filter(Predicate)} operators + * to the same {@link KStream} instance, one for each predicate, instead of branching. + *

                + * The process of routing the records to different branches is a stateless record-by-record operation. + * + *

                Rules of forming the resulting map

                + * The keys of the {@code Map>} entries returned by {@link BranchedKStream#defaultBranch(Branched)} or + * {@link BranchedKStream#noDefaultBranch()} are defined by the following rules: + *
                  + *
                • If {@link Named} parameter was provided for {@link KStream#split(Named)}, its value is used as + * a prefix for each key. By default, no prefix is used + *
                • If a branch name is provided in {@link BranchedKStream#branch(Predicate, Branched)} via the + * {@link Branched} parameter, its value is appended to the prefix to form the {@code Map} key + *
                • If a name is not provided for the branch, then the key defaults to {@code prefix + position} of the branch + * as a decimal number, starting from {@code "1"} + *
                • If a name is not provided for the {@link BranchedKStream#defaultBranch()}, then the key defaults + * to {@code prefix + "0"} + *
                + * The values of the respective {@code Map>} entries are formed as following: + *
                  + *
                • If no chain function or consumer is provided in {@link BranchedKStream#branch(Predicate, Branched)} via + * the {@link Branched} parameter, then the branch itself is added to the {@code Map} + *
                • If chain function is provided and it returns a non-null value for a given branch, then the value + * is the result returned by this function + *
                • If a chain function returns {@code null} for a given branch, then no entry is added to the map + *
                • If a consumer is provided for a given branch, then no entry is added to the map + *
                + * For example: + *
                 {@code
                + * Map> result =
                + *   source.split(Named.as("foo-"))
                + *     .branch(predicate1, Branched.as("bar"))                    // "foo-bar"
                + *     .branch(predicate2, Branched.withConsumer(ks->ks.to("A"))  // no entry: a Consumer is provided
                + *     .branch(predicate3, Branched.withFunction(ks->null))       // no entry: chain function returns null
                + *     .branch(predicate4, Branched.withFunction(ks->ks))         // "foo-4": chain function returns non-null value
                + *     .branch(predicate5)                                        // "foo-5": name defaults to the branch position
                + *     .defaultBranch()                                           // "foo-0": "0" is the default name for the default branch
                + * }
                + * + *

                Usage examples

                + * + *

                Direct Branch Consuming

                + * In many cases we do not need to have a single scope for all the branches, each branch being processed completely + * independently from others. Then we can use 'consuming' lambdas or method references in {@link Branched} parameter: + * + *
                 {@code
                + * source.split()
                + *     .branch(predicate1, Branched.withConsumer(ks -> ks.to("A")))
                + *     .branch(predicate2, Branched.withConsumer(ks -> ks.to("B")))
                + *     .defaultBranch(Branched.withConsumer(ks->ks.to("C")));
                + * }
                + * + *

                Collecting branches in a single scope

                + * In other cases we want to combine branches again after splitting. The map returned by + * {@link BranchedKStream#defaultBranch()} or {@link BranchedKStream#noDefaultBranch()} methods provides + * access to all the branches in the same scope: + * + *
                 {@code
                + * Map> branches = source.split(Named.as("split-"))
                + *     .branch((key, value) -> value == null, Branched.withFunction(s -> s.mapValues(v->"NULL"), "null")
                + *     .defaultBranch(Branched.as("non-null"));
                + *
                + * KStream merged = branches.get("split-non-null").merge(branches.get("split-null"));
                + * }
                + * + *

                Dynamic branching

                + * There is also a case when we might need to create branches dynamically, e. g. one per enum value: + * + *
                 {@code
                + * BranchedKStream branched = stream.split();
                + * for (RecordType recordType : RecordType.values())
                + *     branched.branch((k, v) -> v.getRecType() == recordType,
                + *         Branched.withConsumer(recordType::processRecords));
                + * }
                + * + * @param Type of keys + * @param Type of values + * @see KStream + */ +public interface BranchedKStream { + /** + * Define a branch for records that match the predicate. + * + * @param predicate A {@link Predicate} instance, against which each record will be evaluated. + * If this predicate returns {@code true} for a given record, the record will be + * routed to the current branch and will not be evaluated against the predicates + * for the remaining branches. + * @return {@code this} to facilitate method chaining + */ + BranchedKStream branch(Predicate predicate); + + /** + * Define a branch for records that match the predicate. + * + * @param predicate A {@link Predicate} instance, against which each record will be evaluated. + * If this predicate returns {@code true} for a given record, the record will be + * routed to the current branch and will not be evaluated against the predicates + * for the remaining branches. + * @param branched A {@link Branched} parameter, that allows to define a branch name, an in-place + * branch consumer or branch mapper (see code examples + * for {@link BranchedKStream}) + * @return {@code this} to facilitate method chaining + */ + BranchedKStream branch(Predicate predicate, Branched branched); + + /** + * Finalize the construction of branches and defines the default branch for the messages not intercepted + * by other branches. Calling {@code defaultBranch} or {@link #noDefaultBranch()} is optional. + * + * @return {@link Map} of named branches. For rules of forming the resulting map, see {@code BranchedKStream} + * description. + */ + Map> defaultBranch(); + + /** + * Finalize the construction of branches and defines the default branch for the messages not intercepted + * by other branches. Calling {@code defaultBranch} or {@link #noDefaultBranch()} is optional. + * + * @param branched A {@link Branched} parameter, that allows to define a branch name, an in-place + * branch consumer or branch mapper (see code examples + * for {@link BranchedKStream}) + * @return {@link Map} of named branches. For rules of forming the resulting map, see {@link BranchedKStream} + * description. + */ + Map> defaultBranch(Branched branched); + + /** + * Finalize the construction of branches without forming a default branch. Calling {@code #noDefaultBranch()} + * or {@link #defaultBranch()} is optional. + * + * @return {@link Map} of named branches. For rules of forming the resulting map, see {@link BranchedKStream} + * description. + */ + Map> noDefaultBranch(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/CogroupedKStream.java b/streams/src/main/java/org/apache/kafka/streams/kstream/CogroupedKStream.java new file mode 100644 index 0000000..051396f --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/CogroupedKStream.java @@ -0,0 +1,301 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StoreQueryParameters; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.ReadOnlyKeyValueStore; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; + +/** + * {@code CogroupedKStream} is an abstraction of multiple grouped record streams of {@link KeyValue} pairs. + *

                + * It is an intermediate representation after a grouping of {@link KStream}s, before the + * aggregations are applied to the new partitions resulting in a {@link KTable}. + *

                + * A {@code CogroupedKStream} must be obtained from a {@link KGroupedStream} via + * {@link KGroupedStream#cogroup(Aggregator) cogroup(...)}. + * + * @param Type of keys + * @param Type of values after agg + */ +public interface CogroupedKStream { + + /** + * Add an already {@link KGroupedStream grouped KStream} to this {@code CogroupedKStream}. + *

                + * The added {@link KGroupedStream grouped KStream} must have the same number of partitions as all existing + * streams of this {@code CogroupedKStream}. + * If this is not the case, you would need to call {@link KStream#repartition(Repartitioned)} before + * {@link KStream#groupByKey() grouping} the {@link KStream} and specify the "correct" number of + * partitions via {@link Repartitioned} parameter. + *

                + * The specified {@link Aggregator} is applied in the actual {@link #aggregate(Initializer) aggregation} step for + * each input record and computes a new aggregate using the current aggregate (or for the very first record per key + * using the initial intermediate aggregation result provided via the {@link Initializer} that is passed into + * {@link #aggregate(Initializer)}) and the record's value. + * + * @param groupedStream a group stream + * @param aggregator an {@link Aggregator} that computes a new aggregate result + * @param Type of input values + * @return a {@code CogroupedKStream} + */ + CogroupedKStream cogroup(final KGroupedStream groupedStream, + final Aggregator aggregator); + + /** + * Aggregate the values of records in these streams by the grouped key. + * Records with {@code null} key or value are ignored. + * The result is written into a local {@link KeyValueStore} (which is basically an ever-updating materialized view) + * that can be queried by the given store name in {@code materialized}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * To compute the aggregation the corresponding {@link Aggregator} as specified in + * {@link #cogroup(KGroupedStream, Aggregator) cogroup(...)} is used per input stream. + * The specified {@link Initializer} is applied once per key, directly before the first input record per key is + * processed to provide an initial intermediate aggregation result that is used to process the first record. + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to the + * same key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link ReadOnlyKeyValueStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // some aggregation on value type double
                +     * String queryableStoreName = "storeName" // the store name should be the name of the store as defined by the Materialized instance
                +     * ReadOnlyKeyValueStore> localStore = streams.store(queryableStoreName, QueryableStoreTypes.> timestampedKeyValueStore());
                +     * K key = "some-key";
                +     * ValueAndTimestamp aggForKey = localStore.get(key); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to query + * the value of the key on a parallel running instance of your Kafka Streams application. + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedKeyValueStore}) will be backed by + * an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the Materialized instance must be a valid Kafka topic name and cannot + * contain characters other than ASCII alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is a generated value, and + * "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param initializer an {@link Initializer} that computes an initial intermediate aggregation + * result. Cannot be {@code null}. + * @return a {@link KTable} that contains "update" records with unmodified keys, and values that + * represent the latest (rolling) aggregate for each key + */ + KTable aggregate(final Initializer initializer); + + /** + * Aggregate the values of records in these streams by the grouped key. + * Records with {@code null} key or value are ignored. + * The result is written into a local {@link KeyValueStore} (which is basically an ever-updating materialized view) + * that can be queried by the given store name in {@code materialized}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * To compute the aggregation the corresponding {@link Aggregator} as specified in + * {@link #cogroup(KGroupedStream, Aggregator) cogroup(...)} is used per input stream. + * The specified {@link Initializer} is applied once per key, directly before the first input record per key is + * processed to provide an initial intermediate aggregation result that is used to process the first record. + * The specified {@link Named} is applied once to the processor combining the grouped streams. + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to the + * same key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link ReadOnlyKeyValueStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // some aggregation on value type double
                +     * String queryableStoreName = "storeName" // the store name should be the name of the store as defined by the Materialized instance
                +     * ReadOnlyKeyValueStore> localStore = streams.store(queryableStoreName, QueryableStoreTypes.> timestampedKeyValueStore());
                +     * K key = "some-key";
                +     * ValueAndTimestamp aggForKey = localStore.get(key); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to query + * the value of the key on a parallel running instance of your Kafka Streams application. + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedKeyValueStore}) will be backed by + * an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the Materialized instance must be a valid Kafka topic name and cannot + * contain characters other than ASCII alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is the provide store name defined + * in {@code Materialized}, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param initializer an {@link Initializer} that computes an initial intermediate aggregation + * result. Cannot be {@code null}. + * @param named name the processor. Cannot be {@code null}. + * @return a {@link KTable} that contains "update" records with unmodified keys, and values that + * represent the latest (rolling) aggregate for each key + */ + KTable aggregate(final Initializer initializer, + final Named named); + + /** + * Aggregate the values of records in these streams by the grouped key. + * Records with {@code null} key or value are ignored. + * The result is written into a local {@link KeyValueStore} (which is basically an ever-updating materialized view) + * that can be queried by the given store name in {@code materialized}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * To compute the aggregation the corresponding {@link Aggregator} as specified in + * {@link #cogroup(KGroupedStream, Aggregator) cogroup(...)} is used per input stream. + * The specified {@link Initializer} is applied once per key, directly before the first input record per key is + * processed to provide an initial intermediate aggregation result that is used to process the first record. + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to the + * same key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link ReadOnlyKeyValueStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // some aggregation on value type double
                +     * String queryableStoreName = "storeName" // the store name should be the name of the store as defined by the Materialized instance
                +     * ReadOnlyKeyValueStore> localStore = streams.store(queryableStoreName, QueryableStoreTypes.> timestampedKeyValueStore());
                +     * K key = "some-key";
                +     * ValueAndTimestamp aggForKey = localStore.get(key); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to query + * the value of the key on a parallel running instance of your Kafka Streams application. + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedKeyValueStore} -- regardless of what + * is specified in the parameter {@code materialized}) will be backed by an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the Materialized instance must be a valid Kafka topic name and cannot + * contain characters other than ASCII alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is the provide store name defined + * in {@code Materialized}, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param initializer an {@link Initializer} that computes an initial intermediate aggregation + * result. Cannot be {@code null}. + * @param materialized an instance of {@link Materialized} used to materialize a state store. + * Cannot be {@code null}. + * @return a {@link KTable} that contains "update" records with unmodified keys, and values that + * represent the latest (rolling) aggregate for each key + */ + KTable aggregate(final Initializer initializer, + final Materialized> materialized); + + /** + * Aggregate the values of records in these streams by the grouped key. + * Records with {@code null} key or value are ignored. + * The result is written into a local {@link KeyValueStore} (which is basically an ever-updating materialized view) + * that can be queried by the given store name in {@code materialized}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * To compute the aggregation the corresponding {@link Aggregator} as specified in + * {@link #cogroup(KGroupedStream, Aggregator) cogroup(...)} is used per input stream. + * The specified {@link Initializer} is applied once per key, directly before the first input record per key is + * processed to provide an initial intermediate aggregation result that is used to process the first record. + * The specified {@link Named} is used to name the processor combining the grouped streams. + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to the + * same key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link org.apache.kafka.streams.state.ReadOnlyKeyValueStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // some aggregation on value type double
                +     * String queryableStoreName = "storeName" // the store name should be the name of the store as defined by the Materialized instance
                +     * ReadOnlyKeyValueStore> localStore = streams.store(queryableStoreName, QueryableStoreTypes.> timestampedKeyValueStore());
                +     * K key = "some-key";
                +     * ValueAndTimestamp aggForKey = localStore.get(key); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to query + * the value of the key on a parallel running instance of your Kafka Streams application. + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedKeyValueStore} -- regardless of what + * is specified in the parameter {@code materialized}) will be backed by an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the Materialized instance must be a valid Kafka topic name and cannot + * contain characters other than ASCII alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is the provide store name defined + * in {@code Materialized}, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param initializer an {@link Initializer} that computes an initial intermediate aggregation + * result. Cannot be {@code null}. + * @param materialized an instance of {@link Materialized} used to materialize a state store. + * Cannot be {@code null}. + * @param named name the processors. Cannot be {@code null}. + * @return a {@link KTable} that contains "update" records with unmodified keys, and values that + * represent the latest (rolling) aggregate for each key + */ + KTable aggregate(final Initializer initializer, + final Named named, + final Materialized> materialized); + + /** + * Create a new {@link TimeWindowedCogroupedKStream} instance that can be used to perform windowed + * aggregations. + * + * @param windows the specification of the aggregation {@link Windows} + * @param the window type + * @return an instance of {@link TimeWindowedCogroupedKStream} + */ + TimeWindowedCogroupedKStream windowedBy(final Windows windows); + + /** + * Create a new {@link TimeWindowedCogroupedKStream} instance that can be used to perform sliding + * windowed aggregations. + * + * @param windows the specification of the aggregation {@link SlidingWindows} + * @return an instance of {@link TimeWindowedCogroupedKStream} + */ + TimeWindowedCogroupedKStream windowedBy(final SlidingWindows windows); + + /** + * Create a new {@link SessionWindowedCogroupedKStream} instance that can be used to perform session + * windowed aggregations. + * + * @param windows the specification of the aggregation {@link SessionWindows} + * @return an instance of {@link SessionWindowedCogroupedKStream} + */ + SessionWindowedCogroupedKStream windowedBy(final SessionWindows windows); + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/Consumed.java b/streams/src/main/java/org/apache/kafka/streams/kstream/Consumed.java new file mode 100644 index 0000000..423ca60 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/Consumed.java @@ -0,0 +1,230 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.processor.TimestampExtractor; + +import java.util.Objects; + +/** + * The {@code Consumed} class is used to define the optional parameters when using {@link StreamsBuilder} to + * build instances of {@link KStream}, {@link KTable}, and {@link GlobalKTable}. + *

                + * For example, you can read a topic as {@link KStream} with a custom timestamp extractor and specify the corresponding + * key and value serdes like: + *

                {@code
                + * StreamsBuilder builder = new StreamsBuilder();
                + * KStream stream = builder.stream(
                + *   "topicName",
                + *   Consumed.with(Serdes.String(), Serdes.Long())
                + *           .withTimestampExtractor(new LogAndSkipOnInvalidTimestamp()));
                + * }
                + * Similarly, you can read a topic as {@link KTable} with a custom {@code auto.offset.reset} configuration and force a + * state store {@link org.apache.kafka.streams.kstream.Materialized materialization} to access the content via + * interactive queries: + *
                {@code
                + * StreamsBuilder builder = new StreamsBuilder();
                + * KTable table = builder.table(
                + *   "topicName",
                + *   Consumed.with(AutoOffsetReset.LATEST),
                + *   Materialized.as("queryable-store-name"));
                + * }
                + * + * @param type of record key + * @param type of record value + */ +public class Consumed implements NamedOperation> { + + protected Serde keySerde; + protected Serde valueSerde; + protected TimestampExtractor timestampExtractor; + protected Topology.AutoOffsetReset resetPolicy; + protected String processorName; + + private Consumed(final Serde keySerde, + final Serde valueSerde, + final TimestampExtractor timestampExtractor, + final Topology.AutoOffsetReset resetPolicy, + final String processorName) { + this.keySerde = keySerde; + this.valueSerde = valueSerde; + this.timestampExtractor = timestampExtractor; + this.resetPolicy = resetPolicy; + this.processorName = processorName; + } + + /** + * Create an instance of {@link Consumed} from an existing instance. + * @param consumed the instance of {@link Consumed} to copy + */ + protected Consumed(final Consumed consumed) { + this(consumed.keySerde, + consumed.valueSerde, + consumed.timestampExtractor, + consumed.resetPolicy, + consumed.processorName + ); + } + + /** + * Create an instance of {@link Consumed} with the supplied arguments. {@code null} values are acceptable. + * + * @param keySerde the key serde. If {@code null} the default key serde from config will be used + * @param valueSerde the value serde. If {@code null} the default value serde from config will be used + * @param timestampExtractor the timestamp extractor to used. If {@code null} the default timestamp extractor from config will be used + * @param resetPolicy the offset reset policy to be used. If {@code null} the default reset policy from config will be used + * @param key type + * @param value type + * @return a new instance of {@link Consumed} + */ + public static Consumed with(final Serde keySerde, + final Serde valueSerde, + final TimestampExtractor timestampExtractor, + final Topology.AutoOffsetReset resetPolicy) { + return new Consumed<>(keySerde, valueSerde, timestampExtractor, resetPolicy, null); + + } + + /** + * Create an instance of {@link Consumed} with key and value {@link Serde}s. + * + * @param keySerde the key serde. If {@code null} the default key serde from config will be used + * @param valueSerde the value serde. If {@code null} the default value serde from config will be used + * @param key type + * @param value type + * @return a new instance of {@link Consumed} + */ + public static Consumed with(final Serde keySerde, + final Serde valueSerde) { + return new Consumed<>(keySerde, valueSerde, null, null, null); + } + + /** + * Create an instance of {@link Consumed} with a {@link TimestampExtractor}. + * + * @param timestampExtractor the timestamp extractor to used. If {@code null} the default timestamp extractor from config will be used + * @param key type + * @param value type + * @return a new instance of {@link Consumed} + */ + public static Consumed with(final TimestampExtractor timestampExtractor) { + return new Consumed<>(null, null, timestampExtractor, null, null); + } + + /** + * Create an instance of {@link Consumed} with a {@link org.apache.kafka.streams.Topology.AutoOffsetReset Topology.AutoOffsetReset}. + * + * @param resetPolicy the offset reset policy to be used. If {@code null} the default reset policy from config will be used + * @param key type + * @param value type + * @return a new instance of {@link Consumed} + */ + public static Consumed with(final Topology.AutoOffsetReset resetPolicy) { + return new Consumed<>(null, null, null, resetPolicy, null); + } + + /** + * Create an instance of {@link Consumed} with provided processor name. + * + * @param processorName the processor name to be used. If {@code null} a default processor name will be generated + * @param key type + * @param value type + * @return a new instance of {@link Consumed} + */ + public static Consumed as(final String processorName) { + return new Consumed<>(null, null, null, null, processorName); + } + + /** + * Configure the instance of {@link Consumed} with a key {@link Serde}. + * + * @param keySerde the key serde. If {@code null}the default key serde from config will be used + * @return this + */ + public Consumed withKeySerde(final Serde keySerde) { + this.keySerde = keySerde; + return this; + } + + /** + * Configure the instance of {@link Consumed} with a value {@link Serde}. + * + * @param valueSerde the value serde. If {@code null} the default value serde from config will be used + * @return this + */ + public Consumed withValueSerde(final Serde valueSerde) { + this.valueSerde = valueSerde; + return this; + } + + /** + * Configure the instance of {@link Consumed} with a {@link TimestampExtractor}. + * + * @param timestampExtractor the timestamp extractor to used. If {@code null} the default timestamp extractor from config will be used + * @return this + */ + public Consumed withTimestampExtractor(final TimestampExtractor timestampExtractor) { + this.timestampExtractor = timestampExtractor; + return this; + } + + /** + * Configure the instance of {@link Consumed} with a {@link org.apache.kafka.streams.Topology.AutoOffsetReset Topology.AutoOffsetReset}. + * + * @param resetPolicy the offset reset policy to be used. If {@code null} the default reset policy from config will be used + * @return this + */ + public Consumed withOffsetResetPolicy(final Topology.AutoOffsetReset resetPolicy) { + this.resetPolicy = resetPolicy; + return this; + } + + /** + * Configure the instance of {@link Consumed} with a processor name. + * + * @param processorName the processor name to be used. If {@code null} a default processor name will be generated + * @return this + */ + @Override + public Consumed withName(final String processorName) { + this.processorName = processorName; + return this; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final Consumed consumed = (Consumed) o; + return Objects.equals(keySerde, consumed.keySerde) && + Objects.equals(valueSerde, consumed.valueSerde) && + Objects.equals(timestampExtractor, consumed.timestampExtractor) && + resetPolicy == consumed.resetPolicy; + } + + @Override + public int hashCode() { + return Objects.hash(keySerde, valueSerde, timestampExtractor, resetPolicy); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/ForeachAction.java b/streams/src/main/java/org/apache/kafka/streams/kstream/ForeachAction.java new file mode 100644 index 0000000..f20d77c --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/ForeachAction.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + + +/** + * The {@code ForeachAction} interface for performing an action on a {@link org.apache.kafka.streams.KeyValue key-value + * pair}. + * This is a stateless record-by-record operation, i.e, {@link #apply(Object, Object)} is invoked individually for each + * record of a stream. + * If stateful processing is required, consider using + * {@link KStream#process(org.apache.kafka.streams.processor.api.ProcessorSupplier, String...) KStream#process(...)}. + * + * @param key type + * @param value type + * @see KStream#foreach(ForeachAction) + */ +public interface ForeachAction { + + /** + * Perform an action for each record of a stream. + * + * @param key the key of the record + * @param value the value of the record + */ + void apply(final K key, final V value); +} + + diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/ForeachProcessor.java b/streams/src/main/java/org/apache/kafka/streams/kstream/ForeachProcessor.java new file mode 100644 index 0000000..cccd298 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/ForeachProcessor.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.Record; + +public class ForeachProcessor implements Processor { + + private final ForeachAction action; + + public ForeachProcessor(final ForeachAction action) { + this.action = action; + } + + @Override + public void process(final Record record) { + action.apply(record.key(), record.value()); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/GlobalKTable.java b/streams/src/main/java/org/apache/kafka/streams/kstream/GlobalKTable.java new file mode 100644 index 0000000..73efbc3 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/GlobalKTable.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.state.ReadOnlyKeyValueStore; + +/** + * {@code GlobalKTable} is an abstraction of a changelog stream from a primary-keyed table. + * Each record in this changelog stream is an update on the primary-keyed table with the record key as the primary key. + *

                + * {@code GlobalKTable} can only be used as right-hand side input for {@link KStream stream}-table joins. + *

                + * In contrast to a {@link KTable} that is partitioned over all {@link KafkaStreams} instances, a {@code GlobalKTable} + * is fully replicated per {@link KafkaStreams} instance. + * Every partition of the underlying topic is consumed by each {@code GlobalKTable}, such that the full set of data is + * available in every {@link KafkaStreams} instance. + * This provides the ability to perform joins with {@link KStream} without having to repartition the input stream. + * All joins with the {@code GlobalKTable} require that a {@link KeyValueMapper} is provided that can map from the + * {@link KeyValue} of the left hand side {@link KStream} to the key of the right hand side {@code GlobalKTable}. + *

                + * A {@code GlobalKTable} is created via a {@link StreamsBuilder}. For example: + *

                {@code
                + * builder.globalTable("topic-name", "queryable-store-name");
                + * }
                + * all {@code GlobalKTable}s are backed by a {@link ReadOnlyKeyValueStore} and are therefore queryable via the + * interactive queries API. + * For example: + *
                {@code
                + * final GlobalKTable globalOne = builder.globalTable("g1", "g1-store");
                + * final GlobalKTable globalTwo = builder.globalTable("g2", "g2-store");
                + * ...
                + * final KafkaStreams streams = ...;
                + * streams.start()
                + * ...
                + * ReadOnlyKeyValueStore view = streams.store("g1-store", QueryableStoreTypes.timestampedKeyValueStore());
                + * view.get(key); // can be done on any key, as all keys are present
                + *}
                + * Note that in contrast to {@link KTable} a {@code GlobalKTable}'s state holds a full copy of the underlying topic, + * thus all keys can be queried locally. + *

                + * Records from the source topic that have null keys are dropped. + * + * @param Type of primary keys + * @param Type of value changes + * @see KTable + * @see StreamsBuilder#globalTable(String) + * @see KStream#join(GlobalKTable, KeyValueMapper, ValueJoiner) + * @see KStream#leftJoin(GlobalKTable, KeyValueMapper, ValueJoiner) + */ +public interface GlobalKTable { + /** + * Get the name of the local state store that can be used to query this {@code GlobalKTable}. + * + * @return the underlying state store name, or {@code null} if this {@code GlobalKTable} cannot be queried. + */ + String queryableStoreName(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/Grouped.java b/streams/src/main/java/org/apache/kafka/streams/kstream/Grouped.java new file mode 100644 index 0000000..44b0740 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/Grouped.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.common.serialization.Serde; + +/** + * The class that is used to capture the key and value {@link Serde}s and set the part of name used for + * repartition topics when performing {@link KStream#groupBy(KeyValueMapper, Grouped)}, {@link + * KStream#groupByKey(Grouped)}, or {@link KTable#groupBy(KeyValueMapper, Grouped)} operations. Note + * that Kafka Streams does not always create repartition topics for grouping operations. + * + * @param the key type + * @param the value type + */ +public class Grouped implements NamedOperation> { + + protected final Serde keySerde; + protected final Serde valueSerde; + protected final String name; + + private Grouped(final String name, + final Serde keySerde, + final Serde valueSerde) { + this.name = name; + this.keySerde = keySerde; + this.valueSerde = valueSerde; + } + + protected Grouped(final Grouped grouped) { + this(grouped.name, grouped.keySerde, grouped.valueSerde); + } + + /** + * Create a {@link Grouped} instance with the provided name used as part of the repartition topic if required. + * + * @param name the name used for a repartition topic if required + * @return a new {@link Grouped} configured with the name + * @see KStream#groupByKey(Grouped) + * @see KStream#groupBy(KeyValueMapper, Grouped) + * @see KTable#groupBy(KeyValueMapper, Grouped) + */ + public static Grouped as(final String name) { + return new Grouped<>(name, null, null); + } + + + /** + * Create a {@link Grouped} instance with the provided keySerde. If {@code null} the default key serde from config will be used. + * + * @param keySerde the Serde used for serializing the key. If {@code null} the default key serde from config will be used + * @return a new {@link Grouped} configured with the keySerde + * @see KStream#groupByKey(Grouped) + * @see KStream#groupBy(KeyValueMapper, Grouped) + * @see KTable#groupBy(KeyValueMapper, Grouped) + */ + public static Grouped keySerde(final Serde keySerde) { + return new Grouped<>(null, keySerde, null); + } + + + /** + * Create a {@link Grouped} instance with the provided valueSerde. If {@code null} the default value serde from config will be used. + * + * @param valueSerde the {@link Serde} used for serializing the value. If {@code null} the default value serde from config will be used + * @return a new {@link Grouped} configured with the valueSerde + * @see KStream#groupByKey(Grouped) + * @see KStream#groupBy(KeyValueMapper, Grouped) + * @see KTable#groupBy(KeyValueMapper, Grouped) + */ + public static Grouped valueSerde(final Serde valueSerde) { + return new Grouped<>(null, null, valueSerde); + } + + /** + * Create a {@link Grouped} instance with the provided name, keySerde, and valueSerde. If the keySerde and/or the valueSerde is + * {@code null} the default value for the respective serde from config will be used. + * + * @param name the name used as part of the repartition topic name if required + * @param keySerde the {@link Serde} used for serializing the key. If {@code null} the default key serde from config will be used + * @param valueSerde the {@link Serde} used for serializing the value. If {@code null} the default value serde from config will be used + * @return a new {@link Grouped} configured with the name, keySerde, and valueSerde + * @see KStream#groupByKey(Grouped) + * @see KStream#groupBy(KeyValueMapper, Grouped) + * @see KTable#groupBy(KeyValueMapper, Grouped) + */ + public static Grouped with(final String name, + final Serde keySerde, + final Serde valueSerde) { + return new Grouped<>(name, keySerde, valueSerde); + } + + + /** + * Create a {@link Grouped} instance with the provided keySerde and valueSerde. If the keySerde and/or the valueSerde is + * {@code null} the default value for the respective serde from config will be used. + * + * @param keySerde the {@link Serde} used for serializing the key. If {@code null} the default key serde from config will be used + * @param valueSerde the {@link Serde} used for serializing the value. If {@code null} the default value serde from config will be used + * @return a new {@link Grouped} configured with the keySerde, and valueSerde + * @see KStream#groupByKey(Grouped) + * @see KStream#groupBy(KeyValueMapper, Grouped) + * @see KTable#groupBy(KeyValueMapper, Grouped) + */ + public static Grouped with(final Serde keySerde, + final Serde valueSerde) { + return new Grouped<>(null, keySerde, valueSerde); + } + + /** + * Perform the grouping operation with the name for a repartition topic if required. Note + * that Kafka Streams does not always create repartition topics for grouping operations. + * + * @param name the name used for the processor name and as part of the repartition topic name if required + * @return a new {@link Grouped} instance configured with the name + * */ + @Override + public Grouped withName(final String name) { + return new Grouped<>(name, keySerde, valueSerde); + } + + /** + * Perform the grouping operation using the provided keySerde for serializing the key. + * + * @param keySerde {@link Serde} to use for serializing the key. If {@code null} the default key serde from config will be used + * @return a new {@link Grouped} instance configured with the keySerde + */ + public Grouped withKeySerde(final Serde keySerde) { + return new Grouped<>(name, keySerde, valueSerde); + } + + /** + * Perform the grouping operation using the provided valueSerde for serializing the value. + * + * @param valueSerde {@link Serde} to use for serializing the value. If {@code null} the default value serde from config will be used + * @return a new {@link Grouped} instance configured with the valueSerde + */ + public Grouped withValueSerde(final Serde valueSerde) { + return new Grouped<>(name, keySerde, valueSerde); + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/Initializer.java b/streams/src/main/java/org/apache/kafka/streams/kstream/Initializer.java new file mode 100644 index 0000000..1b59c64 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/Initializer.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + + +/** + * The {@code Initializer} interface for creating an initial value in aggregations. + * {@code Initializer} is used in combination with {@link Aggregator}. + * + * @param aggregate value type + * @see Aggregator + * @see KGroupedStream#aggregate(Initializer, Aggregator) + * @see KGroupedStream#aggregate(Initializer, Aggregator, Materialized) + * @see TimeWindowedKStream#aggregate(Initializer, Aggregator) + * @see TimeWindowedKStream#aggregate(Initializer, Aggregator, Materialized) + * @see SessionWindowedKStream#aggregate(Initializer, Aggregator, Merger) + * @see SessionWindowedKStream#aggregate(Initializer, Aggregator, Merger, Materialized) + */ +public interface Initializer { + + /** + * Return the initial value for an aggregation. + * + * @return the initial value for an aggregation + */ + VA apply(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/JoinWindows.java b/streams/src/main/java/org/apache/kafka/streams/kstream/JoinWindows.java new file mode 100644 index 0000000..f26aee5 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/JoinWindows.java @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.streams.processor.TimestampExtractor; + +import java.time.Duration; +import java.util.Map; +import java.util.Objects; + +import static org.apache.kafka.streams.internals.ApiUtils.prepareMillisCheckFailMsgPrefix; +import static org.apache.kafka.streams.internals.ApiUtils.validateMillisecondDuration; + +/** + * The window specifications used for joins. + *

                + * A {@code JoinWindows} instance defines a maximum time difference for a {@link KStream#join(KStream, ValueJoiner, + * JoinWindows) join over two streams} on the same key. + * In SQL-style you would express this join as + *

                {@code
                + *     SELECT * FROM stream1, stream2
                + *     WHERE
                + *       stream1.key = stream2.key
                + *       AND
                + *       stream1.ts - before <= stream2.ts AND stream2.ts <= stream1.ts + after
                + * }
                + * There are three different window configuration supported: + *
                  + *
                • before = after = time-difference
                • + *
                • before = 0 and after = time-difference
                • + *
                • before = time-difference and after = 0
                • + *
                + * A join is symmetric in the sense, that a join specification on the first stream returns the same result record as + * a join specification on the second stream with flipped before and after values. + *

                + * Both values (before and after) must not result in an "inverse" window, i.e., upper-interval bound cannot be smaller + * than lower-interval bound. + *

                + * {@code JoinWindows} are sliding windows, thus, they are aligned to the actual record timestamps. + * This implies, that each input record defines its own window with start and end time being relative to the record's + * timestamp. + *

                + * For time semantics, see {@link TimestampExtractor}. + * + * @see TimeWindows + * @see UnlimitedWindows + * @see SessionWindows + * @see KStream#join(KStream, ValueJoiner, JoinWindows) + * @see KStream#join(KStream, ValueJoiner, JoinWindows, StreamJoined) + * @see KStream#leftJoin(KStream, ValueJoiner, JoinWindows) + * @see KStream#leftJoin(KStream, ValueJoiner, JoinWindows, StreamJoined) + * @see KStream#outerJoin(KStream, ValueJoiner, JoinWindows) + * @see KStream#outerJoin(KStream, ValueJoiner, JoinWindows, StreamJoined) + * @see TimestampExtractor + */ +public class JoinWindows extends Windows { + + /** Maximum time difference for tuples that are before the join tuple. */ + public final long beforeMs; + /** Maximum time difference for tuples that are after the join tuple. */ + public final long afterMs; + + private final long graceMs; + + /** + * Enable left/outer stream-stream join, by not emitting left/outer results eagerly, but only after the grace period passed. + * This flag can only be enabled via ofTimeDifferenceAndGrace or ofTimeDifferenceWithNoGrace. + */ + protected final boolean enableSpuriousResultFix; + + protected JoinWindows(final JoinWindows joinWindows) { + this(joinWindows.beforeMs, joinWindows.afterMs, joinWindows.graceMs, joinWindows.enableSpuriousResultFix); + } + + private JoinWindows(final long beforeMs, + final long afterMs, + final long graceMs, + final boolean enableSpuriousResultFix) { + if (beforeMs + afterMs < 0) { + throw new IllegalArgumentException("Window interval (ie, beforeMs+afterMs) must not be negative."); + } + + if (graceMs < 0) { + throw new IllegalArgumentException("Grace period must not be negative."); + } + + this.afterMs = afterMs; + this.beforeMs = beforeMs; + this.graceMs = graceMs; + this.enableSpuriousResultFix = enableSpuriousResultFix; + } + + /** + * Specifies that records of the same key are joinable if their timestamps are within {@code timeDifference}, + * i.e., the timestamp of a record from the secondary stream is max {@code timeDifference} before or after + * the timestamp of the record from the primary stream. + *

                + * Using this method explicitly sets the grace period to the duration specified by {@code afterWindowEnd}, which + * means that only out-of-order records arriving more than the grace period after the window end will be dropped. + * The window close, after which any incoming records are considered late and will be rejected, is defined as + * {@code windowEnd + afterWindowEnd} + * + * @param timeDifference join window interval + * @param afterWindowEnd The grace period to admit out-of-order events to a window. + * @return A new JoinWindows object with the specified window definition and grace period + * @throws IllegalArgumentException if {@code timeDifference} is negative or can't be represented as {@code long milliseconds} + * if {@code afterWindowEnd} is negative or can't be represented as {@code long milliseconds} + */ + public static JoinWindows ofTimeDifferenceAndGrace(final Duration timeDifference, final Duration afterWindowEnd) { + final String timeDifferenceMsgPrefix = prepareMillisCheckFailMsgPrefix(timeDifference, "timeDifference"); + final long timeDifferenceMs = validateMillisecondDuration(timeDifference, timeDifferenceMsgPrefix); + + final String afterWindowEndMsgPrefix = prepareMillisCheckFailMsgPrefix(afterWindowEnd, "afterWindowEnd"); + final long afterWindowEndMs = validateMillisecondDuration(afterWindowEnd, afterWindowEndMsgPrefix); + + return new JoinWindows(timeDifferenceMs, timeDifferenceMs, afterWindowEndMs, true); + } + + /** + * Specifies that records of the same key are joinable if their timestamps are within {@code timeDifference}, + * i.e., the timestamp of a record from the secondary stream is max {@code timeDifference} before or after + * the timestamp of the record from the primary stream. + *

                + * CAUTION: Using this method implicitly sets the grace period to zero, which means that any out-of-order + * records arriving after the window ends are considered late and will be dropped. + * + * @param timeDifference join window interval + * @return a new JoinWindows object with the window definition and no grace period. Note that this means out-of-order records arriving after the window end will be dropped + * @throws IllegalArgumentException if {@code timeDifference} is negative or can't be represented as {@code long milliseconds} + */ + public static JoinWindows ofTimeDifferenceWithNoGrace(final Duration timeDifference) { + return ofTimeDifferenceAndGrace(timeDifference, Duration.ofMillis(NO_GRACE_PERIOD)); + } + + /** + * Specifies that records of the same key are joinable if their timestamps are within {@code timeDifference}, + * i.e., the timestamp of a record from the secondary stream is max {@code timeDifference} before or after + * the timestamp of the record from the primary stream. + * + * @param timeDifference join window interval + * @return a new JoinWindows object with the window definition with and grace period (default to 24 hours minus {@code timeDifference}) + * @throws IllegalArgumentException if {@code timeDifference} is negative or can't be represented as {@code long milliseconds} + * @deprecated since 3.0. Use {@link #ofTimeDifferenceWithNoGrace(Duration)}} instead + */ + @Deprecated + public static JoinWindows of(final Duration timeDifference) throws IllegalArgumentException { + final String msgPrefix = prepareMillisCheckFailMsgPrefix(timeDifference, "timeDifference"); + final long timeDifferenceMs = validateMillisecondDuration(timeDifference, msgPrefix); + return new JoinWindows(timeDifferenceMs, timeDifferenceMs, Math.max(DEPRECATED_DEFAULT_24_HR_GRACE_PERIOD - timeDifferenceMs * 2, 0), false); + } + + /** + * Changes the start window boundary to {@code timeDifference} but keep the end window boundary as is. + * Thus, records of the same key are joinable if the timestamp of a record from the secondary stream is at most + * {@code timeDifference} earlier than the timestamp of the record from the primary stream. + * {@code timeDifference} can be negative but its absolute value must not be larger than current window "after" + * value (which would result in a negative window size). + * + * @param timeDifference relative window start time + * @throws IllegalArgumentException if the resulting window size is negative or {@code timeDifference} can't be represented as {@code long milliseconds} + */ + public JoinWindows before(final Duration timeDifference) throws IllegalArgumentException { + final String msgPrefix = prepareMillisCheckFailMsgPrefix(timeDifference, "timeDifference"); + final long timeDifferenceMs = validateMillisecondDuration(timeDifference, msgPrefix); + return new JoinWindows(timeDifferenceMs, afterMs, graceMs, enableSpuriousResultFix); + } + + /** + * Changes the end window boundary to {@code timeDifference} but keep the start window boundary as is. + * Thus, records of the same key are joinable if the timestamp of a record from the secondary stream is at most + * {@code timeDifference} later than the timestamp of the record from the primary stream. + * {@code timeDifference} can be negative but its absolute value must not be larger than current window "before" + * value (which would result in a negative window size). + * + * @param timeDifference relative window end time + * @throws IllegalArgumentException if the resulting window size is negative or {@code timeDifference} can't be represented as {@code long milliseconds} + */ + public JoinWindows after(final Duration timeDifference) throws IllegalArgumentException { + final String msgPrefix = prepareMillisCheckFailMsgPrefix(timeDifference, "timeDifference"); + final long timeDifferenceMs = validateMillisecondDuration(timeDifference, msgPrefix); + return new JoinWindows(beforeMs, timeDifferenceMs, graceMs, enableSpuriousResultFix); + } + + /** + * Not supported by {@code JoinWindows}. + * Throws {@link UnsupportedOperationException}. + * + * @throws UnsupportedOperationException at every invocation + */ + @Override + public Map windowsFor(final long timestamp) { + throw new UnsupportedOperationException("windowsFor() is not supported by JoinWindows."); + } + + @Override + public long size() { + return beforeMs + afterMs; + } + + /** + * Reject out-of-order events that are delayed more than {@code afterWindowEnd} + * after the end of its window. + *

                + * Delay is defined as (stream_time - record_timestamp). + * + * @param afterWindowEnd The grace period to admit out-of-order events to a window. + * @return this updated builder + * @throws IllegalArgumentException if the {@code afterWindowEnd} is negative or can't be represented as {@code long milliseconds} + * @throws IllegalStateException if {@link #grace(Duration)} is called after {@link #ofTimeDifferenceAndGrace(Duration, Duration)} or {@link #ofTimeDifferenceWithNoGrace(Duration)} + * @deprecated since 3.0. Use {@link #ofTimeDifferenceAndGrace(Duration, Duration)} instead + */ + @Deprecated + public JoinWindows grace(final Duration afterWindowEnd) throws IllegalArgumentException { + // re-use the enableSpuriousResultFix flag to identify if grace is called after ofTimeDifferenceAndGrace/ofTimeDifferenceWithNoGrace + if (this.enableSpuriousResultFix) { + throw new IllegalStateException( + "Cannot call grace() after setting grace value via ofTimeDifferenceAndGrace or ofTimeDifferenceWithNoGrace."); + } + + final String msgPrefix = prepareMillisCheckFailMsgPrefix(afterWindowEnd, "afterWindowEnd"); + final long afterWindowEndMs = validateMillisecondDuration(afterWindowEnd, msgPrefix); + return new JoinWindows(beforeMs, afterMs, afterWindowEndMs, false); + } + + @Override + public long gracePeriodMs() { + return graceMs; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final JoinWindows that = (JoinWindows) o; + return beforeMs == that.beforeMs && + afterMs == that.afterMs && + graceMs == that.graceMs; + } + + @Override + public int hashCode() { + return Objects.hash(beforeMs, afterMs, graceMs); + } + + @Override + public String toString() { + return "JoinWindows{" + + "beforeMs=" + beforeMs + + ", afterMs=" + afterMs + + ", graceMs=" + graceMs + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/Joined.java b/streams/src/main/java/org/apache/kafka/streams/kstream/Joined.java new file mode 100644 index 0000000..a2793af --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/Joined.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.common.serialization.Serde; + +/** + * The {@code Joined} class represents optional params that can be passed to + * {@link KStream#join(KTable, ValueJoiner, Joined) KStream#join(KTable,...)} and + * {@link KStream#leftJoin(KTable, ValueJoiner) KStream#leftJoin(KTable,...)} operations. + */ +public class Joined implements NamedOperation> { + + protected final Serde keySerde; + protected final Serde valueSerde; + protected final Serde otherValueSerde; + protected final String name; + + private Joined(final Serde keySerde, + final Serde valueSerde, + final Serde otherValueSerde, + final String name) { + this.keySerde = keySerde; + this.valueSerde = valueSerde; + this.otherValueSerde = otherValueSerde; + this.name = name; + } + + protected Joined(final Joined joined) { + this(joined.keySerde, joined.valueSerde, joined.otherValueSerde, joined.name); + } + + /** + * Create an instance of {@code Joined} with key, value, and otherValue {@link Serde} instances. + * {@code null} values are accepted and will be replaced by the default serdes as defined in config. + * + * @param keySerde the key serde to use. If {@code null} the default key serde from config will be used + * @param valueSerde the value serde to use. If {@code null} the default value serde from config will be used + * @param otherValueSerde the otherValue serde to use. If {@code null} the default value serde from config will be used + * @param key type + * @param value type + * @param other value type + * @return new {@code Joined} instance with the provided serdes + */ + public static Joined with(final Serde keySerde, + final Serde valueSerde, + final Serde otherValueSerde) { + return new Joined<>(keySerde, valueSerde, otherValueSerde, null); + } + + /** + * Create an instance of {@code Joined} with key, value, and otherValue {@link Serde} instances. + * {@code null} values are accepted and will be replaced by the default serdes as defined in + * config. + * + * @param keySerde the key serde to use. If {@code null} the default key serde from config will be + * used + * @param valueSerde the value serde to use. If {@code null} the default value serde from config + * will be used + * @param otherValueSerde the otherValue serde to use. If {@code null} the default value serde + * from config will be used + * @param name the name used as the base for naming components of the join including any + * repartition topics + * @param key type + * @param value type + * @param other value type + * @return new {@code Joined} instance with the provided serdes + */ + public static Joined with(final Serde keySerde, + final Serde valueSerde, + final Serde otherValueSerde, + final String name) { + return new Joined<>(keySerde, valueSerde, otherValueSerde, name); + } + + /** + * Create an instance of {@code Joined} with a key {@link Serde}. + * {@code null} values are accepted and will be replaced by the default key serde as defined in config. + * + * @param keySerde the key serde to use. If {@code null} the default key serde from config will be used + * @param key type + * @param value type + * @param other value type + * @return new {@code Joined} instance configured with the keySerde + */ + public static Joined keySerde(final Serde keySerde) { + return new Joined<>(keySerde, null, null, null); + } + + /** + * Create an instance of {@code Joined} with a value {@link Serde}. + * {@code null} values are accepted and will be replaced by the default value serde as defined in config. + * + * @param valueSerde the value serde to use. If {@code null} the default value serde from config will be used + * @param key type + * @param value type + * @param other value type + * @return new {@code Joined} instance configured with the valueSerde + */ + public static Joined valueSerde(final Serde valueSerde) { + return new Joined<>(null, valueSerde, null, null); + } + + /** + * Create an instance of {@code Joined} with an other value {@link Serde}. + * {@code null} values are accepted and will be replaced by the default value serde as defined in config. + * + * @param otherValueSerde the otherValue serde to use. If {@code null} the default value serde from config will be used + * @param key type + * @param value type + * @param other value type + * @return new {@code Joined} instance configured with the otherValueSerde + */ + public static Joined otherValueSerde(final Serde otherValueSerde) { + return new Joined<>(null, null, otherValueSerde, null); + } + + /** + * Create an instance of {@code Joined} with base name for all components of the join, this may + * include any repartition topics created to complete the join. + * + * @param name the name used as the base for naming components of the join including any + * repartition topics + * @param key type + * @param value type + * @param other value type + * @return new {@code Joined} instance configured with the name + * + */ + public static Joined as(final String name) { + return new Joined<>(null, null, null, name); + } + + + /** + * Set the key {@link Serde} to be used. Null values are accepted and will be replaced by the default + * key serde as defined in config + * + * @param keySerde the key serde to use. If null the default key serde from config will be used + * @return new {@code Joined} instance configured with the {@code name} + */ + public Joined withKeySerde(final Serde keySerde) { + return new Joined<>(keySerde, valueSerde, otherValueSerde, name); + } + + /** + * Set the value {@link Serde} to be used. Null values are accepted and will be replaced by the default + * value serde as defined in config + * + * @param valueSerde the value serde to use. If null the default value serde from config will be used + * @return new {@code Joined} instance configured with the {@code valueSerde} + */ + public Joined withValueSerde(final Serde valueSerde) { + return new Joined<>(keySerde, valueSerde, otherValueSerde, name); + } + + /** + * Set the otherValue {@link Serde} to be used. Null values are accepted and will be replaced by the default + * value serde as defined in config + * + * @param otherValueSerde the otherValue serde to use. If null the default value serde from config will be used + * @return new {@code Joined} instance configured with the {@code valueSerde} + */ + public Joined withOtherValueSerde(final Serde otherValueSerde) { + return new Joined<>(keySerde, valueSerde, otherValueSerde, name); + } + + /** + * Set the base name used for all components of the join, this may include any repartition topics + * created to complete the join. + * + * @param name the name used as the base for naming components of the join including any + * repartition topics + * @return new {@code Joined} instance configured with the {@code name} + */ + @Override + public Joined withName(final String name) { + return new Joined<>(keySerde, valueSerde, otherValueSerde, name); + } + + public Serde keySerde() { + return keySerde; + } + + public Serde valueSerde() { + return valueSerde; + } + + public Serde otherValueSerde() { + return otherValueSerde; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/KGroupedStream.java b/streams/src/main/java/org/apache/kafka/streams/kstream/KGroupedStream.java new file mode 100644 index 0000000..072558c --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/KGroupedStream.java @@ -0,0 +1,575 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StoreQueryParameters; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.ReadOnlyKeyValueStore; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; + +/** + * {@code KGroupedStream} is an abstraction of a grouped record stream of {@link KeyValue} pairs. + * It is an intermediate representation of a {@link KStream} in order to apply an aggregation operation on the original + * {@link KStream} records. + *

                + * It is an intermediate representation after a grouping of a {@link KStream} before an aggregation is applied to the + * new partitions resulting in a {@link KTable}. + *

                + * A {@code KGroupedStream} must be obtained from a {@link KStream} via {@link KStream#groupByKey() groupByKey()} or + * {@link KStream#groupBy(KeyValueMapper) groupBy(...)}. + * + * @param Type of keys + * @param Type of values + * @see KStream + */ +public interface KGroupedStream { + + /** + * Count the number of records in this stream by the grouped key. + * Records with {@code null} key or value are ignored. + * The result is written into a local {@link KeyValueStore} (which is basically an ever-updating materialized view). + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedKeyValueStore}) will be backed by + * an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-${internalStoreName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "internalStoreName" is an internal name + * and "-changelog" is a fixed suffix. + * Note that the internal store name may not be queryable through Interactive Queries. + * + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @return a {@link KTable} that contains "update" records with unmodified keys and {@link Long} values that + * represent the latest (rolling) count (i.e., number of records) for each key + */ + KTable count(); + + /** + * Count the number of records in this stream by the grouped key. + * Records with {@code null} key or value are ignored. + * The result is written into a local {@link KeyValueStore} (which is basically an ever-updating materialized view). + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedKeyValueStore}) will be backed by + * an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-${internalStoreName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "internalStoreName" is an internal name + * and "-changelog" is a fixed suffix. + * Note that the internal store name may not be queryable through Interactive Queries. + * + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param named a {@link Named} config used to name the processor in the topology + * + * @return a {@link KTable} that contains "update" records with unmodified keys and {@link Long} values that + * represent the latest (rolling) count (i.e., number of records) for each key + */ + KTable count(final Named named); + + /** + * Count the number of records in this stream by the grouped key. + * Records with {@code null} key or value are ignored. + * The result is written into a local {@link KeyValueStore} (which is basically an ever-updating materialized view) + * provided by the given store name in {@code materialized}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link ReadOnlyKeyValueStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}. + *

                {@code
                +     * KafkaStreams streams = ... // counting words
                +     * String queryableStoreName = "storeName"; // the store name should be the name of the store as defined by the Materialized instance
                +     * ReadOnlyKeyValueStore> localStore = streams.store(queryableStoreName, QueryableStoreTypes.>timestampedKeyValueStore());
                +     * K key = "some-word";
                +     * ValueAndTimestamp countForWord = localStore.get(key); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + * + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedKeyValueStore} -- regardless of what + * is specified in the parameter {@code materialized}) will be backed by an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the Materialized instance must be a valid Kafka topic name and cannot contain characters other than ASCII + * alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is the + * provide store name defined in {@code Materialized}, and "-changelog" is a fixed suffix. + * + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param materialized an instance of {@link Materialized} used to materialize a state store. Cannot be {@code null}. + * Note: the valueSerde will be automatically set to {@link org.apache.kafka.common.serialization.Serdes#Long() Serdes#Long()} + * if there is no valueSerde provided + * @return a {@link KTable} that contains "update" records with unmodified keys and {@link Long} values that + * represent the latest (rolling) count (i.e., number of records) for each key + */ + KTable count(final Materialized> materialized); + + /** + * Count the number of records in this stream by the grouped key. + * Records with {@code null} key or value are ignored. + * The result is written into a local {@link KeyValueStore} (which is basically an ever-updating materialized view) + * provided by the given store name in {@code materialized}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link ReadOnlyKeyValueStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}. + *

                {@code
                +     * KafkaStreams streams = ... // counting words
                +     * String queryableStoreName = "storeName"; // the store name should be the name of the store as defined by the Materialized instance
                +     * ReadOnlyKeyValueStore> localStore = streams.store(queryableStoreName, QueryableStoreTypes.>timestampedKeyValueStore());
                +     * K key = "some-word";
                +     * ValueAndTimestamp countForWord = localStore.get(key); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + * + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedKeyValueStore} -- regardless of what + * is specified in the parameter {@code materialized}) will be backed by an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the Materialized instance must be a valid Kafka topic name and cannot contain characters other than ASCII + * alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is the + * provide store name defined in {@code Materialized}, and "-changelog" is a fixed suffix. + * + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param named a {@link Named} config used to name the processor in the topology + * @param materialized an instance of {@link Materialized} used to materialize a state store. Cannot be {@code null}. + * Note: the valueSerde will be automatically set to {@link org.apache.kafka.common.serialization.Serdes#Long() Serdes#Long()} + * if there is no valueSerde provided + * @return a {@link KTable} that contains "update" records with unmodified keys and {@link Long} values that + * represent the latest (rolling) count (i.e., number of records) for each key + */ + KTable count(final Named named, + final Materialized> materialized); + + /** + * Combine the values of records in this stream by the grouped key. + * Records with {@code null} key or value are ignored. + * Combining implies that the type of the aggregate result is the same as the type of the input value + * (c.f. {@link #aggregate(Initializer, Aggregator)}). + *

                + * The specified {@link Reducer} is applied for each input record and computes a new aggregate using the current + * aggregate and the record's value. + * If there is no current aggregate the {@link Reducer} is not applied and the new aggregate will be the record's + * value as-is. + * Thus, {@code reduce(Reducer)} can be used to compute aggregate functions like sum, min, or max. + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + * + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedKeyValueStore}) will be backed by + * an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-${internalStoreName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "internalStoreName" is an internal name + * and "-changelog" is a fixed suffix. + * Note that the internal store name may not be queryable through Interactive Queries. + * + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param reducer a {@link Reducer} that computes a new aggregate result. Cannot be {@code null}. + * @return a {@link KTable} that contains "update" records with unmodified keys, and values that represent the + * latest (rolling) aggregate for each key. If the reduce function returns {@code null}, it is then interpreted as + * deletion for the key, and future messages of the same key coming from upstream operators + * will be handled as newly initialized value. + */ + KTable reduce(final Reducer reducer); + + /** + * Combine the value of records in this stream by the grouped key. + * Records with {@code null} key or value are ignored. + * Combining implies that the type of the aggregate result is the same as the type of the input value + * (c.f. {@link #aggregate(Initializer, Aggregator, Materialized)}). + * The result is written into a local {@link KeyValueStore} (which is basically an ever-updating materialized view) + * provided by the given store name in {@code materialized}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * The specified {@link Reducer} is applied for each input record and computes a new aggregate using the current + * aggregate (first argument) and the record's value (second argument): + *

                {@code
                +     * // At the example of a Reducer
                +     * new Reducer() {
                +     *   public Long apply(Long aggValue, Long currValue) {
                +     *     return aggValue + currValue;
                +     *   }
                +     * }
                +     * }
                + *

                + * If there is no current aggregate the {@link Reducer} is not applied and the new aggregate will be the record's + * value as-is. + * Thus, {@code reduce(Reducer, Materialized)} can be used to compute aggregate functions like sum, min, or + * max. + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link ReadOnlyKeyValueStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}. + *

                {@code
                +     * KafkaStreams streams = ... // compute sum
                +     * String queryableStoreName = "storeName" // the store name should be the name of the store as defined by the Materialized instance
                +     * ReadOnlyKeyValueStore> localStore = streams.store(queryableStoreName, QueryableStoreTypes.>timestampedKeyValueStore());
                +     * K key = "some-key";
                +     * ValueAndTimestamp reduceForKey = localStore.get(key); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + * + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedKeyValueStore} -- regardless of what + * is specified in the parameter {@code materialized}) will be backed by an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-${internalStoreName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "internalStoreName" is an internal name + * and "-changelog" is a fixed suffix. + * Note that the internal store name may not be queryable through Interactive Queries. + * + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param reducer a {@link Reducer} that computes a new aggregate result. Cannot be {@code null}. + * @param materialized an instance of {@link Materialized} used to materialize a state store. Cannot be {@code null}. + * @return a {@link KTable} that contains "update" records with unmodified keys, and values that represent the + * latest (rolling) aggregate for each key + */ + KTable reduce(final Reducer reducer, + final Materialized> materialized); + + + /** + * Combine the value of records in this stream by the grouped key. + * Records with {@code null} key or value are ignored. + * Combining implies that the type of the aggregate result is the same as the type of the input value + * (c.f. {@link #aggregate(Initializer, Aggregator, Materialized)}). + * The result is written into a local {@link KeyValueStore} (which is basically an ever-updating materialized view) + * provided by the given store name in {@code materialized}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * The specified {@link Reducer} is applied for each input record and computes a new aggregate using the current + * aggregate (first argument) and the record's value (second argument): + *

                {@code
                +     * // At the example of a Reducer
                +     * new Reducer() {
                +     *   public Long apply(Long aggValue, Long currValue) {
                +     *     return aggValue + currValue;
                +     *   }
                +     * }
                +     * }
                + *

                + * If there is no current aggregate the {@link Reducer} is not applied and the new aggregate will be the record's + * value as-is. + * Thus, {@code reduce(Reducer, Materialized)} can be used to compute aggregate functions like sum, min, or + * max. + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link ReadOnlyKeyValueStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}. + *

                {@code
                +     * KafkaStreams streams = ... // compute sum
                +     * String queryableStoreName = "storeName" // the store name should be the name of the store as defined by the Materialized instance
                +     * ReadOnlyKeyValueStore> localStore = streams.store(queryableStoreName, QueryableStoreTypes.>timestampedKeyValueStore());
                +     * K key = "some-key";
                +     * ValueAndTimestamp reduceForKey = localStore.get(key); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + * + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedKeyValueStore} -- regardless of what + * is specified in the parameter {@code materialized}) will be backed by an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-${internalStoreName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "internalStoreName" is an internal name + * and "-changelog" is a fixed suffix. + * Note that the internal store name may not be queryable through Interactive Queries. + * + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param reducer a {@link Reducer} that computes a new aggregate result. Cannot be {@code null}. + * @param named a {@link Named} config used to name the processor in the topology. + * @param materialized an instance of {@link Materialized} used to materialize a state store. Cannot be {@code null}. + * @return a {@link KTable} that contains "update" records with unmodified keys, and values that represent the + * latest (rolling) aggregate for each key. If the reduce function returns {@code null}, it is then interpreted as + * deletion for the key, and future messages of the same key coming from upstream operators + * will be handled as newly initialized value. + */ + KTable reduce(final Reducer reducer, + final Named named, + final Materialized> materialized); + + /** + * Aggregate the values of records in this stream by the grouped key. + * Records with {@code null} key or value are ignored. + * Aggregating is a generalization of {@link #reduce(Reducer) combining via reduce(...)} as it, for example, + * allows the result to have a different type than the input values. + *

                + * The specified {@link Initializer} is applied once directly before the first input record is processed to + * provide an initial intermediate aggregation result that is used to process the first record. + * The specified {@link Aggregator} is applied for each input record and computes a new aggregate using the current + * aggregate (or for the very first record using the intermediate aggregation result provided via the + * {@link Initializer}) and the record's value. + * Thus, {@code aggregate(Initializer, Aggregator)} can be used to compute aggregate functions like + * count (c.f. {@link #count()}). + *

                + * The default value serde from config will be used for serializing the result. + * If a different serde is required then you should use {@link #aggregate(Initializer, Aggregator, Materialized)}. + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + * + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedKeyValueStore}) will be backed by + * an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-${internalStoreName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "internalStoreName" is an internal name + * and "-changelog" is a fixed suffix. + * Note that the internal store name may not be queryable through Interactive Queries. + * + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param initializer an {@link Initializer} that computes an initial intermediate aggregation result + * @param aggregator an {@link Aggregator} that computes a new aggregate result + * @param the value type of the resulting {@link KTable} + * @return a {@link KTable} that contains "update" records with unmodified keys, and values that represent the + * latest (rolling) aggregate for each key. If the aggregate function returns {@code null}, it is then interpreted as + * deletion for the key, and future messages of the same key coming from upstream operators + * will be handled as newly initialized value. + */ + KTable aggregate(final Initializer initializer, + final Aggregator aggregator); + + /** + * Aggregate the values of records in this stream by the grouped key. + * Records with {@code null} key or value are ignored. + * Aggregating is a generalization of {@link #reduce(Reducer) combining via reduce(...)} as it, for example, + * allows the result to have a different type than the input values. + * The result is written into a local {@link KeyValueStore} (which is basically an ever-updating materialized view) + * that can be queried by the given store name in {@code materialized}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * The specified {@link Initializer} is applied once directly before the first input record is processed to + * provide an initial intermediate aggregation result that is used to process the first record. + * The specified {@link Aggregator} is applied for each input record and computes a new aggregate using the current + * aggregate (or for the very first record using the intermediate aggregation result provided via the + * {@link Initializer}) and the record's value. + * Thus, {@code aggregate(Initializer, Aggregator, Materialized)} can be used to compute aggregate functions like + * count (c.f. {@link #count()}). + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link ReadOnlyKeyValueStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // some aggregation on value type double
                +     * String queryableStoreName = "storeName" // the store name should be the name of the store as defined by the Materialized instance
                +     * ReadOnlyKeyValueStore> localStore = streams.store(queryableStoreName, QueryableStoreTypes.>timestampedKeyValueStore());
                +     * K key = "some-key";
                +     * ValueAndTimestamp aggForKey = localStore.get(key); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + * + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedKeyValueStore} -- regardless of what + * is specified in the parameter {@code materialized}) will be backed by an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the Materialized instance must be a valid Kafka topic name and cannot contain characters other than ASCII + * alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is the + * provide store name defined in {@code Materialized}, and "-changelog" is a fixed suffix. + * + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param initializer an {@link Initializer} that computes an initial intermediate aggregation result + * @param aggregator an {@link Aggregator} that computes a new aggregate result + * @param materialized an instance of {@link Materialized} used to materialize a state store. Cannot be {@code null}. + * @param the value type of the resulting {@link KTable} + * @return a {@link KTable} that contains "update" records with unmodified keys, and values that represent the + * latest (rolling) aggregate for each key + */ + KTable aggregate(final Initializer initializer, + final Aggregator aggregator, + final Materialized> materialized); + + /** + * Aggregate the values of records in this stream by the grouped key. + * Records with {@code null} key or value are ignored. + * Aggregating is a generalization of {@link #reduce(Reducer) combining via reduce(...)} as it, for example, + * allows the result to have a different type than the input values. + * The result is written into a local {@link KeyValueStore} (which is basically an ever-updating materialized view) + * that can be queried by the given store name in {@code materialized}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * The specified {@link Initializer} is applied once directly before the first input record is processed to + * provide an initial intermediate aggregation result that is used to process the first record. + * The specified {@link Aggregator} is applied for each input record and computes a new aggregate using the current + * aggregate (or for the very first record using the intermediate aggregation result provided via the + * {@link Initializer}) and the record's value. + * Thus, {@code aggregate(Initializer, Aggregator, Materialized)} can be used to compute aggregate functions like + * count (c.f. {@link #count()}). + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link ReadOnlyKeyValueStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // some aggregation on value type double
                +     * String queryableStoreName = "storeName" // the store name should be the name of the store as defined by the Materialized instance
                +     * ReadOnlyKeyValueStore> localStore = streams.store(queryableStoreName, QueryableStoreTypes.>timestampedKeyValueStore());
                +     * K key = "some-key";
                +     * ValueAndTimestamp aggForKey = localStore.get(key); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + * + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedKeyValueStore} -- regardless of what + * is specified in the parameter {@code materialized}) will be backed by an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the Materialized instance must be a valid Kafka topic name and cannot contain characters other than ASCII + * alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is the + * provide store name defined in {@code Materialized}, and "-changelog" is a fixed suffix. + * + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param initializer an {@link Initializer} that computes an initial intermediate aggregation result + * @param aggregator an {@link Aggregator} that computes a new aggregate result + * @param named a {@link Named} config used to name the processor in the topology + * @param materialized an instance of {@link Materialized} used to materialize a state store. Cannot be {@code null}. + * @param the value type of the resulting {@link KTable} + * @return a {@link KTable} that contains "update" records with unmodified keys, and values that represent the + * latest (rolling) aggregate for each key. If the aggregate function returns {@code null}, it is then interpreted as + * deletion for the key, and future messages of the same key coming from upstream operators + * will be handled as newly initialized value. + */ + KTable aggregate(final Initializer initializer, + final Aggregator aggregator, + final Named named, + final Materialized> materialized); + + /** + * Create a new {@link TimeWindowedKStream} instance that can be used to perform windowed aggregations. + * @param windows the specification of the aggregation {@link Windows} + * @param the window type + * @return an instance of {@link TimeWindowedKStream} + */ + TimeWindowedKStream windowedBy(final Windows windows); + + /** + * Create a new {@link TimeWindowedKStream} instance that can be used to perform sliding windowed aggregations. + * @param windows the specification of the aggregation {@link SlidingWindows} + * @return an instance of {@link TimeWindowedKStream} + */ + TimeWindowedKStream windowedBy(final SlidingWindows windows); + + /** + * Create a new {@link SessionWindowedKStream} instance that can be used to perform session windowed aggregations. + * @param windows the specification of the aggregation {@link SessionWindows} + * @return an instance of {@link TimeWindowedKStream} + */ + SessionWindowedKStream windowedBy(final SessionWindows windows); + + /** + * Create a new {@link CogroupedKStream} from the this grouped KStream to allow cogrouping other + * {@code KGroupedStream} to it. + * {@link CogroupedKStream} is an abstraction of multiple grouped record streams of {@link KeyValue} pairs. + * It is an intermediate representation after a grouping of {@link KStream}s, before the + * aggregations are applied to the new partitions resulting in a {@link KTable}. + *

                + * The specified {@link Aggregator} is applied in the actual {@link CogroupedKStream#aggregate(Initializer) + * aggregation} step for each input record and computes a new aggregate using the current aggregate (or for the very + * first record per key using the initial intermediate aggregation result provided via the {@link Initializer} that + * is passed into {@link CogroupedKStream#aggregate(Initializer)}) and the record's value. + * + * @param aggregator an {@link Aggregator} that computes a new aggregate result + * @param the type of the output values + * @return a {@link CogroupedKStream} + */ + CogroupedKStream cogroup(final Aggregator aggregator); + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/KGroupedTable.java b/streams/src/main/java/org/apache/kafka/streams/kstream/KGroupedTable.java new file mode 100644 index 0000000..06d12e1 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/KGroupedTable.java @@ -0,0 +1,700 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StoreQueryParameters; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.ReadOnlyKeyValueStore; + +/** + * {@code KGroupedTable} is an abstraction of a re-grouped changelog stream from a primary-keyed table, + * usually on a different grouping key than the original primary key. + *

                + * It is an intermediate representation after a re-grouping of a {@link KTable} before an aggregation is applied to the + * new partitions resulting in a new {@link KTable}. + *

                + * A {@code KGroupedTable} must be obtained from a {@link KTable} via {@link KTable#groupBy(KeyValueMapper) + * groupBy(...)}. + * + * @param Type of keys + * @param Type of values + * @see KTable + */ +public interface KGroupedTable { + + /** + * Count number of records of the original {@link KTable} that got {@link KTable#groupBy(KeyValueMapper) mapped} to + * the same key into a new instance of {@link KTable}. + * Records with {@code null} key are ignored. + * The result is written into a local {@link KeyValueStore} (which is basically an ever-updating materialized view) + * that can be queried using the provided {@code queryableStoreName}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link ReadOnlyKeyValueStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // counting words
                +     * ReadOnlyKeyValueStore> localStore = streams.store(queryableStoreName, QueryableStoreTypes.> timestampedKeyValueStore());
                +     * K key = "some-word";
                +     * ValueAndTimestamp countForWord = localStore.get(key); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + * + *

                + * For failure and recovery the store will be backed by an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the Materialized instance must be a valid Kafka topic name and cannot contain characters other than ASCII + * alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is the + * provide store name defined in {@code Materialized}, and "-changelog" is a fixed suffix. + * + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param materialized the instance of {@link Materialized} used to materialize the state store. Cannot be {@code null} + * @return a {@link KTable} that contains "update" records with unmodified keys and {@link Long} values that + * represent the latest (rolling) count (i.e., number of records) for each key + */ + KTable count(final Materialized> materialized); + + /** + * Count number of records of the original {@link KTable} that got {@link KTable#groupBy(KeyValueMapper) mapped} to + * the same key into a new instance of {@link KTable}. + * Records with {@code null} key are ignored. + * The result is written into a local {@link KeyValueStore} (which is basically an ever-updating materialized view) + * that can be queried using the provided {@code queryableStoreName}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link ReadOnlyKeyValueStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // counting words
                +     * ReadOnlyKeyValueStore> localStore = streams.store(queryableStoreName, QueryableStoreTypes.> timestampedKeyValueStore());
                +     * K key = "some-word";
                +     * ValueAndTimestamp countForWord = localStore.get(key); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + * + *

                + * For failure and recovery the store will be backed by an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the Materialized instance must be a valid Kafka topic name and cannot contain characters other than ASCII + * alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is the + * provide store name defined in {@code Materialized}, and "-changelog" is a fixed suffix. + * + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param named the {@link Named} config used to name the processor in the topology + * @param materialized the instance of {@link Materialized} used to materialize the state store. Cannot be {@code null} + * @return a {@link KTable} that contains "update" records with unmodified keys and {@link Long} values that + * represent the latest (rolling) count (i.e., number of records) for each key + */ + KTable count(final Named named, final Materialized> materialized); + + /** + * Count number of records of the original {@link KTable} that got {@link KTable#groupBy(KeyValueMapper) mapped} to + * the same key into a new instance of {@link KTable}. + * Records with {@code null} key are ignored. + * The result is written into a local {@link KeyValueStore} (which is basically an ever-updating materialized view) + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * For failure and recovery the store will be backed by an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-${internalStoreName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "internalStoreName" is an internal name + * and "-changelog" is a fixed suffix. + * Note that the internal store name may not be queryable through Interactive Queries. + * + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @return a {@link KTable} that contains "update" records with unmodified keys and {@link Long} values that + * represent the latest (rolling) count (i.e., number of records) for each key + */ + KTable count(); + + + /** + * Count number of records of the original {@link KTable} that got {@link KTable#groupBy(KeyValueMapper) mapped} to + * the same key into a new instance of {@link KTable}. + * Records with {@code null} key are ignored. + * The result is written into a local {@link KeyValueStore} (which is basically an ever-updating materialized view) + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * For failure and recovery the store will be backed by an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-${internalStoreName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "internalStoreName" is an internal name + * and "-changelog" is a fixed suffix. + * Note that the internal store name may not be queryable through Interactive Queries. + * + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param named the {@link Named} config used to name the processor in the topology + * @return a {@link KTable} that contains "update" records with unmodified keys and {@link Long} values that + * represent the latest (rolling) count (i.e., number of records) for each key + */ + KTable count(final Named named); + + /** + * Combine the value of records of the original {@link KTable} that got {@link KTable#groupBy(KeyValueMapper) + * mapped} to the same key into a new instance of {@link KTable}. + * Records with {@code null} key are ignored. + * Combining implies that the type of the aggregate result is the same as the type of the input value + * (c.f. {@link #aggregate(Initializer, Aggregator, Aggregator, Materialized)}). + * The result is written into a local {@link KeyValueStore} (which is basically an ever-updating materialized view) + * that can be queried using the provided {@code queryableStoreName}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * Each update to the original {@link KTable} results in a two step update of the result {@link KTable}. + * The specified {@link Reducer adder} is applied for each update record and computes a new aggregate using the + * current aggregate (first argument) and the record's value (second argument) by adding the new record to the + * aggregate. + * The specified {@link Reducer subtractor} is applied for each "replaced" record of the original {@link KTable} + * and computes a new aggregate using the current aggregate (first argument) and the record's value (second + * argument) by "removing" the "replaced" record from the aggregate. + * If there is no current aggregate the {@link Reducer} is not applied and the new aggregate will be the record's + * value as-is. + * Thus, {@code reduce(Reducer, Reducer, String)} can be used to compute aggregate functions like sum. + * For sum, the adder and subtractor would work as follows: + *

                {@code
                +     * public class SumAdder implements Reducer {
                +     *   public Integer apply(Integer currentAgg, Integer newValue) {
                +     *     return currentAgg + newValue;
                +     *   }
                +     * }
                +     *
                +     * public class SumSubtractor implements Reducer {
                +     *   public Integer apply(Integer currentAgg, Integer oldValue) {
                +     *     return currentAgg - oldValue;
                +     *   }
                +     * }
                +     * }
                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link ReadOnlyKeyValueStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // counting words
                +     * ReadOnlyKeyValueStore> localStore = streams.store(queryableStoreName, QueryableStoreTypes.> timestampedKeyValueStore());
                +     * K key = "some-word";
                +     * ValueAndTimestamp reduceForWord = localStore.get(key); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + *

                + * For failure and recovery the store will be backed by an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the Materialized instance must be a valid Kafka topic name and cannot contain characters other than ASCII + * alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is the + * provide store name defined in {@code Materialized}, and "-changelog" is a fixed suffix. + * + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param adder a {@link Reducer} that adds a new value to the aggregate result + * @param subtractor a {@link Reducer} that removed an old value from the aggregate result + * @param materialized the instance of {@link Materialized} used to materialize the state store. Cannot be {@code null} + * @return a {@link KTable} that contains "update" records with unmodified keys, and values that represent the + * latest (rolling) aggregate for each key + */ + KTable reduce(final Reducer adder, + final Reducer subtractor, + final Materialized> materialized); + + + /** + * Combine the value of records of the original {@link KTable} that got {@link KTable#groupBy(KeyValueMapper) + * mapped} to the same key into a new instance of {@link KTable}. + * Records with {@code null} key are ignored. + * Combining implies that the type of the aggregate result is the same as the type of the input value + * (c.f. {@link #aggregate(Initializer, Aggregator, Aggregator, Materialized)}). + * The result is written into a local {@link KeyValueStore} (which is basically an ever-updating materialized view) + * that can be queried using the provided {@code queryableStoreName}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * Each update to the original {@link KTable} results in a two step update of the result {@link KTable}. + * The specified {@link Reducer adder} is applied for each update record and computes a new aggregate using the + * current aggregate (first argument) and the record's value (second argument) by adding the new record to the + * aggregate. + * The specified {@link Reducer subtractor} is applied for each "replaced" record of the original {@link KTable} + * and computes a new aggregate using the current aggregate (first argument) and the record's value (second + * argument) by "removing" the "replaced" record from the aggregate. + * If there is no current aggregate the {@link Reducer} is not applied and the new aggregate will be the record's + * value as-is. + * Thus, {@code reduce(Reducer, Reducer, String)} can be used to compute aggregate functions like sum. + * For sum, the adder and subtractor would work as follows: + *

                {@code
                +     * public class SumAdder implements Reducer {
                +     *   public Integer apply(Integer currentAgg, Integer newValue) {
                +     *     return currentAgg + newValue;
                +     *   }
                +     * }
                +     *
                +     * public class SumSubtractor implements Reducer {
                +     *   public Integer apply(Integer currentAgg, Integer oldValue) {
                +     *     return currentAgg - oldValue;
                +     *   }
                +     * }
                +     * }
                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link ReadOnlyKeyValueStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // counting words
                +     * ReadOnlyKeyValueStore> localStore = streams.store(queryableStoreName, QueryableStoreTypes.> timestampedKeyValueStore());
                +     * K key = "some-word";
                +     * ValueAndTimestamp reduceForWord = localStore.get(key); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + *

                + * For failure and recovery the store will be backed by an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the Materialized instance must be a valid Kafka topic name and cannot contain characters other than ASCII + * alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is the + * provide store name defined in {@code Materialized}, and "-changelog" is a fixed suffix. + * + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param adder a {@link Reducer} that adds a new value to the aggregate result + * @param subtractor a {@link Reducer} that removed an old value from the aggregate result + * @param named a {@link Named} config used to name the processor in the topology + * @param materialized the instance of {@link Materialized} used to materialize the state store. Cannot be {@code null} + * @return a {@link KTable} that contains "update" records with unmodified keys, and values that represent the + * latest (rolling) aggregate for each key + */ + KTable reduce(final Reducer adder, + final Reducer subtractor, + final Named named, + final Materialized> materialized); + + /** + * Combine the value of records of the original {@link KTable} that got {@link KTable#groupBy(KeyValueMapper) + * mapped} to the same key into a new instance of {@link KTable}. + * Records with {@code null} key are ignored. + * Combining implies that the type of the aggregate result is the same as the type of the input value + * (c.f. {@link #aggregate(Initializer, Aggregator, Aggregator)}). + * The result is written into a local {@link KeyValueStore} (which is basically an ever-updating materialized view) + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * Each update to the original {@link KTable} results in a two step update of the result {@link KTable}. + * The specified {@link Reducer adder} is applied for each update record and computes a new aggregate using the + * current aggregate and the record's value by adding the new record to the aggregate. + * The specified {@link Reducer subtractor} is applied for each "replaced" record of the original {@link KTable} + * and computes a new aggregate using the current aggregate and the record's value by "removing" the "replaced" + * record from the aggregate. + * If there is no current aggregate the {@link Reducer} is not applied and the new aggregate will be the record's + * value as-is. + * Thus, {@code reduce(Reducer, Reducer)} can be used to compute aggregate functions like sum. + * For sum, the adder and subtractor would work as follows: + *

                {@code
                +     * public class SumAdder implements Reducer {
                +     *   public Integer apply(Integer currentAgg, Integer newValue) {
                +     *     return currentAgg + newValue;
                +     *   }
                +     * }
                +     *
                +     * public class SumSubtractor implements Reducer {
                +     *   public Integer apply(Integer currentAgg, Integer oldValue) {
                +     *     return currentAgg - oldValue;
                +     *   }
                +     * }
                +     * }
                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * For failure and recovery the store will be backed by an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-${internalStoreName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "internalStoreName" is an internal name + * and "-changelog" is a fixed suffix. + * Note that the internal store name may not be queryable through Interactive Queries. + * + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param adder a {@link Reducer} that adds a new value to the aggregate result + * @param subtractor a {@link Reducer} that removed an old value from the aggregate result + * @return a {@link KTable} that contains "update" records with unmodified keys, and values that represent the + * latest (rolling) aggregate for each key + */ + KTable reduce(final Reducer adder, + final Reducer subtractor); + + /** + * Aggregate the value of records of the original {@link KTable} that got {@link KTable#groupBy(KeyValueMapper) + * mapped} to the same key into a new instance of {@link KTable}. + * Records with {@code null} key are ignored. + * Aggregating is a generalization of {@link #reduce(Reducer, Reducer, Materialized) combining via reduce(...)} as it, + * for example, allows the result to have a different type than the input values. + * The result is written into a local {@link KeyValueStore} (which is basically an ever-updating materialized view) + * that can be queried using the provided {@code queryableStoreName}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * The specified {@link Initializer} is applied once directly before the first input record is processed to + * provide an initial intermediate aggregation result that is used to process the first record. + * Each update to the original {@link KTable} results in a two step update of the result {@link KTable}. + * The specified {@link Aggregator adder} is applied for each update record and computes a new aggregate using the + * current aggregate (or for the very first record using the intermediate aggregation result provided via the + * {@link Initializer}) and the record's value by adding the new record to the aggregate. + * The specified {@link Aggregator subtractor} is applied for each "replaced" record of the original {@link KTable} + * and computes a new aggregate using the current aggregate and the record's value by "removing" the "replaced" + * record from the aggregate. + * Thus, {@code aggregate(Initializer, Aggregator, Aggregator, Materialized)} can be used to compute aggregate functions + * like sum. + * For sum, the initializer, adder, and subtractor would work as follows: + *

                {@code
                +     * // in this example, LongSerde.class must be set as value serde in Materialized#withValueSerde
                +     * public class SumInitializer implements Initializer {
                +     *   public Long apply() {
                +     *     return 0L;
                +     *   }
                +     * }
                +     *
                +     * public class SumAdder implements Aggregator {
                +     *   public Long apply(String key, Integer newValue, Long aggregate) {
                +     *     return aggregate + newValue;
                +     *   }
                +     * }
                +     *
                +     * public class SumSubtractor implements Aggregator {
                +     *   public Long apply(String key, Integer oldValue, Long aggregate) {
                +     *     return aggregate - oldValue;
                +     *   }
                +     * }
                +     * }
                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link ReadOnlyKeyValueStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // counting words
                +     * ReadOnlyKeyValueStore> localStore = streams.store(queryableStoreName, QueryableStoreTypes.> timestampedKeyValueStore());
                +     * K key = "some-word";
                +     * ValueAndTimestamp aggregateForWord = localStore.get(key); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + *

                + * For failure and recovery the store will be backed by an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the Materialized instance must be a valid Kafka topic name and cannot contain characters other than ASCII + * alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is the + * provide store name defined in {@code Materialized}, and "-changelog" is a fixed suffix. + * + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param initializer an {@link Initializer} that provides an initial aggregate result value + * @param adder an {@link Aggregator} that adds a new record to the aggregate result + * @param subtractor an {@link Aggregator} that removed an old record from the aggregate result + * @param materialized the instance of {@link Materialized} used to materialize the state store. Cannot be {@code null} + * @param the value type of the aggregated {@link KTable} + * @return a {@link KTable} that contains "update" records with unmodified keys, and values that represent the + * latest (rolling) aggregate for each key + */ + KTable aggregate(final Initializer initializer, + final Aggregator adder, + final Aggregator subtractor, + final Materialized> materialized); + + + /** + * Aggregate the value of records of the original {@link KTable} that got {@link KTable#groupBy(KeyValueMapper) + * mapped} to the same key into a new instance of {@link KTable}. + * Records with {@code null} key are ignored. + * Aggregating is a generalization of {@link #reduce(Reducer, Reducer, Materialized) combining via reduce(...)} as it, + * for example, allows the result to have a different type than the input values. + * The result is written into a local {@link KeyValueStore} (which is basically an ever-updating materialized view) + * that can be queried using the provided {@code queryableStoreName}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * The specified {@link Initializer} is applied once directly before the first input record is processed to + * provide an initial intermediate aggregation result that is used to process the first record. + * Each update to the original {@link KTable} results in a two step update of the result {@link KTable}. + * The specified {@link Aggregator adder} is applied for each update record and computes a new aggregate using the + * current aggregate (or for the very first record using the intermediate aggregation result provided via the + * {@link Initializer}) and the record's value by adding the new record to the aggregate. + * The specified {@link Aggregator subtractor} is applied for each "replaced" record of the original {@link KTable} + * and computes a new aggregate using the current aggregate and the record's value by "removing" the "replaced" + * record from the aggregate. + * Thus, {@code aggregate(Initializer, Aggregator, Aggregator, Materialized)} can be used to compute aggregate functions + * like sum. + * For sum, the initializer, adder, and subtractor would work as follows: + *

                {@code
                +     * // in this example, LongSerde.class must be set as value serde in Materialized#withValueSerde
                +     * public class SumInitializer implements Initializer {
                +     *   public Long apply() {
                +     *     return 0L;
                +     *   }
                +     * }
                +     *
                +     * public class SumAdder implements Aggregator {
                +     *   public Long apply(String key, Integer newValue, Long aggregate) {
                +     *     return aggregate + newValue;
                +     *   }
                +     * }
                +     *
                +     * public class SumSubtractor implements Aggregator {
                +     *   public Long apply(String key, Integer oldValue, Long aggregate) {
                +     *     return aggregate - oldValue;
                +     *   }
                +     * }
                +     * }
                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link ReadOnlyKeyValueStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // counting words
                +     * ReadOnlyKeyValueStore> localStore = streams.store(queryableStoreName, QueryableStoreTypes.> timestampedKeyValueStore());
                +     * K key = "some-word";
                +     * ValueAndTimestamp aggregateForWord = localStore.get(key); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + *

                + * For failure and recovery the store will be backed by an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the Materialized instance must be a valid Kafka topic name and cannot contain characters other than ASCII + * alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is the + * provide store name defined in {@code Materialized}, and "-changelog" is a fixed suffix. + * + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param initializer an {@link Initializer} that provides an initial aggregate result value + * @param adder an {@link Aggregator} that adds a new record to the aggregate result + * @param subtractor an {@link Aggregator} that removed an old record from the aggregate result + * @param named a {@link Named} config used to name the processor in the topology + * @param materialized the instance of {@link Materialized} used to materialize the state store. Cannot be {@code null} + * @param the value type of the aggregated {@link KTable} + * @return a {@link KTable} that contains "update" records with unmodified keys, and values that represent the + * latest (rolling) aggregate for each key + */ + KTable aggregate(final Initializer initializer, + final Aggregator adder, + final Aggregator subtractor, + final Named named, + final Materialized> materialized); + + /** + * Aggregate the value of records of the original {@link KTable} that got {@link KTable#groupBy(KeyValueMapper) + * mapped} to the same key into a new instance of {@link KTable} using default serializers and deserializers. + * Records with {@code null} key are ignored. + * Aggregating is a generalization of {@link #reduce(Reducer, Reducer) combining via reduce(...)} as it, + * for example, allows the result to have a different type than the input values. + * If the result value type does not match the {@link StreamsConfig#DEFAULT_VALUE_SERDE_CLASS_CONFIG default value + * serde} you should use {@link #aggregate(Initializer, Aggregator, Aggregator, Materialized)}. + * The result is written into a local {@link KeyValueStore} (which is basically an ever-updating materialized view) + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * The specified {@link Initializer} is applied once directly before the first input record is processed to + * provide an initial intermediate aggregation result that is used to process the first record. + * Each update to the original {@link KTable} results in a two step update of the result {@link KTable}. + * The specified {@link Aggregator adder} is applied for each update record and computes a new aggregate using the + * current aggregate (or for the very first record using the intermediate aggregation result provided via the + * {@link Initializer}) and the record's value by adding the new record to the aggregate. + * The specified {@link Aggregator subtractor} is applied for each "replaced" record of the original {@link KTable} + * and computes a new aggregate using the current aggregate and the record's value by "removing" the "replaced" + * record from the aggregate. + * Thus, {@code aggregate(Initializer, Aggregator, Aggregator, String)} can be used to compute aggregate functions + * like sum. + * For sum, the initializer, adder, and subtractor would work as follows: + *

                {@code
                +     * // in this example, LongSerde.class must be set as default value serde in StreamsConfig
                +     * public class SumInitializer implements Initializer {
                +     *   public Long apply() {
                +     *     return 0L;
                +     *   }
                +     * }
                +     *
                +     * public class SumAdder implements Aggregator {
                +     *   public Long apply(String key, Integer newValue, Long aggregate) {
                +     *     return aggregate + newValue;
                +     *   }
                +     * }
                +     *
                +     * public class SumSubtractor implements Aggregator {
                +     *   public Long apply(String key, Integer oldValue, Long aggregate) {
                +     *     return aggregate - oldValue;
                +     *   }
                +     * }
                +     * }
                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + * For failure and recovery the store will be backed by an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-${internalStoreName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "internalStoreName" is an internal name + * and "-changelog" is a fixed suffix. + * Note that the internal store name may not be queryable through Interactive Queries. + * + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param initializer a {@link Initializer} that provides an initial aggregate result value + * @param adder a {@link Aggregator} that adds a new record to the aggregate result + * @param subtractor a {@link Aggregator} that removed an old record from the aggregate result + * @param the value type of the aggregated {@link KTable} + * @return a {@link KTable} that contains "update" records with unmodified keys, and values that represent the + * latest (rolling) aggregate for each key + */ + KTable aggregate(final Initializer initializer, + final Aggregator adder, + final Aggregator subtractor); + + + /** + * Aggregate the value of records of the original {@link KTable} that got {@link KTable#groupBy(KeyValueMapper) + * mapped} to the same key into a new instance of {@link KTable} using default serializers and deserializers. + * Records with {@code null} key are ignored. + * Aggregating is a generalization of {@link #reduce(Reducer, Reducer) combining via reduce(...)} as it, + * for example, allows the result to have a different type than the input values. + * If the result value type does not match the {@link StreamsConfig#DEFAULT_VALUE_SERDE_CLASS_CONFIG default value + * serde} you should use {@link #aggregate(Initializer, Aggregator, Aggregator, Materialized)}. + * The result is written into a local {@link KeyValueStore} (which is basically an ever-updating materialized view) + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * The specified {@link Initializer} is applied once directly before the first input record is processed to + * provide an initial intermediate aggregation result that is used to process the first record. + * Each update to the original {@link KTable} results in a two step update of the result {@link KTable}. + * The specified {@link Aggregator adder} is applied for each update record and computes a new aggregate using the + * current aggregate (or for the very first record using the intermediate aggregation result provided via the + * {@link Initializer}) and the record's value by adding the new record to the aggregate. + * The specified {@link Aggregator subtractor} is applied for each "replaced" record of the original {@link KTable} + * and computes a new aggregate using the current aggregate and the record's value by "removing" the "replaced" + * record from the aggregate. + * Thus, {@code aggregate(Initializer, Aggregator, Aggregator, String)} can be used to compute aggregate functions + * like sum. + * For sum, the initializer, adder, and subtractor would work as follows: + *

                {@code
                +     * // in this example, LongSerde.class must be set as default value serde in StreamsConfig
                +     * public class SumInitializer implements Initializer {
                +     *   public Long apply() {
                +     *     return 0L;
                +     *   }
                +     * }
                +     *
                +     * public class SumAdder implements Aggregator {
                +     *   public Long apply(String key, Integer newValue, Long aggregate) {
                +     *     return aggregate + newValue;
                +     *   }
                +     * }
                +     *
                +     * public class SumSubtractor implements Aggregator {
                +     *   public Long apply(String key, Integer oldValue, Long aggregate) {
                +     *     return aggregate - oldValue;
                +     *   }
                +     * }
                +     * }
                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + * For failure and recovery the store will be backed by an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-${internalStoreName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "internalStoreName" is an internal name + * and "-changelog" is a fixed suffix. + * Note that the internal store name may not be queryable through Interactive Queries. + * + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param initializer a {@link Initializer} that provides an initial aggregate result value + * @param adder a {@link Aggregator} that adds a new record to the aggregate result + * @param subtractor a {@link Aggregator} that removed an old record from the aggregate result + * @param named a {@link Named} config used to name the processor in the topology + * @param the value type of the aggregated {@link KTable} + * @return a {@link KTable} that contains "update" records with unmodified keys, and values that represent the + * latest (rolling) aggregate for each key + */ + KTable aggregate(final Initializer initializer, + final Aggregator adder, + final Aggregator subtractor, + final Named named); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/KStream.java b/streams/src/main/java/org/apache/kafka/streams/kstream/KStream.java new file mode 100644 index 0000000..c2ec757 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/KStream.java @@ -0,0 +1,4939 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.clients.producer.internals.DefaultPartitioner; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.processor.ConnectedStoreProvider; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.StreamPartitioner; +import org.apache.kafka.streams.processor.TopicNameExtractor; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; + +/** + * {@code KStream} is an abstraction of a record stream of {@link KeyValue} pairs, i.e., each record is an + * independent entity/event in the real world. + * For example a user X might buy two items I1 and I2, and thus there might be two records {@code , } + * in the stream. + *

                + * A {@code KStream} is either {@link StreamsBuilder#stream(String) defined from one or multiple Kafka topics} that + * are consumed message by message or the result of a {@code KStream} transformation. + * A {@link KTable} can also be {@link KTable#toStream() converted} into a {@code KStream}. + *

                + * A {@code KStream} can be transformed record by record, joined with another {@code KStream}, {@link KTable}, + * {@link GlobalKTable}, or can be aggregated into a {@link KTable}. + * Kafka Streams DSL can be mixed-and-matched with Processor API (PAPI) (c.f. {@link Topology}) via + * {@link #process(ProcessorSupplier, String...) process(...)}, + * {@link #transform(TransformerSupplier, String...) transform(...)}, and + * {@link #transformValues(ValueTransformerSupplier, String...) transformValues(...)}. + * + * @param Type of keys + * @param Type of values + * @see KTable + * @see KGroupedStream + * @see StreamsBuilder#stream(String) + */ +public interface KStream { + + /** + * Create a new {@code KStream} that consists of all records of this stream which satisfy the given predicate. + * All records that do not satisfy the predicate are dropped. + * This is a stateless record-by-record operation. + * + * @param predicate a filter {@link Predicate} that is applied to each record + * @return a {@code KStream} that contains only those records that satisfy the given predicate + * @see #filterNot(Predicate) + */ + KStream filter(final Predicate predicate); + + /** + * Create a new {@code KStream} that consists of all records of this stream which satisfy the given predicate. + * All records that do not satisfy the predicate are dropped. + * This is a stateless record-by-record operation. + * + * @param predicate a filter {@link Predicate} that is applied to each record + * @param named a {@link Named} config used to name the processor in the topology + * @return a {@code KStream} that contains only those records that satisfy the given predicate + * @see #filterNot(Predicate) + */ + KStream filter(final Predicate predicate, final Named named); + + /** + * Create a new {@code KStream} that consists all records of this stream which do not satisfy the given + * predicate. + * All records that do satisfy the predicate are dropped. + * This is a stateless record-by-record operation. + * + * @param predicate a filter {@link Predicate} that is applied to each record + * @return a {@code KStream} that contains only those records that do not satisfy the given predicate + * @see #filter(Predicate) + */ + KStream filterNot(final Predicate predicate); + + /** + * Create a new {@code KStream} that consists all records of this stream which do not satisfy the given + * predicate. + * All records that do satisfy the predicate are dropped. + * This is a stateless record-by-record operation. + * + * @param predicate a filter {@link Predicate} that is applied to each record + * @param named a {@link Named} config used to name the processor in the topology + * @return a {@code KStream} that contains only those records that do not satisfy the given predicate + * @see #filter(Predicate) + */ + KStream filterNot(final Predicate predicate, final Named named); + + /** + * Set a new key (with possibly new type) for each input record. + * The provided {@link KeyValueMapper} is applied to each input record and computes a new key for it. + * Thus, an input record {@code } can be transformed into an output record {@code }. + * This is a stateless record-by-record operation. + *

                + * For example, you can use this transformation to set a key for a key-less input record {@code } by + * extracting a key from the value within your {@link KeyValueMapper}. The example below computes the new key as the + * length of the value string. + *

                {@code
                +     * KStream keyLessStream = builder.stream("key-less-topic");
                +     * KStream keyedStream = keyLessStream.selectKey(new KeyValueMapper {
                +     *     Integer apply(Byte[] key, String value) {
                +     *         return value.length();
                +     *     }
                +     * });
                +     * }
                + * Setting a new key might result in an internal data redistribution if a key based operator (like an aggregation or + * join) is applied to the result {@code KStream}. + * + * @param mapper a {@link KeyValueMapper} that computes a new key for each record + * @param the new key type of the result stream + * @return a {@code KStream} that contains records with new key (possibly of different type) and unmodified value + * @see #map(KeyValueMapper) + * @see #flatMap(KeyValueMapper) + * @see #mapValues(ValueMapper) + * @see #mapValues(ValueMapperWithKey) + * @see #flatMapValues(ValueMapper) + * @see #flatMapValues(ValueMapperWithKey) + */ + KStream selectKey(final KeyValueMapper mapper); + + /** + * Set a new key (with possibly new type) for each input record. + * The provided {@link KeyValueMapper} is applied to each input record and computes a new key for it. + * Thus, an input record {@code } can be transformed into an output record {@code }. + * This is a stateless record-by-record operation. + *

                + * For example, you can use this transformation to set a key for a key-less input record {@code } by + * extracting a key from the value within your {@link KeyValueMapper}. The example below computes the new key as the + * length of the value string. + *

                {@code
                +     * KStream keyLessStream = builder.stream("key-less-topic");
                +     * KStream keyedStream = keyLessStream.selectKey(new KeyValueMapper {
                +     *     Integer apply(Byte[] key, String value) {
                +     *         return value.length();
                +     *     }
                +     * });
                +     * }
                + * Setting a new key might result in an internal data redistribution if a key based operator (like an aggregation or + * join) is applied to the result {@code KStream}. + * + * @param mapper a {@link KeyValueMapper} that computes a new key for each record + * @param named a {@link Named} config used to name the processor in the topology + * @param the new key type of the result stream + * @return a {@code KStream} that contains records with new key (possibly of different type) and unmodified value + * @see #map(KeyValueMapper) + * @see #flatMap(KeyValueMapper) + * @see #mapValues(ValueMapper) + * @see #mapValues(ValueMapperWithKey) + * @see #flatMapValues(ValueMapper) + * @see #flatMapValues(ValueMapperWithKey) + */ + KStream selectKey(final KeyValueMapper mapper, + final Named named); + + /** + * Transform each record of the input stream into a new record in the output stream (both key and value type can be + * altered arbitrarily). + * The provided {@link KeyValueMapper} is applied to each input record and computes a new output record. + * Thus, an input record {@code } can be transformed into an output record {@code }. + * This is a stateless record-by-record operation (cf. {@link #transform(TransformerSupplier, String...)} for + * stateful record transformation). + *

                + * The example below normalizes the String key to upper-case letters and counts the number of token of the value string. + *

                {@code
                +     * KStream inputStream = builder.stream("topic");
                +     * KStream outputStream = inputStream.map(new KeyValueMapper> {
                +     *     KeyValue apply(String key, String value) {
                +     *         return new KeyValue<>(key.toUpperCase(), value.split(" ").length);
                +     *     }
                +     * });
                +     * }
                + * The provided {@link KeyValueMapper} must return a {@link KeyValue} type and must not return {@code null}. + *

                + * Mapping records might result in an internal data redistribution if a key based operator (like an aggregation or + * join) is applied to the result {@code KStream}. (cf. {@link #mapValues(ValueMapper)}) + * + * @param mapper a {@link KeyValueMapper} that computes a new output record + * @param the key type of the result stream + * @param the value type of the result stream + * @return a {@code KStream} that contains records with new key and value (possibly both of different type) + * @see #selectKey(KeyValueMapper) + * @see #flatMap(KeyValueMapper) + * @see #mapValues(ValueMapper) + * @see #mapValues(ValueMapperWithKey) + * @see #flatMapValues(ValueMapper) + * @see #flatMapValues(ValueMapperWithKey) + * @see #transform(TransformerSupplier, String...) + * @see #transformValues(ValueTransformerSupplier, String...) + * @see #transformValues(ValueTransformerWithKeySupplier, String...) + */ + KStream map(final KeyValueMapper> mapper); + + /** + * Transform each record of the input stream into a new record in the output stream (both key and value type can be + * altered arbitrarily). + * The provided {@link KeyValueMapper} is applied to each input record and computes a new output record. + * Thus, an input record {@code } can be transformed into an output record {@code }. + * This is a stateless record-by-record operation (cf. {@link #transform(TransformerSupplier, String...)} for + * stateful record transformation). + *

                + * The example below normalizes the String key to upper-case letters and counts the number of token of the value string. + *

                {@code
                +     * KStream inputStream = builder.stream("topic");
                +     * KStream outputStream = inputStream.map(new KeyValueMapper> {
                +     *     KeyValue apply(String key, String value) {
                +     *         return new KeyValue<>(key.toUpperCase(), value.split(" ").length);
                +     *     }
                +     * });
                +     * }
                + * The provided {@link KeyValueMapper} must return a {@link KeyValue} type and must not return {@code null}. + *

                + * Mapping records might result in an internal data redistribution if a key based operator (like an aggregation or + * join) is applied to the result {@code KStream}. (cf. {@link #mapValues(ValueMapper)}) + * + * @param mapper a {@link KeyValueMapper} that computes a new output record + * @param named a {@link Named} config used to name the processor in the topology + * @param the key type of the result stream + * @param the value type of the result stream + * @return a {@code KStream} that contains records with new key and value (possibly both of different type) + * @see #selectKey(KeyValueMapper) + * @see #flatMap(KeyValueMapper) + * @see #mapValues(ValueMapper) + * @see #mapValues(ValueMapperWithKey) + * @see #flatMapValues(ValueMapper) + * @see #flatMapValues(ValueMapperWithKey) + * @see #transform(TransformerSupplier, String...) + * @see #transformValues(ValueTransformerSupplier, String...) + * @see #transformValues(ValueTransformerWithKeySupplier, String...) + */ + KStream map(final KeyValueMapper> mapper, + final Named named); + + /** + * Transform the value of each input record into a new value (with possible new type) of the output record. + * The provided {@link ValueMapper} is applied to each input record value and computes a new value for it. + * Thus, an input record {@code } can be transformed into an output record {@code }. + * This is a stateless record-by-record operation (cf. + * {@link #transformValues(ValueTransformerSupplier, String...)} for stateful value transformation). + *

                + * The example below counts the number of token of the value string. + *

                {@code
                +     * KStream inputStream = builder.stream("topic");
                +     * KStream outputStream = inputStream.mapValues(new ValueMapper {
                +     *     Integer apply(String value) {
                +     *         return value.split(" ").length;
                +     *     }
                +     * });
                +     * }
                + * Setting a new value preserves data co-location with respect to the key. + * Thus, no internal data redistribution is required if a key based operator (like an aggregation or join) + * is applied to the result {@code KStream}. (cf. {@link #map(KeyValueMapper)}) + * + * @param mapper a {@link ValueMapper} that computes a new output value + * @param the value type of the result stream + * @return a {@code KStream} that contains records with unmodified key and new values (possibly of different type) + * @see #selectKey(KeyValueMapper) + * @see #map(KeyValueMapper) + * @see #flatMap(KeyValueMapper) + * @see #flatMapValues(ValueMapper) + * @see #flatMapValues(ValueMapperWithKey) + * @see #transform(TransformerSupplier, String...) + * @see #transformValues(ValueTransformerSupplier, String...) + * @see #transformValues(ValueTransformerWithKeySupplier, String...) + */ + KStream mapValues(final ValueMapper mapper); + + /** + * Transform the value of each input record into a new value (with possible new type) of the output record. + * The provided {@link ValueMapper} is applied to each input record value and computes a new value for it. + * Thus, an input record {@code } can be transformed into an output record {@code }. + * This is a stateless record-by-record operation (cf. + * {@link #transformValues(ValueTransformerSupplier, String...)} for stateful value transformation). + *

                + * The example below counts the number of token of the value string. + *

                {@code
                +     * KStream inputStream = builder.stream("topic");
                +     * KStream outputStream = inputStream.mapValues(new ValueMapper {
                +     *     Integer apply(String value) {
                +     *         return value.split(" ").length;
                +     *     }
                +     * });
                +     * }
                + * Setting a new value preserves data co-location with respect to the key. + * Thus, no internal data redistribution is required if a key based operator (like an aggregation or join) + * is applied to the result {@code KStream}. (cf. {@link #map(KeyValueMapper)}) + * + * @param mapper a {@link ValueMapper} that computes a new output value + * @param named a {@link Named} config used to name the processor in the topology + * @param the value type of the result stream + * @return a {@code KStream} that contains records with unmodified key and new values (possibly of different type) + * @see #selectKey(KeyValueMapper) + * @see #map(KeyValueMapper) + * @see #flatMap(KeyValueMapper) + * @see #flatMapValues(ValueMapper) + * @see #flatMapValues(ValueMapperWithKey) + * @see #transform(TransformerSupplier, String...) + * @see #transformValues(ValueTransformerSupplier, String...) + * @see #transformValues(ValueTransformerWithKeySupplier, String...) + */ + KStream mapValues(final ValueMapper mapper, + final Named named); + + /** + * Transform the value of each input record into a new value (with possible new type) of the output record. + * The provided {@link ValueMapperWithKey} is applied to each input record value and computes a new value for it. + * Thus, an input record {@code } can be transformed into an output record {@code }. + * This is a stateless record-by-record operation (cf. + * {@link #transformValues(ValueTransformerWithKeySupplier, String...)} for stateful value transformation). + *

                + * The example below counts the number of tokens of key and value strings. + *

                {@code
                +     * KStream inputStream = builder.stream("topic");
                +     * KStream outputStream = inputStream.mapValues(new ValueMapperWithKey {
                +     *     Integer apply(String readOnlyKey, String value) {
                +     *         return readOnlyKey.split(" ").length + value.split(" ").length;
                +     *     }
                +     * });
                +     * }
                + * Note that the key is read-only and should not be modified, as this can lead to corrupt partitioning. + * So, setting a new value preserves data co-location with respect to the key. + * Thus, no internal data redistribution is required if a key based operator (like an aggregation or join) + * is applied to the result {@code KStream}. (cf. {@link #map(KeyValueMapper)}) + * + * @param mapper a {@link ValueMapperWithKey} that computes a new output value + * @param the value type of the result stream + * @return a {@code KStream} that contains records with unmodified key and new values (possibly of different type) + * @see #selectKey(KeyValueMapper) + * @see #map(KeyValueMapper) + * @see #flatMap(KeyValueMapper) + * @see #flatMapValues(ValueMapper) + * @see #flatMapValues(ValueMapperWithKey) + * @see #transform(TransformerSupplier, String...) + * @see #transformValues(ValueTransformerSupplier, String...) + * @see #transformValues(ValueTransformerWithKeySupplier, String...) + */ + KStream mapValues(final ValueMapperWithKey mapper); + + /** + * Transform the value of each input record into a new value (with possible new type) of the output record. + * The provided {@link ValueMapperWithKey} is applied to each input record value and computes a new value for it. + * Thus, an input record {@code } can be transformed into an output record {@code }. + * This is a stateless record-by-record operation (cf. + * {@link #transformValues(ValueTransformerWithKeySupplier, String...)} for stateful value transformation). + *

                + * The example below counts the number of tokens of key and value strings. + *

                {@code
                +     * KStream inputStream = builder.stream("topic");
                +     * KStream outputStream = inputStream.mapValues(new ValueMapperWithKey {
                +     *     Integer apply(String readOnlyKey, String value) {
                +     *         return readOnlyKey.split(" ").length + value.split(" ").length;
                +     *     }
                +     * });
                +     * }
                + * Note that the key is read-only and should not be modified, as this can lead to corrupt partitioning. + * So, setting a new value preserves data co-location with respect to the key. + * Thus, no internal data redistribution is required if a key based operator (like an aggregation or join) + * is applied to the result {@code KStream}. (cf. {@link #map(KeyValueMapper)}) + * + * @param mapper a {@link ValueMapperWithKey} that computes a new output value + * @param named a {@link Named} config used to name the processor in the topology + * @param the value type of the result stream + * @return a {@code KStream} that contains records with unmodified key and new values (possibly of different type) + * @see #selectKey(KeyValueMapper) + * @see #map(KeyValueMapper) + * @see #flatMap(KeyValueMapper) + * @see #flatMapValues(ValueMapper) + * @see #flatMapValues(ValueMapperWithKey) + * @see #transform(TransformerSupplier, String...) + * @see #transformValues(ValueTransformerSupplier, String...) + * @see #transformValues(ValueTransformerWithKeySupplier, String...) + */ + KStream mapValues(final ValueMapperWithKey mapper, + final Named named); + + /** + * Transform each record of the input stream into zero or more records in the output stream (both key and value type + * can be altered arbitrarily). + * The provided {@link KeyValueMapper} is applied to each input record and computes zero or more output records. + * Thus, an input record {@code } can be transformed into output records {@code , , ...}. + * This is a stateless record-by-record operation (cf. {@link #transform(TransformerSupplier, String...)} for + * stateful record transformation). + *

                + * The example below splits input records {@code } containing sentences as values into their words + * and emit a record {@code } for each word. + *

                {@code
                +     * KStream inputStream = builder.stream("topic");
                +     * KStream outputStream = inputStream.flatMap(
                +     *     new KeyValueMapper>> {
                +     *         Iterable> apply(byte[] key, String value) {
                +     *             String[] tokens = value.split(" ");
                +     *             List> result = new ArrayList<>(tokens.length);
                +     *
                +     *             for(String token : tokens) {
                +     *                 result.add(new KeyValue<>(token, 1));
                +     *             }
                +     *
                +     *             return result;
                +     *         }
                +     *     });
                +     * }
                + * The provided {@link KeyValueMapper} must return an {@link Iterable} (e.g., any {@link java.util.Collection} type) + * and the return value must not be {@code null}. + *

                + * Flat-mapping records might result in an internal data redistribution if a key based operator (like an aggregation + * or join) is applied to the result {@code KStream}. (cf. {@link #flatMapValues(ValueMapper)}) + * + * @param mapper a {@link KeyValueMapper} that computes the new output records + * @param the key type of the result stream + * @param the value type of the result stream + * @return a {@code KStream} that contains more or less records with new key and value (possibly of different type) + * @see #selectKey(KeyValueMapper) + * @see #map(KeyValueMapper) + * @see #mapValues(ValueMapper) + * @see #mapValues(ValueMapperWithKey) + * @see #flatMapValues(ValueMapper) + * @see #flatMapValues(ValueMapperWithKey) + * @see #transform(TransformerSupplier, String...) + * @see #flatTransform(TransformerSupplier, String...) + * @see #transformValues(ValueTransformerSupplier, String...) + * @see #transformValues(ValueTransformerWithKeySupplier, String...) + * @see #flatTransformValues(ValueTransformerSupplier, String...) + * @see #flatTransformValues(ValueTransformerWithKeySupplier, String...) + */ + KStream flatMap(final KeyValueMapper>> mapper); + + /** + * Transform each record of the input stream into zero or more records in the output stream (both key and value type + * can be altered arbitrarily). + * The provided {@link KeyValueMapper} is applied to each input record and computes zero or more output records. + * Thus, an input record {@code } can be transformed into output records {@code , , ...}. + * This is a stateless record-by-record operation (cf. {@link #transform(TransformerSupplier, String...)} for + * stateful record transformation). + *

                + * The example below splits input records {@code } containing sentences as values into their words + * and emit a record {@code } for each word. + *

                {@code
                +     * KStream inputStream = builder.stream("topic");
                +     * KStream outputStream = inputStream.flatMap(
                +     *     new KeyValueMapper>> {
                +     *         Iterable> apply(byte[] key, String value) {
                +     *             String[] tokens = value.split(" ");
                +     *             List> result = new ArrayList<>(tokens.length);
                +     *
                +     *             for(String token : tokens) {
                +     *                 result.add(new KeyValue<>(token, 1));
                +     *             }
                +     *
                +     *             return result;
                +     *         }
                +     *     });
                +     * }
                + * The provided {@link KeyValueMapper} must return an {@link Iterable} (e.g., any {@link java.util.Collection} type) + * and the return value must not be {@code null}. + *

                + * Flat-mapping records might result in an internal data redistribution if a key based operator (like an aggregation + * or join) is applied to the result {@code KStream}. (cf. {@link #flatMapValues(ValueMapper)}) + * + * @param mapper a {@link KeyValueMapper} that computes the new output records + * @param named a {@link Named} config used to name the processor in the topology + * @param the key type of the result stream + * @param the value type of the result stream + * @return a {@code KStream} that contains more or less records with new key and value (possibly of different type) + * @see #selectKey(KeyValueMapper) + * @see #map(KeyValueMapper) + * @see #mapValues(ValueMapper) + * @see #mapValues(ValueMapperWithKey) + * @see #flatMapValues(ValueMapper) + * @see #flatMapValues(ValueMapperWithKey) + * @see #transform(TransformerSupplier, String...) + * @see #flatTransform(TransformerSupplier, String...) + * @see #transformValues(ValueTransformerSupplier, String...) + * @see #transformValues(ValueTransformerWithKeySupplier, String...) + * @see #flatTransformValues(ValueTransformerSupplier, String...) + * @see #flatTransformValues(ValueTransformerWithKeySupplier, String...) + */ + KStream flatMap(final KeyValueMapper>> mapper, + final Named named); + + /** + * Create a new {@code KStream} by transforming the value of each record in this stream into zero or more values + * with the same key in the new stream. + * Transform the value of each input record into zero or more records with the same (unmodified) key in the output + * stream (value type can be altered arbitrarily). + * The provided {@link ValueMapper} is applied to each input record and computes zero or more output values. + * Thus, an input record {@code } can be transformed into output records {@code , , ...}. + * This is a stateless record-by-record operation (cf. {@link #transformValues(ValueTransformerSupplier, String...)} + * for stateful value transformation). + *

                + * The example below splits input records {@code } containing sentences as values into their words. + *

                {@code
                +     * KStream inputStream = builder.stream("topic");
                +     * KStream outputStream = inputStream.flatMapValues(new ValueMapper> {
                +     *     Iterable apply(String value) {
                +     *         return Arrays.asList(value.split(" "));
                +     *     }
                +     * });
                +     * }
                + * The provided {@link ValueMapper} must return an {@link Iterable} (e.g., any {@link java.util.Collection} type) + * and the return value must not be {@code null}. + *

                + * Splitting a record into multiple records with the same key preserves data co-location with respect to the key. + * Thus, no internal data redistribution is required if a key based operator (like an aggregation or join) + * is applied to the result {@code KStream}. (cf. {@link #flatMap(KeyValueMapper)}) + * + * @param mapper a {@link ValueMapper} the computes the new output values + * @param the value type of the result stream + * @return a {@code KStream} that contains more or less records with unmodified keys and new values of different type + * @see #selectKey(KeyValueMapper) + * @see #map(KeyValueMapper) + * @see #flatMap(KeyValueMapper) + * @see #mapValues(ValueMapper) + * @see #mapValues(ValueMapperWithKey) + * @see #transform(TransformerSupplier, String...) + * @see #flatTransform(TransformerSupplier, String...) + * @see #transformValues(ValueTransformerSupplier, String...) + * @see #transformValues(ValueTransformerWithKeySupplier, String...) + * @see #flatTransformValues(ValueTransformerSupplier, String...) + * @see #flatTransformValues(ValueTransformerWithKeySupplier, String...) + */ + KStream flatMapValues(final ValueMapper> mapper); + + /** + * Create a new {@code KStream} by transforming the value of each record in this stream into zero or more values + * with the same key in the new stream. + * Transform the value of each input record into zero or more records with the same (unmodified) key in the output + * stream (value type can be altered arbitrarily). + * The provided {@link ValueMapper} is applied to each input record and computes zero or more output values. + * Thus, an input record {@code } can be transformed into output records {@code , , ...}. + * This is a stateless record-by-record operation (cf. {@link #transformValues(ValueTransformerSupplier, String...)} + * for stateful value transformation). + *

                + * The example below splits input records {@code } containing sentences as values into their words. + *

                {@code
                +     * KStream inputStream = builder.stream("topic");
                +     * KStream outputStream = inputStream.flatMapValues(new ValueMapper> {
                +     *     Iterable apply(String value) {
                +     *         return Arrays.asList(value.split(" "));
                +     *     }
                +     * });
                +     * }
                + * The provided {@link ValueMapper} must return an {@link Iterable} (e.g., any {@link java.util.Collection} type) + * and the return value must not be {@code null}. + *

                + * Splitting a record into multiple records with the same key preserves data co-location with respect to the key. + * Thus, no internal data redistribution is required if a key based operator (like an aggregation or join) + * is applied to the result {@code KStream}. (cf. {@link #flatMap(KeyValueMapper)}) + * + * @param mapper a {@link ValueMapper} the computes the new output values + * @param named a {@link Named} config used to name the processor in the topology + * @param the value type of the result stream + * @return a {@code KStream} that contains more or less records with unmodified keys and new values of different type + * @see #selectKey(KeyValueMapper) + * @see #map(KeyValueMapper) + * @see #flatMap(KeyValueMapper) + * @see #mapValues(ValueMapper) + * @see #mapValues(ValueMapperWithKey) + * @see #transform(TransformerSupplier, String...) + * @see #flatTransform(TransformerSupplier, String...) + * @see #transformValues(ValueTransformerSupplier, String...) + * @see #transformValues(ValueTransformerWithKeySupplier, String...) + * @see #flatTransformValues(ValueTransformerSupplier, String...) + * @see #flatTransformValues(ValueTransformerWithKeySupplier, String...) + */ + KStream flatMapValues(final ValueMapper> mapper, + final Named named); + /** + * Create a new {@code KStream} by transforming the value of each record in this stream into zero or more values + * with the same key in the new stream. + * Transform the value of each input record into zero or more records with the same (unmodified) key in the output + * stream (value type can be altered arbitrarily). + * The provided {@link ValueMapperWithKey} is applied to each input record and computes zero or more output values. + * Thus, an input record {@code } can be transformed into output records {@code , , ...}. + * This is a stateless record-by-record operation (cf. {@link #transformValues(ValueTransformerWithKeySupplier, String...)} + * for stateful value transformation). + *

                + * The example below splits input records {@code }, with key=1, containing sentences as values + * into their words. + *

                {@code
                +     * KStream inputStream = builder.stream("topic");
                +     * KStream outputStream = inputStream.flatMapValues(new ValueMapper> {
                +     *     Iterable apply(Integer readOnlyKey, String value) {
                +     *         if(readOnlyKey == 1) {
                +     *             return Arrays.asList(value.split(" "));
                +     *         } else {
                +     *             return Arrays.asList(value);
                +     *         }
                +     *     }
                +     * });
                +     * }
                + * The provided {@link ValueMapperWithKey} must return an {@link Iterable} (e.g., any {@link java.util.Collection} type) + * and the return value must not be {@code null}. + *

                + * Note that the key is read-only and should not be modified, as this can lead to corrupt partitioning. + * So, splitting a record into multiple records with the same key preserves data co-location with respect to the key. + * Thus, no internal data redistribution is required if a key based operator (like an aggregation or join) + * is applied to the result {@code KStream}. (cf. {@link #flatMap(KeyValueMapper)}) + * + * @param mapper a {@link ValueMapperWithKey} the computes the new output values + * @param the value type of the result stream + * @return a {@code KStream} that contains more or less records with unmodified keys and new values of different type + * @see #selectKey(KeyValueMapper) + * @see #map(KeyValueMapper) + * @see #flatMap(KeyValueMapper) + * @see #mapValues(ValueMapper) + * @see #mapValues(ValueMapperWithKey) + * @see #transform(TransformerSupplier, String...) + * @see #flatTransform(TransformerSupplier, String...) + * @see #transformValues(ValueTransformerSupplier, String...) + * @see #transformValues(ValueTransformerWithKeySupplier, String...) + * @see #flatTransformValues(ValueTransformerSupplier, String...) + * @see #flatTransformValues(ValueTransformerWithKeySupplier, String...) + */ + KStream flatMapValues(final ValueMapperWithKey> mapper); + + /** + * Create a new {@code KStream} by transforming the value of each record in this stream into zero or more values + * with the same key in the new stream. + * Transform the value of each input record into zero or more records with the same (unmodified) key in the output + * stream (value type can be altered arbitrarily). + * The provided {@link ValueMapperWithKey} is applied to each input record and computes zero or more output values. + * Thus, an input record {@code } can be transformed into output records {@code , , ...}. + * This is a stateless record-by-record operation (cf. {@link #transformValues(ValueTransformerWithKeySupplier, String...)} + * for stateful value transformation). + *

                + * The example below splits input records {@code }, with key=1, containing sentences as values + * into their words. + *

                {@code
                +     * KStream inputStream = builder.stream("topic");
                +     * KStream outputStream = inputStream.flatMapValues(new ValueMapper> {
                +     *     Iterable apply(Integer readOnlyKey, String value) {
                +     *         if(readOnlyKey == 1) {
                +     *             return Arrays.asList(value.split(" "));
                +     *         } else {
                +     *             return Arrays.asList(value);
                +     *         }
                +     *     }
                +     * });
                +     * }
                + * The provided {@link ValueMapperWithKey} must return an {@link Iterable} (e.g., any {@link java.util.Collection} type) + * and the return value must not be {@code null}. + *

                + * Note that the key is read-only and should not be modified, as this can lead to corrupt partitioning. + * So, splitting a record into multiple records with the same key preserves data co-location with respect to the key. + * Thus, no internal data redistribution is required if a key based operator (like an aggregation or join) + * is applied to the result {@code KStream}. (cf. {@link #flatMap(KeyValueMapper)}) + * + * @param mapper a {@link ValueMapperWithKey} the computes the new output values + * @param named a {@link Named} config used to name the processor in the topology + * @param the value type of the result stream + * @return a {@code KStream} that contains more or less records with unmodified keys and new values of different type + * @see #selectKey(KeyValueMapper) + * @see #map(KeyValueMapper) + * @see #flatMap(KeyValueMapper) + * @see #mapValues(ValueMapper) + * @see #mapValues(ValueMapperWithKey) + * @see #transform(TransformerSupplier, String...) + * @see #flatTransform(TransformerSupplier, String...) + * @see #transformValues(ValueTransformerSupplier, String...) + * @see #transformValues(ValueTransformerWithKeySupplier, String...) + * @see #flatTransformValues(ValueTransformerSupplier, String...) + * @see #flatTransformValues(ValueTransformerWithKeySupplier, String...) + */ + KStream flatMapValues(final ValueMapperWithKey> mapper, + final Named named); + + /** + * Print the records of this KStream using the options provided by {@link Printed} + * Note that this is mainly for debugging/testing purposes, and it will try to flush on each record print. + * It SHOULD NOT be used for production usage if performance requirements are concerned. + * + * @param printed options for printing + */ + void print(final Printed printed); + + /** + * Perform an action on each record of {@code KStream}. + * This is a stateless record-by-record operation (cf. {@link #process(ProcessorSupplier, String...)}). + * Note that this is a terminal operation that returns void. + * + * @param action an action to perform on each record + * @see #process(ProcessorSupplier, String...) + */ + void foreach(final ForeachAction action); + + /** + * Perform an action on each record of {@code KStream}. + * This is a stateless record-by-record operation (cf. {@link #process(ProcessorSupplier, String...)}). + * Note that this is a terminal operation that returns void. + * + * @param action an action to perform on each record + * @param named a {@link Named} config used to name the processor in the topology + * @see #process(ProcessorSupplier, String...) + */ + void foreach(final ForeachAction action, final Named named); + + /** + * Perform an action on each record of {@code KStream}. + * This is a stateless record-by-record operation (cf. {@link #process(ProcessorSupplier, String...)}). + *

                + * Peek is a non-terminal operation that triggers a side effect (such as logging or statistics collection) + * and returns an unchanged stream. + *

                + * Note that since this operation is stateless, it may execute multiple times for a single record in failure cases. + * + * @param action an action to perform on each record + * @see #process(ProcessorSupplier, String...) + * @return itself + */ + KStream peek(final ForeachAction action); + + /** + * Perform an action on each record of {@code KStream}. + * This is a stateless record-by-record operation (cf. {@link #process(ProcessorSupplier, String...)}). + *

                + * Peek is a non-terminal operation that triggers a side effect (such as logging or statistics collection) + * and returns an unchanged stream. + *

                + * Note that since this operation is stateless, it may execute multiple times for a single record in failure cases. + * + * @param action an action to perform on each record + * @param named a {@link Named} config used to name the processor in the topology + * @see #process(ProcessorSupplier, String...) + * @return itself + */ + KStream peek(final ForeachAction action, final Named named); + + /** + * Creates an array of {@code KStream} from this stream by branching the records in the original stream based on + * the supplied predicates. + * Each record is evaluated against the supplied predicates, and predicates are evaluated in order. + * Each stream in the result array corresponds position-wise (index) to the predicate in the supplied predicates. + * The branching happens on first-match: A record in the original stream is assigned to the corresponding result + * stream for the first predicate that evaluates to true, and is assigned to this stream only. + * A record will be dropped if none of the predicates evaluate to true. + * This is a stateless record-by-record operation. + * + * @param predicates the ordered list of {@link Predicate} instances + * @return multiple distinct substreams of this {@code KStream} + * @deprecated since 2.8. Use {@link #split()} instead. + */ + @Deprecated + @SuppressWarnings("unchecked") + KStream[] branch(final Predicate... predicates); + + /** + * Creates an array of {@code KStream} from this stream by branching the records in the original stream based on + * the supplied predicates. + * Each record is evaluated against the supplied predicates, and predicates are evaluated in order. + * Each stream in the result array corresponds position-wise (index) to the predicate in the supplied predicates. + * The branching happens on first-match: A record in the original stream is assigned to the corresponding result + * stream for the first predicate that evaluates to true, and is assigned to this stream only. + * A record will be dropped if none of the predicates evaluate to true. + * This is a stateless record-by-record operation. + * + * @param named a {@link Named} config used to name the processor in the topology + * @param predicates the ordered list of {@link Predicate} instances + * @return multiple distinct substreams of this {@code KStream} + * @deprecated since 2.8. Use {@link #split(Named)} instead. + */ + @Deprecated + @SuppressWarnings("unchecked") + KStream[] branch(final Named named, final Predicate... predicates); + + /** + * Split this stream into different branches. The returned {@link BranchedKStream} instance can be used for routing + * the records to different branches depending on evaluation against the supplied predicates. + *

                + * Note: Stream branching is a stateless record-by-record operation. + * Please check {@link BranchedKStream} for detailed description and usage example + * + * @return {@link BranchedKStream} that provides methods for routing the records to different branches. + */ + BranchedKStream split(); + + /** + * Split this stream into different branches. The returned {@link BranchedKStream} instance can be used for routing + * the records to different branches depending on evaluation against the supplied predicates. + *

                + * Note: Stream branching is a stateless record-by-record operation. + * Please check {@link BranchedKStream} for detailed description and usage example + * + * @param named a {@link Named} config used to name the processor in the topology and also to set the name prefix + * for the resulting branches (see {@link BranchedKStream}) + * @return {@link BranchedKStream} that provides methods for routing the records to different branches. + */ + BranchedKStream split(final Named named); + + /** + * Merge this stream and the given stream into one larger stream. + *

                + * There is no ordering guarantee between records from this {@code KStream} and records from + * the provided {@code KStream} in the merged stream. + * Relative order is preserved within each input stream though (ie, records within one input + * stream are processed in order). + * + * @param stream a stream which is to be merged into this stream + * @return a merged stream containing all records from this and the provided {@code KStream} + */ + KStream merge(final KStream stream); + + /** + * Merge this stream and the given stream into one larger stream. + *

                + * There is no ordering guarantee between records from this {@code KStream} and records from + * the provided {@code KStream} in the merged stream. + * Relative order is preserved within each input stream though (ie, records within one input + * stream are processed in order). + * + * @param stream a stream which is to be merged into this stream + * @param named a {@link Named} config used to name the processor in the topology + * @return a merged stream containing all records from this and the provided {@code KStream} + */ + KStream merge(final KStream stream, final Named named); + + /** + * Materialize this stream to a topic and creates a new {@code KStream} from the topic using default serializers, + * deserializers, and producer's {@link DefaultPartitioner}. + * The specified topic should be manually created before it is used (i.e., before the Kafka Streams application is + * started). + *

                + * This is similar to calling {@link #to(String) #to(someTopicName)} and + * {@link StreamsBuilder#stream(String) StreamsBuilder#stream(someTopicName)}. + * Note that {@code through()} uses a hard coded {@link org.apache.kafka.streams.processor.FailOnInvalidTimestamp + * timestamp extractor} and does not allow to customize it, to ensure correct timestamp propagation. + * + * @param topic the topic name + * @return a {@code KStream} that contains the exact same (and potentially repartitioned) records as this {@code KStream} + * @deprecated since 2.6; use {@link #repartition()} instead + */ + // TODO: when removed, update `StreamsResetter` decription of --intermediate-topics + @Deprecated + KStream through(final String topic); + + /** + * Materialize this stream to a topic and creates a new {@code KStream} from the topic using the + * {@link Produced} instance for configuration of the {@link Serde key serde}, {@link Serde value serde}, + * and {@link StreamPartitioner}. + * The specified topic should be manually created before it is used (i.e., before the Kafka Streams application is + * started). + *

                + * This is similar to calling {@link #to(String, Produced) to(someTopic, Produced.with(keySerde, valueSerde)} + * and {@link StreamsBuilder#stream(String, Consumed) StreamsBuilder#stream(someTopicName, Consumed.with(keySerde, valueSerde))}. + * Note that {@code through()} uses a hard coded {@link org.apache.kafka.streams.processor.FailOnInvalidTimestamp + * timestamp extractor} and does not allow to customize it, to ensure correct timestamp propagation. + * + * @param topic the topic name + * @param produced the options to use when producing to the topic + * @return a {@code KStream} that contains the exact same (and potentially repartitioned) records as this {@code KStream} + * @deprecated since 2.6; use {@link #repartition(Repartitioned)} instead + */ + @Deprecated + KStream through(final String topic, + final Produced produced); + + /** + * Materialize this stream to an auto-generated repartition topic and create a new {@code KStream} + * from the auto-generated topic using default serializers, deserializers, and producer's {@link DefaultPartitioner}. + * The number of partitions is determined based on the upstream topics partition numbers. + *

                + * The created topic is considered as an internal topic and is meant to be used only by the current Kafka Streams instance. + * Similar to auto-repartitioning, the topic will be created with infinite retention time and data will be automatically purged by Kafka Streams. + * The topic will be named as "${applicationId}-<name>-repartition", where "applicationId" is user-specified in + * {@link StreamsConfig} via parameter {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, + * "<name>" is an internally generated name, and "-repartition" is a fixed suffix. + * + * @return {@code KStream} that contains the exact same repartitioned records as this {@code KStream}. + */ + KStream repartition(); + + /** + * Materialize this stream to an auto-generated repartition topic and create a new {@code KStream} + * from the auto-generated topic using {@link Serde key serde}, {@link Serde value serde}, {@link StreamPartitioner}, + * number of partitions, and topic name part as defined by {@link Repartitioned}. + *

                + * The created topic is considered as an internal topic and is meant to be used only by the current Kafka Streams instance. + * Similar to auto-repartitioning, the topic will be created with infinite retention time and data will be automatically purged by Kafka Streams. + * The topic will be named as "${applicationId}-<name>-repartition", where "applicationId" is user-specified in + * {@link StreamsConfig} via parameter {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, + * "<name>" is either provided via {@link Repartitioned#as(String)} or an internally + * generated name, and "-repartition" is a fixed suffix. + * + * @param repartitioned the {@link Repartitioned} instance used to specify {@link Serdes}, + * {@link StreamPartitioner} which determines how records are distributed among partitions of the topic, + * part of the topic name, and number of partitions for a repartition topic. + * @return a {@code KStream} that contains the exact same repartitioned records as this {@code KStream}. + */ + KStream repartition(final Repartitioned repartitioned); + + /** + * Materialize this stream to a topic using default serializers specified in the config and producer's + * {@link DefaultPartitioner}. + * The specified topic should be manually created before it is used (i.e., before the Kafka Streams application is + * started). + * + * @param topic the topic name + */ + void to(final String topic); + + /** + * Materialize this stream to a topic using the provided {@link Produced} instance. + * The specified topic should be manually created before it is used (i.e., before the Kafka Streams application is + * started). + * + * @param topic the topic name + * @param produced the options to use when producing to the topic + */ + void to(final String topic, + final Produced produced); + + /** + * Dynamically materialize this stream to topics using default serializers specified in the config and producer's + * {@link DefaultPartitioner}. + * The topic names for each record to send to is dynamically determined based on the {@link TopicNameExtractor}. + * + * @param topicExtractor the extractor to determine the name of the Kafka topic to write to for each record + */ + void to(final TopicNameExtractor topicExtractor); + + /** + * Dynamically materialize this stream to topics using the provided {@link Produced} instance. + * The topic names for each record to send to is dynamically determined based on the {@link TopicNameExtractor}. + * + * @param topicExtractor the extractor to determine the name of the Kafka topic to write to for each record + * @param produced the options to use when producing to the topic + */ + void to(final TopicNameExtractor topicExtractor, + final Produced produced); + + /** + * Convert this stream to a {@link KTable}. + *

                + * If a key changing operator was used before this operation (e.g., {@link #selectKey(KeyValueMapper)}, + * {@link #map(KeyValueMapper)}, {@link #flatMap(KeyValueMapper)} or + * {@link #transform(TransformerSupplier, String...)}) an internal repartitioning topic will be created in Kafka. + * This topic will be named "${applicationId}-<name>-repartition", where "applicationId" is user-specified in + * {@link StreamsConfig} via parameter {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, + * "<name>" is an internally generated name, and "-repartition" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + *

                + * For this case, all data of this stream will be redistributed through the repartitioning topic by writing all + * records to it, and rereading all records from it, such that the resulting {@link KTable} is partitioned + * correctly on its key. + * Note that you cannot enable {@link StreamsConfig#TOPOLOGY_OPTIMIZATION_CONFIG} config for this case, because + * repartition topics are considered transient and don't allow to recover the result {@link KTable} in cause of + * a failure; hence, a dedicated changelog topic is required to guarantee fault-tolerance. + *

                + * Note that this is a logical operation and only changes the "interpretation" of the stream, i.e., each record of + * it was a "fact/event" and is re-interpreted as update now (cf. {@link KStream} vs {@code KTable}). + * + * @return a {@link KTable} that contains the same records as this {@code KStream} + */ + KTable toTable(); + + /** + * Convert this stream to a {@link KTable}. + *

                + * If a key changing operator was used before this operation (e.g., {@link #selectKey(KeyValueMapper)}, + * {@link #map(KeyValueMapper)}, {@link #flatMap(KeyValueMapper)} or + * {@link #transform(TransformerSupplier, String...)}) an internal repartitioning topic will be created in Kafka. + * This topic will be named "${applicationId}-<name>-repartition", where "applicationId" is user-specified in + * {@link StreamsConfig} via parameter {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, + * "<name>" is an internally generated name, and "-repartition" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + *

                + * For this case, all data of this stream will be redistributed through the repartitioning topic by writing all + * records to it, and rereading all records from it, such that the resulting {@link KTable} is partitioned + * correctly on its key. + * Note that you cannot enable {@link StreamsConfig#TOPOLOGY_OPTIMIZATION_CONFIG} config for this case, because + * repartition topics are considered transient and don't allow to recover the result {@link KTable} in cause of + * a failure; hence, a dedicated changelog topic is required to guarantee fault-tolerance. + *

                + * Note that this is a logical operation and only changes the "interpretation" of the stream, i.e., each record of + * it was a "fact/event" and is re-interpreted as update now (cf. {@link KStream} vs {@code KTable}). + * + * @param named a {@link Named} config used to name the processor in the topology + * @return a {@link KTable} that contains the same records as this {@code KStream} + */ + KTable toTable(final Named named); + + /** + * Convert this stream to a {@link KTable}. + *

                + * If a key changing operator was used before this operation (e.g., {@link #selectKey(KeyValueMapper)}, + * {@link #map(KeyValueMapper)}, {@link #flatMap(KeyValueMapper)} or + * {@link #transform(TransformerSupplier, String...)}) an internal repartitioning topic will be created in Kafka. + * This topic will be named "${applicationId}-<name>-repartition", where "applicationId" is user-specified in + * {@link StreamsConfig} via parameter {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, + * "<name>" is an internally generated name, and "-repartition" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + *

                + * For this case, all data of this stream will be redistributed through the repartitioning topic by writing all + * records to it, and rereading all records from it, such that the resulting {@link KTable} is partitioned + * correctly on its key. + * Note that you cannot enable {@link StreamsConfig#TOPOLOGY_OPTIMIZATION_CONFIG} config for this case, because + * repartition topics are considered transient and don't allow to recover the result {@link KTable} in cause of + * a failure; hence, a dedicated changelog topic is required to guarantee fault-tolerance. + *

                + * Note that this is a logical operation and only changes the "interpretation" of the stream, i.e., each record of + * it was a "fact/event" and is re-interpreted as update now (cf. {@link KStream} vs {@code KTable}). + * + * @param materialized an instance of {@link Materialized} used to describe how the state store of the + * resulting table should be materialized. + * @return a {@link KTable} that contains the same records as this {@code KStream} + */ + KTable toTable(final Materialized> materialized); + + /** + * Convert this stream to a {@link KTable}. + *

                + * If a key changing operator was used before this operation (e.g., {@link #selectKey(KeyValueMapper)}, + * {@link #map(KeyValueMapper)}, {@link #flatMap(KeyValueMapper)} or + * {@link #transform(TransformerSupplier, String...)}) an internal repartitioning topic will be created in Kafka. + * This topic will be named "${applicationId}-<name>-repartition", where "applicationId" is user-specified in + * {@link StreamsConfig} via parameter {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, + * "<name>" is an internally generated name, and "-repartition" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + *

                + * For this case, all data of this stream will be redistributed through the repartitioning topic by writing all + * records to it, and rereading all records from it, such that the resulting {@link KTable} is partitioned + * correctly on its key. + * Note that you cannot enable {@link StreamsConfig#TOPOLOGY_OPTIMIZATION_CONFIG} config for this case, because + * repartition topics are considered transient and don't allow to recover the result {@link KTable} in cause of + * a failure; hence, a dedicated changelog topic is required to guarantee fault-tolerance. + *

                + * Note that this is a logical operation and only changes the "interpretation" of the stream, i.e., each record of + * it was a "fact/event" and is re-interpreted as update now (cf. {@link KStream} vs {@code KTable}). + * + * @param named a {@link Named} config used to name the processor in the topology + * @param materialized an instance of {@link Materialized} used to describe how the state store of the + * resulting table should be materialized. + * @return a {@link KTable} that contains the same records as this {@code KStream} + */ + KTable toTable(final Named named, + final Materialized> materialized); + + /** + * Group the records of this {@code KStream} on a new key that is selected using the provided {@link KeyValueMapper} + * and default serializers and deserializers. + * {@link KGroupedStream} can be further grouped with other streams to form a {@link CogroupedKStream}. + * Grouping a stream on the record key is required before an aggregation operator can be applied to the data + * (cf. {@link KGroupedStream}). + * The {@link KeyValueMapper} selects a new key (which may or may not be of the same type) while preserving the + * original values. + * If the new record key is {@code null} the record will not be included in the resulting {@link KGroupedStream} + *

                + * Because a new key is selected, an internal repartitioning topic may need to be created in Kafka if a + * later operator depends on the newly selected key. + * This topic will be named "${applicationId}-<name>-repartition", where "applicationId" is user-specified in + * {@link StreamsConfig} via parameter {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, + * "<name>" is an internally generated name, and "-repartition" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + *

                + * All data of this stream will be redistributed through the repartitioning topic by writing all records to it, + * and rereading all records from it, such that the resulting {@link KGroupedStream} is partitioned on the new key. + *

                + * This operation is equivalent to calling {@link #selectKey(KeyValueMapper)} followed by {@link #groupByKey()}. + * If the key type is changed, it is recommended to use {@link #groupBy(KeyValueMapper, Grouped)} instead. + * + * @param keySelector a {@link KeyValueMapper} that computes a new key for grouping + * @param the key type of the result {@link KGroupedStream} + * @return a {@link KGroupedStream} that contains the grouped records of the original {@code KStream} + */ + KGroupedStream groupBy(final KeyValueMapper keySelector); + + /** + * Group the records of this {@code KStream} on a new key that is selected using the provided {@link KeyValueMapper} + * and {@link Serde}s as specified by {@link Grouped}. + * {@link KGroupedStream} can be further grouped with other streams to form a {@link CogroupedKStream}. + * Grouping a stream on the record key is required before an aggregation operator can be applied to the data + * (cf. {@link KGroupedStream}). + * The {@link KeyValueMapper} selects a new key (which may or may not be of the same type) while preserving the + * original values. + * If the new record key is {@code null} the record will not be included in the resulting {@link KGroupedStream}. + *

                + * Because a new key is selected, an internal repartitioning topic may need to be created in Kafka if a later + * operator depends on the newly selected key. + * This topic will be named "${applicationId}-<name>-repartition", where "applicationId" is user-specified in + * {@link StreamsConfig} via parameter {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, + * "<name>" is either provided via {@link org.apache.kafka.streams.kstream.Grouped#as(String)} or an + * internally generated name. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + *

                + * All data of this stream will be redistributed through the repartitioning topic by writing all records to it, + * and rereading all records from it, such that the resulting {@link KGroupedStream} is partitioned on the new key. + *

                + * This operation is equivalent to calling {@link #selectKey(KeyValueMapper)} followed by {@link #groupByKey()}. + * + * @param keySelector a {@link KeyValueMapper} that computes a new key for grouping + * @param grouped the {@link Grouped} instance used to specify {@link org.apache.kafka.common.serialization.Serdes} + * and part of the name for a repartition topic if repartitioning is required. + * @param the key type of the result {@link KGroupedStream} + * @return a {@link KGroupedStream} that contains the grouped records of the original {@code KStream} + */ + KGroupedStream groupBy(final KeyValueMapper keySelector, + final Grouped grouped); + + /** + * Group the records by their current key into a {@link KGroupedStream} while preserving the original values + * and default serializers and deserializers. + * {@link KGroupedStream} can be further grouped with other streams to form a {@link CogroupedKStream}. + * Grouping a stream on the record key is required before an aggregation operator can be applied to the data + * (cf. {@link KGroupedStream}). + * If a record key is {@code null} the record will not be included in the resulting {@link KGroupedStream}. + *

                + * If a key changing operator was used before this operation (e.g., {@link #selectKey(KeyValueMapper)}, + * {@link #map(KeyValueMapper)}, {@link #flatMap(KeyValueMapper)} or + * {@link #transform(TransformerSupplier, String...)}) an internal repartitioning topic may need to be created in + * Kafka if a later operator depends on the newly selected key. + * This topic will be named "${applicationId}-<name>-repartition", where "applicationId" is user-specified in + * {@link StreamsConfig} via parameter {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, + * "<name>" is an internally generated name, and "-repartition" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + *

                + * For this case, all data of this stream will be redistributed through the repartitioning topic by writing all + * records to it, and rereading all records from it, such that the resulting {@link KGroupedStream} is partitioned + * correctly on its key. + * If the last key changing operator changed the key type, it is recommended to use + * {@link #groupByKey(org.apache.kafka.streams.kstream.Grouped)} instead. + * + * @return a {@link KGroupedStream} that contains the grouped records of the original {@code KStream} + * @see #groupBy(KeyValueMapper) + */ + KGroupedStream groupByKey(); + + /** + * Group the records by their current key into a {@link KGroupedStream} while preserving the original values + * and using the serializers as defined by {@link Grouped}. + * {@link KGroupedStream} can be further grouped with other streams to form a {@link CogroupedKStream}. + * Grouping a stream on the record key is required before an aggregation operator can be applied to the data + * (cf. {@link KGroupedStream}). + * If a record key is {@code null} the record will not be included in the resulting {@link KGroupedStream}. + *

                + * If a key changing operator was used before this operation (e.g., {@link #selectKey(KeyValueMapper)}, + * {@link #map(KeyValueMapper)}, {@link #flatMap(KeyValueMapper)} or + * {@link #transform(TransformerSupplier, String...)}) an internal repartitioning topic may need to be created in + * Kafka if a later operator depends on the newly selected key. + * This topic will be named "${applicationId}-<name>-repartition", where "applicationId" is user-specified in + * {@link StreamsConfig} via parameter {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, + * <name> is either provided via {@link org.apache.kafka.streams.kstream.Grouped#as(String)} or an internally + * generated name, and "-repartition" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + *

                + * For this case, all data of this stream will be redistributed through the repartitioning topic by writing all + * records to it, and rereading all records from it, such that the resulting {@link KGroupedStream} is partitioned + * correctly on its key. + * + * @param grouped the {@link Grouped} instance used to specify {@link Serdes} + * and part of the name for a repartition topic if repartitioning is required. + * @return a {@link KGroupedStream} that contains the grouped records of the original {@code KStream} + * @see #groupBy(KeyValueMapper) + */ + KGroupedStream groupByKey(final Grouped grouped); + + /** + * Join records of this stream with another {@code KStream}'s records using windowed inner equi join with default + * serializers and deserializers. + * The join is computed on the records' key with join attribute {@code thisKStream.key == otherKStream.key}. + * Furthermore, two records are only joined if their timestamps are close to each other as defined by the given + * {@link JoinWindows}, i.e., the window defines an additional join predicate on the record timestamps. + *

                + * For each pair of records meeting both join predicates the provided {@link ValueJoiner} will be called to compute + * a value (with arbitrary type) for the result record. + * The key of the result record is the same as for both joining input records. + * If an input record key or value is {@code null} the record will not be included in the join operation and thus no + * output record will be added to the resulting {@code KStream}. + *

                + * Example (assuming all input records belong to the correct windows): + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                thisotherresult
                <K1:A>
                <K2:B><K2:b><K2:ValueJoiner(B,b)>
                <K3:c>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * If this is not the case, you would need to call {@link #repartition(Repartitioned)} (for one input stream) before + * doing the join and specify the "correct" number of partitions via {@link Repartitioned} parameter. + * Furthermore, both input streams need to be co-partitioned on the join key (i.e., use the same partitioner). + * If this requirement is not met, Kafka Streams will automatically repartition the data, i.e., it will create an + * internal repartitioning topic in Kafka and write and re-read the data via this topic before the actual join. + * The repartitioning topic will be named "${applicationId}-<name>-repartition", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "<name>" is an internally generated + * name, and "-repartition" is a fixed suffix. + *

                + * Repartitioning can happen for one or both of the joining {@code KStream}s. + * For this case, all data of the stream will be redistributed through the repartitioning topic by writing all + * records to it, and rereading all records from it, such that the join input {@code KStream} is partitioned + * correctly on its key. + *

                + * Both of the joining {@code KStream}s will be materialized in local state stores with auto-generated store names. + * For failure and recovery each store will be backed by an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-<storename>-changelog", where "applicationId" is user-specified + * in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is an + * internally generated name, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param otherStream the {@code KStream} to be joined with this stream + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param windows the specification of the {@link JoinWindows} + * @param the value type of the other stream + * @param the value type of the result stream + * @return a {@code KStream} that contains join-records for each key and values computed by the given + * {@link ValueJoiner}, one for each matched record-pair with the same key and within the joining window intervals + * @see #leftJoin(KStream, ValueJoiner, JoinWindows) + * @see #outerJoin(KStream, ValueJoiner, JoinWindows) + */ + KStream join(final KStream otherStream, + final ValueJoiner joiner, + final JoinWindows windows); + + /** + * Join records of this stream with another {@code KStream}'s records using windowed inner equi join with default + * serializers and deserializers. + * The join is computed on the records' key with join attribute {@code thisKStream.key == otherKStream.key}. + * Furthermore, two records are only joined if their timestamps are close to each other as defined by the given + * {@link JoinWindows}, i.e., the window defines an additional join predicate on the record timestamps. + *

                + * For each pair of records meeting both join predicates the provided {@link ValueJoinerWithKey} will be called to compute + * a value (with arbitrary type) for the result record. + * Note that the key is read-only and should not be modified, as this can lead to undefined behaviour. + * The key of the result record is the same as for both joining input records. + * If an input record key or value is {@code null} the record will not be included in the join operation and thus no + * output record will be added to the resulting {@code KStream}. + *

                + * Example (assuming all input records belong to the correct windows): + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                thisotherresult
                <K1:A>
                <K2:B><K2:b><K2:ValueJoinerWithKey(K1,B,b)>
                <K3:c>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * If this is not the case, you would need to call {@link #repartition(Repartitioned)} (for one input stream) before + * doing the join and specify the "correct" number of partitions via {@link Repartitioned} parameter. + * Furthermore, both input streams need to be co-partitioned on the join key (i.e., use the same partitioner). + * If this requirement is not met, Kafka Streams will automatically repartition the data, i.e., it will create an + * internal repartitioning topic in Kafka and write and re-read the data via this topic before the actual join. + * The repartitioning topic will be named "${applicationId}-<name>-repartition", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "<name>" is an internally generated + * name, and "-repartition" is a fixed suffix. + *

                + * Repartitioning can happen for one or both of the joining {@code KStream}s. + * For this case, all data of the stream will be redistributed through the repartitioning topic by writing all + * records to it, and rereading all records from it, such that the join input {@code KStream} is partitioned + * correctly on its key. + *

                + * Both of the joining {@code KStream}s will be materialized in local state stores with auto-generated store names. + * For failure and recovery each store will be backed by an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-<storename>-changelog", where "applicationId" is user-specified + * in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is an + * internally generated name, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param otherStream the {@code KStream} to be joined with this stream + * @param joiner a {@link ValueJoinerWithKey} that computes the join result for a pair of matching records + * @param windows the specification of the {@link JoinWindows} + * @param the value type of the other stream + * @param the value type of the result stream + * @return a {@code KStream} that contains join-records for each key and values computed by the given + * {@link ValueJoinerWithKey}, one for each matched record-pair with the same key and within the joining window intervals + * @see #leftJoin(KStream, ValueJoinerWithKey, JoinWindows) + * @see #outerJoin(KStream, ValueJoinerWithKey, JoinWindows) + */ + KStream join(final KStream otherStream, + final ValueJoinerWithKey joiner, + final JoinWindows windows); + + /** + * Join records of this stream with another {@code KStream}'s records using windowed inner equi join using the + * {@link StreamJoined} instance for configuration of the {@link Serde key serde}, {@link Serde this stream's value + * serde}, {@link Serde the other stream's value serde}, and used state stores. + * The join is computed on the records' key with join attribute {@code thisKStream.key == otherKStream.key}. + * Furthermore, two records are only joined if their timestamps are close to each other as defined by the given + * {@link JoinWindows}, i.e., the window defines an additional join predicate on the record timestamps. + *

                + * For each pair of records meeting both join predicates the provided {@link ValueJoiner} will be called to compute + * a value (with arbitrary type) for the result record. + * The key of the result record is the same as for both joining input records. + * If an input record key or value is {@code null} the record will not be included in the join operation and thus no + * output record will be added to the resulting {@code KStream}. + *

                + * Example (assuming all input records belong to the correct windows): + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                thisotherresult
                <K1:A>
                <K2:B><K2:b><K2:ValueJoiner(B,b)>
                <K3:c>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * If this is not the case, you would need to call {@link #repartition(Repartitioned)} (for one input stream) before + * doing the join and specify the "correct" number of partitions via {@link Repartitioned} parameter. + * Furthermore, both input streams need to be co-partitioned on the join key (i.e., use the same partitioner). + * If this requirement is not met, Kafka Streams will automatically repartition the data, i.e., it will create an + * internal repartitioning topic in Kafka and write and re-read the data via this topic before the actual join. + * The repartitioning topic will be named "${applicationId}-<name>-repartition", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "<name>" is an internally generated + * name, and "-repartition" is a fixed suffix. + *

                + * Repartitioning can happen for one or both of the joining {@code KStream}s. + * For this case, all data of the stream will be redistributed through the repartitioning topic by writing all + * records to it, and rereading all records from it, such that the join input {@code KStream} is partitioned + * correctly on its key. + *

                + * Both of the joining {@code KStream}s will be materialized in local state stores with auto-generated store names, + * unless a name is provided via a {@code Materialized} instance. + * For failure and recovery each store will be backed by an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-<storename>-changelog", where "applicationId" is user-specified + * in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is an + * internally generated name, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param the value type of the other stream + * @param the value type of the result stream + * @param otherStream the {@code KStream} to be joined with this stream + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param windows the specification of the {@link JoinWindows} + * @param streamJoined a {@link StreamJoined} used to configure join stores + * @return a {@code KStream} that contains join-records for each key and values computed by the given + * {@link ValueJoiner}, one for each matched record-pair with the same key and within the joining window intervals + * @see #leftJoin(KStream, ValueJoiner, JoinWindows, StreamJoined) + * @see #outerJoin(KStream, ValueJoiner, JoinWindows, StreamJoined) + */ + KStream join(final KStream otherStream, + final ValueJoiner joiner, + final JoinWindows windows, + final StreamJoined streamJoined); + + /** + * Join records of this stream with another {@code KStream}'s records using windowed inner equi join using the + * {@link StreamJoined} instance for configuration of the {@link Serde key serde}, {@link Serde this stream's value + * serde}, {@link Serde the other stream's value serde}, and used state stores. + * The join is computed on the records' key with join attribute {@code thisKStream.key == otherKStream.key}. + * Furthermore, two records are only joined if their timestamps are close to each other as defined by the given + * {@link JoinWindows}, i.e., the window defines an additional join predicate on the record timestamps. + *

                + * For each pair of records meeting both join predicates the provided {@link ValueJoinerWithKey} will be called to compute + * a value (with arbitrary type) for the result record. + * Note that the key is read-only and should not be modified, as this can lead to undefined behaviour. + * The key of the result record is the same as for both joining input records. + * If an input record key or value is {@code null} the record will not be included in the join operation and thus no + * output record will be added to the resulting {@code KStream}. + *

                + * Example (assuming all input records belong to the correct windows): + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                thisotherresult
                <K1:A>
                <K2:B><K2:b><K2:ValueJoinerWithKey(K1,B,b)>
                <K3:c>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * If this is not the case, you would need to call {@link #repartition(Repartitioned)} (for one input stream) before + * doing the join and specify the "correct" number of partitions via {@link Repartitioned} parameter. + * Furthermore, both input streams need to be co-partitioned on the join key (i.e., use the same partitioner). + * If this requirement is not met, Kafka Streams will automatically repartition the data, i.e., it will create an + * internal repartitioning topic in Kafka and write and re-read the data via this topic before the actual join. + * The repartitioning topic will be named "${applicationId}-<name>-repartition", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "<name>" is an internally generated + * name, and "-repartition" is a fixed suffix. + *

                + * Repartitioning can happen for one or both of the joining {@code KStream}s. + * For this case, all data of the stream will be redistributed through the repartitioning topic by writing all + * records to it, and rereading all records from it, such that the join input {@code KStream} is partitioned + * correctly on its key. + *

                + * Both of the joining {@code KStream}s will be materialized in local state stores with auto-generated store names, + * unless a name is provided via a {@code Materialized} instance. + * For failure and recovery each store will be backed by an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-<storename>-changelog", where "applicationId" is user-specified + * in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is an + * internally generated name, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param the value type of the other stream + * @param the value type of the result stream + * @param otherStream the {@code KStream} to be joined with this stream + * @param joiner a {@link ValueJoinerWithKey} that computes the join result for a pair of matching records + * @param windows the specification of the {@link JoinWindows} + * @param streamJoined a {@link StreamJoined} used to configure join stores + * @return a {@code KStream} that contains join-records for each key and values computed by the given + * {@link ValueJoinerWithKey}, one for each matched record-pair with the same key and within the joining window intervals + * @see #leftJoin(KStream, ValueJoinerWithKey, JoinWindows, StreamJoined) + * @see #outerJoin(KStream, ValueJoinerWithKey, JoinWindows, StreamJoined) + */ + KStream join(final KStream otherStream, + final ValueJoinerWithKey joiner, + final JoinWindows windows, + final StreamJoined streamJoined); + /** + * Join records of this stream with another {@code KStream}'s records using windowed left equi join with default + * serializers and deserializers. + * In contrast to {@link #join(KStream, ValueJoiner, JoinWindows) inner-join}, all records from this stream will + * produce at least one output record (cf. below). + * The join is computed on the records' key with join attribute {@code thisKStream.key == otherKStream.key}. + * Furthermore, two records are only joined if their timestamps are close to each other as defined by the given + * {@link JoinWindows}, i.e., the window defines an additional join predicate on the record timestamps. + *

                + * For each pair of records meeting both join predicates the provided {@link ValueJoiner} will be called to compute + * a value (with arbitrary type) for the result record. + * The key of the result record is the same as for both joining input records. + * Furthermore, for each input record of this {@code KStream} that does not satisfy the join predicate the provided + * {@link ValueJoiner} will be called with a {@code null} value for the other stream. + * If an input record key or value is {@code null} the record will not be included in the join operation and thus no + * output record will be added to the resulting {@code KStream}. + *

                + * Example (assuming all input records belong to the correct windows): + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                thisotherresult
                <K1:A><K1:ValueJoiner(A,null)>
                <K2:B><K2:b><K2:ValueJoiner(B,b)>
                <K3:c>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * If this is not the case, you would need to call {@link #repartition(Repartitioned)} (for one input stream) before + * doing the join and specify the "correct" number of partitions via {@link Repartitioned} parameter. + * Furthermore, both input streams need to be co-partitioned on the join key (i.e., use the same partitioner). + * If this requirement is not met, Kafka Streams will automatically repartition the data, i.e., it will create an + * internal repartitioning topic in Kafka and write and re-read the data via this topic before the actual join. + * The repartitioning topic will be named "${applicationId}-<name>-repartition", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "<name>" is an internally generated + * name, and "-repartition" is a fixed suffix. + *

                + * Repartitioning can happen for one or both of the joining {@code KStream}s. + * For this case, all data of the stream will be redistributed through the repartitioning topic by writing all + * records to it, and rereading all records from it, such that the join input {@code KStream} is partitioned + * correctly on its key. + *

                + * Both of the joining {@code KStream}s will be materialized in local state stores with auto-generated store names. + * For failure and recovery each store will be backed by an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-<storename>-changelog", where "applicationId" is user-specified + * in {@link StreamsConfig} via parameter {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, + * "storeName" is an internally generated name, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param otherStream the {@code KStream} to be joined with this stream + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param windows the specification of the {@link JoinWindows} + * @param the value type of the other stream + * @param the value type of the result stream + * @return a {@code KStream} that contains join-records for each key and values computed by the given + * {@link ValueJoiner}, one for each matched record-pair with the same key plus one for each non-matching record of + * this {@code KStream} and within the joining window intervals + * @see #join(KStream, ValueJoiner, JoinWindows) + * @see #outerJoin(KStream, ValueJoiner, JoinWindows) + */ + KStream leftJoin(final KStream otherStream, + final ValueJoiner joiner, + final JoinWindows windows); + /** + * Join records of this stream with another {@code KStream}'s records using windowed left equi join with default + * serializers and deserializers. + * In contrast to {@link #join(KStream, ValueJoinerWithKey, JoinWindows) inner-join}, all records from this stream will + * produce at least one output record (cf. below). + * The join is computed on the records' key with join attribute {@code thisKStream.key == otherKStream.key}. + * Furthermore, two records are only joined if their timestamps are close to each other as defined by the given + * {@link JoinWindows}, i.e., the window defines an additional join predicate on the record timestamps. + *

                + * For each pair of records meeting both join predicates the provided {@link ValueJoinerWithKey} will be called to compute + * a value (with arbitrary type) for the result record. + * Note that the key is read-only and should not be modified, as this can lead to undefined behaviour. + * The key of the result record is the same as for both joining input records. + * Furthermore, for each input record of this {@code KStream} that does not satisfy the join predicate the provided + * {@link ValueJoinerWithKey} will be called with a {@code null} value for the other stream. + * If an input record key or value is {@code null} the record will not be included in the join operation and thus no + * output record will be added to the resulting {@code KStream}. + *

                + * Example (assuming all input records belong to the correct windows): + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                thisotherresult
                <K1:A><K1:ValueJoinerWithKey(K1, A,null)>
                <K2:B><K2:b><K2:ValueJoinerWithKey(K2, B,b)>
                <K3:c>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * If this is not the case, you would need to call {@link #repartition(Repartitioned)} (for one input stream) before + * doing the join and specify the "correct" number of partitions via {@link Repartitioned} parameter. + * Furthermore, both input streams need to be co-partitioned on the join key (i.e., use the same partitioner). + * If this requirement is not met, Kafka Streams will automatically repartition the data, i.e., it will create an + * internal repartitioning topic in Kafka and write and re-read the data via this topic before the actual join. + * The repartitioning topic will be named "${applicationId}-<name>-repartition", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "<name>" is an internally generated + * name, and "-repartition" is a fixed suffix. + *

                + * Repartitioning can happen for one or both of the joining {@code KStream}s. + * For this case, all data of the stream will be redistributed through the repartitioning topic by writing all + * records to it, and rereading all records from it, such that the join input {@code KStream} is partitioned + * correctly on its key. + *

                + * Both of the joining {@code KStream}s will be materialized in local state stores with auto-generated store names. + * For failure and recovery each store will be backed by an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-<storename>-changelog", where "applicationId" is user-specified + * in {@link StreamsConfig} via parameter {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, + * "storeName" is an internally generated name, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param otherStream the {@code KStream} to be joined with this stream + * @param joiner a {@link ValueJoinerWithKey} that computes the join result for a pair of matching records + * @param windows the specification of the {@link JoinWindows} + * @param the value type of the other stream + * @param the value type of the result stream + * @return a {@code KStream} that contains join-records for each key and values computed by the given + * {@link ValueJoinerWithKey}, one for each matched record-pair with the same key plus one for each non-matching record of + * this {@code KStream} and within the joining window intervals + * @see #join(KStream, ValueJoinerWithKey, JoinWindows) + * @see #outerJoin(KStream, ValueJoinerWithKey, JoinWindows) + */ + KStream leftJoin(final KStream otherStream, + final ValueJoinerWithKey joiner, + final JoinWindows windows); + + /** + * Join records of this stream with another {@code KStream}'s records using windowed left equi join using the + * {@link StreamJoined} instance for configuration of the {@link Serde key serde}, {@link Serde this stream's value + * serde}, {@link Serde the other stream's value serde}, and used state stores. + * In contrast to {@link #join(KStream, ValueJoiner, JoinWindows) inner-join}, all records from this stream will + * produce at least one output record (cf. below). + * The join is computed on the records' key with join attribute {@code thisKStream.key == otherKStream.key}. + * Furthermore, two records are only joined if their timestamps are close to each other as defined by the given + * {@link JoinWindows}, i.e., the window defines an additional join predicate on the record timestamps. + *

                + * For each pair of records meeting both join predicates the provided {@link ValueJoiner} will be called to compute + * a value (with arbitrary type) for the result record. + * The key of the result record is the same as for both joining input records. + * Furthermore, for each input record of this {@code KStream} that does not satisfy the join predicate the provided + * {@link ValueJoiner} will be called with a {@code null} value for the other stream. + * If an input record key or value is {@code null} the record will not be included in the join operation and thus no + * output record will be added to the resulting {@code KStream}. + *

                + * Example (assuming all input records belong to the correct windows): + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                thisotherresult
                <K1:A><K1:ValueJoiner(A,null)>
                <K2:B><K2:b><K2:ValueJoiner(B,b)>
                <K3:c>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * If this is not the case, you would need to call {@link #repartition(Repartitioned)} (for one input stream) before + * doing the join and specify the "correct" number of partitions via {@link Repartitioned} parameter. + * Furthermore, both input streams need to be co-partitioned on the join key (i.e., use the same partitioner). + * If this requirement is not met, Kafka Streams will automatically repartition the data, i.e., it will create an + * internal repartitioning topic in Kafka and write and re-read the data via this topic before the actual join. + * The repartitioning topic will be named "${applicationId}-<name>-repartition", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "<name>" is an internally generated + * name, and "-repartition" is a fixed suffix. + *

                + * Repartitioning can happen for one or both of the joining {@code KStream}s. + * For this case, all data of the stream will be redistributed through the repartitioning topic by writing all + * records to it, and rereading all records from it, such that the join input {@code KStream} is partitioned + * correctly on its key. + *

                + * Both of the joining {@code KStream}s will be materialized in local state stores with auto-generated store names, + * unless a name is provided via a {@code Materialized} instance. + * For failure and recovery each store will be backed by an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-<storename>-changelog", where "applicationId" is user-specified + * in {@link StreamsConfig} via parameter {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, + * "storeName" is an internally generated name, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param the value type of the other stream + * @param the value type of the result stream + * @param otherStream the {@code KStream} to be joined with this stream + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param windows the specification of the {@link JoinWindows} + * @param streamJoined a {@link StreamJoined} instance to configure serdes and state stores + * @return a {@code KStream} that contains join-records for each key and values computed by the given + * {@link ValueJoiner}, one for each matched record-pair with the same key plus one for each non-matching record of + * this {@code KStream} and within the joining window intervals + * @see #join(KStream, ValueJoiner, JoinWindows, StreamJoined) + * @see #outerJoin(KStream, ValueJoiner, JoinWindows, StreamJoined) + */ + KStream leftJoin(final KStream otherStream, + final ValueJoiner joiner, + final JoinWindows windows, + final StreamJoined streamJoined); + + /** + * Join records of this stream with another {@code KStream}'s records using windowed left equi join using the + * {@link StreamJoined} instance for configuration of the {@link Serde key serde}, {@link Serde this stream's value + * serde}, {@link Serde the other stream's value serde}, and used state stores. + * In contrast to {@link #join(KStream, ValueJoinerWithKey, JoinWindows) inner-join}, all records from this stream will + * produce at least one output record (cf. below). + * The join is computed on the records' key with join attribute {@code thisKStream.key == otherKStream.key}. + * Furthermore, two records are only joined if their timestamps are close to each other as defined by the given + * {@link JoinWindows}, i.e., the window defines an additional join predicate on the record timestamps. + *

                + * For each pair of records meeting both join predicates the provided {@link ValueJoinerWithKey} will be called to compute + * a value (with arbitrary type) for the result record. + * Note that the key is read-only and should not be modified, as this can lead to undefined behaviour. + * The key of the result record is the same as for both joining input records. + * Furthermore, for each input record of this {@code KStream} that does not satisfy the join predicate the provided + * {@link ValueJoinerWithKey} will be called with a {@code null} value for the other stream. + * If an input record key or value is {@code null} the record will not be included in the join operation and thus no + * output record will be added to the resulting {@code KStream}. + *

                + * Example (assuming all input records belong to the correct windows): + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                thisotherresult
                <K1:A><K1:ValueJoinerWithKey(K1,A,null)>
                <K2:B><K2:b><K2:ValueJoinerWithKey(K2,B,b)>
                <K3:c>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * If this is not the case, you would need to call {@link #repartition(Repartitioned)} (for one input stream) before + * doing the join and specify the "correct" number of partitions via {@link Repartitioned} parameter. + * Furthermore, both input streams need to be co-partitioned on the join key (i.e., use the same partitioner). + * If this requirement is not met, Kafka Streams will automatically repartition the data, i.e., it will create an + * internal repartitioning topic in Kafka and write and re-read the data via this topic before the actual join. + * The repartitioning topic will be named "${applicationId}-<name>-repartition", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "<name>" is an internally generated + * name, and "-repartition" is a fixed suffix. + *

                + * Repartitioning can happen for one or both of the joining {@code KStream}s. + * For this case, all data of the stream will be redistributed through the repartitioning topic by writing all + * records to it, and rereading all records from it, such that the join input {@code KStream} is partitioned + * correctly on its key. + *

                + * Both of the joining {@code KStream}s will be materialized in local state stores with auto-generated store names, + * unless a name is provided via a {@code Materialized} instance. + * For failure and recovery each store will be backed by an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-<storename>-changelog", where "applicationId" is user-specified + * in {@link StreamsConfig} via parameter {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, + * "storeName" is an internally generated name, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param the value type of the other stream + * @param the value type of the result stream + * @param otherStream the {@code KStream} to be joined with this stream + * @param joiner a {@link ValueJoinerWithKey} that computes the join result for a pair of matching records + * @param windows the specification of the {@link JoinWindows} + * @param streamJoined a {@link StreamJoined} instance to configure serdes and state stores + * @return a {@code KStream} that contains join-records for each key and values computed by the given + * {@link ValueJoinerWithKey}, one for each matched record-pair with the same key plus one for each non-matching record of + * this {@code KStream} and within the joining window intervals + * @see #join(KStream, ValueJoinerWithKey, JoinWindows, StreamJoined) + * @see #outerJoin(KStream, ValueJoinerWithKey, JoinWindows, StreamJoined) + */ + KStream leftJoin(final KStream otherStream, + final ValueJoinerWithKey joiner, + final JoinWindows windows, + final StreamJoined streamJoined); + /** + * Join records of this stream with another {@code KStream}'s records using windowed outer equi join with default + * serializers and deserializers. + * In contrast to {@link #join(KStream, ValueJoiner, JoinWindows) inner-join} or + * {@link #leftJoin(KStream, ValueJoiner, JoinWindows) left-join}, all records from both streams will produce at + * least one output record (cf. below). + * The join is computed on the records' key with join attribute {@code thisKStream.key == otherKStream.key}. + * Furthermore, two records are only joined if their timestamps are close to each other as defined by the given + * {@link JoinWindows}, i.e., the window defines an additional join predicate on the record timestamps. + *

                + * For each pair of records meeting both join predicates the provided {@link ValueJoiner} will be called to compute + * a value (with arbitrary type) for the result record. + * The key of the result record is the same as for both joining input records. + * Furthermore, for each input record of both {@code KStream}s that does not satisfy the join predicate the provided + * {@link ValueJoiner} will be called with a {@code null} value for the this/other stream, respectively. + * If an input record key or value is {@code null} the record will not be included in the join operation and thus no + * output record will be added to the resulting {@code KStream}. + *

                + * Example (assuming all input records belong to the correct windows): + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                thisotherresult
                <K1:A><K1:ValueJoiner(A,null)>
                <K2:B><K2:b><K2:ValueJoiner(null,b)>

                <K2:ValueJoiner(B,b)>
                <K3:c><K3:ValueJoiner(null,c)>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * If this is not the case, you would need to call {@link #repartition(Repartitioned)} (for one input stream) before + * doing the join and specify the "correct" number of partitions via {@link Repartitioned} parameter. + * Furthermore, both input streams need to be co-partitioned on the join key (i.e., use the same partitioner). + * If this requirement is not met, Kafka Streams will automatically repartition the data, i.e., it will create an + * internal repartitioning topic in Kafka and write and re-read the data via this topic before the actual join. + * The repartitioning topic will be named "${applicationId}-<name>-repartition", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "<name>" is an internally generated + * name, and "-repartition" is a fixed suffix. + *

                + * Repartitioning can happen for one or both of the joining {@code KStream}s. + * For this case, all data of the stream will be redistributed through the repartitioning topic by writing all + * records to it, and rereading all records from it, such that the join input {@code KStream} is partitioned + * correctly on its key. + *

                + * Both of the joining {@code KStream}s will be materialized in local state stores with auto-generated store names. + * For failure and recovery each store will be backed by an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-<storename>-changelog", where "applicationId" is user-specified + * in {@link StreamsConfig} via parameter {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, + * "storeName" is an internally generated name, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param otherStream the {@code KStream} to be joined with this stream + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param windows the specification of the {@link JoinWindows} + * @param the value type of the other stream + * @param the value type of the result stream + * @return a {@code KStream} that contains join-records for each key and values computed by the given + * {@link ValueJoiner}, one for each matched record-pair with the same key plus one for each non-matching record of + * both {@code KStream} and within the joining window intervals + * @see #join(KStream, ValueJoiner, JoinWindows) + * @see #leftJoin(KStream, ValueJoiner, JoinWindows) + */ + KStream outerJoin(final KStream otherStream, + final ValueJoiner joiner, + final JoinWindows windows); + /** + * Join records of this stream with another {@code KStream}'s records using windowed outer equi join with default + * serializers and deserializers. + * In contrast to {@link #join(KStream, ValueJoinerWithKey, JoinWindows) inner-join} or + * {@link #leftJoin(KStream, ValueJoinerWithKey, JoinWindows) left-join}, all records from both streams will produce at + * least one output record (cf. below). + * The join is computed on the records' key with join attribute {@code thisKStream.key == otherKStream.key}. + * Furthermore, two records are only joined if their timestamps are close to each other as defined by the given + * {@link JoinWindows}, i.e., the window defines an additional join predicate on the record timestamps. + *

                + * For each pair of records meeting both join predicates the provided {@link ValueJoinerWithKey} will be called to compute + * a value (with arbitrary type) for the result record. + * Note that the key is read-only and should not be modified, as this can lead to undefined behaviour. + * The key of the result record is the same as for both joining input records. + * Furthermore, for each input record of both {@code KStream}s that does not satisfy the join predicate the provided + * {@link ValueJoinerWithKey} will be called with a {@code null} value for the this/other stream, respectively. + * If an input record key or value is {@code null} the record will not be included in the join operation and thus no + * output record will be added to the resulting {@code KStream}. + *

                + * Example (assuming all input records belong to the correct windows): + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                thisotherresult
                <K1:A><K1:ValueJoinerWithKey(K1,A,null)>
                <K2:B><K2:b><K2:ValueJoinerWithKey(K2,null,b)>

                <K2:ValueJoinerWithKey(K2,B,b)>
                <K3:c><K3:ValueJoinerWithKey(K3,null,c)>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * If this is not the case, you would need to call {@link #repartition(Repartitioned)} (for one input stream) before + * doing the join and specify the "correct" number of partitions via {@link Repartitioned} parameter. + * Furthermore, both input streams need to be co-partitioned on the join key (i.e., use the same partitioner). + * If this requirement is not met, Kafka Streams will automatically repartition the data, i.e., it will create an + * internal repartitioning topic in Kafka and write and re-read the data via this topic before the actual join. + * The repartitioning topic will be named "${applicationId}-<name>-repartition", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "<name>" is an internally generated + * name, and "-repartition" is a fixed suffix. + *

                + * Repartitioning can happen for one or both of the joining {@code KStream}s. + * For this case, all data of the stream will be redistributed through the repartitioning topic by writing all + * records to it, and rereading all records from it, such that the join input {@code KStream} is partitioned + * correctly on its key. + *

                + * Both of the joining {@code KStream}s will be materialized in local state stores with auto-generated store names. + * For failure and recovery each store will be backed by an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-<storename>-changelog", where "applicationId" is user-specified + * in {@link StreamsConfig} via parameter {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, + * "storeName" is an internally generated name, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param otherStream the {@code KStream} to be joined with this stream + * @param joiner a {@link ValueJoinerWithKey} that computes the join result for a pair of matching records + * @param windows the specification of the {@link JoinWindows} + * @param the value type of the other stream + * @param the value type of the result stream + * @return a {@code KStream} that contains join-records for each key and values computed by the given + * {@link ValueJoinerWithKey}, one for each matched record-pair with the same key plus one for each non-matching record of + * both {@code KStream} and within the joining window intervals + * @see #join(KStream, ValueJoinerWithKey, JoinWindows) + * @see #leftJoin(KStream, ValueJoinerWithKey, JoinWindows) + */ + KStream outerJoin(final KStream otherStream, + final ValueJoinerWithKey joiner, + final JoinWindows windows); + + /** + * Join records of this stream with another {@code KStream}'s records using windowed outer equi join using the + * {@link StreamJoined} instance for configuration of the {@link Serde key serde}, {@link Serde this stream's value + * serde}, {@link Serde the other stream's value serde}, and used state stores. + * In contrast to {@link #join(KStream, ValueJoiner, JoinWindows) inner-join} or + * {@link #leftJoin(KStream, ValueJoiner, JoinWindows) left-join}, all records from both streams will produce at + * least one output record (cf. below). + * The join is computed on the records' key with join attribute {@code thisKStream.key == otherKStream.key}. + * Furthermore, two records are only joined if their timestamps are close to each other as defined by the given + * {@link JoinWindows}, i.e., the window defines an additional join predicate on the record timestamps. + *

                + * For each pair of records meeting both join predicates the provided {@link ValueJoiner} will be called to compute + * a value (with arbitrary type) for the result record. + * The key of the result record is the same as for both joining input records. + * Furthermore, for each input record of both {@code KStream}s that does not satisfy the join predicate the provided + * {@link ValueJoiner} will be called with a {@code null} value for this/other stream, respectively. + * If an input record key or value is {@code null} the record will not be included in the join operation and thus no + * output record will be added to the resulting {@code KStream}. + *

                + * Example (assuming all input records belong to the correct windows): + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                thisotherresult
                <K1:A><K1:ValueJoiner(A,null)>
                <K2:B><K2:b><K2:ValueJoiner(null,b)>

                <K2:ValueJoiner(B,b)>
                <K3:c><K3:ValueJoiner(null,c)>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * If this is not the case, you would need to call {@link #repartition(Repartitioned)} (for one input stream) before + * doing the join and specify the "correct" number of partitions via {@link Repartitioned} parameter. + * Furthermore, both input streams need to be co-partitioned on the join key (i.e., use the same partitioner). + * If this requirement is not met, Kafka Streams will automatically repartition the data, i.e., it will create an + * internal repartitioning topic in Kafka and write and re-read the data via this topic before the actual join. + * The repartitioning topic will be named "${applicationId}-<name>-repartition", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "<name>" is an internally generated + * name, and "-repartition" is a fixed suffix. + *

                + * Repartitioning can happen for one or both of the joining {@code KStream}s. + * For this case, all data of the stream will be redistributed through the repartitioning topic by writing all + * records to it, and rereading all records from it, such that the join input {@code KStream} is partitioned + * correctly on its key. + *

                + * Both of the joining {@code KStream}s will be materialized in local state stores with auto-generated store names, + * unless a name is provided via a {@code Materialized} instance. + * For failure and recovery each store will be backed by an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-<storename>-changelog", where "applicationId" is user-specified + * in {@link StreamsConfig} via parameter {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, + * "storeName" is an internally generated name, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param the value type of the other stream + * @param the value type of the result stream + * @param otherStream the {@code KStream} to be joined with this stream + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param windows the specification of the {@link JoinWindows} + * @param streamJoined a {@link StreamJoined} instance to configure serdes and state stores + * @return a {@code KStream} that contains join-records for each key and values computed by the given + * {@link ValueJoiner}, one for each matched record-pair with the same key plus one for each non-matching record of + * both {@code KStream} and within the joining window intervals + * @see #join(KStream, ValueJoiner, JoinWindows, StreamJoined) + * @see #leftJoin(KStream, ValueJoiner, JoinWindows, StreamJoined) + */ + KStream outerJoin(final KStream otherStream, + final ValueJoiner joiner, + final JoinWindows windows, + final StreamJoined streamJoined); + + /** + * Join records of this stream with another {@code KStream}'s records using windowed outer equi join using the + * {@link StreamJoined} instance for configuration of the {@link Serde key serde}, {@link Serde this stream's value + * serde}, {@link Serde the other stream's value serde}, and used state stores. + * In contrast to {@link #join(KStream, ValueJoinerWithKey, JoinWindows) inner-join} or + * {@link #leftJoin(KStream, ValueJoinerWithKey, JoinWindows) left-join}, all records from both streams will produce at + * least one output record (cf. below). + * The join is computed on the records' key with join attribute {@code thisKStream.key == otherKStream.key}. + * Furthermore, two records are only joined if their timestamps are close to each other as defined by the given + * {@link JoinWindows}, i.e., the window defines an additional join predicate on the record timestamps. + *

                + * For each pair of records meeting both join predicates the provided {@link ValueJoinerWithKey} will be called to compute + * a value (with arbitrary type) for the result record. + * Note that the key is read-only and should not be modified, as this can lead to undefined behaviour. + * The key of the result record is the same as for both joining input records. + * Furthermore, for each input record of both {@code KStream}s that does not satisfy the join predicate the provided + * {@link ValueJoinerWithKey} will be called with a {@code null} value for this/other stream, respectively. + * If an input record key or value is {@code null} the record will not be included in the join operation and thus no + * output record will be added to the resulting {@code KStream}. + *

                + * Example (assuming all input records belong to the correct windows): + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                thisotherresult
                <K1:A><K1:ValueJoinerWithKey(K1,A,null)>
                <K2:B><K2:b><K2:ValueJoinerWithKey(K2,null,b)>

                <K2:ValueJoinerWithKey(K2,B,b)>
                <K3:c><K3:ValueJoinerWithKey(K3,null,c)>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * If this is not the case, you would need to call {@link #repartition(Repartitioned)} (for one input stream) before + * doing the join and specify the "correct" number of partitions via {@link Repartitioned} parameter. + * Furthermore, both input streams need to be co-partitioned on the join key (i.e., use the same partitioner). + * If this requirement is not met, Kafka Streams will automatically repartition the data, i.e., it will create an + * internal repartitioning topic in Kafka and write and re-read the data via this topic before the actual join. + * The repartitioning topic will be named "${applicationId}-<name>-repartition", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "<name>" is an internally generated + * name, and "-repartition" is a fixed suffix. + *

                + * Repartitioning can happen for one or both of the joining {@code KStream}s. + * For this case, all data of the stream will be redistributed through the repartitioning topic by writing all + * records to it, and rereading all records from it, such that the join input {@code KStream} is partitioned + * correctly on its key. + *

                + * Both of the joining {@code KStream}s will be materialized in local state stores with auto-generated store names, + * unless a name is provided via a {@code Materialized} instance. + * For failure and recovery each store will be backed by an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-<storename>-changelog", where "applicationId" is user-specified + * in {@link StreamsConfig} via parameter {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, + * "storeName" is an internally generated name, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param the value type of the other stream + * @param the value type of the result stream + * @param otherStream the {@code KStream} to be joined with this stream + * @param joiner a {@link ValueJoinerWithKey} that computes the join result for a pair of matching records + * @param windows the specification of the {@link JoinWindows} + * @param streamJoined a {@link StreamJoined} instance to configure serdes and state stores + * @return a {@code KStream} that contains join-records for each key and values computed by the given + * {@link ValueJoinerWithKey}, one for each matched record-pair with the same key plus one for each non-matching record of + * both {@code KStream} and within the joining window intervals + * @see #join(KStream, ValueJoinerWithKey, JoinWindows, StreamJoined) + * @see #leftJoin(KStream, ValueJoinerWithKey, JoinWindows, StreamJoined) + */ + KStream outerJoin(final KStream otherStream, + final ValueJoinerWithKey joiner, + final JoinWindows windows, + final StreamJoined streamJoined); + /** + * Join records of this stream with {@link KTable}'s records using non-windowed inner equi join with default + * serializers and deserializers. + * The join is a primary key table lookup join with join attribute {@code stream.key == table.key}. + * "Table lookup join" means, that results are only computed if {@code KStream} records are processed. + * This is done by performing a lookup for matching records in the current (i.e., processing time) internal + * {@link KTable} state. + * In contrast, processing {@link KTable} input records will only update the internal {@link KTable} state and + * will not produce any result records. + *

                + * For each {@code KStream} record that finds a corresponding record in {@link KTable} the provided + * {@link ValueJoiner} will be called to compute a value (with arbitrary type) for the result record. + * The key of the result record is the same as for both joining input records. + * If an {@code KStream} input record key or value is {@code null} the record will not be included in the join + * operation and thus no output record will be added to the resulting {@code KStream}. + *

                + * Example: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                KStreamKTablestateresult
                <K1:A>
                <K1:b><K1:b>
                <K1:C><K1:b><K1:ValueJoiner(C,b)>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * If this is not the case, you would need to call {@link #repartition(Repartitioned)} for this {@code KStream} + * before doing the join, specifying the same number of partitions via {@link Repartitioned} parameter as the given + * {@link KTable}. + * Furthermore, both input streams need to be co-partitioned on the join key (i.e., use the same partitioner); + * cf. {@link #join(GlobalKTable, KeyValueMapper, ValueJoiner)}. + * If this requirement is not met, Kafka Streams will automatically repartition the data, i.e., it will create an + * internal repartitioning topic in Kafka and write and re-read the data via this topic before the actual join. + * The repartitioning topic will be named "${applicationId}-<name>-repartition", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "<name>" is an internally generated + * name, and "-repartition" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + *

                + * Repartitioning can happen only for this {@code KStream} but not for the provided {@link KTable}. + * For this case, all data of the stream will be redistributed through the repartitioning topic by writing all + * records to it, and rereading all records from it, such that the join input {@code KStream} is partitioned + * correctly on its key. + * + * @param table the {@link KTable} to be joined with this stream + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param the value type of the table + * @param the value type of the result stream + * @return a {@code KStream} that contains join-records for each key and values computed by the given + * {@link ValueJoiner}, one for each matched record-pair with the same key + * @see #leftJoin(KTable, ValueJoiner) + * @see #join(GlobalKTable, KeyValueMapper, ValueJoiner) + */ + KStream join(final KTable table, + final ValueJoiner joiner); + + /** + * Join records of this stream with {@link KTable}'s records using non-windowed inner equi join with default + * serializers and deserializers. + * The join is a primary key table lookup join with join attribute {@code stream.key == table.key}. + * "Table lookup join" means, that results are only computed if {@code KStream} records are processed. + * This is done by performing a lookup for matching records in the current (i.e., processing time) internal + * {@link KTable} state. + * In contrast, processing {@link KTable} input records will only update the internal {@link KTable} state and + * will not produce any result records. + *

                + * For each {@code KStream} record that finds a corresponding record in {@link KTable} the provided + * {@link ValueJoinerWithKey} will be called to compute a value (with arbitrary type) for the result record. + * Note that the key is read-only and should not be modified, as this can lead to undefined behaviour. + * + * The key of the result record is the same as for both joining input records. + * If an {@code KStream} input record key or value is {@code null} the record will not be included in the join + * operation and thus no output record will be added to the resulting {@code KStream}. + *

                + * Example: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                KStreamKTablestateresult
                <K1:A>
                <K1:b><K1:b>
                <K1:C><K1:b><K1:ValueJoinerWithKey(K1,C,b)>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * If this is not the case, you would need to call {@link #repartition(Repartitioned)} for this {@code KStream} + * before doing the join, specifying the same number of partitions via {@link Repartitioned} parameter as the given + * {@link KTable}. + * Furthermore, both input streams need to be co-partitioned on the join key (i.e., use the same partitioner); + * cf. {@link #join(GlobalKTable, KeyValueMapper, ValueJoinerWithKey)}. + * If this requirement is not met, Kafka Streams will automatically repartition the data, i.e., it will create an + * internal repartitioning topic in Kafka and write and re-read the data via this topic before the actual join. + * The repartitioning topic will be named "${applicationId}-<name>-repartition", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "<name>" is an internally generated + * name, and "-repartition" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + *

                + * Repartitioning can happen only for this {@code KStream} but not for the provided {@link KTable}. + * For this case, all data of the stream will be redistributed through the repartitioning topic by writing all + * records to it, and rereading all records from it, such that the join input {@code KStream} is partitioned + * correctly on its key. + * + * @param table the {@link KTable} to be joined with this stream + * @param joiner a {@link ValueJoinerWithKey} that computes the join result for a pair of matching records + * @param the value type of the table + * @param the value type of the result stream + * @return a {@code KStream} that contains join-records for each key and values computed by the given + * {@link ValueJoinerWithKey}, one for each matched record-pair with the same key + * @see #leftJoin(KTable, ValueJoinerWithKey) + * @see #join(GlobalKTable, KeyValueMapper, ValueJoinerWithKey) + */ + KStream join(final KTable table, + final ValueJoinerWithKey joiner); + + /** + * Join records of this stream with {@link KTable}'s records using non-windowed inner equi join with default + * serializers and deserializers. + * The join is a primary key table lookup join with join attribute {@code stream.key == table.key}. + * "Table lookup join" means, that results are only computed if {@code KStream} records are processed. + * This is done by performing a lookup for matching records in the current (i.e., processing time) internal + * {@link KTable} state. + * In contrast, processing {@link KTable} input records will only update the internal {@link KTable} state and + * will not produce any result records. + *

                + * For each {@code KStream} record that finds a corresponding record in {@link KTable} the provided + * {@link ValueJoiner} will be called to compute a value (with arbitrary type) for the result record. + * The key of the result record is the same as for both joining input records. + * If an {@code KStream} input record key or value is {@code null} the record will not be included in the join + * operation and thus no output record will be added to the resulting {@code KStream}. + *

                + * Example: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                KStreamKTablestateresult
                <K1:A>
                <K1:b><K1:b>
                <K1:C><K1:b><K1:ValueJoiner(C,b)>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * If this is not the case, you would need to call {@link #repartition(Repartitioned)} for this {@code KStream} + * before doing the join, specifying the same number of partitions via {@link Repartitioned} parameter as the given + * {@link KTable}. + * Furthermore, both input streams need to be co-partitioned on the join key (i.e., use the same partitioner); + * cf. {@link #join(GlobalKTable, KeyValueMapper, ValueJoiner)}. + * If this requirement is not met, Kafka Streams will automatically repartition the data, i.e., it will create an + * internal repartitioning topic in Kafka and write and re-read the data via this topic before the actual join. + * The repartitioning topic will be named "${applicationId}-<name>-repartition", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "<name>" is an internally generated + * name, and "-repartition" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + *

                + * Repartitioning can happen only for this {@code KStream} but not for the provided {@link KTable}. + * For this case, all data of the stream will be redistributed through the repartitioning topic by writing all + * records to it, and rereading all records from it, such that the join input {@code KStream} is partitioned + * correctly on its key. + * + * @param table the {@link KTable} to be joined with this stream + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param joined a {@link Joined} instance that defines the serdes to + * be used to serialize/deserialize inputs of the joined streams + * @param the value type of the table + * @param the value type of the result stream + * @return a {@code KStream} that contains join-records for each key and values computed by the given + * {@link ValueJoiner}, one for each matched record-pair with the same key + * @see #leftJoin(KTable, ValueJoiner, Joined) + * @see #join(GlobalKTable, KeyValueMapper, ValueJoiner) + */ + KStream join(final KTable table, + final ValueJoiner joiner, + final Joined joined); + + /** + * Join records of this stream with {@link KTable}'s records using non-windowed inner equi join with default + * serializers and deserializers. + * The join is a primary key table lookup join with join attribute {@code stream.key == table.key}. + * "Table lookup join" means, that results are only computed if {@code KStream} records are processed. + * This is done by performing a lookup for matching records in the current (i.e., processing time) internal + * {@link KTable} state. + * In contrast, processing {@link KTable} input records will only update the internal {@link KTable} state and + * will not produce any result records. + *

                + * For each {@code KStream} record that finds a corresponding record in {@link KTable} the provided + * {@link ValueJoinerWithKey} will be called to compute a value (with arbitrary type) for the result record. + * The key of the result record is the same as for both joining input records. + * Note that the key is read-only and should not be modified, as this can lead to undefined behaviour. + * + * If an {@code KStream} input record key or value is {@code null} the record will not be included in the join + * operation and thus no output record will be added to the resulting {@code KStream}. + *

                + * Example: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                KStreamKTablestateresult
                <K1:A>
                <K1:b><K1:b>
                <K1:C><K1:b><K1:ValueJoinerWithKey(K1,C,b)>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * If this is not the case, you would need to call {@link #repartition(Repartitioned)} for this {@code KStream} + * before doing the join, specifying the same number of partitions via {@link Repartitioned} parameter as the given + * {@link KTable}. + * Furthermore, both input streams need to be co-partitioned on the join key (i.e., use the same partitioner); + * cf. {@link #join(GlobalKTable, KeyValueMapper, ValueJoinerWithKey)}. + * If this requirement is not met, Kafka Streams will automatically repartition the data, i.e., it will create an + * internal repartitioning topic in Kafka and write and re-read the data via this topic before the actual join. + * The repartitioning topic will be named "${applicationId}-<name>-repartition", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "<name>" is an internally generated + * name, and "-repartition" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + *

                + * Repartitioning can happen only for this {@code KStream} but not for the provided {@link KTable}. + * For this case, all data of the stream will be redistributed through the repartitioning topic by writing all + * records to it, and rereading all records from it, such that the join input {@code KStream} is partitioned + * correctly on its key. + * + * @param table the {@link KTable} to be joined with this stream + * @param joiner a {@link ValueJoinerWithKey} that computes the join result for a pair of matching records + * @param joined a {@link Joined} instance that defines the serdes to + * be used to serialize/deserialize inputs of the joined streams + * @param the value type of the table + * @param the value type of the result stream + * @return a {@code KStream} that contains join-records for each key and values computed by the given + * {@link ValueJoinerWithKey}, one for each matched record-pair with the same key + * @see #leftJoin(KTable, ValueJoinerWithKey, Joined) + * @see #join(GlobalKTable, KeyValueMapper, ValueJoinerWithKey) + */ + KStream join(final KTable table, + final ValueJoinerWithKey joiner, + final Joined joined); + + /** + * Join records of this stream with {@link KTable}'s records using non-windowed left equi join with default + * serializers and deserializers. + * In contrast to {@link #join(KTable, ValueJoiner) inner-join}, all records from this stream will produce an + * output record (cf. below). + * The join is a primary key table lookup join with join attribute {@code stream.key == table.key}. + * "Table lookup join" means, that results are only computed if {@code KStream} records are processed. + * This is done by performing a lookup for matching records in the current (i.e., processing time) internal + * {@link KTable} state. + * In contrast, processing {@link KTable} input records will only update the internal {@link KTable} state and + * will not produce any result records. + *

                + * For each {@code KStream} record whether or not it finds a corresponding record in {@link KTable} the provided + * {@link ValueJoiner} will be called to compute a value (with arbitrary type) for the result record. + * If no {@link KTable} record was found during lookup, a {@code null} value will be provided to {@link ValueJoiner}. + * The key of the result record is the same as for both joining input records. + * If an {@code KStream} input record key or value is {@code null} the record will not be included in the join + * operation and thus no output record will be added to the resulting {@code KStream}. + *

                + * Example: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                KStreamKTablestateresult
                <K1:A><K1:ValueJoiner(A,null)>
                <K1:b><K1:b>
                <K1:C><K1:b><K1:ValueJoiner(C,b)>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * If this is not the case, you would need to call {@link #repartition(Repartitioned)} for this {@code KStream} + * before doing the join, specifying the same number of partitions via {@link Repartitioned} parameter as the given + * {@link KTable}. + * Furthermore, both input streams need to be co-partitioned on the join key (i.e., use the same partitioner); + * cf. {@link #join(GlobalKTable, KeyValueMapper, ValueJoiner)}. + * If this requirement is not met, Kafka Streams will automatically repartition the data, i.e., it will create an + * internal repartitioning topic in Kafka and write and re-read the data via this topic before the actual join. + * The repartitioning topic will be named "${applicationId}-<name>-repartition", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "<name>" is an internally generated + * name, and "-repartition" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + *

                + * Repartitioning can happen only for this {@code KStream} but not for the provided {@link KTable}. + * For this case, all data of the stream will be redistributed through the repartitioning topic by writing all + * records to it, and rereading all records from it, such that the join input {@code KStream} is partitioned + * correctly on its key. + * + * @param table the {@link KTable} to be joined with this stream + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param the value type of the table + * @param the value type of the result stream + * @return a {@code KStream} that contains join-records for each key and values computed by the given + * {@link ValueJoiner}, one output for each input {@code KStream} record + * @see #join(KTable, ValueJoiner) + * @see #leftJoin(GlobalKTable, KeyValueMapper, ValueJoiner) + */ + KStream leftJoin(final KTable table, + final ValueJoiner joiner); + + /** + * Join records of this stream with {@link KTable}'s records using non-windowed left equi join with default + * serializers and deserializers. + * In contrast to {@link #join(KTable, ValueJoinerWithKey) inner-join}, all records from this stream will produce an + * output record (cf. below). + * The join is a primary key table lookup join with join attribute {@code stream.key == table.key}. + * "Table lookup join" means, that results are only computed if {@code KStream} records are processed. + * This is done by performing a lookup for matching records in the current (i.e., processing time) internal + * {@link KTable} state. + * In contrast, processing {@link KTable} input records will only update the internal {@link KTable} state and + * will not produce any result records. + *

                + * For each {@code KStream} record whether or not it finds a corresponding record in {@link KTable} the provided + * {@link ValueJoinerWithKey} will be called to compute a value (with arbitrary type) for the result record. + * If no {@link KTable} record was found during lookup, a {@code null} value will be provided to {@link ValueJoinerWithKey}. + * The key of the result record is the same as for both joining input records. + * Note that the key is read-only and should not be modified, as this can lead to undefined behaviour. + * If an {@code KStream} input record key or value is {@code null} the record will not be included in the join + * operation and thus no output record will be added to the resulting {@code KStream}. + *

                + * Example: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                KStreamKTablestateresult
                <K1:A><K1:ValueJoinerWithKey(K1,A,null)>
                <K1:b><K1:b>
                <K1:C><K1:b><K1:ValueJoinerWithKey(K1,C,b)>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * If this is not the case, you would need to call {@link #repartition(Repartitioned)} for this {@code KStream} + * before doing the join, specifying the same number of partitions via {@link Repartitioned} parameter as the given + * {@link KTable}. + * Furthermore, both input streams need to be co-partitioned on the join key (i.e., use the same partitioner); + * cf. {@link #join(GlobalKTable, KeyValueMapper, ValueJoinerWithKey)}. + * If this requirement is not met, Kafka Streams will automatically repartition the data, i.e., it will create an + * internal repartitioning topic in Kafka and write and re-read the data via this topic before the actual join. + * The repartitioning topic will be named "${applicationId}-<name>-repartition", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "<name>" is an internally generated + * name, and "-repartition" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + *

                + * Repartitioning can happen only for this {@code KStream} but not for the provided {@link KTable}. + * For this case, all data of the stream will be redistributed through the repartitioning topic by writing all + * records to it, and rereading all records from it, such that the join input {@code KStream} is partitioned + * correctly on its key. + * + * @param table the {@link KTable} to be joined with this stream + * @param joiner a {@link ValueJoinerWithKey} that computes the join result for a pair of matching records + * @param the value type of the table + * @param the value type of the result stream + * @return a {@code KStream} that contains join-records for each key and values computed by the given + * {@link ValueJoinerWithKey}, one output for each input {@code KStream} record + * @see #join(KTable, ValueJoinerWithKey) + * @see #leftJoin(GlobalKTable, KeyValueMapper, ValueJoinerWithKey) + */ + KStream leftJoin(final KTable table, + final ValueJoinerWithKey joiner); + + /** + * Join records of this stream with {@link KTable}'s records using non-windowed left equi join with default + * serializers and deserializers. + * In contrast to {@link #join(KTable, ValueJoiner) inner-join}, all records from this stream will produce an + * output record (cf. below). + * The join is a primary key table lookup join with join attribute {@code stream.key == table.key}. + * "Table lookup join" means, that results are only computed if {@code KStream} records are processed. + * This is done by performing a lookup for matching records in the current (i.e., processing time) internal + * {@link KTable} state. + * In contrast, processing {@link KTable} input records will only update the internal {@link KTable} state and + * will not produce any result records. + *

                + * For each {@code KStream} record whether or not it finds a corresponding record in {@link KTable} the provided + * {@link ValueJoiner} will be called to compute a value (with arbitrary type) for the result record. + * If no {@link KTable} record was found during lookup, a {@code null} value will be provided to {@link ValueJoiner}. + * The key of the result record is the same as for both joining input records. + * If an {@code KStream} input record key or value is {@code null} the record will not be included in the join + * operation and thus no output record will be added to the resulting {@code KStream}. + *

                + * Example: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                KStreamKTablestateresult
                <K1:A><K1:ValueJoiner(A,null)>
                <K1:b><K1:b>
                <K1:C><K1:b><K1:ValueJoiner(C,b)>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * If this is not the case, you would need to call {@link #repartition(Repartitioned)} for this {@code KStream} + * before doing the join, specifying the same number of partitions via {@link Repartitioned} parameter as the given + * {@link KTable}. + * Furthermore, both input streams need to be co-partitioned on the join key (i.e., use the same partitioner); + * cf. {@link #join(GlobalKTable, KeyValueMapper, ValueJoiner)}. + * If this requirement is not met, Kafka Streams will automatically repartition the data, i.e., it will create an + * internal repartitioning topic in Kafka and write and re-read the data via this topic before the actual join. + * The repartitioning topic will be named "${applicationId}-<name>-repartition", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "<name>" is an internally generated + * name, and "-repartition" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + *

                + * Repartitioning can happen only for this {@code KStream} but not for the provided {@link KTable}. + * For this case, all data of the stream will be redistributed through the repartitioning topic by writing all + * records to it, and rereading all records from it, such that the join input {@code KStream} is partitioned + * correctly on its key. + * + * @param table the {@link KTable} to be joined with this stream + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param joined a {@link Joined} instance that defines the serdes to + * be used to serialize/deserialize inputs and outputs of the joined streams + * @param the value type of the table + * @param the value type of the result stream + * @return a {@code KStream} that contains join-records for each key and values computed by the given + * {@link ValueJoiner}, one output for each input {@code KStream} record + * @see #join(KTable, ValueJoiner, Joined) + * @see #leftJoin(GlobalKTable, KeyValueMapper, ValueJoiner) + */ + KStream leftJoin(final KTable table, + final ValueJoiner joiner, + final Joined joined); + + /** + * Join records of this stream with {@link KTable}'s records using non-windowed left equi join with default + * serializers and deserializers. + * In contrast to {@link #join(KTable, ValueJoinerWithKey) inner-join}, all records from this stream will produce an + * output record (cf. below). + * The join is a primary key table lookup join with join attribute {@code stream.key == table.key}. + * "Table lookup join" means, that results are only computed if {@code KStream} records are processed. + * This is done by performing a lookup for matching records in the current (i.e., processing time) internal + * {@link KTable} state. + * In contrast, processing {@link KTable} input records will only update the internal {@link KTable} state and + * will not produce any result records. + *

                + * For each {@code KStream} record whether or not it finds a corresponding record in {@link KTable} the provided + * {@link ValueJoinerWithKey} will be called to compute a value (with arbitrary type) for the result record. + * If no {@link KTable} record was found during lookup, a {@code null} value will be provided to {@link ValueJoinerWithKey}. + * The key of the result record is the same as for both joining input records. + * Note that the key is read-only and should not be modified, as this can lead to undefined behaviour. + * If an {@code KStream} input record key or value is {@code null} the record will not be included in the join + * operation and thus no output record will be added to the resulting {@code KStream}. + *

                + * Example: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                KStreamKTablestateresult
                <K1:A><K1:ValueJoinerWithKey(K1,A,null)>
                <K1:b><K1:b>
                <K1:C><K1:b><K1:ValueJoinerWithKey(K1,C,b)>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * If this is not the case, you would need to call {@link #repartition(Repartitioned)} for this {@code KStream} + * before doing the join, specifying the same number of partitions via {@link Repartitioned} parameter as the given + * {@link KTable}. + * Furthermore, both input streams need to be co-partitioned on the join key (i.e., use the same partitioner); + * cf. {@link #join(GlobalKTable, KeyValueMapper, ValueJoinerWithKey)}. + * If this requirement is not met, Kafka Streams will automatically repartition the data, i.e., it will create an + * internal repartitioning topic in Kafka and write and re-read the data via this topic before the actual join. + * The repartitioning topic will be named "${applicationId}-<name>-repartition", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "<name>" is an internally generated + * name, and "-repartition" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + *

                + * Repartitioning can happen only for this {@code KStream} but not for the provided {@link KTable}. + * For this case, all data of the stream will be redistributed through the repartitioning topic by writing all + * records to it, and rereading all records from it, such that the join input {@code KStream} is partitioned + * correctly on its key. + * + * @param table the {@link KTable} to be joined with this stream + * @param joiner a {@link ValueJoinerWithKey} that computes the join result for a pair of matching records + * @param joined a {@link Joined} instance that defines the serdes to + * be used to serialize/deserialize inputs and outputs of the joined streams + * @param the value type of the table + * @param the value type of the result stream + * @return a {@code KStream} that contains join-records for each key and values computed by the given + * {@link ValueJoinerWithKey}, one output for each input {@code KStream} record + * @see #join(KTable, ValueJoinerWithKey, Joined) + * @see #leftJoin(GlobalKTable, KeyValueMapper, ValueJoinerWithKey) + */ + KStream leftJoin(final KTable table, + final ValueJoinerWithKey joiner, + final Joined joined); + + /** + * Join records of this stream with {@link GlobalKTable}'s records using non-windowed inner equi join. + * The join is a primary key table lookup join with join attribute + * {@code keyValueMapper.map(stream.keyValue) == table.key}. + * "Table lookup join" means, that results are only computed if {@code KStream} records are processed. + * This is done by performing a lookup for matching records in the current internal {@link GlobalKTable} + * state. + * In contrast, processing {@link GlobalKTable} input records will only update the internal {@link GlobalKTable} + * state and will not produce any result records. + *

                + * For each {@code KStream} record that finds a corresponding record in {@link GlobalKTable} the provided + * {@link ValueJoiner} will be called to compute a value (with arbitrary type) for the result record. + * The key of the result record is the same as the key of this {@code KStream}. + * If a {@code KStream} input record key or value is {@code null} the record will not be included in the join + * operation and thus no output record will be added to the resulting {@code KStream}. + * If {@code keyValueMapper} returns {@code null} implying no match exists, no output record will be added to the + * resulting {@code KStream}. + * + * @param globalTable the {@link GlobalKTable} to be joined with this stream + * @param keySelector instance of {@link KeyValueMapper} used to map from the (key, value) of this stream + * to the key of the {@link GlobalKTable} + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param the key type of {@link GlobalKTable} + * @param the value type of the {@link GlobalKTable} + * @param the value type of the resulting {@code KStream} + * @return a {@code KStream} that contains join-records for each key and values computed by the given + * {@link ValueJoiner}, one output for each input {@code KStream} record + * @see #leftJoin(GlobalKTable, KeyValueMapper, ValueJoiner) + */ + KStream join(final GlobalKTable globalTable, + final KeyValueMapper keySelector, + final ValueJoiner joiner); + + /** + * Join records of this stream with {@link GlobalKTable}'s records using non-windowed inner equi join. + * The join is a primary key table lookup join with join attribute + * {@code keyValueMapper.map(stream.keyValue) == table.key}. + * "Table lookup join" means, that results are only computed if {@code KStream} records are processed. + * This is done by performing a lookup for matching records in the current internal {@link GlobalKTable} + * state. + * In contrast, processing {@link GlobalKTable} input records will only update the internal {@link GlobalKTable} + * state and will not produce any result records. + *

                + * For each {@code KStream} record that finds a corresponding record in {@link GlobalKTable} the provided + * {@link ValueJoinerWithKey} will be called to compute a value (with arbitrary type) for the result record. + * The key of the result record is the same as the key of this {@code KStream}. + * Note that the key is read-only and should not be modified, as this can lead to undefined behaviour. + * If a {@code KStream} input record key or value is {@code null} the record will not be included in the join + * operation and thus no output record will be added to the resulting {@code KStream}. + * If {@code keyValueMapper} returns {@code null} implying no match exists, no output record will be added to the + * resulting {@code KStream}. + * + * @param globalTable the {@link GlobalKTable} to be joined with this stream + * @param keySelector instance of {@link KeyValueMapper} used to map from the (key, value) of this stream + * to the key of the {@link GlobalKTable} + * @param joiner a {@link ValueJoinerWithKey} that computes the join result for a pair of matching records + * @param the key type of {@link GlobalKTable} + * @param the value type of the {@link GlobalKTable} + * @param the value type of the resulting {@code KStream} + * @return a {@code KStream} that contains join-records for each key and values computed by the given + * {@link ValueJoinerWithKey}, one output for each input {@code KStream} record + * @see #leftJoin(GlobalKTable, KeyValueMapper, ValueJoinerWithKey) + */ + KStream join(final GlobalKTable globalTable, + final KeyValueMapper keySelector, + final ValueJoinerWithKey joiner); + + /** + * Join records of this stream with {@link GlobalKTable}'s records using non-windowed inner equi join. + * The join is a primary key table lookup join with join attribute + * {@code keyValueMapper.map(stream.keyValue) == table.key}. + * "Table lookup join" means, that results are only computed if {@code KStream} records are processed. + * This is done by performing a lookup for matching records in the current internal {@link GlobalKTable} + * state. + * In contrast, processing {@link GlobalKTable} input records will only update the internal {@link GlobalKTable} + * state and will not produce any result records. + *

                + * For each {@code KStream} record that finds a corresponding record in {@link GlobalKTable} the provided + * {@link ValueJoiner} will be called to compute a value (with arbitrary type) for the result record. + * The key of the result record is the same as the key of this {@code KStream}. + * If a {@code KStream} input record key or value is {@code null} the record will not be included in the join + * operation and thus no output record will be added to the resulting {@code KStream}. + * If {@code keyValueMapper} returns {@code null} implying no match exists, no output record will be added to the + * resulting {@code KStream}. + * + * @param globalTable the {@link GlobalKTable} to be joined with this stream + * @param keySelector instance of {@link KeyValueMapper} used to map from the (key, value) of this stream + * to the key of the {@link GlobalKTable} + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param named a {@link Named} config used to name the processor in the topology + * @param the key type of {@link GlobalKTable} + * @param the value type of the {@link GlobalKTable} + * @param the value type of the resulting {@code KStream} + * @return a {@code KStream} that contains join-records for each key and values computed by the given + * {@link ValueJoiner}, one output for each input {@code KStream} record + * @see #leftJoin(GlobalKTable, KeyValueMapper, ValueJoiner) + */ + KStream join(final GlobalKTable globalTable, + final KeyValueMapper keySelector, + final ValueJoiner joiner, + final Named named); + + /** + * Join records of this stream with {@link GlobalKTable}'s records using non-windowed inner equi join. + * The join is a primary key table lookup join with join attribute + * {@code keyValueMapper.map(stream.keyValue) == table.key}. + * "Table lookup join" means, that results are only computed if {@code KStream} records are processed. + * This is done by performing a lookup for matching records in the current internal {@link GlobalKTable} + * state. + * In contrast, processing {@link GlobalKTable} input records will only update the internal {@link GlobalKTable} + * state and will not produce any result records. + *

                + * For each {@code KStream} record that finds a corresponding record in {@link GlobalKTable} the provided + * {@link ValueJoinerWithKey} will be called to compute a value (with arbitrary type) for the result record. + * The key of the result record is the same as the key of this {@code KStream}. + * Note that the key is read-only and should not be modified, as this can lead to undefined behaviour. + * If a {@code KStream} input record key or value is {@code null} the record will not be included in the join + * operation and thus no output record will be added to the resulting {@code KStream}. + * If {@code keyValueMapper} returns {@code null} implying no match exists, no output record will be added to the + * resulting {@code KStream}. + * + * @param globalTable the {@link GlobalKTable} to be joined with this stream + * @param keySelector instance of {@link KeyValueMapper} used to map from the (key, value) of this stream + * to the key of the {@link GlobalKTable} + * @param joiner a {@link ValueJoinerWithKey} that computes the join result for a pair of matching records + * @param named a {@link Named} config used to name the processor in the topology + * @param the key type of {@link GlobalKTable} + * @param the value type of the {@link GlobalKTable} + * @param the value type of the resulting {@code KStream} + * @return a {@code KStream} that contains join-records for each key and values computed by the given + * {@link ValueJoinerWithKey}, one output for each input {@code KStream} record + * @see #leftJoin(GlobalKTable, KeyValueMapper, ValueJoinerWithKey) + */ + KStream join(final GlobalKTable globalTable, + final KeyValueMapper keySelector, + final ValueJoinerWithKey joiner, + final Named named); + + /** + * Join records of this stream with {@link GlobalKTable}'s records using non-windowed left equi join. + * In contrast to {@link #join(GlobalKTable, KeyValueMapper, ValueJoiner) inner-join}, all records from this stream + * will produce an output record (cf. below). + * The join is a primary key table lookup join with join attribute + * {@code keyValueMapper.map(stream.keyValue) == table.key}. + * "Table lookup join" means, that results are only computed if {@code KStream} records are processed. + * This is done by performing a lookup for matching records in the current internal {@link GlobalKTable} + * state. + * In contrast, processing {@link GlobalKTable} input records will only update the internal {@link GlobalKTable} + * state and will not produce any result records. + *

                + * For each {@code KStream} record whether or not it finds a corresponding record in {@link GlobalKTable} the + * provided {@link ValueJoiner} will be called to compute a value (with arbitrary type) for the result record. + * The key of the result record is the same as this {@code KStream}. + * If a {@code KStream} input record key or value is {@code null} the record will not be included in the join + * operation and thus no output record will be added to the resulting {@code KStream}. + * If {@code keyValueMapper} returns {@code null} implying no match exists, a {@code null} value will be + * provided to {@link ValueJoiner}. + * If no {@link GlobalKTable} record was found during lookup, a {@code null} value will be provided to + * {@link ValueJoiner}. + * + * @param globalTable the {@link GlobalKTable} to be joined with this stream + * @param keySelector instance of {@link KeyValueMapper} used to map from the (key, value) of this stream + * to the key of the {@link GlobalKTable} + * @param valueJoiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param the key type of {@link GlobalKTable} + * @param the value type of the {@link GlobalKTable} + * @param the value type of the resulting {@code KStream} + * @return a {@code KStream} that contains join-records for each key and values computed by the given + * {@link ValueJoiner}, one output for each input {@code KStream} record + * @see #join(GlobalKTable, KeyValueMapper, ValueJoiner) + */ + KStream leftJoin(final GlobalKTable globalTable, + final KeyValueMapper keySelector, + final ValueJoiner valueJoiner); + + /** + * Join records of this stream with {@link GlobalKTable}'s records using non-windowed left equi join. + * In contrast to {@link #join(GlobalKTable, KeyValueMapper, ValueJoinerWithKey) inner-join}, all records from this stream + * will produce an output record (cf. below). + * The join is a primary key table lookup join with join attribute + * {@code keyValueMapper.map(stream.keyValue) == table.key}. + * "Table lookup join" means, that results are only computed if {@code KStream} records are processed. + * This is done by performing a lookup for matching records in the current internal {@link GlobalKTable} + * state. + * In contrast, processing {@link GlobalKTable} input records will only update the internal {@link GlobalKTable} + * state and will not produce any result records. + *

                + * For each {@code KStream} record whether or not it finds a corresponding record in {@link GlobalKTable} the + * provided {@link ValueJoinerWithKey} will be called to compute a value (with arbitrary type) for the result record. + * The key of the result record is the same as this {@code KStream}. + * Note that the key is read-only and should not be modified, as this can lead to undefined behaviour. + * If a {@code KStream} input record key or value is {@code null} the record will not be included in the join + * operation and thus no output record will be added to the resulting {@code KStream}. + * If {@code keyValueMapper} returns {@code null} implying no match exists, a {@code null} value will be + * provided to {@link ValueJoinerWithKey}. + * If no {@link GlobalKTable} record was found during lookup, a {@code null} value will be provided to + * {@link ValueJoiner}. + * + * @param globalTable the {@link GlobalKTable} to be joined with this stream + * @param keySelector instance of {@link KeyValueMapper} used to map from the (key, value) of this stream + * to the key of the {@link GlobalKTable} + * @param valueJoiner a {@link ValueJoinerWithKey} that computes the join result for a pair of matching records + * @param the key type of {@link GlobalKTable} + * @param the value type of the {@link GlobalKTable} + * @param the value type of the resulting {@code KStream} + * @return a {@code KStream} that contains join-records for each key and values computed by the given + * {@link ValueJoinerWithKey}, one output for each input {@code KStream} record + * @see #join(GlobalKTable, KeyValueMapper, ValueJoinerWithKey) + */ + KStream leftJoin(final GlobalKTable globalTable, + final KeyValueMapper keySelector, + final ValueJoinerWithKey valueJoiner); + + /** + * Join records of this stream with {@link GlobalKTable}'s records using non-windowed left equi join. + * In contrast to {@link #join(GlobalKTable, KeyValueMapper, ValueJoiner) inner-join}, all records from this stream + * will produce an output record (cf. below). + * The join is a primary key table lookup join with join attribute + * {@code keyValueMapper.map(stream.keyValue) == table.key}. + * "Table lookup join" means, that results are only computed if {@code KStream} records are processed. + * This is done by performing a lookup for matching records in the current internal {@link GlobalKTable} + * state. + * In contrast, processing {@link GlobalKTable} input records will only update the internal {@link GlobalKTable} + * state and will not produce any result records. + *

                + * For each {@code KStream} record whether or not it finds a corresponding record in {@link GlobalKTable} the + * provided {@link ValueJoiner} will be called to compute a value (with arbitrary type) for the result record. + * The key of the result record is the same as this {@code KStream}. + * If a {@code KStream} input record key or value is {@code null} the record will not be included in the join + * operation and thus no output record will be added to the resulting {@code KStream}. + * If {@code keyValueMapper} returns {@code null} implying no match exists, a {@code null} value will be + * provided to {@link ValueJoiner}. + * If no {@link GlobalKTable} record was found during lookup, a {@code null} value will be provided to + * {@link ValueJoiner}. + * + * @param globalTable the {@link GlobalKTable} to be joined with this stream + * @param keySelector instance of {@link KeyValueMapper} used to map from the (key, value) of this stream + * to the key of the {@link GlobalKTable} + * @param valueJoiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param named a {@link Named} config used to name the processor in the topology + * @param the key type of {@link GlobalKTable} + * @param the value type of the {@link GlobalKTable} + * @param the value type of the resulting {@code KStream} + * @return a {@code KStream} that contains join-records for each key and values computed by the given + * {@link ValueJoiner}, one output for each input {@code KStream} record + * @see #join(GlobalKTable, KeyValueMapper, ValueJoiner) + */ + KStream leftJoin(final GlobalKTable globalTable, + final KeyValueMapper keySelector, + final ValueJoiner valueJoiner, + final Named named); + + /** + * Join records of this stream with {@link GlobalKTable}'s records using non-windowed left equi join. + * In contrast to {@link #join(GlobalKTable, KeyValueMapper, ValueJoinerWithKey) inner-join}, all records from this stream + * will produce an output record (cf. below). + * The join is a primary key table lookup join with join attribute + * {@code keyValueMapper.map(stream.keyValue) == table.key}. + * "Table lookup join" means, that results are only computed if {@code KStream} records are processed. + * This is done by performing a lookup for matching records in the current internal {@link GlobalKTable} + * state. + * In contrast, processing {@link GlobalKTable} input records will only update the internal {@link GlobalKTable} + * state and will not produce any result records. + *

                + * For each {@code KStream} record whether or not it finds a corresponding record in {@link GlobalKTable} the + * provided {@link ValueJoinerWithKey} will be called to compute a value (with arbitrary type) for the result record. + * The key of the result record is the same as this {@code KStream}. + * If a {@code KStream} input record key or value is {@code null} the record will not be included in the join + * operation and thus no output record will be added to the resulting {@code KStream}. + * If {@code keyValueMapper} returns {@code null} implying no match exists, a {@code null} value will be + * provided to {@link ValueJoinerWithKey}. + * If no {@link GlobalKTable} record was found during lookup, a {@code null} value will be provided to + * {@link ValueJoinerWithKey}. + * + * @param globalTable the {@link GlobalKTable} to be joined with this stream + * @param keySelector instance of {@link KeyValueMapper} used to map from the (key, value) of this stream + * to the key of the {@link GlobalKTable} + * @param valueJoiner a {@link ValueJoinerWithKey} that computes the join result for a pair of matching records + * @param named a {@link Named} config used to name the processor in the topology + * @param the key type of {@link GlobalKTable} + * @param the value type of the {@link GlobalKTable} + * @param the value type of the resulting {@code KStream} + * @return a {@code KStream} that contains join-records for each key and values computed by the given + * {@link ValueJoinerWithKey}, one output for each input {@code KStream} record + * @see #join(GlobalKTable, KeyValueMapper, ValueJoinerWithKey) + */ + KStream leftJoin(final GlobalKTable globalTable, + final KeyValueMapper keySelector, + final ValueJoinerWithKey valueJoiner, + final Named named); + + /** + * Transform each record of the input stream into zero or one record in the output stream (both key and value type + * can be altered arbitrarily). + * A {@link Transformer} (provided by the given {@link TransformerSupplier}) is applied to each input record and + * returns zero or one output record. + * Thus, an input record {@code } can be transformed into an output record {@code }. + * Attaching a state store makes this a stateful record-by-record operation (cf. {@link #map(KeyValueMapper) map()}). + * If you choose not to attach one, this operation is similar to the stateless {@link #map(KeyValueMapper) map()} + * but allows access to the {@code ProcessorContext} and record metadata. + * This is essentially mixing the Processor API into the DSL, and provides all the functionality of the PAPI. + * Furthermore, via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) Punctuator#punctuate()}, + * the processing progress can be observed and additional periodic actions can be performed. + *

                + * In order for the transformer to use state stores, the stores must be added to the topology and connected to the + * transformer using at least one of two strategies (though it's not required to connect global state stores; read-only + * access to global state stores is available by default). + *

                + * The first strategy is to manually add the {@link StoreBuilder}s via {@link Topology#addStateStore(StoreBuilder, String...)}, + * and specify the store names via {@code stateStoreNames} so they will be connected to the transformer. + *

                {@code
                +     * // create store
                +     * StoreBuilder> keyValueStoreBuilder =
                +     *         Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myTransformState"),
                +     *                 Serdes.String(),
                +     *                 Serdes.String());
                +     * // add store
                +     * builder.addStateStore(keyValueStoreBuilder);
                +     *
                +     * KStream outputStream = inputStream.transform(new TransformerSupplier() {
                +     *     public Transformer get() {
                +     *         return new MyTransformer();
                +     *     }
                +     * }, "myTransformState");
                +     * }
                + * The second strategy is for the given {@link TransformerSupplier} to implement {@link ConnectedStoreProvider#stores()}, + * which provides the {@link StoreBuilder}s to be automatically added to the topology and connected to the transformer. + *
                {@code
                +     * class MyTransformerSupplier implements TransformerSupplier {
                +     *     // supply transformer
                +     *     Transformer get() {
                +     *         return new MyTransformer();
                +     *     }
                +     *
                +     *     // provide store(s) that will be added and connected to the associated transformer
                +     *     // the store name from the builder ("myTransformState") is used to access the store later via the ProcessorContext
                +     *     Set stores() {
                +     *         StoreBuilder> keyValueStoreBuilder =
                +     *                   Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myTransformState"),
                +     *                   Serdes.String(),
                +     *                   Serdes.String());
                +     *         return Collections.singleton(keyValueStoreBuilder);
                +     *     }
                +     * }
                +     *
                +     * ...
                +     *
                +     * KStream outputStream = inputStream.transform(new MyTransformerSupplier());
                +     * }
                + *

                + * With either strategy, within the {@link Transformer}, the state is obtained via the {@link ProcessorContext}. + * To trigger periodic actions via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) punctuate()}, + * a schedule must be registered. + * The {@link Transformer} must return a {@link KeyValue} type in {@link Transformer#transform(Object, Object) + * transform()}. + * The return value of {@link Transformer#transform(Object, Object) Transformer#transform()} may be {@code null}, + * in which case no record is emitted. + *

                {@code
                +     * class MyTransformer implements Transformer {
                +     *     private ProcessorContext context;
                +     *     private StateStore state;
                +     *
                +     *     void init(ProcessorContext context) {
                +     *         this.context = context;
                +     *         this.state = context.getStateStore("myTransformState");
                +     *         // punctuate each second; can access this.state
                +     *         context.schedule(Duration.ofSeconds(1), PunctuationType.WALL_CLOCK_TIME, new Punctuator(..));
                +     *     }
                +     *
                +     *     KeyValue transform(K key, V value) {
                +     *         // can access this.state
                +     *         return new KeyValue(key, value); // can emit a single value via return -- can also be null
                +     *     }
                +     *
                +     *     void close() {
                +     *         // can access this.state
                +     *     }
                +     * }
                +     * }
                + * Even if any upstream operation was key-changing, no auto-repartition is triggered. + * If repartitioning is required, a call to {@link #repartition()} should be performed before {@code transform()}. + *

                + * Transforming records might result in an internal data redistribution if a key based operator (like an aggregation + * or join) is applied to the result {@code KStream}. + * (cf. {@link #transformValues(ValueTransformerSupplier, String...) transformValues()} ) + *

                + * Note that it is possible to emit multiple records for each input record by using + * {@link org.apache.kafka.streams.processor.ProcessorContext#forward(Object, Object) context#forward()} in + * {@link Transformer#transform(Object, Object) Transformer#transform()} and + * {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) Punctuator#punctuate()}. + * Be aware that a mismatch between the types of the emitted records and the type of the stream would only be + * detected at runtime. + * To ensure type-safety at compile-time, + * {@link org.apache.kafka.streams.processor.ProcessorContext#forward(Object, Object) context#forward()} should + * not be used in {@link Transformer#transform(Object, Object) Transformer#transform()} and + * {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) Punctuator#punctuate()}. + * If in {@link Transformer#transform(Object, Object) Transformer#transform()} multiple records need to be emitted + * for each input record, it is recommended to use {@link #flatTransform(TransformerSupplier, String...) + * flatTransform()}. + * The supplier should always generate a new instance each time {@link TransformerSupplier#get()} gets called. Creating + * a single {@link Transformer} object and returning the same object reference in {@link TransformerSupplier#get()} would be + * a violation of the supplier pattern and leads to runtime exceptions. + * + * @param transformerSupplier an instance of {@link TransformerSupplier} that generates a newly constructed + * {@link Transformer} + * @param stateStoreNames the names of the state stores used by the processor; not required if the supplier + * implements {@link ConnectedStoreProvider#stores()} + * @param the key type of the new stream + * @param the value type of the new stream + * @return a {@code KStream} that contains more or less records with new key and value (possibly of different type) + * @see #map(KeyValueMapper) + * @see #flatTransform(TransformerSupplier, String...) + * @see #transformValues(ValueTransformerSupplier, String...) + * @see #transformValues(ValueTransformerWithKeySupplier, String...) + * @see #process(ProcessorSupplier, String...) + */ + KStream transform(final TransformerSupplier> transformerSupplier, + final String... stateStoreNames); + + /** + * Transform each record of the input stream into zero or one record in the output stream (both key and value type + * can be altered arbitrarily). + * A {@link Transformer} (provided by the given {@link TransformerSupplier}) is applied to each input record and + * returns zero or one output record. + * Thus, an input record {@code } can be transformed into an output record {@code }. + * Attaching a state store makes this a stateful record-by-record operation (cf. {@link #map(KeyValueMapper) map()}). + * If you choose not to attach one, this operation is similar to the stateless {@link #map(KeyValueMapper) map()} + * but allows access to the {@code ProcessorContext} and record metadata. + * This is essentially mixing the Processor API into the DSL, and provides all the functionality of the PAPI. + * Furthermore, via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) Punctuator#punctuate()}, + * the processing progress can be observed and additional periodic actions can be performed. + *

                + * In order for the transformer to use state stores, the stores must be added to the topology and connected to the + * transformer using at least one of two strategies (though it's not required to connect global state stores; read-only + * access to global state stores is available by default). + *

                + * The first strategy is to manually add the {@link StoreBuilder}s via {@link Topology#addStateStore(StoreBuilder, String...)}, + * and specify the store names via {@code stateStoreNames} so they will be connected to the transformer. + *

                {@code
                +     * // create store
                +     * StoreBuilder> keyValueStoreBuilder =
                +     *         Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myTransformState"),
                +     *                 Serdes.String(),
                +     *                 Serdes.String());
                +     * // add store
                +     * builder.addStateStore(keyValueStoreBuilder);
                +     *
                +     * KStream outputStream = inputStream.transform(new TransformerSupplier() {
                +     *     public Transformer get() {
                +     *         return new MyTransformer();
                +     *     }
                +     * }, "myTransformState");
                +     * }
                + * The second strategy is for the given {@link TransformerSupplier} to implement {@link ConnectedStoreProvider#stores()}, + * which provides the {@link StoreBuilder}s to be automatically added to the topology and connected to the transformer. + *
                {@code
                +     * class MyTransformerSupplier implements TransformerSupplier {
                +     *     // supply transformer
                +     *     Transformer get() {
                +     *         return new MyTransformer();
                +     *     }
                +     *
                +     *     // provide store(s) that will be added and connected to the associated transformer
                +     *     // the store name from the builder ("myTransformState") is used to access the store later via the ProcessorContext
                +     *     Set stores() {
                +     *         StoreBuilder> keyValueStoreBuilder =
                +     *                   Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myTransformState"),
                +     *                   Serdes.String(),
                +     *                   Serdes.String());
                +     *         return Collections.singleton(keyValueStoreBuilder);
                +     *     }
                +     * }
                +     *
                +     * ...
                +     *
                +     * KStream outputStream = inputStream.transform(new MyTransformerSupplier());
                +     * }
                + *

                + * With either strategy, within the {@link Transformer}, the state is obtained via the {@link ProcessorContext}. + * To trigger periodic actions via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) punctuate()}, + * a schedule must be registered. + * The {@link Transformer} must return a {@link KeyValue} type in {@link Transformer#transform(Object, Object) + * transform()}. + * The return value of {@link Transformer#transform(Object, Object) Transformer#transform()} may be {@code null}, + * in which case no record is emitted. + *

                {@code
                +     * class MyTransformer implements Transformer {
                +     *     private ProcessorContext context;
                +     *     private StateStore state;
                +     *
                +     *     void init(ProcessorContext context) {
                +     *         this.context = context;
                +     *         this.state = context.getStateStore("myTransformState");
                +     *         // punctuate each second; can access this.state
                +     *         context.schedule(Duration.ofSeconds(1), PunctuationType.WALL_CLOCK_TIME, new Punctuator(..));
                +     *     }
                +     *
                +     *     KeyValue transform(K key, V value) {
                +     *         // can access this.state
                +     *         return new KeyValue(key, value); // can emit a single value via return -- can also be null
                +     *     }
                +     *
                +     *     void close() {
                +     *         // can access this.state
                +     *     }
                +     * }
                +     * }
                + * Even if any upstream operation was key-changing, no auto-repartition is triggered. + * If repartitioning is required, a call to {@link #repartition()} should be performed before {@code transform()}. + *

                + * Transforming records might result in an internal data redistribution if a key based operator (like an aggregation + * or join) is applied to the result {@code KStream}. + * (cf. {@link #transformValues(ValueTransformerSupplier, String...) transformValues()} ) + *

                + * Note that it is possible to emit multiple records for each input record by using + * {@link org.apache.kafka.streams.processor.ProcessorContext#forward(Object, Object) context#forward()} in + * {@link Transformer#transform(Object, Object) Transformer#transform()} and + * {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) Punctuator#punctuate()}. + * Be aware that a mismatch between the types of the emitted records and the type of the stream would only be + * detected at runtime. + * To ensure type-safety at compile-time, + * {@link org.apache.kafka.streams.processor.ProcessorContext#forward(Object, Object) context#forward()} should + * not be used in {@link Transformer#transform(Object, Object) Transformer#transform()} and + * {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) Punctuator#punctuate()}. + * If in {@link Transformer#transform(Object, Object) Transformer#transform()} multiple records need to be emitted + * for each input record, it is recommended to use {@link #flatTransform(TransformerSupplier, String...) + * flatTransform()}. + * The supplier should always generate a new instance each time {@link TransformerSupplier#get()} gets called. Creating + * a single {@link Transformer} object and returning the same object reference in {@link TransformerSupplier#get()} would be + * a violation of the supplier pattern and leads to runtime exceptions. + * + * @param transformerSupplier an instance of {@link TransformerSupplier} that generates a newly constructed + * {@link Transformer} + * @param named a {@link Named} config used to name the processor in the topology + * @param stateStoreNames the names of the state stores used by the processor; not required if the supplier + * implements {@link ConnectedStoreProvider#stores()} + * @param the key type of the new stream + * @param the value type of the new stream + * @return a {@code KStream} that contains more or less records with new key and value (possibly of different type) + * @see #map(KeyValueMapper) + * @see #flatTransform(TransformerSupplier, String...) + * @see #transformValues(ValueTransformerSupplier, String...) + * @see #transformValues(ValueTransformerWithKeySupplier, String...) + * @see #process(ProcessorSupplier, String...) + */ + KStream transform(final TransformerSupplier> transformerSupplier, + final Named named, + final String... stateStoreNames); + + /** + * Transform each record of the input stream into zero or more records in the output stream (both key and value type + * can be altered arbitrarily). + * A {@link Transformer} (provided by the given {@link TransformerSupplier}) is applied to each input record and + * returns zero or more output records. + * Thus, an input record {@code } can be transformed into output records {@code , , ...}. + * Attaching a state store makes this a stateful record-by-record operation (cf. {@link #flatMap(KeyValueMapper) flatMap()}). + * If you choose not to attach one, this operation is similar to the stateless {@link #flatMap(KeyValueMapper) flatMap()} + * but allows access to the {@code ProcessorContext} and record metadata. + * Furthermore, via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) Punctuator#punctuate()} + * the processing progress can be observed and additional periodic actions can be performed. + *

                + * In order for the transformer to use state stores, the stores must be added to the topology and connected to the + * transformer using at least one of two strategies (though it's not required to connect global state stores; read-only + * access to global state stores is available by default). + *

                + * The first strategy is to manually add the {@link StoreBuilder}s via {@link Topology#addStateStore(StoreBuilder, String...)}, + * and specify the store names via {@code stateStoreNames} so they will be connected to the transformer. + *

                {@code
                +     * // create store
                +     * StoreBuilder> keyValueStoreBuilder =
                +     *         Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myTransformState"),
                +     *                 Serdes.String(),
                +     *                 Serdes.String());
                +     * // add store
                +     * builder.addStateStore(keyValueStoreBuilder);
                +     *
                +     * KStream outputStream = inputStream.transform(new TransformerSupplier() {
                +     *     public Transformer get() {
                +     *         return new MyTransformer();
                +     *     }
                +     * }, "myTransformState");
                +     * }
                + * The second strategy is for the given {@link TransformerSupplier} to implement {@link ConnectedStoreProvider#stores()}, + * which provides the {@link StoreBuilder}s to be automatically added to the topology and connected to the transformer. + *
                {@code
                +     * class MyTransformerSupplier implements TransformerSupplier {
                +     *     // supply transformer
                +     *     Transformer get() {
                +     *         return new MyTransformer();
                +     *     }
                +     *
                +     *     // provide store(s) that will be added and connected to the associated transformer
                +     *     // the store name from the builder ("myTransformState") is used to access the store later via the ProcessorContext
                +     *     Set stores() {
                +     *         StoreBuilder> keyValueStoreBuilder =
                +     *                   Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myTransformState"),
                +     *                   Serdes.String(),
                +     *                   Serdes.String());
                +     *         return Collections.singleton(keyValueStoreBuilder);
                +     *     }
                +     * }
                +     *
                +     * ...
                +     *
                +     * KStream outputStream = inputStream.flatTransform(new MyTransformerSupplier());
                +     * }
                + *

                + * With either strategy, within the {@link Transformer}, the state is obtained via the {@link ProcessorContext}. + * To trigger periodic actions via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) punctuate()}, + * a schedule must be registered. + * The {@link Transformer} must return an {@link java.lang.Iterable} type (e.g., any {@link java.util.Collection} + * type) in {@link Transformer#transform(Object, Object) transform()}. + * The return value of {@link Transformer#transform(Object, Object) Transformer#transform()} may be {@code null}, + * which is equal to returning an empty {@link java.lang.Iterable Iterable}, i.e., no records are emitted. + *

                {@code
                +     * class MyTransformer implements Transformer {
                +     *     private ProcessorContext context;
                +     *     private StateStore state;
                +     *
                +     *     void init(ProcessorContext context) {
                +     *         this.context = context;
                +     *         this.state = context.getStateStore("myTransformState");
                +     *         // punctuate each second; can access this.state
                +     *         context.schedule(Duration.ofSeconds(1), PunctuationType.WALL_CLOCK_TIME, new Punctuator(..));
                +     *     }
                +     *
                +     *     Iterable transform(K key, V value) {
                +     *         // can access this.state
                +     *         List result = new ArrayList<>();
                +     *         for (int i = 0; i < 3; i++) {
                +     *             result.add(new KeyValue(key, value));
                +     *         }
                +     *         return result; // emits a list of key-value pairs via return
                +     *     }
                +     *
                +     *     void close() {
                +     *         // can access this.state
                +     *     }
                +     * }
                +     * }
                + * Even if any upstream operation was key-changing, no auto-repartition is triggered. + * If repartitioning is required, a call to {@link #repartition()} should be performed before + * {@code flatTransform()}. + *

                + * Transforming records might result in an internal data redistribution if a key based operator (like an aggregation + * or join) is applied to the result {@code KStream}. + * (cf. {@link #transformValues(ValueTransformerSupplier, String...) transformValues()}) + *

                + * Note that it is possible to emit records by using + * {@link org.apache.kafka.streams.processor.ProcessorContext#forward(Object, Object) + * context#forward()} in {@link Transformer#transform(Object, Object) Transformer#transform()} and + * {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) Punctuator#punctuate()}. + * Be aware that a mismatch between the types of the emitted records and the type of the stream would only be + * detected at runtime. + * To ensure type-safety at compile-time, + * {@link org.apache.kafka.streams.processor.ProcessorContext#forward(Object, Object) context#forward()} should + * not be used in {@link Transformer#transform(Object, Object) Transformer#transform()} and + * {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) Punctuator#punctuate()}. + * The supplier should always generate a new instance each time {@link TransformerSupplier#get()} gets called. Creating + * a single {@link Transformer} object and returning the same object reference in {@link TransformerSupplier#get()} would be + * a violation of the supplier pattern and leads to runtime exceptions. + * + * @param transformerSupplier an instance of {@link TransformerSupplier} that generates a newly constructed {@link Transformer} + * @param stateStoreNames the names of the state stores used by the processor; not required if the supplier + * implements {@link ConnectedStoreProvider#stores()} + * @param the key type of the new stream + * @param the value type of the new stream + * @return a {@code KStream} that contains more or less records with new key and value (possibly of different type) + * @see #flatMap(KeyValueMapper) + * @see #transform(TransformerSupplier, String...) + * @see #transformValues(ValueTransformerSupplier, String...) + * @see #transformValues(ValueTransformerWithKeySupplier, String...) + * @see #process(ProcessorSupplier, String...) + */ + KStream flatTransform(final TransformerSupplier>> transformerSupplier, + final String... stateStoreNames); + + /** + * Transform each record of the input stream into zero or more records in the output stream (both key and value type + * can be altered arbitrarily). + * A {@link Transformer} (provided by the given {@link TransformerSupplier}) is applied to each input record and + * returns zero or more output records. + * Thus, an input record {@code } can be transformed into output records {@code , , ...}. + * Attaching a state store makes this a stateful record-by-record operation (cf. {@link #flatMap(KeyValueMapper) flatMap()}). + * If you choose not to attach one, this operation is similar to the stateless {@link #flatMap(KeyValueMapper) flatMap()} + * but allows access to the {@code ProcessorContext} and record metadata. + * Furthermore, via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) Punctuator#punctuate()} + * the processing progress can be observed and additional periodic actions can be performed. + *

                + * In order for the transformer to use state stores, the stores must be added to the topology and connected to the + * transformer using at least one of two strategies (though it's not required to connect global state stores; read-only + * access to global state stores is available by default). + *

                + * The first strategy is to manually add the {@link StoreBuilder}s via {@link Topology#addStateStore(StoreBuilder, String...)}, + * and specify the store names via {@code stateStoreNames} so they will be connected to the transformer. + *

                {@code
                +     * // create store
                +     * StoreBuilder> keyValueStoreBuilder =
                +     *         Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myTransformState"),
                +     *                 Serdes.String(),
                +     *                 Serdes.String());
                +     * // add store
                +     * builder.addStateStore(keyValueStoreBuilder);
                +     *
                +     * KStream outputStream = inputStream.transform(new TransformerSupplier() {
                +     *     public Transformer get() {
                +     *         return new MyTransformer();
                +     *     }
                +     * }, "myTransformState");
                +     * }
                + * The second strategy is for the given {@link TransformerSupplier} to implement {@link ConnectedStoreProvider#stores()}, + * which provides the {@link StoreBuilder}s to be automatically added to the topology and connected to the transformer. + *
                {@code
                +     * class MyTransformerSupplier implements TransformerSupplier {
                +     *     // supply transformer
                +     *     Transformer get() {
                +     *         return new MyTransformer();
                +     *     }
                +     *
                +     *     // provide store(s) that will be added and connected to the associated transformer
                +     *     // the store name from the builder ("myTransformState") is used to access the store later via the ProcessorContext
                +     *     Set stores() {
                +     *         StoreBuilder> keyValueStoreBuilder =
                +     *                   Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myTransformState"),
                +     *                   Serdes.String(),
                +     *                   Serdes.String());
                +     *         return Collections.singleton(keyValueStoreBuilder);
                +     *     }
                +     * }
                +     *
                +     * ...
                +     *
                +     * KStream outputStream = inputStream.flatTransform(new MyTransformerSupplier());
                +     * }
                + *

                + * With either strategy, within the {@link Transformer}, the state is obtained via the {@link ProcessorContext}. + * To trigger periodic actions via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) punctuate()}, + * a schedule must be registered. + * The {@link Transformer} must return an {@link java.lang.Iterable} type (e.g., any {@link java.util.Collection} + * type) in {@link Transformer#transform(Object, Object) transform()}. + * The return value of {@link Transformer#transform(Object, Object) Transformer#transform()} may be {@code null}, + * which is equal to returning an empty {@link java.lang.Iterable Iterable}, i.e., no records are emitted. + *

                {@code
                +     * class MyTransformer implements Transformer {
                +     *     private ProcessorContext context;
                +     *     private StateStore state;
                +     *
                +     *     void init(ProcessorContext context) {
                +     *         this.context = context;
                +     *         this.state = context.getStateStore("myTransformState");
                +     *         // punctuate each second; can access this.state
                +     *         context.schedule(Duration.ofSeconds(1), PunctuationType.WALL_CLOCK_TIME, new Punctuator(..));
                +     *     }
                +     *
                +     *     Iterable transform(K key, V value) {
                +     *         // can access this.state
                +     *         List result = new ArrayList<>();
                +     *         for (int i = 0; i < 3; i++) {
                +     *             result.add(new KeyValue(key, value));
                +     *         }
                +     *         return result; // emits a list of key-value pairs via return
                +     *     }
                +     *
                +     *     void close() {
                +     *         // can access this.state
                +     *     }
                +     * }
                +     * }
                + * Even if any upstream operation was key-changing, no auto-repartition is triggered. + * If repartitioning is required, a call to {@link #repartition()} should be performed before + * {@code flatTransform()}. + *

                + * Transforming records might result in an internal data redistribution if a key based operator (like an aggregation + * or join) is applied to the result {@code KStream}. + * (cf. {@link #transformValues(ValueTransformerSupplier, String...) transformValues()}) + *

                + * Note that it is possible to emit records by using + * {@link org.apache.kafka.streams.processor.ProcessorContext#forward(Object, Object) + * context#forward()} in {@link Transformer#transform(Object, Object) Transformer#transform()} and + * {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) Punctuator#punctuate()}. + * Be aware that a mismatch between the types of the emitted records and the type of the stream would only be + * detected at runtime. + * To ensure type-safety at compile-time, + * {@link org.apache.kafka.streams.processor.ProcessorContext#forward(Object, Object) context#forward()} should + * not be used in {@link Transformer#transform(Object, Object) Transformer#transform()} and + * {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) Punctuator#punctuate()}. + * The supplier should always generate a new instance each time {@link TransformerSupplier#get()} gets called. Creating + * a single {@link Transformer} object and returning the same object reference in {@link TransformerSupplier#get()} would be + * a violation of the supplier pattern and leads to runtime exceptions. + * + * @param transformerSupplier an instance of {@link TransformerSupplier} that generates a newly constructed {@link Transformer} + * @param named a {@link Named} config used to name the processor in the topology + * @param stateStoreNames the names of the state stores used by the processor; not required if the supplier + * implements {@link ConnectedStoreProvider#stores()} + * @param the key type of the new stream + * @param the value type of the new stream + * @return a {@code KStream} that contains more or less records with new key and value (possibly of different type) + * @see #flatMap(KeyValueMapper) + * @see #transform(TransformerSupplier, String...) + * @see #transformValues(ValueTransformerSupplier, String...) + * @see #transformValues(ValueTransformerWithKeySupplier, String...) + * @see #process(ProcessorSupplier, String...) + */ + KStream flatTransform(final TransformerSupplier>> transformerSupplier, + final Named named, + final String... stateStoreNames); + + /** + * Transform the value of each input record into a new value (with possibly a new type) of the output record. + * A {@link ValueTransformer} (provided by the given {@link ValueTransformerSupplier}) is applied to each input + * record value and computes a new value for it. + * Thus, an input record {@code } can be transformed into an output record {@code }. + * Attaching a state store makes this a stateful record-by-record operation (cf. {@link #mapValues(ValueMapper) mapValues()}). + * If you choose not to attach one, this operation is similar to the stateless {@link #mapValues(ValueMapper) mapValues()} + * but allows access to the {@code ProcessorContext} and record metadata. + * Furthermore, via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long)} the processing progress + * can be observed and additional periodic actions can be performed. + *

                + * In order for the transformer to use state stores, the stores must be added to the topology and connected to the + * transformer using at least one of two strategies (though it's not required to connect global state stores; read-only + * access to global state stores is available by default). + *

                + * The first strategy is to manually add the {@link StoreBuilder}s via {@link Topology#addStateStore(StoreBuilder, String...)}, + * and specify the store names via {@code stateStoreNames} so they will be connected to the transformer. + *

                {@code
                +     * // create store
                +     * StoreBuilder> keyValueStoreBuilder =
                +     *         Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myValueTransformState"),
                +     *                 Serdes.String(),
                +     *                 Serdes.String());
                +     * // add store
                +     * builder.addStateStore(keyValueStoreBuilder);
                +     *
                +     * KStream outputStream = inputStream.transformValues(new ValueTransformerSupplier() {
                +     *     public ValueTransformer get() {
                +     *         return new MyValueTransformer();
                +     *     }
                +     * }, "myValueTransformState");
                +     * }
                + * The second strategy is for the given {@link ValueTransformerSupplier} to implement {@link ConnectedStoreProvider#stores()}, + * which provides the {@link StoreBuilder}s to be automatically added to the topology and connected to the transformer. + *
                {@code
                +     * class MyValueTransformerSupplier implements ValueTransformerSupplier {
                +     *     // supply transformer
                +     *     ValueTransformer get() {
                +     *         return new MyValueTransformer();
                +     *     }
                +     *
                +     *     // provide store(s) that will be added and connected to the associated transformer
                +     *     // the store name from the builder ("myValueTransformState") is used to access the store later via the ProcessorContext
                +     *     Set stores() {
                +     *         StoreBuilder> keyValueStoreBuilder =
                +     *                   Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myValueTransformState"),
                +     *                   Serdes.String(),
                +     *                   Serdes.String());
                +     *         return Collections.singleton(keyValueStoreBuilder);
                +     *     }
                +     * }
                +     *
                +     * ...
                +     *
                +     * KStream outputStream = inputStream.transformValues(new MyValueTransformerSupplier());
                +     * }
                + *

                + * With either strategy, within the {@link ValueTransformer}, the state is obtained via the {@link ProcessorContext}. + * To trigger periodic actions via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) punctuate()}, + * a schedule must be registered. + * The {@link ValueTransformer} must return the new value in {@link ValueTransformer#transform(Object) transform()}. + * In contrast to {@link #transform(TransformerSupplier, String...) transform()}, no additional {@link KeyValue} + * pairs can be emitted via + * {@link org.apache.kafka.streams.processor.ProcessorContext#forward(Object, Object) ProcessorContext.forward()}. + * A {@link org.apache.kafka.streams.errors.StreamsException} is thrown if the {@link ValueTransformer} tries to + * emit a {@link KeyValue} pair. + *

                {@code
                +     * class MyValueTransformer implements ValueTransformer {
                +     *     private StateStore state;
                +     *
                +     *     void init(ProcessorContext context) {
                +     *         this.state = context.getStateStore("myValueTransformState");
                +     *         // punctuate each second, can access this.state
                +     *         context.schedule(Duration.ofSeconds(1), PunctuationType.WALL_CLOCK_TIME, new Punctuator(..));
                +     *     }
                +     *
                +     *     NewValueType transform(V value) {
                +     *         // can access this.state
                +     *         return new NewValueType(); // or null
                +     *     }
                +     *
                +     *     void close() {
                +     *         // can access this.state
                +     *     }
                +     * }
                +     * }
                + * Even if any upstream operation was key-changing, no auto-repartition is triggered. + * If repartitioning is required, a call to {@link #repartition()} should be performed before + * {@code transformValues()}. + *

                + * Setting a new value preserves data co-location with respect to the key. + * Thus, no internal data redistribution is required if a key based operator (like an aggregation or join) + * is applied to the result {@code KStream}. (cf. {@link #transform(TransformerSupplier, String...)}) + * + * @param valueTransformerSupplier an instance of {@link ValueTransformerSupplier} that generates a newly constructed {@link ValueTransformer} + * The supplier should always generate a new instance. Creating a single {@link ValueTransformer} object + * and returning the same object reference in {@link ValueTransformer} is a + * violation of the supplier pattern and leads to runtime exceptions. + * @param stateStoreNames the names of the state stores used by the processor; not required if the supplier + * implements {@link ConnectedStoreProvider#stores()} + * @param the value type of the result stream + * @return a {@code KStream} that contains records with unmodified key and new values (possibly of different type) + * @see #mapValues(ValueMapper) + * @see #mapValues(ValueMapperWithKey) + * @see #transform(TransformerSupplier, String...) + */ + KStream transformValues(final ValueTransformerSupplier valueTransformerSupplier, + final String... stateStoreNames); + /** + * Transform the value of each input record into a new value (with possibly a new type) of the output record. + * A {@link ValueTransformer} (provided by the given {@link ValueTransformerSupplier}) is applied to each input + * record value and computes a new value for it. + * Thus, an input record {@code } can be transformed into an output record {@code }. + * Attaching a state store makes this a stateful record-by-record operation (cf. {@link #mapValues(ValueMapper) mapValues()}). + * If you choose not to attach one, this operation is similar to the stateless {@link #mapValues(ValueMapper) mapValues()} + * but allows access to the {@code ProcessorContext} and record metadata. + * This is essentially mixing the Processor API into the DSL, and provides all the functionality of the PAPI. + * Furthermore, via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long)} the processing progress + * can be observed and additional periodic actions can be performed. + *

                + * In order for the transformer to use state stores, the stores must be added to the topology and connected to the + * transformer using at least one of two strategies (though it's not required to connect global state stores; read-only + * access to global state stores is available by default). + *

                + * The first strategy is to manually add the {@link StoreBuilder}s via {@link Topology#addStateStore(StoreBuilder, String...)}, + * and specify the store names via {@code stateStoreNames} so they will be connected to the transformer. + *

                {@code
                +     * // create store
                +     * StoreBuilder> keyValueStoreBuilder =
                +     *         Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myValueTransformState"),
                +     *                 Serdes.String(),
                +     *                 Serdes.String());
                +     * // add store
                +     * builder.addStateStore(keyValueStoreBuilder);
                +     *
                +     * KStream outputStream = inputStream.transformValues(new ValueTransformerSupplier() {
                +     *     public ValueTransformer get() {
                +     *         return new MyValueTransformer();
                +     *     }
                +     * }, "myValueTransformState");
                +     * }
                + * The second strategy is for the given {@link ValueTransformerSupplier} to implement {@link ConnectedStoreProvider#stores()}, + * which provides the {@link StoreBuilder}s to be automatically added to the topology and connected to the transformer. + *
                {@code
                +     * class MyValueTransformerSupplier implements ValueTransformerSupplier {
                +     *     // supply transformer
                +     *     ValueTransformer get() {
                +     *         return new MyValueTransformer();
                +     *     }
                +     *
                +     *     // provide store(s) that will be added and connected to the associated transformer
                +     *     // the store name from the builder ("myValueTransformState") is used to access the store later via the ProcessorContext
                +     *     Set stores() {
                +     *         StoreBuilder> keyValueStoreBuilder =
                +     *                   Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myValueTransformState"),
                +     *                   Serdes.String(),
                +     *                   Serdes.String());
                +     *         return Collections.singleton(keyValueStoreBuilder);
                +     *     }
                +     * }
                +     *
                +     * ...
                +     *
                +     * KStream outputStream = inputStream.transformValues(new MyValueTransformerSupplier());
                +     * }
                + *

                + * With either strategy, within the {@link ValueTransformer}, the state is obtained via the {@link ProcessorContext}. + * To trigger periodic actions via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) punctuate()}, + * a schedule must be registered. + * The {@link ValueTransformer} must return the new value in {@link ValueTransformer#transform(Object) transform()}. + * In contrast to {@link #transform(TransformerSupplier, String...) transform()}, no additional {@link KeyValue} + * pairs can be emitted via + * {@link org.apache.kafka.streams.processor.ProcessorContext#forward(Object, Object) ProcessorContext.forward()}. + * A {@link org.apache.kafka.streams.errors.StreamsException} is thrown if the {@link ValueTransformer} tries to + * emit a {@link KeyValue} pair. + *

                {@code
                +     * class MyValueTransformer implements ValueTransformer {
                +     *     private StateStore state;
                +     *
                +     *     void init(ProcessorContext context) {
                +     *         this.state = context.getStateStore("myValueTransformState");
                +     *         // punctuate each second, can access this.state
                +     *         context.schedule(Duration.ofSeconds(1), PunctuationType.WALL_CLOCK_TIME, new Punctuator(..));
                +     *     }
                +     *
                +     *     NewValueType transform(V value) {
                +     *         // can access this.state
                +     *         return new NewValueType(); // or null
                +     *     }
                +     *
                +     *     void close() {
                +     *         // can access this.state
                +     *     }
                +     * }
                +     * }
                + * Even if any upstream operation was key-changing, no auto-repartition is triggered. + * If repartitioning is required, a call to {@link #repartition()} should be performed before + * {@code transformValues()}. + *

                + * Setting a new value preserves data co-location with respect to the key. + * Thus, no internal data redistribution is required if a key based operator (like an aggregation or join) + * is applied to the result {@code KStream}. (cf. {@link #transform(TransformerSupplier, String...)}) + * + * @param valueTransformerSupplier an instance of {@link ValueTransformerSupplier} that generates a newly constructed {@link ValueTransformer} + * The supplier should always generate a new instance. Creating a single {@link ValueTransformer} object + * and returning the same object reference in {@link ValueTransformer} is a + * violation of the supplier pattern and leads to runtime exceptions. + * @param named a {@link Named} config used to name the processor in the topology + * @param stateStoreNames the names of the state stores used by the processor; not required if the supplier + * implements {@link ConnectedStoreProvider#stores()} + * @param the value type of the result stream + * @return a {@code KStream} that contains records with unmodified key and new values (possibly of different type) + * @see #mapValues(ValueMapper) + * @see #mapValues(ValueMapperWithKey) + * @see #transform(TransformerSupplier, String...) + */ + KStream transformValues(final ValueTransformerSupplier valueTransformerSupplier, + final Named named, + final String... stateStoreNames); + + /** + * Transform the value of each input record into a new value (with possibly a new type) of the output record. + * A {@link ValueTransformerWithKey} (provided by the given {@link ValueTransformerWithKeySupplier}) is applied to + * each input record value and computes a new value for it. + * Thus, an input record {@code } can be transformed into an output record {@code }. + * Attaching a state store makes this a stateful record-by-record operation (cf. {@link #mapValues(ValueMapperWithKey) mapValues()}). + * If you choose not to attach one, this operation is similar to the stateless {@link #mapValues(ValueMapperWithKey) mapValues()} + * but allows access to the {@code ProcessorContext} and record metadata. + * This is essentially mixing the Processor API into the DSL, and provides all the functionality of the PAPI. + * Furthermore, via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long)} the processing progress + * can be observed and additional periodic actions can be performed. + *

                + * In order for the transformer to use state stores, the stores must be added to the topology and connected to the + * transformer using at least one of two strategies (though it's not required to connect global state stores; read-only + * access to global state stores is available by default). + *

                + * The first strategy is to manually add the {@link StoreBuilder}s via {@link Topology#addStateStore(StoreBuilder, String...)}, + * and specify the store names via {@code stateStoreNames} so they will be connected to the transformer. + *

                {@code
                +     * // create store
                +     * StoreBuilder> keyValueStoreBuilder =
                +     *         Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myValueTransformState"),
                +     *                 Serdes.String(),
                +     *                 Serdes.String());
                +     * // add store
                +     * builder.addStateStore(keyValueStoreBuilder);
                +     *
                +     * KStream outputStream = inputStream.transformValues(new ValueTransformerWithKeySupplier() {
                +     *     public ValueTransformer get() {
                +     *         return new MyValueTransformer();
                +     *     }
                +     * }, "myValueTransformState");
                +     * }
                + * The second strategy is for the given {@link ValueTransformerWithKeySupplier} to implement {@link ConnectedStoreProvider#stores()}, + * which provides the {@link StoreBuilder}s to be automatically added to the topology and connected to the transformer. + *
                {@code
                +     * class MyValueTransformerWithKeySupplier implements ValueTransformerWithKeySupplier {
                +     *     // supply transformer
                +     *     ValueTransformerWithKey get() {
                +     *         return new MyValueTransformerWithKey();
                +     *     }
                +     *
                +     *     // provide store(s) that will be added and connected to the associated transformer
                +     *     // the store name from the builder ("myValueTransformState") is used to access the store later via the ProcessorContext
                +     *     Set stores() {
                +     *         StoreBuilder> keyValueStoreBuilder =
                +     *                   Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myValueTransformState"),
                +     *                   Serdes.String(),
                +     *                   Serdes.String());
                +     *         return Collections.singleton(keyValueStoreBuilder);
                +     *     }
                +     * }
                +     *
                +     * ...
                +     *
                +     * KStream outputStream = inputStream.transformValues(new MyValueTransformerWithKeySupplier());
                +     * }
                + *

                + * With either strategy, within the {@link ValueTransformerWithKey}, the state is obtained via the {@link ProcessorContext}. + * To trigger periodic actions via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) punctuate()}, + * a schedule must be registered. + * The {@link ValueTransformerWithKey} must return the new value in + * {@link ValueTransformerWithKey#transform(Object, Object) transform()}. + * In contrast to {@link #transform(TransformerSupplier, String...) transform()} and + * {@link #flatTransform(TransformerSupplier, String...) flatTransform()}, no additional {@link KeyValue} pairs + * can be emitted via + * {@link org.apache.kafka.streams.processor.ProcessorContext#forward(Object, Object) ProcessorContext.forward()}. + * A {@link org.apache.kafka.streams.errors.StreamsException} is thrown if the {@link ValueTransformerWithKey} tries + * to emit a {@link KeyValue} pair. + *

                {@code
                +     * class MyValueTransformerWithKey implements ValueTransformerWithKey {
                +     *     private StateStore state;
                +     *
                +     *     void init(ProcessorContext context) {
                +     *         this.state = context.getStateStore("myValueTransformState");
                +     *         // punctuate each second, can access this.state
                +     *         context.schedule(Duration.ofSeconds(1), PunctuationType.WALL_CLOCK_TIME, new Punctuator(..));
                +     *     }
                +     *
                +     *     NewValueType transform(K readOnlyKey, V value) {
                +     *         // can access this.state and use read-only key
                +     *         return new NewValueType(readOnlyKey); // or null
                +     *     }
                +     *
                +     *     void close() {
                +     *         // can access this.state
                +     *     }
                +     * }
                +     * }
                + * Even if any upstream operation was key-changing, no auto-repartition is triggered. + * If repartitioning is required, a call to {@link #repartition()} should be performed before + * {@code transformValues()}. + *

                + * Note that the key is read-only and should not be modified, as this can lead to corrupt partitioning. + * So, setting a new value preserves data co-location with respect to the key. + * Thus, no internal data redistribution is required if a key based operator (like an aggregation or join) + * is applied to the result {@code KStream}. (cf. {@link #transform(TransformerSupplier, String...)}) + * + * @param valueTransformerSupplier an instance of {@link ValueTransformerWithKeySupplier} that generates a newly constructed {@link ValueTransformerWithKey} + * The supplier should always generate a new instance. Creating a single {@link ValueTransformerWithKey} object + * and returning the same object reference in {@link ValueTransformerWithKey} is a + * violation of the supplier pattern and leads to runtime exceptions. + * @param stateStoreNames the names of the state stores used by the processor; not required if the supplier + * implements {@link ConnectedStoreProvider#stores()} + * @param the value type of the result stream + * @return a {@code KStream} that contains records with unmodified key and new values (possibly of different type) + * @see #mapValues(ValueMapper) + * @see #mapValues(ValueMapperWithKey) + * @see #transform(TransformerSupplier, String...) + */ + KStream transformValues(final ValueTransformerWithKeySupplier valueTransformerSupplier, + final String... stateStoreNames); + + /** + * Transform the value of each input record into a new value (with possibly a new type) of the output record. + * A {@link ValueTransformerWithKey} (provided by the given {@link ValueTransformerWithKeySupplier}) is applied to + * each input record value and computes a new value for it. + * Thus, an input record {@code } can be transformed into an output record {@code }. + * Attaching a state store makes this a stateful record-by-record operation (cf. {@link #mapValues(ValueMapperWithKey) mapValues()}). + * If you choose not to attach one, this operation is similar to the stateless {@link #mapValues(ValueMapperWithKey) mapValues()} + * but allows access to the {@code ProcessorContext} and record metadata. + * This is essentially mixing the Processor API into the DSL, and provides all the functionality of the PAPI. + * Furthermore, via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long)} the processing progress + * can be observed and additional periodic actions can be performed. + *

                + * In order for the transformer to use state stores, the stores must be added to the topology and connected to the + * transformer using at least one of two strategies (though it's not required to connect global state stores; read-only + * access to global state stores is available by default). + *

                + * The first strategy is to manually add the {@link StoreBuilder}s via {@link Topology#addStateStore(StoreBuilder, String...)}, + * and specify the store names via {@code stateStoreNames} so they will be connected to the transformer. + *

                {@code
                +     * // create store
                +     * StoreBuilder> keyValueStoreBuilder =
                +     *         Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myValueTransformState"),
                +     *                 Serdes.String(),
                +     *                 Serdes.String());
                +     * // add store
                +     * builder.addStateStore(keyValueStoreBuilder);
                +     *
                +     * KStream outputStream = inputStream.transformValues(new ValueTransformerWithKeySupplier() {
                +     *     public ValueTransformerWithKey get() {
                +     *         return new MyValueTransformerWithKey();
                +     *     }
                +     * }, "myValueTransformState");
                +     * }
                + * The second strategy is for the given {@link ValueTransformerWithKeySupplier} to implement {@link ConnectedStoreProvider#stores()}, + * which provides the {@link StoreBuilder}s to be automatically added to the topology and connected to the transformer. + *
                {@code
                +     * class MyValueTransformerWithKeySupplier implements ValueTransformerWithKeySupplier {
                +     *     // supply transformer
                +     *     ValueTransformerWithKey get() {
                +     *         return new MyValueTransformerWithKey();
                +     *     }
                +     *
                +     *     // provide store(s) that will be added and connected to the associated transformer
                +     *     // the store name from the builder ("myValueTransformState") is used to access the store later via the ProcessorContext
                +     *     Set stores() {
                +     *         StoreBuilder> keyValueStoreBuilder =
                +     *                   Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myValueTransformState"),
                +     *                   Serdes.String(),
                +     *                   Serdes.String());
                +     *         return Collections.singleton(keyValueStoreBuilder);
                +     *     }
                +     * }
                +     *
                +     * ...
                +     *
                +     * KStream outputStream = inputStream.transformValues(new MyValueTransformerWithKeySupplier());
                +     * }
                + *

                + * With either strategy, within the {@link ValueTransformerWithKey}, the state is obtained via the {@link ProcessorContext}. + * To trigger periodic actions via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) punctuate()}, + * a schedule must be registered. + * The {@link ValueTransformerWithKey} must return the new value in + * {@link ValueTransformerWithKey#transform(Object, Object) transform()}. + * In contrast to {@link #transform(TransformerSupplier, String...) transform()} and + * {@link #flatTransform(TransformerSupplier, String...) flatTransform()}, no additional {@link KeyValue} pairs + * can be emitted via + * {@link org.apache.kafka.streams.processor.ProcessorContext#forward(Object, Object) ProcessorContext.forward()}. + * A {@link org.apache.kafka.streams.errors.StreamsException} is thrown if the {@link ValueTransformerWithKey} tries + * to emit a {@link KeyValue} pair. + *

                {@code
                +     * class MyValueTransformerWithKey implements ValueTransformerWithKey {
                +     *     private StateStore state;
                +     *
                +     *     void init(ProcessorContext context) {
                +     *         this.state = context.getStateStore("myValueTransformState");
                +     *         // punctuate each second, can access this.state
                +     *         context.schedule(Duration.ofSeconds(1), PunctuationType.WALL_CLOCK_TIME, new Punctuator(..));
                +     *     }
                +     *
                +     *     NewValueType transform(K readOnlyKey, V value) {
                +     *         // can access this.state and use read-only key
                +     *         return new NewValueType(readOnlyKey); // or null
                +     *     }
                +     *
                +     *     void close() {
                +     *         // can access this.state
                +     *     }
                +     * }
                +     * }
                + * Even if any upstream operation was key-changing, no auto-repartition is triggered. + * If repartitioning is required, a call to {@link #repartition()} should be performed before + * {@code transformValues()}. + *

                + * Note that the key is read-only and should not be modified, as this can lead to corrupt partitioning. + * So, setting a new value preserves data co-location with respect to the key. + * Thus, no internal data redistribution is required if a key based operator (like an aggregation or join) + * is applied to the result {@code KStream}. (cf. {@link #transform(TransformerSupplier, String...)}) + * + * @param valueTransformerSupplier an instance of {@link ValueTransformerWithKeySupplier} that generates a newly constructed {@link ValueTransformerWithKey} + * The supplier should always generate a new instance. Creating a single {@link ValueTransformerWithKey} object + * and returning the same object reference in {@link ValueTransformerWithKey} is a + * violation of the supplier pattern and leads to runtime exceptions. + * @param named a {@link Named} config used to name the processor in the topology + * @param stateStoreNames the names of the state stores used by the processor; not required if the supplier + * implements {@link ConnectedStoreProvider#stores()} + * @param the value type of the result stream + * @return a {@code KStream} that contains records with unmodified key and new values (possibly of different type) + * @see #mapValues(ValueMapper) + * @see #mapValues(ValueMapperWithKey) + * @see #transform(TransformerSupplier, String...) + */ + KStream transformValues(final ValueTransformerWithKeySupplier valueTransformerSupplier, + final Named named, + final String... stateStoreNames); + /** + * Transform the value of each input record into zero or more new values (with possibly a new + * type) and emit for each new value a record with the same key of the input record and the value. + * A {@link ValueTransformer} (provided by the given {@link ValueTransformerSupplier}) is applied to each input + * record value and computes zero or more new values. + * Thus, an input record {@code } can be transformed into output records {@code , , ...}. + * Attaching a state store makes this a stateful record-by-record operation (cf. {@link #mapValues(ValueMapper) mapValues()}). + * If you choose not to attach one, this operation is similar to the stateless {@link #mapValues(ValueMapper) mapValues()} + * but allows access to the {@code ProcessorContext} and record metadata. + * This is essentially mixing the Processor API into the DSL, and provides all the functionality of the PAPI. + * Furthermore, via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) Punctuator#punctuate()} + * the processing progress can be observed and additional periodic actions can be performed. + *

                + * In order for the transformer to use state stores, the stores must be added to the topology and connected to the + * transformer using at least one of two strategies (though it's not required to connect global state stores; read-only + * access to global state stores is available by default). + *

                + * The first strategy is to manually add the {@link StoreBuilder}s via {@link Topology#addStateStore(StoreBuilder, String...)}, + * and specify the store names via {@code stateStoreNames} so they will be connected to the transformer. + *

                {@code
                +     * // create store
                +     * StoreBuilder> keyValueStoreBuilder =
                +     *         Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myValueTransformState"),
                +     *                 Serdes.String(),
                +     *                 Serdes.String());
                +     * // add store
                +     * builder.addStateStore(keyValueStoreBuilder);
                +     *
                +     * KStream outputStream = inputStream.flatTransformValues(new ValueTransformerSupplier() {
                +     *     public ValueTransformer get() {
                +     *         return new MyValueTransformer();
                +     *     }
                +     * }, "myValueTransformState");
                +     * }
                + * The second strategy is for the given {@link ValueTransformerSupplier} to implement {@link ConnectedStoreProvider#stores()}, + * which provides the {@link StoreBuilder}s to be automatically added to the topology and connected to the transformer. + *
                {@code
                +     * class MyValueTransformerSupplier implements ValueTransformerSupplier {
                +     *     // supply transformer
                +     *     ValueTransformerWithKey get() {
                +     *         return new MyValueTransformerWithKey();
                +     *     }
                +     *
                +     *     // provide store(s) that will be added and connected to the associated transformer
                +     *     // the store name from the builder ("myValueTransformState") is used to access the store later via the ProcessorContext
                +     *     Set stores() {
                +     *         StoreBuilder> keyValueStoreBuilder =
                +     *                   Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myValueTransformState"),
                +     *                   Serdes.String(),
                +     *                   Serdes.String());
                +     *         return Collections.singleton(keyValueStoreBuilder);
                +     *     }
                +     * }
                +     *
                +     * ...
                +     *
                +     * KStream outputStream = inputStream.flatTransformValues(new MyValueTransformer());
                +     * }
                + *

                + * With either strategy, within the {@link ValueTransformer}, the state is obtained via the {@link ProcessorContext}. + * To trigger periodic actions via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) punctuate()}, + * a schedule must be registered. + * The {@link ValueTransformer} must return an {@link java.lang.Iterable} type (e.g., any + * {@link java.util.Collection} type) in {@link ValueTransformer#transform(Object) + * transform()}. + * If the return value of {@link ValueTransformer#transform(Object) ValueTransformer#transform()} is an empty + * {@link java.lang.Iterable Iterable} or {@code null}, no records are emitted. + * In contrast to {@link #transform(TransformerSupplier, String...) transform()} and + * {@link #flatTransform(TransformerSupplier, String...) flatTransform()}, no additional {@link KeyValue} pairs + * can be emitted via + * {@link org.apache.kafka.streams.processor.ProcessorContext#forward(Object, Object) ProcessorContext.forward()}. + * A {@link org.apache.kafka.streams.errors.StreamsException} is thrown if the {@link ValueTransformer} tries to + * emit a {@link KeyValue} pair. + *

                {@code
                +     * class MyValueTransformer implements ValueTransformer {
                +     *     private StateStore state;
                +     *
                +     *     void init(ProcessorContext context) {
                +     *         this.state = context.getStateStore("myValueTransformState");
                +     *         // punctuate each second, can access this.state
                +     *         context.schedule(Duration.ofSeconds(1), PunctuationType.WALL_CLOCK_TIME, new Punctuator(..));
                +     *     }
                +     *
                +     *     Iterable transform(V value) {
                +     *         // can access this.state
                +     *         List result = new ArrayList<>();
                +     *         for (int i = 0; i < 3; i++) {
                +     *             result.add(new NewValueType(value));
                +     *         }
                +     *         return result; // values
                +     *     }
                +     *
                +     *     void close() {
                +     *         // can access this.state
                +     *     }
                +     * }
                +     * }
                + * Even if any upstream operation was key-changing, no auto-repartition is triggered. + * If repartitioning is required, a call to {@link #repartition()} should be performed before + * {@code flatTransformValues()}. + *

                + * Setting a new value preserves data co-location with respect to the key. + * Thus, no internal data redistribution is required if a key based operator (like an aggregation or join) + * is applied to the result {@code KStream}. (cf. {@link #flatTransform(TransformerSupplier, String...) + * flatTransform()}) + * + * @param valueTransformerSupplier an instance of {@link ValueTransformerSupplier} that generates a newly constructed {@link ValueTransformer} + * The supplier should always generate a new instance. Creating a single {@link ValueTransformer} object + * and returning the same object reference in {@link ValueTransformer} is a + * violation of the supplier pattern and leads to runtime exceptions. + * @param stateStoreNames the names of the state stores used by the processor; not required if the supplier + * implements {@link ConnectedStoreProvider#stores()} + * @param the value type of the result stream + * @return a {@code KStream} that contains more or less records with unmodified key and new values (possibly of + * different type) + * @see #mapValues(ValueMapper) + * @see #mapValues(ValueMapperWithKey) + * @see #transform(TransformerSupplier, String...) + * @see #flatTransform(TransformerSupplier, String...) + */ + KStream flatTransformValues(final ValueTransformerSupplier> valueTransformerSupplier, + final String... stateStoreNames); + + /** + * Transform the value of each input record into zero or more new values (with possibly a new + * type) and emit for each new value a record with the same key of the input record and the value. + * A {@link ValueTransformer} (provided by the given {@link ValueTransformerSupplier}) is applied to each input + * record value and computes zero or more new values. + * Thus, an input record {@code } can be transformed into output records {@code , , ...}. + * Attaching a state store makes this a stateful record-by-record operation (cf. {@link #mapValues(ValueMapper) mapValues()}). + * If you choose not to attach one, this operation is similar to the stateless {@link #mapValues(ValueMapper) mapValues()} + * but allows access to the {@code ProcessorContext} and record metadata. + * This is essentially mixing the Processor API into the DSL, and provides all the functionality of the PAPI. + * Furthermore, via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) Punctuator#punctuate()} + * the processing progress can be observed and additional periodic actions can be performed. + *

                + * In order for the transformer to use state stores, the stores must be added to the topology and connected to the + * transformer using at least one of two strategies (though it's not required to connect global state stores; read-only + * access to global state stores is available by default). + *

                + * The first strategy is to manually add the {@link StoreBuilder}s via {@link Topology#addStateStore(StoreBuilder, String...)}, + * and specify the store names via {@code stateStoreNames} so they will be connected to the transformer. + *

                {@code
                +     * // create store
                +     * StoreBuilder> keyValueStoreBuilder =
                +     *         Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myValueTransformState"),
                +     *                 Serdes.String(),
                +     *                 Serdes.String());
                +     * // add store
                +     * builder.addStateStore(keyValueStoreBuilder);
                +     *
                +     * KStream outputStream = inputStream.flatTransformValues(new ValueTransformerSupplier() {
                +     *     public ValueTransformer get() {
                +     *         return new MyValueTransformer();
                +     *     }
                +     * }, "myValueTransformState");
                +     * }
                + * The second strategy is for the given {@link ValueTransformerSupplier} to implement {@link ConnectedStoreProvider#stores()}, + * which provides the {@link StoreBuilder}s to be automatically added to the topology and connected to the transformer. + *
                {@code
                +     * class MyValueTransformerSupplier implements ValueTransformerSupplier {
                +     *     // supply transformer
                +     *     ValueTransformerWithKey get() {
                +     *         return new MyValueTransformerWithKey();
                +     *     }
                +     *
                +     *     // provide store(s) that will be added and connected to the associated transformer
                +     *     // the store name from the builder ("myValueTransformState") is used to access the store later via the ProcessorContext
                +     *     Set stores() {
                +     *         StoreBuilder> keyValueStoreBuilder =
                +     *                   Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myValueTransformState"),
                +     *                   Serdes.String(),
                +     *                   Serdes.String());
                +     *         return Collections.singleton(keyValueStoreBuilder);
                +     *     }
                +     * }
                +     *
                +     * ...
                +     *
                +     * KStream outputStream = inputStream.flatTransformValues(new MyValueTransformer());
                +     * }
                + *

                + * With either strategy, within the {@link ValueTransformer}, the state is obtained via the {@link ProcessorContext}. + * To trigger periodic actions via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) punctuate()}, + * a schedule must be registered. + * The {@link ValueTransformer} must return an {@link java.lang.Iterable} type (e.g., any + * {@link java.util.Collection} type) in {@link ValueTransformer#transform(Object) + * transform()}. + * If the return value of {@link ValueTransformer#transform(Object) ValueTransformer#transform()} is an empty + * {@link java.lang.Iterable Iterable} or {@code null}, no records are emitted. + * In contrast to {@link #transform(TransformerSupplier, String...) transform()} and + * {@link #flatTransform(TransformerSupplier, String...) flatTransform()}, no additional {@link KeyValue} pairs + * can be emitted via + * {@link org.apache.kafka.streams.processor.ProcessorContext#forward(Object, Object) ProcessorContext.forward()}. + * A {@link org.apache.kafka.streams.errors.StreamsException} is thrown if the {@link ValueTransformer} tries to + * emit a {@link KeyValue} pair. + *

                {@code
                +     * class MyValueTransformer implements ValueTransformer {
                +     *     private StateStore state;
                +     *
                +     *     void init(ProcessorContext context) {
                +     *         this.state = context.getStateStore("myValueTransformState");
                +     *         // punctuate each second, can access this.state
                +     *         context.schedule(Duration.ofSeconds(1), PunctuationType.WALL_CLOCK_TIME, new Punctuator(..));
                +     *     }
                +     *
                +     *     Iterable transform(V value) {
                +     *         // can access this.state
                +     *         List result = new ArrayList<>();
                +     *         for (int i = 0; i < 3; i++) {
                +     *             result.add(new NewValueType(value));
                +     *         }
                +     *         return result; // values
                +     *     }
                +     *
                +     *     void close() {
                +     *         // can access this.state
                +     *     }
                +     * }
                +     * }
                + * Even if any upstream operation was key-changing, no auto-repartition is triggered. + * If repartitioning is required, a call to {@link #repartition()} should be performed before + * {@code flatTransformValues()}. + *

                + * Setting a new value preserves data co-location with respect to the key. + * Thus, no internal data redistribution is required if a key based operator (like an aggregation or join) + * is applied to the result {@code KStream}. (cf. {@link #flatTransform(TransformerSupplier, String...) + * flatTransform()}) + * + * @param valueTransformerSupplier an instance of {@link ValueTransformerSupplier} that generates a newly constructed {@link ValueTransformer} + * The supplier should always generate a new instance. Creating a single {@link ValueTransformer} object + * and returning the same object reference in {@link ValueTransformer} is a + * violation of the supplier pattern and leads to runtime exceptions. + * @param named a {@link Named} config used to name the processor in the topology + * @param stateStoreNames the names of the state stores used by the processor; not required if the supplier + * implements {@link ConnectedStoreProvider#stores()} + * @param the value type of the result stream + * @return a {@code KStream} that contains more or less records with unmodified key and new values (possibly of + * different type) + * @see #mapValues(ValueMapper) + * @see #mapValues(ValueMapperWithKey) + * @see #transform(TransformerSupplier, String...) + * @see #flatTransform(TransformerSupplier, String...) + */ + KStream flatTransformValues(final ValueTransformerSupplier> valueTransformerSupplier, + final Named named, + final String... stateStoreNames); + + /** + * Transform the value of each input record into zero or more new values (with possibly a new + * type) and emit for each new value a record with the same key of the input record and the value. + * A {@link ValueTransformerWithKey} (provided by the given {@link ValueTransformerWithKeySupplier}) is applied to + * each input record value and computes zero or more new values. + * Thus, an input record {@code } can be transformed into output records {@code , , ...}. + * Attaching a state store makes this a stateful record-by-record operation (cf. {@link #flatMapValues(ValueMapperWithKey) flatMapValues()}). + * If you choose not to attach one, this operation is similar to the stateless {@link #flatMapValues(ValueMapperWithKey) flatMapValues()} + * but allows access to the {@code ProcessorContext} and record metadata. + * This is essentially mixing the Processor API into the DSL, and provides all the functionality of the PAPI. + * Furthermore, via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long)} the processing progress can + * be observed and additional periodic actions can be performed. + *

                + * In order for the transformer to use state stores, the stores must be added to the topology and connected to the + * transformer using at least one of two strategies (though it's not required to connect global state stores; read-only + * access to global state stores is available by default). + *

                + * The first strategy is to manually add the {@link StoreBuilder}s via {@link Topology#addStateStore(StoreBuilder, String...)}, + * and specify the store names via {@code stateStoreNames} so they will be connected to the transformer. + *

                {@code
                +     * // create store
                +     * StoreBuilder> keyValueStoreBuilder =
                +     *         Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myValueTransformState"),
                +     *                 Serdes.String(),
                +     *                 Serdes.String());
                +     * // add store
                +     * builder.addStateStore(keyValueStoreBuilder);
                +     *
                +     * KStream outputStream = inputStream.flatTransformValues(new ValueTransformerWithKeySupplier() {
                +     *     public ValueTransformerWithKey get() {
                +     *         return new MyValueTransformerWithKey();
                +     *     }
                +     * }, "myValueTransformState");
                +     * }
                + * The second strategy is for the given {@link ValueTransformerSupplier} to implement {@link ConnectedStoreProvider#stores()}, + * which provides the {@link StoreBuilder}s to be automatically added to the topology and connected to the transformer. + *
                {@code
                +     * class MyValueTransformerWithKeySupplier implements ValueTransformerWithKeySupplier {
                +     *     // supply transformer
                +     *     ValueTransformerWithKey get() {
                +     *         return new MyValueTransformerWithKey();
                +     *     }
                +     *
                +     *     // provide store(s) that will be added and connected to the associated transformer
                +     *     // the store name from the builder ("myValueTransformState") is used to access the store later via the ProcessorContext
                +     *     Set stores() {
                +     *         StoreBuilder> keyValueStoreBuilder =
                +     *                   Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myValueTransformState"),
                +     *                   Serdes.String(),
                +     *                   Serdes.String());
                +     *         return Collections.singleton(keyValueStoreBuilder);
                +     *     }
                +     * }
                +     *
                +     * ...
                +     *
                +     * KStream outputStream = inputStream.flatTransformValues(new MyValueTransformerWithKey());
                +     * }
                + *

                + * With either strategy, within the {@link ValueTransformerWithKey}, the state is obtained via the {@link ProcessorContext}. + * To trigger periodic actions via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) punctuate()}, + * a schedule must be registered. + * The {@link ValueTransformerWithKey} must return an {@link java.lang.Iterable} type (e.g., any + * {@link java.util.Collection} type) in {@link ValueTransformerWithKey#transform(Object, Object) + * transform()}. + * If the return value of {@link ValueTransformerWithKey#transform(Object, Object) ValueTransformerWithKey#transform()} + * is an empty {@link java.lang.Iterable Iterable} or {@code null}, no records are emitted. + * In contrast to {@link #transform(TransformerSupplier, String...) transform()} and + * {@link #flatTransform(TransformerSupplier, String...) flatTransform()}, no additional {@link KeyValue} pairs + * can be emitted via + * {@link org.apache.kafka.streams.processor.ProcessorContext#forward(Object, Object) ProcessorContext.forward()}. + * A {@link org.apache.kafka.streams.errors.StreamsException} is thrown if the {@link ValueTransformerWithKey} tries + * to emit a {@link KeyValue} pair. + *

                {@code
                +     * class MyValueTransformerWithKey implements ValueTransformerWithKey {
                +     *     private StateStore state;
                +     *
                +     *     void init(ProcessorContext context) {
                +     *         this.state = context.getStateStore("myValueTransformState");
                +     *         // punctuate each second, can access this.state
                +     *         context.schedule(Duration.ofSeconds(1), PunctuationType.WALL_CLOCK_TIME, new Punctuator(..));
                +     *     }
                +     *
                +     *     Iterable transform(K readOnlyKey, V value) {
                +     *         // can access this.state and use read-only key
                +     *         List result = new ArrayList<>();
                +     *         for (int i = 0; i < 3; i++) {
                +     *             result.add(new NewValueType(readOnlyKey));
                +     *         }
                +     *         return result; // values
                +     *     }
                +     *
                +     *     void close() {
                +     *         // can access this.state
                +     *     }
                +     * }
                +     * }
                + * Even if any upstream operation was key-changing, no auto-repartition is triggered. + * If repartitioning is required, a call to {@link #repartition()} should be performed before + * {@code flatTransformValues()}. + *

                + * Note that the key is read-only and should not be modified, as this can lead to corrupt partitioning. + * So, setting a new value preserves data co-location with respect to the key. + * Thus, no internal data redistribution is required if a key based operator (like an aggregation or join) + * is applied to the result {@code KStream}. (cf. {@link #flatTransform(TransformerSupplier, String...) + * flatTransform()}) + * + * @param valueTransformerSupplier an instance of {@link ValueTransformerWithKeySupplier} that generates a newly constructed {@link ValueTransformerWithKey} + * The supplier should always generate a new instance. Creating a single {@link ValueTransformerWithKey} object + * and returning the same object reference in {@link ValueTransformerWithKey} is a + * violation of the supplier pattern and leads to runtime exceptions. + * @param stateStoreNames the names of the state stores used by the processor; not required if the supplier + * implements {@link ConnectedStoreProvider#stores()} + * @param the value type of the result stream + * @return a {@code KStream} that contains more or less records with unmodified key and new values (possibly of + * different type) + * @see #mapValues(ValueMapper) + * @see #mapValues(ValueMapperWithKey) + * @see #transform(TransformerSupplier, String...) + * @see #flatTransform(TransformerSupplier, String...) + */ + KStream flatTransformValues(final ValueTransformerWithKeySupplier> valueTransformerSupplier, + final String... stateStoreNames); + + /** + * Transform the value of each input record into zero or more new values (with possibly a new + * type) and emit for each new value a record with the same key of the input record and the value. + * A {@link ValueTransformerWithKey} (provided by the given {@link ValueTransformerWithKeySupplier}) is applied to + * each input record value and computes zero or more new values. + * Thus, an input record {@code } can be transformed into output records {@code , , ...}. + * Attaching a state store makes this a stateful record-by-record operation (cf. {@link #flatMapValues(ValueMapperWithKey) flatMapValues()}). + * If you choose not to attach one, this operation is similar to the stateless {@link #flatMapValues(ValueMapperWithKey) flatMapValues()} + * but allows access to the {@code ProcessorContext} and record metadata. + * This is essentially mixing the Processor API into the DSL, and provides all the functionality of the PAPI. + * Furthermore, via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long)} the processing progress can + * be observed and additional periodic actions can be performed. + *

                + * In order for the transformer to use state stores, the stores must be added to the topology and connected to the + * transformer using at least one of two strategies (though it's not required to connect global state stores; read-only + * access to global state stores is available by default). + *

                + * The first strategy is to manually add the {@link StoreBuilder}s via {@link Topology#addStateStore(StoreBuilder, String...)}, + * and specify the store names via {@code stateStoreNames} so they will be connected to the transformer. + *

                {@code
                +     * // create store
                +     * StoreBuilder> keyValueStoreBuilder =
                +     *         Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myValueTransformState"),
                +     *                 Serdes.String(),
                +     *                 Serdes.String());
                +     * // add store
                +     * builder.addStateStore(keyValueStoreBuilder);
                +     *
                +     * KStream outputStream = inputStream.flatTransformValues(new ValueTransformerWithKeySupplier() {
                +     *     public ValueTransformerWithKey get() {
                +     *         return new MyValueTransformerWithKey();
                +     *     }
                +     * }, "myValueTransformState");
                +     * }
                + * The second strategy is for the given {@link ValueTransformerSupplier} to implement {@link ConnectedStoreProvider#stores()}, + * which provides the {@link StoreBuilder}s to be automatically added to the topology and connected to the transformer. + *
                {@code
                +     * class MyValueTransformerWithKeySupplier implements ValueTransformerWithKeySupplier {
                +     *     // supply transformer
                +     *     ValueTransformerWithKey get() {
                +     *         return new MyValueTransformerWithKey();
                +     *     }
                +     *
                +     *     // provide store(s) that will be added and connected to the associated transformer
                +     *     // the store name from the builder ("myValueTransformState") is used to access the store later via the ProcessorContext
                +     *     Set stores() {
                +     *         StoreBuilder> keyValueStoreBuilder =
                +     *                   Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myValueTransformState"),
                +     *                   Serdes.String(),
                +     *                   Serdes.String());
                +     *         return Collections.singleton(keyValueStoreBuilder);
                +     *     }
                +     * }
                +     *
                +     * ...
                +     *
                +     * KStream outputStream = inputStream.flatTransformValues(new MyValueTransformerWithKey());
                +     * }
                + *

                + * With either strategy, within the {@link ValueTransformerWithKey}, the state is obtained via the {@link ProcessorContext}. + * To trigger periodic actions via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) punctuate()}, + * a schedule must be registered. + * The {@link ValueTransformerWithKey} must return an {@link java.lang.Iterable} type (e.g., any + * {@link java.util.Collection} type) in {@link ValueTransformerWithKey#transform(Object, Object) + * transform()}. + * If the return value of {@link ValueTransformerWithKey#transform(Object, Object) ValueTransformerWithKey#transform()} + * is an empty {@link java.lang.Iterable Iterable} or {@code null}, no records are emitted. + * In contrast to {@link #transform(TransformerSupplier, String...) transform()} and + * {@link #flatTransform(TransformerSupplier, String...) flatTransform()}, no additional {@link KeyValue} pairs + * can be emitted via + * {@link org.apache.kafka.streams.processor.ProcessorContext#forward(Object, Object) ProcessorContext.forward()}. + * A {@link org.apache.kafka.streams.errors.StreamsException} is thrown if the {@link ValueTransformerWithKey} tries + * to emit a {@link KeyValue} pair. + *

                {@code
                +     * class MyValueTransformerWithKey implements ValueTransformerWithKey {
                +     *     private StateStore state;
                +     *
                +     *     void init(ProcessorContext context) {
                +     *         this.state = context.getStateStore("myValueTransformState");
                +     *         // punctuate each second, can access this.state
                +     *         context.schedule(Duration.ofSeconds(1), PunctuationType.WALL_CLOCK_TIME, new Punctuator(..));
                +     *     }
                +     *
                +     *     Iterable transform(K readOnlyKey, V value) {
                +     *         // can access this.state and use read-only key
                +     *         List result = new ArrayList<>();
                +     *         for (int i = 0; i < 3; i++) {
                +     *             result.add(new NewValueType(readOnlyKey));
                +     *         }
                +     *         return result; // values
                +     *     }
                +     *
                +     *     void close() {
                +     *         // can access this.state
                +     *     }
                +     * }
                +     * }
                + * Even if any upstream operation was key-changing, no auto-repartition is triggered. + * If repartitioning is required, a call to {@link #repartition()} should be performed before + * {@code flatTransformValues()}. + *

                + * Note that the key is read-only and should not be modified, as this can lead to corrupt partitioning. + * So, setting a new value preserves data co-location with respect to the key. + * Thus, no internal data redistribution is required if a key based operator (like an aggregation or join) + * is applied to the result {@code KStream}. (cf. {@link #flatTransform(TransformerSupplier, String...) + * flatTransform()}) + * + * @param valueTransformerSupplier an instance of {@link ValueTransformerWithKeySupplier} that generates a newly constructed {@link ValueTransformerWithKey} + * The supplier should always generate a new instance. Creating a single {@link ValueTransformerWithKey} object + * and returning the same object reference in {@link ValueTransformerWithKey} is a + * violation of the supplier pattern and leads to runtime exceptions. + * @param named a {@link Named} config used to name the processor in the topology + * @param stateStoreNames the names of the state stores used by the processor; not required if the supplier + * implements {@link ConnectedStoreProvider#stores()} + * @param the value type of the result stream + * @return a {@code KStream} that contains more or less records with unmodified key and new values (possibly of + * different type) + * @see #mapValues(ValueMapper) + * @see #mapValues(ValueMapperWithKey) + * @see #transform(TransformerSupplier, String...) + * @see #flatTransform(TransformerSupplier, String...) + */ + KStream flatTransformValues(final ValueTransformerWithKeySupplier> valueTransformerSupplier, + final Named named, + final String... stateStoreNames); + + /** + * Process all records in this stream, one record at a time, by applying a + * {@link org.apache.kafka.streams.processor.Processor} (provided by the given + * {@link org.apache.kafka.streams.processor.ProcessorSupplier}). + * Attaching a state store makes this a stateful record-by-record operation (cf. {@link #foreach(ForeachAction)}). + * If you choose not to attach one, this operation is similar to the stateless {@link #foreach(ForeachAction)} + * but allows access to the {@code ProcessorContext} and record metadata. + * This is essentially mixing the Processor API into the DSL, and provides all the functionality of the PAPI. + * Furthermore, via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long)} the processing progress + * can be observed and additional periodic actions can be performed. + * Note that this is a terminal operation that returns void. + *

                + * In order for the processor to use state stores, the stores must be added to the topology and connected to the + * processor using at least one of two strategies (though it's not required to connect global state stores; read-only + * access to global state stores is available by default). + *

                + * The first strategy is to manually add the {@link StoreBuilder}s via {@link Topology#addStateStore(StoreBuilder, String...)}, + * and specify the store names via {@code stateStoreNames} so they will be connected to the processor. + *

                {@code
                +     * // create store
                +     * StoreBuilder> keyValueStoreBuilder =
                +     *         Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myProcessorState"),
                +     *                 Serdes.String(),
                +     *                 Serdes.String());
                +     * // add store
                +     * builder.addStateStore(keyValueStoreBuilder);
                +     *
                +     * KStream outputStream = inputStream.processor(new ProcessorSupplier() {
                +     *     public Processor get() {
                +     *         return new MyProcessor();
                +     *     }
                +     * }, "myProcessorState");
                +     * }
                + * The second strategy is for the given {@link org.apache.kafka.streams.processor.ProcessorSupplier} + * to implement {@link ConnectedStoreProvider#stores()}, + * which provides the {@link StoreBuilder}s to be automatically added to the topology and connected to the processor. + *
                {@code
                +     * class MyProcessorSupplier implements ProcessorSupplier {
                +     *     // supply processor
                +     *     Processor get() {
                +     *         return new MyProcessor();
                +     *     }
                +     *
                +     *     // provide store(s) that will be added and connected to the associated processor
                +     *     // the store name from the builder ("myProcessorState") is used to access the store later via the ProcessorContext
                +     *     Set stores() {
                +     *         StoreBuilder> keyValueStoreBuilder =
                +     *                   Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myProcessorState"),
                +     *                   Serdes.String(),
                +     *                   Serdes.String());
                +     *         return Collections.singleton(keyValueStoreBuilder);
                +     *     }
                +     * }
                +     *
                +     * ...
                +     *
                +     * KStream outputStream = inputStream.process(new MyProcessorSupplier());
                +     * }
                + *

                + * With either strategy, within the {@link org.apache.kafka.streams.processor.Processor}, + * the state is obtained via the {@link org.apache.kafka.streams.processor.ProcessorContext}. + * To trigger periodic actions via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) punctuate()}, + * a schedule must be registered. + *

                {@code
                +     * class MyProcessor implements Processor {
                +     *     private StateStore state;
                +     *
                +     *     void init(ProcessorContext context) {
                +     *         this.state = context.getStateStore("myProcessorState");
                +     *         // punctuate each second, can access this.state
                +     *         context.schedule(Duration.ofSeconds(1), PunctuationType.WALL_CLOCK_TIME, new Punctuator(..));
                +     *     }
                +     *
                +     *     void process(K key, V value) {
                +     *         // can access this.state
                +     *     }
                +     *
                +     *     void close() {
                +     *         // can access this.state
                +     *     }
                +     * }
                +     * }
                + * Even if any upstream operation was key-changing, no auto-repartition is triggered. + * If repartitioning is required, a call to {@link #repartition()} should be performed before {@code process()}. + * + * @param processorSupplier an instance of {@link org.apache.kafka.streams.processor.ProcessorSupplier} + * that generates a newly constructed {@link org.apache.kafka.streams.processor.Processor} + * The supplier should always generate a new instance. Creating a single + * {@link org.apache.kafka.streams.processor.Processor} object + * and returning the same object reference in + * {@link org.apache.kafka.streams.processor.ProcessorSupplier#get()} is a + * violation of the supplier pattern and leads to runtime exceptions. + * @param stateStoreNames the names of the state stores used by the processor; not required if the supplier + * implements {@link ConnectedStoreProvider#stores()} + * @see #foreach(ForeachAction) + * @see #transform(TransformerSupplier, String...) + * @deprecated Since 3.0. Use {@link KStream#process(org.apache.kafka.streams.processor.api.ProcessorSupplier, java.lang.String...)} instead. + */ + @Deprecated + void process(final org.apache.kafka.streams.processor.ProcessorSupplier processorSupplier, + final String... stateStoreNames); + + + /** + * Process all records in this stream, one record at a time, by applying a {@link Processor} (provided by the given + * {@link ProcessorSupplier}). + * Attaching a state store makes this a stateful record-by-record operation (cf. {@link #foreach(ForeachAction)}). + * If you choose not to attach one, this operation is similar to the stateless {@link #foreach(ForeachAction)} + * but allows access to the {@code ProcessorContext} and record metadata. + * This is essentially mixing the Processor API into the DSL, and provides all the functionality of the PAPI. + * Furthermore, via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long)} the processing progress + * can be observed and additional periodic actions can be performed. + * Note that this is a terminal operation that returns void. + *

                + * In order for the processor to use state stores, the stores must be added to the topology and connected to the + * processor using at least one of two strategies (though it's not required to connect global state stores; read-only + * access to global state stores is available by default). + *

                + * The first strategy is to manually add the {@link StoreBuilder}s via {@link Topology#addStateStore(StoreBuilder, String...)}, + * and specify the store names via {@code stateStoreNames} so they will be connected to the processor. + *

                {@code
                +     * // create store
                +     * StoreBuilder> keyValueStoreBuilder =
                +     *         Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myProcessorState"),
                +     *                 Serdes.String(),
                +     *                 Serdes.String());
                +     * // add store
                +     * builder.addStateStore(keyValueStoreBuilder);
                +     *
                +     * KStream outputStream = inputStream.processor(new ProcessorSupplier() {
                +     *     public Processor get() {
                +     *         return new MyProcessor();
                +     *     }
                +     * }, "myProcessorState");
                +     * }
                + * The second strategy is for the given {@link ProcessorSupplier} to implement {@link ConnectedStoreProvider#stores()}, + * which provides the {@link StoreBuilder}s to be automatically added to the topology and connected to the processor. + *
                {@code
                +     * class MyProcessorSupplier implements ProcessorSupplier {
                +     *     // supply processor
                +     *     Processor get() {
                +     *         return new MyProcessor();
                +     *     }
                +     *
                +     *     // provide store(s) that will be added and connected to the associated processor
                +     *     // the store name from the builder ("myProcessorState") is used to access the store later via the ProcessorContext
                +     *     Set stores() {
                +     *         StoreBuilder> keyValueStoreBuilder =
                +     *                   Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myProcessorState"),
                +     *                   Serdes.String(),
                +     *                   Serdes.String());
                +     *         return Collections.singleton(keyValueStoreBuilder);
                +     *     }
                +     * }
                +     *
                +     * ...
                +     *
                +     * KStream outputStream = inputStream.process(new MyProcessorSupplier());
                +     * }
                + *

                + * With either strategy, within the {@link Processor}, the state is obtained via the {@link ProcessorContext}. + * To trigger periodic actions via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) punctuate()}, + * a schedule must be registered. + *

                {@code
                +     * class MyProcessor implements Processor {
                +     *     private StateStore state;
                +     *
                +     *     void init(ProcessorContext context) {
                +     *         this.state = context.getStateStore("myProcessorState");
                +     *         // punctuate each second, can access this.state
                +     *         context.schedule(Duration.ofSeconds(1), PunctuationType.WALL_CLOCK_TIME, new Punctuator(..));
                +     *     }
                +     *
                +     *     void process(K key, V value) {
                +     *         // can access this.state
                +     *     }
                +     *
                +     *     void close() {
                +     *         // can access this.state
                +     *     }
                +     * }
                +     * }
                + * Even if any upstream operation was key-changing, no auto-repartition is triggered. + * If repartitioning is required, a call to {@link #repartition()} should be performed before {@code process()}. + * + * @param processorSupplier an instance of {@link ProcessorSupplier} that generates a newly constructed {@link Processor} + * The supplier should always generate a new instance. Creating a single {@link Processor} object + * and returning the same object reference in {@link ProcessorSupplier#get()} is a + * violation of the supplier pattern and leads to runtime exceptions. + * @param stateStoreNames the names of the state stores used by the processor; not required if the supplier + * implements {@link ConnectedStoreProvider#stores()} + * @see #foreach(ForeachAction) + * @see #transform(TransformerSupplier, String...) + */ + void process(final ProcessorSupplier processorSupplier, + final String... stateStoreNames); + + /** + * Process all records in this stream, one record at a time, by applying a + * {@link org.apache.kafka.streams.processor.Processor} (provided by the given + * {@link org.apache.kafka.streams.processor.ProcessorSupplier}). + * Attaching a state store makes this a stateful record-by-record operation (cf. {@link #foreach(ForeachAction)}). + * If you choose not to attach one, this operation is similar to the stateless {@link #foreach(ForeachAction)} + * but allows access to the {@code ProcessorContext} and record metadata. + * This is essentially mixing the Processor API into the DSL, and provides all the functionality of the PAPI. + * Furthermore, via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long)} the processing progress + * can be observed and additional periodic actions can be performed. + * Note that this is a terminal operation that returns void. + *

                + * In order for the processor to use state stores, the stores must be added to the topology and connected to the + * processor using at least one of two strategies (though it's not required to connect global state stores; read-only + * access to global state stores is available by default). + *

                + * The first strategy is to manually add the {@link StoreBuilder}s via {@link Topology#addStateStore(StoreBuilder, String...)}, + * and specify the store names via {@code stateStoreNames} so they will be connected to the processor. + *

                {@code
                +     * // create store
                +     * StoreBuilder> keyValueStoreBuilder =
                +     *         Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myProcessorState"),
                +     *                 Serdes.String(),
                +     *                 Serdes.String());
                +     * // add store
                +     * builder.addStateStore(keyValueStoreBuilder);
                +     *
                +     * KStream outputStream = inputStream.processor(new ProcessorSupplier() {
                +     *     public Processor get() {
                +     *         return new MyProcessor();
                +     *     }
                +     * }, "myProcessorState");
                +     * }
                + * The second strategy is for the given {@link org.apache.kafka.streams.processor.ProcessorSupplier} + * to implement {@link ConnectedStoreProvider#stores()}, + * which provides the {@link StoreBuilder}s to be automatically added to the topology and connected to the processor. + *
                {@code
                +     * class MyProcessorSupplier implements ProcessorSupplier {
                +     *     // supply processor
                +     *     Processor get() {
                +     *         return new MyProcessor();
                +     *     }
                +     *
                +     *     // provide store(s) that will be added and connected to the associated processor
                +     *     // the store name from the builder ("myProcessorState") is used to access the store later via the ProcessorContext
                +     *     Set stores() {
                +     *         StoreBuilder> keyValueStoreBuilder =
                +     *                   Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myProcessorState"),
                +     *                   Serdes.String(),
                +     *                   Serdes.String());
                +     *         return Collections.singleton(keyValueStoreBuilder);
                +     *     }
                +     * }
                +     *
                +     * ...
                +     *
                +     * KStream outputStream = inputStream.process(new MyProcessorSupplier());
                +     * }
                + *

                + * With either strategy, within the {@link org.apache.kafka.streams.processor.Processor}, + * the state is obtained via the {@link org.apache.kafka.streams.processor.ProcessorContext}. + * To trigger periodic actions via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) punctuate()}, + * a schedule must be registered. + *

                {@code
                +     * class MyProcessor implements Processor {
                +     *     private StateStore state;
                +     *
                +     *     void init(ProcessorContext context) {
                +     *         this.state = context.getStateStore("myProcessorState");
                +     *         // punctuate each second, can access this.state
                +     *         context.schedule(Duration.ofSeconds(1), PunctuationType.WALL_CLOCK_TIME, new Punctuator(..));
                +     *     }
                +     *
                +     *     void process(K key, V value) {
                +     *         // can access this.state
                +     *     }
                +     *
                +     *     void close() {
                +     *         // can access this.state
                +     *     }
                +     * }
                +     * }
                + * Even if any upstream operation was key-changing, no auto-repartition is triggered. + * If repartitioning is required, a call to {@link #repartition()} should be performed before {@code process()}. + * + * @param processorSupplier an instance of {@link org.apache.kafka.streams.processor.ProcessorSupplier} + * that generates a newly constructed {@link org.apache.kafka.streams.processor.Processor} + * The supplier should always generate a new instance. Creating a single + * {@link org.apache.kafka.streams.processor.Processor} object + * and returning the same object reference in + * {@link org.apache.kafka.streams.processor.ProcessorSupplier#get()} is a + * violation of the supplier pattern and leads to runtime exceptions. + * @param named a {@link Named} config used to name the processor in the topology + * @param stateStoreNames the names of the state store used by the processor + * @see #foreach(ForeachAction) + * @see #transform(TransformerSupplier, String...) + * @deprecated Since 3.0. Use {@link KStream#process(org.apache.kafka.streams.processor.api.ProcessorSupplier, org.apache.kafka.streams.kstream.Named, java.lang.String...)} instead. + */ + @Deprecated + void process(final org.apache.kafka.streams.processor.ProcessorSupplier processorSupplier, + final Named named, + final String... stateStoreNames); + + /** + * Process all records in this stream, one record at a time, by applying a {@link Processor} (provided by the given + * {@link ProcessorSupplier}). + * Attaching a state store makes this a stateful record-by-record operation (cf. {@link #foreach(ForeachAction)}). + * If you choose not to attach one, this operation is similar to the stateless {@link #foreach(ForeachAction)} + * but allows access to the {@code ProcessorContext} and record metadata. + * This is essentially mixing the Processor API into the DSL, and provides all the functionality of the PAPI. + * Furthermore, via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long)} the processing progress + * can be observed and additional periodic actions can be performed. + * Note that this is a terminal operation that returns void. + *

                + * In order for the processor to use state stores, the stores must be added to the topology and connected to the + * processor using at least one of two strategies (though it's not required to connect global state stores; read-only + * access to global state stores is available by default). + *

                + * The first strategy is to manually add the {@link StoreBuilder}s via {@link Topology#addStateStore(StoreBuilder, String...)}, + * and specify the store names via {@code stateStoreNames} so they will be connected to the processor. + *

                {@code
                +     * // create store
                +     * StoreBuilder> keyValueStoreBuilder =
                +     *         Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myProcessorState"),
                +     *                 Serdes.String(),
                +     *                 Serdes.String());
                +     * // add store
                +     * builder.addStateStore(keyValueStoreBuilder);
                +     *
                +     * KStream outputStream = inputStream.processor(new ProcessorSupplier() {
                +     *     public Processor get() {
                +     *         return new MyProcessor();
                +     *     }
                +     * }, "myProcessorState");
                +     * }
                + * The second strategy is for the given {@link ProcessorSupplier} to implement {@link ConnectedStoreProvider#stores()}, + * which provides the {@link StoreBuilder}s to be automatically added to the topology and connected to the processor. + *
                {@code
                +     * class MyProcessorSupplier implements ProcessorSupplier {
                +     *     // supply processor
                +     *     Processor get() {
                +     *         return new MyProcessor();
                +     *     }
                +     *
                +     *     // provide store(s) that will be added and connected to the associated processor
                +     *     // the store name from the builder ("myProcessorState") is used to access the store later via the ProcessorContext
                +     *     Set stores() {
                +     *         StoreBuilder> keyValueStoreBuilder =
                +     *                   Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myProcessorState"),
                +     *                   Serdes.String(),
                +     *                   Serdes.String());
                +     *         return Collections.singleton(keyValueStoreBuilder);
                +     *     }
                +     * }
                +     *
                +     * ...
                +     *
                +     * KStream outputStream = inputStream.process(new MyProcessorSupplier());
                +     * }
                + *

                + * With either strategy, within the {@link Processor}, the state is obtained via the {@link ProcessorContext}. + * To trigger periodic actions via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) punctuate()}, + * a schedule must be registered. + *

                {@code
                +     * class MyProcessor implements Processor {
                +     *     private StateStore state;
                +     *
                +     *     void init(ProcessorContext context) {
                +     *         this.state = context.getStateStore("myProcessorState");
                +     *         // punctuate each second, can access this.state
                +     *         context.schedule(Duration.ofSeconds(1), PunctuationType.WALL_CLOCK_TIME, new Punctuator(..));
                +     *     }
                +     *
                +     *     void process(K key, V value) {
                +     *         // can access this.state
                +     *     }
                +     *
                +     *     void close() {
                +     *         // can access this.state
                +     *     }
                +     * }
                +     * }
                + * Even if any upstream operation was key-changing, no auto-repartition is triggered. + * If repartitioning is required, a call to {@link #repartition()} should be performed before {@code process()}. + * + * @param processorSupplier an instance of {@link ProcessorSupplier} that generates a newly constructed {@link Processor} + * The supplier should always generate a new instance. Creating a single {@link Processor} object + * and returning the same object reference in {@link ProcessorSupplier#get()} is a + * violation of the supplier pattern and leads to runtime exceptions. + * @param named a {@link Named} config used to name the processor in the topology + * @param stateStoreNames the names of the state store used by the processor + * @see #foreach(ForeachAction) + * @see #transform(TransformerSupplier, String...) + */ + void process(final ProcessorSupplier processorSupplier, + final Named named, + final String... stateStoreNames); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/KTable.java b/streams/src/main/java/org/apache/kafka/streams/kstream/KTable.java new file mode 100644 index 0000000..d1c8623 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/KTable.java @@ -0,0 +1,2369 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StoreQueryParameters; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StreamPartitioner; +import org.apache.kafka.streams.state.KeyValueBytesStoreSupplier; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.ReadOnlyKeyValueStore; + +import java.util.function.Function; + +/** + * {@code KTable} is an abstraction of a changelog stream from a primary-keyed table. + * Each record in this changelog stream is an update on the primary-keyed table with the record key as the primary key. + *

                + * A {@code KTable} is either {@link StreamsBuilder#table(String) defined from a single Kafka topic} that is + * consumed message by message or the result of a {@code KTable} transformation. + * An aggregation of a {@link KStream} also yields a {@code KTable}. + *

                + * A {@code KTable} can be transformed record by record, joined with another {@code KTable} or {@link KStream}, or + * can be re-partitioned and aggregated into a new {@code KTable}. + *

                + * Some {@code KTable}s have an internal state (a {@link ReadOnlyKeyValueStore}) and are therefore queryable via the + * interactive queries API. + * For example: + *

                {@code
                + *     final KTable table = ...
                + *     ...
                + *     final KafkaStreams streams = ...;
                + *     streams.start()
                + *     ...
                + *     final String queryableStoreName = table.queryableStoreName(); // returns null if KTable is not queryable
                + *     ReadOnlyKeyValueStore view = streams.store(queryableStoreName, QueryableStoreTypes.timestampedKeyValueStore());
                + *     view.get(key);
                + *}
                + *

                + * Records from the source topic that have null keys are dropped. + * + * @param Type of primary keys + * @param Type of value changes + * @see KStream + * @see KGroupedTable + * @see GlobalKTable + * @see StreamsBuilder#table(String) + */ +public interface KTable { + + /** + * Create a new {@code KTable} that consists of all records of this {@code KTable} which satisfy the given + * predicate, with default serializers, deserializers, and state store. + * All records that do not satisfy the predicate are dropped. + * For each {@code KTable} update, the filter is evaluated based on the current update + * record and then an update record is produced for the result {@code KTable}. + * This is a stateless record-by-record operation. + *

                + * Note that {@code filter} for a changelog stream works differently than {@link KStream#filter(Predicate) + * record stream filters}, because {@link KeyValue records} with {@code null} values (so-called tombstone records) + * have delete semantics. + * Thus, for tombstones the provided filter predicate is not evaluated but the tombstone record is forwarded + * directly if required (i.e., if there is anything to be deleted). + * Furthermore, for each record that gets dropped (i.e., does not satisfy the given predicate) a tombstone record + * is forwarded. + * + * @param predicate a filter {@link Predicate} that is applied to each record + * @return a {@code KTable} that contains only those records that satisfy the given predicate + * @see #filterNot(Predicate) + */ + KTable filter(final Predicate predicate); + + /** + * Create a new {@code KTable} that consists of all records of this {@code KTable} which satisfy the given + * predicate, with default serializers, deserializers, and state store. + * All records that do not satisfy the predicate are dropped. + * For each {@code KTable} update, the filter is evaluated based on the current update + * record and then an update record is produced for the result {@code KTable}. + * This is a stateless record-by-record operation. + *

                + * Note that {@code filter} for a changelog stream works differently than {@link KStream#filter(Predicate) + * record stream filters}, because {@link KeyValue records} with {@code null} values (so-called tombstone records) + * have delete semantics. + * Thus, for tombstones the provided filter predicate is not evaluated but the tombstone record is forwarded + * directly if required (i.e., if there is anything to be deleted). + * Furthermore, for each record that gets dropped (i.e., does not satisfy the given predicate) a tombstone record + * is forwarded. + * + * @param predicate a filter {@link Predicate} that is applied to each record + * @param named a {@link Named} config used to name the processor in the topology + * @return a {@code KTable} that contains only those records that satisfy the given predicate + * @see #filterNot(Predicate) + */ + KTable filter(final Predicate predicate, final Named named); + + /** + * Create a new {@code KTable} that consists of all records of this {@code KTable} which satisfy the given + * predicate, with the {@link Serde key serde}, {@link Serde value serde}, and the underlying + * {@link KeyValueStore materialized state storage} configured in the {@link Materialized} instance. + * All records that do not satisfy the predicate are dropped. + * For each {@code KTable} update, the filter is evaluated based on the current update + * record and then an update record is produced for the result {@code KTable}. + * This is a stateless record-by-record operation. + *

                + * Note that {@code filter} for a changelog stream works differently than {@link KStream#filter(Predicate) + * record stream filters}, because {@link KeyValue records} with {@code null} values (so-called tombstone records) + * have delete semantics. + * Thus, for tombstones the provided filter predicate is not evaluated but the tombstone record is forwarded + * directly if required (i.e., if there is anything to be deleted). + * Furthermore, for each record that gets dropped (i.e., does not satisfy the given predicate) a tombstone record + * is forwarded. + *

                + * To query the local {@link ReadOnlyKeyValueStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // filtering words
                +     * ReadOnlyKeyValueStore> localStore = streams.store(queryableStoreName, QueryableStoreTypes.>timestampedKeyValueStore());
                +     * K key = "some-word";
                +     * ValueAndTimestamp valueForKey = localStore.get(key); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + * The store name to query with is specified by {@link Materialized#as(String)} or {@link Materialized#as(KeyValueBytesStoreSupplier)}. + *

                + * + * @param predicate a filter {@link Predicate} that is applied to each record + * @param materialized a {@link Materialized} that describes how the {@link StateStore} for the resulting {@code KTable} + * should be materialized. Cannot be {@code null} + * @return a {@code KTable} that contains only those records that satisfy the given predicate + * @see #filterNot(Predicate, Materialized) + */ + KTable filter(final Predicate predicate, + final Materialized> materialized); + + /** + * Create a new {@code KTable} that consists of all records of this {@code KTable} which satisfy the given + * predicate, with the {@link Serde key serde}, {@link Serde value serde}, and the underlying + * {@link KeyValueStore materialized state storage} configured in the {@link Materialized} instance. + * All records that do not satisfy the predicate are dropped. + * For each {@code KTable} update, the filter is evaluated based on the current update + * record and then an update record is produced for the result {@code KTable}. + * This is a stateless record-by-record operation. + *

                + * Note that {@code filter} for a changelog stream works differently than {@link KStream#filter(Predicate) + * record stream filters}, because {@link KeyValue records} with {@code null} values (so-called tombstone records) + * have delete semantics. + * Thus, for tombstones the provided filter predicate is not evaluated but the tombstone record is forwarded + * directly if required (i.e., if there is anything to be deleted). + * Furthermore, for each record that gets dropped (i.e., does not satisfy the given predicate) a tombstone record + * is forwarded. + *

                + * To query the local {@link ReadOnlyKeyValueStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // filtering words
                +     * ReadOnlyKeyValueStore> localStore = streams.store(queryableStoreName, QueryableStoreTypes.>timestampedKeyValueStore());
                +     * K key = "some-word";
                +     * ValueAndTimestamp valueForKey = localStore.get(key); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + * The store name to query with is specified by {@link Materialized#as(String)} or {@link Materialized#as(KeyValueBytesStoreSupplier)}. + *

                + * + * @param predicate a filter {@link Predicate} that is applied to each record + * @param named a {@link Named} config used to name the processor in the topology + * @param materialized a {@link Materialized} that describes how the {@link StateStore} for the resulting {@code KTable} + * should be materialized. Cannot be {@code null} + * @return a {@code KTable} that contains only those records that satisfy the given predicate + * @see #filterNot(Predicate, Materialized) + */ + KTable filter(final Predicate predicate, + final Named named, + final Materialized> materialized); + + /** + * Create a new {@code KTable} that consists all records of this {@code KTable} which do not satisfy the + * given predicate, with default serializers, deserializers, and state store. + * All records that do satisfy the predicate are dropped. + * For each {@code KTable} update, the filter is evaluated based on the current update + * record and then an update record is produced for the result {@code KTable}. + * This is a stateless record-by-record operation. + *

                + * Note that {@code filterNot} for a changelog stream works differently than {@link KStream#filterNot(Predicate) + * record stream filters}, because {@link KeyValue records} with {@code null} values (so-called tombstone records) + * have delete semantics. + * Thus, for tombstones the provided filter predicate is not evaluated but the tombstone record is forwarded + * directly if required (i.e., if there is anything to be deleted). + * Furthermore, for each record that gets dropped (i.e., does satisfy the given predicate) a tombstone record is + * forwarded. + * + * @param predicate a filter {@link Predicate} that is applied to each record + * @return a {@code KTable} that contains only those records that do not satisfy the given predicate + * @see #filter(Predicate) + */ + KTable filterNot(final Predicate predicate); + + /** + * Create a new {@code KTable} that consists all records of this {@code KTable} which do not satisfy the + * given predicate, with default serializers, deserializers, and state store. + * All records that do satisfy the predicate are dropped. + * For each {@code KTable} update, the filter is evaluated based on the current update + * record and then an update record is produced for the result {@code KTable}. + * This is a stateless record-by-record operation. + *

                + * Note that {@code filterNot} for a changelog stream works differently than {@link KStream#filterNot(Predicate) + * record stream filters}, because {@link KeyValue records} with {@code null} values (so-called tombstone records) + * have delete semantics. + * Thus, for tombstones the provided filter predicate is not evaluated but the tombstone record is forwarded + * directly if required (i.e., if there is anything to be deleted). + * Furthermore, for each record that gets dropped (i.e., does satisfy the given predicate) a tombstone record is + * forwarded. + * + * @param predicate a filter {@link Predicate} that is applied to each record + * @param named a {@link Named} config used to name the processor in the topology + * @return a {@code KTable} that contains only those records that do not satisfy the given predicate + * @see #filter(Predicate) + */ + KTable filterNot(final Predicate predicate, final Named named); + + /** + * Create a new {@code KTable} that consists all records of this {@code KTable} which do not satisfy the + * given predicate, with the {@link Serde key serde}, {@link Serde value serde}, and the underlying + * {@link KeyValueStore materialized state storage} configured in the {@link Materialized} instance. + * All records that do satisfy the predicate are dropped. + * For each {@code KTable} update, the filter is evaluated based on the current update + * record and then an update record is produced for the result {@code KTable}. + * This is a stateless record-by-record operation. + *

                + * Note that {@code filterNot} for a changelog stream works differently than {@link KStream#filterNot(Predicate) + * record stream filters}, because {@link KeyValue records} with {@code null} values (so-called tombstone records) + * have delete semantics. + * Thus, for tombstones the provided filter predicate is not evaluated but the tombstone record is forwarded + * directly if required (i.e., if there is anything to be deleted). + * Furthermore, for each record that gets dropped (i.e., does satisfy the given predicate) a tombstone record is + * forwarded. + *

                + * To query the local {@link ReadOnlyKeyValueStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // filtering words
                +     * ReadOnlyKeyValueStore> localStore = streams.store(queryableStoreName, QueryableStoreTypes.>timestampedKeyValueStore());
                +     * K key = "some-word";
                +     * ValueAndTimestamp valueForKey = localStore.get(key); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + * The store name to query with is specified by {@link Materialized#as(String)} or {@link Materialized#as(KeyValueBytesStoreSupplier)}. + *

                + * @param predicate a filter {@link Predicate} that is applied to each record + * @param materialized a {@link Materialized} that describes how the {@link StateStore} for the resulting {@code KTable} + * should be materialized. Cannot be {@code null} + * @return a {@code KTable} that contains only those records that do not satisfy the given predicate + * @see #filter(Predicate, Materialized) + */ + KTable filterNot(final Predicate predicate, + final Materialized> materialized); + + /** + * Create a new {@code KTable} that consists all records of this {@code KTable} which do not satisfy the + * given predicate, with the {@link Serde key serde}, {@link Serde value serde}, and the underlying + * {@link KeyValueStore materialized state storage} configured in the {@link Materialized} instance. + * All records that do satisfy the predicate are dropped. + * For each {@code KTable} update, the filter is evaluated based on the current update + * record and then an update record is produced for the result {@code KTable}. + * This is a stateless record-by-record operation. + *

                + * Note that {@code filterNot} for a changelog stream works differently than {@link KStream#filterNot(Predicate) + * record stream filters}, because {@link KeyValue records} with {@code null} values (so-called tombstone records) + * have delete semantics. + * Thus, for tombstones the provided filter predicate is not evaluated but the tombstone record is forwarded + * directly if required (i.e., if there is anything to be deleted). + * Furthermore, for each record that gets dropped (i.e., does satisfy the given predicate) a tombstone record is + * forwarded. + *

                + * To query the local {@link ReadOnlyKeyValueStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // filtering words
                +     * ReadOnlyKeyValueStore> localStore = streams.store(queryableStoreName, QueryableStoreTypes.>timestampedKeyValueStore());
                +     * K key = "some-word";
                +     * ValueAndTimestamp valueForKey = localStore.get(key); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + * The store name to query with is specified by {@link Materialized#as(String)} or {@link Materialized#as(KeyValueBytesStoreSupplier)}. + *

                + * @param predicate a filter {@link Predicate} that is applied to each record + * @param named a {@link Named} config used to name the processor in the topology + * @param materialized a {@link Materialized} that describes how the {@link StateStore} for the resulting {@code KTable} + * should be materialized. Cannot be {@code null} + * @return a {@code KTable} that contains only those records that do not satisfy the given predicate + * @see #filter(Predicate, Materialized) + */ + KTable filterNot(final Predicate predicate, + final Named named, + final Materialized> materialized); + + /** + * Create a new {@code KTable} by transforming the value of each record in this {@code KTable} into a new value + * (with possibly a new type) in the new {@code KTable}, with default serializers, deserializers, and state store. + * For each {@code KTable} update the provided {@link ValueMapper} is applied to the value of the updated record and + * computes a new value for it, resulting in an updated record for the result {@code KTable}. + * Thus, an input record {@code } can be transformed into an output record {@code }. + * This is a stateless record-by-record operation. + *

                + * The example below counts the number of token of the value string. + *

                {@code
                +     * KTable inputTable = builder.table("topic");
                +     * KTable outputTable = inputTable.mapValues(value -> value.split(" ").length);
                +     * }
                + *

                + * This operation preserves data co-location with respect to the key. + * Thus, no internal data redistribution is required if a key based operator (like a join) is applied to + * the result {@code KTable}. + *

                + * Note that {@code mapValues} for a changelog stream works differently than {@link KStream#mapValues(ValueMapper) + * record stream filters}, because {@link KeyValue records} with {@code null} values (so-called tombstone records) + * have delete semantics. + * Thus, for tombstones the provided value-mapper is not evaluated but the tombstone record is forwarded directly to + * delete the corresponding record in the result {@code KTable}. + * + * @param mapper a {@link ValueMapper} that computes a new output value + * @param the value type of the result {@code KTable} + * @return a {@code KTable} that contains records with unmodified keys and new values (possibly of different type) + */ + KTable mapValues(final ValueMapper mapper); + + /** + * Create a new {@code KTable} by transforming the value of each record in this {@code KTable} into a new value + * (with possibly a new type) in the new {@code KTable}, with default serializers, deserializers, and state store. + * For each {@code KTable} update the provided {@link ValueMapper} is applied to the value of the updated record and + * computes a new value for it, resulting in an updated record for the result {@code KTable}. + * Thus, an input record {@code } can be transformed into an output record {@code }. + * This is a stateless record-by-record operation. + *

                + * The example below counts the number of token of the value string. + *

                {@code
                +     * KTable inputTable = builder.table("topic");
                +     * KTable outputTable = inputTable.mapValues(value -> value.split(" ").length, Named.as("countTokenValue"));
                +     * }
                + *

                + * This operation preserves data co-location with respect to the key. + * Thus, no internal data redistribution is required if a key based operator (like a join) is applied to + * the result {@code KTable}. + *

                + * Note that {@code mapValues} for a changelog stream works differently than {@link KStream#mapValues(ValueMapper) + * record stream filters}, because {@link KeyValue records} with {@code null} values (so-called tombstone records) + * have delete semantics. + * Thus, for tombstones the provided value-mapper is not evaluated but the tombstone record is forwarded directly to + * delete the corresponding record in the result {@code KTable}. + * + * @param mapper a {@link ValueMapper} that computes a new output value + * @param named a {@link Named} config used to name the processor in the topology + * @param the value type of the result {@code KTable} + * @return a {@code KTable} that contains records with unmodified keys and new values (possibly of different type) + */ + KTable mapValues(final ValueMapper mapper, + final Named named); + + /** + * Create a new {@code KTable} by transforming the value of each record in this {@code KTable} into a new value + * (with possibly a new type) in the new {@code KTable}, with default serializers, deserializers, and state store. + * For each {@code KTable} update the provided {@link ValueMapperWithKey} is applied to the value of the update + * record and computes a new value for it, resulting in an updated record for the result {@code KTable}. + * Thus, an input record {@code } can be transformed into an output record {@code }. + * This is a stateless record-by-record operation. + *

                + * The example below counts the number of token of value and key strings. + *

                {@code
                +     * KTable inputTable = builder.table("topic");
                +     * KTable outputTable =
                +     *  inputTable.mapValues((readOnlyKey, value) -> readOnlyKey.split(" ").length + value.split(" ").length);
                +     * }
                + *

                + * Note that the key is read-only and should not be modified, as this can lead to corrupt partitioning. + * This operation preserves data co-location with respect to the key. + * Thus, no internal data redistribution is required if a key based operator (like a join) is applied to + * the result {@code KTable}. + *

                + * Note that {@code mapValues} for a changelog stream works differently than {@link KStream#mapValues(ValueMapperWithKey) + * record stream filters}, because {@link KeyValue records} with {@code null} values (so-called tombstone records) + * have delete semantics. + * Thus, for tombstones the provided value-mapper is not evaluated but the tombstone record is forwarded directly to + * delete the corresponding record in the result {@code KTable}. + * + * @param mapper a {@link ValueMapperWithKey} that computes a new output value + * @param the value type of the result {@code KTable} + * @return a {@code KTable} that contains records with unmodified keys and new values (possibly of different type) + */ + KTable mapValues(final ValueMapperWithKey mapper); + + /** + * Create a new {@code KTable} by transforming the value of each record in this {@code KTable} into a new value + * (with possibly a new type) in the new {@code KTable}, with default serializers, deserializers, and state store. + * For each {@code KTable} update the provided {@link ValueMapperWithKey} is applied to the value of the update + * record and computes a new value for it, resulting in an updated record for the result {@code KTable}. + * Thus, an input record {@code } can be transformed into an output record {@code }. + * This is a stateless record-by-record operation. + *

                + * The example below counts the number of token of value and key strings. + *

                {@code
                +     * KTable inputTable = builder.table("topic");
                +     * KTable outputTable =
                +     *  inputTable.mapValues((readOnlyKey, value) -> readOnlyKey.split(" ").length + value.split(" ").length, Named.as("countTokenValueAndKey"));
                +     * }
                + *

                + * Note that the key is read-only and should not be modified, as this can lead to corrupt partitioning. + * This operation preserves data co-location with respect to the key. + * Thus, no internal data redistribution is required if a key based operator (like a join) is applied to + * the result {@code KTable}. + *

                + * Note that {@code mapValues} for a changelog stream works differently than {@link KStream#mapValues(ValueMapperWithKey) + * record stream filters}, because {@link KeyValue records} with {@code null} values (so-called tombstone records) + * have delete semantics. + * Thus, for tombstones the provided value-mapper is not evaluated but the tombstone record is forwarded directly to + * delete the corresponding record in the result {@code KTable}. + * + * @param mapper a {@link ValueMapperWithKey} that computes a new output value + * @param named a {@link Named} config used to name the processor in the topology + * @param the value type of the result {@code KTable} + * @return a {@code KTable} that contains records with unmodified keys and new values (possibly of different type) + */ + KTable mapValues(final ValueMapperWithKey mapper, + final Named named); + + /** + * Create a new {@code KTable} by transforming the value of each record in this {@code KTable} into a new value + * (with possibly a new type) in the new {@code KTable}, with the {@link Serde key serde}, {@link Serde value serde}, + * and the underlying {@link KeyValueStore materialized state storage} configured in the {@link Materialized} + * instance. + * For each {@code KTable} update the provided {@link ValueMapper} is applied to the value of the updated record and + * computes a new value for it, resulting in an updated record for the result {@code KTable}. + * Thus, an input record {@code } can be transformed into an output record {@code }. + * This is a stateless record-by-record operation. + *

                + * The example below counts the number of token of the value string. + *

                {@code
                +     * KTable inputTable = builder.table("topic");
                +     * KTable outputTable = inputTable.mapValue(new ValueMapper {
                +     *     Integer apply(String value) {
                +     *         return value.split(" ").length;
                +     *     }
                +     * });
                +     * }
                + *

                + * To query the local {@link KeyValueStore} representing outputTable above it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + * The store name to query with is specified by {@link Materialized#as(String)} or {@link Materialized#as(KeyValueBytesStoreSupplier)}. + *

                + * This operation preserves data co-location with respect to the key. + * Thus, no internal data redistribution is required if a key based operator (like a join) is applied to + * the result {@code KTable}. + *

                + * Note that {@code mapValues} for a changelog stream works differently than {@link KStream#mapValues(ValueMapper) + * record stream filters}, because {@link KeyValue records} with {@code null} values (so-called tombstone records) + * have delete semantics. + * Thus, for tombstones the provided value-mapper is not evaluated but the tombstone record is forwarded directly to + * delete the corresponding record in the result {@code KTable}. + * + * @param mapper a {@link ValueMapper} that computes a new output value + * @param materialized a {@link Materialized} that describes how the {@link StateStore} for the resulting {@code KTable} + * should be materialized. Cannot be {@code null} + * @param the value type of the result {@code KTable} + * + * @return a {@code KTable} that contains records with unmodified keys and new values (possibly of different type) + */ + KTable mapValues(final ValueMapper mapper, + final Materialized> materialized); + + /** + * Create a new {@code KTable} by transforming the value of each record in this {@code KTable} into a new value + * (with possibly a new type) in the new {@code KTable}, with the {@link Serde key serde}, {@link Serde value serde}, + * and the underlying {@link KeyValueStore materialized state storage} configured in the {@link Materialized} + * instance. + * For each {@code KTable} update the provided {@link ValueMapper} is applied to the value of the updated record and + * computes a new value for it, resulting in an updated record for the result {@code KTable}. + * Thus, an input record {@code } can be transformed into an output record {@code }. + * This is a stateless record-by-record operation. + *

                + * The example below counts the number of token of the value string. + *

                {@code
                +     * KTable inputTable = builder.table("topic");
                +     * KTable outputTable = inputTable.mapValue(new ValueMapper {
                +     *     Integer apply(String value) {
                +     *         return value.split(" ").length;
                +     *     }
                +     * });
                +     * }
                + *

                + * To query the local {@link KeyValueStore} representing outputTable above it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + * The store name to query with is specified by {@link Materialized#as(String)} or {@link Materialized#as(KeyValueBytesStoreSupplier)}. + *

                + * This operation preserves data co-location with respect to the key. + * Thus, no internal data redistribution is required if a key based operator (like a join) is applied to + * the result {@code KTable}. + *

                + * Note that {@code mapValues} for a changelog stream works differently than {@link KStream#mapValues(ValueMapper) + * record stream filters}, because {@link KeyValue records} with {@code null} values (so-called tombstone records) + * have delete semantics. + * Thus, for tombstones the provided value-mapper is not evaluated but the tombstone record is forwarded directly to + * delete the corresponding record in the result {@code KTable}. + * + * @param mapper a {@link ValueMapper} that computes a new output value + * @param named a {@link Named} config used to name the processor in the topology + * @param materialized a {@link Materialized} that describes how the {@link StateStore} for the resulting {@code KTable} + * should be materialized. Cannot be {@code null} + * @param the value type of the result {@code KTable} + * + * @return a {@code KTable} that contains records with unmodified keys and new values (possibly of different type) + */ + KTable mapValues(final ValueMapper mapper, + final Named named, + final Materialized> materialized); + + /** + * Create a new {@code KTable} by transforming the value of each record in this {@code KTable} into a new value + * (with possibly a new type) in the new {@code KTable}, with the {@link Serde key serde}, {@link Serde value serde}, + * and the underlying {@link KeyValueStore materialized state storage} configured in the {@link Materialized} + * instance. + * For each {@code KTable} update the provided {@link ValueMapperWithKey} is applied to the value of the update + * record and computes a new value for it, resulting in an updated record for the result {@code KTable}. + * Thus, an input record {@code } can be transformed into an output record {@code }. + * This is a stateless record-by-record operation. + *

                + * The example below counts the number of token of value and key strings. + *

                {@code
                +     * KTable inputTable = builder.table("topic");
                +     * KTable outputTable = inputTable.mapValue(new ValueMapperWithKey {
                +     *     Integer apply(String readOnlyKey, String value) {
                +     *          return readOnlyKey.split(" ").length + value.split(" ").length;
                +     *     }
                +     * });
                +     * }
                + *

                + * To query the local {@link KeyValueStore} representing outputTable above it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters)} KafkaStreams#store(...)}: + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + * The store name to query with is specified by {@link Materialized#as(String)} or {@link Materialized#as(KeyValueBytesStoreSupplier)}. + *

                + * Note that the key is read-only and should not be modified, as this can lead to corrupt partitioning. + * This operation preserves data co-location with respect to the key. + * Thus, no internal data redistribution is required if a key based operator (like a join) is applied to + * the result {@code KTable}. + *

                + * Note that {@code mapValues} for a changelog stream works differently than {@link KStream#mapValues(ValueMapper) + * record stream filters}, because {@link KeyValue records} with {@code null} values (so-called tombstone records) + * have delete semantics. + * Thus, for tombstones the provided value-mapper is not evaluated but the tombstone record is forwarded directly to + * delete the corresponding record in the result {@code KTable}. + * + * @param mapper a {@link ValueMapperWithKey} that computes a new output value + * @param materialized a {@link Materialized} that describes how the {@link StateStore} for the resulting {@code KTable} + * should be materialized. Cannot be {@code null} + * @param the value type of the result {@code KTable} + * + * @return a {@code KTable} that contains records with unmodified keys and new values (possibly of different type) + */ + KTable mapValues(final ValueMapperWithKey mapper, + final Materialized> materialized); + + /** + * Create a new {@code KTable} by transforming the value of each record in this {@code KTable} into a new value + * (with possibly a new type) in the new {@code KTable}, with the {@link Serde key serde}, {@link Serde value serde}, + * and the underlying {@link KeyValueStore materialized state storage} configured in the {@link Materialized} + * instance. + * For each {@code KTable} update the provided {@link ValueMapperWithKey} is applied to the value of the update + * record and computes a new value for it, resulting in an updated record for the result {@code KTable}. + * Thus, an input record {@code } can be transformed into an output record {@code }. + * This is a stateless record-by-record operation. + *

                + * The example below counts the number of token of value and key strings. + *

                {@code
                +     * KTable inputTable = builder.table("topic");
                +     * KTable outputTable = inputTable.mapValue(new ValueMapperWithKey {
                +     *     Integer apply(String readOnlyKey, String value) {
                +     *          return readOnlyKey.split(" ").length + value.split(" ").length;
                +     *     }
                +     * });
                +     * }
                + *

                + * To query the local {@link KeyValueStore} representing outputTable above it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + * The store name to query with is specified by {@link Materialized#as(String)} or {@link Materialized#as(KeyValueBytesStoreSupplier)}. + *

                + * Note that the key is read-only and should not be modified, as this can lead to corrupt partitioning. + * This operation preserves data co-location with respect to the key. + * Thus, no internal data redistribution is required if a key based operator (like a join) is applied to + * the result {@code KTable}. + *

                + * Note that {@code mapValues} for a changelog stream works differently than {@link KStream#mapValues(ValueMapper) + * record stream filters}, because {@link KeyValue records} with {@code null} values (so-called tombstone records) + * have delete semantics. + * Thus, for tombstones the provided value-mapper is not evaluated but the tombstone record is forwarded directly to + * delete the corresponding record in the result {@code KTable}. + * + * @param mapper a {@link ValueMapperWithKey} that computes a new output value + * @param named a {@link Named} config used to name the processor in the topology + * @param materialized a {@link Materialized} that describes how the {@link StateStore} for the resulting {@code KTable} + * should be materialized. Cannot be {@code null} + * @param the value type of the result {@code KTable} + * + * @return a {@code KTable} that contains records with unmodified keys and new values (possibly of different type) + */ + KTable mapValues(final ValueMapperWithKey mapper, + final Named named, + final Materialized> materialized); + + /** + * Convert this changelog stream to a {@link KStream}. + *

                + * Note that this is a logical operation and only changes the "interpretation" of the stream, i.e., each record of + * this changelog stream is no longer treated as an updated record (cf. {@link KStream} vs {@code KTable}). + * + * @return a {@link KStream} that contains the same records as this {@code KTable} + */ + KStream toStream(); + + /** + * Convert this changelog stream to a {@link KStream}. + *

                + * Note that this is a logical operation and only changes the "interpretation" of the stream, i.e., each record of + * this changelog stream is no longer treated as an updated record (cf. {@link KStream} vs {@code KTable}). + * + * @param named a {@link Named} config used to name the processor in the topology + * + * @return a {@link KStream} that contains the same records as this {@code KTable} + */ + KStream toStream(final Named named); + + /** + * Convert this changelog stream to a {@link KStream} using the given {@link KeyValueMapper} to select the new key. + *

                + * For example, you can compute the new key as the length of the value string. + *

                {@code
                +     * KTable table = builder.table("topic");
                +     * KTable keyedStream = table.toStream(new KeyValueMapper {
                +     *     Integer apply(String key, String value) {
                +     *         return value.length();
                +     *     }
                +     * });
                +     * }
                + * Setting a new key might result in an internal data redistribution if a key based operator (like an aggregation or + * join) is applied to the result {@link KStream}. + *

                + * This operation is equivalent to calling + * {@code table.}{@link #toStream() toStream}{@code ().}{@link KStream#selectKey(KeyValueMapper) selectKey(KeyValueMapper)}. + *

                + * Note that {@link #toStream()} is a logical operation and only changes the "interpretation" of the stream, i.e., + * each record of this changelog stream is no longer treated as an updated record (cf. {@link KStream} vs {@code KTable}). + * + * @param mapper a {@link KeyValueMapper} that computes a new key for each record + * @param the new key type of the result stream + * @return a {@link KStream} that contains the same records as this {@code KTable} + */ + KStream toStream(final KeyValueMapper mapper); + + /** + * Convert this changelog stream to a {@link KStream} using the given {@link KeyValueMapper} to select the new key. + *

                + * For example, you can compute the new key as the length of the value string. + *

                {@code
                +     * KTable table = builder.table("topic");
                +     * KTable keyedStream = table.toStream(new KeyValueMapper {
                +     *     Integer apply(String key, String value) {
                +     *         return value.length();
                +     *     }
                +     * });
                +     * }
                + * Setting a new key might result in an internal data redistribution if a key based operator (like an aggregation or + * join) is applied to the result {@link KStream}. + *

                + * This operation is equivalent to calling + * {@code table.}{@link #toStream() toStream}{@code ().}{@link KStream#selectKey(KeyValueMapper) selectKey(KeyValueMapper)}. + *

                + * Note that {@link #toStream()} is a logical operation and only changes the "interpretation" of the stream, i.e., + * each record of this changelog stream is no longer treated as an updated record (cf. {@link KStream} vs {@code KTable}). + * + * @param mapper a {@link KeyValueMapper} that computes a new key for each record + * @param named a {@link Named} config used to name the processor in the topology + * @param the new key type of the result stream + * @return a {@link KStream} that contains the same records as this {@code KTable} + */ + KStream toStream(final KeyValueMapper mapper, + final Named named); + + /** + * Suppress some updates from this changelog stream, determined by the supplied {@link Suppressed} configuration. + * + * This controls what updates downstream table and stream operations will receive. + * + * @param suppressed Configuration object determining what, if any, updates to suppress + * @return A new KTable with the desired suppression characteristics. + */ + KTable suppress(final Suppressed suppressed); + + /** + * Create a new {@code KTable} by transforming the value of each record in this {@code KTable} into a new value + * (with possibly a new type), with default serializers, deserializers, and state store. + * A {@link ValueTransformerWithKey} (provided by the given {@link ValueTransformerWithKeySupplier}) is applied to each input + * record value and computes a new value for it. + * Thus, an input record {@code } can be transformed into an output record {@code }. + * This is similar to {@link #mapValues(ValueMapperWithKey)}, but more flexible, allowing access to additional state-stores, + * and access to the {@link ProcessorContext}. + * Furthermore, via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long)} the processing progress can be observed and additional + * periodic actions can be performed. + *

                + * If the downstream topology uses aggregation functions, (e.g. {@link KGroupedTable#reduce}, {@link KGroupedTable#aggregate}, etc), + * care must be taken when dealing with state, (either held in state-stores or transformer instances), to ensure correct aggregate results. + * In contrast, if the resulting KTable is materialized, (cf. {@link #transformValues(ValueTransformerWithKeySupplier, Materialized, String...)}), + * such concerns are handled for you. + *

                + * In order to assign a state, the state must be created and registered beforehand: + *

                {@code
                +     * // create store
                +     * StoreBuilder> keyValueStoreBuilder =
                +     *         Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myValueTransformState"),
                +     *                 Serdes.String(),
                +     *                 Serdes.String());
                +     * // register store
                +     * builder.addStateStore(keyValueStoreBuilder);
                +     *
                +     * KTable outputTable = inputTable.transformValues(new ValueTransformerWithKeySupplier() { ... }, "myValueTransformState");
                +     * }
                + *

                + * Within the {@link ValueTransformerWithKey}, the state is obtained via the + * {@link ProcessorContext}. + * To trigger periodic actions via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) punctuate()}, + * a schedule must be registered. + *

                {@code
                +     * new ValueTransformerWithKeySupplier() {
                +     *     ValueTransformerWithKey get() {
                +     *         return new ValueTransformerWithKey() {
                +     *             private KeyValueStore state;
                +     *
                +     *             void init(ProcessorContext context) {
                +     *                 this.state = (KeyValueStore)context.getStateStore("myValueTransformState");
                +     *                 context.schedule(Duration.ofSeconds(1), PunctuationType.WALL_CLOCK_TIME, new Punctuator(..)); // punctuate each 1000ms, can access this.state
                +     *             }
                +     *
                +     *             NewValueType transform(K readOnlyKey, V value) {
                +     *                 // can access this.state and use read-only key
                +     *                 return new NewValueType(readOnlyKey); // or null
                +     *             }
                +     *
                +     *             void close() {
                +     *                 // can access this.state
                +     *             }
                +     *         }
                +     *     }
                +     * }
                +     * }
                + *

                + * Note that the key is read-only and should not be modified, as this can lead to corrupt partitioning. + * Setting a new value preserves data co-location with respect to the key. + * + * @param transformerSupplier a instance of {@link ValueTransformerWithKeySupplier} that generates a + * {@link ValueTransformerWithKey}. + * At least one transformer instance will be created per streaming task. + * Transformers do not need to be thread-safe. + * @param stateStoreNames the names of the state stores used by the processor + * @param the value type of the result table + * @return a {@code KTable} that contains records with unmodified key and new values (possibly of different type) + * @see #mapValues(ValueMapper) + * @see #mapValues(ValueMapperWithKey) + */ + KTable transformValues(final ValueTransformerWithKeySupplier transformerSupplier, + final String... stateStoreNames); + + /** + * Create a new {@code KTable} by transforming the value of each record in this {@code KTable} into a new value + * (with possibly a new type), with default serializers, deserializers, and state store. + * A {@link ValueTransformerWithKey} (provided by the given {@link ValueTransformerWithKeySupplier}) is applied to each input + * record value and computes a new value for it. + * Thus, an input record {@code } can be transformed into an output record {@code }. + * This is similar to {@link #mapValues(ValueMapperWithKey)}, but more flexible, allowing access to additional state-stores, + * and access to the {@link ProcessorContext}. + * Furthermore, via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long)} the processing progress can be observed and additional + * periodic actions can be performed. + *

                + * If the downstream topology uses aggregation functions, (e.g. {@link KGroupedTable#reduce}, {@link KGroupedTable#aggregate}, etc), + * care must be taken when dealing with state, (either held in state-stores or transformer instances), to ensure correct aggregate results. + * In contrast, if the resulting KTable is materialized, (cf. {@link #transformValues(ValueTransformerWithKeySupplier, Materialized, String...)}), + * such concerns are handled for you. + *

                + * In order to assign a state, the state must be created and registered beforehand: + *

                {@code
                +     * // create store
                +     * StoreBuilder> keyValueStoreBuilder =
                +     *         Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myValueTransformState"),
                +     *                 Serdes.String(),
                +     *                 Serdes.String());
                +     * // register store
                +     * builder.addStateStore(keyValueStoreBuilder);
                +     *
                +     * KTable outputTable = inputTable.transformValues(new ValueTransformerWithKeySupplier() { ... }, "myValueTransformState");
                +     * }
                + *

                + * Within the {@link ValueTransformerWithKey}, the state is obtained via the + * {@link ProcessorContext}. + * To trigger periodic actions via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) punctuate()}, + * a schedule must be registered. + *

                {@code
                +     * new ValueTransformerWithKeySupplier() {
                +     *     ValueTransformerWithKey get() {
                +     *         return new ValueTransformerWithKey() {
                +     *             private KeyValueStore state;
                +     *
                +     *             void init(ProcessorContext context) {
                +     *                 this.state = (KeyValueStore)context.getStateStore("myValueTransformState");
                +     *                 context.schedule(Duration.ofSeconds(1), PunctuationType.WALL_CLOCK_TIME, new Punctuator(..)); // punctuate each 1000ms, can access this.state
                +     *             }
                +     *
                +     *             NewValueType transform(K readOnlyKey, V value) {
                +     *                 // can access this.state and use read-only key
                +     *                 return new NewValueType(readOnlyKey); // or null
                +     *             }
                +     *
                +     *             void close() {
                +     *                 // can access this.state
                +     *             }
                +     *         }
                +     *     }
                +     * }
                +     * }
                + *

                + * Note that the key is read-only and should not be modified, as this can lead to corrupt partitioning. + * Setting a new value preserves data co-location with respect to the key. + * + * @param transformerSupplier a instance of {@link ValueTransformerWithKeySupplier} that generates a + * {@link ValueTransformerWithKey}. + * At least one transformer instance will be created per streaming task. + * Transformers do not need to be thread-safe. + * @param named a {@link Named} config used to name the processor in the topology + * @param stateStoreNames the names of the state stores used by the processor + * @param the value type of the result table + * @return a {@code KTable} that contains records with unmodified key and new values (possibly of different type) + * @see #mapValues(ValueMapper) + * @see #mapValues(ValueMapperWithKey) + */ + KTable transformValues(final ValueTransformerWithKeySupplier transformerSupplier, + final Named named, + final String... stateStoreNames); + + /** + * Create a new {@code KTable} by transforming the value of each record in this {@code KTable} into a new value + * (with possibly a new type), with the {@link Serde key serde}, {@link Serde value serde}, and the underlying + * {@link KeyValueStore materialized state storage} configured in the {@link Materialized} instance. + * A {@link ValueTransformerWithKey} (provided by the given {@link ValueTransformerWithKeySupplier}) is applied to each input + * record value and computes a new value for it. + * This is similar to {@link #mapValues(ValueMapperWithKey)}, but more flexible, allowing stateful, rather than stateless, + * record-by-record operation, access to additional state-stores, and access to the {@link ProcessorContext}. + * Furthermore, via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long)} the processing progress can be observed and additional + * periodic actions can be performed. + * The resulting {@code KTable} is materialized into another state store (additional to the provided state store names) + * as specified by the user via {@link Materialized} parameter, and is queryable through its given name. + *

                + * In order to assign a state, the state must be created and registered beforehand: + *

                {@code
                +     * // create store
                +     * StoreBuilder> keyValueStoreBuilder =
                +     *         Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myValueTransformState"),
                +     *                 Serdes.String(),
                +     *                 Serdes.String());
                +     * // register store
                +     * builder.addStateStore(keyValueStoreBuilder);
                +     *
                +     * KTable outputTable = inputTable.transformValues(
                +     *     new ValueTransformerWithKeySupplier() { ... },
                +     *     Materialized.>as("outputTable")
                +     *                                 .withKeySerde(Serdes.String())
                +     *                                 .withValueSerde(Serdes.String()),
                +     *     "myValueTransformState");
                +     * }
                + *

                + * Within the {@link ValueTransformerWithKey}, the state is obtained via the + * {@link ProcessorContext}. + * To trigger periodic actions via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) punctuate()}, + * a schedule must be registered. + *

                {@code
                +     * new ValueTransformerWithKeySupplier() {
                +     *     ValueTransformerWithKey get() {
                +     *         return new ValueTransformerWithKey() {
                +     *             private KeyValueStore state;
                +     *
                +     *             void init(ProcessorContext context) {
                +     *                 this.state = (KeyValueStore)context.getStateStore("myValueTransformState");
                +     *                 context.schedule(Duration.ofSeconds(1), PunctuationType.WALL_CLOCK_TIME, new Punctuator(..)); // punctuate each 1000ms, can access this.state
                +     *             }
                +     *
                +     *             NewValueType transform(K readOnlyKey, V value) {
                +     *                 // can access this.state and use read-only key
                +     *                 return new NewValueType(readOnlyKey); // or null
                +     *             }
                +     *
                +     *             void close() {
                +     *                 // can access this.state
                +     *             }
                +     *         }
                +     *     }
                +     * }
                +     * }
                + *

                + * Note that the key is read-only and should not be modified, as this can lead to corrupt partitioning. + * Setting a new value preserves data co-location with respect to the key. + * + * @param transformerSupplier a instance of {@link ValueTransformerWithKeySupplier} that generates a + * {@link ValueTransformerWithKey}. + * At least one transformer instance will be created per streaming task. + * Transformers do not need to be thread-safe. + * @param materialized an instance of {@link Materialized} used to describe how the state store of the + * resulting table should be materialized. + * Cannot be {@code null} + * @param stateStoreNames the names of the state stores used by the processor + * @param the value type of the result table + * @return a {@code KTable} that contains records with unmodified key and new values (possibly of different type) + * @see #mapValues(ValueMapper) + * @see #mapValues(ValueMapperWithKey) + */ + KTable transformValues(final ValueTransformerWithKeySupplier transformerSupplier, + final Materialized> materialized, + final String... stateStoreNames); + + /** + * Create a new {@code KTable} by transforming the value of each record in this {@code KTable} into a new value + * (with possibly a new type), with the {@link Serde key serde}, {@link Serde value serde}, and the underlying + * {@link KeyValueStore materialized state storage} configured in the {@link Materialized} instance. + * A {@link ValueTransformerWithKey} (provided by the given {@link ValueTransformerWithKeySupplier}) is applied to each input + * record value and computes a new value for it. + * This is similar to {@link #mapValues(ValueMapperWithKey)}, but more flexible, allowing stateful, rather than stateless, + * record-by-record operation, access to additional state-stores, and access to the {@link ProcessorContext}. + * Furthermore, via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long)} the processing progress can be observed and additional + * periodic actions can be performed. + * The resulting {@code KTable} is materialized into another state store (additional to the provided state store names) + * as specified by the user via {@link Materialized} parameter, and is queryable through its given name. + *

                + * In order to assign a state, the state must be created and registered beforehand: + *

                {@code
                +     * // create store
                +     * StoreBuilder> keyValueStoreBuilder =
                +     *         Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myValueTransformState"),
                +     *                 Serdes.String(),
                +     *                 Serdes.String());
                +     * // register store
                +     * builder.addStateStore(keyValueStoreBuilder);
                +     *
                +     * KTable outputTable = inputTable.transformValues(
                +     *     new ValueTransformerWithKeySupplier() { ... },
                +     *     Materialized.>as("outputTable")
                +     *                                 .withKeySerde(Serdes.String())
                +     *                                 .withValueSerde(Serdes.String()),
                +     *     "myValueTransformState");
                +     * }
                + *

                + * Within the {@link ValueTransformerWithKey}, the state is obtained via the + * {@link ProcessorContext}. + * To trigger periodic actions via {@link org.apache.kafka.streams.processor.Punctuator#punctuate(long) punctuate()}, + * a schedule must be registered. + *

                {@code
                +     * new ValueTransformerWithKeySupplier() {
                +     *     ValueTransformerWithKey get() {
                +     *         return new ValueTransformerWithKey() {
                +     *             private KeyValueStore state;
                +     *
                +     *             void init(ProcessorContext context) {
                +     *                 this.state = (KeyValueStore)context.getStateStore("myValueTransformState");
                +     *                 context.schedule(Duration.ofSeconds(1), PunctuationType.WALL_CLOCK_TIME, new Punctuator(..)); // punctuate each 1000ms, can access this.state
                +     *             }
                +     *
                +     *             NewValueType transform(K readOnlyKey, V value) {
                +     *                 // can access this.state and use read-only key
                +     *                 return new NewValueType(readOnlyKey); // or null
                +     *             }
                +     *
                +     *             void close() {
                +     *                 // can access this.state
                +     *             }
                +     *         }
                +     *     }
                +     * }
                +     * }
                + *

                + * Note that the key is read-only and should not be modified, as this can lead to corrupt partitioning. + * Setting a new value preserves data co-location with respect to the key. + * + * @param transformerSupplier a instance of {@link ValueTransformerWithKeySupplier} that generates a + * {@link ValueTransformerWithKey}. + * At least one transformer instance will be created per streaming task. + * Transformers do not need to be thread-safe. + * @param materialized an instance of {@link Materialized} used to describe how the state store of the + * resulting table should be materialized. + * Cannot be {@code null} + * @param named a {@link Named} config used to name the processor in the topology + * @param stateStoreNames the names of the state stores used by the processor + * @param the value type of the result table + * @return a {@code KTable} that contains records with unmodified key and new values (possibly of different type) + * @see #mapValues(ValueMapper) + * @see #mapValues(ValueMapperWithKey) + */ + KTable transformValues(final ValueTransformerWithKeySupplier transformerSupplier, + final Materialized> materialized, + final Named named, + final String... stateStoreNames); + + /** + * Re-groups the records of this {@code KTable} using the provided {@link KeyValueMapper} and default serializers + * and deserializers. + * Each {@link KeyValue} pair of this {@code KTable} is mapped to a new {@link KeyValue} pair by applying the + * provided {@link KeyValueMapper}. + * Re-grouping a {@code KTable} is required before an aggregation operator can be applied to the data + * (cf. {@link KGroupedTable}). + * The {@link KeyValueMapper} selects a new key and value (with should both have unmodified type). + * If the new record key is {@code null} the record will not be included in the resulting {@link KGroupedTable} + *

                + * Because a new key is selected, an internal repartitioning topic will be created in Kafka. + * This topic will be named "${applicationId}-<name>-repartition", where "applicationId" is user-specified in + * {@link StreamsConfig} via parameter {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "<name>" is + * an internally generated name, and "-repartition" is a fixed suffix. + * + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + *

                + * All data of this {@code KTable} will be redistributed through the repartitioning topic by writing all update + * records to and rereading all updated records from it, such that the resulting {@link KGroupedTable} is partitioned + * on the new key. + *

                + * If the key or value type is changed, it is recommended to use {@link #groupBy(KeyValueMapper, Grouped)} + * instead. + * + * @param selector a {@link KeyValueMapper} that computes a new grouping key and value to be aggregated + * @param the key type of the result {@link KGroupedTable} + * @param the value type of the result {@link KGroupedTable} + * @return a {@link KGroupedTable} that contains the re-grouped records of the original {@code KTable} + */ + KGroupedTable groupBy(final KeyValueMapper> selector); + + /** + * Re-groups the records of this {@code KTable} using the provided {@link KeyValueMapper} + * and {@link Serde}s as specified by {@link Grouped}. + * Each {@link KeyValue} pair of this {@code KTable} is mapped to a new {@link KeyValue} pair by applying the + * provided {@link KeyValueMapper}. + * Re-grouping a {@code KTable} is required before an aggregation operator can be applied to the data + * (cf. {@link KGroupedTable}). + * The {@link KeyValueMapper} selects a new key and value (where both could the same type or a new type). + * If the new record key is {@code null} the record will not be included in the resulting {@link KGroupedTable} + *

                + * Because a new key is selected, an internal repartitioning topic will be created in Kafka. + * This topic will be named "${applicationId}-<name>-repartition", where "applicationId" is user-specified in + * {@link StreamsConfig} via parameter {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "<name>" is + * either provided via {@link org.apache.kafka.streams.kstream.Grouped#as(String)} or an internally generated name. + * + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + *

                + * All data of this {@code KTable} will be redistributed through the repartitioning topic by writing all update + * records to and rereading all updated records from it, such that the resulting {@link KGroupedTable} is partitioned + * on the new key. + * + * @param selector a {@link KeyValueMapper} that computes a new grouping key and value to be aggregated + * @param grouped the {@link Grouped} instance used to specify {@link org.apache.kafka.common.serialization.Serdes} + * and the name for a repartition topic if repartitioning is required. + * @param the key type of the result {@link KGroupedTable} + * @param the value type of the result {@link KGroupedTable} + * @return a {@link KGroupedTable} that contains the re-grouped records of the original {@code KTable} + */ + KGroupedTable groupBy(final KeyValueMapper> selector, + final Grouped grouped); + + /** + * Join records of this {@code KTable} with another {@code KTable}'s records using non-windowed inner equi join, + * with default serializers, deserializers, and state store. + * The join is a primary key join with join attribute {@code thisKTable.key == otherKTable.key}. + * The result is an ever updating {@code KTable} that represents the current (i.e., processing time) result + * of the join. + *

                + * The join is computed by (1) updating the internal state of one {@code KTable} and (2) performing a lookup for a + * matching record in the current (i.e., processing time) internal state of the other {@code KTable}. + * This happens in a symmetric way, i.e., for each update of either {@code this} or the {@code other} input + * {@code KTable} the result gets updated. + *

                + * For each {@code KTable} record that finds a corresponding record in the other {@code KTable} the provided + * {@link ValueJoiner} will be called to compute a value (with arbitrary type) for the result record. + * The key of the result record is the same as for both joining input records. + *

                + * Note that {@link KeyValue records} with {@code null} values (so-called tombstone records) have delete semantics. + * Thus, for input tombstones the provided value-joiner is not called but a tombstone record is forwarded + * directly to delete a record in the result {@code KTable} if required (i.e., if there is anything to be deleted). + *

                + * Input records with {@code null} key will be dropped and no join computation is performed. + *

                + * Example: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                thisKTablethisStateotherKTableotherStateresult updated record
                <K1:A><K1:A>
                <K1:A><K1:b><K1:b><K1:ValueJoiner(A,b)>
                <K1:C><K1:C><K1:b><K1:ValueJoiner(C,b)>
                <K1:C><K1:null><K1:null>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * + * @param other the other {@code KTable} to be joined with this {@code KTable} + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param the value type of the other {@code KTable} + * @param the value type of the result {@code KTable} + * @return a {@code KTable} that contains join-records for each key and values computed by the given + * {@link ValueJoiner}, one for each matched record-pair with the same key + * @see #leftJoin(KTable, ValueJoiner) + * @see #outerJoin(KTable, ValueJoiner) + */ + KTable join(final KTable other, + final ValueJoiner joiner); + + /** + * Join records of this {@code KTable} with another {@code KTable}'s records using non-windowed inner equi join, + * with default serializers, deserializers, and state store. + * The join is a primary key join with join attribute {@code thisKTable.key == otherKTable.key}. + * The result is an ever updating {@code KTable} that represents the current (i.e., processing time) result + * of the join. + *

                + * The join is computed by (1) updating the internal state of one {@code KTable} and (2) performing a lookup for a + * matching record in the current (i.e., processing time) internal state of the other {@code KTable}. + * This happens in a symmetric way, i.e., for each update of either {@code this} or the {@code other} input + * {@code KTable} the result gets updated. + *

                + * For each {@code KTable} record that finds a corresponding record in the other {@code KTable} the provided + * {@link ValueJoiner} will be called to compute a value (with arbitrary type) for the result record. + * The key of the result record is the same as for both joining input records. + *

                + * Note that {@link KeyValue records} with {@code null} values (so-called tombstone records) have delete semantics. + * Thus, for input tombstones the provided value-joiner is not called but a tombstone record is forwarded + * directly to delete a record in the result {@code KTable} if required (i.e., if there is anything to be deleted). + *

                + * Input records with {@code null} key will be dropped and no join computation is performed. + *

                + * Example: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                thisKTablethisStateotherKTableotherStateresult updated record
                <K1:A><K1:A>
                <K1:A><K1:b><K1:b><K1:ValueJoiner(A,b)>
                <K1:C><K1:C><K1:b><K1:ValueJoiner(C,b)>
                <K1:C><K1:null><K1:null>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * + * @param other the other {@code KTable} to be joined with this {@code KTable} + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param named a {@link Named} config used to name the processor in the topology + * @param the value type of the other {@code KTable} + * @param the value type of the result {@code KTable} + * @return a {@code KTable} that contains join-records for each key and values computed by the given + * {@link ValueJoiner}, one for each matched record-pair with the same key + * @see #leftJoin(KTable, ValueJoiner) + * @see #outerJoin(KTable, ValueJoiner) + */ + KTable join(final KTable other, + final ValueJoiner joiner, + final Named named); + + /** + * Join records of this {@code KTable} with another {@code KTable}'s records using non-windowed inner equi join, + * with the {@link Materialized} instance for configuration of the {@link Serde key serde}, + * {@link Serde the result table's value serde}, and {@link KeyValueStore state store}. + * The join is a primary key join with join attribute {@code thisKTable.key == otherKTable.key}. + * The result is an ever updating {@code KTable} that represents the current (i.e., processing time) result + * of the join. + *

                + * The join is computed by (1) updating the internal state of one {@code KTable} and (2) performing a lookup for a + * matching record in the current (i.e., processing time) internal state of the other {@code KTable}. + * This happens in a symmetric way, i.e., for each update of either {@code this} or the {@code other} input + * {@code KTable} the result gets updated. + *

                + * For each {@code KTable} record that finds a corresponding record in the other {@code KTable} the provided + * {@link ValueJoiner} will be called to compute a value (with arbitrary type) for the result record. + * The key of the result record is the same as for both joining input records. + *

                + * Note that {@link KeyValue records} with {@code null} values (so-called tombstone records) have delete semantics. + * Thus, for input tombstones the provided value-joiner is not called but a tombstone record is forwarded + * directly to delete a record in the result {@code KTable} if required (i.e., if there is anything to be deleted). + *

                + * Input records with {@code null} key will be dropped and no join computation is performed. + *

                + * Example: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                thisKTablethisStateotherKTableotherStateresult updated record
                <K1:A><K1:A>
                <K1:A><K1:b><K1:b><K1:ValueJoiner(A,b)>
                <K1:C><K1:C><K1:b><K1:ValueJoiner(C,b)>
                <K1:C><K1:null><K1:null>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * + * @param other the other {@code KTable} to be joined with this {@code KTable} + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param materialized an instance of {@link Materialized} used to describe how the state store should be materialized. + * Cannot be {@code null} + * @param the value type of the other {@code KTable} + * @param the value type of the result {@code KTable} + * @return a {@code KTable} that contains join-records for each key and values computed by the given + * {@link ValueJoiner}, one for each matched record-pair with the same key + * @see #leftJoin(KTable, ValueJoiner, Materialized) + * @see #outerJoin(KTable, ValueJoiner, Materialized) + */ + KTable join(final KTable other, + final ValueJoiner joiner, + final Materialized> materialized); + + /** + * Join records of this {@code KTable} with another {@code KTable}'s records using non-windowed inner equi join, + * with the {@link Materialized} instance for configuration of the {@link Serde key serde}, + * {@link Serde the result table's value serde}, and {@link KeyValueStore state store}. + * The join is a primary key join with join attribute {@code thisKTable.key == otherKTable.key}. + * The result is an ever updating {@code KTable} that represents the current (i.e., processing time) result + * of the join. + *

                + * The join is computed by (1) updating the internal state of one {@code KTable} and (2) performing a lookup for a + * matching record in the current (i.e., processing time) internal state of the other {@code KTable}. + * This happens in a symmetric way, i.e., for each update of either {@code this} or the {@code other} input + * {@code KTable} the result gets updated. + *

                + * For each {@code KTable} record that finds a corresponding record in the other {@code KTable} the provided + * {@link ValueJoiner} will be called to compute a value (with arbitrary type) for the result record. + * The key of the result record is the same as for both joining input records. + *

                + * Note that {@link KeyValue records} with {@code null} values (so-called tombstone records) have delete semantics. + * Thus, for input tombstones the provided value-joiner is not called but a tombstone record is forwarded + * directly to delete a record in the result {@code KTable} if required (i.e., if there is anything to be deleted). + *

                + * Input records with {@code null} key will be dropped and no join computation is performed. + *

                + * Example: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                thisKTablethisStateotherKTableotherStateresult updated record
                <K1:A><K1:A>
                <K1:A><K1:b><K1:b><K1:ValueJoiner(A,b)>
                <K1:C><K1:C><K1:b><K1:ValueJoiner(C,b)>
                <K1:C><K1:null><K1:null>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * + * @param other the other {@code KTable} to be joined with this {@code KTable} + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param named a {@link Named} config used to name the processor in the topology + * @param materialized an instance of {@link Materialized} used to describe how the state store should be materialized. + * Cannot be {@code null} + * @param the value type of the other {@code KTable} + * @param the value type of the result {@code KTable} + * @return a {@code KTable} that contains join-records for each key and values computed by the given + * {@link ValueJoiner}, one for each matched record-pair with the same key + * @see #leftJoin(KTable, ValueJoiner, Materialized) + * @see #outerJoin(KTable, ValueJoiner, Materialized) + */ + KTable join(final KTable other, + final ValueJoiner joiner, + final Named named, + final Materialized> materialized); + + /** + * Join records of this {@code KTable} (left input) with another {@code KTable}'s (right input) records using + * non-windowed left equi join, with default serializers, deserializers, and state store. + * The join is a primary key join with join attribute {@code thisKTable.key == otherKTable.key}. + * In contrast to {@link #join(KTable, ValueJoiner) inner-join}, all records from left {@code KTable} will produce + * an output record (cf. below). + * The result is an ever updating {@code KTable} that represents the current (i.e., processing time) result + * of the join. + *

                + * The join is computed by (1) updating the internal state of one {@code KTable} and (2) performing a lookup for a + * matching record in the current (i.e., processing time) internal state of the other {@code KTable}. + * This happens in a symmetric way, i.e., for each update of either {@code this} or the {@code other} input + * {@code KTable} the result gets updated. + *

                + * For each {@code KTable} record that finds a corresponding record in the other {@code KTable}'s state the + * provided {@link ValueJoiner} will be called to compute a value (with arbitrary type) for the result record. + * Additionally, for each record of left {@code KTable} that does not find a corresponding record in the + * right {@code KTable}'s state the provided {@link ValueJoiner} will be called with {@code rightValue = + * null} to compute a value (with arbitrary type) for the result record. + * The key of the result record is the same as for both joining input records. + *

                + * Note that {@link KeyValue records} with {@code null} values (so-called tombstone records) have delete semantics. + * For example, for left input tombstones the provided value-joiner is not called but a tombstone record is + * forwarded directly to delete a record in the result {@code KTable} if required (i.e., if there is anything to be + * deleted). + *

                + * Input records with {@code null} key will be dropped and no join computation is performed. + *

                + * Example: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                thisKTablethisStateotherKTableotherStateresult updated record
                <K1:A><K1:A><K1:ValueJoiner(A,null)>
                <K1:A><K1:b><K1:b><K1:ValueJoiner(A,b)>
                <K1:null><K1:b><K1:null>
                <K1:null>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * + * @param other the other {@code KTable} to be joined with this {@code KTable} + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param the value type of the other {@code KTable} + * @param the value type of the result {@code KTable} + * @return a {@code KTable} that contains join-records for each key and values computed by the given + * {@link ValueJoiner}, one for each matched record-pair with the same key plus one for each non-matching record of + * left {@code KTable} + * @see #join(KTable, ValueJoiner) + * @see #outerJoin(KTable, ValueJoiner) + */ + KTable leftJoin(final KTable other, + final ValueJoiner joiner); + + /** + * Join records of this {@code KTable} (left input) with another {@code KTable}'s (right input) records using + * non-windowed left equi join, with default serializers, deserializers, and state store. + * The join is a primary key join with join attribute {@code thisKTable.key == otherKTable.key}. + * In contrast to {@link #join(KTable, ValueJoiner) inner-join}, all records from left {@code KTable} will produce + * an output record (cf. below). + * The result is an ever updating {@code KTable} that represents the current (i.e., processing time) result + * of the join. + *

                + * The join is computed by (1) updating the internal state of one {@code KTable} and (2) performing a lookup for a + * matching record in the current (i.e., processing time) internal state of the other {@code KTable}. + * This happens in a symmetric way, i.e., for each update of either {@code this} or the {@code other} input + * {@code KTable} the result gets updated. + *

                + * For each {@code KTable} record that finds a corresponding record in the other {@code KTable}'s state the + * provided {@link ValueJoiner} will be called to compute a value (with arbitrary type) for the result record. + * Additionally, for each record of left {@code KTable} that does not find a corresponding record in the + * right {@code KTable}'s state the provided {@link ValueJoiner} will be called with {@code rightValue = + * null} to compute a value (with arbitrary type) for the result record. + * The key of the result record is the same as for both joining input records. + *

                + * Note that {@link KeyValue records} with {@code null} values (so-called tombstone records) have delete semantics. + * For example, for left input tombstones the provided value-joiner is not called but a tombstone record is + * forwarded directly to delete a record in the result {@code KTable} if required (i.e., if there is anything to be + * deleted). + *

                + * Input records with {@code null} key will be dropped and no join computation is performed. + *

                + * Example: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                thisKTablethisStateotherKTableotherStateresult updated record
                <K1:A><K1:A><K1:ValueJoiner(A,null)>
                <K1:A><K1:b><K1:b><K1:ValueJoiner(A,b)>
                <K1:null><K1:b><K1:null>
                <K1:null>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * + * @param other the other {@code KTable} to be joined with this {@code KTable} + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param named a {@link Named} config used to name the processor in the topology + * @param the value type of the other {@code KTable} + * @param the value type of the result {@code KTable} + * @return a {@code KTable} that contains join-records for each key and values computed by the given + * {@link ValueJoiner}, one for each matched record-pair with the same key plus one for each non-matching record of + * left {@code KTable} + * @see #join(KTable, ValueJoiner) + * @see #outerJoin(KTable, ValueJoiner) + */ + KTable leftJoin(final KTable other, + final ValueJoiner joiner, + final Named named); + + /** + * Join records of this {@code KTable} (left input) with another {@code KTable}'s (right input) records using + * non-windowed left equi join, with the {@link Materialized} instance for configuration of the {@link Serde key serde}, + * {@link Serde the result table's value serde}, and {@link KeyValueStore state store}. + * The join is a primary key join with join attribute {@code thisKTable.key == otherKTable.key}. + * In contrast to {@link #join(KTable, ValueJoiner) inner-join}, all records from left {@code KTable} will produce + * an output record (cf. below). + * The result is an ever updating {@code KTable} that represents the current (i.e., processing time) result + * of the join. + *

                + * The join is computed by (1) updating the internal state of one {@code KTable} and (2) performing a lookup for a + * matching record in the current (i.e., processing time) internal state of the other {@code KTable}. + * This happens in a symmetric way, i.e., for each update of either {@code this} or the {@code other} input + * {@code KTable} the result gets updated. + *

                + * For each {@code KTable} record that finds a corresponding record in the other {@code KTable}'s state the + * provided {@link ValueJoiner} will be called to compute a value (with arbitrary type) for the result record. + * Additionally, for each record of left {@code KTable} that does not find a corresponding record in the + * right {@code KTable}'s state the provided {@link ValueJoiner} will be called with {@code rightValue = + * null} to compute a value (with arbitrary type) for the result record. + * The key of the result record is the same as for both joining input records. + *

                + * Note that {@link KeyValue records} with {@code null} values (so-called tombstone records) have delete semantics. + * For example, for left input tombstones the provided value-joiner is not called but a tombstone record is + * forwarded directly to delete a record in the result {@code KTable} if required (i.e., if there is anything to be + * deleted). + *

                + * Input records with {@code null} key will be dropped and no join computation is performed. + *

                + * Example: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                thisKTablethisStateotherKTableotherStateresult updated record
                <K1:A><K1:A><K1:ValueJoiner(A,null)>
                <K1:A><K1:b><K1:b><K1:ValueJoiner(A,b)>
                <K1:null><K1:b><K1:null>
                <K1:null>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * + * @param other the other {@code KTable} to be joined with this {@code KTable} + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param materialized an instance of {@link Materialized} used to describe how the state store should be materialized. + * Cannot be {@code null} + * @param the value type of the other {@code KTable} + * @param the value type of the result {@code KTable} + * @return a {@code KTable} that contains join-records for each key and values computed by the given + * {@link ValueJoiner}, one for each matched record-pair with the same key plus one for each non-matching record of + * left {@code KTable} + * @see #join(KTable, ValueJoiner, Materialized) + * @see #outerJoin(KTable, ValueJoiner, Materialized) + */ + KTable leftJoin(final KTable other, + final ValueJoiner joiner, + final Materialized> materialized); + + /** + * Join records of this {@code KTable} (left input) with another {@code KTable}'s (right input) records using + * non-windowed left equi join, with the {@link Materialized} instance for configuration of the {@link Serde key serde}, + * {@link Serde the result table's value serde}, and {@link KeyValueStore state store}. + * The join is a primary key join with join attribute {@code thisKTable.key == otherKTable.key}. + * In contrast to {@link #join(KTable, ValueJoiner) inner-join}, all records from left {@code KTable} will produce + * an output record (cf. below). + * The result is an ever updating {@code KTable} that represents the current (i.e., processing time) result + * of the join. + *

                + * The join is computed by (1) updating the internal state of one {@code KTable} and (2) performing a lookup for a + * matching record in the current (i.e., processing time) internal state of the other {@code KTable}. + * This happens in a symmetric way, i.e., for each update of either {@code this} or the {@code other} input + * {@code KTable} the result gets updated. + *

                + * For each {@code KTable} record that finds a corresponding record in the other {@code KTable}'s state the + * provided {@link ValueJoiner} will be called to compute a value (with arbitrary type) for the result record. + * Additionally, for each record of left {@code KTable} that does not find a corresponding record in the + * right {@code KTable}'s state the provided {@link ValueJoiner} will be called with {@code rightValue = + * null} to compute a value (with arbitrary type) for the result record. + * The key of the result record is the same as for both joining input records. + *

                + * Note that {@link KeyValue records} with {@code null} values (so-called tombstone records) have delete semantics. + * For example, for left input tombstones the provided value-joiner is not called but a tombstone record is + * forwarded directly to delete a record in the result {@code KTable} if required (i.e., if there is anything to be + * deleted). + *

                + * Input records with {@code null} key will be dropped and no join computation is performed. + *

                + * Example: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                thisKTablethisStateotherKTableotherStateresult updated record
                <K1:A><K1:A><K1:ValueJoiner(A,null)>
                <K1:A><K1:b><K1:b><K1:ValueJoiner(A,b)>
                <K1:null><K1:b><K1:null>
                <K1:null>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * + * @param other the other {@code KTable} to be joined with this {@code KTable} + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param named a {@link Named} config used to name the processor in the topology + * @param materialized an instance of {@link Materialized} used to describe how the state store should be materialized. + * Cannot be {@code null} + * @param the value type of the other {@code KTable} + * @param the value type of the result {@code KTable} + * @return a {@code KTable} that contains join-records for each key and values computed by the given + * {@link ValueJoiner}, one for each matched record-pair with the same key plus one for each non-matching record of + * left {@code KTable} + * @see #join(KTable, ValueJoiner, Materialized) + * @see #outerJoin(KTable, ValueJoiner, Materialized) + */ + KTable leftJoin(final KTable other, + final ValueJoiner joiner, + final Named named, + final Materialized> materialized); + + /** + * Join records of this {@code KTable} (left input) with another {@code KTable}'s (right input) records using + * non-windowed outer equi join, with default serializers, deserializers, and state store. + * The join is a primary key join with join attribute {@code thisKTable.key == otherKTable.key}. + * In contrast to {@link #join(KTable, ValueJoiner) inner-join} or {@link #leftJoin(KTable, ValueJoiner) left-join}, + * all records from both input {@code KTable}s will produce an output record (cf. below). + * The result is an ever updating {@code KTable} that represents the current (i.e., processing time) result + * of the join. + *

                + * The join is computed by (1) updating the internal state of one {@code KTable} and (2) performing a lookup for a + * matching record in the current (i.e., processing time) internal state of the other {@code KTable}. + * This happens in a symmetric way, i.e., for each update of either {@code this} or the {@code other} input + * {@code KTable} the result gets updated. + *

                + * For each {@code KTable} record that finds a corresponding record in the other {@code KTable}'s state the + * provided {@link ValueJoiner} will be called to compute a value (with arbitrary type) for the result record. + * Additionally, for each record that does not find a corresponding record in the corresponding other + * {@code KTable}'s state the provided {@link ValueJoiner} will be called with {@code null} value for the + * corresponding other value to compute a value (with arbitrary type) for the result record. + * The key of the result record is the same as for both joining input records. + *

                + * Note that {@link KeyValue records} with {@code null} values (so-called tombstone records) have delete semantics. + * Thus, for input tombstones the provided value-joiner is not called but a tombstone record is forwarded directly + * to delete a record in the result {@code KTable} if required (i.e., if there is anything to be deleted). + *

                + * Input records with {@code null} key will be dropped and no join computation is performed. + *

                + * Example: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                thisKTablethisStateotherKTableotherStateresult updated record
                <K1:A><K1:A><K1:ValueJoiner(A,null)>
                <K1:A><K1:b><K1:b><K1:ValueJoiner(A,b)>
                <K1:null><K1:b><K1:ValueJoiner(null,b)>
                <K1:null><K1:null>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * + * @param other the other {@code KTable} to be joined with this {@code KTable} + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param the value type of the other {@code KTable} + * @param the value type of the result {@code KTable} + * @return a {@code KTable} that contains join-records for each key and values computed by the given + * {@link ValueJoiner}, one for each matched record-pair with the same key plus one for each non-matching record of + * both {@code KTable}s + * @see #join(KTable, ValueJoiner) + * @see #leftJoin(KTable, ValueJoiner) + */ + KTable outerJoin(final KTable other, + final ValueJoiner joiner); + + + /** + * Join records of this {@code KTable} (left input) with another {@code KTable}'s (right input) records using + * non-windowed outer equi join, with default serializers, deserializers, and state store. + * The join is a primary key join with join attribute {@code thisKTable.key == otherKTable.key}. + * In contrast to {@link #join(KTable, ValueJoiner) inner-join} or {@link #leftJoin(KTable, ValueJoiner) left-join}, + * all records from both input {@code KTable}s will produce an output record (cf. below). + * The result is an ever updating {@code KTable} that represents the current (i.e., processing time) result + * of the join. + *

                + * The join is computed by (1) updating the internal state of one {@code KTable} and (2) performing a lookup for a + * matching record in the current (i.e., processing time) internal state of the other {@code KTable}. + * This happens in a symmetric way, i.e., for each update of either {@code this} or the {@code other} input + * {@code KTable} the result gets updated. + *

                + * For each {@code KTable} record that finds a corresponding record in the other {@code KTable}'s state the + * provided {@link ValueJoiner} will be called to compute a value (with arbitrary type) for the result record. + * Additionally, for each record that does not find a corresponding record in the corresponding other + * {@code KTable}'s state the provided {@link ValueJoiner} will be called with {@code null} value for the + * corresponding other value to compute a value (with arbitrary type) for the result record. + * The key of the result record is the same as for both joining input records. + *

                + * Note that {@link KeyValue records} with {@code null} values (so-called tombstone records) have delete semantics. + * Thus, for input tombstones the provided value-joiner is not called but a tombstone record is forwarded directly + * to delete a record in the result {@code KTable} if required (i.e., if there is anything to be deleted). + *

                + * Input records with {@code null} key will be dropped and no join computation is performed. + *

                + * Example: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                thisKTablethisStateotherKTableotherStateresult updated record
                <K1:A><K1:A><K1:ValueJoiner(A,null)>
                <K1:A><K1:b><K1:b><K1:ValueJoiner(A,b)>
                <K1:null><K1:b><K1:ValueJoiner(null,b)>
                <K1:null><K1:null>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * + * @param other the other {@code KTable} to be joined with this {@code KTable} + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param named a {@link Named} config used to name the processor in the topology + * @param the value type of the other {@code KTable} + * @param the value type of the result {@code KTable} + * @return a {@code KTable} that contains join-records for each key and values computed by the given + * {@link ValueJoiner}, one for each matched record-pair with the same key plus one for each non-matching record of + * both {@code KTable}s + * @see #join(KTable, ValueJoiner) + * @see #leftJoin(KTable, ValueJoiner) + */ + KTable outerJoin(final KTable other, + final ValueJoiner joiner, + final Named named); + + /** + * Join records of this {@code KTable} (left input) with another {@code KTable}'s (right input) records using + * non-windowed outer equi join, with the {@link Materialized} instance for configuration of the {@link Serde key serde}, + * {@link Serde the result table's value serde}, and {@link KeyValueStore state store}. + * The join is a primary key join with join attribute {@code thisKTable.key == otherKTable.key}. + * In contrast to {@link #join(KTable, ValueJoiner) inner-join} or {@link #leftJoin(KTable, ValueJoiner) left-join}, + * all records from both input {@code KTable}s will produce an output record (cf. below). + * The result is an ever updating {@code KTable} that represents the current (i.e., processing time) result + * of the join. + *

                + * The join is computed by (1) updating the internal state of one {@code KTable} and (2) performing a lookup for a + * matching record in the current (i.e., processing time) internal state of the other {@code KTable}. + * This happens in a symmetric way, i.e., for each update of either {@code this} or the {@code other} input + * {@code KTable} the result gets updated. + *

                + * For each {@code KTable} record that finds a corresponding record in the other {@code KTable}'s state the + * provided {@link ValueJoiner} will be called to compute a value (with arbitrary type) for the result record. + * Additionally, for each record that does not find a corresponding record in the corresponding other + * {@code KTable}'s state the provided {@link ValueJoiner} will be called with {@code null} value for the + * corresponding other value to compute a value (with arbitrary type) for the result record. + * The key of the result record is the same as for both joining input records. + *

                + * Note that {@link KeyValue records} with {@code null} values (so-called tombstone records) have delete semantics. + * Thus, for input tombstones the provided value-joiner is not called but a tombstone record is forwarded directly + * to delete a record in the result {@code KTable} if required (i.e., if there is anything to be deleted). + *

                + * Input records with {@code null} key will be dropped and no join computation is performed. + *

                + * Example: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                thisKTablethisStateotherKTableotherStateresult updated record
                <K1:A><K1:A><K1:ValueJoiner(A,null)>
                <K1:A><K1:b><K1:b><K1:ValueJoiner(A,b)>
                <K1:null><K1:b><K1:ValueJoiner(null,b)>
                <K1:null><K1:null>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * + * @param other the other {@code KTable} to be joined with this {@code KTable} + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param materialized an instance of {@link Materialized} used to describe how the state store should be materialized. + * Cannot be {@code null} + * @param the value type of the other {@code KTable} + * @param the value type of the result {@code KTable} + * @return a {@code KTable} that contains join-records for each key and values computed by the given + * {@link ValueJoiner}, one for each matched record-pair with the same key plus one for each non-matching record of + * both {@code KTable}s + * @see #join(KTable, ValueJoiner) + * @see #leftJoin(KTable, ValueJoiner) + */ + KTable outerJoin(final KTable other, + final ValueJoiner joiner, + final Materialized> materialized); + + + /** + * Join records of this {@code KTable} (left input) with another {@code KTable}'s (right input) records using + * non-windowed outer equi join, with the {@link Materialized} instance for configuration of the {@link Serde key serde}, + * {@link Serde the result table's value serde}, and {@link KeyValueStore state store}. + * The join is a primary key join with join attribute {@code thisKTable.key == otherKTable.key}. + * In contrast to {@link #join(KTable, ValueJoiner) inner-join} or {@link #leftJoin(KTable, ValueJoiner) left-join}, + * all records from both input {@code KTable}s will produce an output record (cf. below). + * The result is an ever updating {@code KTable} that represents the current (i.e., processing time) result + * of the join. + *

                + * The join is computed by (1) updating the internal state of one {@code KTable} and (2) performing a lookup for a + * matching record in the current (i.e., processing time) internal state of the other {@code KTable}. + * This happens in a symmetric way, i.e., for each update of either {@code this} or the {@code other} input + * {@code KTable} the result gets updated. + *

                + * For each {@code KTable} record that finds a corresponding record in the other {@code KTable}'s state the + * provided {@link ValueJoiner} will be called to compute a value (with arbitrary type) for the result record. + * Additionally, for each record that does not find a corresponding record in the corresponding other + * {@code KTable}'s state the provided {@link ValueJoiner} will be called with {@code null} value for the + * corresponding other value to compute a value (with arbitrary type) for the result record. + * The key of the result record is the same as for both joining input records. + *

                + * Note that {@link KeyValue records} with {@code null} values (so-called tombstone records) have delete semantics. + * Thus, for input tombstones the provided value-joiner is not called but a tombstone record is forwarded directly + * to delete a record in the result {@code KTable} if required (i.e., if there is anything to be deleted). + *

                + * Input records with {@code null} key will be dropped and no join computation is performed. + *

                + * Example: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
                thisKTablethisStateotherKTableotherStateresult updated record
                <K1:A><K1:A><K1:ValueJoiner(A,null)>
                <K1:A><K1:b><K1:b><K1:ValueJoiner(A,b)>
                <K1:null><K1:b><K1:ValueJoiner(null,b)>
                <K1:null><K1:null>
                + * Both input streams (or to be more precise, their underlying source topics) need to have the same number of + * partitions. + * + * @param other the other {@code KTable} to be joined with this {@code KTable} + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param named a {@link Named} config used to name the processor in the topology + * @param materialized an instance of {@link Materialized} used to describe how the state store should be materialized. + * Cannot be {@code null} + * @param the value type of the other {@code KTable} + * @param the value type of the result {@code KTable} + * @return a {@code KTable} that contains join-records for each key and values computed by the given + * {@link ValueJoiner}, one for each matched record-pair with the same key plus one for each non-matching record of + * both {@code KTable}s + * @see #join(KTable, ValueJoiner) + * @see #leftJoin(KTable, ValueJoiner) + */ + KTable outerJoin(final KTable other, + final ValueJoiner joiner, + final Named named, + final Materialized> materialized); + + /** + * Join records of this {@code KTable} with another {@code KTable} using non-windowed inner join. + *

                + * This is a foreign key join, where the joining key is determined by the {@code foreignKeyExtractor}. + * + * @param other the other {@code KTable} to be joined with this {@code KTable}. Keyed by KO. + * @param foreignKeyExtractor a {@link Function} that extracts the key (KO) from this table's value (V). If the + * result is null, the update is ignored as invalid. + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param the value type of the result {@code KTable} + * @param the key type of the other {@code KTable} + * @param the value type of the other {@code KTable} + * @return a {@code KTable} that contains the result of joining this table with {@code other} + */ + KTable join(final KTable other, + final Function foreignKeyExtractor, + final ValueJoiner joiner); + + /** + * Join records of this {@code KTable} with another {@code KTable} using non-windowed inner join. + *

                + * This is a foreign key join, where the joining key is determined by the {@code foreignKeyExtractor}. + * + * @param other the other {@code KTable} to be joined with this {@code KTable}. Keyed by KO. + * @param foreignKeyExtractor a {@link Function} that extracts the key (KO) from this table's value (V). If the + * result is null, the update is ignored as invalid. + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param named a {@link Named} config used to name the processor in the topology + * @param the value type of the result {@code KTable} + * @param the key type of the other {@code KTable} + * @param the value type of the other {@code KTable} + * @return a {@code KTable} that contains the result of joining this table with {@code other} + * + * @deprecated since 3.1, removal planned for 4.0. Use {@link #join(KTable, Function, ValueJoiner, TableJoined)} instead. + */ + @Deprecated + KTable join(final KTable other, + final Function foreignKeyExtractor, + final ValueJoiner joiner, + final Named named); + + /** + * Join records of this {@code KTable} with another {@code KTable} using non-windowed inner join, + * using the {@link TableJoined} instance for optional configurations including + * {@link StreamPartitioner partitioners} when the tables being joined use non-default partitioning, + * and also the base name for components of the join. + *

                + * This is a foreign key join, where the joining key is determined by the {@code foreignKeyExtractor}. + * + * @param other the other {@code KTable} to be joined with this {@code KTable}. Keyed by KO. + * @param foreignKeyExtractor a {@link Function} that extracts the key (KO) from this table's value (V). If the + * result is null, the update is ignored as invalid. + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param tableJoined a {@link TableJoined} used to configure partitioners and names of internal topics and stores + * @param the value type of the result {@code KTable} + * @param the key type of the other {@code KTable} + * @param the value type of the other {@code KTable} + * @return a {@code KTable} that contains the result of joining this table with {@code other} + */ + KTable join(final KTable other, + final Function foreignKeyExtractor, + final ValueJoiner joiner, + final TableJoined tableJoined); + + /** + * Join records of this {@code KTable} with another {@code KTable} using non-windowed inner join. + *

                + * This is a foreign key join, where the joining key is determined by the {@code foreignKeyExtractor}. + * + * @param other the other {@code KTable} to be joined with this {@code KTable}. Keyed by KO. + * @param foreignKeyExtractor a {@link Function} that extracts the key (KO) from this table's value (V). If the + * result is null, the update is ignored as invalid. + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param materialized a {@link Materialized} that describes how the {@link StateStore} for the resulting {@code KTable} + * should be materialized. Cannot be {@code null} + * @param the value type of the result {@code KTable} + * @param the key type of the other {@code KTable} + * @param the value type of the other {@code KTable} + * @return a {@code KTable} that contains the result of joining this table with {@code other} + */ + KTable join(final KTable other, + final Function foreignKeyExtractor, + final ValueJoiner joiner, + final Materialized> materialized); + + /** + * Join records of this {@code KTable} with another {@code KTable} using non-windowed inner join. + *

                + * This is a foreign key join, where the joining key is determined by the {@code foreignKeyExtractor}. + * + * @param other the other {@code KTable} to be joined with this {@code KTable}. Keyed by KO. + * @param foreignKeyExtractor a {@link Function} that extracts the key (KO) from this table's value (V). If the + * result is null, the update is ignored as invalid. + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param named a {@link Named} config used to name the processor in the topology + * @param materialized a {@link Materialized} that describes how the {@link StateStore} for the resulting {@code KTable} + * should be materialized. Cannot be {@code null} + * @param the value type of the result {@code KTable} + * @param the key type of the other {@code KTable} + * @param the value type of the other {@code KTable} + * @return a {@code KTable} that contains the result of joining this table with {@code other} + * + * @deprecated since 3.1, removal planned for 4.0. Use {@link #join(KTable, Function, ValueJoiner, TableJoined, Materialized)} instead. + */ + @Deprecated + KTable join(final KTable other, + final Function foreignKeyExtractor, + final ValueJoiner joiner, + final Named named, + final Materialized> materialized); + + /** + * Join records of this {@code KTable} with another {@code KTable} using non-windowed inner join, + * using the {@link TableJoined} instance for optional configurations including + * {@link StreamPartitioner partitioners} when the tables being joined use non-default partitioning, + * and also the base name for components of the join. + *

                + * This is a foreign key join, where the joining key is determined by the {@code foreignKeyExtractor}. + * + * @param other the other {@code KTable} to be joined with this {@code KTable}. Keyed by KO. + * @param foreignKeyExtractor a {@link Function} that extracts the key (KO) from this table's value (V). If the + * result is null, the update is ignored as invalid. + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param tableJoined a {@link TableJoined} used to configure partitioners and names of internal topics and stores + * @param materialized a {@link Materialized} that describes how the {@link StateStore} for the resulting {@code KTable} + * should be materialized. Cannot be {@code null} + * @param the value type of the result {@code KTable} + * @param the key type of the other {@code KTable} + * @param the value type of the other {@code KTable} + * @return a {@code KTable} that contains the result of joining this table with {@code other} + */ + KTable join(final KTable other, + final Function foreignKeyExtractor, + final ValueJoiner joiner, + final TableJoined tableJoined, + final Materialized> materialized); + + /** + * Join records of this {@code KTable} with another {@code KTable} using non-windowed left join. + *

                + * This is a foreign key join, where the joining key is determined by the {@code foreignKeyExtractor}. + * + * @param other the other {@code KTable} to be joined with this {@code KTable}. Keyed by KO. + * @param foreignKeyExtractor a {@link Function} that extracts the key (KO) from this table's value (V). If the + * result is null, the update is ignored as invalid. + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param the value type of the result {@code KTable} + * @param the key type of the other {@code KTable} + * @param the value type of the other {@code KTable} + * @return a {@code KTable} that contains only those records that satisfy the given predicate + */ + KTable leftJoin(final KTable other, + final Function foreignKeyExtractor, + final ValueJoiner joiner); + + /** + * Join records of this {@code KTable} with another {@code KTable} using non-windowed left join. + *

                + * This is a foreign key join, where the joining key is determined by the {@code foreignKeyExtractor}. + * + * @param other the other {@code KTable} to be joined with this {@code KTable}. Keyed by KO. + * @param foreignKeyExtractor a {@link Function} that extracts the key (KO) from this table's value (V) If the + * result is null, the update is ignored as invalid. + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param named a {@link Named} config used to name the processor in the topology + * @param the value type of the result {@code KTable} + * @param the key type of the other {@code KTable} + * @param the value type of the other {@code KTable} + * @return a {@code KTable} that contains the result of joining this table with {@code other} + * + * @deprecated since 3.1, removal planned for 4.0. Use {@link #leftJoin(KTable, Function, ValueJoiner, TableJoined)} instead. + */ + @Deprecated + KTable leftJoin(final KTable other, + final Function foreignKeyExtractor, + final ValueJoiner joiner, + final Named named); + + /** + * Join records of this {@code KTable} with another {@code KTable} using non-windowed left join, + * using the {@link TableJoined} instance for optional configurations including + * {@link StreamPartitioner partitioners} when the tables being joined use non-default partitioning, + * and also the base name for components of the join. + *

                + * This is a foreign key join, where the joining key is determined by the {@code foreignKeyExtractor}. + * + * @param other the other {@code KTable} to be joined with this {@code KTable}. Keyed by KO. + * @param foreignKeyExtractor a {@link Function} that extracts the key (KO) from this table's value (V) If the + * result is null, the update is ignored as invalid. + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param tableJoined a {@link TableJoined} used to configure partitioners and names of internal topics and stores + * @param the value type of the result {@code KTable} + * @param the key type of the other {@code KTable} + * @param the value type of the other {@code KTable} + * @return a {@code KTable} that contains the result of joining this table with {@code other} + */ + KTable leftJoin(final KTable other, + final Function foreignKeyExtractor, + final ValueJoiner joiner, + final TableJoined tableJoined); + + /** + * Join records of this {@code KTable} with another {@code KTable} using non-windowed left join. + *

                + * This is a foreign key join, where the joining key is determined by the {@code foreignKeyExtractor}. + * + * @param other the other {@code KTable} to be joined with this {@code KTable}. Keyed by KO. + * @param foreignKeyExtractor a {@link Function} that extracts the key (KO) from this table's value (V). If the + * result is null, the update is ignored as invalid. + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param materialized a {@link Materialized} that describes how the {@link StateStore} for the resulting {@code KTable} + * should be materialized. Cannot be {@code null} + * @param the value type of the result {@code KTable} + * @param the key type of the other {@code KTable} + * @param the value type of the other {@code KTable} + * @return a {@code KTable} that contains the result of joining this table with {@code other} + */ + KTable leftJoin(final KTable other, + final Function foreignKeyExtractor, + final ValueJoiner joiner, + final Materialized> materialized); + + /** + * Join records of this {@code KTable} with another {@code KTable} using non-windowed left join. + *

                + * This is a foreign key join, where the joining key is determined by the {@code foreignKeyExtractor}. + * + * @param other the other {@code KTable} to be joined with this {@code KTable}. Keyed by KO. + * @param foreignKeyExtractor a {@link Function} that extracts the key (KO) from this table's value (V) If the + * result is null, the update is ignored as invalid. + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param named a {@link Named} config used to name the processor in the topology + * @param materialized a {@link Materialized} that describes how the {@link StateStore} for the resulting {@code KTable} + * should be materialized. Cannot be {@code null} + * @param the value type of the result {@code KTable} + * @param the key type of the other {@code KTable} + * @param the value type of the other {@code KTable} + * @return a {@code KTable} that contains the result of joining this table with {@code other} + * + * @deprecated since 3.1, removal planned for 4.0. Use {@link #leftJoin(KTable, Function, ValueJoiner, TableJoined, Materialized)} instead. + */ + @Deprecated + KTable leftJoin(final KTable other, + final Function foreignKeyExtractor, + final ValueJoiner joiner, + final Named named, + final Materialized> materialized); + + /** + * Join records of this {@code KTable} with another {@code KTable} using non-windowed left join, + * using the {@link TableJoined} instance for optional configurations including + * {@link StreamPartitioner partitioners} when the tables being joined use non-default partitioning, + * and also the base name for components of the join. + *

                + * This is a foreign key join, where the joining key is determined by the {@code foreignKeyExtractor}. + * + * @param other the other {@code KTable} to be joined with this {@code KTable}. Keyed by KO. + * @param foreignKeyExtractor a {@link Function} that extracts the key (KO) from this table's value (V) If the + * result is null, the update is ignored as invalid. + * @param joiner a {@link ValueJoiner} that computes the join result for a pair of matching records + * @param tableJoined a {@link TableJoined} used to configure partitioners and names of internal topics and stores + * @param materialized a {@link Materialized} that describes how the {@link StateStore} for the resulting {@code KTable} + * should be materialized. Cannot be {@code null} + * @param the value type of the result {@code KTable} + * @param the key type of the other {@code KTable} + * @param the value type of the other {@code KTable} + * @return a {@code KTable} that contains the result of joining this table with {@code other} + */ + KTable leftJoin(final KTable other, + final Function foreignKeyExtractor, + final ValueJoiner joiner, + final TableJoined tableJoined, + final Materialized> materialized); + + /** + * Get the name of the local state store used that can be used to query this {@code KTable}. + * + * @return the underlying state store name, or {@code null} if this {@code KTable} cannot be queried. + */ + String queryableStoreName(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/KeyValueMapper.java b/streams/src/main/java/org/apache/kafka/streams/kstream/KeyValueMapper.java new file mode 100644 index 0000000..1112fbb --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/KeyValueMapper.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.streams.KeyValue; + +/** + * The {@code KeyValueMapper} interface for mapping a {@link KeyValue key-value pair} to a new value of arbitrary type. + * For example, it can be used to + *

                  + *
                • map from an input {@link KeyValue} pair to an output {@link KeyValue} pair with different key and/or value type + * (for this case output type {@code VR == }{@link KeyValue KeyValue<NewKeyType,NewValueType>})
                • + *
                • map from an input record to a new key (with arbitrary key type as specified by {@code VR})
                • + *
                + * This is a stateless record-by-record operation, i.e, {@link #apply(Object, Object)} is invoked individually for each + * record of a stream (cf. {@link Transformer} for stateful record transformation). + * {@code KeyValueMapper} is a generalization of {@link ValueMapper}. + * + * @param key type + * @param value type + * @param mapped value type + * @see ValueMapper + * @see Transformer + * @see KStream#map(KeyValueMapper) + * @see KStream#flatMap(KeyValueMapper) + * @see KStream#selectKey(KeyValueMapper) + * @see KStream#groupBy(KeyValueMapper) + * @see KStream#groupBy(KeyValueMapper, Grouped) + * @see KTable#groupBy(KeyValueMapper) + * @see KTable#groupBy(KeyValueMapper, Grouped) + * @see KTable#toStream(KeyValueMapper) + */ +public interface KeyValueMapper { + + /** + * Map a record with the given key and value to a new value. + * + * @param key the key of the record + * @param value the value of the record + * @return the new value + */ + VR apply(final K key, final V value); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/Materialized.java b/streams/src/main/java/org/apache/kafka/streams/kstream/Materialized.java new file mode 100644 index 0000000..82b3800 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/Materialized.java @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.state.KeyValueBytesStoreSupplier; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.SessionBytesStoreSupplier; +import org.apache.kafka.streams.state.SessionStore; +import org.apache.kafka.streams.state.StoreSupplier; +import org.apache.kafka.streams.state.WindowBytesStoreSupplier; +import org.apache.kafka.streams.state.WindowStore; + +import java.time.Duration; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.apache.kafka.streams.internals.ApiUtils.prepareMillisCheckFailMsgPrefix; +import static org.apache.kafka.streams.internals.ApiUtils.validateMillisecondDuration; + +/** + * Used to describe how a {@link StateStore} should be materialized. + * You can either provide a custom {@link StateStore} backend through one of the provided methods accepting a supplier + * or use the default RocksDB backends by providing just a store name. + *

                + * For example, you can read a topic as {@link KTable} and force a state store materialization to access the content + * via Interactive Queries API: + *

                {@code
                + * StreamsBuilder builder = new StreamsBuilder();
                + * KTable table = builder.table(
                + *   "topicName",
                + *   Materialized.as("queryable-store-name"));
                + * }
                + * + * @param type of record key + * @param type of record value + * @param type of state store (note: state stores always have key/value types {@code } + * + * @see org.apache.kafka.streams.state.Stores + */ +public class Materialized { + protected StoreSupplier storeSupplier; + protected String storeName; + protected Serde valueSerde; + protected Serde keySerde; + protected boolean loggingEnabled = true; + protected boolean cachingEnabled = true; + protected Map topicConfig = new HashMap<>(); + protected Duration retention; + + private Materialized(final StoreSupplier storeSupplier) { + this.storeSupplier = storeSupplier; + } + + private Materialized(final String storeName) { + this.storeName = storeName; + } + + /** + * Copy constructor. + * @param materialized the {@link Materialized} instance to copy. + */ + protected Materialized(final Materialized materialized) { + this.storeSupplier = materialized.storeSupplier; + this.storeName = materialized.storeName; + this.keySerde = materialized.keySerde; + this.valueSerde = materialized.valueSerde; + this.loggingEnabled = materialized.loggingEnabled; + this.cachingEnabled = materialized.cachingEnabled; + this.topicConfig = materialized.topicConfig; + this.retention = materialized.retention; + } + + /** + * Materialize a {@link StateStore} with the given name. + * + * @param storeName the name of the underlying {@link KTable} state store; valid characters are ASCII + * alphanumerics, '.', '_' and '-'. + * @param key type of the store + * @param value type of the store + * @param type of the {@link StateStore} + * @return a new {@link Materialized} instance with the given storeName + */ + public static Materialized as(final String storeName) { + Named.validate(storeName); + return new Materialized<>(storeName); + } + + /** + * Materialize a {@link WindowStore} using the provided {@link WindowBytesStoreSupplier}. + * + * Important: Custom subclasses are allowed here, but they should respect the retention contract: + * Window stores are required to retain windows at least as long as (window size + window grace period). + * Stores constructed via {@link org.apache.kafka.streams.state.Stores} already satisfy this contract. + * + * @param supplier the {@link WindowBytesStoreSupplier} used to materialize the store + * @param key type of the store + * @param value type of the store + * @return a new {@link Materialized} instance with the given supplier + */ + public static Materialized> as(final WindowBytesStoreSupplier supplier) { + Objects.requireNonNull(supplier, "supplier can't be null"); + return new Materialized<>(supplier); + } + + /** + * Materialize a {@link SessionStore} using the provided {@link SessionBytesStoreSupplier}. + * + * Important: Custom subclasses are allowed here, but they should respect the retention contract: + * Session stores are required to retain windows at least as long as (session inactivity gap + session grace period). + * Stores constructed via {@link org.apache.kafka.streams.state.Stores} already satisfy this contract. + * + * @param supplier the {@link SessionBytesStoreSupplier} used to materialize the store + * @param key type of the store + * @param value type of the store + * @return a new {@link Materialized} instance with the given sup + * plier + */ + public static Materialized> as(final SessionBytesStoreSupplier supplier) { + Objects.requireNonNull(supplier, "supplier can't be null"); + return new Materialized<>(supplier); + } + + /** + * Materialize a {@link KeyValueStore} using the provided {@link KeyValueBytesStoreSupplier}. + * + * @param supplier the {@link KeyValueBytesStoreSupplier} used to materialize the store + * @param key type of the store + * @param value type of the store + * @return a new {@link Materialized} instance with the given supplier + */ + public static Materialized> as(final KeyValueBytesStoreSupplier supplier) { + Objects.requireNonNull(supplier, "supplier can't be null"); + return new Materialized<>(supplier); + } + + /** + * Materialize a {@link StateStore} with the provided key and value {@link Serde}s. + * An internal name will be used for the store. + * + * @param keySerde the key {@link Serde} to use. If the {@link Serde} is null, then the default key + * serde from configs will be used + * @param valueSerde the value {@link Serde} to use. If the {@link Serde} is null, then the default value + * serde from configs will be used + * @param key type + * @param value type + * @param store type + * @return a new {@link Materialized} instance with the given key and value serdes + */ + public static Materialized with(final Serde keySerde, + final Serde valueSerde) { + return new Materialized((String) null).withKeySerde(keySerde).withValueSerde(valueSerde); + } + + /** + * Set the valueSerde the materialized {@link StateStore} will use. + * + * @param valueSerde the value {@link Serde} to use. If the {@link Serde} is null, then the default value + * serde from configs will be used. If the serialized bytes is null for put operations, + * it is treated as delete operation + * @return itself + */ + public Materialized withValueSerde(final Serde valueSerde) { + this.valueSerde = valueSerde; + return this; + } + + /** + * Set the keySerde the materialize {@link StateStore} will use. + * @param keySerde the key {@link Serde} to use. If the {@link Serde} is null, then the default key + * serde from configs will be used + * @return itself + */ + public Materialized withKeySerde(final Serde keySerde) { + this.keySerde = keySerde; + return this; + } + + /** + * Indicates that a changelog should be created for the store. The changelog will be created + * with the provided configs. + *

                + * Note: Any unrecognized configs will be ignored. + * @param config any configs that should be applied to the changelog + * @return itself + */ + public Materialized withLoggingEnabled(final Map config) { + loggingEnabled = true; + this.topicConfig = config; + return this; + } + + /** + * Disable change logging for the materialized {@link StateStore}. + * @return itself + */ + public Materialized withLoggingDisabled() { + loggingEnabled = false; + this.topicConfig.clear(); + return this; + } + + /** + * Enable caching for the materialized {@link StateStore}. + * @return itself + */ + public Materialized withCachingEnabled() { + cachingEnabled = true; + return this; + } + + /** + * Disable caching for the materialized {@link StateStore}. + * @return itself + */ + public Materialized withCachingDisabled() { + cachingEnabled = false; + return this; + } + + /** + * Configure retention period for window and session stores. Ignored for key/value stores. + * + * Overridden by pre-configured store suppliers + * ({@link Materialized#as(SessionBytesStoreSupplier)} or {@link Materialized#as(WindowBytesStoreSupplier)}). + * + * Note that the retention period must be at least long enough to contain the windowed data's entire life cycle, + * from window-start through window-end, and for the entire grace period. If not specified, the retention + * period would be set as the window length (from window-start through window-end) plus the grace period. + * + * @param retention the retention time + * @return itself + * @throws IllegalArgumentException if retention is negative or can't be represented as {@code long milliseconds} + */ + public Materialized withRetention(final Duration retention) throws IllegalArgumentException { + final String msgPrefix = prepareMillisCheckFailMsgPrefix(retention, "retention"); + final long retenationMs = validateMillisecondDuration(retention, msgPrefix); + + if (retenationMs < 0) { + throw new IllegalArgumentException("Retention must not be negative."); + } + this.retention = retention; + return this; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/Merger.java b/streams/src/main/java/org/apache/kafka/streams/kstream/Merger.java new file mode 100644 index 0000000..6e6b01a --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/Merger.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + + +/** + * The interface for merging aggregate values for {@link SessionWindows} with the given key. + * + * @param key type + * @param aggregate value type + */ +public interface Merger { + + /** + * Compute a new aggregate from the key and two aggregates. + * + * @param aggKey the key of the record + * @param aggOne the first aggregate + * @param aggTwo the second aggregate + * @return the new aggregate value + */ + V apply(final K aggKey, final V aggOne, final V aggTwo); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/Named.java b/streams/src/main/java/org/apache/kafka/streams/kstream/Named.java new file mode 100644 index 0000000..84bb819 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/Named.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.streams.errors.TopologyException; + +import java.util.Objects; + +public class Named implements NamedOperation { + + private static final int MAX_NAME_LENGTH = 249; + + protected String name; + + protected Named(final Named named) { + this(Objects.requireNonNull(named, "named can't be null").name); + } + + protected Named(final String name) { + this.name = name; + if (name != null) { + validate(name); + } + } + + /** + * Create a Named instance with provided name. + * + * @param name the processor name to be used. If {@code null} a default processor name will be generated. + * @return A new {@link Named} instance configured with name + * + * @throws TopologyException if an invalid name is specified; valid characters are ASCII alphanumerics, '.', '_' and '-'. + */ + public static Named as(final String name) { + Objects.requireNonNull(name, "name can't be null"); + return new Named(name); + } + + @Override + public Named withName(final String name) { + return new Named(name); + } + + protected static void validate(final String name) { + if (name.isEmpty()) + throw new TopologyException("Name is illegal, it can't be empty"); + if (name.equals(".") || name.equals("..")) + throw new TopologyException("Name cannot be \".\" or \"..\""); + if (name.length() > MAX_NAME_LENGTH) + throw new TopologyException("Name is illegal, it can't be longer than " + MAX_NAME_LENGTH + + " characters, name: " + name); + if (!containsValidPattern(name)) + throw new TopologyException("Name \"" + name + "\" is illegal, it contains a character other than " + + "ASCII alphanumerics, '.', '_' and '-'"); + } + + /** + * Valid characters for Kafka topics are the ASCII alphanumerics, '.', '_', and '-' + */ + private static boolean containsValidPattern(final String topic) { + for (int i = 0; i < topic.length(); ++i) { + final char c = topic.charAt(i); + + // We don't use Character.isLetterOrDigit(c) because it's slower + final boolean validLetterOrDigit = (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || (c >= 'A' && c <= 'Z'); + final boolean validChar = validLetterOrDigit || c == '.' || c == '_' || c == '-'; + if (!validChar) { + return false; + } + } + return true; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/NamedOperation.java b/streams/src/main/java/org/apache/kafka/streams/kstream/NamedOperation.java new file mode 100644 index 0000000..9a2c40b --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/NamedOperation.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +/** + * Default interface which can be used to personalized the named of operations, internal topics or store. + */ +interface NamedOperation> { + + /** + * Sets the name to be used for an operation. + * + * @param name the name to use. + * @return an instance of {@link NamedOperation} + */ + T withName(final String name); + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/Predicate.java b/streams/src/main/java/org/apache/kafka/streams/kstream/Predicate.java new file mode 100644 index 0000000..8721c05 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/Predicate.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.streams.KeyValue; + +/** + * The {@code Predicate} interface represents a predicate (boolean-valued function) of a {@link KeyValue} pair. + * This is a stateless record-by-record operation, i.e, {@link #test(Object, Object)} is invoked individually for each + * record of a stream. + * + * @param key type + * @param value type + * @see KStream#filter(Predicate) + * @see KStream#filterNot(Predicate) + * @see BranchedKStream#branch(Predicate) + * @see KTable#filter(Predicate) + * @see KTable#filterNot(Predicate) + */ +public interface Predicate { + + /** + * Test if the record with the given key and value satisfies the predicate. + * + * @param key the key of the record + * @param value the value of the record + * @return {@code true} if the {@link KeyValue} pair satisfies the predicate—{@code false} otherwise + */ + boolean test(final K key, final V value); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/Printed.java b/streams/src/main/java/org/apache/kafka/streams/kstream/Printed.java new file mode 100644 index 0000000..6a3d1e5 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/Printed.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.errors.TopologyException; + +import java.io.IOException; +import java.io.OutputStream; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.Objects; + +/** + * An object to define the options used when printing a {@link KStream}. + * + * @param key type + * @param value type + * @see KStream#print(Printed) + */ +public class Printed implements NamedOperation> { + protected final OutputStream outputStream; + protected String label; + protected String processorName; + protected KeyValueMapper mapper = + (KeyValueMapper) (key, value) -> String.format("%s, %s", key, value); + + private Printed(final OutputStream outputStream) { + this.outputStream = outputStream; + } + + /** + * Copy constructor. + * @param printed instance of {@link Printed} to copy + */ + protected Printed(final Printed printed) { + this.outputStream = printed.outputStream; + this.label = printed.label; + this.mapper = printed.mapper; + this.processorName = printed.processorName; + } + + /** + * Print the records of a {@link KStream} to a file. + * + * @param filePath path of the file + * @param key type + * @param value type + * @return a new Printed instance + */ + public static Printed toFile(final String filePath) { + Objects.requireNonNull(filePath, "filePath can't be null"); + if (Utils.isBlank(filePath)) { + throw new TopologyException("filePath can't be an empty string"); + } + try { + return new Printed<>(Files.newOutputStream(Paths.get(filePath))); + } catch (final IOException e) { + throw new TopologyException("Unable to write stream to file at [" + filePath + "] " + e.getMessage()); + } + } + + /** + * Print the records of a {@link KStream} to system out. + * + * @param key type + * @param value type + * @return a new Printed instance + */ + public static Printed toSysOut() { + return new Printed<>(System.out); + } + + /** + * Print the records of a {@link KStream} with the provided label. + * + * @param label label to use + * @return this + */ + public Printed withLabel(final String label) { + Objects.requireNonNull(label, "label can't be null"); + this.label = label; + return this; + } + + /** + * Print the records of a {@link KStream} with the provided {@link KeyValueMapper} + * The provided KeyValueMapper's mapped value type must be {@code String}. + *

                + * The example below shows how to customize output data. + *

                {@code
                +     * final KeyValueMapper mapper = new KeyValueMapper() {
                +     *     public String apply(Integer key, String value) {
                +     *         return String.format("(%d, %s)", key, value);
                +     *     }
                +     * };
                +     * }
                + * + * Implementors will need to override {@code toString()} for keys and values that are not of type {@link String}, + * {@link Integer} etc. to get meaningful information. + * + * @param mapper mapper to use + * @return this + */ + public Printed withKeyValueMapper(final KeyValueMapper mapper) { + Objects.requireNonNull(mapper, "mapper can't be null"); + this.mapper = mapper; + return this; + } + + /** + * Print the records of a {@link KStream} with provided processor name. + * + * @param processorName the processor name to be used. If {@code null} a default processor name will be generated + ** @return this + */ + @Override + public Printed withName(final String processorName) { + this.processorName = processorName; + return this; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/Produced.java b/streams/src/main/java/org/apache/kafka/streams/kstream/Produced.java new file mode 100644 index 0000000..b14c846 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/Produced.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.clients.producer.internals.DefaultPartitioner; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.kstream.internals.WindowedSerializer; +import org.apache.kafka.streams.kstream.internals.WindowedStreamPartitioner; +import org.apache.kafka.streams.processor.StreamPartitioner; + +import java.util.Objects; + +/** + * This class is used to provide the optional parameters when producing to new topics + * using {@link KStream#to(String, Produced)}. + * + * @param key type + * @param value type + */ +public class Produced implements NamedOperation> { + + protected Serde keySerde; + protected Serde valueSerde; + protected StreamPartitioner partitioner; + protected String processorName; + + private Produced(final Serde keySerde, + final Serde valueSerde, + final StreamPartitioner partitioner, + final String processorName) { + this.keySerde = keySerde; + this.valueSerde = valueSerde; + this.partitioner = partitioner; + this.processorName = processorName; + } + + protected Produced(final Produced produced) { + this.keySerde = produced.keySerde; + this.valueSerde = produced.valueSerde; + this.partitioner = produced.partitioner; + this.processorName = produced.processorName; + } + + /** + * Create a Produced instance with provided keySerde and valueSerde. + * @param keySerde Serde to use for serializing the key + * @param valueSerde Serde to use for serializing the value + * @param key type + * @param value type + * @return A new {@link Produced} instance configured with keySerde and valueSerde + * @see KStream#to(String, Produced) + */ + public static Produced with(final Serde keySerde, + final Serde valueSerde) { + return new Produced<>(keySerde, valueSerde, null, null); + } + + /** + * Create a Produced instance with provided keySerde, valueSerde, and partitioner. + * @param keySerde Serde to use for serializing the key + * @param valueSerde Serde to use for serializing the value + * @param partitioner the function used to determine how records are distributed among partitions of the topic, + * if not specified and {@code keySerde} provides a {@link WindowedSerializer} for the key + * {@link WindowedStreamPartitioner} will be used—otherwise {@link DefaultPartitioner} + * will be used + * @param key type + * @param value type + * @return A new {@link Produced} instance configured with keySerde, valueSerde, and partitioner + * @see KStream#to(String, Produced) + */ + public static Produced with(final Serde keySerde, + final Serde valueSerde, + final StreamPartitioner partitioner) { + return new Produced<>(keySerde, valueSerde, partitioner, null); + } + + /** + * Create an instance of {@link Produced} with provided processor name. + * + * @param processorName the processor name to be used. If {@code null} a default processor name will be generated + * @param key type + * @param value type + * @return a new instance of {@link Produced} + */ + public static Produced as(final String processorName) { + return new Produced<>(null, null, null, processorName); + } + + /** + * Create a Produced instance with provided keySerde. + * @param keySerde Serde to use for serializing the key + * @param key type + * @param value type + * @return A new {@link Produced} instance configured with keySerde + * @see KStream#to(String, Produced) + */ + public static Produced keySerde(final Serde keySerde) { + return new Produced<>(keySerde, null, null, null); + } + + /** + * Create a Produced instance with provided valueSerde. + * @param valueSerde Serde to use for serializing the key + * @param key type + * @param value type + * @return A new {@link Produced} instance configured with valueSerde + * @see KStream#to(String, Produced) + */ + public static Produced valueSerde(final Serde valueSerde) { + return new Produced<>(null, valueSerde, null, null); + } + + /** + * Create a Produced instance with provided partitioner. + * @param partitioner the function used to determine how records are distributed among partitions of the topic, + * if not specified and the key serde provides a {@link WindowedSerializer} for the key + * {@link WindowedStreamPartitioner} will be used—otherwise {@link DefaultPartitioner} will be used + * @param key type + * @param value type + * @return A new {@link Produced} instance configured with partitioner + * @see KStream#to(String, Produced) + */ + public static Produced streamPartitioner(final StreamPartitioner partitioner) { + return new Produced<>(null, null, partitioner, null); + } + + /** + * Produce records using the provided partitioner. + * @param partitioner the function used to determine how records are distributed among partitions of the topic, + * if not specified and the key serde provides a {@link WindowedSerializer} for the key + * {@link WindowedStreamPartitioner} will be used—otherwise {@link DefaultPartitioner} wil be used + * @return this + */ + public Produced withStreamPartitioner(final StreamPartitioner partitioner) { + this.partitioner = partitioner; + return this; + } + + /** + * Produce records using the provided valueSerde. + * @param valueSerde Serde to use for serializing the value + * @return this + */ + public Produced withValueSerde(final Serde valueSerde) { + this.valueSerde = valueSerde; + return this; + } + + /** + * Produce records using the provided keySerde. + * @param keySerde Serde to use for serializing the key + * @return this + */ + public Produced withKeySerde(final Serde keySerde) { + this.keySerde = keySerde; + return this; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final Produced produced = (Produced) o; + return Objects.equals(keySerde, produced.keySerde) && + Objects.equals(valueSerde, produced.valueSerde) && + Objects.equals(partitioner, produced.partitioner); + } + + @Override + public int hashCode() { + return Objects.hash(keySerde, valueSerde, partitioner); + } + + @Override + public Produced withName(final String name) { + this.processorName = name; + return this; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/Reducer.java b/streams/src/main/java/org/apache/kafka/streams/kstream/Reducer.java new file mode 100644 index 0000000..1acd587 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/Reducer.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.streams.KeyValue; + +/** + * The {@code Reducer} interface for combining two values of the same type into a new value. + * In contrast to {@link Aggregator} the result type must be the same as the input type. + *

                + * The provided values can be either original values from input {@link KeyValue} pair records or be a previously + * computed result from {@link Reducer#apply(Object, Object)}. + *

                + * {@code Reducer} can be used to implement aggregation functions like sum, min, or max. + * + * @param value type + * @see KGroupedStream#reduce(Reducer) + * @see KGroupedStream#reduce(Reducer, Materialized) + * @see TimeWindowedKStream#reduce(Reducer) + * @see TimeWindowedKStream#reduce(Reducer, Materialized) + * @see SessionWindowedKStream#reduce(Reducer) + * @see SessionWindowedKStream#reduce(Reducer, Materialized) + * @see Aggregator + */ +public interface Reducer { + + /** + * Aggregate the two given values into a single one. + * + * @param value1 the first value for the aggregation + * @param value2 the second value for the aggregation + * @return the aggregated value + */ + V apply(final V value1, final V value2); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/Repartitioned.java b/streams/src/main/java/org/apache/kafka/streams/kstream/Repartitioned.java new file mode 100644 index 0000000..40f66f0 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/Repartitioned.java @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.clients.producer.internals.DefaultPartitioner; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.kstream.internals.WindowedSerializer; +import org.apache.kafka.streams.kstream.internals.WindowedStreamPartitioner; +import org.apache.kafka.streams.processor.StreamPartitioner; + +/** + * This class is used to provide the optional parameters for internal repartition topics. + * + * @param key type + * @param value type + * @see KStream#repartition() + * @see KStream#repartition(Repartitioned) + */ +public class Repartitioned implements NamedOperation> { + + protected final String name; + protected final Serde keySerde; + protected final Serde valueSerde; + protected final Integer numberOfPartitions; + protected final StreamPartitioner partitioner; + + private Repartitioned(final String name, + final Serde keySerde, + final Serde valueSerde, + final Integer numberOfPartitions, + final StreamPartitioner partitioner) { + this.name = name; + this.keySerde = keySerde; + this.valueSerde = valueSerde; + this.numberOfPartitions = numberOfPartitions; + this.partitioner = partitioner; + } + + protected Repartitioned(final Repartitioned repartitioned) { + this( + repartitioned.name, + repartitioned.keySerde, + repartitioned.valueSerde, + repartitioned.numberOfPartitions, + repartitioned.partitioner + ); + } + + /** + * Create a {@code Repartitioned} instance with the provided name used as part of the repartition topic. + * + * @param name the name used as a processor named and part of the repartition topic name. + * @param key type + * @param value type + * @return A new {@code Repartitioned} instance configured with processor name and repartition topic name + * @see KStream#repartition(Repartitioned) + */ + public static Repartitioned as(final String name) { + return new Repartitioned<>(name, null, null, null, null); + } + + /** + * Create a {@code Repartitioned} instance with provided key serde and value serde. + * + * @param keySerde Serde to use for serializing the key + * @param valueSerde Serde to use for serializing the value + * @param key type + * @param value type + * @return A new {@code Repartitioned} instance configured with key serde and value serde + * @see KStream#repartition(Repartitioned) + */ + public static Repartitioned with(final Serde keySerde, + final Serde valueSerde) { + return new Repartitioned<>(null, keySerde, valueSerde, null, null); + } + + /** + * Create a {@code Repartitioned} instance with provided partitioner. + * + * @param partitioner the function used to determine how records are distributed among partitions of the topic, + * if not specified and the key serde provides a {@link WindowedSerializer} for the key + * {@link WindowedStreamPartitioner} will be used—otherwise {@link DefaultPartitioner} will be used + * @param key type + * @param value type + * @return A new {@code Repartitioned} instance configured with partitioner + * @see KStream#repartition(Repartitioned) + */ + public static Repartitioned streamPartitioner(final StreamPartitioner partitioner) { + return new Repartitioned<>(null, null, null, null, partitioner); + } + + /** + * Create a {@code Repartitioned} instance with provided number of partitions for repartition topic. + * + * @param numberOfPartitions number of partitions used when creating repartition topic + * @param key type + * @param value type + * @return A new {@code Repartitioned} instance configured number of partitions + * @see KStream#repartition(Repartitioned) + */ + public static Repartitioned numberOfPartitions(final int numberOfPartitions) { + return new Repartitioned<>(null, null, null, numberOfPartitions, null); + } + + /** + * Create a new instance of {@code Repartitioned} with the provided name used as part of repartition topic and processor name. + * + * @param name the name used for the processor name and as part of the repartition topic + * @return a new {@code Repartitioned} instance configured with the name + */ + @Override + public Repartitioned withName(final String name) { + return new Repartitioned<>(name, keySerde, valueSerde, numberOfPartitions, partitioner); + } + + /** + * Create a new instance of {@code Repartitioned} with the provided number of partitions for repartition topic. + * + * @param numberOfPartitions the name used for the processor name and as part of the repartition topic name + * @return a new {@code Repartitioned} instance configured with the number of partitions + */ + public Repartitioned withNumberOfPartitions(final int numberOfPartitions) { + return new Repartitioned<>(name, keySerde, valueSerde, numberOfPartitions, partitioner); + } + + /** + * Create a new instance of {@code Repartitioned} with the provided key serde. + * + * @param keySerde Serde to use for serializing the key + * @return a new {@code Repartitioned} instance configured with the key serde + */ + public Repartitioned withKeySerde(final Serde keySerde) { + return new Repartitioned<>(name, keySerde, valueSerde, numberOfPartitions, partitioner); + } + + /** + * Create a new instance of {@code Repartitioned} with the provided value serde. + * + * @param valueSerde Serde to use for serializing the value + * @return a new {@code Repartitioned} instance configured with the value serde + */ + public Repartitioned withValueSerde(final Serde valueSerde) { + return new Repartitioned<>(name, keySerde, valueSerde, numberOfPartitions, partitioner); + } + + /** + * Create a new instance of {@code Repartitioned} with the provided partitioner. + * + * @param partitioner the function used to determine how records are distributed among partitions of the topic, + * if not specified and the key serde provides a {@link WindowedSerializer} for the key + * {@link WindowedStreamPartitioner} will be used—otherwise {@link DefaultPartitioner} wil be used + * @return a new {@code Repartitioned} instance configured with provided partitioner + */ + public Repartitioned withStreamPartitioner(final StreamPartitioner partitioner) { + return new Repartitioned<>(name, keySerde, valueSerde, numberOfPartitions, partitioner); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/SessionWindowedCogroupedKStream.java b/streams/src/main/java/org/apache/kafka/streams/kstream/SessionWindowedCogroupedKStream.java new file mode 100644 index 0000000..b7e3b07 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/SessionWindowedCogroupedKStream.java @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StoreQueryParameters; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.state.SessionStore; + +import java.time.Duration; + +/** + * {@code SessionWindowedCogroupKStream} is an abstraction of a windowed record stream of {@link KeyValue} pairs. + * It is an intermediate representation of a {@link CogroupedKStream} in order to apply a windowed aggregation operation + * on the original {@link KGroupedStream} records resulting in a windowed {@link KTable} (a windowed + * {@code KTable} is a {@link KTable} with key type {@link Windowed Windowed}). + *

                + * {@link SessionWindows} are dynamic data driven windows. + * They have no fixed time boundaries, rather the size of the window is determined by the records. + *

                + * The result is written into a local {@link SessionStore} (which is basically an ever-updating + * materialized view) that can be queried using the name provided in the {@link Materialized} instance. + * Furthermore, updates to the store are sent downstream into a windowed {@link KTable} changelog stream, where + * "windowed" implies that the {@link KTable} key is a combined key of the original record key and a window ID. + * New events are added to sessions until their grace period ends (see {@link SessionWindows#grace(Duration)}). + *

                + * A {@code SessionWindowedCogroupedKStream} must be obtained from a {@link CogroupedKStream} via + * {@link CogroupedKStream#windowedBy(SessionWindows)}. + * + * @param Type of keys + * @param Type of values + * @see KStream + * @see KGroupedStream + * @see SessionWindows + * @see CogroupedKStream + */ +public interface SessionWindowedCogroupedKStream { + + /** + * Aggregate the values of records in these streams by the grouped key and defined sessions. + * Note that sessions are generated on a per-key basis and records with different keys create independent sessions. + * Records with {@code null} key or value are ignored. + * The result is written into a local {@link SessionStore} (which is basically an ever-updating materialized view). + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * The specified {@link Initializer} is applied directly before the first input record per session is processed to + * provide an initial intermediate aggregation result that is used to process the first record per session. + * The specified {@link Aggregator} (as specified in {@link KGroupedStream#cogroup(Aggregator)} or + * {@link CogroupedKStream#cogroup(KGroupedStream, Aggregator)}) is applied for each input record and computes a new + * aggregate using the current aggregate (or for the very first record using the intermediate aggregation result + * provided via the {@link Initializer}) and the record's value. + * The specified {@link Merger} is used to merge two existing sessions into one, i.e., when the windows overlap, + * they are merged into a single session and the old sessions are discarded. + * Thus, {@code aggregate()} can be used to compute aggregate functions like count or sum etc. + *

                + * The default key and value serde from the config will be used for serializing the result. + * If a different serde is required then you should use {@link #aggregate(Initializer, Merger, Materialized)}. + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same window and key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * For failure and recovery the store will be backed by an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-${internalStoreName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "internalStoreName" is an internal name + * and "-changelog" is a fixed suffix. + * Note that the internal store name may not be queryable through Interactive Queries. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param initializer an {@link Initializer} that computes an initial intermediate aggregation result. Cannot be {@code null}. + * @param sessionMerger a {@link Merger} that combines two aggregation results. Cannot be {@code null}. + * @return a windowed {@link KTable} that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key per session + */ + KTable, V> aggregate(final Initializer initializer, + final Merger sessionMerger); + + /** + * Aggregate the values of records in these streams by the grouped key and defined sessions. + * Note that sessions are generated on a per-key basis and records with different keys create independent sessions. + * Records with {@code null} key or value are ignored. + * The result is written into a local {@link SessionStore} (which is basically an ever-updating materialized view). + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * The specified {@link Initializer} is applied directly before the first input record per session is processed to + * provide an initial intermediate aggregation result that is used to process the first record per session. + * The specified {@link Aggregator} (as specified in {@link KGroupedStream#cogroup(Aggregator)} or + * {@link CogroupedKStream#cogroup(KGroupedStream, Aggregator)}) is applied for each input record and computes a new + * aggregate using the current aggregate (or for the very first record using the intermediate aggregation result + * provided via the {@link Initializer}) and the record's value. + * The specified {@link Merger} is used to merge two existing sessions into one, i.e., when the windows overlap, + * they are merged into a single session and the old sessions are discarded. + * Thus, {@code aggregate()} can be used to compute aggregate functions like count or sum etc. + *

                + * The default key and value serde from the config will be used for serializing the result. + * If a different serde is required then you should use + * {@link #aggregate(Initializer, Merger, Named, Materialized)}. + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same window and key. + * The rate of propagated updates depends on your input data rate, the number of distinct + * keys, the number of parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} + * parameters for {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * For failure and recovery the store will be backed by an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-${internalStoreName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "internalStoreName" is an internal name + * and "-changelog" is a fixed suffix. + * Note that the internal store name may not be queryable through Interactive Queries. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param initializer an {@link Initializer} that computes an initial intermediate aggregation result. Cannot be {@code null}. + * @param sessionMerger a {@link Merger} that combines two aggregation results. Cannot be {@code null}. + * @param named a {@link Named} config used to name the processor in the topology. Cannot be {@code null}. + * @return a windowed {@link KTable} that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key per session + */ + KTable, V> aggregate(final Initializer initializer, + final Merger sessionMerger, + final Named named); + + /** + * Aggregate the values of records in these streams by the grouped key and defined sessions. + * Records with {@code null} key or value are ignored. + * The result is written into a local {@link SessionStore} (which is basically an ever-updating materialized view) + * that can be queried using the store name as provided with {@link Materialized}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * The specified {@link Initializer} is applied directly before the first input record (per key) in each window is + * processed to provide an initial intermediate aggregation result that is used to process the first record for + * the session (per key). + * The specified {@link Aggregator} (as specified in {@link KGroupedStream#cogroup(Aggregator)} or + * {@link CogroupedKStream#cogroup(KGroupedStream, Aggregator)}) is applied for each input record and computes a new + * aggregate using the current aggregate (or for the very first record using the intermediate aggregation result + * provided via the {@link Initializer}) and the record's value. + * The specified {@link Merger} is used to merge two existing sessions into one, i.e., when the windows overlap, + * they are merged into a single session and the old sessions are discarded. + * Thus, {@code aggregate()} can be used to compute aggregate functions like count or sum etc. + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same window and key if caching is enabled on the {@link Materialized} instance. + * When caching is enabled the rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link SessionStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // counting words
                +     * Store queryableStoreName = ... // the queryableStoreName should be the name of the store as defined by the Materialized instance
                +     * ReadOnlySessionStore localWindowStore = streams.store(queryableStoreName, QueryableStoreTypes.sessionStore());
                +     *
                +     * String key = "some-word";
                +     * long fromTime = ...;
                +     * long toTime = ...;
                +     * WindowStoreIterator aggregateStore = localWindowStore.fetch(key, timeFrom, timeTo); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + *

                + * For failure and recovery the store will be backed by an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the {@link Materialized} instance must be a valid Kafka topic name and + * cannot contain characters other than ASCII alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is the + * provide store name defined in {@link Materialized}, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param initializer an {@link Initializer} that computes an initial intermediate aggregation result. Cannot be {@code null}. + * @param sessionMerger a {@link Merger} that combines two aggregation results. Cannot be {@code null}. + * @param materialized a {@link Materialized} config used to materialize a state store. Cannot be {@code null}. + * @return a windowed {@link KTable} that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key within a window + */ + KTable, V> aggregate(final Initializer initializer, + final Merger sessionMerger, + final Materialized> materialized); + + /** + * Aggregate the values of records in these streams by the grouped key and defined sessions. + * Records with {@code null} key or value are ignored. + * The result is written into a local {@link SessionStore} (which is basically an ever-updating materialized view) + * that can be queried using the store name as provided with {@link Materialized}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * The specified {@link Initializer} is applied directly before the first input record (per key) in each window is + * processed to provide an initial intermediate aggregation result that is used to process the first record for + * the session (per key). + * The specified {@link Aggregator} (as specified in {@link KGroupedStream#cogroup(Aggregator)} or + * {@link CogroupedKStream#cogroup(KGroupedStream, Aggregator)}) is applied for each input record and computes a new + * aggregate using the current aggregate (or for the very first record using the intermediate aggregation result + * provided via the {@link Initializer}) and the record's value. + * The specified {@link Merger} is used to merge two existing sessions into one, i.e., when the windows overlap, + * they are merged into a single session and the old sessions are discarded. + * Thus, {@code aggregate()} can be used to compute aggregate functions like count or sum etc. + *

                + * Not all updates might get sent downstream, as an internal cache will be used to deduplicate consecutive updates + * to the same window and key if caching is enabled on the {@link Materialized} instance. + * When caching is enabled the rate of propagated updates depends on your input data rate, the number of distinct + * keys, the number of parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} + * parameters for {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link SessionStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters)} KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // some windowed aggregation on value type double
                +     * Sting queryableStoreName = ... // the queryableStoreName should be the name of the store as defined by the Materialized instance
                +     * ReadOnlySessionStore sessionStore = streams.store(queryableStoreName, QueryableStoreTypes.sessionStore());
                +     * String key = "some-key";
                +     * KeyValueIterator, Long> aggForKeyForSession = localWindowStore.fetch(key); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + *

                + * For failure and recovery the store will be backed by an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the {@link Materialized} instance must be a valid Kafka topic name and + * cannot contain characters other than ASCII alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is the + * provide store name defined in {@link Materialized}, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param initializer an {@link Initializer} that computes an initial intermediate aggregation result. Cannot be {@code null}. + * @param sessionMerger a {@link Merger} that combines two aggregation results. Cannot be {@code null}. + * @param named a {@link Named} config used to name the processor in the topology. Cannot be {@code null}. + * @param materialized a {@link Materialized} config used to materialize a state store. Cannot be {@code null}. + * @return a windowed {@link KTable} that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key per session + */ + KTable, V> aggregate(final Initializer initializer, + final Merger sessionMerger, + final Named named, + final Materialized> materialized); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/SessionWindowedDeserializer.java b/streams/src/main/java/org/apache/kafka/streams/kstream/SessionWindowedDeserializer.java new file mode 100644 index 0000000..e77efe7 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/SessionWindowedDeserializer.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.state.internals.SessionKeySchema; + +import java.util.Map; + +public class SessionWindowedDeserializer implements Deserializer> { + + private Deserializer inner; + + // Default constructor needed by Kafka + public SessionWindowedDeserializer() {} + + public SessionWindowedDeserializer(final Deserializer inner) { + this.inner = inner; + } + + @SuppressWarnings("unchecked") + @Override + public void configure(final Map configs, final boolean isKey) { + final String windowedInnerClassSerdeConfig = (String) configs.get(StreamsConfig.WINDOWED_INNER_CLASS_SERDE); + + Serde windowInnerClassSerde = null; + + if (windowedInnerClassSerdeConfig != null) { + try { + windowInnerClassSerde = Utils.newInstance(windowedInnerClassSerdeConfig, Serde.class); + } catch (final ClassNotFoundException e) { + throw new ConfigException(StreamsConfig.WINDOWED_INNER_CLASS_SERDE, windowedInnerClassSerdeConfig, + "Serde class " + windowedInnerClassSerdeConfig + " could not be found."); + } + } + + if (inner != null && windowedInnerClassSerdeConfig != null) { + if (!inner.getClass().getName().equals(windowInnerClassSerde.deserializer().getClass().getName())) { + throw new IllegalArgumentException("Inner class deserializer set using constructor " + + "(" + inner.getClass().getName() + ")" + + " is different from the one set in windowed.inner.class.serde config " + + "(" + windowInnerClassSerde.deserializer().getClass().getName() + ")."); + } + } else if (inner == null && windowedInnerClassSerdeConfig == null) { + throw new IllegalArgumentException("Inner class deserializer should be set either via constructor " + + "or via the windowed.inner.class.serde config"); + } else if (inner == null) + inner = windowInnerClassSerde.deserializer(); + } + + @Override + public Windowed deserialize(final String topic, final byte[] data) { + WindowedSerdes.verifyInnerDeserializerNotNull(inner, this); + + if (data == null || data.length == 0) { + return null; + } + + // for either key or value, their schema is the same hence we will just use session key schema + return SessionKeySchema.from(data, inner, topic); + } + + @Override + public void close() { + if (inner != null) { + inner.close(); + } + } + + // Only for testing + Deserializer innerDeserializer() { + return inner; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/SessionWindowedKStream.java b/streams/src/main/java/org/apache/kafka/streams/kstream/SessionWindowedKStream.java new file mode 100644 index 0000000..1b7a363 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/SessionWindowedKStream.java @@ -0,0 +1,646 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StoreQueryParameters; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.state.SessionStore; + +import java.time.Duration; + +/** + * {@code SessionWindowedKStream} is an abstraction of a windowed record stream of {@link KeyValue} pairs. + * It is an intermediate representation after a grouping and windowing of a {@link KStream} before an aggregation is + * applied to the new (partitioned) windows resulting in a windowed {@link KTable} (a windowed + * {@code KTable} is a {@link KTable} with key type {@link Windowed Windowed}). + *

                + * {@link SessionWindows} are dynamic data driven windows. + * They have no fixed time boundaries, rather the size of the window is determined by the records. + *

                + * The result is written into a local {@link SessionStore} (which is basically an ever-updating + * materialized view) that can be queried using the name provided in the {@link Materialized} instance. + * Furthermore, updates to the store are sent downstream into a windowed {@link KTable} changelog stream, where + * "windowed" implies that the {@link KTable} key is a combined key of the original record key and a window ID. + * New events are added to sessions until their grace period ends (see {@link SessionWindows#grace(Duration)}). + *

                + * A {@code SessionWindowedKStream} must be obtained from a {@link KGroupedStream} via + * {@link KGroupedStream#windowedBy(SessionWindows)}. + * + * @param Type of keys + * @param Type of values + * @see KStream + * @see KGroupedStream + * @see SessionWindows + */ +public interface SessionWindowedKStream { + + /** + * Count the number of records in this stream by the grouped key and defined sessions. + * Note that sessions are generated on a per-key basis and records with different keys create independent sessions. + * Records with {@code null} key or value are ignored. + *

                + * The result is written into a local {@link SessionStore} (which is basically an ever-updating materialized view). + * The default key serde from the config will be used for serializing the result. + * If a different serde is required then you should use {@link #count(Materialized)}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same session and key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * For failure and recovery the store will be backed by an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-${internalStoreName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "internalStoreName" is an internal name + * and "-changelog" is a fixed suffix. + * Note that the internal store name may not be queryable through Interactive Queries. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @return a windowed {@link KTable} that contains "update" records with unmodified keys and {@link Long} values + * that represent the latest (rolling) count (i.e., number of records) for each key per session + */ + KTable, Long> count(); + + /** + * Count the number of records in this stream by the grouped key and defined sessions. + * Note that sessions are generated on a per-key basis and records with different keys create independent sessions. + * Records with {@code null} key or value are ignored. + *

                + * The result is written into a local {@link SessionStore} (which is basically an ever-updating materialized view). + * The default key serde from the config will be used for serializing the result. + * If a different serde is required then you should use {@link #count(Named, Materialized)}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same session and key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * For failure and recovery the store will be backed by an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-${internalStoreName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "internalStoreName" is an internal name + * and "-changelog" is a fixed suffix. + * Note that the internal store name may not be queryable through Interactive Queries. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param named a {@link Named} config used to name the processor in the topology. Cannot be {@code null}. + * @return a windowed {@link KTable} that contains "update" records with unmodified keys and {@link Long} values + * that represent the latest (rolling) count (i.e., number of records) for each key per session + */ + KTable, Long> count(final Named named); + + /** + * Count the number of records in this stream by the grouped key and defined sessions. + * Note that sessions are generated on a per-key basis and records with different keys create independent sessions. + * Records with {@code null} key or value are ignored. + *

                + * The result is written into a local {@link SessionStore} (which is basically an ever-updating materialized view) + * that can be queried using the name provided with {@link Materialized}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * Not all updates might get sent downstream, as an internal cache will be used to deduplicate consecutive updates + * to the same window and key if caching is enabled on the {@link Materialized} instance. + * When caching is enabled the rate of propagated updates depends on your input data rate, the number of distinct + * keys, the number of parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} + * parameters for {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link SessionStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // compute sum
                +     * Sting queryableStoreName = ... // the queryableStoreName should be the name of the store as defined by the Materialized instance
                +     * ReadOnlySessionStore localWindowStore = streams.store(queryableStoreName, QueryableStoreTypes.ReadOnlySessionStore);
                +     * String key = "some-key";
                +     * KeyValueIterator, Long> sumForKeyForWindows = localWindowStore.fetch(key); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + *

                + * For failure and recovery the store will be backed by an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the Materialized instance must be a valid Kafka topic name and cannot + * contain characters other than ASCII alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is the provide store name defined + * in {@code Materialized}, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param materialized an instance of {@link Materialized} used to materialize a state store. Cannot be {@code null}. + * Note: the valueSerde will be automatically set to {@link org.apache.kafka.common.serialization.Serdes#Long() Serdes#Long()} + * if there is no valueSerde provided + * @return a windowed {@link KTable} that contains "update" records with unmodified keys and {@link Long} values + * that represent the latest (rolling) count (i.e., number of records) for each key per session + */ + KTable, Long> count(final Materialized> materialized); + + /** + * Count the number of records in this stream by the grouped key and defined sessions. + * Note that sessions are generated on a per-key basis and records with different keys create independent sessions. + * Records with {@code null} key or value are ignored. + *

                + * The result is written into a local {@link SessionStore} (which is basically an ever-updating materialized view) + * that can be queried using the name provided with {@link Materialized}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * Not all updates might get sent downstream, as an internal cache will be used to deduplicate consecutive updates + * to the same window and key if caching is enabled on the {@link Materialized} instance. + * When caching is enabled the rate of propagated updates depends on your input data rate, the number of distinct + * keys, the number of parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} + * parameters for {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link SessionStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // compute sum
                +     * Sting queryableStoreName = ... // the queryableStoreName should be the name of the store as defined by the Materialized instance
                +     * ReadOnlySessionStore localWindowStore = streams.store(queryableStoreName, QueryableStoreTypes.ReadOnlySessionStore);
                +     * String key = "some-key";
                +     * KeyValueIterator, Long> sumForKeyForWindows = localWindowStore.fetch(key); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + *

                + * For failure and recovery the store will be backed by an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the Materialized instance must be a valid Kafka topic name and cannot + * contain characters other than ASCII alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is the provide store name defined + * in {@code Materialized}, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param named a {@link Named} config used to name the processor in the topology. Cannot be {@code null}. + * @param materialized an instance of {@link Materialized} used to materialize a state store. Cannot be {@code null}. + * Note: the valueSerde will be automatically set to {@link org.apache.kafka.common.serialization.Serdes#Long() Serdes#Long()} + * if there is no valueSerde provided + * @return a windowed {@link KTable} that contains "update" records with unmodified keys and {@link Long} values + * that represent the latest (rolling) count (i.e., number of records) for each key per session + */ + KTable, Long> count(final Named named, + final Materialized> materialized); + + /** + * Aggregate the values of records in this stream by the grouped key and defined sessions. + * Note that sessions are generated on a per-key basis and records with different keys create independent sessions. + * Records with {@code null} key or value are ignored. + * Aggregating is a generalization of {@link #reduce(Reducer) combining via reduce(...)} as it, for example, + * allows the result to have a different type than the input values. + * The result is written into a local {@link SessionStore} (which is basically an ever-updating materialized view). + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * The specified {@link Initializer} is applied directly before the first input record per session is processed to + * provide an initial intermediate aggregation result that is used to process the first record per session. + * The specified {@link Aggregator} is applied for each input record and computes a new aggregate using the current + * aggregate (or for the very first record using the intermediate aggregation result provided via the + * {@link Initializer}) and the record's value. + * The specified {@link Merger} is used to merge two existing sessions into one, i.e., when the windows overlap, + * they are merged into a single session and the old sessions are discarded. + * Thus, {@code aggregate()} can be used to compute aggregate functions like count (c.f. {@link #count()}). + *

                + * The default key and value serde from the config will be used for serializing the result. + * If a different serde is required then you should use + * {@link #aggregate(Initializer, Aggregator, Merger, Materialized)}. + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same window and key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * For failure and recovery the store will be backed by an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-${internalStoreName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "internalStoreName" is an internal name + * and "-changelog" is a fixed suffix. + * Note that the internal store name may not be queryable through Interactive Queries. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param initializer an {@link Initializer} that computes an initial intermediate aggregation result. Cannot be {@code null}. + * @param aggregator an {@link Aggregator} that computes a new aggregate result. Cannot be {@code null}. + * @param sessionMerger a {@link Merger} that combines two aggregation results. Cannot be {@code null}. + * @param the value type of the resulting {@link KTable} + * @return a windowed {@link KTable} that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key per session + */ + KTable, VR> aggregate(final Initializer initializer, + final Aggregator aggregator, + final Merger sessionMerger); + + /** + * Aggregate the values of records in this stream by the grouped key and defined sessions. + * Note that sessions are generated on a per-key basis and records with different keys create independent sessions. + * Records with {@code null} key or value are ignored. + * Aggregating is a generalization of {@link #reduce(Reducer) combining via reduce(...)} as it, for example, + * allows the result to have a different type than the input values. + * The result is written into a local {@link SessionStore} (which is basically an ever-updating materialized view). + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * The specified {@link Initializer} is applied directly before the first input record per session is processed to + * provide an initial intermediate aggregation result that is used to process the first record per session. + * The specified {@link Aggregator} is applied for each input record and computes a new aggregate using the current + * aggregate (or for the very first record using the intermediate aggregation result provided via the + * {@link Initializer}) and the record's value. + * The specified {@link Merger} is used to merge two existing sessions into one, i.e., when the windows overlap, + * they are merged into a single session and the old sessions are discarded. + * Thus, {@code aggregate()} can be used to compute aggregate functions like count (c.f. {@link #count()}). + *

                + * The default key and value serde from the config will be used for serializing the result. + * If a different serde is required then you should use + * {@link #aggregate(Initializer, Aggregator, Merger, Named, Materialized)}. + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same window and key. + * The rate of propagated updates depends on your input data rate, the number of distinct + * keys, the number of parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} + * parameters for {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * For failure and recovery the store will be backed by an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-${internalStoreName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "internalStoreName" is an internal name + * and "-changelog" is a fixed suffix. + * Note that the internal store name may not be queryable through Interactive Queries. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param initializer an {@link Initializer} that computes an initial intermediate aggregation result. Cannot be {@code null}. + * @param aggregator an {@link Aggregator} that computes a new aggregate result. Cannot be {@code null}. + * @param sessionMerger a {@link Merger} that combines two aggregation results. Cannot be {@code null}. + * @param named a {@link Named} config used to name the processor in the topology. Cannot be {@code null}. + * @param the value type of the resulting {@link KTable} + * @return a windowed {@link KTable} that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key per session + */ + KTable, VR> aggregate(final Initializer initializer, + final Aggregator aggregator, + final Merger sessionMerger, + final Named named); + + /** + * Aggregate the values of records in this stream by the grouped key and defined sessions. + * Note that sessions are generated on a per-key basis and records with different keys create independent sessions. + * Records with {@code null} key or value are ignored. + * Aggregating is a generalization of {@link #reduce(Reducer) combining via reduce(...)} as it, for example, + * allows the result to have a different type than the input values. + * The result is written into a local {@link SessionStore} (which is basically an ever-updating materialized view) + * that can be queried using the store name as provided with {@link Materialized}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * The specified {@link Initializer} is applied directly before the first input record per session is processed to + * provide an initial intermediate aggregation result that is used to process the first record per session. + * The specified {@link Aggregator} is applied for each input record and computes a new aggregate using the current + * aggregate (or for the very first record using the intermediate aggregation result provided via the + * {@link Initializer}) and the record's value. + * The specified {@link Merger} is used to merge two existing sessions into one, i.e., when the windows overlap, + * they are merged into a single session and the old sessions are discarded. + * Thus, {@code aggregate()} can be used to compute aggregate functions like count (c.f. {@link #count()}). + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same window and key if caching is enabled on the {@link Materialized} instance. + * When caching is enabled the rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link SessionStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // some windowed aggregation on value type double
                +     * Sting queryableStoreName = ... // the queryableStoreName should be the name of the store as defined by the Materialized instance
                +     * ReadOnlySessionStore sessionStore = streams.store(queryableStoreName, QueryableStoreTypes.sessionStore());
                +     * String key = "some-key";
                +     * KeyValueIterator, Long> aggForKeyForSession = localWindowStore.fetch(key); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + *

                + * For failure and recovery the store will be backed by an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the {@link Materialized} instance must be a valid Kafka topic name and + * cannot contain characters other than ASCII alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is the + * provide store name defined in {@link Materialized}, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param initializer an {@link Initializer} that computes an initial intermediate aggregation result. Cannot be {@code null}. + * @param aggregator an {@link Aggregator} that computes a new aggregate result. Cannot be {@code null}. + * @param sessionMerger a {@link Merger} that combines two aggregation results. Cannot be {@code null}. + * @param materialized a {@link Materialized} config used to materialize a state store. Cannot be {@code null}. + * @param the value type of the resulting {@link KTable} + * @return a windowed {@link KTable} that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key per session + */ + KTable, VR> aggregate(final Initializer initializer, + final Aggregator aggregator, + final Merger sessionMerger, + final Materialized> materialized); + + /** + * Aggregate the values of records in this stream by the grouped key and defined sessions. + * Note that sessions are generated on a per-key basis and records with different keys create independent sessions. + * Records with {@code null} key or value are ignored. + * Aggregating is a generalization of {@link #reduce(Reducer) combining via reduce(...)} as it, for example, + * allows the result to have a different type than the input values. + * The result is written into a local {@link SessionStore} (which is basically an ever-updating materialized view) + * that can be queried using the store name as provided with {@link Materialized}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * The specified {@link Initializer} is applied directly before the first input record per session is processed to + * provide an initial intermediate aggregation result that is used to process the first record per session. + * The specified {@link Aggregator} is applied for each input record and computes a new aggregate using the current + * aggregate (or for the very first record using the intermediate aggregation result provided via the + * {@link Initializer}) and the record's value. + * The specified {@link Merger} is used to merge two existing sessions into one, i.e., when the windows overlap, + * they are merged into a single session and the old sessions are discarded. + * Thus, {@code aggregate()} can be used to compute aggregate functions like count (c.f. {@link #count()}). + *

                + * Not all updates might get sent downstream, as an internal cache will be used to deduplicate consecutive updates + * to the same window and key if caching is enabled on the {@link Materialized} instance. + * When caching is enabled the rate of propagated updates depends on your input data rate, the number of distinct + * keys, the number of parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} + * parameters for {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link SessionStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // some windowed aggregation on value type double
                +     * Sting queryableStoreName = ... // the queryableStoreName should be the name of the store as defined by the Materialized instance
                +     * ReadOnlySessionStore sessionStore = streams.store(queryableStoreName, QueryableStoreTypes.sessionStore());
                +     * String key = "some-key";
                +     * KeyValueIterator, Long> aggForKeyForSession = localWindowStore.fetch(key); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + *

                + * For failure and recovery the store will be backed by an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the {@link Materialized} instance must be a valid Kafka topic name and + * cannot contain characters other than ASCII alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is the + * provide store name defined in {@link Materialized}, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param initializer an {@link Initializer} that computes an initial intermediate aggregation result. Cannot be {@code null}. + * @param aggregator an {@link Aggregator} that computes a new aggregate result. Cannot be {@code null}. + * @param sessionMerger a {@link Merger} that combines two aggregation results. Cannot be {@code null}. + * @param named a {@link Named} config used to name the processor in the topology. Cannot be {@code null}. + * @param materialized a {@link Materialized} config used to materialize a state store. Cannot be {@code null}. + * @param the value type of the resulting {@link KTable} + * @return a windowed {@link KTable} that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key per session + */ + KTable, VR> aggregate(final Initializer initializer, + final Aggregator aggregator, + final Merger sessionMerger, + final Named named, + final Materialized> materialized); + + /** + * Combine the values of records in this stream by the grouped key and defined sessions. + * Note that sessions are generated on a per-key basis and records with different keys create independent sessions. + * Records with {@code null} key or value are ignored. + * Combining implies that the type of the aggregate result is the same as the type of the input value + * (c.f. {@link #aggregate(Initializer, Aggregator, Merger)}). + * The result is written into a local {@link SessionStore} (which is basically an ever-updating materialized view). + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + * The default key and value serde from the config will be used for serializing the result. + * If a different serde is required then you should use {@link #reduce(Reducer, Materialized)} . + *

                + * The value of the first record per session initialized the session result. + * The specified {@link Reducer} is applied for each additional input record per session and computes a new + * aggregate using the current aggregate (first argument) and the record's value (second argument): + *

                {@code
                +     * // At the example of a Reducer
                +     * new Reducer() {
                +     *   public Long apply(Long aggValue, Long currValue) {
                +     *     return aggValue + currValue;
                +     *   }
                +     * }
                +     * }
                + * Thus, {@code reduce()} can be used to compute aggregate functions like sum, min, or max. + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same window and key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * For failure and recovery the store will be backed by an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-${internalStoreName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "internalStoreName" is an internal name + * and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param reducer a {@link Reducer} that computes a new aggregate result. Cannot be {@code null}. + * @return a windowed {@link KTable} that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key per session + */ + KTable, V> reduce(final Reducer reducer); + + /** + * Combine the values of records in this stream by the grouped key and defined sessions. + * Note that sessions are generated on a per-key basis and records with different keys create independent sessions. + * Records with {@code null} key or value are ignored. + * Combining implies that the type of the aggregate result is the same as the type of the input value + * (c.f. {@link #aggregate(Initializer, Aggregator, Merger)}). + * The result is written into a local {@link SessionStore} (which is basically an ever-updating materialized view). + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + * The default key and value serde from the config will be used for serializing the result. + * If a different serde is required then you should use {@link #reduce(Reducer, Named, Materialized)} . + *

                + * The value of the first record per session initialized the session result. + * The specified {@link Reducer} is applied for each additional input record per session and computes a new + * aggregate using the current aggregate (first argument) and the record's value (second argument): + *

                {@code
                +     * // At the example of a Reducer
                +     * new Reducer() {
                +     *   public Long apply(Long aggValue, Long currValue) {
                +     *     return aggValue + currValue;
                +     *   }
                +     * }
                +     * }
                + * Thus, {@code reduce()} can be used to compute aggregate functions like sum, min, or max. + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same window and key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * For failure and recovery the store will be backed by an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-${internalStoreName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "internalStoreName" is an internal name + * and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param reducer a {@link Reducer} that computes a new aggregate result. Cannot be {@code null}. + * @param named a {@link Named} config used to name the processor in the topology. Cannot be {@code null}. + * @return a windowed {@link KTable} that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key per session + */ + KTable, V> reduce(final Reducer reducer, final Named named); + + /** + * Combine the values of records in this stream by the grouped key and defined sessions. + * Note that sessions are generated on a per-key basis and records with different keys create independent sessions. + * Records with {@code null} key or value are ignored. + * Combining implies that the type of the aggregate result is the same as the type of the input value + * (c.f. {@link #aggregate(Initializer, Aggregator, Merger)}). + * The result is written into a local {@link SessionStore} (which is basically an ever-updating materialized view) + * that can be queried using the store name as provided with {@link Materialized}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * The value of the first record per session initialized the session result. + * The specified {@link Reducer} is applied for each additional input record per session and computes a new + * aggregate using the current aggregate (first argument) and the record's value (second argument): + *

                {@code
                +     * // At the example of a Reducer
                +     * new Reducer() {
                +     *   public Long apply(Long aggValue, Long currValue) {
                +     *     return aggValue + currValue;
                +     *   }
                +     * }
                +     * }
                + * Thus, {@code reduce()} can be used to compute aggregate functions like sum, min, or max. + *

                + * Not all updates might get sent downstream, as an internal cache will be used to deduplicate consecutive updates + * to the same window and key if caching is enabled on the {@link Materialized} instance. + * When caching is enabled the rate of propagated updates depends on your input data rate, the number of distinct + * keys, the number of parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} + * parameters for {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link SessionStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // compute sum
                +     * Sting queryableStoreName = ... // the queryableStoreName should be the name of the store as defined by the Materialized instance
                +     * ReadOnlySessionStore localWindowStore = streams.store(queryableStoreName, QueryableStoreTypes.ReadOnlySessionStore);
                +     * String key = "some-key";
                +     * KeyValueIterator, Long> sumForKeyForWindows = localWindowStore.fetch(key); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + *

                + * For failure and recovery the store will be backed by an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the Materialized instance must be a valid Kafka topic name and cannot + * contain characters other than ASCII alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is the provide store name defined + * in {@code Materialized}, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param reducer a {@link Reducer} that computes a new aggregate result. Cannot be {@code null}. + * @param materialized a {@link Materialized} config used to materialize a state store. Cannot be {@code null}. + * @return a windowed {@link KTable} that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key per session + */ + KTable, V> reduce(final Reducer reducer, + final Materialized> materialized); + + /** + * Combine the values of records in this stream by the grouped key and defined sessions. + * Note that sessions are generated on a per-key basis and records with different keys create independent sessions. + * Records with {@code null} key or value are ignored. + * Combining implies that the type of the aggregate result is the same as the type of the input value + * (c.f. {@link #aggregate(Initializer, Aggregator, Merger)}). + * The result is written into a local {@link SessionStore} (which is basically an ever-updating materialized view) + * that can be queried using the store name as provided with {@link Materialized}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * The value of the first record per session initialized the session result. + * The specified {@link Reducer} is applied for each additional input record per session and computes a new + * aggregate using the current aggregate (first argument) and the record's value (second argument): + *

                {@code
                +     * // At the example of a Reducer
                +     * new Reducer() {
                +     *   public Long apply(Long aggValue, Long currValue) {
                +     *     return aggValue + currValue;
                +     *   }
                +     * }
                +     * }
                + * Thus, {@code reduce()} can be used to compute aggregate functions like sum, min, or max. + *

                + * Not all updates might get sent downstream, as an internal cache will be used to deduplicate consecutive updates + * to the same window and key if caching is enabled on the {@link Materialized} instance. + * When caching is enabled the rate of propagated updates depends on your input data rate, the number of distinct + * keys, the number of parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} + * parameters for {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link SessionStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters)} KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // compute sum
                +     * Sting queryableStoreName = ... // the queryableStoreName should be the name of the store as defined by the Materialized instance
                +     * ReadOnlySessionStore localWindowStore = streams.store(queryableStoreName, QueryableStoreTypes.ReadOnlySessionStore);
                +     * String key = "some-key";
                +     * KeyValueIterator, Long> sumForKeyForWindows = localWindowStore.fetch(key); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + *

                + * For failure and recovery the store will be backed by an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the Materialized instance must be a valid Kafka topic name and cannot + * contain characters other than ASCII alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is the provide store name defined + * in {@link Materialized}, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param reducer a {@link Reducer} that computes a new aggregate result. Cannot be {@code null}. + * @param named a {@link Named} config used to name the processor in the topology. Cannot be {@code null}. + * @param materialized a {@link Materialized} config used to materialize a state store. Cannot be {@code null}. + * @return a windowed {@link KTable} that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key per session + */ + KTable, V> reduce(final Reducer reducer, + final Named named, + final Materialized> materialized); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/SessionWindowedSerializer.java b/streams/src/main/java/org/apache/kafka/streams/kstream/SessionWindowedSerializer.java new file mode 100644 index 0000000..6ec10bf --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/SessionWindowedSerializer.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.internals.WindowedSerializer; +import org.apache.kafka.streams.state.internals.SessionKeySchema; + +import java.util.Map; + +public class SessionWindowedSerializer implements WindowedSerializer { + + private Serializer inner; + + // Default constructor needed by Kafka + public SessionWindowedSerializer() {} + + public SessionWindowedSerializer(final Serializer inner) { + this.inner = inner; + } + + @SuppressWarnings("unchecked") + @Override + public void configure(final Map configs, final boolean isKey) { + final String windowedInnerClassSerdeConfig = (String) configs.get(StreamsConfig.WINDOWED_INNER_CLASS_SERDE); + Serde windowInnerClassSerde = null; + if (windowedInnerClassSerdeConfig != null) { + try { + windowInnerClassSerde = Utils.newInstance(windowedInnerClassSerdeConfig, Serde.class); + } catch (final ClassNotFoundException e) { + throw new ConfigException(StreamsConfig.WINDOWED_INNER_CLASS_SERDE, windowedInnerClassSerdeConfig, + "Serde class " + windowedInnerClassSerdeConfig + " could not be found."); + } + } + + if (inner != null && windowedInnerClassSerdeConfig != null) { + if (!inner.getClass().getName().equals(windowInnerClassSerde.serializer().getClass().getName())) { + throw new IllegalArgumentException("Inner class serializer set using constructor " + + "(" + inner.getClass().getName() + ")" + + " is different from the one set in windowed.inner.class.serde config " + + "(" + windowInnerClassSerde.serializer().getClass().getName() + ")."); + } + } else if (inner == null && windowedInnerClassSerdeConfig == null) { + throw new IllegalArgumentException("Inner class serializer should be set either via constructor " + + "or via the windowed.inner.class.serde config"); + } else if (inner == null) + inner = windowInnerClassSerde.serializer(); + } + + @Override + public byte[] serialize(final String topic, final Windowed data) { + WindowedSerdes.verifyInnerSerializerNotNull(inner, this); + + if (data == null) { + return null; + } + // for either key or value, their schema is the same hence we will just use session key schema + return SessionKeySchema.toBinary(data, inner, topic); + } + + @Override + public void close() { + if (inner != null) { + inner.close(); + } + } + + @Override + public byte[] serializeBaseKey(final String topic, final Windowed data) { + WindowedSerdes.verifyInnerSerializerNotNull(inner, this); + + return inner.serialize(topic, data.key()); + } + + // Only for testing + Serializer innerSerializer() { + return inner; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/SessionWindows.java b/streams/src/main/java/org/apache/kafka/streams/kstream/SessionWindows.java new file mode 100644 index 0000000..35128bd --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/SessionWindows.java @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.streams.processor.TimestampExtractor; + +import java.time.Duration; +import java.util.Objects; + +import static org.apache.kafka.streams.internals.ApiUtils.prepareMillisCheckFailMsgPrefix; +import static org.apache.kafka.streams.internals.ApiUtils.validateMillisecondDuration; +import static org.apache.kafka.streams.kstream.Windows.DEPRECATED_DEFAULT_24_HR_GRACE_PERIOD; +import static org.apache.kafka.streams.kstream.Windows.NO_GRACE_PERIOD; +import static java.time.Duration.ofMillis; + +/** + * A session based window specification used for aggregating events into sessions. + *

                + * Sessions represent a period of activity separated by a defined gap of inactivity. + * Any events processed that fall within the inactivity gap of any existing sessions are merged into the existing sessions. + * If the event falls outside of the session gap then a new session will be created. + *

                + * For example, if we have a session gap of 5 and the following data arrives: + *

                + * +--------------------------------------+
                + * |    key    |    value    |    time    |
                + * +-----------+-------------+------------+
                + * |    A      |     1       |     10     |
                + * +-----------+-------------+------------+
                + * |    A      |     2       |     12     |
                + * +-----------+-------------+------------+
                + * |    A      |     3       |     20     |
                + * +-----------+-------------+------------+
                + * 
                + * We'd have 2 sessions for key A. + * One starting from time 10 and ending at time 12 and another starting and ending at time 20. + * The length of the session is driven by the timestamps of the data within the session. + * Thus, session windows are no fixed-size windows (c.f. {@link TimeWindows} and {@link JoinWindows}). + *

                + * If we then received another record: + *

                + * +--------------------------------------+
                + * |    key    |    value    |    time    |
                + * +-----------+-------------+------------+
                + * |    A      |     4       |     16     |
                + * +-----------+-------------+------------+
                + * 
                + * The previous 2 sessions would be merged into a single session with start time 10 and end time 20. + * The aggregate value for this session would be the result of aggregating all 4 values. + *

                + * For time semantics, see {@link TimestampExtractor}. + * + * @see TimeWindows + * @see UnlimitedWindows + * @see JoinWindows + * @see KGroupedStream#windowedBy(SessionWindows) + * @see TimestampExtractor + */ +public final class SessionWindows { + + private final long gapMs; + + private final long graceMs; + + // flag to check if the grace is already set via ofInactivityGapAndGrace or ofInactivityGapWithNoGrace + private final boolean hasSetGrace; + + private SessionWindows(final long gapMs, final long graceMs, final boolean hasSetGrace) { + this.gapMs = gapMs; + this.graceMs = graceMs; + this.hasSetGrace = hasSetGrace; + + if (gapMs <= 0) { + throw new IllegalArgumentException("Gap time cannot be zero or negative."); + } + + if (graceMs < 0) { + throw new IllegalArgumentException("Grace period must not be negative."); + } + } + + /** + * Creates a new window specification with the specified inactivity gap. + *

                + * Note that new events may change the boundaries of session windows, so aggressive + * close times can lead to surprising results in which an out-of-order event is rejected and then + * a subsequent event moves the window boundary forward. + *

                + * CAUTION: Using this method implicitly sets the grace period to zero, which means that any out-of-order + * records arriving after the window ends are considered late and will be dropped. + * + * @param inactivityGap the gap of inactivity between sessions + * @return a window definition with the window size and no grace period. Note that this means out-of-order records arriving after the window end will be dropped + * @throws IllegalArgumentException if {@code inactivityGap} is zero or negative or can't be represented as {@code long milliseconds} + */ + public static SessionWindows ofInactivityGapWithNoGrace(final Duration inactivityGap) { + return ofInactivityGapAndGrace(inactivityGap, ofMillis(NO_GRACE_PERIOD)); + } + + /** + * Creates a new window specification with the specified inactivity gap. + *

                + * Note that new events may change the boundaries of session windows, so aggressive + * close times can lead to surprising results in which an out-of-order event is rejected and then + * a subsequent event moves the window boundary forward. + *

                + * Using this method explicitly sets the grace period to the duration specified by {@code afterWindowEnd}, which + * means that only out-of-order records arriving more than the grace period after the window end will be dropped. + * The window close, after which any incoming records are considered late and will be rejected, is defined as + * {@code windowEnd + afterWindowEnd} + * + * @param inactivityGap the gap of inactivity between sessions + * @param afterWindowEnd The grace period to admit out-of-order events to a window. + * @return A SessionWindows object with the specified inactivity gap and grace period + * @throws IllegalArgumentException if {@code inactivityGap} is zero or negative or can't be represented as {@code long milliseconds} + * if {@code afterWindowEnd} is negative or can't be represented as {@code long milliseconds} + */ + public static SessionWindows ofInactivityGapAndGrace(final Duration inactivityGap, final Duration afterWindowEnd) { + final String inactivityGapMsgPrefix = prepareMillisCheckFailMsgPrefix(inactivityGap, "inactivityGap"); + final long inactivityGapMs = validateMillisecondDuration(inactivityGap, inactivityGapMsgPrefix); + + final String afterWindowEndMsgPrefix = prepareMillisCheckFailMsgPrefix(afterWindowEnd, "afterWindowEnd"); + final long afterWindowEndMs = validateMillisecondDuration(afterWindowEnd, afterWindowEndMsgPrefix); + + return new SessionWindows(inactivityGapMs, afterWindowEndMs, true); + } + + /** + * Create a new window specification with the specified inactivity gap. + * + * @param inactivityGap the gap of inactivity between sessions + * @return a new window specification without specifying a grace period (default to 24 hours minus {@code inactivityGap}) + * @throws IllegalArgumentException if {@code inactivityGap} is zero or negative or can't be represented as {@code long milliseconds} + * @deprecated since 3.0. Use {@link #ofInactivityGapWithNoGrace(Duration)} instead + */ + @Deprecated + public static SessionWindows with(final Duration inactivityGap) { + final String msgPrefix = prepareMillisCheckFailMsgPrefix(inactivityGap, "inactivityGap"); + final long inactivityGapMs = validateMillisecondDuration(inactivityGap, msgPrefix); + + return new SessionWindows(inactivityGapMs, Math.max(DEPRECATED_DEFAULT_24_HR_GRACE_PERIOD - inactivityGapMs, 0), false); + } + + /** + * Reject out-of-order events that arrive more than {@code afterWindowEnd} + * after the end of its window. + *

                + * Note that new events may change the boundaries of session windows, so aggressive + * close times can lead to surprising results in which an out-of-order event is rejected and then + * a subsequent event moves the window boundary forward. + * + * @param afterWindowEnd The grace period to admit out-of-order events to a window. + * @return this updated builder + * @throws IllegalArgumentException if the {@code afterWindowEnd} is negative or can't be represented as {@code long milliseconds} + * @throws IllegalStateException if {@link #grace(Duration)} is called after {@link #ofInactivityGapAndGrace(Duration, Duration)} or {@link #ofInactivityGapWithNoGrace(Duration)} + * @deprecated since 3.0. Use {@link #ofInactivityGapAndGrace(Duration, Duration)} instead + */ + @Deprecated + public SessionWindows grace(final Duration afterWindowEnd) throws IllegalArgumentException { + if (this.hasSetGrace) { + throw new IllegalStateException( + "Cannot call grace() after setting grace value via ofInactivityGapAndGrace or ofInactivityGapWithNoGrace."); + } + + final String msgPrefix = prepareMillisCheckFailMsgPrefix(afterWindowEnd, "afterWindowEnd"); + final long afterWindowEndMs = validateMillisecondDuration(afterWindowEnd, msgPrefix); + + return new SessionWindows(gapMs, afterWindowEndMs, false); + } + + public long gracePeriodMs() { + return graceMs; + } + + /** + * Return the specified gap for the session windows in milliseconds. + * + * @return the inactivity gap of the specified windows + */ + public long inactivityGap() { + return gapMs; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final SessionWindows that = (SessionWindows) o; + return gapMs == that.gapMs && + graceMs == that.graceMs; + } + + @Override + public int hashCode() { + return Objects.hash(gapMs, graceMs); + } + + @Override + public String toString() { + return "SessionWindows{" + + "gapMs=" + gapMs + + ", graceMs=" + graceMs + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/SlidingWindows.java b/streams/src/main/java/org/apache/kafka/streams/kstream/SlidingWindows.java new file mode 100644 index 0000000..159fea7 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/SlidingWindows.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.streams.processor.TimestampExtractor; + +import java.time.Duration; +import java.util.Objects; + +import static java.time.Duration.ofMillis; +import static org.apache.kafka.streams.internals.ApiUtils.prepareMillisCheckFailMsgPrefix; +import static org.apache.kafka.streams.internals.ApiUtils.validateMillisecondDuration; +import static org.apache.kafka.streams.kstream.Windows.NO_GRACE_PERIOD; + +/** + * A sliding window used for aggregating events. + *

                + * Sliding Windows are defined based on a record's timestamp, the window size based on the given maximum time difference (inclusive) between + * records in the same window, and the given window grace period. While the window is sliding over the input data stream, a new window is + * created each time a record enters the sliding window or a record drops out of the sliding window. + *

                + * Records that come after set grace period will be ignored, i.e., a window is closed when + * {@code stream-time > window-end + grace-period}. + *

                + * For example, if we have a time difference of 5000ms and the following data arrives: + *

                + * +--------------------------------------+
                + * |    key    |    value    |    time    |
                + * +-----------+-------------+------------+
                + * |    A      |     1       |    8000    |
                + * +-----------+-------------+------------+
                + * |    A      |     2       |    9200    |
                + * +-----------+-------------+------------+
                + * |    A      |     3       |    12400   |
                + * +-----------+-------------+------------+
                + * 
                + * We'd have the following 5 windows: + *
                  + *
                • window {@code [3000;8000]} contains [1] (created when first record enters the window)
                • + *
                • window {@code [4200;9200]} contains [1,2] (created when second record enters the window)
                • + *
                • window {@code [7400;12400]} contains [1,2,3] (created when third record enters the window)
                • + *
                • window {@code [8001;13001]} contains [2,3] (created when the first record drops out of the window)
                • + *
                • window {@code [9201;14201]} contains [3] (created when the second record drops out of the window)
                • + *
                + *

                + * Note that while SlidingWindows are of a fixed size, as are {@link TimeWindows}, the start and end points of the window + * depend on when events occur in the stream (i.e., event timestamps), similar to {@link SessionWindows}. + *

                + * For time semantics, see {@link TimestampExtractor}. + * + * @see TimeWindows + * @see SessionWindows + * @see UnlimitedWindows + * @see JoinWindows + * @see KGroupedStream#windowedBy(SlidingWindows) + * @see CogroupedKStream#windowedBy(SlidingWindows) + * @see TimestampExtractor + */ + +public final class SlidingWindows { + + /** The size of the windows in milliseconds, defined by the max time difference between records. */ + private final long timeDifferenceMs; + + /** The grace period in milliseconds. */ + private final long graceMs; + + private SlidingWindows(final long timeDifferenceMs, final long graceMs) { + this.timeDifferenceMs = timeDifferenceMs; + this.graceMs = graceMs; + + if (timeDifferenceMs < 0) { + throw new IllegalArgumentException("Window time difference must not be negative."); + } + + if (graceMs < 0) { + throw new IllegalArgumentException("Window grace period must not be negative."); + } + } + + /** + * Return a window definition with the window size based on the given maximum time difference (inclusive) between + * records in the same window and given window grace period. Reject out-of-order events that arrive after {@code grace}. + * A window is closed when {@code stream-time > window-end + grace-period}. + *

                + * CAUTION: Using this method implicitly sets the grace period to zero, which means that any out-of-order + * records arriving after the window ends are considered late and will be dropped. + * + * @param timeDifference the max time difference (inclusive) between two records in a window + * @return a new window definition with no grace period. Note that this means out-of-order records arriving after the window end will be dropped + * @throws IllegalArgumentException if the timeDifference is negative or can't be represented as {@code long milliseconds} + */ + public static SlidingWindows ofTimeDifferenceWithNoGrace(final Duration timeDifference) throws IllegalArgumentException { + return ofTimeDifferenceAndGrace(timeDifference, ofMillis(NO_GRACE_PERIOD)); + } + + /** + * Return a window definition with the window size based on the given maximum time difference (inclusive) between + * records in the same window and given window grace period. Reject out-of-order events that arrive after {@code afterWindowEnd}. + * A window is closed when {@code stream-time > window-end + grace-period}. + * + * @param timeDifference the max time difference (inclusive) between two records in a window + * @param afterWindowEnd the grace period to admit out-of-order events to a window + * @return a new window definition with the specified grace period + * @throws IllegalArgumentException if the timeDifference or afterWindowEnd (grace period) is negative or can't be represented as {@code long milliseconds} + */ + public static SlidingWindows ofTimeDifferenceAndGrace(final Duration timeDifference, final Duration afterWindowEnd) throws IllegalArgumentException { + final String timeDifferenceMsgPrefix = prepareMillisCheckFailMsgPrefix(timeDifference, "timeDifference"); + final long timeDifferenceMs = validateMillisecondDuration(timeDifference, timeDifferenceMsgPrefix); + final String afterWindowEndMsgPrefix = prepareMillisCheckFailMsgPrefix(afterWindowEnd, "afterWindowEnd"); + final long afterWindowEndMs = validateMillisecondDuration(afterWindowEnd, afterWindowEndMsgPrefix); + + return new SlidingWindows(timeDifferenceMs, afterWindowEndMs); + } + + /** + * Return a window definition with the window size based on the given maximum time difference (inclusive) between + * records in the same window and given window grace period. Reject out-of-order events that arrive after {@code grace}. + * A window is closed when {@code stream-time > window-end + grace-period}. + * + * @param timeDifference the max time difference (inclusive) between two records in a window + * @param grace the grace period to admit out-of-order events to a window + * @return a new window definition + * @throws IllegalArgumentException if the specified window size is < 0 or grace < 0, or either can't be represented as {@code long milliseconds} + * @deprecated since 3.0. Use {@link #ofTimeDifferenceWithNoGrace(Duration)} or {@link #ofTimeDifferenceAndGrace(Duration, Duration)} instead + */ + @Deprecated + public static SlidingWindows withTimeDifferenceAndGrace(final Duration timeDifference, final Duration grace) throws IllegalArgumentException { + final String msgPrefixSize = prepareMillisCheckFailMsgPrefix(timeDifference, "timeDifference"); + final long timeDifferenceMs = validateMillisecondDuration(timeDifference, msgPrefixSize); + + final String msgPrefixGrace = prepareMillisCheckFailMsgPrefix(grace, "grace"); + final long graceMs = validateMillisecondDuration(grace, msgPrefixGrace); + + return new SlidingWindows(timeDifferenceMs, graceMs); + } + + public long timeDifferenceMs() { + return timeDifferenceMs; + } + + public long gracePeriodMs() { + return graceMs; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final SlidingWindows that = (SlidingWindows) o; + return timeDifferenceMs == that.timeDifferenceMs && + graceMs == that.graceMs; + } + + @Override + public int hashCode() { + return Objects.hash(timeDifferenceMs, graceMs); + } + + @Override + public String toString() { + return "SlidingWindows{" + + ", sizeMs=" + timeDifferenceMs + + ", graceMs=" + graceMs + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/StreamJoined.java b/streams/src/main/java/org/apache/kafka/streams/kstream/StreamJoined.java new file mode 100644 index 0000000..72d0922 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/StreamJoined.java @@ -0,0 +1,363 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.state.WindowBytesStoreSupplier; + +import java.util.HashMap; +import java.util.Map; + +/** + * Class used to configure the name of the join processor, the repartition topic name, + * state stores or state store names in Stream-Stream join. + * @param the key type + * @param this value type + * @param other value type + */ +public class StreamJoined implements NamedOperation> { + + protected final Serde keySerde; + protected final Serde valueSerde; + protected final Serde otherValueSerde; + protected final WindowBytesStoreSupplier thisStoreSupplier; + protected final WindowBytesStoreSupplier otherStoreSupplier; + protected final String name; + protected final String storeName; + protected final boolean loggingEnabled; + protected final Map topicConfig; + + protected StreamJoined(final StreamJoined streamJoined) { + this(streamJoined.keySerde, + streamJoined.valueSerde, + streamJoined.otherValueSerde, + streamJoined.thisStoreSupplier, + streamJoined.otherStoreSupplier, + streamJoined.name, + streamJoined.storeName, + streamJoined.loggingEnabled, + streamJoined.topicConfig); + } + + private StreamJoined(final Serde keySerde, + final Serde valueSerde, + final Serde otherValueSerde, + final WindowBytesStoreSupplier thisStoreSupplier, + final WindowBytesStoreSupplier otherStoreSupplier, + final String name, + final String storeName, + final boolean loggingEnabled, + final Map topicConfig) { + this.keySerde = keySerde; + this.valueSerde = valueSerde; + this.otherValueSerde = otherValueSerde; + this.thisStoreSupplier = thisStoreSupplier; + this.otherStoreSupplier = otherStoreSupplier; + this.name = name; + this.storeName = storeName; + this.loggingEnabled = loggingEnabled; + this.topicConfig = topicConfig; + } + + /** + * Creates a StreamJoined instance with the provided store suppliers. The store suppliers must implement + * the {@link WindowBytesStoreSupplier} interface. The store suppliers must provide unique names or a + * {@link org.apache.kafka.streams.errors.StreamsException} is thrown. + * + * @param storeSupplier this store supplier + * @param otherStoreSupplier other store supplier + * @param the key type + * @param this value type + * @param other value type + * @return {@link StreamJoined} instance + */ + public static StreamJoined with(final WindowBytesStoreSupplier storeSupplier, + final WindowBytesStoreSupplier otherStoreSupplier) { + return new StreamJoined<>( + null, + null, + null, + storeSupplier, + otherStoreSupplier, + null, + null, + true, + new HashMap<>() + ); + } + + /** + * Creates a {@link StreamJoined} instance using the provided name for the state stores and hence the changelog + * topics for the join stores. The name for the stores will be ${applicationId}-<storeName>-this-join and ${applicationId}-<storeName>-other-join + * or ${applicationId}-<storeName>-outer-this-join and ${applicationId}-<storeName>-outer-other-join depending if the join is an inner-join + * or an outer join. The changelog topics will have the -changelog suffix. The user should note that even though the join stores will have a + * specified name, the stores will remain unavailable for querying. + * + * Please note that if you are using {@link StreamJoined} to replace deprecated {@link KStream#join} functions with + * {@link Joined} parameters in order to set the name for the join processors, you would need to create the {@link StreamJoined} + * object first and then call {@link StreamJoined#withName} + * + * @param storeName The name to use for the store + * @param The key type + * @param This value type + * @param Other value type + * @return {@link StreamJoined} instance + */ + public static StreamJoined as(final String storeName) { + return new StreamJoined<>( + null, + null, + null, + null, + null, + null, + storeName, + true, + new HashMap<>() + ); + } + + + /** + * Creates a {@link StreamJoined} instance with the provided serdes to configure the stores + * for the join. + * @param keySerde The key serde + * @param valueSerde This value serde + * @param otherValueSerde Other value serde + * @param The key type + * @param This value type + * @param Other value type + * @return {@link StreamJoined} instance + */ + public static StreamJoined with(final Serde keySerde, + final Serde valueSerde, + final Serde otherValueSerde + ) { + return new StreamJoined<>( + keySerde, + valueSerde, + otherValueSerde, + null, + null, + null, + null, + true, + new HashMap<>() + ); + } + + /** + * Set the name to use for the join processor and the repartition topic(s) if required. + * @param name the name to use + * @return a new {@link StreamJoined} instance + */ + @Override + public StreamJoined withName(final String name) { + return new StreamJoined<>( + keySerde, + valueSerde, + otherValueSerde, + thisStoreSupplier, + otherStoreSupplier, + name, + storeName, + loggingEnabled, + topicConfig + ); + } + + /** + * Sets the base store name to use for both sides of the join. The name for the state stores and hence the changelog + * topics for the join stores. The name for the stores will be ${applicationId}-<storeName>-this-join and ${applicationId}-<storeName>-other-join + * or ${applicationId}-<storeName>-outer-this-join and ${applicationId}-<storeName>-outer-other-join depending if the join is an inner-join + * or an outer join. The changelog topics will have the -changelog suffix. The user should note that even though the join stores will have a + * specified name, the stores will remain unavailable for querying. + * + * @param storeName the storeName to use + * @return a new {@link StreamJoined} instance + */ + public StreamJoined withStoreName(final String storeName) { + return new StreamJoined<>( + keySerde, + valueSerde, + otherValueSerde, + thisStoreSupplier, + otherStoreSupplier, + name, + storeName, + loggingEnabled, + topicConfig + ); + } + + /** + * Configure with the provided {@link Serde Serde} for the key + * @param keySerde the serde to use for the key + * @return a new {@link StreamJoined} configured with the keySerde + */ + public StreamJoined withKeySerde(final Serde keySerde) { + return new StreamJoined<>( + keySerde, + valueSerde, + otherValueSerde, + thisStoreSupplier, + otherStoreSupplier, + name, + storeName, + loggingEnabled, + topicConfig + ); + } + + /** + * Configure with the provided {@link Serde Serde} for this value + * @param valueSerde the serde to use for this value (calling or left side of the join) + * @return a new {@link StreamJoined} configured with the valueSerde + */ + public StreamJoined withValueSerde(final Serde valueSerde) { + return new StreamJoined<>( + keySerde, + valueSerde, + otherValueSerde, + thisStoreSupplier, + otherStoreSupplier, + name, + storeName, + loggingEnabled, + topicConfig + ); + } + + /** + * Configure with the provided {@link Serde Serde} for the other value + * @param otherValueSerde the serde to use for the other value (other or right side of the join) + * @return a new {@link StreamJoined} configured with the otherValueSerde + */ + public StreamJoined withOtherValueSerde(final Serde otherValueSerde) { + return new StreamJoined<>( + keySerde, + valueSerde, + otherValueSerde, + thisStoreSupplier, + otherStoreSupplier, + name, + storeName, + loggingEnabled, + topicConfig + ); + } + + /** + * Configure with the provided {@link WindowBytesStoreSupplier} for this store supplier. Please note + * this method only provides the store supplier for the left side of the join. If you wish to also provide a + * store supplier for the right (i.e., other) side you must use the {@link StreamJoined#withOtherStoreSupplier(WindowBytesStoreSupplier)} + * method + * @param thisStoreSupplier the store supplier to use for this store supplier (calling or left side of the join) + * @return a new {@link StreamJoined} configured with thisStoreSupplier + */ + public StreamJoined withThisStoreSupplier(final WindowBytesStoreSupplier thisStoreSupplier) { + return new StreamJoined<>( + keySerde, + valueSerde, + otherValueSerde, + thisStoreSupplier, + otherStoreSupplier, + name, + storeName, + loggingEnabled, + topicConfig + ); + } + + /** + * Configure with the provided {@link WindowBytesStoreSupplier} for the other store supplier. Please note + * this method only provides the store supplier for the right side of the join. If you wish to also provide a + * store supplier for the left side you must use the {@link StreamJoined#withThisStoreSupplier(WindowBytesStoreSupplier)} + * method + * @param otherStoreSupplier the store supplier to use for the other store supplier (other or right side of the join) + * @return a new {@link StreamJoined} configured with otherStoreSupplier + */ + public StreamJoined withOtherStoreSupplier(final WindowBytesStoreSupplier otherStoreSupplier) { + return new StreamJoined<>( + keySerde, + valueSerde, + otherValueSerde, + thisStoreSupplier, + otherStoreSupplier, + name, + storeName, + loggingEnabled, + topicConfig + ); + } + + /** + * Configures logging for both state stores. The changelog will be created with the provided configs. + *

                + * Note: Any unrecognized configs will be ignored + * @param config configs applied to the changelog topic + * @return a new {@link StreamJoined} configured with logging enabled + */ + public StreamJoined withLoggingEnabled(final Map config) { + + return new StreamJoined<>( + keySerde, + valueSerde, + otherValueSerde, + thisStoreSupplier, + otherStoreSupplier, + name, + storeName, + true, + config + ); + } + + /** + * Disable change logging for both state stores. + * @return a new {@link StreamJoined} configured with logging disabled + */ + public StreamJoined withLoggingDisabled() { + return new StreamJoined<>( + keySerde, + valueSerde, + otherValueSerde, + thisStoreSupplier, + otherStoreSupplier, + name, + storeName, + false, + new HashMap<>() + ); + } + + @Override + public String toString() { + return "StreamJoin{" + + "keySerde=" + keySerde + + ", valueSerde=" + valueSerde + + ", otherValueSerde=" + otherValueSerde + + ", thisStoreSupplier=" + thisStoreSupplier + + ", otherStoreSupplier=" + otherStoreSupplier + + ", name='" + name + '\'' + + ", storeName='" + storeName + '\'' + + ", loggingEnabled=" + loggingEnabled + + ", topicConfig=" + topicConfig + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/Suppressed.java b/streams/src/main/java/org/apache/kafka/streams/kstream/Suppressed.java new file mode 100644 index 0000000..31a53ce --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/Suppressed.java @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.streams.kstream.internals.suppress.EagerBufferConfigImpl; +import org.apache.kafka.streams.kstream.internals.suppress.FinalResultsSuppressionBuilder; +import org.apache.kafka.streams.kstream.internals.suppress.StrictBufferConfigImpl; +import org.apache.kafka.streams.kstream.internals.suppress.SuppressedInternal; + +import java.time.Duration; +import java.util.Collections; +import java.util.Map; + +public interface Suppressed extends NamedOperation> { + + /** + * Marker interface for a buffer configuration that is "strict" in the sense that it will strictly + * enforce the time bound and never emit early. + */ + interface StrictBufferConfig extends BufferConfig { + + } + + /** + * Marker interface for a buffer configuration that will strictly enforce size constraints + * (bytes and/or number of records) on the buffer, so it is suitable for reducing duplicate + * results downstream, but does not promise to eliminate them entirely. + */ + interface EagerBufferConfig extends BufferConfig { + + } + + interface BufferConfig> { + /** + * Create a size-constrained buffer in terms of the maximum number of keys it will store. + */ + static EagerBufferConfig maxRecords(final long recordLimit) { + return new EagerBufferConfigImpl(recordLimit, Long.MAX_VALUE, Collections.emptyMap()); + } + + /** + * Set a size constraint on the buffer in terms of the maximum number of keys it will store. + */ + BC withMaxRecords(final long recordLimit); + + /** + * Create a size-constrained buffer in terms of the maximum number of bytes it will use. + */ + static EagerBufferConfig maxBytes(final long byteLimit) { + return new EagerBufferConfigImpl(Long.MAX_VALUE, byteLimit, Collections.emptyMap()); + } + + /** + * Set a size constraint on the buffer, the maximum number of bytes it will use. + */ + BC withMaxBytes(final long byteLimit); + + /** + * Create a buffer unconstrained by size (either keys or bytes). + * + * As a result, the buffer will consume as much memory as it needs, dictated by the time bound. + * + * If there isn't enough heap available to meet the demand, the application will encounter an + * {@link OutOfMemoryError} and shut down (not guaranteed to be a graceful exit). Also, note that + * JVM processes under extreme memory pressure may exhibit poor GC behavior. + * + * This is a convenient option if you doubt that your buffer will be that large, but also don't + * wish to pick particular constraints, such as in testing. + * + * This buffer is "strict" in the sense that it will enforce the time bound or crash. + * It will never emit early. + */ + static StrictBufferConfig unbounded() { + return new StrictBufferConfigImpl(); + } + + /** + * Set the buffer to be unconstrained by size (either keys or bytes). + * + * As a result, the buffer will consume as much memory as it needs, dictated by the time bound. + * + * If there isn't enough heap available to meet the demand, the application will encounter an + * {@link OutOfMemoryError} and shut down (not guaranteed to be a graceful exit). Also, note that + * JVM processes under extreme memory pressure may exhibit poor GC behavior. + * + * This is a convenient option if you doubt that your buffer will be that large, but also don't + * wish to pick particular constraints, such as in testing. + * + * This buffer is "strict" in the sense that it will enforce the time bound or crash. + * It will never emit early. + */ + StrictBufferConfig withNoBound(); + + /** + * Set the buffer to gracefully shut down the application when any of its constraints are violated + * + * This buffer is "strict" in the sense that it will enforce the time bound or shut down. + * It will never emit early. + */ + StrictBufferConfig shutDownWhenFull(); + + /** + * Set the buffer to just emit the oldest records when any of its constraints are violated. + * + * This buffer is "not strict" in the sense that it may emit early, so it is suitable for reducing + * duplicate results downstream, but does not promise to eliminate them. + */ + EagerBufferConfig emitEarlyWhenFull(); + + /** + * Disable the changelog for this suppression's internal buffer. + * This will turn off fault-tolerance for the suppression, and will result in data loss in the event of a rebalance. + * By default the changelog is enabled. + * @return this + */ + BC withLoggingDisabled(); + + /** + * Indicates that a changelog topic should be created containing the currently suppressed + * records. Due to the short-lived nature of records in this topic it is likely more + * compactable than changelog topics for KTables. + * + * @param config Configs that should be applied to the changelog. Note: Any unrecognized + * configs will be ignored. + * @return this + */ + BC withLoggingEnabled(final Map config); + } + + /** + * Configure the suppression to emit only the "final results" from the window. + * + * By default all Streams operators emit results whenever new results are available. + * This includes windowed operations. + * + * This configuration will instead emit just one result per key for each window, guaranteeing + * to deliver only the final result. This option is suitable for use cases in which the business logic + * requires a hard guarantee that only the final result is propagated. For example, sending alerts. + * + * To accomplish this, the operator will buffer events from the window until the window close (that is, + * until the end-time passes, and additionally until the grace period expires). Since windowed operators + * are required to reject out-of-order events for a window whose grace period is expired, there is an additional + * guarantee that the final results emitted from this suppression will match any queryable state upstream. + * + * @param bufferConfig A configuration specifying how much space to use for buffering intermediate results. + * This is required to be a "strict" config, since it would violate the "final results" + * property to emit early and then issue an update later. + * @return a "final results" mode suppression configuration + */ + static Suppressed untilWindowCloses(final StrictBufferConfig bufferConfig) { + return new FinalResultsSuppressionBuilder<>(null, bufferConfig); + } + + /** + * Configure the suppression to wait {@code timeToWaitForMoreEvents} amount of time after receiving a record + * before emitting it further downstream. If another record for the same key arrives in the mean time, it replaces + * the first record in the buffer but does not re-start the timer. + * + * @param timeToWaitForMoreEvents The amount of time to wait, per record, for new events. + * @param bufferConfig A configuration specifying how much space to use for buffering intermediate results. + * @param The key type for the KTable to apply this suppression to. + * @return a suppression configuration + */ + static Suppressed untilTimeLimit(final Duration timeToWaitForMoreEvents, final BufferConfig bufferConfig) { + return new SuppressedInternal<>(null, timeToWaitForMoreEvents, bufferConfig, null, false); + } + + /** + * Use the specified name for the suppression node in the topology. + *

                + * This can be used to insert a suppression without changing the rest of the topology names + * (and therefore not requiring an application reset). + *

                + * Note however, that once a suppression has buffered some records, removing it from the topology would cause + * the loss of those records. + *

                + * A suppression can be "disabled" with the configuration {@code untilTimeLimit(Duration.ZERO, ...}. + * + * @param name The name to be used for the suppression node and changelog topic + * @return The same configuration with the addition of the given {@code name}. + */ + @Override + Suppressed withName(final String name); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/TableJoined.java b/streams/src/main/java/org/apache/kafka/streams/kstream/TableJoined.java new file mode 100644 index 0000000..70a3630 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/TableJoined.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.streams.processor.StreamPartitioner; + +import java.util.function.Function; + +/** + * The {@code TableJoined} class represents optional parameters that can be passed to + * {@link KTable#join(KTable, Function, ValueJoiner, TableJoined) KTable#join(KTable,Function,...)} and + * {@link KTable#leftJoin(KTable, Function, ValueJoiner, TableJoined) KTable#leftJoin(KTable,Function,...)} + * operations, for foreign key joins. + * @param this key type ; key type for the left (primary) table + * @param other key type ; key type for the right (foreign key) table + */ +public class TableJoined implements NamedOperation> { + + protected final StreamPartitioner partitioner; + protected final StreamPartitioner otherPartitioner; + protected final String name; + + private TableJoined(final StreamPartitioner partitioner, + final StreamPartitioner otherPartitioner, + final String name) { + this.partitioner = partitioner; + this.otherPartitioner = otherPartitioner; + this.name = name; + } + + protected TableJoined(final TableJoined tableJoined) { + this(tableJoined.partitioner, tableJoined.otherPartitioner, tableJoined.name); + } + + /** + * Create an instance of {@code TableJoined} with partitioner and otherPartitioner {@link StreamPartitioner} instances. + * {@code null} values are accepted and will result in the default partitioner being used. + * + * @param partitioner a {@link StreamPartitioner} that captures the partitioning strategy for the left (primary) + * table of the foreign key join. Specifying this option does not repartition or otherwise + * affect the source table; rather, this option informs the foreign key join on how internal + * topics should be partitioned in order to be co-partitioned with the left join table. + * The partitioning strategy must depend only on the message key and not the message value, + * else the source table is not supported with foreign key joins. This option may be left + * {@code null} if the source table uses the default partitioner. + * @param otherPartitioner a {@link StreamPartitioner} that captures the partitioning strategy for the right (foreign + * key) table of the foreign key join. Specifying this option does not repartition or otherwise + * affect the source table; rather, this option informs the foreign key join on how internal + * topics should be partitioned in order to be co-partitioned with the right join table. + * The partitioning strategy must depend only on the message key and not the message value, + * else the source table is not supported with foreign key joins. This option may be left + * {@code null} if the source table uses the default partitioner. + * @param this key type ; key type for the left (primary) table + * @param other key type ; key type for the right (foreign key) table + * @return new {@code TableJoined} instance with the provided partitioners + */ + public static TableJoined with(final StreamPartitioner partitioner, + final StreamPartitioner otherPartitioner) { + return new TableJoined<>(partitioner, otherPartitioner, null); + } + + /** + * Create an instance of {@code TableJoined} with base name for all components of the join, including internal topics + * created to complete the join. + * + * @param name the name used as the base for naming components of the join including internal topics + * @param this key type ; key type for the left (primary) table + * @param other key type ; key type for the right (foreign key) table + * @return new {@code TableJoined} instance configured with the {@code name} + * + */ + public static TableJoined as(final String name) { + return new TableJoined<>(null, null, name); + } + + /** + * Set the custom {@link StreamPartitioner} to be used as part of computing the join. + * {@code null} values are accepted and will result in the default partitioner being used. + * + * @param partitioner a {@link StreamPartitioner} that captures the partitioning strategy for the left (primary) + * table of the foreign key join. Specifying this option does not repartition or otherwise + * affect the source table; rather, this option informs the foreign key join on how internal + * topics should be partitioned in order to be co-partitioned with the left join table. + * The partitioning strategy must depend only on the message key and not the message value, + * else the source table is not supported with foreign key joins. This option may be left + * {@code null} if the source table uses the default partitioner. + * @return new {@code TableJoined} instance configured with the {@code partitioner} + */ + public TableJoined withPartitioner(final StreamPartitioner partitioner) { + return new TableJoined<>(partitioner, otherPartitioner, name); + } + + /** + * Set the custom other {@link StreamPartitioner} to be used as part of computing the join. + * {@code null} values are accepted and will result in the default partitioner being used. + * + * @param otherPartitioner a {@link StreamPartitioner} that captures the partitioning strategy for the right (foreign + * key) table of the foreign key join. Specifying this option does not repartition or otherwise + * affect the source table; rather, this option informs the foreign key join on how internal + * topics should be partitioned in order to be co-partitioned with the right join table. + * The partitioning strategy must depend only on the message key and not the message value, + * else the source table is not supported with foreign key joins. This option may be left + * {@code null} if the source table uses the default partitioner. + * @return new {@code TableJoined} instance configured with the {@code otherPartitioner} + */ + public TableJoined withOtherPartitioner(final StreamPartitioner otherPartitioner) { + return new TableJoined<>(partitioner, otherPartitioner, name); + } + + /** + * Set the base name used for all components of the join, including internal topics + * created to complete the join. + * + * @param name the name used as the base for naming components of the join including internal topics + * @return new {@code TableJoined} instance configured with the {@code name} + */ + @Override + public TableJoined withName(final String name) { + return new TableJoined<>(partitioner, otherPartitioner, name); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/TimeWindowedCogroupedKStream.java b/streams/src/main/java/org/apache/kafka/streams/kstream/TimeWindowedCogroupedKStream.java new file mode 100644 index 0000000..e4178bc --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/TimeWindowedCogroupedKStream.java @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StoreQueryParameters; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.state.ReadOnlyWindowStore; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.streams.state.TimestampedWindowStore; + +import java.time.Duration; + +/** + * {@code TimeWindowedCogroupKStream} is an abstraction of a windowed record stream of {@link KeyValue} pairs. + * It is an intermediate representation of a {@link CogroupedKStream} in order to apply a windowed aggregation operation + * on the original {@link KGroupedStream} records resulting in a windowed {@link KTable} (a windowed + * {@code KTable} is a {@link KTable} with key type {@link Windowed Windowed}). + *

                + * The specified {@code windows} define either hopping time windows that can be overlapping or tumbling (c.f. + * {@link TimeWindows}) or they define landmark windows (c.f. {@link UnlimitedWindows}). + *

                + * The result is written into a local {@link WindowStore} (which is basically an ever-updating + * materialized view) that can be queried using the name provided in the {@link Materialized} instance. + * Furthermore, updates to the store are sent downstream into a windowed {@link KTable} changelog stream, where + * "windowed" implies that the {@link KTable} key is a combined key of the original record key and a window ID. + * New events are added to windows until their grace period ends (see {@link TimeWindows#grace(Duration)}). + *

                + * A {@code TimeWindowedCogroupedKStream} must be obtained from a {@link CogroupedKStream} via + * {@link CogroupedKStream#windowedBy(Windows)}. + * + * @param Type of keys + * @param Type of values + * @see KStream + * @see KGroupedStream + * @see CogroupedKStream + */ +public interface TimeWindowedCogroupedKStream { + + /** + * Aggregate the values of records in this stream by the grouped key and defined windows. + * Records with {@code null} key or value are ignored. + * The result is written into a local {@link WindowStore} (which is basically an ever-updating materialized view). + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * The specified {@link Initializer} is applied directly before the first input record (per key) in each window is + * processed to provide an initial intermediate aggregation result that is used to process the first record for + * the window (per key). + * The specified {@link Aggregator} (as specified in {@link KGroupedStream#cogroup(Aggregator)} or + * {@link CogroupedKStream#cogroup(KGroupedStream, Aggregator)}) is applied for each input record and computes a new + * aggregate using the current aggregate (or for the very first record using the intermediate aggregation result + * provided via the {@link Initializer}) and the record's value. + * Thus, {@code aggregate()} can be used to compute aggregate functions like count or sum etc. + *

                + * The default key and value serde from the config will be used for serializing the result. + * If a different serde is required then you should use {@link #aggregate(Initializer, Materialized)}. + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same window and key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedWindowStore}) will be backed by + * an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-${internalStoreName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "internalStoreName" is an internal name + * and "-changelog" is a fixed suffix. + * Note that the internal store name may not be queryable through Interactive Queries. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param initializer an {@link Initializer} that computes an initial intermediate aggregation result. Cannot be {@code null}. + * @return a windowed {@link KTable} that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key within a window + */ + KTable, V> aggregate(final Initializer initializer); + + /** + * Aggregate the values of records in this stream by the grouped key and defined windows. + * Records with {@code null} key or value are ignored. + * The result is written into a local {@link WindowStore} (which is basically an ever-updating materialized view). + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * The specified {@link Initializer} is applied directly before the first input record (per key) in each window is + * processed to provide an initial intermediate aggregation result that is used to process the first record for + * the window (per key). + * The specified {@link Aggregator} (as specified in {@link KGroupedStream#cogroup(Aggregator)} or + * {@link CogroupedKStream#cogroup(KGroupedStream, Aggregator)}) is applied for each input record and computes a new + * aggregate using the current aggregate (or for the very first record using the intermediate aggregation result + * provided via the {@link Initializer}) and the record's value. + * Thus, {@code aggregate()} can be used to compute aggregate functions like count or sum etc. + *

                + * The default key and value serde from the config will be used for serializing the result. + * If a different serde is required then you should use {@link #aggregate(Initializer, Named, Materialized)}. + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same window and key. + * The rate of propagated updates depends on your input data rate, the number of distinct + * keys, the number of parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} + * parameters for {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedWindowStore}) will be backed by + * an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-${internalStoreName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "internalStoreName" is an internal name + * and "-changelog" is a fixed suffix. + * Note that the internal store name may not be queryable through Interactive Queries. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param initializer an {@link Initializer} that computes an initial intermediate aggregation result. Cannot be {@code null}. + * @param named a {@link Named} config used to name the processor in the topology. Cannot be {@code null}. + * @return a windowed {@link KTable} that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key within a window + */ + KTable, V> aggregate(final Initializer initializer, + final Named named); + + /** + * Aggregate the values of records in this stream by the grouped key and defined windows. + * Records with {@code null} key or value are ignored. + * The result is written into a local {@link WindowStore} (which is basically an ever-updating materialized view) + * that can be queried using the store name as provided with {@link Materialized}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * The specified {@link Initializer} is applied directly before the first input record (per key) in each window is + * processed to provide an initial intermediate aggregation result that is used to process the first record for + * the window (per key). + * The specified {@link Aggregator} (as specified in {@link KGroupedStream#cogroup(Aggregator)} or + * {@link CogroupedKStream#cogroup(KGroupedStream, Aggregator)}) is applied for each input record and computes a new + * aggregate using the current aggregate (or for the very first record using the intermediate aggregation result + * provided via the {@link Initializer}) and the record's value. + * Thus, {@code aggregate()} can be used to compute aggregate functions like count or sum etc. + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same window and key if caching is enabled on the {@link Materialized} instance. + * When caching is enabled the rate of propagated updates depends on your input data rate, the number of distinct + * keys, the number of parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} + * parameters for {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link ReadOnlyWindowStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // counting words
                +     * Store queryableStoreName = ... // the queryableStoreName should be the name of the store as defined by the Materialized instance
                +     * ReadOnlyWindowStore> localWindowStore = streams.store(queryableStoreName, QueryableStoreTypes.>timestampedWindowStore());
                +     *
                +     * K key = "some-word";
                +     * long fromTime = ...;
                +     * long toTime = ...;
                +     * WindowStoreIterator> aggregateStore = localWindowStore.fetch(key, timeFrom, timeTo); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedWindowStore} -- regardless of what + * is specified in the parameter {@code materialized}) will be backed by an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the {@link Materialized} instance must be a valid Kafka topic name and + * cannot contain characters other than ASCII alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is the + * provide store name defined in {@link Materialized}, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param initializer an {@link Initializer} that computes an initial intermediate aggregation result. Cannot be {@code null}. + * @param materialized a {@link Materialized} config used to materialize a state store. Cannot be {@code null}. + * @return a windowed {@link KTable} that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key within a window + */ + KTable, V> aggregate(final Initializer initializer, + final Materialized> materialized); + + /** + * Aggregate the values of records in this stream by the grouped key and defined windows. + * Records with {@code null} key or value are ignored. + * The result is written into a local {@link WindowStore} (which is basically an ever-updating materialized view) + * that can be queried using the store name as provided with {@link Materialized}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * The specified {@link Initializer} is applied directly before the first input record (per key) in each window is + * processed to provide an initial intermediate aggregation result that is used to process the first record for + * the window (per key). + * The specified {@link Aggregator} (as specified in {@link KGroupedStream#cogroup(Aggregator)} or + * {@link CogroupedKStream#cogroup(KGroupedStream, Aggregator)}) is applied for each input record and computes a new + * aggregate using the current aggregate (or for the very first record using the intermediate aggregation result + * provided via the {@link Initializer}) and the record's value. + * Thus, {@code aggregate()} can be used to compute aggregate functions like count or sum etc. + *

                + * Not all updates might get sent downstream, as an internal cache will be used to deduplicate consecutive updates + * to the same window and key if caching is enabled on the {@link Materialized} instance. + * When caching is enabled the rate of propagated updates depends on your input data rate, the number of distinct + * keys, the number of parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} + * parameters for {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link ReadOnlyWindowStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // counting words
                +     * Store queryableStoreName = ... // the queryableStoreName should be the name of the store as defined by the Materialized instance
                +     * ReadOnlyWindowStore> localWindowStore = streams.store(queryableStoreName, QueryableStoreTypes.>timestampedWindowStore());
                +     *
                +     * K key = "some-word";
                +     * long fromTime = ...;
                +     * long toTime = ...;
                +     * WindowStoreIterator> aggregateStore = localWindowStore.fetch(key, timeFrom, timeTo); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedWindowStore} -- regardless of what + * is specified in the parameter {@code materialized}) will be backed by an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the {@link Materialized} instance must be a valid Kafka topic name and + * cannot contain characters other than ASCII alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is the + * provide store name defined in {@link Materialized}, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param initializer an {@link Initializer} that computes an initial intermediate aggregation result. Cannot be {@code null}. + * @param named a {@link Named} config used to name the processor in the topology. Cannot be {@code null}. + * @param materialized a {@link Materialized} config used to materialize a state store. Cannot be {@code null}. + * @return a windowed {@link KTable} that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key within a window + */ + KTable, V> aggregate(final Initializer initializer, + final Named named, + final Materialized> materialized); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/TimeWindowedDeserializer.java b/streams/src/main/java/org/apache/kafka/streams/kstream/TimeWindowedDeserializer.java new file mode 100644 index 0000000..0be750e --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/TimeWindowedDeserializer.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.state.internals.WindowKeySchema; + +import java.util.Map; + +public class TimeWindowedDeserializer implements Deserializer> { + + private Long windowSize; + private boolean isChangelogTopic; + + private Deserializer inner; + + // Default constructor needed by Kafka + public TimeWindowedDeserializer() { + this(null, null); + } + + @Deprecated + public TimeWindowedDeserializer(final Deserializer inner) { + this(inner, Long.MAX_VALUE); + } + + public TimeWindowedDeserializer(final Deserializer inner, final Long windowSize) { + this.inner = inner; + this.windowSize = windowSize; + this.isChangelogTopic = false; + } + + public Long getWindowSize() { + return this.windowSize; + } + + @SuppressWarnings("unchecked") + @Override + public void configure(final Map configs, final boolean isKey) { + //check to see if the window size config is set and the window size is already set from the constructor + final Long configWindowSize; + if (configs.get(StreamsConfig.WINDOW_SIZE_MS_CONFIG) instanceof String) { + configWindowSize = Long.parseLong((String) configs.get(StreamsConfig.WINDOW_SIZE_MS_CONFIG)); + } else { + configWindowSize = (Long) configs.get(StreamsConfig.WINDOW_SIZE_MS_CONFIG); + } + if (windowSize != null && configWindowSize != null) { + throw new IllegalArgumentException("Window size should not be set in both the time windowed deserializer constructor and the window.size.ms config"); + } else if (windowSize == null && configWindowSize == null) { + throw new IllegalArgumentException("Window size needs to be set either through the time windowed deserializer " + + "constructor or the window.size.ms config but not both"); + } else { + windowSize = windowSize == null ? configWindowSize : windowSize; + } + + final String windowedInnerClassSerdeConfig = (String) configs.get(StreamsConfig.WINDOWED_INNER_CLASS_SERDE); + + Serde windowInnerClassSerde = null; + + if (windowedInnerClassSerdeConfig != null) { + try { + windowInnerClassSerde = Utils.newInstance(windowedInnerClassSerdeConfig, Serde.class); + } catch (final ClassNotFoundException e) { + throw new ConfigException(StreamsConfig.WINDOWED_INNER_CLASS_SERDE, windowedInnerClassSerdeConfig, + "Serde class " + windowedInnerClassSerdeConfig + " could not be found."); + } + } + + if (inner != null && windowedInnerClassSerdeConfig != null) { + if (!inner.getClass().getName().equals(windowInnerClassSerde.deserializer().getClass().getName())) { + throw new IllegalArgumentException("Inner class deserializer set using constructor " + + "(" + inner.getClass().getName() + ")" + + " is different from the one set in windowed.inner.class.serde config " + + "(" + windowInnerClassSerde.deserializer().getClass().getName() + ")."); + } + } else if (inner == null && windowedInnerClassSerdeConfig == null) { + throw new IllegalArgumentException("Inner class deserializer should be set either via constructor " + + "or via the windowed.inner.class.serde config"); + } else if (inner == null) + inner = windowInnerClassSerde.deserializer(); + } + + @Override + public Windowed deserialize(final String topic, final byte[] data) { + WindowedSerdes.verifyInnerDeserializerNotNull(inner, this); + + if (data == null || data.length == 0) { + return null; + } + + // toStoreKeyBinary was used to serialize the data. + if (this.isChangelogTopic) { + return WindowKeySchema.fromStoreKey(data, windowSize, inner, topic); + } + + // toBinary was used to serialize the data + return WindowKeySchema.from(data, windowSize, inner, topic); + } + + @Override + public void close() { + if (inner != null) { + inner.close(); + } + } + + public void setIsChangelogTopic(final boolean isChangelogTopic) { + this.isChangelogTopic = isChangelogTopic; + } + + // Only for testing + Deserializer innerDeserializer() { + return inner; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/TimeWindowedKStream.java b/streams/src/main/java/org/apache/kafka/streams/kstream/TimeWindowedKStream.java new file mode 100644 index 0000000..c015e79 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/TimeWindowedKStream.java @@ -0,0 +1,651 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StoreQueryParameters; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.state.ReadOnlyWindowStore; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.streams.state.TimestampedWindowStore; + +import java.time.Duration; + +/** + * {@code TimeWindowedKStream} is an abstraction of a windowed record stream of {@link KeyValue} pairs. + * It is an intermediate representation after a grouping and windowing of a {@link KStream} before an aggregation is + * applied to the new (partitioned) windows resulting in a windowed {@link KTable} (a windowed + * {@code KTable} is a {@link KTable} with key type {@link Windowed Windowed}). + *

                + * The specified {@code windows} define either hopping time windows that can be overlapping or tumbling (c.f. + * {@link TimeWindows}) or they define landmark windows (c.f. {@link UnlimitedWindows}). + *

                + * The result is written into a local {@link WindowStore} (which is basically an ever-updating + * materialized view) that can be queried using the name provided in the {@link Materialized} instance. + * Furthermore, updates to the store are sent downstream into a windowed {@link KTable} changelog stream, where + * "windowed" implies that the {@link KTable} key is a combined key of the original record key and a window ID. + * New events are added to {@link TimeWindows} until their grace period ends (see {@link TimeWindows#grace(Duration)}). + *

                + * A {@code TimeWindowedKStream} must be obtained from a {@link KGroupedStream} via + * {@link KGroupedStream#windowedBy(Windows)}. + * + * @param Type of keys + * @param Type of values + * @see KStream + * @see KGroupedStream + */ +public interface TimeWindowedKStream { + + /** + * Count the number of records in this stream by the grouped key and defined windows. + * Records with {@code null} key or value are ignored. + *

                + * The result is written into a local {@link WindowStore} (which is basically an ever-updating materialized view). + * The default key serde from the config will be used for serializing the result. + * If a different serde is required then you should use {@link #count(Materialized)}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same window and key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedWindowStore}) will be backed by + * an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-${internalStoreName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "internalStoreName" is an internal name + * and "-changelog" is a fixed suffix. + * Note that the internal store name may not be queryable through Interactive Queries. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @return a windowed {@link KTable} that contains "update" records with unmodified keys and {@link Long} values + * that represent the latest (rolling) count (i.e., number of records) for each key within a window + */ + KTable, Long> count(); + + /** + * Count the number of records in this stream by the grouped key and defined windows. + * Records with {@code null} key or value are ignored. + *

                + * The result is written into a local {@link WindowStore} (which is basically an ever-updating materialized view). + * The default key serde from the config will be used for serializing the result. + * If a different serde is required then you should use {@link #count(Named, Materialized)}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same window and key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedWindowStore}) will be backed by + * an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-${internalStoreName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "internalStoreName" is an internal name + * and "-changelog" is a fixed suffix. + * Note that the internal store name may not be queryable through Interactive Queries. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param named a {@link Named} config used to name the processor in the topology. Cannot be {@code null}. + * @return a windowed {@link KTable} that contains "update" records with unmodified keys and {@link Long} values + * that represent the latest (rolling) count (i.e., number of records) for each key within a window + */ + KTable, Long> count(final Named named); + + /** + * Count the number of records in this stream by the grouped key and defined windows. + * Records with {@code null} key or value are ignored. + *

                + * The result is written into a local {@link WindowStore} (which is basically an ever-updating materialized view) + * that can be queried using the name provided with {@link Materialized}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * Not all updates might get sent downstream, as an internal cache will be used to deduplicate consecutive updates + * to the same window and key if caching is enabled on the {@link Materialized} instance. + * When caching is enabled the rate of propagated updates depends on your input data rate, the number of distinct + * keys, the number of parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} + * parameters for {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link ReadOnlyWindowStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // counting words
                +     * Store queryableStoreName = ... // the queryableStoreName should be the name of the store as defined by the Materialized instance
                +     * ReadOnlyWindowStore> localWindowStore = streams.store(queryableStoreName, QueryableStoreTypes.>timestampedWindowStore());
                +     *
                +     * K key = "some-word";
                +     * long fromTime = ...;
                +     * long toTime = ...;
                +     * WindowStoreIterator> countForWordsForWindows = localWindowStore.fetch(key, timeFrom, timeTo); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedWindowStore} -- regardless of what + * is specified in the parameter {@code materialized}) will be backed by an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the Materialized instance must be a valid Kafka topic name and cannot + * contain characters other than ASCII alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is the provide store name defined + * in {@code Materialized}, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param materialized an instance of {@link Materialized} used to materialize a state store. Cannot be {@code null}. + * Note: the valueSerde will be automatically set to {@link org.apache.kafka.common.serialization.Serdes#Long() Serdes#Long()} + * if there is no valueSerde provided + * @return a windowed {@link KTable} that contains "update" records with unmodified keys and {@link Long} values + * that represent the latest (rolling) count (i.e., number of records) for each key within a window + */ + KTable, Long> count(final Materialized> materialized); + + /** + * Count the number of records in this stream by the grouped key and defined windows. + * Records with {@code null} key or value are ignored. + *

                + * The result is written into a local {@link WindowStore} (which is basically an ever-updating materialized view) + * that can be queried using the name provided with {@link Materialized}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * Not all updates might get sent downstream, as an internal cache will be used to deduplicate consecutive updates + * to the same window and key if caching is enabled on the {@link Materialized} instance. + * When caching is enabled the rate of propagated updates depends on your input data rate, the number of distinct + * keys, the number of parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} + * parameters for {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval} + *

                + * To query the local {@link ReadOnlyWindowStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // counting words
                +     * Store queryableStoreName = ... // the queryableStoreName should be the name of the store as defined by the Materialized instance
                +     * ReadOnlyWindowStore> localWindowStore = streams.store(queryableStoreName, QueryableStoreTypes.>timestampedWindowStore());
                +     *
                +     * K key = "some-word";
                +     * long fromTime = ...;
                +     * long toTime = ...;
                +     * WindowStoreIterator> countForWordsForWindows = localWindowStore.fetch(key, timeFrom, timeTo); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedWindowStore} -- regardless of what + * is specified in the parameter {@code materialized}) will be backed by an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the Materialized instance must be a valid Kafka topic name and cannot + * contain characters other than ASCII alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is the provide store name defined + * in {@code Materialized}, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param named a {@link Named} config used to name the processor in the topology. Cannot be {@code null}. + * @param materialized an instance of {@link Materialized} used to materialize a state store. Cannot be {@code null}. + * Note: the valueSerde will be automatically set to {@link org.apache.kafka.common.serialization.Serdes#Long() Serdes#Long()} + * if there is no valueSerde provided + * @return a windowed {@link KTable} that contains "update" records with unmodified keys and {@link Long} values + * that represent the latest (rolling) count (i.e., number of records) for each key within a window + */ + KTable, Long> count(final Named named, + final Materialized> materialized); + + /** + * Aggregate the values of records in this stream by the grouped key and defined windows. + * Records with {@code null} key or value are ignored. + * Aggregating is a generalization of {@link #reduce(Reducer) combining via reduce(...)} as it, for example, + * allows the result to have a different type than the input values. + * The result is written into a local {@link WindowStore} (which is basically an ever-updating materialized view). + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * The specified {@link Initializer} is applied directly before the first input record (per key) in each window is + * processed to provide an initial intermediate aggregation result that is used to process the first record for + * the window (per key). + * The specified {@link Aggregator} is applied for each input record and computes a new aggregate using the current + * aggregate (or for the very first record using the intermediate aggregation result provided via the + * {@link Initializer}) and the record's value. + * Thus, {@code aggregate()} can be used to compute aggregate functions like count (c.f. {@link #count()}). + *

                + * The default key and value serde from the config will be used for serializing the result. + * If a different serde is required then you should use {@link #aggregate(Initializer, Aggregator, Materialized)}. + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same window and key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedWindowStore}) will be backed by + * an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-${internalStoreName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "internalStoreName" is an internal name + * and "-changelog" is a fixed suffix. + * Note that the internal store name may not be queryable through Interactive Queries. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param initializer an {@link Initializer} that computes an initial intermediate aggregation result. Cannot be {@code null}. + * @param aggregator an {@link Aggregator} that computes a new aggregate result. Cannot be {@code null}. + * @param the value type of the resulting {@link KTable} + * @return a windowed {@link KTable} that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key within a window + */ + KTable, VR> aggregate(final Initializer initializer, + final Aggregator aggregator); + + /** + * Aggregate the values of records in this stream by the grouped key and defined windows. + * Records with {@code null} key or value are ignored. + * Aggregating is a generalization of {@link #reduce(Reducer) combining via reduce(...)} as it, for example, + * allows the result to have a different type than the input values. + * The result is written into a local {@link WindowStore} (which is basically an ever-updating materialized view). + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * The specified {@link Initializer} is applied directly before the first input record (per key) in each window is + * processed to provide an initial intermediate aggregation result that is used to process the first record for + * the window (per key). + * The specified {@link Aggregator} is applied for each input record and computes a new aggregate using the current + * aggregate (or for the very first record using the intermediate aggregation result provided via the + * {@link Initializer}) and the record's value. + * Thus, {@code aggregate()} can be used to compute aggregate functions like count (c.f. {@link #count()}). + *

                + * The default key and value serde from the config will be used for serializing the result. + * If a different serde is required then you should use + * {@link #aggregate(Initializer, Aggregator, Named, Materialized)}. + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same window and key. + * The rate of propagated updates depends on your input data rate, the number of distinct + * keys, the number of parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} + * parameters for {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedWindowStore}) will be backed by + * an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-${internalStoreName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "internalStoreName" is an internal name + * and "-changelog" is a fixed suffix. + * Note that the internal store name may not be queryable through Interactive Queries. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param initializer an {@link Initializer} that computes an initial intermediate aggregation result. Cannot be {@code null}. + * @param aggregator an {@link Aggregator} that computes a new aggregate result. Cannot be {@code null}. + * @param named a {@link Named} config used to name the processor in the topology. Cannot be {@code null}. + * @param the value type of the resulting {@link KTable} + * @return a windowed {@link KTable} that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key within a window + */ + KTable, VR> aggregate(final Initializer initializer, + final Aggregator aggregator, + final Named named); + + /** + * Aggregate the values of records in this stream by the grouped key and defined windows. + * Records with {@code null} key or value are ignored. + * Aggregating is a generalization of {@link #reduce(Reducer) combining via reduce(...)} as it, for example, + * allows the result to have a different type than the input values. + * The result is written into a local {@link WindowStore} (which is basically an ever-updating materialized view) + * that can be queried using the store name as provided with {@link Materialized}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * The specified {@link Initializer} is applied directly before the first input record (per key) in each window is + * processed to provide an initial intermediate aggregation result that is used to process the first record for + * the window (per key). + * The specified {@link Aggregator} is applied for each input record and computes a new aggregate using the current + * aggregate (or for the very first record using the intermediate aggregation result provided via the + * {@link Initializer}) and the record's value. + * Thus, {@code aggregate()} can be used to compute aggregate functions like count (c.f. {@link #count()}). + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same window and key if caching is enabled on the {@link Materialized} instance. + * When caching is enabled the rate of propagated updates depends on your input data rate, the number of distinct + * keys, the number of parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} + * parameters for {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link ReadOnlyWindowStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // counting words
                +     * Store queryableStoreName = ... // the queryableStoreName should be the name of the store as defined by the Materialized instance
                +     * ReadOnlyWindowStore> localWindowStore = streams.store(queryableStoreName, QueryableStoreTypes.>timestampedWindowStore());
                +     *
                +     * K key = "some-word";
                +     * long fromTime = ...;
                +     * long toTime = ...;
                +     * WindowStoreIterator> aggregateStore = localWindowStore.fetch(key, timeFrom, timeTo); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedWindowStore} -- regardless of what + * is specified in the parameter {@code materialized}) will be backed by an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the {@link Materialized} instance must be a valid Kafka topic name and + * cannot contain characters other than ASCII alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is the + * provide store name defined in {@link Materialized}, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param initializer an {@link Initializer} that computes an initial intermediate aggregation result. Cannot be {@code null}. + * @param aggregator an {@link Aggregator} that computes a new aggregate result. Cannot be {@code null}. + * @param materialized a {@link Materialized} config used to materialize a state store. Cannot be {@code null}. + * @param the value type of the resulting {@link KTable} + * @return a windowed {@link KTable} that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key within a window + */ + KTable, VR> aggregate(final Initializer initializer, + final Aggregator aggregator, + final Materialized> materialized); + + /** + * Aggregate the values of records in this stream by the grouped key and defined windows. + * Records with {@code null} key or value are ignored. + * Aggregating is a generalization of {@link #reduce(Reducer) combining via reduce(...)} as it, for example, + * allows the result to have a different type than the input values. + * The result is written into a local {@link WindowStore} (which is basically an ever-updating materialized view) + * that can be queried using the store name as provided with {@link Materialized}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * The specified {@link Initializer} is applied directly before the first input record (per key) in each window is + * processed to provide an initial intermediate aggregation result that is used to process the first record for + * the window (per key). + * The specified {@link Aggregator} is applied for each input record and computes a new aggregate using the current + * aggregate (or for the very first record using the intermediate aggregation result provided via the + * {@link Initializer}) and the record's value. + * Thus, {@code aggregate()} can be used to compute aggregate functions like count (c.f. {@link #count()}). + *

                + * Not all updates might get sent downstream, as an internal cache will be used to deduplicate consecutive updates + * to the same window and key if caching is enabled on the {@link Materialized} instance. + * When caching is enabled the rate of propagated updates depends on your input data rate, the number of distinct + * keys, the number of parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} + * parameters for {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval} + *

                + * To query the local {@link ReadOnlyWindowStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // counting words
                +     * Store queryableStoreName = ... // the queryableStoreName should be the name of the store as defined by the Materialized instance
                +     * ReadOnlyWindowStore> localWindowStore = streams.store(queryableStoreName, QueryableStoreTypes.>timestampedWindowStore());
                +     *
                +     * K key = "some-word";
                +     * long fromTime = ...;
                +     * long toTime = ...;
                +     * WindowStoreIterator> aggregateStore = localWindowStore.fetch(key, timeFrom, timeTo); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedWindowStore} -- regardless of what + * is specified in the parameter {@code materialized}) will be backed by an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the {@link Materialized} instance must be a valid Kafka topic name and + * cannot contain characters other than ASCII alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is the + * provide store name defined in {@link Materialized}, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param initializer an {@link Initializer} that computes an initial intermediate aggregation result. Cannot be {@code null}. + * @param aggregator an {@link Aggregator} that computes a new aggregate result. Cannot be {@code null}. + * @param named a {@link Named} config used to name the processor in the topology. Cannot be {@code null}. + * @param materialized a {@link Materialized} config used to materialize a state store. Cannot be {@code null}. + * @param the value type of the resulting {@link KTable} + * @return a windowed {@link KTable} that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key within a window + */ + KTable, VR> aggregate(final Initializer initializer, + final Aggregator aggregator, + final Named named, + final Materialized> materialized); + + /** + * Combine the values of records in this stream by the grouped key and defined windows. + * Records with {@code null} key or value are ignored. + * Combining implies that the type of the aggregate result is the same as the type of the input value + * (c.f. {@link #aggregate(Initializer, Aggregator)}). + *

                + * The result is written into a local {@link WindowStore} (which is basically an ever-updating materialized view). + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + * The default key and value serde from the config will be used for serializing the result. + * If a different serde is required then you should use {@link #reduce(Reducer, Materialized)} . + *

                + * The value of the first record per window initialized the aggregation result. + * The specified {@link Reducer} is applied for each additional input record per window and computes a new + * aggregate using the current aggregate (first argument) and the record's value (second argument): + *

                {@code
                +     * // At the example of a Reducer
                +     * new Reducer() {
                +     *   public Long apply(Long aggValue, Long currValue) {
                +     *     return aggValue + currValue;
                +     *   }
                +     * }
                +     * }
                + * Thus, {@code reduce()} can be used to compute aggregate functions like sum, min, or max. + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same window and key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedWindowStore}) will be backed by + * an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-${internalStoreName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "internalStoreName" is an internal name + * and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param reducer a {@link Reducer} that computes a new aggregate result. Cannot be {@code null}. + * @return a windowed {@link KTable} that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key within a window + */ + KTable, V> reduce(final Reducer reducer); + + /** + * Combine the values of records in this stream by the grouped key and defined windows. + * Records with {@code null} key or value are ignored. + * Combining implies that the type of the aggregate result is the same as the type of the input value. + *

                + * The result is written into a local {@link WindowStore} (which is basically an ever-updating materialized view). + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + * The default key and value serde from the config will be used for serializing the result. + * If a different serde is required then you should use {@link #reduce(Reducer, Named, Materialized)} . + *

                + * The value of the first record per window initialized the aggregation result. + * The specified {@link Reducer} is applied for each additional input record per window and computes a new + * aggregate using the current aggregate (first argument) and the record's value (second argument): + *

                {@code
                +     * // At the example of a Reducer
                +     * new Reducer() {
                +     *   public Long apply(Long aggValue, Long currValue) {
                +     *     return aggValue + currValue;
                +     *   }
                +     * }
                +     * }
                + * Thus, {@code reduce()} can be used to compute aggregate functions like sum, min, or max. + *

                + * Not all updates might get sent downstream, as an internal cache is used to deduplicate consecutive updates to + * the same window and key. + * The rate of propagated updates depends on your input data rate, the number of distinct keys, the number of + * parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} parameters for + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedWindowStore}) will be backed by + * an internal changelog topic that will be created in Kafka. + * The changelog topic will be named "${applicationId}-${internalStoreName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "internalStoreName" is an internal name + * and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param reducer a {@link Reducer} that computes a new aggregate result. Cannot be {@code null}. + * @param named a {@link Named} config used to name the processor in the topology. Cannot be {@code null}. + * @return a windowed {@link KTable} that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key within a window + */ + KTable, V> reduce(final Reducer reducer, final Named named); + + /** + * Combine the values of records in this stream by the grouped key and defined windows. + * Records with {@code null} key or value are ignored. + * Combining implies that the type of the aggregate result is the same as the type of the input value. + *

                + * The result is written into a local {@link WindowStore} (which is basically an ever-updating materialized view) + * that can be queried using the store name as provided with {@link Materialized}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * The value of the first record per window initialized the aggregation result. + * The specified {@link Reducer} is applied for each additional input record per window and computes a new + * aggregate using the current aggregate (first argument) and the record's value (second argument): + *

                {@code
                +     * // At the example of a Reducer
                +     * new Reducer() {
                +     *   public Long apply(Long aggValue, Long currValue) {
                +     *     return aggValue + currValue;
                +     *   }
                +     * }
                +     * }
                + * Thus, {@code reduce()} can be used to compute aggregate functions like sum, min, or max. + *

                + * Not all updates might get sent downstream, as an internal cache will be used to deduplicate consecutive updates + * to the same window and key if caching is enabled on the {@link Materialized} instance. + * When caching is enabled the rate of propagated updates depends on your input data rate, the number of distinct + * keys, the number of parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} + * parameters for {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link ReadOnlyWindowStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // counting words
                +     * Store queryableStoreName = ... // the queryableStoreName should be the name of the store as defined by the Materialized instance
                +     * ReadOnlyWindowStore> localWindowStore = streams.store(queryableStoreName, QueryableStoreTypes.>timestampedWindowStore());
                +     *
                +     * K key = "some-word";
                +     * long fromTime = ...;
                +     * long toTime = ...;
                +     * WindowStoreIterator> reduceStore = localWindowStore.fetch(key, timeFrom, timeTo); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedWindowStore} -- regardless of what + * is specified in the parameter {@code materialized}) will be backed by an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the Materialized instance must be a valid Kafka topic name and cannot + * contain characters other than ASCII alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is the provide store name defined + * in {@code Materialized}, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param reducer a {@link Reducer} that computes a new aggregate result. Cannot be {@code null}. + * @param materialized a {@link Materialized} config used to materialize a state store. Cannot be {@code null}. + * @return a windowed {@link KTable} that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key within a window + */ + KTable, V> reduce(final Reducer reducer, + final Materialized> materialized); + + /** + * Combine the values of records in this stream by the grouped key and defined windows. + * Records with {@code null} key or value are ignored. + * Combining implies that the type of the aggregate result is the same as the type of the input value. + *

                + * The result is written into a local {@link WindowStore} (which is basically an ever-updating materialized view) + * that can be queried using the store name as provided with {@link Materialized}. + * Furthermore, updates to the store are sent downstream into a {@link KTable} changelog stream. + *

                + * The value of the first record per window initialized the aggregation result. + * The specified {@link Reducer} is applied for each additional input record per window and computes a new + * aggregate using the current aggregate (first argument) and the record's value (second argument): + *

                {@code
                +     * // At the example of a Reducer
                +     * new Reducer() {
                +     *   public Long apply(Long aggValue, Long currValue) {
                +     *     return aggValue + currValue;
                +     *   }
                +     * }
                +     * }
                + * Thus, {@code reduce()} can be used to compute aggregate functions like sum, min, or max. + *

                + * Not all updates might get sent downstream, as an internal cache will be used to deduplicate consecutive updates + * to the same window and key if caching is enabled on the {@link Materialized} instance. + * When caching is enabled the rate of propagated updates depends on your input data rate, the number of distinct + * keys, the number of parallel running Kafka Streams instances, and the {@link StreamsConfig configuration} + * parameters for {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache size}, and + * {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit interval}. + *

                + * To query the local {@link ReadOnlyWindowStore} it must be obtained via + * {@link KafkaStreams#store(StoreQueryParameters) KafkaStreams#store(...)}: + *

                {@code
                +     * KafkaStreams streams = ... // counting words
                +     * Store queryableStoreName = ... // the queryableStoreName should be the name of the store as defined by the Materialized instance
                +     * ReadOnlyWindowStore> localWindowStore = streams.store(queryableStoreName, QueryableStoreTypes.>timestampedWindowStore());
                +     *
                +     * K key = "some-word";
                +     * long fromTime = ...;
                +     * long toTime = ...;
                +     * WindowStoreIterator> reduceStore = localWindowStore.fetch(key, timeFrom, timeTo); // key must be local (application state is shared over all running Kafka Streams instances)
                +     * }
                + * For non-local keys, a custom RPC mechanism must be implemented using {@link KafkaStreams#metadataForAllStreamsClients()} to + * query the value of the key on a parallel running instance of your Kafka Streams application. + *

                + * For failure and recovery the store (which always will be of type {@link TimestampedWindowStore} -- regardless of what + * is specified in the parameter {@code materialized}) will be backed by an internal changelog topic that will be created in Kafka. + * Therefore, the store name defined by the Materialized instance must be a valid Kafka topic name and cannot + * contain characters other than ASCII alphanumerics, '.', '_' and '-'. + * The changelog topic will be named "${applicationId}-${storeName}-changelog", where "applicationId" is + * user-specified in {@link StreamsConfig} via parameter + * {@link StreamsConfig#APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG}, "storeName" is the provide store name defined + * in {@link Materialized}, and "-changelog" is a fixed suffix. + *

                + * You can retrieve all generated internal topic names via {@link Topology#describe()}. + * + * @param reducer a {@link Reducer} that computes a new aggregate result. Cannot be {@code null}. + * @param named a {@link Named} config used to name the processor in the topology. Cannot be {@code null}. + * @param materialized a {@link Materialized} config used to materialize a state store. Cannot be {@code null}. + * @return a windowed {@link KTable} that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key within a window + */ + KTable, V> reduce(final Reducer reducer, + final Named named, + final Materialized> materialized); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/TimeWindowedSerializer.java b/streams/src/main/java/org/apache/kafka/streams/kstream/TimeWindowedSerializer.java new file mode 100644 index 0000000..54bdd6a --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/TimeWindowedSerializer.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.internals.WindowedSerializer; +import org.apache.kafka.streams.state.internals.WindowKeySchema; + +import java.util.Map; + +public class TimeWindowedSerializer implements WindowedSerializer { + + private Serializer inner; + + // Default constructor needed by Kafka + @SuppressWarnings("WeakerAccess") + public TimeWindowedSerializer() {} + + public TimeWindowedSerializer(final Serializer inner) { + this.inner = inner; + } + + @SuppressWarnings("unchecked") + @Override + public void configure(final Map configs, final boolean isKey) { + final String windowedInnerClassSerdeConfig = (String) configs.get(StreamsConfig.WINDOWED_INNER_CLASS_SERDE); + Serde windowInnerClassSerde = null; + if (windowedInnerClassSerdeConfig != null) { + try { + windowInnerClassSerde = Utils.newInstance(windowedInnerClassSerdeConfig, Serde.class); + } catch (final ClassNotFoundException e) { + throw new ConfigException(StreamsConfig.WINDOWED_INNER_CLASS_SERDE, windowedInnerClassSerdeConfig, + "Serde class " + windowedInnerClassSerdeConfig + " could not be found."); + } + } + + if (inner != null && windowedInnerClassSerdeConfig != null) { + if (!inner.getClass().getName().equals(windowInnerClassSerde.serializer().getClass().getName())) { + throw new IllegalArgumentException("Inner class serializer set using constructor " + + "(" + inner.getClass().getName() + ")" + + " is different from the one set in windowed.inner.class.serde config " + + "(" + windowInnerClassSerde.serializer().getClass().getName() + ")."); + } + } else if (inner == null && windowedInnerClassSerdeConfig == null) { + throw new IllegalArgumentException("Inner class serializer should be set either via constructor " + + "or via the windowed.inner.class.serde config"); + } else if (inner == null) + inner = windowInnerClassSerde.serializer(); + } + + @Override + public byte[] serialize(final String topic, final Windowed data) { + WindowedSerdes.verifyInnerSerializerNotNull(inner, this); + + if (data == null) { + return null; + } + + return WindowKeySchema.toBinary(data, inner, topic); + } + + @Override + public void close() { + if (inner != null) { + inner.close(); + } + } + + @Override + public byte[] serializeBaseKey(final String topic, final Windowed data) { + WindowedSerdes.verifyInnerSerializerNotNull(inner, this); + + return inner.serialize(topic, data.key()); + } + + // Only for testing + Serializer innerSerializer() { + return inner; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/TimeWindows.java b/streams/src/main/java/org/apache/kafka/streams/kstream/TimeWindows.java new file mode 100644 index 0000000..adae2ae --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/TimeWindows.java @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.streams.kstream.internals.TimeWindow; +import org.apache.kafka.streams.processor.TimestampExtractor; + +import java.time.Duration; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Objects; + +import static java.time.Duration.ofMillis; +import static org.apache.kafka.streams.internals.ApiUtils.prepareMillisCheckFailMsgPrefix; +import static org.apache.kafka.streams.internals.ApiUtils.validateMillisecondDuration; + +/** + * The fixed-size time-based window specifications used for aggregations. + *

                + * The semantics of time-based aggregation windows are: Every T1 (advance) milliseconds, compute the aggregate total for + * T2 (size) milliseconds. + *

                  + *
                • If {@code advance < size} a hopping windows is defined:
                  + * it discretize a stream into overlapping windows, which implies that a record maybe contained in one and or + * more "adjacent" windows.
                • + *
                • If {@code advance == size} a tumbling window is defined:
                  + * it discretize a stream into non-overlapping windows, which implies that a record is only ever contained in + * one and only one tumbling window.
                • + *
                + * Thus, the specified {@link TimeWindow}s are aligned to the epoch. + * Aligned to the epoch means, that the first window starts at timestamp zero. + * For example, hopping windows with size of 5000ms and advance of 3000ms, have window boundaries + * [0;5000),[3000;8000),... and not [1000;6000),[4000;9000),... or even something "random" like [1452;6452),[4452;9452),... + *

                + * For time semantics, see {@link TimestampExtractor}. + * + * @see SessionWindows + * @see UnlimitedWindows + * @see JoinWindows + * @see KGroupedStream#windowedBy(Windows) + * @see TimestampExtractor + */ +public final class TimeWindows extends Windows { + + /** The size of the windows in milliseconds. */ + @SuppressWarnings("WeakerAccess") + public final long sizeMs; + + /** + * The size of the window's advance interval in milliseconds, i.e., by how much a window moves forward relative to + * the previous one. + */ + @SuppressWarnings("WeakerAccess") + public final long advanceMs; + + private final long graceMs; + + // flag to check if the grace is already set via ofSizeAndGrace or ofSizeWithNoGrace + private final boolean hasSetGrace; + + private TimeWindows(final long sizeMs, final long advanceMs, final long graceMs, final boolean hasSetGrace) { + this.sizeMs = sizeMs; + this.advanceMs = advanceMs; + this.graceMs = graceMs; + this.hasSetGrace = hasSetGrace; + + if (sizeMs <= 0) { + throw new IllegalArgumentException("Window size (sizeMs) must be larger than zero."); + } + + if (advanceMs <= 0 || advanceMs > sizeMs) { + throw new IllegalArgumentException(String.format("Window advancement interval should be more than zero " + + "and less than window duration which is %d ms, but given advancement interval is: %d ms", sizeMs, advanceMs)); + } + + if (graceMs < 0) { + throw new IllegalArgumentException("Grace period must not be negative."); + } + } + + /** + * Return a window definition with the given window size, and with the advance interval being equal to the window + * size. + * The time interval represented by the N-th window is: {@code [N * size, N * size + size)}. + *

                + * This provides the semantics of tumbling windows, which are fixed-sized, gap-less, non-overlapping windows. + * Tumbling windows are a special case of hopping windows with {@code advance == size}. + *

                + * CAUTION: Using this method implicitly sets the grace period to zero, which means that any out-of-order + * records arriving after the window ends are considered late and will be dropped. + * + * @param size The size of the window + * @return a new window definition with default no grace period. Note that this means out-of-order records arriving after the window end will be dropped + * @throws IllegalArgumentException if the specified window size is zero or negative or can't be represented as {@code long milliseconds} + */ + public static TimeWindows ofSizeWithNoGrace(final Duration size) throws IllegalArgumentException { + return ofSizeAndGrace(size, ofMillis(NO_GRACE_PERIOD)); + } + + /** + * Return a window definition with the given window size, and with the advance interval being equal to the window + * size. + * The time interval represented by the N-th window is: {@code [N * size, N * size + size)}. + *

                + * This provides the semantics of tumbling windows, which are fixed-sized, gap-less, non-overlapping windows. + * Tumbling windows are a special case of hopping windows with {@code advance == size}. + *

                + * Using this method explicitly sets the grace period to the duration specified by {@code afterWindowEnd}, which + * means that only out-of-order records arriving more than the grace period after the window end will be dropped. + * The window close, after which any incoming records are considered late and will be rejected, is defined as + * {@code windowEnd + afterWindowEnd} + * + * @param size The size of the window. Must be larger than zero + * @param afterWindowEnd The grace period to admit out-of-order events to a window. Must be non-negative. + * @return a TimeWindows object with the specified size and the specified grace period + * @throws IllegalArgumentException if {@code afterWindowEnd} is negative or can't be represented as {@code long milliseconds} + */ + public static TimeWindows ofSizeAndGrace(final Duration size, final Duration afterWindowEnd) throws IllegalArgumentException { + final String sizeMsgPrefix = prepareMillisCheckFailMsgPrefix(size, "size"); + final long sizeMs = validateMillisecondDuration(size, sizeMsgPrefix); + + final String afterWindowEndMsgPrefix = prepareMillisCheckFailMsgPrefix(afterWindowEnd, "afterWindowEnd"); + final long afterWindowEndMs = validateMillisecondDuration(afterWindowEnd, afterWindowEndMsgPrefix); + + return new TimeWindows(sizeMs, sizeMs, afterWindowEndMs, true); + } + + /** + * Return a window definition with the given window size, and with the advance interval being equal to the window + * size. + * The time interval represented by the N-th window is: {@code [N * size, N * size + size)}. + *

                + * This provides the semantics of tumbling windows, which are fixed-sized, gap-less, non-overlapping windows. + * Tumbling windows are a special case of hopping windows with {@code advance == size}. + * + * @param size The size of the window + * @return a new window definition without specifying the grace period (default to 24 hours minus window {@code size}) + * @throws IllegalArgumentException if the specified window size is zero or negative or can't be represented as {@code long milliseconds} + * @deprecated since 3.0. Use {@link #ofSizeWithNoGrace(Duration)} } instead + */ + @Deprecated + public static TimeWindows of(final Duration size) throws IllegalArgumentException { + final String msgPrefix = prepareMillisCheckFailMsgPrefix(size, "size"); + final long sizeMs = validateMillisecondDuration(size, msgPrefix); + + return new TimeWindows(sizeMs, sizeMs, Math.max(DEPRECATED_DEFAULT_24_HR_GRACE_PERIOD - sizeMs, 0), false); + } + + /** + * Return a window definition with the original size, but advance ("hop") the window by the given interval, which + * specifies by how much a window moves forward relative to the previous one. + * The time interval represented by the N-th window is: {@code [N * advance, N * advance + size)}. + *

                + * This provides the semantics of hopping windows, which are fixed-sized, overlapping windows. + * + * @param advance The advance interval ("hop") of the window, with the requirement that {@code 0 < advance.toMillis() <= sizeMs}. + * @return a new window definition with default maintain duration of 1 day + * @throws IllegalArgumentException if the advance interval is negative, zero, or larger than the window size + */ + public TimeWindows advanceBy(final Duration advance) { + final String msgPrefix = prepareMillisCheckFailMsgPrefix(advance, "advance"); + final long advanceMs = validateMillisecondDuration(advance, msgPrefix); + return new TimeWindows(sizeMs, advanceMs, graceMs, false); + } + + @Override + public Map windowsFor(final long timestamp) { + long windowStart = (Math.max(0, timestamp - sizeMs + advanceMs) / advanceMs) * advanceMs; + final Map windows = new LinkedHashMap<>(); + while (windowStart <= timestamp) { + final TimeWindow window = new TimeWindow(windowStart, windowStart + sizeMs); + windows.put(windowStart, window); + windowStart += advanceMs; + } + return windows; + } + + @Override + public long size() { + return sizeMs; + } + + /** + * Reject out-of-order events that arrive more than {@code millisAfterWindowEnd} + * after the end of its window. + *

                + * Delay is defined as (stream_time - record_timestamp). + * + * @param afterWindowEnd The grace period to admit out-of-order events to a window. + * @return this updated builder + * @throws IllegalArgumentException if {@code afterWindowEnd} is negative or can't be represented as {@code long milliseconds} + * @throws IllegalStateException if {@link #grace(Duration)} is called after {@link #ofSizeAndGrace(Duration, Duration)} or {@link #ofSizeWithNoGrace(Duration)} + * @deprecated since 3.0. Use {@link #ofSizeAndGrace(Duration, Duration)} instead + */ + @Deprecated + public TimeWindows grace(final Duration afterWindowEnd) throws IllegalArgumentException { + if (this.hasSetGrace) { + throw new IllegalStateException( + "Cannot call grace() after setting grace value via ofSizeAndGrace or ofSizeWithNoGrace."); + } + + final String msgPrefix = prepareMillisCheckFailMsgPrefix(afterWindowEnd, "afterWindowEnd"); + final long afterWindowEndMs = validateMillisecondDuration(afterWindowEnd, msgPrefix); + + return new TimeWindows(sizeMs, advanceMs, afterWindowEndMs, false); + } + + @Override + public long gracePeriodMs() { + return graceMs; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final TimeWindows that = (TimeWindows) o; + return sizeMs == that.sizeMs && + advanceMs == that.advanceMs && + graceMs == that.graceMs; + } + + @Override + public int hashCode() { + return Objects.hash(sizeMs, advanceMs, graceMs); + } + + @Override + public String toString() { + return "TimeWindows{" + + ", sizeMs=" + sizeMs + + ", advanceMs=" + advanceMs + + ", graceMs=" + graceMs + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/Transformer.java b/streams/src/main/java/org/apache/kafka/streams/kstream/Transformer.java new file mode 100644 index 0000000..af8e87e --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/Transformer.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import java.time.Duration; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.Punctuator; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.To; + +/** + * The {@code Transformer} interface is for stateful mapping of an input record to zero, one, or multiple new output + * records (both key and value type can be altered arbitrarily). + * This is a stateful record-by-record operation, i.e, {@link #transform(Object, Object)} is invoked individually for + * each record of a stream and can access and modify a state that is available beyond a single call of + * {@link #transform(Object, Object)} (cf. {@link KeyValueMapper} for stateless record transformation). + * Additionally, this {@code Transformer} can {@link ProcessorContext#schedule(Duration, PunctuationType, Punctuator) schedule} + * a method to be {@link Punctuator#punctuate(long) called periodically} with the provided context. + *

                + * Use {@link TransformerSupplier} to provide new instances of {@code Transformer} to Kafka Stream's runtime. + *

                + * If only a record's value should be modified {@link ValueTransformer} can be used. + * + * @param key type + * @param value type + * @param {@link KeyValue} return type (both key and value type can be set + * arbitrarily) + * @see TransformerSupplier + * @see KStream#transform(TransformerSupplier, String...) + * @see ValueTransformer + * @see KStream#map(KeyValueMapper) + * @see KStream#flatMap(KeyValueMapper) + */ +public interface Transformer { + + /** + * Initialize this transformer. + * This is called once per instance when the topology gets initialized. + * When the framework is done with the transformer, {@link #close()} will be called on it; the + * framework may later re-use the transformer by calling {@link #init(ProcessorContext)} again. + *

                + * The provided {@link ProcessorContext context} can be used to access topology and record meta data, to + * {@link ProcessorContext#schedule(Duration, PunctuationType, Punctuator) schedule} a method to be + * {@link Punctuator#punctuate(long) called periodically} and to access attached {@link StateStore}s. + *

                + * Note, that {@link ProcessorContext} is updated in the background with the current record's meta data. + * Thus, it only contains valid record meta data when accessed within {@link #transform(Object, Object)}. + * + * @param context the context + */ + void init(final ProcessorContext context); + + /** + * Transform the record with the given key and value. + * Additionally, any {@link StateStore state} that is {@link KStream#transform(TransformerSupplier, String...) + * attached} to this operator can be accessed and modified + * arbitrarily (cf. {@link ProcessorContext#getStateStore(String)}). + *

                + * If only one record should be forward downstream, {@code transform} can return a new {@link KeyValue}. If + * more than one output record should be forwarded downstream, {@link ProcessorContext#forward(Object, Object)} + * and {@link ProcessorContext#forward(Object, Object, To)} can be used. + * If no record should be forwarded downstream, {@code transform} can return {@code null}. + * + * Note that returning a new {@link KeyValue} is merely for convenience. The same can be achieved by using + * {@link ProcessorContext#forward(Object, Object)} and returning {@code null}. + * + * @param key the key for the record + * @param value the value for the record + * @return new {@link KeyValue} pair—if {@code null} no key-value pair will + * be forwarded to down stream + */ + R transform(final K key, final V value); + + /** + * Close this transformer and clean up any resources. The framework may + * later re-use this transformer by calling {@link #init(ProcessorContext)} on it again. + *

                + * To generate new {@link KeyValue} pairs {@link ProcessorContext#forward(Object, Object)} and + * {@link ProcessorContext#forward(Object, Object, To)} can be used. + */ + void close(); + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/TransformerSupplier.java b/streams/src/main/java/org/apache/kafka/streams/kstream/TransformerSupplier.java new file mode 100644 index 0000000..80c1dcf --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/TransformerSupplier.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + + +import org.apache.kafka.streams.processor.ConnectedStoreProvider; + +import java.util.function.Supplier; + +/** + * A {@code TransformerSupplier} interface which can create one or more {@link Transformer} instances. + *

                + * The supplier should always generate a new instance each time {@link TransformerSupplier#get()} gets called. Creating + * a single {@link Transformer} object and returning the same object reference in {@link TransformerSupplier#get()} would be + * a violation of the supplier pattern and leads to runtime exceptions. + * + * @param key type + * @param value type + * @param {@link org.apache.kafka.streams.KeyValue KeyValue} return type (both key and value type can be set + * arbitrarily) + * @see Transformer + * @see KStream#transform(TransformerSupplier, String...) + * @see ValueTransformer + * @see ValueTransformerSupplier + * @see KStream#transformValues(ValueTransformerSupplier, String...) + */ +public interface TransformerSupplier extends ConnectedStoreProvider, Supplier> { + + /** + * Return a newly constructed {@link Transformer} instance. + * The supplier should always generate a new instance each time {@link TransformerSupplier#get() gets called}. + *

                + * Creating a single {@link Transformer} object and returning the same object reference in {@link TransformerSupplier#get()} + * is a violation of the supplier pattern and leads to runtime exceptions. + * + * @return a newly constructed {@link Transformer} instance + */ + Transformer get(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/UnlimitedWindows.java b/streams/src/main/java/org/apache/kafka/streams/kstream/UnlimitedWindows.java new file mode 100644 index 0000000..513f19e --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/UnlimitedWindows.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.streams.internals.ApiUtils; +import org.apache.kafka.streams.kstream.internals.UnlimitedWindow; +import org.apache.kafka.streams.processor.TimestampExtractor; + +import java.time.Instant; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.apache.kafka.streams.internals.ApiUtils.prepareMillisCheckFailMsgPrefix; + +/** + * The unlimited window specifications used for aggregations. + *

                + * An unlimited time window is also called landmark window. + * It has a fixed starting point while its window end is defined as infinite. + * With this regard, it is a fixed-size window with infinite window size. + *

                + * For time semantics, see {@link TimestampExtractor}. + * + * @see TimeWindows + * @see SessionWindows + * @see JoinWindows + * @see KGroupedStream#windowedBy(Windows) + * @see TimestampExtractor + */ +public final class UnlimitedWindows extends Windows { + + private static final long DEFAULT_START_TIMESTAMP_MS = 0L; + + /** The start timestamp of the window. */ + @SuppressWarnings("WeakerAccess") + public final long startMs; + + private UnlimitedWindows(final long startMs) { + this.startMs = startMs; + } + + /** + * Return an unlimited window starting at timestamp zero. + */ + public static UnlimitedWindows of() { + return new UnlimitedWindows(DEFAULT_START_TIMESTAMP_MS); + } + + /** + * Return a new unlimited window for the specified start timestamp. + * + * @param start the window start time + * @return a new unlimited window that starts at {@code start} + * @throws IllegalArgumentException if the start time is negative or can't be represented as {@code long milliseconds} + */ + public UnlimitedWindows startOn(final Instant start) throws IllegalArgumentException { + final String msgPrefix = prepareMillisCheckFailMsgPrefix(start, "start"); + final long startMs = ApiUtils.validateMillisecondInstant(start, msgPrefix); + if (startMs < 0) { + throw new IllegalArgumentException("Window start time (startMs) cannot be negative."); + } + return new UnlimitedWindows(startMs); + } + + @Override + public Map windowsFor(final long timestamp) { + // always return the single unlimited window + + // we cannot use Collections.singleMap since it does not support remove() + final Map windows = new HashMap<>(); + if (timestamp >= startMs) { + windows.put(startMs, new UnlimitedWindow(startMs)); + } + return windows; + } + + /** + * {@inheritDoc} + * As unlimited windows have conceptually infinite size, this methods just returns {@link Long#MAX_VALUE}. + * + * @return the size of the specified windows which is {@link Long#MAX_VALUE} + */ + @Override + public long size() { + return Long.MAX_VALUE; + } + + @Override + public long gracePeriodMs() { + return 0L; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final UnlimitedWindows that = (UnlimitedWindows) o; + return startMs == that.startMs; + } + + @Override + public int hashCode() { + return Objects.hash(startMs); + } + + @Override + public String toString() { + return "UnlimitedWindows{" + + "startMs=" + startMs + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/ValueJoiner.java b/streams/src/main/java/org/apache/kafka/streams/kstream/ValueJoiner.java new file mode 100644 index 0000000..0e57375 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/ValueJoiner.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + + +/** + * The {@code ValueJoiner} interface for joining two values into a new value of arbitrary type. + * This is a stateless operation, i.e, {@link #apply(Object, Object)} is invoked individually for each joining + * record-pair of a {@link KStream}-{@link KStream}, {@link KStream}-{@link KTable}, or {@link KTable}-{@link KTable} + * join. + * + * @param first value type + * @param second value type + * @param joined value type + * @see KStream#join(KStream, ValueJoiner, JoinWindows) + * @see KStream#join(KStream, ValueJoiner, JoinWindows, StreamJoined) + * @see KStream#leftJoin(KStream, ValueJoiner, JoinWindows) + * @see KStream#leftJoin(KStream, ValueJoiner, JoinWindows, StreamJoined) + * @see KStream#outerJoin(KStream, ValueJoiner, JoinWindows) + * @see KStream#outerJoin(KStream, ValueJoiner, JoinWindows, StreamJoined) + * @see KStream#join(KTable, ValueJoiner) + * @see KStream#join(KTable, ValueJoiner, Joined) + * @see KStream#leftJoin(KTable, ValueJoiner) + * @see KStream#leftJoin(KTable, ValueJoiner, Joined) + * @see KTable#join(KTable, ValueJoiner) + * @see KTable#leftJoin(KTable, ValueJoiner) + * @see KTable#outerJoin(KTable, ValueJoiner) + */ +public interface ValueJoiner { + + /** + * Return a joined value consisting of {@code value1} and {@code value2}. + * + * @param value1 the first value for joining + * @param value2 the second value for joining + * @return the joined value + */ + VR apply(final V1 value1, final V2 value2); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/ValueJoinerWithKey.java b/streams/src/main/java/org/apache/kafka/streams/kstream/ValueJoinerWithKey.java new file mode 100644 index 0000000..57f76c8 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/ValueJoinerWithKey.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + + +/** + * The {@code ValueJoinerWithKey} interface for joining two values into a new value of arbitrary type. + * This interface provides access to a read-only key that the user should not modify as this would lead to + * undefined behavior + * This is a stateless operation, i.e, {@link #apply(Object, Object, Object)} is invoked individually for each joining + * record-pair of a {@link KStream}-{@link KStream}, {@link KStream}-{@link KTable}, or {@link KTable}-{@link KTable} + * join. + * + * @param key value type + * @param first value type + * @param second value type + * @param joined value type + * @see KStream#join(KStream, ValueJoinerWithKey, JoinWindows) + * @see KStream#join(KStream, ValueJoinerWithKey, JoinWindows, StreamJoined) + * @see KStream#leftJoin(KStream, ValueJoinerWithKey, JoinWindows) + * @see KStream#leftJoin(KStream, ValueJoinerWithKey, JoinWindows, StreamJoined) + * @see KStream#outerJoin(KStream, ValueJoinerWithKey, JoinWindows) + * @see KStream#outerJoin(KStream, ValueJoinerWithKey, JoinWindows, StreamJoined) + * @see KStream#join(KTable, ValueJoinerWithKey) + * @see KStream#join(KTable, ValueJoinerWithKey, Joined) + * @see KStream#leftJoin(KTable, ValueJoinerWithKey) + * @see KStream#leftJoin(KTable, ValueJoinerWithKey, Joined) + * @see KStream#join(GlobalKTable, KeyValueMapper, ValueJoinerWithKey) + * @see KStream#join(GlobalKTable, KeyValueMapper, ValueJoinerWithKey, Named) + * @see KStream#leftJoin(GlobalKTable, KeyValueMapper, ValueJoinerWithKey) + * @see KStream#leftJoin(GlobalKTable, KeyValueMapper, ValueJoinerWithKey, Named) + */ +public interface ValueJoinerWithKey { + + /** + * Return a joined value consisting of {@code readOnlyKey}, {@code value1} and {@code value2}. + * + * @param readOnlyKey the key + * @param value1 the first value for joining + * @param value2 the second value for joining + * @return the joined value + */ + VR apply(final K1 readOnlyKey, final V1 value1, final V2 value2); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/ValueMapper.java b/streams/src/main/java/org/apache/kafka/streams/kstream/ValueMapper.java new file mode 100644 index 0000000..be550a1 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/ValueMapper.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + + +/** + * The {@code ValueMapper} interface for mapping a value to a new value of arbitrary type. + * This is a stateless record-by-record operation, i.e, {@link #apply(Object)} is invoked individually for each record + * of a stream (cf. {@link ValueTransformer} for stateful value transformation). + * If {@code ValueMapper} is applied to a {@link org.apache.kafka.streams.KeyValue key-value pair} record the record's + * key is preserved. + * If a record's key and value should be modified {@link KeyValueMapper} can be used. + * + * @param value type + * @param mapped value type + * @see KeyValueMapper + * @see ValueTransformer + * @see ValueTransformerWithKey + * @see KStream#mapValues(ValueMapper) + * @see KStream#mapValues(ValueMapperWithKey) + * @see KStream#flatMapValues(ValueMapper) + * @see KStream#flatMapValues(ValueMapperWithKey) + * @see KTable#mapValues(ValueMapper) + * @see KTable#mapValues(ValueMapperWithKey) + */ +public interface ValueMapper { + + /** + * Map the given value to a new value. + * + * @param value the value to be mapped + * @return the new value + */ + VR apply(final V value); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/ValueMapperWithKey.java b/streams/src/main/java/org/apache/kafka/streams/kstream/ValueMapperWithKey.java new file mode 100644 index 0000000..b20c61a --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/ValueMapperWithKey.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +/** + * The {@code ValueMapperWithKey} interface for mapping a value to a new value of arbitrary type. + * This is a stateless record-by-record operation, i.e, {@link #apply(Object, Object)} is invoked individually for each + * record of a stream (cf. {@link ValueTransformer} for stateful value transformation). + * If {@code ValueMapperWithKey} is applied to a {@link org.apache.kafka.streams.KeyValue key-value pair} record the + * record's key is preserved. + * Note that the key is read-only and should not be modified, as this can lead to corrupt partitioning. + * If a record's key and value should be modified {@link KeyValueMapper} can be used. + * + * @param key type + * @param value type + * @param mapped value type + * @see KeyValueMapper + * @see ValueTransformer + * @see ValueTransformerWithKey + * @see KStream#mapValues(ValueMapper) + * @see KStream#mapValues(ValueMapperWithKey) + * @see KStream#flatMapValues(ValueMapper) + * @see KStream#flatMapValues(ValueMapperWithKey) + * @see KTable#mapValues(ValueMapper) + * @see KTable#mapValues(ValueMapperWithKey) + */ + +public interface ValueMapperWithKey { + + /** + * Map the given [key and ]value to a new value. + * + * @param readOnlyKey the read-only key + * @param value the value to be mapped + * @return the new value + */ + VR apply(final K readOnlyKey, final V value); +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/ValueTransformer.java b/streams/src/main/java/org/apache/kafka/streams/kstream/ValueTransformer.java new file mode 100644 index 0000000..987cae5 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/ValueTransformer.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import java.time.Duration; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.Punctuator; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.To; + +/** + * The {@code ValueTransformer} interface for stateful mapping of a value to a new value (with possible new type). + * This is a stateful record-by-record operation, i.e, {@link #transform(Object)} is invoked individually for each + * record of a stream and can access and modify a state that is available beyond a single call of + * {@link #transform(Object)} (cf. {@link ValueMapper} for stateless value transformation). + * Additionally, this {@code ValueTransformer} can {@link ProcessorContext#schedule(Duration, PunctuationType, Punctuator) schedule} + * a method to be {@link Punctuator#punctuate(long) called periodically} with the provided context. + * If {@code ValueTransformer} is applied to a {@link KeyValue} pair record the record's key is preserved. + *

                + * Use {@link ValueTransformerSupplier} to provide new instances of {@code ValueTransformer} to Kafka Stream's runtime. + *

                + * If a record's key and value should be modified {@link Transformer} can be used. + * + * @param value type + * @param transformed value type + * @see ValueTransformerSupplier + * @see ValueTransformerWithKeySupplier + * @see KStream#transformValues(ValueTransformerSupplier, String...) + * @see KStream#transformValues(ValueTransformerWithKeySupplier, String...) + * @see Transformer + */ +public interface ValueTransformer { + + /** + * Initialize this transformer. + * This is called once per instance when the topology gets initialized. + * When the framework is done with the transformer, {@link #close()} will be called on it; the + * framework may later re-use the transformer by calling {@link #init(ProcessorContext)} again. + *

                + * The provided {@link ProcessorContext context} can be used to access topology and record meta data, to + * {@link ProcessorContext#schedule(Duration, PunctuationType, Punctuator) schedule} a method to be + * {@link Punctuator#punctuate(long) called periodically} and to access attached {@link StateStore}s. + *

                + * Note that {@link ProcessorContext} is updated in the background with the current record's meta data. + * Thus, it only contains valid record meta data when accessed within {@link #transform(Object)}. + *

                + * Note that using {@link ProcessorContext#forward(Object, Object)} or + * {@link ProcessorContext#forward(Object, Object, To)} is not allowed within any method of + * {@code ValueTransformer} and will result in an {@link StreamsException exception}. + * + * @param context the context + * @throws IllegalStateException If store gets registered after initialization is already finished + * @throws StreamsException if the store's change log does not contain the partition + */ + void init(final ProcessorContext context); + + /** + * Transform the given value to a new value. + * Additionally, any {@link StateStore} that is {@link KStream#transformValues(ValueTransformerSupplier, String...) + * attached} to this operator can be accessed and modified arbitrarily (cf. + * {@link ProcessorContext#getStateStore(String)}). + *

                + * Note, that using {@link ProcessorContext#forward(Object, Object)} or + * {@link ProcessorContext#forward(Object, Object, To)} is not allowed within {@code transform} and + * will result in an {@link StreamsException exception}. + * + * @param value the value to be transformed + * @return the new value + */ + VR transform(final V value); + + /** + * Close this transformer and clean up any resources. The framework may + * later re-use this transformer by calling {@link #init(ProcessorContext)} on it again. + *

                + * It is not possible to return any new output records within {@code close()}. + * Using {@link ProcessorContext#forward(Object, Object)} or {@link ProcessorContext#forward(Object, Object, To)} + * will result in an {@link StreamsException exception}. + */ + void close(); + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/ValueTransformerSupplier.java b/streams/src/main/java/org/apache/kafka/streams/kstream/ValueTransformerSupplier.java new file mode 100644 index 0000000..b0c18db --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/ValueTransformerSupplier.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.streams.processor.ConnectedStoreProvider; + +/** + * A {@code ValueTransformerSupplier} interface which can create one or more {@link ValueTransformer} instances. + *

                + * The supplier should always generate a new instance each time {@link ValueTransformerSupplier#get()} gets called. Creating + * a single {@link ValueTransformer} object and returning the same object reference in {@link ValueTransformerSupplier#get()} would be + * a violation of the supplier pattern and leads to runtime exceptions. + * + * @param value type + * @param transformed value type + * @see ValueTransformer + * @see ValueTransformerWithKey + * @see ValueTransformerWithKeySupplier + * @see KStream#transformValues(ValueTransformerSupplier, String...) + * @see KStream#transformValues(ValueTransformerWithKeySupplier, String...) + * @see Transformer + * @see TransformerSupplier + * @see KStream#transform(TransformerSupplier, String...) + */ +public interface ValueTransformerSupplier extends ConnectedStoreProvider { + + /** + * Return a newly constructed {@link ValueTransformer} instance. + * The supplier should always generate a new instance each time {@link ValueTransformerSupplier#get()} gets called. + *

                + * Creating a single {@link ValueTransformer} object and returning the same object reference in {@link ValueTransformerSupplier#get()} + * is a violation of the supplier pattern and leads to runtime exceptions. + * + * @return a new {@link ValueTransformer} instance + * @return a newly constructed {@link ValueTransformer} instance + */ + ValueTransformer get(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/ValueTransformerWithKey.java b/streams/src/main/java/org/apache/kafka/streams/kstream/ValueTransformerWithKey.java new file mode 100644 index 0000000..fecd96f --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/ValueTransformerWithKey.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import java.time.Duration; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.Punctuator; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.To; + +/** + * The {@code ValueTransformerWithKey} interface for stateful mapping of a value to a new value (with possible new type). + * This is a stateful record-by-record operation, i.e, {@link #transform(Object, Object)} is invoked individually for each + * record of a stream and can access and modify a state that is available beyond a single call of + * {@link #transform(Object, Object)} (cf. {@link ValueMapper} for stateless value transformation). + * Additionally, this {@code ValueTransformerWithKey} can + * {@link ProcessorContext#schedule(Duration, PunctuationType, Punctuator) schedule} a method to be + * {@link Punctuator#punctuate(long) called periodically} with the provided context. + * Note that the key is read-only and should not be modified, as this can lead to corrupt partitioning. + * If {@code ValueTransformerWithKey} is applied to a {@link KeyValue} pair record the record's key is preserved. + *

                + * Use {@link ValueTransformerWithKeySupplier} to provide new instances of {@link ValueTransformerWithKey} to + * Kafka Stream's runtime. + *

                + * If a record's key and value should be modified {@link Transformer} can be used. + * + * @param key type + * @param value type + * @param transformed value type + * @see ValueTransformer + * @see ValueTransformerWithKeySupplier + * @see KStream#transformValues(ValueTransformerSupplier, String...) + * @see KStream#transformValues(ValueTransformerWithKeySupplier, String...) + * @see Transformer + */ + +public interface ValueTransformerWithKey { + + /** + * Initialize this transformer. + * This is called once per instance when the topology gets initialized. + *

                + * The provided {@link ProcessorContext context} can be used to access topology and record meta data, to + * {@link ProcessorContext#schedule(Duration, PunctuationType, Punctuator) schedule} a method to be + * {@link Punctuator#punctuate(long) called periodically} and to access attached {@link StateStore}s. + *

                + * Note that {@link ProcessorContext} is updated in the background with the current record's meta data. + * Thus, it only contains valid record meta data when accessed within {@link #transform(Object, Object)}. + *

                + * Note that using {@link ProcessorContext#forward(Object, Object)} or + * {@link ProcessorContext#forward(Object, Object, To)} is not allowed within any method of + * {@code ValueTransformerWithKey} and will result in an {@link StreamsException exception}. + * + * @param context the context + * @throws IllegalStateException If store gets registered after initialization is already finished + * @throws StreamsException if the store's change log does not contain the partition + */ + void init(final ProcessorContext context); + + /** + * Transform the given [key and] value to a new value. + * Additionally, any {@link StateStore} that is {@link KStream#transformValues(ValueTransformerWithKeySupplier, String...) + * attached} to this operator can be accessed and modified arbitrarily (cf. + * {@link ProcessorContext#getStateStore(String)}). + *

                + * Note that using {@link ProcessorContext#forward(Object, Object)} or + * {@link ProcessorContext#forward(Object, Object, To)} is not allowed within {@code transform} and + * will result in an {@link StreamsException exception}. + *

                + * Note that if a {@code ValueTransformerWithKey} is used in a {@link KTable#transformValues(ValueTransformerWithKeySupplier, String...)} + * (or any other overload of {@code KTable#transformValues(...)}) operation, + * then the provided {@link ProcessorContext} from {@link #init(ProcessorContext)} + * does not guarantee that all context information will be available when {@code transform()} + * is executed, as it might be executed "out-of-band" due to some internal optimizations + * applied by the Kafka Streams DSL. + * + * @param readOnlyKey the read-only key + * @param value the value to be transformed + * @return the new value + */ + VR transform(final K readOnlyKey, final V value); + + /** + * Close this processor and clean up any resources. + *

                + * It is not possible to return any new output records within {@code close()}. + * Using {@link ProcessorContext#forward(Object, Object)} or {@link ProcessorContext#forward(Object, Object, To)}, + * will result in an {@link StreamsException exception}. + */ + void close(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/ValueTransformerWithKeySupplier.java b/streams/src/main/java/org/apache/kafka/streams/kstream/ValueTransformerWithKeySupplier.java new file mode 100644 index 0000000..9aae791 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/ValueTransformerWithKeySupplier.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.streams.processor.ConnectedStoreProvider; + +import java.util.function.Supplier; + +/** + * A {@code ValueTransformerWithKeySupplier} interface which can create one or more {@link ValueTransformerWithKey} instances. + *

                + * The supplier should always generate a new instance each time {@link ValueTransformerWithKeySupplier#get()} gets called. Creating + * a single {@link ValueTransformerWithKey} object and returning the same object reference in {@link ValueTransformerWithKeySupplier#get()} would be + * a violation of the supplier pattern and leads to runtime exceptions. + * + * @param key type + * @param value type + * @param transformed value type + * @see ValueTransformer + * @see ValueTransformerWithKey + * @see KStream#transformValues(ValueTransformerSupplier, String...) + * @see KStream#transformValues(ValueTransformerWithKeySupplier, String...) + * @see Transformer + * @see TransformerSupplier + * @see KStream#transform(TransformerSupplier, String...) + */ +public interface ValueTransformerWithKeySupplier extends ConnectedStoreProvider, Supplier> { + + /** + * Return a newly constructed {@link ValueTransformerWithKey} instance. + * The supplier should always generate a new instance each time {@link ValueTransformerWithKeySupplier#get()} gets called. + *

                + * Creating a single {@link ValueTransformerWithKey} object and returning the same object reference in {@link ValueTransformerWithKeySupplier#get()} + * is a violation of the supplier pattern and leads to runtime exceptions. + * + * @return a new {@link ValueTransformerWithKey} instance + * @return a newly constructed {@link ValueTransformerWithKey} instance + */ + ValueTransformerWithKey get(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/Window.java b/streams/src/main/java/org/apache/kafka/streams/kstream/Window.java new file mode 100644 index 0000000..432bb45 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/Window.java @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.streams.processor.TimestampExtractor; + +import java.time.Instant; + +/** + * A single window instance, defined by its start and end timestamp. + * {@code Window} is agnostic if start/end boundaries are inclusive or exclusive; this is defined by concrete + * window implementations. + *

                + * To specify how {@code Window} boundaries are defined use {@link Windows}. + * For time semantics, see {@link TimestampExtractor}. + * + * @see Windows + * @see org.apache.kafka.streams.kstream.internals.TimeWindow + * @see org.apache.kafka.streams.kstream.internals.SessionWindow + * @see org.apache.kafka.streams.kstream.internals.UnlimitedWindow + * @see TimestampExtractor + */ +public abstract class Window { + + protected final long startMs; + protected final long endMs; + private final Instant startTime; + private final Instant endTime; + + + /** + * Create a new window for the given start and end time. + * + * @param startMs the start timestamp of the window + * @param endMs the end timestamp of the window + * @throws IllegalArgumentException if {@code startMs} is negative or if {@code endMs} is smaller than {@code startMs} + */ + public Window(final long startMs, final long endMs) throws IllegalArgumentException { + if (startMs < 0) { + throw new IllegalArgumentException("Window startMs time cannot be negative."); + } + if (endMs < startMs) { + throw new IllegalArgumentException("Window endMs time cannot be smaller than window startMs time."); + } + this.startMs = startMs; + this.endMs = endMs; + + this.startTime = Instant.ofEpochMilli(startMs); + this.endTime = Instant.ofEpochMilli(endMs); + } + + /** + * Return the start timestamp of this window. + * + * @return The start timestamp of this window. + */ + public long start() { + return startMs; + } + + /** + * Return the end timestamp of this window. + * + * @return The end timestamp of this window. + */ + public long end() { + return endMs; + } + + /** + * Return the start time of this window. + * + * @return The start time of this window. + */ + public Instant startTime() { + return startTime; + } + + /** + * Return the end time of this window. + * + * @return The end time of this window. + */ + public Instant endTime() { + return endTime; + } + + /** + * Check if the given window overlaps with this window. + * Should throw an {@link IllegalArgumentException} if the {@code other} window has a different type than {@code + * this} window. + * + * @param other another window of the same type + * @return {@code true} if {@code other} overlaps with this window—{@code false} otherwise + */ + public abstract boolean overlap(final Window other); + + @Override + public boolean equals(final Object obj) { + if (obj == this) { + return true; + } + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + + final Window other = (Window) obj; + return startMs == other.startMs && endMs == other.endMs; + } + + @Override + public int hashCode() { + return (int) (((startMs << 32) | endMs) % 0xFFFFFFFFL); + } + + @Override + public String toString() { + return "Window{" + + "startMs=" + startMs + + ", endMs=" + endMs + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/Windowed.java b/streams/src/main/java/org/apache/kafka/streams/kstream/Windowed.java new file mode 100644 index 0000000..d830f58 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/Windowed.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + + +/** + * The result key type of a windowed stream aggregation. + *

                + * If a {@link KStream} gets grouped and aggregated using a window-aggregation the resulting {@link KTable} is a + * so-called "windowed {@link KTable}" with a combined key type that encodes the corresponding aggregation window and + * the original record key. + * Thus, a windowed {@link KTable} has type {@code ,V>}. + * + * @param type of the key + * @see KGroupedStream#windowedBy(Windows) + * @see KGroupedStream#windowedBy(SessionWindows) + */ +public class Windowed { + + private final K key; + + private final Window window; + + public Windowed(final K key, final Window window) { + this.key = key; + this.window = window; + } + + /** + * Return the key of the window. + * + * @return the key of the window + */ + public K key() { + return key; + } + + /** + * Return the window containing the values associated with this key. + * + * @return the window containing the values + */ + public Window window() { + return window; + } + + @Override + public String toString() { + return "[" + key + "@" + window.start() + "/" + window.end() + "]"; + } + + @Override + public boolean equals(final Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof Windowed)) { + return false; + } + final Windowed that = (Windowed) obj; + return window.equals(that.window) && key.equals(that.key); + } + + @Override + public int hashCode() { + final long n = ((long) window.hashCode() << 32) | key.hashCode(); + return (int) (n % 0xFFFFFFFFL); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/WindowedSerdes.java b/streams/src/main/java/org/apache/kafka/streams/kstream/WindowedSerdes.java new file mode 100644 index 0000000..07e1cae --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/WindowedSerdes.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; + +public class WindowedSerdes { + + static public class TimeWindowedSerde extends Serdes.WrapperSerde> { + // Default constructor needed for reflection object creation + public TimeWindowedSerde() { + super(new TimeWindowedSerializer<>(), new TimeWindowedDeserializer<>()); + } + + @Deprecated + public TimeWindowedSerde(final Serde inner) { + super(new TimeWindowedSerializer<>(inner.serializer()), new TimeWindowedDeserializer<>(inner.deserializer())); + } + + // This constructor can be used for serialize/deserialize a windowed topic + public TimeWindowedSerde(final Serde inner, final long windowSize) { + super(new TimeWindowedSerializer<>(inner.serializer()), new TimeWindowedDeserializer<>(inner.deserializer(), windowSize)); + } + + // Helper method for users to specify whether the input topic is a changelog topic for deserializing the key properly. + public TimeWindowedSerde forChangelog(final boolean isChangelogTopic) { + final TimeWindowedDeserializer deserializer = (TimeWindowedDeserializer) this.deserializer(); + deserializer.setIsChangelogTopic(isChangelogTopic); + return this; + } + } + + static public class SessionWindowedSerde extends Serdes.WrapperSerde> { + // Default constructor needed for reflection object creation + public SessionWindowedSerde() { + super(new SessionWindowedSerializer<>(), new SessionWindowedDeserializer<>()); + } + + public SessionWindowedSerde(final Serde inner) { + super(new SessionWindowedSerializer<>(inner.serializer()), new SessionWindowedDeserializer<>(inner.deserializer())); + } + } + + /** + * Construct a {@code TimeWindowedSerde} object for the specified inner class type. + */ + @Deprecated + static public Serde> timeWindowedSerdeFrom(final Class type) { + return new TimeWindowedSerde<>(Serdes.serdeFrom(type)); + } + + /** + * Construct a {@code TimeWindowedSerde} object to deserialize changelog topic + * for the specified inner class type and window size. + */ + static public Serde> timeWindowedSerdeFrom(final Class type, final long windowSize) { + return new TimeWindowedSerde<>(Serdes.serdeFrom(type), windowSize); + } + + /** + * Construct a {@code SessionWindowedSerde} object for the specified inner class type. + */ + static public Serde> sessionWindowedSerdeFrom(final Class type) { + return new SessionWindowedSerde<>(Serdes.serdeFrom(type)); + } + + static void verifyInnerSerializerNotNull(final Serializer inner, + final Serializer wrapper) { + if (inner == null) { + throw new NullPointerException("Inner serializer is `null`. " + + "User code must use constructor `" + wrapper.getClass().getSimpleName() + "(final Serializer inner)` " + + "instead of the no-arg constructor."); + } + } + + static void verifyInnerDeserializerNotNull(final Deserializer inner, + final Deserializer wrapper) { + if (inner == null) { + throw new NullPointerException("Inner deserializer is `null`. " + + "User code must use constructor `" + wrapper.getClass().getSimpleName() + "(final Deserializer inner)` " + + "instead of the no-arg constructor."); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/Windows.java b/streams/src/main/java/org/apache/kafka/streams/kstream/Windows.java new file mode 100644 index 0000000..cd8a286 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/Windows.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.streams.processor.TimestampExtractor; + +import java.util.Map; + +/** + * The window specification for fixed size windows that is used to define window boundaries and grace period. + *

                + * Grace period defines how long to wait on out-of-order events. That is, windows will continue to accept new records until {@code stream_time >= window_end + grace_period}. + * Records that arrive after the grace period passed are considered late and will not be processed but are dropped. + *

                + * Warning: It may be unsafe to use objects of this class in set- or map-like collections, + * since the equals and hashCode methods depend on mutable fields. + * + * @param type of the window instance + * @see TimeWindows + * @see UnlimitedWindows + * @see JoinWindows + * @see SessionWindows + * @see TimestampExtractor + */ +public abstract class Windows { + + /** + * By default grace period is 24 hours for all windows in other words we allow out-of-order data for up to a day + * This behavior is now deprecated and additional details are available in the motivation for the KIP + * Check out KIP-633 for more details + */ + protected static final long DEPRECATED_DEFAULT_24_HR_GRACE_PERIOD = 24 * 60 * 60 * 1000L; + + /** + * This constant is used as the specified grace period where we do not have any grace periods instead of magic constants + */ + protected static final long NO_GRACE_PERIOD = 0L; + + protected Windows() {} + + /** + * Create all windows that contain the provided timestamp, indexed by non-negative window start timestamps. + * + * @param timestamp the timestamp window should get created for + * @return a map of {@code windowStartTimestamp -> Window} entries + */ + public abstract Map windowsFor(final long timestamp); + + /** + * Return the size of the specified windows in milliseconds. + * + * @return the size of the specified windows + */ + public abstract long size(); + + /** + * Return the window grace period (the time to admit + * out-of-order events after the end of the window.) + * + * Delay is defined as (stream_time - record_timestamp). + */ + public abstract long gracePeriodMs(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/AbstractStream.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/AbstractStream.java new file mode 100644 index 0000000..b145741 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/AbstractStream.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.internals.ApiUtils; +import org.apache.kafka.streams.kstream.ValueJoiner; +import org.apache.kafka.streams.kstream.ValueJoinerWithKey; +import org.apache.kafka.streams.kstream.ValueMapper; +import org.apache.kafka.streams.kstream.ValueMapperWithKey; +import org.apache.kafka.streams.kstream.ValueTransformer; +import org.apache.kafka.streams.kstream.ValueTransformerSupplier; +import org.apache.kafka.streams.kstream.ValueTransformerWithKey; +import org.apache.kafka.streams.kstream.ValueTransformerWithKeySupplier; +import org.apache.kafka.streams.kstream.internals.graph.GraphNode; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; +import org.apache.kafka.streams.state.StoreBuilder; + +import java.util.Collection; +import java.util.HashSet; +import java.util.Objects; +import java.util.Set; + +/* + * Any classes (KTable, KStream, etc) extending this class should follow the serde specification precedence ordering as: + * + * 1) Overridden values via control objects (e.g. Materialized, Serialized, Consumed, etc) + * 2) Serdes that can be inferred from the operator itself (e.g. groupBy().count(), where value serde can default to `LongSerde`). + * 3) Serde inherited from parent operator if possible (note if the key / value types have been changed, then the corresponding serde cannot be inherited). + * 4) Default serde specified in the config. + */ +public abstract class AbstractStream { + + protected final String name; + protected final Serde keySerde; + protected final Serde valueSerde; + protected final Set subTopologySourceNodes; + protected final GraphNode graphNode; + protected final InternalStreamsBuilder builder; + + // This copy-constructor will allow to extend KStream + // and KTable APIs with new methods without impacting the public interface. + public AbstractStream(final AbstractStream stream) { + this.name = stream.name; + this.builder = stream.builder; + this.keySerde = stream.keySerde; + this.valueSerde = stream.valueSerde; + this.subTopologySourceNodes = stream.subTopologySourceNodes; + this.graphNode = stream.graphNode; + } + + AbstractStream(final String name, + final Serde keySerde, + final Serde valueSerde, + final Set subTopologySourceNodes, + final GraphNode graphNode, + final InternalStreamsBuilder builder) { + if (subTopologySourceNodes == null || subTopologySourceNodes.isEmpty()) { + throw new IllegalArgumentException("parameter must not be null or empty"); + } + + this.name = name; + this.builder = builder; + this.keySerde = keySerde; + this.valueSerde = valueSerde; + this.subTopologySourceNodes = subTopologySourceNodes; + this.graphNode = graphNode; + } + + // This method allows to expose the InternalTopologyBuilder instance + // to subclasses that extend AbstractStream class. + protected InternalTopologyBuilder internalTopologyBuilder() { + return builder.internalTopologyBuilder; + } + + Set ensureCopartitionWith(final Collection> otherStreams) { + final Set allSourceNodes = new HashSet<>(subTopologySourceNodes); + for (final AbstractStream other: otherStreams) { + allSourceNodes.addAll(other.subTopologySourceNodes); + } + builder.internalTopologyBuilder.copartitionSources(allSourceNodes); + + return allSourceNodes; + } + + static ValueJoiner reverseJoiner(final ValueJoiner joiner) { + return (value2, value1) -> joiner.apply(value1, value2); + } + + static ValueJoinerWithKey reverseJoinerWithKey(final ValueJoinerWithKey joiner) { + return (key, value2, value1) -> joiner.apply(key, value1, value2); + } + + static ValueMapperWithKey withKey(final ValueMapper valueMapper) { + Objects.requireNonNull(valueMapper, "valueMapper can't be null"); + return (readOnlyKey, value) -> valueMapper.apply(value); + } + + static ValueTransformerWithKeySupplier toValueTransformerWithKeySupplier( + final ValueTransformerSupplier valueTransformerSupplier) { + Objects.requireNonNull(valueTransformerSupplier, "valueTransformerSupplier can't be null"); + ApiUtils.checkSupplier(valueTransformerSupplier); + return new ValueTransformerWithKeySupplier() { + @Override + public ValueTransformerWithKey get() { + final ValueTransformer valueTransformer = valueTransformerSupplier.get(); + return new ValueTransformerWithKey() { + @Override + public void init(final ProcessorContext context) { + valueTransformer.init(context); + } + + @Override + public VR transform(final K readOnlyKey, final V value) { + return valueTransformer.transform(value); + } + + @Override + public void close() { + valueTransformer.close(); + } + }; + } + + @Override + public Set> stores() { + return valueTransformerSupplier.stores(); + } + }; + } + + static ValueJoinerWithKey toValueJoinerWithKey(final ValueJoiner valueJoiner) { + Objects.requireNonNull(valueJoiner, "joiner can't be null"); + return (readOnlyKey, value1, value2) -> valueJoiner.apply(value1, value2); + } + + // for testing only + public Serde keySerde() { + return keySerde; + } + + public Serde valueSerde() { + return valueSerde; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/BranchedInternal.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/BranchedInternal.java new file mode 100644 index 0000000..a1d7552 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/BranchedInternal.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.kstream.Branched; +import org.apache.kafka.streams.kstream.KStream; + +import java.util.function.Consumer; +import java.util.function.Function; + +class BranchedInternal extends Branched { + BranchedInternal(final Branched branched) { + super(branched); + } + + BranchedInternal() { + super(null, null, null); + } + + static BranchedInternal empty() { + return new BranchedInternal<>(); + } + + String name() { + return name; + } + + public Function, ? extends KStream> chainFunction() { + return chainFunction; + } + + public Consumer> chainConsumer() { + return chainConsumer; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/BranchedKStreamImpl.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/BranchedKStreamImpl.java new file mode 100644 index 0000000..e553991 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/BranchedKStreamImpl.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.kstream.Branched; +import org.apache.kafka.streams.kstream.BranchedKStream; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Predicate; +import org.apache.kafka.streams.kstream.internals.graph.ProcessorGraphNode; +import org.apache.kafka.streams.kstream.internals.graph.ProcessorParameters; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class BranchedKStreamImpl implements BranchedKStream { + + private static final String BRANCH_NAME = "KSTREAM-BRANCH-"; + + private final KStreamImpl source; + private final boolean repartitionRequired; + private final String splitterName; + private final Map> outputBranches = new HashMap<>(); + + private final List> predicates = new ArrayList<>(); + private final List childNames = new ArrayList<>(); + private final ProcessorGraphNode splitterNode; + + BranchedKStreamImpl(final KStreamImpl source, final boolean repartitionRequired, final NamedInternal named) { + this.source = source; + this.repartitionRequired = repartitionRequired; + this.splitterName = named.orElseGenerateWithPrefix(source.builder, BRANCH_NAME); + + // predicates and childNames are passed by reference so when the user adds a branch they get added to + final ProcessorParameters processorParameters = + new ProcessorParameters<>(new KStreamBranch<>(predicates, childNames), splitterName); + splitterNode = new ProcessorGraphNode<>(splitterName, processorParameters); + source.builder.addGraphNode(source.graphNode, splitterNode); + } + + @Override + public BranchedKStream branch(final Predicate predicate) { + return branch(predicate, BranchedInternal.empty()); + } + + @Override + public BranchedKStream branch(final Predicate predicate, final Branched branched) { + predicates.add(predicate); + createBranch(branched, predicates.size()); + return this; + } + + @Override + public Map> defaultBranch() { + return defaultBranch(BranchedInternal.empty()); + } + + @Override + public Map> defaultBranch(final Branched branched) { + createBranch(branched, 0); + return outputBranches; + } + + private void createBranch(final Branched branched, final int index) { + final BranchedInternal branchedInternal = new BranchedInternal<>(branched); + final String branchChildName = getBranchChildName(index, branchedInternal); + childNames.add(branchChildName); + source.builder.newProcessorName(branchChildName); + final ProcessorParameters parameters = new ProcessorParameters<>(new PassThrough<>(), branchChildName); + final ProcessorGraphNode branchChildNode = new ProcessorGraphNode<>(branchChildName, parameters); + source.builder.addGraphNode(splitterNode, branchChildNode); + final KStreamImpl branch = new KStreamImpl<>(branchChildName, source.keySerde, + source.valueSerde, source.subTopologySourceNodes, + repartitionRequired, branchChildNode, source.builder); + process(branch, branchChildName, branchedInternal); + } + + private String getBranchChildName(final int index, final BranchedInternal branchedInternal) { + if (branchedInternal.name() == null) { + return splitterName + index; + } else { + return splitterName + branchedInternal.name(); + } + } + + private void process(final KStreamImpl branch, final String branchChildName, + final BranchedInternal branchedInternal) { + if (branchedInternal.chainFunction() != null) { + final KStream transformedStream = branchedInternal.chainFunction().apply(branch); + if (transformedStream != null) { + outputBranches.put(branchChildName, transformedStream); + } + } else if (branchedInternal.chainConsumer() != null) { + branchedInternal.chainConsumer().accept(branch); + } else { + outputBranches.put(branchChildName, branch); + } + } + + @Override + public Map> noDefaultBranch() { + return outputBranches; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/Change.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/Change.java new file mode 100644 index 0000000..c9a18de --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/Change.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import java.util.Objects; + +public class Change { + + public final T newValue; + public final T oldValue; + + public Change(final T newValue, final T oldValue) { + this.newValue = newValue; + this.oldValue = oldValue; + } + + @Override + public String toString() { + return "(" + newValue + "<-" + oldValue + ")"; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final Change change = (Change) o; + return Objects.equals(newValue, change.newValue) && + Objects.equals(oldValue, change.oldValue); + } + + @Override + public int hashCode() { + return Objects.hash(newValue, oldValue); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/ChangedDeserializer.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/ChangedDeserializer.java new file mode 100644 index 0000000..ac58591 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/ChangedDeserializer.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.streams.processor.internals.SerdeGetter; + +import java.nio.ByteBuffer; + +public class ChangedDeserializer implements Deserializer>, WrappingNullableDeserializer, Void, T> { + + private static final int NEWFLAG_SIZE = 1; + + private Deserializer inner; + + public ChangedDeserializer(final Deserializer inner) { + this.inner = inner; + } + + public Deserializer inner() { + return inner; + } + + @SuppressWarnings("unchecked") + @Override + public void setIfUnset(final SerdeGetter getter) { + if (inner == null) { + inner = (Deserializer) getter.valueSerde().deserializer(); + } + } + + @Override + public Change deserialize(final String topic, final Headers headers, final byte[] data) { + + final byte[] bytes = new byte[data.length - NEWFLAG_SIZE]; + + System.arraycopy(data, 0, bytes, 0, bytes.length); + + if (ByteBuffer.wrap(data).get(data.length - NEWFLAG_SIZE) != 0) { + return new Change<>(inner.deserialize(topic, headers, bytes), null); + } else { + return new Change<>(null, inner.deserialize(topic, headers, bytes)); + } + } + + @Override + public Change deserialize(final String topic, final byte[] data) { + return deserialize(topic, null, data); + } + + @Override + public void close() { + inner.close(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/ChangedSerializer.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/ChangedSerializer.java new file mode 100644 index 0000000..9cb43b0 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/ChangedSerializer.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.internals.SerdeGetter; + +import java.nio.ByteBuffer; + +public class ChangedSerializer implements Serializer>, WrappingNullableSerializer, Void, T> { + + private static final int NEWFLAG_SIZE = 1; + + private Serializer inner; + + public ChangedSerializer(final Serializer inner) { + this.inner = inner; + } + + public Serializer inner() { + return inner; + } + + @SuppressWarnings("unchecked") + @Override + public void setIfUnset(final SerdeGetter getter) { + if (inner == null) { + inner = (Serializer) getter.valueSerde().serializer(); + } + } + + /** + * @throws StreamsException if both old and new values of data are null, or if + * both values are not null + */ + @Override + public byte[] serialize(final String topic, final Headers headers, final Change data) { + final byte[] serializedKey; + + // only one of the old / new values would be not null + if (data.newValue != null) { + if (data.oldValue != null) { + throw new StreamsException("Both old and new values are not null (" + data.oldValue + + " : " + data.newValue + ") in ChangeSerializer, which is not allowed."); + } + + serializedKey = inner.serialize(topic, headers, data.newValue); + } else { + if (data.oldValue == null) { + throw new StreamsException("Both old and new values are null in ChangeSerializer, which is not allowed."); + } + + serializedKey = inner.serialize(topic, headers, data.oldValue); + } + + final ByteBuffer buf = ByteBuffer.allocate(serializedKey.length + NEWFLAG_SIZE); + buf.put(serializedKey); + buf.put((byte) (data.newValue != null ? 1 : 0)); + + return buf.array(); + } + + @Override + public byte[] serialize(final String topic, final Change data) { + return serialize(topic, null, data); + } + + @Override + public void close() { + inner.close(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/CogroupedKStreamImpl.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/CogroupedKStreamImpl.java new file mode 100644 index 0000000..a4a0351 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/CogroupedKStreamImpl.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.CogroupedKStream; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Named; +import org.apache.kafka.streams.kstream.SessionWindowedCogroupedKStream; +import org.apache.kafka.streams.kstream.SessionWindows; +import org.apache.kafka.streams.kstream.SlidingWindows; +import org.apache.kafka.streams.kstream.TimeWindowedCogroupedKStream; +import org.apache.kafka.streams.kstream.Window; +import org.apache.kafka.streams.kstream.Windows; +import org.apache.kafka.streams.kstream.internals.graph.GraphNode; +import org.apache.kafka.streams.state.KeyValueStore; + +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +public class CogroupedKStreamImpl extends AbstractStream implements CogroupedKStream { + + static final String AGGREGATE_NAME = "COGROUPKSTREAM-AGGREGATE-"; + static final String MERGE_NAME = "COGROUPKSTREAM-MERGE-"; + + final private Map, Aggregator> groupPatterns; + final private CogroupedStreamAggregateBuilder aggregateBuilder; + + CogroupedKStreamImpl(final String name, + final Set subTopologySourceNodes, + final GraphNode graphNode, + final InternalStreamsBuilder builder) { + super(name, null, null, subTopologySourceNodes, graphNode, builder); + groupPatterns = new LinkedHashMap<>(); + aggregateBuilder = new CogroupedStreamAggregateBuilder<>(builder); + } + + @SuppressWarnings("unchecked") + @Override + public CogroupedKStream cogroup(final KGroupedStream groupedStream, + final Aggregator aggregator) { + Objects.requireNonNull(groupedStream, "groupedStream can't be null"); + Objects.requireNonNull(aggregator, "aggregator can't be null"); + groupPatterns.put((KGroupedStreamImpl) groupedStream, + (Aggregator) aggregator); + return this; + } + + @Override + public KTable aggregate(final Initializer initializer, + final Materialized> materialized) { + return aggregate(initializer, NamedInternal.empty(), materialized); + } + + @Override + public KTable aggregate(final Initializer initializer, final Named named) { + return aggregate(initializer, named, Materialized.with(keySerde, null)); + } + + @Override + public KTable aggregate(final Initializer initializer, + final Named named, + final Materialized> materialized) { + Objects.requireNonNull(initializer, "initializer can't be null"); + Objects.requireNonNull(named, "named can't be null"); + Objects.requireNonNull(materialized, "materialized can't be null"); + return doAggregate( + initializer, + new NamedInternal(named), + new MaterializedInternal<>(materialized, builder, AGGREGATE_NAME)); + } + + @Override + public KTable aggregate(final Initializer initializer) { + return aggregate(initializer, Materialized.with(keySerde, null)); + } + + @Override + public TimeWindowedCogroupedKStream windowedBy(final Windows windows) { + Objects.requireNonNull(windows, "windows can't be null"); + return new TimeWindowedCogroupedKStreamImpl<>( + windows, + builder, + subTopologySourceNodes, + name, + aggregateBuilder, + graphNode, + groupPatterns); + } + + @Override + public TimeWindowedCogroupedKStream windowedBy(final SlidingWindows slidingWindows) { + Objects.requireNonNull(slidingWindows, "slidingWindows can't be null"); + return new SlidingWindowedCogroupedKStreamImpl<>( + slidingWindows, + builder, + subTopologySourceNodes, + name, + aggregateBuilder, + graphNode, + groupPatterns); + } + + @Override + public SessionWindowedCogroupedKStream windowedBy(final SessionWindows sessionWindows) { + Objects.requireNonNull(sessionWindows, "sessionWindows can't be null"); + return new SessionWindowedCogroupedKStreamImpl<>(sessionWindows, + builder, + subTopologySourceNodes, + name, + aggregateBuilder, + graphNode, + groupPatterns); + } + + private KTable doAggregate(final Initializer initializer, + final NamedInternal named, + final MaterializedInternal> materializedInternal) { + return aggregateBuilder.build( + groupPatterns, + initializer, + named, + new TimestampedKeyValueStoreMaterializer<>(materializedInternal).materialize(), + materializedInternal.keySerde(), + materializedInternal.valueSerde(), + materializedInternal.queryableStoreName()); + } +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/CogroupedStreamAggregateBuilder.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/CogroupedStreamAggregateBuilder.java new file mode 100644 index 0000000..3630454 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/CogroupedStreamAggregateBuilder.java @@ -0,0 +1,295 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import static org.apache.kafka.streams.kstream.internals.graph.OptimizableRepartitionNode.optimizableRepartitionNodeBuilder; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Map.Entry; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Merger; +import org.apache.kafka.streams.kstream.SessionWindows; +import org.apache.kafka.streams.kstream.SlidingWindows; +import org.apache.kafka.streams.kstream.Window; +import org.apache.kafka.streams.kstream.Windows; +import org.apache.kafka.streams.kstream.internals.graph.OptimizableRepartitionNode.OptimizableRepartitionNodeBuilder; +import org.apache.kafka.streams.kstream.internals.graph.ProcessorGraphNode; +import org.apache.kafka.streams.kstream.internals.graph.ProcessorParameters; +import org.apache.kafka.streams.kstream.internals.graph.StatefulProcessorNode; +import org.apache.kafka.streams.kstream.internals.graph.GraphNode; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.state.StoreBuilder; + +class CogroupedStreamAggregateBuilder { + private final InternalStreamsBuilder builder; + private final Map, GraphNode> parentNodes = new LinkedHashMap<>(); + + CogroupedStreamAggregateBuilder(final InternalStreamsBuilder builder) { + this.builder = builder; + } + KTable build(final Map, Aggregator> groupPatterns, + final Initializer initializer, + final NamedInternal named, + final StoreBuilder storeBuilder, + final Serde keySerde, + final Serde valueSerde, + final String queryableName) { + processRepartitions(groupPatterns, storeBuilder); + final Collection processors = new ArrayList<>(); + final Collection parentProcessors = new ArrayList<>(); + boolean stateCreated = false; + int counter = 0; + for (final Entry, Aggregator> kGroupedStream : groupPatterns.entrySet()) { + final KStreamAggProcessorSupplier parentProcessor = + new KStreamAggregate<>(storeBuilder.name(), initializer, kGroupedStream.getValue()); + parentProcessors.add(parentProcessor); + final StatefulProcessorNode statefulProcessorNode = getStatefulProcessorNode( + named.suffixWithOrElseGet( + "-cogroup-agg-" + counter++, + builder, + CogroupedKStreamImpl.AGGREGATE_NAME), + stateCreated, + storeBuilder, + parentProcessor); + stateCreated = true; + processors.add(statefulProcessorNode); + builder.addGraphNode(parentNodes.get(kGroupedStream.getKey()), statefulProcessorNode); + } + return createTable(processors, parentProcessors, named, keySerde, valueSerde, queryableName, storeBuilder.name()); + } + + @SuppressWarnings("unchecked") + KTable build(final Map, Aggregator> groupPatterns, + final Initializer initializer, + final NamedInternal named, + final StoreBuilder storeBuilder, + final Serde keySerde, + final Serde valueSerde, + final String queryableName, + final Windows windows) { + processRepartitions(groupPatterns, storeBuilder); + + final Collection processors = new ArrayList<>(); + final Collection parentProcessors = new ArrayList<>(); + boolean stateCreated = false; + int counter = 0; + for (final Entry, Aggregator> kGroupedStream : groupPatterns.entrySet()) { + final KStreamAggProcessorSupplier parentProcessor = + (KStreamAggProcessorSupplier) new KStreamWindowAggregate( + windows, + storeBuilder.name(), + initializer, + kGroupedStream.getValue()); + parentProcessors.add(parentProcessor); + final StatefulProcessorNode statefulProcessorNode = getStatefulProcessorNode( + named.suffixWithOrElseGet( + "-cogroup-agg-" + counter++, + builder, + CogroupedKStreamImpl.AGGREGATE_NAME), + stateCreated, + storeBuilder, + parentProcessor); + stateCreated = true; + processors.add(statefulProcessorNode); + builder.addGraphNode(parentNodes.get(kGroupedStream.getKey()), statefulProcessorNode); + } + return createTable(processors, parentProcessors, named, keySerde, valueSerde, queryableName, storeBuilder.name()); + } + + @SuppressWarnings("unchecked") + KTable build(final Map, Aggregator> groupPatterns, + final Initializer initializer, + final NamedInternal named, + final StoreBuilder storeBuilder, + final Serde keySerde, + final Serde valueSerde, + final String queryableName, + final SessionWindows sessionWindows, + final Merger sessionMerger) { + processRepartitions(groupPatterns, storeBuilder); + final Collection processors = new ArrayList<>(); + final Collection parentProcessors = new ArrayList<>(); + boolean stateCreated = false; + int counter = 0; + for (final Entry, Aggregator> kGroupedStream : groupPatterns.entrySet()) { + final KStreamAggProcessorSupplier parentProcessor = + (KStreamAggProcessorSupplier) new KStreamSessionWindowAggregate( + sessionWindows, + storeBuilder.name(), + initializer, + kGroupedStream.getValue(), + sessionMerger); + parentProcessors.add(parentProcessor); + final StatefulProcessorNode statefulProcessorNode = getStatefulProcessorNode( + named.suffixWithOrElseGet( + "-cogroup-agg-" + counter++, + builder, + CogroupedKStreamImpl.AGGREGATE_NAME), + stateCreated, + storeBuilder, + parentProcessor); + stateCreated = true; + processors.add(statefulProcessorNode); + builder.addGraphNode(parentNodes.get(kGroupedStream.getKey()), statefulProcessorNode); + } + return createTable(processors, parentProcessors, named, keySerde, valueSerde, queryableName, storeBuilder.name()); + } + + @SuppressWarnings("unchecked") + KTable build(final Map, Aggregator> groupPatterns, + final Initializer initializer, + final NamedInternal named, + final StoreBuilder storeBuilder, + final Serde keySerde, + final Serde valueSerde, + final String queryableName, + final SlidingWindows slidingWindows) { + processRepartitions(groupPatterns, storeBuilder); + final Collection parentProcessors = new ArrayList<>(); + final Collection processors = new ArrayList<>(); + boolean stateCreated = false; + int counter = 0; + for (final Entry, Aggregator> kGroupedStream : groupPatterns.entrySet()) { + final KStreamAggProcessorSupplier parentProcessor = + (KStreamAggProcessorSupplier) new KStreamSlidingWindowAggregate( + slidingWindows, + storeBuilder.name(), + initializer, + kGroupedStream.getValue()); + parentProcessors.add(parentProcessor); + final StatefulProcessorNode statefulProcessorNode = getStatefulProcessorNode( + named.suffixWithOrElseGet( + "-cogroup-agg-" + counter++, + builder, + CogroupedKStreamImpl.AGGREGATE_NAME), + stateCreated, + storeBuilder, + parentProcessor); + stateCreated = true; + processors.add(statefulProcessorNode); + builder.addGraphNode(parentNodes.get(kGroupedStream.getKey()), statefulProcessorNode); + } + return createTable(processors, parentProcessors, named, keySerde, valueSerde, queryableName, storeBuilder.name()); + } + + private void processRepartitions(final Map, Aggregator> groupPatterns, + final StoreBuilder storeBuilder) { + for (final KGroupedStreamImpl repartitionReqs : groupPatterns.keySet()) { + + if (repartitionReqs.repartitionRequired) { + + final OptimizableRepartitionNodeBuilder repartitionNodeBuilder = optimizableRepartitionNodeBuilder(); + + final String repartitionNamePrefix = repartitionReqs.userProvidedRepartitionTopicName != null ? + repartitionReqs.userProvidedRepartitionTopicName : storeBuilder.name(); + + createRepartitionSource(repartitionNamePrefix, repartitionNodeBuilder, repartitionReqs.keySerde, repartitionReqs.valueSerde); + + if (!parentNodes.containsKey(repartitionReqs)) { + final GraphNode repartitionNode = repartitionNodeBuilder.build(); + builder.addGraphNode(repartitionReqs.graphNode, repartitionNode); + parentNodes.put(repartitionReqs, repartitionNode); + } + } else { + parentNodes.put(repartitionReqs, repartitionReqs.graphNode); + } + } + + final Collection> groupedStreams = new ArrayList<>(parentNodes.keySet()); + final AbstractStream kGrouped = groupedStreams.iterator().next(); + groupedStreams.remove(kGrouped); + kGrouped.ensureCopartitionWith(groupedStreams); + + } + + @SuppressWarnings("unchecked") + KTable createTable(final Collection processors, + final Collection parentProcessors, + final NamedInternal named, + final Serde keySerde, + final Serde valueSerde, + final String queryableName, + final String storeName) { + + final String mergeProcessorName = named.suffixWithOrElseGet( + "-cogroup-merge", + builder, + CogroupedKStreamImpl.MERGE_NAME); + final KTableNewProcessorSupplier passThrough = new KTablePassThrough<>(parentProcessors, storeName); + final ProcessorParameters processorParameters = new ProcessorParameters(passThrough, mergeProcessorName); + final ProcessorGraphNode mergeNode = + new ProcessorGraphNode<>(mergeProcessorName, processorParameters); + + builder.addGraphNode(processors, mergeNode); + + return new KTableImpl( + mergeProcessorName, + keySerde, + valueSerde, + Collections.singleton(mergeNode.nodeName()), + queryableName, + passThrough, + mergeNode, + builder); + } + + private StatefulProcessorNode getStatefulProcessorNode(final String processorName, + final boolean stateCreated, + final StoreBuilder storeBuilder, + final ProcessorSupplier kStreamAggregate) { + final StatefulProcessorNode statefulProcessorNode; + if (!stateCreated) { + statefulProcessorNode = + new StatefulProcessorNode<>( + processorName, + new ProcessorParameters<>(kStreamAggregate, processorName), + storeBuilder + ); + } else { + statefulProcessorNode = + new StatefulProcessorNode<>( + processorName, + new ProcessorParameters<>(kStreamAggregate, processorName), + new String[]{storeBuilder.name()} + ); + } + + return statefulProcessorNode; + } + + @SuppressWarnings("unchecked") + private void createRepartitionSource(final String repartitionTopicNamePrefix, + final OptimizableRepartitionNodeBuilder optimizableRepartitionNodeBuilder, + final Serde keySerde, + final Serde valueSerde) { + + KStreamImpl.createRepartitionedSource(builder, + keySerde, + (Serde) valueSerde, + repartitionTopicNamePrefix, + null, + (OptimizableRepartitionNodeBuilder) optimizableRepartitionNodeBuilder); + + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/ConsumedInternal.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/ConsumedInternal.java new file mode 100644 index 0000000..0aa4820 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/ConsumedInternal.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.processor.TimestampExtractor; + +public class ConsumedInternal extends Consumed { + + public ConsumedInternal(final Consumed consumed) { + super(consumed); + } + + + public ConsumedInternal(final Serde keySerde, + final Serde valueSerde, + final TimestampExtractor timestampExtractor, + final Topology.AutoOffsetReset offsetReset) { + this(Consumed.with(keySerde, valueSerde, timestampExtractor, offsetReset)); + } + + public ConsumedInternal() { + this(Consumed.with(null, null)); + } + + public Serde keySerde() { + return keySerde; + } + + public Deserializer keyDeserializer() { + return keySerde == null ? null : keySerde.deserializer(); + } + + public Serde valueSerde() { + return valueSerde; + } + + public Deserializer valueDeserializer() { + return valueSerde == null ? null : valueSerde.deserializer(); + } + + public TimestampExtractor timestampExtractor() { + return timestampExtractor; + } + + public Topology.AutoOffsetReset offsetResetPolicy() { + return resetPolicy; + } + + public String name() { + return processorName; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/FullChangeSerde.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/FullChangeSerde.java new file mode 100644 index 0000000..3a34394 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/FullChangeSerde.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serializer; + +import java.nio.ByteBuffer; + +import static java.util.Objects.requireNonNull; +import static org.apache.kafka.common.utils.Utils.getNullableSizePrefixedArray; + +public final class FullChangeSerde { + private final Serde inner; + + public static FullChangeSerde wrap(final Serde serde) { + if (serde == null) { + return null; + } else { + return new FullChangeSerde<>(serde); + } + } + + private FullChangeSerde(final Serde inner) { + this.inner = requireNonNull(inner); + } + + public Serde innerSerde() { + return inner; + } + + public Change serializeParts(final String topic, final Change data) { + if (data == null) { + return null; + } + final Serializer innerSerializer = innerSerde().serializer(); + final byte[] oldBytes = data.oldValue == null ? null : innerSerializer.serialize(topic, data.oldValue); + final byte[] newBytes = data.newValue == null ? null : innerSerializer.serialize(topic, data.newValue); + return new Change<>(newBytes, oldBytes); + } + + + public Change deserializeParts(final String topic, final Change serialChange) { + if (serialChange == null) { + return null; + } + final Deserializer innerDeserializer = innerSerde().deserializer(); + + final T oldValue = + serialChange.oldValue == null ? null : innerDeserializer.deserialize(topic, serialChange.oldValue); + final T newValue = + serialChange.newValue == null ? null : innerDeserializer.deserialize(topic, serialChange.newValue); + + return new Change<>(newValue, oldValue); + } + + /** + * We used to serialize a Change into a single byte[]. Now, we don't anymore, but we still + * need to be able to read it (so that we can load the state store from previously-written changelog records). + */ + public static Change decomposeLegacyFormattedArrayIntoChangeArrays(final byte[] data) { + if (data == null) { + return null; + } + final ByteBuffer buffer = ByteBuffer.wrap(data); + final byte[] oldBytes = getNullableSizePrefixedArray(buffer); + final byte[] newBytes = getNullableSizePrefixedArray(buffer); + return new Change<>(newBytes, oldBytes); + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/FullTimeWindowedSerde.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/FullTimeWindowedSerde.java new file mode 100644 index 0000000..a69002f --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/FullTimeWindowedSerde.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.kstream.TimeWindowedDeserializer; +import org.apache.kafka.streams.kstream.TimeWindowedSerializer; +import org.apache.kafka.streams.kstream.Windowed; + +class FullTimeWindowedSerde extends Serdes.WrapperSerde> { + FullTimeWindowedSerde(final Serde inner, final long windowSize) { + super( + new TimeWindowedSerializer<>(inner.serializer()), + new TimeWindowedDeserializer<>(inner.deserializer(), windowSize) + ); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/GlobalKTableImpl.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/GlobalKTableImpl.java new file mode 100644 index 0000000..734ff4a --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/GlobalKTableImpl.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.kstream.GlobalKTable; + +public class GlobalKTableImpl implements GlobalKTable { + + private final KTableValueGetterSupplier valueGetterSupplier; + private final String queryableStoreName; + + GlobalKTableImpl(final KTableValueGetterSupplier valueGetterSupplier, + final String queryableStoreName) { + this.valueGetterSupplier = valueGetterSupplier; + this.queryableStoreName = queryableStoreName; + } + + KTableValueGetterSupplier valueGetterSupplier() { + return valueGetterSupplier; + } + + @Override + public String queryableStoreName() { + return queryableStoreName; + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/GroupedInternal.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/GroupedInternal.java new file mode 100644 index 0000000..3569caa --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/GroupedInternal.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.kstream.Grouped; + +public class GroupedInternal extends Grouped { + + public GroupedInternal(final Grouped grouped) { + super(grouped); + } + + public Serde keySerde() { + return keySerde; + } + + public Serde valueSerde() { + return valueSerde; + } + + public String name() { + return name; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/GroupedStreamAggregateBuilder.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/GroupedStreamAggregateBuilder.java new file mode 100644 index 0000000..dfcf63d --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/GroupedStreamAggregateBuilder.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.internals.graph.ProcessorParameters; +import org.apache.kafka.streams.kstream.internals.graph.StatefulProcessorNode; +import org.apache.kafka.streams.kstream.internals.graph.GraphNode; +import org.apache.kafka.streams.state.StoreBuilder; + +import java.util.Collections; +import java.util.Set; + +import static org.apache.kafka.streams.kstream.internals.graph.OptimizableRepartitionNode.OptimizableRepartitionNodeBuilder; +import static org.apache.kafka.streams.kstream.internals.graph.OptimizableRepartitionNode.optimizableRepartitionNodeBuilder; + +class GroupedStreamAggregateBuilder { + + private final InternalStreamsBuilder builder; + private final Serde keySerde; + private final Serde valueSerde; + private final boolean repartitionRequired; + private final String userProvidedRepartitionTopicName; + private final Set subTopologySourceNodes; + private final String name; + private final GraphNode graphNode; + private GraphNode repartitionNode; + + final Initializer countInitializer = () -> 0L; + + final Aggregator countAggregator = (aggKey, value, aggregate) -> aggregate + 1; + + final Initializer reduceInitializer = () -> null; + + GroupedStreamAggregateBuilder(final InternalStreamsBuilder builder, + final GroupedInternal groupedInternal, + final boolean repartitionRequired, + final Set subTopologySourceNodes, + final String name, + final GraphNode graphNode) { + + this.builder = builder; + this.keySerde = groupedInternal.keySerde(); + this.valueSerde = groupedInternal.valueSerde(); + this.repartitionRequired = repartitionRequired; + this.subTopologySourceNodes = subTopologySourceNodes; + this.name = name; + this.graphNode = graphNode; + this.userProvidedRepartitionTopicName = groupedInternal.name(); + } + + KTable build(final NamedInternal functionName, + final StoreBuilder storeBuilder, + final KStreamAggProcessorSupplier aggregateSupplier, + final String queryableStoreName, + final Serde keySerde, + final Serde valueSerde) { + assert queryableStoreName == null || queryableStoreName.equals(storeBuilder.name()); + + final String aggFunctionName = functionName.name(); + + String sourceName = this.name; + GraphNode parentNode = graphNode; + + if (repartitionRequired) { + final OptimizableRepartitionNodeBuilder repartitionNodeBuilder = optimizableRepartitionNodeBuilder(); + final String repartitionTopicPrefix = userProvidedRepartitionTopicName != null ? userProvidedRepartitionTopicName : storeBuilder.name(); + sourceName = createRepartitionSource(repartitionTopicPrefix, repartitionNodeBuilder); + + // First time through we need to create a repartition node. + // Any subsequent calls to GroupedStreamAggregateBuilder#build we check if + // the user has provided a name for the repartition topic, is so we re-use + // the existing repartition node, otherwise we create a new one. + if (repartitionNode == null || userProvidedRepartitionTopicName == null) { + repartitionNode = repartitionNodeBuilder.build(); + } + + builder.addGraphNode(parentNode, repartitionNode); + parentNode = repartitionNode; + } + + final StatefulProcessorNode statefulProcessorNode = + new StatefulProcessorNode<>( + aggFunctionName, + new ProcessorParameters<>(aggregateSupplier, aggFunctionName), + storeBuilder + ); + + builder.addGraphNode(parentNode, statefulProcessorNode); + + return new KTableImpl<>(aggFunctionName, + keySerde, + valueSerde, + sourceName.equals(this.name) ? subTopologySourceNodes : Collections.singleton(sourceName), + queryableStoreName, + aggregateSupplier, + statefulProcessorNode, + builder); + + } + + /** + * @return the new sourceName of the repartitioned source + */ + private String createRepartitionSource(final String repartitionTopicNamePrefix, + final OptimizableRepartitionNodeBuilder optimizableRepartitionNodeBuilder) { + + return KStreamImpl.createRepartitionedSource(builder, + keySerde, + valueSerde, + repartitionTopicNamePrefix, + null, + optimizableRepartitionNodeBuilder); + + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/InternalNameProvider.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/InternalNameProvider.java new file mode 100644 index 0000000..bc35d68 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/InternalNameProvider.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +public interface InternalNameProvider { + String newProcessorName(final String prefix); + + String newStoreName(final String prefix); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/InternalStreamsBuilder.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/InternalStreamsBuilder.java new file mode 100644 index 0000000..7f75f72 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/InternalStreamsBuilder.java @@ -0,0 +1,550 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import java.util.TreeMap; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.errors.TopologyException; +import org.apache.kafka.streams.kstream.GlobalKTable; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.internals.graph.GlobalStoreNode; +import org.apache.kafka.streams.kstream.internals.graph.OptimizableRepartitionNode; +import org.apache.kafka.streams.kstream.internals.graph.ProcessorParameters; +import org.apache.kafka.streams.kstream.internals.graph.StateStoreNode; +import org.apache.kafka.streams.kstream.internals.graph.StreamSourceNode; +import org.apache.kafka.streams.kstream.internals.graph.GraphNode; +import org.apache.kafka.streams.kstream.internals.graph.TableSourceNode; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.PriorityQueue; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Predicate; +import java.util.regex.Pattern; + +public class InternalStreamsBuilder implements InternalNameProvider { + + private static final String TABLE_SOURCE_SUFFIX = "-source"; + + final InternalTopologyBuilder internalTopologyBuilder; + private final AtomicInteger index = new AtomicInteger(0); + + private final AtomicInteger buildPriorityIndex = new AtomicInteger(0); + private final LinkedHashMap>> keyChangingOperationsToOptimizableRepartitionNodes = new LinkedHashMap<>(); + private final LinkedHashSet mergeNodes = new LinkedHashSet<>(); + private final LinkedHashSet tableSourceNodes = new LinkedHashSet<>(); + + private static final String TOPOLOGY_ROOT = "root"; + private static final Logger LOG = LoggerFactory.getLogger(InternalStreamsBuilder.class); + + protected final GraphNode root = new GraphNode(TOPOLOGY_ROOT) { + @Override + public void writeToTopology(final InternalTopologyBuilder topologyBuilder) { + // no-op for root node + } + }; + + public InternalStreamsBuilder(final InternalTopologyBuilder internalTopologyBuilder) { + this.internalTopologyBuilder = internalTopologyBuilder; + } + + public KStream stream(final Collection topics, + final ConsumedInternal consumed) { + + final String name = new NamedInternal(consumed.name()).orElseGenerateWithPrefix(this, KStreamImpl.SOURCE_NAME); + final StreamSourceNode streamSourceNode = new StreamSourceNode<>(name, topics, consumed); + + addGraphNode(root, streamSourceNode); + + return new KStreamImpl<>(name, + consumed.keySerde(), + consumed.valueSerde(), + Collections.singleton(name), + false, + streamSourceNode, + this); + } + + public KStream stream(final Pattern topicPattern, + final ConsumedInternal consumed) { + final String name = new NamedInternal(consumed.name()).orElseGenerateWithPrefix(this, KStreamImpl.SOURCE_NAME); + final StreamSourceNode streamPatternSourceNode = new StreamSourceNode<>(name, topicPattern, consumed); + + addGraphNode(root, streamPatternSourceNode); + + return new KStreamImpl<>(name, + consumed.keySerde(), + consumed.valueSerde(), + Collections.singleton(name), + false, + streamPatternSourceNode, + this); + } + + public KTable table(final String topic, + final ConsumedInternal consumed, + final MaterializedInternal> materialized) { + + final NamedInternal named = new NamedInternal(consumed.name()); + + final String sourceName = named + .suffixWithOrElseGet(TABLE_SOURCE_SUFFIX, this, KStreamImpl.SOURCE_NAME); + + final String tableSourceName = named + .orElseGenerateWithPrefix(this, KTableImpl.SOURCE_NAME); + + final KTableSource tableSource = new KTableSource<>(materialized.storeName(), materialized.queryableStoreName()); + final ProcessorParameters processorParameters = new ProcessorParameters<>(tableSource, tableSourceName); + + final TableSourceNode tableSourceNode = TableSourceNode.tableSourceNodeBuilder() + .withTopic(topic) + .withSourceName(sourceName) + .withNodeName(tableSourceName) + .withConsumedInternal(consumed) + .withMaterializedInternal(materialized) + .withProcessorParameters(processorParameters) + .build(); + + addGraphNode(root, tableSourceNode); + + return new KTableImpl<>(tableSourceName, + consumed.keySerde(), + consumed.valueSerde(), + Collections.singleton(sourceName), + materialized.queryableStoreName(), + tableSource, + tableSourceNode, + this); + } + + public GlobalKTable globalTable(final String topic, + final ConsumedInternal consumed, + final MaterializedInternal> materialized) { + Objects.requireNonNull(consumed, "consumed can't be null"); + Objects.requireNonNull(materialized, "materialized can't be null"); + // explicitly disable logging for global stores + materialized.withLoggingDisabled(); + + final NamedInternal named = new NamedInternal(consumed.name()); + + final String sourceName = named + .suffixWithOrElseGet(TABLE_SOURCE_SUFFIX, this, KStreamImpl.SOURCE_NAME); + + final String processorName = named + .orElseGenerateWithPrefix(this, KTableImpl.SOURCE_NAME); + + // enforce store name as queryable name to always materialize global table stores + final String storeName = materialized.storeName(); + final KTableSource tableSource = new KTableSource<>(storeName, storeName); + + final ProcessorParameters processorParameters = new ProcessorParameters<>(tableSource, processorName); + + final TableSourceNode tableSourceNode = TableSourceNode.tableSourceNodeBuilder() + .withTopic(topic) + .isGlobalKTable(true) + .withSourceName(sourceName) + .withConsumedInternal(consumed) + .withMaterializedInternal(materialized) + .withProcessorParameters(processorParameters) + .build(); + + addGraphNode(root, tableSourceNode); + + return new GlobalKTableImpl<>(new KTableSourceValueGetterSupplier<>(storeName), materialized.queryableStoreName()); + } + + @Override + public String newProcessorName(final String prefix) { + return prefix + String.format("%010d", index.getAndIncrement()); + } + + @Override + public String newStoreName(final String prefix) { + return prefix + String.format(KTableImpl.STATE_STORE_NAME + "%010d", index.getAndIncrement()); + } + + public synchronized void addStateStore(final StoreBuilder builder) { + addGraphNode(root, new StateStoreNode<>(builder)); + } + + public synchronized void addGlobalStore(final StoreBuilder storeBuilder, + final String topic, + final ConsumedInternal consumed, + final org.apache.kafka.streams.processor.api.ProcessorSupplier stateUpdateSupplier) { + // explicitly disable logging for global stores + storeBuilder.withLoggingDisabled(); + final String sourceName = newProcessorName(KStreamImpl.SOURCE_NAME); + final String processorName = newProcessorName(KTableImpl.SOURCE_NAME); + + final GraphNode globalStoreNode = new GlobalStoreNode<>( + storeBuilder, + sourceName, + topic, + consumed, + processorName, + stateUpdateSupplier + ); + + addGraphNode(root, globalStoreNode); + } + + void addGraphNode(final GraphNode parent, + final GraphNode child) { + Objects.requireNonNull(parent, "parent node can't be null"); + Objects.requireNonNull(child, "child node can't be null"); + parent.addChild(child); + maybeAddNodeForOptimizationMetadata(child); + } + + void addGraphNode(final Collection parents, + final GraphNode child) { + Objects.requireNonNull(parents, "parent node can't be null"); + Objects.requireNonNull(child, "child node can't be null"); + + if (parents.isEmpty()) { + throw new StreamsException("Parent node collection can't be empty"); + } + + for (final GraphNode parent : parents) { + addGraphNode(parent, child); + } + } + + private void maybeAddNodeForOptimizationMetadata(final GraphNode node) { + node.setBuildPriority(buildPriorityIndex.getAndIncrement()); + + if (node.parentNodes().isEmpty() && !node.nodeName().equals(TOPOLOGY_ROOT)) { + throw new IllegalStateException( + "Nodes should not have a null parent node. Name: " + node.nodeName() + " Type: " + + node.getClass().getSimpleName()); + } + + if (node.isKeyChangingOperation()) { + keyChangingOperationsToOptimizableRepartitionNodes.put(node, new LinkedHashSet<>()); + } else if (node instanceof OptimizableRepartitionNode) { + final GraphNode parentNode = getKeyChangingParentNode(node); + if (parentNode != null) { + keyChangingOperationsToOptimizableRepartitionNodes.get(parentNode).add((OptimizableRepartitionNode) node); + } + } else if (node.isMergeNode()) { + mergeNodes.add(node); + } else if (node instanceof TableSourceNode) { + tableSourceNodes.add(node); + } + } + + // use this method for testing only + public void buildAndOptimizeTopology() { + buildAndOptimizeTopology(null); + } + + public void buildAndOptimizeTopology(final Properties props) { + + mergeDuplicateSourceNodes(); + maybePerformOptimizations(props); + + final PriorityQueue graphNodePriorityQueue = new PriorityQueue<>(5, Comparator.comparing(GraphNode::buildPriority)); + + graphNodePriorityQueue.offer(root); + + while (!graphNodePriorityQueue.isEmpty()) { + final GraphNode streamGraphNode = graphNodePriorityQueue.remove(); + + if (LOG.isDebugEnabled()) { + LOG.debug("Adding nodes to topology {} child nodes {}", streamGraphNode, streamGraphNode.children()); + } + + if (streamGraphNode.allParentsWrittenToTopology() && !streamGraphNode.hasWrittenToTopology()) { + streamGraphNode.writeToTopology(internalTopologyBuilder); + streamGraphNode.setHasWrittenToTopology(true); + } + + for (final GraphNode graphNode : streamGraphNode.children()) { + graphNodePriorityQueue.offer(graphNode); + } + } + internalTopologyBuilder.validateCopartition(); + } + + private void mergeDuplicateSourceNodes() { + final Map> topicsToSourceNodes = new HashMap<>(); + + // We don't really care about the order here, but since Pattern does not implement equals() we can't rely on + // a regular HashMap and containsKey(Pattern). But for our purposes it's sufficient to compare the compiled + // string and flags to determine if two pattern subscriptions can be merged into a single source node + final Map> patternsToSourceNodes = + new TreeMap<>(Comparator.comparing(Pattern::pattern).thenComparing(Pattern::flags)); + + for (final GraphNode graphNode : root.children()) { + if (graphNode instanceof StreamSourceNode) { + final StreamSourceNode currentSourceNode = (StreamSourceNode) graphNode; + + if (currentSourceNode.topicPattern().isPresent()) { + final Pattern topicPattern = currentSourceNode.topicPattern().get(); + if (!patternsToSourceNodes.containsKey(topicPattern)) { + patternsToSourceNodes.put(topicPattern, currentSourceNode); + } else { + final StreamSourceNode mainSourceNode = patternsToSourceNodes.get(topicPattern); + mainSourceNode.merge(currentSourceNode); + root.removeChild(graphNode); + } + } else { + for (final String topic : currentSourceNode.topicNames().get()) { + if (!topicsToSourceNodes.containsKey(topic)) { + topicsToSourceNodes.put(topic, currentSourceNode); + } else { + final StreamSourceNode mainSourceNode = topicsToSourceNodes.get(topic); + // TODO we only merge source nodes if the subscribed topic(s) are an exact match, so it's still not + // possible to subscribe to topicA in one KStream and topicA + topicB in another. We could achieve + // this by splitting these source nodes into one topic per node and routing to the subscribed children + if (!mainSourceNode.topicNames().equals(currentSourceNode.topicNames())) { + LOG.error("Topic {} was found in subscription for non-equal source nodes {} and {}", + topic, mainSourceNode, currentSourceNode); + throw new TopologyException("Two source nodes are subscribed to overlapping but not equal input topics"); + } + mainSourceNode.merge(currentSourceNode); + root.removeChild(graphNode); + } + } + } + } + } + } + + private void maybePerformOptimizations(final Properties props) { + + if (props != null && StreamsConfig.OPTIMIZE.equals(props.getProperty(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG))) { + LOG.debug("Optimizing the Kafka Streams graph for repartition nodes"); + optimizeKTableSourceTopics(); + maybeOptimizeRepartitionOperations(); + } + } + + private void optimizeKTableSourceTopics() { + LOG.debug("Marking KTable source nodes to optimize using source topic for changelogs "); + tableSourceNodes.forEach(node -> ((TableSourceNode) node).reuseSourceTopicForChangeLog(true)); + } + + private void maybeOptimizeRepartitionOperations() { + maybeUpdateKeyChangingRepartitionNodeMap(); + final Iterator>>> entryIterator = + keyChangingOperationsToOptimizableRepartitionNodes.entrySet().iterator(); + + while (entryIterator.hasNext()) { + final Map.Entry>> entry = entryIterator.next(); + + final GraphNode keyChangingNode = entry.getKey(); + + if (entry.getValue().isEmpty()) { + continue; + } + + final GroupedInternal groupedInternal = new GroupedInternal<>(getRepartitionSerdes(entry.getValue())); + + final String repartitionTopicName = getFirstRepartitionTopicName(entry.getValue()); + //passing in the name of the first repartition topic, re-used to create the optimized repartition topic + final GraphNode optimizedSingleRepartition = createRepartitionNode(repartitionTopicName, + groupedInternal.keySerde(), + groupedInternal.valueSerde()); + + // re-use parent buildPriority to make sure the single repartition graph node is evaluated before downstream nodes + optimizedSingleRepartition.setBuildPriority(keyChangingNode.buildPriority()); + + for (final OptimizableRepartitionNode repartitionNodeToBeReplaced : entry.getValue()) { + + final GraphNode keyChangingNodeChild = findParentNodeMatching(repartitionNodeToBeReplaced, gn -> gn.parentNodes().contains(keyChangingNode)); + + if (keyChangingNodeChild == null) { + throw new StreamsException(String.format("Found a null keyChangingChild node for %s", repartitionNodeToBeReplaced)); + } + + LOG.debug("Found the child node of the key changer {} from the repartition {}.", keyChangingNodeChild, repartitionNodeToBeReplaced); + + // need to add children of key-changing node as children of optimized repartition + // in order to process records from re-partitioning + optimizedSingleRepartition.addChild(keyChangingNodeChild); + + LOG.debug("Removing {} from {} children {}", keyChangingNodeChild, keyChangingNode, keyChangingNode.children()); + // now remove children from key-changing node + keyChangingNode.removeChild(keyChangingNodeChild); + + // now need to get children of repartition node so we can remove repartition node + final Collection repartitionNodeToBeReplacedChildren = repartitionNodeToBeReplaced.children(); + final Collection parentsOfRepartitionNodeToBeReplaced = repartitionNodeToBeReplaced.parentNodes(); + + for (final GraphNode repartitionNodeToBeReplacedChild : repartitionNodeToBeReplacedChildren) { + for (final GraphNode parentNode : parentsOfRepartitionNodeToBeReplaced) { + parentNode.addChild(repartitionNodeToBeReplacedChild); + } + } + + for (final GraphNode parentNode : parentsOfRepartitionNodeToBeReplaced) { + parentNode.removeChild(repartitionNodeToBeReplaced); + } + repartitionNodeToBeReplaced.clearChildren(); + + // if replaced repartition node is part of any copartition group, + // we need to update it with the new node name so that co-partitioning won't break. + internalTopologyBuilder.maybeUpdateCopartitionSourceGroups(repartitionNodeToBeReplaced.nodeName(), + optimizedSingleRepartition.nodeName()); + + LOG.debug("Updated node {} children {}", optimizedSingleRepartition, optimizedSingleRepartition.children()); + } + + keyChangingNode.addChild(optimizedSingleRepartition); + entryIterator.remove(); + } + } + + private void maybeUpdateKeyChangingRepartitionNodeMap() { + final Map> mergeNodesToKeyChangers = new HashMap<>(); + final Set mergeNodeKeyChangingParentsToRemove = new HashSet<>(); + for (final GraphNode mergeNode : mergeNodes) { + mergeNodesToKeyChangers.put(mergeNode, new LinkedHashSet<>()); + final Set>>> entrySet = keyChangingOperationsToOptimizableRepartitionNodes.entrySet(); + for (final Map.Entry>> entry : entrySet) { + if (mergeNodeHasRepartitionChildren(mergeNode, entry.getValue())) { + final GraphNode maybeParentKey = findParentNodeMatching(mergeNode, node -> node.parentNodes().contains(entry.getKey())); + if (maybeParentKey != null) { + mergeNodesToKeyChangers.get(mergeNode).add(entry.getKey()); + } + } + } + } + + for (final Map.Entry> entry : mergeNodesToKeyChangers.entrySet()) { + final GraphNode mergeKey = entry.getKey(); + final Collection keyChangingParents = entry.getValue(); + final LinkedHashSet> repartitionNodes = new LinkedHashSet<>(); + for (final GraphNode keyChangingParent : keyChangingParents) { + repartitionNodes.addAll(keyChangingOperationsToOptimizableRepartitionNodes.get(keyChangingParent)); + mergeNodeKeyChangingParentsToRemove.add(keyChangingParent); + } + keyChangingOperationsToOptimizableRepartitionNodes.put(mergeKey, repartitionNodes); + } + + for (final GraphNode mergeNodeKeyChangingParent : mergeNodeKeyChangingParentsToRemove) { + keyChangingOperationsToOptimizableRepartitionNodes.remove(mergeNodeKeyChangingParent); + } + } + + private boolean mergeNodeHasRepartitionChildren(final GraphNode mergeNode, + final LinkedHashSet> repartitionNodes) { + return repartitionNodes.stream().allMatch(n -> findParentNodeMatching(n, gn -> gn.parentNodes().contains(mergeNode)) != null); + } + + private OptimizableRepartitionNode createRepartitionNode(final String repartitionTopicName, + final Serde keySerde, + final Serde valueSerde) { + + final OptimizableRepartitionNode.OptimizableRepartitionNodeBuilder repartitionNodeBuilder = + OptimizableRepartitionNode.optimizableRepartitionNodeBuilder(); + KStreamImpl.createRepartitionedSource( + this, + keySerde, + valueSerde, + repartitionTopicName, + null, + repartitionNodeBuilder + ); + + // ensures setting the repartition topic to the name of the + // first repartition topic to get merged + // this may be an auto-generated name or a user specified name + repartitionNodeBuilder.withRepartitionTopic(repartitionTopicName); + + return repartitionNodeBuilder.build(); + + } + + private GraphNode getKeyChangingParentNode(final GraphNode repartitionNode) { + final GraphNode shouldBeKeyChangingNode = findParentNodeMatching(repartitionNode, n -> n.isKeyChangingOperation() || n.isValueChangingOperation()); + + final GraphNode keyChangingNode = findParentNodeMatching(repartitionNode, GraphNode::isKeyChangingOperation); + if (shouldBeKeyChangingNode != null && shouldBeKeyChangingNode.equals(keyChangingNode)) { + return keyChangingNode; + } + return null; + } + + private String getFirstRepartitionTopicName(final Collection> repartitionNodes) { + return repartitionNodes.iterator().next().repartitionTopic(); + } + + @SuppressWarnings("unchecked") + private GroupedInternal getRepartitionSerdes(final Collection> repartitionNodes) { + Serde keySerde = null; + Serde valueSerde = null; + + for (final OptimizableRepartitionNode repartitionNode : repartitionNodes) { + if (keySerde == null && repartitionNode.keySerde() != null) { + keySerde = (Serde) repartitionNode.keySerde(); + } + + if (valueSerde == null && repartitionNode.valueSerde() != null) { + valueSerde = (Serde) repartitionNode.valueSerde(); + } + + if (keySerde != null && valueSerde != null) { + break; + } + } + + return new GroupedInternal<>(Grouped.with(keySerde, valueSerde)); + } + + private GraphNode findParentNodeMatching(final GraphNode startSeekingNode, + final Predicate parentNodePredicate) { + if (parentNodePredicate.test(startSeekingNode)) { + return startSeekingNode; + } + GraphNode foundParentNode = null; + + for (final GraphNode parentNode : startSeekingNode.parentNodes()) { + if (parentNodePredicate.test(parentNode)) { + return parentNode; + } + foundParentNode = findParentNodeMatching(parentNode, parentNodePredicate); + } + return foundParentNode; + } + + public GraphNode root() { + return root; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/JoinWindowsInternal.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/JoinWindowsInternal.java new file mode 100644 index 0000000..60061ad --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/JoinWindowsInternal.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.kstream.JoinWindows; + +public class JoinWindowsInternal extends JoinWindows { + + public JoinWindowsInternal(final JoinWindows joinWindows) { + super(joinWindows); + } + + public boolean spuriousResultFixEnabled() { + return enableSpuriousResultFix; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/JoinedInternal.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/JoinedInternal.java new file mode 100644 index 0000000..eb58840 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/JoinedInternal.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.kstream.Joined; + +public class JoinedInternal extends Joined { + + JoinedInternal(final Joined joined) { + super(joined); + } + + public Serde keySerde() { + return keySerde; + } + + public Serde valueSerde() { + return valueSerde; + } + + public Serde otherValueSerde() { + return otherValueSerde; + } + + public String name() { + return name; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KGroupedStreamImpl.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KGroupedStreamImpl.java new file mode 100644 index 0000000..d56caed --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KGroupedStreamImpl.java @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.CogroupedKStream; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Named; +import org.apache.kafka.streams.kstream.Reducer; +import org.apache.kafka.streams.kstream.SessionWindowedKStream; +import org.apache.kafka.streams.kstream.SessionWindows; +import org.apache.kafka.streams.kstream.SlidingWindows; +import org.apache.kafka.streams.kstream.TimeWindowedKStream; +import org.apache.kafka.streams.kstream.Window; +import org.apache.kafka.streams.kstream.Windows; +import org.apache.kafka.streams.kstream.internals.graph.GraphNode; +import org.apache.kafka.streams.state.KeyValueStore; + +import java.util.Objects; +import java.util.Set; + +class KGroupedStreamImpl extends AbstractStream implements KGroupedStream { + + static final String REDUCE_NAME = "KSTREAM-REDUCE-"; + static final String AGGREGATE_NAME = "KSTREAM-AGGREGATE-"; + + private final GroupedStreamAggregateBuilder aggregateBuilder; + final boolean repartitionRequired; + final String userProvidedRepartitionTopicName; + + KGroupedStreamImpl(final String name, + final Set subTopologySourceNodes, + final GroupedInternal groupedInternal, + final boolean repartitionRequired, + final GraphNode graphNode, + final InternalStreamsBuilder builder) { + super(name, groupedInternal.keySerde(), groupedInternal.valueSerde(), subTopologySourceNodes, graphNode, builder); + this.repartitionRequired = repartitionRequired; + this.userProvidedRepartitionTopicName = groupedInternal.name(); + this.aggregateBuilder = new GroupedStreamAggregateBuilder<>( + builder, + groupedInternal, + repartitionRequired, + subTopologySourceNodes, + name, + graphNode + ); + } + + @Override + public KTable reduce(final Reducer reducer) { + return reduce(reducer, Materialized.with(keySerde, valueSerde)); + } + + @Override + public KTable reduce(final Reducer reducer, + final Materialized> materialized) { + return reduce(reducer, NamedInternal.empty(), materialized); + } + + @Override + public KTable reduce(final Reducer reducer, + final Named named, + final Materialized> materialized) { + Objects.requireNonNull(reducer, "reducer can't be null"); + Objects.requireNonNull(materialized, "materialized can't be null"); + Objects.requireNonNull(named, "name can't be null"); + + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(materialized, builder, REDUCE_NAME); + + if (materializedInternal.keySerde() == null) { + materializedInternal.withKeySerde(keySerde); + } + if (materializedInternal.valueSerde() == null) { + materializedInternal.withValueSerde(valueSerde); + } + + final String name = new NamedInternal(named).orElseGenerateWithPrefix(builder, REDUCE_NAME); + return doAggregate( + new KStreamReduce<>(materializedInternal.storeName(), reducer), + name, + materializedInternal + ); + } + + @Override + public KTable aggregate(final Initializer initializer, + final Aggregator aggregator, + final Materialized> materialized) { + return aggregate(initializer, aggregator, NamedInternal.empty(), materialized); + } + + @Override + public KTable aggregate(final Initializer initializer, + final Aggregator aggregator, + final Named named, + final Materialized> materialized) { + Objects.requireNonNull(initializer, "initializer can't be null"); + Objects.requireNonNull(aggregator, "aggregator can't be null"); + Objects.requireNonNull(materialized, "materialized can't be null"); + Objects.requireNonNull(named, "named can't be null"); + + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(materialized, builder, AGGREGATE_NAME); + + if (materializedInternal.keySerde() == null) { + materializedInternal.withKeySerde(keySerde); + } + + final String name = new NamedInternal(named).orElseGenerateWithPrefix(builder, AGGREGATE_NAME); + return doAggregate( + new KStreamAggregate<>(materializedInternal.storeName(), initializer, aggregator), + name, + materializedInternal + ); + } + + @Override + public KTable aggregate(final Initializer initializer, + final Aggregator aggregator) { + return aggregate(initializer, aggregator, Materialized.with(keySerde, null)); + } + + @Override + public KTable count() { + return doCount(NamedInternal.empty(), Materialized.with(keySerde, Serdes.Long())); + } + + @Override + public KTable count(final Named named) { + Objects.requireNonNull(named, "named can't be null"); + return doCount(named, Materialized.with(keySerde, Serdes.Long())); + } + + @Override + public KTable count(final Materialized> materialized) { + return count(NamedInternal.empty(), materialized); + } + + @Override + public KTable count(final Named named, final Materialized> materialized) { + Objects.requireNonNull(materialized, "materialized can't be null"); + + // TODO: remove this when we do a topology-incompatible release + // we used to burn a topology name here, so we have to keep doing it for compatibility + if (new MaterializedInternal<>(materialized).storeName() == null) { + builder.newStoreName(AGGREGATE_NAME); + } + + return doCount(named, materialized); + } + + private KTable doCount(final Named named, final Materialized> materialized) { + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(materialized, builder, AGGREGATE_NAME); + + if (materializedInternal.keySerde() == null) { + materializedInternal.withKeySerde(keySerde); + } + if (materializedInternal.valueSerde() == null) { + materializedInternal.withValueSerde(Serdes.Long()); + } + + final String name = new NamedInternal(named).orElseGenerateWithPrefix(builder, AGGREGATE_NAME); + return doAggregate( + new KStreamAggregate<>(materializedInternal.storeName(), aggregateBuilder.countInitializer, aggregateBuilder.countAggregator), + name, + materializedInternal); + } + + @Override + public TimeWindowedKStream windowedBy(final Windows windows) { + + return new TimeWindowedKStreamImpl<>( + windows, + builder, + subTopologySourceNodes, + name, + keySerde, + valueSerde, + aggregateBuilder, + graphNode + ); + } + + @Override + public TimeWindowedKStream windowedBy(final SlidingWindows windows) { + + return new SlidingWindowedKStreamImpl<>( + windows, + builder, + subTopologySourceNodes, + name, + keySerde, + valueSerde, + aggregateBuilder, + graphNode + ); + } + + @Override + public SessionWindowedKStream windowedBy(final SessionWindows windows) { + + return new SessionWindowedKStreamImpl<>( + windows, + builder, + subTopologySourceNodes, + name, + keySerde, + valueSerde, + aggregateBuilder, + graphNode + ); + } + + private KTable doAggregate(final KStreamAggProcessorSupplier aggregateSupplier, + final String functionName, + final MaterializedInternal> materializedInternal) { + return aggregateBuilder.build( + new NamedInternal(functionName), + new TimestampedKeyValueStoreMaterializer<>(materializedInternal).materialize(), + aggregateSupplier, + materializedInternal.queryableStoreName(), + materializedInternal.keySerde(), + materializedInternal.valueSerde()); + } + + @Override + public CogroupedKStream cogroup(final Aggregator aggregator) { + Objects.requireNonNull(aggregator, "aggregator can't be null"); + return new CogroupedKStreamImpl(name, subTopologySourceNodes, graphNode, builder) + .cogroup(this, aggregator); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KGroupedTableImpl.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KGroupedTableImpl.java new file mode 100644 index 0000000..da35dac --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KGroupedTableImpl.java @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.kstream.KGroupedTable; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Named; +import org.apache.kafka.streams.kstream.Reducer; +import org.apache.kafka.streams.kstream.internals.graph.GroupedTableOperationRepartitionNode; +import org.apache.kafka.streams.kstream.internals.graph.ProcessorParameters; +import org.apache.kafka.streams.kstream.internals.graph.StatefulProcessorNode; +import org.apache.kafka.streams.kstream.internals.graph.GraphNode; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.state.KeyValueStore; + +import java.util.Collections; +import java.util.Objects; +import java.util.Set; + +/** + * The implementation class of {@link KGroupedTable}. + * + * @param the key type + * @param the value type + */ +public class KGroupedTableImpl extends AbstractStream implements KGroupedTable { + + private static final String AGGREGATE_NAME = "KTABLE-AGGREGATE-"; + + private static final String REDUCE_NAME = "KTABLE-REDUCE-"; + + private final String userProvidedRepartitionTopicName; + + private final Initializer countInitializer = () -> 0L; + + private final Aggregator countAdder = (aggKey, value, aggregate) -> aggregate + 1L; + + private final Aggregator countSubtractor = (aggKey, value, aggregate) -> aggregate - 1L; + + private GraphNode repartitionGraphNode; + + KGroupedTableImpl(final InternalStreamsBuilder builder, + final String name, + final Set subTopologySourceNodes, + final GroupedInternal groupedInternal, + final GraphNode graphNode) { + super(name, groupedInternal.keySerde(), groupedInternal.valueSerde(), subTopologySourceNodes, graphNode, builder); + + this.userProvidedRepartitionTopicName = groupedInternal.name(); + } + + private KTable doAggregate(final ProcessorSupplier, K, Change> aggregateSupplier, + final NamedInternal named, + final String functionName, + final MaterializedInternal> materialized) { + + final String sinkName = named.suffixWithOrElseGet("-sink", builder, KStreamImpl.SINK_NAME); + final String sourceName = named.suffixWithOrElseGet("-source", builder, KStreamImpl.SOURCE_NAME); + final String funcName = named.orElseGenerateWithPrefix(builder, functionName); + final String repartitionTopic = (userProvidedRepartitionTopicName != null ? userProvidedRepartitionTopicName : materialized.storeName()) + + KStreamImpl.REPARTITION_TOPIC_SUFFIX; + + if (repartitionGraphNode == null || userProvidedRepartitionTopicName == null) { + repartitionGraphNode = createRepartitionNode(sinkName, sourceName, repartitionTopic); + } + + + // the passed in StreamsGraphNode must be the parent of the repartition node + builder.addGraphNode(this.graphNode, repartitionGraphNode); + + final StatefulProcessorNode statefulProcessorNode = new StatefulProcessorNode<>( + funcName, + new ProcessorParameters<>(aggregateSupplier, funcName), + new TimestampedKeyValueStoreMaterializer<>(materialized).materialize() + ); + + // now the repartition node must be the parent of the StateProcessorNode + builder.addGraphNode(repartitionGraphNode, statefulProcessorNode); + + // return the KTable representation with the intermediate topic as the sources + return new KTableImpl<>(funcName, + materialized.keySerde(), + materialized.valueSerde(), + Collections.singleton(sourceName), + materialized.queryableStoreName(), + aggregateSupplier, + statefulProcessorNode, + builder); + } + + private GroupedTableOperationRepartitionNode createRepartitionNode(final String sinkName, + final String sourceName, + final String topic) { + + return GroupedTableOperationRepartitionNode.groupedTableOperationNodeBuilder() + .withRepartitionTopic(topic) + .withSinkName(sinkName) + .withSourceName(sourceName) + .withKeySerde(keySerde) + .withValueSerde(valueSerde) + .withNodeName(sourceName).build(); + } + + @Override + public KTable reduce(final Reducer adder, + final Reducer subtractor, + final Materialized> materialized) { + return reduce(adder, subtractor, NamedInternal.empty(), materialized); + } + + @Override + public KTable reduce(final Reducer adder, + final Reducer subtractor, + final Named named, + final Materialized> materialized) { + Objects.requireNonNull(adder, "adder can't be null"); + Objects.requireNonNull(subtractor, "subtractor can't be null"); + Objects.requireNonNull(named, "named can't be null"); + Objects.requireNonNull(materialized, "materialized can't be null"); + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(materialized, builder, AGGREGATE_NAME); + + if (materializedInternal.keySerde() == null) { + materializedInternal.withKeySerde(keySerde); + } + if (materializedInternal.valueSerde() == null) { + materializedInternal.withValueSerde(valueSerde); + } + final ProcessorSupplier, K, Change> aggregateSupplier = new KTableReduce<>( + materializedInternal.storeName(), + adder, + subtractor); + return doAggregate(aggregateSupplier, new NamedInternal(named), REDUCE_NAME, materializedInternal); + } + + @Override + public KTable reduce(final Reducer adder, + final Reducer subtractor) { + return reduce(adder, subtractor, Materialized.with(keySerde, valueSerde)); + } + + + @Override + public KTable count(final Materialized> materialized) { + return count(NamedInternal.empty(), materialized); + } + + @Override + public KTable count(final Named named, final Materialized> materialized) { + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(materialized, builder, AGGREGATE_NAME); + + if (materializedInternal.keySerde() == null) { + materializedInternal.withKeySerde(keySerde); + } + if (materializedInternal.valueSerde() == null) { + materializedInternal.withValueSerde(Serdes.Long()); + } + + final ProcessorSupplier, K, Change> aggregateSupplier = new KTableAggregate<>( + materializedInternal.storeName(), + countInitializer, + countAdder, + countSubtractor); + + return doAggregate(aggregateSupplier, new NamedInternal(named), AGGREGATE_NAME, materializedInternal); + } + + @Override + public KTable count() { + return count(Materialized.with(keySerde, Serdes.Long())); + } + + @Override + public KTable count(final Named named) { + return count(named, Materialized.with(keySerde, Serdes.Long())); + } + + @Override + public KTable aggregate(final Initializer initializer, + final Aggregator adder, + final Aggregator subtractor, + final Materialized> materialized) { + return aggregate(initializer, adder, subtractor, NamedInternal.empty(), materialized); + } + + @Override + public KTable aggregate(final Initializer initializer, + final Aggregator adder, + final Aggregator subtractor, + final Named named, + final Materialized> materialized) { + Objects.requireNonNull(initializer, "initializer can't be null"); + Objects.requireNonNull(adder, "adder can't be null"); + Objects.requireNonNull(subtractor, "subtractor can't be null"); + Objects.requireNonNull(named, "named can't be null"); + Objects.requireNonNull(materialized, "materialized can't be null"); + + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(materialized, builder, AGGREGATE_NAME); + + if (materializedInternal.keySerde() == null) { + materializedInternal.withKeySerde(keySerde); + } + final ProcessorSupplier, K, Change> aggregateSupplier = new KTableAggregate<>( + materializedInternal.storeName(), + initializer, + adder, + subtractor); + return doAggregate(aggregateSupplier, new NamedInternal(named), AGGREGATE_NAME, materializedInternal); + } + + @Override + public KTable aggregate(final Initializer initializer, + final Aggregator adder, + final Aggregator subtractor, + final Named named) { + return aggregate(initializer, adder, subtractor, named, Materialized.with(keySerde, null)); + } + + @Override + public KTable aggregate(final Initializer initializer, + final Aggregator adder, + final Aggregator subtractor) { + return aggregate(initializer, adder, subtractor, Materialized.with(keySerde, null)); + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamAggProcessorSupplier.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamAggProcessorSupplier.java new file mode 100644 index 0000000..7d8ff94 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamAggProcessorSupplier.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.processor.api.ProcessorSupplier; + +public interface KStreamAggProcessorSupplier extends ProcessorSupplier> { + + KTableValueGetterSupplier view(); + + void enableSendingOldValues(); +} + diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamAggregate.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamAggregate.java new file mode 100644 index 0000000..8e0a910 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamAggregate.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.processor.api.ContextualProcessor; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.api.RecordMetadata; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.kafka.streams.processor.internals.metrics.TaskMetrics.droppedRecordsSensor; +import static org.apache.kafka.streams.state.ValueAndTimestamp.getValueOrNull; + +public class KStreamAggregate implements KStreamAggProcessorSupplier { + + private static final Logger LOG = LoggerFactory.getLogger(KStreamAggregate.class); + + private final String storeName; + private final Initializer initializer; + private final Aggregator aggregator; + + private boolean sendOldValues = false; + + KStreamAggregate(final String storeName, + final Initializer initializer, + final Aggregator aggregator) { + this.storeName = storeName; + this.initializer = initializer; + this.aggregator = aggregator; + } + + @Override + public Processor> get() { + return new KStreamAggregateProcessor(); + } + + @Override + public void enableSendingOldValues() { + sendOldValues = true; + } + + + private class KStreamAggregateProcessor extends ContextualProcessor> { + private TimestampedKeyValueStore store; + private Sensor droppedRecordsSensor; + private TimestampedTupleForwarder tupleForwarder; + + @Override + public void init(final ProcessorContext> context) { + super.init(context); + droppedRecordsSensor = droppedRecordsSensor( + Thread.currentThread().getName(), + context.taskId().toString(), + (StreamsMetricsImpl) context.metrics()); + store = context.getStateStore(storeName); + tupleForwarder = new TimestampedTupleForwarder<>( + store, + context, + new TimestampedCacheFlushListener<>(context), + sendOldValues); + } + + @Override + public void process(final Record record) { + // If the key or value is null we don't need to proceed + if (record.key() == null || record.value() == null) { + if (context().recordMetadata().isPresent()) { + final RecordMetadata recordMetadata = context().recordMetadata().get(); + LOG.warn( + "Skipping record due to null key or value. " + + "topic=[{}] partition=[{}] offset=[{}]", + recordMetadata.topic(), recordMetadata.partition(), recordMetadata.offset() + ); + } else { + LOG.warn( + "Skipping record due to null key or value. Topic, partition, and offset not known." + ); + } + droppedRecordsSensor.record(); + return; + } + + final ValueAndTimestamp oldAggAndTimestamp = store.get(record.key()); + VAgg oldAgg = getValueOrNull(oldAggAndTimestamp); + + final VAgg newAgg; + final long newTimestamp; + + if (oldAgg == null) { + oldAgg = initializer.apply(); + newTimestamp = record.timestamp(); + } else { + oldAgg = oldAggAndTimestamp.value(); + newTimestamp = Math.max(record.timestamp(), oldAggAndTimestamp.timestamp()); + } + + newAgg = aggregator.apply(record.key(), record.value(), oldAgg); + + store.put(record.key(), ValueAndTimestamp.make(newAgg, newTimestamp)); + tupleForwarder.maybeForward(record.key(), newAgg, sendOldValues ? oldAgg : null, newTimestamp); + } + } + + @Override + public KTableValueGetterSupplier view() { + return new KTableValueGetterSupplier() { + + public KTableValueGetter get() { + return new KStreamAggregateValueGetter(); + } + + @Override + public String[] storeNames() { + return new String[]{storeName}; + } + }; + } + + private class KStreamAggregateValueGetter implements KTableValueGetter { + private TimestampedKeyValueStore store; + + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + store = context.getStateStore(storeName); + } + + @Override + public ValueAndTimestamp get(final KIn key) { + return store.get(key); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamBranch.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamBranch.java new file mode 100644 index 0000000..2d3fc76 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamBranch.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import java.util.List; +import org.apache.kafka.streams.kstream.Predicate; +import org.apache.kafka.streams.processor.api.ContextualProcessor; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.api.Record; + +class KStreamBranch implements ProcessorSupplier { + + private final List> predicates; + private final List childNodes; + + KStreamBranch(final List> predicates, + final List childNodes) { + this.predicates = predicates; + this.childNodes = childNodes; + } + + @Override + public Processor get() { + return new KStreamBranchProcessor(); + } + + private class KStreamBranchProcessor extends ContextualProcessor { + + @Override + public void process(final Record record) { + for (int i = 0; i < predicates.size(); i++) { + if (predicates.get(i).test(record.key(), record.value())) { + // use forward with child here and then break the loop + // so that no record is going to be piped to multiple streams + context().forward(record, childNodes.get(i)); + return; + } + } + // using default child node if supplied + if (childNodes.size() > predicates.size()) { + context().forward(record, childNodes.get(predicates.size())); + } + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamFilter.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamFilter.java new file mode 100644 index 0000000..ffafd10 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamFilter.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.kstream.Predicate; +import org.apache.kafka.streams.processor.api.ContextualProcessor; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.api.Record; + +class KStreamFilter implements ProcessorSupplier { + + private final Predicate predicate; + private final boolean filterNot; + + public KStreamFilter(final Predicate predicate, final boolean filterNot) { + this.predicate = predicate; + this.filterNot = filterNot; + } + + @Override + public Processor get() { + return new KStreamFilterProcessor(); + } + + private class KStreamFilterProcessor extends ContextualProcessor { + @Override + public void process(final Record record) { + if (filterNot ^ predicate.test(record.key(), record.value())) { + context().forward(record); + } + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamFlatMap.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamFlatMap.java new file mode 100644 index 0000000..501e951 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamFlatMap.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.processor.api.ContextualProcessor; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.api.Record; + +import java.util.Objects; + +class KStreamFlatMap implements ProcessorSupplier { + + private final KeyValueMapper>> mapper; + + KStreamFlatMap(final KeyValueMapper>> mapper) { + this.mapper = mapper; + } + + @Override + public Processor get() { + return new KStreamFlatMapProcessor(); + } + + private class KStreamFlatMapProcessor extends ContextualProcessor { + @Override + public void process(final Record record) { + final Iterable> newKeyValues = + mapper.apply(record.key(), record.value()); + Objects.requireNonNull(newKeyValues, "The provided KeyValueMapper returned null which is not allowed."); + for (final KeyValue newPair : newKeyValues) { + context().forward(record.withKey(newPair.key).withValue(newPair.value)); + } + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamFlatMapValues.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamFlatMapValues.java new file mode 100644 index 0000000..1008b29 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamFlatMapValues.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.kstream.ValueMapperWithKey; +import org.apache.kafka.streams.processor.api.ContextualProcessor; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.api.Record; + +class KStreamFlatMapValues implements ProcessorSupplier { + + private final ValueMapperWithKey> mapper; + + KStreamFlatMapValues(final ValueMapperWithKey> mapper) { + this.mapper = mapper; + } + + @Override + public Processor get() { + return new KStreamFlatMapValuesProcessor(); + } + + private class KStreamFlatMapValuesProcessor extends ContextualProcessor { + @Override + public void process(final Record record) { + final Iterable newValues = mapper.apply(record.key(), record.value()); + for (final VOut v : newValues) { + context().forward(record.withValue(v)); + } + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamFlatTransform.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamFlatTransform.java new file mode 100644 index 0000000..4d4fd2b --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamFlatTransform.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Transformer; +import org.apache.kafka.streams.kstream.TransformerSupplier; +import org.apache.kafka.streams.state.StoreBuilder; + +import java.util.Set; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class KStreamFlatTransform implements org.apache.kafka.streams.processor.ProcessorSupplier { + + private final TransformerSupplier>> transformerSupplier; + + public KStreamFlatTransform(final TransformerSupplier>> transformerSupplier) { + this.transformerSupplier = transformerSupplier; + } + + @Override + public org.apache.kafka.streams.processor.Processor get() { + return new KStreamFlatTransformProcessor<>(transformerSupplier.get()); + } + + @Override + public Set> stores() { + return transformerSupplier.stores(); + } + + public static class KStreamFlatTransformProcessor extends org.apache.kafka.streams.processor.AbstractProcessor { + + private final Transformer>> transformer; + + public KStreamFlatTransformProcessor(final Transformer>> transformer) { + this.transformer = transformer; + } + + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + super.init(context); + transformer.init(context); + } + + @Override + public void process(final KIn key, final VIn value) { + final Iterable> pairs = transformer.transform(key, value); + if (pairs != null) { + for (final KeyValue pair : pairs) { + context().forward(pair.key, pair.value); + } + } + } + + @Override + public void close() { + transformer.close(); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamFlatTransformValues.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamFlatTransformValues.java new file mode 100644 index 0000000..baf44c7 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamFlatTransformValues.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.kstream.ValueTransformerWithKey; +import org.apache.kafka.streams.kstream.ValueTransformerWithKeySupplier; +import org.apache.kafka.streams.processor.internals.ForwardingDisabledProcessorContext; +import org.apache.kafka.streams.state.StoreBuilder; + +import java.util.Set; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class KStreamFlatTransformValues implements org.apache.kafka.streams.processor.ProcessorSupplier { + + private final ValueTransformerWithKeySupplier> valueTransformerSupplier; + + public KStreamFlatTransformValues(final ValueTransformerWithKeySupplier> valueTransformerWithKeySupplier) { + this.valueTransformerSupplier = valueTransformerWithKeySupplier; + } + + @Override + public org.apache.kafka.streams.processor.Processor get() { + return new KStreamFlatTransformValuesProcessor<>(valueTransformerSupplier.get()); + } + + @Override + public Set> stores() { + return valueTransformerSupplier.stores(); + } + + public static class KStreamFlatTransformValuesProcessor extends org.apache.kafka.streams.processor.AbstractProcessor { + + private final ValueTransformerWithKey> valueTransformer; + + KStreamFlatTransformValuesProcessor(final ValueTransformerWithKey> valueTransformer) { + this.valueTransformer = valueTransformer; + } + + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + super.init(context); + valueTransformer.init(new ForwardingDisabledProcessorContext(context)); + } + + @Override + public void process(final KIn key, final VIn value) { + final Iterable transformedValues = valueTransformer.transform(key, value); + if (transformedValues != null) { + for (final VOut transformedValue : transformedValues) { + context.forward(key, transformedValue); + } + } + } + + @Override + public void close() { + super.close(); + valueTransformer.close(); + } + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamGlobalKTableJoin.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamGlobalKTableJoin.java new file mode 100644 index 0000000..be3fe6b --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamGlobalKTableJoin.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.ValueJoinerWithKey; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +class KStreamGlobalKTableJoin implements org.apache.kafka.streams.processor.ProcessorSupplier { + + private final KTableValueGetterSupplier valueGetterSupplier; + private final ValueJoinerWithKey joiner; + private final KeyValueMapper mapper; + private final boolean leftJoin; + + KStreamGlobalKTableJoin(final KTableValueGetterSupplier valueGetterSupplier, + final ValueJoinerWithKey joiner, + final KeyValueMapper mapper, + final boolean leftJoin) { + this.valueGetterSupplier = valueGetterSupplier; + this.joiner = joiner; + this.mapper = mapper; + this.leftJoin = leftJoin; + } + + @Override + public org.apache.kafka.streams.processor.Processor get() { + return new KStreamKTableJoinProcessor<>(valueGetterSupplier.get(), mapper, joiner, leftJoin); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamImpl.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamImpl.java new file mode 100644 index 0000000..f7075f6 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamImpl.java @@ -0,0 +1,1527 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.internals.ApiUtils; +import org.apache.kafka.streams.kstream.BranchedKStream; +import org.apache.kafka.streams.kstream.ForeachAction; +import org.apache.kafka.streams.kstream.GlobalKTable; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.JoinWindows; +import org.apache.kafka.streams.kstream.Joined; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Named; +import org.apache.kafka.streams.kstream.Predicate; +import org.apache.kafka.streams.kstream.Printed; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.Repartitioned; +import org.apache.kafka.streams.kstream.StreamJoined; +import org.apache.kafka.streams.kstream.TransformerSupplier; +import org.apache.kafka.streams.kstream.ValueJoiner; +import org.apache.kafka.streams.kstream.ValueJoinerWithKey; +import org.apache.kafka.streams.kstream.ValueMapper; +import org.apache.kafka.streams.kstream.ValueMapperWithKey; +import org.apache.kafka.streams.kstream.ValueTransformerSupplier; +import org.apache.kafka.streams.kstream.ValueTransformerWithKeySupplier; +import org.apache.kafka.streams.kstream.internals.graph.BaseRepartitionNode; +import org.apache.kafka.streams.kstream.internals.graph.BaseRepartitionNode.BaseRepartitionNodeBuilder; +import org.apache.kafka.streams.kstream.internals.graph.OptimizableRepartitionNode; +import org.apache.kafka.streams.kstream.internals.graph.OptimizableRepartitionNode.OptimizableRepartitionNodeBuilder; +import org.apache.kafka.streams.kstream.internals.graph.ProcessorGraphNode; +import org.apache.kafka.streams.kstream.internals.graph.ProcessorParameters; +import org.apache.kafka.streams.kstream.internals.graph.StatefulProcessorNode; +import org.apache.kafka.streams.kstream.internals.graph.StreamSinkNode; +import org.apache.kafka.streams.kstream.internals.graph.StreamTableJoinNode; +import org.apache.kafka.streams.kstream.internals.graph.StreamToTableNode; +import org.apache.kafka.streams.kstream.internals.graph.GraphNode; +import org.apache.kafka.streams.kstream.internals.graph.UnoptimizableRepartitionNode; +import org.apache.kafka.streams.kstream.internals.graph.UnoptimizableRepartitionNode.UnoptimizableRepartitionNodeBuilder; +import org.apache.kafka.streams.processor.FailOnInvalidTimestamp; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.StreamPartitioner; +import org.apache.kafka.streams.processor.TopicNameExtractor; +import org.apache.kafka.streams.kstream.ForeachProcessor; +import org.apache.kafka.streams.processor.internals.InternalTopicProperties; +import org.apache.kafka.streams.processor.internals.StaticTopicNameExtractor; +import org.apache.kafka.streams.state.KeyValueStore; + +import java.lang.reflect.Array; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Objects; +import java.util.Set; + +import static org.apache.kafka.streams.kstream.internals.graph.OptimizableRepartitionNode.optimizableRepartitionNodeBuilder; + +public class KStreamImpl extends AbstractStream implements KStream { + + static final String JOINTHIS_NAME = "KSTREAM-JOINTHIS-"; + + static final String JOINOTHER_NAME = "KSTREAM-JOINOTHER-"; + + static final String JOIN_NAME = "KSTREAM-JOIN-"; + + static final String LEFTJOIN_NAME = "KSTREAM-LEFTJOIN-"; + + static final String MERGE_NAME = "KSTREAM-MERGE-"; + + static final String OUTERTHIS_NAME = "KSTREAM-OUTERTHIS-"; + + static final String OUTEROTHER_NAME = "KSTREAM-OUTEROTHER-"; + + static final String WINDOWED_NAME = "KSTREAM-WINDOWED-"; + + static final String OUTERSHARED_NAME = "KSTREAM-OUTERSHARED-"; + + static final String SOURCE_NAME = "KSTREAM-SOURCE-"; + + static final String SINK_NAME = "KSTREAM-SINK-"; + + static final String REPARTITION_TOPIC_SUFFIX = "-repartition"; + + private static final String BRANCH_NAME = "KSTREAM-BRANCH-"; + + private static final String BRANCHCHILD_NAME = "KSTREAM-BRANCHCHILD-"; + + private static final String FILTER_NAME = "KSTREAM-FILTER-"; + + private static final String PEEK_NAME = "KSTREAM-PEEK-"; + + private static final String FLATMAP_NAME = "KSTREAM-FLATMAP-"; + + private static final String FLATMAPVALUES_NAME = "KSTREAM-FLATMAPVALUES-"; + + private static final String MAP_NAME = "KSTREAM-MAP-"; + + private static final String MAPVALUES_NAME = "KSTREAM-MAPVALUES-"; + + private static final String PROCESSOR_NAME = "KSTREAM-PROCESSOR-"; + + private static final String PRINTING_NAME = "KSTREAM-PRINTER-"; + + private static final String KEY_SELECT_NAME = "KSTREAM-KEY-SELECT-"; + + private static final String TRANSFORM_NAME = "KSTREAM-TRANSFORM-"; + + private static final String TRANSFORMVALUES_NAME = "KSTREAM-TRANSFORMVALUES-"; + + private static final String FOREACH_NAME = "KSTREAM-FOREACH-"; + + private static final String TO_KTABLE_NAME = "KSTREAM-TOTABLE-"; + + private static final String REPARTITION_NAME = "KSTREAM-REPARTITION-"; + + private final boolean repartitionRequired; + + private OptimizableRepartitionNode repartitionNode; + + KStreamImpl(final String name, + final Serde keySerde, + final Serde valueSerde, + final Set subTopologySourceNodes, + final boolean repartitionRequired, + final GraphNode graphNode, + final InternalStreamsBuilder builder) { + super(name, keySerde, valueSerde, subTopologySourceNodes, graphNode, builder); + this.repartitionRequired = repartitionRequired; + } + + @Override + public KStream filter(final Predicate predicate) { + return filter(predicate, NamedInternal.empty()); + } + + @Override + public KStream filter(final Predicate predicate, + final Named named) { + Objects.requireNonNull(predicate, "predicate can't be null"); + Objects.requireNonNull(named, "named can't be null"); + + final String name = new NamedInternal(named).orElseGenerateWithPrefix(builder, FILTER_NAME); + final ProcessorParameters processorParameters = + new ProcessorParameters<>(new KStreamFilter<>(predicate, false), name); + final ProcessorGraphNode filterProcessorNode = + new ProcessorGraphNode<>(name, processorParameters); + + builder.addGraphNode(graphNode, filterProcessorNode); + + return new KStreamImpl<>( + name, + keySerde, + valueSerde, + subTopologySourceNodes, + repartitionRequired, + filterProcessorNode, + builder); + } + + @Override + public KStream filterNot(final Predicate predicate) { + return filterNot(predicate, NamedInternal.empty()); + } + + @Override + public KStream filterNot(final Predicate predicate, + final Named named) { + Objects.requireNonNull(predicate, "predicate can't be null"); + Objects.requireNonNull(named, "named can't be null"); + + final String name = new NamedInternal(named).orElseGenerateWithPrefix(builder, FILTER_NAME); + final ProcessorParameters processorParameters = + new ProcessorParameters<>(new KStreamFilter<>(predicate, true), name); + final ProcessorGraphNode filterNotProcessorNode = + new ProcessorGraphNode<>(name, processorParameters); + + builder.addGraphNode(graphNode, filterNotProcessorNode); + + return new KStreamImpl<>( + name, + keySerde, + valueSerde, + subTopologySourceNodes, + repartitionRequired, + filterNotProcessorNode, + builder); + } + + @Override + public KStream selectKey(final KeyValueMapper mapper) { + return selectKey(mapper, NamedInternal.empty()); + } + + @Override + public KStream selectKey(final KeyValueMapper mapper, + final Named named) { + Objects.requireNonNull(mapper, "mapper can't be null"); + Objects.requireNonNull(named, "named can't be null"); + + final ProcessorGraphNode selectKeyProcessorNode = internalSelectKey(mapper, new NamedInternal(named)); + selectKeyProcessorNode.keyChangingOperation(true); + + builder.addGraphNode(graphNode, selectKeyProcessorNode); + + // key serde cannot be preserved + return new KStreamImpl<>( + selectKeyProcessorNode.nodeName(), + null, + valueSerde, + subTopologySourceNodes, + true, + selectKeyProcessorNode, + builder); + } + + private ProcessorGraphNode internalSelectKey(final KeyValueMapper mapper, + final NamedInternal named) { + final String name = named.orElseGenerateWithPrefix(builder, KEY_SELECT_NAME); + final KStreamMap kStreamMap = + new KStreamMap<>((key, value) -> new KeyValue<>(mapper.apply(key, value), value)); + final ProcessorParameters processorParameters = new ProcessorParameters<>(kStreamMap, name); + + return new ProcessorGraphNode<>(name, processorParameters); + } + + @Override + public KStream map(final KeyValueMapper> mapper) { + return map(mapper, NamedInternal.empty()); + } + + @Override + public KStream map(final KeyValueMapper> mapper, + final Named named) { + Objects.requireNonNull(mapper, "mapper can't be null"); + Objects.requireNonNull(named, "named can't be null"); + + final String name = new NamedInternal(named).orElseGenerateWithPrefix(builder, MAP_NAME); + final ProcessorParameters processorParameters = + new ProcessorParameters<>(new KStreamMap<>(mapper), name); + final ProcessorGraphNode mapProcessorNode = + new ProcessorGraphNode<>(name, processorParameters); + mapProcessorNode.keyChangingOperation(true); + + builder.addGraphNode(graphNode, mapProcessorNode); + + // key and value serde cannot be preserved + return new KStreamImpl<>( + name, + null, + null, + subTopologySourceNodes, + true, + mapProcessorNode, + builder); + } + + @Override + public KStream mapValues(final ValueMapper valueMapper) { + return mapValues(withKey(valueMapper)); + } + + @Override + public KStream mapValues(final ValueMapper mapper, + final Named named) { + return mapValues(withKey(mapper), named); + } + + @Override + public KStream mapValues(final ValueMapperWithKey valueMapperWithKey) { + return mapValues(valueMapperWithKey, NamedInternal.empty()); + } + + @Override + public KStream mapValues(final ValueMapperWithKey valueMapperWithKey, + final Named named) { + Objects.requireNonNull(valueMapperWithKey, "valueMapperWithKey can't be null"); + Objects.requireNonNull(named, "named can't be null"); + + final String name = new NamedInternal(named).orElseGenerateWithPrefix(builder, MAPVALUES_NAME); + final ProcessorParameters processorParameters = + new ProcessorParameters<>(new KStreamMapValues<>(valueMapperWithKey), name); + final ProcessorGraphNode mapValuesProcessorNode = + new ProcessorGraphNode<>(name, processorParameters); + mapValuesProcessorNode.setValueChangingOperation(true); + + builder.addGraphNode(graphNode, mapValuesProcessorNode); + + // value serde cannot be preserved + return new KStreamImpl<>( + name, + keySerde, + null, + subTopologySourceNodes, + repartitionRequired, + mapValuesProcessorNode, + builder); + } + + @Override + public KStream flatMap(final KeyValueMapper>> mapper) { + return flatMap(mapper, NamedInternal.empty()); + } + + @Override + public KStream flatMap(final KeyValueMapper>> mapper, + final Named named) { + Objects.requireNonNull(mapper, "mapper can't be null"); + Objects.requireNonNull(named, "named can't be null"); + + final String name = new NamedInternal(named).orElseGenerateWithPrefix(builder, FLATMAP_NAME); + final ProcessorParameters processorParameters = + new ProcessorParameters<>(new KStreamFlatMap<>(mapper), name); + final ProcessorGraphNode flatMapNode = + new ProcessorGraphNode<>(name, processorParameters); + flatMapNode.keyChangingOperation(true); + + builder.addGraphNode(graphNode, flatMapNode); + + // key and value serde cannot be preserved + return new KStreamImpl<>(name, null, null, subTopologySourceNodes, true, flatMapNode, builder); + } + + @Override + public KStream flatMapValues(final ValueMapper> mapper) { + return flatMapValues(withKey(mapper)); + } + + @Override + public KStream flatMapValues(final ValueMapper> mapper, + final Named named) { + return flatMapValues(withKey(mapper), named); + } + + @Override + public KStream flatMapValues(final ValueMapperWithKey> mapper) { + return flatMapValues(mapper, NamedInternal.empty()); + } + + @Override + public KStream flatMapValues(final ValueMapperWithKey> valueMapper, + final Named named) { + Objects.requireNonNull(valueMapper, "valueMapper can't be null"); + Objects.requireNonNull(named, "named can't be null"); + + final String name = new NamedInternal(named).orElseGenerateWithPrefix(builder, FLATMAPVALUES_NAME); + final ProcessorParameters processorParameters = + new ProcessorParameters<>(new KStreamFlatMapValues<>(valueMapper), name); + final ProcessorGraphNode flatMapValuesNode = + new ProcessorGraphNode<>(name, processorParameters); + flatMapValuesNode.setValueChangingOperation(true); + + builder.addGraphNode(graphNode, flatMapValuesNode); + + // value serde cannot be preserved + return new KStreamImpl<>( + name, + keySerde, + null, + subTopologySourceNodes, + repartitionRequired, + flatMapValuesNode, + builder); + } + + @Override + public void print(final Printed printed) { + Objects.requireNonNull(printed, "printed can't be null"); + + final PrintedInternal printedInternal = new PrintedInternal<>(printed); + final String name = new NamedInternal(printedInternal.name()).orElseGenerateWithPrefix(builder, PRINTING_NAME); + final ProcessorParameters processorParameters = + new ProcessorParameters<>(printedInternal.build(this.name), name); + final ProcessorGraphNode printNode = + new ProcessorGraphNode<>(name, processorParameters); + + builder.addGraphNode(graphNode, printNode); + } + + @Override + public void foreach(final ForeachAction action) { + foreach(action, NamedInternal.empty()); + } + + @Override + public void foreach(final ForeachAction action, + final Named named) { + Objects.requireNonNull(action, "action can't be null"); + Objects.requireNonNull(named, "named can't be null"); + + final String name = new NamedInternal(named).orElseGenerateWithPrefix(builder, FOREACH_NAME); + final ProcessorParameters processorParameters = + new ProcessorParameters<>(() -> new ForeachProcessor<>(action), name); + final ProcessorGraphNode foreachNode = + new ProcessorGraphNode<>(name, processorParameters); + + builder.addGraphNode(graphNode, foreachNode); + } + + @Override + public KStream peek(final ForeachAction action) { + return peek(action, NamedInternal.empty()); + } + + @Override + public KStream peek(final ForeachAction action, + final Named named) { + Objects.requireNonNull(action, "action can't be null"); + Objects.requireNonNull(named, "named can't be null"); + + final String name = new NamedInternal(named).orElseGenerateWithPrefix(builder, PEEK_NAME); + final ProcessorParameters processorParameters = + new ProcessorParameters<>(new KStreamPeek<>(action), name); + final ProcessorGraphNode peekNode = + new ProcessorGraphNode<>(name, processorParameters); + + builder.addGraphNode(graphNode, peekNode); + + return new KStreamImpl<>( + name, + keySerde, + valueSerde, + subTopologySourceNodes, + repartitionRequired, + peekNode, + builder); + } + + @Deprecated + @Override + @SuppressWarnings("unchecked") + public KStream[] branch(final Predicate... predicates) { + return doBranch(NamedInternal.empty(), predicates); + } + + @Deprecated + @Override + @SuppressWarnings("unchecked") + public KStream[] branch(final Named named, + final Predicate... predicates) { + Objects.requireNonNull(named, "named can't be null"); + return doBranch(new NamedInternal(named), predicates); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private KStream[] doBranch(final NamedInternal named, + final Predicate... predicates) { + Objects.requireNonNull(predicates, "predicates can't be a null array"); + if (predicates.length == 0) { + throw new IllegalArgumentException("branch() requires at least one predicate"); + } + for (final Predicate predicate : predicates) { + Objects.requireNonNull(predicate, "predicates can't be null"); + } + + final String branchName = named.orElseGenerateWithPrefix(builder, BRANCH_NAME); + final String[] childNames = new String[predicates.length]; + for (int i = 0; i < predicates.length; i++) { + childNames[i] = named.suffixWithOrElseGet("-predicate-" + i, builder, BRANCHCHILD_NAME); + } + + final ProcessorParameters processorParameters = + new ProcessorParameters<>(new KStreamBranch(Arrays.asList(predicates.clone()), + Arrays.asList(childNames)), branchName); + final ProcessorGraphNode branchNode = + new ProcessorGraphNode<>(branchName, processorParameters); + + builder.addGraphNode(graphNode, branchNode); + + final KStream[] branchChildren = (KStream[]) Array.newInstance(KStream.class, predicates.length); + for (int i = 0; i < predicates.length; i++) { + final ProcessorParameters innerProcessorParameters = + new ProcessorParameters<>(new PassThrough(), childNames[i]); + final ProcessorGraphNode branchChildNode = + new ProcessorGraphNode<>(childNames[i], innerProcessorParameters); + + builder.addGraphNode(branchNode, branchChildNode); + branchChildren[i] = new KStreamImpl<>(childNames[i], keySerde, valueSerde, subTopologySourceNodes, repartitionRequired, branchChildNode, builder); + } + + return branchChildren; + } + + @Override + public BranchedKStream split() { + return new BranchedKStreamImpl<>(this, repartitionRequired, NamedInternal.empty()); + } + + @Override + public BranchedKStream split(final Named named) { + Objects.requireNonNull(named, "named can't be null"); + return new BranchedKStreamImpl<>(this, repartitionRequired, new NamedInternal(named)); + } + + @Override + public KStream merge(final KStream stream) { + return merge(stream, NamedInternal.empty()); + } + + @Override + public KStream merge(final KStream stream, + final Named named) { + Objects.requireNonNull(stream, "stream can't be null"); + Objects.requireNonNull(named, "named can't be null"); + + return merge(builder, stream, new NamedInternal(named)); + } + + private KStream merge(final InternalStreamsBuilder builder, + final KStream stream, + final NamedInternal named) { + final KStreamImpl streamImpl = (KStreamImpl) stream; + final boolean requireRepartitioning = streamImpl.repartitionRequired || repartitionRequired; + final String name = named.orElseGenerateWithPrefix(builder, MERGE_NAME); + final Set allSubTopologySourceNodes = new HashSet<>(); + allSubTopologySourceNodes.addAll(subTopologySourceNodes); + allSubTopologySourceNodes.addAll(streamImpl.subTopologySourceNodes); + + final ProcessorParameters processorParameters = + new ProcessorParameters<>(new PassThrough<>(), name); + final ProcessorGraphNode mergeNode = + new ProcessorGraphNode<>(name, processorParameters); + mergeNode.setMergeNode(true); + + builder.addGraphNode(Arrays.asList(graphNode, streamImpl.graphNode), mergeNode); + + // drop the serde as we cannot safely use either one to represent both streams + return new KStreamImpl<>( + name, + null, + null, + allSubTopologySourceNodes, + requireRepartitioning, + mergeNode, + builder); + } + + @Deprecated + @Override + public KStream through(final String topic) { + return through(topic, Produced.with(keySerde, valueSerde, null)); + } + + @Deprecated + @Override + public KStream through(final String topic, + final Produced produced) { + Objects.requireNonNull(topic, "topic can't be null"); + Objects.requireNonNull(produced, "produced can't be null"); + + final ProducedInternal producedInternal = new ProducedInternal<>(produced); + if (producedInternal.keySerde() == null) { + producedInternal.withKeySerde(keySerde); + } + if (producedInternal.valueSerde() == null) { + producedInternal.withValueSerde(valueSerde); + } + to(topic, producedInternal); + + return builder.stream( + Collections.singleton(topic), + new ConsumedInternal<>( + producedInternal.keySerde(), + producedInternal.valueSerde(), + new FailOnInvalidTimestamp(), + null + ) + ); + } + + @Override + public KStream repartition() { + return doRepartition(Repartitioned.as(null)); + } + + @Override + public KStream repartition(final Repartitioned repartitioned) { + return doRepartition(repartitioned); + } + + private KStream doRepartition(final Repartitioned repartitioned) { + Objects.requireNonNull(repartitioned, "repartitioned can't be null"); + + final RepartitionedInternal repartitionedInternal = new RepartitionedInternal<>(repartitioned); + + final String name = repartitionedInternal.name() != null ? repartitionedInternal.name() : builder + .newProcessorName(REPARTITION_NAME); + + final Serde valueSerde = repartitionedInternal.valueSerde() == null ? this.valueSerde : repartitionedInternal.valueSerde(); + final Serde keySerde = repartitionedInternal.keySerde() == null ? this.keySerde : repartitionedInternal.keySerde(); + + final UnoptimizableRepartitionNodeBuilder unoptimizableRepartitionNodeBuilder = UnoptimizableRepartitionNode + .unoptimizableRepartitionNodeBuilder(); + + final InternalTopicProperties internalTopicProperties = repartitionedInternal.toInternalTopicProperties(); + + final String repartitionSourceName = createRepartitionedSource( + builder, + repartitionedInternal.keySerde(), + valueSerde, + name, + repartitionedInternal.streamPartitioner(), + unoptimizableRepartitionNodeBuilder.withInternalTopicProperties(internalTopicProperties) + ); + + final UnoptimizableRepartitionNode unoptimizableRepartitionNode = unoptimizableRepartitionNodeBuilder.build(); + + builder.addGraphNode(graphNode, unoptimizableRepartitionNode); + + final Set sourceNodes = new HashSet<>(); + sourceNodes.add(unoptimizableRepartitionNode.nodeName()); + + return new KStreamImpl<>( + repartitionSourceName, + keySerde, + valueSerde, + Collections.unmodifiableSet(sourceNodes), + false, + unoptimizableRepartitionNode, + builder + ); + } + + @Override + public void to(final String topic) { + to(topic, Produced.with(keySerde, valueSerde, null)); + } + + @Override + public void to(final String topic, + final Produced produced) { + Objects.requireNonNull(topic, "topic can't be null"); + Objects.requireNonNull(produced, "produced can't be null"); + + final ProducedInternal producedInternal = new ProducedInternal<>(produced); + if (producedInternal.keySerde() == null) { + producedInternal.withKeySerde(keySerde); + } + if (producedInternal.valueSerde() == null) { + producedInternal.withValueSerde(valueSerde); + } + to(new StaticTopicNameExtractor<>(topic), producedInternal); + } + + @Override + public void to(final TopicNameExtractor topicExtractor) { + to(topicExtractor, Produced.with(keySerde, valueSerde, null)); + } + + @Override + public void to(final TopicNameExtractor topicExtractor, + final Produced produced) { + Objects.requireNonNull(topicExtractor, "topicExtractor can't be null"); + Objects.requireNonNull(produced, "produced can't be null"); + + final ProducedInternal producedInternal = new ProducedInternal<>(produced); + if (producedInternal.keySerde() == null) { + producedInternal.withKeySerde(keySerde); + } + if (producedInternal.valueSerde() == null) { + producedInternal.withValueSerde(valueSerde); + } + to(topicExtractor, producedInternal); + } + + private void to(final TopicNameExtractor topicExtractor, + final ProducedInternal produced) { + final String name = new NamedInternal(produced.name()).orElseGenerateWithPrefix(builder, SINK_NAME); + final StreamSinkNode sinkNode = new StreamSinkNode<>( + name, + topicExtractor, + produced + ); + + builder.addGraphNode(graphNode, sinkNode); + } + + @Override + public KTable toTable() { + return toTable(NamedInternal.empty(), Materialized.with(keySerde, valueSerde)); + } + + @Override + public KTable toTable(final Named named) { + return toTable(named, Materialized.with(keySerde, valueSerde)); + } + + @Override + public KTable toTable(final Materialized> materialized) { + return toTable(NamedInternal.empty(), materialized); + } + + @Override + public KTable toTable(final Named named, + final Materialized> materialized) { + Objects.requireNonNull(named, "named can't be null"); + Objects.requireNonNull(materialized, "materialized can't be null"); + + final NamedInternal namedInternal = new NamedInternal(named); + final String name = namedInternal.orElseGenerateWithPrefix(builder, TO_KTABLE_NAME); + + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(materialized, builder, TO_KTABLE_NAME); + + final Serde keySerdeOverride = materializedInternal.keySerde() == null + ? keySerde + : materializedInternal.keySerde(); + final Serde valueSerdeOverride = materializedInternal.valueSerde() == null + ? valueSerde + : materializedInternal.valueSerde(); + + final Set subTopologySourceNodes; + final GraphNode tableParentNode; + + if (repartitionRequired) { + final OptimizableRepartitionNodeBuilder repartitionNodeBuilder = optimizableRepartitionNodeBuilder(); + final String sourceName = createRepartitionedSource( + builder, + keySerdeOverride, + valueSerdeOverride, + name, + null, + repartitionNodeBuilder + ); + + tableParentNode = repartitionNodeBuilder.build(); + builder.addGraphNode(graphNode, tableParentNode); + subTopologySourceNodes = Collections.singleton(sourceName); + } else { + tableParentNode = graphNode; + subTopologySourceNodes = this.subTopologySourceNodes; + } + + final KTableSource tableSource = new KTableSource<>( + materializedInternal.storeName(), + materializedInternal.queryableStoreName() + ); + final ProcessorParameters processorParameters = new ProcessorParameters<>(tableSource, name); + final GraphNode tableNode = new StreamToTableNode<>( + name, + processorParameters, + materializedInternal + ); + + builder.addGraphNode(tableParentNode, tableNode); + + return new KTableImpl( + name, + keySerdeOverride, + valueSerdeOverride, + subTopologySourceNodes, + materializedInternal.queryableStoreName(), + tableSource, + tableNode, + builder + ); + } + + @Override + public KGroupedStream groupBy(final KeyValueMapper keySelector) { + return groupBy(keySelector, Grouped.with(null, valueSerde)); + } + + @Override + public KGroupedStream groupBy(final KeyValueMapper keySelector, + final Grouped grouped) { + Objects.requireNonNull(keySelector, "keySelector can't be null"); + Objects.requireNonNull(grouped, "grouped can't be null"); + + final GroupedInternal groupedInternal = new GroupedInternal<>(grouped); + final ProcessorGraphNode selectKeyMapNode = internalSelectKey(keySelector, new NamedInternal(groupedInternal.name())); + selectKeyMapNode.keyChangingOperation(true); + + builder.addGraphNode(graphNode, selectKeyMapNode); + + return new KGroupedStreamImpl<>( + selectKeyMapNode.nodeName(), + subTopologySourceNodes, + groupedInternal, + true, + selectKeyMapNode, + builder); + } + + @Override + public KGroupedStream groupByKey() { + return groupByKey(Grouped.with(keySerde, valueSerde)); + } + + @Override + public KGroupedStream groupByKey(final Grouped grouped) { + Objects.requireNonNull(grouped, "grouped can't be null"); + + final GroupedInternal groupedInternal = new GroupedInternal<>(grouped); + + return new KGroupedStreamImpl<>( + name, + subTopologySourceNodes, + groupedInternal, + repartitionRequired, + graphNode, + builder); + } + + @Override + public KStream join(final KStream otherStream, + final ValueJoiner joiner, + final JoinWindows windows) { + return join(otherStream, toValueJoinerWithKey(joiner), windows); + } + + @Override + public KStream join(final KStream otherStream, + final ValueJoinerWithKey joiner, + final JoinWindows windows) { + return join(otherStream, joiner, windows, StreamJoined.with(null, null, null)); + } + + @Override + public KStream join(final KStream otherStream, + final ValueJoiner joiner, + final JoinWindows windows, + final StreamJoined streamJoined) { + + return join(otherStream, toValueJoinerWithKey(joiner), windows, streamJoined); + } + + @Override + public KStream join(final KStream otherStream, + final ValueJoinerWithKey joiner, + final JoinWindows windows, + final StreamJoined streamJoined) { + + return doJoin( + otherStream, + joiner, + windows, + streamJoined, + new KStreamImplJoin(builder, false, false)); + } + + @Override + public KStream leftJoin(final KStream otherStream, + final ValueJoiner joiner, + final JoinWindows windows) { + return leftJoin(otherStream, toValueJoinerWithKey(joiner), windows); + } + + @Override + public KStream leftJoin(final KStream otherStream, + final ValueJoinerWithKey joiner, + final JoinWindows windows) { + return leftJoin(otherStream, joiner, windows, StreamJoined.with(null, null, null)); + } + + @Override + public KStream leftJoin(final KStream otherStream, + final ValueJoiner joiner, + final JoinWindows windows, + final StreamJoined streamJoined) { + return doJoin( + otherStream, + toValueJoinerWithKey(joiner), + windows, + streamJoined, + new KStreamImplJoin(builder, true, false)); + } + + @Override + public KStream leftJoin(final KStream otherStream, + final ValueJoinerWithKey joiner, + final JoinWindows windows, + final StreamJoined streamJoined) { + return doJoin( + otherStream, + joiner, + windows, + streamJoined, + new KStreamImplJoin(builder, true, false)); + } + + @Override + public KStream outerJoin(final KStream otherStream, + final ValueJoiner joiner, + final JoinWindows windows) { + return outerJoin(otherStream, toValueJoinerWithKey(joiner), windows); + } + + @Override + public KStream outerJoin(final KStream otherStream, + final ValueJoinerWithKey joiner, + final JoinWindows windows) { + return outerJoin(otherStream, joiner, windows, StreamJoined.with(null, null, null)); + } + + @Override + public KStream outerJoin(final KStream otherStream, + final ValueJoiner joiner, + final JoinWindows windows, + final StreamJoined streamJoined) { + + return outerJoin(otherStream, toValueJoinerWithKey(joiner), windows, streamJoined); + } + + @Override + public KStream outerJoin(final KStream otherStream, + final ValueJoinerWithKey joiner, + final JoinWindows windows, + final StreamJoined streamJoined) { + + return doJoin(otherStream, joiner, windows, streamJoined, new KStreamImplJoin(builder, true, true)); + } + + private KStream doJoin(final KStream otherStream, + final ValueJoinerWithKey joiner, + final JoinWindows windows, + final StreamJoined streamJoined, + final KStreamImplJoin join) { + Objects.requireNonNull(otherStream, "otherStream can't be null"); + Objects.requireNonNull(joiner, "joiner can't be null"); + Objects.requireNonNull(windows, "windows can't be null"); + Objects.requireNonNull(streamJoined, "streamJoined can't be null"); + + KStreamImpl joinThis = this; + KStreamImpl joinOther = (KStreamImpl) otherStream; + + final StreamJoinedInternal streamJoinedInternal = new StreamJoinedInternal<>(streamJoined); + final NamedInternal name = new NamedInternal(streamJoinedInternal.name()); + if (joinThis.repartitionRequired) { + final String joinThisName = joinThis.name; + final String leftJoinRepartitionTopicName = name.suffixWithOrElseGet("-left", joinThisName); + joinThis = joinThis.repartitionForJoin(leftJoinRepartitionTopicName, streamJoinedInternal.keySerde(), streamJoinedInternal.valueSerde()); + } + + if (joinOther.repartitionRequired) { + final String joinOtherName = joinOther.name; + final String rightJoinRepartitionTopicName = name.suffixWithOrElseGet("-right", joinOtherName); + joinOther = joinOther.repartitionForJoin(rightJoinRepartitionTopicName, streamJoinedInternal.keySerde(), streamJoinedInternal.otherValueSerde()); + } + + joinThis.ensureCopartitionWith(Collections.singleton(joinOther)); + + return join.join( + joinThis, + joinOther, + joiner, + windows, + streamJoined); + } + + /** + * Repartition a stream. This is required on join operations occurring after + * an operation that changes the key, i.e, selectKey, map(..), flatMap(..). + */ + private KStreamImpl repartitionForJoin(final String repartitionName, + final Serde keySerdeOverride, + final Serde valueSerdeOverride) { + final Serde repartitionKeySerde = keySerdeOverride != null ? keySerdeOverride : keySerde; + final Serde repartitionValueSerde = valueSerdeOverride != null ? valueSerdeOverride : valueSerde; + final OptimizableRepartitionNodeBuilder optimizableRepartitionNodeBuilder = + OptimizableRepartitionNode.optimizableRepartitionNodeBuilder(); + // we still need to create the repartitioned source each time + // as it increments the counter which + // is needed to maintain topology compatibility + final String repartitionedSourceName = createRepartitionedSource( + builder, + repartitionKeySerde, + repartitionValueSerde, + repartitionName, + null, + optimizableRepartitionNodeBuilder); + + if (repartitionNode == null || !name.equals(repartitionName)) { + repartitionNode = optimizableRepartitionNodeBuilder.build(); + builder.addGraphNode(graphNode, repartitionNode); + } + + return new KStreamImpl<>( + repartitionedSourceName, + repartitionKeySerde, + repartitionValueSerde, + Collections.singleton(repartitionedSourceName), + false, + repartitionNode, + builder); + } + + static > String createRepartitionedSource(final InternalStreamsBuilder builder, + final Serde keySerde, + final Serde valueSerde, + final String repartitionTopicNamePrefix, + final StreamPartitioner streamPartitioner, + final BaseRepartitionNodeBuilder baseRepartitionNodeBuilder) { + + final String repartitionTopicName = repartitionTopicNamePrefix.endsWith(REPARTITION_TOPIC_SUFFIX) ? + repartitionTopicNamePrefix : + repartitionTopicNamePrefix + REPARTITION_TOPIC_SUFFIX; + + // Always need to generate the names to burn index counter for compatibility + final String genSinkName = builder.newProcessorName(SINK_NAME); + final String genNullKeyFilterProcessorName = builder.newProcessorName(FILTER_NAME); + final String genSourceName = builder.newProcessorName(SOURCE_NAME); + + final String sinkName; + final String sourceName; + final String nullKeyFilterProcessorName; + + if (repartitionTopicNamePrefix.matches("KSTREAM.*-[0-9]{10}")) { + sinkName = genSinkName; + sourceName = genSourceName; + nullKeyFilterProcessorName = genNullKeyFilterProcessorName; + } else { + sinkName = repartitionTopicName + "-sink"; + sourceName = repartitionTopicName + "-source"; + nullKeyFilterProcessorName = repartitionTopicName + "-filter"; + } + + final Predicate notNullKeyPredicate = (k, v) -> k != null; + final ProcessorParameters processorParameters = new ProcessorParameters<>( + new KStreamFilter<>(notNullKeyPredicate, false), + nullKeyFilterProcessorName + ); + + baseRepartitionNodeBuilder.withKeySerde(keySerde) + .withValueSerde(valueSerde) + .withSourceName(sourceName) + .withRepartitionTopic(repartitionTopicName) + .withSinkName(sinkName) + .withProcessorParameters(processorParameters) + .withStreamPartitioner(streamPartitioner) + // reusing the source name for the graph node name + // adding explicit variable as it simplifies logic + .withNodeName(sourceName); + + return sourceName; + } + + @Override + public KStream join(final KTable table, + final ValueJoiner joiner) { + return join(table, toValueJoinerWithKey(joiner)); + } + + @Override + public KStream join(final KTable table, + final ValueJoinerWithKey joiner) { + return join(table, joiner, Joined.with(null, null, null)); + } + + @Override + public KStream join(final KTable table, + final ValueJoiner joiner, + final Joined joined) { + Objects.requireNonNull(table, "table can't be null"); + Objects.requireNonNull(joiner, "joiner can't be null"); + Objects.requireNonNull(joined, "joined can't be null"); + return join(table, toValueJoinerWithKey(joiner), joined); + } + + @Override + public KStream join(final KTable table, + final ValueJoinerWithKey joiner, + final Joined joined) { + Objects.requireNonNull(table, "table can't be null"); + Objects.requireNonNull(joiner, "joiner can't be null"); + Objects.requireNonNull(joined, "joined can't be null"); + + final JoinedInternal joinedInternal = new JoinedInternal<>(joined); + final String name = joinedInternal.name(); + + if (repartitionRequired) { + final KStreamImpl thisStreamRepartitioned = repartitionForJoin( + name != null ? name : this.name, + joined.keySerde(), + joined.valueSerde() + ); + return thisStreamRepartitioned.doStreamTableJoin(table, joiner, joined, false); + } else { + return doStreamTableJoin(table, joiner, joined, false); + } + } + + @Override + public KStream leftJoin(final KTable table, final ValueJoiner joiner) { + return leftJoin(table, toValueJoinerWithKey(joiner)); + } + + @Override + public KStream leftJoin(final KTable table, final ValueJoinerWithKey joiner) { + return leftJoin(table, joiner, Joined.with(null, null, null)); + } + + @Override + public KStream leftJoin(final KTable table, + final ValueJoiner joiner, + final Joined joined) { + Objects.requireNonNull(table, "table can't be null"); + Objects.requireNonNull(joiner, "joiner can't be null"); + Objects.requireNonNull(joined, "joined can't be null"); + + return leftJoin(table, toValueJoinerWithKey(joiner), joined); + } + + @Override + public KStream leftJoin(final KTable table, + final ValueJoinerWithKey joiner, + final Joined joined) { + Objects.requireNonNull(table, "table can't be null"); + Objects.requireNonNull(joiner, "joiner can't be null"); + Objects.requireNonNull(joined, "joined can't be null"); + final JoinedInternal joinedInternal = new JoinedInternal<>(joined); + final String name = joinedInternal.name(); + + if (repartitionRequired) { + final KStreamImpl thisStreamRepartitioned = repartitionForJoin( + name != null ? name : this.name, + joined.keySerde(), + joined.valueSerde() + ); + return thisStreamRepartitioned.doStreamTableJoin(table, joiner, joined, true); + } else { + return doStreamTableJoin(table, joiner, joined, true); + } + } + + @Override + public KStream join(final GlobalKTable globalTable, + final KeyValueMapper keySelector, + final ValueJoiner joiner) { + return join(globalTable, keySelector, toValueJoinerWithKey(joiner)); + } + + @Override + public KStream join(final GlobalKTable globalTable, + final KeyValueMapper keySelector, + final ValueJoinerWithKey joiner) { + return globalTableJoin(globalTable, keySelector, joiner, false, NamedInternal.empty()); + } + + @Override + public KStream join(final GlobalKTable globalTable, + final KeyValueMapper keySelector, + final ValueJoiner joiner, + final Named named) { + return join(globalTable, keySelector, toValueJoinerWithKey(joiner), named); + } + + @Override + public KStream join(final GlobalKTable globalTable, + final KeyValueMapper keySelector, + final ValueJoinerWithKey joiner, + final Named named) { + return globalTableJoin(globalTable, keySelector, joiner, false, named); + } + + @Override + public KStream leftJoin(final GlobalKTable globalTable, + final KeyValueMapper keySelector, + final ValueJoiner joiner) { + return leftJoin(globalTable, keySelector, toValueJoinerWithKey(joiner)); + } + + @Override + public KStream leftJoin(final GlobalKTable globalTable, + final KeyValueMapper keySelector, + final ValueJoinerWithKey joiner) { + return globalTableJoin(globalTable, keySelector, joiner, true, NamedInternal.empty()); + } + + @Override + public KStream leftJoin(final GlobalKTable globalTable, + final KeyValueMapper keySelector, + final ValueJoiner joiner, + final Named named) { + return leftJoin(globalTable, keySelector, toValueJoinerWithKey(joiner), named); + } + + @Override + public KStream leftJoin(final GlobalKTable globalTable, + final KeyValueMapper keySelector, + final ValueJoinerWithKey joiner, + final Named named) { + return globalTableJoin(globalTable, keySelector, joiner, true, named); + } + + private KStream globalTableJoin(final GlobalKTable globalTable, + final KeyValueMapper keySelector, + final ValueJoinerWithKey joiner, + final boolean leftJoin, + final Named named) { + Objects.requireNonNull(globalTable, "globalTable can't be null"); + Objects.requireNonNull(keySelector, "keySelector can't be null"); + Objects.requireNonNull(joiner, "joiner can't be null"); + Objects.requireNonNull(named, "named can't be null"); + + final KTableValueGetterSupplier valueGetterSupplier = + ((GlobalKTableImpl) globalTable).valueGetterSupplier(); + final String name = new NamedInternal(named).orElseGenerateWithPrefix(builder, LEFTJOIN_NAME); + // Old PAPI. Needs to be migrated. + @SuppressWarnings("deprecation") + final org.apache.kafka.streams.processor.ProcessorSupplier processorSupplier = new KStreamGlobalKTableJoin<>( + valueGetterSupplier, + joiner, + keySelector, + leftJoin); + final ProcessorParameters processorParameters = new ProcessorParameters<>(processorSupplier, name); + final StreamTableJoinNode streamTableJoinNode = + new StreamTableJoinNode<>(name, processorParameters, new String[] {}, null); + + builder.addGraphNode(graphNode, streamTableJoinNode); + + // do not have serde for joined result + return new KStreamImpl<>( + name, + keySerde, + null, + subTopologySourceNodes, + repartitionRequired, + streamTableJoinNode, + builder); + } + + @SuppressWarnings("unchecked") + private KStream doStreamTableJoin(final KTable table, + final ValueJoinerWithKey joiner, + final Joined joined, + final boolean leftJoin) { + Objects.requireNonNull(table, "table can't be null"); + Objects.requireNonNull(joiner, "joiner can't be null"); + + final Set allSourceNodes = ensureCopartitionWith(Collections.singleton((AbstractStream) table)); + + final JoinedInternal joinedInternal = new JoinedInternal<>(joined); + final NamedInternal renamed = new NamedInternal(joinedInternal.name()); + + final String name = renamed.orElseGenerateWithPrefix(builder, leftJoin ? LEFTJOIN_NAME : JOIN_NAME); + // Old PAPI. Needs to be migrated. + @SuppressWarnings("deprecation") + final org.apache.kafka.streams.processor.ProcessorSupplier processorSupplier = new KStreamKTableJoin<>( + ((KTableImpl) table).valueGetterSupplier(), + joiner, + leftJoin); + + final ProcessorParameters processorParameters = new ProcessorParameters<>(processorSupplier, name); + final StreamTableJoinNode streamTableJoinNode = new StreamTableJoinNode<>( + name, + processorParameters, + ((KTableImpl) table).valueGetterSupplier().storeNames(), + this.name + ); + + builder.addGraphNode(graphNode, streamTableJoinNode); + + // do not have serde for joined result + return new KStreamImpl<>( + name, + joined.keySerde() != null ? joined.keySerde() : keySerde, + null, + allSourceNodes, + false, + streamTableJoinNode, + builder); + } + + @Override + public KStream transform(final TransformerSupplier> transformerSupplier, + final String... stateStoreNames) { + Objects.requireNonNull(transformerSupplier, "transformerSupplier can't be null"); + final String name = builder.newProcessorName(TRANSFORM_NAME); + return flatTransform(new TransformerSupplierAdapter<>(transformerSupplier), Named.as(name), stateStoreNames); + } + + @Override + public KStream transform(final TransformerSupplier> transformerSupplier, + final Named named, + final String... stateStoreNames) { + Objects.requireNonNull(transformerSupplier, "transformerSupplier can't be null"); + return flatTransform(new TransformerSupplierAdapter<>(transformerSupplier), named, stateStoreNames); + } + + @Override + public KStream flatTransform(final TransformerSupplier>> transformerSupplier, + final String... stateStoreNames) { + Objects.requireNonNull(transformerSupplier, "transformerSupplier can't be null"); + final String name = builder.newProcessorName(TRANSFORM_NAME); + return flatTransform(transformerSupplier, Named.as(name), stateStoreNames); + } + + @Override + public KStream flatTransform(final TransformerSupplier>> transformerSupplier, + final Named named, + final String... stateStoreNames) { + Objects.requireNonNull(transformerSupplier, "transformerSupplier can't be null"); + Objects.requireNonNull(named, "named can't be null"); + Objects.requireNonNull(stateStoreNames, "stateStoreNames can't be a null array"); + ApiUtils.checkSupplier(transformerSupplier); + for (final String stateStoreName : stateStoreNames) { + Objects.requireNonNull(stateStoreName, "stateStoreNames can't contain `null` as store name"); + } + + final String name = new NamedInternal(named).name(); + final StatefulProcessorNode transformNode = new StatefulProcessorNode<>( + name, + new ProcessorParameters<>(new KStreamFlatTransform<>(transformerSupplier), name), + stateStoreNames); + transformNode.keyChangingOperation(true); + + builder.addGraphNode(graphNode, transformNode); + + // cannot inherit key and value serde + return new KStreamImpl<>( + name, + null, + null, + subTopologySourceNodes, + true, + transformNode, + builder); + } + + @Override + public KStream transformValues(final ValueTransformerSupplier valueTransformerSupplier, + final String... stateStoreNames) { + Objects.requireNonNull(valueTransformerSupplier, "valueTransformerSupplier can't be null"); + return doTransformValues( + toValueTransformerWithKeySupplier(valueTransformerSupplier), + NamedInternal.empty(), + stateStoreNames); + } + + @Override + public KStream transformValues(final ValueTransformerSupplier valueTransformerSupplier, + final Named named, + final String... stateStoreNames) { + Objects.requireNonNull(valueTransformerSupplier, "valueTransformerSupplier can't be null"); + Objects.requireNonNull(named, "named can't be null"); + return doTransformValues( + toValueTransformerWithKeySupplier(valueTransformerSupplier), + new NamedInternal(named), + stateStoreNames); + } + + @Override + public KStream transformValues(final ValueTransformerWithKeySupplier valueTransformerSupplier, + final String... stateStoreNames) { + Objects.requireNonNull(valueTransformerSupplier, "valueTransformerSupplier can't be null"); + return doTransformValues(valueTransformerSupplier, NamedInternal.empty(), stateStoreNames); + } + + @Override + public KStream transformValues(final ValueTransformerWithKeySupplier valueTransformerSupplier, + final Named named, + final String... stateStoreNames) { + Objects.requireNonNull(valueTransformerSupplier, "valueTransformerSupplier can't be null"); + Objects.requireNonNull(named, "named can't be null"); + return doTransformValues(valueTransformerSupplier, new NamedInternal(named), stateStoreNames); + } + + private KStream doTransformValues(final ValueTransformerWithKeySupplier valueTransformerWithKeySupplier, + final NamedInternal named, + final String... stateStoreNames) { + Objects.requireNonNull(stateStoreNames, "stateStoreNames can't be a null array"); + for (final String stateStoreName : stateStoreNames) { + Objects.requireNonNull(stateStoreName, "stateStoreNames can't contain `null` as store name"); + } + ApiUtils.checkSupplier(valueTransformerWithKeySupplier); + + final String name = named.orElseGenerateWithPrefix(builder, TRANSFORMVALUES_NAME); + final StatefulProcessorNode transformNode = new StatefulProcessorNode<>( + name, + new ProcessorParameters<>(new KStreamTransformValues<>(valueTransformerWithKeySupplier), name), + stateStoreNames); + transformNode.setValueChangingOperation(true); + + builder.addGraphNode(graphNode, transformNode); + + // cannot inherit value serde + return new KStreamImpl<>( + name, + keySerde, + null, + subTopologySourceNodes, + repartitionRequired, + transformNode, + builder); + } + + @Override + public KStream flatTransformValues(final ValueTransformerSupplier> valueTransformerSupplier, + final String... stateStoreNames) { + Objects.requireNonNull(valueTransformerSupplier, "valueTransformerSupplier can't be null"); + return doFlatTransformValues( + toValueTransformerWithKeySupplier(valueTransformerSupplier), + NamedInternal.empty(), + stateStoreNames); + } + + @Override + public KStream flatTransformValues(final ValueTransformerSupplier> valueTransformerSupplier, + final Named named, + final String... stateStoreNames) { + Objects.requireNonNull(valueTransformerSupplier, "valueTransformerSupplier can't be null"); + return doFlatTransformValues( + toValueTransformerWithKeySupplier(valueTransformerSupplier), + named, + stateStoreNames); + } + + @Override + public KStream flatTransformValues(final ValueTransformerWithKeySupplier> valueTransformerSupplier, + final String... stateStoreNames) { + Objects.requireNonNull(valueTransformerSupplier, "valueTransformerSupplier can't be null"); + return doFlatTransformValues(valueTransformerSupplier, NamedInternal.empty(), stateStoreNames); + } + + @Override + public KStream flatTransformValues(final ValueTransformerWithKeySupplier> valueTransformerSupplier, + final Named named, + final String... stateStoreNames) { + Objects.requireNonNull(valueTransformerSupplier, "valueTransformerSupplier can't be null"); + return doFlatTransformValues(valueTransformerSupplier, named, stateStoreNames); + } + + private KStream doFlatTransformValues(final ValueTransformerWithKeySupplier> valueTransformerWithKeySupplier, + final Named named, + final String... stateStoreNames) { + Objects.requireNonNull(stateStoreNames, "stateStoreNames can't be a null array"); + for (final String stateStoreName : stateStoreNames) { + Objects.requireNonNull(stateStoreName, "stateStoreNames can't contain `null` as store name"); + } + ApiUtils.checkSupplier(valueTransformerWithKeySupplier); + + final String name = new NamedInternal(named).orElseGenerateWithPrefix(builder, TRANSFORMVALUES_NAME); + final StatefulProcessorNode transformNode = new StatefulProcessorNode<>( + name, + new ProcessorParameters<>(new KStreamFlatTransformValues<>(valueTransformerWithKeySupplier), name), + stateStoreNames); + transformNode.setValueChangingOperation(true); + + builder.addGraphNode(graphNode, transformNode); + + // cannot inherit value serde + return new KStreamImpl<>( + name, + keySerde, + null, + subTopologySourceNodes, + repartitionRequired, + transformNode, + builder); + } + + @Override + @Deprecated + public void process(final org.apache.kafka.streams.processor.ProcessorSupplier processorSupplier, + final String... stateStoreNames) { + process(processorSupplier, Named.as(builder.newProcessorName(PROCESSOR_NAME)), stateStoreNames); + } + + @Override + public void process(final ProcessorSupplier processorSupplier, + final String... stateStoreNames) { + process(processorSupplier, Named.as(builder.newProcessorName(PROCESSOR_NAME)), stateStoreNames); + } + + @Override + @Deprecated + public void process(final org.apache.kafka.streams.processor.ProcessorSupplier processorSupplier, + final Named named, + final String... stateStoreNames) { + Objects.requireNonNull(processorSupplier, "processorSupplier can't be null"); + Objects.requireNonNull(named, "named can't be null"); + Objects.requireNonNull(stateStoreNames, "stateStoreNames can't be a null array"); + ApiUtils.checkSupplier(processorSupplier); + for (final String stateStoreName : stateStoreNames) { + Objects.requireNonNull(stateStoreName, "stateStoreNames can't be null"); + } + + final String name = new NamedInternal(named).name(); + final StatefulProcessorNode processNode = new StatefulProcessorNode<>( + name, + new ProcessorParameters<>(processorSupplier, name), + stateStoreNames); + + builder.addGraphNode(graphNode, processNode); + } + + @Override + public void process(final ProcessorSupplier processorSupplier, + final Named named, + final String... stateStoreNames) { + Objects.requireNonNull(processorSupplier, "processorSupplier can't be null"); + Objects.requireNonNull(named, "named can't be null"); + Objects.requireNonNull(stateStoreNames, "stateStoreNames can't be a null array"); + ApiUtils.checkSupplier(processorSupplier); + for (final String stateStoreName : stateStoreNames) { + Objects.requireNonNull(stateStoreName, "stateStoreNames can't be null"); + } + + final String name = new NamedInternal(named).name(); + final StatefulProcessorNode processNode = new StatefulProcessorNode<>( + name, + new ProcessorParameters<>(processorSupplier, name), + stateStoreNames); + + builder.addGraphNode(graphNode, processNode); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamImplJoin.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamImplJoin.java new file mode 100644 index 0000000..a7496ce --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamImplJoin.java @@ -0,0 +1,338 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.kstream.JoinWindows; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.StreamJoined; +import org.apache.kafka.streams.kstream.ValueJoinerWithKey; +import org.apache.kafka.streams.kstream.internals.graph.GraphNode; +import org.apache.kafka.streams.kstream.internals.graph.ProcessorGraphNode; +import org.apache.kafka.streams.kstream.internals.graph.ProcessorParameters; +import org.apache.kafka.streams.kstream.internals.graph.StreamStreamJoinNode; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.internals.TimestampedKeyAndJoinSide; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.internals.LeftOrRightValue; +import org.apache.kafka.streams.state.WindowBytesStoreSupplier; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.streams.state.internals.TimestampedKeyAndJoinSideSerde; +import org.apache.kafka.streams.state.internals.ListValueStoreBuilder; +import org.apache.kafka.streams.state.internals.LeftOrRightValueSerde; + +import java.time.Duration; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static org.apache.kafka.streams.internals.ApiUtils.prepareMillisCheckFailMsgPrefix; +import static org.apache.kafka.streams.internals.ApiUtils.validateMillisecondDuration; + +class KStreamImplJoin { + + private final InternalStreamsBuilder builder; + private final boolean leftOuter; + private final boolean rightOuter; + + static class TimeTracker { + private long emitIntervalMs = 1000L; + long streamTime = ConsumerRecord.NO_TIMESTAMP; + long minTime = Long.MAX_VALUE; + long nextTimeToEmit; + + public void setEmitInterval(final long emitIntervalMs) { + this.emitIntervalMs = emitIntervalMs; + } + + public void advanceStreamTime(final long recordTimestamp) { + streamTime = Math.max(recordTimestamp, streamTime); + } + + public void updatedMinTime(final long recordTimestamp) { + minTime = Math.min(recordTimestamp, minTime); + } + + public void advanceNextTimeToEmit() { + nextTimeToEmit += emitIntervalMs; + } + } + + KStreamImplJoin(final InternalStreamsBuilder builder, + final boolean leftOuter, + final boolean rightOuter) { + this.builder = builder; + this.leftOuter = leftOuter; + this.rightOuter = rightOuter; + } + + public KStream join(final KStream lhs, + final KStream other, + final ValueJoinerWithKey joiner, + final JoinWindows windows, + final StreamJoined streamJoined) { + + final StreamJoinedInternal streamJoinedInternal = new StreamJoinedInternal<>(streamJoined); + final NamedInternal renamed = new NamedInternal(streamJoinedInternal.name()); + final String joinThisSuffix = rightOuter ? "-outer-this-join" : "-this-join"; + final String joinOtherSuffix = leftOuter ? "-outer-other-join" : "-other-join"; + + final String thisWindowStreamProcessorName = renamed.suffixWithOrElseGet( + "-this-windowed", builder, KStreamImpl.WINDOWED_NAME); + final String otherWindowStreamProcessorName = renamed.suffixWithOrElseGet( + "-other-windowed", builder, KStreamImpl.WINDOWED_NAME); + + final String joinThisGeneratedName = rightOuter ? builder.newProcessorName(KStreamImpl.OUTERTHIS_NAME) : builder.newProcessorName(KStreamImpl.JOINTHIS_NAME); + final String joinOtherGeneratedName = leftOuter ? builder.newProcessorName(KStreamImpl.OUTEROTHER_NAME) : builder.newProcessorName(KStreamImpl.JOINOTHER_NAME); + + final String joinThisName = renamed.suffixWithOrElseGet(joinThisSuffix, joinThisGeneratedName); + final String joinOtherName = renamed.suffixWithOrElseGet(joinOtherSuffix, joinOtherGeneratedName); + + final String joinMergeName = renamed.suffixWithOrElseGet( + "-merge", builder, KStreamImpl.MERGE_NAME); + + final GraphNode thisGraphNode = ((AbstractStream) lhs).graphNode; + final GraphNode otherGraphNode = ((AbstractStream) other).graphNode; + + final StoreBuilder> thisWindowStore; + final StoreBuilder> otherWindowStore; + final String userProvidedBaseStoreName = streamJoinedInternal.storeName(); + + final WindowBytesStoreSupplier thisStoreSupplier = streamJoinedInternal.thisStoreSupplier(); + final WindowBytesStoreSupplier otherStoreSupplier = streamJoinedInternal.otherStoreSupplier(); + + assertUniqueStoreNames(thisStoreSupplier, otherStoreSupplier); + + if (thisStoreSupplier == null) { + final String thisJoinStoreName = userProvidedBaseStoreName == null ? joinThisGeneratedName : userProvidedBaseStoreName + joinThisSuffix; + thisWindowStore = joinWindowStoreBuilder(thisJoinStoreName, windows, streamJoinedInternal.keySerde(), streamJoinedInternal.valueSerde(), streamJoinedInternal.loggingEnabled(), streamJoinedInternal.logConfig()); + } else { + assertWindowSettings(thisStoreSupplier, windows); + thisWindowStore = joinWindowStoreBuilderFromSupplier(thisStoreSupplier, streamJoinedInternal.keySerde(), streamJoinedInternal.valueSerde()); + } + + if (otherStoreSupplier == null) { + final String otherJoinStoreName = userProvidedBaseStoreName == null ? joinOtherGeneratedName : userProvidedBaseStoreName + joinOtherSuffix; + otherWindowStore = joinWindowStoreBuilder(otherJoinStoreName, windows, streamJoinedInternal.keySerde(), streamJoinedInternal.otherValueSerde(), streamJoinedInternal.loggingEnabled(), streamJoinedInternal.logConfig()); + } else { + assertWindowSettings(otherStoreSupplier, windows); + otherWindowStore = joinWindowStoreBuilderFromSupplier(otherStoreSupplier, streamJoinedInternal.keySerde(), streamJoinedInternal.otherValueSerde()); + } + + final KStreamJoinWindow thisWindowedStream = new KStreamJoinWindow<>(thisWindowStore.name()); + + final ProcessorParameters thisWindowStreamProcessorParams = new ProcessorParameters<>(thisWindowedStream, thisWindowStreamProcessorName); + final ProcessorGraphNode thisWindowedStreamsNode = new ProcessorGraphNode<>(thisWindowStreamProcessorName, thisWindowStreamProcessorParams); + builder.addGraphNode(thisGraphNode, thisWindowedStreamsNode); + + final KStreamJoinWindow otherWindowedStream = new KStreamJoinWindow<>(otherWindowStore.name()); + + final ProcessorParameters otherWindowStreamProcessorParams = new ProcessorParameters<>(otherWindowedStream, otherWindowStreamProcessorName); + final ProcessorGraphNode otherWindowedStreamsNode = new ProcessorGraphNode<>(otherWindowStreamProcessorName, otherWindowStreamProcessorParams); + builder.addGraphNode(otherGraphNode, otherWindowedStreamsNode); + + Optional, LeftOrRightValue>>> outerJoinWindowStore = Optional.empty(); + if (leftOuter) { + outerJoinWindowStore = Optional.of(sharedOuterJoinWindowStoreBuilder(windows, streamJoinedInternal, joinThisGeneratedName)); + } + + // Time-shared between joins to keep track of the maximum stream time + final TimeTracker sharedTimeTracker = new TimeTracker(); + + final JoinWindowsInternal internalWindows = new JoinWindowsInternal(windows); + final KStreamKStreamJoin joinThis = new KStreamKStreamJoin<>( + true, + otherWindowStore.name(), + internalWindows, + joiner, + leftOuter, + outerJoinWindowStore.map(StoreBuilder::name), + sharedTimeTracker + ); + + final KStreamKStreamJoin joinOther = new KStreamKStreamJoin<>( + false, + thisWindowStore.name(), + internalWindows, + AbstractStream.reverseJoinerWithKey(joiner), + rightOuter, + outerJoinWindowStore.map(StoreBuilder::name), + sharedTimeTracker + ); + + final PassThrough joinMerge = new PassThrough<>(); + + final StreamStreamJoinNode.StreamStreamJoinNodeBuilder joinBuilder = StreamStreamJoinNode.streamStreamJoinNodeBuilder(); + + final ProcessorParameters joinThisProcessorParams = new ProcessorParameters<>(joinThis, joinThisName); + final ProcessorParameters joinOtherProcessorParams = new ProcessorParameters<>(joinOther, joinOtherName); + final ProcessorParameters joinMergeProcessorParams = new ProcessorParameters<>(joinMerge, joinMergeName); + + joinBuilder.withJoinMergeProcessorParameters(joinMergeProcessorParams) + .withJoinThisProcessorParameters(joinThisProcessorParams) + .withJoinOtherProcessorParameters(joinOtherProcessorParams) + .withThisWindowStoreBuilder(thisWindowStore) + .withOtherWindowStoreBuilder(otherWindowStore) + .withThisWindowedStreamProcessorParameters(thisWindowStreamProcessorParams) + .withOtherWindowedStreamProcessorParameters(otherWindowStreamProcessorParams) + .withOuterJoinWindowStoreBuilder(outerJoinWindowStore) + .withValueJoiner(joiner) + .withNodeName(joinMergeName); + + if (internalWindows.spuriousResultFixEnabled()) { + joinBuilder.withSpuriousResultFixEnabled(); + } + + final GraphNode joinGraphNode = joinBuilder.build(); + + builder.addGraphNode(Arrays.asList(thisGraphNode, otherGraphNode), joinGraphNode); + + final Set allSourceNodes = new HashSet<>(((KStreamImpl) lhs).subTopologySourceNodes); + allSourceNodes.addAll(((KStreamImpl) other).subTopologySourceNodes); + + // do not have serde for joined result; + // also for key serde we do not inherit from either since we cannot tell if these two serdes are different + return new KStreamImpl<>(joinMergeName, streamJoinedInternal.keySerde(), null, allSourceNodes, false, joinGraphNode, builder); + } + + private void assertWindowSettings(final WindowBytesStoreSupplier supplier, final JoinWindows joinWindows) { + if (!supplier.retainDuplicates()) { + throw new StreamsException("The StoreSupplier must set retainDuplicates=true, found retainDuplicates=false"); + } + final boolean allMatch = supplier.retentionPeriod() == (joinWindows.size() + joinWindows.gracePeriodMs()) && + supplier.windowSize() == joinWindows.size(); + if (!allMatch) { + throw new StreamsException(String.format("Window settings mismatch. WindowBytesStoreSupplier settings %s must match JoinWindows settings %s" + + " for the window size and retention period", supplier, joinWindows)); + } + } + + private void assertUniqueStoreNames(final WindowBytesStoreSupplier supplier, + final WindowBytesStoreSupplier otherSupplier) { + + if (supplier != null + && otherSupplier != null + && supplier.name().equals(otherSupplier.name())) { + throw new StreamsException("Both StoreSuppliers have the same name. StoreSuppliers must provide unique names"); + } + } + + private static StoreBuilder> joinWindowStoreBuilder(final String storeName, + final JoinWindows windows, + final Serde keySerde, + final Serde valueSerde, + final boolean loggingEnabled, + final Map logConfig) { + final StoreBuilder> builder = Stores.windowStoreBuilder( + Stores.persistentWindowStore( + storeName + "-store", + Duration.ofMillis(windows.size() + windows.gracePeriodMs()), + Duration.ofMillis(windows.size()), + true + ), + keySerde, + valueSerde + ); + if (loggingEnabled) { + builder.withLoggingEnabled(logConfig); + } else { + builder.withLoggingDisabled(); + } + + return builder; + } + + private String buildOuterJoinWindowStoreName(final StreamJoinedInternal streamJoinedInternal, final String joinThisGeneratedName) { + final String outerJoinSuffix = rightOuter ? "-outer-shared-join" : "-left-shared-join"; + + if (streamJoinedInternal.thisStoreSupplier() != null && !streamJoinedInternal.thisStoreSupplier().name().isEmpty()) { + return streamJoinedInternal.thisStoreSupplier().name() + outerJoinSuffix; + } else if (streamJoinedInternal.storeName() != null) { + return streamJoinedInternal.storeName() + outerJoinSuffix; + } else { + return KStreamImpl.OUTERSHARED_NAME + + joinThisGeneratedName.substring( + rightOuter + ? KStreamImpl.OUTERTHIS_NAME.length() + : KStreamImpl.JOINTHIS_NAME.length()); + } + } + + private StoreBuilder, LeftOrRightValue>> sharedOuterJoinWindowStoreBuilder(final JoinWindows windows, + final StreamJoinedInternal streamJoinedInternal, + final String joinThisGeneratedName) { + final boolean persistent = streamJoinedInternal.thisStoreSupplier() == null || streamJoinedInternal.thisStoreSupplier().get().persistent(); + final String storeName = buildOuterJoinWindowStoreName(streamJoinedInternal, joinThisGeneratedName) + "-store"; + + // we are using a key-value store with list-values for the shared store, and have the window retention / grace period + // handled totally on the processor node level, and hence here we are only validating these values but not using them at all + final Duration retentionPeriod = Duration.ofMillis(windows.size() + windows.gracePeriodMs()); + final Duration windowSize = Duration.ofMillis(windows.size()); + final String rpMsgPrefix = prepareMillisCheckFailMsgPrefix(retentionPeriod, "retentionPeriod"); + final long retentionMs = validateMillisecondDuration(retentionPeriod, rpMsgPrefix); + final String wsMsgPrefix = prepareMillisCheckFailMsgPrefix(windowSize, "windowSize"); + final long windowSizeMs = validateMillisecondDuration(windowSize, wsMsgPrefix); + + if (retentionMs < 0L) { + throw new IllegalArgumentException("retentionPeriod cannot be negative"); + } + if (windowSizeMs < 0L) { + throw new IllegalArgumentException("windowSize cannot be negative"); + } + if (windowSizeMs > retentionMs) { + throw new IllegalArgumentException("The retention period of the window store " + + storeName + " must be no smaller than its window size. Got size=[" + + windowSizeMs + "], retention=[" + retentionMs + "]"); + } + + final TimestampedKeyAndJoinSideSerde timestampedKeyAndJoinSideSerde = new TimestampedKeyAndJoinSideSerde<>(streamJoinedInternal.keySerde()); + final LeftOrRightValueSerde leftOrRightValueSerde = new LeftOrRightValueSerde<>(streamJoinedInternal.valueSerde(), streamJoinedInternal.otherValueSerde()); + + final StoreBuilder, LeftOrRightValue>> builder = + new ListValueStoreBuilder<>( + persistent ? Stores.persistentKeyValueStore(storeName) : Stores.inMemoryKeyValueStore(storeName), + timestampedKeyAndJoinSideSerde, + leftOrRightValueSerde, + Time.SYSTEM + ); + + if (streamJoinedInternal.loggingEnabled()) { + builder.withLoggingEnabled(streamJoinedInternal.logConfig()); + } else { + builder.withLoggingDisabled(); + } + + return builder; + } + + private static StoreBuilder> joinWindowStoreBuilderFromSupplier(final WindowBytesStoreSupplier storeSupplier, + final Serde keySerde, + final Serde valueSerde) { + return Stores.windowStoreBuilder( + storeSupplier, + keySerde, + valueSerde + ); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamJoinWindow.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamJoinWindow.java new file mode 100644 index 0000000..317943a --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamJoinWindow.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.processor.api.ContextualProcessor; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.state.WindowStore; + +class KStreamJoinWindow implements ProcessorSupplier { + + private final String windowName; + + KStreamJoinWindow(final String windowName) { + this.windowName = windowName; + } + + @Override + public Processor get() { + return new KStreamJoinWindowProcessor(); + } + + private class KStreamJoinWindowProcessor extends ContextualProcessor { + + private WindowStore window; + + @Override + public void init(final ProcessorContext context) { + super.init(context); + + window = context.getStateStore(windowName); + } + + @Override + public void process(final Record record) { + // if the key is null, we do not need to put the record into window store + // since it will never be considered for join operations + if (record.key() != null) { + context().forward(record); + // Every record basically starts a new window. We're using a window store mostly for the retention. + window.put(record.key(), record.value(), record.timestamp()); + } + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamKStreamJoin.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamKStreamJoin.java new file mode 100644 index 0000000..305cb38 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamKStreamJoin.java @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.ValueJoinerWithKey; +import org.apache.kafka.streams.kstream.internals.KStreamImplJoin.TimeTracker; +import org.apache.kafka.streams.processor.api.ContextualProcessor; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.api.RecordMetadata; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.streams.state.WindowStoreIterator; +import org.apache.kafka.streams.state.internals.TimestampedKeyAndJoinSide; +import org.apache.kafka.streams.state.internals.LeftOrRightValue; +import org.apache.kafka.streams.StreamsConfig; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Optional; + +import static org.apache.kafka.streams.StreamsConfig.InternalConfig.EMIT_INTERVAL_MS_KSTREAMS_OUTER_JOIN_SPURIOUS_RESULTS_FIX; +import static org.apache.kafka.streams.processor.internals.metrics.TaskMetrics.droppedRecordsSensor; + +class KStreamKStreamJoin implements ProcessorSupplier { + private static final Logger LOG = LoggerFactory.getLogger(KStreamKStreamJoin.class); + + private final String otherWindowName; + private final long joinBeforeMs; + private final long joinAfterMs; + private final long joinGraceMs; + private final boolean enableSpuriousResultFix; + + private final boolean outer; + private final boolean isLeftSide; + private final Optional outerJoinWindowName; + private final ValueJoinerWithKey joiner; + + private final TimeTracker sharedTimeTracker; + + KStreamKStreamJoin(final boolean isLeftSide, + final String otherWindowName, + final JoinWindowsInternal windows, + final ValueJoinerWithKey joiner, + final boolean outer, + final Optional outerJoinWindowName, + final TimeTracker sharedTimeTracker) { + this.isLeftSide = isLeftSide; + this.otherWindowName = otherWindowName; + if (isLeftSide) { + this.joinBeforeMs = windows.beforeMs; + this.joinAfterMs = windows.afterMs; + } else { + this.joinBeforeMs = windows.afterMs; + this.joinAfterMs = windows.beforeMs; + } + this.joinGraceMs = windows.gracePeriodMs(); + this.enableSpuriousResultFix = windows.spuriousResultFixEnabled(); + this.joiner = joiner; + this.outer = outer; + this.outerJoinWindowName = outerJoinWindowName; + this.sharedTimeTracker = sharedTimeTracker; + } + + @Override + public Processor get() { + return new KStreamKStreamJoinProcessor(); + } + + private class KStreamKStreamJoinProcessor extends ContextualProcessor { + private WindowStore otherWindowStore; + private Sensor droppedRecordsSensor; + private Optional, LeftOrRightValue>> outerJoinStore = Optional.empty(); + private InternalProcessorContext internalProcessorContext; + + @Override + public void init(final ProcessorContext context) { + super.init(context); + internalProcessorContext = (InternalProcessorContext) context; + + final StreamsMetricsImpl metrics = (StreamsMetricsImpl) context.metrics(); + droppedRecordsSensor = droppedRecordsSensor(Thread.currentThread().getName(), context.taskId().toString(), metrics); + otherWindowStore = context.getStateStore(otherWindowName); + + if (enableSpuriousResultFix) { + outerJoinStore = outerJoinWindowName.map(context::getStateStore); + + sharedTimeTracker.setEmitInterval( + StreamsConfig.InternalConfig.getLong( + context.appConfigs(), + EMIT_INTERVAL_MS_KSTREAMS_OUTER_JOIN_SPURIOUS_RESULTS_FIX, + 1000L + ) + ); + } + } + + @SuppressWarnings("unchecked") + @Override + public void process(final Record record) { + // we do join iff keys are equal, thus, if key is null we cannot join and just ignore the record + // + // we also ignore the record if value is null, because in a key-value data model a null-value indicates + // an empty message (ie, there is nothing to be joined) -- this contrast SQL NULL semantics + // furthermore, on left/outer joins 'null' in ValueJoiner#apply() indicates a missing record -- + // thus, to be consistent and to avoid ambiguous null semantics, null values are ignored + if (record.key() == null || record.value() == null) { + if (context().recordMetadata().isPresent()) { + final RecordMetadata recordMetadata = context().recordMetadata().get(); + LOG.warn( + "Skipping record due to null key or value. " + + "topic=[{}] partition=[{}] offset=[{}]", + recordMetadata.topic(), recordMetadata.partition(), recordMetadata.offset() + ); + } else { + LOG.warn( + "Skipping record due to null key or value. Topic, partition, and offset not known." + ); + } + droppedRecordsSensor.record(); + return; + } + + boolean needOuterJoin = outer; + + final long inputRecordTimestamp = record.timestamp(); + final long timeFrom = Math.max(0L, inputRecordTimestamp - joinBeforeMs); + final long timeTo = Math.max(0L, inputRecordTimestamp + joinAfterMs); + + sharedTimeTracker.advanceStreamTime(inputRecordTimestamp); + + // Emit all non-joined records which window has closed + if (inputRecordTimestamp == sharedTimeTracker.streamTime) { + outerJoinStore.ifPresent(this::emitNonJoinedOuterRecords); + } + + try (final WindowStoreIterator iter = otherWindowStore.fetch(record.key(), timeFrom, timeTo)) { + while (iter.hasNext()) { + needOuterJoin = false; + final KeyValue otherRecord = iter.next(); + final long otherRecordTimestamp = otherRecord.key; + + outerJoinStore.ifPresent(store -> { + // use putIfAbsent to first read and see if there's any values for the key, + // if yes delete the key, otherwise do not issue a put; + // we may delete some values with the same key early but since we are going + // range over all values of the same key even after failure, since the other window-store + // is only cleaned up by stream time, so this is okay for at-least-once. + store.putIfAbsent(TimestampedKeyAndJoinSide.make(!isLeftSide, record.key(), otherRecordTimestamp), null); + }); + + context().forward( + record.withValue(joiner.apply(record.key(), record.value(), otherRecord.value)) + .withTimestamp(Math.max(inputRecordTimestamp, otherRecordTimestamp))); + } + + if (needOuterJoin) { + // The maxStreamTime contains the max time observed in both sides of the join. + // Having access to the time observed in the other join side fixes the following + // problem: + // + // Say we have a window size of 5 seconds + // 1. A non-joined record wth time T10 is seen in the left-topic (maxLeftStreamTime: 10) + // The record is not processed yet, and is added to the outer-join store + // 2. A non-joined record with time T2 is seen in the right-topic (maxRightStreamTime: 2) + // The record is not processed yet, and is added to the outer-join store + // 3. A joined record with time T11 is seen in the left-topic (maxLeftStreamTime: 11) + // It is time to look at the expired records. T10 and T2 should be emitted, but + // because T2 was late, then it is not fetched by the window store, so it is not processed + // + // See KStreamKStreamLeftJoinTest.testLowerWindowBound() tests + // + // This condition below allows us to process the out-of-order records without the need + // to hold it in the temporary outer store + if (!outerJoinStore.isPresent() || timeTo < sharedTimeTracker.streamTime) { + context().forward(record.withValue(joiner.apply(record.key(), record.value(), null))); + } else { + sharedTimeTracker.updatedMinTime(inputRecordTimestamp); + outerJoinStore.ifPresent(store -> store.put( + TimestampedKeyAndJoinSide.make(isLeftSide, record.key(), inputRecordTimestamp), + LeftOrRightValue.make(isLeftSide, record.value()))); + } + } + } + } + + @SuppressWarnings("unchecked") + private void emitNonJoinedOuterRecords(final KeyValueStore, LeftOrRightValue> store) { + // calling `store.all()` creates an iterator what is an expensive operation on RocksDB; + // to reduce runtime cost, we try to avoid paying those cost + + // only try to emit left/outer join results if there _might_ be any result records + if (sharedTimeTracker.minTime >= sharedTimeTracker.streamTime - joinAfterMs - joinGraceMs) { + return; + } + // throttle the emit frequency to a (configurable) interval; + // we use processing time to decouple from data properties, + // as throttling is a non-functional performance optimization + if (internalProcessorContext.currentSystemTimeMs() < sharedTimeTracker.nextTimeToEmit) { + return; + } + if (sharedTimeTracker.nextTimeToEmit == 0) { + sharedTimeTracker.nextTimeToEmit = internalProcessorContext.currentSystemTimeMs(); + } + sharedTimeTracker.advanceNextTimeToEmit(); + + // reset to MAX_VALUE in case the store is empty + sharedTimeTracker.minTime = Long.MAX_VALUE; + + try (final KeyValueIterator, LeftOrRightValue> it = store.all()) { + TimestampedKeyAndJoinSide prevKey = null; + + while (it.hasNext()) { + final KeyValue, LeftOrRightValue> record = it.next(); + + final TimestampedKeyAndJoinSide timestampedKeyAndJoinSide = record.key; + final LeftOrRightValue value = record.value; + final K key = timestampedKeyAndJoinSide.getKey(); + final long timestamp = timestampedKeyAndJoinSide.getTimestamp(); + sharedTimeTracker.minTime = timestamp; + + // Skip next records if window has not closed + if (timestamp + joinAfterMs + joinGraceMs >= sharedTimeTracker.streamTime) { + break; + } + + final VOut nullJoinedValue; + if (isLeftSide) { + nullJoinedValue = joiner.apply(key, + value.getLeftValue(), + value.getRightValue()); + } else { + nullJoinedValue = joiner.apply(key, + (V1) value.getRightValue(), + (V2) value.getLeftValue()); + } + + context().forward(new Record<>(key, nullJoinedValue, timestamp)); + + if (prevKey != null && !prevKey.equals(timestampedKeyAndJoinSide)) { + // blind-delete the previous key from the outer window store now it is emitted; + // we do this because this delete would remove the whole list of values of the same key, + // and hence if we delete eagerly and then fail, we would miss emitting join results of the later + // values in the list. + // we do not use delete() calls since it would incur extra get() + store.put(prevKey, null); + } + + prevKey = timestampedKeyAndJoinSide; + } + + // at the end of the iteration, we need to delete the last key + if (prevKey != null) { + store.put(prevKey, null); + } + } + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamKTableJoin.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamKTableJoin.java new file mode 100644 index 0000000..deca86d --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamKTableJoin.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.ValueJoinerWithKey; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +class KStreamKTableJoin implements org.apache.kafka.streams.processor.ProcessorSupplier { + + private final KeyValueMapper keyValueMapper = (key, value) -> key; + private final KTableValueGetterSupplier valueGetterSupplier; + private final ValueJoinerWithKey joiner; + private final boolean leftJoin; + + KStreamKTableJoin(final KTableValueGetterSupplier valueGetterSupplier, + final ValueJoinerWithKey joiner, + final boolean leftJoin) { + this.valueGetterSupplier = valueGetterSupplier; + this.joiner = joiner; + this.leftJoin = leftJoin; + } + + @Override + public org.apache.kafka.streams.processor.Processor get() { + return new KStreamKTableJoinProcessor<>(valueGetterSupplier.get(), keyValueMapper, joiner, leftJoin); + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamKTableJoinProcessor.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamKTableJoinProcessor.java new file mode 100644 index 0000000..38e839c --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamKTableJoinProcessor.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.ValueJoinerWithKey; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.kafka.streams.processor.internals.metrics.TaskMetrics.droppedRecordsSensor; +import static org.apache.kafka.streams.state.ValueAndTimestamp.getValueOrNull; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +class KStreamKTableJoinProcessor extends org.apache.kafka.streams.processor.AbstractProcessor { + private static final Logger LOG = LoggerFactory.getLogger(KStreamKTableJoin.class); + + private final KTableValueGetter valueGetter; + private final KeyValueMapper keyMapper; + private final ValueJoinerWithKey joiner; + private final boolean leftJoin; + private Sensor droppedRecordsSensor; + + KStreamKTableJoinProcessor(final KTableValueGetter valueGetter, + final KeyValueMapper keyMapper, + final ValueJoinerWithKey joiner, + final boolean leftJoin) { + this.valueGetter = valueGetter; + this.keyMapper = keyMapper; + this.joiner = joiner; + this.leftJoin = leftJoin; + } + + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + super.init(context); + final StreamsMetricsImpl metrics = (StreamsMetricsImpl) context.metrics(); + droppedRecordsSensor = droppedRecordsSensor(Thread.currentThread().getName(), context.taskId().toString(), metrics); + valueGetter.init(context); + } + + @Override + public void process(final K1 key, final V1 value) { + // we do join iff the join keys are equal, thus, if {@code keyMapper} returns {@code null} we + // cannot join and just ignore the record. Note for KTables, this is the same as having a null key + // since keyMapper just returns the key, but for GlobalKTables we can have other keyMappers + // + // we also ignore the record if value is null, because in a key-value data model a null-value indicates + // an empty message (ie, there is nothing to be joined) -- this contrast SQL NULL semantics + // furthermore, on left/outer joins 'null' in ValueJoiner#apply() indicates a missing record -- + // thus, to be consistent and to avoid ambiguous null semantics, null values are ignored + final K2 mappedKey = keyMapper.apply(key, value); + if (mappedKey == null || value == null) { + LOG.warn( + "Skipping record due to null join key or value. key=[{}] value=[{}] topic=[{}] partition=[{}] offset=[{}]", + key, value, context().topic(), context().partition(), context().offset() + ); + droppedRecordsSensor.record(); + } else { + final V2 value2 = getValueOrNull(valueGetter.get(mappedKey)); + if (leftJoin || value2 != null) { + context().forward(key, joiner.apply(key, value, value2)); + } + } + } + + @Override + public void close() { + valueGetter.close(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamMap.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamMap.java new file mode 100644 index 0000000..59087f7 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamMap.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.processor.api.ContextualProcessor; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.api.Record; + +import java.util.Objects; + +class KStreamMap implements ProcessorSupplier { + + private final KeyValueMapper> mapper; + + public KStreamMap(final KeyValueMapper> mapper) { + this.mapper = mapper; + } + + @Override + public Processor get() { + return new KStreamMapProcessor(); + } + + private class KStreamMapProcessor extends ContextualProcessor { + + @Override + public void process(final Record record) { + final KeyValue newPair = + mapper.apply(record.key(), record.value()); + Objects.requireNonNull(newPair, "The provided KeyValueMapper returned null which is not allowed."); + context().forward(record.withKey(newPair.key).withValue(newPair.value)); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamMapValues.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamMapValues.java new file mode 100644 index 0000000..f73bfdd --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamMapValues.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.kstream.ValueMapperWithKey; +import org.apache.kafka.streams.processor.api.ContextualProcessor; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.api.Record; + +class KStreamMapValues implements ProcessorSupplier { + + private final ValueMapperWithKey mapper; + + public KStreamMapValues(final ValueMapperWithKey mapper) { + this.mapper = mapper; + } + + @Override + public Processor get() { + return new KStreamMapProcessor(); + } + + private class KStreamMapProcessor extends ContextualProcessor { + @Override + public void process(final Record record) { + final VOut newValue = mapper.apply(record.key(), record.value()); + context().forward(record.withValue(newValue)); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamPeek.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamPeek.java new file mode 100644 index 0000000..69b5e7f --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamPeek.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.kstream.ForeachAction; +import org.apache.kafka.streams.processor.api.ContextualProcessor; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.api.Record; + +class KStreamPeek implements ProcessorSupplier { + + private final ForeachAction action; + + public KStreamPeek(final ForeachAction action) { + this.action = action; + } + + @Override + public Processor get() { + return new KStreamPeekProcessor(); + } + + private class KStreamPeekProcessor extends ContextualProcessor { + @Override + public void process(final Record record) { + action.apply(record.key(), record.value()); + context().forward(record); + } + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamPrint.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamPrint.java new file mode 100644 index 0000000..a04662c --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamPrint.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.kstream.ForeachAction; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.api.Record; + +public class KStreamPrint implements ProcessorSupplier { + + private final ForeachAction action; + + public KStreamPrint(final ForeachAction action) { + this.action = action; + } + + @Override + public Processor get() { + return new KStreamPrintProcessor(); + } + + private class KStreamPrintProcessor implements Processor { + + @Override + public void process(final Record record) { + action.apply(record.key(), record.value()); + } + + @Override + public void close() { + if (action instanceof PrintForeachAction) { + ((PrintForeachAction) action).close(); + } + } + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamReduce.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamReduce.java new file mode 100644 index 0000000..080f9a4 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamReduce.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.streams.kstream.Reducer; +import org.apache.kafka.streams.processor.api.ContextualProcessor; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.api.RecordMetadata; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.kafka.streams.processor.internals.metrics.TaskMetrics.droppedRecordsSensor; +import static org.apache.kafka.streams.state.ValueAndTimestamp.getValueOrNull; + +public class KStreamReduce implements KStreamAggProcessorSupplier { + + private static final Logger LOG = LoggerFactory.getLogger(KStreamReduce.class); + + private final String storeName; + private final Reducer reducer; + + private boolean sendOldValues = false; + + KStreamReduce(final String storeName, final Reducer reducer) { + this.storeName = storeName; + this.reducer = reducer; + } + + @Override + public Processor> get() { + return new KStreamReduceProcessor(); + } + + @Override + public void enableSendingOldValues() { + sendOldValues = true; + } + + + private class KStreamReduceProcessor extends ContextualProcessor> { + private TimestampedKeyValueStore store; + private TimestampedTupleForwarder tupleForwarder; + private Sensor droppedRecordsSensor; + + @Override + public void init(final ProcessorContext> context) { + super.init(context); + droppedRecordsSensor = droppedRecordsSensor( + Thread.currentThread().getName(), + context.taskId().toString(), + (StreamsMetricsImpl) context.metrics() + ); + store = context.getStateStore(storeName); + tupleForwarder = new TimestampedTupleForwarder<>( + store, + context, + new TimestampedCacheFlushListener<>(context), + sendOldValues); + } + + @Override + public void process(final Record record) { + // If the key or value is null we don't need to proceed + if (record.key() == null || record.value() == null) { + if (context().recordMetadata().isPresent()) { + final RecordMetadata recordMetadata = context().recordMetadata().get(); + LOG.warn( + "Skipping record due to null key or value. " + + "topic=[{}] partition=[{}] offset=[{}]", + recordMetadata.topic(), recordMetadata.partition(), recordMetadata.offset() + ); + } else { + LOG.warn( + "Skipping record due to null key. Topic, partition, and offset not known." + ); + } + droppedRecordsSensor.record(); + return; + } + + final ValueAndTimestamp oldAggAndTimestamp = store.get(record.key()); + final V oldAgg = getValueOrNull(oldAggAndTimestamp); + + final V newAgg; + final long newTimestamp; + + if (oldAgg == null) { + newAgg = record.value(); + newTimestamp = record.timestamp(); + } else { + newAgg = reducer.apply(oldAgg, record.value()); + newTimestamp = Math.max(record.timestamp(), oldAggAndTimestamp.timestamp()); + } + + store.put(record.key(), ValueAndTimestamp.make(newAgg, newTimestamp)); + tupleForwarder.maybeForward(record.key(), newAgg, sendOldValues ? oldAgg : null, newTimestamp); + } + } + + @Override + public KTableValueGetterSupplier view() { + return new KTableValueGetterSupplier() { + + public KTableValueGetter get() { + return new KStreamReduceValueGetter(); + } + + @Override + public String[] storeNames() { + return new String[]{storeName}; + } + }; + } + + + private class KStreamReduceValueGetter implements KTableValueGetter { + private TimestampedKeyValueStore store; + + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + store = context.getStateStore(storeName); + } + + @Override + public ValueAndTimestamp get(final K key) { + return store.get(key); + } + } +} + diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamSessionWindowAggregate.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamSessionWindowAggregate.java new file mode 100644 index 0000000..00b6959 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamSessionWindowAggregate.java @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.kstream.Merger; +import org.apache.kafka.streams.kstream.SessionWindows; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.api.ContextualProcessor; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.api.RecordMetadata; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.SessionStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.kafka.streams.processor.internals.metrics.TaskMetrics.droppedRecordsSensor; + +public class KStreamSessionWindowAggregate implements KStreamAggProcessorSupplier, VAgg> { + + private static final Logger LOG = LoggerFactory.getLogger(KStreamSessionWindowAggregate.class); + + private final String storeName; + private final SessionWindows windows; + private final Initializer initializer; + private final Aggregator aggregator; + private final Merger sessionMerger; + + private boolean sendOldValues = false; + + public KStreamSessionWindowAggregate(final SessionWindows windows, + final String storeName, + final Initializer initializer, + final Aggregator aggregator, + final Merger sessionMerger) { + this.windows = windows; + this.storeName = storeName; + this.initializer = initializer; + this.aggregator = aggregator; + this.sessionMerger = sessionMerger; + } + + @Override + public Processor, Change> get() { + return new KStreamSessionWindowAggregateProcessor(); + } + + public SessionWindows windows() { + return windows; + } + + @Override + public void enableSendingOldValues() { + sendOldValues = true; + } + + private class KStreamSessionWindowAggregateProcessor extends + ContextualProcessor, Change> { + + private SessionStore store; + private SessionTupleForwarder tupleForwarder; + private Sensor droppedRecordsSensor; + private long observedStreamTime = ConsumerRecord.NO_TIMESTAMP; + + @Override + public void init(final ProcessorContext, Change> context) { + super.init(context); + final StreamsMetricsImpl metrics = (StreamsMetricsImpl) context.metrics(); + final String threadId = Thread.currentThread().getName(); + droppedRecordsSensor = droppedRecordsSensor(threadId, context.taskId().toString(), + metrics); + store = context.getStateStore(storeName); + tupleForwarder = new SessionTupleForwarder<>( + store, + context, + new SessionCacheFlushListener<>(context), + sendOldValues + ); + } + + @Override + public void process(final Record record) { + // if the key is null, we do not need proceed aggregating + // the record with the table + if (record.key() == null) { + if (context().recordMetadata().isPresent()) { + final RecordMetadata recordMetadata = context().recordMetadata().get(); + LOG.warn( + "Skipping record due to null key. " + + "topic=[{}] partition=[{}] offset=[{}]", + recordMetadata.topic(), recordMetadata.partition(), recordMetadata.offset() + ); + } else { + LOG.warn( + "Skipping record due to null key. Topic, partition, and offset not known." + ); + } + droppedRecordsSensor.record(); + return; + } + + final long timestamp = record.timestamp(); + observedStreamTime = Math.max(observedStreamTime, timestamp); + final long closeTime = observedStreamTime - windows.gracePeriodMs() - windows.inactivityGap(); + + final List, VAgg>> merged = new ArrayList<>(); + final SessionWindow newSessionWindow = new SessionWindow(timestamp, timestamp); + SessionWindow mergedWindow = newSessionWindow; + VAgg agg = initializer.apply(); + + try ( + final KeyValueIterator, VAgg> iterator = store.findSessions( + record.key(), + timestamp - windows.inactivityGap(), + timestamp + windows.inactivityGap() + ) + ) { + while (iterator.hasNext()) { + final KeyValue, VAgg> next = iterator.next(); + merged.add(next); + agg = sessionMerger.apply(record.key(), agg, next.value); + mergedWindow = mergeSessionWindow(mergedWindow, (SessionWindow) next.key.window()); + } + } + + if (mergedWindow.end() < closeTime) { + if (context().recordMetadata().isPresent()) { + final RecordMetadata recordMetadata = context().recordMetadata().get(); + LOG.warn( + "Skipping record for expired window. " + + "topic=[{}] " + + "partition=[{}] " + + "offset=[{}] " + + "timestamp=[{}] " + + "window=[{},{}] " + + "expiration=[{}] " + + "streamTime=[{}]", + recordMetadata.topic(), recordMetadata.partition(), recordMetadata.offset(), + timestamp, + mergedWindow.start(), mergedWindow.end(), + closeTime, + observedStreamTime + ); + } else { + LOG.warn( + "Skipping record for expired window. Topic, partition, and offset not known. " + + "timestamp=[{}] " + + "window=[{},{}] " + + "expiration=[{}] " + + "streamTime=[{}]", + timestamp, + mergedWindow.start(), mergedWindow.end(), + closeTime, + observedStreamTime + ); + } + droppedRecordsSensor.record(); + } else { + if (!mergedWindow.equals(newSessionWindow)) { + for (final KeyValue, VAgg> session : merged) { + store.remove(session.key); + tupleForwarder.maybeForward(session.key, null, + sendOldValues ? session.value : null); + } + } + + agg = aggregator.apply(record.key(), record.value(), agg); + final Windowed sessionKey = new Windowed<>(record.key(), mergedWindow); + store.put(sessionKey, agg); + tupleForwarder.maybeForward(sessionKey, agg, null); + } + } + } + + private SessionWindow mergeSessionWindow(final SessionWindow one, final SessionWindow two) { + final long start = Math.min(one.start(), two.start()); + final long end = Math.max(one.end(), two.end()); + return new SessionWindow(start, end); + } + + @Override + public KTableValueGetterSupplier, VAgg> view() { + return new KTableValueGetterSupplier, VAgg>() { + @Override + public KTableValueGetter, VAgg> get() { + return new KTableSessionWindowValueGetter(); + } + + @Override + public String[] storeNames() { + return new String[]{storeName}; + } + }; + } + + private class KTableSessionWindowValueGetter implements KTableValueGetter, VAgg> { + + private SessionStore store; + + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + store = context.getStateStore(storeName); + } + + @Override + public ValueAndTimestamp get(final Windowed key) { + return ValueAndTimestamp.make( + store.fetchSession(key.key(), key.window().start(), key.window().end()), + key.window().end()); + } + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamSlidingWindowAggregate.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamSlidingWindowAggregate.java new file mode 100644 index 0000000..ac4710e --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamSlidingWindowAggregate.java @@ -0,0 +1,559 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.kstream.Window; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.SlidingWindows; +import org.apache.kafka.streams.processor.api.ContextualProcessor; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.api.RecordMetadata; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.TimestampedWindowStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import java.util.HashSet; +import java.util.Set; + +import static org.apache.kafka.streams.processor.internals.metrics.TaskMetrics.droppedRecordsSensor; +import static org.apache.kafka.streams.state.ValueAndTimestamp.getValueOrNull; + +public class KStreamSlidingWindowAggregate implements KStreamAggProcessorSupplier, VAgg> { + + private final Logger log = LoggerFactory.getLogger(getClass()); + + private final String storeName; + private final SlidingWindows windows; + private final Initializer initializer; + private final Aggregator aggregator; + + private boolean sendOldValues = false; + + public KStreamSlidingWindowAggregate(final SlidingWindows windows, + final String storeName, + final Initializer initializer, + final Aggregator aggregator) { + this.windows = windows; + this.storeName = storeName; + this.initializer = initializer; + this.aggregator = aggregator; + } + + @Override + public Processor, Change> get() { + return new KStreamSlidingWindowAggregateProcessor(); + } + + public SlidingWindows windows() { + return windows; + } + + @Override + public void enableSendingOldValues() { + sendOldValues = true; + } + + private class KStreamSlidingWindowAggregateProcessor extends ContextualProcessor, Change> { + private TimestampedWindowStore windowStore; + private TimestampedTupleForwarder, VAgg> tupleForwarder; + private Sensor droppedRecordsSensor; + private long observedStreamTime = ConsumerRecord.NO_TIMESTAMP; + private Boolean reverseIteratorPossible = null; + + @Override + public void init(final ProcessorContext, Change> context) { + super.init(context); + final InternalProcessorContext, Change> internalProcessorContext = + (InternalProcessorContext, Change>) context; + final StreamsMetricsImpl metrics = internalProcessorContext.metrics(); + final String threadId = Thread.currentThread().getName(); + droppedRecordsSensor = droppedRecordsSensor(threadId, context.taskId().toString(), metrics); + windowStore = context.getStateStore(storeName); + tupleForwarder = new TimestampedTupleForwarder<>( + windowStore, + context, + new TimestampedCacheFlushListener<>(context), + sendOldValues); + } + + @Override + public void process(final Record record) { + if (record.key() == null || record.value() == null) { + if (context().recordMetadata().isPresent()) { + final RecordMetadata recordMetadata = context().recordMetadata().get(); + log.warn( + "Skipping record due to null key or value. " + + "topic=[{}] partition=[{}] offset=[{}]", + recordMetadata.topic(), recordMetadata.partition(), recordMetadata.offset() + ); + } else { + log.warn( + "Skipping record due to null key or value. Topic, partition, and offset not known." + ); + } + droppedRecordsSensor.record(); + return; + } + + final long inputRecordTimestamp = record.timestamp(); + observedStreamTime = Math.max(observedStreamTime, inputRecordTimestamp); + final long closeTime = observedStreamTime - windows.gracePeriodMs(); + + if (inputRecordTimestamp + 1L + windows.timeDifferenceMs() <= closeTime) { + if (context().recordMetadata().isPresent()) { + final RecordMetadata recordMetadata = context().recordMetadata().get(); + log.warn( + "Skipping record for expired window. " + + "topic=[{}] " + + "partition=[{}] " + + "offset=[{}] " + + "timestamp=[{}] " + + "window=[{},{}] " + + "expiration=[{}] " + + "streamTime=[{}]", + recordMetadata.topic(), recordMetadata.partition(), recordMetadata.offset(), + record.timestamp(), + inputRecordTimestamp - windows.timeDifferenceMs(), inputRecordTimestamp, + closeTime, + observedStreamTime + ); + } else { + log.warn( + "Skipping record for expired window. Topic, partition, and offset not known. " + + "timestamp=[{}] " + + "window=[{},{}] " + + "expiration=[{}] " + + "streamTime=[{}]", + record.timestamp(), + inputRecordTimestamp - windows.timeDifferenceMs(), inputRecordTimestamp, + closeTime, + observedStreamTime + ); + } + droppedRecordsSensor.record(); + return; + } + + if (inputRecordTimestamp < windows.timeDifferenceMs()) { + processEarly(record.key(), record.value(), inputRecordTimestamp, closeTime); + return; + } + + if (reverseIteratorPossible == null) { + try { + windowStore.backwardFetch(record.key(), 0L, 0L); + reverseIteratorPossible = true; + log.debug("Sliding Windows aggregate using a reverse iterator"); + } catch (final UnsupportedOperationException e) { + reverseIteratorPossible = false; + log.debug("Sliding Windows aggregate using a forward iterator"); + } + } + + if (reverseIteratorPossible) { + processReverse(record.key(), record.value(), inputRecordTimestamp, closeTime); + } else { + processInOrder(record.key(), record.value(), inputRecordTimestamp, closeTime); + } + } + + public void processInOrder(final KIn key, final VIn value, final long inputRecordTimestamp, final long closeTime) { + + final Set windowStartTimes = new HashSet<>(); + + // aggregate that will go in the current record’s left/right window (if needed) + ValueAndTimestamp leftWinAgg = null; + ValueAndTimestamp rightWinAgg = null; + + //if current record's left/right windows already exist + boolean leftWinAlreadyCreated = false; + boolean rightWinAlreadyCreated = false; + + Long previousRecordTimestamp = null; + + try ( + final KeyValueIterator, ValueAndTimestamp> iterator = windowStore.fetch( + key, + key, + Math.max(0, inputRecordTimestamp - 2 * windows.timeDifferenceMs()), + // add 1 to upper bound to catch the current record's right window, if it exists, without more calls to the store + inputRecordTimestamp + 1) + ) { + while (iterator.hasNext()) { + final KeyValue, ValueAndTimestamp> windowBeingProcessed = iterator.next(); + final long startTime = windowBeingProcessed.key.window().start(); + windowStartTimes.add(startTime); + final long endTime = startTime + windows.timeDifferenceMs(); + final long windowMaxRecordTimestamp = windowBeingProcessed.value.timestamp(); + + if (endTime < inputRecordTimestamp) { + leftWinAgg = windowBeingProcessed.value; + previousRecordTimestamp = windowMaxRecordTimestamp; + } else if (endTime == inputRecordTimestamp) { + leftWinAlreadyCreated = true; + if (windowMaxRecordTimestamp < inputRecordTimestamp) { + previousRecordTimestamp = windowMaxRecordTimestamp; + } + updateWindowAndForward(windowBeingProcessed.key.window(), windowBeingProcessed.value, key, value, closeTime, inputRecordTimestamp); + } else if (endTime > inputRecordTimestamp && startTime <= inputRecordTimestamp) { + rightWinAgg = windowBeingProcessed.value; + updateWindowAndForward(windowBeingProcessed.key.window(), windowBeingProcessed.value, key, value, closeTime, inputRecordTimestamp); + } else if (startTime == inputRecordTimestamp + 1) { + rightWinAlreadyCreated = true; + } else { + log.error( + "Unexpected window with start {} found when processing record at {} in `KStreamSlidingWindowAggregate`.", + startTime, inputRecordTimestamp + ); + throw new IllegalStateException("Unexpected window found when processing sliding windows"); + } + } + } + createWindows(key, value, inputRecordTimestamp, closeTime, windowStartTimes, rightWinAgg, leftWinAgg, leftWinAlreadyCreated, rightWinAlreadyCreated, previousRecordTimestamp); + } + + public void processReverse(final KIn key, final VIn value, final long inputRecordTimestamp, final long closeTime) { + + final Set windowStartTimes = new HashSet<>(); + + // aggregate that will go in the current record’s left/right window (if needed) + ValueAndTimestamp leftWinAgg = null; + ValueAndTimestamp rightWinAgg = null; + + //if current record's left/right windows already exist + boolean leftWinAlreadyCreated = false; + boolean rightWinAlreadyCreated = false; + + Long previousRecordTimestamp = null; + + try ( + final KeyValueIterator, ValueAndTimestamp> iterator = windowStore.backwardFetch( + key, + key, + Math.max(0, inputRecordTimestamp - 2 * windows.timeDifferenceMs()), + // add 1 to upper bound to catch the current record's right window, if it exists, without more calls to the store + inputRecordTimestamp + 1) + ) { + while (iterator.hasNext()) { + final KeyValue, ValueAndTimestamp> windowBeingProcessed = iterator.next(); + final long startTime = windowBeingProcessed.key.window().start(); + windowStartTimes.add(startTime); + final long endTime = startTime + windows.timeDifferenceMs(); + final long windowMaxRecordTimestamp = windowBeingProcessed.value.timestamp(); + if (startTime == inputRecordTimestamp + 1) { + rightWinAlreadyCreated = true; + } else if (endTime > inputRecordTimestamp) { + if (rightWinAgg == null) { + rightWinAgg = windowBeingProcessed.value; + } + updateWindowAndForward(windowBeingProcessed.key.window(), windowBeingProcessed.value, key, value, closeTime, inputRecordTimestamp); + } else if (endTime == inputRecordTimestamp) { + leftWinAlreadyCreated = true; + updateWindowAndForward(windowBeingProcessed.key.window(), windowBeingProcessed.value, key, value, closeTime, inputRecordTimestamp); + if (windowMaxRecordTimestamp < inputRecordTimestamp) { + previousRecordTimestamp = windowMaxRecordTimestamp; + } else { + return; + } + } else if (endTime < inputRecordTimestamp) { + leftWinAgg = windowBeingProcessed.value; + previousRecordTimestamp = windowMaxRecordTimestamp; + break; + } else { + log.error( + "Unexpected window with start {} found when processing record at {} in `KStreamSlidingWindowAggregate`.", + startTime, inputRecordTimestamp + ); + throw new IllegalStateException("Unexpected window found when processing sliding windows"); + } + } + } + createWindows(key, value, inputRecordTimestamp, closeTime, windowStartTimes, rightWinAgg, leftWinAgg, leftWinAlreadyCreated, rightWinAlreadyCreated, previousRecordTimestamp); + } + + /** + * Created to handle records where 0 < inputRecordTimestamp < timeDifferenceMs. These records would create + * windows with negative start times, which is not supported. Instead, we will put them into the [0, timeDifferenceMs] + * window as a "workaround", and we will update or create their right windows as new records come in later + */ + private void processEarly(final KIn key, final VIn value, final long inputRecordTimestamp, final long closeTime) { + if (inputRecordTimestamp < 0 || inputRecordTimestamp >= windows.timeDifferenceMs()) { + log.error( + "Early record for sliding windows must fall between fall between 0 <= inputRecordTimestamp. Timestamp {} does not fall between 0 <= {}", + inputRecordTimestamp, windows.timeDifferenceMs() + ); + throw new IllegalArgumentException("Early record for sliding windows must fall between fall between 0 <= inputRecordTimestamp"); + } + + // A window from [0, timeDifferenceMs] that holds all early records + KeyValue, ValueAndTimestamp> combinedWindow = null; + ValueAndTimestamp rightWinAgg = null; + boolean rightWinAlreadyCreated = false; + final Set windowStartTimes = new HashSet<>(); + + Long previousRecordTimestamp = null; + + try ( + final KeyValueIterator, ValueAndTimestamp> iterator = windowStore.fetch( + key, + key, + 0, + // add 1 to upper bound to catch the current record's right window, if it exists, without more calls to the store + inputRecordTimestamp + 1) + ) { + while (iterator.hasNext()) { + final KeyValue, ValueAndTimestamp> windowBeingProcessed = iterator.next(); + final long startTime = windowBeingProcessed.key.window().start(); + windowStartTimes.add(startTime); + final long windowMaxRecordTimestamp = windowBeingProcessed.value.timestamp(); + + if (startTime == 0) { + combinedWindow = windowBeingProcessed; + // We don't need to store previousRecordTimestamp if maxRecordTimestamp >= timestamp + // because the previous record's right window (if there is a previous record) + // would have already been created by maxRecordTimestamp + if (windowMaxRecordTimestamp < inputRecordTimestamp) { + previousRecordTimestamp = windowMaxRecordTimestamp; + } + + } else if (startTime <= inputRecordTimestamp) { + rightWinAgg = windowBeingProcessed.value; + updateWindowAndForward(windowBeingProcessed.key.window(), windowBeingProcessed.value, key, value, closeTime, inputRecordTimestamp); + } else if (startTime == inputRecordTimestamp + 1) { + rightWinAlreadyCreated = true; + } else { + log.error( + "Unexpected window with start {} found when processing record at {} in `KStreamSlidingWindowAggregate`.", + startTime, inputRecordTimestamp + ); + throw new IllegalStateException("Unexpected window found when processing sliding windows"); + } + } + } + + // If there wasn't a right window agg found and we need a right window for our new record, + // the current aggregate in the combined window will go in the new record's right window. We can be sure that the combined + // window only holds records that fall into the current record's right window for two reasons: + // 1. If there were records earlier than the current record AND later than the current record, there would be a right window found + // when we looked for right window agg. + // 2. If there was only a record before the current record, we wouldn't need a right window for the current record and wouldn't update the + // rightWinAgg value here, as the combinedWindow.value.timestamp() < inputRecordTimestamp + if (rightWinAgg == null && combinedWindow != null && combinedWindow.value.timestamp() > inputRecordTimestamp) { + rightWinAgg = combinedWindow.value; + } + + if (!rightWinAlreadyCreated && rightWindowIsNotEmpty(rightWinAgg, inputRecordTimestamp)) { + createCurrentRecordRightWindow(inputRecordTimestamp, rightWinAgg, key); + } + + //create the right window for the previous record if the previous record exists and the window hasn't already been created + if (previousRecordTimestamp != null && !windowStartTimes.contains(previousRecordTimestamp + 1)) { + createPreviousRecordRightWindow(previousRecordTimestamp + 1, inputRecordTimestamp, key, value, closeTime); + } + + if (combinedWindow == null) { + final TimeWindow window = new TimeWindow(0, windows.timeDifferenceMs()); + final ValueAndTimestamp valueAndTime = ValueAndTimestamp.make(initializer.apply(), inputRecordTimestamp); + updateWindowAndForward(window, valueAndTime, key, value, closeTime, inputRecordTimestamp); + + } else { + //update the combined window with the new aggregate + updateWindowAndForward(combinedWindow.key.window(), combinedWindow.value, key, value, closeTime, inputRecordTimestamp); + } + + } + + private void createWindows(final KIn key, + final VIn value, + final long inputRecordTimestamp, + final long closeTime, + final Set windowStartTimes, + final ValueAndTimestamp rightWinAgg, + final ValueAndTimestamp leftWinAgg, + final boolean leftWinAlreadyCreated, + final boolean rightWinAlreadyCreated, + final Long previousRecordTimestamp) { + //create right window for previous record + if (previousRecordTimestamp != null) { + final long previousRightWinStart = previousRecordTimestamp + 1; + if (previousRecordRightWindowDoesNotExistAndIsNotEmpty(windowStartTimes, previousRightWinStart, inputRecordTimestamp)) { + createPreviousRecordRightWindow(previousRightWinStart, inputRecordTimestamp, key, value, closeTime); + } + } + + //create left window for new record + if (!leftWinAlreadyCreated) { + final ValueAndTimestamp valueAndTime; + if (leftWindowNotEmpty(previousRecordTimestamp, inputRecordTimestamp)) { + valueAndTime = ValueAndTimestamp.make(leftWinAgg.value(), inputRecordTimestamp); + } else { + valueAndTime = ValueAndTimestamp.make(initializer.apply(), inputRecordTimestamp); + } + final TimeWindow window = new TimeWindow(inputRecordTimestamp - windows.timeDifferenceMs(), inputRecordTimestamp); + updateWindowAndForward(window, valueAndTime, key, value, closeTime, inputRecordTimestamp); + } + + // create right window for new record, if necessary + if (!rightWinAlreadyCreated && rightWindowIsNotEmpty(rightWinAgg, inputRecordTimestamp)) { + createCurrentRecordRightWindow(inputRecordTimestamp, rightWinAgg, key); + } + } + + private void createCurrentRecordRightWindow(final long inputRecordTimestamp, + final ValueAndTimestamp rightWinAgg, + final KIn key) { + final TimeWindow window = new TimeWindow(inputRecordTimestamp + 1, inputRecordTimestamp + 1 + windows.timeDifferenceMs()); + windowStore.put( + key, + rightWinAgg, + window.start()); + tupleForwarder.maybeForward( + new Windowed<>(key, window), + rightWinAgg.value(), + null, + rightWinAgg.timestamp()); + } + + private void createPreviousRecordRightWindow(final long windowStart, + final long inputRecordTimestamp, + final KIn key, + final VIn value, + final long closeTime) { + final TimeWindow window = new TimeWindow(windowStart, windowStart + windows.timeDifferenceMs()); + final ValueAndTimestamp valueAndTime = ValueAndTimestamp.make(initializer.apply(), inputRecordTimestamp); + updateWindowAndForward(window, valueAndTime, key, value, closeTime, inputRecordTimestamp); + } + + // checks if the previous record falls into the current records left window; if yes, the left window is not empty, otherwise it is empty + private boolean leftWindowNotEmpty(final Long previousRecordTimestamp, final long inputRecordTimestamp) { + return previousRecordTimestamp != null && inputRecordTimestamp - windows.timeDifferenceMs() <= previousRecordTimestamp; + } + + // checks if the previous record's right window does not already exist and the current record falls within previous record's right window + private boolean previousRecordRightWindowDoesNotExistAndIsNotEmpty(final Set windowStartTimes, + final long previousRightWindowStart, + final long inputRecordTimestamp) { + return !windowStartTimes.contains(previousRightWindowStart) && previousRightWindowStart + windows.timeDifferenceMs() >= inputRecordTimestamp; + } + + // checks if the aggregate we found has records that fall into the current record's right window; if yes, the right window is not empty + private boolean rightWindowIsNotEmpty(final ValueAndTimestamp rightWinAgg, final long inputRecordTimestamp) { + return rightWinAgg != null && rightWinAgg.timestamp() > inputRecordTimestamp; + } + + private void updateWindowAndForward(final Window window, + final ValueAndTimestamp valueAndTime, + final KIn key, + final VIn value, + final long closeTime, + final long inputRecordTimestamp) { + final long windowStart = window.start(); + final long windowEnd = window.end(); + if (windowEnd > closeTime) { + //get aggregate from existing window + final VAgg oldAgg = getValueOrNull(valueAndTime); + final VAgg newAgg = aggregator.apply(key, value, oldAgg); + + final long newTimestamp = oldAgg == null ? inputRecordTimestamp : Math.max(inputRecordTimestamp, valueAndTime.timestamp()); + windowStore.put( + key, + ValueAndTimestamp.make(newAgg, newTimestamp), + windowStart); + tupleForwarder.maybeForward( + new Windowed<>(key, window), + newAgg, + sendOldValues ? oldAgg : null, + newTimestamp); + } else { + if (context().recordMetadata().isPresent()) { + final RecordMetadata recordMetadata = context().recordMetadata().get(); + log.warn( + "Skipping record for expired window. " + + "topic=[{}] " + + "partition=[{}] " + + "offset=[{}] " + + "timestamp=[{}] " + + "window=[{},{}] " + + "expiration=[{}] " + + "streamTime=[{}]", + recordMetadata.topic(), recordMetadata.partition(), recordMetadata.offset(), + inputRecordTimestamp, + windowStart, windowEnd, + closeTime, + observedStreamTime + ); + } else { + log.warn( + "Skipping record for expired window. Topic, partition, and offset not known. " + + "timestamp=[{}] " + + "window=[{},{}] " + + "expiration=[{}] " + + "streamTime=[{}]", + inputRecordTimestamp, + windowStart, windowEnd, + closeTime, + observedStreamTime + ); + } + droppedRecordsSensor.record(); + } + } + } + + @Override + public KTableValueGetterSupplier, VAgg> view() { + return new KTableValueGetterSupplier, VAgg>() { + + public KTableValueGetter, VAgg> get() { + return new KStreamWindowAggregateValueGetter(); + } + + @Override + public String[] storeNames() { + return new String[] {storeName}; + } + }; + } + + private class KStreamWindowAggregateValueGetter implements KTableValueGetter, VAgg> { + private TimestampedWindowStore windowStore; + + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + windowStore = context.getStateStore(storeName); + } + + @Override + public ValueAndTimestamp get(final Windowed windowedKey) { + final KIn key = windowedKey.key(); + return windowStore.fetch(key, windowedKey.window().start()); + } + + @Override + public void close() {} + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamTransformValues.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamTransformValues.java new file mode 100644 index 0000000..468bf8d --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamTransformValues.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.kstream.ValueTransformerWithKey; +import org.apache.kafka.streams.kstream.ValueTransformerWithKeySupplier; +import org.apache.kafka.streams.processor.internals.ForwardingDisabledProcessorContext; +import org.apache.kafka.streams.state.StoreBuilder; + +import java.util.Set; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class KStreamTransformValues implements org.apache.kafka.streams.processor.ProcessorSupplier { + + private final ValueTransformerWithKeySupplier valueTransformerSupplier; + + KStreamTransformValues(final ValueTransformerWithKeySupplier valueTransformerSupplier) { + this.valueTransformerSupplier = valueTransformerSupplier; + } + + @Override + public org.apache.kafka.streams.processor.Processor get() { + return new KStreamTransformValuesProcessor<>(valueTransformerSupplier.get()); + } + + @Override + public Set> stores() { + return valueTransformerSupplier.stores(); + } + + public static class KStreamTransformValuesProcessor extends org.apache.kafka.streams.processor.AbstractProcessor { + + private final ValueTransformerWithKey valueTransformer; + + KStreamTransformValuesProcessor(final ValueTransformerWithKey valueTransformer) { + this.valueTransformer = valueTransformer; + } + + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + super.init(context); + valueTransformer.init(new ForwardingDisabledProcessorContext(context)); + } + + @Override + public void process(final K key, final V value) { + context.forward(key, valueTransformer.transform(key, value)); + } + + @Override + public void close() { + valueTransformer.close(); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamWindowAggregate.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamWindowAggregate.java new file mode 100644 index 0000000..5730ae6 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KStreamWindowAggregate.java @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.kstream.Window; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.Windows; +import org.apache.kafka.streams.processor.api.ContextualProcessor; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.api.RecordMetadata; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.TimestampedWindowStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Map; + +import static org.apache.kafka.streams.processor.internals.metrics.TaskMetrics.droppedRecordsSensor; +import static org.apache.kafka.streams.state.ValueAndTimestamp.getValueOrNull; + +public class KStreamWindowAggregate implements KStreamAggProcessorSupplier, VAgg> { + + private final Logger log = LoggerFactory.getLogger(getClass()); + + private final String storeName; + private final Windows windows; + private final Initializer initializer; + private final Aggregator aggregator; + + private boolean sendOldValues = false; + + public KStreamWindowAggregate(final Windows windows, + final String storeName, + final Initializer initializer, + final Aggregator aggregator) { + this.windows = windows; + this.storeName = storeName; + this.initializer = initializer; + this.aggregator = aggregator; + } + + @Override + public Processor, Change> get() { + return new KStreamWindowAggregateProcessor(); + } + + public Windows windows() { + return windows; + } + + @Override + public void enableSendingOldValues() { + sendOldValues = true; + } + + + private class KStreamWindowAggregateProcessor extends ContextualProcessor, Change> { + private TimestampedWindowStore windowStore; + private TimestampedTupleForwarder, VAgg> tupleForwarder; + private Sensor droppedRecordsSensor; + private long observedStreamTime = ConsumerRecord.NO_TIMESTAMP; + + @Override + public void init(final ProcessorContext, Change> context) { + super.init(context); + final InternalProcessorContext, Change> internalProcessorContext = + (InternalProcessorContext, Change>) context; + final StreamsMetricsImpl metrics = internalProcessorContext.metrics(); + final String threadId = Thread.currentThread().getName(); + droppedRecordsSensor = droppedRecordsSensor(threadId, context.taskId().toString(), metrics); + windowStore = context.getStateStore(storeName); + tupleForwarder = new TimestampedTupleForwarder<>( + windowStore, + context, + new TimestampedCacheFlushListener<>(context), + sendOldValues); + } + + @Override + public void process(final Record record) { + if (record.key() == null) { + if (context().recordMetadata().isPresent()) { + final RecordMetadata recordMetadata = context().recordMetadata().get(); + log.warn( + "Skipping record due to null key. " + + "topic=[{}] partition=[{}] offset=[{}]", + recordMetadata.topic(), recordMetadata.partition(), recordMetadata.offset() + ); + } else { + log.warn( + "Skipping record due to null key. Topic, partition, and offset not known." + ); + } + droppedRecordsSensor.record(); + return; + } + + // first get the matching windows + final long timestamp = record.timestamp(); + observedStreamTime = Math.max(observedStreamTime, timestamp); + final long closeTime = observedStreamTime - windows.gracePeriodMs(); + + final Map matchedWindows = windows.windowsFor(timestamp); + + // try update the window, and create the new window for the rest of unmatched window that do not exist yet + for (final Map.Entry entry : matchedWindows.entrySet()) { + final Long windowStart = entry.getKey(); + final long windowEnd = entry.getValue().end(); + if (windowEnd > closeTime) { + final ValueAndTimestamp oldAggAndTimestamp = windowStore.fetch(record.key(), windowStart); + VAgg oldAgg = getValueOrNull(oldAggAndTimestamp); + + final VAgg newAgg; + final long newTimestamp; + + if (oldAgg == null) { + oldAgg = initializer.apply(); + newTimestamp = record.timestamp(); + } else { + newTimestamp = Math.max(record.timestamp(), oldAggAndTimestamp.timestamp()); + } + + newAgg = aggregator.apply(record.key(), record.value(), oldAgg); + + // update the store with the new value + windowStore.put(record.key(), ValueAndTimestamp.make(newAgg, newTimestamp), windowStart); + tupleForwarder.maybeForward( + new Windowed<>(record.key(), entry.getValue()), + newAgg, + sendOldValues ? oldAgg : null, + newTimestamp); + } else { + if (context().recordMetadata().isPresent()) { + final RecordMetadata recordMetadata = context().recordMetadata().get(); + log.warn( + "Skipping record for expired window. " + + "topic=[{}] " + + "partition=[{}] " + + "offset=[{}] " + + "timestamp=[{}] " + + "window=[{},{}) " + + "expiration=[{}] " + + "streamTime=[{}]", + recordMetadata.topic(), recordMetadata.partition(), recordMetadata.offset(), + record.timestamp(), + windowStart, windowEnd, + closeTime, + observedStreamTime + ); + } else { + log.warn( + "Skipping record for expired window. Topic, partition, and offset not known. " + + "timestamp=[{}] " + + "window=[{},{}] " + + "expiration=[{}] " + + "streamTime=[{}]", + record.timestamp(), + windowStart, windowEnd, + closeTime, + observedStreamTime + ); + } + droppedRecordsSensor.record(); + } + } + } + } + + @Override + public KTableValueGetterSupplier, VAgg> view() { + return new KTableValueGetterSupplier, VAgg>() { + + public KTableValueGetter, VAgg> get() { + return new KStreamWindowAggregateValueGetter(); + } + + @Override + public String[] storeNames() { + return new String[] {storeName}; + } + }; + } + + private class KStreamWindowAggregateValueGetter implements KTableValueGetter, VAgg> { + private TimestampedWindowStore windowStore; + + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + windowStore = context.getStateStore(storeName); + } + + @SuppressWarnings("unchecked") + @Override + public ValueAndTimestamp get(final Windowed windowedKey) { + final KIn key = windowedKey.key(); + final W window = (W) windowedKey.window(); + return windowStore.fetch(key, window.start()); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableAggregate.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableAggregate.java new file mode 100644 index 0000000..410a49b --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableAggregate.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; + +import static org.apache.kafka.streams.state.ValueAndTimestamp.getValueOrNull; + +public class KTableAggregate implements KTableNewProcessorSupplier { + + private final String storeName; + private final Initializer initializer; + private final Aggregator add; + private final Aggregator remove; + + private boolean sendOldValues = false; + + KTableAggregate(final String storeName, + final Initializer initializer, + final Aggregator add, + final Aggregator remove) { + this.storeName = storeName; + this.initializer = initializer; + this.add = add; + this.remove = remove; + } + + @Override + public boolean enableSendingOldValues(final boolean forceMaterialization) { + // Aggregates are always materialized: + sendOldValues = true; + return true; + } + + @Override + public Processor, KIn, Change> get() { + return new KTableAggregateProcessor(); + } + + private class KTableAggregateProcessor implements Processor, KIn, Change> { + private TimestampedKeyValueStore store; + private TimestampedTupleForwarder tupleForwarder; + + @SuppressWarnings("unchecked") + @Override + public void init(final ProcessorContext> context) { + store = (TimestampedKeyValueStore) context.getStateStore(storeName); + tupleForwarder = new TimestampedTupleForwarder<>( + store, + context, + new TimestampedCacheFlushListener<>(context), + sendOldValues); + } + + /** + * @throws StreamsException if key is null + */ + @Override + public void process(final Record> record) { + // the keys should never be null + if (record.key() == null) { + throw new StreamsException("Record key for KTable aggregate operator with state " + storeName + " should not be null."); + } + + final ValueAndTimestamp oldAggAndTimestamp = store.get(record.key()); + final VAgg oldAgg = getValueOrNull(oldAggAndTimestamp); + final VAgg intermediateAgg; + long newTimestamp = record.timestamp(); + + // first try to remove the old value + if (record.value().oldValue != null && oldAgg != null) { + intermediateAgg = remove.apply(record.key(), record.value().oldValue, oldAgg); + newTimestamp = Math.max(record.timestamp(), oldAggAndTimestamp.timestamp()); + } else { + intermediateAgg = oldAgg; + } + + // then try to add the new value + final VAgg newAgg; + if (record.value().newValue != null) { + final VAgg initializedAgg; + if (intermediateAgg == null) { + initializedAgg = initializer.apply(); + } else { + initializedAgg = intermediateAgg; + } + + newAgg = add.apply(record.key(), record.value().newValue, initializedAgg); + if (oldAggAndTimestamp != null) { + newTimestamp = Math.max(record.timestamp(), oldAggAndTimestamp.timestamp()); + } + } else { + newAgg = intermediateAgg; + } + + // update the store with the new value + store.put(record.key(), ValueAndTimestamp.make(newAgg, newTimestamp)); + tupleForwarder.maybeForward(record.key(), newAgg, sendOldValues ? oldAgg : null, newTimestamp); + } + + } + + @Override + public KTableValueGetterSupplier view() { + return new KTableMaterializedValueGetterSupplier<>(storeName); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableFilter.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableFilter.java new file mode 100644 index 0000000..3d97408 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableFilter.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.kstream.Predicate; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; + +import static org.apache.kafka.streams.state.ValueAndTimestamp.getValueOrNull; + +class KTableFilter implements KTableNewProcessorSupplier { + private final KTableImpl parent; + private final Predicate predicate; + private final boolean filterNot; + private final String queryableName; + private boolean sendOldValues; + + KTableFilter(final KTableImpl parent, + final Predicate predicate, + final boolean filterNot, + final String queryableName) { + this.parent = parent; + this.predicate = predicate; + this.filterNot = filterNot; + this.queryableName = queryableName; + // If upstream is already materialized, enable sending old values to avoid sending unnecessary tombstones: + this.sendOldValues = parent.enableSendingOldValues(false); + } + + @Override + public Processor, KIn, Change> get() { + return new KTableFilterProcessor(); + } + + @Override + public boolean enableSendingOldValues(final boolean forceMaterialization) { + if (queryableName != null) { + sendOldValues = true; + return true; + } + + if (parent.enableSendingOldValues(forceMaterialization)) { + sendOldValues = true; + } + return sendOldValues; + } + + private VIn computeValue(final KIn key, final VIn value) { + VIn newValue = null; + + if (value != null && (filterNot ^ predicate.test(key, value))) { + newValue = value; + } + + return newValue; + } + + private ValueAndTimestamp computeValue(final KIn key, final ValueAndTimestamp valueAndTimestamp) { + ValueAndTimestamp newValueAndTimestamp = null; + + if (valueAndTimestamp != null) { + final VIn value = valueAndTimestamp.value(); + if (filterNot ^ predicate.test(key, value)) { + newValueAndTimestamp = valueAndTimestamp; + } + } + + return newValueAndTimestamp; + } + + + private class KTableFilterProcessor implements Processor, KIn, Change> { + private ProcessorContext> context; + private TimestampedKeyValueStore store; + private TimestampedTupleForwarder tupleForwarder; + + @Override + public void init(final ProcessorContext> context) { + this.context = context; + if (queryableName != null) { + store = context.getStateStore(queryableName); + tupleForwarder = new TimestampedTupleForwarder<>( + store, + context, + new TimestampedCacheFlushListener<>(context), + sendOldValues); + } + } + + @Override + public void process(final Record> record) { + final KIn key = record.key(); + final Change change = record.value(); + + final VIn newValue = computeValue(key, change.newValue); + final VIn oldValue = computeOldValue(key, change); + + if (sendOldValues && oldValue == null && newValue == null) { + return; // unnecessary to forward here. + } + + if (queryableName != null) { + store.put(key, ValueAndTimestamp.make(newValue, record.timestamp())); + tupleForwarder.maybeForward(record.withValue(new Change<>(newValue, oldValue))); + } else { + context.forward(record.withValue(new Change<>(newValue, oldValue))); + } + } + + private VIn computeOldValue(final KIn key, final Change change) { + if (!sendOldValues) { + return null; + } + + return queryableName != null + ? getValueOrNull(store.get(key)) + : computeValue(key, change.oldValue); + } + } + + @Override + public KTableValueGetterSupplier view() { + // if the KTable is materialized, use the materialized store to return getter value; + // otherwise rely on the parent getter and apply filter on-the-fly + if (queryableName != null) { + return new KTableMaterializedValueGetterSupplier<>(queryableName); + } else { + return new KTableValueGetterSupplier() { + final KTableValueGetterSupplier parentValueGetterSupplier = parent.valueGetterSupplier(); + + public KTableValueGetter get() { + return new KTableFilterValueGetter(parentValueGetterSupplier.get()); + } + + @Override + public String[] storeNames() { + return parentValueGetterSupplier.storeNames(); + } + }; + } + } + + + private class KTableFilterValueGetter implements KTableValueGetter { + private final KTableValueGetter parentGetter; + + KTableFilterValueGetter(final KTableValueGetter parentGetter) { + this.parentGetter = parentGetter; + } + + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + // This is the old processor context for compatibility with the other KTable processors. + // Once we migrte them all, we can swap this out. + parentGetter.init(context); + } + + @Override + public ValueAndTimestamp get(final KIn key) { + return computeValue(key, parentGetter.get(key)); + } + + @Override + public void close() { + parentGetter.close(); + } + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableImpl.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableImpl.java new file mode 100644 index 0000000..8c2fb69 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableImpl.java @@ -0,0 +1,1294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KGroupedTable; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Named; +import org.apache.kafka.streams.kstream.Predicate; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.Suppressed; +import org.apache.kafka.streams.kstream.TableJoined; +import org.apache.kafka.streams.kstream.ValueJoiner; +import org.apache.kafka.streams.kstream.ValueMapper; +import org.apache.kafka.streams.kstream.ValueMapperWithKey; +import org.apache.kafka.streams.kstream.ValueTransformerWithKeySupplier; +import org.apache.kafka.streams.kstream.internals.foreignkeyjoin.CombinedKey; +import org.apache.kafka.streams.kstream.internals.foreignkeyjoin.CombinedKeySchema; +import org.apache.kafka.streams.kstream.internals.foreignkeyjoin.ForeignJoinSubscriptionProcessorSupplier; +import org.apache.kafka.streams.kstream.internals.foreignkeyjoin.ForeignJoinSubscriptionSendProcessorSupplier; +import org.apache.kafka.streams.kstream.internals.foreignkeyjoin.SubscriptionJoinForeignProcessorSupplier; +import org.apache.kafka.streams.kstream.internals.foreignkeyjoin.SubscriptionResolverJoinProcessorSupplier; +import org.apache.kafka.streams.kstream.internals.foreignkeyjoin.SubscriptionResponseWrapper; +import org.apache.kafka.streams.kstream.internals.foreignkeyjoin.SubscriptionResponseWrapperSerde; +import org.apache.kafka.streams.kstream.internals.foreignkeyjoin.SubscriptionStoreReceiveProcessorSupplier; +import org.apache.kafka.streams.kstream.internals.foreignkeyjoin.SubscriptionWrapper; +import org.apache.kafka.streams.kstream.internals.foreignkeyjoin.SubscriptionWrapperSerde; +import org.apache.kafka.streams.kstream.internals.graph.KTableKTableJoinNode; +import org.apache.kafka.streams.kstream.internals.graph.ProcessorGraphNode; +import org.apache.kafka.streams.kstream.internals.graph.ProcessorParameters; +import org.apache.kafka.streams.kstream.internals.graph.StatefulProcessorNode; +import org.apache.kafka.streams.kstream.internals.graph.StreamSinkNode; +import org.apache.kafka.streams.kstream.internals.graph.StreamSourceNode; +import org.apache.kafka.streams.kstream.internals.graph.GraphNode; +import org.apache.kafka.streams.kstream.internals.graph.TableProcessorNode; +import org.apache.kafka.streams.kstream.internals.suppress.FinalResultsSuppressionBuilder; +import org.apache.kafka.streams.kstream.internals.suppress.KTableSuppressProcessorSupplier; +import org.apache.kafka.streams.kstream.internals.suppress.NamedSuppressed; +import org.apache.kafka.streams.kstream.internals.suppress.SuppressedInternal; +import org.apache.kafka.streams.processor.StreamPartitioner; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.internals.InternalTopicProperties; +import org.apache.kafka.streams.processor.internals.StaticTopicNameExtractor; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.streams.state.internals.InMemoryTimeOrderedKeyValueBuffer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.util.Collections; +import java.util.HashSet; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.function.Function; +import java.util.function.Supplier; + +import static org.apache.kafka.streams.kstream.internals.graph.GraphGraceSearchUtil.findAndVerifyWindowGrace; + +/** + * The implementation class of {@link KTable}. + * + * @param the key type + * @param the source's (parent's) value type + * @param the value type + */ +public class KTableImpl extends AbstractStream implements KTable { + private static final Logger LOG = LoggerFactory.getLogger(KTableImpl.class); + + static final String SOURCE_NAME = "KTABLE-SOURCE-"; + + static final String STATE_STORE_NAME = "STATE-STORE-"; + + private static final String FILTER_NAME = "KTABLE-FILTER-"; + + private static final String JOINTHIS_NAME = "KTABLE-JOINTHIS-"; + + private static final String JOINOTHER_NAME = "KTABLE-JOINOTHER-"; + + private static final String MAPVALUES_NAME = "KTABLE-MAPVALUES-"; + + private static final String MERGE_NAME = "KTABLE-MERGE-"; + + private static final String SELECT_NAME = "KTABLE-SELECT-"; + + private static final String SUPPRESS_NAME = "KTABLE-SUPPRESS-"; + + private static final String TOSTREAM_NAME = "KTABLE-TOSTREAM-"; + + private static final String TRANSFORMVALUES_NAME = "KTABLE-TRANSFORMVALUES-"; + + private static final String FK_JOIN = "KTABLE-FK-JOIN-"; + private static final String FK_JOIN_STATE_STORE_NAME = FK_JOIN + "SUBSCRIPTION-STATE-STORE-"; + private static final String SUBSCRIPTION_REGISTRATION = FK_JOIN + "SUBSCRIPTION-REGISTRATION-"; + private static final String SUBSCRIPTION_RESPONSE = FK_JOIN + "SUBSCRIPTION-RESPONSE-"; + private static final String SUBSCRIPTION_PROCESSOR = FK_JOIN + "SUBSCRIPTION-PROCESSOR-"; + private static final String SUBSCRIPTION_RESPONSE_RESOLVER_PROCESSOR = FK_JOIN + "SUBSCRIPTION-RESPONSE-RESOLVER-PROCESSOR-"; + private static final String FK_JOIN_OUTPUT_NAME = FK_JOIN + "OUTPUT-"; + + private static final String TOPIC_SUFFIX = "-topic"; + private static final String SINK_NAME = "KTABLE-SINK-"; + + // Temporarily setting the processorSupplier to type Object so that we can transition from the + // old ProcessorSupplier to the new api.ProcessorSupplier. This works because all accesses to + // this field are guarded by typechecks anyway. + private final Object processorSupplier; + + private final String queryableStoreName; + + private boolean sendOldValues = false; + + @SuppressWarnings("deprecation") // Old PAPI compatibility. + public KTableImpl(final String name, + final Serde keySerde, + final Serde valueSerde, + final Set subTopologySourceNodes, + final String queryableStoreName, + final org.apache.kafka.streams.processor.ProcessorSupplier processorSupplier, + final GraphNode graphNode, + final InternalStreamsBuilder builder) { + super(name, keySerde, valueSerde, subTopologySourceNodes, graphNode, builder); + this.processorSupplier = processorSupplier; + this.queryableStoreName = queryableStoreName; + } + + public KTableImpl(final String name, + final Serde keySerde, + final Serde valueSerde, + final Set subTopologySourceNodes, + final String queryableStoreName, + final org.apache.kafka.streams.processor.api.ProcessorSupplier newProcessorSupplier, + final GraphNode graphNode, + final InternalStreamsBuilder builder) { + super(name, keySerde, valueSerde, subTopologySourceNodes, graphNode, builder); + this.processorSupplier = newProcessorSupplier; + this.queryableStoreName = queryableStoreName; + } + + @Override + public String queryableStoreName() { + return queryableStoreName; + } + + private KTable doFilter(final Predicate predicate, + final Named named, + final MaterializedInternal> materializedInternal, + final boolean filterNot) { + final Serde keySerde; + final Serde valueSerde; + final String queryableStoreName; + final StoreBuilder> storeBuilder; + + if (materializedInternal != null) { + // we actually do not need to generate store names at all since if it is not specified, we will not + // materialize the store; but we still need to burn one index BEFORE generating the processor to keep compatibility. + if (materializedInternal.storeName() == null) { + builder.newStoreName(FILTER_NAME); + } + // we can inherit parent key and value serde if user do not provide specific overrides, more specifically: + // we preserve the key following the order of 1) materialized, 2) parent + keySerde = materializedInternal.keySerde() != null ? materializedInternal.keySerde() : this.keySerde; + // we preserve the value following the order of 1) materialized, 2) parent + valueSerde = materializedInternal.valueSerde() != null ? materializedInternal.valueSerde() : this.valueSerde; + queryableStoreName = materializedInternal.queryableStoreName(); + // only materialize if materialized is specified and it has queryable name + storeBuilder = queryableStoreName != null ? (new TimestampedKeyValueStoreMaterializer<>(materializedInternal)).materialize() : null; + } else { + keySerde = this.keySerde; + valueSerde = this.valueSerde; + queryableStoreName = null; + storeBuilder = null; + } + final String name = new NamedInternal(named).orElseGenerateWithPrefix(builder, FILTER_NAME); + + final KTableNewProcessorSupplier processorSupplier = + new KTableFilter<>(this, predicate, filterNot, queryableStoreName); + + final ProcessorParameters processorParameters = unsafeCastProcessorParametersToCompletelyDifferentType( + new ProcessorParameters<>(processorSupplier, name) + ); + + final GraphNode tableNode = new TableProcessorNode<>( + name, + processorParameters, + storeBuilder + ); + + builder.addGraphNode(this.graphNode, tableNode); + + return new KTableImpl( + name, + keySerde, + valueSerde, + subTopologySourceNodes, + queryableStoreName, + processorSupplier, + tableNode, + builder); + } + + @Override + public KTable filter(final Predicate predicate) { + Objects.requireNonNull(predicate, "predicate can't be null"); + return doFilter(predicate, NamedInternal.empty(), null, false); + } + + @Override + public KTable filter(final Predicate predicate, final Named named) { + Objects.requireNonNull(predicate, "predicate can't be null"); + return doFilter(predicate, named, null, false); + } + + @Override + public KTable filter(final Predicate predicate, + final Named named, + final Materialized> materialized) { + Objects.requireNonNull(predicate, "predicate can't be null"); + Objects.requireNonNull(materialized, "materialized can't be null"); + final MaterializedInternal> materializedInternal = new MaterializedInternal<>(materialized); + + return doFilter(predicate, named, materializedInternal, false); + } + + @Override + public KTable filter(final Predicate predicate, + final Materialized> materialized) { + return filter(predicate, NamedInternal.empty(), materialized); + } + + @Override + public KTable filterNot(final Predicate predicate) { + Objects.requireNonNull(predicate, "predicate can't be null"); + return doFilter(predicate, NamedInternal.empty(), null, true); + } + + @Override + public KTable filterNot(final Predicate predicate, + final Named named) { + Objects.requireNonNull(predicate, "predicate can't be null"); + return doFilter(predicate, named, null, true); + } + + @Override + public KTable filterNot(final Predicate predicate, + final Materialized> materialized) { + return filterNot(predicate, NamedInternal.empty(), materialized); + } + + @Override + public KTable filterNot(final Predicate predicate, + final Named named, + final Materialized> materialized) { + Objects.requireNonNull(predicate, "predicate can't be null"); + Objects.requireNonNull(materialized, "materialized can't be null"); + final MaterializedInternal> materializedInternal = new MaterializedInternal<>(materialized); + final NamedInternal renamed = new NamedInternal(named); + return doFilter(predicate, renamed, materializedInternal, true); + } + + private KTable doMapValues(final ValueMapperWithKey mapper, + final Named named, + final MaterializedInternal> materializedInternal) { + final Serde keySerde; + final Serde valueSerde; + final String queryableStoreName; + final StoreBuilder> storeBuilder; + + if (materializedInternal != null) { + // we actually do not need to generate store names at all since if it is not specified, we will not + // materialize the store; but we still need to burn one index BEFORE generating the processor to keep compatibility. + if (materializedInternal.storeName() == null) { + builder.newStoreName(MAPVALUES_NAME); + } + keySerde = materializedInternal.keySerde() != null ? materializedInternal.keySerde() : this.keySerde; + valueSerde = materializedInternal.valueSerde(); + queryableStoreName = materializedInternal.queryableStoreName(); + // only materialize if materialized is specified and it has queryable name + storeBuilder = queryableStoreName != null ? (new TimestampedKeyValueStoreMaterializer<>(materializedInternal)).materialize() : null; + } else { + keySerde = this.keySerde; + valueSerde = null; + queryableStoreName = null; + storeBuilder = null; + } + + final String name = new NamedInternal(named).orElseGenerateWithPrefix(builder, MAPVALUES_NAME); + + final KTableNewProcessorSupplier processorSupplier = new KTableMapValues<>(this, mapper, queryableStoreName); + + // leaving in calls to ITB until building topology with graph + + final ProcessorParameters processorParameters = unsafeCastProcessorParametersToCompletelyDifferentType( + new ProcessorParameters<>(processorSupplier, name) + ); + final GraphNode tableNode = new TableProcessorNode<>( + name, + processorParameters, + storeBuilder + ); + + builder.addGraphNode(this.graphNode, tableNode); + + // don't inherit parent value serde, since this operation may change the value type, more specifically: + // we preserve the key following the order of 1) materialized, 2) parent, 3) null + // we preserve the value following the order of 1) materialized, 2) null + return new KTableImpl<>( + name, + keySerde, + valueSerde, + subTopologySourceNodes, + queryableStoreName, + processorSupplier, + tableNode, + builder + ); + } + + @Override + public KTable mapValues(final ValueMapper mapper) { + Objects.requireNonNull(mapper, "mapper can't be null"); + return doMapValues(withKey(mapper), NamedInternal.empty(), null); + } + + @Override + public KTable mapValues(final ValueMapper mapper, + final Named named) { + Objects.requireNonNull(mapper, "mapper can't be null"); + return doMapValues(withKey(mapper), named, null); + } + + @Override + public KTable mapValues(final ValueMapperWithKey mapper) { + Objects.requireNonNull(mapper, "mapper can't be null"); + return doMapValues(mapper, NamedInternal.empty(), null); + } + + @Override + public KTable mapValues(final ValueMapperWithKey mapper, + final Named named) { + Objects.requireNonNull(mapper, "mapper can't be null"); + return doMapValues(mapper, named, null); + } + + @Override + public KTable mapValues(final ValueMapper mapper, + final Materialized> materialized) { + return mapValues(mapper, NamedInternal.empty(), materialized); + } + + @Override + public KTable mapValues(final ValueMapper mapper, + final Named named, + final Materialized> materialized) { + Objects.requireNonNull(mapper, "mapper can't be null"); + Objects.requireNonNull(materialized, "materialized can't be null"); + + final MaterializedInternal> materializedInternal = new MaterializedInternal<>(materialized); + + return doMapValues(withKey(mapper), named, materializedInternal); + } + + @Override + public KTable mapValues(final ValueMapperWithKey mapper, + final Materialized> materialized) { + return mapValues(mapper, NamedInternal.empty(), materialized); + } + + @Override + public KTable mapValues(final ValueMapperWithKey mapper, + final Named named, + final Materialized> materialized) { + Objects.requireNonNull(mapper, "mapper can't be null"); + Objects.requireNonNull(materialized, "materialized can't be null"); + + final MaterializedInternal> materializedInternal = new MaterializedInternal<>(materialized); + + return doMapValues(mapper, named, materializedInternal); + } + + @Override + public KTable transformValues(final ValueTransformerWithKeySupplier transformerSupplier, + final String... stateStoreNames) { + return doTransformValues(transformerSupplier, null, NamedInternal.empty(), stateStoreNames); + } + + @Override + public KTable transformValues(final ValueTransformerWithKeySupplier transformerSupplier, + final Named named, + final String... stateStoreNames) { + Objects.requireNonNull(named, "processorName can't be null"); + return doTransformValues(transformerSupplier, null, new NamedInternal(named), stateStoreNames); + } + + @Override + public KTable transformValues(final ValueTransformerWithKeySupplier transformerSupplier, + final Materialized> materialized, + final String... stateStoreNames) { + return transformValues(transformerSupplier, materialized, NamedInternal.empty(), stateStoreNames); + } + + @Override + public KTable transformValues(final ValueTransformerWithKeySupplier transformerSupplier, + final Materialized> materialized, + final Named named, + final String... stateStoreNames) { + Objects.requireNonNull(materialized, "materialized can't be null"); + Objects.requireNonNull(named, "named can't be null"); + final MaterializedInternal> materializedInternal = new MaterializedInternal<>(materialized); + + return doTransformValues(transformerSupplier, materializedInternal, new NamedInternal(named), stateStoreNames); + } + + private KTable doTransformValues(final ValueTransformerWithKeySupplier transformerSupplier, + final MaterializedInternal> materializedInternal, + final NamedInternal namedInternal, + final String... stateStoreNames) { + Objects.requireNonNull(stateStoreNames, "stateStoreNames"); + final Serde keySerde; + final Serde valueSerde; + final String queryableStoreName; + final StoreBuilder> storeBuilder; + + if (materializedInternal != null) { + // don't inherit parent value serde, since this operation may change the value type, more specifically: + // we preserve the key following the order of 1) materialized, 2) parent, 3) null + keySerde = materializedInternal.keySerde() != null ? materializedInternal.keySerde() : this.keySerde; + // we preserve the value following the order of 1) materialized, 2) null + valueSerde = materializedInternal.valueSerde(); + queryableStoreName = materializedInternal.queryableStoreName(); + // only materialize if materialized is specified and it has queryable name + storeBuilder = queryableStoreName != null ? (new TimestampedKeyValueStoreMaterializer<>(materializedInternal)).materialize() : null; + } else { + keySerde = this.keySerde; + valueSerde = null; + queryableStoreName = null; + storeBuilder = null; + } + + final String name = namedInternal.orElseGenerateWithPrefix(builder, TRANSFORMVALUES_NAME); + + final KTableProcessorSupplier processorSupplier = new KTableTransformValues<>( + this, + transformerSupplier, + queryableStoreName); + + final ProcessorParameters processorParameters = unsafeCastProcessorParametersToCompletelyDifferentType( + new ProcessorParameters<>(processorSupplier, name) + ); + + final GraphNode tableNode = new TableProcessorNode<>( + name, + processorParameters, + storeBuilder, + stateStoreNames + ); + + builder.addGraphNode(this.graphNode, tableNode); + + return new KTableImpl<>( + name, + keySerde, + valueSerde, + subTopologySourceNodes, + queryableStoreName, + processorSupplier, + tableNode, + builder); + } + + @Override + public KStream toStream() { + return toStream(NamedInternal.empty()); + } + + @Override + public KStream toStream(final Named named) { + Objects.requireNonNull(named, "named can't be null"); + + final String name = new NamedInternal(named).orElseGenerateWithPrefix(builder, TOSTREAM_NAME); + final KStreamMapValues, V> kStreamMapValues = new KStreamMapValues<>((key, change) -> change.newValue); + final ProcessorParameters processorParameters = unsafeCastProcessorParametersToCompletelyDifferentType( + new ProcessorParameters<>(kStreamMapValues, name) + ); + + final ProcessorGraphNode toStreamNode = new ProcessorGraphNode<>( + name, + processorParameters + ); + + builder.addGraphNode(this.graphNode, toStreamNode); + + // we can inherit parent key and value serde + return new KStreamImpl<>(name, keySerde, valueSerde, subTopologySourceNodes, false, toStreamNode, builder); + } + + @Override + public KStream toStream(final KeyValueMapper mapper) { + return toStream().selectKey(mapper); + } + + @Override + public KStream toStream(final KeyValueMapper mapper, + final Named named) { + return toStream(named).selectKey(mapper); + } + + @Override + public KTable suppress(final Suppressed suppressed) { + final String name; + if (suppressed instanceof NamedSuppressed) { + final String givenName = ((NamedSuppressed) suppressed).name(); + name = givenName != null ? givenName : builder.newProcessorName(SUPPRESS_NAME); + } else { + throw new IllegalArgumentException("Custom subclasses of Suppressed are not supported."); + } + + final SuppressedInternal suppressedInternal = buildSuppress(suppressed, name); + + final String storeName = + suppressedInternal.name() != null ? suppressedInternal.name() + "-store" : builder.newStoreName(SUPPRESS_NAME); + + final ProcessorSupplier, K, Change> suppressionSupplier = new KTableSuppressProcessorSupplier<>( + suppressedInternal, + storeName, + this + ); + + final StoreBuilder> storeBuilder; + + if (suppressedInternal.bufferConfig().isLoggingEnabled()) { + final Map topicConfig = suppressedInternal.bufferConfig().getLogConfig(); + storeBuilder = new InMemoryTimeOrderedKeyValueBuffer.Builder<>( + storeName, + keySerde, + valueSerde) + .withLoggingEnabled(topicConfig); + } else { + storeBuilder = new InMemoryTimeOrderedKeyValueBuffer.Builder<>( + storeName, + keySerde, + valueSerde) + .withLoggingDisabled(); + } + + final ProcessorGraphNode> node = new StatefulProcessorNode<>( + name, + new ProcessorParameters<>(suppressionSupplier, name), + storeBuilder + ); + + builder.addGraphNode(graphNode, node); + + return new KTableImpl( + name, + keySerde, + valueSerde, + Collections.singleton(this.name), + null, + suppressionSupplier, + node, + builder + ); + } + + @SuppressWarnings("unchecked") + private SuppressedInternal buildSuppress(final Suppressed suppress, final String name) { + if (suppress instanceof FinalResultsSuppressionBuilder) { + final long grace = findAndVerifyWindowGrace(graphNode); + LOG.info("Using grace period of [{}] as the suppress duration for node [{}].", + Duration.ofMillis(grace), name); + + final FinalResultsSuppressionBuilder builder = (FinalResultsSuppressionBuilder) suppress; + + final SuppressedInternal finalResultsSuppression = + builder.buildFinalResultsSuppression(Duration.ofMillis(grace)); + + return (SuppressedInternal) finalResultsSuppression; + } else if (suppress instanceof SuppressedInternal) { + return (SuppressedInternal) suppress; + } else { + throw new IllegalArgumentException("Custom subclasses of Suppressed are not allowed."); + } + } + + @Override + public KTable join(final KTable other, + final ValueJoiner joiner) { + return doJoin(other, joiner, NamedInternal.empty(), null, false, false); + } + + @Override + public KTable join(final KTable other, + final ValueJoiner joiner, + final Named named) { + return doJoin(other, joiner, named, null, false, false); + } + + @Override + public KTable join(final KTable other, + final ValueJoiner joiner, + final Materialized> materialized) { + return join(other, joiner, NamedInternal.empty(), materialized); + } + + @Override + public KTable join(final KTable other, + final ValueJoiner joiner, + final Named named, + final Materialized> materialized) { + Objects.requireNonNull(materialized, "materialized can't be null"); + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(materialized, builder, MERGE_NAME); + + return doJoin(other, joiner, named, materializedInternal, false, false); + } + + @Override + public KTable outerJoin(final KTable other, + final ValueJoiner joiner) { + return outerJoin(other, joiner, NamedInternal.empty()); + } + + @Override + public KTable outerJoin(final KTable other, + final ValueJoiner joiner, + final Named named) { + return doJoin(other, joiner, named, null, true, true); + } + + @Override + public KTable outerJoin(final KTable other, + final ValueJoiner joiner, + final Materialized> materialized) { + return outerJoin(other, joiner, NamedInternal.empty(), materialized); + } + + @Override + public KTable outerJoin(final KTable other, + final ValueJoiner joiner, + final Named named, + final Materialized> materialized) { + Objects.requireNonNull(materialized, "materialized can't be null"); + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(materialized, builder, MERGE_NAME); + + return doJoin(other, joiner, named, materializedInternal, true, true); + } + + @Override + public KTable leftJoin(final KTable other, + final ValueJoiner joiner) { + return leftJoin(other, joiner, NamedInternal.empty()); + } + + @Override + public KTable leftJoin(final KTable other, + final ValueJoiner joiner, + final Named named) { + return doJoin(other, joiner, named, null, true, false); + } + + @Override + public KTable leftJoin(final KTable other, + final ValueJoiner joiner, + final Materialized> materialized) { + return leftJoin(other, joiner, NamedInternal.empty(), materialized); + } + + @Override + public KTable leftJoin(final KTable other, + final ValueJoiner joiner, + final Named named, + final Materialized> materialized) { + Objects.requireNonNull(materialized, "materialized can't be null"); + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(materialized, builder, MERGE_NAME); + + return doJoin(other, joiner, named, materializedInternal, true, false); + } + + @SuppressWarnings("unchecked") + private KTable doJoin(final KTable other, + final ValueJoiner joiner, + final Named joinName, + final MaterializedInternal> materializedInternal, + final boolean leftOuter, + final boolean rightOuter) { + Objects.requireNonNull(other, "other can't be null"); + Objects.requireNonNull(joiner, "joiner can't be null"); + Objects.requireNonNull(joinName, "joinName can't be null"); + + final NamedInternal renamed = new NamedInternal(joinName); + final String joinMergeName = renamed.orElseGenerateWithPrefix(builder, MERGE_NAME); + final Set allSourceNodes = ensureCopartitionWith(Collections.singleton((AbstractStream) other)); + + if (leftOuter) { + enableSendingOldValues(true); + } + if (rightOuter) { + ((KTableImpl) other).enableSendingOldValues(true); + } + + final KTableKTableAbstractJoin joinThis; + final KTableKTableAbstractJoin joinOther; + + if (!leftOuter) { // inner + joinThis = new KTableKTableInnerJoin<>(this, (KTableImpl) other, joiner); + joinOther = new KTableKTableInnerJoin<>((KTableImpl) other, this, reverseJoiner(joiner)); + } else if (!rightOuter) { // left + joinThis = new KTableKTableLeftJoin<>(this, (KTableImpl) other, joiner); + joinOther = new KTableKTableRightJoin<>((KTableImpl) other, this, reverseJoiner(joiner)); + } else { // outer + joinThis = new KTableKTableOuterJoin<>(this, (KTableImpl) other, joiner); + joinOther = new KTableKTableOuterJoin<>((KTableImpl) other, this, reverseJoiner(joiner)); + } + + final String joinThisName = renamed.suffixWithOrElseGet("-join-this", builder, JOINTHIS_NAME); + final String joinOtherName = renamed.suffixWithOrElseGet("-join-other", builder, JOINOTHER_NAME); + + final ProcessorParameters, ?, ?> joinThisProcessorParameters = new ProcessorParameters<>(joinThis, joinThisName); + final ProcessorParameters, ?, ?> joinOtherProcessorParameters = new ProcessorParameters<>(joinOther, joinOtherName); + + final Serde keySerde; + final Serde valueSerde; + final String queryableStoreName; + final StoreBuilder> storeBuilder; + + if (materializedInternal != null) { + if (materializedInternal.keySerde() == null) { + materializedInternal.withKeySerde(this.keySerde); + } + keySerde = materializedInternal.keySerde(); + valueSerde = materializedInternal.valueSerde(); + queryableStoreName = materializedInternal.storeName(); + storeBuilder = new TimestampedKeyValueStoreMaterializer<>(materializedInternal).materialize(); + } else { + keySerde = this.keySerde; + valueSerde = null; + queryableStoreName = null; + storeBuilder = null; + } + + final KTableKTableJoinNode kTableKTableJoinNode = + KTableKTableJoinNode.kTableKTableJoinNodeBuilder() + .withNodeName(joinMergeName) + .withJoinThisProcessorParameters(joinThisProcessorParameters) + .withJoinOtherProcessorParameters(joinOtherProcessorParameters) + .withThisJoinSideNodeName(name) + .withOtherJoinSideNodeName(((KTableImpl) other).name) + .withJoinThisStoreNames(valueGetterSupplier().storeNames()) + .withJoinOtherStoreNames(((KTableImpl) other).valueGetterSupplier().storeNames()) + .withKeySerde(keySerde) + .withValueSerde(valueSerde) + .withQueryableStoreName(queryableStoreName) + .withStoreBuilder(storeBuilder) + .build(); + builder.addGraphNode(this.graphNode, kTableKTableJoinNode); + + // we can inherit parent key serde if user do not provide specific overrides + return new KTableImpl, VR>( + kTableKTableJoinNode.nodeName(), + kTableKTableJoinNode.keySerde(), + kTableKTableJoinNode.valueSerde(), + allSourceNodes, + kTableKTableJoinNode.queryableStoreName(), + kTableKTableJoinNode.joinMerger(), + kTableKTableJoinNode, + builder + ); + } + + @Override + public KGroupedTable groupBy(final KeyValueMapper> selector) { + return groupBy(selector, Grouped.with(null, null)); + } + + @Override + public KGroupedTable groupBy(final KeyValueMapper> selector, + final Grouped grouped) { + Objects.requireNonNull(selector, "selector can't be null"); + Objects.requireNonNull(grouped, "grouped can't be null"); + final GroupedInternal groupedInternal = new GroupedInternal<>(grouped); + final String selectName = new NamedInternal(groupedInternal.name()).orElseGenerateWithPrefix(builder, SELECT_NAME); + + final KTableProcessorSupplier> selectSupplier = new KTableRepartitionMap<>(this, selector); + final ProcessorParameters, ?, ?> processorParameters = new ProcessorParameters<>(selectSupplier, selectName); + + // select the aggregate key and values (old and new), it would require parent to send old values + final ProcessorGraphNode> groupByMapNode = new ProcessorGraphNode<>(selectName, processorParameters); + + builder.addGraphNode(this.graphNode, groupByMapNode); + + this.enableSendingOldValues(true); + return new KGroupedTableImpl<>( + builder, + selectName, + subTopologySourceNodes, + groupedInternal, + groupByMapNode + ); + } + + @SuppressWarnings("unchecked") + public KTableValueGetterSupplier valueGetterSupplier() { + if (processorSupplier instanceof KTableSource) { + final KTableSource source = (KTableSource) processorSupplier; + // whenever a source ktable is required for getter, it should be materialized + source.materialize(); + return new KTableSourceValueGetterSupplier<>(source.queryableName()); + } else if (processorSupplier instanceof KStreamAggProcessorSupplier) { + return ((KStreamAggProcessorSupplier) processorSupplier).view(); + } else if (processorSupplier instanceof KTableNewProcessorSupplier) { + return ((KTableNewProcessorSupplier) processorSupplier).view(); + } else { + return ((KTableProcessorSupplier) processorSupplier).view(); + } + } + + @SuppressWarnings("unchecked") + public boolean enableSendingOldValues(final boolean forceMaterialization) { + if (!sendOldValues) { + if (processorSupplier instanceof KTableSource) { + final KTableSource source = (KTableSource) processorSupplier; + if (!forceMaterialization && !source.materialized()) { + return false; + } + source.enableSendingOldValues(); + } else if (processorSupplier instanceof KStreamAggProcessorSupplier) { + ((KStreamAggProcessorSupplier) processorSupplier).enableSendingOldValues(); + } else if (processorSupplier instanceof KTableNewProcessorSupplier) { + final KTableNewProcessorSupplier tableProcessorSupplier = + (KTableNewProcessorSupplier) processorSupplier; + if (!tableProcessorSupplier.enableSendingOldValues(forceMaterialization)) { + return false; + } + } else { + final KTableProcessorSupplier tableProcessorSupplier = (KTableProcessorSupplier) processorSupplier; + if (!tableProcessorSupplier.enableSendingOldValues(forceMaterialization)) { + return false; + } + } + sendOldValues = true; + } + return true; + } + + boolean sendingOldValueEnabled() { + return sendOldValues; + } + + /** + * We conflate V with Change in many places. This will get fixed in the implementation of KIP-478. + * For now, I'm just explicitly lying about the parameterized type. + */ + @SuppressWarnings("unchecked") + private ProcessorParameters unsafeCastProcessorParametersToCompletelyDifferentType(final ProcessorParameters, ?, ?> kObjectProcessorParameters) { + return (ProcessorParameters) kObjectProcessorParameters; + } + + @Override + public KTable join(final KTable other, + final Function foreignKeyExtractor, + final ValueJoiner joiner) { + return doJoinOnForeignKey( + other, + foreignKeyExtractor, + joiner, + TableJoined.with(null, null), + Materialized.with(null, null), + false + ); + } + + @SuppressWarnings("deprecation") + @Override + public KTable join(final KTable other, + final Function foreignKeyExtractor, + final ValueJoiner joiner, + final Named named) { + return doJoinOnForeignKey( + other, + foreignKeyExtractor, + joiner, + TableJoined.as(new NamedInternal(named).name()), + Materialized.with(null, null), + false + ); + } + + @Override + public KTable join(final KTable other, + final Function foreignKeyExtractor, + final ValueJoiner joiner, + final TableJoined tableJoined) { + return doJoinOnForeignKey( + other, + foreignKeyExtractor, + joiner, + tableJoined, + Materialized.with(null, null), + false + ); + } + + @Override + public KTable join(final KTable other, + final Function foreignKeyExtractor, + final ValueJoiner joiner, + final Materialized> materialized) { + return doJoinOnForeignKey(other, foreignKeyExtractor, joiner, TableJoined.with(null, null), materialized, false); + } + + @SuppressWarnings("deprecation") + @Override + public KTable join(final KTable other, + final Function foreignKeyExtractor, + final ValueJoiner joiner, + final Named named, + final Materialized> materialized) { + return doJoinOnForeignKey( + other, + foreignKeyExtractor, + joiner, + TableJoined.as(new NamedInternal(named).name()), + materialized, + false + ); + } + + @Override + public KTable join(final KTable other, + final Function foreignKeyExtractor, + final ValueJoiner joiner, + final TableJoined tableJoined, + final Materialized> materialized) { + return doJoinOnForeignKey( + other, + foreignKeyExtractor, + joiner, + tableJoined, + materialized, + false + ); + } + + @Override + public KTable leftJoin(final KTable other, + final Function foreignKeyExtractor, + final ValueJoiner joiner) { + return doJoinOnForeignKey( + other, + foreignKeyExtractor, + joiner, + TableJoined.with(null, null), + Materialized.with(null, null), + true + ); + } + + @SuppressWarnings("deprecation") + @Override + public KTable leftJoin(final KTable other, + final Function foreignKeyExtractor, + final ValueJoiner joiner, + final Named named) { + return doJoinOnForeignKey( + other, + foreignKeyExtractor, + joiner, + TableJoined.as(new NamedInternal(named).name()), + Materialized.with(null, null), + true + ); + } + + @Override + public KTable leftJoin(final KTable other, + final Function foreignKeyExtractor, + final ValueJoiner joiner, + final TableJoined tableJoined) { + return doJoinOnForeignKey( + other, + foreignKeyExtractor, + joiner, + tableJoined, + Materialized.with(null, null), + true + ); + } + + @SuppressWarnings("deprecation") + @Override + public KTable leftJoin(final KTable other, + final Function foreignKeyExtractor, + final ValueJoiner joiner, + final Named named, + final Materialized> materialized) { + return doJoinOnForeignKey( + other, + foreignKeyExtractor, + joiner, + TableJoined.as(new NamedInternal(named).name()), + materialized, + true); + } + + @Override + public KTable leftJoin(final KTable other, + final Function foreignKeyExtractor, + final ValueJoiner joiner, + final TableJoined tableJoined, + final Materialized> materialized) { + return doJoinOnForeignKey( + other, + foreignKeyExtractor, + joiner, + tableJoined, + materialized, + true); + } + + @Override + public KTable leftJoin(final KTable other, + final Function foreignKeyExtractor, + final ValueJoiner joiner, + final Materialized> materialized) { + return doJoinOnForeignKey(other, foreignKeyExtractor, joiner, TableJoined.with(null, null), materialized, true); + } + + @SuppressWarnings("unchecked") + private KTable doJoinOnForeignKey(final KTable foreignKeyTable, + final Function foreignKeyExtractor, + final ValueJoiner joiner, + final TableJoined tableJoined, + final Materialized> materialized, + final boolean leftJoin) { + Objects.requireNonNull(foreignKeyTable, "foreignKeyTable can't be null"); + Objects.requireNonNull(foreignKeyExtractor, "foreignKeyExtractor can't be null"); + Objects.requireNonNull(joiner, "joiner can't be null"); + Objects.requireNonNull(tableJoined, "tableJoined can't be null"); + Objects.requireNonNull(materialized, "materialized can't be null"); + + //Old values are a useful optimization. The old values from the foreignKeyTable table are compared to the new values, + //such that identical values do not cause a prefixScan. PrefixScan and propagation can be expensive and should + //not be done needlessly. + ((KTableImpl) foreignKeyTable).enableSendingOldValues(true); + + //Old values must be sent such that the ForeignJoinSubscriptionSendProcessorSupplier can propagate deletions to the correct node. + //This occurs whenever the extracted foreignKey changes values. + enableSendingOldValues(true); + + final TableJoinedInternal tableJoinedInternal = new TableJoinedInternal<>(tableJoined); + final NamedInternal renamed = new NamedInternal(tableJoinedInternal.name()); + + final String subscriptionTopicName = renamed.suffixWithOrElseGet( + "-subscription-registration", + builder, + SUBSCRIPTION_REGISTRATION + ) + TOPIC_SUFFIX; + + // the decoration can't be performed until we have the configuration available when the app runs, + // so we pass Suppliers into the components, which they can call at run time + + final Supplier subscriptionPrimaryKeySerdePseudoTopic = + () -> internalTopologyBuilder().decoratePseudoTopic(subscriptionTopicName + "-pk"); + + final Supplier subscriptionForeignKeySerdePseudoTopic = + () -> internalTopologyBuilder().decoratePseudoTopic(subscriptionTopicName + "-fk"); + + final Supplier valueHashSerdePseudoTopic = + () -> internalTopologyBuilder().decoratePseudoTopic(subscriptionTopicName + "-vh"); + + builder.internalTopologyBuilder.addInternalTopic(subscriptionTopicName, InternalTopicProperties.empty()); + + final Serde foreignKeySerde = ((KTableImpl) foreignKeyTable).keySerde; + final Serde> subscriptionWrapperSerde = new SubscriptionWrapperSerde<>(subscriptionPrimaryKeySerdePseudoTopic, keySerde); + final SubscriptionResponseWrapperSerde responseWrapperSerde = + new SubscriptionResponseWrapperSerde<>(((KTableImpl) foreignKeyTable).valueSerde); + + final CombinedKeySchema combinedKeySchema = new CombinedKeySchema<>( + subscriptionForeignKeySerdePseudoTopic, + foreignKeySerde, + subscriptionPrimaryKeySerdePseudoTopic, + keySerde + ); + + final ProcessorGraphNode> subscriptionNode = new ProcessorGraphNode<>( + new ProcessorParameters<>( + new ForeignJoinSubscriptionSendProcessorSupplier<>( + foreignKeyExtractor, + subscriptionForeignKeySerdePseudoTopic, + valueHashSerdePseudoTopic, + foreignKeySerde, + valueSerde == null ? null : valueSerde.serializer(), + leftJoin + ), + renamed.suffixWithOrElseGet("-subscription-registration-processor", builder, SUBSCRIPTION_REGISTRATION) + ) + ); + builder.addGraphNode(graphNode, subscriptionNode); + + + final StreamPartitioner> subscriptionSinkPartitioner = + tableJoinedInternal.otherPartitioner() == null + ? null + : (topic, key, val, numPartitions) -> + tableJoinedInternal.otherPartitioner().partition(topic, key, null, numPartitions); + + final StreamSinkNode> subscriptionSink = new StreamSinkNode<>( + renamed.suffixWithOrElseGet("-subscription-registration-sink", builder, SINK_NAME), + new StaticTopicNameExtractor<>(subscriptionTopicName), + new ProducedInternal<>(Produced.with(foreignKeySerde, subscriptionWrapperSerde, subscriptionSinkPartitioner)) + ); + builder.addGraphNode(subscriptionNode, subscriptionSink); + + final StreamSourceNode> subscriptionSource = new StreamSourceNode<>( + renamed.suffixWithOrElseGet("-subscription-registration-source", builder, SOURCE_NAME), + Collections.singleton(subscriptionTopicName), + new ConsumedInternal<>(Consumed.with(foreignKeySerde, subscriptionWrapperSerde)) + ); + builder.addGraphNode(subscriptionSink, subscriptionSource); + + // The subscription source is the source node on the *receiving* end *after* the repartition. + // This topic needs to be copartitioned with the Foreign Key table. + final Set copartitionedRepartitionSources = + new HashSet<>(((KTableImpl) foreignKeyTable).subTopologySourceNodes); + copartitionedRepartitionSources.add(subscriptionSource.nodeName()); + builder.internalTopologyBuilder.copartitionSources(copartitionedRepartitionSources); + + + final StoreBuilder>> subscriptionStore = + Stores.timestampedKeyValueStoreBuilder( + Stores.persistentTimestampedKeyValueStore( + renamed.suffixWithOrElseGet("-subscription-store", builder, FK_JOIN_STATE_STORE_NAME) + ), + new Serdes.BytesSerde(), + subscriptionWrapperSerde + ); + builder.addStateStore(subscriptionStore); + + final StatefulProcessorNode> subscriptionReceiveNode = + new StatefulProcessorNode<>( + new ProcessorParameters<>( + new SubscriptionStoreReceiveProcessorSupplier<>(subscriptionStore, combinedKeySchema), + renamed.suffixWithOrElseGet("-subscription-receive", builder, SUBSCRIPTION_PROCESSOR) + ), + Collections.singleton(subscriptionStore), + Collections.emptySet() + ); + builder.addGraphNode(subscriptionSource, subscriptionReceiveNode); + + final StatefulProcessorNode, Change>>> subscriptionJoinForeignNode = + new StatefulProcessorNode<>( + new ProcessorParameters<>( + new SubscriptionJoinForeignProcessorSupplier<>( + ((KTableImpl) foreignKeyTable).valueGetterSupplier() + ), + renamed.suffixWithOrElseGet("-subscription-join-foreign", builder, SUBSCRIPTION_PROCESSOR) + ), + Collections.emptySet(), + Collections.singleton(((KTableImpl) foreignKeyTable).valueGetterSupplier()) + ); + builder.addGraphNode(subscriptionReceiveNode, subscriptionJoinForeignNode); + + final StatefulProcessorNode> foreignJoinSubscriptionNode = new StatefulProcessorNode<>( + new ProcessorParameters<>( + new ForeignJoinSubscriptionProcessorSupplier<>(subscriptionStore, combinedKeySchema), + renamed.suffixWithOrElseGet("-foreign-join-subscription", builder, SUBSCRIPTION_PROCESSOR) + ), + Collections.singleton(subscriptionStore), + Collections.emptySet() + ); + builder.addGraphNode(((KTableImpl) foreignKeyTable).graphNode, foreignJoinSubscriptionNode); + + + final String finalRepartitionTopicName = renamed.suffixWithOrElseGet("-subscription-response", builder, SUBSCRIPTION_RESPONSE) + TOPIC_SUFFIX; + builder.internalTopologyBuilder.addInternalTopic(finalRepartitionTopicName, InternalTopicProperties.empty()); + + final StreamPartitioner> foreignResponseSinkPartitioner = + tableJoinedInternal.partitioner() == null + ? null + : (topic, key, val, numPartitions) -> + tableJoinedInternal.partitioner().partition(topic, key, null, numPartitions); + + final StreamSinkNode> foreignResponseSink = + new StreamSinkNode<>( + renamed.suffixWithOrElseGet("-subscription-response-sink", builder, SINK_NAME), + new StaticTopicNameExtractor<>(finalRepartitionTopicName), + new ProducedInternal<>(Produced.with(keySerde, responseWrapperSerde, foreignResponseSinkPartitioner)) + ); + builder.addGraphNode(subscriptionJoinForeignNode, foreignResponseSink); + builder.addGraphNode(foreignJoinSubscriptionNode, foreignResponseSink); + + final StreamSourceNode> foreignResponseSource = new StreamSourceNode<>( + renamed.suffixWithOrElseGet("-subscription-response-source", builder, SOURCE_NAME), + Collections.singleton(finalRepartitionTopicName), + new ConsumedInternal<>(Consumed.with(keySerde, responseWrapperSerde)) + ); + builder.addGraphNode(foreignResponseSink, foreignResponseSource); + + // the response topic has to be copartitioned with the left (primary) side of the join + final Set resultSourceNodes = new HashSet<>(this.subTopologySourceNodes); + resultSourceNodes.add(foreignResponseSource.nodeName()); + builder.internalTopologyBuilder.copartitionSources(resultSourceNodes); + + final KTableValueGetterSupplier primaryKeyValueGetter = valueGetterSupplier(); + final SubscriptionResolverJoinProcessorSupplier resolverProcessorSupplier = new SubscriptionResolverJoinProcessorSupplier<>( + primaryKeyValueGetter, + valueSerde == null ? null : valueSerde.serializer(), + valueHashSerdePseudoTopic, + joiner, + leftJoin + ); + final StatefulProcessorNode> resolverNode = new StatefulProcessorNode<>( + new ProcessorParameters<>( + resolverProcessorSupplier, + renamed.suffixWithOrElseGet("-subscription-response-resolver", builder, SUBSCRIPTION_RESPONSE_RESOLVER_PROCESSOR) + ), + Collections.emptySet(), + Collections.singleton(primaryKeyValueGetter) + ); + builder.addGraphNode(foreignResponseSource, resolverNode); + + final String resultProcessorName = renamed.suffixWithOrElseGet("-result", builder, FK_JOIN_OUTPUT_NAME); + + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>( + materialized, + builder, + FK_JOIN_OUTPUT_NAME + ); + + // If we have a key serde, it's still valid, but we don't know the value serde, since it's the result + // of the joiner (VR). + if (materializedInternal.keySerde() == null) { + materializedInternal.withKeySerde(keySerde); + } + + final KTableSource resultProcessorSupplier = new KTableSource<>( + materializedInternal.storeName(), + materializedInternal.queryableStoreName() + ); + + final StoreBuilder> resultStore = + new TimestampedKeyValueStoreMaterializer<>(materializedInternal).materialize(); + + final TableProcessorNode resultNode = new TableProcessorNode<>( + resultProcessorName, + new ProcessorParameters<>( + resultProcessorSupplier, + resultProcessorName + ), + resultStore + ); + builder.addGraphNode(resolverNode, resultNode); + + return new KTableImpl( + resultProcessorName, + keySerde, + materializedInternal.valueSerde(), + resultSourceNodes, + materializedInternal.storeName(), + resultProcessorSupplier, + resultNode, + builder + ); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableKTableAbstractJoin.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableKTableAbstractJoin.java new file mode 100644 index 0000000..ecaaf45 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableKTableAbstractJoin.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.kstream.ValueJoiner; + +abstract class KTableKTableAbstractJoin implements KTableProcessorSupplier { + + private final KTableImpl table1; + private final KTableImpl table2; + final KTableValueGetterSupplier valueGetterSupplier1; + final KTableValueGetterSupplier valueGetterSupplier2; + final ValueJoiner joiner; + + boolean sendOldValues = false; + + KTableKTableAbstractJoin(final KTableImpl table1, + final KTableImpl table2, + final ValueJoiner joiner) { + this.table1 = table1; + this.table2 = table2; + this.valueGetterSupplier1 = table1.valueGetterSupplier(); + this.valueGetterSupplier2 = table2.valueGetterSupplier(); + this.joiner = joiner; + } + + @Override + public final boolean enableSendingOldValues(final boolean forceMaterialization) { + // Table-table joins require upstream materialization: + table1.enableSendingOldValues(true); + table2.enableSendingOldValues(true); + sendOldValues = true; + return true; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableKTableAbstractJoinValueGetterSupplier.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableKTableAbstractJoinValueGetterSupplier.java new file mode 100644 index 0000000..924452d --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableKTableAbstractJoinValueGetterSupplier.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +public abstract class KTableKTableAbstractJoinValueGetterSupplier implements KTableValueGetterSupplier { + final KTableValueGetterSupplier valueGetterSupplier1; + final KTableValueGetterSupplier valueGetterSupplier2; + + KTableKTableAbstractJoinValueGetterSupplier(final KTableValueGetterSupplier valueGetterSupplier1, + final KTableValueGetterSupplier valueGetterSupplier2) { + this.valueGetterSupplier1 = valueGetterSupplier1; + this.valueGetterSupplier2 = valueGetterSupplier2; + } + + @Override + public String[] storeNames() { + final String[] storeNames1 = valueGetterSupplier1.storeNames(); + final String[] storeNames2 = valueGetterSupplier2.storeNames(); + final Set stores = new HashSet<>(storeNames1.length + storeNames2.length); + Collections.addAll(stores, storeNames1); + Collections.addAll(stores, storeNames2); + return stores.toArray(new String[0]); + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableKTableInnerJoin.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableKTableInnerJoin.java new file mode 100644 index 0000000..e448aef --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableKTableInnerJoin.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.ValueJoiner; +import org.apache.kafka.streams.processor.To; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.kafka.streams.processor.internals.metrics.TaskMetrics.droppedRecordsSensor; +import static org.apache.kafka.streams.state.ValueAndTimestamp.getValueOrNull; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +class KTableKTableInnerJoin extends KTableKTableAbstractJoin { + private static final Logger LOG = LoggerFactory.getLogger(KTableKTableInnerJoin.class); + + private final KeyValueMapper keyValueMapper = (key, value) -> key; + + KTableKTableInnerJoin(final KTableImpl table1, + final KTableImpl table2, + final ValueJoiner joiner) { + super(table1, table2, joiner); + } + + @Override + public org.apache.kafka.streams.processor.Processor> get() { + return new KTableKTableJoinProcessor(valueGetterSupplier2.get()); + } + + @Override + public KTableValueGetterSupplier view() { + return new KTableKTableInnerJoinValueGetterSupplier(valueGetterSupplier1, valueGetterSupplier2); + } + + private class KTableKTableInnerJoinValueGetterSupplier extends KTableKTableAbstractJoinValueGetterSupplier { + + KTableKTableInnerJoinValueGetterSupplier(final KTableValueGetterSupplier valueGetterSupplier1, + final KTableValueGetterSupplier valueGetterSupplier2) { + super(valueGetterSupplier1, valueGetterSupplier2); + } + + public KTableValueGetter get() { + return new KTableKTableInnerJoinValueGetter(valueGetterSupplier1.get(), valueGetterSupplier2.get()); + } + } + + private class KTableKTableJoinProcessor extends org.apache.kafka.streams.processor.AbstractProcessor> { + + private final KTableValueGetter valueGetter; + private Sensor droppedRecordsSensor; + + KTableKTableJoinProcessor(final KTableValueGetter valueGetter) { + this.valueGetter = valueGetter; + } + + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + super.init(context); + droppedRecordsSensor = droppedRecordsSensor( + Thread.currentThread().getName(), + context.taskId().toString(), + (StreamsMetricsImpl) context.metrics() + ); + valueGetter.init(context); + } + + @Override + public void process(final K key, final Change change) { + // we do join iff keys are equal, thus, if key is null we cannot join and just ignore the record + if (key == null) { + LOG.warn( + "Skipping record due to null key. change=[{}] topic=[{}] partition=[{}] offset=[{}]", + change, context().topic(), context().partition(), context().offset() + ); + droppedRecordsSensor.record(); + return; + } + + R newValue = null; + final long resultTimestamp; + R oldValue = null; + + final ValueAndTimestamp valueAndTimestampRight = valueGetter.get(key); + final V2 valueRight = getValueOrNull(valueAndTimestampRight); + if (valueRight == null) { + return; + } + + resultTimestamp = Math.max(context().timestamp(), valueAndTimestampRight.timestamp()); + + if (change.newValue != null) { + newValue = joiner.apply(change.newValue, valueRight); + } + + if (sendOldValues && change.oldValue != null) { + oldValue = joiner.apply(change.oldValue, valueRight); + } + + context().forward(key, new Change<>(newValue, oldValue), To.all().withTimestamp(resultTimestamp)); + } + + @Override + public void close() { + valueGetter.close(); + } + } + + private class KTableKTableInnerJoinValueGetter implements KTableValueGetter { + + private final KTableValueGetter valueGetter1; + private final KTableValueGetter valueGetter2; + + KTableKTableInnerJoinValueGetter(final KTableValueGetter valueGetter1, + final KTableValueGetter valueGetter2) { + this.valueGetter1 = valueGetter1; + this.valueGetter2 = valueGetter2; + } + + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + valueGetter1.init(context); + valueGetter2.init(context); + } + + @Override + public ValueAndTimestamp get(final K key) { + final ValueAndTimestamp valueAndTimestamp1 = valueGetter1.get(key); + final V1 value1 = getValueOrNull(valueAndTimestamp1); + + if (value1 != null) { + final ValueAndTimestamp valueAndTimestamp2 = valueGetter2.get(keyValueMapper.apply(key, value1)); + final V2 value2 = getValueOrNull(valueAndTimestamp2); + + if (value2 != null) { + return ValueAndTimestamp.make( + joiner.apply(value1, value2), + Math.max(valueAndTimestamp1.timestamp(), valueAndTimestamp2.timestamp())); + } else { + return null; + } + } else { + return null; + } + } + + @Override + public void close() { + valueGetter1.close(); + valueGetter2.close(); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableKTableJoinMerger.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableKTableJoinMerger.java new file mode 100644 index 0000000..36c94b5 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableKTableJoinMerger.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; + +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class KTableKTableJoinMerger implements KTableProcessorSupplier { + + private final KTableProcessorSupplier parent1; + private final KTableProcessorSupplier parent2; + private final String queryableName; + private boolean sendOldValues = false; + + KTableKTableJoinMerger(final KTableProcessorSupplier parent1, + final KTableProcessorSupplier parent2, + final String queryableName) { + this.parent1 = parent1; + this.parent2 = parent2; + this.queryableName = queryableName; + } + + public String getQueryableName() { + return queryableName; + } + + @Override + public org.apache.kafka.streams.processor.Processor> get() { + return new KTableKTableJoinMergeProcessor(); + } + + @Override + public KTableValueGetterSupplier view() { + // if the result KTable is materialized, use the materialized store to return getter value; + // otherwise rely on the parent getter and apply join on-the-fly + if (queryableName != null) { + return new KTableMaterializedValueGetterSupplier<>(queryableName); + } else { + return new KTableValueGetterSupplier() { + + public KTableValueGetter get() { + return parent1.view().get(); + } + + @Override + public String[] storeNames() { + final String[] storeNames1 = parent1.view().storeNames(); + final String[] storeNames2 = parent2.view().storeNames(); + final Set stores = new HashSet<>(storeNames1.length + storeNames2.length); + Collections.addAll(stores, storeNames1); + Collections.addAll(stores, storeNames2); + return stores.toArray(new String[0]); + } + }; + } + } + + @Override + public boolean enableSendingOldValues(final boolean forceMaterialization) { + // Table-table joins require upstream materialization: + parent1.enableSendingOldValues(true); + parent2.enableSendingOldValues(true); + sendOldValues = true; + return true; + } + + public static KTableKTableJoinMerger of(final KTableProcessorSupplier parent1, + final KTableProcessorSupplier parent2) { + return of(parent1, parent2, null); + } + + public static KTableKTableJoinMerger of(final KTableProcessorSupplier parent1, + final KTableProcessorSupplier parent2, + final String queryableName) { + return new KTableKTableJoinMerger<>(parent1, parent2, queryableName); + } + + private class KTableKTableJoinMergeProcessor extends org.apache.kafka.streams.processor.AbstractProcessor> { + private TimestampedKeyValueStore store; + private TimestampedTupleForwarder tupleForwarder; + + @SuppressWarnings("unchecked") + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + super.init(context); + if (queryableName != null) { + store = (TimestampedKeyValueStore) context.getStateStore(queryableName); + tupleForwarder = new TimestampedTupleForwarder<>( + store, + context, + new TimestampedCacheFlushListener<>(context), + sendOldValues); + } + } + + @Override + public void process(final K key, final Change value) { + if (queryableName != null) { + store.put(key, ValueAndTimestamp.make(value.newValue, context().timestamp())); + tupleForwarder.maybeForward(key, value.newValue, sendOldValues ? value.oldValue : null); + } else { + if (sendOldValues) { + context().forward(key, value); + } else { + context().forward(key, new Change<>(value.newValue, null)); + } + } + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableKTableLeftJoin.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableKTableLeftJoin.java new file mode 100644 index 0000000..dc274cd --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableKTableLeftJoin.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.streams.kstream.ValueJoiner; +import org.apache.kafka.streams.processor.To; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.kafka.streams.processor.internals.metrics.TaskMetrics.droppedRecordsSensor; +import static org.apache.kafka.streams.processor.internals.RecordQueue.UNKNOWN; +import static org.apache.kafka.streams.state.ValueAndTimestamp.getValueOrNull; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +class KTableKTableLeftJoin extends KTableKTableAbstractJoin { + private static final Logger LOG = LoggerFactory.getLogger(KTableKTableLeftJoin.class); + + KTableKTableLeftJoin(final KTableImpl table1, + final KTableImpl table2, + final ValueJoiner joiner) { + super(table1, table2, joiner); + } + + @Override + public org.apache.kafka.streams.processor.Processor> get() { + return new KTableKTableLeftJoinProcessor(valueGetterSupplier2.get()); + } + + @Override + public KTableValueGetterSupplier view() { + return new KTableKTableLeftJoinValueGetterSupplier(valueGetterSupplier1, valueGetterSupplier2); + } + + private class KTableKTableLeftJoinValueGetterSupplier extends KTableKTableAbstractJoinValueGetterSupplier { + + KTableKTableLeftJoinValueGetterSupplier(final KTableValueGetterSupplier valueGetterSupplier1, + final KTableValueGetterSupplier valueGetterSupplier2) { + super(valueGetterSupplier1, valueGetterSupplier2); + } + + public KTableValueGetter get() { + return new KTableKTableLeftJoinValueGetter(valueGetterSupplier1.get(), valueGetterSupplier2.get()); + } + } + + + private class KTableKTableLeftJoinProcessor extends org.apache.kafka.streams.processor.AbstractProcessor> { + + private final KTableValueGetter valueGetter; + private Sensor droppedRecordsSensor; + + KTableKTableLeftJoinProcessor(final KTableValueGetter valueGetter) { + this.valueGetter = valueGetter; + } + + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + super.init(context); + droppedRecordsSensor = droppedRecordsSensor( + Thread.currentThread().getName(), + context.taskId().toString(), + (StreamsMetricsImpl) context.metrics() + ); + valueGetter.init(context); + } + + @Override + public void process(final K key, final Change change) { + // we do join iff keys are equal, thus, if key is null we cannot join and just ignore the record + if (key == null) { + LOG.warn( + "Skipping record due to null key. change=[{}] topic=[{}] partition=[{}] offset=[{}]", + change, context().topic(), context().partition(), context().offset() + ); + droppedRecordsSensor.record(); + return; + } + + R newValue = null; + final long resultTimestamp; + R oldValue = null; + + final ValueAndTimestamp valueAndTimestampRight = valueGetter.get(key); + final V2 value2 = getValueOrNull(valueAndTimestampRight); + final long timestampRight; + + if (value2 == null) { + if (change.newValue == null && change.oldValue == null) { + return; + } + timestampRight = UNKNOWN; + } else { + timestampRight = valueAndTimestampRight.timestamp(); + } + + resultTimestamp = Math.max(context().timestamp(), timestampRight); + + if (change.newValue != null) { + newValue = joiner.apply(change.newValue, value2); + } + + if (sendOldValues && change.oldValue != null) { + oldValue = joiner.apply(change.oldValue, value2); + } + + context().forward(key, new Change<>(newValue, oldValue), To.all().withTimestamp(resultTimestamp)); + } + + @Override + public void close() { + valueGetter.close(); + } + } + + private class KTableKTableLeftJoinValueGetter implements KTableValueGetter { + + private final KTableValueGetter valueGetter1; + private final KTableValueGetter valueGetter2; + + KTableKTableLeftJoinValueGetter(final KTableValueGetter valueGetter1, + final KTableValueGetter valueGetter2) { + this.valueGetter1 = valueGetter1; + this.valueGetter2 = valueGetter2; + } + + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + valueGetter1.init(context); + valueGetter2.init(context); + } + + @Override + public ValueAndTimestamp get(final K key) { + final ValueAndTimestamp valueAndTimestamp1 = valueGetter1.get(key); + final V1 value1 = getValueOrNull(valueAndTimestamp1); + + if (value1 != null) { + final ValueAndTimestamp valueAndTimestamp2 = valueGetter2.get(key); + final V2 value2 = getValueOrNull(valueAndTimestamp2); + final long resultTimestamp; + if (valueAndTimestamp2 == null) { + resultTimestamp = valueAndTimestamp1.timestamp(); + } else { + resultTimestamp = Math.max(valueAndTimestamp1.timestamp(), valueAndTimestamp2.timestamp()); + } + return ValueAndTimestamp.make(joiner.apply(value1, value2), resultTimestamp); + } else { + return null; + } + } + + @Override + public void close() { + valueGetter1.close(); + valueGetter2.close(); + } + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableKTableOuterJoin.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableKTableOuterJoin.java new file mode 100644 index 0000000..6b2017a --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableKTableOuterJoin.java @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.streams.kstream.ValueJoiner; +import org.apache.kafka.streams.processor.To; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.kafka.streams.processor.internals.metrics.TaskMetrics.droppedRecordsSensor; +import static org.apache.kafka.streams.processor.internals.RecordQueue.UNKNOWN; +import static org.apache.kafka.streams.state.ValueAndTimestamp.getValueOrNull; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +class KTableKTableOuterJoin extends KTableKTableAbstractJoin { + private static final Logger LOG = LoggerFactory.getLogger(KTableKTableOuterJoin.class); + + KTableKTableOuterJoin(final KTableImpl table1, + final KTableImpl table2, + final ValueJoiner joiner) { + super(table1, table2, joiner); + } + + @Override + public org.apache.kafka.streams.processor.Processor> get() { + return new KTableKTableOuterJoinProcessor(valueGetterSupplier2.get()); + } + + @Override + public KTableValueGetterSupplier view() { + return new KTableKTableOuterJoinValueGetterSupplier(valueGetterSupplier1, valueGetterSupplier2); + } + + private class KTableKTableOuterJoinValueGetterSupplier extends KTableKTableAbstractJoinValueGetterSupplier { + + KTableKTableOuterJoinValueGetterSupplier(final KTableValueGetterSupplier valueGetterSupplier1, + final KTableValueGetterSupplier valueGetterSupplier2) { + super(valueGetterSupplier1, valueGetterSupplier2); + } + + public KTableValueGetter get() { + return new KTableKTableOuterJoinValueGetter(valueGetterSupplier1.get(), valueGetterSupplier2.get()); + } + } + + private class KTableKTableOuterJoinProcessor extends org.apache.kafka.streams.processor.AbstractProcessor> { + + private final KTableValueGetter valueGetter; + private Sensor droppedRecordsSensor; + + KTableKTableOuterJoinProcessor(final KTableValueGetter valueGetter) { + this.valueGetter = valueGetter; + } + + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + super.init(context); + droppedRecordsSensor = droppedRecordsSensor( + Thread.currentThread().getName(), + context.taskId().toString(), + (StreamsMetricsImpl) context.metrics() + ); + valueGetter.init(context); + } + + @Override + public void process(final K key, final Change change) { + // we do join iff keys are equal, thus, if key is null we cannot join and just ignore the record + if (key == null) { + LOG.warn( + "Skipping record due to null key. change=[{}] topic=[{}] partition=[{}] offset=[{}]", + change, context().topic(), context().partition(), context().offset() + ); + droppedRecordsSensor.record(); + return; + } + + R newValue = null; + final long resultTimestamp; + R oldValue = null; + + final ValueAndTimestamp valueAndTimestamp2 = valueGetter.get(key); + final V2 value2 = getValueOrNull(valueAndTimestamp2); + if (value2 == null) { + if (change.newValue == null && change.oldValue == null) { + return; + } + resultTimestamp = context().timestamp(); + } else { + resultTimestamp = Math.max(context().timestamp(), valueAndTimestamp2.timestamp()); + } + + if (value2 != null || change.newValue != null) { + newValue = joiner.apply(change.newValue, value2); + } + + if (sendOldValues && (value2 != null || change.oldValue != null)) { + oldValue = joiner.apply(change.oldValue, value2); + } + + context().forward(key, new Change<>(newValue, oldValue), To.all().withTimestamp(resultTimestamp)); + } + + @Override + public void close() { + valueGetter.close(); + } + } + + private class KTableKTableOuterJoinValueGetter implements KTableValueGetter { + + private final KTableValueGetter valueGetter1; + private final KTableValueGetter valueGetter2; + + KTableKTableOuterJoinValueGetter(final KTableValueGetter valueGetter1, + final KTableValueGetter valueGetter2) { + this.valueGetter1 = valueGetter1; + this.valueGetter2 = valueGetter2; + } + + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + valueGetter1.init(context); + valueGetter2.init(context); + } + + @Override + public ValueAndTimestamp get(final K key) { + R newValue = null; + + final ValueAndTimestamp valueAndTimestamp1 = valueGetter1.get(key); + final V1 value1; + final long timestamp1; + if (valueAndTimestamp1 == null) { + value1 = null; + timestamp1 = UNKNOWN; + } else { + value1 = valueAndTimestamp1.value(); + timestamp1 = valueAndTimestamp1.timestamp(); + } + + final ValueAndTimestamp valueAndTimestamp2 = valueGetter2.get(key); + final V2 value2; + final long timestamp2; + if (valueAndTimestamp2 == null) { + value2 = null; + timestamp2 = UNKNOWN; + } else { + value2 = valueAndTimestamp2.value(); + timestamp2 = valueAndTimestamp2.timestamp(); + } + + if (value1 != null || value2 != null) { + newValue = joiner.apply(value1, value2); + } + + return ValueAndTimestamp.make(newValue, Math.max(timestamp1, timestamp2)); + } + + @Override + public void close() { + valueGetter1.close(); + valueGetter2.close(); + } + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableKTableRightJoin.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableKTableRightJoin.java new file mode 100644 index 0000000..f948cfe --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableKTableRightJoin.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.streams.kstream.ValueJoiner; +import org.apache.kafka.streams.processor.To; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.kafka.streams.processor.internals.metrics.TaskMetrics.droppedRecordsSensor; +import static org.apache.kafka.streams.state.ValueAndTimestamp.getValueOrNull; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +class KTableKTableRightJoin extends KTableKTableAbstractJoin { + private static final Logger LOG = LoggerFactory.getLogger(KTableKTableRightJoin.class); + + KTableKTableRightJoin(final KTableImpl table1, + final KTableImpl table2, + final ValueJoiner joiner) { + super(table1, table2, joiner); + } + + @Override + public org.apache.kafka.streams.processor.Processor> get() { + return new KTableKTableRightJoinProcessor(valueGetterSupplier2.get()); + } + + @Override + public KTableValueGetterSupplier view() { + return new KTableKTableRightJoinValueGetterSupplier(valueGetterSupplier1, valueGetterSupplier2); + } + + private class KTableKTableRightJoinValueGetterSupplier extends KTableKTableAbstractJoinValueGetterSupplier { + + KTableKTableRightJoinValueGetterSupplier(final KTableValueGetterSupplier valueGetterSupplier1, + final KTableValueGetterSupplier valueGetterSupplier2) { + super(valueGetterSupplier1, valueGetterSupplier2); + } + + public KTableValueGetter get() { + return new KTableKTableRightJoinValueGetter(valueGetterSupplier1.get(), valueGetterSupplier2.get()); + } + } + + private class KTableKTableRightJoinProcessor extends org.apache.kafka.streams.processor.AbstractProcessor> { + + private final KTableValueGetter valueGetter; + private Sensor droppedRecordsSensor; + + KTableKTableRightJoinProcessor(final KTableValueGetter valueGetter) { + this.valueGetter = valueGetter; + } + + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + super.init(context); + droppedRecordsSensor = droppedRecordsSensor( + Thread.currentThread().getName(), + context.taskId().toString(), + (StreamsMetricsImpl) context.metrics() + ); + valueGetter.init(context); + } + + @Override + public void process(final K key, final Change change) { + // we do join iff keys are equal, thus, if key is null we cannot join and just ignore the record + if (key == null) { + LOG.warn( + "Skipping record due to null key. change=[{}] topic=[{}] partition=[{}] offset=[{}]", + change, context().topic(), context().partition(), context().offset() + ); + droppedRecordsSensor.record(); + return; + } + + final R newValue; + final long resultTimestamp; + R oldValue = null; + + final ValueAndTimestamp valueAndTimestampLeft = valueGetter.get(key); + final V2 valueLeft = getValueOrNull(valueAndTimestampLeft); + if (valueLeft == null) { + return; + } + + resultTimestamp = Math.max(context().timestamp(), valueAndTimestampLeft.timestamp()); + + // joiner == "reverse joiner" + newValue = joiner.apply(change.newValue, valueLeft); + + if (sendOldValues) { + // joiner == "reverse joiner" + oldValue = joiner.apply(change.oldValue, valueLeft); + } + + context().forward(key, new Change<>(newValue, oldValue), To.all().withTimestamp(resultTimestamp)); + } + + @Override + public void close() { + valueGetter.close(); + } + } + + private class KTableKTableRightJoinValueGetter implements KTableValueGetter { + + private final KTableValueGetter valueGetter1; + private final KTableValueGetter valueGetter2; + + KTableKTableRightJoinValueGetter(final KTableValueGetter valueGetter1, + final KTableValueGetter valueGetter2) { + this.valueGetter1 = valueGetter1; + this.valueGetter2 = valueGetter2; + } + + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + valueGetter1.init(context); + valueGetter2.init(context); + } + + @Override + public ValueAndTimestamp get(final K key) { + final ValueAndTimestamp valueAndTimestamp2 = valueGetter2.get(key); + final V2 value2 = getValueOrNull(valueAndTimestamp2); + + if (value2 != null) { + final ValueAndTimestamp valueAndTimestamp1 = valueGetter1.get(key); + final V1 value1 = getValueOrNull(valueAndTimestamp1); + final long resultTimestamp; + if (valueAndTimestamp1 == null) { + resultTimestamp = valueAndTimestamp2.timestamp(); + } else { + resultTimestamp = Math.max(valueAndTimestamp1.timestamp(), valueAndTimestamp2.timestamp()); + } + return ValueAndTimestamp.make(joiner.apply(value1, value2), resultTimestamp); + } else { + return null; + } + } + + @Override + public void close() { + valueGetter1.close(); + valueGetter2.close(); + } + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableMapValues.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableMapValues.java new file mode 100644 index 0000000..221f986 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableMapValues.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.kstream.ValueMapperWithKey; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; + +import static org.apache.kafka.streams.state.ValueAndTimestamp.getValueOrNull; + + +class KTableMapValues implements KTableNewProcessorSupplier { + private final KTableImpl parent; + private final ValueMapperWithKey mapper; + private final String queryableName; + private boolean sendOldValues = false; + + KTableMapValues(final KTableImpl parent, + final ValueMapperWithKey mapper, + final String queryableName) { + this.parent = parent; + this.mapper = mapper; + this.queryableName = queryableName; + } + + @Override + public Processor, KIn, Change> get() { + return new KTableMapValuesProcessor(); + } + + @Override + public KTableValueGetterSupplier view() { + // if the KTable is materialized, use the materialized store to return getter value; + // otherwise rely on the parent getter and apply map-values on-the-fly + if (queryableName != null) { + return new KTableMaterializedValueGetterSupplier<>(queryableName); + } else { + return new KTableValueGetterSupplier() { + final KTableValueGetterSupplier parentValueGetterSupplier = parent.valueGetterSupplier(); + + public KTableValueGetter get() { + return new KTableMapValuesValueGetter(parentValueGetterSupplier.get()); + } + + @Override + public String[] storeNames() { + return parentValueGetterSupplier.storeNames(); + } + }; + } + } + + @Override + public boolean enableSendingOldValues(final boolean forceMaterialization) { + if (queryableName != null) { + sendOldValues = true; + return true; + } + + if (parent.enableSendingOldValues(forceMaterialization)) { + sendOldValues = true; + } + + return sendOldValues; + } + + private VOut computeValue(final KIn key, final VIn value) { + VOut newValue = null; + + if (value != null) { + newValue = mapper.apply(key, value); + } + + return newValue; + } + + private ValueAndTimestamp computeValueAndTimestamp(final KIn key, final ValueAndTimestamp valueAndTimestamp) { + VOut newValue = null; + long timestamp = 0; + + if (valueAndTimestamp != null) { + newValue = mapper.apply(key, valueAndTimestamp.value()); + timestamp = valueAndTimestamp.timestamp(); + } + + return ValueAndTimestamp.make(newValue, timestamp); + } + + + private class KTableMapValuesProcessor implements Processor, KIn, Change> { + private ProcessorContext> context; + private TimestampedKeyValueStore store; + private TimestampedTupleForwarder tupleForwarder; + + @Override + public void init(final ProcessorContext> context) { + this.context = context; + if (queryableName != null) { + store = context.getStateStore(queryableName); + tupleForwarder = new TimestampedTupleForwarder<>( + store, + context, + new TimestampedCacheFlushListener<>(context), + sendOldValues); + } + } + + @Override + public void process(final Record> record) { + final VOut newValue = computeValue(record.key(), record.value().newValue); + final VOut oldValue = computeOldValue(record.key(), record.value()); + + if (queryableName != null) { + store.put(record.key(), ValueAndTimestamp.make(newValue, record.timestamp())); + tupleForwarder.maybeForward(record.key(), newValue, oldValue); + } else { + context.forward(record.withValue(new Change<>(newValue, oldValue))); + } + } + + private VOut computeOldValue(final KIn key, final Change change) { + if (!sendOldValues) { + return null; + } + + return queryableName != null + ? getValueOrNull(store.get(key)) + : computeValue(key, change.oldValue); + } + } + + + private class KTableMapValuesValueGetter implements KTableValueGetter { + private final KTableValueGetter parentGetter; + + KTableMapValuesValueGetter(final KTableValueGetter parentGetter) { + this.parentGetter = parentGetter; + } + + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + parentGetter.init(context); + } + + @Override + public ValueAndTimestamp get(final KIn key) { + return computeValueAndTimestamp(key, parentGetter.get(key)); + } + + @Override + public void close() { + parentGetter.close(); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableMaterializedValueGetterSupplier.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableMaterializedValueGetterSupplier.java new file mode 100644 index 0000000..351e001 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableMaterializedValueGetterSupplier.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; + +public class KTableMaterializedValueGetterSupplier implements KTableValueGetterSupplier { + private final String storeName; + + KTableMaterializedValueGetterSupplier(final String storeName) { + this.storeName = storeName; + } + + public KTableValueGetter get() { + return new KTableMaterializedValueGetter(); + } + + @Override + public String[] storeNames() { + return new String[]{storeName}; + } + + private class KTableMaterializedValueGetter implements KTableValueGetter { + private TimestampedKeyValueStore store; + + @Override + public void init(final ProcessorContext context) { + store = context.getStateStore(storeName); + } + + @Override + public ValueAndTimestamp get(final K key) { + return store.get(key); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableNewProcessorSupplier.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableNewProcessorSupplier.java new file mode 100644 index 0000000..4cf195a --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableNewProcessorSupplier.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.processor.api.ProcessorSupplier; + +public interface KTableNewProcessorSupplier extends ProcessorSupplier, KOut, Change> { + + KTableValueGetterSupplier view(); + + /** + * Potentially enables sending old values. + *

                + * If {@code forceMaterialization} is {@code true}, the method will force the materialization of upstream nodes to + * enable sending old values. + *

                + * If {@code forceMaterialization} is {@code false}, the method will only enable the sending of old values if + * an upstream node is already materialized. + * + * @param forceMaterialization indicates if an upstream node should be forced to materialize to enable sending old + * values. + * @return {@code true} if sending old values is enabled, i.e. either because {@code forceMaterialization} was + * {@code true} or some upstream node is materialized. + */ + boolean enableSendingOldValues(boolean forceMaterialization); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTablePassThrough.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTablePassThrough.java new file mode 100644 index 0000000..6e60866 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTablePassThrough.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; + +import java.util.Collection; + +public class KTablePassThrough implements KTableNewProcessorSupplier { + private final Collection parents; + private final String storeName; + + + KTablePassThrough(final Collection parents, final String storeName) { + this.parents = parents; + this.storeName = storeName; + } + + @Override + public Processor, KIn, Change> get() { + return new KTablePassThroughProcessor(); + } + + @Override + public boolean enableSendingOldValues(final boolean forceMaterialization) { + // Aggregation requires materialization so we will always enable sending old values + for (final KStreamAggProcessorSupplier parent : parents) { + parent.enableSendingOldValues(); + } + return true; + } + + @Override + public KTableValueGetterSupplier view() { + + return new KTableValueGetterSupplier() { + + public KTableValueGetter get() { + return new KTablePassThroughValueGetter(); + } + + @Override + public String[] storeNames() { + return new String[]{storeName}; + } + }; + } + + private class KTablePassThroughProcessor implements Processor, KIn, Change> { + private ProcessorContext> context; + + @Override + public void init(final ProcessorContext> context) { + this.context = context; + } + + @Override + public void process(final Record> record) { + context.forward(record); + } + } + + private class KTablePassThroughValueGetter implements KTableValueGetter { + private TimestampedKeyValueStore store; + + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + store = context.getStateStore(storeName); + } + + @Override + public ValueAndTimestamp get(final KIn key) { + return store.get(key); + } + + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableProcessorSupplier.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableProcessorSupplier.java new file mode 100644 index 0000000..6f30dbb --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableProcessorSupplier.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public interface KTableProcessorSupplier extends org.apache.kafka.streams.processor.ProcessorSupplier> { + + KTableValueGetterSupplier view(); + + /** + * Potentially enables sending old values. + *

                + * If {@code forceMaterialization} is {@code true}, the method will force the materialization of upstream nodes to + * enable sending old values. + *

                + * If {@code forceMaterialization} is {@code false}, the method will only enable the sending of old values if + * an upstream node is already materialized. + * + * @param forceMaterialization indicates if an upstream node should be forced to materialize to enable sending old + * values. + * @return {@code true} is sending old values is enabled, i.e. either because {@code forceMaterialization} was + * {@code true} or some upstream node is materialized. + */ + boolean enableSendingOldValues(boolean forceMaterialization); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableReduce.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableReduce.java new file mode 100644 index 0000000..d43f6df --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableReduce.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.kstream.Reducer; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; + +import static org.apache.kafka.streams.state.ValueAndTimestamp.getValueOrNull; + +public class KTableReduce implements KTableNewProcessorSupplier { + + private final String storeName; + private final Reducer addReducer; + private final Reducer removeReducer; + + private boolean sendOldValues = false; + + KTableReduce(final String storeName, final Reducer addReducer, final Reducer removeReducer) { + this.storeName = storeName; + this.addReducer = addReducer; + this.removeReducer = removeReducer; + } + + @Override + public boolean enableSendingOldValues(final boolean forceMaterialization) { + // Reduce is always materialized: + sendOldValues = true; + return true; + } + + @Override + public Processor, K, Change> get() { + return new KTableReduceProcessor(); + } + + private class KTableReduceProcessor implements Processor, K, Change> { + + private TimestampedKeyValueStore store; + private TimestampedTupleForwarder tupleForwarder; + + @SuppressWarnings("unchecked") + @Override + public void init(final ProcessorContext> context) { + store = (TimestampedKeyValueStore) context.getStateStore(storeName); + tupleForwarder = new TimestampedTupleForwarder<>( + store, + context, + new TimestampedCacheFlushListener<>(context), + sendOldValues); + } + + /** + * @throws StreamsException if key is null + */ + @Override + public void process(final Record> record) { + // the keys should never be null + if (record.key() == null) { + throw new StreamsException("Record key for KTable reduce operator with state " + storeName + " should not be null."); + } + + final ValueAndTimestamp oldAggAndTimestamp = store.get(record.key()); + final V oldAgg = getValueOrNull(oldAggAndTimestamp); + final V intermediateAgg; + long newTimestamp; + + // first try to remove the old value + if (record.value().oldValue != null && oldAgg != null) { + intermediateAgg = removeReducer.apply(oldAgg, record.value().oldValue); + newTimestamp = Math.max(record.timestamp(), oldAggAndTimestamp.timestamp()); + } else { + intermediateAgg = oldAgg; + newTimestamp = record.timestamp(); + } + + // then try to add the new value + final V newAgg; + if (record.value().newValue != null) { + if (intermediateAgg == null) { + newAgg = record.value().newValue; + } else { + newAgg = addReducer.apply(intermediateAgg, record.value().newValue); + newTimestamp = Math.max(record.timestamp(), oldAggAndTimestamp.timestamp()); + } + } else { + newAgg = intermediateAgg; + } + + // update the store with the new value + store.put(record.key(), ValueAndTimestamp.make(newAgg, newTimestamp)); + tupleForwarder.maybeForward(record.key(), newAgg, sendOldValues ? oldAgg : null, newTimestamp); + } + } + + @Override + public KTableValueGetterSupplier view() { + return new KTableMaterializedValueGetterSupplier<>(storeName); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableRepartitionMap.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableRepartitionMap.java new file mode 100644 index 0000000..6df6ce5 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableRepartitionMap.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.state.ValueAndTimestamp; + +import static org.apache.kafka.streams.state.ValueAndTimestamp.getValueOrNull; + +/** + * KTable repartition map functions are not exposed to public APIs, but only used for keyed aggregations. + *

                + * Given the input, it can output at most two records (one mapped from old value and one mapped from new value). + */ +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class KTableRepartitionMap implements KTableProcessorSupplier> { + + private final KTableImpl parent; + private final KeyValueMapper> mapper; + + KTableRepartitionMap(final KTableImpl parent, final KeyValueMapper> mapper) { + this.parent = parent; + this.mapper = mapper; + } + + @Override + public org.apache.kafka.streams.processor.Processor> get() { + return new KTableMapProcessor(); + } + + @Override + public KTableValueGetterSupplier> view() { + final KTableValueGetterSupplier parentValueGetterSupplier = parent.valueGetterSupplier(); + + return new KTableValueGetterSupplier>() { + + public KTableValueGetter> get() { + return new KTableMapValueGetter(parentValueGetterSupplier.get()); + } + + @Override + public String[] storeNames() { + throw new StreamsException("Underlying state store not accessible due to repartitioning."); + } + }; + } + + /** + * @throws IllegalStateException since this method should never be called + */ + @Override + public boolean enableSendingOldValues(final boolean forceMaterialization) { + // this should never be called + throw new IllegalStateException("KTableRepartitionMap should always require sending old values."); + } + + private class KTableMapProcessor extends org.apache.kafka.streams.processor.AbstractProcessor> { + + /** + * @throws StreamsException if key is null + */ + @Override + public void process(final K key, final Change change) { + // the original key should never be null + if (key == null) { + throw new StreamsException("Record key for the grouping KTable should not be null."); + } + + // if the value is null, we do not need to forward its selected key-value further + final KeyValue newPair = change.newValue == null ? null : mapper.apply(key, change.newValue); + final KeyValue oldPair = change.oldValue == null ? null : mapper.apply(key, change.oldValue); + + // if the selected repartition key or value is null, skip + // forward oldPair first, to be consistent with reduce and aggregate + if (oldPair != null && oldPair.key != null && oldPair.value != null) { + context().forward(oldPair.key, new Change<>(null, oldPair.value)); + } + + if (newPair != null && newPair.key != null && newPair.value != null) { + context().forward(newPair.key, new Change<>(newPair.value, null)); + } + + } + } + + private class KTableMapValueGetter implements KTableValueGetter> { + private final KTableValueGetter parentGetter; + private org.apache.kafka.streams.processor.ProcessorContext context; + + KTableMapValueGetter(final KTableValueGetter parentGetter) { + this.parentGetter = parentGetter; + } + + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + this.context = context; + parentGetter.init(context); + } + + @Override + public ValueAndTimestamp> get(final K key) { + final ValueAndTimestamp valueAndTimestamp = parentGetter.get(key); + return ValueAndTimestamp.make( + mapper.apply(key, getValueOrNull(valueAndTimestamp)), + valueAndTimestamp == null ? context.timestamp() : valueAndTimestamp.timestamp()); + } + + @Override + public void close() { + parentGetter.close(); + } + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableSource.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableSource.java new file mode 100644 index 0000000..f15780d --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableSource.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import static org.apache.kafka.streams.processor.internals.metrics.TaskMetrics.droppedRecordsSensor; + +import java.util.Objects; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.api.RecordMetadata; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class KTableSource implements ProcessorSupplier> { + + private static final Logger LOG = LoggerFactory.getLogger(KTableSource.class); + + private final String storeName; + private String queryableName; + private boolean sendOldValues; + + public KTableSource(final String storeName, final String queryableName) { + Objects.requireNonNull(storeName, "storeName can't be null"); + + this.storeName = storeName; + this.queryableName = queryableName; + this.sendOldValues = false; + } + + public String queryableName() { + return queryableName; + } + + @Override + public Processor> get() { + return new KTableSourceProcessor(); + } + + // when source ktable requires sending old values, we just + // need to set the queryable name as the store name to enforce materialization + public void enableSendingOldValues() { + this.sendOldValues = true; + this.queryableName = storeName; + } + + // when the source ktable requires materialization from downstream, we just + // need to set the queryable name as the store name to enforce materialization + public void materialize() { + this.queryableName = storeName; + } + + public boolean materialized() { + return queryableName != null; + } + + private class KTableSourceProcessor implements Processor> { + + private ProcessorContext> context; + private TimestampedKeyValueStore store; + private TimestampedTupleForwarder tupleForwarder; + private Sensor droppedRecordsSensor; + + @SuppressWarnings("unchecked") + @Override + public void init(final ProcessorContext> context) { + this.context = context; + final StreamsMetricsImpl metrics = (StreamsMetricsImpl) context.metrics(); + droppedRecordsSensor = droppedRecordsSensor(Thread.currentThread().getName(), + context.taskId().toString(), metrics); + if (queryableName != null) { + store = context.getStateStore(queryableName); + tupleForwarder = new TimestampedTupleForwarder<>( + store, + context, + new TimestampedCacheFlushListener<>(context), + sendOldValues); + } + } + + @Override + public void process(final Record record) { + // if the key is null, then ignore the record + if (record.key() == null) { + if (context.recordMetadata().isPresent()) { + final RecordMetadata recordMetadata = context.recordMetadata().get(); + LOG.warn( + "Skipping record due to null key. " + + "topic=[{}] partition=[{}] offset=[{}]", + recordMetadata.topic(), recordMetadata.partition(), recordMetadata.offset() + ); + } else { + LOG.warn( + "Skipping record due to null key. Topic, partition, and offset not known." + ); + } + droppedRecordsSensor.record(); + return; + } + + if (queryableName != null) { + final ValueAndTimestamp oldValueAndTimestamp = store.get(record.key()); + final VIn oldValue; + if (oldValueAndTimestamp != null) { + oldValue = oldValueAndTimestamp.value(); + if (record.timestamp() < oldValueAndTimestamp.timestamp()) { + if (context.recordMetadata().isPresent()) { + final RecordMetadata recordMetadata = context.recordMetadata().get(); + LOG.warn( + "Detected out-of-order KTable update for {}, " + + "old timestamp=[{}] new timestamp=[{}]. " + + "topic=[{}] partition=[{}] offset=[{}].", + store.name(), + oldValueAndTimestamp.timestamp(), record.timestamp(), + recordMetadata.topic(), recordMetadata.offset(), recordMetadata.partition() + ); + } else { + LOG.warn( + "Detected out-of-order KTable update for {}, " + + "old timestamp=[{}] new timestamp=[{}]. " + + "Topic, partition and offset not known.", + store.name(), + oldValueAndTimestamp.timestamp(), record.timestamp() + ); + } + } + } else { + oldValue = null; + } + store.put(record.key(), ValueAndTimestamp.make(record.value(), record.timestamp())); + tupleForwarder.maybeForward(record.key(), record.value(), oldValue); + } else { + context.forward(record.withValue(new Change<>(record.value(), null))); + } + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableSourceValueGetterSupplier.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableSourceValueGetterSupplier.java new file mode 100644 index 0000000..740fbf6 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableSourceValueGetterSupplier.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; + +public class KTableSourceValueGetterSupplier implements KTableValueGetterSupplier { + private final String storeName; + + public KTableSourceValueGetterSupplier(final String storeName) { + this.storeName = storeName; + } + + public KTableValueGetter get() { + return new KTableSourceValueGetter(); + } + + @Override + public String[] storeNames() { + return new String[]{storeName}; + } + + private class KTableSourceValueGetter implements KTableValueGetter { + private TimestampedKeyValueStore store = null; + + public void init(final ProcessorContext context) { + store = context.getStateStore(storeName); + } + + public ValueAndTimestamp get(final K key) { + return store.get(key); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableTransformValues.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableTransformValues.java new file mode 100644 index 0000000..94a1c0e --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableTransformValues.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.streams.kstream.ValueTransformerWithKey; +import org.apache.kafka.streams.kstream.ValueTransformerWithKeySupplier; +import org.apache.kafka.streams.processor.internals.ForwardingDisabledProcessorContext; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.processor.internals.ProcessorRecordContext; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; + +import java.util.Objects; + +import static org.apache.kafka.streams.processor.internals.RecordQueue.UNKNOWN; +import static org.apache.kafka.streams.state.ValueAndTimestamp.getValueOrNull; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +class KTableTransformValues implements KTableProcessorSupplier { + private final KTableImpl parent; + private final ValueTransformerWithKeySupplier transformerSupplier; + private final String queryableName; + private boolean sendOldValues = false; + + KTableTransformValues(final KTableImpl parent, + final ValueTransformerWithKeySupplier transformerSupplier, + final String queryableName) { + this.parent = Objects.requireNonNull(parent, "parent"); + this.transformerSupplier = Objects.requireNonNull(transformerSupplier, "transformerSupplier"); + this.queryableName = queryableName; + } + + @Override + public org.apache.kafka.streams.processor.Processor> get() { + return new KTableTransformValuesProcessor(transformerSupplier.get()); + } + + @Override + public KTableValueGetterSupplier view() { + if (queryableName != null) { + return new KTableMaterializedValueGetterSupplier<>(queryableName); + } + + return new KTableValueGetterSupplier() { + final KTableValueGetterSupplier parentValueGetterSupplier = parent.valueGetterSupplier(); + + public KTableValueGetter get() { + return new KTableTransformValuesGetter( + parentValueGetterSupplier.get(), + transformerSupplier.get()); + } + + @Override + public String[] storeNames() { + return parentValueGetterSupplier.storeNames(); + } + }; + } + + @Override + public boolean enableSendingOldValues(final boolean forceMaterialization) { + if (queryableName != null) { + sendOldValues = true; + return true; + } + + if (parent.enableSendingOldValues(forceMaterialization)) { + sendOldValues = true; + } + return sendOldValues; + } + + private class KTableTransformValuesProcessor extends org.apache.kafka.streams.processor.AbstractProcessor> { + private final ValueTransformerWithKey valueTransformer; + private TimestampedKeyValueStore store; + private TimestampedTupleForwarder tupleForwarder; + + private KTableTransformValuesProcessor(final ValueTransformerWithKey valueTransformer) { + this.valueTransformer = Objects.requireNonNull(valueTransformer, "valueTransformer"); + } + + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + super.init(context); + valueTransformer.init(new ForwardingDisabledProcessorContext(context)); + if (queryableName != null) { + store = context.getStateStore(queryableName); + tupleForwarder = new TimestampedTupleForwarder<>( + store, + context, + new TimestampedCacheFlushListener<>(context), + sendOldValues); + } + } + + @Override + public void process(final K key, final Change change) { + final V1 newValue = valueTransformer.transform(key, change.newValue); + + if (queryableName == null) { + final V1 oldValue = sendOldValues ? valueTransformer.transform(key, change.oldValue) : null; + context().forward(key, new Change<>(newValue, oldValue)); + } else { + final V1 oldValue = sendOldValues ? getValueOrNull(store.get(key)) : null; + store.put(key, ValueAndTimestamp.make(newValue, context().timestamp())); + tupleForwarder.maybeForward(key, newValue, oldValue); + } + } + + @Override + public void close() { + valueTransformer.close(); + } + } + + + private class KTableTransformValuesGetter implements KTableValueGetter { + private final KTableValueGetter parentGetter; + private InternalProcessorContext internalProcessorContext; + private final ValueTransformerWithKey valueTransformer; + + KTableTransformValuesGetter(final KTableValueGetter parentGetter, + final ValueTransformerWithKey valueTransformer) { + this.parentGetter = Objects.requireNonNull(parentGetter, "parentGetter"); + this.valueTransformer = Objects.requireNonNull(valueTransformer, "valueTransformer"); + } + + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + internalProcessorContext = (InternalProcessorContext) context; + parentGetter.init(context); + valueTransformer.init(new ForwardingDisabledProcessorContext(context)); + } + + @Override + public ValueAndTimestamp get(final K key) { + final ValueAndTimestamp valueAndTimestamp = parentGetter.get(key); + + final ProcessorRecordContext currentContext = internalProcessorContext.recordContext(); + + internalProcessorContext.setRecordContext(new ProcessorRecordContext( + valueAndTimestamp == null ? UNKNOWN : valueAndTimestamp.timestamp(), + -1L, // we don't know the original offset + // technically, we know the partition, but in the new `api.Processor` class, + // we move to `RecordMetadata` than would be `null` for this case and thus + // we won't have the partition information, so it's better to not provide it + // here either, to not introduce a regression later on + -1, + null, // we don't know the upstream input topic + new RecordHeaders() + )); + + final ValueAndTimestamp result = ValueAndTimestamp.make( + valueTransformer.transform(key, getValueOrNull(valueAndTimestamp)), + valueAndTimestamp == null ? UNKNOWN : valueAndTimestamp.timestamp()); + + internalProcessorContext.setRecordContext(currentContext); + + return result; + } + + @Override + public void close() { + parentGetter.close(); + valueTransformer.close(); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableValueGetter.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableValueGetter.java new file mode 100644 index 0000000..12145fa --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableValueGetter.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.state.ValueAndTimestamp; + +public interface KTableValueGetter { + + void init(ProcessorContext context); + + ValueAndTimestamp get(K key); + + default void close() {} +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableValueGetterSupplier.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableValueGetterSupplier.java new file mode 100644 index 0000000..aa28e9a --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/KTableValueGetterSupplier.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +public interface KTableValueGetterSupplier { + + KTableValueGetter get(); + + String[] storeNames(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/MaterializedInternal.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/MaterializedInternal.java new file mode 100644 index 0000000..4a3cbb2 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/MaterializedInternal.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.state.StoreSupplier; + +import java.time.Duration; +import java.util.Map; + +public class MaterializedInternal extends Materialized { + + private final boolean queryable; + + public MaterializedInternal(final Materialized materialized) { + this(materialized, null, null); + } + + public MaterializedInternal(final Materialized materialized, + final InternalNameProvider nameProvider, + final String generatedStorePrefix) { + super(materialized); + + // if storeName is not provided, the corresponding KTable would never be queryable; + // but we still need to provide an internal name for it in case we materialize. + queryable = storeName() != null; + if (!queryable && nameProvider != null) { + storeName = nameProvider.newStoreName(generatedStorePrefix); + } + } + + public String queryableStoreName() { + return queryable ? storeName() : null; + } + + public String storeName() { + if (storeSupplier != null) { + return storeSupplier.name(); + } + return storeName; + } + + public StoreSupplier storeSupplier() { + return storeSupplier; + } + + public Serde keySerde() { + return keySerde; + } + + public Serde valueSerde() { + return valueSerde; + } + + public boolean loggingEnabled() { + return loggingEnabled; + } + + Map logConfig() { + return topicConfig; + } + + public boolean cachingEnabled() { + return cachingEnabled; + } + + Duration retention() { + return retention; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/NamedInternal.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/NamedInternal.java new file mode 100644 index 0000000..532928a --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/NamedInternal.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.kstream.Named; + +public class NamedInternal extends Named { + + public static NamedInternal empty() { + return new NamedInternal((String) null); + } + + public static NamedInternal with(final String name) { + return new NamedInternal(name); + } + + /** + * Creates a new {@link NamedInternal} instance. + * + * @param internal the internal name. + */ + NamedInternal(final Named internal) { + super(internal); + } + + /** + * Creates a new {@link NamedInternal} instance. + * + * @param internal the internal name. + */ + NamedInternal(final String internal) { + super(internal); + } + + /** + * @return a string name. + */ + public String name() { + return name; + } + + @Override + public NamedInternal withName(final String name) { + return new NamedInternal(name); + } + + String suffixWithOrElseGet(final String suffix, final String other) { + if (name != null) { + return name + suffix; + } else { + return other; + } + } + + String suffixWithOrElseGet(final String suffix, final InternalNameProvider provider, final String prefix) { + // We actually do not need to generate processor names for operation if a name is specified. + // But before returning, we still need to burn index for the operation to keep topology backward compatibility. + if (name != null) { + provider.newProcessorName(prefix); + + final String suffixed = name + suffix; + // Re-validate generated name as suffixed string could be too large. + Named.validate(suffixed); + + return suffixed; + } else { + return provider.newProcessorName(prefix); + } + } + + String orElseGenerateWithPrefix(final InternalNameProvider provider, final String prefix) { + // We actually do not need to generate processor names for operation if a name is specified. + // But before returning, we still need to burn index for the operation to keep topology backward compatibility. + if (name != null) { + provider.newProcessorName(prefix); + return name; + } else { + return provider.newProcessorName(prefix); + } + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/PassThrough.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/PassThrough.java new file mode 100644 index 0000000..f357a46 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/PassThrough.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.processor.api.ContextualProcessor; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.api.Record; + +class PassThrough implements ProcessorSupplier { + + @Override + public Processor get() { + return new PassThroughProcessor<>(); + } + + private static final class PassThroughProcessor extends ContextualProcessor { + @Override + public void process(final Record record) { + context().forward(record); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/PrintForeachAction.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/PrintForeachAction.java new file mode 100644 index 0000000..861dfd3 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/PrintForeachAction.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.kstream.ForeachAction; +import org.apache.kafka.streams.kstream.KeyValueMapper; + +import java.io.OutputStream; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.nio.charset.StandardCharsets; + +public class PrintForeachAction implements ForeachAction { + + private final String label; + private final PrintWriter printWriter; + private final boolean closable; + private final KeyValueMapper mapper; + + /** + * Print customized output with given writer. The {@link OutputStream} can be {@link System#out} or the others. + * + * @param outputStream The output stream to write to. + * @param mapper The mapper which can allow user to customize output will be printed. + * @param label The given name will be printed. + */ + PrintForeachAction(final OutputStream outputStream, + final KeyValueMapper mapper, + final String label) { + this.printWriter = new PrintWriter(new OutputStreamWriter(outputStream, StandardCharsets.UTF_8)); + this.closable = outputStream != System.out && outputStream != System.err; + this.mapper = mapper; + this.label = label; + } + + @Override + public void apply(final K key, final V value) { + final String data = String.format("[%s]: %s", label, mapper.apply(key, value)); + printWriter.println(data); + if (!closable) { + printWriter.flush(); + } + } + + public void close() { + if (closable) { + printWriter.close(); + } else { + printWriter.flush(); + } + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/PrintedInternal.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/PrintedInternal.java new file mode 100644 index 0000000..0cd1760 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/PrintedInternal.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.kstream.Printed; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; + +public class PrintedInternal extends Printed { + public PrintedInternal(final Printed printed) { + super(printed); + } + + public ProcessorSupplier build(final String processorName) { + return new KStreamPrint<>(new PrintForeachAction<>(outputStream, mapper, label != null ? label : processorName)); + } + + public String name() { + return processorName; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/ProducedInternal.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/ProducedInternal.java new file mode 100644 index 0000000..0f0620c --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/ProducedInternal.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.processor.StreamPartitioner; + +public class ProducedInternal extends Produced { + + public ProducedInternal(final Produced produced) { + super(produced); + } + + public Serde keySerde() { + return keySerde; + } + + public Serde valueSerde() { + return valueSerde; + } + + public StreamPartitioner streamPartitioner() { + return partitioner; + } + + public String name() { + return processorName; + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/RepartitionedInternal.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/RepartitionedInternal.java new file mode 100644 index 0000000..bd66d73 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/RepartitionedInternal.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.kstream.Repartitioned; +import org.apache.kafka.streams.processor.StreamPartitioner; +import org.apache.kafka.streams.processor.internals.InternalTopicProperties; + +public class RepartitionedInternal extends Repartitioned { + + public RepartitionedInternal(final Repartitioned repartitioned) { + super(repartitioned); + } + + InternalTopicProperties toInternalTopicProperties() { + return new InternalTopicProperties(numberOfPartitions()); + } + + public String name() { + return name; + } + + public Serde keySerde() { + return keySerde; + } + + public Serde valueSerde() { + return valueSerde; + } + + public StreamPartitioner streamPartitioner() { + return partitioner; + } + + public Integer numberOfPartitions() { + return numberOfPartitions; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/SessionCacheFlushListener.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/SessionCacheFlushListener.java new file mode 100644 index 0000000..a2c95bf --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/SessionCacheFlushListener.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.processor.internals.ProcessorNode; +import org.apache.kafka.streams.state.internals.CacheFlushListener; + +class SessionCacheFlushListener implements CacheFlushListener, VOut> { + private final InternalProcessorContext, Change> context; + + @SuppressWarnings("rawtypes") + private final ProcessorNode myNode; + + SessionCacheFlushListener(final ProcessorContext, Change> context) { + this.context = (InternalProcessorContext, Change>) context; + myNode = this.context.currentNode(); + } + + @Override + public void apply(final Windowed key, + final VOut newValue, + final VOut oldValue, + final long timestamp) { + @SuppressWarnings("rawtypes") final ProcessorNode prev = context.currentNode(); + context.setCurrentNode(myNode); + try { + context.forward(new Record<>(key, new Change<>(newValue, oldValue), key.window().end())); + } finally { + context.setCurrentNode(prev); + } + } + + @Override + public void apply(final Record, Change> record) { + @SuppressWarnings("rawtypes") final ProcessorNode prev = context.currentNode(); + context.setCurrentNode(myNode); + try { + context.forward(record.withTimestamp(record.key().window().end())); + } finally { + context.setCurrentNode(prev); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/SessionTupleForwarder.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/SessionTupleForwarder.java new file mode 100644 index 0000000..ac475a4 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/SessionTupleForwarder.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.state.internals.CacheFlushListener; +import org.apache.kafka.streams.state.internals.WrappedStateStore; + +/** + * This class is used to determine if a processor should forward values to child nodes. + * Forwarding by this class only occurs when caching is not enabled. If caching is enabled, + * forwarding occurs in the flush listener when the cached store flushes. + * + * @param + * @param + */ +class SessionTupleForwarder { + private final ProcessorContext, Change> context; + private final boolean sendOldValues; + private final boolean cachingEnabled; + + @SuppressWarnings("unchecked") + SessionTupleForwarder(final StateStore store, + final ProcessorContext, Change> context, + final CacheFlushListener, V> flushListener, + final boolean sendOldValues) { + this.context = context; + this.sendOldValues = sendOldValues; + cachingEnabled = ((WrappedStateStore) store).setFlushListener(flushListener, sendOldValues); + } + + public void maybeForward(final Windowed key, + final V newValue, + final V oldValue) { + if (!cachingEnabled) { + context.forward(new Record<>( + key, + new Change<>(newValue, sendOldValues ? oldValue : null), + key.window().end())); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/SessionWindow.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/SessionWindow.java new file mode 100644 index 0000000..3057e32 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/SessionWindow.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.kstream.Window; + +/** + * A session window covers a closed time interval with its start and end timestamp both being an inclusive boundary. + *

                + * For time semantics, see {@link org.apache.kafka.streams.processor.TimestampExtractor TimestampExtractor}. + * + * @see TimeWindow + * @see UnlimitedWindow + * @see org.apache.kafka.streams.kstream.SessionWindows + * @see org.apache.kafka.streams.processor.TimestampExtractor + */ +public final class SessionWindow extends Window { + + /** + * Create a new window for the given start time and end time (both inclusive). + * + * @param startMs the start timestamp of the window + * @param endMs the end timestamp of the window + * @throws IllegalArgumentException if {@code startMs} is negative or if {@code endMs} is smaller than {@code startMs} + */ + public SessionWindow(final long startMs, final long endMs) throws IllegalArgumentException { + super(startMs, endMs); + } + + /** + * Check if the given window overlaps with this window. + * + * @param other another window + * @return {@code true} if {@code other} overlaps with this window—{@code false} otherwise + * @throws IllegalArgumentException if the {@code other} window has a different type than {@code this} window + */ + public boolean overlap(final Window other) throws IllegalArgumentException { + if (getClass() != other.getClass()) { + throw new IllegalArgumentException("Cannot compare windows of different type. Other window has type " + + other.getClass() + "."); + } + final SessionWindow otherWindow = (SessionWindow) other; + return !(otherWindow.endMs < startMs || endMs < otherWindow.startMs); + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/SessionWindowedCogroupedKStreamImpl.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/SessionWindowedCogroupedKStreamImpl.java new file mode 100644 index 0000000..a78bcd3 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/SessionWindowedCogroupedKStreamImpl.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Merger; +import org.apache.kafka.streams.kstream.Named; +import org.apache.kafka.streams.kstream.SessionWindowedCogroupedKStream; +import org.apache.kafka.streams.kstream.SessionWindows; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.WindowedSerdes; +import org.apache.kafka.streams.kstream.internals.graph.GraphNode; +import org.apache.kafka.streams.state.SessionBytesStoreSupplier; +import org.apache.kafka.streams.state.SessionStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; + +import java.time.Duration; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +public class SessionWindowedCogroupedKStreamImpl extends + AbstractStream implements SessionWindowedCogroupedKStream { + + private final SessionWindows sessionWindows; + private final CogroupedStreamAggregateBuilder aggregateBuilder; + private final Map, Aggregator> groupPatterns; + + SessionWindowedCogroupedKStreamImpl(final SessionWindows sessionWindows, + final InternalStreamsBuilder builder, + final Set subTopologySourceNodes, + final String name, + final CogroupedStreamAggregateBuilder aggregateBuilder, + final GraphNode graphNode, + final Map, Aggregator> groupPatterns) { + super(name, null, null, subTopologySourceNodes, graphNode, builder); + //keySerde and valueSerde are null because there are many different groupStreams that they could be from + this.sessionWindows = sessionWindows; + this.aggregateBuilder = aggregateBuilder; + this.groupPatterns = groupPatterns; + } + + @Override + public KTable, V> aggregate(final Initializer initializer, + final Merger sessionMerger) { + return aggregate(initializer, sessionMerger, Materialized.with(null, null)); + } + + @Override + public KTable, V> aggregate(final Initializer initializer, + final Merger sessionMerger, + final Materialized> materialized) { + return aggregate(initializer, sessionMerger, NamedInternal.empty(), materialized); + } + + @Override + public KTable, V> aggregate(final Initializer initializer, + final Merger sessionMerger, final Named named) { + return aggregate(initializer, sessionMerger, named, Materialized.with(null, null)); + } + + @Override + public KTable, V> aggregate(final Initializer initializer, + final Merger sessionMerger, final Named named, + final Materialized> materialized) { + Objects.requireNonNull(initializer, "initializer can't be null"); + Objects.requireNonNull(sessionMerger, "sessionMerger can't be null"); + Objects.requireNonNull(materialized, "materialized can't be null"); + Objects.requireNonNull(named, "named can't be null"); + final MaterializedInternal> materializedInternal = new MaterializedInternal<>( + materialized, + builder, + CogroupedKStreamImpl.AGGREGATE_NAME); + return aggregateBuilder.build( + groupPatterns, + initializer, + new NamedInternal(named), + materialize(materializedInternal), + materializedInternal.keySerde() != null ? + new WindowedSerdes.SessionWindowedSerde<>( + materializedInternal.keySerde()) : + null, + materializedInternal.valueSerde(), + materializedInternal.queryableStoreName(), + sessionWindows, + sessionMerger); + } + + private StoreBuilder> materialize(final MaterializedInternal> materialized) { + SessionBytesStoreSupplier supplier = (SessionBytesStoreSupplier) materialized.storeSupplier(); + if (supplier == null) { + final long retentionPeriod = materialized.retention() != null ? + materialized.retention().toMillis() : sessionWindows.inactivityGap() + sessionWindows.gracePeriodMs(); + + if ((sessionWindows.inactivityGap() + sessionWindows.gracePeriodMs()) > retentionPeriod) { + throw new IllegalArgumentException("The retention period of the session store " + + materialized.storeName() + + " must be no smaller than the session inactivity gap plus the" + + " grace period." + + " Got gap=[" + sessionWindows.inactivityGap() + "]," + + " grace=[" + sessionWindows.gracePeriodMs() + "]," + + " retention=[" + retentionPeriod + "]"); + } + supplier = Stores.persistentSessionStore( + materialized.storeName(), + Duration.ofMillis(retentionPeriod) + ); + } + final StoreBuilder> builder = Stores.sessionStoreBuilder( + supplier, + materialized.keySerde(), + materialized.valueSerde() + ); + + if (materialized.loggingEnabled()) { + builder.withLoggingEnabled(materialized.logConfig()); + } else { + builder.withLoggingDisabled(); + } + + if (materialized.cachingEnabled()) { + builder.withCachingEnabled(); + } + return builder; + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/SessionWindowedKStreamImpl.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/SessionWindowedKStreamImpl.java new file mode 100644 index 0000000..fe9a3a1 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/SessionWindowedKStreamImpl.java @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Merger; +import org.apache.kafka.streams.kstream.Named; +import org.apache.kafka.streams.kstream.Reducer; +import org.apache.kafka.streams.kstream.SessionWindowedKStream; +import org.apache.kafka.streams.kstream.SessionWindows; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.WindowedSerdes; +import org.apache.kafka.streams.kstream.internals.graph.GraphNode; +import org.apache.kafka.streams.state.SessionBytesStoreSupplier; +import org.apache.kafka.streams.state.SessionStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; + +import java.time.Duration; +import java.util.Objects; +import java.util.Set; + +import static org.apache.kafka.streams.kstream.internals.KGroupedStreamImpl.AGGREGATE_NAME; +import static org.apache.kafka.streams.kstream.internals.KGroupedStreamImpl.REDUCE_NAME; + +public class SessionWindowedKStreamImpl extends AbstractStream implements SessionWindowedKStream { + private final SessionWindows windows; + private final GroupedStreamAggregateBuilder aggregateBuilder; + private final Merger countMerger = (aggKey, aggOne, aggTwo) -> aggOne + aggTwo; + + SessionWindowedKStreamImpl(final SessionWindows windows, + final InternalStreamsBuilder builder, + final Set subTopologySourceNodes, + final String name, + final Serde keySerde, + final Serde valueSerde, + final GroupedStreamAggregateBuilder aggregateBuilder, + final GraphNode graphNode) { + super(name, keySerde, valueSerde, subTopologySourceNodes, graphNode, builder); + Objects.requireNonNull(windows, "windows can't be null"); + this.windows = windows; + this.aggregateBuilder = aggregateBuilder; + } + + @Override + public KTable, Long> count() { + return count(NamedInternal.empty()); + } + + @Override + public KTable, Long> count(final Named named) { + return doCount(named, Materialized.with(keySerde, Serdes.Long())); + } + + @Override + public KTable, Long> count(final Materialized> materialized) { + return count(NamedInternal.empty(), materialized); + } + + @Override + public KTable, Long> count(final Named named, final Materialized> materialized) { + Objects.requireNonNull(materialized, "materialized can't be null"); + + // TODO: remove this when we do a topology-incompatible release + // we used to burn a topology name here, so we have to keep doing it for compatibility + if (new MaterializedInternal<>(materialized).storeName() == null) { + builder.newStoreName(AGGREGATE_NAME); + } + + return doCount(named, materialized); + } + + private KTable, Long> doCount(final Named named, + final Materialized> materialized) { + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(materialized, builder, AGGREGATE_NAME); + + if (materializedInternal.keySerde() == null) { + materializedInternal.withKeySerde(keySerde); + } + if (materializedInternal.valueSerde() == null) { + materializedInternal.withValueSerde(Serdes.Long()); + } + + final String aggregateName = new NamedInternal(named).orElseGenerateWithPrefix(builder, AGGREGATE_NAME); + return aggregateBuilder.build( + new NamedInternal(aggregateName), + materialize(materializedInternal), + new KStreamSessionWindowAggregate<>( + windows, + materializedInternal.storeName(), + aggregateBuilder.countInitializer, + aggregateBuilder.countAggregator, + countMerger), + materializedInternal.queryableStoreName(), + materializedInternal.keySerde() != null ? new WindowedSerdes.SessionWindowedSerde<>(materializedInternal.keySerde()) : null, + materializedInternal.valueSerde()); + } + + @Override + public KTable, V> reduce(final Reducer reducer) { + return reduce(reducer, NamedInternal.empty()); + } + + @Override + public KTable, V> reduce(final Reducer reducer, final Named named) { + return reduce(reducer, named, Materialized.with(keySerde, valueSerde)); + } + + @Override + public KTable, V> reduce(final Reducer reducer, + final Materialized> materialized) { + return reduce(reducer, NamedInternal.empty(), materialized); + } + + @Override + public KTable, V> reduce(final Reducer reducer, + final Named named, + final Materialized> materialized) { + Objects.requireNonNull(reducer, "reducer can't be null"); + Objects.requireNonNull(named, "named can't be null"); + Objects.requireNonNull(materialized, "materialized can't be null"); + final Aggregator reduceAggregator = aggregatorForReducer(reducer); + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(materialized, builder, REDUCE_NAME); + if (materializedInternal.keySerde() == null) { + materializedInternal.withKeySerde(keySerde); + } + if (materializedInternal.valueSerde() == null) { + materializedInternal.withValueSerde(valueSerde); + } + + final String reduceName = new NamedInternal(named).orElseGenerateWithPrefix(builder, REDUCE_NAME); + return aggregateBuilder.build( + new NamedInternal(reduceName), + materialize(materializedInternal), + new KStreamSessionWindowAggregate<>( + windows, + materializedInternal.storeName(), + aggregateBuilder.reduceInitializer, + reduceAggregator, + mergerForAggregator(reduceAggregator) + ), + materializedInternal.queryableStoreName(), + materializedInternal.keySerde() != null ? new WindowedSerdes.SessionWindowedSerde<>(materializedInternal.keySerde()) : null, + materializedInternal.valueSerde()); + } + + @Override + public KTable, T> aggregate(final Initializer initializer, + final Aggregator aggregator, + final Merger sessionMerger) { + return aggregate(initializer, aggregator, sessionMerger, NamedInternal.empty()); + } + + @Override + public KTable, T> aggregate(final Initializer initializer, + final Aggregator aggregator, + final Merger sessionMerger, + final Named named) { + return aggregate(initializer, aggregator, sessionMerger, named, Materialized.with(keySerde, null)); + } + + @Override + public KTable, VR> aggregate(final Initializer initializer, + final Aggregator aggregator, + final Merger sessionMerger, + final Materialized> materialized) { + return aggregate(initializer, aggregator, sessionMerger, NamedInternal.empty(), materialized); + } + + @Override + public KTable, VR> aggregate(final Initializer initializer, + final Aggregator aggregator, + final Merger sessionMerger, + final Named named, + final Materialized> materialized) { + Objects.requireNonNull(initializer, "initializer can't be null"); + Objects.requireNonNull(aggregator, "aggregator can't be null"); + Objects.requireNonNull(sessionMerger, "sessionMerger can't be null"); + Objects.requireNonNull(materialized, "materialized can't be null"); + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(materialized, builder, AGGREGATE_NAME); + + if (materializedInternal.keySerde() == null) { + materializedInternal.withKeySerde(keySerde); + } + + final String aggregateName = new NamedInternal(named).orElseGenerateWithPrefix(builder, AGGREGATE_NAME); + + return aggregateBuilder.build( + new NamedInternal(aggregateName), + materialize(materializedInternal), + new KStreamSessionWindowAggregate<>( + windows, + materializedInternal.storeName(), + initializer, + aggregator, + sessionMerger), + materializedInternal.queryableStoreName(), + materializedInternal.keySerde() != null ? new WindowedSerdes.SessionWindowedSerde<>(materializedInternal.keySerde()) : null, + materializedInternal.valueSerde()); + } + + private StoreBuilder> materialize(final MaterializedInternal> materialized) { + SessionBytesStoreSupplier supplier = (SessionBytesStoreSupplier) materialized.storeSupplier(); + if (supplier == null) { + final long retentionPeriod = materialized.retention() != null ? + materialized.retention().toMillis() : windows.inactivityGap() + windows.gracePeriodMs(); + + if ((windows.inactivityGap() + windows.gracePeriodMs()) > retentionPeriod) { + throw new IllegalArgumentException("The retention period of the session store " + + materialized.storeName() + + " must be no smaller than the session inactivity gap plus the" + + " grace period." + + " Got gap=[" + windows.inactivityGap() + "]," + + " grace=[" + windows.gracePeriodMs() + "]," + + " retention=[" + retentionPeriod + "]"); + } + supplier = Stores.persistentSessionStore( + materialized.storeName(), + Duration.ofMillis(retentionPeriod) + ); + } + final StoreBuilder> builder = Stores.sessionStoreBuilder( + supplier, + materialized.keySerde(), + materialized.valueSerde() + ); + + if (materialized.loggingEnabled()) { + builder.withLoggingEnabled(materialized.logConfig()); + } else { + builder.withLoggingDisabled(); + } + + if (materialized.cachingEnabled()) { + builder.withCachingEnabled(); + } + return builder; + } + + private Merger mergerForAggregator(final Aggregator aggregator) { + return (aggKey, aggOne, aggTwo) -> aggregator.apply(aggKey, aggTwo, aggOne); + } + + private Aggregator aggregatorForReducer(final Reducer reducer) { + return (aggKey, value, aggregate) -> aggregate == null ? value : reducer.apply(aggregate, value); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/SlidingWindowedCogroupedKStreamImpl.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/SlidingWindowedCogroupedKStreamImpl.java new file mode 100644 index 0000000..a432b1f --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/SlidingWindowedCogroupedKStreamImpl.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Named; +import org.apache.kafka.streams.kstream.SlidingWindows; +import org.apache.kafka.streams.kstream.TimeWindowedCogroupedKStream; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.graph.GraphNode; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.TimestampedWindowStore; +import org.apache.kafka.streams.state.WindowBytesStoreSupplier; +import org.apache.kafka.streams.state.WindowStore; +import java.time.Duration; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +public class SlidingWindowedCogroupedKStreamImpl extends AbstractStream implements TimeWindowedCogroupedKStream { + private final SlidingWindows windows; + private final CogroupedStreamAggregateBuilder aggregateBuilder; + private final Map, Aggregator> groupPatterns; + + SlidingWindowedCogroupedKStreamImpl(final SlidingWindows windows, + final InternalStreamsBuilder builder, + final Set subTopologySourceNodes, + final String name, + final CogroupedStreamAggregateBuilder aggregateBuilder, + final GraphNode graphNode, + final Map, Aggregator> groupPatterns) { + super(name, null, null, subTopologySourceNodes, graphNode, builder); + //keySerde and valueSerde are null because there are many different groupStreams that they could be from + this.windows = windows; + this.aggregateBuilder = aggregateBuilder; + this.groupPatterns = groupPatterns; + } + + @Override + public KTable, V> aggregate(final Initializer initializer) { + return aggregate(initializer, Materialized.with(null, null)); + } + + @Override + public KTable, V> aggregate(final Initializer initializer, + final Materialized> materialized) { + return aggregate(initializer, NamedInternal.empty(), materialized); + } + + @Override + public KTable, V> aggregate(final Initializer initializer, + final Named named) { + return aggregate(initializer, named, Materialized.with(null, null)); + } + + @Override + public KTable, V> aggregate(final Initializer initializer, + final Named named, + final Materialized> materialized) { + Objects.requireNonNull(initializer, "initializer can't be null"); + Objects.requireNonNull(named, "named can't be null"); + Objects.requireNonNull(materialized, "materialized can't be null"); + final MaterializedInternal> materializedInternal = new MaterializedInternal<>( + materialized, + builder, + CogroupedKStreamImpl.AGGREGATE_NAME); + return aggregateBuilder.build( + groupPatterns, + initializer, + new NamedInternal(named), + materialize(materializedInternal), + materializedInternal.keySerde() != null ? + new FullTimeWindowedSerde<>(materializedInternal.keySerde(), windows.timeDifferenceMs()) + : null, + materializedInternal.valueSerde(), + materializedInternal.queryableStoreName(), + windows); + } + + private StoreBuilder> materialize(final MaterializedInternal> materialized) { + WindowBytesStoreSupplier supplier = (WindowBytesStoreSupplier) materialized.storeSupplier(); + if (supplier == null) { + final long retentionPeriod = materialized.retention() != null ? materialized.retention().toMillis() : windows.gracePeriodMs() + 2 * windows.timeDifferenceMs(); + + if ((windows.timeDifferenceMs() * 2 + windows.gracePeriodMs()) > retentionPeriod) { + throw new IllegalArgumentException("The retention period of the window store " + + name + + " must be no smaller than 2 * time difference plus the grace period." + + " Got time difference=[" + windows.timeDifferenceMs() + "]," + + " grace=[" + windows.gracePeriodMs() + + "]," + + " retention=[" + retentionPeriod + + "]"); + } + + supplier = Stores.persistentTimestampedWindowStore( + materialized.storeName(), + Duration.ofMillis(retentionPeriod), + Duration.ofMillis(windows.timeDifferenceMs()), + false + ); + + } + final StoreBuilder> builder = Stores + .timestampedWindowStoreBuilder( + supplier, + materialized.keySerde(), + materialized.valueSerde() + ); + + if (materialized.loggingEnabled()) { + builder.withLoggingEnabled(materialized.logConfig()); + } else { + builder.withLoggingDisabled(); + } + if (materialized.cachingEnabled()) { + builder.withCachingEnabled(); + } else { + builder.withCachingDisabled(); + } + return builder; + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/SlidingWindowedKStreamImpl.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/SlidingWindowedKStreamImpl.java new file mode 100644 index 0000000..ddfe9ab --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/SlidingWindowedKStreamImpl.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Named; +import org.apache.kafka.streams.kstream.Reducer; +import org.apache.kafka.streams.kstream.SlidingWindows; +import org.apache.kafka.streams.kstream.TimeWindowedKStream; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.graph.GraphNode; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.TimestampedWindowStore; +import org.apache.kafka.streams.state.WindowBytesStoreSupplier; +import org.apache.kafka.streams.state.WindowStore; +import java.time.Duration; +import java.util.Objects; +import java.util.Set; +import static org.apache.kafka.streams.kstream.internals.KGroupedStreamImpl.AGGREGATE_NAME; +import static org.apache.kafka.streams.kstream.internals.KGroupedStreamImpl.REDUCE_NAME; + +public class SlidingWindowedKStreamImpl extends AbstractStream implements TimeWindowedKStream { + private final SlidingWindows windows; + private final GroupedStreamAggregateBuilder aggregateBuilder; + + SlidingWindowedKStreamImpl(final SlidingWindows windows, + final InternalStreamsBuilder builder, + final Set subTopologySourceNodes, + final String name, + final Serde keySerde, + final Serde valueSerde, + final GroupedStreamAggregateBuilder aggregateBuilder, + final GraphNode graphNode) { + super(name, keySerde, valueSerde, subTopologySourceNodes, graphNode, builder); + this.windows = Objects.requireNonNull(windows, "windows can't be null"); + this.aggregateBuilder = aggregateBuilder; + } + + @Override + public KTable, Long> count() { + return count(NamedInternal.empty()); + } + + @Override + public KTable, Long> count(final Named named) { + return doCount(named, Materialized.with(keySerde, Serdes.Long())); + } + + @Override + public KTable, Long> count(final Materialized> materialized) { + return count(NamedInternal.empty(), materialized); + } + + @Override + public KTable, Long> count(final Named named, final Materialized> materialized) { + Objects.requireNonNull(materialized, "materialized can't be null"); + return doCount(named, materialized); + } + + private KTable, Long> doCount(final Named named, + final Materialized> materialized) { + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(materialized, builder, AGGREGATE_NAME); + + if (materializedInternal.keySerde() == null) { + materializedInternal.withKeySerde(keySerde); + } + if (materializedInternal.valueSerde() == null) { + materializedInternal.withValueSerde(Serdes.Long()); + } + + final String aggregateName = new NamedInternal(named).orElseGenerateWithPrefix(builder, AGGREGATE_NAME); + + return aggregateBuilder.build( + new NamedInternal(aggregateName), + materialize(materializedInternal), + new KStreamSlidingWindowAggregate<>(windows, materializedInternal.storeName(), aggregateBuilder.countInitializer, aggregateBuilder.countAggregator), + materializedInternal.queryableStoreName(), + materializedInternal.keySerde() != null ? new FullTimeWindowedSerde<>(materializedInternal.keySerde(), windows.timeDifferenceMs()) : null, + materializedInternal.valueSerde()); + } + + @Override + public KTable, VR> aggregate(final Initializer initializer, + final Aggregator aggregator) { + return aggregate(initializer, aggregator, Materialized.with(keySerde, null)); + } + + @Override + public KTable, VR> aggregate(final Initializer initializer, + final Aggregator aggregator, + final Named named) { + return aggregate(initializer, aggregator, named, Materialized.with(keySerde, null)); + } + + @Override + public KTable, VR> aggregate(final Initializer initializer, + final Aggregator aggregator, + final Materialized> materialized) { + return aggregate(initializer, aggregator, NamedInternal.empty(), materialized); + } + + @Override + public KTable, VR> aggregate(final Initializer initializer, + final Aggregator aggregator, + final Named named, + final Materialized> materialized) { + Objects.requireNonNull(initializer, "initializer can't be null"); + Objects.requireNonNull(aggregator, "aggregator can't be null"); + Objects.requireNonNull(materialized, "materialized can't be null"); + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(materialized, builder, AGGREGATE_NAME); + if (materializedInternal.keySerde() == null) { + materializedInternal.withKeySerde(keySerde); + } + final String aggregateName = new NamedInternal(named).orElseGenerateWithPrefix(builder, AGGREGATE_NAME); + + return aggregateBuilder.build( + new NamedInternal(aggregateName), + materialize(materializedInternal), + new KStreamSlidingWindowAggregate<>(windows, materializedInternal.storeName(), initializer, aggregator), + materializedInternal.queryableStoreName(), + materializedInternal.keySerde() != null ? new FullTimeWindowedSerde<>(materializedInternal.keySerde(), windows.timeDifferenceMs()) : null, + materializedInternal.valueSerde()); + } + + @Override + public KTable, V> reduce(final Reducer reducer) { + return reduce(reducer, NamedInternal.empty()); + } + + @Override + public KTable, V> reduce(final Reducer reducer, final Named named) { + return reduce(reducer, named, Materialized.with(keySerde, valueSerde)); + } + + @Override + public KTable, V> reduce(final Reducer reducer, + final Materialized> materialized) { + return reduce(reducer, NamedInternal.empty(), materialized); + } + + @Override + public KTable, V> reduce(final Reducer reducer, + final Named named, + final Materialized> materialized) { + Objects.requireNonNull(reducer, "reducer can't be null"); + Objects.requireNonNull(named, "named can't be null"); + Objects.requireNonNull(materialized, "materialized can't be null"); + + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(materialized, builder, REDUCE_NAME); + + if (materializedInternal.keySerde() == null) { + materializedInternal.withKeySerde(keySerde); + } + if (materializedInternal.valueSerde() == null) { + materializedInternal.withValueSerde(valueSerde); + } + + final String reduceName = new NamedInternal(named).orElseGenerateWithPrefix(builder, REDUCE_NAME); + + return aggregateBuilder.build( + new NamedInternal(reduceName), + materialize(materializedInternal), + new KStreamSlidingWindowAggregate<>(windows, materializedInternal.storeName(), aggregateBuilder.reduceInitializer, aggregatorForReducer(reducer)), + materializedInternal.queryableStoreName(), + materializedInternal.keySerde() != null ? new FullTimeWindowedSerde<>(materializedInternal.keySerde(), windows.timeDifferenceMs()) : null, + materializedInternal.valueSerde()); + } + + private StoreBuilder> materialize(final MaterializedInternal> materialized) { + WindowBytesStoreSupplier supplier = (WindowBytesStoreSupplier) materialized.storeSupplier(); + if (supplier == null) { + final long retentionPeriod = materialized.retention() != null ? materialized.retention().toMillis() : windows.gracePeriodMs() + 2 * windows.timeDifferenceMs(); + + // large retention time to ensure that all existing windows needed to create new sliding windows can be accessed + // earliest window start time we could need to create corresponding right window would be recordTime - 2 * timeDifference + if ((windows.timeDifferenceMs() * 2 + windows.gracePeriodMs()) > retentionPeriod) { + throw new IllegalArgumentException("The retention period of the window store " + + name + " must be no smaller than 2 * time difference plus the grace period." + + " Got time difference=[" + windows.timeDifferenceMs() + "]," + + " grace=[" + windows.gracePeriodMs() + "]," + + " retention=[" + retentionPeriod + "]"); + } + supplier = Stores.persistentTimestampedWindowStore( + materialized.storeName(), + Duration.ofMillis(retentionPeriod), + Duration.ofMillis(windows.timeDifferenceMs()), + false + ); + } + final StoreBuilder> builder = Stores.timestampedWindowStoreBuilder( + supplier, + materialized.keySerde(), + materialized.valueSerde() + ); + + if (materialized.loggingEnabled()) { + builder.withLoggingEnabled(materialized.logConfig()); + } else { + builder.withLoggingDisabled(); + } + if (materialized.cachingEnabled()) { + builder.withCachingEnabled(); + } else { + builder.withCachingDisabled(); + } + return builder; + } + + private Aggregator aggregatorForReducer(final Reducer reducer) { + return (aggKey, value, aggregate) -> aggregate == null ? value : reducer.apply(aggregate, value); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/StreamJoinedInternal.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/StreamJoinedInternal.java new file mode 100644 index 0000000..670783e --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/StreamJoinedInternal.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.kstream.StreamJoined; +import org.apache.kafka.streams.state.WindowBytesStoreSupplier; + +import java.util.Map; + +public class StreamJoinedInternal extends StreamJoined { + + //Needs to be public for testing + public StreamJoinedInternal(final StreamJoined streamJoined) { + super(streamJoined); + } + + public Serde keySerde() { + return keySerde; + } + + public Serde valueSerde() { + return valueSerde; + } + + public Serde otherValueSerde() { + return otherValueSerde; + } + + public String name() { + return name; + } + + public String storeName() { + return storeName; + } + + public WindowBytesStoreSupplier thisStoreSupplier() { + return thisStoreSupplier; + } + + public WindowBytesStoreSupplier otherStoreSupplier() { + return otherStoreSupplier; + } + + public boolean loggingEnabled() { + return loggingEnabled; + } + + Map logConfig() { + return topicConfig; + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/TableJoinedInternal.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/TableJoinedInternal.java new file mode 100644 index 0000000..fe16552 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/TableJoinedInternal.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.kstream.TableJoined; +import org.apache.kafka.streams.processor.StreamPartitioner; + +public class TableJoinedInternal extends TableJoined { + + TableJoinedInternal(final TableJoined tableJoined) { + super(tableJoined); + } + + public StreamPartitioner partitioner() { + return partitioner; + } + + public StreamPartitioner otherPartitioner() { + return otherPartitioner; + } + + public String name() { + return name; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/TimeWindow.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/TimeWindow.java new file mode 100644 index 0000000..4037d8f --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/TimeWindow.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.kstream.Window; + +/** + * A {@link TimeWindow} covers a half-open time interval with its start timestamp as an inclusive boundary and its end + * timestamp as exclusive boundary. + * It is a fixed size window, i.e., all instances (of a single {@link org.apache.kafka.streams.kstream.TimeWindows + * window specification}) will have the same size. + *

                + * For time semantics, see {@link org.apache.kafka.streams.processor.TimestampExtractor TimestampExtractor}. + * + * @see SessionWindow + * @see UnlimitedWindow + * @see org.apache.kafka.streams.kstream.TimeWindows + * @see org.apache.kafka.streams.processor.TimestampExtractor + */ +public class TimeWindow extends Window { + + /** + * Create a new window for the given start time (inclusive) and end time (exclusive). + * + * @param startMs the start timestamp of the window (inclusive) + * @param endMs the end timestamp of the window (exclusive) + * @throws IllegalArgumentException if {@code startMs} is negative or if {@code endMs} is smaller than or equal to + * {@code startMs} + */ + public TimeWindow(final long startMs, final long endMs) throws IllegalArgumentException { + super(startMs, endMs); + if (startMs == endMs) { + throw new IllegalArgumentException("Window endMs must be greater than window startMs."); + } + } + + /** + * Check if the given window overlaps with this window. + * + * @param other another window + * @return {@code true} if {@code other} overlaps with this window—{@code false} otherwise + * @throws IllegalArgumentException if the {@code other} window has a different type than {@code this} window + */ + @Override + public boolean overlap(final Window other) throws IllegalArgumentException { + if (getClass() != other.getClass()) { + throw new IllegalArgumentException("Cannot compare windows of different type. Other window has type " + + other.getClass() + "."); + } + final TimeWindow otherWindow = (TimeWindow) other; + return startMs < otherWindow.endMs && otherWindow.startMs < endMs; + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/TimeWindowedCogroupedKStreamImpl.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/TimeWindowedCogroupedKStreamImpl.java new file mode 100644 index 0000000..8cef89f --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/TimeWindowedCogroupedKStreamImpl.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Named; +import org.apache.kafka.streams.kstream.TimeWindowedCogroupedKStream; +import org.apache.kafka.streams.kstream.Window; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.Windows; +import org.apache.kafka.streams.kstream.internals.graph.GraphNode; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.TimestampedWindowStore; +import org.apache.kafka.streams.state.WindowBytesStoreSupplier; +import org.apache.kafka.streams.state.WindowStore; + +import java.time.Duration; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +public class TimeWindowedCogroupedKStreamImpl extends AbstractStream + implements TimeWindowedCogroupedKStream { + + private final Windows windows; + private final CogroupedStreamAggregateBuilder aggregateBuilder; + private final Map, Aggregator> groupPatterns; + + TimeWindowedCogroupedKStreamImpl(final Windows windows, + final InternalStreamsBuilder builder, + final Set subTopologySourceNodes, + final String name, + final CogroupedStreamAggregateBuilder aggregateBuilder, + final GraphNode graphNode, + final Map, Aggregator> groupPatterns) { + super(name, null, null, subTopologySourceNodes, graphNode, builder); + //keySerde and valueSerde are null because there are many different groupStreams that they could be from + this.windows = windows; + this.aggregateBuilder = aggregateBuilder; + this.groupPatterns = groupPatterns; + } + + + @Override + public KTable, V> aggregate(final Initializer initializer) { + return aggregate(initializer, Materialized.with(null, null)); + } + + @Override + public KTable, V> aggregate(final Initializer initializer, + final Materialized> materialized) { + return aggregate(initializer, NamedInternal.empty(), materialized); + } + + @Override + public KTable, V> aggregate(final Initializer initializer, + final Named named) { + return aggregate(initializer, named, Materialized.with(null, null)); + } + + @Override + public KTable, V> aggregate(final Initializer initializer, + final Named named, + final Materialized> materialized) { + Objects.requireNonNull(initializer, "initializer can't be null"); + Objects.requireNonNull(named, "named can't be null"); + Objects.requireNonNull(materialized, "materialized can't be null"); + final MaterializedInternal> materializedInternal = new MaterializedInternal<>( + materialized, + builder, + CogroupedKStreamImpl.AGGREGATE_NAME); + return aggregateBuilder.build( + groupPatterns, + initializer, + new NamedInternal(named), + materialize(materializedInternal), + materializedInternal.keySerde() != null ? + new FullTimeWindowedSerde<>(materializedInternal.keySerde(), windows.size()) + : null, + materializedInternal.valueSerde(), + materializedInternal.queryableStoreName(), + windows); + } + + private StoreBuilder> materialize( + final MaterializedInternal> materialized) { + WindowBytesStoreSupplier supplier = (WindowBytesStoreSupplier) materialized.storeSupplier(); + if (supplier == null) { + final long retentionPeriod = materialized.retention() != null ? + materialized.retention().toMillis() : windows.size() + windows.gracePeriodMs(); + + if ((windows.size() + windows.gracePeriodMs()) > retentionPeriod) { + throw new IllegalArgumentException("The retention period of the window store " + + name + + " must be no smaller than its window size plus the grace period." + + " Got size=[" + windows.size() + "]," + + " grace=[" + windows.gracePeriodMs() + + "]," + + " retention=[" + retentionPeriod + + "]"); + } + + supplier = Stores.persistentTimestampedWindowStore( + materialized.storeName(), + Duration.ofMillis(retentionPeriod), + Duration.ofMillis(windows.size()), + false + ); + } + + final StoreBuilder> builder = Stores + .timestampedWindowStoreBuilder( + supplier, + materialized.keySerde(), + materialized.valueSerde() + ); + + if (materialized.loggingEnabled()) { + builder.withLoggingEnabled(materialized.logConfig()); + } else { + builder.withLoggingDisabled(); + } + + if (materialized.cachingEnabled()) { + builder.withCachingEnabled(); + } + return builder; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/TimeWindowedKStreamImpl.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/TimeWindowedKStreamImpl.java new file mode 100644 index 0000000..2282672 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/TimeWindowedKStreamImpl.java @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Named; +import org.apache.kafka.streams.kstream.Reducer; +import org.apache.kafka.streams.kstream.TimeWindowedKStream; +import org.apache.kafka.streams.kstream.Window; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.Windows; +import org.apache.kafka.streams.kstream.internals.graph.GraphNode; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.TimestampedWindowStore; +import org.apache.kafka.streams.state.WindowBytesStoreSupplier; +import org.apache.kafka.streams.state.WindowStore; + +import java.time.Duration; +import java.util.Objects; +import java.util.Set; + +import static org.apache.kafka.streams.kstream.internals.KGroupedStreamImpl.AGGREGATE_NAME; +import static org.apache.kafka.streams.kstream.internals.KGroupedStreamImpl.REDUCE_NAME; + +public class TimeWindowedKStreamImpl extends AbstractStream implements TimeWindowedKStream { + + private final Windows windows; + private final GroupedStreamAggregateBuilder aggregateBuilder; + + TimeWindowedKStreamImpl(final Windows windows, + final InternalStreamsBuilder builder, + final Set subTopologySourceNodes, + final String name, + final Serde keySerde, + final Serde valueSerde, + final GroupedStreamAggregateBuilder aggregateBuilder, + final GraphNode graphNode) { + super(name, keySerde, valueSerde, subTopologySourceNodes, graphNode, builder); + this.windows = Objects.requireNonNull(windows, "windows can't be null"); + this.aggregateBuilder = aggregateBuilder; + } + + @Override + public KTable, Long> count() { + return count(NamedInternal.empty()); + } + + @Override + public KTable, Long> count(final Named named) { + return doCount(named, Materialized.with(keySerde, Serdes.Long())); + } + + + @Override + public KTable, Long> count(final Materialized> materialized) { + return count(NamedInternal.empty(), materialized); + } + + @Override + public KTable, Long> count(final Named named, final Materialized> materialized) { + Objects.requireNonNull(materialized, "materialized can't be null"); + + // TODO: remove this when we do a topology-incompatible release + // we used to burn a topology name here, so we have to keep doing it for compatibility + if (new MaterializedInternal<>(materialized).storeName() == null) { + builder.newStoreName(AGGREGATE_NAME); + } + + return doCount(named, materialized); + } + + private KTable, Long> doCount(final Named named, + final Materialized> materialized) { + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(materialized, builder, AGGREGATE_NAME); + + if (materializedInternal.keySerde() == null) { + materializedInternal.withKeySerde(keySerde); + } + if (materializedInternal.valueSerde() == null) { + materializedInternal.withValueSerde(Serdes.Long()); + } + + final String aggregateName = new NamedInternal(named).orElseGenerateWithPrefix(builder, AGGREGATE_NAME); + + return aggregateBuilder.build( + new NamedInternal(aggregateName), + materialize(materializedInternal), + new KStreamWindowAggregate<>(windows, materializedInternal.storeName(), aggregateBuilder.countInitializer, aggregateBuilder.countAggregator), + materializedInternal.queryableStoreName(), + materializedInternal.keySerde() != null ? new FullTimeWindowedSerde<>(materializedInternal.keySerde(), windows.size()) : null, + materializedInternal.valueSerde()); + + + } + + @Override + public KTable, VR> aggregate(final Initializer initializer, + final Aggregator aggregator) { + return aggregate(initializer, aggregator, Materialized.with(keySerde, null)); + } + + @Override + public KTable, VR> aggregate(final Initializer initializer, + final Aggregator aggregator, + final Named named) { + return aggregate(initializer, aggregator, named, Materialized.with(keySerde, null)); + } + + + @Override + public KTable, VR> aggregate(final Initializer initializer, + final Aggregator aggregator, + final Materialized> materialized) { + return aggregate(initializer, aggregator, NamedInternal.empty(), materialized); + } + + @Override + public KTable, VR> aggregate(final Initializer initializer, + final Aggregator aggregator, + final Named named, + final Materialized> materialized) { + Objects.requireNonNull(initializer, "initializer can't be null"); + Objects.requireNonNull(aggregator, "aggregator can't be null"); + Objects.requireNonNull(materialized, "materialized can't be null"); + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(materialized, builder, AGGREGATE_NAME); + if (materializedInternal.keySerde() == null) { + materializedInternal.withKeySerde(keySerde); + } + + final String aggregateName = new NamedInternal(named).orElseGenerateWithPrefix(builder, AGGREGATE_NAME); + + return aggregateBuilder.build( + new NamedInternal(aggregateName), + materialize(materializedInternal), + new KStreamWindowAggregate<>(windows, materializedInternal.storeName(), initializer, aggregator), + materializedInternal.queryableStoreName(), + materializedInternal.keySerde() != null ? new FullTimeWindowedSerde<>(materializedInternal.keySerde(), windows.size()) : null, + materializedInternal.valueSerde()); + + + } + + @Override + public KTable, V> reduce(final Reducer reducer) { + return reduce(reducer, NamedInternal.empty()); + } + + @Override + public KTable, V> reduce(final Reducer reducer, final Named named) { + return reduce(reducer, named, Materialized.with(keySerde, valueSerde)); + } + + @Override + public KTable, V> reduce(final Reducer reducer, + final Materialized> materialized) { + return reduce(reducer, NamedInternal.empty(), materialized); + } + + @Override + public KTable, V> reduce(final Reducer reducer, + final Named named, + final Materialized> materialized) { + Objects.requireNonNull(reducer, "reducer can't be null"); + Objects.requireNonNull(named, "named can't be null"); + Objects.requireNonNull(materialized, "materialized can't be null"); + + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(materialized, builder, REDUCE_NAME); + + if (materializedInternal.keySerde() == null) { + materializedInternal.withKeySerde(keySerde); + } + if (materializedInternal.valueSerde() == null) { + materializedInternal.withValueSerde(valueSerde); + } + + final String reduceName = new NamedInternal(named).orElseGenerateWithPrefix(builder, REDUCE_NAME); + + return aggregateBuilder.build( + new NamedInternal(reduceName), + materialize(materializedInternal), + new KStreamWindowAggregate<>(windows, materializedInternal.storeName(), aggregateBuilder.reduceInitializer, aggregatorForReducer(reducer)), + materializedInternal.queryableStoreName(), + materializedInternal.keySerde() != null ? new FullTimeWindowedSerde<>(materializedInternal.keySerde(), windows.size()) : null, + materializedInternal.valueSerde()); + + + } + + private StoreBuilder> materialize(final MaterializedInternal> materialized) { + WindowBytesStoreSupplier supplier = (WindowBytesStoreSupplier) materialized.storeSupplier(); + if (supplier == null) { + final long retentionPeriod = materialized.retention() != null ? + materialized.retention().toMillis() : windows.size() + windows.gracePeriodMs(); + + if ((windows.size() + windows.gracePeriodMs()) > retentionPeriod) { + throw new IllegalArgumentException("The retention period of the window store " + + name + " must be no smaller than its window size plus the grace period." + + " Got size=[" + windows.size() + "]," + + " grace=[" + windows.gracePeriodMs() + "]," + + " retention=[" + retentionPeriod + "]"); + } + + supplier = Stores.persistentTimestampedWindowStore( + materialized.storeName(), + Duration.ofMillis(retentionPeriod), + Duration.ofMillis(windows.size()), + false + ); + } + + final StoreBuilder> builder = Stores.timestampedWindowStoreBuilder( + supplier, + materialized.keySerde(), + materialized.valueSerde() + ); + + if (materialized.loggingEnabled()) { + builder.withLoggingEnabled(materialized.logConfig()); + } else { + builder.withLoggingDisabled(); + } + + if (materialized.cachingEnabled()) { + builder.withCachingEnabled(); + } + return builder; + } + + private Aggregator aggregatorForReducer(final Reducer reducer) { + return (aggKey, value, aggregate) -> aggregate == null ? value : reducer.apply(aggregate, value); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/TimestampedCacheFlushListener.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/TimestampedCacheFlushListener.java new file mode 100644 index 0000000..4034414 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/TimestampedCacheFlushListener.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.processor.To; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.processor.internals.ProcessorNode; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.streams.state.internals.CacheFlushListener; + +import static org.apache.kafka.streams.state.ValueAndTimestamp.getValueOrNull; + +class TimestampedCacheFlushListener implements CacheFlushListener> { + private final InternalProcessorContext> context; + + @SuppressWarnings("rawtypes") + private final ProcessorNode myNode; + + TimestampedCacheFlushListener(final ProcessorContext> context) { + this.context = (InternalProcessorContext>) context; + myNode = this.context.currentNode(); + } + + @SuppressWarnings("unchecked") + TimestampedCacheFlushListener(final org.apache.kafka.streams.processor.ProcessorContext context) { + this.context = (InternalProcessorContext>) context; + myNode = this.context.currentNode(); + } + + @Override + public void apply(final KOut key, + final ValueAndTimestamp newValue, + final ValueAndTimestamp oldValue, + final long timestamp) { + final ProcessorNode prev = context.currentNode(); + context.setCurrentNode(myNode); + try { + context.forward( + key, + new Change<>(getValueOrNull(newValue), getValueOrNull(oldValue)), + To.all().withTimestamp(newValue != null ? newValue.timestamp() : timestamp)); + } finally { + context.setCurrentNode(prev); + } + } + + @Override + public void apply(final Record>> record) { + @SuppressWarnings("rawtypes") final ProcessorNode prev = context.currentNode(); + context.setCurrentNode(myNode); + try { + context.forward( + record.withValue( + new Change<>( + getValueOrNull(record.value().newValue), + getValueOrNull(record.value().oldValue) + ) + ) + ); + } finally { + context.setCurrentNode(prev); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/TimestampedKeyValueStoreMaterializer.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/TimestampedKeyValueStoreMaterializer.java new file mode 100644 index 0000000..fb40b46 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/TimestampedKeyValueStoreMaterializer.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.state.KeyValueBytesStoreSupplier; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; + +public class TimestampedKeyValueStoreMaterializer { + private final MaterializedInternal> materialized; + + public TimestampedKeyValueStoreMaterializer(final MaterializedInternal> materialized) { + this.materialized = materialized; + } + + /** + * @return StoreBuilder + */ + public StoreBuilder> materialize() { + KeyValueBytesStoreSupplier supplier = (KeyValueBytesStoreSupplier) materialized.storeSupplier(); + if (supplier == null) { + final String name = materialized.storeName(); + supplier = Stores.persistentTimestampedKeyValueStore(name); + } + final StoreBuilder> builder = Stores.timestampedKeyValueStoreBuilder( + supplier, + materialized.keySerde(), + materialized.valueSerde()); + + if (materialized.loggingEnabled()) { + builder.withLoggingEnabled(materialized.logConfig()); + } else { + builder.withLoggingDisabled(); + } + + if (materialized.cachingEnabled()) { + builder.withCachingEnabled(); + } + return builder; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/TimestampedTupleForwarder.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/TimestampedTupleForwarder.java new file mode 100644 index 0000000..6411b35 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/TimestampedTupleForwarder.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.To; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.state.internals.WrappedStateStore; + +/** + * This class is used to determine if a processor should forward values to child nodes. + * Forwarding by this class only occurs when caching is not enabled. If caching is enabled, + * forwarding occurs in the flush listener when the cached store flushes. + * + * @param the type of the key + * @param the type of the value + */ +class TimestampedTupleForwarder { + private final InternalProcessorContext> context; + private final boolean sendOldValues; + private final boolean cachingEnabled; + + @SuppressWarnings({"unchecked", "rawtypes"}) + TimestampedTupleForwarder(final StateStore store, + final ProcessorContext> context, + final TimestampedCacheFlushListener flushListener, + final boolean sendOldValues) { + this.context = (InternalProcessorContext>) context; + this.sendOldValues = sendOldValues; + cachingEnabled = ((WrappedStateStore) store).setFlushListener(flushListener, sendOldValues); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + TimestampedTupleForwarder(final StateStore store, + final org.apache.kafka.streams.processor.ProcessorContext context, + final TimestampedCacheFlushListener flushListener, + final boolean sendOldValues) { + this.context = (InternalProcessorContext) context; + this.sendOldValues = sendOldValues; + cachingEnabled = ((WrappedStateStore) store).setFlushListener(flushListener, sendOldValues); + } + + public void maybeForward(final Record> record) { + if (!cachingEnabled) { + if (sendOldValues) { + context.forward(record); + } else { + context.forward(record.withValue(new Change<>(record.value().newValue, null))); + } + } + } + + public void maybeForward(final K key, + final V newValue, + final V oldValue) { + if (!cachingEnabled) { + context.forward(key, new Change<>(newValue, sendOldValues ? oldValue : null)); + } + } + + public void maybeForward(final K key, + final V newValue, + final V oldValue, + final long timestamp) { + if (!cachingEnabled) { + context.forward(key, new Change<>(newValue, sendOldValues ? oldValue : null), To.all().withTimestamp(timestamp)); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/TransformerSupplierAdapter.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/TransformerSupplierAdapter.java new file mode 100644 index 0000000..93d2f55 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/TransformerSupplierAdapter.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import java.util.Collections; +import java.util.Set; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Transformer; +import org.apache.kafka.streams.kstream.TransformerSupplier; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.state.StoreBuilder; + +public class TransformerSupplierAdapter implements TransformerSupplier>> { + + private final TransformerSupplier> transformerSupplier; + + public TransformerSupplierAdapter(final TransformerSupplier> transformerSupplier) { + this.transformerSupplier = transformerSupplier; + } + + @Override + public Transformer>> get() { + return new Transformer>>() { + + private final Transformer> transformer = transformerSupplier.get(); + + @Override + public void init(final ProcessorContext context) { + transformer.init(context); + } + + @Override + public Iterable> transform(final KIn key, final VIn value) { + final KeyValue pair = transformer.transform(key, value); + if (pair != null) { + return Collections.singletonList(pair); + } + return Collections.emptyList(); + } + + @Override + public void close() { + transformer.close(); + } + }; + } + + @Override + public Set> stores() { + return transformerSupplier.stores(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/UnlimitedWindow.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/UnlimitedWindow.java new file mode 100644 index 0000000..9b29b9e --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/UnlimitedWindow.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.kstream.Window; + +/** + * {@link UnlimitedWindow} is an "infinite" large window with a fixed (inclusive) start time. + * All windows of the same {@link org.apache.kafka.streams.kstream.UnlimitedWindows window specification} will have the + * same start time. + * To make the window size "infinite" end time is set to {@link Long#MAX_VALUE}. + *

                + * For time semantics, see {@link org.apache.kafka.streams.processor.TimestampExtractor TimestampExtractor}. + * + * @see TimeWindow + * @see SessionWindow + * @see org.apache.kafka.streams.kstream.UnlimitedWindows + * @see org.apache.kafka.streams.processor.TimestampExtractor + */ +public class UnlimitedWindow extends Window { + + /** + * Create a new window for the given start time (inclusive). + * + * @param startMs the start timestamp of the window (inclusive) + * @throws IllegalArgumentException if {@code start} is negative + */ + public UnlimitedWindow(final long startMs) { + super(startMs, Long.MAX_VALUE); + } + + /** + * Returns {@code true} if the given window is of the same type, because all unlimited windows overlap with each + * other due to their infinite size. + * + * @param other another window + * @return {@code true} + * @throws IllegalArgumentException if the {@code other} window has a different type than {@code this} window + */ + @Override + public boolean overlap(final Window other) { + if (getClass() != other.getClass()) { + throw new IllegalArgumentException("Cannot compare windows of different type. Other window has type " + + other.getClass() + "."); + } + return true; + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/WindowedSerializer.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/WindowedSerializer.java new file mode 100644 index 0000000..09185b2 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/WindowedSerializer.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.kstream.Windowed; + +public interface WindowedSerializer extends Serializer> { + + byte[] serializeBaseKey(String topic, Windowed data); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/WindowedStreamPartitioner.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/WindowedStreamPartitioner.java new file mode 100644 index 0000000..8e1476a --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/WindowedStreamPartitioner.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.StreamPartitioner; + +import static org.apache.kafka.common.utils.Utils.toPositive; + +public class WindowedStreamPartitioner implements StreamPartitioner, V> { + + private final WindowedSerializer serializer; + + public WindowedStreamPartitioner(final WindowedSerializer serializer) { + this.serializer = serializer; + } + + /** + * WindowedStreamPartitioner determines the partition number for a record with the given windowed key and value + * and the current number of partitions. The partition number id determined by the original key of the windowed key + * using the same logic as DefaultPartitioner so that the topic is partitioned by the original key. + * + * @param topic the topic name this record is sent to + * @param windowedKey the key of the record + * @param value the value of the record + * @param numPartitions the total number of partitions + * @return an integer between 0 and {@code numPartitions-1}, or {@code null} if the default partitioning logic should be used + */ + @Override + public Integer partition(final String topic, final Windowed windowedKey, final V value, final int numPartitions) { + final byte[] keyBytes = serializer.serializeBaseKey(topic, windowedKey); + + // hash the keyBytes to choose a partition + return toPositive(Utils.murmur2(keyBytes)) % numPartitions; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/WrappingNullableDeserializer.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/WrappingNullableDeserializer.java new file mode 100644 index 0000000..ebe975d --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/WrappingNullableDeserializer.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.streams.processor.internals.SerdeGetter; + +public interface WrappingNullableDeserializer extends Deserializer { + void setIfUnset(final SerdeGetter getter); +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/WrappingNullableSerde.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/WrappingNullableSerde.java new file mode 100644 index 0000000..27f1704 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/WrappingNullableSerde.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.processor.internals.SerdeGetter; + +import java.util.Map; +import java.util.Objects; + +public abstract class WrappingNullableSerde implements Serde { + private final WrappingNullableSerializer serializer; + private final WrappingNullableDeserializer deserializer; + + protected WrappingNullableSerde(final WrappingNullableSerializer serializer, + final WrappingNullableDeserializer deserializer) { + Objects.requireNonNull(serializer, "serializer can't be null"); + Objects.requireNonNull(deserializer, "deserializer can't be null"); + this.serializer = serializer; + this.deserializer = deserializer; + } + + @Override + public Serializer serializer() { + return serializer; + } + + @Override + public Deserializer deserializer() { + return deserializer; + } + + @Override + public void configure(final Map configs, + final boolean isKey) { + serializer.configure(configs, isKey); + deserializer.configure(configs, isKey); + } + + @Override + public void close() { + serializer.close(); + deserializer.close(); + } + + public void setIfUnset(final SerdeGetter getter) { + serializer.setIfUnset(getter); + deserializer.setIfUnset(getter); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/WrappingNullableSerializer.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/WrappingNullableSerializer.java new file mode 100644 index 0000000..6840b32 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/WrappingNullableSerializer.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.processor.internals.SerdeGetter; + +public interface WrappingNullableSerializer extends Serializer { + void setIfUnset(final SerdeGetter getter); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/WrappingNullableUtils.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/WrappingNullableUtils.java new file mode 100644 index 0000000..b904608 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/WrappingNullableUtils.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.internals.SerdeGetter; + +/** + * If a component's serdes are Wrapping serdes, then they require a little extra setup + * to be fully initialized at run time. + */ +public class WrappingNullableUtils { + + @SuppressWarnings("unchecked") + private static Deserializer prepareDeserializer(final Deserializer specificDeserializer, final ProcessorContext context, final boolean isKey, final String name) { + final Deserializer deserializerToUse; + + if (specificDeserializer == null) { + final Deserializer contextKeyDeserializer = context.keySerde().deserializer(); + final Deserializer contextValueDeserializer = context.valueSerde().deserializer(); + deserializerToUse = (Deserializer) (isKey ? contextKeyDeserializer : contextValueDeserializer); + } else { + deserializerToUse = specificDeserializer; + initNullableDeserializer(deserializerToUse, new SerdeGetter(context)); + } + return deserializerToUse; + } + @SuppressWarnings("unchecked") + private static Serializer prepareSerializer(final Serializer specificSerializer, final ProcessorContext context, final boolean isKey, final String name) { + final Serializer serializerToUse; + if (specificSerializer == null) { + final Serializer contextKeySerializer = context.keySerde().serializer(); + final Serializer contextValueSerializer = context.valueSerde().serializer(); + serializerToUse = (Serializer) (isKey ? contextKeySerializer : contextValueSerializer); + } else { + serializerToUse = specificSerializer; + initNullableSerializer(serializerToUse, new SerdeGetter(context)); + } + return serializerToUse; + } + + @SuppressWarnings({"rawtypes", "unchecked"}) + private static Serde prepareSerde(final Serde specificSerde, final SerdeGetter getter, final boolean isKey) { + final Serde serdeToUse; + if (specificSerde == null) { + serdeToUse = (Serde) (isKey ? getter.keySerde() : getter.valueSerde()); + } else { + serdeToUse = specificSerde; + } + if (serdeToUse instanceof WrappingNullableSerde) { + ((WrappingNullableSerde) serdeToUse).setIfUnset(getter); + } + return serdeToUse; + } + + public static Deserializer prepareKeyDeserializer(final Deserializer specificDeserializer, final ProcessorContext context, final String name) { + return prepareDeserializer(specificDeserializer, context, true, name); + } + + public static Deserializer prepareValueDeserializer(final Deserializer specificDeserializer, final ProcessorContext context, final String name) { + return prepareDeserializer(specificDeserializer, context, false, name); + } + + public static Serializer prepareKeySerializer(final Serializer specificSerializer, final ProcessorContext context, final String name) { + return prepareSerializer(specificSerializer, context, true, name); + } + + public static Serializer prepareValueSerializer(final Serializer specificSerializer, final ProcessorContext context, final String name) { + return prepareSerializer(specificSerializer, context, false, name); + } + + public static Serde prepareKeySerde(final Serde specificSerde, final SerdeGetter getter) { + return prepareSerde(specificSerde, getter, true); + } + + public static Serde prepareValueSerde(final Serde specificSerde, final SerdeGetter getter) { + return prepareSerde(specificSerde, getter, false); + } + @SuppressWarnings({"rawtypes", "unchecked"}) + public static void initNullableSerializer(final Serializer specificSerializer, final SerdeGetter getter) { + if (specificSerializer instanceof WrappingNullableSerializer) { + ((WrappingNullableSerializer) specificSerializer).setIfUnset(getter); + } + } + @SuppressWarnings({"rawtypes", "unchecked"}) + public static void initNullableDeserializer(final Deserializer specificDeserializer, final SerdeGetter getter) { + if (specificDeserializer instanceof WrappingNullableDeserializer) { + ((WrappingNullableDeserializer) specificDeserializer).setIfUnset(getter); + } + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/CombinedKey.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/CombinedKey.java new file mode 100644 index 0000000..f196b18 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/CombinedKey.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals.foreignkeyjoin; + +import java.util.Objects; + +public class CombinedKey { + private final KF foreignKey; + private final KP primaryKey; + + CombinedKey(final KF foreignKey, final KP primaryKey) { + Objects.requireNonNull(foreignKey, "foreignKey can't be null"); + Objects.requireNonNull(primaryKey, "primaryKey can't be null"); + this.foreignKey = foreignKey; + this.primaryKey = primaryKey; + } + + public KF getForeignKey() { + return foreignKey; + } + + public KP getPrimaryKey() { + return primaryKey; + } + + public boolean equals(final KF foreignKey, final KP primaryKey) { + return this.foreignKey.equals(foreignKey) && this.primaryKey.equals(primaryKey); + } + + @Override + public String toString() { + return "CombinedKey{" + + "foreignKey=" + foreignKey + + ", primaryKey=" + primaryKey + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/CombinedKeySchema.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/CombinedKeySchema.java new file mode 100644 index 0000000..57bc646 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/CombinedKeySchema.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals.foreignkeyjoin; + +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.processor.ProcessorContext; + +import java.nio.ByteBuffer; +import java.util.function.Supplier; + +/** + * Factory for creating CombinedKey serializers / deserializers. + */ +public class CombinedKeySchema { + private final Supplier undecoratedPrimaryKeySerdeTopicSupplier; + private final Supplier undecoratedForeignKeySerdeTopicSupplier; + private String primaryKeySerdeTopic; + private String foreignKeySerdeTopic; + private Serializer primaryKeySerializer; + private Deserializer primaryKeyDeserializer; + private Serializer foreignKeySerializer; + private Deserializer foreignKeyDeserializer; + + public CombinedKeySchema(final Supplier foreignKeySerdeTopicSupplier, + final Serde foreignKeySerde, + final Supplier primaryKeySerdeTopicSupplier, + final Serde primaryKeySerde) { + undecoratedPrimaryKeySerdeTopicSupplier = primaryKeySerdeTopicSupplier; + undecoratedForeignKeySerdeTopicSupplier = foreignKeySerdeTopicSupplier; + primaryKeySerializer = primaryKeySerde == null ? null : primaryKeySerde.serializer(); + primaryKeyDeserializer = primaryKeySerde == null ? null : primaryKeySerde.deserializer(); + foreignKeyDeserializer = foreignKeySerde == null ? null : foreignKeySerde.deserializer(); + foreignKeySerializer = foreignKeySerde == null ? null : foreignKeySerde.serializer(); + } + + @SuppressWarnings("unchecked") + public void init(final ProcessorContext context) { + primaryKeySerdeTopic = undecoratedPrimaryKeySerdeTopicSupplier.get(); + foreignKeySerdeTopic = undecoratedForeignKeySerdeTopicSupplier.get(); + primaryKeySerializer = primaryKeySerializer == null ? (Serializer) context.keySerde().serializer() : primaryKeySerializer; + primaryKeyDeserializer = primaryKeyDeserializer == null ? (Deserializer) context.keySerde().deserializer() : primaryKeyDeserializer; + foreignKeySerializer = foreignKeySerializer == null ? (Serializer) context.keySerde().serializer() : foreignKeySerializer; + foreignKeyDeserializer = foreignKeyDeserializer == null ? (Deserializer) context.keySerde().deserializer() : foreignKeyDeserializer; + } + + Bytes toBytes(final KO foreignKey, final K primaryKey) { + //The serialization format - note that primaryKeySerialized may be null, such as when a prefixScan + //key is being created. + //{Integer.BYTES foreignKeyLength}{foreignKeySerialized}{Optional-primaryKeySerialized} + final byte[] foreignKeySerializedData = foreignKeySerializer.serialize(foreignKeySerdeTopic, + foreignKey); + + //? bytes + final byte[] primaryKeySerializedData = primaryKeySerializer.serialize(primaryKeySerdeTopic, + primaryKey); + + final ByteBuffer buf = ByteBuffer.allocate(Integer.BYTES + foreignKeySerializedData.length + primaryKeySerializedData.length); + buf.putInt(foreignKeySerializedData.length); + buf.put(foreignKeySerializedData); + buf.put(primaryKeySerializedData); + return Bytes.wrap(buf.array()); + } + + + public CombinedKey fromBytes(final Bytes data) { + //{Integer.BYTES foreignKeyLength}{foreignKeySerialized}{Optional-primaryKeySerialized} + final byte[] dataArray = data.get(); + final ByteBuffer dataBuffer = ByteBuffer.wrap(dataArray); + final int foreignKeyLength = dataBuffer.getInt(); + final byte[] foreignKeyRaw = new byte[foreignKeyLength]; + dataBuffer.get(foreignKeyRaw, 0, foreignKeyLength); + final KO foreignKey = foreignKeyDeserializer.deserialize(foreignKeySerdeTopic, foreignKeyRaw); + + final byte[] primaryKeyRaw = new byte[dataArray.length - foreignKeyLength - Integer.BYTES]; + dataBuffer.get(primaryKeyRaw, 0, primaryKeyRaw.length); + final K primaryKey = primaryKeyDeserializer.deserialize(primaryKeySerdeTopic, primaryKeyRaw); + return new CombinedKey<>(foreignKey, primaryKey); + } + + Bytes prefixBytes(final KO key) { + //The serialization format. Note that primaryKeySerialized is not required/used in this function. + //{Integer.BYTES foreignKeyLength}{foreignKeySerialized}{Optional-primaryKeySerialized} + + final byte[] foreignKeySerializedData = foreignKeySerializer.serialize(foreignKeySerdeTopic, key); + + final ByteBuffer buf = ByteBuffer.allocate(Integer.BYTES + foreignKeySerializedData.length); + buf.putInt(foreignKeySerializedData.length); + buf.put(foreignKeySerializedData); + return Bytes.wrap(buf.array()); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/ForeignJoinSubscriptionProcessorSupplier.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/ForeignJoinSubscriptionProcessorSupplier.java new file mode 100644 index 0000000..f0114b1 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/ForeignJoinSubscriptionProcessorSupplier.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals.foreignkeyjoin; + +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.internals.Change; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.processor.internals.metrics.TaskMetrics; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class ForeignJoinSubscriptionProcessorSupplier implements org.apache.kafka.streams.processor.ProcessorSupplier> { + private static final Logger LOG = LoggerFactory.getLogger(ForeignJoinSubscriptionProcessorSupplier.class); + private final StoreBuilder>> storeBuilder; + private final CombinedKeySchema keySchema; + + public ForeignJoinSubscriptionProcessorSupplier( + final StoreBuilder>> storeBuilder, + final CombinedKeySchema keySchema) { + + this.storeBuilder = storeBuilder; + this.keySchema = keySchema; + } + + @Override + public org.apache.kafka.streams.processor.Processor> get() { + return new KTableKTableJoinProcessor(); + } + + + private final class KTableKTableJoinProcessor extends org.apache.kafka.streams.processor.AbstractProcessor> { + private Sensor droppedRecordsSensor; + private TimestampedKeyValueStore> store; + + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + super.init(context); + final InternalProcessorContext internalProcessorContext = (InternalProcessorContext) context; + droppedRecordsSensor = TaskMetrics.droppedRecordsSensor( + Thread.currentThread().getName(), + internalProcessorContext.taskId().toString(), + internalProcessorContext.metrics() + ); + store = internalProcessorContext.getStateStore(storeBuilder); + } + + @Override + public void process(final KO key, final Change value) { + // if the key is null, we do not need proceed aggregating + // the record with the table + if (key == null) { + LOG.warn( + "Skipping record due to null key. value=[{}] topic=[{}] partition=[{}] offset=[{}]", + value, context().topic(), context().partition(), context().offset() + ); + droppedRecordsSensor.record(); + return; + } + + final Bytes prefixBytes = keySchema.prefixBytes(key); + + //Perform the prefixScan and propagate the results + try (final KeyValueIterator>> prefixScanResults = + store.range(prefixBytes, Bytes.increment(prefixBytes))) { + + while (prefixScanResults.hasNext()) { + final KeyValue>> next = prefixScanResults.next(); + // have to check the prefix because the range end is inclusive :( + if (prefixEquals(next.key.get(), prefixBytes.get())) { + final CombinedKey combinedKey = keySchema.fromBytes(next.key); + context().forward( + combinedKey.getPrimaryKey(), + new SubscriptionResponseWrapper<>(next.value.value().getHash(), value.newValue) + ); + } + } + } + } + + private boolean prefixEquals(final byte[] x, final byte[] y) { + final int min = Math.min(x.length, y.length); + final ByteBuffer xSlice = ByteBuffer.wrap(x, 0, min); + final ByteBuffer ySlice = ByteBuffer.wrap(y, 0, min); + return xSlice.equals(ySlice); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/ForeignJoinSubscriptionSendProcessorSupplier.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/ForeignJoinSubscriptionSendProcessorSupplier.java new file mode 100644 index 0000000..8f1e1f9 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/ForeignJoinSubscriptionSendProcessorSupplier.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals.foreignkeyjoin; + +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.kstream.internals.Change; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.processor.internals.metrics.TaskMetrics; +import org.apache.kafka.streams.state.internals.Murmur3; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Arrays; +import java.util.function.Function; +import java.util.function.Supplier; + +import static org.apache.kafka.streams.kstream.internals.foreignkeyjoin.SubscriptionWrapper.Instruction.DELETE_KEY_AND_PROPAGATE; +import static org.apache.kafka.streams.kstream.internals.foreignkeyjoin.SubscriptionWrapper.Instruction.DELETE_KEY_NO_PROPAGATE; +import static org.apache.kafka.streams.kstream.internals.foreignkeyjoin.SubscriptionWrapper.Instruction.PROPAGATE_NULL_IF_NO_FK_VAL_AVAILABLE; +import static org.apache.kafka.streams.kstream.internals.foreignkeyjoin.SubscriptionWrapper.Instruction.PROPAGATE_ONLY_IF_FK_VAL_AVAILABLE; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class ForeignJoinSubscriptionSendProcessorSupplier implements org.apache.kafka.streams.processor.ProcessorSupplier> { + private static final Logger LOG = LoggerFactory.getLogger(ForeignJoinSubscriptionSendProcessorSupplier.class); + + private final Function foreignKeyExtractor; + private final Supplier foreignKeySerdeTopicSupplier; + private final Supplier valueSerdeTopicSupplier; + private final boolean leftJoin; + private Serializer foreignKeySerializer; + private Serializer valueSerializer; + + public ForeignJoinSubscriptionSendProcessorSupplier(final Function foreignKeyExtractor, + final Supplier foreignKeySerdeTopicSupplier, + final Supplier valueSerdeTopicSupplier, + final Serde foreignKeySerde, + final Serializer valueSerializer, + final boolean leftJoin) { + this.foreignKeyExtractor = foreignKeyExtractor; + this.foreignKeySerdeTopicSupplier = foreignKeySerdeTopicSupplier; + this.valueSerdeTopicSupplier = valueSerdeTopicSupplier; + this.valueSerializer = valueSerializer; + this.leftJoin = leftJoin; + foreignKeySerializer = foreignKeySerde == null ? null : foreignKeySerde.serializer(); + } + + @Override + public org.apache.kafka.streams.processor.Processor> get() { + return new UnbindChangeProcessor(); + } + + private class UnbindChangeProcessor extends org.apache.kafka.streams.processor.AbstractProcessor> { + + private Sensor droppedRecordsSensor; + private String foreignKeySerdeTopic; + private String valueSerdeTopic; + + @SuppressWarnings("unchecked") + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + super.init(context); + foreignKeySerdeTopic = foreignKeySerdeTopicSupplier.get(); + valueSerdeTopic = valueSerdeTopicSupplier.get(); + // get default key serde if it wasn't supplied directly at construction + if (foreignKeySerializer == null) { + foreignKeySerializer = (Serializer) context.keySerde().serializer(); + } + if (valueSerializer == null) { + valueSerializer = (Serializer) context.valueSerde().serializer(); + } + droppedRecordsSensor = TaskMetrics.droppedRecordsSensor( + Thread.currentThread().getName(), + context.taskId().toString(), + (StreamsMetricsImpl) context.metrics() + ); + } + + @Override + public void process(final K key, final Change change) { + final long[] currentHash = change.newValue == null ? + null : + Murmur3.hash128(valueSerializer.serialize(valueSerdeTopic, change.newValue)); + + if (change.oldValue != null) { + final KO oldForeignKey = foreignKeyExtractor.apply(change.oldValue); + if (oldForeignKey == null) { + LOG.warn( + "Skipping record due to null foreign key. value=[{}] topic=[{}] partition=[{}] offset=[{}]", + change.oldValue, context().topic(), context().partition(), context().offset() + ); + droppedRecordsSensor.record(); + return; + } + if (change.newValue != null) { + final KO newForeignKey = foreignKeyExtractor.apply(change.newValue); + if (newForeignKey == null) { + LOG.warn( + "Skipping record due to null foreign key. value=[{}] topic=[{}] partition=[{}] offset=[{}]", + change.newValue, context().topic(), context().partition(), context().offset() + ); + droppedRecordsSensor.record(); + return; + } + + final byte[] serialOldForeignKey = + foreignKeySerializer.serialize(foreignKeySerdeTopic, oldForeignKey); + final byte[] serialNewForeignKey = + foreignKeySerializer.serialize(foreignKeySerdeTopic, newForeignKey); + if (!Arrays.equals(serialNewForeignKey, serialOldForeignKey)) { + //Different Foreign Key - delete the old key value and propagate the new one. + //Delete it from the oldKey's state store + context().forward(oldForeignKey, new SubscriptionWrapper<>(currentHash, DELETE_KEY_NO_PROPAGATE, key)); + //Add to the newKey's state store. Additionally, propagate null if no FK is found there, + //since we must "unset" any output set by the previous FK-join. This is true for both INNER + //and LEFT join. + } + context().forward(newForeignKey, new SubscriptionWrapper<>(currentHash, PROPAGATE_NULL_IF_NO_FK_VAL_AVAILABLE, key)); + } else { + //A simple propagatable delete. Delete from the state store and propagate the delete onwards. + context().forward(oldForeignKey, new SubscriptionWrapper<>(currentHash, DELETE_KEY_AND_PROPAGATE, key)); + } + } else if (change.newValue != null) { + //change.oldValue is null, which means it was deleted at least once before, or it is brand new. + //In either case, we only need to propagate if the FK_VAL is available, as the null from the delete would + //have been propagated otherwise. + + final SubscriptionWrapper.Instruction instruction; + if (leftJoin) { + //Want to send info even if RHS is null. + instruction = PROPAGATE_NULL_IF_NO_FK_VAL_AVAILABLE; + } else { + instruction = PROPAGATE_ONLY_IF_FK_VAL_AVAILABLE; + } + final KO newForeignKey = foreignKeyExtractor.apply(change.newValue); + if (newForeignKey == null) { + LOG.warn( + "Skipping record due to null foreign key. value=[{}] topic=[{}] partition=[{}] offset=[{}]", + change.newValue, context().topic(), context().partition(), context().offset() + ); + droppedRecordsSensor.record(); + } else { + context().forward(newForeignKey, new SubscriptionWrapper<>(currentHash, instruction, key)); + } + } + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionJoinForeignProcessorSupplier.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionJoinForeignProcessorSupplier.java new file mode 100644 index 0000000..7c53a68 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionJoinForeignProcessorSupplier.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals.foreignkeyjoin; + +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.streams.kstream.internals.Change; +import org.apache.kafka.streams.kstream.internals.KTableValueGetter; +import org.apache.kafka.streams.kstream.internals.KTableValueGetterSupplier; +import org.apache.kafka.streams.processor.To; +import org.apache.kafka.streams.state.ValueAndTimestamp; + +import java.util.Objects; + +/** + * Receives {@code SubscriptionWrapper} events and processes them according to their Instruction. + * Depending on the results, {@code SubscriptionResponseWrapper}s are created, which will be propagated to + * the {@code SubscriptionResolverJoinProcessorSupplier} instance. + * + * @param Type of primary keys + * @param Type of foreign key + * @param Type of foreign value + */ +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class SubscriptionJoinForeignProcessorSupplier + implements org.apache.kafka.streams.processor.ProcessorSupplier, Change>>> { + + private final KTableValueGetterSupplier foreignValueGetterSupplier; + + public SubscriptionJoinForeignProcessorSupplier(final KTableValueGetterSupplier foreignValueGetterSupplier) { + this.foreignValueGetterSupplier = foreignValueGetterSupplier; + } + + @Override + public org.apache.kafka.streams.processor.Processor, Change>>> get() { + + return new org.apache.kafka.streams.processor.AbstractProcessor, Change>>>() { + + private KTableValueGetter foreignValues; + + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + super.init(context); + foreignValues = foreignValueGetterSupplier.get(); + foreignValues.init(context); + } + + @Override + public void process(final CombinedKey combinedKey, final Change>> change) { + Objects.requireNonNull(combinedKey, "This processor should never see a null key."); + Objects.requireNonNull(change, "This processor should never see a null value."); + final ValueAndTimestamp> valueAndTimestamp = change.newValue; + Objects.requireNonNull(valueAndTimestamp, "This processor should never see a null newValue."); + final SubscriptionWrapper value = valueAndTimestamp.value(); + + if (value.getVersion() != SubscriptionWrapper.CURRENT_VERSION) { + //Guard against modifications to SubscriptionWrapper. Need to ensure that there is compatibility + //with previous versions to enable rolling upgrades. Must develop a strategy for upgrading + //from older SubscriptionWrapper versions to newer versions. + throw new UnsupportedVersionException("SubscriptionWrapper is of an incompatible version."); + } + + final ValueAndTimestamp foreignValueAndTime = foreignValues.get(combinedKey.getForeignKey()); + + final long resultTimestamp = + foreignValueAndTime == null ? + valueAndTimestamp.timestamp() : + Math.max(valueAndTimestamp.timestamp(), foreignValueAndTime.timestamp()); + + switch (value.getInstruction()) { + case DELETE_KEY_AND_PROPAGATE: + context().forward( + combinedKey.getPrimaryKey(), + new SubscriptionResponseWrapper(value.getHash(), null), + To.all().withTimestamp(resultTimestamp) + ); + break; + case PROPAGATE_NULL_IF_NO_FK_VAL_AVAILABLE: + //This one needs to go through regardless of LEFT or INNER join, since the extracted FK was + //changed and there is no match for it. We must propagate the (key, null) to ensure that the + //downstream consumers are alerted to this fact. + final VO valueToSend = foreignValueAndTime == null ? null : foreignValueAndTime.value(); + + context().forward( + combinedKey.getPrimaryKey(), + new SubscriptionResponseWrapper<>(value.getHash(), valueToSend), + To.all().withTimestamp(resultTimestamp) + ); + break; + case PROPAGATE_ONLY_IF_FK_VAL_AVAILABLE: + if (foreignValueAndTime != null) { + context().forward( + combinedKey.getPrimaryKey(), + new SubscriptionResponseWrapper<>(value.getHash(), foreignValueAndTime.value()), + To.all().withTimestamp(resultTimestamp) + ); + } + break; + case DELETE_KEY_NO_PROPAGATE: + break; + default: + throw new IllegalStateException("Unhandled instruction: " + value.getInstruction()); + } + } + }; + } +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionResolverJoinProcessorSupplier.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionResolverJoinProcessorSupplier.java new file mode 100644 index 0000000..652adbd --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionResolverJoinProcessorSupplier.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals.foreignkeyjoin; + +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.kstream.ValueJoiner; +import org.apache.kafka.streams.kstream.internals.KTableValueGetter; +import org.apache.kafka.streams.kstream.internals.KTableValueGetterSupplier; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.streams.state.internals.Murmur3; + +import java.util.function.Supplier; + +/** + * Receives {@code SubscriptionResponseWrapper} events and filters out events which do not match the current hash + * of the primary key. This eliminates race-condition results for rapidly-changing foreign-keys for a given primary key. + * Applies the join and emits nulls according to LEFT/INNER rules. + * + * @param Type of primary keys + * @param Type of primary values + * @param Type of foreign values + * @param Type of joined result of primary and foreign values + */ +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class SubscriptionResolverJoinProcessorSupplier implements org.apache.kafka.streams.processor.ProcessorSupplier> { + private final KTableValueGetterSupplier valueGetterSupplier; + private final Serializer constructionTimeValueSerializer; + private final Supplier valueHashSerdePseudoTopicSupplier; + private final ValueJoiner joiner; + private final boolean leftJoin; + + public SubscriptionResolverJoinProcessorSupplier(final KTableValueGetterSupplier valueGetterSupplier, + final Serializer valueSerializer, + final Supplier valueHashSerdePseudoTopicSupplier, + final ValueJoiner joiner, + final boolean leftJoin) { + this.valueGetterSupplier = valueGetterSupplier; + constructionTimeValueSerializer = valueSerializer; + this.valueHashSerdePseudoTopicSupplier = valueHashSerdePseudoTopicSupplier; + this.joiner = joiner; + this.leftJoin = leftJoin; + } + + @Override + public org.apache.kafka.streams.processor.Processor> get() { + return new org.apache.kafka.streams.processor.AbstractProcessor>() { + private String valueHashSerdePseudoTopic; + private Serializer runtimeValueSerializer = constructionTimeValueSerializer; + + private KTableValueGetter valueGetter; + + @SuppressWarnings("unchecked") + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + super.init(context); + valueHashSerdePseudoTopic = valueHashSerdePseudoTopicSupplier.get(); + valueGetter = valueGetterSupplier.get(); + valueGetter.init(context); + if (runtimeValueSerializer == null) { + runtimeValueSerializer = (Serializer) context.valueSerde().serializer(); + } + } + + @Override + public void process(final K key, final SubscriptionResponseWrapper value) { + if (value.getVersion() != SubscriptionResponseWrapper.CURRENT_VERSION) { + //Guard against modifications to SubscriptionResponseWrapper. Need to ensure that there is + //compatibility with previous versions to enable rolling upgrades. Must develop a strategy for + //upgrading from older SubscriptionWrapper versions to newer versions. + throw new UnsupportedVersionException("SubscriptionResponseWrapper is of an incompatible version."); + } + final ValueAndTimestamp currentValueWithTimestamp = valueGetter.get(key); + + final long[] currentHash = currentValueWithTimestamp == null ? + null : + Murmur3.hash128(runtimeValueSerializer.serialize(valueHashSerdePseudoTopic, currentValueWithTimestamp.value())); + + final long[] messageHash = value.getOriginalValueHash(); + + //If this value doesn't match the current value from the original table, it is stale and should be discarded. + if (java.util.Arrays.equals(messageHash, currentHash)) { + final VR result; + + if (value.getForeignValue() == null && (!leftJoin || currentValueWithTimestamp == null)) { + result = null; //Emit tombstone + } else { + result = joiner.apply(currentValueWithTimestamp == null ? null : currentValueWithTimestamp.value(), value.getForeignValue()); + } + context().forward(key, result); + } + } + }; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionResponseWrapper.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionResponseWrapper.java new file mode 100644 index 0000000..9c79e46 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionResponseWrapper.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals.foreignkeyjoin; + +import org.apache.kafka.common.errors.UnsupportedVersionException; + +import java.util.Arrays; + +public class SubscriptionResponseWrapper { + final static byte CURRENT_VERSION = 0x00; + private final long[] originalValueHash; + private final FV foreignValue; + private final byte version; + + public SubscriptionResponseWrapper(final long[] originalValueHash, final FV foreignValue) { + this(originalValueHash, foreignValue, CURRENT_VERSION); + } + + public SubscriptionResponseWrapper(final long[] originalValueHash, final FV foreignValue, final byte version) { + if (version != CURRENT_VERSION) { + throw new UnsupportedVersionException("SubscriptionWrapper does not support version " + version); + } + this.originalValueHash = originalValueHash; + this.foreignValue = foreignValue; + this.version = version; + } + + public long[] getOriginalValueHash() { + return originalValueHash; + } + + public FV getForeignValue() { + return foreignValue; + } + + public byte getVersion() { + return version; + } + + @Override + public String toString() { + return "SubscriptionResponseWrapper{" + + "version=" + version + + ", foreignValue=" + foreignValue + + ", originalValueHash=" + Arrays.toString(originalValueHash) + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionResponseWrapperSerde.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionResponseWrapperSerde.java new file mode 100644 index 0000000..8910ff8 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionResponseWrapperSerde.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals.foreignkeyjoin; + +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.kstream.internals.WrappingNullableDeserializer; +import org.apache.kafka.streams.kstream.internals.WrappingNullableSerializer; +import org.apache.kafka.streams.processor.internals.SerdeGetter; + +import java.nio.ByteBuffer; + +public class SubscriptionResponseWrapperSerde implements Serde> { + private final SubscriptionResponseWrapperSerializer serializer; + private final SubscriptionResponseWrapperDeserializer deserializer; + + public SubscriptionResponseWrapperSerde(final Serde foreignValueSerde) { + serializer = new SubscriptionResponseWrapperSerializer<>(foreignValueSerde == null ? null : foreignValueSerde.serializer()); + deserializer = new SubscriptionResponseWrapperDeserializer<>(foreignValueSerde == null ? null : foreignValueSerde.deserializer()); + } + + @Override + public Serializer> serializer() { + return serializer; + } + + @Override + public Deserializer> deserializer() { + return deserializer; + } + + private static final class SubscriptionResponseWrapperSerializer + implements Serializer>, WrappingNullableSerializer, Void, V> { + + private Serializer serializer; + + private SubscriptionResponseWrapperSerializer(final Serializer serializer) { + this.serializer = serializer; + } + + @SuppressWarnings("unchecked") + @Override + public void setIfUnset(final SerdeGetter getter) { + if (serializer == null) { + serializer = (Serializer) getter.valueSerde().serializer(); + } + } + + @Override + public byte[] serialize(final String topic, final SubscriptionResponseWrapper data) { + //{1-bit-isHashNull}{7-bits-version}{Optional-16-byte-Hash}{n-bytes serialized data} + + //7-bit (0x7F) maximum for data version. + if (Byte.compare((byte) 0x7F, data.getVersion()) < 0) { + throw new UnsupportedVersionException("SubscriptionResponseWrapper version is larger than maximum supported 0x7F"); + } + + final byte[] serializedData = data.getForeignValue() == null ? null : serializer.serialize(topic, data.getForeignValue()); + final int serializedDataLength = serializedData == null ? 0 : serializedData.length; + final long[] originalHash = data.getOriginalValueHash(); + final int hashLength = originalHash == null ? 0 : 2 * Long.BYTES; + + final ByteBuffer buf = ByteBuffer.allocate(1 + hashLength + serializedDataLength); + + if (originalHash != null) { + buf.put(data.getVersion()); + buf.putLong(originalHash[0]); + buf.putLong(originalHash[1]); + } else { + //Don't store hash as it's null. + buf.put((byte) (data.getVersion() | (byte) 0x80)); + } + + if (serializedData != null) + buf.put(serializedData); + return buf.array(); + } + + } + + private static final class SubscriptionResponseWrapperDeserializer + implements Deserializer>, WrappingNullableDeserializer, Void, V> { + + private Deserializer deserializer; + + private SubscriptionResponseWrapperDeserializer(final Deserializer deserializer) { + this.deserializer = deserializer; + } + + @SuppressWarnings("unchecked") + @Override + public void setIfUnset(final SerdeGetter getter) { + if (deserializer == null) { + deserializer = (Deserializer) getter.valueSerde().deserializer(); + } + } + + @Override + public SubscriptionResponseWrapper deserialize(final String topic, final byte[] data) { + //{1-bit-isHashNull}{7-bits-version}{Optional-16-byte-Hash}{n-bytes serialized data} + + final ByteBuffer buf = ByteBuffer.wrap(data); + final byte versionAndIsHashNull = buf.get(); + final byte version = (byte) (0x7F & versionAndIsHashNull); + final boolean isHashNull = (0x80 & versionAndIsHashNull) == 0x80; + + final long[] hash; + int lengthSum = 1; //The first byte + if (isHashNull) { + hash = null; + } else { + hash = new long[2]; + hash[0] = buf.getLong(); + hash[1] = buf.getLong(); + lengthSum += 2 * Long.BYTES; + } + + final V value; + if (data.length - lengthSum > 0) { + final byte[] serializedValue; + serializedValue = new byte[data.length - lengthSum]; + buf.get(serializedValue, 0, serializedValue.length); + value = deserializer.deserialize(topic, serializedValue); + } else { + value = null; + } + + return new SubscriptionResponseWrapper<>(hash, value, version); + } + + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionStoreReceiveProcessorSupplier.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionStoreReceiveProcessorSupplier.java new file mode 100644 index 0000000..928bd48 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionStoreReceiveProcessorSupplier.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals.foreignkeyjoin; + +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.kstream.internals.Change; +import org.apache.kafka.streams.processor.To; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.processor.internals.metrics.TaskMetrics; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class SubscriptionStoreReceiveProcessorSupplier + implements org.apache.kafka.streams.processor.ProcessorSupplier> { + private static final Logger LOG = LoggerFactory.getLogger(SubscriptionStoreReceiveProcessorSupplier.class); + + private final StoreBuilder>> storeBuilder; + private final CombinedKeySchema keySchema; + + public SubscriptionStoreReceiveProcessorSupplier( + final StoreBuilder>> storeBuilder, + final CombinedKeySchema keySchema) { + + this.storeBuilder = storeBuilder; + this.keySchema = keySchema; + } + + @Override + public org.apache.kafka.streams.processor.Processor> get() { + + return new org.apache.kafka.streams.processor.AbstractProcessor>() { + + private TimestampedKeyValueStore> store; + private Sensor droppedRecordsSensor; + + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + super.init(context); + final InternalProcessorContext internalProcessorContext = (InternalProcessorContext) context; + + droppedRecordsSensor = TaskMetrics.droppedRecordsSensor( + Thread.currentThread().getName(), + internalProcessorContext.taskId().toString(), + internalProcessorContext.metrics() + ); + store = internalProcessorContext.getStateStore(storeBuilder); + + keySchema.init(context); + } + + @Override + public void process(final KO key, final SubscriptionWrapper value) { + if (key == null) { + LOG.warn( + "Skipping record due to null foreign key. value=[{}] topic=[{}] partition=[{}] offset=[{}]", + value, context().topic(), context().partition(), context().offset() + ); + droppedRecordsSensor.record(); + return; + } + if (value.getVersion() != SubscriptionWrapper.CURRENT_VERSION) { + //Guard against modifications to SubscriptionWrapper. Need to ensure that there is compatibility + //with previous versions to enable rolling upgrades. Must develop a strategy for upgrading + //from older SubscriptionWrapper versions to newer versions. + throw new UnsupportedVersionException("SubscriptionWrapper is of an incompatible version."); + } + + final Bytes subscriptionKey = keySchema.toBytes(key, value.getPrimaryKey()); + + final ValueAndTimestamp> newValue = ValueAndTimestamp.make(value, context().timestamp()); + final ValueAndTimestamp> oldValue = store.get(subscriptionKey); + + //This store is used by the prefix scanner in ForeignJoinSubscriptionProcessorSupplier + if (value.getInstruction().equals(SubscriptionWrapper.Instruction.DELETE_KEY_AND_PROPAGATE) || + value.getInstruction().equals(SubscriptionWrapper.Instruction.DELETE_KEY_NO_PROPAGATE)) { + store.delete(subscriptionKey); + } else { + store.put(subscriptionKey, newValue); + } + final Change>> change = new Change<>(newValue, oldValue); + // note: key is non-nullable + // note: newValue is non-nullable + context().forward( + new CombinedKey<>(key, value.getPrimaryKey()), + change, + To.all().withTimestamp(newValue.timestamp()) + ); + } + }; + } +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionWrapper.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionWrapper.java new file mode 100644 index 0000000..a757895 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionWrapper.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals.foreignkeyjoin; + +import org.apache.kafka.common.errors.UnsupportedVersionException; + +import java.util.Arrays; +import java.util.Objects; + + +public class SubscriptionWrapper { + static final byte CURRENT_VERSION = 0; + + private final long[] hash; + private final Instruction instruction; + private final byte version; + private final K primaryKey; + + public enum Instruction { + //Send nothing. Do not propagate. + DELETE_KEY_NO_PROPAGATE((byte) 0x00), + + //Send (k, null) + DELETE_KEY_AND_PROPAGATE((byte) 0x01), + + //(changing foreign key, but FK+Val may not exist) + //Send (k, fk-val) OR + //Send (k, null) if fk-val does not exist + PROPAGATE_NULL_IF_NO_FK_VAL_AVAILABLE((byte) 0x02), + + //(first time ever sending key) + //Send (k, fk-val) only if fk-val exists. + PROPAGATE_ONLY_IF_FK_VAL_AVAILABLE((byte) 0x03); + + private final byte value; + Instruction(final byte value) { + this.value = value; + } + + public byte getValue() { + return value; + } + + public static Instruction fromValue(final byte value) { + for (final Instruction i: values()) { + if (i.value == value) { + return i; + } + } + throw new IllegalArgumentException("Unknown instruction byte value = " + value); + } + } + + public SubscriptionWrapper(final long[] hash, final Instruction instruction, final K primaryKey) { + this(hash, instruction, primaryKey, CURRENT_VERSION); + } + + public SubscriptionWrapper(final long[] hash, final Instruction instruction, final K primaryKey, final byte version) { + Objects.requireNonNull(instruction, "instruction cannot be null. Required by downstream processor."); + Objects.requireNonNull(primaryKey, "primaryKey cannot be null. Required by downstream processor."); + if (version != CURRENT_VERSION) { + throw new UnsupportedVersionException("SubscriptionWrapper does not support version " + version); + } + + this.instruction = instruction; + this.hash = hash; + this.primaryKey = primaryKey; + this.version = version; + } + + public Instruction getInstruction() { + return instruction; + } + + public long[] getHash() { + return hash; + } + + public K getPrimaryKey() { + return primaryKey; + } + + public byte getVersion() { + return version; + } + + @Override + public String toString() { + return "SubscriptionWrapper{" + + "version=" + version + + ", primaryKey=" + primaryKey + + ", instruction=" + instruction + + ", hash=" + Arrays.toString(hash) + + '}'; + } +} + diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionWrapperSerde.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionWrapperSerde.java new file mode 100644 index 0000000..e713762 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionWrapperSerde.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals.foreignkeyjoin; + +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.kstream.internals.WrappingNullableDeserializer; +import org.apache.kafka.streams.kstream.internals.WrappingNullableSerde; +import org.apache.kafka.streams.kstream.internals.WrappingNullableSerializer; +import org.apache.kafka.streams.processor.internals.SerdeGetter; + +import java.nio.ByteBuffer; +import java.util.function.Supplier; + +public class SubscriptionWrapperSerde extends WrappingNullableSerde, K, Void> { + public SubscriptionWrapperSerde(final Supplier primaryKeySerializationPseudoTopicSupplier, + final Serde primaryKeySerde) { + super( + new SubscriptionWrapperSerializer<>(primaryKeySerializationPseudoTopicSupplier, + primaryKeySerde == null ? null : primaryKeySerde.serializer()), + new SubscriptionWrapperDeserializer<>(primaryKeySerializationPseudoTopicSupplier, + primaryKeySerde == null ? null : primaryKeySerde.deserializer()) + ); + } + + private static class SubscriptionWrapperSerializer + implements Serializer>, WrappingNullableSerializer, K, Void> { + + private final Supplier primaryKeySerializationPseudoTopicSupplier; + private String primaryKeySerializationPseudoTopic = null; + private Serializer primaryKeySerializer; + + SubscriptionWrapperSerializer(final Supplier primaryKeySerializationPseudoTopicSupplier, + final Serializer primaryKeySerializer) { + this.primaryKeySerializationPseudoTopicSupplier = primaryKeySerializationPseudoTopicSupplier; + this.primaryKeySerializer = primaryKeySerializer; + } + + @SuppressWarnings("unchecked") + @Override + public void setIfUnset(final SerdeGetter getter) { + if (primaryKeySerializer == null) { + primaryKeySerializer = (Serializer) getter.keySerde().serializer(); + } + } + + @Override + public byte[] serialize(final String ignored, final SubscriptionWrapper data) { + //{1-bit-isHashNull}{7-bits-version}{1-byte-instruction}{Optional-16-byte-Hash}{PK-serialized} + + //7-bit (0x7F) maximum for data version. + if (Byte.compare((byte) 0x7F, data.getVersion()) < 0) { + throw new UnsupportedVersionException("SubscriptionWrapper version is larger than maximum supported 0x7F"); + } + + if (primaryKeySerializationPseudoTopic == null) { + primaryKeySerializationPseudoTopic = primaryKeySerializationPseudoTopicSupplier.get(); + } + + final byte[] primaryKeySerializedData = primaryKeySerializer.serialize( + primaryKeySerializationPseudoTopic, + data.getPrimaryKey() + ); + + final ByteBuffer buf; + if (data.getHash() != null) { + buf = ByteBuffer.allocate(2 + 2 * Long.BYTES + primaryKeySerializedData.length); + buf.put(data.getVersion()); + } else { + //Don't store hash as it's null. + buf = ByteBuffer.allocate(2 + primaryKeySerializedData.length); + buf.put((byte) (data.getVersion() | (byte) 0x80)); + } + + buf.put(data.getInstruction().getValue()); + final long[] elem = data.getHash(); + if (data.getHash() != null) { + buf.putLong(elem[0]); + buf.putLong(elem[1]); + } + buf.put(primaryKeySerializedData); + return buf.array(); + } + + } + + private static class SubscriptionWrapperDeserializer + implements Deserializer>, WrappingNullableDeserializer, K, Void> { + + private final Supplier primaryKeySerializationPseudoTopicSupplier; + private String primaryKeySerializationPseudoTopic = null; + private Deserializer primaryKeyDeserializer; + + SubscriptionWrapperDeserializer(final Supplier primaryKeySerializationPseudoTopicSupplier, + final Deserializer primaryKeyDeserializer) { + this.primaryKeySerializationPseudoTopicSupplier = primaryKeySerializationPseudoTopicSupplier; + this.primaryKeyDeserializer = primaryKeyDeserializer; + } + + @SuppressWarnings("unchecked") + @Override + public void setIfUnset(final SerdeGetter getter) { + if (primaryKeyDeserializer == null) { + primaryKeyDeserializer = (Deserializer) getter.keySerde().deserializer(); + } + } + + @Override + public SubscriptionWrapper deserialize(final String ignored, final byte[] data) { + //{7-bits-version}{1-bit-isHashNull}{1-byte-instruction}{Optional-16-byte-Hash}{PK-serialized} + final ByteBuffer buf = ByteBuffer.wrap(data); + final byte versionAndIsHashNull = buf.get(); + final byte version = (byte) (0x7F & versionAndIsHashNull); + final boolean isHashNull = (0x80 & versionAndIsHashNull) == 0x80; + final SubscriptionWrapper.Instruction inst = SubscriptionWrapper.Instruction.fromValue(buf.get()); + + final long[] hash; + int lengthSum = 2; //The first 2 bytes + if (isHashNull) { + hash = null; + } else { + hash = new long[2]; + hash[0] = buf.getLong(); + hash[1] = buf.getLong(); + lengthSum += 2 * Long.BYTES; + } + + final byte[] primaryKeyRaw = new byte[data.length - lengthSum]; //The remaining data is the serialized pk + buf.get(primaryKeyRaw, 0, primaryKeyRaw.length); + + if (primaryKeySerializationPseudoTopic == null) { + primaryKeySerializationPseudoTopic = primaryKeySerializationPseudoTopicSupplier.get(); + } + + final K primaryKey = primaryKeyDeserializer.deserialize(primaryKeySerializationPseudoTopic, + primaryKeyRaw); + + return new SubscriptionWrapper<>(hash, inst, primaryKey, version); + } + + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/BaseJoinProcessorNode.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/BaseJoinProcessorNode.java new file mode 100644 index 0000000..121db53 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/BaseJoinProcessorNode.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals.graph; + +import org.apache.kafka.streams.kstream.ValueJoinerWithKey; + +/** + * Utility base class containing the common fields between + * a Stream-Stream join and a Table-Table join + */ +abstract class BaseJoinProcessorNode extends GraphNode { + + private final ProcessorParameters joinThisProcessorParameters; + private final ProcessorParameters joinOtherProcessorParameters; + private final ProcessorParameters joinMergeProcessorParameters; + private final ValueJoinerWithKey valueJoiner; + private final String thisJoinSideNodeName; + private final String otherJoinSideNodeName; + + + BaseJoinProcessorNode(final String nodeName, + final ValueJoinerWithKey valueJoiner, + final ProcessorParameters joinThisProcessorParameters, + final ProcessorParameters joinOtherProcessorParameters, + final ProcessorParameters joinMergeProcessorParameters, + final String thisJoinSideNodeName, + final String otherJoinSideNodeName) { + + super(nodeName); + + this.valueJoiner = valueJoiner; + this.joinThisProcessorParameters = joinThisProcessorParameters; + this.joinOtherProcessorParameters = joinOtherProcessorParameters; + this.joinMergeProcessorParameters = joinMergeProcessorParameters; + this.thisJoinSideNodeName = thisJoinSideNodeName; + this.otherJoinSideNodeName = otherJoinSideNodeName; + } + + ProcessorParameters thisProcessorParameters() { + return joinThisProcessorParameters; + } + + ProcessorParameters otherProcessorParameters() { + return joinOtherProcessorParameters; + } + + ProcessorParameters mergeProcessorParameters() { + return joinMergeProcessorParameters; + } + + String thisJoinSideNodeName() { + return thisJoinSideNodeName; + } + + String otherJoinSideNodeName() { + return otherJoinSideNodeName; + } + + @Override + public String toString() { + return "BaseJoinProcessorNode{" + + "joinThisProcessorParameters=" + joinThisProcessorParameters + + ", joinOtherProcessorParameters=" + joinOtherProcessorParameters + + ", joinMergeProcessorParameters=" + joinMergeProcessorParameters + + ", valueJoiner=" + valueJoiner + + ", thisJoinSideNodeName='" + thisJoinSideNodeName + '\'' + + ", otherJoinSideNodeName='" + otherJoinSideNodeName + '\'' + + "} " + super.toString(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/BaseRepartitionNode.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/BaseRepartitionNode.java new file mode 100644 index 0000000..533d820 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/BaseRepartitionNode.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals.graph; + +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.processor.StreamPartitioner; +import org.apache.kafka.streams.processor.internals.InternalTopicProperties; + +public abstract class BaseRepartitionNode extends GraphNode { + + protected final Serde keySerde; + protected final Serde valueSerde; + protected final String sinkName; + protected final String sourceName; + protected final String repartitionTopic; + protected final ProcessorParameters processorParameters; + protected final StreamPartitioner partitioner; + protected final InternalTopicProperties internalTopicProperties; + + BaseRepartitionNode(final String nodeName, + final String sourceName, + final ProcessorParameters processorParameters, + final Serde keySerde, + final Serde valueSerde, + final String sinkName, + final String repartitionTopic, + final StreamPartitioner partitioner, + final InternalTopicProperties internalTopicProperties) { + + super(nodeName); + + this.keySerde = keySerde; + this.valueSerde = valueSerde; + this.sinkName = sinkName; + this.sourceName = sourceName; + this.repartitionTopic = repartitionTopic; + this.processorParameters = processorParameters; + this.partitioner = partitioner; + this.internalTopicProperties = internalTopicProperties; + } + + Serializer valueSerializer() { + return valueSerde != null ? valueSerde.serializer() : null; + } + + Deserializer valueDeserializer() { + return valueSerde != null ? valueSerde.deserializer() : null; + } + + Serializer keySerializer() { + return keySerde != null ? keySerde.serializer() : null; + } + + Deserializer keyDeserializer() { + return keySerde != null ? keySerde.deserializer() : null; + } + + @Override + public String toString() { + return "BaseRepartitionNode{" + + "keySerde=" + keySerde + + ", valueSerde=" + valueSerde + + ", sinkName='" + sinkName + '\'' + + ", sourceName='" + sourceName + '\'' + + ", repartitionTopic='" + repartitionTopic + '\'' + + ", processorParameters=" + processorParameters + '\'' + + ", partitioner=" + partitioner + + ", internalTopicProperties=" + internalTopicProperties + + "} " + super.toString(); + } + + public abstract static class BaseRepartitionNodeBuilder> { + protected String nodeName; + protected ProcessorParameters processorParameters; + protected Serde keySerde; + protected Serde valueSerde; + protected String sinkName; + protected String sourceName; + protected String repartitionTopic; + protected StreamPartitioner partitioner; + protected InternalTopicProperties internalTopicProperties = InternalTopicProperties.empty(); + + public BaseRepartitionNodeBuilder withProcessorParameters(final ProcessorParameters processorParameters) { + this.processorParameters = processorParameters; + return this; + } + + public BaseRepartitionNodeBuilder withKeySerde(final Serde keySerde) { + this.keySerde = keySerde; + return this; + } + + public BaseRepartitionNodeBuilder withValueSerde(final Serde valueSerde) { + this.valueSerde = valueSerde; + return this; + } + + public BaseRepartitionNodeBuilder withSinkName(final String sinkName) { + this.sinkName = sinkName; + return this; + } + + public BaseRepartitionNodeBuilder withSourceName(final String sourceName) { + this.sourceName = sourceName; + return this; + } + + public BaseRepartitionNodeBuilder withRepartitionTopic(final String repartitionTopic) { + this.repartitionTopic = repartitionTopic; + return this; + } + + public BaseRepartitionNodeBuilder withStreamPartitioner(final StreamPartitioner partitioner) { + this.partitioner = partitioner; + return this; + } + + public BaseRepartitionNodeBuilder withNodeName(final String nodeName) { + this.nodeName = nodeName; + return this; + } + + public BaseRepartitionNodeBuilder withInternalTopicProperties(final InternalTopicProperties internalTopicProperties) { + this.internalTopicProperties = internalTopicProperties; + return this; + } + + public abstract T build(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/GlobalStoreNode.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/GlobalStoreNode.java new file mode 100644 index 0000000..753e076 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/GlobalStoreNode.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals.graph; + +import org.apache.kafka.streams.kstream.internals.ConsumedInternal; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; +import org.apache.kafka.streams.state.StoreBuilder; + +public class GlobalStoreNode extends StateStoreNode { + + private final String sourceName; + private final String topic; + private final ConsumedInternal consumed; + private final String processorName; + private final ProcessorSupplier stateUpdateSupplier; + + + public GlobalStoreNode(final StoreBuilder storeBuilder, + final String sourceName, + final String topic, + final ConsumedInternal consumed, + final String processorName, + final ProcessorSupplier stateUpdateSupplier) { + + super(storeBuilder); + this.sourceName = sourceName; + this.topic = topic; + this.consumed = consumed; + this.processorName = processorName; + this.stateUpdateSupplier = stateUpdateSupplier; + } + + @Override + public void writeToTopology(final InternalTopologyBuilder topologyBuilder) { + storeBuilder.withLoggingDisabled(); + topologyBuilder.addGlobalStore(storeBuilder, + sourceName, + consumed.timestampExtractor(), + consumed.keyDeserializer(), + consumed.valueDeserializer(), + topic, + processorName, + stateUpdateSupplier); + + } + + @Override + public String toString() { + return "GlobalStoreNode{" + + "sourceName='" + sourceName + '\'' + + ", topic='" + topic + '\'' + + ", processorName='" + processorName + '\'' + + "} "; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/GraphGraceSearchUtil.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/GraphGraceSearchUtil.java new file mode 100644 index 0000000..66ffdc0 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/GraphGraceSearchUtil.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals.graph; + +import org.apache.kafka.streams.errors.TopologyException; +import org.apache.kafka.streams.kstream.SessionWindows; +import org.apache.kafka.streams.kstream.SlidingWindows; +import org.apache.kafka.streams.kstream.Windows; +import org.apache.kafka.streams.kstream.internals.KStreamSessionWindowAggregate; +import org.apache.kafka.streams.kstream.internals.KStreamSlidingWindowAggregate; +import org.apache.kafka.streams.kstream.internals.KStreamWindowAggregate; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; + +public final class GraphGraceSearchUtil { + private GraphGraceSearchUtil() {} + + public static long findAndVerifyWindowGrace(final GraphNode graphNode) { + return findAndVerifyWindowGrace(graphNode, ""); + } + + private static long findAndVerifyWindowGrace(final GraphNode graphNode, final String chain) { + // error base case: we traversed off the end of the graph without finding a window definition + if (graphNode == null) { + throw new TopologyException( + "Window close time is only defined for windowed computations. Got [" + chain + "]." + ); + } + // base case: return if this node defines a grace period. + { + final Long gracePeriod = extractGracePeriod(graphNode); + if (gracePeriod != null) { + return gracePeriod; + } + } + + final String newChain = chain.equals("") ? graphNode.nodeName() : graphNode.nodeName() + "->" + chain; + + if (graphNode.parentNodes().isEmpty()) { + // error base case: we traversed to the end of the graph without finding a window definition + throw new TopologyException( + "Window close time is only defined for windowed computations. Got [" + newChain + "]." + ); + } + + // recursive case: all parents must define a grace period, and we use the max of our parents' graces. + long inheritedGrace = -1; + for (final GraphNode parentNode : graphNode.parentNodes()) { + final long parentGrace = findAndVerifyWindowGrace(parentNode, newChain); + inheritedGrace = Math.max(inheritedGrace, parentGrace); + } + + if (inheritedGrace == -1) { + throw new IllegalStateException(); // shouldn't happen, and it's not a legal grace period + } + + return inheritedGrace; + } + + @SuppressWarnings("rawtypes") + private static Long extractGracePeriod(final GraphNode node) { + if (node instanceof StatefulProcessorNode) { + final ProcessorSupplier processorSupplier = ((StatefulProcessorNode) node).processorParameters().processorSupplier(); + if (processorSupplier instanceof KStreamWindowAggregate) { + final KStreamWindowAggregate kStreamWindowAggregate = (KStreamWindowAggregate) processorSupplier; + final Windows windows = kStreamWindowAggregate.windows(); + return windows.gracePeriodMs(); + } else if (processorSupplier instanceof KStreamSessionWindowAggregate) { + final KStreamSessionWindowAggregate kStreamSessionWindowAggregate = (KStreamSessionWindowAggregate) processorSupplier; + final SessionWindows windows = kStreamSessionWindowAggregate.windows(); + return windows.gracePeriodMs() + windows.inactivityGap(); + } else if (processorSupplier instanceof KStreamSlidingWindowAggregate) { + final KStreamSlidingWindowAggregate kStreamSlidingWindowAggregate = (KStreamSlidingWindowAggregate) processorSupplier; + final SlidingWindows windows = kStreamSlidingWindowAggregate.windows(); + return windows.gracePeriodMs(); + } else { + return null; + } + } else { + return null; + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/GraphNode.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/GraphNode.java new file mode 100644 index 0000000..76c2b5c --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/GraphNode.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals.graph; + + +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; + +import java.util.Arrays; +import java.util.Collection; +import java.util.LinkedHashSet; + +public abstract class GraphNode { + + private final Collection childNodes = new LinkedHashSet<>(); + private final Collection parentNodes = new LinkedHashSet<>(); + private final String nodeName; + private boolean keyChangingOperation; + private boolean valueChangingOperation; + private boolean mergeNode; + private Integer buildPriority; + private boolean hasWrittenToTopology = false; + + public GraphNode(final String nodeName) { + this.nodeName = nodeName; + } + + public Collection parentNodes() { + return parentNodes; + } + + String[] parentNodeNames() { + final String[] parentNames = new String[parentNodes.size()]; + int index = 0; + for (final GraphNode parentNode : parentNodes) { + parentNames[index++] = parentNode.nodeName(); + } + return parentNames; + } + + public boolean allParentsWrittenToTopology() { + for (final GraphNode parentNode : parentNodes) { + if (!parentNode.hasWrittenToTopology()) { + return false; + } + } + return true; + } + + public Collection children() { + return new LinkedHashSet<>(childNodes); + } + + public void clearChildren() { + for (final GraphNode childNode : childNodes) { + childNode.parentNodes.remove(this); + } + childNodes.clear(); + } + + public boolean removeChild(final GraphNode child) { + return childNodes.remove(child) && child.parentNodes.remove(this); + } + + public void addChild(final GraphNode childNode) { + this.childNodes.add(childNode); + childNode.parentNodes.add(this); + } + + public String nodeName() { + return nodeName; + } + + public boolean isKeyChangingOperation() { + return keyChangingOperation; + } + + public boolean isValueChangingOperation() { + return valueChangingOperation; + } + + public boolean isMergeNode() { + return mergeNode; + } + + public void setMergeNode(final boolean mergeNode) { + this.mergeNode = mergeNode; + } + + public void setValueChangingOperation(final boolean valueChangingOperation) { + this.valueChangingOperation = valueChangingOperation; + } + + public void keyChangingOperation(final boolean keyChangingOperation) { + this.keyChangingOperation = keyChangingOperation; + } + + public void setBuildPriority(final int buildPriority) { + this.buildPriority = buildPriority; + } + + public Integer buildPriority() { + return this.buildPriority; + } + + public abstract void writeToTopology(final InternalTopologyBuilder topologyBuilder); + + public boolean hasWrittenToTopology() { + return hasWrittenToTopology; + } + + public void setHasWrittenToTopology(final boolean hasWrittenToTopology) { + this.hasWrittenToTopology = hasWrittenToTopology; + } + + @Override + public String toString() { + final String[] parentNames = parentNodeNames(); + return "StreamsGraphNode{" + + "nodeName='" + nodeName + '\'' + + ", buildPriority=" + buildPriority + + ", hasWrittenToTopology=" + hasWrittenToTopology + + ", keyChangingOperation=" + keyChangingOperation + + ", valueChangingOperation=" + valueChangingOperation + + ", mergeNode=" + mergeNode + + ", parentNodes=" + Arrays.toString(parentNames) + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/GroupedTableOperationRepartitionNode.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/GroupedTableOperationRepartitionNode.java new file mode 100644 index 0000000..a7ba30d --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/GroupedTableOperationRepartitionNode.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals.graph; + +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.kstream.internals.ChangedDeserializer; +import org.apache.kafka.streams.kstream.internals.ChangedSerializer; +import org.apache.kafka.streams.processor.FailOnInvalidTimestamp; +import org.apache.kafka.streams.processor.internals.InternalTopicProperties; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; + +public class GroupedTableOperationRepartitionNode extends BaseRepartitionNode { + + + private GroupedTableOperationRepartitionNode(final String nodeName, + final Serde keySerde, + final Serde valueSerde, + final String sinkName, + final String sourceName, + final String repartitionTopic, + final ProcessorParameters processorParameters) { + + super( + nodeName, + sourceName, + processorParameters, + keySerde, + valueSerde, + sinkName, + repartitionTopic, + null, + InternalTopicProperties.empty() + ); + } + + @Override + Serializer valueSerializer() { + final Serializer valueSerializer = super.valueSerializer(); + return unsafeCastChangedToValueSerializer(valueSerializer); + } + + @SuppressWarnings("unchecked") + private Serializer unsafeCastChangedToValueSerializer(final Serializer valueSerializer) { + return (Serializer) new ChangedSerializer<>(valueSerializer); + } + + @Override + Deserializer valueDeserializer() { + final Deserializer valueDeserializer = super.valueDeserializer(); + return unsafeCastChangedToValueDeserializer(valueDeserializer); + } + + @SuppressWarnings("unchecked") + private Deserializer unsafeCastChangedToValueDeserializer(final Deserializer valueDeserializer) { + return (Deserializer) new ChangedDeserializer<>(valueDeserializer); + } + + @Override + public String toString() { + return "GroupedTableOperationRepartitionNode{} " + super.toString(); + } + + @Override + public void writeToTopology(final InternalTopologyBuilder topologyBuilder) { + topologyBuilder.addInternalTopic(repartitionTopic, internalTopicProperties); + + topologyBuilder.addSink( + sinkName, + repartitionTopic, + keySerializer(), + valueSerializer(), + null, + parentNodeNames() + ); + + topologyBuilder.addSource( + null, + sourceName, + new FailOnInvalidTimestamp(), + keyDeserializer(), + valueDeserializer(), + repartitionTopic + ); + + } + + public static GroupedTableOperationRepartitionNodeBuilder groupedTableOperationNodeBuilder() { + return new GroupedTableOperationRepartitionNodeBuilder<>(); + } + + public static final class GroupedTableOperationRepartitionNodeBuilder extends BaseRepartitionNodeBuilder> { + + @Override + public GroupedTableOperationRepartitionNode build() { + return new GroupedTableOperationRepartitionNode<>( + nodeName, + keySerde, + valueSerde, + sinkName, + sourceName, + repartitionTopic, + processorParameters + ); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/KTableKTableJoinNode.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/KTableKTableJoinNode.java new file mode 100644 index 0000000..e6bd49d --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/KTableKTableJoinNode.java @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals.graph; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.kstream.internals.Change; +import org.apache.kafka.streams.kstream.internals.KTableKTableJoinMerger; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; + +import java.util.Arrays; + +/** + * Too much specific information to generalize so the KTable-KTable join requires a specific node. + */ +public class KTableKTableJoinNode extends BaseJoinProcessorNode, Change, Change> { + + private final Serde keySerde; + private final Serde valueSerde; + private final String[] joinThisStoreNames; + private final String[] joinOtherStoreNames; + private final StoreBuilder> storeBuilder; + + KTableKTableJoinNode(final String nodeName, + final ProcessorParameters, ?, ?> joinThisProcessorParameters, + final ProcessorParameters, ?, ?> joinOtherProcessorParameters, + final ProcessorParameters, ?, ?> joinMergeProcessorParameters, + final String thisJoinSide, + final String otherJoinSide, + final Serde keySerde, + final Serde valueSerde, + final String[] joinThisStoreNames, + final String[] joinOtherStoreNames, + final StoreBuilder> storeBuilder) { + + super(nodeName, + null, + joinThisProcessorParameters, + joinOtherProcessorParameters, + joinMergeProcessorParameters, + thisJoinSide, + otherJoinSide); + + this.keySerde = keySerde; + this.valueSerde = valueSerde; + this.joinThisStoreNames = joinThisStoreNames; + this.joinOtherStoreNames = joinOtherStoreNames; + this.storeBuilder = storeBuilder; + } + + public Serde keySerde() { + return keySerde; + } + + public Serde valueSerde() { + return valueSerde; + } + + public String[] joinThisStoreNames() { + return joinThisStoreNames; + } + + public String[] joinOtherStoreNames() { + return joinOtherStoreNames; + } + + public String queryableStoreName() { + return mergeProcessorParameters().kTableKTableJoinMergerProcessorSupplier().getQueryableName(); + } + + /** + * The supplier which provides processor with KTable-KTable join merge functionality. + */ + @SuppressWarnings("unchecked") + public KTableKTableJoinMerger joinMerger() { + final KTableKTableJoinMerger> merger = + mergeProcessorParameters().kTableKTableJoinMergerProcessorSupplier(); + // this incorrect cast should be corrected by the end of the KIP-478 implementation + return (KTableKTableJoinMerger) merger; + } + + @Override + public void writeToTopology(final InternalTopologyBuilder topologyBuilder) { + final String thisProcessorName = thisProcessorParameters().processorName(); + final String otherProcessorName = otherProcessorParameters().processorName(); + final String mergeProcessorName = mergeProcessorParameters().processorName(); + + topologyBuilder.addProcessor( + thisProcessorName, + thisProcessorParameters().processorSupplier(), + thisJoinSideNodeName()); + + topologyBuilder.addProcessor( + otherProcessorName, + otherProcessorParameters().processorSupplier(), + otherJoinSideNodeName()); + + topologyBuilder.addProcessor( + mergeProcessorName, + mergeProcessorParameters().processorSupplier(), + thisProcessorName, + otherProcessorName); + + topologyBuilder.connectProcessorAndStateStores(thisProcessorName, joinOtherStoreNames); + topologyBuilder.connectProcessorAndStateStores(otherProcessorName, joinThisStoreNames); + + if (storeBuilder != null) { + topologyBuilder.addStateStore(storeBuilder, mergeProcessorName); + } + } + + @Override + public String toString() { + return "KTableKTableJoinNode{" + + "joinThisStoreNames=" + Arrays.toString(joinThisStoreNames()) + + ", joinOtherStoreNames=" + Arrays.toString(joinOtherStoreNames()) + + "} " + super.toString(); + } + + public static KTableKTableJoinNodeBuilder kTableKTableJoinNodeBuilder() { + return new KTableKTableJoinNodeBuilder<>(); + } + + public static final class KTableKTableJoinNodeBuilder { + private String nodeName; + private ProcessorParameters, ?, ?> joinThisProcessorParameters; + private ProcessorParameters, ?, ?> joinOtherProcessorParameters; + private String thisJoinSide; + private String otherJoinSide; + private Serde keySerde; + private Serde valueSerde; + private String[] joinThisStoreNames; + private String[] joinOtherStoreNames; + private String queryableStoreName; + private StoreBuilder> storeBuilder; + + private KTableKTableJoinNodeBuilder() { + } + + public KTableKTableJoinNodeBuilder withNodeName(final String nodeName) { + this.nodeName = nodeName; + return this; + } + + public KTableKTableJoinNodeBuilder withJoinThisProcessorParameters(final ProcessorParameters, ?, ?> joinThisProcessorParameters) { + this.joinThisProcessorParameters = joinThisProcessorParameters; + return this; + } + + public KTableKTableJoinNodeBuilder withJoinOtherProcessorParameters(final ProcessorParameters, ?, ?> joinOtherProcessorParameters) { + this.joinOtherProcessorParameters = joinOtherProcessorParameters; + return this; + } + + public KTableKTableJoinNodeBuilder withThisJoinSideNodeName(final String thisJoinSide) { + this.thisJoinSide = thisJoinSide; + return this; + } + + public KTableKTableJoinNodeBuilder withOtherJoinSideNodeName(final String otherJoinSide) { + this.otherJoinSide = otherJoinSide; + return this; + } + + public KTableKTableJoinNodeBuilder withKeySerde(final Serde keySerde) { + this.keySerde = keySerde; + return this; + } + + public KTableKTableJoinNodeBuilder withValueSerde(final Serde valueSerde) { + this.valueSerde = valueSerde; + return this; + } + + public KTableKTableJoinNodeBuilder withJoinThisStoreNames(final String[] joinThisStoreNames) { + this.joinThisStoreNames = joinThisStoreNames; + return this; + } + + public KTableKTableJoinNodeBuilder withJoinOtherStoreNames(final String[] joinOtherStoreNames) { + this.joinOtherStoreNames = joinOtherStoreNames; + return this; + } + + public KTableKTableJoinNodeBuilder withQueryableStoreName(final String queryableStoreName) { + this.queryableStoreName = queryableStoreName; + return this; + } + + public KTableKTableJoinNodeBuilder withStoreBuilder(final StoreBuilder> storeBuilder) { + this.storeBuilder = storeBuilder; + return this; + } + + public KTableKTableJoinNode build() { + return new KTableKTableJoinNode<>( + nodeName, + joinThisProcessorParameters, + joinOtherProcessorParameters, + new ProcessorParameters<>( + KTableKTableJoinMerger.of( + joinThisProcessorParameters.kTableProcessorSupplier(), + joinOtherProcessorParameters.kTableProcessorSupplier(), + queryableStoreName), + nodeName), + thisJoinSide, + otherJoinSide, + keySerde, + valueSerde, + joinThisStoreNames, + joinOtherStoreNames, + storeBuilder + ); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/OptimizableRepartitionNode.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/OptimizableRepartitionNode.java new file mode 100644 index 0000000..a9693ec --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/OptimizableRepartitionNode.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals.graph; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.processor.FailOnInvalidTimestamp; +import org.apache.kafka.streams.processor.StreamPartitioner; +import org.apache.kafka.streams.processor.internals.InternalTopicProperties; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; + +public class OptimizableRepartitionNode extends BaseRepartitionNode { + + private OptimizableRepartitionNode(final String nodeName, + final String sourceName, + final ProcessorParameters processorParameters, + final Serde keySerde, + final Serde valueSerde, + final String sinkName, + final String repartitionTopic, + final StreamPartitioner partitioner) { + super( + nodeName, + sourceName, + processorParameters, + keySerde, + valueSerde, + sinkName, + repartitionTopic, + partitioner, + InternalTopicProperties.empty() + ); + } + + public Serde keySerde() { + return keySerde; + } + + public Serde valueSerde() { + return valueSerde; + } + + public String repartitionTopic() { + return repartitionTopic; + } + + @Override + public String toString() { + return "OptimizableRepartitionNode{ " + super.toString() + " }"; + } + + @Override + public void writeToTopology(final InternalTopologyBuilder topologyBuilder) { + topologyBuilder.addInternalTopic(repartitionTopic, internalTopicProperties); + + topologyBuilder.addProcessor( + processorParameters.processorName(), + processorParameters.processorSupplier(), + parentNodeNames() + ); + + topologyBuilder.addSink( + sinkName, + repartitionTopic, + keySerializer(), + valueSerializer(), + partitioner, + processorParameters.processorName() + ); + + topologyBuilder.addSource( + null, + sourceName, + new FailOnInvalidTimestamp(), + keyDeserializer(), + valueDeserializer(), + repartitionTopic + ); + + } + + public static OptimizableRepartitionNodeBuilder optimizableRepartitionNodeBuilder() { + return new OptimizableRepartitionNodeBuilder<>(); + } + + + public static final class OptimizableRepartitionNodeBuilder extends BaseRepartitionNodeBuilder> { + + @Override + public OptimizableRepartitionNode build() { + + return new OptimizableRepartitionNode<>( + nodeName, + sourceName, + processorParameters, + keySerde, + valueSerde, + sinkName, + repartitionTopic, + partitioner + ); + + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/ProcessorGraphNode.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/ProcessorGraphNode.java new file mode 100644 index 0000000..a38f516 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/ProcessorGraphNode.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals.graph; + +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; + +/** + * Used to represent any type of stateless operation: + * + * map, mapValues, flatMap, flatMapValues, filter, filterNot, branch + */ +public class ProcessorGraphNode extends GraphNode { + + private final ProcessorParameters processorParameters; + + public ProcessorGraphNode(final ProcessorParameters processorParameters) { + + super(processorParameters.processorName()); + + this.processorParameters = processorParameters; + } + + public ProcessorGraphNode(final String nodeName, + final ProcessorParameters processorParameters) { + + super(nodeName); + + this.processorParameters = processorParameters; + } + + public ProcessorParameters processorParameters() { + return processorParameters; + } + + @Override + public String toString() { + return "ProcessorNode{" + + "processorParameters=" + processorParameters + + "} " + super.toString(); + } + + @Override + public void writeToTopology(final InternalTopologyBuilder topologyBuilder) { + + topologyBuilder.addProcessor(processorParameters.processorName(), processorParameters.processorSupplier(), parentNodeNames()); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/ProcessorParameters.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/ProcessorParameters.java new file mode 100644 index 0000000..da59ef6 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/ProcessorParameters.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals.graph; + +import org.apache.kafka.streams.kstream.internals.KTableKTableJoinMerger; +import org.apache.kafka.streams.kstream.internals.KTableProcessorSupplier; +import org.apache.kafka.streams.kstream.internals.KTableSource; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.internals.ProcessorAdapter; + +/** + * Class used to represent a {@link ProcessorSupplier} and the name + * used to register it with the {@link org.apache.kafka.streams.processor.internals.InternalTopologyBuilder} + * + * Used by the Join nodes as there are several parameters, this abstraction helps + * keep the number of arguments more reasonable. + */ +public class ProcessorParameters { + + // During the transition to KIP-478, we capture arguments passed from the old API to simplify + // the performance of casts that we still need to perform. This will eventually be removed. + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + private final org.apache.kafka.streams.processor.ProcessorSupplier oldProcessorSupplier; + private final ProcessorSupplier processorSupplier; + private final String processorName; + + @SuppressWarnings("deprecation") // Old PAPI compatibility. + public ProcessorParameters(final org.apache.kafka.streams.processor.ProcessorSupplier processorSupplier, + final String processorName) { + oldProcessorSupplier = processorSupplier; + this.processorSupplier = () -> ProcessorAdapter.adapt(processorSupplier.get()); + this.processorName = processorName; + } + + public ProcessorParameters(final ProcessorSupplier processorSupplier, + final String processorName) { + oldProcessorSupplier = null; + this.processorSupplier = processorSupplier; + this.processorName = processorName; + } + + public ProcessorSupplier processorSupplier() { + return processorSupplier; + } + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + public org.apache.kafka.streams.processor.ProcessorSupplier oldProcessorSupplier() { + return oldProcessorSupplier; + } + + @SuppressWarnings("unchecked") + KTableSource kTableSourceSupplier() { + return processorSupplier instanceof KTableSource ? (KTableSource) processorSupplier : null; + } + + @SuppressWarnings("unchecked") + KTableProcessorSupplier kTableProcessorSupplier() { + // This cast always works because KTableProcessorSupplier hasn't been converted yet. + return (KTableProcessorSupplier) oldProcessorSupplier; + } + + @SuppressWarnings("unchecked") + KTableKTableJoinMerger kTableKTableJoinMergerProcessorSupplier() { + return (KTableKTableJoinMerger) oldProcessorSupplier; + } + + public String processorName() { + return processorName; + } + + @Override + public String toString() { + return "ProcessorParameters{" + + "processor class=" + processorSupplier.get().getClass() + + ", processor name='" + processorName + '\'' + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/SourceGraphNode.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/SourceGraphNode.java new file mode 100644 index 0000000..0505227 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/SourceGraphNode.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals.graph; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.Optional; +import java.util.Set; +import java.util.regex.Pattern; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.kstream.internals.ConsumedInternal; + +abstract public class SourceGraphNode extends GraphNode { + + private final Set topicNames; + private final Pattern topicPattern; + private final ConsumedInternal consumedInternal; + + public SourceGraphNode(final String nodeName, + final Collection topicNames, + final ConsumedInternal consumedInternal) { + super(nodeName); + + this.topicNames = new HashSet<>(topicNames); + this.topicPattern = null; + this.consumedInternal = consumedInternal; + } + + public SourceGraphNode(final String nodeName, + final Pattern topicPattern, + final ConsumedInternal consumedInternal) { + + super(nodeName); + + this.topicNames = null; + this.topicPattern = topicPattern; + this.consumedInternal = consumedInternal; + } + + public Optional> topicNames() { + return topicNames == null ? Optional.empty() : Optional.of(Collections.unmodifiableSet(topicNames)); + } + + public Optional topicPattern() { + return Optional.ofNullable(topicPattern); + } + + public ConsumedInternal consumedInternal() { + return consumedInternal; + } + + public Serde keySerde() { + return consumedInternal.keySerde(); + } + + public Serde valueSerde() { + return consumedInternal.valueSerde(); + } + +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/StateStoreNode.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/StateStoreNode.java new file mode 100644 index 0000000..32dc93d --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/StateStoreNode.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals.graph; + +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; +import org.apache.kafka.streams.state.StoreBuilder; + +public class StateStoreNode extends GraphNode { + + protected final StoreBuilder storeBuilder; + + public StateStoreNode(final StoreBuilder storeBuilder) { + super(storeBuilder.name()); + + this.storeBuilder = storeBuilder; + } + + @Override + public void writeToTopology(final InternalTopologyBuilder topologyBuilder) { + + topologyBuilder.addStateStore(storeBuilder); + } + + @Override + public String toString() { + return "StateStoreNode{" + + " name='" + storeBuilder.name() + '\'' + + ", logConfig=" + storeBuilder.logConfig() + + ", loggingEnabled='" + storeBuilder.loggingEnabled() + '\'' + + "} "; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/StatefulProcessorNode.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/StatefulProcessorNode.java new file mode 100644 index 0000000..381a88a --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/StatefulProcessorNode.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals.graph; + +import org.apache.kafka.streams.kstream.internals.KTableValueGetterSupplier; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; +import org.apache.kafka.streams.state.StoreBuilder; + +import java.util.Arrays; +import java.util.Set; +import java.util.stream.Stream; + +public class StatefulProcessorNode extends ProcessorGraphNode { + + private final String[] storeNames; + private final StoreBuilder storeBuilder; + + /** + * Create a node representing a stateful processor, where the named stores have already been registered. + */ + public StatefulProcessorNode(final ProcessorParameters processorParameters, + final Set> preRegisteredStores, + final Set> valueGetterSuppliers) { + super(processorParameters.processorName(), processorParameters); + final Stream registeredStoreNames = preRegisteredStores.stream().map(StoreBuilder::name); + final Stream valueGetterStoreNames = valueGetterSuppliers.stream().flatMap(s -> Arrays.stream(s.storeNames())); + storeNames = Stream.concat(registeredStoreNames, valueGetterStoreNames).toArray(String[]::new); + storeBuilder = null; + } + + /** + * Create a node representing a stateful processor, where the named stores have already been registered. + */ + public StatefulProcessorNode(final String nodeName, + final ProcessorParameters processorParameters, + final String[] storeNames) { + super(nodeName, processorParameters); + + this.storeNames = storeNames; + this.storeBuilder = null; + } + + + /** + * Create a node representing a stateful processor, + * where the store needs to be built and registered as part of building this node. + */ + public StatefulProcessorNode(final String nodeName, + final ProcessorParameters processorParameters, + final StoreBuilder materializedKTableStoreBuilder) { + super(nodeName, processorParameters); + + this.storeNames = null; + this.storeBuilder = materializedKTableStoreBuilder; + } + + @Override + public String toString() { + return "StatefulProcessorNode{" + + "storeNames=" + Arrays.toString(storeNames) + + ", storeBuilder=" + storeBuilder + + "} " + super.toString(); + } + + @Override + public void writeToTopology(final InternalTopologyBuilder topologyBuilder) { + + final String processorName = processorParameters().processorName(); + final ProcessorSupplier processorSupplier = processorParameters().processorSupplier(); + + topologyBuilder.addProcessor(processorName, processorSupplier, parentNodeNames()); + + if (storeNames != null && storeNames.length > 0) { + topologyBuilder.connectProcessorAndStateStores(processorName, storeNames); + } + + if (storeBuilder != null) { + topologyBuilder.addStateStore(storeBuilder, processorName); + } + + if (processorSupplier.stores() != null) { + for (final StoreBuilder storeBuilder : processorSupplier.stores()) { + topologyBuilder.addStateStore(storeBuilder, processorName); + } + } + + // temporary hack until KIP-478 is fully implemented + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + final org.apache.kafka.streams.processor.ProcessorSupplier oldProcessorSupplier = + processorParameters().oldProcessorSupplier(); + if (oldProcessorSupplier != null && oldProcessorSupplier.stores() != null) { + for (final StoreBuilder storeBuilder : oldProcessorSupplier.stores()) { + topologyBuilder.addStateStore(storeBuilder, processorName); + } + } + + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/StreamSinkNode.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/StreamSinkNode.java new file mode 100644 index 0000000..f12a9e5 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/StreamSinkNode.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals.graph; + +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.kstream.internals.ProducedInternal; +import org.apache.kafka.streams.kstream.internals.WindowedSerializer; +import org.apache.kafka.streams.kstream.internals.WindowedStreamPartitioner; +import org.apache.kafka.streams.processor.StreamPartitioner; +import org.apache.kafka.streams.processor.TopicNameExtractor; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; +import org.apache.kafka.streams.processor.internals.StaticTopicNameExtractor; + +public class StreamSinkNode extends GraphNode { + + private final TopicNameExtractor topicNameExtractor; + private final ProducedInternal producedInternal; + + public StreamSinkNode(final String nodeName, + final TopicNameExtractor topicNameExtractor, + final ProducedInternal producedInternal) { + + super(nodeName); + + this.topicNameExtractor = topicNameExtractor; + this.producedInternal = producedInternal; + } + + + @Override + public String toString() { + return "StreamSinkNode{" + + "topicNameExtractor=" + topicNameExtractor + + ", producedInternal=" + producedInternal + + "} " + super.toString(); + } + + @Override + @SuppressWarnings("unchecked") + public void writeToTopology(final InternalTopologyBuilder topologyBuilder) { + final Serializer keySerializer = producedInternal.keySerde() == null ? null : producedInternal.keySerde().serializer(); + final Serializer valSerializer = producedInternal.valueSerde() == null ? null : producedInternal.valueSerde().serializer(); + final String[] parentNames = parentNodeNames(); + + final StreamPartitioner partitioner; + if (producedInternal.streamPartitioner() == null && keySerializer instanceof WindowedSerializer) { + partitioner = (StreamPartitioner) new WindowedStreamPartitioner((WindowedSerializer) keySerializer); + } else { + partitioner = producedInternal.streamPartitioner(); + } + + if (topicNameExtractor instanceof StaticTopicNameExtractor) { + final String topicName = ((StaticTopicNameExtractor) topicNameExtractor).topicName; + topologyBuilder.addSink(nodeName(), topicName, keySerializer, valSerializer, partitioner, parentNames); + } else { + topologyBuilder.addSink(nodeName(), topicNameExtractor, keySerializer, valSerializer, partitioner, parentNames); + } + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/StreamSourceNode.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/StreamSourceNode.java new file mode 100644 index 0000000..81cd569 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/StreamSourceNode.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals.graph; + +import org.apache.kafka.streams.Topology.AutoOffsetReset; +import org.apache.kafka.streams.errors.TopologyException; +import org.apache.kafka.streams.kstream.internals.ConsumedInternal; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; + +import java.util.Collection; +import java.util.regex.Pattern; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class StreamSourceNode extends SourceGraphNode { + + private final Logger log = LoggerFactory.getLogger(StreamSourceNode.class); + + public StreamSourceNode(final String nodeName, + final Collection topicNames, + final ConsumedInternal consumedInternal) { + super(nodeName, topicNames, consumedInternal); + } + + public StreamSourceNode(final String nodeName, + final Pattern topicPattern, + final ConsumedInternal consumedInternal) { + + super(nodeName, topicPattern, consumedInternal); + } + + public void merge(final StreamSourceNode other) { + final AutoOffsetReset resetPolicy = consumedInternal().offsetResetPolicy(); + final AutoOffsetReset otherResetPolicy = other.consumedInternal().offsetResetPolicy(); + if (resetPolicy != null && !resetPolicy.equals(otherResetPolicy) + || otherResetPolicy != null && !otherResetPolicy.equals(resetPolicy)) { + log.error("Tried to merge source nodes {} and {} which are subscribed to the same topic/pattern, but " + + "the offset reset policies do not match", this, other); + throw new TopologyException("Can't configure different offset reset policies on the same input topic(s)"); + } + for (final GraphNode otherChild : other.children()) { + other.removeChild(otherChild); + addChild(otherChild); + } + } + + @Override + public String toString() { + return "StreamSourceNode{" + + "topicNames=" + (topicNames().isPresent() ? topicNames().get() : null) + + ", topicPattern=" + (topicPattern().isPresent() ? topicPattern().get() : null) + + ", consumedInternal=" + consumedInternal() + + "} " + super.toString(); + } + + @Override + public void writeToTopology(final InternalTopologyBuilder topologyBuilder) { + + if (topicPattern().isPresent()) { + topologyBuilder.addSource(consumedInternal().offsetResetPolicy(), + nodeName(), + consumedInternal().timestampExtractor(), + consumedInternal().keyDeserializer(), + consumedInternal().valueDeserializer(), + topicPattern().get()); + } else { + topologyBuilder.addSource(consumedInternal().offsetResetPolicy(), + nodeName(), + consumedInternal().timestampExtractor(), + consumedInternal().keyDeserializer(), + consumedInternal().valueDeserializer(), + topicNames().get().toArray(new String[0])); + } + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/StreamStreamJoinNode.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/StreamStreamJoinNode.java new file mode 100644 index 0000000..48f9587 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/StreamStreamJoinNode.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals.graph; + +import org.apache.kafka.streams.kstream.Joined; +import org.apache.kafka.streams.kstream.ValueJoinerWithKey; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.streams.state.internals.TimestampedKeyAndJoinSide; +import org.apache.kafka.streams.state.internals.LeftOrRightValue; + +import java.util.Optional; + +/** + * Too much information to generalize, so Stream-Stream joins are represented by a specific node. + */ +public class StreamStreamJoinNode extends BaseJoinProcessorNode { + private final ProcessorParameters thisWindowedStreamProcessorParameters; + private final ProcessorParameters otherWindowedStreamProcessorParameters; + private final StoreBuilder> thisWindowStoreBuilder; + private final StoreBuilder> otherWindowStoreBuilder; + private final Optional, LeftOrRightValue>>> outerJoinWindowStoreBuilder; + private final Joined joined; + private final boolean enableSpuriousResultFix; + + private StreamStreamJoinNode(final String nodeName, + final ValueJoinerWithKey valueJoiner, + final ProcessorParameters joinThisProcessorParameters, + final ProcessorParameters joinOtherProcessParameters, + final ProcessorParameters joinMergeProcessorParameters, + final ProcessorParameters thisWindowedStreamProcessorParameters, + final ProcessorParameters otherWindowedStreamProcessorParameters, + final StoreBuilder> thisWindowStoreBuilder, + final StoreBuilder> otherWindowStoreBuilder, + final Optional, LeftOrRightValue>>> outerJoinWindowStoreBuilder, + final Joined joined, + final boolean enableSpuriousResultFix) { + + super(nodeName, + valueJoiner, + joinThisProcessorParameters, + joinOtherProcessParameters, + joinMergeProcessorParameters, + null, + null); + + this.thisWindowStoreBuilder = thisWindowStoreBuilder; + this.otherWindowStoreBuilder = otherWindowStoreBuilder; + this.joined = joined; + this.thisWindowedStreamProcessorParameters = thisWindowedStreamProcessorParameters; + this.otherWindowedStreamProcessorParameters = otherWindowedStreamProcessorParameters; + this.outerJoinWindowStoreBuilder = outerJoinWindowStoreBuilder; + this.enableSpuriousResultFix = enableSpuriousResultFix; + } + + + @Override + public String toString() { + return "StreamStreamJoinNode{" + + "thisWindowedStreamProcessorParameters=" + thisWindowedStreamProcessorParameters + + ", otherWindowedStreamProcessorParameters=" + otherWindowedStreamProcessorParameters + + ", thisWindowStoreBuilder=" + thisWindowStoreBuilder + + ", otherWindowStoreBuilder=" + otherWindowStoreBuilder + + ", outerJoinWindowStoreBuilder=" + outerJoinWindowStoreBuilder + + ", joined=" + joined + + "} " + super.toString(); + } + + @SuppressWarnings("unchecked") + @Override + public void writeToTopology(final InternalTopologyBuilder topologyBuilder) { + + final String thisProcessorName = thisProcessorParameters().processorName(); + final String otherProcessorName = otherProcessorParameters().processorName(); + final String thisWindowedStreamProcessorName = thisWindowedStreamProcessorParameters.processorName(); + final String otherWindowedStreamProcessorName = otherWindowedStreamProcessorParameters.processorName(); + + topologyBuilder.addProcessor(thisProcessorName, thisProcessorParameters().processorSupplier(), thisWindowedStreamProcessorName); + topologyBuilder.addProcessor(otherProcessorName, otherProcessorParameters().processorSupplier(), otherWindowedStreamProcessorName); + topologyBuilder.addProcessor(mergeProcessorParameters().processorName(), mergeProcessorParameters().processorSupplier(), thisProcessorName, otherProcessorName); + topologyBuilder.addStateStore(thisWindowStoreBuilder, thisWindowedStreamProcessorName, otherProcessorName); + topologyBuilder.addStateStore(otherWindowStoreBuilder, otherWindowedStreamProcessorName, thisProcessorName); + + if (enableSpuriousResultFix) { + outerJoinWindowStoreBuilder.ifPresent(builder -> topologyBuilder.addStateStore(builder, thisProcessorName, otherProcessorName)); + } + } + + public static StreamStreamJoinNodeBuilder streamStreamJoinNodeBuilder() { + return new StreamStreamJoinNodeBuilder<>(); + } + + public static final class StreamStreamJoinNodeBuilder { + + private String nodeName; + private ValueJoinerWithKey valueJoiner; + private ProcessorParameters joinThisProcessorParameters; + private ProcessorParameters joinOtherProcessorParameters; + private ProcessorParameters joinMergeProcessorParameters; + private ProcessorParameters thisWindowedStreamProcessorParameters; + private ProcessorParameters otherWindowedStreamProcessorParameters; + private StoreBuilder> thisWindowStoreBuilder; + private StoreBuilder> otherWindowStoreBuilder; + private Optional, LeftOrRightValue>>> outerJoinWindowStoreBuilder; + private Joined joined; + private boolean enableSpuriousResultFix = false; + + private StreamStreamJoinNodeBuilder() { + } + + public StreamStreamJoinNodeBuilder withValueJoiner(final ValueJoinerWithKey valueJoiner) { + this.valueJoiner = valueJoiner; + return this; + } + + public StreamStreamJoinNodeBuilder withJoinThisProcessorParameters(final ProcessorParameters joinThisProcessorParameters) { + this.joinThisProcessorParameters = joinThisProcessorParameters; + return this; + } + + public StreamStreamJoinNodeBuilder withNodeName(final String nodeName) { + this.nodeName = nodeName; + return this; + } + + public StreamStreamJoinNodeBuilder withJoinOtherProcessorParameters(final ProcessorParameters joinOtherProcessParameters) { + this.joinOtherProcessorParameters = joinOtherProcessParameters; + return this; + } + + public StreamStreamJoinNodeBuilder withJoinMergeProcessorParameters(final ProcessorParameters joinMergeProcessorParameters) { + this.joinMergeProcessorParameters = joinMergeProcessorParameters; + return this; + } + + public StreamStreamJoinNodeBuilder withThisWindowedStreamProcessorParameters(final ProcessorParameters thisWindowedStreamProcessorParameters) { + this.thisWindowedStreamProcessorParameters = thisWindowedStreamProcessorParameters; + return this; + } + + public StreamStreamJoinNodeBuilder withOtherWindowedStreamProcessorParameters( + final ProcessorParameters otherWindowedStreamProcessorParameters) { + this.otherWindowedStreamProcessorParameters = otherWindowedStreamProcessorParameters; + return this; + } + + public StreamStreamJoinNodeBuilder withThisWindowStoreBuilder(final StoreBuilder> thisWindowStoreBuilder) { + this.thisWindowStoreBuilder = thisWindowStoreBuilder; + return this; + } + + public StreamStreamJoinNodeBuilder withOtherWindowStoreBuilder(final StoreBuilder> otherWindowStoreBuilder) { + this.otherWindowStoreBuilder = otherWindowStoreBuilder; + return this; + } + + public StreamStreamJoinNodeBuilder withOuterJoinWindowStoreBuilder(final Optional, LeftOrRightValue>>> outerJoinWindowStoreBuilder) { + this.outerJoinWindowStoreBuilder = outerJoinWindowStoreBuilder; + return this; + } + + public StreamStreamJoinNodeBuilder withJoined(final Joined joined) { + this.joined = joined; + return this; + } + + public StreamStreamJoinNodeBuilder withSpuriousResultFixEnabled() { + this.enableSpuriousResultFix = true; + return this; + } + + public StreamStreamJoinNode build() { + + return new StreamStreamJoinNode<>(nodeName, + valueJoiner, + joinThisProcessorParameters, + joinOtherProcessorParameters, + joinMergeProcessorParameters, + thisWindowedStreamProcessorParameters, + otherWindowedStreamProcessorParameters, + thisWindowStoreBuilder, + otherWindowStoreBuilder, + outerJoinWindowStoreBuilder, + joined, + enableSpuriousResultFix); + + + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/StreamTableJoinNode.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/StreamTableJoinNode.java new file mode 100644 index 0000000..a4db1ba --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/StreamTableJoinNode.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals.graph; + +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; + +import java.util.Arrays; + +/** + * Represents a join between a KStream and a KTable or GlobalKTable + */ + +public class StreamTableJoinNode extends GraphNode { + + private final String[] storeNames; + private final ProcessorParameters processorParameters; + private final String otherJoinSideNodeName; + + public StreamTableJoinNode(final String nodeName, + final ProcessorParameters processorParameters, + final String[] storeNames, + final String otherJoinSideNodeName) { + super(nodeName); + + // in the case of Stream-Table join the state stores associated with the KTable + this.storeNames = storeNames; + this.processorParameters = processorParameters; + this.otherJoinSideNodeName = otherJoinSideNodeName; + } + + @Override + public String toString() { + return "StreamTableJoinNode{" + + "storeNames=" + Arrays.toString(storeNames) + + ", processorParameters=" + processorParameters + + ", otherJoinSideNodeName='" + otherJoinSideNodeName + '\'' + + "} " + super.toString(); + } + + @Override + public void writeToTopology(final InternalTopologyBuilder topologyBuilder) { + final String processorName = processorParameters.processorName(); + final ProcessorSupplier processorSupplier = processorParameters.processorSupplier(); + + // Stream - Table join (Global or KTable) + topologyBuilder.addProcessor(processorName, processorSupplier, parentNodeNames()); + + // Steam - KTable join only + if (otherJoinSideNodeName != null) { + topologyBuilder.connectProcessorAndStateStores(processorName, storeNames); + } + + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/StreamToTableNode.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/StreamToTableNode.java new file mode 100644 index 0000000..3b0a572 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/StreamToTableNode.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals.graph; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.kstream.internals.KTableSource; +import org.apache.kafka.streams.kstream.internals.MaterializedInternal; +import org.apache.kafka.streams.kstream.internals.TimestampedKeyValueStoreMaterializer; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; + +/** + * Represents a KTable convert From KStream + */ +public class StreamToTableNode extends GraphNode { + + private final ProcessorParameters processorParameters; + private final MaterializedInternal materializedInternal; + + public StreamToTableNode(final String nodeName, + final ProcessorParameters processorParameters, + final MaterializedInternal materializedInternal) { + super(nodeName); + this.processorParameters = processorParameters; + this.materializedInternal = materializedInternal; + } + + @Override + public String toString() { + return "StreamToTableNode{" + + ", processorParameters=" + processorParameters + + ", materializedInternal=" + materializedInternal + + "} " + super.toString(); + } + + @SuppressWarnings("unchecked") + @Override + public void writeToTopology(final InternalTopologyBuilder topologyBuilder) { + final StoreBuilder> storeBuilder = + new TimestampedKeyValueStoreMaterializer<>((MaterializedInternal>) materializedInternal).materialize(); + + final String processorName = processorParameters.processorName(); + final KTableSource ktableSource = processorParameters.kTableSourceSupplier(); + topologyBuilder.addProcessor(processorName, processorParameters.processorSupplier(), parentNodeNames()); + + if (storeBuilder != null && ktableSource.materialized()) { + topologyBuilder.addStateStore(storeBuilder, processorName); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/TableProcessorNode.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/TableProcessorNode.java new file mode 100644 index 0000000..ff90a20 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/TableProcessorNode.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals.graph; + +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; + +import java.util.Arrays; +import java.util.Objects; + +public class TableProcessorNode extends GraphNode { + + private final ProcessorParameters processorParameters; + private final StoreBuilder> storeBuilder; + private final String[] storeNames; + + public TableProcessorNode(final String nodeName, + final ProcessorParameters processorParameters, + final StoreBuilder> storeBuilder) { + this(nodeName, processorParameters, storeBuilder, null); + } + + public TableProcessorNode(final String nodeName, + final ProcessorParameters processorParameters, + // TODO KIP-300: we are enforcing this as a keyvalue store, but it should go beyond any type of stores + final StoreBuilder> storeBuilder, + final String[] storeNames) { + super(nodeName); + this.processorParameters = processorParameters; + this.storeBuilder = storeBuilder; + this.storeNames = storeNames != null ? storeNames : new String[] {}; + } + + @Override + public String toString() { + return "TableProcessorNode{" + + ", processorParameters=" + processorParameters + + ", storeBuilder=" + (storeBuilder == null ? "null" : storeBuilder.name()) + + ", storeNames=" + Arrays.toString(storeNames) + + "} " + super.toString(); + } + + @Override + public void writeToTopology(final InternalTopologyBuilder topologyBuilder) { + final String processorName = processorParameters.processorName(); + topologyBuilder.addProcessor(processorName, processorParameters.processorSupplier(), parentNodeNames()); + + if (storeNames.length > 0) { + topologyBuilder.connectProcessorAndStateStores(processorName, storeNames); + } + + if (processorParameters.kTableSourceSupplier() != null) { + if (processorParameters.kTableSourceSupplier().materialized()) { + topologyBuilder.addStateStore(Objects.requireNonNull(storeBuilder, "storeBuilder was null"), + processorName); + } + } else if (storeBuilder != null) { + topologyBuilder.addStateStore(storeBuilder, processorName); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/TableSourceNode.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/TableSourceNode.java new file mode 100644 index 0000000..3b35673 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/TableSourceNode.java @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals.graph; + +import java.util.Iterator; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.kstream.internals.ConsumedInternal; +import org.apache.kafka.streams.kstream.internals.KTableSource; +import org.apache.kafka.streams.kstream.internals.MaterializedInternal; +import org.apache.kafka.streams.kstream.internals.TimestampedKeyValueStoreMaterializer; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; + +import java.util.Collections; + +/** + * Used to represent either a KTable source or a GlobalKTable source. A boolean flag is used to indicate if this represents a GlobalKTable a {@link + * org.apache.kafka.streams.kstream.GlobalKTable} + */ +public class TableSourceNode extends SourceGraphNode { + + private final MaterializedInternal materializedInternal; + private final ProcessorParameters processorParameters; + private final String sourceName; + private final boolean isGlobalKTable; + private boolean shouldReuseSourceTopicForChangelog = false; + + private TableSourceNode(final String nodeName, + final String sourceName, + final String topic, + final ConsumedInternal consumedInternal, + final MaterializedInternal materializedInternal, + final ProcessorParameters processorParameters, + final boolean isGlobalKTable) { + + super(nodeName, + Collections.singletonList(topic), + consumedInternal); + + this.sourceName = sourceName; + this.isGlobalKTable = isGlobalKTable; + this.processorParameters = processorParameters; + this.materializedInternal = materializedInternal; + } + + + public void reuseSourceTopicForChangeLog(final boolean shouldReuseSourceTopicForChangelog) { + this.shouldReuseSourceTopicForChangelog = shouldReuseSourceTopicForChangelog; + } + + @Override + public String toString() { + return "TableSourceNode{" + + "materializedInternal=" + materializedInternal + + ", processorParameters=" + processorParameters + + ", sourceName='" + sourceName + '\'' + + ", isGlobalKTable=" + isGlobalKTable + + "} " + super.toString(); + } + + public static TableSourceNodeBuilder tableSourceNodeBuilder() { + return new TableSourceNodeBuilder<>(); + } + + @Override + @SuppressWarnings("unchecked") + public void writeToTopology(final InternalTopologyBuilder topologyBuilder) { + final String topicName; + if (topicNames().isPresent()) { + final Iterator topicNames = topicNames().get().iterator(); + topicName = topicNames.next(); + if (topicNames.hasNext()) { + throw new IllegalStateException("A table source node must have a single topic as input"); + } + } else { + throw new IllegalStateException("A table source node must have a single topic as input"); + } + + // TODO: we assume source KTables can only be timestamped-key-value stores for now. + // should be expanded for other types of stores as well. + final StoreBuilder> storeBuilder = + new TimestampedKeyValueStoreMaterializer<>((MaterializedInternal>) materializedInternal).materialize(); + + if (isGlobalKTable) { + topologyBuilder.addGlobalStore( + storeBuilder, + sourceName, + consumedInternal().timestampExtractor(), + consumedInternal().keyDeserializer(), + consumedInternal().valueDeserializer(), + topicName, + processorParameters.processorName(), + (ProcessorSupplier) processorParameters.processorSupplier() + ); + } else { + topologyBuilder.addSource(consumedInternal().offsetResetPolicy(), + sourceName, + consumedInternal().timestampExtractor(), + consumedInternal().keyDeserializer(), + consumedInternal().valueDeserializer(), + topicName); + + topologyBuilder.addProcessor(processorParameters.processorName(), processorParameters.processorSupplier(), sourceName); + + // only add state store if the source KTable should be materialized + final KTableSource ktableSource = processorParameters.kTableSourceSupplier(); + if (ktableSource.materialized()) { + topologyBuilder.addStateStore(storeBuilder, nodeName()); + + if (shouldReuseSourceTopicForChangelog) { + storeBuilder.withLoggingDisabled(); + topologyBuilder.connectSourceStoreAndTopic(storeBuilder.name(), topicName); + } + } + } + + } + + public static final class TableSourceNodeBuilder { + + private String nodeName; + private String sourceName; + private String topic; + private ConsumedInternal consumedInternal; + private MaterializedInternal materializedInternal; + private ProcessorParameters processorParameters; + private boolean isGlobalKTable = false; + + private TableSourceNodeBuilder() { + } + + public TableSourceNodeBuilder withSourceName(final String sourceName) { + this.sourceName = sourceName; + return this; + } + + public TableSourceNodeBuilder withTopic(final String topic) { + this.topic = topic; + return this; + } + + public TableSourceNodeBuilder withMaterializedInternal(final MaterializedInternal materializedInternal) { + this.materializedInternal = materializedInternal; + return this; + } + + public TableSourceNodeBuilder withConsumedInternal(final ConsumedInternal consumedInternal) { + this.consumedInternal = consumedInternal; + return this; + } + + public TableSourceNodeBuilder withProcessorParameters(final ProcessorParameters processorParameters) { + this.processorParameters = processorParameters; + return this; + } + + public TableSourceNodeBuilder withNodeName(final String nodeName) { + this.nodeName = nodeName; + return this; + } + + public TableSourceNodeBuilder isGlobalKTable(final boolean isGlobaKTable) { + this.isGlobalKTable = isGlobaKTable; + return this; + } + + public TableSourceNode build() { + return new TableSourceNode<>(nodeName, + sourceName, + topic, + consumedInternal, + materializedInternal, + processorParameters, + isGlobalKTable); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/UnoptimizableRepartitionNode.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/UnoptimizableRepartitionNode.java new file mode 100644 index 0000000..daac9bd --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/graph/UnoptimizableRepartitionNode.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals.graph; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.processor.FailOnInvalidTimestamp; +import org.apache.kafka.streams.processor.StreamPartitioner; +import org.apache.kafka.streams.processor.internals.InternalTopicProperties; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; + +/** + * Repartition node that is not subject of optimization algorithm + */ +public class UnoptimizableRepartitionNode extends BaseRepartitionNode { + + private UnoptimizableRepartitionNode(final String nodeName, + final String sourceName, + final ProcessorParameters processorParameters, + final Serde keySerde, + final Serde valueSerde, + final String sinkName, + final String repartitionTopic, + final StreamPartitioner partitioner, + final InternalTopicProperties internalTopicProperties) { + super( + nodeName, + sourceName, + processorParameters, + keySerde, + valueSerde, + sinkName, + repartitionTopic, + partitioner, + internalTopicProperties + ); + } + + @Override + public void writeToTopology(final InternalTopologyBuilder topologyBuilder) { + topologyBuilder.addInternalTopic(repartitionTopic, internalTopicProperties); + + topologyBuilder.addProcessor( + processorParameters.processorName(), + processorParameters.processorSupplier(), + parentNodeNames() + ); + + topologyBuilder.addSink( + sinkName, + repartitionTopic, + keySerializer(), + valueSerializer(), + partitioner, + processorParameters.processorName() + ); + + topologyBuilder.addSource( + null, + sourceName, + new FailOnInvalidTimestamp(), + keyDeserializer(), + valueDeserializer(), + repartitionTopic + ); + } + + @Override + public String toString() { + return "UnoptimizableRepartitionNode{" + super.toString() + " }"; + } + + public static UnoptimizableRepartitionNodeBuilder unoptimizableRepartitionNodeBuilder() { + return new UnoptimizableRepartitionNodeBuilder<>(); + } + + public static final class UnoptimizableRepartitionNodeBuilder extends BaseRepartitionNodeBuilder> { + + @Override + public UnoptimizableRepartitionNode build() { + return new UnoptimizableRepartitionNode<>(nodeName, + sourceName, + processorParameters, + keySerde, + valueSerde, + sinkName, + repartitionTopic, + partitioner, + internalTopicProperties); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/BufferConfigInternal.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/BufferConfigInternal.java new file mode 100644 index 0000000..800a2a5 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/BufferConfigInternal.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals.suppress; + +import org.apache.kafka.streams.kstream.Suppressed; + +import java.util.Map; + +import static org.apache.kafka.streams.kstream.internals.suppress.BufferFullStrategy.SHUT_DOWN; + +public abstract class BufferConfigInternal> implements Suppressed.BufferConfig { + public abstract long maxRecords(); + + public abstract long maxBytes(); + + @SuppressWarnings("unused") + public abstract BufferFullStrategy bufferFullStrategy(); + + @Override + public Suppressed.StrictBufferConfig withNoBound() { + return new StrictBufferConfigImpl( + Long.MAX_VALUE, + Long.MAX_VALUE, + SHUT_DOWN, // doesn't matter, given the bounds + getLogConfig() + ); + } + + @Override + public Suppressed.StrictBufferConfig shutDownWhenFull() { + return new StrictBufferConfigImpl(maxRecords(), maxBytes(), SHUT_DOWN, getLogConfig()); + } + + @Override + public Suppressed.EagerBufferConfig emitEarlyWhenFull() { + return new EagerBufferConfigImpl(maxRecords(), maxBytes(), getLogConfig()); + } + + public abstract boolean isLoggingEnabled(); + + public abstract Map getLogConfig(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/BufferFullStrategy.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/BufferFullStrategy.java new file mode 100644 index 0000000..870a3d1 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/BufferFullStrategy.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals.suppress; + +public enum BufferFullStrategy { + EMIT, + SHUT_DOWN +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/EagerBufferConfigImpl.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/EagerBufferConfigImpl.java new file mode 100644 index 0000000..7665e66 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/EagerBufferConfigImpl.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals.suppress; + +import org.apache.kafka.streams.kstream.Suppressed; + +import java.util.Collections; +import java.util.Map; +import java.util.Objects; + +public class EagerBufferConfigImpl extends BufferConfigInternal implements Suppressed.EagerBufferConfig { + + private final long maxRecords; + private final long maxBytes; + private final Map logConfig; + + public EagerBufferConfigImpl(final long maxRecords, + final long maxBytes, + final Map logConfig) { + this.maxRecords = maxRecords; + this.maxBytes = maxBytes; + this.logConfig = logConfig; + } + + @Override + public Suppressed.EagerBufferConfig withMaxRecords(final long recordLimit) { + return new EagerBufferConfigImpl(recordLimit, maxBytes, logConfig); + } + + @Override + public Suppressed.EagerBufferConfig withMaxBytes(final long byteLimit) { + return new EagerBufferConfigImpl(maxRecords, byteLimit, logConfig); + } + + @Override + public long maxRecords() { + return maxRecords; + } + + @Override + public long maxBytes() { + return maxBytes; + } + + @Override + public BufferFullStrategy bufferFullStrategy() { + return BufferFullStrategy.EMIT; + } + + @Override + public Suppressed.EagerBufferConfig withLoggingDisabled() { + return new EagerBufferConfigImpl(maxRecords, maxBytes, null); + } + + @Override + public Suppressed.EagerBufferConfig withLoggingEnabled(final Map config) { + return new EagerBufferConfigImpl(maxRecords, maxBytes, config); + } + + @Override + public boolean isLoggingEnabled() { + return logConfig != null; + } + + @Override + public Map getLogConfig() { + return isLoggingEnabled() ? logConfig : Collections.emptyMap(); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final EagerBufferConfigImpl that = (EagerBufferConfigImpl) o; + return maxRecords == that.maxRecords && + maxBytes == that.maxBytes && + Objects.equals(getLogConfig(), that.getLogConfig()); + } + + @Override + public int hashCode() { + return Objects.hash(maxRecords, maxBytes, getLogConfig()); + } + + @Override + public String toString() { + return "EagerBufferConfigImpl{maxRecords=" + maxRecords + + ", maxBytes=" + maxBytes + + ", logConfig=" + getLogConfig() + + "}"; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/FinalResultsSuppressionBuilder.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/FinalResultsSuppressionBuilder.java new file mode 100644 index 0000000..e917556 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/FinalResultsSuppressionBuilder.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals.suppress; + +import org.apache.kafka.streams.kstream.Suppressed; +import org.apache.kafka.streams.kstream.Windowed; + +import java.time.Duration; +import java.util.Objects; + +public class FinalResultsSuppressionBuilder implements Suppressed, NamedSuppressed { + private final String name; + private final StrictBufferConfig bufferConfig; + + public FinalResultsSuppressionBuilder(final String name, final Suppressed.StrictBufferConfig bufferConfig) { + this.name = name; + this.bufferConfig = bufferConfig; + } + + public SuppressedInternal buildFinalResultsSuppression(final Duration gracePeriod) { + return new SuppressedInternal<>( + name, + gracePeriod, + bufferConfig, + TimeDefinitions.WindowEndTimeDefinition.instance(), + true + ); + } + + @Override + public Suppressed withName(final String name) { + return new FinalResultsSuppressionBuilder<>(name, bufferConfig); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final FinalResultsSuppressionBuilder that = (FinalResultsSuppressionBuilder) o; + return Objects.equals(name, that.name) && + Objects.equals(bufferConfig, that.bufferConfig); + } + + @Override + public String name() { + return name; + } + + @Override + public int hashCode() { + return Objects.hash(name, bufferConfig); + } + + @Override + public String toString() { + return "FinalResultsSuppressionBuilder{" + + "name='" + name + '\'' + + ", bufferConfig=" + bufferConfig + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/KTableSuppressProcessorSupplier.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/KTableSuppressProcessorSupplier.java new file mode 100644 index 0000000..ef7943b --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/KTableSuppressProcessorSupplier.java @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals.suppress; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.kstream.internals.Change; +import org.apache.kafka.streams.kstream.internals.KTableImpl; +import org.apache.kafka.streams.kstream.internals.KTableNewProcessorSupplier; +import org.apache.kafka.streams.kstream.internals.KTableValueGetter; +import org.apache.kafka.streams.kstream.internals.KTableValueGetterSupplier; +import org.apache.kafka.streams.kstream.internals.suppress.TimeDefinitions.TimeDefinition; +import org.apache.kafka.streams.processor.api.ContextualProcessor; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.processor.internals.ProcessorRecordContext; +import org.apache.kafka.streams.processor.internals.SerdeGetter; +import org.apache.kafka.streams.processor.internals.metrics.ProcessorNodeMetrics; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.streams.state.internals.Maybe; +import org.apache.kafka.streams.state.internals.TimeOrderedKeyValueBuffer; + +import static java.util.Objects.requireNonNull; + +public class KTableSuppressProcessorSupplier implements KTableNewProcessorSupplier { + private final SuppressedInternal suppress; + private final String storeName; + private final KTableImpl parentKTable; + + public KTableSuppressProcessorSupplier(final SuppressedInternal suppress, + final String storeName, + final KTableImpl parentKTable) { + this.suppress = suppress; + this.storeName = storeName; + this.parentKTable = parentKTable; + // The suppress buffer requires seeing the old values, to support the prior value view. + parentKTable.enableSendingOldValues(true); + } + + @Override + public Processor, K, Change> get() { + return new KTableSuppressProcessor<>(suppress, storeName); + } + + @Override + public KTableValueGetterSupplier view() { + final KTableValueGetterSupplier parentValueGetterSupplier = parentKTable.valueGetterSupplier(); + return new KTableValueGetterSupplier() { + + @Override + public KTableValueGetter get() { + final KTableValueGetter parentGetter = parentValueGetterSupplier.get(); + return new KTableValueGetter() { + private TimeOrderedKeyValueBuffer buffer; + + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + parentGetter.init(context); + // the main processor is responsible for the buffer's lifecycle + buffer = requireNonNull(context.getStateStore(storeName)); + } + + @Override + public ValueAndTimestamp get(final K key) { + final Maybe> maybeValue = buffer.priorValueForBuffered(key); + if (maybeValue.isDefined()) { + return maybeValue.getNullableValue(); + } else { + // not buffered, so the suppressed view is equal to the parent view + return parentGetter.get(key); + } + } + + @Override + public void close() { + // the main processor is responsible for the buffer's lifecycle + parentGetter.close(); + } + }; + } + + @Override + public String[] storeNames() { + final String[] parentStores = parentValueGetterSupplier.storeNames(); + final String[] stores = new String[1 + parentStores.length]; + System.arraycopy(parentStores, 0, stores, 1, parentStores.length); + stores[0] = storeName; + return stores; + } + }; + } + + @Override + public boolean enableSendingOldValues(final boolean forceMaterialization) { + return parentKTable.enableSendingOldValues(forceMaterialization); + } + + private static final class KTableSuppressProcessor extends ContextualProcessor, K, Change> { + private final long maxRecords; + private final long maxBytes; + private final long suppressDurationMillis; + private final TimeDefinition bufferTimeDefinition; + private final BufferFullStrategy bufferFullStrategy; + private final boolean safeToDropTombstones; + private final String storeName; + + private TimeOrderedKeyValueBuffer buffer; + private InternalProcessorContext> internalProcessorContext; + private Sensor suppressionEmitSensor; + private long observedStreamTime = ConsumerRecord.NO_TIMESTAMP; + + private KTableSuppressProcessor(final SuppressedInternal suppress, final String storeName) { + this.storeName = storeName; + requireNonNull(suppress); + maxRecords = suppress.bufferConfig().maxRecords(); + maxBytes = suppress.bufferConfig().maxBytes(); + suppressDurationMillis = suppress.timeToWaitForMoreEvents().toMillis(); + bufferTimeDefinition = suppress.timeDefinition(); + bufferFullStrategy = suppress.bufferConfig().bufferFullStrategy(); + safeToDropTombstones = suppress.safeToDropTombstones(); + } + + @Override + public void init(final ProcessorContext> context) { + super.init(context); + internalProcessorContext = (InternalProcessorContext>) context; + suppressionEmitSensor = ProcessorNodeMetrics.suppressionEmitSensor( + Thread.currentThread().getName(), + context.taskId().toString(), + internalProcessorContext.currentNode().name(), + internalProcessorContext.metrics() + ); + + buffer = requireNonNull(context.getStateStore(storeName)); + buffer.setSerdesIfNull(new SerdeGetter(context)); + } + + @Override + public void process(final Record> record) { + observedStreamTime = Math.max(observedStreamTime, record.timestamp()); + buffer(record); + enforceConstraints(); + } + + private void buffer(final Record> record) { + final long bufferTime = bufferTimeDefinition.time(internalProcessorContext, record.key()); + + buffer.put(bufferTime, record, internalProcessorContext.recordContext()); + } + + private void enforceConstraints() { + final long streamTime = observedStreamTime; + final long expiryTime = streamTime - suppressDurationMillis; + + buffer.evictWhile(() -> buffer.minTimestamp() <= expiryTime, this::emit); + + if (overCapacity()) { + switch (bufferFullStrategy) { + case EMIT: + buffer.evictWhile(this::overCapacity, this::emit); + return; + case SHUT_DOWN: + throw new StreamsException(String.format( + "%s buffer exceeded its max capacity. Currently [%d/%d] records and [%d/%d] bytes.", + internalProcessorContext.currentNode().name(), + buffer.numRecords(), maxRecords, + buffer.bufferSize(), maxBytes + )); + default: + throw new UnsupportedOperationException( + "The bufferFullStrategy [" + bufferFullStrategy + + "] is not implemented. This is a bug in Kafka Streams." + ); + } + } + } + + private boolean overCapacity() { + return buffer.numRecords() > maxRecords || buffer.bufferSize() > maxBytes; + } + + private void emit(final TimeOrderedKeyValueBuffer.Eviction toEmit) { + if (shouldForward(toEmit.value())) { + final ProcessorRecordContext prevRecordContext = internalProcessorContext.recordContext(); + internalProcessorContext.setRecordContext(toEmit.recordContext()); + try { + internalProcessorContext.forward(toEmit.record() + .withTimestamp(toEmit.recordContext().timestamp()) + .withHeaders(toEmit.recordContext().headers())); + suppressionEmitSensor.record(1.0d, internalProcessorContext.currentSystemTimeMs()); + } finally { + internalProcessorContext.setRecordContext(prevRecordContext); + } + } + } + + private boolean shouldForward(final Change value) { + return value.newValue != null || !safeToDropTombstones; + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/NamedSuppressed.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/NamedSuppressed.java new file mode 100644 index 0000000..78f6bd6 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/NamedSuppressed.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals.suppress; + +import org.apache.kafka.streams.kstream.Suppressed; + +/** + * Internally-facing interface to work around the fact that all Suppressed config objects + * are name-able, but do not present a getter (for consistency with other config objects). + * If we allow getters on config objects in the future, we can delete this interface. + */ +public interface NamedSuppressed extends Suppressed { + String name(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/StrictBufferConfigImpl.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/StrictBufferConfigImpl.java new file mode 100644 index 0000000..2ca5ef9 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/StrictBufferConfigImpl.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals.suppress; + +import org.apache.kafka.streams.kstream.Suppressed; + +import java.util.Collections; +import java.util.Map; +import java.util.Objects; + +import static org.apache.kafka.streams.kstream.internals.suppress.BufferFullStrategy.SHUT_DOWN; + +public class StrictBufferConfigImpl extends BufferConfigInternal implements Suppressed.StrictBufferConfig { + + private final long maxRecords; + private final long maxBytes; + private final BufferFullStrategy bufferFullStrategy; + private final Map logConfig; + + public StrictBufferConfigImpl(final long maxRecords, + final long maxBytes, + final BufferFullStrategy bufferFullStrategy, + final Map logConfig) { + this.maxRecords = maxRecords; + this.maxBytes = maxBytes; + this.bufferFullStrategy = bufferFullStrategy; + this.logConfig = logConfig; + } + + + public StrictBufferConfigImpl() { + this.maxRecords = Long.MAX_VALUE; + this.maxBytes = Long.MAX_VALUE; + this.bufferFullStrategy = SHUT_DOWN; + this.logConfig = Collections.emptyMap(); + } + + @Override + public Suppressed.StrictBufferConfig withMaxRecords(final long recordLimit) { + return new StrictBufferConfigImpl(recordLimit, maxBytes, bufferFullStrategy, getLogConfig()); + } + + @Override + public Suppressed.StrictBufferConfig withMaxBytes(final long byteLimit) { + return new StrictBufferConfigImpl(maxRecords, byteLimit, bufferFullStrategy, getLogConfig()); + } + + @Override + public long maxRecords() { + return maxRecords; + } + + @Override + public long maxBytes() { + return maxBytes; + } + + @Override + public BufferFullStrategy bufferFullStrategy() { + return bufferFullStrategy; + } + + @Override + public Suppressed.StrictBufferConfig withLoggingDisabled() { + return new StrictBufferConfigImpl(maxRecords, maxBytes, bufferFullStrategy, null); + } + + @Override + public Suppressed.StrictBufferConfig withLoggingEnabled(final Map config) { + return new StrictBufferConfigImpl(maxRecords, maxBytes, bufferFullStrategy, config); + } + + @Override + public boolean isLoggingEnabled() { + return logConfig != null; + } + + @Override + public Map getLogConfig() { + return isLoggingEnabled() ? logConfig : Collections.emptyMap(); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final StrictBufferConfigImpl that = (StrictBufferConfigImpl) o; + return maxRecords == that.maxRecords && + maxBytes == that.maxBytes && + bufferFullStrategy == that.bufferFullStrategy && + Objects.equals(getLogConfig(), ((StrictBufferConfigImpl) o).getLogConfig()); + } + + @Override + public int hashCode() { + return Objects.hash(maxRecords, maxBytes, bufferFullStrategy, getLogConfig()); + } + + @Override + public String toString() { + return "StrictBufferConfigImpl{maxKeys=" + maxRecords + + ", maxBytes=" + maxBytes + + ", bufferFullStrategy=" + bufferFullStrategy + + ", logConfig=" + getLogConfig().toString() + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/SuppressedInternal.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/SuppressedInternal.java new file mode 100644 index 0000000..51307bb --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/SuppressedInternal.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals.suppress; + +import org.apache.kafka.streams.kstream.Suppressed; +import org.apache.kafka.streams.kstream.internals.suppress.TimeDefinitions.TimeDefinition; + +import java.time.Duration; +import java.util.Objects; + +public class SuppressedInternal implements Suppressed, NamedSuppressed { + private static final Duration DEFAULT_SUPPRESSION_TIME = Duration.ofMillis(Long.MAX_VALUE); + private static final StrictBufferConfigImpl DEFAULT_BUFFER_CONFIG = (StrictBufferConfigImpl) BufferConfig.unbounded(); + + private final String name; + private final BufferConfigInternal bufferConfig; + private final Duration timeToWaitForMoreEvents; + private final TimeDefinition timeDefinition; + private final boolean safeToDropTombstones; + + /** + * @param safeToDropTombstones Note: it's *only* safe to drop tombstones for windowed KTables in "final results" mode. + * In that case, we have a priori knowledge that we have never before emitted any + * results for a given key, and therefore the tombstone is unnecessary (albeit + * idempotent and correct). We decided that the unnecessary tombstones would not be + * desirable in the output stream, though, hence the ability to drop them. + * + * A alternative is to remember whether a result has previously been emitted + * for a key and drop tombstones in that case, but it would be a little complicated to + * figure out when to forget the fact that we have emitted some result (currently, the + * buffer immediately forgets all about a key when we emit, which helps to keep it + * compact). + */ + public SuppressedInternal(final String name, + final Duration suppressionTime, + final BufferConfig bufferConfig, + final TimeDefinition timeDefinition, + final boolean safeToDropTombstones) { + this.name = name; + this.timeToWaitForMoreEvents = suppressionTime == null ? DEFAULT_SUPPRESSION_TIME : suppressionTime; + this.timeDefinition = timeDefinition == null ? TimeDefinitions.RecordTimeDefinition.instance() : timeDefinition; + this.bufferConfig = bufferConfig == null ? DEFAULT_BUFFER_CONFIG : (BufferConfigInternal) bufferConfig; + this.safeToDropTombstones = safeToDropTombstones; + } + + @Override + public Suppressed withName(final String name) { + return new SuppressedInternal<>(name, timeToWaitForMoreEvents, bufferConfig, timeDefinition, safeToDropTombstones); + } + + @Override + public String name() { + return name; + } + + @SuppressWarnings("unchecked") + public > BufferConfigInternal bufferConfig() { + return bufferConfig; + } + + TimeDefinition timeDefinition() { + return timeDefinition; + } + + Duration timeToWaitForMoreEvents() { + return timeToWaitForMoreEvents == null ? Duration.ZERO : timeToWaitForMoreEvents; + } + + boolean safeToDropTombstones() { + return safeToDropTombstones; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final SuppressedInternal that = (SuppressedInternal) o; + return safeToDropTombstones == that.safeToDropTombstones && + Objects.equals(name, that.name) && + Objects.equals(bufferConfig, that.bufferConfig) && + Objects.equals(timeToWaitForMoreEvents, that.timeToWaitForMoreEvents) && + Objects.equals(timeDefinition, that.timeDefinition); + } + + @Override + public int hashCode() { + return Objects.hash(name, bufferConfig, timeToWaitForMoreEvents, timeDefinition, safeToDropTombstones); + } + + @Override + public String toString() { + return "SuppressedInternal{" + + "name='" + name + '\'' + + ", bufferConfig=" + bufferConfig + + ", timeToWaitForMoreEvents=" + timeToWaitForMoreEvents + + ", timeDefinition=" + timeDefinition + + ", safeToDropTombstones=" + safeToDropTombstones + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/TimeDefinitions.java b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/TimeDefinitions.java new file mode 100644 index 0000000..640965f --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/kstream/internals/suppress/TimeDefinitions.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals.suppress; + +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.ProcessorContext; + +final class TimeDefinitions { + private TimeDefinitions() {} + + enum TimeDefinitionType { + RECORD_TIME, WINDOW_END_TIME + } + + /** + * This interface should never be instantiated outside of this class. + */ + interface TimeDefinition { + long time(final ProcessorContext context, final K key); + + TimeDefinitionType type(); + } + + public static class RecordTimeDefinition implements TimeDefinition { + private static final RecordTimeDefinition INSTANCE = new RecordTimeDefinition(); + + private RecordTimeDefinition() {} + + @SuppressWarnings("unchecked") + public static RecordTimeDefinition instance() { + return RecordTimeDefinition.INSTANCE; + } + + @Override + public long time(final ProcessorContext context, final K key) { + return context.timestamp(); + } + + @Override + public TimeDefinitionType type() { + return TimeDefinitionType.RECORD_TIME; + } + } + + public static class WindowEndTimeDefinition implements TimeDefinition { + private static final WindowEndTimeDefinition INSTANCE = new WindowEndTimeDefinition(); + + private WindowEndTimeDefinition() {} + + @SuppressWarnings("unchecked") + public static WindowEndTimeDefinition instance() { + return WindowEndTimeDefinition.INSTANCE; + } + + @Override + public long time(final ProcessorContext context, final K key) { + return key.window().end(); + } + + @Override + public TimeDefinitionType type() { + return TimeDefinitionType.WINDOW_END_TIME; + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/AbstractProcessor.java b/streams/src/main/java/org/apache/kafka/streams/processor/AbstractProcessor.java new file mode 100644 index 0000000..52a213d --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/AbstractProcessor.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +/** + * An abstract implementation of {@link Processor} that manages the {@link ProcessorContext} instance and provides default no-op + * implementation of {@link #close()}. + * + * @param the type of keys + * @param the type of values + * @deprecated Since 3.0. Use {@link org.apache.kafka.streams.processor.api.Processor} or + * {@link org.apache.kafka.streams.processor.api.ContextualProcessor} instead. + */ +@Deprecated +public abstract class AbstractProcessor implements Processor { + + protected ProcessorContext context; + + protected AbstractProcessor() {} + + @Override + public void init(final ProcessorContext context) { + this.context = context; + } + + /** + * Close this processor and clean up any resources. + *

                + * This method does nothing by default; if desired, subclasses should override it with custom functionality. + *

                + */ + @Override + public void close() { + // do nothing + } + + /** + * Get the processor's context set during {@link #init(ProcessorContext) initialization}. + * + * @return the processor context; null only when called prior to {@link #init(ProcessorContext) initialization}. + */ + protected final ProcessorContext context() { + return context; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/BatchingStateRestoreCallback.java b/streams/src/main/java/org/apache/kafka/streams/processor/BatchingStateRestoreCallback.java new file mode 100644 index 0000000..d29c7ba --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/BatchingStateRestoreCallback.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.processor; + + +import org.apache.kafka.streams.KeyValue; + +import java.util.Collection; + +/** + * Interface for batching restoration of a {@link StateStore} + * + * It is expected that implementations of this class will not call the {@link StateRestoreCallback#restore(byte[], + * byte[])} method. + */ +public interface BatchingStateRestoreCallback extends StateRestoreCallback { + + /** + * Called to restore a number of records. This method is called repeatedly until the {@link StateStore} is fulled + * restored. + * + * @param records the records to restore. + */ + void restoreAll(Collection> records); + + @Override + default void restore(byte[] key, byte[] value) { + throw new UnsupportedOperationException(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/Cancellable.java b/streams/src/main/java/org/apache/kafka/streams/processor/Cancellable.java new file mode 100644 index 0000000..2acb762 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/Cancellable.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +import java.time.Duration; + +/** + * Cancellable interface returned in {@link ProcessorContext#schedule(Duration, PunctuationType, Punctuator)}. + * + * @see Punctuator + */ +public interface Cancellable { + + /** + * Cancel the scheduled operation to avoid future calls. + */ + void cancel(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/ConnectedStoreProvider.java b/streams/src/main/java/org/apache/kafka/streams/processor/ConnectedStoreProvider.java new file mode 100644 index 0000000..49c029e --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/ConnectedStoreProvider.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Named; +import org.apache.kafka.streams.kstream.TransformerSupplier; +import org.apache.kafka.streams.kstream.ValueTransformerSupplier; +import org.apache.kafka.streams.kstream.ValueTransformerWithKeySupplier; +import org.apache.kafka.streams.state.StoreBuilder; + +import java.util.Set; + +/** + * Provides a set of {@link StoreBuilder}s that will be automatically added to the topology and connected to the + * associated processor. + *

                + * Implementing this interface is recommended when the associated processor wants to encapsulate its usage of its state + * stores, rather than exposing them to the user building the topology. + *

                + * In the event that separate but related processors may want to share the same store, different {@link ConnectedStoreProvider}s + * may provide the same instance of {@link StoreBuilder}, as shown below. + *

                {@code
                + * class StateSharingProcessors {
                + *     StoreBuilder> storeBuilder =
                + *         Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("myStore"), Serdes.String(), Serdes.String());
                + *
                + *     class SupplierA implements ProcessorSupplier {
                + *         Processor get() {
                + *             return new Processor() {
                + *                 private StateStore store;
                + *
                + *                 void init(ProcessorContext context) {
                + *                     this.store = context.getStateStore("myStore");
                + *                 }
                + *
                + *                 void process(String key, Integer value) {
                + *                     // can access this.store
                + *                 }
                + *
                + *                 void close() {
                + *                     // can access this.store
                + *                 }
                + *             }
                + *         }
                + *
                + *         Set> stores() {
                + *             return Collections.singleton(storeBuilder);
                + *         }
                + *     }
                + *
                + *     class SupplierB implements ProcessorSupplier {
                + *         Processor get() {
                + *             return new Processor() {
                + *                 private StateStore store;
                + *
                + *                 void init(ProcessorContext context) {
                + *                     this.store = context.getStateStore("myStore");
                + *                 }
                + *
                + *                 void process(String key, String value) {
                + *                     // can access this.store
                + *                 }
                + *
                + *                 void close() {
                + *                     // can access this.store
                + *                 }
                + *             }
                + *         }
                + *
                + *         Set> stores() {
                + *             return Collections.singleton(storeBuilder);
                + *         }
                + *     }
                + * }
                + * }
                + * + * @see Topology#addProcessor(String, org.apache.kafka.streams.processor.api.ProcessorSupplier, String...) + * @see KStream#process(org.apache.kafka.streams.processor.api.ProcessorSupplier, String...) + * @see KStream#process(org.apache.kafka.streams.processor.api.ProcessorSupplier, Named, String...) + * @see KStream#transform(TransformerSupplier, String...) + * @see KStream#transform(TransformerSupplier, Named, String...) + * @see KStream#transformValues(ValueTransformerSupplier, String...) + * @see KStream#transformValues(ValueTransformerSupplier, Named, String...) + * @see KStream#transformValues(ValueTransformerWithKeySupplier, String...) + * @see KStream#transformValues(ValueTransformerWithKeySupplier, Named, String...) + * @see KStream#flatTransform(TransformerSupplier, String...) + * @see KStream#flatTransform(TransformerSupplier, Named, String...) + * @see KStream#flatTransformValues(ValueTransformerSupplier, String...) + * @see KStream#flatTransformValues(ValueTransformerSupplier, Named, String...) + * @see KStream#flatTransformValues(ValueTransformerWithKeySupplier, String...) + * @see KStream#flatTransformValues(ValueTransformerWithKeySupplier, Named, String...) + */ +public interface ConnectedStoreProvider { + + /** + * @return the state stores to be connected and added, or null if no stores should be automatically connected and added. + */ + default Set> stores() { + return null; + } +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/ExtractRecordMetadataTimestamp.java b/streams/src/main/java/org/apache/kafka/streams/processor/ExtractRecordMetadataTimestamp.java new file mode 100644 index 0000000..bd28bb7 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/ExtractRecordMetadataTimestamp.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +import org.apache.kafka.clients.consumer.ConsumerRecord; + +/** + * Retrieves embedded metadata timestamps from Kafka messages. + * If a record has a negative (invalid) timestamp value, an error handler method is called. + *

                + * Embedded metadata timestamp was introduced in "KIP-32: Add timestamps to Kafka message" for the new + * 0.10+ Kafka message format. + *

                + * Here, "embedded metadata" refers to the fact that compatible Kafka producer clients automatically and + * transparently embed such timestamps into message metadata they send to Kafka, which can then be retrieved + * via this timestamp extractor. + *

                + * If the embedded metadata timestamp represents CreateTime (cf. Kafka broker setting + * {@code message.timestamp.type} and Kafka topic setting {@code log.message.timestamp.type}), + * this extractor effectively provides event-time semantics. + * If LogAppendTime is used as broker/topic setting to define the embedded metadata timestamps, + * using this extractor effectively provides ingestion-time semantics. + *

                + * If you need processing-time semantics, use {@link WallclockTimestampExtractor}. + * + * @see FailOnInvalidTimestamp + * @see LogAndSkipOnInvalidTimestamp + * @see UsePartitionTimeOnInvalidTimestamp + * @see WallclockTimestampExtractor + */ +abstract class ExtractRecordMetadataTimestamp implements TimestampExtractor { + + /** + * Extracts the embedded metadata timestamp from the given {@link ConsumerRecord}. + * + * @param record a data record + * @param partitionTime the highest extracted valid timestamp of the current record's partition˙ (could be -1 if unknown) + * @return the embedded metadata timestamp of the given {@link ConsumerRecord} + */ + @Override + public long extract(final ConsumerRecord record, final long partitionTime) { + final long timestamp = record.timestamp(); + + if (timestamp < 0) { + return onInvalidTimestamp(record, timestamp, partitionTime); + } + + return timestamp; + } + + /** + * Called if no valid timestamp is embedded in the record meta data. + * + * @param record a data record + * @param recordTimestamp the timestamp extractor from the record + * @param partitionTime the highest extracted valid timestamp of the current record's partition˙ (could be -1 if unknown) + * @return a new timestamp for the record (if negative, record will not be processed but dropped silently) + */ + public abstract long onInvalidTimestamp(final ConsumerRecord record, + final long recordTimestamp, + final long partitionTime); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/FailOnInvalidTimestamp.java b/streams/src/main/java/org/apache/kafka/streams/processor/FailOnInvalidTimestamp.java new file mode 100644 index 0000000..f4541da --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/FailOnInvalidTimestamp.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.streams.errors.StreamsException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Retrieves embedded metadata timestamps from Kafka messages. + * If a record has a negative (invalid) timestamp value, this extractor raises an exception. + *

                + * Embedded metadata timestamp was introduced in "KIP-32: Add timestamps to Kafka message" for the new + * 0.10+ Kafka message format. + *

                + * Here, "embedded metadata" refers to the fact that compatible Kafka producer clients automatically and + * transparently embed such timestamps into message metadata they send to Kafka, which can then be retrieved + * via this timestamp extractor. + *

                + * If the embedded metadata timestamp represents CreateTime (cf. Kafka broker setting + * {@code message.timestamp.type} and Kafka topic setting {@code log.message.timestamp.type}), + * this extractor effectively provides event-time semantics. + * If LogAppendTime is used as broker/topic setting to define the embedded metadata timestamps, + * using this extractor effectively provides ingestion-time semantics. + *

                + * If you need processing-time semantics, use {@link WallclockTimestampExtractor}. + * + * @see LogAndSkipOnInvalidTimestamp + * @see UsePartitionTimeOnInvalidTimestamp + * @see WallclockTimestampExtractor + */ +public class FailOnInvalidTimestamp extends ExtractRecordMetadataTimestamp { + private static final Logger log = LoggerFactory.getLogger(FailOnInvalidTimestamp.class); + + /** + * Raises an exception on every call. + * + * @param record a data record + * @param recordTimestamp the timestamp extractor from the record + * @param partitionTime the highest extracted valid timestamp of the current record's partition˙ (could be -1 if unknown) + * @return nothing; always raises an exception + * @throws StreamsException on every invocation + */ + @Override + public long onInvalidTimestamp(final ConsumerRecord record, + final long recordTimestamp, + final long partitionTime) + throws StreamsException { + + final String message = "Input record " + record + " has invalid (negative) timestamp. " + + "Possibly because a pre-0.10 producer client was used to write this record to Kafka without embedding " + + "a timestamp, or because the input topic was created before upgrading the Kafka cluster to 0.10+. " + + "Use a different TimestampExtractor to process this data."; + + log.error(message); + throw new StreamsException(message); + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/LogAndSkipOnInvalidTimestamp.java b/streams/src/main/java/org/apache/kafka/streams/processor/LogAndSkipOnInvalidTimestamp.java new file mode 100644 index 0000000..f2095c1 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/LogAndSkipOnInvalidTimestamp.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Retrieves embedded metadata timestamps from Kafka messages. + * If a record has a negative (invalid) timestamp value the timestamp is returned as-is; + * in addition, a WARN message is logged in your application. + * Returning the timestamp as-is results in dropping the record, i.e., the record will not be processed. + *

                + * Embedded metadata timestamp was introduced in "KIP-32: Add timestamps to Kafka message" for the new + * 0.10+ Kafka message format. + *

                + * Here, "embedded metadata" refers to the fact that compatible Kafka producer clients automatically and + * transparently embed such timestamps into message metadata they send to Kafka, which can then be retrieved + * via this timestamp extractor. + *

                + * If the embedded metadata timestamp represents CreateTime (cf. Kafka broker setting + * {@code message.timestamp.type} and Kafka topic setting {@code log.message.timestamp.type}), + * this extractor effectively provides event-time semantics. + * If LogAppendTime is used as broker/topic setting to define the embedded metadata timestamps, + * using this extractor effectively provides ingestion-time semantics. + *

                + * If you need processing-time semantics, use {@link WallclockTimestampExtractor}. + * + * @see FailOnInvalidTimestamp + * @see UsePartitionTimeOnInvalidTimestamp + * @see WallclockTimestampExtractor + */ +public class LogAndSkipOnInvalidTimestamp extends ExtractRecordMetadataTimestamp { + private static final Logger log = LoggerFactory.getLogger(LogAndSkipOnInvalidTimestamp.class); + + /** + * Writes a log WARN message when the extracted timestamp is invalid (negative) but returns the invalid timestamp as-is, + * which ultimately causes the record to be skipped and not to be processed. + * + * @param record a data record + * @param recordTimestamp the timestamp extractor from the record + * @param partitionTime the highest extracted valid timestamp of the current record's partition˙ (could be -1 if unknown) + * @return the originally extracted timestamp of the record + */ + @Override + public long onInvalidTimestamp(final ConsumerRecord record, + final long recordTimestamp, + final long partitionTime) { + log.warn("Input record {} will be dropped because it has an invalid (negative) timestamp.", record); + return recordTimestamp; + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/Processor.java b/streams/src/main/java/org/apache/kafka/streams/processor/Processor.java new file mode 100644 index 0000000..9d724ec --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/Processor.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +import java.time.Duration; + +/** + * A processor of key-value pair records. + * + * @param the type of keys + * @param the type of values + * @deprecated Since 3.0. Use {@link org.apache.kafka.streams.processor.api.Processor} instead. + */ +@Deprecated +public interface Processor { + + /** + * Initialize this processor with the given context. The framework ensures this is called once per processor when the topology + * that contains it is initialized. When the framework is done with the processor, {@link #close()} will be called on it; the + * framework may later re-use the processor by calling {@code #init()} again. + *

                + * The provided {@link ProcessorContext context} can be used to access topology and record meta data, to + * {@link ProcessorContext#schedule(Duration, PunctuationType, Punctuator) schedule} a method to be + * {@link Punctuator#punctuate(long) called periodically} and to access attached {@link StateStore}s. + * + * @param context the context; may not be null + */ + void init(ProcessorContext context); + + /** + * Process the record with the given key and value. + * + * @param key the key for the record + * @param value the value for the record + */ + void process(K key, V value); + + /** + * Close this processor and clean up any resources. Be aware that {@code #close()} is called after an internal cleanup. + * Thus, it is not possible to write anything to Kafka as underlying clients are already closed. The framework may + * later re-use this processor by calling {@code #init()} on it again. + *

                + * Note: Do not close any streams managed resources, like {@link StateStore}s here, as they are managed by the library. + */ + void close(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/ProcessorContext.java b/streams/src/main/java/org/apache/kafka/streams/processor/ProcessorContext.java new file mode 100644 index 0000000..f438a88 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/ProcessorContext.java @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsMetrics; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.kstream.ValueTransformerWithKeySupplier; + +import java.io.File; +import java.time.Duration; +import java.util.Map; + +/** + * Processor context interface. + */ +@SuppressWarnings("deprecation") // Not deprecating the old context, since it is used by Transformers. See KAFKA-10603. +public interface ProcessorContext { + + /** + * Return the application id. + * + * @return the application id + */ + String applicationId(); + + /** + * Return the task id. + * + * @return the task id + */ + TaskId taskId(); + + /** + * Return the default key serde. + * + * @return the key serializer + */ + Serde keySerde(); + + /** + * Return the default value serde. + * + * @return the value serializer + */ + Serde valueSerde(); + + /** + * Return the state directory for the partition. + * + * @return the state directory + */ + File stateDir(); + + /** + * Return Metrics instance. + * + * @return StreamsMetrics + */ + StreamsMetrics metrics(); + + /** + * Register and possibly restores the specified storage engine. + * + * @param store the storage engine + * @param stateRestoreCallback the restoration callback logic for log-backed state stores upon restart + * + * @throws IllegalStateException If store gets registered after initialized is already finished + * @throws StreamsException if the store's change log does not contain the partition + */ + void register(final StateStore store, + final StateRestoreCallback stateRestoreCallback); + + /** + * Get the state store given the store name. + * + * @param name The store name + * @param The type or interface of the store to return + * @return The state store instance + * + * @throws ClassCastException if the return type isn't a type or interface of the actual returned store. + */ + S getStateStore(final String name); + + /** + * Schedule a periodic operation for processors. A processor may call this method during + * {@link Processor#init(ProcessorContext) initialization} or + * {@link Processor#process(Object, Object) processing} to + * schedule a periodic callback — called a punctuation — to {@link Punctuator#punctuate(long)}. + * The type parameter controls what notion of time is used for punctuation: + *

                  + *
                • {@link PunctuationType#STREAM_TIME} — uses "stream time", which is advanced by the processing of messages + * in accordance with the timestamp as extracted by the {@link TimestampExtractor} in use. + * The first punctuation will be triggered by the first record that is processed. + * NOTE: Only advanced if messages arrive
                • + *
                • {@link PunctuationType#WALL_CLOCK_TIME} — uses system time (the wall-clock time), + * which is advanced independent of whether new messages arrive. + * The first punctuation will be triggered after interval has elapsed. + * NOTE: This is best effort only as its granularity is limited by how long an iteration of the + * processing loop takes to complete
                • + *
                + * + * Skipping punctuations: Punctuations will not be triggered more than once at any given timestamp. + * This means that "missed" punctuation will be skipped. + * It's possible to "miss" a punctuation if: + *
                  + *
                • with {@link PunctuationType#STREAM_TIME}, when stream time advances more than interval
                • + *
                • with {@link PunctuationType#WALL_CLOCK_TIME}, on GC pause, too short interval, ...
                • + *
                + * + * @param interval the time interval between punctuations (supported minimum is 1 millisecond) + * @param type one of: {@link PunctuationType#STREAM_TIME}, {@link PunctuationType#WALL_CLOCK_TIME} + * @param callback a function consuming timestamps representing the current stream or system time + * @return a handle allowing cancellation of the punctuation schedule established by this method + * @throws IllegalArgumentException if the interval is not representable in milliseconds + */ + Cancellable schedule(final Duration interval, + final PunctuationType type, + final Punctuator callback); + + /** + * Forward a key/value pair to all downstream processors. + * Used the input record's timestamp as timestamp for the output record. + * + *

                If this method is called with {@link Punctuator#punctuate(long)} the record that + * is sent downstream won't have any associated record metadata like topic, partition, or offset. + * + * @param key key + * @param value value + */ + void forward(final K key, final V value); + + /** + * Forward a key/value pair to the specified downstream processors. + * Can be used to set the timestamp of the output record. + * + *

                If this method is called with {@link Punctuator#punctuate(long)} the record that + * is sent downstream won't have any associated record metadata like topic, partition, or offset. + * + * @param key key + * @param value value + * @param to the options to use when forwarding + */ + void forward(final K key, final V value, final To to); + + /** + * Request a commit. + */ + void commit(); + + /** + * Return the topic name of the current input record; could be {@code null} if it is not + * available. + * + *

                For example, if this method is invoked within a {@link Punctuator#punctuate(long) + * punctuation callback}, or while processing a record that was forwarded by a punctuation + * callback, the record won't have an associated topic. + * Another example is + * {@link org.apache.kafka.streams.kstream.KTable#transformValues(ValueTransformerWithKeySupplier, String...)} + * (and siblings), that do not always guarantee to provide a valid topic name, as they might be + * executed "out-of-band" due to some internal optimizations applied by the Kafka Streams DSL. + * + * @return the topic name + */ + String topic(); + + /** + * Return the partition id of the current input record; could be {@code -1} if it is not + * available. + * + *

                For example, if this method is invoked within a {@link Punctuator#punctuate(long) + * punctuation callback}, or while processing a record that was forwarded by a punctuation + * callback, the record won't have an associated partition id. + * Another example is + * {@link org.apache.kafka.streams.kstream.KTable#transformValues(ValueTransformerWithKeySupplier, String...)} + * (and siblings), that do not always guarantee to provide a valid partition id, as they might be + * executed "out-of-band" due to some internal optimizations applied by the Kafka Streams DSL. + * + * @return the partition id + */ + int partition(); + + /** + * Return the offset of the current input record; could be {@code -1} if it is not + * available. + * + *

                For example, if this method is invoked within a {@link Punctuator#punctuate(long) + * punctuation callback}, or while processing a record that was forwarded by a punctuation + * callback, the record won't have an associated offset. + * Another example is + * {@link org.apache.kafka.streams.kstream.KTable#transformValues(ValueTransformerWithKeySupplier, String...)} + * (and siblings), that do not always guarantee to provide a valid offset, as they might be + * executed "out-of-band" due to some internal optimizations applied by the Kafka Streams DSL. + * + * @return the offset + */ + long offset(); + + /** + * Return the headers of the current input record; could be an empty header if it is not + * available. + * + *

                For example, if this method is invoked within a {@link Punctuator#punctuate(long) + * punctuation callback}, or while processing a record that was forwarded by a punctuation + * callback, the record might not have any associated headers. + * Another example is + * {@link org.apache.kafka.streams.kstream.KTable#transformValues(ValueTransformerWithKeySupplier, String...)} + * (and siblings), that do not always guarantee to provide valid headers, as they might be + * executed "out-of-band" due to some internal optimizations applied by the Kafka Streams DSL. + * + * @return the headers + */ + Headers headers(); + + /** + * Return the current timestamp. + * + *

                If it is triggered while processing a record streamed from the source processor, + * timestamp is defined as the timestamp of the current input record; the timestamp is extracted from + * {@link org.apache.kafka.clients.consumer.ConsumerRecord ConsumerRecord} by {@link TimestampExtractor}. + * Note, that an upstream {@link Processor} might have set a new timestamp by calling + * {@link ProcessorContext#forward(Object, Object, To) forward(..., To.all().withTimestamp(...))}. + * In particular, some Kafka Streams DSL operators set result record timestamps explicitly, + * to guarantee deterministic results. + * + *

                If it is triggered while processing a record generated not from the source processor (for example, + * if this method is invoked from the punctuate call), timestamp is defined as the current + * task's stream time, which is defined as the largest timestamp of any record processed by the task. + * + * @return the timestamp + */ + long timestamp(); + + /** + * Return all the application config properties as key/value pairs. + * + *

                The config properties are defined in the {@link org.apache.kafka.streams.StreamsConfig} + * object and associated to the ProcessorContext. + * + *

                The type of the values is dependent on the {@link org.apache.kafka.common.config.ConfigDef.Type type} of the property + * (e.g. the value of {@link org.apache.kafka.streams.StreamsConfig#DEFAULT_KEY_SERDE_CLASS_CONFIG DEFAULT_KEY_SERDE_CLASS_CONFIG} + * will be of type {@link Class}, even if it was specified as a String to + * {@link org.apache.kafka.streams.StreamsConfig#StreamsConfig(Map) StreamsConfig(Map)}). + * + * @return all the key/values from the StreamsConfig properties + */ + Map appConfigs(); + + /** + * Return all the application config properties with the given key prefix, as key/value pairs + * stripping the prefix. + * + *

                The config properties are defined in the {@link org.apache.kafka.streams.StreamsConfig} + * object and associated to the ProcessorContext. + * + * @param prefix the properties prefix + * @return the key/values matching the given prefix from the StreamsConfig properties. + */ + Map appConfigsWithPrefix(final String prefix); + + /** + * Return the current system timestamp (also called wall-clock time) in milliseconds. + * + *

                Note: this method returns the internally cached system timestamp from the Kafka Stream runtime. + * Thus, it may return a different value compared to {@code System.currentTimeMillis()}. + * + * @return the current system timestamp in milliseconds + */ + long currentSystemTimeMs(); + + /** + * Return the current stream-time in milliseconds. + * + *

                Stream-time is the maximum observed {@link TimestampExtractor record timestamp} so far + * (including the currently processed record), i.e., it can be considered a high-watermark. + * Stream-time is tracked on a per-task basis and is preserved across restarts and during task migration. + * + *

                Note: this method is not supported for global processors (cf. {@link Topology#addGlobalStore} (...) + * and {@link StreamsBuilder#addGlobalStore} (...), + * because there is no concept of stream-time for this case. + * Calling this method in a global processor will result in an {@link UnsupportedOperationException}. + * + * @return the current stream-time in milliseconds + */ + long currentStreamTimeMs(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/ProcessorSupplier.java b/streams/src/main/java/org/apache/kafka/streams/processor/ProcessorSupplier.java new file mode 100644 index 0000000..e53a63a --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/ProcessorSupplier.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +import org.apache.kafka.streams.Topology; + +import java.util.function.Supplier; + +/** + * A processor supplier that can create one or more {@link Processor} instances. + *

                + * It is used in {@link Topology} for adding new processor operators, whose generated + * topology can then be replicated (and thus creating one or more {@link Processor} instances) + * and distributed to multiple stream threads. + *

                + * The supplier should always generate a new instance each time {@link ProcessorSupplier#get()} gets called. Creating + * a single {@link Processor} object and returning the same object reference in {@link ProcessorSupplier#get()} would be + * a violation of the supplier pattern and leads to runtime exceptions. + * + * @param the type of keys + * @param the type of values + * @deprecated Since 3.0. Use {@link org.apache.kafka.streams.processor.api.ProcessorSupplier} instead. + */ +@Deprecated +public interface ProcessorSupplier extends ConnectedStoreProvider, Supplier> { + + /** + * Return a newly constructed {@link Processor} instance. + * The supplier should always generate a new instance each time {@link ProcessorSupplier#get()} gets called. + *

                + * Creating a single {@link Processor} object and returning the same object reference in {@link ProcessorSupplier#get()} + * is a violation of the supplier pattern and leads to runtime exceptions. + * + * @return a newly constructed {@link Processor} instance + */ + Processor get(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/PunctuationType.java b/streams/src/main/java/org/apache/kafka/streams/processor/PunctuationType.java new file mode 100644 index 0000000..32965e8 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/PunctuationType.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +import java.time.Duration; + +/** + * Controls what notion of time is used for punctuation scheduled via {@link ProcessorContext#schedule(Duration, PunctuationType, Punctuator)} schedule}: + *

                  + *
                • STREAM_TIME - uses "stream time", which is advanced by the processing of messages + * in accordance with the timestamp as extracted by the {@link TimestampExtractor} in use. + * NOTE: Only advanced if messages arrive
                • + *
                • WALL_CLOCK_TIME - uses system time (the wall-clock time), + * which is advanced at the polling interval ({@link org.apache.kafka.streams.StreamsConfig#POLL_MS_CONFIG}) + * independent of whether new messages arrive. NOTE: This is best effort only as its granularity is limited + * by how long an iteration of the processing loop takes to complete
                • + *
                + */ +public enum PunctuationType { + STREAM_TIME, + WALL_CLOCK_TIME, +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/Punctuator.java b/streams/src/main/java/org/apache/kafka/streams/processor/Punctuator.java new file mode 100644 index 0000000..1cbde6d --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/Punctuator.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +import java.time.Duration; +import org.apache.kafka.streams.processor.api.Record; + +/** + * A functional interface used as an argument to {@link ProcessorContext#schedule(Duration, PunctuationType, Punctuator)}. + * + * @see Cancellable + */ +public interface Punctuator { + + /** + * Perform the scheduled periodic operation. + * + *

                If this method accesses {@link ProcessorContext} or + * {@link org.apache.kafka.streams.processor.api.ProcessorContext}, record metadata like topic, + * partition, and offset or {@link org.apache.kafka.streams.processor.api.RecordMetadata} won't + * be available. + * + *

                Furthermore, for any record that is sent downstream via {@link ProcessorContext#forward(Object, Object)} + * or {@link org.apache.kafka.streams.processor.api.ProcessorContext#forward(Record)}, there + * won't be any record metadata. If {@link ProcessorContext#forward(Object, Object)} is used, + * it's also not possible to set records headers. + * + * @param timestamp when the operation is being called, depending on {@link PunctuationType} + */ + void punctuate(long timestamp); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/RecordContext.java b/streams/src/main/java/org/apache/kafka/streams/processor/RecordContext.java new file mode 100644 index 0000000..9b21df8 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/RecordContext.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.streams.kstream.ValueTransformerWithKeySupplier; + +/** + * The context associated with the current record being processed by + * an {@link org.apache.kafka.streams.processor.api.Processor} + */ +public interface RecordContext { + + /** + * Return the topic name of the current input record; could be {@code null} if it is not + * available. + * + *

                For example, if this method is invoked within a {@link Punctuator#punctuate(long) + * punctuation callback}, or while processing a record that was forwarded by a punctuation + * callback, the record won't have an associated topic. + * Another example is + * {@link org.apache.kafka.streams.kstream.KTable#transformValues(ValueTransformerWithKeySupplier, String...)} + * (and siblings), that do not always guarantee to provide a valid topic name, as they might be + * executed "out-of-band" due to some internal optimizations applied by the Kafka Streams DSL. + * + * @return the topic name + */ + String topic(); + + /** + * Return the partition id of the current input record; could be {@code -1} if it is not + * available. + * + *

                For example, if this method is invoked within a {@link Punctuator#punctuate(long) + * punctuation callback}, or while processing a record that was forwarded by a punctuation + * callback, the record won't have an associated partition id. + * Another example is + * {@link org.apache.kafka.streams.kstream.KTable#transformValues(ValueTransformerWithKeySupplier, String...)} + * (and siblings), that do not always guarantee to provide a valid partition id, as they might be + * executed "out-of-band" due to some internal optimizations applied by the Kafka Streams DSL. + * + * @return the partition id + */ + int partition(); + + /** + * Return the offset of the current input record; could be {@code -1} if it is not + * available. + * + *

                For example, if this method is invoked within a {@link Punctuator#punctuate(long) + * punctuation callback}, or while processing a record that was forwarded by a punctuation + * callback, the record won't have an associated offset. + * Another example is + * {@link org.apache.kafka.streams.kstream.KTable#transformValues(ValueTransformerWithKeySupplier, String...)} + * (and siblings), that do not always guarantee to provide a valid offset, as they might be + * executed "out-of-band" due to some internal optimizations applied by the Kafka Streams DSL. + * + * @return the offset + */ + long offset(); + + /** + * Return the current timestamp. + * + *

                If it is triggered while processing a record streamed from the source processor, + * timestamp is defined as the timestamp of the current input record; the timestamp is extracted from + * {@link org.apache.kafka.clients.consumer.ConsumerRecord ConsumerRecord} by {@link TimestampExtractor}. + * Note, that an upstream {@link org.apache.kafka.streams.processor.api.Processor} + * might have set a new timestamp by calling + * {@link org.apache.kafka.streams.processor.api.ProcessorContext#forward(org.apache.kafka.streams.processor.api.Record) + * forward(..., To.all().withTimestamp(...))}. + * In particular, some Kafka Streams DSL operators set result record timestamps explicitly, + * to guarantee deterministic results. + * + *

                If it is triggered while processing a record generated not from the source processor (for example, + * if this method is invoked from the punctuate call), timestamp is defined as the current + * task's stream time, which is defined as the largest timestamp of any record processed by the task. + * + * @return the timestamp + */ + long timestamp(); + + /** + * Return the headers of the current input record; could be an empty header if it is not + * available. + * + *

                For example, if this method is invoked within a {@link Punctuator#punctuate(long) + * punctuation callback}, or while processing a record that was forwarded by a punctuation + * callback, the record might not have any associated headers. + * Another example is + * {@link org.apache.kafka.streams.kstream.KTable#transformValues(ValueTransformerWithKeySupplier, String...)} + * (and siblings), that do not always guarantee to provide a valid headers, as they might be + * executed "out-of-band" due to some internal optimizations applied by the Kafka Streams DSL. + * + * @return the headers + */ + Headers headers(); + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/StateRestoreCallback.java b/streams/src/main/java/org/apache/kafka/streams/processor/StateRestoreCallback.java new file mode 100644 index 0000000..2e896c8 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/StateRestoreCallback.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +/** + * Restoration logic for log-backed state stores upon restart, + * it takes one record at a time from the logs to apply to the restoring state. + */ +public interface StateRestoreCallback { + + void restore(byte[] key, byte[] value); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/StateRestoreListener.java b/streams/src/main/java/org/apache/kafka/streams/processor/StateRestoreListener.java new file mode 100644 index 0000000..210a5de --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/StateRestoreListener.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.processor; + +import org.apache.kafka.common.TopicPartition; + +/** + * Class for listening to various states of the restoration process of a StateStore. + * + * When calling {@link org.apache.kafka.streams.KafkaStreams#setGlobalStateRestoreListener(StateRestoreListener)} + * the passed instance is expected to be stateless since the {@code StateRestoreListener} is shared + * across all {@link org.apache.kafka.streams.processor.internals.StreamThread} instances. + * + * Users desiring stateful operations will need to provide synchronization internally in + * the {@code StateRestorerListener} implementation. + * + * Note that this listener is only registered at the per-client level and users can base on the {@code storeName} + * parameter to define specific monitoring for different {@link StateStore}s. There is another + * {@link StateRestoreCallback} interface which is registered via the {@link ProcessorContext#register(StateStore, StateRestoreCallback)} + * function per-store, and it is used to apply the fetched changelog records into the local state store during restoration. + * These two interfaces serve different restoration purposes and users should not try to implement both of them in a single + * class during state store registration. + * + * Incremental updates are exposed so users can estimate how much progress has been made. + */ +public interface StateRestoreListener { + + /** + * Method called at the very beginning of {@link StateStore} restoration. + * + * @param topicPartition the TopicPartition containing the values to restore + * @param storeName the name of the store undergoing restoration + * @param startingOffset the starting offset of the entire restoration process for this TopicPartition + * @param endingOffset the exclusive ending offset of the entire restoration process for this TopicPartition + */ + void onRestoreStart(final TopicPartition topicPartition, + final String storeName, + final long startingOffset, + final long endingOffset); + + /** + * Method called after restoring a batch of records. In this case the maximum size of the batch is whatever + * the value of the MAX_POLL_RECORDS is set to. + * + * This method is called after restoring each batch and it is advised to keep processing to a minimum. + * Any heavy processing will hold up recovering the next batch, hence slowing down the restore process as a + * whole. + * + * If you need to do any extended processing or connecting to an external service consider doing so asynchronously. + * + * @param topicPartition the TopicPartition containing the values to restore + * @param storeName the name of the store undergoing restoration + * @param batchEndOffset the inclusive ending offset for the current restored batch for this TopicPartition + * @param numRestored the total number of records restored in this batch for this TopicPartition + */ + void onBatchRestored(final TopicPartition topicPartition, + final String storeName, + final long batchEndOffset, + final long numRestored); + + /** + * Method called when restoring the {@link StateStore} is complete. + * + * @param topicPartition the TopicPartition containing the values to restore + * @param storeName the name of the store just restored + * @param totalRestored the total number of records restored for this TopicPartition + */ + void onRestoreEnd(final TopicPartition topicPartition, + final String storeName, + final long totalRestored); + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/StateStore.java b/streams/src/main/java/org/apache/kafka/streams/processor/StateStore.java new file mode 100644 index 0000000..76d1ab4 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/StateStore.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.internals.StoreToProcessorContextAdapter; + +/** + * A storage engine for managing state maintained by a stream processor. + *

                + * If the store is implemented as a persistent store, it must use the store name as directory name and write + * all data into this store directory. + * The store directory must be created with the state directory. + * The state directory can be obtained via {@link ProcessorContext#stateDir() #stateDir()} using the + * {@link ProcessorContext} provided via {@link #init(StateStoreContext, StateStore) init(...)}. + *

                + * Using nested store directories within the state directory isolates different state stores. + * If a state store would write into the state directory directly, it might conflict with others state stores and thus, + * data might get corrupted and/or Streams might fail with an error. + * Furthermore, Kafka Streams relies on using the store name as store directory name to perform internal cleanup tasks. + *

                + * This interface does not specify any query capabilities, which, of course, + * would be query engine specific. Instead it just specifies the minimum + * functionality required to reload a storage engine from its changelog as well + * as basic lifecycle management. + */ +public interface StateStore { + + /** + * The name of this store. + * @return the storage name + */ + String name(); + + /** + * Initializes this state store. + *

                + * The implementation of this function must register the root store in the context via the + * {@link org.apache.kafka.streams.processor.ProcessorContext#register(StateStore, StateRestoreCallback)} function, + * where the first {@link StateStore} parameter should always be the passed-in {@code root} object, and + * the second parameter should be an object of user's implementation + * of the {@link StateRestoreCallback} interface used for restoring the state store from the changelog. + *

                + * Note that if the state store engine itself supports bulk writes, users can implement another + * interface {@link BatchingStateRestoreCallback} which extends {@link StateRestoreCallback} to + * let users implement bulk-load restoration logic instead of restoring one record at a time. + *

                + * This method is not called if {@link StateStore#init(StateStoreContext, StateStore)} + * is implemented. + * + * @throws IllegalStateException If store gets registered after initialized is already finished + * @throws StreamsException if the store's change log does not contain the partition + * @deprecated Since 2.7.0. Callers should invoke {@link #init(StateStoreContext, StateStore)} instead. + * Implementers may choose to implement this method for backward compatibility or to throw an + * informative exception instead. + */ + @Deprecated + void init(org.apache.kafka.streams.processor.ProcessorContext context, StateStore root); + + /** + * Initializes this state store. + *

                + * The implementation of this function must register the root store in the context via the + * {@link StateStoreContext#register(StateStore, StateRestoreCallback)} function, where the + * first {@link StateStore} parameter should always be the passed-in {@code root} object, and + * the second parameter should be an object of user's implementation + * of the {@link StateRestoreCallback} interface used for restoring the state store from the changelog. + *

                + * Note that if the state store engine itself supports bulk writes, users can implement another + * interface {@link BatchingStateRestoreCallback} which extends {@link StateRestoreCallback} to + * let users implement bulk-load restoration logic instead of restoring one record at a time. + * + * @throws IllegalStateException If store gets registered after initialized is already finished + * @throws StreamsException if the store's change log does not contain the partition + */ + default void init(final StateStoreContext context, final StateStore root) { + init(StoreToProcessorContextAdapter.adapt(context), root); + } + + /** + * Flush any cached data + */ + void flush(); + + /** + * Close the storage engine. + * Note that this function needs to be idempotent since it may be called + * several times on the same state store. + *

                + * Users only need to implement this function but should NEVER need to call this api explicitly + * as it will be called by the library automatically when necessary + */ + void close(); + + /** + * Return if the storage is persistent or not. + * + * @return {@code true} if the storage is persistent—{@code false} otherwise + */ + boolean persistent(); + + /** + * Is this store open for reading and writing + * @return {@code true} if the store is open + */ + boolean isOpen(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/StateStoreContext.java b/streams/src/main/java/org/apache/kafka/streams/processor/StateStoreContext.java new file mode 100644 index 0000000..43810a2 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/StateStoreContext.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.StreamsMetrics; +import org.apache.kafka.streams.errors.StreamsException; + +import java.io.File; +import java.util.Map; + +/** + * State store context interface. + */ +public interface StateStoreContext { + + /** + * Returns the application id. + * + * @return the application id + */ + String applicationId(); + + /** + * Returns the task id. + * + * @return the task id + */ + TaskId taskId(); + + /** + * Returns the default key serde. + * + * @return the key serializer + */ + Serde keySerde(); + + /** + * Returns the default value serde. + * + * @return the value serializer + */ + Serde valueSerde(); + + /** + * Returns the state directory for the partition. + * + * @return the state directory + */ + File stateDir(); + + /** + * Returns Metrics instance. + * + * @return StreamsMetrics + */ + StreamsMetrics metrics(); + + /** + * Registers and possibly restores the specified storage engine. + * + * @param store the storage engine + * @param stateRestoreCallback the restoration callback logic for log-backed state stores upon restart + * + * @throws IllegalStateException If store gets registered after initialized is already finished + * @throws StreamsException if the store's change log does not contain the partition + */ + void register(final StateStore store, + final StateRestoreCallback stateRestoreCallback); + + /** + * Returns all the application config properties as key/value pairs. + * + *

                The config properties are defined in the {@link org.apache.kafka.streams.StreamsConfig} + * object and associated to the StateStoreContext. + * + *

                The type of the values is dependent on the {@link org.apache.kafka.common.config.ConfigDef.Type type} of the property + * (e.g. the value of {@link org.apache.kafka.streams.StreamsConfig#DEFAULT_KEY_SERDE_CLASS_CONFIG DEFAULT_KEY_SERDE_CLASS_CONFIG} + * will be of type {@link Class}, even if it was specified as a String to + * {@link org.apache.kafka.streams.StreamsConfig#StreamsConfig(Map) StreamsConfig(Map)}). + * + * @return all the key/values from the StreamsConfig properties + */ + Map appConfigs(); + + /** + * Returns all the application config properties with the given key prefix, as key/value pairs + * stripping the prefix. + * + *

                The config properties are defined in the {@link org.apache.kafka.streams.StreamsConfig} + * object and associated to the StateStoreContext. + * + * @param prefix the properties prefix + * @return the key/values matching the given prefix from the StreamsConfig properties. + */ + Map appConfigsWithPrefix(final String prefix); + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/StreamPartitioner.java b/streams/src/main/java/org/apache/kafka/streams/processor/StreamPartitioner.java new file mode 100644 index 0000000..a435caf --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/StreamPartitioner.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +import org.apache.kafka.clients.producer.internals.DefaultPartitioner; +import org.apache.kafka.streams.Topology; + +/** + * Determine how records are distributed among the partitions in a Kafka topic. If not specified, the underlying producer's + * {@link DefaultPartitioner} will be used to determine the partition. + *

                + * Kafka topics are divided into one or more partitions. Since each partition must fit on the servers that host it, so + * using multiple partitions allows the topic to scale beyond a size that will fit on a single machine. Partitions also enable you + * to use multiple instances of your topology to process in parallel all of the records on the topology's source topics. + *

                + * When a topology is instantiated, each of its sources are assigned a subset of that topic's partitions. That means that only + * those processors in that topology instance will consume the records from those partitions. In many cases, Kafka Streams will + * automatically manage these instances, and adjust when new topology instances are added or removed. + *

                + * Some topologies, though, need more control over which records appear in each partition. For example, some topologies that have + * stateful processors may want all records within a range of keys to always be delivered to and handled by the same topology instance. + * An upstream topology producing records to that topic can use a custom stream partitioner to precisely and consistently + * determine to which partition each record should be written. + *

                + * To do this, create a StreamPartitioner implementation, and when you build your topology specify that custom partitioner + * when {@link Topology#addSink(String, String, org.apache.kafka.common.serialization.Serializer, org.apache.kafka.common.serialization.Serializer, StreamPartitioner, String...) adding a sink} + * for that topic. + *

                + * All StreamPartitioner implementations should be stateless and a pure function so they can be shared across topic and sink nodes. + * + * @param the type of keys + * @param the type of values + * @see Topology#addSink(String, String, org.apache.kafka.common.serialization.Serializer, + * org.apache.kafka.common.serialization.Serializer, StreamPartitioner, String...) + * @see Topology#addSink(String, String, StreamPartitioner, String...) + */ +public interface StreamPartitioner { + + /** + * Determine the partition number for a record with the given key and value and the current number of partitions. + * + * @param topic the topic name this record is sent to + * @param key the key of the record + * @param value the value of the record + * @param numPartitions the total number of partitions + * @return an integer between 0 and {@code numPartitions-1}, or {@code null} if the default partitioning logic should be used + */ + Integer partition(String topic, K key, V value, int numPartitions); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/TaskId.java b/streams/src/main/java/org/apache/kafka/streams/processor/TaskId.java new file mode 100644 index 0000000..a5f48d9 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/TaskId.java @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +import org.apache.kafka.streams.errors.TaskIdFormatException; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Objects; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.kafka.streams.processor.internals.assignment.ConsumerProtocolUtils.readTaskIdFrom; +import static org.apache.kafka.streams.processor.internals.assignment.ConsumerProtocolUtils.writeTaskIdTo; + +/** + * The task ID representation composed as subtopology (aka topicGroupId) plus the assigned partition ID. + */ +public class TaskId implements Comparable { + + private static final Logger LOG = LoggerFactory.getLogger(TaskId.class); + + public static final String NAMED_TOPOLOGY_DELIMITER = "__"; + + /** The ID of the subtopology, aka topicGroupId. */ + @Deprecated + public final int topicGroupId; + /** The ID of the partition. */ + @Deprecated + public final int partition; + + /** The namedTopology that this task belongs to, or null if it does not belong to one */ + private final String topologyName; + + public TaskId(final int topicGroupId, final int partition) { + this(topicGroupId, partition, null); + } + + public TaskId(final int topicGroupId, final int partition, final String topologyName) { + this.topicGroupId = topicGroupId; + this.partition = partition; + if (topologyName != null && topologyName.length() == 0) { + LOG.warn("Empty string passed in for task's namedTopology, since NamedTopology name cannot be empty, we " + + "assume this task does not belong to a NamedTopology and downgrade this to null"); + this.topologyName = null; + } else { + this.topologyName = topologyName; + } + } + + public int subtopology() { + return topicGroupId; + } + + public int partition() { + return partition; + } + + /** + * Experimental feature -- will return null + */ + public String topologyName() { + return topologyName; + } + + @Override + public String toString() { + return topologyName != null ? topologyName + NAMED_TOPOLOGY_DELIMITER + topicGroupId + "_" + partition : topicGroupId + "_" + partition; + } + + /** + * @throws TaskIdFormatException if the taskIdStr is not a valid {@link TaskId} + */ + public static TaskId parse(final String taskIdStr) { + try { + final int namedTopologyDelimiterIndex = taskIdStr.indexOf(NAMED_TOPOLOGY_DELIMITER); + // If there is no copy of the NamedTopology delimiter, this task has no named topology and only one `_` char + if (namedTopologyDelimiterIndex < 0) { + final int index = taskIdStr.indexOf('_'); + + final int topicGroupId = Integer.parseInt(taskIdStr.substring(0, index)); + final int partition = Integer.parseInt(taskIdStr.substring(index + 1)); + + return new TaskId(topicGroupId, partition); + } else { + final int topicGroupIdIndex = namedTopologyDelimiterIndex + 2; + final int subtopologyPartitionDelimiterIndex = taskIdStr.indexOf('_', topicGroupIdIndex); + + final String namedTopology = taskIdStr.substring(0, namedTopologyDelimiterIndex); + final int topicGroupId = Integer.parseInt(taskIdStr.substring(topicGroupIdIndex, subtopologyPartitionDelimiterIndex)); + final int partition = Integer.parseInt(taskIdStr.substring(subtopologyPartitionDelimiterIndex + 1)); + + return new TaskId(topicGroupId, partition, namedTopology); + } + } catch (final Exception e) { + throw new TaskIdFormatException(taskIdStr); + } + } + + /** + * @throws IOException if cannot write to output stream + * @deprecated since 3.0, for internal use, will be removed + */ + @Deprecated + public void writeTo(final DataOutputStream out, final int version) throws IOException { + writeTaskIdTo(this, out, version); + } + + /** + * @throws IOException if cannot read from input stream + * @deprecated since 3.0, for internal use, will be removed + */ + @Deprecated + public static TaskId readFrom(final DataInputStream in, final int version) throws IOException { + return readTaskIdFrom(in, version); + } + + /** + * @deprecated since 3.0, for internal use, will be removed + */ + @Deprecated + public void writeTo(final ByteBuffer buf, final int version) { + writeTaskIdTo(this, buf, version); + } + + /** + * @deprecated since 3.0, for internal use, will be removed + */ + @Deprecated + public static TaskId readFrom(final ByteBuffer buf, final int version) { + return readTaskIdFrom(buf, version); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final TaskId taskId = (TaskId) o; + + if (topicGroupId != taskId.topicGroupId || partition != taskId.partition) { + return false; + } + + if (topologyName != null && taskId.topologyName != null) { + return topologyName.equals(taskId.topologyName); + } else { + return topologyName == null && taskId.topologyName == null; + } + } + + @Override + public int hashCode() { + return Objects.hash(topicGroupId, partition, topologyName); + } + + @Override + public int compareTo(final TaskId other) { + if (topologyName != null && other.topologyName != null) { + final int comparingNamedTopologies = topologyName.compareTo(other.topologyName); + if (comparingNamedTopologies != 0) { + return comparingNamedTopologies; + } + } else if (topologyName != null || other.topologyName != null) { + LOG.error("Tried to compare this = {} with other = {}, but only one had a valid named topology", this, other); + throw new IllegalStateException("Can't compare a TaskId with a namedTopology to one without"); + } + final int comparingTopicGroupId = Integer.compare(this.topicGroupId, other.topicGroupId); + return comparingTopicGroupId != 0 ? comparingTopicGroupId : Integer.compare(this.partition, other.partition); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/TaskMetadata.java b/streams/src/main/java/org/apache/kafka/streams/processor/TaskMetadata.java new file mode 100644 index 0000000..b9f5d91 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/TaskMetadata.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.streams.KafkaStreams; + +import java.util.Collections; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; + +/** + * Represents the state of a single task running within a {@link KafkaStreams} application. + * @deprecated since 3.0, use {@link org.apache.kafka.streams.TaskMetadata} instead. + */ +@Deprecated +public class TaskMetadata { + + private final String taskId; + + private final Set topicPartitions; + + private final Map committedOffsets; + + private final Map endOffsets; + + private final Optional timeCurrentIdlingStarted; + + public TaskMetadata(final String taskId, + final Set topicPartitions, + final Map committedOffsets, + final Map endOffsets, + final Optional timeCurrentIdlingStarted) { + this.taskId = taskId; + this.topicPartitions = Collections.unmodifiableSet(topicPartitions); + this.committedOffsets = Collections.unmodifiableMap(committedOffsets); + this.endOffsets = Collections.unmodifiableMap(endOffsets); + this.timeCurrentIdlingStarted = timeCurrentIdlingStarted; + } + + /** + * @return the basic task metadata such as subtopology and partition id + */ + public String taskId() { + return taskId; + } + + public Set topicPartitions() { + return topicPartitions; + } + + /** + * This function will return a map of TopicPartitions and the highest committed offset seen so far + */ + public Map committedOffsets() { + return committedOffsets; + } + + /** + * This function will return a map of TopicPartitions and the highest offset seen so far in the Topic + */ + public Map endOffsets() { + return endOffsets; + } + + /** + * This function will return the time task idling started, if the task is not currently idling it will return empty + */ + public Optional timeCurrentIdlingStarted() { + return timeCurrentIdlingStarted; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final TaskMetadata that = (TaskMetadata) o; + return Objects.equals(taskId, that.taskId) && + Objects.equals(topicPartitions, that.topicPartitions); + } + + @Override + public int hashCode() { + return Objects.hash(taskId, topicPartitions); + } + + @Override + public String toString() { + return "TaskMetadata{" + + "taskId=" + taskId + + ", topicPartitions=" + topicPartitions + + ", committedOffsets=" + committedOffsets + + ", endOffsets=" + endOffsets + + ", timeCurrentIdlingStarted=" + timeCurrentIdlingStarted + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/ThreadMetadata.java b/streams/src/main/java/org/apache/kafka/streams/processor/ThreadMetadata.java new file mode 100644 index 0000000..68ac663 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/ThreadMetadata.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +import org.apache.kafka.streams.KafkaStreams; + +import java.util.Collections; +import java.util.Objects; +import java.util.Set; + +/** + * Represents the state of a single thread running within a {@link KafkaStreams} application. + * @deprecated since 3.0 use {@link org.apache.kafka.streams.ThreadMetadata} instead + */ +@Deprecated +public class ThreadMetadata { + + private final String threadName; + + private final String threadState; + + private final Set activeTasks; + + private final Set standbyTasks; + + private final String mainConsumerClientId; + + private final String restoreConsumerClientId; + + private final Set producerClientIds; + + // the admin client should be shared among all threads, so the client id should be the same; + // we keep it at the thread-level for user's convenience and possible extensions in the future + private final String adminClientId; + + public ThreadMetadata(final String threadName, + final String threadState, + final String mainConsumerClientId, + final String restoreConsumerClientId, + final Set producerClientIds, + final String adminClientId, + final Set activeTasks, + final Set standbyTasks) { + this.mainConsumerClientId = mainConsumerClientId; + this.restoreConsumerClientId = restoreConsumerClientId; + this.producerClientIds = producerClientIds; + this.adminClientId = adminClientId; + this.threadName = threadName; + this.threadState = threadState; + this.activeTasks = Collections.unmodifiableSet(activeTasks); + this.standbyTasks = Collections.unmodifiableSet(standbyTasks); + } + + public String threadState() { + return threadState; + } + + public String threadName() { + return threadName; + } + + public Set activeTasks() { + return activeTasks; + } + + public Set standbyTasks() { + return standbyTasks; + } + + public String consumerClientId() { + return mainConsumerClientId; + } + + public String restoreConsumerClientId() { + return restoreConsumerClientId; + } + + public Set producerClientIds() { + return producerClientIds; + } + + public String adminClientId() { + return adminClientId; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final ThreadMetadata that = (ThreadMetadata) o; + return Objects.equals(threadName, that.threadName) && + Objects.equals(threadState, that.threadState) && + Objects.equals(activeTasks, that.activeTasks) && + Objects.equals(standbyTasks, that.standbyTasks) && + mainConsumerClientId.equals(that.mainConsumerClientId) && + restoreConsumerClientId.equals(that.restoreConsumerClientId) && + Objects.equals(producerClientIds, that.producerClientIds) && + adminClientId.equals(that.adminClientId); + } + + @Override + public int hashCode() { + return Objects.hash( + threadName, + threadState, + activeTasks, + standbyTasks, + mainConsumerClientId, + restoreConsumerClientId, + producerClientIds, + adminClientId); + } + + @Override + public String toString() { + return "ThreadMetadata{" + + "threadName=" + threadName + + ", threadState=" + threadState + + ", activeTasks=" + activeTasks + + ", standbyTasks=" + standbyTasks + + ", consumerClientId=" + mainConsumerClientId + + ", restoreConsumerClientId=" + restoreConsumerClientId + + ", producerClientIds=" + producerClientIds + + ", adminClientId=" + adminClientId + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/TimestampExtractor.java b/streams/src/main/java/org/apache/kafka/streams/processor/TimestampExtractor.java new file mode 100644 index 0000000..30d8208 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/TimestampExtractor.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.streams.kstream.KTable; + +/** + * An interface that allows the Kafka Streams framework to extract a timestamp from an instance of {@link ConsumerRecord}. + * The extracted timestamp is defined as milliseconds. + */ +public interface TimestampExtractor { + + /** + * Extracts a timestamp from a record. The timestamp must be positive to be considered a valid timestamp. + * Returning a negative timestamp will cause the record not to be processed but rather silently skipped. + * In case the record contains a negative timestamp and this is considered a fatal error for the application, + * throwing a {@link RuntimeException} instead of returning the timestamp is a valid option too. + * For this case, Streams will stop processing and shut down to allow you investigate in the root cause of the + * negative timestamp. + *

                + * The timestamp extractor implementation must be stateless. + *

                + * The extracted timestamp MUST represent the milliseconds since midnight, January 1, 1970 UTC. + *

                + * It is important to note that this timestamp may become the message timestamp for any messages sent to changelogs + * updated by {@link KTable}s and joins. + * The message timestamp is used for log retention and log rolling, so using nonsensical values may result in + * excessive log rolling and therefore broker performance degradation. + * + * + * @param record a data record + * @param partitionTime the highest extracted valid timestamp of the current record's partition˙ (could be -1 if unknown) + * @return the timestamp of the record + */ + long extract(ConsumerRecord record, long partitionTime); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/To.java b/streams/src/main/java/org/apache/kafka/streams/processor/To.java new file mode 100644 index 0000000..69c0c5b --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/To.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +import java.util.Objects; + +/** + * This class is used to provide the optional parameters when sending output records to downstream processor + * using {@link ProcessorContext#forward(Object, Object, To)}. + */ +public class To { + protected String childName; + protected long timestamp; + + private To(final String childName, + final long timestamp) { + this.childName = childName; + this.timestamp = timestamp; + } + + protected To(final To to) { + this(to.childName, to.timestamp); + } + + protected void update(final To to) { + childName = to.childName; + timestamp = to.timestamp; + } + + /** + * Forward the key/value pair to one of the downstream processors designated by the downstream processor name. + * @param childName name of downstream processor + * @return a new {@link To} instance configured with {@code childName} + */ + public static To child(final String childName) { + return new To(childName, -1); + } + + /** + * Forward the key/value pair to all downstream processors + * @return a new {@link To} instance configured for all downstream processor + */ + public static To all() { + return new To(null, -1); + } + + /** + * Set the timestamp of the output record. + * @param timestamp the output record timestamp + * @return itself (i.e., {@code this}) + */ + public To withTimestamp(final long timestamp) { + this.timestamp = timestamp; + return this; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final To to = (To) o; + return timestamp == to.timestamp && + Objects.equals(childName, to.childName); + } + + /** + * Equality is implemented in support of tests, *not* for use in Hash collections, since this class is mutable. + */ + @Override + public int hashCode() { + throw new UnsupportedOperationException("To is unsafe for use in Hash collections"); + } + + @Override + public String toString() { + return "To{" + + "childName='" + childName + '\'' + + ", timestamp=" + timestamp + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/TopicNameExtractor.java b/streams/src/main/java/org/apache/kafka/streams/processor/TopicNameExtractor.java new file mode 100644 index 0000000..0faa610 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/TopicNameExtractor.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +/** + * An interface that allows to dynamically determine the name of the Kafka topic to send at the sink node of the topology. + */ +public interface TopicNameExtractor { + + /** + * Extracts the topic name to send to. The topic name must already exist, since the Kafka Streams library will not + * try to automatically create the topic with the extracted name. + * + * @param key the record key + * @param value the record value + * @param recordContext current context metadata of the record + * @return the topic name this record should be sent to + */ + String extract(final K key, final V value, final RecordContext recordContext); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/UsePartitionTimeOnInvalidTimestamp.java b/streams/src/main/java/org/apache/kafka/streams/processor/UsePartitionTimeOnInvalidTimestamp.java new file mode 100644 index 0000000..92c4022 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/UsePartitionTimeOnInvalidTimestamp.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.streams.errors.StreamsException; + +/** + * Retrieves embedded metadata timestamps from Kafka messages. + * If a record has a negative (invalid) timestamp, a new timestamp will be inferred from the current stream-time. + *

                + * Embedded metadata timestamp was introduced in "KIP-32: Add timestamps to Kafka message" for the new + * 0.10+ Kafka message format. + *

                + * Here, "embedded metadata" refers to the fact that compatible Kafka producer clients automatically and + * transparently embed such timestamps into message metadata they send to Kafka, which can then be retrieved + * via this timestamp extractor. + *

                + * If the embedded metadata timestamp represents CreateTime (cf. Kafka broker setting + * {@code message.timestamp.type} and Kafka topic setting {@code log.message.timestamp.type}), + * this extractor effectively provides event-time semantics. + * If LogAppendTime is used as broker/topic setting to define the embedded metadata timestamps, + * using this extractor effectively provides ingestion-time semantics. + *

                + * If you need processing-time semantics, use {@link WallclockTimestampExtractor}. + * + * @see FailOnInvalidTimestamp + * @see LogAndSkipOnInvalidTimestamp + * @see WallclockTimestampExtractor + */ + +public class UsePartitionTimeOnInvalidTimestamp extends ExtractRecordMetadataTimestamp { + /** + * Returns the current stream-time as new timestamp for the record. + * + * @param record a data record + * @param recordTimestamp the timestamp extractor from the record + * @param partitionTime the highest extracted valid timestamp of the current record's partition˙ (could be -1 if unknown) + * @return the provided highest extracted valid timestamp as new timestamp for the record + * @throws StreamsException if highest extracted valid timestamp is unknown + */ + @Override + public long onInvalidTimestamp(final ConsumerRecord record, + final long recordTimestamp, + final long partitionTime) + throws StreamsException { + if (partitionTime < 0) { + throw new StreamsException("Could not infer new timestamp for input record " + record + + " because partition time is unknown."); + } + return partitionTime; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/WallclockTimestampExtractor.java b/streams/src/main/java/org/apache/kafka/streams/processor/WallclockTimestampExtractor.java new file mode 100644 index 0000000..799cdf4 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/WallclockTimestampExtractor.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +import org.apache.kafka.clients.consumer.ConsumerRecord; + +/** + * Retrieves current wall clock timestamps as {@link System#currentTimeMillis()}. + *

                + * Using this extractor effectively provides processing-time semantics. + *

                + * If you need event-time semantics, use {@link FailOnInvalidTimestamp} with + * built-in CreateTime or LogAppendTime timestamp (see KIP-32: Add timestamps to Kafka message for details). + * + * @see FailOnInvalidTimestamp + * @see LogAndSkipOnInvalidTimestamp + * @see UsePartitionTimeOnInvalidTimestamp + */ +public class WallclockTimestampExtractor implements TimestampExtractor { + + /** + * Return the current wall clock time as timestamp. + * + * @param record a data record + * @param partitionTime the highest extracted valid timestamp of the current record's partition˙ (could be -1 if unknown) + * @return the current wall clock time, expressed in milliseconds since midnight, January 1, 1970 UTC + */ + @Override + public long extract(final ConsumerRecord record, final long partitionTime) { + return System.currentTimeMillis(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/api/ContextualProcessor.java b/streams/src/main/java/org/apache/kafka/streams/processor/api/ContextualProcessor.java new file mode 100644 index 0000000..96cc278 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/api/ContextualProcessor.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.api; + +/** + * An abstract implementation of {@link Processor} that manages the {@link ProcessorContext} instance and provides default no-op + * implementation of {@link #close()}. + * + * @param the type of input keys + * @param the type of input values + * @param the type of output keys + * @param the type of output values + */ +public abstract class ContextualProcessor implements Processor { + + private ProcessorContext context; + + protected ContextualProcessor() {} + + @Override + public void init(final ProcessorContext context) { + this.context = context; + } + + /** + * Get the processor's context set during {@link #init(ProcessorContext) initialization}. + * + * @return the processor context; null only when called prior to {@link #init(ProcessorContext) initialization}. + */ + protected final ProcessorContext context() { + return context; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/api/Processor.java b/streams/src/main/java/org/apache/kafka/streams/processor/api/Processor.java new file mode 100644 index 0000000..167976b --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/api/Processor.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.api; + +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.Punctuator; +import org.apache.kafka.streams.processor.StateStore; + +import java.time.Duration; + +/** + * A processor of key-value pair records. + * + * @param the type of input keys + * @param the type of input values + * @param the type of output keys + * @param the type of output values + */ +public interface Processor { + + /** + * Initialize this processor with the given context. The framework ensures this is called once per processor when the topology + * that contains it is initialized. When the framework is done with the processor, {@link #close()} will be called on it; the + * framework may later re-use the processor by calling {@code #init()} again. + *

                + * The provided {@link ProcessorContext context} can be used to access topology and record meta data, to + * {@link ProcessorContext#schedule(Duration, PunctuationType, Punctuator) schedule} a method to be + * {@link Punctuator#punctuate(long) called periodically} and to access attached {@link StateStore}s. + * + * @param context the context; may not be null + */ + default void init(final ProcessorContext context) {} + + /** + * Process the record. Note that record metadata is undefined in cases such as a forward call from a punctuator. + * + * @param record the record to process + */ + void process(Record record); + + /** + * Close this processor and clean up any resources. Be aware that {@code #close()} is called after an internal cleanup. + * Thus, it is not possible to write anything to Kafka as underlying clients are already closed. The framework may + * later re-use this processor by calling {@code #init()} on it again. + *

                + * Note: Do not close any streams managed resources, like {@link StateStore}s here, as they are managed by the library. + */ + default void close() {} +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/api/ProcessorContext.java b/streams/src/main/java/org/apache/kafka/streams/processor/api/ProcessorContext.java new file mode 100644 index 0000000..d110a76 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/api/ProcessorContext.java @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.api; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.StreamsMetrics; +import org.apache.kafka.streams.processor.Cancellable; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.Punctuator; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.TimestampExtractor; + +import java.io.File; +import java.time.Duration; +import java.util.Map; +import java.util.Optional; + +/** + * Processor context interface. + * + * @param a bound on the types of keys that may be forwarded + * @param a bound on the types of values that may be forwarded + */ +public interface ProcessorContext { + + /** + * Return the application id. + * + * @return the application id + */ + String applicationId(); + + /** + * Return the task id. + * + * @return the task id + */ + TaskId taskId(); + + /** + * Return the metadata of the current record if available. Processors may be invoked to + * process a source record from an input topic, to run a scheduled punctuation + * (see {@link ProcessorContext#schedule(Duration, PunctuationType, Punctuator)}), + * or because a parent processor called {@link ProcessorContext#forward(Record)}. + *

                + * In the case of a punctuation, there is no source record, so this metadata would be + * undefined. Note that when a punctuator invokes {@link ProcessorContext#forward(Record)}, + * downstream processors will receive the forwarded record as a regular + * {@link Processor#process(Record)} invocation. In other words, it wouldn't be apparent to + * downstream processors whether or not the record being processed came from an input topic + * or punctuation and therefore whether or not this metadata is defined. This is why + * the return type of this method is {@link Optional}. + *

                + * If there is any possibility of punctuators upstream, any access + * to this field should consider the case of + * "recordMetadata().isPresent() == false". + * Of course, it would be safest to always guard this condition. + */ + Optional recordMetadata(); + + /** + * Return the default key serde. + * + * @return the key serializer + */ + Serde keySerde(); + + /** + * Return the default value serde. + * + * @return the value serializer + */ + Serde valueSerde(); + + /** + * Return the state directory for the partition. + * + * @return the state directory + */ + File stateDir(); + + /** + * Return Metrics instance. + * + * @return StreamsMetrics + */ + StreamsMetrics metrics(); + + /** + * Get the state store given the store name. + * + * @param name The store name + * @param The type or interface of the store to return + * @return The state store instance + * + * @throws ClassCastException if the return type isn't a type or interface of the actual returned store. + */ + S getStateStore(final String name); + + /** + * Schedule a periodic operation for processors. A processor may call this method during + * {@link Processor#init(ProcessorContext) initialization} or + * {@link Processor#process(Record)} processing} to + * schedule a periodic callback — called a punctuation — to {@link Punctuator#punctuate(long)}. + * The type parameter controls what notion of time is used for punctuation: + *

                  + *
                • {@link PunctuationType#STREAM_TIME} — uses "stream time", which is advanced by the processing of messages + * in accordance with the timestamp as extracted by the {@link TimestampExtractor} in use. + * The first punctuation will be triggered by the first record that is processed. + * NOTE: Only advanced if messages arrive
                • + *
                • {@link PunctuationType#WALL_CLOCK_TIME} — uses system time (the wall-clock time), + * which is advanced independent of whether new messages arrive. + * The first punctuation will be triggered after interval has elapsed. + * NOTE: This is best effort only as its granularity is limited by how long an iteration of the + * processing loop takes to complete
                • + *
                + * + * Skipping punctuations: Punctuations will not be triggered more than once at any given timestamp. + * This means that "missed" punctuation will be skipped. + * It's possible to "miss" a punctuation if: + *
                  + *
                • with {@link PunctuationType#STREAM_TIME}, when stream time advances more than interval
                • + *
                • with {@link PunctuationType#WALL_CLOCK_TIME}, on GC pause, too short interval, ...
                • + *
                + * + * @param interval the time interval between punctuations (supported minimum is 1 millisecond) + * @param type one of: {@link PunctuationType#STREAM_TIME}, {@link PunctuationType#WALL_CLOCK_TIME} + * @param callback a function consuming timestamps representing the current stream or system time + * @return a handle allowing cancellation of the punctuation schedule established by this method + * @throws IllegalArgumentException if the interval is not representable in milliseconds + */ + Cancellable schedule(final Duration interval, + final PunctuationType type, + final Punctuator callback); + + /** + * Forward a record to all child processors. + *

                + * Note that the forwarded {@link Record} is shared between the parent and child + * processors. And of course, the parent may forward the same object to multiple children, + * and the child may forward it to grandchildren, etc. Therefore, you should be mindful + * of mutability. + *

                + * The {@link Record} class itself is immutable (all the setter-style methods return an + * independent copy of the instance). However, the key, value, and headers referenced by + * the Record may themselves be mutable. + *

                + * Some programs may opt to make use of this mutability for high performance, in which case + * the input record may be mutated and then forwarded by each {@link Processor}. However, + * most applications should instead favor safety. + *

                + * Forwarding records safely simply means to make a copy of the record before you mutate it. + * This is trivial when using the {@link Record#withKey(Object)}, {@link Record#withValue(Object)}, + * and {@link Record#withTimestamp(long)} methods, as each of these methods make a copy of the + * record as a matter of course. But a little extra care must be taken with headers, since + * the {@link org.apache.kafka.common.header.Header} class is mutable. The easiest way to + * safely handle headers is to use the {@link Record} constructors to make a copy before + * modifying headers. + *

                + * In other words, this would be considered unsafe: + * + * process(Record inputRecord) { + * inputRecord.headers().add(...); + * context.forward(inputRecord); + * } + * + * This is unsafe because the parent, and potentially siblings, grandparents, etc., + * all will see this modification to their shared Headers reference. This is a violation + * of causality and could lead to undefined behavior. + *

                + * A safe usage would look like this: + * + * process(Record inputRecord) { + * // makes a copy of the headers + * Record toForward = inputRecord.withHeaders(inputRecord.headers()); + * // Other options to create a safe copy are: + * // * use any copy-on-write method, which makes a copy of all fields: + * // toForward = inputRecord.withValue(); + * // * explicitly copy all fields: + * // toForward = new Record(inputRecord.key(), inputRecord.value(), inputRecord.timestamp(), inputRecord.headers()); + * // * create a fresh, empty Headers: + * // toForward = new Record(inputRecord.key(), inputRecord.value(), inputRecord.timestamp()); + * // * etc. + * + * // now, we are modifying our own independent copy of the headers. + * toForward.headers().add(...); + * context.forward(toForward); + * } + * + * @param record The record to forward to all children + */ + void forward(Record record); + + /** + * Forward a record to the specified child processor. + * See {@link ProcessorContext#forward(Record)} for considerations. + * + * @param record The record to forward + * @param childName The name of the child processor to receive the record + * @see ProcessorContext#forward(Record) + */ + void forward(Record record, final String childName); + + /** + * Request a commit. + */ + void commit(); + + /** + * Returns all the application config properties as key/value pairs. + * + *

                The config properties are defined in the {@link org.apache.kafka.streams.StreamsConfig} + * object and associated to the ProcessorContext. + * + *

                The type of the values is dependent on the {@link org.apache.kafka.common.config.ConfigDef.Type type} of the property + * (e.g. the value of {@link org.apache.kafka.streams.StreamsConfig#DEFAULT_KEY_SERDE_CLASS_CONFIG DEFAULT_KEY_SERDE_CLASS_CONFIG} + * will be of type {@link Class}, even if it was specified as a String to + * {@link org.apache.kafka.streams.StreamsConfig#StreamsConfig(Map) StreamsConfig(Map)}). + * + * @return all the key/values from the StreamsConfig properties + */ + Map appConfigs(); + + /** + * Return all the application config properties with the given key prefix, as key/value pairs + * stripping the prefix. + * + *

                The config properties are defined in the {@link org.apache.kafka.streams.StreamsConfig} + * object and associated to the ProcessorContext. + * + * @param prefix the properties prefix + * @return the key/values matching the given prefix from the StreamsConfig properties. + */ + Map appConfigsWithPrefix(final String prefix); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/api/ProcessorSupplier.java b/streams/src/main/java/org/apache/kafka/streams/processor/api/ProcessorSupplier.java new file mode 100644 index 0000000..7ee1a5f --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/api/ProcessorSupplier.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.api; + +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.processor.ConnectedStoreProvider; + +import java.util.function.Supplier; + +/** + * A processor supplier that can create one or more {@link Processor} instances. + *

                + * It is used in {@link Topology} for adding new processor operators, whose generated + * topology can then be replicated (and thus creating one or more {@link Processor} instances) + * and distributed to multiple stream threads. + * + * The supplier should always generate a new instance each time {@link ProcessorSupplier#get()} gets called. Creating + * a single {@link Processor} object and returning the same object reference in {@link ProcessorSupplier#get()} would be + * a violation of the supplier pattern and leads to runtime exceptions. + * + * @param the type of input keys + * @param the type of input values + * @param the type of output keys + * @param the type of output values + */ +@FunctionalInterface +public interface ProcessorSupplier extends ConnectedStoreProvider, Supplier> { + + /** + * Return a newly constructed {@link Processor} instance. + * The supplier should always generate a new instance each time {@link ProcessorSupplier#get()} gets called. + *

                + * Creating a single {@link Processor} object and returning the same object reference in {@link ProcessorSupplier#get()} + * is a violation of the supplier pattern and leads to runtime exceptions. + * + * @return a new {@link Processor} instance + */ + Processor get(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/api/Record.java b/streams/src/main/java/org/apache/kafka/streams/processor/api/Record.java new file mode 100644 index 0000000..225b95f --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/api/Record.java @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.api; + +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.streams.errors.StreamsException; + +import java.util.Objects; + +/** + * A data class representing an incoming record for processing in a {@link Processor} + * or a record to forward to downstream processors via {@link ProcessorContext}. + * + * This class encapsulates all the data attributes of a record: the key and value, but + * also the timestamp of the record and any record headers. + * + * This class is immutable, though the objects referenced in the attributes of this class + * may themselves be mutable. + * + * @param The type of the key + * @param The type of the value + */ +public class Record { + private final K key; + private final V value; + private final long timestamp; + private final Headers headers; + + /** + * The full constructor, specifying all the attributes of the record. + * + * Note: this constructor makes a copy of the headers argument. + * See {@link ProcessorContext#forward(Record)} for + * considerations around mutability of keys, values, and headers. + * + * @param key The key of the record. May be null. + * @param value The value of the record. May be null. + * @param timestamp The timestamp of the record. May not be negative. + * @param headers The headers of the record. May be null, which will cause subsequent calls + * to {@link #headers()} to return a non-null, empty, {@link Headers} collection. + * @throws IllegalArgumentException if the timestamp is negative. + * @see ProcessorContext#forward(Record) + */ + public Record(final K key, final V value, final long timestamp, final Headers headers) { + this.key = key; + this.value = value; + if (timestamp < 0) { + throw new StreamsException( + "Malformed Record", + new IllegalArgumentException("Timestamp may not be negative. Got: " + timestamp) + ); + } + this.timestamp = timestamp; + this.headers = new RecordHeaders(headers); + } + + /** + * Convenience constructor in case you do not wish to specify any headers. + * Subsequent calls to {@link #headers()} will return a non-null, empty, {@link Headers} collection. + * + * @param key The key of the record. May be null. + * @param value The value of the record. May be null. + * @param timestamp The timestamp of the record. May not be negative. + * + * @throws IllegalArgumentException if the timestamp is negative. + */ + public Record(final K key, final V value, final long timestamp) { + this(key, value, timestamp, null); + } + + /** + * The key of the record. May be null. + */ + public K key() { + return key; + } + + /** + * The value of the record. May be null. + */ + public V value() { + return value; + } + + /** + * The timestamp of the record. Will never be negative. + */ + public long timestamp() { + return timestamp; + } + + /** + * The headers of the record. Never null. + */ + public Headers headers() { + return headers; + } + + /** + * A convenient way to produce a new record if you only need to change the key. + * + * Copies the attributes of this record with the key replaced. + * + * @param key The key of the result record. May be null. + * @param The type of the new record's key. + * @return A new Record instance with all the same attributes (except that the key is replaced). + */ + public Record withKey(final NewK key) { + return new Record<>(key, value, timestamp, headers); + } + + /** + * A convenient way to produce a new record if you only need to change the value. + * + * Copies the attributes of this record with the value replaced. + * + * @param value The value of the result record. + * @param The type of the new record's value. + * @return A new Record instance with all the same attributes (except that the value is replaced). + */ + public Record withValue(final NewV value) { + return new Record<>(key, value, timestamp, headers); + } + + /** + * A convenient way to produce a new record if you only need to change the timestamp. + * + * Copies the attributes of this record with the timestamp replaced. + * + * @param timestamp The timestamp of the result record. + * @return A new Record instance with all the same attributes (except that the timestamp is replaced). + */ + public Record withTimestamp(final long timestamp) { + return new Record<>(key, value, timestamp, headers); + } + + /** + * A convenient way to produce a new record if you only need to change the headers. + * + * Copies the attributes of this record with the headers replaced. + * Also makes a copy of the provided headers. + * + * See {@link ProcessorContext#forward(Record)} for + * considerations around mutability of keys, values, and headers. + * + * @param headers The headers of the result record. + * @return A new Record instance with all the same attributes (except that the headers are replaced). + */ + public Record withHeaders(final Headers headers) { + return new Record<>(key, value, timestamp, headers); + } + + @Override + public String toString() { + return "Record{" + + "key=" + key + + ", value=" + value + + ", timestamp=" + timestamp + + ", headers=" + headers + + '}'; + } + + @Override + public boolean equals(final Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + final Record record = (Record) o; + return timestamp == record.timestamp && + Objects.equals(key, record.key) && + Objects.equals(value, record.value) && + Objects.equals(headers, record.headers); + } + + @Override + public int hashCode() { + return Objects.hash(key, value, timestamp, headers); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/api/RecordMetadata.java b/streams/src/main/java/org/apache/kafka/streams/processor/api/RecordMetadata.java new file mode 100644 index 0000000..ab88b89 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/api/RecordMetadata.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.api; + +import org.apache.kafka.streams.kstream.ValueTransformerWithKeySupplier; + +public interface RecordMetadata { + /** + * Return the topic name of the current input record; could be {@code null} if it is not + * available. + * + *

                For example, if this method is invoked within a @link Punctuator#punctuate(long) + * punctuation callback}, or while processing a record that was forwarded by a punctuation + * callback, the record won't have an associated topic. + * Another example is + * {@link org.apache.kafka.streams.kstream.KTable#transformValues(ValueTransformerWithKeySupplier, String...)} + * (and siblings), that do not always guarantee to provide a valid topic name, as they might be + * executed "out-of-band" due to some internal optimizations applied by the Kafka Streams DSL. + * + * @return the topic name + */ + String topic(); + + /** + * Return the partition id of the current input record; could be {@code -1} if it is not + * available. + * + *

                For example, if this method is invoked within a @link Punctuator#punctuate(long) + * punctuation callback}, or while processing a record that was forwarded by a punctuation + * callback, the record won't have an associated partition id. + * Another example is + * {@link org.apache.kafka.streams.kstream.KTable#transformValues(ValueTransformerWithKeySupplier, String...)} + * (and siblings), that do not always guarantee to provide a valid partition id, as they might be + * executed "out-of-band" due to some internal optimizations applied by the Kafka Streams DSL. + * + * @return the partition id + */ + int partition(); + + /** + * Return the offset of the current input record; could be {@code -1} if it is not + * available. + * + *

                For example, if this method is invoked within a @link Punctuator#punctuate(long) + * punctuation callback}, or while processing a record that was forwarded by a punctuation + * callback, the record won't have an associated offset. + * Another example is + * {@link org.apache.kafka.streams.kstream.KTable#transformValues(ValueTransformerWithKeySupplier, String...)} + * (and siblings), that do not always guarantee to provide a valid offset, as they might be + * executed "out-of-band" due to some internal optimizations applied by the Kafka Streams DSL. + * + * @return the offset + */ + long offset(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractProcessorContext.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractProcessorContext.java new file mode 100644 index 0000000..ad53a6e --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractProcessorContext.java @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.processor.StateRestoreCallback; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.api.RecordMetadata; +import org.apache.kafka.streams.processor.internals.Task.TaskType; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.internals.ThreadCache; + +import java.io.File; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +public abstract class AbstractProcessorContext implements InternalProcessorContext { + + private final TaskId taskId; + private final String applicationId; + private final StreamsConfig config; + private final StreamsMetricsImpl metrics; + private final Serde keySerde; + private final Serde valueSerde; + private boolean initialized; + protected ProcessorRecordContext recordContext; + protected ProcessorNode currentNode; + private long cachedSystemTimeMs; + protected ThreadCache cache; + + public AbstractProcessorContext(final TaskId taskId, + final StreamsConfig config, + final StreamsMetricsImpl metrics, + final ThreadCache cache) { + this.taskId = taskId; + this.applicationId = config.getString(StreamsConfig.APPLICATION_ID_CONFIG); + this.config = config; + this.metrics = metrics; + valueSerde = null; + keySerde = null; + this.cache = cache; + } + + protected abstract StateManager stateManager(); + + @Override + public void setSystemTimeMs(final long timeMs) { + cachedSystemTimeMs = timeMs; + } + + @Override + public long currentSystemTimeMs() { + return cachedSystemTimeMs; + } + + @Override + public String applicationId() { + return applicationId; + } + + @Override + public TaskId taskId() { + return taskId; + } + + @Override + public Serde keySerde() { + if (keySerde == null) { + return config.defaultKeySerde(); + } + return keySerde; + } + + @Override + public Serde valueSerde() { + if (valueSerde == null) { + return config.defaultValueSerde(); + } + return valueSerde; + } + + @Override + public File stateDir() { + return stateManager().baseDir(); + } + + @Override + public StreamsMetricsImpl metrics() { + return metrics; + } + + @Override + public void register(final StateStore store, + final StateRestoreCallback stateRestoreCallback) { + if (initialized) { + throw new IllegalStateException("Can only create state stores during initialization."); + } + Objects.requireNonNull(store, "store must not be null"); + stateManager().registerStore(store, stateRestoreCallback); + } + + @Override + public String topic() { + if (recordContext == null) { + // This is only exposed via the deprecated ProcessorContext, + // in which case, we're preserving the pre-existing behavior + // of returning dummy values when the record context is undefined. + // For topic, the dummy value is `null`. + return null; + } else { + return recordContext.topic(); + } + } + + @Override + public int partition() { + if (recordContext == null) { + // This is only exposed via the deprecated ProcessorContext, + // in which case, we're preserving the pre-existing behavior + // of returning dummy values when the record context is undefined. + // For partition, the dummy value is `-1`. + return -1; + } else { + return recordContext.partition(); + } + } + + @Override + public long offset() { + if (recordContext == null) { + // This is only exposed via the deprecated ProcessorContext, + // in which case, we're preserving the pre-existing behavior + // of returning dummy values when the record context is undefined. + // For offset, the dummy value is `-1L`. + return -1L; + } else { + return recordContext.offset(); + } + } + + @Override + public Headers headers() { + if (recordContext == null) { + // This is only exposed via the deprecated ProcessorContext, + // in which case, we're preserving the pre-existing behavior + // of returning dummy values when the record context is undefined. + // For headers, the dummy value is an empty headers collection. + return new RecordHeaders(); + } else { + return recordContext.headers(); + } + } + + @Override + public long timestamp() { + if (recordContext == null) { + // This is only exposed via the deprecated ProcessorContext, + // in which case, we're preserving the pre-existing behavior + // of returning dummy values when the record context is undefined. + // For timestamp, the dummy value is `0L`. + return 0L; + } else { + return recordContext.timestamp(); + } + } + + @Override + public Map appConfigs() { + final Map combined = new HashMap<>(); + combined.putAll(config.originals()); + combined.putAll(config.values()); + return combined; + } + + @Override + public Map appConfigsWithPrefix(final String prefix) { + return config.originalsWithPrefix(prefix); + } + + @Override + public void setRecordContext(final ProcessorRecordContext recordContext) { + this.recordContext = recordContext; + } + + @Override + public ProcessorRecordContext recordContext() { + return recordContext; + } + + @Override + public Optional recordMetadata() { + return Optional.ofNullable(recordContext); + } + + @Override + public void setCurrentNode(final ProcessorNode currentNode) { + this.currentNode = currentNode; + } + + @Override + public ProcessorNode currentNode() { + return currentNode; + } + + @Override + public ThreadCache cache() { + return cache; + } + + @Override + public void initialize() { + initialized = true; + } + + @Override + public void uninitialize() { + initialized = false; + } + + @Override + public TaskType taskType() { + return stateManager().taskType(); + } + + @Override + public String changelogFor(final String storeName) { + return stateManager().changelogFor(storeName); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractReadOnlyDecorator.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractReadOnlyDecorator.java new file mode 100644 index 0000000..3ec8d7f --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractReadOnlyDecorator.java @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.SessionStore; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.TimestampedWindowStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.streams.state.WindowStoreIterator; +import org.apache.kafka.streams.state.internals.WrappedStateStore; + +import java.util.List; + +abstract class AbstractReadOnlyDecorator extends WrappedStateStore { + + static final String ERROR_MESSAGE = "Global store is read only"; + + private AbstractReadOnlyDecorator(final T inner) { + super(inner); + } + + @Override + public void flush() { + throw new UnsupportedOperationException(ERROR_MESSAGE); + } + + @Deprecated + @Override + public void init(final ProcessorContext context, + final StateStore root) { + throw new UnsupportedOperationException(ERROR_MESSAGE); + } + + @Override + public void init(final StateStoreContext context, + final StateStore root) { + throw new UnsupportedOperationException(ERROR_MESSAGE); + } + + @Override + public void close() { + throw new UnsupportedOperationException(ERROR_MESSAGE); + } + + static StateStore getReadOnlyStore(final StateStore global) { + if (global instanceof TimestampedKeyValueStore) { + return new TimestampedKeyValueStoreReadOnlyDecorator<>((TimestampedKeyValueStore) global); + } else if (global instanceof KeyValueStore) { + return new KeyValueStoreReadOnlyDecorator<>((KeyValueStore) global); + } else if (global instanceof TimestampedWindowStore) { + return new TimestampedWindowStoreReadOnlyDecorator<>((TimestampedWindowStore) global); + } else if (global instanceof WindowStore) { + return new WindowStoreReadOnlyDecorator<>((WindowStore) global); + } else if (global instanceof SessionStore) { + return new SessionStoreReadOnlyDecorator<>((SessionStore) global); + } else { + return global; + } + } + + static class KeyValueStoreReadOnlyDecorator + extends AbstractReadOnlyDecorator, K, V> + implements KeyValueStore { + + private KeyValueStoreReadOnlyDecorator(final KeyValueStore inner) { + super(inner); + } + + @Override + public V get(final K key) { + return wrapped().get(key); + } + + @Override + public KeyValueIterator range(final K from, + final K to) { + return wrapped().range(from, to); + } + + @Override + public KeyValueIterator reverseRange(final K from, + final K to) { + return wrapped().reverseRange(from, to); + } + + @Override + public KeyValueIterator all() { + return wrapped().all(); + } + + @Override + public KeyValueIterator reverseAll() { + return wrapped().reverseAll(); + } + + @Override + public , P> KeyValueIterator prefixScan(final P prefix, + final PS prefixKeySerializer) { + return wrapped().prefixScan(prefix, prefixKeySerializer); + } + + @Override + public long approximateNumEntries() { + return wrapped().approximateNumEntries(); + } + + @Override + public void put(final K key, + final V value) { + throw new UnsupportedOperationException(ERROR_MESSAGE); + } + + @Override + public V putIfAbsent(final K key, + final V value) { + throw new UnsupportedOperationException(ERROR_MESSAGE); + } + + @Override + public void putAll(final List> entries) { + throw new UnsupportedOperationException(ERROR_MESSAGE); + } + + @Override + public V delete(final K key) { + throw new UnsupportedOperationException(ERROR_MESSAGE); + } + } + + static class TimestampedKeyValueStoreReadOnlyDecorator + extends KeyValueStoreReadOnlyDecorator> + implements TimestampedKeyValueStore { + + private TimestampedKeyValueStoreReadOnlyDecorator(final TimestampedKeyValueStore inner) { + super(inner); + } + } + + static class WindowStoreReadOnlyDecorator + extends AbstractReadOnlyDecorator, K, V> + implements WindowStore { + + private WindowStoreReadOnlyDecorator(final WindowStore inner) { + super(inner); + } + + @Override + public void put(final K key, + final V value, + final long windowStartTimestamp) { + throw new UnsupportedOperationException(ERROR_MESSAGE); + } + + @Override + public V fetch(final K key, + final long time) { + return wrapped().fetch(key, time); + } + + @Override + @Deprecated + public WindowStoreIterator fetch(final K key, + final long timeFrom, + final long timeTo) { + return wrapped().fetch(key, timeFrom, timeTo); + } + + @Override + public WindowStoreIterator backwardFetch(final K key, + final long timeFrom, + final long timeTo) { + return wrapped().backwardFetch(key, timeFrom, timeTo); + } + + @Override + @Deprecated + public KeyValueIterator, V> fetch(final K keyFrom, + final K keyTo, + final long timeFrom, + final long timeTo) { + return wrapped().fetch(keyFrom, keyTo, timeFrom, timeTo); + } + + @Override + public KeyValueIterator, V> backwardFetch(final K keyFrom, + final K keyTo, + final long timeFrom, + final long timeTo) { + return wrapped().backwardFetch(keyFrom, keyTo, timeFrom, timeTo); + } + + @Override + public KeyValueIterator, V> all() { + return wrapped().all(); + } + + @Override + public KeyValueIterator, V> backwardAll() { + return wrapped().backwardAll(); + } + + @Override + @Deprecated + public KeyValueIterator, V> fetchAll(final long timeFrom, + final long timeTo) { + return wrapped().fetchAll(timeFrom, timeTo); + } + + @Override + public KeyValueIterator, V> backwardFetchAll(final long timeFrom, + final long timeTo) { + return wrapped().backwardFetchAll(timeFrom, timeTo); + } + } + + static class TimestampedWindowStoreReadOnlyDecorator + extends WindowStoreReadOnlyDecorator> + implements TimestampedWindowStore { + + private TimestampedWindowStoreReadOnlyDecorator(final TimestampedWindowStore inner) { + super(inner); + } + } + + static class SessionStoreReadOnlyDecorator + extends AbstractReadOnlyDecorator, K, AGG> + implements SessionStore { + + private SessionStoreReadOnlyDecorator(final SessionStore inner) { + super(inner); + } + + @Override + public KeyValueIterator, AGG> findSessions(final K key, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + return wrapped().findSessions(key, earliestSessionEndTime, latestSessionStartTime); + } + + @Override + public KeyValueIterator, AGG> findSessions(final K keyFrom, + final K keyTo, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + return wrapped().findSessions(keyFrom, keyTo, earliestSessionEndTime, latestSessionStartTime); + } + + @Override + public void remove(final Windowed sessionKey) { + throw new UnsupportedOperationException(ERROR_MESSAGE); + } + + @Override + public void put(final Windowed sessionKey, + final AGG aggregate) { + throw new UnsupportedOperationException(ERROR_MESSAGE); + } + + @Override + public AGG fetchSession(final K key, final long earliestSessionEndTime, final long latestSessionStartTime) { + return wrapped().fetchSession(key, earliestSessionEndTime, latestSessionStartTime); + } + + @Override + public KeyValueIterator, AGG> fetch(final K key) { + return wrapped().fetch(key); + } + + @Override + public KeyValueIterator, AGG> fetch(final K keyFrom, + final K keyTo) { + return wrapped().fetch(keyFrom, keyTo); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractReadWriteDecorator.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractReadWriteDecorator.java new file mode 100644 index 0000000..aff099a --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractReadWriteDecorator.java @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.SessionStore; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.TimestampedWindowStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.streams.state.WindowStoreIterator; +import org.apache.kafka.streams.state.internals.WrappedStateStore; + +import java.util.List; + +abstract class AbstractReadWriteDecorator extends WrappedStateStore { + static final String ERROR_MESSAGE = "This method may only be called by Kafka Streams"; + + private AbstractReadWriteDecorator(final T inner) { + super(inner); + } + + @Deprecated + @Override + public void init(final ProcessorContext context, + final StateStore root) { + throw new UnsupportedOperationException(ERROR_MESSAGE); + } + + @Override + public void init(final StateStoreContext context, + final StateStore root) { + throw new UnsupportedOperationException(ERROR_MESSAGE); + } + + @Override + public void close() { + throw new UnsupportedOperationException(ERROR_MESSAGE); + } + + static StateStore getReadWriteStore(final StateStore store) { + if (store instanceof TimestampedKeyValueStore) { + return new TimestampedKeyValueStoreReadWriteDecorator<>((TimestampedKeyValueStore) store); + } else if (store instanceof KeyValueStore) { + return new KeyValueStoreReadWriteDecorator<>((KeyValueStore) store); + } else if (store instanceof TimestampedWindowStore) { + return new TimestampedWindowStoreReadWriteDecorator<>((TimestampedWindowStore) store); + } else if (store instanceof WindowStore) { + return new WindowStoreReadWriteDecorator<>((WindowStore) store); + } else if (store instanceof SessionStore) { + return new SessionStoreReadWriteDecorator<>((SessionStore) store); + } else { + return store; + } + } + + static class KeyValueStoreReadWriteDecorator + extends AbstractReadWriteDecorator, K, V> + implements KeyValueStore { + + KeyValueStoreReadWriteDecorator(final KeyValueStore inner) { + super(inner); + } + + @Override + public V get(final K key) { + return wrapped().get(key); + } + + @Override + public KeyValueIterator range(final K from, + final K to) { + return wrapped().range(from, to); + } + + @Override + public KeyValueIterator reverseRange(final K from, + final K to) { + return wrapped().reverseRange(from, to); + } + + @Override + public KeyValueIterator all() { + return wrapped().all(); + } + + @Override + public KeyValueIterator reverseAll() { + return wrapped().reverseAll(); + } + + @Override + public , P> KeyValueIterator prefixScan(final P prefix, + final PS prefixKeySerializer) { + return wrapped().prefixScan(prefix, prefixKeySerializer); + } + + @Override + public long approximateNumEntries() { + return wrapped().approximateNumEntries(); + } + + @Override + public void put(final K key, + final V value) { + wrapped().put(key, value); + } + + @Override + public V putIfAbsent(final K key, + final V value) { + return wrapped().putIfAbsent(key, value); + } + + @Override + public void putAll(final List> entries) { + wrapped().putAll(entries); + } + + @Override + public V delete(final K key) { + return wrapped().delete(key); + } + } + + static class TimestampedKeyValueStoreReadWriteDecorator + extends KeyValueStoreReadWriteDecorator> + implements TimestampedKeyValueStore { + + TimestampedKeyValueStoreReadWriteDecorator(final TimestampedKeyValueStore inner) { + super(inner); + } + } + + static class WindowStoreReadWriteDecorator + extends AbstractReadWriteDecorator, K, V> + implements WindowStore { + + WindowStoreReadWriteDecorator(final WindowStore inner) { + super(inner); + } + + @Override + public void put(final K key, + final V value, + final long windowStartTimestamp) { + wrapped().put(key, value, windowStartTimestamp); + } + + @Override + public V fetch(final K key, + final long time) { + return wrapped().fetch(key, time); + } + + @Override + public WindowStoreIterator fetch(final K key, + final long timeFrom, + final long timeTo) { + return wrapped().fetch(key, timeFrom, timeTo); + } + + @Override + public WindowStoreIterator backwardFetch(final K key, + final long timeFrom, + final long timeTo) { + return wrapped().backwardFetch(key, timeFrom, timeTo); + } + + @Override + public KeyValueIterator, V> fetch(final K keyFrom, + final K keyTo, + final long timeFrom, + final long timeTo) { + return wrapped().fetch(keyFrom, keyTo, timeFrom, timeTo); + } + + @Override + public KeyValueIterator, V> backwardFetch(final K keyFrom, + final K keyTo, + final long timeFrom, + final long timeTo) { + return wrapped().backwardFetch(keyFrom, keyTo, timeFrom, timeTo); + } + + @Override + public KeyValueIterator, V> fetchAll(final long timeFrom, + final long timeTo) { + return wrapped().fetchAll(timeFrom, timeTo); + } + + @Override + public KeyValueIterator, V> backwardFetchAll(final long timeFrom, + final long timeTo) { + return wrapped().backwardFetchAll(timeFrom, timeTo); + } + + @Override + public KeyValueIterator, V> all() { + return wrapped().all(); + } + + @Override + public KeyValueIterator, V> backwardAll() { + return wrapped().backwardAll(); + } + } + + static class TimestampedWindowStoreReadWriteDecorator + extends WindowStoreReadWriteDecorator> + implements TimestampedWindowStore { + + TimestampedWindowStoreReadWriteDecorator(final TimestampedWindowStore inner) { + super(inner); + } + } + + static class SessionStoreReadWriteDecorator + extends AbstractReadWriteDecorator, K, AGG> + implements SessionStore { + + SessionStoreReadWriteDecorator(final SessionStore inner) { + super(inner); + } + + @Override + public KeyValueIterator, AGG> findSessions(final K key, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + return wrapped().findSessions(key, earliestSessionEndTime, latestSessionStartTime); + } + + @Override + public KeyValueIterator, AGG> findSessions(final K keyFrom, + final K keyTo, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + return wrapped().findSessions(keyFrom, keyTo, earliestSessionEndTime, latestSessionStartTime); + } + + @Override + public void remove(final Windowed sessionKey) { + wrapped().remove(sessionKey); + } + + @Override + public void put(final Windowed sessionKey, + final AGG aggregate) { + wrapped().put(sessionKey, aggregate); + } + + @Override + public AGG fetchSession(final K key, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + return wrapped().fetchSession(key, earliestSessionEndTime, latestSessionStartTime); + } + + @Override + public KeyValueIterator, AGG> fetch(final K key) { + return wrapped().fetch(key); + } + + @Override + public KeyValueIterator, AGG> fetch(final K keyFrom, + final K keyTo) { + return wrapped().fetch(keyFrom, keyTo); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractTask.java new file mode 100644 index 0000000..4e652a6 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/AbstractTask.java @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.errors.TaskMigratedException; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.TaskId; +import org.slf4j.Logger; + +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.apache.kafka.streams.processor.internals.Task.State.CLOSED; +import static org.apache.kafka.streams.processor.internals.Task.State.CREATED; + +public abstract class AbstractTask implements Task { + private final static long NO_DEADLINE = -1L; + + private Task.State state = CREATED; + private long deadlineMs = NO_DEADLINE; + + protected Set inputPartitions; + protected final Logger log; + protected final LogContext logContext; + protected final String logPrefix; + + /** + * If the checkpoint has not been loaded from the file yet (null), then we should not overwrite the checkpoint; + * If the checkpoint has been loaded from the file and has never been re-written (empty map), then we should re-write the checkpoint; + * If the checkpoint has been loaded from the file but has not been updated since, then we do not need to checkpoint; + * If the checkpoint has been loaded from the file and has been updated since, then we could overwrite the checkpoint; + */ + protected Map offsetSnapshotSinceLastFlush = null; + + protected final TaskId id; + protected final ProcessorTopology topology; + protected final StateDirectory stateDirectory; + protected final ProcessorStateManager stateMgr; + private final long taskTimeoutMs; + + AbstractTask(final TaskId id, + final ProcessorTopology topology, + final StateDirectory stateDirectory, + final ProcessorStateManager stateMgr, + final Set inputPartitions, + final long taskTimeoutMs, + final String taskType, + final Class clazz) { + this.id = id; + this.stateMgr = stateMgr; + this.topology = topology; + this.inputPartitions = inputPartitions; + this.stateDirectory = stateDirectory; + this.taskTimeoutMs = taskTimeoutMs; + + final String threadIdPrefix = String.format("stream-thread [%s] ", Thread.currentThread().getName()); + logPrefix = threadIdPrefix + String.format("%s [%s] ", taskType, id); + logContext = new LogContext(logPrefix); + log = logContext.logger(clazz); + } + + /** + * The following exceptions maybe thrown from the state manager flushing call + * + * @throws TaskMigratedException recoverable error sending changelog records that would cause the task to be removed + * @throws StreamsException fatal error when flushing the state store, for example sending changelog records failed + * or flushing state store get IO errors; such error should cause the thread to die + */ + protected void maybeWriteCheckpoint(final boolean enforceCheckpoint) { + final Map offsetSnapshot = stateMgr.changelogOffsets(); + if (StateManagerUtil.checkpointNeeded(enforceCheckpoint, offsetSnapshotSinceLastFlush, offsetSnapshot)) { + // the state's current offset would be used to checkpoint + stateMgr.flush(); + stateMgr.checkpoint(); + offsetSnapshotSinceLastFlush = new HashMap<>(offsetSnapshot); + } + } + + + @Override + public TaskId id() { + return id; + } + + @Override + public Set inputPartitions() { + return inputPartitions; + } + + @Override + public Collection changelogPartitions() { + return stateMgr.changelogPartitions(); + } + + @Override + public void markChangelogAsCorrupted(final Collection partitions) { + stateMgr.markChangelogAsCorrupted(partitions); + } + + @Override + public StateStore getStore(final String name) { + return stateMgr.getStore(name); + } + + @Override + public final Task.State state() { + return state; + } + + @Override + public void revive() { + if (state == CLOSED) { + clearTaskTimeout(); + transitionTo(CREATED); + } else { + throw new IllegalStateException("Illegal state " + state() + " while reviving task " + id); + } + } + + final void transitionTo(final Task.State newState) { + final State oldState = state(); + + if (oldState.isValidTransition(newState)) { + state = newState; + } else { + throw new IllegalStateException("Invalid transition from " + oldState + " to " + newState); + } + } + + @Override + public void updateInputPartitions(final Set topicPartitions, final Map> allTopologyNodesToSourceTopics) { + this.inputPartitions = topicPartitions; + topology.updateSourceTopics(allTopologyNodesToSourceTopics); + } + + @Override + public void maybeInitTaskTimeoutOrThrow(final long currentWallClockMs, + final Exception cause) { + if (deadlineMs == NO_DEADLINE) { + deadlineMs = currentWallClockMs + taskTimeoutMs; + } else if (currentWallClockMs > deadlineMs) { + final String errorMessage = String.format( + "Task %s did not make progress within %d ms. Adjust `%s` if needed.", + id, + currentWallClockMs - deadlineMs + taskTimeoutMs, + StreamsConfig.TASK_TIMEOUT_MS_CONFIG + ); + + if (cause != null) { + throw new StreamsException(new TimeoutException(errorMessage, cause), id); + } else { + throw new StreamsException(new TimeoutException(errorMessage), id); + } + } + + if (cause != null) { + log.debug( + String.format( + "Task did not make progress. Remaining time to deadline %d; retrying.", + deadlineMs - currentWallClockMs + ), + cause + ); + } else { + log.debug( + "Task did not make progress. Remaining time to deadline {}; retrying.", + deadlineMs - currentWallClockMs + ); + } + + } + + @Override + public void clearTaskTimeout() { + if (deadlineMs != NO_DEADLINE) { + log.debug("Clearing task timeout."); + deadlineMs = NO_DEADLINE; + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java new file mode 100644 index 0000000..75ec24c --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreator.java @@ -0,0 +1,339 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.KafkaClientSupplier; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.Task.TaskType; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.processor.internals.metrics.ThreadMetrics; +import org.apache.kafka.streams.state.internals.ThreadCache; +import org.slf4j.Logger; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; + +import static org.apache.kafka.common.utils.Utils.filterMap; +import static org.apache.kafka.streams.processor.internals.ClientUtils.getTaskProducerClientId; +import static org.apache.kafka.streams.processor.internals.ClientUtils.getThreadProducerClientId; +import static org.apache.kafka.streams.processor.internals.StreamThread.ProcessingMode.EXACTLY_ONCE_ALPHA; +import static org.apache.kafka.streams.processor.internals.StreamThread.ProcessingMode.EXACTLY_ONCE_V2; + +class ActiveTaskCreator { + private final TopologyMetadata topologyMetadata; + private final StreamsConfig config; + private final StreamsMetricsImpl streamsMetrics; + private final StateDirectory stateDirectory; + private final ChangelogReader storeChangelogReader; + private final ThreadCache cache; + private final Time time; + private final KafkaClientSupplier clientSupplier; + private final String threadId; + private final Logger log; + private final Sensor createTaskSensor; + private final StreamsProducer threadProducer; + private final Map taskProducers; + private final StreamThread.ProcessingMode processingMode; + + // Tasks may have been assigned for a NamedTopology that is not yet known by this host. When that occurs we stash + // these unknown tasks until either the corresponding NamedTopology is added and we can create them at last, or + // we receive a new assignment and they are revoked from the thread. + private final Map> unknownTasksToBeCreated = new HashMap<>(); + + ActiveTaskCreator(final TopologyMetadata topologyMetadata, + final StreamsConfig config, + final StreamsMetricsImpl streamsMetrics, + final StateDirectory stateDirectory, + final ChangelogReader storeChangelogReader, + final ThreadCache cache, + final Time time, + final KafkaClientSupplier clientSupplier, + final String threadId, + final UUID processId, + final Logger log) { + this.topologyMetadata = topologyMetadata; + this.config = config; + this.streamsMetrics = streamsMetrics; + this.stateDirectory = stateDirectory; + this.storeChangelogReader = storeChangelogReader; + this.cache = cache; + this.time = time; + this.clientSupplier = clientSupplier; + this.threadId = threadId; + this.log = log; + + createTaskSensor = ThreadMetrics.createTaskSensor(threadId, streamsMetrics); + processingMode = StreamThread.processingMode(config); + + if (processingMode == EXACTLY_ONCE_ALPHA) { + threadProducer = null; + taskProducers = new HashMap<>(); + } else { // non-eos and eos-v2 + log.info("Creating thread producer client"); + + final String threadIdPrefix = String.format("stream-thread [%s] ", Thread.currentThread().getName()); + final LogContext logContext = new LogContext(threadIdPrefix); + + threadProducer = new StreamsProducer( + config, + threadId, + clientSupplier, + null, + processId, + logContext, + time); + taskProducers = Collections.emptyMap(); + } + } + + public void reInitializeThreadProducer() { + threadProducer.resetProducer(); + } + + StreamsProducer streamsProducerForTask(final TaskId taskId) { + if (processingMode != EXACTLY_ONCE_ALPHA) { + throw new IllegalStateException("Expected EXACTLY_ONCE to be enabled, but the processing mode was " + processingMode); + } + + final StreamsProducer taskProducer = taskProducers.get(taskId); + if (taskProducer == null) { + throw new IllegalStateException("Unknown TaskId: " + taskId); + } + return taskProducer; + } + + StreamsProducer threadProducer() { + if (processingMode != EXACTLY_ONCE_V2) { + throw new IllegalStateException("Expected EXACTLY_ONCE_V2 to be enabled, but the processing mode was " + processingMode); + } + return threadProducer; + } + + void removeRevokedUnknownTasks(final Set assignedTasks) { + unknownTasksToBeCreated.keySet().retainAll(assignedTasks); + } + + Map> uncreatedTasksForTopologies(final Set currentTopologies) { + return filterMap(unknownTasksToBeCreated, t -> currentTopologies.contains(t.getKey().topologyName())); + } + + // TODO: change return type to `StreamTask` + Collection createTasks(final Consumer consumer, + final Map> tasksToBeCreated) { + // TODO: change type to `StreamTask` + final List createdTasks = new ArrayList<>(); + final Map> newUnknownTasks = new HashMap<>(); + + for (final Map.Entry> newTaskAndPartitions : tasksToBeCreated.entrySet()) { + final TaskId taskId = newTaskAndPartitions.getKey(); + final Set partitions = newTaskAndPartitions.getValue(); + + final LogContext logContext = getLogContext(taskId); + + final ProcessorTopology topology = topologyMetadata.buildSubtopology(taskId); + if (topology == null) { + // task belongs to a named topology that hasn't been added yet, wait until it has to create this + newUnknownTasks.put(taskId, partitions); + continue; + } + + final ProcessorStateManager stateManager = new ProcessorStateManager( + taskId, + Task.TaskType.ACTIVE, + StreamThread.eosEnabled(config), + logContext, + stateDirectory, + storeChangelogReader, + topology.storeToChangelogTopic(), + partitions + ); + + final InternalProcessorContext context = new ProcessorContextImpl( + taskId, + config, + stateManager, + streamsMetrics, + cache + ); + + createdTasks.add( + createActiveTask( + taskId, + partitions, + consumer, + logContext, + topology, + stateManager, + context + ) + ); + unknownTasksToBeCreated.remove(taskId); + } + if (!newUnknownTasks.isEmpty()) { + log.info("Delaying creation of tasks not yet known by this instance: {}", newUnknownTasks.keySet()); + unknownTasksToBeCreated.putAll(newUnknownTasks); + } + return createdTasks; + } + + + StreamTask createActiveTaskFromStandby(final StandbyTask standbyTask, + final Set inputPartitions, + final Consumer consumer) { + final InternalProcessorContext context = standbyTask.processorContext(); + final ProcessorStateManager stateManager = standbyTask.stateMgr; + final LogContext logContext = getLogContext(standbyTask.id); + + standbyTask.closeCleanAndRecycleState(); + stateManager.transitionTaskType(TaskType.ACTIVE, logContext); + + return createActiveTask( + standbyTask.id, + inputPartitions, + consumer, + logContext, + topologyMetadata.buildSubtopology(standbyTask.id), + stateManager, + context + ); + } + + private StreamTask createActiveTask(final TaskId taskId, + final Set inputPartitions, + final Consumer consumer, + final LogContext logContext, + final ProcessorTopology topology, + final ProcessorStateManager stateManager, + final InternalProcessorContext context) { + final StreamsProducer streamsProducer; + if (processingMode == StreamThread.ProcessingMode.EXACTLY_ONCE_ALPHA) { + log.info("Creating producer client for task {}", taskId); + streamsProducer = new StreamsProducer( + config, + threadId, + clientSupplier, + taskId, + null, + logContext, + time); + taskProducers.put(taskId, streamsProducer); + } else { + streamsProducer = threadProducer; + } + + final RecordCollector recordCollector = new RecordCollectorImpl( + logContext, + taskId, + streamsProducer, + config.defaultProductionExceptionHandler(), + streamsMetrics + ); + + final StreamTask task = new StreamTask( + taskId, + inputPartitions, + topology, + consumer, + config, + streamsMetrics, + stateDirectory, + cache, + time, + stateManager, + recordCollector, + context, + logContext + ); + + log.trace("Created task {} with assigned partitions {}", taskId, inputPartitions); + createTaskSensor.record(); + return task; + } + + void closeThreadProducerIfNeeded() { + if (threadProducer != null) { + try { + threadProducer.close(); + } catch (final RuntimeException e) { + throw new StreamsException("Thread producer encounter error trying to close.", e); + } + } + } + + void closeAndRemoveTaskProducerIfNeeded(final TaskId id) { + final StreamsProducer taskProducer = taskProducers.remove(id); + if (taskProducer != null) { + try { + taskProducer.close(); + } catch (final RuntimeException e) { + throw new StreamsException("[" + id + "] task producer encounter error trying to close.", e, id); + } + } + } + + Map producerMetrics() { + // When EOS is turned on, each task will have its own producer client + // and the producer object passed in here will be null. We would then iterate through + // all the active tasks and add their metrics to the output metrics map. + final Collection producers = threadProducer != null ? + Collections.singleton(threadProducer) : + taskProducers.values(); + return ClientUtils.producerMetrics(producers); + } + + Set producerClientIds() { + if (threadProducer != null) { + return Collections.singleton(getThreadProducerClientId(threadId)); + } else { + return taskProducers.keySet() + .stream() + .map(taskId -> getTaskProducerClientId(threadId, taskId)) + .collect(Collectors.toSet()); + } + } + + private LogContext getLogContext(final TaskId taskId) { + final String threadIdPrefix = String.format("stream-thread [%s] ", Thread.currentThread().getName()); + final String logPrefix = threadIdPrefix + String.format("%s [%s] ", "task", taskId); + return new LogContext(logPrefix); + } + + public double totalProducerBlockedTime() { + if (threadProducer != null) { + return threadProducer.totalBlockedTime(); + } + return taskProducers.values().stream() + .mapToDouble(StreamsProducer::totalBlockedTime) + .sum(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogReader.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogReader.java new file mode 100644 index 0000000..9c62dd1 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogReader.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.streams.processor.TaskId; + +import java.util.Map; +import java.util.Set; + +/** + * See {@link StoreChangelogReader}. + */ +public interface ChangelogReader extends ChangelogRegister { + /** + * Restore all registered state stores by reading from their changelogs + */ + void restore(final Map tasks); + + /** + * Transit to restore active changelogs mode + */ + void enforceRestoreActive(); + + /** + * Transit to update standby changelogs mode + */ + void transitToUpdateStandby(); + + /** + * @return the changelog partitions that have been completed restoring + */ + Set completedChangelogs(); + + /** + * Clear all partitions + */ + void clear(); + + /** + * @return whether the changelog reader has just been cleared or is uninitialized + */ + boolean isEmpty(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogRegister.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogRegister.java new file mode 100644 index 0000000..3f5f977 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogRegister.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import java.util.Collection; +import org.apache.kafka.common.TopicPartition; + +/** + * See {@link StoreChangelogReader}. + */ +public interface ChangelogRegister { + /** + * Register a state store for restoration. + * + * @param partition the state store's changelog partition for restoring + * @param stateManager the state manager used for restoring (one per task) + */ + void register(final TopicPartition partition, final ProcessorStateManager stateManager); + + /** + * Unregisters and removes the passed in partitions from the set of changelogs + * @param removedPartitions the set of partitions to remove + */ + void unregister(final Collection removedPartitions); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogTopics.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogTopics.java new file mode 100644 index 0000000..a2ad123 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ChangelogTopics.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder.TopicsInfo; +import org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology; + +import org.slf4j.Logger; + +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.UNKNOWN; + +public class ChangelogTopics { + + private final InternalTopicManager internalTopicManager; + private final Map topicGroups; + private final Map> tasksForTopicGroup; + private final Map> changelogPartitionsForStatefulTask = new HashMap<>(); + private final Map> preExistingChangelogPartitionsForTask = new HashMap<>(); + private final Set preExistingNonSourceTopicBasedChangelogPartitions = new HashSet<>(); + private final Set sourceTopicBasedChangelogTopics = new HashSet<>(); + private final Set preExsitingSourceTopicBasedChangelogPartitions = new HashSet<>(); + private final Logger log; + + public ChangelogTopics(final InternalTopicManager internalTopicManager, + final Map topicGroups, + final Map> tasksForTopicGroup, + final String logPrefix) { + this.internalTopicManager = internalTopicManager; + this.topicGroups = topicGroups; + this.tasksForTopicGroup = tasksForTopicGroup; + final LogContext logContext = new LogContext(logPrefix); + log = logContext.logger(getClass()); + } + + public void setup() { + // add tasks to state change log topic subscribers + final Map changelogTopicMetadata = new HashMap<>(); + for (final Map.Entry entry : topicGroups.entrySet()) { + final Subtopology subtopology = entry.getKey(); + final TopicsInfo topicsInfo = entry.getValue(); + + final Set topicGroupTasks = tasksForTopicGroup.get(subtopology); + if (topicGroupTasks == null) { + log.debug("No tasks found for subtopology {}", subtopology); + continue; + } else if (topicsInfo.stateChangelogTopics.isEmpty()) { + continue; + } + + for (final TaskId task : topicGroupTasks) { + final Set changelogTopicPartitions = topicsInfo.stateChangelogTopics + .keySet() + .stream() + .map(topic -> new TopicPartition(topic, task.partition())) + .collect(Collectors.toSet()); + changelogPartitionsForStatefulTask.put(task, changelogTopicPartitions); + } + + for (final InternalTopicConfig topicConfig : topicsInfo.nonSourceChangelogTopics()) { + // the expected number of partitions is the max value of TaskId.partition + 1 + int numPartitions = UNKNOWN; + for (final TaskId task : topicGroupTasks) { + if (numPartitions < task.partition() + 1) { + numPartitions = task.partition() + 1; + } + } + topicConfig.setNumberOfPartitions(numPartitions); + changelogTopicMetadata.put(topicConfig.name(), topicConfig); + } + sourceTopicBasedChangelogTopics.addAll(topicsInfo.sourceTopicChangelogs()); + } + + final Set newlyCreatedChangelogTopics = internalTopicManager.makeReady(changelogTopicMetadata); + log.debug("Created state changelog topics {} from the parsed topology.", changelogTopicMetadata.values()); + + for (final Map.Entry> entry : changelogPartitionsForStatefulTask.entrySet()) { + final TaskId taskId = entry.getKey(); + final Set topicPartitions = entry.getValue(); + for (final TopicPartition topicPartition : topicPartitions) { + if (!newlyCreatedChangelogTopics.contains(topicPartition.topic())) { + preExistingChangelogPartitionsForTask.computeIfAbsent(taskId, task -> new HashSet<>()).add(topicPartition); + if (!sourceTopicBasedChangelogTopics.contains(topicPartition.topic())) { + preExistingNonSourceTopicBasedChangelogPartitions.add(topicPartition); + } else { + preExsitingSourceTopicBasedChangelogPartitions.add(topicPartition); + } + } + } + } + } + + public Set preExistingNonSourceTopicBasedPartitions() { + return Collections.unmodifiableSet(preExistingNonSourceTopicBasedChangelogPartitions); + } + + public Set preExistingPartitionsFor(final TaskId taskId) { + if (preExistingChangelogPartitionsForTask.containsKey(taskId)) { + return Collections.unmodifiableSet(preExistingChangelogPartitionsForTask.get(taskId)); + } + return Collections.emptySet(); + } + + public Set preExistingSourceTopicBasedPartitions() { + return Collections.unmodifiableSet(preExsitingSourceTopicBasedChangelogPartitions); + } + + public Set statefulTaskIds() { + return Collections.unmodifiableSet(changelogPartitionsForStatefulTask.keySet()); + } +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ClientUtils.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ClientUtils.java new file mode 100644 index 0000000..a5807b5 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ClientUtils.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.ListOffsetsResult.ListOffsetsResultInfo; +import org.apache.kafka.clients.admin.OffsetSpec; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.TaskId; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.function.Function; +import java.util.stream.Collectors; + +public class ClientUtils { + private static final Logger LOG = LoggerFactory.getLogger(ClientUtils.class); + + public static final class QuietStreamsConfig extends StreamsConfig { + public QuietStreamsConfig(final Map props) { + super(props, false); + } + } + + public static final class QuietConsumerConfig extends ConsumerConfig { + public QuietConsumerConfig(final Map props) { + super(props, false); + } + } + + + // currently admin client is shared among all threads + public static String getSharedAdminClientId(final String clientId) { + return clientId + "-admin"; + } + + public static String getConsumerClientId(final String threadClientId) { + return threadClientId + "-consumer"; + } + + public static String getRestoreConsumerClientId(final String threadClientId) { + return threadClientId + "-restore-consumer"; + } + + public static String getThreadProducerClientId(final String threadClientId) { + return threadClientId + "-producer"; + } + + public static String getTaskProducerClientId(final String threadClientId, final TaskId taskId) { + return threadClientId + "-" + taskId + "-producer"; + } + + public static Map consumerMetrics(final Consumer mainConsumer, + final Consumer restoreConsumer) { + final Map consumerMetrics = mainConsumer.metrics(); + final Map restoreConsumerMetrics = restoreConsumer.metrics(); + final LinkedHashMap result = new LinkedHashMap<>(); + result.putAll(consumerMetrics); + result.putAll(restoreConsumerMetrics); + return result; + } + + public static Map adminClientMetrics(final Admin adminClient) { + final Map adminClientMetrics = adminClient.metrics(); + return new LinkedHashMap<>(adminClientMetrics); + } + + public static Map producerMetrics(final Collection producers) { + final Map result = new LinkedHashMap<>(); + for (final StreamsProducer producer : producers) { + final Map producerMetrics = producer.metrics(); + if (producerMetrics != null) { + result.putAll(producerMetrics); + } + } + return result; + } + + /** + * @throws StreamsException if the consumer throws an exception + * @throws org.apache.kafka.common.errors.TimeoutException if the request times out + */ + public static Map fetchCommittedOffsets(final Set partitions, + final Consumer consumer) { + if (partitions.isEmpty()) { + return Collections.emptyMap(); + } + + final Map committedOffsets; + try { + // those which do not have a committed offset would default to 0 + committedOffsets = consumer.committed(partitions).entrySet().stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue() == null ? 0L : e.getValue().offset())); + } catch (final TimeoutException timeoutException) { + LOG.warn("The committed offsets request timed out, try increasing the consumer client's default.api.timeout.ms", timeoutException); + throw timeoutException; + } catch (final KafkaException fatal) { + LOG.warn("The committed offsets request failed.", fatal); + throw new StreamsException(String.format("Failed to retrieve end offsets for %s", partitions), fatal); + } + + return committedOffsets; + } + + public static KafkaFuture> fetchEndOffsetsFuture(final Collection partitions, + final Admin adminClient) { + return adminClient.listOffsets( + partitions.stream().collect(Collectors.toMap(Function.identity(), tp -> OffsetSpec.latest())) + ).all(); + } + + /** + * A helper method that wraps the {@code Future#get} call and rethrows any thrown exception as a StreamsException + * @throws StreamsException if the admin client request throws an exception + */ + public static Map getEndOffsets(final KafkaFuture> endOffsetsFuture) { + try { + return endOffsetsFuture.get(); + } catch (final RuntimeException | InterruptedException | ExecutionException e) { + LOG.warn("The listOffsets request failed.", e); + throw new StreamsException("Unable to obtain end offsets from kafka", e); + } + } + + /** + * @throws StreamsException if the admin client request throws an exception + */ + public static Map fetchEndOffsets(final Collection partitions, + final Admin adminClient) { + if (partitions.isEmpty()) { + return Collections.emptyMap(); + } + return getEndOffsets(fetchEndOffsetsFuture(partitions, adminClient)); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultKafkaClientSupplier.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultKafkaClientSupplier.java new file mode 100644 index 0000000..15fec76 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultKafkaClientSupplier.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import java.util.Map; + +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.Producer; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.streams.KafkaClientSupplier; + +public class DefaultKafkaClientSupplier implements KafkaClientSupplier { + @Override + public Admin getAdmin(final Map config) { + // create a new client upon each call; but expect this call to be only triggered once so this should be fine + return Admin.create(config); + } + + @Override + public Producer getProducer(final Map config) { + return new KafkaProducer<>(config, new ByteArraySerializer(), new ByteArraySerializer()); + } + + @Override + public Consumer getConsumer(final Map config) { + return new KafkaConsumer<>(config, new ByteArrayDeserializer(), new ByteArrayDeserializer()); + } + + @Override + public Consumer getRestoreConsumer(final Map config) { + return new KafkaConsumer<>(config, new ByteArrayDeserializer(), new ByteArrayDeserializer()); + } + + @Override + public Consumer getGlobalConsumer(final Map config) { + return new KafkaConsumer<>(config, new ByteArrayDeserializer(), new ByteArrayDeserializer()); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStreamPartitioner.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStreamPartitioner.java new file mode 100644 index 0000000..a90a028 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStreamPartitioner.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.producer.internals.DefaultPartitioner; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.processor.StreamPartitioner; + +public class DefaultStreamPartitioner implements StreamPartitioner { + + private final Cluster cluster; + private final Serializer keySerializer; + private final DefaultPartitioner defaultPartitioner; + + public DefaultStreamPartitioner(final Serializer keySerializer, final Cluster cluster) { + this.cluster = cluster; + this.keySerializer = keySerializer; + this.defaultPartitioner = new DefaultPartitioner(); + } + + @Override + public Integer partition(final String topic, final K key, final V value, final int numPartitions) { + final byte[] keyBytes = keySerializer.serialize(topic, key); + return defaultPartitioner.partition(topic, key, keyBytes, value, null, cluster, numPartitions); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ForwardingDisabledProcessorContext.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ForwardingDisabledProcessorContext.java new file mode 100644 index 0000000..a2a2f09 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ForwardingDisabledProcessorContext.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.StreamsMetrics; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.Cancellable; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.Punctuator; +import org.apache.kafka.streams.processor.StateRestoreCallback; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.To; + +import java.io.File; +import java.time.Duration; +import java.util.Map; +import java.util.Objects; + +/** + * {@code ProcessorContext} implementation that will throw on any forward call. + */ +public final class ForwardingDisabledProcessorContext implements ProcessorContext { + private final ProcessorContext delegate; + + private static final String EXPLANATION = "ProcessorContext#forward() is not supported from this context, " + + "as the framework must ensure the key is not changed (#forward allows changing the key on " + + "messages which are sent). Try another function, which doesn't allow the key to be changed " + + "(for example - #tranformValues)."; + + public ForwardingDisabledProcessorContext(final ProcessorContext delegate) { + this.delegate = Objects.requireNonNull(delegate, "delegate"); + } + + @Override + public String applicationId() { + return delegate.applicationId(); + } + + @Override + public TaskId taskId() { + return delegate.taskId(); + } + + @Override + public Serde keySerde() { + return delegate.keySerde(); + } + + @Override + public Serde valueSerde() { + return delegate.valueSerde(); + } + + @Override + public File stateDir() { + return delegate.stateDir(); + } + + @Override + public StreamsMetrics metrics() { + return delegate.metrics(); + } + + @Override + public void register(final StateStore store, + final StateRestoreCallback stateRestoreCallback) { + delegate.register(store, stateRestoreCallback); + } + + @Override + public S getStateStore(final String name) { + return delegate.getStateStore(name); + } + + @Override + public Cancellable schedule(final Duration interval, + final PunctuationType type, + final Punctuator callback) throws IllegalArgumentException { + return delegate.schedule(interval, type, callback); + } + + @Override + public void forward(final K key, final V value) { + throw new StreamsException(EXPLANATION); + } + + @Override + public void forward(final K key, final V value, final To to) { + throw new StreamsException(EXPLANATION); + } + + @Override + public void commit() { + delegate.commit(); + } + + @Override + public String topic() { + return delegate.topic(); + } + + @Override + public int partition() { + return delegate.partition(); + } + + @Override + public long offset() { + return delegate.offset(); + } + + @Override + public Headers headers() { + return delegate.headers(); + } + + @Override + public long timestamp() { + return delegate.timestamp(); + } + + @Override + public Map appConfigs() { + return delegate.appConfigs(); + } + + @Override + public Map appConfigsWithPrefix(final String prefix) { + return delegate.appConfigsWithPrefix(prefix); + } + + @Override + public long currentSystemTimeMs() { + return delegate.currentSystemTimeMs(); + } + + @Override + public long currentStreamTimeMs() { + return delegate.currentStreamTimeMs(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalProcessorContextImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalProcessorContextImpl.java new file mode 100644 index 0000000..dbdd6a2 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalProcessorContextImpl.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.processor.Cancellable; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.Punctuator; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.To; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.internals.ThreadCache; +import org.apache.kafka.streams.state.internals.ThreadCache.DirtyEntryFlushListener; + +import java.time.Duration; + +import static org.apache.kafka.streams.processor.internals.AbstractReadWriteDecorator.getReadWriteStore; + +public class GlobalProcessorContextImpl extends AbstractProcessorContext { + + private final GlobalStateManager stateManager; + private final Time time; + + public GlobalProcessorContextImpl(final StreamsConfig config, + final GlobalStateManager stateMgr, + final StreamsMetricsImpl metrics, + final ThreadCache cache, + final Time time) { + super(new TaskId(-1, -1), config, metrics, cache); + stateManager = stateMgr; + this.time = time; + } + + @Override + protected StateManager stateManager() { + return stateManager; + } + + @SuppressWarnings("unchecked") + @Override + public S getStateStore(final String name) { + final StateStore store = stateManager.getGlobalStore(name); + return (S) getReadWriteStore(store); + } + + @SuppressWarnings("unchecked") + @Override + public void forward(final Record record) { + final ProcessorNode previousNode = currentNode(); + try { + for (final ProcessorNode child : currentNode().children()) { + setCurrentNode(child); + ((ProcessorNode) child).process(record); + } + } finally { + setCurrentNode(previousNode); + } + } + + @Override + public void forward(final Record record, final String childName) { + throw new UnsupportedOperationException("this should not happen: forward() not supported in global processor context."); + } + + @Override + public void forward(final KIn key, final VIn value) { + forward(new Record<>(key, value, timestamp(), headers())); + } + + /** + * No-op. This should only be called on GlobalStateStore#flush and there should be no child nodes + */ + @Override + public void forward(final K key, final V value, final To to) { + if (!currentNode().children().isEmpty()) { + throw new IllegalStateException("This method should only be called on 'GlobalStateStore.flush' that should not have any children."); + } + } + + @Override + public void commit() { + //no-op + } + + @Override + public long currentSystemTimeMs() { + return time.milliseconds(); + } + + @Override + public long currentStreamTimeMs() { + throw new UnsupportedOperationException("There is no concept of stream-time for a global processor."); + } + + /** + * @throws UnsupportedOperationException on every invocation + */ + @Override + public Cancellable schedule(final Duration interval, final PunctuationType type, final Punctuator callback) { + throw new UnsupportedOperationException("this should not happen: schedule() not supported in global processor context."); + } + + @Override + public void logChange(final String storeName, + final Bytes key, + final byte[] value, + final long timestamp) { + throw new UnsupportedOperationException("this should not happen: logChange() not supported in global processor context."); + } + + @Override + public void transitionToActive(final StreamTask streamTask, final RecordCollector recordCollector, final ThreadCache newCache) { + throw new UnsupportedOperationException("this should not happen: transitionToActive() not supported in global processor context."); + } + + @Override + public void transitionToStandby(final ThreadCache newCache) { + throw new UnsupportedOperationException("this should not happen: transitionToStandby() not supported in global processor context."); + } + + @Override + public void registerCacheFlushListener(final String namespace, final DirtyEntryFlushListener listener) { + cache.addDirtyEntryFlushListener(namespace, listener); + } +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateMaintainer.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateMaintainer.java new file mode 100644 index 0000000..9a8aab6 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateMaintainer.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.TopicPartition; + +import java.io.IOException; +import java.util.Map; + +/** + * Interface for maintaining global state stores. see {@link GlobalStateUpdateTask} + */ +interface GlobalStateMaintainer { + + Map initialize(); + + void flushState(); + + void close(final boolean wipeStateStore) throws IOException; + + void update(ConsumerRecord record); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManager.java new file mode 100644 index 0000000..479fd1f --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManager.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.streams.errors.StreamsException; + +import java.util.Set; + +public interface GlobalStateManager extends StateManager { + + void setGlobalProcessorContext(final InternalProcessorContext processorContext); + + /** + * @throws IllegalStateException If store gets registered after initialized is already finished + * @throws StreamsException if the store's change log does not contain the partition + */ + Set initialize(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java new file mode 100644 index 0000000..090621d --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java @@ -0,0 +1,438 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.apache.kafka.common.utils.FixedOrderMap; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.ProcessorStateException; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.StateRestoreCallback; +import org.apache.kafka.streams.processor.StateRestoreListener; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.internals.Task.TaskType; +import org.apache.kafka.streams.state.internals.OffsetCheckpoint; +import org.apache.kafka.streams.state.internals.RecordConverter; +import org.slf4j.Logger; + +import java.io.File; +import java.io.IOException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.Supplier; + +import static org.apache.kafka.streams.processor.internals.StateManagerUtil.CHECKPOINT_FILE_NAME; +import static org.apache.kafka.streams.processor.internals.StateManagerUtil.converterForStore; + +/** + * This class is responsible for the initialization, restoration, closing, flushing etc + * of Global State Stores. There is only ever 1 instance of this class per Application Instance. + */ +public class GlobalStateManagerImpl implements GlobalStateManager { + private final static long NO_DEADLINE = -1L; + + private final Time time; + private final Logger log; + private final File baseDir; + private final long taskTimeoutMs; + private final ProcessorTopology topology; + private final OffsetCheckpoint checkpointFile; + private final Duration pollMsPlusRequestTimeout; + private final Consumer globalConsumer; + private final StateRestoreListener stateRestoreListener; + private final Map checkpointFileCache; + private final Map storeToChangelogTopic; + private final Set globalStoreNames = new HashSet<>(); + private final Set globalNonPersistentStoresTopics = new HashSet<>(); + private final FixedOrderMap> globalStores = new FixedOrderMap<>(); + + private InternalProcessorContext globalProcessorContext; + + public GlobalStateManagerImpl(final LogContext logContext, + final Time time, + final ProcessorTopology topology, + final Consumer globalConsumer, + final StateDirectory stateDirectory, + final StateRestoreListener stateRestoreListener, + final StreamsConfig config) { + this.time = time; + this.topology = topology; + baseDir = stateDirectory.globalStateDir(); + storeToChangelogTopic = topology.storeToChangelogTopic(); + checkpointFile = new OffsetCheckpoint(new File(baseDir, CHECKPOINT_FILE_NAME)); + checkpointFileCache = new HashMap<>(); + + // Find non persistent store's topics + for (final StateStore store : topology.globalStateStores()) { + globalStoreNames.add(store.name()); + if (!store.persistent()) { + globalNonPersistentStoresTopics.add(changelogFor(store.name())); + } + } + + log = logContext.logger(GlobalStateManagerImpl.class); + this.globalConsumer = globalConsumer; + this.stateRestoreListener = stateRestoreListener; + + final Map consumerProps = config.getGlobalConsumerConfigs("dummy"); + // need to add mandatory configs; otherwise `QuietConsumerConfig` throws + consumerProps.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class); + consumerProps.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class); + final int requestTimeoutMs = new ClientUtils.QuietConsumerConfig(consumerProps) + .getInt(ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG); + pollMsPlusRequestTimeout = Duration.ofMillis( + config.getLong(StreamsConfig.POLL_MS_CONFIG) + requestTimeoutMs + ); + taskTimeoutMs = config.getLong(StreamsConfig.TASK_TIMEOUT_MS_CONFIG); + } + + @Override + public void setGlobalProcessorContext(final InternalProcessorContext globalProcessorContext) { + this.globalProcessorContext = globalProcessorContext; + } + + @Override + public Set initialize() { + try { + checkpointFileCache.putAll(checkpointFile.read()); + } catch (final IOException e) { + throw new StreamsException("Failed to read checkpoints for global state globalStores", e); + } + + final Set changelogTopics = new HashSet<>(); + for (final StateStore stateStore : topology.globalStateStores()) { + final String sourceTopic = storeToChangelogTopic.get(stateStore.name()); + changelogTopics.add(sourceTopic); + stateStore.init((StateStoreContext) globalProcessorContext, stateStore); + } + + // make sure each topic-partition from checkpointFileCache is associated with a global state store + checkpointFileCache.keySet().forEach(tp -> { + if (!changelogTopics.contains(tp.topic())) { + log.error( + "Encountered a topic-partition in the global checkpoint file not associated with any global" + + " state store, topic-partition: {}, checkpoint file: {}. If this topic-partition is no longer valid," + + " an application reset and state store directory cleanup will be required.", + tp.topic(), + checkpointFile.toString() + ); + throw new StreamsException("Encountered a topic-partition not associated with any global state store"); + } + }); + return Collections.unmodifiableSet(globalStoreNames); + } + + public StateStore getGlobalStore(final String name) { + return globalStores.getOrDefault(name, Optional.empty()).orElse(null); + } + + @Override + public StateStore getStore(final String name) { + return getGlobalStore(name); + } + + public File baseDir() { + return baseDir; + } + + @Override + public void registerStore(final StateStore store, final StateRestoreCallback stateRestoreCallback) { + log.info("Restoring state for global store {}", store.name()); + + // TODO (KAFKA-12887): we should not trigger user's exception handler for illegal-argument but always + // fail-crash; in this case we would not need to immediately close the state store before throwing + if (globalStores.containsKey(store.name())) { + store.close(); + throw new IllegalArgumentException(String.format("Global Store %s has already been registered", store.name())); + } + + if (!globalStoreNames.contains(store.name())) { + store.close(); + throw new IllegalArgumentException(String.format("Trying to register store %s that is not a known global store", store.name())); + } + + // register the store first, so that if later an exception is thrown then eventually while we call `close` + // on the state manager this state store would be closed as well + globalStores.put(store.name(), Optional.of(store)); + + if (stateRestoreCallback == null) { + throw new IllegalArgumentException(String.format("The stateRestoreCallback provided for store %s was null", store.name())); + } + + final List topicPartitions = topicPartitionsForStore(store); + final Map highWatermarks = retryUntilSuccessOrThrowOnTaskTimeout( + () -> globalConsumer.endOffsets(topicPartitions), + String.format( + "Failed to get offsets for partitions %s. The broker may be transiently unavailable at the moment.", + topicPartitions + ) + ); + + try { + restoreState( + stateRestoreCallback, + topicPartitions, + highWatermarks, + store.name(), + converterForStore(store) + ); + } finally { + globalConsumer.unsubscribe(); + } + } + + private List topicPartitionsForStore(final StateStore store) { + final String sourceTopic = storeToChangelogTopic.get(store.name()); + + final List partitionInfos = retryUntilSuccessOrThrowOnTaskTimeout( + () -> globalConsumer.partitionsFor(sourceTopic), + String.format( + "Failed to get partitions for topic %s. The broker may be transiently unavailable at the moment.", + sourceTopic + ) + ); + + if (partitionInfos == null || partitionInfos.isEmpty()) { + throw new StreamsException(String.format("There are no partitions available for topic %s when initializing global store %s", sourceTopic, store.name())); + } + + final List topicPartitions = new ArrayList<>(); + for (final PartitionInfo partition : partitionInfos) { + topicPartitions.add(new TopicPartition(partition.topic(), partition.partition())); + } + return topicPartitions; + } + + private void restoreState(final StateRestoreCallback stateRestoreCallback, + final List topicPartitions, + final Map highWatermarks, + final String storeName, + final RecordConverter recordConverter) { + for (final TopicPartition topicPartition : topicPartitions) { + long currentDeadline = NO_DEADLINE; + + globalConsumer.assign(Collections.singletonList(topicPartition)); + long offset; + final Long checkpoint = checkpointFileCache.get(topicPartition); + if (checkpoint != null) { + globalConsumer.seek(topicPartition, checkpoint); + offset = checkpoint; + } else { + globalConsumer.seekToBeginning(Collections.singletonList(topicPartition)); + offset = getGlobalConsumerOffset(topicPartition); + } + + final Long highWatermark = highWatermarks.get(topicPartition); + final RecordBatchingStateRestoreCallback stateRestoreAdapter = + StateRestoreCallbackAdapter.adapt(stateRestoreCallback); + + stateRestoreListener.onRestoreStart(topicPartition, storeName, offset, highWatermark); + long restoreCount = 0L; + + while (offset < highWatermark) { + // we add `request.timeout.ms` to `poll.ms` because `poll.ms` might be too short + // to give a fetch request a fair chance to actually complete and we don't want to + // start `task.timeout.ms` too early + // + // TODO with https://issues.apache.org/jira/browse/KAFKA-10315 we can just call + // `poll(pollMS)` without adding the request timeout and do a more precise + // timeout handling + final ConsumerRecords records = globalConsumer.poll(pollMsPlusRequestTimeout); + if (records.isEmpty()) { + currentDeadline = maybeUpdateDeadlineOrThrow(currentDeadline); + } else { + currentDeadline = NO_DEADLINE; + } + + final List> restoreRecords = new ArrayList<>(); + for (final ConsumerRecord record : records.records(topicPartition)) { + if (record.key() != null) { + restoreRecords.add(recordConverter.convert(record)); + } + } + + offset = getGlobalConsumerOffset(topicPartition); + + stateRestoreAdapter.restoreBatch(restoreRecords); + stateRestoreListener.onBatchRestored(topicPartition, storeName, offset, restoreRecords.size()); + restoreCount += restoreRecords.size(); + } + stateRestoreListener.onRestoreEnd(topicPartition, storeName, restoreCount); + checkpointFileCache.put(topicPartition, offset); + } + } + + private long getGlobalConsumerOffset(final TopicPartition topicPartition) { + return retryUntilSuccessOrThrowOnTaskTimeout( + () -> globalConsumer.position(topicPartition), + String.format( + "Failed to get position for partition %s. The broker may be transiently unavailable at the moment.", + topicPartition + ) + ); + } + + private R retryUntilSuccessOrThrowOnTaskTimeout(final Supplier supplier, + final String errorMessage) { + long deadlineMs = NO_DEADLINE; + + do { + try { + return supplier.get(); + } catch (final TimeoutException retriableException) { + if (taskTimeoutMs == 0L) { + throw new StreamsException( + String.format( + "Retrying is disabled. You can enable it by setting `%s` to a value larger than zero.", + StreamsConfig.TASK_TIMEOUT_MS_CONFIG + ), + retriableException + ); + } + + deadlineMs = maybeUpdateDeadlineOrThrow(deadlineMs); + + log.warn(errorMessage, retriableException); + } + } while (true); + } + + private long maybeUpdateDeadlineOrThrow(final long currentDeadlineMs) { + final long currentWallClockMs = time.milliseconds(); + + if (currentDeadlineMs == NO_DEADLINE) { + final long newDeadlineMs = currentWallClockMs + taskTimeoutMs; + return newDeadlineMs < 0L ? Long.MAX_VALUE : newDeadlineMs; + } else if (currentWallClockMs >= currentDeadlineMs) { + throw new TimeoutException(String.format( + "Global task did not make progress to restore state within %d ms. Adjust `%s` if needed.", + currentWallClockMs - currentDeadlineMs + taskTimeoutMs, + StreamsConfig.TASK_TIMEOUT_MS_CONFIG + )); + } + + return currentDeadlineMs; + } + + @Override + public void flush() { + log.debug("Flushing all global globalStores registered in the state manager"); + for (final Map.Entry> entry : globalStores.entrySet()) { + if (entry.getValue().isPresent()) { + final StateStore store = entry.getValue().get(); + try { + log.trace("Flushing global store={}", store.name()); + store.flush(); + } catch (final RuntimeException e) { + throw new ProcessorStateException( + String.format("Failed to flush global state store %s", store.name()), + e + ); + } + } else { + throw new IllegalStateException("Expected " + entry.getKey() + " to have been initialized"); + } + } + } + + @Override + public void close() { + if (globalStores.isEmpty()) { + return; + } + final StringBuilder closeFailed = new StringBuilder(); + for (final Map.Entry> entry : globalStores.entrySet()) { + if (entry.getValue().isPresent()) { + log.debug("Closing global storage engine {}", entry.getKey()); + try { + entry.getValue().get().close(); + } catch (final RuntimeException e) { + log.error("Failed to close global state store {}", entry.getKey(), e); + closeFailed.append("Failed to close global state store:") + .append(entry.getKey()) + .append(". Reason: ") + .append(e) + .append("\n"); + } + globalStores.put(entry.getKey(), Optional.empty()); + } else { + log.info("Skipping to close non-initialized store {}", entry.getKey()); + } + } + + if (closeFailed.length() > 0) { + throw new ProcessorStateException("Exceptions caught during close of 1 or more global state globalStores\n" + closeFailed); + } + } + + @Override + public void updateChangelogOffsets(final Map offsets) { + checkpointFileCache.putAll(offsets); + } + + @Override + public void checkpoint() { + final Map filteredOffsets = new HashMap<>(); + + // Skip non persistent store + for (final Map.Entry topicPartitionOffset : checkpointFileCache.entrySet()) { + final String topic = topicPartitionOffset.getKey().topic(); + if (!globalNonPersistentStoresTopics.contains(topic)) { + filteredOffsets.put(topicPartitionOffset.getKey(), topicPartitionOffset.getValue()); + } + } + + try { + checkpointFile.write(filteredOffsets); + } catch (final IOException e) { + log.warn("Failed to write offset checkpoint file to {} for global stores: {}." + + " This may occur if OS cleaned the state.dir in case when it is located in the (default) ${java.io.tmpdir}/kafka-streams directory." + + " Changing the location of state.dir may resolve the problem", checkpointFile, e); + } + } + + @Override + public TaskType taskType() { + return TaskType.GLOBAL; + } + + @Override + public Map changelogOffsets() { + return Collections.unmodifiableMap(checkpointFileCache); + } + + public String changelogFor(final String storeName) { + return storeToChangelogTopic.get(storeName); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateUpdateTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateUpdateTask.java new file mode 100644 index 0000000..5232285 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateUpdateTask.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.errors.DeserializationExceptionHandler; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.api.Record; +import org.slf4j.Logger; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +import static org.apache.kafka.streams.processor.internals.metrics.TaskMetrics.droppedRecordsSensor; + +/** + * Updates the state for all Global State Stores. + */ +public class GlobalStateUpdateTask implements GlobalStateMaintainer { + private final Logger log; + private final LogContext logContext; + + private final ProcessorTopology topology; + private final InternalProcessorContext processorContext; + private final Map offsets = new HashMap<>(); + private final Map deserializers = new HashMap<>(); + private final GlobalStateManager stateMgr; + private final DeserializationExceptionHandler deserializationExceptionHandler; + + public GlobalStateUpdateTask(final LogContext logContext, + final ProcessorTopology topology, + final InternalProcessorContext processorContext, + final GlobalStateManager stateMgr, + final DeserializationExceptionHandler deserializationExceptionHandler) { + this.logContext = logContext; + this.log = logContext.logger(getClass()); + this.topology = topology; + this.stateMgr = stateMgr; + this.processorContext = processorContext; + this.deserializationExceptionHandler = deserializationExceptionHandler; + } + + /** + * @throws IllegalStateException If store gets registered after initialized is already finished + * @throws StreamsException If the store's change log does not contain the partition + */ + @Override + public Map initialize() { + final Set storeNames = stateMgr.initialize(); + final Map storeNameToTopic = topology.storeToChangelogTopic(); + for (final String storeName : storeNames) { + final String sourceTopic = storeNameToTopic.get(storeName); + final SourceNode source = topology.source(sourceTopic); + deserializers.put( + sourceTopic, + new RecordDeserializer( + source, + deserializationExceptionHandler, + logContext, + droppedRecordsSensor( + Thread.currentThread().getName(), + processorContext.taskId().toString(), + processorContext.metrics() + ) + ) + ); + } + initTopology(); + processorContext.initialize(); + return stateMgr.changelogOffsets(); + } + + @SuppressWarnings("unchecked") + @Override + public void update(final ConsumerRecord record) { + final RecordDeserializer sourceNodeAndDeserializer = deserializers.get(record.topic()); + final ConsumerRecord deserialized = sourceNodeAndDeserializer.deserialize(processorContext, record); + + if (deserialized != null) { + final ProcessorRecordContext recordContext = + new ProcessorRecordContext( + deserialized.timestamp(), + deserialized.offset(), + deserialized.partition(), + deserialized.topic(), + deserialized.headers()); + processorContext.setRecordContext(recordContext); + processorContext.setCurrentNode(sourceNodeAndDeserializer.sourceNode()); + final Record toProcess = new Record<>( + deserialized.key(), + deserialized.value(), + processorContext.timestamp(), + processorContext.headers() + ); + ((SourceNode) sourceNodeAndDeserializer.sourceNode()).process(toProcess); + } + + offsets.put(new TopicPartition(record.topic(), record.partition()), record.offset() + 1); + } + + public void flushState() { + // this could theoretically throw a ProcessorStateException caused by a ProducerFencedException, + // but in practice this shouldn't happen for global state update tasks, since the stores are not + // logged and there are no downstream operators after global stores. + stateMgr.flush(); + stateMgr.updateChangelogOffsets(offsets); + stateMgr.checkpoint(); + } + + public void close(final boolean wipeStateStore) throws IOException { + stateMgr.close(); + if (wipeStateStore) { + try { + log.info("Deleting global task directory after detecting corruption."); + Utils.delete(stateMgr.baseDir()); + } catch (final IOException e) { + log.error("Failed to delete global task directory after detecting corruption.", e); + } + } + } + + @SuppressWarnings("unchecked") + private void initTopology() { + for (final ProcessorNode node : this.topology.processors()) { + processorContext.setCurrentNode(node); + try { + node.init(this.processorContext); + } finally { + processorContext.setCurrentNode(null); + } + } + } + + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java new file mode 100644 index 0000000..f45350d --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStreamThread.java @@ -0,0 +1,456 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.InvalidOffsetException; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.StateRestoreListener; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.internals.ThreadCache; +import org.slf4j.Logger; + +import java.io.IOException; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.atomic.AtomicLong; + +import static org.apache.kafka.streams.processor.internals.GlobalStreamThread.State.CREATED; +import static org.apache.kafka.streams.processor.internals.GlobalStreamThread.State.DEAD; +import static org.apache.kafka.streams.processor.internals.GlobalStreamThread.State.PENDING_SHUTDOWN; +import static org.apache.kafka.streams.processor.internals.GlobalStreamThread.State.RUNNING; + +/** + * This is the thread responsible for keeping all Global State Stores updated. + * It delegates most of the responsibility to the internal class StateConsumer + */ +public class GlobalStreamThread extends Thread { + + private final Logger log; + private final LogContext logContext; + private final StreamsConfig config; + private final Consumer globalConsumer; + private final StateDirectory stateDirectory; + private final Time time; + private final ThreadCache cache; + private final StreamsMetricsImpl streamsMetrics; + private final ProcessorTopology topology; + private final AtomicLong cacheSize; + private volatile StreamsException startupException; + private java.util.function.Consumer streamsUncaughtExceptionHandler; + + /** + * The states that the global stream thread can be in + * + *

                +     *                +-------------+
                +     *          +<--- | Created (0) |
                +     *          |     +-----+-------+
                +     *          |           |
                +     *          |           v
                +     *          |     +-----+-------+
                +     *          +<--- | Running (1) |
                +     *          |     +-----+-------+
                +     *          |           |
                +     *          |           v
                +     *          |     +-----+-------+
                +     *          +---> | Pending     |
                +     *                | Shutdown (2)|
                +     *                +-----+-------+
                +     *                      |
                +     *                      v
                +     *                +-----+-------+
                +     *                | Dead (3)    |
                +     *                +-------------+
                +     * 
                + * + * Note the following: + *
                  + *
                • Any state can go to PENDING_SHUTDOWN. That is because streams can be closed at any time.
                • + *
                • State PENDING_SHUTDOWN may want to transit itself. In this case we will forbid the transition but will not treat as an error.
                • + *
                + */ + public enum State implements ThreadStateTransitionValidator { + CREATED(1, 2), RUNNING(2), PENDING_SHUTDOWN(3), DEAD; + + private final Set validTransitions = new HashSet<>(); + + State(final Integer... validTransitions) { + this.validTransitions.addAll(Arrays.asList(validTransitions)); + } + + public boolean isRunning() { + return equals(RUNNING); + } + + public boolean inErrorState() { + return equals(DEAD) || equals(PENDING_SHUTDOWN); + } + + @Override + public boolean isValidTransition(final ThreadStateTransitionValidator newState) { + final State tmpState = (State) newState; + return validTransitions.contains(tmpState.ordinal()); + } + } + + private volatile State state = State.CREATED; + private final Object stateLock = new Object(); + private StreamThread.StateListener stateListener = null; + private final String logPrefix; + private final StateRestoreListener stateRestoreListener; + + /** + * Set the {@link StreamThread.StateListener} to be notified when state changes. Note this API is internal to + * Kafka Streams and is not intended to be used by an external application. + */ + public void setStateListener(final StreamThread.StateListener listener) { + stateListener = listener; + } + + /** + * @return The state this instance is in + */ + public State state() { + // we do not need to use the stat lock since the variable is volatile + return state; + } + + /** + * Sets the state + * + * @param newState New state + */ + private void setState(final State newState) { + final State oldState = state; + + synchronized (stateLock) { + if (state == State.PENDING_SHUTDOWN && newState == State.PENDING_SHUTDOWN) { + // when the state is already in PENDING_SHUTDOWN, its transition to itself + // will be refused but we do not throw exception here + return; + } else if (state == State.DEAD) { + // when the state is already in NOT_RUNNING, all its transitions + // will be refused but we do not throw exception here + return; + } else if (!state.isValidTransition(newState)) { + log.error("Unexpected state transition from {} to {}", oldState, newState); + throw new StreamsException(logPrefix + "Unexpected state transition from " + oldState + " to " + newState); + } else { + log.info("State transition from {} to {}", oldState, newState); + } + + state = newState; + } + + if (stateListener != null) { + stateListener.onChange(this, state, oldState); + } + } + + public boolean stillRunning() { + synchronized (stateLock) { + return state.isRunning(); + } + } + + public boolean inErrorState() { + synchronized (stateLock) { + return state.inErrorState(); + } + } + + public boolean stillInitializing() { + synchronized (stateLock) { + return state.equals(CREATED); + } + } + + public GlobalStreamThread(final ProcessorTopology topology, + final StreamsConfig config, + final Consumer globalConsumer, + final StateDirectory stateDirectory, + final long cacheSizeBytes, + final StreamsMetricsImpl streamsMetrics, + final Time time, + final String threadClientId, + final StateRestoreListener stateRestoreListener, + final java.util.function.Consumer streamsUncaughtExceptionHandler) { + super(threadClientId); + this.time = time; + this.config = config; + this.topology = topology; + this.globalConsumer = globalConsumer; + this.stateDirectory = stateDirectory; + this.streamsMetrics = streamsMetrics; + this.logPrefix = String.format("global-stream-thread [%s] ", threadClientId); + this.logContext = new LogContext(logPrefix); + this.log = logContext.logger(getClass()); + this.cache = new ThreadCache(logContext, cacheSizeBytes, this.streamsMetrics); + this.stateRestoreListener = stateRestoreListener; + this.streamsUncaughtExceptionHandler = streamsUncaughtExceptionHandler; + this.cacheSize = new AtomicLong(-1L); + } + + static class StateConsumer { + private final Consumer globalConsumer; + private final GlobalStateMaintainer stateMaintainer; + private final Time time; + private final Duration pollTime; + private final long flushInterval; + private final Logger log; + + private long lastFlush; + + StateConsumer(final LogContext logContext, + final Consumer globalConsumer, + final GlobalStateMaintainer stateMaintainer, + final Time time, + final Duration pollTime, + final long flushInterval) { + this.log = logContext.logger(getClass()); + this.globalConsumer = globalConsumer; + this.stateMaintainer = stateMaintainer; + this.time = time; + this.pollTime = pollTime; + this.flushInterval = flushInterval; + } + + /** + * @throws IllegalStateException If store gets registered after initialized is already finished + * @throws StreamsException if the store's change log does not contain the partition + */ + void initialize() { + final Map partitionOffsets = stateMaintainer.initialize(); + globalConsumer.assign(partitionOffsets.keySet()); + for (final Map.Entry entry : partitionOffsets.entrySet()) { + globalConsumer.seek(entry.getKey(), entry.getValue()); + } + lastFlush = time.milliseconds(); + } + + void pollAndUpdate() { + final ConsumerRecords received = globalConsumer.poll(pollTime); + for (final ConsumerRecord record : received) { + stateMaintainer.update(record); + } + final long now = time.milliseconds(); + if (now - flushInterval >= lastFlush) { + stateMaintainer.flushState(); + lastFlush = now; + } + } + + public void close(final boolean wipeStateStore) throws IOException { + try { + globalConsumer.close(); + } catch (final RuntimeException e) { + // just log an error if the consumer throws an exception during close + // so we can always attempt to close the state stores. + log.error("Failed to close global consumer due to the following error:", e); + } + + stateMaintainer.close(wipeStateStore); + } + } + + @Override + public void run() { + final StateConsumer stateConsumer = initialize(); + + if (stateConsumer == null) { + // during initialization, the caller thread would wait for the state consumer + // to restore the global state store before transiting to RUNNING state and return; + // if an error happens during the restoration process, the stateConsumer will be null + // and in this case we will transit the state to PENDING_SHUTDOWN and DEAD immediately. + // the exception will be thrown in the caller thread during start() function. + setState(State.PENDING_SHUTDOWN); + setState(State.DEAD); + + log.warn("Error happened during initialization of the global state store; this thread has shutdown"); + streamsMetrics.removeAllThreadLevelSensors(getName()); + streamsMetrics.removeAllThreadLevelMetrics(getName()); + + return; + } + setState(RUNNING); + + boolean wipeStateStore = false; + try { + while (stillRunning()) { + final long size = cacheSize.getAndSet(-1L); + if (size != -1L) { + cache.resize(size); + } + stateConsumer.pollAndUpdate(); + } + } catch (final InvalidOffsetException recoverableException) { + wipeStateStore = true; + log.error( + "Updating global state failed due to inconsistent local state. Will attempt to clean up the local state. You can restart KafkaStreams to recover from this error.", + recoverableException + ); + final StreamsException e = new StreamsException( + "Updating global state failed. You can restart KafkaStreams to launch a new GlobalStreamThread to recover from this error.", + recoverableException + ); + this.streamsUncaughtExceptionHandler.accept(e); + } catch (final Exception e) { + this.streamsUncaughtExceptionHandler.accept(e); + } finally { + // set the state to pending shutdown first as it may be called due to error; + // its state may already be PENDING_SHUTDOWN so it will return false but we + // intentionally do not check the returned flag + setState(State.PENDING_SHUTDOWN); + + log.info("Shutting down"); + + try { + stateConsumer.close(wipeStateStore); + } catch (final IOException e) { + log.error("Failed to close state maintainer due to the following error:", e); + } + + streamsMetrics.removeAllThreadLevelSensors(getName()); + streamsMetrics.removeAllThreadLevelMetrics(getName()); + + setState(DEAD); + + log.info("Shutdown complete"); + } + } + + public void setUncaughtExceptionHandler(final java.util.function.Consumer streamsUncaughtExceptionHandler) { + this.streamsUncaughtExceptionHandler = streamsUncaughtExceptionHandler; + } + + public void resize(final long cacheSize) { + this.cacheSize.set(cacheSize); + } + + private StateConsumer initialize() { + StateConsumer stateConsumer = null; + try { + final GlobalStateManager stateMgr = new GlobalStateManagerImpl( + logContext, + time, + topology, + globalConsumer, + stateDirectory, + stateRestoreListener, + config + ); + + final GlobalProcessorContextImpl globalProcessorContext = new GlobalProcessorContextImpl( + config, + stateMgr, + streamsMetrics, + cache, + time + ); + stateMgr.setGlobalProcessorContext(globalProcessorContext); + + stateConsumer = new StateConsumer( + logContext, + globalConsumer, + new GlobalStateUpdateTask( + logContext, + topology, + globalProcessorContext, + stateMgr, + config.defaultDeserializationExceptionHandler() + ), + time, + Duration.ofMillis(config.getLong(StreamsConfig.POLL_MS_CONFIG)), + config.getLong(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG) + ); + + try { + stateConsumer.initialize(); + } catch (final InvalidOffsetException recoverableException) { + log.error( + "Bootstrapping global state failed due to inconsistent local state. Will attempt to clean up the local state. You can restart KafkaStreams to recover from this error.", + recoverableException + ); + + closeStateConsumer(stateConsumer, true); + + throw new StreamsException( + "Bootstrapping global state failed. You can restart KafkaStreams to recover from this error.", + recoverableException + ); + } + + return stateConsumer; + } catch (final StreamsException fatalException) { + closeStateConsumer(stateConsumer, false); + startupException = fatalException; + } catch (final Exception fatalException) { + closeStateConsumer(stateConsumer, false); + startupException = new StreamsException("Exception caught during initialization of GlobalStreamThread", fatalException); + } + return null; + } + + private void closeStateConsumer(final StateConsumer stateConsumer, final boolean wipeStateStore) { + if (stateConsumer != null) { + try { + stateConsumer.close(wipeStateStore); + } catch (final IOException e) { + log.error("Failed to close state consumer due to the following error:", e); + } + } + } + + @Override + public synchronized void start() { + super.start(); + while (stillInitializing()) { + Utils.sleep(1); + if (startupException != null) { + throw startupException; + } + } + + if (inErrorState()) { + throw new IllegalStateException("Initialization for the global stream thread failed"); + } + } + + public void shutdown() { + // one could call shutdown() multiple times, so ignore subsequent calls + // if already shutting down or dead + setState(PENDING_SHUTDOWN); + } + + public Map consumerMetrics() { + return Collections.unmodifiableMap(globalConsumer.metrics()); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalProcessorContext.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalProcessorContext.java new file mode 100644 index 0000000..88e47e3 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalProcessorContext.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.serialization.BytesSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.RecordContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.internals.Task.TaskType; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.internals.ThreadCache; +import org.apache.kafka.streams.state.internals.ThreadCache.DirtyEntryFlushListener; + +/** + * For internal use so we can update the {@link RecordContext} and current + * {@link ProcessorNode} when we are forwarding items that have been evicted or flushed from + * {@link ThreadCache} + */ +public interface InternalProcessorContext + extends ProcessorContext, org.apache.kafka.streams.processor.api.ProcessorContext, StateStoreContext { + + BytesSerializer BYTES_KEY_SERIALIZER = new BytesSerializer(); + ByteArraySerializer BYTEARRAY_VALUE_SERIALIZER = new ByteArraySerializer(); + + @Override + StreamsMetricsImpl metrics(); + + /** + * @param timeMs current wall-clock system timestamp in milliseconds + */ + void setSystemTimeMs(long timeMs); + + /** + * Returns the current {@link RecordContext} + * @return the current {@link RecordContext} + */ + ProcessorRecordContext recordContext(); + + /** + * @param recordContext the {@link ProcessorRecordContext} for the record about to be processes + */ + void setRecordContext(ProcessorRecordContext recordContext); + + /** + * @param currentNode the current {@link ProcessorNode} + */ + void setCurrentNode(ProcessorNode currentNode); + + /** + * Get the current {@link ProcessorNode} + */ + ProcessorNode currentNode(); + + /** + * Get the thread-global cache + */ + ThreadCache cache(); + + /** + * Mark this context as being initialized + */ + void initialize(); + + /** + * Mark this context as being uninitialized + */ + void uninitialize(); + + /** + * @return the type of task (active/standby/global) that this context corresponds to + */ + TaskType taskType(); + + /** + * Transition to active task and register a new task and cache to this processor context + */ + void transitionToActive(final StreamTask streamTask, final RecordCollector recordCollector, final ThreadCache newCache); + + /** + * Transition to standby task and register a dummy cache to this processor context + */ + void transitionToStandby(final ThreadCache newCache); + + /** + * Register a dirty entry flush listener for a particular namespace + */ + void registerCacheFlushListener(final String namespace, final DirtyEntryFlushListener listener); + + /** + * Get a correctly typed state store, given a handle on the original builder. + */ + @SuppressWarnings("unchecked") + default T getStateStore(final StoreBuilder builder) { + return (T) getStateStore(builder.name()); + } + + void logChange(final String storeName, + final Bytes key, + final byte[] value, + final long timestamp); + + String changelogFor(final String storeName); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopicConfig.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopicConfig.java new file mode 100644 index 0000000..0b5a327 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopicConfig.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.config.TopicConfig; +import org.apache.kafka.common.internals.Topic; + +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +/** + * InternalTopicConfig captures the properties required for configuring + * the internal topics we create for change-logs and repartitioning etc. + */ +public abstract class InternalTopicConfig { + final String name; + final Map topicConfigs; + final boolean enforceNumberOfPartitions; + + private Optional numberOfPartitions = Optional.empty(); + + static final Map INTERNAL_TOPIC_DEFAULT_OVERRIDES = new HashMap<>(); + static { + INTERNAL_TOPIC_DEFAULT_OVERRIDES.put(TopicConfig.MESSAGE_TIMESTAMP_TYPE_CONFIG, "CreateTime"); + } + + InternalTopicConfig(final String name, final Map topicConfigs) { + this.name = Objects.requireNonNull(name, "name can't be null"); + Topic.validate(name); + this.topicConfigs = Objects.requireNonNull(topicConfigs, "topicConfigs can't be null"); + this.enforceNumberOfPartitions = false; + } + + InternalTopicConfig(final String name, + final Map topicConfigs, + final int numberOfPartitions, + final boolean enforceNumberOfPartitions) { + this.name = Objects.requireNonNull(name, "name can't be null"); + Topic.validate(name); + validateNumberOfPartitions(numberOfPartitions); + this.topicConfigs = Objects.requireNonNull(topicConfigs, "topicConfigs can't be null"); + this.numberOfPartitions = Optional.of(numberOfPartitions); + this.enforceNumberOfPartitions = enforceNumberOfPartitions; + } + + /** + * Get the configured properties for this topic. If retentionMs is set then + * we add additionalRetentionMs to work out the desired retention when cleanup.policy=compact,delete + * + * @param additionalRetentionMs - added to retention to allow for clock drift etc + * @return Properties to be used when creating the topic + */ + public abstract Map getProperties(final Map defaultProperties, final long additionalRetentionMs); + + public boolean hasEnforcedNumberOfPartitions() { + return enforceNumberOfPartitions; + } + + public String name() { + return name; + } + + public Optional numberOfPartitions() { + return numberOfPartitions; + } + + public void setNumberOfPartitions(final int numberOfPartitions) { + if (hasEnforcedNumberOfPartitions()) { + throw new UnsupportedOperationException("number of partitions are enforced on topic " + + "" + name() + " and can't be altered."); + } + + validateNumberOfPartitions(numberOfPartitions); + + this.numberOfPartitions = Optional.of(numberOfPartitions); + } + + private static void validateNumberOfPartitions(final int numberOfPartitions) { + if (numberOfPartitions < 1) { + throw new IllegalArgumentException("Number of partitions must be at least 1."); + } + } + + @Override + public String toString() { + return "InternalTopicConfig(" + + "name=" + name + + ", topicConfigs=" + topicConfigs + + ", enforceNumberOfPartitions=" + enforceNumberOfPartitions + + ")"; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopicManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopicManager.java new file mode 100644 index 0000000..6954921 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopicManager.java @@ -0,0 +1,807 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.Config; +import org.apache.kafka.clients.admin.ConfigEntry; +import org.apache.kafka.clients.admin.CreateTopicsResult; +import org.apache.kafka.clients.admin.DeleteTopicsResult; +import org.apache.kafka.clients.admin.DescribeConfigsResult; +import org.apache.kafka.clients.admin.DescribeTopicsResult; +import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.clients.admin.TopicDescription; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.config.ConfigResource.Type; +import org.apache.kafka.common.config.TopicConfig; +import org.apache.kafka.common.errors.InterruptException; +import org.apache.kafka.common.errors.LeaderNotAvailableException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.TopicExistsException; +import org.apache.kafka.common.errors.UnknownTopicOrPartitionException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.internals.ClientUtils.QuietConsumerConfig; +import org.slf4j.Logger; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.function.BiConsumer; +import java.util.stream.Collectors; + +public class InternalTopicManager { + private final static String BUG_ERROR_MESSAGE = "This indicates a bug. " + + "Please report at https://issues.apache.org/jira/projects/KAFKA/issues or to the dev-mailing list (https://kafka.apache.org/contact)."; + private final static String INTERRUPTED_ERROR_MESSAGE = "Thread got interrupted. " + BUG_ERROR_MESSAGE; + + private final Logger log; + + private final Time time; + private final Admin adminClient; + + private final short replicationFactor; + private final long windowChangeLogAdditionalRetention; + private final long retryBackOffMs; + private final long retryTimeoutMs; + + private final Map defaultTopicConfigs = new HashMap<>(); + + public InternalTopicManager(final Time time, + final Admin adminClient, + final StreamsConfig streamsConfig) { + this.time = time; + this.adminClient = adminClient; + + final LogContext logContext = new LogContext(String.format("stream-thread [%s] ", Thread.currentThread().getName())); + log = logContext.logger(getClass()); + + replicationFactor = streamsConfig.getInt(StreamsConfig.REPLICATION_FACTOR_CONFIG).shortValue(); + windowChangeLogAdditionalRetention = streamsConfig.getLong(StreamsConfig.WINDOW_STORE_CHANGE_LOG_ADDITIONAL_RETENTION_MS_CONFIG); + retryBackOffMs = streamsConfig.getLong(StreamsConfig.RETRY_BACKOFF_MS_CONFIG); + final Map consumerConfig = streamsConfig.getMainConsumerConfigs("dummy", "dummy", -1); + // need to add mandatory configs; otherwise `QuietConsumerConfig` throws + consumerConfig.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class); + consumerConfig.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class); + retryTimeoutMs = new QuietConsumerConfig(consumerConfig).getInt(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG) / 2L; + + log.debug("Configs:" + Utils.NL + + "\t{} = {}" + Utils.NL + + "\t{} = {}", + StreamsConfig.REPLICATION_FACTOR_CONFIG, replicationFactor, + StreamsConfig.WINDOW_STORE_CHANGE_LOG_ADDITIONAL_RETENTION_MS_CONFIG, windowChangeLogAdditionalRetention); + + for (final Map.Entry entry : streamsConfig.originalsWithPrefix(StreamsConfig.TOPIC_PREFIX).entrySet()) { + if (entry.getValue() != null) { + defaultTopicConfigs.put(entry.getKey(), entry.getValue().toString()); + } + } + } + + static class ValidationResult { + private final Set missingTopics = new HashSet<>(); + private final Map> misconfigurationsForTopics = new HashMap<>(); + + public void addMissingTopic(final String topic) { + missingTopics.add(topic); + } + + public Set missingTopics() { + return Collections.unmodifiableSet(missingTopics); + } + + public void addMisconfiguration(final String topic, final String message) { + misconfigurationsForTopics.computeIfAbsent(topic, ignored -> new ArrayList<>()) + .add(message); + } + + public Map> misconfigurationsForTopics() { + return Collections.unmodifiableMap(misconfigurationsForTopics); + } + } + + /** + * Validates the internal topics passed. + * + * The validation of the internal topics verifies if the topics: + * - are missing on the brokers + * - have the expected number of partitions + * - have configured a clean-up policy that avoids data loss + * + * @param topicConfigs internal topics to validate + * + * @return validation results that contains + * - the set of missing internal topics on the brokers + * - descriptions of misconfigurations per topic + */ + public ValidationResult validate(final Map topicConfigs) { + log.info("Starting to validate internal topics {}.", topicConfigs.keySet()); + + final long now = time.milliseconds(); + final long deadline = now + retryTimeoutMs; + + final ValidationResult validationResult = new ValidationResult(); + final Set topicDescriptionsStillToValidate = new HashSet<>(topicConfigs.keySet()); + final Set topicConfigsStillToValidate = new HashSet<>(topicConfigs.keySet()); + while (!topicDescriptionsStillToValidate.isEmpty() || !topicConfigsStillToValidate.isEmpty()) { + Map> descriptionsForTopic = Collections.emptyMap(); + if (!topicDescriptionsStillToValidate.isEmpty()) { + final DescribeTopicsResult describeTopicsResult = adminClient.describeTopics(topicDescriptionsStillToValidate); + descriptionsForTopic = describeTopicsResult.topicNameValues(); + } + Map> configsForTopic = Collections.emptyMap(); + if (!topicConfigsStillToValidate.isEmpty()) { + final DescribeConfigsResult describeConfigsResult = adminClient.describeConfigs( + topicConfigsStillToValidate.stream() + .map(topic -> new ConfigResource(Type.TOPIC, topic)) + .collect(Collectors.toSet()) + ); + configsForTopic = describeConfigsResult.values().entrySet().stream() + .collect(Collectors.toMap(entry -> entry.getKey().name(), Map.Entry::getValue)); + } + + while (!descriptionsForTopic.isEmpty() || !configsForTopic.isEmpty()) { + if (!descriptionsForTopic.isEmpty()) { + doValidateTopic( + validationResult, + descriptionsForTopic, + topicConfigs, + topicDescriptionsStillToValidate, + (streamsSide, brokerSide) -> validatePartitionCount(validationResult, streamsSide, brokerSide) + ); + } + if (!configsForTopic.isEmpty()) { + doValidateTopic( + validationResult, + configsForTopic, + topicConfigs, + topicConfigsStillToValidate, + (streamsSide, brokerSide) -> validateCleanupPolicy(validationResult, streamsSide, brokerSide) + ); + } + + maybeThrowTimeoutException( + Arrays.asList(topicDescriptionsStillToValidate, topicConfigsStillToValidate), + deadline, + String.format("Could not validate internal topics within %d milliseconds. " + + "This can happen if the Kafka cluster is temporarily not available.", retryTimeoutMs) + ); + + if (!descriptionsForTopic.isEmpty() || !configsForTopic.isEmpty()) { + Utils.sleep(100); + } + } + + maybeSleep( + Arrays.asList(topicDescriptionsStillToValidate, topicConfigsStillToValidate), + deadline, + "validated" + ); + } + + log.info("Completed validation of internal topics {}.", topicConfigs.keySet()); + return validationResult; + } + + private void doValidateTopic(final ValidationResult validationResult, + final Map> futuresForTopic, + final Map topicsConfigs, + final Set topicsStillToValidate, + final BiConsumer validator) { + for (final String topicName : new HashSet<>(topicsStillToValidate)) { + if (!futuresForTopic.containsKey(topicName)) { + throw new IllegalStateException("Description results do not contain topics to validate. " + BUG_ERROR_MESSAGE); + } + final KafkaFuture future = futuresForTopic.get(topicName); + if (future.isDone()) { + try { + final V brokerSideTopicConfig = future.get(); + final InternalTopicConfig streamsSideTopicConfig = topicsConfigs.get(topicName); + validator.accept(streamsSideTopicConfig, brokerSideTopicConfig); + topicsStillToValidate.remove(topicName); + } catch (final ExecutionException executionException) { + final Throwable cause = executionException.getCause(); + if (cause instanceof UnknownTopicOrPartitionException) { + log.info("Internal topic {} is missing", topicName); + validationResult.addMissingTopic(topicName); + topicsStillToValidate.remove(topicName); + } else if (cause instanceof LeaderNotAvailableException) { + log.info("The leader of internal topic {} is not available.", topicName); + } else if (cause instanceof TimeoutException) { + log.info("Retrieving data for internal topic {} timed out.", topicName); + } else { + log.error("Unexpected error during internal topic validation: ", cause); + throw new StreamsException( + String.format("Could not validate internal topic %s for the following reason: ", topicName), + cause + ); + } + } catch (final InterruptedException interruptedException) { + throw new InterruptException(interruptedException); + } finally { + futuresForTopic.remove(topicName); + } + } + } + } + + private void validatePartitionCount(final ValidationResult validationResult, + final InternalTopicConfig topicConfig, + final TopicDescription topicDescription) { + final String topicName = topicConfig.name(); + final int requiredPartitionCount = topicConfig.numberOfPartitions() + .orElseThrow(() -> new IllegalStateException("No partition count is specified for internal topic " + + topicName + ". " + BUG_ERROR_MESSAGE)); + final int actualPartitionCount = topicDescription.partitions().size(); + if (actualPartitionCount != requiredPartitionCount) { + validationResult.addMisconfiguration( + topicName, + "Internal topic " + topicName + " requires " + requiredPartitionCount + " partitions, " + + "but the existing topic on the broker has " + actualPartitionCount + " partitions." + ); + } + } + + private void validateCleanupPolicy(final ValidationResult validationResult, + final InternalTopicConfig topicConfig, + final Config brokerSideTopicConfig) { + if (topicConfig instanceof UnwindowedChangelogTopicConfig) { + validateCleanupPolicyForUnwindowedChangelogs(validationResult, topicConfig, brokerSideTopicConfig); + } else if (topicConfig instanceof WindowedChangelogTopicConfig) { + validateCleanupPolicyForWindowedChangelogs(validationResult, topicConfig, brokerSideTopicConfig); + } else if (topicConfig instanceof RepartitionTopicConfig) { + validateCleanupPolicyForRepartitionTopic(validationResult, topicConfig, brokerSideTopicConfig); + } else { + throw new IllegalStateException("Internal topic " + topicConfig.name() + " has unknown type."); + } + } + + private void validateCleanupPolicyForUnwindowedChangelogs(final ValidationResult validationResult, + final InternalTopicConfig topicConfig, + final Config brokerSideTopicConfig) { + final String topicName = topicConfig.name(); + final String cleanupPolicy = getBrokerSideConfigValue(brokerSideTopicConfig, TopicConfig.CLEANUP_POLICY_CONFIG, topicName); + if (cleanupPolicy.contains(TopicConfig.CLEANUP_POLICY_DELETE)) { + validationResult.addMisconfiguration( + topicName, + "Cleanup policy (" + TopicConfig.CLEANUP_POLICY_CONFIG + ") of existing internal topic " + + topicName + " should not contain \"" + + TopicConfig.CLEANUP_POLICY_DELETE + "\"." + ); + } + } + + private void validateCleanupPolicyForWindowedChangelogs(final ValidationResult validationResult, + final InternalTopicConfig topicConfig, + final Config brokerSideTopicConfig) { + final String topicName = topicConfig.name(); + final String cleanupPolicy = getBrokerSideConfigValue(brokerSideTopicConfig, TopicConfig.CLEANUP_POLICY_CONFIG, topicName); + if (cleanupPolicy.contains(TopicConfig.CLEANUP_POLICY_DELETE)) { + final long brokerSideRetentionMs = + Long.parseLong(getBrokerSideConfigValue(brokerSideTopicConfig, TopicConfig.RETENTION_MS_CONFIG, topicName)); + final Map streamsSideConfig = + topicConfig.getProperties(defaultTopicConfigs, windowChangeLogAdditionalRetention); + final long streamsSideRetentionMs = Long.parseLong(streamsSideConfig.get(TopicConfig.RETENTION_MS_CONFIG)); + if (brokerSideRetentionMs < streamsSideRetentionMs) { + validationResult.addMisconfiguration( + topicName, + "Retention time (" + TopicConfig.RETENTION_MS_CONFIG + ") of existing internal topic " + + topicName + " is " + brokerSideRetentionMs + " but should be " + streamsSideRetentionMs + " or larger." + ); + } + final String brokerSideRetentionBytes = + getBrokerSideConfigValue(brokerSideTopicConfig, TopicConfig.RETENTION_BYTES_CONFIG, topicName); + if (brokerSideRetentionBytes != null) { + validationResult.addMisconfiguration( + topicName, + "Retention byte (" + TopicConfig.RETENTION_BYTES_CONFIG + ") of existing internal topic " + + topicName + " is set but it should be unset." + ); + } + } + } + + private void validateCleanupPolicyForRepartitionTopic(final ValidationResult validationResult, + final InternalTopicConfig topicConfig, + final Config brokerSideTopicConfig) { + final String topicName = topicConfig.name(); + final String cleanupPolicy = getBrokerSideConfigValue(brokerSideTopicConfig, TopicConfig.CLEANUP_POLICY_CONFIG, topicName); + if (cleanupPolicy.contains(TopicConfig.CLEANUP_POLICY_COMPACT)) { + validationResult.addMisconfiguration( + topicName, + "Cleanup policy (" + TopicConfig.CLEANUP_POLICY_CONFIG + ") of existing internal topic " + + topicName + " should not contain \"" + TopicConfig.CLEANUP_POLICY_COMPACT + "\"." + ); + } else if (cleanupPolicy.contains(TopicConfig.CLEANUP_POLICY_DELETE)) { + final long brokerSideRetentionMs = + Long.parseLong(getBrokerSideConfigValue(brokerSideTopicConfig, TopicConfig.RETENTION_MS_CONFIG, topicName)); + if (brokerSideRetentionMs != -1) { + validationResult.addMisconfiguration( + topicName, + "Retention time (" + TopicConfig.RETENTION_MS_CONFIG + ") of existing internal topic " + + topicName + " is " + brokerSideRetentionMs + " but should be -1." + ); + } + final String brokerSideRetentionBytes = + getBrokerSideConfigValue(brokerSideTopicConfig, TopicConfig.RETENTION_BYTES_CONFIG, topicName); + if (brokerSideRetentionBytes != null) { + validationResult.addMisconfiguration( + topicName, + "Retention byte (" + TopicConfig.RETENTION_BYTES_CONFIG + ") of existing internal topic " + + topicName + " is set but it should be unset." + ); + } + } + } + + private String getBrokerSideConfigValue(final Config brokerSideTopicConfig, + final String configName, + final String topicName) { + final ConfigEntry brokerSideConfigEntry = brokerSideTopicConfig.get(configName); + if (brokerSideConfigEntry == null) { + throw new IllegalStateException("The config " + configName + " for topic " + + topicName + " could not be " + "retrieved from the brokers. " + BUG_ERROR_MESSAGE); + } + return brokerSideConfigEntry.value(); + } + + /** + * Prepares a set of given internal topics. + * + * If a topic does not exist creates a new topic. + * If a topic with the correct number of partitions exists ignores it. + * If a topic exists already but has different number of partitions we fail and throw exception requesting user to reset the app before restarting again. + * @return the set of topics which had to be newly created + */ + public Set makeReady(final Map topics) { + // we will do the validation / topic-creation in a loop, until we have confirmed all topics + // have existed with the expected number of partitions, or some create topic returns fatal errors. + log.debug("Starting to validate internal topics {} in partition assignor.", topics); + + long currentWallClockMs = time.milliseconds(); + final long deadlineMs = currentWallClockMs + retryTimeoutMs; + + Set topicsNotReady = new HashSet<>(topics.keySet()); + final Set newlyCreatedTopics = new HashSet<>(); + + while (!topicsNotReady.isEmpty()) { + final Set tempUnknownTopics = new HashSet<>(); + topicsNotReady = validateTopics(topicsNotReady, topics, tempUnknownTopics); + newlyCreatedTopics.addAll(topicsNotReady); + + if (!topicsNotReady.isEmpty()) { + final Set newTopics = new HashSet<>(); + + for (final String topicName : topicsNotReady) { + if (tempUnknownTopics.contains(topicName)) { + // for the tempUnknownTopics, don't create topic for them + // we'll check again later if remaining retries > 0 + continue; + } + final InternalTopicConfig internalTopicConfig = Objects.requireNonNull(topics.get(topicName)); + final Map topicConfig = internalTopicConfig.getProperties(defaultTopicConfigs, windowChangeLogAdditionalRetention); + + log.debug("Going to create topic {} with {} partitions and config {}.", + internalTopicConfig.name(), + internalTopicConfig.numberOfPartitions(), + topicConfig); + + newTopics.add( + new NewTopic( + internalTopicConfig.name(), + internalTopicConfig.numberOfPartitions(), + Optional.of(replicationFactor)) + .configs(topicConfig)); + } + + // it's possible that although some topics are not ready yet because they + // are temporarily not available, not that they do not exist; in this case + // the new topics to create may be empty and hence we can skip here + if (!newTopics.isEmpty()) { + final CreateTopicsResult createTopicsResult = adminClient.createTopics(newTopics); + + for (final Map.Entry> createTopicResult : createTopicsResult.values().entrySet()) { + final String topicName = createTopicResult.getKey(); + try { + createTopicResult.getValue().get(); + topicsNotReady.remove(topicName); + } catch (final InterruptedException fatalException) { + // this should not happen; if it ever happens it indicate a bug + Thread.currentThread().interrupt(); + log.error(INTERRUPTED_ERROR_MESSAGE, fatalException); + throw new IllegalStateException(INTERRUPTED_ERROR_MESSAGE, fatalException); + } catch (final ExecutionException executionException) { + final Throwable cause = executionException.getCause(); + if (cause instanceof TopicExistsException) { + // This topic didn't exist earlier or its leader not known before; just retain it for next round of validation. + log.info( + "Could not create topic {}. Topic is probably marked for deletion (number of partitions is unknown).\n" + + + "Will retry to create this topic in {} ms (to let broker finish async delete operation first).\n" + + + "Error message was: {}", topicName, retryBackOffMs, + cause.toString()); + } else { + log.error("Unexpected error during topic creation for {}.\n" + + "Error message was: {}", topicName, cause.toString()); + + if (cause instanceof UnsupportedVersionException) { + final String errorMessage = cause.getMessage(); + if (errorMessage != null && + errorMessage.startsWith("Creating topics with default partitions/replication factor are only supported in CreateTopicRequest version 4+")) { + + throw new StreamsException(String.format( + "Could not create topic %s, because brokers don't support configuration replication.factor=-1." + + " You can change the replication.factor config or upgrade your brokers to version 2.4 or newer to avoid this error.", + topicName) + ); + } + } else { + throw new StreamsException( + String.format("Could not create topic %s.", topicName), + cause + ); + } + } + } catch (final TimeoutException retriableException) { + log.error("Creating topic {} timed out.\n" + + "Error message was: {}", topicName, retriableException.toString()); + } + } + } + } + + if (!topicsNotReady.isEmpty()) { + currentWallClockMs = time.milliseconds(); + + if (currentWallClockMs >= deadlineMs) { + final String timeoutError = String.format("Could not create topics within %d milliseconds. " + + "This can happen if the Kafka cluster is temporarily not available.", retryTimeoutMs); + log.error(timeoutError); + throw new TimeoutException(timeoutError); + } + log.info( + "Topics {} could not be made ready. Will retry in {} milliseconds. Remaining time in milliseconds: {}", + topicsNotReady, + retryBackOffMs, + deadlineMs - currentWallClockMs + ); + Utils.sleep(retryBackOffMs); + } + } + log.debug("Completed validating internal topics and created {}", newlyCreatedTopics); + + return newlyCreatedTopics; + } + + /** + * Try to get the number of partitions for the given topics; return the number of partitions for topics that already exists. + * + * Topics that were not able to get its description will simply not be returned + */ + // visible for testing + protected Map getNumPartitions(final Set topics, + final Set tempUnknownTopics) { + log.debug("Trying to check if topics {} have been created with expected number of partitions.", topics); + + final DescribeTopicsResult describeTopicsResult = adminClient.describeTopics(topics); + final Map> futures = describeTopicsResult.topicNameValues(); + + final Map existedTopicPartition = new HashMap<>(); + for (final Map.Entry> topicFuture : futures.entrySet()) { + final String topicName = topicFuture.getKey(); + try { + final TopicDescription topicDescription = topicFuture.getValue().get(); + existedTopicPartition.put(topicName, topicDescription.partitions().size()); + } catch (final InterruptedException fatalException) { + // this should not happen; if it ever happens it indicate a bug + Thread.currentThread().interrupt(); + log.error(INTERRUPTED_ERROR_MESSAGE, fatalException); + throw new IllegalStateException(INTERRUPTED_ERROR_MESSAGE, fatalException); + } catch (final ExecutionException couldNotDescribeTopicException) { + final Throwable cause = couldNotDescribeTopicException.getCause(); + if (cause instanceof UnknownTopicOrPartitionException) { + // This topic didn't exist, proceed to try to create it + log.debug("Topic {} is unknown or not found, hence not existed yet.\n" + + "Error message was: {}", topicName, cause.toString()); + } else if (cause instanceof LeaderNotAvailableException) { + tempUnknownTopics.add(topicName); + log.debug("The leader of topic {} is not available.\n" + + "Error message was: {}", topicName, cause.toString()); + } else { + log.error("Unexpected error during topic description for {}.\n" + + "Error message was: {}", topicName, cause.toString()); + throw new StreamsException(String.format("Could not create topic %s.", topicName), cause); + } + } catch (final TimeoutException retriableException) { + tempUnknownTopics.add(topicName); + log.debug("Describing topic {} (to get number of partitions) timed out.\n" + + "Error message was: {}", topicName, retriableException.toString()); + } + } + + return existedTopicPartition; + } + + /** + * Check the existing topics to have correct number of partitions; and return the remaining topics that needs to be created + */ + private Set validateTopics(final Set topicsToValidate, + final Map topicsMap, + final Set tempUnknownTopics) { + if (!topicsMap.keySet().containsAll(topicsToValidate)) { + throw new IllegalStateException("The topics map " + topicsMap.keySet() + " does not contain all the topics " + + topicsToValidate + " trying to validate."); + } + + final Map existedTopicPartition = getNumPartitions(topicsToValidate, tempUnknownTopics); + + final Set topicsToCreate = new HashSet<>(); + for (final String topicName : topicsToValidate) { + final Optional numberOfPartitions = topicsMap.get(topicName).numberOfPartitions(); + if (!numberOfPartitions.isPresent()) { + log.error("Found undefined number of partitions for topic {}", topicName); + throw new StreamsException("Topic " + topicName + " number of partitions not defined"); + } + if (existedTopicPartition.containsKey(topicName)) { + if (!existedTopicPartition.get(topicName).equals(numberOfPartitions.get())) { + final String errorMsg = String.format("Existing internal topic %s has invalid partitions: " + + "expected: %d; actual: %d. " + + "Use 'kafka.tools.StreamsResetter' tool to clean up invalid topics before processing.", + topicName, numberOfPartitions.get(), existedTopicPartition.get(topicName)); + log.error(errorMsg); + throw new StreamsException(errorMsg); + } + } else { + topicsToCreate.add(topicName); + } + } + + return topicsToCreate; + } + + /** + * Sets up internal topics. + * + * Either the given topic are all created or the method fails with an exception. + * + * @param topicConfigs internal topics to setup + */ + public void setup(final Map topicConfigs) { + log.info("Starting to setup internal topics {}.", topicConfigs.keySet()); + + final long now = time.milliseconds(); + final long deadline = now + retryTimeoutMs; + + final Map> streamsSideTopicConfigs = topicConfigs.values().stream() + .collect(Collectors.toMap( + InternalTopicConfig::name, + topicConfig -> topicConfig.getProperties(defaultTopicConfigs, windowChangeLogAdditionalRetention) + )); + final Set createdTopics = new HashSet<>(); + final Set topicStillToCreate = new HashSet<>(topicConfigs.keySet()); + while (!topicStillToCreate.isEmpty()) { + final Set newTopics = topicStillToCreate.stream() + .map(topicName -> new NewTopic( + topicName, + topicConfigs.get(topicName).numberOfPartitions(), + Optional.of(replicationFactor) + ).configs(streamsSideTopicConfigs.get(topicName)) + ).collect(Collectors.toSet()); + + log.info("Going to create internal topics: " + newTopics); + final CreateTopicsResult createTopicsResult = adminClient.createTopics(newTopics); + + processCreateTopicResults(createTopicsResult, topicStillToCreate, createdTopics, deadline); + + maybeSleep(Collections.singletonList(topicStillToCreate), deadline, "created"); + } + + log.info("Completed setup of internal topics {}.", topicConfigs.keySet()); + } + + private void processCreateTopicResults(final CreateTopicsResult createTopicsResult, + final Set topicStillToCreate, + final Set createdTopics, + final long deadline) { + final Map lastErrorsSeenForTopic = new HashMap<>(); + final Map> createResultForTopic = createTopicsResult.values(); + while (!createResultForTopic.isEmpty()) { + for (final String topicName : new HashSet<>(topicStillToCreate)) { + if (!createResultForTopic.containsKey(topicName)) { + cleanUpCreatedTopics(createdTopics); + throw new IllegalStateException("Create topic results do not contain internal topic " + topicName + + " to setup. " + BUG_ERROR_MESSAGE); + } + final KafkaFuture createResult = createResultForTopic.get(topicName); + if (createResult.isDone()) { + try { + createResult.get(); + createdTopics.add(topicName); + topicStillToCreate.remove(topicName); + } catch (final ExecutionException executionException) { + final Throwable cause = executionException.getCause(); + if (cause instanceof TopicExistsException) { + lastErrorsSeenForTopic.put(topicName, cause); + log.info("Internal topic {} already exists. Topic is probably marked for deletion. " + + "Will retry to create this topic later (to let broker complete async delete operation first)", + topicName); + } else if (cause instanceof TimeoutException) { + lastErrorsSeenForTopic.put(topicName, cause); + log.info("Creating internal topic {} timed out.", topicName); + } else { + cleanUpCreatedTopics(createdTopics); + log.error("Unexpected error during creation of internal topic: ", cause); + throw new StreamsException( + String.format("Could not create internal topic %s for the following reason: ", topicName), + cause + ); + } + } catch (final InterruptedException interruptedException) { + throw new InterruptException(interruptedException); + } finally { + createResultForTopic.remove(topicName); + } + } + } + + maybeThrowTimeoutExceptionDuringSetup( + topicStillToCreate, + createdTopics, + lastErrorsSeenForTopic, + deadline + ); + + if (!createResultForTopic.isEmpty()) { + Utils.sleep(100); + } + } + } + + private void cleanUpCreatedTopics(final Set topicsToCleanUp) { + log.info("Starting to clean up internal topics {}.", topicsToCleanUp); + + final long now = time.milliseconds(); + final long deadline = now + retryTimeoutMs; + + final Set topicsStillToCleanup = new HashSet<>(topicsToCleanUp); + while (!topicsStillToCleanup.isEmpty()) { + log.info("Going to cleanup internal topics: " + topicsStillToCleanup); + final DeleteTopicsResult deleteTopicsResult = adminClient.deleteTopics(topicsStillToCleanup); + final Map> deleteResultForTopic = deleteTopicsResult.topicNameValues(); + while (!deleteResultForTopic.isEmpty()) { + for (final String topicName : new HashSet<>(topicsStillToCleanup)) { + if (!deleteResultForTopic.containsKey(topicName)) { + throw new IllegalStateException("Delete topic results do not contain internal topic " + topicName + + " to clean up. " + BUG_ERROR_MESSAGE); + } + final KafkaFuture deleteResult = deleteResultForTopic.get(topicName); + if (deleteResult.isDone()) { + try { + deleteResult.get(); + topicsStillToCleanup.remove(topicName); + } catch (final ExecutionException executionException) { + final Throwable cause = executionException.getCause(); + if (cause instanceof UnknownTopicOrPartitionException) { + log.info("Internal topic {} to clean up is missing", topicName); + } else if (cause instanceof LeaderNotAvailableException) { + log.info("The leader of internal topic {} to clean up is not available.", topicName); + } else if (cause instanceof TimeoutException) { + log.info("Cleaning up internal topic {} timed out.", topicName); + } else { + log.error("Unexpected error during cleanup of internal topics: ", cause); + throw new StreamsException( + String.format("Could not clean up internal topics %s, because during the cleanup " + + "of topic %s the following error occurred: ", + topicsStillToCleanup, topicName), + cause + ); + } + } catch (final InterruptedException interruptedException) { + throw new InterruptException(interruptedException); + } finally { + deleteResultForTopic.remove(topicName); + } + } + } + + maybeThrowTimeoutException( + Collections.singletonList(topicsStillToCleanup), + deadline, + String.format("Could not cleanup internal topics within %d milliseconds. This can happen if the " + + "Kafka cluster is temporarily not available or the broker did not complete topic creation " + + "before the cleanup. The following internal topics could not be cleaned up: %s", + retryTimeoutMs, topicsStillToCleanup) + ); + + if (!deleteResultForTopic.isEmpty()) { + Utils.sleep(100); + } + } + + maybeSleep( + Collections.singletonList(topicsStillToCleanup), + deadline, + "validated" + ); + } + + log.info("Completed cleanup of internal topics {}.", topicsToCleanUp); + } + + private void maybeThrowTimeoutException(final List> topicStillToProcess, + final long deadline, + final String errorMessage) { + if (topicStillToProcess.stream().anyMatch(resultSet -> !resultSet.isEmpty())) { + final long now = time.milliseconds(); + if (now >= deadline) { + log.error(errorMessage); + throw new TimeoutException(errorMessage); + } + } + } + + private void maybeThrowTimeoutExceptionDuringSetup(final Set topicStillToProcess, + final Set createdTopics, + final Map lastErrorsSeenForTopic, + final long deadline) { + if (topicStillToProcess.stream().anyMatch(resultSet -> !resultSet.isEmpty())) { + final long now = time.milliseconds(); + if (now >= deadline) { + cleanUpCreatedTopics(createdTopics); + final String errorMessage = String.format("Could not create internal topics within %d milliseconds. This can happen if the " + + "Kafka cluster is temporarily not available or a topic is marked for deletion and the broker " + + "did not complete its deletion within the timeout. The last errors seen per topic are: %s", + retryTimeoutMs, lastErrorsSeenForTopic); + log.error(errorMessage); + throw new TimeoutException(errorMessage); + } + } + } + + private void maybeSleep(final List> resultSetsStillToValidate, + final long deadline, + final String action) { + if (resultSetsStillToValidate.stream().anyMatch(resultSet -> !resultSet.isEmpty())) { + final long now = time.milliseconds(); + log.info( + "Internal topics {} could not be {}. Will retry in {} milliseconds. Remaining time in milliseconds: {}", + resultSetsStillToValidate.stream().flatMap(Collection::stream).collect(Collectors.toSet()), + action, + retryBackOffMs, + deadline - now + ); + Utils.sleep(retryBackOffMs); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopicProperties.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopicProperties.java new file mode 100644 index 0000000..b98780c --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopicProperties.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import java.util.Optional; + +public class InternalTopicProperties { + private final Integer numberOfPartitions; + + public InternalTopicProperties(final Integer numberOfPartitions) { + this.numberOfPartitions = numberOfPartitions; + } + + public Optional getNumberOfPartitions() { + return Optional.ofNullable(numberOfPartitions); + } + + public static InternalTopicProperties empty() { + return new InternalTopicProperties(null); + } + + @Override + public String toString() { + return "InternalTopicProperties{" + + "numberOfPartitions=" + numberOfPartitions + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopologyBuilder.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopologyBuilder.java new file mode 100644 index 0000000..dd07c10 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/InternalTopologyBuilder.java @@ -0,0 +1,2129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.OffsetResetStrategy; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.errors.TopologyException; +import org.apache.kafka.streams.internals.ApiUtils; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StreamPartitioner; +import org.apache.kafka.streams.processor.TimestampExtractor; +import org.apache.kafka.streams.processor.TopicNameExtractor; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology; +import org.apache.kafka.streams.processor.internals.namedtopology.NamedTopology; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.internals.SessionStoreBuilder; +import org.apache.kafka.streams.state.internals.TimestampedWindowStoreBuilder; +import org.apache.kafka.streams.state.internals.WindowStoreBuilder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.TreeMap; +import java.util.TreeSet; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import static org.apache.kafka.clients.consumer.OffsetResetStrategy.EARLIEST; +import static org.apache.kafka.clients.consumer.OffsetResetStrategy.LATEST; +import static org.apache.kafka.clients.consumer.OffsetResetStrategy.NONE; + +public class InternalTopologyBuilder { + + private static final Logger log = LoggerFactory.getLogger(InternalTopologyBuilder.class); + private static final String[] NO_PREDECESSORS = {}; + + // node factories in a topological order + private final Map> nodeFactories = new LinkedHashMap<>(); + + private final Map> stateFactories = new HashMap<>(); + + private final Map> globalStateBuilders = new LinkedHashMap<>(); + + // built global state stores + private final Map globalStateStores = new LinkedHashMap<>(); + + // all topics subscribed from source processors (without application-id prefix for internal topics) + private final Set sourceTopicNames = new HashSet<>(); + + // all internal topics with their corresponding properties auto-created by the topology builder and used in source / sink processors + private final Map internalTopicNamesWithProperties = new HashMap<>(); + + // groups of source processors that need to be copartitioned + private final List> copartitionSourceGroups = new ArrayList<>(); + + // map from source processor names to subscribed topics (without application-id prefix for internal topics) + private final Map> nodeToSourceTopics = new HashMap<>(); + + // map from source processor names to regex subscription patterns + private final Map nodeToSourcePatterns = new LinkedHashMap<>(); + + // map from sink processor names to sink topic (without application-id prefix for internal topics) + private final Map nodeToSinkTopic = new HashMap<>(); + + // map from state store names to all the topics subscribed from source processors that + // are connected to these state stores + private final Map> stateStoreNameToSourceTopics = new HashMap<>(); + + // map from state store names to all the regex subscribed topics from source processors that + // are connected to these state stores + private final Map> stateStoreNameToSourceRegex = new HashMap<>(); + + // map from state store names to this state store's corresponding changelog topic if possible + private final Map storeToChangelogTopic = new HashMap<>(); + + // map from changelog topic name to its corresponding state store. + private final Map changelogTopicToStore = new HashMap<>(); + + // all global topics + private final Set globalTopics = new HashSet<>(); + + private final Set earliestResetTopics = new HashSet<>(); + + private final Set latestResetTopics = new HashSet<>(); + + private final Set earliestResetPatterns = new HashSet<>(); + + private final Set latestResetPatterns = new HashSet<>(); + + private final QuickUnion nodeGrouper = new QuickUnion<>(); + + // Used to capture subscribed topics via Patterns discovered during the partition assignment process. + private final Set subscriptionUpdates = new HashSet<>(); + + private String applicationId = null; + + private String sourceTopicPatternString = null; + + private List sourceTopicCollection = null; + + private Map> nodeGroups = null; + + private StreamsConfig config = null; + + // The name of the topology this builder belongs to, or null if none + private String topologyName; + private NamedTopology namedTopology; + + private boolean hasPersistentStores = false; + + public static class StateStoreFactory { + private final StoreBuilder builder; + private final Set users = new HashSet<>(); + + private StateStoreFactory(final StoreBuilder builder) { + this.builder = builder; + } + + public S build() { + return builder.build(); + } + + long retentionPeriod() { + if (builder instanceof WindowStoreBuilder) { + return ((WindowStoreBuilder) builder).retentionPeriod(); + } else if (builder instanceof TimestampedWindowStoreBuilder) { + return ((TimestampedWindowStoreBuilder) builder).retentionPeriod(); + } else if (builder instanceof SessionStoreBuilder) { + return ((SessionStoreBuilder) builder).retentionPeriod(); + } else { + throw new IllegalStateException("retentionPeriod is not supported when not a window store"); + } + } + + private Set users() { + return users; + } + + public boolean loggingEnabled() { + return builder.loggingEnabled(); + } + + private String name() { + return builder.name(); + } + + private boolean isWindowStore() { + return builder instanceof WindowStoreBuilder + || builder instanceof TimestampedWindowStoreBuilder + || builder instanceof SessionStoreBuilder; + } + + // Apparently Java strips the generics from this method because we're using the raw type for builder, + // even though this method doesn't use builder's (missing) type parameter. Our usage seems obviously + // correct, though, hence the suppression. + private Map logConfig() { + return builder.logConfig(); + } + } + + private static abstract class NodeFactory { + final String name; + final String[] predecessors; + + NodeFactory(final String name, + final String[] predecessors) { + this.name = name; + this.predecessors = predecessors; + } + + public abstract ProcessorNode build(); + + abstract AbstractNode describe(); + } + + private static class ProcessorNodeFactory extends NodeFactory { + private final ProcessorSupplier supplier; + private final Set stateStoreNames = new HashSet<>(); + + ProcessorNodeFactory(final String name, + final String[] predecessors, + final ProcessorSupplier supplier) { + super(name, predecessors.clone()); + this.supplier = supplier; + } + + public void addStateStore(final String stateStoreName) { + stateStoreNames.add(stateStoreName); + } + + @Override + public ProcessorNode build() { + return new ProcessorNode<>(name, supplier.get(), stateStoreNames); + } + + @Override + Processor describe() { + return new Processor(name, new HashSet<>(stateStoreNames)); + } + } + + // Map from topics to their matched regex patterns, this is to ensure one topic is passed through on source node + // even if it can be matched by multiple regex patterns. Only used by SourceNodeFactory + private final Map topicToPatterns = new HashMap<>(); + + private class SourceNodeFactory extends NodeFactory { + private final List topics; + private final Pattern pattern; + private final Deserializer keyDeserializer; + private final Deserializer valDeserializer; + private final TimestampExtractor timestampExtractor; + + private SourceNodeFactory(final String name, + final String[] topics, + final Pattern pattern, + final TimestampExtractor timestampExtractor, + final Deserializer keyDeserializer, + final Deserializer valDeserializer) { + super(name, NO_PREDECESSORS); + this.topics = topics != null ? Arrays.asList(topics) : new ArrayList<>(); + this.pattern = pattern; + this.keyDeserializer = keyDeserializer; + this.valDeserializer = valDeserializer; + this.timestampExtractor = timestampExtractor; + } + + List getTopics(final Collection subscribedTopics) { + // if it is subscribed via patterns, it is possible that the topic metadata has not been updated + // yet and hence the map from source node to topics is stale, in this case we put the pattern as a place holder; + // this should only happen for debugging since during runtime this function should always be called after the metadata has updated. + if (subscribedTopics.isEmpty()) { + return Collections.singletonList(String.valueOf(pattern)); + } + + final List matchedTopics = new ArrayList<>(); + for (final String update : subscribedTopics) { + if (pattern == topicToPatterns.get(update)) { + matchedTopics.add(update); + } else if (topicToPatterns.containsKey(update) && isMatch(update)) { + // the same topic cannot be matched to more than one pattern + // TODO: we should lift this requirement in the future + throw new TopologyException("Topic " + update + + " is already matched for another regex pattern " + topicToPatterns.get(update) + + " and hence cannot be matched to this regex pattern " + pattern + " any more."); + } else if (isMatch(update)) { + topicToPatterns.put(update, pattern); + matchedTopics.add(update); + } + } + return matchedTopics; + } + + @Override + public ProcessorNode build() { + return new SourceNode<>(name, timestampExtractor, keyDeserializer, valDeserializer); + } + + private boolean isMatch(final String topic) { + return pattern.matcher(topic).matches(); + } + + @Override + Source describe() { + return new Source(name, topics.size() == 0 ? null : new HashSet<>(topics), pattern); + } + } + + private class SinkNodeFactory extends NodeFactory { + private final Serializer keySerializer; + private final Serializer valSerializer; + private final StreamPartitioner partitioner; + private final TopicNameExtractor topicExtractor; + + private SinkNodeFactory(final String name, + final String[] predecessors, + final TopicNameExtractor topicExtractor, + final Serializer keySerializer, + final Serializer valSerializer, + final StreamPartitioner partitioner) { + super(name, predecessors.clone()); + this.topicExtractor = topicExtractor; + this.keySerializer = keySerializer; + this.valSerializer = valSerializer; + this.partitioner = partitioner; + } + + @Override + public ProcessorNode build() { + if (topicExtractor instanceof StaticTopicNameExtractor) { + final String topic = ((StaticTopicNameExtractor) topicExtractor).topicName; + if (internalTopicNamesWithProperties.containsKey(topic)) { + // prefix the internal topic name with the application id + return new SinkNode<>(name, new StaticTopicNameExtractor<>(decorateTopic(topic)), keySerializer, valSerializer, partitioner); + } else { + return new SinkNode<>(name, topicExtractor, keySerializer, valSerializer, partitioner); + } + } else { + return new SinkNode<>(name, topicExtractor, keySerializer, valSerializer, partitioner); + } + } + + @Override + Sink describe() { + return new Sink<>(name, topicExtractor); + } + } + + + public void setNamedTopology(final NamedTopology topology) { + final String topologyName = topology.name(); + Objects.requireNonNull(topologyName, "topology name can't be null"); + Objects.requireNonNull(topology, "named topology can't be null"); + if (this.topologyName != null) { + log.error("Tried to reset the topologyName to {} but it was already set to {}", topologyName, this.topologyName); + throw new IllegalStateException("The topologyName has already been set to " + this.topologyName); + } + this.namedTopology = topology; + this.topologyName = topologyName; + } + + // public for testing only + public final InternalTopologyBuilder setApplicationId(final String applicationId) { + Objects.requireNonNull(applicationId, "applicationId can't be null"); + this.applicationId = applicationId; + + return this; + } + + public synchronized final InternalTopologyBuilder setStreamsConfig(final StreamsConfig config) { + Objects.requireNonNull(config, "config can't be null"); + this.config = config; + + return this; + } + + public synchronized final StreamsConfig getStreamsConfig() { + return config; + } + + public String topologyName() { + return topologyName; + } + + public NamedTopology namedTopology() { + return namedTopology; + } + + public synchronized final InternalTopologyBuilder rewriteTopology(final StreamsConfig config) { + Objects.requireNonNull(config, "config can't be null"); + + // set application id + setApplicationId(config.getString(StreamsConfig.APPLICATION_ID_CONFIG)); + + // maybe strip out caching layers + if (config.getLong(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG) == 0L) { + for (final StateStoreFactory storeFactory : stateFactories.values()) { + storeFactory.builder.withCachingDisabled(); + } + + for (final StoreBuilder storeBuilder : globalStateBuilders.values()) { + storeBuilder.withCachingDisabled(); + } + } + + // build global state stores + for (final StoreBuilder storeBuilder : globalStateBuilders.values()) { + globalStateStores.put(storeBuilder.name(), storeBuilder.build()); + } + + // set streams config + setStreamsConfig(config); + + return this; + } + + public final void addSource(final Topology.AutoOffsetReset offsetReset, + final String name, + final TimestampExtractor timestampExtractor, + final Deserializer keyDeserializer, + final Deserializer valDeserializer, + final String... topics) { + if (topics.length == 0) { + throw new TopologyException("You must provide at least one topic"); + } + Objects.requireNonNull(name, "name must not be null"); + if (nodeFactories.containsKey(name)) { + throw new TopologyException("Processor " + name + " is already added."); + } + + for (final String topic : topics) { + Objects.requireNonNull(topic, "topic names cannot be null"); + validateTopicNotAlreadyRegistered(topic); + maybeAddToResetList(earliestResetTopics, latestResetTopics, offsetReset, topic); + sourceTopicNames.add(topic); + } + + nodeFactories.put(name, new SourceNodeFactory<>(name, topics, null, timestampExtractor, keyDeserializer, valDeserializer)); + nodeToSourceTopics.put(name, Arrays.asList(topics)); + nodeGrouper.add(name); + nodeGroups = null; + } + + public final void addSource(final Topology.AutoOffsetReset offsetReset, + final String name, + final TimestampExtractor timestampExtractor, + final Deserializer keyDeserializer, + final Deserializer valDeserializer, + final Pattern topicPattern) { + Objects.requireNonNull(topicPattern, "topicPattern can't be null"); + Objects.requireNonNull(name, "name can't be null"); + + if (nodeFactories.containsKey(name)) { + throw new TopologyException("Processor " + name + " is already added."); + } + + for (final String sourceTopicName : sourceTopicNames) { + if (topicPattern.matcher(sourceTopicName).matches()) { + throw new TopologyException("Pattern " + topicPattern + " will match a topic that has already been registered by another source."); + } + } + + maybeAddToResetList(earliestResetPatterns, latestResetPatterns, offsetReset, topicPattern); + + nodeFactories.put(name, new SourceNodeFactory<>(name, null, topicPattern, timestampExtractor, keyDeserializer, valDeserializer)); + nodeToSourcePatterns.put(name, topicPattern); + nodeGrouper.add(name); + nodeGroups = null; + } + + public final void addSink(final String name, + final String topic, + final Serializer keySerializer, + final Serializer valSerializer, + final StreamPartitioner partitioner, + final String... predecessorNames) { + Objects.requireNonNull(name, "name must not be null"); + Objects.requireNonNull(topic, "topic must not be null"); + Objects.requireNonNull(predecessorNames, "predecessor names must not be null"); + if (predecessorNames.length == 0) { + throw new TopologyException("Sink " + name + " must have at least one parent"); + } + + addSink(name, new StaticTopicNameExtractor<>(topic), keySerializer, valSerializer, partitioner, predecessorNames); + nodeToSinkTopic.put(name, topic); + nodeGroups = null; + } + + public final void addSink(final String name, + final TopicNameExtractor topicExtractor, + final Serializer keySerializer, + final Serializer valSerializer, + final StreamPartitioner partitioner, + final String... predecessorNames) { + Objects.requireNonNull(name, "name must not be null"); + Objects.requireNonNull(topicExtractor, "topic extractor must not be null"); + Objects.requireNonNull(predecessorNames, "predecessor names must not be null"); + if (nodeFactories.containsKey(name)) { + throw new TopologyException("Processor " + name + " is already added."); + } + if (predecessorNames.length == 0) { + throw new TopologyException("Sink " + name + " must have at least one parent"); + } + + for (final String predecessor : predecessorNames) { + Objects.requireNonNull(predecessor, "predecessor name can't be null"); + if (predecessor.equals(name)) { + throw new TopologyException("Processor " + name + " cannot be a predecessor of itself."); + } + if (!nodeFactories.containsKey(predecessor)) { + throw new TopologyException("Predecessor processor " + predecessor + " is not added yet."); + } + if (nodeToSinkTopic.containsKey(predecessor)) { + throw new TopologyException("Sink " + predecessor + " cannot be used a parent."); + } + } + + nodeFactories.put(name, new SinkNodeFactory<>(name, predecessorNames, topicExtractor, keySerializer, valSerializer, partitioner)); + nodeGrouper.add(name); + nodeGrouper.unite(name, predecessorNames); + nodeGroups = null; + } + + public final void addProcessor(final String name, + final ProcessorSupplier supplier, + final String... predecessorNames) { + Objects.requireNonNull(name, "name must not be null"); + Objects.requireNonNull(supplier, "supplier must not be null"); + Objects.requireNonNull(predecessorNames, "predecessor names must not be null"); + ApiUtils.checkSupplier(supplier); + if (nodeFactories.containsKey(name)) { + throw new TopologyException("Processor " + name + " is already added."); + } + if (predecessorNames.length == 0) { + throw new TopologyException("Processor " + name + " must have at least one parent"); + } + + for (final String predecessor : predecessorNames) { + Objects.requireNonNull(predecessor, "predecessor name must not be null"); + if (predecessor.equals(name)) { + throw new TopologyException("Processor " + name + " cannot be a predecessor of itself."); + } + if (!nodeFactories.containsKey(predecessor)) { + throw new TopologyException("Predecessor processor " + predecessor + " is not added yet for " + name); + } + } + + nodeFactories.put(name, new ProcessorNodeFactory<>(name, predecessorNames, supplier)); + nodeGrouper.add(name); + nodeGrouper.unite(name, predecessorNames); + nodeGroups = null; + } + + public final void addStateStore(final StoreBuilder storeBuilder, + final String... processorNames) { + addStateStore(storeBuilder, false, processorNames); + } + + public final void addStateStore(final StoreBuilder storeBuilder, + final boolean allowOverride, + final String... processorNames) { + Objects.requireNonNull(storeBuilder, "storeBuilder can't be null"); + final StateStoreFactory stateFactory = stateFactories.get(storeBuilder.name()); + if (!allowOverride && stateFactory != null && !stateFactory.builder.equals(storeBuilder)) { + throw new TopologyException("A different StateStore has already been added with the name " + storeBuilder.name()); + } + if (globalStateBuilders.containsKey(storeBuilder.name())) { + throw new TopologyException("A different GlobalStateStore has already been added with the name " + storeBuilder.name()); + } + + stateFactories.put(storeBuilder.name(), new StateStoreFactory<>(storeBuilder)); + + if (processorNames != null) { + for (final String processorName : processorNames) { + Objects.requireNonNull(processorName, "processor name must not be null"); + connectProcessorAndStateStore(processorName, storeBuilder.name()); + } + } + nodeGroups = null; + } + + public final void addGlobalStore(final StoreBuilder storeBuilder, + final String sourceName, + final TimestampExtractor timestampExtractor, + final Deserializer keyDeserializer, + final Deserializer valueDeserializer, + final String topic, + final String processorName, + final ProcessorSupplier stateUpdateSupplier) { + Objects.requireNonNull(storeBuilder, "store builder must not be null"); + ApiUtils.checkSupplier(stateUpdateSupplier); + validateGlobalStoreArguments(sourceName, + topic, + processorName, + stateUpdateSupplier, + storeBuilder.name(), + storeBuilder.loggingEnabled()); + validateTopicNotAlreadyRegistered(topic); + + final String[] topics = {topic}; + final String[] predecessors = {sourceName}; + + final ProcessorNodeFactory nodeFactory = new ProcessorNodeFactory<>( + processorName, + predecessors, + stateUpdateSupplier + ); + + globalTopics.add(topic); + nodeFactories.put(sourceName, new SourceNodeFactory<>( + sourceName, + topics, + null, + timestampExtractor, + keyDeserializer, + valueDeserializer) + ); + nodeToSourceTopics.put(sourceName, Arrays.asList(topics)); + nodeGrouper.add(sourceName); + nodeFactory.addStateStore(storeBuilder.name()); + nodeFactories.put(processorName, nodeFactory); + nodeGrouper.add(processorName); + nodeGrouper.unite(processorName, predecessors); + globalStateBuilders.put(storeBuilder.name(), storeBuilder); + connectSourceStoreAndTopic(storeBuilder.name(), topic); + nodeGroups = null; + } + + private void validateTopicNotAlreadyRegistered(final String topic) { + if (sourceTopicNames.contains(topic) || globalTopics.contains(topic)) { + throw new TopologyException("Topic " + topic + " has already been registered by another source."); + } + + for (final Pattern pattern : nodeToSourcePatterns.values()) { + if (pattern.matcher(topic).matches()) { + throw new TopologyException("Topic " + topic + " matches a Pattern already registered by another source."); + } + } + } + + public final void connectProcessorAndStateStores(final String processorName, + final String... stateStoreNames) { + Objects.requireNonNull(processorName, "processorName can't be null"); + Objects.requireNonNull(stateStoreNames, "state store list must not be null"); + if (stateStoreNames.length == 0) { + throw new TopologyException("Must provide at least one state store name."); + } + for (final String stateStoreName : stateStoreNames) { + Objects.requireNonNull(stateStoreName, "state store name must not be null"); + connectProcessorAndStateStore(processorName, stateStoreName); + } + nodeGroups = null; + } + + public String getStoreForChangelogTopic(final String topicName) { + return changelogTopicToStore.get(topicName); + } + + public void connectSourceStoreAndTopic(final String sourceStoreName, + final String topic) { + if (storeToChangelogTopic.containsKey(sourceStoreName)) { + throw new TopologyException("Source store " + sourceStoreName + " is already added."); + } + storeToChangelogTopic.put(sourceStoreName, topic); + changelogTopicToStore.put(topic, sourceStoreName); + } + + public final void addInternalTopic(final String topicName, + final InternalTopicProperties internalTopicProperties) { + Objects.requireNonNull(topicName, "topicName can't be null"); + Objects.requireNonNull(internalTopicProperties, "internalTopicProperties can't be null"); + + internalTopicNamesWithProperties.put(topicName, internalTopicProperties); + } + + public final void copartitionSources(final Collection sourceNodes) { + copartitionSourceGroups.add(new HashSet<>(sourceNodes)); + } + + public final void maybeUpdateCopartitionSourceGroups(final String replacedNodeName, + final String optimizedNodeName) { + for (final Set copartitionSourceGroup : copartitionSourceGroups) { + if (copartitionSourceGroup.contains(replacedNodeName)) { + copartitionSourceGroup.remove(replacedNodeName); + copartitionSourceGroup.add(optimizedNodeName); + } + } + } + + public void validateCopartition() { + // allCopartitionedSourceTopics take the list of co-partitioned nodes and + // replaces each processor name with the corresponding source topic name + final List> allCopartitionedSourceTopics = + copartitionSourceGroups + .stream() + .map(sourceGroup -> sourceGroup + .stream() + .flatMap(sourceNodeName -> nodeToSourceTopics.getOrDefault(sourceNodeName, + Collections.emptyList()).stream()) + .collect(Collectors.toSet()) + ).collect(Collectors.toList()); + for (final Set copartition : allCopartitionedSourceTopics) { + final Map numberOfPartitionsPerTopic = new HashMap<>(); + copartition.forEach(topic -> { + final InternalTopicProperties prop = internalTopicNamesWithProperties.get(topic); + if (prop != null && prop.getNumberOfPartitions().isPresent()) { + numberOfPartitionsPerTopic.put(topic, prop.getNumberOfPartitions().get()); + } + }); + if (!numberOfPartitionsPerTopic.isEmpty() && copartition.equals(numberOfPartitionsPerTopic.keySet())) { + final Collection partitionNumbers = numberOfPartitionsPerTopic.values(); + final Integer first = partitionNumbers.iterator().next(); + for (final Integer partitionNumber : partitionNumbers) { + if (!partitionNumber.equals(first)) { + final String msg = String.format("Following topics do not have the same number of " + + "partitions: [%s]", new TreeMap<>(numberOfPartitionsPerTopic)); + throw new TopologyException(msg); + + } + } + } + } + } + + private void validateGlobalStoreArguments(final String sourceName, + final String topic, + final String processorName, + final ProcessorSupplier stateUpdateSupplier, + final String storeName, + final boolean loggingEnabled) { + Objects.requireNonNull(sourceName, "sourceName must not be null"); + Objects.requireNonNull(topic, "topic must not be null"); + Objects.requireNonNull(stateUpdateSupplier, "supplier must not be null"); + Objects.requireNonNull(processorName, "processorName must not be null"); + if (nodeFactories.containsKey(sourceName)) { + throw new TopologyException("Processor " + sourceName + " is already added."); + } + if (nodeFactories.containsKey(processorName)) { + throw new TopologyException("Processor " + processorName + " is already added."); + } + if (stateFactories.containsKey(storeName)) { + throw new TopologyException("A different StateStore has already been added with the name " + storeName); + } + if (globalStateBuilders.containsKey(storeName)) { + throw new TopologyException("A different GlobalStateStore has already been added with the name " + storeName); + } + if (loggingEnabled) { + throw new TopologyException("StateStore " + storeName + " for global table must not have logging enabled."); + } + if (sourceName.equals(processorName)) { + throw new TopologyException("sourceName and processorName must be different."); + } + } + + private void connectProcessorAndStateStore(final String processorName, + final String stateStoreName) { + if (globalStateBuilders.containsKey(stateStoreName)) { + throw new TopologyException("Global StateStore " + stateStoreName + + " can be used by a Processor without being specified; it should not be explicitly passed."); + } + if (!stateFactories.containsKey(stateStoreName)) { + throw new TopologyException("StateStore " + stateStoreName + " is not added yet."); + } + if (!nodeFactories.containsKey(processorName)) { + throw new TopologyException("Processor " + processorName + " is not added yet."); + } + + final StateStoreFactory stateStoreFactory = stateFactories.get(stateStoreName); + final Iterator iter = stateStoreFactory.users().iterator(); + if (iter.hasNext()) { + final String user = iter.next(); + nodeGrouper.unite(user, processorName); + } + stateStoreFactory.users().add(processorName); + + final NodeFactory nodeFactory = nodeFactories.get(processorName); + if (nodeFactory instanceof ProcessorNodeFactory) { + final ProcessorNodeFactory processorNodeFactory = (ProcessorNodeFactory) nodeFactory; + processorNodeFactory.addStateStore(stateStoreName); + connectStateStoreNameToSourceTopicsOrPattern(stateStoreName, processorNodeFactory); + } else { + throw new TopologyException("cannot connect a state store " + stateStoreName + " to a source node or a sink node."); + } + } + + private Set> findSourcesForProcessorPredecessors(final String[] predecessors) { + final Set> sourceNodes = new HashSet<>(); + for (final String predecessor : predecessors) { + final NodeFactory nodeFactory = nodeFactories.get(predecessor); + if (nodeFactory instanceof SourceNodeFactory) { + sourceNodes.add((SourceNodeFactory) nodeFactory); + } else if (nodeFactory instanceof ProcessorNodeFactory) { + sourceNodes.addAll(findSourcesForProcessorPredecessors(((ProcessorNodeFactory) nodeFactory).predecessors)); + } + } + return sourceNodes; + } + + private void connectStateStoreNameToSourceTopicsOrPattern(final String stateStoreName, + final ProcessorNodeFactory processorNodeFactory) { + // we should never update the mapping from state store names to source topics if the store name already exists + // in the map; this scenario is possible, for example, that a state store underlying a source KTable is + // connecting to a join operator whose source topic is not the original KTable's source topic but an internal repartition topic. + + if (stateStoreNameToSourceTopics.containsKey(stateStoreName) + || stateStoreNameToSourceRegex.containsKey(stateStoreName)) { + return; + } + + final Set sourceTopics = new HashSet<>(); + final Set sourcePatterns = new HashSet<>(); + final Set> sourceNodesForPredecessor = + findSourcesForProcessorPredecessors(processorNodeFactory.predecessors); + + for (final SourceNodeFactory sourceNodeFactory : sourceNodesForPredecessor) { + if (sourceNodeFactory.pattern != null) { + sourcePatterns.add(sourceNodeFactory.pattern); + } else { + sourceTopics.addAll(sourceNodeFactory.topics); + } + } + + if (!sourceTopics.isEmpty()) { + stateStoreNameToSourceTopics.put( + stateStoreName, + Collections.unmodifiableSet(sourceTopics) + ); + } + + if (!sourcePatterns.isEmpty()) { + stateStoreNameToSourceRegex.put( + stateStoreName, + Collections.unmodifiableSet(sourcePatterns) + ); + } + + } + + private void maybeAddToResetList(final Collection earliestResets, + final Collection latestResets, + final Topology.AutoOffsetReset offsetReset, + final T item) { + if (offsetReset != null) { + switch (offsetReset) { + case EARLIEST: + earliestResets.add(item); + break; + case LATEST: + latestResets.add(item); + break; + default: + throw new TopologyException(String.format("Unrecognized reset format %s", offsetReset)); + } + } + } + + public synchronized Map> nodeGroups() { + if (nodeGroups == null) { + nodeGroups = makeNodeGroups(); + } + return nodeGroups; + } + + // Order node groups by their position in the actual topology, ie upstream subtopologies come before downstream + private Map> makeNodeGroups() { + final Map> nodeGroups = new LinkedHashMap<>(); + final Map> rootToNodeGroup = new HashMap<>(); + + int nodeGroupId = 0; + + // Traverse in topological order + for (final String nodeName : nodeFactories.keySet()) { + nodeGroupId = putNodeGroupName(nodeName, nodeGroupId, nodeGroups, rootToNodeGroup); + } + + return nodeGroups; + } + + private int putNodeGroupName(final String nodeName, + final int nodeGroupId, + final Map> nodeGroups, + final Map> rootToNodeGroup) { + int newNodeGroupId = nodeGroupId; + final String root = nodeGrouper.root(nodeName); + Set nodeGroup = rootToNodeGroup.get(root); + if (nodeGroup == null) { + nodeGroup = new HashSet<>(); + rootToNodeGroup.put(root, nodeGroup); + nodeGroups.put(newNodeGroupId++, nodeGroup); + } + nodeGroup.add(nodeName); + return newNodeGroupId; + } + + /** + * @return the full topology minus any global state + */ + public synchronized ProcessorTopology buildTopology() { + final Set nodeGroup = new HashSet<>(); + for (final Set value : nodeGroups().values()) { + nodeGroup.addAll(value); + } + nodeGroup.removeAll(globalNodeGroups()); + + initializeSubscription(); + return build(nodeGroup); + } + + /** + * @param topicGroupId group of topics corresponding to a single subtopology + * @return subset of the full topology + */ + public synchronized ProcessorTopology buildSubtopology(final int topicGroupId) { + final Set nodeGroup = nodeGroups().get(topicGroupId); + return build(nodeGroup); + } + + /** + * Builds the topology for any global state stores + * @return ProcessorTopology of global state + */ + public synchronized ProcessorTopology buildGlobalStateTopology() { + Objects.requireNonNull(applicationId, "topology has not completed optimization"); + + final Set globalGroups = globalNodeGroups(); + if (globalGroups.isEmpty()) { + return null; + } + return build(globalGroups); + } + + private Set globalNodeGroups() { + final Set globalGroups = new HashSet<>(); + for (final Map.Entry> nodeGroup : nodeGroups().entrySet()) { + final Set nodes = nodeGroup.getValue(); + for (final String node : nodes) { + if (isGlobalSource(node)) { + globalGroups.addAll(nodes); + } + } + } + return globalGroups; + } + + @SuppressWarnings("unchecked") + private ProcessorTopology build(final Set nodeGroup) { + Objects.requireNonNull(applicationId, "topology has not completed optimization"); + + final Map> processorMap = new LinkedHashMap<>(); + final Map> topicSourceMap = new HashMap<>(); + final Map> topicSinkMap = new HashMap<>(); + final Map stateStoreMap = new LinkedHashMap<>(); + final Set repartitionTopics = new HashSet<>(); + + // create processor nodes in a topological order ("nodeFactories" is already topologically sorted) + // also make sure the state store map values following the insertion ordering + for (final NodeFactory factory : nodeFactories.values()) { + if (nodeGroup == null || nodeGroup.contains(factory.name)) { + final ProcessorNode node = factory.build(); + processorMap.put(node.name(), node); + + if (factory instanceof ProcessorNodeFactory) { + buildProcessorNode(processorMap, + stateStoreMap, + (ProcessorNodeFactory) factory, + (ProcessorNode) node); + + } else if (factory instanceof SourceNodeFactory) { + buildSourceNode(topicSourceMap, + repartitionTopics, + (SourceNodeFactory) factory, + (SourceNode) node); + + } else if (factory instanceof SinkNodeFactory) { + buildSinkNode(processorMap, + topicSinkMap, + repartitionTopics, + (SinkNodeFactory) factory, + (SinkNode) node); + } else { + throw new TopologyException("Unknown definition class: " + factory.getClass().getName()); + } + } + } + + return new ProcessorTopology(new ArrayList<>(processorMap.values()), + topicSourceMap, + topicSinkMap, + new ArrayList<>(stateStoreMap.values()), + new ArrayList<>(globalStateStores.values()), + storeToChangelogTopic, + repartitionTopics); + } + + private void buildSinkNode(final Map> processorMap, + final Map> topicSinkMap, + final Set repartitionTopics, + final SinkNodeFactory sinkNodeFactory, + final SinkNode node) { + @SuppressWarnings("unchecked") final ProcessorNode sinkNode = + (ProcessorNode) node; + + for (final String predecessorName : sinkNodeFactory.predecessors) { + final ProcessorNode processor = getProcessor(processorMap, predecessorName); + processor.addChild(sinkNode); + if (sinkNodeFactory.topicExtractor instanceof StaticTopicNameExtractor) { + final String topic = ((StaticTopicNameExtractor) sinkNodeFactory.topicExtractor).topicName; + + if (internalTopicNamesWithProperties.containsKey(topic)) { + // prefix the internal topic name with the application id + final String decoratedTopic = decorateTopic(topic); + topicSinkMap.put(decoratedTopic, node); + repartitionTopics.add(decoratedTopic); + } else { + topicSinkMap.put(topic, node); + } + + } + } + } + + @SuppressWarnings("unchecked") + private static ProcessorNode getProcessor( + final Map> processorMap, + final String predecessor) { + + return (ProcessorNode) processorMap.get(predecessor); + } + + private void buildSourceNode(final Map> topicSourceMap, + final Set repartitionTopics, + final SourceNodeFactory sourceNodeFactory, + final SourceNode node) { + + final List topics = (sourceNodeFactory.pattern != null) ? + sourceNodeFactory.getTopics(subscriptionUpdates()) : + sourceNodeFactory.topics; + + for (final String topic : topics) { + if (internalTopicNamesWithProperties.containsKey(topic)) { + // prefix the internal topic name with the application id + final String decoratedTopic = decorateTopic(topic); + topicSourceMap.put(decoratedTopic, node); + repartitionTopics.add(decoratedTopic); + } else { + topicSourceMap.put(topic, node); + } + } + } + + private void buildProcessorNode(final Map> processorMap, + final Map stateStoreMap, + final ProcessorNodeFactory factory, + final ProcessorNode node) { + + for (final String predecessor : factory.predecessors) { + final ProcessorNode predecessorNode = getProcessor(processorMap, predecessor); + predecessorNode.addChild(node); + } + for (final String stateStoreName : factory.stateStoreNames) { + if (!stateStoreMap.containsKey(stateStoreName)) { + final StateStore store; + if (stateFactories.containsKey(stateStoreName)) { + final StateStoreFactory stateStoreFactory = stateFactories.get(stateStoreName); + + // remember the changelog topic if this state store is change-logging enabled + if (stateStoreFactory.loggingEnabled() && !storeToChangelogTopic.containsKey(stateStoreName)) { + final String changelogTopic = + ProcessorStateManager.storeChangelogTopic(applicationId, stateStoreName, topologyName); + storeToChangelogTopic.put(stateStoreName, changelogTopic); + changelogTopicToStore.put(changelogTopic, stateStoreName); + } + store = stateStoreFactory.build(); + stateStoreMap.put(stateStoreName, store); + } else { + store = globalStateStores.get(stateStoreName); + stateStoreMap.put(stateStoreName, store); + } + + if (store.persistent()) { + hasPersistentStores = true; + } + } + } + } + + /** + * Get any global {@link StateStore}s that are part of the + * topology + * @return map containing all global {@link StateStore}s + */ + public Map globalStateStores() { + Objects.requireNonNull(applicationId, "topology has not completed optimization"); + + return Collections.unmodifiableMap(globalStateStores); + } + + public Set allStateStoreNames() { + Objects.requireNonNull(applicationId, "topology has not completed optimization"); + + final Set allNames = new HashSet<>(stateFactories.keySet()); + allNames.addAll(globalStateStores.keySet()); + return Collections.unmodifiableSet(allNames); + } + + public boolean hasStore(final String name) { + return stateFactories.containsKey(name) || globalStateStores.containsKey(name); + } + + public boolean hasPersistentStores() { + return hasPersistentStores; + } + + /** + * Returns the map of topic groups keyed by the group id. + * A topic group is a group of topics in the same task. + * + * @return groups of topic names + */ + public synchronized Map topicGroups() { + final Map topicGroups = new LinkedHashMap<>(); + + if (nodeGroups == null) { + nodeGroups = makeNodeGroups(); + } + + for (final Map.Entry> entry : nodeGroups.entrySet()) { + final Set sinkTopics = new HashSet<>(); + final Set sourceTopics = new HashSet<>(); + final Map repartitionTopics = new HashMap<>(); + final Map stateChangelogTopics = new HashMap<>(); + for (final String node : entry.getValue()) { + // if the node is a source node, add to the source topics + final List topics = nodeToSourceTopics.get(node); + if (topics != null) { + // if some of the topics are internal, add them to the internal topics + for (final String topic : topics) { + // skip global topic as they don't need partition assignment + if (globalTopics.contains(topic)) { + continue; + } + if (internalTopicNamesWithProperties.containsKey(topic)) { + // prefix the internal topic name with the application id + final String internalTopic = decorateTopic(topic); + + final RepartitionTopicConfig repartitionTopicConfig = buildRepartitionTopicConfig( + internalTopic, + internalTopicNamesWithProperties.get(topic).getNumberOfPartitions() + ); + + repartitionTopics.put(repartitionTopicConfig.name(), repartitionTopicConfig); + sourceTopics.add(repartitionTopicConfig.name()); + } else { + sourceTopics.add(topic); + } + } + } + + // if the node is a sink node, add to the sink topics + final String topic = nodeToSinkTopic.get(node); + if (topic != null) { + if (internalTopicNamesWithProperties.containsKey(topic)) { + // prefix the change log topic name with the application id + sinkTopics.add(decorateTopic(topic)); + } else { + sinkTopics.add(topic); + } + } + + // if the node is connected to a state store whose changelog topics are not predefined, + // add to the changelog topics + for (final StateStoreFactory stateFactory : stateFactories.values()) { + if (stateFactory.users().contains(node) && storeToChangelogTopic.containsKey(stateFactory.name())) { + final String topicName = storeToChangelogTopic.get(stateFactory.name()); + if (!stateChangelogTopics.containsKey(topicName)) { + final InternalTopicConfig internalTopicConfig = + createChangelogTopicConfig(stateFactory, topicName); + stateChangelogTopics.put(topicName, internalTopicConfig); + } + } + } + } + if (!sourceTopics.isEmpty()) { + topicGroups.put(new Subtopology(entry.getKey(), topologyName), new TopicsInfo( + Collections.unmodifiableSet(sinkTopics), + Collections.unmodifiableSet(sourceTopics), + Collections.unmodifiableMap(repartitionTopics), + Collections.unmodifiableMap(stateChangelogTopics))); + } + } + + return Collections.unmodifiableMap(topicGroups); + } + + public Map> nodeToSourceTopics() { + return Collections.unmodifiableMap(nodeToSourceTopics); + } + + private RepartitionTopicConfig buildRepartitionTopicConfig(final String internalTopic, + final Optional numberOfPartitions) { + return numberOfPartitions + .map(partitions -> new RepartitionTopicConfig(internalTopic, + Collections.emptyMap(), + partitions, + true)) + .orElse(new RepartitionTopicConfig(internalTopic, Collections.emptyMap())); + } + + private void setRegexMatchedTopicsToSourceNodes() { + if (hasSubscriptionUpdates()) { + for (final String nodeName : nodeToSourcePatterns.keySet()) { + final SourceNodeFactory sourceNode = (SourceNodeFactory) nodeFactories.get(nodeName); + final List sourceTopics = sourceNode.getTopics(subscriptionUpdates); + //need to update nodeToSourceTopics and sourceTopicNames with topics matched from given regex + nodeToSourceTopics.put(nodeName, sourceTopics); + sourceTopicNames.addAll(sourceTopics); + } + log.debug("Updated nodeToSourceTopics: {}", nodeToSourceTopics); + } + } + + private void setRegexMatchedTopicToStateStore() { + if (hasSubscriptionUpdates()) { + for (final Map.Entry> storePattern : stateStoreNameToSourceRegex.entrySet()) { + final Set updatedTopicsForStateStore = new HashSet<>(); + for (final String subscriptionUpdateTopic : subscriptionUpdates()) { + for (final Pattern pattern : storePattern.getValue()) { + if (pattern.matcher(subscriptionUpdateTopic).matches()) { + updatedTopicsForStateStore.add(subscriptionUpdateTopic); + } + } + } + if (!updatedTopicsForStateStore.isEmpty()) { + final Collection storeTopics = stateStoreNameToSourceTopics.get(storePattern.getKey()); + if (storeTopics != null) { + updatedTopicsForStateStore.addAll(storeTopics); + } + stateStoreNameToSourceTopics.put( + storePattern.getKey(), + Collections.unmodifiableSet(updatedTopicsForStateStore)); + } + } + } + } + + private InternalTopicConfig createChangelogTopicConfig(final StateStoreFactory factory, + final String name) { + if (factory.isWindowStore()) { + final WindowedChangelogTopicConfig config = new WindowedChangelogTopicConfig(name, factory.logConfig()); + config.setRetentionMs(factory.retentionPeriod()); + return config; + } else { + return new UnwindowedChangelogTopicConfig(name, factory.logConfig()); + } + } + + public boolean hasOffsetResetOverrides() { + return !(earliestResetTopics.isEmpty() && earliestResetPatterns.isEmpty() + && latestResetTopics.isEmpty() && latestResetPatterns.isEmpty()); + } + + public OffsetResetStrategy offsetResetStrategy(final String topic) { + if (maybeDecorateInternalSourceTopics(earliestResetTopics).contains(topic) || + earliestResetPatterns.stream().anyMatch(p -> p.matcher(topic).matches())) { + return EARLIEST; + } else if (maybeDecorateInternalSourceTopics(latestResetTopics).contains(topic) || + latestResetPatterns.stream().anyMatch(p -> p.matcher(topic).matches())) { + return LATEST; + } else if (maybeDecorateInternalSourceTopics(sourceTopicNames).contains(topic) + || !hasNamedTopology() + || (usesPatternSubscription() && Pattern.compile(sourceTopicPatternString).matcher(topic).matches())) { + return NONE; + } else { + // return null if the topic wasn't found at all while using NamedTopologies as it's likely in another + return null; + } + } + + public Map> stateStoreNameToSourceTopics() { + final Map> results = new HashMap<>(); + for (final Map.Entry> entry : stateStoreNameToSourceTopics.entrySet()) { + results.put(entry.getKey(), maybeDecorateInternalSourceTopics(entry.getValue())); + } + return results; + } + + public Collection sourceTopicsForStore(final String storeName) { + return maybeDecorateInternalSourceTopics(stateStoreNameToSourceTopics.get(storeName)); + } + + public synchronized Collection> copartitionGroups() { + // compute transitive closures of copartitionGroups to relieve registering code to know all members + // of a copartitionGroup at the same time + final List> copartitionSourceTopics = + copartitionSourceGroups + .stream() + .map(sourceGroup -> + sourceGroup + .stream() + .flatMap(node -> maybeDecorateInternalSourceTopics(nodeToSourceTopics.get(node)).stream()) + .collect(Collectors.toSet()) + ).collect(Collectors.toList()); + + final Map> topicsToCopartitionGroup = new LinkedHashMap<>(); + for (final Set topics : copartitionSourceTopics) { + if (topics != null) { + Set coparititonGroup = null; + for (final String topic : topics) { + coparititonGroup = topicsToCopartitionGroup.get(topic); + if (coparititonGroup != null) { + break; + } + } + if (coparititonGroup == null) { + coparititonGroup = new HashSet<>(); + } + coparititonGroup.addAll(maybeDecorateInternalSourceTopics(topics)); + for (final String topic : topics) { + topicsToCopartitionGroup.put(topic, coparititonGroup); + } + } + } + final Set> uniqueCopartitionGroups = new HashSet<>(topicsToCopartitionGroup.values()); + return Collections.unmodifiableList(new ArrayList<>(uniqueCopartitionGroups)); + } + + private List maybeDecorateInternalSourceTopics(final Collection sourceTopics) { + if (sourceTopics == null) { + return Collections.emptyList(); + } + final List decoratedTopics = new ArrayList<>(); + for (final String topic : sourceTopics) { + if (internalTopicNamesWithProperties.containsKey(topic)) { + decoratedTopics.add(decorateTopic(topic)); + } else { + decoratedTopics.add(topic); + } + } + return decoratedTopics; + } + + public String decoratePseudoTopic(final String topic) { + return decorateTopic(topic); + } + + private String decorateTopic(final String topic) { + if (applicationId == null) { + throw new TopologyException("there are internal topics and " + + "applicationId hasn't been set. Call " + + "setApplicationId first"); + } + if (hasNamedTopology()) { + return applicationId + "-" + topologyName + "-" + topic; + } else { + return applicationId + "-" + topic; + } + } + + void initializeSubscription() { + if (usesPatternSubscription()) { + log.debug("Found pattern subscribed source topics, initializing consumer's subscription pattern."); + sourceTopicPatternString = buildSourceTopicsPatternString(); + } else { + log.debug("No source topics using pattern subscription found, initializing consumer's subscription collection."); + sourceTopicCollection = maybeDecorateInternalSourceTopics(sourceTopicNames); + Collections.sort(sourceTopicCollection); + } + } + + private String buildSourceTopicsPatternString() { + final List allSourceTopics = maybeDecorateInternalSourceTopics(sourceTopicNames); + Collections.sort(allSourceTopics); + + final StringBuilder builder = new StringBuilder(); + + for (final String topic : allSourceTopics) { + builder.append(topic).append("|"); + } + + for (final Pattern sourcePattern : nodeToSourcePatterns.values()) { + builder.append(sourcePattern.pattern()).append("|"); + } + + if (builder.length() > 0) { + builder.setLength(builder.length() - 1); + } + + return builder.toString(); + } + + boolean usesPatternSubscription() { + return !nodeToSourcePatterns.isEmpty(); + } + + synchronized Collection sourceTopicCollection() { + return sourceTopicCollection; + } + + synchronized String sourceTopicsPatternString() { + // With a NamedTopology, it may be that this topology does not use pattern subscription but another one does + // in which case we would need to initialize the pattern string where we would otherwise have not + if (sourceTopicPatternString == null && hasNamedTopology()) { + sourceTopicPatternString = buildSourceTopicsPatternString(); + } + return sourceTopicPatternString; + } + + public boolean hasNoLocalTopology() { + return nodeToSourcePatterns.isEmpty() && sourceTopicNames.isEmpty(); + } + + public boolean hasGlobalStores() { + return !globalStateStores.isEmpty(); + } + + private boolean isGlobalSource(final String nodeName) { + final NodeFactory nodeFactory = nodeFactories.get(nodeName); + + if (nodeFactory instanceof SourceNodeFactory) { + final List topics = ((SourceNodeFactory) nodeFactory).topics; + return topics != null && topics.size() == 1 && globalTopics.contains(topics.get(0)); + } + return false; + } + + public TopologyDescription describe() { + final TopologyDescription description = new TopologyDescription(topologyName); + + for (final Map.Entry> nodeGroup : makeNodeGroups().entrySet()) { + + final Set allNodesOfGroups = nodeGroup.getValue(); + final boolean isNodeGroupOfGlobalStores = nodeGroupContainsGlobalSourceNode(allNodesOfGroups); + + if (!isNodeGroupOfGlobalStores) { + describeSubtopology(description, nodeGroup.getKey(), allNodesOfGroups); + } else { + describeGlobalStore(description, allNodesOfGroups, nodeGroup.getKey()); + } + } + + return description; + } + + private void describeGlobalStore(final TopologyDescription description, + final Set nodes, + final int id) { + final Iterator it = nodes.iterator(); + while (it.hasNext()) { + final String node = it.next(); + + if (isGlobalSource(node)) { + // we found a GlobalStore node group; those contain exactly two node: {sourceNode,processorNode} + it.remove(); // remove sourceNode from group + final String processorNode = nodes.iterator().next(); // get remaining processorNode + + description.addGlobalStore(new GlobalStore( + node, + processorNode, + ((ProcessorNodeFactory) nodeFactories.get(processorNode)).stateStoreNames.iterator().next(), + nodeToSourceTopics.get(node).get(0), + id + )); + break; + } + } + } + + private boolean nodeGroupContainsGlobalSourceNode(final Set allNodesOfGroups) { + for (final String node : allNodesOfGroups) { + if (isGlobalSource(node)) { + return true; + } + } + return false; + } + + private static class NodeComparator implements Comparator, Serializable { + + @Override + public int compare(final TopologyDescription.Node node1, + final TopologyDescription.Node node2) { + if (node1.equals(node2)) { + return 0; + } + final int size1 = ((AbstractNode) node1).size; + final int size2 = ((AbstractNode) node2).size; + + // it is possible that two nodes have the same sub-tree size (think two nodes connected via state stores) + // in this case default to processor name string + if (size1 != size2) { + return size2 - size1; + } else { + return node1.name().compareTo(node2.name()); + } + } + } + + private final static NodeComparator NODE_COMPARATOR = new NodeComparator(); + + private static void updateSize(final AbstractNode node, + final int delta) { + node.size += delta; + + for (final TopologyDescription.Node predecessor : node.predecessors()) { + updateSize((AbstractNode) predecessor, delta); + } + } + + private void describeSubtopology(final TopologyDescription description, + final Integer subtopologyId, + final Set nodeNames) { + + final Map nodesByName = new HashMap<>(); + + // add all nodes + for (final String nodeName : nodeNames) { + nodesByName.put(nodeName, nodeFactories.get(nodeName).describe()); + } + + // connect each node to its predecessors and successors + for (final AbstractNode node : nodesByName.values()) { + for (final String predecessorName : nodeFactories.get(node.name()).predecessors) { + final AbstractNode predecessor = nodesByName.get(predecessorName); + node.addPredecessor(predecessor); + predecessor.addSuccessor(node); + updateSize(predecessor, node.size); + } + } + + description.addSubtopology(new SubtopologyDescription( + subtopologyId, + new HashSet<>(nodesByName.values()))); + } + + public final static class GlobalStore implements TopologyDescription.GlobalStore { + private final Source source; + private final Processor processor; + private final int id; + + public GlobalStore(final String sourceName, + final String processorName, + final String storeName, + final String topicName, + final int id) { + source = new Source(sourceName, Collections.singleton(topicName), null); + processor = new Processor(processorName, Collections.singleton(storeName)); + source.successors.add(processor); + processor.predecessors.add(source); + this.id = id; + } + + @Override + public int id() { + return id; + } + + @Override + public TopologyDescription.Source source() { + return source; + } + + @Override + public TopologyDescription.Processor processor() { + return processor; + } + + @Override + public String toString() { + return "Sub-topology: " + id + " for global store (will not generate tasks)\n" + + " " + source.toString() + "\n" + + " " + processor.toString() + "\n"; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + final GlobalStore that = (GlobalStore) o; + return source.equals(that.source) + && processor.equals(that.processor); + } + + @Override + public int hashCode() { + return Objects.hash(source, processor); + } + } + + public abstract static class AbstractNode implements TopologyDescription.Node { + final String name; + final Set predecessors = new TreeSet<>(NODE_COMPARATOR); + final Set successors = new TreeSet<>(NODE_COMPARATOR); + + // size of the sub-topology rooted at this node, including the node itself + int size; + + AbstractNode(final String name) { + Objects.requireNonNull(name, "name cannot be null"); + this.name = name; + this.size = 1; + } + + @Override + public String name() { + return name; + } + + @Override + public Set predecessors() { + return Collections.unmodifiableSet(predecessors); + } + + @Override + public Set successors() { + return Collections.unmodifiableSet(successors); + } + + public void addPredecessor(final TopologyDescription.Node predecessor) { + predecessors.add(predecessor); + } + + public void addSuccessor(final TopologyDescription.Node successor) { + successors.add(successor); + } + } + + public final static class Source extends AbstractNode implements TopologyDescription.Source { + private final Set topics; + private final Pattern topicPattern; + + public Source(final String name, + final Set topics, + final Pattern pattern) { + super(name); + if (topics == null && pattern == null) { + throw new IllegalArgumentException("Either topics or pattern must be not-null, but both are null."); + } + if (topics != null && pattern != null) { + throw new IllegalArgumentException("Either topics or pattern must be null, but both are not null."); + } + + this.topics = topics; + this.topicPattern = pattern; + } + + @Override + public Set topicSet() { + return topics; + } + + @Override + public Pattern topicPattern() { + return topicPattern; + } + + @Override + public void addPredecessor(final TopologyDescription.Node predecessor) { + throw new UnsupportedOperationException("Sources don't have predecessors."); + } + + @Override + public String toString() { + final String topicsString = topics == null ? topicPattern.toString() : topics.toString(); + + return "Source: " + name + " (topics: " + topicsString + ")\n --> " + nodeNames(successors); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + final Source source = (Source) o; + // omit successor to avoid infinite loops + return name.equals(source.name) + && Objects.equals(topics, source.topics) + && (topicPattern == null ? + source.topicPattern == null : + topicPattern.pattern().equals(source.topicPattern.pattern())); + } + + @Override + public int hashCode() { + // omit successor as it might change and alter the hash code + return Objects.hash(name, topics, topicPattern); + } + } + + public final static class Processor extends AbstractNode implements TopologyDescription.Processor { + private final Set stores; + + public Processor(final String name, + final Set stores) { + super(name); + this.stores = stores; + } + + @Override + public Set stores() { + return Collections.unmodifiableSet(stores); + } + + @Override + public String toString() { + return "Processor: " + name + " (stores: " + stores + ")\n --> " + + nodeNames(successors) + "\n <-- " + nodeNames(predecessors); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + final Processor processor = (Processor) o; + // omit successor to avoid infinite loops + return name.equals(processor.name) + && stores.equals(processor.stores) + && predecessors.equals(processor.predecessors); + } + + @Override + public int hashCode() { + // omit successor as it might change and alter the hash code + return Objects.hash(name, stores); + } + } + + public final static class Sink extends AbstractNode implements TopologyDescription.Sink { + private final TopicNameExtractor topicNameExtractor; + public Sink(final String name, + final TopicNameExtractor topicNameExtractor) { + super(name); + this.topicNameExtractor = topicNameExtractor; + } + + public Sink(final String name, + final String topic) { + super(name); + this.topicNameExtractor = new StaticTopicNameExtractor<>(topic); + } + + @Override + public String topic() { + if (topicNameExtractor instanceof StaticTopicNameExtractor) { + return ((StaticTopicNameExtractor) topicNameExtractor).topicName; + } else { + return null; + } + } + + @Override + public TopicNameExtractor topicNameExtractor() { + if (topicNameExtractor instanceof StaticTopicNameExtractor) { + return null; + } else { + return topicNameExtractor; + } + } + + @Override + public void addSuccessor(final TopologyDescription.Node successor) { + throw new UnsupportedOperationException("Sinks don't have successors."); + } + + @Override + public String toString() { + if (topicNameExtractor instanceof StaticTopicNameExtractor) { + return "Sink: " + name + " (topic: " + topic() + ")\n <-- " + nodeNames(predecessors); + } + return "Sink: " + name + " (extractor class: " + topicNameExtractor + ")\n <-- " + + nodeNames(predecessors); + } + + @SuppressWarnings("unchecked") + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + final Sink sink = (Sink) o; + return name.equals(sink.name) + && topicNameExtractor.equals(sink.topicNameExtractor) + && predecessors.equals(sink.predecessors); + } + + @Override + public int hashCode() { + // omit predecessors as it might change and alter the hash code + return Objects.hash(name, topicNameExtractor); + } + } + + public final static class SubtopologyDescription implements TopologyDescription.Subtopology { + private final int id; + private final Set nodes; + + public SubtopologyDescription(final int id, final Set nodes) { + this.id = id; + this.nodes = new TreeSet<>(NODE_COMPARATOR); + this.nodes.addAll(nodes); + } + + @Override + public int id() { + return id; + } + + @Override + public Set nodes() { + return Collections.unmodifiableSet(nodes); + } + + // visible for testing + Iterator nodesInOrder() { + return nodes.iterator(); + } + + @Override + public String toString() { + return "Sub-topology: " + id + "\n" + nodesAsString() + "\n"; + } + + private String nodesAsString() { + final StringBuilder sb = new StringBuilder(); + for (final TopologyDescription.Node node : nodes) { + sb.append(" "); + sb.append(node); + sb.append('\n'); + } + return sb.toString(); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + final SubtopologyDescription that = (SubtopologyDescription) o; + return id == that.id + && nodes.equals(that.nodes); + } + + @Override + public int hashCode() { + return Objects.hash(id, nodes); + } + } + + public static class TopicsInfo { + public final Set sinkTopics; + public final Set sourceTopics; + public final Map stateChangelogTopics; + public final Map repartitionSourceTopics; + + TopicsInfo(final Set sinkTopics, + final Set sourceTopics, + final Map repartitionSourceTopics, + final Map stateChangelogTopics) { + this.sinkTopics = sinkTopics; + this.sourceTopics = sourceTopics; + this.stateChangelogTopics = stateChangelogTopics; + this.repartitionSourceTopics = repartitionSourceTopics; + } + + /** + * Returns the config for any changelogs that must be prepared for this topic group, ie excluding any source + * topics that are reused as a changelog + */ + public Set nonSourceChangelogTopics() { + final Set topicConfigs = new HashSet<>(); + for (final Map.Entry changelogTopicEntry : stateChangelogTopics.entrySet()) { + if (!sourceTopics.contains(changelogTopicEntry.getKey())) { + topicConfigs.add(changelogTopicEntry.getValue()); + } + } + return topicConfigs; + } + + /** + * Returns the topic names for any optimized source changelogs + */ + public Set sourceTopicChangelogs() { + return sourceTopics.stream().filter(stateChangelogTopics::containsKey).collect(Collectors.toSet()); + } + + @Override + public boolean equals(final Object o) { + if (o instanceof TopicsInfo) { + final TopicsInfo other = (TopicsInfo) o; + return other.sourceTopics.equals(sourceTopics) && other.stateChangelogTopics.equals(stateChangelogTopics); + } else { + return false; + } + } + + @Override + public int hashCode() { + final long n = ((long) sourceTopics.hashCode() << 32) | (long) stateChangelogTopics.hashCode(); + return (int) (n % 0xFFFFFFFFL); + } + + @Override + public String toString() { + return "TopicsInfo{" + + "sinkTopics=" + sinkTopics + + ", sourceTopics=" + sourceTopics + + ", repartitionSourceTopics=" + repartitionSourceTopics + + ", stateChangelogTopics=" + stateChangelogTopics + + '}'; + } + } + + private static class GlobalStoreComparator implements Comparator, Serializable { + @Override + public int compare(final TopologyDescription.GlobalStore globalStore1, + final TopologyDescription.GlobalStore globalStore2) { + if (globalStore1.equals(globalStore2)) { + return 0; + } + return globalStore1.id() - globalStore2.id(); + } + } + + private final static GlobalStoreComparator GLOBALSTORE_COMPARATOR = new GlobalStoreComparator(); + + private static class SubtopologyComparator implements Comparator, Serializable { + @Override + public int compare(final TopologyDescription.Subtopology subtopology1, + final TopologyDescription.Subtopology subtopology2) { + if (subtopology1.equals(subtopology2)) { + return 0; + } + return subtopology1.id() - subtopology2.id(); + } + } + + private final static SubtopologyComparator SUBTOPOLOGY_COMPARATOR = new SubtopologyComparator(); + + public final static class TopologyDescription implements org.apache.kafka.streams.TopologyDescription { + private final TreeSet subtopologies = new TreeSet<>(SUBTOPOLOGY_COMPARATOR); + private final TreeSet globalStores = new TreeSet<>(GLOBALSTORE_COMPARATOR); + private final String namedTopology; + + public TopologyDescription() { + this(null); + } + + public TopologyDescription(final String namedTopology) { + this.namedTopology = namedTopology; + } + + public void addSubtopology(final TopologyDescription.Subtopology subtopology) { + subtopologies.add(subtopology); + } + + public void addGlobalStore(final TopologyDescription.GlobalStore globalStore) { + globalStores.add(globalStore); + } + + @Override + public Set subtopologies() { + return Collections.unmodifiableSet(subtopologies); + } + + @Override + public Set globalStores() { + return Collections.unmodifiableSet(globalStores); + } + + @Override + public String toString() { + final StringBuilder sb = new StringBuilder(); + + if (namedTopology == null) { + sb.append("Topologies:\n "); + } else { + sb.append("Topology - ").append(namedTopology).append(":\n "); + } + final TopologyDescription.Subtopology[] sortedSubtopologies = + subtopologies.descendingSet().toArray(new TopologyDescription.Subtopology[0]); + final TopologyDescription.GlobalStore[] sortedGlobalStores = + globalStores.descendingSet().toArray(new GlobalStore[0]); + int expectedId = 0; + int subtopologiesIndex = sortedSubtopologies.length - 1; + int globalStoresIndex = sortedGlobalStores.length - 1; + while (subtopologiesIndex != -1 && globalStoresIndex != -1) { + sb.append(" "); + final TopologyDescription.Subtopology subtopology = sortedSubtopologies[subtopologiesIndex]; + final TopologyDescription.GlobalStore globalStore = sortedGlobalStores[globalStoresIndex]; + if (subtopology.id() == expectedId) { + sb.append(subtopology); + subtopologiesIndex--; + } else { + sb.append(globalStore); + globalStoresIndex--; + } + expectedId++; + } + while (subtopologiesIndex != -1) { + final TopologyDescription.Subtopology subtopology = sortedSubtopologies[subtopologiesIndex]; + sb.append(" "); + sb.append(subtopology); + subtopologiesIndex--; + } + while (globalStoresIndex != -1) { + final TopologyDescription.GlobalStore globalStore = sortedGlobalStores[globalStoresIndex]; + sb.append(" "); + sb.append(globalStore); + globalStoresIndex--; + } + return sb.toString(); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + final TopologyDescription that = (TopologyDescription) o; + return subtopologies.equals(that.subtopologies) + && globalStores.equals(that.globalStores); + } + + @Override + public int hashCode() { + return Objects.hash(subtopologies, globalStores); + } + + } + + private static String nodeNames(final Set nodes) { + final StringBuilder sb = new StringBuilder(); + if (!nodes.isEmpty()) { + for (final TopologyDescription.Node n : nodes) { + sb.append(n.name()); + sb.append(", "); + } + sb.deleteCharAt(sb.length() - 1); + sb.deleteCharAt(sb.length() - 1); + } else { + return "none"; + } + return sb.toString(); + } + + private Set subscriptionUpdates() { + return Collections.unmodifiableSet(subscriptionUpdates); + } + + private boolean hasSubscriptionUpdates() { + return !subscriptionUpdates.isEmpty(); + } + + synchronized void addSubscribedTopicsFromAssignment(final List partitions, final String logPrefix) { + if (usesPatternSubscription()) { + final Set assignedTopics = new HashSet<>(); + for (final TopicPartition topicPartition : partitions) { + assignedTopics.add(topicPartition.topic()); + } + + final Collection existingTopics = subscriptionUpdates(); + + if (!existingTopics.equals(assignedTopics)) { + assignedTopics.addAll(existingTopics); + updateSubscribedTopics(assignedTopics, logPrefix); + } + } + } + + synchronized void addSubscribedTopicsFromMetadata(final Set topics, final String logPrefix) { + if (usesPatternSubscription() && !subscriptionUpdates().equals(topics)) { + updateSubscribedTopics(topics, logPrefix); + } + } + + private void updateSubscribedTopics(final Set topics, final String logPrefix) { + subscriptionUpdates.clear(); + subscriptionUpdates.addAll(topics); + + log.debug("{}found {} topics possibly matching subscription", logPrefix, topics.size()); + + setRegexMatchedTopicsToSourceNodes(); + setRegexMatchedTopicToStateStore(); + } + + /** + * @return a copy of all source topic names, including the application id and named topology prefix if applicable + */ + public synchronized List fullSourceTopicNames() { + return new ArrayList<>(maybeDecorateInternalSourceTopics(sourceTopicNames)); + } + + /** + * @return a copy of the string representation of any pattern subscribed source nodes + */ + public synchronized List allSourcePatternStrings() { + return nodeToSourcePatterns.values().stream().map(Pattern::pattern).collect(Collectors.toList()); + } + + public boolean hasNamedTopology() { + return topologyName != null; + } + + // following functions are for test only + public synchronized Set sourceTopicNames() { + return sourceTopicNames; + } + + public synchronized Map> stateStores() { + return stateFactories; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGroup.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGroup.java new file mode 100644 index 0000000..199bc0e --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGroup.java @@ -0,0 +1,373 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.streams.StreamsConfig; +import org.slf4j.Logger; + +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Map; +import java.util.OptionalLong; +import java.util.PriorityQueue; +import java.util.Set; +import java.util.function.Function; + +/** + * PartitionGroup is used to buffer all co-partitioned records for processing. + * + * In other words, it represents the "same" partition over multiple co-partitioned topics, and it is used + * to buffer records from that partition in each of the contained topic-partitions. + * Each StreamTask has exactly one PartitionGroup. + * + * PartitionGroup implements the algorithm that determines in what order buffered records are selected for processing. + * + * Specifically, when polled, it returns the record from the topic-partition with the lowest stream-time. + * Stream-time for a topic-partition is defined as the highest timestamp + * yet observed at the head of that topic-partition. + * + * PartitionGroup also maintains a stream-time for the group as a whole. + * This is defined as the highest timestamp of any record yet polled from the PartitionGroup. + * Note however that any computation that depends on stream-time should track it on a per-operator basis to obtain an + * accurate view of the local time as seen by that processor. + * + * The PartitionGroups's stream-time is initially UNKNOWN (-1), and it set to a known value upon first poll. + * As a consequence of the definition, the PartitionGroup's stream-time is non-decreasing + * (i.e., it increases or stays the same over time). + */ +public class PartitionGroup { + + private final Logger logger; + private final Map partitionQueues; + private final Function lagProvider; + private final Sensor enforcedProcessingSensor; + private final long maxTaskIdleMs; + private final Sensor recordLatenessSensor; + private final PriorityQueue nonEmptyQueuesByTime; + + private long streamTime; + private int totalBuffered; + private boolean allBuffered; + private final Map idlePartitionDeadlines = new HashMap<>(); + + static class RecordInfo { + RecordQueue queue; + + ProcessorNode node() { + return queue.source(); + } + + TopicPartition partition() { + return queue.partition(); + } + + RecordQueue queue() { + return queue; + } + } + + PartitionGroup(final LogContext logContext, + final Map partitionQueues, + final Function lagProvider, + final Sensor recordLatenessSensor, + final Sensor enforcedProcessingSensor, + final long maxTaskIdleMs) { + this.logger = logContext.logger(PartitionGroup.class); + nonEmptyQueuesByTime = new PriorityQueue<>(partitionQueues.size(), Comparator.comparingLong(RecordQueue::headRecordTimestamp)); + this.partitionQueues = partitionQueues; + this.lagProvider = lagProvider; + this.enforcedProcessingSensor = enforcedProcessingSensor; + this.maxTaskIdleMs = maxTaskIdleMs; + this.recordLatenessSensor = recordLatenessSensor; + totalBuffered = 0; + allBuffered = false; + streamTime = RecordQueue.UNKNOWN; + } + + public boolean readyToProcess(final long wallClockTime) { + if (maxTaskIdleMs == StreamsConfig.MAX_TASK_IDLE_MS_DISABLED) { + if (logger.isTraceEnabled() && !allBuffered && totalBuffered > 0) { + final Set bufferedPartitions = new HashSet<>(); + final Set emptyPartitions = new HashSet<>(); + for (final Map.Entry entry : partitionQueues.entrySet()) { + if (entry.getValue().isEmpty()) { + emptyPartitions.add(entry.getKey()); + } else { + bufferedPartitions.add(entry.getKey()); + } + } + logger.trace("Ready for processing because max.task.idle.ms is disabled." + + "\n\tThere may be out-of-order processing for this task as a result." + + "\n\tBuffered partitions: {}" + + "\n\tNon-buffered partitions: {}", + bufferedPartitions, + emptyPartitions); + } + return true; + } + + final Set queued = new HashSet<>(); + Map enforced = null; + + for (final Map.Entry entry : partitionQueues.entrySet()) { + final TopicPartition partition = entry.getKey(); + final RecordQueue queue = entry.getValue(); + + + if (!queue.isEmpty()) { + // this partition is ready for processing + idlePartitionDeadlines.remove(partition); + queued.add(partition); + } else { + final OptionalLong fetchedLag = lagProvider.apply(partition); + + if (!fetchedLag.isPresent()) { + // must wait to fetch metadata for the partition + idlePartitionDeadlines.remove(partition); + logger.trace("Waiting to fetch data for {}", partition); + return false; + } else if (fetchedLag.getAsLong() > 0L) { + // must wait to poll the data we know to be on the broker + idlePartitionDeadlines.remove(partition); + logger.trace( + "Lag for {} is currently {}, but no data is buffered locally. Waiting to buffer some records.", + partition, + fetchedLag.getAsLong() + ); + return false; + } else { + // p is known to have zero lag. wait for maxTaskIdleMs to see if more data shows up. + // One alternative would be to set the deadline to nullableMetadata.receivedTimestamp + maxTaskIdleMs + // instead. That way, we would start the idling timer as of the freshness of our knowledge about the zero + // lag instead of when we happen to run this method, but realistically it's probably a small difference + // and using wall clock time seems more intuitive for users, + // since the log message will be as of wallClockTime. + idlePartitionDeadlines.putIfAbsent(partition, wallClockTime + maxTaskIdleMs); + final long deadline = idlePartitionDeadlines.get(partition); + if (wallClockTime < deadline) { + logger.trace( + "Lag for {} is currently 0 and current time is {}. Waiting for new data to be produced for configured idle time {} (deadline is {}).", + partition, + wallClockTime, + maxTaskIdleMs, + deadline + ); + return false; + } else { + // this partition is ready for processing due to the task idling deadline passing + if (enforced == null) { + enforced = new HashMap<>(); + } + enforced.put(partition, deadline); + } + } + } + } + if (enforced == null) { + logger.trace("All partitions were buffered locally, so this task is ready for processing."); + return true; + } else if (queued.isEmpty()) { + logger.trace("No partitions were buffered locally, so this task is not ready for processing."); + return false; + } else { + enforcedProcessingSensor.record(1.0d, wallClockTime); + logger.trace("Continuing to process although some partitions are empty on the broker." + + "\n\tThere may be out-of-order processing for this task as a result." + + "\n\tPartitions with local data: {}." + + "\n\tPartitions we gave up waiting for, with their corresponding deadlines: {}." + + "\n\tConfigured max.task.idle.ms: {}." + + "\n\tCurrent wall-clock time: {}.", + queued, + enforced, + maxTaskIdleMs, + wallClockTime); + return true; + } + } + + // visible for testing + long partitionTimestamp(final TopicPartition partition) { + final RecordQueue queue = partitionQueues.get(partition); + if (queue == null) { + throw new IllegalStateException("Partition " + partition + " not found."); + } + return queue.partitionTime(); + } + + // creates queues for new partitions, removes old queues, saves cached records for previously assigned partitions + void updatePartitions(final Set newInputPartitions, final Function recordQueueCreator) { + final Set removedPartitions = new HashSet<>(); + final Iterator> queuesIterator = partitionQueues.entrySet().iterator(); + while (queuesIterator.hasNext()) { + final Map.Entry queueEntry = queuesIterator.next(); + final TopicPartition topicPartition = queueEntry.getKey(); + if (!newInputPartitions.contains(topicPartition)) { + // if partition is removed should delete its queue + totalBuffered -= queueEntry.getValue().size(); + queuesIterator.remove(); + removedPartitions.add(topicPartition); + } + newInputPartitions.remove(topicPartition); + } + for (final TopicPartition newInputPartition : newInputPartitions) { + partitionQueues.put(newInputPartition, recordQueueCreator.apply(newInputPartition)); + } + nonEmptyQueuesByTime.removeIf(q -> removedPartitions.contains(q.partition())); + allBuffered = allBuffered && newInputPartitions.isEmpty(); + } + + void setPartitionTime(final TopicPartition partition, final long partitionTime) { + final RecordQueue queue = partitionQueues.get(partition); + if (queue == null) { + throw new IllegalStateException("Partition " + partition + " not found."); + } + if (streamTime < partitionTime) { + streamTime = partitionTime; + } + queue.setPartitionTime(partitionTime); + } + + /** + * Get the next record and queue + * + * @return StampedRecord + */ + StampedRecord nextRecord(final RecordInfo info, final long wallClockTime) { + StampedRecord record = null; + + final RecordQueue queue = nonEmptyQueuesByTime.poll(); + info.queue = queue; + + if (queue != null) { + // get the first record from this queue. + record = queue.poll(); + + if (record != null) { + --totalBuffered; + + if (queue.isEmpty()) { + // if a certain queue has been drained, reset the flag + allBuffered = false; + } else { + nonEmptyQueuesByTime.offer(queue); + } + + // always update the stream-time to the record's timestamp yet to be processed if it is larger + if (record.timestamp > streamTime) { + streamTime = record.timestamp; + recordLatenessSensor.record(0, wallClockTime); + } else { + recordLatenessSensor.record(streamTime - record.timestamp, wallClockTime); + } + } + } + + return record; + } + + /** + * Adds raw records to this partition group + * + * @param partition the partition + * @param rawRecords the raw records + * @return the queue size for the partition + */ + int addRawRecords(final TopicPartition partition, final Iterable> rawRecords) { + final RecordQueue recordQueue = partitionQueues.get(partition); + + if (recordQueue == null) { + throw new IllegalStateException("Partition " + partition + " not found."); + } + + final int oldSize = recordQueue.size(); + final int newSize = recordQueue.addRawRecords(rawRecords); + + // add this record queue to be considered for processing in the future if it was empty before + if (oldSize == 0 && newSize > 0) { + nonEmptyQueuesByTime.offer(recordQueue); + + // if all partitions now are non-empty, set the flag + // we do not need to update the stream-time here since this task will definitely be + // processed next, and hence the stream-time will be updated when we retrieved records by then + if (nonEmptyQueuesByTime.size() == this.partitionQueues.size()) { + allBuffered = true; + } + } + + totalBuffered += newSize - oldSize; + + return newSize; + } + + Set partitions() { + return Collections.unmodifiableSet(partitionQueues.keySet()); + } + + /** + * Return the stream-time of this partition group defined as the largest timestamp seen across all partitions + */ + long streamTime() { + return streamTime; + } + + Long headRecordOffset(final TopicPartition partition) { + final RecordQueue recordQueue = partitionQueues.get(partition); + + if (recordQueue == null) { + throw new IllegalStateException("Partition " + partition + " not found."); + } + + return recordQueue.headRecordOffset(); + } + + /** + * @throws IllegalStateException if the record's partition does not belong to this partition group + */ + int numBuffered(final TopicPartition partition) { + final RecordQueue recordQueue = partitionQueues.get(partition); + + if (recordQueue == null) { + throw new IllegalStateException("Partition " + partition + " not found."); + } + + return recordQueue.size(); + } + + int numBuffered() { + return totalBuffered; + } + + boolean allPartitionsBufferedLocally() { + return allBuffered; + } + + void clear() { + for (final RecordQueue queue : partitionQueues.values()) { + queue.clear(); + } + nonEmptyQueuesByTime.clear(); + totalBuffered = 0; + streamTime = RecordQueue.UNKNOWN; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGrouper.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGrouper.java new file mode 100644 index 0000000..a664756 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/PartitionGrouper.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Groups partitions by the partition id. + * + * Join operations requires that topics of the joining entities are copartitioned, i.e., being partitioned by the same key and having the same + * number of partitions. Copartitioning is ensured by having the same number of partitions on + * joined topics, and by using the serialization and Producer's default partitioner. + */ +public class PartitionGrouper { + + private static final Logger log = LoggerFactory.getLogger(PartitionGrouper.class); + /** + * Generate tasks with the assigned topic partitions. + * + * @param topicGroups group of topics that need to be joined together + * @param metadata metadata of the consuming cluster + * @return The map from generated task ids to the assigned partitions + */ + public Map> partitionGroups(final Map> topicGroups, final Cluster metadata) { + final Map> groups = new HashMap<>(); + + for (final Map.Entry> entry : topicGroups.entrySet()) { + final Subtopology subtopology = entry.getKey(); + final Set topicGroup = entry.getValue(); + + final int maxNumPartitions = maxNumPartitions(metadata, topicGroup); + + for (int partitionId = 0; partitionId < maxNumPartitions; partitionId++) { + final Set group = new HashSet<>(topicGroup.size()); + + for (final String topic : topicGroup) { + final List partitions = metadata.partitionsForTopic(topic); + if (partitionId < partitions.size()) { + group.add(new TopicPartition(topic, partitionId)); + } + } + groups.put(new TaskId(subtopology.nodeGroupId, partitionId, subtopology.namedTopology), Collections.unmodifiableSet(group)); + } + } + + return Collections.unmodifiableMap(groups); + } + + /** + * @throws StreamsException if no metadata can be received for a topic + */ + protected int maxNumPartitions(final Cluster metadata, final Set topics) { + int maxNumPartitions = 0; + for (final String topic : topics) { + final List partitions = metadata.partitionsForTopic(topic); + if (partitions.isEmpty()) { + log.error("Empty partitions for topic {}", topic); + throw new RuntimeException("Empty partitions for topic " + topic); + } + + final int numPartitions = partitions.size(); + if (numPartitions > maxNumPartitions) { + maxNumPartitions = numPartitions; + } + } + return maxNumPartitions; + } + +} + + + diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorAdapter.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorAdapter.java new file mode 100644 index 0000000..79db384 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorAdapter.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + + +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.Record; + +@SuppressWarnings("deprecation") // Old PAPI compatibility +public final class ProcessorAdapter implements Processor { + private final org.apache.kafka.streams.processor.Processor delegate; + private InternalProcessorContext context; + + public static Processor adapt(final org.apache.kafka.streams.processor.Processor delegate) { + if (delegate == null) { + return null; + } else { + return new ProcessorAdapter<>(delegate); + } + } + + @SuppressWarnings({"rawtypes", "unchecked"}) + public static Processor adaptRaw(final org.apache.kafka.streams.processor.Processor delegate) { + if (delegate == null) { + return null; + } else { + return new ProcessorAdapter<>(delegate); + } + } + + private ProcessorAdapter(final org.apache.kafka.streams.processor.Processor delegate) { + this.delegate = delegate; + } + + @Override + public void init(final ProcessorContext context) { + // It only makes sense to use this adapter internally to Streams, in which case + // all contexts are implementations of InternalProcessorContext. + // This would fail if someone were to use this adapter in a unit test where + // the context only implements api.ProcessorContext. + this.context = (InternalProcessorContext) context; + delegate.init((org.apache.kafka.streams.processor.ProcessorContext) context); + } + + @Override + public void process(final Record record) { + final ProcessorRecordContext processorRecordContext = context.recordContext(); + try { + context.setRecordContext(new ProcessorRecordContext( + record.timestamp(), + context.offset(), + context.partition(), + context.topic(), + record.headers() + )); + delegate.process(record.key(), record.value()); + } finally { + context.setRecordContext(processorRecordContext); + } + } + + @Override + public void close() { + delegate.close(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java new file mode 100644 index 0000000..ce06cb1 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextImpl.java @@ -0,0 +1,332 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.Cancellable; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.Punctuator; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.To; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.Task.TaskType; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.internals.ThreadCache; +import org.apache.kafka.streams.state.internals.ThreadCache.DirtyEntryFlushListener; + +import java.time.Duration; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.kafka.streams.internals.ApiUtils.prepareMillisCheckFailMsgPrefix; +import static org.apache.kafka.streams.internals.ApiUtils.validateMillisecondDuration; +import static org.apache.kafka.streams.processor.internals.AbstractReadOnlyDecorator.getReadOnlyStore; +import static org.apache.kafka.streams.processor.internals.AbstractReadWriteDecorator.getReadWriteStore; + +public class ProcessorContextImpl extends AbstractProcessorContext implements RecordCollector.Supplier { + // the below are null for standby tasks + private StreamTask streamTask; + private RecordCollector collector; + + private final ProcessorStateManager stateManager; + + final Map cacheNameToFlushListener = new HashMap<>(); + + public ProcessorContextImpl(final TaskId id, + final StreamsConfig config, + final ProcessorStateManager stateMgr, + final StreamsMetricsImpl metrics, + final ThreadCache cache) { + super(id, config, metrics, cache); + stateManager = stateMgr; + } + + @Override + public void transitionToActive(final StreamTask streamTask, final RecordCollector recordCollector, final ThreadCache newCache) { + if (stateManager.taskType() != TaskType.ACTIVE) { + throw new IllegalStateException("Tried to transition processor context to active but the state manager's " + + "type was " + stateManager.taskType()); + } + this.streamTask = streamTask; + this.collector = recordCollector; + this.cache = newCache; + addAllFlushListenersToNewCache(); + } + + @Override + public void transitionToStandby(final ThreadCache newCache) { + if (stateManager.taskType() != TaskType.STANDBY) { + throw new IllegalStateException("Tried to transition processor context to standby but the state manager's " + + "type was " + stateManager.taskType()); + } + this.streamTask = null; + this.collector = null; + this.cache = newCache; + addAllFlushListenersToNewCache(); + } + + @Override + public void registerCacheFlushListener(final String namespace, final DirtyEntryFlushListener listener) { + cacheNameToFlushListener.put(namespace, listener); + cache.addDirtyEntryFlushListener(namespace, listener); + } + + private void addAllFlushListenersToNewCache() { + for (final Map.Entry cacheEntry : cacheNameToFlushListener.entrySet()) { + cache.addDirtyEntryFlushListener(cacheEntry.getKey(), cacheEntry.getValue()); + } + } + + @Override + public ProcessorStateManager stateManager() { + return stateManager; + } + + @Override + public RecordCollector recordCollector() { + return collector; + } + + @Override + public void logChange(final String storeName, + final Bytes key, + final byte[] value, + final long timestamp) { + throwUnsupportedOperationExceptionIfStandby("logChange"); + + final TopicPartition changelogPartition = stateManager().registeredChangelogPartitionFor(storeName); + + // Sending null headers to changelog topics (KIP-244) + collector.send( + changelogPartition.topic(), + key, + value, + null, + changelogPartition.partition(), + timestamp, + BYTES_KEY_SERIALIZER, + BYTEARRAY_VALUE_SERIALIZER + ); + } + + /** + * @throws StreamsException if an attempt is made to access this state store from an unknown node + * @throws UnsupportedOperationException if the current streamTask type is standby + */ + @SuppressWarnings("unchecked") + @Override + public S getStateStore(final String name) { + throwUnsupportedOperationExceptionIfStandby("getStateStore"); + if (currentNode() == null) { + throw new StreamsException("Accessing from an unknown node"); + } + + final StateStore globalStore = stateManager.getGlobalStore(name); + if (globalStore != null) { + return (S) getReadOnlyStore(globalStore); + } + + if (!currentNode().stateStores.contains(name)) { + throw new StreamsException("Processor " + currentNode().name() + " has no access to StateStore " + name + + " as the store is not connected to the processor. If you add stores manually via '.addStateStore()' " + + "make sure to connect the added store to the processor by providing the processor name to " + + "'.addStateStore()' or connect them via '.connectProcessorAndStateStores()'. " + + "DSL users need to provide the store name to '.process()', '.transform()', or '.transformValues()' " + + "to connect the store to the corresponding operator, or they can provide a StoreBuilder by implementing " + + "the stores() method on the Supplier itself. If you do not add stores manually, " + + "please file a bug report at https://issues.apache.org/jira/projects/KAFKA."); + } + + final StateStore store = stateManager.getStore(name); + return (S) getReadWriteStore(store); + } + + @Override + public void forward(final K key, + final V value) { + final Record toForward = new Record<>( + key, + value, + timestamp(), + headers() + ); + forward(toForward); + } + + @Override + public void forward(final K key, + final V value, + final To to) { + final ToInternal toInternal = new ToInternal(to); + final Record toForward = new Record<>( + key, + value, + toInternal.hasTimestamp() ? toInternal.timestamp() : timestamp(), + headers() + ); + forward(toForward, toInternal.child()); + } + + @Override + public void forward(final Record record) { + forward(record, null); + } + + @SuppressWarnings("unchecked") + @Override + public void forward(final Record record, final String childName) { + throwUnsupportedOperationExceptionIfStandby("forward"); + + final ProcessorNode previousNode = currentNode(); + if (previousNode == null) { + throw new StreamsException("Current node is unknown. This can happen if 'forward()' is called " + + "in an illegal scope. The root cause could be that a 'Processor' or 'Transformer' instance" + + " is shared. To avoid this error, make sure that your suppliers return new instances " + + "each time 'get()' of Supplier is called and do not return the same object reference " + + "multiple times."); + } + + final ProcessorRecordContext previousContext = recordContext; + + try { + // we don't want to set the recordContext if it's null, since that means that + // the context itself is undefined. this isn't perfect, since downstream + // old API processors wouldn't see the timestamps or headers of upstream + // new API processors. But then again, from the perspective of those old-API + // processors, even consulting the timestamp or headers when the record context + // is undefined is itself not well defined. Plus, I don't think we need to worry + // too much about heterogeneous applications, in which the upstream processor is + // implementing the new API and the downstream one is implementing the old API. + // So, this seems like a fine compromise for now. + if (recordContext != null && (record.timestamp() != timestamp() || record.headers() != headers())) { + recordContext = new ProcessorRecordContext( + record.timestamp(), + recordContext.offset(), + recordContext.partition(), + recordContext.topic(), + record.headers()); + } + + if (childName == null) { + final List> children = currentNode().children(); + for (final ProcessorNode child : children) { + forwardInternal((ProcessorNode) child, record); + } + } else { + final ProcessorNode child = currentNode().getChild(childName); + if (child == null) { + throw new StreamsException("Unknown downstream node: " + childName + + " either does not exist or is not connected to this processor."); + } + forwardInternal((ProcessorNode) child, record); + } + + } finally { + recordContext = previousContext; + setCurrentNode(previousNode); + } + } + + private void forwardInternal(final ProcessorNode child, + final Record record) { + setCurrentNode(child); + + child.process(record); + + if (child.isTerminalNode()) { + streamTask.maybeRecordE2ELatency(record.timestamp(), currentSystemTimeMs(), child.name()); + } + } + + @Override + public void commit() { + throwUnsupportedOperationExceptionIfStandby("commit"); + streamTask.requestCommit(); + } + + @Override + public Cancellable schedule(final Duration interval, + final PunctuationType type, + final Punctuator callback) throws IllegalArgumentException { + throwUnsupportedOperationExceptionIfStandby("schedule"); + final String msgPrefix = prepareMillisCheckFailMsgPrefix(interval, "interval"); + final long intervalMs = validateMillisecondDuration(interval, msgPrefix); + if (intervalMs < 1) { + throw new IllegalArgumentException("The minimum supported scheduling interval is 1 millisecond."); + } + return streamTask.schedule(intervalMs, type, callback); + } + + @Override + public String topic() { + throwUnsupportedOperationExceptionIfStandby("topic"); + return super.topic(); + } + + @Override + public int partition() { + throwUnsupportedOperationExceptionIfStandby("partition"); + return super.partition(); + } + + @Override + public long offset() { + throwUnsupportedOperationExceptionIfStandby("offset"); + return super.offset(); + } + + @Override + public long timestamp() { + throwUnsupportedOperationExceptionIfStandby("timestamp"); + return super.timestamp(); + } + + @Override + public long currentStreamTimeMs() { + return streamTask.streamTime(); + } + + @Override + public ProcessorNode currentNode() { + throwUnsupportedOperationExceptionIfStandby("currentNode"); + return super.currentNode(); + } + + @Override + public void setRecordContext(final ProcessorRecordContext recordContext) { + throwUnsupportedOperationExceptionIfStandby("setRecordContext"); + super.setRecordContext(recordContext); + } + + @Override + public ProcessorRecordContext recordContext() { + throwUnsupportedOperationExceptionIfStandby("recordContext"); + return super.recordContext(); + } + + private void throwUnsupportedOperationExceptionIfStandby(final String operationName) { + if (taskType() == TaskType.STANDBY) { + throw new UnsupportedOperationException( + "this should not happen: " + operationName + "() is not supported in standby tasks."); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextUtils.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextUtils.java new file mode 100644 index 0000000..4de059e --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorContextUtils.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; + +/** + * This class bridges the gap for components that _should_ be compatible with + * the public ProcessorContext interface, but have come to depend on features + * in InternalProcessorContext. In theory, all the features adapted here could + * migrate to the public interface, so each method in this class should reference + * the ticket that would ultimately obviate it. + */ +public final class ProcessorContextUtils { + + private ProcessorContextUtils() {} + + /** + * Note that KIP-622 would move currentSystemTimeMs to ProcessorContext, + * removing the need for this method. + */ + public static long currentSystemTime(final ProcessorContext context) { + return context.currentSystemTimeMs(); + } + + /** + * Should be removed as part of KAFKA-10217 + */ + public static StreamsMetricsImpl getMetricsImpl(final ProcessorContext context) { + return (StreamsMetricsImpl) context.metrics(); + } + + /** + * Should be removed as part of KAFKA-10217 + */ + public static StreamsMetricsImpl getMetricsImpl(final StateStoreContext context) { + return (StreamsMetricsImpl) context.metrics(); + } + + public static String changelogFor(final ProcessorContext context, final String storeName) { + return context instanceof InternalProcessorContext + ? ((InternalProcessorContext) context).changelogFor(storeName) + : ProcessorStateManager.storeChangelogTopic(context.applicationId(), storeName, context.taskId().topologyName()); + } + + public static String changelogFor(final StateStoreContext context, final String storeName) { + return context instanceof InternalProcessorContext + ? ((InternalProcessorContext) context).changelogFor(storeName) + : ProcessorStateManager.storeChangelogTopic(context.applicationId(), storeName, context.taskId().topologyName()); + } + + public static InternalProcessorContext asInternalProcessorContext(final ProcessorContext context) { + if (context instanceof InternalProcessorContext) { + return (InternalProcessorContext) context; + } else { + throw new IllegalArgumentException( + "This component requires internal features of Kafka Streams and must be disabled for unit tests." + ); + } + } + + public static InternalProcessorContext asInternalProcessorContext(final StateStoreContext context) { + if (context instanceof InternalProcessorContext) { + return (InternalProcessorContext) context; + } else { + throw new IllegalArgumentException( + "This component requires internal features of Kafka Streams and must be disabled for unit tests." + ); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorNode.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorNode.java new file mode 100644 index 0000000..48c95f1 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorNode.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.utils.SystemTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.Punctuator; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.Record; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +public class ProcessorNode { + + private final List> children; + private final Map> childByName; + + private final Processor processor; + private final String name; + private final Time time; + + public final Set stateStores; + + private InternalProcessorContext internalProcessorContext; + private String threadId; + + private boolean closed = true; + + public ProcessorNode(final String name) { + this(name, (Processor) null, null); + } + + public ProcessorNode(final String name, + final Processor processor, + final Set stateStores) { + + this.name = name; + this.processor = processor; + this.children = new ArrayList<>(); + this.childByName = new HashMap<>(); + this.stateStores = stateStores; + this.time = new SystemTime(); + } + + public ProcessorNode(final String name, + @SuppressWarnings("deprecation") final org.apache.kafka.streams.processor.Processor processor, + final Set stateStores) { + + this.name = name; + this.processor = ProcessorAdapter.adapt(processor); + this.children = new ArrayList<>(); + this.childByName = new HashMap<>(); + this.stateStores = stateStores; + this.time = new SystemTime(); + } + + public final String name() { + return name; + } + + public final Processor processor() { + return processor; + } + + public List> children() { + return children; + } + + ProcessorNode getChild(final String childName) { + return childByName.get(childName); + } + + public void addChild(final ProcessorNode child) { + children.add(child); + childByName.put(child.name, child); + } + + public void init(final InternalProcessorContext context) { + if (!closed) + throw new IllegalStateException("The processor is not closed"); + + try { + threadId = Thread.currentThread().getName(); + internalProcessorContext = context; + if (processor != null) { + processor.init(context); + } + } catch (final Exception e) { + throw new StreamsException(String.format("failed to initialize processor %s", name), e); + } + + // revived tasks could re-initialize the topology, + // in which case we should reset the flag + closed = false; + } + + public void close() { + throwIfClosed(); + + try { + if (processor != null) { + processor.close(); + } + internalProcessorContext.metrics().removeAllNodeLevelSensors( + threadId, + internalProcessorContext.taskId().toString(), + name + ); + } catch (final Exception e) { + throw new StreamsException(String.format("failed to close processor %s", name), e); + } + + closed = true; + } + + protected void throwIfClosed() { + if (closed) { + throw new IllegalStateException("The processor is already closed"); + } + } + + + public void process(final Record record) { + throwIfClosed(); + + try { + processor.process(record); + } catch (final ClassCastException e) { + final String keyClass = record.key() == null ? "unknown because key is null" : record.key().getClass().getName(); + final String valueClass = record.value() == null ? "unknown because value is null" : record.value().getClass().getName(); + throw new StreamsException(String.format("ClassCastException invoking processor: %s. Do the Processor's " + + "input types match the deserialized types? Check the Serde setup and change the default Serdes in " + + "StreamConfig or provide correct Serdes via method parameters. Make sure the Processor can accept " + + "the deserialized input of type key: %s, and value: %s.%n" + + "Note that although incorrect Serdes are a common cause of error, the cast exception might have " + + "another cause (in user code, for example). For example, if a processor wires in a store, but casts " + + "the generics incorrectly, a class cast exception could be raised during processing, but the " + + "cause would not be wrong Serdes.", + this.name(), + keyClass, + valueClass), + e); + } + } + + public void punctuate(final long timestamp, final Punctuator punctuator) { + punctuator.punctuate(timestamp); + } + + public boolean isTerminalNode() { + return children.isEmpty(); + } + + /** + * @return a string representation of this node, useful for debugging. + */ + @Override + public String toString() { + return toString(""); + } + + /** + * @return a string representation of this node starting with the given indent, useful for debugging. + */ + public String toString(final String indent) { + final StringBuilder sb = new StringBuilder(indent + name + ":\n"); + if (stateStores != null && !stateStores.isEmpty()) { + sb.append(indent).append("\tstates:\t\t["); + for (final String store : stateStores) { + sb.append(store); + sb.append(", "); + } + sb.setLength(sb.length() - 2); // remove the last comma + sb.append("]\n"); + } + return sb.toString(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorNodePunctuator.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorNodePunctuator.java new file mode 100644 index 0000000..5544adc --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorNodePunctuator.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.Punctuator; + +public interface ProcessorNodePunctuator { + + void punctuate(ProcessorNode node, long timestamp, PunctuationType type, Punctuator punctuator); + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorRecordContext.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorRecordContext.java new file mode 100644 index 0000000..07e9ab3 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorRecordContext.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.streams.processor.RecordContext; +import org.apache.kafka.streams.processor.api.RecordMetadata; + +import java.nio.ByteBuffer; +import java.util.Objects; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Objects.requireNonNull; +import static org.apache.kafka.common.utils.Utils.getNullableSizePrefixedArray; + +public class ProcessorRecordContext implements RecordContext, RecordMetadata { + + private final long timestamp; + private final long offset; + private final String topic; + private final int partition; + private final Headers headers; + + public ProcessorRecordContext(final long timestamp, + final long offset, + final int partition, + final String topic, + final Headers headers) { + this.timestamp = timestamp; + this.offset = offset; + this.topic = topic; + this.partition = partition; + this.headers = Objects.requireNonNull(headers); + } + + @Override + public long offset() { + return offset; + } + + @Override + public long timestamp() { + return timestamp; + } + + @Override + public String topic() { + return topic; + } + + @Override + public int partition() { + return partition; + } + + @Override + public Headers headers() { + return headers; + } + + public long residentMemorySizeEstimate() { + long size = 0; + size += Long.BYTES; // value.context.timestamp + size += Long.BYTES; // value.context.offset + if (topic != null) { + size += topic.toCharArray().length; + } + size += Integer.BYTES; // partition + for (final Header header : headers) { + size += header.key().toCharArray().length; + final byte[] value = header.value(); + if (value != null) { + size += value.length; + } + } + return size; + } + + public byte[] serialize() { + final byte[] topicBytes = topic.getBytes(UTF_8); + final byte[][] headerKeysBytes; + final byte[][] headerValuesBytes; + + int size = 0; + size += Long.BYTES; // value.context.timestamp + size += Long.BYTES; // value.context.offset + size += Integer.BYTES; // size of topic + size += topicBytes.length; + size += Integer.BYTES; // partition + size += Integer.BYTES; // number of headers + + final Header[] headers = this.headers.toArray(); + headerKeysBytes = new byte[headers.length][]; + headerValuesBytes = new byte[headers.length][]; + + for (int i = 0; i < headers.length; i++) { + size += 2 * Integer.BYTES; // sizes of key and value + + final byte[] keyBytes = headers[i].key().getBytes(UTF_8); + size += keyBytes.length; + final byte[] valueBytes = headers[i].value(); + if (valueBytes != null) { + size += valueBytes.length; + } + + headerKeysBytes[i] = keyBytes; + headerValuesBytes[i] = valueBytes; + } + + final ByteBuffer buffer = ByteBuffer.allocate(size); + buffer.putLong(timestamp); + buffer.putLong(offset); + + // not handling the null condition because we believe topic will never be null in cases where we serialize + buffer.putInt(topicBytes.length); + buffer.put(topicBytes); + + buffer.putInt(partition); + buffer.putInt(headerKeysBytes.length); + for (int i = 0; i < headerKeysBytes.length; i++) { + buffer.putInt(headerKeysBytes[i].length); + buffer.put(headerKeysBytes[i]); + + if (headerValuesBytes[i] != null) { + buffer.putInt(headerValuesBytes[i].length); + buffer.put(headerValuesBytes[i]); + } else { + buffer.putInt(-1); + } + } + + return buffer.array(); + } + + public static ProcessorRecordContext deserialize(final ByteBuffer buffer) { + final long timestamp = buffer.getLong(); + final long offset = buffer.getLong(); + final String topic; + { + // we believe the topic will never be null when we serialize + final byte[] topicBytes = requireNonNull(getNullableSizePrefixedArray(buffer)); + topic = new String(topicBytes, UTF_8); + } + final int partition = buffer.getInt(); + final int headerCount = buffer.getInt(); + final Headers headers; + if (headerCount == -1) { // keep for backward compatibilty + headers = new RecordHeaders(); + } else { + final Header[] headerArr = new Header[headerCount]; + for (int i = 0; i < headerCount; i++) { + final byte[] keyBytes = requireNonNull(getNullableSizePrefixedArray(buffer)); + final byte[] valueBytes = getNullableSizePrefixedArray(buffer); + headerArr[i] = new RecordHeader(new String(keyBytes, UTF_8), valueBytes); + } + headers = new RecordHeaders(headerArr); + } + + return new ProcessorRecordContext(timestamp, offset, partition, topic, headers); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final ProcessorRecordContext that = (ProcessorRecordContext) o; + return timestamp == that.timestamp && + offset == that.offset && + partition == that.partition && + Objects.equals(topic, that.topic) && + Objects.equals(headers, that.headers); + } + + /** + * Equality is implemented in support of tests, *not* for use in Hash collections, since this class is mutable. + */ + @Deprecated + @Override + public int hashCode() { + throw new UnsupportedOperationException("ProcessorRecordContext is unsafe for use in Hash collections"); + } + + @Override + public String toString() { + return "ProcessorRecordContext{" + + "topic='" + topic + '\'' + + ", partition=" + partition + + ", offset=" + offset + + ", timestamp=" + timestamp + + ", headers=" + headers + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java new file mode 100644 index 0000000..b507aea --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorStateManager.java @@ -0,0 +1,685 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.FixedOrderMap; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.streams.errors.ProcessorStateException; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.errors.TaskCorruptedException; +import org.apache.kafka.streams.errors.TaskMigratedException; +import org.apache.kafka.streams.processor.StateRestoreCallback; +import org.apache.kafka.streams.processor.StateRestoreListener; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.Task.TaskType; +import org.apache.kafka.streams.state.internals.CachedStateStore; +import org.apache.kafka.streams.state.internals.OffsetCheckpoint; +import org.apache.kafka.streams.state.internals.RecordConverter; +import org.apache.kafka.streams.state.internals.TimeOrderedKeyValueBuffer; +import org.slf4j.Logger; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static java.lang.String.format; +import static org.apache.kafka.streams.processor.internals.StateManagerUtil.CHECKPOINT_FILE_NAME; +import static org.apache.kafka.streams.processor.internals.StateManagerUtil.converterForStore; +import static org.apache.kafka.streams.processor.internals.StateRestoreCallbackAdapter.adapt; +import static org.apache.kafka.streams.state.internals.OffsetCheckpoint.OFFSET_UNKNOWN; + +/** + * ProcessorStateManager is the source of truth for the current offset for each state store, + * which is either the read offset during restoring, or the written offset during normal processing. + * + * The offset is initialized as null when the state store is registered, and then it can be updated by + * loading checkpoint file, restore state stores, or passing from the record collector's written offsets. + * + * When checkpointing, if the offset is not null it would be written to the file. + * + * The manager is also responsible for restoring state stores via their registered restore callback, + * which is used for both updating standby tasks as well as restoring active tasks. + */ +public class ProcessorStateManager implements StateManager { + + public static class StateStoreMetadata { + private final StateStore stateStore; + + // corresponding changelog partition of the store, this and the following two fields + // will only be not-null if the state store is logged (i.e. changelog partition and restorer provided) + private final TopicPartition changelogPartition; + + // could be used for both active restoration and standby + private final StateRestoreCallback restoreCallback; + + // record converters used for restoration and standby + private final RecordConverter recordConverter; + + // indicating the current snapshot of the store as the offset of last changelog record that has been + // applied to the store used for both restoration (active and standby tasks restored offset) and + // normal processing that update stores (written offset); could be null (when initialized) + // + // the offset is updated in three ways: + // 1. when loading from the checkpoint file, when the corresponding task has acquired the state + // directory lock and have registered all the state store; it is only one-time + // 2. when updating with restore records (by both restoring active and standby), + // update to the last restore record's offset + // 3. when checkpointing with the given written offsets from record collector, + // update blindly with the given offset + private Long offset; + + // corrupted state store should not be included in checkpointing + private boolean corrupted; + + private StateStoreMetadata(final StateStore stateStore) { + this.stateStore = stateStore; + this.restoreCallback = null; + this.recordConverter = null; + this.changelogPartition = null; + this.corrupted = false; + this.offset = null; + } + + private StateStoreMetadata(final StateStore stateStore, + final TopicPartition changelogPartition, + final StateRestoreCallback restoreCallback, + final RecordConverter recordConverter) { + if (restoreCallback == null) { + throw new IllegalStateException("Log enabled store should always provide a restore callback upon registration"); + } + + this.stateStore = stateStore; + this.changelogPartition = changelogPartition; + this.restoreCallback = restoreCallback; + this.recordConverter = recordConverter; + this.offset = null; + } + + private void setOffset(final Long offset) { + this.offset = offset; + } + + // the offset is exposed to the changelog reader to determine if restoration is completed + Long offset() { + return this.offset; + } + + TopicPartition changelogPartition() { + return this.changelogPartition; + } + + StateStore store() { + return this.stateStore; + } + + @Override + public String toString() { + return "StateStoreMetadata (" + stateStore.name() + " : " + changelogPartition + " @ " + offset; + } + } + + private static final String STATE_CHANGELOG_TOPIC_SUFFIX = "-changelog"; + + private Logger log; + private String logPrefix; + + private final TaskId taskId; + private final boolean eosEnabled; + private final ChangelogRegister changelogReader; + private final Collection sourcePartitions; + private final Map storeToChangelogTopic; + + // must be maintained in topological order + private final FixedOrderMap stores = new FixedOrderMap<>(); + private final FixedOrderMap globalStores = new FixedOrderMap<>(); + + private final File baseDir; + private final OffsetCheckpoint checkpointFile; + + private TaskType taskType; + + public static String storeChangelogTopic(final String applicationId, final String storeName, final String namedTopology) { + if (namedTopology == null) { + return applicationId + "-" + storeName + STATE_CHANGELOG_TOPIC_SUFFIX; + } else { + return applicationId + "-" + namedTopology + "-" + storeName + STATE_CHANGELOG_TOPIC_SUFFIX; + } + } + + /** + * @throws ProcessorStateException if the task directory does not exist and could not be created + */ + public ProcessorStateManager(final TaskId taskId, + final TaskType taskType, + final boolean eosEnabled, + final LogContext logContext, + final StateDirectory stateDirectory, + final ChangelogRegister changelogReader, + final Map storeToChangelogTopic, + final Collection sourcePartitions) throws ProcessorStateException { + this.storeToChangelogTopic = storeToChangelogTopic; + this.log = logContext.logger(ProcessorStateManager.class); + this.logPrefix = logContext.logPrefix(); + this.taskId = taskId; + this.taskType = taskType; + this.eosEnabled = eosEnabled; + this.changelogReader = changelogReader; + this.sourcePartitions = sourcePartitions; + + this.baseDir = stateDirectory.getOrCreateDirectoryForTask(taskId); + this.checkpointFile = new OffsetCheckpoint(stateDirectory.checkpointFileFor(taskId)); + + log.debug("Created state store manager for task {}", taskId); + } + + void registerStateStores(final List allStores, final InternalProcessorContext processorContext) { + processorContext.uninitialize(); + for (final StateStore store : allStores) { + if (stores.containsKey(store.name())) { + maybeRegisterStoreWithChangelogReader(store.name()); + } else { + store.init((StateStoreContext) processorContext, store); + } + log.trace("Registered state store {}", store.name()); + } + } + + void registerGlobalStateStores(final List stateStores) { + log.debug("Register global stores {}", stateStores); + for (final StateStore stateStore : stateStores) { + globalStores.put(stateStore.name(), stateStore); + } + } + + @Override + public StateStore getGlobalStore(final String name) { + return globalStores.get(name); + } + + // package-private for test only + void initializeStoreOffsetsFromCheckpoint(final boolean storeDirIsEmpty) { + try { + final Map loadedCheckpoints = checkpointFile.read(); + + log.trace("Loaded offsets from the checkpoint file: {}", loadedCheckpoints); + + for (final StateStoreMetadata store : stores.values()) { + if (store.corrupted) { + log.error("Tried to initialize store offsets for corrupted store {}", store); + throw new IllegalStateException("Should not initialize offsets for a corrupted task"); + } + + if (store.changelogPartition == null) { + log.info("State store {} is not logged and hence would not be restored", store.stateStore.name()); + } else if (!store.stateStore.persistent()) { + log.info("Initializing to the starting offset for changelog {} of in-memory state store {}", + store.changelogPartition, store.stateStore.name()); + } else if (store.offset() == null) { + if (loadedCheckpoints.containsKey(store.changelogPartition)) { + final Long offset = changelogOffsetFromCheckpointedOffset(loadedCheckpoints.remove(store.changelogPartition)); + store.setOffset(offset); + + log.debug("State store {} initialized from checkpoint with offset {} at changelog {}", + store.stateStore.name(), store.offset, store.changelogPartition); + } else { + // with EOS, if the previous run did not shutdown gracefully, we may lost the checkpoint file + // and hence we are uncertain that the current local state only contains committed data; + // in that case we need to treat it as a task-corrupted exception + if (eosEnabled && !storeDirIsEmpty) { + log.warn("State store {} did not find checkpoint offsets while stores are not empty, " + + "since under EOS it has the risk of getting uncommitted data in stores we have to " + + "treat it as a task corruption error and wipe out the local state of task {} " + + "before re-bootstrapping", store.stateStore.name(), taskId); + + throw new TaskCorruptedException(Collections.singleton(taskId)); + } else { + log.info("State store {} did not find checkpoint offset, hence would " + + "default to the starting offset at changelog {}", + store.stateStore.name(), store.changelogPartition); + } + } + } else { + loadedCheckpoints.remove(store.changelogPartition); + log.debug("Skipping re-initialization of offset from checkpoint for recycled store {}", + store.stateStore.name()); + } + } + + if (!loadedCheckpoints.isEmpty()) { + log.warn("Some loaded checkpoint offsets cannot find their corresponding state stores: {}", loadedCheckpoints); + } + + if (eosEnabled) { + checkpointFile.delete(); + } + } catch (final TaskCorruptedException e) { + throw e; + } catch (final IOException | RuntimeException e) { + // both IOException or runtime exception like number parsing can throw + throw new ProcessorStateException(format("%sError loading and deleting checkpoint file when creating the state manager", + logPrefix), e); + } + } + + private void maybeRegisterStoreWithChangelogReader(final String storeName) { + if (isLoggingEnabled(storeName)) { + changelogReader.register(getStorePartition(storeName), this); + } + } + + private List getAllChangelogTopicPartitions() { + final List allChangelogPartitions = new ArrayList<>(); + for (final StateStoreMetadata storeMetadata : stores.values()) { + if (storeMetadata.changelogPartition != null) { + allChangelogPartitions.add(storeMetadata.changelogPartition); + } + } + return allChangelogPartitions; + } + + @Override + public File baseDir() { + return baseDir; + } + + @Override + public void registerStore(final StateStore store, final StateRestoreCallback stateRestoreCallback) { + final String storeName = store.name(); + + // TODO (KAFKA-12887): we should not trigger user's exception handler for illegal-argument but always + // fail-crash; in this case we would not need to immediately close the state store before throwing + if (CHECKPOINT_FILE_NAME.equals(storeName)) { + store.close(); + throw new IllegalArgumentException(format("%sIllegal store name: %s, which collides with the pre-defined " + + "checkpoint file name", logPrefix, storeName)); + } + + if (stores.containsKey(storeName)) { + store.close(); + throw new IllegalArgumentException(format("%sStore %s has already been registered.", logPrefix, storeName)); + } + + if (stateRestoreCallback instanceof StateRestoreListener) { + log.warn("The registered state restore callback is also implementing the state restore listener interface, " + + "which is not expected and would be ignored"); + } + + final StateStoreMetadata storeMetadata = isLoggingEnabled(storeName) ? + new StateStoreMetadata( + store, + getStorePartition(storeName), + stateRestoreCallback, + converterForStore(store)) : + new StateStoreMetadata(store); + + // register the store first, so that if later an exception is thrown then eventually while we call `close` + // on the state manager this state store would be closed as well + stores.put(storeName, storeMetadata); + + maybeRegisterStoreWithChangelogReader(storeName); + + log.debug("Registered state store {} to its state manager", storeName); + } + + @Override + public StateStore getStore(final String name) { + if (stores.containsKey(name)) { + return stores.get(name).stateStore; + } else { + return null; + } + } + + Collection changelogPartitions() { + return changelogOffsets().keySet(); + } + + void markChangelogAsCorrupted(final Collection partitions) { + for (final StateStoreMetadata storeMetadata : stores.values()) { + if (partitions.contains(storeMetadata.changelogPartition)) { + storeMetadata.corrupted = true; + partitions.remove(storeMetadata.changelogPartition); + } + } + + if (!partitions.isEmpty()) { + throw new IllegalStateException("Some partitions " + partitions + " are not contained in the store list of task " + + taskId + " marking as corrupted, this is not expected"); + } + } + + @Override + public Map changelogOffsets() { + // return the current offsets for those logged stores + final Map changelogOffsets = new HashMap<>(); + for (final StateStoreMetadata storeMetadata : stores.values()) { + if (storeMetadata.changelogPartition != null) { + // for changelog whose offset is unknown, use 0L indicating earliest offset + // otherwise return the current offset + 1 as the next offset to fetch + changelogOffsets.put( + storeMetadata.changelogPartition, + storeMetadata.offset == null ? 0L : storeMetadata.offset + 1L); + } + } + return changelogOffsets; + } + + TaskId taskId() { + return taskId; + } + + // used by the changelog reader only + boolean changelogAsSource(final TopicPartition partition) { + return sourcePartitions.contains(partition); + } + + @Override + public TaskType taskType() { + return taskType; + } + + // used by the changelog reader only + StateStoreMetadata storeMetadata(final TopicPartition partition) { + for (final StateStoreMetadata storeMetadata : stores.values()) { + if (partition.equals(storeMetadata.changelogPartition)) { + return storeMetadata; + } + } + return null; + } + + // used by the changelog reader only + void restore(final StateStoreMetadata storeMetadata, final List> restoreRecords) { + if (!stores.containsValue(storeMetadata)) { + throw new IllegalStateException("Restoring " + storeMetadata + " which is not registered in this state manager, " + + "this should not happen."); + } + + if (!restoreRecords.isEmpty()) { + // restore states from changelog records and update the snapshot offset as the batch end record's offset + final Long batchEndOffset = restoreRecords.get(restoreRecords.size() - 1).offset(); + final RecordBatchingStateRestoreCallback restoreCallback = adapt(storeMetadata.restoreCallback); + final List> convertedRecords = restoreRecords.stream() + .map(storeMetadata.recordConverter::convert) + .collect(Collectors.toList()); + + try { + restoreCallback.restoreBatch(convertedRecords); + } catch (final RuntimeException e) { + throw new ProcessorStateException( + format("%sException caught while trying to restore state from %s", logPrefix, storeMetadata.changelogPartition), + e + ); + } + + storeMetadata.setOffset(batchEndOffset); + } + } + + /** + * @throws TaskMigratedException recoverable error sending changelog records that would cause the task to be removed + * @throws StreamsException fatal error when flushing the state store, for example sending changelog records failed + * or flushing state store get IO errors; such error should cause the thread to die + */ + @Override + public void flush() { + RuntimeException firstException = null; + // attempting to flush the stores + if (!stores.isEmpty()) { + log.debug("Flushing all stores registered in the state manager: {}", stores); + for (final StateStoreMetadata metadata : stores.values()) { + final StateStore store = metadata.stateStore; + log.trace("Flushing store {}", store.name()); + try { + store.flush(); + } catch (final RuntimeException exception) { + if (firstException == null) { + // do NOT wrap the error if it is actually caused by Streams itself + if (exception instanceof StreamsException) + firstException = exception; + else + firstException = new ProcessorStateException( + format("%sFailed to flush state store %s", logPrefix, store.name()), exception); + } + log.error("Failed to flush state store {}: ", store.name(), exception); + } + } + } + + if (firstException != null) { + throw firstException; + } + } + + public void flushCache() { + RuntimeException firstException = null; + // attempting to flush the stores + if (!stores.isEmpty()) { + log.debug("Flushing all store caches registered in the state manager: {}", stores); + for (final StateStoreMetadata metadata : stores.values()) { + final StateStore store = metadata.stateStore; + + try { + // buffer should be flushed to send all records to changelog + if (store instanceof TimeOrderedKeyValueBuffer) { + store.flush(); + } else if (store instanceof CachedStateStore) { + ((CachedStateStore) store).flushCache(); + } + log.trace("Flushed cache or buffer {}", store.name()); + } catch (final RuntimeException exception) { + if (firstException == null) { + // do NOT wrap the error if it is actually caused by Streams itself + if (exception instanceof StreamsException) { + firstException = exception; + } else { + firstException = new ProcessorStateException( + format("%sFailed to flush cache of store %s", logPrefix, store.name()), + exception + ); + } + } + log.error("Failed to flush cache of store {}: ", store.name(), exception); + } + } + } + + if (firstException != null) { + throw firstException; + } + } + + /** + * {@link StateStore#close() Close} all stores (even in case of failure). + * Log all exceptions and re-throw the first exception that occurred at the end. + * + * @throws ProcessorStateException if any error happens when closing the state stores + */ + @Override + public void close() throws ProcessorStateException { + log.debug("Closing its state manager and all the registered state stores: {}", stores); + + changelogReader.unregister(getAllChangelogTopicPartitions()); + + RuntimeException firstException = null; + // attempting to close the stores, just in case they + // are not closed by a ProcessorNode yet + if (!stores.isEmpty()) { + for (final Map.Entry entry : stores.entrySet()) { + final StateStore store = entry.getValue().stateStore; + log.trace("Closing store {}", store.name()); + try { + store.close(); + } catch (final RuntimeException exception) { + if (firstException == null) { + // do NOT wrap the error if it is actually caused by Streams itself + if (exception instanceof StreamsException) + firstException = exception; + else + firstException = new ProcessorStateException( + format("%sFailed to close state store %s", logPrefix, store.name()), exception); + } + log.error("Failed to close state store {}: ", store.name(), exception); + } + } + + stores.clear(); + } + + if (firstException != null) { + throw firstException; + } + } + + /** + * Alternative to {@link #close()} that just resets the changelogs without closing any of the underlying state + * or unregistering the stores themselves + */ + void recycle() { + log.debug("Recycling state for {} task {}.", taskType, taskId); + + final List allChangelogs = getAllChangelogTopicPartitions(); + changelogReader.unregister(allChangelogs); + } + + void transitionTaskType(final TaskType newType, final LogContext logContext) { + if (taskType.equals(newType)) { + throw new IllegalStateException("Tried to recycle state for task type conversion but new type was the same."); + } + + final TaskType oldType = taskType; + taskType = newType; + log = logContext.logger(ProcessorStateManager.class); + logPrefix = logContext.logPrefix(); + + log.debug("Transitioning state manager for {} task {} to {}", oldType, taskId, newType); + } + + @Override + public void updateChangelogOffsets(final Map writtenOffsets) { + for (final Map.Entry entry : writtenOffsets.entrySet()) { + final StateStoreMetadata store = findStore(entry.getKey()); + + if (store != null) { + store.setOffset(entry.getValue()); + + log.debug("State store {} updated to written offset {} at changelog {}", + store.stateStore.name(), store.offset, store.changelogPartition); + } + } + } + + @Override + public void checkpoint() { + // checkpoint those stores that are only logged and persistent to the checkpoint file + final Map checkpointingOffsets = new HashMap<>(); + for (final StateStoreMetadata storeMetadata : stores.values()) { + // store is logged, persistent, not corrupted, and has a valid current offset + if (storeMetadata.changelogPartition != null && + storeMetadata.stateStore.persistent() && + !storeMetadata.corrupted) { + + final long checkpointableOffset = checkpointableOffsetFromChangelogOffset(storeMetadata.offset); + checkpointingOffsets.put(storeMetadata.changelogPartition, checkpointableOffset); + } + } + + log.debug("Writing checkpoint: {}", checkpointingOffsets); + try { + checkpointFile.write(checkpointingOffsets); + } catch (final IOException e) { + log.warn("Failed to write offset checkpoint file to [{}]." + + " This may occur if OS cleaned the state.dir in case when it located in ${java.io.tmpdir} directory." + + " This may also occur due to running multiple instances on the same machine using the same state dir." + + " Changing the location of state.dir may resolve the problem.", + checkpointFile, e); + } + } + + private TopicPartition getStorePartition(final String storeName) { + // NOTE we assume the partition of the topic can always be inferred from the task id; + // if user ever use a custom partition grouper (deprecated in KIP-528) this would break and + // it is not a regression (it would always break anyways) + return new TopicPartition(changelogFor(storeName), taskId.partition()); + } + + private boolean isLoggingEnabled(final String storeName) { + // if the store name does not exist in the changelog map, it means the underlying store + // is not log enabled (including global stores) + return changelogFor(storeName) != null; + } + + private StateStoreMetadata findStore(final TopicPartition changelogPartition) { + final List found = stores.values().stream() + .filter(metadata -> changelogPartition.equals(metadata.changelogPartition)) + .collect(Collectors.toList()); + + if (found.size() > 1) { + throw new IllegalStateException("Multiple state stores are found for changelog partition " + changelogPartition + + ", this should never happen: " + found); + } + + return found.isEmpty() ? null : found.get(0); + } + + // Pass in a sentinel value to checkpoint when the changelog offset is not yet initialized/known + private long checkpointableOffsetFromChangelogOffset(final Long offset) { + return offset != null ? offset : OFFSET_UNKNOWN; + } + + // Convert the written offsets in the checkpoint file back to the changelog offset + private Long changelogOffsetFromCheckpointedOffset(final long offset) { + return offset != OFFSET_UNKNOWN ? offset : null; + } + + public TopicPartition registeredChangelogPartitionFor(final String storeName) { + final StateStoreMetadata storeMetadata = stores.get(storeName); + if (storeMetadata == null) { + throw new IllegalStateException("State store " + storeName + + " for which the registered changelog partition should be" + + " retrieved has not been registered" + ); + } + if (storeMetadata.changelogPartition == null) { + throw new IllegalStateException("Registered state store " + storeName + + " does not have a registered changelog partition." + + " This may happen if logging is disabled for the state store." + ); + } + return storeMetadata.changelogPartition; + } + + public String changelogFor(final String storeName) { + return storeToChangelogTopic.get(storeName); + } + + public void deleteCheckPointFileIfEOSEnabled() throws IOException { + if (eosEnabled) { + checkpointFile.delete(); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorTopology.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorTopology.java new file mode 100644 index 0000000..d2383c7 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ProcessorTopology.java @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.streams.processor.StateStore; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +public class ProcessorTopology { + private final Logger log = LoggerFactory.getLogger(ProcessorTopology.class); + + private final List> processorNodes; + private final Map> sourceNodesByName; + private final Map> sourceNodesByTopic; + private final Map> sinksByTopic; + private final Set terminalNodes; + private final List stateStores; + private final Set repartitionTopics; + + // the following contains entries for the entire topology, eg stores that do not belong to this ProcessorTopology + private final List globalStateStores; + private final Map storeToChangelogTopic; + + public ProcessorTopology(final List> processorNodes, + final Map> sourceNodesByTopic, + final Map> sinksByTopic, + final List stateStores, + final List globalStateStores, + final Map storeToChangelogTopic, + final Set repartitionTopics) { + this.processorNodes = Collections.unmodifiableList(processorNodes); + this.sourceNodesByTopic = new HashMap<>(sourceNodesByTopic); + this.sinksByTopic = Collections.unmodifiableMap(sinksByTopic); + this.stateStores = Collections.unmodifiableList(stateStores); + this.globalStateStores = Collections.unmodifiableList(globalStateStores); + this.storeToChangelogTopic = Collections.unmodifiableMap(storeToChangelogTopic); + this.repartitionTopics = Collections.unmodifiableSet(repartitionTopics); + + this.terminalNodes = new HashSet<>(); + for (final ProcessorNode node : processorNodes) { + if (node.isTerminalNode()) { + terminalNodes.add(node.name()); + } + } + + this.sourceNodesByName = new HashMap<>(); + for (final SourceNode source : sourceNodesByTopic.values()) { + sourceNodesByName.put(source.name(), source); + } + } + + public Set sourceTopics() { + return sourceNodesByTopic.keySet(); + } + + public SourceNode source(final String topic) { + return sourceNodesByTopic.get(topic); + } + + public Set> sources() { + return new HashSet<>(sourceNodesByTopic.values()); + } + + public Set sinkTopics() { + return sinksByTopic.keySet(); + } + + public SinkNode sink(final String topic) { + return sinksByTopic.get(topic); + } + + public Set terminalNodes() { + return terminalNodes; + } + + public List> processors() { + return processorNodes; + } + + public List stateStores() { + return stateStores; + } + + public List globalStateStores() { + return Collections.unmodifiableList(globalStateStores); + } + + public Map storeToChangelogTopic() { + return Collections.unmodifiableMap(storeToChangelogTopic); + } + + boolean isRepartitionTopic(final String topic) { + return repartitionTopics.contains(topic); + } + + boolean hasStateWithChangelogs() { + for (final StateStore stateStore : stateStores) { + if (storeToChangelogTopic.containsKey(stateStore.name())) { + return true; + } + } + return false; + } + + public boolean hasPersistentLocalStore() { + for (final StateStore store : stateStores) { + if (store.persistent()) { + return true; + } + } + return false; + } + + public boolean hasPersistentGlobalStore() { + for (final StateStore store : globalStateStores) { + if (store.persistent()) { + return true; + } + } + return false; + } + + public void updateSourceTopics(final Map> allSourceTopicsByNodeName) { + sourceNodesByTopic.clear(); + for (final Map.Entry> sourceNodeEntry : sourceNodesByName.entrySet()) { + final String sourceNodeName = sourceNodeEntry.getKey(); + final SourceNode sourceNode = sourceNodeEntry.getValue(); + + final List updatedSourceTopics = allSourceTopicsByNodeName.get(sourceNodeName); + if (updatedSourceTopics == null) { + log.error("Unable to find source node {} in updated topics map {}", + sourceNodeName, allSourceTopicsByNodeName); + throw new IllegalStateException("Node " + sourceNodeName + " not found in full topology"); + } + + log.trace("Updating source node {} with new topics {}", sourceNodeName, updatedSourceTopics); + for (final String topic : updatedSourceTopics) { + if (sourceNodesByTopic.containsKey(topic)) { + log.error("Tried to subscribe topic {} to two nodes when updating topics from {}", + topic, allSourceTopicsByNodeName); + throw new IllegalStateException("Topic " + topic + " was already registered to source node " + + sourceNodesByTopic.get(topic).name()); + } + sourceNodesByTopic.put(topic, sourceNode); + } + } + } + + private String childrenToString(final String indent, final List> children) { + if (children == null || children.isEmpty()) { + return ""; + } + + final StringBuilder sb = new StringBuilder(indent + "\tchildren:\t["); + for (final ProcessorNode child : children) { + sb.append(child.name()); + sb.append(", "); + } + sb.setLength(sb.length() - 2); // remove the last comma + sb.append("]\n"); + + // recursively print children + for (final ProcessorNode child : children) { + sb.append(child.toString(indent)).append(childrenToString(indent, child.children())); + } + return sb.toString(); + } + + /** + * Produces a string representation containing useful information this topology starting with the given indent. + * This is useful in debugging scenarios. + * @return A string representation of this instance. + */ + @Override + public String toString() { + return toString(""); + } + + /** + * Produces a string representation containing useful information this topology. + * This is useful in debugging scenarios. + * @return A string representation of this instance. + */ + public String toString(final String indent) { + final Map, List> sourceToTopics = new HashMap<>(); + for (final Map.Entry> sourceNodeEntry : sourceNodesByTopic.entrySet()) { + final String topic = sourceNodeEntry.getKey(); + final SourceNode source = sourceNodeEntry.getValue(); + sourceToTopics.computeIfAbsent(source, s -> new ArrayList<>()); + sourceToTopics.get(source).add(topic); + } + + final StringBuilder sb = new StringBuilder(indent + "ProcessorTopology:\n"); + + // start from sources + for (final Map.Entry, List> sourceNodeEntry : sourceToTopics.entrySet()) { + final SourceNode source = sourceNodeEntry.getKey(); + final List topics = sourceNodeEntry.getValue(); + sb.append(source.toString(indent + "\t")) + .append(topicsToString(indent + "\t", topics)) + .append(childrenToString(indent + "\t", source.children())); + } + return sb.toString(); + } + + private static String topicsToString(final String indent, final List topics) { + final StringBuilder sb = new StringBuilder(); + sb.append(indent).append("\ttopics:\t\t["); + for (final String topic : topics) { + sb.append(topic); + sb.append(", "); + } + sb.setLength(sb.length() - 2); // remove the last comma + sb.append("]\n"); + return sb.toString(); + } + + // for testing only + public Set processorConnectedStateStores(final String processorName) { + for (final ProcessorNode node : processorNodes) { + if (node.name().equals(processorName)) { + return node.stateStores; + } + } + + return Collections.emptySet(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/PunctuationQueue.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/PunctuationQueue.java new file mode 100644 index 0000000..eca9ae5 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/PunctuationQueue.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.streams.errors.TaskMigratedException; +import org.apache.kafka.streams.processor.Cancellable; +import org.apache.kafka.streams.processor.PunctuationType; + +import java.util.PriorityQueue; + +public class PunctuationQueue { + + private final PriorityQueue pq = new PriorityQueue<>(); + + public Cancellable schedule(final PunctuationSchedule sched) { + synchronized (pq) { + pq.add(sched); + } + return sched.cancellable(); + } + + public void close() { + synchronized (pq) { + pq.clear(); + } + } + + /** + * @throws TaskMigratedException if the task producer got fenced (EOS only) + */ + boolean mayPunctuate(final long timestamp, final PunctuationType type, final ProcessorNodePunctuator processorNodePunctuator) { + synchronized (pq) { + boolean punctuated = false; + PunctuationSchedule top = pq.peek(); + while (top != null && top.timestamp <= timestamp) { + final PunctuationSchedule sched = top; + pq.poll(); + + if (!sched.isCancelled()) { + processorNodePunctuator.punctuate(sched.node(), timestamp, type, sched.punctuator()); + // sched can be cancelled from within the punctuator + if (!sched.isCancelled()) { + pq.add(sched.next(timestamp)); + } + punctuated = true; + } + + + top = pq.peek(); + } + + return punctuated; + } + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/PunctuationSchedule.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/PunctuationSchedule.java new file mode 100644 index 0000000..bd8a150 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/PunctuationSchedule.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.streams.processor.Cancellable; +import org.apache.kafka.streams.processor.Punctuator; + +public class PunctuationSchedule extends Stamped { + + private final long interval; + private final Punctuator punctuator; + private boolean isCancelled = false; + // this Cancellable will be re-pointed at the successor schedule in next() + private final RepointableCancellable cancellable; + + PunctuationSchedule(final ProcessorNode node, + final long time, + final long interval, + final Punctuator punctuator) { + this(node, time, interval, punctuator, new RepointableCancellable()); + cancellable.setSchedule(this); + } + + private PunctuationSchedule(final ProcessorNode node, + final long time, + final long interval, + final Punctuator punctuator, + final RepointableCancellable cancellable) { + super(node, time); + this.interval = interval; + this.punctuator = punctuator; + this.cancellable = cancellable; + } + + public ProcessorNode node() { + return value; + } + + public Punctuator punctuator() { + return punctuator; + } + + public Cancellable cancellable() { + return cancellable; + } + + void markCancelled() { + isCancelled = true; + } + + boolean isCancelled() { + return isCancelled; + } + + public PunctuationSchedule next(final long currTimestamp) { + long nextPunctuationTime = timestamp + interval; + if (currTimestamp >= nextPunctuationTime) { + // we missed one ore more punctuations + // avoid scheduling a new punctuations immediately, this can happen: + // - when using STREAM_TIME punctuation and there was a gap i.e., no data was + // received for at least 2*interval + // - when using WALL_CLOCK_TIME and there was a gap i.e., punctuation was delayed for at least 2*interval (GC pause, overload, ...) + final long intervalsMissed = (currTimestamp - timestamp) / interval; + nextPunctuationTime = timestamp + (intervalsMissed + 1) * interval; + } + + final PunctuationSchedule nextSchedule = new PunctuationSchedule(value, nextPunctuationTime, interval, punctuator, cancellable); + + cancellable.setSchedule(nextSchedule); + + return nextSchedule; + } + + @Override + public boolean equals(final Object other) { + return super.equals(other); + } + + @Override + public int hashCode() { + return super.hashCode(); + } + + private static class RepointableCancellable implements Cancellable { + private PunctuationSchedule schedule; + + synchronized void setSchedule(final PunctuationSchedule schedule) { + this.schedule = schedule; + } + + @Override + synchronized public void cancel() { + schedule.markCancelled(); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/QuickUnion.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/QuickUnion.java new file mode 100644 index 0000000..f91bdcc --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/QuickUnion.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import java.util.HashMap; +import java.util.NoSuchElementException; + +public class QuickUnion { + + private final HashMap ids = new HashMap<>(); + + public void add(final T id) { + ids.put(id, id); + } + + public boolean exists(final T id) { + return ids.containsKey(id); + } + + /** + * @throws NoSuchElementException if the parent of this node is null + */ + public T root(final T id) { + T current = id; + T parent = ids.get(current); + + if (parent == null) { + throw new NoSuchElementException("id: " + id.toString()); + } + + while (!parent.equals(current)) { + // do the path splitting + final T grandparent = ids.get(parent); + ids.put(current, grandparent); + + current = parent; + parent = grandparent; + } + return current; + } + + @SuppressWarnings("unchecked") + void unite(final T id1, final T... idList) { + for (final T id2 : idList) { + unitePair(id1, id2); + } + } + + private void unitePair(final T id1, final T id2) { + final T root1 = root(id1); + final T root2 = root(id2); + + if (!root1.equals(root2)) { + ids.put(root1, root2); + } + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordBatchingStateRestoreCallback.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordBatchingStateRestoreCallback.java new file mode 100644 index 0000000..300b60d --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordBatchingStateRestoreCallback.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.BatchingStateRestoreCallback; + +import java.util.Collection; + +public interface RecordBatchingStateRestoreCallback extends BatchingStateRestoreCallback { + void restoreBatch(final Collection> records); + + @Override + default void restoreAll(final Collection> records) { + throw new UnsupportedOperationException(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollector.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollector.java new file mode 100644 index 0000000..8b22f22 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollector.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.producer.Producer; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.processor.StreamPartitioner; + +import java.util.Map; + +public interface RecordCollector { + + void send(final String topic, + final K key, + final V value, + final Headers headers, + final Integer partition, + final Long timestamp, + final Serializer keySerializer, + final Serializer valueSerializer); + + void send(final String topic, + final K key, + final V value, + final Headers headers, + final Long timestamp, + final Serializer keySerializer, + final Serializer valueSerializer, + final StreamPartitioner partitioner); + + /** + * Initialize the internal {@link Producer}; note this function should be made idempotent + * + * @throws org.apache.kafka.common.errors.TimeoutException if producer initializing txn id timed out + */ + void initialize(); + + /** + * Flush the internal {@link Producer}. + */ + void flush(); + + /** + * Clean close the internal {@link Producer}. + */ + void closeClean(); + + /** + * Dirty close the internal {@link Producer}. + */ + void closeDirty(); + + /** + * The last acked offsets from the internal {@link Producer}. + * + * @return an immutable map from TopicPartition to offset + */ + Map offsets(); + + /** + * A supplier of a {@link RecordCollectorImpl} instance. + */ + // TODO: after we have done KAFKA-9088 we should just add this function + // to InternalProcessorContext interface + interface Supplier { + /** + * Get the record collector. + * @return the record collector + */ + RecordCollector recordCollector(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java new file mode 100644 index 0000000..f8c9cf9 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordCollectorImpl.java @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.producer.Producer; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.errors.AuthorizationException; +import org.apache.kafka.common.errors.InvalidProducerEpochException; +import org.apache.kafka.common.errors.InvalidTopicException; +import org.apache.kafka.common.errors.OffsetMetadataTooLarge; +import org.apache.kafka.common.errors.OutOfOrderSequenceException; +import org.apache.kafka.common.errors.ProducerFencedException; +import org.apache.kafka.common.errors.RetriableException; +import org.apache.kafka.common.errors.SecurityDisabledException; +import org.apache.kafka.common.errors.SerializationException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.UnknownServerException; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.streams.errors.ProductionExceptionHandler; +import org.apache.kafka.streams.errors.ProductionExceptionHandler.ProductionExceptionHandlerResponse; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.errors.TaskCorruptedException; +import org.apache.kafka.streams.errors.TaskMigratedException; +import org.apache.kafka.streams.processor.StreamPartitioner; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.processor.internals.metrics.TaskMetrics; +import org.slf4j.Logger; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; + +public class RecordCollectorImpl implements RecordCollector { + private final static String SEND_EXCEPTION_MESSAGE = "Error encountered sending record to topic %s for task %s due to:%n%s"; + + private final Logger log; + private final TaskId taskId; + private final StreamsProducer streamsProducer; + private final ProductionExceptionHandler productionExceptionHandler; + private final Sensor droppedRecordsSensor; + private final boolean eosEnabled; + private final Map offsets; + + private final AtomicReference sendException = new AtomicReference<>(null); + + /** + * @throws StreamsException fatal error that should cause the thread to die (from producer.initTxn) + */ + public RecordCollectorImpl(final LogContext logContext, + final TaskId taskId, + final StreamsProducer streamsProducer, + final ProductionExceptionHandler productionExceptionHandler, + final StreamsMetricsImpl streamsMetrics) { + this.log = logContext.logger(getClass()); + this.taskId = taskId; + this.streamsProducer = streamsProducer; + this.productionExceptionHandler = productionExceptionHandler; + this.eosEnabled = streamsProducer.eosEnabled(); + + final String threadId = Thread.currentThread().getName(); + this.droppedRecordsSensor = TaskMetrics.droppedRecordsSensor(threadId, taskId.toString(), streamsMetrics); + + this.offsets = new HashMap<>(); + } + + @Override + public void initialize() { + if (eosEnabled) { + streamsProducer.initTransaction(); + } + } + + /** + * @throws StreamsException fatal error that should cause the thread to die + * @throws TaskMigratedException recoverable error that would cause the task to be removed + */ + @Override + public void send(final String topic, + final K key, + final V value, + final Headers headers, + final Long timestamp, + final Serializer keySerializer, + final Serializer valueSerializer, + final StreamPartitioner partitioner) { + final Integer partition; + + if (partitioner != null) { + final List partitions; + try { + partitions = streamsProducer.partitionsFor(topic); + } catch (final TimeoutException timeoutException) { + log.warn("Could not get partitions for topic {}, will retry", topic); + + // re-throw to trigger `task.timeout.ms` + throw timeoutException; + } catch (final KafkaException fatal) { + // here we cannot drop the message on the floor even if it is a transient timeout exception, + // so we treat everything the same as a fatal exception + throw new StreamsException("Could not determine the number of partitions for topic '" + topic + + "' for task " + taskId + " due to " + fatal.toString(), + fatal + ); + } + if (partitions.size() > 0) { + partition = partitioner.partition(topic, key, value, partitions.size()); + } else { + throw new StreamsException("Could not get partition information for topic " + topic + " for task " + taskId + + ". This can happen if the topic does not exist."); + } + } else { + partition = null; + } + + send(topic, key, value, headers, partition, timestamp, keySerializer, valueSerializer); + } + + @Override + public void send(final String topic, + final K key, + final V value, + final Headers headers, + final Integer partition, + final Long timestamp, + final Serializer keySerializer, + final Serializer valueSerializer) { + checkForException(); + + final byte[] keyBytes; + final byte[] valBytes; + try { + keyBytes = keySerializer.serialize(topic, headers, key); + valBytes = valueSerializer.serialize(topic, headers, value); + } catch (final ClassCastException exception) { + final String keyClass = key == null ? "unknown because key is null" : key.getClass().getName(); + final String valueClass = value == null ? "unknown because value is null" : value.getClass().getName(); + throw new StreamsException( + String.format( + "ClassCastException while producing data to topic %s. " + + "A serializer (key: %s / value: %s) is not compatible to the actual key or value type " + + "(key type: %s / value type: %s). " + + "Change the default Serdes in StreamConfig or provide correct Serdes via method parameters " + + "(for example if using the DSL, `#to(String topic, Produced produced)` with " + + "`Produced.keySerde(WindowedSerdes.timeWindowedSerdeFrom(String.class))`).", + topic, + keySerializer.getClass().getName(), + valueSerializer.getClass().getName(), + keyClass, + valueClass), + exception); + } catch (final RuntimeException exception) { + final String errorMessage = String.format(SEND_EXCEPTION_MESSAGE, topic, taskId, exception.toString()); + throw new StreamsException(errorMessage, exception); + } + + final ProducerRecord serializedRecord = new ProducerRecord<>(topic, partition, timestamp, keyBytes, valBytes, headers); + + streamsProducer.send(serializedRecord, (metadata, exception) -> { + // if there's already an exception record, skip logging offsets or new exceptions + if (sendException.get() != null) { + return; + } + + if (exception == null) { + final TopicPartition tp = new TopicPartition(metadata.topic(), metadata.partition()); + if (metadata.offset() >= 0L) { + offsets.put(tp, metadata.offset()); + } else { + log.warn("Received offset={} in produce response for {}", metadata.offset(), tp); + } + } else { + recordSendError(topic, exception, serializedRecord); + + // KAFKA-7510 only put message key and value in TRACE level log so we don't leak data by default + log.trace("Failed record: (key {} value {} timestamp {}) topic=[{}] partition=[{}]", key, value, timestamp, topic, partition); + } + }); + } + + private void recordSendError(final String topic, final Exception exception, final ProducerRecord serializedRecord) { + String errorMessage = String.format(SEND_EXCEPTION_MESSAGE, topic, taskId, exception.toString()); + + if (isFatalException(exception)) { + errorMessage += "\nWritten offsets would not be recorded and no more records would be sent since this is a fatal error."; + sendException.set(new StreamsException(errorMessage, exception)); + } else if (exception instanceof ProducerFencedException || + exception instanceof InvalidProducerEpochException || + exception instanceof OutOfOrderSequenceException) { + errorMessage += "\nWritten offsets would not be recorded and no more records would be sent since the producer is fenced, " + + "indicating the task may be migrated out"; + sendException.set(new TaskMigratedException(errorMessage, exception)); + } else { + if (exception instanceof RetriableException) { + errorMessage += "\nThe broker is either slow or in bad state (like not having enough replicas) in responding the request, " + + "or the connection to broker was interrupted sending the request or receiving the response. " + + "\nConsider overwriting `max.block.ms` and /or " + + "`delivery.timeout.ms` to a larger value to wait longer for such scenarios and avoid timeout errors"; + sendException.set(new TaskCorruptedException(Collections.singleton(taskId))); + } else { + if (productionExceptionHandler.handle(serializedRecord, exception) == ProductionExceptionHandlerResponse.FAIL) { + errorMessage += "\nException handler choose to FAIL the processing, no more records would be sent."; + sendException.set(new StreamsException(errorMessage, exception)); + } else { + errorMessage += "\nException handler choose to CONTINUE processing in spite of this error but written offsets would not be recorded."; + droppedRecordsSensor.record(); + } + } + } + + log.error(errorMessage, exception); + } + + private boolean isFatalException(final Exception exception) { + final boolean securityException = exception instanceof AuthenticationException || + exception instanceof AuthorizationException || + exception instanceof SecurityDisabledException; + + final boolean communicationException = exception instanceof InvalidTopicException || + exception instanceof UnknownServerException || + exception instanceof SerializationException || + exception instanceof OffsetMetadataTooLarge || + exception instanceof IllegalStateException; + + return securityException || communicationException; + } + + /** + * @throws StreamsException fatal error that should cause the thread to die + * @throws TaskMigratedException recoverable error that would cause the task to be removed + */ + @Override + public void flush() { + log.debug("Flushing record collector"); + streamsProducer.flush(); + checkForException(); + } + + /** + * @throws StreamsException fatal error that should cause the thread to die + * @throws TaskMigratedException recoverable error that would cause the task to be removed + */ + @Override + public void closeClean() { + log.info("Closing record collector clean"); + + // No need to abort transaction during a clean close: either we have successfully committed the ongoing + // transaction during handleRevocation and thus there is no transaction in flight, or else none of the revoked + // tasks had any data in the current transaction and therefore there is no need to commit or abort it. + + checkForException(); + } + + /** + * @throws StreamsException fatal error that should cause the thread to die + * @throws TaskMigratedException recoverable error that would cause the task to be removed + */ + @Override + public void closeDirty() { + log.info("Closing record collector dirty"); + + if (eosEnabled) { + // We may be closing dirty because the commit failed, so we must abort the transaction to be safe + streamsProducer.abortTransaction(); + } + + checkForException(); + } + + @Override + public Map offsets() { + return Collections.unmodifiableMap(new HashMap<>(offsets)); + } + + private void checkForException() { + final KafkaException exception = sendException.get(); + + if (exception != null) { + sendException.set(null); + throw exception; + } + } + + // for testing only + Producer producer() { + return streamsProducer.kafkaProducer(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordDeserializer.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordDeserializer.java new file mode 100644 index 0000000..b5c821a --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordDeserializer.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.streams.errors.DeserializationExceptionHandler; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.slf4j.Logger; + +import java.util.Optional; + +import static org.apache.kafka.streams.StreamsConfig.DEFAULT_DESERIALIZATION_EXCEPTION_HANDLER_CLASS_CONFIG; + +class RecordDeserializer { + private final Logger log; + private final SourceNode sourceNode; + private final Sensor droppedRecordsSensor; + private final DeserializationExceptionHandler deserializationExceptionHandler; + + RecordDeserializer(final SourceNode sourceNode, + final DeserializationExceptionHandler deserializationExceptionHandler, + final LogContext logContext, + final Sensor droppedRecordsSensor) { + this.sourceNode = sourceNode; + this.deserializationExceptionHandler = deserializationExceptionHandler; + this.log = logContext.logger(RecordDeserializer.class); + this.droppedRecordsSensor = droppedRecordsSensor; + } + + /** + * @throws StreamsException if a deserialization error occurs and the deserialization callback returns + * {@link DeserializationExceptionHandler.DeserializationHandlerResponse#FAIL FAIL} + * or throws an exception itself + */ + ConsumerRecord deserialize(final ProcessorContext processorContext, + final ConsumerRecord rawRecord) { + + try { + return new ConsumerRecord<>( + rawRecord.topic(), + rawRecord.partition(), + rawRecord.offset(), + rawRecord.timestamp(), + TimestampType.CREATE_TIME, + rawRecord.serializedKeySize(), + rawRecord.serializedValueSize(), + sourceNode.deserializeKey(rawRecord.topic(), rawRecord.headers(), rawRecord.key()), + sourceNode.deserializeValue(rawRecord.topic(), rawRecord.headers(), rawRecord.value()), + rawRecord.headers(), + Optional.empty() + ); + } catch (final Exception deserializationException) { + final DeserializationExceptionHandler.DeserializationHandlerResponse response; + try { + response = deserializationExceptionHandler.handle(processorContext, rawRecord, deserializationException); + } catch (final Exception fatalUserException) { + log.error( + "Deserialization error callback failed after deserialization error for record {}", + rawRecord, + deserializationException); + throw new StreamsException("Fatal user code error in deserialization error callback", fatalUserException); + } + + if (response == DeserializationExceptionHandler.DeserializationHandlerResponse.FAIL) { + throw new StreamsException("Deserialization exception handler is set to fail upon" + + " a deserialization error. If you would rather have the streaming pipeline" + + " continue after a deserialization error, please set the " + + DEFAULT_DESERIALIZATION_EXCEPTION_HANDLER_CLASS_CONFIG + " appropriately.", + deserializationException); + } else { + log.warn( + "Skipping record due to deserialization error. topic=[{}] partition=[{}] offset=[{}]", + rawRecord.topic(), + rawRecord.partition(), + rawRecord.offset(), + deserializationException + ); + droppedRecordsSensor.record(); + return null; + } + } + } + + SourceNode sourceNode() { + return sourceNode; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordQueue.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordQueue.java new file mode 100644 index 0000000..418a7b0 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RecordQueue.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.streams.errors.DeserializationExceptionHandler; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.internals.metrics.TaskMetrics; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.TimestampExtractor; +import org.slf4j.Logger; + +import java.util.ArrayDeque; + +/** + * RecordQueue is a FIFO queue of {@link StampedRecord} (ConsumerRecord + timestamp). It also keeps track of the + * partition timestamp defined as the largest timestamp seen on the partition so far; this is passed to the + * timestamp extractor. + */ +public class RecordQueue { + + public static final long UNKNOWN = ConsumerRecord.NO_TIMESTAMP; + + private final Logger log; + private final SourceNode source; + private final TopicPartition partition; + private final ProcessorContext processorContext; + private final TimestampExtractor timestampExtractor; + private final RecordDeserializer recordDeserializer; + private final ArrayDeque> fifoQueue; + + private StampedRecord headRecord = null; + private long partitionTime = UNKNOWN; + + private final Sensor droppedRecordsSensor; + + RecordQueue(final TopicPartition partition, + final SourceNode source, + final TimestampExtractor timestampExtractor, + final DeserializationExceptionHandler deserializationExceptionHandler, + final InternalProcessorContext processorContext, + final LogContext logContext) { + this.source = source; + this.partition = partition; + this.fifoQueue = new ArrayDeque<>(); + this.timestampExtractor = timestampExtractor; + this.processorContext = processorContext; + droppedRecordsSensor = TaskMetrics.droppedRecordsSensor( + Thread.currentThread().getName(), + processorContext.taskId().toString(), + processorContext.metrics() + ); + recordDeserializer = new RecordDeserializer( + source, + deserializationExceptionHandler, + logContext, + droppedRecordsSensor + ); + this.log = logContext.logger(RecordQueue.class); + } + + void setPartitionTime(final long partitionTime) { + this.partitionTime = partitionTime; + } + + /** + * Returns the corresponding source node in the topology + * + * @return SourceNode + */ + public SourceNode source() { + return source; + } + + /** + * Returns the partition with which this queue is associated + * + * @return TopicPartition + */ + public TopicPartition partition() { + return partition; + } + + /** + * Add a batch of {@link ConsumerRecord} into the queue + * + * @param rawRecords the raw records + * @return the size of this queue + */ + int addRawRecords(final Iterable> rawRecords) { + for (final ConsumerRecord rawRecord : rawRecords) { + fifoQueue.addLast(rawRecord); + } + + updateHead(); + + return size(); + } + + /** + * Get the next {@link StampedRecord} from the queue + * + * @return StampedRecord + */ + public StampedRecord poll() { + final StampedRecord recordToReturn = headRecord; + headRecord = null; + partitionTime = Math.max(partitionTime, recordToReturn.timestamp); + + updateHead(); + + return recordToReturn; + } + + /** + * Returns the number of records in the queue + * + * @return the number of records + */ + public int size() { + // plus one deserialized head record for timestamp tracking + return fifoQueue.size() + (headRecord == null ? 0 : 1); + } + + /** + * Tests if the queue is empty + * + * @return true if the queue is empty, otherwise false + */ + public boolean isEmpty() { + return fifoQueue.isEmpty() && headRecord == null; + } + + /** + * Returns the head record's timestamp + * + * @return timestamp + */ + public long headRecordTimestamp() { + return headRecord == null ? UNKNOWN : headRecord.timestamp; + } + + public Long headRecordOffset() { + return headRecord == null ? null : headRecord.offset(); + } + + /** + * Clear the fifo queue of its elements + */ + public void clear() { + fifoQueue.clear(); + headRecord = null; + partitionTime = UNKNOWN; + } + + private void updateHead() { + while (headRecord == null && !fifoQueue.isEmpty()) { + final ConsumerRecord raw = fifoQueue.pollFirst(); + final ConsumerRecord deserialized = recordDeserializer.deserialize(processorContext, raw); + + if (deserialized == null) { + // this only happens if the deserializer decides to skip. It has already logged the reason. + continue; + } + + final long timestamp; + try { + timestamp = timestampExtractor.extract(deserialized, partitionTime); + } catch (final StreamsException internalFatalExtractorException) { + throw internalFatalExtractorException; + } catch (final Exception fatalUserException) { + throw new StreamsException( + String.format("Fatal user code error in TimestampExtractor callback for record %s.", deserialized), + fatalUserException); + } + log.trace("Source node {} extracted timestamp {} for record {}", source.name(), timestamp, deserialized); + + // drop message if TS is invalid, i.e., negative + if (timestamp < 0) { + log.warn( + "Skipping record due to negative extracted timestamp. topic=[{}] partition=[{}] offset=[{}] extractedTimestamp=[{}] extractor=[{}]", + deserialized.topic(), deserialized.partition(), deserialized.offset(), timestamp, timestampExtractor.getClass().getCanonicalName() + ); + droppedRecordsSensor.record(); + continue; + } + headRecord = new StampedRecord(deserialized, timestamp); + } + } + + /** + * @return the local partitionTime for this particular RecordQueue + */ + long partitionTime() { + return partitionTime; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RepartitionTopicConfig.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RepartitionTopicConfig.java new file mode 100644 index 0000000..098bb9f --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RepartitionTopicConfig.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.config.TopicConfig; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +/** + * RepartitionTopicConfig captures the properties required for configuring + * the repartition topics. + */ +public class RepartitionTopicConfig extends InternalTopicConfig { + + private static final Map REPARTITION_TOPIC_DEFAULT_OVERRIDES; + static { + final Map tempTopicDefaultOverrides = new HashMap<>(INTERNAL_TOPIC_DEFAULT_OVERRIDES); + tempTopicDefaultOverrides.put(TopicConfig.CLEANUP_POLICY_CONFIG, TopicConfig.CLEANUP_POLICY_DELETE); + tempTopicDefaultOverrides.put(TopicConfig.SEGMENT_BYTES_CONFIG, "52428800"); // 50 MB + tempTopicDefaultOverrides.put(TopicConfig.RETENTION_MS_CONFIG, String.valueOf(-1)); // Infinity + REPARTITION_TOPIC_DEFAULT_OVERRIDES = Collections.unmodifiableMap(tempTopicDefaultOverrides); + } + + RepartitionTopicConfig(final String name, final Map topicConfigs) { + super(name, topicConfigs); + } + + RepartitionTopicConfig(final String name, + final Map topicConfigs, + final int numberOfPartitions, + final boolean enforceNumberOfPartitions) { + super(name, topicConfigs, numberOfPartitions, enforceNumberOfPartitions); + } + + /** + * Get the configured properties for this topic. If retentionMs is set then + * we add additionalRetentionMs to work out the desired retention when cleanup.policy=compact,delete + * + * @param additionalRetentionMs - added to retention to allow for clock drift etc + * @return Properties to be used when creating the topic + */ + @Override + public Map getProperties(final Map defaultProperties, final long additionalRetentionMs) { + // internal topic config overridden rule: library overrides < global config overrides < per-topic config overrides + final Map topicConfig = new HashMap<>(REPARTITION_TOPIC_DEFAULT_OVERRIDES); + + topicConfig.putAll(defaultProperties); + + topicConfig.putAll(topicConfigs); + + return topicConfig; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final RepartitionTopicConfig that = (RepartitionTopicConfig) o; + return Objects.equals(name, that.name) && + Objects.equals(topicConfigs, that.topicConfigs) && + Objects.equals(enforceNumberOfPartitions, that.enforceNumberOfPartitions); + } + + @Override + public int hashCode() { + return Objects.hash(name, topicConfigs, enforceNumberOfPartitions); + } + + @Override + public String toString() { + return "RepartitionTopicConfig(" + + "name=" + name + + ", topicConfigs=" + topicConfigs + + ", enforceNumberOfPartitions=" + enforceNumberOfPartitions + + ")"; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RepartitionTopics.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RepartitionTopics.java new file mode 100644 index 0000000..801e6c1 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RepartitionTopics.java @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.streams.errors.MissingSourceTopicException; +import org.apache.kafka.streams.errors.TaskAssignmentException; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder.TopicsInfo; +import org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology; +import org.apache.kafka.streams.processor.internals.assignment.CopartitionedTopicsEnforcer; +import org.slf4j.Logger; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +public class RepartitionTopics { + + private final InternalTopicManager internalTopicManager; + private final TopologyMetadata topologyMetadata; + private final Cluster clusterMetadata; + private final CopartitionedTopicsEnforcer copartitionedTopicsEnforcer; + private final Logger log; + private final Map topicPartitionInfos = new HashMap<>(); + + public RepartitionTopics(final TopologyMetadata topologyMetadata, + final InternalTopicManager internalTopicManager, + final CopartitionedTopicsEnforcer copartitionedTopicsEnforcer, + final Cluster clusterMetadata, + final String logPrefix) { + this.topologyMetadata = topologyMetadata; + this.internalTopicManager = internalTopicManager; + this.clusterMetadata = clusterMetadata; + this.copartitionedTopicsEnforcer = copartitionedTopicsEnforcer; + final LogContext logContext = new LogContext(logPrefix); + log = logContext.logger(getClass()); + } + + public void setup() { + final Map topicGroups = topologyMetadata.topicGroups(); + final Map repartitionTopicMetadata = computeRepartitionTopicConfig(topicGroups, clusterMetadata); + + // ensure the co-partitioning topics within the group have the same number of partitions, + // and enforce the number of partitions for those repartition topics to be the same if they + // are co-partitioned as well. + ensureCopartitioning(topologyMetadata.copartitionGroups(), repartitionTopicMetadata, clusterMetadata); + + // make sure the repartition source topics exist with the right number of partitions, + // create these topics if necessary + internalTopicManager.makeReady(repartitionTopicMetadata); + + // augment the metadata with the newly computed number of partitions for all the + // repartition source topics + for (final Map.Entry entry : repartitionTopicMetadata.entrySet()) { + final String topic = entry.getKey(); + final int numPartitions = entry.getValue().numberOfPartitions().orElse(-1); + + for (int partition = 0; partition < numPartitions; partition++) { + topicPartitionInfos.put( + new TopicPartition(topic, partition), + new PartitionInfo(topic, partition, null, new Node[0], new Node[0]) + ); + } + } + } + + public Map topicPartitionsInfo() { + return Collections.unmodifiableMap(topicPartitionInfos); + } + + private Map computeRepartitionTopicConfig(final Map topicGroups, + final Cluster clusterMetadata) { + + final Map repartitionTopicConfigs = new HashMap<>(); + for (final TopicsInfo topicsInfo : topicGroups.values()) { + checkIfExternalSourceTopicsExist(topicsInfo, clusterMetadata); + repartitionTopicConfigs.putAll(topicsInfo.repartitionSourceTopics.values().stream() + .collect(Collectors.toMap(InternalTopicConfig::name, topicConfig -> topicConfig))); + } + setRepartitionSourceTopicPartitionCount(repartitionTopicConfigs, topicGroups, clusterMetadata); + + return repartitionTopicConfigs; + } + + private void ensureCopartitioning(final Collection> copartitionGroups, + final Map repartitionTopicMetadata, + final Cluster clusterMetadata) { + for (final Set copartitionGroup : copartitionGroups) { + copartitionedTopicsEnforcer.enforce(copartitionGroup, repartitionTopicMetadata, clusterMetadata); + } + } + + private void checkIfExternalSourceTopicsExist(final TopicsInfo topicsInfo, + final Cluster clusterMetadata) { + final Set missingExternalSourceTopics = new HashSet<>(topicsInfo.sourceTopics); + missingExternalSourceTopics.removeAll(topicsInfo.repartitionSourceTopics.keySet()); + missingExternalSourceTopics.removeAll(clusterMetadata.topics()); + if (!missingExternalSourceTopics.isEmpty()) { + log.error("The following source topics are missing/unknown: {}. Please make sure all source topics " + + "have been pre-created before starting the Streams application. ", + missingExternalSourceTopics); + throw new MissingSourceTopicException("Missing source topics."); + } + } + + /** + * Computes the number of partitions and sets it for each repartition topic in repartitionTopicMetadata + */ + private void setRepartitionSourceTopicPartitionCount(final Map repartitionTopicMetadata, + final Map topicGroups, + final Cluster clusterMetadata) { + boolean partitionCountNeeded; + do { + partitionCountNeeded = false; + boolean progressMadeThisIteration = false; // avoid infinitely looping without making any progress on unknown repartitions + + for (final TopicsInfo topicsInfo : topicGroups.values()) { + for (final String repartitionSourceTopic : topicsInfo.repartitionSourceTopics.keySet()) { + final Optional repartitionSourceTopicPartitionCount = + repartitionTopicMetadata.get(repartitionSourceTopic).numberOfPartitions(); + + if (!repartitionSourceTopicPartitionCount.isPresent()) { + final Integer numPartitions = computePartitionCount( + repartitionTopicMetadata, + topicGroups, + clusterMetadata, + repartitionSourceTopic + ); + + if (numPartitions == null) { + partitionCountNeeded = true; + log.trace("Unable to determine number of partitions for {}, another iteration is needed", + repartitionSourceTopic); + } else { + log.trace("Determined number of partitions for {} to be {}", repartitionSourceTopic, numPartitions); + repartitionTopicMetadata.get(repartitionSourceTopic).setNumberOfPartitions(numPartitions); + progressMadeThisIteration = true; + } + } + } + } + if (!progressMadeThisIteration && partitionCountNeeded) { + log.error("Unable to determine the number of partitions of all repartition topics, most likely a source topic is missing or pattern doesn't match any topics\n" + + "topic groups: {}\n" + + "cluster topics: {}.", topicGroups, clusterMetadata.topics()); + throw new TaskAssignmentException("Failed to compute number of partitions for all repartition topics, " + + "make sure all user input topics are created and all Pattern subscriptions match at least one topic in the cluster"); + } + } while (partitionCountNeeded); + } + + private Integer computePartitionCount(final Map repartitionTopicMetadata, + final Map topicGroups, + final Cluster clusterMetadata, + final String repartitionSourceTopic) { + Integer partitionCount = null; + // try set the number of partitions for this repartition topic if it is not set yet + for (final TopicsInfo topicsInfo : topicGroups.values()) { + final Set sinkTopics = topicsInfo.sinkTopics; + + if (sinkTopics.contains(repartitionSourceTopic)) { + // if this topic is one of the sink topics of this topology, + // use the maximum of all its source topic partitions as the number of partitions + for (final String upstreamSourceTopic : topicsInfo.sourceTopics) { + Integer numPartitionsCandidate = null; + // It is possible the sourceTopic is another internal topic, i.e, + // map().join().join(map()) + if (repartitionTopicMetadata.containsKey(upstreamSourceTopic)) { + if (repartitionTopicMetadata.get(upstreamSourceTopic).numberOfPartitions().isPresent()) { + numPartitionsCandidate = + repartitionTopicMetadata.get(upstreamSourceTopic).numberOfPartitions().get(); + } + } else { + final Integer count = clusterMetadata.partitionCountForTopic(upstreamSourceTopic); + if (count == null) { + throw new TaskAssignmentException( + "No partition count found for source topic " + + upstreamSourceTopic + + ", but it should have been." + ); + } + numPartitionsCandidate = count; + } + + if (numPartitionsCandidate != null) { + if (partitionCount == null || numPartitionsCandidate > partitionCount) { + partitionCount = numPartitionsCandidate; + } + } + } + } + } + return partitionCount; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/RestoringTasks.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RestoringTasks.java new file mode 100644 index 0000000..36bddcc --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/RestoringTasks.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.TopicPartition; + +public interface RestoringTasks { + + StreamTask restoringTaskFor(final TopicPartition partition); + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/SerdeGetter.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/SerdeGetter.java new file mode 100644 index 0000000..72bfc99 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/SerdeGetter.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.processor.StateStoreContext; + +/** + * Allows serde access across different context types. + */ +public class SerdeGetter { + + private final org.apache.kafka.streams.processor.ProcessorContext oldProcessorContext; + private final org.apache.kafka.streams.processor.api.ProcessorContext newProcessorContext; + private final StateStoreContext stateStorecontext; + public SerdeGetter(final org.apache.kafka.streams.processor.ProcessorContext context) { + oldProcessorContext = context; + newProcessorContext = null; + stateStorecontext = null; + } + public SerdeGetter(final org.apache.kafka.streams.processor.api.ProcessorContext context) { + oldProcessorContext = null; + newProcessorContext = context; + stateStorecontext = null; + } + public SerdeGetter(final StateStoreContext context) { + oldProcessorContext = null; + newProcessorContext = null; + stateStorecontext = context; + } + public Serde keySerde() { + return oldProcessorContext != null ? oldProcessorContext.keySerde() : + newProcessorContext != null ? newProcessorContext.keySerde() : stateStorecontext.keySerde(); + } + public Serde valueSerde() { + return oldProcessorContext != null ? oldProcessorContext.valueSerde() : + newProcessorContext != null ? newProcessorContext.valueSerde() : stateStorecontext.valueSerde(); + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/SinkNode.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/SinkNode.java new file mode 100644 index 0000000..f30e2d2 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/SinkNode.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.processor.StreamPartitioner; +import org.apache.kafka.streams.processor.TopicNameExtractor; +import org.apache.kafka.streams.processor.api.Record; + +import static org.apache.kafka.streams.kstream.internals.WrappingNullableUtils.prepareKeySerializer; +import static org.apache.kafka.streams.kstream.internals.WrappingNullableUtils.prepareValueSerializer; + +public class SinkNode extends ProcessorNode { + + private Serializer keySerializer; + private Serializer valSerializer; + private final TopicNameExtractor topicExtractor; + private final StreamPartitioner partitioner; + + private InternalProcessorContext context; + + SinkNode(final String name, + final TopicNameExtractor topicExtractor, + final Serializer keySerializer, + final Serializer valSerializer, + final StreamPartitioner partitioner) { + super(name); + + this.topicExtractor = topicExtractor; + this.keySerializer = keySerializer; + this.valSerializer = valSerializer; + this.partitioner = partitioner; + } + + /** + * @throws UnsupportedOperationException if this method adds a child to a sink node + */ + @Override + public void addChild(final ProcessorNode child) { + throw new UnsupportedOperationException("sink node does not allow addChild"); + } + + @Override + public void init(final InternalProcessorContext context) { + super.init(context); + this.context = context; + keySerializer = prepareKeySerializer(keySerializer, context, this.name()); + valSerializer = prepareValueSerializer(valSerializer, context, this.name()); + } + + @Override + public void process(final Record record) { + final RecordCollector collector = ((RecordCollector.Supplier) context).recordCollector(); + + final KIn key = record.key(); + final VIn value = record.value(); + + final long timestamp = record.timestamp(); + + final ProcessorRecordContext contextForExtraction = + new ProcessorRecordContext( + timestamp, + context.offset(), + context.partition(), + context.topic(), + record.headers() + ); + + final String topic = topicExtractor.extract(key, value, contextForExtraction); + + collector.send(topic, key, value, record.headers(), timestamp, keySerializer, valSerializer, partitioner); + } + + /** + * @return a string representation of this node, useful for debugging. + */ + @Override + public String toString() { + return toString(""); + } + + /** + * @return a string representation of this node starting with the given indent, useful for debugging. + */ + @Override + public String toString(final String indent) { + final StringBuilder sb = new StringBuilder(super.toString(indent)); + sb.append(indent).append("\ttopic:\t\t"); + sb.append(topicExtractor); + sb.append("\n"); + return sb.toString(); + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/SourceNode.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/SourceNode.java new file mode 100644 index 0000000..5d0c04b --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/SourceNode.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.streams.processor.TimestampExtractor; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.metrics.ProcessorNodeMetrics; + +import static org.apache.kafka.streams.kstream.internals.WrappingNullableUtils.prepareKeyDeserializer; +import static org.apache.kafka.streams.kstream.internals.WrappingNullableUtils.prepareValueDeserializer; + +public class SourceNode extends ProcessorNode { + + private InternalProcessorContext context; + private Deserializer keyDeserializer; + private Deserializer valDeserializer; + private final TimestampExtractor timestampExtractor; + private Sensor processAtSourceSensor; + + public SourceNode(final String name, + final TimestampExtractor timestampExtractor, + final Deserializer keyDeserializer, + final Deserializer valDeserializer) { + super(name); + this.timestampExtractor = timestampExtractor; + this.keyDeserializer = keyDeserializer; + this.valDeserializer = valDeserializer; + } + + public SourceNode(final String name, + final Deserializer keyDeserializer, + final Deserializer valDeserializer) { + this(name, null, keyDeserializer, valDeserializer); + } + + KIn deserializeKey(final String topic, final Headers headers, final byte[] data) { + return keyDeserializer.deserialize(topic, headers, data); + } + + VIn deserializeValue(final String topic, final Headers headers, final byte[] data) { + return valDeserializer.deserialize(topic, headers, data); + } + + @Override + public void init(final InternalProcessorContext context) { + // It is important to first create the sensor before calling init on the + // parent object. Otherwise due to backwards compatibility an empty sensor + // without parent is created with the same name. + // Once the backwards compatibility is not needed anymore it might be possible to + // change this. + processAtSourceSensor = ProcessorNodeMetrics.processAtSourceSensor( + Thread.currentThread().getName(), + context.taskId().toString(), + context.currentNode().name(), + context.metrics() + ); + super.init(context); + this.context = context; + + keyDeserializer = prepareKeyDeserializer(keyDeserializer, context, name()); + valDeserializer = prepareValueDeserializer(valDeserializer, context, name()); + } + + + @Override + public void process(final Record record) { + context.forward(record); + processAtSourceSensor.record(1.0d, context.currentSystemTimeMs()); + } + + /** + * @return a string representation of this node, useful for debugging. + */ + @Override + public String toString() { + return toString(""); + } + + public TimestampExtractor getTimestampExtractor() { + return timestampExtractor; + } +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Stamped.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Stamped.java new file mode 100644 index 0000000..cbbb244 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Stamped.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import java.util.Objects; + +public class Stamped implements Comparable { + + public final V value; + public final long timestamp; + + Stamped(final V value, final long timestamp) { + this.value = value; + this.timestamp = timestamp; + } + + @Override + public int compareTo(final Object other) { + final long otherTimestamp = ((Stamped) other).timestamp; + + if (timestamp < otherTimestamp) { + return -1; + } else if (timestamp > otherTimestamp) { + return 1; + } + return 0; + } + + @Override + public boolean equals(final Object other) { + if (other == null || getClass() != other.getClass()) { + return false; + } + final long otherTimestamp = ((Stamped) other).timestamp; + return timestamp == otherTimestamp; + } + + @Override + public int hashCode() { + return Objects.hash(timestamp); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StampedRecord.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StampedRecord.java new file mode 100644 index 0000000..3c6df2a --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StampedRecord.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.header.Headers; + +public class StampedRecord extends Stamped> { + + public StampedRecord(final ConsumerRecord record, final long timestamp) { + super(record, timestamp); + } + + public String topic() { + return value.topic(); + } + + public int partition() { + return value.partition(); + } + + public Object key() { + return value.key(); + } + + public Object value() { + return value.value(); + } + + public long offset() { + return value.offset(); + } + + public Headers headers() { + return value.headers(); + } + + @Override + public String toString() { + return value.toString() + ", timestamp = " + timestamp; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java new file mode 100644 index 0000000..802bca1 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTask.java @@ -0,0 +1,340 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.StreamsMetrics; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.errors.TaskMigratedException; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.processor.internals.metrics.ThreadMetrics; +import org.apache.kafka.streams.state.internals.ThreadCache; + +import java.util.Collections; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +/** + * A StandbyTask + */ +public class StandbyTask extends AbstractTask implements Task { + private final Sensor closeTaskSensor; + private final boolean eosEnabled; + private final InternalProcessorContext processorContext; + private final StreamsMetricsImpl streamsMetrics; + + /** + * @param id the ID of this task + * @param inputPartitions input topic partitions, used for thread metadata only + * @param topology the instance of {@link ProcessorTopology} + * @param config the {@link StreamsConfig} specified by the user + * @param streamsMetrics the {@link StreamsMetrics} created by the thread + * @param stateMgr the {@link ProcessorStateManager} for this task + * @param stateDirectory the {@link StateDirectory} created by the thread + */ + StandbyTask(final TaskId id, + final Set inputPartitions, + final ProcessorTopology topology, + final StreamsConfig config, + final StreamsMetricsImpl streamsMetrics, + final ProcessorStateManager stateMgr, + final StateDirectory stateDirectory, + final ThreadCache cache, + final InternalProcessorContext processorContext) { + super( + id, + topology, + stateDirectory, + stateMgr, + inputPartitions, + config.getLong(StreamsConfig.TASK_TIMEOUT_MS_CONFIG), + "standby-task", + StandbyTask.class + ); + this.processorContext = processorContext; + this.streamsMetrics = streamsMetrics; + processorContext.transitionToStandby(cache); + + closeTaskSensor = ThreadMetrics.closeTaskSensor(Thread.currentThread().getName(), streamsMetrics); + eosEnabled = StreamThread.eosEnabled(config); + } + + @Override + public boolean isActive() { + return false; + } + + /** + * @throws StreamsException fatal error, should close the thread + */ + @Override + public void initializeIfNeeded() { + if (state() == State.CREATED) { + StateManagerUtil.registerStateStores(log, logPrefix, topology, stateMgr, stateDirectory, processorContext); + + // with and without EOS we would check for checkpointing at each commit during running, + // and the file may be deleted in which case we should checkpoint immediately, + // therefore we initialize the snapshot as empty + offsetSnapshotSinceLastFlush = Collections.emptyMap(); + + // no topology needs initialized, we can transit to RUNNING + // right after registered the stores + transitionTo(State.RESTORING); + transitionTo(State.RUNNING); + + processorContext.initialize(); + + log.info("Initialized"); + } else if (state() == State.RESTORING) { + throw new IllegalStateException("Illegal state " + state() + " while initializing standby task " + id); + } + } + + @Override + public void completeRestoration(final java.util.function.Consumer> offsetResetter) { + throw new IllegalStateException("Standby task " + id + " should never be completing restoration"); + } + + @Override + public void suspend() { + switch (state()) { + case CREATED: + log.info("Suspended created"); + transitionTo(State.SUSPENDED); + + break; + + case RUNNING: + log.info("Suspended running"); + transitionTo(State.SUSPENDED); + + break; + + case SUSPENDED: + log.info("Skip suspending since state is {}", state()); + + break; + + case RESTORING: + case CLOSED: + throw new IllegalStateException("Illegal state " + state() + " while suspending standby task " + id); + + default: + throw new IllegalStateException("Unknown state " + state() + " while suspending standby task " + id); + } + } + + @Override + public void resume() { + if (state() == State.RESTORING) { + throw new IllegalStateException("Illegal state " + state() + " while resuming standby task " + id); + } + log.trace("No-op resume with state {}", state()); + } + + /** + * Flush stores before a commit; the following exceptions maybe thrown from the state manager flushing call + * + * @throws TaskMigratedException recoverable error sending changelog records that would cause the task to be removed + * @throws StreamsException fatal error when flushing the state store, for example sending changelog records failed + * or flushing state store get IO errors; such error should cause the thread to die + */ + @Override + public Map prepareCommit() { + switch (state()) { + case CREATED: + log.debug("Skipped preparing created task for commit"); + + break; + + case RUNNING: + case SUSPENDED: + // do not need to flush state store caches in pre-commit since nothing would be sent for standby tasks + log.debug("Prepared {} task for committing", state()); + + break; + + default: + throw new IllegalStateException("Illegal state " + state() + " while preparing standby task " + id + " for committing "); + } + + return Collections.emptyMap(); + } + + @Override + public void postCommit(final boolean enforceCheckpoint) { + switch (state()) { + case CREATED: + // We should never write a checkpoint for a CREATED task as we may overwrite an existing checkpoint + // with empty uninitialized offsets + log.debug("Skipped writing checkpoint for created task"); + + break; + + case RUNNING: + case SUSPENDED: + maybeWriteCheckpoint(enforceCheckpoint); + + log.debug("Finalized commit for {} task", state()); + + break; + + default: + throw new IllegalStateException("Illegal state " + state() + " while post committing standby task " + id); + } + } + + @Override + public void closeClean() { + streamsMetrics.removeAllTaskLevelSensors(Thread.currentThread().getName(), id.toString()); + close(true); + log.info("Closed clean"); + } + + @Override + public void closeDirty() { + streamsMetrics.removeAllTaskLevelSensors(Thread.currentThread().getName(), id.toString()); + close(false); + log.info("Closed dirty"); + } + + @Override + public void closeCleanAndRecycleState() { + streamsMetrics.removeAllTaskLevelSensors(Thread.currentThread().getName(), id.toString()); + if (state() == State.SUSPENDED) { + stateMgr.recycle(); + } else { + throw new IllegalStateException("Illegal state " + state() + " while closing standby task " + id); + } + + closeTaskSensor.record(); + transitionTo(State.CLOSED); + + log.info("Closed clean and recycled state"); + } + + private void close(final boolean clean) { + switch (state()) { + case SUSPENDED: + TaskManager.executeAndMaybeSwallow( + clean, + () -> StateManagerUtil.closeStateManager( + log, + logPrefix, + clean, + eosEnabled, + stateMgr, + stateDirectory, + TaskType.STANDBY + ), + "state manager close", + log + ); + + break; + + case CLOSED: + log.trace("Skip closing since state is {}", state()); + return; + + case CREATED: + case RESTORING: // a StandbyTask is never in RESTORING state + case RUNNING: + throw new IllegalStateException("Illegal state " + state() + " while closing standby task " + id); + + default: + throw new IllegalStateException("Unknown state " + state() + " while closing standby task " + id); + } + + closeTaskSensor.record(); + transitionTo(State.CLOSED); + } + + @Override + public boolean commitNeeded() { + // for standby tasks committing is the same as checkpointing, + // so we only need to commit if we want to checkpoint + return StateManagerUtil.checkpointNeeded(false, offsetSnapshotSinceLastFlush, stateMgr.changelogOffsets()); + } + + @Override + public Map changelogOffsets() { + return Collections.unmodifiableMap(stateMgr.changelogOffsets()); + } + + @Override + public Map committedOffsets() { + return Collections.emptyMap(); + } + + @Override + public Map highWaterMark() { + return Collections.emptyMap(); + } + + @Override + public Optional timeCurrentIdlingStarted() { + return Optional.empty(); + } + + @Override + public void addRecords(final TopicPartition partition, final Iterable> records) { + throw new IllegalStateException("Attempted to add records to task " + id() + " for invalid input partition " + partition); + } + + InternalProcessorContext processorContext() { + return processorContext; + } + + /** + * Produces a string representation containing useful information about a Task. + * This is useful in debugging scenarios. + * + * @return A string representation of the StreamTask instance. + */ + @Override + public String toString() { + return toString(""); + } + + /** + * Produces a string representation containing useful information about a Task starting with the given indent. + * This is useful in debugging scenarios. + * + * @return A string representation of the Task instance. + */ + public String toString(final String indent) { + final StringBuilder sb = new StringBuilder(); + sb.append(indent); + sb.append("TaskId: "); + sb.append(id); + sb.append("\n"); + + // print topology + if (topology != null) { + sb.append(indent).append(topology.toString(indent + "\t")); + } + + return sb.toString(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java new file mode 100644 index 0000000..ee2a3e1 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StandbyTaskCreator.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.Task.TaskType; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.processor.internals.metrics.ThreadMetrics; +import org.apache.kafka.streams.state.internals.ThreadCache; +import org.slf4j.Logger; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.apache.kafka.common.utils.Utils.filterMap; + +class StandbyTaskCreator { + private final TopologyMetadata topologyMetadata; + private final StreamsConfig config; + private final StreamsMetricsImpl streamsMetrics; + private final StateDirectory stateDirectory; + private final ChangelogReader storeChangelogReader; + private final ThreadCache dummyCache; + private final Logger log; + private final Sensor createTaskSensor; + + // tasks may be assigned for a NamedTopology that is not yet known by this host, and saved for later creation + private final Map> unknownTasksToBeCreated = new HashMap<>(); + + StandbyTaskCreator(final TopologyMetadata topologyMetadata, + final StreamsConfig config, + final StreamsMetricsImpl streamsMetrics, + final StateDirectory stateDirectory, + final ChangelogReader storeChangelogReader, + final String threadId, + final Logger log) { + this.topologyMetadata = topologyMetadata; + this.config = config; + this.streamsMetrics = streamsMetrics; + this.stateDirectory = stateDirectory; + this.storeChangelogReader = storeChangelogReader; + this.log = log; + + createTaskSensor = ThreadMetrics.createTaskSensor(threadId, streamsMetrics); + + dummyCache = new ThreadCache( + new LogContext(String.format("stream-thread [%s] ", Thread.currentThread().getName())), + 0, + streamsMetrics + ); + } + + void removeRevokedUnknownTasks(final Set assignedTasks) { + unknownTasksToBeCreated.keySet().retainAll(assignedTasks); + } + + Map> uncreatedTasksForTopologies(final Set currentTopologies) { + return filterMap(unknownTasksToBeCreated, t -> currentTopologies.contains(t.getKey().topologyName())); + } + + // TODO: change return type to `StandbyTask` + Collection createTasks(final Map> tasksToBeCreated) { + // TODO: change type to `StandbyTask` + final List createdTasks = new ArrayList<>(); + final Map> newUnknownTasks = new HashMap<>(); + + for (final Map.Entry> newTaskAndPartitions : tasksToBeCreated.entrySet()) { + final TaskId taskId = newTaskAndPartitions.getKey(); + final Set partitions = newTaskAndPartitions.getValue(); + + final ProcessorTopology topology = topologyMetadata.buildSubtopology(taskId); + if (topology == null) { + // task belongs to a named topology that hasn't been added yet, wait until it has to create this + newUnknownTasks.put(taskId, partitions); + continue; + } + + if (topology.hasStateWithChangelogs()) { + final ProcessorStateManager stateManager = new ProcessorStateManager( + taskId, + Task.TaskType.STANDBY, + StreamThread.eosEnabled(config), + getLogContext(taskId), + stateDirectory, + storeChangelogReader, + topology.storeToChangelogTopic(), + partitions + ); + + final InternalProcessorContext context = new ProcessorContextImpl( + taskId, + config, + stateManager, + streamsMetrics, + dummyCache + ); + + createdTasks.add(createStandbyTask(taskId, partitions, topology, stateManager, context)); + } else { + log.trace( + "Skipped standby task {} with assigned partitions {} " + + "since it does not have any state stores to materialize", + taskId, partitions + ); + } + unknownTasksToBeCreated.remove(taskId); + } + if (!newUnknownTasks.isEmpty()) { + log.info("Delaying creation of tasks not yet known by this instance: {}", newUnknownTasks.keySet()); + unknownTasksToBeCreated.putAll(newUnknownTasks); + } + return createdTasks; + } + + StandbyTask createStandbyTaskFromActive(final StreamTask streamTask, + final Set inputPartitions) { + final InternalProcessorContext context = streamTask.processorContext(); + final ProcessorStateManager stateManager = streamTask.stateMgr; + + streamTask.closeCleanAndRecycleState(); + stateManager.transitionTaskType(TaskType.STANDBY, getLogContext(streamTask.id())); + + return createStandbyTask( + streamTask.id(), + inputPartitions, + topologyMetadata.buildSubtopology(streamTask.id), + stateManager, + context + ); + } + + StandbyTask createStandbyTask(final TaskId taskId, + final Set inputPartitions, + final ProcessorTopology topology, + final ProcessorStateManager stateManager, + final InternalProcessorContext context) { + final StandbyTask task = new StandbyTask( + taskId, + inputPartitions, + topology, + config, + streamsMetrics, + stateManager, + stateDirectory, + dummyCache, + context + ); + + log.trace("Created task {} with assigned partitions {}", taskId, inputPartitions); + createTaskSensor.record(); + return task; + } + + private LogContext getLogContext(final TaskId taskId) { + final String threadIdPrefix = String.format("stream-thread [%s] ", Thread.currentThread().getName()); + final String logPrefix = threadIdPrefix + String.format("%s [%s] ", "standby-task", taskId); + return new LogContext(logPrefix); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateDirectory.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateDirectory.java new file mode 100644 index 0000000..f242f55 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateDirectory.java @@ -0,0 +1,678 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.ProcessorStateException; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.TaskId; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileFilter; +import java.io.IOException; +import java.nio.channels.FileChannel; +import java.nio.channels.FileLock; +import java.nio.channels.OverlappingFileLockException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.nio.file.attribute.PosixFilePermission; +import java.nio.file.attribute.PosixFilePermissions; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicReference; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import static org.apache.kafka.streams.processor.internals.StateManagerUtil.CHECKPOINT_FILE_NAME; +import static org.apache.kafka.streams.processor.internals.StateManagerUtil.parseTaskDirectoryName; + +/** + * Manages the directories where the state of Tasks owned by a {@link StreamThread} are + * stored. Handles creation/locking/unlocking/cleaning of the Task Directories. This class is not + * thread-safe. + */ +public class StateDirectory { + + private static final Pattern TASK_DIR_PATH_NAME = Pattern.compile("\\d+_\\d+"); + private static final Pattern NAMED_TOPOLOGY_DIR_PATH_NAME = Pattern.compile("__.+__"); // named topology dirs follow '__Topology-Name__' + private static final Logger log = LoggerFactory.getLogger(StateDirectory.class); + static final String LOCK_FILE_NAME = ".lock"; + + /* The process file is used to persist the process id across restarts. + * For compatibility reasons you should only ever add fields to the json schema + */ + static final String PROCESS_FILE_NAME = "kafka-streams-process-metadata"; + + @JsonIgnoreProperties(ignoreUnknown = true) + static class StateDirectoryProcessFile { + @JsonProperty + private final UUID processId; + + // required by jackson -- do not remove, your IDE may be warning that this is unused but it's lying to you + public StateDirectoryProcessFile() { + this.processId = null; + } + + StateDirectoryProcessFile(final UUID processId) { + this.processId = processId; + } + } + + private final Object taskDirCreationLock = new Object(); + private final Time time; + private final String appId; + private final File stateDir; + private final boolean hasPersistentStores; + private final boolean hasNamedTopologies; + + private final HashMap lockedTasksToOwner = new HashMap<>(); + + private FileChannel stateDirLockChannel; + private FileLock stateDirLock; + + /** + * Ensures that the state base directory as well as the application's sub-directory are created. + * + * @param config streams application configuration to read the root state directory path + * @param time system timer used to execute periodic cleanup procedure + * @param hasPersistentStores only when the application's topology does have stores persisted on local file + * system, we would go ahead and auto-create the corresponding application / task / store + * directories whenever necessary; otherwise no directories would be created. + * @param hasNamedTopologies whether this application is composed of independent named topologies + * + * @throws ProcessorStateException if the base state directory or application state directory does not exist + * and could not be created when hasPersistentStores is enabled. + */ + public StateDirectory(final StreamsConfig config, final Time time, final boolean hasPersistentStores, final boolean hasNamedTopologies) { + this.time = time; + this.hasPersistentStores = hasPersistentStores; + this.hasNamedTopologies = hasNamedTopologies; + this.appId = config.getString(StreamsConfig.APPLICATION_ID_CONFIG); + final String stateDirName = config.getString(StreamsConfig.STATE_DIR_CONFIG); + final File baseDir = new File(stateDirName); + stateDir = new File(baseDir, appId); + + if (this.hasPersistentStores) { + if (!baseDir.exists() && !baseDir.mkdirs()) { + throw new ProcessorStateException( + String.format("base state directory [%s] doesn't exist and couldn't be created", stateDirName)); + } + if (!stateDir.exists() && !stateDir.mkdir()) { + throw new ProcessorStateException( + String.format("state directory [%s] doesn't exist and couldn't be created", stateDir.getPath())); + } else if (stateDir.exists() && !stateDir.isDirectory()) { + throw new ProcessorStateException( + String.format("state directory [%s] can't be created as there is an existing file with the same name", stateDir.getPath())); + } + + if (stateDirName.startsWith(System.getProperty("java.io.tmpdir"))) { + log.warn("Using an OS temp directory in the state.dir property can cause failures with writing" + + " the checkpoint file due to the fact that this directory can be cleared by the OS." + + " Resolved state.dir: [" + stateDirName + "]"); + + } + // change the dir permission to "rwxr-x---" to avoid world readable + configurePermissions(baseDir); + configurePermissions(stateDir); + } + } + + private void configurePermissions(final File file) { + final Path path = file.toPath(); + if (path.getFileSystem().supportedFileAttributeViews().contains("posix")) { + final Set perms = PosixFilePermissions.fromString("rwxr-x---"); + try { + Files.setPosixFilePermissions(path, perms); + } catch (final IOException e) { + log.error("Error changing permissions for the directory {} ", path, e); + } + } else { + boolean set = file.setReadable(true, true); + set &= file.setWritable(true, true); + set &= file.setExecutable(true, true); + if (!set) { + log.error("Failed to change permissions for the directory {}", file); + } + } + } + + /** + * @return true if the state directory was successfully locked + */ + private boolean lockStateDirectory() { + final File lockFile = new File(stateDir, LOCK_FILE_NAME); + try { + stateDirLockChannel = FileChannel.open(lockFile.toPath(), StandardOpenOption.CREATE, StandardOpenOption.WRITE); + stateDirLock = tryLock(stateDirLockChannel); + } catch (final IOException e) { + log.error("Unable to lock the state directory due to unexpected exception", e); + throw new ProcessorStateException("Failed to lock the state directory during startup", e); + } + + return stateDirLock != null; + } + + public UUID initializeProcessId() { + if (!hasPersistentStores) { + return UUID.randomUUID(); + } + + if (!lockStateDirectory()) { + log.error("Unable to obtain lock as state directory is already locked by another process"); + throw new StreamsException("Unable to initialize state, this can happen if multiple instances of " + + "Kafka Streams are running in the same state directory"); + } + + final File processFile = new File(stateDir, PROCESS_FILE_NAME); + final ObjectMapper mapper = new ObjectMapper(); + + try { + if (processFile.exists()) { + try { + final StateDirectoryProcessFile processFileData = mapper.readValue(processFile, StateDirectoryProcessFile.class); + log.info("Reading UUID from process file: {}", processFileData.processId); + if (processFileData.processId != null) { + return processFileData.processId; + } + } catch (final Exception e) { + log.warn("Failed to read json process file", e); + } + } + + final StateDirectoryProcessFile processFileData = new StateDirectoryProcessFile(UUID.randomUUID()); + log.info("No process id found on disk, got fresh process id {}", processFileData.processId); + + mapper.writeValue(processFile, processFileData); + return processFileData.processId; + } catch (final IOException e) { + log.error("Unable to read/write process file due to unexpected exception", e); + throw new ProcessorStateException(e); + } + } + + /** + * Get or create the directory for the provided {@link TaskId}. + * @return directory for the {@link TaskId} + * @throws ProcessorStateException if the task directory does not exist and could not be created + */ + public File getOrCreateDirectoryForTask(final TaskId taskId) { + final File taskParentDir = getTaskDirectoryParentName(taskId); + final File taskDir = new File(taskParentDir, StateManagerUtil.toTaskDirString(taskId)); + if (hasPersistentStores) { + if (!taskDir.exists()) { + synchronized (taskDirCreationLock) { + // to avoid a race condition, we need to check again if the directory does not exist: + // otherwise, two threads might pass the outer `if` (and enter the `then` block), + // one blocks on `synchronized` while the other creates the directory, + // and the blocking one fails when trying to create it after it's unblocked + if (!taskParentDir.exists() && !taskParentDir.mkdir()) { + throw new ProcessorStateException( + String.format("Parent [%s] of task directory [%s] doesn't exist and couldn't be created", + taskParentDir.getPath(), taskDir.getPath())); + } + if (!taskDir.exists() && !taskDir.mkdir()) { + throw new ProcessorStateException( + String.format("task directory [%s] doesn't exist and couldn't be created", taskDir.getPath())); + } + } + } else if (!taskDir.isDirectory()) { + throw new ProcessorStateException( + String.format("state directory [%s] can't be created as there is an existing file with the same name", taskDir.getPath())); + } + } + return taskDir; + } + + private File getTaskDirectoryParentName(final TaskId taskId) { + final String namedTopology = taskId.topologyName(); + if (namedTopology != null) { + if (!hasNamedTopologies) { + throw new IllegalStateException("Tried to lookup taskId with named topology, but StateDirectory thinks hasNamedTopologies = false"); + } + return new File(stateDir, getNamedTopologyDirName(namedTopology)); + } else { + return stateDir; + } + } + + private String getNamedTopologyDirName(final String topologyName) { + return "__" + topologyName + "__"; + } + + /** + * @return The File handle for the checkpoint in the given task's directory + */ + File checkpointFileFor(final TaskId taskId) { + return new File(getOrCreateDirectoryForTask(taskId), StateManagerUtil.CHECKPOINT_FILE_NAME); + } + + /** + * Decide if the directory of the task is empty or not + */ + boolean directoryForTaskIsEmpty(final TaskId taskId) { + final File taskDir = getOrCreateDirectoryForTask(taskId); + + return taskDirIsEmpty(taskDir); + } + + private boolean taskDirIsEmpty(final File taskDir) { + final File[] storeDirs = taskDir.listFiles(pathname -> + !pathname.getName().equals(CHECKPOINT_FILE_NAME)); + + boolean taskDirEmpty = true; + + // if the task is stateless, storeDirs would be null + if (storeDirs != null && storeDirs.length > 0) { + for (final File file : storeDirs) { + // We removed the task directory locking but some upgrading applications may still have old lock files on disk, + // we just lazily delete those in this method since it's the only thing that would be affected by these + if (file.getName().endsWith(LOCK_FILE_NAME)) { + if (!file.delete()) { + // If we hit an error deleting this just ignore it and move on, we'll retry again at some point + log.warn("Error encountered deleting lock file in {}", taskDir); + } + } else { + // If it's not a lock file then the directory is not empty, + // but finish up the loop in case there's a lock file left to delete + log.trace("TaskDir {} was not empty, found {}", taskDir, file); + taskDirEmpty = false; + } + } + } + return taskDirEmpty; + } + + /** + * Get or create the directory for the global stores. + * @return directory for the global stores + * @throws ProcessorStateException if the global store directory does not exists and could not be created + */ + File globalStateDir() { + final File dir = new File(stateDir, "global"); + if (hasPersistentStores) { + if (!dir.exists() && !dir.mkdir()) { + throw new ProcessorStateException( + String.format("global state directory [%s] doesn't exist and couldn't be created", dir.getPath())); + } else if (dir.exists() && !dir.isDirectory()) { + throw new ProcessorStateException( + String.format("global state directory [%s] can't be created as there is an existing file with the same name", dir.getPath())); + } + } + return dir; + } + + private String logPrefix() { + return String.format("stream-thread [%s]", Thread.currentThread().getName()); + } + + /** + * Get the lock for the {@link TaskId}s directory if it is available + * @param taskId task id + * @return true if successful + */ + synchronized boolean lock(final TaskId taskId) { + if (!hasPersistentStores) { + return true; + } + + final Thread lockOwner = lockedTasksToOwner.get(taskId); + if (lockOwner != null) { + if (lockOwner.equals(Thread.currentThread())) { + log.trace("{} Found cached state dir lock for task {}", logPrefix(), taskId); + // we already own the lock + return true; + } else { + // another thread owns the lock + return false; + } + } else if (!stateDir.exists()) { + log.error("Tried to lock task directory for {} but the state directory does not exist", taskId); + throw new IllegalStateException("The state directory has been deleted"); + } else { + lockedTasksToOwner.put(taskId, Thread.currentThread()); + // make sure the task directory actually exists, and create it if not + getOrCreateDirectoryForTask(taskId); + return true; + } + } + + /** + * Unlock the state directory for the given {@link TaskId}. + */ + synchronized void unlock(final TaskId taskId) { + final Thread lockOwner = lockedTasksToOwner.get(taskId); + if (lockOwner != null && lockOwner.equals(Thread.currentThread())) { + lockedTasksToOwner.remove(taskId); + log.debug("{} Released state dir lock for task {}", logPrefix(), taskId); + } + } + + public void close() { + if (hasPersistentStores) { + try { + stateDirLock.release(); + stateDirLockChannel.close(); + + stateDirLock = null; + stateDirLockChannel = null; + } catch (final IOException e) { + log.error("Unexpected exception while unlocking the state dir", e); + throw new StreamsException("Failed to release the lock on the state directory", e); + } + + // all threads should be stopped and cleaned up by now, so none should remain holding a lock + if (!lockedTasksToOwner.isEmpty()) { + log.error("Some task directories still locked while closing state, this indicates unclean shutdown: {}", lockedTasksToOwner); + } + } + } + + public synchronized void clean() { + try { + cleanStateAndTaskDirectoriesCalledByUser(); + } catch (final Exception e) { + throw new StreamsException(e); + } + + try { + if (stateDir.exists()) { + Utils.delete(globalStateDir().getAbsoluteFile()); + } + } catch (final IOException exception) { + log.error( + String.format("%s Failed to delete global state directory of %s due to an unexpected exception", + logPrefix(), appId), + exception + ); + throw new StreamsException(exception); + } + + try { + if (hasPersistentStores && stateDir.exists() && !stateDir.delete()) { + log.warn( + String.format("%s Failed to delete state store directory of %s for it is not empty", + logPrefix(), stateDir.getAbsolutePath()) + ); + } + } catch (final SecurityException exception) { + log.error( + String.format("%s Failed to delete state store directory of %s due to an unexpected exception", + logPrefix(), stateDir.getAbsolutePath()), + exception + ); + throw new StreamsException(exception); + } + } + + /** + * Remove the directories for any {@link TaskId}s that are no-longer + * owned by this {@link StreamThread} and aren't locked by either + * another process or another {@link StreamThread} + * @param cleanupDelayMs only remove directories if they haven't been modified for at least + * this amount of time (milliseconds) + */ + public synchronized void cleanRemovedTasks(final long cleanupDelayMs) { + try { + cleanRemovedTasksCalledByCleanerThread(cleanupDelayMs); + } catch (final Exception cannotHappen) { + throw new IllegalStateException("Should have swallowed exception.", cannotHappen); + } + } + + private void cleanRemovedTasksCalledByCleanerThread(final long cleanupDelayMs) { + for (final TaskDirectory taskDir : listAllTaskDirectories()) { + final String dirName = taskDir.file().getName(); + final TaskId id = parseTaskDirectoryName(dirName, taskDir.namedTopology()); + if (!lockedTasksToOwner.containsKey(id)) { + try { + if (lock(id)) { + final long now = time.milliseconds(); + final long lastModifiedMs = taskDir.file().lastModified(); + if (now - cleanupDelayMs > lastModifiedMs) { + log.info("{} Deleting obsolete state directory {} for task {} as {}ms has elapsed (cleanup delay is {}ms).", + logPrefix(), dirName, id, now - lastModifiedMs, cleanupDelayMs); + Utils.delete(taskDir.file()); + } + } + } catch (final IOException exception) { + log.warn( + String.format("%s Swallowed the following exception during deletion of obsolete state directory %s for task %s:", + logPrefix(), dirName, id), + exception + ); + } finally { + unlock(id); + } + } + } + // Ok to ignore returned exception as it should be swallowed + maybeCleanEmptyNamedTopologyDirs(true); + } + + /** + * Cleans up any leftover named topology directories that are empty, if any exist + * @param logExceptionAsWarn if true, an exception will be logged as a warning + * if false, an exception will be logged as error + * @return the first IOException to be encountered + */ + private IOException maybeCleanEmptyNamedTopologyDirs(final boolean logExceptionAsWarn) { + if (!hasNamedTopologies) { + return null; + } + + final AtomicReference firstException = new AtomicReference<>(null); + final File[] namedTopologyDirs = stateDir.listFiles(pathname -> + pathname.isDirectory() && NAMED_TOPOLOGY_DIR_PATH_NAME.matcher(pathname.getName()).matches() + ); + if (namedTopologyDirs != null) { + for (final File namedTopologyDir : namedTopologyDirs) { + final File[] contents = namedTopologyDir.listFiles(); + if (contents != null && contents.length == 0) { + try { + Utils.delete(namedTopologyDir); + } catch (final IOException exception) { + if (logExceptionAsWarn) { + log.warn( + String.format("%sSwallowed the following exception during deletion of named topology directory %s", + logPrefix(), namedTopologyDir.getName()), + exception + ); + } else { + log.error( + String.format("%s Failed to delete named topology directory %s with exception:", + logPrefix(), namedTopologyDir.getName()), + exception + ); + } + firstException.compareAndSet(null, exception); + } + } + } + } + return firstException.get(); + } + + /** + * Clears out any local state found for the given NamedTopology after it was removed + * + * @throws StreamsException if cleanup failed + */ + public void clearLocalStateForNamedTopology(final String topologyName) { + final File namedTopologyDir = new File(stateDir, getNamedTopologyDirName(topologyName)); + if (!namedTopologyDir.exists() || !namedTopologyDir.isDirectory()) { + log.debug("Tried to clear out the local state for NamedTopology {} but none was found", topologyName); + } + try { + Utils.delete(namedTopologyDir); + } catch (final IOException e) { + log.error("Hit an unexpected error while clearing local state for NamedTopology {}", topologyName); + throw new StreamsException("Unable to delete state for the named topology " + topologyName); + } + } + + private void cleanStateAndTaskDirectoriesCalledByUser() throws Exception { + if (!lockedTasksToOwner.isEmpty()) { + log.warn("Found some still-locked task directories when user requested to cleaning up the state, " + + "since Streams is not running any more these will be ignored to complete the cleanup"); + } + final AtomicReference firstException = new AtomicReference<>(); + for (final TaskDirectory taskDir : listAllTaskDirectories()) { + final String dirName = taskDir.file().getName(); + final TaskId id = parseTaskDirectoryName(dirName, taskDir.namedTopology()); + try { + log.info("{} Deleting task directory {} for {} as user calling cleanup.", + logPrefix(), dirName, id); + + if (lockedTasksToOwner.containsKey(id)) { + log.warn("{} Task {} in state directory {} was still locked by {}", + logPrefix(), dirName, id, lockedTasksToOwner.get(id)); + } + Utils.delete(taskDir.file()); + } catch (final IOException exception) { + log.error( + String.format("%s Failed to delete task directory %s for %s with exception:", + logPrefix(), dirName, id), + exception + ); + firstException.compareAndSet(null, exception); + } + } + + firstException.compareAndSet(null, maybeCleanEmptyNamedTopologyDirs(false)); + + final Exception exception = firstException.get(); + if (exception != null) { + throw exception; + } + } + + /** + * List all of the task directories that are non-empty + * @return The list of all the non-empty local directories for stream tasks + */ + List listNonEmptyTaskDirectories() { + return listTaskDirectories(pathname -> { + if (!pathname.isDirectory() || !TASK_DIR_PATH_NAME.matcher(pathname.getName()).matches()) { + return false; + } else { + return !taskDirIsEmpty(pathname); + } + }); + } + + /** + * List all of the task directories along with their parent directory if they belong to a named topology + * @return The list of all the existing local directories for stream tasks + */ + List listAllTaskDirectories() { + return listTaskDirectories(pathname -> pathname.isDirectory() && TASK_DIR_PATH_NAME.matcher(pathname.getName()).matches()); + } + + private List listTaskDirectories(final FileFilter filter) { + final List taskDirectories = new ArrayList<>(); + if (hasPersistentStores && stateDir.exists()) { + if (hasNamedTopologies) { + for (final File namedTopologyDir : listNamedTopologyDirs()) { + final String namedTopology = parseNamedTopologyFromDirectory(namedTopologyDir.getName()); + final File[] taskDirs = namedTopologyDir.listFiles(filter); + if (taskDirs != null) { + taskDirectories.addAll(Arrays.stream(taskDirs) + .map(f -> new TaskDirectory(f, namedTopology)).collect(Collectors.toList())); + } + } + } else { + final File[] taskDirs = + stateDir.listFiles(filter); + if (taskDirs != null) { + taskDirectories.addAll(Arrays.stream(taskDirs) + .map(f -> new TaskDirectory(f, null)).collect(Collectors.toList())); + } + } + } + + return taskDirectories; + } + + private List listNamedTopologyDirs() { + final File[] namedTopologyDirectories = stateDir.listFiles(f -> f.getName().startsWith("__") && f.getName().endsWith("__")); + return namedTopologyDirectories != null ? Arrays.asList(namedTopologyDirectories) : Collections.emptyList(); + } + + private String parseNamedTopologyFromDirectory(final String dirName) { + return dirName.substring(2, dirName.length() - 2); + } + + private FileLock tryLock(final FileChannel channel) throws IOException { + try { + return channel.tryLock(); + } catch (final OverlappingFileLockException e) { + return null; + } + } + + public static class TaskDirectory { + private final File file; + private final String namedTopology; // may be null if hasNamedTopologies = false + + TaskDirectory(final File file, final String namedTopology) { + this.file = file; + this.namedTopology = namedTopology; + } + + public File file() { + return file; + } + + public String namedTopology() { + return namedTopology; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final TaskDirectory that = (TaskDirectory) o; + return file.equals(that.file) && + Objects.equals(namedTopology, that.namedTopology); + } + + @Override + public int hashCode() { + return Objects.hash(file, namedTopology); + } + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateManager.java new file mode 100644 index 0000000..ad5c3cb --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateManager.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.StateRestoreCallback; +import org.apache.kafka.streams.processor.StateStore; + +import java.io.File; +import java.io.IOException; +import java.util.Map; +import org.apache.kafka.streams.processor.internals.Task.TaskType; + +public interface StateManager { + File baseDir(); + + /** + * @throws IllegalArgumentException if the store name has already been registered or if it is not a valid name + * (e.g., when it conflicts with the names of internal topics, like the checkpoint file name) + * @throws StreamsException if the store's change log does not contain the partition + */ + void registerStore(final StateStore store, final StateRestoreCallback stateRestoreCallback); + + StateStore getStore(final String name); + + void flush(); + + void updateChangelogOffsets(final Map writtenOffsets); + + void checkpoint(); + + Map changelogOffsets(); + + void close() throws IOException; + + TaskType taskType(); + + String changelogFor(final String storeName); + + // TODO: we can remove this when consolidating global state manager into processor state manager + StateStore getGlobalStore(final String name); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateManagerUtil.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateManagerUtil.java new file mode 100644 index 0000000..a1c8ca8 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateManagerUtil.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.errors.LockException; +import org.apache.kafka.streams.errors.ProcessorStateException; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.errors.TaskIdFormatException; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.Task.TaskType; +import org.apache.kafka.streams.state.internals.RecordConverter; +import org.slf4j.Logger; + +import static org.apache.kafka.streams.state.internals.RecordConverters.identity; +import static org.apache.kafka.streams.state.internals.RecordConverters.rawValueToTimestampedValue; +import static org.apache.kafka.streams.state.internals.WrappedStateStore.isTimestamped; + +/** + * Shared functions to handle state store registration and cleanup between + * active and standby tasks. + */ +final class StateManagerUtil { + static final String CHECKPOINT_FILE_NAME = ".checkpoint"; + static final long OFFSET_DELTA_THRESHOLD_FOR_CHECKPOINT = 10_000L; + + private StateManagerUtil() {} + + static RecordConverter converterForStore(final StateStore store) { + return isTimestamped(store) ? rawValueToTimestampedValue() : identity(); + } + + static boolean checkpointNeeded(final boolean enforceCheckpoint, + final Map oldOffsetSnapshot, + final Map newOffsetSnapshot) { + // we should always have the old snapshot post completing the register state stores; + // if it is null it means the registration is not done and hence we should not overwrite the checkpoint + if (oldOffsetSnapshot == null) { + return false; + } + + if (enforceCheckpoint) + return true; + + // we can checkpoint if the difference between the current and the previous snapshot is large enough + long totalOffsetDelta = 0L; + for (final Map.Entry entry : newOffsetSnapshot.entrySet()) { + totalOffsetDelta += entry.getValue() - oldOffsetSnapshot.getOrDefault(entry.getKey(), 0L); + } + + // when enforcing checkpoint is required, we should overwrite the checkpoint if it is different from the old one; + // otherwise, we only overwrite the checkpoint if it is largely different from the old one + return totalOffsetDelta > OFFSET_DELTA_THRESHOLD_FOR_CHECKPOINT; + } + + /** + * @throws StreamsException If the store's changelog does not contain the partition + */ + static void registerStateStores(final Logger log, + final String logPrefix, + final ProcessorTopology topology, + final ProcessorStateManager stateMgr, + final StateDirectory stateDirectory, + final InternalProcessorContext processorContext) { + if (topology.stateStores().isEmpty()) { + return; + } + + final TaskId id = stateMgr.taskId(); + if (!stateDirectory.lock(id)) { + throw new LockException(String.format("%sFailed to lock the state directory for task %s", logPrefix, id)); + } + log.debug("Acquired state directory lock"); + + final boolean storeDirsEmpty = stateDirectory.directoryForTaskIsEmpty(id); + + stateMgr.registerStateStores(topology.stateStores(), processorContext); + log.debug("Registered state stores"); + + // We should only load checkpoint AFTER the corresponding state directory lock has been acquired and + // the state stores have been registered; we should not try to load at the state manager construction time. + // See https://issues.apache.org/jira/browse/KAFKA-8574 + stateMgr.initializeStoreOffsetsFromCheckpoint(storeDirsEmpty); + log.debug("Initialized state stores"); + } + + /** + * @throws ProcessorStateException if there is an error while closing the state manager + */ + static void closeStateManager(final Logger log, + final String logPrefix, + final boolean closeClean, + final boolean eosEnabled, + final ProcessorStateManager stateMgr, + final StateDirectory stateDirectory, + final TaskType taskType) { + // if EOS is enabled, wipe out the whole state store for unclean close since it is now invalid + final boolean wipeStateStore = !closeClean && eosEnabled; + + final TaskId id = stateMgr.taskId(); + log.trace("Closing state manager for {} task {}", taskType, id); + + final AtomicReference firstException = new AtomicReference<>(null); + try { + if (stateDirectory.lock(id)) { + try { + stateMgr.close(); + } catch (final ProcessorStateException e) { + firstException.compareAndSet(null, e); + } finally { + try { + if (wipeStateStore) { + log.debug("Wiping state stores for {} task {}", taskType, id); + // we can just delete the whole dir of the task, including the state store images and the checkpoint files, + // and then we write an empty checkpoint file indicating that the previous close is graceful and we just + // need to re-bootstrap the restoration from the beginning + Utils.delete(stateMgr.baseDir()); + } + } finally { + stateDirectory.unlock(id); + } + } + } + } catch (final IOException e) { + final ProcessorStateException exception = new ProcessorStateException( + String.format("%sFatal error while trying to close the state manager for task %s", logPrefix, id), e + ); + firstException.compareAndSet(null, exception); + } + + final ProcessorStateException exception = firstException.get(); + if (exception != null) { + throw exception; + } + } + + /** + * Parse the task directory name (of the form topicGroupId_partition) and construct the TaskId with the + * optional namedTopology (may be null) + * + * @throws TaskIdFormatException if the taskIdStr is not a valid {@link TaskId} + */ + static TaskId parseTaskDirectoryName(final String taskIdStr, final String namedTopology) { + final int index = taskIdStr.indexOf('_'); + if (index <= 0 || index + 1 >= taskIdStr.length()) { + throw new TaskIdFormatException(taskIdStr); + } + + try { + final int topicGroupId = Integer.parseInt(taskIdStr.substring(0, index)); + final int partition = Integer.parseInt(taskIdStr.substring(index + 1)); + + return new TaskId(topicGroupId, partition, namedTopology); + } catch (final Exception e) { + throw new TaskIdFormatException(taskIdStr); + } + } + + /** + * @return The string representation of the subtopology and partition metadata, ie the task id string without + * the named topology, which defines the innermost task directory name of this task's state + */ + static String toTaskDirString(final TaskId taskId) { + return taskId.subtopology() + "_" + taskId.partition(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateRestoreCallbackAdapter.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateRestoreCallbackAdapter.java new file mode 100644 index 0000000..fce3f80 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StateRestoreCallbackAdapter.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.BatchingStateRestoreCallback; +import org.apache.kafka.streams.processor.StateRestoreCallback; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +public final class StateRestoreCallbackAdapter { + private StateRestoreCallbackAdapter() {} + + public static RecordBatchingStateRestoreCallback adapt(final StateRestoreCallback restoreCallback) { + Objects.requireNonNull(restoreCallback, "stateRestoreCallback must not be null"); + if (restoreCallback instanceof RecordBatchingStateRestoreCallback) { + return (RecordBatchingStateRestoreCallback) restoreCallback; + } else if (restoreCallback instanceof BatchingStateRestoreCallback) { + return records -> { + final List> keyValues = new ArrayList<>(); + for (final ConsumerRecord record : records) { + keyValues.add(new KeyValue<>(record.key(), record.value())); + } + ((BatchingStateRestoreCallback) restoreCallback).restoreAll(keyValues); + }; + } else { + return records -> { + for (final ConsumerRecord record : records) { + restoreCallback.restore(record.key(), record.value()); + } + }; + } + } +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StaticTopicNameExtractor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StaticTopicNameExtractor.java new file mode 100644 index 0000000..f38c547 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StaticTopicNameExtractor.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.streams.processor.RecordContext; +import org.apache.kafka.streams.processor.TopicNameExtractor; + +import java.util.Objects; + +/** + * Static topic name extractor + */ +public class StaticTopicNameExtractor implements TopicNameExtractor { + + public final String topicName; + + public StaticTopicNameExtractor(final String topicName) { + this.topicName = topicName; + } + + public String extract(final K key, final V value, final RecordContext recordContext) { + return topicName; + } + + @Override + public String toString() { + return "StaticTopicNameExtractor(" + topicName + ")"; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final StaticTopicNameExtractor that = (StaticTopicNameExtractor) o; + return Objects.equals(topicName, that.topicName); + } + + @Override + public int hashCode() { + return Objects.hash(topicName); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java new file mode 100644 index 0000000..fdf027f --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreChangelogReader.java @@ -0,0 +1,956 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.ListOffsetsOptions; +import org.apache.kafka.clients.admin.ListOffsetsResult; +import org.apache.kafka.clients.admin.OffsetSpec; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.InvalidOffsetException; +import org.apache.kafka.common.IsolationLevel; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.errors.TaskCorruptedException; +import org.apache.kafka.streams.processor.StateRestoreListener; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.ProcessorStateManager.StateStoreMetadata; +import org.slf4j.Logger; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static org.apache.kafka.streams.processor.internals.ClientUtils.fetchCommittedOffsets; + +/** + * ChangelogReader is created and maintained by the stream thread and used for both updating standby tasks and + * restoring active tasks. It manages the restore consumer, including its assigned partitions, when to pause / resume + * these partitions, etc. + *

                + * The reader also maintains the source of truth for restoration state: only active tasks restoring changelog could + * be completed, while standby tasks updating changelog would always be in restoring state after being initialized. + */ +public class StoreChangelogReader implements ChangelogReader { + private static final long RESTORE_LOG_INTERVAL_MS = 10_000L; + private long lastRestoreLogTime = 0L; + + enum ChangelogState { + // registered but need to be initialized (i.e. set its starting, end, limit offsets) + REGISTERED("REGISTERED"), + + // initialized and restoring + RESTORING("RESTORING", 0), + + // completed restoring (only for active restoring task, standby task should never be completed) + COMPLETED("COMPLETED", 1); + + public final String name; + private final List prevStates; + + ChangelogState(final String name, final Integer... prevStates) { + this.name = name; + this.prevStates = Arrays.asList(prevStates); + } + } + + // NOTE we assume that the changelog reader is used only for either + // 1) restoring active task or + // 2) updating standby task at a given time, + // but never doing both + enum ChangelogReaderState { + ACTIVE_RESTORING("ACTIVE_RESTORING"), + + STANDBY_UPDATING("STANDBY_UPDATING"); + + public final String name; + + ChangelogReaderState(final String name) { + this.name = name; + } + } + + static class ChangelogMetadata { + + private final StateStoreMetadata storeMetadata; + + private final ProcessorStateManager stateManager; + + private ChangelogState changelogState; + + private long totalRestored; + + // the end offset beyond which records should not be applied (yet) to restore the states + // + // for both active restoring tasks and standby updating tasks, it is defined as: + // * log-end-offset if the changelog is not piggy-backed with source topic + // * min(log-end-offset, committed-offset) if the changelog is piggy-backed with source topic + // + // the log-end-offset only needs to be updated once and only need to be for active tasks since for standby + // tasks it would never "complete" based on the end-offset; + // + // the committed-offset needs to be updated periodically for those standby tasks + // + // NOTE we do not book keep the current offset since we leverage state manager as its source of truth + private Long restoreEndOffset; + + // buffer records polled by the restore consumer; + private final List> bufferedRecords; + + // the limit index (exclusive) inside the buffered records beyond which should not be used to restore + // either due to limit offset (standby) or committed end offset (active) + private int bufferedLimitIndex; + + private ChangelogMetadata(final StateStoreMetadata storeMetadata, final ProcessorStateManager stateManager) { + this.changelogState = ChangelogState.REGISTERED; + this.storeMetadata = storeMetadata; + this.stateManager = stateManager; + this.restoreEndOffset = null; + this.totalRestored = 0L; + + this.bufferedRecords = new ArrayList<>(); + this.bufferedLimitIndex = 0; + } + + private void clear() { + this.bufferedRecords.clear(); + } + + private void transitTo(final ChangelogState newState) { + if (newState.prevStates.contains(changelogState.ordinal())) { + changelogState = newState; + } else { + throw new IllegalStateException("Invalid transition from " + changelogState + " to " + newState); + } + } + + @Override + public String toString() { + final Long currentOffset = storeMetadata.offset(); + return changelogState + " " + stateManager.taskType() + + " (currentOffset " + currentOffset + ", endOffset " + restoreEndOffset + ")"; + } + + // for testing only below + ChangelogState state() { + return changelogState; + } + + long totalRestored() { + return totalRestored; + } + + Long endOffset() { + return restoreEndOffset; + } + + List> bufferedRecords() { + return bufferedRecords; + } + + int bufferedLimitIndex() { + return bufferedLimitIndex; + } + } + + private final static long DEFAULT_OFFSET_UPDATE_MS = Duration.ofMinutes(5L).toMillis(); + + private ChangelogReaderState state; + + private final Time time; + private final Logger log; + private final Duration pollTime; + private final long updateOffsetIntervalMs; + + // 1) we keep adding partitions to restore consumer whenever new tasks are registered with the state manager; + // 2) we do not unassign partitions when we switch between standbys and actives, we just pause / resume them; + // 3) we only remove an assigned partition when the corresponding task is being removed from the thread. + private final Consumer restoreConsumer; + private final StateRestoreListener stateRestoreListener; + + // source of the truth of the current registered changelogs; + // NOTE a changelog would only be removed when its corresponding task + // is being removed from the thread; otherwise it would stay in this map even after completed + private final Map changelogs; + + // the changelog reader only need the main consumer to get committed offsets for source changelog partitions + // to update offset limit for standby tasks; + private Consumer mainConsumer; + + // the changelog reader needs the admin client to list end offsets + private final Admin adminClient; + + private long lastUpdateOffsetTime; + + void setMainConsumer(final Consumer consumer) { + this.mainConsumer = consumer; + } + + public StoreChangelogReader(final Time time, + final StreamsConfig config, + final LogContext logContext, + final Admin adminClient, + final Consumer restoreConsumer, + final StateRestoreListener stateRestoreListener) { + this.time = time; + this.log = logContext.logger(StoreChangelogReader.class); + this.state = ChangelogReaderState.ACTIVE_RESTORING; + this.adminClient = adminClient; + this.restoreConsumer = restoreConsumer; + this.stateRestoreListener = stateRestoreListener; + + this.pollTime = Duration.ofMillis(config.getLong(StreamsConfig.POLL_MS_CONFIG)); + this.updateOffsetIntervalMs = config.getLong(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG) == Long.MAX_VALUE ? + DEFAULT_OFFSET_UPDATE_MS : config.getLong(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG); + this.lastUpdateOffsetTime = 0L; + + this.changelogs = new HashMap<>(); + } + + private static String recordEndOffset(final Long endOffset) { + return endOffset == null ? "UNKNOWN (since it is for standby task)" : endOffset.toString(); + } + + private boolean hasRestoredToEnd(final ChangelogMetadata metadata) { + final Long endOffset = metadata.restoreEndOffset; + if (endOffset == null) { + // end offset is not initialized meaning that it is from a standby task, + // this should never happen since we only call this function for active task in restoring phase + throw new IllegalStateException("End offset for changelog " + metadata + " is unknown when deciding " + + "if it has completed restoration, this should never happen."); + } else if (endOffset == 0) { + // this is a special case, meaning there's nothing to be restored since the changelog has no data + // OR the changelog is a source topic and there's no committed offset + return true; + } else if (metadata.bufferedRecords.isEmpty()) { + // NOTE there are several corner cases that we need to consider: + // 1) the end / committed offset returned from the consumer is the last offset + 1 + // 2) there could be txn markers as the last record if EOS is enabled at the producer + // + // It is possible that: the last record's offset == last txn marker offset - 1 == end / committed offset - 2 + // + // So we make the following decision: + // 1) if all the buffered records have been applied, then we compare the end offset with the + // current consumer's position, which is the "next" record to fetch, bypassing the txn marker already + // 2) if not all the buffered records have been applied, then it means we are restricted by the end offset, + // and the consumer's position is likely already ahead of that end offset. Then we just need to check + // the first record in the remaining buffer and see if that record is no smaller than the end offset. + final TopicPartition partition = metadata.storeMetadata.changelogPartition(); + try { + return restoreConsumer.position(partition) >= endOffset; + } catch (final TimeoutException timeoutException) { + // re-throw to trigger `task.timeout.ms` + throw timeoutException; + } catch (final KafkaException e) { + // this also includes InvalidOffsetException, which should not happen under normal + // execution, hence it is also okay to wrap it as fatal StreamsException + throw new StreamsException("Restore consumer get unexpected error trying to get the position " + + " of " + partition, e); + } + } else { + return metadata.bufferedRecords.get(0).offset() >= endOffset; + } + } + + // Once some new tasks are created, we transit to restore them and pause on the existing standby tasks. It is + // possible that when newly created tasks are created the changelog reader are still restoring existing + // active tasks, and hence this function is idempotent and can be called multiple times. + // + // NOTE: even if the newly created tasks do not need any restoring, we still first transit to this state and then + // immediately transit back -- there's no overhead of transiting back and forth but simplifies the logic a lot. + @Override + public void enforceRestoreActive() { + if (state != ChangelogReaderState.ACTIVE_RESTORING) { + log.debug("Transiting to restore active tasks: {}", changelogs); + lastRestoreLogTime = 0L; + + // pause all partitions that are for standby tasks from the restore consumer + pauseChangelogsFromRestoreConsumer(standbyRestoringChangelogs()); + + state = ChangelogReaderState.ACTIVE_RESTORING; + } + } + + // Only after we've completed restoring all active tasks we'll then move back to resume updating standby tasks. + // This function is NOT idempotent: if it is already in updating standby tasks mode, we should not call it again. + // + // NOTE: we do not clear completed active restoring changelogs or remove partitions from restore consumer either + // upon completing them but only pause the corresponding partitions; the changelog metadata / partitions would only + // be cleared when the corresponding task is being removed from the thread. In other words, the restore consumer + // should contain all changelogs that are RESTORING or COMPLETED + @Override + public void transitToUpdateStandby() { + if (state != ChangelogReaderState.ACTIVE_RESTORING) { + throw new IllegalStateException( + "The changelog reader is not restoring active tasks (is " + state + ") while trying to " + + "transit to update standby tasks: " + changelogs + ); + } + + log.debug("Transiting to update standby tasks: {}", changelogs); + + // resume all standby restoring changelogs from the restore consumer + resumeChangelogsFromRestoreConsumer(standbyRestoringChangelogs()); + + state = ChangelogReaderState.STANDBY_UPDATING; + } + + /** + * Since it is shared for multiple tasks and hence multiple state managers, the registration would take its + * corresponding state manager as well for restoring. + */ + @Override + public void register(final TopicPartition partition, final ProcessorStateManager stateManager) { + final StateStoreMetadata storeMetadata = stateManager.storeMetadata(partition); + if (storeMetadata == null) { + throw new IllegalStateException("Cannot find the corresponding state store metadata for changelog " + + partition); + } + + final ChangelogMetadata changelogMetadata = new ChangelogMetadata(storeMetadata, stateManager); + + // initializing limit offset to 0L for standby changelog to effectively disable any restoration until it is updated + if (stateManager.taskType() == Task.TaskType.STANDBY && stateManager.changelogAsSource(partition)) { + changelogMetadata.restoreEndOffset = 0L; + } + + if (changelogs.putIfAbsent(partition, changelogMetadata) != null) { + throw new IllegalStateException("There is already a changelog registered for " + partition + + ", this should not happen: " + changelogs); + } + } + + private ChangelogMetadata restoringChangelogByPartition(final TopicPartition partition) { + final ChangelogMetadata changelogMetadata = changelogs.get(partition); + if (changelogMetadata == null) { + throw new IllegalStateException("The corresponding changelog restorer for " + partition + + " does not exist, this should not happen."); + } + if (changelogMetadata.changelogState != ChangelogState.RESTORING) { + throw new IllegalStateException("The corresponding changelog restorer for " + partition + + " has already transited to completed state, this should not happen."); + } + + return changelogMetadata; + } + + private Set registeredChangelogs() { + return changelogs.values().stream() + .filter(metadata -> metadata.changelogState == ChangelogState.REGISTERED) + .collect(Collectors.toSet()); + } + + private Set restoringChangelogs() { + return changelogs.values().stream() + .filter(metadata -> metadata.changelogState == ChangelogState.RESTORING) + .map(metadata -> metadata.storeMetadata.changelogPartition()) + .collect(Collectors.toSet()); + } + + private Set activeRestoringChangelogs() { + return changelogs.values().stream() + .filter(metadata -> metadata.changelogState == ChangelogState.RESTORING && + metadata.stateManager.taskType() == Task.TaskType.ACTIVE) + .map(metadata -> metadata.storeMetadata.changelogPartition()) + .collect(Collectors.toSet()); + } + + private Set standbyRestoringChangelogs() { + return changelogs.values().stream() + .filter(metadata -> metadata.changelogState == ChangelogState.RESTORING && + metadata.stateManager.taskType() == Task.TaskType.STANDBY) + .map(metadata -> metadata.storeMetadata.changelogPartition()) + .collect(Collectors.toSet()); + } + + private boolean allChangelogsCompleted() { + return changelogs.values().stream() + .allMatch(metadata -> metadata.changelogState == ChangelogState.COMPLETED); + } + + @Override + public Set completedChangelogs() { + return changelogs.values().stream() + .filter(metadata -> metadata.changelogState == ChangelogState.COMPLETED) + .map(metadata -> metadata.storeMetadata.changelogPartition()) + .collect(Collectors.toSet()); + } + + // 1. if there are any registered changelogs that needs initialization, try to initialize them first; + // 2. if all changelogs have finished, return early; + // 3. if there are any restoring changelogs, try to read from the restore consumer and process them. + @Override + public void restore(final Map tasks) { + initializeChangelogs(tasks, registeredChangelogs()); + + if (!activeRestoringChangelogs().isEmpty() && state == ChangelogReaderState.STANDBY_UPDATING) { + throw new IllegalStateException("Should not be in standby updating state if there are still un-completed active changelogs"); + } + + if (allChangelogsCompleted()) { + log.debug("Finished restoring all changelogs {}", changelogs.keySet()); + return; + } + + final Set restoringChangelogs = restoringChangelogs(); + if (!restoringChangelogs.isEmpty()) { + final ConsumerRecords polledRecords; + + try { + // for restoring active and updating standby we may prefer different poll time + // in order to make sure we call the main consumer#poll in time. + // TODO: once we move ChangelogReader to a separate thread this may no longer be a concern + polledRecords = restoreConsumer.poll(state == ChangelogReaderState.STANDBY_UPDATING ? Duration.ZERO : pollTime); + + // TODO (?) If we cannot fetch records during restore, should we trigger `task.timeout.ms` ? + // TODO (?) If we cannot fetch records for standby task, should we trigger `task.timeout.ms` ? + } catch (final InvalidOffsetException e) { + log.warn("Encountered " + e.getClass().getName() + + " fetching records from restore consumer for partitions " + e.partitions() + ", it is likely that " + + "the consumer's position has fallen out of the topic partition offset range because the topic was " + + "truncated or compacted on the broker, marking the corresponding tasks as corrupted and re-initializing" + + " it later.", e); + + final Set corruptedTasks = new HashSet<>(); + e.partitions().forEach(partition -> corruptedTasks.add(changelogs.get(partition).stateManager.taskId())); + throw new TaskCorruptedException(corruptedTasks, e); + } catch (final KafkaException e) { + throw new StreamsException("Restore consumer get unexpected error polling records.", e); + } + + for (final TopicPartition partition : polledRecords.partitions()) { + bufferChangelogRecords(restoringChangelogByPartition(partition), polledRecords.records(partition)); + } + + for (final TopicPartition partition : restoringChangelogs) { + // even if some partition do not have any accumulated data, we still trigger + // restoring since some changelog may not need to restore any at all, and the + // restore to end check needs to be executed still. + // TODO: we always try to restore as a batch when some records are accumulated, which may result in + // small batches; this can be optimized in the future, e.g. wait longer for larger batches. + final TaskId taskId = changelogs.get(partition).stateManager.taskId(); + try { + if (restoreChangelog(changelogs.get(partition))) { + tasks.get(taskId).clearTaskTimeout(); + } + } catch (final TimeoutException timeoutException) { + tasks.get(taskId).maybeInitTaskTimeoutOrThrow( + time.milliseconds(), + timeoutException + ); + } + } + + maybeUpdateLimitOffsetsForStandbyChangelogs(tasks); + + maybeLogRestorationProgress(); + } + } + + private void maybeLogRestorationProgress() { + if (state == ChangelogReaderState.ACTIVE_RESTORING) { + if (time.milliseconds() - lastRestoreLogTime > RESTORE_LOG_INTERVAL_MS) { + final Set topicPartitions = activeRestoringChangelogs(); + if (!topicPartitions.isEmpty()) { + final StringBuilder builder = new StringBuilder().append("Restoration in progress for ") + .append(topicPartitions.size()) + .append(" partitions."); + for (final TopicPartition partition : topicPartitions) { + final ChangelogMetadata changelogMetadata = restoringChangelogByPartition(partition); + builder.append(" {") + .append(partition) + .append(": ") + .append("position=") + .append(getPositionString(partition, changelogMetadata)) + .append(", end=") + .append(changelogMetadata.restoreEndOffset) + .append(", totalRestored=") + .append(changelogMetadata.totalRestored) + .append("}"); + } + log.info(builder.toString()); + lastRestoreLogTime = time.milliseconds(); + } + } + } else { + lastRestoreLogTime = 0L; + } + } + + private static String getPositionString(final TopicPartition partition, + final ChangelogMetadata changelogMetadata) { + final ProcessorStateManager stateManager = changelogMetadata.stateManager; + final Long offsets = stateManager.changelogOffsets().get(partition); + return offsets == null ? "unknown" : String.valueOf(offsets); + } + + private void maybeUpdateLimitOffsetsForStandbyChangelogs(final Map tasks) { + // we only consider updating the limit offset for standbys if we are not restoring active tasks + if (state == ChangelogReaderState.STANDBY_UPDATING && + updateOffsetIntervalMs < time.milliseconds() - lastUpdateOffsetTime) { + + // when the interval has elapsed we should try to update the limit offset for standbys reading from + // a source changelog with the new committed offset, unless there are no buffered records since + // we only need the limit when processing new records + // for other changelog partitions we do not need to update limit offset at all since we never need to + // check when it completes based on limit offset anyways: the end offset would keep increasing and the + // standby never need to stop + final Set changelogsWithLimitOffsets = changelogs.entrySet().stream() + .filter(entry -> entry.getValue().stateManager.taskType() == Task.TaskType.STANDBY && + entry.getValue().stateManager.changelogAsSource(entry.getKey())) + .map(Map.Entry::getKey).collect(Collectors.toSet()); + + for (final TopicPartition partition : changelogsWithLimitOffsets) { + if (!changelogs.get(partition).bufferedRecords().isEmpty()) { + updateLimitOffsetsForStandbyChangelogs(committedOffsetForChangelogs(tasks, changelogsWithLimitOffsets)); + break; + } + } + } + } + + private void bufferChangelogRecords(final ChangelogMetadata changelogMetadata, final List> records) { + // update the buffered records and limit index with the fetched records + for (final ConsumerRecord record : records) { + // filter polled records for null-keys and also possibly update buffer limit index + if (record.key() == null) { + log.warn("Read changelog record with null key from changelog {} at offset {}, " + + "skipping it for restoration", changelogMetadata.storeMetadata.changelogPartition(), record.offset()); + } else { + changelogMetadata.bufferedRecords.add(record); + final long offset = record.offset(); + if (changelogMetadata.restoreEndOffset == null || offset < changelogMetadata.restoreEndOffset) { + changelogMetadata.bufferedLimitIndex = changelogMetadata.bufferedRecords.size(); + } + } + } + } + + /** + * restore a changelog with its buffered records if there's any; for active changelogs also check if + * it has completed the restoration and can transit to COMPLETED state and trigger restore callbacks + */ + private boolean restoreChangelog(final ChangelogMetadata changelogMetadata) { + final ProcessorStateManager stateManager = changelogMetadata.stateManager; + final StateStoreMetadata storeMetadata = changelogMetadata.storeMetadata; + final TopicPartition partition = storeMetadata.changelogPartition(); + final String storeName = storeMetadata.store().name(); + final int numRecords = changelogMetadata.bufferedLimitIndex; + + boolean madeProgress = false; + + if (numRecords != 0) { + madeProgress = true; + + final List> records = changelogMetadata.bufferedRecords.subList(0, numRecords); + stateManager.restore(storeMetadata, records); + + // NOTE here we use removeRange of ArrayList in order to achieve efficiency with range shifting, + // otherwise one-at-a-time removal or addition would be very costly; if all records are restored + // then we can further optimize to save the array-shift but just set array elements to null + if (numRecords < changelogMetadata.bufferedRecords.size()) { + records.clear(); + } else { + changelogMetadata.bufferedRecords.clear(); + } + + final Long currentOffset = storeMetadata.offset(); + log.trace("Restored {} records from changelog {} to store {}, end offset is {}, current offset is {}", + partition, storeName, numRecords, recordEndOffset(changelogMetadata.restoreEndOffset), currentOffset); + + changelogMetadata.bufferedLimitIndex = 0; + changelogMetadata.totalRestored += numRecords; + + // do not trigger restore listener if we are processing standby tasks + if (changelogMetadata.stateManager.taskType() == Task.TaskType.ACTIVE) { + try { + stateRestoreListener.onBatchRestored(partition, storeName, currentOffset, numRecords); + } catch (final Exception e) { + throw new StreamsException("State restore listener failed on batch restored", e); + } + } + } + + // we should check even if there's nothing restored, but do not check completed if we are processing standby tasks + if (changelogMetadata.stateManager.taskType() == Task.TaskType.ACTIVE && hasRestoredToEnd(changelogMetadata)) { + madeProgress = true; + + log.info("Finished restoring changelog {} to store {} with a total number of {} records", + partition, storeName, changelogMetadata.totalRestored); + + changelogMetadata.transitTo(ChangelogState.COMPLETED); + pauseChangelogsFromRestoreConsumer(Collections.singleton(partition)); + + try { + stateRestoreListener.onRestoreEnd(partition, storeName, changelogMetadata.totalRestored); + } catch (final Exception e) { + throw new StreamsException("State restore listener failed on restore completed", e); + } + } + + return madeProgress; + } + + private Set getTasksFromPartitions(final Map tasks, + final Set partitions) { + final Set result = new HashSet<>(); + for (final TopicPartition partition : partitions) { + result.add(tasks.get(changelogs.get(partition).stateManager.taskId())); + } + return result; + } + + private void clearTaskTimeout(final Set tasks) { + tasks.forEach(Task::clearTaskTimeout); + } + + private void maybeInitTaskTimeoutOrThrow(final Set tasks, + final Exception cause) { + final long now = time.milliseconds(); + tasks.forEach(t -> t.maybeInitTaskTimeoutOrThrow(now, cause)); + } + + private Map committedOffsetForChangelogs(final Map tasks, + final Set partitions) { + final Map committedOffsets; + try { + committedOffsets = fetchCommittedOffsets(partitions, mainConsumer); + clearTaskTimeout(getTasksFromPartitions(tasks, partitions)); + } catch (final TimeoutException timeoutException) { + log.debug("Could not fetch all committed offsets for {}, will retry in the next run loop", partitions); + maybeInitTaskTimeoutOrThrow(getTasksFromPartitions(tasks, partitions), timeoutException); + return Collections.emptyMap(); + } + lastUpdateOffsetTime = time.milliseconds(); + return committedOffsets; + } + + private Map endOffsetForChangelogs(final Map tasks, + final Set partitions) { + if (partitions.isEmpty()) { + return Collections.emptyMap(); + } + + try { + final ListOffsetsResult result = adminClient.listOffsets( + partitions.stream().collect(Collectors.toMap(Function.identity(), tp -> OffsetSpec.latest())), + new ListOffsetsOptions(IsolationLevel.READ_UNCOMMITTED) + ); + + final Map resultPerPartition = result.all().get(); + clearTaskTimeout(getTasksFromPartitions(tasks, partitions)); + + return resultPerPartition.entrySet().stream().collect( + Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().offset()) + ); + } catch (final TimeoutException | InterruptedException | ExecutionException retriableException) { + log.debug("Could not fetch all end offsets for {}, will retry in the next run loop", partitions); + maybeInitTaskTimeoutOrThrow(getTasksFromPartitions(tasks, partitions), retriableException); + return Collections.emptyMap(); + } catch (final KafkaException e) { + throw new StreamsException(String.format("Failed to retrieve end offsets for %s", partitions), e); + } + } + + private void updateLimitOffsetsForStandbyChangelogs(final Map committedOffsets) { + for (final ChangelogMetadata metadata : changelogs.values()) { + final TopicPartition partition = metadata.storeMetadata.changelogPartition(); + if (metadata.stateManager.taskType() == Task.TaskType.STANDBY && + metadata.stateManager.changelogAsSource(partition) && + committedOffsets.containsKey(partition)) { + + final Long newLimit = committedOffsets.get(partition); + final Long previousLimit = metadata.restoreEndOffset; + + if (previousLimit != null && previousLimit > newLimit) { + throw new IllegalStateException("Offset limit should monotonically increase, but was reduced for partition " + + partition + ". New limit: " + newLimit + ". Previous limit: " + previousLimit); + } + + metadata.restoreEndOffset = newLimit; + + // update the limit index for buffered records + while (metadata.bufferedLimitIndex < metadata.bufferedRecords.size() && + metadata.bufferedRecords.get(metadata.bufferedLimitIndex).offset() < metadata.restoreEndOffset) + metadata.bufferedLimitIndex++; + } + } + } + + private void initializeChangelogs(final Map tasks, + final Set newPartitionsToRestore) { + if (newPartitionsToRestore.isEmpty()) { + return; + } + + // for active changelogs, we need to find their end offset before transit to restoring + // if the changelog is on source topic, then its end offset should be the minimum of + // its committed offset and its end offset; for standby tasks that use source topics + // as changelogs, we want to initialize their limit offsets as committed offsets as well + final Set newPartitionsToFindEndOffset = new HashSet<>(); + final Set newPartitionsToFindCommittedOffset = new HashSet<>(); + + for (final ChangelogMetadata metadata : newPartitionsToRestore) { + final TopicPartition partition = metadata.storeMetadata.changelogPartition(); + + // TODO K9113: when TaskType.GLOBAL is added we need to modify this + if (metadata.stateManager.taskType() == Task.TaskType.ACTIVE) { + newPartitionsToFindEndOffset.add(partition); + } + + if (metadata.stateManager.changelogAsSource(partition)) { + newPartitionsToFindCommittedOffset.add(partition); + } + } + + // NOTE we assume that all requested partitions will be included in the returned map for both end/committed + // offsets, i.e., it would not return partial result and would timeout if some of the results cannot be found + final Map endOffsets = endOffsetForChangelogs(tasks, newPartitionsToFindEndOffset); + final Map committedOffsets = committedOffsetForChangelogs(tasks, newPartitionsToFindCommittedOffset); + + for (final TopicPartition partition : newPartitionsToFindEndOffset) { + final ChangelogMetadata changelogMetadata = changelogs.get(partition); + final Long endOffset = endOffsets.get(partition); + final Long committedOffset = newPartitionsToFindCommittedOffset.contains(partition) ? + committedOffsets.get(partition) : Long.valueOf(Long.MAX_VALUE); + + if (endOffset != null && committedOffset != null) { + if (changelogMetadata.restoreEndOffset != null) { + throw new IllegalStateException("End offset for " + partition + + " should only be initialized once. Existing value: " + changelogMetadata.restoreEndOffset + + ", new value: (" + endOffset + ", " + committedOffset + ")"); + } + + changelogMetadata.restoreEndOffset = Math.min(endOffset, committedOffset); + + log.debug("End offset for changelog {} initialized as {}.", partition, changelogMetadata.restoreEndOffset); + } else { + if (!newPartitionsToRestore.remove(changelogMetadata)) { + throw new IllegalStateException("New changelogs to restore " + newPartitionsToRestore + + " does not contain the one looking for end offset: " + partition + ", this should not happen."); + } + + log.info("End offset for changelog {} cannot be found; will retry in the next time.", partition); + } + } + + // try initialize limit offsets for standby tasks for the first time + if (!committedOffsets.isEmpty()) { + updateLimitOffsetsForStandbyChangelogs(committedOffsets); + } + + // add new partitions to the restore consumer and transit them to restoring state + addChangelogsToRestoreConsumer(newPartitionsToRestore.stream().map(metadata -> metadata.storeMetadata.changelogPartition()) + .collect(Collectors.toSet())); + + newPartitionsToRestore.forEach(metadata -> metadata.transitTo(ChangelogState.RESTORING)); + + // if it is in the active restoring mode, we immediately pause those standby changelogs + // here we just blindly pause all (including the existing and newly added) + if (state == ChangelogReaderState.ACTIVE_RESTORING) { + pauseChangelogsFromRestoreConsumer(standbyRestoringChangelogs()); + } + + // prepare newly added partitions of the restore consumer by setting their starting position + prepareChangelogs(newPartitionsToRestore); + } + + private void addChangelogsToRestoreConsumer(final Set partitions) { + final Set assignment = new HashSet<>(restoreConsumer.assignment()); + + // the current assignment should not contain any of the new partitions + if (assignment.removeAll(partitions)) { + throw new IllegalStateException("The current assignment " + restoreConsumer.assignment() + " " + + "already contains some of the new partitions " + partitions); + } + assignment.addAll(partitions); + restoreConsumer.assign(assignment); + + log.debug("Added partitions {} to the restore consumer, current assignment is {}", partitions, assignment); + } + + private void pauseChangelogsFromRestoreConsumer(final Collection partitions) { + final Set assignment = new HashSet<>(restoreConsumer.assignment()); + + // the current assignment should contain all the partitions to pause + if (!assignment.containsAll(partitions)) { + throw new IllegalStateException("The current assignment " + assignment + " " + + "does not contain some of the partitions " + partitions + " for pausing."); + } + restoreConsumer.pause(partitions); + + log.debug("Paused partitions {} from the restore consumer", partitions); + } + + private void removeChangelogsFromRestoreConsumer(final Collection partitions) { + final Set assignment = new HashSet<>(restoreConsumer.assignment()); + + // the current assignment should contain all the partitions to remove + if (!assignment.containsAll(partitions)) { + throw new IllegalStateException("The current assignment " + assignment + " " + + "does not contain some of the partitions " + partitions + " for removing."); + } + assignment.removeAll(partitions); + restoreConsumer.assign(assignment); + } + + private void resumeChangelogsFromRestoreConsumer(final Collection partitions) { + final Set assignment = new HashSet<>(restoreConsumer.assignment()); + + // the current assignment should contain all the partitions to resume + if (!assignment.containsAll(partitions)) { + throw new IllegalStateException("The current assignment " + assignment + " " + + "does not contain some of the partitions " + partitions + " for resuming."); + } + restoreConsumer.resume(partitions); + + log.debug("Resumed partitions {} from the restore consumer", partitions); + } + + private void prepareChangelogs(final Set newPartitionsToRestore) { + // separate those who do not have the current offset loaded from checkpoint + final Set newPartitionsWithoutStartOffset = new HashSet<>(); + + for (final ChangelogMetadata changelogMetadata : newPartitionsToRestore) { + final StateStoreMetadata storeMetadata = changelogMetadata.storeMetadata; + final TopicPartition partition = storeMetadata.changelogPartition(); + final Long currentOffset = storeMetadata.offset(); + final Long endOffset = changelogs.get(partition).restoreEndOffset; + + if (currentOffset != null) { + // the current offset is the offset of the last record, so we should set the position + // as that offset + 1 as the "next" record to fetch; seek is not a blocking call so + // there's nothing to capture + restoreConsumer.seek(partition, currentOffset + 1); + + log.debug("Start restoring changelog partition {} from current offset {} to end offset {}.", + partition, currentOffset, recordEndOffset(endOffset)); + } else { + log.debug("Start restoring changelog partition {} from the beginning offset to end offset {} " + + "since we cannot find current offset.", partition, recordEndOffset(endOffset)); + + newPartitionsWithoutStartOffset.add(partition); + } + } + + // optimization: batch all seek-to-beginning offsets in a single request + // seek is not a blocking call so there's nothing to capture + if (!newPartitionsWithoutStartOffset.isEmpty()) { + restoreConsumer.seekToBeginning(newPartitionsWithoutStartOffset); + } + + // do not trigger restore listener if we are processing standby tasks + for (final ChangelogMetadata changelogMetadata : newPartitionsToRestore) { + if (changelogMetadata.stateManager.taskType() == Task.TaskType.ACTIVE) { + final StateStoreMetadata storeMetadata = changelogMetadata.storeMetadata; + final TopicPartition partition = storeMetadata.changelogPartition(); + final String storeName = storeMetadata.store().name(); + + long startOffset = 0L; + try { + startOffset = restoreConsumer.position(partition); + } catch (final TimeoutException swallow) { + // if we cannot find the starting position at the beginning, just use the default 0L + } catch (final KafkaException e) { + // this also includes InvalidOffsetException, which should not happen under normal + // execution, hence it is also okay to wrap it as fatal StreamsException + throw new StreamsException("Restore consumer get unexpected error trying to get the position " + + " of " + partition, e); + } + + try { + stateRestoreListener.onRestoreStart(partition, storeName, startOffset, changelogMetadata.restoreEndOffset); + } catch (final Exception e) { + throw new StreamsException("State restore listener failed on batch restored", e); + } + } + } + } + + @Override + public void unregister(final Collection revokedChangelogs) { + // Only changelogs that are initialized have been added to the restore consumer's assignment + final List revokedInitializedChangelogs = new ArrayList<>(); + + for (final TopicPartition partition : revokedChangelogs) { + final ChangelogMetadata changelogMetadata = changelogs.remove(partition); + if (changelogMetadata != null) { + if (!changelogMetadata.state().equals(ChangelogState.REGISTERED)) { + revokedInitializedChangelogs.add(partition); + } + + changelogMetadata.clear(); + } else { + log.debug("Changelog partition {} could not be found," + + " it could be already cleaned up during the handling" + + " of task corruption and never restore again", partition); + } + } + + removeChangelogsFromRestoreConsumer(revokedInitializedChangelogs); + } + + @Override + public void clear() { + for (final ChangelogMetadata changelogMetadata : changelogs.values()) { + changelogMetadata.clear(); + } + changelogs.clear(); + + try { + restoreConsumer.unsubscribe(); + } catch (final KafkaException e) { + throw new StreamsException("Restore consumer get unexpected error unsubscribing", e); + } + } + + @Override + public boolean isEmpty() { + return changelogs.isEmpty(); + } + + @Override + public String toString() { + return "StoreChangelogReader: " + changelogs + "\n"; + } + + // for testing only + ChangelogMetadata changelogMetadata(final TopicPartition partition) { + return changelogs.get(partition); + } + + ChangelogReaderState state() { + return state; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreToProcessorContextAdapter.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreToProcessorContextAdapter.java new file mode 100644 index 0000000..1cc54b5 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StoreToProcessorContextAdapter.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.StreamsMetrics; +import org.apache.kafka.streams.processor.Cancellable; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.Punctuator; +import org.apache.kafka.streams.processor.StateRestoreCallback; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.To; + +import java.io.File; +import java.time.Duration; +import java.util.Map; + +public final class StoreToProcessorContextAdapter implements ProcessorContext { + private final StateStoreContext delegate; + + public static ProcessorContext adapt(final StateStoreContext delegate) { + if (delegate instanceof ProcessorContext) { + return (ProcessorContext) delegate; + } else { + return new StoreToProcessorContextAdapter(delegate); + } + } + + private StoreToProcessorContextAdapter(final StateStoreContext delegate) { + this.delegate = delegate; + } + + @Override + public String applicationId() { + return delegate.applicationId(); + } + + @Override + public TaskId taskId() { + return delegate.taskId(); + } + + @Override + public Serde keySerde() { + return delegate.keySerde(); + } + + @Override + public Serde valueSerde() { + return delegate.valueSerde(); + } + + @Override + public File stateDir() { + return delegate.stateDir(); + } + + @Override + public StreamsMetrics metrics() { + return delegate.metrics(); + } + + @Override + public void register(final StateStore store, final StateRestoreCallback stateRestoreCallback) { + delegate.register(store, stateRestoreCallback); + } + + @Override + public S getStateStore(final String name) { + throw new UnsupportedOperationException("StateStores can't access getStateStore."); + } + + @Override + public Cancellable schedule(final Duration interval, final PunctuationType type, final Punctuator callback) { + throw new UnsupportedOperationException("StateStores can't access schedule."); + } + + @Override + public void forward(final K key, final V value) { + throw new UnsupportedOperationException("StateStores can't access forward."); + } + + @Override + public void forward(final K key, final V value, final To to) { + throw new UnsupportedOperationException("StateStores can't access forward."); + } + + @Override + public void commit() { + throw new UnsupportedOperationException("StateStores can't access commit."); + } + + @Override + public String topic() { + throw new UnsupportedOperationException("StateStores can't access topic."); + } + + @Override + public int partition() { + throw new UnsupportedOperationException("StateStores can't access partition."); + } + + @Override + public long offset() { + throw new UnsupportedOperationException("StateStores can't access offset."); + } + + @Override + public Headers headers() { + throw new UnsupportedOperationException("StateStores can't access headers."); + } + + @Override + public long timestamp() { + throw new UnsupportedOperationException("StateStores can't access timestamp."); + } + + @Override + public Map appConfigs() { + return delegate.appConfigs(); + } + + @Override + public Map appConfigsWithPrefix(final String prefix) { + return delegate.appConfigsWithPrefix(prefix); + } + + @Override + public long currentSystemTimeMs() { + throw new UnsupportedOperationException("StateStores can't access system time."); + } + + @Override + public long currentStreamTimeMs() { + throw new UnsupportedOperationException("StateStores can't access stream time."); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java new file mode 100644 index 0000000..7c85869 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamTask.java @@ -0,0 +1,1288 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.DeserializationExceptionHandler; +import org.apache.kafka.streams.errors.LockException; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.errors.TaskCorruptedException; +import org.apache.kafka.streams.errors.TaskMigratedException; +import org.apache.kafka.streams.errors.TopologyException; +import org.apache.kafka.streams.processor.Cancellable; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.Punctuator; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.TimestampExtractor; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.metrics.ProcessorNodeMetrics; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.processor.internals.metrics.TaskMetrics; +import org.apache.kafka.streams.processor.internals.metrics.ThreadMetrics; +import org.apache.kafka.streams.state.internals.ThreadCache; + +import java.io.IOException; +import java.io.PrintWriter; +import java.io.StringWriter; +import java.nio.ByteBuffer; +import java.util.Base64; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static java.util.Collections.singleton; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.maybeMeasureLatency; + +/** + * A StreamTask is associated with a {@link PartitionGroup}, and is assigned to a StreamThread for processing. + */ +public class StreamTask extends AbstractTask implements ProcessorNodePunctuator, Task { + + // visible for testing + static final byte LATEST_MAGIC_BYTE = 1; + + private final Time time; + private final Consumer mainConsumer; + + // we want to abstract eos logic out of StreamTask, however + // there's still an optimization that requires this info to be + // leaked into this class, which is to checkpoint after committing if EOS is not enabled. + private final boolean eosEnabled; + + private final int maxBufferedSize; + private final PartitionGroup partitionGroup; + private final RecordCollector recordCollector; + private final PartitionGroup.RecordInfo recordInfo; + private final Map consumedOffsets; + private final Map committedOffsets; + private final Map highWatermark; + private final Set resetOffsetsForPartitions; + private Optional timeCurrentIdlingStarted; + private final PunctuationQueue streamTimePunctuationQueue; + private final PunctuationQueue systemTimePunctuationQueue; + private final StreamsMetricsImpl streamsMetrics; + + private long processTimeMs = 0L; + + private final Sensor closeTaskSensor; + private final Sensor processRatioSensor; + private final Sensor processLatencySensor; + private final Sensor punctuateLatencySensor; + private final Sensor bufferedRecordsSensor; + private final Map e2eLatencySensors = new HashMap<>(); + + @SuppressWarnings("rawtypes") + private final InternalProcessorContext processorContext; + + private final RecordQueueCreator recordQueueCreator; + + private StampedRecord record; + private boolean commitNeeded = false; + private boolean commitRequested = false; + private boolean hasPendingTxCommit = false; + + @SuppressWarnings("rawtypes") + public StreamTask(final TaskId id, + final Set inputPartitions, + final ProcessorTopology topology, + final Consumer mainConsumer, + final StreamsConfig config, + final StreamsMetricsImpl streamsMetrics, + final StateDirectory stateDirectory, + final ThreadCache cache, + final Time time, + final ProcessorStateManager stateMgr, + final RecordCollector recordCollector, + final InternalProcessorContext processorContext, + final LogContext logContext) { + super( + id, + topology, + stateDirectory, + stateMgr, + inputPartitions, + config.getLong(StreamsConfig.TASK_TIMEOUT_MS_CONFIG), + "task", + StreamTask.class + ); + this.mainConsumer = mainConsumer; + + this.processorContext = processorContext; + processorContext.transitionToActive(this, recordCollector, cache); + + this.time = time; + this.recordCollector = recordCollector; + eosEnabled = StreamThread.eosEnabled(config); + + final String threadId = Thread.currentThread().getName(); + this.streamsMetrics = streamsMetrics; + closeTaskSensor = ThreadMetrics.closeTaskSensor(threadId, streamsMetrics); + final String taskId = id.toString(); + processRatioSensor = TaskMetrics.activeProcessRatioSensor(threadId, taskId, streamsMetrics); + processLatencySensor = TaskMetrics.processLatencySensor(threadId, taskId, streamsMetrics); + punctuateLatencySensor = TaskMetrics.punctuateSensor(threadId, taskId, streamsMetrics); + bufferedRecordsSensor = TaskMetrics.activeBufferedRecordsSensor(threadId, taskId, streamsMetrics); + + for (final String terminalNodeName : topology.terminalNodes()) { + e2eLatencySensors.put( + terminalNodeName, + ProcessorNodeMetrics.e2ELatencySensor(threadId, taskId, terminalNodeName, streamsMetrics) + ); + } + + for (final ProcessorNode sourceNode : topology.sources()) { + final String sourceNodeName = sourceNode.name(); + e2eLatencySensors.put( + sourceNodeName, + ProcessorNodeMetrics.e2ELatencySensor(threadId, taskId, sourceNodeName, streamsMetrics) + ); + } + + streamTimePunctuationQueue = new PunctuationQueue(); + systemTimePunctuationQueue = new PunctuationQueue(); + maxBufferedSize = config.getInt(StreamsConfig.BUFFERED_RECORDS_PER_PARTITION_CONFIG); + + // initialize the consumed and committed offset cache + consumedOffsets = new HashMap<>(); + resetOffsetsForPartitions = new HashSet<>(); + + recordQueueCreator = new RecordQueueCreator(this.logContext, config.defaultTimestampExtractor(), config.defaultDeserializationExceptionHandler()); + + recordInfo = new PartitionGroup.RecordInfo(); + + final Sensor enforcedProcessingSensor; + enforcedProcessingSensor = TaskMetrics.enforcedProcessingSensor(threadId, taskId, streamsMetrics); + final long maxTaskIdleMs = config.getLong(StreamsConfig.MAX_TASK_IDLE_MS_CONFIG); + partitionGroup = new PartitionGroup( + logContext, + createPartitionQueues(), + mainConsumer::currentLag, + TaskMetrics.recordLatenessSensor(threadId, taskId, streamsMetrics), + enforcedProcessingSensor, + maxTaskIdleMs + ); + + stateMgr.registerGlobalStateStores(topology.globalStateStores()); + committedOffsets = new HashMap<>(); + highWatermark = new HashMap<>(); + for (final TopicPartition topicPartition: inputPartitions) { + committedOffsets.put(topicPartition, -1L); + highWatermark.put(topicPartition, -1L); + } + timeCurrentIdlingStarted = Optional.empty(); + } + + // create queues for each assigned partition and associate them + // to corresponding source nodes in the processor topology + private Map createPartitionQueues() { + final Map partitionQueues = new HashMap<>(); + for (final TopicPartition partition : inputPartitions()) { + partitionQueues.put(partition, recordQueueCreator.createQueue(partition)); + } + return partitionQueues; + } + + @Override + public boolean isActive() { + return true; + } + + /** + * @throws LockException could happen when multi-threads within the single instance, could retry + * @throws TimeoutException if initializing record collector timed out + * @throws StreamsException fatal error, should close the thread + */ + @Override + public void initializeIfNeeded() { + if (state() == State.CREATED) { + recordCollector.initialize(); + + StateManagerUtil.registerStateStores(log, logPrefix, topology, stateMgr, stateDirectory, processorContext); + + // without EOS the checkpoint file would not be deleted after loading, and + // with EOS we would not checkpoint ever during running state anyways. + // therefore we can initialize the snapshot as empty so that we would checkpoint right after loading + offsetSnapshotSinceLastFlush = Collections.emptyMap(); + + transitionTo(State.RESTORING); + + log.info("Initialized"); + } + } + + public void addPartitionsForOffsetReset(final Set partitionsForOffsetReset) { + mainConsumer.pause(partitionsForOffsetReset); + resetOffsetsForPartitions.addAll(partitionsForOffsetReset); + } + + /** + * @throws TimeoutException if fetching committed offsets timed out + */ + @Override + public void completeRestoration(final java.util.function.Consumer> offsetResetter) { + switch (state()) { + case RUNNING: + return; + + case RESTORING: + resetOffsetsIfNeededAndInitializeMetadata(offsetResetter); + initializeTopology(); + processorContext.initialize(); + + transitionTo(State.RUNNING); + + log.info("Restored and ready to run"); + + break; + + case CREATED: + case SUSPENDED: + case CLOSED: + throw new IllegalStateException("Illegal state " + state() + " while completing restoration for active task " + id); + + default: + throw new IllegalStateException("Unknown state " + state() + " while completing restoration for active task " + id); + } + } + + @Override + public void suspend() { + switch (state()) { + case CREATED: + transitToSuspend(); + break; + + case RESTORING: + transitToSuspend(); + break; + + case RUNNING: + try { + // use try-catch to ensure state transition to SUSPENDED even if user code throws in `Processor#close()` + closeTopology(); + + // we must clear the buffered records when suspending because upon resuming the consumer would + // re-fetch those records starting from the committed position + partitionGroup.clear(); + } finally { + transitToSuspend(); + log.info("Suspended running"); + } + + break; + + case SUSPENDED: + log.info("Skip suspending since state is {}", state()); + + break; + + case CLOSED: + throw new IllegalStateException("Illegal state " + state() + " while suspending active task " + id); + + default: + throw new IllegalStateException("Unknown state " + state() + " while suspending active task " + id); + } + } + + @SuppressWarnings("unchecked") + private void closeTopology() { + log.trace("Closing processor topology"); + + // close the processors + // make sure close() is called for each node even when there is a RuntimeException + RuntimeException exception = null; + for (final ProcessorNode node : topology.processors()) { + processorContext.setCurrentNode(node); + try { + node.close(); + } catch (final RuntimeException e) { + exception = e; + } finally { + processorContext.setCurrentNode(null); + } + } + + if (exception != null) { + throw exception; + } + } + + /** + *

                +     * - resume the task
                +     * 
                + */ + @Override + public void resume() { + switch (state()) { + case CREATED: + case RUNNING: + case RESTORING: + // no need to do anything, just let them continue running / restoring / closing + log.trace("Skip resuming since state is {}", state()); + break; + + case SUSPENDED: + // just transit the state without any logical changes: suspended and restoring states + // are not actually any different for inner modules + + // Deleting checkpoint file before transition to RESTORING state (KAFKA-10362) + try { + stateMgr.deleteCheckPointFileIfEOSEnabled(); + log.debug("Deleted check point file upon resuming with EOS enabled"); + } catch (final IOException ioe) { + log.error("Encountered error while deleting the checkpoint file due to this exception", ioe); + } + + transitionTo(State.RESTORING); + log.info("Resumed to restoring state"); + + break; + + case CLOSED: + throw new IllegalStateException("Illegal state " + state() + " while resuming active task " + id); + + default: + throw new IllegalStateException("Unknown state " + state() + " while resuming active task " + id); + } + timeCurrentIdlingStarted = Optional.empty(); + } + + /** + * @throws StreamsException fatal error that should cause the thread to die + * @throws TaskMigratedException recoverable error that would cause the task to be removed + * @return offsets that should be committed for this task + */ + @Override + public Map prepareCommit() { + switch (state()) { + case CREATED: + case RESTORING: + case RUNNING: + case SUSPENDED: + // the commitNeeded flag just indicates whether we have reached RUNNING and processed any new data, + // so it only indicates whether the record collector should be flushed or not, whereas the state + // manager should always be flushed; either there is newly restored data or the flush will be a no-op + if (commitNeeded) { + // we need to flush the store caches before flushing the record collector since it may cause some + // cached records to be processed and hence generate more records to be sent out + // + // TODO: this should be removed after we decouple caching with emitting + stateMgr.flushCache(); + recordCollector.flush(); + hasPendingTxCommit = eosEnabled; + + log.debug("Prepared {} task for committing", state()); + return committableOffsetsAndMetadata(); + } else { + log.debug("Skipped preparing {} task for commit since there is nothing to commit", state()); + return Collections.emptyMap(); + } + + case CLOSED: + throw new IllegalStateException("Illegal state " + state() + " while preparing active task " + id + " for committing"); + + default: + throw new IllegalStateException("Unknown state " + state() + " while preparing active task " + id + " for committing"); + } + } + + private Map committableOffsetsAndMetadata() { + final Map committableOffsets; + + switch (state()) { + case CREATED: + case RESTORING: + committableOffsets = Collections.emptyMap(); + + break; + + case RUNNING: + case SUSPENDED: + final Map partitionTimes = extractPartitionTimes(); + + committableOffsets = new HashMap<>(consumedOffsets.size()); + for (final Map.Entry entry : consumedOffsets.entrySet()) { + final TopicPartition partition = entry.getKey(); + Long offset = partitionGroup.headRecordOffset(partition); + if (offset == null) { + try { + offset = mainConsumer.position(partition); + } catch (final TimeoutException error) { + // the `consumer.position()` call should never block, because we know that we did process data + // for the requested partition and thus the consumer should have a valid local position + // that it can return immediately + + // hence, a `TimeoutException` indicates a bug and thus we rethrow it as fatal `IllegalStateException` + throw new IllegalStateException(error); + } catch (final KafkaException fatal) { + throw new StreamsException(fatal); + } + } + final long partitionTime = partitionTimes.get(partition); + committableOffsets.put(partition, new OffsetAndMetadata(offset, encodeTimestamp(partitionTime))); + } + + break; + + case CLOSED: + throw new IllegalStateException("Illegal state " + state() + " while getting committable offsets for active task " + id); + + default: + throw new IllegalStateException("Unknown state " + state() + " while post committing active task " + id); + } + + return committableOffsets; + } + + @Override + public void postCommit(final boolean enforceCheckpoint) { + switch (state()) { + case CREATED: + // We should never write a checkpoint for a CREATED task as we may overwrite an existing checkpoint + // with empty uninitialized offsets + log.debug("Skipped writing checkpoint for {} task", state()); + + break; + + case RESTORING: + case SUSPENDED: + maybeWriteCheckpoint(enforceCheckpoint); + log.debug("Finalized commit for {} task with enforce checkpoint {}", state(), enforceCheckpoint); + + break; + + case RUNNING: + if (enforceCheckpoint || !eosEnabled) { + maybeWriteCheckpoint(enforceCheckpoint); + } + log.debug("Finalized commit for {} task with eos {} enforce checkpoint {}", state(), eosEnabled, enforceCheckpoint); + + break; + + case CLOSED: + throw new IllegalStateException("Illegal state " + state() + " while post committing active task " + id); + + default: + throw new IllegalStateException("Unknown state " + state() + " while post committing active task " + id); + } + + clearCommitStatuses(); + } + + private void clearCommitStatuses() { + commitNeeded = false; + commitRequested = false; + hasPendingTxCommit = false; + } + + private Map extractPartitionTimes() { + final Map partitionTimes = new HashMap<>(); + for (final TopicPartition partition : partitionGroup.partitions()) { + partitionTimes.put(partition, partitionGroup.partitionTimestamp(partition)); + } + return partitionTimes; + } + + @Override + public void closeClean() { + validateClean(); + removeAllSensors(); + clearCommitStatuses(); + close(true); + log.info("Closed clean"); + } + + @Override + public void closeDirty() { + removeAllSensors(); + clearCommitStatuses(); + close(false); + log.info("Closed dirty"); + } + + @Override + public void updateInputPartitions(final Set topicPartitions, final Map> allTopologyNodesToSourceTopics) { + super.updateInputPartitions(topicPartitions, allTopologyNodesToSourceTopics); + partitionGroup.updatePartitions(topicPartitions, recordQueueCreator::createQueue); + } + + @Override + public void closeCleanAndRecycleState() { + validateClean(); + removeAllSensors(); + clearCommitStatuses(); + switch (state()) { + case SUSPENDED: + stateMgr.recycle(); + recordCollector.closeClean(); + + break; + + case CREATED: + case RESTORING: + case RUNNING: + case CLOSED: + throw new IllegalStateException("Illegal state " + state() + " while recycling active task " + id); + default: + throw new IllegalStateException("Unknown state " + state() + " while recycling active task " + id); + } + + closeTaskSensor.record(); + + transitionTo(State.CLOSED); + + log.info("Closed clean and recycled state"); + } + + /** + * The following exceptions maybe thrown from the state manager flushing call + * + * @throws TaskMigratedException recoverable error sending changelog records that would cause the task to be removed + * @throws StreamsException fatal error when flushing the state store, for example sending changelog records failed + * or flushing state store get IO errors; such error should cause the thread to die + */ + @Override + protected void maybeWriteCheckpoint(final boolean enforceCheckpoint) { + // commitNeeded indicates we may have processed some records since last commit + // and hence we need to refresh checkpointable offsets regardless whether we should checkpoint or not + if (commitNeeded || enforceCheckpoint) { + stateMgr.updateChangelogOffsets(checkpointableOffsets()); + } + + super.maybeWriteCheckpoint(enforceCheckpoint); + } + + private void validateClean() { + // It may be that we failed to commit a task during handleRevocation, but "forgot" this and tried to + // closeClean in handleAssignment. We should throw if we detect this to force the TaskManager to closeDirty + if (commitNeeded) { + log.debug("Tried to close clean but there was pending uncommitted data, this means we failed to" + + " commit and should close as dirty instead"); + throw new TaskMigratedException("Tried to close dirty task as clean"); + } + } + + private void removeAllSensors() { + streamsMetrics.removeAllTaskLevelSensors(Thread.currentThread().getName(), id.toString()); + for (final String nodeName : e2eLatencySensors.keySet()) { + streamsMetrics.removeAllNodeLevelSensors(Thread.currentThread().getName(), id.toString(), nodeName); + } + } + + /** + * You must commit a task and checkpoint the state manager before closing as this will release the state dir lock + */ + private void close(final boolean clean) { + switch (state()) { + case SUSPENDED: + // first close state manager (which is idempotent) then close the record collector + // if the latter throws and we re-close dirty which would close the state manager again. + TaskManager.executeAndMaybeSwallow( + clean, + () -> StateManagerUtil.closeStateManager( + log, + logPrefix, + clean, + eosEnabled, + stateMgr, + stateDirectory, + TaskType.ACTIVE + ), + "state manager close", + log); + + TaskManager.executeAndMaybeSwallow( + clean, + clean ? recordCollector::closeClean : recordCollector::closeDirty, + "record collector close", + log + ); + + break; + + case CLOSED: + log.trace("Skip closing since state is {}", state()); + return; + + case CREATED: + case RESTORING: + case RUNNING: + throw new IllegalStateException("Illegal state " + state() + " while closing active task " + id); + + default: + throw new IllegalStateException("Unknown state " + state() + " while closing active task " + id); + } + + record = null; + partitionGroup.clear(); + closeTaskSensor.record(); + + transitionTo(State.CLOSED); + } + + /** + * An active task is processable if its buffer contains data for all of its input + * source topic partitions, or if it is enforced to be processable. + */ + public boolean isProcessable(final long wallClockTime) { + if (state() == State.CLOSED) { + // a task is only closing / closed when 1) task manager is closing, 2) a rebalance is undergoing; + // in either case we can just log it and move on without notifying the thread since the consumer + // would soon be updated to not return any records for this task anymore. + log.info("Stream task {} is already in {} state, skip processing it.", id(), state()); + + return false; + } + + if (hasPendingTxCommit) { + // if the task has a pending TX commit, we should just retry the commit but not process any records + // thus, the task is not processable, even if there is available data in the record queue + return false; + } + final boolean readyToProcess = partitionGroup.readyToProcess(wallClockTime); + if (!readyToProcess) { + if (!timeCurrentIdlingStarted.isPresent()) { + timeCurrentIdlingStarted = Optional.of(wallClockTime); + } + } else { + timeCurrentIdlingStarted = Optional.empty(); + } + return readyToProcess; + } + + /** + * Process one record. + * + * @return true if this method processes a record, false if it does not process a record. + * @throws TaskMigratedException if the task producer got fenced (EOS only) + */ + @SuppressWarnings("unchecked") + public boolean process(final long wallClockTime) { + if (record == null) { + if (!isProcessable(wallClockTime)) { + return false; + } + + // get the next record to process + record = partitionGroup.nextRecord(recordInfo, wallClockTime); + + // if there is no record to process, return immediately + if (record == null) { + return false; + } + } + + + try { + // process the record by passing to the source node of the topology + final ProcessorNode currNode = (ProcessorNode) recordInfo.node(); + final TopicPartition partition = recordInfo.partition(); + + log.trace("Start processing one record [{}]", record); + + final ProcessorRecordContext recordContext = new ProcessorRecordContext( + record.timestamp, + record.offset(), + record.partition(), + record.topic(), + record.headers() + ); + updateProcessorContext(currNode, wallClockTime, recordContext); + + maybeRecordE2ELatency(record.timestamp, wallClockTime, currNode.name()); + final Record toProcess = new Record<>( + record.key(), + record.value(), + processorContext.timestamp(), + processorContext.headers() + ); + maybeMeasureLatency(() -> currNode.process(toProcess), time, processLatencySensor); + + log.trace("Completed processing one record [{}]", record); + + // update the consumed offset map after processing is done + consumedOffsets.put(partition, record.offset()); + commitNeeded = true; + + // after processing this record, if its partition queue's buffered size has been + // decreased to the threshold, we can then resume the consumption on this partition + if (recordInfo.queue().size() == maxBufferedSize) { + mainConsumer.resume(singleton(partition)); + } + + record = null; + } catch (final TimeoutException timeoutException) { + if (!eosEnabled) { + throw timeoutException; + } else { + record = null; + throw new TaskCorruptedException(Collections.singleton(id)); + } + } catch (final StreamsException exception) { + record = null; + throw exception; + } catch (final RuntimeException e) { + final StreamsException error = new StreamsException( + String.format( + "Exception caught in process. taskId=%s, processor=%s, topic=%s, partition=%d, offset=%d, stacktrace=%s", + id(), + processorContext.currentNode().name(), + record.topic(), + record.partition(), + record.offset(), + getStacktraceString(e) + ), + e + ); + record = null; + + throw error; + } finally { + processorContext.setCurrentNode(null); + } + + return true; + } + + @Override + public void recordProcessBatchTime(final long processBatchTime) { + processTimeMs += processBatchTime; + } + + @Override + public void recordProcessTimeRatioAndBufferSize(final long allTaskProcessMs, final long now) { + bufferedRecordsSensor.record(partitionGroup.numBuffered()); + processRatioSensor.record((double) processTimeMs / allTaskProcessMs, now); + processTimeMs = 0L; + } + + private String getStacktraceString(final RuntimeException e) { + String stacktrace = null; + try (final StringWriter stringWriter = new StringWriter(); + final PrintWriter printWriter = new PrintWriter(stringWriter)) { + e.printStackTrace(printWriter); + stacktrace = stringWriter.toString(); + } catch (final IOException ioe) { + log.error("Encountered error extracting stacktrace from this exception", ioe); + } + return stacktrace; + } + + /** + * @throws IllegalStateException if the current node is not null + * @throws TaskMigratedException if the task producer got fenced (EOS only) + */ + @SuppressWarnings("unchecked") + @Override + public void punctuate(final ProcessorNode node, + final long timestamp, + final PunctuationType type, + final Punctuator punctuator) { + if (processorContext.currentNode() != null) { + throw new IllegalStateException(String.format("%sCurrent node is not null", logPrefix)); + } + + // when punctuating, we need to preserve the timestamp (this can be either system time or event time) + // while other record context are set as dummy: null topic, -1 partition, -1 offset and empty header + final ProcessorRecordContext recordContext = new ProcessorRecordContext( + timestamp, + -1L, + -1, + null, + new RecordHeaders() + ); + updateProcessorContext(node, time.milliseconds(), recordContext); + + if (log.isTraceEnabled()) { + log.trace("Punctuating processor {} with timestamp {} and punctuation type {}", node.name(), timestamp, type); + } + + try { + maybeMeasureLatency(() -> node.punctuate(timestamp, punctuator), time, punctuateLatencySensor); + } catch (final StreamsException e) { + throw e; + } catch (final RuntimeException e) { + throw new StreamsException(String.format("%sException caught while punctuating processor '%s'", logPrefix, node.name()), e); + } finally { + processorContext.setCurrentNode(null); + } + } + + @SuppressWarnings("unchecked") + private void updateProcessorContext(final ProcessorNode currNode, + final long wallClockTime, + final ProcessorRecordContext recordContext) { + processorContext.setRecordContext(recordContext); + processorContext.setCurrentNode(currNode); + processorContext.setSystemTimeMs(wallClockTime); + } + + /** + * Return all the checkpointable offsets(written + consumed) to the state manager. + * Currently only changelog topic offsets need to be checkpointed. + */ + private Map checkpointableOffsets() { + final Map checkpointableOffsets = new HashMap<>(recordCollector.offsets()); + for (final Map.Entry entry : consumedOffsets.entrySet()) { + checkpointableOffsets.putIfAbsent(entry.getKey(), entry.getValue()); + } + + log.debug("Checkpointable offsets {}", checkpointableOffsets); + + return checkpointableOffsets; + } + + private void resetOffsetsIfNeededAndInitializeMetadata(final java.util.function.Consumer> offsetResetter) { + try { + final Map offsetsAndMetadata = mainConsumer.committed(inputPartitions()); + + for (final Map.Entry committedEntry : offsetsAndMetadata.entrySet()) { + if (resetOffsetsForPartitions.contains(committedEntry.getKey())) { + final OffsetAndMetadata offsetAndMetadata = committedEntry.getValue(); + if (offsetAndMetadata != null) { + mainConsumer.seek(committedEntry.getKey(), offsetAndMetadata); + resetOffsetsForPartitions.remove(committedEntry.getKey()); + } + } + } + + offsetResetter.accept(resetOffsetsForPartitions); + resetOffsetsForPartitions.clear(); + + initializeTaskTime(offsetsAndMetadata.entrySet().stream() + .filter(e -> e.getValue() != null) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)) + ); + } catch (final TimeoutException timeoutException) { + log.warn( + "Encountered {} while trying to fetch committed offsets, will retry initializing the metadata in the next loop." + + "\nConsider overwriting consumer config {} to a larger value to avoid timeout errors", + time.toString(), + ConsumerConfig.DEFAULT_API_TIMEOUT_MS_CONFIG); + + // re-throw to trigger `task.timeout.ms` + throw timeoutException; + } catch (final KafkaException e) { + throw new StreamsException(String.format("task [%s] Failed to initialize offsets for %s", id, inputPartitions()), e); + } + } + + private void initializeTaskTime(final Map offsetsAndMetadata) { + for (final Map.Entry entry : offsetsAndMetadata.entrySet()) { + final TopicPartition partition = entry.getKey(); + final OffsetAndMetadata metadata = entry.getValue(); + + if (metadata != null) { + final long committedTimestamp = decodeTimestamp(metadata.metadata()); + partitionGroup.setPartitionTime(partition, committedTimestamp); + log.debug("A committed timestamp was detected: setting the partition time of partition {}" + + " to {} in stream task {}", partition, committedTimestamp, id); + } else { + log.debug("No committed timestamp was found in metadata for partition {}", partition); + } + } + + final Set nonCommitted = new HashSet<>(inputPartitions()); + nonCommitted.removeAll(offsetsAndMetadata.keySet()); + for (final TopicPartition partition : nonCommitted) { + log.debug("No committed offset for partition {}, therefore no timestamp can be found for this partition", partition); + } + } + + @Override + public Map purgeableOffsets() { + final Map purgeableConsumedOffsets = new HashMap<>(); + for (final Map.Entry entry : consumedOffsets.entrySet()) { + final TopicPartition tp = entry.getKey(); + if (topology.isRepartitionTopic(tp.topic())) { + purgeableConsumedOffsets.put(tp, entry.getValue() + 1); + } + } + + return purgeableConsumedOffsets; + } + + @SuppressWarnings("unchecked") + private void initializeTopology() { + // initialize the task by initializing all its processor nodes in the topology + log.trace("Initializing processor nodes of the topology"); + for (final ProcessorNode node : topology.processors()) { + processorContext.setCurrentNode(node); + try { + node.init(processorContext); + } finally { + processorContext.setCurrentNode(null); + } + } + } + + /** + * Adds records to queues. If a record has an invalid (i.e., negative) timestamp, the record is skipped + * and not added to the queue for processing + * + * @param partition the partition + * @param records the records + */ + @Override + public void addRecords(final TopicPartition partition, final Iterable> records) { + final int newQueueSize = partitionGroup.addRawRecords(partition, records); + + if (log.isTraceEnabled()) { + log.trace("Added records into the buffered queue of partition {}, new queue size is {}", partition, newQueueSize); + } + + // if after adding these records, its partition queue's buffered size has been + // increased beyond the threshold, we can then pause the consumption for this partition + if (newQueueSize > maxBufferedSize) { + mainConsumer.pause(singleton(partition)); + } + } + + /** + * Schedules a punctuation for the processor + * + * @param interval the interval in milliseconds + * @param type the punctuation type + * @throws IllegalStateException if the current node is not null + */ + public Cancellable schedule(final long interval, final PunctuationType type, final Punctuator punctuator) { + switch (type) { + case STREAM_TIME: + // align punctuation to 0L, punctuate as soon as we have data + return schedule(0L, interval, type, punctuator); + case WALL_CLOCK_TIME: + // align punctuation to now, punctuate after interval has elapsed + return schedule(time.milliseconds() + interval, interval, type, punctuator); + default: + throw new IllegalArgumentException("Unrecognized PunctuationType: " + type); + } + } + + /** + * Schedules a punctuation for the processor + * + * @param startTime time of the first punctuation + * @param interval the interval in milliseconds + * @param type the punctuation type + * @throws IllegalStateException if the current node is not null + */ + private Cancellable schedule(final long startTime, final long interval, final PunctuationType type, final Punctuator punctuator) { + if (processorContext.currentNode() == null) { + throw new IllegalStateException(String.format("%sCurrent node is null", logPrefix)); + } + + final PunctuationSchedule schedule = new PunctuationSchedule(processorContext.currentNode(), startTime, interval, punctuator); + + switch (type) { + case STREAM_TIME: + // STREAM_TIME punctuation is data driven, will first punctuate as soon as stream-time is known and >= time, + // stream-time is known when we have received at least one record from each input topic + return streamTimePunctuationQueue.schedule(schedule); + case WALL_CLOCK_TIME: + // WALL_CLOCK_TIME is driven by the wall clock time, will first punctuate when now >= time + return systemTimePunctuationQueue.schedule(schedule); + default: + throw new IllegalArgumentException("Unrecognized PunctuationType: " + type); + } + } + + /** + * Possibly trigger registered stream-time punctuation functions if + * current partition group timestamp has reached the defined stamp + * Note, this is only called in the presence of new records + * + * @throws TaskMigratedException if the task producer got fenced (EOS only) + */ + public boolean maybePunctuateStreamTime() { + final long streamTime = partitionGroup.streamTime(); + + // if the timestamp is not known yet, meaning there is not enough data accumulated + // to reason stream partition time, then skip. + if (streamTime == RecordQueue.UNKNOWN) { + return false; + } else { + final boolean punctuated = streamTimePunctuationQueue.mayPunctuate(streamTime, PunctuationType.STREAM_TIME, this); + + if (punctuated) { + commitNeeded = true; + } + + return punctuated; + } + } + + /** + * Possibly trigger registered system-time punctuation functions if + * current system timestamp has reached the defined stamp + * Note, this is called irrespective of the presence of new records + * + * @throws TaskMigratedException if the task producer got fenced (EOS only) + */ + public boolean maybePunctuateSystemTime() { + final long systemTime = time.milliseconds(); + + final boolean punctuated = systemTimePunctuationQueue.mayPunctuate(systemTime, PunctuationType.WALL_CLOCK_TIME, this); + + if (punctuated) { + commitNeeded = true; + } + + return punctuated; + } + + void maybeRecordE2ELatency(final long recordTimestamp, final long now, final String nodeName) { + final Sensor e2eLatencySensor = e2eLatencySensors.get(nodeName); + if (e2eLatencySensor == null) { + throw new IllegalStateException("Requested to record e2e latency but could not find sensor for node " + nodeName); + } else if (e2eLatencySensor.shouldRecord() && e2eLatencySensor.hasMetrics()) { + e2eLatencySensor.record(now - recordTimestamp, now); + } + } + + /** + * Request committing the current task's state + */ + void requestCommit() { + commitRequested = true; + } + + /** + * Whether or not a request has been made to commit the current state + */ + @Override + public boolean commitRequested() { + return commitRequested; + } + + static String encodeTimestamp(final long partitionTime) { + final ByteBuffer buffer = ByteBuffer.allocate(9); + buffer.put(LATEST_MAGIC_BYTE); + buffer.putLong(partitionTime); + return Base64.getEncoder().encodeToString(buffer.array()); + } + + long decodeTimestamp(final String encryptedString) { + if (encryptedString.isEmpty()) { + return RecordQueue.UNKNOWN; + } + final ByteBuffer buffer = ByteBuffer.wrap(Base64.getDecoder().decode(encryptedString)); + final byte version = buffer.get(); + switch (version) { + case LATEST_MAGIC_BYTE: + return buffer.getLong(); + default: + log.warn("Unsupported offset metadata version found. Supported version {}. Found version {}.", + LATEST_MAGIC_BYTE, version); + return RecordQueue.UNKNOWN; + } + } + + @SuppressWarnings("rawtypes") + public InternalProcessorContext processorContext() { + return processorContext; + } + + /** + * Produces a string representation containing useful information about a Task. + * This is useful in debugging scenarios. + * + * @return A string representation of the StreamTask instance. + */ + @Override + public String toString() { + return toString(""); + } + + /** + * Produces a string representation containing useful information about a Task starting with the given indent. + * This is useful in debugging scenarios. + * + * @return A string representation of the Task instance. + */ + public String toString(final String indent) { + final StringBuilder sb = new StringBuilder(); + sb.append(indent); + sb.append("TaskId: "); + sb.append(id); + sb.append("\n"); + + // print topology + if (topology != null) { + sb.append(indent).append(topology.toString(indent + "\t")); + } + + // print assigned partitions + final Set partitions = inputPartitions(); + if (partitions != null && !partitions.isEmpty()) { + sb.append(indent).append("Partitions ["); + for (final TopicPartition topicPartition : partitions) { + sb.append(topicPartition).append(", "); + } + sb.setLength(sb.length() - 2); + sb.append("]\n"); + } + return sb.toString(); + } + + @Override + public boolean commitNeeded() { + // we need to do an extra check if the flag was false, that + // if the consumer position has been updated; this is because + // there may be non data records such as control markers bypassed + if (commitNeeded) { + return true; + } else { + for (final Map.Entry entry : consumedOffsets.entrySet()) { + final TopicPartition partition = entry.getKey(); + try { + final long offset = mainConsumer.position(partition); + + // note the position in consumer is the "next" record to fetch, + // so it should be larger than the consumed offset by 1; if it is + // more than 1 it means there are control records, which the consumer skips over silently + if (offset > entry.getValue() + 1) { + commitNeeded = true; + entry.setValue(offset - 1); + } + } catch (final TimeoutException error) { + // the `consumer.position()` call should never block, because we know that we did process data + // for the requested partition and thus the consumer should have a valid local position + // that it can return immediately + + // hence, a `TimeoutException` indicates a bug and thus we rethrow it as fatal `IllegalStateException` + throw new IllegalStateException(error); + } catch (final KafkaException fatal) { + throw new StreamsException(fatal); + } + } + + return commitNeeded; + } + } + + @Override + public Map changelogOffsets() { + if (state() == State.RUNNING) { + // if we are in running state, just return the latest offset sentinel indicating + // we should be at the end of the changelog + return changelogPartitions().stream() + .collect(Collectors.toMap(Function.identity(), tp -> Task.LATEST_OFFSET)); + } else { + return Collections.unmodifiableMap(stateMgr.changelogOffsets()); + } + } + + @Override + public Map committedOffsets() { + return Collections.unmodifiableMap(committedOffsets); + } + + @Override + public Map highWaterMark() { + return Collections.unmodifiableMap(highWatermark); + } + + private void transitToSuspend() { + log.info("Suspended {}", state()); + transitionTo(State.SUSPENDED); + timeCurrentIdlingStarted = Optional.of(System.currentTimeMillis()); + } + + @Override + public Optional timeCurrentIdlingStarted() { + return timeCurrentIdlingStarted; + } + + public void updateCommittedOffsets(final TopicPartition topicPartition, final Long offset) { + committedOffsets.put(topicPartition, offset); + } + + public void updateEndOffsets(final TopicPartition topicPartition, final Long offset) { + highWatermark.put(topicPartition, offset); + } + + public boolean hasRecordsQueued() { + return numBuffered() > 0; + } + + RecordCollector recordCollector() { + return recordCollector; + } + + // below are visible for testing only + int numBuffered() { + return partitionGroup.numBuffered(); + } + + long streamTime() { + return partitionGroup.streamTime(); + } + + private class RecordQueueCreator { + private final LogContext logContext; + private final TimestampExtractor defaultTimestampExtractor; + private final DeserializationExceptionHandler defaultDeserializationExceptionHandler; + + private RecordQueueCreator(final LogContext logContext, + final TimestampExtractor defaultTimestampExtractor, + final DeserializationExceptionHandler defaultDeserializationExceptionHandler) { + this.logContext = logContext; + this.defaultTimestampExtractor = defaultTimestampExtractor; + this.defaultDeserializationExceptionHandler = defaultDeserializationExceptionHandler; + } + + public RecordQueue createQueue(final TopicPartition partition) { + final SourceNode source = topology.source(partition.topic()); + if (source == null) { + throw new TopologyException( + "Topic is unknown to the topology. " + + "This may happen if different KafkaStreams instances of the same application execute different Topologies. " + + "Note that Topologies are only identical if all operators are added in the same order." + ); + } + + final TimestampExtractor sourceTimestampExtractor = source.getTimestampExtractor(); + final TimestampExtractor timestampExtractor = sourceTimestampExtractor != null ? sourceTimestampExtractor : defaultTimestampExtractor; + return new RecordQueue( + partition, + source, + timestampExtractor, + defaultDeserializationExceptionHandler, + processorContext, + logContext + ); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java new file mode 100644 index 0000000..88814b0 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThread.java @@ -0,0 +1,1310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRebalanceListener; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.InvalidOffsetException; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.KafkaClientSupplier; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.TaskMetadata; +import org.apache.kafka.streams.ThreadMetadata; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.errors.TaskCorruptedException; +import org.apache.kafka.streams.errors.TaskMigratedException; +import org.apache.kafka.streams.internals.metrics.ClientMetrics; +import org.apache.kafka.streams.processor.StateRestoreListener; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.assignment.AssignorError; +import org.apache.kafka.streams.processor.internals.assignment.ReferenceContainer; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.processor.internals.metrics.ThreadMetrics; +import org.apache.kafka.streams.state.internals.ThreadCache; +import org.slf4j.Logger; + +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.stream.Collectors; + +import static org.apache.kafka.streams.processor.internals.ClientUtils.getConsumerClientId; +import static org.apache.kafka.streams.processor.internals.ClientUtils.getRestoreConsumerClientId; +import static org.apache.kafka.streams.processor.internals.ClientUtils.getSharedAdminClientId; + +public class StreamThread extends Thread { + + /** + * Stream thread states are the possible states that a stream thread can be in. + * A thread must only be in one state at a time + * The expected state transitions with the following defined states is: + * + *
                +     *                 +-------------+
                +     *          +<---- | Created (0) |
                +     *          |      +-----+-------+
                +     *          |            |
                +     *          |            v
                +     *          |      +-----+-------+
                +     *          +<---- | Starting (1)|----->+
                +     *          |      +-----+-------+      |
                +     *          |                           |
                +     *          |            +<----------+  |
                +     *          |            |           |  |
                +     *          |            v           |  |
                +     *          |      +-----+-------+   |  |
                +     *          +<---- | Partitions  | --+  |
                +     *          |      | Revoked (2) | <----+
                +     *          |      +-----+-------+      |
                +     *          |           |  ^            |
                +     *          |           v  |            |
                +     *          |      +-----+-------+      |
                +     *          +<---- | Partitions  |      |
                +     *          |      | Assigned (3)| <----+
                +     *          |      +-----+-------+      |
                +     *          |            |              |
                +     *          |            +<----------+  |
                +     *          |            |           |  |
                +     *          |            v           |  |
                +     *          |      +-----+-------+   |  |
                +     *          |      |             | --+  |
                +     *          |      | Running (4) | ---->+
                +     *          |      +-----+-------+
                +     *          |            |
                +     *          |            v
                +     *          |      +-----+-------+
                +     *          +----> | Pending     |
                +     *                 | Shutdown (5)|
                +     *                 +-----+-------+
                +     *                       |
                +     *                       v
                +     *                 +-----+-------+
                +     *                 | Dead (6)    |
                +     *                 +-------------+
                +     * 
                + * + * Note the following: + *
                  + *
                • Any state can go to PENDING_SHUTDOWN. That is because streams can be closed at any time.
                • + *
                • + * State PENDING_SHUTDOWN may want to transit to some other states other than DEAD, + * in the corner case when the shutdown is triggered while the thread is still in the rebalance loop. + * In this case we will forbid the transition but will not treat as an error. + *
                • + *
                • + * State PARTITIONS_REVOKED may want transit to itself indefinitely, in the corner case when + * the coordinator repeatedly fails in-between revoking partitions and assigning new partitions. + * Also during streams instance start up PARTITIONS_REVOKED may want to transit to itself as well. + * In this case we will allow the transition but it will be a no-op as the set of revoked partitions + * should be empty. + *
                • + *
                + */ + public enum State implements ThreadStateTransitionValidator { + + CREATED(1, 5), // 0 + STARTING(2, 3, 5), // 1 + PARTITIONS_REVOKED(2, 3, 5), // 2 + PARTITIONS_ASSIGNED(2, 3, 4, 5), // 3 + RUNNING(2, 3, 4, 5), // 4 + PENDING_SHUTDOWN(6), // 5 + DEAD; // 6 + + private final Set validTransitions = new HashSet<>(); + + State(final Integer... validTransitions) { + this.validTransitions.addAll(Arrays.asList(validTransitions)); + } + + public boolean isAlive() { + return equals(RUNNING) || equals(STARTING) || equals(PARTITIONS_REVOKED) || equals(PARTITIONS_ASSIGNED); + } + + @Override + public boolean isValidTransition(final ThreadStateTransitionValidator newState) { + final State tmpState = (State) newState; + return validTransitions.contains(tmpState.ordinal()); + } + } + + /** + * Listen to state change events + */ + public interface StateListener { + + /** + * Called when state changes + * + * @param thread thread changing state + * @param newState current state + * @param oldState previous state + */ + void onChange(final Thread thread, final ThreadStateTransitionValidator newState, final ThreadStateTransitionValidator oldState); + } + + /** + * Set the {@link StreamThread.StateListener} to be notified when state changes. Note this API is internal to + * Kafka Streams and is not intended to be used by an external application. + */ + public void setStateListener(final StreamThread.StateListener listener) { + stateListener = listener; + } + + /** + * @return The state this instance is in + */ + public State state() { + // we do not need to use the state lock since the variable is volatile + return state; + } + + void setPartitionAssignedTime(final long lastPartitionAssignedMs) { + this.lastPartitionAssignedMs = lastPartitionAssignedMs; + } + + /** + * Sets the state + * + * @param newState New state + * @return The state prior to the call to setState, or null if the transition is invalid + */ + State setState(final State newState) { + final State oldState; + + synchronized (stateLock) { + oldState = state; + + if (state == State.PENDING_SHUTDOWN && newState != State.DEAD) { + log.debug("Ignoring request to transit from PENDING_SHUTDOWN to {}: " + + "only DEAD state is a valid next state", newState); + // when the state is already in PENDING_SHUTDOWN, all other transitions will be + // refused but we do not throw exception here + return null; + } else if (state == State.DEAD) { + log.debug("Ignoring request to transit from DEAD to {}: " + + "no valid next state after DEAD", newState); + // when the state is already in NOT_RUNNING, all its transitions + // will be refused but we do not throw exception here + return null; + } else if (!state.isValidTransition(newState)) { + log.error("Unexpected state transition from {} to {}", oldState, newState); + throw new StreamsException(logPrefix + "Unexpected state transition from " + oldState + " to " + newState); + } else { + log.info("State transition from {} to {}", oldState, newState); + } + + state = newState; + if (newState == State.RUNNING) { + updateThreadMetadata(taskManager.activeTaskMap(), taskManager.standbyTaskMap()); + } + + stateLock.notifyAll(); + } + + if (stateListener != null) { + stateListener.onChange(this, state, oldState); + } + + return oldState; + } + + public boolean isRunning() { + synchronized (stateLock) { + return state.isAlive(); + } + } + + private final Time time; + private final Logger log; + private final String logPrefix; + public final Object stateLock; + private final Duration pollTime; + private final long commitTimeMs; + private final int maxPollTimeMs; + private final String originalReset; + private final TaskManager taskManager; + + private final StreamsMetricsImpl streamsMetrics; + private final Sensor commitSensor; + private final Sensor pollSensor; + private final Sensor pollRecordsSensor; + private final Sensor punctuateSensor; + private final Sensor processRecordsSensor; + private final Sensor processLatencySensor; + private final Sensor processRateSensor; + private final Sensor pollRatioSensor; + private final Sensor processRatioSensor; + private final Sensor punctuateRatioSensor; + private final Sensor commitRatioSensor; + private final Sensor failedStreamThreadSensor; + + private static final long LOG_SUMMARY_INTERVAL_MS = 2 * 60 * 1000L; // log a summary of processing every 2 minutes + private long lastLogSummaryMs = -1L; + private long totalRecordsProcessedSinceLastSummary = 0L; + private long totalPunctuatorsSinceLastSummary = 0L; + private long totalCommittedSinceLastSummary = 0L; + + private long now; + private long lastPollMs; + private long lastCommitMs; + private long lastPartitionAssignedMs = -1L; + private int numIterations; + private volatile State state = State.CREATED; + private volatile ThreadMetadata threadMetadata; + private StreamThread.StateListener stateListener; + private final Optional getGroupInstanceID; + + private final ChangelogReader changelogReader; + private final ConsumerRebalanceListener rebalanceListener; + private final Consumer mainConsumer; + private final Consumer restoreConsumer; + private final Admin adminClient; + private final TopologyMetadata topologyMetadata; + private final java.util.function.Consumer cacheResizer; + + private java.util.function.Consumer streamsUncaughtExceptionHandler; + private final Runnable shutdownErrorHook; + + private long lastSeenTopologyVersion = 0L; + + // These must be Atomic references as they are shared and used to signal between the assignor and the stream thread + private final AtomicInteger assignmentErrorCode; + private final AtomicLong nextProbingRebalanceMs; + + // These are used to signal from outside the stream thread, but the variables themselves are internal to the thread + private final AtomicLong cacheResizeSize = new AtomicLong(-1L); + private final AtomicBoolean leaveGroupRequested = new AtomicBoolean(false); + private final boolean eosEnabled; + + public static StreamThread create(final TopologyMetadata topologyMetadata, + final StreamsConfig config, + final KafkaClientSupplier clientSupplier, + final Admin adminClient, + final UUID processId, + final String clientId, + final StreamsMetricsImpl streamsMetrics, + final Time time, + final StreamsMetadataState streamsMetadataState, + final long cacheSizeBytes, + final StateDirectory stateDirectory, + final StateRestoreListener userStateRestoreListener, + final int threadIdx, + final Runnable shutdownErrorHook, + final java.util.function.Consumer streamsUncaughtExceptionHandler) { + final String threadId = clientId + "-StreamThread-" + threadIdx; + + final String logPrefix = String.format("stream-thread [%s] ", threadId); + final LogContext logContext = new LogContext(logPrefix); + final Logger log = logContext.logger(StreamThread.class); + + final ReferenceContainer referenceContainer = new ReferenceContainer(); + referenceContainer.adminClient = adminClient; + referenceContainer.streamsMetadataState = streamsMetadataState; + referenceContainer.time = time; + + log.info("Creating restore consumer client"); + final Map restoreConsumerConfigs = config.getRestoreConsumerConfigs(getRestoreConsumerClientId(threadId)); + final Consumer restoreConsumer = clientSupplier.getRestoreConsumer(restoreConsumerConfigs); + + final StoreChangelogReader changelogReader = new StoreChangelogReader( + time, + config, + logContext, + adminClient, + restoreConsumer, + userStateRestoreListener + ); + + final ThreadCache cache = new ThreadCache(logContext, cacheSizeBytes, streamsMetrics); + + final ActiveTaskCreator activeTaskCreator = new ActiveTaskCreator( + topologyMetadata, + config, + streamsMetrics, + stateDirectory, + changelogReader, + cache, + time, + clientSupplier, + threadId, + processId, + log + ); + final StandbyTaskCreator standbyTaskCreator = new StandbyTaskCreator( + topologyMetadata, + config, + streamsMetrics, + stateDirectory, + changelogReader, + threadId, + log + ); + final TaskManager taskManager = new TaskManager( + time, + changelogReader, + processId, + logPrefix, + streamsMetrics, + activeTaskCreator, + standbyTaskCreator, + topologyMetadata, + adminClient, + stateDirectory, + processingMode(config) + ); + referenceContainer.taskManager = taskManager; + + log.info("Creating consumer client"); + final String applicationId = config.getString(StreamsConfig.APPLICATION_ID_CONFIG); + final Map consumerConfigs = config.getMainConsumerConfigs(applicationId, getConsumerClientId(threadId), threadIdx); + consumerConfigs.put(StreamsConfig.InternalConfig.REFERENCE_CONTAINER_PARTITION_ASSIGNOR, referenceContainer); + + final String originalReset = (String) consumerConfigs.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG); + // If there are any overrides, we never fall through to the consumer, but only handle offset management ourselves. + if (topologyMetadata.hasOffsetResetOverrides()) { + consumerConfigs.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "none"); + } + + final Consumer mainConsumer = clientSupplier.getConsumer(consumerConfigs); + changelogReader.setMainConsumer(mainConsumer); + taskManager.setMainConsumer(mainConsumer); + referenceContainer.mainConsumer = mainConsumer; + + final StreamThread streamThread = new StreamThread( + time, + config, + adminClient, + mainConsumer, + restoreConsumer, + changelogReader, + originalReset, + taskManager, + streamsMetrics, + topologyMetadata, + threadId, + logContext, + referenceContainer.assignmentErrorCode, + referenceContainer.nextScheduledRebalanceMs, + shutdownErrorHook, + streamsUncaughtExceptionHandler, + cache::resize + ); + + return streamThread.updateThreadMetadata(getSharedAdminClientId(clientId)); + } + + public enum ProcessingMode { + AT_LEAST_ONCE("AT_LEAST_ONCE"), + + EXACTLY_ONCE_ALPHA("EXACTLY_ONCE_ALPHA"), + + EXACTLY_ONCE_V2("EXACTLY_ONCE_V2"); + + public final String name; + + ProcessingMode(final String name) { + this.name = name; + } + } + + // Note: the below two methods are static methods here instead of methods on StreamsConfig because it's a public API + + @SuppressWarnings("deprecation") + public static ProcessingMode processingMode(final StreamsConfig config) { + if (StreamsConfig.EXACTLY_ONCE.equals(config.getString(StreamsConfig.PROCESSING_GUARANTEE_CONFIG))) { + return StreamThread.ProcessingMode.EXACTLY_ONCE_ALPHA; + } else if (StreamsConfig.EXACTLY_ONCE_BETA.equals(config.getString(StreamsConfig.PROCESSING_GUARANTEE_CONFIG))) { + return StreamThread.ProcessingMode.EXACTLY_ONCE_V2; + } else if (StreamsConfig.EXACTLY_ONCE_V2.equals(config.getString(StreamsConfig.PROCESSING_GUARANTEE_CONFIG))) { + return StreamThread.ProcessingMode.EXACTLY_ONCE_V2; + } else { + return StreamThread.ProcessingMode.AT_LEAST_ONCE; + } + } + + public static boolean eosEnabled(final StreamsConfig config) { + return eosEnabled(processingMode(config)); + } + + public static boolean eosEnabled(final ProcessingMode processingMode) { + return processingMode == ProcessingMode.EXACTLY_ONCE_ALPHA || + processingMode == ProcessingMode.EXACTLY_ONCE_V2; + } + + public StreamThread(final Time time, + final StreamsConfig config, + final Admin adminClient, + final Consumer mainConsumer, + final Consumer restoreConsumer, + final ChangelogReader changelogReader, + final String originalReset, + final TaskManager taskManager, + final StreamsMetricsImpl streamsMetrics, + final TopologyMetadata topologyMetadata, + final String threadId, + final LogContext logContext, + final AtomicInteger assignmentErrorCode, + final AtomicLong nextProbingRebalanceMs, + final Runnable shutdownErrorHook, + final java.util.function.Consumer streamsUncaughtExceptionHandler, + final java.util.function.Consumer cacheResizer) { + super(threadId); + this.stateLock = new Object(); + this.adminClient = adminClient; + this.streamsMetrics = streamsMetrics; + this.commitSensor = ThreadMetrics.commitSensor(threadId, streamsMetrics); + this.pollSensor = ThreadMetrics.pollSensor(threadId, streamsMetrics); + this.pollRecordsSensor = ThreadMetrics.pollRecordsSensor(threadId, streamsMetrics); + this.pollRatioSensor = ThreadMetrics.pollRatioSensor(threadId, streamsMetrics); + this.processLatencySensor = ThreadMetrics.processLatencySensor(threadId, streamsMetrics); + this.processRecordsSensor = ThreadMetrics.processRecordsSensor(threadId, streamsMetrics); + this.processRateSensor = ThreadMetrics.processRateSensor(threadId, streamsMetrics); + this.processRatioSensor = ThreadMetrics.processRatioSensor(threadId, streamsMetrics); + this.punctuateSensor = ThreadMetrics.punctuateSensor(threadId, streamsMetrics); + this.punctuateRatioSensor = ThreadMetrics.punctuateRatioSensor(threadId, streamsMetrics); + this.commitRatioSensor = ThreadMetrics.commitRatioSensor(threadId, streamsMetrics); + this.failedStreamThreadSensor = ClientMetrics.failedStreamThreadSensor(streamsMetrics); + this.assignmentErrorCode = assignmentErrorCode; + this.shutdownErrorHook = shutdownErrorHook; + this.streamsUncaughtExceptionHandler = streamsUncaughtExceptionHandler; + this.cacheResizer = cacheResizer; + + // The following sensors are created here but their references are not stored in this object, since within + // this object they are not recorded. The sensors are created here so that the stream threads starts with all + // its metrics initialised. Otherwise, those sensors would have been created during processing, which could + // lead to missing metrics. If no task were created, the metrics for created and closed + // tasks would never be added to the metrics. + ThreadMetrics.createTaskSensor(threadId, streamsMetrics); + ThreadMetrics.closeTaskSensor(threadId, streamsMetrics); + + ThreadMetrics.addThreadStartTimeMetric( + threadId, + streamsMetrics, + time.milliseconds() + ); + ThreadMetrics.addThreadBlockedTimeMetric( + threadId, + new StreamThreadTotalBlockedTime( + mainConsumer, + restoreConsumer, + taskManager::totalProducerBlockedTime + ), + streamsMetrics + ); + + this.time = time; + this.topologyMetadata = topologyMetadata; + this.logPrefix = logContext.logPrefix(); + this.log = logContext.logger(StreamThread.class); + this.rebalanceListener = new StreamsRebalanceListener(time, taskManager, this, this.log, this.assignmentErrorCode); + this.taskManager = taskManager; + this.restoreConsumer = restoreConsumer; + this.mainConsumer = mainConsumer; + this.changelogReader = changelogReader; + this.originalReset = originalReset; + this.nextProbingRebalanceMs = nextProbingRebalanceMs; + this.getGroupInstanceID = mainConsumer.groupMetadata().groupInstanceId(); + + this.pollTime = Duration.ofMillis(config.getLong(StreamsConfig.POLL_MS_CONFIG)); + final int dummyThreadIdx = 1; + this.maxPollTimeMs = new InternalConsumerConfig(config.getMainConsumerConfigs("dummyGroupId", "dummyClientId", dummyThreadIdx)) + .getInt(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG); + this.commitTimeMs = config.getLong(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG); + + this.numIterations = 1; + this.eosEnabled = eosEnabled(config); + } + + private static final class InternalConsumerConfig extends ConsumerConfig { + private InternalConsumerConfig(final Map props) { + super(ConsumerConfig.appendDeserializerToConfig(props, new ByteArrayDeserializer(), + new ByteArrayDeserializer()), false); + } + } + + /** + * Execute the stream processors + * + * @throws KafkaException for any Kafka-related exceptions + * @throws RuntimeException for any other non-Kafka exceptions + */ + @Override + public void run() { + log.info("Starting"); + if (setState(State.STARTING) == null) { + log.info("StreamThread already shutdown. Not running"); + return; + } + boolean cleanRun = false; + try { + cleanRun = runLoop(); + } catch (final Throwable e) { + failedStreamThreadSensor.record(); + this.streamsUncaughtExceptionHandler.accept(e); + } finally { + completeShutdown(cleanRun); + } + } + + /** + * Main event loop for polling, and processing records through topologies. + * + * @throws IllegalStateException If store gets registered after initialized is already finished + * @throws StreamsException if the store's change log does not contain the partition + */ + @SuppressWarnings("deprecation") // Needed to include StreamsConfig.EXACTLY_ONCE_BETA in error log for UnsupportedVersionException + boolean runLoop() { + subscribeConsumer(); + + // if the thread is still in the middle of a rebalance, we should keep polling + // until the rebalance is completed before we close and commit the tasks + while (isRunning() || taskManager.isRebalanceInProgress()) { + try { + maybeSendShutdown(); + final long size = cacheResizeSize.getAndSet(-1L); + if (size != -1L) { + cacheResizer.accept(size); + } + runOnce(); + if (nextProbingRebalanceMs.get() < time.milliseconds()) { + log.info("Triggering the followup rebalance scheduled for {} ms.", nextProbingRebalanceMs.get()); + mainConsumer.enforceRebalance(); + nextProbingRebalanceMs.set(Long.MAX_VALUE); + } + } catch (final TaskCorruptedException e) { + log.warn("Detected the states of tasks " + e.corruptedTasks() + " are corrupted. " + + "Will close the task as dirty and re-create and bootstrap from scratch.", e); + try { + // check if any active task got corrupted. We will trigger a rebalance in that case. + // once the task corruptions have been handled + final boolean enforceRebalance = taskManager.handleCorruption(e.corruptedTasks()); + if (enforceRebalance && eosEnabled) { + log.info("Active task(s) got corrupted. Triggering a rebalance."); + mainConsumer.enforceRebalance(); + } + } catch (final TaskMigratedException taskMigrated) { + handleTaskMigrated(taskMigrated); + } + } catch (final TaskMigratedException e) { + handleTaskMigrated(e); + } catch (final UnsupportedVersionException e) { + final String errorMessage = e.getMessage(); + if (errorMessage != null && + errorMessage.startsWith("Broker unexpectedly doesn't support requireStable flag on version ")) { + + log.error("Shutting down because the Kafka cluster seems to be on a too old version. " + + "Setting {}=\"{}\"/\"{}\" requires broker version 2.5 or higher.", + StreamsConfig.PROCESSING_GUARANTEE_CONFIG, + StreamsConfig.EXACTLY_ONCE_V2, StreamsConfig.EXACTLY_ONCE_BETA); + } + failedStreamThreadSensor.record(); + this.streamsUncaughtExceptionHandler.accept(new StreamsException(e)); + return false; + } catch (final StreamsException e) { + throw e; + } catch (final Exception e) { + throw new StreamsException(e); + } + } + return true; + } + + /** + * Sets the streams uncaught exception handler. + * + * @param streamsUncaughtExceptionHandler the user handler wrapped in shell to execute the action + */ + public void setStreamsUncaughtExceptionHandler(final java.util.function.Consumer streamsUncaughtExceptionHandler) { + this.streamsUncaughtExceptionHandler = streamsUncaughtExceptionHandler; + } + + public void maybeSendShutdown() { + if (assignmentErrorCode.get() == AssignorError.SHUTDOWN_REQUESTED.code()) { + log.warn("Detected that shutdown was requested. " + + "All clients in this app will now begin to shutdown"); + mainConsumer.enforceRebalance(); + } + } + + public boolean waitOnThreadState(final StreamThread.State targetState, final long timeoutMs) { + final long begin = time.milliseconds(); + synchronized (stateLock) { + boolean interrupted = false; + long elapsedMs = 0L; + try { + while (state != targetState) { + if (timeoutMs >= elapsedMs) { + final long remainingMs = timeoutMs - elapsedMs; + try { + stateLock.wait(remainingMs); + } catch (final InterruptedException e) { + interrupted = true; + } + } else { + log.debug("Cannot transit to {} within {}ms", targetState, timeoutMs); + return false; + } + elapsedMs = time.milliseconds() - begin; + } + return true; + } finally { + // Make sure to restore the interruption status before returning. + // We do not always own the current thread that executes this method, i.e., we do not know the + // interruption policy of the thread. The least we can do is restore the interruption status before + // the current thread exits this method. + if (interrupted) { + Thread.currentThread().interrupt(); + } + } + } + } + + public void shutdownToError() { + shutdownErrorHook.run(); + } + + public void sendShutdownRequest(final AssignorError assignorError) { + assignmentErrorCode.set(assignorError.code()); + } + + private void handleTaskMigrated(final TaskMigratedException e) { + log.warn("Detected that the thread is being fenced. " + + "This implies that this thread missed a rebalance and dropped out of the consumer group. " + + "Will close out all assigned tasks and rejoin the consumer group.", e); + + taskManager.handleLostAll(); + mainConsumer.unsubscribe(); + subscribeConsumer(); + } + + private void subscribeConsumer() { + if (topologyMetadata.usesPatternSubscription()) { + mainConsumer.subscribe(topologyMetadata.sourceTopicPattern(), rebalanceListener); + } else { + mainConsumer.subscribe(topologyMetadata.sourceTopicCollection(), rebalanceListener); + } + } + + public void resizeCache(final long size) { + cacheResizeSize.set(size); + } + + /** + * One iteration of a thread includes the following steps: + * + * 1. poll records from main consumer and add to buffer; + * 2. restore from restore consumer and update standby tasks if necessary; + * 3. process active tasks from the buffers; + * 4. punctuate active tasks if necessary; + * 5. commit all tasks if necessary; + * + * Among them, step 3/4/5 is done in batches in which we try to process as much as possible while trying to + * stop iteration to call the next iteration when it's close to the next main consumer's poll deadline + * + * @throws IllegalStateException If store gets registered after initialized is already finished + * @throws StreamsException If the store's change log does not contain the partition + * @throws TaskMigratedException If another thread wrote to the changelog topic that is currently restored + * or if committing offsets failed (non-EOS) + * or if the task producer got fenced (EOS) + */ + // Visible for testing + void runOnce() { + final long startMs = time.milliseconds(); + now = startMs; + + final long pollLatency = pollPhase(); + + // Shutdown hook could potentially be triggered and transit the thread state to PENDING_SHUTDOWN during #pollRequests(). + // The task manager internal states could be uninitialized if the state transition happens during #onPartitionsAssigned(). + // Should only proceed when the thread is still running after #pollRequests(), because no external state mutation + // could affect the task manager state beyond this point within #runOnce(). + if (!isRunning()) { + log.debug("Thread state is already {}, skipping the run once call after poll request", state); + return; + } + + initializeAndRestorePhase(); + + // TODO: we should record the restore latency and its relative time spent ratio after + // we figure out how to move this method out of the stream thread + advanceNowAndComputeLatency(); + + int totalProcessed = 0; + long totalCommitLatency = 0L; + long totalProcessLatency = 0L; + long totalPunctuateLatency = 0L; + if (state == State.RUNNING) { + /* + * Within an iteration, after processing up to N (N initialized as 1 upon start up) records for each applicable tasks, check the current time: + * 1. If it is time to punctuate, do it; + * 2. If it is time to commit, do it, this should be after 1) since punctuate may trigger commit; + * 3. If there's no records processed, end the current iteration immediately; + * 4. If we are close to consumer's next poll deadline, end the current iteration immediately; + * 5. If any of 1), 2) and 4) happens, half N for next iteration; + * 6. Otherwise, increment N. + */ + do { + log.debug("Processing tasks with {} iterations.", numIterations); + final int processed = taskManager.process(numIterations, time); + final long processLatency = advanceNowAndComputeLatency(); + totalProcessLatency += processLatency; + if (processed > 0) { + // It makes no difference to the outcome of these metrics when we record "0", + // so we can just avoid the method call when we didn't process anything. + processRateSensor.record(processed, now); + + // This metric is scaled to represent the _average_ processing time of _each_ + // task. Note, it's hard to interpret this as defined; the per-task process-ratio + // as well as total time ratio spent on processing compared with polling / committing etc + // are reported on other metrics. + processLatencySensor.record(processLatency / (double) processed, now); + + totalProcessed += processed; + totalRecordsProcessedSinceLastSummary += processed; + } + + log.debug("Processed {} records with {} iterations; invoking punctuators if necessary", + processed, + numIterations); + + final int punctuated = taskManager.punctuate(); + totalPunctuatorsSinceLastSummary += punctuated; + final long punctuateLatency = advanceNowAndComputeLatency(); + totalPunctuateLatency += punctuateLatency; + if (punctuated > 0) { + punctuateSensor.record(punctuateLatency / (double) punctuated, now); + } + + log.debug("{} punctuators ran.", punctuated); + + final long beforeCommitMs = now; + final int committed = maybeCommit(); + totalCommittedSinceLastSummary += committed; + final long commitLatency = Math.max(now - beforeCommitMs, 0); + totalCommitLatency += commitLatency; + if (committed > 0) { + commitSensor.record(commitLatency / (double) committed, now); + + if (log.isDebugEnabled()) { + log.debug("Committed all active tasks {} and standby tasks {} in {}ms", + taskManager.activeTaskIds(), taskManager.standbyTaskIds(), commitLatency); + } + } + + if (processed == 0) { + // if there are no records to be processed, exit after punctuate / commit + break; + } else if (Math.max(now - lastPollMs, 0) > maxPollTimeMs / 2) { + numIterations = numIterations > 1 ? numIterations / 2 : numIterations; + break; + } else if (punctuated > 0 || committed > 0) { + numIterations = numIterations > 1 ? numIterations / 2 : numIterations; + } else { + numIterations++; + } + } while (true); + + // we record the ratio out of the while loop so that the accumulated latency spans over + // multiple iterations with reasonably large max.num.records and hence is less vulnerable to outliers + taskManager.recordTaskProcessRatio(totalProcessLatency, now); + } + + now = time.milliseconds(); + final long runOnceLatency = now - startMs; + processRecordsSensor.record(totalProcessed, now); + processRatioSensor.record((double) totalProcessLatency / runOnceLatency, now); + punctuateRatioSensor.record((double) totalPunctuateLatency / runOnceLatency, now); + pollRatioSensor.record((double) pollLatency / runOnceLatency, now); + commitRatioSensor.record((double) totalCommitLatency / runOnceLatency, now); + + final boolean logProcessingSummary = now - lastLogSummaryMs > LOG_SUMMARY_INTERVAL_MS; + if (logProcessingSummary) { + log.info("Processed {} total records, ran {} punctuators, and committed {} total tasks since the last update", + totalRecordsProcessedSinceLastSummary, totalPunctuatorsSinceLastSummary, totalCommittedSinceLastSummary); + + totalRecordsProcessedSinceLastSummary = 0L; + totalPunctuatorsSinceLastSummary = 0L; + totalCommittedSinceLastSummary = 0L; + lastLogSummaryMs = now; + } + } + + private void initializeAndRestorePhase() { + // only try to initialize the assigned tasks + // if the state is still in PARTITION_ASSIGNED after the poll call + final State stateSnapshot = state; + if (stateSnapshot == State.PARTITIONS_ASSIGNED + || stateSnapshot == State.RUNNING && taskManager.needsInitializationOrRestoration()) { + + log.debug("State is {}; initializing tasks if necessary", stateSnapshot); + + // transit to restore active is idempotent so we can call it multiple times + changelogReader.enforceRestoreActive(); + + if (taskManager.tryToCompleteRestoration(now, partitions -> resetOffsets(partitions, null))) { + changelogReader.transitToUpdateStandby(); + log.info("Restoration took {} ms for all tasks {}", time.milliseconds() - lastPartitionAssignedMs, + taskManager.tasks().keySet()); + setState(State.RUNNING); + } + + if (log.isDebugEnabled()) { + log.debug("Initialization call done. State is {}", state); + } + } + + if (log.isDebugEnabled()) { + log.debug("Idempotently invoking restoration logic in state {}", state); + } + // we can always let changelog reader try restoring in order to initialize the changelogs; + // if there's no active restoring or standby updating it would not try to fetch any data + changelogReader.restore(taskManager.tasks()); + log.debug("Idempotent restore call done. Thread state has not changed."); + } + + // Check if the topology has been updated since we last checked, ie via #addNamedTopology or #removeNamedTopology + private void checkForTopologyUpdates() { + if (lastSeenTopologyVersion < topologyMetadata.topologyVersion() || topologyMetadata.isEmpty()) { + lastSeenTopologyVersion = topologyMetadata.topologyVersion(); + taskManager.handleTopologyUpdates(); + + topologyMetadata.maybeWaitForNonEmptyTopology(() -> state); + + // TODO KAFKA-12648 Pt.4: optimize to avoid always triggering a rebalance for each thread on every update + log.info("StreamThread has detected an update to the topology, triggering a rebalance to refresh the assignment"); + subscribeConsumer(); + mainConsumer.enforceRebalance(); + } + } + + private long pollPhase() { + checkForTopologyUpdates(); + + final ConsumerRecords records; + log.debug("Invoking poll on main Consumer"); + + if (state == State.PARTITIONS_ASSIGNED) { + // try to fetch some records with zero poll millis + // to unblock the restoration as soon as possible + records = pollRequests(Duration.ZERO); + } else if (state == State.PARTITIONS_REVOKED) { + // try to fetch some records with zero poll millis to unblock + // other useful work while waiting for the join response + records = pollRequests(Duration.ZERO); + } else if (state == State.RUNNING || state == State.STARTING) { + // try to fetch some records with normal poll time + // in order to get long polling + records = pollRequests(pollTime); + } else if (state == State.PENDING_SHUTDOWN) { + // we are only here because there's rebalance in progress, + // just poll with zero to complete it + records = pollRequests(Duration.ZERO); + } else { + // any other state should not happen + log.error("Unexpected state {} during normal iteration", state); + throw new StreamsException(logPrefix + "Unexpected state " + state + " during normal iteration"); + } + + final long pollLatency = advanceNowAndComputeLatency(); + + final int numRecords = records.count(); + + for (final TopicPartition topicPartition: records.partitions()) { + records + .records(topicPartition) + .stream() + .max(Comparator.comparing(ConsumerRecord::offset)) + .ifPresent(t -> taskManager.updateTaskEndMetadata(topicPartition, t.offset())); + } + + log.debug("Main Consumer poll completed in {} ms and fetched {} records", pollLatency, numRecords); + + pollSensor.record(pollLatency, now); + + if (!records.isEmpty()) { + pollRecordsSensor.record(numRecords, now); + taskManager.addRecordsToTasks(records); + } + return pollLatency; + } + + /** + * Get the next batch of records by polling. + * + * @param pollTime how long to block in Consumer#poll + * @return Next batch of records or null if no records available. + * @throws TaskMigratedException if the task producer got fenced (EOS only) + */ + private ConsumerRecords pollRequests(final Duration pollTime) { + ConsumerRecords records = ConsumerRecords.empty(); + + lastPollMs = now; + + try { + records = mainConsumer.poll(pollTime); + } catch (final InvalidOffsetException e) { + resetOffsets(e.partitions(), e); + } + + return records; + } + + private void resetOffsets(final Set partitions, final Exception cause) { + final Set loggedTopics = new HashSet<>(); + final Set seekToBeginning = new HashSet<>(); + final Set seekToEnd = new HashSet<>(); + final Set notReset = new HashSet<>(); + + for (final TopicPartition partition : partitions) { + switch (topologyMetadata.offsetResetStrategy(partition.topic())) { + case EARLIEST: + addToResetList(partition, seekToBeginning, "Setting topic '{}' to consume from {} offset", "earliest", loggedTopics); + break; + case LATEST: + addToResetList(partition, seekToEnd, "Setting topic '{}' to consume from {} offset", "latest", loggedTopics); + break; + case NONE: + if ("earliest".equals(originalReset)) { + addToResetList(partition, seekToBeginning, "No custom setting defined for topic '{}' using original config '{}' for offset reset", "earliest", loggedTopics); + } else if ("latest".equals(originalReset)) { + addToResetList(partition, seekToEnd, "No custom setting defined for topic '{}' using original config '{}' for offset reset", "latest", loggedTopics); + } else { + notReset.add(partition); + } + break; + default: + throw new IllegalStateException("Unable to locate topic " + partition.topic() + " in the topology"); + } + } + + if (notReset.isEmpty()) { + if (!seekToBeginning.isEmpty()) { + mainConsumer.seekToBeginning(seekToBeginning); + } + + if (!seekToEnd.isEmpty()) { + mainConsumer.seekToEnd(seekToEnd); + } + } else { + final String notResetString = + notReset.stream() + .map(TopicPartition::topic) + .distinct() + .collect(Collectors.joining(",")); + + final String format = String.format( + "No valid committed offset found for input [%s] and no valid reset policy configured." + + " You need to set configuration parameter \"auto.offset.reset\" or specify a topic specific reset " + + "policy via StreamsBuilder#stream(..., Consumed.with(Topology.AutoOffsetReset)) or " + + "StreamsBuilder#table(..., Consumed.with(Topology.AutoOffsetReset))", + notResetString + ); + + if (cause == null) { + throw new StreamsException(format); + } else { + throw new StreamsException(format, cause); + } + } + } + + private void addToResetList(final TopicPartition partition, final Set partitions, final String logMessage, final String resetPolicy, final Set loggedTopics) { + final String topic = partition.topic(); + if (loggedTopics.add(topic)) { + log.info(logMessage, topic, resetPolicy); + } + partitions.add(partition); + } + + /** + * Try to commit all active tasks owned by this thread. + * + * Visible for testing. + * + * @throws TaskMigratedException if committing offsets failed (non-EOS) + * or if the task producer got fenced (EOS) + */ + int maybeCommit() { + final int committed; + if (now - lastCommitMs > commitTimeMs) { + if (log.isDebugEnabled()) { + log.debug("Committing all active tasks {} and standby tasks {} since {}ms has elapsed (commit interval is {}ms)", + taskManager.activeTaskIds(), taskManager.standbyTaskIds(), now - lastCommitMs, commitTimeMs); + } + + committed = taskManager.commit( + taskManager.tasks() + .values() + .stream() + .filter(t -> t.state() == Task.State.RUNNING || t.state() == Task.State.RESTORING) + .collect(Collectors.toSet()) + ); + + if (committed > 0) { + // try to purge the committed records for repartition topics if possible + taskManager.maybePurgeCommittedRecords(); + } + + if (committed == -1) { + log.debug("Unable to commit as we are in the middle of a rebalance, will try again when it completes."); + } else { + now = time.milliseconds(); + lastCommitMs = now; + } + } else { + committed = taskManager.maybeCommitActiveTasksPerUserRequested(); + } + + return committed; + } + + /** + * Compute the latency based on the current marked timestamp, and update the marked timestamp + * with the current system timestamp. + * + * @return latency + */ + private long advanceNowAndComputeLatency() { + final long previous = now; + now = time.milliseconds(); + + return Math.max(now - previous, 0); + } + + /** + * Shutdown this stream thread. + *

                + * Note that there is nothing to prevent this function from being called multiple times + * (e.g., in testing), hence the state is set only the first time + */ + public void shutdown() { + log.info("Informed to shut down"); + final State oldState = setState(State.PENDING_SHUTDOWN); + if (oldState == State.CREATED) { + // The thread may not have been started. Take responsibility for shutting down + completeShutdown(true); + } + } + + private void completeShutdown(final boolean cleanRun) { + // set the state to pending shutdown first as it may be called due to error; + // its state may already be PENDING_SHUTDOWN so it will return false but we + // intentionally do not check the returned flag + setState(State.PENDING_SHUTDOWN); + + log.info("Shutting down"); + + try { + taskManager.shutdown(cleanRun); + } catch (final Throwable e) { + log.error("Failed to close task manager due to the following error:", e); + } + try { + changelogReader.clear(); + } catch (final Throwable e) { + log.error("Failed to close changelog reader due to the following error:", e); + } + if (leaveGroupRequested.get()) { + mainConsumer.unsubscribe(); + } + try { + mainConsumer.close(); + } catch (final Throwable e) { + log.error("Failed to close consumer due to the following error:", e); + } + try { + restoreConsumer.close(); + } catch (final Throwable e) { + log.error("Failed to close restore consumer due to the following error:", e); + } + streamsMetrics.removeAllThreadLevelSensors(getName()); + streamsMetrics.removeAllThreadLevelMetrics(getName()); + + setState(State.DEAD); + + log.info("Shutdown complete"); + } + + /** + * Return information about the current {@link StreamThread}. + * + * @return {@link ThreadMetadata}. + */ + public final ThreadMetadata threadMetadata() { + return threadMetadata; + } + + // package-private for testing only + StreamThread updateThreadMetadata(final String adminClientId) { + + threadMetadata = new ThreadMetadataImpl( + getName(), + state().name(), + getConsumerClientId(getName()), + getRestoreConsumerClientId(getName()), + taskManager.producerClientIds(), + adminClientId, + Collections.emptySet(), + Collections.emptySet()); + + return this; + } + + private void updateThreadMetadata(final Map activeTasks, + final Map standbyTasks) { + final Set activeTasksMetadata = new HashSet<>(); + for (final Map.Entry task : activeTasks.entrySet()) { + activeTasksMetadata.add(new TaskMetadataImpl( + task.getValue().id(), + task.getValue().inputPartitions(), + task.getValue().committedOffsets(), + task.getValue().highWaterMark(), + task.getValue().timeCurrentIdlingStarted() + )); + } + final Set standbyTasksMetadata = new HashSet<>(); + for (final Map.Entry task : standbyTasks.entrySet()) { + standbyTasksMetadata.add(new TaskMetadataImpl( + task.getValue().id(), + task.getValue().inputPartitions(), + task.getValue().committedOffsets(), + task.getValue().highWaterMark(), + task.getValue().timeCurrentIdlingStarted() + )); + } + + final String adminClientId = threadMetadata.adminClientId(); + threadMetadata = new ThreadMetadataImpl( + getName(), + state().name(), + getConsumerClientId(getName()), + getRestoreConsumerClientId(getName()), + taskManager.producerClientIds(), + adminClientId, + activeTasksMetadata, + standbyTasksMetadata + ); + } + + public Map activeTaskMap() { + return taskManager.activeTaskMap(); + } + + public List activeTasks() { + return taskManager.activeTaskIterable(); + } + + public Map allTasks() { + return taskManager.tasks(); + } + + /** + * Produces a string representation containing useful information about a StreamThread. + * This is useful in debugging scenarios. + * + * @return A string representation of the StreamThread instance. + */ + @Override + public String toString() { + return toString(""); + } + + /** + * Produces a string representation containing useful information about a StreamThread, starting with the given indent. + * This is useful in debugging scenarios. + * + * @return A string representation of the StreamThread instance. + */ + public String toString(final String indent) { + return indent + "\tStreamsThread threadId: " + getName() + "\n" + taskManager.toString(indent); + } + + public Optional getGroupInstanceID() { + return getGroupInstanceID; + } + + public void requestLeaveGroupDuringShutdown() { + this.leaveGroupRequested.set(true); + } + + public Map producerMetrics() { + return taskManager.producerMetrics(); + } + + public Map consumerMetrics() { + return ClientUtils.consumerMetrics(mainConsumer, restoreConsumer); + } + + public Map adminClientMetrics() { + return ClientUtils.adminClientMetrics(adminClient); + } + + public Object getStateLock() { + return stateLock; + } + + // the following are for testing only + void setNow(final long now) { + this.now = now; + } + + TaskManager taskManager() { + return taskManager; + } + + int currentNumIterations() { + return numIterations; + } + + ConsumerRebalanceListener rebalanceListener() { + return rebalanceListener; + } + + Consumer mainConsumer() { + return mainConsumer; + } + + Consumer restoreConsumer() { + return restoreConsumer; + } + + Admin adminClient() { + return adminClient; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThreadTotalBlockedTime.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThreadTotalBlockedTime.java new file mode 100644 index 0000000..dc21615 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamThreadTotalBlockedTime.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.processor.internals; + +import java.util.Map; +import java.util.function.Supplier; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; + +public class StreamThreadTotalBlockedTime { + private final Consumer consumer; + private final Consumer restoreConsumer; + private final Supplier producerTotalBlockedTime; + + StreamThreadTotalBlockedTime( + final Consumer consumer, + final Consumer restoreConsumer, + final Supplier producerTotalBlockedTime) { + this.consumer = consumer; + this.restoreConsumer = restoreConsumer; + this.producerTotalBlockedTime = producerTotalBlockedTime; + } + + private double metricValue( + final Map metrics, + final String name) { + return metrics.keySet().stream() + .filter(n -> n.name().equals(name)) + .findFirst() + .map(n -> (Double) metrics.get(n).metricValue()) + .orElse(0.0); + } + + public double compute() { + return metricValue(consumer.metrics(), "io-wait-time-ns-total") + + metricValue(consumer.metrics(), "io-time-ns-total") + + metricValue(consumer.metrics(), "committed-time-ns-total") + + metricValue(consumer.metrics(), "commit-sync-time-ns-total") + + metricValue(restoreConsumer.metrics(), "io-wait-time-ns-total") + + metricValue(restoreConsumer.metrics(), "io-time-ns-total") + + producerTotalBlockedTime.get(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsMetadataState.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsMetadataState.java new file mode 100644 index 0000000..8cc1016 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsMetadataState.java @@ -0,0 +1,344 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Stream; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyQueryMetadata; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StreamPartitioner; +import org.apache.kafka.streams.state.HostInfo; +import org.apache.kafka.streams.StreamsMetadata; +import org.apache.kafka.streams.state.internals.StreamsMetadataImpl; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + + +/** + * Provides access to the {@link StreamsMetadata} in a KafkaStreams application. This can be used + * to discover the locations of {@link org.apache.kafka.streams.processor.StateStore}s + * in a KafkaStreams application + */ +public class StreamsMetadataState { + public static final HostInfo UNKNOWN_HOST = HostInfo.unavailable(); + private final TopologyMetadata topologyMetadata; + private final Set globalStores; + private final HostInfo thisHost; + private List allMetadata = Collections.emptyList(); + private Cluster clusterMetadata; + private final AtomicReference localMetadata = new AtomicReference<>(null); + + public StreamsMetadataState(final TopologyMetadata topologyMetadata, final HostInfo thisHost) { + this.topologyMetadata = topologyMetadata; + this.globalStores = this.topologyMetadata.globalStateStores().keySet(); + this.thisHost = thisHost; + } + + @Override + public String toString() { + return toString(""); + } + + public String toString(final String indent) { + final StringBuilder builder = new StringBuilder(); + builder.append(indent).append("GlobalMetadata: ").append(allMetadata).append("\n"); + builder.append(indent).append("GlobalStores: ").append(globalStores).append("\n"); + builder.append(indent).append("My HostInfo: ").append(thisHost).append("\n"); + builder.append(indent).append(clusterMetadata).append("\n"); + + return builder.toString(); + } + + /** + * Get the {@link StreamsMetadata}s for the local instance in a {@link KafkaStreams application} + * + * @return the {@link StreamsMetadata}s for the local instance in a {@link KafkaStreams} application + */ + public StreamsMetadata getLocalMetadata() { + return localMetadata.get(); + } + + /** + * Find all of the {@link StreamsMetadata}s in a + * {@link KafkaStreams application} + * + * @return all the {@link StreamsMetadata}s in a {@link KafkaStreams} application + */ + public Collection getAllMetadata() { + return Collections.unmodifiableList(allMetadata); + } + + /** + * Find all of the {@link StreamsMetadata}s for a given storeName + * + * @param storeName the storeName to find metadata for + * @return A collection of {@link StreamsMetadata} that have the provided storeName + */ + public synchronized Collection getAllMetadataForStore(final String storeName) { + Objects.requireNonNull(storeName, "storeName cannot be null"); + + if (!isInitialized()) { + return Collections.emptyList(); + } + + if (globalStores.contains(storeName)) { + return allMetadata; + } + + final Collection sourceTopics = topologyMetadata.sourceTopicsForStore(storeName); + if (sourceTopics.isEmpty()) { + return Collections.emptyList(); + } + + final ArrayList results = new ArrayList<>(); + for (final StreamsMetadata metadata : allMetadata) { + if (metadata.stateStoreNames().contains(storeName) || metadata.standbyStateStoreNames().contains(storeName)) { + results.add(metadata); + } + } + return results; + } + + /** + * Find the {@link KeyQueryMetadata}s for a given storeName and key. This method will use the + * {@link DefaultStreamPartitioner} to locate the store. If a custom partitioner has been used + * please use {@link StreamsMetadataState#getKeyQueryMetadataForKey(String, Object, StreamPartitioner)} instead. + * + * Note: the key may not exist in the {@link org.apache.kafka.streams.processor.StateStore}, + * this method provides a way of finding which {@link KeyQueryMetadata} it would exist on. + * + * @param storeName Name of the store + * @param key Key to use + * @param keySerializer Serializer for the key + * @param key type + * @return The {@link KeyQueryMetadata} for the storeName and key or {@link KeyQueryMetadata#NOT_AVAILABLE} + * if streams is (re-)initializing or {@code null} if the corresponding topic cannot be found, + * or null if no matching metadata could be found. + */ + public synchronized KeyQueryMetadata getKeyQueryMetadataForKey(final String storeName, + final K key, + final Serializer keySerializer) { + Objects.requireNonNull(keySerializer, "keySerializer can't be null"); + return getKeyQueryMetadataForKey(storeName, + key, + new DefaultStreamPartitioner<>(keySerializer, clusterMetadata)); + } + + /** + * Find the {@link KeyQueryMetadata}s for a given storeName and key + * + * Note: the key may not exist in the {@link StateStore},this method provides a way of finding which + * {@link StreamsMetadata} it would exist on. + * + * @param storeName Name of the store + * @param key Key to use + * @param partitioner partitioner to use to find correct partition for key + * @param key type + * @return The {@link KeyQueryMetadata} for the storeName and key or {@link KeyQueryMetadata#NOT_AVAILABLE} + * if streams is (re-)initializing, or {@code null} if no matching metadata could be found. + */ + public synchronized KeyQueryMetadata getKeyQueryMetadataForKey(final String storeName, + final K key, + final StreamPartitioner partitioner) { + Objects.requireNonNull(storeName, "storeName can't be null"); + Objects.requireNonNull(key, "key can't be null"); + Objects.requireNonNull(partitioner, "partitioner can't be null"); + + if (!isInitialized()) { + return KeyQueryMetadata.NOT_AVAILABLE; + } + + if (globalStores.contains(storeName)) { + // global stores are on every node. if we don't have the host info + // for this host then just pick the first metadata + if (thisHost.equals(UNKNOWN_HOST)) { + return new KeyQueryMetadata(allMetadata.get(0).hostInfo(), Collections.emptySet(), -1); + } + return new KeyQueryMetadata(localMetadata.get().hostInfo(), Collections.emptySet(), -1); + } + + final SourceTopicsInfo sourceTopicsInfo = getSourceTopicsInfo(storeName); + if (sourceTopicsInfo == null) { + return null; + } + return getKeyQueryMetadataForKey(storeName, key, partitioner, sourceTopicsInfo); + } + + /** + * Respond to changes to the HostInfo -> TopicPartition mapping. Will rebuild the + * metadata + * + * @param activePartitionHostMap the current mapping of {@link HostInfo} -> {@link TopicPartition}s for active partitions + * @param standbyPartitionHostMap the current mapping of {@link HostInfo} -> {@link TopicPartition}s for standby partitions + * @param clusterMetadata the current clusterMetadata {@link Cluster} + */ + synchronized void onChange(final Map> activePartitionHostMap, + final Map> standbyPartitionHostMap, + final Cluster clusterMetadata) { + this.clusterMetadata = clusterMetadata; + rebuildMetadata(activePartitionHostMap, standbyPartitionHostMap); + } + + private boolean hasPartitionsForAnyTopics(final List topicNames, final Set partitionForHost) { + for (final TopicPartition topicPartition : partitionForHost) { + if (topicNames.contains(topicPartition.topic())) { + return true; + } + } + return false; + } + + private Set getStoresOnHost(final Map> storeToSourceTopics, final Set sourceTopicPartitions) { + final Set storesOnHost = new HashSet<>(); + for (final Map.Entry> storeTopicEntry : storeToSourceTopics.entrySet()) { + final List topicsForStore = storeTopicEntry.getValue(); + if (hasPartitionsForAnyTopics(topicsForStore, sourceTopicPartitions)) { + storesOnHost.add(storeTopicEntry.getKey()); + } + } + return storesOnHost; + } + + + private void rebuildMetadata(final Map> activePartitionHostMap, + final Map> standbyPartitionHostMap) { + if (activePartitionHostMap.isEmpty() && standbyPartitionHostMap.isEmpty()) { + allMetadata = Collections.emptyList(); + localMetadata.set(new StreamsMetadataImpl( + thisHost, + Collections.emptySet(), + Collections.emptySet(), + Collections.emptySet(), + Collections.emptySet() + )); + return; + } + + final List rebuiltMetadata = new ArrayList<>(); + final Map> storeToSourceTopics = topologyMetadata.stateStoreNameToSourceTopics(); + Stream.concat(activePartitionHostMap.keySet().stream(), standbyPartitionHostMap.keySet().stream()) + .distinct() + .forEach(hostInfo -> { + final Set activePartitionsOnHost = new HashSet<>(); + final Set activeStoresOnHost = new HashSet<>(); + if (activePartitionHostMap.containsKey(hostInfo)) { + activePartitionsOnHost.addAll(activePartitionHostMap.get(hostInfo)); + activeStoresOnHost.addAll(getStoresOnHost(storeToSourceTopics, activePartitionsOnHost)); + } + activeStoresOnHost.addAll(globalStores); + + final Set standbyPartitionsOnHost = new HashSet<>(); + final Set standbyStoresOnHost = new HashSet<>(); + if (standbyPartitionHostMap.containsKey(hostInfo)) { + standbyPartitionsOnHost.addAll(standbyPartitionHostMap.get(hostInfo)); + standbyStoresOnHost.addAll(getStoresOnHost(storeToSourceTopics, standbyPartitionsOnHost)); + } + + final StreamsMetadata metadata = new StreamsMetadataImpl( + hostInfo, + activeStoresOnHost, + activePartitionsOnHost, + standbyStoresOnHost, + standbyPartitionsOnHost); + rebuiltMetadata.add(metadata); + if (hostInfo.equals(thisHost)) { + localMetadata.set(metadata); + } + }); + + allMetadata = rebuiltMetadata; + } + + private KeyQueryMetadata getKeyQueryMetadataForKey(final String storeName, + final K key, + final StreamPartitioner partitioner, + final SourceTopicsInfo sourceTopicsInfo) { + + final Integer partition = partitioner.partition(sourceTopicsInfo.topicWithMostPartitions, key, null, sourceTopicsInfo.maxPartitions); + final Set matchingPartitions = new HashSet<>(); + for (final String sourceTopic : sourceTopicsInfo.sourceTopics) { + matchingPartitions.add(new TopicPartition(sourceTopic, partition)); + } + + HostInfo activeHost = UNKNOWN_HOST; + final Set standbyHosts = new HashSet<>(); + for (final StreamsMetadata streamsMetadata : allMetadata) { + final Set activeStateStoreNames = streamsMetadata.stateStoreNames(); + final Set topicPartitions = new HashSet<>(streamsMetadata.topicPartitions()); + final Set standbyStateStoreNames = streamsMetadata.standbyStateStoreNames(); + final Set standbyTopicPartitions = new HashSet<>(streamsMetadata.standbyTopicPartitions()); + + topicPartitions.retainAll(matchingPartitions); + if (activeStateStoreNames.contains(storeName) && !topicPartitions.isEmpty()) { + activeHost = streamsMetadata.hostInfo(); + } + + standbyTopicPartitions.retainAll(matchingPartitions); + if (standbyStateStoreNames.contains(storeName) && !standbyTopicPartitions.isEmpty()) { + standbyHosts.add(streamsMetadata.hostInfo()); + } + } + + return new KeyQueryMetadata(activeHost, standbyHosts, partition); + } + + private SourceTopicsInfo getSourceTopicsInfo(final String storeName) { + final List sourceTopics = new ArrayList<>(topologyMetadata.sourceTopicsForStore(storeName)); + if (sourceTopics.isEmpty()) { + return null; + } + return new SourceTopicsInfo(sourceTopics); + } + + private boolean isInitialized() { + + return clusterMetadata != null && !clusterMetadata.topics().isEmpty() && localMetadata.get() != null; + } + + public String getStoreForChangelogTopic(final String topicName) { + return topologyMetadata.getStoreForChangelogTopic(topicName); + } + + private class SourceTopicsInfo { + private final List sourceTopics; + private int maxPartitions; + private String topicWithMostPartitions; + + private SourceTopicsInfo(final List sourceTopics) { + this.sourceTopics = sourceTopics; + for (final String topic : sourceTopics) { + final List partitions = clusterMetadata.partitionsForTopic(topic); + if (partitions.size() > maxPartitions) { + maxPartitions = partitions.size(); + topicWithMostPartitions = partitions.get(0).topic(); + } + } + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java new file mode 100644 index 0000000..2ae381d --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignor.java @@ -0,0 +1,1451 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.ListOffsetsResult.ListOffsetsResultInfo; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerGroupMetadata; +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.Configurable; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.errors.MissingSourceTopicException; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.errors.TaskAssignmentException; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder.TopicsInfo; +import org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology; +import org.apache.kafka.streams.processor.internals.assignment.AssignmentInfo; +import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration; +import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs; +import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentListener; +import org.apache.kafka.streams.processor.internals.assignment.AssignorError; +import org.apache.kafka.streams.processor.internals.assignment.ClientState; +import org.apache.kafka.streams.processor.internals.assignment.CopartitionedTopicsEnforcer; +import org.apache.kafka.streams.processor.internals.assignment.FallbackPriorTaskAssignor; +import org.apache.kafka.streams.processor.internals.assignment.ReferenceContainer; +import org.apache.kafka.streams.processor.internals.assignment.StickyTaskAssignor; +import org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo; +import org.apache.kafka.streams.processor.internals.assignment.TaskAssignor; +import org.apache.kafka.streams.state.HostInfo; +import org.slf4j.Logger; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.PriorityQueue; +import java.util.Queue; +import java.util.Set; +import java.util.SortedSet; +import java.util.TreeMap; +import java.util.TreeSet; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import static java.util.UUID.randomUUID; + +import static org.apache.kafka.common.utils.Utils.filterMap; +import static org.apache.kafka.streams.processor.internals.ClientUtils.fetchCommittedOffsets; +import static org.apache.kafka.streams.processor.internals.ClientUtils.fetchEndOffsetsFuture; +import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.EARLIEST_PROBEABLE_VERSION; +import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.LATEST_SUPPORTED_VERSION; +import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.UNKNOWN; +import static org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo.UNKNOWN_OFFSET_SUM; + +public class StreamsPartitionAssignor implements ConsumerPartitionAssignor, Configurable { + + private Logger log; + private String logPrefix; + + private static class AssignedPartition implements Comparable { + + private final TaskId taskId; + private final TopicPartition partition; + + AssignedPartition(final TaskId taskId, final TopicPartition partition) { + this.taskId = taskId; + this.partition = partition; + } + + @Override + public int compareTo(final AssignedPartition that) { + return PARTITION_COMPARATOR.compare(partition, that.partition); + } + + @Override + public boolean equals(final Object o) { + if (!(o instanceof AssignedPartition)) { + return false; + } + final AssignedPartition other = (AssignedPartition) o; + return compareTo(other) == 0; + } + + @Override + public int hashCode() { + // Only partition is important for compareTo, equals and hashCode. + return partition.hashCode(); + } + } + + private static class ClientMetadata { + + private final HostInfo hostInfo; + private final ClientState state; + private final SortedSet consumers; + + ClientMetadata(final String endPoint) { + + // get the host info, or null if no endpoint is configured (ie endPoint == null) + hostInfo = HostInfo.buildFromEndpoint(endPoint); + + // initialize the consumer memberIds + consumers = new TreeSet<>(); + + // initialize the client state + state = new ClientState(); + } + + void addConsumer(final String consumerMemberId, final List ownedPartitions) { + consumers.add(consumerMemberId); + state.incrementCapacity(); + state.addOwnedPartitions(ownedPartitions, consumerMemberId); + } + + void addPreviousTasksAndOffsetSums(final String consumerId, final Map taskOffsetSums) { + state.addPreviousTasksAndOffsetSums(consumerId, taskOffsetSums); + } + + @Override + public String toString() { + return "ClientMetadata{" + + "hostInfo=" + hostInfo + + ", consumers=" + consumers + + ", state=" + state + + '}'; + } + } + + // keep track of any future consumers in a "dummy" Client since we can't decipher their subscription + private static final UUID FUTURE_ID = randomUUID(); + + protected static final Comparator PARTITION_COMPARATOR = + Comparator.comparing(TopicPartition::topic).thenComparingInt(TopicPartition::partition); + + private String userEndPoint; + private AssignmentConfigs assignmentConfigs; + + // for the main consumer, we need to use a supplier to break a cyclic setup dependency + private Supplier> mainConsumerSupplier; + private Admin adminClient; + private TaskManager taskManager; + private StreamsMetadataState streamsMetadataState; + private PartitionGrouper partitionGrouper; + private AtomicInteger assignmentErrorCode; + private AtomicLong nextScheduledRebalanceMs; + private Time time; + + protected int usedSubscriptionMetadataVersion = LATEST_SUPPORTED_VERSION; + + private InternalTopicManager internalTopicManager; + private CopartitionedTopicsEnforcer copartitionedTopicsEnforcer; + private RebalanceProtocol rebalanceProtocol; + private AssignmentListener assignmentListener; + + private Supplier taskAssignorSupplier; + private byte uniqueField; + + /** + * We need to have the PartitionAssignor and its StreamThread to be mutually accessible since the former needs + * latter's cached metadata while sending subscriptions, and the latter needs former's returned assignment when + * adding tasks. + * + * @throws KafkaException if the stream thread is not specified + */ + @Override + public void configure(final Map configs) { + final AssignorConfiguration assignorConfiguration = new AssignorConfiguration(configs); + + logPrefix = assignorConfiguration.logPrefix(); + log = new LogContext(logPrefix).logger(getClass()); + usedSubscriptionMetadataVersion = assignorConfiguration.configuredMetadataVersion(usedSubscriptionMetadataVersion); + + final ReferenceContainer referenceContainer = assignorConfiguration.referenceContainer(); + mainConsumerSupplier = () -> Objects.requireNonNull(referenceContainer.mainConsumer, "Main consumer was not specified"); + adminClient = Objects.requireNonNull(referenceContainer.adminClient, "Admin client was not specified"); + taskManager = Objects.requireNonNull(referenceContainer.taskManager, "TaskManager was not specified"); + streamsMetadataState = Objects.requireNonNull(referenceContainer.streamsMetadataState, "StreamsMetadataState was not specified"); + assignmentErrorCode = referenceContainer.assignmentErrorCode; + nextScheduledRebalanceMs = referenceContainer.nextScheduledRebalanceMs; + time = Objects.requireNonNull(referenceContainer.time, "Time was not specified"); + assignmentConfigs = assignorConfiguration.assignmentConfigs(); + partitionGrouper = new PartitionGrouper(); + userEndPoint = assignorConfiguration.userEndPoint(); + internalTopicManager = assignorConfiguration.internalTopicManager(); + copartitionedTopicsEnforcer = assignorConfiguration.copartitionedTopicsEnforcer(); + rebalanceProtocol = assignorConfiguration.rebalanceProtocol(); + taskAssignorSupplier = assignorConfiguration::taskAssignor; + assignmentListener = assignorConfiguration.assignmentListener(); + uniqueField = 0; + } + + @Override + public String name() { + return "stream"; + } + + @Override + public List supportedProtocols() { + final List supportedProtocols = new ArrayList<>(); + supportedProtocols.add(RebalanceProtocol.EAGER); + if (rebalanceProtocol == RebalanceProtocol.COOPERATIVE) { + supportedProtocols.add(rebalanceProtocol); + } + return supportedProtocols; + } + + @Override + public ByteBuffer subscriptionUserData(final Set topics) { + // Adds the following information to subscription + // 1. Client UUID (a unique id assigned to an instance of KafkaStreams) + // 2. Map from task id to its overall lag + // 3. Unique Field to ensure a rebalance when a thread rejoins by forcing the user data to be different + + handleRebalanceStart(topics); + uniqueField++; + + final Set currentNamedTopologies = taskManager.topologyMetadata().namedTopologiesView(); + + // If using NamedTopologies, filter out any that are no longer recognized/have been removed + final Map taskOffsetSums = taskManager.topologyMetadata().hasNamedTopologies() ? + filterMap(taskManager.getTaskOffsetSums(), t -> currentNamedTopologies.contains(t.getKey().topologyName())) : + taskManager.getTaskOffsetSums(); + + return new SubscriptionInfo( + usedSubscriptionMetadataVersion, + LATEST_SUPPORTED_VERSION, + taskManager.processId(), + userEndPoint, + taskOffsetSums, + uniqueField, + assignmentErrorCode.get() + ).encode(); + } + + private Map errorAssignment(final Map clientsMetadata, + final int errorCode) { + final Map assignment = new HashMap<>(); + for (final ClientMetadata clientMetadata : clientsMetadata.values()) { + for (final String consumerId : clientMetadata.consumers) { + assignment.put(consumerId, new Assignment( + Collections.emptyList(), + new AssignmentInfo(LATEST_SUPPORTED_VERSION, + Collections.emptyList(), + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap(), + errorCode).encode() + )); + } + } + return assignment; + } + + /* + * This assigns tasks to consumer clients in the following steps. + * + * 0. decode the subscriptions to assemble the metadata for each client and check for version probing + * + * 1. check all repartition source topics and use internal topic manager to make sure + * they have been created with the right number of partitions. Also verify and/or create + * any changelog topics with the correct number of partitions. + * + * 2. use the partition grouper to generate tasks along with their assigned partitions, then use + * the configured TaskAssignor to construct the mapping of tasks to clients. + * + * 3. construct the global mapping of host to partitions to enable query routing. + * + * 4. within each client, assign tasks to consumer clients. + */ + @Override + public GroupAssignment assign(final Cluster metadata, final GroupSubscription groupSubscription) { + final Map subscriptions = groupSubscription.groupSubscription(); + + // ---------------- Step Zero ---------------- // + + // construct the client metadata from the decoded subscription info + + final Map clientMetadataMap = new HashMap<>(); + final Set allOwnedPartitions = new HashSet<>(); + + int minReceivedMetadataVersion = LATEST_SUPPORTED_VERSION; + int minSupportedMetadataVersion = LATEST_SUPPORTED_VERSION; + + boolean shutdownRequested = false; + boolean assignementErrorFound = false; + int futureMetadataVersion = UNKNOWN; + for (final Map.Entry entry : subscriptions.entrySet()) { + final String consumerId = entry.getKey(); + final Subscription subscription = entry.getValue(); + final SubscriptionInfo info = SubscriptionInfo.decode(subscription.userData()); + final int usedVersion = info.version(); + if (info.errorCode() == AssignorError.SHUTDOWN_REQUESTED.code()) { + shutdownRequested = true; + } + + minReceivedMetadataVersion = updateMinReceivedVersion(usedVersion, minReceivedMetadataVersion); + minSupportedMetadataVersion = updateMinSupportedVersion(info.latestSupportedVersion(), minSupportedMetadataVersion); + + final UUID processId; + if (usedVersion > LATEST_SUPPORTED_VERSION) { + futureMetadataVersion = usedVersion; + processId = FUTURE_ID; + if (!clientMetadataMap.containsKey(FUTURE_ID)) { + clientMetadataMap.put(FUTURE_ID, new ClientMetadata(null)); + } + } else { + processId = info.processId(); + } + + ClientMetadata clientMetadata = clientMetadataMap.get(processId); + + // create the new client metadata if necessary + if (clientMetadata == null) { + clientMetadata = new ClientMetadata(info.userEndPoint()); + clientMetadataMap.put(info.processId(), clientMetadata); + } + + // add the consumer and any info in its subscription to the client + clientMetadata.addConsumer(consumerId, subscription.ownedPartitions()); + final int prevSize = allOwnedPartitions.size(); + allOwnedPartitions.addAll(subscription.ownedPartitions()); + if (allOwnedPartitions.size() < prevSize + subscription.ownedPartitions().size()) { + assignementErrorFound = true; + } + clientMetadata.addPreviousTasksAndOffsetSums(consumerId, info.taskOffsetSums()); + } + + if (assignementErrorFound) { + log.warn("The previous assignment contains a partition more than once. " + + "\t Mapping: {}", subscriptions); + } + + try { + final boolean versionProbing = + checkMetadataVersions(minReceivedMetadataVersion, minSupportedMetadataVersion, futureMetadataVersion); + + log.debug("Constructed client metadata {} from the member subscriptions.", clientMetadataMap); + + // ---------------- Step One ---------------- // + + if (shutdownRequested) { + return new GroupAssignment(errorAssignment(clientMetadataMap, AssignorError.SHUTDOWN_REQUESTED.code())); + } + + // parse the topology to determine the repartition source topics, + // making sure they are created with the number of partitions as + // the maximum of the depending sub-topologies source topics' number of partitions + final Map allRepartitionTopicPartitions = prepareRepartitionTopics(metadata); + final Cluster fullMetadata = metadata.withPartitions(allRepartitionTopicPartitions); + log.debug("Created repartition topics {} from the parsed topology.", allRepartitionTopicPartitions.values()); + + // ---------------- Step Two ---------------- // + + // construct the assignment of tasks to clients + + final Map topicGroups = taskManager.topologyMetadata().topicGroups(); + + final Set allSourceTopics = new HashSet<>(); + final Map> sourceTopicsByGroup = new HashMap<>(); + for (final Map.Entry entry : topicGroups.entrySet()) { + allSourceTopics.addAll(entry.getValue().sourceTopics); + sourceTopicsByGroup.put(entry.getKey(), entry.getValue().sourceTopics); + } + + // get the tasks as partition groups from the partition grouper + final Map> partitionsForTask = + partitionGrouper.partitionGroups(sourceTopicsByGroup, fullMetadata); + + final Set statefulTasks = new HashSet<>(); + + final boolean probingRebalanceNeeded = assignTasksToClients(fullMetadata, allSourceTopics, topicGroups, clientMetadataMap, partitionsForTask, statefulTasks); + + // ---------------- Step Three ---------------- // + + // construct the global partition assignment per host map + + final Map> partitionsByHost = new HashMap<>(); + final Map> standbyPartitionsByHost = new HashMap<>(); + if (minReceivedMetadataVersion >= 2) { + populatePartitionsByHostMaps(partitionsByHost, standbyPartitionsByHost, partitionsForTask, clientMetadataMap); + } + streamsMetadataState.onChange(partitionsByHost, standbyPartitionsByHost, fullMetadata); + + // ---------------- Step Four ---------------- // + + // compute the assignment of tasks to threads within each client and build the final group assignment + + final Map assignment = computeNewAssignment( + statefulTasks, + clientMetadataMap, + partitionsForTask, + partitionsByHost, + standbyPartitionsByHost, + allOwnedPartitions, + minReceivedMetadataVersion, + minSupportedMetadataVersion, + versionProbing, + probingRebalanceNeeded + ); + + return new GroupAssignment(assignment); + } catch (final MissingSourceTopicException e) { + log.error("Caught an error in the task assignment. Returning an error assignment.", e); + return new GroupAssignment( + errorAssignment(clientMetadataMap, AssignorError.INCOMPLETE_SOURCE_TOPIC_METADATA.code()) + ); + } catch (final TaskAssignmentException e) { + log.error("Caught an error in the task assignment. Returning an error assignment.", e); + return new GroupAssignment( + errorAssignment(clientMetadataMap, AssignorError.ASSIGNMENT_ERROR.code()) + ); + } + } + + /** + * Verify the subscription versions are within the expected bounds and check for version probing. + * + * @return whether this was a version probing rebalance + */ + private boolean checkMetadataVersions(final int minReceivedMetadataVersion, + final int minSupportedMetadataVersion, + final int futureMetadataVersion) { + final boolean versionProbing; + + if (futureMetadataVersion == UNKNOWN) { + versionProbing = false; + } else if (minReceivedMetadataVersion >= EARLIEST_PROBEABLE_VERSION) { + versionProbing = true; + log.info("Received a future (version probing) subscription (version: {})." + + " Sending assignment back (with supported version {}).", + futureMetadataVersion, + minSupportedMetadataVersion); + + } else { + throw new TaskAssignmentException( + "Received a future (version probing) subscription (version: " + futureMetadataVersion + + ") and an incompatible pre Kafka 2.0 subscription (version: " + minReceivedMetadataVersion + + ") at the same time." + ); + } + + if (minReceivedMetadataVersion < LATEST_SUPPORTED_VERSION) { + log.info("Downgrade metadata to version {}. Latest supported version is {}.", + minReceivedMetadataVersion, + LATEST_SUPPORTED_VERSION); + } + if (minSupportedMetadataVersion < LATEST_SUPPORTED_VERSION) { + log.info("Downgrade latest supported metadata to version {}. Latest supported version is {}.", + minSupportedMetadataVersion, + LATEST_SUPPORTED_VERSION); + } + return versionProbing; + } + + /** + * Computes and assembles all repartition topic metadata then creates the topics if necessary. + * + * @return map from repartition topic to its partition info + */ + private Map prepareRepartitionTopics(final Cluster metadata) { + + final RepartitionTopics repartitionTopics = new RepartitionTopics( + taskManager.topologyMetadata(), + internalTopicManager, + copartitionedTopicsEnforcer, + metadata, + logPrefix + ); + repartitionTopics.setup(); + return repartitionTopics.topicPartitionsInfo(); + } + + /** + * Populates the taskForPartition and tasksForTopicGroup maps, and checks that partitions are assigned to exactly + * one task. + * + * @param taskForPartition a map from partition to the corresponding task. Populated here. + * @param tasksForTopicGroup a map from the topicGroupId to the set of corresponding tasks. Populated here. + * @param allSourceTopics a set of all source topics in the topology + * @param partitionsForTask a map from task to the set of input partitions + * @param fullMetadata the cluster metadata + */ + private void populateTasksForMaps(final Map taskForPartition, + final Map> tasksForTopicGroup, + final Set allSourceTopics, + final Map> partitionsForTask, + final Cluster fullMetadata) { + // check if all partitions are assigned, and there are no duplicates of partitions in multiple tasks + final Set allAssignedPartitions = new HashSet<>(); + for (final Map.Entry> entry : partitionsForTask.entrySet()) { + final TaskId id = entry.getKey(); + final Set partitions = entry.getValue(); + + for (final TopicPartition partition : partitions) { + taskForPartition.put(partition, id); + if (allAssignedPartitions.contains(partition)) { + log.warn("Partition {} is assigned to more than one tasks: {}", partition, partitionsForTask); + } + } + allAssignedPartitions.addAll(partitions); + + tasksForTopicGroup.computeIfAbsent(new Subtopology(id.subtopology(), id.topologyName()), k -> new HashSet<>()).add(id); + } + + checkAllPartitions(allSourceTopics, partitionsForTask, allAssignedPartitions, fullMetadata); + } + + // Logs a warning if any partitions are not assigned to a task, or a task has no assigned partitions + private void checkAllPartitions(final Set allSourceTopics, + final Map> partitionsForTask, + final Set allAssignedPartitions, + final Cluster fullMetadata) { + for (final String topic : allSourceTopics) { + final List partitionInfoList = fullMetadata.partitionsForTopic(topic); + if (partitionInfoList.isEmpty()) { + log.warn("No partitions found for topic {}", topic); + } else { + for (final PartitionInfo partitionInfo : partitionInfoList) { + final TopicPartition partition = new TopicPartition(partitionInfo.topic(), + partitionInfo.partition()); + if (!allAssignedPartitions.contains(partition)) { + log.warn("Partition {} is not assigned to any tasks: {}" + + " Possible causes of a partition not getting assigned" + + " is that another topic defined in the topology has not been" + + " created when starting your streams application," + + " resulting in no tasks created for this topology at all.", partition, + partitionsForTask); + } + } + } + } + } + + /** + * Assigns a set of tasks to each client (Streams instance) using the configured task assignor, and also + * populate the stateful tasks that have been assigned to the clients + * @return true if a probing rebalance should be triggered + */ + private boolean assignTasksToClients(final Cluster fullMetadata, + final Set allSourceTopics, + final Map topicGroups, + final Map clientMetadataMap, + final Map> partitionsForTask, + final Set statefulTasks) { + if (!statefulTasks.isEmpty()) { + throw new TaskAssignmentException("The stateful tasks should not be populated before assigning tasks to clients"); + } + + final Map taskForPartition = new HashMap<>(); + final Map> tasksForTopicGroup = new HashMap<>(); + populateTasksForMaps(taskForPartition, tasksForTopicGroup, allSourceTopics, partitionsForTask, fullMetadata); + + final ChangelogTopics changelogTopics = new ChangelogTopics( + internalTopicManager, + topicGroups, + tasksForTopicGroup, + logPrefix + ); + changelogTopics.setup(); + + final Map clientStates = new HashMap<>(); + final boolean lagComputationSuccessful = + populateClientStatesMap(clientStates, clientMetadataMap, taskForPartition, changelogTopics); + + log.info("All members participating in this rebalance: \n{}.", + clientStates.entrySet().stream() + .map(entry -> entry.getKey() + ": " + entry.getValue().consumers()) + .collect(Collectors.joining(Utils.NL))); + + final Set allTasks = partitionsForTask.keySet(); + statefulTasks.addAll(changelogTopics.statefulTaskIds()); + + log.debug("Assigning tasks {} including stateful {} to clients {} with number of replicas {}", + allTasks, statefulTasks, clientStates, numStandbyReplicas()); + + final TaskAssignor taskAssignor = createTaskAssignor(lagComputationSuccessful); + + final boolean probingRebalanceNeeded = taskAssignor.assign(clientStates, + allTasks, + statefulTasks, + assignmentConfigs); + + log.info("Assigned tasks {} including stateful {} to clients as: \n{}.", + allTasks, statefulTasks, clientStates.entrySet().stream() + .map(entry -> entry.getKey() + "=" + entry.getValue().currentAssignment()) + .collect(Collectors.joining(Utils.NL))); + + return probingRebalanceNeeded; + } + + private TaskAssignor createTaskAssignor(final boolean lagComputationSuccessful) { + final TaskAssignor taskAssignor = taskAssignorSupplier.get(); + if (taskAssignor instanceof StickyTaskAssignor) { + // special case: to preserve pre-existing behavior, we invoke the StickyTaskAssignor + // whether or not lag computation failed. + return taskAssignor; + } else if (lagComputationSuccessful) { + return taskAssignor; + } else { + log.info("Failed to fetch end offsets for changelogs, will return previous assignment to clients and " + + "trigger another rebalance to retry."); + return new FallbackPriorTaskAssignor(); + } + } + + /** + * Builds a map from client to state, and readies each ClientState for assignment by adding any missing prev tasks + * and computing the per-task overall lag based on the fetched end offsets for each changelog. + * + * @param clientStates a map from each client to its state, including offset lags. Populated by this method. + * @param clientMetadataMap a map from each client to its full metadata + * @param taskForPartition map from topic partition to its corresponding task + * @param changelogTopics object that manages changelog topics + * + * @return whether we were able to successfully fetch the changelog end offsets and compute each client's lag + */ + private boolean populateClientStatesMap(final Map clientStates, + final Map clientMetadataMap, + final Map taskForPartition, + final ChangelogTopics changelogTopics) { + boolean fetchEndOffsetsSuccessful; + Map allTaskEndOffsetSums; + try { + // Make the listOffsets request first so it can fetch the offsets for non-source changelogs + // asynchronously while we use the blocking Consumer#committed call to fetch source-changelog offsets + final KafkaFuture> endOffsetsFuture = + fetchEndOffsetsFuture(changelogTopics.preExistingNonSourceTopicBasedPartitions(), adminClient); + + final Map sourceChangelogEndOffsets = + fetchCommittedOffsets(changelogTopics.preExistingSourceTopicBasedPartitions(), mainConsumerSupplier.get()); + + final Map endOffsets = ClientUtils.getEndOffsets(endOffsetsFuture); + + allTaskEndOffsetSums = computeEndOffsetSumsByTask( + endOffsets, + sourceChangelogEndOffsets, + changelogTopics + ); + fetchEndOffsetsSuccessful = true; + } catch (final StreamsException | TimeoutException e) { + allTaskEndOffsetSums = changelogTopics.statefulTaskIds().stream().collect(Collectors.toMap(t -> t, t -> UNKNOWN_OFFSET_SUM)); + fetchEndOffsetsSuccessful = false; + } + + for (final Map.Entry entry : clientMetadataMap.entrySet()) { + final UUID uuid = entry.getKey(); + final ClientState state = entry.getValue().state; + state.initializePrevTasks(taskForPartition); + + state.computeTaskLags(uuid, allTaskEndOffsetSums); + clientStates.put(uuid, state); + } + + return fetchEndOffsetsSuccessful; + } + + /** + * @param endOffsets the listOffsets result from the adminClient + * @param sourceChangelogEndOffsets the end (committed) offsets of optimized source changelogs + * @param changelogTopics object that manages changelog topics + * + * @return Map from stateful task to its total end offset summed across all changelog partitions + */ + private Map computeEndOffsetSumsByTask(final Map endOffsets, + final Map sourceChangelogEndOffsets, + final ChangelogTopics changelogTopics) { + + final Map taskEndOffsetSums = new HashMap<>(); + for (final TaskId taskId : changelogTopics.statefulTaskIds()) { + taskEndOffsetSums.put(taskId, 0L); + for (final TopicPartition changelogPartition : changelogTopics.preExistingPartitionsFor(taskId)) { + final long changelogPartitionEndOffset; + if (sourceChangelogEndOffsets.containsKey(changelogPartition)) { + changelogPartitionEndOffset = sourceChangelogEndOffsets.get(changelogPartition); + } else if (endOffsets.containsKey(changelogPartition)) { + changelogPartitionEndOffset = endOffsets.get(changelogPartition).offset(); + } else { + log.debug("Fetched offsets did not contain the changelog {} of task {}", changelogPartition, taskId); + throw new IllegalStateException("Could not get end offset for " + changelogPartition); + } + final long newEndOffsetSum = taskEndOffsetSums.get(taskId) + changelogPartitionEndOffset; + if (newEndOffsetSum < 0) { + taskEndOffsetSums.put(taskId, Long.MAX_VALUE); + break; + } else { + taskEndOffsetSums.put(taskId, newEndOffsetSum); + } + } + } + return taskEndOffsetSums; + } + + /** + * Populates the global partitionsByHost and standbyPartitionsByHost maps that are sent to each member + * + * @param partitionsByHost a map from host to the set of partitions hosted there. Populated here. + * @param standbyPartitionsByHost a map from host to the set of standby partitions hosted there. Populated here. + * @param partitionsForTask a map from task to its set of assigned partitions + * @param clientMetadataMap a map from client to its metadata and state + */ + private void populatePartitionsByHostMaps(final Map> partitionsByHost, + final Map> standbyPartitionsByHost, + final Map> partitionsForTask, + final Map clientMetadataMap) { + for (final Map.Entry entry : clientMetadataMap.entrySet()) { + final HostInfo hostInfo = entry.getValue().hostInfo; + + // if application server is configured, also include host state map + if (hostInfo != null) { + final Set topicPartitions = new HashSet<>(); + final Set standbyPartitions = new HashSet<>(); + final ClientState state = entry.getValue().state; + + for (final TaskId id : state.activeTasks()) { + topicPartitions.addAll(partitionsForTask.get(id)); + } + + for (final TaskId id : state.standbyTasks()) { + standbyPartitions.addAll(partitionsForTask.get(id)); + } + + partitionsByHost.put(hostInfo, topicPartitions); + standbyPartitionsByHost.put(hostInfo, standbyPartitions); + } + } + } + + /** + * Computes the assignment of tasks to threads within each client and assembles the final assignment to send out. + * + * @return the final assignment for each StreamThread consumer + */ + private Map computeNewAssignment(final Set statefulTasks, + final Map clientsMetadata, + final Map> partitionsForTask, + final Map> partitionsByHostState, + final Map> standbyPartitionsByHost, + final Set allOwnedPartitions, + final int minUserMetadataVersion, + final int minSupportedMetadataVersion, + final boolean versionProbing, + final boolean shouldTriggerProbingRebalance) { + boolean rebalanceRequired = shouldTriggerProbingRebalance || versionProbing; + final Map assignment = new HashMap<>(); + + // within the client, distribute tasks to its owned consumers + for (final Map.Entry clientEntry : clientsMetadata.entrySet()) { + final UUID clientId = clientEntry.getKey(); + final ClientMetadata clientMetadata = clientEntry.getValue(); + final ClientState state = clientMetadata.state; + final SortedSet consumers = clientMetadata.consumers; + + final Map> activeTaskAssignment = assignTasksToThreads( + state.statefulActiveTasks(), + state.statelessActiveTasks(), + consumers, + state + ); + + final Map> standbyTaskAssignment = assignTasksToThreads( + state.standbyTasks(), + Collections.emptySet(), + consumers, + state + ); + + // Arbitrarily choose the leader's client to be responsible for triggering the probing rebalance, + // note once we pick the first consumer within the process to trigger probing rebalance, other consumer + // would not set to trigger any more. + final boolean encodeNextProbingRebalanceTime = shouldTriggerProbingRebalance && clientId.equals(taskManager.processId()); + + final boolean tasksRevoked = addClientAssignments( + statefulTasks, + assignment, + clientMetadata, + partitionsForTask, + partitionsByHostState, + standbyPartitionsByHost, + allOwnedPartitions, + activeTaskAssignment, + standbyTaskAssignment, + minUserMetadataVersion, + minSupportedMetadataVersion, + encodeNextProbingRebalanceTime + ); + + if (tasksRevoked || encodeNextProbingRebalanceTime) { + rebalanceRequired = true; + log.debug("Requested client {} to schedule a followup rebalance", clientId); + } + + log.info("Client {} per-consumer assignment:\n" + + "\tprev owned active {}\n" + + "\tprev owned standby {}\n" + + "\tassigned active {}\n" + + "\trevoking active {}\n" + + "\tassigned standby {}\n", + clientId, + clientMetadata.state.prevOwnedActiveTasksByConsumer(), + clientMetadata.state.prevOwnedStandbyByConsumer(), + clientMetadata.state.assignedActiveTasksByConsumer(), + clientMetadata.state.revokingActiveTasksByConsumer(), + clientMetadata.state.assignedStandbyTasksByConsumer()); + } + + if (rebalanceRequired) { + assignmentListener.onAssignmentComplete(false); + log.info("Finished unstable assignment of tasks, a followup rebalance will be scheduled."); + } else { + assignmentListener.onAssignmentComplete(true); + log.info("Finished stable assignment of tasks, no followup rebalances required."); + } + + return assignment; + } + + /** + * Adds the encoded assignment for each StreamThread consumer in the client to the overall assignment map + * @return true if a followup rebalance will be required due to revoked tasks + */ + private boolean addClientAssignments(final Set statefulTasks, + final Map assignment, + final ClientMetadata clientMetadata, + final Map> partitionsForTask, + final Map> partitionsByHostState, + final Map> standbyPartitionsByHost, + final Set allOwnedPartitions, + final Map> activeTaskAssignments, + final Map> standbyTaskAssignments, + final int minUserMetadataVersion, + final int minSupportedMetadataVersion, + final boolean probingRebalanceNeeded) { + boolean followupRebalanceRequiredForRevokedTasks = false; + + // We only want to encode a scheduled probing rebalance for a single member in this client + boolean shouldEncodeProbingRebalance = probingRebalanceNeeded; + + // Loop through the consumers and build their assignment + for (final String consumer : clientMetadata.consumers) { + final List activeTasksForConsumer = activeTaskAssignments.get(consumer); + + // These will be filled in by populateActiveTaskAndPartitionsLists below + final List activePartitionsList = new ArrayList<>(); + final List assignedActiveList = new ArrayList<>(); + + final Set activeTasksRemovedPendingRevokation = populateActiveTaskAndPartitionsLists( + activePartitionsList, + assignedActiveList, + consumer, + clientMetadata.state, + activeTasksForConsumer, + partitionsForTask, + allOwnedPartitions + ); + + final Map> standbyTaskMap = buildStandbyTaskMap( + consumer, + standbyTaskAssignments.get(consumer), + activeTasksRemovedPendingRevokation, + statefulTasks, + partitionsForTask, + clientMetadata.state + ); + + final AssignmentInfo info = new AssignmentInfo( + minUserMetadataVersion, + minSupportedMetadataVersion, + assignedActiveList, + standbyTaskMap, + partitionsByHostState, + standbyPartitionsByHost, + AssignorError.NONE.code() + ); + + if (!activeTasksRemovedPendingRevokation.isEmpty()) { + // TODO: once KAFKA-10078 is resolved we can leave it to the client to trigger this rebalance + log.info("Requesting followup rebalance be scheduled immediately by {} due to tasks changing ownership.", consumer); + info.setNextRebalanceTime(0L); + followupRebalanceRequiredForRevokedTasks = true; + // Don't bother to schedule a probing rebalance if an immediate one is already scheduled + shouldEncodeProbingRebalance = false; + } else if (shouldEncodeProbingRebalance) { + final long nextRebalanceTimeMs = time.milliseconds() + probingRebalanceIntervalMs(); + log.info("Requesting followup rebalance be scheduled by {} for {} ms to probe for caught-up replica tasks.", + consumer, nextRebalanceTimeMs); + info.setNextRebalanceTime(nextRebalanceTimeMs); + shouldEncodeProbingRebalance = false; + } + + assignment.put( + consumer, + new Assignment( + activePartitionsList, + info.encode() + ) + ); + } + return followupRebalanceRequiredForRevokedTasks; + } + + /** + * Populates the lists of active tasks and active task partitions for the consumer with a 1:1 mapping between them + * such that the nth task corresponds to the nth partition in the list. This means tasks with multiple partitions + * will be repeated in the list. + */ + private Set populateActiveTaskAndPartitionsLists(final List activePartitionsList, + final List assignedActiveList, + final String consumer, + final ClientState clientState, + final List activeTasksForConsumer, + final Map> partitionsForTask, + final Set allOwnedPartitions) { + final List assignedPartitions = new ArrayList<>(); + final Set removedActiveTasks = new TreeSet<>(); + + for (final TaskId taskId : activeTasksForConsumer) { + // Populate the consumer for assigned tasks without considering revocation, + // this is for debugging purposes only + clientState.assignActiveToConsumer(taskId, consumer); + + final List assignedPartitionsForTask = new ArrayList<>(); + for (final TopicPartition partition : partitionsForTask.get(taskId)) { + final String oldOwner = clientState.previousOwnerForPartition(partition); + final boolean newPartitionForConsumer = oldOwner == null || !oldOwner.equals(consumer); + + // If the partition is new to this consumer but is still owned by another, remove from the assignment + // until it has been revoked and can safely be reassigned according to the COOPERATIVE protocol + if (newPartitionForConsumer && allOwnedPartitions.contains(partition)) { + log.info( + "Removing task {} from {} active assignment until it is safely revoked in followup rebalance", + taskId, + consumer + ); + removedActiveTasks.add(taskId); + + clientState.revokeActiveFromConsumer(taskId, consumer); + + // Clear the assigned partitions list for this task if any partition can not safely be assigned, + // so as not to encode a partial task + assignedPartitionsForTask.clear(); + + // This has no effect on the assignment, as we'll never consult the ClientState again, but + // it does perform a useful assertion that the task was actually assigned. + clientState.unassignActive(taskId); + break; + } else { + assignedPartitionsForTask.add(new AssignedPartition(taskId, partition)); + } + } + // assignedPartitionsForTask will either contain all partitions for the task or be empty, so just add all + assignedPartitions.addAll(assignedPartitionsForTask); + } + + // Add one copy of a task for each corresponding partition, so the receiver can determine the task <-> tp mapping + Collections.sort(assignedPartitions); + for (final AssignedPartition partition : assignedPartitions) { + assignedActiveList.add(partition.taskId); + activePartitionsList.add(partition.partition); + } + return removedActiveTasks; + } + + /** + * @return map from task id to its assigned partitions for all standby tasks + */ + private Map> buildStandbyTaskMap(final String consumer, + final Iterable standbyTasks, + final Iterable revokedTasks, + final Set allStatefulTasks, + final Map> partitionsForTask, + final ClientState clientState) { + final Map> standbyTaskMap = new HashMap<>(); + + for (final TaskId task : standbyTasks) { + clientState.assignStandbyToConsumer(task, consumer); + standbyTaskMap.put(task, partitionsForTask.get(task)); + } + + for (final TaskId task : revokedTasks) { + if (allStatefulTasks.contains(task)) { + log.info("Adding removed stateful active task {} as a standby for {} before it is revoked in followup rebalance", + task, consumer); + + // This has no effect on the assignment, as we'll never consult the ClientState again, but + // it does perform a useful assertion that the it's legal to assign this task as a standby to this instance + clientState.assignStandbyToConsumer(task, consumer); + clientState.assignStandby(task); + + standbyTaskMap.put(task, partitionsForTask.get(task)); + } + } + return standbyTaskMap; + } + + /** + * Generate an assignment that tries to preserve thread-level stickiness of stateful tasks without violating + * balance. The stateful and total task load are both balanced across threads. Tasks without previous owners + * will be interleaved by group id to spread subtopologies across threads and further balance the workload. + */ + static Map> assignTasksToThreads(final Collection statefulTasksToAssign, + final Collection statelessTasksToAssign, + final SortedSet consumers, + final ClientState state) { + final Map> assignment = new HashMap<>(); + for (final String consumer : consumers) { + assignment.put(consumer, new ArrayList<>()); + } + + final List unassignedStatelessTasks = new ArrayList<>(statelessTasksToAssign); + Collections.sort(unassignedStatelessTasks); + + final Iterator unassignedStatelessTasksIter = unassignedStatelessTasks.iterator(); + + final int minStatefulTasksPerThread = (int) Math.floor(((double) statefulTasksToAssign.size()) / consumers.size()); + final PriorityQueue unassignedStatefulTasks = new PriorityQueue<>(statefulTasksToAssign); + + final Queue consumersToFill = new LinkedList<>(); + // keep track of tasks that we have to skip during the first pass in case we can reassign them later + // using tree-map to make sure the iteration ordering over keys are preserved + final Map unassignedTaskToPreviousOwner = new TreeMap<>(); + + if (!unassignedStatefulTasks.isEmpty()) { + // First assign stateful tasks to previous owner, up to the min expected tasks/thread + for (final String consumer : consumers) { + final List threadAssignment = assignment.get(consumer); + + for (final TaskId task : state.prevTasksByLag(consumer)) { + if (unassignedStatefulTasks.contains(task)) { + if (threadAssignment.size() < minStatefulTasksPerThread) { + threadAssignment.add(task); + unassignedStatefulTasks.remove(task); + } else { + unassignedTaskToPreviousOwner.put(task, consumer); + } + } + } + + if (threadAssignment.size() < minStatefulTasksPerThread) { + consumersToFill.offer(consumer); + } + } + + // Next interleave remaining unassigned tasks amongst unfilled consumers + while (!consumersToFill.isEmpty()) { + final TaskId task = unassignedStatefulTasks.poll(); + if (task != null) { + final String consumer = consumersToFill.poll(); + final List threadAssignment = assignment.get(consumer); + threadAssignment.add(task); + if (threadAssignment.size() < minStatefulTasksPerThread) { + consumersToFill.offer(consumer); + } + } else { + throw new TaskAssignmentException("Ran out of unassigned stateful tasks but some members were not at capacity"); + } + } + + // At this point all consumers are at the min capacity, so there may be up to N - 1 unassigned + // stateful tasks still remaining that should now be distributed over the consumers + if (!unassignedStatefulTasks.isEmpty()) { + consumersToFill.addAll(consumers); + + // Go over the tasks we skipped earlier and assign them to their previous owner when possible + for (final Map.Entry taskEntry : unassignedTaskToPreviousOwner.entrySet()) { + final TaskId task = taskEntry.getKey(); + final String consumer = taskEntry.getValue(); + if (consumersToFill.contains(consumer) && unassignedStatefulTasks.contains(task)) { + assignment.get(consumer).add(task); + unassignedStatefulTasks.remove(task); + // Remove this consumer since we know it is now at minCapacity + 1 + consumersToFill.remove(consumer); + } + } + + // Now just distribute the remaining unassigned stateful tasks over the consumers still at min capacity + for (final TaskId task : unassignedStatefulTasks) { + final String consumer = consumersToFill.poll(); + final List threadAssignment = assignment.get(consumer); + threadAssignment.add(task); + } + + + // There must be at least one consumer still at min capacity while all the others are at min + // capacity + 1, so start distributing stateless tasks to get all consumers back to the same count + while (unassignedStatelessTasksIter.hasNext()) { + final String consumer = consumersToFill.poll(); + if (consumer != null) { + final TaskId task = unassignedStatelessTasksIter.next(); + unassignedStatelessTasksIter.remove(); + assignment.get(consumer).add(task); + } else { + break; + } + } + } + } + + // Now just distribute tasks while circling through all the consumers + consumersToFill.addAll(consumers); + + while (unassignedStatelessTasksIter.hasNext()) { + final TaskId task = unassignedStatelessTasksIter.next(); + final String consumer = consumersToFill.poll(); + assignment.get(consumer).add(task); + consumersToFill.offer(consumer); + } + + return assignment; + } + + private void validateMetadataVersions(final int receivedAssignmentMetadataVersion, + final int latestCommonlySupportedVersion) { + + if (receivedAssignmentMetadataVersion > usedSubscriptionMetadataVersion) { + log.error("Leader sent back an assignment with version {} which was greater than our used version {}", + receivedAssignmentMetadataVersion, usedSubscriptionMetadataVersion); + throw new TaskAssignmentException( + "Sent a version " + usedSubscriptionMetadataVersion + + " subscription but got an assignment with higher version " + + receivedAssignmentMetadataVersion + "." + ); + } + + if (latestCommonlySupportedVersion > LATEST_SUPPORTED_VERSION) { + log.error("Leader sent back assignment with commonly supported version {} that is greater than our " + + "actual latest supported version {}", latestCommonlySupportedVersion, LATEST_SUPPORTED_VERSION); + throw new TaskAssignmentException("Can't upgrade to metadata version greater than we support"); + } + } + + // Returns true if subscription version was changed, indicating version probing and need to rebalance again + protected boolean maybeUpdateSubscriptionVersion(final int receivedAssignmentMetadataVersion, + final int latestCommonlySupportedVersion) { + if (receivedAssignmentMetadataVersion >= EARLIEST_PROBEABLE_VERSION) { + // If the latest commonly supported version is now greater than our used version, this indicates we have just + // completed the rolling upgrade and can now update our subscription version for the final rebalance + if (latestCommonlySupportedVersion > usedSubscriptionMetadataVersion) { + log.info( + "Sent a version {} subscription and group's latest commonly supported version is {} (successful " + + + "version probing and end of rolling upgrade). Upgrading subscription metadata version to " + + "{} for next rebalance.", + usedSubscriptionMetadataVersion, + latestCommonlySupportedVersion, + latestCommonlySupportedVersion + ); + usedSubscriptionMetadataVersion = latestCommonlySupportedVersion; + return true; + } + + // If we received a lower version than we sent, someone else in the group still hasn't upgraded. We + // should downgrade our subscription until everyone is on the latest version + if (receivedAssignmentMetadataVersion < usedSubscriptionMetadataVersion) { + log.info( + "Sent a version {} subscription and got version {} assignment back (successful version probing). " + + + "Downgrade subscription metadata to commonly supported version {} and trigger new rebalance.", + usedSubscriptionMetadataVersion, + receivedAssignmentMetadataVersion, + latestCommonlySupportedVersion + ); + usedSubscriptionMetadataVersion = latestCommonlySupportedVersion; + return true; + } + } else { + log.debug("Received an assignment version {} that is less than the earliest version that allows version " + + "probing {}. If this is not during a rolling upgrade from version 2.0 or below, this is an error.", + receivedAssignmentMetadataVersion, EARLIEST_PROBEABLE_VERSION); + } + + return false; + } + + @Override + public void onAssignment(final Assignment assignment, final ConsumerGroupMetadata metadata) { + final List partitions = new ArrayList<>(assignment.partitions()); + partitions.sort(PARTITION_COMPARATOR); + + final AssignmentInfo info = AssignmentInfo.decode(assignment.userData()); + if (info.errCode() != AssignorError.NONE.code()) { + // set flag to shutdown streams app + assignmentErrorCode.set(info.errCode()); + return; + } + /* + * latestCommonlySupportedVersion belongs to [usedSubscriptionMetadataVersion, LATEST_SUPPORTED_VERSION] + * receivedAssignmentMetadataVersion belongs to [EARLIEST_PROBEABLE_VERSION, usedSubscriptionMetadataVersion] + * + * usedSubscriptionMetadataVersion will be downgraded to receivedAssignmentMetadataVersion during a rolling + * bounce upgrade with version probing. + * + * usedSubscriptionMetadataVersion will be upgraded to latestCommonlySupportedVersion when all members have + * been bounced and it is safe to use the latest version. + */ + final int receivedAssignmentMetadataVersion = info.version(); + final int latestCommonlySupportedVersion = info.commonlySupportedVersion(); + + validateMetadataVersions(receivedAssignmentMetadataVersion, latestCommonlySupportedVersion); + + // version 1 field + final Map> activeTasks; + // version 2 fields + final Map topicToPartitionInfo; + final Map> partitionsByHost; + final Map> standbyPartitionsByHost; + final long encodedNextScheduledRebalanceMs; + + switch (receivedAssignmentMetadataVersion) { + case 1: + validateActiveTaskEncoding(partitions, info, logPrefix); + + activeTasks = getActiveTasks(partitions, info); + partitionsByHost = Collections.emptyMap(); + standbyPartitionsByHost = Collections.emptyMap(); + topicToPartitionInfo = Collections.emptyMap(); + encodedNextScheduledRebalanceMs = Long.MAX_VALUE; + break; + case 2: + case 3: + case 4: + case 5: + validateActiveTaskEncoding(partitions, info, logPrefix); + + activeTasks = getActiveTasks(partitions, info); + partitionsByHost = info.partitionsByHost(); + standbyPartitionsByHost = Collections.emptyMap(); + topicToPartitionInfo = getTopicPartitionInfo(partitionsByHost); + encodedNextScheduledRebalanceMs = Long.MAX_VALUE; + break; + case 6: + validateActiveTaskEncoding(partitions, info, logPrefix); + + activeTasks = getActiveTasks(partitions, info); + partitionsByHost = info.partitionsByHost(); + standbyPartitionsByHost = info.standbyPartitionByHost(); + topicToPartitionInfo = getTopicPartitionInfo(partitionsByHost); + encodedNextScheduledRebalanceMs = Long.MAX_VALUE; + break; + case 7: + case 8: + case 9: + case 10: + validateActiveTaskEncoding(partitions, info, logPrefix); + + activeTasks = getActiveTasks(partitions, info); + partitionsByHost = info.partitionsByHost(); + standbyPartitionsByHost = info.standbyPartitionByHost(); + topicToPartitionInfo = getTopicPartitionInfo(partitionsByHost); + encodedNextScheduledRebalanceMs = info.nextRebalanceMs(); + break; + default: + throw new IllegalStateException( + "This code should never be reached." + + " Please file a bug report at https://issues.apache.org/jira/projects/KAFKA/" + ); + } + + maybeScheduleFollowupRebalance( + encodedNextScheduledRebalanceMs, + receivedAssignmentMetadataVersion, + latestCommonlySupportedVersion, + partitionsByHost.keySet() + ); + + final Cluster fakeCluster = Cluster.empty().withPartitions(topicToPartitionInfo); + streamsMetadataState.onChange(partitionsByHost, standbyPartitionsByHost, fakeCluster); + + // we do not capture any exceptions but just let the exception thrown from consumer.poll directly + // since when stream thread captures it, either we close all tasks as dirty or we close thread + taskManager.handleAssignment(activeTasks, info.standbyTasks()); + } + + private void maybeScheduleFollowupRebalance(final long encodedNextScheduledRebalanceMs, + final int receivedAssignmentMetadataVersion, + final int latestCommonlySupportedVersion, + final Set groupHostInfo) { + if (maybeUpdateSubscriptionVersion(receivedAssignmentMetadataVersion, latestCommonlySupportedVersion)) { + log.info("Requested to schedule immediate rebalance due to version probing."); + nextScheduledRebalanceMs.set(0L); + } else if (!verifyHostInfo(groupHostInfo)) { + log.info("Requested to schedule immediate rebalance to update group with new host endpoint = {}.", userEndPoint); + nextScheduledRebalanceMs.set(0L); + } else if (encodedNextScheduledRebalanceMs == 0L) { + log.info("Requested to schedule immediate rebalance for new tasks to be safely revoked from current owner."); + nextScheduledRebalanceMs.set(0L); + } else if (encodedNextScheduledRebalanceMs < Long.MAX_VALUE) { + log.info("Requested to schedule probing rebalance for {} ms.", encodedNextScheduledRebalanceMs); + nextScheduledRebalanceMs.set(encodedNextScheduledRebalanceMs); + } else { + log.info("No followup rebalance was requested, resetting the rebalance schedule."); + nextScheduledRebalanceMs.set(Long.MAX_VALUE); + } + } + + /** + * Verify that this client's host info was included in the map returned in the assignment, and trigger a + * rebalance if not. This may be necessary when using static membership, as a rejoining client will be handed + * back its original assignment to avoid an unnecessary rebalance. If the client's endpoint has changed, we need + * to force a rebalance for the other members in the group to get the updated host info for this client. + * + * @param groupHostInfo the HostInfo of all clients in the group + * @return false if the current host info does not match that in the group assignment + */ + private boolean verifyHostInfo(final Set groupHostInfo) { + if (userEndPoint != null && !groupHostInfo.isEmpty()) { + final HostInfo myHostInfo = HostInfo.buildFromEndpoint(userEndPoint); + + return groupHostInfo.contains(myHostInfo); + } else { + return true; + } + } + + // protected for upgrade test + protected static Map> getActiveTasks(final List partitions, final AssignmentInfo info) { + final Map> activeTasks = new HashMap<>(); + for (int i = 0; i < partitions.size(); i++) { + final TopicPartition partition = partitions.get(i); + final TaskId id = info.activeTasks().get(i); + activeTasks.computeIfAbsent(id, k1 -> new HashSet<>()).add(partition); + } + return activeTasks; + } + + static Map getTopicPartitionInfo(final Map> partitionsByHost) { + final Map topicToPartitionInfo = new HashMap<>(); + for (final Set value : partitionsByHost.values()) { + for (final TopicPartition topicPartition : value) { + topicToPartitionInfo.put( + topicPartition, + new PartitionInfo( + topicPartition.topic(), + topicPartition.partition(), + null, + new Node[0], + new Node[0] + ) + ); + } + } + return topicToPartitionInfo; + } + + private static void validateActiveTaskEncoding(final List partitions, final AssignmentInfo info, final String logPrefix) { + // the number of assigned partitions should be the same as number of active tasks, which + // could be duplicated if one task has more than one assigned partitions + if (partitions.size() != info.activeTasks().size()) { + throw new TaskAssignmentException( + String.format( + "%sNumber of assigned partitions %d is not equal to " + + "the number of active taskIds %d, assignmentInfo=%s", + logPrefix, partitions.size(), + info.activeTasks().size(), info.toString() + ) + ); + } + } + + private int updateMinReceivedVersion(final int usedVersion, final int minReceivedMetadataVersion) { + return Math.min(usedVersion, minReceivedMetadataVersion); + } + + private int updateMinSupportedVersion(final int supportedVersion, final int minSupportedMetadataVersion) { + if (supportedVersion < minSupportedMetadataVersion) { + log.debug("Downgrade the current minimum supported version {} to the smaller seen supported version {}", + minSupportedMetadataVersion, supportedVersion); + return supportedVersion; + } else { + log.debug("Current minimum supported version remains at {}, last seen supported version was {}", + minSupportedMetadataVersion, supportedVersion); + return minSupportedMetadataVersion; + } + } + + // following functions are for test only + void setInternalTopicManager(final InternalTopicManager internalTopicManager) { + this.internalTopicManager = internalTopicManager; + } + + RebalanceProtocol rebalanceProtocol() { + return rebalanceProtocol; + } + + protected String userEndPoint() { + return userEndPoint; + } + + protected TaskManager taskManager() { + return taskManager; + } + + protected byte uniqueField() { + return uniqueField; + } + + protected void handleRebalanceStart(final Set topics) { + taskManager.handleRebalanceStart(topics); + } + + long acceptableRecoveryLag() { + return assignmentConfigs.acceptableRecoveryLag; + } + + int maxWarmupReplicas() { + return assignmentConfigs.maxWarmupReplicas; + } + + int numStandbyReplicas() { + return assignmentConfigs.numStandbyReplicas; + } + + long probingRebalanceIntervalMs() { + return assignmentConfigs.probingRebalanceIntervalMs; + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsProducer.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsProducer.java new file mode 100644 index 0000000..fe4f363 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsProducer.java @@ -0,0 +1,373 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import java.util.stream.Collectors; +import org.apache.kafka.clients.consumer.CommitFailedException; +import org.apache.kafka.clients.consumer.ConsumerGroupMetadata; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.producer.Callback; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.Producer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.InvalidProducerEpochException; +import org.apache.kafka.common.errors.ProducerFencedException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.UnknownProducerIdException; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.KafkaClientSupplier; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.errors.TaskMigratedException; +import org.apache.kafka.streams.processor.TaskId; +import org.slf4j.Logger; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.UUID; +import java.util.concurrent.Future; + +import static org.apache.kafka.streams.processor.internals.ClientUtils.getTaskProducerClientId; +import static org.apache.kafka.streams.processor.internals.ClientUtils.getThreadProducerClientId; +import static org.apache.kafka.streams.processor.internals.StreamThread.ProcessingMode.EXACTLY_ONCE_V2; + +/** + * {@code StreamsProducer} manages the producers within a Kafka Streams application. + *

                + * If EOS is enabled, it is responsible to init and begin transactions if necessary. + * It also tracks the transaction status, ie, if a transaction is in-fight. + *

                + * For non-EOS, the user should not call transaction related methods. + */ +public class StreamsProducer { + private final Logger log; + private final String logPrefix; + + private final Map eosV2ProducerConfigs; + private final KafkaClientSupplier clientSupplier; + private final StreamThread.ProcessingMode processingMode; + private final Time time; + + private Producer producer; + private boolean transactionInFlight = false; + private boolean transactionInitialized = false; + private double oldProducerTotalBlockedTime = 0; + + public StreamsProducer(final StreamsConfig config, + final String threadId, + final KafkaClientSupplier clientSupplier, + final TaskId taskId, + final UUID processId, + final LogContext logContext, + final Time time) { + Objects.requireNonNull(config, "config cannot be null"); + Objects.requireNonNull(threadId, "threadId cannot be null"); + this.clientSupplier = Objects.requireNonNull(clientSupplier, "clientSupplier cannot be null"); + log = Objects.requireNonNull(logContext, "logContext cannot be null").logger(getClass()); + logPrefix = logContext.logPrefix().trim(); + this.time = Objects.requireNonNull(time, "time"); + + processingMode = StreamThread.processingMode(config); + + final Map producerConfigs; + switch (processingMode) { + case AT_LEAST_ONCE: { + producerConfigs = config.getProducerConfigs(getThreadProducerClientId(threadId)); + eosV2ProducerConfigs = null; + + break; + } + case EXACTLY_ONCE_ALPHA: { + producerConfigs = config.getProducerConfigs( + getTaskProducerClientId( + threadId, + Objects.requireNonNull(taskId, "taskId cannot be null for exactly-once alpha") + ) + ); + + final String applicationId = config.getString(StreamsConfig.APPLICATION_ID_CONFIG); + producerConfigs.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, applicationId + "-" + taskId); + + eosV2ProducerConfigs = null; + + break; + } + case EXACTLY_ONCE_V2: { + producerConfigs = config.getProducerConfigs(getThreadProducerClientId(threadId)); + + final String applicationId = config.getString(StreamsConfig.APPLICATION_ID_CONFIG); + producerConfigs.put( + ProducerConfig.TRANSACTIONAL_ID_CONFIG, + applicationId + "-" + + Objects.requireNonNull(processId, "processId cannot be null for exactly-once v2") + + "-" + threadId.split("-StreamThread-")[1]); + + eosV2ProducerConfigs = producerConfigs; + + break; + } + default: + throw new IllegalArgumentException("Unknown processing mode: " + processingMode); + } + + producer = clientSupplier.getProducer(producerConfigs); + } + + private String formatException(final String message) { + return message + " [" + logPrefix + "]"; + } + + boolean eosEnabled() { + return StreamThread.eosEnabled(processingMode); + } + + /** + * @throws IllegalStateException if EOS is disabled + */ + void initTransaction() { + if (!eosEnabled()) { + throw new IllegalStateException(formatException("Exactly-once is not enabled")); + } + if (!transactionInitialized) { + // initialize transactions if eos is turned on, which will block if the previous transaction has not + // completed yet; do not start the first transaction until the topology has been initialized later + try { + producer.initTransactions(); + transactionInitialized = true; + } catch (final TimeoutException timeoutException) { + log.warn( + "Timeout exception caught trying to initialize transactions. " + + "The broker is either slow or in bad state (like not having enough replicas) in " + + "responding to the request, or the connection to broker was interrupted sending " + + "the request or receiving the response. " + + "Will retry initializing the task in the next loop. " + + "Consider overwriting {} to a larger value to avoid timeout errors", + ProducerConfig.MAX_BLOCK_MS_CONFIG + ); + + // re-throw to trigger `task.timeout.ms` + throw timeoutException; + } catch (final KafkaException exception) { + throw new StreamsException( + formatException("Error encountered trying to initialize transactions"), + exception + ); + } + } + } + + public void resetProducer() { + if (processingMode != EXACTLY_ONCE_V2) { + throw new IllegalStateException("Expected eos-v2 to be enabled, but the processing mode was " + processingMode); + } + + oldProducerTotalBlockedTime += totalBlockedTime(producer); + final long start = time.nanoseconds(); + producer.close(); + final long closeTime = time.nanoseconds() - start; + oldProducerTotalBlockedTime += closeTime; + + producer = clientSupplier.getProducer(eosV2ProducerConfigs); + transactionInitialized = false; + } + + private double getMetricValue(final Map metrics, + final String name) { + final List found = metrics.keySet().stream() + .filter(n -> n.name().equals(name)) + .collect(Collectors.toList()); + if (found.isEmpty()) { + return 0.0; + } + if (found.size() > 1) { + final String err = String.format( + "found %d values for metric %s. total blocked time computation may be incorrect", + found.size(), + name + ); + log.error(err); + throw new IllegalStateException(err); + } + return (Double) metrics.get(found.get(0)).metricValue(); + } + + private double totalBlockedTime(final Producer producer) { + return getMetricValue(producer.metrics(), "bufferpool-wait-time-ns-total") + + getMetricValue(producer.metrics(), "flush-time-ns-total") + + getMetricValue(producer.metrics(), "txn-init-time-ns-total") + + getMetricValue(producer.metrics(), "txn-begin-time-ns-total") + + getMetricValue(producer.metrics(), "txn-send-offsets-time-ns-total") + + getMetricValue(producer.metrics(), "txn-commit-time-ns-total") + + getMetricValue(producer.metrics(), "txn-abort-time-ns-total"); + } + + public double totalBlockedTime() { + return oldProducerTotalBlockedTime + totalBlockedTime(producer); + } + + private void maybeBeginTransaction() { + if (eosEnabled() && !transactionInFlight) { + try { + producer.beginTransaction(); + transactionInFlight = true; + } catch (final ProducerFencedException | InvalidProducerEpochException error) { + throw new TaskMigratedException( + formatException("Producer got fenced trying to begin a new transaction"), + error + ); + } catch (final KafkaException error) { + throw new StreamsException( + formatException("Error encountered trying to begin a new transaction"), + error + ); + } + } + } + + Future send(final ProducerRecord record, + final Callback callback) { + maybeBeginTransaction(); + try { + return producer.send(record, callback); + } catch (final KafkaException uncaughtException) { + if (isRecoverable(uncaughtException)) { + // producer.send() call may throw a KafkaException which wraps a FencedException, + // in this case we should throw its wrapped inner cause so that it can be + // captured and re-wrapped as TaskMigratedException + throw new TaskMigratedException( + formatException("Producer got fenced trying to send a record"), + uncaughtException.getCause() + ); + } else { + throw new StreamsException( + formatException(String.format("Error encountered trying to send record to topic %s", record.topic())), + uncaughtException + ); + } + } + } + + private static boolean isRecoverable(final KafkaException uncaughtException) { + return uncaughtException.getCause() instanceof ProducerFencedException || + uncaughtException.getCause() instanceof InvalidProducerEpochException || + uncaughtException.getCause() instanceof UnknownProducerIdException; + } + + /** + * @throws IllegalStateException if EOS is disabled + * @throws TaskMigratedException + */ + protected void commitTransaction(final Map offsets, + final ConsumerGroupMetadata consumerGroupMetadata) { + if (!eosEnabled()) { + throw new IllegalStateException(formatException("Exactly-once is not enabled")); + } + maybeBeginTransaction(); + try { + // EOS-v2 assumes brokers are on version 2.5+ and thus can understand the full set of consumer group metadata + // Thus if we are using EOS-v1 and can't make this assumption, we must downgrade the request to include only the group id metadata + final ConsumerGroupMetadata maybeDowngradedGroupMetadata = processingMode == EXACTLY_ONCE_V2 ? consumerGroupMetadata : new ConsumerGroupMetadata(consumerGroupMetadata.groupId()); + producer.sendOffsetsToTransaction(offsets, maybeDowngradedGroupMetadata); + producer.commitTransaction(); + transactionInFlight = false; + } catch (final ProducerFencedException | InvalidProducerEpochException | CommitFailedException error) { + throw new TaskMigratedException( + formatException("Producer got fenced trying to commit a transaction"), + error + ); + } catch (final TimeoutException timeoutException) { + // re-throw to trigger `task.timeout.ms` + throw timeoutException; + } catch (final KafkaException error) { + throw new StreamsException( + formatException("Error encountered trying to commit a transaction"), + error + ); + } + } + + /** + * @throws IllegalStateException if EOS is disabled + */ + void abortTransaction() { + if (!eosEnabled()) { + throw new IllegalStateException(formatException("Exactly-once is not enabled")); + } + if (transactionInFlight) { + try { + producer.abortTransaction(); + } catch (final TimeoutException logAndSwallow) { + // no need to re-throw because we abort a TX only if we close a task dirty, + // and thus `task.timeout.ms` does not apply + log.warn( + "Aborting transaction failed due to timeout." + + " Will rely on broker to eventually abort the transaction after the transaction timeout passed.", + logAndSwallow + ); + } catch (final ProducerFencedException | InvalidProducerEpochException error) { + // The producer is aborting the txn when there's still an ongoing one, + // which means that we did not commit the task while closing it, which + // means that it is a dirty close. Therefore it is possible that the dirty + // close is due to an fenced exception already thrown previously, and hence + // when calling abortTxn here the same exception would be thrown again. + // Even if the dirty close was not due to an observed fencing exception but + // something else (e.g. task corrupted) we can still ignore the exception here + // since transaction already got aborted by brokers/transactional-coordinator if this happens + log.debug("Encountered {} while aborting the transaction; this is expected and hence swallowed", error.getMessage()); + } catch (final KafkaException error) { + throw new StreamsException( + formatException("Error encounter trying to abort a transaction"), + error + ); + } + transactionInFlight = false; + } + } + + /** + * Cf {@link KafkaProducer#partitionsFor(String)} + */ + List partitionsFor(final String topic) { + return producer.partitionsFor(topic); + } + + Map metrics() { + return producer.metrics(); + } + + void flush() { + producer.flush(); + } + + void close() { + producer.close(); + } + + // for testing only + Producer kafkaProducer() { + return producer; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsRebalanceListener.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsRebalanceListener.java new file mode 100644 index 0000000..ba2883b --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/StreamsRebalanceListener.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.kafka.clients.consumer.ConsumerRebalanceListener; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.errors.MissingSourceTopicException; +import org.apache.kafka.streams.errors.TaskAssignmentException; +import org.apache.kafka.streams.processor.internals.StreamThread.State; +import org.apache.kafka.streams.processor.internals.assignment.AssignorError; +import org.slf4j.Logger; + +import java.util.Collection; + +public class StreamsRebalanceListener implements ConsumerRebalanceListener { + + private final Time time; + private final TaskManager taskManager; + private final StreamThread streamThread; + private final Logger log; + private final AtomicInteger assignmentErrorCode; + + StreamsRebalanceListener(final Time time, + final TaskManager taskManager, + final StreamThread streamThread, + final Logger log, + final AtomicInteger assignmentErrorCode) { + this.time = time; + this.taskManager = taskManager; + this.streamThread = streamThread; + this.log = log; + this.assignmentErrorCode = assignmentErrorCode; + } + + @Override + public void onPartitionsAssigned(final Collection partitions) { + // NB: all task management is already handled by: + // org.apache.kafka.streams.processor.internals.StreamsPartitionAssignor.onAssignment + if (assignmentErrorCode.get() == AssignorError.INCOMPLETE_SOURCE_TOPIC_METADATA.code()) { + log.error("Received error code {}", AssignorError.INCOMPLETE_SOURCE_TOPIC_METADATA); + taskManager.handleRebalanceComplete(); + throw new MissingSourceTopicException("One or more source topics were missing during rebalance"); + } else if (assignmentErrorCode.get() == AssignorError.VERSION_PROBING.code()) { + log.info("Received version probing code {}", AssignorError.VERSION_PROBING); + } else if (assignmentErrorCode.get() == AssignorError.ASSIGNMENT_ERROR.code()) { + log.error("Received error code {}", AssignorError.ASSIGNMENT_ERROR); + taskManager.handleRebalanceComplete(); + throw new TaskAssignmentException("Hit an unexpected exception during task assignment phase of rebalance"); + } else if (assignmentErrorCode.get() == AssignorError.SHUTDOWN_REQUESTED.code()) { + log.error("A Kafka Streams client in this Kafka Streams application is requesting to shutdown the application"); + taskManager.handleRebalanceComplete(); + streamThread.shutdownToError(); + return; + } else if (assignmentErrorCode.get() != AssignorError.NONE.code()) { + log.error("Received unknown error code {}", assignmentErrorCode.get()); + throw new TaskAssignmentException("Hit an unrecognized exception during rebalance"); + } + + streamThread.setState(State.PARTITIONS_ASSIGNED); + streamThread.setPartitionAssignedTime(time.milliseconds()); + taskManager.handleRebalanceComplete(); + } + + @Override + public void onPartitionsRevoked(final Collection partitions) { + log.debug("Current state {}: revoked partitions {} because of consumer rebalance.\n" + + "\tcurrently assigned active tasks: {}\n" + + "\tcurrently assigned standby tasks: {}\n", + streamThread.state(), + partitions, + taskManager.activeTaskIds(), + taskManager.standbyTaskIds()); + + // We need to still invoke handleRevocation if the thread has been told to shut down, but we shouldn't ever + // transition away from PENDING_SHUTDOWN once it's been initiated (to anything other than DEAD) + if ((streamThread.setState(State.PARTITIONS_REVOKED) != null || streamThread.state() == State.PENDING_SHUTDOWN) && !partitions.isEmpty()) { + final long start = time.milliseconds(); + try { + taskManager.handleRevocation(partitions); + } finally { + log.info("partition revocation took {} ms.", time.milliseconds() - start); + } + } + } + + @Override + public void onPartitionsLost(final Collection partitions) { + log.info("at state {}: partitions {} lost due to missed rebalance.\n" + + "\tlost active tasks: {}\n" + + "\tlost assigned standby tasks: {}\n", + streamThread.state(), + partitions, + taskManager.activeTaskIds(), + taskManager.standbyTaskIds()); + + final long start = time.milliseconds(); + try { + // close all active tasks as lost but don't try to commit offsets as we no longer own them + taskManager.handleLostAll(); + } finally { + log.info("partitions lost took {} ms.", time.milliseconds() - start); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java new file mode 100644 index 0000000..3549ba2 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Task.java @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.streams.errors.LockException; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.TaskId; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +public interface Task { + + // this must be negative to distinguish a running active task from other kinds of tasks + // which may be caught up to the same offsets + long LATEST_OFFSET = -2L; + + /* + *

                +     *                 +---------------+
                +     *          +----- |  Created (0)  | <---------+
                +     *          |      +-------+-------+           |
                +     *          |              |                   |
                +     *          |              v                   |
                +     *          |      +-------+-------+           |
                +     *          +----- | Restoring (1) | <---+     |
                +     *          |      +-------+-------+     |     |
                +     *          |              |             |     |
                +     *          |              v             |     |
                +     *          |      +-------+-------+     |     |
                +     *          |      |  Running (2)  |     |     |
                +     *          |      +-------+-------+     |     |
                +     *          |              |             |     |
                +     *          |              v             |     |
                +     *          |     +--------+-------+     |     |
                +     *          +---> |  Suspended (3) | ----+     |    //TODO Suspended(3) could be removed after we've stable on KIP-429
                +     *                +--------+-------+           |
                +     *                         |                   |
                +     *                         v                   |
                +     *                +--------+-------+           |
                +     *                |   Closed (4)   | ----------+
                +     *                +----------------+
                +     * 
                + */ + enum State { + CREATED(1, 3), // 0 + RESTORING(2, 3), // 1 + RUNNING(3), // 2 + SUSPENDED(1, 4), // 3 + CLOSED(0); // 4, we allow CLOSED to transit to CREATED to handle corrupted tasks + + private final Set validTransitions = new HashSet<>(); + + State(final Integer... validTransitions) { + this.validTransitions.addAll(Arrays.asList(validTransitions)); + } + + public boolean isValidTransition(final State newState) { + return validTransitions.contains(newState.ordinal()); + } + } + + enum TaskType { + ACTIVE("ACTIVE"), + + STANDBY("STANDBY"), + + GLOBAL("GLOBAL"); + + public final String name; + + TaskType(final String name) { + this.name = name; + } + } + + + // idempotent life-cycle methods + + /** + * @throws LockException could happen when multi-threads within the single instance, could retry + * @throws StreamsException fatal error, should close the thread + */ + void initializeIfNeeded(); + + default void addPartitionsForOffsetReset(final Set partitionsForOffsetReset) { + throw new UnsupportedOperationException(); + } + + /** + * @throws StreamsException fatal error, should close the thread + */ + void completeRestoration(final java.util.function.Consumer> offsetResetter); + + void suspend(); + + /** + * @throws StreamsException fatal error, should close the thread + */ + void resume(); + + /** + * Must be idempotent. + */ + void closeDirty(); + + /** + * Must be idempotent. + */ + void closeClean(); + + + // non-idempotent life-cycle methods + + /** + * Updates input partitions and topology after rebalance + */ + void updateInputPartitions(final Set topicPartitions, final Map> allTopologyNodesToSourceTopics); + + void markChangelogAsCorrupted(final Collection partitions); + + /** + * Revive a closed task to a created one; should never throw an exception + */ + void revive(); + + /** + * Attempt a clean close but do not close the underlying state + */ + void closeCleanAndRecycleState(); + + + // runtime methods (using in RUNNING state) + + void addRecords(TopicPartition partition, Iterable> records); + + default boolean process(final long wallClockTime) { + return false; + } + + default void recordProcessBatchTime(final long processBatchTime) { + } + + default void recordProcessTimeRatioAndBufferSize(final long allTaskProcessMs, final long now) { + } + + default boolean maybePunctuateStreamTime() { + return false; + } + + default boolean maybePunctuateSystemTime() { + return false; + } + + /** + * @throws StreamsException fatal error, should close the thread + */ + Map prepareCommit(); + + void postCommit(boolean enforceCheckpoint); + + default Map purgeableOffsets() { + return Collections.emptyMap(); + } + + /** + * @throws StreamsException if {@code currentWallClockMs > task-timeout-deadline} + */ + void maybeInitTaskTimeoutOrThrow(final long currentWallClockMs, + final Exception cause); + + void clearTaskTimeout(); + + // task status inquiry + + TaskId id(); + + boolean isActive(); + + Set inputPartitions(); + + /** + * @return any changelog partitions associated with this task + */ + Collection changelogPartitions(); + + State state(); + + default boolean needsInitializationOrRestoration() { + return state() == State.CREATED || state() == State.RESTORING; + } + + boolean commitNeeded(); + + default boolean commitRequested() { + return false; + } + + + // IQ related methods + + StateStore getStore(final String name); + + /** + * @return the offsets of all the changelog partitions associated with this task, + * indicating the current positions of the logged state stores of the task. + */ + Map changelogOffsets(); + + /** + * @return the offsets that each TopicPartition has committed so far in this task, + * indicating how far the processing has committed + */ + Map committedOffsets(); + + /** + * @return the highest offsets that each TopicPartition has seen so far in this task + */ + Map highWaterMark(); + + /** + * @return This returns the time the task started idling. If it is not idling it returns empty. + */ + Optional timeCurrentIdlingStarted(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskAction.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskAction.java new file mode 100644 index 0000000..da5f325 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskAction.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +interface TaskAction { + String name(); + void apply(final T task); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java new file mode 100644 index 0000000..c5ba0c1 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java @@ -0,0 +1,1478 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.DeleteRecordsResult; +import org.apache.kafka.clients.admin.RecordsToDelete; +import org.apache.kafka.clients.consumer.CommitFailedException; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.errors.LockException; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.errors.TaskCorruptedException; +import org.apache.kafka.streams.errors.TaskIdFormatException; +import org.apache.kafka.streams.errors.TaskMigratedException; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.StateDirectory.TaskDirectory; +import org.apache.kafka.streams.processor.internals.Task.State; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.internals.OffsetCheckpoint; + +import org.slf4j.Logger; + +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; +import java.util.TreeSet; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.apache.kafka.common.utils.Utils.intersection; +import static org.apache.kafka.common.utils.Utils.union; +import static org.apache.kafka.streams.processor.internals.StateManagerUtil.parseTaskDirectoryName; +import static org.apache.kafka.streams.processor.internals.StreamThread.ProcessingMode.EXACTLY_ONCE_ALPHA; +import static org.apache.kafka.streams.processor.internals.StreamThread.ProcessingMode.EXACTLY_ONCE_V2; + +public class TaskManager { + // initialize the task list + // activeTasks needs to be concurrent as it can be accessed + // by QueryableState + private final Logger log; + private final Time time; + private final ChangelogReader changelogReader; + private final UUID processId; + private final String logPrefix; + private final TopologyMetadata topologyMetadata; + private final Admin adminClient; + private final StateDirectory stateDirectory; + private final StreamThread.ProcessingMode processingMode; + private final Tasks tasks; + + private Consumer mainConsumer; + + private DeleteRecordsResult deleteRecordsResult; + + private boolean rebalanceInProgress = false; // if we are in the middle of a rebalance, it is not safe to commit + + // includes assigned & initialized tasks and unassigned tasks we locked temporarily during rebalance + private final Set lockedTaskDirectories = new HashSet<>(); + + TaskManager(final Time time, + final ChangelogReader changelogReader, + final UUID processId, + final String logPrefix, + final StreamsMetricsImpl streamsMetrics, + final ActiveTaskCreator activeTaskCreator, + final StandbyTaskCreator standbyTaskCreator, + final TopologyMetadata topologyMetadata, + final Admin adminClient, + final StateDirectory stateDirectory, + final StreamThread.ProcessingMode processingMode) { + this.time = time; + this.changelogReader = changelogReader; + this.processId = processId; + this.logPrefix = logPrefix; + this.topologyMetadata = topologyMetadata; + this.adminClient = adminClient; + this.stateDirectory = stateDirectory; + this.processingMode = processingMode; + this.tasks = new Tasks(logPrefix, topologyMetadata, streamsMetrics, activeTaskCreator, standbyTaskCreator); + + final LogContext logContext = new LogContext(logPrefix); + log = logContext.logger(getClass()); + } + + void setMainConsumer(final Consumer mainConsumer) { + this.mainConsumer = mainConsumer; + tasks.setMainConsumer(mainConsumer); + } + + public double totalProducerBlockedTime() { + return tasks.totalProducerBlockedTime(); + } + + public UUID processId() { + return processId; + } + + public TopologyMetadata topologyMetadata() { + return topologyMetadata; + } + + boolean isRebalanceInProgress() { + return rebalanceInProgress; + } + + void handleRebalanceStart(final Set subscribedTopics) { + topologyMetadata.addSubscribedTopicsFromMetadata(subscribedTopics, logPrefix); + + tryToLockAllNonEmptyTaskDirectories(); + + rebalanceInProgress = true; + } + + void handleRebalanceComplete() { + // we should pause consumer only within the listener since + // before then the assignment has not been updated yet. + mainConsumer.pause(mainConsumer.assignment()); + + releaseLockedUnassignedTaskDirectories(); + + rebalanceInProgress = false; + } + + /** + * Stop all tasks and consuming after the last named topology is removed to prevent further processing + */ + void handleEmptyTopology() { + log.info("Closing all tasks and unsubscribing the consumer due to empty topology"); + mainConsumer.unsubscribe(); + shutdown(true); + } + + /** + * @throws TaskMigratedException + */ + boolean handleCorruption(final Set corruptedTasks) { + final Set corruptedActiveTasks = new HashSet<>(); + final Set corruptedStandbyTasks = new HashSet<>(); + + for (final TaskId taskId : corruptedTasks) { + final Task task = tasks.task(taskId); + if (task.isActive()) { + corruptedActiveTasks.add(task); + } else { + corruptedStandbyTasks.add(task); + } + } + + // Make sure to clean up any corrupted standby tasks in their entirety before committing + // since TaskMigrated can be thrown and the resulting handleLostAll will only clean up active tasks + closeDirtyAndRevive(corruptedStandbyTasks, true); + + // We need to commit before closing the corrupted active tasks since this will force the ongoing txn to abort + try { + final Collection tasksToCommit = tasks() + .values() + .stream() + .filter(t -> t.state() == Task.State.RUNNING || t.state() == Task.State.RESTORING) + .filter(t -> !corruptedTasks.contains(t.id())) + .collect(Collectors.toSet()); + commitTasksAndMaybeUpdateCommittableOffsets(tasksToCommit, new HashMap<>()); + } catch (final TaskCorruptedException e) { + log.info("Some additional tasks were found corrupted while trying to commit, these will be added to the " + + "tasks to clean and revive: {}", e.corruptedTasks()); + corruptedActiveTasks.addAll(tasks.tasks(e.corruptedTasks())); + } catch (final TimeoutException e) { + log.info("Hit TimeoutException when committing all non-corrupted tasks, these will be closed and revived"); + final Collection uncorruptedTasks = new HashSet<>(tasks.activeTasks()); + uncorruptedTasks.removeAll(corruptedActiveTasks); + // Those tasks which just timed out can just be closed dirty without marking changelogs as corrupted + closeDirtyAndRevive(uncorruptedTasks, false); + } + + closeDirtyAndRevive(corruptedActiveTasks, true); + return !corruptedActiveTasks.isEmpty(); + } + + private void closeDirtyAndRevive(final Collection taskWithChangelogs, final boolean markAsCorrupted) { + for (final Task task : taskWithChangelogs) { + final Collection corruptedPartitions = task.changelogPartitions(); + + // mark corrupted partitions to not be checkpointed, and then close the task as dirty + if (markAsCorrupted) { + task.markChangelogAsCorrupted(corruptedPartitions); + } + + try { + // we do not need to take the returned offsets since we are not going to commit anyways; + // this call is only used for active tasks to flush the cache before suspending and + // closing the topology + task.prepareCommit(); + } catch (final RuntimeException swallow) { + log.error("Error flushing cache for corrupted task {} ", task.id(), swallow); + } + + try { + task.suspend(); + + // we need to enforce a checkpoint that removes the corrupted partitions + if (markAsCorrupted) { + task.postCommit(true); + } + } catch (final RuntimeException swallow) { + log.error("Error suspending corrupted task {} ", task.id(), swallow); + } + task.closeDirty(); + + // For active tasks pause their input partitions so we won't poll any more records + // for this task until it has been re-initialized; + // Note, closeDirty already clears the partition-group for the task. + if (task.isActive()) { + final Set currentAssignment = mainConsumer.assignment(); + final Set taskInputPartitions = task.inputPartitions(); + final Set assignedToPauseAndReset = + intersection(HashSet::new, currentAssignment, taskInputPartitions); + if (!assignedToPauseAndReset.equals(taskInputPartitions)) { + log.warn( + "Expected the current consumer assignment {} to contain the input partitions {}. " + + "Will proceed to recover.", + currentAssignment, + taskInputPartitions + ); + } + + task.addPartitionsForOffsetReset(assignedToPauseAndReset); + } + task.revive(); + } + } + + /** + * @throws TaskMigratedException if the task producer got fenced (EOS only) + * @throws StreamsException fatal error while creating / initializing the task + * + * public for upgrade testing only + */ + public void handleAssignment(final Map> activeTasks, + final Map> standbyTasks) { + log.info("Handle new assignment with:\n" + + "\tNew active tasks: {}\n" + + "\tNew standby tasks: {}\n" + + "\tExisting active tasks: {}\n" + + "\tExisting standby tasks: {}", + activeTasks.keySet(), standbyTasks.keySet(), activeTaskIds(), standbyTaskIds()); + + topologyMetadata.addSubscribedTopicsFromAssignment( + activeTasks.values().stream().flatMap(Collection::stream).collect(Collectors.toList()), + logPrefix + ); + + final LinkedHashMap taskCloseExceptions = new LinkedHashMap<>(); + final Map> activeTasksToCreate = new HashMap<>(activeTasks); + final Map> standbyTasksToCreate = new HashMap<>(standbyTasks); + final Comparator byId = Comparator.comparing(Task::id); + final Set tasksToRecycle = new TreeSet<>(byId); + final Set tasksToCloseClean = new TreeSet<>(byId); + final Set tasksToCloseDirty = new TreeSet<>(byId); + + // first rectify all existing tasks + for (final Task task : tasks.allTasks()) { + if (activeTasks.containsKey(task.id()) && task.isActive()) { + tasks.updateInputPartitionsAndResume(task, activeTasks.get(task.id())); + activeTasksToCreate.remove(task.id()); + } else if (standbyTasks.containsKey(task.id()) && !task.isActive()) { + tasks.updateInputPartitionsAndResume(task, standbyTasks.get(task.id())); + standbyTasksToCreate.remove(task.id()); + } else if (activeTasks.containsKey(task.id()) || standbyTasks.containsKey(task.id())) { + // check for tasks that were owned previously but have changed active/standby status + tasksToRecycle.add(task); + } else { + tasksToCloseClean.add(task); + } + } + + // close and recycle those tasks + handleCloseAndRecycle( + tasksToRecycle, + tasksToCloseClean, + tasksToCloseDirty, + activeTasksToCreate, + standbyTasksToCreate, + taskCloseExceptions + ); + + if (!taskCloseExceptions.isEmpty()) { + log.error("Hit exceptions while closing / recycling tasks: {}", taskCloseExceptions); + + for (final Map.Entry entry : taskCloseExceptions.entrySet()) { + if (!(entry.getValue() instanceof TaskMigratedException)) { + final TaskId taskId = entry.getKey(); + final RuntimeException exception = entry.getValue(); + if (exception instanceof StreamsException) { + ((StreamsException) exception).setTaskId(taskId); + throw exception; + } else if (exception instanceof KafkaException) { + throw new StreamsException(exception, taskId); + } else { + throw new StreamsException( + "Unexpected failure to close " + taskCloseExceptions.size() + + " task(s) [" + taskCloseExceptions.keySet() + "]. " + + "First unexpected exception (for task " + taskId + ") follows.", + exception, + taskId + ); + } + } + } + + // If all exceptions are task-migrated, we would just throw the first one. No need to wrap with a + // StreamsException since TaskMigrated is handled explicitly by the StreamThread + final Map.Entry first = taskCloseExceptions.entrySet().iterator().next(); + throw first.getValue(); + } + + tasks.handleNewAssignmentAndCreateTasks(activeTasksToCreate, standbyTasksToCreate, activeTasks.keySet(), standbyTasks.keySet()); + } + + private void handleCloseAndRecycle(final Set tasksToRecycle, + final Set tasksToCloseClean, + final Set tasksToCloseDirty, + final Map> activeTasksToCreate, + final Map> standbyTasksToCreate, + final LinkedHashMap taskCloseExceptions) { + if (!tasksToCloseDirty.isEmpty()) { + throw new IllegalArgumentException("Tasks to close-dirty should be empty"); + } + + // for all tasks to close or recycle, we should first write a checkpoint as in post-commit + final List tasksToCheckpoint = new ArrayList<>(tasksToCloseClean); + tasksToCheckpoint.addAll(tasksToRecycle); + for (final Task task : tasksToCheckpoint) { + try { + // Note that we are not actually committing here but just check if we need to write checkpoint file: + // 1) for active tasks prepareCommit should return empty if it has committed during suspension successfully, + // and their changelog positions should not change at all postCommit would not write the checkpoint again. + // 2) for standby tasks prepareCommit should always return empty, and then in postCommit we would probably + // write the checkpoint file. + final Map offsets = task.prepareCommit(); + if (!offsets.isEmpty()) { + log.error("Task {} should have been committed when it was suspended, but it reports non-empty " + + "offsets {} to commit; this means it failed during last commit and hence should be closed dirty", + task.id(), offsets); + + tasksToCloseDirty.add(task); + } else if (!task.isActive()) { + // For standby tasks, always try to first suspend before committing (checkpointing) it; + // Since standby tasks do not actually need to commit offsets but only need to + // flush / checkpoint state stores, so we only need to call postCommit here. + task.suspend(); + + task.postCommit(true); + } + } catch (final RuntimeException e) { + final String uncleanMessage = String.format( + "Failed to checkpoint task %s. Attempting to close remaining tasks before re-throwing:", + task.id()); + log.error(uncleanMessage, e); + taskCloseExceptions.putIfAbsent(task.id(), e); + // We've already recorded the exception (which is the point of clean). + // Now, we should go ahead and complete the close because a half-closed task is no good to anyone. + tasksToCloseDirty.add(task); + } + } + + tasksToCloseClean.removeAll(tasksToCloseDirty); + for (final Task task : tasksToCloseClean) { + try { + completeTaskCloseClean(task); + if (task.isActive()) { + tasks.cleanUpTaskProducerAndRemoveTask(task.id(), taskCloseExceptions); + } + } catch (final RuntimeException e) { + final String uncleanMessage = String.format( + "Failed to close task %s cleanly. Attempting to close remaining tasks before re-throwing:", + task.id()); + log.error(uncleanMessage, e); + taskCloseExceptions.putIfAbsent(task.id(), e); + tasksToCloseDirty.add(task); + } + } + + tasksToRecycle.removeAll(tasksToCloseDirty); + for (final Task oldTask : tasksToRecycle) { + try { + if (oldTask.isActive()) { + final Set partitions = standbyTasksToCreate.remove(oldTask.id()); + tasks.convertActiveToStandby((StreamTask) oldTask, partitions, taskCloseExceptions); + } else { + final Set partitions = activeTasksToCreate.remove(oldTask.id()); + tasks.convertStandbyToActive((StandbyTask) oldTask, partitions); + } + } catch (final RuntimeException e) { + final String uncleanMessage = String.format("Failed to recycle task %s cleanly. Attempting to close remaining tasks before re-throwing:", oldTask.id()); + log.error(uncleanMessage, e); + taskCloseExceptions.putIfAbsent(oldTask.id(), e); + tasksToCloseDirty.add(oldTask); + } + } + + // for tasks that cannot be cleanly closed or recycled, close them dirty + for (final Task task : tasksToCloseDirty) { + closeTaskDirty(task); + tasks.cleanUpTaskProducerAndRemoveTask(task.id(), taskCloseExceptions); + } + } + + /** + * Tries to initialize any new or still-uninitialized tasks, then checks if they can/have completed restoration. + * + * @throws IllegalStateException If store gets registered after initialized is already finished + * @throws StreamsException if the store's change log does not contain the partition + * @return {@code true} if all tasks are fully restored + */ + boolean tryToCompleteRestoration(final long now, final java.util.function.Consumer> offsetResetter) { + boolean allRunning = true; + + final List activeTasks = new LinkedList<>(); + for (final Task task : tasks.allTasks()) { + try { + task.initializeIfNeeded(); + task.clearTaskTimeout(); + } catch (final LockException lockException) { + // it is possible that if there are multiple threads within the instance that one thread + // trying to grab the task from the other, while the other has not released the lock since + // it did not participate in the rebalance. In this case we can just retry in the next iteration + log.debug("Could not initialize task {} since: {}; will retry", task.id(), lockException.getMessage()); + allRunning = false; + } catch (final TimeoutException timeoutException) { + task.maybeInitTaskTimeoutOrThrow(now, timeoutException); + allRunning = false; + } + + if (task.isActive()) { + activeTasks.add(task); + } + } + + if (allRunning && !activeTasks.isEmpty()) { + + final Set restored = changelogReader.completedChangelogs(); + + for (final Task task : activeTasks) { + if (restored.containsAll(task.changelogPartitions())) { + try { + task.completeRestoration(offsetResetter); + task.clearTaskTimeout(); + } catch (final TimeoutException timeoutException) { + task.maybeInitTaskTimeoutOrThrow(now, timeoutException); + log.debug( + String.format( + "Could not complete restoration for %s due to the following exception; will retry", + task.id()), + timeoutException + ); + + allRunning = false; + } + } else { + // we found a restoring task that isn't done restoring, which is evidence that + // not all tasks are running + allRunning = false; + } + } + } + + if (allRunning) { + // we can call resume multiple times since it is idempotent. + mainConsumer.resume(mainConsumer.assignment()); + } + + return allRunning; + } + + /** + * Handle the revoked partitions and prepare for closing the associated tasks in {@link #handleAssignment(Map, Map)} + * We should commit the revoking tasks first before suspending them as we will not officially own them anymore when + * {@link #handleAssignment(Map, Map)} is called. Note that only active task partitions are passed in from the + * rebalance listener, so we only need to consider/commit active tasks here + * + * If eos-v2 is used, we must commit ALL tasks. Otherwise, we can just commit those (active) tasks which are revoked + * + * @throws TaskMigratedException if the task producer got fenced (EOS only) + */ + void handleRevocation(final Collection revokedPartitions) { + final Set remainingRevokedPartitions = new HashSet<>(revokedPartitions); + + final Set revokedActiveTasks = new HashSet<>(); + final Set commitNeededActiveTasks = new HashSet<>(); + final Map> consumedOffsetsPerTask = new HashMap<>(); + final AtomicReference firstException = new AtomicReference<>(null); + + for (final Task task : activeTaskIterable()) { + if (remainingRevokedPartitions.containsAll(task.inputPartitions())) { + // when the task input partitions are included in the revoked list, + // this is an active task and should be revoked + revokedActiveTasks.add(task); + remainingRevokedPartitions.removeAll(task.inputPartitions()); + } else if (task.commitNeeded()) { + commitNeededActiveTasks.add(task); + } + } + + if (!remainingRevokedPartitions.isEmpty()) { + log.debug("The following revoked partitions {} are missing from the current task partitions. It could " + + "potentially be due to race condition of consumer detecting the heartbeat failure, or the tasks " + + "have been cleaned up by the handleAssignment callback.", remainingRevokedPartitions); + } + + prepareCommitAndAddOffsetsToMap(revokedActiveTasks, consumedOffsetsPerTask); + + // if we need to commit any revoking task then we just commit all of those needed committing together + final boolean shouldCommitAdditionalTasks = !consumedOffsetsPerTask.isEmpty(); + if (shouldCommitAdditionalTasks) { + prepareCommitAndAddOffsetsToMap(commitNeededActiveTasks, consumedOffsetsPerTask); + } + + // even if commit failed, we should still continue and complete suspending those tasks, so we would capture + // any exception and rethrow it at the end. some exceptions may be handled immediately and then swallowed, + // as such we just need to skip those dirty tasks in the checkpoint + final Set dirtyTasks = new HashSet<>(); + try { + // in handleRevocation we must call commitOffsetsOrTransaction() directly rather than + // commitAndFillInConsumedOffsetsAndMetadataPerTaskMap() to make sure we don't skip the + // offset commit because we are in a rebalance + commitOffsetsOrTransaction(consumedOffsetsPerTask); + } catch (final TaskCorruptedException e) { + log.warn("Some tasks were corrupted when trying to commit offsets, these will be cleaned and revived: {}", + e.corruptedTasks()); + + // If we hit a TaskCorruptedException it must be EOS, just handle the cleanup for those corrupted tasks right here + dirtyTasks.addAll(tasks.tasks(e.corruptedTasks())); + closeDirtyAndRevive(dirtyTasks, true); + } catch (final TimeoutException e) { + log.warn("Timed out while trying to commit all tasks during revocation, these will be cleaned and revived"); + + // If we hit a TimeoutException it must be ALOS, just close dirty and revive without wiping the state + dirtyTasks.addAll(consumedOffsetsPerTask.keySet()); + closeDirtyAndRevive(dirtyTasks, false); + } catch (final RuntimeException e) { + log.error("Exception caught while committing those revoked tasks " + revokedActiveTasks, e); + firstException.compareAndSet(null, e); + dirtyTasks.addAll(consumedOffsetsPerTask.keySet()); + } + + // we enforce checkpointing upon suspending a task: if it is resumed later we just proceed normally, if it is + // going to be closed we would checkpoint by then + for (final Task task : revokedActiveTasks) { + if (!dirtyTasks.contains(task)) { + try { + task.postCommit(true); + } catch (final RuntimeException e) { + log.error("Exception caught while post-committing task " + task.id(), e); + maybeWrapAndSetFirstException(firstException, e, task.id()); + } + } + } + + if (shouldCommitAdditionalTasks) { + for (final Task task : commitNeededActiveTasks) { + if (!dirtyTasks.contains(task)) { + try { + // for non-revoking active tasks, we should not enforce checkpoint + // since if it is EOS enabled, no checkpoint should be written while + // the task is in RUNNING tate + task.postCommit(false); + } catch (final RuntimeException e) { + log.error("Exception caught while post-committing task " + task.id(), e); + maybeWrapAndSetFirstException(firstException, e, task.id()); + } + } + } + } + + for (final Task task : revokedActiveTasks) { + try { + task.suspend(); + } catch (final RuntimeException e) { + log.error("Caught the following exception while trying to suspend revoked task " + task.id(), e); + maybeWrapAndSetFirstException(firstException, e, task.id()); + } + } + + if (firstException.get() != null) { + throw firstException.get(); + } + } + + private void prepareCommitAndAddOffsetsToMap(final Set tasksToPrepare, + final Map> consumedOffsetsPerTask) { + for (final Task task : tasksToPrepare) { + try { + final Map committableOffsets = task.prepareCommit(); + if (!committableOffsets.isEmpty()) { + consumedOffsetsPerTask.put(task, committableOffsets); + } + } catch (final StreamsException e) { + e.setTaskId(task.id()); + throw e; + } catch (final Exception e) { + throw new StreamsException(e, task.id()); + } + } + } + + /** + * Closes active tasks as zombies, as these partitions have been lost and are no longer owned. + * NOTE this method assumes that when it is called, EVERY task/partition has been lost and must + * be closed as a zombie. + * + * @throws TaskMigratedException if the task producer got fenced (EOS only) + */ + void handleLostAll() { + log.debug("Closing lost active tasks as zombies."); + + final Set allTask = new HashSet<>(tasks.allTasks()); + for (final Task task : allTask) { + // Even though we've apparently dropped out of the group, we can continue safely to maintain our + // standby tasks while we rejoin. + if (task.isActive()) { + closeTaskDirty(task); + + tasks.cleanUpTaskProducerAndRemoveTask(task.id(), new HashMap<>()); + } + } + + if (processingMode == EXACTLY_ONCE_V2) { + tasks.reInitializeThreadProducer(); + } + } + + /** + * Compute the offset total summed across all stores in a task. Includes offset sum for any tasks we own the + * lock for, which includes assigned and unassigned tasks we locked in {@link #tryToLockAllNonEmptyTaskDirectories()}. + * Does not include stateless or non-logged tasks. + */ + public Map getTaskOffsetSums() { + final Map taskOffsetSums = new HashMap<>(); + + // Not all tasks will create directories, and there may be directories for tasks we don't currently own, + // so we consider all tasks that are either owned or on disk. This includes stateless tasks, which should + // just have an empty changelogOffsets map. + for (final TaskId id : union(HashSet::new, lockedTaskDirectories, tasks.tasksPerId().keySet())) { + final Task task = tasks.owned(id) ? tasks.task(id) : null; + // Closed and uninitialized tasks don't have any offsets so we should read directly from the checkpoint + if (task != null && task.state() != State.CREATED && task.state() != State.CLOSED) { + final Map changelogOffsets = task.changelogOffsets(); + if (changelogOffsets.isEmpty()) { + log.debug("Skipping to encode apparently stateless (or non-logged) offset sum for task {}", id); + } else { + taskOffsetSums.put(id, sumOfChangelogOffsets(id, changelogOffsets)); + } + } else { + final File checkpointFile = stateDirectory.checkpointFileFor(id); + try { + if (checkpointFile.exists()) { + taskOffsetSums.put(id, sumOfChangelogOffsets(id, new OffsetCheckpoint(checkpointFile).read())); + } + } catch (final IOException e) { + log.warn(String.format("Exception caught while trying to read checkpoint for task %s:", id), e); + } + } + } + + return taskOffsetSums; + } + + /** + * Makes a weak attempt to lock all non-empty task directories in the state dir. We are responsible for computing and + * reporting the offset sum for any unassigned tasks we obtain the lock for in the upcoming rebalance. Tasks + * that we locked but didn't own will be released at the end of the rebalance (unless of course we were + * assigned the task as a result of the rebalance). This method should be idempotent. + */ + private void tryToLockAllNonEmptyTaskDirectories() { + // Always clear the set at the beginning as we're always dealing with the + // current set of actually-locked tasks. + lockedTaskDirectories.clear(); + + for (final TaskDirectory taskDir : stateDirectory.listNonEmptyTaskDirectories()) { + final File dir = taskDir.file(); + final String namedTopology = taskDir.namedTopology(); + try { + final TaskId id = parseTaskDirectoryName(dir.getName(), namedTopology); + if (stateDirectory.lock(id)) { + lockedTaskDirectories.add(id); + if (!tasks.owned(id)) { + log.debug("Temporarily locked unassigned task {} for the upcoming rebalance", id); + } + } + } catch (final TaskIdFormatException e) { + // ignore any unknown files that sit in the same directory + } + } + } + + /** + * Clean up after closed or removed tasks by making sure to unlock any remaining locked directories for them, for + * example unassigned tasks or those in the CREATED state when closed, since Task#close will not unlock them + */ + private void releaseLockedDirectoriesForTasks(final Set tasksToUnlock) { + final Iterator taskIdIterator = lockedTaskDirectories.iterator(); + while (taskIdIterator.hasNext()) { + final TaskId id = taskIdIterator.next(); + if (tasksToUnlock.contains(id)) { + stateDirectory.unlock(id); + taskIdIterator.remove(); + } + } + } + + /** + * We must release the lock for any unassigned tasks that we temporarily locked in preparation for a + * rebalance in {@link #tryToLockAllNonEmptyTaskDirectories()}. + */ + private void releaseLockedUnassignedTaskDirectories() { + final Iterator taskIdIterator = lockedTaskDirectories.iterator(); + while (taskIdIterator.hasNext()) { + final TaskId id = taskIdIterator.next(); + if (!tasks.owned(id)) { + stateDirectory.unlock(id); + taskIdIterator.remove(); + } + } + } + + private long sumOfChangelogOffsets(final TaskId id, final Map changelogOffsets) { + long offsetSum = 0L; + for (final Map.Entry changelogEntry : changelogOffsets.entrySet()) { + final long offset = changelogEntry.getValue(); + + + if (offset == Task.LATEST_OFFSET) { + // this condition can only be true for active tasks; never for standby + // for this case, the offset of all partitions is set to `LATEST_OFFSET` + // and we "forward" the sentinel value directly + return Task.LATEST_OFFSET; + } else if (offset != OffsetCheckpoint.OFFSET_UNKNOWN) { + if (offset < 0) { + throw new StreamsException( + new IllegalStateException("Expected not to get a sentinel offset, but got: " + changelogEntry), + id); + } + offsetSum += offset; + if (offsetSum < 0) { + log.warn("Sum of changelog offsets for task {} overflowed, pinning to Long.MAX_VALUE", id); + return Long.MAX_VALUE; + } + } + } + + return offsetSum; + } + + private void closeTaskDirty(final Task task) { + try { + // we call this function only to flush the case if necessary + // before suspending and closing the topology + task.prepareCommit(); + } catch (final RuntimeException swallow) { + log.error("Error flushing caches of dirty task {} ", task.id(), swallow); + } + + try { + task.suspend(); + } catch (final RuntimeException swallow) { + log.error("Error suspending dirty task {} ", task.id(), swallow); + } + tasks.removeTaskBeforeClosing(task.id()); + task.closeDirty(); + } + + private void completeTaskCloseClean(final Task task) { + tasks.removeTaskBeforeClosing(task.id()); + task.closeClean(); + } + + void shutdown(final boolean clean) { + final AtomicReference firstException = new AtomicReference<>(null); + + // TODO: change type to `StreamTask` + final Set activeTasks = new TreeSet<>(Comparator.comparing(Task::id)); + activeTasks.addAll(tasks.activeTasks()); + + executeAndMaybeSwallow( + clean, + () -> closeAndCleanUpTasks(activeTasks, standbyTaskIterable(), clean), + e -> firstException.compareAndSet(null, e), + e -> log.warn("Ignoring an exception while unlocking remaining task directories.", e) + ); + + executeAndMaybeSwallow( + clean, + tasks::closeThreadProducerIfNeeded, + e -> firstException.compareAndSet(null, e), + e -> log.warn("Ignoring an exception while closing thread producer.", e) + ); + + tasks.clear(); + + // this should be called after closing all tasks and clearing them from `tasks` to make sure we unlock the dir + // for any tasks that may have still been in CREATED at the time of shutdown, since Task#close will not do so + executeAndMaybeSwallow( + clean, + this::releaseLockedUnassignedTaskDirectories, + e -> firstException.compareAndSet(null, e), + e -> log.warn("Ignoring an exception while unlocking remaining task directories.", e) + ); + + final RuntimeException fatalException = firstException.get(); + if (fatalException != null) { + throw new RuntimeException("Unexpected exception while closing task", fatalException); + } + } + + /** + * Closes and cleans up after the provided tasks, including closing their corresponding task producers + */ + void closeAndCleanUpTasks(final Collection activeTasks, final Collection standbyTasks, final boolean clean) { + final AtomicReference firstException = new AtomicReference<>(null); + + final Set tasksToCloseDirty = new HashSet<>(); + tasksToCloseDirty.addAll(tryCloseCleanActiveTasks(activeTasks, clean, firstException)); + tasksToCloseDirty.addAll(tryCloseCleanStandbyTasks(standbyTasks, clean, firstException)); + + for (final Task task : tasksToCloseDirty) { + closeTaskDirty(task); + } + + // TODO: change type to `StreamTask` + for (final Task activeTask : activeTasks) { + executeAndMaybeSwallow( + clean, + () -> tasks.closeAndRemoveTaskProducerIfNeeded(activeTask), + e -> firstException.compareAndSet(null, e), + e -> log.warn("Ignoring an exception while closing task " + activeTask.id() + " producer.", e) + ); + } + + final RuntimeException exception = firstException.get(); + if (exception != null) { + throw exception; + } + } + + // Returns the set of active tasks that must be closed dirty + private Collection tryCloseCleanActiveTasks(final Collection activeTasksToClose, + final boolean clean, + final AtomicReference firstException) { + if (!clean) { + return activeTaskIterable(); + } + final Comparator byId = Comparator.comparing(Task::id); + final Set tasksToCommit = new TreeSet<>(byId); + final Set tasksToCloseDirty = new TreeSet<>(byId); + final Set tasksToCloseClean = new TreeSet<>(byId); + final Map> consumedOffsetsAndMetadataPerTask = new HashMap<>(); + + // first committing all tasks and then suspend and close them clean + for (final Task task : activeTasksToClose) { + try { + final Map committableOffsets = task.prepareCommit(); + tasksToCommit.add(task); + if (!committableOffsets.isEmpty()) { + consumedOffsetsAndMetadataPerTask.put(task, committableOffsets); + } + tasksToCloseClean.add(task); + } catch (final TaskMigratedException e) { + // just ignore the exception as it doesn't matter during shutdown + tasksToCloseDirty.add(task); + } catch (final StreamsException e) { + e.setTaskId(task.id()); + firstException.compareAndSet(null, e); + tasksToCloseDirty.add(task); + } catch (final RuntimeException e) { + firstException.compareAndSet(null, new StreamsException(e, task.id())); + tasksToCloseDirty.add(task); + } + } + + // If any active tasks can't be committed, none of them can be, and all that need a commit must be closed dirty + if (processingMode == EXACTLY_ONCE_V2 && !tasksToCloseDirty.isEmpty()) { + tasksToCloseClean.removeAll(tasksToCommit); + tasksToCloseDirty.addAll(tasksToCommit); + } else { + try { + commitOffsetsOrTransaction(consumedOffsetsAndMetadataPerTask); + + for (final Task task : activeTaskIterable()) { + try { + task.postCommit(true); + } catch (final RuntimeException e) { + log.error("Exception caught while post-committing task " + task.id(), e); + maybeWrapAndSetFirstException(firstException, e, task.id()); + tasksToCloseDirty.add(task); + tasksToCloseClean.remove(task); + } + } + } catch (final TimeoutException timeoutException) { + firstException.compareAndSet(null, timeoutException); + + tasksToCloseClean.removeAll(tasksToCommit); + tasksToCloseDirty.addAll(tasksToCommit); + } catch (final TaskCorruptedException taskCorruptedException) { + firstException.compareAndSet(null, taskCorruptedException); + + final Set corruptedTaskIds = taskCorruptedException.corruptedTasks(); + final Set corruptedTasks = tasksToCommit + .stream() + .filter(task -> corruptedTaskIds.contains(task.id())) + .collect(Collectors.toSet()); + + tasksToCloseClean.removeAll(corruptedTasks); + tasksToCloseDirty.addAll(corruptedTasks); + } catch (final RuntimeException e) { + log.error("Exception caught while committing tasks during shutdown", e); + firstException.compareAndSet(null, e); + + // If the commit fails, everyone who participated in it must be closed dirty + tasksToCloseClean.removeAll(tasksToCommit); + tasksToCloseDirty.addAll(tasksToCommit); + } + } + + for (final Task task : tasksToCloseClean) { + try { + task.suspend(); + completeTaskCloseClean(task); + } catch (final StreamsException e) { + log.error("Exception caught while clean-closing task " + task.id(), e); + e.setTaskId(task.id()); + firstException.compareAndSet(null, e); + tasksToCloseDirty.add(task); + } catch (final RuntimeException e) { + log.error("Exception caught while clean-closing task " + task.id(), e); + firstException.compareAndSet(null, new StreamsException(e, task.id())); + tasksToCloseDirty.add(task); + } + } + + return tasksToCloseDirty; + } + + // Returns the set of standby tasks that must be closed dirty + private Collection tryCloseCleanStandbyTasks(final Collection standbyTasksToClose, + final boolean clean, + final AtomicReference firstException) { + if (!clean) { + return standbyTaskIterable(); + } + final Set tasksToCloseDirty = new HashSet<>(); + + // first committing and then suspend / close clean + for (final Task task : standbyTasksToClose) { + try { + task.prepareCommit(); + task.postCommit(true); + task.suspend(); + completeTaskCloseClean(task); + } catch (final TaskMigratedException e) { + // just ignore the exception as it doesn't matter during shutdown + tasksToCloseDirty.add(task); + } catch (final RuntimeException e) { + maybeWrapAndSetFirstException(firstException, e, task.id()); + tasksToCloseDirty.add(task); + } + } + return tasksToCloseDirty; + } + + Set activeTaskIds() { + return activeTaskStream() + .map(Task::id) + .collect(Collectors.toSet()); + } + + Set standbyTaskIds() { + return standbyTaskStream() + .map(Task::id) + .collect(Collectors.toSet()); + } + + Map tasks() { + // not bothering with an unmodifiable map, since the tasks themselves are mutable, but + // if any outside code modifies the map or the tasks, it would be a severe transgression. + return tasks.tasksPerId(); + } + + Map activeTaskMap() { + return activeTaskStream().collect(Collectors.toMap(Task::id, t -> t)); + } + + List activeTaskIterable() { + return activeTaskStream().collect(Collectors.toList()); + } + + private Stream activeTaskStream() { + return tasks.allTasks().stream().filter(Task::isActive); + } + + Map standbyTaskMap() { + return standbyTaskStream().collect(Collectors.toMap(Task::id, t -> t)); + } + + private List standbyTaskIterable() { + return standbyTaskStream().collect(Collectors.toList()); + } + + private Stream standbyTaskStream() { + return tasks.allTasks().stream().filter(t -> !t.isActive()); + } + + // For testing only. + int commitAll() { + return commit(new HashSet<>(tasks.allTasks())); + } + + /** + * Take records and add them to each respective task + * + * @param records Records, can be null + */ + void addRecordsToTasks(final ConsumerRecords records) { + for (final TopicPartition partition : records.partitions()) { + final Task activeTask = tasks.activeTasksForInputPartition(partition); + + if (activeTask == null) { + log.error("Unable to locate active task for received-record partition {}. Current tasks: {}", + partition, toString(">")); + throw new NullPointerException("Task was unexpectedly missing for partition " + partition); + } + + activeTask.addRecords(partition, records.records(partition)); + } + } + + /** + * @throws TaskMigratedException if committing offsets failed (non-EOS) + * or if the task producer got fenced (EOS) + * @throws TimeoutException if task.timeout.ms has been exceeded (non-EOS) + * @throws TaskCorruptedException if committing offsets failed due to TimeoutException (EOS) + * @return number of committed offsets, or -1 if we are in the middle of a rebalance and cannot commit + */ + int commit(final Collection tasksToCommit) { + int committed = 0; + + final Map> consumedOffsetsAndMetadataPerTask = new HashMap<>(); + try { + committed = commitTasksAndMaybeUpdateCommittableOffsets(tasksToCommit, consumedOffsetsAndMetadataPerTask); + } catch (final TimeoutException timeoutException) { + consumedOffsetsAndMetadataPerTask + .keySet() + .forEach(t -> t.maybeInitTaskTimeoutOrThrow(time.milliseconds(), timeoutException)); + } + + return committed; + } + + /** + * @throws TaskMigratedException if committing offsets failed (non-EOS) + * or if the task producer got fenced (EOS) + * @throws TimeoutException if committing offsets failed due to TimeoutException (non-EOS) + * @throws TaskCorruptedException if committing offsets failed due to TimeoutException (EOS) + * @param consumedOffsetsAndMetadata an empty map that will be filled in with the prepared offsets + * @return number of committed offsets, or -1 if we are in the middle of a rebalance and cannot commit + */ + private int commitTasksAndMaybeUpdateCommittableOffsets(final Collection tasksToCommit, + final Map> consumedOffsetsAndMetadata) { + if (rebalanceInProgress) { + return -1; + } + + int committed = 0; + for (final Task task : tasksToCommit) { + // we need to call commitNeeded first since we need to update committable offsets + if (task.commitNeeded()) { + final Map offsetAndMetadata = task.prepareCommit(); + if (!offsetAndMetadata.isEmpty()) { + consumedOffsetsAndMetadata.put(task, offsetAndMetadata); + } + } + } + + commitOffsetsOrTransaction(consumedOffsetsAndMetadata); + + for (final Task task : tasksToCommit) { + if (task.commitNeeded()) { + task.clearTaskTimeout(); + ++committed; + task.postCommit(false); + } + } + return committed; + } + + /** + * @throws TaskMigratedException if committing offsets failed (non-EOS) + * or if the task producer got fenced (EOS) + */ + int maybeCommitActiveTasksPerUserRequested() { + if (rebalanceInProgress) { + return -1; + } else { + for (final Task task : activeTaskIterable()) { + if (task.commitRequested() && task.commitNeeded()) { + return commit(activeTaskIterable()); + } + } + return 0; + } + } + + /** + * Caution: do not invoke this directly if it's possible a rebalance is occurring, as the commit will fail. If + * this is a possibility, prefer the {@link #commitTasksAndMaybeUpdateCommittableOffsets} instead. + * + * @throws TaskMigratedException if committing offsets failed due to CommitFailedException (non-EOS) + * @throws TimeoutException if committing offsets failed due to TimeoutException (non-EOS) + * @throws TaskCorruptedException if committing offsets failed due to TimeoutException (EOS) + */ + private void commitOffsetsOrTransaction(final Map> offsetsPerTask) { + log.debug("Committing task offsets {}", offsetsPerTask.entrySet().stream().collect(Collectors.toMap(t -> t.getKey().id(), Entry::getValue))); // avoid logging actual Task objects + + final Set corruptedTasks = new HashSet<>(); + + if (!offsetsPerTask.isEmpty()) { + if (processingMode == EXACTLY_ONCE_ALPHA) { + for (final Map.Entry> taskToCommit : offsetsPerTask.entrySet()) { + final Task task = taskToCommit.getKey(); + try { + tasks.streamsProducerForTask(task.id()) + .commitTransaction(taskToCommit.getValue(), mainConsumer.groupMetadata()); + updateTaskCommitMetadata(taskToCommit.getValue()); + } catch (final TimeoutException timeoutException) { + log.error( + String.format("Committing task %s failed.", task.id()), + timeoutException + ); + corruptedTasks.add(task.id()); + } + } + } else { + final Map allOffsets = offsetsPerTask.values().stream() + .flatMap(e -> e.entrySet().stream()).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + if (processingMode == EXACTLY_ONCE_V2) { + try { + tasks.threadProducer().commitTransaction(allOffsets, mainConsumer.groupMetadata()); + updateTaskCommitMetadata(allOffsets); + } catch (final TimeoutException timeoutException) { + log.error( + String.format("Committing task(s) %s failed.", + offsetsPerTask + .keySet() + .stream() + .map(t -> t.id().toString()) + .collect(Collectors.joining(", "))), + timeoutException + ); + offsetsPerTask + .keySet() + .forEach(task -> corruptedTasks.add(task.id())); + } + } else { + try { + mainConsumer.commitSync(allOffsets); + updateTaskCommitMetadata(allOffsets); + } catch (final CommitFailedException error) { + throw new TaskMigratedException("Consumer committing offsets failed, " + + "indicating the corresponding thread is no longer part of the group", error); + } catch (final TimeoutException timeoutException) { + log.error( + String.format("Committing task(s) %s failed.", + offsetsPerTask + .keySet() + .stream() + .map(t -> t.id().toString()) + .collect(Collectors.joining(", "))), + timeoutException + ); + throw timeoutException; + } catch (final KafkaException error) { + throw new StreamsException("Error encountered committing offsets via consumer", error); + } + } + } + + if (!corruptedTasks.isEmpty()) { + throw new TaskCorruptedException(corruptedTasks); + } + } + } + + private void updateTaskCommitMetadata(final Map allOffsets) { + for (final Task task: tasks.activeTasks()) { + if (task instanceof StreamTask) { + for (final TopicPartition topicPartition : task.inputPartitions()) { + if (allOffsets.containsKey(topicPartition)) { + ((StreamTask) task).updateCommittedOffsets(topicPartition, allOffsets.get(topicPartition).offset()); + } + } + } + } + } + + public void updateTaskEndMetadata(final TopicPartition topicPartition, final Long offset) { + for (final Task task: tasks.activeTasks()) { + if (task instanceof StreamTask) { + if (task.inputPartitions().contains(topicPartition)) { + ((StreamTask) task).updateEndOffsets(topicPartition, offset); + } + } + } + } + + /** + * Handle any added or removed NamedTopologies. Check if any uncreated assigned tasks belong to a newly + * added NamedTopology and create them if so, then freeze any tasks whose named topology no longer exists + */ + void handleTopologyUpdates() { + tasks.maybeCreateTasksFromNewTopologies(); + + try { + final Set activeTasksToRemove = new HashSet<>(); + final Set standbyTasksToRemove = new HashSet<>(); + for (final Task task : tasks.allTasks()) { + if (!topologyMetadata.namedTopologiesView().contains(task.id().topologyName())) { + if (task.isActive()) { + activeTasksToRemove.add(task); + } else { + standbyTasksToRemove.add(task); + } + } + } + + final Set allRemovedTasks = + union(HashSet::new, activeTasksToRemove, standbyTasksToRemove).stream().map(Task::id).collect(Collectors.toSet()); + closeAndCleanUpTasks(activeTasksToRemove, standbyTasksToRemove, true); + allRemovedTasks.forEach(tasks::removeTaskBeforeClosing); + releaseLockedDirectoriesForTasks(allRemovedTasks); + } catch (final Exception e) { + // TODO KAFKA-12648: for now just swallow the exception to avoid interfering with the other topologies + // that are running alongside, but eventually we should be able to rethrow up to the handler to inform + // the user of an error in this named topology without killing the thread and delaying the others + log.error("Caught the following exception while closing tasks from a removed topology:", e); + } + } + + /** + * @throws TaskMigratedException if the task producer got fenced (EOS only) + * @throws StreamsException if any task threw an exception while processing + */ + int process(final int maxNumRecords, final Time time) { + int totalProcessed = 0; + + long now = time.milliseconds(); + for (final Task task : activeTaskIterable()) { + int processed = 0; + final long then = now; + try { + while (processed < maxNumRecords && task.process(now)) { + task.clearTaskTimeout(); + processed++; + } + } catch (final TimeoutException timeoutException) { + task.maybeInitTaskTimeoutOrThrow(now, timeoutException); + log.debug( + String.format( + "Could not complete processing records for %s due to the following exception; will move to next task and retry later", + task.id()), + timeoutException + ); + } catch (final TaskMigratedException e) { + log.info("Failed to process stream task {} since it got migrated to another thread already. " + + "Will trigger a new rebalance and close all tasks as zombies together.", task.id()); + throw e; + } catch (final StreamsException e) { + log.error("Failed to process stream task {} due to the following error:", task.id(), e); + e.setTaskId(task.id()); + throw e; + } catch (final RuntimeException e) { + log.error("Failed to process stream task {} due to the following error:", task.id(), e); + throw new StreamsException(e, task.id()); + } finally { + now = time.milliseconds(); + totalProcessed += processed; + task.recordProcessBatchTime(now - then); + } + } + + return totalProcessed; + } + + void recordTaskProcessRatio(final long totalProcessLatencyMs, final long now) { + for (final Task task : activeTaskIterable()) { + task.recordProcessTimeRatioAndBufferSize(totalProcessLatencyMs, now); + } + } + + /** + * @throws TaskMigratedException if the task producer got fenced (EOS only) + */ + int punctuate() { + int punctuated = 0; + + for (final Task task : activeTaskIterable()) { + try { + if (task.maybePunctuateStreamTime()) { + punctuated++; + } + if (task.maybePunctuateSystemTime()) { + punctuated++; + } + } catch (final TaskMigratedException e) { + log.info("Failed to punctuate stream task {} since it got migrated to another thread already. " + + "Will trigger a new rebalance and close all tasks as zombies together.", task.id()); + throw e; + } catch (final StreamsException e) { + log.error("Failed to punctuate stream task {} due to the following error:", task.id(), e); + e.setTaskId(task.id()); + throw e; + } catch (final KafkaException e) { + log.error("Failed to punctuate stream task {} due to the following error:", task.id(), e); + throw new StreamsException(e, task.id()); + } + } + + return punctuated; + } + + void maybePurgeCommittedRecords() { + // we do not check any possible exceptions since none of them are fatal + // that should cause the application to fail, and we will try delete with + // newer offsets anyways. + if (deleteRecordsResult == null || deleteRecordsResult.all().isDone()) { + + if (deleteRecordsResult != null && deleteRecordsResult.all().isCompletedExceptionally()) { + log.debug("Previous delete-records request has failed: {}. Try sending the new request now", + deleteRecordsResult.lowWatermarks()); + } + + final Map recordsToDelete = new HashMap<>(); + for (final Task task : activeTaskIterable()) { + for (final Map.Entry entry : task.purgeableOffsets().entrySet()) { + recordsToDelete.put(entry.getKey(), RecordsToDelete.beforeOffset(entry.getValue())); + } + } + if (!recordsToDelete.isEmpty()) { + deleteRecordsResult = adminClient.deleteRecords(recordsToDelete); + log.trace("Sent delete-records request: {}", recordsToDelete); + } + } + } + + /** + * Produces a string representation containing useful information about the TaskManager. + * This is useful in debugging scenarios. + * + * @return A string representation of the TaskManager instance. + */ + @Override + public String toString() { + return toString(""); + } + + public String toString(final String indent) { + final StringBuilder stringBuilder = new StringBuilder(); + stringBuilder.append("TaskManager\n"); + stringBuilder.append(indent).append("\tMetadataState:\n"); + stringBuilder.append(indent).append("\tTasks:\n"); + for (final Task task : tasks.allTasks()) { + stringBuilder.append(indent) + .append("\t\t") + .append(task.id()) + .append(" ") + .append(task.state()) + .append(" ") + .append(task.getClass().getSimpleName()) + .append('(').append(task.isActive() ? "active" : "standby").append(')'); + } + return stringBuilder.toString(); + } + + Map producerMetrics() { + return tasks.producerMetrics(); + } + + Set producerClientIds() { + return tasks.producerClientIds(); + } + + Set lockedTaskDirectories() { + return Collections.unmodifiableSet(lockedTaskDirectories); + } + + private void maybeWrapAndSetFirstException(final AtomicReference firstException, + final RuntimeException exception, + final TaskId taskId) { + if (exception instanceof StreamsException) { + ((StreamsException) exception).setTaskId(taskId); + firstException.compareAndSet(null, exception); + } else { + firstException.compareAndSet(null, new StreamsException(exception, taskId)); + } + } + + public static void executeAndMaybeSwallow(final boolean clean, + final Runnable runnable, + final java.util.function.Consumer actionIfClean, + final java.util.function.Consumer actionIfNotClean) { + try { + runnable.run(); + } catch (final RuntimeException e) { + if (clean) { + actionIfClean.accept(e); + } else { + actionIfNotClean.accept(e); + } + } + } + + public static void executeAndMaybeSwallow(final boolean clean, + final Runnable runnable, + final String name, + final Logger log) { + executeAndMaybeSwallow( + clean, + runnable, + e -> { + throw e; + }, + e -> log.debug("Ignoring error in unclean {}", name)); + } + + boolean needsInitializationOrRestoration() { + return tasks().values().stream().anyMatch(Task::needsInitializationOrRestoration); + } + + // for testing only + void addTask(final Task task) { + tasks.addTask(task); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskMetadataImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskMetadataImpl.java new file mode 100644 index 0000000..95b14d0 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskMetadataImpl.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.streams.TaskMetadata; +import org.apache.kafka.streams.processor.TaskId; + +import java.util.Collections; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; + +public class TaskMetadataImpl implements TaskMetadata { + + private final TaskId taskId; + + private final Set topicPartitions; + + private final Map committedOffsets; + + private final Map endOffsets; + + private final Optional timeCurrentIdlingStarted; + + public TaskMetadataImpl(final TaskId taskId, + final Set topicPartitions, + final Map committedOffsets, + final Map endOffsets, + final Optional timeCurrentIdlingStarted) { + this.taskId = taskId; + this.topicPartitions = Collections.unmodifiableSet(topicPartitions); + this.committedOffsets = Collections.unmodifiableMap(committedOffsets); + this.endOffsets = Collections.unmodifiableMap(endOffsets); + this.timeCurrentIdlingStarted = timeCurrentIdlingStarted; + } + + @Override + public TaskId taskId() { + return taskId; + } + + @Override + public Set topicPartitions() { + return topicPartitions; + } + + @Override + public Map committedOffsets() { + return committedOffsets; + } + + @Override + public Map endOffsets() { + return endOffsets; + } + + @Override + public Optional timeCurrentIdlingStarted() { + return timeCurrentIdlingStarted; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final TaskMetadataImpl that = (TaskMetadataImpl) o; + return Objects.equals(taskId, that.taskId) && + Objects.equals(topicPartitions, that.topicPartitions); + } + + @Override + public int hashCode() { + return Objects.hash(taskId, topicPartitions); + } + + @Override + public String toString() { + return "TaskMetadata{" + + "taskId=" + taskId + + ", topicPartitions=" + topicPartitions + + ", committedOffsets=" + committedOffsets + + ", endOffsets=" + endOffsets + + ", timeCurrentIdlingStarted=" + timeCurrentIdlingStarted + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java new file mode 100644 index 0000000..96c0ee1 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java @@ -0,0 +1,328 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; + +import java.util.HashSet; +import org.slf4j.Logger; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; +import java.util.stream.Collectors; + +class Tasks { + private final Logger log; + private final TopologyMetadata topologyMetadata; + private final StreamsMetricsImpl streamsMetrics; + + private final Map allTasksPerId = new TreeMap<>(); + private final Map readOnlyTasksPerId = Collections.unmodifiableMap(allTasksPerId); + private final Collection readOnlyTasks = Collections.unmodifiableCollection(allTasksPerId.values()); + + // TODO: change type to `StreamTask` + private final Map activeTasksPerId = new TreeMap<>(); + // TODO: change type to `StreamTask` + private final Map activeTasksPerPartition = new HashMap<>(); + // TODO: change type to `StreamTask` + private final Map readOnlyActiveTasksPerId = Collections.unmodifiableMap(activeTasksPerId); + private final Set readOnlyActiveTaskIds = Collections.unmodifiableSet(activeTasksPerId.keySet()); + // TODO: change type to `StreamTask` + private final Collection readOnlyActiveTasks = Collections.unmodifiableCollection(activeTasksPerId.values()); + + // TODO: change type to `StandbyTask` + private final Map standbyTasksPerId = new TreeMap<>(); + // TODO: change type to `StandbyTask` + private final Map readOnlyStandbyTasksPerId = Collections.unmodifiableMap(standbyTasksPerId); + private final Set readOnlyStandbyTaskIds = Collections.unmodifiableSet(standbyTasksPerId.keySet()); + + private final ActiveTaskCreator activeTaskCreator; + private final StandbyTaskCreator standbyTaskCreator; + + private Consumer mainConsumer; + + Tasks(final String logPrefix, + final TopologyMetadata topologyMetadata, + final StreamsMetricsImpl streamsMetrics, + final ActiveTaskCreator activeTaskCreator, + final StandbyTaskCreator standbyTaskCreator) { + + final LogContext logContext = new LogContext(logPrefix); + log = logContext.logger(getClass()); + + this.topologyMetadata = topologyMetadata; + this.streamsMetrics = streamsMetrics; + this.activeTaskCreator = activeTaskCreator; + this.standbyTaskCreator = standbyTaskCreator; + } + + void setMainConsumer(final Consumer mainConsumer) { + this.mainConsumer = mainConsumer; + } + + void handleNewAssignmentAndCreateTasks(final Map> activeTasksToCreate, + final Map> standbyTasksToCreate, + final Set assignedActiveTasks, + final Set assignedStandbyTasks) { + activeTaskCreator.removeRevokedUnknownTasks(assignedActiveTasks); + standbyTaskCreator.removeRevokedUnknownTasks(assignedStandbyTasks); + createTasks(activeTasksToCreate, standbyTasksToCreate); + } + + void maybeCreateTasksFromNewTopologies() { + final Set currentNamedTopologies = topologyMetadata.namedTopologiesView(); + createTasks( + activeTaskCreator.uncreatedTasksForTopologies(currentNamedTopologies), + standbyTaskCreator.uncreatedTasksForTopologies(currentNamedTopologies) + ); + } + + double totalProducerBlockedTime() { + return activeTaskCreator.totalProducerBlockedTime(); + } + + void createTasks(final Map> activeTasksToCreate, + final Map> standbyTasksToCreate) { + for (final Map.Entry> taskToBeCreated : activeTasksToCreate.entrySet()) { + final TaskId taskId = taskToBeCreated.getKey(); + + if (activeTasksPerId.containsKey(taskId)) { + throw new IllegalStateException("Attempted to create an active task that we already own: " + taskId); + } + } + + for (final Map.Entry> taskToBeCreated : standbyTasksToCreate.entrySet()) { + final TaskId taskId = taskToBeCreated.getKey(); + + if (standbyTasksPerId.containsKey(taskId)) { + throw new IllegalStateException("Attempted to create a standby task that we already own: " + taskId); + } + } + + // keep this check to simplify testing (ie, no need to mock `activeTaskCreator`) + if (!activeTasksToCreate.isEmpty()) { + // TODO: change type to `StreamTask` + for (final Task activeTask : activeTaskCreator.createTasks(mainConsumer, activeTasksToCreate)) { + activeTasksPerId.put(activeTask.id(), activeTask); + allTasksPerId.put(activeTask.id(), activeTask); + for (final TopicPartition topicPartition : activeTask.inputPartitions()) { + activeTasksPerPartition.put(topicPartition, activeTask); + } + } + } + + // keep this check to simplify testing (ie, no need to mock `standbyTaskCreator`) + if (!standbyTasksToCreate.isEmpty()) { + // TODO: change type to `StandbyTask` + for (final Task standbyTask : standbyTaskCreator.createTasks(standbyTasksToCreate)) { + standbyTasksPerId.put(standbyTask.id(), standbyTask); + allTasksPerId.put(standbyTask.id(), standbyTask); + } + } + } + + void convertActiveToStandby(final StreamTask activeTask, + final Set partitions, + final Map taskCloseExceptions) { + if (activeTasksPerId.remove(activeTask.id()) == null) { + throw new IllegalStateException("Attempted to convert unknown active task to standby task: " + activeTask.id()); + } + final Set toBeRemoved = activeTasksPerPartition.entrySet().stream() + .filter(e -> e.getValue().id().equals(activeTask.id())) + .map(Map.Entry::getKey) + .collect(Collectors.toSet()); + toBeRemoved.forEach(activeTasksPerPartition::remove); + + cleanUpTaskProducerAndRemoveTask(activeTask.id(), taskCloseExceptions); + + final StandbyTask standbyTask = standbyTaskCreator.createStandbyTaskFromActive(activeTask, partitions); + standbyTasksPerId.put(standbyTask.id(), standbyTask); + allTasksPerId.put(standbyTask.id(), standbyTask); + } + + void convertStandbyToActive(final StandbyTask standbyTask, final Set partitions) { + if (standbyTasksPerId.remove(standbyTask.id()) == null) { + throw new IllegalStateException("Attempted to convert unknown standby task to stream task: " + standbyTask.id()); + } + + final StreamTask activeTask = activeTaskCreator.createActiveTaskFromStandby(standbyTask, partitions, mainConsumer); + activeTasksPerId.put(activeTask.id(), activeTask); + for (final TopicPartition topicPartition : activeTask.inputPartitions()) { + activeTasksPerPartition.put(topicPartition, activeTask); + } + allTasksPerId.put(activeTask.id(), activeTask); + } + + void updateInputPartitionsAndResume(final Task task, final Set topicPartitions) { + final boolean requiresUpdate = !task.inputPartitions().equals(topicPartitions); + if (requiresUpdate) { + log.debug("Update task {} inputPartitions: current {}, new {}", task, task.inputPartitions(), topicPartitions); + for (final TopicPartition inputPartition : task.inputPartitions()) { + activeTasksPerPartition.remove(inputPartition); + } + if (task.isActive()) { + for (final TopicPartition topicPartition : topicPartitions) { + activeTasksPerPartition.put(topicPartition, task); + } + } + task.updateInputPartitions(topicPartitions, topologyMetadata.nodeToSourceTopics(task.id())); + } + task.resume(); + } + + void cleanUpTaskProducerAndRemoveTask(final TaskId taskId, + final Map taskCloseExceptions) { + try { + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId); + } catch (final RuntimeException e) { + final String uncleanMessage = String.format("Failed to close task %s cleanly. Attempting to close remaining tasks before re-throwing:", taskId); + log.error(uncleanMessage, e); + taskCloseExceptions.putIfAbsent(taskId, e); + } + removeTaskBeforeClosing(taskId); + } + + void reInitializeThreadProducer() { + activeTaskCreator.reInitializeThreadProducer(); + } + + void closeThreadProducerIfNeeded() { + activeTaskCreator.closeThreadProducerIfNeeded(); + } + + // TODO: change type to `StreamTask` + void closeAndRemoveTaskProducerIfNeeded(final Task activeTask) { + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(activeTask.id()); + } + + void removeTaskBeforeClosing(final TaskId taskId) { + activeTasksPerId.remove(taskId); + final Set toBeRemoved = activeTasksPerPartition.entrySet().stream() + .filter(e -> e.getValue().id().equals(taskId)) + .map(Map.Entry::getKey) + .collect(Collectors.toSet()); + toBeRemoved.forEach(activeTasksPerPartition::remove); + standbyTasksPerId.remove(taskId); + allTasksPerId.remove(taskId); + } + + void clear() { + activeTasksPerId.clear(); + activeTasksPerPartition.clear(); + standbyTasksPerId.clear(); + allTasksPerId.clear(); + } + + // TODO: change return type to `StreamTask` + Task activeTasksForInputPartition(final TopicPartition partition) { + return activeTasksPerPartition.get(partition); + } + + // TODO: change return type to `StandbyTask` + Task standbyTask(final TaskId taskId) { + if (!standbyTasksPerId.containsKey(taskId)) { + throw new IllegalStateException("Standby task unknown: " + taskId); + } + return standbyTasksPerId.get(taskId); + } + + Task task(final TaskId taskId) { + if (!allTasksPerId.containsKey(taskId)) { + throw new IllegalStateException("Task unknown: " + taskId); + } + return allTasksPerId.get(taskId); + } + + Collection tasks(final Collection taskIds) { + final Set tasks = new HashSet<>(); + for (final TaskId taskId : taskIds) { + tasks.add(task(taskId)); + } + return tasks; + } + + // TODO: change return type to `StreamTask` + Collection activeTasks() { + return readOnlyActiveTasks; + } + + Collection allTasks() { + return readOnlyTasks; + } + + Set activeTaskIds() { + return readOnlyActiveTaskIds; + } + + Set standbyTaskIds() { + return readOnlyStandbyTaskIds; + } + + // TODO: change return type to `StreamTask` + Map activeTaskMap() { + return readOnlyActiveTasksPerId; + } + + // TODO: change return type to `StandbyTask` + Map standbyTaskMap() { + return readOnlyStandbyTasksPerId; + } + + Map tasksPerId() { + return readOnlyTasksPerId; + } + + boolean owned(final TaskId taskId) { + return allTasksPerId.containsKey(taskId); + } + + StreamsProducer streamsProducerForTask(final TaskId taskId) { + return activeTaskCreator.streamsProducerForTask(taskId); + } + + StreamsProducer threadProducer() { + return activeTaskCreator.threadProducer(); + } + + Map producerMetrics() { + return activeTaskCreator.producerMetrics(); + } + + Set producerClientIds() { + return activeTaskCreator.producerClientIds(); + } + + // for testing only + void addTask(final Task task) { + if (task.isActive()) { + activeTasksPerId.put(task.id(), task); + } else { + standbyTasksPerId.put(task.id(), task); + } + allTasksPerId.put(task.id(), task); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadMetadataImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadMetadataImpl.java new file mode 100644 index 0000000..7a0188c --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadMetadataImpl.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.TaskMetadata; +import org.apache.kafka.streams.ThreadMetadata; + +import java.util.Collections; +import java.util.Objects; +import java.util.Set; + +/** + * Represents the state of a single thread running within a {@link KafkaStreams} application. + */ +public class ThreadMetadataImpl implements ThreadMetadata { + + private final String threadName; + + private final String threadState; + + private final Set activeTasks; + + private final Set standbyTasks; + + private final String mainConsumerClientId; + + private final String restoreConsumerClientId; + + private final Set producerClientIds; + + // the admin client should be shared among all threads, so the client id should be the same; + // we keep it at the thread-level for user's convenience and possible extensions in the future + private final String adminClientId; + + public ThreadMetadataImpl(final String threadName, + final String threadState, + final String mainConsumerClientId, + final String restoreConsumerClientId, + final Set producerClientIds, + final String adminClientId, + final Set activeTasks, + final Set standbyTasks) { + this.mainConsumerClientId = mainConsumerClientId; + this.restoreConsumerClientId = restoreConsumerClientId; + this.producerClientIds = Collections.unmodifiableSet(producerClientIds); + this.adminClientId = adminClientId; + this.threadName = threadName; + this.threadState = threadState; + this.activeTasks = Collections.unmodifiableSet(activeTasks); + this.standbyTasks = Collections.unmodifiableSet(standbyTasks); + } + + + public String threadState() { + return threadState; + } + + public String threadName() { + return threadName; + } + + + public Set activeTasks() { + return activeTasks; + } + + public Set standbyTasks() { + return standbyTasks; + } + + public String consumerClientId() { + return mainConsumerClientId; + } + + public String restoreConsumerClientId() { + return restoreConsumerClientId; + } + + public Set producerClientIds() { + return producerClientIds; + } + + public String adminClientId() { + return adminClientId; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final ThreadMetadataImpl that = (ThreadMetadataImpl) o; + return Objects.equals(threadName, that.threadName) && + Objects.equals(threadState, that.threadState) && + Objects.equals(activeTasks, that.activeTasks) && + Objects.equals(standbyTasks, that.standbyTasks) && + mainConsumerClientId.equals(that.mainConsumerClientId) && + restoreConsumerClientId.equals(that.restoreConsumerClientId) && + Objects.equals(producerClientIds, that.producerClientIds) && + adminClientId.equals(that.adminClientId); + } + + @Override + public int hashCode() { + return Objects.hash( + threadName, + threadState, + activeTasks, + standbyTasks, + mainConsumerClientId, + restoreConsumerClientId, + producerClientIds, + adminClientId); + } + + @Override + public String toString() { + return "ThreadMetadata{" + + "threadName=" + threadName + + ", threadState=" + threadState + + ", activeTasks=" + activeTasks + + ", standbyTasks=" + standbyTasks + + ", consumerClientId=" + mainConsumerClientId + + ", restoreConsumerClientId=" + restoreConsumerClientId + + ", producerClientIds=" + producerClientIds + + ", adminClientId=" + adminClientId + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadStateTransitionValidator.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadStateTransitionValidator.java new file mode 100644 index 0000000..4197c71 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ThreadStateTransitionValidator.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +/** + * Basic interface for keeping track of the state of a thread. + */ +public interface ThreadStateTransitionValidator { + boolean isValidTransition(final ThreadStateTransitionValidator newState); +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/ToInternal.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ToInternal.java new file mode 100644 index 0000000..8865846 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/ToInternal.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.streams.processor.To; + +public class ToInternal extends To { + public ToInternal() { + super(To.all()); + } + + public ToInternal(final To to) { + super(to); + } + + public void update(final To to) { + super.update(to); + } + + public boolean hasTimestamp() { + return timestamp != -1; + } + + public long timestamp() { + return timestamp; + } + + public String child() { + return childName; + } +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TopologyMetadata.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TopologyMetadata.java new file mode 100644 index 0000000..129ca09 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TopologyMetadata.java @@ -0,0 +1,445 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.OffsetResetStrategy; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.TopologyException; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder.TopicsInfo; +import org.apache.kafka.streams.processor.internals.StreamThread.State; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.ConcurrentNavigableMap; +import java.util.concurrent.ConcurrentSkipListMap; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; +import java.util.regex.Pattern; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static java.util.Collections.emptySet; + +public class TopologyMetadata { + private final Logger log = LoggerFactory.getLogger(TopologyMetadata.class); + + // the "__" (double underscore) string is not allowed for topology names, so it's safe to use to indicate + // that it's not a named topology + private static final String UNNAMED_TOPOLOGY = "__UNNAMED_TOPOLOGY__"; + private static final Pattern EMPTY_ZERO_LENGTH_PATTERN = Pattern.compile(""); + + private final StreamsConfig config; + private final TopologyVersion version; + + private final ConcurrentNavigableMap builders; // Keep sorted by topology name for readability + + private ProcessorTopology globalTopology; + private final Map globalStateStores = new HashMap<>(); + private final Set allInputTopics = new HashSet<>(); + + public static class TopologyVersion { + public AtomicLong topologyVersion = new AtomicLong(0L); // the local topology version + public ReentrantLock topologyLock = new ReentrantLock(); + public Condition topologyCV = topologyLock.newCondition(); + } + + public TopologyMetadata(final InternalTopologyBuilder builder, + final StreamsConfig config) { + version = new TopologyVersion(); + this.config = config; + builders = new ConcurrentSkipListMap<>(); + if (builder.hasNamedTopology()) { + builders.put(builder.topologyName(), builder); + } else { + builders.put(UNNAMED_TOPOLOGY, builder); + } + } + + public TopologyMetadata(final ConcurrentNavigableMap builders, + final StreamsConfig config) { + version = new TopologyVersion(); + this.config = config; + + this.builders = builders; + if (builders.isEmpty()) { + log.debug("Starting up empty KafkaStreams app with no topology"); + } + } + + public long topologyVersion() { + return version.topologyVersion.get(); + } + + private void lock() { + version.topologyLock.lock(); + } + + private void unlock() { + version.topologyLock.unlock(); + } + + public void wakeupThreads() { + try { + lock(); + version.topologyCV.signalAll(); + } finally { + unlock(); + } + } + + public void maybeWaitForNonEmptyTopology(final Supplier threadState) { + if (isEmpty() && threadState.get().isAlive()) { + try { + lock(); + while (isEmpty() && threadState.get().isAlive()) { + try { + log.debug("Detected that the topology is currently empty, waiting for something to process"); + version.topologyCV.await(); + } catch (final InterruptedException e) { + log.debug("StreamThread was interrupted while waiting on empty topology", e); + } + } + } finally { + unlock(); + } + } + } + + public void registerAndBuildNewTopology(final InternalTopologyBuilder newTopologyBuilder) { + try { + lock(); + version.topologyVersion.incrementAndGet(); + log.info("Adding NamedTopology {}, latest topology version is {}", newTopologyBuilder.topologyName(), version.topologyVersion.get()); + builders.put(newTopologyBuilder.topologyName(), newTopologyBuilder); + buildAndVerifyTopology(newTopologyBuilder); + version.topologyCV.signalAll(); + } finally { + unlock(); + } + } + + public void unregisterTopology(final String topologyName) { + try { + lock(); + version.topologyVersion.incrementAndGet(); + log.info("Removing NamedTopology {}, latest topology version is {}", topologyName, version.topologyVersion.get()); + final InternalTopologyBuilder removedBuilder = builders.remove(topologyName); + removedBuilder.fullSourceTopicNames().forEach(allInputTopics::remove); + removedBuilder.allSourcePatternStrings().forEach(allInputTopics::remove); + version.topologyCV.signalAll(); + } finally { + unlock(); + } + } + + public void buildAndRewriteTopology() { + applyToEachBuilder(this::buildAndVerifyTopology); + } + + private void buildAndVerifyTopology(final InternalTopologyBuilder builder) { + builder.rewriteTopology(config); + builder.buildTopology(); + + // As we go, check each topology for overlap in the set of input topics/patterns + final int numInputTopics = allInputTopics.size(); + final List inputTopics = builder.fullSourceTopicNames(); + final Collection inputPatterns = builder.allSourcePatternStrings(); + + final int numNewInputTopics = inputTopics.size() + inputPatterns.size(); + allInputTopics.addAll(inputTopics); + allInputTopics.addAll(inputPatterns); + if (allInputTopics.size() != numInputTopics + numNewInputTopics) { + inputTopics.retainAll(allInputTopics); + inputPatterns.retainAll(allInputTopics); + log.error("Tried to add the NamedTopology {} but it had overlap with other input topics {} or patterns {}", + builder.topologyName(), inputTopics, inputPatterns); + throw new TopologyException("Named Topologies may not subscribe to the same input topics or patterns"); + } + + final ProcessorTopology globalTopology = builder.buildGlobalStateTopology(); + if (globalTopology != null) { + if (builder.topologyName() != null) { + throw new IllegalStateException("Global state stores are not supported with Named Topologies"); + } else if (this.globalTopology == null) { + this.globalTopology = globalTopology; + } else { + throw new IllegalStateException("Topology builder had global state, but global topology has already been set"); + } + } + globalStateStores.putAll(builder.globalStateStores()); + } + + public int getNumStreamThreads(final StreamsConfig config) { + final int configuredNumStreamThreads = config.getInt(StreamsConfig.NUM_STREAM_THREADS_CONFIG); + + // If there are named topologies but some are empty, this indicates a bug in user code + if (hasNamedTopologies()) { + if (hasNoLocalTopology()) { + log.error("Detected a named topology with no input topics, a named topology may not be empty."); + throw new TopologyException("Topology has no stream threads and no global threads, " + + "must subscribe to at least one source topic or pattern."); + } + } else { + // If both the global and non-global topologies are empty, this indicates a bug in user code + if (hasNoLocalTopology() && !hasGlobalTopology()) { + log.error("Topology with no input topics will create no stream threads and no global thread."); + throw new TopologyException("Topology has no stream threads and no global threads, " + + "must subscribe to at least one source topic or global table."); + } + } + + // Lastly we check for an empty non-global topology and override the threads to zero if set otherwise + if (configuredNumStreamThreads != 0 && hasNoLocalTopology()) { + log.info("Overriding number of StreamThreads to zero for global-only topology"); + return 0; + } + + return configuredNumStreamThreads; + } + + /** + * @return true iff the app is using named topologies, or was started up with no topology at all + */ + public boolean hasNamedTopologies() { + return !builders.containsKey(UNNAMED_TOPOLOGY); + } + + Set namedTopologiesView() { + return hasNamedTopologies() ? Collections.unmodifiableSet(builders.keySet()) : emptySet(); + } + + /** + * @return true iff any of the topologies have a global topology + */ + public boolean hasGlobalTopology() { + return evaluateConditionIsTrueForAnyBuilders(InternalTopologyBuilder::hasGlobalStores); + } + + /** + * @return true iff any of the topologies have no local (aka non-global) topology + */ + public boolean hasNoLocalTopology() { + return evaluateConditionIsTrueForAnyBuilders(InternalTopologyBuilder::hasNoLocalTopology); + } + + public boolean hasPersistentStores() { + // If the app is using named topologies, there may not be any persistent state when it first starts up + // but a new NamedTopology may introduce it later, so we must return true + if (hasNamedTopologies()) { + return true; + } + return evaluateConditionIsTrueForAnyBuilders(InternalTopologyBuilder::hasPersistentStores); + } + + public boolean hasStore(final String name) { + return evaluateConditionIsTrueForAnyBuilders(b -> b.hasStore(name)); + } + + public boolean hasOffsetResetOverrides() { + // Return true if using named topologies, as there may be named topologies added later which do have overrides + return hasNamedTopologies() || evaluateConditionIsTrueForAnyBuilders(InternalTopologyBuilder::hasOffsetResetOverrides); + } + + public OffsetResetStrategy offsetResetStrategy(final String topic) { + for (final InternalTopologyBuilder builder : builders.values()) { + final OffsetResetStrategy resetStrategy = builder.offsetResetStrategy(topic); + if (resetStrategy != null) { + return resetStrategy; + } + } + return null; + } + + Collection sourceTopicCollection() { + final List sourceTopics = new ArrayList<>(); + applyToEachBuilder(b -> sourceTopics.addAll(b.sourceTopicCollection())); + return sourceTopics; + } + + Pattern sourceTopicPattern() { + final StringBuilder patternBuilder = new StringBuilder(); + + applyToEachBuilder(b -> { + final String patternString = b.sourceTopicsPatternString(); + if (patternString.length() > 0) { + patternBuilder.append(patternString).append("|"); + } + }); + + if (patternBuilder.length() > 0) { + patternBuilder.setLength(patternBuilder.length() - 1); + return Pattern.compile(patternBuilder.toString()); + } else { + return EMPTY_ZERO_LENGTH_PATTERN; + } + } + + public boolean usesPatternSubscription() { + return evaluateConditionIsTrueForAnyBuilders(InternalTopologyBuilder::usesPatternSubscription); + } + + // Can be empty if app is started up with no Named Topologies, in order to add them on later + public boolean isEmpty() { + return builders.isEmpty(); + } + + public String topologyDescriptionString() { + if (isEmpty()) { + return ""; + } + final StringBuilder sb = new StringBuilder(); + + applyToEachBuilder(b -> sb.append(b.describe().toString())); + + return sb.toString(); + } + + /** + * @return the subtopology built for this task, or null if the corresponding NamedTopology does not (yet) exist + */ + public ProcessorTopology buildSubtopology(final TaskId task) { + final InternalTopologyBuilder builder = lookupBuilderForTask(task); + return builder == null ? null : builder.buildSubtopology(task.subtopology()); + } + + public ProcessorTopology globalTaskTopology() { + if (hasNamedTopologies()) { + throw new IllegalStateException("Global state stores are not supported with Named Topologies"); + } + return globalTopology; + } + + public Map globalStateStores() { + return globalStateStores; + } + + public Map> stateStoreNameToSourceTopics() { + final Map> stateStoreNameToSourceTopics = new HashMap<>(); + applyToEachBuilder(b -> stateStoreNameToSourceTopics.putAll(b.stateStoreNameToSourceTopics())); + return stateStoreNameToSourceTopics; + } + + public String getStoreForChangelogTopic(final String topicName) { + for (final InternalTopologyBuilder builder : builders.values()) { + final String store = builder.getStoreForChangelogTopic(topicName); + if (store != null) { + return store; + } + } + log.warn("Unable to locate any store for topic {}", topicName); + return ""; + } + + public Collection sourceTopicsForStore(final String storeName) { + final List sourceTopics = new ArrayList<>(); + applyToEachBuilder(b -> sourceTopics.addAll(b.sourceTopicsForStore(storeName))); + return sourceTopics; + } + + public Map topicGroups() { + final Map topicGroups = new HashMap<>(); + applyToEachBuilder(b -> topicGroups.putAll(b.topicGroups())); + return topicGroups; + } + + public Map> nodeToSourceTopics(final TaskId task) { + return lookupBuilderForTask(task).nodeToSourceTopics(); + } + + void addSubscribedTopicsFromMetadata(final Set topics, final String logPrefix) { + applyToEachBuilder(b -> b.addSubscribedTopicsFromMetadata(topics, logPrefix)); + } + + void addSubscribedTopicsFromAssignment(final List partitions, final String logPrefix) { + applyToEachBuilder(b -> b.addSubscribedTopicsFromAssignment(partitions, logPrefix)); + } + + public Collection> copartitionGroups() { + final List> copartitionGroups = new ArrayList<>(); + applyToEachBuilder(b -> copartitionGroups.addAll(b.copartitionGroups())); + return copartitionGroups; + } + + private InternalTopologyBuilder lookupBuilderForTask(final TaskId task) { + return task.topologyName() == null ? builders.get(UNNAMED_TOPOLOGY) : builders.get(task.topologyName()); + } + + /** + * @return the InternalTopologyBuilder for a NamedTopology, or null if no such NamedTopology exists + */ + public InternalTopologyBuilder lookupBuilderForNamedTopology(final String name) { + return builders.get(name); + } + + private boolean evaluateConditionIsTrueForAnyBuilders(final Function condition) { + for (final InternalTopologyBuilder builder : builders.values()) { + if (condition.apply(builder)) { + return true; + } + } + return false; + } + + private void applyToEachBuilder(final Consumer function) { + for (final InternalTopologyBuilder builder : builders.values()) { + function.accept(builder); + } + } + + public static class Subtopology { + final int nodeGroupId; + final String namedTopology; + + public Subtopology(final int nodeGroupId, final String namedTopology) { + this.nodeGroupId = nodeGroupId; + this.namedTopology = namedTopology; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final Subtopology that = (Subtopology) o; + return nodeGroupId == that.nodeGroupId && + Objects.equals(namedTopology, that.namedTopology); + } + + @Override + public int hashCode() { + return Objects.hash(nodeGroupId, namedTopology); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/UnwindowedChangelogTopicConfig.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/UnwindowedChangelogTopicConfig.java new file mode 100644 index 0000000..b0c2fda --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/UnwindowedChangelogTopicConfig.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.config.TopicConfig; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +/** + * UnwindowedChangelogTopicConfig captures the properties required for configuring + * the un-windowed store changelog topics. + */ +public class UnwindowedChangelogTopicConfig extends InternalTopicConfig { + private static final Map UNWINDOWED_STORE_CHANGELOG_TOPIC_DEFAULT_OVERRIDES; + static { + final Map tempTopicDefaultOverrides = new HashMap<>(INTERNAL_TOPIC_DEFAULT_OVERRIDES); + tempTopicDefaultOverrides.put(TopicConfig.CLEANUP_POLICY_CONFIG, TopicConfig.CLEANUP_POLICY_COMPACT); + UNWINDOWED_STORE_CHANGELOG_TOPIC_DEFAULT_OVERRIDES = Collections.unmodifiableMap(tempTopicDefaultOverrides); + } + + UnwindowedChangelogTopicConfig(final String name, final Map topicConfigs) { + super(name, topicConfigs); + } + + /** + * Get the configured properties for this topic. If retentionMs is set then + * we add additionalRetentionMs to work out the desired retention when cleanup.policy=compact,delete + * + * @param additionalRetentionMs - added to retention to allow for clock drift etc + * @return Properties to be used when creating the topic + */ + @Override + public Map getProperties(final Map defaultProperties, final long additionalRetentionMs) { + // internal topic config overridden rule: library overrides < global config overrides < per-topic config overrides + final Map topicConfig = new HashMap<>(UNWINDOWED_STORE_CHANGELOG_TOPIC_DEFAULT_OVERRIDES); + + topicConfig.putAll(defaultProperties); + + topicConfig.putAll(topicConfigs); + + return topicConfig; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final UnwindowedChangelogTopicConfig that = (UnwindowedChangelogTopicConfig) o; + return Objects.equals(name, that.name) && + Objects.equals(topicConfigs, that.topicConfigs) && + Objects.equals(enforceNumberOfPartitions, that.enforceNumberOfPartitions); + } + + @Override + public int hashCode() { + return Objects.hash(name, topicConfigs, enforceNumberOfPartitions); + } + + @Override + public String toString() { + return "UnwindowedChangelogTopicConfig(" + + "name=" + name + + ", topicConfigs=" + topicConfigs + + ", enforceNumberOfPartitions=" + enforceNumberOfPartitions + + ")"; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/WindowedChangelogTopicConfig.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/WindowedChangelogTopicConfig.java new file mode 100644 index 0000000..6e1dc86 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/WindowedChangelogTopicConfig.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.config.TopicConfig; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +/** + * WindowedChangelogTopicConfig captures the properties required for configuring + * the windowed store changelog topics. + */ +public class WindowedChangelogTopicConfig extends InternalTopicConfig { + private static final Map WINDOWED_STORE_CHANGELOG_TOPIC_DEFAULT_OVERRIDES; + static { + final Map tempTopicDefaultOverrides = new HashMap<>(INTERNAL_TOPIC_DEFAULT_OVERRIDES); + tempTopicDefaultOverrides.put(TopicConfig.CLEANUP_POLICY_CONFIG, TopicConfig.CLEANUP_POLICY_COMPACT + "," + TopicConfig.CLEANUP_POLICY_DELETE); + WINDOWED_STORE_CHANGELOG_TOPIC_DEFAULT_OVERRIDES = Collections.unmodifiableMap(tempTopicDefaultOverrides); + } + + private Long retentionMs; + + WindowedChangelogTopicConfig(final String name, final Map topicConfigs) { + super(name, topicConfigs); + } + + /** + * Get the configured properties for this topic. If retentionMs is set then + * we add additionalRetentionMs to work out the desired retention when cleanup.policy=compact,delete + * + * @param additionalRetentionMs - added to retention to allow for clock drift etc + * @return Properties to be used when creating the topic + */ + @Override + public Map getProperties(final Map defaultProperties, final long additionalRetentionMs) { + // internal topic config overridden rule: library overrides < global config overrides < per-topic config overrides + final Map topicConfig = new HashMap<>(WINDOWED_STORE_CHANGELOG_TOPIC_DEFAULT_OVERRIDES); + + topicConfig.putAll(defaultProperties); + + topicConfig.putAll(topicConfigs); + + if (retentionMs != null) { + long retentionValue; + try { + retentionValue = Math.addExact(retentionMs, additionalRetentionMs); + } catch (final ArithmeticException swallow) { + retentionValue = Long.MAX_VALUE; + } + topicConfig.put(TopicConfig.RETENTION_MS_CONFIG, String.valueOf(retentionValue)); + } + + return topicConfig; + } + + void setRetentionMs(final long retentionMs) { + if (!topicConfigs.containsKey(TopicConfig.RETENTION_MS_CONFIG)) { + this.retentionMs = retentionMs; + } + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final WindowedChangelogTopicConfig that = (WindowedChangelogTopicConfig) o; + return Objects.equals(name, that.name) && + Objects.equals(topicConfigs, that.topicConfigs) && + Objects.equals(retentionMs, that.retentionMs) && + Objects.equals(enforceNumberOfPartitions, that.enforceNumberOfPartitions); + } + + @Override + public int hashCode() { + return Objects.hash(name, topicConfigs, retentionMs, enforceNumberOfPartitions); + } + + @Override + public String toString() { + return "WindowedChangelogTopicConfig(" + + "name=" + name + + ", topicConfigs=" + topicConfigs + + ", retentionMs=" + retentionMs + + ", enforceNumberOfPartitions=" + enforceNumberOfPartitions + + ")"; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentInfo.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentInfo.java new file mode 100644 index 0000000..ecf0b25 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentInfo.java @@ -0,0 +1,499 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.ByteBufferInputStream; +import org.apache.kafka.streams.errors.TaskAssignmentException; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.state.HostInfo; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.apache.kafka.streams.processor.internals.assignment.ConsumerProtocolUtils.readTaskIdFrom; +import static org.apache.kafka.streams.processor.internals.assignment.ConsumerProtocolUtils.writeTaskIdTo; +import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.LATEST_SUPPORTED_VERSION; +import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.UNKNOWN; + +public class AssignmentInfo { + private static final Logger log = LoggerFactory.getLogger(AssignmentInfo.class); + + private final int usedVersion; + private final int commonlySupportedVersion; + private List activeTasks; + private Map> standbyTasks; + private Map> partitionsByHost; + private Map> standbyPartitionsByHost; + private int errCode; + private Long nextRebalanceMs = Long.MAX_VALUE; + + // used for decoding and "future consumer" assignments during version probing + public AssignmentInfo(final int version, + final int commonlySupportedVersion) { + this(version, + commonlySupportedVersion, + Collections.emptyList(), + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap(), + 0); + } + + public AssignmentInfo(final int version, + final List activeTasks, + final Map> standbyTasks, + final Map> partitionsByHost, + final Map> standbyPartitionsByHost, + final int errCode) { + this(version, LATEST_SUPPORTED_VERSION, activeTasks, standbyTasks, partitionsByHost, standbyPartitionsByHost, errCode); + } + + public AssignmentInfo(final int version, + final int commonlySupportedVersion, + final List activeTasks, + final Map> standbyTasks, + final Map> partitionsByHost, + final Map> standbyPartitionsByHost, + final int errCode) { + this.usedVersion = version; + this.commonlySupportedVersion = commonlySupportedVersion; + this.activeTasks = activeTasks; + this.standbyTasks = standbyTasks; + this.partitionsByHost = partitionsByHost; + this.standbyPartitionsByHost = standbyPartitionsByHost; + this.errCode = errCode; + + if (version < 1 || version > LATEST_SUPPORTED_VERSION) { + throw new IllegalArgumentException("version must be between 1 and " + LATEST_SUPPORTED_VERSION + + "; was: " + version); + } + } + + public void setNextRebalanceTime(final long nextRebalanceTimeMs) { + this.nextRebalanceMs = nextRebalanceTimeMs; + } + + public int version() { + return usedVersion; + } + + public int errCode() { + return errCode; + } + + public int commonlySupportedVersion() { + return commonlySupportedVersion; + } + + public List activeTasks() { + return activeTasks; + } + + public Map> standbyTasks() { + return standbyTasks; + } + + public Map> partitionsByHost() { + return partitionsByHost; + } + + public Map> standbyPartitionByHost() { + return standbyPartitionsByHost; + } + + public long nextRebalanceMs() { + return nextRebalanceMs; + } + + /** + * @throws TaskAssignmentException if method fails to encode the data, e.g., if there is an + * IO exception during encoding + */ + public ByteBuffer encode() { + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + + try (final DataOutputStream out = new DataOutputStream(baos)) { + switch (usedVersion) { + case 1: + out.writeInt(usedVersion); // version + encodeActiveAndStandbyTaskAssignment(out); + break; + case 2: + out.writeInt(usedVersion); // version + encodeActiveAndStandbyTaskAssignment(out); + encodePartitionsByHost(out); + break; + case 3: + out.writeInt(usedVersion); + out.writeInt(commonlySupportedVersion); + encodeActiveAndStandbyTaskAssignment(out); + encodePartitionsByHost(out); + break; + case 4: + out.writeInt(usedVersion); + out.writeInt(commonlySupportedVersion); + encodeActiveAndStandbyTaskAssignment(out); + encodePartitionsByHost(out); + out.writeInt(errCode); + break; + case 5: + out.writeInt(usedVersion); + out.writeInt(commonlySupportedVersion); + encodeActiveAndStandbyTaskAssignment(out); + encodePartitionsByHostAsDictionary(out); + out.writeInt(errCode); + break; + case 6: + out.writeInt(usedVersion); + out.writeInt(commonlySupportedVersion); + encodeActiveAndStandbyTaskAssignment(out); + encodeActiveAndStandbyHostPartitions(out); + out.writeInt(errCode); + break; + case 7: + case 8: + case 9: + case 10: + out.writeInt(usedVersion); + out.writeInt(commonlySupportedVersion); + encodeActiveAndStandbyTaskAssignment(out); + encodeActiveAndStandbyHostPartitions(out); + out.writeInt(errCode); + out.writeLong(nextRebalanceMs); + break; + default: + throw new IllegalStateException("Unknown metadata version: " + usedVersion + + "; latest commonly supported version: " + commonlySupportedVersion); + } + + out.flush(); + out.close(); + + return ByteBuffer.wrap(baos.toByteArray()); + } catch (final IOException ex) { + throw new TaskAssignmentException("Failed to encode AssignmentInfo", ex); + } + } + + private void encodeActiveAndStandbyTaskAssignment(final DataOutputStream out) throws IOException { + // encode active tasks + out.writeInt(activeTasks.size()); + for (final TaskId id : activeTasks) { + writeTaskIdTo(id, out, usedVersion); + } + + // encode standby tasks + out.writeInt(standbyTasks.size()); + for (final Map.Entry> entry : standbyTasks.entrySet()) { + final TaskId id = entry.getKey(); + writeTaskIdTo(id, out, usedVersion); + + final Set partitions = entry.getValue(); + writeTopicPartitions(out, partitions); + } + } + + private void encodePartitionsByHost(final DataOutputStream out) throws IOException { + // encode partitions by host + out.writeInt(partitionsByHost.size()); + for (final Map.Entry> entry : partitionsByHost.entrySet()) { + writeHostInfo(out, entry.getKey()); + writeTopicPartitions(out, entry.getValue()); + } + } + + private void encodeHostPartitionMapUsingDictionary(final DataOutputStream out, + final Map topicNameDict, + final Map> hostPartitionMap) throws IOException { + // encode partitions by host + out.writeInt(hostPartitionMap.size()); + + // Write the topic index, partition + for (final Map.Entry> entry : hostPartitionMap.entrySet()) { + writeHostInfo(out, entry.getKey()); + out.writeInt(entry.getValue().size()); + for (final TopicPartition partition : entry.getValue()) { + out.writeInt(topicNameDict.get(partition.topic())); + out.writeInt(partition.partition()); + } + } + } + + private Map encodeTopicDictionaryAndGet(final DataOutputStream out, + final Set topicPartitions) throws IOException { + // Build a dictionary to encode topicNames + int topicIndex = 0; + final Map topicNameDict = new HashMap<>(); + for (final TopicPartition topicPartition : topicPartitions) { + if (!topicNameDict.containsKey(topicPartition.topic())) { + topicNameDict.put(topicPartition.topic(), topicIndex++); + } + } + + // write the topic name dictionary out + out.writeInt(topicNameDict.size()); + for (final Map.Entry entry : topicNameDict.entrySet()) { + out.writeInt(entry.getValue()); + out.writeUTF(entry.getKey()); + } + + return topicNameDict; + } + + private void encodePartitionsByHostAsDictionary(final DataOutputStream out) throws IOException { + final Set allTopicPartitions = partitionsByHost.values().stream() + .flatMap(Collection::stream).collect(Collectors.toSet()); + final Map topicNameDict = encodeTopicDictionaryAndGet(out, allTopicPartitions); + encodeHostPartitionMapUsingDictionary(out, topicNameDict, partitionsByHost); + } + + private void encodeActiveAndStandbyHostPartitions(final DataOutputStream out) throws IOException { + final Set allTopicPartitions = Stream + .concat(partitionsByHost.values().stream(), standbyPartitionsByHost.values().stream()) + .flatMap(Collection::stream).collect(Collectors.toSet()); + final Map topicNameDict = encodeTopicDictionaryAndGet(out, allTopicPartitions); + encodeHostPartitionMapUsingDictionary(out, topicNameDict, partitionsByHost); + encodeHostPartitionMapUsingDictionary(out, topicNameDict, standbyPartitionsByHost); + } + + private void writeHostInfo(final DataOutputStream out, final HostInfo hostInfo) throws IOException { + out.writeUTF(hostInfo.host()); + out.writeInt(hostInfo.port()); + } + + private void writeTopicPartitions(final DataOutputStream out, + final Set partitions) throws IOException { + out.writeInt(partitions.size()); + for (final TopicPartition partition : partitions) { + out.writeUTF(partition.topic()); + out.writeInt(partition.partition()); + } + } + + /** + * @throws TaskAssignmentException if method fails to decode the data or if the data version is unknown + */ + public static AssignmentInfo decode(final ByteBuffer data) { + // ensure we are at the beginning of the ByteBuffer + data.rewind(); + + try (final DataInputStream in = new DataInputStream(new ByteBufferInputStream(data))) { + final AssignmentInfo assignmentInfo; + + final int usedVersion = in.readInt(); + final int commonlySupportedVersion; + switch (usedVersion) { + case 1: + assignmentInfo = new AssignmentInfo(usedVersion, UNKNOWN); + decodeActiveTasks(assignmentInfo, in); + decodeStandbyTasks(assignmentInfo, in); + assignmentInfo.partitionsByHost = new HashMap<>(); + break; + case 2: + assignmentInfo = new AssignmentInfo(usedVersion, UNKNOWN); + decodeActiveTasks(assignmentInfo, in); + decodeStandbyTasks(assignmentInfo, in); + decodePartitionsByHost(assignmentInfo, in); + break; + case 3: + commonlySupportedVersion = in.readInt(); + assignmentInfo = new AssignmentInfo(usedVersion, commonlySupportedVersion); + decodeActiveTasks(assignmentInfo, in); + decodeStandbyTasks(assignmentInfo, in); + decodePartitionsByHost(assignmentInfo, in); + break; + case 4: + commonlySupportedVersion = in.readInt(); + assignmentInfo = new AssignmentInfo(usedVersion, commonlySupportedVersion); + decodeActiveTasks(assignmentInfo, in); + decodeStandbyTasks(assignmentInfo, in); + decodePartitionsByHost(assignmentInfo, in); + assignmentInfo.errCode = in.readInt(); + break; + case 5: + commonlySupportedVersion = in.readInt(); + assignmentInfo = new AssignmentInfo(usedVersion, commonlySupportedVersion); + decodeActiveTasks(assignmentInfo, in); + decodeStandbyTasks(assignmentInfo, in); + decodePartitionsByHostUsingDictionary(assignmentInfo, in); + assignmentInfo.errCode = in.readInt(); + break; + case 6: + commonlySupportedVersion = in.readInt(); + assignmentInfo = new AssignmentInfo(usedVersion, commonlySupportedVersion); + decodeActiveTasks(assignmentInfo, in); + decodeStandbyTasks(assignmentInfo, in); + decodeActiveAndStandbyHostPartitions(assignmentInfo, in); + assignmentInfo.errCode = in.readInt(); + break; + case 7: + case 8: + case 9: + case 10: + commonlySupportedVersion = in.readInt(); + assignmentInfo = new AssignmentInfo(usedVersion, commonlySupportedVersion); + decodeActiveTasks(assignmentInfo, in); + decodeStandbyTasks(assignmentInfo, in); + decodeActiveAndStandbyHostPartitions(assignmentInfo, in); + assignmentInfo.errCode = in.readInt(); + assignmentInfo.nextRebalanceMs = in.readLong(); + break; + default: + final TaskAssignmentException fatalException = new TaskAssignmentException("Unable to decode assignment data: " + + "used version: " + usedVersion + "; latest supported version: " + LATEST_SUPPORTED_VERSION); + log.error(fatalException.getMessage(), fatalException); + throw fatalException; + } + + return assignmentInfo; + } catch (final IOException ex) { + throw new TaskAssignmentException("Failed to decode AssignmentInfo", ex); + } + } + + private static void decodeActiveTasks(final AssignmentInfo assignmentInfo, + final DataInputStream in) throws IOException { + final int count = in.readInt(); + assignmentInfo.activeTasks = new ArrayList<>(count); + for (int i = 0; i < count; i++) { + assignmentInfo.activeTasks.add(readTaskIdFrom(in, assignmentInfo.usedVersion)); + } + } + + private static void decodeStandbyTasks(final AssignmentInfo assignmentInfo, + final DataInputStream in) throws IOException { + final int count = in.readInt(); + assignmentInfo.standbyTasks = new HashMap<>(count); + for (int i = 0; i < count; i++) { + final TaskId id = readTaskIdFrom(in, assignmentInfo.usedVersion); + assignmentInfo.standbyTasks.put(id, readTopicPartitions(in)); + } + } + + private static void decodePartitionsByHost(final AssignmentInfo assignmentInfo, + final DataInputStream in) throws IOException { + assignmentInfo.partitionsByHost = new HashMap<>(); + final int numEntries = in.readInt(); + for (int i = 0; i < numEntries; i++) { + final HostInfo hostInfo = new HostInfo(in.readUTF(), in.readInt()); + assignmentInfo.partitionsByHost.put(hostInfo, readTopicPartitions(in)); + } + } + + private static Set readTopicPartitions(final DataInputStream in) throws IOException { + final int numPartitions = in.readInt(); + final Set partitions = new HashSet<>(numPartitions); + for (int j = 0; j < numPartitions; j++) { + partitions.add(new TopicPartition(in.readUTF(), in.readInt())); + } + return partitions; + } + + private static Map decodeTopicIndexAndGet(final DataInputStream in) throws IOException { + final int dictSize = in.readInt(); + final Map topicIndexDict = new HashMap<>(dictSize); + for (int i = 0; i < dictSize; i++) { + topicIndexDict.put(in.readInt(), in.readUTF()); + } + return topicIndexDict; + } + + private static Map> decodeHostPartitionMapUsingDictionary(final DataInputStream in, + final Map topicIndexDict) throws IOException { + final Map> hostPartitionMap = new HashMap<>(); + final int numEntries = in.readInt(); + for (int i = 0; i < numEntries; i++) { + final HostInfo hostInfo = new HostInfo(in.readUTF(), in.readInt()); + hostPartitionMap.put(hostInfo, readTopicPartitions(in, topicIndexDict)); + } + return hostPartitionMap; + } + + private static void decodePartitionsByHostUsingDictionary(final AssignmentInfo assignmentInfo, + final DataInputStream in) throws IOException { + final Map topicIndexDict = decodeTopicIndexAndGet(in); + assignmentInfo.partitionsByHost = decodeHostPartitionMapUsingDictionary(in, topicIndexDict); + } + + private static void decodeActiveAndStandbyHostPartitions(final AssignmentInfo assignmentInfo, + final DataInputStream in) throws IOException { + final Map topicIndexDict = decodeTopicIndexAndGet(in); + assignmentInfo.partitionsByHost = decodeHostPartitionMapUsingDictionary(in, topicIndexDict); + assignmentInfo.standbyPartitionsByHost = decodeHostPartitionMapUsingDictionary(in, topicIndexDict); + } + + private static Set readTopicPartitions(final DataInputStream in, + final Map topicIndexDict) throws IOException { + final int numPartitions = in.readInt(); + final Set partitions = new HashSet<>(numPartitions); + for (int j = 0; j < numPartitions; j++) { + partitions.add(new TopicPartition(topicIndexDict.get(in.readInt()), in.readInt())); + } + return partitions; + } + + @Override + public int hashCode() { + final int hostMapHashCode = partitionsByHost.hashCode() ^ standbyPartitionsByHost.hashCode(); + return usedVersion ^ commonlySupportedVersion ^ activeTasks.hashCode() ^ standbyTasks.hashCode() + ^ hostMapHashCode ^ errCode; + } + + @Override + public boolean equals(final Object o) { + if (o instanceof AssignmentInfo) { + final AssignmentInfo other = (AssignmentInfo) o; + return usedVersion == other.usedVersion && + commonlySupportedVersion == other.commonlySupportedVersion && + errCode == other.errCode && + activeTasks.equals(other.activeTasks) && + standbyTasks.equals(other.standbyTasks) && + partitionsByHost.equals(other.partitionsByHost) && + standbyPartitionsByHost.equals(other.standbyPartitionsByHost); + } else { + return false; + } + } + + @Override + public String toString() { + return "[version=" + usedVersion + + ", supported version=" + commonlySupportedVersion + + ", active tasks=" + activeTasks + + ", standby tasks=" + standbyTasks + + ", partitions by host=" + partitionsByHost + + ", standbyPartitions by host=" + standbyPartitionsByHost + + "]"; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfiguration.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfiguration.java new file mode 100644 index 0000000..c14ab70 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfiguration.java @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.RebalanceProtocol; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.config.ConfigDef; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.StreamsConfig.InternalConfig; +import org.apache.kafka.streams.processor.internals.ClientUtils; +import org.apache.kafka.streams.processor.internals.InternalTopicManager; +import org.slf4j.Logger; + +import java.util.Map; + +import static org.apache.kafka.common.utils.Utils.getHost; +import static org.apache.kafka.common.utils.Utils.getPort; +import static org.apache.kafka.streams.StreamsConfig.InternalConfig.INTERNAL_TASK_ASSIGNOR_CLASS; +import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.LATEST_SUPPORTED_VERSION; + +public final class AssignorConfiguration { + private final String taskAssignorClass; + + private final String logPrefix; + private final Logger log; + private final ReferenceContainer referenceContainer; + + private final StreamsConfig streamsConfig; + private final Map internalConfigs; + + public AssignorConfiguration(final Map configs) { + // NOTE: If you add a new config to pass through to here, be sure to test it in a real + // application. Since we filter out some configurations, we may have to explicitly copy + // them over when we construct the Consumer. + streamsConfig = new ClientUtils.QuietStreamsConfig(configs); + internalConfigs = configs; + + // Setting the logger with the passed in client thread name + logPrefix = String.format("stream-thread [%s] ", streamsConfig.getString(CommonClientConfigs.CLIENT_ID_CONFIG)); + final LogContext logContext = new LogContext(logPrefix); + log = logContext.logger(getClass()); + + { + final Object o = configs.get(InternalConfig.REFERENCE_CONTAINER_PARTITION_ASSIGNOR); + if (o == null) { + final KafkaException fatalException = new KafkaException("ReferenceContainer is not specified"); + log.error(fatalException.getMessage(), fatalException); + throw fatalException; + } + + if (!(o instanceof ReferenceContainer)) { + final KafkaException fatalException = new KafkaException( + String.format("%s is not an instance of %s", o.getClass().getName(), ReferenceContainer.class.getName()) + ); + log.error(fatalException.getMessage(), fatalException); + throw fatalException; + } + + referenceContainer = (ReferenceContainer) o; + } + + { + final String o = (String) configs.get(INTERNAL_TASK_ASSIGNOR_CLASS); + if (o == null) { + taskAssignorClass = HighAvailabilityTaskAssignor.class.getName(); + } else { + taskAssignorClass = o; + } + } + } + + public ReferenceContainer referenceContainer() { + return referenceContainer; + } + + public RebalanceProtocol rebalanceProtocol() { + final String upgradeFrom = streamsConfig.getString(StreamsConfig.UPGRADE_FROM_CONFIG); + if (upgradeFrom != null) { + switch (upgradeFrom) { + case StreamsConfig.UPGRADE_FROM_0100: + case StreamsConfig.UPGRADE_FROM_0101: + case StreamsConfig.UPGRADE_FROM_0102: + case StreamsConfig.UPGRADE_FROM_0110: + case StreamsConfig.UPGRADE_FROM_10: + case StreamsConfig.UPGRADE_FROM_11: + case StreamsConfig.UPGRADE_FROM_20: + case StreamsConfig.UPGRADE_FROM_21: + case StreamsConfig.UPGRADE_FROM_22: + case StreamsConfig.UPGRADE_FROM_23: + // ATTENTION: The following log messages is used for verification in system test + // streams/streams_cooperative_rebalance_upgrade_test.py::StreamsCooperativeRebalanceUpgradeTest.test_upgrade_to_cooperative_rebalance + // If you change it, please do also change the system test accordingly and + // verify whether the test passes. + log.info("Eager rebalancing protocol is enabled now for upgrade from {}.x", upgradeFrom); + log.warn("The eager rebalancing protocol is deprecated and will stop being supported in a future release." + + " Please be prepared to remove the 'upgrade.from' config soon."); + return RebalanceProtocol.EAGER; + default: + throw new IllegalArgumentException("Unknown configuration value for parameter 'upgrade.from': " + upgradeFrom); + } + } + // ATTENTION: The following log messages is used for verification in system test + // streams/streams_cooperative_rebalance_upgrade_test.py::StreamsCooperativeRebalanceUpgradeTest.test_upgrade_to_cooperative_rebalance + // If you change it, please do also change the system test accordingly and + // verify whether the test passes. + log.info("Cooperative rebalancing protocol is enabled now"); + return RebalanceProtocol.COOPERATIVE; + } + + public String logPrefix() { + return logPrefix; + } + + public int configuredMetadataVersion(final int priorVersion) { + final String upgradeFrom = streamsConfig.getString(StreamsConfig.UPGRADE_FROM_CONFIG); + if (upgradeFrom != null) { + switch (upgradeFrom) { + case StreamsConfig.UPGRADE_FROM_0100: + log.info( + "Downgrading metadata version from {} to 1 for upgrade from 0.10.0.x.", + LATEST_SUPPORTED_VERSION + ); + return 1; + case StreamsConfig.UPGRADE_FROM_0101: + case StreamsConfig.UPGRADE_FROM_0102: + case StreamsConfig.UPGRADE_FROM_0110: + case StreamsConfig.UPGRADE_FROM_10: + case StreamsConfig.UPGRADE_FROM_11: + log.info( + "Downgrading metadata version from {} to 2 for upgrade from {}.x.", + LATEST_SUPPORTED_VERSION, + upgradeFrom + ); + return 2; + case StreamsConfig.UPGRADE_FROM_20: + case StreamsConfig.UPGRADE_FROM_21: + case StreamsConfig.UPGRADE_FROM_22: + case StreamsConfig.UPGRADE_FROM_23: + // These configs are for cooperative rebalancing and should not affect the metadata version + break; + default: + throw new IllegalArgumentException( + "Unknown configuration value for parameter 'upgrade.from': " + upgradeFrom + ); + } + } + return priorVersion; + } + + public String userEndPoint() { + final String configuredUserEndpoint = streamsConfig.getString(StreamsConfig.APPLICATION_SERVER_CONFIG); + if (configuredUserEndpoint != null && !configuredUserEndpoint.isEmpty()) { + try { + final String host = getHost(configuredUserEndpoint); + final Integer port = getPort(configuredUserEndpoint); + + if (host == null || port == null) { + throw new ConfigException( + String.format( + "%s Config %s isn't in the correct format. Expected a host:port pair but received %s", + logPrefix, StreamsConfig.APPLICATION_SERVER_CONFIG, configuredUserEndpoint + ) + ); + } + } catch (final NumberFormatException nfe) { + throw new ConfigException( + String.format("%s Invalid port supplied in %s for config %s: %s", + logPrefix, configuredUserEndpoint, StreamsConfig.APPLICATION_SERVER_CONFIG, nfe) + ); + } + return configuredUserEndpoint; + } else { + return null; + } + } + + public InternalTopicManager internalTopicManager() { + return new InternalTopicManager(referenceContainer.time, referenceContainer.adminClient, streamsConfig); + } + + public CopartitionedTopicsEnforcer copartitionedTopicsEnforcer() { + return new CopartitionedTopicsEnforcer(logPrefix); + } + + public AssignmentConfigs assignmentConfigs() { + return new AssignmentConfigs(streamsConfig); + } + + public TaskAssignor taskAssignor() { + try { + return Utils.newInstance(taskAssignorClass, TaskAssignor.class); + } catch (final ClassNotFoundException e) { + throw new IllegalArgumentException( + "Expected an instantiable class name for " + INTERNAL_TASK_ASSIGNOR_CLASS, + e + ); + } + } + + public AssignmentListener assignmentListener() { + final Object o = internalConfigs.get(InternalConfig.ASSIGNMENT_LISTENER); + if (o == null) { + return stable -> { }; + } + + if (!(o instanceof AssignmentListener)) { + final KafkaException fatalException = new KafkaException( + String.format("%s is not an instance of %s", o.getClass().getName(), AssignmentListener.class.getName()) + ); + log.error(fatalException.getMessage(), fatalException); + throw fatalException; + } + + return (AssignmentListener) o; + } + + public interface AssignmentListener { + void onAssignmentComplete(final boolean stable); + } + + public static class AssignmentConfigs { + public final long acceptableRecoveryLag; + public final int maxWarmupReplicas; + public final int numStandbyReplicas; + public final long probingRebalanceIntervalMs; + + private AssignmentConfigs(final StreamsConfig configs) { + acceptableRecoveryLag = configs.getLong(StreamsConfig.ACCEPTABLE_RECOVERY_LAG_CONFIG); + maxWarmupReplicas = configs.getInt(StreamsConfig.MAX_WARMUP_REPLICAS_CONFIG); + numStandbyReplicas = configs.getInt(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG); + probingRebalanceIntervalMs = configs.getLong(StreamsConfig.PROBING_REBALANCE_INTERVAL_MS_CONFIG); + } + + AssignmentConfigs(final Long acceptableRecoveryLag, + final Integer maxWarmupReplicas, + final Integer numStandbyReplicas, + final Long probingRebalanceIntervalMs) { + this.acceptableRecoveryLag = validated(StreamsConfig.ACCEPTABLE_RECOVERY_LAG_CONFIG, acceptableRecoveryLag); + this.maxWarmupReplicas = validated(StreamsConfig.MAX_WARMUP_REPLICAS_CONFIG, maxWarmupReplicas); + this.numStandbyReplicas = validated(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, numStandbyReplicas); + this.probingRebalanceIntervalMs = validated(StreamsConfig.PROBING_REBALANCE_INTERVAL_MS_CONFIG, probingRebalanceIntervalMs); + } + + private static T validated(final String configKey, final T value) { + final ConfigDef.Validator validator = StreamsConfig.configDef().configKeys().get(configKey).validator; + if (validator != null) { + validator.ensureValid(configKey, value); + } + return value; + } + + @Override + public String toString() { + return "AssignmentConfigs{" + + "\n acceptableRecoveryLag=" + acceptableRecoveryLag + + "\n maxWarmupReplicas=" + maxWarmupReplicas + + "\n numStandbyReplicas=" + numStandbyReplicas + + "\n probingRebalanceIntervalMs=" + probingRebalanceIntervalMs + + "\n}"; + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorError.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorError.java new file mode 100644 index 0000000..2104144 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/AssignorError.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +public enum AssignorError { + // Note: this error code should be reserved for fatal errors, as the receiving clients are future-proofed + // to throw an exception upon an unrecognized error code. + NONE(0), + INCOMPLETE_SOURCE_TOPIC_METADATA(1), + VERSION_PROBING(2), // not actually used anymore, but we may hit it during a rolling upgrade from earlier versions + ASSIGNMENT_ERROR(3), + SHUTDOWN_REQUESTED(4); + + private final int code; + + AssignorError(final int code) { + this.code = code; + } + + public int code() { + return code; + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java new file mode 100644 index 0000000..d828f1e --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientState.java @@ -0,0 +1,448 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.Task; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; +import java.util.Comparator; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.SortedSet; +import java.util.stream.Collectors; +import java.util.TreeMap; +import java.util.TreeSet; +import java.util.UUID; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.unmodifiableMap; +import static java.util.Collections.unmodifiableSet; +import static java.util.Comparator.comparing; +import static java.util.Comparator.comparingLong; + +import static org.apache.kafka.common.utils.Utils.union; +import static org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo.UNKNOWN_OFFSET_SUM; + +public class ClientState { + private static final Logger LOG = LoggerFactory.getLogger(ClientState.class); + public static final Comparator TOPIC_PARTITION_COMPARATOR = comparing(TopicPartition::topic).thenComparing(TopicPartition::partition); + + private final Map taskOffsetSums; // contains only stateful tasks we previously owned + private final Map taskLagTotals; // contains lag for all stateful tasks in the app topology + private final Map ownedPartitions = new TreeMap<>(TOPIC_PARTITION_COMPARATOR); + private final Map> consumerToPreviousStatefulTaskIds = new TreeMap<>(); + + private final ClientStateTask assignedActiveTasks = new ClientStateTask(new TreeSet<>(), new TreeMap<>()); + private final ClientStateTask assignedStandbyTasks = new ClientStateTask(new TreeSet<>(), new TreeMap<>()); + private final ClientStateTask previousActiveTasks = new ClientStateTask(null, new TreeMap<>()); + private final ClientStateTask previousStandbyTasks = new ClientStateTask(null, null); + private final ClientStateTask revokingActiveTasks = new ClientStateTask(null, new TreeMap<>()); + + private int capacity; + + public ClientState() { + this(0); + } + + ClientState(final int capacity) { + previousStandbyTasks.taskIds(new TreeSet<>()); + previousActiveTasks.taskIds(new TreeSet<>()); + + taskOffsetSums = new TreeMap<>(); + taskLagTotals = new TreeMap<>(); + this.capacity = capacity; + } + + // For testing only + public ClientState(final Set previousActiveTasks, + final Set previousStandbyTasks, + final Map taskLagTotals, + final int capacity) { + this.previousStandbyTasks.taskIds(unmodifiableSet(new TreeSet<>(previousStandbyTasks))); + this.previousActiveTasks.taskIds(unmodifiableSet(new TreeSet<>(previousActiveTasks))); + taskOffsetSums = emptyMap(); + this.taskLagTotals = unmodifiableMap(taskLagTotals); + this.capacity = capacity; + } + + int capacity() { + return capacity; + } + + public void incrementCapacity() { + capacity++; + } + + boolean reachedCapacity() { + return assignedTaskCount() >= capacity; + } + + public Set activeTasks() { + return unmodifiableSet(assignedActiveTasks.taskIds()); + } + + public int activeTaskCount() { + return assignedActiveTasks.taskIds().size(); + } + + double activeTaskLoad() { + return ((double) activeTaskCount()) / capacity; + } + + public void assignActiveTasks(final Collection tasks) { + assignedActiveTasks.taskIds().addAll(tasks); + } + + public void assignActiveToConsumer(final TaskId task, final String consumer) { + if (!assignedActiveTasks.taskIds().contains(task)) { + throw new IllegalStateException("added not assign active task " + task + " to this client state."); + } + assignedActiveTasks.consumerToTaskIds() + .computeIfAbsent(consumer, k -> new HashSet<>()).add(task); + } + + public void assignStandbyToConsumer(final TaskId task, final String consumer) { + assignedStandbyTasks.consumerToTaskIds().computeIfAbsent(consumer, k -> new HashSet<>()).add(task); + } + + public void revokeActiveFromConsumer(final TaskId task, final String consumer) { + revokingActiveTasks.consumerToTaskIds().computeIfAbsent(consumer, k -> new HashSet<>()).add(task); + } + + public Map> prevOwnedActiveTasksByConsumer() { + return previousActiveTasks.consumerToTaskIds(); + } + + public Map> prevOwnedStandbyByConsumer() { + // standbys are just those stateful tasks minus active tasks + final Map> consumerToPreviousStandbyTaskIds = new TreeMap<>(); + final Map> consumerToPreviousActiveTaskIds = previousActiveTasks.consumerToTaskIds(); + + for (final Map.Entry> entry: consumerToPreviousStatefulTaskIds.entrySet()) { + final Set standbyTaskIds = new HashSet<>(entry.getValue()); + if (consumerToPreviousActiveTaskIds.containsKey(entry.getKey())) + standbyTaskIds.removeAll(consumerToPreviousActiveTaskIds.get(entry.getKey())); + consumerToPreviousStandbyTaskIds.put(entry.getKey(), standbyTaskIds); + } + + return consumerToPreviousStandbyTaskIds; + } + + // including both active and standby tasks + public Set prevOwnedStatefulTasksByConsumer(final String memberId) { + return consumerToPreviousStatefulTaskIds.get(memberId); + } + + public Map> assignedActiveTasksByConsumer() { + return assignedActiveTasks.consumerToTaskIds(); + } + + public Map> revokingActiveTasksByConsumer() { + return revokingActiveTasks.consumerToTaskIds(); + } + + public Map> assignedStandbyTasksByConsumer() { + return assignedStandbyTasks.consumerToTaskIds(); + } + + public void assignActive(final TaskId task) { + assertNotAssigned(task); + assignedActiveTasks.taskIds().add(task); + } + + public void unassignActive(final TaskId task) { + final Set taskIds = assignedActiveTasks.taskIds(); + if (!taskIds.contains(task)) { + throw new IllegalArgumentException("Tried to unassign active task " + task + ", but it is not currently assigned: " + this); + } + taskIds.remove(task); + } + + public Set standbyTasks() { + return unmodifiableSet(assignedStandbyTasks.taskIds()); + } + + boolean hasStandbyTask(final TaskId taskId) { + return assignedStandbyTasks.taskIds().contains(taskId); + } + + int standbyTaskCount() { + return assignedStandbyTasks.taskIds().size(); + } + + public void assignStandby(final TaskId task) { + assertNotAssigned(task); + assignedStandbyTasks.taskIds().add(task); + } + + void unassignStandby(final TaskId task) { + final Set taskIds = assignedStandbyTasks.taskIds(); + if (!taskIds.contains(task)) { + throw new IllegalArgumentException("Tried to unassign standby task " + task + ", but it is not currently assigned: " + this); + } + taskIds.remove(task); + } + + Set assignedTasks() { + final Set assignedActiveTaskIds = assignedActiveTasks.taskIds(); + final Set assignedStandbyTaskIds = assignedStandbyTasks.taskIds(); + // Since we're copying it, it's not strictly necessary to make it unmodifiable also. + // I'm just trying to prevent subtle bugs if we write code that thinks it can update + // the assignment by updating the returned set. + return unmodifiableSet( + union( + () -> new HashSet<>(assignedActiveTaskIds.size() + assignedStandbyTaskIds.size()), + assignedActiveTaskIds, + assignedStandbyTaskIds + ) + ); + } + + public int assignedTaskCount() { + return activeTaskCount() + standbyTaskCount(); + } + + double assignedTaskLoad() { + return ((double) assignedTaskCount()) / capacity; + } + + boolean hasAssignedTask(final TaskId taskId) { + return assignedActiveTasks.taskIds().contains(taskId) || assignedStandbyTasks.taskIds().contains(taskId); + } + + Set prevActiveTasks() { + return unmodifiableSet(previousActiveTasks.taskIds()); + } + + private void addPreviousActiveTask(final TaskId task) { + previousActiveTasks.taskIds().add(task); + } + + void addPreviousActiveTasks(final Set prevTasks) { + previousActiveTasks.taskIds().addAll(prevTasks); + } + + Set prevStandbyTasks() { + return unmodifiableSet(previousStandbyTasks.taskIds()); + } + + private void addPreviousStandbyTask(final TaskId task) { + previousStandbyTasks.taskIds().add(task); + } + + void addPreviousStandbyTasks(final Set standbyTasks) { + previousStandbyTasks.taskIds().addAll(standbyTasks); + } + + Set previousAssignedTasks() { + final Set previousActiveTaskIds = previousActiveTasks.taskIds(); + final Set previousStandbyTaskIds = previousStandbyTasks.taskIds(); + return union(() -> new HashSet<>(previousActiveTaskIds.size() + previousStandbyTaskIds.size()), + previousActiveTaskIds, + previousStandbyTaskIds); + } + + // May return null + public String previousOwnerForPartition(final TopicPartition partition) { + return ownedPartitions.get(partition); + } + + public void addOwnedPartitions(final Collection ownedPartitions, final String consumer) { + for (final TopicPartition tp : ownedPartitions) { + this.ownedPartitions.put(tp, consumer); + } + } + + public void addPreviousTasksAndOffsetSums(final String consumerId, final Map taskOffsetSums) { + this.taskOffsetSums.putAll(taskOffsetSums); + consumerToPreviousStatefulTaskIds.put(consumerId, taskOffsetSums.keySet()); + } + + public void initializePrevTasks(final Map taskForPartitionMap) { + if (!previousActiveTasks.taskIds().isEmpty() || !previousStandbyTasks.taskIds().isEmpty()) { + throw new IllegalStateException("Already added previous tasks to this client state."); + } + initializePrevActiveTasksFromOwnedPartitions(taskForPartitionMap); + initializeRemainingPrevTasksFromTaskOffsetSums(); + } + + /** + * Compute the lag for each stateful task, including tasks this client did not previously have. + */ + public void computeTaskLags(final UUID uuid, final Map allTaskEndOffsetSums) { + if (!taskLagTotals.isEmpty()) { + throw new IllegalStateException("Already computed task lags for this client."); + } + + for (final Map.Entry taskEntry : allTaskEndOffsetSums.entrySet()) { + final TaskId task = taskEntry.getKey(); + final Long endOffsetSum = taskEntry.getValue(); + final Long offsetSum = taskOffsetSums.getOrDefault(task, 0L); + + if (offsetSum == Task.LATEST_OFFSET) { + taskLagTotals.put(task, Task.LATEST_OFFSET); + } else if (offsetSum == UNKNOWN_OFFSET_SUM) { + taskLagTotals.put(task, UNKNOWN_OFFSET_SUM); + } else if (endOffsetSum < offsetSum) { + LOG.warn("Task " + task + " had endOffsetSum=" + endOffsetSum + " smaller than offsetSum=" + + offsetSum + " on member " + uuid + ". This probably means the task is corrupted," + + " which in turn indicates that it will need to restore from scratch if it gets assigned." + + " The assignor will de-prioritize returning this task to this member in the hopes that" + + " some other member may be able to re-use its state."); + taskLagTotals.put(task, endOffsetSum); + } else { + taskLagTotals.put(task, endOffsetSum - offsetSum); + } + } + } + + /** + * Returns the total lag across all logged stores in the task. Equal to the end offset sum if this client + * did not have any state for this task on disk. + * + * @return end offset sum - offset sum + * Task.LATEST_OFFSET if this was previously an active running task on this client + */ + public long lagFor(final TaskId task) { + final Long totalLag = taskLagTotals.get(task); + if (totalLag == null) { + throw new IllegalStateException("Tried to lookup lag for unknown task " + task); + } + return totalLag; + } + + /** + * @return the previous tasks assigned to this consumer ordered by lag, filtered for any tasks that don't exist in this assignment + */ + public SortedSet prevTasksByLag(final String consumer) { + final SortedSet prevTasksByLag = new TreeSet<>(comparingLong(this::lagFor).thenComparing(TaskId::compareTo)); + for (final TaskId task : prevOwnedStatefulTasksByConsumer(consumer)) { + if (taskLagTotals.containsKey(task)) { + prevTasksByLag.add(task); + } else { + LOG.debug("Skipping previous task {} since it's not part of the current assignment", task); + } + } + return prevTasksByLag; + } + + public Set statefulActiveTasks() { + return assignedActiveTasks.taskIds().stream().filter(this::isStateful).collect(Collectors.toSet()); + } + + public Set statelessActiveTasks() { + return assignedActiveTasks.taskIds().stream().filter(task -> !isStateful(task)).collect(Collectors.toSet()); + } + + boolean hasUnfulfilledQuota(final int tasksPerThread) { + return assignedActiveTasks.taskIds().size() < capacity * tasksPerThread; + } + + boolean hasMoreAvailableCapacityThan(final ClientState other) { + if (capacity <= 0) { + throw new IllegalStateException("Capacity of this ClientState must be greater than 0."); + } + + if (other.capacity <= 0) { + throw new IllegalStateException("Capacity of other ClientState must be greater than 0"); + } + + final double otherLoad = (double) other.assignedTaskCount() / other.capacity; + final double thisLoad = (double) assignedTaskCount() / capacity; + + if (thisLoad < otherLoad) { + return true; + } else if (thisLoad > otherLoad) { + return false; + } else { + return capacity > other.capacity; + } + } + + public String consumers() { + return consumerToPreviousStatefulTaskIds.keySet().toString(); + } + + public String currentAssignment() { + return "[activeTasks: (" + assignedActiveTasks.taskIds() + + ") standbyTasks: (" + assignedStandbyTasks.taskIds() + ")]"; + } + + @Override + public String toString() { + return "[activeTasks: (" + assignedActiveTasks.taskIds() + + ") standbyTasks: (" + assignedStandbyTasks.taskIds() + + ") prevActiveTasks: (" + previousActiveTasks.taskIds() + + ") prevStandbyTasks: (" + previousStandbyTasks.taskIds() + + ") changelogOffsetTotalsByTask: (" + taskOffsetSums.entrySet() + + ") taskLagTotals: (" + taskLagTotals.entrySet() + + ") capacity: " + capacity + + " assigned: " + assignedTaskCount() + + "]"; + } + + private boolean isStateful(final TaskId task) { + return taskLagTotals.containsKey(task); + } + + private void initializePrevActiveTasksFromOwnedPartitions(final Map taskForPartitionMap) { + // there are three cases where we need to construct some or all of the prevTasks from the ownedPartitions: + // 1) COOPERATIVE clients on version 2.4-2.5 do not encode active tasks at all and rely on ownedPartitions + // 2) future client during version probing, when we can't decode the future subscription info's prev tasks + // 3) stateless tasks are not encoded in the task lags, and must be figured out from the ownedPartitions + for (final Map.Entry partitionEntry : ownedPartitions.entrySet()) { + final TopicPartition tp = partitionEntry.getKey(); + final TaskId task = taskForPartitionMap.get(tp); + if (task != null) { + addPreviousActiveTask(task); + previousActiveTasks.consumerToTaskIds().computeIfAbsent(partitionEntry.getValue(), k -> new HashSet<>()).add(task); + } else { + LOG.error("No task found for topic partition {}", tp); + } + } + } + + private void initializeRemainingPrevTasksFromTaskOffsetSums() { + final Set previousActiveTaskIds = previousActiveTasks.taskIds(); + if (previousActiveTaskIds.isEmpty() && !ownedPartitions.isEmpty()) { + LOG.error("Tried to process tasks in offset sum map before processing tasks from ownedPartitions = {}", ownedPartitions); + throw new IllegalStateException("Must initialize prevActiveTasks from ownedPartitions before initializing remaining tasks."); + } + for (final Map.Entry taskEntry : taskOffsetSums.entrySet()) { + final TaskId task = taskEntry.getKey(); + if (!previousActiveTaskIds.contains(task)) { + final long offsetSum = taskEntry.getValue(); + if (offsetSum == Task.LATEST_OFFSET) { + addPreviousActiveTask(task); + } else { + addPreviousStandbyTask(task); + } + } + } + } + + private void assertNotAssigned(final TaskId task) { + if (assignedStandbyTasks.taskIds().contains(task) || assignedActiveTasks.taskIds().contains(task)) { + throw new IllegalArgumentException("Tried to assign task " + task + ", but it is already assigned: " + this); + } + } +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTask.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTask.java new file mode 100644 index 0000000..9276969 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTask.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import java.util.Map; +import java.util.Set; + +import org.apache.kafka.streams.processor.TaskId; + +class ClientStateTask { + private final Map> consumerToTaskIds; + private Set taskIds; + + ClientStateTask(final Set taskIds, + final Map> consumerToTaskIds) { + this.taskIds = taskIds; + this.consumerToTaskIds = consumerToTaskIds; + } + + void taskIds(final Set clientToTaskIds) { + taskIds = clientToTaskIds; + } + + Set taskIds() { + return taskIds; + } + + Map> consumerToTaskIds() { + return consumerToTaskIds; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ConstrainedPrioritySet.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ConstrainedPrioritySet.java new file mode 100644 index 0000000..1de9dfc --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ConstrainedPrioritySet.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import org.apache.kafka.streams.processor.TaskId; + +import java.util.Collection; +import java.util.Comparator; +import java.util.HashSet; +import java.util.PriorityQueue; +import java.util.Set; +import java.util.UUID; +import java.util.function.BiFunction; +import java.util.function.Function; + +/** + * Wraps a priority queue of clients and returns the next valid candidate(s) based on the current task assignment + */ +class ConstrainedPrioritySet { + + private final PriorityQueue clientsByTaskLoad; + private final BiFunction constraint; + private final Set uniqueClients = new HashSet<>(); + + ConstrainedPrioritySet(final BiFunction constraint, + final Function weight) { + this.constraint = constraint; + clientsByTaskLoad = new PriorityQueue<>(Comparator.comparing(weight).thenComparing(clientId -> clientId)); + } + + /** + * @return the next least loaded client that satisfies the given criteria, or null if none do + */ + UUID poll(final TaskId task, final Function extraConstraint) { + final Set invalidPolledClients = new HashSet<>(); + while (!clientsByTaskLoad.isEmpty()) { + final UUID candidateClient = pollNextClient(); + if (constraint.apply(candidateClient, task) && extraConstraint.apply(candidateClient)) { + // then we found the lightest, valid client + offerAll(invalidPolledClients); + return candidateClient; + } else { + // remember this client and try again later + invalidPolledClients.add(candidateClient); + } + } + // we tried all the clients, and none met the constraint (or there are no clients) + offerAll(invalidPolledClients); + return null; + } + + /** + * @return the next least loaded client that satisfies the given criteria, or null if none do + */ + UUID poll(final TaskId task) { + return poll(task, client -> true); + } + + void offerAll(final Collection clients) { + for (final UUID client : clients) { + offer(client); + } + } + + void offer(final UUID client) { + if (uniqueClients.contains(client)) { + clientsByTaskLoad.remove(client); + } else { + uniqueClients.add(client); + } + clientsByTaskLoad.offer(client); + } + + private UUID pollNextClient() { + final UUID client = clientsByTaskLoad.remove(); + uniqueClients.remove(client); + return client; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ConsumerProtocolUtils.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ConsumerProtocolUtils.java new file mode 100644 index 0000000..8dbedb8 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ConsumerProtocolUtils.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import org.apache.kafka.streams.errors.TaskAssignmentException; +import org.apache.kafka.streams.processor.TaskId; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; + +import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.MIN_NAMED_TOPOLOGY_VERSION; + +/** + * Utility class for common assignment or consumer protocol utility methods such as de/serialization + */ +public class ConsumerProtocolUtils { + + /** + * @throws IOException if cannot write to output stream + */ + public static void writeTaskIdTo(final TaskId taskId, final DataOutputStream out, final int version) throws IOException { + out.writeInt(taskId.subtopology()); + out.writeInt(taskId.partition()); + if (version >= MIN_NAMED_TOPOLOGY_VERSION) { + if (taskId.topologyName() != null) { + out.writeInt(taskId.topologyName().length()); + out.writeChars(taskId.topologyName()); + } else { + out.writeInt(0); + } + } else if (taskId.topologyName() != null) { + throw new TaskAssignmentException("Named topologies are not compatible with protocol version " + version); + } + } + + /** + * @throws IOException if cannot read from input stream + */ + public static TaskId readTaskIdFrom(final DataInputStream in, final int version) throws IOException { + final int subtopology = in.readInt(); + final int partition = in.readInt(); + final String namedTopology; + if (version >= MIN_NAMED_TOPOLOGY_VERSION) { + final int numNamedTopologyChars = in.readInt(); + final StringBuilder namedTopologyBuilder = new StringBuilder(); + for (int i = 0; i < numNamedTopologyChars; ++i) { + namedTopologyBuilder.append(in.readChar()); + } + namedTopology = namedTopologyBuilder.toString(); + } else { + namedTopology = null; + } + return new TaskId(subtopology, partition, getNamedTopologyOrElseNull(namedTopology)); + } + + public static void writeTaskIdTo(final TaskId taskId, final ByteBuffer buf, final int version) { + buf.putInt(taskId.subtopology()); + buf.putInt(taskId.partition()); + if (version >= MIN_NAMED_TOPOLOGY_VERSION) { + if (taskId.topologyName() != null) { + buf.putInt(taskId.topologyName().length()); + for (final char c : taskId.topologyName().toCharArray()) { + buf.putChar(c); + } + } else { + buf.putInt(0); + } + } else if (taskId.topologyName() != null) { + throw new TaskAssignmentException("Named topologies are not compatible with protocol version " + version); + } + } + + public static TaskId readTaskIdFrom(final ByteBuffer buf, final int version) { + final int subtopology = buf.getInt(); + final int partition = buf.getInt(); + final String namedTopology; + if (version >= MIN_NAMED_TOPOLOGY_VERSION) { + final int numNamedTopologyChars = buf.getInt(); + final StringBuilder namedTopologyBuilder = new StringBuilder(); + for (int i = 0; i < numNamedTopologyChars; ++i) { + namedTopologyBuilder.append(buf.getChar()); + } + namedTopology = namedTopologyBuilder.toString(); + } else { + namedTopology = null; + } + return new TaskId(subtopology, partition, getNamedTopologyOrElseNull(namedTopology)); + } + + /** + * @return the namedTopology name, or null if the passed in namedTopology is null or the empty string + */ + private static String getNamedTopologyOrElseNull(final String namedTopology) { + return (namedTopology == null || namedTopology.length() == 0) ? + null : + namedTopology; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/CopartitionedTopicsEnforcer.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/CopartitionedTopicsEnforcer.java new file mode 100644 index 0000000..8324532 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/CopartitionedTopicsEnforcer.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.streams.errors.TopologyException; +import org.apache.kafka.streams.processor.internals.InternalTopicConfig; +import org.slf4j.Logger; + +import java.util.Collection; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.Set; +import java.util.TreeMap; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +public class CopartitionedTopicsEnforcer { + private final String logPrefix; + private final Logger log; + + public CopartitionedTopicsEnforcer(final String logPrefix) { + this.logPrefix = logPrefix; + final LogContext logContext = new LogContext(logPrefix); + log = logContext.logger(getClass()); + } + + public void enforce(final Set copartitionGroup, + final Map allRepartitionTopicsNumPartitions, + final Cluster metadata) { + if (copartitionGroup.isEmpty()) { + return; + } + + final Map repartitionTopicConfigs = + copartitionGroup.stream() + .filter(allRepartitionTopicsNumPartitions::containsKey) + .collect(Collectors.toMap(topic -> topic, allRepartitionTopicsNumPartitions::get)); + + final Map nonRepartitionTopicPartitions = + copartitionGroup.stream().filter(topic -> !allRepartitionTopicsNumPartitions.containsKey(topic)) + .collect(Collectors.toMap(topic -> topic, topic -> { + final Integer partitions = metadata.partitionCountForTopic(topic); + if (partitions == null) { + final String str = String.format("%sTopic not found: %s", logPrefix, topic); + log.error(str); + throw new IllegalStateException(str); + } else { + return partitions; + } + })); + + final int numPartitionsToUseForRepartitionTopics; + final Collection internalTopicConfigs = repartitionTopicConfigs.values(); + + if (copartitionGroup.equals(repartitionTopicConfigs.keySet())) { + final Collection internalTopicConfigsWithEnforcedNumberOfPartitions = internalTopicConfigs + .stream() + .filter(InternalTopicConfig::hasEnforcedNumberOfPartitions) + .collect(Collectors.toList()); + + // if there's at least one repartition topic with enforced number of partitions + // validate that they all have same number of partitions + if (!internalTopicConfigsWithEnforcedNumberOfPartitions.isEmpty()) { + numPartitionsToUseForRepartitionTopics = validateAndGetNumOfPartitions( + repartitionTopicConfigs, + internalTopicConfigsWithEnforcedNumberOfPartitions + ); + } else { + // If all topics for this co-partition group are repartition topics, + // then set the number of partitions to be the maximum of the number of partitions. + numPartitionsToUseForRepartitionTopics = getMaxPartitions(repartitionTopicConfigs); + } + } else { + // Otherwise, use the number of partitions from external topics (which must all be the same) + numPartitionsToUseForRepartitionTopics = getSamePartitions(nonRepartitionTopicPartitions); + } + + // coerce all the repartition topics to use the decided number of partitions. + for (final InternalTopicConfig config : internalTopicConfigs) { + maybeSetNumberOfPartitionsForInternalTopic(numPartitionsToUseForRepartitionTopics, config); + + final int numberOfPartitionsOfInternalTopic = config + .numberOfPartitions() + .orElseThrow(emptyNumberOfPartitionsExceptionSupplier(config.name())); + + if (numberOfPartitionsOfInternalTopic != numPartitionsToUseForRepartitionTopics) { + final String msg = String.format("%sNumber of partitions [%s] of repartition topic [%s] " + + "doesn't match number of partitions [%s] of the source topic.", + logPrefix, + numberOfPartitionsOfInternalTopic, + config.name(), + numPartitionsToUseForRepartitionTopics); + throw new TopologyException(msg); + } + } + } + + private static void maybeSetNumberOfPartitionsForInternalTopic(final int numPartitionsToUseForRepartitionTopics, + final InternalTopicConfig config) { + if (!config.hasEnforcedNumberOfPartitions()) { + config.setNumberOfPartitions(numPartitionsToUseForRepartitionTopics); + } + } + + private int validateAndGetNumOfPartitions(final Map repartitionTopicConfigs, + final Collection internalTopicConfigs) { + final InternalTopicConfig firstInternalTopicConfig = internalTopicConfigs.iterator().next(); + + final int firstNumberOfPartitionsOfInternalTopic = firstInternalTopicConfig + .numberOfPartitions() + .orElseThrow(emptyNumberOfPartitionsExceptionSupplier(firstInternalTopicConfig.name())); + + for (final InternalTopicConfig internalTopicConfig : internalTopicConfigs) { + final Integer numberOfPartitions = internalTopicConfig + .numberOfPartitions() + .orElseThrow(emptyNumberOfPartitionsExceptionSupplier(internalTopicConfig.name())); + + if (numberOfPartitions != firstNumberOfPartitionsOfInternalTopic) { + final Map repartitionTopics = repartitionTopicConfigs + .entrySet() + .stream() + .collect(Collectors.toMap(Entry::getKey, entry -> entry.getValue().numberOfPartitions().get())); + + final String msg = String.format("%sFollowing topics do not have the same number of partitions: [%s]", + logPrefix, + new TreeMap<>(repartitionTopics)); + throw new TopologyException(msg); + } + } + + return firstNumberOfPartitionsOfInternalTopic; + } + + private static Supplier emptyNumberOfPartitionsExceptionSupplier(final String topic) { + return () -> new TopologyException("Number of partitions is not set for topic: " + topic); + } + + private int getSamePartitions(final Map nonRepartitionTopicsInCopartitionGroup) { + final int partitions = nonRepartitionTopicsInCopartitionGroup.values().iterator().next(); + for (final Entry entry : nonRepartitionTopicsInCopartitionGroup.entrySet()) { + if (entry.getValue() != partitions) { + final TreeMap sorted = new TreeMap<>(nonRepartitionTopicsInCopartitionGroup); + throw new TopologyException( + String.format("%sTopics not co-partitioned: [%s]", + logPrefix, sorted) + ); + } + } + return partitions; + } + + private int getMaxPartitions(final Map repartitionTopicsInCopartitionGroup) { + int maxPartitions = 0; + + for (final InternalTopicConfig config : repartitionTopicsInCopartitionGroup.values()) { + final Optional partitions = config.numberOfPartitions(); + maxPartitions = Integer.max(maxPartitions, partitions.orElse(maxPartitions)); + } + if (maxPartitions <= 0) { + throw new IllegalStateException(logPrefix + "Could not validate the copartitioning of topics: " + repartitionTopicsInCopartitionGroup.keySet()); + } + return maxPartitions; + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/FallbackPriorTaskAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/FallbackPriorTaskAssignor.java new file mode 100644 index 0000000..58456ac --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/FallbackPriorTaskAssignor.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs; + +import java.util.Map; +import java.util.Set; +import java.util.UUID; + +/** + * A special task assignor implementation to be used as a fallback in case the + * configured assignor couldn't be invoked. + * + * Specifically, this assignor must: + * 1. ignore the task lags in the ClientState map + * 2. always return true, indicating that a follow-up rebalance is needed + */ +public class FallbackPriorTaskAssignor implements TaskAssignor { + private final StickyTaskAssignor delegate; + + public FallbackPriorTaskAssignor() { + delegate = new StickyTaskAssignor(true); + } + + @Override + public boolean assign(final Map clients, + final Set allTaskIds, + final Set statefulTaskIds, + final AssignmentConfigs configs) { + delegate.assign(clients, allTaskIds, statefulTaskIds, configs); + return true; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java new file mode 100644 index 0000000..f6464f8 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignor.java @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.Task; +import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.Set; +import java.util.SortedMap; +import java.util.SortedSet; +import java.util.TreeMap; +import java.util.TreeSet; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiConsumer; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static org.apache.kafka.common.utils.Utils.diff; +import static org.apache.kafka.streams.processor.internals.assignment.TaskMovement.assignActiveTaskMovements; +import static org.apache.kafka.streams.processor.internals.assignment.TaskMovement.assignStandbyTaskMovements; + +public class HighAvailabilityTaskAssignor implements TaskAssignor { + private static final Logger log = LoggerFactory.getLogger(HighAvailabilityTaskAssignor.class); + + @Override + public boolean assign(final Map clients, + final Set allTaskIds, + final Set statefulTaskIds, + final AssignmentConfigs configs) { + final SortedSet statefulTasks = new TreeSet<>(statefulTaskIds); + final TreeMap clientStates = new TreeMap<>(clients); + + assignActiveStatefulTasks(clientStates, statefulTasks); + + assignStandbyReplicaTasks( + clientStates, + statefulTasks, + configs.numStandbyReplicas + ); + + final AtomicInteger remainingWarmupReplicas = new AtomicInteger(configs.maxWarmupReplicas); + + final Map> tasksToCaughtUpClients = tasksToCaughtUpClients( + statefulTasks, + clientStates, + configs.acceptableRecoveryLag + ); + + // We temporarily need to know which standby tasks were intended as warmups + // for active tasks, so that we don't move them (again) when we plan standby + // task movements. We can then immediately treat warmups exactly the same as + // hot-standby replicas, so we just track it right here as metadata, rather + // than add "warmup" assignments to ClientState, for example. + final Map> warmups = new TreeMap<>(); + + final int neededActiveTaskMovements = assignActiveTaskMovements( + tasksToCaughtUpClients, + clientStates, + warmups, + remainingWarmupReplicas + ); + + final int neededStandbyTaskMovements = assignStandbyTaskMovements( + tasksToCaughtUpClients, + clientStates, + remainingWarmupReplicas, + warmups + ); + + assignStatelessActiveTasks(clientStates, diff(TreeSet::new, allTaskIds, statefulTasks)); + + final boolean probingRebalanceNeeded = neededActiveTaskMovements + neededStandbyTaskMovements > 0; + + log.info("Decided on assignment: " + + clientStates + + " with" + + (probingRebalanceNeeded ? "" : " no") + + " followup probing rebalance."); + + return probingRebalanceNeeded; + } + + private static void assignActiveStatefulTasks(final SortedMap clientStates, + final SortedSet statefulTasks) { + Iterator clientStateIterator = null; + for (final TaskId task : statefulTasks) { + if (clientStateIterator == null || !clientStateIterator.hasNext()) { + clientStateIterator = clientStates.values().iterator(); + } + clientStateIterator.next().assignActive(task); + } + + balanceTasksOverThreads( + clientStates, + ClientState::activeTasks, + ClientState::unassignActive, + ClientState::assignActive + ); + } + + private static void assignStandbyReplicaTasks(final TreeMap clientStates, + final Set statefulTasks, + final int numStandbyReplicas) { + final Map tasksToRemainingStandbys = + statefulTasks.stream().collect(Collectors.toMap(task -> task, t -> numStandbyReplicas)); + + final ConstrainedPrioritySet standbyTaskClientsByTaskLoad = new ConstrainedPrioritySet( + (client, task) -> !clientStates.get(client).hasAssignedTask(task), + client -> clientStates.get(client).assignedTaskLoad() + ); + standbyTaskClientsByTaskLoad.offerAll(clientStates.keySet()); + + for (final TaskId task : statefulTasks) { + int numRemainingStandbys = tasksToRemainingStandbys.get(task); + while (numRemainingStandbys > 0) { + final UUID client = standbyTaskClientsByTaskLoad.poll(task); + if (client == null) { + break; + } + clientStates.get(client).assignStandby(task); + numRemainingStandbys--; + standbyTaskClientsByTaskLoad.offer(client); + } + + if (numRemainingStandbys > 0) { + log.warn("Unable to assign {} of {} standby tasks for task [{}]. " + + "There is not enough available capacity. You should " + + "increase the number of application instances " + + "to maintain the requested number of standby replicas.", + numRemainingStandbys, numStandbyReplicas, task); + } + } + + balanceTasksOverThreads( + clientStates, + ClientState::standbyTasks, + ClientState::unassignStandby, + ClientState::assignStandby + ); + } + + private static void balanceTasksOverThreads(final SortedMap clientStates, + final Function> currentAssignmentAccessor, + final BiConsumer taskUnassignor, + final BiConsumer taskAssignor) { + boolean keepBalancing = true; + while (keepBalancing) { + keepBalancing = false; + for (final Map.Entry sourceEntry : clientStates.entrySet()) { + final UUID sourceClient = sourceEntry.getKey(); + final ClientState sourceClientState = sourceEntry.getValue(); + + for (final Map.Entry destinationEntry : clientStates.entrySet()) { + final UUID destinationClient = destinationEntry.getKey(); + final ClientState destinationClientState = destinationEntry.getValue(); + if (sourceClient.equals(destinationClient)) { + continue; + } + + final Set sourceTasks = new TreeSet<>(currentAssignmentAccessor.apply(sourceClientState)); + final Iterator sourceIterator = sourceTasks.iterator(); + while (shouldMoveATask(sourceClientState, destinationClientState) && sourceIterator.hasNext()) { + final TaskId taskToMove = sourceIterator.next(); + final boolean canMove = !destinationClientState.hasAssignedTask(taskToMove); + if (canMove) { + taskUnassignor.accept(sourceClientState, taskToMove); + taskAssignor.accept(destinationClientState, taskToMove); + keepBalancing = true; + } + } + } + } + } + } + + private static boolean shouldMoveATask(final ClientState sourceClientState, + final ClientState destinationClientState) { + final double skew = sourceClientState.assignedTaskLoad() - destinationClientState.assignedTaskLoad(); + + if (skew <= 0) { + return false; + } + + final double proposedAssignedTasksPerStreamThreadAtDestination = + (destinationClientState.assignedTaskCount() + 1.0) / destinationClientState.capacity(); + final double proposedAssignedTasksPerStreamThreadAtSource = + (sourceClientState.assignedTaskCount() - 1.0) / sourceClientState.capacity(); + final double proposedSkew = proposedAssignedTasksPerStreamThreadAtSource - proposedAssignedTasksPerStreamThreadAtDestination; + + if (proposedSkew < 0) { + // then the move would only create an imbalance in the other direction. + return false; + } + // we should only move a task if doing so would actually improve the skew. + return proposedSkew < skew; + } + + private static void assignStatelessActiveTasks(final TreeMap clientStates, + final Iterable statelessTasks) { + final ConstrainedPrioritySet statelessActiveTaskClientsByTaskLoad = new ConstrainedPrioritySet( + (client, task) -> true, + client -> clientStates.get(client).activeTaskLoad() + ); + statelessActiveTaskClientsByTaskLoad.offerAll(clientStates.keySet()); + + for (final TaskId task : statelessTasks) { + final UUID client = statelessActiveTaskClientsByTaskLoad.poll(task); + final ClientState state = clientStates.get(client); + state.assignActive(task); + statelessActiveTaskClientsByTaskLoad.offer(client); + } + } + + private static Map> tasksToCaughtUpClients(final Set statefulTasks, + final Map clientStates, + final long acceptableRecoveryLag) { + final Map> taskToCaughtUpClients = new HashMap<>(); + + for (final TaskId task : statefulTasks) { + final TreeSet caughtUpClients = new TreeSet<>(); + for (final Map.Entry clientEntry : clientStates.entrySet()) { + final UUID client = clientEntry.getKey(); + final long taskLag = clientEntry.getValue().lagFor(task); + if (activeRunning(taskLag) || unbounded(acceptableRecoveryLag) || acceptable(acceptableRecoveryLag, taskLag)) { + caughtUpClients.add(client); + } + } + taskToCaughtUpClients.put(task, caughtUpClients); + } + + return taskToCaughtUpClients; + } + + private static boolean unbounded(final long acceptableRecoveryLag) { + return acceptableRecoveryLag == Long.MAX_VALUE; + } + + private static boolean acceptable(final long acceptableRecoveryLag, final long taskLag) { + return taskLag >= 0 && taskLag <= acceptableRecoveryLag; + } + + private static boolean activeRunning(final long taskLag) { + return taskLag == Task.LATEST_OFFSET; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ReferenceContainer.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ReferenceContainer.java new file mode 100644 index 0000000..fbf65e5 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/ReferenceContainer.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.processor.internals.StreamsMetadataState; +import org.apache.kafka.streams.processor.internals.TaskManager; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; + +public class ReferenceContainer { + public Consumer mainConsumer; + public Admin adminClient; + public TaskManager taskManager; + public StreamsMetadataState streamsMetadataState; + public final AtomicInteger assignmentErrorCode = new AtomicInteger(); + public final AtomicLong nextScheduledRebalanceMs = new AtomicLong(Long.MAX_VALUE); + public Time time; +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignor.java new file mode 100644 index 0000000..d9f7efa --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignor.java @@ -0,0 +1,337 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.UUID; + +public class StickyTaskAssignor implements TaskAssignor { + + private static final Logger log = LoggerFactory.getLogger(StickyTaskAssignor.class); + private Map clients; + private Set allTaskIds; + private Set standbyTaskIds; + private final Map previousActiveTaskAssignment = new HashMap<>(); + private final Map> previousStandbyTaskAssignment = new HashMap<>(); + private TaskPairs taskPairs; + + private final boolean mustPreserveActiveTaskAssignment; + + public StickyTaskAssignor() { + this(false); + } + + StickyTaskAssignor(final boolean mustPreserveActiveTaskAssignment) { + this.mustPreserveActiveTaskAssignment = mustPreserveActiveTaskAssignment; + } + + @Override + public boolean assign(final Map clients, + final Set allTaskIds, + final Set statefulTaskIds, + final AssignmentConfigs configs) { + this.clients = clients; + this.allTaskIds = allTaskIds; + this.standbyTaskIds = statefulTaskIds; + + final int maxPairs = allTaskIds.size() * (allTaskIds.size() - 1) / 2; + taskPairs = new TaskPairs(maxPairs); + mapPreviousTaskAssignment(clients); + + assignActive(); + assignStandby(configs.numStandbyReplicas); + return false; + } + + private void assignStandby(final int numStandbyReplicas) { + for (final TaskId taskId : standbyTaskIds) { + for (int i = 0; i < numStandbyReplicas; i++) { + final Set ids = findClientsWithoutAssignedTask(taskId); + if (ids.isEmpty()) { + log.warn("Unable to assign {} of {} standby tasks for task [{}]. " + + "There is not enough available capacity. You should " + + "increase the number of threads and/or application instances " + + "to maintain the requested number of standby replicas.", + numStandbyReplicas - i, + numStandbyReplicas, taskId); + break; + } + allocateTaskWithClientCandidates(taskId, ids, false); + } + } + } + + private void assignActive() { + final int totalCapacity = sumCapacity(clients.values()); + final int tasksPerThread = allTaskIds.size() / totalCapacity; + final Set assigned = new HashSet<>(); + + // first try and re-assign existing active tasks to clients that previously had + // the same active task + for (final Map.Entry entry : previousActiveTaskAssignment.entrySet()) { + final TaskId taskId = entry.getKey(); + if (allTaskIds.contains(taskId)) { + final ClientState client = clients.get(entry.getValue()); + if (mustPreserveActiveTaskAssignment || client.hasUnfulfilledQuota(tasksPerThread)) { + assignTaskToClient(assigned, taskId, client); + } + } + } + + final Set unassigned = new HashSet<>(allTaskIds); + unassigned.removeAll(assigned); + + // try and assign any remaining unassigned tasks to clients that previously + // have seen the task. + for (final Iterator iterator = unassigned.iterator(); iterator.hasNext(); ) { + final TaskId taskId = iterator.next(); + final Set clientIds = previousStandbyTaskAssignment.get(taskId); + if (clientIds != null) { + for (final UUID clientId : clientIds) { + final ClientState client = clients.get(clientId); + if (client.hasUnfulfilledQuota(tasksPerThread)) { + assignTaskToClient(assigned, taskId, client); + iterator.remove(); + break; + } + } + } + } + + // assign any remaining unassigned tasks + final List sortedTasks = new ArrayList<>(unassigned); + Collections.sort(sortedTasks); + for (final TaskId taskId : sortedTasks) { + allocateTaskWithClientCandidates(taskId, clients.keySet(), true); + } + } + + private void allocateTaskWithClientCandidates(final TaskId taskId, final Set clientsWithin, final boolean active) { + final ClientState client = findClient(taskId, clientsWithin); + taskPairs.addPairs(taskId, client.assignedTasks()); + if (active) { + client.assignActive(taskId); + } else { + client.assignStandby(taskId); + } + } + + private void assignTaskToClient(final Set assigned, final TaskId taskId, final ClientState client) { + taskPairs.addPairs(taskId, client.assignedTasks()); + client.assignActive(taskId); + assigned.add(taskId); + } + + private Set findClientsWithoutAssignedTask(final TaskId taskId) { + final Set clientIds = new HashSet<>(); + for (final Map.Entry client : clients.entrySet()) { + if (!client.getValue().hasAssignedTask(taskId)) { + clientIds.add(client.getKey()); + } + } + return clientIds; + } + + + private ClientState findClient(final TaskId taskId, final Set clientsWithin) { + + // optimize the case where there is only 1 id to search within. + if (clientsWithin.size() == 1) { + return clients.get(clientsWithin.iterator().next()); + } + + final ClientState previous = findClientsWithPreviousAssignedTask(taskId, clientsWithin); + if (previous == null) { + return leastLoaded(taskId, clientsWithin); + } + + if (shouldBalanceLoad(previous)) { + final ClientState standby = findLeastLoadedClientWithPreviousStandByTask(taskId, clientsWithin); + if (standby == null || shouldBalanceLoad(standby)) { + return leastLoaded(taskId, clientsWithin); + } + return standby; + } + + return previous; + } + + private boolean shouldBalanceLoad(final ClientState client) { + return client.reachedCapacity() && hasClientsWithMoreAvailableCapacity(client); + } + + private boolean hasClientsWithMoreAvailableCapacity(final ClientState client) { + for (final ClientState clientState : clients.values()) { + if (clientState.hasMoreAvailableCapacityThan(client)) { + return true; + } + } + return false; + } + + private ClientState findClientsWithPreviousAssignedTask(final TaskId taskId, final Set clientsWithin) { + final UUID previous = previousActiveTaskAssignment.get(taskId); + if (previous != null && clientsWithin.contains(previous)) { + return clients.get(previous); + } + return findLeastLoadedClientWithPreviousStandByTask(taskId, clientsWithin); + } + + private ClientState findLeastLoadedClientWithPreviousStandByTask(final TaskId taskId, final Set clientsWithin) { + final Set ids = previousStandbyTaskAssignment.get(taskId); + if (ids == null) { + return null; + } + final HashSet constrainTo = new HashSet<>(ids); + constrainTo.retainAll(clientsWithin); + return leastLoaded(taskId, constrainTo); + } + + private ClientState leastLoaded(final TaskId taskId, final Set clientIds) { + final ClientState leastLoaded = findLeastLoaded(taskId, clientIds, true); + if (leastLoaded == null) { + return findLeastLoaded(taskId, clientIds, false); + } + return leastLoaded; + } + + private ClientState findLeastLoaded(final TaskId taskId, + final Set clientIds, + final boolean checkTaskPairs) { + ClientState leastLoaded = null; + for (final UUID id : clientIds) { + final ClientState client = clients.get(id); + if (client.assignedTaskCount() == 0) { + return client; + } + + if (leastLoaded == null || client.hasMoreAvailableCapacityThan(leastLoaded)) { + if (!checkTaskPairs) { + leastLoaded = client; + } else if (taskPairs.hasNewPair(taskId, client.assignedTasks())) { + leastLoaded = client; + } + } + + } + return leastLoaded; + + } + + private void mapPreviousTaskAssignment(final Map clients) { + for (final Map.Entry clientState : clients.entrySet()) { + for (final TaskId activeTask : clientState.getValue().prevActiveTasks()) { + previousActiveTaskAssignment.put(activeTask, clientState.getKey()); + } + + for (final TaskId prevAssignedTask : clientState.getValue().prevStandbyTasks()) { + previousStandbyTaskAssignment.computeIfAbsent(prevAssignedTask, t -> new HashSet<>()); + previousStandbyTaskAssignment.get(prevAssignedTask).add(clientState.getKey()); + } + } + + } + + private int sumCapacity(final Collection values) { + int capacity = 0; + for (final ClientState client : values) { + capacity += client.capacity(); + } + return capacity; + } + + private static class TaskPairs { + private final Set pairs; + private final int maxPairs; + + TaskPairs(final int maxPairs) { + this.maxPairs = maxPairs; + this.pairs = new HashSet<>(maxPairs); + } + + boolean hasNewPair(final TaskId task1, + final Set taskIds) { + if (pairs.size() == maxPairs) { + return false; + } + for (final TaskId taskId : taskIds) { + if (!pairs.contains(pair(task1, taskId))) { + return true; + } + } + return false; + } + + void addPairs(final TaskId taskId, final Set assigned) { + for (final TaskId id : assigned) { + pairs.add(pair(id, taskId)); + } + } + + Pair pair(final TaskId task1, final TaskId task2) { + if (task1.compareTo(task2) < 0) { + return new Pair(task1, task2); + } + return new Pair(task2, task1); + } + + private static class Pair { + private final TaskId task1; + private final TaskId task2; + + Pair(final TaskId task1, final TaskId task2) { + this.task1 = task1; + this.task2 = task2; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final Pair pair = (Pair) o; + return Objects.equals(task1, pair.task1) && + Objects.equals(task2, pair.task2); + } + + @Override + public int hashCode() { + return Objects.hash(task1, task2); + } + } + + + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StreamsAssignmentProtocolVersions.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StreamsAssignmentProtocolVersions.java new file mode 100644 index 0000000..9bd7142 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/StreamsAssignmentProtocolVersions.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +public final class StreamsAssignmentProtocolVersions { + public static final int UNKNOWN = -1; + public static final int EARLIEST_PROBEABLE_VERSION = 3; + public static final int MIN_NAMED_TOPOLOGY_VERSION = 10; + public static final int LATEST_SUPPORTED_VERSION = 10; + /* + * Any time you modify the subscription or assignment info, you need to bump the latest supported version, unless + * the version has already been bumped within the current release cycle. + * + * Last version bump: May 2021, before 3.0 + * + * When changing the version: + * 1) Update variable highest_version in streams_upgrade_test.py::StreamsUpgradeTest.test_version_probing_upgrade + * 2) Add a unit test in SubscriptionInfoTest and/or AssignmentInfoTest + * 3) Note the date & corresponding Kafka version of this bump + */ + + private StreamsAssignmentProtocolVersions() {} +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfo.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfo.java new file mode 100644 index 0000000..58d2dbe --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfo.java @@ -0,0 +1,350 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; + +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.protocol.ByteBufferAccessor; +import org.apache.kafka.common.protocol.MessageUtil; +import org.apache.kafka.streams.errors.TaskAssignmentException; +import org.apache.kafka.streams.internals.generated.SubscriptionInfoData; +import org.apache.kafka.streams.internals.generated.SubscriptionInfoData.PartitionToOffsetSum; +import org.apache.kafka.streams.internals.generated.SubscriptionInfoData.TaskOffsetSum; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.Task; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; + +import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.LATEST_SUPPORTED_VERSION; +import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.MIN_NAMED_TOPOLOGY_VERSION; + +public class SubscriptionInfo { + private static final Logger LOG = LoggerFactory.getLogger(SubscriptionInfo.class); + + static final int UNKNOWN = -1; + static final int MIN_VERSION_OFFSET_SUM_SUBSCRIPTION = 7; + public static final long UNKNOWN_OFFSET_SUM = -3L; + + private final SubscriptionInfoData data; + private Set prevTasksCache = null; + private Set standbyTasksCache = null; + private Map taskOffsetSumsCache = null; + + static { + // Just statically check to make sure that the generated code always stays in sync with the overall protocol + final int subscriptionInfoLatestVersion = SubscriptionInfoData.SCHEMAS.length - 1; + if (subscriptionInfoLatestVersion != LATEST_SUPPORTED_VERSION) { + throw new IllegalArgumentException( + "streams/src/main/resources/common/message/SubscriptionInfo.json needs to be updated to match the " + + "latest assignment protocol version. SubscriptionInfo only supports up to [" + + subscriptionInfoLatestVersion + "] but needs to support up to [" + LATEST_SUPPORTED_VERSION + "]."); + } + } + + private static void validateVersions(final int version, final int latestSupportedVersion) { + if (latestSupportedVersion == UNKNOWN && (version < 1 || version > 2)) { + throw new IllegalArgumentException( + "Only versions 1 and 2 are expected to use an UNKNOWN (-1) latest supported version. " + + "Got " + version + "." + ); + } else if (latestSupportedVersion != UNKNOWN && (version < 1 || version > latestSupportedVersion)) { + throw new IllegalArgumentException( + "version must be between 1 and " + latestSupportedVersion + "; was: " + version + ); + } + } + + public SubscriptionInfo(final int version, + final int latestSupportedVersion, + final UUID processId, + final String userEndPoint, + final Map taskOffsetSums, + final byte uniqueField, + final int errorCode) { + validateVersions(version, latestSupportedVersion); + final SubscriptionInfoData data = new SubscriptionInfoData(); + data.setVersion(version); + data.setProcessId(new Uuid(processId.getMostSignificantBits(), + processId.getLeastSignificantBits())); + + if (version >= 2) { + data.setUserEndPoint(userEndPoint == null + ? new byte[0] + : userEndPoint.getBytes(StandardCharsets.UTF_8)); + } + if (version >= 3) { + data.setLatestSupportedVersion(latestSupportedVersion); + } + if (version >= 8) { + data.setUniqueField(uniqueField); + } + if (version >= 9) { + data.setErrorCode(errorCode); + } + + this.data = data; + + if (version >= MIN_NAMED_TOPOLOGY_VERSION) { + setTaskOffsetSumDataWithNamedTopologiesFromTaskOffsetSumMap(taskOffsetSums); + } else if (version >= MIN_VERSION_OFFSET_SUM_SUBSCRIPTION) { + setTaskOffsetSumDataFromTaskOffsetSumMap(taskOffsetSums); + } else { + setPrevAndStandbySetsFromParsedTaskOffsetSumMap(taskOffsetSums); + } + } + + private SubscriptionInfo(final SubscriptionInfoData subscriptionInfoData) { + validateVersions(subscriptionInfoData.version(), subscriptionInfoData.latestSupportedVersion()); + this.data = subscriptionInfoData; + } + + public int errorCode() { + return data.errorCode(); + } + + // For version > MIN_NAMED_TOPOLOGY_VERSION + private void setTaskOffsetSumDataWithNamedTopologiesFromTaskOffsetSumMap(final Map taskOffsetSums) { + data.setTaskOffsetSums(taskOffsetSums.entrySet().stream().map(t -> { + final SubscriptionInfoData.TaskOffsetSum taskOffsetSum = new SubscriptionInfoData.TaskOffsetSum(); + final TaskId task = t.getKey(); + taskOffsetSum.setTopicGroupId(task.subtopology()); + taskOffsetSum.setPartition(task.partition()); + taskOffsetSum.setNamedTopology(task.topologyName()); + taskOffsetSum.setOffsetSum(t.getValue()); + return taskOffsetSum; + }).collect(Collectors.toList())); + } + + // For MIN_NAMED_TOPOLOGY_VERSION > version > MIN_VERSION_OFFSET_SUM_SUBSCRIPTION + private void setTaskOffsetSumDataFromTaskOffsetSumMap(final Map taskOffsetSums) { + final Map> topicGroupIdToPartitionOffsetSum = new HashMap<>(); + for (final Map.Entry taskEntry : taskOffsetSums.entrySet()) { + final TaskId task = taskEntry.getKey(); + if (task.topologyName() != null) { + throw new TaskAssignmentException("Named topologies are not compatible with older protocol versions"); + } + topicGroupIdToPartitionOffsetSum.computeIfAbsent(task.subtopology(), t -> new ArrayList<>()).add( + new SubscriptionInfoData.PartitionToOffsetSum() + .setPartition(task.partition()) + .setOffsetSum(taskEntry.getValue())); + } + + data.setTaskOffsetSums(topicGroupIdToPartitionOffsetSum.entrySet().stream().map(t -> { + final SubscriptionInfoData.TaskOffsetSum taskOffsetSum = new SubscriptionInfoData.TaskOffsetSum(); + taskOffsetSum.setTopicGroupId(t.getKey()); + taskOffsetSum.setPartitionToOffsetSum(t.getValue()); + return taskOffsetSum; + }).collect(Collectors.toList())); + } + + // For MIN_VERSION_OFFSET_SUM_SUBSCRIPTION > version + private void setPrevAndStandbySetsFromParsedTaskOffsetSumMap(final Map taskOffsetSums) { + final Set prevTasks = new HashSet<>(); + final Set standbyTasks = new HashSet<>(); + + for (final Map.Entry taskOffsetSum : taskOffsetSums.entrySet()) { + if (taskOffsetSum.getKey().topologyName() != null) { + throw new TaskAssignmentException("Named topologies are not compatible with older protocol versions"); + } + if (taskOffsetSum.getValue() == Task.LATEST_OFFSET) { + prevTasks.add(taskOffsetSum.getKey()); + } else { + standbyTasks.add(taskOffsetSum.getKey()); + } + } + + data.setPrevTasks(prevTasks.stream().map(t -> { + final SubscriptionInfoData.TaskId taskId = new SubscriptionInfoData.TaskId(); + taskId.setTopicGroupId(t.subtopology()); + taskId.setPartition(t.partition()); + return taskId; + }).collect(Collectors.toList())); + data.setStandbyTasks(standbyTasks.stream().map(t -> { + final SubscriptionInfoData.TaskId taskId = new SubscriptionInfoData.TaskId(); + taskId.setTopicGroupId(t.subtopology()); + taskId.setPartition(t.partition()); + return taskId; + }).collect(Collectors.toList())); + } + + public int version() { + return data.version(); + } + + public int latestSupportedVersion() { + return data.latestSupportedVersion(); + } + + public UUID processId() { + return new UUID(data.processId().getMostSignificantBits(), data.processId().getLeastSignificantBits()); + } + + public Set prevTasks() { + if (prevTasksCache == null) { + if (data.version() >= MIN_VERSION_OFFSET_SUM_SUBSCRIPTION) { + prevTasksCache = getActiveTasksFromTaskOffsetSumMap(taskOffsetSums()); + } else { + prevTasksCache = Collections.unmodifiableSet( + data.prevTasks() + .stream() + .map(t -> new TaskId(t.topicGroupId(), t.partition())) + .collect(Collectors.toSet()) + ); + } + } + return prevTasksCache; + } + + public Set standbyTasks() { + if (standbyTasksCache == null) { + if (data.version() >= MIN_VERSION_OFFSET_SUM_SUBSCRIPTION) { + standbyTasksCache = getStandbyTasksFromTaskOffsetSumMap(taskOffsetSums()); + } else { + standbyTasksCache = Collections.unmodifiableSet( + data.standbyTasks() + .stream() + .map(t -> new TaskId(t.topicGroupId(), t.partition())) + .collect(Collectors.toSet()) + ); + } + } + return standbyTasksCache; + } + + public Map taskOffsetSums() { + if (taskOffsetSumsCache == null) { + taskOffsetSumsCache = new HashMap<>(); + if (data.version() >= MIN_VERSION_OFFSET_SUM_SUBSCRIPTION) { + for (final TaskOffsetSum taskOffsetSum : data.taskOffsetSums()) { + if (data.version() >= MIN_NAMED_TOPOLOGY_VERSION) { + taskOffsetSumsCache.put( + new TaskId(taskOffsetSum.topicGroupId(), + taskOffsetSum.partition(), + taskOffsetSum.namedTopology()), + taskOffsetSum.offsetSum()); + } else { + for (final PartitionToOffsetSum partitionOffsetSum : taskOffsetSum.partitionToOffsetSum()) { + taskOffsetSumsCache.put( + new TaskId(taskOffsetSum.topicGroupId(), + partitionOffsetSum.partition()), + partitionOffsetSum.offsetSum() + ); + } + } + } + } else { + prevTasks().forEach(taskId -> taskOffsetSumsCache.put(taskId, Task.LATEST_OFFSET)); + standbyTasks().forEach(taskId -> taskOffsetSumsCache.put(taskId, UNKNOWN_OFFSET_SUM)); + } + } + return taskOffsetSumsCache; + } + + public String userEndPoint() { + return data.userEndPoint() == null || data.userEndPoint().length == 0 + ? null + : new String(data.userEndPoint(), StandardCharsets.UTF_8); + } + + public static Set getActiveTasksFromTaskOffsetSumMap(final Map taskOffsetSums) { + return taskOffsetSumMapToTaskSet(taskOffsetSums, true); + } + + public static Set getStandbyTasksFromTaskOffsetSumMap(final Map taskOffsetSums) { + return taskOffsetSumMapToTaskSet(taskOffsetSums, false); + } + + private static Set taskOffsetSumMapToTaskSet(final Map taskOffsetSums, + final boolean getActiveTasks) { + return taskOffsetSums.entrySet().stream() + .filter(t -> getActiveTasks == (t.getValue() == Task.LATEST_OFFSET)) + .map(Map.Entry::getKey) + .collect(Collectors.toSet()); + } + + /** + * @throws TaskAssignmentException if method fails to encode the data + */ + public ByteBuffer encode() { + if (data.version() > LATEST_SUPPORTED_VERSION) { + throw new IllegalStateException( + "Should never try to encode a SubscriptionInfo with version [" + + data.version() + "] > LATEST_SUPPORTED_VERSION [" + LATEST_SUPPORTED_VERSION + "]" + ); + } else return MessageUtil.toByteBuffer(data, (short) data.version()); + } + + /** + * @throws TaskAssignmentException if method fails to decode the data + */ + public static SubscriptionInfo decode(final ByteBuffer data) { + data.rewind(); + final int version = data.getInt(); + + if (version > LATEST_SUPPORTED_VERSION) { + // in this special case, we only rely on the version and latest version, + // + final int latestSupportedVersion = data.getInt(); + final SubscriptionInfoData subscriptionInfoData = new SubscriptionInfoData(); + subscriptionInfoData.setVersion(version); + subscriptionInfoData.setLatestSupportedVersion(latestSupportedVersion); + LOG.info("Unable to decode subscription data: used version: {}; latest supported version: {}", + version, + latestSupportedVersion + ); + return new SubscriptionInfo(subscriptionInfoData); + } else { + data.rewind(); + final ByteBufferAccessor accessor = new ByteBufferAccessor(data); + final SubscriptionInfoData subscriptionInfoData = new SubscriptionInfoData(accessor, (short) version); + return new SubscriptionInfo(subscriptionInfoData); + } + } + + @Override + public int hashCode() { + return data.hashCode(); + } + + @Override + public boolean equals(final Object o) { + if (o instanceof SubscriptionInfo) { + final SubscriptionInfo other = (SubscriptionInfo) o; + return data.equals(other.data); + } else { + return false; + } + } + + @Override + public String toString() { + return data.toString(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java new file mode 100644 index 0000000..aeb2192 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignor.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import org.apache.kafka.streams.processor.TaskId; + +import java.util.Map; +import java.util.Set; +import java.util.UUID; + +public interface TaskAssignor { + /** + * @return whether the generated assignment requires a followup probing rebalance to satisfy all conditions + */ + boolean assign(Map clients, + Set allTaskIds, + Set statefulTaskIds, + AssignorConfiguration.AssignmentConfigs configs); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovement.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovement.java new file mode 100644 index 0000000..cbfa3da --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovement.java @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import org.apache.kafka.streams.processor.TaskId; + +import java.util.Collections; +import java.util.Comparator; +import java.util.Map; +import java.util.PriorityQueue; +import java.util.Queue; +import java.util.Set; +import java.util.SortedSet; +import java.util.TreeSet; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiFunction; + +import static java.util.Arrays.asList; +import static java.util.Objects.requireNonNull; + +final class TaskMovement { + private final TaskId task; + private final UUID destination; + private final SortedSet caughtUpClients; + + private TaskMovement(final TaskId task, final UUID destination, final SortedSet caughtUpClients) { + this.task = task; + this.destination = destination; + this.caughtUpClients = caughtUpClients; + + if (caughtUpClients == null || caughtUpClients.isEmpty()) { + throw new IllegalStateException("Should not attempt to move a task if no caught up clients exist"); + } + } + + private TaskId task() { + return task; + } + + private int numCaughtUpClients() { + return caughtUpClients.size(); + } + + private static boolean taskIsNotCaughtUpOnClientAndOtherCaughtUpClientsExist(final TaskId task, + final UUID client, + final Map> tasksToCaughtUpClients) { + return !taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(task, client, tasksToCaughtUpClients); + } + + private static boolean taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(final TaskId task, + final UUID client, + final Map> tasksToCaughtUpClients) { + final Set caughtUpClients = requireNonNull(tasksToCaughtUpClients.get(task), "uninitialized set"); + return caughtUpClients.isEmpty() || caughtUpClients.contains(client); + } + + static int assignActiveTaskMovements(final Map> tasksToCaughtUpClients, + final Map clientStates, + final Map> warmups, + final AtomicInteger remainingWarmupReplicas) { + final BiFunction caughtUpPredicate = + (client, task) -> taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(task, client, tasksToCaughtUpClients); + + final ConstrainedPrioritySet caughtUpClientsByTaskLoad = new ConstrainedPrioritySet( + caughtUpPredicate, + client -> clientStates.get(client).assignedTaskLoad() + ); + + final Queue taskMovements = new PriorityQueue<>( + Comparator.comparing(TaskMovement::numCaughtUpClients).thenComparing(TaskMovement::task) + ); + + for (final Map.Entry clientStateEntry : clientStates.entrySet()) { + final UUID client = clientStateEntry.getKey(); + final ClientState state = clientStateEntry.getValue(); + for (final TaskId task : state.activeTasks()) { + // if the desired client is not caught up, and there is another client that _is_ caught up, then + // we schedule a movement, so we can move the active task to the caught-up client. We'll try to + // assign a warm-up to the desired client so that we can move it later on. + if (taskIsNotCaughtUpOnClientAndOtherCaughtUpClientsExist(task, client, tasksToCaughtUpClients)) { + taskMovements.add(new TaskMovement(task, client, tasksToCaughtUpClients.get(task))); + } + } + caughtUpClientsByTaskLoad.offer(client); + } + + final int movementsNeeded = taskMovements.size(); + + for (final TaskMovement movement : taskMovements) { + final UUID standbySourceClient = caughtUpClientsByTaskLoad.poll( + movement.task, + c -> clientStates.get(c).hasStandbyTask(movement.task) + ); + if (standbySourceClient == null) { + // there's not a caught-up standby available to take over the task, so we'll schedule a warmup instead + final UUID sourceClient = requireNonNull( + caughtUpClientsByTaskLoad.poll(movement.task), + "Tried to move task to caught-up client but none exist" + ); + + moveActiveAndTryToWarmUp( + remainingWarmupReplicas, + movement.task, + clientStates.get(sourceClient), + clientStates.get(movement.destination), + warmups.computeIfAbsent(movement.destination, x -> new TreeSet<>()) + ); + caughtUpClientsByTaskLoad.offerAll(asList(sourceClient, movement.destination)); + } else { + // we found a candidate to trade standby/active state with our destination, so we don't need a warmup + swapStandbyAndActive( + movement.task, + clientStates.get(standbySourceClient), + clientStates.get(movement.destination) + ); + caughtUpClientsByTaskLoad.offerAll(asList(standbySourceClient, movement.destination)); + } + } + + return movementsNeeded; + } + + static int assignStandbyTaskMovements(final Map> tasksToCaughtUpClients, + final Map clientStates, + final AtomicInteger remainingWarmupReplicas, + final Map> warmups) { + final BiFunction caughtUpPredicate = + (client, task) -> taskIsCaughtUpOnClientOrNoCaughtUpClientsExist(task, client, tasksToCaughtUpClients); + + final ConstrainedPrioritySet caughtUpClientsByTaskLoad = new ConstrainedPrioritySet( + caughtUpPredicate, + client -> clientStates.get(client).assignedTaskLoad() + ); + + final Queue taskMovements = new PriorityQueue<>( + Comparator.comparing(TaskMovement::numCaughtUpClients).thenComparing(TaskMovement::task) + ); + + for (final Map.Entry clientStateEntry : clientStates.entrySet()) { + final UUID destination = clientStateEntry.getKey(); + final ClientState state = clientStateEntry.getValue(); + for (final TaskId task : state.standbyTasks()) { + if (warmups.getOrDefault(destination, Collections.emptySet()).contains(task)) { + // this is a warmup, so we won't move it. + } else if (taskIsNotCaughtUpOnClientAndOtherCaughtUpClientsExist(task, destination, tasksToCaughtUpClients)) { + // if the desired client is not caught up, and there is another client that _is_ caught up, then + // we schedule a movement, so we can move the active task to the caught-up client. We'll try to + // assign a warm-up to the desired client so that we can move it later on. + taskMovements.add(new TaskMovement(task, destination, tasksToCaughtUpClients.get(task))); + } + } + caughtUpClientsByTaskLoad.offer(destination); + } + + int movementsNeeded = 0; + + for (final TaskMovement movement : taskMovements) { + final UUID sourceClient = caughtUpClientsByTaskLoad.poll( + movement.task, + clientId -> !clientStates.get(clientId).hasAssignedTask(movement.task) + ); + + if (sourceClient == null) { + // then there's no caught-up client that doesn't already have a copy of this task, so there's + // nowhere to move it. + } else { + moveStandbyAndTryToWarmUp( + remainingWarmupReplicas, + movement.task, + clientStates.get(sourceClient), + clientStates.get(movement.destination) + ); + caughtUpClientsByTaskLoad.offerAll(asList(sourceClient, movement.destination)); + movementsNeeded++; + } + } + + return movementsNeeded; + } + + private static void moveActiveAndTryToWarmUp(final AtomicInteger remainingWarmupReplicas, + final TaskId task, + final ClientState sourceClientState, + final ClientState destinationClientState, + final Set warmups) { + sourceClientState.assignActive(task); + + if (remainingWarmupReplicas.getAndDecrement() > 0) { + destinationClientState.unassignActive(task); + destinationClientState.assignStandby(task); + warmups.add(task); + } else { + // we have no more standbys or warmups to hand out, so we have to try and move it + // to the destination in a follow-on rebalance + destinationClientState.unassignActive(task); + } + } + + private static void moveStandbyAndTryToWarmUp(final AtomicInteger remainingWarmupReplicas, + final TaskId task, + final ClientState sourceClientState, + final ClientState destinationClientState) { + sourceClientState.assignStandby(task); + + if (remainingWarmupReplicas.getAndDecrement() > 0) { + // Then we can leave it also assigned to the destination as a warmup + } else { + // we have no more warmups to hand out, so we have to try and move it + // to the destination in a follow-on rebalance + destinationClientState.unassignStandby(task); + } + } + + private static void swapStandbyAndActive(final TaskId task, + final ClientState sourceClientState, + final ClientState destinationClientState) { + sourceClientState.unassignStandby(task); + sourceClientState.assignActive(task); + destinationClientState.unassignActive(task); + destinationClientState.assignStandby(task); + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/ProcessorNodeMetrics.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/ProcessorNodeMetrics.java new file mode 100644 index 0000000..231d9a6 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/ProcessorNodeMetrics.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.metrics; + +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.Sensor.RecordingLevel; + +import java.util.Map; + +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.PROCESSOR_NODE_LEVEL_GROUP; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.RECORD_E2E_LATENCY; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.RECORD_E2E_LATENCY_AVG_DESCRIPTION; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.RECORD_E2E_LATENCY_MAX_DESCRIPTION; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.RECORD_E2E_LATENCY_MIN_DESCRIPTION; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.ROLLUP_VALUE; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.TASK_LEVEL_GROUP; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.TOTAL_DESCRIPTION; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addAvgAndMinAndMaxToSensor; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addInvocationRateAndCountToSensor; + +public class ProcessorNodeMetrics { + private ProcessorNodeMetrics() {} + + private static final String RATE_DESCRIPTION_PREFIX = "The average number of "; + private static final String RATE_DESCRIPTION_SUFFIX = " per second"; + + private static final String SUPPRESSION_EMIT = "suppression-emit"; + private static final String SUPPRESSION_EMIT_DESCRIPTION = "emitted records from the suppression buffer"; + private static final String SUPPRESSION_EMIT_TOTAL_DESCRIPTION = TOTAL_DESCRIPTION + SUPPRESSION_EMIT_DESCRIPTION; + private static final String SUPPRESSION_EMIT_RATE_DESCRIPTION = + RATE_DESCRIPTION_PREFIX + SUPPRESSION_EMIT_DESCRIPTION + RATE_DESCRIPTION_SUFFIX; + + private static final String IDEMPOTENT_UPDATE_SKIP = "idempotent-update-skip"; + private static final String IDEMPOTENT_UPDATE_SKIP_DESCRIPTION = "skipped idempotent updates"; + private static final String IDEMPOTENT_UPDATE_SKIP_TOTAL_DESCRIPTION = TOTAL_DESCRIPTION + IDEMPOTENT_UPDATE_SKIP_DESCRIPTION; + private static final String IDEMPOTENT_UPDATE_SKIP_RATE_DESCRIPTION = + RATE_DESCRIPTION_PREFIX + IDEMPOTENT_UPDATE_SKIP_DESCRIPTION + RATE_DESCRIPTION_SUFFIX; + + private static final String PROCESS = "process"; + private static final String PROCESS_DESCRIPTION = "calls to process"; + private static final String PROCESS_TOTAL_DESCRIPTION = TOTAL_DESCRIPTION + PROCESS_DESCRIPTION; + private static final String PROCESS_RATE_DESCRIPTION = + RATE_DESCRIPTION_PREFIX + PROCESS_DESCRIPTION + RATE_DESCRIPTION_SUFFIX; + + private static final String FORWARD = "forward"; + private static final String FORWARD_DESCRIPTION = "calls to forward"; + private static final String FORWARD_TOTAL_DESCRIPTION = TOTAL_DESCRIPTION + FORWARD_DESCRIPTION; + private static final String FORWARD_RATE_DESCRIPTION = + RATE_DESCRIPTION_PREFIX + FORWARD_DESCRIPTION + RATE_DESCRIPTION_SUFFIX; + + public static Sensor suppressionEmitSensor(final String threadId, + final String taskId, + final String processorNodeId, + final StreamsMetricsImpl streamsMetrics) { + return throughputSensor( + threadId, + taskId, + processorNodeId, + SUPPRESSION_EMIT, + SUPPRESSION_EMIT_RATE_DESCRIPTION, + SUPPRESSION_EMIT_TOTAL_DESCRIPTION, + RecordingLevel.DEBUG, + streamsMetrics + ); + } + + public static Sensor skippedIdempotentUpdatesSensor(final String threadId, + final String taskId, + final String processorNodeId, + final StreamsMetricsImpl streamsMetrics) { + return throughputSensor( + threadId, + taskId, + processorNodeId, + IDEMPOTENT_UPDATE_SKIP, + IDEMPOTENT_UPDATE_SKIP_RATE_DESCRIPTION, + IDEMPOTENT_UPDATE_SKIP_TOTAL_DESCRIPTION, + RecordingLevel.DEBUG, + streamsMetrics + ); + } + + public static Sensor processAtSourceSensor(final String threadId, + final String taskId, + final String processorNodeId, + final StreamsMetricsImpl streamsMetrics) { + final Sensor parentSensor = streamsMetrics.taskLevelSensor(threadId, taskId, PROCESS, RecordingLevel.DEBUG); + addInvocationRateAndCountToSensor( + parentSensor, + TASK_LEVEL_GROUP, + streamsMetrics.taskLevelTagMap(threadId, taskId), + PROCESS, + PROCESS_RATE_DESCRIPTION, + PROCESS_TOTAL_DESCRIPTION + ); + return throughputSensor( + threadId, + taskId, + processorNodeId, + PROCESS, + PROCESS_RATE_DESCRIPTION, + PROCESS_TOTAL_DESCRIPTION, + RecordingLevel.DEBUG, + streamsMetrics, + parentSensor + ); + } + + public static Sensor forwardSensor(final String threadId, + final String taskId, + final String processorNodeId, + final StreamsMetricsImpl streamsMetrics) { + final Sensor parentSensor = throughputParentSensor( + threadId, + taskId, + FORWARD, + FORWARD_RATE_DESCRIPTION, + FORWARD_TOTAL_DESCRIPTION, + RecordingLevel.DEBUG, + streamsMetrics + ); + return throughputSensor( + threadId, + taskId, + processorNodeId, + FORWARD, + FORWARD_RATE_DESCRIPTION, + FORWARD_TOTAL_DESCRIPTION, + RecordingLevel.DEBUG, + streamsMetrics, + parentSensor + ); + } + + public static Sensor e2ELatencySensor(final String threadId, + final String taskId, + final String processorNodeId, + final StreamsMetricsImpl streamsMetrics) { + final String sensorName = processorNodeId + "-" + RECORD_E2E_LATENCY; + final Sensor sensor = streamsMetrics.nodeLevelSensor(threadId, taskId, processorNodeId, sensorName, RecordingLevel.INFO); + final Map tagMap = streamsMetrics.nodeLevelTagMap(threadId, taskId, processorNodeId); + addAvgAndMinAndMaxToSensor( + sensor, + PROCESSOR_NODE_LEVEL_GROUP, + tagMap, + RECORD_E2E_LATENCY, + RECORD_E2E_LATENCY_AVG_DESCRIPTION, + RECORD_E2E_LATENCY_MIN_DESCRIPTION, + RECORD_E2E_LATENCY_MAX_DESCRIPTION + ); + return sensor; + } + + private static Sensor throughputParentSensor(final String threadId, + final String taskId, + final String metricNamePrefix, + final String descriptionOfRate, + final String descriptionOfCount, + final RecordingLevel recordingLevel, + final StreamsMetricsImpl streamsMetrics) { + final Sensor sensor = streamsMetrics.taskLevelSensor(threadId, taskId, metricNamePrefix, recordingLevel); + final Map parentTagMap = streamsMetrics.nodeLevelTagMap(threadId, taskId, ROLLUP_VALUE); + addInvocationRateAndCountToSensor( + sensor, + PROCESSOR_NODE_LEVEL_GROUP, + parentTagMap, + metricNamePrefix, + descriptionOfRate, + descriptionOfCount + ); + return sensor; + } + + private static Sensor throughputSensor(final String threadId, + final String taskId, + final String processorNodeId, + final String metricNamePrefix, + final String descriptionOfRate, + final String descriptionOfCount, + final RecordingLevel recordingLevel, + final StreamsMetricsImpl streamsMetrics, + final Sensor... parentSensors) { + final Sensor sensor = + streamsMetrics.nodeLevelSensor(threadId, taskId, processorNodeId, metricNamePrefix, recordingLevel, parentSensors); + final Map tagMap = streamsMetrics.nodeLevelTagMap(threadId, taskId, processorNodeId); + addInvocationRateAndCountToSensor( + sensor, + PROCESSOR_NODE_LEVEL_GROUP, + tagMap, + metricNamePrefix, + descriptionOfRate, + descriptionOfCount + ); + return sensor; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/StreamsMetricsImpl.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/StreamsMetricsImpl.java new file mode 100644 index 0000000..dea2399 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/StreamsMetricsImpl.java @@ -0,0 +1,866 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.metrics; + +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.Gauge; +import org.apache.kafka.common.metrics.MetricConfig; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.Sensor.RecordingLevel; +import org.apache.kafka.common.metrics.stats.Avg; +import org.apache.kafka.common.metrics.stats.CumulativeCount; +import org.apache.kafka.common.metrics.stats.CumulativeSum; +import org.apache.kafka.common.metrics.stats.Max; +import org.apache.kafka.common.metrics.stats.Min; +import org.apache.kafka.common.metrics.stats.Rate; +import org.apache.kafka.common.metrics.stats.Value; +import org.apache.kafka.common.metrics.stats.WindowedCount; +import org.apache.kafka.common.metrics.stats.WindowedSum; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.StreamsMetrics; +import org.apache.kafka.streams.state.internals.metrics.RocksDBMetricsRecordingTrigger; + +import java.util.Collections; +import java.util.Deque; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.LinkedList; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +public class StreamsMetricsImpl implements StreamsMetrics { + + public enum Version { + LATEST + } + + static class ImmutableMetricValue implements Gauge { + private final T value; + + public ImmutableMetricValue(final T value) { + this.value = value; + } + + @Override + public T value(final MetricConfig config, final long now) { + return value; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final ImmutableMetricValue that = (ImmutableMetricValue) o; + return Objects.equals(value, that.value); + } + + @Override + public int hashCode() { + return Objects.hash(value); + } + } + + private final Metrics metrics; + private final Map parentSensors; + private final String clientId; + + private final Version version; + private final Deque clientLevelMetrics = new LinkedList<>(); + private final Deque clientLevelSensors = new LinkedList<>(); + private final Map> threadLevelMetrics = new HashMap<>(); + private final Map> threadLevelSensors = new HashMap<>(); + private final Map> taskLevelSensors = new HashMap<>(); + private final Map> nodeLevelSensors = new HashMap<>(); + private final Map> cacheLevelSensors = new HashMap<>(); + private final ConcurrentMap> storeLevelSensors = new ConcurrentHashMap<>(); + private final ConcurrentMap> storeLevelMetrics = new ConcurrentHashMap<>(); + + private final RocksDBMetricsRecordingTrigger rocksDBMetricsRecordingTrigger; + + private static final String SENSOR_PREFIX_DELIMITER = "."; + private static final String SENSOR_NAME_DELIMITER = ".s."; + private static final String SENSOR_TASK_LABEL = "task"; + private static final String SENSOR_NODE_LABEL = "node"; + private static final String SENSOR_CACHE_LABEL = "cache"; + private static final String SENSOR_STORE_LABEL = "store"; + private static final String SENSOR_ENTITY_LABEL = "entity"; + private static final String SENSOR_EXTERNAL_LABEL = "external"; + private static final String SENSOR_INTERNAL_LABEL = "internal"; + + public static final String CLIENT_ID_TAG = "client-id"; + public static final String THREAD_ID_TAG = "thread-id"; + public static final String TASK_ID_TAG = "task-id"; + public static final String PROCESSOR_NODE_ID_TAG = "processor-node-id"; + public static final String STORE_ID_TAG = "state-id"; + public static final String RECORD_CACHE_ID_TAG = "record-cache-id"; + + public static final String ROLLUP_VALUE = "all"; + + public static final String LATENCY_SUFFIX = "-latency"; + public static final String RECORDS_SUFFIX = "-records"; + public static final String AVG_SUFFIX = "-avg"; + public static final String MAX_SUFFIX = "-max"; + public static final String MIN_SUFFIX = "-min"; + public static final String RATE_SUFFIX = "-rate"; + public static final String TOTAL_SUFFIX = "-total"; + public static final String RATIO_SUFFIX = "-ratio"; + + public static final String GROUP_PREFIX_WO_DELIMITER = "stream"; + public static final String GROUP_PREFIX = GROUP_PREFIX_WO_DELIMITER + "-"; + public static final String GROUP_SUFFIX = "-metrics"; + public static final String CLIENT_LEVEL_GROUP = GROUP_PREFIX_WO_DELIMITER + GROUP_SUFFIX; + public static final String THREAD_LEVEL_GROUP = GROUP_PREFIX + "thread" + GROUP_SUFFIX; + public static final String TASK_LEVEL_GROUP = GROUP_PREFIX + "task" + GROUP_SUFFIX; + public static final String PROCESSOR_NODE_LEVEL_GROUP = GROUP_PREFIX + "processor-node" + GROUP_SUFFIX; + public static final String STATE_STORE_LEVEL_GROUP = GROUP_PREFIX + "state" + GROUP_SUFFIX; + public static final String CACHE_LEVEL_GROUP = GROUP_PREFIX + "record-cache" + GROUP_SUFFIX; + + public static final String OPERATIONS = " operations"; + public static final String TOTAL_DESCRIPTION = "The total number of "; + public static final String RATE_DESCRIPTION = "The average per-second number of "; + public static final String AVG_LATENCY_DESCRIPTION = "The average latency of "; + public static final String MAX_LATENCY_DESCRIPTION = "The maximum latency of "; + public static final String RATE_DESCRIPTION_PREFIX = "The average number of "; + public static final String RATE_DESCRIPTION_SUFFIX = " per second"; + + public static final String RECORD_E2E_LATENCY = "record-e2e-latency"; + public static final String RECORD_E2E_LATENCY_DESCRIPTION_SUFFIX = + "end-to-end latency of a record, measuring by comparing the record timestamp with the " + + "system time when it has been fully processed by the node"; + public static final String RECORD_E2E_LATENCY_AVG_DESCRIPTION = "The average " + RECORD_E2E_LATENCY_DESCRIPTION_SUFFIX; + public static final String RECORD_E2E_LATENCY_MIN_DESCRIPTION = "The minimum " + RECORD_E2E_LATENCY_DESCRIPTION_SUFFIX; + public static final String RECORD_E2E_LATENCY_MAX_DESCRIPTION = "The maximum " + RECORD_E2E_LATENCY_DESCRIPTION_SUFFIX; + + public StreamsMetricsImpl(final Metrics metrics, + final String clientId, + final String builtInMetricsVersion, + final Time time) { + Objects.requireNonNull(metrics, "Metrics cannot be null"); + Objects.requireNonNull(builtInMetricsVersion, "Built-in metrics version cannot be null"); + this.metrics = metrics; + this.clientId = clientId; + version = Version.LATEST; + rocksDBMetricsRecordingTrigger = new RocksDBMetricsRecordingTrigger(time); + + this.parentSensors = new HashMap<>(); + } + + public Version version() { + return version; + } + + public RocksDBMetricsRecordingTrigger rocksDBMetricsRecordingTrigger() { + return rocksDBMetricsRecordingTrigger; + } + + public void addClientLevelImmutableMetric(final String name, + final String description, + final RecordingLevel recordingLevel, + final T value) { + final MetricName metricName = metrics.metricName(name, CLIENT_LEVEL_GROUP, description, clientLevelTagMap()); + final MetricConfig metricConfig = new MetricConfig().recordLevel(recordingLevel); + synchronized (clientLevelMetrics) { + metrics.addMetric(metricName, metricConfig, new ImmutableMetricValue<>(value)); + clientLevelMetrics.push(metricName); + } + } + + public void addClientLevelMutableMetric(final String name, + final String description, + final RecordingLevel recordingLevel, + final Gauge valueProvider) { + final MetricName metricName = metrics.metricName(name, CLIENT_LEVEL_GROUP, description, clientLevelTagMap()); + final MetricConfig metricConfig = new MetricConfig().recordLevel(recordingLevel); + synchronized (clientLevelMetrics) { + metrics.addMetric(metricName, metricConfig, valueProvider); + clientLevelMetrics.push(metricName); + } + } + + public void addThreadLevelImmutableMetric(final String name, + final String description, + final String threadId, + final T value) { + final MetricName metricName = metrics.metricName( + name, THREAD_LEVEL_GROUP, description, threadLevelTagMap(threadId)); + synchronized (threadLevelMetrics) { + threadLevelMetrics.computeIfAbsent( + threadSensorPrefix(threadId), + tid -> new LinkedList<>() + ).add(metricName); + metrics.addMetric(metricName, new ImmutableMetricValue<>(value)); + } + } + + public void addThreadLevelMutableMetric(final String name, + final String description, + final String threadId, + final Gauge valueProvider) { + final MetricName metricName = metrics.metricName( + name, THREAD_LEVEL_GROUP, description, threadLevelTagMap(threadId)); + synchronized (threadLevelMetrics) { + threadLevelMetrics.computeIfAbsent( + threadSensorPrefix(threadId), + tid -> new LinkedList<>() + ).add(metricName); + metrics.addMetric(metricName, valueProvider); + } + } + + public final Sensor clientLevelSensor(final String sensorName, + final RecordingLevel recordingLevel, + final Sensor... parents) { + synchronized (clientLevelSensors) { + final String fullSensorName = CLIENT_LEVEL_GROUP + SENSOR_NAME_DELIMITER + sensorName; + final Sensor sensor = metrics.getSensor(fullSensorName); + if (sensor == null) { + clientLevelSensors.push(fullSensorName); + return metrics.sensor(fullSensorName, recordingLevel, parents); + } + return sensor; + } + } + + public final Sensor threadLevelSensor(final String threadId, + final String sensorName, + final RecordingLevel recordingLevel, + final Sensor... parents) { + final String key = threadSensorPrefix(threadId); + synchronized (threadLevelSensors) { + return getSensors(threadLevelSensors, sensorName, key, recordingLevel, parents); + } + } + + private String threadSensorPrefix(final String threadId) { + return SENSOR_INTERNAL_LABEL + SENSOR_PREFIX_DELIMITER + threadId; + } + + public Map clientLevelTagMap() { + final Map tagMap = new LinkedHashMap<>(); + tagMap.put(CLIENT_ID_TAG, clientId); + return tagMap; + } + + public Map threadLevelTagMap(final String threadId) { + final Map tagMap = new LinkedHashMap<>(); + tagMap.put(THREAD_ID_TAG, threadId); + return tagMap; + } + + public final void removeAllClientLevelSensorsAndMetrics() { + removeAllClientLevelSensors(); + removeAllClientLevelMetrics(); + } + + private void removeAllClientLevelMetrics() { + synchronized (clientLevelMetrics) { + while (!clientLevelMetrics.isEmpty()) { + metrics.removeMetric(clientLevelMetrics.pop()); + } + } + } + + private void removeAllClientLevelSensors() { + synchronized (clientLevelSensors) { + while (!clientLevelSensors.isEmpty()) { + metrics.removeSensor(clientLevelSensors.pop()); + } + } + } + + public final void removeAllThreadLevelSensors(final String threadId) { + final String key = threadSensorPrefix(threadId); + synchronized (threadLevelSensors) { + final Deque sensors = threadLevelSensors.remove(key); + while (sensors != null && !sensors.isEmpty()) { + metrics.removeSensor(sensors.pop()); + } + } + } + + public final void removeAllThreadLevelMetrics(final String threadId) { + synchronized (threadLevelMetrics) { + final Deque names = threadLevelMetrics.remove(threadSensorPrefix(threadId)); + while (names != null && !names.isEmpty()) { + metrics.removeMetric(names.pop()); + } + } + } + + public Map taskLevelTagMap(final String threadId, final String taskId) { + final Map tagMap = threadLevelTagMap(threadId); + tagMap.put(TASK_ID_TAG, taskId); + return tagMap; + } + + public Map nodeLevelTagMap(final String threadId, + final String taskName, + final String processorNodeName) { + final Map tagMap = taskLevelTagMap(threadId, taskName); + tagMap.put(PROCESSOR_NODE_ID_TAG, processorNodeName); + return tagMap; + } + + public Map storeLevelTagMap(final String taskName, + final String storeType, + final String storeName) { + final Map tagMap = taskLevelTagMap(Thread.currentThread().getName(), taskName); + tagMap.put(storeType + "-" + STORE_ID_TAG, storeName); + return tagMap; + } + + public final Sensor taskLevelSensor(final String threadId, + final String taskId, + final String sensorName, + final RecordingLevel recordingLevel, + final Sensor... parents) { + final String key = taskSensorPrefix(threadId, taskId); + synchronized (taskLevelSensors) { + return getSensors(taskLevelSensors, sensorName, key, recordingLevel, parents); + } + } + + public final void removeAllTaskLevelSensors(final String threadId, final String taskId) { + final String key = taskSensorPrefix(threadId, taskId); + synchronized (taskLevelSensors) { + final Deque sensors = taskLevelSensors.remove(key); + while (sensors != null && !sensors.isEmpty()) { + metrics.removeSensor(sensors.pop()); + } + } + } + + private String taskSensorPrefix(final String threadId, final String taskId) { + return threadSensorPrefix(threadId) + SENSOR_PREFIX_DELIMITER + SENSOR_TASK_LABEL + SENSOR_PREFIX_DELIMITER + + taskId; + } + + public Sensor nodeLevelSensor(final String threadId, + final String taskId, + final String processorNodeName, + final String sensorName, + final Sensor.RecordingLevel recordingLevel, + final Sensor... parents) { + final String key = nodeSensorPrefix(threadId, taskId, processorNodeName); + synchronized (nodeLevelSensors) { + return getSensors(nodeLevelSensors, sensorName, key, recordingLevel, parents); + } + } + + public final void removeAllNodeLevelSensors(final String threadId, + final String taskId, + final String processorNodeName) { + final String key = nodeSensorPrefix(threadId, taskId, processorNodeName); + synchronized (nodeLevelSensors) { + final Deque sensors = nodeLevelSensors.remove(key); + while (sensors != null && !sensors.isEmpty()) { + metrics.removeSensor(sensors.pop()); + } + } + } + + private String nodeSensorPrefix(final String threadId, final String taskId, final String processorNodeName) { + return taskSensorPrefix(threadId, taskId) + + SENSOR_PREFIX_DELIMITER + SENSOR_NODE_LABEL + SENSOR_PREFIX_DELIMITER + processorNodeName; + } + + public Sensor cacheLevelSensor(final String threadId, + final String taskName, + final String storeName, + final String sensorName, + final Sensor.RecordingLevel recordingLevel, + final Sensor... parents) { + final String key = cacheSensorPrefix(threadId, taskName, storeName); + synchronized (cacheLevelSensors) { + return getSensors(cacheLevelSensors, sensorName, key, recordingLevel, parents); + } + } + + public Map cacheLevelTagMap(final String threadId, + final String taskId, + final String storeName) { + final Map tagMap = new LinkedHashMap<>(); + tagMap.put(THREAD_ID_TAG, threadId); + tagMap.put(TASK_ID_TAG, taskId); + tagMap.put(RECORD_CACHE_ID_TAG, storeName); + return tagMap; + } + + public final void removeAllCacheLevelSensors(final String threadId, final String taskId, final String cacheName) { + final String key = cacheSensorPrefix(threadId, taskId, cacheName); + synchronized (cacheLevelSensors) { + final Deque strings = cacheLevelSensors.remove(key); + while (strings != null && !strings.isEmpty()) { + metrics.removeSensor(strings.pop()); + } + } + } + + private String cacheSensorPrefix(final String threadId, final String taskId, final String cacheName) { + return taskSensorPrefix(threadId, taskId) + + SENSOR_PREFIX_DELIMITER + SENSOR_CACHE_LABEL + SENSOR_PREFIX_DELIMITER + cacheName; + } + + public final Sensor storeLevelSensor(final String taskId, + final String storeName, + final String sensorName, + final RecordingLevel recordingLevel, + final Sensor... parents) { + final String key = storeSensorPrefix(Thread.currentThread().getName(), taskId, storeName); + // since the keys in the map storeLevelSensors contain the name of the current thread and threads only + // access keys in which their name is contained, the value in the maps do not need to be thread safe + // and we can use a LinkedList here. + // TODO: In future, we could use thread local maps since each thread will exclusively access the set of keys + // that contain its name. Similar is true for the other metric levels. Thread-level metrics need some + // special attention, since they are created before the thread is constructed. The creation of those + // metrics could be moved into the run() method of the thread. + return getSensors(storeLevelSensors, sensorName, key, recordingLevel, parents); + } + + public void addStoreLevelMutableMetric(final String taskId, + final String metricsScope, + final String storeName, + final String name, + final String description, + final RecordingLevel recordingLevel, + final Gauge valueProvider) { + final MetricName metricName = metrics.metricName( + name, + STATE_STORE_LEVEL_GROUP, + description, + storeLevelTagMap(taskId, metricsScope, storeName) + ); + if (metrics.metric(metricName) == null) { + final MetricConfig metricConfig = new MetricConfig().recordLevel(recordingLevel); + final String key = storeSensorPrefix(Thread.currentThread().getName(), taskId, storeName); + metrics.addMetric(metricName, metricConfig, valueProvider); + storeLevelMetrics.computeIfAbsent(key, ignored -> new LinkedList<>()).push(metricName); + } + } + + public final void removeAllStoreLevelSensorsAndMetrics(final String taskId, + final String storeName) { + final String threadId = Thread.currentThread().getName(); + removeAllStoreLevelSensors(threadId, taskId, storeName); + removeAllStoreLevelMetrics(threadId, taskId, storeName); + } + + private void removeAllStoreLevelSensors(final String threadId, + final String taskId, + final String storeName) { + final String key = storeSensorPrefix(threadId, taskId, storeName); + final Deque sensors = storeLevelSensors.remove(key); + while (sensors != null && !sensors.isEmpty()) { + metrics.removeSensor(sensors.pop()); + } + } + + private void removeAllStoreLevelMetrics(final String threadId, + final String taskId, + final String storeName) { + final String key = storeSensorPrefix(threadId, taskId, storeName); + final Deque metricNames = storeLevelMetrics.remove(key); + while (metricNames != null && !metricNames.isEmpty()) { + metrics.removeMetric(metricNames.pop()); + } + } + + private String storeSensorPrefix(final String threadId, + final String taskId, + final String storeName) { + return taskSensorPrefix(threadId, taskId) + + SENSOR_PREFIX_DELIMITER + SENSOR_STORE_LABEL + SENSOR_PREFIX_DELIMITER + storeName; + } + + @Override + public Sensor addSensor(final String name, final Sensor.RecordingLevel recordingLevel) { + return metrics.sensor(name, recordingLevel); + } + + @Override + public Sensor addSensor(final String name, final Sensor.RecordingLevel recordingLevel, final Sensor... parents) { + return metrics.sensor(name, recordingLevel, parents); + } + + @Override + public Map metrics() { + return Collections.unmodifiableMap(this.metrics.metrics()); + } + + private Map customizedTags(final String threadId, + final String scopeName, + final String entityName, + final String... tags) { + final Map tagMap = threadLevelTagMap(threadId); + tagMap.put(scopeName + "-id", entityName); + if (tags != null) { + if ((tags.length % 2) != 0) { + throw new IllegalArgumentException("Tags needs to be specified in key-value pairs"); + } + for (int i = 0; i < tags.length; i += 2) { + tagMap.put(tags[i], tags[i + 1]); + } + } + return tagMap; + } + + private Sensor customInvocationRateAndCountSensor(final String threadId, + final String groupName, + final String entityName, + final String operationName, + final Map tags, + final Sensor.RecordingLevel recordingLevel) { + final Sensor sensor = metrics.sensor(externalChildSensorName(threadId, operationName, entityName), recordingLevel); + addInvocationRateAndCountToSensor( + sensor, + groupName, + tags, + operationName, + RATE_DESCRIPTION_PREFIX + operationName + OPERATIONS + RATE_DESCRIPTION_SUFFIX, + TOTAL_DESCRIPTION + operationName + OPERATIONS + ); + return sensor; + } + + @Override + public Sensor addLatencyRateTotalSensor(final String scopeName, + final String entityName, + final String operationName, + final Sensor.RecordingLevel recordingLevel, + final String... tags) { + final String threadId = Thread.currentThread().getName(); + final String group = groupNameFromScope(scopeName); + final Map tagMap = customizedTags(threadId, scopeName, entityName, tags); + final Sensor sensor = + customInvocationRateAndCountSensor(threadId, group, entityName, operationName, tagMap, recordingLevel); + addAvgAndMaxToSensor( + sensor, + group, + tagMap, + operationName + LATENCY_SUFFIX, + AVG_LATENCY_DESCRIPTION + operationName, + MAX_LATENCY_DESCRIPTION + operationName + ); + + return sensor; + } + + @Override + public Sensor addRateTotalSensor(final String scopeName, + final String entityName, + final String operationName, + final Sensor.RecordingLevel recordingLevel, + final String... tags) { + final String threadId = Thread.currentThread().getName(); + final Map tagMap = customizedTags(threadId, scopeName, entityName, tags); + return customInvocationRateAndCountSensor( + threadId, + groupNameFromScope(scopeName), + entityName, + operationName, + tagMap, + recordingLevel + ); + } + + private String externalChildSensorName(final String threadId, final String operationName, final String entityName) { + return SENSOR_EXTERNAL_LABEL + SENSOR_PREFIX_DELIMITER + threadId + + SENSOR_PREFIX_DELIMITER + SENSOR_ENTITY_LABEL + SENSOR_PREFIX_DELIMITER + entityName + + SENSOR_NAME_DELIMITER + operationName; + } + + public static void addAvgAndMaxToSensor(final Sensor sensor, + final String group, + final Map tags, + final String operation, + final String descriptionOfAvg, + final String descriptionOfMax) { + sensor.add( + new MetricName( + operation + AVG_SUFFIX, + group, + descriptionOfAvg, + tags), + new Avg() + ); + sensor.add( + new MetricName( + operation + MAX_SUFFIX, + group, + descriptionOfMax, + tags), + new Max() + ); + } + + public static void addMinAndMaxToSensor(final Sensor sensor, + final String group, + final Map tags, + final String operation, + final String descriptionOfMin, + final String descriptionOfMax) { + sensor.add( + new MetricName( + operation + MIN_SUFFIX, + group, + descriptionOfMin, + tags), + new Min() + ); + + sensor.add( + new MetricName( + operation + MAX_SUFFIX, + group, + descriptionOfMax, + tags), + new Max() + ); + } + + public static void addAvgAndMaxLatencyToSensor(final Sensor sensor, + final String group, + final Map tags, + final String operation) { + sensor.add( + new MetricName( + operation + "-latency-avg", + group, + AVG_LATENCY_DESCRIPTION + operation + " operation.", + tags), + new Avg() + ); + sensor.add( + new MetricName( + operation + "-latency-max", + group, + MAX_LATENCY_DESCRIPTION + operation + " operation.", + tags), + new Max() + ); + } + + public static void addAvgAndMinAndMaxToSensor(final Sensor sensor, + final String group, + final Map tags, + final String operation, + final String descriptionOfAvg, + final String descriptionOfMin, + final String descriptionOfMax) { + addAvgAndMaxToSensor(sensor, group, tags, operation, descriptionOfAvg, descriptionOfMax); + sensor.add( + new MetricName( + operation + MIN_SUFFIX, + group, + descriptionOfMin, + tags), + new Min() + ); + } + + public static void addInvocationRateAndCountToSensor(final Sensor sensor, + final String group, + final Map tags, + final String operation, + final String descriptionOfRate, + final String descriptionOfCount) { + addInvocationRateToSensor(sensor, group, tags, operation, descriptionOfRate); + sensor.add( + new MetricName( + operation + TOTAL_SUFFIX, + group, + descriptionOfCount, + tags + ), + new CumulativeCount() + ); + } + + public static void addInvocationRateToSensor(final Sensor sensor, + final String group, + final Map tags, + final String operation, + final String descriptionOfRate) { + sensor.add( + new MetricName( + operation + RATE_SUFFIX, + group, + descriptionOfRate, + tags + ), + new Rate(TimeUnit.SECONDS, new WindowedCount()) + ); + } + + public static void addInvocationRateAndCountToSensor(final Sensor sensor, + final String group, + final Map tags, + final String operation) { + addInvocationRateAndCountToSensor( + sensor, + group, + tags, + operation, + RATE_DESCRIPTION + operation, + TOTAL_DESCRIPTION + operation + ); + } + + public static void addRateOfSumAndSumMetricsToSensor(final Sensor sensor, + final String group, + final Map tags, + final String operation, + final String descriptionOfRate, + final String descriptionOfTotal) { + addRateOfSumMetricToSensor(sensor, group, tags, operation, descriptionOfRate); + addSumMetricToSensor(sensor, group, tags, operation, descriptionOfTotal); + } + + public static void addRateOfSumMetricToSensor(final Sensor sensor, + final String group, + final Map tags, + final String operation, + final String description) { + sensor.add(new MetricName(operation + RATE_SUFFIX, group, description, tags), + new Rate(TimeUnit.SECONDS, new WindowedSum())); + } + + public static void addSumMetricToSensor(final Sensor sensor, + final String group, + final Map tags, + final String operation, + final String description) { + addSumMetricToSensor(sensor, group, tags, operation, true, description); + } + + public static void addSumMetricToSensor(final Sensor sensor, + final String group, + final Map tags, + final String operation, + final boolean withSuffix, + final String description) { + sensor.add( + new MetricName( + withSuffix ? operation + TOTAL_SUFFIX : operation, + group, + description, + tags + ), + new CumulativeSum() + ); + } + + public static void addValueMetricToSensor(final Sensor sensor, + final String group, + final Map tags, + final String name, + final String description) { + sensor.add(new MetricName(name, group, description, tags), new Value()); + } + + public static void addAvgAndSumMetricsToSensor(final Sensor sensor, + final String group, + final Map tags, + final String metricNamePrefix, + final String descriptionOfAvg, + final String descriptionOfTotal) { + sensor.add(new MetricName(metricNamePrefix + AVG_SUFFIX, group, descriptionOfAvg, tags), new Avg()); + sensor.add( + new MetricName(metricNamePrefix + TOTAL_SUFFIX, group, descriptionOfTotal, tags), + new CumulativeSum() + ); + } + + public static void maybeMeasureLatency(final Runnable actionToMeasure, + final Time time, + final Sensor sensor) { + if (sensor.shouldRecord() && sensor.hasMetrics()) { + final long startNs = time.nanoseconds(); + try { + actionToMeasure.run(); + } finally { + sensor.record(time.nanoseconds() - startNs); + } + } else { + actionToMeasure.run(); + } + } + + public static T maybeMeasureLatency(final Supplier actionToMeasure, + final Time time, + final Sensor sensor) { + if (sensor.shouldRecord() && sensor.hasMetrics()) { + final long startNs = time.nanoseconds(); + try { + return actionToMeasure.get(); + } finally { + sensor.record(time.nanoseconds() - startNs); + } + } else { + return actionToMeasure.get(); + } + } + + private Sensor getSensors(final Map> sensors, + final String sensorName, + final String key, + final RecordingLevel recordingLevel, + final Sensor... parents) { + final String fullSensorName = key + SENSOR_NAME_DELIMITER + sensorName; + final Sensor sensor = metrics.getSensor(fullSensorName); + if (sensor == null) { + sensors.computeIfAbsent(key, ignored -> new LinkedList<>()).push(fullSensorName); + return metrics.sensor(fullSensorName, recordingLevel, parents); + } + return sensor; + } + + /** + * Deletes a sensor and its parents, if any + */ + @Override + public void removeSensor(final Sensor sensor) { + Objects.requireNonNull(sensor, "Sensor is null"); + metrics.removeSensor(sensor.name()); + + final Sensor parent = parentSensors.remove(sensor); + if (parent != null) { + metrics.removeSensor(parent.name()); + } + } + + /** + * Visible for testing + */ + Map parentSensors() { + return Collections.unmodifiableMap(parentSensors); + } + + private static String groupNameFromScope(final String scopeName) { + return "stream-" + scopeName + "-metrics"; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/TaskMetrics.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/TaskMetrics.java new file mode 100644 index 0000000..cfa1ac6 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/TaskMetrics.java @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.metrics; + +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.Sensor.RecordingLevel; + +import java.util.Map; + +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.LATENCY_SUFFIX; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.RATIO_SUFFIX; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.TASK_LEVEL_GROUP; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.TOTAL_DESCRIPTION; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addAvgAndMaxToSensor; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addInvocationRateAndCountToSensor; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addValueMetricToSensor; + +public class TaskMetrics { + private TaskMetrics() {} + + private static final String AVG_LATENCY_DESCRIPTION = "The average latency of "; + private static final String MAX_LATENCY_DESCRIPTION = "The maximum latency of "; + private static final String RATE_DESCRIPTION_PREFIX = "The average number of "; + private static final String RATE_DESCRIPTION_SUFFIX = " per second"; + private static final String ACTIVE_TASK_PREFIX = "active-"; + + private static final String COMMIT = "commit"; + private static final String COMMIT_DESCRIPTION = "calls to commit"; + private static final String COMMIT_TOTAL_DESCRIPTION = TOTAL_DESCRIPTION + COMMIT_DESCRIPTION; + private static final String COMMIT_RATE_DESCRIPTION = + RATE_DESCRIPTION_PREFIX + COMMIT_DESCRIPTION + RATE_DESCRIPTION_SUFFIX; + + private static final String PUNCTUATE = "punctuate"; + private static final String PUNCTUATE_DESCRIPTION = "calls to punctuate"; + private static final String PUNCTUATE_TOTAL_DESCRIPTION = TOTAL_DESCRIPTION + PUNCTUATE_DESCRIPTION; + private static final String PUNCTUATE_RATE_DESCRIPTION = + RATE_DESCRIPTION_PREFIX + PUNCTUATE_DESCRIPTION + RATE_DESCRIPTION_SUFFIX; + private static final String PUNCTUATE_AVG_LATENCY_DESCRIPTION = AVG_LATENCY_DESCRIPTION + PUNCTUATE_DESCRIPTION; + private static final String PUNCTUATE_MAX_LATENCY_DESCRIPTION = MAX_LATENCY_DESCRIPTION + PUNCTUATE_DESCRIPTION; + + private static final String ENFORCED_PROCESSING = "enforced-processing"; + private static final String ENFORCED_PROCESSING_TOTAL_DESCRIPTION = + "The total number of occurrences of enforced-processing operations"; + private static final String ENFORCED_PROCESSING_RATE_DESCRIPTION = + "The average number of occurrences of enforced-processing operations per second"; + + private static final String RECORD_LATENESS = "record-lateness"; + private static final String RECORD_LATENESS_MAX_DESCRIPTION = + "The observed maximum lateness of records in milliseconds, measured by comparing the record timestamp with the " + + "current stream time"; + private static final String RECORD_LATENESS_AVG_DESCRIPTION = + "The observed average lateness of records in milliseconds, measured by comparing the record timestamp with the " + + "current stream time"; + + private static final String DROPPED_RECORDS = "dropped-records"; + private static final String DROPPED_RECORDS_DESCRIPTION = "dropped records"; + private static final String DROPPED_RECORDS_TOTAL_DESCRIPTION = TOTAL_DESCRIPTION + DROPPED_RECORDS_DESCRIPTION; + private static final String DROPPED_RECORDS_RATE_DESCRIPTION = + RATE_DESCRIPTION_PREFIX + DROPPED_RECORDS_DESCRIPTION + RATE_DESCRIPTION_SUFFIX; + + private static final String PROCESS = "process"; + private static final String PROCESS_LATENCY = PROCESS + LATENCY_SUFFIX; + private static final String PROCESS_DESCRIPTION = "calls to process"; + private static final String PROCESS_AVG_LATENCY_DESCRIPTION = AVG_LATENCY_DESCRIPTION + PROCESS_DESCRIPTION; + private static final String PROCESS_MAX_LATENCY_DESCRIPTION = MAX_LATENCY_DESCRIPTION + PROCESS_DESCRIPTION; + private static final String PROCESS_RATIO_DESCRIPTION = "The fraction of time the thread spent " + + "on processing this task among all assigned active tasks"; + + private static final String BUFFER_COUNT = "buffer-count"; + private static final String NUM_BUFFERED_RECORDS_DESCRIPTION = "The count of buffered records that are polled " + + "from consumer and not yet processed for this active task"; + + public static Sensor processLatencySensor(final String threadId, + final String taskId, + final StreamsMetricsImpl streamsMetrics) { + return avgAndMaxSensor( + threadId, + taskId, + PROCESS_LATENCY, + PROCESS_AVG_LATENCY_DESCRIPTION, + PROCESS_MAX_LATENCY_DESCRIPTION, + RecordingLevel.DEBUG, + streamsMetrics + ); + } + + public static Sensor activeProcessRatioSensor(final String threadId, + final String taskId, + final StreamsMetricsImpl streamsMetrics) { + final String name = ACTIVE_TASK_PREFIX + PROCESS + RATIO_SUFFIX; + final Sensor sensor = streamsMetrics.taskLevelSensor(threadId, taskId, name, Sensor.RecordingLevel.INFO); + addValueMetricToSensor( + sensor, + TASK_LEVEL_GROUP, + streamsMetrics.taskLevelTagMap(threadId, taskId), + name, + PROCESS_RATIO_DESCRIPTION + ); + return sensor; + } + + public static Sensor activeBufferedRecordsSensor(final String threadId, + final String taskId, + final StreamsMetricsImpl streamsMetrics) { + final String name = ACTIVE_TASK_PREFIX + BUFFER_COUNT; + final Sensor sensor = streamsMetrics.taskLevelSensor(threadId, taskId, name, Sensor.RecordingLevel.DEBUG); + addValueMetricToSensor( + sensor, + TASK_LEVEL_GROUP, + streamsMetrics.taskLevelTagMap(threadId, taskId), + name, + NUM_BUFFERED_RECORDS_DESCRIPTION + ); + return sensor; + } + + public static Sensor punctuateSensor(final String threadId, + final String taskId, + final StreamsMetricsImpl streamsMetrics) { + return invocationRateAndCountAndAvgAndMaxLatencySensor( + threadId, + taskId, + PUNCTUATE, + PUNCTUATE_RATE_DESCRIPTION, + PUNCTUATE_TOTAL_DESCRIPTION, + PUNCTUATE_AVG_LATENCY_DESCRIPTION, + PUNCTUATE_MAX_LATENCY_DESCRIPTION, + Sensor.RecordingLevel.DEBUG, + streamsMetrics + ); + } + + public static Sensor commitSensor(final String threadId, + final String taskId, + final StreamsMetricsImpl streamsMetrics, + final Sensor... parentSensor) { + return invocationRateAndCountSensor( + threadId, + taskId, + COMMIT, + COMMIT_RATE_DESCRIPTION, + COMMIT_TOTAL_DESCRIPTION, + Sensor.RecordingLevel.DEBUG, + streamsMetrics, + parentSensor + ); + } + + public static Sensor enforcedProcessingSensor(final String threadId, + final String taskId, + final StreamsMetricsImpl streamsMetrics, + final Sensor... parentSensors) { + return invocationRateAndCountSensor( + threadId, + taskId, + ENFORCED_PROCESSING, + ENFORCED_PROCESSING_RATE_DESCRIPTION, + ENFORCED_PROCESSING_TOTAL_DESCRIPTION, + RecordingLevel.DEBUG, + streamsMetrics, + parentSensors + ); + } + + public static Sensor recordLatenessSensor(final String threadId, + final String taskId, + final StreamsMetricsImpl streamsMetrics) { + return avgAndMaxSensor( + threadId, + taskId, + RECORD_LATENESS, + RECORD_LATENESS_AVG_DESCRIPTION, + RECORD_LATENESS_MAX_DESCRIPTION, + RecordingLevel.DEBUG, + streamsMetrics + ); + } + + public static Sensor droppedRecordsSensor(final String threadId, + final String taskId, + final StreamsMetricsImpl streamsMetrics) { + return invocationRateAndCountSensor( + threadId, + taskId, + DROPPED_RECORDS, + DROPPED_RECORDS_RATE_DESCRIPTION, + DROPPED_RECORDS_TOTAL_DESCRIPTION, + RecordingLevel.INFO, + streamsMetrics + ); + } + + private static Sensor invocationRateAndCountSensor(final String threadId, + final String taskId, + final String metricName, + final String descriptionOfRate, + final String descriptionOfCount, + final RecordingLevel recordingLevel, + final StreamsMetricsImpl streamsMetrics, + final Sensor... parentSensors) { + final Sensor sensor = streamsMetrics.taskLevelSensor(threadId, taskId, metricName, recordingLevel, parentSensors); + addInvocationRateAndCountToSensor( + sensor, + TASK_LEVEL_GROUP, + streamsMetrics.taskLevelTagMap(threadId, taskId), + metricName, + descriptionOfRate, + descriptionOfCount + ); + return sensor; + } + + private static Sensor avgAndMaxSensor(final String threadId, + final String taskId, + final String metricName, + final String descriptionOfAvg, + final String descriptionOfMax, + final RecordingLevel recordingLevel, + final StreamsMetricsImpl streamsMetrics, + final Sensor... parentSensors) { + final Sensor sensor = streamsMetrics.taskLevelSensor(threadId, taskId, metricName, recordingLevel, parentSensors); + final Map tagMap = streamsMetrics.taskLevelTagMap(threadId, taskId); + addAvgAndMaxToSensor( + sensor, + TASK_LEVEL_GROUP, + tagMap, + metricName, + descriptionOfAvg, + descriptionOfMax + ); + return sensor; + } + + private static Sensor invocationRateAndCountAndAvgAndMaxLatencySensor(final String threadId, + final String taskId, + final String metricName, + final String descriptionOfRate, + final String descriptionOfCount, + final String descriptionOfAvg, + final String descriptionOfMax, + final RecordingLevel recordingLevel, + final StreamsMetricsImpl streamsMetrics, + final Sensor... parentSensors) { + final Sensor sensor = streamsMetrics.taskLevelSensor(threadId, taskId, metricName, recordingLevel, parentSensors); + final Map tagMap = streamsMetrics.taskLevelTagMap(threadId, taskId); + addAvgAndMaxToSensor( + sensor, + TASK_LEVEL_GROUP, + tagMap, + metricName + LATENCY_SUFFIX, + descriptionOfAvg, + descriptionOfMax + ); + addInvocationRateAndCountToSensor( + sensor, + TASK_LEVEL_GROUP, + tagMap, + metricName, + descriptionOfRate, + descriptionOfCount + ); + return sensor; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/ThreadMetrics.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/ThreadMetrics.java new file mode 100644 index 0000000..9c3e809 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/metrics/ThreadMetrics.java @@ -0,0 +1,388 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.metrics; + +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.Sensor.RecordingLevel; +import org.apache.kafka.streams.processor.internals.StreamThreadTotalBlockedTime; + +import java.util.Map; + +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.LATENCY_SUFFIX; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.RATE_DESCRIPTION; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.RATE_SUFFIX; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.RATIO_SUFFIX; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.RECORDS_SUFFIX; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.ROLLUP_VALUE; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.TASK_LEVEL_GROUP; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.THREAD_LEVEL_GROUP; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.TOTAL_DESCRIPTION; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addAvgAndMaxToSensor; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addInvocationRateAndCountToSensor; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addRateOfSumAndSumMetricsToSensor; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addValueMetricToSensor; + +public class ThreadMetrics { + private ThreadMetrics() {} + + private static final String COMMIT = "commit"; + private static final String POLL = "poll"; + private static final String PROCESS = "process"; + private static final String PUNCTUATE = "punctuate"; + private static final String CREATE_TASK = "task-created"; + private static final String CLOSE_TASK = "task-closed"; + private static final String SKIP_RECORD = "skipped-records"; + private static final String BLOCKED_TIME = "blocked-time-ns-total"; + private static final String THREAD_START_TIME = "thread-start-time"; + + private static final String COMMIT_DESCRIPTION = "calls to commit"; + private static final String COMMIT_TOTAL_DESCRIPTION = TOTAL_DESCRIPTION + COMMIT_DESCRIPTION; + private static final String COMMIT_RATE_DESCRIPTION = RATE_DESCRIPTION + COMMIT_DESCRIPTION; + private static final String COMMIT_AVG_LATENCY_DESCRIPTION = "The average commit latency"; + private static final String COMMIT_MAX_LATENCY_DESCRIPTION = "The maximum commit latency"; + private static final String CREATE_TASK_DESCRIPTION = "newly created tasks"; + private static final String CREATE_TASK_TOTAL_DESCRIPTION = TOTAL_DESCRIPTION + CREATE_TASK_DESCRIPTION; + private static final String CREATE_TASK_RATE_DESCRIPTION = RATE_DESCRIPTION + CREATE_TASK_DESCRIPTION; + private static final String CLOSE_TASK_DESCRIPTION = "closed tasks"; + private static final String CLOSE_TASK_TOTAL_DESCRIPTION = TOTAL_DESCRIPTION + CLOSE_TASK_DESCRIPTION; + private static final String CLOSE_TASK_RATE_DESCRIPTION = RATE_DESCRIPTION + CLOSE_TASK_DESCRIPTION; + private static final String POLL_DESCRIPTION = "calls to poll"; + private static final String POLL_TOTAL_DESCRIPTION = TOTAL_DESCRIPTION + POLL_DESCRIPTION; + private static final String POLL_RATE_DESCRIPTION = RATE_DESCRIPTION + POLL_DESCRIPTION; + private static final String POLL_AVG_LATENCY_DESCRIPTION = "The average poll latency"; + private static final String POLL_MAX_LATENCY_DESCRIPTION = "The maximum poll latency"; + private static final String POLL_AVG_RECORDS_DESCRIPTION = "The average number of records polled from consumer within an iteration"; + private static final String POLL_MAX_RECORDS_DESCRIPTION = "The maximum number of records polled from consumer within an iteration"; + private static final String PROCESS_DESCRIPTION = "calls to process"; + private static final String PROCESS_TOTAL_DESCRIPTION = TOTAL_DESCRIPTION + PROCESS_DESCRIPTION; + private static final String PROCESS_RATE_DESCRIPTION = RATE_DESCRIPTION + PROCESS_DESCRIPTION; + private static final String PROCESS_AVG_LATENCY_DESCRIPTION = "The average process latency"; + private static final String PROCESS_MAX_LATENCY_DESCRIPTION = "The maximum process latency"; + private static final String PROCESS_AVG_RECORDS_DESCRIPTION = "The average number of records processed within an iteration"; + private static final String PROCESS_MAX_RECORDS_DESCRIPTION = "The maximum number of records processed within an iteration"; + private static final String PUNCTUATE_DESCRIPTION = "calls to punctuate"; + private static final String PUNCTUATE_TOTAL_DESCRIPTION = TOTAL_DESCRIPTION + PUNCTUATE_DESCRIPTION; + private static final String PUNCTUATE_RATE_DESCRIPTION = RATE_DESCRIPTION + PUNCTUATE_DESCRIPTION; + private static final String PUNCTUATE_AVG_LATENCY_DESCRIPTION = "The average punctuate latency"; + private static final String PUNCTUATE_MAX_LATENCY_DESCRIPTION = "The maximum punctuate latency"; + private static final String SKIP_RECORDS_DESCRIPTION = "skipped records"; + private static final String SKIP_RECORD_TOTAL_DESCRIPTION = TOTAL_DESCRIPTION + SKIP_RECORDS_DESCRIPTION; + private static final String SKIP_RECORD_RATE_DESCRIPTION = RATE_DESCRIPTION + SKIP_RECORDS_DESCRIPTION; + private static final String COMMIT_OVER_TASKS_DESCRIPTION = + "calls to commit over all tasks assigned to one stream thread"; + private static final String COMMIT_OVER_TASKS_TOTAL_DESCRIPTION = TOTAL_DESCRIPTION + COMMIT_OVER_TASKS_DESCRIPTION; + private static final String COMMIT_OVER_TASKS_RATE_DESCRIPTION = RATE_DESCRIPTION + COMMIT_OVER_TASKS_DESCRIPTION; + private static final String PROCESS_RATIO_DESCRIPTION = + "The fraction of time the thread spent on processing active tasks"; + private static final String PUNCTUATE_RATIO_DESCRIPTION = + "The fraction of time the thread spent on punctuating active tasks"; + private static final String POLL_RATIO_DESCRIPTION = + "The fraction of time the thread spent on polling records from consumer"; + private static final String COMMIT_RATIO_DESCRIPTION = + "The fraction of time the thread spent on committing all tasks"; + private static final String BLOCKED_TIME_DESCRIPTION = + "The total time the thread spent blocked on kafka in nanoseconds"; + private static final String THREAD_START_TIME_DESCRIPTION = + "The time that the thread was started"; + + public static Sensor createTaskSensor(final String threadId, + final StreamsMetricsImpl streamsMetrics) { + return invocationRateAndCountSensor( + threadId, + CREATE_TASK, + CREATE_TASK_RATE_DESCRIPTION, + CREATE_TASK_TOTAL_DESCRIPTION, + RecordingLevel.INFO, + streamsMetrics + ); + } + + public static Sensor closeTaskSensor(final String threadId, + final StreamsMetricsImpl streamsMetrics) { + return invocationRateAndCountSensor( + threadId, + CLOSE_TASK, + CLOSE_TASK_RATE_DESCRIPTION, + CLOSE_TASK_TOTAL_DESCRIPTION, + RecordingLevel.INFO, + streamsMetrics + ); + } + + public static Sensor skipRecordSensor(final String threadId, + final StreamsMetricsImpl streamsMetrics) { + return invocationRateAndCountSensor( + threadId, + SKIP_RECORD, + SKIP_RECORD_RATE_DESCRIPTION, + SKIP_RECORD_TOTAL_DESCRIPTION, + RecordingLevel.INFO, + streamsMetrics + ); + } + + public static Sensor commitSensor(final String threadId, + final StreamsMetricsImpl streamsMetrics) { + return invocationRateAndCountAndAvgAndMaxLatencySensor( + threadId, + COMMIT, + COMMIT_RATE_DESCRIPTION, + COMMIT_TOTAL_DESCRIPTION, + COMMIT_AVG_LATENCY_DESCRIPTION, + COMMIT_MAX_LATENCY_DESCRIPTION, + Sensor.RecordingLevel.INFO, + streamsMetrics + ); + } + + public static Sensor pollSensor(final String threadId, + final StreamsMetricsImpl streamsMetrics) { + return invocationRateAndCountAndAvgAndMaxLatencySensor( + threadId, + POLL, + POLL_RATE_DESCRIPTION, + POLL_TOTAL_DESCRIPTION, + POLL_AVG_LATENCY_DESCRIPTION, + POLL_MAX_LATENCY_DESCRIPTION, + Sensor.RecordingLevel.INFO, + streamsMetrics + ); + } + + public static Sensor processLatencySensor(final String threadId, + final StreamsMetricsImpl streamsMetrics) { + final Sensor sensor = + streamsMetrics.threadLevelSensor(threadId, PROCESS + LATENCY_SUFFIX, RecordingLevel.INFO); + final Map tagMap = streamsMetrics.threadLevelTagMap(threadId); + addAvgAndMaxToSensor( + sensor, + THREAD_LEVEL_GROUP, + tagMap, + PROCESS + LATENCY_SUFFIX, + PROCESS_AVG_LATENCY_DESCRIPTION, + PROCESS_MAX_LATENCY_DESCRIPTION + ); + return sensor; + } + + public static Sensor pollRecordsSensor(final String threadId, + final StreamsMetricsImpl streamsMetrics) { + final Sensor sensor = + streamsMetrics.threadLevelSensor(threadId, POLL + RECORDS_SUFFIX, RecordingLevel.INFO); + final Map tagMap = streamsMetrics.threadLevelTagMap(threadId); + addAvgAndMaxToSensor( + sensor, + THREAD_LEVEL_GROUP, + tagMap, + POLL + RECORDS_SUFFIX, + POLL_AVG_RECORDS_DESCRIPTION, + POLL_MAX_RECORDS_DESCRIPTION + ); + return sensor; + } + + public static Sensor processRecordsSensor(final String threadId, + final StreamsMetricsImpl streamsMetrics) { + final Sensor sensor = + streamsMetrics.threadLevelSensor(threadId, PROCESS + RECORDS_SUFFIX, RecordingLevel.INFO); + final Map tagMap = streamsMetrics.threadLevelTagMap(threadId); + addAvgAndMaxToSensor( + sensor, + THREAD_LEVEL_GROUP, + tagMap, + PROCESS + RECORDS_SUFFIX, + PROCESS_AVG_RECORDS_DESCRIPTION, + PROCESS_MAX_RECORDS_DESCRIPTION + ); + return sensor; + } + + public static Sensor processRateSensor(final String threadId, + final StreamsMetricsImpl streamsMetrics) { + final Sensor sensor = + streamsMetrics.threadLevelSensor(threadId, PROCESS + RATE_SUFFIX, RecordingLevel.INFO); + final Map tagMap = streamsMetrics.threadLevelTagMap(threadId); + addRateOfSumAndSumMetricsToSensor( + sensor, + THREAD_LEVEL_GROUP, + tagMap, + PROCESS, + PROCESS_RATE_DESCRIPTION, + PROCESS_TOTAL_DESCRIPTION + ); + return sensor; + } + + public static Sensor punctuateSensor(final String threadId, + final StreamsMetricsImpl streamsMetrics) { + return invocationRateAndCountAndAvgAndMaxLatencySensor( + threadId, + PUNCTUATE, + PUNCTUATE_RATE_DESCRIPTION, + PUNCTUATE_TOTAL_DESCRIPTION, + PUNCTUATE_AVG_LATENCY_DESCRIPTION, + PUNCTUATE_MAX_LATENCY_DESCRIPTION, + Sensor.RecordingLevel.INFO, + streamsMetrics + ); + } + + public static Sensor commitOverTasksSensor(final String threadId, + final StreamsMetricsImpl streamsMetrics) { + final Sensor commitOverTasksSensor = + streamsMetrics.threadLevelSensor(threadId, COMMIT, Sensor.RecordingLevel.DEBUG); + final Map tagMap = streamsMetrics.taskLevelTagMap(threadId, ROLLUP_VALUE); + addInvocationRateAndCountToSensor( + commitOverTasksSensor, + TASK_LEVEL_GROUP, + tagMap, + COMMIT, + COMMIT_OVER_TASKS_RATE_DESCRIPTION, + COMMIT_OVER_TASKS_TOTAL_DESCRIPTION + ); + return commitOverTasksSensor; + } + + public static Sensor processRatioSensor(final String threadId, + final StreamsMetricsImpl streamsMetrics) { + final Sensor sensor = + streamsMetrics.threadLevelSensor(threadId, PROCESS + RATIO_SUFFIX, Sensor.RecordingLevel.INFO); + final Map tagMap = streamsMetrics.threadLevelTagMap(threadId); + addValueMetricToSensor( + sensor, + THREAD_LEVEL_GROUP, + tagMap, + PROCESS + RATIO_SUFFIX, + PROCESS_RATIO_DESCRIPTION + ); + return sensor; + } + + public static Sensor punctuateRatioSensor(final String threadId, + final StreamsMetricsImpl streamsMetrics) { + final Sensor sensor = + streamsMetrics.threadLevelSensor(threadId, PUNCTUATE + RATIO_SUFFIX, Sensor.RecordingLevel.INFO); + final Map tagMap = streamsMetrics.threadLevelTagMap(threadId); + addValueMetricToSensor( + sensor, + THREAD_LEVEL_GROUP, + tagMap, + PUNCTUATE + RATIO_SUFFIX, + PUNCTUATE_RATIO_DESCRIPTION + ); + return sensor; + } + + public static Sensor pollRatioSensor(final String threadId, + final StreamsMetricsImpl streamsMetrics) { + final Sensor sensor = + streamsMetrics.threadLevelSensor(threadId, POLL + RATIO_SUFFIX, Sensor.RecordingLevel.INFO); + final Map tagMap = streamsMetrics.threadLevelTagMap(threadId); + addValueMetricToSensor( + sensor, + THREAD_LEVEL_GROUP, + tagMap, + POLL + RATIO_SUFFIX, + POLL_RATIO_DESCRIPTION + ); + return sensor; + } + + public static Sensor commitRatioSensor(final String threadId, + final StreamsMetricsImpl streamsMetrics) { + final Sensor sensor = + streamsMetrics.threadLevelSensor(threadId, COMMIT + RATIO_SUFFIX, Sensor.RecordingLevel.INFO); + final Map tagMap = streamsMetrics.threadLevelTagMap(threadId); + addValueMetricToSensor( + sensor, + THREAD_LEVEL_GROUP, + tagMap, + COMMIT + RATIO_SUFFIX, + COMMIT_RATIO_DESCRIPTION + ); + return sensor; + } + + public static void addThreadStartTimeMetric(final String threadId, + final StreamsMetricsImpl streamsMetrics, + final long startTime) { + streamsMetrics.addThreadLevelImmutableMetric( + THREAD_START_TIME, + THREAD_START_TIME_DESCRIPTION, + threadId, + startTime + ); + } + + public static void addThreadBlockedTimeMetric(final String threadId, + final StreamThreadTotalBlockedTime blockedTime, + final StreamsMetricsImpl streamsMetrics) { + streamsMetrics.addThreadLevelMutableMetric( + BLOCKED_TIME, + BLOCKED_TIME_DESCRIPTION, + threadId, + (config, now) -> blockedTime.compute() + ); + } + + private static Sensor invocationRateAndCountSensor(final String threadId, + final String metricName, + final String descriptionOfRate, + final String descriptionOfCount, + final RecordingLevel recordingLevel, + final StreamsMetricsImpl streamsMetrics) { + final Sensor sensor = streamsMetrics.threadLevelSensor(threadId, metricName, recordingLevel); + addInvocationRateAndCountToSensor( + sensor, + THREAD_LEVEL_GROUP, + streamsMetrics.threadLevelTagMap(threadId), + metricName, + descriptionOfRate, + descriptionOfCount + ); + return sensor; + } + + private static Sensor invocationRateAndCountAndAvgAndMaxLatencySensor(final String threadId, + final String metricName, + final String descriptionOfRate, + final String descriptionOfCount, + final String descriptionOfAvg, + final String descriptionOfMax, + final RecordingLevel recordingLevel, + final StreamsMetricsImpl streamsMetrics) { + final Sensor sensor = streamsMetrics.threadLevelSensor(threadId, metricName, recordingLevel); + final Map tagMap = streamsMetrics.threadLevelTagMap(threadId); + addAvgAndMaxToSensor( + sensor, + THREAD_LEVEL_GROUP, + tagMap, + metricName + LATENCY_SUFFIX, + descriptionOfAvg, + descriptionOfMax + ); + addInvocationRateAndCountToSensor( + sensor, + THREAD_LEVEL_GROUP, + tagMap, + metricName, + descriptionOfRate, + descriptionOfCount + ); + return sensor; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/namedtopology/KafkaStreamsNamedTopologyWrapper.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/namedtopology/KafkaStreamsNamedTopologyWrapper.java new file mode 100644 index 0000000..d4170d4 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/namedtopology/KafkaStreamsNamedTopologyWrapper.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.namedtopology; + +import org.apache.kafka.common.annotation.InterfaceStability.Unstable; +import org.apache.kafka.streams.KafkaClientSupplier; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.errors.TopologyException; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.internals.DefaultKafkaClientSupplier; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; +import org.apache.kafka.streams.processor.internals.TopologyMetadata; + +import java.util.Collection; +import java.util.Collections; +import java.util.Optional; +import java.util.Properties; +import java.util.concurrent.ConcurrentSkipListMap; +import java.util.stream.Collectors; + +/** + * This is currently an internal and experimental feature for enabling certain kinds of topology upgrades. Use at + * your own risk. + * + * Status: additive upgrades possible, removal of NamedTopologies not yet supported + * + * Note: some standard features of Kafka Streams are not yet supported with NamedTopologies. These include: + * - global state stores + * - interactive queries (IQ) + * - TopologyTestDriver (TTD) + */ +@Unstable +public class KafkaStreamsNamedTopologyWrapper extends KafkaStreams { + + /** + * A Kafka Streams application with a single initial NamedTopology + */ + public KafkaStreamsNamedTopologyWrapper(final NamedTopology topology, final Properties props) { + this(Collections.singleton(topology), new StreamsConfig(props), new DefaultKafkaClientSupplier()); + } + + /** + * A Kafka Streams application with a single initial NamedTopology + */ + public KafkaStreamsNamedTopologyWrapper(final NamedTopology topology, final Properties props, final KafkaClientSupplier clientSupplier) { + this(Collections.singleton(topology), new StreamsConfig(props), clientSupplier); + } + + /** + * An empty Kafka Streams application that allows NamedTopologies to be added at a later point + */ + public KafkaStreamsNamedTopologyWrapper(final Properties props) { + this(Collections.emptyList(), new StreamsConfig(props), new DefaultKafkaClientSupplier()); + } + + /** + * An empty Kafka Streams application that allows NamedTopologies to be added at a later point + */ + public KafkaStreamsNamedTopologyWrapper(final Properties props, final KafkaClientSupplier clientSupplier) { + this(Collections.emptyList(), new StreamsConfig(props), clientSupplier); + } + + /** + * A Kafka Streams application with a multiple initial NamedTopologies + * + * @throws IllegalArgumentException if any of the named topologies have the same name + * @throws TopologyException if multiple NamedTopologies subscribe to the same input topics or pattern + */ + public KafkaStreamsNamedTopologyWrapper(final Collection topologies, final Properties props) { + this(topologies, new StreamsConfig(props), new DefaultKafkaClientSupplier()); + } + + /** + * A Kafka Streams application with a multiple initial NamedTopologies + * + * @throws IllegalArgumentException if any of the named topologies have the same name + * @throws TopologyException if multiple NamedTopologies subscribe to the same input topics or pattern + */ + public KafkaStreamsNamedTopologyWrapper(final Collection topologies, final Properties props, final KafkaClientSupplier clientSupplier) { + this(topologies, new StreamsConfig(props), clientSupplier); + } + + private KafkaStreamsNamedTopologyWrapper(final Collection topologies, final StreamsConfig config, final KafkaClientSupplier clientSupplier) { + super( + new TopologyMetadata( + topologies.stream().collect(Collectors.toMap( + NamedTopology::name, + NamedTopology::internalTopologyBuilder, + (v1, v2) -> { + throw new IllegalArgumentException("Topology names must be unique"); + }, + () -> new ConcurrentSkipListMap<>())), + config), + config, + clientSupplier + ); + } + + /** + * @return the NamedTopology for the specific name, or Optional.empty() if the application has no NamedTopology of that name + */ + public Optional getTopologyByName(final String name) { + return Optional.ofNullable(topologyMetadata.lookupBuilderForNamedTopology(name)).map(InternalTopologyBuilder::namedTopology); + } + + /** + * Add a new NamedTopology to a running Kafka Streams app. If multiple instances of the application are running, + * you should inform all of them by calling {@link #addNamedTopology(NamedTopology)} on each client in order for + * it to begin processing the new topology. + * + * @throws IllegalArgumentException if this topology name is already in use + * @throws IllegalStateException if streams has not been started or has already shut down + * @throws TopologyException if this topology subscribes to any input topics or pattern already in use + */ + public void addNamedTopology(final NamedTopology newTopology) { + if (hasStartedOrFinishedShuttingDown()) { + throw new IllegalStateException("Cannot add a NamedTopology while the state is " + super.state); + } else if (getTopologyByName(newTopology.name()).isPresent()) { + throw new IllegalArgumentException("Unable to add the new NamedTopology " + newTopology.name() + + " as another of the same name already exists"); + } + topologyMetadata.registerAndBuildNewTopology(newTopology.internalTopologyBuilder()); + } + + /** + * Remove an existing NamedTopology from a running Kafka Streams app. If multiple instances of the application are + * running, you should inform all of them by calling {@link #removeNamedTopology(String)} on each client to ensure + * it stops processing the old topology. + * + * @throws IllegalArgumentException if this topology name cannot be found + * @throws IllegalStateException if streams has not been started or has already shut down + * @throws TopologyException if this topology subscribes to any input topics or pattern already in use + */ + public void removeNamedTopology(final String topologyToRemove) { + if (!isRunningOrRebalancing()) { + throw new IllegalStateException("Cannot remove a NamedTopology while the state is " + super.state); + } else if (!getTopologyByName(topologyToRemove).isPresent()) { + throw new IllegalArgumentException("Unable to locate for removal a NamedTopology called " + topologyToRemove); + } + + topologyMetadata.unregisterTopology(topologyToRemove); + } + + /** + * Do a clean up of the local state directory for this NamedTopology by deleting all data with regard to the + * @link StreamsConfig#APPLICATION_ID_CONFIG application ID} in the ({@link StreamsConfig#STATE_DIR_CONFIG}) + *

                + * May be called while the Streams is in any state, but only on a {@link NamedTopology} that has already been + * removed via {@link #removeNamedTopology(String)}. + *

                + * Calling this method triggers a restore of local {@link StateStore}s for this {@link NamedTopology} if it is + * ever re-added via {@link #addNamedTopology(NamedTopology)}. + * + * @throws IllegalStateException if this {@code NamedTopology} hasn't been removed + * @throws StreamsException if cleanup failed + */ + public void cleanUpNamedTopology(final String name) { + if (getTopologyByName(name).isPresent()) { + throw new IllegalStateException("Can't clean up local state for an active NamedTopology: " + name); + } + stateDirectory.clearLocalStateForNamedTopology(name); + } + + public String getFullTopologyDescription() { + return topologyMetadata.topologyDescriptionString(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/namedtopology/NamedTopology.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/namedtopology/NamedTopology.java new file mode 100644 index 0000000..0cdfd5b --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/namedtopology/NamedTopology.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.namedtopology; + +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; + +public class NamedTopology extends Topology { + + private final Logger log = LoggerFactory.getLogger(NamedTopology.class); + private String name; + + void setTopologyName(final String newTopologyName) { + if (name != null) { + log.error("Unable to set topologyName = {} because the name is already set to {}", newTopologyName, name); + throw new IllegalStateException("Tried to set topologyName but the name was already set"); + } + name = newTopologyName; + internalTopologyBuilder.setNamedTopology(this); + } + + public String name() { + return name; + } + + public List sourceTopics() { + return super.internalTopologyBuilder.fullSourceTopicNames(); + } + + InternalTopologyBuilder internalTopologyBuilder() { + return internalTopologyBuilder; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/namedtopology/NamedTopologyStreamsBuilder.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/namedtopology/NamedTopologyStreamsBuilder.java new file mode 100644 index 0000000..5d3fad8 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/namedtopology/NamedTopologyStreamsBuilder.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.namedtopology; + +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.processor.TaskId; + +import java.util.Properties; + +public class NamedTopologyStreamsBuilder extends StreamsBuilder { + final String topologyName; + + /** + * @param topologyName any string representing your NamedTopology, all characters allowed except for "__" + * @throws IllegalArgumentException if the name contains the character sequence "__" + */ + public NamedTopologyStreamsBuilder(final String topologyName) { + super(); + this.topologyName = topologyName; + if (topologyName.contains(TaskId.NAMED_TOPOLOGY_DELIMITER)) { + throw new IllegalArgumentException("The character sequence '__' is not allowed in a NamedTopology, please select a new name"); + } + } + + public synchronized NamedTopology buildNamedTopology(final Properties props) { + super.build(props); + final NamedTopology namedTopology = (NamedTopology) super.topology; + namedTopology.setTopologyName(topologyName); + return namedTopology; + } + + @Override + public Topology getNewTopology() { + return new NamedTopology(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/HostInfo.java b/streams/src/main/java/org/apache/kafka/streams/state/HostInfo.java new file mode 100644 index 0000000..c25f184 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/HostInfo.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state; + +import static org.apache.kafka.common.utils.Utils.getHost; +import static org.apache.kafka.common.utils.Utils.getPort; + +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.processor.internals.StreamsPartitionAssignor; + +/** + * Represents a user defined endpoint in a {@link org.apache.kafka.streams.KafkaStreams} application. + * Instances of this class can be obtained by calling one of: + * {@link KafkaStreams#metadataForAllStreamsClients()} + * {@link KafkaStreams#streamsMetadataForStore(String)} + * + * The HostInfo is constructed during Partition Assignment + * see {@link StreamsPartitionAssignor} + * It is extracted from the config {@link org.apache.kafka.streams.StreamsConfig#APPLICATION_SERVER_CONFIG} + * + * If developers wish to expose an endpoint in their KafkaStreams applications they should provide the above + * config. + */ +public class HostInfo { + private final String host; + private final int port; + + public HostInfo(final String host, + final int port) { + this.host = host; + this.port = port; + } + + /** + * @throws ConfigException if the host or port cannot be parsed from the given endpoint string + * @return a new HostInfo or null if endPoint is null or has no characters + */ + public static HostInfo buildFromEndpoint(final String endPoint) { + if (Utils.isBlank(endPoint)) { + return null; + } + + final String host = getHost(endPoint); + final Integer port = getPort(endPoint); + + if (host == null || port == null) { + throw new ConfigException( + String.format("Error parsing host address %s. Expected format host:port.", endPoint) + ); + } + return new HostInfo(host, port); + } + + /** + * @return a sentinel for cases where the host metadata is currently unavailable, eg during rebalance operations. + */ + public static HostInfo unavailable() { + return new HostInfo("unavailable", -1); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + final HostInfo hostInfo = (HostInfo) o; + return port == hostInfo.port && host.equals(hostInfo.host); + } + + @Override + public int hashCode() { + int result = host.hashCode(); + result = 31 * result + port; + return result; + } + + public String host() { + return host; + } + + public int port() { + return port; + } + + @Override + public String toString() { + return "HostInfo{" + + "host=\'" + host + '\'' + + ", port=" + port + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/KeyValueBytesStoreSupplier.java b/streams/src/main/java/org/apache/kafka/streams/state/KeyValueBytesStoreSupplier.java new file mode 100644 index 0000000..9855be3 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/KeyValueBytesStoreSupplier.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state; + +import org.apache.kafka.common.utils.Bytes; + +/** + * A store supplier that can be used to create one or more {@link KeyValueStore KeyValueStore<Bytes, byte[]>} instances of type <Bytes, byte[]>. + * + * For any stores implementing the {@link KeyValueStore KeyValueStore<Bytes, byte[]>} interface, null value bytes are considered as "not exist". This means: + * + *

                  + *
                1. Null value bytes in put operations should be treated as delete.
                2. + *
                3. If the key does not exist, get operations should return null value bytes.
                4. + *
                + */ +public interface KeyValueBytesStoreSupplier extends StoreSupplier> { + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/KeyValueIterator.java b/streams/src/main/java/org/apache/kafka/streams/state/KeyValueIterator.java new file mode 100644 index 0000000..b1f5e2c --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/KeyValueIterator.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state; + +import org.apache.kafka.streams.KeyValue; + +import java.io.Closeable; +import java.util.Iterator; + +/** + * Iterator interface of {@link KeyValue}. + *

                + * Users must call its {@code close} method explicitly upon completeness to release resources, + * or use try-with-resources statement (available since JDK7) for this {@link Closeable} class. + * Note that {@code remove()} is not supported. + * + * @param Type of keys + * @param Type of values + */ +public interface KeyValueIterator extends Iterator>, Closeable { + + @Override + void close(); + + /** + * Peek at the next key without advancing the iterator + * @return the key of the next value that would be returned from the next call to next + */ + K peekNextKey(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/KeyValueStore.java b/streams/src/main/java/org/apache/kafka/streams/state/KeyValueStore.java new file mode 100644 index 0000000..3af8d90 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/KeyValueStore.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.StateStore; + +import java.util.List; + +/** + * A key-value store that supports put/get/delete and range queries. + * + * @param The key type + * @param The value type + */ +public interface KeyValueStore extends StateStore, ReadOnlyKeyValueStore { + + /** + * Update the value associated with this key. + * + * @param key The key to associate the value to + * @param value The value to update, it can be {@code null}; + * if the serialized bytes are also {@code null} it is interpreted as deletes + * @throws NullPointerException If {@code null} is used for key. + */ + void put(K key, V value); + + /** + * Update the value associated with this key, unless a value is already associated with the key. + * + * @param key The key to associate the value to + * @param value The value to update, it can be {@code null}; + * if the serialized bytes are also {@code null} it is interpreted as deletes + * @return The old value or {@code null} if there is no such key. + * @throws NullPointerException If {@code null} is used for key. + */ + V putIfAbsent(K key, V value); + + /** + * Update all the given key/value pairs. + * + * @param entries A list of entries to put into the store; + * if the serialized bytes are also {@code null} it is interpreted as deletes + * @throws NullPointerException If {@code null} is used for key. + */ + void putAll(List> entries); + + /** + * Delete the value from the store (if there is one). + * + * @param key The key + * @return The old value or {@code null} if there is no such key. + * @throws NullPointerException If {@code null} is used for key. + */ + V delete(K key); +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/state/QueryableStoreType.java b/streams/src/main/java/org/apache/kafka/streams/state/QueryableStoreType.java new file mode 100644 index 0000000..9771553 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/QueryableStoreType.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state; + +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.state.internals.StateStoreProvider; + +/** + * Used to enable querying of custom {@link StateStore} types via the {@link KafkaStreams} API. + * + * @param The store type + * @see QueryableStoreTypes + */ +public interface QueryableStoreType { + + /** + * Called when searching for {@link StateStore}s to see if they + * match the type expected by implementors of this interface. + * + * @param stateStore The stateStore + * @return true if it is a match + */ + boolean accepts(final StateStore stateStore); + + /** + * Create an instance of {@code T} (usually a facade) that developers can use + * to query the underlying {@link StateStore}s. + * + * @param storeProvider provides access to all the underlying StateStore instances + * @param storeName The name of the Store + * @return a read-only interface over a {@code StateStore} + * (cf. {@link QueryableStoreTypes.KeyValueStoreType}) + */ + T create(final StateStoreProvider storeProvider, + final String storeName); +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/state/QueryableStoreTypes.java b/streams/src/main/java/org/apache/kafka/streams/state/QueryableStoreTypes.java new file mode 100644 index 0000000..343d274 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/QueryableStoreTypes.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state; + +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StoreQueryParameters; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.state.internals.CompositeReadOnlyKeyValueStore; +import org.apache.kafka.streams.state.internals.CompositeReadOnlySessionStore; +import org.apache.kafka.streams.state.internals.CompositeReadOnlyWindowStore; +import org.apache.kafka.streams.state.internals.StateStoreProvider; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +/** + * Provides access to the {@link QueryableStoreType}s provided with {@link KafkaStreams}. + * These can be used with {@link KafkaStreams#store(StoreQueryParameters)}. + * To access and query the {@link StateStore}s that are part of a {@link Topology}. + */ +public final class QueryableStoreTypes { + + /** + * A {@link QueryableStoreType} that accepts {@link ReadOnlyKeyValueStore}. + * + * @param key type of the store + * @param value type of the store + * @return {@link QueryableStoreTypes.KeyValueStoreType} + */ + public static QueryableStoreType> keyValueStore() { + return new KeyValueStoreType<>(); + } + + /** + * A {@link QueryableStoreType} that accepts {@link ReadOnlyKeyValueStore ReadOnlyKeyValueStore>}. + * + * @param key type of the store + * @param value type of the store + * @return {@link QueryableStoreTypes.TimestampedKeyValueStoreType} + */ + public static QueryableStoreType>> timestampedKeyValueStore() { + return new TimestampedKeyValueStoreType<>(); + } + + /** + * A {@link QueryableStoreType} that accepts {@link ReadOnlyWindowStore}. + * + * @param key type of the store + * @param value type of the store + * @return {@link QueryableStoreTypes.WindowStoreType} + */ + public static QueryableStoreType> windowStore() { + return new WindowStoreType<>(); + } + + /** + * A {@link QueryableStoreType} that accepts {@link ReadOnlyWindowStore ReadOnlyWindowStore>}. + * + * @param key type of the store + * @param value type of the store + * @return {@link QueryableStoreTypes.TimestampedWindowStoreType} + */ + public static QueryableStoreType>> timestampedWindowStore() { + return new TimestampedWindowStoreType<>(); + } + + /** + * A {@link QueryableStoreType} that accepts {@link ReadOnlySessionStore}. + * + * @param key type of the store + * @param value type of the store + * @return {@link QueryableStoreTypes.SessionStoreType} + */ + public static QueryableStoreType> sessionStore() { + return new SessionStoreType<>(); + } + + private static abstract class QueryableStoreTypeMatcher implements QueryableStoreType { + + private final Set matchTo; + + QueryableStoreTypeMatcher(final Set matchTo) { + this.matchTo = matchTo; + } + + @SuppressWarnings("unchecked") + @Override + public boolean accepts(final StateStore stateStore) { + for (final Class matchToClass : matchTo) { + if (!matchToClass.isAssignableFrom(stateStore.getClass())) { + return false; + } + } + return true; + } + } + + public static class KeyValueStoreType extends QueryableStoreTypeMatcher> { + + KeyValueStoreType() { + super(Collections.singleton(ReadOnlyKeyValueStore.class)); + } + + @Override + public ReadOnlyKeyValueStore create(final StateStoreProvider storeProvider, + final String storeName) { + return new CompositeReadOnlyKeyValueStore<>(storeProvider, this, storeName); + } + + } + + private static class TimestampedKeyValueStoreType + extends QueryableStoreTypeMatcher>> { + + TimestampedKeyValueStoreType() { + super(new HashSet<>(Arrays.asList( + TimestampedKeyValueStore.class, + ReadOnlyKeyValueStore.class))); + } + + @Override + public ReadOnlyKeyValueStore> create(final StateStoreProvider storeProvider, + final String storeName) { + return new CompositeReadOnlyKeyValueStore<>(storeProvider, this, storeName); + } + } + + public static class WindowStoreType extends QueryableStoreTypeMatcher> { + + WindowStoreType() { + super(Collections.singleton(ReadOnlyWindowStore.class)); + } + + @Override + public ReadOnlyWindowStore create(final StateStoreProvider storeProvider, + final String storeName) { + return new CompositeReadOnlyWindowStore<>(storeProvider, this, storeName); + } + } + + private static class TimestampedWindowStoreType + extends QueryableStoreTypeMatcher>> { + + TimestampedWindowStoreType() { + super(new HashSet<>(Arrays.asList( + TimestampedWindowStore.class, + ReadOnlyWindowStore.class))); + } + + @Override + public ReadOnlyWindowStore> create(final StateStoreProvider storeProvider, + final String storeName) { + return new CompositeReadOnlyWindowStore<>(storeProvider, this, storeName); + } + } + + public static class SessionStoreType extends QueryableStoreTypeMatcher> { + + SessionStoreType() { + super(Collections.singleton(ReadOnlySessionStore.class)); + } + + @Override + public ReadOnlySessionStore create(final StateStoreProvider storeProvider, + final String storeName) { + return new CompositeReadOnlySessionStore<>(storeProvider, this, storeName); + } + } + +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/state/ReadOnlyKeyValueStore.java b/streams/src/main/java/org/apache/kafka/streams/state/ReadOnlyKeyValueStore.java new file mode 100644 index 0000000..1244d0a --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/ReadOnlyKeyValueStore.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state; + +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.errors.InvalidStateStoreException; + +/** + * A key-value store that only supports read operations. + * Implementations should be thread-safe as concurrent reads and writes are expected. + *

                + * Please note that this contract defines the thread-safe read functionality only; it does not + * guarantee anything about whether the actual instance is writable by another thread, or + * whether it uses some locking mechanism under the hood. For this reason, making dependencies + * between the read and write operations on different StateStore instances can cause concurrency + * problems like deadlock. + * + * @param the key type + * @param the value type + */ +public interface ReadOnlyKeyValueStore { + + /** + * Get the value corresponding to this key. + * + * @param key The key to fetch + * @return The value or null if no value is found. + * @throws NullPointerException If null is used for key. + * @throws InvalidStateStoreException if the store is not initialized + */ + V get(K key); + + /** + * Get an iterator over a given range of keys. This iterator must be closed after use. + * The returned iterator must be safe from {@link java.util.ConcurrentModificationException}s + * and must not return null values. + * Order is not guaranteed as bytes lexicographical ordering might not represent key order. + * + * @param from The first key that could be in the range, where iteration starts from. + * A null value indicates that the range starts with the first element in the store. + * @param to The last key that could be in the range, where iteration ends. + * A null value indicates that the range ends with the last element in the store. + * @return The iterator for this range, from smallest to largest bytes. + * @throws InvalidStateStoreException if the store is not initialized + */ + KeyValueIterator range(K from, K to); + + /** + * Get a reverse iterator over a given range of keys. This iterator must be closed after use. + * The returned iterator must be safe from {@link java.util.ConcurrentModificationException}s + * and must not return null values. + * Order is not guaranteed as bytes lexicographical ordering might not represent key order. + * + * @param from The first key that could be in the range, where iteration ends. + * A null value indicates that the range starts with the first element in the store. + * @param to The last key that could be in the range, where iteration starts from. + * A null value indicates that the range ends with the last element in the store. + * @return The reverse iterator for this range, from largest to smallest key bytes. + * @throws InvalidStateStoreException if the store is not initialized + */ + default KeyValueIterator reverseRange(K from, K to) { + throw new UnsupportedOperationException(); + } + + /** + * Return an iterator over all keys in this store. This iterator must be closed after use. + * The returned iterator must be safe from {@link java.util.ConcurrentModificationException}s + * and must not return null values. + * Order is not guaranteed as bytes lexicographical ordering might not represent key order. + * + * @return An iterator of all key/value pairs in the store, from smallest to largest bytes. + * @throws InvalidStateStoreException if the store is not initialized + */ + KeyValueIterator all(); + + /** + * Return a reverse iterator over all keys in this store. This iterator must be closed after use. + * The returned iterator must be safe from {@link java.util.ConcurrentModificationException}s + * and must not return null values. + * Order is not guaranteed as bytes lexicographical ordering might not represent key order. + * + * @return An reverse iterator of all key/value pairs in the store, from largest to smallest key bytes. + * @throws InvalidStateStoreException if the store is not initialized + */ + default KeyValueIterator reverseAll() { + throw new UnsupportedOperationException(); + } + + /** + * Return an iterator over all keys with the specified prefix. + * Since the type of the prefix can be different from that of the key, a serializer to convert the + * prefix into the format in which the keys are stored in the stores needs to be passed to this method. + * The returned iterator must be safe from {@link java.util.ConcurrentModificationException}s + * and must not return null values. + * Since {@code prefixScan()} relies on byte lexicographical ordering and not on the ordering of the key type, results for some types might be unexpected. + * For example, if the key type is {@code Integer}, and the store contains keys [1, 2, 11, 13], + * then running {@code store.prefixScan(1, new IntegerSerializer())} will return [1] and not [1,11,13]. + * In contrast, if the key type is {@code String} the keys will be sorted [1, 11, 13, 2] in the store and {@code store.prefixScan(1, new StringSerializer())} will return [1,11,13]. + * In both cases {@code prefixScan()} starts the scan at 1 and stops at 2. + * + * @param prefix The prefix. + * @param prefixKeySerializer Serializer for the Prefix key type + * @param Prefix Serializer type + * @param

                Prefix Type. + * @return The iterator for keys having the specified prefix. + */ + default , P> KeyValueIterator prefixScan(P prefix, PS prefixKeySerializer) { + throw new UnsupportedOperationException(); + } + + /** + * Return an approximate count of key-value mappings in this store. + *

                + * The count is not guaranteed to be exact in order to accommodate stores + * where an exact count is expensive to calculate. + * + * @return an approximate count of key-value mappings in the store. + * @throws InvalidStateStoreException if the store is not initialized + */ + long approximateNumEntries(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/ReadOnlySessionStore.java b/streams/src/main/java/org/apache/kafka/streams/state/ReadOnlySessionStore.java new file mode 100644 index 0000000..049c1a4 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/ReadOnlySessionStore.java @@ -0,0 +1,322 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state; + + +import org.apache.kafka.streams.kstream.Windowed; + +import java.time.Instant; + +/** + * A session store that only supports read operations. Implementations should be thread-safe as + * concurrent reads and writes are expected. + * + * @param the key type + * @param the aggregated value type + */ +public interface ReadOnlySessionStore { + + /** + * Fetch any sessions with the matching key and the sessions end is ≥ earliestSessionEndTime + * and the sessions start is ≤ latestSessionStartTime iterating from earliest to latest. + *

                + * This iterator must be closed after use. + * + * @param key the key to return sessions for + * @param earliestSessionEndTime the end timestamp of the earliest session to search for, where + * iteration starts. + * @param latestSessionStartTime the end timestamp of the latest session to search for, where + * iteration ends. + * @return iterator of sessions with the matching key and aggregated values, from earliest to + * latest session time. + * @throws NullPointerException If null is used for key. + */ + default KeyValueIterator, AGG> findSessions(final K key, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + throw new UnsupportedOperationException( + "This API is not supported by this implementation of ReadOnlySessionStore."); + } + + /** + * Fetch any sessions with the matching key and the sessions end is ≥ earliestSessionEndTime + * and the sessions start is ≤ latestSessionStartTime iterating from earliest to latest. + *

                + * This iterator must be closed after use. + * + * @param key the key to return sessions for + * @param earliestSessionEndTime the end timestamp of the earliest session to search for, where + * iteration starts. + * @param latestSessionStartTime the end timestamp of the latest session to search for, where + * iteration ends. + * @return iterator of sessions with the matching key and aggregated values, from earliest to + * latest session time. + * @throws NullPointerException If null is used for key. + */ + default KeyValueIterator, AGG> findSessions(final K key, + final Instant earliestSessionEndTime, + final Instant latestSessionStartTime) { + throw new UnsupportedOperationException( + "This API is not supported by this implementation of ReadOnlySessionStore."); + } + + /** + * Fetch any sessions with the matching key and the sessions end is ≥ earliestSessionEndTime + * and the sessions start is ≤ latestSessionStartTime iterating from latest to earliest. + *

                + * This iterator must be closed after use. + * + * @param key the key to return sessions for + * @param earliestSessionEndTime the end timestamp of the earliest session to search for, where + * iteration ends. + * @param latestSessionStartTime the end timestamp of the latest session to search for, where + * iteration starts. + * @return backward iterator of sessions with the matching key and aggregated values, from + * latest to earliest session time. + * @throws NullPointerException If null is used for key. + */ + default KeyValueIterator, AGG> backwardFindSessions(final K key, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + throw new UnsupportedOperationException( + "This API is not supported by this implementation of ReadOnlySessionStore."); + } + + /** + * Fetch any sessions with the matching key and the sessions end is ≥ earliestSessionEndTime + * and the sessions start is ≤ latestSessionStartTime iterating from latest to earliest. + *

                + * This iterator must be closed after use. + * + * @param key the key to return sessions for + * @param earliestSessionEndTime the end timestamp of the earliest session to search for, where + * iteration ends. + * @param latestSessionStartTime the end timestamp of the latest session to search for, where + * iteration starts. + * @return backward iterator of sessions with the matching key and aggregated values, from + * latest to earliest session time. + * @throws NullPointerException If null is used for key. + */ + default KeyValueIterator, AGG> backwardFindSessions(final K key, + final Instant earliestSessionEndTime, + final Instant latestSessionStartTime) { + throw new UnsupportedOperationException( + "This API is not supported by this implementation of ReadOnlySessionStore."); + } + + /** + * Fetch any sessions in the given range of keys and the sessions end is ≥ + * earliestSessionEndTime and the sessions start is ≤ latestSessionStartTime iterating from + * earliest to latest. + *

                + * This iterator must be closed after use. + * + * @param keyFrom The first key that could be in the range + * A null value indicates a starting position from the first element in the store. + * @param keyTo The last key that could be in the range + * A null value indicates that the range ends with the last element in the store. + * @param earliestSessionEndTime the end timestamp of the earliest session to search for, where + * iteration starts. + * @param latestSessionStartTime the end timestamp of the latest session to search for, where + * iteration ends. + * @return iterator of sessions with the matching keys and aggregated values, from earliest to + * latest session time. + */ + default KeyValueIterator, AGG> findSessions(final K keyFrom, + final K keyTo, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + throw new UnsupportedOperationException( + "This API is not supported by this implementation of ReadOnlySessionStore."); + } + + /** + * Fetch any sessions in the given range of keys and the sessions end is ≥ + * earliestSessionEndTime and the sessions start is ≤ latestSessionStartTime iterating from + * earliest to latest. + *

                + * This iterator must be closed after use. + * + * @param keyFrom The first key that could be in the range + * A null value indicates a starting position from the first element in the store. + * @param keyTo The last key that could be in the range + * A null value indicates that the range ends with the last element in the store. + * @param earliestSessionEndTime the end timestamp of the earliest session to search for, where + * iteration starts. + * @param latestSessionStartTime the end timestamp of the latest session to search for, where + * iteration ends. + * @return iterator of sessions with the matching keys and aggregated values, from earliest to + * latest session time. + */ + default KeyValueIterator, AGG> findSessions(final K keyFrom, + final K keyTo, + final Instant earliestSessionEndTime, + final Instant latestSessionStartTime) { + throw new UnsupportedOperationException( + "This API is not supported by this implementation of ReadOnlySessionStore."); + } + + /** + * Fetch any sessions in the given range of keys and the sessions end is ≥ + * earliestSessionEndTime and the sessions start is ≤ latestSessionStartTime iterating from + * latest to earliest. + *

                + * This iterator must be closed after use. + * + * @param keyFrom The first key that could be in the range + * A null value indicates a starting position from the first element in the store. + * @param keyTo The last key that could be in the range + * A null value indicates that the range ends with the last element in the store. + * @param earliestSessionEndTime the end timestamp of the earliest session to search for, where + * iteration ends. + * @param latestSessionStartTime the end timestamp of the latest session to search for, where + * iteration starts. + * @return backward iterator of sessions with the matching keys and aggregated values, from + * latest to earliest session time. + */ + default KeyValueIterator, AGG> backwardFindSessions(final K keyFrom, + final K keyTo, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + throw new UnsupportedOperationException( + "This API is not supported by this implementation of ReadOnlySessionStore."); + } + + /** + * Fetch any sessions in the given range of keys and the sessions end is ≥ + * earliestSessionEndTime and the sessions start is ≤ latestSessionStartTime iterating from + * latest to earliest. + *

                + * This iterator must be closed after use. + * + * @param keyFrom The first key that could be in the range + * A null value indicates a starting position from the first element in the store. + * @param keyTo The last key that could be in the range + * A null value indicates that the range ends with the last element in the store. + * @param earliestSessionEndTime the end timestamp of the earliest session to search for, where + * iteration ends. + * @param latestSessionStartTime the end timestamp of the latest session to search for, where + * iteration starts. + * @return backward iterator of sessions with the matching keys and aggregated values, from + * latest to earliest session time. + */ + default KeyValueIterator, AGG> backwardFindSessions(final K keyFrom, + final K keyTo, + final Instant earliestSessionEndTime, + final Instant latestSessionStartTime) { + throw new UnsupportedOperationException( + "This API is not supported by this implementation of ReadOnlySessionStore."); + } + + /** + * Get the value of key from a single session. + * + * @param key the key to fetch + * @param earliestSessionEndTime start timestamp of the session + * @param latestSessionStartTime end timestamp of the session + * @return The value or {@code null} if no session associated with the key can be found + * @throws NullPointerException If {@code null} is used for any key. + */ + default AGG fetchSession(final K key, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + throw new UnsupportedOperationException( + "This API is not supported by this implementation of ReadOnlySessionStore."); + } + + /** + * Get the value of key from a single session. + * + * @param key the key to fetch + * @param earliestSessionEndTime start timestamp of the session + * @param latestSessionStartTime end timestamp of the session + * @return The value or {@code null} if no session associated with the key can be found + * @throws NullPointerException If {@code null} is used for any key. + */ + default AGG fetchSession(final K key, + final Instant earliestSessionEndTime, + final Instant latestSessionStartTime) { + throw new UnsupportedOperationException( + "This API is not supported by this implementation of ReadOnlySessionStore."); + } + + /** + * Retrieve all aggregated sessions for the provided key. This iterator must be closed after + * use. + *

                + * For each key, the iterator guarantees ordering of sessions, starting from the oldest/earliest + * available session to the newest/latest session. + * + * @param key record key to find aggregated session values for + * @return KeyValueIterator containing all sessions for the provided key, from oldest to newest + * session. + * @throws NullPointerException If null is used for key. + */ + KeyValueIterator, AGG> fetch(final K key); + + /** + * Retrieve all aggregated sessions for the provided key. This iterator must be closed after + * use. + *

                + * For each key, the iterator guarantees ordering of sessions, starting from the newest/latest + * available session to the oldest/earliest session. + * + * @param key record key to find aggregated session values for + * @return backward KeyValueIterator containing all sessions for the provided key, from newest + * to oldest session. + * @throws NullPointerException If null is used for key. + */ + default KeyValueIterator, AGG> backwardFetch(final K key) { + throw new UnsupportedOperationException( + "This API is not supported by this implementation of ReadOnlySessionStore."); + } + + /** + * Retrieve all aggregated sessions for the given range of keys. This iterator must be closed + * after use. + *

                + * For each key, the iterator guarantees ordering of sessions, starting from the oldest/earliest + * available session to the newest/latest session. + * + * @param keyFrom first key in the range to find aggregated session values for + * A null value indicates a starting position from the first element in the store. + * @param keyTo last key in the range to find aggregated session values for + * A null value indicates that the range ends with the last element in the store. + * @return KeyValueIterator containing all sessions for the provided key, from oldest to newest + * session. + */ + KeyValueIterator, AGG> fetch(final K keyFrom, final K keyTo); + + /** + * Retrieve all aggregated sessions for the given range of keys. This iterator must be closed + * after use. + *

                + * For each key, the iterator guarantees ordering of sessions, starting from the newest/latest + * available session to the oldest/earliest session. + * + * @param keyFrom first key in the range to find aggregated session values for + * A null value indicates a starting position from the first element in the store. + * @param keyTo last key in the range to find aggregated session values for + * A null value indicates that the range ends with the last element in the store. + * @return backward KeyValueIterator containing all sessions for the provided key, from newest + * to oldest session. + */ + default KeyValueIterator, AGG> backwardFetch(final K keyFrom, final K keyTo) { + throw new UnsupportedOperationException( + "This API is not supported by this implementation of ReadOnlySessionStore."); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/ReadOnlyWindowStore.java b/streams/src/main/java/org/apache/kafka/streams/state/ReadOnlyWindowStore.java new file mode 100644 index 0000000..3df170d --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/ReadOnlyWindowStore.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state; + +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.kstream.Windowed; + +import java.time.Instant; + +/** + * A window store that only supports read operations. + * Implementations should be thread-safe as concurrent reads and writes are expected. + * + *

                Note: The current implementation of either forward or backward fetches on range-key-range-time does not + * obey the ordering when there are multiple local stores hosted on that instance. For example, + * if there are two stores from two tasks hosting keys {1,3} and {2,4}, then a range query of key [1,4] + * would return in the order of [1,3,2,4] but not [1,2,3,4] since it is just looping over the stores only. + * + * @param Type of keys + * @param Type of values + */ +public interface ReadOnlyWindowStore { + + /** + * Get the value of key from a window. + * + * @param key the key to fetch + * @param time start timestamp (inclusive) of the window + * @return The value or {@code null} if no value is found in the window + * @throws InvalidStateStoreException if the store is not initialized + * @throws NullPointerException if {@code null} is used for any key. + */ + V fetch(K key, long time); + + /** + * Get all the key-value pairs with the given key and the time range from all the existing windows. + *

                + * This iterator must be closed after use. + *

                + * The time range is inclusive and applies to the starting timestamp of the window. + * For example, if we have the following windows: + *

                +     * +-------------------------------+
                +     * |  key  | start time | end time |
                +     * +-------+------------+----------+
                +     * |   A   |     10     |    20    |
                +     * +-------+------------+----------+
                +     * |   A   |     15     |    25    |
                +     * +-------+------------+----------+
                +     * |   A   |     20     |    30    |
                +     * +-------+------------+----------+
                +     * |   A   |     25     |    35    |
                +     * +--------------------------------
                +     * 
                + * And we call {@code store.fetch("A", Instant.ofEpochMilli(10), Instant.ofEpochMilli(20))} then the results will contain the first + * three windows from the table above, i.e., all those where 10 <= start time <= 20. + *

                + * For each key, the iterator guarantees ordering of windows, starting from the oldest/earliest + * available window to the newest/latest window. + * + * @param key the key to fetch + * @param timeFrom time range start (inclusive), where iteration starts. + * @param timeTo time range end (inclusive), where iteration ends. + * @return an iterator over key-value pairs {@code }, from beginning to end of time. + * @throws InvalidStateStoreException if the store is not initialized + * @throws NullPointerException if {@code null} is used for key. + * @throws IllegalArgumentException if duration is negative or can't be represented as {@code long milliseconds} + */ + WindowStoreIterator fetch(K key, Instant timeFrom, Instant timeTo) throws IllegalArgumentException; + + /** + * Get all the key-value pairs with the given key and the time range from all the existing windows + * in backward order with respect to time (from end to beginning of time). + *

                + * This iterator must be closed after use. + *

                + * The time range is inclusive and applies to the starting timestamp of the window. + * For example, if we have the following windows: + *

                +     * +-------------------------------+
                +     * |  key  | start time | end time |
                +     * +-------+------------+----------+
                +     * |   A   |     10     |    20    |
                +     * +-------+------------+----------+
                +     * |   A   |     15     |    25    |
                +     * +-------+------------+----------+
                +     * |   A   |     20     |    30    |
                +     * +-------+------------+----------+
                +     * |   A   |     25     |    35    |
                +     * +--------------------------------
                +     * 
                + * And we call {@code store.backwardFetch("A", Instant.ofEpochMilli(10), Instant.ofEpochMilli(20))} then the + * results will contain the first three windows from the table above in backward order, + * i.e., all those where 10 <= start time <= 20. + *

                + * For each key, the iterator guarantees ordering of windows, starting from the newest/latest + * available window to the oldest/earliest window. + * + * @param key the key to fetch + * @param timeFrom time range start (inclusive), where iteration ends. + * @param timeTo time range end (inclusive), where iteration starts. + * @return an iterator over key-value pairs {@code }, from end to beginning of time. + * @throws InvalidStateStoreException if the store is not initialized + * @throws NullPointerException if {@code null} is used for key. + * @throws IllegalArgumentException if duration is negative or can't be represented as {@code long milliseconds} + */ + default WindowStoreIterator backwardFetch(K key, Instant timeFrom, Instant timeTo) throws IllegalArgumentException { + throw new UnsupportedOperationException(); + } + + /** + * Get all the key-value pairs in the given key range and time range from all the existing windows. + *

                + * This iterator must be closed after use. + * + * @param keyFrom the first key in the range + * A null value indicates a starting position from the first element in the store. + * @param keyTo the last key in the range + * A null value indicates that the range ends with the last element in the store. + * @param timeFrom time range start (inclusive), where iteration starts. + * @param timeTo time range end (inclusive), where iteration ends. + * @return an iterator over windowed key-value pairs {@code , value>}, from beginning to end of time. + * @throws InvalidStateStoreException if the store is not initialized + * @throws IllegalArgumentException if duration is negative or can't be represented as {@code long milliseconds} + */ + KeyValueIterator, V> fetch(K keyFrom, K keyTo, Instant timeFrom, Instant timeTo) + throws IllegalArgumentException; + + /** + * Get all the key-value pairs in the given key range and time range from all the existing windows + * in backward order with respect to time (from end to beginning of time). + *

                + * This iterator must be closed after use. + * + * @param keyFrom the first key in the range + * A null value indicates a starting position from the first element in the store. + * @param keyTo the last key in the range + * A null value indicates that the range ends with the last element in the store. + * @param timeFrom time range start (inclusive), where iteration ends. + * @param timeTo time range end (inclusive), where iteration starts. + * @return an iterator over windowed key-value pairs {@code , value>}, from end to beginning of time. + * @throws InvalidStateStoreException if the store is not initialized + * @throws IllegalArgumentException if duration is negative or can't be represented as {@code long milliseconds} + */ + default KeyValueIterator, V> backwardFetch(K keyFrom, K keyTo, Instant timeFrom, Instant timeTo) + throws IllegalArgumentException { + throw new UnsupportedOperationException(); + } + + + /** + * Gets all the key-value pairs in the existing windows. + * + * @return an iterator over windowed key-value pairs {@code , value>}, from beginning to end of time. + * @throws InvalidStateStoreException if the store is not initialized + */ + KeyValueIterator, V> all(); + + /** + * Gets all the key-value pairs in the existing windows in backward order + * with respect to time (from end to beginning of time). + * + * @return an backward iterator over windowed key-value pairs {@code , value>}, from the end to beginning of time. + * @throws InvalidStateStoreException if the store is not initialized + */ + default KeyValueIterator, V> backwardAll() { + throw new UnsupportedOperationException(); + } + + /** + * Gets all the key-value pairs that belong to the windows within in the given time range. + * + * @param timeFrom the beginning of the time slot from which to search (inclusive), where iteration starts. + * @param timeTo the end of the time slot from which to search (inclusive), where iteration ends. + * @return an iterator over windowed key-value pairs {@code , value>}, from beginning to end of time. + * @throws InvalidStateStoreException if the store is not initialized + * @throws NullPointerException if {@code null} is used for any key + * @throws IllegalArgumentException if duration is negative or can't be represented as {@code long milliseconds} + */ + KeyValueIterator, V> fetchAll(Instant timeFrom, Instant timeTo) throws IllegalArgumentException; + + /** + * Gets all the key-value pairs that belong to the windows within in the given time range in backward order + * with respect to time (from end to beginning of time). + * + * @param timeFrom the beginning of the time slot from which to search (inclusive), where iteration ends. + * @param timeTo the end of the time slot from which to search (inclusive), where iteration starts. + * @return an backward iterator over windowed key-value pairs {@code , value>}, from end to beginning of time. + * @throws InvalidStateStoreException if the store is not initialized + * @throws NullPointerException if {@code null} is used for any key + * @throws IllegalArgumentException if duration is negative or can't be represented as {@code long milliseconds} + */ + default KeyValueIterator, V> backwardFetchAll(Instant timeFrom, Instant timeTo) throws IllegalArgumentException { + throw new UnsupportedOperationException(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/RocksDBConfigSetter.java b/streams/src/main/java/org/apache/kafka/streams/state/RocksDBConfigSetter.java new file mode 100644 index 0000000..7da3d8f --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/RocksDBConfigSetter.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state; + +import org.rocksdb.Options; + +import java.util.Map; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An interface to that allows developers to customize the RocksDB settings for a given Store. + * Please read the RocksDB Tuning Guide. + * + * Note: if you choose to modify the {@code org.rocksdb.BlockBasedTableConfig} you should retrieve a reference to + * the existing one (rather than create a new BlockBasedTableConfig object) so as to not lose the other default settings. + * This can be done as {@code BlockBasedTableConfig tableConfig = (BlockBasedTableConfig) options.tableFormatConfig();} + */ +public interface RocksDBConfigSetter { + + Logger LOG = LoggerFactory.getLogger(RocksDBConfigSetter.class); + + /** + * Set the rocks db options for the provided storeName. + * + * @param storeName the name of the store being configured + * @param options the RocksDB options + * @param configs the configuration supplied to {@link org.apache.kafka.streams.StreamsConfig} + */ + void setConfig(final String storeName, final Options options, final Map configs); + + /** + * Close any user-constructed objects that inherit from {@code org.rocksdb.RocksObject}. + *

                + * Any object created with {@code new} in {@link RocksDBConfigSetter#setConfig setConfig()} and that inherits + * from {@code org.rocksdb.RocksObject} should have {@code org.rocksdb.RocksObject#close()} + * called on it here to avoid leaking off-heap memory. Objects to be closed can be saved by the user or retrieved + * back from {@code options} using its getter methods. + *

                + * Example objects needing to be closed include {@code org.rocksdb.Filter} and {@code org.rocksdb.Cache}. + * + * @param storeName the name of the store being configured + * @param options the RocksDB options + */ + void close(final String storeName, final Options options); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/SessionBytesStoreSupplier.java b/streams/src/main/java/org/apache/kafka/streams/state/SessionBytesStoreSupplier.java new file mode 100644 index 0000000..5c7bc25 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/SessionBytesStoreSupplier.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state; + +import org.apache.kafka.common.utils.Bytes; + +/** + * A store supplier that can be used to create one or more {@link SessionStore SessionStore<Byte, byte[]>} instances. + * + * For any stores implementing the {@link SessionStore SessionStore<Byte, byte[]>} interface, {@code null} value + * bytes are considered as "not exist". This means: + *

                  + *
                1. {@code null} value bytes in put operations should be treated as delete.
                2. + *
                3. {@code null} value bytes should never be returned in range query results.
                4. + *
                + */ +public interface SessionBytesStoreSupplier extends StoreSupplier> { + + /** + * The size of a segment, in milliseconds. Used when caching is enabled to segment the cache + * and reduce the amount of data that needs to be scanned when performing range queries. + * + * @return segmentInterval in milliseconds + */ + long segmentIntervalMs(); + + /** + * The time period for which the {@link SessionStore} will retain historic data. + * + * @return retentionPeriod + */ + long retentionPeriod(); +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/state/SessionStore.java b/streams/src/main/java/org/apache/kafka/streams/state/SessionStore.java new file mode 100644 index 0000000..926cddc --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/SessionStore.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state; + +import org.apache.kafka.streams.internals.ApiUtils; +import org.apache.kafka.streams.kstream.Window; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.StateStore; + +import java.time.Instant; + +import static org.apache.kafka.streams.internals.ApiUtils.prepareMillisCheckFailMsgPrefix; + +/** + * Interface for storing the aggregated values of sessions. + *

                + * The key is internally represented as {@link Windowed Windowed<K>} that comprises the plain + * key and the {@link Window} that represents window start- and end-timestamp. + *

                + * If two sessions are merged, a new session with new start- and end-timestamp must be inserted into + * the store while the two old sessions must be deleted. + * + * @param type of the record keys + * @param type of the aggregated values + */ +public interface SessionStore extends StateStore, ReadOnlySessionStore { + + @Override + default KeyValueIterator, AGG> findSessions(final K key, + final Instant earliestSessionEndTime, + final Instant latestSessionStartTime) { + return findSessions( + key, + ApiUtils.validateMillisecondInstant(earliestSessionEndTime, + prepareMillisCheckFailMsgPrefix(earliestSessionEndTime, "earliestSessionEndTime")), + ApiUtils.validateMillisecondInstant(latestSessionStartTime, + prepareMillisCheckFailMsgPrefix(latestSessionStartTime, "latestSessionStartTime"))); + } + + @Override + default KeyValueIterator, AGG> backwardFindSessions(final K key, + final Instant earliestSessionEndTime, + final Instant latestSessionStartTime) { + return backwardFindSessions( + key, + ApiUtils.validateMillisecondInstant(earliestSessionEndTime, + prepareMillisCheckFailMsgPrefix(earliestSessionEndTime, "earliestSessionEndTime")), + ApiUtils.validateMillisecondInstant(latestSessionStartTime, + prepareMillisCheckFailMsgPrefix(latestSessionStartTime, "latestSessionStartTime"))); + } + + default KeyValueIterator, AGG> findSessions(final K keyFrom, + final K keyTo, + final Instant earliestSessionEndTime, + final Instant latestSessionStartTime) { + return findSessions( + keyFrom, + keyTo, + ApiUtils.validateMillisecondInstant(earliestSessionEndTime, + prepareMillisCheckFailMsgPrefix(earliestSessionEndTime, "earliestSessionEndTime")), + ApiUtils.validateMillisecondInstant(latestSessionStartTime, + prepareMillisCheckFailMsgPrefix(latestSessionStartTime, "latestSessionStartTime"))); + } + + default KeyValueIterator, AGG> backwardFindSessions(final K keyFrom, + final K keyTo, + final Instant earliestSessionEndTime, + final Instant latestSessionStartTime) { + return backwardFindSessions( + keyFrom, + keyTo, + ApiUtils.validateMillisecondInstant(earliestSessionEndTime, + prepareMillisCheckFailMsgPrefix(earliestSessionEndTime, "earliestSessionEndTime")), + ApiUtils.validateMillisecondInstant(latestSessionStartTime, + prepareMillisCheckFailMsgPrefix(latestSessionStartTime, "latestSessionStartTime"))); + } + + default AGG fetchSession(final K key, final Instant earliestSessionEndTime, final Instant latestSessionStartTime) { + return fetchSession(key, + ApiUtils.validateMillisecondInstant(earliestSessionEndTime, + prepareMillisCheckFailMsgPrefix(earliestSessionEndTime, "startTime")), + ApiUtils.validateMillisecondInstant(latestSessionStartTime, + prepareMillisCheckFailMsgPrefix(latestSessionStartTime, "endTime"))); + } + + /** + * Remove the session aggregated with provided {@link Windowed} key from the store + * + * @param sessionKey key of the session to remove + * @throws NullPointerException If null is used for sessionKey. + */ + void remove(final Windowed sessionKey); + + /** + * Write the aggregated value for the provided key to the store + * + * @param sessionKey key of the session to write + * @param aggregate the aggregated value for the session, it can be null; if the serialized + * bytes are also null it is interpreted as deletes + * @throws NullPointerException If null is used for sessionKey. + */ + void put(final Windowed sessionKey, final AGG aggregate); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/StateSerdes.java b/streams/src/main/java/org/apache/kafka/streams/state/StateSerdes.java new file mode 100644 index 0000000..f9f0bdc --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/StateSerdes.java @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state; + +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.state.internals.ValueAndTimestampSerializer; + +import java.util.Objects; + +/** + * Factory for creating serializers / deserializers for state stores in Kafka Streams. + * + * @param key type of serde + * @param value type of serde + */ +public final class StateSerdes { + + public static final int TIMESTAMP_SIZE = 8; + public static final int BOOLEAN_SIZE = 1; + + /** + * Create a new instance of {@link StateSerdes} for the given state name and key-/value-type classes. + * + * @param topic the topic name + * @param keyClass the class of the key type + * @param valueClass the class of the value type + * @param the key type + * @param the value type + * @return a new instance of {@link StateSerdes} + */ + public static StateSerdes withBuiltinTypes( + final String topic, + final Class keyClass, + final Class valueClass) { + return new StateSerdes<>(topic, Serdes.serdeFrom(keyClass), Serdes.serdeFrom(valueClass)); + } + + private final String topic; + private final Serde keySerde; + private final Serde valueSerde; + + /** + * Create a context for serialization using the specified serializers and deserializers which + * must match the key and value types used as parameters for this object; the state changelog topic + * is provided to bind this serde factory to, so that future calls for serialize / deserialize do not + * need to provide the topic name any more. + * + * @param topic the topic name + * @param keySerde the serde for keys; cannot be null + * @param valueSerde the serde for values; cannot be null + * @throws IllegalArgumentException if key or value serde is null + */ + public StateSerdes(final String topic, + final Serde keySerde, + final Serde valueSerde) { + Objects.requireNonNull(topic, "topic cannot be null"); + Objects.requireNonNull(keySerde, "key serde cannot be null"); + Objects.requireNonNull(valueSerde, "value serde cannot be null"); + + this.topic = topic; + this.keySerde = keySerde; + this.valueSerde = valueSerde; + } + + /** + * Return the key serde. + * + * @return the key serde + */ + public Serde keySerde() { + return keySerde; + } + + /** + * Return the value serde. + * + * @return the value serde + */ + public Serde valueSerde() { + return valueSerde; + } + + /** + * Return the key deserializer. + * + * @return the key deserializer + */ + public Deserializer keyDeserializer() { + return keySerde.deserializer(); + } + + /** + * Return the key serializer. + * + * @return the key serializer + */ + public Serializer keySerializer() { + return keySerde.serializer(); + } + + /** + * Return the value deserializer. + * + * @return the value deserializer + */ + public Deserializer valueDeserializer() { + return valueSerde.deserializer(); + } + + /** + * Return the value serializer. + * + * @return the value serializer + */ + public Serializer valueSerializer() { + return valueSerde.serializer(); + } + + /** + * Return the topic. + * + * @return the topic + */ + public String topic() { + return topic; + } + + /** + * Deserialize the key from raw bytes. + * + * @param rawKey the key as raw bytes + * @return the key as typed object + */ + public K keyFrom(final byte[] rawKey) { + return keySerde.deserializer().deserialize(topic, rawKey); + } + + /** + * Deserialize the value from raw bytes. + * + * @param rawValue the value as raw bytes + * @return the value as typed object + */ + public V valueFrom(final byte[] rawValue) { + return valueSerde.deserializer().deserialize(topic, rawValue); + } + + /** + * Serialize the given key. + * + * @param key the key to be serialized + * @return the serialized key + */ + public byte[] rawKey(final K key) { + try { + return keySerde.serializer().serialize(topic, key); + } catch (final ClassCastException e) { + final String keyClass = key == null ? "unknown because key is null" : key.getClass().getName(); + throw new StreamsException( + String.format("A serializer (%s) is not compatible to the actual key type " + + "(key type: %s). Change the default Serdes in StreamConfig or " + + "provide correct Serdes via method parameters.", + keySerializer().getClass().getName(), + keyClass), + e); + } + } + + /** + * Serialize the given value. + * + * @param value the value to be serialized + * @return the serialized value + */ + public byte[] rawValue(final V value) { + try { + return valueSerde.serializer().serialize(topic, value); + } catch (final ClassCastException e) { + final String valueClass; + final Class serializerClass; + if (valueSerializer() instanceof ValueAndTimestampSerializer) { + serializerClass = ((ValueAndTimestampSerializer) valueSerializer()).valueSerializer.getClass(); + valueClass = value == null ? "unknown because value is null" : ((ValueAndTimestamp) value).value().getClass().getName(); + } else { + serializerClass = valueSerializer().getClass(); + valueClass = value == null ? "unknown because value is null" : value.getClass().getName(); + } + throw new StreamsException( + String.format("A serializer (%s) is not compatible to the actual value type " + + "(value type: %s). Change the default Serdes in StreamConfig or " + + "provide correct Serdes via method parameters.", + serializerClass.getName(), + valueClass), + e); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/StoreBuilder.java b/streams/src/main/java/org/apache/kafka/streams/state/StoreBuilder.java new file mode 100644 index 0000000..430ba27 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/StoreBuilder.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state; + +import org.apache.kafka.streams.processor.StateStore; + +import java.util.Map; + +/** + * Build a {@link StateStore} wrapped with optional caching and logging. + * @param the type of store to build + */ +public interface StoreBuilder { + + /** + * Enable caching on the store. + * @return this + */ + StoreBuilder withCachingEnabled(); + + /** + * Disable caching on the store. + * @return this + */ + StoreBuilder withCachingDisabled(); + + /** + * Maintain a changelog for any changes made to the store. + * Use the provided config to set the config of the changelog topic. + * @param config config applied to the changelog topic + * @return this + */ + StoreBuilder withLoggingEnabled(final Map config); + + /** + * Disable the changelog for store built by this {@link StoreBuilder}. + * This will turn off fault-tolerance for your store. + * By default the changelog is enabled. + * @return this + */ + StoreBuilder withLoggingDisabled(); + + /** + * Build the store as defined by the builder. + * + * @return the built {@link StateStore} + */ + T build(); + + /** + * Returns a Map containing any log configs that will be used when creating the changelog for the {@link StateStore}. + *

                + * Note: any unrecognized configs will be ignored by the Kafka brokers. + * + * @return Map containing any log configs to be used when creating the changelog for the {@link StateStore} + * If {@code loggingEnabled} returns false, this function will always return an empty map + */ + Map logConfig(); + + /** + * @return {@code true} if the {@link StateStore} should have logging enabled + */ + boolean loggingEnabled(); + + /** + * Return the name of this state store builder. + * This must be a valid Kafka topic name; valid characters are ASCII alphanumerics, '.', '_' and '-'. + * + * @return the name of this state store builder + */ + String name(); + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/StoreSupplier.java b/streams/src/main/java/org/apache/kafka/streams/state/StoreSupplier.java new file mode 100644 index 0000000..10e6f2d --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/StoreSupplier.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state; + +import org.apache.kafka.streams.processor.StateStore; + +/** + * A state store supplier which can create one or more {@link StateStore} instances. + * + * @param State store type + */ +public interface StoreSupplier { + /** + * Return the name of this state store supplier. + * This must be a valid Kafka topic name; valid characters are ASCII alphanumerics, '.', '_' and '-'. + * + * @return the name of this state store supplier + */ + String name(); + + /** + * Return a new {@link StateStore} instance. + * + * @return a new {@link StateStore} instance of type T + */ + T get(); + + /** + * Return a String that is used as the scope for metrics recorded by Metered stores. + * @return metricsScope + */ + String metricsScope(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/Stores.java b/streams/src/main/java/org/apache/kafka/streams/state/Stores.java new file mode 100644 index 0000000..bf4e5aa --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/Stores.java @@ -0,0 +1,466 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.state.internals.InMemoryKeyValueStore; +import org.apache.kafka.streams.state.internals.InMemorySessionBytesStoreSupplier; +import org.apache.kafka.streams.state.internals.InMemoryWindowBytesStoreSupplier; +import org.apache.kafka.streams.state.internals.KeyValueStoreBuilder; +import org.apache.kafka.streams.state.internals.MemoryNavigableLRUCache; +import org.apache.kafka.streams.state.internals.RocksDbKeyValueBytesStoreSupplier; +import org.apache.kafka.streams.state.internals.RocksDbSessionBytesStoreSupplier; +import org.apache.kafka.streams.state.internals.RocksDbWindowBytesStoreSupplier; +import org.apache.kafka.streams.state.internals.SessionStoreBuilder; +import org.apache.kafka.streams.state.internals.TimestampedKeyValueStoreBuilder; +import org.apache.kafka.streams.state.internals.TimestampedWindowStoreBuilder; +import org.apache.kafka.streams.state.internals.WindowStoreBuilder; + +import java.time.Duration; +import java.util.Objects; + +import static org.apache.kafka.streams.internals.ApiUtils.prepareMillisCheckFailMsgPrefix; +import static org.apache.kafka.streams.internals.ApiUtils.validateMillisecondDuration; + +/** + * Factory for creating state stores in Kafka Streams. + *

                + * When using the high-level DSL, i.e., {@link org.apache.kafka.streams.StreamsBuilder StreamsBuilder}, users create + * {@link StoreSupplier}s that can be further customized via + * {@link org.apache.kafka.streams.kstream.Materialized Materialized}. + * For example, a topic read as {@link org.apache.kafka.streams.kstream.KTable KTable} can be materialized into an + * in-memory store with custom key/value serdes and caching disabled: + *

                {@code
                + * StreamsBuilder builder = new StreamsBuilder();
                + * KeyValueBytesStoreSupplier storeSupplier = Stores.inMemoryKeyValueStore("queryable-store-name");
                + * KTable table = builder.table(
                + *   "topicName",
                + *   Materialized.as(storeSupplier)
                + *               .withKeySerde(Serdes.Long())
                + *               .withValueSerde(Serdes.String())
                + *               .withCachingDisabled());
                + * }
                + * When using the Processor API, i.e., {@link org.apache.kafka.streams.Topology Topology}, users create + * {@link StoreBuilder}s that can be attached to {@link org.apache.kafka.streams.processor.api.Processor Processor}s. + * For example, you can create a {@link org.apache.kafka.streams.kstream.Windowed windowed} RocksDB store with custom + * changelog topic configuration like: + *
                {@code
                + * Topology topology = new Topology();
                + * topology.addProcessor("processorName", ...);
                + *
                + * Map topicConfig = new HashMap<>();
                + * StoreBuilder> storeBuilder = Stores
                + *   .windowStoreBuilder(
                + *     Stores.persistentWindowStore("queryable-store-name", ...),
                + *     Serdes.Integer(),
                + *     Serdes.Long())
                + *   .withLoggingEnabled(topicConfig);
                + *
                + * topology.addStateStore(storeBuilder, "processorName");
                + * }
                + */ +public final class Stores { + + /** + * Create a persistent {@link KeyValueBytesStoreSupplier}. + *

                + * This store supplier can be passed into a {@link #keyValueStoreBuilder(KeyValueBytesStoreSupplier, Serde, Serde)}. + * If you want to create a {@link TimestampedKeyValueStore} you should use + * {@link #persistentTimestampedKeyValueStore(String)} to create a store supplier instead. + * + * @param name name of the store (cannot be {@code null}) + * @return an instance of a {@link KeyValueBytesStoreSupplier} that can be used + * to build a persistent key-value store + */ + public static KeyValueBytesStoreSupplier persistentKeyValueStore(final String name) { + Objects.requireNonNull(name, "name cannot be null"); + return new RocksDbKeyValueBytesStoreSupplier(name, false); + } + + /** + * Create a persistent {@link KeyValueBytesStoreSupplier}. + *

                + * This store supplier can be passed into a + * {@link #timestampedKeyValueStoreBuilder(KeyValueBytesStoreSupplier, Serde, Serde)}. + * If you want to create a {@link KeyValueStore} you should use + * {@link #persistentKeyValueStore(String)} to create a store supplier instead. + * + * @param name name of the store (cannot be {@code null}) + * @return an instance of a {@link KeyValueBytesStoreSupplier} that can be used + * to build a persistent key-(timestamp/value) store + */ + public static KeyValueBytesStoreSupplier persistentTimestampedKeyValueStore(final String name) { + Objects.requireNonNull(name, "name cannot be null"); + return new RocksDbKeyValueBytesStoreSupplier(name, true); + } + + /** + * Create an in-memory {@link KeyValueBytesStoreSupplier}. + *

                + * This store supplier can be passed into a {@link #keyValueStoreBuilder(KeyValueBytesStoreSupplier, Serde, Serde)} + * or {@link #timestampedKeyValueStoreBuilder(KeyValueBytesStoreSupplier, Serde, Serde)}. + * + * @param name name of the store (cannot be {@code null}) + * @return an instance of a {@link KeyValueBytesStoreSupplier} than can be used to + * build an in-memory store + */ + public static KeyValueBytesStoreSupplier inMemoryKeyValueStore(final String name) { + Objects.requireNonNull(name, "name cannot be null"); + return new KeyValueBytesStoreSupplier() { + @Override + public String name() { + return name; + } + + @Override + public KeyValueStore get() { + return new InMemoryKeyValueStore(name); + } + + @Override + public String metricsScope() { + return "in-memory"; + } + }; + } + + /** + * Create a LRU Map {@link KeyValueBytesStoreSupplier}. + *

                + * This store supplier can be passed into a {@link #keyValueStoreBuilder(KeyValueBytesStoreSupplier, Serde, Serde)} + * or {@link #timestampedKeyValueStoreBuilder(KeyValueBytesStoreSupplier, Serde, Serde)}. + * + * @param name name of the store (cannot be {@code null}) + * @param maxCacheSize maximum number of items in the LRU (cannot be negative) + * @return an instance of a {@link KeyValueBytesStoreSupplier} that can be used to build + * an LRU Map based store + * @throws IllegalArgumentException if {@code maxCacheSize} is negative + */ + public static KeyValueBytesStoreSupplier lruMap(final String name, final int maxCacheSize) { + Objects.requireNonNull(name, "name cannot be null"); + if (maxCacheSize < 0) { + throw new IllegalArgumentException("maxCacheSize cannot be negative"); + } + return new KeyValueBytesStoreSupplier() { + @Override + public String name() { + return name; + } + + @Override + public KeyValueStore get() { + return new MemoryNavigableLRUCache(name, maxCacheSize); + } + + @Override + public String metricsScope() { + return "in-memory-lru"; + } + }; + } + + /** + * Create a persistent {@link WindowBytesStoreSupplier}. + *

                + * This store supplier can be passed into a {@link #windowStoreBuilder(WindowBytesStoreSupplier, Serde, Serde)}. + * If you want to create a {@link TimestampedWindowStore} you should use + * {@link #persistentTimestampedWindowStore(String, Duration, Duration, boolean)} to create a store supplier instead. + * + * @param name name of the store (cannot be {@code null}) + * @param retentionPeriod length of time to retain data in the store (cannot be negative) + * (note that the retention period must be at least long enough to contain the + * windowed data's entire life cycle, from window-start through window-end, + * and for the entire grace period) + * @param windowSize size of the windows (cannot be negative) + * @param retainDuplicates whether or not to retain duplicates. Turning this on will automatically disable + * caching and means that null values will be ignored. + * @return an instance of {@link WindowBytesStoreSupplier} + * @throws IllegalArgumentException if {@code retentionPeriod} or {@code windowSize} can't be represented as {@code long milliseconds} + * @throws IllegalArgumentException if {@code retentionPeriod} is smaller than {@code windowSize} + */ + public static WindowBytesStoreSupplier persistentWindowStore(final String name, + final Duration retentionPeriod, + final Duration windowSize, + final boolean retainDuplicates) throws IllegalArgumentException { + return persistentWindowStore(name, retentionPeriod, windowSize, retainDuplicates, false); + } + + /** + * Create a persistent {@link WindowBytesStoreSupplier}. + *

                + * This store supplier can be passed into a + * {@link #timestampedWindowStoreBuilder(WindowBytesStoreSupplier, Serde, Serde)}. + * If you want to create a {@link WindowStore} you should use + * {@link #persistentWindowStore(String, Duration, Duration, boolean)} to create a store supplier instead. + * + * @param name name of the store (cannot be {@code null}) + * @param retentionPeriod length of time to retain data in the store (cannot be negative) + * (note that the retention period must be at least long enough to contain the + * windowed data's entire life cycle, from window-start through window-end, + * and for the entire grace period) + * @param windowSize size of the windows (cannot be negative) + * @param retainDuplicates whether or not to retain duplicates. Turning this on will automatically disable + * caching and means that null values will be ignored. + * @return an instance of {@link WindowBytesStoreSupplier} + * @throws IllegalArgumentException if {@code retentionPeriod} or {@code windowSize} can't be represented as {@code long milliseconds} + * @throws IllegalArgumentException if {@code retentionPeriod} is smaller than {@code windowSize} + */ + public static WindowBytesStoreSupplier persistentTimestampedWindowStore(final String name, + final Duration retentionPeriod, + final Duration windowSize, + final boolean retainDuplicates) throws IllegalArgumentException { + return persistentWindowStore(name, retentionPeriod, windowSize, retainDuplicates, true); + } + + private static WindowBytesStoreSupplier persistentWindowStore(final String name, + final Duration retentionPeriod, + final Duration windowSize, + final boolean retainDuplicates, + final boolean timestampedStore) { + Objects.requireNonNull(name, "name cannot be null"); + final String rpMsgPrefix = prepareMillisCheckFailMsgPrefix(retentionPeriod, "retentionPeriod"); + final long retentionMs = validateMillisecondDuration(retentionPeriod, rpMsgPrefix); + final String wsMsgPrefix = prepareMillisCheckFailMsgPrefix(windowSize, "windowSize"); + final long windowSizeMs = validateMillisecondDuration(windowSize, wsMsgPrefix); + + final long defaultSegmentInterval = Math.max(retentionMs / 2, 60_000L); + + return persistentWindowStore(name, retentionMs, windowSizeMs, retainDuplicates, defaultSegmentInterval, timestampedStore); + } + + private static WindowBytesStoreSupplier persistentWindowStore(final String name, + final long retentionPeriod, + final long windowSize, + final boolean retainDuplicates, + final long segmentInterval, + final boolean timestampedStore) { + Objects.requireNonNull(name, "name cannot be null"); + if (retentionPeriod < 0L) { + throw new IllegalArgumentException("retentionPeriod cannot be negative"); + } + if (windowSize < 0L) { + throw new IllegalArgumentException("windowSize cannot be negative"); + } + if (segmentInterval < 1L) { + throw new IllegalArgumentException("segmentInterval cannot be zero or negative"); + } + if (windowSize > retentionPeriod) { + throw new IllegalArgumentException("The retention period of the window store " + + name + " must be no smaller than its window size. Got size=[" + + windowSize + "], retention=[" + retentionPeriod + "]"); + } + + return new RocksDbWindowBytesStoreSupplier( + name, + retentionPeriod, + segmentInterval, + windowSize, + retainDuplicates, + timestampedStore); + } + + /** + * Create an in-memory {@link WindowBytesStoreSupplier}. + *

                + * This store supplier can be passed into a {@link #windowStoreBuilder(WindowBytesStoreSupplier, Serde, Serde)} or + * {@link #timestampedWindowStoreBuilder(WindowBytesStoreSupplier, Serde, Serde)}. + * + * @param name name of the store (cannot be {@code null}) + * @param retentionPeriod length of time to retain data in the store (cannot be negative) + * Note that the retention period must be at least long enough to contain the + * windowed data's entire life cycle, from window-start through window-end, + * and for the entire grace period. + * @param windowSize size of the windows (cannot be negative) + * @param retainDuplicates whether or not to retain duplicates. Turning this on will automatically disable + * caching and means that null values will be ignored. + * @return an instance of {@link WindowBytesStoreSupplier} + * @throws IllegalArgumentException if {@code retentionPeriod} or {@code windowSize} can't be represented as {@code long milliseconds} + * @throws IllegalArgumentException if {@code retentionPeriod} is smaller than {@code windowSize} + */ + public static WindowBytesStoreSupplier inMemoryWindowStore(final String name, + final Duration retentionPeriod, + final Duration windowSize, + final boolean retainDuplicates) throws IllegalArgumentException { + Objects.requireNonNull(name, "name cannot be null"); + + final String repartitionPeriodErrorMessagePrefix = prepareMillisCheckFailMsgPrefix(retentionPeriod, "retentionPeriod"); + final long retentionMs = validateMillisecondDuration(retentionPeriod, repartitionPeriodErrorMessagePrefix); + if (retentionMs < 0L) { + throw new IllegalArgumentException("retentionPeriod cannot be negative"); + } + + final String windowSizeErrorMessagePrefix = prepareMillisCheckFailMsgPrefix(windowSize, "windowSize"); + final long windowSizeMs = validateMillisecondDuration(windowSize, windowSizeErrorMessagePrefix); + if (windowSizeMs < 0L) { + throw new IllegalArgumentException("windowSize cannot be negative"); + } + + if (windowSizeMs > retentionMs) { + throw new IllegalArgumentException("The retention period of the window store " + + name + " must be no smaller than its window size. Got size=[" + + windowSize + "], retention=[" + retentionPeriod + "]"); + } + + return new InMemoryWindowBytesStoreSupplier(name, retentionMs, windowSizeMs, retainDuplicates); + } + + /** + * Create a persistent {@link SessionBytesStoreSupplier}. + * + * @param name name of the store (cannot be {@code null}) + * @param retentionPeriod length of time to retain data in the store (cannot be negative) + * (note that the retention period must be at least as long enough to + * contain the inactivity gap of the session and the entire grace period.) + * @return an instance of a {@link SessionBytesStoreSupplier} + */ + public static SessionBytesStoreSupplier persistentSessionStore(final String name, + final Duration retentionPeriod) { + Objects.requireNonNull(name, "name cannot be null"); + final String msgPrefix = prepareMillisCheckFailMsgPrefix(retentionPeriod, "retentionPeriod"); + final long retentionPeriodMs = validateMillisecondDuration(retentionPeriod, msgPrefix); + if (retentionPeriodMs < 0) { + throw new IllegalArgumentException("retentionPeriod cannot be negative"); + } + return new RocksDbSessionBytesStoreSupplier(name, retentionPeriodMs); + } + + /** + * Create an in-memory {@link SessionBytesStoreSupplier}. + * + * @param name name of the store (cannot be {@code null}) + * @param retentionPeriod length ot time to retain data in the store (cannot be negative) + * (note that the retention period must be at least as long enough to + * contain the inactivity gap of the session and the entire grace period.) + * @return an instance of a {@link SessionBytesStoreSupplier} + */ + public static SessionBytesStoreSupplier inMemorySessionStore(final String name, final Duration retentionPeriod) { + Objects.requireNonNull(name, "name cannot be null"); + + final String msgPrefix = prepareMillisCheckFailMsgPrefix(retentionPeriod, "retentionPeriod"); + final long retentionPeriodMs = validateMillisecondDuration(retentionPeriod, msgPrefix); + if (retentionPeriodMs < 0) { + throw new IllegalArgumentException("retentionPeriod cannot be negative"); + } + return new InMemorySessionBytesStoreSupplier(name, retentionPeriodMs); + } + + /** + * Creates a {@link StoreBuilder} that can be used to build a {@link KeyValueStore}. + *

                + * The provided supplier should not be a supplier for + * {@link TimestampedKeyValueStore TimestampedKeyValueStores}. + * + * @param supplier a {@link KeyValueBytesStoreSupplier} (cannot be {@code null}) + * @param keySerde the key serde to use + * @param valueSerde the value serde to use; if the serialized bytes is {@code null} for put operations, + * it is treated as delete + * @param key type + * @param value type + * @return an instance of a {@link StoreBuilder} that can build a {@link KeyValueStore} + */ + public static StoreBuilder> keyValueStoreBuilder(final KeyValueBytesStoreSupplier supplier, + final Serde keySerde, + final Serde valueSerde) { + Objects.requireNonNull(supplier, "supplier cannot be null"); + return new KeyValueStoreBuilder<>(supplier, keySerde, valueSerde, Time.SYSTEM); + } + + /** + * Creates a {@link StoreBuilder} that can be used to build a {@link TimestampedKeyValueStore}. + *

                + * The provided supplier should not be a supplier for + * {@link KeyValueStore KeyValueStores}. For this case, passed in timestamps will be dropped and not stored in the + * key-value-store. On read, no valid timestamp but a dummy timestamp will be returned. + * + * @param supplier a {@link KeyValueBytesStoreSupplier} (cannot be {@code null}) + * @param keySerde the key serde to use + * @param valueSerde the value serde to use; if the serialized bytes is {@code null} for put operations, + * it is treated as delete + * @param key type + * @param value type + * @return an instance of a {@link StoreBuilder} that can build a {@link KeyValueStore} + */ + public static StoreBuilder> timestampedKeyValueStoreBuilder(final KeyValueBytesStoreSupplier supplier, + final Serde keySerde, + final Serde valueSerde) { + Objects.requireNonNull(supplier, "supplier cannot be null"); + return new TimestampedKeyValueStoreBuilder<>(supplier, keySerde, valueSerde, Time.SYSTEM); + } + + /** + * Creates a {@link StoreBuilder} that can be used to build a {@link WindowStore}. + *

                + * The provided supplier should not be a supplier for + * {@link TimestampedWindowStore TimestampedWindowStores}. + * + * @param supplier a {@link WindowBytesStoreSupplier} (cannot be {@code null}) + * @param keySerde the key serde to use + * @param valueSerde the value serde to use; if the serialized bytes is {@code null} for put operations, + * it is treated as delete + * @param key type + * @param value type + * @return an instance of {@link StoreBuilder} than can build a {@link WindowStore} + */ + public static StoreBuilder> windowStoreBuilder(final WindowBytesStoreSupplier supplier, + final Serde keySerde, + final Serde valueSerde) { + Objects.requireNonNull(supplier, "supplier cannot be null"); + return new WindowStoreBuilder<>(supplier, keySerde, valueSerde, Time.SYSTEM); + } + + /** + * Creates a {@link StoreBuilder} that can be used to build a {@link TimestampedWindowStore}. + *

                + * The provided supplier should not be a supplier for + * {@link WindowStore WindowStores}. For this case, passed in timestamps will be dropped and not stored in the + * window-store. On read, no valid timestamp but a dummy timestamp will be returned. + * + * @param supplier a {@link WindowBytesStoreSupplier} (cannot be {@code null}) + * @param keySerde the key serde to use + * @param valueSerde the value serde to use; if the serialized bytes is {@code null} for put operations, + * it is treated as delete + * @param key type + * @param value type + * @return an instance of {@link StoreBuilder} that can build a {@link TimestampedWindowStore} + */ + public static StoreBuilder> timestampedWindowStoreBuilder(final WindowBytesStoreSupplier supplier, + final Serde keySerde, + final Serde valueSerde) { + Objects.requireNonNull(supplier, "supplier cannot be null"); + return new TimestampedWindowStoreBuilder<>(supplier, keySerde, valueSerde, Time.SYSTEM); + } + + /** + * Creates a {@link StoreBuilder} that can be used to build a {@link SessionStore}. + * + * @param supplier a {@link SessionBytesStoreSupplier} (cannot be {@code null}) + * @param keySerde the key serde to use + * @param valueSerde the value serde to use; if the serialized bytes is {@code null} for put operations, + * it is treated as delete + * @param key type + * @param value type + * @return an instance of {@link StoreBuilder} than can build a {@link SessionStore} + */ + public static StoreBuilder> sessionStoreBuilder(final SessionBytesStoreSupplier supplier, + final Serde keySerde, + final Serde valueSerde) { + Objects.requireNonNull(supplier, "supplier cannot be null"); + return new SessionStoreBuilder<>(supplier, keySerde, valueSerde, Time.SYSTEM); + } +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/state/StreamsMetadata.java b/streams/src/main/java/org/apache/kafka/streams/state/StreamsMetadata.java new file mode 100644 index 0000000..131d16f --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/StreamsMetadata.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state; + +import java.util.Objects; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.streams.KafkaStreams; + +import java.util.Collections; +import java.util.Set; + +/** + * Represents the state of an instance (process) in a {@link KafkaStreams} application. + * It contains the user supplied {@link HostInfo} that can be used by developers to build + * APIs and services to connect to other instances, the Set of state stores available on + * the instance and the Set of {@link TopicPartition}s available on the instance. + * NOTE: This is a point in time view. It may change when rebalances happen. + * @deprecated since 3.0.0 use {@link org.apache.kafka.streams.StreamsMetadata} + */ +@Deprecated +public class StreamsMetadata { + /** + * Sentinel to indicate that the StreamsMetadata is currently unavailable. This can occur during rebalance + * operations. + */ + public final static StreamsMetadata NOT_AVAILABLE = new StreamsMetadata(HostInfo.unavailable(), + Collections.emptySet(), + Collections.emptySet(), + Collections.emptySet(), + Collections.emptySet()); + + private final HostInfo hostInfo; + + private final Set stateStoreNames; + + private final Set topicPartitions; + + private final Set standbyStateStoreNames; + + private final Set standbyTopicPartitions; + + public StreamsMetadata(final HostInfo hostInfo, + final Set stateStoreNames, + final Set topicPartitions, + final Set standbyStateStoreNames, + final Set standbyTopicPartitions) { + + this.hostInfo = hostInfo; + this.stateStoreNames = stateStoreNames; + this.topicPartitions = topicPartitions; + this.standbyTopicPartitions = standbyTopicPartitions; + this.standbyStateStoreNames = standbyStateStoreNames; + } + + /** + * The value of {@link org.apache.kafka.streams.StreamsConfig#APPLICATION_SERVER_CONFIG} configured for the streams + * instance, which is typically host/port + * + * @return {@link HostInfo} corresponding to the streams instance + */ + public HostInfo hostInfo() { + return hostInfo; + } + + /** + * State stores owned by the instance as an active replica + * + * @return set of active state store names + */ + public Set stateStoreNames() { + return Collections.unmodifiableSet(stateStoreNames); + } + + /** + * Topic partitions consumed by the instance as an active replica + * + * @return set of active topic partitions + */ + public Set topicPartitions() { + return Collections.unmodifiableSet(topicPartitions); + } + + /** + * (Source) Topic partitions for which the instance acts as standby. + * + * @return set of standby topic partitions + */ + public Set standbyTopicPartitions() { + return Collections.unmodifiableSet(standbyTopicPartitions); + } + + /** + * State stores owned by the instance as a standby replica + * + * @return set of standby state store names + */ + public Set standbyStateStoreNames() { + return Collections.unmodifiableSet(standbyStateStoreNames); + } + + public String host() { + return hostInfo.host(); + } + + @SuppressWarnings("unused") + public int port() { + return hostInfo.port(); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + final StreamsMetadata that = (StreamsMetadata) o; + return Objects.equals(hostInfo, that.hostInfo) + && Objects.equals(stateStoreNames, that.stateStoreNames) + && Objects.equals(topicPartitions, that.topicPartitions) + && Objects.equals(standbyStateStoreNames, that.standbyStateStoreNames) + && Objects.equals(standbyTopicPartitions, that.standbyTopicPartitions); + } + + @Override + public int hashCode() { + return Objects.hash(hostInfo, stateStoreNames, topicPartitions, standbyStateStoreNames, standbyTopicPartitions); + } + + @Override + public String toString() { + return "StreamsMetadata {" + + "hostInfo=" + hostInfo + + ", stateStoreNames=" + stateStoreNames + + ", topicPartitions=" + topicPartitions + + ", standbyStateStoreNames=" + standbyStateStoreNames + + ", standbyTopicPartitions=" + standbyTopicPartitions + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/TimestampedBytesStore.java b/streams/src/main/java/org/apache/kafka/streams/state/TimestampedBytesStore.java new file mode 100644 index 0000000..e609b70 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/TimestampedBytesStore.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state; + +import java.nio.ByteBuffer; + +import static org.apache.kafka.clients.consumer.ConsumerRecord.NO_TIMESTAMP; + +public interface TimestampedBytesStore { + static byte[] convertToTimestampedFormat(final byte[] plainValue) { + if (plainValue == null) { + return null; + } + return ByteBuffer + .allocate(8 + plainValue.length) + .putLong(NO_TIMESTAMP) + .put(plainValue) + .array(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/TimestampedKeyValueStore.java b/streams/src/main/java/org/apache/kafka/streams/state/TimestampedKeyValueStore.java new file mode 100644 index 0000000..ef5ef57 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/TimestampedKeyValueStore.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state; + +/** + * A key-(value/timestamp) store that supports put/get/delete and range queries. + * + * @param The key type + * @param The value type + */ +public interface TimestampedKeyValueStore extends KeyValueStore> { } \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/state/TimestampedWindowStore.java b/streams/src/main/java/org/apache/kafka/streams/state/TimestampedWindowStore.java new file mode 100644 index 0000000..7d52c12 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/TimestampedWindowStore.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state; + +import org.apache.kafka.streams.kstream.Windowed; + +/** + * Interface for storing the aggregated values of fixed-size time windows. + *

                + * Note, that the stores's physical key type is {@link Windowed Windowed<K>}. + * In contrast to a {@link WindowStore} that stores plain windowedKeys-value pairs, + * a {@code TimestampedWindowStore} stores windowedKeys-(value/timestamp) pairs. + *

                + * While the window start- and end-timestamp are fixed per window, the value-side timestamp is used + * to store the last update timestamp of the corresponding window. + * + * @param Type of keys + * @param Type of values + */ +public interface TimestampedWindowStore extends WindowStore> { } \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/state/ValueAndTimestamp.java b/streams/src/main/java/org/apache/kafka/streams/state/ValueAndTimestamp.java new file mode 100644 index 0000000..f5fc7a2 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/ValueAndTimestamp.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state; + +import org.apache.kafka.streams.KeyValue; + +import java.util.Objects; + +/** + * Combines a value from a {@link KeyValue} with a timestamp. + * + * @param + */ +public final class ValueAndTimestamp { + private final V value; + private final long timestamp; + + private ValueAndTimestamp(final V value, + final long timestamp) { + Objects.requireNonNull(value); + this.value = value; + this.timestamp = timestamp; + } + + /** + * Create a new {@link ValueAndTimestamp} instance if the provide {@code value} is not {@code null}. + * + * @param value the value + * @param timestamp the timestamp + * @param the type of the value + * @return a new {@link ValueAndTimestamp} instance if the provide {@code value} is not {@code null}; + * otherwise {@code null} is returned + */ + public static ValueAndTimestamp make(final V value, + final long timestamp) { + return value == null ? null : new ValueAndTimestamp<>(value, timestamp); + } + + /** + * Return the wrapped {@code value} of the given {@code valueAndTimestamp} parameter + * if the parameter is not {@code null}. + * + * @param valueAndTimestamp a {@link ValueAndTimestamp} instance; can be {@code null} + * @param the type of the value + * @return the wrapped {@code value} of {@code valueAndTimestamp} if not {@code null}; otherwise {@code null} + */ + public static V getValueOrNull(final ValueAndTimestamp valueAndTimestamp) { + return valueAndTimestamp == null ? null : valueAndTimestamp.value(); + } + + public V value() { + return value; + } + + public long timestamp() { + return timestamp; + } + + @Override + public String toString() { + return "<" + value + "," + timestamp + ">"; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final ValueAndTimestamp that = (ValueAndTimestamp) o; + return timestamp == that.timestamp && + Objects.equals(value, that.value); + } + + @Override + public int hashCode() { + return Objects.hash(value, timestamp); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/WindowBytesStoreSupplier.java b/streams/src/main/java/org/apache/kafka/streams/state/WindowBytesStoreSupplier.java new file mode 100644 index 0000000..9ced28c --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/WindowBytesStoreSupplier.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state; + +import org.apache.kafka.common.utils.Bytes; + +/** + * A store supplier that can be used to create one or more {@link WindowStore WindowStore<Byte, byte[]>} instances of type <Byte, byte[]>. + * + * For any stores implementing the {@link WindowStore WindowStore<Byte, byte[]>} interface, null value bytes are considered as "not exist". This means: + * + * 1. Null value bytes in put operations should be treated as delete. + * 2. Null value bytes should never be returned in range query results. + */ +public interface WindowBytesStoreSupplier extends StoreSupplier> { + + /** + * The size of the segments (in milliseconds) the store has. + * If your store is segmented then this should be the size of segments in the underlying store. + * It is also used to reduce the amount of data that is scanned when caching is enabled. + * + * @return size of the segments (in milliseconds) + */ + long segmentIntervalMs(); + + /** + * The size of the windows (in milliseconds) any store created from this supplier is creating. + * + * @return window size + */ + long windowSize(); + + /** + * Whether or not this store is retaining duplicate keys. + * Usually only true if the store is being used for joins. + * Note this should return false if caching is enabled. + * + * @return true if duplicates should be retained + */ + boolean retainDuplicates(); + + /** + * The time period for which the {@link WindowStore} will retain historic data. + * + * @return retentionPeriod + */ + long retentionPeriod(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/WindowStore.java b/streams/src/main/java/org/apache/kafka/streams/state/WindowStore.java new file mode 100644 index 0000000..86c82fa --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/WindowStore.java @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state; + +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.internals.ApiUtils; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.StateStore; + +import java.time.Instant; + +import static org.apache.kafka.streams.internals.ApiUtils.prepareMillisCheckFailMsgPrefix; + +/** + * Interface for storing the aggregated values of fixed-size time windows. + *

                + * Note, that the stores' physical key type is {@link Windowed Windowed<K>}. + * + * @param Type of keys + * @param Type of values + */ +public interface WindowStore extends StateStore, ReadOnlyWindowStore { + + /** + * Put a key-value pair into the window with given window start timestamp + *

                + * If serialized value bytes are null it is interpreted as delete. Note that deletes will be + * ignored in the case of an underlying store that retains duplicates. + * + * @param key The key to associate the value to + * @param value The value; can be null + * @param windowStartTimestamp The timestamp of the beginning of the window to put the key/value into + * @throws NullPointerException if the given key is {@code null} + */ + void put(K key, V value, long windowStartTimestamp); + + /** + * Get all the key-value pairs with the given key and the time range from all the existing windows. + *

                + * This iterator must be closed after use. + *

                + * The time range is inclusive and applies to the starting timestamp of the window. + * For example, if we have the following windows: + *

                +     * +-------------------------------+
                +     * |  key  | start time | end time |
                +     * +-------+------------+----------+
                +     * |   A   |     10     |    20    |
                +     * +-------+------------+----------+
                +     * |   A   |     15     |    25    |
                +     * +-------+------------+----------+
                +     * |   A   |     20     |    30    |
                +     * +-------+------------+----------+
                +     * |   A   |     25     |    35    |
                +     * +--------------------------------
                +     * 
                + * And we call {@code store.fetch("A", 10, 20)} then the results will contain the first + * three windows from the table above, i.e., all those where 10 <= start time <= 20. + *

                + * For each key, the iterator guarantees ordering of windows, starting from the oldest/earliest + * available window to the newest/latest window. + * + * @param key the key to fetch + * @param timeFrom time range start (inclusive) + * @param timeTo time range end (inclusive) + * @return an iterator over key-value pairs {@code } + * @throws InvalidStateStoreException if the store is not initialized + * @throws NullPointerException if the given key is {@code null} + */ + // WindowStore keeps a long-based implementation of ReadOnlyWindowStore#fetch Instant-based + // if super#fetch is removed, keep this implementation as it serves PAPI Stores. + WindowStoreIterator fetch(K key, long timeFrom, long timeTo); + + @Override + default WindowStoreIterator fetch(final K key, + final Instant timeFrom, + final Instant timeTo) throws IllegalArgumentException { + return fetch( + key, + ApiUtils.validateMillisecondInstant(timeFrom, prepareMillisCheckFailMsgPrefix(timeFrom, "timeFrom")), + ApiUtils.validateMillisecondInstant(timeTo, prepareMillisCheckFailMsgPrefix(timeTo, "timeTo"))); + } + + default WindowStoreIterator backwardFetch(final K key, + final long timeFrom, + final long timeTo) { + throw new UnsupportedOperationException(); + } + + @Override + default WindowStoreIterator backwardFetch(final K key, + final Instant timeFrom, + final Instant timeTo) throws IllegalArgumentException { + return backwardFetch( + key, + ApiUtils.validateMillisecondInstant(timeFrom, prepareMillisCheckFailMsgPrefix(timeFrom, "timeFrom")), + ApiUtils.validateMillisecondInstant(timeTo, prepareMillisCheckFailMsgPrefix(timeTo, "timeTo"))); + } + + /** + * Get all the key-value pairs in the given key range and time range from all the existing windows. + *

                + * This iterator must be closed after use. + * + * @param keyFrom the first key in the range + * A null value indicates a starting position from the first element in the store. + * @param keyTo the last key in the range + * A null value indicates that the range ends with the last element in the store. + * @param timeFrom time range start (inclusive) + * @param timeTo time range end (inclusive) + * @return an iterator over windowed key-value pairs {@code , value>} + * @throws InvalidStateStoreException if the store is not initialized + */ + // WindowStore keeps a long-based implementation of ReadOnlyWindowStore#fetch Instant-based + // if super#fetch is removed, keep this implementation as it serves PAPI Stores. + KeyValueIterator, V> fetch(K keyFrom, K keyTo, long timeFrom, long timeTo); + + @Override + default KeyValueIterator, V> fetch(final K keyFrom, + final K keyTo, + final Instant timeFrom, + final Instant timeTo) throws IllegalArgumentException { + return fetch( + keyFrom, + keyTo, + ApiUtils.validateMillisecondInstant(timeFrom, prepareMillisCheckFailMsgPrefix(timeFrom, "timeFrom")), + ApiUtils.validateMillisecondInstant(timeTo, prepareMillisCheckFailMsgPrefix(timeTo, "timeTo"))); + } + + default KeyValueIterator, V> backwardFetch(final K keyFrom, + final K keyTo, + final long timeFrom, + final long timeTo) { + throw new UnsupportedOperationException(); + } + + @Override + default KeyValueIterator, V> backwardFetch(final K keyFrom, + final K keyTo, + final Instant timeFrom, + final Instant timeTo) throws IllegalArgumentException { + return backwardFetch( + keyFrom, + keyTo, + ApiUtils.validateMillisecondInstant(timeFrom, prepareMillisCheckFailMsgPrefix(timeFrom, "timeFrom")), + ApiUtils.validateMillisecondInstant(timeTo, prepareMillisCheckFailMsgPrefix(timeTo, "timeTo"))); + } + + /** + * Gets all the key-value pairs that belong to the windows within in the given time range. + * + * @param timeFrom the beginning of the time slot from which to search (inclusive) + * @param timeTo the end of the time slot from which to search (inclusive) + * @return an iterator over windowed key-value pairs {@code , value>} + * @throws InvalidStateStoreException if the store is not initialized + */ + // WindowStore keeps a long-based implementation of ReadOnlyWindowStore#fetch Instant-based + // if super#fetch is removed, keep this implementation as it serves PAPI Stores. + KeyValueIterator, V> fetchAll(long timeFrom, long timeTo); + + @Override + default KeyValueIterator, V> fetchAll(final Instant timeFrom, final Instant timeTo) throws IllegalArgumentException { + return fetchAll( + ApiUtils.validateMillisecondInstant(timeFrom, prepareMillisCheckFailMsgPrefix(timeFrom, "timeFrom")), + ApiUtils.validateMillisecondInstant(timeTo, prepareMillisCheckFailMsgPrefix(timeTo, "timeTo"))); + } + + default KeyValueIterator, V> backwardFetchAll(final long timeFrom, final long timeTo) { + throw new UnsupportedOperationException(); + } + + @Override + default KeyValueIterator, V> backwardFetchAll(final Instant timeFrom, final Instant timeTo) throws IllegalArgumentException { + return backwardFetchAll( + ApiUtils.validateMillisecondInstant(timeFrom, prepareMillisCheckFailMsgPrefix(timeFrom, "timeFrom")), + ApiUtils.validateMillisecondInstant(timeTo, prepareMillisCheckFailMsgPrefix(timeTo, "timeTo"))); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/WindowStoreIterator.java b/streams/src/main/java/org/apache/kafka/streams/state/WindowStoreIterator.java new file mode 100644 index 0000000..1416351 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/WindowStoreIterator.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state; + +import java.time.Instant; +import org.apache.kafka.streams.KeyValue; + +import java.io.Closeable; + +/** + * Iterator interface of {@link KeyValue} with key typed {@link Long} used for {@link WindowStore#fetch(Object, long, long)} + * and {@link WindowStore#fetch(Object, Instant, Instant)} + * + * Users must call its {@code close} method explicitly upon completeness to release resources, + * or use try-with-resources statement (available since JDK7) for this {@link Closeable} class. + * + * @param Type of values + */ +public interface WindowStoreIterator extends KeyValueIterator, Closeable { + + @Override + void close(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractMergedSortedCacheStoreIterator.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractMergedSortedCacheStoreIterator.java new file mode 100644 index 0000000..819c58c --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractMergedSortedCacheStoreIterator.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.state.KeyValueIterator; + +import java.util.NoSuchElementException; + +/** + * Merges two iterators. Assumes each of them is sorted by key + * + * @param + * @param + */ +abstract class AbstractMergedSortedCacheStoreIterator implements KeyValueIterator { + private final PeekingKeyValueIterator cacheIterator; + private final KeyValueIterator storeIterator; + private final boolean forward; + + AbstractMergedSortedCacheStoreIterator(final PeekingKeyValueIterator cacheIterator, + final KeyValueIterator storeIterator, + final boolean forward) { + this.cacheIterator = cacheIterator; + this.storeIterator = storeIterator; + this.forward = forward; + } + + abstract int compare(final Bytes cacheKey, final KS storeKey); + + abstract K deserializeStoreKey(final KS key); + + abstract KeyValue deserializeStorePair(final KeyValue pair); + + abstract K deserializeCacheKey(final Bytes cacheKey); + + abstract V deserializeCacheValue(final LRUCacheEntry cacheEntry); + + private boolean isDeletedCacheEntry(final KeyValue nextFromCache) { + return nextFromCache.value.value() == null; + } + + @Override + public boolean hasNext() { + // skip over items deleted from cache, and corresponding store items if they have the same key + while (cacheIterator.hasNext() && isDeletedCacheEntry(cacheIterator.peekNext())) { + if (storeIterator.hasNext()) { + final KS nextStoreKey = storeIterator.peekNextKey(); + // advance the store iterator if the key is the same as the deleted cache key + if (compare(cacheIterator.peekNextKey(), nextStoreKey) == 0) { + storeIterator.next(); + } + } + cacheIterator.next(); + } + + return cacheIterator.hasNext() || storeIterator.hasNext(); + } + + @Override + public KeyValue next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + + final Bytes nextCacheKey = cacheIterator.hasNext() ? cacheIterator.peekNextKey() : null; + final KS nextStoreKey = storeIterator.hasNext() ? storeIterator.peekNextKey() : null; + + if (nextCacheKey == null) { + return nextStoreValue(nextStoreKey); + } + + if (nextStoreKey == null) { + return nextCacheValue(nextCacheKey); + } + + final int comparison = compare(nextCacheKey, nextStoreKey); + return chooseNextValue(nextCacheKey, nextStoreKey, comparison); + } + + private KeyValue chooseNextValue(final Bytes nextCacheKey, + final KS nextStoreKey, + final int comparison) { + if (forward) { + if (comparison > 0) { + return nextStoreValue(nextStoreKey); + } else if (comparison < 0) { + return nextCacheValue(nextCacheKey); + } else { + // skip the same keyed element + storeIterator.next(); + return nextCacheValue(nextCacheKey); + } + } else { + if (comparison < 0) { + return nextStoreValue(nextStoreKey); + } else if (comparison > 0) { + return nextCacheValue(nextCacheKey); + } else { + // skip the same keyed element + storeIterator.next(); + return nextCacheValue(nextCacheKey); + } + } + } + + private KeyValue nextStoreValue(final KS nextStoreKey) { + final KeyValue next = storeIterator.next(); + + if (!next.key.equals(nextStoreKey)) { + throw new IllegalStateException("Next record key is not the peeked key value; this should not happen"); + } + + return deserializeStorePair(next); + } + + private KeyValue nextCacheValue(final Bytes nextCacheKey) { + final KeyValue next = cacheIterator.next(); + + if (!next.key.equals(nextCacheKey)) { + throw new IllegalStateException("Next record key is not the peeked key value; this should not happen"); + } + + return KeyValue.pair(deserializeCacheKey(next.key), deserializeCacheValue(next.value)); + } + + @Override + public K peekNextKey() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + + final Bytes nextCacheKey = cacheIterator.hasNext() ? cacheIterator.peekNextKey() : null; + final KS nextStoreKey = storeIterator.hasNext() ? storeIterator.peekNextKey() : null; + + if (nextCacheKey == null) { + return deserializeStoreKey(nextStoreKey); + } + + if (nextStoreKey == null) { + return deserializeCacheKey(nextCacheKey); + } + + final int comparison = compare(nextCacheKey, nextStoreKey); + return chooseNextKey(nextCacheKey, nextStoreKey, comparison); + } + + private K chooseNextKey(final Bytes nextCacheKey, + final KS nextStoreKey, + final int comparison) { + if (forward) { + if (comparison > 0) { + return deserializeStoreKey(nextStoreKey); + } else if (comparison < 0) { + return deserializeCacheKey(nextCacheKey); + } else { + // skip the same keyed element + storeIterator.next(); + return deserializeCacheKey(nextCacheKey); + } + } else { + if (comparison < 0) { + return deserializeStoreKey(nextStoreKey); + } else if (comparison > 0) { + return deserializeCacheKey(nextCacheKey); + } else { + // skip the same keyed element + storeIterator.next(); + return deserializeCacheKey(nextCacheKey); + } + } + } + + @Override + public void close() { + cacheIterator.close(); + storeIterator.close(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractRocksDBSegmentedBytesStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractRocksDBSegmentedBytesStore.java new file mode 100644 index 0000000..bfee6b2 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractRocksDBSegmentedBytesStore.java @@ -0,0 +1,334 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.errors.ProcessorStateException; +import org.apache.kafka.streams.processor.BatchingStateRestoreCallback; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.internals.ProcessorContextUtils; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.processor.internals.metrics.TaskMetrics; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.rocksdb.RocksDBException; +import org.rocksdb.WriteBatch; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class AbstractRocksDBSegmentedBytesStore implements SegmentedBytesStore { + private static final Logger LOG = LoggerFactory.getLogger(AbstractRocksDBSegmentedBytesStore.class); + + private final String name; + private final AbstractSegments segments; + private final String metricScope; + private final KeySchema keySchema; + + private ProcessorContext context; + private Sensor expiredRecordSensor; + private long observedStreamTime = ConsumerRecord.NO_TIMESTAMP; + + private volatile boolean open; + + AbstractRocksDBSegmentedBytesStore(final String name, + final String metricScope, + final KeySchema keySchema, + final AbstractSegments segments) { + this.name = name; + this.metricScope = metricScope; + this.keySchema = keySchema; + this.segments = segments; + } + + @Override + public KeyValueIterator fetch(final Bytes key, + final long from, + final long to) { + return fetch(key, from, to, true); + } + + @Override + public KeyValueIterator backwardFetch(final Bytes key, + final long from, + final long to) { + return fetch(key, from, to, false); + } + + KeyValueIterator fetch(final Bytes key, + final long from, + final long to, + final boolean forward) { + final List searchSpace = keySchema.segmentsToSearch(segments, from, to, forward); + + final Bytes binaryFrom = keySchema.lowerRangeFixedSize(key, from); + final Bytes binaryTo = keySchema.upperRangeFixedSize(key, to); + + return new SegmentIterator<>( + searchSpace.iterator(), + keySchema.hasNextCondition(key, key, from, to), + binaryFrom, + binaryTo, + forward); + } + + @Override + public KeyValueIterator fetch(final Bytes keyFrom, + final Bytes keyTo, + final long from, + final long to) { + return fetch(keyFrom, keyTo, from, to, true); + } + + @Override + public KeyValueIterator backwardFetch(final Bytes keyFrom, + final Bytes keyTo, + final long from, + final long to) { + return fetch(keyFrom, keyTo, from, to, false); + } + + KeyValueIterator fetch(final Bytes keyFrom, + final Bytes keyTo, + final long from, + final long to, + final boolean forward) { + if (keyFrom != null && keyTo != null && keyFrom.compareTo(keyTo) > 0) { + LOG.warn("Returning empty iterator for fetch with invalid key range: from > to. " + + "This may be due to range arguments set in the wrong order, " + + "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes. " + + "Note that the built-in numerical serdes do not follow this for negative numbers"); + return KeyValueIterators.emptyIterator(); + } + + final List searchSpace = keySchema.segmentsToSearch(segments, from, to, forward); + + final Bytes binaryFrom = keyFrom == null ? null : keySchema.lowerRange(keyFrom, from); + final Bytes binaryTo = keyTo == null ? null : keySchema.upperRange(keyTo, to); + + return new SegmentIterator<>( + searchSpace.iterator(), + keySchema.hasNextCondition(keyFrom, keyTo, from, to), + binaryFrom, + binaryTo, + forward); + } + + @Override + public KeyValueIterator all() { + final List searchSpace = segments.allSegments(true); + + return new SegmentIterator<>( + searchSpace.iterator(), + keySchema.hasNextCondition(null, null, 0, Long.MAX_VALUE), + null, + null, + true); + } + + @Override + public KeyValueIterator backwardAll() { + final List searchSpace = segments.allSegments(false); + + return new SegmentIterator<>( + searchSpace.iterator(), + keySchema.hasNextCondition(null, null, 0, Long.MAX_VALUE), + null, + null, + false); + } + + @Override + public KeyValueIterator fetchAll(final long timeFrom, + final long timeTo) { + final List searchSpace = segments.segments(timeFrom, timeTo, true); + + return new SegmentIterator<>( + searchSpace.iterator(), + keySchema.hasNextCondition(null, null, timeFrom, timeTo), + null, + null, + true); + } + + @Override + public KeyValueIterator backwardFetchAll(final long timeFrom, + final long timeTo) { + final List searchSpace = segments.segments(timeFrom, timeTo, false); + + return new SegmentIterator<>( + searchSpace.iterator(), + keySchema.hasNextCondition(null, null, timeFrom, timeTo), + null, + null, + false); + } + + @Override + public void remove(final Bytes key) { + final long timestamp = keySchema.segmentTimestamp(key); + observedStreamTime = Math.max(observedStreamTime, timestamp); + final S segment = segments.getSegmentForTimestamp(timestamp); + if (segment == null) { + return; + } + segment.delete(key); + } + + @Override + public void remove(final Bytes key, final long timestamp) { + final Bytes keyBytes = keySchema.toStoreBinaryKeyPrefix(key, timestamp); + final S segment = segments.getSegmentForTimestamp(timestamp); + if (segment != null) { + segment.deleteRange(keyBytes, keyBytes); + } + } + + @Override + public void put(final Bytes key, + final byte[] value) { + final long timestamp = keySchema.segmentTimestamp(key); + observedStreamTime = Math.max(observedStreamTime, timestamp); + final long segmentId = segments.segmentId(timestamp); + final S segment = segments.getOrCreateSegmentIfLive(segmentId, context, observedStreamTime); + if (segment == null) { + expiredRecordSensor.record(1.0d, ProcessorContextUtils.currentSystemTime(context)); + LOG.warn("Skipping record for expired segment."); + } else { + segment.put(key, value); + } + } + + @Override + public byte[] get(final Bytes key) { + final S segment = segments.getSegmentForTimestamp(keySchema.segmentTimestamp(key)); + if (segment == null) { + return null; + } + return segment.get(key); + } + + @Override + public String name() { + return name; + } + + @Deprecated + @Override + public void init(final ProcessorContext context, + final StateStore root) { + this.context = context; + + final StreamsMetricsImpl metrics = ProcessorContextUtils.getMetricsImpl(context); + final String threadId = Thread.currentThread().getName(); + final String taskName = context.taskId().toString(); + + expiredRecordSensor = TaskMetrics.droppedRecordsSensor( + threadId, + taskName, + metrics + ); + + segments.openExisting(this.context, observedStreamTime); + + // register and possibly restore the state from the logs + context.register(root, new RocksDBSegmentsBatchingRestoreCallback()); + + open = true; + } + + @Override + public void flush() { + segments.flush(); + } + + @Override + public void close() { + open = false; + segments.close(); + } + + @Override + public boolean persistent() { + return true; + } + + @Override + public boolean isOpen() { + return open; + } + + // Visible for testing + List getSegments() { + return segments.allSegments(false); + } + + // Visible for testing + void restoreAllInternal(final Collection> records) { + try { + final Map writeBatchMap = getWriteBatches(records); + for (final Map.Entry entry : writeBatchMap.entrySet()) { + final S segment = entry.getKey(); + final WriteBatch batch = entry.getValue(); + segment.write(batch); + batch.close(); + } + } catch (final RocksDBException e) { + throw new ProcessorStateException("Error restoring batch to store " + this.name, e); + } + } + + // Visible for testing + Map getWriteBatches(final Collection> records) { + // advance stream time to the max timestamp in the batch + for (final KeyValue record : records) { + final long timestamp = keySchema.segmentTimestamp(Bytes.wrap(record.key)); + observedStreamTime = Math.max(observedStreamTime, timestamp); + } + + final Map writeBatchMap = new HashMap<>(); + for (final KeyValue record : records) { + final long timestamp = keySchema.segmentTimestamp(Bytes.wrap(record.key)); + final long segmentId = segments.segmentId(timestamp); + final S segment = segments.getOrCreateSegmentIfLive(segmentId, context, observedStreamTime); + if (segment != null) { + try { + final WriteBatch batch = writeBatchMap.computeIfAbsent(segment, s -> new WriteBatch()); + segment.addToBatch(record, batch); + } catch (final RocksDBException e) { + throw new ProcessorStateException("Error restoring batch to store " + this.name, e); + } + } + } + return writeBatchMap; + } + + private class RocksDBSegmentsBatchingRestoreCallback implements BatchingStateRestoreCallback { + + @Override + public void restoreAll(final Collection> records) { + restoreAllInternal(records); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractSegments.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractSegments.java new file mode 100644 index 0000000..4b59c95 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractSegments.java @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.errors.ProcessorStateException; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.text.ParseException; +import java.text.SimpleDateFormat; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NavigableMap; +import java.util.SimpleTimeZone; +import java.util.TreeMap; + +abstract class AbstractSegments implements Segments { + private static final Logger log = LoggerFactory.getLogger(AbstractSegments.class); + + final TreeMap segments = new TreeMap<>(); + final String name; + private final long retentionPeriod; + private final long segmentInterval; + private final SimpleDateFormat formatter; + + AbstractSegments(final String name, final long retentionPeriod, final long segmentInterval) { + this.name = name; + this.segmentInterval = segmentInterval; + this.retentionPeriod = retentionPeriod; + // Create a date formatter. Formatted timestamps are used as segment name suffixes + this.formatter = new SimpleDateFormat("yyyyMMddHHmm"); + this.formatter.setTimeZone(new SimpleTimeZone(0, "UTC")); + } + + @Override + public long segmentId(final long timestamp) { + return timestamp / segmentInterval; + } + + @Override + public String segmentName(final long segmentId) { + // (1) previous format used - as a separator so if this changes in the future + // then we should use something different. + // (2) previous format used : as a separator (which did break KafkaStreams on Windows OS) + // so if this changes in the future then we should use something different. + return name + "." + segmentId * segmentInterval; + } + + @Override + public S getSegmentForTimestamp(final long timestamp) { + return segments.get(segmentId(timestamp)); + } + + @Override + public S getOrCreateSegmentIfLive(final long segmentId, + final ProcessorContext context, + final long streamTime) { + final long minLiveTimestamp = streamTime - retentionPeriod; + final long minLiveSegment = segmentId(minLiveTimestamp); + + final S toReturn; + if (segmentId >= minLiveSegment) { + // The segment is live. get it, ensure it's open, and return it. + toReturn = getOrCreateSegment(segmentId, context); + } else { + toReturn = null; + } + + cleanupEarlierThan(minLiveSegment); + return toReturn; + } + + @Override + public void openExisting(final ProcessorContext context, final long streamTime) { + try { + final File dir = new File(context.stateDir(), name); + if (dir.exists()) { + final String[] list = dir.list(); + if (list != null) { + Arrays.stream(list) + .map(segment -> segmentIdFromSegmentName(segment, dir)) + .sorted() // open segments in the id order + .filter(segmentId -> segmentId >= 0) + .forEach(segmentId -> getOrCreateSegment(segmentId, context)); + } + } else { + if (!dir.mkdir()) { + throw new ProcessorStateException(String.format("dir %s doesn't exist and cannot be created for segments %s", dir, name)); + } + } + } catch (final Exception ex) { + // ignore + } + + final long minLiveSegment = segmentId(streamTime - retentionPeriod); + cleanupEarlierThan(minLiveSegment); + } + + @Override + public List segments(final long timeFrom, final long timeTo, final boolean forward) { + final List result = new ArrayList<>(); + final NavigableMap segmentsInRange; + if (forward) { + segmentsInRange = segments.subMap( + segmentId(timeFrom), true, + segmentId(timeTo), true + ); + } else { + segmentsInRange = segments.subMap( + segmentId(timeFrom), true, + segmentId(timeTo), true + ).descendingMap(); + } + for (final S segment : segmentsInRange.values()) { + if (segment.isOpen()) { + result.add(segment); + } + } + return result; + } + + @Override + public List allSegments(final boolean forward) { + final List result = new ArrayList<>(); + final Collection values; + if (forward) { + values = segments.values(); + } else { + values = segments.descendingMap().values(); + } + for (final S segment : values) { + if (segment.isOpen()) { + result.add(segment); + } + } + return result; + } + + @Override + public void flush() { + for (final S segment : segments.values()) { + segment.flush(); + } + } + + @Override + public void close() { + for (final S segment : segments.values()) { + segment.close(); + } + segments.clear(); + } + + private void cleanupEarlierThan(final long minLiveSegment) { + final Iterator> toRemove = + segments.headMap(minLiveSegment, false).entrySet().iterator(); + + while (toRemove.hasNext()) { + final Map.Entry next = toRemove.next(); + toRemove.remove(); + final S segment = next.getValue(); + segment.close(); + try { + segment.destroy(); + } catch (final IOException e) { + log.error("Error destroying {}", segment, e); + } + } + } + + private long segmentIdFromSegmentName(final String segmentName, + final File parent) { + final int segmentSeparatorIndex = name.length(); + final char segmentSeparator = segmentName.charAt(segmentSeparatorIndex); + final String segmentIdString = segmentName.substring(segmentSeparatorIndex + 1); + final long segmentId; + + // old style segment name with date + if (segmentSeparator == '-') { + try { + segmentId = formatter.parse(segmentIdString).getTime() / segmentInterval; + } catch (final ParseException e) { + log.warn("Unable to parse segmentName {} to a date. This segment will be skipped", segmentName); + return -1L; + } + renameSegmentFile(parent, segmentName, segmentId); + } else { + // for both new formats (with : or .) parse segment ID identically + try { + segmentId = Long.parseLong(segmentIdString) / segmentInterval; + } catch (final NumberFormatException e) { + throw new ProcessorStateException("Unable to parse segment id as long from segmentName: " + segmentName); + } + + // intermediate segment name with : breaks KafkaStreams on Windows OS -> rename segment file to new name with . + if (segmentSeparator == ':') { + renameSegmentFile(parent, segmentName, segmentId); + } + } + + return segmentId; + + } + + private void renameSegmentFile(final File parent, + final String segmentName, + final long segmentId) { + final File newName = new File(parent, segmentName(segmentId)); + final File oldName = new File(parent, segmentName); + if (!oldName.renameTo(newName)) { + throw new ProcessorStateException("Unable to rename old style segment from: " + + oldName + + " to new name: " + + newName); + } + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractStoreBuilder.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractStoreBuilder.java new file mode 100644 index 0000000..4dde3bd --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/AbstractStoreBuilder.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.state.StoreBuilder; + +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +abstract public class AbstractStoreBuilder implements StoreBuilder { + protected Map logConfig = new HashMap<>(); + protected final String name; + final Serde keySerde; + final Serde valueSerde; + final Time time; + boolean enableCaching; + boolean enableLogging = true; + + public AbstractStoreBuilder(final String name, + final Serde keySerde, + final Serde valueSerde, + final Time time) { + Objects.requireNonNull(name, "name cannot be null"); + Objects.requireNonNull(time, "time cannot be null"); + this.name = name; + this.keySerde = keySerde; + this.valueSerde = valueSerde; + this.time = time; + } + + @Override + public StoreBuilder withCachingEnabled() { + enableCaching = true; + return this; + } + + @Override + public StoreBuilder withCachingDisabled() { + enableCaching = false; + return this; + } + + @Override + public StoreBuilder withLoggingEnabled(final Map config) { + Objects.requireNonNull(config, "config can't be null"); + enableLogging = true; + logConfig.putAll(config); + return this; + } + + @Override + public StoreBuilder withLoggingDisabled() { + enableLogging = false; + logConfig.clear(); + return this; + } + + @Override + public Map logConfig() { + return logConfig; + } + + @Override + public boolean loggingEnabled() { + return enableLogging; + } + + @Override + public String name() { + return name; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/BatchWritingStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/BatchWritingStore.java new file mode 100644 index 0000000..2ac1e3b --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/BatchWritingStore.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.KeyValue; +import org.rocksdb.RocksDBException; +import org.rocksdb.WriteBatch; + +public interface BatchWritingStore { + void addToBatch(final KeyValue record, + final WriteBatch batch) throws RocksDBException; + void write(final WriteBatch batch) throws RocksDBException; +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/BlockBasedTableConfigWithAccessibleCache.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/BlockBasedTableConfigWithAccessibleCache.java new file mode 100644 index 0000000..5a87c7b --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/BlockBasedTableConfigWithAccessibleCache.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.rocksdb.BlockBasedTableConfig; +import org.rocksdb.Cache; + +public class BlockBasedTableConfigWithAccessibleCache extends BlockBasedTableConfig { + + private Cache blockCache = null; + + @Override + public BlockBasedTableConfig setBlockCache(final Cache cache) { + blockCache = cache; + return super.setBlockCache(cache); + } + + public Cache blockCache() { + return blockCache; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/BufferKey.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/BufferKey.java new file mode 100644 index 0000000..9a13aa0 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/BufferKey.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; + +import java.util.Objects; + +public final class BufferKey implements Comparable { + private final long time; + private final Bytes key; + + BufferKey(final long time, final Bytes key) { + this.time = time; + this.key = key; + } + + long time() { + return time; + } + + Bytes key() { + return key; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final BufferKey bufferKey = (BufferKey) o; + return time == bufferKey.time && + Objects.equals(key, bufferKey.key); + } + + @Override + public int hashCode() { + return Objects.hash(time, key); + } + + @Override + public int compareTo(final BufferKey o) { + // ordering of keys within a time uses hashCode. + final int timeComparison = Long.compare(time, o.time); + return timeComparison == 0 ? key.compareTo(o.key) : timeComparison; + } + + @Override + public String toString() { + return "BufferKey{" + + "key=" + key + + ", time=" + time + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/BufferValue.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/BufferValue.java new file mode 100644 index 0000000..f27ab19 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/BufferValue.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.processor.internals.ProcessorRecordContext; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Objects; + +import static org.apache.kafka.common.utils.Utils.getNullableArray; +import static org.apache.kafka.common.utils.Utils.getNullableSizePrefixedArray; + +public final class BufferValue { + private static final int NULL_VALUE_SENTINEL = -1; + private static final int OLD_PREV_DUPLICATE_VALUE_SENTINEL = -2; + private final byte[] priorValue; + private final byte[] oldValue; + private final byte[] newValue; + private final ProcessorRecordContext recordContext; + + BufferValue(final byte[] priorValue, + final byte[] oldValue, + final byte[] newValue, + final ProcessorRecordContext recordContext) { + this.oldValue = oldValue; + this.newValue = newValue; + this.recordContext = recordContext; + + // This de-duplicates the prior and old references. + // If they were already the same reference, the comparison is trivially fast, so we don't specifically check + // for that case. + if (Arrays.equals(priorValue, oldValue)) { + this.priorValue = oldValue; + } else { + this.priorValue = priorValue; + } + } + + byte[] priorValue() { + return priorValue; + } + + byte[] oldValue() { + return oldValue; + } + + byte[] newValue() { + return newValue; + } + + ProcessorRecordContext context() { + return recordContext; + } + + static BufferValue deserialize(final ByteBuffer buffer) { + final ProcessorRecordContext context = ProcessorRecordContext.deserialize(buffer); + + final byte[] priorValue = getNullableSizePrefixedArray(buffer); + + final byte[] oldValue; + final int oldValueLength = buffer.getInt(); + if (oldValueLength == OLD_PREV_DUPLICATE_VALUE_SENTINEL) { + oldValue = priorValue; + } else { + oldValue = getNullableArray(buffer, oldValueLength); + } + + final byte[] newValue = getNullableSizePrefixedArray(buffer); + + return new BufferValue(priorValue, oldValue, newValue, context); + } + + ByteBuffer serialize(final int endPadding) { + + final int sizeOfValueLength = Integer.BYTES; + + final int sizeOfPriorValue = priorValue == null ? 0 : priorValue.length; + final int sizeOfOldValue = oldValue == null || priorValue == oldValue ? 0 : oldValue.length; + final int sizeOfNewValue = newValue == null ? 0 : newValue.length; + + final byte[] serializedContext = recordContext.serialize(); + + final ByteBuffer buffer = ByteBuffer.allocate( + serializedContext.length + + sizeOfValueLength + sizeOfPriorValue + + sizeOfValueLength + sizeOfOldValue + + sizeOfValueLength + sizeOfNewValue + + endPadding + ); + + buffer.put(serializedContext); + + addValue(buffer, priorValue); + + if (oldValue == null) { + buffer.putInt(NULL_VALUE_SENTINEL); + } else if (Arrays.equals(priorValue, oldValue)) { + buffer.putInt(OLD_PREV_DUPLICATE_VALUE_SENTINEL); + } else { + buffer.putInt(sizeOfOldValue); + buffer.put(oldValue); + } + + addValue(buffer, newValue); + + return buffer; + } + + private static void addValue(final ByteBuffer buffer, final byte[] value) { + if (value == null) { + buffer.putInt(NULL_VALUE_SENTINEL); + } else { + buffer.putInt(value.length); + buffer.put(value); + } + } + + long residentMemorySizeEstimate() { + return (priorValue == null ? 0 : priorValue.length) + + (oldValue == null || priorValue == oldValue ? 0 : oldValue.length) + + (newValue == null ? 0 : newValue.length) + + recordContext.residentMemorySizeEstimate(); + } + + @Override + public boolean equals(final Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + final BufferValue that = (BufferValue) o; + return Arrays.equals(priorValue, that.priorValue) && + Arrays.equals(oldValue, that.oldValue) && + Arrays.equals(newValue, that.newValue) && + Objects.equals(recordContext, that.recordContext); + } + + @Override + public int hashCode() { + int result = Objects.hash(recordContext); + result = 31 * result + Arrays.hashCode(priorValue); + result = 31 * result + Arrays.hashCode(oldValue); + result = 31 * result + Arrays.hashCode(newValue); + return result; + } + + @Override + public String toString() { + return "BufferValue{" + + "priorValue=" + Arrays.toString(priorValue) + + ", oldValue=" + Arrays.toString(oldValue) + + ", newValue=" + Arrays.toString(newValue) + + ", recordContext=" + recordContext + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/CacheFlushListener.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/CacheFlushListener.java new file mode 100644 index 0000000..c86d216 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/CacheFlushListener.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.kstream.internals.Change; +import org.apache.kafka.streams.processor.api.Record; + +/** + * Listen to cache flush events + * @param key type + * @param value type + */ +public interface CacheFlushListener { + + /** + * Called when records are flushed from the {@link ThreadCache} + * @param key key of the entry + * @param newValue current value + * @param oldValue previous value + * @param timestamp timestamp of new value + */ + void apply(final K key, final V newValue, final V oldValue, final long timestamp); + + /** + * Called when records are flushed from the {@link ThreadCache} + */ + void apply(final Record> record); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/CacheFunction.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/CacheFunction.java new file mode 100644 index 0000000..66ef2d7 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/CacheFunction.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; + +interface CacheFunction { + Bytes key(Bytes cacheKey); + Bytes cacheKey(Bytes cacheKey); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachedStateStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachedStateStore.java new file mode 100644 index 0000000..37758e1 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachedStateStore.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +public interface CachedStateStore { + /** + * Set the {@link CacheFlushListener} to be notified when entries are flushed from the + * cache to the underlying {@link org.apache.kafka.streams.processor.StateStore} + * @param listener + * @param sendOldValues + */ + boolean setFlushListener(final CacheFlushListener listener, + final boolean sendOldValues); + + /** + * Flush only the cache but not the underlying state stores + * + * TODO: this is a hacky workaround for now, should be removed when we decouple caching with emitting + */ + void flushCache(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingKeyValueStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingKeyValueStore.java new file mode 100644 index 0000000..6129034 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingKeyValueStore.java @@ -0,0 +1,368 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.processor.internals.ProcessorRecordContext; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.LinkedList; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; + +import static org.apache.kafka.streams.processor.internals.ProcessorContextUtils.asInternalProcessorContext; +import static org.apache.kafka.streams.state.internals.ExceptionUtils.executeAll; +import static org.apache.kafka.streams.state.internals.ExceptionUtils.throwSuppressed; + +public class CachingKeyValueStore + extends WrappedStateStore, byte[], byte[]> + implements KeyValueStore, CachedStateStore { + + private static final Logger LOG = LoggerFactory.getLogger(CachingKeyValueStore.class); + + private CacheFlushListener flushListener; + private boolean sendOldValues; + private String cacheName; + private InternalProcessorContext context; + private Thread streamThread; + private final ReadWriteLock lock = new ReentrantReadWriteLock(); + + CachingKeyValueStore(final KeyValueStore underlying) { + super(underlying); + } + + @Deprecated + @Override + public void init(final ProcessorContext context, + final StateStore root) { + initInternal(asInternalProcessorContext(context)); + super.init(context, root); + // save the stream thread as we only ever want to trigger a flush + // when the stream thread is the current thread. + streamThread = Thread.currentThread(); + } + + @Override + public void init(final StateStoreContext context, + final StateStore root) { + initInternal(asInternalProcessorContext(context)); + super.init(context, root); + // save the stream thread as we only ever want to trigger a flush + // when the stream thread is the current thread. + streamThread = Thread.currentThread(); + } + + private void initInternal(final InternalProcessorContext context) { + this.context = context; + + this.cacheName = ThreadCache.nameSpaceFromTaskIdAndStore(context.taskId().toString(), name()); + this.context.registerCacheFlushListener(cacheName, entries -> { + for (final ThreadCache.DirtyEntry entry : entries) { + putAndMaybeForward(entry, context); + } + }); + } + + private void putAndMaybeForward(final ThreadCache.DirtyEntry entry, + final InternalProcessorContext context) { + if (flushListener != null) { + final byte[] rawNewValue = entry.newValue(); + final byte[] rawOldValue = rawNewValue == null || sendOldValues ? wrapped().get(entry.key()) : null; + + // this is an optimization: if this key did not exist in underlying store and also not in the cache, + // we can skip flushing to downstream as well as writing to underlying store + if (rawNewValue != null || rawOldValue != null) { + // we need to get the old values if needed, and then put to store, and then flush + wrapped().put(entry.key(), entry.newValue()); + + final ProcessorRecordContext current = context.recordContext(); + context.setRecordContext(entry.entry().context()); + try { + flushListener.apply( + entry.key().get(), + rawNewValue, + sendOldValues ? rawOldValue : null, + entry.entry().context().timestamp()); + } finally { + context.setRecordContext(current); + } + } + } else { + wrapped().put(entry.key(), entry.newValue()); + } + } + + @Override + public boolean setFlushListener(final CacheFlushListener flushListener, + final boolean sendOldValues) { + this.flushListener = flushListener; + this.sendOldValues = sendOldValues; + + return true; + } + + @Override + public void put(final Bytes key, + final byte[] value) { + Objects.requireNonNull(key, "key cannot be null"); + validateStoreOpen(); + lock.writeLock().lock(); + try { + validateStoreOpen(); + // for null bytes, we still put it into cache indicating tombstones + putInternal(key, value); + } finally { + lock.writeLock().unlock(); + } + } + + private void putInternal(final Bytes key, + final byte[] value) { + context.cache().put( + cacheName, + key, + new LRUCacheEntry( + value, + context.headers(), + true, + context.offset(), + context.timestamp(), + context.partition(), + context.topic())); + } + + @Override + public byte[] putIfAbsent(final Bytes key, + final byte[] value) { + Objects.requireNonNull(key, "key cannot be null"); + validateStoreOpen(); + lock.writeLock().lock(); + try { + validateStoreOpen(); + final byte[] v = getInternal(key); + if (v == null) { + putInternal(key, value); + } + return v; + } finally { + lock.writeLock().unlock(); + } + } + + @Override + public void putAll(final List> entries) { + validateStoreOpen(); + lock.writeLock().lock(); + try { + validateStoreOpen(); + for (final KeyValue entry : entries) { + Objects.requireNonNull(entry.key, "key cannot be null"); + put(entry.key, entry.value); + } + } finally { + lock.writeLock().unlock(); + } + } + + @Override + public byte[] delete(final Bytes key) { + Objects.requireNonNull(key, "key cannot be null"); + validateStoreOpen(); + lock.writeLock().lock(); + try { + validateStoreOpen(); + return deleteInternal(key); + } finally { + lock.writeLock().unlock(); + } + } + + private byte[] deleteInternal(final Bytes key) { + final byte[] v = getInternal(key); + putInternal(key, null); + return v; + } + + @Override + public byte[] get(final Bytes key) { + Objects.requireNonNull(key, "key cannot be null"); + validateStoreOpen(); + final Lock theLock; + if (Thread.currentThread().equals(streamThread)) { + theLock = lock.writeLock(); + } else { + theLock = lock.readLock(); + } + theLock.lock(); + try { + validateStoreOpen(); + return getInternal(key); + } finally { + theLock.unlock(); + } + } + + private byte[] getInternal(final Bytes key) { + LRUCacheEntry entry = null; + if (context.cache() != null) { + entry = context.cache().get(cacheName, key); + } + if (entry == null) { + final byte[] rawValue = wrapped().get(key); + if (rawValue == null) { + return null; + } + // only update the cache if this call is on the streamThread + // as we don't want other threads to trigger an eviction/flush + if (Thread.currentThread().equals(streamThread)) { + context.cache().put(cacheName, key, new LRUCacheEntry(rawValue)); + } + return rawValue; + } else { + return entry.value(); + } + } + + @Override + public KeyValueIterator range(final Bytes from, + final Bytes to) { + if (Objects.nonNull(from) && Objects.nonNull(to) && from.compareTo(to) > 0) { + LOG.warn("Returning empty iterator for fetch with invalid key range: from > to. " + + "This may be due to range arguments set in the wrong order, " + + "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes. " + + "Note that the built-in numerical serdes do not follow this for negative numbers"); + return KeyValueIterators.emptyIterator(); + } + + validateStoreOpen(); + final KeyValueIterator storeIterator = wrapped().range(from, to); + final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = context.cache().range(cacheName, from, to); + return new MergedSortedCacheKeyValueBytesStoreIterator(cacheIterator, storeIterator, true); + } + + @Override + public KeyValueIterator reverseRange(final Bytes from, + final Bytes to) { + if (Objects.nonNull(from) && Objects.nonNull(to) && from.compareTo(to) > 0) { + LOG.warn("Returning empty iterator for fetch with invalid key range: from > to. " + + "This may be due to range arguments set in the wrong order, " + + "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes. " + + "Note that the built-in numerical serdes do not follow this for negative numbers"); + return KeyValueIterators.emptyIterator(); + } + + validateStoreOpen(); + final KeyValueIterator storeIterator = wrapped().reverseRange(from, to); + final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = context.cache().reverseRange(cacheName, from, to); + return new MergedSortedCacheKeyValueBytesStoreIterator(cacheIterator, storeIterator, false); + } + + @Override + public KeyValueIterator all() { + validateStoreOpen(); + final KeyValueIterator storeIterator = + new DelegatingPeekingKeyValueIterator<>(this.name(), wrapped().all()); + final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = context.cache().all(cacheName); + return new MergedSortedCacheKeyValueBytesStoreIterator(cacheIterator, storeIterator, true); + } + + @Override + public , P> KeyValueIterator prefixScan(final P prefix, final PS prefixKeySerializer) { + validateStoreOpen(); + final KeyValueIterator storeIterator = wrapped().prefixScan(prefix, prefixKeySerializer); + final Bytes from = Bytes.wrap(prefixKeySerializer.serialize(null, prefix)); + final Bytes to = Bytes.increment(from); + final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = context.cache().range(cacheName, from, to, false); + return new MergedSortedCacheKeyValueBytesStoreIterator(cacheIterator, storeIterator, true); + } + + @Override + public KeyValueIterator reverseAll() { + validateStoreOpen(); + final KeyValueIterator storeIterator = + new DelegatingPeekingKeyValueIterator<>(this.name(), wrapped().reverseAll()); + final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = context.cache().reverseAll(cacheName); + return new MergedSortedCacheKeyValueBytesStoreIterator(cacheIterator, storeIterator, false); + } + + @Override + public long approximateNumEntries() { + validateStoreOpen(); + lock.readLock().lock(); + try { + validateStoreOpen(); + return wrapped().approximateNumEntries(); + } finally { + lock.readLock().unlock(); + } + } + + @Override + public void flush() { + validateStoreOpen(); + lock.writeLock().lock(); + try { + validateStoreOpen(); + context.cache().flush(cacheName); + wrapped().flush(); + } finally { + lock.writeLock().unlock(); + } + } + + @Override + public void flushCache() { + validateStoreOpen(); + lock.writeLock().lock(); + try { + validateStoreOpen(); + context.cache().flush(cacheName); + } finally { + lock.writeLock().unlock(); + } + } + + @Override + public void close() { + lock.writeLock().lock(); + try { + final LinkedList suppressed = executeAll( + () -> context.cache().flush(cacheName), + () -> context.cache().close(cacheName), + wrapped()::close + ); + if (!suppressed.isEmpty()) { + throwSuppressed("Caught an exception while closing caching key value store for store " + name(), + suppressed); + } + } finally { + lock.writeLock().unlock(); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java new file mode 100644 index 0000000..9b07fe8 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingSessionStore.java @@ -0,0 +1,499 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.SessionWindow; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.processor.internals.ProcessorRecordContext; +import org.apache.kafka.streams.processor.internals.RecordQueue; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.SessionStore; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.LinkedList; +import java.util.NoSuchElementException; +import java.util.Objects; + +import static org.apache.kafka.streams.processor.internals.ProcessorContextUtils.asInternalProcessorContext; +import static org.apache.kafka.streams.state.internals.ExceptionUtils.executeAll; +import static org.apache.kafka.streams.state.internals.ExceptionUtils.throwSuppressed; + +class CachingSessionStore + extends WrappedStateStore, byte[], byte[]> + implements SessionStore, CachedStateStore { + + private static final Logger LOG = LoggerFactory.getLogger(CachingSessionStore.class); + + private final SessionKeySchema keySchema; + private final SegmentedCacheFunction cacheFunction; + private static final String INVALID_RANGE_WARN_MSG = "Returning empty iterator for fetch with invalid key range: from > to. " + + "This may be due to range arguments set in the wrong order, " + + "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes. " + + "Note that the built-in numerical serdes do not follow this for negative numbers"; + + private String cacheName; + private InternalProcessorContext context; + private CacheFlushListener flushListener; + private boolean sendOldValues; + + private long maxObservedTimestamp; // Refers to the window end time (determines segmentId) + + CachingSessionStore(final SessionStore bytesStore, + final long segmentInterval) { + super(bytesStore); + this.keySchema = new SessionKeySchema(); + this.cacheFunction = new SegmentedCacheFunction(keySchema, segmentInterval); + this.maxObservedTimestamp = RecordQueue.UNKNOWN; + } + + @Deprecated + @Override + public void init(final ProcessorContext context, final StateStore root) { + initInternal(asInternalProcessorContext(context)); + super.init(context, root); + } + + @Override + public void init(final StateStoreContext context, final StateStore root) { + initInternal(asInternalProcessorContext(context)); + super.init(context, root); + } + + private void initInternal(final InternalProcessorContext context) { + this.context = context; + + cacheName = context.taskId() + "-" + name(); + context.registerCacheFlushListener(cacheName, entries -> { + for (final ThreadCache.DirtyEntry entry : entries) { + putAndMaybeForward(entry, context); + } + }); + } + + private void putAndMaybeForward(final ThreadCache.DirtyEntry entry, final InternalProcessorContext context) { + final Bytes binaryKey = cacheFunction.key(entry.key()); + final Windowed bytesKey = SessionKeySchema.from(binaryKey); + if (flushListener != null) { + final byte[] newValueBytes = entry.newValue(); + final byte[] oldValueBytes = newValueBytes == null || sendOldValues ? + wrapped().fetchSession(bytesKey.key(), bytesKey.window().start(), bytesKey.window().end()) : null; + + // this is an optimization: if this key did not exist in underlying store and also not in the cache, + // we can skip flushing to downstream as well as writing to underlying store + if (newValueBytes != null || oldValueBytes != null) { + // we need to get the old values if needed, and then put to store, and then flush + wrapped().put(bytesKey, entry.newValue()); + + final ProcessorRecordContext current = context.recordContext(); + context.setRecordContext(entry.entry().context()); + try { + flushListener.apply( + binaryKey.get(), + newValueBytes, + sendOldValues ? oldValueBytes : null, + entry.entry().context().timestamp()); + } finally { + context.setRecordContext(current); + } + } + } else { + wrapped().put(bytesKey, entry.newValue()); + } + } + + @Override + public boolean setFlushListener(final CacheFlushListener flushListener, + final boolean sendOldValues) { + this.flushListener = flushListener; + this.sendOldValues = sendOldValues; + + return true; + } + + @Override + public void put(final Windowed key, final byte[] value) { + validateStoreOpen(); + final Bytes binaryKey = SessionKeySchema.toBinary(key); + final LRUCacheEntry entry = + new LRUCacheEntry( + value, + context.headers(), + true, + context.offset(), + context.timestamp(), + context.partition(), + context.topic()); + context.cache().put(cacheName, cacheFunction.cacheKey(binaryKey), entry); + + maxObservedTimestamp = Math.max(keySchema.segmentTimestamp(binaryKey), maxObservedTimestamp); + } + + @Override + public void remove(final Windowed sessionKey) { + validateStoreOpen(); + put(sessionKey, null); + } + + @Override + public KeyValueIterator, byte[]> findSessions(final Bytes key, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + validateStoreOpen(); + + final PeekingKeyValueIterator cacheIterator = wrapped().persistent() ? + new CacheIteratorWrapper(key, earliestSessionEndTime, latestSessionStartTime, true) : + context.cache().range(cacheName, + cacheFunction.cacheKey(keySchema.lowerRangeFixedSize(key, earliestSessionEndTime)), + cacheFunction.cacheKey(keySchema.upperRangeFixedSize(key, latestSessionStartTime)) + ); + + final KeyValueIterator, byte[]> storeIterator = wrapped().findSessions(key, + earliestSessionEndTime, + latestSessionStartTime); + final HasNextCondition hasNextCondition = keySchema.hasNextCondition(key, + key, + earliestSessionEndTime, + latestSessionStartTime); + final PeekingKeyValueIterator filteredCacheIterator = + new FilteredCacheIterator(cacheIterator, hasNextCondition, cacheFunction); + return new MergedSortedCacheSessionStoreIterator(filteredCacheIterator, storeIterator, cacheFunction, true); + } + + @Override + public KeyValueIterator, byte[]> backwardFindSessions(final Bytes key, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + validateStoreOpen(); + + final PeekingKeyValueIterator cacheIterator = wrapped().persistent() ? + new CacheIteratorWrapper(key, earliestSessionEndTime, latestSessionStartTime, false) : + context.cache().reverseRange( + cacheName, + cacheFunction.cacheKey(keySchema.lowerRangeFixedSize(key, earliestSessionEndTime)), + cacheFunction.cacheKey(keySchema.upperRangeFixedSize(key, latestSessionStartTime) + ) + ); + + final KeyValueIterator, byte[]> storeIterator = wrapped().backwardFindSessions( + key, + earliestSessionEndTime, + latestSessionStartTime + ); + final HasNextCondition hasNextCondition = keySchema.hasNextCondition( + key, + key, + earliestSessionEndTime, + latestSessionStartTime + ); + final PeekingKeyValueIterator filteredCacheIterator = + new FilteredCacheIterator(cacheIterator, hasNextCondition, cacheFunction); + return new MergedSortedCacheSessionStoreIterator(filteredCacheIterator, storeIterator, cacheFunction, false); + } + + @Override + public KeyValueIterator, byte[]> findSessions(final Bytes keyFrom, + final Bytes keyTo, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + if (keyFrom != null && keyTo != null && keyFrom.compareTo(keyTo) > 0) { + LOG.warn(INVALID_RANGE_WARN_MSG); + return KeyValueIterators.emptyIterator(); + } + + validateStoreOpen(); + + final Bytes cacheKeyFrom = keyFrom == null ? null : cacheFunction.cacheKey(keySchema.lowerRange(keyFrom, earliestSessionEndTime)); + final Bytes cacheKeyTo = keyTo == null ? null : cacheFunction.cacheKey(keySchema.upperRange(keyTo, latestSessionStartTime)); + final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = context.cache().range(cacheName, cacheKeyFrom, cacheKeyTo); + + final KeyValueIterator, byte[]> storeIterator = wrapped().findSessions( + keyFrom, keyTo, earliestSessionEndTime, latestSessionStartTime + ); + final HasNextCondition hasNextCondition = keySchema.hasNextCondition(keyFrom, + keyTo, + earliestSessionEndTime, + latestSessionStartTime); + final PeekingKeyValueIterator filteredCacheIterator = + new FilteredCacheIterator(cacheIterator, hasNextCondition, cacheFunction); + return new MergedSortedCacheSessionStoreIterator(filteredCacheIterator, storeIterator, cacheFunction, true); + } + + @Override + public KeyValueIterator, byte[]> backwardFindSessions(final Bytes keyFrom, + final Bytes keyTo, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + if (keyFrom != null && keyTo != null && keyFrom.compareTo(keyTo) > 0) { + LOG.warn(INVALID_RANGE_WARN_MSG); + return KeyValueIterators.emptyIterator(); + } + + validateStoreOpen(); + + final Bytes cacheKeyFrom = keyFrom == null ? null : cacheFunction.cacheKey(keySchema.lowerRange(keyFrom, earliestSessionEndTime)); + final Bytes cacheKeyTo = keyTo == null ? null : cacheFunction.cacheKey(keySchema.upperRange(keyTo, latestSessionStartTime)); + final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = context.cache().reverseRange(cacheName, cacheKeyFrom, cacheKeyTo); + + final KeyValueIterator, byte[]> storeIterator = + wrapped().backwardFindSessions(keyFrom, keyTo, earliestSessionEndTime, latestSessionStartTime); + final HasNextCondition hasNextCondition = keySchema.hasNextCondition( + keyFrom, + keyTo, + earliestSessionEndTime, + latestSessionStartTime + ); + final PeekingKeyValueIterator filteredCacheIterator = + new FilteredCacheIterator(cacheIterator, hasNextCondition, cacheFunction); + return new MergedSortedCacheSessionStoreIterator(filteredCacheIterator, storeIterator, cacheFunction, false); + } + + @Override + public byte[] fetchSession(final Bytes key, final long earliestSessionEndTime, final long latestSessionStartTime) { + Objects.requireNonNull(key, "key cannot be null"); + validateStoreOpen(); + if (context.cache() == null) { + return wrapped().fetchSession(key, earliestSessionEndTime, latestSessionStartTime); + } else { + final Bytes bytesKey = SessionKeySchema.toBinary(key, earliestSessionEndTime, + latestSessionStartTime); + final Bytes cacheKey = cacheFunction.cacheKey(bytesKey); + final LRUCacheEntry entry = context.cache().get(cacheName, cacheKey); + if (entry == null) { + return wrapped().fetchSession(key, earliestSessionEndTime, latestSessionStartTime); + } else { + return entry.value(); + } + } + } + + @Override + public KeyValueIterator, byte[]> fetch(final Bytes key) { + Objects.requireNonNull(key, "key cannot be null"); + return findSessions(key, 0, Long.MAX_VALUE); + } + + @Override + public KeyValueIterator, byte[]> backwardFetch(final Bytes key) { + Objects.requireNonNull(key, "key cannot be null"); + return backwardFindSessions(key, 0, Long.MAX_VALUE); + } + + @Override + public KeyValueIterator, byte[]> fetch(final Bytes keyFrom, + final Bytes keyTo) { + return findSessions(keyFrom, keyTo, 0, Long.MAX_VALUE); + } + + @Override + public KeyValueIterator, byte[]> backwardFetch(final Bytes keyFrom, + final Bytes keyTo) { + return backwardFindSessions(keyFrom, keyTo, 0, Long.MAX_VALUE); + } + + public void flush() { + context.cache().flush(cacheName); + wrapped().flush(); + } + + @Override + public void flushCache() { + context.cache().flush(cacheName); + } + + public void close() { + final LinkedList suppressed = executeAll( + () -> context.cache().flush(cacheName), + () -> context.cache().close(cacheName), + wrapped()::close + ); + if (!suppressed.isEmpty()) { + throwSuppressed("Caught an exception while closing caching session store for store " + name(), + suppressed); + } + } + + private class CacheIteratorWrapper implements PeekingKeyValueIterator { + + private final long segmentInterval; + + private final Bytes keyFrom; + private final Bytes keyTo; + private final long latestSessionStartTime; + private final boolean forward; + + private long lastSegmentId; + + private long currentSegmentId; + private Bytes cacheKeyFrom; + private Bytes cacheKeyTo; + + private ThreadCache.MemoryLRUCacheBytesIterator current; + + private CacheIteratorWrapper(final Bytes key, + final long earliestSessionEndTime, + final long latestSessionStartTime, + final boolean forward) { + this(key, key, earliestSessionEndTime, latestSessionStartTime, forward); + } + + private CacheIteratorWrapper(final Bytes keyFrom, + final Bytes keyTo, + final long earliestSessionEndTime, + final long latestSessionStartTime, + final boolean forward) { + this.keyFrom = keyFrom; + this.keyTo = keyTo; + this.latestSessionStartTime = latestSessionStartTime; + this.segmentInterval = cacheFunction.getSegmentInterval(); + this.forward = forward; + + + if (forward) { + this.currentSegmentId = cacheFunction.segmentId(earliestSessionEndTime); + this.lastSegmentId = cacheFunction.segmentId(maxObservedTimestamp); + + setCacheKeyRange(earliestSessionEndTime, currentSegmentLastTime()); + this.current = context.cache().range(cacheName, cacheKeyFrom, cacheKeyTo); + } else { + this.lastSegmentId = cacheFunction.segmentId(earliestSessionEndTime); + this.currentSegmentId = cacheFunction.segmentId(maxObservedTimestamp); + + setCacheKeyRange(currentSegmentBeginTime(), Math.min(latestSessionStartTime, maxObservedTimestamp)); + this.current = context.cache().reverseRange(cacheName, cacheKeyFrom, cacheKeyTo); + } + } + + @Override + public boolean hasNext() { + if (current == null) { + return false; + } + + if (current.hasNext()) { + return true; + } + + while (!current.hasNext()) { + getNextSegmentIterator(); + if (current == null) { + return false; + } + } + return true; + } + + @Override + public Bytes peekNextKey() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return current.peekNextKey(); + } + + @Override + public KeyValue peekNext() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return current.peekNext(); + } + + @Override + public KeyValue next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return current.next(); + } + + @Override + public void close() { + current.close(); + } + + private long currentSegmentBeginTime() { + return currentSegmentId * segmentInterval; + } + + private long currentSegmentLastTime() { + return currentSegmentBeginTime() + segmentInterval - 1; + } + + private void getNextSegmentIterator() { + if (forward) { + ++currentSegmentId; + lastSegmentId = cacheFunction.segmentId(maxObservedTimestamp); + + if (currentSegmentId > lastSegmentId) { + current = null; + return; + } + + setCacheKeyRange(currentSegmentBeginTime(), currentSegmentLastTime()); + + current.close(); + + current = context.cache().range(cacheName, cacheKeyFrom, cacheKeyTo); + } else { + --currentSegmentId; + + if (currentSegmentId < lastSegmentId) { + current = null; + return; + } + + setCacheKeyRange(currentSegmentBeginTime(), currentSegmentLastTime()); + + current.close(); + + current = context.cache().reverseRange(cacheName, cacheKeyFrom, cacheKeyTo); + } + + } + + private void setCacheKeyRange(final long lowerRangeEndTime, final long upperRangeEndTime) { + if (cacheFunction.segmentId(lowerRangeEndTime) != cacheFunction.segmentId(upperRangeEndTime)) { + throw new IllegalStateException("Error iterating over segments: segment interval has changed"); + } + + if (keyFrom.equals(keyTo)) { + cacheKeyFrom = cacheFunction.cacheKey(segmentLowerRangeFixedSize(keyFrom, lowerRangeEndTime)); + cacheKeyTo = cacheFunction.cacheKey(segmentUpperRangeFixedSize(keyTo, upperRangeEndTime)); + } else { + cacheKeyFrom = cacheFunction.cacheKey(keySchema.lowerRange(keyFrom, lowerRangeEndTime), currentSegmentId); + cacheKeyTo = cacheFunction.cacheKey(keySchema.upperRange(keyTo, latestSessionStartTime), currentSegmentId); + } + } + + private Bytes segmentLowerRangeFixedSize(final Bytes key, final long segmentBeginTime) { + final Windowed sessionKey = new Windowed<>(key, new SessionWindow(0, Math.max(0, segmentBeginTime))); + return SessionKeySchema.toBinary(sessionKey); + } + + private Bytes segmentUpperRangeFixedSize(final Bytes key, final long segmentEndTime) { + final Windowed sessionKey = new Windowed<>(key, new SessionWindow(Math.min(latestSessionStartTime, segmentEndTime), segmentEndTime)); + return SessionKeySchema.toBinary(sessionKey); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingWindowStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingWindowStore.java new file mode 100644 index 0000000..fa04ac8 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/CachingWindowStore.java @@ -0,0 +1,595 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.processor.internals.ProcessorRecordContext; +import org.apache.kafka.streams.processor.internals.ProcessorStateManager; +import org.apache.kafka.streams.processor.internals.RecordQueue; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.StateSerdes; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.streams.state.WindowStoreIterator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.LinkedList; +import java.util.NoSuchElementException; +import java.util.concurrent.atomic.AtomicLong; + +import static org.apache.kafka.streams.processor.internals.ProcessorContextUtils.asInternalProcessorContext; +import static org.apache.kafka.streams.state.internals.ExceptionUtils.executeAll; +import static org.apache.kafka.streams.state.internals.ExceptionUtils.throwSuppressed; + +class CachingWindowStore + extends WrappedStateStore, byte[], byte[]> + implements WindowStore, CachedStateStore { + + private static final Logger LOG = LoggerFactory.getLogger(CachingWindowStore.class); + + private final long windowSize; + private final SegmentedCacheFunction cacheFunction; + private final SegmentedBytesStore.KeySchema keySchema = new WindowKeySchema(); + + private String cacheName; + private boolean sendOldValues; + private InternalProcessorContext context; + private StateSerdes bytesSerdes; + private CacheFlushListener flushListener; + + private final AtomicLong maxObservedTimestamp; + + CachingWindowStore(final WindowStore underlying, + final long windowSize, + final long segmentInterval) { + super(underlying); + this.windowSize = windowSize; + this.cacheFunction = new SegmentedCacheFunction(keySchema, segmentInterval); + this.maxObservedTimestamp = new AtomicLong(RecordQueue.UNKNOWN); + } + + @Deprecated + @Override + public void init(final ProcessorContext context, final StateStore root) { + initInternal(asInternalProcessorContext(context)); + super.init(context, root); + } + + @Override + public void init(final StateStoreContext context, final StateStore root) { + initInternal(asInternalProcessorContext(context)); + super.init(context, root); + } + + private void initInternal(final InternalProcessorContext context) { + this.context = context; + final String topic = ProcessorStateManager.storeChangelogTopic(context.applicationId(), name(), context.taskId().topologyName()); + + bytesSerdes = new StateSerdes<>( + topic, + Serdes.Bytes(), + Serdes.ByteArray()); + cacheName = context.taskId() + "-" + name(); + + context.registerCacheFlushListener(cacheName, entries -> { + for (final ThreadCache.DirtyEntry entry : entries) { + putAndMaybeForward(entry, context); + } + }); + } + + private void putAndMaybeForward(final ThreadCache.DirtyEntry entry, + final InternalProcessorContext context) { + final byte[] binaryWindowKey = cacheFunction.key(entry.key()).get(); + final Windowed windowedKeyBytes = WindowKeySchema.fromStoreBytesKey(binaryWindowKey, windowSize); + final long windowStartTimestamp = windowedKeyBytes.window().start(); + final Bytes binaryKey = windowedKeyBytes.key(); + if (flushListener != null) { + final byte[] rawNewValue = entry.newValue(); + final byte[] rawOldValue = rawNewValue == null || sendOldValues ? + wrapped().fetch(binaryKey, windowStartTimestamp) : null; + + // this is an optimization: if this key did not exist in underlying store and also not in the cache, + // we can skip flushing to downstream as well as writing to underlying store + if (rawNewValue != null || rawOldValue != null) { + // we need to get the old values if needed, and then put to store, and then flush + wrapped().put(binaryKey, entry.newValue(), windowStartTimestamp); + + final ProcessorRecordContext current = context.recordContext(); + context.setRecordContext(entry.entry().context()); + try { + flushListener.apply( + binaryWindowKey, + rawNewValue, + sendOldValues ? rawOldValue : null, + entry.entry().context().timestamp()); + } finally { + context.setRecordContext(current); + } + } + } else { + wrapped().put(binaryKey, entry.newValue(), windowStartTimestamp); + } + } + + @Override + public boolean setFlushListener(final CacheFlushListener flushListener, + final boolean sendOldValues) { + this.flushListener = flushListener; + this.sendOldValues = sendOldValues; + + return true; + } + + + @Override + public synchronized void put(final Bytes key, + final byte[] value, + final long windowStartTimestamp) { + // since this function may not access the underlying inner store, we need to validate + // if store is open outside as well. + validateStoreOpen(); + + final Bytes keyBytes = WindowKeySchema.toStoreKeyBinary(key, windowStartTimestamp, 0); + final LRUCacheEntry entry = + new LRUCacheEntry( + value, + context.headers(), + true, + context.offset(), + context.timestamp(), + context.partition(), + context.topic()); + context.cache().put(cacheName, cacheFunction.cacheKey(keyBytes), entry); + + maxObservedTimestamp.set(Math.max(keySchema.segmentTimestamp(keyBytes), maxObservedTimestamp.get())); + } + + @Override + public byte[] fetch(final Bytes key, + final long timestamp) { + validateStoreOpen(); + final Bytes bytesKey = WindowKeySchema.toStoreKeyBinary(key, timestamp, 0); + final Bytes cacheKey = cacheFunction.cacheKey(bytesKey); + if (context.cache() == null) { + return wrapped().fetch(key, timestamp); + } + final LRUCacheEntry entry = context.cache().get(cacheName, cacheKey); + if (entry == null) { + return wrapped().fetch(key, timestamp); + } else { + return entry.value(); + } + } + + @Override + public synchronized WindowStoreIterator fetch(final Bytes key, + final long timeFrom, + final long timeTo) { + // since this function may not access the underlying inner store, we need to validate + // if store is open outside as well. + validateStoreOpen(); + + final WindowStoreIterator underlyingIterator = wrapped().fetch(key, timeFrom, timeTo); + if (context.cache() == null) { + return underlyingIterator; + } + + final PeekingKeyValueIterator cacheIterator = wrapped().persistent() ? + new CacheIteratorWrapper(key, timeFrom, timeTo, true) : + context.cache().range( + cacheName, + cacheFunction.cacheKey(keySchema.lowerRangeFixedSize(key, timeFrom)), + cacheFunction.cacheKey(keySchema.upperRangeFixedSize(key, timeTo)) + ); + + final HasNextCondition hasNextCondition = keySchema.hasNextCondition(key, key, timeFrom, timeTo); + final PeekingKeyValueIterator filteredCacheIterator = + new FilteredCacheIterator(cacheIterator, hasNextCondition, cacheFunction); + + return new MergedSortedCacheWindowStoreIterator(filteredCacheIterator, underlyingIterator, true); + } + + @Override + public synchronized WindowStoreIterator backwardFetch(final Bytes key, + final long timeFrom, + final long timeTo) { + // since this function may not access the underlying inner store, we need to validate + // if store is open outside as well. + validateStoreOpen(); + + final WindowStoreIterator underlyingIterator = wrapped().backwardFetch(key, timeFrom, timeTo); + if (context.cache() == null) { + return underlyingIterator; + } + + final PeekingKeyValueIterator cacheIterator = wrapped().persistent() ? + new CacheIteratorWrapper(key, timeFrom, timeTo, false) : + context.cache().reverseRange( + cacheName, + cacheFunction.cacheKey(keySchema.lowerRangeFixedSize(key, timeFrom)), + cacheFunction.cacheKey(keySchema.upperRangeFixedSize(key, timeTo)) + ); + + final HasNextCondition hasNextCondition = keySchema.hasNextCondition(key, key, timeFrom, timeTo); + final PeekingKeyValueIterator filteredCacheIterator = + new FilteredCacheIterator(cacheIterator, hasNextCondition, cacheFunction); + + return new MergedSortedCacheWindowStoreIterator(filteredCacheIterator, underlyingIterator, false); + } + + @Override + public KeyValueIterator, byte[]> fetch(final Bytes keyFrom, + final Bytes keyTo, + final long timeFrom, + final long timeTo) { + if (keyFrom != null && keyTo != null && keyFrom.compareTo(keyTo) > 0) { + LOG.warn("Returning empty iterator for fetch with invalid key range: from > to. " + + "This may be due to range arguments set in the wrong order, " + + "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes. " + + "Note that the built-in numerical serdes do not follow this for negative numbers"); + return KeyValueIterators.emptyIterator(); + } + + // since this function may not access the underlying inner store, we need to validate + // if store is open outside as well. + validateStoreOpen(); + + final KeyValueIterator, byte[]> underlyingIterator = + wrapped().fetch(keyFrom, keyTo, timeFrom, timeTo); + if (context.cache() == null) { + return underlyingIterator; + } + + final PeekingKeyValueIterator cacheIterator = wrapped().persistent() ? + new CacheIteratorWrapper(keyFrom, keyTo, timeFrom, timeTo, true) : + context.cache().range( + cacheName, + keyFrom == null ? null : cacheFunction.cacheKey(keySchema.lowerRange(keyFrom, timeFrom)), + keyTo == null ? null : cacheFunction.cacheKey(keySchema.upperRange(keyTo, timeTo)) + ); + + final HasNextCondition hasNextCondition = keySchema.hasNextCondition(keyFrom, keyTo, timeFrom, timeTo); + final PeekingKeyValueIterator filteredCacheIterator = + new FilteredCacheIterator(cacheIterator, hasNextCondition, cacheFunction); + + return new MergedSortedCacheWindowStoreKeyValueIterator( + filteredCacheIterator, + underlyingIterator, + bytesSerdes, + windowSize, + cacheFunction, + true + ); + } + + @Override + public KeyValueIterator, byte[]> backwardFetch(final Bytes keyFrom, + final Bytes keyTo, + final long timeFrom, + final long timeTo) { + if (keyFrom != null && keyTo != null && keyFrom.compareTo(keyTo) > 0) { + LOG.warn("Returning empty iterator for fetch with invalid key range: from > to. " + + "This may be due to serdes that don't preserve ordering when lexicographically comparing the serialized bytes. " + + "Note that the built-in numerical serdes do not follow this for negative numbers"); + return KeyValueIterators.emptyIterator(); + } + + // since this function may not access the underlying inner store, we need to validate + // if store is open outside as well. + validateStoreOpen(); + + final KeyValueIterator, byte[]> underlyingIterator = + wrapped().backwardFetch(keyFrom, keyTo, timeFrom, timeTo); + if (context.cache() == null) { + return underlyingIterator; + } + + final PeekingKeyValueIterator cacheIterator = wrapped().persistent() ? + new CacheIteratorWrapper(keyFrom, keyTo, timeFrom, timeTo, false) : + context.cache().reverseRange( + cacheName, + keyFrom == null ? null : cacheFunction.cacheKey(keySchema.lowerRange(keyFrom, timeFrom)), + keyTo == null ? null : cacheFunction.cacheKey(keySchema.upperRange(keyTo, timeTo)) + ); + + final HasNextCondition hasNextCondition = keySchema.hasNextCondition(keyFrom, keyTo, timeFrom, timeTo); + final PeekingKeyValueIterator filteredCacheIterator = + new FilteredCacheIterator(cacheIterator, hasNextCondition, cacheFunction); + + return new MergedSortedCacheWindowStoreKeyValueIterator( + filteredCacheIterator, + underlyingIterator, + bytesSerdes, + windowSize, + cacheFunction, + false + ); + } + + @Override + public KeyValueIterator, byte[]> fetchAll(final long timeFrom, + final long timeTo) { + validateStoreOpen(); + + final KeyValueIterator, byte[]> underlyingIterator = wrapped().fetchAll(timeFrom, timeTo); + final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = context.cache().all(cacheName); + + final HasNextCondition hasNextCondition = keySchema.hasNextCondition(null, null, timeFrom, timeTo); + final PeekingKeyValueIterator filteredCacheIterator = + new FilteredCacheIterator(cacheIterator, hasNextCondition, cacheFunction); + return new MergedSortedCacheWindowStoreKeyValueIterator( + filteredCacheIterator, + underlyingIterator, + bytesSerdes, + windowSize, + cacheFunction, + true + ); + } + + @Override + public KeyValueIterator, byte[]> backwardFetchAll(final long timeFrom, + final long timeTo) { + validateStoreOpen(); + + final KeyValueIterator, byte[]> underlyingIterator = wrapped().backwardFetchAll(timeFrom, timeTo); + final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = context.cache().reverseAll(cacheName); + + final HasNextCondition hasNextCondition = keySchema.hasNextCondition(null, null, timeFrom, timeTo); + final PeekingKeyValueIterator filteredCacheIterator = + new FilteredCacheIterator(cacheIterator, hasNextCondition, cacheFunction); + + return new MergedSortedCacheWindowStoreKeyValueIterator( + filteredCacheIterator, + underlyingIterator, + bytesSerdes, + windowSize, + cacheFunction, + false + ); + } + + @Override + public KeyValueIterator, byte[]> all() { + validateStoreOpen(); + + final KeyValueIterator, byte[]> underlyingIterator = wrapped().all(); + final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = context.cache().all(cacheName); + + return new MergedSortedCacheWindowStoreKeyValueIterator( + cacheIterator, + underlyingIterator, + bytesSerdes, + windowSize, + cacheFunction, + true + ); + } + + @Override + public KeyValueIterator, byte[]> backwardAll() { + validateStoreOpen(); + + final KeyValueIterator, byte[]> underlyingIterator = wrapped().backwardAll(); + final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = context.cache().reverseAll(cacheName); + + return new MergedSortedCacheWindowStoreKeyValueIterator( + cacheIterator, + underlyingIterator, + bytesSerdes, + windowSize, + cacheFunction, + false + ); + } + + @Override + public synchronized void flush() { + context.cache().flush(cacheName); + wrapped().flush(); + } + + @Override + public void flushCache() { + context.cache().flush(cacheName); + } + + @Override + public synchronized void close() { + final LinkedList suppressed = executeAll( + () -> context.cache().flush(cacheName), + () -> context.cache().close(cacheName), + wrapped()::close + ); + if (!suppressed.isEmpty()) { + throwSuppressed("Caught an exception while closing caching window store for store " + name(), + suppressed); + } + } + + + private class CacheIteratorWrapper implements PeekingKeyValueIterator { + + private final long segmentInterval; + private final Bytes keyFrom; + private final Bytes keyTo; + private final long timeTo; + private final boolean forward; + + private long lastSegmentId; + private long currentSegmentId; + private Bytes cacheKeyFrom; + private Bytes cacheKeyTo; + + private ThreadCache.MemoryLRUCacheBytesIterator current; + + private CacheIteratorWrapper(final Bytes key, + final long timeFrom, + final long timeTo, + final boolean forward) { + this(key, key, timeFrom, timeTo, forward); + } + + private CacheIteratorWrapper(final Bytes keyFrom, + final Bytes keyTo, + final long timeFrom, + final long timeTo, + final boolean forward) { + this.keyFrom = keyFrom; + this.keyTo = keyTo; + this.timeTo = timeTo; + this.forward = forward; + + this.segmentInterval = cacheFunction.getSegmentInterval(); + + if (forward) { + this.lastSegmentId = cacheFunction.segmentId(Math.min(timeTo, maxObservedTimestamp.get())); + this.currentSegmentId = cacheFunction.segmentId(timeFrom); + + setCacheKeyRange(timeFrom, currentSegmentLastTime()); + this.current = context.cache().range(cacheName, cacheKeyFrom, cacheKeyTo); + } else { + this.currentSegmentId = cacheFunction.segmentId(Math.min(timeTo, maxObservedTimestamp.get())); + this.lastSegmentId = cacheFunction.segmentId(timeFrom); + + setCacheKeyRange(currentSegmentBeginTime(), Math.min(timeTo, maxObservedTimestamp.get())); + this.current = context.cache().reverseRange(cacheName, cacheKeyFrom, cacheKeyTo); + } + } + + @Override + public boolean hasNext() { + if (current == null) { + return false; + } + + if (current.hasNext()) { + return true; + } + + while (!current.hasNext()) { + getNextSegmentIterator(); + if (current == null) { + return false; + } + } + return true; + } + + @Override + public Bytes peekNextKey() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return current.peekNextKey(); + } + + @Override + public KeyValue peekNext() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return current.peekNext(); + } + + @Override + public KeyValue next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return current.next(); + } + + @Override + public void close() { + current.close(); + } + + private long currentSegmentBeginTime() { + return currentSegmentId * segmentInterval; + } + + private long currentSegmentLastTime() { + return Math.min(timeTo, currentSegmentBeginTime() + segmentInterval - 1); + } + + private void getNextSegmentIterator() { + if (forward) { + ++currentSegmentId; + // updating as maxObservedTimestamp can change while iterating + lastSegmentId = cacheFunction.segmentId(Math.min(timeTo, maxObservedTimestamp.get())); + + if (currentSegmentId > lastSegmentId) { + current = null; + return; + } + + setCacheKeyRange(currentSegmentBeginTime(), currentSegmentLastTime()); + + current.close(); + + current = context.cache().range(cacheName, cacheKeyFrom, cacheKeyTo); + } else { + --currentSegmentId; + + // last segment id is stable when iterating backward, therefore no need to update + if (currentSegmentId < lastSegmentId) { + current = null; + return; + } + + setCacheKeyRange(currentSegmentBeginTime(), currentSegmentLastTime()); + + current.close(); + + current = context.cache().reverseRange(cacheName, cacheKeyFrom, cacheKeyTo); + } + } + + private void setCacheKeyRange(final long lowerRangeEndTime, final long upperRangeEndTime) { + if (cacheFunction.segmentId(lowerRangeEndTime) != cacheFunction.segmentId(upperRangeEndTime)) { + throw new IllegalStateException("Error iterating over segments: segment interval has changed"); + } + + if (keyFrom != null && keyTo != null && keyFrom.equals(keyTo)) { + cacheKeyFrom = cacheFunction.cacheKey(segmentLowerRangeFixedSize(keyFrom, lowerRangeEndTime)); + cacheKeyTo = cacheFunction.cacheKey(segmentUpperRangeFixedSize(keyTo, upperRangeEndTime)); + } else { + cacheKeyFrom = keyFrom == null ? null : + cacheFunction.cacheKey(keySchema.lowerRange(keyFrom, lowerRangeEndTime), currentSegmentId); + cacheKeyTo = keyTo == null ? null : + cacheFunction.cacheKey(keySchema.upperRange(keyTo, timeTo), currentSegmentId); + } + } + + private Bytes segmentLowerRangeFixedSize(final Bytes key, final long segmentBeginTime) { + return WindowKeySchema.toStoreKeyBinary(key, Math.max(0, segmentBeginTime), 0); + } + + private Bytes segmentUpperRangeFixedSize(final Bytes key, final long segmentEndTime) { + return WindowKeySchema.toStoreKeyBinary(key, segmentEndTime, Integer.MAX_VALUE); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/ChangeLoggingKeyValueBytesStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/ChangeLoggingKeyValueBytesStore.java new file mode 100644 index 0000000..88c9292 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/ChangeLoggingKeyValueBytesStore.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; + +import java.util.List; + +import static org.apache.kafka.streams.processor.internals.ProcessorContextUtils.asInternalProcessorContext; + +public class ChangeLoggingKeyValueBytesStore + extends WrappedStateStore, byte[], byte[]> + implements KeyValueStore { + + InternalProcessorContext context; + + ChangeLoggingKeyValueBytesStore(final KeyValueStore inner) { + super(inner); + } + + @Deprecated + @Override + public void init(final ProcessorContext context, + final StateStore root) { + this.context = asInternalProcessorContext(context); + super.init(context, root); + maybeSetEvictionListener(); + } + + @Override + public void init(final StateStoreContext context, + final StateStore root) { + this.context = asInternalProcessorContext(context); + super.init(context, root); + maybeSetEvictionListener(); + } + + private void maybeSetEvictionListener() { + // if the inner store is an LRU cache, add the eviction listener to log removed record + if (wrapped() instanceof MemoryLRUCache) { + ((MemoryLRUCache) wrapped()).setWhenEldestRemoved((key, value) -> { + // pass null to indicate removal + log(key, null); + }); + } + } + + @Override + public long approximateNumEntries() { + return wrapped().approximateNumEntries(); + } + + @Override + public void put(final Bytes key, + final byte[] value) { + wrapped().put(key, value); + log(key, value); + } + + @Override + public byte[] putIfAbsent(final Bytes key, + final byte[] value) { + final byte[] previous = wrapped().putIfAbsent(key, value); + if (previous == null) { + // then it was absent + log(key, value); + } + return previous; + } + + @Override + public void putAll(final List> entries) { + wrapped().putAll(entries); + for (final KeyValue entry : entries) { + log(entry.key, entry.value); + } + } + + @Override + public , P> KeyValueIterator prefixScan(final P prefix, + final PS prefixKeySerializer) { + return wrapped().prefixScan(prefix, prefixKeySerializer); + } + + @Override + public byte[] delete(final Bytes key) { + final byte[] oldValue = wrapped().delete(key); + log(key, null); + return oldValue; + } + + @Override + public byte[] get(final Bytes key) { + return wrapped().get(key); + } + + @Override + public KeyValueIterator range(final Bytes from, + final Bytes to) { + return wrapped().range(from, to); + } + + @Override + public KeyValueIterator reverseRange(final Bytes from, + final Bytes to) { + return wrapped().reverseRange(from, to); + } + + @Override + public KeyValueIterator all() { + return wrapped().all(); + } + + @Override + public KeyValueIterator reverseAll() { + return wrapped().reverseAll(); + } + + void log(final Bytes key, + final byte[] value) { + context.logChange(name(), key, value, context.timestamp()); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/ChangeLoggingListValueBytesStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/ChangeLoggingListValueBytesStore.java new file mode 100644 index 0000000..c01594b --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/ChangeLoggingListValueBytesStore.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.state.KeyValueStore; + +public class ChangeLoggingListValueBytesStore extends ChangeLoggingKeyValueBytesStore { + + ChangeLoggingListValueBytesStore(final KeyValueStore inner) { + super(inner); + } + + @Override + public void put(final Bytes key, final byte[] value) { + wrapped().put(key, value); + // the provided new value will be added to the list in the inner put() + // we need to log the full new list and thus call get() on the inner store below + // if the value is a tombstone, we delete the whole list and thus can save the get call + if (value == null) { + log(key, null); + } else { + log(key, wrapped().get(key)); + } + } + + @Override + public byte[] putIfAbsent(final Bytes key, final byte[] value) { + final byte[] oldValue = wrapped().get(key); + + if (oldValue != null) { + put(key, value); + } + + // TODO: here we always return null so that deser would not fail. + // we only do this since we know the only caller (stream-stream join processor) + // would not need the actual value at all + return null; + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/ChangeLoggingSessionBytesStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/ChangeLoggingSessionBytesStore.java new file mode 100644 index 0000000..baa9846 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/ChangeLoggingSessionBytesStore.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.SessionStore; + +import static org.apache.kafka.streams.processor.internals.ProcessorContextUtils.asInternalProcessorContext; + +/** + * Simple wrapper around a {@link SessionStore} to support writing + * updates to a changelog + */ +class ChangeLoggingSessionBytesStore + extends WrappedStateStore, byte[], byte[]> + implements SessionStore { + + private InternalProcessorContext context; + + ChangeLoggingSessionBytesStore(final SessionStore bytesStore) { + super(bytesStore); + } + + @Deprecated + @Override + public void init(final ProcessorContext context, final StateStore root) { + this.context = asInternalProcessorContext(context); + super.init(context, root); + } + + @Override + public void init(final StateStoreContext context, final StateStore root) { + this.context = asInternalProcessorContext(context); + super.init(context, root); + } + + @Override + public KeyValueIterator, byte[]> findSessions(final Bytes key, final long earliestSessionEndTime, final long latestSessionStartTime) { + return wrapped().findSessions(key, earliestSessionEndTime, latestSessionStartTime); + } + + @Override + public KeyValueIterator, byte[]> backwardFindSessions(final Bytes key, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + return wrapped().backwardFindSessions(key, earliestSessionEndTime, latestSessionStartTime); + } + + @Override + public KeyValueIterator, byte[]> findSessions(final Bytes keyFrom, final Bytes keyTo, final long earliestSessionEndTime, final long latestSessionStartTime) { + return wrapped().findSessions(keyFrom, keyTo, earliestSessionEndTime, latestSessionStartTime); + } + + @Override + public KeyValueIterator, byte[]> backwardFindSessions(final Bytes keyFrom, final Bytes keyTo, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + return wrapped().backwardFindSessions(keyFrom, keyTo, earliestSessionEndTime, latestSessionStartTime); + } + + @Override + public void remove(final Windowed sessionKey) { + wrapped().remove(sessionKey); + context.logChange(name(), SessionKeySchema.toBinary(sessionKey), null, context.timestamp()); + } + + @Override + public void put(final Windowed sessionKey, final byte[] aggregate) { + wrapped().put(sessionKey, aggregate); + context.logChange(name(), SessionKeySchema.toBinary(sessionKey), aggregate, context.timestamp()); + } + + @Override + public byte[] fetchSession(final Bytes key, final long earliestSessionEndTime, final long latestSessionStartTime) { + return wrapped().fetchSession(key, earliestSessionEndTime, latestSessionStartTime); + } + + @Override + public KeyValueIterator, byte[]> backwardFetch(final Bytes key) { + return wrapped().backwardFetch(key); + } + + @Override + public KeyValueIterator, byte[]> fetch(final Bytes key) { + return wrapped().fetch(key); + } + + @Override + public KeyValueIterator, byte[]> backwardFetch(final Bytes keyFrom, final Bytes keyTo) { + return wrapped().backwardFetch(keyFrom, keyTo); + } + + @Override + public KeyValueIterator, byte[]> fetch(final Bytes keyFrom, final Bytes keyTo) { + return wrapped().fetch(keyFrom, keyTo); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/ChangeLoggingTimestampedKeyValueBytesStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/ChangeLoggingTimestampedKeyValueBytesStore.java new file mode 100644 index 0000000..7cdac97 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/ChangeLoggingTimestampedKeyValueBytesStore.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.state.KeyValueStore; + +import static org.apache.kafka.streams.state.internals.ValueAndTimestampDeserializer.rawValue; +import static org.apache.kafka.streams.state.internals.ValueAndTimestampDeserializer.timestamp; + +public class ChangeLoggingTimestampedKeyValueBytesStore extends ChangeLoggingKeyValueBytesStore { + + ChangeLoggingTimestampedKeyValueBytesStore(final KeyValueStore inner) { + super(inner); + } + + @Override + void log(final Bytes key, + final byte[] valueAndTimestamp) { + if (valueAndTimestamp != null) { + context.logChange(name(), key, rawValue(valueAndTimestamp), timestamp(valueAndTimestamp)); + } else { + context.logChange(name(), key, null, context.timestamp()); + } + } +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/ChangeLoggingTimestampedWindowBytesStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/ChangeLoggingTimestampedWindowBytesStore.java new file mode 100644 index 0000000..8584616 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/ChangeLoggingTimestampedWindowBytesStore.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.state.WindowStore; + +import static org.apache.kafka.streams.state.internals.ValueAndTimestampDeserializer.rawValue; +import static org.apache.kafka.streams.state.internals.ValueAndTimestampDeserializer.timestamp; + +class ChangeLoggingTimestampedWindowBytesStore extends ChangeLoggingWindowBytesStore { + + ChangeLoggingTimestampedWindowBytesStore(final WindowStore bytesStore, + final boolean retainDuplicates) { + super(bytesStore, retainDuplicates, WindowKeySchema::toStoreKeyBinary); + } + + @Override + void log(final Bytes key, + final byte[] valueAndTimestamp) { + if (valueAndTimestamp != null) { + context.logChange(name(), key, rawValue(valueAndTimestamp), timestamp(valueAndTimestamp)); + } else { + context.logChange(name(), key, null, context.timestamp()); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/ChangeLoggingWindowBytesStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/ChangeLoggingWindowBytesStore.java new file mode 100644 index 0000000..5ce9654 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/ChangeLoggingWindowBytesStore.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.streams.state.WindowStoreIterator; + +import static java.util.Objects.requireNonNull; +import static org.apache.kafka.streams.processor.internals.ProcessorContextUtils.asInternalProcessorContext; + +/** + * Simple wrapper around a {@link WindowStore} to support writing + * updates to a changelog + */ +class ChangeLoggingWindowBytesStore + extends WrappedStateStore, byte[], byte[]> + implements WindowStore { + + interface ChangeLoggingKeySerializer { + Bytes serialize(final Bytes key, final long timestamp, final int seqnum); + } + + private final boolean retainDuplicates; + InternalProcessorContext context; + private int seqnum = 0; + private final ChangeLoggingKeySerializer keySerializer; + + ChangeLoggingWindowBytesStore(final WindowStore bytesStore, + final boolean retainDuplicates, + final ChangeLoggingKeySerializer keySerializer) { + super(bytesStore); + this.retainDuplicates = retainDuplicates; + this.keySerializer = requireNonNull(keySerializer, "keySerializer"); + } + + @Deprecated + @Override + public void init(final ProcessorContext context, + final StateStore root) { + this.context = asInternalProcessorContext(context); + super.init(context, root); + } + + @Override + public void init(final StateStoreContext context, + final StateStore root) { + this.context = asInternalProcessorContext(context); + super.init(context, root); + } + + @Override + public byte[] fetch(final Bytes key, + final long timestamp) { + return wrapped().fetch(key, timestamp); + } + + @Override + public WindowStoreIterator fetch(final Bytes key, + final long from, + final long to) { + return wrapped().fetch(key, from, to); + } + + @Override + public WindowStoreIterator backwardFetch(final Bytes key, + final long timeFrom, + final long timeTo) { + return wrapped().backwardFetch(key, timeFrom, timeTo); + } + + @Override + public KeyValueIterator, byte[]> fetch(final Bytes keyFrom, + final Bytes keyTo, + final long timeFrom, + final long to) { + return wrapped().fetch(keyFrom, keyTo, timeFrom, to); + } + + @Override + public KeyValueIterator, byte[]> backwardFetch(final Bytes keyFrom, + final Bytes keyTo, + final long timeFrom, + final long timeTo) { + return wrapped().backwardFetch(keyFrom, keyTo, timeFrom, timeTo); + } + + @Override + public KeyValueIterator, byte[]> all() { + return wrapped().all(); + } + + + @Override + public KeyValueIterator, byte[]> backwardAll() { + return wrapped().backwardAll(); + } + + @Override + public KeyValueIterator, byte[]> fetchAll(final long timeFrom, + final long timeTo) { + return wrapped().fetchAll(timeFrom, timeTo); + } + + @Override + public KeyValueIterator, byte[]> backwardFetchAll(final long timeFrom, + final long timeTo) { + return wrapped().backwardFetchAll(timeFrom, timeTo); + } + + @Override + public void put(final Bytes key, + final byte[] value, + final long windowStartTimestamp) { + wrapped().put(key, value, windowStartTimestamp); + log(keySerializer.serialize(key, windowStartTimestamp, maybeUpdateSeqnumForDups()), value); + } + + void log(final Bytes key, + final byte[] value) { + context.logChange(name(), key, value, context.timestamp()); + } + + private int maybeUpdateSeqnumForDups() { + if (retainDuplicates) { + seqnum = (seqnum + 1) & 0x7FFFFFFF; + } + return seqnum; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/CompositeKeyValueIterator.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/CompositeKeyValueIterator.java new file mode 100644 index 0000000..1614f9f --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/CompositeKeyValueIterator.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.state.KeyValueIterator; + +import java.util.Iterator; +import java.util.NoSuchElementException; + +class CompositeKeyValueIterator implements KeyValueIterator { + + private final Iterator storeIterator; + private final NextIteratorFunction nextIteratorFunction; + + private KeyValueIterator current; + + CompositeKeyValueIterator(final Iterator underlying, + final NextIteratorFunction nextIteratorFunction) { + this.storeIterator = underlying; + this.nextIteratorFunction = nextIteratorFunction; + } + + @Override + public void close() { + if (current != null) { + current.close(); + current = null; + } + } + + @Override + public K peekNextKey() { + throw new UnsupportedOperationException("peekNextKey not supported"); + } + + @Override + public boolean hasNext() { + while ((current == null || !current.hasNext()) && storeIterator.hasNext()) { + close(); + current = nextIteratorFunction.apply(storeIterator.next()); + } + return current != null && current.hasNext(); + } + + + @Override + public KeyValue next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return current.next(); + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/CompositeReadOnlyKeyValueStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/CompositeReadOnlyKeyValueStore.java new file mode 100644 index 0000000..f7711e3 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/CompositeReadOnlyKeyValueStore.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.QueryableStoreType; +import org.apache.kafka.streams.state.ReadOnlyKeyValueStore; + +import java.util.List; +import java.util.Objects; + +/** + * A wrapper over the underlying {@link ReadOnlyKeyValueStore}s found in a {@link + * org.apache.kafka.streams.processor.internals.ProcessorTopology} + * + * @param key type + * @param value type + */ +public class CompositeReadOnlyKeyValueStore implements ReadOnlyKeyValueStore { + + private final StateStoreProvider storeProvider; + private final QueryableStoreType> storeType; + private final String storeName; + + public CompositeReadOnlyKeyValueStore(final StateStoreProvider storeProvider, + final QueryableStoreType> storeType, + final String storeName) { + this.storeProvider = storeProvider; + this.storeType = storeType; + this.storeName = storeName; + } + + + @Override + public V get(final K key) { + Objects.requireNonNull(key); + final List> stores = storeProvider.stores(storeName, storeType); + for (final ReadOnlyKeyValueStore store : stores) { + try { + final V result = store.get(key); + if (result != null) { + return result; + } + } catch (final InvalidStateStoreException e) { + throw new InvalidStateStoreException("State store is not available anymore and may have been migrated to another instance; please re-discover its location from the state metadata."); + } + + } + return null; + } + + @Override + public KeyValueIterator range(final K from, final K to) { + final NextIteratorFunction> nextIteratorFunction = new NextIteratorFunction>() { + @Override + public KeyValueIterator apply(final ReadOnlyKeyValueStore store) { + try { + return store.range(from, to); + } catch (final InvalidStateStoreException e) { + throw new InvalidStateStoreException("State store is not available anymore and may have been migrated to another instance; please re-discover its location from the state metadata."); + } + } + }; + final List> stores = storeProvider.stores(storeName, storeType); + return new DelegatingPeekingKeyValueIterator<>( + storeName, + new CompositeKeyValueIterator<>(stores.iterator(), nextIteratorFunction)); + } + + @Override + public KeyValueIterator reverseRange(final K from, final K to) { + final NextIteratorFunction> nextIteratorFunction = new NextIteratorFunction>() { + @Override + public KeyValueIterator apply(final ReadOnlyKeyValueStore store) { + try { + return store.reverseRange(from, to); + } catch (final InvalidStateStoreException e) { + throw new InvalidStateStoreException("State store is not available anymore and may have been migrated to another instance; please re-discover its location from the state metadata."); + } + } + }; + final List> stores = storeProvider.stores(storeName, storeType); + return new DelegatingPeekingKeyValueIterator<>( + storeName, + new CompositeKeyValueIterator<>(stores.iterator(), nextIteratorFunction)); + } + + @Override + public , P> KeyValueIterator prefixScan(final P prefix, final PS prefixKeySerializer) { + Objects.requireNonNull(prefix); + Objects.requireNonNull(prefixKeySerializer); + final NextIteratorFunction> nextIteratorFunction = new NextIteratorFunction>() { + @Override + public KeyValueIterator apply(final ReadOnlyKeyValueStore store) { + try { + return store.prefixScan(prefix, prefixKeySerializer); + } catch (final InvalidStateStoreException e) { + throw new InvalidStateStoreException("State store is not available anymore and may have been migrated to another instance; please re-discover its location from the state metadata."); + } + } + }; + final List> stores = storeProvider.stores(storeName, storeType); + return new DelegatingPeekingKeyValueIterator<>( + storeName, + new CompositeKeyValueIterator<>(stores.iterator(), nextIteratorFunction)); + } + + @Override + public KeyValueIterator all() { + final NextIteratorFunction> nextIteratorFunction = new NextIteratorFunction>() { + @Override + public KeyValueIterator apply(final ReadOnlyKeyValueStore store) { + try { + return store.all(); + } catch (final InvalidStateStoreException e) { + throw new InvalidStateStoreException("State store is not available anymore and may have been migrated to another instance; please re-discover its location from the state metadata."); + } + } + }; + final List> stores = storeProvider.stores(storeName, storeType); + return new DelegatingPeekingKeyValueIterator<>( + storeName, + new CompositeKeyValueIterator<>(stores.iterator(), nextIteratorFunction)); + } + + @Override + public KeyValueIterator reverseAll() { + final NextIteratorFunction> nextIteratorFunction = new NextIteratorFunction>() { + @Override + public KeyValueIterator apply(final ReadOnlyKeyValueStore store) { + try { + return store.reverseAll(); + } catch (final InvalidStateStoreException e) { + throw new InvalidStateStoreException("State store is not available anymore and may have been migrated to another instance; please re-discover its location from the state metadata."); + } + } + }; + final List> stores = storeProvider.stores(storeName, storeType); + return new DelegatingPeekingKeyValueIterator<>( + storeName, + new CompositeKeyValueIterator<>(stores.iterator(), nextIteratorFunction)); + } + + @Override + public long approximateNumEntries() { + final List> stores = storeProvider.stores(storeName, storeType); + long total = 0; + for (final ReadOnlyKeyValueStore store : stores) { + total += store.approximateNumEntries(); + if (total < 0) { + return Long.MAX_VALUE; + } + } + return total; + } + +} + diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/CompositeReadOnlySessionStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/CompositeReadOnlySessionStore.java new file mode 100644 index 0000000..0f153cc --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/CompositeReadOnlySessionStore.java @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.QueryableStoreType; +import org.apache.kafka.streams.state.ReadOnlySessionStore; + +import java.util.List; +import java.util.Objects; + +/** + * Wrapper over the underlying {@link ReadOnlySessionStore}s found in a {@link + * org.apache.kafka.streams.processor.internals.ProcessorTopology} + */ +public class CompositeReadOnlySessionStore implements ReadOnlySessionStore { + private final StateStoreProvider storeProvider; + private final QueryableStoreType> queryableStoreType; + private final String storeName; + + public CompositeReadOnlySessionStore(final StateStoreProvider storeProvider, + final QueryableStoreType> queryableStoreType, + final String storeName) { + this.storeProvider = storeProvider; + this.queryableStoreType = queryableStoreType; + this.storeName = storeName; + } + + @Override + public KeyValueIterator, V> findSessions(final K key, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + Objects.requireNonNull(key, "key can't be null"); + final List> stores = storeProvider.stores(storeName, queryableStoreType); + for (final ReadOnlySessionStore store : stores) { + try { + final KeyValueIterator, V> result = + store.findSessions(key, earliestSessionEndTime, latestSessionStartTime); + + if (!result.hasNext()) { + result.close(); + } else { + return result; + } + } catch (final InvalidStateStoreException ise) { + throw new InvalidStateStoreException( + "State store [" + storeName + "] is not available anymore" + + " and may have been migrated to another instance; " + + "please re-discover its location from the state metadata.", + ise + ); + } + } + return KeyValueIterators.emptyIterator(); + } + + @Override + public KeyValueIterator, V> backwardFindSessions(final K key, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + Objects.requireNonNull(key, "key can't be null"); + final List> stores = storeProvider.stores(storeName, queryableStoreType); + for (final ReadOnlySessionStore store : stores) { + try { + final KeyValueIterator, V> result = store.backwardFindSessions(key, earliestSessionEndTime, latestSessionStartTime); + if (!result.hasNext()) { + result.close(); + } else { + return result; + } + } catch (final InvalidStateStoreException ise) { + throw new InvalidStateStoreException( + "State store [" + storeName + "] is not available anymore" + + " and may have been migrated to another instance; " + + "please re-discover its location from the state metadata.", + ise + ); + } + } + return KeyValueIterators.emptyIterator(); + } + + @Override + public KeyValueIterator, V> findSessions(final K keyFrom, + final K keyTo, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + final List> stores = storeProvider.stores(storeName, queryableStoreType); + for (final ReadOnlySessionStore store : stores) { + try { + final KeyValueIterator, V> result = + store.findSessions(keyFrom, keyTo, earliestSessionEndTime, latestSessionStartTime); + if (!result.hasNext()) { + result.close(); + } else { + return result; + } + } catch (final InvalidStateStoreException ise) { + throw new InvalidStateStoreException( + "State store [" + storeName + "] is not available anymore" + + " and may have been migrated to another instance; " + + "please re-discover its location from the state metadata.", + ise + ); + } + } + return KeyValueIterators.emptyIterator(); + } + + @Override + public KeyValueIterator, V> backwardFindSessions(final K keyFrom, + final K keyTo, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + final List> stores = storeProvider.stores(storeName, queryableStoreType); + for (final ReadOnlySessionStore store : stores) { + try { + final KeyValueIterator, V> result = + store.backwardFindSessions(keyFrom, keyTo, earliestSessionEndTime, latestSessionStartTime); + if (!result.hasNext()) { + result.close(); + } else { + return result; + } + } catch (final InvalidStateStoreException ise) { + throw new InvalidStateStoreException( + "State store [" + storeName + "] is not available anymore" + + " and may have been migrated to another instance; " + + "please re-discover its location from the state metadata.", + ise + ); + } + } + return KeyValueIterators.emptyIterator(); + } + + @Override + public V fetchSession(final K key, final long earliestSessionEndTime, final long latestSessionStartTime) { + Objects.requireNonNull(key, "key can't be null"); + final List> stores = storeProvider.stores(storeName, queryableStoreType); + for (final ReadOnlySessionStore store : stores) { + try { + return store.fetchSession(key, earliestSessionEndTime, latestSessionStartTime); + } catch (final InvalidStateStoreException ise) { + throw new InvalidStateStoreException( + "State store [" + storeName + "] is not available anymore" + + " and may have been migrated to another instance; " + + "please re-discover its location from the state metadata.", + ise + ); + } + } + return null; + } + + @Override + public KeyValueIterator, V> fetch(final K key) { + Objects.requireNonNull(key, "key can't be null"); + final List> stores = storeProvider.stores(storeName, queryableStoreType); + for (final ReadOnlySessionStore store : stores) { + try { + final KeyValueIterator, V> result = store.fetch(key); + if (!result.hasNext()) { + result.close(); + } else { + return result; + } + } catch (final InvalidStateStoreException ise) { + throw new InvalidStateStoreException("State store [" + storeName + "] is not available anymore" + + " and may have been migrated to another instance; " + + "please re-discover its location from the state metadata. " + + "Original error message: " + ise.toString()); + } + } + return KeyValueIterators.emptyIterator(); + } + + @Override + public KeyValueIterator, V> backwardFetch(final K key) { + Objects.requireNonNull(key, "key can't be null"); + final List> stores = storeProvider.stores(storeName, queryableStoreType); + for (final ReadOnlySessionStore store : stores) { + try { + final KeyValueIterator, V> result = store.backwardFetch(key); + if (!result.hasNext()) { + result.close(); + } else { + return result; + } + } catch (final InvalidStateStoreException ise) { + throw new InvalidStateStoreException( + "State store [" + storeName + "] is not available anymore" + + " and may have been migrated to another instance; " + + "please re-discover its location from the state metadata.", + ise + ); + } + } + return KeyValueIterators.emptyIterator(); + } + + @Override + public KeyValueIterator, V> fetch(final K keyFrom, final K keyTo) { + final NextIteratorFunction, V, ReadOnlySessionStore> nextIteratorFunction = + store -> store.fetch(keyFrom, keyTo); + return new DelegatingPeekingKeyValueIterator<>(storeName, + new CompositeKeyValueIterator<>( + storeProvider.stores(storeName, queryableStoreType).iterator(), + nextIteratorFunction)); + } + + @Override + public KeyValueIterator, V> backwardFetch(final K keyFrom, final K keyTo) { + final NextIteratorFunction, V, ReadOnlySessionStore> nextIteratorFunction = + store -> store.backwardFetch(keyFrom, keyTo); + return new DelegatingPeekingKeyValueIterator<>( + storeName, + new CompositeKeyValueIterator<>( + storeProvider.stores(storeName, queryableStoreType).iterator(), + nextIteratorFunction + ) + ); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/CompositeReadOnlyWindowStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/CompositeReadOnlyWindowStore.java new file mode 100644 index 0000000..c6b0b60 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/CompositeReadOnlyWindowStore.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.QueryableStoreType; +import org.apache.kafka.streams.state.ReadOnlyWindowStore; +import org.apache.kafka.streams.state.WindowStoreIterator; + +import java.time.Instant; +import java.util.List; +import java.util.Objects; + +/** + * Wrapper over the underlying {@link ReadOnlyWindowStore}s found in a {@link + * org.apache.kafka.streams.processor.internals.ProcessorTopology} + */ +public class CompositeReadOnlyWindowStore implements ReadOnlyWindowStore { + + private final QueryableStoreType> windowStoreType; + private final String storeName; + private final StateStoreProvider provider; + + public CompositeReadOnlyWindowStore(final StateStoreProvider provider, + final QueryableStoreType> windowStoreType, + final String storeName) { + this.provider = provider; + this.windowStoreType = windowStoreType; + this.storeName = storeName; + } + + @Override + public V fetch(final K key, final long time) { + Objects.requireNonNull(key, "key can't be null"); + final List> stores = provider.stores(storeName, windowStoreType); + for (final ReadOnlyWindowStore windowStore : stores) { + try { + final V result = windowStore.fetch(key, time); + if (result != null) { + return result; + } + } catch (final InvalidStateStoreException e) { + throw new InvalidStateStoreException( + "State store is not available anymore and may have been migrated to another instance; " + + "please re-discover its location from the state metadata."); + } + } + return null; + } + + @Override + public WindowStoreIterator fetch(final K key, + final Instant timeFrom, + final Instant timeTo) { + Objects.requireNonNull(key, "key can't be null"); + final List> stores = provider.stores(storeName, windowStoreType); + for (final ReadOnlyWindowStore windowStore : stores) { + try { + final WindowStoreIterator result = windowStore.fetch(key, timeFrom, timeTo); + if (!result.hasNext()) { + result.close(); + } else { + return result; + } + } catch (final InvalidStateStoreException e) { + throw new InvalidStateStoreException( + "State store is not available anymore and may have been migrated to another instance; " + + "please re-discover its location from the state metadata."); + } + } + return KeyValueIterators.emptyWindowStoreIterator(); + } + + @Override + public WindowStoreIterator backwardFetch(final K key, + final Instant timeFrom, + final Instant timeTo) throws IllegalArgumentException { + Objects.requireNonNull(key, "key can't be null"); + final List> stores = provider.stores(storeName, windowStoreType); + for (final ReadOnlyWindowStore windowStore : stores) { + try { + final WindowStoreIterator result = windowStore.backwardFetch(key, timeFrom, timeTo); + if (!result.hasNext()) { + result.close(); + } else { + return result; + } + } catch (final InvalidStateStoreException e) { + throw new InvalidStateStoreException( + "State store is not available anymore and may have been migrated to another instance; " + + "please re-discover its location from the state metadata."); + } + } + return KeyValueIterators.emptyWindowStoreIterator(); + } + + @Override + public KeyValueIterator, V> fetch(final K keyFrom, + final K keyTo, + final Instant timeFrom, + final Instant timeTo) { + final NextIteratorFunction, V, ReadOnlyWindowStore> nextIteratorFunction = + store -> store.fetch(keyFrom, keyTo, timeFrom, timeTo); + return new DelegatingPeekingKeyValueIterator<>( + storeName, + new CompositeKeyValueIterator<>( + provider.stores(storeName, windowStoreType).iterator(), + nextIteratorFunction)); + } + + @Override + public KeyValueIterator, V> backwardFetch(final K keyFrom, + final K keyTo, + final Instant timeFrom, + final Instant timeTo) throws IllegalArgumentException { + final NextIteratorFunction, V, ReadOnlyWindowStore> nextIteratorFunction = + store -> store.backwardFetch(keyFrom, keyTo, timeFrom, timeTo); + return new DelegatingPeekingKeyValueIterator<>( + storeName, + new CompositeKeyValueIterator<>( + provider.stores(storeName, windowStoreType).iterator(), + nextIteratorFunction)); + } + + @Override + public KeyValueIterator, V> all() { + final NextIteratorFunction, V, ReadOnlyWindowStore> nextIteratorFunction = + ReadOnlyWindowStore::all; + return new DelegatingPeekingKeyValueIterator<>( + storeName, + new CompositeKeyValueIterator<>( + provider.stores(storeName, windowStoreType).iterator(), + nextIteratorFunction)); + } + + @Override + public KeyValueIterator, V> backwardAll() { + final NextIteratorFunction, V, ReadOnlyWindowStore> nextIteratorFunction = + ReadOnlyWindowStore::backwardAll; + return new DelegatingPeekingKeyValueIterator<>( + storeName, + new CompositeKeyValueIterator<>( + provider.stores(storeName, windowStoreType).iterator(), + nextIteratorFunction)); + } + + @Override + public KeyValueIterator, V> fetchAll(final Instant timeFrom, + final Instant timeTo) { + final NextIteratorFunction, V, ReadOnlyWindowStore> nextIteratorFunction = + store -> store.fetchAll(timeFrom, timeTo); + return new DelegatingPeekingKeyValueIterator<>( + storeName, + new CompositeKeyValueIterator<>( + provider.stores(storeName, windowStoreType).iterator(), + nextIteratorFunction)); + } + + @Override + public KeyValueIterator, V> backwardFetchAll(final Instant timeFrom, + final Instant timeTo) throws IllegalArgumentException { + final NextIteratorFunction, V, ReadOnlyWindowStore> nextIteratorFunction = + store -> store.backwardFetchAll(timeFrom, timeTo); + return new DelegatingPeekingKeyValueIterator<>( + storeName, + new CompositeKeyValueIterator<>( + provider.stores(storeName, windowStoreType).iterator(), + nextIteratorFunction)); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/ContextualRecord.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/ContextualRecord.java new file mode 100644 index 0000000..a26b437 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/ContextualRecord.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.processor.internals.ProcessorRecordContext; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Objects; + +import static org.apache.kafka.common.utils.Utils.getNullableSizePrefixedArray; + +public class ContextualRecord { + private final byte[] value; + private final ProcessorRecordContext recordContext; + + public ContextualRecord(final byte[] value, final ProcessorRecordContext recordContext) { + this.value = value; + this.recordContext = Objects.requireNonNull(recordContext); + } + + public ProcessorRecordContext recordContext() { + return recordContext; + } + + public byte[] value() { + return value; + } + + long residentMemorySizeEstimate() { + return (value == null ? 0 : value.length) + recordContext.residentMemorySizeEstimate(); + } + + static ContextualRecord deserialize(final ByteBuffer buffer) { + final ProcessorRecordContext context = ProcessorRecordContext.deserialize(buffer); + final byte[] value = getNullableSizePrefixedArray(buffer); + return new ContextualRecord(value, context); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final ContextualRecord that = (ContextualRecord) o; + return Arrays.equals(value, that.value) && + Objects.equals(recordContext, that.recordContext); + } + + @Override + public int hashCode() { + return Objects.hash(value, recordContext); + } + + @Override + public String toString() { + return "ContextualRecord{" + + "recordContext=" + recordContext + + ", value=" + Arrays.toString(value) + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/DelegatingPeekingKeyValueIterator.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/DelegatingPeekingKeyValueIterator.java new file mode 100644 index 0000000..245c9e8 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/DelegatingPeekingKeyValueIterator.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.state.KeyValueIterator; + +import java.util.NoSuchElementException; + +/** + * Optimized {@link KeyValueIterator} used when the same element could be peeked multiple times. + */ +public class DelegatingPeekingKeyValueIterator implements KeyValueIterator, PeekingKeyValueIterator { + private final KeyValueIterator underlying; + private final String storeName; + private KeyValue next; + + private volatile boolean open = true; + + public DelegatingPeekingKeyValueIterator(final String storeName, final KeyValueIterator underlying) { + this.storeName = storeName; + this.underlying = underlying; + } + + @Override + public synchronized K peekNextKey() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return next.key; + } + + @Override + public synchronized void close() { + underlying.close(); + open = false; + } + + @Override + public synchronized boolean hasNext() { + if (!open) { + throw new InvalidStateStoreException(String.format("Store %s has closed", storeName)); + } + if (next != null) { + return true; + } + + if (!underlying.hasNext()) { + return false; + } + + next = underlying.next(); + return true; + } + + @Override + public synchronized KeyValue next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + final KeyValue result = next; + next = null; + return result; + } + + @Override + public KeyValue peekNext() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return next; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/ExceptionUtils.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/ExceptionUtils.java new file mode 100644 index 0000000..e40b6ad --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/ExceptionUtils.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import java.util.LinkedList; + +final class ExceptionUtils { + private ExceptionUtils() {} + + static LinkedList executeAll(final Runnable... actions) { + final LinkedList suppressed = new LinkedList<>(); + for (final Runnable action : actions) { + try { + action.run(); + } catch (final RuntimeException exception) { + suppressed.add(exception); + } + } + return suppressed; + } + + static void throwSuppressed(final String message, final LinkedList suppressed) { + if (!suppressed.isEmpty()) { + final RuntimeException firstCause = suppressed.pollFirst(); + final RuntimeException toThrow = new RuntimeException(message, firstCause); + for (final RuntimeException e : suppressed) { + toThrow.addSuppressed(e); + } + throw toThrow; + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/FilteredCacheIterator.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/FilteredCacheIterator.java new file mode 100644 index 0000000..9b1bfd8 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/FilteredCacheIterator.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; + +import java.util.NoSuchElementException; + +class FilteredCacheIterator implements PeekingKeyValueIterator { + private final PeekingKeyValueIterator cacheIterator; + private final HasNextCondition hasNextCondition; + private final PeekingKeyValueIterator wrappedIterator; + + FilteredCacheIterator(final PeekingKeyValueIterator cacheIterator, + final HasNextCondition hasNextCondition, + final CacheFunction cacheFunction) { + this.cacheIterator = cacheIterator; + this.hasNextCondition = hasNextCondition; + this.wrappedIterator = new PeekingKeyValueIterator() { + @Override + public KeyValue peekNext() { + return cachedPair(cacheIterator.peekNext()); + } + + @Override + public void close() { + cacheIterator.close(); + } + + @Override + public Bytes peekNextKey() { + return cacheFunction.key(cacheIterator.peekNextKey()); + } + + @Override + public boolean hasNext() { + return cacheIterator.hasNext(); + } + + @Override + public KeyValue next() { + return cachedPair(cacheIterator.next()); + } + + private KeyValue cachedPair(final KeyValue next) { + return KeyValue.pair(cacheFunction.key(next.key), next.value); + } + + }; + } + + @Override + public void close() { + // no-op + } + + @Override + public Bytes peekNextKey() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return cacheIterator.peekNextKey(); + } + + @Override + public boolean hasNext() { + return hasNextCondition.hasNext(wrappedIterator); + } + + @Override + public KeyValue next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return cacheIterator.next(); + + } + + @Override + public KeyValue peekNext() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return cacheIterator.peekNext(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/GlobalStateStoreProvider.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/GlobalStateStoreProvider.java new file mode 100644 index 0000000..057a836 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/GlobalStateStoreProvider.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.state.QueryableStoreType; +import org.apache.kafka.streams.state.QueryableStoreTypes; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.TimestampedWindowStore; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +public class GlobalStateStoreProvider implements StateStoreProvider { + private final Map globalStateStores; + + public GlobalStateStoreProvider(final Map globalStateStores) { + this.globalStateStores = globalStateStores; + } + + @SuppressWarnings("unchecked") + @Override + public List stores(final String storeName, final QueryableStoreType queryableStoreType) { + final StateStore store = globalStateStores.get(storeName); + if (store == null || !queryableStoreType.accepts(store)) { + return Collections.emptyList(); + } + if (!store.isOpen()) { + throw new InvalidStateStoreException("the state store, " + storeName + ", is not open."); + } + if (store instanceof TimestampedKeyValueStore && queryableStoreType instanceof QueryableStoreTypes.KeyValueStoreType) { + return (List) Collections.singletonList(new ReadOnlyKeyValueStoreFacade((TimestampedKeyValueStore) store)); + } else if (store instanceof TimestampedWindowStore && queryableStoreType instanceof QueryableStoreTypes.WindowStoreType) { + return (List) Collections.singletonList(new ReadOnlyWindowStoreFacade((TimestampedWindowStore) store)); + } + return (List) Collections.singletonList(store); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/HasNextCondition.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/HasNextCondition.java new file mode 100644 index 0000000..8784dba --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/HasNextCondition.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.state.KeyValueIterator; + +interface HasNextCondition { + boolean hasNext(final KeyValueIterator iterator); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryKeyValueStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryKeyValueStore.java new file mode 100644 index 0000000..f0c6dbe --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryKeyValueStore.java @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Iterator; +import java.util.List; +import java.util.NavigableMap; +import java.util.Set; +import java.util.TreeMap; +import java.util.TreeSet; + +public class InMemoryKeyValueStore implements KeyValueStore { + + private static final Logger LOG = LoggerFactory.getLogger(InMemoryKeyValueStore.class); + + private final String name; + private final NavigableMap map = new TreeMap<>(); + private volatile boolean open = false; + + public InMemoryKeyValueStore(final String name) { + this.name = name; + } + + @Override + public String name() { + return name; + } + + @Deprecated + @Override + public void init(final ProcessorContext context, + final StateStore root) { + if (root != null) { + // register the store + context.register(root, (key, value) -> put(Bytes.wrap(key), value)); + } + + open = true; + } + + @Override + public boolean persistent() { + return false; + } + + @Override + public boolean isOpen() { + return open; + } + + @Override + public synchronized byte[] get(final Bytes key) { + return map.get(key); + } + + @Override + public synchronized void put(final Bytes key, final byte[] value) { + putInternal(key, value); + } + + @Override + public synchronized byte[] putIfAbsent(final Bytes key, final byte[] value) { + final byte[] originalValue = get(key); + if (originalValue == null) { + put(key, value); + } + return originalValue; + } + + // the unlocked implementation of put method, to avoid multiple lock/unlock cost in `putAll` method + private void putInternal(final Bytes key, final byte[] value) { + if (value == null) { + map.remove(key); + } else { + map.put(key, value); + } + } + + @Override + public synchronized void putAll(final List> entries) { + for (final KeyValue entry : entries) { + putInternal(entry.key, entry.value); + } + } + + @Override + public , P> KeyValueIterator prefixScan(final P prefix, final PS prefixKeySerializer) { + + final Bytes from = Bytes.wrap(prefixKeySerializer.serialize(null, prefix)); + final Bytes to = Bytes.increment(from); + + return new DelegatingPeekingKeyValueIterator<>( + name, + new InMemoryKeyValueIterator(map.subMap(from, true, to, false).keySet(), true) + ); + } + + @Override + public synchronized byte[] delete(final Bytes key) { + return map.remove(key); + } + + @Override + public synchronized KeyValueIterator range(final Bytes from, final Bytes to) { + return range(from, to, true); + } + + @Override + public synchronized KeyValueIterator reverseRange(final Bytes from, final Bytes to) { + return range(from, to, false); + } + + private KeyValueIterator range(final Bytes from, final Bytes to, final boolean forward) { + if (from == null && to == null) { + return getKeyValueIterator(map.keySet(), forward); + } else if (from == null) { + return getKeyValueIterator(map.headMap(to, true).keySet(), forward); + } else if (to == null) { + return getKeyValueIterator(map.tailMap(from, true).keySet(), forward); + } else if (from.compareTo(to) > 0) { + LOG.warn("Returning empty iterator for fetch with invalid key range: from > to. " + + "This may be due to range arguments set in the wrong order, " + + "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes. " + + "Note that the built-in numerical serdes do not follow this for negative numbers"); + return KeyValueIterators.emptyIterator(); + } else { + return getKeyValueIterator(map.subMap(from, true, to, true).keySet(), forward); + } + } + + private KeyValueIterator getKeyValueIterator(final Set rangeSet, final boolean forward) { + return new DelegatingPeekingKeyValueIterator<>(name, new InMemoryKeyValueIterator(rangeSet, forward)); + } + + @Override + public synchronized KeyValueIterator all() { + return range(null, null); + } + + @Override + public synchronized KeyValueIterator reverseAll() { + return new DelegatingPeekingKeyValueIterator<>( + name, + new InMemoryKeyValueIterator(map.keySet(), false)); + } + + @Override + public long approximateNumEntries() { + return map.size(); + } + + @Override + public void flush() { + // do-nothing since it is in-memory + } + + @Override + public void close() { + map.clear(); + open = false; + } + + private class InMemoryKeyValueIterator implements KeyValueIterator { + private final Iterator iter; + + private InMemoryKeyValueIterator(final Set keySet, final boolean forward) { + if (forward) { + this.iter = new TreeSet<>(keySet).iterator(); + } else { + this.iter = new TreeSet<>(keySet).descendingIterator(); + } + } + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public KeyValue next() { + final Bytes key = iter.next(); + return new KeyValue<>(key, map.get(key)); + } + + @Override + public void close() { + // do nothing + } + + @Override + public Bytes peekNextKey() { + throw new UnsupportedOperationException("peekNextKey() not supported in " + getClass().getName()); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemorySessionBytesStoreSupplier.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemorySessionBytesStoreSupplier.java new file mode 100644 index 0000000..fa4eddc --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemorySessionBytesStoreSupplier.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.state.SessionBytesStoreSupplier; +import org.apache.kafka.streams.state.SessionStore; + +public class InMemorySessionBytesStoreSupplier implements SessionBytesStoreSupplier { + private final String name; + private final long retentionPeriod; + + public InMemorySessionBytesStoreSupplier(final String name, + final long retentionPeriod) { + this.name = name; + this.retentionPeriod = retentionPeriod; + } + + @Override + public String name() { + return name; + } + + @Override + public SessionStore get() { + return new InMemorySessionStore(name, retentionPeriod, metricsScope()); + } + + @Override + public String metricsScope() { + return "in-memory-session"; + } + + // In-memory store is not *really* segmented, so just say it is 1 (for ordering consistency with caching enabled) + @Override + public long segmentIntervalMs() { + return 1; + } + + @Override + public long retentionPeriod() { + return retentionPeriod; + } +} + diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemorySessionStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemorySessionStore.java new file mode 100644 index 0000000..bc8cda6 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemorySessionStore.java @@ -0,0 +1,515 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.SessionWindow; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.processor.internals.metrics.TaskMetrics; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.SessionStore; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Iterator; +import java.util.Map; +import java.util.Map.Entry; +import java.util.NoSuchElementException; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentNavigableMap; +import java.util.concurrent.ConcurrentSkipListMap; + +public class InMemorySessionStore implements SessionStore { + + private static final Logger LOG = LoggerFactory.getLogger(InMemorySessionStore.class); + + private final String name; + private final String metricScope; + private Sensor expiredRecordSensor; + private InternalProcessorContext context; + private long observedStreamTime = ConsumerRecord.NO_TIMESTAMP; + + private final long retentionPeriod; + + private final static String INVALID_RANGE_WARN_MSG = + "Returning empty iterator for fetch with invalid key range: from > to. " + + "This may be due to range arguments set in the wrong order, " + + "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes. " + + "Note that the built-in numerical serdes do not follow this for negative numbers"; + + private final ConcurrentNavigableMap>> endTimeMap = new ConcurrentSkipListMap<>(); + private final Set openIterators = ConcurrentHashMap.newKeySet(); + + private volatile boolean open = false; + + InMemorySessionStore(final String name, + final long retentionPeriod, + final String metricScope) { + this.name = name; + this.retentionPeriod = retentionPeriod; + this.metricScope = metricScope; + } + + @Override + public String name() { + return name; + } + + @Deprecated + @Override + public void init(final ProcessorContext context, final StateStore root) { + final String threadId = Thread.currentThread().getName(); + final String taskName = context.taskId().toString(); + + // The provided context is not required to implement InternalProcessorContext, + // If it doesn't, we can't record this metric. + if (context instanceof InternalProcessorContext) { + this.context = (InternalProcessorContext) context; + final StreamsMetricsImpl metrics = this.context.metrics(); + expiredRecordSensor = TaskMetrics.droppedRecordsSensor( + threadId, + taskName, + metrics + ); + } else { + this.context = null; + expiredRecordSensor = null; + } + + if (root != null) { + context.register(root, (key, value) -> put(SessionKeySchema.from(Bytes.wrap(key)), value)); + } + open = true; + } + + @Override + public void put(final Windowed sessionKey, final byte[] aggregate) { + removeExpiredSegments(); + + final long windowEndTimestamp = sessionKey.window().end(); + observedStreamTime = Math.max(observedStreamTime, windowEndTimestamp); + + if (windowEndTimestamp <= observedStreamTime - retentionPeriod) { + // The provided context is not required to implement InternalProcessorContext, + // If it doesn't, we can't record this metric (in fact, we wouldn't have even initialized it). + if (expiredRecordSensor != null && context != null) { + expiredRecordSensor.record(1.0d, context.currentSystemTimeMs()); + } + LOG.warn("Skipping record for expired segment."); + } else { + if (aggregate != null) { + endTimeMap.computeIfAbsent(windowEndTimestamp, t -> new ConcurrentSkipListMap<>()); + final ConcurrentNavigableMap> keyMap = endTimeMap.get(windowEndTimestamp); + keyMap.computeIfAbsent(sessionKey.key(), t -> new ConcurrentSkipListMap<>()); + keyMap.get(sessionKey.key()).put(sessionKey.window().start(), aggregate); + } else { + remove(sessionKey); + } + } + } + + @Override + public void remove(final Windowed sessionKey) { + final ConcurrentNavigableMap> keyMap = endTimeMap.get(sessionKey.window().end()); + if (keyMap == null) { + return; + } + + final ConcurrentNavigableMap startTimeMap = keyMap.get(sessionKey.key()); + if (startTimeMap == null) { + return; + } + + startTimeMap.remove(sessionKey.window().start()); + + if (startTimeMap.isEmpty()) { + keyMap.remove(sessionKey.key()); + if (keyMap.isEmpty()) { + endTimeMap.remove(sessionKey.window().end()); + } + } + } + + @Override + public byte[] fetchSession(final Bytes key, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + removeExpiredSegments(); + + Objects.requireNonNull(key, "key cannot be null"); + + // Only need to search if the record hasn't expired yet + if (latestSessionStartTime > observedStreamTime - retentionPeriod) { + final ConcurrentNavigableMap> keyMap = endTimeMap.get(latestSessionStartTime); + if (keyMap != null) { + final ConcurrentNavigableMap startTimeMap = keyMap.get(key); + if (startTimeMap != null) { + return startTimeMap.get(earliestSessionEndTime); + } + } + } + return null; + } + + @Override + public KeyValueIterator, byte[]> findSessions(final Bytes key, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + Objects.requireNonNull(key, "key cannot be null"); + + removeExpiredSegments(); + + return registerNewIterator(key, + key, + latestSessionStartTime, + endTimeMap.tailMap(earliestSessionEndTime, true).entrySet().iterator(), + true); + } + + @Override + public KeyValueIterator, byte[]> backwardFindSessions(final Bytes key, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + Objects.requireNonNull(key, "key cannot be null"); + + removeExpiredSegments(); + + return registerNewIterator( + key, + key, + latestSessionStartTime, + endTimeMap.tailMap(earliestSessionEndTime, true).descendingMap().entrySet().iterator(), + false + ); + } + + @Override + public KeyValueIterator, byte[]> findSessions(final Bytes keyFrom, + final Bytes keyTo, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + removeExpiredSegments(); + + if (keyFrom != null && keyTo != null && keyFrom.compareTo(keyTo) > 0) { + LOG.warn(INVALID_RANGE_WARN_MSG); + return KeyValueIterators.emptyIterator(); + } + + return registerNewIterator(keyFrom, + keyTo, + latestSessionStartTime, + endTimeMap.tailMap(earliestSessionEndTime, true).entrySet().iterator(), + true); + } + + @Override + public KeyValueIterator, byte[]> backwardFindSessions(final Bytes keyFrom, + final Bytes keyTo, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + removeExpiredSegments(); + + if (keyFrom != null && keyTo != null && keyFrom.compareTo(keyTo) > 0) { + LOG.warn(INVALID_RANGE_WARN_MSG); + return KeyValueIterators.emptyIterator(); + } + + return registerNewIterator( + keyFrom, + keyTo, + latestSessionStartTime, + endTimeMap.tailMap(earliestSessionEndTime, true).descendingMap().entrySet().iterator(), + false + ); + } + + @Override + public KeyValueIterator, byte[]> fetch(final Bytes key) { + + Objects.requireNonNull(key, "key cannot be null"); + + removeExpiredSegments(); + + return registerNewIterator(key, key, Long.MAX_VALUE, endTimeMap.entrySet().iterator(), true); + } + + @Override + public KeyValueIterator, byte[]> backwardFetch(final Bytes key) { + + Objects.requireNonNull(key, "key cannot be null"); + + removeExpiredSegments(); + + return registerNewIterator(key, key, Long.MAX_VALUE, endTimeMap.descendingMap().entrySet().iterator(), false); + } + + @Override + public KeyValueIterator, byte[]> fetch(final Bytes keyFrom, final Bytes keyTo) { + removeExpiredSegments(); + + return registerNewIterator(keyFrom, keyTo, Long.MAX_VALUE, endTimeMap.entrySet().iterator(), true); + } + + @Override + public KeyValueIterator, byte[]> backwardFetch(final Bytes keyFrom, final Bytes keyTo) { + removeExpiredSegments(); + + return registerNewIterator( + keyFrom, keyTo, Long.MAX_VALUE, endTimeMap.descendingMap().entrySet().iterator(), false); + } + + @Override + public boolean persistent() { + return false; + } + + @Override + public boolean isOpen() { + return open; + } + + @Override + public void flush() { + // do-nothing since it is in-memory + } + + @Override + public void close() { + if (openIterators.size() != 0) { + LOG.warn("Closing {} open iterators for store {}", openIterators.size(), name); + for (final InMemorySessionStoreIterator it : openIterators) { + it.close(); + } + } + + endTimeMap.clear(); + openIterators.clear(); + open = false; + } + + private void removeExpiredSegments() { + long minLiveTime = Math.max(0L, observedStreamTime - retentionPeriod + 1); + + for (final InMemorySessionStoreIterator it : openIterators) { + minLiveTime = Math.min(minLiveTime, it.minTime()); + } + + endTimeMap.headMap(minLiveTime, false).clear(); + } + + private InMemorySessionStoreIterator registerNewIterator(final Bytes keyFrom, + final Bytes keyTo, + final long latestSessionStartTime, + final Iterator>>> endTimeIterator, + final boolean forward) { + final InMemorySessionStoreIterator iterator = + new InMemorySessionStoreIterator( + keyFrom, + keyTo, + latestSessionStartTime, + endTimeIterator, + openIterators::remove, + forward + ); + openIterators.add(iterator); + return iterator; + } + + interface ClosingCallback { + void deregisterIterator(final InMemorySessionStoreIterator iterator); + } + + private static class InMemorySessionStoreIterator implements KeyValueIterator, byte[]> { + + private final Iterator>>> endTimeIterator; + private Iterator>> keyIterator; + private Iterator> recordIterator; + + private KeyValue, byte[]> next; + private Bytes currentKey; + private long currentEndTime; + + private final Bytes keyFrom; + private final Bytes keyTo; + private final long latestSessionStartTime; + + private final ClosingCallback callback; + + private final boolean forward; + + InMemorySessionStoreIterator(final Bytes keyFrom, + final Bytes keyTo, + final long latestSessionStartTime, + final Iterator>>> endTimeIterator, + final ClosingCallback callback, + final boolean forward) { + this.keyFrom = keyFrom; + this.keyTo = keyTo; + this.latestSessionStartTime = latestSessionStartTime; + + this.endTimeIterator = endTimeIterator; + this.callback = callback; + this.forward = forward; + setAllIterators(); + } + + @Override + public boolean hasNext() { + if (next != null) { + return true; + } else if (recordIterator == null) { + return false; + } else { + next = getNext(); + return next != null; + } + } + + @Override + public Windowed peekNextKey() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return next.key; + } + + @Override + public KeyValue, byte[]> next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + + final KeyValue, byte[]> ret = next; + next = null; + return ret; + } + + @Override + public void close() { + next = null; + recordIterator = null; + callback.deregisterIterator(this); + } + + Long minTime() { + return currentEndTime; + } + + // getNext is only called when either recordIterator or segmentIterator has a next + // Note this does not guarantee a next record exists as the next segments may not contain any keys in range + private KeyValue, byte[]> getNext() { + if (!recordIterator.hasNext()) { + getNextIterators(); + } + + if (recordIterator == null) { + return null; + } + + final Map.Entry nextRecord = recordIterator.next(); + final SessionWindow sessionWindow = new SessionWindow(nextRecord.getKey(), currentEndTime); + final Windowed windowedKey = new Windowed<>(currentKey, sessionWindow); + + return new KeyValue<>(windowedKey, nextRecord.getValue()); + } + + // Called when the inner two (key and starttime) iterators are empty to roll to the next endTimestamp + // Rolls all three iterators forward until recordIterator has a next entry + // Sets recordIterator to null if there are no records to return + private void setAllIterators() { + while (endTimeIterator.hasNext()) { + final Entry>> nextEndTimeEntry = endTimeIterator.next(); + currentEndTime = nextEndTimeEntry.getKey(); + + final ConcurrentNavigableMap> subKVMap; + if (keyFrom == null && keyTo == null) { + subKVMap = nextEndTimeEntry.getValue(); + } else if (keyFrom == null) { + subKVMap = nextEndTimeEntry.getValue().headMap(keyTo, true); + } else if (keyTo == null) { + subKVMap = nextEndTimeEntry.getValue().tailMap(keyFrom, true); + } else { + subKVMap = nextEndTimeEntry.getValue().subMap(keyFrom, true, keyTo, true); + } + + if (forward) { + keyIterator = subKVMap.entrySet().iterator(); + } else { + keyIterator = subKVMap.descendingMap().entrySet().iterator(); + } + + if (setInnerIterators()) { + return; + } + } + recordIterator = null; + } + + // Rolls the inner two iterators (key and record) forward until recordIterators has a next entry + // Returns false if no more records are found (for the current end time) + private boolean setInnerIterators() { + while (keyIterator.hasNext()) { + final Entry> nextKeyEntry = keyIterator.next(); + currentKey = nextKeyEntry.getKey(); + + if (latestSessionStartTime == Long.MAX_VALUE) { + if (forward) { + recordIterator = nextKeyEntry.getValue().descendingMap().entrySet().iterator(); + } else { + recordIterator = nextKeyEntry.getValue().entrySet().iterator(); + } + } else { + if (forward) { + recordIterator = nextKeyEntry.getValue() + .headMap(latestSessionStartTime, true) + .descendingMap() + .entrySet().iterator(); + } else { + recordIterator = nextKeyEntry.getValue() + .headMap(latestSessionStartTime, true) + .entrySet().iterator(); + } + } + + if (recordIterator.hasNext()) { + return true; + } + } + return false; + } + + // Called when the current recordIterator has no entries left to roll it to the next valid entry + // When there are no more records to return, recordIterator will be set to null + private void getNextIterators() { + if (setInnerIterators()) { + return; + } + + setAllIterators(); + } + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryTimeOrderedKeyValueBuffer.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryTimeOrderedKeyValueBuffer.java new file mode 100644 index 0000000..ba8a745 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryTimeOrderedKeyValueBuffer.java @@ -0,0 +1,566 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.serialization.BytesSerializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.kstream.internals.Change; +import org.apache.kafka.streams.kstream.internals.FullChangeSerde; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.processor.internals.ProcessorContextUtils; +import org.apache.kafka.streams.processor.internals.ProcessorRecordContext; +import org.apache.kafka.streams.processor.internals.ProcessorStateManager; +import org.apache.kafka.streams.processor.internals.RecordBatchingStateRestoreCallback; +import org.apache.kafka.streams.processor.internals.RecordCollector; +import org.apache.kafka.streams.processor.internals.RecordQueue; +import org.apache.kafka.streams.processor.internals.SerdeGetter; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.streams.state.internals.TimeOrderedKeyValueBufferChangelogDeserializationHelper.DeserializationResult; +import org.apache.kafka.streams.state.internals.metrics.StateStoreMetrics; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Set; +import java.util.TreeMap; +import java.util.function.Consumer; +import java.util.function.Supplier; + +import static java.util.Objects.requireNonNull; +import static org.apache.kafka.streams.state.internals.TimeOrderedKeyValueBufferChangelogDeserializationHelper.deserializeV0; +import static org.apache.kafka.streams.state.internals.TimeOrderedKeyValueBufferChangelogDeserializationHelper.deserializeV1; +import static org.apache.kafka.streams.state.internals.TimeOrderedKeyValueBufferChangelogDeserializationHelper.deserializeV3; +import static org.apache.kafka.streams.state.internals.TimeOrderedKeyValueBufferChangelogDeserializationHelper.duckTypeV2; + +public final class InMemoryTimeOrderedKeyValueBuffer implements TimeOrderedKeyValueBuffer { + private static final BytesSerializer KEY_SERIALIZER = new BytesSerializer(); + private static final ByteArraySerializer VALUE_SERIALIZER = new ByteArraySerializer(); + private static final byte[] V_1_CHANGELOG_HEADER_VALUE = {(byte) 1}; + private static final byte[] V_2_CHANGELOG_HEADER_VALUE = {(byte) 2}; + private static final byte[] V_3_CHANGELOG_HEADER_VALUE = {(byte) 3}; + static final RecordHeaders CHANGELOG_HEADERS = + new RecordHeaders(new Header[] {new RecordHeader("v", V_3_CHANGELOG_HEADER_VALUE)}); + private static final String METRIC_SCOPE = "in-memory-suppression"; + + private final Map index = new HashMap<>(); + private final TreeMap sortedMap = new TreeMap<>(); + + private final Set dirtyKeys = new HashSet<>(); + private final String storeName; + private final boolean loggingEnabled; + + private Serde keySerde; + private FullChangeSerde valueSerde; + + private long memBufferSize = 0L; + private long minTimestamp = Long.MAX_VALUE; + private InternalProcessorContext context; + private String changelogTopic; + private Sensor bufferSizeSensor; + private Sensor bufferCountSensor; + private StreamsMetricsImpl streamsMetrics; + private String taskId; + + private volatile boolean open; + + private int partition; + + public static class Builder implements StoreBuilder> { + + private final String storeName; + private final Serde keySerde; + private final Serde valueSerde; + private boolean loggingEnabled = true; + private Map logConfig = new HashMap<>(); + + public Builder(final String storeName, final Serde keySerde, final Serde valueSerde) { + this.storeName = storeName; + this.keySerde = keySerde; + this.valueSerde = valueSerde; + } + + /** + * As of 2.1, there's no way for users to directly interact with the buffer, + * so this method is implemented solely to be called by Streams (which + * it will do based on the {@code cache.max.bytes.buffering} config. + *

                + * It's currently a no-op. + */ + @Override + public StoreBuilder> withCachingEnabled() { + return this; + } + + /** + * As of 2.1, there's no way for users to directly interact with the buffer, + * so this method is implemented solely to be called by Streams (which + * it will do based on the {@code cache.max.bytes.buffering} config. + *

                + * It's currently a no-op. + */ + @Override + public StoreBuilder> withCachingDisabled() { + return this; + } + + @Override + public StoreBuilder> withLoggingEnabled(final Map config) { + logConfig = config; + return this; + } + + @Override + public StoreBuilder> withLoggingDisabled() { + loggingEnabled = false; + return this; + } + + @Override + public InMemoryTimeOrderedKeyValueBuffer build() { + return new InMemoryTimeOrderedKeyValueBuffer<>(storeName, loggingEnabled, keySerde, valueSerde); + } + + @Override + public Map logConfig() { + return loggingEnabled() ? Collections.unmodifiableMap(logConfig) : Collections.emptyMap(); + } + + @Override + public boolean loggingEnabled() { + return loggingEnabled; + } + + @Override + public String name() { + return storeName; + } + } + + private InMemoryTimeOrderedKeyValueBuffer(final String storeName, + final boolean loggingEnabled, + final Serde keySerde, + final Serde valueSerde) { + this.storeName = storeName; + this.loggingEnabled = loggingEnabled; + this.keySerde = keySerde; + this.valueSerde = FullChangeSerde.wrap(valueSerde); + } + + @Override + public String name() { + return storeName; + } + + + @Override + public boolean persistent() { + return false; + } + + @SuppressWarnings("unchecked") + @Override + public void setSerdesIfNull(final SerdeGetter getter) { + keySerde = keySerde == null ? (Serde) getter.keySerde() : keySerde; + valueSerde = valueSerde == null ? FullChangeSerde.wrap((Serde) getter.valueSerde()) : valueSerde; + } + + @Deprecated + @Override + public void init(final ProcessorContext context, final StateStore root) { + this.context = ProcessorContextUtils.asInternalProcessorContext(context); + init(root); + } + + @Override + public void init(final StateStoreContext context, final StateStore root) { + this.context = ProcessorContextUtils.asInternalProcessorContext(context); + init(root); + } + + private void init(final StateStore root) { + taskId = context.taskId().toString(); + streamsMetrics = context.metrics(); + + bufferSizeSensor = StateStoreMetrics.suppressionBufferSizeSensor( + taskId, + METRIC_SCOPE, + storeName, + streamsMetrics + ); + bufferCountSensor = StateStoreMetrics.suppressionBufferCountSensor( + taskId, + METRIC_SCOPE, + storeName, + streamsMetrics + ); + + context.register(root, (RecordBatchingStateRestoreCallback) this::restoreBatch); + changelogTopic = ProcessorStateManager.storeChangelogTopic(context.applicationId(), storeName, context.taskId().topologyName()); + updateBufferMetrics(); + open = true; + partition = context.taskId().partition(); + } + + @Override + public boolean isOpen() { + return open; + } + + @Override + public void close() { + open = false; + index.clear(); + sortedMap.clear(); + dirtyKeys.clear(); + memBufferSize = 0; + minTimestamp = Long.MAX_VALUE; + updateBufferMetrics(); + streamsMetrics.removeAllStoreLevelSensorsAndMetrics(taskId, storeName); + } + + @Override + public void flush() { + if (loggingEnabled) { + // counting on this getting called before the record collector's flush + for (final Bytes key : dirtyKeys) { + + final BufferKey bufferKey = index.get(key); + + if (bufferKey == null) { + // The record was evicted from the buffer. Send a tombstone. + logTombstone(key); + } else { + final BufferValue value = sortedMap.get(bufferKey); + + logValue(key, bufferKey, value); + } + } + dirtyKeys.clear(); + } + } + + private void logValue(final Bytes key, final BufferKey bufferKey, final BufferValue value) { + + final int sizeOfBufferTime = Long.BYTES; + final ByteBuffer buffer = value.serialize(sizeOfBufferTime); + buffer.putLong(bufferKey.time()); + final byte[] array = buffer.array(); + ((RecordCollector.Supplier) context).recordCollector().send( + changelogTopic, + key, + array, + CHANGELOG_HEADERS, + partition, + null, + KEY_SERIALIZER, + VALUE_SERIALIZER + ); + } + + private void logTombstone(final Bytes key) { + ((RecordCollector.Supplier) context).recordCollector().send( + changelogTopic, + key, + null, + null, + partition, + null, + KEY_SERIALIZER, + VALUE_SERIALIZER + ); + } + + private void restoreBatch(final Collection> batch) { + for (final ConsumerRecord record : batch) { + if (record.partition() != partition) { + throw new IllegalStateException( + String.format( + "record partition [%d] is being restored by the wrong suppress partition [%d]", + record.partition(), + partition + ) + ); + } + final Bytes key = Bytes.wrap(record.key()); + if (record.value() == null) { + // This was a tombstone. Delete the record. + final BufferKey bufferKey = index.remove(key); + if (bufferKey != null) { + final BufferValue removed = sortedMap.remove(bufferKey); + if (removed != null) { + memBufferSize -= computeRecordSize(bufferKey.key(), removed); + } + if (bufferKey.time() == minTimestamp) { + minTimestamp = sortedMap.isEmpty() ? Long.MAX_VALUE : sortedMap.firstKey().time(); + } + } + } else { + final Header versionHeader = record.headers().lastHeader("v"); + if (versionHeader == null) { + // Version 0: + // value: + // - buffer time + // - old value + // - new value + final byte[] previousBufferedValue = index.containsKey(key) + ? internalPriorValueForBuffered(key) + : null; + final DeserializationResult deserializationResult = deserializeV0(record, key, previousBufferedValue); + cleanPut(deserializationResult.time(), deserializationResult.key(), deserializationResult.bufferValue()); + } else if (Arrays.equals(versionHeader.value(), V_3_CHANGELOG_HEADER_VALUE)) { + // Version 3: + // value: + // - record context + // - prior value + // - old value + // - new value + // - buffer time + final DeserializationResult deserializationResult = deserializeV3(record, key); + cleanPut(deserializationResult.time(), deserializationResult.key(), deserializationResult.bufferValue()); + + } else if (Arrays.equals(versionHeader.value(), V_2_CHANGELOG_HEADER_VALUE)) { + // Version 2: + // value: + // - record context + // - old value + // - new value + // - prior value + // - buffer time + // NOTE: 2.4.0, 2.4.1, and 2.5.0 actually encode Version 3 formatted data, + // but still set the Version 2 flag, so to deserialize, we have to duck type. + final DeserializationResult deserializationResult = duckTypeV2(record, key); + cleanPut(deserializationResult.time(), deserializationResult.key(), deserializationResult.bufferValue()); + } else if (Arrays.equals(versionHeader.value(), V_1_CHANGELOG_HEADER_VALUE)) { + // Version 1: + // value: + // - buffer time + // - record context + // - old value + // - new value + final byte[] previousBufferedValue = index.containsKey(key) + ? internalPriorValueForBuffered(key) + : null; + final DeserializationResult deserializationResult = deserializeV1(record, key, previousBufferedValue); + cleanPut(deserializationResult.time(), deserializationResult.key(), deserializationResult.bufferValue()); + } else { + throw new IllegalArgumentException("Restoring apparently invalid changelog record: " + record); + } + } + } + updateBufferMetrics(); + } + + + @Override + public void evictWhile(final Supplier predicate, + final Consumer> callback) { + final Iterator> delegate = sortedMap.entrySet().iterator(); + int evictions = 0; + + if (predicate.get()) { + Map.Entry next = null; + if (delegate.hasNext()) { + next = delegate.next(); + } + + // predicate being true means we read one record, call the callback, and then remove it + while (next != null && predicate.get()) { + if (next.getKey().time() != minTimestamp) { + throw new IllegalStateException( + "minTimestamp [" + minTimestamp + "] did not match the actual min timestamp [" + + next.getKey().time() + "]" + ); + } + final K key = keySerde.deserializer().deserialize(changelogTopic, next.getKey().key().get()); + final BufferValue bufferValue = next.getValue(); + final Change value = valueSerde.deserializeParts( + changelogTopic, + new Change<>(bufferValue.newValue(), bufferValue.oldValue()) + ); + callback.accept(new Eviction<>(key, value, bufferValue.context())); + + delegate.remove(); + index.remove(next.getKey().key()); + + dirtyKeys.add(next.getKey().key()); + + memBufferSize -= computeRecordSize(next.getKey().key(), bufferValue); + + // peek at the next record so we can update the minTimestamp + if (delegate.hasNext()) { + next = delegate.next(); + minTimestamp = next == null ? Long.MAX_VALUE : next.getKey().time(); + } else { + next = null; + minTimestamp = Long.MAX_VALUE; + } + + evictions++; + } + } + if (evictions > 0) { + updateBufferMetrics(); + } + } + + @Override + public Maybe> priorValueForBuffered(final K key) { + final Bytes serializedKey = Bytes.wrap(keySerde.serializer().serialize(changelogTopic, key)); + if (index.containsKey(serializedKey)) { + final byte[] serializedValue = internalPriorValueForBuffered(serializedKey); + + final V deserializedValue = valueSerde.innerSerde().deserializer().deserialize( + changelogTopic, + serializedValue + ); + + // it's unfortunately not possible to know this, unless we materialize the suppressed result, since our only + // knowledge of the prior value is what the upstream processor sends us as the "old value" when we first + // buffer something. + return Maybe.defined(ValueAndTimestamp.make(deserializedValue, RecordQueue.UNKNOWN)); + } else { + return Maybe.undefined(); + } + } + + private byte[] internalPriorValueForBuffered(final Bytes key) { + final BufferKey bufferKey = index.get(key); + if (bufferKey == null) { + throw new NoSuchElementException("Key [" + key + "] is not in the buffer."); + } else { + final BufferValue bufferValue = sortedMap.get(bufferKey); + return bufferValue.priorValue(); + } + } + + @Override + public void put(final long time, + final Record> record, + final ProcessorRecordContext recordContext) { + requireNonNull(record.value(), "value cannot be null"); + requireNonNull(recordContext, "recordContext cannot be null"); + + final Bytes serializedKey = Bytes.wrap(keySerde.serializer().serialize(changelogTopic, record.key())); + final Change serialChange = valueSerde.serializeParts(changelogTopic, record.value()); + + final BufferValue buffered = getBuffered(serializedKey); + final byte[] serializedPriorValue; + if (buffered == null) { + serializedPriorValue = serialChange.oldValue; + } else { + serializedPriorValue = buffered.priorValue(); + } + + cleanPut( + time, + serializedKey, + new BufferValue(serializedPriorValue, serialChange.oldValue, serialChange.newValue, recordContext) + ); + dirtyKeys.add(serializedKey); + updateBufferMetrics(); + } + + private BufferValue getBuffered(final Bytes key) { + final BufferKey bufferKey = index.get(key); + return bufferKey == null ? null : sortedMap.get(bufferKey); + } + + private void cleanPut(final long time, final Bytes key, final BufferValue value) { + // non-resetting semantics: + // if there was a previous version of the same record, + // then insert the new record in the same place in the priority queue + + final BufferKey previousKey = index.get(key); + if (previousKey == null) { + final BufferKey nextKey = new BufferKey(time, key); + index.put(key, nextKey); + sortedMap.put(nextKey, value); + minTimestamp = Math.min(minTimestamp, time); + memBufferSize += computeRecordSize(key, value); + } else { + final BufferValue removedValue = sortedMap.put(previousKey, value); + memBufferSize = + memBufferSize + + computeRecordSize(key, value) + - (removedValue == null ? 0 : computeRecordSize(key, removedValue)); + } + } + + @Override + public int numRecords() { + return index.size(); + } + + @Override + public long bufferSize() { + return memBufferSize; + } + + @Override + public long minTimestamp() { + return minTimestamp; + } + + private static long computeRecordSize(final Bytes key, final BufferValue value) { + long size = 0L; + size += 8; // buffer time + size += key.get().length; + if (value != null) { + size += value.residentMemorySizeEstimate(); + } + return size; + } + + private void updateBufferMetrics() { + bufferSizeSensor.record(memBufferSize, context.currentSystemTimeMs()); + bufferCountSensor.record(index.size(), context.currentSystemTimeMs()); + } + + @Override + public String toString() { + return "InMemoryTimeOrderedKeyValueBuffer{" + + "storeName='" + storeName + '\'' + + ", changelogTopic='" + changelogTopic + '\'' + + ", open=" + open + + ", loggingEnabled=" + loggingEnabled + + ", minTimestamp=" + minTimestamp + + ", memBufferSize=" + memBufferSize + + ", \n\tdirtyKeys=" + dirtyKeys + + ", \n\tindex=" + index + + ", \n\tsortedMap=" + sortedMap + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryWindowBytesStoreSupplier.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryWindowBytesStoreSupplier.java new file mode 100644 index 0000000..4ab93ec --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryWindowBytesStoreSupplier.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.state.WindowBytesStoreSupplier; +import org.apache.kafka.streams.state.WindowStore; + +public class InMemoryWindowBytesStoreSupplier implements WindowBytesStoreSupplier { + private final String name; + private final long retentionPeriod; + private final long windowSize; + private final boolean retainDuplicates; + + public InMemoryWindowBytesStoreSupplier(final String name, + final long retentionPeriod, + final long windowSize, + final boolean retainDuplicates) { + this.name = name; + this.retentionPeriod = retentionPeriod; + this.windowSize = windowSize; + this.retainDuplicates = retainDuplicates; + } + + @Override + public String name() { + return name; + } + + @Override + public WindowStore get() { + return new InMemoryWindowStore(name, + retentionPeriod, + windowSize, + retainDuplicates, + metricsScope()); + } + + @Override + public String metricsScope() { + return "in-memory-window"; + } + + @Override + public long retentionPeriod() { + return retentionPeriod; + } + + + @Override + public long windowSize() { + return windowSize; + } + + // In-memory window store is not *really* segmented, so just say size is 1 ms + @Override + public long segmentIntervalMs() { + return 1; + } + + @Override + public boolean retainDuplicates() { + return retainDuplicates; + } + + @Override + public String toString() { + return "InMemoryWindowBytesStoreSupplier{" + + "name='" + name + '\'' + + ", retentionPeriod=" + retentionPeriod + + ", windowSize=" + windowSize + + ", retainDuplicates=" + retainDuplicates + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryWindowStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryWindowStore.java new file mode 100644 index 0000000..5327e75 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/InMemoryWindowStore.java @@ -0,0 +1,623 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.TimeWindow; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.internals.ProcessorContextUtils; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.processor.internals.metrics.TaskMetrics; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.streams.state.WindowStoreIterator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentNavigableMap; +import java.util.concurrent.ConcurrentSkipListMap; + +import static org.apache.kafka.streams.state.internals.WindowKeySchema.extractStoreKeyBytes; +import static org.apache.kafka.streams.state.internals.WindowKeySchema.extractStoreTimestamp; + + +public class InMemoryWindowStore implements WindowStore { + + private static final Logger LOG = LoggerFactory.getLogger(InMemoryWindowStore.class); + private static final int SEQNUM_SIZE = 4; + + private final String name; + private final String metricScope; + private final long retentionPeriod; + private final long windowSize; + private final boolean retainDuplicates; + + private final ConcurrentNavigableMap> segmentMap = new ConcurrentSkipListMap<>(); + private final Set openIterators = ConcurrentHashMap.newKeySet(); + + private ProcessorContext context; + private Sensor expiredRecordSensor; + private int seqnum = 0; + private long observedStreamTime = ConsumerRecord.NO_TIMESTAMP; + + private volatile boolean open = false; + + public InMemoryWindowStore(final String name, + final long retentionPeriod, + final long windowSize, + final boolean retainDuplicates, + final String metricScope) { + this.name = name; + this.retentionPeriod = retentionPeriod; + this.windowSize = windowSize; + this.retainDuplicates = retainDuplicates; + this.metricScope = metricScope; + } + + @Override + public String name() { + return name; + } + + @Deprecated + @Override + public void init(final ProcessorContext context, final StateStore root) { + this.context = context; + + final StreamsMetricsImpl metrics = ProcessorContextUtils.getMetricsImpl(context); + final String threadId = Thread.currentThread().getName(); + final String taskName = context.taskId().toString(); + expiredRecordSensor = TaskMetrics.droppedRecordsSensor( + threadId, + taskName, + metrics + ); + + if (root != null) { + context.register(root, (key, value) -> + put(Bytes.wrap(extractStoreKeyBytes(key)), value, extractStoreTimestamp(key))); + } + open = true; + } + + @Override + public void put(final Bytes key, final byte[] value, final long windowStartTimestamp) { + removeExpiredSegments(); + observedStreamTime = Math.max(observedStreamTime, windowStartTimestamp); + + if (windowStartTimestamp <= observedStreamTime - retentionPeriod) { + expiredRecordSensor.record(1.0d, ProcessorContextUtils.currentSystemTime(context)); + LOG.warn("Skipping record for expired segment."); + } else { + if (value != null) { + maybeUpdateSeqnumForDups(); + final Bytes keyBytes = retainDuplicates ? wrapForDups(key, seqnum) : key; + segmentMap.computeIfAbsent(windowStartTimestamp, t -> new ConcurrentSkipListMap<>()); + segmentMap.get(windowStartTimestamp).put(keyBytes, value); + } else if (!retainDuplicates) { + // Skip if value is null and duplicates are allowed since this delete is a no-op + segmentMap.computeIfPresent(windowStartTimestamp, (t, kvMap) -> { + kvMap.remove(key); + if (kvMap.isEmpty()) { + segmentMap.remove(windowStartTimestamp); + } + return kvMap; + }); + } + } + } + + @Override + public byte[] fetch(final Bytes key, final long windowStartTimestamp) { + Objects.requireNonNull(key, "key cannot be null"); + + removeExpiredSegments(); + + if (windowStartTimestamp <= observedStreamTime - retentionPeriod) { + return null; + } + + final ConcurrentNavigableMap kvMap = segmentMap.get(windowStartTimestamp); + if (kvMap == null) { + return null; + } else { + return kvMap.get(key); + } + } + + @Deprecated + @Override + public WindowStoreIterator fetch(final Bytes key, final long timeFrom, final long timeTo) { + return fetch(key, timeFrom, timeTo, true); + } + + @Override + public WindowStoreIterator backwardFetch(final Bytes key, final long timeFrom, final long timeTo) { + return fetch(key, timeFrom, timeTo, false); + } + + WindowStoreIterator fetch(final Bytes key, final long timeFrom, final long timeTo, final boolean forward) { + Objects.requireNonNull(key, "key cannot be null"); + + removeExpiredSegments(); + + // add one b/c records expire exactly retentionPeriod ms after created + final long minTime = Math.max(timeFrom, observedStreamTime - retentionPeriod + 1); + + if (timeTo < minTime) { + return WrappedInMemoryWindowStoreIterator.emptyIterator(); + } + + if (forward) { + return registerNewWindowStoreIterator( + key, + segmentMap.subMap(minTime, true, timeTo, true) + .entrySet().iterator(), + true + ); + } else { + return registerNewWindowStoreIterator( + key, + segmentMap.subMap(minTime, true, timeTo, true) + .descendingMap().entrySet().iterator(), + false + ); + } + } + + @Deprecated + @Override + public KeyValueIterator, byte[]> fetch(final Bytes keyFrom, + final Bytes keyTo, + final long timeFrom, + final long timeTo) { + return fetch(keyFrom, keyTo, timeFrom, timeTo, true); + } + + @Override + public KeyValueIterator, byte[]> backwardFetch(final Bytes keyFrom, + final Bytes keyTo, + final long timeFrom, + final long timeTo) { + return fetch(keyFrom, keyTo, timeFrom, timeTo, false); + } + + KeyValueIterator, byte[]> fetch(final Bytes from, + final Bytes to, + final long timeFrom, + final long timeTo, + final boolean forward) { + removeExpiredSegments(); + + if (from != null && to != null && from.compareTo(to) > 0) { + LOG.warn("Returning empty iterator for fetch with invalid key range: from > to. " + + "This may be due to range arguments set in the wrong order, " + + "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes. " + + "Note that the built-in numerical serdes do not follow this for negative numbers"); + return KeyValueIterators.emptyIterator(); + } + + // add one b/c records expire exactly retentionPeriod ms after created + final long minTime = Math.max(timeFrom, observedStreamTime - retentionPeriod + 1); + + if (timeTo < minTime) { + return KeyValueIterators.emptyIterator(); + } + + if (forward) { + return registerNewWindowedKeyValueIterator( + from, + to, + segmentMap.subMap(minTime, true, timeTo, true) + .entrySet().iterator(), + true + ); + } else { + return registerNewWindowedKeyValueIterator( + from, + to, + segmentMap.subMap(minTime, true, timeTo, true) + .descendingMap().entrySet().iterator(), + false + ); + } + } + + @Deprecated + @Override + public KeyValueIterator, byte[]> fetchAll(final long timeFrom, final long timeTo) { + return fetchAll(timeFrom, timeTo, true); + } + + @Override + public KeyValueIterator, byte[]> backwardFetchAll(final long timeFrom, final long timeTo) { + return fetchAll(timeFrom, timeTo, false); + } + + KeyValueIterator, byte[]> fetchAll(final long timeFrom, final long timeTo, final boolean forward) { + removeExpiredSegments(); + + // add one b/c records expire exactly retentionPeriod ms after created + final long minTime = Math.max(timeFrom, observedStreamTime - retentionPeriod + 1); + + if (timeTo < minTime) { + return KeyValueIterators.emptyIterator(); + } + + if (forward) { + return registerNewWindowedKeyValueIterator( + null, + null, + segmentMap.subMap(minTime, true, timeTo, true) + .entrySet().iterator(), + true + ); + } else { + return registerNewWindowedKeyValueIterator( + null, + null, + segmentMap.subMap(minTime, true, timeTo, true) + .descendingMap().entrySet().iterator(), + false + ); + } + } + + @Override + public KeyValueIterator, byte[]> all() { + removeExpiredSegments(); + + final long minTime = observedStreamTime - retentionPeriod; + + return registerNewWindowedKeyValueIterator( + null, + null, + segmentMap.tailMap(minTime, false).entrySet().iterator(), + true + ); + } + + @Override + public KeyValueIterator, byte[]> backwardAll() { + removeExpiredSegments(); + + final long minTime = observedStreamTime - retentionPeriod; + + return registerNewWindowedKeyValueIterator( + null, + null, + segmentMap.tailMap(minTime, false).descendingMap().entrySet().iterator(), + false + ); + } + + @Override + public boolean persistent() { + return false; + } + + @Override + public boolean isOpen() { + return open; + } + + @Override + public void flush() { + // do-nothing since it is in-memory + } + + @Override + public void close() { + if (openIterators.size() != 0) { + LOG.warn("Closing {} open iterators for store {}", openIterators.size(), name); + for (final InMemoryWindowStoreIteratorWrapper it : openIterators) { + it.close(); + } + } + + segmentMap.clear(); + open = false; + } + + private void removeExpiredSegments() { + long minLiveTime = Math.max(0L, observedStreamTime - retentionPeriod + 1); + for (final InMemoryWindowStoreIteratorWrapper it : openIterators) { + minLiveTime = Math.min(minLiveTime, it.minTime()); + } + segmentMap.headMap(minLiveTime, false).clear(); + } + + private void maybeUpdateSeqnumForDups() { + if (retainDuplicates) { + seqnum = (seqnum + 1) & 0x7FFFFFFF; + } + } + + private static Bytes wrapForDups(final Bytes key, final int seqnum) { + final ByteBuffer buf = ByteBuffer.allocate(key.get().length + SEQNUM_SIZE); + buf.put(key.get()); + buf.putInt(seqnum); + + return Bytes.wrap(buf.array()); + } + + private static Bytes getKey(final Bytes keyBytes) { + final byte[] bytes = new byte[keyBytes.get().length - SEQNUM_SIZE]; + System.arraycopy(keyBytes.get(), 0, bytes, 0, bytes.length); + return Bytes.wrap(bytes); + } + + private WrappedInMemoryWindowStoreIterator registerNewWindowStoreIterator(final Bytes key, + final Iterator>> segmentIterator, + final boolean forward) { + final Bytes keyFrom = retainDuplicates ? wrapForDups(key, 0) : key; + final Bytes keyTo = retainDuplicates ? wrapForDups(key, Integer.MAX_VALUE) : key; + + final WrappedInMemoryWindowStoreIterator iterator = + new WrappedInMemoryWindowStoreIterator(keyFrom, keyTo, segmentIterator, openIterators::remove, retainDuplicates, forward); + + openIterators.add(iterator); + return iterator; + } + + private WrappedWindowedKeyValueIterator registerNewWindowedKeyValueIterator(final Bytes keyFrom, + final Bytes keyTo, + final Iterator>> segmentIterator, + final boolean forward) { + final Bytes from = (retainDuplicates && keyFrom != null) ? wrapForDups(keyFrom, 0) : keyFrom; + final Bytes to = (retainDuplicates && keyTo != null) ? wrapForDups(keyTo, Integer.MAX_VALUE) : keyTo; + + final WrappedWindowedKeyValueIterator iterator = + new WrappedWindowedKeyValueIterator( + from, + to, + segmentIterator, + openIterators::remove, + retainDuplicates, + windowSize, + forward); + openIterators.add(iterator); + return iterator; + } + + + interface ClosingCallback { + void deregisterIterator(final InMemoryWindowStoreIteratorWrapper iterator); + } + + private static abstract class InMemoryWindowStoreIteratorWrapper { + + private Iterator> recordIterator; + private KeyValue next; + private long currentTime; + + private final boolean allKeys; + private final Bytes keyFrom; + private final Bytes keyTo; + private final Iterator>> segmentIterator; + private final ClosingCallback callback; + private final boolean retainDuplicates; + private final boolean forward; + + InMemoryWindowStoreIteratorWrapper(final Bytes keyFrom, + final Bytes keyTo, + final Iterator>> segmentIterator, + final ClosingCallback callback, + final boolean retainDuplicates, + final boolean forward) { + this.keyFrom = keyFrom; + this.keyTo = keyTo; + allKeys = (keyFrom == null) && (keyTo == null); + this.retainDuplicates = retainDuplicates; + this.forward = forward; + + this.segmentIterator = segmentIterator; + this.callback = callback; + recordIterator = segmentIterator == null ? null : setRecordIterator(); + } + + public boolean hasNext() { + if (next != null) { + return true; + } + if (recordIterator == null || (!recordIterator.hasNext() && !segmentIterator.hasNext())) { + return false; + } + + next = getNext(); + if (next == null) { + return false; + } + + if (allKeys || !retainDuplicates) { + return true; + } + + final Bytes key = getKey(next.key); + if (isKeyWithinRange(key)) { + return true; + } else { + next = null; + return hasNext(); + } + } + + private boolean isKeyWithinRange(final Bytes key) { + // split all cases for readability and avoid BooleanExpressionComplexity checkstyle warning + if (keyFrom == null && keyTo == null) { + // fetch all + return true; + } else if (keyFrom == null) { + // start from the beginning + return key.compareTo(getKey(keyTo)) <= 0; + } else if (keyTo == null) { + // end to the last + return key.compareTo(getKey(keyFrom)) >= 0; + } else { + // key is within the range + return key.compareTo(getKey(keyFrom)) >= 0 && key.compareTo(getKey(keyTo)) <= 0; + } + } + + public void close() { + next = null; + recordIterator = null; + callback.deregisterIterator(this); + } + + // getNext is only called when either recordIterator or segmentIterator has a next + // Note this does not guarantee a next record exists as the next segments may not contain any keys in range + protected KeyValue getNext() { + while (!recordIterator.hasNext()) { + recordIterator = setRecordIterator(); + if (recordIterator == null) { + return null; + } + } + final Map.Entry nextRecord = recordIterator.next(); + return new KeyValue<>(nextRecord.getKey(), nextRecord.getValue()); + } + + // Resets recordIterator to point to the next segment and returns null if there are no more segments + // Note it may not actually point to anything if no keys in range exist in the next segment + Iterator> setRecordIterator() { + if (!segmentIterator.hasNext()) { + return null; + } + + final Map.Entry> currentSegment = segmentIterator.next(); + currentTime = currentSegment.getKey(); + + final ConcurrentNavigableMap subMap; + if (allKeys) { // keyFrom == null && keyTo == null + subMap = currentSegment.getValue(); + } else if (keyFrom == null) { + subMap = currentSegment.getValue().headMap(keyTo, true); + } else if (keyTo == null) { + subMap = currentSegment.getValue().tailMap(keyFrom, true); + } else { + subMap = currentSegment.getValue().subMap(keyFrom, true, keyTo, true); + } + + if (forward) { + return subMap.entrySet().iterator(); + } else { + return subMap.descendingMap().entrySet().iterator(); + } + } + + Long minTime() { + return currentTime; + } + } + + private static class WrappedInMemoryWindowStoreIterator extends InMemoryWindowStoreIteratorWrapper implements WindowStoreIterator { + + WrappedInMemoryWindowStoreIterator(final Bytes keyFrom, + final Bytes keyTo, + final Iterator>> segmentIterator, + final ClosingCallback callback, + final boolean retainDuplicates, + final boolean forward) { + super(keyFrom, keyTo, segmentIterator, callback, retainDuplicates, forward); + } + + @Override + public Long peekNextKey() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return super.currentTime; + } + + @Override + public KeyValue next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + + final KeyValue result = new KeyValue<>(super.currentTime, super.next.value); + super.next = null; + return result; + } + + public static WrappedInMemoryWindowStoreIterator emptyIterator() { + return new WrappedInMemoryWindowStoreIterator(null, null, null, it -> { + }, false, true); + } + } + + private static class WrappedWindowedKeyValueIterator + extends InMemoryWindowStoreIteratorWrapper + implements KeyValueIterator, byte[]> { + + private final long windowSize; + + WrappedWindowedKeyValueIterator(final Bytes keyFrom, + final Bytes keyTo, + final Iterator>> segmentIterator, + final ClosingCallback callback, + final boolean retainDuplicates, + final long windowSize, + final boolean forward) { + super(keyFrom, keyTo, segmentIterator, callback, retainDuplicates, forward); + this.windowSize = windowSize; + } + + public Windowed peekNextKey() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return getWindowedKey(); + } + + public KeyValue, byte[]> next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + + final KeyValue, byte[]> result = new KeyValue<>(getWindowedKey(), super.next.value); + super.next = null; + return result; + } + + private Windowed getWindowedKey() { + final Bytes key = super.retainDuplicates ? getKey(super.next.key) : super.next.key; + long endTime = super.currentTime + windowSize; + + if (endTime < 0) { + LOG.warn("Warning: window end time was truncated to Long.MAX"); + endTime = Long.MAX_VALUE; + } + + final TimeWindow timeWindow = new TimeWindow(super.currentTime, endTime); + return new Windowed<>(key, timeWindow); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueIteratorFacade.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueIteratorFacade.java new file mode 100644 index 0000000..f79b6f3 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueIteratorFacade.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.ValueAndTimestamp; + +import static org.apache.kafka.streams.state.ValueAndTimestamp.getValueOrNull; + +public class KeyValueIteratorFacade implements KeyValueIterator { + private final KeyValueIterator> innerIterator; + + public KeyValueIteratorFacade(final KeyValueIterator> iterator) { + innerIterator = iterator; + } + + @Override + public boolean hasNext() { + return innerIterator.hasNext(); + } + + @Override + public K peekNextKey() { + return innerIterator.peekNextKey(); + } + + @Override + public KeyValue next() { + final KeyValue> innerKeyValue = innerIterator.next(); + return KeyValue.pair(innerKeyValue.key, getValueOrNull(innerKeyValue.value)); + } + + @Override + public void close() { + innerIterator.close(); + } +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueIterators.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueIterators.java new file mode 100644 index 0000000..29c3009 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueIterators.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.WindowStoreIterator; + +import java.util.NoSuchElementException; + +class KeyValueIterators { + + private static class EmptyKeyValueIterator implements KeyValueIterator { + + @Override + public void close() { + } + + @Override + public K peekNextKey() { + throw new NoSuchElementException(); + } + + @Override + public boolean hasNext() { + return false; + } + + @Override + public KeyValue next() { + throw new NoSuchElementException(); + } + + } + + private static class EmptyWindowStoreIterator extends EmptyKeyValueIterator + implements WindowStoreIterator { + } + + private static final KeyValueIterator EMPTY_ITERATOR = new EmptyKeyValueIterator(); + private static final WindowStoreIterator EMPTY_WINDOW_STORE_ITERATOR = new EmptyWindowStoreIterator(); + + + @SuppressWarnings("unchecked") + static KeyValueIterator emptyIterator() { + return (KeyValueIterator) EMPTY_ITERATOR; + } + + @SuppressWarnings("unchecked") + static WindowStoreIterator emptyWindowStoreIterator() { + return (WindowStoreIterator) EMPTY_WINDOW_STORE_ITERATOR; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueSegment.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueSegment.java new file mode 100644 index 0000000..66c55fc --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueSegment.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.state.internals.metrics.RocksDBMetricsRecorder; + +import java.io.File; +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +class KeyValueSegment extends RocksDBStore implements Comparable, Segment { + public final long id; + + KeyValueSegment(final String segmentName, + final String windowName, + final long id, + final RocksDBMetricsRecorder metricsRecorder) { + super(segmentName, windowName, metricsRecorder); + this.id = id; + } + + @Override + public void destroy() throws IOException { + Utils.delete(dbDir); + } + + @Override + public synchronized void deleteRange(final Bytes keyFrom, final Bytes keyTo) { + super.deleteRange(keyFrom, keyTo); + } + + @Override + public int compareTo(final KeyValueSegment segment) { + return Long.compare(id, segment.id); + } + + @Override + public void openDB(final Map configs, final File stateDir) { + super.openDB(configs, stateDir); + // skip the registering step + } + + @Override + public String toString() { + return "KeyValueSegment(id=" + id + ", name=" + name() + ")"; + } + + @Override + public boolean equals(final Object obj) { + if (obj == null || getClass() != obj.getClass()) { + return false; + } + final KeyValueSegment segment = (KeyValueSegment) obj; + return id == segment.id; + } + + @Override + public int hashCode() { + return Objects.hash(id); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueSegments.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueSegments.java new file mode 100644 index 0000000..a17666e --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueSegments.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.internals.ProcessorContextUtils; +import org.apache.kafka.streams.state.internals.metrics.RocksDBMetricsRecorder; + +/** + * Manages the {@link KeyValueSegment}s that are used by the {@link RocksDBSegmentedBytesStore} + */ +class KeyValueSegments extends AbstractSegments { + + private final RocksDBMetricsRecorder metricsRecorder; + + KeyValueSegments(final String name, + final String metricsScope, + final long retentionPeriod, + final long segmentInterval) { + super(name, retentionPeriod, segmentInterval); + metricsRecorder = new RocksDBMetricsRecorder(metricsScope, name); + } + + @Override + public KeyValueSegment getOrCreateSegment(final long segmentId, + final ProcessorContext context) { + if (segments.containsKey(segmentId)) { + return segments.get(segmentId); + } else { + final KeyValueSegment newSegment = + new KeyValueSegment(segmentName(segmentId), name, segmentId, metricsRecorder); + + if (segments.put(segmentId, newSegment) != null) { + throw new IllegalStateException("KeyValueSegment already exists. Possible concurrent access."); + } + + newSegment.openDB(context.appConfigs(), context.stateDir()); + return newSegment; + } + } + + @Override + public void openExisting(final ProcessorContext context, final long streamTime) { + metricsRecorder.init(ProcessorContextUtils.getMetricsImpl(context), context.taskId()); + super.openExisting(context, streamTime); + } +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueStoreBuilder.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueStoreBuilder.java new file mode 100644 index 0000000..7888316 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueStoreBuilder.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.state.KeyValueBytesStoreSupplier; +import org.apache.kafka.streams.state.KeyValueStore; + +import java.util.Objects; + +public class KeyValueStoreBuilder extends AbstractStoreBuilder> { + + private final KeyValueBytesStoreSupplier storeSupplier; + + public KeyValueStoreBuilder(final KeyValueBytesStoreSupplier storeSupplier, + final Serde keySerde, + final Serde valueSerde, + final Time time) { + super(storeSupplier.name(), keySerde, valueSerde, time); + Objects.requireNonNull(storeSupplier, "storeSupplier can't be null"); + Objects.requireNonNull(storeSupplier.metricsScope(), "storeSupplier's metricsScope can't be null"); + this.storeSupplier = storeSupplier; + } + + @Override + public KeyValueStore build() { + return new MeteredKeyValueStore<>( + maybeWrapCaching(maybeWrapLogging(storeSupplier.get())), + storeSupplier.metricsScope(), + time, + keySerde, + valueSerde); + } + + private KeyValueStore maybeWrapCaching(final KeyValueStore inner) { + if (!enableCaching) { + return inner; + } + return new CachingKeyValueStore(inner); + } + + private KeyValueStore maybeWrapLogging(final KeyValueStore inner) { + if (!enableLogging) { + return inner; + } + return new ChangeLoggingKeyValueBytesStore(inner); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueToTimestampedKeyValueByteStoreAdapter.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueToTimestampedKeyValueByteStoreAdapter.java new file mode 100644 index 0000000..d9b42c2 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueToTimestampedKeyValueByteStoreAdapter.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.state.KeyValueBytesStoreSupplier; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; + +import java.util.List; + +import static org.apache.kafka.streams.state.TimestampedBytesStore.convertToTimestampedFormat; +import static org.apache.kafka.streams.state.internals.ValueAndTimestampDeserializer.rawValue; + +/** + * This class is used to ensure backward compatibility at DSL level between + * {@link org.apache.kafka.streams.state.TimestampedKeyValueStore} and {@link KeyValueStore}. + *

                + * If a user provides a supplier for plain {@code KeyValueStores} via + * {@link org.apache.kafka.streams.kstream.Materialized#as(KeyValueBytesStoreSupplier)} this adapter is used to + * translate between old a new {@code byte[]} format of the value. + * + * @see KeyValueToTimestampedKeyValueIteratorAdapter + */ +public class KeyValueToTimestampedKeyValueByteStoreAdapter implements KeyValueStore { + final KeyValueStore store; + + KeyValueToTimestampedKeyValueByteStoreAdapter(final KeyValueStore store) { + if (!store.persistent()) { + throw new IllegalArgumentException("Provided store must be a persistent store, but it is not."); + } + this.store = store; + } + + @Override + public void put(final Bytes key, + final byte[] valueWithTimestamp) { + store.put(key, valueWithTimestamp == null ? null : rawValue(valueWithTimestamp)); + } + + @Override + public byte[] putIfAbsent(final Bytes key, + final byte[] valueWithTimestamp) { + return convertToTimestampedFormat(store.putIfAbsent( + key, + valueWithTimestamp == null ? null : rawValue(valueWithTimestamp))); + } + + @Override + public void putAll(final List> entries) { + for (final KeyValue entry : entries) { + final byte[] valueWithTimestamp = entry.value; + store.put(entry.key, valueWithTimestamp == null ? null : rawValue(valueWithTimestamp)); + } + } + + @Override + public byte[] delete(final Bytes key) { + return convertToTimestampedFormat(store.delete(key)); + } + + @Override + public String name() { + return store.name(); + } + + @Deprecated + @Override + public void init(final ProcessorContext context, + final StateStore root) { + store.init(context, root); + } + + @Override + public void init(final StateStoreContext context, final StateStore root) { + store.init(context, root); + } + + @Override + public void flush() { + store.flush(); + } + + @Override + public void close() { + store.close(); + } + + @Override + public boolean persistent() { + return true; + } + + @Override + public boolean isOpen() { + return store.isOpen(); + } + + @Override + public byte[] get(final Bytes key) { + return convertToTimestampedFormat(store.get(key)); + } + + @Override + public KeyValueIterator range(final Bytes from, + final Bytes to) { + return new KeyValueToTimestampedKeyValueIteratorAdapter<>(store.range(from, to)); + } + + @Override + public KeyValueIterator reverseRange(final Bytes from, + final Bytes to) { + return new KeyValueToTimestampedKeyValueIteratorAdapter<>(store.reverseRange(from, to)); + } + + @Override + public KeyValueIterator all() { + return new KeyValueToTimestampedKeyValueIteratorAdapter<>(store.all()); + } + + @Override + public KeyValueIterator reverseAll() { + return new KeyValueToTimestampedKeyValueIteratorAdapter<>(store.reverseAll()); + } + + @Override + public , P> KeyValueIterator prefixScan(final P prefix, + final PS prefixKeySerializer) { + return new KeyValueToTimestampedKeyValueIteratorAdapter<>(store.prefixScan(prefix, prefixKeySerializer)); + } + + @Override + public long approximateNumEntries() { + return store.approximateNumEntries(); + } + +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueToTimestampedKeyValueIteratorAdapter.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueToTimestampedKeyValueIteratorAdapter.java new file mode 100644 index 0000000..7bdcb5b --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/KeyValueToTimestampedKeyValueIteratorAdapter.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.state.KeyValueIterator; + +import static org.apache.kafka.streams.state.TimestampedBytesStore.convertToTimestampedFormat; + +/** + * This class is used to ensure backward compatibility at DSL level between + * {@link org.apache.kafka.streams.state.TimestampedKeyValueStore} and + * {@link org.apache.kafka.streams.state.KeyValueStore}. + * + * @see KeyValueToTimestampedKeyValueByteStoreAdapter + */ +class KeyValueToTimestampedKeyValueIteratorAdapter implements KeyValueIterator { + private final KeyValueIterator innerIterator; + + KeyValueToTimestampedKeyValueIteratorAdapter(final KeyValueIterator innerIterator) { + this.innerIterator = innerIterator; + } + + @Override + public void close() { + innerIterator.close(); + } + + @Override + public K peekNextKey() { + return innerIterator.peekNextKey(); + } + + @Override + public boolean hasNext() { + return innerIterator.hasNext(); + } + + @Override + public KeyValue next() { + final KeyValue plainKeyValue = innerIterator.next(); + return KeyValue.pair(plainKeyValue.key, convertToTimestampedFormat(plainKeyValue.value)); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/LRUCacheEntry.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/LRUCacheEntry.java new file mode 100644 index 0000000..f4233c7 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/LRUCacheEntry.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.streams.processor.internals.ProcessorRecordContext; + +import java.util.Objects; + +/** + * A cache entry + */ +class LRUCacheEntry { + private final ContextualRecord record; + private final long sizeBytes; + private boolean isDirty; + + + LRUCacheEntry(final byte[] value) { + this(value, new RecordHeaders(), false, -1, -1, -1, ""); + } + + LRUCacheEntry(final byte[] value, + final Headers headers, + final boolean isDirty, + final long offset, + final long timestamp, + final int partition, + final String topic) { + final ProcessorRecordContext context = new ProcessorRecordContext(timestamp, offset, partition, topic, headers); + + this.record = new ContextualRecord( + value, + context + ); + + this.isDirty = isDirty; + this.sizeBytes = 1 + // isDirty + record.residentMemorySizeEstimate(); + } + + void markClean() { + isDirty = false; + } + + boolean isDirty() { + return isDirty; + } + + long size() { + return sizeBytes; + } + + byte[] value() { + return record.value(); + } + + public ProcessorRecordContext context() { + return record.recordContext(); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final LRUCacheEntry that = (LRUCacheEntry) o; + return sizeBytes == that.sizeBytes && + isDirty() == that.isDirty() && + Objects.equals(record, that.record); + } + + @Override + public int hashCode() { + return Objects.hash(record, sizeBytes, isDirty()); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/LeftOrRightValue.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/LeftOrRightValue.java new file mode 100644 index 0000000..bb7b516 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/LeftOrRightValue.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import java.util.Objects; + +/** + * This class is used in combination of {@link TimestampedKeyAndJoinSide}. The {@link TimestampedKeyAndJoinSide} class + * combines a key with a boolean value that specifies if the key is found in the left side of a + * join or on the right side. This {@link LeftOrRightValue} object contains either the V1 value, + * which is found in the left topic, or V2 value if it is found in the right topic. + */ +public class LeftOrRightValue { + private final V1 leftValue; + private final V2 rightValue; + + private LeftOrRightValue(final V1 leftValue, final V2 rightValue) { + if (leftValue != null && rightValue != null) { + throw new IllegalArgumentException("Only one value cannot be null"); + } else if (leftValue == null && rightValue == null) { + throw new NullPointerException("Only one value can be null"); + } + + this.leftValue = leftValue; + this.rightValue = rightValue; + } + + /** + * Create a new {@link LeftOrRightValue} instance with the V1 value as {@code leftValue} and + * V2 value as null. + * + * @param leftValue the left V1 value + * @param the type of the value + * @return a new {@link LeftOrRightValue} instance + */ + public static LeftOrRightValue makeLeftValue(final V1 leftValue) { + return new LeftOrRightValue<>(leftValue, null); + } + + /** + * Create a new {@link LeftOrRightValue} instance with the V2 value as {@code rightValue} and + * V1 value as null. + * + * @param rightValue the right V2 value + * @param the type of the value + * @return a new {@link LeftOrRightValue} instance + */ + public static LeftOrRightValue makeRightValue(final V2 rightValue) { + return new LeftOrRightValue<>(null, rightValue); + } + + /** + * Create a new {@link LeftOrRightValue} instance with the V value as {@code leftValue} if + * {@code isLeftSide} is True; otherwise {@code rightValue} if {@code isLeftSide} is False. + * + * @param value the V value (either V1 or V2 type) + * @param the type of the value + * @return a new {@link LeftOrRightValue} instance + */ + public static LeftOrRightValue make(final boolean isLeftSide, final V value) { + Objects.requireNonNull(value, "value is null"); + return isLeftSide + ? LeftOrRightValue.makeLeftValue(value) + : LeftOrRightValue.makeRightValue(value); + } + + public V1 getLeftValue() { + return leftValue; + } + + public V2 getRightValue() { + return rightValue; + } + + @Override + public String toString() { + return "<" + + ((leftValue != null) ? "left," + leftValue : "right," + rightValue) + + ">"; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final LeftOrRightValue that = (LeftOrRightValue) o; + return Objects.equals(leftValue, that.leftValue) && + Objects.equals(rightValue, that.rightValue); + } + + @Override + public int hashCode() { + return Objects.hash(leftValue, rightValue); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/LeftOrRightValueDeserializer.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/LeftOrRightValueDeserializer.java new file mode 100644 index 0000000..df45bc6 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/LeftOrRightValueDeserializer.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.streams.kstream.internals.WrappingNullableDeserializer; +import org.apache.kafka.streams.processor.internals.SerdeGetter; + +import java.util.Map; + +import static org.apache.kafka.streams.kstream.internals.WrappingNullableUtils.initNullableDeserializer; + +public class LeftOrRightValueDeserializer implements WrappingNullableDeserializer, Void, Object> { + public Deserializer leftDeserializer; + public Deserializer rightDeserializer; + + public LeftOrRightValueDeserializer(final Deserializer leftDeserializer, final Deserializer rightDeserializer) { + this.leftDeserializer = leftDeserializer; + this.rightDeserializer = rightDeserializer; + } + + @SuppressWarnings("unchecked") + @Override + public void setIfUnset(final SerdeGetter getter) { + if (leftDeserializer == null) { + leftDeserializer = (Deserializer) getter.valueSerde().deserializer(); + } + + if (rightDeserializer == null) { + rightDeserializer = (Deserializer) getter.valueSerde().deserializer(); + } + + initNullableDeserializer(leftDeserializer, getter); + initNullableDeserializer(rightDeserializer, getter); + } + + @Override + public void configure(final Map configs, + final boolean isKey) { + leftDeserializer.configure(configs, isKey); + rightDeserializer.configure(configs, isKey); + } + + @Override + public LeftOrRightValue deserialize(final String topic, final byte[] data) { + if (data == null || data.length == 0) { + return null; + } + + return (data[0] == 1) + ? LeftOrRightValue.makeLeftValue(leftDeserializer.deserialize(topic, rawValue(data))) + : LeftOrRightValue.makeRightValue(rightDeserializer.deserialize(topic, rawValue(data))); + } + + private byte[] rawValue(final byte[] data) { + final byte[] rawValue = new byte[data.length - 1]; + System.arraycopy(data, 1, rawValue, 0, rawValue.length); + return rawValue; + } +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/LeftOrRightValueSerde.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/LeftOrRightValueSerde.java new file mode 100644 index 0000000..cc2d068 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/LeftOrRightValueSerde.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.kstream.internals.WrappingNullableSerde; + +public class LeftOrRightValueSerde extends WrappingNullableSerde, Void, Object> { + public LeftOrRightValueSerde(final Serde leftValueSerde, final Serde rightValueSerde) { + super( + new LeftOrRightValueSerializer<>( + leftValueSerde != null ? leftValueSerde.serializer() : null, + rightValueSerde != null ? rightValueSerde.serializer() : null), + new LeftOrRightValueDeserializer<>( + leftValueSerde != null ? leftValueSerde.deserializer() : null, + rightValueSerde != null ? rightValueSerde.deserializer() : null) + ); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/LeftOrRightValueSerializer.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/LeftOrRightValueSerializer.java new file mode 100644 index 0000000..8f3c47d --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/LeftOrRightValueSerializer.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.kstream.internals.WrappingNullableSerializer; +import org.apache.kafka.streams.processor.internals.SerdeGetter; + +import java.nio.ByteBuffer; +import java.util.Map; + +import static org.apache.kafka.streams.kstream.internals.WrappingNullableUtils.initNullableSerializer; + +/** + * Serializes a {@link LeftOrRightValue}. The serialized bytes starts with a byte that references + * to whether the value is V1 or V2. + */ +public class LeftOrRightValueSerializer implements WrappingNullableSerializer, Void, Object> { + private Serializer leftSerializer; + private Serializer rightSerializer; + + public LeftOrRightValueSerializer(final Serializer leftSerializer, final Serializer rightSerializer) { + this.leftSerializer = leftSerializer; + this.rightSerializer = rightSerializer; + } + + @SuppressWarnings("unchecked") + @Override + public void setIfUnset(final SerdeGetter getter) { + if (leftSerializer == null) { + leftSerializer = (Serializer) getter.valueSerde().serializer(); + } + + if (rightSerializer == null) { + rightSerializer = (Serializer) getter.valueSerde().serializer(); + } + + initNullableSerializer(leftSerializer, getter); + initNullableSerializer(rightSerializer, getter); + } + + @Override + public void configure(final Map configs, final boolean isKey) { + leftSerializer.configure(configs, isKey); + rightSerializer.configure(configs, isKey); + } + + @Override + public byte[] serialize(final String topic, final LeftOrRightValue data) { + if (data == null) { + return null; + } + + final byte[] rawValue = (data.getLeftValue() != null) + ? leftSerializer.serialize(topic, data.getLeftValue()) + : rightSerializer.serialize(topic, data.getRightValue()); + + if (rawValue == null) { + return null; + } + + return ByteBuffer + .allocate(1 + rawValue.length) + .put((byte) (data.getLeftValue() != null ? 1 : 0)) + .put(rawValue) + .array(); + } + + @Override + public void close() { + leftSerializer.close(); + rightSerializer.close(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/ListValueStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/ListValueStore.java new file mode 100644 index 0000000..13e8997 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/ListValueStore.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.NoSuchElementException; + +import org.apache.kafka.common.utils.AbstractIterator; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serde; + +/** + * A wrapper key-value store that serializes the record values bytes as a list. + * As a result put calls would be interpreted as a get-append-put to the underlying RocksDB store. + * A put(k,null) will still delete the key, ie, the full list of all values of this key. + * Range iterators would also flatten the value lists and return the values one-by-one. + * + * This store is used for cases where we do not want to de-duplicate values of the same keys but want to retain all such values. + */ +@SuppressWarnings("unchecked") +public class ListValueStore + extends WrappedStateStore, Bytes, byte[]> + implements KeyValueStore { + + static private final Serde> LIST_SERDE = Serdes.ListSerde(ArrayList.class, Serdes.ByteArray()); + + ListValueStore(final KeyValueStore bytesStore) { + super(bytesStore); + } + + @Override + public void put(final Bytes key, final byte[] addedValue) { + // if the value is null we can skip the get and blind delete + if (addedValue == null) { + wrapped().put(key, null); + } else { + final byte[] oldValue = wrapped().get(key); + putInternal(key, addedValue, oldValue); + } + } + + @Override + public byte[] putIfAbsent(final Bytes key, final byte[] addedValue) { + final byte[] oldValue = wrapped().get(key); + + if (oldValue != null) { + // if the value is null we can skip the get and blind delete + if (addedValue == null) { + wrapped().put(key, null); + } else { + putInternal(key, addedValue, oldValue); + } + } + + // TODO: here we always return null so that deser would not fail. + // we only do this since we know the only caller (stream-stream join processor) + // would not need the actual value at all; the changelogging wrapper would not call this function + return null; + } + + // this function assumes the addedValue is not null; callers should check null themselves + private void putInternal(final Bytes key, final byte[] addedValue, final byte[] oldValue) { + if (oldValue == null) { + wrapped().put(key, LIST_SERDE.serializer().serialize(null, Collections.singletonList(addedValue))); + } else { + final List list = LIST_SERDE.deserializer().deserialize(null, oldValue); + list.add(addedValue); + + wrapped().put(key, LIST_SERDE.serializer().serialize(null, list)); + } + } + + @Override + public void putAll(final List> entries) { + throw new UnsupportedOperationException("putAll not supported"); + } + + @Override + public byte[] delete(final Bytes key) { + // we intentionally disable delete calls since the returned bytes would + // represent a list, not a single value; we need to have a new API for delete if we do need it + throw new UnsupportedOperationException("delete not supported"); + } + + @Override + public byte[] get(final Bytes key) { + return wrapped().get(key); + } + + @Override + public KeyValueIterator range(final Bytes from, final Bytes to) { + throw new UnsupportedOperationException("range not supported"); + } + + @Override + public KeyValueIterator all() { + return new ValueListIterator(wrapped().all()); + } + + @Override + public long approximateNumEntries() { + return wrapped().approximateNumEntries(); + } + + private static class ValueListIterator extends AbstractIterator> + implements KeyValueIterator { + + private final KeyValueIterator bytesIterator; + private final List currList = new ArrayList<>(); + private KeyValue next; + private Bytes nextKey; + + ValueListIterator(final KeyValueIterator bytesIterator) { + this.bytesIterator = bytesIterator; + } + + @Override + public Bytes peekNextKey() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return next.key; + } + + @Override + public KeyValue makeNext() { + while (currList.isEmpty() && bytesIterator.hasNext()) { + final KeyValue next = bytesIterator.next(); + nextKey = next.key; + currList.addAll(LIST_SERDE.deserializer().deserialize(null, next.value)); + } + + if (currList.isEmpty()) { + return allDone(); + } else { + next = KeyValue.pair(nextKey, currList.remove(0)); + return next; + } + } + + @Override + public void close() { + bytesIterator.close(); + // also need to clear the current list buffer since + // otherwise even after close the iter can still return data + currList.clear(); + } + } +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/ListValueStoreBuilder.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/ListValueStoreBuilder.java new file mode 100644 index 0000000..34e2e8b --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/ListValueStoreBuilder.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.state.KeyValueBytesStoreSupplier; +import org.apache.kafka.streams.state.KeyValueStore; + +import java.util.Objects; + +public class ListValueStoreBuilder extends AbstractStoreBuilder> { + private final KeyValueBytesStoreSupplier storeSupplier; + + public ListValueStoreBuilder(final KeyValueBytesStoreSupplier storeSupplier, + final Serde keySerde, + final Serde valueSerde, + final Time time) { + super(storeSupplier.name(), keySerde, valueSerde, time); + Objects.requireNonNull(storeSupplier, "storeSupplier can't be null"); + Objects.requireNonNull(storeSupplier.metricsScope(), "storeSupplier's metricsScope can't be null"); + this.storeSupplier = storeSupplier; + } + + @Override + public KeyValueStore build() { + return new MeteredKeyValueStore<>( + maybeWrapCaching(maybeWrapLogging(new ListValueStore(storeSupplier.get()))), + storeSupplier.metricsScope(), + time, + keySerde, + valueSerde); + } + + private KeyValueStore maybeWrapCaching(final KeyValueStore inner) { + if (!enableCaching) { + return inner; + } + return new CachingKeyValueStore(inner); + } + + private KeyValueStore maybeWrapLogging(final KeyValueStore inner) { + if (!enableLogging) { + return inner; + } + return new ChangeLoggingListValueBytesStore(inner); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/Maybe.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/Maybe.java new file mode 100644 index 0000000..8f95ecf --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/Maybe.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import java.util.NoSuchElementException; +import java.util.Objects; + +/** + * A container that may be empty, may contain null, or may contain a value. + * Distinct from {@link java.util.Optional}, since Optional cannot contain null. + * + * @param + */ +public final class Maybe { + private final T nullableValue; + private final boolean defined; + + public static Maybe defined(final T nullableValue) { + return new Maybe<>(nullableValue); + } + + public static Maybe undefined() { + return new Maybe<>(); + } + + private Maybe(final T nullableValue) { + this.nullableValue = nullableValue; + defined = true; + } + + private Maybe() { + nullableValue = null; + defined = false; + } + + public T getNullableValue() { + if (defined) { + return nullableValue; + } else { + throw new NoSuchElementException(); + } + } + + public boolean isDefined() { + return defined; + } + + @Override + public boolean equals(final Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + final Maybe maybe = (Maybe) o; + + // All undefined maybes are equal + // All defined null maybes are equal + return defined == maybe.defined && + (!defined || Objects.equals(nullableValue, maybe.nullableValue)); + } + + @Override + public int hashCode() { + // Since all undefined maybes are equal, we can hard-code their hashCode to -1. + // Since all defined null maybes are equal, we can hard-code their hashCode to 0. + return defined ? nullableValue == null ? 0 : nullableValue.hashCode() : -1; + } + + @Override + public String toString() { + if (defined) { + return "DefinedMaybe{" + nullableValue + "}"; + } else { + return "UndefinedMaybe{}"; + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/MemoryLRUCache.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/MemoryLRUCache.java new file mode 100644 index 0000000..22f1215 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/MemoryLRUCache.java @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; + +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +/** + * An in-memory LRU cache store based on HashSet and HashMap. + */ +public class MemoryLRUCache implements KeyValueStore { + + public interface EldestEntryRemovalListener { + void apply(Bytes key, byte[] value); + } + + private final String name; + protected final Map map; + + private boolean restoring = false; // TODO: this is a sub-optimal solution to avoid logging during restoration. + // in the future we should augment the StateRestoreCallback with onComplete etc to better resolve this. + private volatile boolean open = true; + + private EldestEntryRemovalListener listener; + + MemoryLRUCache(final String name, final int maxCacheSize) { + this.name = name; + + // leave room for one extra entry to handle adding an entry before the oldest can be removed + this.map = new LinkedHashMap(maxCacheSize + 1, 1.01f, true) { + private static final long serialVersionUID = 1L; + + @Override + protected boolean removeEldestEntry(final Map.Entry eldest) { + final boolean evict = super.size() > maxCacheSize; + if (evict && !restoring && listener != null) { + listener.apply(eldest.getKey(), eldest.getValue()); + } + return evict; + } + }; + } + + void setWhenEldestRemoved(final EldestEntryRemovalListener listener) { + this.listener = listener; + } + + @Override + public String name() { + return this.name; + } + + @Deprecated + @Override + public void init(final ProcessorContext context, final StateStore root) { + + // register the store + context.register(root, (key, value) -> { + restoring = true; + put(Bytes.wrap(key), value); + restoring = false; + }); + } + + @Override + public void init(final StateStoreContext context, final StateStore root) { + // register the store + context.register(root, (key, value) -> { + restoring = true; + put(Bytes.wrap(key), value); + restoring = false; + }); + } + + @Override + public boolean persistent() { + return false; + } + + @Override + public boolean isOpen() { + return open; + } + + @Override + public synchronized byte[] get(final Bytes key) { + Objects.requireNonNull(key); + + return this.map.get(key); + } + + @Override + public synchronized void put(final Bytes key, final byte[] value) { + Objects.requireNonNull(key); + if (value == null) { + delete(key); + } else { + this.map.put(key, value); + } + } + + @Override + public synchronized byte[] putIfAbsent(final Bytes key, final byte[] value) { + Objects.requireNonNull(key); + final byte[] originalValue = get(key); + if (originalValue == null) { + put(key, value); + } + return originalValue; + } + + @Override + public void putAll(final List> entries) { + for (final KeyValue entry : entries) { + put(entry.key, entry.value); + } + } + + @Override + public synchronized byte[] delete(final Bytes key) { + Objects.requireNonNull(key); + return this.map.remove(key); + } + + /** + * @throws UnsupportedOperationException at every invocation + */ + @Override + public KeyValueIterator range(final Bytes from, final Bytes to) { + throw new UnsupportedOperationException("MemoryLRUCache does not support range() function."); + } + + /** + * @throws UnsupportedOperationException at every invocation + */ + @Override + public KeyValueIterator reverseRange(final Bytes from, final Bytes to) { + throw new UnsupportedOperationException("MemoryLRUCache does not support reverseRange() function."); + } + + /** + * @throws UnsupportedOperationException at every invocation + */ + @Override + public KeyValueIterator all() { + throw new UnsupportedOperationException("MemoryLRUCache does not support all() function."); + } + + /** + * @throws UnsupportedOperationException at every invocation + */ + @Override + public KeyValueIterator reverseAll() { + throw new UnsupportedOperationException("MemoryLRUCache does not support reverseAll() function."); + } + + /** + * @throws UnsupportedOperationException at every invocation + */ + @Override + public , P> KeyValueIterator prefixScan(final P prefix, + final PS prefixKeySerializer) { + throw new UnsupportedOperationException("MemoryLRUCache does not support prefixScan() function."); + } + + @Override + public long approximateNumEntries() { + return this.map.size(); + } + + @Override + public void flush() { + // do-nothing since it is in-memory + } + + @Override + public void close() { + open = false; + } + + public int size() { + return this.map.size(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/MemoryNavigableLRUCache.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/MemoryNavigableLRUCache.java new file mode 100644 index 0000000..84f46ad --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/MemoryNavigableLRUCache.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Iterator; +import java.util.Map; +import java.util.Objects; +import java.util.TreeMap; + +public class MemoryNavigableLRUCache extends MemoryLRUCache { + + private static final Logger LOG = LoggerFactory.getLogger(MemoryNavigableLRUCache.class); + + public MemoryNavigableLRUCache(final String name, final int maxCacheSize) { + super(name, maxCacheSize); + } + + @Override + public KeyValueIterator range(final Bytes from, final Bytes to) { + if (Objects.nonNull(from) && Objects.nonNull(to) && from.compareTo(to) > 0) { + LOG.warn("Returning empty iterator for fetch with invalid key range: from > to. " + + "This may be due to range arguments set in the wrong order, " + + "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes. " + + "Note that the built-in numerical serdes do not follow this for negative numbers"); + return KeyValueIterators.emptyIterator(); + } else { + final TreeMap treeMap = toTreeMap(); + final Iterator keys = getIterator(treeMap, from, to, true); + return new DelegatingPeekingKeyValueIterator<>(name(), + new MemoryNavigableLRUCache.CacheIterator(keys, treeMap)); + } + } + + @Override + public KeyValueIterator reverseRange(final Bytes from, final Bytes to) { + if (Objects.nonNull(from) && Objects.nonNull(to) && from.compareTo(to) > 0) { + LOG.warn("Returning empty iterator for fetch with invalid key range: from > to. " + + "This may be due to range arguments set in the wrong order, " + + "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes. " + + "Note that the built-in numerical serdes do not follow this for negative numbers"); + return KeyValueIterators.emptyIterator(); + } else { + final TreeMap treeMap = toTreeMap(); + final Iterator keys = getIterator(treeMap, from, to, false); + return new DelegatingPeekingKeyValueIterator<>(name(), + new MemoryNavigableLRUCache.CacheIterator(keys, treeMap)); + } + } + + private Iterator getIterator(final TreeMap treeMap, final Bytes from, final Bytes to, final boolean forward) { + if (from == null && to == null) { + return forward ? treeMap.navigableKeySet().iterator() : treeMap.navigableKeySet().descendingIterator(); + } else if (from == null) { + return forward ? treeMap.navigableKeySet().headSet(to, true).iterator() : treeMap.navigableKeySet().headSet(to, true).descendingIterator(); + } else if (to == null) { + return forward ? treeMap.navigableKeySet().tailSet(from, true).iterator() : treeMap.navigableKeySet().tailSet(from, true).descendingIterator(); + } else { + return forward ? treeMap.navigableKeySet().subSet(from, true, to, true).iterator() : treeMap.navigableKeySet().subSet(from, true, to, true).descendingIterator(); + } + } + + @Override + public , P> KeyValueIterator prefixScan(final P prefix, final PS prefixKeySerializer) { + + final Bytes from = Bytes.wrap(prefixKeySerializer.serialize(null, prefix)); + final Bytes to = Bytes.increment(from); + + final TreeMap treeMap = toTreeMap(); + + return new DelegatingPeekingKeyValueIterator<>( + name(), + new MemoryNavigableLRUCache.CacheIterator(treeMap.subMap(from, true, to, false).keySet().iterator(), treeMap) + ); + } + + @Override + public KeyValueIterator all() { + return range(null, null); + } + + @Override + public KeyValueIterator reverseAll() { + return reverseRange(null, null); + } + + private synchronized TreeMap toTreeMap() { + return new TreeMap<>(this.map); + } + + + private static class CacheIterator implements KeyValueIterator { + private final Iterator keys; + private final Map entries; + + private CacheIterator(final Iterator keys, final Map entries) { + this.keys = keys; + this.entries = entries; + } + + @Override + public boolean hasNext() { + return keys.hasNext(); + } + + @Override + public KeyValue next() { + final Bytes lastKey = keys.next(); + return new KeyValue<>(lastKey, entries.get(lastKey)); + } + + @Override + public void close() { + // do nothing + } + + @Override + public Bytes peekNextKey() { + throw new UnsupportedOperationException("peekNextKey not supported"); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/MergedSortedCacheKeyValueBytesStoreIterator.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/MergedSortedCacheKeyValueBytesStoreIterator.java new file mode 100644 index 0000000..701bdd1 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/MergedSortedCacheKeyValueBytesStoreIterator.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.state.KeyValueIterator; + +/** + * Merges two iterators. Assumes each of them is sorted by key + * + */ +class MergedSortedCacheKeyValueBytesStoreIterator + extends AbstractMergedSortedCacheStoreIterator { + + + MergedSortedCacheKeyValueBytesStoreIterator(final PeekingKeyValueIterator cacheIterator, + final KeyValueIterator storeIterator, + final boolean forward) { + super(cacheIterator, storeIterator, forward); + } + + @Override + public KeyValue deserializeStorePair(final KeyValue pair) { + return pair; + } + + @Override + Bytes deserializeCacheKey(final Bytes cacheKey) { + return cacheKey; + } + + @Override + byte[] deserializeCacheValue(final LRUCacheEntry cacheEntry) { + return cacheEntry.value(); + } + + @Override + public Bytes deserializeStoreKey(final Bytes key) { + return key; + } + + @Override + public int compare(final Bytes cacheKey, final Bytes storeKey) { + return cacheKey.compareTo(storeKey); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/MergedSortedCacheSessionStoreIterator.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/MergedSortedCacheSessionStoreIterator.java new file mode 100644 index 0000000..cd0c0df --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/MergedSortedCacheSessionStoreIterator.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Window; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.KeyValueIterator; + +/** + * Merges two iterators. Assumes each of them is sorted by key + * + */ +class MergedSortedCacheSessionStoreIterator extends AbstractMergedSortedCacheStoreIterator, Windowed, byte[], byte[]> { + + private final SegmentedCacheFunction cacheFunction; + + MergedSortedCacheSessionStoreIterator(final PeekingKeyValueIterator cacheIterator, + final KeyValueIterator, byte[]> storeIterator, + final SegmentedCacheFunction cacheFunction, + final boolean forward) { + super(cacheIterator, storeIterator, forward); + this.cacheFunction = cacheFunction; + } + + @Override + public KeyValue, byte[]> deserializeStorePair(final KeyValue, byte[]> pair) { + return pair; + } + + @Override + Windowed deserializeCacheKey(final Bytes cacheKey) { + final byte[] binaryKey = cacheFunction.key(cacheKey).get(); + final byte[] keyBytes = SessionKeySchema.extractKeyBytes(binaryKey); + final Window window = SessionKeySchema.extractWindow(binaryKey); + return new Windowed<>(Bytes.wrap(keyBytes), window); + } + + + @Override + byte[] deserializeCacheValue(final LRUCacheEntry cacheEntry) { + return cacheEntry.value(); + } + + @Override + public Windowed deserializeStoreKey(final Windowed key) { + return key; + } + + @Override + public int compare(final Bytes cacheKey, final Windowed storeKey) { + final Bytes storeKeyBytes = SessionKeySchema.toBinary(storeKey); + return cacheFunction.compareSegmentedKeys(cacheKey, storeKeyBytes); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/MergedSortedCacheWindowStoreIterator.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/MergedSortedCacheWindowStoreIterator.java new file mode 100644 index 0000000..46004f5 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/MergedSortedCacheWindowStoreIterator.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.WindowStoreIterator; + +import static org.apache.kafka.streams.state.internals.SegmentedCacheFunction.bytesFromCacheKey; + +/** + * Merges two iterators. Assumes each of them is sorted by key + * + */ +class MergedSortedCacheWindowStoreIterator extends AbstractMergedSortedCacheStoreIterator implements WindowStoreIterator { + + + MergedSortedCacheWindowStoreIterator(final PeekingKeyValueIterator cacheIterator, + final KeyValueIterator storeIterator, + final boolean forward) { + super(cacheIterator, storeIterator, forward); + } + + @Override + public KeyValue deserializeStorePair(final KeyValue pair) { + return pair; + } + + @Override + Long deserializeCacheKey(final Bytes cacheKey) { + final byte[] binaryKey = bytesFromCacheKey(cacheKey); + return WindowKeySchema.extractStoreTimestamp(binaryKey); + } + + @Override + byte[] deserializeCacheValue(final LRUCacheEntry cacheEntry) { + return cacheEntry.value(); + } + + @Override + public Long deserializeStoreKey(final Long key) { + return key; + } + + @Override + public int compare(final Bytes cacheKey, final Long storeKey) { + final byte[] binaryKey = bytesFromCacheKey(cacheKey); + + final Long cacheTimestamp = WindowKeySchema.extractStoreTimestamp(binaryKey); + return cacheTimestamp.compareTo(storeKey); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/MergedSortedCacheWindowStoreKeyValueIterator.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/MergedSortedCacheWindowStoreKeyValueIterator.java new file mode 100644 index 0000000..afc6a04 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/MergedSortedCacheWindowStoreKeyValueIterator.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.StateSerdes; + +class MergedSortedCacheWindowStoreKeyValueIterator + extends AbstractMergedSortedCacheStoreIterator, Windowed, byte[], byte[]> { + + private final StateSerdes serdes; + private final long windowSize; + private final SegmentedCacheFunction cacheFunction; + + MergedSortedCacheWindowStoreKeyValueIterator( + final PeekingKeyValueIterator filteredCacheIterator, + final KeyValueIterator, byte[]> underlyingIterator, + final StateSerdes serdes, + final long windowSize, + final SegmentedCacheFunction cacheFunction, + final boolean forward + ) { + super(filteredCacheIterator, underlyingIterator, forward); + this.serdes = serdes; + this.windowSize = windowSize; + this.cacheFunction = cacheFunction; + } + + @Override + Windowed deserializeStoreKey(final Windowed key) { + return key; + } + + @Override + KeyValue, byte[]> deserializeStorePair(final KeyValue, byte[]> pair) { + return pair; + } + + @Override + Windowed deserializeCacheKey(final Bytes cacheKey) { + final byte[] binaryKey = cacheFunction.key(cacheKey).get(); + return WindowKeySchema.fromStoreKey(binaryKey, windowSize, serdes.keyDeserializer(), serdes.topic()); + } + + @Override + byte[] deserializeCacheValue(final LRUCacheEntry cacheEntry) { + return cacheEntry.value(); + } + + @Override + int compare(final Bytes cacheKey, final Windowed storeKey) { + final Bytes storeKeyBytes = WindowKeySchema.toStoreKeyBinary(storeKey.key(), storeKey.window().start(), 0); + return cacheFunction.compareSegmentedKeys(cacheKey, storeKeyBytes); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredKeyValueStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredKeyValueStore.java new file mode 100644 index 0000000..27a760a --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredKeyValueStore.java @@ -0,0 +1,383 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.errors.ProcessorStateException; +import org.apache.kafka.streams.kstream.internals.Change; +import org.apache.kafka.streams.kstream.internals.WrappingNullableUtils; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.internals.SerdeGetter; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.processor.internals.ProcessorContextUtils; +import org.apache.kafka.streams.processor.internals.ProcessorStateManager; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.StateSerdes; +import org.apache.kafka.streams.state.internals.metrics.StateStoreMetrics; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import static org.apache.kafka.streams.kstream.internals.WrappingNullableUtils.prepareKeySerde; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.maybeMeasureLatency; + +/** + * A Metered {@link KeyValueStore} wrapper that is used for recording operation metrics, and hence its + * inner KeyValueStore implementation do not need to provide its own metrics collecting functionality. + * The inner {@link KeyValueStore} of this class is of type <Bytes,byte[]>, hence we use {@link Serde}s + * to convert from <K,V> to <Bytes,byte[]> + * + * @param + * @param + */ +public class MeteredKeyValueStore + extends WrappedStateStore, K, V> + implements KeyValueStore { + + final Serde keySerde; + final Serde valueSerde; + StateSerdes serdes; + + private final String metricsScope; + protected final Time time; + protected Sensor putSensor; + private Sensor putIfAbsentSensor; + protected Sensor getSensor; + private Sensor deleteSensor; + private Sensor putAllSensor; + private Sensor allSensor; + private Sensor rangeSensor; + private Sensor prefixScanSensor; + private Sensor flushSensor; + private Sensor e2eLatencySensor; + private InternalProcessorContext context; + private StreamsMetricsImpl streamsMetrics; + private TaskId taskId; + + MeteredKeyValueStore(final KeyValueStore inner, + final String metricsScope, + final Time time, + final Serde keySerde, + final Serde valueSerde) { + super(inner); + this.metricsScope = metricsScope; + this.time = time != null ? time : Time.SYSTEM; + this.keySerde = keySerde; + this.valueSerde = valueSerde; + } + + @Deprecated + @Override + public void init(final ProcessorContext context, + final StateStore root) { + this.context = context instanceof InternalProcessorContext ? (InternalProcessorContext) context : null; + taskId = context.taskId(); + initStoreSerde(context); + streamsMetrics = (StreamsMetricsImpl) context.metrics(); + + registerMetrics(); + final Sensor restoreSensor = + StateStoreMetrics.restoreSensor(taskId.toString(), metricsScope, name(), streamsMetrics); + + // register and possibly restore the state from the logs + maybeMeasureLatency(() -> super.init(context, root), time, restoreSensor); + } + + @Override + public void init(final StateStoreContext context, + final StateStore root) { + this.context = context instanceof InternalProcessorContext ? (InternalProcessorContext) context : null; + taskId = context.taskId(); + initStoreSerde(context); + streamsMetrics = (StreamsMetricsImpl) context.metrics(); + + registerMetrics(); + final Sensor restoreSensor = + StateStoreMetrics.restoreSensor(taskId.toString(), metricsScope, name(), streamsMetrics); + + // register and possibly restore the state from the logs + maybeMeasureLatency(() -> super.init(context, root), time, restoreSensor); + } + + private void registerMetrics() { + putSensor = StateStoreMetrics.putSensor(taskId.toString(), metricsScope, name(), streamsMetrics); + putIfAbsentSensor = StateStoreMetrics.putIfAbsentSensor(taskId.toString(), metricsScope, name(), streamsMetrics); + putAllSensor = StateStoreMetrics.putAllSensor(taskId.toString(), metricsScope, name(), streamsMetrics); + getSensor = StateStoreMetrics.getSensor(taskId.toString(), metricsScope, name(), streamsMetrics); + allSensor = StateStoreMetrics.allSensor(taskId.toString(), metricsScope, name(), streamsMetrics); + rangeSensor = StateStoreMetrics.rangeSensor(taskId.toString(), metricsScope, name(), streamsMetrics); + prefixScanSensor = StateStoreMetrics.prefixScanSensor(taskId.toString(), metricsScope, name(), streamsMetrics); + flushSensor = StateStoreMetrics.flushSensor(taskId.toString(), metricsScope, name(), streamsMetrics); + deleteSensor = StateStoreMetrics.deleteSensor(taskId.toString(), metricsScope, name(), streamsMetrics); + e2eLatencySensor = StateStoreMetrics.e2ELatencySensor(taskId.toString(), metricsScope, name(), streamsMetrics); + } + + protected Serde prepareValueSerdeForStore(final Serde valueSerde, final SerdeGetter getter) { + return WrappingNullableUtils.prepareValueSerde(valueSerde, getter); + } + + + @Deprecated + private void initStoreSerde(final ProcessorContext context) { + final String storeName = name(); + final String changelogTopic = ProcessorContextUtils.changelogFor(context, storeName); + serdes = new StateSerdes<>( + changelogTopic != null ? + changelogTopic : + ProcessorStateManager.storeChangelogTopic(context.applicationId(), storeName, taskId.topologyName()), + prepareKeySerde(keySerde, new SerdeGetter(context)), + prepareValueSerdeForStore(valueSerde, new SerdeGetter(context)) + ); + } + + private void initStoreSerde(final StateStoreContext context) { + final String storeName = name(); + final String changelogTopic = ProcessorContextUtils.changelogFor(context, storeName); + serdes = new StateSerdes<>( + changelogTopic != null ? + changelogTopic : + ProcessorStateManager.storeChangelogTopic(context.applicationId(), storeName, taskId.topologyName()), + prepareKeySerde(keySerde, new SerdeGetter(context)), + prepareValueSerdeForStore(valueSerde, new SerdeGetter(context)) + ); + } + + @SuppressWarnings("unchecked") + @Override + public boolean setFlushListener(final CacheFlushListener listener, + final boolean sendOldValues) { + final KeyValueStore wrapped = wrapped(); + if (wrapped instanceof CachedStateStore) { + return ((CachedStateStore) wrapped).setFlushListener( + new CacheFlushListener() { + @Override + public void apply(final byte[] rawKey, final byte[] rawNewValue, final byte[] rawOldValue, final long timestamp) { + listener.apply( + serdes.keyFrom(rawKey), + rawNewValue != null ? serdes.valueFrom(rawNewValue) : null, + rawOldValue != null ? serdes.valueFrom(rawOldValue) : null, + timestamp + ); + } + + @Override + public void apply(final Record> record) { + listener.apply( + record.withKey(serdes.keyFrom(record.key())) + .withValue(new Change<>( + record.value().newValue != null ? serdes.valueFrom(record.value().newValue) : null, + record.value().oldValue != null ? serdes.valueFrom(record.value().oldValue) : null + )) + ); + } + }, + sendOldValues); + } + return false; + } + + @Override + public V get(final K key) { + Objects.requireNonNull(key, "key cannot be null"); + try { + return maybeMeasureLatency(() -> outerValue(wrapped().get(keyBytes(key))), time, getSensor); + } catch (final ProcessorStateException e) { + final String message = String.format(e.getMessage(), key); + throw new ProcessorStateException(message, e); + } + } + + @Override + public void put(final K key, + final V value) { + Objects.requireNonNull(key, "key cannot be null"); + try { + maybeMeasureLatency(() -> wrapped().put(keyBytes(key), serdes.rawValue(value)), time, putSensor); + maybeRecordE2ELatency(); + } catch (final ProcessorStateException e) { + final String message = String.format(e.getMessage(), key, value); + throw new ProcessorStateException(message, e); + } + } + + @Override + public V putIfAbsent(final K key, + final V value) { + Objects.requireNonNull(key, "key cannot be null"); + final V currentValue = maybeMeasureLatency( + () -> outerValue(wrapped().putIfAbsent(keyBytes(key), serdes.rawValue(value))), + time, + putIfAbsentSensor + ); + maybeRecordE2ELatency(); + return currentValue; + } + + @Override + public void putAll(final List> entries) { + entries.forEach(entry -> Objects.requireNonNull(entry.key, "key cannot be null")); + maybeMeasureLatency(() -> wrapped().putAll(innerEntries(entries)), time, putAllSensor); + } + + @Override + public V delete(final K key) { + Objects.requireNonNull(key, "key cannot be null"); + try { + return maybeMeasureLatency(() -> outerValue(wrapped().delete(keyBytes(key))), time, deleteSensor); + } catch (final ProcessorStateException e) { + final String message = String.format(e.getMessage(), key); + throw new ProcessorStateException(message, e); + } + } + + @Override + public , P> KeyValueIterator prefixScan(final P prefix, final PS prefixKeySerializer) { + Objects.requireNonNull(prefix, "prefix cannot be null"); + Objects.requireNonNull(prefixKeySerializer, "prefixKeySerializer cannot be null"); + return new MeteredKeyValueIterator(wrapped().prefixScan(prefix, prefixKeySerializer), prefixScanSensor); + } + + @Override + public KeyValueIterator range(final K from, + final K to) { + final byte[] serFrom = from == null ? null : serdes.rawKey(from); + final byte[] serTo = to == null ? null : serdes.rawKey(to); + return new MeteredKeyValueIterator( + wrapped().range(Bytes.wrap(serFrom), Bytes.wrap(serTo)), + rangeSensor + ); + } + + @Override + public KeyValueIterator reverseRange(final K from, + final K to) { + final byte[] serFrom = from == null ? null : serdes.rawKey(from); + final byte[] serTo = to == null ? null : serdes.rawKey(to); + return new MeteredKeyValueIterator( + wrapped().reverseRange(Bytes.wrap(serFrom), Bytes.wrap(serTo)), + rangeSensor + ); + } + + @Override + public KeyValueIterator all() { + return new MeteredKeyValueIterator(wrapped().all(), allSensor); + } + + @Override + public KeyValueIterator reverseAll() { + return new MeteredKeyValueIterator(wrapped().reverseAll(), allSensor); + } + + @Override + public void flush() { + maybeMeasureLatency(super::flush, time, flushSensor); + } + + @Override + public long approximateNumEntries() { + return wrapped().approximateNumEntries(); + } + + @Override + public void close() { + try { + wrapped().close(); + } finally { + streamsMetrics.removeAllStoreLevelSensorsAndMetrics(taskId.toString(), name()); + } + } + + protected V outerValue(final byte[] value) { + return value != null ? serdes.valueFrom(value) : null; + } + + protected Bytes keyBytes(final K key) { + return Bytes.wrap(serdes.rawKey(key)); + } + + private List> innerEntries(final List> from) { + final List> byteEntries = new ArrayList<>(); + for (final KeyValue entry : from) { + byteEntries.add(KeyValue.pair(Bytes.wrap(serdes.rawKey(entry.key)), serdes.rawValue(entry.value))); + } + return byteEntries; + } + + private void maybeRecordE2ELatency() { + // Context is null if the provided context isn't an implementation of InternalProcessorContext. + // In that case, we _can't_ get the current timestamp, so we don't record anything. + if (e2eLatencySensor.shouldRecord() && context != null) { + final long currentTime = time.milliseconds(); + final long e2eLatency = currentTime - context.timestamp(); + e2eLatencySensor.record(e2eLatency, currentTime); + } + } + + private class MeteredKeyValueIterator implements KeyValueIterator { + + private final KeyValueIterator iter; + private final Sensor sensor; + private final long startNs; + + private MeteredKeyValueIterator(final KeyValueIterator iter, + final Sensor sensor) { + this.iter = iter; + this.sensor = sensor; + this.startNs = time.nanoseconds(); + } + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public KeyValue next() { + final KeyValue keyValue = iter.next(); + return KeyValue.pair( + serdes.keyFrom(keyValue.key.get()), + outerValue(keyValue.value)); + } + + @Override + public void close() { + try { + iter.close(); + } finally { + sensor.record(time.nanoseconds() - startNs); + } + } + + @Override + public K peekNextKey() { + return serdes.keyFrom(iter.peekNextKey().get()); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStore.java new file mode 100644 index 0000000..041f391 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStore.java @@ -0,0 +1,391 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.errors.ProcessorStateException; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.Change; +import org.apache.kafka.streams.kstream.internals.WrappingNullableUtils; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.internals.SerdeGetter; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.processor.internals.ProcessorContextUtils; +import org.apache.kafka.streams.processor.internals.ProcessorStateManager; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.SessionStore; +import org.apache.kafka.streams.state.StateSerdes; +import org.apache.kafka.streams.state.internals.metrics.StateStoreMetrics; + +import java.util.Objects; + +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.maybeMeasureLatency; + +public class MeteredSessionStore + extends WrappedStateStore, Windowed, V> + implements SessionStore { + + private final String metricsScope; + private final Serde keySerde; + private final Serde valueSerde; + private final Time time; + private StateSerdes serdes; + private StreamsMetricsImpl streamsMetrics; + private Sensor putSensor; + private Sensor fetchSensor; + private Sensor flushSensor; + private Sensor removeSensor; + private Sensor e2eLatencySensor; + private InternalProcessorContext context; + private TaskId taskId; + + + MeteredSessionStore(final SessionStore inner, + final String metricsScope, + final Serde keySerde, + final Serde valueSerde, + final Time time) { + super(inner); + this.metricsScope = metricsScope; + this.keySerde = keySerde; + this.valueSerde = valueSerde; + this.time = time; + } + + @Deprecated + @Override + public void init(final ProcessorContext context, + final StateStore root) { + this.context = context instanceof InternalProcessorContext ? (InternalProcessorContext) context : null; + taskId = context.taskId(); + initStoreSerde(context); + streamsMetrics = (StreamsMetricsImpl) context.metrics(); + + registerMetrics(); + final Sensor restoreSensor = + StateStoreMetrics.restoreSensor(taskId.toString(), metricsScope, name(), streamsMetrics); + + // register and possibly restore the state from the logs + maybeMeasureLatency(() -> super.init(context, root), time, restoreSensor); + } + + @Override + public void init(final StateStoreContext context, + final StateStore root) { + this.context = context instanceof InternalProcessorContext ? (InternalProcessorContext) context : null; + taskId = context.taskId(); + initStoreSerde(context); + streamsMetrics = (StreamsMetricsImpl) context.metrics(); + + registerMetrics(); + final Sensor restoreSensor = + StateStoreMetrics.restoreSensor(taskId.toString(), metricsScope, name(), streamsMetrics); + + // register and possibly restore the state from the logs + maybeMeasureLatency(() -> super.init(context, root), time, restoreSensor); + } + + private void registerMetrics() { + putSensor = StateStoreMetrics.putSensor(taskId.toString(), metricsScope, name(), streamsMetrics); + fetchSensor = StateStoreMetrics.fetchSensor(taskId.toString(), metricsScope, name(), streamsMetrics); + flushSensor = StateStoreMetrics.flushSensor(taskId.toString(), metricsScope, name(), streamsMetrics); + removeSensor = StateStoreMetrics.removeSensor(taskId.toString(), metricsScope, name(), streamsMetrics); + e2eLatencySensor = StateStoreMetrics.e2ELatencySensor(taskId.toString(), metricsScope, name(), streamsMetrics); + } + + + private void initStoreSerde(final ProcessorContext context) { + final String storeName = name(); + final String changelogTopic = ProcessorContextUtils.changelogFor(context, storeName); + serdes = new StateSerdes<>( + changelogTopic != null ? + changelogTopic : + ProcessorStateManager.storeChangelogTopic(context.applicationId(), storeName, taskId.topologyName()), + WrappingNullableUtils.prepareKeySerde(keySerde, new SerdeGetter(context)), + WrappingNullableUtils.prepareValueSerde(valueSerde, new SerdeGetter(context)) + ); + } + + private void initStoreSerde(final StateStoreContext context) { + final String storeName = name(); + final String changelogTopic = ProcessorContextUtils.changelogFor(context, storeName); + serdes = new StateSerdes<>( + changelogTopic != null ? + changelogTopic : + ProcessorStateManager.storeChangelogTopic(context.applicationId(), storeName, taskId.topologyName()), + WrappingNullableUtils.prepareKeySerde(keySerde, new SerdeGetter(context)), + WrappingNullableUtils.prepareValueSerde(valueSerde, new SerdeGetter(context)) + ); + } + + @SuppressWarnings("unchecked") + @Override + public boolean setFlushListener(final CacheFlushListener, V> listener, + final boolean sendOldValues) { + final SessionStore wrapped = wrapped(); + if (wrapped instanceof CachedStateStore) { + return ((CachedStateStore) wrapped).setFlushListener( + new CacheFlushListener() { + @Override + public void apply(final byte[] key, final byte[] newValue, final byte[] oldValue, final long timestamp) { + listener.apply( + SessionKeySchema.from(key, serdes.keyDeserializer(), serdes.topic()), + newValue != null ? serdes.valueFrom(newValue) : null, + oldValue != null ? serdes.valueFrom(oldValue) : null, + timestamp + ); + } + + @Override + public void apply(final Record> record) { + listener.apply( + record.withKey(SessionKeySchema.from(record.key(), serdes.keyDeserializer(), serdes.topic())) + .withValue(new Change<>( + record.value().newValue != null ? serdes.valueFrom(record.value().newValue) : null, + record.value().oldValue != null ? serdes.valueFrom(record.value().oldValue) : null + )) + ); + } + }, + sendOldValues); + } + return false; + } + + @Override + public void put(final Windowed sessionKey, + final V aggregate) { + Objects.requireNonNull(sessionKey, "sessionKey can't be null"); + Objects.requireNonNull(sessionKey.key(), "sessionKey.key() can't be null"); + Objects.requireNonNull(sessionKey.window(), "sessionKey.window() can't be null"); + + try { + maybeMeasureLatency( + () -> { + final Bytes key = keyBytes(sessionKey.key()); + wrapped().put(new Windowed<>(key, sessionKey.window()), serdes.rawValue(aggregate)); + }, + time, + putSensor + ); + maybeRecordE2ELatency(); + } catch (final ProcessorStateException e) { + final String message = String.format(e.getMessage(), sessionKey.key(), aggregate); + throw new ProcessorStateException(message, e); + } + } + + @Override + public void remove(final Windowed sessionKey) { + Objects.requireNonNull(sessionKey, "sessionKey can't be null"); + Objects.requireNonNull(sessionKey.key(), "sessionKey.key() can't be null"); + Objects.requireNonNull(sessionKey.window(), "sessionKey.window() can't be null"); + + try { + maybeMeasureLatency( + () -> { + final Bytes key = keyBytes(sessionKey.key()); + wrapped().remove(new Windowed<>(key, sessionKey.window())); + }, + time, + removeSensor + ); + } catch (final ProcessorStateException e) { + final String message = String.format(e.getMessage(), sessionKey.key()); + throw new ProcessorStateException(message, e); + } + } + + @Override + public V fetchSession(final K key, final long earliestSessionEndTime, final long latestSessionStartTime) { + Objects.requireNonNull(key, "key cannot be null"); + return maybeMeasureLatency( + () -> { + final Bytes bytesKey = keyBytes(key); + final byte[] result = wrapped().fetchSession( + bytesKey, + earliestSessionEndTime, + latestSessionStartTime + ); + if (result == null) { + return null; + } + return serdes.valueFrom(result); + }, + time, + fetchSensor + ); + } + + @Override + public KeyValueIterator, V> fetch(final K key) { + Objects.requireNonNull(key, "key cannot be null"); + return new MeteredWindowedKeyValueIterator<>( + wrapped().fetch(keyBytes(key)), + fetchSensor, + streamsMetrics, + serdes, + time); + } + + @Override + public KeyValueIterator, V> backwardFetch(final K key) { + Objects.requireNonNull(key, "key cannot be null"); + return new MeteredWindowedKeyValueIterator<>( + wrapped().backwardFetch(keyBytes(key)), + fetchSensor, + streamsMetrics, + serdes, + time + ); + } + + @Override + public KeyValueIterator, V> fetch(final K keyFrom, + final K keyTo) { + return new MeteredWindowedKeyValueIterator<>( + wrapped().fetch(keyBytes(keyFrom), keyBytes(keyTo)), + fetchSensor, + streamsMetrics, + serdes, + time); + } + + @Override + public KeyValueIterator, V> backwardFetch(final K keyFrom, + final K keyTo) { + return new MeteredWindowedKeyValueIterator<>( + wrapped().backwardFetch(keyBytes(keyFrom), keyBytes(keyTo)), + fetchSensor, + streamsMetrics, + serdes, + time + ); + } + + @Override + public KeyValueIterator, V> findSessions(final K key, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + Objects.requireNonNull(key, "key cannot be null"); + final Bytes bytesKey = keyBytes(key); + return new MeteredWindowedKeyValueIterator<>( + wrapped().findSessions( + bytesKey, + earliestSessionEndTime, + latestSessionStartTime), + fetchSensor, + streamsMetrics, + serdes, + time); + } + + @Override + public KeyValueIterator, V> backwardFindSessions(final K key, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + Objects.requireNonNull(key, "key cannot be null"); + final Bytes bytesKey = keyBytes(key); + return new MeteredWindowedKeyValueIterator<>( + wrapped().backwardFindSessions( + bytesKey, + earliestSessionEndTime, + latestSessionStartTime + ), + fetchSensor, + streamsMetrics, + serdes, + time + ); + } + + @Override + public KeyValueIterator, V> findSessions(final K keyFrom, + final K keyTo, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + final Bytes bytesKeyFrom = keyBytes(keyFrom); + final Bytes bytesKeyTo = keyBytes(keyTo); + return new MeteredWindowedKeyValueIterator<>( + wrapped().findSessions( + bytesKeyFrom, + bytesKeyTo, + earliestSessionEndTime, + latestSessionStartTime), + fetchSensor, + streamsMetrics, + serdes, + time); + } + + @Override + public KeyValueIterator, V> backwardFindSessions(final K keyFrom, + final K keyTo, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + final Bytes bytesKeyFrom = keyBytes(keyFrom); + final Bytes bytesKeyTo = keyBytes(keyTo); + return new MeteredWindowedKeyValueIterator<>( + wrapped().backwardFindSessions( + bytesKeyFrom, + bytesKeyTo, + earliestSessionEndTime, + latestSessionStartTime + ), + fetchSensor, + streamsMetrics, + serdes, + time + ); + } + + @Override + public void flush() { + maybeMeasureLatency(super::flush, time, flushSensor); + } + + @Override + public void close() { + try { + wrapped().close(); + } finally { + streamsMetrics.removeAllStoreLevelSensorsAndMetrics(taskId.toString(), name()); + } + } + + private Bytes keyBytes(final K key) { + return key == null ? null : Bytes.wrap(serdes.rawKey(key)); + } + + private void maybeRecordE2ELatency() { + // Context is null if the provided context isn't an implementation of InternalProcessorContext. + // In that case, we _can't_ get the current timestamp, so we don't record anything. + if (e2eLatencySensor.shouldRecord() && context != null) { + final long currentTime = time.milliseconds(); + final long e2eLatency = currentTime - context.timestamp(); + e2eLatencySensor.record(e2eLatency, currentTime); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredTimestampedKeyValueStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredTimestampedKeyValueStore.java new file mode 100644 index 0000000..3068cf1 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredTimestampedKeyValueStore.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.errors.ProcessorStateException; +import org.apache.kafka.streams.processor.internals.SerdeGetter; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; + +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.maybeMeasureLatency; + +/** + * A Metered {@link TimestampedKeyValueStore} wrapper that is used for recording operation metrics, and hence its + * inner KeyValueStore implementation do not need to provide its own metrics collecting functionality. + * The inner {@link KeyValueStore} of this class is of type <Bytes,byte[]>, hence we use {@link Serde}s + * to convert from <K,ValueAndTimestamp<V>> to <Bytes,byte[]> + * @param + * @param + */ +public class MeteredTimestampedKeyValueStore + extends MeteredKeyValueStore> + implements TimestampedKeyValueStore { + + MeteredTimestampedKeyValueStore(final KeyValueStore inner, + final String metricScope, + final Time time, + final Serde keySerde, + final Serde> valueSerde) { + super(inner, metricScope, time, keySerde, valueSerde); + } + + + @SuppressWarnings("unchecked") + @Override + protected Serde> prepareValueSerdeForStore(final Serde> valueSerde, final SerdeGetter getter) { + if (valueSerde == null) { + return new ValueAndTimestampSerde<>((Serde) getter.valueSerde()); + } else { + return super.prepareValueSerdeForStore(valueSerde, getter); + } + } + + + public RawAndDeserializedValue getWithBinary(final K key) { + try { + return maybeMeasureLatency(() -> { + final byte[] serializedValue = wrapped().get(keyBytes(key)); + return new RawAndDeserializedValue<>(serializedValue, outerValue(serializedValue)); + }, time, getSensor); + } catch (final ProcessorStateException e) { + final String message = String.format(e.getMessage(), key); + throw new ProcessorStateException(message, e); + } + } + + public boolean putIfDifferentValues(final K key, + final ValueAndTimestamp newValue, + final byte[] oldSerializedValue) { + try { + return maybeMeasureLatency( + () -> { + final byte[] newSerializedValue = serdes.rawValue(newValue); + if (ValueAndTimestampSerializer.valuesAreSameAndTimeIsIncreasing(oldSerializedValue, newSerializedValue)) { + return false; + } else { + wrapped().put(keyBytes(key), newSerializedValue); + return true; + } + }, + time, + putSensor + ); + } catch (final ProcessorStateException e) { + final String message = String.format(e.getMessage(), key, newValue); + throw new ProcessorStateException(message, e); + } + } + + static class RawAndDeserializedValue { + final byte[] serializedValue; + final ValueAndTimestamp value; + RawAndDeserializedValue(final byte[] serializedValue, final ValueAndTimestamp value) { + this.serializedValue = serializedValue; + this.value = value; + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredTimestampedWindowStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredTimestampedWindowStore.java new file mode 100644 index 0000000..dd489f8 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredTimestampedWindowStore.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.processor.internals.SerdeGetter; +import org.apache.kafka.streams.state.TimestampedWindowStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.streams.state.WindowStore; + +/** + * A Metered {@link TimestampedWindowStore} wrapper that is used for recording operation metrics, and hence its + * inner WindowStore implementation do not need to provide its own metrics collecting functionality. + * The inner {@link WindowStore} of this class is of type <Bytes,byte[]>, hence we use {@link Serde}s + * to convert from <K,ValueAndTimestamp<V>> to <Bytes,byte[]> + * + * @param + * @param + */ +class MeteredTimestampedWindowStore + extends MeteredWindowStore> + implements TimestampedWindowStore { + + MeteredTimestampedWindowStore(final WindowStore inner, + final long windowSizeMs, + final String metricScope, + final Time time, + final Serde keySerde, + final Serde> valueSerde) { + super(inner, windowSizeMs, metricScope, time, keySerde, valueSerde); + } + + @SuppressWarnings("unchecked") + @Override + protected Serde> prepareValueSerde(final Serde> valueSerde, final SerdeGetter getter) { + if (valueSerde == null) { + return new ValueAndTimestampSerde<>((Serde) getter.valueSerde()); + } else { + return super.prepareValueSerde(valueSerde, getter); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredWindowStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredWindowStore.java new file mode 100644 index 0000000..1a45b55 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredWindowStore.java @@ -0,0 +1,336 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.errors.ProcessorStateException; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.Change; +import org.apache.kafka.streams.kstream.internals.WrappingNullableUtils; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.internals.SerdeGetter; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.processor.internals.ProcessorContextUtils; +import org.apache.kafka.streams.processor.internals.ProcessorStateManager; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.StateSerdes; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.streams.state.WindowStoreIterator; +import org.apache.kafka.streams.state.internals.metrics.StateStoreMetrics; + +import java.util.Objects; + +import static org.apache.kafka.streams.kstream.internals.WrappingNullableUtils.prepareKeySerde; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.maybeMeasureLatency; + +public class MeteredWindowStore + extends WrappedStateStore, Windowed, V> + implements WindowStore { + + private final long windowSizeMs; + private final String metricsScope; + private final Time time; + private final Serde keySerde; + private final Serde valueSerde; + private StateSerdes serdes; + private StreamsMetricsImpl streamsMetrics; + private Sensor putSensor; + private Sensor fetchSensor; + private Sensor flushSensor; + private Sensor e2eLatencySensor; + private InternalProcessorContext context; + private TaskId taskId; + + MeteredWindowStore(final WindowStore inner, + final long windowSizeMs, + final String metricsScope, + final Time time, + final Serde keySerde, + final Serde valueSerde) { + super(inner); + this.windowSizeMs = windowSizeMs; + this.metricsScope = metricsScope; + this.time = time; + this.keySerde = keySerde; + this.valueSerde = valueSerde; + } + + @Deprecated + @Override + public void init(final ProcessorContext context, + final StateStore root) { + this.context = context instanceof InternalProcessorContext ? (InternalProcessorContext) context : null; + taskId = context.taskId(); + initStoreSerde(context); + streamsMetrics = (StreamsMetricsImpl) context.metrics(); + + registerMetrics(); + final Sensor restoreSensor = + StateStoreMetrics.restoreSensor(taskId.toString(), metricsScope, name(), streamsMetrics); + + // register and possibly restore the state from the logs + maybeMeasureLatency(() -> super.init(context, root), time, restoreSensor); + } + + @Override + public void init(final StateStoreContext context, + final StateStore root) { + this.context = context instanceof InternalProcessorContext ? (InternalProcessorContext) context : null; + taskId = context.taskId(); + initStoreSerde(context); + streamsMetrics = (StreamsMetricsImpl) context.metrics(); + + registerMetrics(); + final Sensor restoreSensor = + StateStoreMetrics.restoreSensor(taskId.toString(), metricsScope, name(), streamsMetrics); + + // register and possibly restore the state from the logs + maybeMeasureLatency(() -> super.init(context, root), time, restoreSensor); + } + protected Serde prepareValueSerde(final Serde valueSerde, final SerdeGetter getter) { + return WrappingNullableUtils.prepareValueSerde(valueSerde, getter); + } + + private void registerMetrics() { + putSensor = StateStoreMetrics.putSensor(taskId.toString(), metricsScope, name(), streamsMetrics); + fetchSensor = StateStoreMetrics.fetchSensor(taskId.toString(), metricsScope, name(), streamsMetrics); + flushSensor = StateStoreMetrics.flushSensor(taskId.toString(), metricsScope, name(), streamsMetrics); + e2eLatencySensor = StateStoreMetrics.e2ELatencySensor(taskId.toString(), metricsScope, name(), streamsMetrics); + } + + @Deprecated + private void initStoreSerde(final ProcessorContext context) { + final String storeName = name(); + final String changelogTopic = ProcessorContextUtils.changelogFor(context, storeName); + serdes = new StateSerdes<>( + changelogTopic != null ? + changelogTopic : + ProcessorStateManager.storeChangelogTopic(context.applicationId(), storeName, taskId.topologyName()), + prepareKeySerde(keySerde, new SerdeGetter(context)), + prepareValueSerde(valueSerde, new SerdeGetter(context))); + } + + private void initStoreSerde(final StateStoreContext context) { + final String storeName = name(); + final String changelogTopic = ProcessorContextUtils.changelogFor(context, storeName); + serdes = new StateSerdes<>( + changelogTopic != null ? + changelogTopic : + ProcessorStateManager.storeChangelogTopic(context.applicationId(), storeName, taskId.topologyName()), + prepareKeySerde(keySerde, new SerdeGetter(context)), + prepareValueSerde(valueSerde, new SerdeGetter(context))); + } + + @SuppressWarnings("unchecked") + @Override + public boolean setFlushListener(final CacheFlushListener, V> listener, + final boolean sendOldValues) { + final WindowStore wrapped = wrapped(); + if (wrapped instanceof CachedStateStore) { + return ((CachedStateStore) wrapped).setFlushListener( + new CacheFlushListener() { + @Override + public void apply(final byte[] key, final byte[] newValue, final byte[] oldValue, final long timestamp) { + listener.apply( + WindowKeySchema.fromStoreKey(key, windowSizeMs, serdes.keyDeserializer(), serdes.topic()), + newValue != null ? serdes.valueFrom(newValue) : null, + oldValue != null ? serdes.valueFrom(oldValue) : null, + timestamp + ); + } + + @Override + public void apply(final Record> record) { + listener.apply( + record.withKey(WindowKeySchema.fromStoreKey(record.key(), windowSizeMs, serdes.keyDeserializer(), serdes.topic())) + .withValue(new Change<>( + record.value().newValue != null ? serdes.valueFrom(record.value().newValue) : null, + record.value().oldValue != null ? serdes.valueFrom(record.value().oldValue) : null + )) + ); + } + }, + sendOldValues); + } + return false; + } + + @Override + public void put(final K key, + final V value, + final long windowStartTimestamp) { + Objects.requireNonNull(key, "key cannot be null"); + try { + maybeMeasureLatency( + () -> wrapped().put(keyBytes(key), serdes.rawValue(value), windowStartTimestamp), + time, + putSensor + ); + maybeRecordE2ELatency(); + } catch (final ProcessorStateException e) { + final String message = String.format(e.getMessage(), key, value); + throw new ProcessorStateException(message, e); + } + } + + @Override + public V fetch(final K key, + final long timestamp) { + Objects.requireNonNull(key, "key cannot be null"); + return maybeMeasureLatency( + () -> { + final byte[] result = wrapped().fetch(keyBytes(key), timestamp); + if (result == null) { + return null; + } + return serdes.valueFrom(result); + }, + time, + fetchSensor + ); + } + + @Override + public WindowStoreIterator fetch(final K key, + final long timeFrom, + final long timeTo) { + Objects.requireNonNull(key, "key cannot be null"); + return new MeteredWindowStoreIterator<>( + wrapped().fetch(keyBytes(key), timeFrom, timeTo), + fetchSensor, + streamsMetrics, + serdes, + time + ); + } + + @Override + public WindowStoreIterator backwardFetch(final K key, + final long timeFrom, + final long timeTo) { + Objects.requireNonNull(key, "key cannot be null"); + return new MeteredWindowStoreIterator<>( + wrapped().backwardFetch(keyBytes(key), timeFrom, timeTo), + fetchSensor, + streamsMetrics, + serdes, + time + ); + } + + @Override + public KeyValueIterator, V> fetch(final K keyFrom, + final K keyTo, + final long timeFrom, + final long timeTo) { + return new MeteredWindowedKeyValueIterator<>( + wrapped().fetch( + keyBytes(keyFrom), + keyBytes(keyTo), + timeFrom, + timeTo), + fetchSensor, + streamsMetrics, + serdes, + time); + } + + @Override + public KeyValueIterator, V> backwardFetch(final K keyFrom, + final K keyTo, + final long timeFrom, + final long timeTo) { + return new MeteredWindowedKeyValueIterator<>( + wrapped().backwardFetch( + keyBytes(keyFrom), + keyBytes(keyTo), + timeFrom, + timeTo), + fetchSensor, + streamsMetrics, + serdes, + time); + } + + @Override + public KeyValueIterator, V> fetchAll(final long timeFrom, + final long timeTo) { + return new MeteredWindowedKeyValueIterator<>( + wrapped().fetchAll(timeFrom, timeTo), + fetchSensor, + streamsMetrics, + serdes, + time); + } + + @Override + public KeyValueIterator, V> backwardFetchAll(final long timeFrom, + final long timeTo) { + return new MeteredWindowedKeyValueIterator<>( + wrapped().backwardFetchAll(timeFrom, timeTo), + fetchSensor, + streamsMetrics, + serdes, + time); + } + + @Override + public KeyValueIterator, V> all() { + return new MeteredWindowedKeyValueIterator<>(wrapped().all(), fetchSensor, streamsMetrics, serdes, time); + } + + @Override + public KeyValueIterator, V> backwardAll() { + return new MeteredWindowedKeyValueIterator<>(wrapped().backwardAll(), fetchSensor, streamsMetrics, serdes, time); + } + + @Override + public void flush() { + maybeMeasureLatency(super::flush, time, flushSensor); + } + + @Override + public void close() { + try { + wrapped().close(); + } finally { + streamsMetrics.removeAllStoreLevelSensorsAndMetrics(taskId.toString(), name()); + } + } + + private Bytes keyBytes(final K key) { + return Bytes.wrap(serdes.rawKey(key)); + } + + private void maybeRecordE2ELatency() { + // Context is null if the provided context isn't an implementation of InternalProcessorContext. + // In that case, we _can't_ get the current timestamp, so we don't record anything. + if (e2eLatencySensor.shouldRecord() && context != null) { + final long currentTime = time.milliseconds(); + final long e2eLatency = currentTime - context.timestamp(); + e2eLatencySensor.record(e2eLatency, currentTime); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredWindowStoreIterator.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredWindowStoreIterator.java new file mode 100644 index 0000000..98bc655 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredWindowStoreIterator.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsMetrics; +import org.apache.kafka.streams.state.StateSerdes; +import org.apache.kafka.streams.state.WindowStoreIterator; + +class MeteredWindowStoreIterator implements WindowStoreIterator { + + private final WindowStoreIterator iter; + private final Sensor sensor; + private final StreamsMetrics metrics; + private final StateSerdes serdes; + private final long startNs; + private final Time time; + + MeteredWindowStoreIterator(final WindowStoreIterator iter, + final Sensor sensor, + final StreamsMetrics metrics, + final StateSerdes serdes, + final Time time) { + this.iter = iter; + this.sensor = sensor; + this.metrics = metrics; + this.serdes = serdes; + this.startNs = time.nanoseconds(); + this.time = time; + } + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public KeyValue next() { + final KeyValue next = iter.next(); + return KeyValue.pair(next.key, serdes.valueFrom(next.value)); + } + + @Override + public void close() { + try { + iter.close(); + } finally { + sensor.record(time.nanoseconds() - startNs); + } + } + + @Override + public Long peekNextKey() { + return iter.peekNextKey(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredWindowedKeyValueIterator.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredWindowedKeyValueIterator.java new file mode 100644 index 0000000..411871d --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredWindowedKeyValueIterator.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsMetrics; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.StateSerdes; + +class MeteredWindowedKeyValueIterator implements KeyValueIterator, V> { + + private final KeyValueIterator, byte[]> iter; + private final Sensor sensor; + private final StreamsMetrics metrics; + private final StateSerdes serdes; + private final long startNs; + private final Time time; + + MeteredWindowedKeyValueIterator(final KeyValueIterator, byte[]> iter, + final Sensor sensor, + final StreamsMetrics metrics, + final StateSerdes serdes, + final Time time) { + this.iter = iter; + this.sensor = sensor; + this.metrics = metrics; + this.serdes = serdes; + this.startNs = time.nanoseconds(); + this.time = time; + } + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public KeyValue, V> next() { + final KeyValue, byte[]> next = iter.next(); + return KeyValue.pair(windowedKey(next.key), serdes.valueFrom(next.value)); + } + + private Windowed windowedKey(final Windowed bytesKey) { + final K key = serdes.keyFrom(bytesKey.key().get()); + return new Windowed<>(key, bytesKey.window()); + } + + @Override + public void close() { + try { + iter.close(); + } finally { + sensor.record(time.nanoseconds() - startNs); + } + } + + @Override + public Windowed peekNextKey() { + return windowedKey(iter.peekNextKey()); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/Murmur3.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/Murmur3.java new file mode 100644 index 0000000..5581a03 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/Murmur3.java @@ -0,0 +1,548 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +/** + * This class was taken from Hive org.apache.hive.common.util; + * https://github.com/apache/hive/blob/master/storage-api/src/java/org/apache/hive/common/util/Murmur3.java + * Commit: dffa3a16588bc8e95b9d0ab5af295a74e06ef702 + * + * + * Murmur3 is successor to Murmur2 fast non-crytographic hash algorithms. + * + * Murmur3 32 and 128 bit variants. + * 32-bit Java port of https://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.cpp#94 + * 128-bit Java port of https://code.google.com/p/smhasher/source/browse/trunk/MurmurHash3.cpp#255 + * + * This is a public domain code with no copyrights. + * From homepage of MurmurHash (https://code.google.com/p/smhasher/), + * "All MurmurHash versions are public domain software, and the author disclaims all copyright + * to their code." + */ +@SuppressWarnings("fallthrough") +public class Murmur3 { + // from 64-bit linear congruential generator + public static final long NULL_HASHCODE = 2862933555777941757L; + + // Constants for 32 bit variant + private static final int C1_32 = 0xcc9e2d51; + private static final int C2_32 = 0x1b873593; + private static final int R1_32 = 15; + private static final int R2_32 = 13; + private static final int M_32 = 5; + private static final int N_32 = 0xe6546b64; + + // Constants for 128 bit variant + private static final long C1 = 0x87c37b91114253d5L; + private static final long C2 = 0x4cf5ad432745937fL; + private static final int R1 = 31; + private static final int R2 = 27; + private static final int R3 = 33; + private static final int M = 5; + private static final int N1 = 0x52dce729; + private static final int N2 = 0x38495ab5; + + public static final int DEFAULT_SEED = 104729; + + public static int hash32(long l0, long l1) { + return hash32(l0, l1, DEFAULT_SEED); + } + + public static int hash32(long l0) { + return hash32(l0, DEFAULT_SEED); + } + + /** + * Murmur3 32-bit variant. + */ + public static int hash32(long l0, int seed) { + int hash = seed; + final long r0 = Long.reverseBytes(l0); + + hash = mix32((int) r0, hash); + hash = mix32((int) (r0 >>> 32), hash); + + return fmix32(Long.BYTES, hash); + } + + /** + * Murmur3 32-bit variant. + */ + public static int hash32(long l0, long l1, int seed) { + int hash = seed; + final long r0 = Long.reverseBytes(l0); + final long r1 = Long.reverseBytes(l1); + + hash = mix32((int) r0, hash); + hash = mix32((int) (r0 >>> 32), hash); + hash = mix32((int) (r1), hash); + hash = mix32((int) (r1 >>> 32), hash); + + return fmix32(Long.BYTES * 2, hash); + } + + /** + * Murmur3 32-bit variant. + * + * @param data - input byte array + * @return - hashcode + */ + public static int hash32(byte[] data) { + return hash32(data, 0, data.length, DEFAULT_SEED); + } + + /** + * Murmur3 32-bit variant. + * + * @param data - input byte array + * @param length - length of array + * @return - hashcode + */ + public static int hash32(byte[] data, int length) { + return hash32(data, 0, length, DEFAULT_SEED); + } + + /** + * Murmur3 32-bit variant. + * + * @param data - input byte array + * @param length - length of array + * @param seed - seed. (default 0) + * @return - hashcode + */ + public static int hash32(byte[] data, int length, int seed) { + return hash32(data, 0, length, seed); + } + + /** + * Murmur3 32-bit variant. + * + * @param data - input byte array + * @param offset - offset of data + * @param length - length of array + * @param seed - seed. (default 0) + * @return - hashcode + */ + public static int hash32(byte[] data, int offset, int length, int seed) { + int hash = seed; + final int nblocks = length >> 2; + + // body + for (int i = 0; i < nblocks; i++) { + int i_4 = i << 2; + int k = (data[offset + i_4] & 0xff) + | ((data[offset + i_4 + 1] & 0xff) << 8) + | ((data[offset + i_4 + 2] & 0xff) << 16) + | ((data[offset + i_4 + 3] & 0xff) << 24); + + hash = mix32(k, hash); + } + + // tail + int idx = nblocks << 2; + int k1 = 0; + switch (length - idx) { + case 3: + k1 ^= data[offset + idx + 2] << 16; + case 2: + k1 ^= data[offset + idx + 1] << 8; + case 1: + k1 ^= data[offset + idx]; + + // mix functions + k1 *= C1_32; + k1 = Integer.rotateLeft(k1, R1_32); + k1 *= C2_32; + hash ^= k1; + } + + return fmix32(length, hash); + } + + private static int mix32(int k, int hash) { + k *= C1_32; + k = Integer.rotateLeft(k, R1_32); + k *= C2_32; + hash ^= k; + return Integer.rotateLeft(hash, R2_32) * M_32 + N_32; + } + + private static int fmix32(int length, int hash) { + hash ^= length; + hash ^= (hash >>> 16); + hash *= 0x85ebca6b; + hash ^= (hash >>> 13); + hash *= 0xc2b2ae35; + hash ^= (hash >>> 16); + + return hash; + } + + /** + * Murmur3 64-bit variant. This is essentially MSB 8 bytes of Murmur3 128-bit variant. + * + * @param data - input byte array + * @return - hashcode + */ + public static long hash64(byte[] data) { + return hash64(data, 0, data.length, DEFAULT_SEED); + } + + public static long hash64(long data) { + long hash = DEFAULT_SEED; + long k = Long.reverseBytes(data); + int length = Long.BYTES; + // mix functions + k *= C1; + k = Long.rotateLeft(k, R1); + k *= C2; + hash ^= k; + hash = Long.rotateLeft(hash, R2) * M + N1; + // finalization + hash ^= length; + hash = fmix64(hash); + return hash; + } + + public static long hash64(int data) { + long k1 = Integer.reverseBytes(data) & (-1L >>> 32); + int length = Integer.BYTES; + long hash = DEFAULT_SEED; + k1 *= C1; + k1 = Long.rotateLeft(k1, R1); + k1 *= C2; + hash ^= k1; + // finalization + hash ^= length; + hash = fmix64(hash); + return hash; + } + + public static long hash64(short data) { + long hash = DEFAULT_SEED; + long k1 = 0; + k1 ^= ((long) data & 0xff) << 8; + k1 ^= ((long)((data & 0xFF00) >> 8) & 0xff); + k1 *= C1; + k1 = Long.rotateLeft(k1, R1); + k1 *= C2; + hash ^= k1; + + // finalization + hash ^= Short.BYTES; + hash = fmix64(hash); + return hash; + } + + public static long hash64(byte[] data, int offset, int length) { + return hash64(data, offset, length, DEFAULT_SEED); + } + + /** + * Murmur3 64-bit variant. This is essentially MSB 8 bytes of Murmur3 128-bit variant. + * + * @param data - input byte array + * @param length - length of array + * @param seed - seed. (default is 0) + * @return - hashcode + */ + public static long hash64(byte[] data, int offset, int length, int seed) { + long hash = seed; + final int nblocks = length >> 3; + + // body + for (int i = 0; i < nblocks; i++) { + final int i8 = i << 3; + long k = ((long) data[offset + i8] & 0xff) + | (((long) data[offset + i8 + 1] & 0xff) << 8) + | (((long) data[offset + i8 + 2] & 0xff) << 16) + | (((long) data[offset + i8 + 3] & 0xff) << 24) + | (((long) data[offset + i8 + 4] & 0xff) << 32) + | (((long) data[offset + i8 + 5] & 0xff) << 40) + | (((long) data[offset + i8 + 6] & 0xff) << 48) + | (((long) data[offset + i8 + 7] & 0xff) << 56); + + // mix functions + k *= C1; + k = Long.rotateLeft(k, R1); + k *= C2; + hash ^= k; + hash = Long.rotateLeft(hash, R2) * M + N1; + } + + // tail + long k1 = 0; + int tailStart = nblocks << 3; + switch (length - tailStart) { + case 7: + k1 ^= ((long) data[offset + tailStart + 6] & 0xff) << 48; + case 6: + k1 ^= ((long) data[offset + tailStart + 5] & 0xff) << 40; + case 5: + k1 ^= ((long) data[offset + tailStart + 4] & 0xff) << 32; + case 4: + k1 ^= ((long) data[offset + tailStart + 3] & 0xff) << 24; + case 3: + k1 ^= ((long) data[offset + tailStart + 2] & 0xff) << 16; + case 2: + k1 ^= ((long) data[offset + tailStart + 1] & 0xff) << 8; + case 1: + k1 ^= ((long) data[offset + tailStart] & 0xff); + k1 *= C1; + k1 = Long.rotateLeft(k1, R1); + k1 *= C2; + hash ^= k1; + } + + // finalization + hash ^= length; + hash = fmix64(hash); + + return hash; + } + + /** + * Murmur3 128-bit variant. + * + * @param data - input byte array + * @return - hashcode (2 longs) + */ + public static long[] hash128(byte[] data) { + return hash128(data, 0, data.length, DEFAULT_SEED); + } + + /** + * Murmur3 128-bit variant. + * + * @param data - input byte array + * @param offset - the first element of array + * @param length - length of array + * @param seed - seed. (default is 0) + * @return - hashcode (2 longs) + */ + public static long[] hash128(byte[] data, int offset, int length, int seed) { + long h1 = seed; + long h2 = seed; + final int nblocks = length >> 4; + + // body + for (int i = 0; i < nblocks; i++) { + final int i16 = i << 4; + long k1 = ((long) data[offset + i16] & 0xff) + | (((long) data[offset + i16 + 1] & 0xff) << 8) + | (((long) data[offset + i16 + 2] & 0xff) << 16) + | (((long) data[offset + i16 + 3] & 0xff) << 24) + | (((long) data[offset + i16 + 4] & 0xff) << 32) + | (((long) data[offset + i16 + 5] & 0xff) << 40) + | (((long) data[offset + i16 + 6] & 0xff) << 48) + | (((long) data[offset + i16 + 7] & 0xff) << 56); + + long k2 = ((long) data[offset + i16 + 8] & 0xff) + | (((long) data[offset + i16 + 9] & 0xff) << 8) + | (((long) data[offset + i16 + 10] & 0xff) << 16) + | (((long) data[offset + i16 + 11] & 0xff) << 24) + | (((long) data[offset + i16 + 12] & 0xff) << 32) + | (((long) data[offset + i16 + 13] & 0xff) << 40) + | (((long) data[offset + i16 + 14] & 0xff) << 48) + | (((long) data[offset + i16 + 15] & 0xff) << 56); + + // mix functions for k1 + k1 *= C1; + k1 = Long.rotateLeft(k1, R1); + k1 *= C2; + h1 ^= k1; + h1 = Long.rotateLeft(h1, R2); + h1 += h2; + h1 = h1 * M + N1; + + // mix functions for k2 + k2 *= C2; + k2 = Long.rotateLeft(k2, R3); + k2 *= C1; + h2 ^= k2; + h2 = Long.rotateLeft(h2, R1); + h2 += h1; + h2 = h2 * M + N2; + } + + // tail + long k1 = 0; + long k2 = 0; + int tailStart = nblocks << 4; + switch (length - tailStart) { + case 15: + k2 ^= (long) (data[offset + tailStart + 14] & 0xff) << 48; + case 14: + k2 ^= (long) (data[offset + tailStart + 13] & 0xff) << 40; + case 13: + k2 ^= (long) (data[offset + tailStart + 12] & 0xff) << 32; + case 12: + k2 ^= (long) (data[offset + tailStart + 11] & 0xff) << 24; + case 11: + k2 ^= (long) (data[offset + tailStart + 10] & 0xff) << 16; + case 10: + k2 ^= (long) (data[offset + tailStart + 9] & 0xff) << 8; + case 9: + k2 ^= (long) (data[offset + tailStart + 8] & 0xff); + k2 *= C2; + k2 = Long.rotateLeft(k2, R3); + k2 *= C1; + h2 ^= k2; + + case 8: + k1 ^= (long) (data[offset + tailStart + 7] & 0xff) << 56; + case 7: + k1 ^= (long) (data[offset + tailStart + 6] & 0xff) << 48; + case 6: + k1 ^= (long) (data[offset + tailStart + 5] & 0xff) << 40; + case 5: + k1 ^= (long) (data[offset + tailStart + 4] & 0xff) << 32; + case 4: + k1 ^= (long) (data[offset + tailStart + 3] & 0xff) << 24; + case 3: + k1 ^= (long) (data[offset + tailStart + 2] & 0xff) << 16; + case 2: + k1 ^= (long) (data[offset + tailStart + 1] & 0xff) << 8; + case 1: + k1 ^= (long) (data[offset + tailStart] & 0xff); + k1 *= C1; + k1 = Long.rotateLeft(k1, R1); + k1 *= C2; + h1 ^= k1; + } + + // finalization + h1 ^= length; + h2 ^= length; + + h1 += h2; + h2 += h1; + + h1 = fmix64(h1); + h2 = fmix64(h2); + + h1 += h2; + h2 += h1; + + return new long[]{h1, h2}; + } + + private static long fmix64(long h) { + h ^= (h >>> 33); + h *= 0xff51afd7ed558ccdL; + h ^= (h >>> 33); + h *= 0xc4ceb9fe1a85ec53L; + h ^= (h >>> 33); + return h; + } + + public static class IncrementalHash32 { + byte[] tail = new byte[3]; + int tailLen; + int totalLen; + int hash; + + public final void start(int hash) { + tailLen = totalLen = 0; + this.hash = hash; + } + + public final void add(byte[] data, int offset, int length) { + if (length == 0) return; + totalLen += length; + if (tailLen + length < 4) { + System.arraycopy(data, offset, tail, tailLen, length); + tailLen += length; + return; + } + int offset2 = 0; + if (tailLen > 0) { + offset2 = (4 - tailLen); + int k = -1; + switch (tailLen) { + case 1: + k = orBytes(tail[0], data[offset], data[offset + 1], data[offset + 2]); + break; + case 2: + k = orBytes(tail[0], tail[1], data[offset], data[offset + 1]); + break; + case 3: + k = orBytes(tail[0], tail[1], tail[2], data[offset]); + break; + default: throw new AssertionError(tailLen); + } + // mix functions + k *= C1_32; + k = Integer.rotateLeft(k, R1_32); + k *= C2_32; + hash ^= k; + hash = Integer.rotateLeft(hash, R2_32) * M_32 + N_32; + } + int length2 = length - offset2; + offset += offset2; + final int nblocks = length2 >> 2; + + for (int i = 0; i < nblocks; i++) { + int i_4 = (i << 2) + offset; + int k = orBytes(data[i_4], data[i_4 + 1], data[i_4 + 2], data[i_4 + 3]); + + // mix functions + k *= C1_32; + k = Integer.rotateLeft(k, R1_32); + k *= C2_32; + hash ^= k; + hash = Integer.rotateLeft(hash, R2_32) * M_32 + N_32; + } + + int consumed = (nblocks << 2); + tailLen = length2 - consumed; + if (consumed == length2) return; + System.arraycopy(data, offset + consumed, tail, 0, tailLen); + } + + public final int end() { + int k1 = 0; + switch (tailLen) { + case 3: + k1 ^= tail[2] << 16; + case 2: + k1 ^= tail[1] << 8; + case 1: + k1 ^= tail[0]; + + // mix functions + k1 *= C1_32; + k1 = Integer.rotateLeft(k1, R1_32); + k1 *= C2_32; + hash ^= k1; + } + + // finalization + hash ^= totalLen; + hash ^= (hash >>> 16); + hash *= 0x85ebca6b; + hash ^= (hash >>> 13); + hash *= 0xc2b2ae35; + hash ^= (hash >>> 16); + return hash; + } + } + + private static int orBytes(byte b1, byte b2, byte b3, byte b4) { + return (b1 & 0xff) | ((b2 & 0xff) << 8) | ((b3 & 0xff) << 16) | ((b4 & 0xff) << 24); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/NamedCache.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/NamedCache.java new file mode 100644 index 0000000..ecf063b --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/NamedCache.java @@ -0,0 +1,397 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.internals.metrics.NamedCacheMetrics; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.NavigableMap; +import java.util.Set; +import java.util.TreeMap; +import java.util.TreeSet; + +class NamedCache { + private static final Logger log = LoggerFactory.getLogger(NamedCache.class); + private final String name; + private final String storeName; + private final String taskName; + private final NavigableMap cache = new TreeMap<>(); + private final Set dirtyKeys = new LinkedHashSet<>(); + private ThreadCache.DirtyEntryFlushListener listener; + private LRUNode tail; + private LRUNode head; + private long currentSizeBytes; + + private final StreamsMetricsImpl streamsMetrics; + private final Sensor hitRatioSensor; + + // internal stats + private long numReadHits = 0; + private long numReadMisses = 0; + private long numOverwrites = 0; + private long numFlushes = 0; + + NamedCache(final String name, final StreamsMetricsImpl streamsMetrics) { + this.name = name; + this.streamsMetrics = streamsMetrics; + storeName = ThreadCache.underlyingStoreNamefromCacheName(name); + taskName = ThreadCache.taskIDfromCacheName(name); + hitRatioSensor = NamedCacheMetrics.hitRatioSensor( + streamsMetrics, + Thread.currentThread().getName(), + taskName, + storeName + ); + } + + synchronized final String name() { + return name; + } + + synchronized long hits() { + return numReadHits; + } + + synchronized long misses() { + return numReadMisses; + } + + synchronized long overwrites() { + return numOverwrites; + } + + synchronized long flushes() { + return numFlushes; + } + + synchronized LRUCacheEntry get(final Bytes key) { + if (key == null) { + return null; + } + + final LRUNode node = getInternal(key); + if (node == null) { + return null; + } + updateLRU(node); + return node.entry; + } + + synchronized void setListener(final ThreadCache.DirtyEntryFlushListener listener) { + this.listener = listener; + } + + synchronized void flush() { + flush(null); + } + + private void flush(final LRUNode evicted) { + numFlushes++; + + if (log.isTraceEnabled()) { + log.trace("Named cache {} stats on flush: #hits={}, #misses={}, #overwrites={}, #flushes={}", + name, hits(), misses(), overwrites(), flushes()); + } + + if (listener == null) { + throw new IllegalArgumentException("No listener for namespace " + name + " registered with cache"); + } + + if (dirtyKeys.isEmpty()) { + return; + } + + final List entries = new ArrayList<>(); + final List deleted = new ArrayList<>(); + + // evicted already been removed from the cache so add it to the list of + // flushed entries and remove from dirtyKeys. + if (evicted != null) { + entries.add(new ThreadCache.DirtyEntry(evicted.key, evicted.entry.value(), evicted.entry)); + dirtyKeys.remove(evicted.key); + } + + for (final Bytes key : dirtyKeys) { + final LRUNode node = getInternal(key); + if (node == null) { + throw new IllegalStateException("Key = " + key + " found in dirty key set, but entry is null"); + } + entries.add(new ThreadCache.DirtyEntry(key, node.entry.value(), node.entry)); + node.entry.markClean(); + if (node.entry.value() == null) { + deleted.add(node.key); + } + } + // clear dirtyKeys before the listener is applied as it may be re-entrant. + dirtyKeys.clear(); + listener.apply(entries); + for (final Bytes key : deleted) { + delete(key); + } + } + + synchronized void put(final Bytes key, final LRUCacheEntry value) { + if (!value.isDirty() && dirtyKeys.contains(key)) { + throw new IllegalStateException( + String.format( + "Attempting to put a clean entry for key [%s] into NamedCache [%s] when it already contains a dirty entry for the same key", + key, name + ) + ); + } + LRUNode node = cache.get(key); + if (node != null) { + numOverwrites++; + + currentSizeBytes -= node.size(); + node.update(value); + updateLRU(node); + } else { + node = new LRUNode(key, value); + // put element + putHead(node); + cache.put(key, node); + } + if (value.isDirty()) { + // first remove and then add so we can maintain ordering as the arrival order of the records. + dirtyKeys.remove(key); + dirtyKeys.add(key); + } + currentSizeBytes += node.size(); + } + + synchronized long sizeInBytes() { + return currentSizeBytes; + } + + private LRUNode getInternal(final Bytes key) { + final LRUNode node = cache.get(key); + if (node == null) { + numReadMisses++; + + return null; + } else { + numReadHits++; + hitRatioSensor.record((double) numReadHits / (double) (numReadHits + numReadMisses)); + } + return node; + } + + private void updateLRU(final LRUNode node) { + remove(node); + + putHead(node); + } + + private void remove(final LRUNode node) { + if (node.previous != null) { + node.previous.next = node.next; + } else { + head = node.next; + } + if (node.next != null) { + node.next.previous = node.previous; + } else { + tail = node.previous; + } + } + + private void putHead(final LRUNode node) { + node.next = head; + node.previous = null; + if (head != null) { + head.previous = node; + } + head = node; + if (tail == null) { + tail = head; + } + } + + synchronized void evict() { + if (tail == null) { + return; + } + final LRUNode eldest = tail; + currentSizeBytes -= eldest.size(); + remove(eldest); + cache.remove(eldest.key); + if (eldest.entry.isDirty()) { + flush(eldest); + } + } + + synchronized LRUCacheEntry putIfAbsent(final Bytes key, final LRUCacheEntry value) { + final LRUCacheEntry originalValue = get(key); + if (originalValue == null) { + put(key, value); + } + return originalValue; + } + + synchronized void putAll(final List> entries) { + for (final KeyValue entry : entries) { + put(Bytes.wrap(entry.key), entry.value); + } + } + + synchronized LRUCacheEntry delete(final Bytes key) { + final LRUNode node = cache.remove(key); + + if (node == null) { + return null; + } + + remove(node); + dirtyKeys.remove(key); + currentSizeBytes -= node.size(); + return node.entry(); + } + + public long size() { + return cache.size(); + } + + public boolean isEmpty() { + return cache.isEmpty(); + } + + synchronized Iterator keyRange(final Bytes from, final Bytes to, final boolean toInclusive) { + final Set rangeSet = computeSubSet(from, to, toInclusive); + return keySetIterator(rangeSet, true); + } + + synchronized Iterator reverseKeyRange(final Bytes from, final Bytes to) { + final Set rangeSet = computeSubSet(from, to, true); + return keySetIterator(rangeSet, false); + } + + private Set computeSubSet(final Bytes from, final Bytes to, final boolean toInclusive) { + if (from == null && to == null) { + return cache.navigableKeySet(); + } else if (from == null) { + return cache.headMap(to, toInclusive).keySet(); + } else if (to == null) { + return cache.tailMap(from, true).keySet(); + } else if (from.compareTo(to) > 0) { + return Collections.emptyNavigableSet(); + } else { + return cache.navigableKeySet().subSet(from, true, to, toInclusive); + } + } + + private Iterator keySetIterator(final Set keySet, final boolean forward) { + if (forward) { + return new TreeSet<>(keySet).iterator(); + } else { + return new TreeSet<>(keySet).descendingIterator(); + } + } + + synchronized Iterator allKeys() { + return keySetIterator(cache.navigableKeySet(), true); + } + + synchronized Iterator reverseAllKeys() { + return keySetIterator(cache.navigableKeySet(), false); + } + + synchronized LRUCacheEntry first() { + if (head == null) { + return null; + } + return head.entry; + } + + synchronized LRUCacheEntry last() { + if (tail == null) { + return null; + } + return tail.entry; + } + + synchronized LRUNode head() { + return head; + } + + synchronized LRUNode tail() { + return tail; + } + + synchronized void close() { + head = tail = null; + listener = null; + currentSizeBytes = 0; + dirtyKeys.clear(); + cache.clear(); + streamsMetrics.removeAllCacheLevelSensors(Thread.currentThread().getName(), taskName, storeName); + } + + /** + * A simple wrapper class to implement a doubly-linked list around MemoryLRUCacheBytesEntry + */ + static class LRUNode { + private final Bytes key; + private LRUCacheEntry entry; + private LRUNode previous; + private LRUNode next; + + LRUNode(final Bytes key, final LRUCacheEntry entry) { + this.key = key; + this.entry = entry; + } + + LRUCacheEntry entry() { + return entry; + } + + Bytes key() { + return key; + } + + long size() { + return key.get().length + + 8 + // entry + 8 + // previous + 8 + // next + entry.size(); + } + + LRUNode next() { + return next; + } + + LRUNode previous() { + return previous; + } + + private void update(final LRUCacheEntry entry) { + this.entry = entry; + } + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/NextIteratorFunction.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/NextIteratorFunction.java new file mode 100644 index 0000000..9f9ac55 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/NextIteratorFunction.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.state.KeyValueIterator; + +interface NextIteratorFunction { + + KeyValueIterator apply(final StoreType store); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/OffsetCheckpoint.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/OffsetCheckpoint.java new file mode 100644 index 0000000..3ec2386 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/OffsetCheckpoint.java @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.EOFException; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.NoSuchFileException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.regex.Pattern; + +/** + * This class saves out a map of topic/partition=>offsets to a file. The format of the file is UTF-8 text containing the following: + *

                + *   <version>
                + *   <n>
                + *   <topic_name_1> <partition_1> <offset_1>
                + *   .
                + *   .
                + *   .
                + *   <topic_name_n> <partition_n> <offset_n>
                + * 
                + * The first line contains a number designating the format version (currently 0), the get line contains + * a number giving the total number of offsets. Each successive line gives a topic/partition/offset triple + * separated by spaces. + */ +public class OffsetCheckpoint { + private static final Logger LOG = LoggerFactory.getLogger(OffsetCheckpoint.class); + + private static final Pattern WHITESPACE_MINIMUM_ONCE = Pattern.compile("\\s+"); + + private static final int VERSION = 0; + + // Use a negative sentinel when we don't know the offset instead of skipping it to distinguish it from dirty state + // and use -4 as the -1 sentinel may be taken by some producer errors and -2 in the + // subscription means that the state is used by an active task and hence caught-up and + // -3 is also used in the subscription. + public static final long OFFSET_UNKNOWN = -4L; + + private final File file; + private final Object lock; + + public OffsetCheckpoint(final File file) { + this.file = file; + lock = new Object(); + } + + /** + * Write the given offsets to the checkpoint file. All offsets should be non-negative. + * + * @throws IOException if any file operation fails with an IO exception + */ + public void write(final Map offsets) throws IOException { + // if there are no offsets, skip writing the file to save disk IOs + // but make sure to delete the existing file if one exists + if (offsets.isEmpty()) { + Utils.delete(file); + return; + } + + synchronized (lock) { + // write to temp file and then swap with the existing file + final File temp = new File(file.getAbsolutePath() + ".tmp"); + LOG.trace("Writing tmp checkpoint file {}", temp.getAbsolutePath()); + + final FileOutputStream fileOutputStream = new FileOutputStream(temp); + try (final BufferedWriter writer = new BufferedWriter( + new OutputStreamWriter(fileOutputStream, StandardCharsets.UTF_8))) { + writeIntLine(writer, VERSION); + writeIntLine(writer, offsets.size()); + + for (final Map.Entry entry : offsets.entrySet()) { + final TopicPartition tp = entry.getKey(); + final Long offset = entry.getValue(); + if (isValid(offset)) { + writeEntry(writer, tp, offset); + } else { + LOG.error("Received offset={} to write to checkpoint file for {}", offset, tp); + throw new IllegalStateException("Attempted to write a negative offset to the checkpoint file"); + } + } + + writer.flush(); + fileOutputStream.getFD().sync(); + } + + LOG.trace("Swapping tmp checkpoint file {} {}", temp.toPath(), file.toPath()); + Utils.atomicMoveWithFallback(temp.toPath(), file.toPath()); + } + } + + /** + * @throws IOException if file write operations failed with any IO exception + */ + static void writeIntLine(final BufferedWriter writer, + final int number) throws IOException { + writer.write(Integer.toString(number)); + writer.newLine(); + } + + /** + * @throws IOException if file write operations failed with any IO exception + */ + static void writeEntry(final BufferedWriter writer, + final TopicPartition part, + final long offset) throws IOException { + writer.write(part.topic()); + writer.write(' '); + writer.write(Integer.toString(part.partition())); + writer.write(' '); + writer.write(Long.toString(offset)); + writer.newLine(); + } + + + /** + * Reads the offsets from the local checkpoint file, skipping any negative offsets it finds. + * + * @throws IOException if any file operation fails with an IO exception + * @throws IllegalArgumentException if the offset checkpoint version is unknown + */ + public Map read() throws IOException { + synchronized (lock) { + try (final BufferedReader reader = Files.newBufferedReader(file.toPath())) { + final int version = readInt(reader); + switch (version) { + case 0: + int expectedSize = readInt(reader); + final Map offsets = new HashMap<>(); + String line = reader.readLine(); + while (line != null) { + final String[] pieces = WHITESPACE_MINIMUM_ONCE.split(line); + if (pieces.length != 3) { + throw new IOException( + String.format("Malformed line in offset checkpoint file: '%s'.", line)); + } + + final String topic = pieces[0]; + final int partition = Integer.parseInt(pieces[1]); + final TopicPartition tp = new TopicPartition(topic, partition); + final long offset = Long.parseLong(pieces[2]); + if (isValid(offset)) { + offsets.put(tp, offset); + } else { + LOG.warn("Read offset={} from checkpoint file for {}", offset, tp); + --expectedSize; + } + + line = reader.readLine(); + } + if (offsets.size() != expectedSize) { + throw new IOException( + String.format("Expected %d entries but found only %d", expectedSize, offsets.size())); + } + return offsets; + + default: + throw new IllegalArgumentException("Unknown offset checkpoint version: " + version); + } + } catch (final NoSuchFileException e) { + return Collections.emptyMap(); + } + } + } + + /** + * @throws IOException if file read ended prematurely + */ + private int readInt(final BufferedReader reader) throws IOException { + final String line = reader.readLine(); + if (line == null) { + throw new EOFException("File ended prematurely."); + } + return Integer.parseInt(line); + } + + /** + * @throws IOException if there is any IO exception during delete + */ + public void delete() throws IOException { + Files.deleteIfExists(file.toPath()); + } + + @Override + public String toString() { + return file.getAbsolutePath(); + } + + private boolean isValid(final long offset) { + return offset >= 0L || offset == OFFSET_UNKNOWN; + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/OrderedBytes.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/OrderedBytes.java new file mode 100644 index 0000000..561f24c --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/OrderedBytes.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; + +import java.nio.ByteBuffer; + +class OrderedBytes { + + private static final int MIN_KEY_LENGTH = 1; + /** + * Returns the upper byte range for a key with a given fixed size maximum suffix + * + * Assumes the minimum key length is one byte + */ + static Bytes upperRange(final Bytes key, final byte[] maxSuffix) { + final byte[] bytes = key.get(); + final ByteBuffer rangeEnd = ByteBuffer.allocate(bytes.length + maxSuffix.length); + final int firstTimestampByte = maxSuffix[0] & 0xFF; + + // if firstTimestampByte is 0, we'll put all key bytes into range result because `(bytes[i] & 0xFF) >= firstTimestampByte` + // will always be true (this is a byte to unsigned int conversion comparison) + if (firstTimestampByte == 0) { + return Bytes.wrap( + rangeEnd + .put(bytes) + .put(maxSuffix) + .array() + ); + } else { + int i = 0; + while (i < bytes.length && ( + i < MIN_KEY_LENGTH // assumes keys are at least one byte long + || (bytes[i] & 0xFF) >= firstTimestampByte + )) { + rangeEnd.put(bytes[i++]); + } + + rangeEnd.put(maxSuffix); + rangeEnd.flip(); + + final byte[] res = new byte[rangeEnd.remaining()]; + ByteBuffer.wrap(res).put(rangeEnd); + return Bytes.wrap(res); + } + } + + static Bytes lowerRange(final Bytes key, final byte[] minSuffix) { + final byte[] bytes = key.get(); + final ByteBuffer rangeStart = ByteBuffer.allocate(bytes.length + minSuffix.length); + // any key in the range would start at least with the given prefix to be + // in the range, and have at least SUFFIX_SIZE number of trailing zero bytes. + + // unless there is a maximum key length, you can keep appending more zero bytes + // to keyFrom to create a key that will match the range, yet that would precede + // KeySchema.toBinaryKey(keyFrom, from, 0) in byte order + return Bytes.wrap( + rangeStart + .put(bytes) + .put(minSuffix) + .array() + ); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/PeekingKeyValueIterator.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/PeekingKeyValueIterator.java new file mode 100644 index 0000000..1554ad2 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/PeekingKeyValueIterator.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.state.KeyValueIterator; + +public interface PeekingKeyValueIterator extends KeyValueIterator { + + KeyValue peekNext(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/QueryableStoreProvider.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/QueryableStoreProvider.java new file mode 100644 index 0000000..07cf0ee --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/QueryableStoreProvider.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.StoreQueryParameters; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.state.QueryableStoreType; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * A wrapper over all of the {@link StateStoreProvider}s in a Topology + * + * The store providers field is a reference + */ +public class QueryableStoreProvider { + + // map of StreamThread.name to StreamThreadStateStoreProvider + private final Map storeProviders; + private final GlobalStateStoreProvider globalStoreProvider; + + public QueryableStoreProvider(final GlobalStateStoreProvider globalStateStoreProvider) { + this.storeProviders = new HashMap<>(); + this.globalStoreProvider = globalStateStoreProvider; + } + + /** + * Get a composite object wrapping the instances of the {@link StateStore} with the provided + * storeName and {@link QueryableStoreType} + * + * @param storeQueryParameters if stateStoresEnabled is used i.e. staleStoresEnabled is true, include standbys and recovering stores; + * if stateStoresDisabled i.e. staleStoresEnabled is false, only include running actives; + * if partition is null then it fetches all local partitions on the instance; + * if partition is set then it fetches a specific partition. + * @param The expected type of the returned store + * @return A composite object that wraps the store instances. + */ + public T getStore(final StoreQueryParameters storeQueryParameters) { + final String storeName = storeQueryParameters.storeName(); + final QueryableStoreType queryableStoreType = storeQueryParameters.queryableStoreType(); + final List globalStore = globalStoreProvider.stores(storeName, queryableStoreType); + if (!globalStore.isEmpty()) { + return queryableStoreType.create(globalStoreProvider, storeName); + } + return queryableStoreType.create( + new WrappingStoreProvider(storeProviders.values(), storeQueryParameters), + storeName + ); + } + + public void addStoreProviderForThread(final String threadName, final StreamThreadStateStoreProvider streamThreadStateStoreProvider) { + this.storeProviders.put(threadName, streamThreadStateStoreProvider); + } + + public void removeStoreProviderForThread(final String threadName) { + this.storeProviders.remove(threadName); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/ReadOnlyKeyValueStoreFacade.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/ReadOnlyKeyValueStoreFacade.java new file mode 100644 index 0000000..7a03f72 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/ReadOnlyKeyValueStoreFacade.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.ReadOnlyKeyValueStore; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; + +import static org.apache.kafka.streams.state.ValueAndTimestamp.getValueOrNull; + +public class ReadOnlyKeyValueStoreFacade implements ReadOnlyKeyValueStore { + protected final TimestampedKeyValueStore inner; + + protected ReadOnlyKeyValueStoreFacade(final TimestampedKeyValueStore store) { + inner = store; + } + + @Override + public V get(final K key) { + return getValueOrNull(inner.get(key)); + } + + @Override + public KeyValueIterator range(final K from, + final K to) { + return new KeyValueIteratorFacade<>(inner.range(from, to)); + } + + @Override + public KeyValueIterator reverseRange(final K from, + final K to) { + return new KeyValueIteratorFacade<>(inner.reverseRange(from, to)); + } + + @Override + public , P> KeyValueIterator prefixScan(final P prefix, + final PS prefixKeySerializer) { + return new KeyValueIteratorFacade<>(inner.prefixScan(prefix, prefixKeySerializer)); + } + + @Override + public KeyValueIterator all() { + return new KeyValueIteratorFacade<>(inner.all()); + } + + @Override + public KeyValueIterator reverseAll() { + return new KeyValueIteratorFacade<>(inner.reverseAll()); + } + + @Override + public long approximateNumEntries() { + return inner.approximateNumEntries(); + } +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/ReadOnlyWindowStoreFacade.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/ReadOnlyWindowStoreFacade.java new file mode 100644 index 0000000..281a1c2 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/ReadOnlyWindowStoreFacade.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.ReadOnlyWindowStore; +import org.apache.kafka.streams.state.TimestampedWindowStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.streams.state.WindowStoreIterator; + +import java.time.Instant; + +import static org.apache.kafka.streams.state.ValueAndTimestamp.getValueOrNull; + +public class ReadOnlyWindowStoreFacade implements ReadOnlyWindowStore { + protected final TimestampedWindowStore inner; + + protected ReadOnlyWindowStoreFacade(final TimestampedWindowStore store) { + inner = store; + } + + @Override + public V fetch(final K key, + final long time) { + return getValueOrNull(inner.fetch(key, time)); + } + + @Override + public WindowStoreIterator fetch(final K key, + final Instant timeFrom, + final Instant timeTo) throws IllegalArgumentException { + return new WindowStoreIteratorFacade<>(inner.fetch(key, timeFrom, timeTo)); + } + + @Override + public WindowStoreIterator backwardFetch(final K key, + final Instant timeFrom, + final Instant timeTo) throws IllegalArgumentException { + return new WindowStoreIteratorFacade<>(inner.backwardFetch(key, timeFrom, timeTo)); + } + + @Override + public KeyValueIterator, V> fetch(final K keyFrom, + final K keyTo, + final Instant timeFrom, + final Instant timeTo) throws IllegalArgumentException { + return new KeyValueIteratorFacade<>(inner.fetch(keyFrom, keyTo, timeFrom, timeTo)); + } + + @Override + public KeyValueIterator, V> backwardFetch(final K keyFrom, + final K keyTo, + final Instant timeFrom, + final Instant timeTo) throws IllegalArgumentException { + return new KeyValueIteratorFacade<>(inner.backwardFetch(keyFrom, keyTo, timeFrom, timeTo)); + } + + @Override + public KeyValueIterator, V> fetchAll(final Instant timeFrom, + final Instant timeTo) throws IllegalArgumentException { + return new KeyValueIteratorFacade<>(inner.fetchAll(timeFrom, timeTo)); + } + + @Override + public KeyValueIterator, V> backwardFetchAll(final Instant timeFrom, + final Instant timeTo) throws IllegalArgumentException { + return new KeyValueIteratorFacade<>(inner.backwardFetchAll(timeFrom, timeTo)); + } + + @Override + public KeyValueIterator, V> all() { + return new KeyValueIteratorFacade<>(inner.all()); + } + + @Override + public KeyValueIterator, V> backwardAll() { + return new KeyValueIteratorFacade<>(inner.backwardAll()); + } + + private static class WindowStoreIteratorFacade implements WindowStoreIterator { + final KeyValueIterator> innerIterator; + + WindowStoreIteratorFacade(final KeyValueIterator> iterator) { + innerIterator = iterator; + } + + @Override + public void close() { + innerIterator.close(); + } + + @Override + public Long peekNextKey() { + return innerIterator.peekNextKey(); + } + + @Override + public boolean hasNext() { + return innerIterator.hasNext(); + } + + @Override + public KeyValue next() { + final KeyValue> innerKeyValue = innerIterator.next(); + return KeyValue.pair(innerKeyValue.key, getValueOrNull(innerKeyValue.value)); + } + } +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/RecordConverter.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/RecordConverter.java new file mode 100644 index 0000000..9046e37 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/RecordConverter.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; + +public interface RecordConverter { + ConsumerRecord convert(final ConsumerRecord record); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/RecordConverters.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/RecordConverters.java new file mode 100644 index 0000000..ad3c91e --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/RecordConverters.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; + +import java.nio.ByteBuffer; + +public final class RecordConverters { + private static final RecordConverter IDENTITY_INSTANCE = record -> record; + + private static final RecordConverter RAW_TO_TIMESTAMED_INSTANCE = record -> { + final byte[] rawValue = record.value(); + final long timestamp = record.timestamp(); + final byte[] recordValue = rawValue == null ? null : + ByteBuffer.allocate(8 + rawValue.length) + .putLong(timestamp) + .put(rawValue) + .array(); + return new ConsumerRecord<>( + record.topic(), + record.partition(), + record.offset(), + timestamp, + record.timestampType(), + record.serializedKeySize(), + record.serializedValueSize(), + record.key(), + recordValue, + record.headers(), + record.leaderEpoch() + ); + }; + + // privatize the constructor so the class cannot be instantiated (only used for its static members) + private RecordConverters() {} + + public static RecordConverter rawValueToTimestampedValue() { + return RAW_TO_TIMESTAMED_INSTANCE; + } + + public static RecordConverter identity() { + return IDENTITY_INSTANCE; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBGenericOptionsToDbOptionsColumnFamilyOptionsAdapter.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBGenericOptionsToDbOptionsColumnFamilyOptionsAdapter.java new file mode 100644 index 0000000..2f7a70a --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBGenericOptionsToDbOptionsColumnFamilyOptionsAdapter.java @@ -0,0 +1,1687 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.rocksdb.AbstractCompactionFilter; +import org.rocksdb.AbstractCompactionFilterFactory; +import org.rocksdb.AbstractComparator; +import org.rocksdb.AbstractEventListener; +import org.rocksdb.AbstractSlice; +import org.rocksdb.AbstractWalFilter; +import org.rocksdb.AccessHint; +import org.rocksdb.BuiltinComparator; +import org.rocksdb.Cache; +import org.rocksdb.ColumnFamilyOptions; +import org.rocksdb.CompactionOptionsFIFO; +import org.rocksdb.CompactionOptionsUniversal; +import org.rocksdb.CompactionPriority; +import org.rocksdb.CompactionStyle; +import org.rocksdb.CompressionOptions; +import org.rocksdb.CompressionType; +import org.rocksdb.ConcurrentTaskLimiter; +import org.rocksdb.DBOptions; +import org.rocksdb.DbPath; +import org.rocksdb.Env; +import org.rocksdb.InfoLogLevel; +import org.rocksdb.MemTableConfig; +import org.rocksdb.MergeOperator; +import org.rocksdb.Options; +import org.rocksdb.RateLimiter; +import org.rocksdb.SstFileManager; +import org.rocksdb.SstPartitionerFactory; +import org.rocksdb.Statistics; +import org.rocksdb.TableFormatConfig; +import org.rocksdb.WALRecoveryMode; +import org.rocksdb.WalFilter; +import org.rocksdb.WriteBufferManager; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collection; +import java.util.List; + +/** + * The generic {@link Options} class allows users to set all configs on one object if only default column family + * is used. Because we use multiple column families, we need to use {@link DBOptions} and {@link ColumnFamilyOptions} + * that cover a part of all options each. + * + * This class do the translation between generic {@link Options} into {@link DBOptions} and {@link ColumnFamilyOptions}. + */ +public class RocksDBGenericOptionsToDbOptionsColumnFamilyOptionsAdapter extends Options { + + private static final Logger log = LoggerFactory.getLogger(RocksDBGenericOptionsToDbOptionsColumnFamilyOptionsAdapter.class); + + private final DBOptions dbOptions; + private final ColumnFamilyOptions columnFamilyOptions; + + RocksDBGenericOptionsToDbOptionsColumnFamilyOptionsAdapter(final DBOptions dbOptions, + final ColumnFamilyOptions columnFamilyOptions) { + this.dbOptions = dbOptions; + this.columnFamilyOptions = columnFamilyOptions; + } + + @Override + public Options setIncreaseParallelism(final int totalThreads) { + dbOptions.setIncreaseParallelism(totalThreads); + return this; + } + + @Override + public Options setCreateIfMissing(final boolean flag) { + dbOptions.setCreateIfMissing(flag); + return this; + } + + @Override + public Options setCreateMissingColumnFamilies(final boolean flag) { + dbOptions.setCreateMissingColumnFamilies(flag); + return this; + } + + @Override + public Options setEnv(final Env env) { + dbOptions.setEnv(env); + return this; + } + + @Override + public Env getEnv() { + return dbOptions.getEnv(); + } + + @Override + public Options prepareForBulkLoad() { + super.prepareForBulkLoad(); + return this; + } + + @Override + public boolean createIfMissing() { + return dbOptions.createIfMissing(); + } + + @Override + public boolean createMissingColumnFamilies() { + return dbOptions.createMissingColumnFamilies(); + } + + @Override + public Options optimizeForSmallDb() { + dbOptions.optimizeForSmallDb(); + columnFamilyOptions.optimizeForSmallDb(); + return this; + } + + @Override + public Options optimizeForPointLookup(final long blockCacheSizeMb) { + columnFamilyOptions.optimizeForPointLookup(blockCacheSizeMb); + return this; + } + + @Override + public Options optimizeLevelStyleCompaction() { + columnFamilyOptions.optimizeLevelStyleCompaction(); + return this; + } + + @Override + public Options optimizeLevelStyleCompaction(final long memtableMemoryBudget) { + columnFamilyOptions.optimizeLevelStyleCompaction(memtableMemoryBudget); + return this; + } + + @Override + public Options optimizeUniversalStyleCompaction() { + columnFamilyOptions.optimizeUniversalStyleCompaction(); + return this; + } + + @Override + public Options optimizeUniversalStyleCompaction(final long memtableMemoryBudget) { + columnFamilyOptions.optimizeUniversalStyleCompaction(memtableMemoryBudget); + return this; + } + + @Override + public Options setComparator(final BuiltinComparator builtinComparator) { + columnFamilyOptions.setComparator(builtinComparator); + return this; + } + + @Override + public Options setComparator(final AbstractComparator comparator) { + columnFamilyOptions.setComparator(comparator); + return this; + } + + @Override + public Options setMergeOperatorName(final String name) { + columnFamilyOptions.setMergeOperatorName(name); + return this; + } + + @Override + public Options setMergeOperator(final MergeOperator mergeOperator) { + columnFamilyOptions.setMergeOperator(mergeOperator); + return this; + } + + @Override + public Options setWriteBufferSize(final long writeBufferSize) { + columnFamilyOptions.setWriteBufferSize(writeBufferSize); + return this; + } + + @Override + public long writeBufferSize() { + return columnFamilyOptions.writeBufferSize(); + } + + @Override + public Options setMaxWriteBufferNumber(final int maxWriteBufferNumber) { + columnFamilyOptions.setMaxWriteBufferNumber(maxWriteBufferNumber); + return this; + } + + @Override + public int maxWriteBufferNumber() { + return columnFamilyOptions.maxWriteBufferNumber(); + } + + @Override + public boolean errorIfExists() { + return dbOptions.errorIfExists(); + } + + @Override + public Options setErrorIfExists(final boolean errorIfExists) { + dbOptions.setErrorIfExists(errorIfExists); + return this; + } + + @Override + public boolean paranoidChecks() { + final boolean columnFamilyParanoidFileChecks = columnFamilyOptions.paranoidFileChecks(); + final boolean dbOptionsParanoidChecks = dbOptions.paranoidChecks(); + + if (columnFamilyParanoidFileChecks != dbOptionsParanoidChecks) { + throw new IllegalStateException("Config for paranoid checks for RockDB and ColumnFamilies should be the same."); + } + + return dbOptionsParanoidChecks; + } + + @Override + public Options setParanoidChecks(final boolean paranoidChecks) { + columnFamilyOptions.paranoidFileChecks(); + dbOptions.setParanoidChecks(paranoidChecks); + return this; + } + + @Override + public int maxOpenFiles() { + return dbOptions.maxOpenFiles(); + } + + @Override + public Options setMaxFileOpeningThreads(final int maxFileOpeningThreads) { + dbOptions.setMaxFileOpeningThreads(maxFileOpeningThreads); + return this; + } + + @Override + public int maxFileOpeningThreads() { + return dbOptions.maxFileOpeningThreads(); + } + + @Override + public Options setMaxTotalWalSize(final long maxTotalWalSize) { + logIgnoreWalOption("maxTotalWalSize"); + return this; + } + + @Override + public long maxTotalWalSize() { + return dbOptions.maxTotalWalSize(); + } + + @Override + public Options setMaxOpenFiles(final int maxOpenFiles) { + dbOptions.setMaxOpenFiles(maxOpenFiles); + return this; + } + + @Override + public boolean useFsync() { + return dbOptions.useFsync(); + } + + @Override + public Options setUseFsync(final boolean useFsync) { + dbOptions.setUseFsync(useFsync); + return this; + } + + @Override + public Options setDbPaths(final Collection dbPaths) { + dbOptions.setDbPaths(dbPaths); + return this; + } + + @Override + public List dbPaths() { + return dbOptions.dbPaths(); + } + + @Override + public String dbLogDir() { + return dbOptions.dbLogDir(); + } + + @Override + public Options setDbLogDir(final String dbLogDir) { + dbOptions.setDbLogDir(dbLogDir); + return this; + } + + @Override + public String walDir() { + return dbOptions.walDir(); + } + + @Override + public Options setWalDir(final String walDir) { + logIgnoreWalOption("walDir"); + return this; + } + + @Override + public long deleteObsoleteFilesPeriodMicros() { + return dbOptions.deleteObsoleteFilesPeriodMicros(); + } + + @Override + public Options setDeleteObsoleteFilesPeriodMicros(final long micros) { + dbOptions.setDeleteObsoleteFilesPeriodMicros(micros); + return this; + } + + @Deprecated + @Override + public int maxBackgroundCompactions() { + return dbOptions.maxBackgroundCompactions(); + } + + @Override + public Options setStatistics(final Statistics statistics) { + dbOptions.setStatistics(statistics); + return this; + } + + @Override + public Statistics statistics() { + return dbOptions.statistics(); + } + + @Deprecated + @Override + public void setBaseBackgroundCompactions(final int baseBackgroundCompactions) { + dbOptions.setBaseBackgroundCompactions(baseBackgroundCompactions); + } + + @Override + public int baseBackgroundCompactions() { + return dbOptions.baseBackgroundCompactions(); + } + + @Deprecated + @Override + public Options setMaxBackgroundCompactions(final int maxBackgroundCompactions) { + dbOptions.setMaxBackgroundCompactions(maxBackgroundCompactions); + return this; + } + + @Override + public Options setMaxSubcompactions(final int maxSubcompactions) { + dbOptions.setMaxSubcompactions(maxSubcompactions); + return this; + } + + @Override + public int maxSubcompactions() { + return dbOptions.maxSubcompactions(); + } + + @Deprecated + @Override + public int maxBackgroundFlushes() { + return dbOptions.maxBackgroundFlushes(); + } + + @Deprecated + @Override + public Options setMaxBackgroundFlushes(final int maxBackgroundFlushes) { + dbOptions.setMaxBackgroundFlushes(maxBackgroundFlushes); + return this; + } + + @Override + public int maxBackgroundJobs() { + return dbOptions.maxBackgroundJobs(); + } + + @Override + public Options setMaxBackgroundJobs(final int maxBackgroundJobs) { + dbOptions.setMaxBackgroundJobs(maxBackgroundJobs); + return this; + } + + @Override + public long maxLogFileSize() { + return dbOptions.maxLogFileSize(); + } + + @Override + public Options setMaxLogFileSize(final long maxLogFileSize) { + dbOptions.setMaxLogFileSize(maxLogFileSize); + return this; + } + + @Override + public long logFileTimeToRoll() { + return dbOptions.logFileTimeToRoll(); + } + + @Override + public Options setLogFileTimeToRoll(final long logFileTimeToRoll) { + dbOptions.setLogFileTimeToRoll(logFileTimeToRoll); + return this; + } + + @Override + public long keepLogFileNum() { + return dbOptions.keepLogFileNum(); + } + + @Override + public Options setKeepLogFileNum(final long keepLogFileNum) { + dbOptions.setKeepLogFileNum(keepLogFileNum); + return this; + } + + @Override + public Options setRecycleLogFileNum(final long recycleLogFileNum) { + dbOptions.setRecycleLogFileNum(recycleLogFileNum); + return this; + } + + @Override + public long recycleLogFileNum() { + return dbOptions.recycleLogFileNum(); + } + + @Override + public long maxManifestFileSize() { + return dbOptions.maxManifestFileSize(); + } + + @Override + public Options setMaxManifestFileSize(final long maxManifestFileSize) { + dbOptions.setMaxManifestFileSize(maxManifestFileSize); + return this; + } + + @Override + public Options setMaxTableFilesSizeFIFO(final long maxTableFilesSize) { + columnFamilyOptions.setMaxTableFilesSizeFIFO(maxTableFilesSize); + return this; + } + + @Override + public long maxTableFilesSizeFIFO() { + return columnFamilyOptions.maxTableFilesSizeFIFO(); + } + + @Override + public int tableCacheNumshardbits() { + return dbOptions.tableCacheNumshardbits(); + } + + @Override + public Options setTableCacheNumshardbits(final int tableCacheNumshardbits) { + dbOptions.setTableCacheNumshardbits(tableCacheNumshardbits); + return this; + } + + @Override + public long walTtlSeconds() { + return dbOptions.walTtlSeconds(); + } + + @Override + public Options setWalTtlSeconds(final long walTtlSeconds) { + logIgnoreWalOption("walTtlSeconds"); + return this; + } + + @Override + public long walSizeLimitMB() { + return dbOptions.walSizeLimitMB(); + } + + @Override + public Options setWalSizeLimitMB(final long sizeLimitMB) { + logIgnoreWalOption("walSizeLimitMB"); + return this; + } + + @Override + public long manifestPreallocationSize() { + return dbOptions.manifestPreallocationSize(); + } + + @Override + public Options setManifestPreallocationSize(final long size) { + dbOptions.setManifestPreallocationSize(size); + return this; + } + + @Override + public Options setUseDirectReads(final boolean useDirectReads) { + dbOptions.setUseDirectReads(useDirectReads); + return this; + } + + @Override + public boolean useDirectReads() { + return dbOptions.useDirectReads(); + } + + @Override + public Options setUseDirectIoForFlushAndCompaction(final boolean useDirectIoForFlushAndCompaction) { + dbOptions.setUseDirectIoForFlushAndCompaction(useDirectIoForFlushAndCompaction); + return this; + } + + @Override + public boolean useDirectIoForFlushAndCompaction() { + return dbOptions.useDirectIoForFlushAndCompaction(); + } + + @Override + public Options setAllowFAllocate(final boolean allowFAllocate) { + dbOptions.setAllowFAllocate(allowFAllocate); + return this; + } + + @Override + public boolean allowFAllocate() { + return dbOptions.allowFAllocate(); + } + + @Override + public boolean allowMmapReads() { + return dbOptions.allowMmapReads(); + } + + @Override + public Options setAllowMmapReads(final boolean allowMmapReads) { + dbOptions.setAllowMmapReads(allowMmapReads); + return this; + } + + @Override + public boolean allowMmapWrites() { + return dbOptions.allowMmapWrites(); + } + + @Override + public Options setAllowMmapWrites(final boolean allowMmapWrites) { + dbOptions.setAllowMmapWrites(allowMmapWrites); + return this; + } + + @Override + public boolean isFdCloseOnExec() { + return dbOptions.isFdCloseOnExec(); + } + + @Override + public Options setIsFdCloseOnExec(final boolean isFdCloseOnExec) { + dbOptions.setIsFdCloseOnExec(isFdCloseOnExec); + return this; + } + + @Override + public int statsDumpPeriodSec() { + return dbOptions.statsDumpPeriodSec(); + } + + @Override + public Options setStatsDumpPeriodSec(final int statsDumpPeriodSec) { + dbOptions.setStatsDumpPeriodSec(statsDumpPeriodSec); + return this; + } + + @Override + public boolean adviseRandomOnOpen() { + return dbOptions.adviseRandomOnOpen(); + } + + @Override + public Options setAdviseRandomOnOpen(final boolean adviseRandomOnOpen) { + dbOptions.setAdviseRandomOnOpen(adviseRandomOnOpen); + return this; + } + + @Override + public Options setDbWriteBufferSize(final long dbWriteBufferSize) { + dbOptions.setDbWriteBufferSize(dbWriteBufferSize); + return this; + } + + @Override + public long dbWriteBufferSize() { + return dbOptions.dbWriteBufferSize(); + } + + @Override + public Options setAccessHintOnCompactionStart(final AccessHint accessHint) { + dbOptions.setAccessHintOnCompactionStart(accessHint); + return this; + } + + @Override + public AccessHint accessHintOnCompactionStart() { + return dbOptions.accessHintOnCompactionStart(); + } + + @Override + public Options setNewTableReaderForCompactionInputs(final boolean newTableReaderForCompactionInputs) { + dbOptions.setNewTableReaderForCompactionInputs(newTableReaderForCompactionInputs); + return this; + } + + @Override + public boolean newTableReaderForCompactionInputs() { + return dbOptions.newTableReaderForCompactionInputs(); + } + + @Override + public Options setCompactionReadaheadSize(final long compactionReadaheadSize) { + dbOptions.setCompactionReadaheadSize(compactionReadaheadSize); + return this; + } + + @Override + public long compactionReadaheadSize() { + return dbOptions.compactionReadaheadSize(); + } + + @Override + public Options setRandomAccessMaxBufferSize(final long randomAccessMaxBufferSize) { + dbOptions.setRandomAccessMaxBufferSize(randomAccessMaxBufferSize); + return this; + } + + @Override + public long randomAccessMaxBufferSize() { + return dbOptions.randomAccessMaxBufferSize(); + } + + @Override + public Options setWritableFileMaxBufferSize(final long writableFileMaxBufferSize) { + dbOptions.setWritableFileMaxBufferSize(writableFileMaxBufferSize); + return this; + } + + @Override + public long writableFileMaxBufferSize() { + return dbOptions.writableFileMaxBufferSize(); + } + + @Override + public boolean useAdaptiveMutex() { + return dbOptions.useAdaptiveMutex(); + } + + @Override + public Options setUseAdaptiveMutex(final boolean useAdaptiveMutex) { + dbOptions.setUseAdaptiveMutex(useAdaptiveMutex); + return this; + } + + @Override + public long bytesPerSync() { + return dbOptions.bytesPerSync(); + } + + @Override + public Options setBytesPerSync(final long bytesPerSync) { + dbOptions.setBytesPerSync(bytesPerSync); + return this; + } + + @Override + public Options setWalBytesPerSync(final long walBytesPerSync) { + logIgnoreWalOption("walBytesPerSync"); + return this; + } + + @Override + public long walBytesPerSync() { + return dbOptions.walBytesPerSync(); + } + + @Override + public Options setEnableThreadTracking(final boolean enableThreadTracking) { + dbOptions.setEnableThreadTracking(enableThreadTracking); + return this; + } + + @Override + public boolean enableThreadTracking() { + return dbOptions.enableThreadTracking(); + } + + @Override + public Options setDelayedWriteRate(final long delayedWriteRate) { + dbOptions.setDelayedWriteRate(delayedWriteRate); + return this; + } + + @Override + public long delayedWriteRate() { + return dbOptions.delayedWriteRate(); + } + + @Override + public Options setAllowConcurrentMemtableWrite(final boolean allowConcurrentMemtableWrite) { + dbOptions.setAllowConcurrentMemtableWrite(allowConcurrentMemtableWrite); + return this; + } + + @Override + public boolean allowConcurrentMemtableWrite() { + return dbOptions.allowConcurrentMemtableWrite(); + } + + @Override + public Options setEnableWriteThreadAdaptiveYield(final boolean enableWriteThreadAdaptiveYield) { + dbOptions.setEnableWriteThreadAdaptiveYield(enableWriteThreadAdaptiveYield); + return this; + } + + @Override + public boolean enableWriteThreadAdaptiveYield() { + return dbOptions.enableWriteThreadAdaptiveYield(); + } + + @Override + public Options setWriteThreadMaxYieldUsec(final long writeThreadMaxYieldUsec) { + dbOptions.setWriteThreadMaxYieldUsec(writeThreadMaxYieldUsec); + return this; + } + + @Override + public long writeThreadMaxYieldUsec() { + return dbOptions.writeThreadMaxYieldUsec(); + } + + @Override + public Options setWriteThreadSlowYieldUsec(final long writeThreadSlowYieldUsec) { + dbOptions.setWriteThreadSlowYieldUsec(writeThreadSlowYieldUsec); + return this; + } + + @Override + public long writeThreadSlowYieldUsec() { + return dbOptions.writeThreadSlowYieldUsec(); + } + + @Override + public Options setSkipStatsUpdateOnDbOpen(final boolean skipStatsUpdateOnDbOpen) { + dbOptions.setSkipStatsUpdateOnDbOpen(skipStatsUpdateOnDbOpen); + return this; + } + + @Override + public boolean skipStatsUpdateOnDbOpen() { + return dbOptions.skipStatsUpdateOnDbOpen(); + } + + @Override + public Options setWalRecoveryMode(final WALRecoveryMode walRecoveryMode) { + logIgnoreWalOption("walRecoveryMode"); + return this; + } + + @Override + public WALRecoveryMode walRecoveryMode() { + return dbOptions.walRecoveryMode(); + } + + @Override + public Options setAllow2pc(final boolean allow2pc) { + dbOptions.setAllow2pc(allow2pc); + return this; + } + + @Override + public boolean allow2pc() { + return dbOptions.allow2pc(); + } + + @Override + public Options setRowCache(final Cache rowCache) { + dbOptions.setRowCache(rowCache); + return this; + } + + @Override + public Cache rowCache() { + return dbOptions.rowCache(); + } + + @Override + public Options setFailIfOptionsFileError(final boolean failIfOptionsFileError) { + dbOptions.setFailIfOptionsFileError(failIfOptionsFileError); + return this; + } + + @Override + public boolean failIfOptionsFileError() { + return dbOptions.failIfOptionsFileError(); + } + + @Override + public Options setDumpMallocStats(final boolean dumpMallocStats) { + dbOptions.setDumpMallocStats(dumpMallocStats); + return this; + } + + @Override + public boolean dumpMallocStats() { + return dbOptions.dumpMallocStats(); + } + + @Override + public Options setAvoidFlushDuringRecovery(final boolean avoidFlushDuringRecovery) { + dbOptions.setAvoidFlushDuringRecovery(avoidFlushDuringRecovery); + return this; + } + + @Override + public boolean avoidFlushDuringRecovery() { + return dbOptions.avoidFlushDuringRecovery(); + } + + @Override + public Options setAvoidFlushDuringShutdown(final boolean avoidFlushDuringShutdown) { + dbOptions.setAvoidFlushDuringShutdown(avoidFlushDuringShutdown); + return this; + } + + @Override + public boolean avoidFlushDuringShutdown() { + return dbOptions.avoidFlushDuringShutdown(); + } + + @Override + public MemTableConfig memTableConfig() { + return columnFamilyOptions.memTableConfig(); + } + + @Override + public Options setMemTableConfig(final MemTableConfig config) { + columnFamilyOptions.setMemTableConfig(config); + return this; + } + + @Override + public Options setRateLimiter(final RateLimiter rateLimiter) { + dbOptions.setRateLimiter(rateLimiter); + return this; + } + + @Override + public Options setSstFileManager(final SstFileManager sstFileManager) { + dbOptions.setSstFileManager(sstFileManager); + return this; + } + + @Override + public Options setLogger(final org.rocksdb.Logger logger) { + dbOptions.setLogger(logger); + return this; + } + + @Override + public Options setInfoLogLevel(final InfoLogLevel infoLogLevel) { + dbOptions.setInfoLogLevel(infoLogLevel); + return this; + } + + @Override + public InfoLogLevel infoLogLevel() { + return dbOptions.infoLogLevel(); + } + + @Override + public String memTableFactoryName() { + return columnFamilyOptions.memTableFactoryName(); + } + + @Override + public TableFormatConfig tableFormatConfig() { + return columnFamilyOptions.tableFormatConfig(); + } + + @Override + public Options setTableFormatConfig(final TableFormatConfig config) { + columnFamilyOptions.setTableFormatConfig(config); + return this; + } + + @Override + public String tableFactoryName() { + return columnFamilyOptions.tableFactoryName(); + } + + @Override + public Options useFixedLengthPrefixExtractor(final int n) { + columnFamilyOptions.useFixedLengthPrefixExtractor(n); + return this; + } + + @Override + public Options useCappedPrefixExtractor(final int n) { + columnFamilyOptions.useCappedPrefixExtractor(n); + return this; + } + + @Override + public CompressionType compressionType() { + return columnFamilyOptions.compressionType(); + } + + @Override + public Options setCompressionPerLevel(final List compressionLevels) { + columnFamilyOptions.setCompressionPerLevel(compressionLevels); + return this; + } + + @Override + public List compressionPerLevel() { + return columnFamilyOptions.compressionPerLevel(); + } + + @Override + public Options setCompressionType(final CompressionType compressionType) { + columnFamilyOptions.setCompressionType(compressionType); + return this; + } + + + @Override + public Options setBottommostCompressionType(final CompressionType bottommostCompressionType) { + columnFamilyOptions.setBottommostCompressionType(bottommostCompressionType); + return this; + } + + @Override + public CompressionType bottommostCompressionType() { + return columnFamilyOptions.bottommostCompressionType(); + } + + @Override + public Options setCompressionOptions(final CompressionOptions compressionOptions) { + columnFamilyOptions.setCompressionOptions(compressionOptions); + return this; + } + + @Override + public CompressionOptions compressionOptions() { + return columnFamilyOptions.compressionOptions(); + } + + @Override + public CompactionStyle compactionStyle() { + return columnFamilyOptions.compactionStyle(); + } + + @Override + public Options setCompactionStyle(final CompactionStyle compactionStyle) { + columnFamilyOptions.setCompactionStyle(compactionStyle); + return this; + } + + @Override + public int numLevels() { + return columnFamilyOptions.numLevels(); + } + + @Override + public Options setNumLevels(final int numLevels) { + columnFamilyOptions.setNumLevels(numLevels); + return this; + } + + @Override + public int levelZeroFileNumCompactionTrigger() { + return columnFamilyOptions.levelZeroFileNumCompactionTrigger(); + } + + @Override + public Options setLevelZeroFileNumCompactionTrigger(final int numFiles) { + columnFamilyOptions.setLevelZeroFileNumCompactionTrigger(numFiles); + return this; + } + + @Override + public int levelZeroSlowdownWritesTrigger() { + return columnFamilyOptions.levelZeroSlowdownWritesTrigger(); + } + + @Override + public Options setLevelZeroSlowdownWritesTrigger(final int numFiles) { + columnFamilyOptions.setLevelZeroSlowdownWritesTrigger(numFiles); + return this; + } + + @Override + public int levelZeroStopWritesTrigger() { + return columnFamilyOptions.levelZeroStopWritesTrigger(); + } + + @Override + public Options setLevelZeroStopWritesTrigger(final int numFiles) { + columnFamilyOptions.setLevelZeroStopWritesTrigger(numFiles); + return this; + } + + @Override + public long targetFileSizeBase() { + return columnFamilyOptions.targetFileSizeBase(); + } + + @Override + public Options setTargetFileSizeBase(final long targetFileSizeBase) { + columnFamilyOptions.setTargetFileSizeBase(targetFileSizeBase); + return this; + } + + @Override + public int targetFileSizeMultiplier() { + return columnFamilyOptions.targetFileSizeMultiplier(); + } + + @Override + public Options setTargetFileSizeMultiplier(final int multiplier) { + columnFamilyOptions.setTargetFileSizeMultiplier(multiplier); + return this; + } + + @Override + public Options setMaxBytesForLevelBase(final long maxBytesForLevelBase) { + columnFamilyOptions.setMaxBytesForLevelBase(maxBytesForLevelBase); + return this; + } + + @Override + public long maxBytesForLevelBase() { + return columnFamilyOptions.maxBytesForLevelBase(); + } + + @Override + public Options setLevelCompactionDynamicLevelBytes(final boolean enableLevelCompactionDynamicLevelBytes) { + columnFamilyOptions.setLevelCompactionDynamicLevelBytes(enableLevelCompactionDynamicLevelBytes); + return this; + } + + @Override + public boolean levelCompactionDynamicLevelBytes() { + return columnFamilyOptions.levelCompactionDynamicLevelBytes(); + } + + @Override + public double maxBytesForLevelMultiplier() { + return columnFamilyOptions.maxBytesForLevelMultiplier(); + } + + @Override + public Options setMaxBytesForLevelMultiplier(final double multiplier) { + columnFamilyOptions.setMaxBytesForLevelMultiplier(multiplier); + return this; + } + + @Override + public long maxCompactionBytes() { + return columnFamilyOptions.maxCompactionBytes(); + } + + @Override + public Options setMaxCompactionBytes(final long maxCompactionBytes) { + columnFamilyOptions.setMaxCompactionBytes(maxCompactionBytes); + return this; + } + + @Override + public long arenaBlockSize() { + return columnFamilyOptions.arenaBlockSize(); + } + + @Override + public Options setArenaBlockSize(final long arenaBlockSize) { + columnFamilyOptions.setArenaBlockSize(arenaBlockSize); + return this; + } + + @Override + public boolean disableAutoCompactions() { + return columnFamilyOptions.disableAutoCompactions(); + } + + @Override + public Options setDisableAutoCompactions(final boolean disableAutoCompactions) { + columnFamilyOptions.setDisableAutoCompactions(disableAutoCompactions); + return this; + } + + @Override + public long maxSequentialSkipInIterations() { + return columnFamilyOptions.maxSequentialSkipInIterations(); + } + + @Override + public Options setMaxSequentialSkipInIterations(final long maxSequentialSkipInIterations) { + columnFamilyOptions.setMaxSequentialSkipInIterations(maxSequentialSkipInIterations); + return this; + } + + @Override + public boolean inplaceUpdateSupport() { + return columnFamilyOptions.inplaceUpdateSupport(); + } + + @Override + public Options setInplaceUpdateSupport(final boolean inplaceUpdateSupport) { + columnFamilyOptions.setInplaceUpdateSupport(inplaceUpdateSupport); + return this; + } + + @Override + public long inplaceUpdateNumLocks() { + return columnFamilyOptions.inplaceUpdateNumLocks(); + } + + @Override + public Options setInplaceUpdateNumLocks(final long inplaceUpdateNumLocks) { + columnFamilyOptions.setInplaceUpdateNumLocks(inplaceUpdateNumLocks); + return this; + } + + @Override + public double memtablePrefixBloomSizeRatio() { + return columnFamilyOptions.memtablePrefixBloomSizeRatio(); + } + + @Override + public Options setMemtablePrefixBloomSizeRatio(final double memtablePrefixBloomSizeRatio) { + columnFamilyOptions.setMemtablePrefixBloomSizeRatio(memtablePrefixBloomSizeRatio); + return this; + } + + @Override + public int bloomLocality() { + return columnFamilyOptions.bloomLocality(); + } + + @Override + public Options setBloomLocality(final int bloomLocality) { + columnFamilyOptions.setBloomLocality(bloomLocality); + return this; + } + + @Override + public long maxSuccessiveMerges() { + return columnFamilyOptions.maxSuccessiveMerges(); + } + + @Override + public Options setMaxSuccessiveMerges(final long maxSuccessiveMerges) { + columnFamilyOptions.setMaxSuccessiveMerges(maxSuccessiveMerges); + return this; + } + + @Override + public int minWriteBufferNumberToMerge() { + return columnFamilyOptions.minWriteBufferNumberToMerge(); + } + + @Override + public Options setMinWriteBufferNumberToMerge(final int minWriteBufferNumberToMerge) { + columnFamilyOptions.setMinWriteBufferNumberToMerge(minWriteBufferNumberToMerge); + return this; + } + + @Override + public Options setOptimizeFiltersForHits(final boolean optimizeFiltersForHits) { + columnFamilyOptions.setOptimizeFiltersForHits(optimizeFiltersForHits); + return this; + } + + @Override + public boolean optimizeFiltersForHits() { + return columnFamilyOptions.optimizeFiltersForHits(); + } + + @Override + public Options setMemtableHugePageSize(final long memtableHugePageSize) { + columnFamilyOptions.setMemtableHugePageSize(memtableHugePageSize); + return this; + } + + @Override + public long memtableHugePageSize() { + return columnFamilyOptions.memtableHugePageSize(); + } + + @Override + public Options setSoftPendingCompactionBytesLimit(final long softPendingCompactionBytesLimit) { + columnFamilyOptions.setSoftPendingCompactionBytesLimit(softPendingCompactionBytesLimit); + return this; + } + + @Override + public long softPendingCompactionBytesLimit() { + return columnFamilyOptions.softPendingCompactionBytesLimit(); + } + + @Override + public Options setHardPendingCompactionBytesLimit(final long hardPendingCompactionBytesLimit) { + columnFamilyOptions.setHardPendingCompactionBytesLimit(hardPendingCompactionBytesLimit); + return this; + } + + @Override + public long hardPendingCompactionBytesLimit() { + return columnFamilyOptions.hardPendingCompactionBytesLimit(); + } + + @Override + public Options setLevel0FileNumCompactionTrigger(final int level0FileNumCompactionTrigger) { + columnFamilyOptions.setLevel0FileNumCompactionTrigger(level0FileNumCompactionTrigger); + return this; + } + + @Override + public int level0FileNumCompactionTrigger() { + return columnFamilyOptions.level0FileNumCompactionTrigger(); + } + + @Override + public Options setLevel0SlowdownWritesTrigger(final int level0SlowdownWritesTrigger) { + columnFamilyOptions.setLevel0SlowdownWritesTrigger(level0SlowdownWritesTrigger); + return this; + } + + @Override + public int level0SlowdownWritesTrigger() { + return columnFamilyOptions.level0SlowdownWritesTrigger(); + } + + @Override + public Options setLevel0StopWritesTrigger(final int level0StopWritesTrigger) { + columnFamilyOptions.setLevel0StopWritesTrigger(level0StopWritesTrigger); + return this; + } + + @Override + public int level0StopWritesTrigger() { + return columnFamilyOptions.level0StopWritesTrigger(); + } + + @Override + public Options setMaxBytesForLevelMultiplierAdditional(final int[] maxBytesForLevelMultiplierAdditional) { + columnFamilyOptions.setMaxBytesForLevelMultiplierAdditional(maxBytesForLevelMultiplierAdditional); + return this; + } + + @Override + public int[] maxBytesForLevelMultiplierAdditional() { + return columnFamilyOptions.maxBytesForLevelMultiplierAdditional(); + } + + @Override + public Options setParanoidFileChecks(final boolean paranoidFileChecks) { + columnFamilyOptions.setParanoidFileChecks(paranoidFileChecks); + return this; + } + + @Override + public boolean paranoidFileChecks() { + return columnFamilyOptions.paranoidFileChecks(); + } + + @Override + public Options setMaxWriteBufferNumberToMaintain(final int maxWriteBufferNumberToMaintain) { + columnFamilyOptions.setMaxWriteBufferNumberToMaintain(maxWriteBufferNumberToMaintain); + return this; + } + + @Override + public int maxWriteBufferNumberToMaintain() { + return columnFamilyOptions.maxWriteBufferNumberToMaintain(); + } + + @Override + public Options setCompactionPriority(final CompactionPriority compactionPriority) { + columnFamilyOptions.setCompactionPriority(compactionPriority); + return this; + } + + @Override + public CompactionPriority compactionPriority() { + return columnFamilyOptions.compactionPriority(); + } + + @Override + public Options setReportBgIoStats(final boolean reportBgIoStats) { + columnFamilyOptions.setReportBgIoStats(reportBgIoStats); + return this; + } + + @Override + public boolean reportBgIoStats() { + return columnFamilyOptions.reportBgIoStats(); + } + + @Override + public Options setCompactionOptionsUniversal(final CompactionOptionsUniversal compactionOptionsUniversal) { + columnFamilyOptions.setCompactionOptionsUniversal(compactionOptionsUniversal); + return this; + } + + @Override + public CompactionOptionsUniversal compactionOptionsUniversal() { + return columnFamilyOptions.compactionOptionsUniversal(); + } + + @Override + public Options setCompactionOptionsFIFO(final CompactionOptionsFIFO compactionOptionsFIFO) { + columnFamilyOptions.setCompactionOptionsFIFO(compactionOptionsFIFO); + return this; + } + + @Override + public CompactionOptionsFIFO compactionOptionsFIFO() { + return columnFamilyOptions.compactionOptionsFIFO(); + } + + @Override + public Options setForceConsistencyChecks(final boolean forceConsistencyChecks) { + columnFamilyOptions.setForceConsistencyChecks(forceConsistencyChecks); + return this; + } + + @Override + public boolean forceConsistencyChecks() { + return columnFamilyOptions.forceConsistencyChecks(); + } + + @Override + public Options setWriteBufferManager(final WriteBufferManager writeBufferManager) { + dbOptions.setWriteBufferManager(writeBufferManager); + return this; + } + + @Override + public WriteBufferManager writeBufferManager() { + return dbOptions.writeBufferManager(); + } + + @Override + public Options setMaxWriteBatchGroupSizeBytes(final long maxWriteBatchGroupSizeBytes) { + dbOptions.setMaxWriteBatchGroupSizeBytes(maxWriteBatchGroupSizeBytes); + return this; + } + + @Override + public long maxWriteBatchGroupSizeBytes() { + return dbOptions.maxWriteBatchGroupSizeBytes(); + } + + @Override + public Options oldDefaults(final int majorVersion, final int minorVersion) { + columnFamilyOptions.oldDefaults(majorVersion, minorVersion); + return this; + } + + @Override + public Options optimizeForSmallDb(final Cache cache) { + return super.optimizeForSmallDb(cache); + } + + @Override + public AbstractCompactionFilter> compactionFilter() { + return columnFamilyOptions.compactionFilter(); + } + + @Override + public AbstractCompactionFilterFactory> compactionFilterFactory() { + return columnFamilyOptions.compactionFilterFactory(); + } + + @Override + public Options setStatsPersistPeriodSec(final int statsPersistPeriodSec) { + dbOptions.setStatsPersistPeriodSec(statsPersistPeriodSec); + return this; + } + + @Override + public int statsPersistPeriodSec() { + return dbOptions.statsPersistPeriodSec(); + } + + @Override + public Options setStatsHistoryBufferSize(final long statsHistoryBufferSize) { + dbOptions.setStatsHistoryBufferSize(statsHistoryBufferSize); + return this; + } + + @Override + public long statsHistoryBufferSize() { + return dbOptions.statsHistoryBufferSize(); + } + + @Override + public Options setStrictBytesPerSync(final boolean strictBytesPerSync) { + dbOptions.setStrictBytesPerSync(strictBytesPerSync); + return this; + } + + @Override + public boolean strictBytesPerSync() { + return dbOptions.strictBytesPerSync(); + } + + @Override + public Options setListeners(final List listeners) { + dbOptions.setListeners(listeners); + return this; + } + + @Override + public List listeners() { + return dbOptions.listeners(); + } + + @Override + public Options setEnablePipelinedWrite(final boolean enablePipelinedWrite) { + dbOptions.setEnablePipelinedWrite(enablePipelinedWrite); + return this; + } + + @Override + public boolean enablePipelinedWrite() { + return dbOptions.enablePipelinedWrite(); + } + + @Override + public Options setUnorderedWrite(final boolean unorderedWrite) { + dbOptions.setUnorderedWrite(unorderedWrite); + return this; + } + + @Override + public boolean unorderedWrite() { + return dbOptions.unorderedWrite(); + } + + @Override + public Options setSkipCheckingSstFileSizesOnDbOpen(final boolean skipCheckingSstFileSizesOnDbOpen) { + dbOptions.setSkipCheckingSstFileSizesOnDbOpen(skipCheckingSstFileSizesOnDbOpen); + return this; + } + + @Override + public boolean skipCheckingSstFileSizesOnDbOpen() { + return dbOptions.skipCheckingSstFileSizesOnDbOpen(); + } + + @Override + public Options setWalFilter(final AbstractWalFilter walFilter) { + logIgnoreWalOption("walFilter"); + return this; + } + + @Override + public WalFilter walFilter() { + return dbOptions.walFilter(); + } + + @Override + public Options setAllowIngestBehind(final boolean allowIngestBehind) { + dbOptions.setAllowIngestBehind(allowIngestBehind); + return this; + } + + @Override + public boolean allowIngestBehind() { + return dbOptions.allowIngestBehind(); + } + + @Override + public Options setPreserveDeletes(final boolean preserveDeletes) { + dbOptions.setPreserveDeletes(preserveDeletes); + return this; + } + + @Override + public boolean preserveDeletes() { + return dbOptions.preserveDeletes(); + } + + @Override + public Options setTwoWriteQueues(final boolean twoWriteQueues) { + dbOptions.setTwoWriteQueues(twoWriteQueues); + return this; + } + + @Override + public boolean twoWriteQueues() { + return dbOptions.twoWriteQueues(); + } + + @Override + public Options setManualWalFlush(final boolean manualWalFlush) { + logIgnoreWalOption("manualWalFlush"); + return this; + } + + @Override + public boolean manualWalFlush() { + return dbOptions.manualWalFlush(); + } + + @Override + public Options setCfPaths(final Collection cfPaths) { + columnFamilyOptions.setCfPaths(cfPaths); + return this; + } + + @Override + public List cfPaths() { + return columnFamilyOptions.cfPaths(); + } + + @Override + public Options setBottommostCompressionOptions(final CompressionOptions bottommostCompressionOptions) { + columnFamilyOptions.setBottommostCompressionOptions(bottommostCompressionOptions); + return this; + } + + @Override + public CompressionOptions bottommostCompressionOptions() { + return columnFamilyOptions.bottommostCompressionOptions(); + } + + @Override + public Options setTtl(final long ttl) { + columnFamilyOptions.setTtl(ttl); + return this; + } + + @Override + public long ttl() { + return columnFamilyOptions.ttl(); + } + + @Override + public Options setAtomicFlush(final boolean atomicFlush) { + dbOptions.setAtomicFlush(atomicFlush); + return this; + } + + @Override + public boolean atomicFlush() { + return dbOptions.atomicFlush(); + } + + @Override + public Options setAvoidUnnecessaryBlockingIO(final boolean avoidUnnecessaryBlockingIO) { + dbOptions.setAvoidUnnecessaryBlockingIO(avoidUnnecessaryBlockingIO); + return this; + } + + @Override + public boolean avoidUnnecessaryBlockingIO() { + return dbOptions.avoidUnnecessaryBlockingIO(); + } + + @Override + public Options setPersistStatsToDisk(final boolean persistStatsToDisk) { + dbOptions.setPersistStatsToDisk(persistStatsToDisk); + return this; + } + + @Override + public boolean persistStatsToDisk() { + return dbOptions.persistStatsToDisk(); + } + + @Override + public Options setWriteDbidToManifest(final boolean writeDbidToManifest) { + dbOptions.setWriteDbidToManifest(writeDbidToManifest); + return this; + } + + @Override + public boolean writeDbidToManifest() { + return dbOptions.writeDbidToManifest(); + } + + @Override + public Options setLogReadaheadSize(final long logReadaheadSize) { + dbOptions.setLogReadaheadSize(logReadaheadSize); + return this; + } + + @Override + public long logReadaheadSize() { + return dbOptions.logReadaheadSize(); + } + + @Override + public Options setBestEffortsRecovery(final boolean bestEffortsRecovery) { + dbOptions.setBestEffortsRecovery(bestEffortsRecovery); + return this; + } + + @Override + public boolean bestEffortsRecovery() { + return dbOptions.bestEffortsRecovery(); + } + + @Override + public Options setMaxBgErrorResumeCount(final int maxBgerrorResumeCount) { + dbOptions.setMaxBgErrorResumeCount(maxBgerrorResumeCount); + return this; + } + + @Override + public int maxBgerrorResumeCount() { + return dbOptions.maxBgerrorResumeCount(); + } + + @Override + public Options setBgerrorResumeRetryInterval(final long bgerrorResumeRetryInterval) { + dbOptions.setBgerrorResumeRetryInterval(bgerrorResumeRetryInterval); + return this; + } + + @Override + public long bgerrorResumeRetryInterval() { + return dbOptions.bgerrorResumeRetryInterval(); + } + + @Override + public Options setSstPartitionerFactory(final SstPartitionerFactory sstPartitionerFactory) { + columnFamilyOptions.setSstPartitionerFactory(sstPartitionerFactory); + return this; + } + + @Override + public SstPartitionerFactory sstPartitionerFactory() { + return columnFamilyOptions.sstPartitionerFactory(); + } + + @Override + public Options setCompactionThreadLimiter(final ConcurrentTaskLimiter compactionThreadLimiter) { + columnFamilyOptions.setCompactionThreadLimiter(compactionThreadLimiter); + return this; + } + + @Override + public ConcurrentTaskLimiter compactionThreadLimiter() { + return columnFamilyOptions.compactionThreadLimiter(); + } + + public Options setCompactionFilter(final AbstractCompactionFilter> compactionFilter) { + columnFamilyOptions.setCompactionFilter(compactionFilter); + return this; + } + + public Options setCompactionFilterFactory(final AbstractCompactionFilterFactory> compactionFilterFactory) { + columnFamilyOptions.setCompactionFilterFactory(compactionFilterFactory); + return this; + } + + @Override + public void close() { + // ColumnFamilyOptions should be closed after DBOptions + dbOptions.close(); + columnFamilyOptions.close(); + // close super last since we initialized it first + super.close(); + } + + private void logIgnoreWalOption(final String option) { + log.warn("WAL is explicitly disabled by Streams in RocksDB. Setting option '{}' will be ignored", option); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBRangeIterator.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBRangeIterator.java new file mode 100644 index 0000000..21e2201 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBRangeIterator.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.rocksdb.RocksIterator; + +import java.util.Comparator; +import java.util.Set; + +class RocksDBRangeIterator extends RocksDbIterator { + // RocksDB's JNI interface does not expose getters/setters that allow the + // comparator to be pluggable, and the default is lexicographic, so it's + // safe to just force lexicographic comparator here for now. + private final Comparator comparator = Bytes.BYTES_LEXICO_COMPARATOR; + private final byte[] rawLastKey; + private final boolean forward; + private final boolean toInclusive; + + RocksDBRangeIterator(final String storeName, + final RocksIterator iter, + final Set> openIterators, + final Bytes from, + final Bytes to, + final boolean forward, + final boolean toInclusive) { + super(storeName, iter, openIterators, forward); + this.forward = forward; + this.toInclusive = toInclusive; + if (forward) { + if (from == null) { + iter.seekToFirst(); + } else { + iter.seek(from.get()); + } + rawLastKey = to == null ? null : to.get(); + } else { + if (to == null) { + iter.seekToLast(); + } else { + iter.seekForPrev(to.get()); + } + rawLastKey = from == null ? null : from.get(); + } + } + + @Override + public KeyValue makeNext() { + final KeyValue next = super.makeNext(); + if (next == null) { + return allDone(); + } else if (rawLastKey == null) { + //null means range endpoint is open + return next; + + } else { + if (forward) { + if (comparator.compare(next.key.get(), rawLastKey) < 0) { + return next; + } else if (comparator.compare(next.key.get(), rawLastKey) == 0) { + return toInclusive ? next : allDone(); + } else { + return allDone(); + } + } else { + if (comparator.compare(next.key.get(), rawLastKey) >= 0) { + return next; + } else { + return allDone(); + } + } + } + } +} + diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBSegmentedBytesStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBSegmentedBytesStore.java new file mode 100644 index 0000000..6c72fa6 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBSegmentedBytesStore.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +public class RocksDBSegmentedBytesStore extends AbstractRocksDBSegmentedBytesStore { + + RocksDBSegmentedBytesStore(final String name, + final String metricsScope, + final long retention, + final long segmentInterval, + final KeySchema keySchema) { + super(name, metricsScope, keySchema, new KeyValueSegments(name, metricsScope, retention, segmentInterval)); + } +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBSessionStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBSessionStore.java new file mode 100644 index 0000000..f5d7108 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBSessionStore.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.SessionStore; + + +public class RocksDBSessionStore + extends WrappedStateStore + implements SessionStore { + + RocksDBSessionStore(final SegmentedBytesStore bytesStore) { + super(bytesStore); + } + + @Override + public KeyValueIterator, byte[]> findSessions(final Bytes key, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + final KeyValueIterator bytesIterator = wrapped().fetch( + key, + earliestSessionEndTime, + latestSessionStartTime + ); + return new WrappedSessionStoreIterator(bytesIterator); + } + + @Override + public KeyValueIterator, byte[]> backwardFindSessions(final Bytes key, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + final KeyValueIterator bytesIterator = wrapped().backwardFetch( + key, + earliestSessionEndTime, + latestSessionStartTime + ); + return new WrappedSessionStoreIterator(bytesIterator); + } + + @Override + public KeyValueIterator, byte[]> findSessions(final Bytes keyFrom, + final Bytes keyTo, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + final KeyValueIterator bytesIterator = wrapped().fetch( + keyFrom, + keyTo, + earliestSessionEndTime, + latestSessionStartTime + ); + return new WrappedSessionStoreIterator(bytesIterator); + } + + @Override + public KeyValueIterator, byte[]> backwardFindSessions(final Bytes keyFrom, + final Bytes keyTo, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + final KeyValueIterator bytesIterator = wrapped().backwardFetch( + keyFrom, + keyTo, + earliestSessionEndTime, + latestSessionStartTime + ); + return new WrappedSessionStoreIterator(bytesIterator); + } + + @Override + public byte[] fetchSession(final Bytes key, + final long earliestSessionEndTime, + final long latestSessionStartTime) { + return wrapped().get(SessionKeySchema.toBinary( + key, + earliestSessionEndTime, + latestSessionStartTime + )); + } + + @Override + public KeyValueIterator, byte[]> fetch(final Bytes key) { + return findSessions(key, 0, Long.MAX_VALUE); + } + + @Override + public KeyValueIterator, byte[]> backwardFetch(final Bytes key) { + return backwardFindSessions(key, 0, Long.MAX_VALUE); + } + + @Override + public KeyValueIterator, byte[]> fetch(final Bytes keyFrom, final Bytes keyTo) { + return findSessions(keyFrom, keyTo, 0, Long.MAX_VALUE); + } + + @Override + public KeyValueIterator, byte[]> backwardFetch(final Bytes keyFrom, final Bytes keyTo) { + return backwardFindSessions(keyFrom, keyTo, 0, Long.MAX_VALUE); + } + + @Override + public void remove(final Windowed key) { + wrapped().remove(SessionKeySchema.toBinary(key)); + } + + @Override + public void put(final Windowed sessionKey, final byte[] aggregate) { + wrapped().put(SessionKeySchema.toBinary(sessionKey), aggregate); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBStore.java new file mode 100644 index 0000000..aa1b1ba --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBStore.java @@ -0,0 +1,707 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.metrics.Sensor.RecordingLevel; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.errors.ProcessorStateException; +import org.apache.kafka.streams.processor.BatchingStateRestoreCallback; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.RocksDBConfigSetter; +import org.apache.kafka.streams.state.internals.metrics.RocksDBMetricsRecorder; +import org.rocksdb.BlockBasedTableConfig; +import org.rocksdb.BloomFilter; +import org.rocksdb.Cache; +import org.rocksdb.ColumnFamilyDescriptor; +import org.rocksdb.ColumnFamilyHandle; +import org.rocksdb.ColumnFamilyOptions; +import org.rocksdb.CompactionStyle; +import org.rocksdb.CompressionType; +import org.rocksdb.DBOptions; +import org.rocksdb.FlushOptions; +import org.rocksdb.InfoLogLevel; +import org.rocksdb.LRUCache; +import org.rocksdb.Options; +import org.rocksdb.RocksDB; +import org.rocksdb.RocksDBException; +import org.rocksdb.RocksIterator; +import org.rocksdb.Statistics; +import org.rocksdb.TableFormatConfig; +import org.rocksdb.WriteBatch; +import org.rocksdb.WriteOptions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +import static org.apache.kafka.streams.StreamsConfig.METRICS_RECORDING_LEVEL_CONFIG; +import static org.apache.kafka.streams.processor.internals.ProcessorContextUtils.getMetricsImpl; + +/** + * A persistent key-value store based on RocksDB. + */ +public class RocksDBStore implements KeyValueStore, BatchWritingStore { + private static final Logger log = LoggerFactory.getLogger(RocksDBStore.class); + + private static final CompressionType COMPRESSION_TYPE = CompressionType.NO_COMPRESSION; + private static final CompactionStyle COMPACTION_STYLE = CompactionStyle.UNIVERSAL; + private static final long WRITE_BUFFER_SIZE = 16 * 1024 * 1024L; + private static final long BLOCK_CACHE_SIZE = 50 * 1024 * 1024L; + private static final long BLOCK_SIZE = 4096L; + private static final int MAX_WRITE_BUFFERS = 3; + private static final String DB_FILE_DIR = "rocksdb"; + + final String name; + private final String parentDir; + final Set> openIterators = Collections.synchronizedSet(new HashSet<>()); + + File dbDir; + RocksDB db; + RocksDBAccessor dbAccessor; + + // the following option objects will be created in openDB and closed in the close() method + private RocksDBGenericOptionsToDbOptionsColumnFamilyOptionsAdapter userSpecifiedOptions; + WriteOptions wOptions; + FlushOptions fOptions; + private Cache cache; + private BloomFilter filter; + + private RocksDBConfigSetter configSetter; + private boolean userSpecifiedStatistics = false; + + private final RocksDBMetricsRecorder metricsRecorder; + + protected volatile boolean open = false; + + RocksDBStore(final String name, + final String metricsScope) { + this(name, DB_FILE_DIR, new RocksDBMetricsRecorder(metricsScope, name)); + } + + RocksDBStore(final String name, + final String parentDir, + final RocksDBMetricsRecorder metricsRecorder) { + this.name = name; + this.parentDir = parentDir; + this.metricsRecorder = metricsRecorder; + } + + @SuppressWarnings("unchecked") + void openDB(final Map configs, final File stateDir) { + // initialize the default rocksdb options + + final DBOptions dbOptions = new DBOptions(); + final ColumnFamilyOptions columnFamilyOptions = new ColumnFamilyOptions(); + userSpecifiedOptions = new RocksDBGenericOptionsToDbOptionsColumnFamilyOptionsAdapter(dbOptions, columnFamilyOptions); + + final BlockBasedTableConfigWithAccessibleCache tableConfig = new BlockBasedTableConfigWithAccessibleCache(); + cache = new LRUCache(BLOCK_CACHE_SIZE); + tableConfig.setBlockCache(cache); + tableConfig.setBlockSize(BLOCK_SIZE); + + filter = new BloomFilter(); + tableConfig.setFilterPolicy(filter); + + userSpecifiedOptions.optimizeFiltersForHits(); + userSpecifiedOptions.setTableFormatConfig(tableConfig); + userSpecifiedOptions.setWriteBufferSize(WRITE_BUFFER_SIZE); + userSpecifiedOptions.setCompressionType(COMPRESSION_TYPE); + userSpecifiedOptions.setCompactionStyle(COMPACTION_STYLE); + userSpecifiedOptions.setMaxWriteBufferNumber(MAX_WRITE_BUFFERS); + userSpecifiedOptions.setCreateIfMissing(true); + userSpecifiedOptions.setErrorIfExists(false); + userSpecifiedOptions.setInfoLogLevel(InfoLogLevel.ERROR_LEVEL); + // this is the recommended way to increase parallelism in RocksDb + // note that the current implementation of setIncreaseParallelism affects the number + // of compaction threads but not flush threads (the latter remains one). Also + // the parallelism value needs to be at least two because of the code in + // https://github.com/facebook/rocksdb/blob/62ad0a9b19f0be4cefa70b6b32876e764b7f3c11/util/options.cc#L580 + // subtracts one from the value passed to determine the number of compaction threads + // (this could be a bug in the RocksDB code and their devs have been contacted). + userSpecifiedOptions.setIncreaseParallelism(Math.max(Runtime.getRuntime().availableProcessors(), 2)); + + wOptions = new WriteOptions(); + wOptions.setDisableWAL(true); + + fOptions = new FlushOptions(); + fOptions.setWaitForFlush(true); + + final Class configSetterClass = + (Class) configs.get(StreamsConfig.ROCKSDB_CONFIG_SETTER_CLASS_CONFIG); + + if (configSetterClass != null) { + configSetter = Utils.newInstance(configSetterClass); + configSetter.setConfig(name, userSpecifiedOptions, configs); + } + + dbDir = new File(new File(stateDir, parentDir), name); + + try { + Files.createDirectories(dbDir.getParentFile().toPath()); + Files.createDirectories(dbDir.getAbsoluteFile().toPath()); + } catch (final IOException fatal) { + throw new ProcessorStateException(fatal); + } + + // Setup statistics before the database is opened, otherwise the statistics are not updated + // with the measurements from Rocks DB + maybeSetUpStatistics(configs); + + openRocksDB(dbOptions, columnFamilyOptions); + open = true; + + addValueProvidersToMetricsRecorder(); + } + + private void maybeSetUpStatistics(final Map configs) { + if (userSpecifiedOptions.statistics() != null) { + userSpecifiedStatistics = true; + } + if (!userSpecifiedStatistics && + RecordingLevel.forName((String) configs.get(METRICS_RECORDING_LEVEL_CONFIG)) == RecordingLevel.DEBUG) { + + // metrics recorder will clean up statistics object + final Statistics statistics = new Statistics(); + userSpecifiedOptions.setStatistics(statistics); + } + } + + private void addValueProvidersToMetricsRecorder() { + final TableFormatConfig tableFormatConfig = userSpecifiedOptions.tableFormatConfig(); + final Statistics statistics = userSpecifiedStatistics ? null : userSpecifiedOptions.statistics(); + if (tableFormatConfig instanceof BlockBasedTableConfigWithAccessibleCache) { + final Cache cache = ((BlockBasedTableConfigWithAccessibleCache) tableFormatConfig).blockCache(); + metricsRecorder.addValueProviders(name, db, cache, statistics); + } else if (tableFormatConfig instanceof BlockBasedTableConfig) { + throw new ProcessorStateException("The used block-based table format configuration does not expose the " + + "block cache. Use the BlockBasedTableConfig instance provided by Options#tableFormatConfig() to configure " + + "the block-based table format of RocksDB. Do not provide a new instance of BlockBasedTableConfig to " + + "the RocksDB options."); + } else { + metricsRecorder.addValueProviders(name, db, null, statistics); + } + } + + void openRocksDB(final DBOptions dbOptions, + final ColumnFamilyOptions columnFamilyOptions) { + final List columnFamilyDescriptors + = Collections.singletonList(new ColumnFamilyDescriptor(RocksDB.DEFAULT_COLUMN_FAMILY, columnFamilyOptions)); + final List columnFamilies = new ArrayList<>(columnFamilyDescriptors.size()); + + try { + db = RocksDB.open(dbOptions, dbDir.getAbsolutePath(), columnFamilyDescriptors, columnFamilies); + dbAccessor = new SingleColumnFamilyAccessor(columnFamilies.get(0)); + } catch (final RocksDBException e) { + throw new ProcessorStateException("Error opening store " + name + " at location " + dbDir.toString(), e); + } + } + + @Deprecated + @Override + public void init(final ProcessorContext context, + final StateStore root) { + // open the DB dir + metricsRecorder.init(getMetricsImpl(context), context.taskId()); + openDB(context.appConfigs(), context.stateDir()); + + // value getter should always read directly from rocksDB + // since it is only for values that are already flushed + context.register(root, new RocksDBBatchingRestoreCallback(this)); + } + + @Override + public void init(final StateStoreContext context, + final StateStore root) { + // open the DB dir + metricsRecorder.init(getMetricsImpl(context), context.taskId()); + openDB(context.appConfigs(), context.stateDir()); + + // value getter should always read directly from rocksDB + // since it is only for values that are already flushed + context.register(root, new RocksDBBatchingRestoreCallback(this)); + } + + @Override + public String name() { + return name; + } + + @Override + public boolean persistent() { + return true; + } + + @Override + public boolean isOpen() { + return open; + } + + private void validateStoreOpen() { + if (!open) { + throw new InvalidStateStoreException("Store " + name + " is currently closed"); + } + } + + @Override + public synchronized void put(final Bytes key, + final byte[] value) { + Objects.requireNonNull(key, "key cannot be null"); + validateStoreOpen(); + dbAccessor.put(key.get(), value); + } + + @Override + public synchronized byte[] putIfAbsent(final Bytes key, + final byte[] value) { + Objects.requireNonNull(key, "key cannot be null"); + final byte[] originalValue = get(key); + if (originalValue == null) { + put(key, value); + } + return originalValue; + } + + @Override + public void putAll(final List> entries) { + try (final WriteBatch batch = new WriteBatch()) { + dbAccessor.prepareBatch(entries, batch); + write(batch); + } catch (final RocksDBException e) { + throw new ProcessorStateException("Error while batch writing to store " + name, e); + } + } + + @Override + public , P> KeyValueIterator prefixScan(final P prefix, + final PS prefixKeySerializer) { + validateStoreOpen(); + Objects.requireNonNull(prefix, "prefix cannot be null"); + Objects.requireNonNull(prefixKeySerializer, "prefixKeySerializer cannot be null"); + final Bytes prefixBytes = Bytes.wrap(prefixKeySerializer.serialize(null, prefix)); + + final KeyValueIterator rocksDbPrefixSeekIterator = dbAccessor.prefixScan(prefixBytes); + openIterators.add(rocksDbPrefixSeekIterator); + + return rocksDbPrefixSeekIterator; + } + + @Override + public synchronized byte[] get(final Bytes key) { + validateStoreOpen(); + try { + return dbAccessor.get(key.get()); + } catch (final RocksDBException e) { + // String format is happening in wrapping stores. So formatted message is thrown from wrapping stores. + throw new ProcessorStateException("Error while getting value for key from store " + name, e); + } + } + + @Override + public synchronized byte[] delete(final Bytes key) { + Objects.requireNonNull(key, "key cannot be null"); + final byte[] oldValue; + try { + oldValue = dbAccessor.getOnly(key.get()); + } catch (final RocksDBException e) { + // String format is happening in wrapping stores. So formatted message is thrown from wrapping stores. + throw new ProcessorStateException("Error while getting value for key from store " + name, e); + } + put(key, null); + return oldValue; + } + + void deleteRange(final Bytes keyFrom, final Bytes keyTo) { + Objects.requireNonNull(keyFrom, "keyFrom cannot be null"); + Objects.requireNonNull(keyTo, "keyTo cannot be null"); + + validateStoreOpen(); + + // End of key is exclusive, so we increment it by 1 byte to make keyTo inclusive + dbAccessor.deleteRange(keyFrom.get(), Bytes.increment(keyTo).get()); + } + + @Override + public synchronized KeyValueIterator range(final Bytes from, + final Bytes to) { + return range(from, to, true); + } + + @Override + public synchronized KeyValueIterator reverseRange(final Bytes from, + final Bytes to) { + return range(from, to, false); + } + + KeyValueIterator range(final Bytes from, + final Bytes to, + final boolean forward) { + if (Objects.nonNull(from) && Objects.nonNull(to) && from.compareTo(to) > 0) { + log.warn("Returning empty iterator for fetch with invalid key range: from > to. " + + "This may be due to range arguments set in the wrong order, " + + "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes. " + + "Note that the built-in numerical serdes do not follow this for negative numbers"); + return KeyValueIterators.emptyIterator(); + } + + validateStoreOpen(); + + final KeyValueIterator rocksDBRangeIterator = dbAccessor.range(from, to, forward); + openIterators.add(rocksDBRangeIterator); + + return rocksDBRangeIterator; + } + + @Override + public synchronized KeyValueIterator all() { + return all(true); + } + + @Override + public KeyValueIterator reverseAll() { + return all(false); + } + + KeyValueIterator all(final boolean forward) { + validateStoreOpen(); + final KeyValueIterator rocksDbIterator = dbAccessor.all(forward); + openIterators.add(rocksDbIterator); + return rocksDbIterator; + } + + /** + * Return an approximate count of key-value mappings in this store. + * + * RocksDB cannot return an exact entry count without doing a + * full scan, so this method relies on the rocksdb.estimate-num-keys + * property to get an approximate count. The returned size also includes + * a count of dirty keys in the store's in-memory cache, which may lead to some + * double-counting of entries and inflate the estimate. + * + * @return an approximate count of key-value mappings in the store. + */ + @Override + public long approximateNumEntries() { + validateStoreOpen(); + final long numEntries; + try { + numEntries = dbAccessor.approximateNumEntries(); + } catch (final RocksDBException e) { + throw new ProcessorStateException("Error fetching property from store " + name, e); + } + if (isOverflowing(numEntries)) { + return Long.MAX_VALUE; + } + return numEntries; + } + + private boolean isOverflowing(final long value) { + // RocksDB returns an unsigned 8-byte integer, which could overflow long + // and manifest as a negative value. + return value < 0; + } + + @Override + public synchronized void flush() { + if (db == null) { + return; + } + try { + dbAccessor.flush(); + } catch (final RocksDBException e) { + throw new ProcessorStateException("Error while executing flush from store " + name, e); + } + } + + @Override + public void addToBatch(final KeyValue record, + final WriteBatch batch) throws RocksDBException { + dbAccessor.addToBatch(record.key, record.value, batch); + } + + @Override + public void write(final WriteBatch batch) throws RocksDBException { + db.write(wOptions, batch); + } + + @Override + public synchronized void close() { + if (!open) { + return; + } + + open = false; + closeOpenIterators(); + + if (configSetter != null) { + configSetter.close(name, userSpecifiedOptions); + configSetter = null; + } + + metricsRecorder.removeValueProviders(name); + + // Important: do not rearrange the order in which the below objects are closed! + // Order of closing must follow: ColumnFamilyHandle > RocksDB > DBOptions > ColumnFamilyOptions + dbAccessor.close(); + db.close(); + userSpecifiedOptions.close(); + wOptions.close(); + fOptions.close(); + filter.close(); + cache.close(); + + dbAccessor = null; + userSpecifiedOptions = null; + wOptions = null; + fOptions = null; + db = null; + filter = null; + cache = null; + } + + private void closeOpenIterators() { + final HashSet> iterators; + synchronized (openIterators) { + iterators = new HashSet<>(openIterators); + } + if (iterators.size() != 0) { + log.warn("Closing {} open iterators for store {}", iterators.size(), name); + for (final KeyValueIterator iterator : iterators) { + iterator.close(); + } + } + } + + interface RocksDBAccessor { + + void put(final byte[] key, + final byte[] value); + + void prepareBatch(final List> entries, + final WriteBatch batch) throws RocksDBException; + + byte[] get(final byte[] key) throws RocksDBException; + + /** + * In contrast to get(), we don't migrate the key to new CF. + *

                + * Use for get() within delete() -- no need to migrate, as it's deleted anyway + */ + byte[] getOnly(final byte[] key) throws RocksDBException; + + KeyValueIterator range(final Bytes from, + final Bytes to, + final boolean forward); + + /** + * Deletes keys entries in the range ['from', 'to'], including 'from' and excluding 'to'. + */ + void deleteRange(final byte[] from, + final byte[] to); + + KeyValueIterator all(final boolean forward); + + KeyValueIterator prefixScan(final Bytes prefix); + + long approximateNumEntries() throws RocksDBException; + + void flush() throws RocksDBException; + + void prepareBatchForRestore(final Collection> records, + final WriteBatch batch) throws RocksDBException; + + void addToBatch(final byte[] key, + final byte[] value, + final WriteBatch batch) throws RocksDBException; + + void close(); + } + + class SingleColumnFamilyAccessor implements RocksDBAccessor { + private final ColumnFamilyHandle columnFamily; + + SingleColumnFamilyAccessor(final ColumnFamilyHandle columnFamily) { + this.columnFamily = columnFamily; + } + + @Override + public void put(final byte[] key, + final byte[] value) { + if (value == null) { + try { + db.delete(columnFamily, wOptions, key); + } catch (final RocksDBException e) { + // String format is happening in wrapping stores. So formatted message is thrown from wrapping stores. + throw new ProcessorStateException("Error while removing key from store " + name, e); + } + } else { + try { + db.put(columnFamily, wOptions, key, value); + } catch (final RocksDBException e) { + // String format is happening in wrapping stores. So formatted message is thrown from wrapping stores. + throw new ProcessorStateException("Error while putting key/value into store " + name, e); + } + } + } + + @Override + public void prepareBatch(final List> entries, + final WriteBatch batch) throws RocksDBException { + for (final KeyValue entry : entries) { + Objects.requireNonNull(entry.key, "key cannot be null"); + addToBatch(entry.key.get(), entry.value, batch); + } + } + + @Override + public byte[] get(final byte[] key) throws RocksDBException { + return db.get(columnFamily, key); + } + + @Override + public byte[] getOnly(final byte[] key) throws RocksDBException { + return db.get(columnFamily, key); + } + + @Override + public KeyValueIterator range(final Bytes from, + final Bytes to, + final boolean forward) { + return new RocksDBRangeIterator( + name, + db.newIterator(columnFamily), + openIterators, + from, + to, + forward, + true + ); + } + + @Override + public void deleteRange(final byte[] from, final byte[] to) { + try { + db.deleteRange(columnFamily, wOptions, from, to); + } catch (final RocksDBException e) { + // String format is happening in wrapping stores. So formatted message is thrown from wrapping stores. + throw new ProcessorStateException("Error while removing key from store " + name, e); + } + } + + @Override + public KeyValueIterator all(final boolean forward) { + final RocksIterator innerIterWithTimestamp = db.newIterator(columnFamily); + if (forward) { + innerIterWithTimestamp.seekToFirst(); + } else { + innerIterWithTimestamp.seekToLast(); + } + return new RocksDbIterator(name, innerIterWithTimestamp, openIterators, forward); + } + + @Override + public KeyValueIterator prefixScan(final Bytes prefix) { + final Bytes to = Bytes.increment(prefix); + return new RocksDBRangeIterator( + name, + db.newIterator(columnFamily), + openIterators, + prefix, + to, + true, + false + ); + } + + @Override + public long approximateNumEntries() throws RocksDBException { + return db.getLongProperty(columnFamily, "rocksdb.estimate-num-keys"); + } + + @Override + public void flush() throws RocksDBException { + db.flush(fOptions, columnFamily); + } + + @Override + public void prepareBatchForRestore(final Collection> records, + final WriteBatch batch) throws RocksDBException { + for (final KeyValue record : records) { + addToBatch(record.key, record.value, batch); + } + } + + @Override + public void addToBatch(final byte[] key, + final byte[] value, + final WriteBatch batch) throws RocksDBException { + if (value == null) { + batch.delete(columnFamily, key); + } else { + batch.put(columnFamily, key, value); + } + } + + @Override + public void close() { + columnFamily.close(); + } + } + + // not private for testing + static class RocksDBBatchingRestoreCallback implements BatchingStateRestoreCallback { + + private final RocksDBStore rocksDBStore; + + RocksDBBatchingRestoreCallback(final RocksDBStore rocksDBStore) { + this.rocksDBStore = rocksDBStore; + } + + @Override + public void restoreAll(final Collection> records) { + try (final WriteBatch batch = new WriteBatch()) { + rocksDBStore.dbAccessor.prepareBatchForRestore(records, batch); + rocksDBStore.write(batch); + } catch (final RocksDBException e) { + throw new ProcessorStateException("Error restoring batch to store " + rocksDBStore.name, e); + } + } + } + + // for testing + public Options getOptions() { + return userSpecifiedOptions; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimestampedSegmentedBytesStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimestampedSegmentedBytesStore.java new file mode 100644 index 0000000..7fd958c --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimestampedSegmentedBytesStore.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +public class RocksDBTimestampedSegmentedBytesStore extends AbstractRocksDBSegmentedBytesStore { + + RocksDBTimestampedSegmentedBytesStore(final String name, + final String metricsScope, + final long retention, + final long segmentInterval, + final KeySchema keySchema) { + super(name, metricsScope, keySchema, new TimestampedSegments(name, metricsScope, retention, segmentInterval)); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimestampedStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimestampedStore.java new file mode 100644 index 0000000..dd56e44 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimestampedStore.java @@ -0,0 +1,477 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.AbstractIterator; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.errors.ProcessorStateException; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.TimestampedBytesStore; +import org.apache.kafka.streams.state.internals.metrics.RocksDBMetricsRecorder; +import org.rocksdb.ColumnFamilyDescriptor; +import org.rocksdb.ColumnFamilyHandle; +import org.rocksdb.ColumnFamilyOptions; +import org.rocksdb.DBOptions; +import org.rocksdb.RocksDB; +import org.rocksdb.RocksDBException; +import org.rocksdb.RocksIterator; +import org.rocksdb.WriteBatch; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Comparator; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Objects; + +import static java.util.Arrays.asList; +import static org.apache.kafka.streams.state.TimestampedBytesStore.convertToTimestampedFormat; + +/** + * A persistent key-(value-timestamp) store based on RocksDB. + */ +public class RocksDBTimestampedStore extends RocksDBStore implements TimestampedBytesStore { + private static final Logger log = LoggerFactory.getLogger(RocksDBTimestampedStore.class); + + RocksDBTimestampedStore(final String name, + final String metricsScope) { + super(name, metricsScope); + } + + RocksDBTimestampedStore(final String name, + final String parentDir, + final RocksDBMetricsRecorder metricsRecorder) { + super(name, parentDir, metricsRecorder); + } + + @Override + void openRocksDB(final DBOptions dbOptions, + final ColumnFamilyOptions columnFamilyOptions) { + final List columnFamilyDescriptors = asList( + new ColumnFamilyDescriptor(RocksDB.DEFAULT_COLUMN_FAMILY, columnFamilyOptions), + new ColumnFamilyDescriptor("keyValueWithTimestamp".getBytes(StandardCharsets.UTF_8), columnFamilyOptions)); + final List columnFamilies = new ArrayList<>(columnFamilyDescriptors.size()); + + try { + db = RocksDB.open(dbOptions, dbDir.getAbsolutePath(), columnFamilyDescriptors, columnFamilies); + setDbAccessor(columnFamilies.get(0), columnFamilies.get(1)); + } catch (final RocksDBException e) { + if ("Column family not found: keyValueWithTimestamp".equals(e.getMessage())) { + try { + db = RocksDB.open(dbOptions, dbDir.getAbsolutePath(), columnFamilyDescriptors.subList(0, 1), columnFamilies); + columnFamilies.add(db.createColumnFamily(columnFamilyDescriptors.get(1))); + } catch (final RocksDBException fatal) { + throw new ProcessorStateException("Error opening store " + name + " at location " + dbDir.toString(), fatal); + } + setDbAccessor(columnFamilies.get(0), columnFamilies.get(1)); + } else { + throw new ProcessorStateException("Error opening store " + name + " at location " + dbDir.toString(), e); + } + } + } + + private void setDbAccessor(final ColumnFamilyHandle noTimestampColumnFamily, + final ColumnFamilyHandle withTimestampColumnFamily) { + final RocksIterator noTimestampsIter = db.newIterator(noTimestampColumnFamily); + noTimestampsIter.seekToFirst(); + if (noTimestampsIter.isValid()) { + log.info("Opening store {} in upgrade mode", name); + dbAccessor = new DualColumnFamilyAccessor(noTimestampColumnFamily, withTimestampColumnFamily); + } else { + log.info("Opening store {} in regular mode", name); + dbAccessor = new SingleColumnFamilyAccessor(withTimestampColumnFamily); + noTimestampColumnFamily.close(); + } + noTimestampsIter.close(); + } + + + private class DualColumnFamilyAccessor implements RocksDBAccessor { + private final ColumnFamilyHandle oldColumnFamily; + private final ColumnFamilyHandle newColumnFamily; + + private DualColumnFamilyAccessor(final ColumnFamilyHandle oldColumnFamily, + final ColumnFamilyHandle newColumnFamily) { + this.oldColumnFamily = oldColumnFamily; + this.newColumnFamily = newColumnFamily; + } + + @Override + public void put(final byte[] key, + final byte[] valueWithTimestamp) { + if (valueWithTimestamp == null) { + try { + db.delete(oldColumnFamily, wOptions, key); + } catch (final RocksDBException e) { + // String format is happening in wrapping stores. So formatted message is thrown from wrapping stores. + throw new ProcessorStateException("Error while removing key from store " + name, e); + } + try { + db.delete(newColumnFamily, wOptions, key); + } catch (final RocksDBException e) { + // String format is happening in wrapping stores. So formatted message is thrown from wrapping stores. + throw new ProcessorStateException("Error while removing key from store " + name, e); + } + } else { + try { + db.delete(oldColumnFamily, wOptions, key); + } catch (final RocksDBException e) { + // String format is happening in wrapping stores. So formatted message is thrown from wrapping stores. + throw new ProcessorStateException("Error while removing key from store " + name, e); + } + try { + db.put(newColumnFamily, wOptions, key, valueWithTimestamp); + } catch (final RocksDBException e) { + // String format is happening in wrapping stores. So formatted message is thrown from wrapping stores. + throw new ProcessorStateException("Error while putting key/value into store " + name, e); + } + } + } + + @Override + public void prepareBatch(final List> entries, + final WriteBatch batch) throws RocksDBException { + for (final KeyValue entry : entries) { + Objects.requireNonNull(entry.key, "key cannot be null"); + addToBatch(entry.key.get(), entry.value, batch); + } + } + + @Override + public byte[] get(final byte[] key) throws RocksDBException { + final byte[] valueWithTimestamp = db.get(newColumnFamily, key); + if (valueWithTimestamp != null) { + return valueWithTimestamp; + } + + final byte[] plainValue = db.get(oldColumnFamily, key); + if (plainValue != null) { + final byte[] valueWithUnknownTimestamp = convertToTimestampedFormat(plainValue); + // this does only work, because the changelog topic contains correct data already + // for other format changes, we cannot take this short cut and can only migrate data + // from old to new store on put() + put(key, valueWithUnknownTimestamp); + return valueWithUnknownTimestamp; + } + + return null; + } + + @Override + public byte[] getOnly(final byte[] key) throws RocksDBException { + final byte[] valueWithTimestamp = db.get(newColumnFamily, key); + if (valueWithTimestamp != null) { + return valueWithTimestamp; + } + + final byte[] plainValue = db.get(oldColumnFamily, key); + if (plainValue != null) { + return convertToTimestampedFormat(plainValue); + } + + return null; + } + + @Override + public KeyValueIterator range(final Bytes from, + final Bytes to, + final boolean forward) { + return new RocksDBDualCFRangeIterator( + name, + db.newIterator(newColumnFamily), + db.newIterator(oldColumnFamily), + from, + to, + forward, + true); + } + + @Override + public void deleteRange(final byte[] from, final byte[] to) { + try { + db.deleteRange(oldColumnFamily, wOptions, from, to); + } catch (final RocksDBException e) { + // String format is happening in wrapping stores. So formatted message is thrown from wrapping stores. + throw new ProcessorStateException("Error while removing key from store " + name, e); + } + try { + db.deleteRange(newColumnFamily, wOptions, from, to); + } catch (final RocksDBException e) { + // String format is happening in wrapping stores. So formatted message is thrown from wrapping stores. + throw new ProcessorStateException("Error while removing key from store " + name, e); + } + } + + @Override + public KeyValueIterator all(final boolean forward) { + final RocksIterator innerIterWithTimestamp = db.newIterator(newColumnFamily); + final RocksIterator innerIterNoTimestamp = db.newIterator(oldColumnFamily); + if (forward) { + innerIterWithTimestamp.seekToFirst(); + innerIterNoTimestamp.seekToFirst(); + } else { + innerIterWithTimestamp.seekToLast(); + innerIterNoTimestamp.seekToLast(); + } + return new RocksDBDualCFIterator(name, innerIterWithTimestamp, innerIterNoTimestamp, forward); + } + + @Override + public KeyValueIterator prefixScan(final Bytes prefix) { + final Bytes to = Bytes.increment(prefix); + return new RocksDBDualCFRangeIterator( + name, + db.newIterator(newColumnFamily), + db.newIterator(oldColumnFamily), + prefix, + to, + true, + false + ); + } + + @Override + public long approximateNumEntries() throws RocksDBException { + return db.getLongProperty(oldColumnFamily, "rocksdb.estimate-num-keys") + + db.getLongProperty(newColumnFamily, "rocksdb.estimate-num-keys"); + } + + @Override + public void flush() throws RocksDBException { + db.flush(fOptions, oldColumnFamily); + db.flush(fOptions, newColumnFamily); + } + + @Override + public void prepareBatchForRestore(final Collection> records, + final WriteBatch batch) throws RocksDBException { + for (final KeyValue record : records) { + addToBatch(record.key, record.value, batch); + } + } + + @Override + public void addToBatch(final byte[] key, + final byte[] value, + final WriteBatch batch) throws RocksDBException { + if (value == null) { + batch.delete(oldColumnFamily, key); + batch.delete(newColumnFamily, key); + } else { + batch.delete(oldColumnFamily, key); + batch.put(newColumnFamily, key, value); + } + } + + @Override + public void close() { + oldColumnFamily.close(); + newColumnFamily.close(); + } + } + + private class RocksDBDualCFIterator extends AbstractIterator> + implements KeyValueIterator { + + // RocksDB's JNI interface does not expose getters/setters that allow the + // comparator to be pluggable, and the default is lexicographic, so it's + // safe to just force lexicographic comparator here for now. + private final Comparator comparator = Bytes.BYTES_LEXICO_COMPARATOR; + + private final String storeName; + private final RocksIterator iterWithTimestamp; + private final RocksIterator iterNoTimestamp; + private final boolean forward; + + private volatile boolean open = true; + + private byte[] nextWithTimestamp; + private byte[] nextNoTimestamp; + private KeyValue next; + + RocksDBDualCFIterator(final String storeName, + final RocksIterator iterWithTimestamp, + final RocksIterator iterNoTimestamp, + final boolean forward) { + this.iterWithTimestamp = iterWithTimestamp; + this.iterNoTimestamp = iterNoTimestamp; + this.storeName = storeName; + this.forward = forward; + } + + @Override + public synchronized boolean hasNext() { + if (!open) { + throw new InvalidStateStoreException(String.format("RocksDB iterator for store %s has closed", storeName)); + } + return super.hasNext(); + } + + @Override + public synchronized KeyValue next() { + return super.next(); + } + + @Override + public KeyValue makeNext() { + if (nextNoTimestamp == null && iterNoTimestamp.isValid()) { + nextNoTimestamp = iterNoTimestamp.key(); + } + + if (nextWithTimestamp == null && iterWithTimestamp.isValid()) { + nextWithTimestamp = iterWithTimestamp.key(); + } + + if (nextNoTimestamp == null && !iterNoTimestamp.isValid()) { + if (nextWithTimestamp == null && !iterWithTimestamp.isValid()) { + return allDone(); + } else { + next = KeyValue.pair(new Bytes(nextWithTimestamp), iterWithTimestamp.value()); + nextWithTimestamp = null; + if (forward) { + iterWithTimestamp.next(); + } else { + iterWithTimestamp.prev(); + } + } + } else { + if (nextWithTimestamp == null) { + next = KeyValue.pair(new Bytes(nextNoTimestamp), convertToTimestampedFormat(iterNoTimestamp.value())); + nextNoTimestamp = null; + if (forward) { + iterNoTimestamp.next(); + } else { + iterNoTimestamp.prev(); + } + } else { + if (forward) { + if (comparator.compare(nextNoTimestamp, nextWithTimestamp) <= 0) { + next = KeyValue.pair(new Bytes(nextNoTimestamp), convertToTimestampedFormat(iterNoTimestamp.value())); + nextNoTimestamp = null; + iterNoTimestamp.next(); + } else { + next = KeyValue.pair(new Bytes(nextWithTimestamp), iterWithTimestamp.value()); + nextWithTimestamp = null; + iterWithTimestamp.next(); + } + } else { + if (comparator.compare(nextNoTimestamp, nextWithTimestamp) >= 0) { + next = KeyValue.pair(new Bytes(nextNoTimestamp), convertToTimestampedFormat(iterNoTimestamp.value())); + nextNoTimestamp = null; + iterNoTimestamp.prev(); + } else { + next = KeyValue.pair(new Bytes(nextWithTimestamp), iterWithTimestamp.value()); + nextWithTimestamp = null; + iterWithTimestamp.prev(); + } + } + } + } + return next; + } + + @Override + public synchronized void close() { + openIterators.remove(this); + iterNoTimestamp.close(); + iterWithTimestamp.close(); + open = false; + } + + @Override + public Bytes peekNextKey() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return next.key; + } + } + + private class RocksDBDualCFRangeIterator extends RocksDBDualCFIterator { + // RocksDB's JNI interface does not expose getters/setters that allow the + // comparator to be pluggable, and the default is lexicographic, so it's + // safe to just force lexicographic comparator here for now. + private final Comparator comparator = Bytes.BYTES_LEXICO_COMPARATOR; + private final byte[] rawLastKey; + private final boolean forward; + private final boolean toInclusive; + + RocksDBDualCFRangeIterator(final String storeName, + final RocksIterator iterWithTimestamp, + final RocksIterator iterNoTimestamp, + final Bytes from, + final Bytes to, + final boolean forward, + final boolean toInclusive) { + super(storeName, iterWithTimestamp, iterNoTimestamp, forward); + this.forward = forward; + this.toInclusive = toInclusive; + if (forward) { + if (from == null) { + iterWithTimestamp.seekToFirst(); + iterNoTimestamp.seekToFirst(); + } else { + iterWithTimestamp.seek(from.get()); + iterNoTimestamp.seek(from.get()); + } + rawLastKey = to == null ? null : to.get(); + } else { + if (to == null) { + iterWithTimestamp.seekToLast(); + iterNoTimestamp.seekToLast(); + } else { + iterWithTimestamp.seekForPrev(to.get()); + iterNoTimestamp.seekForPrev(to.get()); + } + rawLastKey = from == null ? null : from.get(); + } + } + + @Override + public KeyValue makeNext() { + final KeyValue next = super.makeNext(); + + if (next == null) { + return allDone(); + } else if (rawLastKey == null) { + //null means range endpoint is open + return next; + } else { + if (forward) { + if (comparator.compare(next.key.get(), rawLastKey) < 0) { + return next; + } else if (comparator.compare(next.key.get(), rawLastKey) == 0) { + return toInclusive ? next : allDone(); + } else { + return allDone(); + } + } else { + if (comparator.compare(next.key.get(), rawLastKey) >= 0) { + return next; + } else { + return allDone(); + } + } + } + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimestampedWindowStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimestampedWindowStore.java new file mode 100644 index 0000000..b96748e --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBTimestampedWindowStore.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.state.TimestampedBytesStore; + +class RocksDBTimestampedWindowStore extends RocksDBWindowStore implements TimestampedBytesStore { + + RocksDBTimestampedWindowStore(final SegmentedBytesStore bytesStore, + final boolean retainDuplicates, + final long windowSize) { + super(bytesStore, retainDuplicates, windowSize); + } + +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBWindowStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBWindowStore.java new file mode 100644 index 0000000..8f48dca --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDBWindowStore.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.streams.state.WindowStoreIterator; + +public class RocksDBWindowStore + extends WrappedStateStore + implements WindowStore { + + private final boolean retainDuplicates; + private final long windowSize; + + private int seqnum = 0; + + RocksDBWindowStore(final SegmentedBytesStore bytesStore, + final boolean retainDuplicates, + final long windowSize) { + super(bytesStore); + this.retainDuplicates = retainDuplicates; + this.windowSize = windowSize; + } + + @Override + public void put(final Bytes key, final byte[] value, final long windowStartTimestamp) { + // Skip if value is null and duplicates are allowed since this delete is a no-op + if (!(value == null && retainDuplicates)) { + maybeUpdateSeqnumForDups(); + wrapped().put(WindowKeySchema.toStoreKeyBinary(key, windowStartTimestamp, seqnum), value); + } + } + + @Override + public byte[] fetch(final Bytes key, final long timestamp) { + return wrapped().get(WindowKeySchema.toStoreKeyBinary(key, timestamp, seqnum)); + } + + @Override + public WindowStoreIterator fetch(final Bytes key, final long timeFrom, final long timeTo) { + final KeyValueIterator bytesIterator = wrapped().fetch(key, timeFrom, timeTo); + return new WindowStoreIteratorWrapper(bytesIterator, windowSize).valuesIterator(); + } + + @Override + public WindowStoreIterator backwardFetch(final Bytes key, final long timeFrom, final long timeTo) { + final KeyValueIterator bytesIterator = wrapped().backwardFetch(key, timeFrom, timeTo); + return new WindowStoreIteratorWrapper(bytesIterator, windowSize).valuesIterator(); + } + + @Override + public KeyValueIterator, byte[]> fetch(final Bytes keyFrom, + final Bytes keyTo, + final long timeFrom, + final long timeTo) { + final KeyValueIterator bytesIterator = wrapped().fetch(keyFrom, keyTo, timeFrom, timeTo); + return new WindowStoreIteratorWrapper(bytesIterator, windowSize).keyValueIterator(); + } + + @Override + public KeyValueIterator, byte[]> backwardFetch(final Bytes keyFrom, + final Bytes keyTo, + final long timeFrom, + final long timeTo) { + final KeyValueIterator bytesIterator = wrapped().backwardFetch(keyFrom, keyTo, timeFrom, timeTo); + return new WindowStoreIteratorWrapper(bytesIterator, windowSize).keyValueIterator(); + } + + @Override + public KeyValueIterator, byte[]> all() { + final KeyValueIterator bytesIterator = wrapped().all(); + return new WindowStoreIteratorWrapper(bytesIterator, windowSize).keyValueIterator(); + } + + @Override + public KeyValueIterator, byte[]> backwardAll() { + final KeyValueIterator bytesIterator = wrapped().backwardAll(); + return new WindowStoreIteratorWrapper(bytesIterator, windowSize).keyValueIterator(); + } + + @Override + public KeyValueIterator, byte[]> fetchAll(final long timeFrom, final long timeTo) { + final KeyValueIterator bytesIterator = wrapped().fetchAll(timeFrom, timeTo); + return new WindowStoreIteratorWrapper(bytesIterator, windowSize).keyValueIterator(); + } + + @Override + public KeyValueIterator, byte[]> backwardFetchAll(final long timeFrom, final long timeTo) { + final KeyValueIterator bytesIterator = wrapped().backwardFetchAll(timeFrom, timeTo); + return new WindowStoreIteratorWrapper(bytesIterator, windowSize).keyValueIterator(); + } + + private void maybeUpdateSeqnumForDups() { + if (retainDuplicates) { + seqnum = (seqnum + 1) & 0x7FFFFFFF; + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDbIterator.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDbIterator.java new file mode 100644 index 0000000..388195a --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDbIterator.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.AbstractIterator; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.rocksdb.RocksIterator; + +import java.util.NoSuchElementException; +import java.util.Set; +import java.util.function.Consumer; + +class RocksDbIterator extends AbstractIterator> implements KeyValueIterator { + + private final String storeName; + private final RocksIterator iter; + private final Set> openIterators; + private final Consumer advanceIterator; + + private volatile boolean open = true; + + private KeyValue next; + + RocksDbIterator(final String storeName, + final RocksIterator iter, + final Set> openIterators, + final boolean forward) { + this.storeName = storeName; + this.iter = iter; + this.openIterators = openIterators; + this.advanceIterator = forward ? RocksIterator::next : RocksIterator::prev; + } + + @Override + public synchronized boolean hasNext() { + if (!open) { + throw new InvalidStateStoreException(String.format("RocksDB iterator for store %s has closed", storeName)); + } + return super.hasNext(); + } + + @Override + public KeyValue makeNext() { + if (!iter.isValid()) { + return allDone(); + } else { + next = getKeyValue(); + advanceIterator.accept(iter); + return next; + } + } + + private KeyValue getKeyValue() { + return new KeyValue<>(new Bytes(iter.key()), iter.value()); + } + + @Override + public synchronized void close() { + openIterators.remove(this); + iter.close(); + open = false; + } + + @Override + public Bytes peekNextKey() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return next.key; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDbKeyValueBytesStoreSupplier.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDbKeyValueBytesStoreSupplier.java new file mode 100644 index 0000000..87be767 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDbKeyValueBytesStoreSupplier.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.state.KeyValueBytesStoreSupplier; +import org.apache.kafka.streams.state.KeyValueStore; + +public class RocksDbKeyValueBytesStoreSupplier implements KeyValueBytesStoreSupplier { + + private final String name; + private final boolean returnTimestampedStore; + + public RocksDbKeyValueBytesStoreSupplier(final String name, + final boolean returnTimestampedStore) { + this.name = name; + this.returnTimestampedStore = returnTimestampedStore; + } + + @Override + public String name() { + return name; + } + + @Override + public KeyValueStore get() { + return returnTimestampedStore ? + new RocksDBTimestampedStore(name, metricsScope()) : + new RocksDBStore(name, metricsScope()); + } + + @Override + public String metricsScope() { + return "rocksdb"; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDbSessionBytesStoreSupplier.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDbSessionBytesStoreSupplier.java new file mode 100644 index 0000000..684ebf4 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDbSessionBytesStoreSupplier.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.state.SessionBytesStoreSupplier; +import org.apache.kafka.streams.state.SessionStore; + +public class RocksDbSessionBytesStoreSupplier implements SessionBytesStoreSupplier { + private final String name; + private final long retentionPeriod; + + public RocksDbSessionBytesStoreSupplier(final String name, + final long retentionPeriod) { + this.name = name; + this.retentionPeriod = retentionPeriod; + } + + @Override + public String name() { + return name; + } + + @Override + public SessionStore get() { + final RocksDBSegmentedBytesStore segmented = new RocksDBSegmentedBytesStore( + name, + metricsScope(), + retentionPeriod, + segmentIntervalMs(), + new SessionKeySchema()); + return new RocksDBSessionStore(segmented); + } + + @Override + public String metricsScope() { + return "rocksdb-session"; + } + + @Override + public long segmentIntervalMs() { + // Selected somewhat arbitrarily. Profiling may reveal a different value is preferable. + return Math.max(retentionPeriod / 2, 60_000L); + } + + @Override + public long retentionPeriod() { + return retentionPeriod; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDbWindowBytesStoreSupplier.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDbWindowBytesStoreSupplier.java new file mode 100644 index 0000000..3ee5b88 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/RocksDbWindowBytesStoreSupplier.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.state.WindowBytesStoreSupplier; +import org.apache.kafka.streams.state.WindowStore; + +public class RocksDbWindowBytesStoreSupplier implements WindowBytesStoreSupplier { + public enum WindowStoreTypes { + DEFAULT_WINDOW_STORE, + TIMESTAMPED_WINDOW_STORE + } + + private final String name; + private final long retentionPeriod; + private final long segmentInterval; + private final long windowSize; + private final boolean retainDuplicates; + private final WindowStoreTypes windowStoreType; + + public RocksDbWindowBytesStoreSupplier(final String name, + final long retentionPeriod, + final long segmentInterval, + final long windowSize, + final boolean retainDuplicates, + final boolean returnTimestampedStore) { + this(name, retentionPeriod, segmentInterval, windowSize, retainDuplicates, + returnTimestampedStore + ? WindowStoreTypes.TIMESTAMPED_WINDOW_STORE + : WindowStoreTypes.DEFAULT_WINDOW_STORE); + } + + public RocksDbWindowBytesStoreSupplier(final String name, + final long retentionPeriod, + final long segmentInterval, + final long windowSize, + final boolean retainDuplicates, + final WindowStoreTypes windowStoreType) { + this.name = name; + this.retentionPeriod = retentionPeriod; + this.segmentInterval = segmentInterval; + this.windowSize = windowSize; + this.retainDuplicates = retainDuplicates; + this.windowStoreType = windowStoreType; + } + + @Override + public String name() { + return name; + } + + @Override + public WindowStore get() { + switch (windowStoreType) { + case DEFAULT_WINDOW_STORE: + return new RocksDBWindowStore( + new RocksDBSegmentedBytesStore( + name, + metricsScope(), + retentionPeriod, + segmentInterval, + new WindowKeySchema()), + retainDuplicates, + windowSize); + case TIMESTAMPED_WINDOW_STORE: + return new RocksDBTimestampedWindowStore( + new RocksDBTimestampedSegmentedBytesStore( + name, + metricsScope(), + retentionPeriod, + segmentInterval, + new WindowKeySchema()), + retainDuplicates, + windowSize); + default: + throw new IllegalArgumentException("invalid window store type: " + windowStoreType); + } + } + + @Override + public String metricsScope() { + return "rocksdb-window"; + } + + @Override + public long segmentIntervalMs() { + return segmentInterval; + } + + @Override + public long windowSize() { + return windowSize; + } + + @Override + public boolean retainDuplicates() { + return retainDuplicates; + } + + @Override + public long retentionPeriod() { + return retentionPeriod; + } + + @Override + public String toString() { + return "RocksDbWindowBytesStoreSupplier{" + + "name='" + name + '\'' + + ", retentionPeriod=" + retentionPeriod + + ", segmentInterval=" + segmentInterval + + ", windowSize=" + windowSize + + ", retainDuplicates=" + retainDuplicates + + ", windowStoreType=" + windowStoreType + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/Segment.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/Segment.java new file mode 100644 index 0000000..ea3b89a --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/Segment.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.state.KeyValueStore; + +import java.io.IOException; + +public interface Segment extends KeyValueStore, BatchWritingStore { + + void destroy() throws IOException; + + void deleteRange(Bytes keyFrom, Bytes keyTo); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/SegmentIterator.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/SegmentIterator.java new file mode 100644 index 0000000..6191c49 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/SegmentIterator.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.state.KeyValueIterator; + +import java.util.Iterator; +import java.util.NoSuchElementException; + +/** + * Iterate over multiple KeyValueSegments + */ +class SegmentIterator implements KeyValueIterator { + + private final Bytes from; + private final Bytes to; + private final boolean forward; + protected final Iterator segments; + protected final HasNextCondition hasNextCondition; + + private S currentSegment; + KeyValueIterator currentIterator; + + SegmentIterator(final Iterator segments, + final HasNextCondition hasNextCondition, + final Bytes from, + final Bytes to, + final boolean forward) { + this.segments = segments; + this.hasNextCondition = hasNextCondition; + this.from = from; + this.to = to; + this.forward = forward; + } + + @Override + public void close() { + if (currentIterator != null) { + currentIterator.close(); + currentIterator = null; + } + } + + @Override + public Bytes peekNextKey() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return currentIterator.peekNextKey(); + } + + @Override + public boolean hasNext() { + boolean hasNext = false; + while ((currentIterator == null || !(hasNext = hasNextConditionHasNext()) || !currentSegment.isOpen()) + && segments.hasNext()) { + close(); + currentSegment = segments.next(); + try { + if (forward) { + currentIterator = currentSegment.range(from, to); + } else { + currentIterator = currentSegment.reverseRange(from, to); + } + } catch (final InvalidStateStoreException e) { + // segment may have been closed so we ignore it. + } + } + return currentIterator != null && hasNext; + } + + private boolean hasNextConditionHasNext() { + boolean hasNext = false; + try { + hasNext = hasNextCondition.hasNext(currentIterator); + } catch (final InvalidStateStoreException e) { + //already closed so ignore + } + return hasNext; + } + + @Override + public KeyValue next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return currentIterator.next(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/SegmentedBytesStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/SegmentedBytesStore.java new file mode 100644 index 0000000..4519929 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/SegmentedBytesStore.java @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.state.KeyValueIterator; + +import java.util.List; + +/** + * The interface representing a StateStore that has 1 or more segments that are based + * on time. + * @see RocksDBSegmentedBytesStore + */ +public interface SegmentedBytesStore extends StateStore { + + /** + * Fetch all records from the segmented store with the provided key and time range + * from all existing segments + * @param key the key to match + * @param from earliest time to match + * @param to latest time to match + * @return an iterator over key-value pairs + */ + KeyValueIterator fetch(final Bytes key, final long from, final long to); + + /** + * Fetch all records from the segmented store with the provided key and time range + * from all existing segments in backward order (from latest to earliest) + * @param key the key to match + * @param from earliest time to match + * @param to latest time to match + * @return an iterator over key-value pairs + */ + KeyValueIterator backwardFetch(final Bytes key, final long from, final long to); + + /** + * Fetch all records from the segmented store in the provided key range and time range + * from all existing segments + * @param keyFrom The first key that could be in the range + * @param keyTo The last key that could be in the range + * @param from earliest time to match + * @param to latest time to match + * @return an iterator over key-value pairs + */ + KeyValueIterator fetch(final Bytes keyFrom, final Bytes keyTo, final long from, final long to); + + /** + * Fetch all records from the segmented store in the provided key range and time range + * from all existing segments in backward order (from latest to earliest) + * @param keyFrom The first key that could be in the range + * @param keyTo The last key that could be in the range + * @param from earliest time to match + * @param to latest time to match + * @return an iterator over key-value pairs + */ + KeyValueIterator backwardFetch(final Bytes keyFrom, final Bytes keyTo, final long from, final long to); + + /** + * Gets all the key-value pairs in the existing windows. + * + * @return an iterator over windowed key-value pairs {@code , value>} + * @throws InvalidStateStoreException if the store is not initialized + */ + KeyValueIterator all(); + + /** + * Gets all the key-value pairs in the existing windows in backward order (from latest to earliest). + * + * @return an iterator over windowed key-value pairs {@code , value>} + * @throws InvalidStateStoreException if the store is not initialized + */ + KeyValueIterator backwardAll(); + + /** + * Gets all the key-value pairs that belong to the windows within in the given time range. + * + * @param from the beginning of the time slot from which to search + * @param to the end of the time slot from which to search + * @return an iterator over windowed key-value pairs {@code , value>} + * @throws InvalidStateStoreException if the store is not initialized + * @throws NullPointerException if null is used for any key + */ + KeyValueIterator fetchAll(final long from, final long to); + + KeyValueIterator backwardFetchAll(final long from, final long to); + + /** + * Remove the record with the provided key. The key + * should be a composite of the record key, and the timestamp information etc + * as described by the {@link KeySchema} + * @param key the segmented key to remove + */ + void remove(Bytes key); + + /** + * Remove all duplicated records with the provided key in the specified timestamp. + * + * @param key the segmented key to remove + * @param timestamp the timestamp to match + */ + void remove(Bytes key, long timestamp); + + /** + * Write a new value to the store with the provided key. The key + * should be a composite of the record key, and the timestamp information etc + * as described by the {@link KeySchema} + * @param key + * @param value + */ + void put(Bytes key, byte[] value); + + /** + * Get the record from the store with the given key. The key + * should be a composite of the record key, and the timestamp information etc + * as described by the {@link KeySchema} + * @param key + * @return + */ + byte[] get(Bytes key); + + interface KeySchema { + + /** + * Given a range of record keys and a time, construct a Segmented key that represents + * the upper range of keys to search when performing range queries. + * @see SessionKeySchema#upperRange + * @see WindowKeySchema#upperRange + * @param key + * @param to + * @return The key that represents the upper range to search for in the store + */ + Bytes upperRange(final Bytes key, final long to); + + /** + * Given a range of record keys and a time, construct a Segmented key that represents + * the lower range of keys to search when performing range queries. + * @see SessionKeySchema#lowerRange + * @see WindowKeySchema#lowerRange + * @param key + * @param from + * @return The key that represents the lower range to search for in the store + */ + Bytes lowerRange(final Bytes key, final long from); + + /** + * Given a record key and a time, construct a Segmented key to search when performing + * prefixed queries. + * + * @param key + * @param timestamp + * @return The key that represents the prefixed Segmented key in bytes. + */ + default Bytes toStoreBinaryKeyPrefix(final Bytes key, long timestamp) { + throw new UnsupportedOperationException(); + } + + /** + * Given a range of fixed size record keys and a time, construct a Segmented key that represents + * the upper range of keys to search when performing range queries. + * @see SessionKeySchema#upperRange + * @see WindowKeySchema#upperRange + * @param key the last key in the range + * @param to the last timestamp in the range + * @return The key that represents the upper range to search for in the store + */ + Bytes upperRangeFixedSize(final Bytes key, final long to); + + /** + * Given a range of fixed size record keys and a time, construct a Segmented key that represents + * the lower range of keys to search when performing range queries. + * @see SessionKeySchema#lowerRange + * @see WindowKeySchema#lowerRange + * @param key the first key in the range + * @param from the first timestamp in the range + * @return The key that represents the lower range to search for in the store + */ + Bytes lowerRangeFixedSize(final Bytes key, final long from); + + /** + * Extract the timestamp of the segment from the key. The key is a composite of + * the record-key, any timestamps, plus any additional information. + * @see SessionKeySchema#lowerRange + * @see WindowKeySchema#lowerRange + * @param key + * @return + */ + long segmentTimestamp(final Bytes key); + + /** + * Create an implementation of {@link HasNextCondition} that knows when + * to stop iterating over the KeyValueSegments. Used during {@link SegmentedBytesStore#fetch(Bytes, Bytes, long, long)} operations + * @param binaryKeyFrom the first key in the range + * @param binaryKeyTo the last key in the range + * @param from starting time range + * @param to ending time range + * @return + */ + HasNextCondition hasNextCondition(final Bytes binaryKeyFrom, final Bytes binaryKeyTo, final long from, final long to); + + /** + * Used during {@link SegmentedBytesStore#fetch(Bytes, long, long)} operations to determine + * which segments should be scanned. + * @param segments + * @param from + * @param to + * @return List of segments to search + */ + List segmentsToSearch(Segments segments, long from, long to, boolean forward); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/SegmentedCacheFunction.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/SegmentedCacheFunction.java new file mode 100644 index 0000000..68a40b4 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/SegmentedCacheFunction.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.state.internals.SegmentedBytesStore.KeySchema; + +import java.nio.ByteBuffer; + +class SegmentedCacheFunction implements CacheFunction { + + private static final int SEGMENT_ID_BYTES = 8; + + private final KeySchema keySchema; + private final long segmentInterval; + + SegmentedCacheFunction(final KeySchema keySchema, final long segmentInterval) { + this.keySchema = keySchema; + this.segmentInterval = segmentInterval; + } + + @Override + public Bytes key(final Bytes cacheKey) { + return Bytes.wrap(bytesFromCacheKey(cacheKey)); + } + + @Override + public Bytes cacheKey(final Bytes key) { + return cacheKey(key, segmentId(key)); + } + + Bytes cacheKey(final Bytes key, final long segmentId) { + final byte[] keyBytes = key.get(); + final ByteBuffer buf = ByteBuffer.allocate(SEGMENT_ID_BYTES + keyBytes.length); + buf.putLong(segmentId).put(keyBytes); + return Bytes.wrap(buf.array()); + } + + static byte[] bytesFromCacheKey(final Bytes cacheKey) { + final byte[] binaryKey = new byte[cacheKey.get().length - SEGMENT_ID_BYTES]; + System.arraycopy(cacheKey.get(), SEGMENT_ID_BYTES, binaryKey, 0, binaryKey.length); + return binaryKey; + } + + public long segmentId(final Bytes key) { + return segmentId(keySchema.segmentTimestamp(key)); + } + + long segmentId(final long timestamp) { + return timestamp / segmentInterval; + } + + long getSegmentInterval() { + return segmentInterval; + } + + int compareSegmentedKeys(final Bytes cacheKey, final Bytes storeKey) { + final long storeSegmentId = segmentId(storeKey); + final long cacheSegmentId = ByteBuffer.wrap(cacheKey.get()).getLong(); + + final int segmentCompare = Long.compare(cacheSegmentId, storeSegmentId); + if (segmentCompare == 0) { + final byte[] cacheKeyBytes = cacheKey.get(); + final byte[] storeKeyBytes = storeKey.get(); + return Bytes.BYTES_LEXICO_COMPARATOR.compare( + cacheKeyBytes, SEGMENT_ID_BYTES, cacheKeyBytes.length - SEGMENT_ID_BYTES, + storeKeyBytes, 0, storeKeyBytes.length + ); + } else { + return segmentCompare; + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/Segments.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/Segments.java new file mode 100644 index 0000000..7e50b98 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/Segments.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.processor.ProcessorContext; + +import java.util.List; + +interface Segments { + + long segmentId(final long timestamp); + + String segmentName(final long segmentId); + + S getSegmentForTimestamp(final long timestamp); + + S getOrCreateSegmentIfLive(final long segmentId, final ProcessorContext context, final long streamTime); + + S getOrCreateSegment(final long segmentId, final ProcessorContext context); + + void openExisting(final ProcessorContext context, final long streamTime); + + List segments(final long timeFrom, final long timeTo, final boolean forward); + + List allSegments(final boolean forward); + + void flush(); + + void close(); +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/SessionKeySchema.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/SessionKeySchema.java new file mode 100644 index 0000000..8bb50e5 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/SessionKeySchema.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.kstream.Window; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.SessionWindow; + +import static org.apache.kafka.streams.state.StateSerdes.TIMESTAMP_SIZE; + +import java.nio.ByteBuffer; +import java.util.List; + + +public class SessionKeySchema implements SegmentedBytesStore.KeySchema { + + private static final int SUFFIX_SIZE = 2 * TIMESTAMP_SIZE; + private static final byte[] MIN_SUFFIX = new byte[SUFFIX_SIZE]; + + @Override + public Bytes upperRangeFixedSize(final Bytes key, final long to) { + final Windowed sessionKey = new Windowed<>(key, new SessionWindow(to, Long.MAX_VALUE)); + return SessionKeySchema.toBinary(sessionKey); + } + + @Override + public Bytes lowerRangeFixedSize(final Bytes key, final long from) { + final Windowed sessionKey = new Windowed<>(key, new SessionWindow(0, Math.max(0, from))); + return SessionKeySchema.toBinary(sessionKey); + } + + @Override + public Bytes upperRange(final Bytes key, final long to) { + final byte[] maxSuffix = ByteBuffer.allocate(SUFFIX_SIZE) + // the end timestamp can be as large as possible as long as it's larger than start time + .putLong(Long.MAX_VALUE) + // this is the start timestamp + .putLong(to) + .array(); + return OrderedBytes.upperRange(key, maxSuffix); + } + + @Override + public Bytes lowerRange(final Bytes key, final long from) { + return OrderedBytes.lowerRange(key, MIN_SUFFIX); + } + + @Override + public long segmentTimestamp(final Bytes key) { + return SessionKeySchema.extractEndTimestamp(key.get()); + } + + @Override + public HasNextCondition hasNextCondition(final Bytes binaryKeyFrom, final Bytes binaryKeyTo, final long from, final long to) { + return iterator -> { + while (iterator.hasNext()) { + final Bytes bytes = iterator.peekNextKey(); + final Windowed windowedKey = SessionKeySchema.from(bytes); + if ((binaryKeyFrom == null || windowedKey.key().compareTo(binaryKeyFrom) >= 0) + && (binaryKeyTo == null || windowedKey.key().compareTo(binaryKeyTo) <= 0) + && windowedKey.window().end() >= from + && windowedKey.window().start() <= to) { + return true; + } + iterator.next(); + } + return false; + }; + } + + @Override + public List segmentsToSearch(final Segments segments, + final long from, + final long to, + final boolean forward) { + return segments.segments(from, Long.MAX_VALUE, forward); + } + + private static K extractKey(final byte[] binaryKey, + final Deserializer deserializer, + final String topic) { + return deserializer.deserialize(topic, extractKeyBytes(binaryKey)); + } + + static byte[] extractKeyBytes(final byte[] binaryKey) { + final byte[] bytes = new byte[binaryKey.length - 2 * TIMESTAMP_SIZE]; + System.arraycopy(binaryKey, 0, bytes, 0, bytes.length); + return bytes; + } + + static long extractEndTimestamp(final byte[] binaryKey) { + return ByteBuffer.wrap(binaryKey).getLong(binaryKey.length - 2 * TIMESTAMP_SIZE); + } + + static long extractStartTimestamp(final byte[] binaryKey) { + return ByteBuffer.wrap(binaryKey).getLong(binaryKey.length - TIMESTAMP_SIZE); + } + + static Window extractWindow(final byte[] binaryKey) { + final ByteBuffer buffer = ByteBuffer.wrap(binaryKey); + final long start = buffer.getLong(binaryKey.length - TIMESTAMP_SIZE); + final long end = buffer.getLong(binaryKey.length - 2 * TIMESTAMP_SIZE); + return new SessionWindow(start, end); + } + + public static Windowed from(final byte[] binaryKey, + final Deserializer keyDeserializer, + final String topic) { + final K key = extractKey(binaryKey, keyDeserializer, topic); + final Window window = extractWindow(binaryKey); + return new Windowed<>(key, window); + } + + public static Windowed from(final Bytes bytesKey) { + final byte[] binaryKey = bytesKey.get(); + final Window window = extractWindow(binaryKey); + return new Windowed<>(Bytes.wrap(extractKeyBytes(binaryKey)), window); + } + + public static Windowed from(final Windowed keyBytes, + final Deserializer keyDeserializer, + final String topic) { + final K key = keyDeserializer.deserialize(topic, keyBytes.key().get()); + return new Windowed<>(key, keyBytes.window()); + } + + public static byte[] toBinary(final Windowed sessionKey, + final Serializer serializer, + final String topic) { + final byte[] bytes = serializer.serialize(topic, sessionKey.key()); + return toBinary(Bytes.wrap(bytes), sessionKey.window().start(), sessionKey.window().end()).get(); + } + + public static Bytes toBinary(final Windowed sessionKey) { + return toBinary(sessionKey.key(), sessionKey.window().start(), sessionKey.window().end()); + } + + public static Bytes toBinary(final Bytes key, + final long startTime, + final long endTime) { + final byte[] bytes = key.get(); + final ByteBuffer buf = ByteBuffer.allocate(bytes.length + 2 * TIMESTAMP_SIZE); + buf.put(bytes); + buf.putLong(endTime); + buf.putLong(startTime); + return Bytes.wrap(buf.array()); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/SessionStoreBuilder.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/SessionStoreBuilder.java new file mode 100644 index 0000000..d0d0394 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/SessionStoreBuilder.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.state.SessionBytesStoreSupplier; +import org.apache.kafka.streams.state.SessionStore; + +import java.util.Objects; + + +public class SessionStoreBuilder extends AbstractStoreBuilder> { + + private final SessionBytesStoreSupplier storeSupplier; + + public SessionStoreBuilder(final SessionBytesStoreSupplier storeSupplier, + final Serde keySerde, + final Serde valueSerde, + final Time time) { + super(Objects.requireNonNull(storeSupplier, "storeSupplier cannot be null").name(), keySerde, valueSerde, time); + Objects.requireNonNull(storeSupplier.metricsScope(), "storeSupplier's metricsScope can't be null"); + this.storeSupplier = storeSupplier; + } + + @Override + public SessionStore build() { + return new MeteredSessionStore<>( + maybeWrapCaching(maybeWrapLogging(storeSupplier.get())), + storeSupplier.metricsScope(), + keySerde, + valueSerde, + time); + } + + private SessionStore maybeWrapCaching(final SessionStore inner) { + if (!enableCaching) { + return inner; + } + return new CachingSessionStore(inner, storeSupplier.segmentIntervalMs()); + } + + private SessionStore maybeWrapLogging(final SessionStore inner) { + if (!enableLogging) { + return inner; + } + return new ChangeLoggingSessionBytesStore(inner); + } + + public long retentionPeriod() { + return storeSupplier.retentionPeriod(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/StateStoreProvider.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/StateStoreProvider.java new file mode 100644 index 0000000..bc14091 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/StateStoreProvider.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.state.QueryableStoreType; +import org.apache.kafka.streams.state.QueryableStoreTypes; + +import java.util.List; + +/** + * Provides access to {@link StateStore}s that have been created + * as part of the {@link org.apache.kafka.streams.processor.internals.ProcessorTopology}. + * To get access to custom stores developers should implement {@link QueryableStoreType}. + * @see QueryableStoreTypes + */ +public interface StateStoreProvider { + + /** + * Find instances of StateStore that are accepted by {@link QueryableStoreType#accepts} and + * have the provided storeName. + * + * @param storeName name of the store + * @param queryableStoreType filter stores based on this queryableStoreType + * @param The type of the Store + * @return List of the instances of the store in this topology. Empty List if not found + */ + List stores(String storeName, QueryableStoreType queryableStoreType); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProvider.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProvider.java new file mode 100644 index 0000000..7dd796e --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProvider.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.StoreQueryParameters; +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.StreamThread; +import org.apache.kafka.streams.processor.internals.Task; +import org.apache.kafka.streams.state.QueryableStoreType; +import org.apache.kafka.streams.state.QueryableStoreTypes; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.TimestampedWindowStore; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +public class StreamThreadStateStoreProvider { + + private final StreamThread streamThread; + + public StreamThreadStateStoreProvider(final StreamThread streamThread) { + this.streamThread = streamThread; + } + + @SuppressWarnings("unchecked") + public List stores(final StoreQueryParameters storeQueryParams) { + final String storeName = storeQueryParams.storeName(); + final QueryableStoreType queryableStoreType = storeQueryParams.queryableStoreType(); + if (streamThread.state() == StreamThread.State.DEAD) { + return Collections.emptyList(); + } + final StreamThread.State state = streamThread.state(); + if (storeQueryParams.staleStoresEnabled() ? state.isAlive() : state == StreamThread.State.RUNNING) { + final Collection tasks = storeQueryParams.staleStoresEnabled() ? + streamThread.allTasks().values() : streamThread.activeTasks(); + + if (storeQueryParams.partition() != null) { + for (final Task task : tasks) { + if (task.id().partition() == storeQueryParams.partition() && + task.getStore(storeName) != null && + storeName.equals(task.getStore(storeName).name())) { + final T typedStore = validateAndCastStores(task.getStore(storeName), queryableStoreType, storeName, task.id()); + return Collections.singletonList(typedStore); + } + } + return Collections.emptyList(); + } else { + final List list = new ArrayList<>(); + for (final Task task : tasks) { + final StateStore store = task.getStore(storeName); + if (store == null) { + // then this task doesn't have that store + } else { + final T typedStore = validateAndCastStores(store, queryableStoreType, storeName, task.id()); + list.add(typedStore); + } + } + return list; + } + } else { + throw new InvalidStateStoreException("Cannot get state store " + storeName + " because the stream thread is " + + state + ", not RUNNING" + + (storeQueryParams.staleStoresEnabled() ? " or REBALANCING" : "")); + } + } + + @SuppressWarnings("unchecked") + private static T validateAndCastStores(final StateStore store, + final QueryableStoreType queryableStoreType, + final String storeName, + final TaskId taskId) { + if (store == null) { + throw new NullPointerException("Expected store not to be null at this point."); + } else if (queryableStoreType.accepts(store)) { + if (!store.isOpen()) { + throw new InvalidStateStoreException( + "Cannot get state store " + storeName + " for task " + taskId + + " because the store is not open. " + + "The state store may have migrated to another instance."); + } + if (store instanceof TimestampedKeyValueStore && queryableStoreType instanceof QueryableStoreTypes.KeyValueStoreType) { + return (T) new ReadOnlyKeyValueStoreFacade<>((TimestampedKeyValueStore) store); + } else if (store instanceof TimestampedWindowStore && queryableStoreType instanceof QueryableStoreTypes.WindowStoreType) { + return (T) new ReadOnlyWindowStoreFacade<>((TimestampedWindowStore) store); + } else { + return (T) store; + } + } else { + throw new InvalidStateStoreException( + "Cannot get state store " + storeName + + " because the queryable store type [" + queryableStoreType.getClass() + + "] does not accept the actual store type [" + store.getClass() + "]." + ); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/StreamsMetadataImpl.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/StreamsMetadataImpl.java new file mode 100644 index 0000000..6bd314c --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/StreamsMetadataImpl.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsMetadata; +import org.apache.kafka.streams.state.HostInfo; + +import java.util.Collections; +import java.util.Objects; +import java.util.Set; + +/** + * Represents the state of an instance (process) in a {@link KafkaStreams} application. + * It contains the user supplied {@link HostInfo} that can be used by developers to build + * APIs and services to connect to other instances, the Set of state stores available on + * the instance and the Set of {@link TopicPartition}s available on the instance. + * NOTE: This is a point in time view. It may change when rebalances happen. + */ +public class StreamsMetadataImpl implements StreamsMetadata { + /** + * Sentinel to indicate that the StreamsMetadata is currently unavailable. This can occur during rebalance + * operations. + */ + public final static StreamsMetadataImpl NOT_AVAILABLE = new StreamsMetadataImpl( + HostInfo.unavailable(), + Collections.emptySet(), + Collections.emptySet(), + Collections.emptySet(), + Collections.emptySet()); + + private final HostInfo hostInfo; + + private final Set stateStoreNames; + + private final Set topicPartitions; + + private final Set standbyStateStoreNames; + + private final Set standbyTopicPartitions; + + public StreamsMetadataImpl(final HostInfo hostInfo, + final Set stateStoreNames, + final Set topicPartitions, + final Set standbyStateStoreNames, + final Set standbyTopicPartitions) { + + this.hostInfo = hostInfo; + this.stateStoreNames = Collections.unmodifiableSet(stateStoreNames); + this.topicPartitions = Collections.unmodifiableSet(topicPartitions); + this.standbyTopicPartitions = Collections.unmodifiableSet(standbyTopicPartitions); + this.standbyStateStoreNames = Collections.unmodifiableSet(standbyStateStoreNames); + } + + /** + * The value of {@link org.apache.kafka.streams.StreamsConfig#APPLICATION_SERVER_CONFIG} configured for the streams + * instance, which is typically host/port + * + * @return {@link HostInfo} corresponding to the streams instance + */ + @Override + public HostInfo hostInfo() { + return hostInfo; + } + + /** + * State stores owned by the instance as an active replica + * + * @return set of active state store names + */ + @Override + public Set stateStoreNames() { + return stateStoreNames; + } + + /** + * Topic partitions consumed by the instance as an active replica + * + * @return set of active topic partitions + */ + @Override + public Set topicPartitions() { + return topicPartitions; + } + + /** + * (Source) Topic partitions for which the instance acts as standby. + * + * @return set of standby topic partitions + */ + @Override + public Set standbyTopicPartitions() { + return standbyTopicPartitions; + } + + /** + * State stores owned by the instance as a standby replica + * + * @return set of standby state store names + */ + @Override + public Set standbyStateStoreNames() { + return standbyStateStoreNames; + } + + @Override + public String host() { + return hostInfo.host(); + } + + @SuppressWarnings("unused") + @Override + public int port() { + return hostInfo.port(); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + final StreamsMetadataImpl that = (StreamsMetadataImpl) o; + return Objects.equals(hostInfo, that.hostInfo) + && Objects.equals(stateStoreNames, that.stateStoreNames) + && Objects.equals(topicPartitions, that.topicPartitions) + && Objects.equals(standbyStateStoreNames, that.standbyStateStoreNames) + && Objects.equals(standbyTopicPartitions, that.standbyTopicPartitions); + } + + @Override + public int hashCode() { + return Objects.hash(hostInfo, stateStoreNames, topicPartitions, standbyStateStoreNames, standbyTopicPartitions); + } + + @Override + public String toString() { + return "StreamsMetadata {" + + "hostInfo=" + hostInfo + + ", stateStoreNames=" + stateStoreNames + + ", topicPartitions=" + topicPartitions + + ", standbyStateStoreNames=" + standbyStateStoreNames + + ", standbyTopicPartitions=" + standbyTopicPartitions + + '}'; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/ThreadCache.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/ThreadCache.java new file mode 100644 index 0000000..cf3a39a --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/ThreadCache.java @@ -0,0 +1,381 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.CircularIterator; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.slf4j.Logger; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; + +/** + * An in-memory LRU cache store similar to {@link MemoryLRUCache} but byte-based, not + * record based + */ +public class ThreadCache { + private final Logger log; + private volatile long maxCacheSizeBytes; + private final StreamsMetricsImpl metrics; + private final Map caches = new HashMap<>(); + + // internal stats + private long numPuts = 0; + private long numGets = 0; + private long numEvicts = 0; + private long numFlushes = 0; + + public interface DirtyEntryFlushListener { + void apply(final List dirty); + } + + public ThreadCache(final LogContext logContext, final long maxCacheSizeBytes, final StreamsMetricsImpl metrics) { + this.maxCacheSizeBytes = maxCacheSizeBytes; + this.metrics = metrics; + this.log = logContext.logger(getClass()); + } + + public long puts() { + return numPuts; + } + + public long gets() { + return numGets; + } + + public long evicts() { + return numEvicts; + } + + public long flushes() { + return numFlushes; + } + + public synchronized void resize(final long newCacheSizeBytes) { + final boolean shrink = newCacheSizeBytes < maxCacheSizeBytes; + maxCacheSizeBytes = newCacheSizeBytes; + if (shrink) { + log.debug("Cache size was shrunk to {}", newCacheSizeBytes); + if (caches.values().isEmpty()) { + return; + } + final CircularIterator circularIterator = new CircularIterator<>(caches.values()); + while (sizeBytes() > maxCacheSizeBytes) { + final NamedCache cache = circularIterator.next(); + cache.evict(); + numEvicts++; + } + } else { + log.debug("Cache size was expanded to {}", newCacheSizeBytes); + } + } + + /** + * The thread cache maintains a set of {@link NamedCache}s whose names are a concatenation of the task ID and the + * underlying store name. This method creates those names. + * @param taskIDString Task ID + * @param underlyingStoreName Underlying store name + */ + public static String nameSpaceFromTaskIdAndStore(final String taskIDString, final String underlyingStoreName) { + return taskIDString + "-" + underlyingStoreName; + } + + /** + * Given a cache name of the form taskid-storename, return the task ID. + */ + public static String taskIDfromCacheName(final String cacheName) { + final String[] tokens = cacheName.split("-", 2); + return tokens[0]; + } + + /** + * Given a cache name of the form taskid-storename, return the store name. + */ + public static String underlyingStoreNamefromCacheName(final String cacheName) { + final String[] tokens = cacheName.split("-", 2); + return tokens[1]; + } + + + /** + * Add a listener that is called each time an entry is evicted from the cache or an explicit flush is called + */ + public void addDirtyEntryFlushListener(final String namespace, final DirtyEntryFlushListener listener) { + final NamedCache cache = getOrCreateCache(namespace); + cache.setListener(listener); + } + + public void flush(final String namespace) { + numFlushes++; + + final NamedCache cache = getCache(namespace); + if (cache == null) { + return; + } + cache.flush(); + + if (log.isTraceEnabled()) { + log.trace("Cache stats on flush: #puts={}, #gets={}, #evicts={}, #flushes={}", puts(), gets(), evicts(), flushes()); + } + } + + public LRUCacheEntry get(final String namespace, final Bytes key) { + numGets++; + + if (key == null) { + return null; + } + + final NamedCache cache = getCache(namespace); + if (cache == null) { + return null; + } + return cache.get(key); + } + + public void put(final String namespace, final Bytes key, final LRUCacheEntry value) { + numPuts++; + + final NamedCache cache = getOrCreateCache(namespace); + cache.put(key, value); + maybeEvict(namespace); + } + + public LRUCacheEntry putIfAbsent(final String namespace, final Bytes key, final LRUCacheEntry value) { + final NamedCache cache = getOrCreateCache(namespace); + + final LRUCacheEntry result = cache.putIfAbsent(key, value); + maybeEvict(namespace); + + if (result == null) { + numPuts++; + } + return result; + } + + public void putAll(final String namespace, final List> entries) { + for (final KeyValue entry : entries) { + put(namespace, entry.key, entry.value); + } + } + + public LRUCacheEntry delete(final String namespace, final Bytes key) { + final NamedCache cache = getCache(namespace); + if (cache == null) { + return null; + } + + return cache.delete(key); + } + + public MemoryLRUCacheBytesIterator range(final String namespace, final Bytes from, final Bytes to) { + return range(namespace, from, to, true); + } + + public MemoryLRUCacheBytesIterator range(final String namespace, final Bytes from, final Bytes to, final boolean toInclusive) { + final NamedCache cache = getCache(namespace); + if (cache == null) { + return new MemoryLRUCacheBytesIterator(Collections.emptyIterator(), new NamedCache(namespace, this.metrics)); + } + return new MemoryLRUCacheBytesIterator(cache.keyRange(from, to, toInclusive), cache); + } + + public MemoryLRUCacheBytesIterator reverseRange(final String namespace, final Bytes from, final Bytes to) { + final NamedCache cache = getCache(namespace); + if (cache == null) { + return new MemoryLRUCacheBytesIterator(Collections.emptyIterator(), new NamedCache(namespace, this.metrics)); + } + return new MemoryLRUCacheBytesIterator(cache.reverseKeyRange(from, to), cache); + } + + public MemoryLRUCacheBytesIterator all(final String namespace) { + final NamedCache cache = getCache(namespace); + if (cache == null) { + return new MemoryLRUCacheBytesIterator(Collections.emptyIterator(), new NamedCache(namespace, this.metrics)); + } + return new MemoryLRUCacheBytesIterator(cache.allKeys(), cache); + } + + public MemoryLRUCacheBytesIterator reverseAll(final String namespace) { + final NamedCache cache = getCache(namespace); + if (cache == null) { + return new MemoryLRUCacheBytesIterator(Collections.emptyIterator(), new NamedCache(namespace, this.metrics)); + } + return new MemoryLRUCacheBytesIterator(cache.reverseAllKeys(), cache); + } + + public long size() { + long size = 0; + for (final NamedCache cache : caches.values()) { + size += cache.size(); + if (isOverflowing(size)) { + return Long.MAX_VALUE; + } + } + return size; + } + + private boolean isOverflowing(final long size) { + return size < 0; + } + + long sizeBytes() { + long sizeInBytes = 0; + for (final NamedCache namedCache : caches.values()) { + sizeInBytes += namedCache.sizeInBytes(); + if (isOverflowing(sizeInBytes)) { + return Long.MAX_VALUE; + } + } + return sizeInBytes; + } + + synchronized void close(final String namespace) { + final NamedCache removed = caches.remove(namespace); + if (removed != null) { + removed.close(); + } + } + + private void maybeEvict(final String namespace) { + int numEvicted = 0; + while (sizeBytes() > maxCacheSizeBytes) { + final NamedCache cache = getOrCreateCache(namespace); + // we abort here as the put on this cache may have triggered + // a put on another cache. So even though the sizeInBytes() is + // still > maxCacheSizeBytes there is nothing to evict from this + // namespaced cache. + if (cache.isEmpty()) { + return; + } + cache.evict(); + numEvicts++; + numEvicted++; + } + if (log.isTraceEnabled()) { + log.trace("Evicted {} entries from cache {}", numEvicted, namespace); + } + } + + private synchronized NamedCache getCache(final String namespace) { + return caches.get(namespace); + } + + private synchronized NamedCache getOrCreateCache(final String name) { + NamedCache cache = caches.get(name); + if (cache == null) { + cache = new NamedCache(name, this.metrics); + caches.put(name, cache); + } + return cache; + } + + static class MemoryLRUCacheBytesIterator implements PeekingKeyValueIterator { + private final Iterator keys; + private final NamedCache cache; + private KeyValue nextEntry; + + MemoryLRUCacheBytesIterator(final Iterator keys, final NamedCache cache) { + this.keys = keys; + this.cache = cache; + } + + public Bytes peekNextKey() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return nextEntry.key; + } + + + public KeyValue peekNext() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return nextEntry; + } + + @Override + public boolean hasNext() { + if (nextEntry != null) { + return true; + } + + while (keys.hasNext() && nextEntry == null) { + internalNext(); + } + + return nextEntry != null; + } + + @Override + public KeyValue next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + final KeyValue result = nextEntry; + nextEntry = null; + return result; + } + + private void internalNext() { + final Bytes cacheKey = keys.next(); + final LRUCacheEntry entry = cache.get(cacheKey); + if (entry == null) { + return; + } + + nextEntry = new KeyValue<>(cacheKey, entry); + } + + @Override + public void close() { + // do nothing + } + } + + static class DirtyEntry { + private final Bytes key; + private final byte[] newValue; + private final LRUCacheEntry recordContext; + + DirtyEntry(final Bytes key, final byte[] newValue, final LRUCacheEntry recordContext) { + this.key = key; + this.newValue = newValue; + this.recordContext = recordContext; + } + + public Bytes key() { + return key; + } + + public byte[] newValue() { + return newValue; + } + + public LRUCacheEntry entry() { + return recordContext; + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/TimeOrderedKeyValueBuffer.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimeOrderedKeyValueBuffer.java new file mode 100644 index 0000000..e2096ac --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimeOrderedKeyValueBuffer.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.kstream.internals.Change; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.ProcessorRecordContext; +import org.apache.kafka.streams.processor.internals.SerdeGetter; +import org.apache.kafka.streams.state.ValueAndTimestamp; + +import java.util.Objects; +import java.util.function.Consumer; +import java.util.function.Supplier; + +public interface TimeOrderedKeyValueBuffer extends StateStore { + + final class Eviction { + private final K key; + private final Change value; + private final ProcessorRecordContext recordContext; + + Eviction(final K key, final Change value, final ProcessorRecordContext recordContext) { + this.key = key; + this.value = value; + this.recordContext = recordContext; + } + + public K key() { + return key; + } + + public Change value() { + return value; + } + + public Record> record() { + return new Record<>(key, value, recordContext.timestamp()); + } + + public ProcessorRecordContext recordContext() { + return recordContext; + } + + @Override + public String toString() { + return "Eviction{key=" + key + ", value=" + value + ", recordContext=" + recordContext + '}'; + } + + @Override + public boolean equals(final Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + final Eviction eviction = (Eviction) o; + return Objects.equals(key, eviction.key) && + Objects.equals(value, eviction.value) && + Objects.equals(recordContext, eviction.recordContext); + } + + @Override + public int hashCode() { + return Objects.hash(key, value, recordContext); + } + + } + + void setSerdesIfNull(final SerdeGetter getter); + + void evictWhile(final Supplier predicate, final Consumer> callback); + + Maybe> priorValueForBuffered(K key); + + void put(long time, Record> record, ProcessorRecordContext recordContext); + + int numRecords(); + + long bufferSize(); + + long minTimestamp(); +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/TimeOrderedKeyValueBufferChangelogDeserializationHelper.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimeOrderedKeyValueBufferChangelogDeserializationHelper.java new file mode 100644 index 0000000..74489c2 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimeOrderedKeyValueBufferChangelogDeserializationHelper.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.kstream.internals.Change; +import org.apache.kafka.streams.kstream.internals.FullChangeSerde; +import org.apache.kafka.streams.processor.internals.ProcessorRecordContext; + +import java.nio.ByteBuffer; + +import static java.util.Objects.requireNonNull; + +final class TimeOrderedKeyValueBufferChangelogDeserializationHelper { + private TimeOrderedKeyValueBufferChangelogDeserializationHelper() {} + + static final class DeserializationResult { + private final long time; + private final Bytes key; + private final BufferValue bufferValue; + + private DeserializationResult(final long time, final Bytes key, final BufferValue bufferValue) { + this.time = time; + this.key = key; + this.bufferValue = bufferValue; + } + + long time() { + return time; + } + + Bytes key() { + return key; + } + + BufferValue bufferValue() { + return bufferValue; + } + } + + static DeserializationResult deserializeV0(final ConsumerRecord record, + final Bytes key, + final byte[] previousBufferedValue) { + + final ByteBuffer timeAndValue = ByteBuffer.wrap(record.value()); + final long time = timeAndValue.getLong(); + final byte[] changelogValue = new byte[record.value().length - 8]; + timeAndValue.get(changelogValue); + + final Change change = requireNonNull(FullChangeSerde.decomposeLegacyFormattedArrayIntoChangeArrays(changelogValue)); + + final ProcessorRecordContext recordContext = new ProcessorRecordContext( + record.timestamp(), + record.offset(), + record.partition(), + record.topic(), + record.headers() + ); + + return new DeserializationResult( + time, + key, + new BufferValue( + previousBufferedValue == null ? change.oldValue : previousBufferedValue, + change.oldValue, + change.newValue, + recordContext + ) + ); + } + + static DeserializationResult deserializeV1(final ConsumerRecord record, + final Bytes key, + final byte[] previousBufferedValue) { + final ByteBuffer timeAndValue = ByteBuffer.wrap(record.value()); + final long time = timeAndValue.getLong(); + final byte[] changelogValue = new byte[record.value().length - 8]; + timeAndValue.get(changelogValue); + + final ContextualRecord contextualRecord = ContextualRecord.deserialize(ByteBuffer.wrap(changelogValue)); + final Change change = requireNonNull(FullChangeSerde.decomposeLegacyFormattedArrayIntoChangeArrays(contextualRecord.value())); + + return new DeserializationResult( + time, + key, + new BufferValue( + previousBufferedValue == null ? change.oldValue : previousBufferedValue, + change.oldValue, + change.newValue, + contextualRecord.recordContext() + ) + ); + } + + static DeserializationResult duckTypeV2(final ConsumerRecord record, final Bytes key) { + DeserializationResult deserializationResult = null; + RuntimeException v2DeserializationException = null; + RuntimeException v3DeserializationException = null; + try { + deserializationResult = deserializeV2(record, key); + } catch (final RuntimeException e) { + v2DeserializationException = e; + } + // versions 2.4.0, 2.4.1, and 2.5.0 would have erroneously encoded a V3 record with the + // V2 header, so we'll try duck-typing to see if this is decodable as V3 + if (deserializationResult == null) { + try { + deserializationResult = deserializeV3(record, key); + } catch (final RuntimeException e) { + v3DeserializationException = e; + } + } + + if (deserializationResult == null) { + // ok, it wasn't V3 either. Throw both exceptions: + final RuntimeException exception = + new RuntimeException("Couldn't deserialize record as v2 or v3: " + record, + v2DeserializationException); + exception.addSuppressed(v3DeserializationException); + throw exception; + } + return deserializationResult; + } + + private static DeserializationResult deserializeV2(final ConsumerRecord record, + final Bytes key) { + final ByteBuffer valueAndTime = ByteBuffer.wrap(record.value()); + final ContextualRecord contextualRecord = ContextualRecord.deserialize(valueAndTime); + final Change change = requireNonNull(FullChangeSerde.decomposeLegacyFormattedArrayIntoChangeArrays(contextualRecord.value())); + final byte[] priorValue = Utils.getNullableSizePrefixedArray(valueAndTime); + final long time = valueAndTime.getLong(); + final BufferValue bufferValue = new BufferValue(priorValue, change.oldValue, change.newValue, contextualRecord.recordContext()); + return new DeserializationResult(time, key, bufferValue); + } + + static DeserializationResult deserializeV3(final ConsumerRecord record, final Bytes key) { + final ByteBuffer valueAndTime = ByteBuffer.wrap(record.value()); + final BufferValue bufferValue = BufferValue.deserialize(valueAndTime); + final long time = valueAndTime.getLong(); + return new DeserializationResult(time, key, bufferValue); + } +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSide.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSide.java new file mode 100644 index 0000000..c0516e1 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSide.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.KeyValue; + +import java.util.Objects; + +/** + * Combines a timestamped key from a {@link KeyValue} with a boolean value referencing if the key is + * part of the left join (true) or right join (false). This class is only useful when a state + * store needs to be shared between left and right processors, and each processor needs to + * access the key of the other processor. + * + * Note that it might be cleaner to have two layers for such usages: first a KeyAndJoinSide, where the Key + * is in the form of a ; but with the nested structure serdes would need extra byte array copies. + * Since it is only used in a single place today we decided to combine them into a single type / serde. + */ +public class TimestampedKeyAndJoinSide { + private final K key; + private final long timestamp; + private final boolean leftSide; + + private TimestampedKeyAndJoinSide(final boolean leftSide, final K key, final long timestamp) { + this.key = Objects.requireNonNull(key, "key cannot be null"); + this.leftSide = leftSide; + this.timestamp = timestamp; + } + + /** + * Create a new {@link TimestampedKeyAndJoinSide} instance if the provide {@code key} is not {@code null}. + * + * @param leftSide True if the key is part of the left join side; False if it is from the right join side + * @param key the key + * @param the type of the key + * @return a new {@link TimestampedKeyAndJoinSide} instance if the provide {@code key} is not {@code null} + */ + public static TimestampedKeyAndJoinSide make(final boolean leftSide, final K key, final long timestamp) { + return new TimestampedKeyAndJoinSide<>(leftSide, key, timestamp); + } + + public boolean isLeftSide() { + return leftSide; + } + + public K getKey() { + return key; + } + + public long getTimestamp() { + return timestamp; + } + + @Override + public String toString() { + final String joinSide = leftSide ? "left" : "right"; + return "<" + joinSide + "," + key + ":" + timestamp + ">"; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final TimestampedKeyAndJoinSide that = (TimestampedKeyAndJoinSide) o; + return leftSide == that.leftSide && + Objects.equals(key, that.key) && + timestamp == that.timestamp; + } + + @Override + public int hashCode() { + return Objects.hash(leftSide, key, timestamp); + } +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSideDeserializer.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSideDeserializer.java new file mode 100644 index 0000000..9ecea46 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSideDeserializer.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.LongDeserializer; +import org.apache.kafka.streams.kstream.internals.WrappingNullableDeserializer; +import org.apache.kafka.streams.processor.internals.SerdeGetter; +import org.apache.kafka.streams.state.StateSerdes; + +import java.util.Map; + +import static org.apache.kafka.streams.kstream.internals.WrappingNullableUtils.initNullableDeserializer; + +/** + * The deserializer that is used for {@link TimestampedKeyAndJoinSide}, which is a combo key format of + * @param the raw key type + */ +public class TimestampedKeyAndJoinSideDeserializer implements WrappingNullableDeserializer, K, Void> { + private Deserializer keyDeserializer; + private final Deserializer timestampDeserializer = new LongDeserializer(); + + TimestampedKeyAndJoinSideDeserializer(final Deserializer keyDeserializer) { + this.keyDeserializer = keyDeserializer; + } + + @SuppressWarnings("unchecked") + @Override + public void setIfUnset(final SerdeGetter getter) { + if (keyDeserializer == null) { + keyDeserializer = (Deserializer) getter.keySerde().deserializer(); + } + + initNullableDeserializer(keyDeserializer, getter); + } + + @Override + public void configure(final Map configs, final boolean isKey) { + keyDeserializer.configure(configs, isKey); + } + + @Override + public TimestampedKeyAndJoinSide deserialize(final String topic, final byte[] data) { + final boolean bool = data[StateSerdes.TIMESTAMP_SIZE] == 1; + final K key = keyDeserializer.deserialize(topic, rawKey(data)); + final long timestamp = timestampDeserializer.deserialize(topic, rawTimestamp(data)); + + return TimestampedKeyAndJoinSide.make(bool, key, timestamp); + } + + private byte[] rawTimestamp(final byte[] data) { + final byte[] rawTimestamp = new byte[8]; + System.arraycopy(data, 0, rawTimestamp, 0, 8); + return rawTimestamp; + } + + private byte[] rawKey(final byte[] data) { + final byte[] rawKey = new byte[data.length - StateSerdes.TIMESTAMP_SIZE - StateSerdes.BOOLEAN_SIZE]; + System.arraycopy(data, StateSerdes.TIMESTAMP_SIZE + StateSerdes.BOOLEAN_SIZE, rawKey, 0, rawKey.length); + return rawKey; + } + + @Override + public void close() { + keyDeserializer.close(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSideSerde.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSideSerde.java new file mode 100644 index 0000000..6ae1923 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSideSerde.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.kstream.internals.WrappingNullableSerde; + +public class TimestampedKeyAndJoinSideSerde extends WrappingNullableSerde, K, Void> { + + public TimestampedKeyAndJoinSideSerde(final Serde keySerde) { + super( + new TimestampedKeyAndJoinSideSerializer<>(keySerde != null ? keySerde.serializer() : null), + new TimestampedKeyAndJoinSideDeserializer<>(keySerde != null ? keySerde.deserializer() : null) + ); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSideSerializer.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSideSerializer.java new file mode 100644 index 0000000..801c417 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSideSerializer.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.LongSerializer; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.kstream.internals.WrappingNullableSerializer; +import org.apache.kafka.streams.processor.internals.SerdeGetter; + +import java.nio.ByteBuffer; +import java.util.Map; + +import static org.apache.kafka.streams.kstream.internals.WrappingNullableUtils.initNullableSerializer; + +/** + * The serializer that is used for {@link TimestampedKeyAndJoinSide}, which is a combo key format of + * @param the raw key type + */ +public class TimestampedKeyAndJoinSideSerializer implements WrappingNullableSerializer, K, Void> { + private Serializer keySerializer; + private final Serializer timestampSerializer = new LongSerializer(); + + TimestampedKeyAndJoinSideSerializer(final Serializer keySerializer) { + this.keySerializer = keySerializer; + } + + @SuppressWarnings("unchecked") + @Override + public void setIfUnset(final SerdeGetter getter) { + if (keySerializer == null) { + keySerializer = (Serializer) getter.keySerde().serializer(); + } + + initNullableSerializer(keySerializer, getter); + } + + @Override + public void configure(final Map configs, final boolean isKey) { + keySerializer.configure(configs, isKey); + } + + @Override + public byte[] serialize(final String topic, final TimestampedKeyAndJoinSide data) { + final byte boolByte = (byte) (data.isLeftSide() ? 1 : 0); + final byte[] keyBytes = keySerializer.serialize(topic, data.getKey()); + final byte[] timestampBytes = timestampSerializer.serialize(topic, data.getTimestamp()); + + return ByteBuffer + .allocate(timestampBytes.length + 1 + keyBytes.length) + .put(timestampBytes) + .put(boolByte) + .put(keyBytes) + .array(); + } + + @Override + public void close() { + keySerializer.close(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedKeyValueStoreBuilder.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedKeyValueStoreBuilder.java new file mode 100644 index 0000000..a249a14 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedKeyValueStoreBuilder.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.state.KeyValueBytesStoreSupplier; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.TimestampedBytesStore; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; + +import java.util.List; +import java.util.Objects; + +public class TimestampedKeyValueStoreBuilder + extends AbstractStoreBuilder, TimestampedKeyValueStore> { + + private final KeyValueBytesStoreSupplier storeSupplier; + + public TimestampedKeyValueStoreBuilder(final KeyValueBytesStoreSupplier storeSupplier, + final Serde keySerde, + final Serde valueSerde, + final Time time) { + super( + storeSupplier.name(), + keySerde, + valueSerde == null ? null : new ValueAndTimestampSerde<>(valueSerde), + time); + Objects.requireNonNull(storeSupplier, "storeSupplier can't be null"); + Objects.requireNonNull(storeSupplier.metricsScope(), "storeSupplier's metricsScope can't be null"); + this.storeSupplier = storeSupplier; + } + + @Override + public TimestampedKeyValueStore build() { + KeyValueStore store = storeSupplier.get(); + if (!(store instanceof TimestampedBytesStore)) { + if (store.persistent()) { + store = new KeyValueToTimestampedKeyValueByteStoreAdapter(store); + } else { + store = new InMemoryTimestampedKeyValueStoreMarker(store); + } + } + return new MeteredTimestampedKeyValueStore<>( + maybeWrapCaching(maybeWrapLogging(store)), + storeSupplier.metricsScope(), + time, + keySerde, + valueSerde); + } + + private KeyValueStore maybeWrapCaching(final KeyValueStore inner) { + if (!enableCaching) { + return inner; + } + return new CachingKeyValueStore(inner); + } + + private KeyValueStore maybeWrapLogging(final KeyValueStore inner) { + if (!enableLogging) { + return inner; + } + return new ChangeLoggingTimestampedKeyValueBytesStore(inner); + } + + private final static class InMemoryTimestampedKeyValueStoreMarker + implements KeyValueStore, TimestampedBytesStore { + + final KeyValueStore wrapped; + + private InMemoryTimestampedKeyValueStoreMarker(final KeyValueStore wrapped) { + if (wrapped.persistent()) { + throw new IllegalArgumentException("Provided store must not be a persistent store, but it is."); + } + this.wrapped = wrapped; + } + + @Deprecated + @Override + public void init(final ProcessorContext context, + final StateStore root) { + wrapped.init(context, root); + } + + @Override + public void init(final StateStoreContext context, final StateStore root) { + wrapped.init(context, root); + } + + @Override + public void put(final Bytes key, + final byte[] value) { + wrapped.put(key, value); + } + + @Override + public byte[] putIfAbsent(final Bytes key, + final byte[] value) { + return wrapped.putIfAbsent(key, value); + } + + @Override + public void putAll(final List> entries) { + wrapped.putAll(entries); + } + + @Override + public byte[] delete(final Bytes key) { + return wrapped.delete(key); + } + + @Override + public byte[] get(final Bytes key) { + return wrapped.get(key); + } + + @Override + public KeyValueIterator range(final Bytes from, + final Bytes to) { + return wrapped.range(from, to); + } + + @Override + public KeyValueIterator reverseRange(final Bytes from, + final Bytes to) { + return wrapped.reverseRange(from, to); + } + + @Override + public KeyValueIterator all() { + return wrapped.all(); + } + + @Override + public KeyValueIterator reverseAll() { + return wrapped.reverseAll(); + } + + @Override + public , P> KeyValueIterator prefixScan(final P prefix, + final PS prefixKeySerializer) { + return wrapped.prefixScan(prefix, prefixKeySerializer); + } + + @Override + public long approximateNumEntries() { + return wrapped.approximateNumEntries(); + } + + @Override + public void flush() { + wrapped.flush(); + } + + @Override + public void close() { + wrapped.close(); + } + + @Override + public boolean isOpen() { + return wrapped.isOpen(); + } + + @Override + public String name() { + return wrapped.name(); + } + + @Override + public boolean persistent() { + return false; + } + } +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedSegment.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedSegment.java new file mode 100644 index 0000000..f0e4cf6 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedSegment.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.state.internals.metrics.RocksDBMetricsRecorder; + +import java.io.File; +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +class TimestampedSegment extends RocksDBTimestampedStore implements Comparable, Segment { + public final long id; + + TimestampedSegment(final String segmentName, + final String windowName, + final long id, + final RocksDBMetricsRecorder metricsRecorder) { + super(segmentName, windowName, metricsRecorder); + this.id = id; + } + + @Override + public void destroy() throws IOException { + Utils.delete(dbDir); + } + + @Override + public void deleteRange(final Bytes keyFrom, final Bytes keyTo) { + throw new UnsupportedOperationException(); + } + + @Override + public int compareTo(final TimestampedSegment segment) { + return Long.compare(id, segment.id); + } + + @Override + public void openDB(final Map configs, final File stateDir) { + super.openDB(configs, stateDir); + // skip the registering step + } + + @Override + public String toString() { + return "TimestampedSegment(id=" + id + ", name=" + name() + ")"; + } + + @Override + public boolean equals(final Object obj) { + if (obj == null || getClass() != obj.getClass()) { + return false; + } + final TimestampedSegment segment = (TimestampedSegment) obj; + return id == segment.id; + } + + @Override + public int hashCode() { + return Objects.hash(id); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedSegments.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedSegments.java new file mode 100644 index 0000000..7318208 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedSegments.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.internals.ProcessorContextUtils; +import org.apache.kafka.streams.state.internals.metrics.RocksDBMetricsRecorder; + +/** + * Manages the {@link TimestampedSegment}s that are used by the {@link RocksDBTimestampedSegmentedBytesStore} + */ +class TimestampedSegments extends AbstractSegments { + + private final RocksDBMetricsRecorder metricsRecorder; + + TimestampedSegments(final String name, + final String metricsScope, + final long retentionPeriod, + final long segmentInterval) { + super(name, retentionPeriod, segmentInterval); + metricsRecorder = new RocksDBMetricsRecorder(metricsScope, name); + } + + @Override + public TimestampedSegment getOrCreateSegment(final long segmentId, + final ProcessorContext context) { + if (segments.containsKey(segmentId)) { + return segments.get(segmentId); + } else { + final TimestampedSegment newSegment = + new TimestampedSegment(segmentName(segmentId), name, segmentId, metricsRecorder); + + if (segments.put(segmentId, newSegment) != null) { + throw new IllegalStateException("TimestampedSegment already exists. Possible concurrent access."); + } + + newSegment.openDB(context.appConfigs(), context.stateDir()); + return newSegment; + } + } + + @Override + public void openExisting(final ProcessorContext context, final long streamTime) { + metricsRecorder.init(ProcessorContextUtils.getMetricsImpl(context), context.taskId()); + super.openExisting(context, streamTime); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedWindowStoreBuilder.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedWindowStoreBuilder.java new file mode 100644 index 0000000..b3727f5 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/TimestampedWindowStoreBuilder.java @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.TimestampedBytesStore; +import org.apache.kafka.streams.state.TimestampedWindowStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.streams.state.WindowBytesStoreSupplier; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.streams.state.WindowStoreIterator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Objects; + +public class TimestampedWindowStoreBuilder + extends AbstractStoreBuilder, TimestampedWindowStore> { + + private static final Logger LOG = LoggerFactory.getLogger(TimestampedWindowStoreBuilder.class); + + private final WindowBytesStoreSupplier storeSupplier; + + public TimestampedWindowStoreBuilder(final WindowBytesStoreSupplier storeSupplier, + final Serde keySerde, + final Serde valueSerde, + final Time time) { + super(storeSupplier.name(), keySerde, valueSerde == null ? null : new ValueAndTimestampSerde<>(valueSerde), time); + Objects.requireNonNull(storeSupplier, "storeSupplier can't be null"); + Objects.requireNonNull(storeSupplier.metricsScope(), "storeSupplier's metricsScope can't be null"); + this.storeSupplier = storeSupplier; + } + + @Override + public TimestampedWindowStore build() { + WindowStore store = storeSupplier.get(); + if (!(store instanceof TimestampedBytesStore)) { + if (store.persistent()) { + store = new WindowToTimestampedWindowByteStoreAdapter(store); + } else { + store = new InMemoryTimestampedWindowStoreMarker(store); + } + } + if (storeSupplier.retainDuplicates() && enableCaching) { + LOG.warn("Disabling caching for {} since store was configured to retain duplicates", storeSupplier.name()); + enableCaching = false; + } + + return new MeteredTimestampedWindowStore<>( + maybeWrapCaching(maybeWrapLogging(store)), + storeSupplier.windowSize(), + storeSupplier.metricsScope(), + time, + keySerde, + valueSerde); + } + + private WindowStore maybeWrapCaching(final WindowStore inner) { + if (!enableCaching) { + return inner; + } + return new CachingWindowStore( + inner, + storeSupplier.windowSize(), + storeSupplier.segmentIntervalMs()); + } + + private WindowStore maybeWrapLogging(final WindowStore inner) { + if (!enableLogging) { + return inner; + } + return new ChangeLoggingTimestampedWindowBytesStore(inner, storeSupplier.retainDuplicates()); + } + + public long retentionPeriod() { + return storeSupplier.retentionPeriod(); + } + + + private final static class InMemoryTimestampedWindowStoreMarker + implements WindowStore, TimestampedBytesStore { + + private final WindowStore wrapped; + + private InMemoryTimestampedWindowStoreMarker(final WindowStore wrapped) { + if (wrapped.persistent()) { + throw new IllegalArgumentException("Provided store must not be a persistent store, but it is."); + } + this.wrapped = wrapped; + } + + @Deprecated + @Override + public void init(final ProcessorContext context, + final StateStore root) { + wrapped.init(context, root); + } + + @Override + public void init(final StateStoreContext context, final StateStore root) { + wrapped.init(context, root); + } + + @Override + public void put(final Bytes key, + final byte[] value, + final long windowStartTimestamp) { + wrapped.put(key, value, windowStartTimestamp); + } + + @Override + public byte[] fetch(final Bytes key, + final long time) { + return wrapped.fetch(key, time); + } + + @Override + public WindowStoreIterator fetch(final Bytes key, + final long timeFrom, + final long timeTo) { + return wrapped.fetch(key, timeFrom, timeTo); + } + + @Override + public WindowStoreIterator backwardFetch(final Bytes key, + final long timeFrom, + final long timeTo) { + return wrapped.backwardFetch(key, timeFrom, timeTo); + } + + @Override + public KeyValueIterator, byte[]> fetch(final Bytes keyFrom, + final Bytes keyTo, + final long timeFrom, + final long timeTo) { + return wrapped.fetch(keyFrom, keyTo, timeFrom, timeTo); + } + + @Override + public KeyValueIterator, byte[]> backwardFetch(final Bytes keyFrom, + final Bytes keyTo, + final long timeFrom, + final long timeTo) { + return wrapped.backwardFetch(keyFrom, keyTo, timeFrom, timeTo); + } + + @Override + public KeyValueIterator, byte[]> fetchAll(final long timeFrom, + final long timeTo) { + return wrapped.fetchAll(timeFrom, timeTo); + } + + @Override + public KeyValueIterator, byte[]> backwardFetchAll(final long timeFrom, + final long timeTo) { + return wrapped.backwardFetchAll(timeFrom, timeTo); + } + + @Override + public KeyValueIterator, byte[]> all() { + return wrapped.all(); + } + + @Override + public KeyValueIterator, byte[]> backwardAll() { + return wrapped.backwardAll(); + } + + @Override + public void flush() { + wrapped.flush(); + } + + @Override + public void close() { + wrapped.close(); + } + + @Override + public boolean isOpen() { + return wrapped.isOpen(); + } + + @Override + public String name() { + return wrapped.name(); + } + + @Override + public boolean persistent() { + return false; + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/ValueAndTimestampDeserializer.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/ValueAndTimestampDeserializer.java new file mode 100644 index 0000000..e67d7a6 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/ValueAndTimestampDeserializer.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.LongDeserializer; +import org.apache.kafka.streams.kstream.internals.WrappingNullableDeserializer; +import org.apache.kafka.streams.processor.internals.SerdeGetter; +import org.apache.kafka.streams.state.ValueAndTimestamp; + +import java.nio.ByteBuffer; +import java.util.Map; +import java.util.Objects; + +import static org.apache.kafka.streams.kstream.internals.WrappingNullableUtils.initNullableDeserializer; + +class ValueAndTimestampDeserializer implements WrappingNullableDeserializer, Void, V> { + private final static LongDeserializer LONG_DESERIALIZER = new LongDeserializer(); + + public final Deserializer valueDeserializer; + private final Deserializer timestampDeserializer; + + ValueAndTimestampDeserializer(final Deserializer valueDeserializer) { + Objects.requireNonNull(valueDeserializer); + this.valueDeserializer = valueDeserializer; + timestampDeserializer = new LongDeserializer(); + } + + @Override + public void configure(final Map configs, + final boolean isKey) { + valueDeserializer.configure(configs, isKey); + timestampDeserializer.configure(configs, isKey); + } + + @Override + public ValueAndTimestamp deserialize(final String topic, + final byte[] valueAndTimestamp) { + if (valueAndTimestamp == null) { + return null; + } + + final long timestamp = timestampDeserializer.deserialize(topic, rawTimestamp(valueAndTimestamp)); + final V value = valueDeserializer.deserialize(topic, rawValue(valueAndTimestamp)); + return ValueAndTimestamp.make(value, timestamp); + } + + @Override + public void close() { + valueDeserializer.close(); + timestampDeserializer.close(); + } + + static byte[] rawValue(final byte[] rawValueAndTimestamp) { + final int rawValueLength = rawValueAndTimestamp.length - 8; + + return ByteBuffer + .allocate(rawValueLength) + .put(rawValueAndTimestamp, 8, rawValueLength) + .array(); + } + + private static byte[] rawTimestamp(final byte[] rawValueAndTimestamp) { + return ByteBuffer + .allocate(8) + .put(rawValueAndTimestamp, 0, 8) + .array(); + } + + static long timestamp(final byte[] rawValueAndTimestamp) { + return LONG_DESERIALIZER.deserialize(null, rawTimestamp(rawValueAndTimestamp)); + } + + @Override + public void setIfUnset(final SerdeGetter getter) { + // ValueAndTimestampDeserializer never wraps a null deserializer (or configure would throw), + // but it may wrap a deserializer that itself wraps a null deserializer. + initNullableDeserializer(valueDeserializer, getter); + } +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/ValueAndTimestampSerde.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/ValueAndTimestampSerde.java new file mode 100644 index 0000000..1936d29 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/ValueAndTimestampSerde.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.kstream.internals.WrappingNullableSerde; +import org.apache.kafka.streams.state.ValueAndTimestamp; + +import static java.util.Objects.requireNonNull; + +public class ValueAndTimestampSerde extends WrappingNullableSerde, Void, V> { + public ValueAndTimestampSerde(final Serde valueSerde) { + super( + new ValueAndTimestampSerializer<>(requireNonNull(valueSerde, "valueSerde was null").serializer()), + new ValueAndTimestampDeserializer<>(requireNonNull(valueSerde, "valueSerde was null").deserializer()) + ); + } +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/ValueAndTimestampSerializer.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/ValueAndTimestampSerializer.java new file mode 100644 index 0000000..e64e5e1 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/ValueAndTimestampSerializer.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.LongSerializer; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.kstream.internals.WrappingNullableSerializer; +import org.apache.kafka.streams.processor.internals.SerdeGetter; +import org.apache.kafka.streams.state.ValueAndTimestamp; + +import java.nio.ByteBuffer; +import java.util.Map; +import java.util.Objects; + +import static org.apache.kafka.streams.kstream.internals.WrappingNullableUtils.initNullableSerializer; + +public class ValueAndTimestampSerializer implements WrappingNullableSerializer, Void, V> { + public final Serializer valueSerializer; + private final Serializer timestampSerializer; + + ValueAndTimestampSerializer(final Serializer valueSerializer) { + Objects.requireNonNull(valueSerializer); + this.valueSerializer = valueSerializer; + timestampSerializer = new LongSerializer(); + } + + public static boolean valuesAreSameAndTimeIsIncreasing(final byte[] oldRecord, final byte[] newRecord) { + if (oldRecord == newRecord) { + // same reference, so they are trivially the same (might both be null) + return true; + } else if (oldRecord == null || newRecord == null) { + // only one is null, so they cannot be the same + return false; + } else if (newRecord.length != oldRecord.length) { + // they are different length, so they cannot be the same + return false; + } else if (timeIsDecreasing(oldRecord, newRecord)) { + // the record time represents the beginning of the validity interval, so if the time + // moves backwards, we need to do the update regardless of whether the value has changed + return false; + } else { + // all other checks have fallen through, so we actually compare the binary data of the two values + return valuesAreSame(oldRecord, newRecord); + } + } + + @Override + public void configure(final Map configs, + final boolean isKey) { + valueSerializer.configure(configs, isKey); + timestampSerializer.configure(configs, isKey); + } + + @Override + public byte[] serialize(final String topic, + final ValueAndTimestamp data) { + if (data == null) { + return null; + } + return serialize(topic, data.value(), data.timestamp()); + } + + public byte[] serialize(final String topic, + final V data, + final long timestamp) { + if (data == null) { + return null; + } + final byte[] rawValue = valueSerializer.serialize(topic, data); + + // Since we can't control the result of the internal serializer, we make sure that the result + // is not null as well. + // Serializing non-null values to null can be useful when working with Optional-like values + // where the Optional.empty case is serialized to null. + // See the discussion here: https://github.com/apache/kafka/pull/7679 + if (rawValue == null) { + return null; + } + + final byte[] rawTimestamp = timestampSerializer.serialize(topic, timestamp); + return ByteBuffer + .allocate(rawTimestamp.length + rawValue.length) + .put(rawTimestamp) + .put(rawValue) + .array(); + } + + @Override + public void close() { + valueSerializer.close(); + timestampSerializer.close(); + } + + private static boolean timeIsDecreasing(final byte[] oldRecord, final byte[] newRecord) { + return extractTimestamp(newRecord) <= extractTimestamp(oldRecord); + } + + private static long extractTimestamp(final byte[] bytes) { + final byte[] timestampBytes = new byte[Long.BYTES]; + System.arraycopy(bytes, 0, timestampBytes, 0, Long.BYTES); + return ByteBuffer.wrap(timestampBytes).getLong(); + } + + private static boolean valuesAreSame(final byte[] left, final byte[] right) { + for (int i = Long.BYTES; i < left.length; i++) { + if (left[i] != right[i]) { + return false; + } + } + return true; + } + + @Override + public void setIfUnset(final SerdeGetter getter) { + // ValueAndTimestampSerializer never wraps a null serializer (or configure would throw), + // but it may wrap a serializer that itself wraps a null serializer. + initNullableSerializer(valueSerializer, getter); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/WindowKeySchema.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/WindowKeySchema.java new file mode 100644 index 0000000..5834f94 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/WindowKeySchema.java @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.kstream.Window; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.TimeWindow; +import org.apache.kafka.streams.state.StateSerdes; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; +import java.util.List; + +import static org.apache.kafka.streams.state.StateSerdes.TIMESTAMP_SIZE; + +public class WindowKeySchema implements RocksDBSegmentedBytesStore.KeySchema { + + private static final Logger LOG = LoggerFactory.getLogger(WindowKeySchema.class); + + private static final int SEQNUM_SIZE = 4; + private static final int SUFFIX_SIZE = TIMESTAMP_SIZE + SEQNUM_SIZE; + private static final byte[] MIN_SUFFIX = new byte[SUFFIX_SIZE]; + + @Override + public Bytes upperRange(final Bytes key, final long to) { + final byte[] maxSuffix = ByteBuffer.allocate(SUFFIX_SIZE) + .putLong(to) + .putInt(Integer.MAX_VALUE) + .array(); + + return OrderedBytes.upperRange(key, maxSuffix); + } + + @Override + public Bytes lowerRange(final Bytes key, final long from) { + return OrderedBytes.lowerRange(key, MIN_SUFFIX); + } + + @Override + public Bytes lowerRangeFixedSize(final Bytes key, final long from) { + return WindowKeySchema.toStoreKeyBinary(key, Math.max(0, from), 0); + } + + @Override + public Bytes upperRangeFixedSize(final Bytes key, final long to) { + return WindowKeySchema.toStoreKeyBinary(key, to, Integer.MAX_VALUE); + } + + @Override + public long segmentTimestamp(final Bytes key) { + return WindowKeySchema.extractStoreTimestamp(key.get()); + } + + @Override + public HasNextCondition hasNextCondition(final Bytes binaryKeyFrom, + final Bytes binaryKeyTo, + final long from, + final long to) { + return iterator -> { + while (iterator.hasNext()) { + final Bytes bytes = iterator.peekNextKey(); + final Bytes keyBytes = Bytes.wrap(WindowKeySchema.extractStoreKeyBytes(bytes.get())); + final long time = WindowKeySchema.extractStoreTimestamp(bytes.get()); + if ((binaryKeyFrom == null || keyBytes.compareTo(binaryKeyFrom) >= 0) + && (binaryKeyTo == null || keyBytes.compareTo(binaryKeyTo) <= 0) + && time >= from + && time <= to) { + return true; + } + iterator.next(); + } + return false; + }; + } + + @Override + public List segmentsToSearch(final Segments segments, + final long from, + final long to, + final boolean forward) { + return segments.segments(from, to, forward); + } + + /** + * Safely construct a time window of the given size, + * taking care of bounding endMs to Long.MAX_VALUE if necessary + */ + static TimeWindow timeWindowForSize(final long startMs, + final long windowSize) { + long endMs = startMs + windowSize; + + if (endMs < 0) { + LOG.warn("Warning: window end time was truncated to Long.MAX"); + endMs = Long.MAX_VALUE; + } + return new TimeWindow(startMs, endMs); + } + + // for pipe serdes + + public static byte[] toBinary(final Windowed timeKey, + final Serializer serializer, + final String topic) { + final byte[] bytes = serializer.serialize(topic, timeKey.key()); + final ByteBuffer buf = ByteBuffer.allocate(bytes.length + TIMESTAMP_SIZE); + buf.put(bytes); + buf.putLong(timeKey.window().start()); + + return buf.array(); + } + + public static Windowed from(final byte[] binaryKey, + final long windowSize, + final Deserializer deserializer, + final String topic) { + final byte[] bytes = new byte[binaryKey.length - TIMESTAMP_SIZE]; + System.arraycopy(binaryKey, 0, bytes, 0, bytes.length); + final K key = deserializer.deserialize(topic, bytes); + final Window window = extractWindow(binaryKey, windowSize); + return new Windowed<>(key, window); + } + + private static Window extractWindow(final byte[] binaryKey, + final long windowSize) { + final ByteBuffer buffer = ByteBuffer.wrap(binaryKey); + final long start = buffer.getLong(binaryKey.length - TIMESTAMP_SIZE); + return timeWindowForSize(start, windowSize); + } + + // for store serdes + + public static Bytes toStoreKeyBinary(final Bytes key, + final long timestamp, + final int seqnum) { + final byte[] serializedKey = key.get(); + return toStoreKeyBinary(serializedKey, timestamp, seqnum); + } + + public static Bytes toStoreKeyBinary(final K key, + final long timestamp, + final int seqnum, + final StateSerdes serdes) { + final byte[] serializedKey = serdes.rawKey(key); + return toStoreKeyBinary(serializedKey, timestamp, seqnum); + } + + public static Bytes toStoreKeyBinary(final Windowed timeKey, + final int seqnum) { + final byte[] bytes = timeKey.key().get(); + return toStoreKeyBinary(bytes, timeKey.window().start(), seqnum); + } + + public static Bytes toStoreKeyBinary(final Windowed timeKey, + final int seqnum, + final StateSerdes serdes) { + final byte[] serializedKey = serdes.rawKey(timeKey.key()); + return toStoreKeyBinary(serializedKey, timeKey.window().start(), seqnum); + } + + // package private for testing + static Bytes toStoreKeyBinary(final byte[] serializedKey, + final long timestamp, + final int seqnum) { + final ByteBuffer buf = ByteBuffer.allocate(serializedKey.length + TIMESTAMP_SIZE + SEQNUM_SIZE); + buf.put(serializedKey); + buf.putLong(timestamp); + buf.putInt(seqnum); + + return Bytes.wrap(buf.array()); + } + + static byte[] extractStoreKeyBytes(final byte[] binaryKey) { + final byte[] bytes = new byte[binaryKey.length - TIMESTAMP_SIZE - SEQNUM_SIZE]; + System.arraycopy(binaryKey, 0, bytes, 0, bytes.length); + return bytes; + } + + static K extractStoreKey(final byte[] binaryKey, + final StateSerdes serdes) { + final byte[] bytes = new byte[binaryKey.length - TIMESTAMP_SIZE - SEQNUM_SIZE]; + System.arraycopy(binaryKey, 0, bytes, 0, bytes.length); + return serdes.keyFrom(bytes); + } + + static long extractStoreTimestamp(final byte[] binaryKey) { + return ByteBuffer.wrap(binaryKey).getLong(binaryKey.length - TIMESTAMP_SIZE - SEQNUM_SIZE); + } + + static int extractStoreSequence(final byte[] binaryKey) { + return ByteBuffer.wrap(binaryKey).getInt(binaryKey.length - SEQNUM_SIZE); + } + + public static Windowed fromStoreKey(final byte[] binaryKey, + final long windowSize, + final Deserializer deserializer, + final String topic) { + final K key = deserializer.deserialize(topic, extractStoreKeyBytes(binaryKey)); + final Window window = extractStoreWindow(binaryKey, windowSize); + return new Windowed<>(key, window); + } + + public static Windowed fromStoreKey(final Windowed windowedKey, + final Deserializer deserializer, + final String topic) { + final K key = deserializer.deserialize(topic, windowedKey.key().get()); + return new Windowed<>(key, windowedKey.window()); + } + + public static Windowed fromStoreBytesKey(final byte[] binaryKey, + final long windowSize) { + final Bytes key = Bytes.wrap(extractStoreKeyBytes(binaryKey)); + final Window window = extractStoreWindow(binaryKey, windowSize); + return new Windowed<>(key, window); + } + + static Window extractStoreWindow(final byte[] binaryKey, + final long windowSize) { + final ByteBuffer buffer = ByteBuffer.wrap(binaryKey); + final long start = buffer.getLong(binaryKey.length - TIMESTAMP_SIZE - SEQNUM_SIZE); + return timeWindowForSize(start, windowSize); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/WindowStoreBuilder.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/WindowStoreBuilder.java new file mode 100644 index 0000000..a1b1ead --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/WindowStoreBuilder.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.config.TopicConfig; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.state.WindowBytesStoreSupplier; +import org.apache.kafka.streams.state.WindowStore; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Objects; + +public class WindowStoreBuilder extends AbstractStoreBuilder> { + private final Logger log = LoggerFactory.getLogger(WindowStoreBuilder.class); + + private final WindowBytesStoreSupplier storeSupplier; + + public WindowStoreBuilder(final WindowBytesStoreSupplier storeSupplier, + final Serde keySerde, + final Serde valueSerde, + final Time time) { + super(storeSupplier.name(), keySerde, valueSerde, time); + Objects.requireNonNull(storeSupplier, "storeSupplier can't be null"); + Objects.requireNonNull(storeSupplier.metricsScope(), "storeSupplier's metricsScope can't be null"); + this.storeSupplier = storeSupplier; + + if (storeSupplier.retainDuplicates()) { + this.logConfig.put(TopicConfig.CLEANUP_POLICY_CONFIG, TopicConfig.CLEANUP_POLICY_DELETE); + } + } + + @Override + public WindowStore build() { + if (storeSupplier.retainDuplicates() && enableCaching) { + log.warn("Disabling caching for {} since store was configured to retain duplicates", storeSupplier.name()); + enableCaching = false; + } + + return new MeteredWindowStore<>( + maybeWrapCaching(maybeWrapLogging(storeSupplier.get())), + storeSupplier.windowSize(), + storeSupplier.metricsScope(), + time, + keySerde, + valueSerde); + } + + private WindowStore maybeWrapCaching(final WindowStore inner) { + if (!enableCaching) { + return inner; + } + return new CachingWindowStore( + inner, + storeSupplier.windowSize(), + storeSupplier.segmentIntervalMs()); + } + + private WindowStore maybeWrapLogging(final WindowStore inner) { + if (!enableLogging) { + return inner; + } + return new ChangeLoggingWindowBytesStore( + inner, + storeSupplier.retainDuplicates(), + WindowKeySchema::toStoreKeyBinary + ); + } + + public long retentionPeriod() { + return storeSupplier.retentionPeriod(); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/WindowStoreIteratorWrapper.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/WindowStoreIteratorWrapper.java new file mode 100644 index 0000000..14acb13 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/WindowStoreIteratorWrapper.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.WindowStoreIterator; + +class WindowStoreIteratorWrapper { + + private final KeyValueIterator bytesIterator; + private final long windowSize; + + WindowStoreIteratorWrapper(final KeyValueIterator bytesIterator, + final long windowSize) { + this.bytesIterator = bytesIterator; + this.windowSize = windowSize; + } + + public WindowStoreIterator valuesIterator() { + return new WrappedWindowStoreIterator(bytesIterator); + } + + public KeyValueIterator, byte[]> keyValueIterator() { + return new WrappedKeyValueIterator(bytesIterator, windowSize); + } + + private static class WrappedWindowStoreIterator implements WindowStoreIterator { + final KeyValueIterator bytesIterator; + + WrappedWindowStoreIterator( + final KeyValueIterator bytesIterator) { + this.bytesIterator = bytesIterator; + } + + @Override + public Long peekNextKey() { + return WindowKeySchema.extractStoreTimestamp(bytesIterator.peekNextKey().get()); + } + + @Override + public boolean hasNext() { + return bytesIterator.hasNext(); + } + + @Override + public KeyValue next() { + final KeyValue next = bytesIterator.next(); + final long timestamp = WindowKeySchema.extractStoreTimestamp(next.key.get()); + return KeyValue.pair(timestamp, next.value); + } + + @Override + public void close() { + bytesIterator.close(); + } + } + + private static class WrappedKeyValueIterator implements KeyValueIterator, byte[]> { + final KeyValueIterator bytesIterator; + final long windowSize; + + WrappedKeyValueIterator(final KeyValueIterator bytesIterator, + final long windowSize) { + this.bytesIterator = bytesIterator; + this.windowSize = windowSize; + } + + @Override + public Windowed peekNextKey() { + final byte[] nextKey = bytesIterator.peekNextKey().get(); + return WindowKeySchema.fromStoreBytesKey(nextKey, windowSize); + } + + @Override + public boolean hasNext() { + return bytesIterator.hasNext(); + } + + @Override + public KeyValue, byte[]> next() { + final KeyValue next = bytesIterator.next(); + return KeyValue.pair(WindowKeySchema.fromStoreBytesKey(next.key.get(), windowSize), next.value); + } + + @Override + public void close() { + bytesIterator.close(); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/WindowToTimestampedWindowByteStoreAdapter.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/WindowToTimestampedWindowByteStoreAdapter.java new file mode 100644 index 0000000..f7999d3 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/WindowToTimestampedWindowByteStoreAdapter.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.streams.state.WindowStoreIterator; + +import java.time.Instant; + +import static org.apache.kafka.streams.state.TimestampedBytesStore.convertToTimestampedFormat; +import static org.apache.kafka.streams.state.internals.ValueAndTimestampDeserializer.rawValue; + +class WindowToTimestampedWindowByteStoreAdapter implements WindowStore { + final WindowStore store; + + WindowToTimestampedWindowByteStoreAdapter(final WindowStore store) { + if (!store.persistent()) { + throw new IllegalArgumentException("Provided store must be a persistent store, but it is not."); + } + this.store = store; + } + + @Override + public void put(final Bytes key, + final byte[] valueWithTimestamp, + final long windowStartTimestamp) { + store.put(key, valueWithTimestamp == null ? null : rawValue(valueWithTimestamp), windowStartTimestamp); + } + + @Override + public byte[] fetch(final Bytes key, + final long time) { + return convertToTimestampedFormat(store.fetch(key, time)); + } + + @Override + public WindowStoreIterator fetch(final Bytes key, + final long timeFrom, + final long timeTo) { + return new WindowToTimestampedWindowIteratorAdapter(store.fetch(key, timeFrom, timeTo)); + } + + @Override + public WindowStoreIterator fetch(final Bytes key, + final Instant timeFrom, + final Instant timeTo) throws IllegalArgumentException { + return new WindowToTimestampedWindowIteratorAdapter(store.fetch(key, timeFrom, timeTo)); + } + + @Override + public WindowStoreIterator backwardFetch(final Bytes key, + final long timeFrom, + final long timeTo) { + return new WindowToTimestampedWindowIteratorAdapter(store.backwardFetch(key, timeFrom, timeTo)); + } + + @Override + public WindowStoreIterator backwardFetch(final Bytes key, + final Instant timeFrom, + final Instant timeTo) throws IllegalArgumentException { + return new WindowToTimestampedWindowIteratorAdapter(store.backwardFetch(key, timeFrom, timeTo)); + } + + @Override + public KeyValueIterator, byte[]> fetch(final Bytes keyFrom, + final Bytes keyTo, + final long timeFrom, + final long timeTo) { + return new KeyValueToTimestampedKeyValueIteratorAdapter<>(store.fetch(keyFrom, keyTo, timeFrom, timeTo)); + } + + @Override + public KeyValueIterator, byte[]> backwardFetch(final Bytes keyFrom, + final Bytes keyTo, + final Instant timeFrom, + final Instant timeTo) throws IllegalArgumentException { + return new KeyValueToTimestampedKeyValueIteratorAdapter<>(store.backwardFetch(keyFrom, keyTo, timeFrom, timeTo)); + } + + @Override + public KeyValueIterator, byte[]> fetch(final Bytes keyFrom, + final Bytes keyTo, + final Instant timeFrom, + final Instant timeTo) throws IllegalArgumentException { + return new KeyValueToTimestampedKeyValueIteratorAdapter<>(store.fetch(keyFrom, keyTo, timeFrom, timeTo)); + } + + @Override + public KeyValueIterator, byte[]> backwardFetch(final Bytes keyFrom, + final Bytes keyTo, + final long timeFrom, + final long timeTo) { + return new KeyValueToTimestampedKeyValueIteratorAdapter<>(store.backwardFetch(keyFrom, keyTo, timeFrom, timeTo)); + } + + @Override + public KeyValueIterator, byte[]> all() { + return new KeyValueToTimestampedKeyValueIteratorAdapter<>(store.all()); + } + + @Override + public KeyValueIterator, byte[]> backwardAll() { + return new KeyValueToTimestampedKeyValueIteratorAdapter<>(store.backwardAll()); + } + + @Override + public KeyValueIterator, byte[]> fetchAll(final long timeFrom, + final long timeTo) { + return new KeyValueToTimestampedKeyValueIteratorAdapter<>(store.fetchAll(timeFrom, timeTo)); + } + + @Override + public KeyValueIterator, byte[]> fetchAll(final Instant timeFrom, + final Instant timeTo) throws IllegalArgumentException { + return new KeyValueToTimestampedKeyValueIteratorAdapter<>(store.fetchAll(timeFrom, timeTo)); + } + + @Override + public KeyValueIterator, byte[]> backwardFetchAll(final long timeFrom, final long timeTo) { + return new KeyValueToTimestampedKeyValueIteratorAdapter<>(store.backwardFetchAll(timeFrom, timeTo)); + } + + @Override + public KeyValueIterator, byte[]> backwardFetchAll(final Instant timeFrom, + final Instant timeTo) throws IllegalArgumentException { + return new KeyValueToTimestampedKeyValueIteratorAdapter<>(store.backwardFetchAll(timeFrom, timeTo)); + } + + @Override + public String name() { + return store.name(); + } + + @Deprecated + @Override + public void init(final ProcessorContext context, + final StateStore root) { + store.init(context, root); + } + + @Override + public void init(final StateStoreContext context, final StateStore root) { + store.init(context, root); + } + + @Override + public void flush() { + store.flush(); + } + + @Override + public void close() { + store.close(); + } + + @Override + public boolean persistent() { + return true; + } + + @Override + public boolean isOpen() { + return store.isOpen(); + } + + + private static class WindowToTimestampedWindowIteratorAdapter + extends KeyValueToTimestampedKeyValueIteratorAdapter + implements WindowStoreIterator { + + WindowToTimestampedWindowIteratorAdapter(final KeyValueIterator innerIterator) { + super(innerIterator); + } + } + +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/WrappedSessionStoreIterator.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/WrappedSessionStoreIterator.java new file mode 100644 index 0000000..ce26029 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/WrappedSessionStoreIterator.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.KeyValueIterator; + +class WrappedSessionStoreIterator implements KeyValueIterator, byte[]> { + + private final KeyValueIterator bytesIterator; + + WrappedSessionStoreIterator(final KeyValueIterator bytesIterator) { + this.bytesIterator = bytesIterator; + } + + @Override + public void close() { + bytesIterator.close(); + } + + @Override + public Windowed peekNextKey() { + return SessionKeySchema.from(bytesIterator.peekNextKey()); + } + + @Override + public boolean hasNext() { + return bytesIterator.hasNext(); + } + + @Override + public KeyValue, byte[]> next() { + final KeyValue next = bytesIterator.next(); + return KeyValue.pair(SessionKeySchema.from(next.key), next.value); + } +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/WrappedStateStore.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/WrappedStateStore.java new file mode 100644 index 0000000..e8244f7 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/WrappedStateStore.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.state.TimestampedBytesStore; + +/** + * A storage engine wrapper for utilities like logging, caching, and metering. + */ +public abstract class WrappedStateStore implements StateStore, CachedStateStore { + + public static boolean isTimestamped(final StateStore stateStore) { + if (stateStore instanceof TimestampedBytesStore) { + return true; + } else if (stateStore instanceof WrappedStateStore) { + return isTimestamped(((WrappedStateStore) stateStore).wrapped()); + } else { + return false; + } + } + + private final S wrapped; + + public WrappedStateStore(final S wrapped) { + this.wrapped = wrapped; + } + + @Deprecated + @Override + public void init(final ProcessorContext context, + final StateStore root) { + wrapped.init(context, root); + } + + @Override + public void init(final StateStoreContext context, final StateStore root) { + wrapped.init(context, root); + } + + @SuppressWarnings("unchecked") + @Override + public boolean setFlushListener(final CacheFlushListener listener, + final boolean sendOldValues) { + if (wrapped instanceof CachedStateStore) { + return ((CachedStateStore) wrapped).setFlushListener(listener, sendOldValues); + } + return false; + } + + @Override + public void flushCache() { + if (wrapped instanceof CachedStateStore) { + ((CachedStateStore) wrapped).flushCache(); + } + } + + @Override + public String name() { + return wrapped.name(); + } + + @Override + public boolean persistent() { + return wrapped.persistent(); + } + + @Override + public boolean isOpen() { + return wrapped.isOpen(); + } + + void validateStoreOpen() { + if (!wrapped.isOpen()) { + throw new InvalidStateStoreException("Store " + wrapped.name() + " is currently closed."); + } + } + + @Override + public void flush() { + wrapped.flush(); + } + + @Override + public void close() { + wrapped.close(); + } + + public S wrapped() { + return wrapped; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/WrappingStoreProvider.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/WrappingStoreProvider.java new file mode 100644 index 0000000..6b4ae92 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/WrappingStoreProvider.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.StoreQueryParameters; +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.errors.InvalidStateStorePartitionException; +import org.apache.kafka.streams.state.QueryableStoreType; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +/** + * Provides a wrapper over multiple underlying {@link StateStoreProvider}s + */ +public class WrappingStoreProvider implements StateStoreProvider { + + private final Collection storeProviders; + private StoreQueryParameters storeQueryParameters; + + WrappingStoreProvider(final Collection storeProviders, + final StoreQueryParameters storeQueryParameters) { + this.storeProviders = storeProviders; + this.storeQueryParameters = storeQueryParameters; + } + + //visible for testing + public void setStoreQueryParameters(final StoreQueryParameters storeQueryParameters) { + this.storeQueryParameters = storeQueryParameters; + } + + @Override + public List stores(final String storeName, + final QueryableStoreType queryableStoreType) { + final List allStores = new ArrayList<>(); + for (final StreamThreadStateStoreProvider storeProvider : storeProviders) { + final List stores = storeProvider.stores(storeQueryParameters); + if (!stores.isEmpty()) { + allStores.addAll(stores); + if (storeQueryParameters.partition() != null) { + break; + } + } + } + if (allStores.isEmpty()) { + if (storeQueryParameters.partition() != null) { + throw new InvalidStateStorePartitionException( + String.format("The specified partition %d for store %s does not exist.", + storeQueryParameters.partition(), + storeName)); + } + throw new InvalidStateStoreException("The state store, " + storeName + ", may have migrated to another instance."); + } + return allStores; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/metrics/NamedCacheMetrics.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/metrics/NamedCacheMetrics.java new file mode 100644 index 0000000..184c50f --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/metrics/NamedCacheMetrics.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals.metrics; + +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; + +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.CACHE_LEVEL_GROUP; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addAvgAndMinAndMaxToSensor; + +public class NamedCacheMetrics { + private NamedCacheMetrics() {} + + private static final String HIT_RATIO = "hit-ratio"; + private static final String HIT_RATIO_AVG_DESCRIPTION = "The average cache hit ratio"; + private static final String HIT_RATIO_MIN_DESCRIPTION = "The minimum cache hit ratio"; + private static final String HIT_RATIO_MAX_DESCRIPTION = "The maximum cache hit ratio"; + + + public static Sensor hitRatioSensor(final StreamsMetricsImpl streamsMetrics, + final String threadId, + final String taskName, + final String storeName) { + + final Sensor hitRatioSensor; + final String hitRatioName; + hitRatioName = HIT_RATIO; + hitRatioSensor = streamsMetrics.cacheLevelSensor( + threadId, + taskName, + storeName, + hitRatioName, + Sensor.RecordingLevel.DEBUG + ); + addAvgAndMinAndMaxToSensor( + hitRatioSensor, + CACHE_LEVEL_GROUP, + streamsMetrics.cacheLevelTagMap(threadId, taskName, storeName), + hitRatioName, + HIT_RATIO_AVG_DESCRIPTION, + HIT_RATIO_MIN_DESCRIPTION, + HIT_RATIO_MAX_DESCRIPTION + ); + return hitRatioSensor; + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/metrics/RocksDBMetrics.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/metrics/RocksDBMetrics.java new file mode 100644 index 0000000..a30a891 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/metrics/RocksDBMetrics.java @@ -0,0 +1,803 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals.metrics; + +import org.apache.kafka.common.metrics.Gauge; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.Sensor.RecordingLevel; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; + +import java.math.BigInteger; +import java.util.Objects; + +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.AVG_SUFFIX; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.MAX_SUFFIX; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.MIN_SUFFIX; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.RATIO_SUFFIX; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.STATE_STORE_LEVEL_GROUP; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addRateOfSumMetricToSensor; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addRateOfSumAndSumMetricsToSensor; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addAvgAndSumMetricsToSensor; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addSumMetricToSensor; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addValueMetricToSensor; + +public class RocksDBMetrics { + private RocksDBMetrics() {} + + private static final String BYTES_WRITTEN_TO_DB = "bytes-written"; + private static final String BYTES_READ_FROM_DB = "bytes-read"; + private static final String MEMTABLE_BYTES_FLUSHED = "memtable-bytes-flushed"; + private static final String MEMTABLE_HIT_RATIO = "memtable-hit" + RATIO_SUFFIX; + private static final String MEMTABLE_FLUSH_TIME = "memtable-flush-time"; + private static final String MEMTABLE_FLUSH_TIME_AVG = MEMTABLE_FLUSH_TIME + AVG_SUFFIX; + private static final String MEMTABLE_FLUSH_TIME_MIN = MEMTABLE_FLUSH_TIME + MIN_SUFFIX; + private static final String MEMTABLE_FLUSH_TIME_MAX = MEMTABLE_FLUSH_TIME + MAX_SUFFIX; + private static final String WRITE_STALL_DURATION = "write-stall-duration"; + private static final String BLOCK_CACHE_DATA_HIT_RATIO = "block-cache-data-hit" + RATIO_SUFFIX; + private static final String BLOCK_CACHE_INDEX_HIT_RATIO = "block-cache-index-hit" + RATIO_SUFFIX; + private static final String BLOCK_CACHE_FILTER_HIT_RATIO = "block-cache-filter-hit" + RATIO_SUFFIX; + private static final String BYTES_READ_DURING_COMPACTION = "bytes-read-compaction"; + private static final String BYTES_WRITTEN_DURING_COMPACTION = "bytes-written-compaction"; + private static final String COMPACTION_TIME = "compaction-time"; + private static final String COMPACTION_TIME_AVG = COMPACTION_TIME + AVG_SUFFIX; + private static final String COMPACTION_TIME_MIN = COMPACTION_TIME + MIN_SUFFIX; + private static final String COMPACTION_TIME_MAX = COMPACTION_TIME + MAX_SUFFIX; + private static final String NUMBER_OF_OPEN_FILES = "number-open-files"; + private static final String NUMBER_OF_FILE_ERRORS = "number-file-errors"; + static final String NUMBER_OF_ENTRIES_ACTIVE_MEMTABLE = "num-entries-active-mem-table"; + static final String NUMBER_OF_DELETES_ACTIVE_MEMTABLE = "num-deletes-active-mem-table"; + static final String NUMBER_OF_ENTRIES_IMMUTABLE_MEMTABLES = "num-entries-imm-mem-tables"; + static final String NUMBER_OF_DELETES_IMMUTABLE_MEMTABLES = "num-deletes-imm-mem-tables"; + static final String NUMBER_OF_IMMUTABLE_MEMTABLES = "num-immutable-mem-table"; + static final String CURRENT_SIZE_OF_ACTIVE_MEMTABLE = "cur-size-active-mem-table"; + static final String CURRENT_SIZE_OF_ALL_MEMTABLES = "cur-size-all-mem-tables"; + static final String SIZE_OF_ALL_MEMTABLES = "size-all-mem-tables"; + static final String MEMTABLE_FLUSH_PENDING = "mem-table-flush-pending"; + static final String NUMBER_OF_RUNNING_FLUSHES = "num-running-flushes"; + static final String COMPACTION_PENDING = "compaction-pending"; + static final String NUMBER_OF_RUNNING_COMPACTIONS = "num-running-compactions"; + static final String ESTIMATED_BYTES_OF_PENDING_COMPACTION = "estimate-pending-compaction-bytes"; + static final String TOTAL_SST_FILES_SIZE = "total-sst-files-size"; + static final String LIVE_SST_FILES_SIZE = "live-sst-files-size"; + static final String NUMBER_OF_LIVE_VERSIONS = "num-live-versions"; + static final String CAPACITY_OF_BLOCK_CACHE = "block-cache-capacity"; + static final String USAGE_OF_BLOCK_CACHE = "block-cache-usage"; + static final String PINNED_USAGE_OF_BLOCK_CACHE = "block-cache-pinned-usage"; + static final String ESTIMATED_NUMBER_OF_KEYS = "estimate-num-keys"; + static final String ESTIMATED_MEMORY_OF_TABLE_READERS = "estimate-table-readers-mem"; + static final String NUMBER_OF_BACKGROUND_ERRORS = "background-errors"; + + private static final String BYTES_WRITTEN_TO_DB_RATE_DESCRIPTION = + "Average number of bytes written per second to the RocksDB state store"; + private static final String BYTES_WRITTEN_TO_DB_TOTAL_DESCRIPTION = + "Total number of bytes written to the RocksDB state store"; + private static final String BYTES_READ_FROM_DB_RATE_DESCRIPTION = + "Average number of bytes read per second from the RocksDB state store"; + private static final String BYTES_READ_FROM_DB_TOTAL_DESCRIPTION = + "Total number of bytes read from the RocksDB state store"; + private static final String MEMTABLE_BYTES_FLUSHED_RATE_DESCRIPTION = + "Average number of bytes flushed per second from the memtable to disk"; + private static final String MEMTABLE_BYTES_FLUSHED_TOTAL_DESCRIPTION = + "Total number of bytes flushed from the memtable to disk"; + private static final String MEMTABLE_HIT_RATIO_DESCRIPTION = + "Ratio of memtable hits relative to all lookups to the memtable"; + private static final String MEMTABLE_FLUSH_TIME_AVG_DESCRIPTION = + "Average time spent on flushing the memtable to disk in ms"; + private static final String MEMTABLE_FLUSH_TIME_MIN_DESCRIPTION = + "Minimum time spent on flushing the memtable to disk in ms"; + private static final String MEMTABLE_FLUSH_TIME_MAX_DESCRIPTION = + "Maximum time spent on flushing the memtable to disk in ms"; + private static final String WRITE_STALL_DURATION_AVG_DESCRIPTION = "Average duration of write stalls in ms"; + private static final String WRITE_STALL_DURATION_TOTAL_DESCRIPTION = "Total duration of write stalls in ms"; + private static final String BLOCK_CACHE_DATA_HIT_RATIO_DESCRIPTION = + "Ratio of block cache hits for data relative to all lookups for data to the block cache"; + private static final String BLOCK_CACHE_INDEX_HIT_RATIO_DESCRIPTION = + "Ratio of block cache hits for indexes relative to all lookups for indexes to the block cache"; + private static final String BLOCK_CACHE_FILTER_HIT_RATIO_DESCRIPTION = + "Ratio of block cache hits for filters relative to all lookups for filters to the block cache"; + private static final String BYTES_READ_DURING_COMPACTION_DESCRIPTION = + "Average number of bytes read per second during compaction"; + private static final String BYTES_WRITTEN_DURING_COMPACTION_DESCRIPTION = + "Average number of bytes written per second during compaction"; + private static final String COMPACTION_TIME_AVG_DESCRIPTION = "Average time spent on compaction in ms"; + private static final String COMPACTION_TIME_MIN_DESCRIPTION = "Minimum time spent on compaction in ms"; + private static final String COMPACTION_TIME_MAX_DESCRIPTION = "Maximum time spent on compaction in ms"; + private static final String NUMBER_OF_OPEN_FILES_DESCRIPTION = "Number of currently open files"; + private static final String NUMBER_OF_FILE_ERRORS_DESCRIPTION = "Total number of file errors occurred"; + private static final String NUMBER_OF_ENTRIES_ACTIVE_MEMTABLE_DESCRIPTION = + "Total number of entries in the active memtable"; + private static final String NUMBER_OF_DELETES_ACTIVE_MEMTABLES_DESCRIPTION = + "Total number of delete entries in the active memtable"; + private static final String NUMBER_OF_ENTRIES_IMMUTABLE_MEMTABLES_DESCRIPTION = + "Total number of entries in the unflushed immutable memtables"; + private static final String NUMBER_OF_DELETES_IMMUTABLE_MEMTABLES_DESCRIPTION = + "Total number of delete entries in the unflushed immutable memtables"; + private static final String NUMBER_OF_IMMUTABLE_MEMTABLES_DESCRIPTION = + "Number of immutable memtables that have not yet been flushed"; + private static final String CURRENT_SIZE_OF_ACTIVE_MEMTABLE_DESCRIPTION = + "Approximate size of active memtable in bytes"; + private static final String CURRENT_SIZE_OF_ALL_MEMTABLES_DESCRIPTION = + "Approximate size of active and unflushed immutable memtables in bytes"; + private static final String SIZE_OF_ALL_MEMTABLES_DESCRIPTION = + "Approximate size of active, unflushed immutable, and pinned immutable memtables in bytes"; + private static final String MEMTABLE_FLUSH_PENDING_DESCRIPTION = + "Reports 1 if a memtable flush is pending, otherwise it reports 0"; + private static final String NUMBER_OF_RUNNING_FLUSHES_DESCRIPTION = "Number of currently running flushes"; + private static final String COMPACTION_PENDING_DESCRIPTION = + "Reports 1 if at least one compaction is pending, otherwise it reports 0"; + private static final String NUMBER_OF_RUNNING_COMPACTIONS_DESCRIPTION = "Number of currently running compactions"; + private static final String ESTIMATED_BYTES_OF_PENDING_COMPACTION_DESCRIPTION = + "Estimated total number of bytes a compaction needs to rewrite on disk to get all levels down to under target size"; + private static final String TOTAL_SST_FILE_SIZE_DESCRIPTION = "Total size in bytes of all SST files"; + private static final String LIVE_SST_FILES_SIZE_DESCRIPTION = + "Total size in bytes of all SST files that belong to the latest LSM tree"; + private static final String NUMBER_OF_LIVE_VERSIONS_DESCRIPTION = "Number of live versions of the LSM tree"; + private static final String CAPACITY_OF_BLOCK_CACHE_DESCRIPTION = "Capacity of the block cache in bytes"; + private static final String USAGE_OF_BLOCK_CACHE_DESCRIPTION = + "Memory size of the entries residing in block cache in bytes"; + private static final String PINNED_USAGE_OF_BLOCK_CACHE_DESCRIPTION = + "Memory size for the entries being pinned in the block cache in bytes"; + private static final String ESTIMATED_NUMBER_OF_KEYS_DESCRIPTION = + "Estimated number of keys in the active and unflushed immutable memtables and storage"; + private static final String ESTIMATED_MEMORY_OF_TABLE_READERS_DESCRIPTION = + "Estimated memory in bytes used for reading SST tables, excluding memory used in block cache"; + private static final String TOTAL_NUMBER_OF_BACKGROUND_ERRORS_DESCRIPTION = "Total number of background errors"; + + public static class RocksDBMetricContext { + private final String taskName; + private final String metricsScope; + private final String storeName; + + public RocksDBMetricContext(final String taskName, + final String metricsScope, + final String storeName) { + this.taskName = taskName; + this.metricsScope = metricsScope; + this.storeName = storeName; + } + + public String taskName() { + return taskName; + } + public String metricsScope() { + return metricsScope; + } + public String storeName() { + return storeName; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final RocksDBMetricContext that = (RocksDBMetricContext) o; + return Objects.equals(taskName, that.taskName) && + Objects.equals(metricsScope, that.metricsScope) && + Objects.equals(storeName, that.storeName); + } + + @Override + public int hashCode() { + return Objects.hash(taskName, metricsScope, storeName); + } + } + + public static Sensor bytesWrittenToDatabaseSensor(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext) { + final Sensor sensor = createSensor(streamsMetrics, metricContext, BYTES_WRITTEN_TO_DB); + addRateOfSumAndSumMetricsToSensor( + sensor, + STATE_STORE_LEVEL_GROUP, + streamsMetrics.storeLevelTagMap( + metricContext.taskName(), + metricContext.metricsScope(), + metricContext.storeName() + ), + BYTES_WRITTEN_TO_DB, + BYTES_WRITTEN_TO_DB_RATE_DESCRIPTION, + BYTES_WRITTEN_TO_DB_TOTAL_DESCRIPTION + ); + return sensor; + } + + public static Sensor bytesReadFromDatabaseSensor(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext) { + final Sensor sensor = createSensor(streamsMetrics, metricContext, BYTES_READ_FROM_DB); + addRateOfSumAndSumMetricsToSensor( + sensor, + STATE_STORE_LEVEL_GROUP, + streamsMetrics.storeLevelTagMap( + metricContext.taskName(), + metricContext.metricsScope(), + metricContext.storeName() + ), + BYTES_READ_FROM_DB, + BYTES_READ_FROM_DB_RATE_DESCRIPTION, + BYTES_READ_FROM_DB_TOTAL_DESCRIPTION + ); + return sensor; + } + + public static Sensor memtableBytesFlushedSensor(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext) { + final Sensor sensor = createSensor(streamsMetrics, metricContext, MEMTABLE_BYTES_FLUSHED); + addRateOfSumAndSumMetricsToSensor( + sensor, + STATE_STORE_LEVEL_GROUP, + streamsMetrics.storeLevelTagMap( + metricContext.taskName(), + metricContext.metricsScope(), + metricContext.storeName() + ), + MEMTABLE_BYTES_FLUSHED, + MEMTABLE_BYTES_FLUSHED_RATE_DESCRIPTION, + MEMTABLE_BYTES_FLUSHED_TOTAL_DESCRIPTION + ); + return sensor; + } + + public static Sensor memtableHitRatioSensor(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext) { + final Sensor sensor = createSensor(streamsMetrics, metricContext, MEMTABLE_HIT_RATIO); + addValueMetricToSensor( + sensor, + STATE_STORE_LEVEL_GROUP, + streamsMetrics.storeLevelTagMap( + metricContext.taskName(), + metricContext.metricsScope(), + metricContext.storeName() + ), + MEMTABLE_HIT_RATIO, + MEMTABLE_HIT_RATIO_DESCRIPTION + ); + return sensor; + } + + public static Sensor memtableAvgFlushTimeSensor(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext) { + final Sensor sensor = createSensor(streamsMetrics, metricContext, MEMTABLE_FLUSH_TIME_AVG); + addValueMetricToSensor( + sensor, + STATE_STORE_LEVEL_GROUP, + streamsMetrics.storeLevelTagMap( + metricContext.taskName(), + metricContext.metricsScope(), + metricContext.storeName() + ), + MEMTABLE_FLUSH_TIME_AVG, + MEMTABLE_FLUSH_TIME_AVG_DESCRIPTION + ); + return sensor; + } + + public static Sensor memtableMinFlushTimeSensor(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext) { + final Sensor sensor = createSensor(streamsMetrics, metricContext, MEMTABLE_FLUSH_TIME_MIN); + addValueMetricToSensor( + sensor, + STATE_STORE_LEVEL_GROUP, + streamsMetrics.storeLevelTagMap( + metricContext.taskName(), + metricContext.metricsScope(), + metricContext.storeName() + ), + MEMTABLE_FLUSH_TIME_MIN, + MEMTABLE_FLUSH_TIME_MIN_DESCRIPTION + ); + return sensor; + } + + public static Sensor memtableMaxFlushTimeSensor(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext) { + final Sensor sensor = createSensor(streamsMetrics, metricContext, MEMTABLE_FLUSH_TIME_MAX); + addValueMetricToSensor( + sensor, + STATE_STORE_LEVEL_GROUP, + streamsMetrics.storeLevelTagMap( + metricContext.taskName(), + metricContext.metricsScope(), + metricContext.storeName() + ), + MEMTABLE_FLUSH_TIME_MAX, + MEMTABLE_FLUSH_TIME_MAX_DESCRIPTION + ); + return sensor; + } + + public static Sensor writeStallDurationSensor(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext) { + final Sensor sensor = createSensor(streamsMetrics, metricContext, WRITE_STALL_DURATION); + addAvgAndSumMetricsToSensor( + sensor, + STATE_STORE_LEVEL_GROUP, + streamsMetrics.storeLevelTagMap( + metricContext.taskName(), + metricContext.metricsScope(), + metricContext.storeName() + ), + WRITE_STALL_DURATION, + WRITE_STALL_DURATION_AVG_DESCRIPTION, + WRITE_STALL_DURATION_TOTAL_DESCRIPTION + ); + return sensor; + } + + public static Sensor blockCacheDataHitRatioSensor(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext) { + final Sensor sensor = createSensor(streamsMetrics, metricContext, BLOCK_CACHE_DATA_HIT_RATIO); + addValueMetricToSensor( + sensor, + STATE_STORE_LEVEL_GROUP, + streamsMetrics.storeLevelTagMap( + metricContext.taskName(), + metricContext.metricsScope(), + metricContext.storeName() + ), + BLOCK_CACHE_DATA_HIT_RATIO, + BLOCK_CACHE_DATA_HIT_RATIO_DESCRIPTION + ); + return sensor; + } + + public static Sensor blockCacheIndexHitRatioSensor(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext) { + final Sensor sensor = createSensor(streamsMetrics, metricContext, BLOCK_CACHE_INDEX_HIT_RATIO); + addValueMetricToSensor( + sensor, + STATE_STORE_LEVEL_GROUP, + streamsMetrics.storeLevelTagMap( + metricContext.taskName(), + metricContext.metricsScope(), + metricContext.storeName() + ), + BLOCK_CACHE_INDEX_HIT_RATIO, + BLOCK_CACHE_INDEX_HIT_RATIO_DESCRIPTION + ); + return sensor; + } + + public static Sensor blockCacheFilterHitRatioSensor(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext) { + final Sensor sensor = createSensor(streamsMetrics, metricContext, BLOCK_CACHE_FILTER_HIT_RATIO); + addValueMetricToSensor( + sensor, + STATE_STORE_LEVEL_GROUP, + streamsMetrics.storeLevelTagMap( + metricContext.taskName(), + metricContext.metricsScope(), + metricContext.storeName() + ), + BLOCK_CACHE_FILTER_HIT_RATIO, + BLOCK_CACHE_FILTER_HIT_RATIO_DESCRIPTION + ); + return sensor; + } + + public static Sensor bytesReadDuringCompactionSensor(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext) { + final Sensor sensor = createSensor(streamsMetrics, metricContext, BYTES_READ_DURING_COMPACTION); + addRateOfSumMetricToSensor( + sensor, + STATE_STORE_LEVEL_GROUP, + streamsMetrics.storeLevelTagMap( + metricContext.taskName(), + metricContext.metricsScope(), + metricContext.storeName() + ), + BYTES_READ_DURING_COMPACTION, + BYTES_READ_DURING_COMPACTION_DESCRIPTION + ); + return sensor; + } + + public static Sensor bytesWrittenDuringCompactionSensor(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext) { + final Sensor sensor = createSensor(streamsMetrics, metricContext, BYTES_WRITTEN_DURING_COMPACTION); + addRateOfSumMetricToSensor( + sensor, + STATE_STORE_LEVEL_GROUP, + streamsMetrics.storeLevelTagMap( + metricContext.taskName(), + metricContext.metricsScope(), + metricContext.storeName() + ), + BYTES_WRITTEN_DURING_COMPACTION, + BYTES_WRITTEN_DURING_COMPACTION_DESCRIPTION + ); + return sensor; + } + + public static Sensor compactionTimeAvgSensor(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext) { + final Sensor sensor = createSensor(streamsMetrics, metricContext, COMPACTION_TIME_AVG); + addValueMetricToSensor( + sensor, + STATE_STORE_LEVEL_GROUP, + streamsMetrics.storeLevelTagMap( + metricContext.taskName(), + metricContext.metricsScope(), + metricContext.storeName() + ), + COMPACTION_TIME_AVG, + COMPACTION_TIME_AVG_DESCRIPTION + ); + return sensor; + } + + public static Sensor compactionTimeMinSensor(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext) { + final Sensor sensor = createSensor(streamsMetrics, metricContext, COMPACTION_TIME_MIN); + addValueMetricToSensor( + sensor, + STATE_STORE_LEVEL_GROUP, + streamsMetrics.storeLevelTagMap( + metricContext.taskName(), + metricContext.metricsScope(), + metricContext.storeName() + ), + COMPACTION_TIME_MIN, + COMPACTION_TIME_MIN_DESCRIPTION + ); + return sensor; + } + + public static Sensor compactionTimeMaxSensor(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext) { + final Sensor sensor = createSensor(streamsMetrics, metricContext, COMPACTION_TIME_MAX); + addValueMetricToSensor( + sensor, + STATE_STORE_LEVEL_GROUP, + streamsMetrics.storeLevelTagMap( + metricContext.taskName(), + metricContext.metricsScope(), + metricContext.storeName() + ), + COMPACTION_TIME_MAX, + COMPACTION_TIME_MAX_DESCRIPTION + ); + return sensor; + } + + public static Sensor numberOfOpenFilesSensor(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext) { + final Sensor sensor = createSensor(streamsMetrics, metricContext, NUMBER_OF_OPEN_FILES); + addSumMetricToSensor( + sensor, + STATE_STORE_LEVEL_GROUP, + streamsMetrics.storeLevelTagMap( + metricContext.taskName(), + metricContext.metricsScope(), + metricContext.storeName() + ), + NUMBER_OF_OPEN_FILES, + false, + NUMBER_OF_OPEN_FILES_DESCRIPTION + ); + return sensor; + } + + public static Sensor numberOfFileErrorsSensor(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext) { + final Sensor sensor = createSensor(streamsMetrics, metricContext, NUMBER_OF_FILE_ERRORS); + addSumMetricToSensor( + sensor, + STATE_STORE_LEVEL_GROUP, + streamsMetrics.storeLevelTagMap( + metricContext.taskName(), + metricContext.metricsScope(), + metricContext.storeName() + ), + NUMBER_OF_FILE_ERRORS, + NUMBER_OF_FILE_ERRORS_DESCRIPTION + ); + return sensor; + } + + public static void addNumEntriesActiveMemTableMetric(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext, + final Gauge valueProvider) { + addMutableMetric( + streamsMetrics, + metricContext, + valueProvider, + NUMBER_OF_ENTRIES_ACTIVE_MEMTABLE, + NUMBER_OF_ENTRIES_ACTIVE_MEMTABLE_DESCRIPTION + ); + } + + public static void addNumEntriesImmMemTablesMetric(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext, + final Gauge valueProvider) { + addMutableMetric( + streamsMetrics, + metricContext, + valueProvider, + NUMBER_OF_ENTRIES_IMMUTABLE_MEMTABLES, + NUMBER_OF_ENTRIES_IMMUTABLE_MEMTABLES_DESCRIPTION + ); + } + + public static void addNumDeletesImmMemTablesMetric(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext, + final Gauge valueProvider) { + addMutableMetric( + streamsMetrics, + metricContext, + valueProvider, + NUMBER_OF_DELETES_IMMUTABLE_MEMTABLES, + NUMBER_OF_DELETES_IMMUTABLE_MEMTABLES_DESCRIPTION + ); + } + + public static void addNumDeletesActiveMemTableMetric(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext, + final Gauge valueProvider) { + addMutableMetric( + streamsMetrics, + metricContext, + valueProvider, + NUMBER_OF_DELETES_ACTIVE_MEMTABLE, + NUMBER_OF_DELETES_ACTIVE_MEMTABLES_DESCRIPTION + ); + } + + public static void addNumImmutableMemTableMetric(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext, + final Gauge valueProvider) { + addMutableMetric( + streamsMetrics, + metricContext, + valueProvider, + NUMBER_OF_IMMUTABLE_MEMTABLES, + NUMBER_OF_IMMUTABLE_MEMTABLES_DESCRIPTION + ); + } + + public static void addCurSizeActiveMemTable(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext, + final Gauge valueProvider) { + addMutableMetric( + streamsMetrics, + metricContext, + valueProvider, + CURRENT_SIZE_OF_ACTIVE_MEMTABLE, + CURRENT_SIZE_OF_ACTIVE_MEMTABLE_DESCRIPTION + ); + } + + public static void addCurSizeAllMemTables(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext, + final Gauge valueProvider) { + addMutableMetric( + streamsMetrics, + metricContext, + valueProvider, + CURRENT_SIZE_OF_ALL_MEMTABLES, + CURRENT_SIZE_OF_ALL_MEMTABLES_DESCRIPTION + ); + } + + public static void addSizeAllMemTables(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext, + final Gauge valueProvider) { + addMutableMetric( + streamsMetrics, + metricContext, + valueProvider, + SIZE_OF_ALL_MEMTABLES, + SIZE_OF_ALL_MEMTABLES_DESCRIPTION + ); + } + + public static void addMemTableFlushPending(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext, + final Gauge valueProvider) { + addMutableMetric( + streamsMetrics, + metricContext, + valueProvider, + MEMTABLE_FLUSH_PENDING, + MEMTABLE_FLUSH_PENDING_DESCRIPTION + ); + } + + public static void addNumRunningFlushesMetric(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext, + final Gauge valueProvider) { + addMutableMetric( + streamsMetrics, + metricContext, + valueProvider, + NUMBER_OF_RUNNING_FLUSHES, + NUMBER_OF_RUNNING_FLUSHES_DESCRIPTION + ); + } + + public static void addCompactionPendingMetric(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext, + final Gauge valueProvider) { + addMutableMetric( + streamsMetrics, + metricContext, + valueProvider, + COMPACTION_PENDING, + COMPACTION_PENDING_DESCRIPTION + ); + } + + public static void addNumRunningCompactionsMetric(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext, + final Gauge valueProvider) { + addMutableMetric( + streamsMetrics, + metricContext, + valueProvider, + NUMBER_OF_RUNNING_COMPACTIONS, + NUMBER_OF_RUNNING_COMPACTIONS_DESCRIPTION + ); + } + + public static void addEstimatePendingCompactionBytesMetric(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext, + final Gauge valueProvider) { + addMutableMetric( + streamsMetrics, + metricContext, + valueProvider, + ESTIMATED_BYTES_OF_PENDING_COMPACTION, + ESTIMATED_BYTES_OF_PENDING_COMPACTION_DESCRIPTION + ); + } + + public static void addTotalSstFilesSizeMetric(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext, + final Gauge valueProvider) { + addMutableMetric( + streamsMetrics, + metricContext, + valueProvider, + TOTAL_SST_FILES_SIZE, + TOTAL_SST_FILE_SIZE_DESCRIPTION + ); + } + + public static void addLiveSstFilesSizeMetric(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext, + final Gauge valueProvider) { + addMutableMetric( + streamsMetrics, + metricContext, + valueProvider, + LIVE_SST_FILES_SIZE, + LIVE_SST_FILES_SIZE_DESCRIPTION + ); + } + + public static void addNumLiveVersionMetric(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext, + final Gauge valueProvider) { + addMutableMetric( + streamsMetrics, + metricContext, + valueProvider, + NUMBER_OF_LIVE_VERSIONS, + NUMBER_OF_LIVE_VERSIONS_DESCRIPTION + ); + } + + public static void addBlockCacheCapacityMetric(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext, + final Gauge valueProvider) { + addMutableMetric( + streamsMetrics, + metricContext, + valueProvider, + CAPACITY_OF_BLOCK_CACHE, + CAPACITY_OF_BLOCK_CACHE_DESCRIPTION + ); + } + + public static void addBlockCacheUsageMetric(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext, + final Gauge valueProvider) { + addMutableMetric( + streamsMetrics, + metricContext, + valueProvider, + USAGE_OF_BLOCK_CACHE, + USAGE_OF_BLOCK_CACHE_DESCRIPTION + ); + } + + public static void addBlockCachePinnedUsageMetric(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext, + final Gauge valueProvider) { + addMutableMetric( + streamsMetrics, + metricContext, + valueProvider, + PINNED_USAGE_OF_BLOCK_CACHE, + PINNED_USAGE_OF_BLOCK_CACHE_DESCRIPTION + ); + } + + public static void addEstimateNumKeysMetric(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext, + final Gauge valueProvider) { + addMutableMetric( + streamsMetrics, + metricContext, + valueProvider, + ESTIMATED_NUMBER_OF_KEYS, + ESTIMATED_NUMBER_OF_KEYS_DESCRIPTION + ); + } + + public static void addEstimateTableReadersMemMetric(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext, + final Gauge valueProvider) { + addMutableMetric( + streamsMetrics, + metricContext, + valueProvider, + ESTIMATED_MEMORY_OF_TABLE_READERS, + ESTIMATED_MEMORY_OF_TABLE_READERS_DESCRIPTION + ); + } + + public static void addBackgroundErrorsMetric(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext, + final Gauge valueProvider) { + addMutableMetric( + streamsMetrics, + metricContext, + valueProvider, + NUMBER_OF_BACKGROUND_ERRORS, + TOTAL_NUMBER_OF_BACKGROUND_ERRORS_DESCRIPTION + ); + } + + private static void addMutableMetric(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext, + final Gauge valueProvider, + final String name, + final String description) { + streamsMetrics.addStoreLevelMutableMetric( + metricContext.taskName(), + metricContext.metricsScope(), + metricContext.storeName(), + name, + description, + RecordingLevel.INFO, + valueProvider + ); + } + + private static Sensor createSensor(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext, + final String sensorName) { + return streamsMetrics.storeLevelSensor( + metricContext.taskName(), + metricContext.storeName(), + sensorName, + RecordingLevel.DEBUG); + } +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/metrics/RocksDBMetricsRecorder.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/metrics/RocksDBMetricsRecorder.java new file mode 100644 index 0000000..85412d1 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/metrics/RocksDBMetricsRecorder.java @@ -0,0 +1,475 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals.metrics; + +import org.apache.kafka.common.metrics.Gauge; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.streams.errors.ProcessorStateException; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.RocksDBMetricContext; +import org.rocksdb.Cache; +import org.rocksdb.RocksDB; +import org.rocksdb.RocksDBException; +import org.rocksdb.Statistics; +import org.rocksdb.StatsLevel; +import org.rocksdb.TickerType; +import org.slf4j.Logger; + +import java.math.BigInteger; +import java.nio.ByteBuffer; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; + +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.CAPACITY_OF_BLOCK_CACHE; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.COMPACTION_PENDING; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.CURRENT_SIZE_OF_ACTIVE_MEMTABLE; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.CURRENT_SIZE_OF_ALL_MEMTABLES; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.ESTIMATED_BYTES_OF_PENDING_COMPACTION; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.ESTIMATED_MEMORY_OF_TABLE_READERS; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.ESTIMATED_NUMBER_OF_KEYS; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.LIVE_SST_FILES_SIZE; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.MEMTABLE_FLUSH_PENDING; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.NUMBER_OF_DELETES_ACTIVE_MEMTABLE; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.NUMBER_OF_DELETES_IMMUTABLE_MEMTABLES; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.NUMBER_OF_ENTRIES_ACTIVE_MEMTABLE; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.NUMBER_OF_ENTRIES_IMMUTABLE_MEMTABLES; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.NUMBER_OF_IMMUTABLE_MEMTABLES; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.NUMBER_OF_LIVE_VERSIONS; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.NUMBER_OF_RUNNING_COMPACTIONS; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.NUMBER_OF_RUNNING_FLUSHES; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.PINNED_USAGE_OF_BLOCK_CACHE; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.SIZE_OF_ALL_MEMTABLES; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.NUMBER_OF_BACKGROUND_ERRORS; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.TOTAL_SST_FILES_SIZE; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.USAGE_OF_BLOCK_CACHE; + +public class RocksDBMetricsRecorder { + + private static class DbAndCacheAndStatistics { + public final RocksDB db; + public final Cache cache; + public final Statistics statistics; + + public DbAndCacheAndStatistics(final RocksDB db, final Cache cache, final Statistics statistics) { + Objects.requireNonNull(db, "database instance must not be null"); + this.db = db; + this.cache = cache; + if (statistics != null) { + statistics.setStatsLevel(StatsLevel.EXCEPT_DETAILED_TIMERS); + } + this.statistics = statistics; + } + + public void maybeCloseStatistics() { + if (statistics != null) { + statistics.close(); + } + } + } + + private static final String ROCKSDB_PROPERTIES_PREFIX = "rocksdb."; + + + private final Logger logger; + + private Sensor bytesWrittenToDatabaseSensor; + private Sensor bytesReadFromDatabaseSensor; + private Sensor memtableBytesFlushedSensor; + private Sensor memtableHitRatioSensor; + private Sensor writeStallDurationSensor; + private Sensor blockCacheDataHitRatioSensor; + private Sensor blockCacheIndexHitRatioSensor; + private Sensor blockCacheFilterHitRatioSensor; + private Sensor bytesReadDuringCompactionSensor; + private Sensor bytesWrittenDuringCompactionSensor; + private Sensor numberOfOpenFilesSensor; + private Sensor numberOfFileErrorsSensor; + + private final Map storeToValueProviders = new ConcurrentHashMap<>(); + private final String metricsScope; + private final String storeName; + private TaskId taskId; + private StreamsMetricsImpl streamsMetrics; + private boolean singleCache = true; + + public RocksDBMetricsRecorder(final String metricsScope, + final String storeName) { + this.metricsScope = metricsScope; + this.storeName = storeName; + final LogContext logContext = new LogContext(String.format("[RocksDB Metrics Recorder for %s] ", storeName)); + logger = logContext.logger(RocksDBMetricsRecorder.class); + } + + public String storeName() { + return storeName; + } + + public TaskId taskId() { + return taskId; + } + + /** + * The initialisation of the metrics recorder is idempotent. + */ + public void init(final StreamsMetricsImpl streamsMetrics, + final TaskId taskId) { + Objects.requireNonNull(streamsMetrics, "Streams metrics must not be null"); + Objects.requireNonNull(streamsMetrics, "task ID must not be null"); + if (this.taskId != null && !this.taskId.equals(taskId)) { + throw new IllegalStateException("Metrics recorder is re-initialised with different task: previous task is " + + this.taskId + " whereas current task is " + taskId + ". This is a bug in Kafka Streams. " + + "Please open a bug report under https://issues.apache.org/jira/projects/KAFKA/issues"); + } + if (this.streamsMetrics != null && this.streamsMetrics != streamsMetrics) { + throw new IllegalStateException("Metrics recorder is re-initialised with different Streams metrics. " + + "This is a bug in Kafka Streams. " + + "Please open a bug report under https://issues.apache.org/jira/projects/KAFKA/issues"); + } + final RocksDBMetricContext metricContext = new RocksDBMetricContext(taskId.toString(), metricsScope, storeName); + initSensors(streamsMetrics, metricContext); + initGauges(streamsMetrics, metricContext); + this.taskId = taskId; + this.streamsMetrics = streamsMetrics; + } + + public void addValueProviders(final String segmentName, + final RocksDB db, + final Cache cache, + final Statistics statistics) { + if (storeToValueProviders.isEmpty()) { + logger.debug("Adding metrics recorder of task {} to metrics recording trigger", taskId); + streamsMetrics.rocksDBMetricsRecordingTrigger().addMetricsRecorder(this); + } else if (storeToValueProviders.containsKey(segmentName)) { + throw new IllegalStateException("Value providers for store " + segmentName + " of task " + taskId + + " has been already added. This is a bug in Kafka Streams. " + + "Please open a bug report under https://issues.apache.org/jira/projects/KAFKA/issues"); + } + verifyDbAndCacheAndStatistics(segmentName, db, cache, statistics); + logger.debug("Adding value providers for store {} of task {}", segmentName, taskId); + storeToValueProviders.put(segmentName, new DbAndCacheAndStatistics(db, cache, statistics)); + } + + private void verifyDbAndCacheAndStatistics(final String segmentName, + final RocksDB db, + final Cache cache, + final Statistics statistics) { + for (final DbAndCacheAndStatistics valueProviders : storeToValueProviders.values()) { + verifyConsistencyOfValueProvidersAcrossSegments(segmentName, statistics, valueProviders.statistics, "statistics"); + verifyConsistencyOfValueProvidersAcrossSegments(segmentName, cache, valueProviders.cache, "cache"); + if (db == valueProviders.db) { + throw new IllegalStateException("DB instance for store " + segmentName + " of task " + taskId + + " was already added for another segment as a value provider. This is a bug in Kafka Streams. " + + "Please open a bug report under https://issues.apache.org/jira/projects/KAFKA/issues"); + } + if (storeToValueProviders.size() == 1 && cache != valueProviders.cache) { + singleCache = false; + } else if (singleCache && cache != valueProviders.cache || !singleCache && cache == valueProviders.cache) { + throw new IllegalStateException("Caches for store " + storeName + " of task " + taskId + + " are either not all distinct or do not all refer to the same cache. This is a bug in Kafka Streams. " + + "Please open a bug report under https://issues.apache.org/jira/projects/KAFKA/issues"); + } + } + } + + private void verifyConsistencyOfValueProvidersAcrossSegments(final String segmentName, + final Object newValueProvider, + final Object oldValueProvider, + final String valueProviderName) { + if (newValueProvider == null && oldValueProvider != null || + newValueProvider != null && oldValueProvider == null) { + + final char capitalizedFirstChar = valueProviderName.toUpperCase(Locale.US).charAt(0); + final StringBuilder capitalizedValueProviderName = new StringBuilder(valueProviderName); + capitalizedValueProviderName.setCharAt(0, capitalizedFirstChar); + throw new IllegalStateException(capitalizedValueProviderName + + " for segment " + segmentName + " of task " + taskId + + " is" + (newValueProvider == null ? " " : " not ") + "null although the " + valueProviderName + + " of another segment in this metrics recorder is" + (newValueProvider != null ? " " : " not ") + "null. " + + "This is a bug in Kafka Streams. " + + "Please open a bug report under https://issues.apache.org/jira/projects/KAFKA/issues"); + } + } + + private void initSensors(final StreamsMetricsImpl streamsMetrics, final RocksDBMetricContext metricContext) { + bytesWrittenToDatabaseSensor = RocksDBMetrics.bytesWrittenToDatabaseSensor(streamsMetrics, metricContext); + bytesReadFromDatabaseSensor = RocksDBMetrics.bytesReadFromDatabaseSensor(streamsMetrics, metricContext); + memtableBytesFlushedSensor = RocksDBMetrics.memtableBytesFlushedSensor(streamsMetrics, metricContext); + memtableHitRatioSensor = RocksDBMetrics.memtableHitRatioSensor(streamsMetrics, metricContext); + writeStallDurationSensor = RocksDBMetrics.writeStallDurationSensor(streamsMetrics, metricContext); + blockCacheDataHitRatioSensor = RocksDBMetrics.blockCacheDataHitRatioSensor(streamsMetrics, metricContext); + blockCacheIndexHitRatioSensor = RocksDBMetrics.blockCacheIndexHitRatioSensor(streamsMetrics, metricContext); + blockCacheFilterHitRatioSensor = RocksDBMetrics.blockCacheFilterHitRatioSensor(streamsMetrics, metricContext); + bytesWrittenDuringCompactionSensor = + RocksDBMetrics.bytesWrittenDuringCompactionSensor(streamsMetrics, metricContext); + bytesReadDuringCompactionSensor = RocksDBMetrics.bytesReadDuringCompactionSensor(streamsMetrics, metricContext); + numberOfOpenFilesSensor = RocksDBMetrics.numberOfOpenFilesSensor(streamsMetrics, metricContext); + numberOfFileErrorsSensor = RocksDBMetrics.numberOfFileErrorsSensor(streamsMetrics, metricContext); + } + + private void initGauges(final StreamsMetricsImpl streamsMetrics, + final RocksDBMetricContext metricContext) { + RocksDBMetrics.addNumImmutableMemTableMetric( + streamsMetrics, + metricContext, + gaugeToComputeSumOfProperties(NUMBER_OF_IMMUTABLE_MEMTABLES) + ); + RocksDBMetrics.addCurSizeActiveMemTable( + streamsMetrics, + metricContext, + gaugeToComputeSumOfProperties(CURRENT_SIZE_OF_ACTIVE_MEMTABLE) + ); + RocksDBMetrics.addCurSizeAllMemTables( + streamsMetrics, + metricContext, + gaugeToComputeSumOfProperties(CURRENT_SIZE_OF_ALL_MEMTABLES) + ); + RocksDBMetrics.addSizeAllMemTables( + streamsMetrics, + metricContext, + gaugeToComputeSumOfProperties(SIZE_OF_ALL_MEMTABLES) + ); + RocksDBMetrics.addNumEntriesActiveMemTableMetric( + streamsMetrics, + metricContext, + gaugeToComputeSumOfProperties(NUMBER_OF_ENTRIES_ACTIVE_MEMTABLE) + ); + RocksDBMetrics.addNumDeletesActiveMemTableMetric( + streamsMetrics, + metricContext, + gaugeToComputeSumOfProperties(NUMBER_OF_DELETES_ACTIVE_MEMTABLE) + ); + RocksDBMetrics.addNumEntriesImmMemTablesMetric( + streamsMetrics, + metricContext, + gaugeToComputeSumOfProperties(NUMBER_OF_ENTRIES_IMMUTABLE_MEMTABLES) + ); + RocksDBMetrics.addNumDeletesImmMemTablesMetric( + streamsMetrics, + metricContext, + gaugeToComputeSumOfProperties(NUMBER_OF_DELETES_IMMUTABLE_MEMTABLES) + ); + RocksDBMetrics.addMemTableFlushPending( + streamsMetrics, + metricContext, + gaugeToComputeSumOfProperties(MEMTABLE_FLUSH_PENDING) + ); + RocksDBMetrics.addNumRunningFlushesMetric( + streamsMetrics, + metricContext, + gaugeToComputeSumOfProperties(NUMBER_OF_RUNNING_FLUSHES) + ); + RocksDBMetrics.addCompactionPendingMetric( + streamsMetrics, + metricContext, + gaugeToComputeSumOfProperties(COMPACTION_PENDING) + ); + RocksDBMetrics.addNumRunningCompactionsMetric( + streamsMetrics, + metricContext, + gaugeToComputeSumOfProperties(NUMBER_OF_RUNNING_COMPACTIONS) + ); + RocksDBMetrics.addEstimatePendingCompactionBytesMetric( + streamsMetrics, + metricContext, + gaugeToComputeSumOfProperties(ESTIMATED_BYTES_OF_PENDING_COMPACTION) + ); + RocksDBMetrics.addTotalSstFilesSizeMetric( + streamsMetrics, + metricContext, + gaugeToComputeSumOfProperties(TOTAL_SST_FILES_SIZE) + ); + RocksDBMetrics.addLiveSstFilesSizeMetric( + streamsMetrics, + metricContext, + gaugeToComputeSumOfProperties(LIVE_SST_FILES_SIZE) + ); + RocksDBMetrics.addNumLiveVersionMetric( + streamsMetrics, + metricContext, + gaugeToComputeSumOfProperties(NUMBER_OF_LIVE_VERSIONS) + ); + RocksDBMetrics.addEstimateNumKeysMetric( + streamsMetrics, + metricContext, + gaugeToComputeSumOfProperties(ESTIMATED_NUMBER_OF_KEYS) + ); + RocksDBMetrics.addEstimateTableReadersMemMetric( + streamsMetrics, + metricContext, + gaugeToComputeSumOfProperties(ESTIMATED_MEMORY_OF_TABLE_READERS) + ); + RocksDBMetrics.addBackgroundErrorsMetric( + streamsMetrics, + metricContext, + gaugeToComputeSumOfProperties(NUMBER_OF_BACKGROUND_ERRORS) + ); + RocksDBMetrics.addBlockCacheCapacityMetric( + streamsMetrics, + metricContext, + gaugeToComputeBlockCacheMetrics(CAPACITY_OF_BLOCK_CACHE) + ); + RocksDBMetrics.addBlockCacheUsageMetric( + streamsMetrics, + metricContext, + gaugeToComputeBlockCacheMetrics(USAGE_OF_BLOCK_CACHE) + ); + RocksDBMetrics.addBlockCachePinnedUsageMetric( + streamsMetrics, + metricContext, + gaugeToComputeBlockCacheMetrics(PINNED_USAGE_OF_BLOCK_CACHE) + ); + } + + private Gauge gaugeToComputeSumOfProperties(final String propertyName) { + return (metricsConfig, now) -> { + BigInteger result = BigInteger.valueOf(0); + for (final DbAndCacheAndStatistics valueProvider : storeToValueProviders.values()) { + try { + // values of RocksDB properties are of type unsigned long in C++, i.e., in Java we need to use + // BigInteger and construct the object from the byte representation of the value + result = result.add(new BigInteger(1, longToBytes( + valueProvider.db.getAggregatedLongProperty(ROCKSDB_PROPERTIES_PREFIX + propertyName) + ))); + } catch (final RocksDBException e) { + throw new ProcessorStateException("Error recording RocksDB metric " + propertyName, e); + } + } + return result; + }; + } + + private Gauge gaugeToComputeBlockCacheMetrics(final String propertyName) { + return (metricsConfig, now) -> { + BigInteger result = BigInteger.valueOf(0); + for (final DbAndCacheAndStatistics valueProvider : storeToValueProviders.values()) { + try { + if (singleCache) { + // values of RocksDB properties are of type unsigned long in C++, i.e., in Java we need to use + // BigInteger and construct the object from the byte representation of the value + result = new BigInteger(1, longToBytes( + valueProvider.db.getAggregatedLongProperty(ROCKSDB_PROPERTIES_PREFIX + propertyName) + )); + break; + } else { + // values of RocksDB properties are of type unsigned long in C++, i.e., in Java we need to use + // BigInteger and construct the object from the byte representation of the value + result = result.add(new BigInteger(1, longToBytes( + valueProvider.db.getAggregatedLongProperty(ROCKSDB_PROPERTIES_PREFIX + propertyName) + ))); + } + } catch (final RocksDBException e) { + throw new ProcessorStateException("Error recording RocksDB metric " + propertyName, e); + } + } + return result; + }; + } + + private static byte[] longToBytes(final long data) { + final ByteBuffer conversionBuffer = ByteBuffer.allocate(Long.BYTES); + conversionBuffer.putLong(0, data); + return conversionBuffer.array(); + } + + public void removeValueProviders(final String segmentName) { + logger.debug("Removing value providers for store {} of task {}", segmentName, taskId); + final DbAndCacheAndStatistics removedValueProviders = storeToValueProviders.remove(segmentName); + if (removedValueProviders == null) { + throw new IllegalStateException("No value providers for store \"" + segmentName + "\" of task " + taskId + + " could be found. This is a bug in Kafka Streams. " + + "Please open a bug report under https://issues.apache.org/jira/projects/KAFKA/issues"); + } + removedValueProviders.maybeCloseStatistics(); + if (storeToValueProviders.isEmpty()) { + logger.debug( + "Removing metrics recorder for store {} of task {} from metrics recording trigger", + storeName, + taskId + ); + streamsMetrics.rocksDBMetricsRecordingTrigger().removeMetricsRecorder(this); + } + } + + public void record(final long now) { + logger.debug("Recording metrics for store {}", storeName); + long bytesWrittenToDatabase = 0; + long bytesReadFromDatabase = 0; + long memtableBytesFlushed = 0; + long memtableHits = 0; + long memtableMisses = 0; + long blockCacheDataHits = 0; + long blockCacheDataMisses = 0; + long blockCacheIndexHits = 0; + long blockCacheIndexMisses = 0; + long blockCacheFilterHits = 0; + long blockCacheFilterMisses = 0; + long writeStallDuration = 0; + long bytesWrittenDuringCompaction = 0; + long bytesReadDuringCompaction = 0; + long numberOfOpenFiles = 0; + long numberOfFileErrors = 0; + boolean shouldRecord = true; + for (final DbAndCacheAndStatistics valueProviders : storeToValueProviders.values()) { + if (valueProviders.statistics == null) { + shouldRecord = false; + break; + } + bytesWrittenToDatabase += valueProviders.statistics.getAndResetTickerCount(TickerType.BYTES_WRITTEN); + bytesReadFromDatabase += valueProviders.statistics.getAndResetTickerCount(TickerType.BYTES_READ); + memtableBytesFlushed += valueProviders.statistics.getAndResetTickerCount(TickerType.FLUSH_WRITE_BYTES); + memtableHits += valueProviders.statistics.getAndResetTickerCount(TickerType.MEMTABLE_HIT); + memtableMisses += valueProviders.statistics.getAndResetTickerCount(TickerType.MEMTABLE_MISS); + blockCacheDataHits += valueProviders.statistics.getAndResetTickerCount(TickerType.BLOCK_CACHE_DATA_HIT); + blockCacheDataMisses += valueProviders.statistics.getAndResetTickerCount(TickerType.BLOCK_CACHE_DATA_MISS); + blockCacheIndexHits += valueProviders.statistics.getAndResetTickerCount(TickerType.BLOCK_CACHE_INDEX_HIT); + blockCacheIndexMisses += valueProviders.statistics.getAndResetTickerCount(TickerType.BLOCK_CACHE_INDEX_MISS); + blockCacheFilterHits += valueProviders.statistics.getAndResetTickerCount(TickerType.BLOCK_CACHE_FILTER_HIT); + blockCacheFilterMisses += valueProviders.statistics.getAndResetTickerCount(TickerType.BLOCK_CACHE_FILTER_MISS); + writeStallDuration += valueProviders.statistics.getAndResetTickerCount(TickerType.STALL_MICROS); + bytesWrittenDuringCompaction += valueProviders.statistics.getAndResetTickerCount(TickerType.COMPACT_WRITE_BYTES); + bytesReadDuringCompaction += valueProviders.statistics.getAndResetTickerCount(TickerType.COMPACT_READ_BYTES); + numberOfOpenFiles += valueProviders.statistics.getAndResetTickerCount(TickerType.NO_FILE_OPENS) + - valueProviders.statistics.getAndResetTickerCount(TickerType.NO_FILE_CLOSES); + numberOfFileErrors += valueProviders.statistics.getAndResetTickerCount(TickerType.NO_FILE_ERRORS); + } + if (shouldRecord) { + bytesWrittenToDatabaseSensor.record(bytesWrittenToDatabase, now); + bytesReadFromDatabaseSensor.record(bytesReadFromDatabase, now); + memtableBytesFlushedSensor.record(memtableBytesFlushed, now); + memtableHitRatioSensor.record(computeHitRatio(memtableHits, memtableMisses), now); + blockCacheDataHitRatioSensor.record(computeHitRatio(blockCacheDataHits, blockCacheDataMisses), now); + blockCacheIndexHitRatioSensor.record(computeHitRatio(blockCacheIndexHits, blockCacheIndexMisses), now); + blockCacheFilterHitRatioSensor.record(computeHitRatio(blockCacheFilterHits, blockCacheFilterMisses), now); + writeStallDurationSensor.record(writeStallDuration, now); + bytesWrittenDuringCompactionSensor.record(bytesWrittenDuringCompaction, now); + bytesReadDuringCompactionSensor.record(bytesReadDuringCompaction, now); + numberOfOpenFilesSensor.record(numberOfOpenFiles, now); + numberOfFileErrorsSensor.record(numberOfFileErrors, now); + } + } + + private double computeHitRatio(final long hits, final long misses) { + if (hits == 0) { + return 0; + } + return (double) hits / (hits + misses); + } +} \ No newline at end of file diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/metrics/RocksDBMetricsRecordingTrigger.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/metrics/RocksDBMetricsRecordingTrigger.java new file mode 100644 index 0000000..d93f985 --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/metrics/RocksDBMetricsRecordingTrigger.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals.metrics; + +import org.apache.kafka.common.utils.Time; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +public class RocksDBMetricsRecordingTrigger implements Runnable { + + private final Map metricsRecordersToTrigger = new ConcurrentHashMap<>(); + private final Time time; + + public RocksDBMetricsRecordingTrigger(final Time time) { + this.time = time; + } + + public void addMetricsRecorder(final RocksDBMetricsRecorder metricsRecorder) { + final String metricsRecorderName = metricsRecorderName(metricsRecorder); + if (metricsRecordersToTrigger.containsKey(metricsRecorderName)) { + throw new IllegalStateException("RocksDB metrics recorder for store \"" + metricsRecorder.storeName() + + "\" of task " + metricsRecorder.taskId().toString() + " has already been added. " + + "This is a bug in Kafka Streams."); + } + metricsRecordersToTrigger.put(metricsRecorderName, metricsRecorder); + } + + public void removeMetricsRecorder(final RocksDBMetricsRecorder metricsRecorder) { + final RocksDBMetricsRecorder removedMetricsRecorder = + metricsRecordersToTrigger.remove(metricsRecorderName(metricsRecorder)); + if (removedMetricsRecorder == null) { + throw new IllegalStateException("No RocksDB metrics recorder for store " + + "\"" + metricsRecorder.storeName() + "\" of task " + metricsRecorder.taskId() + " could be found. " + + "This is a bug in Kafka Streams."); + } + } + + private String metricsRecorderName(final RocksDBMetricsRecorder metricsRecorder) { + return metricsRecorder.taskId().toString() + "-" + metricsRecorder.storeName(); + } + + @Override + public void run() { + final long now = time.milliseconds(); + for (final RocksDBMetricsRecorder metricsRecorder : metricsRecordersToTrigger.values()) { + metricsRecorder.record(now); + } + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/state/internals/metrics/StateStoreMetrics.java b/streams/src/main/java/org/apache/kafka/streams/state/internals/metrics/StateStoreMetrics.java new file mode 100644 index 0000000..360cd8d --- /dev/null +++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/metrics/StateStoreMetrics.java @@ -0,0 +1,481 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals.metrics; + +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.Sensor.RecordingLevel; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; + +import java.util.Map; + +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.LATENCY_SUFFIX; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.RECORD_E2E_LATENCY; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.RECORD_E2E_LATENCY_AVG_DESCRIPTION; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.RECORD_E2E_LATENCY_MAX_DESCRIPTION; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.RECORD_E2E_LATENCY_MIN_DESCRIPTION; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.STATE_STORE_LEVEL_GROUP; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.TOTAL_DESCRIPTION; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addAvgAndMaxToSensor; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addAvgAndMinAndMaxToSensor; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addInvocationRateAndCountToSensor; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addInvocationRateToSensor; + +public class StateStoreMetrics { + private StateStoreMetrics() {} + + private static final String AVG_DESCRIPTION_PREFIX = "The average "; + private static final String MAX_DESCRIPTION_PREFIX = "The maximum "; + private static final String LATENCY_DESCRIPTION = "latency of "; + private static final String AVG_LATENCY_DESCRIPTION_PREFIX = AVG_DESCRIPTION_PREFIX + LATENCY_DESCRIPTION; + private static final String MAX_LATENCY_DESCRIPTION_PREFIX = MAX_DESCRIPTION_PREFIX + LATENCY_DESCRIPTION; + private static final String RATE_DESCRIPTION_PREFIX = "The average number of "; + private static final String RATE_DESCRIPTION_SUFFIX = " per second"; + private static final String BUFFERED_RECORDS = "buffered records"; + + private static final String PUT = "put"; + private static final String PUT_DESCRIPTION = "calls to put"; + private static final String PUT_RATE_DESCRIPTION = + RATE_DESCRIPTION_PREFIX + PUT_DESCRIPTION + RATE_DESCRIPTION_SUFFIX; + private static final String PUT_AVG_LATENCY_DESCRIPTION = AVG_LATENCY_DESCRIPTION_PREFIX + PUT_DESCRIPTION; + private static final String PUT_MAX_LATENCY_DESCRIPTION = MAX_LATENCY_DESCRIPTION_PREFIX + PUT_DESCRIPTION; + + private static final String PUT_IF_ABSENT = "put-if-absent"; + private static final String PUT_IF_ABSENT_DESCRIPTION = "calls to put-if-absent"; + private static final String PUT_IF_ABSENT_RATE_DESCRIPTION = + RATE_DESCRIPTION_PREFIX + PUT_IF_ABSENT_DESCRIPTION + RATE_DESCRIPTION_SUFFIX; + private static final String PUT_IF_ABSENT_AVG_LATENCY_DESCRIPTION = + AVG_LATENCY_DESCRIPTION_PREFIX + PUT_IF_ABSENT_DESCRIPTION; + private static final String PUT_IF_ABSENT_MAX_LATENCY_DESCRIPTION = + MAX_LATENCY_DESCRIPTION_PREFIX + PUT_IF_ABSENT_DESCRIPTION; + + private static final String PUT_ALL = "put-all"; + private static final String PUT_ALL_DESCRIPTION = "calls to put-all"; + private static final String PUT_ALL_RATE_DESCRIPTION = + RATE_DESCRIPTION_PREFIX + PUT_ALL_DESCRIPTION + RATE_DESCRIPTION_SUFFIX; + private static final String PUT_ALL_AVG_LATENCY_DESCRIPTION = AVG_LATENCY_DESCRIPTION_PREFIX + PUT_ALL_DESCRIPTION; + private static final String PUT_ALL_MAX_LATENCY_DESCRIPTION = MAX_LATENCY_DESCRIPTION_PREFIX + PUT_ALL_DESCRIPTION; + + private static final String GET = "get"; + private static final String GET_DESCRIPTION = "calls to get"; + private static final String GET_RATE_DESCRIPTION = + RATE_DESCRIPTION_PREFIX + GET_DESCRIPTION + RATE_DESCRIPTION_SUFFIX; + private static final String GET_AVG_LATENCY_DESCRIPTION = AVG_LATENCY_DESCRIPTION_PREFIX + GET_DESCRIPTION; + private static final String GET_MAX_LATENCY_DESCRIPTION = MAX_LATENCY_DESCRIPTION_PREFIX + GET_DESCRIPTION; + + private static final String FETCH = "fetch"; + private static final String FETCH_DESCRIPTION = "calls to fetch"; + private static final String FETCH_RATE_DESCRIPTION = + RATE_DESCRIPTION_PREFIX + FETCH_DESCRIPTION + RATE_DESCRIPTION_SUFFIX; + private static final String FETCH_AVG_LATENCY_DESCRIPTION = AVG_LATENCY_DESCRIPTION_PREFIX + FETCH_DESCRIPTION; + private static final String FETCH_MAX_LATENCY_DESCRIPTION = MAX_LATENCY_DESCRIPTION_PREFIX + FETCH_DESCRIPTION; + + private static final String ALL = "all"; + private static final String ALL_DESCRIPTION = "calls to all"; + private static final String ALL_RATE_DESCRIPTION = + RATE_DESCRIPTION_PREFIX + ALL_DESCRIPTION + RATE_DESCRIPTION_SUFFIX; + private static final String ALL_AVG_LATENCY_DESCRIPTION = AVG_LATENCY_DESCRIPTION_PREFIX + ALL_DESCRIPTION; + private static final String ALL_MAX_LATENCY_DESCRIPTION = MAX_LATENCY_DESCRIPTION_PREFIX + ALL_DESCRIPTION; + + private static final String RANGE = "range"; + private static final String RANGE_DESCRIPTION = "calls to range"; + private static final String RANGE_RATE_DESCRIPTION = + RATE_DESCRIPTION_PREFIX + RANGE_DESCRIPTION + RATE_DESCRIPTION_SUFFIX; + private static final String RANGE_AVG_LATENCY_DESCRIPTION = AVG_LATENCY_DESCRIPTION_PREFIX + RANGE_DESCRIPTION; + private static final String RANGE_MAX_LATENCY_DESCRIPTION = MAX_LATENCY_DESCRIPTION_PREFIX + RANGE_DESCRIPTION; + + private static final String PREFIX_SCAN = "prefix-scan"; + private static final String PREFIX_SCAN_DESCRIPTION = "calls to prefix-scan"; + private static final String PREFIX_SCAN_RATE_DESCRIPTION = + RATE_DESCRIPTION_PREFIX + PREFIX_SCAN_DESCRIPTION + RATE_DESCRIPTION_SUFFIX; + private static final String PREFIX_SCAN_AVG_LATENCY_DESCRIPTION = AVG_LATENCY_DESCRIPTION_PREFIX + PREFIX_SCAN_DESCRIPTION; + private static final String PREFIX_SCAN_MAX_LATENCY_DESCRIPTION = MAX_LATENCY_DESCRIPTION_PREFIX + PREFIX_SCAN_DESCRIPTION; + + private static final String FLUSH = "flush"; + private static final String FLUSH_DESCRIPTION = "calls to flush"; + private static final String FLUSH_RATE_DESCRIPTION = + RATE_DESCRIPTION_PREFIX + FLUSH_DESCRIPTION + RATE_DESCRIPTION_SUFFIX; + private static final String FLUSH_AVG_LATENCY_DESCRIPTION = AVG_LATENCY_DESCRIPTION_PREFIX + FLUSH_DESCRIPTION; + private static final String FLUSH_MAX_LATENCY_DESCRIPTION = MAX_LATENCY_DESCRIPTION_PREFIX + FLUSH_DESCRIPTION; + + private static final String DELETE = "delete"; + private static final String DELETE_DESCRIPTION = "calls to delete"; + private static final String DELETE_RATE_DESCRIPTION = + RATE_DESCRIPTION_PREFIX + DELETE_DESCRIPTION + RATE_DESCRIPTION_SUFFIX; + private static final String DELETE_AVG_LATENCY_DESCRIPTION = AVG_LATENCY_DESCRIPTION_PREFIX + DELETE_DESCRIPTION; + private static final String DELETE_MAX_LATENCY_DESCRIPTION = MAX_LATENCY_DESCRIPTION_PREFIX + DELETE_DESCRIPTION; + + private static final String REMOVE = "remove"; + private static final String REMOVE_DESCRIPTION = "calls to remove"; + private static final String REMOVE_RATE_DESCRIPTION = + RATE_DESCRIPTION_PREFIX + REMOVE_DESCRIPTION + RATE_DESCRIPTION_SUFFIX; + private static final String REMOVE_AVG_LATENCY_DESCRIPTION = AVG_LATENCY_DESCRIPTION_PREFIX + REMOVE_DESCRIPTION; + private static final String REMOVE_MAX_LATENCY_DESCRIPTION = MAX_LATENCY_DESCRIPTION_PREFIX + REMOVE_DESCRIPTION; + + private static final String RESTORE = "restore"; + private static final String RESTORE_DESCRIPTION = "restorations"; + private static final String RESTORE_RATE_DESCRIPTION = + RATE_DESCRIPTION_PREFIX + RESTORE_DESCRIPTION + RATE_DESCRIPTION_SUFFIX; + private static final String RESTORE_AVG_LATENCY_DESCRIPTION = AVG_LATENCY_DESCRIPTION_PREFIX + RESTORE_DESCRIPTION; + private static final String RESTORE_MAX_LATENCY_DESCRIPTION = MAX_LATENCY_DESCRIPTION_PREFIX + RESTORE_DESCRIPTION; + + private static final String SUPPRESSION_BUFFER_COUNT = "suppression-buffer-count"; + private static final String SUPPRESSION_BUFFER_COUNT_DESCRIPTION = "count of " + BUFFERED_RECORDS; + private static final String SUPPRESSION_BUFFER_COUNT_AVG_DESCRIPTION = + AVG_DESCRIPTION_PREFIX + SUPPRESSION_BUFFER_COUNT_DESCRIPTION; + private static final String SUPPRESSION_BUFFER_COUNT_MAX_DESCRIPTION = + MAX_DESCRIPTION_PREFIX + SUPPRESSION_BUFFER_COUNT_DESCRIPTION; + + private static final String SUPPRESSION_BUFFER_SIZE = "suppression-buffer-size"; + private static final String SUPPRESSION_BUFFER_SIZE_DESCRIPTION = "size of " + BUFFERED_RECORDS; + private static final String SUPPRESSION_BUFFER_SIZE_AVG_DESCRIPTION = + AVG_DESCRIPTION_PREFIX + SUPPRESSION_BUFFER_SIZE_DESCRIPTION; + private static final String SUPPRESSION_BUFFER_SIZE_MAX_DESCRIPTION = + MAX_DESCRIPTION_PREFIX + SUPPRESSION_BUFFER_SIZE_DESCRIPTION; + + private static final String EXPIRED_WINDOW_RECORD_DROP = "expired-window-record-drop"; + private static final String EXPIRED_WINDOW_RECORD_DROP_DESCRIPTION = "dropped records due to an expired window"; + private static final String EXPIRED_WINDOW_RECORD_DROP_TOTAL_DESCRIPTION = + TOTAL_DESCRIPTION + EXPIRED_WINDOW_RECORD_DROP_DESCRIPTION; + private static final String EXPIRED_WINDOW_RECORD_DROP_RATE_DESCRIPTION = + RATE_DESCRIPTION_PREFIX + EXPIRED_WINDOW_RECORD_DROP_DESCRIPTION + RATE_DESCRIPTION_SUFFIX; + + public static Sensor putSensor(final String taskId, + final String storeType, + final String storeName, + final StreamsMetricsImpl streamsMetrics) { + return throughputAndLatencySensor( + taskId, + storeType, + storeName, + PUT, + PUT_RATE_DESCRIPTION, + PUT_AVG_LATENCY_DESCRIPTION, + PUT_MAX_LATENCY_DESCRIPTION, + RecordingLevel.DEBUG, + streamsMetrics + ); + } + + public static Sensor putIfAbsentSensor(final String taskId, + final String storeType, + final String storeName, + final StreamsMetricsImpl streamsMetrics) { + return throughputAndLatencySensor( + taskId, + storeType, + storeName, + PUT_IF_ABSENT, + PUT_IF_ABSENT_RATE_DESCRIPTION, + PUT_IF_ABSENT_AVG_LATENCY_DESCRIPTION, + PUT_IF_ABSENT_MAX_LATENCY_DESCRIPTION, + RecordingLevel.DEBUG, + streamsMetrics + ); + } + + public static Sensor putAllSensor(final String taskId, + final String storeType, + final String storeName, + final StreamsMetricsImpl streamsMetrics) { + return throughputAndLatencySensor( + taskId, + storeType, + storeName, + PUT_ALL, + PUT_ALL_RATE_DESCRIPTION, + PUT_ALL_AVG_LATENCY_DESCRIPTION, + PUT_ALL_MAX_LATENCY_DESCRIPTION, + RecordingLevel.DEBUG, + streamsMetrics + ); + } + + public static Sensor getSensor(final String taskId, + final String storeType, + final String storeName, + final StreamsMetricsImpl streamsMetrics) { + return throughputAndLatencySensor( + taskId, + storeType, + storeName, + GET, + GET_RATE_DESCRIPTION, + GET_AVG_LATENCY_DESCRIPTION, + GET_MAX_LATENCY_DESCRIPTION, + RecordingLevel.DEBUG, + streamsMetrics + ); + } + + public static Sensor fetchSensor(final String taskId, + final String storeType, + final String storeName, + final StreamsMetricsImpl streamsMetrics) { + return throughputAndLatencySensor( + taskId, + storeType, + storeName, + FETCH, + FETCH_RATE_DESCRIPTION, + FETCH_AVG_LATENCY_DESCRIPTION, + FETCH_MAX_LATENCY_DESCRIPTION, + RecordingLevel.DEBUG, + streamsMetrics + ); + } + + public static Sensor allSensor(final String taskId, + final String storeType, + final String storeName, + final StreamsMetricsImpl streamsMetrics) { + return throughputAndLatencySensor( + taskId, + storeType, + storeName, + ALL, + ALL_RATE_DESCRIPTION, + ALL_AVG_LATENCY_DESCRIPTION, + ALL_MAX_LATENCY_DESCRIPTION, + RecordingLevel.DEBUG, + streamsMetrics + ); + } + + public static Sensor rangeSensor(final String taskId, + final String storeType, + final String storeName, + final StreamsMetricsImpl streamsMetrics) { + return throughputAndLatencySensor( + taskId, + storeType, + storeName, + RANGE, + RANGE_RATE_DESCRIPTION, + RANGE_AVG_LATENCY_DESCRIPTION, + RANGE_MAX_LATENCY_DESCRIPTION, + RecordingLevel.DEBUG, + streamsMetrics + ); + } + + public static Sensor prefixScanSensor(final String taskId, + final String storeType, + final String storeName, + final StreamsMetricsImpl streamsMetrics) { + + final String latencyMetricName = PREFIX_SCAN + LATENCY_SUFFIX; + final Map tagMap = streamsMetrics.storeLevelTagMap(taskId, storeType, storeName); + + final Sensor sensor = streamsMetrics.storeLevelSensor(taskId, storeName, PREFIX_SCAN, RecordingLevel.DEBUG); + addInvocationRateToSensor( + sensor, + STATE_STORE_LEVEL_GROUP, + tagMap, + PREFIX_SCAN, + PREFIX_SCAN_RATE_DESCRIPTION + ); + addAvgAndMaxToSensor( + sensor, + STATE_STORE_LEVEL_GROUP, + tagMap, + latencyMetricName, + PREFIX_SCAN_AVG_LATENCY_DESCRIPTION, + PREFIX_SCAN_MAX_LATENCY_DESCRIPTION + ); + return sensor; + } + + public static Sensor flushSensor(final String taskId, + final String storeType, + final String storeName, + final StreamsMetricsImpl streamsMetrics) { + return throughputAndLatencySensor( + taskId, + storeType, + storeName, + FLUSH, + FLUSH_RATE_DESCRIPTION, + FLUSH_AVG_LATENCY_DESCRIPTION, + FLUSH_MAX_LATENCY_DESCRIPTION, + RecordingLevel.DEBUG, + streamsMetrics + ); + } + + public static Sensor deleteSensor(final String taskId, + final String storeType, + final String storeName, + final StreamsMetricsImpl streamsMetrics) { + return throughputAndLatencySensor( + taskId, + storeType, + storeName, + DELETE, + DELETE_RATE_DESCRIPTION, + DELETE_AVG_LATENCY_DESCRIPTION, + DELETE_MAX_LATENCY_DESCRIPTION, + RecordingLevel.DEBUG, + streamsMetrics + ); + } + + public static Sensor removeSensor(final String taskId, + final String storeType, + final String storeName, + final StreamsMetricsImpl streamsMetrics) { + return throughputAndLatencySensor( + taskId, + storeType, + storeName, + REMOVE, + REMOVE_RATE_DESCRIPTION, + REMOVE_AVG_LATENCY_DESCRIPTION, + REMOVE_MAX_LATENCY_DESCRIPTION, + RecordingLevel.DEBUG, + streamsMetrics + ); + } + + public static Sensor restoreSensor(final String taskId, + final String storeType, + final String storeName, + final StreamsMetricsImpl streamsMetrics) { + return throughputAndLatencySensor( + taskId, storeType, + storeName, + RESTORE, + RESTORE_RATE_DESCRIPTION, + RESTORE_AVG_LATENCY_DESCRIPTION, + RESTORE_MAX_LATENCY_DESCRIPTION, + RecordingLevel.DEBUG, + streamsMetrics + ); + } + + public static Sensor expiredWindowRecordDropSensor(final String taskId, + final String storeType, + final String storeName, + final StreamsMetricsImpl streamsMetrics) { + final Sensor sensor = streamsMetrics.storeLevelSensor( + taskId, + storeName, + EXPIRED_WINDOW_RECORD_DROP, + RecordingLevel.INFO + ); + addInvocationRateAndCountToSensor( + sensor, + "stream-" + storeType + "-metrics", + streamsMetrics.storeLevelTagMap(taskId, storeType, storeName), + EXPIRED_WINDOW_RECORD_DROP, + EXPIRED_WINDOW_RECORD_DROP_RATE_DESCRIPTION, + EXPIRED_WINDOW_RECORD_DROP_TOTAL_DESCRIPTION + ); + return sensor; + } + + public static Sensor suppressionBufferCountSensor(final String taskId, + final String storeType, + final String storeName, + final StreamsMetricsImpl streamsMetrics) { + return sizeOrCountSensor( + taskId, + storeType, + storeName, + SUPPRESSION_BUFFER_COUNT, + SUPPRESSION_BUFFER_COUNT_AVG_DESCRIPTION, + SUPPRESSION_BUFFER_COUNT_MAX_DESCRIPTION, + RecordingLevel.DEBUG, + streamsMetrics + ); + } + + public static Sensor suppressionBufferSizeSensor(final String taskId, + final String storeType, + final String storeName, + final StreamsMetricsImpl streamsMetrics) { + return sizeOrCountSensor( + taskId, + storeType, + storeName, + SUPPRESSION_BUFFER_SIZE, + SUPPRESSION_BUFFER_SIZE_AVG_DESCRIPTION, + SUPPRESSION_BUFFER_SIZE_MAX_DESCRIPTION, + RecordingLevel.DEBUG, + streamsMetrics + ); + } + + public static Sensor e2ELatencySensor(final String taskId, + final String storeType, + final String storeName, + final StreamsMetricsImpl streamsMetrics) { + final Sensor sensor = streamsMetrics.storeLevelSensor(taskId, storeName, RECORD_E2E_LATENCY, RecordingLevel.TRACE); + final Map tagMap = streamsMetrics.storeLevelTagMap(taskId, storeType, storeName); + addAvgAndMinAndMaxToSensor( + sensor, + STATE_STORE_LEVEL_GROUP, + tagMap, + RECORD_E2E_LATENCY, + RECORD_E2E_LATENCY_AVG_DESCRIPTION, + RECORD_E2E_LATENCY_MIN_DESCRIPTION, + RECORD_E2E_LATENCY_MAX_DESCRIPTION + ); + return sensor; + } + + private static Sensor sizeOrCountSensor(final String taskId, + final String storeType, + final String storeName, + final String metricName, + final String descriptionOfAvg, + final String descriptionOfMax, + final RecordingLevel recordingLevel, + final StreamsMetricsImpl streamsMetrics) { + final Sensor sensor = streamsMetrics.storeLevelSensor(taskId, storeName, metricName, recordingLevel); + final String group; + final Map tagMap; + group = STATE_STORE_LEVEL_GROUP; + tagMap = streamsMetrics.storeLevelTagMap(taskId, storeType, storeName); + addAvgAndMaxToSensor(sensor, group, tagMap, metricName, descriptionOfAvg, descriptionOfMax); + return sensor; + } + + private static Sensor throughputAndLatencySensor(final String taskId, + final String storeType, + final String storeName, + final String metricName, + final String descriptionOfRate, + final String descriptionOfAvg, + final String descriptionOfMax, + final RecordingLevel recordingLevel, + final StreamsMetricsImpl streamsMetrics) { + final Sensor sensor; + final String latencyMetricName = metricName + LATENCY_SUFFIX; + final Map tagMap = streamsMetrics.storeLevelTagMap(taskId, storeType, storeName); + sensor = streamsMetrics.storeLevelSensor(taskId, storeName, metricName, recordingLevel); + addInvocationRateToSensor(sensor, STATE_STORE_LEVEL_GROUP, tagMap, metricName, descriptionOfRate); + addAvgAndMaxToSensor( + sensor, + STATE_STORE_LEVEL_GROUP, + tagMap, + latencyMetricName, + descriptionOfAvg, + descriptionOfMax + ); + return sensor; + } +} diff --git a/streams/src/main/resources/common/message/SubscriptionInfoData.json b/streams/src/main/resources/common/message/SubscriptionInfoData.json new file mode 100644 index 0000000..f9a830e --- /dev/null +++ b/streams/src/main/resources/common/message/SubscriptionInfoData.json @@ -0,0 +1,142 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +{ + "name": "SubscriptionInfoData", + "validVersions": "1-10", + "flexibleVersions": "none", + "fields": [ + { + "name": "version", + "versions": "1+", + "type": "int32" + }, + { + "name": "latestSupportedVersion", + "versions": "3+", + "default": "-1", + "type": "int32" + }, + { + "name": "processId", + "versions": "1+", + "type": "uuid" + }, + /***** Protocol version 1-6 only (after 6 this is encoded in task offset sum map) *****/ + { + "name": "prevTasks", + "versions": "1-6", + "type": "[]TaskId" + }, + { + "name": "standbyTasks", + "versions": "1-6", + "type": "[]TaskId" + }, + /***************/ + { + "name": "userEndPoint", + "versions": "2+", + "type": "bytes" + }, + { + "name": "taskOffsetSums", + "versions": "7+", + "type": "[]TaskOffsetSum" + }, + { + "name": "uniqueField", + "versions": "8+", + "type": "int8" + }, + { + "name": "errorCode", + "versions": "9+", + "type": "int32" + } + ], + "commonStructs": [ + // TaskId was only used from 1-6, after 6 we encode each field of the TaskId separately along with the other information for that map entry + { + "name": "TaskId", + "versions": "1-6", + "fields": [ + { + "name": "topicGroupId", + "versions": "1-6", + "type": "int32" + }, + { + "name": "partition", + "versions": "1-6", + "type": "int32" + } + ] + }, + { + "name": "TaskOffsetSum", + "versions": "7+", + "fields": [ + { + "name": "topicGroupId", + "versions": "7+", + "type": "int32" + }, + // Prior to version 10, in 7-9, the below fields (partition and offsetSum) were encoded via the nested + // partitionToOffsetSum struct. In 10+ all fields are encoded directly in the TaskOffsetSum struct + { + "name": "partition", + "versions": "10+", + "type": "int32" + }, + { + "name": "offsetSum", + "versions": "10+", + "type": "int64" + }, + { + "name": "namedTopology", + "versions": "10+", + "nullableVersions": "10+", + "ignorable": "false", // namedTopology is not ignorable because if you do, a TaskId may not be unique + "type": "string" + }, + { + "name": "partitionToOffsetSum", + "versions": "7-9", + "type": "[]PartitionToOffsetSum" + } + ] + }, + + { + "name": "PartitionToOffsetSum", + "versions": "7-9", + "fields": [ + { + "name": "partition", + "versions": "7-9", + "type": "int32" + }, + { + "name": "offsetSum", + "versions": "7-9", + "type": "int64" + } + ] + } + ], + "type": "data" +} diff --git a/streams/src/test/java/org/apache/kafka/common/metrics/SensorAccessor.java b/streams/src/test/java/org/apache/kafka/common/metrics/SensorAccessor.java new file mode 100644 index 0000000..bcc642f --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/common/metrics/SensorAccessor.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.common.metrics; + +import java.util.List; + +/** + * This class allows unit tests to access package-private members in class {@link Sensor}. + */ +public class SensorAccessor { + + public final Sensor sensor; + + public SensorAccessor(final Sensor sensor) { + this.sensor = sensor; + } + + public List parents() { + return sensor.parents(); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/EqualityCheck.java b/streams/src/test/java/org/apache/kafka/streams/EqualityCheck.java new file mode 100644 index 0000000..4ca1dde --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/EqualityCheck.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +public final class EqualityCheck { + private EqualityCheck() {} + + // Inspired by EqualsTester from Guava + public static void verifyEquality(final T o1, final T o2) { + // making sure we don't get an NPE in the test + if (o1 == null && o2 == null) { + return; + } else if (o1 == null) { + throw new AssertionError(String.format("o1 was null, but o2[%s] was not.", o2)); + } else if (o2 == null) { + throw new AssertionError(String.format("o1[%s] was not null, but o2 was.", o1)); + } + verifyGeneralEqualityProperties(o1, o2); + + + // the two objects should equal each other + if (!o1.equals(o2)) { + throw new AssertionError(String.format("o1[%s] was not equal to o2[%s].", o1, o2)); + } + + if (!o2.equals(o1)) { + throw new AssertionError(String.format("o2[%s] was not equal to o1[%s].", o2, o1)); + } + + verifyHashCodeConsistency(o1, o2); + + // since these objects are equal, their hashcode MUST be the same + if (o1.hashCode() != o2.hashCode()) { + throw new AssertionError(String.format("o1[%s].hash[%d] was not equal to o2[%s].hash[%d].", o1, o1.hashCode(), o2, o2.hashCode())); + } + } + + public static void verifyInEquality(final T o1, final T o2) { + // making sure we don't get an NPE in the test + if (o1 == null && o2 == null) { + throw new AssertionError("Both o1 and o2 were null."); + } else if (o1 == null) { + return; + } else if (o2 == null) { + return; + } + + verifyGeneralEqualityProperties(o1, o2); + + // these two objects should NOT equal each other + if (o1.equals(o2)) { + throw new AssertionError(String.format("o1[%s] was equal to o2[%s].", o1, o2)); + } + + if (o2.equals(o1)) { + throw new AssertionError(String.format("o2[%s] was equal to o1[%s].", o2, o1)); + } + verifyHashCodeConsistency(o1, o2); + + + // since these objects are NOT equal, their hashcode SHOULD PROBABLY not be the same + if (o1.hashCode() == o2.hashCode()) { + throw new AssertionError( + String.format( + "o1[%s].hash[%d] was equal to o2[%s].hash[%d], even though !o1.equals(o2). " + + "This is NOT A BUG, but it is undesirable for hash collection performance.", + o1, + o1.hashCode(), + o2, + o2.hashCode() + ) + ); + } + } + + + @SuppressWarnings({"EqualsWithItself", "ConstantConditions", "ObjectEqualsNull"}) + private static void verifyGeneralEqualityProperties(final T o1, final T o2) { + // objects should equal themselves + if (!o1.equals(o1)) { + throw new AssertionError(String.format("o1[%s] was not equal to itself.", o1)); + } + + if (!o2.equals(o2)) { + throw new AssertionError(String.format("o2[%s] was not equal to itself.", o2)); + } + + // non-null objects should not equal null + if (o1.equals(null)) { + throw new AssertionError(String.format("o1[%s] was equal to null.", o1)); + } + + if (o2.equals(null)) { + throw new AssertionError(String.format("o2[%s] was equal to null.", o2)); + } + + // objects should not equal some random object + if (o1.equals(new Object())) { + throw new AssertionError(String.format("o1[%s] was equal to an anonymous Object.", o1)); + } + + if (o2.equals(new Object())) { + throw new AssertionError(String.format("o2[%s] was equal to an anonymous Object.", o2)); + } + } + + + private static void verifyHashCodeConsistency(final T o1, final T o2) { + { + final int first = o1.hashCode(); + final int second = o1.hashCode(); + if (first != second) { + throw new AssertionError( + String.format( + "o1[%s]'s hashcode was not consistent: [%d]!=[%d].", + o1, + first, + second + ) + ); + } + } + + { + final int first = o2.hashCode(); + final int second = o2.hashCode(); + if (first != second) { + throw new AssertionError( + String.format( + "o2[%s]'s hashcode was not consistent: [%d]!=[%d].", + o2, + first, + second + ) + ); + } + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java b/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java new file mode 100644 index 0000000..d854910 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsTest.java @@ -0,0 +1,1133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.ListOffsetsResult; +import org.apache.kafka.clients.admin.ListOffsetsResult.ListOffsetsResultInfo; +import org.apache.kafka.clients.admin.MockAdminClient; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.producer.MockProducer; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.metrics.MetricConfig; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.MetricsContext; +import org.apache.kafka.common.metrics.MetricsReporter; +import org.apache.kafka.common.metrics.Sensor.RecordingLevel; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.KafkaStreams.State; +import org.apache.kafka.streams.errors.StreamsNotStartedException; +import org.apache.kafka.streams.errors.StreamsUncaughtExceptionHandler; +import org.apache.kafka.streams.errors.TopologyException; +import org.apache.kafka.streams.errors.UnknownStateStoreException; +import org.apache.kafka.streams.internals.metrics.ClientMetrics; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.processor.StateRestoreListener; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.GlobalStreamThread; +import org.apache.kafka.streams.processor.internals.ProcessorTopology; +import org.apache.kafka.streams.processor.internals.StateDirectory; +import org.apache.kafka.streams.processor.internals.StreamThread; +import org.apache.kafka.streams.processor.internals.StreamsMetadataState; +import org.apache.kafka.streams.processor.internals.TopologyMetadata; +import org.apache.kafka.streams.processor.internals.ThreadMetadataImpl; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.internals.metrics.RocksDBMetricsRecordingTrigger; +import org.apache.kafka.test.MockClientSupplier; +import org.apache.kafka.test.MockMetricsReporter; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.TestUtils; +import org.easymock.Capture; +import org.easymock.EasyMock; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestName; +import org.junit.runner.RunWith; +import org.powermock.api.easymock.PowerMock; +import org.powermock.api.easymock.annotation.Mock; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.net.InetSocketAddress; +import java.time.Duration; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Properties; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import static java.util.Collections.emptyList; +import static java.util.Collections.singletonList; +import static org.apache.kafka.streams.state.QueryableStoreTypes.keyValueStore; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.waitForApplicationState; + +import static org.apache.kafka.test.TestUtils.waitForCondition; +import static org.easymock.EasyMock.anyInt; +import static org.easymock.EasyMock.anyLong; +import static org.easymock.EasyMock.anyObject; +import static org.easymock.EasyMock.anyString; +import static org.easymock.EasyMock.capture; +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +@RunWith(PowerMockRunner.class) +@PrepareForTest({KafkaStreams.class, StreamThread.class, ClientMetrics.class}) +public class KafkaStreamsTest { + + private static final int NUM_THREADS = 2; + private final static String APPLICATION_ID = "appId"; + private final static String CLIENT_ID = "test-client"; + private final static Duration DEFAULT_DURATION = Duration.ofSeconds(30); + + @Rule + public TestName testName = new TestName(); + + private MockClientSupplier supplier; + private MockTime time; + + private Properties props; + + @Mock + private StateDirectory stateDirectory; + @Mock + private StreamThread streamThreadOne; + @Mock + private StreamThread streamThreadTwo; + @Mock + private GlobalStreamThread globalStreamThread; + @Mock + private Metrics metrics; + + private StateListenerStub streamsStateListener; + private Capture> metricsReportersCapture; + private Capture threadStatelistenerCapture; + + public static class StateListenerStub implements KafkaStreams.StateListener { + int numChanges = 0; + KafkaStreams.State oldState; + KafkaStreams.State newState; + public Map mapStates = new HashMap<>(); + + @Override + public void onChange(final KafkaStreams.State newState, + final KafkaStreams.State oldState) { + final long prevCount = mapStates.containsKey(newState) ? mapStates.get(newState) : 0; + numChanges++; + this.oldState = oldState; + this.newState = newState; + mapStates.put(newState, prevCount + 1); + } + } + + @Before + public void before() throws Exception { + time = new MockTime(); + supplier = new MockClientSupplier(); + supplier.setCluster(Cluster.bootstrap(singletonList(new InetSocketAddress("localhost", 9999)))); + streamsStateListener = new StateListenerStub(); + threadStatelistenerCapture = EasyMock.newCapture(); + metricsReportersCapture = EasyMock.newCapture(); + + props = new Properties(); + props.put(StreamsConfig.APPLICATION_ID_CONFIG, APPLICATION_ID); + props.put(StreamsConfig.CLIENT_ID_CONFIG, CLIENT_ID); + props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:2018"); + props.put(StreamsConfig.METRIC_REPORTER_CLASSES_CONFIG, MockMetricsReporter.class.getName()); + props.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()); + props.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, NUM_THREADS); + + prepareStreams(); + } + + private void prepareStreams() throws Exception { + // setup metrics + PowerMock.expectNew(Metrics.class, + anyObject(MetricConfig.class), + capture(metricsReportersCapture), + anyObject(Time.class), + anyObject(MetricsContext.class) + ).andAnswer(() -> { + for (final MetricsReporter reporter : metricsReportersCapture.getValue()) { + reporter.init(Collections.emptyList()); + } + return metrics; + }).anyTimes(); + metrics.close(); + EasyMock.expectLastCall().andAnswer(() -> { + for (final MetricsReporter reporter : metricsReportersCapture.getValue()) { + reporter.close(); + } + return null; + }).anyTimes(); + + PowerMock.mockStatic(ClientMetrics.class); + EasyMock.expect(ClientMetrics.version()).andReturn("1.56"); + EasyMock.expect(ClientMetrics.commitId()).andReturn("1a2b3c4d5e"); + ClientMetrics.addVersionMetric(anyObject(StreamsMetricsImpl.class)); + ClientMetrics.addCommitIdMetric(anyObject(StreamsMetricsImpl.class)); + ClientMetrics.addApplicationIdMetric(anyObject(StreamsMetricsImpl.class), EasyMock.eq(APPLICATION_ID)); + ClientMetrics.addTopologyDescriptionMetric(anyObject(StreamsMetricsImpl.class), EasyMock.anyObject()); + ClientMetrics.addStateMetric(anyObject(StreamsMetricsImpl.class), anyObject()); + ClientMetrics.addNumAliveStreamThreadMetric(anyObject(StreamsMetricsImpl.class), anyObject()); + + // setup stream threads + PowerMock.mockStatic(StreamThread.class); + EasyMock.expect(StreamThread.create( + anyObject(TopologyMetadata.class), + anyObject(StreamsConfig.class), + anyObject(KafkaClientSupplier.class), + anyObject(Admin.class), + anyObject(UUID.class), + anyObject(String.class), + anyObject(StreamsMetricsImpl.class), + anyObject(Time.class), + anyObject(StreamsMetadataState.class), + anyLong(), + anyObject(StateDirectory.class), + anyObject(StateRestoreListener.class), + anyInt(), + anyObject(Runnable.class), + anyObject() + )).andReturn(streamThreadOne).andReturn(streamThreadTwo); + + EasyMock.expect(StreamThread.eosEnabled(anyObject(StreamsConfig.class))).andReturn(false).anyTimes(); + EasyMock.expect(StreamThread.processingMode(anyObject(StreamsConfig.class))).andReturn(StreamThread.ProcessingMode.AT_LEAST_ONCE).anyTimes(); + EasyMock.expect(streamThreadOne.getId()).andReturn(1L).anyTimes(); + EasyMock.expect(streamThreadTwo.getId()).andReturn(2L).anyTimes(); + prepareStreamThread(streamThreadOne, 1, true); + prepareStreamThread(streamThreadTwo, 2, false); + + // setup global threads + final AtomicReference globalThreadState = new AtomicReference<>(GlobalStreamThread.State.CREATED); + PowerMock.expectNew(GlobalStreamThread.class, + anyObject(ProcessorTopology.class), + anyObject(StreamsConfig.class), + anyObject(Consumer.class), + anyObject(StateDirectory.class), + anyLong(), + anyObject(StreamsMetricsImpl.class), + anyObject(Time.class), + anyString(), + anyObject(StateRestoreListener.class), + anyObject(StreamsUncaughtExceptionHandler.class) + ).andReturn(globalStreamThread).anyTimes(); + EasyMock.expect(globalStreamThread.state()).andAnswer(globalThreadState::get).anyTimes(); + globalStreamThread.setStateListener(capture(threadStatelistenerCapture)); + EasyMock.expectLastCall().anyTimes(); + + globalStreamThread.start(); + EasyMock.expectLastCall().andAnswer(() -> { + globalThreadState.set(GlobalStreamThread.State.RUNNING); + threadStatelistenerCapture.getValue().onChange(globalStreamThread, + GlobalStreamThread.State.RUNNING, + GlobalStreamThread.State.CREATED); + return null; + }).anyTimes(); + globalStreamThread.shutdown(); + EasyMock.expectLastCall().andAnswer(() -> { + supplier.restoreConsumer.close(); + + for (final MockProducer producer : supplier.producers) { + producer.close(); + } + globalThreadState.set(GlobalStreamThread.State.DEAD); + threadStatelistenerCapture.getValue().onChange(globalStreamThread, + GlobalStreamThread.State.PENDING_SHUTDOWN, + GlobalStreamThread.State.RUNNING); + threadStatelistenerCapture.getValue().onChange(globalStreamThread, + GlobalStreamThread.State.DEAD, + GlobalStreamThread.State.PENDING_SHUTDOWN); + return null; + }).anyTimes(); + EasyMock.expect(globalStreamThread.stillRunning()).andReturn(globalThreadState.get() == GlobalStreamThread.State.RUNNING).anyTimes(); + globalStreamThread.join(); + EasyMock.expectLastCall().anyTimes(); + + PowerMock.replay( + StreamThread.class, + Metrics.class, + metrics, + ClientMetrics.class, + streamThreadOne, + streamThreadTwo, + GlobalStreamThread.class, + globalStreamThread + ); + } + + private void prepareStreamThread(final StreamThread thread, + final int threadId, + final boolean terminable) throws Exception { + final AtomicReference state = new AtomicReference<>(StreamThread.State.CREATED); + EasyMock.expect(thread.state()).andAnswer(state::get).anyTimes(); + + thread.setStateListener(capture(threadStatelistenerCapture)); + EasyMock.expectLastCall().anyTimes(); + + EasyMock.expect(thread.getStateLock()).andReturn(new Object()).anyTimes(); + + thread.start(); + EasyMock.expectLastCall().andAnswer(() -> { + state.set(StreamThread.State.STARTING); + threadStatelistenerCapture.getValue().onChange(thread, + StreamThread.State.STARTING, + StreamThread.State.CREATED); + threadStatelistenerCapture.getValue().onChange(thread, + StreamThread.State.PARTITIONS_REVOKED, + StreamThread.State.STARTING); + threadStatelistenerCapture.getValue().onChange(thread, + StreamThread.State.PARTITIONS_ASSIGNED, + StreamThread.State.PARTITIONS_REVOKED); + threadStatelistenerCapture.getValue().onChange(thread, + StreamThread.State.RUNNING, + StreamThread.State.PARTITIONS_ASSIGNED); + return null; + }).anyTimes(); + EasyMock.expect(thread.getGroupInstanceID()).andStubReturn(Optional.empty()); + EasyMock.expect(thread.threadMetadata()).andReturn(new ThreadMetadataImpl( + "processId-StreamThread-" + threadId, + "DEAD", + "", + "", + Collections.emptySet(), + "", + Collections.emptySet(), + Collections.emptySet() + ) + ).anyTimes(); + EasyMock.expect(thread.waitOnThreadState(EasyMock.isA(StreamThread.State.class), anyLong())).andStubReturn(true); + EasyMock.expect(thread.isAlive()).andReturn(true).times(0, 1); + thread.resizeCache(EasyMock.anyLong()); + EasyMock.expectLastCall().anyTimes(); + thread.requestLeaveGroupDuringShutdown(); + EasyMock.expectLastCall().anyTimes(); + EasyMock.expect(thread.getName()).andStubReturn("processId-StreamThread-" + threadId); + thread.shutdown(); + EasyMock.expectLastCall().andAnswer(() -> { + supplier.consumer.close(); + supplier.restoreConsumer.close(); + for (final MockProducer producer : supplier.producers) { + producer.close(); + } + state.set(StreamThread.State.DEAD); + + threadStatelistenerCapture.getValue().onChange(thread, StreamThread.State.PENDING_SHUTDOWN, StreamThread.State.RUNNING); + threadStatelistenerCapture.getValue().onChange(thread, StreamThread.State.DEAD, StreamThread.State.PENDING_SHUTDOWN); + return null; + }).anyTimes(); + EasyMock.expect(thread.isRunning()).andReturn(state.get() == StreamThread.State.RUNNING).anyTimes(); + thread.join(); + if (terminable) { + EasyMock.expectLastCall().anyTimes(); + } else { + EasyMock.expectLastCall().andAnswer(() -> { + Thread.sleep(2000L); + return null; + }).anyTimes(); + } + + EasyMock.expect(thread.activeTasks()).andStubReturn(emptyList()); + EasyMock.expect(thread.allTasks()).andStubReturn(Collections.emptyMap()); + } + + @Test + public void testShouldTransitToNotRunningIfCloseRightAfterCreated() { + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) { + streams.close(); + + Assert.assertEquals(KafkaStreams.State.NOT_RUNNING, streams.state()); + } + } + + @Test + public void stateShouldTransitToRunningIfNonDeadThreadsBackToRunning() throws InterruptedException { + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) { + streams.setStateListener(streamsStateListener); + + Assert.assertEquals(0, streamsStateListener.numChanges); + Assert.assertEquals(KafkaStreams.State.CREATED, streams.state()); + + streams.start(); + + waitForCondition( + () -> streamsStateListener.numChanges == 2, + "Streams never started."); + Assert.assertEquals(KafkaStreams.State.RUNNING, streams.state()); + waitForCondition( + () -> streamsStateListener.numChanges == 2, + "Streams never started."); + Assert.assertEquals(KafkaStreams.State.RUNNING, streams.state()); + + for (final StreamThread thread : streams.threads) { + threadStatelistenerCapture.getValue().onChange( + thread, + StreamThread.State.PARTITIONS_REVOKED, + StreamThread.State.RUNNING); + } + + Assert.assertEquals(3, streamsStateListener.numChanges); + Assert.assertEquals(KafkaStreams.State.REBALANCING, streams.state()); + + for (final StreamThread thread : streams.threads) { + threadStatelistenerCapture.getValue().onChange( + thread, + StreamThread.State.PARTITIONS_ASSIGNED, + StreamThread.State.PARTITIONS_REVOKED); + } + + Assert.assertEquals(3, streamsStateListener.numChanges); + Assert.assertEquals(KafkaStreams.State.REBALANCING, streams.state()); + + threadStatelistenerCapture.getValue().onChange( + streams.threads.get(NUM_THREADS - 1), + StreamThread.State.PENDING_SHUTDOWN, + StreamThread.State.PARTITIONS_ASSIGNED); + + threadStatelistenerCapture.getValue().onChange( + streams.threads.get(NUM_THREADS - 1), + StreamThread.State.DEAD, + StreamThread.State.PENDING_SHUTDOWN); + + Assert.assertEquals(3, streamsStateListener.numChanges); + Assert.assertEquals(KafkaStreams.State.REBALANCING, streams.state()); + + for (final StreamThread thread : streams.threads) { + if (thread != streams.threads.get(NUM_THREADS - 1)) { + threadStatelistenerCapture.getValue().onChange( + thread, + StreamThread.State.RUNNING, + StreamThread.State.PARTITIONS_ASSIGNED); + } + } + + Assert.assertEquals(4, streamsStateListener.numChanges); + Assert.assertEquals(KafkaStreams.State.RUNNING, streams.state()); + + streams.close(); + + waitForCondition( + () -> streamsStateListener.numChanges == 6, + "Streams never closed."); + Assert.assertEquals(KafkaStreams.State.NOT_RUNNING, streams.state()); + } + } + + @Test + public void shouldCleanupResourcesOnCloseWithoutPreviousStart() throws Exception { + final StreamsBuilder builder = getBuilderWithSource(); + builder.globalTable("anyTopic"); + + try (final KafkaStreams streams = new KafkaStreams(builder.build(), props, supplier, time)) { + streams.close(); + + waitForCondition( + () -> streams.state() == KafkaStreams.State.NOT_RUNNING, + "Streams never stopped."); + } + + assertTrue(supplier.consumer.closed()); + assertTrue(supplier.restoreConsumer.closed()); + for (final MockProducer p : supplier.producers) { + assertTrue(p.closed()); + } + } + + @Test + public void testStateThreadClose() throws Exception { + // make sure we have the global state thread running too + final StreamsBuilder builder = getBuilderWithSource(); + builder.globalTable("anyTopic"); + + try (final KafkaStreams streams = new KafkaStreams(builder.build(), props, supplier, time)) { + assertEquals(NUM_THREADS, streams.threads.size()); + assertEquals(streams.state(), KafkaStreams.State.CREATED); + + streams.start(); + waitForCondition( + () -> streams.state() == KafkaStreams.State.RUNNING, + "Streams never started."); + + for (int i = 0; i < NUM_THREADS; i++) { + final StreamThread tmpThread = streams.threads.get(i); + tmpThread.shutdown(); + waitForCondition(() -> tmpThread.state() == StreamThread.State.DEAD, + "Thread never stopped."); + streams.threads.get(i).join(); + } + waitForCondition( + () -> streams.metadataForLocalThreads().stream().allMatch(t -> t.threadState().equals("DEAD")), + "Streams never stopped" + ); + streams.close(); + + waitForCondition( + () -> streams.state() == KafkaStreams.State.NOT_RUNNING, + "Streams never stopped."); + + assertNull(streams.globalStreamThread); + } + } + + @Test + public void testStateGlobalThreadClose() throws Exception { + // make sure we have the global state thread running too + final StreamsBuilder builder = getBuilderWithSource(); + builder.globalTable("anyTopic"); + + try (final KafkaStreams streams = new KafkaStreams(builder.build(), props, supplier, time)) { + streams.start(); + waitForCondition( + () -> streams.state() == KafkaStreams.State.RUNNING, + "Streams never started."); + + final GlobalStreamThread globalStreamThread = streams.globalStreamThread; + globalStreamThread.shutdown(); + waitForCondition( + () -> globalStreamThread.state() == GlobalStreamThread.State.DEAD, + "Thread never stopped."); + globalStreamThread.join(); + waitForCondition( + () -> streams.state() == KafkaStreams.State.PENDING_ERROR, + "Thread never stopped." + ); + streams.close(); + + waitForCondition( + () -> streams.state() == KafkaStreams.State.ERROR, + "Thread never stopped." + ); + } + } + + @Test + public void testInitializesAndDestroysMetricsReporters() { + final int oldInitCount = MockMetricsReporter.INIT_COUNT.get(); + + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) { + final int newInitCount = MockMetricsReporter.INIT_COUNT.get(); + final int initDiff = newInitCount - oldInitCount; + assertTrue("some reporters should be initialized by calling on construction", initDiff > 0); + + streams.start(); + final int oldCloseCount = MockMetricsReporter.CLOSE_COUNT.get(); + streams.close(); + assertEquals(oldCloseCount + initDiff, MockMetricsReporter.CLOSE_COUNT.get()); + } + } + + @Test + public void testCloseIsIdempotent() { + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) { + streams.close(); + final int closeCount = MockMetricsReporter.CLOSE_COUNT.get(); + + streams.close(); + Assert.assertEquals("subsequent close() calls should do nothing", + closeCount, MockMetricsReporter.CLOSE_COUNT.get()); + } + } + + @Test + public void shouldAddThreadWhenRunning() throws InterruptedException { + props.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 1); + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) { + streams.start(); + final int oldSize = streams.threads.size(); + waitForCondition(() -> streams.state() == KafkaStreams.State.RUNNING, 15L, "wait until running"); + assertThat(streams.addStreamThread(), equalTo(Optional.of("processId-StreamThread-" + 2))); + assertThat(streams.threads.size(), equalTo(oldSize + 1)); + } + } + + @Test + public void shouldNotAddThreadWhenCreated() { + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) { + final int oldSize = streams.threads.size(); + assertThat(streams.addStreamThread(), equalTo(Optional.empty())); + assertThat(streams.threads.size(), equalTo(oldSize)); + } + } + + @Test + public void shouldNotAddThreadWhenClosed() { + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) { + final int oldSize = streams.threads.size(); + streams.close(); + assertThat(streams.addStreamThread(), equalTo(Optional.empty())); + assertThat(streams.threads.size(), equalTo(oldSize)); + } + } + + @Test + public void shouldNotAddThreadWhenError() { + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) { + final int oldSize = streams.threads.size(); + streams.start(); + globalStreamThread.shutdown(); + assertThat(streams.addStreamThread(), equalTo(Optional.empty())); + assertThat(streams.threads.size(), equalTo(oldSize)); + } + } + + @Test + public void shouldNotReturnDeadThreads() { + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) { + streams.start(); + streamThreadOne.shutdown(); + final Set threads = streams.metadataForLocalThreads(); + assertThat(threads.size(), equalTo(1)); + assertThat(threads, hasItem(streamThreadTwo.threadMetadata())); + } + } + + @Test + public void shouldRemoveThread() throws InterruptedException { + props.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 2); + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) { + streams.start(); + final int oldSize = streams.threads.size(); + waitForCondition(() -> streams.state() == KafkaStreams.State.RUNNING, 15L, + "Kafka Streams client did not reach state RUNNING"); + assertThat(streams.removeStreamThread(), equalTo(Optional.of("processId-StreamThread-" + 1))); + assertThat(streams.threads.size(), equalTo(oldSize - 1)); + } + } + + @Test + public void shouldNotRemoveThreadWhenNotRunning() { + props.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 1); + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) { + assertThat(streams.removeStreamThread(), equalTo(Optional.empty())); + assertThat(streams.threads.size(), equalTo(1)); + } + } + + @Test + public void testCannotStartOnceClosed() { + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) { + streams.start(); + streams.close(); + try { + streams.start(); + fail("Should have throw IllegalStateException"); + } catch (final IllegalStateException expected) { + // this is ok + } + } + } + + @Test + public void shouldNotSetGlobalRestoreListenerAfterStarting() { + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) { + streams.start(); + try { + streams.setGlobalStateRestoreListener(null); + fail("Should throw an IllegalStateException"); + } catch (final IllegalStateException e) { + // expected + } + } + } + + @Test + public void shouldThrowExceptionSettingUncaughtExceptionHandlerNotInCreateState() { + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) { + streams.start(); + assertThrows(IllegalStateException.class, () -> streams.setUncaughtExceptionHandler((StreamsUncaughtExceptionHandler) null)); + } + } + + @Test + public void shouldThrowExceptionSettingStreamsUncaughtExceptionHandlerNotInCreateState() { + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) { + streams.start(); + assertThrows(IllegalStateException.class, () -> streams.setUncaughtExceptionHandler((StreamsUncaughtExceptionHandler) null)); + } + + } + @Test + public void shouldThrowNullPointerExceptionSettingStreamsUncaughtExceptionHandlerIfNull() { + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) { + assertThrows(NullPointerException.class, () -> streams.setUncaughtExceptionHandler((StreamsUncaughtExceptionHandler) null)); + } + } + + @Test + public void shouldThrowExceptionSettingStateListenerNotInCreateState() { + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) { + streams.start(); + try { + streams.setStateListener(null); + fail("Should throw IllegalStateException"); + } catch (final IllegalStateException e) { + // expected + } + } + } + + @Test + public void shouldAllowCleanupBeforeStartAndAfterClose() { + final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time); + try { + streams.cleanUp(); + streams.start(); + } finally { + streams.close(); + streams.cleanUp(); + } + } + + @Test + public void shouldThrowOnCleanupWhileRunning() throws InterruptedException { + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) { + streams.start(); + waitForCondition( + () -> streams.state() == KafkaStreams.State.RUNNING, + "Streams never started."); + + try { + streams.cleanUp(); + fail("Should have thrown IllegalStateException"); + } catch (final IllegalStateException expected) { + assertEquals("Cannot clean up while running.", expected.getMessage()); + } + } + } + + @Test + public void shouldThrowOnCleanupWhileShuttingDown() throws InterruptedException { + final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time); + streams.start(); + waitForCondition( + () -> streams.state() == KafkaStreams.State.RUNNING, + "Streams never started."); + + streams.close(Duration.ZERO); + assertThat(streams.state() == State.PENDING_SHUTDOWN, equalTo(true)); + assertThrows(IllegalStateException.class, streams::cleanUp); + assertThat(streams.state() == State.PENDING_SHUTDOWN, equalTo(true)); + } + + @Test + public void shouldNotGetAllTasksWhenNotRunning() throws InterruptedException { + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) { + assertThrows(StreamsNotStartedException.class, streams::metadataForAllStreamsClients); + streams.start(); + waitForApplicationState(Collections.singletonList(streams), KafkaStreams.State.RUNNING, DEFAULT_DURATION); + streams.close(); + waitForApplicationState(Collections.singletonList(streams), KafkaStreams.State.NOT_RUNNING, DEFAULT_DURATION); + assertThrows(IllegalStateException.class, streams::metadataForAllStreamsClients); + } + } + + @Test + public void shouldNotGetAllTasksWithStoreWhenNotRunning() throws InterruptedException { + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) { + assertThrows(StreamsNotStartedException.class, () -> streams.streamsMetadataForStore("store")); + streams.start(); + waitForApplicationState(Collections.singletonList(streams), KafkaStreams.State.RUNNING, DEFAULT_DURATION); + streams.close(); + waitForApplicationState(Collections.singletonList(streams), KafkaStreams.State.NOT_RUNNING, DEFAULT_DURATION); + assertThrows(IllegalStateException.class, () -> streams.streamsMetadataForStore("store")); + } + } + + @Test + public void shouldNotGetQueryMetadataWithSerializerWhenNotRunningOrRebalancing() throws InterruptedException { + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) { + assertThrows(StreamsNotStartedException.class, () -> streams.queryMetadataForKey("store", "key", Serdes.String().serializer())); + streams.start(); + waitForApplicationState(Collections.singletonList(streams), KafkaStreams.State.RUNNING, DEFAULT_DURATION); + streams.close(); + waitForApplicationState(Collections.singletonList(streams), KafkaStreams.State.NOT_RUNNING, DEFAULT_DURATION); + assertThrows(IllegalStateException.class, () -> streams.queryMetadataForKey("store", "key", Serdes.String().serializer())); + } + } + + @Test + public void shouldGetQueryMetadataWithSerializerWhenRunningOrRebalancing() { + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) { + streams.start(); + assertEquals(KeyQueryMetadata.NOT_AVAILABLE, streams.queryMetadataForKey("store", "key", Serdes.String().serializer())); + } + } + + @Test + public void shouldNotGetQueryMetadataWithPartitionerWhenNotRunningOrRebalancing() throws InterruptedException { + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) { + assertThrows(StreamsNotStartedException.class, () -> streams.queryMetadataForKey("store", "key", (topic, key, value, numPartitions) -> 0)); + streams.start(); + waitForApplicationState(Collections.singletonList(streams), KafkaStreams.State.RUNNING, DEFAULT_DURATION); + streams.close(); + waitForApplicationState(Collections.singletonList(streams), KafkaStreams.State.NOT_RUNNING, DEFAULT_DURATION); + assertThrows(IllegalStateException.class, () -> streams.queryMetadataForKey("store", "key", (topic, key, value, numPartitions) -> 0)); + } + } + + @Test + public void shouldThrowUnknownStateStoreExceptionWhenStoreNotExist() throws InterruptedException { + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) { + streams.start(); + waitForApplicationState(Collections.singletonList(streams), KafkaStreams.State.RUNNING, DEFAULT_DURATION); + assertThrows(UnknownStateStoreException.class, () -> streams.store(StoreQueryParameters.fromNameAndType("unknown-store", keyValueStore()))); + } + } + + @Test + public void shouldNotGetStoreWhenWhenNotRunningOrRebalancing() throws InterruptedException { + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) { + assertThrows(StreamsNotStartedException.class, () -> streams.store(StoreQueryParameters.fromNameAndType("store", keyValueStore()))); + streams.start(); + waitForApplicationState(Collections.singletonList(streams), KafkaStreams.State.RUNNING, DEFAULT_DURATION); + streams.close(); + waitForApplicationState(Collections.singletonList(streams), KafkaStreams.State.NOT_RUNNING, DEFAULT_DURATION); + assertThrows(IllegalStateException.class, () -> streams.store(StoreQueryParameters.fromNameAndType("store", keyValueStore()))); + } + } + + @Test + public void shouldReturnEmptyLocalStorePartitionLags() { + // Mock all calls made to compute the offset lags, + final ListOffsetsResult result = EasyMock.mock(ListOffsetsResult.class); + final KafkaFutureImpl> allFuture = new KafkaFutureImpl<>(); + allFuture.complete(Collections.emptyMap()); + + EasyMock.expect(result.all()).andReturn(allFuture); + final MockAdminClient mockAdminClient = EasyMock.partialMockBuilder(MockAdminClient.class) + .addMockedMethod("listOffsets", Map.class).createMock(); + EasyMock.expect(mockAdminClient.listOffsets(anyObject())).andStubReturn(result); + final MockClientSupplier mockClientSupplier = EasyMock.partialMockBuilder(MockClientSupplier.class) + .addMockedMethod("getAdmin").createMock(); + EasyMock.expect(mockClientSupplier.getAdmin(anyObject())).andReturn(mockAdminClient); + EasyMock.replay(result, mockAdminClient, mockClientSupplier); + + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, mockClientSupplier, time)) { + streams.start(); + assertEquals(0, streams.allLocalStorePartitionLags().size()); + } + } + + @Test + public void shouldReturnFalseOnCloseWhenThreadsHaventTerminated() { + // do not use mock time so that it can really elapse + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier)) { + assertFalse(streams.close(Duration.ofMillis(10L))); + } + } + + @Test + public void shouldThrowOnNegativeTimeoutForClose() { + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) { + assertThrows(IllegalArgumentException.class, () -> streams.close(Duration.ofMillis(-1L))); + } + } + + @Test + public void shouldNotBlockInCloseForZeroDuration() { + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) { + // with mock time that does not elapse, close would not return if it ever waits on the state transition + assertFalse(streams.close(Duration.ZERO)); + } + } + + @Test + public void shouldTriggerRecordingOfRocksDBMetricsIfRecordingLevelIsDebug() { + PowerMock.mockStatic(Executors.class); + final ScheduledExecutorService cleanupSchedule = EasyMock.niceMock(ScheduledExecutorService.class); + final ScheduledExecutorService rocksDBMetricsRecordingTriggerThread = + EasyMock.mock(ScheduledExecutorService.class); + EasyMock.expect(Executors.newSingleThreadScheduledExecutor( + anyObject(ThreadFactory.class) + )).andReturn(cleanupSchedule); + EasyMock.expect(Executors.newSingleThreadScheduledExecutor( + anyObject(ThreadFactory.class) + )).andReturn(rocksDBMetricsRecordingTriggerThread); + EasyMock.expect(rocksDBMetricsRecordingTriggerThread.scheduleAtFixedRate( + EasyMock.anyObject(RocksDBMetricsRecordingTrigger.class), + EasyMock.eq(0L), + EasyMock.eq(1L), + EasyMock.eq(TimeUnit.MINUTES) + )).andReturn(null); + EasyMock.expect(rocksDBMetricsRecordingTriggerThread.shutdownNow()).andReturn(null); + PowerMock.replay(Executors.class); + PowerMock.replay(rocksDBMetricsRecordingTriggerThread); + PowerMock.replay(cleanupSchedule); + + final StreamsBuilder builder = new StreamsBuilder(); + builder.table("topic", Materialized.as("store")); + props.setProperty(StreamsConfig.METRICS_RECORDING_LEVEL_CONFIG, RecordingLevel.DEBUG.name()); + + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) { + streams.start(); + } + + PowerMock.verify(Executors.class); + PowerMock.verify(rocksDBMetricsRecordingTriggerThread); + } + + @Test + public void shouldNotTriggerRecordingOfRocksDBMetricsIfRecordingLevelIsInfo() { + PowerMock.mockStatic(Executors.class); + final ScheduledExecutorService cleanupSchedule = EasyMock.niceMock(ScheduledExecutorService.class); + final ScheduledExecutorService rocksDBMetricsRecordingTriggerThread = + EasyMock.mock(ScheduledExecutorService.class); + EasyMock.expect(Executors.newSingleThreadScheduledExecutor( + anyObject(ThreadFactory.class) + )).andReturn(cleanupSchedule); + PowerMock.replay(Executors.class, rocksDBMetricsRecordingTriggerThread, cleanupSchedule); + + final StreamsBuilder builder = new StreamsBuilder(); + builder.table("topic", Materialized.as("store")); + props.setProperty(StreamsConfig.METRICS_RECORDING_LEVEL_CONFIG, RecordingLevel.INFO.name()); + + try (final KafkaStreams streams = new KafkaStreams(getBuilderWithSource().build(), props, supplier, time)) { + streams.start(); + } + + PowerMock.verify(Executors.class, rocksDBMetricsRecordingTriggerThread); + } + + @Test + public void shouldCleanupOldStateDirs() throws Exception { + PowerMock.mockStatic(Executors.class); + final ScheduledExecutorService cleanupSchedule = EasyMock.mock(ScheduledExecutorService.class); + EasyMock.expect(Executors.newSingleThreadScheduledExecutor( + anyObject(ThreadFactory.class) + )).andReturn(cleanupSchedule).anyTimes(); + EasyMock.expect(cleanupSchedule.scheduleAtFixedRate( + EasyMock.anyObject(Runnable.class), + EasyMock.eq(1L), + EasyMock.eq(1L), + EasyMock.eq(TimeUnit.MILLISECONDS) + )).andReturn(null); + EasyMock.expect(cleanupSchedule.shutdownNow()).andReturn(null); + PowerMock.expectNew(StateDirectory.class, + anyObject(StreamsConfig.class), + anyObject(Time.class), + EasyMock.eq(true), + EasyMock.eq(false) + ).andReturn(stateDirectory); + EasyMock.expect(stateDirectory.initializeProcessId()).andReturn(UUID.randomUUID()); + stateDirectory.close(); + PowerMock.replayAll(Executors.class, cleanupSchedule, stateDirectory); + + props.setProperty(StreamsConfig.STATE_CLEANUP_DELAY_MS_CONFIG, "1"); + final StreamsBuilder builder = new StreamsBuilder(); + builder.table("topic", Materialized.as("store")); + + try (final KafkaStreams streams = new KafkaStreams(builder.build(), props, supplier, time)) { + streams.start(); + } + + PowerMock.verify(Executors.class, cleanupSchedule); + } + + @Test + public void statelessTopologyShouldNotCreateStateDirectory() throws Exception { + final String safeTestName = safeUniqueTestName(getClass(), testName); + final String inputTopic = safeTestName + "-input"; + final String outputTopic = safeTestName + "-output"; + final Topology topology = new Topology(); + topology.addSource("source", Serdes.String().deserializer(), Serdes.String().deserializer(), inputTopic) + .addProcessor("process", () -> new Processor() { + private ProcessorContext context; + + @Override + public void init(final ProcessorContext context) { + this.context = context; + } + + @Override + public void process(final Record record) { + if (record.value().length() % 2 == 0) { + context.forward(record.withValue(record.key() + record.value())); + } + } + }, "source") + .addSink("sink", outputTopic, new StringSerializer(), new StringSerializer(), "process"); + startStreamsAndCheckDirExists(topology, false); + } + + @Test + public void inMemoryStatefulTopologyShouldNotCreateStateDirectory() throws Exception { + final String safeTestName = safeUniqueTestName(getClass(), testName); + final String inputTopic = safeTestName + "-input"; + final String outputTopic = safeTestName + "-output"; + final String globalTopicName = safeTestName + "-global"; + final String storeName = safeTestName + "-counts"; + final String globalStoreName = safeTestName + "-globalStore"; + final Topology topology = getStatefulTopology(inputTopic, outputTopic, globalTopicName, storeName, globalStoreName, false); + startStreamsAndCheckDirExists(topology, false); + } + + @Test + public void statefulTopologyShouldCreateStateDirectory() throws Exception { + final String safeTestName = safeUniqueTestName(getClass(), testName); + final String inputTopic = safeTestName + "-input"; + final String outputTopic = safeTestName + "-output"; + final String globalTopicName = safeTestName + "-global"; + final String storeName = safeTestName + "-counts"; + final String globalStoreName = safeTestName + "-globalStore"; + final Topology topology = getStatefulTopology(inputTopic, outputTopic, globalTopicName, storeName, globalStoreName, true); + startStreamsAndCheckDirExists(topology, true); + } + + @Test + public void shouldThrowTopologyExceptionOnEmptyTopology() { + try { + new KafkaStreams(new StreamsBuilder().build(), props, supplier, time); + fail("Should have thrown TopologyException"); + } catch (final TopologyException e) { + assertThat( + e.getMessage(), + equalTo("Invalid topology: Topology has no stream threads and no global threads, " + + "must subscribe to at least one source topic or global table.")); + } + } + + @Test + public void shouldNotCreateStreamThreadsForGlobalOnlyTopology() { + final StreamsBuilder builder = new StreamsBuilder(); + builder.globalTable("anyTopic"); + try (final KafkaStreams streams = new KafkaStreams(builder.build(), props, supplier, time)) { + + assertThat(streams.threads.size(), equalTo(0)); + } + } + + @Test + public void shouldTransitToRunningWithGlobalOnlyTopology() throws InterruptedException { + final StreamsBuilder builder = new StreamsBuilder(); + builder.globalTable("anyTopic"); + try (final KafkaStreams streams = new KafkaStreams(builder.build(), props, supplier, time)) { + + assertThat(streams.threads.size(), equalTo(0)); + assertEquals(streams.state(), KafkaStreams.State.CREATED); + + streams.start(); + waitForCondition( + () -> streams.state() == KafkaStreams.State.RUNNING, + "Streams never started, state is " + streams.state()); + + streams.close(); + + waitForCondition( + () -> streams.state() == KafkaStreams.State.NOT_RUNNING, + "Streams never stopped."); + } + } + + @Deprecated // testing old PAPI + private Topology getStatefulTopology(final String inputTopic, + final String outputTopic, + final String globalTopicName, + final String storeName, + final String globalStoreName, + final boolean isPersistentStore) { + final StoreBuilder> storeBuilder = Stores.keyValueStoreBuilder( + isPersistentStore ? + Stores.persistentKeyValueStore(storeName) + : Stores.inMemoryKeyValueStore(storeName), + Serdes.String(), + Serdes.Long()); + final Topology topology = new Topology(); + topology.addSource("source", Serdes.String().deserializer(), Serdes.String().deserializer(), inputTopic) + .addProcessor("process", () -> new Processor() { + private ProcessorContext context; + + @Override + public void init(final ProcessorContext context) { + this.context = context; + } + + @Override + public void process(final Record record) { + final KeyValueStore kvStore = context.getStateStore(storeName); + kvStore.put(record.key(), 5L); + + context.forward(record.withValue("5")); + context.commit(); + } + }, "source") + .addStateStore(storeBuilder, "process") + .addSink("sink", outputTopic, new StringSerializer(), new StringSerializer(), "process"); + + final StoreBuilder> globalStoreBuilder = Stores.keyValueStoreBuilder( + isPersistentStore ? Stores.persistentKeyValueStore(globalStoreName) : Stores.inMemoryKeyValueStore(globalStoreName), + Serdes.String(), Serdes.String()).withLoggingDisabled(); + topology.addGlobalStore( + globalStoreBuilder, + "global", + Serdes.String().deserializer(), + Serdes.String().deserializer(), + globalTopicName, + globalTopicName + "-processor", + new MockProcessorSupplier<>()); + return topology; + } + + private StreamsBuilder getBuilderWithSource() { + final StreamsBuilder builder = new StreamsBuilder(); + builder.stream("source-topic"); + return builder; + } + + private void startStreamsAndCheckDirExists(final Topology topology, + final boolean shouldFilesExist) throws Exception { + PowerMock.expectNew(StateDirectory.class, + anyObject(StreamsConfig.class), + anyObject(Time.class), + EasyMock.eq(shouldFilesExist), + EasyMock.eq(false) + ).andReturn(stateDirectory); + EasyMock.expect(stateDirectory.initializeProcessId()).andReturn(UUID.randomUUID()); + + PowerMock.replayAll(); + + new KafkaStreams(topology, props, supplier, time); + + PowerMock.verifyAll(); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsWrapper.java b/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsWrapper.java new file mode 100644 index 0000000..9dd1fc1 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/KafkaStreamsWrapper.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import java.util.Properties; + +import org.apache.kafka.streams.KafkaStreams.State; +import org.apache.kafka.streams.processor.internals.StreamThread; + +/** + * This class allows to access the {@link KafkaStreams} a {@link StreamThread.StateListener} object. + * + */ +public class KafkaStreamsWrapper extends KafkaStreams { + + public KafkaStreamsWrapper(final Topology topology, + final Properties props) { + super(topology, props); + } + + /** + * An app can set a single {@link StreamThread.StateListener} so that the app is notified when state changes. + * + * @param listener a new StreamThread state listener + * @throws IllegalStateException if this {@code KafkaStreams} instance is not in state {@link State#CREATED CREATED}. + */ + public void setStreamThreadStateListener(final StreamThread.StateListener listener) { + if (state == State.CREATED) { + for (final StreamThread thread : threads) { + thread.setStateListener(listener); + } + } else { + throw new IllegalStateException("Can only set StateListener in CREATED state. " + + "Current state is: " + state); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/KeyValueTest.java b/streams/src/test/java/org/apache/kafka/streams/KeyValueTest.java new file mode 100644 index 0000000..24f7d5d --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/KeyValueTest.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import org.junit.Test; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class KeyValueTest { + + @Test + public void shouldHaveSameEqualsAndHashCode() { + final KeyValue kv = KeyValue.pair("key1", 1L); + final KeyValue copyOfKV = KeyValue.pair(kv.key, kv.value); + + // Reflexive + assertTrue(kv.equals(kv)); + assertTrue(kv.hashCode() == kv.hashCode()); + + // Symmetric + assertTrue(kv.equals(copyOfKV)); + assertTrue(kv.hashCode() == copyOfKV.hashCode()); + assertTrue(copyOfKV.hashCode() == kv.hashCode()); + + // Transitive + final KeyValue copyOfCopyOfKV = KeyValue.pair(copyOfKV.key, copyOfKV.value); + assertTrue(copyOfKV.equals(copyOfCopyOfKV)); + assertTrue(copyOfKV.hashCode() == copyOfCopyOfKV.hashCode()); + assertTrue(kv.equals(copyOfCopyOfKV)); + assertTrue(kv.hashCode() == copyOfCopyOfKV.hashCode()); + + // Inequality scenarios + assertFalse("must be false for null", kv.equals(null)); + assertFalse("must be false if key is non-null and other key is null", kv.equals(KeyValue.pair(null, kv.value))); + assertFalse("must be false if value is non-null and other value is null", kv.equals(KeyValue.pair(kv.key, null))); + final KeyValue differentKeyType = KeyValue.pair(1L, kv.value); + assertFalse("must be false for different key types", kv.equals(differentKeyType)); + final KeyValue differentValueType = KeyValue.pair(kv.key, "anyString"); + assertFalse("must be false for different value types", kv.equals(differentValueType)); + final KeyValue differentKeyValueTypes = KeyValue.pair(1L, "anyString"); + assertFalse("must be false for different key and value types", kv.equals(differentKeyValueTypes)); + assertFalse("must be false for different types of objects", kv.equals(new Object())); + + final KeyValue differentKey = KeyValue.pair(kv.key + "suffix", kv.value); + assertFalse("must be false if key is different", kv.equals(differentKey)); + assertFalse("must be false if key is different", differentKey.equals(kv)); + + final KeyValue differentValue = KeyValue.pair(kv.key, kv.value + 1L); + assertFalse("must be false if value is different", kv.equals(differentValue)); + assertFalse("must be false if value is different", differentValue.equals(kv)); + + final KeyValue differentKeyAndValue = KeyValue.pair(kv.key + "suffix", kv.value + 1L); + assertFalse("must be false if key and value are different", kv.equals(differentKeyAndValue)); + assertFalse("must be false if key and value are different", differentKeyAndValue.equals(kv)); + } + +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/KeyValueTimestamp.java b/streams/src/test/java/org/apache/kafka/streams/KeyValueTimestamp.java new file mode 100644 index 0000000..b578562 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/KeyValueTimestamp.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import java.util.Objects; + +public class KeyValueTimestamp { + private final K key; + private final V value; + private final long timestamp; + + public KeyValueTimestamp(final K key, final V value, final long timestamp) { + this.key = key; + this.value = value; + this.timestamp = timestamp; + } + + public K key() { + return key; + } + + public V value() { + return value; + } + + public long timestamp() { + return timestamp; + } + + @Override + public String toString() { + return "KeyValueTimestamp{key=" + key + ", value=" + value + ", timestamp=" + timestamp + '}'; + } + + @Override + public boolean equals(final Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + final KeyValueTimestamp that = (KeyValueTimestamp) o; + return timestamp == that.timestamp && + Objects.equals(key, that.key) && + Objects.equals(value, that.value); + } + + @Override + public int hashCode() { + return Objects.hash(key, value, timestamp); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/StreamsBuilderTest.java b/streams/src/test/java/org/apache/kafka/streams/StreamsBuilderTest.java new file mode 100644 index 0000000..e18b972 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/StreamsBuilderTest.java @@ -0,0 +1,1100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import org.apache.kafka.common.serialization.LongSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.Topology.AutoOffsetReset; +import org.apache.kafka.streams.errors.TopologyException; +import org.apache.kafka.streams.kstream.Branched; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.ForeachAction; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.JoinWindows; +import org.apache.kafka.streams.kstream.Joined; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Named; +import org.apache.kafka.streams.kstream.Printed; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.StreamJoined; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; +import org.apache.kafka.streams.processor.internals.ProcessorNode; +import org.apache.kafka.streams.processor.internals.ProcessorTopology; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.test.MockMapper; +import org.apache.kafka.test.MockPredicate; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.MockValueJoiner; +import org.apache.kafka.test.NoopValueTransformer; +import org.apache.kafka.test.NoopValueTransformerWithKey; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Test; + +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.regex.Pattern; + +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.SUBTOPOLOGY_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.SUBTOPOLOGY_1; + +import static java.util.Arrays.asList; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.Is.is; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class StreamsBuilderTest { + + private static final String STREAM_TOPIC = "stream-topic"; + + private static final String STREAM_OPERATION_NAME = "stream-operation"; + + private static final String STREAM_TOPIC_TWO = "stream-topic-two"; + + private static final String TABLE_TOPIC = "table-topic"; + + private final StreamsBuilder builder = new StreamsBuilder(); + + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.String(), Serdes.String()); + + @Test + public void shouldAddGlobalStore() { + final StreamsBuilder builder = new StreamsBuilder(); + builder.addGlobalStore( + Stores.keyValueStoreBuilder( + Stores.inMemoryKeyValueStore("store"), + Serdes.String(), + Serdes.String() + ), + "topic", + Consumed.with(Serdes.String(), Serdes.String()), + () -> new Processor() { + private KeyValueStore store; + + @Override + public void init(final ProcessorContext context) { + store = context.getStateStore("store"); + } + + @Override + public void process(final Record record) { + store.put(record.key(), record.value()); + } + } + ); + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build())) { + final TestInputTopic inputTopic = + driver.createInputTopic("topic", new StringSerializer(), new StringSerializer()); + inputTopic.pipeInput("hey", "there"); + final KeyValueStore store = driver.getKeyValueStore("store"); + final String hey = store.get("hey"); + assertThat(hey, is("there")); + } + } + + @Test + public void shouldNotThrowNullPointerIfOptimizationsNotSpecified() { + final Properties properties = new Properties(); + + final StreamsBuilder builder = new StreamsBuilder(); + builder.build(properties); + } + + @Test + public void shouldAllowJoinUnmaterializedFilteredKTable() { + final KTable filteredKTable = builder + .table(TABLE_TOPIC) + .filter(MockPredicate.allGoodPredicate()); + builder + .stream(STREAM_TOPIC) + .join(filteredKTable, MockValueJoiner.TOSTRING_JOINER); + builder.build(); + + final ProcessorTopology topology = + builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + + assertThat( + topology.stateStores().size(), + equalTo(1)); + assertThat( + topology.processorConnectedStateStores("KSTREAM-JOIN-0000000005"), + equalTo(Collections.singleton(topology.stateStores().get(0).name()))); + assertTrue( + topology.processorConnectedStateStores("KTABLE-FILTER-0000000003").isEmpty()); + } + + @Test + public void shouldAllowJoinMaterializedFilteredKTable() { + final KTable filteredKTable = builder + .table(TABLE_TOPIC) + .filter(MockPredicate.allGoodPredicate(), Materialized.as("store")); + builder + .stream(STREAM_TOPIC) + .join(filteredKTable, MockValueJoiner.TOSTRING_JOINER); + builder.build(); + + final ProcessorTopology topology = + builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + + assertThat( + topology.stateStores().size(), + equalTo(1)); + assertThat( + topology.processorConnectedStateStores("KSTREAM-JOIN-0000000005"), + equalTo(Collections.singleton("store"))); + assertThat( + topology.processorConnectedStateStores("KTABLE-FILTER-0000000003"), + equalTo(Collections.singleton("store"))); + } + + @Test + public void shouldAllowJoinUnmaterializedMapValuedKTable() { + final KTable mappedKTable = builder + .table(TABLE_TOPIC) + .mapValues(MockMapper.noOpValueMapper()); + builder + .stream(STREAM_TOPIC) + .join(mappedKTable, MockValueJoiner.TOSTRING_JOINER); + builder.build(); + + final ProcessorTopology topology = + builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + + assertThat( + topology.stateStores().size(), + equalTo(1)); + assertThat( + topology.processorConnectedStateStores("KSTREAM-JOIN-0000000005"), + equalTo(Collections.singleton(topology.stateStores().get(0).name()))); + assertTrue( + topology.processorConnectedStateStores("KTABLE-MAPVALUES-0000000003").isEmpty()); + } + + @Test + public void shouldAllowJoinMaterializedMapValuedKTable() { + final KTable mappedKTable = builder + .table(TABLE_TOPIC) + .mapValues(MockMapper.noOpValueMapper(), Materialized.as("store")); + builder + .stream(STREAM_TOPIC) + .join(mappedKTable, MockValueJoiner.TOSTRING_JOINER); + builder.build(); + + final ProcessorTopology topology = + builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + + assertThat( + topology.stateStores().size(), + equalTo(1)); + assertThat( + topology.processorConnectedStateStores("KSTREAM-JOIN-0000000005"), + equalTo(Collections.singleton("store"))); + assertThat( + topology.processorConnectedStateStores("KTABLE-MAPVALUES-0000000003"), + equalTo(Collections.singleton("store"))); + } + + @Test + public void shouldAllowJoinUnmaterializedJoinedKTable() { + final KTable table1 = builder.table("table-topic1"); + final KTable table2 = builder.table("table-topic2"); + builder + .stream(STREAM_TOPIC) + .join(table1.join(table2, MockValueJoiner.TOSTRING_JOINER), MockValueJoiner.TOSTRING_JOINER); + builder.build(); + + final ProcessorTopology topology = + builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + + assertThat( + topology.stateStores().size(), + equalTo(2)); + assertThat( + topology.processorConnectedStateStores("KSTREAM-JOIN-0000000010"), + equalTo(Utils.mkSet(topology.stateStores().get(0).name(), topology.stateStores().get(1).name()))); + assertTrue( + topology.processorConnectedStateStores("KTABLE-MERGE-0000000007").isEmpty()); + } + + @Test + public void shouldAllowJoinMaterializedJoinedKTable() { + final KTable table1 = builder.table("table-topic1"); + final KTable table2 = builder.table("table-topic2"); + builder + .stream(STREAM_TOPIC) + .join( + table1.join(table2, MockValueJoiner.TOSTRING_JOINER, Materialized.as("store")), + MockValueJoiner.TOSTRING_JOINER); + builder.build(); + + final ProcessorTopology topology = + builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + + assertThat( + topology.stateStores().size(), + equalTo(3)); + assertThat( + topology.processorConnectedStateStores("KSTREAM-JOIN-0000000010"), + equalTo(Collections.singleton("store"))); + assertThat( + topology.processorConnectedStateStores("KTABLE-MERGE-0000000007"), + equalTo(Collections.singleton("store"))); + } + + @Test + public void shouldAllowJoinMaterializedSourceKTable() { + final KTable table = builder.table(TABLE_TOPIC); + builder.stream(STREAM_TOPIC).join(table, MockValueJoiner.TOSTRING_JOINER); + builder.build(); + + final ProcessorTopology topology = + builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + + assertThat( + topology.stateStores().size(), + equalTo(1)); + assertThat( + topology.processorConnectedStateStores("KTABLE-SOURCE-0000000002"), + equalTo(Collections.singleton(topology.stateStores().get(0).name()))); + assertThat( + topology.processorConnectedStateStores("KSTREAM-JOIN-0000000004"), + equalTo(Collections.singleton(topology.stateStores().get(0).name()))); + } + + @Test + public void shouldProcessingFromSinkTopic() { + final KStream source = builder.stream("topic-source"); + source.to("topic-sink"); + + final MockProcessorSupplier processorSupplier = new MockProcessorSupplier<>(); + source.process(processorSupplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic("topic-source", new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + inputTopic.pipeInput("A", "aa"); + } + + // no exception was thrown + assertEquals(Collections.singletonList(new KeyValueTimestamp<>("A", "aa", 0)), + processorSupplier.theCapturedProcessor().processed()); + } + + @Deprecated + @Test + public void shouldProcessViaThroughTopic() { + final KStream source = builder.stream("topic-source"); + final KStream through = source.through("topic-sink"); + + final MockProcessorSupplier sourceProcessorSupplier = new MockProcessorSupplier<>(); + source.process(sourceProcessorSupplier); + + final MockProcessorSupplier throughProcessorSupplier = new MockProcessorSupplier<>(); + through.process(throughProcessorSupplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic("topic-source", new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + inputTopic.pipeInput("A", "aa"); + } + + assertEquals(Collections.singletonList(new KeyValueTimestamp<>("A", "aa", 0)), sourceProcessorSupplier.theCapturedProcessor().processed()); + assertEquals(Collections.singletonList(new KeyValueTimestamp<>("A", "aa", 0)), throughProcessorSupplier.theCapturedProcessor().processed()); + } + + @Test + public void shouldProcessViaRepartitionTopic() { + final KStream source = builder.stream("topic-source"); + final KStream through = source.repartition(); + + final MockProcessorSupplier sourceProcessorSupplier = new MockProcessorSupplier<>(); + source.process(sourceProcessorSupplier); + + final MockProcessorSupplier throughProcessorSupplier = new MockProcessorSupplier<>(); + through.process(throughProcessorSupplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic("topic-source", new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + inputTopic.pipeInput("A", "aa"); + } + + assertEquals(Collections.singletonList(new KeyValueTimestamp<>("A", "aa", 0)), sourceProcessorSupplier.theCapturedProcessor().processed()); + assertEquals(Collections.singletonList(new KeyValueTimestamp<>("A", "aa", 0)), throughProcessorSupplier.theCapturedProcessor().processed()); + } + + @Test + public void shouldMergeStreams() { + final String topic1 = "topic-1"; + final String topic2 = "topic-2"; + + final KStream source1 = builder.stream(topic1); + final KStream source2 = builder.stream(topic2); + final KStream merged = source1.merge(source2); + + final MockProcessorSupplier processorSupplier = new MockProcessorSupplier<>(); + merged.process(processorSupplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + + inputTopic1.pipeInput("A", "aa"); + inputTopic2.pipeInput("B", "bb"); + inputTopic2.pipeInput("C", "cc"); + inputTopic1.pipeInput("D", "dd"); + } + + assertEquals(asList(new KeyValueTimestamp<>("A", "aa", 0), + new KeyValueTimestamp<>("B", "bb", 0), + new KeyValueTimestamp<>("C", "cc", 0), + new KeyValueTimestamp<>("D", "dd", 0)), processorSupplier.theCapturedProcessor().processed()); + } + + @Test + public void shouldUseSerdesDefinedInMaterializedToConsumeTable() { + final Map results = new HashMap<>(); + final String topic = "topic"; + final ForeachAction action = results::put; + builder.table(topic, Materialized.>as("store") + .withKeySerde(Serdes.Long()) + .withValueSerde(Serdes.String())) + .toStream().foreach(action); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(topic, new LongSerializer(), new StringSerializer()); + inputTopic.pipeInput(1L, "value1"); + inputTopic.pipeInput(2L, "value2"); + + final KeyValueStore store = driver.getKeyValueStore("store"); + assertThat(store.get(1L), equalTo("value1")); + assertThat(store.get(2L), equalTo("value2")); + assertThat(results.get(1L), equalTo("value1")); + assertThat(results.get(2L), equalTo("value2")); + } + } + + @Test + public void shouldUseSerdesDefinedInMaterializedToConsumeGlobalTable() { + final String topic = "topic"; + builder.globalTable(topic, Materialized.>as("store") + .withKeySerde(Serdes.Long()) + .withValueSerde(Serdes.String())); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(topic, new LongSerializer(), new StringSerializer()); + inputTopic.pipeInput(1L, "value1"); + inputTopic.pipeInput(2L, "value2"); + final KeyValueStore store = driver.getKeyValueStore("store"); + + assertThat(store.get(1L), equalTo("value1")); + assertThat(store.get(2L), equalTo("value2")); + } + } + + @Test + public void shouldNotMaterializeStoresIfNotRequired() { + final String topic = "topic"; + builder.table(topic, Materialized.with(Serdes.Long(), Serdes.String())); + + final ProcessorTopology topology = + builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + + assertThat(topology.stateStores().size(), equalTo(0)); + } + + @Test + public void shouldReuseSourceTopicAsChangelogsWithOptimization20() { + final String topic = "topic"; + builder.table(topic, Materialized.>as("store")); + final Properties props = StreamsTestUtils.getStreamsConfig(); + props.put(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, StreamsConfig.OPTIMIZE); + final Topology topology = builder.build(props); + + final InternalTopologyBuilder internalTopologyBuilder = TopologyWrapper.getInternalTopologyBuilder(topology); + internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)); + + assertThat( + internalTopologyBuilder.buildTopology().storeToChangelogTopic(), + equalTo(Collections.singletonMap("store", "topic"))); + assertThat( + internalTopologyBuilder.stateStores().keySet(), + equalTo(Collections.singleton("store"))); + assertThat( + internalTopologyBuilder.stateStores().get("store").loggingEnabled(), + equalTo(false)); + assertThat( + internalTopologyBuilder.topicGroups().get(SUBTOPOLOGY_0).nonSourceChangelogTopics().isEmpty(), + equalTo(true)); + } + + @Test + public void shouldNotReuseRepartitionTopicAsChangelogs() { + final String topic = "topic"; + builder.stream(topic).repartition().toTable(Materialized.as("store")); + final Properties props = StreamsTestUtils.getStreamsConfig("appId"); + props.put(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, StreamsConfig.OPTIMIZE); + final Topology topology = builder.build(props); + + final InternalTopologyBuilder internalTopologyBuilder = TopologyWrapper.getInternalTopologyBuilder(topology); + internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)); + + assertThat( + internalTopologyBuilder.buildTopology().storeToChangelogTopic(), + equalTo(Collections.singletonMap("store", "appId-store-changelog")) + ); + assertThat( + internalTopologyBuilder.stateStores().keySet(), + equalTo(Collections.singleton("store")) + ); + assertThat( + internalTopologyBuilder.stateStores().get("store").loggingEnabled(), + equalTo(true) + ); + assertThat( + internalTopologyBuilder.topicGroups().get(SUBTOPOLOGY_1).stateChangelogTopics.keySet(), + equalTo(Collections.singleton("appId-store-changelog")) + ); + } + + @Test + public void shouldNotReuseSourceTopicAsChangelogsByDefault() { + final String topic = "topic"; + builder.table(topic, Materialized.>as("store")); + + final InternalTopologyBuilder internalTopologyBuilder = TopologyWrapper.getInternalTopologyBuilder(builder.build()); + internalTopologyBuilder.setApplicationId("appId"); + + assertThat( + internalTopologyBuilder.buildTopology().storeToChangelogTopic(), + equalTo(Collections.singletonMap("store", "appId-store-changelog"))); + assertThat( + internalTopologyBuilder.stateStores().keySet(), + equalTo(Collections.singleton("store"))); + assertThat( + internalTopologyBuilder.stateStores().get("store").loggingEnabled(), + equalTo(true)); + assertThat( + internalTopologyBuilder.topicGroups().get(SUBTOPOLOGY_0).stateChangelogTopics.keySet(), + equalTo(Collections.singleton("appId-store-changelog"))); + } + + @Test + public void shouldThrowExceptionWhenNoTopicPresent() { + builder.stream(Collections.emptyList()); + assertThrows(TopologyException.class, builder::build); + } + + @Test + public void shouldThrowExceptionWhenTopicNamesAreNull() { + builder.stream(Arrays.asList(null, null)); + assertThrows(NullPointerException.class, builder::build); + } + + @Test + public void shouldUseSpecifiedNameForStreamSourceProcessor() { + final String expected = "source-node"; + builder.stream(STREAM_TOPIC, Consumed.as(expected)); + builder.stream(STREAM_TOPIC_TWO); + builder.build(); + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + assertNamesForOperation(topology, expected, "KSTREAM-SOURCE-0000000001"); + } + + @Test + public void shouldUseSpecifiedNameForTableSourceProcessor() { + final String expected = "source-node"; + builder.table(STREAM_TOPIC, Consumed.as(expected)); + builder.table(STREAM_TOPIC_TWO); + builder.build(); + + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + + assertNamesForOperation( + topology, + expected + "-source", + expected, + "KSTREAM-SOURCE-0000000004", + "KTABLE-SOURCE-0000000005"); + } + + @Test + public void shouldUseSpecifiedNameForGlobalTableSourceProcessor() { + final String expected = "source-processor"; + builder.globalTable(STREAM_TOPIC, Consumed.as(expected)); + builder.globalTable(STREAM_TOPIC_TWO); + builder.build(); + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + + assertNamesForStateStore( + topology.globalStateStores(), + "stream-topic-STATE-STORE-0000000000", + "stream-topic-two-STATE-STORE-0000000003" + ); + } + + @Test + public void shouldUseSpecifiedNameForSinkProcessor() { + final String expected = "sink-processor"; + final KStream stream = builder.stream(STREAM_TOPIC); + stream.to(STREAM_TOPIC_TWO, Produced.as(expected)); + stream.to(STREAM_TOPIC_TWO); + builder.build(); + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + assertNamesForOperation(topology, "KSTREAM-SOURCE-0000000000", expected, "KSTREAM-SINK-0000000002"); + } + + @Test + public void shouldUseSpecifiedNameForMapOperation() { + builder.stream(STREAM_TOPIC).map(KeyValue::pair, Named.as(STREAM_OPERATION_NAME)); + builder.build(); + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + assertNamesForOperation(topology, "KSTREAM-SOURCE-0000000000", STREAM_OPERATION_NAME); + } + + @Test + public void shouldUseSpecifiedNameForMapValuesOperation() { + builder.stream(STREAM_TOPIC).mapValues(v -> v, Named.as(STREAM_OPERATION_NAME)); + builder.build(); + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + assertNamesForOperation(topology, "KSTREAM-SOURCE-0000000000", STREAM_OPERATION_NAME); + } + + @Test + public void shouldUseSpecifiedNameForMapValuesWithKeyOperation() { + builder.stream(STREAM_TOPIC).mapValues((k, v) -> v, Named.as(STREAM_OPERATION_NAME)); + builder.build(); + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + assertNamesForOperation(topology, "KSTREAM-SOURCE-0000000000", STREAM_OPERATION_NAME); + } + + @Test + public void shouldUseSpecifiedNameForFilterOperation() { + builder.stream(STREAM_TOPIC).filter((k, v) -> true, Named.as(STREAM_OPERATION_NAME)); + builder.build(); + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + assertNamesForOperation(topology, "KSTREAM-SOURCE-0000000000", STREAM_OPERATION_NAME); + } + + @Test + public void shouldUseSpecifiedNameForForEachOperation() { + builder.stream(STREAM_TOPIC).foreach((k, v) -> { }, Named.as(STREAM_OPERATION_NAME)); + builder.build(); + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + assertNamesForOperation(topology, "KSTREAM-SOURCE-0000000000", STREAM_OPERATION_NAME); + } + + @Test + public void shouldUseSpecifiedNameForTransform() { + builder.stream(STREAM_TOPIC).transform(() -> null, Named.as(STREAM_OPERATION_NAME)); + builder.build(); + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + assertNamesForOperation(topology, "KSTREAM-SOURCE-0000000000", STREAM_OPERATION_NAME); + } + + @Test + public void shouldUseSpecifiedNameForTransformValues() { + builder.stream(STREAM_TOPIC).transformValues(() -> new NoopValueTransformer<>(), Named.as(STREAM_OPERATION_NAME)); + builder.build(); + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + assertNamesForOperation(topology, "KSTREAM-SOURCE-0000000000", STREAM_OPERATION_NAME); + } + + @Test + public void shouldUseSpecifiedNameForTransformValuesWithKey() { + builder.stream(STREAM_TOPIC).transformValues(() -> new NoopValueTransformerWithKey<>(), Named.as(STREAM_OPERATION_NAME)); + builder.build(); + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + assertNamesForOperation(topology, "KSTREAM-SOURCE-0000000000", STREAM_OPERATION_NAME); + } + + @Test + @SuppressWarnings({"unchecked", "deprecation"}) + public void shouldUseSpecifiedNameForBranchOperation() { + builder.stream(STREAM_TOPIC) + .branch(Named.as("branch-processor"), (k, v) -> true, (k, v) -> false); + + builder.build(); + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + assertNamesForOperation(topology, + "KSTREAM-SOURCE-0000000000", + "branch-processor", + "branch-processor-predicate-0", + "branch-processor-predicate-1"); + } + + @Test + public void shouldUseSpecifiedNameForSplitOperation() { + builder.stream(STREAM_TOPIC) + .split(Named.as("branch-processor")) + .branch((k, v) -> true, Branched.as("-1")) + .branch((k, v) -> false, Branched.as("-2")); + builder.build(); + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + assertNamesForOperation(topology, + "KSTREAM-SOURCE-0000000000", + "branch-processor", + "branch-processor-1", + "branch-processor-2"); + } + + @Test + public void shouldUseSpecifiedNameForJoinOperationBetweenKStreamAndKTable() { + final KStream streamOne = builder.stream(STREAM_TOPIC); + final KTable streamTwo = builder.table("table-topic"); + streamOne.join(streamTwo, (value1, value2) -> value1, Joined.as(STREAM_OPERATION_NAME)); + builder.build(); + + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + assertNamesForOperation(topology, + "KSTREAM-SOURCE-0000000000", + "KSTREAM-SOURCE-0000000002", + "KTABLE-SOURCE-0000000003", + STREAM_OPERATION_NAME); + } + + @Test + public void shouldUseSpecifiedNameForLeftJoinOperationBetweenKStreamAndKTable() { + final KStream streamOne = builder.stream(STREAM_TOPIC); + final KTable streamTwo = builder.table(STREAM_TOPIC_TWO); + streamOne.leftJoin(streamTwo, (value1, value2) -> value1, Joined.as(STREAM_OPERATION_NAME)); + builder.build(); + + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + assertNamesForOperation(topology, + "KSTREAM-SOURCE-0000000000", + "KSTREAM-SOURCE-0000000002", + "KTABLE-SOURCE-0000000003", + STREAM_OPERATION_NAME); + } + + @Test + public void shouldNotAddThirdStateStoreIfStreamStreamJoinFixIsDisabledViaOldApi() { + final KStream streamOne = builder.stream(STREAM_TOPIC); + final KStream streamTwo = builder.stream(STREAM_TOPIC_TWO); + + streamOne.leftJoin( + streamTwo, + (value1, value2) -> value1, + JoinWindows.of(Duration.ofHours(1)), + StreamJoined.as(STREAM_OPERATION_NAME) + .withName(STREAM_OPERATION_NAME) + ); + + final Properties properties = new Properties(); + builder.build(properties); + + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + assertNamesForStateStore(topology.stateStores(), + STREAM_OPERATION_NAME + "-this-join-store", + STREAM_OPERATION_NAME + "-outer-other-join-store" + ); + assertNamesForOperation(topology, + "KSTREAM-SOURCE-0000000000", + "KSTREAM-SOURCE-0000000001", + STREAM_OPERATION_NAME + "-this-windowed", + STREAM_OPERATION_NAME + "-other-windowed", + STREAM_OPERATION_NAME + "-this-join", + STREAM_OPERATION_NAME + "-outer-other-join", + STREAM_OPERATION_NAME + "-merge"); + } + + @Test + public void shouldUseSpecifiedNameForLeftJoinOperationBetweenKStreamAndKStream() { + final KStream streamOne = builder.stream(STREAM_TOPIC); + final KStream streamTwo = builder.stream(STREAM_TOPIC_TWO); + + streamOne.leftJoin( + streamTwo, + (value1, value2) -> value1, + JoinWindows.ofTimeDifferenceWithNoGrace(Duration.ofHours(1)), + StreamJoined.as(STREAM_OPERATION_NAME) + .withName(STREAM_OPERATION_NAME) + ); + builder.build(); + + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + assertNamesForStateStore(topology.stateStores(), + STREAM_OPERATION_NAME + "-this-join-store", + STREAM_OPERATION_NAME + "-outer-other-join-store", + STREAM_OPERATION_NAME + "-left-shared-join-store" + ); + assertNamesForOperation(topology, + "KSTREAM-SOURCE-0000000000", + "KSTREAM-SOURCE-0000000001", + STREAM_OPERATION_NAME + "-this-windowed", + STREAM_OPERATION_NAME + "-other-windowed", + STREAM_OPERATION_NAME + "-this-join", + STREAM_OPERATION_NAME + "-outer-other-join", + STREAM_OPERATION_NAME + "-merge"); + } + + @Test + public void shouldUseGeneratedStoreNamesForLeftJoinOperationBetweenKStreamAndKStream() { + final KStream streamOne = builder.stream(STREAM_TOPIC); + final KStream streamTwo = builder.stream(STREAM_TOPIC_TWO); + + streamOne.leftJoin( + streamTwo, + (value1, value2) -> value1, + JoinWindows.ofTimeDifferenceWithNoGrace(Duration.ofHours(1)), + StreamJoined.with(Serdes.String(), Serdes.String(), Serdes.String()) + .withName(STREAM_OPERATION_NAME) + ); + builder.build(); + + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + assertNamesForStateStore(topology.stateStores(), + "KSTREAM-JOINTHIS-0000000004-store", + "KSTREAM-OUTEROTHER-0000000005-store", + "KSTREAM-OUTERSHARED-0000000004-store" + ); + assertNamesForOperation(topology, + "KSTREAM-SOURCE-0000000000", + "KSTREAM-SOURCE-0000000001", + STREAM_OPERATION_NAME + "-this-windowed", + STREAM_OPERATION_NAME + "-other-windowed", + STREAM_OPERATION_NAME + "-this-join", + STREAM_OPERATION_NAME + "-outer-other-join", + STREAM_OPERATION_NAME + "-merge"); + } + + @Test + public void shouldUseSpecifiedNameForJoinOperationBetweenKStreamAndKStream() { + final KStream streamOne = builder.stream(STREAM_TOPIC); + final KStream streamTwo = builder.stream(STREAM_TOPIC_TWO); + + streamOne.join(streamTwo, (value1, value2) -> value1, JoinWindows.of(Duration.ofHours(1)), StreamJoined.as(STREAM_OPERATION_NAME).withName(STREAM_OPERATION_NAME)); + builder.build(); + + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + assertNamesForStateStore(topology.stateStores(), + STREAM_OPERATION_NAME + "-this-join-store", + STREAM_OPERATION_NAME + "-other-join-store" + ); + assertNamesForOperation(topology, + "KSTREAM-SOURCE-0000000000", + "KSTREAM-SOURCE-0000000001", + STREAM_OPERATION_NAME + "-this-windowed", + STREAM_OPERATION_NAME + "-other-windowed", + STREAM_OPERATION_NAME + "-this-join", + STREAM_OPERATION_NAME + "-other-join", + STREAM_OPERATION_NAME + "-merge"); + } + + @Test + public void shouldUseGeneratedNameForJoinOperationBetweenKStreamAndKStream() { + final KStream streamOne = builder.stream(STREAM_TOPIC); + final KStream streamTwo = builder.stream(STREAM_TOPIC_TWO); + + streamOne.join(streamTwo, (value1, value2) -> value1, JoinWindows.of(Duration.ofHours(1)), StreamJoined.with(Serdes.String(), Serdes.String(), Serdes.String()).withName(STREAM_OPERATION_NAME)); + builder.build(); + + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + assertNamesForStateStore(topology.stateStores(), + "KSTREAM-JOINTHIS-0000000004-store", + "KSTREAM-JOINOTHER-0000000005-store" + ); + assertNamesForOperation(topology, + "KSTREAM-SOURCE-0000000000", + "KSTREAM-SOURCE-0000000001", + STREAM_OPERATION_NAME + "-this-windowed", + STREAM_OPERATION_NAME + "-other-windowed", + STREAM_OPERATION_NAME + "-this-join", + STREAM_OPERATION_NAME + "-other-join", + STREAM_OPERATION_NAME + "-merge"); + } + + @Test + public void shouldUseSpecifiedNameForOuterJoinOperationBetweenKStreamAndKStream() { + final KStream streamOne = builder.stream(STREAM_TOPIC); + final KStream streamTwo = builder.stream(STREAM_TOPIC_TWO); + + streamOne.outerJoin( + streamTwo, + (value1, value2) -> value1, + JoinWindows.ofTimeDifferenceWithNoGrace(Duration.ofHours(1)), + StreamJoined.as(STREAM_OPERATION_NAME) + .withName(STREAM_OPERATION_NAME) + ); + builder.build(); + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + assertNamesForStateStore(topology.stateStores(), + STREAM_OPERATION_NAME + "-outer-this-join-store", + STREAM_OPERATION_NAME + "-outer-other-join-store", + STREAM_OPERATION_NAME + "-outer-shared-join-store"); + assertNamesForOperation(topology, + "KSTREAM-SOURCE-0000000000", + "KSTREAM-SOURCE-0000000001", + STREAM_OPERATION_NAME + "-this-windowed", + STREAM_OPERATION_NAME + "-other-windowed", + STREAM_OPERATION_NAME + "-outer-this-join", + STREAM_OPERATION_NAME + "-outer-other-join", + STREAM_OPERATION_NAME + "-merge"); + + } + + @Test + public void shouldUseGeneratedStoreNamesForOuterJoinOperationBetweenKStreamAndKStream() { + final KStream streamOne = builder.stream(STREAM_TOPIC); + final KStream streamTwo = builder.stream(STREAM_TOPIC_TWO); + + streamOne.outerJoin( + streamTwo, + (value1, value2) -> value1, + JoinWindows.ofTimeDifferenceWithNoGrace(Duration.ofHours(1)), + StreamJoined.with(Serdes.String(), Serdes.String(), Serdes.String()) + .withName(STREAM_OPERATION_NAME) + ); + builder.build(); + + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + assertNamesForStateStore(topology.stateStores(), + "KSTREAM-OUTERTHIS-0000000004-store", + "KSTREAM-OUTEROTHER-0000000005-store", + "KSTREAM-OUTERSHARED-0000000004-store" + ); + assertNamesForOperation(topology, + "KSTREAM-SOURCE-0000000000", + "KSTREAM-SOURCE-0000000001", + STREAM_OPERATION_NAME + "-this-windowed", + STREAM_OPERATION_NAME + "-other-windowed", + STREAM_OPERATION_NAME + "-outer-this-join", + STREAM_OPERATION_NAME + "-outer-other-join", + STREAM_OPERATION_NAME + "-merge"); + } + + + @Test + public void shouldUseSpecifiedNameForMergeOperation() { + final String topic1 = "topic-1"; + final String topic2 = "topic-2"; + + final KStream source1 = builder.stream(topic1); + final KStream source2 = builder.stream(topic2); + source1.merge(source2, Named.as("merge-processor")); + builder.build(); + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + assertNamesForOperation(topology, "KSTREAM-SOURCE-0000000000", "KSTREAM-SOURCE-0000000001", "merge-processor"); + } + + @Test + public void shouldUseSpecifiedNameForProcessOperation() { + builder.stream(STREAM_TOPIC) + .process(new MockProcessorSupplier<>(), Named.as("test-processor")); + + builder.build(); + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + assertNamesForOperation(topology, "KSTREAM-SOURCE-0000000000", "test-processor"); + } + + @Test + public void shouldUseSpecifiedNameForPrintOperation() { + builder.stream(STREAM_TOPIC).print(Printed.toSysOut().withName("print-processor")); + builder.build(); + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + assertNamesForOperation(topology, "KSTREAM-SOURCE-0000000000", "print-processor"); + } + + @Test + public void shouldUseSpecifiedNameForFlatTransformValueOperation() { + builder.stream(STREAM_TOPIC).flatTransformValues(() -> new NoopValueTransformer<>(), Named.as(STREAM_OPERATION_NAME)); + builder.build(); + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + assertNamesForOperation(topology, "KSTREAM-SOURCE-0000000000", STREAM_OPERATION_NAME); + } + + @Test + @SuppressWarnings({"unchecked", "rawtypes"}) + public void shouldUseSpecifiedNameForFlatTransformValueWithKeyOperation() { + builder.stream(STREAM_TOPIC).flatTransformValues(() -> new NoopValueTransformerWithKey(), Named.as(STREAM_OPERATION_NAME)); + builder.build(); + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + assertNamesForOperation(topology, "KSTREAM-SOURCE-0000000000", STREAM_OPERATION_NAME); + } + + @Test + public void shouldUseSpecifiedNameForToStream() { + builder.table(STREAM_TOPIC) + .toStream(Named.as("to-stream")); + + builder.build(); + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + assertNamesForOperation(topology, + "KSTREAM-SOURCE-0000000001", + "KTABLE-SOURCE-0000000002", + "to-stream"); + } + + @Test + public void shouldUseSpecifiedNameForToStreamWithMapper() { + builder.table(STREAM_TOPIC) + .toStream(KeyValue::pair, Named.as("to-stream")); + + builder.build(); + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + assertNamesForOperation(topology, + "KSTREAM-SOURCE-0000000001", + "KTABLE-SOURCE-0000000002", + "to-stream", + "KSTREAM-KEY-SELECT-0000000004"); + } + + @Test + public void shouldUseSpecifiedNameForAggregateOperationGivenTable() { + builder.table(STREAM_TOPIC).groupBy(KeyValue::pair, Grouped.as("group-operation")).count(Named.as(STREAM_OPERATION_NAME)); + builder.build(); + final ProcessorTopology topology = builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(props)).buildTopology(); + assertNamesForStateStore( + topology.stateStores(), + STREAM_TOPIC + "-STATE-STORE-0000000000", + "KTABLE-AGGREGATE-STATE-STORE-0000000004"); + + assertNamesForOperation( + topology, + "KSTREAM-SOURCE-0000000001", + "KTABLE-SOURCE-0000000002", + "group-operation", + STREAM_OPERATION_NAME + "-sink", + STREAM_OPERATION_NAME + "-source", + STREAM_OPERATION_NAME); + } + + @Test + public void shouldAllowStreamsFromSameTopic() { + builder.stream("topic"); + builder.stream("topic"); + assertBuildDoesNotThrow(builder); + } + + @Test + public void shouldAllowSubscribingToSamePattern() { + builder.stream(Pattern.compile("some-regex")); + builder.stream(Pattern.compile("some-regex")); + assertBuildDoesNotThrow(builder); + } + + @Test + public void shouldAllowReadingFromSameCollectionOfTopics() { + builder.stream(asList("topic1", "topic2")); + builder.stream(asList("topic2", "topic1")); + assertBuildDoesNotThrow(builder); + } + + @Test + public void shouldNotAllowReadingFromOverlappingAndUnequalCollectionOfTopics() { + builder.stream(Collections.singletonList("topic")); + builder.stream(asList("topic", "anotherTopic")); + assertThrows(TopologyException.class, builder::build); + } + + @Test + public void shouldThrowWhenSubscribedToATopicWithDifferentResetPolicies() { + builder.stream("topic", Consumed.with(AutoOffsetReset.EARLIEST)); + builder.stream("topic", Consumed.with(AutoOffsetReset.LATEST)); + assertThrows(TopologyException.class, builder::build); + } + + @Test + public void shouldThrowWhenSubscribedToATopicWithSetAndUnsetResetPolicies() { + builder.stream("topic", Consumed.with(AutoOffsetReset.EARLIEST)); + builder.stream("topic"); + assertThrows(TopologyException.class, builder::build); + } + + @Test + public void shouldThrowWhenSubscribedToATopicWithUnsetAndSetResetPolicies() { + builder.stream("another-topic"); + builder.stream("another-topic", Consumed.with(AutoOffsetReset.LATEST)); + assertThrows(TopologyException.class, builder::build); + } + + @Test + public void shouldThrowWhenSubscribedToAPatternWithDifferentResetPolicies() { + builder.stream(Pattern.compile("some-regex"), Consumed.with(AutoOffsetReset.EARLIEST)); + builder.stream(Pattern.compile("some-regex"), Consumed.with(AutoOffsetReset.LATEST)); + assertThrows(TopologyException.class, builder::build); + } + + @Test + public void shouldThrowWhenSubscribedToAPatternWithSetAndUnsetResetPolicies() { + builder.stream(Pattern.compile("some-regex"), Consumed.with(AutoOffsetReset.EARLIEST)); + builder.stream(Pattern.compile("some-regex")); + assertThrows(TopologyException.class, builder::build); + } + + @Test + public void shouldNotAllowTablesFromSameTopic() { + builder.table("topic"); + builder.table("topic"); + assertThrows(TopologyException.class, builder::build); + } + + @Test + public void shouldNowAllowStreamAndTableFromSameTopic() { + builder.stream("topic"); + builder.table("topic"); + assertThrows(TopologyException.class, builder::build); + } + + private static void assertBuildDoesNotThrow(final StreamsBuilder builder) { + try { + builder.build(); + } catch (final TopologyException topologyException) { + fail("TopologyException not expected"); + } + } + + private static void assertNamesForOperation(final ProcessorTopology topology, final String... expected) { + final List> processors = topology.processors(); + assertEquals("Invalid number of expected processors", expected.length, processors.size()); + for (int i = 0; i < expected.length; i++) { + assertEquals(expected[i], processors.get(i).name()); + } + } + + private static void assertNamesForStateStore(final List stores, final String... expected) { + assertEquals("Invalid number of expected state stores", expected.length, stores.size()); + for (int i = 0; i < expected.length; i++) { + assertEquals(expected[i], stores.get(i).name()); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/StreamsConfigTest.java b/streams/src/test/java/org/apache/kafka/streams/StreamsConfigTest.java new file mode 100644 index 0000000..2e1b0d8 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/StreamsConfigTest.java @@ -0,0 +1,1150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.clients.admin.AdminClientConfig; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.config.TopicConfig; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.FailOnInvalidTimestamp; +import org.apache.kafka.streams.processor.TimestampExtractor; +import org.apache.kafka.streams.processor.internals.StreamsPartitionAssignor; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Properties; + +import static org.apache.kafka.common.IsolationLevel.READ_COMMITTED; +import static org.apache.kafka.common.IsolationLevel.READ_UNCOMMITTED; +import static org.apache.kafka.streams.StreamsConfig.AT_LEAST_ONCE; +import static org.apache.kafka.streams.StreamsConfig.EXACTLY_ONCE; +import static org.apache.kafka.streams.StreamsConfig.EXACTLY_ONCE_BETA; +import static org.apache.kafka.streams.StreamsConfig.EXACTLY_ONCE_V2; +import static org.apache.kafka.streams.StreamsConfig.STATE_DIR_CONFIG; +import static org.apache.kafka.streams.StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG; +import static org.apache.kafka.streams.StreamsConfig.adminClientPrefix; +import static org.apache.kafka.streams.StreamsConfig.consumerPrefix; +import static org.apache.kafka.streams.StreamsConfig.producerPrefix; +import static org.apache.kafka.test.StreamsTestUtils.getStreamsConfig; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.IsEqual.equalTo; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class StreamsConfigTest { + + private final Properties props = new Properties(); + private StreamsConfig streamsConfig; + + private final String groupId = "example-application"; + private final String clientId = "client"; + private final int threadIdx = 1; + + @Before + public void setUp() { + props.put(StreamsConfig.APPLICATION_ID_CONFIG, "streams-config-test"); + props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + props.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass().getName()); + props.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass().getName()); + props.put("key.deserializer.encoding", StandardCharsets.UTF_8.name()); + props.put("value.deserializer.encoding", StandardCharsets.UTF_16.name()); + streamsConfig = new StreamsConfig(props); + } + + @Test + public void testIllegalMetricsRecordingLevel() { + props.put(StreamsConfig.METRICS_RECORDING_LEVEL_CONFIG, "illegalConfig"); + assertThrows(ConfigException.class, () -> new StreamsConfig(props)); + } + + @Test + public void testOsDefaultSocketBufferSizes() { + props.put(StreamsConfig.SEND_BUFFER_CONFIG, CommonClientConfigs.RECEIVE_BUFFER_LOWER_BOUND); + props.put(StreamsConfig.RECEIVE_BUFFER_CONFIG, CommonClientConfigs.RECEIVE_BUFFER_LOWER_BOUND); + new StreamsConfig(props); + } + + @Test + public void testInvalidSocketSendBufferSize() { + props.put(StreamsConfig.SEND_BUFFER_CONFIG, -2); + assertThrows(ConfigException.class, () -> new StreamsConfig(props)); + } + + @Test + public void testInvalidSocketReceiveBufferSize() { + props.put(StreamsConfig.RECEIVE_BUFFER_CONFIG, -2); + assertThrows(ConfigException.class, () -> new StreamsConfig(props)); + } + + @Test + public void shouldThrowExceptionIfApplicationIdIsNotSet() { + props.remove(StreamsConfig.APPLICATION_ID_CONFIG); + assertThrows(ConfigException.class, () -> new StreamsConfig(props)); + } + + @Test + public void shouldThrowExceptionIfBootstrapServersIsNotSet() { + props.remove(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG); + assertThrows(ConfigException.class, () -> new StreamsConfig(props)); + } + + @Test + public void testGetProducerConfigs() { + final Map returnedProps = streamsConfig.getProducerConfigs(clientId); + assertThat(returnedProps.get(ProducerConfig.CLIENT_ID_CONFIG), equalTo(clientId)); + assertThat(returnedProps.get(ProducerConfig.LINGER_MS_CONFIG), equalTo("100")); + } + + @Test + public void testGetConsumerConfigs() { + final Map returnedProps = streamsConfig.getMainConsumerConfigs(groupId, clientId, threadIdx); + assertThat(returnedProps.get(ConsumerConfig.CLIENT_ID_CONFIG), equalTo(clientId)); + assertThat(returnedProps.get(ConsumerConfig.GROUP_ID_CONFIG), equalTo(groupId)); + assertThat(returnedProps.get(ConsumerConfig.MAX_POLL_RECORDS_CONFIG), equalTo("1000")); + assertNull(returnedProps.get(ConsumerConfig.GROUP_INSTANCE_ID_CONFIG)); + } + + @Test + public void testGetGroupInstanceIdConfigs() { + props.put(ConsumerConfig.GROUP_INSTANCE_ID_CONFIG, "group-instance-id"); + props.put(StreamsConfig.mainConsumerPrefix(ConsumerConfig.GROUP_INSTANCE_ID_CONFIG), "group-instance-id-1"); + props.put(StreamsConfig.restoreConsumerPrefix(ConsumerConfig.GROUP_INSTANCE_ID_CONFIG), "group-instance-id-2"); + props.put(StreamsConfig.globalConsumerPrefix(ConsumerConfig.GROUP_INSTANCE_ID_CONFIG), "group-instance-id-3"); + final StreamsConfig streamsConfig = new StreamsConfig(props); + + Map returnedProps = streamsConfig.getMainConsumerConfigs(groupId, clientId, threadIdx); + assertThat( + returnedProps.get(ConsumerConfig.GROUP_INSTANCE_ID_CONFIG), + equalTo("group-instance-id-1-" + threadIdx) + ); + + returnedProps = streamsConfig.getRestoreConsumerConfigs(clientId); + assertNull(returnedProps.get(ConsumerConfig.GROUP_INSTANCE_ID_CONFIG)); + + returnedProps = streamsConfig.getGlobalConsumerConfigs(clientId); + assertNull(returnedProps.get(ConsumerConfig.GROUP_INSTANCE_ID_CONFIG)); + } + + @Test + public void consumerConfigMustContainStreamPartitionAssignorConfig() { + props.put(StreamsConfig.REPLICATION_FACTOR_CONFIG, 42); + props.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1); + props.put(StreamsConfig.ACCEPTABLE_RECOVERY_LAG_CONFIG, 99L); + props.put(StreamsConfig.MAX_WARMUP_REPLICAS_CONFIG, 9); + props.put(StreamsConfig.PROBING_REBALANCE_INTERVAL_MS_CONFIG, 99_999L); + props.put(StreamsConfig.WINDOW_STORE_CHANGE_LOG_ADDITIONAL_RETENTION_MS_CONFIG, 7L); + props.put(StreamsConfig.APPLICATION_SERVER_CONFIG, "dummy:host"); + props.put(StreamsConfig.topicPrefix(TopicConfig.SEGMENT_BYTES_CONFIG), 100); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map returnedProps = streamsConfig.getMainConsumerConfigs(groupId, clientId, threadIdx); + + assertEquals(42, returnedProps.get(StreamsConfig.REPLICATION_FACTOR_CONFIG)); + assertEquals(1, returnedProps.get(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG)); + assertEquals(99L, returnedProps.get(StreamsConfig.ACCEPTABLE_RECOVERY_LAG_CONFIG)); + assertEquals(9, returnedProps.get(StreamsConfig.MAX_WARMUP_REPLICAS_CONFIG)); + assertEquals(99_999L, returnedProps.get(StreamsConfig.PROBING_REBALANCE_INTERVAL_MS_CONFIG)); + assertEquals( + StreamsPartitionAssignor.class.getName(), + returnedProps.get(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG) + ); + assertEquals(7L, returnedProps.get(StreamsConfig.WINDOW_STORE_CHANGE_LOG_ADDITIONAL_RETENTION_MS_CONFIG)); + assertEquals("dummy:host", returnedProps.get(StreamsConfig.APPLICATION_SERVER_CONFIG)); + assertEquals(100, returnedProps.get(StreamsConfig.topicPrefix(TopicConfig.SEGMENT_BYTES_CONFIG))); + } + + @Test + public void testGetMainConsumerConfigsWithMainConsumerOverridenPrefix() { + props.put(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_RECORDS_CONFIG), "5"); + props.put(StreamsConfig.mainConsumerPrefix(ConsumerConfig.MAX_POLL_RECORDS_CONFIG), "50"); + props.put(StreamsConfig.mainConsumerPrefix(ConsumerConfig.GROUP_ID_CONFIG), "another-id"); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map returnedProps = streamsConfig.getMainConsumerConfigs(groupId, clientId, threadIdx); + assertEquals(groupId, returnedProps.get(ConsumerConfig.GROUP_ID_CONFIG)); + assertEquals("50", returnedProps.get(ConsumerConfig.MAX_POLL_RECORDS_CONFIG)); + } + + @Test + public void testGetRestoreConsumerConfigs() { + final Map returnedProps = streamsConfig.getRestoreConsumerConfigs(clientId); + assertEquals(returnedProps.get(ConsumerConfig.CLIENT_ID_CONFIG), clientId); + assertNull(returnedProps.get(ConsumerConfig.GROUP_ID_CONFIG)); + } + + @Test + public void defaultSerdeShouldBeConfigured() { + final Map serializerConfigs = new HashMap<>(); + serializerConfigs.put("key.serializer.encoding", StandardCharsets.UTF_8.name()); + serializerConfigs.put("value.serializer.encoding", StandardCharsets.UTF_16.name()); + final Serializer serializer = Serdes.String().serializer(); + + final String str = "my string for testing"; + final String topic = "my topic"; + + serializer.configure(serializerConfigs, true); + assertEquals( + "Should get the original string after serialization and deserialization with the configured encoding", + str, + streamsConfig.defaultKeySerde().deserializer().deserialize(topic, serializer.serialize(topic, str)) + ); + + serializer.configure(serializerConfigs, false); + assertEquals( + "Should get the original string after serialization and deserialization with the configured encoding", + str, + streamsConfig.defaultValueSerde().deserializer().deserialize(topic, serializer.serialize(topic, str)) + ); + } + + @Test + public void shouldSupportMultipleBootstrapServers() { + final List expectedBootstrapServers = Arrays.asList("broker1:9092", "broker2:9092"); + final String bootstrapServersString = Utils.join(expectedBootstrapServers, ","); + final Properties props = new Properties(); + props.put(StreamsConfig.APPLICATION_ID_CONFIG, "irrelevant"); + props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServersString); + final StreamsConfig config = new StreamsConfig(props); + + final List actualBootstrapServers = config.getList(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG); + assertEquals(expectedBootstrapServers, actualBootstrapServers); + } + + @Test + public void shouldSupportPrefixedConsumerConfigs() { + props.put(consumerPrefix(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG), "earliest"); + props.put(consumerPrefix(ConsumerConfig.METRICS_NUM_SAMPLES_CONFIG), 1); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map consumerConfigs = streamsConfig.getMainConsumerConfigs(groupId, clientId, threadIdx); + assertEquals("earliest", consumerConfigs.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG)); + assertEquals(1, consumerConfigs.get(ConsumerConfig.METRICS_NUM_SAMPLES_CONFIG)); + } + + @Test + public void shouldSupportPrefixedRestoreConsumerConfigs() { + props.put(consumerPrefix(ConsumerConfig.METRICS_NUM_SAMPLES_CONFIG), 1); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map consumerConfigs = streamsConfig.getRestoreConsumerConfigs(clientId); + assertEquals(1, consumerConfigs.get(ConsumerConfig.METRICS_NUM_SAMPLES_CONFIG)); + } + + @Test + public void shouldSupportPrefixedPropertiesThatAreNotPartOfConsumerConfig() { + props.put(consumerPrefix("interceptor.statsd.host"), "host"); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map consumerConfigs = streamsConfig.getMainConsumerConfigs(groupId, clientId, threadIdx); + assertEquals("host", consumerConfigs.get("interceptor.statsd.host")); + } + + @Test + public void shouldSupportPrefixedPropertiesThatAreNotPartOfRestoreConsumerConfig() { + props.put(consumerPrefix("interceptor.statsd.host"), "host"); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map consumerConfigs = streamsConfig.getRestoreConsumerConfigs(clientId); + assertEquals("host", consumerConfigs.get("interceptor.statsd.host")); + } + + @Test + public void shouldSupportPrefixedPropertiesThatAreNotPartOfProducerConfig() { + props.put(producerPrefix("interceptor.statsd.host"), "host"); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map producerConfigs = streamsConfig.getProducerConfigs(clientId); + assertEquals("host", producerConfigs.get("interceptor.statsd.host")); + } + + @Test + public void shouldSupportPrefixedProducerConfigs() { + props.put(producerPrefix(ProducerConfig.BUFFER_MEMORY_CONFIG), 10); + props.put(producerPrefix(ConsumerConfig.METRICS_NUM_SAMPLES_CONFIG), 1); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map configs = streamsConfig.getProducerConfigs(clientId); + assertEquals(10, configs.get(ProducerConfig.BUFFER_MEMORY_CONFIG)); + assertEquals(1, configs.get(ProducerConfig.METRICS_NUM_SAMPLES_CONFIG)); + } + + @Test + public void shouldBeSupportNonPrefixedConsumerConfigs() { + props.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + props.put(ConsumerConfig.METRICS_NUM_SAMPLES_CONFIG, 1); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map consumerConfigs = streamsConfig.getMainConsumerConfigs(groupId, clientId, threadIdx); + assertEquals("earliest", consumerConfigs.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG)); + assertEquals(1, consumerConfigs.get(ConsumerConfig.METRICS_NUM_SAMPLES_CONFIG)); + } + + @Test + public void shouldBeSupportNonPrefixedRestoreConsumerConfigs() { + props.put(ConsumerConfig.METRICS_NUM_SAMPLES_CONFIG, 1); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map consumerConfigs = streamsConfig.getRestoreConsumerConfigs(groupId); + assertEquals(1, consumerConfigs.get(ConsumerConfig.METRICS_NUM_SAMPLES_CONFIG)); + } + + @Test + public void shouldSupportNonPrefixedProducerConfigs() { + props.put(ProducerConfig.BUFFER_MEMORY_CONFIG, 10); + props.put(ConsumerConfig.METRICS_NUM_SAMPLES_CONFIG, 1); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map configs = streamsConfig.getProducerConfigs(clientId); + assertEquals(10, configs.get(ProducerConfig.BUFFER_MEMORY_CONFIG)); + assertEquals(1, configs.get(ProducerConfig.METRICS_NUM_SAMPLES_CONFIG)); + } + + @Test + public void shouldForwardCustomConfigsWithNoPrefixToAllClients() { + props.put("custom.property.host", "host"); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map consumerConfigs = streamsConfig.getMainConsumerConfigs(groupId, clientId, threadIdx); + final Map restoreConsumerConfigs = streamsConfig.getRestoreConsumerConfigs(clientId); + final Map producerConfigs = streamsConfig.getProducerConfigs(clientId); + final Map adminConfigs = streamsConfig.getAdminConfigs(clientId); + assertEquals("host", consumerConfigs.get("custom.property.host")); + assertEquals("host", restoreConsumerConfigs.get("custom.property.host")); + assertEquals("host", producerConfigs.get("custom.property.host")); + assertEquals("host", adminConfigs.get("custom.property.host")); + } + + @Test + public void shouldOverrideNonPrefixedCustomConfigsWithPrefixedConfigs() { + props.put("custom.property.host", "host0"); + props.put(consumerPrefix("custom.property.host"), "host1"); + props.put(producerPrefix("custom.property.host"), "host2"); + props.put(adminClientPrefix("custom.property.host"), "host3"); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map consumerConfigs = streamsConfig.getMainConsumerConfigs(groupId, clientId, threadIdx); + final Map restoreConsumerConfigs = streamsConfig.getRestoreConsumerConfigs(clientId); + final Map producerConfigs = streamsConfig.getProducerConfigs(clientId); + final Map adminConfigs = streamsConfig.getAdminConfigs(clientId); + assertEquals("host1", consumerConfigs.get("custom.property.host")); + assertEquals("host1", restoreConsumerConfigs.get("custom.property.host")); + assertEquals("host2", producerConfigs.get("custom.property.host")); + assertEquals("host3", adminConfigs.get("custom.property.host")); + } + + @Test + public void shouldSupportNonPrefixedAdminConfigs() { + props.put(AdminClientConfig.DEFAULT_API_TIMEOUT_MS_CONFIG, 10); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map configs = streamsConfig.getAdminConfigs(clientId); + assertEquals(10, configs.get(AdminClientConfig.DEFAULT_API_TIMEOUT_MS_CONFIG)); + } + + @Test + public void shouldThrowStreamsExceptionIfKeySerdeConfigFails() { + props.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, MisconfiguredSerde.class); + final StreamsConfig streamsConfig = new StreamsConfig(props); + assertThrows(StreamsException.class, streamsConfig::defaultKeySerde); + } + + @Test + public void shouldThrowStreamsExceptionIfValueSerdeConfigFails() { + props.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, MisconfiguredSerde.class); + final StreamsConfig streamsConfig = new StreamsConfig(props); + assertThrows(StreamsException.class, streamsConfig::defaultValueSerde); + } + + @Test + public void shouldOverrideStreamsDefaultConsumerConfigs() { + props.put(StreamsConfig.consumerPrefix(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG), "latest"); + props.put(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_RECORDS_CONFIG), "10"); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map consumerConfigs = streamsConfig.getMainConsumerConfigs(groupId, clientId, threadIdx); + assertEquals("latest", consumerConfigs.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG)); + assertEquals("10", consumerConfigs.get(ConsumerConfig.MAX_POLL_RECORDS_CONFIG)); + } + + @Test + public void shouldOverrideStreamsDefaultProducerConfigs() { + props.put(StreamsConfig.producerPrefix(ProducerConfig.LINGER_MS_CONFIG), "10000"); + props.put(StreamsConfig.producerPrefix(ProducerConfig.TRANSACTION_TIMEOUT_CONFIG), "30000"); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map producerConfigs = streamsConfig.getProducerConfigs(clientId); + assertEquals("10000", producerConfigs.get(ProducerConfig.LINGER_MS_CONFIG)); + assertEquals("30000", producerConfigs.get(ProducerConfig.TRANSACTION_TIMEOUT_CONFIG)); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldThrowIfTransactionTimeoutSmallerThanCommitIntervalForEOSAlpha() { + assertThrows(IllegalArgumentException.class, + () -> testTransactionTimeoutSmallerThanCommitInterval(EXACTLY_ONCE)); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldThrowIfTransactionTimeoutSmallerThanCommitIntervalForEOSBeta() { + assertThrows(IllegalArgumentException.class, + () -> testTransactionTimeoutSmallerThanCommitInterval(EXACTLY_ONCE_BETA)); + } + + @Test + public void shouldNotThrowIfTransactionTimeoutSmallerThanCommitIntervalForAtLeastOnce() { + testTransactionTimeoutSmallerThanCommitInterval(AT_LEAST_ONCE); + } + + private void testTransactionTimeoutSmallerThanCommitInterval(final String processingGuarantee) { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, processingGuarantee); + props.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 10000L); + props.put(StreamsConfig.producerPrefix(ProducerConfig.TRANSACTION_TIMEOUT_CONFIG), 3000); + new StreamsConfig(props); + } + + @Test + public void shouldOverrideStreamsDefaultConsumerConifgsOnRestoreConsumer() { + props.put(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_RECORDS_CONFIG), "10"); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map consumerConfigs = streamsConfig.getRestoreConsumerConfigs(clientId); + assertEquals("10", consumerConfigs.get(ConsumerConfig.MAX_POLL_RECORDS_CONFIG)); + } + + @Test + public void shouldResetToDefaultIfConsumerAutoCommitIsOverridden() { + props.put(StreamsConfig.consumerPrefix(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG), "true"); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map consumerConfigs = streamsConfig.getMainConsumerConfigs("a", "b", threadIdx); + assertEquals("false", consumerConfigs.get(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG)); + } + + @Test + public void shouldResetToDefaultIfRestoreConsumerAutoCommitIsOverridden() { + props.put(StreamsConfig.consumerPrefix(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG), "true"); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map consumerConfigs = streamsConfig.getRestoreConsumerConfigs(clientId); + assertEquals("false", consumerConfigs.get(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG)); + } + + @Test + public void testGetRestoreConsumerConfigsWithRestoreConsumerOverridenPrefix() { + props.put(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_RECORDS_CONFIG), "5"); + props.put(StreamsConfig.restoreConsumerPrefix(ConsumerConfig.MAX_POLL_RECORDS_CONFIG), "50"); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map returnedProps = streamsConfig.getRestoreConsumerConfigs(clientId); + assertEquals("50", returnedProps.get(ConsumerConfig.MAX_POLL_RECORDS_CONFIG)); + } + + @Test + public void testGetGlobalConsumerConfigs() { + final Map returnedProps = streamsConfig.getGlobalConsumerConfigs(clientId); + assertEquals(returnedProps.get(ConsumerConfig.CLIENT_ID_CONFIG), clientId + "-global-consumer"); + assertNull(returnedProps.get(ConsumerConfig.GROUP_ID_CONFIG)); + } + + @Test + public void shouldSupportPrefixedGlobalConsumerConfigs() { + props.put(consumerPrefix(ConsumerConfig.METRICS_NUM_SAMPLES_CONFIG), 1); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map consumerConfigs = streamsConfig.getGlobalConsumerConfigs(clientId); + assertEquals(1, consumerConfigs.get(ConsumerConfig.METRICS_NUM_SAMPLES_CONFIG)); + } + + @Test + public void shouldSupportPrefixedPropertiesThatAreNotPartOfGlobalConsumerConfig() { + props.put(consumerPrefix("interceptor.statsd.host"), "host"); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map consumerConfigs = streamsConfig.getGlobalConsumerConfigs(clientId); + assertEquals("host", consumerConfigs.get("interceptor.statsd.host")); + } + + @Test + public void shouldBeSupportNonPrefixedGlobalConsumerConfigs() { + props.put(ConsumerConfig.METRICS_NUM_SAMPLES_CONFIG, 1); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map consumerConfigs = streamsConfig.getGlobalConsumerConfigs(groupId); + assertEquals(1, consumerConfigs.get(ConsumerConfig.METRICS_NUM_SAMPLES_CONFIG)); + } + + @Test + public void shouldResetToDefaultIfGlobalConsumerAutoCommitIsOverridden() { + props.put(StreamsConfig.consumerPrefix(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG), "true"); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map consumerConfigs = streamsConfig.getGlobalConsumerConfigs(clientId); + assertEquals("false", consumerConfigs.get(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG)); + } + + @Test + public void testGetGlobalConsumerConfigsWithGlobalConsumerOverridenPrefix() { + props.put(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_RECORDS_CONFIG), "5"); + props.put(StreamsConfig.globalConsumerPrefix(ConsumerConfig.MAX_POLL_RECORDS_CONFIG), "50"); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map returnedProps = streamsConfig.getGlobalConsumerConfigs(clientId); + assertEquals("50", returnedProps.get(ConsumerConfig.MAX_POLL_RECORDS_CONFIG)); + } + + @Test + public void shouldSetInternalLeaveGroupOnCloseConfigToFalseInConsumer() { + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map consumerConfigs = streamsConfig.getMainConsumerConfigs(groupId, clientId, threadIdx); + assertThat(consumerConfigs.get("internal.leave.group.on.close"), is(false)); + } + + @Test + public void shouldNotSetInternalThrowOnFetchStableOffsetUnsupportedConfigToFalseInConsumerForEosDisabled() { + final Map consumerConfigs = streamsConfig.getMainConsumerConfigs(groupId, clientId, threadIdx); + assertThat(consumerConfigs.get("internal.throw.on.fetch.stable.offset.unsupported"), is(nullValue())); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNotSetInternalThrowOnFetchStableOffsetUnsupportedConfigToFalseInConsumerForEosAlpha() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map consumerConfigs = streamsConfig.getMainConsumerConfigs(groupId, clientId, threadIdx); + assertThat(consumerConfigs.get("internal.throw.on.fetch.stable.offset.unsupported"), is(nullValue())); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNotSetInternalThrowOnFetchStableOffsetUnsupportedConfigToFalseInConsumerForEosBeta() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE_BETA); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map consumerConfigs = streamsConfig.getMainConsumerConfigs(groupId, clientId, threadIdx); + assertThat(consumerConfigs.get("internal.throw.on.fetch.stable.offset.unsupported"), is(true)); + } + + @Test + public void shouldNotSetInternalThrowOnFetchStableOffsetUnsupportedConfigToFalseInConsumerForEosV2() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE_V2); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map consumerConfigs = streamsConfig.getMainConsumerConfigs(groupId, clientId, threadIdx); + assertThat(consumerConfigs.get("internal.throw.on.fetch.stable.offset.unsupported"), is(true)); + } + + @Test + public void shouldNotSetInternalAutoDowngradeTxnCommitToTrueInProducerForEosDisabled() { + final Map producerConfigs = streamsConfig.getProducerConfigs(clientId); + assertThat(producerConfigs.get("internal.auto.downgrade.txn.commit"), is(nullValue())); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldSetInternalAutoDowngradeTxnCommitToTrueInProducerForEosAlpha() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map producerConfigs = streamsConfig.getProducerConfigs(clientId); + assertThat(producerConfigs.get("internal.auto.downgrade.txn.commit"), is(true)); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNotSetInternalAutoDowngradeTxnCommitToTrueInProducerForEosBeta() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE_BETA); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map producerConfigs = streamsConfig.getProducerConfigs(clientId); + assertThat(producerConfigs.get("internal.auto.downgrade.txn.commit"), is(nullValue())); + } + + @Test + public void shouldNotSetInternalAutoDowngradeTxnCommitToTrueInProducerForEosV2() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE_V2); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map producerConfigs = streamsConfig.getProducerConfigs(clientId); + assertThat(producerConfigs.get("internal.auto.downgrade.txn.commit"), is(nullValue())); + } + + @Test + public void shouldAcceptAtLeastOnce() { + // don't use `StreamsConfig.AT_LEAST_ONCE` to actually do a useful test + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, "at_least_once"); + new StreamsConfig(props); + } + + @Test + public void shouldAcceptExactlyOnce() { + // don't use `StreamsConfig.EXACLTY_ONCE` to actually do a useful test + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, "exactly_once"); + new StreamsConfig(props); + } + + @Test + public void shouldAcceptExactlyOnceBeta() { + // don't use `StreamsConfig.EXACLTY_ONCE_BETA` to actually do a useful test + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, "exactly_once_beta"); + new StreamsConfig(props); + } + + @Test + public void shouldThrowExceptionIfNotAtLeastOnceOrExactlyOnce() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, "bad_value"); + assertThrows(ConfigException.class, () -> new StreamsConfig(props)); + } + + @Test + public void shouldAcceptBuiltInMetricsLatestVersion() { + // don't use `StreamsConfig.METRICS_LATEST` to actually do a useful test + props.put(StreamsConfig.BUILT_IN_METRICS_VERSION_CONFIG, "latest"); + new StreamsConfig(props); + } + + @Test + public void shouldSetDefaultBuiltInMetricsVersionIfNoneIsSpecified() { + final StreamsConfig config = new StreamsConfig(props); + assertThat(config.getString(StreamsConfig.BUILT_IN_METRICS_VERSION_CONFIG), is(StreamsConfig.METRICS_LATEST)); + } + + @Test + public void shouldThrowIfBuiltInMetricsVersionInvalid() { + final String invalidVersion = "0.0.1"; + props.put(StreamsConfig.BUILT_IN_METRICS_VERSION_CONFIG, invalidVersion); + final Exception exception = assertThrows(ConfigException.class, () -> new StreamsConfig(props)); + assertThat( + exception.getMessage(), + containsString("Invalid value " + invalidVersion + " for configuration built.in.metrics.version") + ); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldResetToDefaultIfConsumerIsolationLevelIsOverriddenIfEosAlphaEnabled() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE); + shouldResetToDefaultIfConsumerIsolationLevelIsOverriddenIfEosEnabled(); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldResetToDefaultIfConsumerIsolationLevelIsOverriddenIfEosBetaEnabled() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE_BETA); + shouldResetToDefaultIfConsumerIsolationLevelIsOverriddenIfEosEnabled(); + } + + @Test + public void shouldResetToDefaultIfConsumerIsolationLevelIsOverriddenIfEosV2Enabled() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE_V2); + shouldResetToDefaultIfConsumerIsolationLevelIsOverriddenIfEosEnabled(); + } + + private void shouldResetToDefaultIfConsumerIsolationLevelIsOverriddenIfEosEnabled() { + props.put(ConsumerConfig.ISOLATION_LEVEL_CONFIG, "anyValue"); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map consumerConfigs = streamsConfig.getMainConsumerConfigs(groupId, clientId, threadIdx); + assertThat( + consumerConfigs.get(ConsumerConfig.ISOLATION_LEVEL_CONFIG), + equalTo(READ_COMMITTED.name().toLowerCase(Locale.ROOT)) + ); + } + + @Test + public void shouldAllowSettingConsumerIsolationLevelIfEosDisabled() { + props.put(ConsumerConfig.ISOLATION_LEVEL_CONFIG, READ_UNCOMMITTED.name().toLowerCase(Locale.ROOT)); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map consumerConfigs = streamsConfig.getMainConsumerConfigs(groupId, clientId, threadIdx); + assertThat( + consumerConfigs.get(ConsumerConfig.ISOLATION_LEVEL_CONFIG), + equalTo(READ_UNCOMMITTED.name().toLowerCase(Locale.ROOT)) + ); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldResetToDefaultIfProducerEnableIdempotenceIsOverriddenIfEosAlphaEnabled() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE); + shouldResetToDefaultIfProducerEnableIdempotenceIsOverriddenIfEosEnabled(); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldResetToDefaultIfProducerEnableIdempotenceIsOverriddenIfEosBetaEnabled() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE_BETA); + shouldResetToDefaultIfProducerEnableIdempotenceIsOverriddenIfEosEnabled(); + } + + @Test + public void shouldResetToDefaultIfProducerEnableIdempotenceIsOverriddenIfEosV2Enabled() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE_V2); + shouldResetToDefaultIfProducerEnableIdempotenceIsOverriddenIfEosEnabled(); + } + + private void shouldResetToDefaultIfProducerEnableIdempotenceIsOverriddenIfEosEnabled() { + props.put(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG, "anyValue"); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map producerConfigs = streamsConfig.getProducerConfigs(clientId); + assertTrue((Boolean) producerConfigs.get(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG)); + } + + @Test + public void shouldAllowSettingProducerEnableIdempotenceIfEosDisabled() { + props.put(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG, false); + final StreamsConfig streamsConfig = new StreamsConfig(props); + final Map producerConfigs = streamsConfig.getProducerConfigs(clientId); + assertThat(producerConfigs.get(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG), equalTo(false)); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldSetDifferentDefaultsIfEosAlphaEnabled() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE); + shouldSetDifferentDefaultsIfEosEnabled(); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldSetDifferentDefaultsIfEosBetaEnabled() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE_BETA); + shouldSetDifferentDefaultsIfEosEnabled(); + } + + @Test + public void shouldSetDifferentDefaultsIfEosV2Enabled() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE_V2); + shouldSetDifferentDefaultsIfEosEnabled(); + } + + private void shouldSetDifferentDefaultsIfEosEnabled() { + final StreamsConfig streamsConfig = new StreamsConfig(props); + + final Map consumerConfigs = streamsConfig.getMainConsumerConfigs(groupId, clientId, threadIdx); + final Map producerConfigs = streamsConfig.getProducerConfigs(clientId); + + assertThat( + consumerConfigs.get(ConsumerConfig.ISOLATION_LEVEL_CONFIG), + equalTo(READ_COMMITTED.name().toLowerCase(Locale.ROOT)) + ); + assertTrue((Boolean) producerConfigs.get(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG)); + assertThat(producerConfigs.get(ProducerConfig.DELIVERY_TIMEOUT_MS_CONFIG), equalTo(Integer.MAX_VALUE)); + assertThat(producerConfigs.get(ProducerConfig.TRANSACTION_TIMEOUT_CONFIG), equalTo(10000)); + assertThat(streamsConfig.getLong(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG), equalTo(100L)); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldOverrideUserConfigTransactionalIdIfEosAlphaEnabled() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE); + shouldOverrideUserConfigTransactionalIdIfEosEnable(); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldOverrideUserConfigTransactionalIdIfEosBetaEnabled() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE_BETA); + shouldOverrideUserConfigTransactionalIdIfEosEnable(); + } + + @Test + public void shouldOverrideUserConfigTransactionalIdIfEosV2Enabled() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE_V2); + shouldOverrideUserConfigTransactionalIdIfEosEnable(); + } + + private void shouldOverrideUserConfigTransactionalIdIfEosEnable() { + props.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "user-TxId"); + final StreamsConfig streamsConfig = new StreamsConfig(props); + + final Map producerConfigs = streamsConfig.getProducerConfigs(clientId); + + assertThat(producerConfigs.get(ProducerConfig.TRANSACTIONAL_ID_CONFIG), is(nullValue())); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNotOverrideUserConfigRetriesIfExactlyAlphaOnceEnabled() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE); + shouldNotOverrideUserConfigRetriesIfExactlyOnceEnabled(); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNotOverrideUserConfigRetriesIfExactlyBetaOnceEnabled() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE_BETA); + shouldNotOverrideUserConfigRetriesIfExactlyOnceEnabled(); + } + + @Test + public void shouldNotOverrideUserConfigRetriesIfExactlyV2OnceEnabled() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE_V2); + shouldNotOverrideUserConfigRetriesIfExactlyOnceEnabled(); + } + + private void shouldNotOverrideUserConfigRetriesIfExactlyOnceEnabled() { + final int numberOfRetries = 42; + props.put(ProducerConfig.RETRIES_CONFIG, numberOfRetries); + final StreamsConfig streamsConfig = new StreamsConfig(props); + + final Map producerConfigs = streamsConfig.getProducerConfigs(clientId); + + assertThat(producerConfigs.get(ProducerConfig.RETRIES_CONFIG), equalTo(numberOfRetries)); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNotOverrideUserConfigCommitIntervalMsIfExactlyOnceAlphaEnabled() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE); + shouldNotOverrideUserConfigCommitIntervalMsIfExactlyOnceEnabled(); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNotOverrideUserConfigCommitIntervalMsIfExactlyOnceBetaEnabled() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE_BETA); + shouldNotOverrideUserConfigCommitIntervalMsIfExactlyOnceEnabled(); + } + + @Test + public void shouldNotOverrideUserConfigCommitIntervalMsIfExactlyOnceV2Enabled() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE_V2); + shouldNotOverrideUserConfigCommitIntervalMsIfExactlyOnceEnabled(); + } + + private void shouldNotOverrideUserConfigCommitIntervalMsIfExactlyOnceEnabled() { + final long commitIntervalMs = 73L; + props.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, commitIntervalMs); + final StreamsConfig streamsConfig = new StreamsConfig(props); + + assertThat(streamsConfig.getLong(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG), equalTo(commitIntervalMs)); + } + + @Test + public void shouldThrowExceptionIfCommitIntervalMsIsNegative() { + final long commitIntervalMs = -1; + props.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, commitIntervalMs); + try { + new StreamsConfig(props); + fail("Should throw ConfigException when commitIntervalMs is set to a negative value"); + } catch (final ConfigException e) { + assertEquals( + "Invalid value -1 for configuration commit.interval.ms: Value must be at least 0", + e.getMessage() + ); + } + } + + @Test + public void shouldUseNewConfigsWhenPresent() { + final Properties props = getStreamsConfig(); + props.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.Long().getClass()); + props.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.Long().getClass()); + props.put(StreamsConfig.DEFAULT_TIMESTAMP_EXTRACTOR_CLASS_CONFIG, MockTimestampExtractor.class); + + final StreamsConfig config = new StreamsConfig(props); + assertTrue(config.defaultKeySerde() instanceof Serdes.LongSerde); + assertTrue(config.defaultValueSerde() instanceof Serdes.LongSerde); + assertTrue(config.defaultTimestampExtractor() instanceof MockTimestampExtractor); + } + + @Test + public void shouldUseCorrectDefaultsWhenNoneSpecified() { + final StreamsConfig config = new StreamsConfig(getStreamsConfig()); + + assertTrue(config.defaultTimestampExtractor() instanceof FailOnInvalidTimestamp); + assertThrows(ConfigException.class, config::defaultKeySerde); + assertThrows(ConfigException.class, config::defaultValueSerde); + } + + @Test + public void shouldSpecifyCorrectKeySerdeClassOnError() { + final Properties props = getStreamsConfig(); + props.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, MisconfiguredSerde.class); + final StreamsConfig config = new StreamsConfig(props); + try { + config.defaultKeySerde(); + fail("Test should throw a StreamsException"); + } catch (final StreamsException e) { + assertEquals( + "Failed to configure key serde class org.apache.kafka.streams.StreamsConfigTest$MisconfiguredSerde", + e.getMessage() + ); + } + } + + @Test + public void shouldSpecifyCorrectValueSerdeClassOnError() { + final Properties props = getStreamsConfig(); + props.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, MisconfiguredSerde.class); + final StreamsConfig config = new StreamsConfig(props); + try { + config.defaultValueSerde(); + fail("Test should throw a StreamsException"); + } catch (final StreamsException e) { + assertEquals( + "Failed to configure value serde class org.apache.kafka.streams.StreamsConfigTest$MisconfiguredSerde", + e.getMessage() + ); + } + } + + @SuppressWarnings("deprecation") + @Test + public void shouldThrowExceptionIfMaxInFlightRequestsGreaterThanFiveIfEosAlphaEnabled() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE); + shouldThrowExceptionIfMaxInFlightRequestsGreaterThanFiveIfEosEnabled(); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldThrowExceptionIfMaxInFlightRequestsGreaterThanFiveIfEosBetaEnabled() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE_BETA); + shouldThrowExceptionIfMaxInFlightRequestsGreaterThanFiveIfEosEnabled(); + } + + @Test + public void shouldThrowExceptionIfMaxInFlightRequestsGreaterThanFiveIfEosV2Enabled() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE_V2); + shouldThrowExceptionIfMaxInFlightRequestsGreaterThanFiveIfEosEnabled(); + } + + private void shouldThrowExceptionIfMaxInFlightRequestsGreaterThanFiveIfEosEnabled() { + props.put(ProducerConfig.MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION, 7); + final StreamsConfig streamsConfig = new StreamsConfig(props); + try { + streamsConfig.getProducerConfigs(clientId); + fail("Should throw ConfigException when ESO is enabled and maxInFlight requests exceeds 5"); + } catch (final ConfigException e) { + assertEquals( + "Invalid value 7 for configuration max.in.flight.requests.per.connection:" + + " Can't exceed 5 when exactly-once processing is enabled", + e.getMessage() + ); + } + } + + @SuppressWarnings("deprecation") + @Test + public void shouldAllowToSpecifyMaxInFlightRequestsPerConnectionAsStringIfEosAlphaEnabled() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE); + shouldAllowToSpecifyMaxInFlightRequestsPerConnectionAsStringIfEosEnabled(); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldAllowToSpecifyMaxInFlightRequestsPerConnectionAsStringIfEosBetaEnabled() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE_BETA); + shouldAllowToSpecifyMaxInFlightRequestsPerConnectionAsStringIfEosEnabled(); + } + + @Test + public void shouldAllowToSpecifyMaxInFlightRequestsPerConnectionAsStringIfEosV2Enabled() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE_V2); + shouldAllowToSpecifyMaxInFlightRequestsPerConnectionAsStringIfEosEnabled(); + } + + private void shouldAllowToSpecifyMaxInFlightRequestsPerConnectionAsStringIfEosEnabled() { + props.put(ProducerConfig.MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION, "3"); + + new StreamsConfig(props).getProducerConfigs(clientId); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldThrowConfigExceptionIfMaxInFlightRequestsPerConnectionIsInvalidStringIfEosAlphaEnabled() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE); + shouldThrowConfigExceptionIfMaxInFlightRequestsPerConnectionIsInvalidStringIfEosEnabled(); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldThrowConfigExceptionIfMaxInFlightRequestsPerConnectionIsInvalidStringIfEosBetaEnabled() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE_BETA); + shouldThrowConfigExceptionIfMaxInFlightRequestsPerConnectionIsInvalidStringIfEosEnabled(); + } + + @Test + public void shouldThrowConfigExceptionIfMaxInFlightRequestsPerConnectionIsInvalidStringIfEosV2Enabled() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, EXACTLY_ONCE_V2); + shouldThrowConfigExceptionIfMaxInFlightRequestsPerConnectionIsInvalidStringIfEosEnabled(); + } + + private void shouldThrowConfigExceptionIfMaxInFlightRequestsPerConnectionIsInvalidStringIfEosEnabled() { + props.put(ProducerConfig.MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION, "not-a-number"); + + try { + new StreamsConfig(props).getProducerConfigs(clientId); + fail("Should throw ConfigException when EOS is enabled and maxInFlight cannot be paresed into an integer"); + } catch (final ConfigException e) { + assertEquals( + "Invalid value not-a-number for configuration max.in.flight.requests.per.connection:" + + " String value could not be parsed as 32-bit integer", + e.getMessage() + ); + } + } + + @Test + public void shouldStateDirStartsWithJavaIOTmpDir() { + final String expectedPrefix = System.getProperty("java.io.tmpdir") + File.separator; + final String actual = streamsConfig.getString(STATE_DIR_CONFIG); + assertTrue(actual.startsWith(expectedPrefix)); + } + + @Test + public void shouldSpecifyNoOptimizationWhenNotExplicitlyAddedToConfigs() { + final String expectedOptimizeConfig = "none"; + final String actualOptimizedConifig = streamsConfig.getString(TOPOLOGY_OPTIMIZATION_CONFIG); + assertEquals("Optimization should be \"none\"", expectedOptimizeConfig, actualOptimizedConifig); + } + + @Test + public void shouldSpecifyOptimizationWhenNotExplicitlyAddedToConfigs() { + final String expectedOptimizeConfig = "all"; + props.put(TOPOLOGY_OPTIMIZATION_CONFIG, "all"); + final StreamsConfig config = new StreamsConfig(props); + final String actualOptimizedConifig = config.getString(TOPOLOGY_OPTIMIZATION_CONFIG); + assertEquals("Optimization should be \"all\"", expectedOptimizeConfig, actualOptimizedConifig); + } + + @Test + public void shouldThrowConfigExceptionWhenOptimizationConfigNotValueInRange() { + props.put(TOPOLOGY_OPTIMIZATION_CONFIG, "maybe"); + assertThrows(ConfigException.class, () -> new StreamsConfig(props)); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldLogWarningWhenEosAlphaIsUsed() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE); + + LogCaptureAppender.setClassLoggerToDebug(StreamsConfig.class); + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(StreamsConfig.class)) { + new StreamsConfig(props); + + assertThat( + appender.getMessages(), + hasItem("Configuration parameter `" + StreamsConfig.EXACTLY_ONCE + + "` is deprecated and will be removed in the 4.0.0 release. " + + "Please use `" + StreamsConfig.EXACTLY_ONCE_V2 + "` instead. " + + "Note that this requires broker version 2.5+ so you should prepare " + + "to upgrade your brokers if necessary.") + ); + } + } + + @SuppressWarnings("deprecation") + @Test + public void shouldLogWarningWhenEosBetaIsUsed() { + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE_BETA); + + LogCaptureAppender.setClassLoggerToDebug(StreamsConfig.class); + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(StreamsConfig.class)) { + new StreamsConfig(props); + + assertThat( + appender.getMessages(), + hasItem("Configuration parameter `" + StreamsConfig.EXACTLY_ONCE_BETA + + "` is deprecated and will be removed in the 4.0.0 release. " + + "Please use `" + StreamsConfig.EXACTLY_ONCE_V2 + "` instead.") + ); + } + } + + @SuppressWarnings("deprecation") + @Test + public void shouldLogWarningWhenRetriesIsUsed() { + props.put(StreamsConfig.RETRIES_CONFIG, 0); + + LogCaptureAppender.setClassLoggerToDebug(StreamsConfig.class); + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(StreamsConfig.class)) { + new StreamsConfig(props); + + assertThat( + appender.getMessages(), + hasItem("Configuration parameter `" + StreamsConfig.RETRIES_CONFIG + + "` is deprecated and will be removed in the 4.0.0 release.") + ); + } + } + + @Test + public void shouldSetDefaultAcceptableRecoveryLag() { + final StreamsConfig config = new StreamsConfig(props); + assertThat(config.getLong(StreamsConfig.ACCEPTABLE_RECOVERY_LAG_CONFIG), is(10000L)); + } + + @Test + public void shouldThrowConfigExceptionIfAcceptableRecoveryLagIsOutsideBounds() { + props.put(StreamsConfig.ACCEPTABLE_RECOVERY_LAG_CONFIG, -1L); + assertThrows(ConfigException.class, () -> new StreamsConfig(props)); + } + + @Test + public void shouldSetDefaultNumStandbyReplicas() { + final StreamsConfig config = new StreamsConfig(props); + assertThat(config.getInt(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG), is(0)); + } + + @Test + public void shouldThrowConfigExceptionIfNumStandbyReplicasIsOutsideBounds() { + props.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, -1L); + assertThrows(ConfigException.class, () -> new StreamsConfig(props)); + } + + @Test + public void shouldSetDefaultMaxWarmupReplicas() { + final StreamsConfig config = new StreamsConfig(props); + assertThat(config.getInt(StreamsConfig.MAX_WARMUP_REPLICAS_CONFIG), is(2)); + } + + @Test + public void shouldThrowConfigExceptionIfMaxWarmupReplicasIsOutsideBounds() { + props.put(StreamsConfig.MAX_WARMUP_REPLICAS_CONFIG, 0L); + assertThrows(ConfigException.class, () -> new StreamsConfig(props)); + } + + @Test + public void shouldSetDefaultProbingRebalanceInterval() { + final StreamsConfig config = new StreamsConfig(props); + assertThat(config.getLong(StreamsConfig.PROBING_REBALANCE_INTERVAL_MS_CONFIG), is(10 * 60 * 1000L)); + } + + @Test + public void shouldThrowConfigExceptionIfProbingRebalanceIntervalIsOutsideBounds() { + props.put(StreamsConfig.PROBING_REBALANCE_INTERVAL_MS_CONFIG, (60 * 1000L) - 1); + assertThrows(ConfigException.class, () -> new StreamsConfig(props)); + } + + static class MisconfiguredSerde implements Serde { + @Override + public void configure(final Map configs, final boolean isKey) { + throw new RuntimeException("boom"); + } + + @Override + public Serializer serializer() { + return null; + } + + @Override + public Deserializer deserializer() { + return null; + } + } + + public static class MockTimestampExtractor implements TimestampExtractor { + @Override + public long extract(final ConsumerRecord record, final long partitionTime) { + return 0; + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/TopologyTest.java b/streams/src/test/java/org/apache/kafka/streams/TopologyTest.java new file mode 100644 index 0000000..b332f6c --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/TopologyTest.java @@ -0,0 +1,1685 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.errors.TopologyException; +import org.apache.kafka.streams.kstream.JoinWindows; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.SessionWindows; +import org.apache.kafka.streams.kstream.StreamJoined; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.streams.processor.RecordContext; +import org.apache.kafka.streams.processor.TopicNameExtractor; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder.SubtopologyDescription; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.WindowBytesStoreSupplier; +import org.apache.kafka.streams.state.internals.KeyValueStoreBuilder; +import org.apache.kafka.test.MockApiProcessorSupplier; +import org.apache.kafka.test.MockKeyValueStore; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.MockValueJoiner; +import org.easymock.EasyMock; +import org.junit.Assert; +import org.junit.Test; + +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Properties; +import java.util.Set; +import java.util.regex.Pattern; + +import static java.time.Duration.ofMillis; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.fail; + +@SuppressWarnings("deprecation") +public class TopologyTest { + + private final StoreBuilder storeBuilder = EasyMock.createNiceMock(StoreBuilder.class); + private final KeyValueStoreBuilder globalStoreBuilder = EasyMock.createNiceMock(KeyValueStoreBuilder.class); + private final Topology topology = new Topology(); + private final InternalTopologyBuilder.TopologyDescription expectedDescription = new InternalTopologyBuilder.TopologyDescription(); + + @Test + public void shouldNotAllowNullNameWhenAddingSourceWithTopic() { + assertThrows(NullPointerException.class, () -> topology.addSource((String) null, "topic")); + } + + @Test + public void shouldNotAllowNullNameWhenAddingSourceWithPattern() { + assertThrows(NullPointerException.class, () -> topology.addSource(null, Pattern.compile(".*"))); + } + + @Test + public void shouldNotAllowNullTopicsWhenAddingSoureWithTopic() { + assertThrows(NullPointerException.class, () -> topology.addSource("source", (String[]) null)); + } + + @Test + public void shouldNotAllowNullTopicsWhenAddingSourceWithPattern() { + assertThrows(NullPointerException.class, () -> topology.addSource("source", (Pattern) null)); + } + + @Test + public void shouldNotAllowZeroTopicsWhenAddingSource() { + assertThrows(TopologyException.class, () -> topology.addSource("source")); + } + + @Test + public void shouldNotAllowNullNameWhenAddingProcessor() { + assertThrows(NullPointerException.class, () -> topology.addProcessor(null, () -> new MockApiProcessorSupplier<>().get())); + } + + @Test + public void shouldNotAllowNullProcessorSupplierWhenAddingProcessor() { + assertThrows(NullPointerException.class, () -> topology.addProcessor("name", + (ProcessorSupplier) null)); + } + + @Test + public void shouldNotAllowNullNameWhenAddingSink() { + assertThrows(NullPointerException.class, () -> topology.addSink(null, "topic")); + } + + @Test + public void shouldNotAllowNullTopicWhenAddingSink() { + assertThrows(NullPointerException.class, () -> topology.addSink("name", (String) null)); + } + + @Test + public void shouldNotAllowNullTopicChooserWhenAddingSink() { + assertThrows(NullPointerException.class, () -> topology.addSink("name", (TopicNameExtractor) null)); + } + + @Test + public void shouldNotAllowNullProcessorNameWhenConnectingProcessorAndStateStores() { + assertThrows(NullPointerException.class, () -> topology.connectProcessorAndStateStores(null, "store")); + } + + @Test + public void shouldNotAllowNullStoreNameWhenConnectingProcessorAndStateStores() { + assertThrows(NullPointerException.class, () -> topology.connectProcessorAndStateStores("processor", (String[]) null)); + } + + @Test + public void shouldNotAllowZeroStoreNameWhenConnectingProcessorAndStateStores() { + assertThrows(TopologyException.class, () -> topology.connectProcessorAndStateStores("processor")); + } + + @Test + public void shouldNotAddNullStateStoreSupplier() { + assertThrows(NullPointerException.class, () -> topology.addStateStore(null)); + } + + @Test + public void shouldNotAllowToAddSourcesWithSameName() { + topology.addSource("source", "topic-1"); + try { + topology.addSource("source", "topic-2"); + fail("Should throw TopologyException for duplicate source name"); + } catch (final TopologyException expected) { } + } + + @Test + public void shouldNotAllowToAddTopicTwice() { + topology.addSource("source", "topic-1"); + try { + topology.addSource("source-2", "topic-1"); + fail("Should throw TopologyException for already used topic"); + } catch (final TopologyException expected) { } + } + + @Test + public void testPatternMatchesAlreadyProvidedTopicSource() { + topology.addSource("source-1", "foo"); + try { + topology.addSource("source-2", Pattern.compile("f.*")); + fail("Should have thrown TopologyException for overlapping pattern with already registered topic"); + } catch (final TopologyException expected) { } + } + + @Test + public void testNamedTopicMatchesAlreadyProvidedPattern() { + topology.addSource("source-1", Pattern.compile("f.*")); + try { + topology.addSource("source-2", "foo"); + fail("Should have thrown TopologyException for overlapping topic with already registered pattern"); + } catch (final TopologyException expected) { } + } + + @Test + public void shouldNotAllowToAddProcessorWithSameName() { + topology.addSource("source", "topic-1"); + topology.addProcessor("processor", new MockApiProcessorSupplier<>(), "source"); + try { + topology.addProcessor("processor", new MockApiProcessorSupplier<>(), "source"); + fail("Should throw TopologyException for duplicate processor name"); + } catch (final TopologyException expected) { } + } + + @Test + public void shouldNotAllowToAddProcessorWithEmptyParents() { + topology.addSource("source", "topic-1"); + try { + topology.addProcessor("processor", new MockApiProcessorSupplier<>()); + fail("Should throw TopologyException for processor without at least one parent node"); + } catch (final TopologyException expected) { } + } + + @Test + public void shouldNotAllowToAddProcessorWithNullParents() { + topology.addSource("source", "topic-1"); + try { + topology.addProcessor("processor", new MockApiProcessorSupplier<>(), (String) null); + fail("Should throw NullPointerException for processor when null parent names are provided"); + } catch (final NullPointerException expected) { } + } + + @Test + public void shouldFailOnUnknownSource() { + assertThrows(TopologyException.class, () -> topology.addProcessor("processor", new MockApiProcessorSupplier<>(), "source")); + } + + @Test + public void shouldFailIfNodeIsItsOwnParent() { + assertThrows(TopologyException.class, () -> topology.addProcessor("processor", new MockApiProcessorSupplier<>(), "processor")); + } + + @Test + public void shouldNotAllowToAddSinkWithSameName() { + topology.addSource("source", "topic-1"); + topology.addSink("sink", "topic-2", "source"); + try { + topology.addSink("sink", "topic-3", "source"); + fail("Should throw TopologyException for duplicate sink name"); + } catch (final TopologyException expected) { } + } + + @Test + public void shouldNotAllowToAddSinkWithEmptyParents() { + topology.addSource("source", "topic-1"); + topology.addProcessor("processor", new MockApiProcessorSupplier<>(), "source"); + try { + topology.addSink("sink", "topic-2"); + fail("Should throw TopologyException for sink without at least one parent node"); + } catch (final TopologyException expected) { } + } + + @Test + public void shouldNotAllowToAddSinkWithNullParents() { + topology.addSource("source", "topic-1"); + topology.addProcessor("processor", new MockApiProcessorSupplier<>(), "source"); + try { + topology.addSink("sink", "topic-2", (String) null); + fail("Should throw NullPointerException for sink when null parent names are provided"); + } catch (final NullPointerException expected) { } + } + + @Test + public void shouldFailWithUnknownParent() { + assertThrows(TopologyException.class, () -> topology.addSink("sink", "topic-2", "source")); + } + + @Test + public void shouldFailIfSinkIsItsOwnParent() { + assertThrows(TopologyException.class, () -> topology.addSink("sink", "topic-2", "sink")); + } + + @Test + public void shouldFailIfSinkIsParent() { + topology.addSource("source", "topic-1"); + topology.addSink("sink-1", "topic-2", "source"); + try { + topology.addSink("sink-2", "topic-3", "sink-1"); + fail("Should throw TopologyException for using sink as parent"); + } catch (final TopologyException expected) { } + } + + @Test + public void shouldNotAllowToAddStateStoreToNonExistingProcessor() { + mockStoreBuilder(); + EasyMock.replay(storeBuilder); + assertThrows(TopologyException.class, () -> topology.addStateStore(storeBuilder, "no-such-processor")); + } + + @Test + public void shouldNotAllowToAddStateStoreToSource() { + mockStoreBuilder(); + EasyMock.replay(storeBuilder); + topology.addSource("source-1", "topic-1"); + try { + topology.addStateStore(storeBuilder, "source-1"); + fail("Should have thrown TopologyException for adding store to source node"); + } catch (final TopologyException expected) { } + } + + @Test + public void shouldNotAllowToAddStateStoreToSink() { + mockStoreBuilder(); + EasyMock.replay(storeBuilder); + topology.addSource("source-1", "topic-1"); + topology.addSink("sink-1", "topic-1", "source-1"); + try { + topology.addStateStore(storeBuilder, "sink-1"); + fail("Should have thrown TopologyException for adding store to sink node"); + } catch (final TopologyException expected) { } + } + + private void mockStoreBuilder() { + EasyMock.expect(storeBuilder.name()).andReturn("store").anyTimes(); + EasyMock.expect(storeBuilder.logConfig()).andReturn(Collections.emptyMap()); + EasyMock.expect(storeBuilder.loggingEnabled()).andReturn(false); + } + + @Test + public void shouldNotAllowToAddStoreWithSameNameAndDifferentInstance() { + mockStoreBuilder(); + EasyMock.replay(storeBuilder); + topology.addStateStore(storeBuilder); + + final StoreBuilder otherStoreBuilder = EasyMock.createNiceMock(StoreBuilder.class); + EasyMock.expect(otherStoreBuilder.name()).andReturn("store").anyTimes(); + EasyMock.expect(otherStoreBuilder.logConfig()).andReturn(Collections.emptyMap()); + EasyMock.expect(otherStoreBuilder.loggingEnabled()).andReturn(false); + EasyMock.replay(otherStoreBuilder); + try { + topology.addStateStore(otherStoreBuilder); + fail("Should have thrown TopologyException for same store name with different StoreBuilder"); + } catch (final TopologyException expected) { } + } + + @Test + public void shouldAllowToShareStoreUsingSameStoreBuilder() { + mockStoreBuilder(); + EasyMock.replay(storeBuilder); + + topology.addSource("source", "topic-1"); + + topology.addProcessor("processor-1", new MockProcessorSupplierProvidingStore<>(storeBuilder), "source"); + topology.addProcessor("processor-2", new MockProcessorSupplierProvidingStore<>(storeBuilder), "source"); + } + + private static class MockProcessorSupplierProvidingStore extends MockApiProcessorSupplier { + private final StoreBuilder storeBuilder; + + public MockProcessorSupplierProvidingStore(final StoreBuilder storeBuilder) { + this.storeBuilder = storeBuilder; + } + + @Override + public Set> stores() { + return Collections.singleton(storeBuilder); + } + } + + @Test + public void shouldThrowOnUnassignedStateStoreAccess() { + final String sourceNodeName = "source"; + final String goodNodeName = "goodGuy"; + final String badNodeName = "badGuy"; + + mockStoreBuilder(); + EasyMock.expect(storeBuilder.build()).andReturn(new MockKeyValueStore("store", false)); + EasyMock.replay(storeBuilder); + topology + .addSource(sourceNodeName, "topic") + .addProcessor(goodNodeName, new LocalMockProcessorSupplier(), sourceNodeName) + .addStateStore( + storeBuilder, + goodNodeName) + .addProcessor(badNodeName, new LocalMockProcessorSupplier(), sourceNodeName); + + final Properties config = new Properties(); + config.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.ByteArraySerde.class); + config.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.ByteArraySerde.class); + try { + new TopologyTestDriver(topology, config); + fail("Should have thrown StreamsException"); + } catch (final StreamsException e) { + final String error = e.toString(); + final String expectedMessage = "org.apache.kafka.streams.errors.StreamsException: failed to initialize processor " + badNodeName; + + assertThat(error, equalTo(expectedMessage)); + } + } + + private static class LocalMockProcessorSupplier implements ProcessorSupplier { + final static String STORE_NAME = "store"; + + @Override + public Processor get() { + return new Processor() { + @Override + public void init(final ProcessorContext context) { + context.getStateStore(STORE_NAME); + } + + @Override + public void process(final Record record) { } + }; + } + } + + @Deprecated // testing old PAPI + @Test + public void shouldNotAllowToAddGlobalStoreWithSourceNameEqualsProcessorName() { + EasyMock.expect(globalStoreBuilder.name()).andReturn("anyName").anyTimes(); + EasyMock.replay(globalStoreBuilder); + assertThrows(TopologyException.class, () -> topology.addGlobalStore( + globalStoreBuilder, + "sameName", + null, + null, + "anyTopicName", + "sameName", + new MockProcessorSupplier<>())); + } + + @Test + public void shouldDescribeEmptyTopology() { + assertThat(topology.describe(), equalTo(expectedDescription)); + } + + @Test + public void sinkShouldReturnNullTopicWithDynamicRouting() { + final TopologyDescription.Sink expectedSinkNode = + new InternalTopologyBuilder.Sink<>("sink", (key, value, record) -> record.topic() + "-" + key); + + assertThat(expectedSinkNode.topic(), equalTo(null)); + } + + @Test + public void sinkShouldReturnTopicNameExtractorWithDynamicRouting() { + final TopicNameExtractor topicNameExtractor = (key, value, record) -> record.topic() + "-" + key; + final TopologyDescription.Sink expectedSinkNode = + new InternalTopologyBuilder.Sink<>("sink", topicNameExtractor); + + assertThat(expectedSinkNode.topicNameExtractor(), equalTo(topicNameExtractor)); + } + + @Test + public void singleSourceShouldHaveSingleSubtopology() { + final TopologyDescription.Source expectedSourceNode = addSource("source", "topic"); + + expectedDescription.addSubtopology( + new SubtopologyDescription(0, + Collections.singleton(expectedSourceNode))); + + assertThat(topology.describe(), equalTo(expectedDescription)); + assertThat(topology.describe().hashCode(), equalTo(expectedDescription.hashCode())); + } + + @Test + public void singleSourceWithListOfTopicsShouldHaveSingleSubtopology() { + final TopologyDescription.Source expectedSourceNode = addSource("source", "topic1", "topic2", "topic3"); + + expectedDescription.addSubtopology( + new SubtopologyDescription(0, + Collections.singleton(expectedSourceNode))); + + assertThat(topology.describe(), equalTo(expectedDescription)); + assertThat(topology.describe().hashCode(), equalTo(expectedDescription.hashCode())); + } + + @Test + public void singleSourcePatternShouldHaveSingleSubtopology() { + final TopologyDescription.Source expectedSourceNode = addSource("source", Pattern.compile("topic[0-9]")); + + expectedDescription.addSubtopology( + new SubtopologyDescription(0, + Collections.singleton(expectedSourceNode))); + + assertThat(topology.describe(), equalTo(expectedDescription)); + assertThat(topology.describe().hashCode(), equalTo(expectedDescription.hashCode())); + } + + @Test + public void multipleSourcesShouldHaveDistinctSubtopologies() { + final TopologyDescription.Source expectedSourceNode1 = addSource("source1", "topic1"); + expectedDescription.addSubtopology( + new SubtopologyDescription(0, + Collections.singleton(expectedSourceNode1))); + + final TopologyDescription.Source expectedSourceNode2 = addSource("source2", "topic2"); + expectedDescription.addSubtopology( + new SubtopologyDescription(1, + Collections.singleton(expectedSourceNode2))); + + final TopologyDescription.Source expectedSourceNode3 = addSource("source3", "topic3"); + expectedDescription.addSubtopology( + new SubtopologyDescription(2, + Collections.singleton(expectedSourceNode3))); + + assertThat(topology.describe(), equalTo(expectedDescription)); + assertThat(topology.describe().hashCode(), equalTo(expectedDescription.hashCode())); + } + + @Test + public void sourceAndProcessorShouldHaveSingleSubtopology() { + final TopologyDescription.Source expectedSourceNode = addSource("source", "topic"); + final TopologyDescription.Processor expectedProcessorNode = addProcessor("processor", expectedSourceNode); + + final Set allNodes = new HashSet<>(); + allNodes.add(expectedSourceNode); + allNodes.add(expectedProcessorNode); + expectedDescription.addSubtopology(new SubtopologyDescription(0, allNodes)); + + assertThat(topology.describe(), equalTo(expectedDescription)); + assertThat(topology.describe().hashCode(), equalTo(expectedDescription.hashCode())); + } + + @Test + public void sourceAndProcessorWithStateShouldHaveSingleSubtopology() { + final TopologyDescription.Source expectedSourceNode = addSource("source", "topic"); + final String[] store = new String[] {"store"}; + final TopologyDescription.Processor expectedProcessorNode = + addProcessorWithNewStore("processor", store, expectedSourceNode); + + final Set allNodes = new HashSet<>(); + allNodes.add(expectedSourceNode); + allNodes.add(expectedProcessorNode); + expectedDescription.addSubtopology(new SubtopologyDescription(0, allNodes)); + + assertThat(topology.describe(), equalTo(expectedDescription)); + assertThat(topology.describe().hashCode(), equalTo(expectedDescription.hashCode())); + } + + + @Test + public void sourceAndProcessorWithMultipleStatesShouldHaveSingleSubtopology() { + final TopologyDescription.Source expectedSourceNode = addSource("source", "topic"); + final String[] stores = new String[] {"store1", "store2"}; + final TopologyDescription.Processor expectedProcessorNode = + addProcessorWithNewStore("processor", stores, expectedSourceNode); + + final Set allNodes = new HashSet<>(); + allNodes.add(expectedSourceNode); + allNodes.add(expectedProcessorNode); + expectedDescription.addSubtopology(new SubtopologyDescription(0, allNodes)); + + assertThat(topology.describe(), equalTo(expectedDescription)); + assertThat(topology.describe().hashCode(), equalTo(expectedDescription.hashCode())); + } + + @Test + public void sourceWithMultipleProcessorsShouldHaveSingleSubtopology() { + final TopologyDescription.Source expectedSourceNode = addSource("source", "topic"); + final TopologyDescription.Processor expectedProcessorNode1 = addProcessor("processor1", expectedSourceNode); + final TopologyDescription.Processor expectedProcessorNode2 = addProcessor("processor2", expectedSourceNode); + + final Set allNodes = new HashSet<>(); + allNodes.add(expectedSourceNode); + allNodes.add(expectedProcessorNode1); + allNodes.add(expectedProcessorNode2); + expectedDescription.addSubtopology(new SubtopologyDescription(0, allNodes)); + + assertThat(topology.describe(), equalTo(expectedDescription)); + assertThat(topology.describe().hashCode(), equalTo(expectedDescription.hashCode())); + } + + @Test + public void processorWithMultipleSourcesShouldHaveSingleSubtopology() { + final TopologyDescription.Source expectedSourceNode1 = addSource("source1", "topic0"); + final TopologyDescription.Source expectedSourceNode2 = addSource("source2", Pattern.compile("topic[1-9]")); + final TopologyDescription.Processor expectedProcessorNode = addProcessor("processor", expectedSourceNode1, expectedSourceNode2); + + final Set allNodes = new HashSet<>(); + allNodes.add(expectedSourceNode1); + allNodes.add(expectedSourceNode2); + allNodes.add(expectedProcessorNode); + expectedDescription.addSubtopology(new SubtopologyDescription(0, allNodes)); + + assertThat(topology.describe(), equalTo(expectedDescription)); + assertThat(topology.describe().hashCode(), equalTo(expectedDescription.hashCode())); + } + + @Test + public void multipleSourcesWithProcessorsShouldHaveDistinctSubtopologies() { + final TopologyDescription.Source expectedSourceNode1 = addSource("source1", "topic1"); + final TopologyDescription.Processor expectedProcessorNode1 = addProcessor("processor1", expectedSourceNode1); + + final TopologyDescription.Source expectedSourceNode2 = addSource("source2", "topic2"); + final TopologyDescription.Processor expectedProcessorNode2 = addProcessor("processor2", expectedSourceNode2); + + final TopologyDescription.Source expectedSourceNode3 = addSource("source3", "topic3"); + final TopologyDescription.Processor expectedProcessorNode3 = addProcessor("processor3", expectedSourceNode3); + + final Set allNodes1 = new HashSet<>(); + allNodes1.add(expectedSourceNode1); + allNodes1.add(expectedProcessorNode1); + expectedDescription.addSubtopology(new SubtopologyDescription(0, allNodes1)); + + final Set allNodes2 = new HashSet<>(); + allNodes2.add(expectedSourceNode2); + allNodes2.add(expectedProcessorNode2); + expectedDescription.addSubtopology(new SubtopologyDescription(1, allNodes2)); + + final Set allNodes3 = new HashSet<>(); + allNodes3.add(expectedSourceNode3); + allNodes3.add(expectedProcessorNode3); + expectedDescription.addSubtopology(new SubtopologyDescription(2, allNodes3)); + + assertThat(topology.describe(), equalTo(expectedDescription)); + assertThat(topology.describe().hashCode(), equalTo(expectedDescription.hashCode())); + } + + @Test + public void multipleSourcesWithSinksShouldHaveDistinctSubtopologies() { + final TopologyDescription.Source expectedSourceNode1 = addSource("source1", "topic1"); + final TopologyDescription.Sink expectedSinkNode1 = addSink("sink1", "sinkTopic1", expectedSourceNode1); + + final TopologyDescription.Source expectedSourceNode2 = addSource("source2", "topic2"); + final TopologyDescription.Sink expectedSinkNode2 = addSink("sink2", "sinkTopic2", expectedSourceNode2); + + final TopologyDescription.Source expectedSourceNode3 = addSource("source3", "topic3"); + final TopologyDescription.Sink expectedSinkNode3 = addSink("sink3", "sinkTopic3", expectedSourceNode3); + + final Set allNodes1 = new HashSet<>(); + allNodes1.add(expectedSourceNode1); + allNodes1.add(expectedSinkNode1); + expectedDescription.addSubtopology(new SubtopologyDescription(0, allNodes1)); + + final Set allNodes2 = new HashSet<>(); + allNodes2.add(expectedSourceNode2); + allNodes2.add(expectedSinkNode2); + expectedDescription.addSubtopology(new SubtopologyDescription(1, allNodes2)); + + final Set allNodes3 = new HashSet<>(); + allNodes3.add(expectedSourceNode3); + allNodes3.add(expectedSinkNode3); + expectedDescription.addSubtopology(new SubtopologyDescription(2, allNodes3)); + + assertThat(topology.describe(), equalTo(expectedDescription)); + assertThat(topology.describe().hashCode(), equalTo(expectedDescription.hashCode())); + } + + @Test + public void processorsWithSameSinkShouldHaveSameSubtopology() { + final TopologyDescription.Source expectedSourceNode1 = addSource("source", "topic"); + final TopologyDescription.Processor expectedProcessorNode1 = addProcessor("processor1", expectedSourceNode1); + + final TopologyDescription.Source expectedSourceNode2 = addSource("source2", "topic2"); + final TopologyDescription.Processor expectedProcessorNode2 = addProcessor("processor2", expectedSourceNode2); + + final TopologyDescription.Source expectedSourceNode3 = addSource("source3", "topic3"); + final TopologyDescription.Processor expectedProcessorNode3 = addProcessor("processor3", expectedSourceNode3); + + final TopologyDescription.Sink expectedSinkNode = addSink( + "sink", + "sinkTopic", + expectedProcessorNode1, + expectedProcessorNode2, + expectedProcessorNode3); + + final Set allNodes = new HashSet<>(); + allNodes.add(expectedSourceNode1); + allNodes.add(expectedProcessorNode1); + allNodes.add(expectedSourceNode2); + allNodes.add(expectedProcessorNode2); + allNodes.add(expectedSourceNode3); + allNodes.add(expectedProcessorNode3); + allNodes.add(expectedSinkNode); + expectedDescription.addSubtopology(new SubtopologyDescription(0, allNodes)); + + assertThat(topology.describe(), equalTo(expectedDescription)); + assertThat(topology.describe().hashCode(), equalTo(expectedDescription.hashCode())); + } + + @Test + public void processorsWithSharedStateShouldHaveSameSubtopology() { + final String[] store1 = new String[] {"store1"}; + final String[] store2 = new String[] {"store2"}; + final String[] bothStores = new String[] {store1[0], store2[0]}; + + final TopologyDescription.Source expectedSourceNode1 = addSource("source", "topic"); + final TopologyDescription.Processor expectedProcessorNode1 = + addProcessorWithNewStore("processor1", store1, expectedSourceNode1); + + final TopologyDescription.Source expectedSourceNode2 = addSource("source2", "topic2"); + final TopologyDescription.Processor expectedProcessorNode2 = + addProcessorWithNewStore("processor2", store2, expectedSourceNode2); + + final TopologyDescription.Source expectedSourceNode3 = addSource("source3", "topic3"); + final TopologyDescription.Processor expectedProcessorNode3 = + addProcessorWithExistingStore("processor3", bothStores, expectedSourceNode3); + + final Set allNodes = new HashSet<>(); + allNodes.add(expectedSourceNode1); + allNodes.add(expectedProcessorNode1); + allNodes.add(expectedSourceNode2); + allNodes.add(expectedProcessorNode2); + allNodes.add(expectedSourceNode3); + allNodes.add(expectedProcessorNode3); + expectedDescription.addSubtopology(new SubtopologyDescription(0, allNodes)); + + assertThat(topology.describe(), equalTo(expectedDescription)); + assertThat(topology.describe().hashCode(), equalTo(expectedDescription.hashCode())); + } + + @Test + public void shouldDescribeGlobalStoreTopology() { + addGlobalStoreToTopologyAndExpectedDescription("globalStore", "source", "globalTopic", "processor", 0); + assertThat(topology.describe(), equalTo(expectedDescription)); + assertThat(topology.describe().hashCode(), equalTo(expectedDescription.hashCode())); + } + + @Test + public void shouldDescribeMultipleGlobalStoreTopology() { + addGlobalStoreToTopologyAndExpectedDescription("globalStore1", "source1", "globalTopic1", "processor1", 0); + addGlobalStoreToTopologyAndExpectedDescription("globalStore2", "source2", "globalTopic2", "processor2", 1); + assertThat(topology.describe(), equalTo(expectedDescription)); + assertThat(topology.describe().hashCode(), equalTo(expectedDescription.hashCode())); + } + + @Test + public void streamStreamJoinTopologyWithDefaultStoresNames() { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream1; + final KStream stream2; + + stream1 = builder.stream("input-topic1"); + stream2 = builder.stream("input-topic2"); + + stream1.join( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.of(ofMillis(100)), + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String())); + + final TopologyDescription describe = builder.build().describe(); + + assertEquals( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input-topic1])\n" + + " --> KSTREAM-WINDOWED-0000000002\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [input-topic2])\n" + + " --> KSTREAM-WINDOWED-0000000003\n" + + " Processor: KSTREAM-WINDOWED-0000000002 (stores: [KSTREAM-JOINTHIS-0000000004-store])\n" + + " --> KSTREAM-JOINTHIS-0000000004\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: KSTREAM-WINDOWED-0000000003 (stores: [KSTREAM-JOINOTHER-0000000005-store])\n" + + " --> KSTREAM-JOINOTHER-0000000005\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: KSTREAM-JOINOTHER-0000000005 (stores: [KSTREAM-JOINTHIS-0000000004-store])\n" + + " --> KSTREAM-MERGE-0000000006\n" + + " <-- KSTREAM-WINDOWED-0000000003\n" + + " Processor: KSTREAM-JOINTHIS-0000000004 (stores: [KSTREAM-JOINOTHER-0000000005-store])\n" + + " --> KSTREAM-MERGE-0000000006\n" + + " <-- KSTREAM-WINDOWED-0000000002\n" + + " Processor: KSTREAM-MERGE-0000000006 (stores: [])\n" + + " --> none\n" + + " <-- KSTREAM-JOINTHIS-0000000004, KSTREAM-JOINOTHER-0000000005\n\n", + describe.toString()); + } + + @Test + public void streamStreamJoinTopologyWithCustomStoresNames() { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream1; + final KStream stream2; + + stream1 = builder.stream("input-topic1"); + stream2 = builder.stream("input-topic2"); + + stream1.join( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.of(ofMillis(100)), + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()) + .withStoreName("custom-name")); + + final TopologyDescription describe = builder.build().describe(); + + assertEquals( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input-topic1])\n" + + " --> KSTREAM-WINDOWED-0000000002\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [input-topic2])\n" + + " --> KSTREAM-WINDOWED-0000000003\n" + + " Processor: KSTREAM-WINDOWED-0000000002 (stores: [custom-name-this-join-store])\n" + + " --> KSTREAM-JOINTHIS-0000000004\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: KSTREAM-WINDOWED-0000000003 (stores: [custom-name-other-join-store])\n" + + " --> KSTREAM-JOINOTHER-0000000005\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: KSTREAM-JOINOTHER-0000000005 (stores: [custom-name-this-join-store])\n" + + " --> KSTREAM-MERGE-0000000006\n" + + " <-- KSTREAM-WINDOWED-0000000003\n" + + " Processor: KSTREAM-JOINTHIS-0000000004 (stores: [custom-name-other-join-store])\n" + + " --> KSTREAM-MERGE-0000000006\n" + + " <-- KSTREAM-WINDOWED-0000000002\n" + + " Processor: KSTREAM-MERGE-0000000006 (stores: [])\n" + + " --> none\n" + + " <-- KSTREAM-JOINTHIS-0000000004, KSTREAM-JOINOTHER-0000000005\n\n", + describe.toString()); + } + + @Test + public void streamStreamJoinTopologyWithCustomStoresSuppliers() { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream1; + final KStream stream2; + + stream1 = builder.stream("input-topic1"); + stream2 = builder.stream("input-topic2"); + + final JoinWindows joinWindows = JoinWindows.of(ofMillis(100)); + + final WindowBytesStoreSupplier thisStoreSupplier = Stores.inMemoryWindowStore("in-memory-join-store", + Duration.ofMillis(joinWindows.size() + joinWindows.gracePeriodMs()), + Duration.ofMillis(joinWindows.size()), true); + + final WindowBytesStoreSupplier otherStoreSupplier = Stores.inMemoryWindowStore("in-memory-join-store-other", + Duration.ofMillis(joinWindows.size() + joinWindows.gracePeriodMs()), + Duration.ofMillis(joinWindows.size()), true); + + stream1.join( + stream2, + MockValueJoiner.TOSTRING_JOINER, + joinWindows, + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()) + .withThisStoreSupplier(thisStoreSupplier) + .withOtherStoreSupplier(otherStoreSupplier)); + + final TopologyDescription describe = builder.build().describe(); + + assertEquals( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input-topic1])\n" + + " --> KSTREAM-WINDOWED-0000000002\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [input-topic2])\n" + + " --> KSTREAM-WINDOWED-0000000003\n" + + " Processor: KSTREAM-WINDOWED-0000000002 (stores: [in-memory-join-store])\n" + + " --> KSTREAM-JOINTHIS-0000000004\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: KSTREAM-WINDOWED-0000000003 (stores: [in-memory-join-store-other])\n" + + " --> KSTREAM-JOINOTHER-0000000005\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: KSTREAM-JOINOTHER-0000000005 (stores: [in-memory-join-store])\n" + + " --> KSTREAM-MERGE-0000000006\n" + + " <-- KSTREAM-WINDOWED-0000000003\n" + + " Processor: KSTREAM-JOINTHIS-0000000004 (stores: [in-memory-join-store-other])\n" + + " --> KSTREAM-MERGE-0000000006\n" + + " <-- KSTREAM-WINDOWED-0000000002\n" + + " Processor: KSTREAM-MERGE-0000000006 (stores: [])\n" + + " --> none\n" + + " <-- KSTREAM-JOINTHIS-0000000004, KSTREAM-JOINOTHER-0000000005\n\n", + describe.toString()); + } + + @Test + public void streamStreamLeftJoinTopologyWithDefaultStoresNames() { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream1; + final KStream stream2; + + stream1 = builder.stream("input-topic1"); + stream2 = builder.stream("input-topic2"); + + stream1.leftJoin( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(100)), + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String())); + + final TopologyDescription describe = builder.build().describe(); + + assertEquals( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input-topic1])\n" + + " --> KSTREAM-WINDOWED-0000000002\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [input-topic2])\n" + + " --> KSTREAM-WINDOWED-0000000003\n" + + " Processor: KSTREAM-WINDOWED-0000000002 (stores: [KSTREAM-JOINTHIS-0000000004-store])\n" + + " --> KSTREAM-JOINTHIS-0000000004\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: KSTREAM-WINDOWED-0000000003 (stores: [KSTREAM-OUTEROTHER-0000000005-store])\n" + + " --> KSTREAM-OUTEROTHER-0000000005\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: KSTREAM-JOINTHIS-0000000004 (stores: [KSTREAM-OUTEROTHER-0000000005-store, KSTREAM-OUTERSHARED-0000000004-store])\n" + + " --> KSTREAM-MERGE-0000000006\n" + + " <-- KSTREAM-WINDOWED-0000000002\n" + + " Processor: KSTREAM-OUTEROTHER-0000000005 (stores: [KSTREAM-JOINTHIS-0000000004-store, KSTREAM-OUTERSHARED-0000000004-store])\n" + + " --> KSTREAM-MERGE-0000000006\n" + + " <-- KSTREAM-WINDOWED-0000000003\n" + + " Processor: KSTREAM-MERGE-0000000006 (stores: [])\n" + + " --> none\n" + + " <-- KSTREAM-JOINTHIS-0000000004, KSTREAM-OUTEROTHER-0000000005\n\n", + describe.toString()); + } + + @Test + public void streamStreamLeftJoinTopologyWithCustomStoresNames() { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream1; + final KStream stream2; + + stream1 = builder.stream("input-topic1"); + stream2 = builder.stream("input-topic2"); + + stream1.leftJoin( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(100)), + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()) + .withStoreName("custom-name")); + + final TopologyDescription describe = builder.build().describe(); + + assertEquals( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input-topic1])\n" + + " --> KSTREAM-WINDOWED-0000000002\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [input-topic2])\n" + + " --> KSTREAM-WINDOWED-0000000003\n" + + " Processor: KSTREAM-WINDOWED-0000000002 (stores: [custom-name-this-join-store])\n" + + " --> KSTREAM-JOINTHIS-0000000004\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: KSTREAM-WINDOWED-0000000003 (stores: [custom-name-outer-other-join-store])\n" + + " --> KSTREAM-OUTEROTHER-0000000005\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: KSTREAM-JOINTHIS-0000000004 (stores: [custom-name-outer-other-join-store, custom-name-left-shared-join-store])\n" + + " --> KSTREAM-MERGE-0000000006\n" + + " <-- KSTREAM-WINDOWED-0000000002\n" + + " Processor: KSTREAM-OUTEROTHER-0000000005 (stores: [custom-name-this-join-store, custom-name-left-shared-join-store])\n" + + " --> KSTREAM-MERGE-0000000006\n" + + " <-- KSTREAM-WINDOWED-0000000003\n" + + " Processor: KSTREAM-MERGE-0000000006 (stores: [])\n" + + " --> none\n" + + " <-- KSTREAM-JOINTHIS-0000000004, KSTREAM-OUTEROTHER-0000000005\n\n", + describe.toString()); + } + + @Test + public void streamStreamLeftJoinTopologyWithCustomStoresSuppliers() { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream1; + final KStream stream2; + + stream1 = builder.stream("input-topic1"); + stream2 = builder.stream("input-topic2"); + + final JoinWindows joinWindows = JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(100)); + + final WindowBytesStoreSupplier thisStoreSupplier = Stores.inMemoryWindowStore("in-memory-join-store", + Duration.ofMillis(joinWindows.size() + joinWindows.gracePeriodMs()), + Duration.ofMillis(joinWindows.size()), true); + + final WindowBytesStoreSupplier otherStoreSupplier = Stores.inMemoryWindowStore("in-memory-join-store-other", + Duration.ofMillis(joinWindows.size() + joinWindows.gracePeriodMs()), + Duration.ofMillis(joinWindows.size()), true); + + stream1.leftJoin( + stream2, + MockValueJoiner.TOSTRING_JOINER, + joinWindows, + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()) + .withThisStoreSupplier(thisStoreSupplier) + .withOtherStoreSupplier(otherStoreSupplier)); + + final TopologyDescription describe = builder.build().describe(); + + assertEquals( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input-topic1])\n" + + " --> KSTREAM-WINDOWED-0000000002\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [input-topic2])\n" + + " --> KSTREAM-WINDOWED-0000000003\n" + + " Processor: KSTREAM-WINDOWED-0000000002 (stores: [in-memory-join-store])\n" + + " --> KSTREAM-JOINTHIS-0000000004\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: KSTREAM-WINDOWED-0000000003 (stores: [in-memory-join-store-other])\n" + + " --> KSTREAM-OUTEROTHER-0000000005\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: KSTREAM-JOINTHIS-0000000004 (stores: [in-memory-join-store-other, in-memory-join-store-left-shared-join-store])\n" + + " --> KSTREAM-MERGE-0000000006\n" + + " <-- KSTREAM-WINDOWED-0000000002\n" + + " Processor: KSTREAM-OUTEROTHER-0000000005 (stores: [in-memory-join-store, in-memory-join-store-left-shared-join-store])\n" + + " --> KSTREAM-MERGE-0000000006\n" + + " <-- KSTREAM-WINDOWED-0000000003\n" + + " Processor: KSTREAM-MERGE-0000000006 (stores: [])\n" + + " --> none\n" + + " <-- KSTREAM-JOINTHIS-0000000004, KSTREAM-OUTEROTHER-0000000005\n\n", + describe.toString()); + } + + @Test + public void streamStreamOuterJoinTopologyWithDefaultStoresNames() { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream1; + final KStream stream2; + + stream1 = builder.stream("input-topic1"); + stream2 = builder.stream("input-topic2"); + + stream1.outerJoin( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(100)), + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String())); + + final TopologyDescription describe = builder.build().describe(); + + assertEquals( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input-topic1])\n" + + " --> KSTREAM-WINDOWED-0000000002\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [input-topic2])\n" + + " --> KSTREAM-WINDOWED-0000000003\n" + + " Processor: KSTREAM-WINDOWED-0000000002 (stores: [KSTREAM-OUTERTHIS-0000000004-store])\n" + + " --> KSTREAM-OUTERTHIS-0000000004\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: KSTREAM-WINDOWED-0000000003 (stores: [KSTREAM-OUTEROTHER-0000000005-store])\n" + + " --> KSTREAM-OUTEROTHER-0000000005\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: KSTREAM-OUTEROTHER-0000000005 (stores: [KSTREAM-OUTERTHIS-0000000004-store, KSTREAM-OUTERSHARED-0000000004-store])\n" + + " --> KSTREAM-MERGE-0000000006\n" + + " <-- KSTREAM-WINDOWED-0000000003\n" + + " Processor: KSTREAM-OUTERTHIS-0000000004 (stores: [KSTREAM-OUTEROTHER-0000000005-store, KSTREAM-OUTERSHARED-0000000004-store])\n" + + " --> KSTREAM-MERGE-0000000006\n" + + " <-- KSTREAM-WINDOWED-0000000002\n" + + " Processor: KSTREAM-MERGE-0000000006 (stores: [])\n" + + " --> none\n" + + " <-- KSTREAM-OUTERTHIS-0000000004, KSTREAM-OUTEROTHER-0000000005\n\n", + describe.toString()); + } + + @Test + public void streamStreamOuterJoinTopologyWithCustomStoresNames() { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream1; + final KStream stream2; + + stream1 = builder.stream("input-topic1"); + stream2 = builder.stream("input-topic2"); + + stream1.outerJoin( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(100)), + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()) + .withStoreName("custom-name")); + + final TopologyDescription describe = builder.build().describe(); + + assertEquals( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input-topic1])\n" + + " --> KSTREAM-WINDOWED-0000000002\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [input-topic2])\n" + + " --> KSTREAM-WINDOWED-0000000003\n" + + " Processor: KSTREAM-WINDOWED-0000000002 (stores: [custom-name-outer-this-join-store])\n" + + " --> KSTREAM-OUTERTHIS-0000000004\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: KSTREAM-WINDOWED-0000000003 (stores: [custom-name-outer-other-join-store])\n" + + " --> KSTREAM-OUTEROTHER-0000000005\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: KSTREAM-OUTEROTHER-0000000005 (stores: [custom-name-outer-this-join-store, custom-name-outer-shared-join-store])\n" + + " --> KSTREAM-MERGE-0000000006\n" + + " <-- KSTREAM-WINDOWED-0000000003\n" + + " Processor: KSTREAM-OUTERTHIS-0000000004 (stores: [custom-name-outer-other-join-store, custom-name-outer-shared-join-store])\n" + + " --> KSTREAM-MERGE-0000000006\n" + + " <-- KSTREAM-WINDOWED-0000000002\n" + + " Processor: KSTREAM-MERGE-0000000006 (stores: [])\n" + + " --> none\n" + + " <-- KSTREAM-OUTERTHIS-0000000004, KSTREAM-OUTEROTHER-0000000005\n\n", + describe.toString()); + } + + @Test + public void streamStreamOuterJoinTopologyWithCustomStoresSuppliers() { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream1; + final KStream stream2; + + stream1 = builder.stream("input-topic1"); + stream2 = builder.stream("input-topic2"); + + final JoinWindows joinWindows = JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(100)); + + final WindowBytesStoreSupplier thisStoreSupplier = Stores.inMemoryWindowStore("in-memory-join-store", + Duration.ofMillis(joinWindows.size() + joinWindows.gracePeriodMs()), + Duration.ofMillis(joinWindows.size()), true); + + final WindowBytesStoreSupplier otherStoreSupplier = Stores.inMemoryWindowStore("in-memory-join-store-other", + Duration.ofMillis(joinWindows.size() + joinWindows.gracePeriodMs()), + Duration.ofMillis(joinWindows.size()), true); + + stream1.outerJoin( + stream2, + MockValueJoiner.TOSTRING_JOINER, + joinWindows, + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()) + .withThisStoreSupplier(thisStoreSupplier) + .withOtherStoreSupplier(otherStoreSupplier)); + + final TopologyDescription describe = builder.build().describe(); + + assertEquals( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input-topic1])\n" + + " --> KSTREAM-WINDOWED-0000000002\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [input-topic2])\n" + + " --> KSTREAM-WINDOWED-0000000003\n" + + " Processor: KSTREAM-WINDOWED-0000000002 (stores: [in-memory-join-store])\n" + + " --> KSTREAM-OUTERTHIS-0000000004\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: KSTREAM-WINDOWED-0000000003 (stores: [in-memory-join-store-other])\n" + + " --> KSTREAM-OUTEROTHER-0000000005\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: KSTREAM-OUTEROTHER-0000000005 (stores: [in-memory-join-store-outer-shared-join-store, in-memory-join-store])\n" + + " --> KSTREAM-MERGE-0000000006\n" + + " <-- KSTREAM-WINDOWED-0000000003\n" + + " Processor: KSTREAM-OUTERTHIS-0000000004 (stores: [in-memory-join-store-other, in-memory-join-store-outer-shared-join-store])\n" + + " --> KSTREAM-MERGE-0000000006\n" + + " <-- KSTREAM-WINDOWED-0000000002\n" + + " Processor: KSTREAM-MERGE-0000000006 (stores: [])\n" + + " --> none\n" + + " <-- KSTREAM-OUTERTHIS-0000000004, KSTREAM-OUTEROTHER-0000000005\n\n", + describe.toString()); + } + + @Test + public void topologyWithDynamicRoutingShouldDescribeExtractorClass() { + final StreamsBuilder builder = new StreamsBuilder(); + + final TopicNameExtractor topicNameExtractor = new TopicNameExtractor() { + @Override + public String extract(final Object key, final Object value, final RecordContext recordContext) { + return recordContext.topic() + "-" + key; + } + + @Override + public String toString() { + return "anonymous topic name extractor. topic is [recordContext.topic()]-[key]"; + } + }; + builder.stream("input-topic").to(topicNameExtractor); + final TopologyDescription describe = builder.build().describe(); + + assertEquals( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input-topic])\n" + + " --> KSTREAM-SINK-0000000001\n" + + " Sink: KSTREAM-SINK-0000000001 (extractor class: anonymous topic name extractor. topic is [recordContext.topic()]-[key])\n" + + " <-- KSTREAM-SOURCE-0000000000\n\n", + describe.toString()); + } + + @Test + public void kGroupedStreamZeroArgCountShouldPreserveTopologyStructure() { + final StreamsBuilder builder = new StreamsBuilder(); + builder.stream("input-topic") + .groupByKey() + .count(); + final TopologyDescription describe = builder.build().describe(); + assertEquals( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input-topic])\n" + + " --> KSTREAM-AGGREGATE-0000000002\n" + + " Processor: KSTREAM-AGGREGATE-0000000002 (stores: [KSTREAM-AGGREGATE-STATE-STORE-0000000001])\n" + + " --> none\n" + + " <-- KSTREAM-SOURCE-0000000000\n\n", + describe.toString() + ); + } + + @Test + public void kGroupedStreamNamedMaterializedCountShouldPreserveTopologyStructure() { + final StreamsBuilder builder = new StreamsBuilder(); + builder.stream("input-topic") + .groupByKey() + .count(Materialized.as("count-store")); + final TopologyDescription describe = builder.build().describe(); + assertEquals( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input-topic])\n" + + " --> KSTREAM-AGGREGATE-0000000001\n" + + " Processor: KSTREAM-AGGREGATE-0000000001 (stores: [count-store])\n" + + " --> none\n" + + " <-- KSTREAM-SOURCE-0000000000\n\n", + describe.toString() + ); + } + + @Test + public void kGroupedStreamAnonymousMaterializedCountShouldPreserveTopologyStructure() { + final StreamsBuilder builder = new StreamsBuilder(); + builder.stream("input-topic") + .groupByKey() + .count(Materialized.with(null, Serdes.Long())); + final TopologyDescription describe = builder.build().describe(); + assertEquals( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input-topic])\n" + + " --> KSTREAM-AGGREGATE-0000000003\n" + + " Processor: KSTREAM-AGGREGATE-0000000003 (stores: [KSTREAM-AGGREGATE-STATE-STORE-0000000002])\n" + + " --> none\n" + + " <-- KSTREAM-SOURCE-0000000000\n\n", + describe.toString() + ); + } + + @Test + public void timeWindowZeroArgCountShouldPreserveTopologyStructure() { + final StreamsBuilder builder = new StreamsBuilder(); + builder.stream("input-topic") + .groupByKey() + .windowedBy(TimeWindows.of(ofMillis(1))) + .count(); + final TopologyDescription describe = builder.build().describe(); + assertEquals( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input-topic])\n" + + " --> KSTREAM-AGGREGATE-0000000002\n" + + " Processor: KSTREAM-AGGREGATE-0000000002 (stores: [KSTREAM-AGGREGATE-STATE-STORE-0000000001])\n" + + " --> none\n" + + " <-- KSTREAM-SOURCE-0000000000\n\n", + describe.toString() + ); + } + + @Test + public void timeWindowNamedMaterializedCountShouldPreserveTopologyStructure() { + final StreamsBuilder builder = new StreamsBuilder(); + builder.stream("input-topic") + .groupByKey() + .windowedBy(TimeWindows.of(ofMillis(1))) + .count(Materialized.as("count-store")); + final TopologyDescription describe = builder.build().describe(); + assertEquals( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input-topic])\n" + + " --> KSTREAM-AGGREGATE-0000000001\n" + + " Processor: KSTREAM-AGGREGATE-0000000001 (stores: [count-store])\n" + + " --> none\n" + + " <-- KSTREAM-SOURCE-0000000000\n\n", + describe.toString() + ); + } + + @Test + public void timeWindowAnonymousMaterializedCountShouldPreserveTopologyStructure() { + final StreamsBuilder builder = new StreamsBuilder(); + builder.stream("input-topic") + .groupByKey() + .windowedBy(TimeWindows.of(ofMillis(1))) + .count(Materialized.with(null, Serdes.Long())); + final TopologyDescription describe = builder.build().describe(); + assertEquals( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input-topic])\n" + + " --> KSTREAM-AGGREGATE-0000000003\n" + + " Processor: KSTREAM-AGGREGATE-0000000003 (stores: [KSTREAM-AGGREGATE-STATE-STORE-0000000002])\n" + + " --> none\n" + + " <-- KSTREAM-SOURCE-0000000000\n\n", + describe.toString() + ); + } + + @Test + public void sessionWindowZeroArgCountShouldPreserveTopologyStructure() { + final StreamsBuilder builder = new StreamsBuilder(); + builder.stream("input-topic") + .groupByKey() + .windowedBy(SessionWindows.with(ofMillis(1))) + .count(); + final TopologyDescription describe = builder.build().describe(); + assertEquals( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input-topic])\n" + + " --> KSTREAM-AGGREGATE-0000000002\n" + + " Processor: KSTREAM-AGGREGATE-0000000002 (stores: [KSTREAM-AGGREGATE-STATE-STORE-0000000001])\n" + + " --> none\n" + + " <-- KSTREAM-SOURCE-0000000000\n\n", + describe.toString() + ); + } + + @Test + public void sessionWindowNamedMaterializedCountShouldPreserveTopologyStructure() { + final StreamsBuilder builder = new StreamsBuilder(); + builder.stream("input-topic") + .groupByKey() + .windowedBy(SessionWindows.with(ofMillis(1))) + .count(Materialized.as("count-store")); + final TopologyDescription describe = builder.build().describe(); + assertEquals( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input-topic])\n" + + " --> KSTREAM-AGGREGATE-0000000001\n" + + " Processor: KSTREAM-AGGREGATE-0000000001 (stores: [count-store])\n" + + " --> none\n" + + " <-- KSTREAM-SOURCE-0000000000\n\n", + describe.toString() + ); + } + + @Test + public void sessionWindowAnonymousMaterializedCountShouldPreserveTopologyStructure() { + final StreamsBuilder builder = new StreamsBuilder(); + builder.stream("input-topic") + .groupByKey() + .windowedBy(SessionWindows.with(ofMillis(1))) + .count(Materialized.with(null, Serdes.Long())); + final TopologyDescription describe = builder.build().describe(); + assertEquals( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input-topic])\n" + + " --> KSTREAM-AGGREGATE-0000000003\n" + + " Processor: KSTREAM-AGGREGATE-0000000003 (stores: [KSTREAM-AGGREGATE-STATE-STORE-0000000002])\n" + + " --> none\n" + + " <-- KSTREAM-SOURCE-0000000000\n\n", + describe.toString() + ); + } + + @Test + public void tableZeroArgCountShouldPreserveTopologyStructure() { + final StreamsBuilder builder = new StreamsBuilder(); + builder.table("input-topic") + .groupBy((key, value) -> null) + .count(); + final TopologyDescription describe = builder.build().describe(); + + assertEquals( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [input-topic])\n" + + " --> KTABLE-SOURCE-0000000002\n" + + " Processor: KTABLE-SOURCE-0000000002 (stores: [input-topic-STATE-STORE-0000000000])\n" + + " --> KTABLE-SELECT-0000000003\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: KTABLE-SELECT-0000000003 (stores: [])\n" + + " --> KSTREAM-SINK-0000000005\n" + + " <-- KTABLE-SOURCE-0000000002\n" + + " Sink: KSTREAM-SINK-0000000005 (topic: KTABLE-AGGREGATE-STATE-STORE-0000000004-repartition)\n" + + " <-- KTABLE-SELECT-0000000003\n" + + "\n" + + " Sub-topology: 1\n" + + " Source: KSTREAM-SOURCE-0000000006 (topics: [KTABLE-AGGREGATE-STATE-STORE-0000000004-repartition])\n" + + " --> KTABLE-AGGREGATE-0000000007\n" + + " Processor: KTABLE-AGGREGATE-0000000007 (stores: [KTABLE-AGGREGATE-STATE-STORE-0000000004])\n" + + " --> none\n" + + " <-- KSTREAM-SOURCE-0000000006\n" + + "\n", + describe.toString() + ); + } + + @Test + public void tableNamedMaterializedCountShouldPreserveTopologyStructure() { + final StreamsBuilder builder = new StreamsBuilder(); + builder.table("input-topic") + .groupBy((key, value) -> null) + .count(Materialized.as("count-store")); + final TopologyDescription describe = builder.build().describe(); + assertEquals( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [input-topic])\n" + + " --> KTABLE-SOURCE-0000000002\n" + + " Processor: KTABLE-SOURCE-0000000002 (stores: [input-topic-STATE-STORE-0000000000])\n" + + " --> KTABLE-SELECT-0000000003\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: KTABLE-SELECT-0000000003 (stores: [])\n" + + " --> KSTREAM-SINK-0000000004\n" + + " <-- KTABLE-SOURCE-0000000002\n" + + " Sink: KSTREAM-SINK-0000000004 (topic: count-store-repartition)\n" + + " <-- KTABLE-SELECT-0000000003\n" + + "\n" + + " Sub-topology: 1\n" + + " Source: KSTREAM-SOURCE-0000000005 (topics: [count-store-repartition])\n" + + " --> KTABLE-AGGREGATE-0000000006\n" + + " Processor: KTABLE-AGGREGATE-0000000006 (stores: [count-store])\n" + + " --> none\n" + + " <-- KSTREAM-SOURCE-0000000005\n" + + "\n", + describe.toString() + ); + } + + @Test + public void tableAnonymousMaterializedCountShouldPreserveTopologyStructure() { + final StreamsBuilder builder = new StreamsBuilder(); + builder.table("input-topic") + .groupBy((key, value) -> null) + .count(Materialized.with(null, Serdes.Long())); + final TopologyDescription describe = builder.build().describe(); + assertEquals( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [input-topic])\n" + + " --> KTABLE-SOURCE-0000000002\n" + + " Processor: KTABLE-SOURCE-0000000002 (stores: [input-topic-STATE-STORE-0000000000])\n" + + " --> KTABLE-SELECT-0000000003\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: KTABLE-SELECT-0000000003 (stores: [])\n" + + " --> KSTREAM-SINK-0000000005\n" + + " <-- KTABLE-SOURCE-0000000002\n" + + " Sink: KSTREAM-SINK-0000000005 (topic: KTABLE-AGGREGATE-STATE-STORE-0000000004-repartition)\n" + + " <-- KTABLE-SELECT-0000000003\n" + + "\n" + + " Sub-topology: 1\n" + + " Source: KSTREAM-SOURCE-0000000006 (topics: [KTABLE-AGGREGATE-STATE-STORE-0000000004-repartition])\n" + + " --> KTABLE-AGGREGATE-0000000007\n" + + " Processor: KTABLE-AGGREGATE-0000000007 (stores: [KTABLE-AGGREGATE-STATE-STORE-0000000004])\n" + + " --> none\n" + + " <-- KSTREAM-SOURCE-0000000006\n" + + "\n", + describe.toString() + ); + } + + @Test + public void kTableNonMaterializedMapValuesShouldPreserveTopologyStructure() { + final StreamsBuilder builder = new StreamsBuilder(); + final KTable table = builder.table("input-topic"); + table.mapValues((readOnlyKey, value) -> null); + final TopologyDescription describe = builder.build().describe(); + Assert.assertEquals( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [input-topic])\n" + + " --> KTABLE-SOURCE-0000000002\n" + + " Processor: KTABLE-SOURCE-0000000002 (stores: [])\n" + + " --> KTABLE-MAPVALUES-0000000003\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: KTABLE-MAPVALUES-0000000003 (stores: [])\n" + + " --> none\n" + + " <-- KTABLE-SOURCE-0000000002\n\n", + describe.toString()); + } + + @Test + public void kTableAnonymousMaterializedMapValuesShouldPreserveTopologyStructure() { + final StreamsBuilder builder = new StreamsBuilder(); + final KTable table = builder.table("input-topic"); + table.mapValues( + (readOnlyKey, value) -> null, + Materialized.with(null, null)); + final TopologyDescription describe = builder.build().describe(); + Assert.assertEquals( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [input-topic])\n" + + " --> KTABLE-SOURCE-0000000002\n" + + " Processor: KTABLE-SOURCE-0000000002 (stores: [])\n" + + " --> KTABLE-MAPVALUES-0000000004\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + // previously, this was + // Processor: KTABLE-MAPVALUES-0000000004 (stores: [KTABLE-MAPVALUES-STATE-STORE-0000000003] + // but we added a change not to materialize non-queryable stores. This change shouldn't break compatibility. + " Processor: KTABLE-MAPVALUES-0000000004 (stores: [])\n" + + " --> none\n" + + " <-- KTABLE-SOURCE-0000000002\n" + + "\n", + describe.toString()); + } + + @Test + public void kTableNamedMaterializedMapValuesShouldPreserveTopologyStructure() { + final StreamsBuilder builder = new StreamsBuilder(); + final KTable table = builder.table("input-topic"); + table.mapValues( + (readOnlyKey, value) -> null, + Materialized.>as("store-name").withKeySerde(null).withValueSerde(null)); + final TopologyDescription describe = builder.build().describe(); + Assert.assertEquals( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [input-topic])\n" + + " --> KTABLE-SOURCE-0000000002\n" + + " Processor: KTABLE-SOURCE-0000000002 (stores: [])\n" + + " --> KTABLE-MAPVALUES-0000000003\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: KTABLE-MAPVALUES-0000000003 (stores: [store-name])\n" + + " --> none\n" + + " <-- KTABLE-SOURCE-0000000002\n" + + "\n", + describe.toString()); + } + + @Test + public void kTableNonMaterializedFilterShouldPreserveTopologyStructure() { + final StreamsBuilder builder = new StreamsBuilder(); + final KTable table = builder.table("input-topic"); + table.filter((key, value) -> false); + final TopologyDescription describe = builder.build().describe(); + Assert.assertEquals( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [input-topic])\n" + + " --> KTABLE-SOURCE-0000000002\n" + + " Processor: KTABLE-SOURCE-0000000002 (stores: [])\n" + + " --> KTABLE-FILTER-0000000003\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: KTABLE-FILTER-0000000003 (stores: [])\n" + + " --> none\n" + + " <-- KTABLE-SOURCE-0000000002\n\n", + describe.toString()); + } + + @Test + public void kTableAnonymousMaterializedFilterShouldPreserveTopologyStructure() { + final StreamsBuilder builder = new StreamsBuilder(); + final KTable table = builder.table("input-topic"); + table.filter((key, value) -> false, Materialized.with(null, null)); + final TopologyDescription describe = builder.build().describe(); + Assert.assertEquals( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [input-topic])\n" + + " --> KTABLE-SOURCE-0000000002\n" + + " Processor: KTABLE-SOURCE-0000000002 (stores: [])\n" + + " --> KTABLE-FILTER-0000000004\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + // Previously, this was + // Processor: KTABLE-FILTER-0000000004 (stores: [KTABLE-FILTER-STATE-STORE-0000000003] + // but we added a change not to materialize non-queryable stores. This change shouldn't break compatibility. + " Processor: KTABLE-FILTER-0000000004 (stores: [])\n" + + " --> none\n" + + " <-- KTABLE-SOURCE-0000000002\n" + + "\n", + describe.toString()); + } + + @Test + public void kTableNamedMaterializedFilterShouldPreserveTopologyStructure() { + final StreamsBuilder builder = new StreamsBuilder(); + final KTable table = builder.table("input-topic"); + table.filter((key, value) -> false, Materialized.as("store-name")); + final TopologyDescription describe = builder.build().describe(); + + Assert.assertEquals( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [input-topic])\n" + + " --> KTABLE-SOURCE-0000000002\n" + + " Processor: KTABLE-SOURCE-0000000002 (stores: [])\n" + + " --> KTABLE-FILTER-0000000003\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: KTABLE-FILTER-0000000003 (stores: [store-name])\n" + + " --> none\n" + + " <-- KTABLE-SOURCE-0000000002\n" + + "\n", + describe.toString()); + } + + @Test + public void topologyWithStaticTopicNameExtractorShouldRespectEqualHashcodeContract() { + final Topology topologyA = topologyWithStaticTopicName(); + final Topology topologyB = topologyWithStaticTopicName(); + assertThat(topologyA.describe(), equalTo(topologyB.describe())); + assertThat(topologyA.describe().hashCode(), equalTo(topologyB.describe().hashCode())); + } + + private Topology topologyWithStaticTopicName() { + final StreamsBuilder builder = new StreamsBuilder(); + builder.stream("from-topic-name").to("to-topic-name"); + return builder.build(); + } + + private TopologyDescription.Source addSource(final String sourceName, + final String... sourceTopic) { + topology.addSource(null, sourceName, null, null, null, sourceTopic); + final StringBuilder allSourceTopics = new StringBuilder(sourceTopic[0]); + for (int i = 1; i < sourceTopic.length; ++i) { + allSourceTopics.append(", ").append(sourceTopic[i]); + } + return new InternalTopologyBuilder.Source(sourceName, new HashSet<>(Arrays.asList(sourceTopic)), null); + } + + private TopologyDescription.Source addSource(final String sourceName, + final Pattern sourcePattern) { + topology.addSource(null, sourceName, null, null, null, sourcePattern); + return new InternalTopologyBuilder.Source(sourceName, null, sourcePattern); + } + + private TopologyDescription.Processor addProcessor(final String processorName, + final TopologyDescription.Node... parents) { + return addProcessorWithNewStore(processorName, new String[0], parents); + } + + private TopologyDescription.Processor addProcessorWithNewStore(final String processorName, + final String[] storeNames, + final TopologyDescription.Node... parents) { + return addProcessorWithStore(processorName, storeNames, true, parents); + } + + private TopologyDescription.Processor addProcessorWithExistingStore(final String processorName, + final String[] storeNames, + final TopologyDescription.Node... parents) { + return addProcessorWithStore(processorName, storeNames, false, parents); + } + + private TopologyDescription.Processor addProcessorWithStore(final String processorName, + final String[] storeNames, + final boolean newStores, + final TopologyDescription.Node... parents) { + final String[] parentNames = new String[parents.length]; + for (int i = 0; i < parents.length; ++i) { + parentNames[i] = parents[i].name(); + } + + topology.addProcessor(processorName, new MockApiProcessorSupplier<>(), parentNames); + if (newStores) { + for (final String store : storeNames) { + final StoreBuilder storeBuilder = EasyMock.createNiceMock(StoreBuilder.class); + EasyMock.expect(storeBuilder.name()).andReturn(store).anyTimes(); + EasyMock.replay(storeBuilder); + topology.addStateStore(storeBuilder, processorName); + } + } else { + topology.connectProcessorAndStateStores(processorName, storeNames); + } + final TopologyDescription.Processor expectedProcessorNode = + new InternalTopologyBuilder.Processor(processorName, new HashSet<>(Arrays.asList(storeNames))); + + for (final TopologyDescription.Node parent : parents) { + ((InternalTopologyBuilder.AbstractNode) parent).addSuccessor(expectedProcessorNode); + ((InternalTopologyBuilder.AbstractNode) expectedProcessorNode).addPredecessor(parent); + } + + return expectedProcessorNode; + } + + private TopologyDescription.Sink addSink(final String sinkName, + final String sinkTopic, + final TopologyDescription.Node... parents) { + final String[] parentNames = new String[parents.length]; + for (int i = 0; i < parents.length; ++i) { + parentNames[i] = parents[i].name(); + } + + topology.addSink(sinkName, sinkTopic, null, null, null, parentNames); + final TopologyDescription.Sink expectedSinkNode = + new InternalTopologyBuilder.Sink(sinkName, sinkTopic); + + for (final TopologyDescription.Node parent : parents) { + ((InternalTopologyBuilder.AbstractNode) parent).addSuccessor(expectedSinkNode); + ((InternalTopologyBuilder.AbstractNode) expectedSinkNode).addPredecessor(parent); + } + + return expectedSinkNode; + } + + @Deprecated // testing old PAPI + private void addGlobalStoreToTopologyAndExpectedDescription(final String globalStoreName, + final String sourceName, + final String globalTopicName, + final String processorName, + final int id) { + final KeyValueStoreBuilder globalStoreBuilder = EasyMock.createNiceMock(KeyValueStoreBuilder.class); + EasyMock.expect(globalStoreBuilder.name()).andReturn(globalStoreName).anyTimes(); + EasyMock.replay(globalStoreBuilder); + topology.addGlobalStore( + globalStoreBuilder, + sourceName, + null, + null, + null, + globalTopicName, + processorName, + new MockProcessorSupplier<>()); + + final TopologyDescription.GlobalStore expectedGlobalStore = new InternalTopologyBuilder.GlobalStore( + sourceName, + processorName, + globalStoreName, + globalTopicName, + id); + + expectedDescription.addGlobalStore(expectedGlobalStore); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/TopologyTestDriverWrapper.java b/streams/src/test/java/org/apache/kafka/streams/TopologyTestDriverWrapper.java new file mode 100644 index 0000000..e91b007 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/TopologyTestDriverWrapper.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.internals.ProcessorContextImpl; +import org.apache.kafka.streams.processor.internals.ProcessorNode; + +import java.util.Properties; + +/** + * This class provides access to {@link TopologyTestDriver} protected methods. + * It should only be used for internal testing, in the rare occasions where the + * necessary functionality is not supported by {@link TopologyTestDriver}. + */ +public class TopologyTestDriverWrapper extends TopologyTestDriver { + + + public TopologyTestDriverWrapper(final Topology topology, + final Properties config) { + super(topology, config); + } + + /** + * Get the processor context, setting the processor whose name is given as current node + * + * @param processorName processor name to set as current node + * @return the processor context + */ + public ProcessorContext setCurrentNodeForProcessorContext(final String processorName) { + final ProcessorContext context = task.processorContext(); + ((ProcessorContextImpl) context).setCurrentNode(getProcessor(processorName)); + return context; + } + + /** + * Get a processor by name + * + * @param name the name to search for + * @return the processor matching the search name + */ + public ProcessorNode getProcessor(final String name) { + for (final ProcessorNode node : processorTopology.processors()) { + if (node.name().equals(name)) { + return node; + } + } + for (final ProcessorNode node : globalTopology.processors()) { + if (node.name().equals(name)) { + return node; + } + } + throw new StreamsException("Could not find a processor named '" + name + "'"); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/TopologyWrapper.java b/streams/src/test/java/org/apache/kafka/streams/TopologyWrapper.java new file mode 100644 index 0000000..e1c7c11 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/TopologyWrapper.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; +import org.apache.kafka.test.StreamsTestUtils; + +/** + * This class allows to access the {@link InternalTopologyBuilder} a {@link Topology} object. + * + */ +public class TopologyWrapper extends Topology { + + static public InternalTopologyBuilder getInternalTopologyBuilder(final Topology topology) { + return topology.internalTopologyBuilder.rewriteTopology(new StreamsConfig(StreamsTestUtils.getStreamsConfig())); + } + + public InternalTopologyBuilder getInternalBuilder() { + return internalTopologyBuilder.rewriteTopology(new StreamsConfig(StreamsTestUtils.getStreamsConfig())); + } + + public InternalTopologyBuilder getInternalBuilder(final String applicationId) { + return internalTopologyBuilder.rewriteTopology(new StreamsConfig(StreamsTestUtils.getStreamsConfig(applicationId))); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/errors/AlwaysContinueProductionExceptionHandler.java b/streams/src/test/java/org/apache/kafka/streams/errors/AlwaysContinueProductionExceptionHandler.java new file mode 100644 index 0000000..111874d --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/errors/AlwaysContinueProductionExceptionHandler.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.errors; + +import java.util.Map; +import org.apache.kafka.clients.producer.ProducerRecord; + +/** + * Production exception handler that always instructs streams to continue when an exception + * happens while attempting to produce result records. + */ +public class AlwaysContinueProductionExceptionHandler implements ProductionExceptionHandler { + @Override + public ProductionExceptionHandlerResponse handle(final ProducerRecord record, + final Exception exception) { + return ProductionExceptionHandlerResponse.CONTINUE; + } + + @Override + public void configure(final Map configs) { + // ignore + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/AbstractJoinIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/AbstractJoinIntegrationTest.java new file mode 100644 index 0000000..d41cec0 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/AbstractJoinIntegrationTest.java @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.serialization.LongDeserializer; +import org.apache.kafka.common.serialization.LongSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TestOutputTopic; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.ValueJoiner; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.ReadOnlyKeyValueStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.streams.test.TestRecord; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestUtils; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.experimental.categories.Category; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Properties; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.Is.is; +import static org.hamcrest.core.IsEqual.equalTo; + +/** + * Tests all available joins of Kafka Streams DSL. + */ +@Category({IntegrationTest.class}) +@RunWith(value = Parameterized.class) +public abstract class AbstractJoinIntegrationTest { + @Rule + public final TemporaryFolder testFolder = new TemporaryFolder(TestUtils.tempDirectory()); + + @Parameterized.Parameters(name = "caching enabled = {0}") + public static Collection data() { + final List values = new ArrayList<>(); + for (final boolean cacheEnabled : Arrays.asList(true, false)) { + values.add(new Object[]{cacheEnabled}); + } + return values; + } + + static String appID; + + private final MockTime time = new MockTime(); + private static final Long COMMIT_INTERVAL = 100L; + static final Properties STREAMS_CONFIG = new Properties(); + static final String INPUT_TOPIC_RIGHT = "inputTopicRight"; + static final String INPUT_TOPIC_LEFT = "inputTopicLeft"; + static final String OUTPUT_TOPIC = "outputTopic"; + static final long ANY_UNIQUE_KEY = 0L; + + StreamsBuilder builder; + + private final List> input = Arrays.asList( + new Input<>(INPUT_TOPIC_LEFT, null), + new Input<>(INPUT_TOPIC_RIGHT, null), + new Input<>(INPUT_TOPIC_LEFT, "A"), + new Input<>(INPUT_TOPIC_RIGHT, "a"), + new Input<>(INPUT_TOPIC_LEFT, "B"), + new Input<>(INPUT_TOPIC_RIGHT, "b"), + new Input<>(INPUT_TOPIC_LEFT, null), + new Input<>(INPUT_TOPIC_RIGHT, null), + new Input<>(INPUT_TOPIC_LEFT, "C"), + new Input<>(INPUT_TOPIC_RIGHT, "c"), + new Input<>(INPUT_TOPIC_RIGHT, null), + new Input<>(INPUT_TOPIC_LEFT, null), + new Input<>(INPUT_TOPIC_RIGHT, null), + new Input<>(INPUT_TOPIC_RIGHT, "d"), + new Input<>(INPUT_TOPIC_LEFT, "D") + ); + + final ValueJoiner valueJoiner = (value1, value2) -> value1 + "-" + value2; + + final boolean cacheEnabled; + + AbstractJoinIntegrationTest(final boolean cacheEnabled) { + this.cacheEnabled = cacheEnabled; + } + + @BeforeClass + public static void setupConfigsAndUtils() { + STREAMS_CONFIG.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + STREAMS_CONFIG.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.Long().getClass()); + STREAMS_CONFIG.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + STREAMS_CONFIG.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, COMMIT_INTERVAL); + } + + void prepareEnvironment() throws InterruptedException { + if (!cacheEnabled) { + STREAMS_CONFIG.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + } + + STREAMS_CONFIG.put(StreamsConfig.STATE_DIR_CONFIG, testFolder.getRoot().getPath()); + } + + void runTestWithDriver(final List>> expectedResult) { + runTestWithDriver(expectedResult, null); + } + + void runTestWithDriver(final List>> expectedResult, final String storeName) { + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(STREAMS_CONFIG), STREAMS_CONFIG)) { + final TestInputTopic right = driver.createInputTopic(INPUT_TOPIC_RIGHT, new LongSerializer(), new StringSerializer()); + final TestInputTopic left = driver.createInputTopic(INPUT_TOPIC_LEFT, new LongSerializer(), new StringSerializer()); + final TestOutputTopic outputTopic = driver.createOutputTopic(OUTPUT_TOPIC, new LongDeserializer(), new StringDeserializer()); + final Map> testInputTopicMap = new HashMap<>(); + + testInputTopicMap.put(INPUT_TOPIC_RIGHT, right); + testInputTopicMap.put(INPUT_TOPIC_LEFT, left); + + TestRecord expectedFinalResult = null; + + final long firstTimestamp = time.milliseconds(); + long eventTimestamp = firstTimestamp; + final Iterator>> resultIterator = expectedResult.iterator(); + for (final Input singleInputRecord : input) { + testInputTopicMap.get(singleInputRecord.topic).pipeInput(singleInputRecord.record.key, singleInputRecord.record.value, ++eventTimestamp); + + final List> expected = resultIterator.next(); + if (expected != null) { + final List> updatedExpected = new LinkedList<>(); + for (final TestRecord record : expected) { + updatedExpected.add(new TestRecord<>(record.key(), record.value(), null, firstTimestamp + record.timestamp())); + } + + final List> output = outputTopic.readRecordsToList(); + assertThat(output, equalTo(updatedExpected)); + expectedFinalResult = updatedExpected.get(expected.size() - 1); + } + } + + if (storeName != null) { + checkQueryableStore(storeName, expectedFinalResult, driver); + } + } + } + + void runTestWithDriver(final TestRecord expectedFinalResult, final String storeName) throws InterruptedException { + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(STREAMS_CONFIG), STREAMS_CONFIG)) { + final TestInputTopic right = driver.createInputTopic(INPUT_TOPIC_RIGHT, new LongSerializer(), new StringSerializer()); + final TestInputTopic left = driver.createInputTopic(INPUT_TOPIC_LEFT, new LongSerializer(), new StringSerializer()); + final TestOutputTopic outputTopic = driver.createOutputTopic(OUTPUT_TOPIC, new LongDeserializer(), new StringDeserializer()); + final Map> testInputTopicMap = new HashMap<>(); + + testInputTopicMap.put(INPUT_TOPIC_RIGHT, right); + testInputTopicMap.put(INPUT_TOPIC_LEFT, left); + + final long firstTimestamp = time.milliseconds(); + long eventTimestamp = firstTimestamp; + + for (final Input singleInputRecord : input) { + testInputTopicMap.get(singleInputRecord.topic).pipeInput(singleInputRecord.record.key, singleInputRecord.record.value, ++eventTimestamp); + } + + final TestRecord updatedExpectedFinalResult = + new TestRecord<>( + expectedFinalResult.key(), + expectedFinalResult.value(), + null, + firstTimestamp + expectedFinalResult.timestamp()); + + final List> output = outputTopic.readRecordsToList(); + + assertThat(output.get(output.size() - 1), equalTo(updatedExpectedFinalResult)); + + if (storeName != null) { + checkQueryableStore(storeName, updatedExpectedFinalResult, driver); + } + } + } + + private void checkQueryableStore(final String queryableName, final TestRecord expectedFinalResult, final TopologyTestDriver driver) { + final ReadOnlyKeyValueStore> store = driver.getTimestampedKeyValueStore(queryableName); + + final KeyValueIterator> all = store.all(); + final KeyValue> onlyEntry = all.next(); + + try { + assertThat(onlyEntry.key, is(expectedFinalResult.key())); + assertThat(onlyEntry.value.value(), is(expectedFinalResult.value())); + assertThat(onlyEntry.value.timestamp(), is(expectedFinalResult.timestamp())); + assertThat(all.hasNext(), is(false)); + } finally { + all.close(); + } + } + + private static final class Input { + String topic; + KeyValue record; + + Input(final String topic, final V value) { + this.topic = topic; + record = KeyValue.pair(ANY_UNIQUE_KEY, value); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/AbstractResetIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/AbstractResetIntegrationTest.java new file mode 100644 index 0000000..fd5da12 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/AbstractResetIntegrationTest.java @@ -0,0 +1,447 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import kafka.tools.StreamsResetter; +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.config.SslConfigs; +import org.apache.kafka.common.config.types.Password; +import org.apache.kafka.common.internals.Topic; +import org.apache.kafka.common.serialization.LongDeserializer; +import org.apache.kafka.common.serialization.LongSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestUtils; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TemporaryFolder; +import org.junit.rules.TestName; + +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileWriter; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.stream.Collectors; + +import static java.time.Duration.ofMillis; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.waitForEmptyConsumerGroup; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; + +@Category({IntegrationTest.class}) +public abstract class AbstractResetIntegrationTest { + static EmbeddedKafkaCluster cluster; + + private static MockTime mockTime; + protected static KafkaStreams streams; + protected static Admin adminClient; + + abstract Map getClientSslConfig(); + + @Rule + public final TestName testName = new TestName(); + + @AfterClass + public static void afterClassCleanup() { + if (adminClient != null) { + adminClient.close(Duration.ofSeconds(10)); + adminClient = null; + } + } + + protected Properties commonClientConfig; + protected Properties streamsConfig; + private Properties producerConfig; + protected Properties resultConsumerConfig; + + private void prepareEnvironment() { + if (adminClient == null) { + adminClient = Admin.create(commonClientConfig); + } + + boolean timeSet = false; + while (!timeSet) { + timeSet = setCurrentTime(); + } + } + + private boolean setCurrentTime() { + boolean currentTimeSet = false; + try { + mockTime = cluster.time; + // we align time to seconds to get clean window boundaries and thus ensure the same result for each run + // otherwise, input records could fall into different windows for different runs depending on the initial mock time + final long alignedTime = (System.currentTimeMillis() / 1000 + 1) * 1000; + mockTime.setCurrentTimeMs(alignedTime); + currentTimeSet = true; + } catch (final IllegalArgumentException e) { + // don't care will retry until set + } + return currentTimeSet; + } + + private void prepareConfigs(final String appID) { + commonClientConfig = new Properties(); + commonClientConfig.put(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG, cluster.bootstrapServers()); + + final Map sslConfig = getClientSslConfig(); + if (sslConfig != null) { + commonClientConfig.put(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG, sslConfig.get(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG)); + commonClientConfig.put(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG, ((Password) sslConfig.get(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG)).value()); + commonClientConfig.put(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, "SSL"); + } + + producerConfig = new Properties(); + producerConfig.put(ProducerConfig.ACKS_CONFIG, "all"); + producerConfig.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, LongSerializer.class); + producerConfig.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, StringSerializer.class); + producerConfig.putAll(commonClientConfig); + + resultConsumerConfig = new Properties(); + resultConsumerConfig.put(ConsumerConfig.GROUP_ID_CONFIG, appID + "-result-consumer"); + resultConsumerConfig.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + resultConsumerConfig.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, LongDeserializer.class); + resultConsumerConfig.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, LongDeserializer.class); + resultConsumerConfig.putAll(commonClientConfig); + + streamsConfig = new Properties(); + streamsConfig.put(StreamsConfig.STATE_DIR_CONFIG, testFolder.getRoot().getPath()); + streamsConfig.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.Long().getClass()); + streamsConfig.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + streamsConfig.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + streamsConfig.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L); + streamsConfig.put(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, 100); + streamsConfig.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + streamsConfig.put(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, Integer.toString(STREAMS_CONSUMER_TIMEOUT)); + streamsConfig.putAll(commonClientConfig); + } + + @Rule + public final TemporaryFolder testFolder = new TemporaryFolder(TestUtils.tempDirectory()); + + protected static final String INPUT_TOPIC = "inputTopic"; + protected static final String OUTPUT_TOPIC = "outputTopic"; + private static final String OUTPUT_TOPIC_2 = "outputTopic2"; + private static final String OUTPUT_TOPIC_2_RERUN = "outputTopic2_rerun"; + private static final String INTERMEDIATE_USER_TOPIC = "userTopic"; + + protected static final int STREAMS_CONSUMER_TIMEOUT = 2000; + protected static final int CLEANUP_CONSUMER_TIMEOUT = 2000; + protected static final int TIMEOUT_MULTIPLIER = 15; + + void prepareTest() throws Exception { + final String appID = IntegrationTestUtils.safeUniqueTestName(getClass(), testName); + prepareConfigs(appID); + prepareEnvironment(); + + waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * CLEANUP_CONSUMER_TIMEOUT); + + cluster.deleteAllTopicsAndWait(120000); + cluster.createTopics(INPUT_TOPIC, OUTPUT_TOPIC, OUTPUT_TOPIC_2, OUTPUT_TOPIC_2_RERUN); + + add10InputElements(); + } + + void cleanupTest() throws Exception { + if (streams != null) { + streams.close(Duration.ofSeconds(30)); + } + IntegrationTestUtils.purgeLocalStreamsState(streamsConfig); + } + + private void add10InputElements() { + final List> records = Arrays.asList(KeyValue.pair(0L, "aaa"), + KeyValue.pair(1L, "bbb"), + KeyValue.pair(0L, "ccc"), + KeyValue.pair(1L, "ddd"), + KeyValue.pair(0L, "eee"), + KeyValue.pair(1L, "fff"), + KeyValue.pair(0L, "ggg"), + KeyValue.pair(1L, "hhh"), + KeyValue.pair(0L, "iii"), + KeyValue.pair(1L, "jjj")); + + for (final KeyValue record : records) { + mockTime.sleep(10); + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp(INPUT_TOPIC, Collections.singleton(record), producerConfig, mockTime.milliseconds()); + } + } + + @Test + public void testResetWhenInternalTopicsAreSpecified() throws Exception { + final String appID = IntegrationTestUtils.safeUniqueTestName(getClass(), testName); + streamsConfig.put(StreamsConfig.APPLICATION_ID_CONFIG, appID); + + // RUN + streams = new KafkaStreams(setupTopologyWithIntermediateTopic(true, OUTPUT_TOPIC_2), streamsConfig); + streams.start(); + IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(resultConsumerConfig, OUTPUT_TOPIC, 10); + + streams.close(); + waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT); + + // RESET + streams.cleanUp(); + + final List internalTopics = cluster.getAllTopicsInCluster().stream() + .filter(StreamsResetter::matchesInternalTopicFormat) + .collect(Collectors.toList()); + cleanGlobal(false, + "--internal-topics", + String.join(",", internalTopics.subList(1, internalTopics.size())), + appID); + waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT); + + assertInternalTopicsGotDeleted(internalTopics.get(0)); + } + + @Test + public void testReprocessingFromScratchAfterResetWithoutIntermediateUserTopic() throws Exception { + final String appID = IntegrationTestUtils.safeUniqueTestName(getClass(), testName); + streamsConfig.put(StreamsConfig.APPLICATION_ID_CONFIG, appID); + + // RUN + streams = new KafkaStreams(setupTopologyWithoutIntermediateUserTopic(), streamsConfig); + streams.start(); + final List> result = IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(resultConsumerConfig, OUTPUT_TOPIC, 10); + + streams.close(); + waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT); + + // RESET + streams = new KafkaStreams(setupTopologyWithoutIntermediateUserTopic(), streamsConfig); + streams.cleanUp(); + cleanGlobal(false, null, null, appID); + waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT); + + assertInternalTopicsGotDeleted(null); + + // RE-RUN + streams.start(); + final List> resultRerun = IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(resultConsumerConfig, OUTPUT_TOPIC, 10); + streams.close(); + + assertThat(resultRerun, equalTo(result)); + + waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT); + cleanGlobal(false, null, null, appID); + } + + @Test + public void testReprocessingFromScratchAfterResetWithIntermediateUserTopic() throws Exception { + testReprocessingFromScratchAfterResetWithIntermediateUserTopic(false); + } + + @Test + public void testReprocessingFromScratchAfterResetWithIntermediateInternalTopic() throws Exception { + testReprocessingFromScratchAfterResetWithIntermediateUserTopic(true); + } + + private void testReprocessingFromScratchAfterResetWithIntermediateUserTopic(final boolean useRepartitioned) throws Exception { + if (!useRepartitioned) { + cluster.createTopic(INTERMEDIATE_USER_TOPIC); + } + + final String appID = IntegrationTestUtils.safeUniqueTestName(getClass(), testName); + streamsConfig.put(StreamsConfig.APPLICATION_ID_CONFIG, appID); + + // RUN + streams = new KafkaStreams(setupTopologyWithIntermediateTopic(useRepartitioned, OUTPUT_TOPIC_2), streamsConfig); + streams.start(); + final List> result = IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(resultConsumerConfig, OUTPUT_TOPIC, 10); + // receive only first values to make sure intermediate user topic is not consumed completely + // => required to test "seekToEnd" for intermediate topics + final List> result2 = IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(resultConsumerConfig, OUTPUT_TOPIC_2, 40); + + streams.close(); + waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT); + + // insert bad record to make sure intermediate user topic gets seekToEnd() + mockTime.sleep(1); + final KeyValue badMessage = new KeyValue<>(-1L, "badRecord-ShouldBeSkipped"); + if (!useRepartitioned) { + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + INTERMEDIATE_USER_TOPIC, + Collections.singleton(badMessage), + producerConfig, + mockTime.milliseconds()); + } + + // RESET + streams = new KafkaStreams(setupTopologyWithIntermediateTopic(useRepartitioned, OUTPUT_TOPIC_2_RERUN), streamsConfig); + streams.cleanUp(); + cleanGlobal(!useRepartitioned, null, null, appID); + waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT); + + assertInternalTopicsGotDeleted(useRepartitioned ? null : INTERMEDIATE_USER_TOPIC); + + // RE-RUN + streams.start(); + final List> resultRerun = IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(resultConsumerConfig, OUTPUT_TOPIC, 10); + final List> resultRerun2 = IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(resultConsumerConfig, OUTPUT_TOPIC_2_RERUN, 40); + streams.close(); + + assertThat(resultRerun, equalTo(result)); + assertThat(resultRerun2, equalTo(result2)); + + if (!useRepartitioned) { + final Properties props = TestUtils.consumerConfig(cluster.bootstrapServers(), appID + "-result-consumer", LongDeserializer.class, StringDeserializer.class, commonClientConfig); + final List> resultIntermediate = IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(props, INTERMEDIATE_USER_TOPIC, 21); + + for (int i = 0; i < 10; i++) { + assertThat(resultIntermediate.get(i), equalTo(resultIntermediate.get(i + 11))); + } + assertThat(resultIntermediate.get(10), equalTo(badMessage)); + } + + waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT); + cleanGlobal(!useRepartitioned, null, null, appID); + + if (!useRepartitioned) { + cluster.deleteTopicAndWait(INTERMEDIATE_USER_TOPIC); + } + } + + @SuppressWarnings("deprecation") + private Topology setupTopologyWithIntermediateTopic(final boolean useRepartitioned, + final String outputTopic2) { + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream input = builder.stream(INPUT_TOPIC); + + // use map to trigger internal re-partitioning before groupByKey + input.map(KeyValue::new) + .groupByKey() + .count() + .toStream() + .to(OUTPUT_TOPIC, Produced.with(Serdes.Long(), Serdes.Long())); + + final KStream stream; + if (useRepartitioned) { + stream = input.repartition(); + } else { + input.to(INTERMEDIATE_USER_TOPIC); + stream = builder.stream(INTERMEDIATE_USER_TOPIC); + } + stream.groupByKey() + .windowedBy(TimeWindows.of(ofMillis(35)).advanceBy(ofMillis(10))) + .count() + .toStream() + .map((key, value) -> new KeyValue<>(key.window().start() + key.window().end(), value)) + .to(outputTopic2, Produced.with(Serdes.Long(), Serdes.Long())); + + return builder.build(); + } + + protected Topology setupTopologyWithoutIntermediateUserTopic() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream input = builder.stream(INPUT_TOPIC); + + // use map to trigger internal re-partitioning before groupByKey + input.map((key, value) -> new KeyValue<>(key, key)) + .to(OUTPUT_TOPIC, Produced.with(Serdes.Long(), Serdes.Long())); + + return builder.build(); + } + + protected boolean tryCleanGlobal(final boolean withIntermediateTopics, + final String resetScenario, + final String resetScenarioArg, + final String appID) throws Exception { + final List parameterList = new ArrayList<>( + Arrays.asList("--application-id", appID, + "--bootstrap-servers", cluster.bootstrapServers(), + "--input-topics", INPUT_TOPIC + )); + if (withIntermediateTopics) { + parameterList.add("--intermediate-topics"); + parameterList.add(INTERMEDIATE_USER_TOPIC); + } + + final Map sslConfig = getClientSslConfig(); + if (sslConfig != null) { + final File configFile = TestUtils.tempFile(); + final BufferedWriter writer = new BufferedWriter(new FileWriter(configFile)); + writer.write(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG + "=SSL\n"); + writer.write(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG + "=" + sslConfig.get(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG) + "\n"); + writer.write(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG + "=" + ((Password) sslConfig.get(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG)).value() + "\n"); + writer.close(); + + parameterList.add("--config-file"); + parameterList.add(configFile.getAbsolutePath()); + } + if (resetScenario != null) { + parameterList.add(resetScenario); + } + if (resetScenarioArg != null) { + parameterList.add(resetScenarioArg); + } + + final String[] parameters = parameterList.toArray(new String[0]); + + final Properties cleanUpConfig = new Properties(); + cleanUpConfig.put(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, 100); + cleanUpConfig.put(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, Integer.toString(CLEANUP_CONSUMER_TIMEOUT)); + + return new StreamsResetter().run(parameters, cleanUpConfig) == 0; + } + + protected void cleanGlobal(final boolean withIntermediateTopics, + final String resetScenario, + final String resetScenarioArg, + final String appID) throws Exception { + final boolean cleanResult = tryCleanGlobal(withIntermediateTopics, resetScenario, resetScenarioArg, appID); + Assert.assertTrue(cleanResult); + } + + protected void assertInternalTopicsGotDeleted(final String additionalExistingTopic) throws Exception { + // do not use list topics request, but read from the embedded cluster's zookeeper path directly to confirm + if (additionalExistingTopic != null) { + cluster.waitForRemainingTopics(30000, INPUT_TOPIC, OUTPUT_TOPIC, OUTPUT_TOPIC_2, OUTPUT_TOPIC_2_RERUN, + Topic.GROUP_METADATA_TOPIC_NAME, additionalExistingTopic); + } else { + cluster.waitForRemainingTopics(30000, INPUT_TOPIC, OUTPUT_TOPIC, OUTPUT_TOPIC_2, OUTPUT_TOPIC_2_RERUN, + Topic.GROUP_METADATA_TOPIC_NAME); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/AdjustStreamThreadCountTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/AdjustStreamThreadCountTest.java new file mode 100644 index 0000000..26edd69 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/AdjustStreamThreadCountTest.java @@ -0,0 +1,453 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.ThreadMetadata; +import org.apache.kafka.streams.errors.StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Transformer; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestUtils; + +import java.util.concurrent.atomic.AtomicBoolean; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; +import java.io.IOException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.Properties; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkObjectProperties; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.purgeLocalStreamsState; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; +import static org.apache.kafka.test.TestUtils.waitForCondition; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +@Category(IntegrationTest.class) +public class AdjustStreamThreadCountTest { + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(1); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + + @Rule + public TestName testName = new TestName(); + + private final List stateTransitionHistory = new ArrayList<>(); + private static String inputTopic; + private static StreamsBuilder builder; + private static Properties properties; + private static String appId = ""; + public static final Duration DEFAULT_DURATION = Duration.ofSeconds(30); + + @Before + public void setup() { + final String testId = safeUniqueTestName(getClass(), testName); + appId = "appId_" + testId; + inputTopic = "input" + testId; + IntegrationTestUtils.cleanStateBeforeTest(CLUSTER, inputTopic); + + builder = new StreamsBuilder(); + builder.stream(inputTopic); + + properties = mkObjectProperties( + mkMap( + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()), + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, appId), + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()), + mkEntry(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 2), + mkEntry(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.StringSerde.class), + mkEntry(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.StringSerde.class), + mkEntry(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, 10000) + ) + ); + } + + private void startStreamsAndWaitForRunning(final KafkaStreams kafkaStreams) throws InterruptedException { + kafkaStreams.start(); + waitForRunning(); + } + + @After + public void teardown() throws IOException { + purgeLocalStreamsState(properties); + } + + private void addStreamStateChangeListener(final KafkaStreams kafkaStreams) { + kafkaStreams.setStateListener( + (newState, oldState) -> stateTransitionHistory.add(newState) + ); + } + + private void waitForRunning() throws InterruptedException { + waitForCondition( + () -> !stateTransitionHistory.isEmpty() && + stateTransitionHistory.get(stateTransitionHistory.size() - 1).equals(KafkaStreams.State.RUNNING), + DEFAULT_DURATION.toMillis(), + () -> String.format("Client did not transit to state %s in %d seconds", + KafkaStreams.State.RUNNING, DEFAULT_DURATION.toMillis() / 1000) + ); + } + + private void waitForTransitionFromRebalancingToRunning() throws InterruptedException { + waitForRunning(); + + final int historySize = stateTransitionHistory.size(); + assertThat("Client did not transit from REBALANCING to RUNNING. The observed state transitions are: " + stateTransitionHistory, + historySize >= 2 && + stateTransitionHistory.get(historySize - 2).equals(KafkaStreams.State.REBALANCING) && + stateTransitionHistory.get(historySize - 1).equals(KafkaStreams.State.RUNNING), is(true)); + } + + @Test + public void shouldAddStreamThread() throws Exception { + try (final KafkaStreams kafkaStreams = new KafkaStreams(builder.build(), properties)) { + addStreamStateChangeListener(kafkaStreams); + startStreamsAndWaitForRunning(kafkaStreams); + + final int oldThreadCount = kafkaStreams.metadataForLocalThreads().size(); + assertThat(kafkaStreams.metadataForLocalThreads().stream().map(t -> t.threadName().split("-StreamThread-")[1]).sorted().toArray(), equalTo(new String[] {"1", "2"})); + + stateTransitionHistory.clear(); + final Optional name = kafkaStreams.addStreamThread(); + + assertThat(name, not(Optional.empty())); + TestUtils.waitForCondition( + () -> kafkaStreams.metadataForLocalThreads().stream().sequential() + .map(ThreadMetadata::threadName).anyMatch(t -> t.equals(name.orElse(""))), + "Wait for the thread to be added" + ); + assertThat(kafkaStreams.metadataForLocalThreads().size(), equalTo(oldThreadCount + 1)); + assertThat( + kafkaStreams + .metadataForLocalThreads() + .stream() + .map(t -> t.threadName().split("-StreamThread-")[1]) + .sorted().toArray(), + equalTo(new String[] {"1", "2", "3"}) + ); + + waitForTransitionFromRebalancingToRunning(); + } + } + + @Test + public void shouldRemoveStreamThread() throws Exception { + try (final KafkaStreams kafkaStreams = new KafkaStreams(builder.build(), properties)) { + addStreamStateChangeListener(kafkaStreams); + startStreamsAndWaitForRunning(kafkaStreams); + + final int oldThreadCount = kafkaStreams.metadataForLocalThreads().size(); + stateTransitionHistory.clear(); + assertThat(kafkaStreams.removeStreamThread().get().split("-")[0], equalTo(appId)); + assertThat(kafkaStreams.metadataForLocalThreads().size(), equalTo(oldThreadCount - 1)); + + waitForTransitionFromRebalancingToRunning(); + } + } + + @Test + public void shouldRemoveStreamThreadWithStaticMembership() throws Exception { + properties.put(ConsumerConfig.GROUP_INSTANCE_ID_CONFIG, "member-A"); + try (final KafkaStreams kafkaStreams = new KafkaStreams(builder.build(), properties)) { + addStreamStateChangeListener(kafkaStreams); + startStreamsAndWaitForRunning(kafkaStreams); + + final int oldThreadCount = kafkaStreams.metadataForLocalThreads().size(); + stateTransitionHistory.clear(); + assertThat(kafkaStreams.removeStreamThread().get().split("-")[0], equalTo(appId)); + assertThat(kafkaStreams.metadataForLocalThreads().size(), equalTo(oldThreadCount - 1)); + + waitForTransitionFromRebalancingToRunning(); + } + } + + @Test + public void shouldnNotRemoveStreamThreadWithinTimeout() throws Exception { + try (final KafkaStreams kafkaStreams = new KafkaStreams(builder.build(), properties)) { + addStreamStateChangeListener(kafkaStreams); + startStreamsAndWaitForRunning(kafkaStreams); + assertThrows(TimeoutException.class, () -> kafkaStreams.removeStreamThread(Duration.ZERO.minus(DEFAULT_DURATION))); + } + } + + @Test + public void shouldAddAndRemoveThreadsMultipleTimes() throws InterruptedException { + try (final KafkaStreams kafkaStreams = new KafkaStreams(builder.build(), properties)) { + addStreamStateChangeListener(kafkaStreams); + startStreamsAndWaitForRunning(kafkaStreams); + + final int oldThreadCount = kafkaStreams.metadataForLocalThreads().size(); + stateTransitionHistory.clear(); + + final CountDownLatch latch = new CountDownLatch(2); + final Thread one = adjustCountHelperThread(kafkaStreams, 4, latch); + final Thread two = adjustCountHelperThread(kafkaStreams, 6, latch); + two.start(); + one.start(); + latch.await(30, TimeUnit.SECONDS); + assertThat(kafkaStreams.metadataForLocalThreads().size(), equalTo(oldThreadCount)); + + waitForTransitionFromRebalancingToRunning(); + } + } + + private Thread adjustCountHelperThread(final KafkaStreams kafkaStreams, final int count, final CountDownLatch latch) { + return new Thread(() -> { + for (int i = 0; i < count; i++) { + kafkaStreams.addStreamThread(); + kafkaStreams.removeStreamThread(); + } + latch.countDown(); + }); + } + + @Test + public void shouldAddAndRemoveStreamThreadsWhileKeepingNamesCorrect() throws Exception { + try (final KafkaStreams kafkaStreams = new KafkaStreams(builder.build(), properties)) { + addStreamStateChangeListener(kafkaStreams); + startStreamsAndWaitForRunning(kafkaStreams); + + int oldThreadCount = kafkaStreams.metadataForLocalThreads().size(); + stateTransitionHistory.clear(); + + assertThat( + kafkaStreams.metadataForLocalThreads() + .stream() + .map(t -> t.threadName().split("-StreamThread-")[1]) + .sorted() + .toArray(), + equalTo(new String[] {"1", "2"}) + ); + + final Optional name = kafkaStreams.addStreamThread(); + + assertThat("New thread has index 3", "3".equals(name.get().split("-StreamThread-")[1])); + TestUtils.waitForCondition( + () -> kafkaStreams + .metadataForLocalThreads() + .stream().sequential() + .map(ThreadMetadata::threadName) + .anyMatch(t -> t.equals(name.get())), + "Stream thread has not been added" + ); + assertThat(kafkaStreams.metadataForLocalThreads().size(), equalTo(oldThreadCount + 1)); + assertThat( + kafkaStreams + .metadataForLocalThreads() + .stream() + .map(t -> t.threadName().split("-StreamThread-")[1]) + .sorted() + .toArray(), + equalTo(new String[] {"1", "2", "3"}) + ); + waitForTransitionFromRebalancingToRunning(); + + oldThreadCount = kafkaStreams.metadataForLocalThreads().size(); + stateTransitionHistory.clear(); + + final Optional removedThread = kafkaStreams.removeStreamThread(); + + assertThat(removedThread, not(Optional.empty())); + assertThat(kafkaStreams.metadataForLocalThreads().size(), equalTo(oldThreadCount - 1)); + waitForTransitionFromRebalancingToRunning(); + + stateTransitionHistory.clear(); + + final Optional name2 = kafkaStreams.addStreamThread(); + + assertThat(name2, not(Optional.empty())); + TestUtils.waitForCondition( + () -> kafkaStreams.metadataForLocalThreads().stream().sequential() + .map(ThreadMetadata::threadName).anyMatch(t -> t.equals(name2.orElse(""))), + "Wait for the thread to be added" + ); + assertThat(kafkaStreams.metadataForLocalThreads().size(), equalTo(oldThreadCount)); + assertThat( + kafkaStreams + .metadataForLocalThreads() + .stream() + .map(t -> t.threadName().split("-StreamThread-")[1]) + .sorted() + .toArray(), + equalTo(new String[] {"1", "2", "3"}) + ); + + assertThat("the new thread should have received the old threads name", name2.equals(removedThread)); + waitForTransitionFromRebalancingToRunning(); + } + } + + @Test + public void testConcurrentlyAccessThreads() throws InterruptedException { + try (final KafkaStreams kafkaStreams = new KafkaStreams(builder.build(), properties)) { + addStreamStateChangeListener(kafkaStreams); + startStreamsAndWaitForRunning(kafkaStreams); + final int oldThreadCount = kafkaStreams.metadataForLocalThreads().size(); + final int threadCount = 5; + final int loop = 3; + final AtomicReference lastException = new AtomicReference<>(); + final ExecutorService executor = Executors.newFixedThreadPool(threadCount); + for (int threadIndex = 0; threadIndex < threadCount; ++threadIndex) { + executor.execute(() -> { + try { + for (int i = 0; i < loop + 1; i++) { + if (!kafkaStreams.addStreamThread().isPresent()) + throw new RuntimeException("failed to create stream thread"); + kafkaStreams.metadataForLocalThreads(); + if (!kafkaStreams.removeStreamThread().isPresent()) + throw new RuntimeException("failed to delete a stream thread"); + } + } catch (final Exception e) { + lastException.set(e); + } + }); + } + executor.shutdown(); + assertTrue(executor.awaitTermination(60, TimeUnit.SECONDS)); + assertNull(lastException.get()); + assertEquals(oldThreadCount, kafkaStreams.metadataForLocalThreads().size()); + } + } + + @Test + public void shouldResizeCacheAfterThreadRemovalTimesOut() throws InterruptedException { + final long totalCacheBytes = 10L; + final Properties props = new Properties(); + props.putAll(properties); + props.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 2); + props.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, totalCacheBytes); + + try (final KafkaStreams kafkaStreams = new KafkaStreams(builder.build(), props)) { + addStreamStateChangeListener(kafkaStreams); + startStreamsAndWaitForRunning(kafkaStreams); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(KafkaStreams.class)) { + assertThrows(TimeoutException.class, () -> kafkaStreams.removeStreamThread(Duration.ofSeconds(0))); + + for (final String log : appender.getMessages()) { + // all 10 bytes should be available for remaining thread + if (log.endsWith("Resizing thread cache due to thread removal, new cache size per thread is 10")) { + return; + } + } + } + } + fail(); + } + + @Test + public void shouldResizeCacheAfterThreadReplacement() throws InterruptedException { + final long totalCacheBytes = 10L; + final Properties props = new Properties(); + props.putAll(properties); + props.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 2); + props.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, totalCacheBytes); + + final AtomicBoolean injectError = new AtomicBoolean(false); + + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream = builder.stream(inputTopic); + stream.transform(() -> new Transformer>() { + @Override + public void init(final ProcessorContext context) { + context.schedule(Duration.ofSeconds(1), PunctuationType.WALL_CLOCK_TIME, timestamp -> { + if (Thread.currentThread().getName().endsWith("StreamThread-1") && injectError.get()) { + injectError.set(false); + throw new RuntimeException("BOOM"); + } + }); + } + + @Override + public KeyValue transform(final String key, final String value) { + return new KeyValue<>(key, value); + } + + @Override + public void close() { + } + }); + + try (final KafkaStreams kafkaStreams = new KafkaStreams(builder.build(), props)) { + addStreamStateChangeListener(kafkaStreams); + kafkaStreams.setUncaughtExceptionHandler(e -> StreamThreadExceptionResponse.REPLACE_THREAD); + startStreamsAndWaitForRunning(kafkaStreams); + + stateTransitionHistory.clear(); + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister()) { + injectError.set(true); + waitForCondition(() -> !injectError.get(), "StreamThread did not hit and reset the injected error"); + + waitForTransitionFromRebalancingToRunning(); + + for (final String log : appender.getMessages()) { + // after we replace the thread there should be two remaining threads with 5 bytes each + if (log.endsWith("Adding StreamThread-3, there will now be 2 live threads and the new cache size per thread is 5")) { + return; + } + } + } + } + fail(); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/EOSUncleanShutdownIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/EOSUncleanShutdownIntegrationTest.java new file mode 100644 index 0000000..d3e991d --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/EOSUncleanShutdownIntegrationTest.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KafkaStreams.State; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestUtils; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.File; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.Optional; +import java.util.Properties; +import java.util.concurrent.atomic.AtomicInteger; + +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkProperties; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.cleanStateBeforeTest; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.quietlyCleanStateAfterTest; +import static org.junit.Assert.assertTrue; + + +/** + * Test the unclean shutdown behavior around state store cleanup. + */ +@RunWith(Parameterized.class) +@Category(IntegrationTest.class) +public class EOSUncleanShutdownIntegrationTest { + + @SuppressWarnings("deprecation") + @Parameterized.Parameters(name = "{0}") + public static Collection data() { + return Arrays.asList(new String[][] { + {StreamsConfig.EXACTLY_ONCE}, + {StreamsConfig.EXACTLY_ONCE_V2} + }); + } + + @Parameterized.Parameter + public String eosConfig; + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(3); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + STREAMS_CONFIG.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + STREAMS_CONFIG.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + STREAMS_CONFIG.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + STREAMS_CONFIG.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + STREAMS_CONFIG.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, COMMIT_INTERVAL); + STREAMS_CONFIG.put(StreamsConfig.STATE_DIR_CONFIG, TEST_FOLDER.getRoot().getPath()); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + @ClassRule + public static final TemporaryFolder TEST_FOLDER = new TemporaryFolder(TestUtils.tempDirectory()); + + private static final Properties STREAMS_CONFIG = new Properties(); + private static final StringSerializer STRING_SERIALIZER = new StringSerializer(); + private static final Long COMMIT_INTERVAL = 100L; + + private static final int RECORD_TOTAL = 3; + + @Test + public void shouldWorkWithUncleanShutdownWipeOutStateStore() throws InterruptedException { + final String appId = "shouldWorkWithUncleanShutdownWipeOutStateStore"; + STREAMS_CONFIG.put(StreamsConfig.APPLICATION_ID_CONFIG, appId); + STREAMS_CONFIG.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, eosConfig); + + final String input = "input-topic"; + cleanStateBeforeTest(CLUSTER, input); + + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream inputStream = builder.stream(input); + + final AtomicInteger recordCount = new AtomicInteger(0); + + final KTable valueCounts = inputStream + .groupByKey() + .aggregate( + () -> "()", + (key, value, aggregate) -> aggregate + ",(" + key + ": " + value + ")", + Materialized.as("aggregated_value")); + + valueCounts.toStream().peek((key, value) -> { + if (recordCount.incrementAndGet() >= RECORD_TOTAL) { + throw new IllegalStateException("Crash on the " + RECORD_TOTAL + " record"); + } + }); + + final Properties producerConfig = mkProperties(mkMap( + mkEntry(ProducerConfig.CLIENT_ID_CONFIG, "anything"), + mkEntry(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, ((Serializer) STRING_SERIALIZER).getClass().getName()), + mkEntry(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, ((Serializer) STRING_SERIALIZER).getClass().getName()), + mkEntry(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()) + )); + final KafkaStreams driver = new KafkaStreams(builder.build(), STREAMS_CONFIG); + driver.cleanUp(); + driver.start(); + + // Task's StateDir + final File taskStateDir = new File(String.join("/", TEST_FOLDER.getRoot().getPath(), appId, "0_0")); + final File taskCheckpointFile = new File(taskStateDir, ".checkpoint"); + + try { + IntegrationTestUtils.produceSynchronously(producerConfig, false, input, Optional.empty(), + singletonList(new KeyValueTimestamp<>("k1", "v1", 0L))); + + // wait until the first request is processed and some files are created in it + TestUtils.waitForCondition(() -> taskStateDir.exists() && taskStateDir.isDirectory() && taskStateDir.list().length > 0, + "Failed awaiting CreateTopics first request failure"); + IntegrationTestUtils.produceSynchronously(producerConfig, false, input, Optional.empty(), + asList(new KeyValueTimestamp<>("k2", "v2", 1L), + new KeyValueTimestamp<>("k3", "v3", 2L))); + + TestUtils.waitForCondition(() -> recordCount.get() == RECORD_TOTAL, + "Expected " + RECORD_TOTAL + " records processed but only got " + recordCount.get()); + } finally { + TestUtils.waitForCondition(() -> driver.state().equals(State.ERROR), + "Expected ERROR state but driver is on " + driver.state()); + + driver.close(); + + // Although there is an uncaught exception, + // case 1: the state directory is cleaned up without any problems. + // case 2: The state directory is not cleaned up, for it does not include any checkpoint file. + // case 3: The state directory is not cleaned up, for it includes a checkpoint file but it is empty. + assertTrue(!taskStateDir.exists() + || (taskStateDir.exists() && taskStateDir.list().length > 0 && !taskCheckpointFile.exists()) + || (taskCheckpointFile.exists() && taskCheckpointFile.length() == 0L)); + + quietlyCleanStateAfterTest(CLUSTER, driver); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/EmitOnChangeIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/EmitOnChangeIntegrationTest.java new file mode 100644 index 0000000..63e0f27 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/EmitOnChangeIntegrationTest.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.serialization.IntegerDeserializer; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.StreamsTestUtils; +import org.apache.kafka.test.TestUtils; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Properties; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkObjectProperties; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; + +@Category(IntegrationTest.class) +public class EmitOnChangeIntegrationTest { + + private static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(1); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + @Rule + public TestName testName = new TestName(); + + private static String inputTopic; + private static String outputTopic; + private static String appId = ""; + + @Before + public void setup() { + final String testId = safeUniqueTestName(getClass(), testName); + appId = "appId_" + testId; + inputTopic = "input" + testId; + outputTopic = "output" + testId; + IntegrationTestUtils.cleanStateBeforeTest(CLUSTER, inputTopic, outputTopic); + } + + @Test + public void shouldEmitSameRecordAfterFailover() throws Exception { + final Properties properties = mkObjectProperties( + mkMap( + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()), + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, appId), + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()), + mkEntry(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 1), + mkEntry(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0), + mkEntry(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 300000L), + mkEntry(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.IntegerSerde.class), + mkEntry(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.StringSerde.class), + mkEntry(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, 10000) + ) + ); + + final AtomicBoolean shouldThrow = new AtomicBoolean(true); + final StreamsBuilder builder = new StreamsBuilder(); + builder.table(inputTopic, Materialized.as("test-store")) + .toStream() + .map((key, value) -> { + if (shouldThrow.compareAndSet(true, false)) { + throw new RuntimeException("Kaboom"); + } else { + return new KeyValue<>(key, value); + } + }) + .to(outputTopic); + + try (final KafkaStreams kafkaStreams = new KafkaStreams(builder.build(), properties)) { + kafkaStreams.setUncaughtExceptionHandler(exception -> StreamThreadExceptionResponse.REPLACE_THREAD); + StreamsTestUtils.startKafkaStreamsAndWaitForRunningState(kafkaStreams); + + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + inputTopic, + Arrays.asList( + new KeyValue<>(1, "A"), + new KeyValue<>(1, "B") + ), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + IntegerSerializer.class, + StringSerializer.class, + new Properties()), + 0L); + + IntegrationTestUtils.waitUntilFinalKeyValueRecordsReceived( + TestUtils.consumerConfig( + CLUSTER.bootstrapServers(), + IntegerDeserializer.class, + StringDeserializer.class + ), + outputTopic, + Arrays.asList( + new KeyValue<>(1, "A"), + new KeyValue<>(1, "B") + ) + ); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java new file mode 100644 index 0000000..08cfc4b --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/EosIntegrationTest.java @@ -0,0 +1,1105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.AdminClientConfig; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.IsolationLevel; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.apache.kafka.common.serialization.LongDeserializer; +import org.apache.kafka.common.serialization.LongSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Transformer; +import org.apache.kafka.streams.kstream.TransformerSupplier; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.StreamThread; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.QueryableStoreTypes; +import org.apache.kafka.streams.state.ReadOnlyKeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.internals.OffsetCheckpoint; +import org.apache.kafka.streams.state.internals.RocksDBStore; +import org.apache.kafka.streams.state.internals.RocksDbKeyValueBytesStoreSupplier; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.MockInternalProcessorContext; +import org.apache.kafka.test.MockKeyValueStore; +import org.apache.kafka.test.StreamsTestUtils; +import org.apache.kafka.test.TestUtils; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.math.BigInteger; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.waitForEmptyConsumerGroup; +import static org.apache.kafka.test.StreamsTestUtils.startKafkaStreamsAndWaitForRunningState; +import static org.apache.kafka.test.TestUtils.consumerConfig; +import static org.apache.kafka.test.TestUtils.waitForCondition; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +@RunWith(Parameterized.class) +@Category({IntegrationTest.class}) +public class EosIntegrationTest { + private static final Logger LOG = LoggerFactory.getLogger(EosIntegrationTest.class); + private static final int NUM_BROKERS = 3; + private static final int MAX_POLL_INTERVAL_MS = 5 * 1000; + private static final int MAX_WAIT_TIME_MS = 60 * 1000; + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster( + NUM_BROKERS, + Utils.mkProperties(Collections.singletonMap("auto.create.topics.enable", "true")) + ); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + + private String applicationId; + private final static int NUM_TOPIC_PARTITIONS = 2; + private final static String CONSUMER_GROUP_ID = "readCommitted"; + private final static String SINGLE_PARTITION_INPUT_TOPIC = "singlePartitionInputTopic"; + private final static String SINGLE_PARTITION_THROUGH_TOPIC = "singlePartitionThroughTopic"; + private final static String SINGLE_PARTITION_OUTPUT_TOPIC = "singlePartitionOutputTopic"; + private final static String MULTI_PARTITION_INPUT_TOPIC = "multiPartitionInputTopic"; + private final static String MULTI_PARTITION_THROUGH_TOPIC = "multiPartitionThroughTopic"; + private final static String MULTI_PARTITION_OUTPUT_TOPIC = "multiPartitionOutputTopic"; + private final String storeName = "store"; + + private AtomicBoolean errorInjected; + private AtomicBoolean stallInjected; + private AtomicReference stallingHost; + private volatile boolean doStall = true; + private AtomicInteger commitRequested; + private Throwable uncaughtException; + + private static final AtomicInteger TEST_NUMBER = new AtomicInteger(0); + + private volatile boolean hasUnexpectedError = false; + + private String stateTmpDir; + + @SuppressWarnings("deprecation") + @Parameters(name = "{0}") + public static Collection data() { + return Arrays.asList(new String[][]{ + {StreamsConfig.AT_LEAST_ONCE}, + {StreamsConfig.EXACTLY_ONCE}, + {StreamsConfig.EXACTLY_ONCE_V2} + }); + } + + @Parameter + public String eosConfig; + + @Before + public void createTopics() throws Exception { + applicationId = "appId-" + TEST_NUMBER.getAndIncrement(); + CLUSTER.deleteTopicsAndWait( + SINGLE_PARTITION_INPUT_TOPIC, MULTI_PARTITION_INPUT_TOPIC, + SINGLE_PARTITION_THROUGH_TOPIC, MULTI_PARTITION_THROUGH_TOPIC, + SINGLE_PARTITION_OUTPUT_TOPIC, MULTI_PARTITION_OUTPUT_TOPIC); + + CLUSTER.createTopics(SINGLE_PARTITION_INPUT_TOPIC, SINGLE_PARTITION_THROUGH_TOPIC, SINGLE_PARTITION_OUTPUT_TOPIC); + CLUSTER.createTopic(MULTI_PARTITION_INPUT_TOPIC, NUM_TOPIC_PARTITIONS, 1); + CLUSTER.createTopic(MULTI_PARTITION_THROUGH_TOPIC, NUM_TOPIC_PARTITIONS, 1); + CLUSTER.createTopic(MULTI_PARTITION_OUTPUT_TOPIC, NUM_TOPIC_PARTITIONS, 1); + } + + @Test + public void shouldBeAbleToRunWithEosEnabled() throws Exception { + runSimpleCopyTest(1, SINGLE_PARTITION_INPUT_TOPIC, null, SINGLE_PARTITION_OUTPUT_TOPIC, false, eosConfig); + } + + @Test + public void shouldCommitCorrectOffsetIfInputTopicIsTransactional() throws Exception { + runSimpleCopyTest(1, SINGLE_PARTITION_INPUT_TOPIC, null, SINGLE_PARTITION_OUTPUT_TOPIC, true, eosConfig); + + try (final Admin adminClient = Admin.create(mkMap(mkEntry(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()))); + final Consumer consumer = new KafkaConsumer<>(mkMap( + mkEntry(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()), + mkEntry(ConsumerConfig.GROUP_ID_CONFIG, applicationId), + mkEntry(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class), + mkEntry(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class)))) { + + waitForEmptyConsumerGroup(adminClient, applicationId, 5 * MAX_POLL_INTERVAL_MS); + + final TopicPartition topicPartition = new TopicPartition(SINGLE_PARTITION_INPUT_TOPIC, 0); + final Collection topicPartitions = Collections.singleton(topicPartition); + + final long committedOffset = adminClient.listConsumerGroupOffsets(applicationId).partitionsToOffsetAndMetadata().get().get(topicPartition).offset(); + + consumer.assign(topicPartitions); + final long consumerPosition = consumer.position(topicPartition); + final long endOffset = consumer.endOffsets(topicPartitions).get(topicPartition); + + assertThat(committedOffset, equalTo(consumerPosition)); + assertThat(committedOffset, equalTo(endOffset)); + } + } + + @Test + public void shouldBeAbleToRestartAfterClose() throws Exception { + runSimpleCopyTest(2, SINGLE_PARTITION_INPUT_TOPIC, null, SINGLE_PARTITION_OUTPUT_TOPIC, false, eosConfig); + } + + @Test + public void shouldBeAbleToCommitToMultiplePartitions() throws Exception { + runSimpleCopyTest(1, SINGLE_PARTITION_INPUT_TOPIC, null, MULTI_PARTITION_OUTPUT_TOPIC, false, eosConfig); + } + + @Test + public void shouldBeAbleToCommitMultiplePartitionOffsets() throws Exception { + runSimpleCopyTest(1, MULTI_PARTITION_INPUT_TOPIC, null, SINGLE_PARTITION_OUTPUT_TOPIC, false, eosConfig); + } + + @Test + public void shouldBeAbleToRunWithTwoSubtopologies() throws Exception { + runSimpleCopyTest(1, SINGLE_PARTITION_INPUT_TOPIC, SINGLE_PARTITION_THROUGH_TOPIC, SINGLE_PARTITION_OUTPUT_TOPIC, false, eosConfig); + } + + @Test + public void shouldBeAbleToRunWithTwoSubtopologiesAndMultiplePartitions() throws Exception { + runSimpleCopyTest(1, MULTI_PARTITION_INPUT_TOPIC, MULTI_PARTITION_THROUGH_TOPIC, MULTI_PARTITION_OUTPUT_TOPIC, false, eosConfig); + } + + private void runSimpleCopyTest(final int numberOfRestarts, + final String inputTopic, + final String throughTopic, + final String outputTopic, + final boolean inputTopicTransactional, + final String eosConfig) throws Exception { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream input = builder.stream(inputTopic); + KStream output = input; + if (throughTopic != null) { + input.to(throughTopic); + output = builder.stream(throughTopic); + } + output.to(outputTopic); + + final Properties properties = new Properties(); + properties.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, eosConfig); + properties.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + properties.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L); + properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_RECORDS_CONFIG), 1); + properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.METADATA_MAX_AGE_CONFIG), "1000"); + properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG), "earliest"); + properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG), MAX_POLL_INTERVAL_MS - 1); + properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG), MAX_POLL_INTERVAL_MS); + + for (int i = 0; i < numberOfRestarts; ++i) { + final Properties config = StreamsTestUtils.getStreamsConfig( + applicationId, + CLUSTER.bootstrapServers(), + Serdes.LongSerde.class.getName(), + Serdes.LongSerde.class.getName(), + properties); + + try (final KafkaStreams streams = new KafkaStreams(builder.build(), config)) { + startKafkaStreamsAndWaitForRunningState(streams, MAX_WAIT_TIME_MS); + + final List> inputData = prepareData(i * 100, i * 100 + 10L, 0L, 1L); + + final Properties producerConfigs = new Properties(); + if (inputTopicTransactional) { + producerConfigs.setProperty(ProducerConfig.TRANSACTIONAL_ID_CONFIG, applicationId + "-input-producer"); + } + + IntegrationTestUtils.produceKeyValuesSynchronously( + inputTopic, + inputData, + TestUtils.producerConfig(CLUSTER.bootstrapServers(), LongSerializer.class, LongSerializer.class, producerConfigs), + CLUSTER.time, + inputTopicTransactional + ); + + final List> committedRecords = + IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived( + TestUtils.consumerConfig( + CLUSTER.bootstrapServers(), + CONSUMER_GROUP_ID, + LongDeserializer.class, + LongDeserializer.class, + Utils.mkProperties(Collections.singletonMap( + ConsumerConfig.ISOLATION_LEVEL_CONFIG, + IsolationLevel.READ_COMMITTED.name().toLowerCase(Locale.ROOT))) + ), + outputTopic, + inputData.size() + ); + + checkResultPerKey(committedRecords, inputData, "The committed records do not match what expected"); + } + } + } + + private void checkResultPerKey(final List> result, + final List> expectedResult, + final String reason) { + final Set allKeys = new HashSet<>(); + addAllKeys(allKeys, result); + addAllKeys(allKeys, expectedResult); + + for (final Long key : allKeys) { + assertThat(reason, getAllRecordPerKey(key, result), equalTo(getAllRecordPerKey(key, expectedResult))); + } + } + + private void addAllKeys(final Set allKeys, final List> records) { + for (final KeyValue record : records) { + allKeys.add(record.key); + } + } + + private List> getAllRecordPerKey(final Long key, final List> records) { + final List> recordsPerKey = new ArrayList<>(records.size()); + + for (final KeyValue record : records) { + if (record.key.equals(key)) { + recordsPerKey.add(record); + } + } + + return recordsPerKey; + } + + @Test + public void shouldBeAbleToPerformMultipleTransactions() throws Exception { + final StreamsBuilder builder = new StreamsBuilder(); + builder.stream(SINGLE_PARTITION_INPUT_TOPIC).to(SINGLE_PARTITION_OUTPUT_TOPIC); + + final Properties properties = new Properties(); + properties.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, eosConfig); + properties.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + properties.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L); + properties.put(ConsumerConfig.METADATA_MAX_AGE_CONFIG, "1000"); + properties.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + + final Properties config = StreamsTestUtils.getStreamsConfig( + applicationId, + CLUSTER.bootstrapServers(), + Serdes.LongSerde.class.getName(), + Serdes.LongSerde.class.getName(), + properties); + + try (final KafkaStreams streams = new KafkaStreams(builder.build(), config)) { + startKafkaStreamsAndWaitForRunningState(streams, MAX_WAIT_TIME_MS); + + final List> firstBurstOfData = prepareData(0L, 5L, 0L); + final List> secondBurstOfData = prepareData(5L, 8L, 0L); + + IntegrationTestUtils.produceKeyValuesSynchronously( + SINGLE_PARTITION_INPUT_TOPIC, + firstBurstOfData, + TestUtils.producerConfig(CLUSTER.bootstrapServers(), LongSerializer.class, LongSerializer.class), + CLUSTER.time + ); + + final List> firstCommittedRecords = + IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived( + TestUtils.consumerConfig( + CLUSTER.bootstrapServers(), + CONSUMER_GROUP_ID, + LongDeserializer.class, + LongDeserializer.class, + Utils.mkProperties(Collections.singletonMap( + ConsumerConfig.ISOLATION_LEVEL_CONFIG, + IsolationLevel.READ_COMMITTED.name().toLowerCase(Locale.ROOT))) + ), + SINGLE_PARTITION_OUTPUT_TOPIC, + firstBurstOfData.size() + ); + + assertThat(firstCommittedRecords, equalTo(firstBurstOfData)); + + IntegrationTestUtils.produceKeyValuesSynchronously( + SINGLE_PARTITION_INPUT_TOPIC, + secondBurstOfData, + TestUtils.producerConfig(CLUSTER.bootstrapServers(), LongSerializer.class, LongSerializer.class), + CLUSTER.time + ); + + final List> secondCommittedRecords = + IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived( + TestUtils.consumerConfig( + CLUSTER.bootstrapServers(), + CONSUMER_GROUP_ID, + LongDeserializer.class, + LongDeserializer.class, + Utils.mkProperties(Collections.singletonMap( + ConsumerConfig.ISOLATION_LEVEL_CONFIG, + IsolationLevel.READ_COMMITTED.name().toLowerCase(Locale.ROOT))) + ), + SINGLE_PARTITION_OUTPUT_TOPIC, + secondBurstOfData.size() + ); + + assertThat(secondCommittedRecords, equalTo(secondBurstOfData)); + } + } + + @Test + public void shouldNotViolateEosIfOneTaskFails() throws Exception { + if (eosConfig.equals(StreamsConfig.AT_LEAST_ONCE)) return; + + // this test writes 10 + 5 + 5 records per partition (running with 2 partitions) + // the app is supposed to copy all 40 records into the output topic + // + // the app first commits after each 10 records per partition(total 20 records), and thus will have 2 * 5 uncommitted writes + // + // the failure gets inject after 20 committed and 30 uncommitted records got received + // -> the failure only kills one thread + // after fail over, we should read 40 committed records (even if 50 record got written) + + try (final KafkaStreams streams = getKafkaStreams("dummy", false, "appDir", 2, eosConfig, MAX_POLL_INTERVAL_MS)) { + startKafkaStreamsAndWaitForRunningState(streams, MAX_WAIT_TIME_MS); + + final List> committedDataBeforeFailure = prepareData(0L, 10L, 0L, 1L); + final List> uncommittedDataBeforeFailure = prepareData(10L, 15L, 0L, 1L); + + final List> dataBeforeFailure = new ArrayList<>( + committedDataBeforeFailure.size() + uncommittedDataBeforeFailure.size()); + dataBeforeFailure.addAll(committedDataBeforeFailure); + dataBeforeFailure.addAll(uncommittedDataBeforeFailure); + + final List> dataAfterFailure = prepareData(15L, 20L, 0L, 1L); + + writeInputData(committedDataBeforeFailure); + + waitForCondition( + () -> commitRequested.get() == 2, MAX_WAIT_TIME_MS, + "StreamsTasks did not request commit."); + + // expected end state per output partition (C == COMMIT; A == ABORT; ---> indicate the changes): + // + // p-0: ---> 10 rec + C + // p-1: ---> 10 rec + C + + final List> committedRecords = readResult(committedDataBeforeFailure.size(), CONSUMER_GROUP_ID); + checkResultPerKey( + committedRecords, + committedDataBeforeFailure, + "The committed records before failure do not match what expected"); + + writeInputData(uncommittedDataBeforeFailure); + + // expected end state per output partition (C == COMMIT; A == ABORT; ---> indicate the changes): + // + // p-0: ---> 10 rec + C + 5 rec (pending) + // p-1: ---> 10 rec + C + 5 rec (pending) + + final List> uncommittedRecords = readResult(dataBeforeFailure.size(), null); + checkResultPerKey( + uncommittedRecords, + dataBeforeFailure, + "The uncommitted records before failure do not match what expected"); + + errorInjected.set(true); + writeInputData(dataAfterFailure); + + waitForCondition( + () -> uncaughtException != null, MAX_WAIT_TIME_MS, + "Should receive uncaught exception from one StreamThread."); + + // expected end state per output partition (C == COMMIT; A == ABORT; ---> indicate the changes): + // + // p-0: ---> 10 rec + C + 5 rec + C + 5 rec + C + // p-1: ---> 10 rec + C + 5 rec + C + 5 rec + C + + final List> allCommittedRecords = readResult( + committedDataBeforeFailure.size() + uncommittedDataBeforeFailure.size() + dataAfterFailure.size(), + CONSUMER_GROUP_ID + "_ALL"); + + final List> committedRecordsAfterFailure = readResult( + uncommittedDataBeforeFailure.size() + dataAfterFailure.size(), + CONSUMER_GROUP_ID); + + final int allCommittedRecordsAfterRecoverySize = committedDataBeforeFailure.size() + + uncommittedDataBeforeFailure.size() + dataAfterFailure.size(); + final List> allExpectedCommittedRecordsAfterRecovery = new ArrayList<>(allCommittedRecordsAfterRecoverySize); + allExpectedCommittedRecordsAfterRecovery.addAll(committedDataBeforeFailure); + allExpectedCommittedRecordsAfterRecovery.addAll(uncommittedDataBeforeFailure); + allExpectedCommittedRecordsAfterRecovery.addAll(dataAfterFailure); + + final int committedRecordsAfterRecoverySize = uncommittedDataBeforeFailure.size() + dataAfterFailure.size(); + final List> expectedCommittedRecordsAfterRecovery = new ArrayList<>(committedRecordsAfterRecoverySize); + expectedCommittedRecordsAfterRecovery.addAll(uncommittedDataBeforeFailure); + expectedCommittedRecordsAfterRecovery.addAll(dataAfterFailure); + + checkResultPerKey( + allCommittedRecords, + allExpectedCommittedRecordsAfterRecovery, + "The all committed records after recovery do not match what expected"); + checkResultPerKey( + committedRecordsAfterFailure, + expectedCommittedRecordsAfterRecovery, + "The committed records after recovery do not match what expected"); + + assertThat("Should only get one uncaught exception from Streams.", hasUnexpectedError, is(false)); + } + } + + @Test + public void shouldNotViolateEosIfOneTaskFailsWithState() throws Exception { + if (eosConfig.equals(StreamsConfig.AT_LEAST_ONCE)) return; + + // this test updates a store with 10 + 5 + 5 records per partition (running with 2 partitions) + // the app is supposed to emit all 40 update records into the output topic + // + // the app first commits after each 10 records per partition (total 20 records), and thus will have 2 * 5 uncommitted writes + // and store updates (ie, another 5 uncommitted writes to a changelog topic per partition) + // in the uncommitted batch, sending some data for the new key to validate that upon resuming they will not be shown up in the store + // + // the failure gets inject after 20 committed and 30 uncommitted records got received + // -> the failure only kills one thread + // after fail over, we should read 40 committed records and the state stores should contain the correct sums + // per key (even if some records got processed twice) + + // We need more processing time under "with state" situation, so increasing the max.poll.interval.ms + // to avoid unexpected rebalance during test, which will cause unexpected fail over triggered + try (final KafkaStreams streams = getKafkaStreams("dummy", true, "appDir", 2, eosConfig, 3 * MAX_POLL_INTERVAL_MS)) { + startKafkaStreamsAndWaitForRunningState(streams, MAX_WAIT_TIME_MS); + + final List> committedDataBeforeFailure = prepareData(0L, 10L, 0L, 1L); + final List> uncommittedDataBeforeFailure = prepareData(10L, 15L, 0L, 1L, 2L, 3L); + + final List> dataBeforeFailure = new ArrayList<>( + committedDataBeforeFailure.size() + uncommittedDataBeforeFailure.size()); + dataBeforeFailure.addAll(committedDataBeforeFailure); + dataBeforeFailure.addAll(uncommittedDataBeforeFailure); + + final List> dataAfterFailure = prepareData(15L, 20L, 0L, 1L); + + writeInputData(committedDataBeforeFailure); + + waitForCondition( + () -> commitRequested.get() == 2, MAX_WAIT_TIME_MS, + "SteamsTasks did not request commit."); + + // expected end state per output partition (C == COMMIT; A == ABORT; ---> indicate the changes): + // + // p-0: ---> 10 rec + C + // p-1: ---> 10 rec + C + + final List> committedRecords = readResult(committedDataBeforeFailure.size(), CONSUMER_GROUP_ID); + checkResultPerKey( + committedRecords, + computeExpectedResult(committedDataBeforeFailure), + "The committed records before failure do not match what expected"); + + writeInputData(uncommittedDataBeforeFailure); + + // expected end state per output partition (C == COMMIT; A == ABORT; ---> indicate the changes): + // + // p-0: ---> 10 rec + C + 5 rec (pending) + // p-1: ---> 10 rec + C + 5 rec (pending) + + final List> uncommittedRecords = readResult(dataBeforeFailure.size(), null); + final List> expectedResultBeforeFailure = computeExpectedResult(dataBeforeFailure); + + + checkResultPerKey( + uncommittedRecords, + expectedResultBeforeFailure, + "The uncommitted records before failure do not match what expected"); + verifyStateStore( + streams, + getMaxPerKey(expectedResultBeforeFailure), + "The state store content before failure do not match what expected"); + + errorInjected.set(true); + writeInputData(dataAfterFailure); + + waitForCondition( + () -> uncaughtException != null, MAX_WAIT_TIME_MS, + "Should receive uncaught exception from one StreamThread."); + + // expected end state per output partition (C == COMMIT; A == ABORT; ---> indicate the changes): + // + // p-0: ---> 10 rec + C + 5 rec + C + 5 rec + C + // p-1: ---> 10 rec + C + 5 rec + C + 5 rec + C + + final List> allCommittedRecords = readResult( + committedDataBeforeFailure.size() + uncommittedDataBeforeFailure.size() + dataAfterFailure.size(), + CONSUMER_GROUP_ID + "_ALL"); + + final List> committedRecordsAfterFailure = readResult( + uncommittedDataBeforeFailure.size() + dataAfterFailure.size(), + CONSUMER_GROUP_ID); + + final int allCommittedRecordsAfterRecoverySize = committedDataBeforeFailure.size() + + uncommittedDataBeforeFailure.size() + dataAfterFailure.size(); + final List> allExpectedCommittedRecordsAfterRecovery = new ArrayList<>(allCommittedRecordsAfterRecoverySize); + allExpectedCommittedRecordsAfterRecovery.addAll(committedDataBeforeFailure); + allExpectedCommittedRecordsAfterRecovery.addAll(uncommittedDataBeforeFailure); + allExpectedCommittedRecordsAfterRecovery.addAll(dataAfterFailure); + + final List> expectedResult = computeExpectedResult(allExpectedCommittedRecordsAfterRecovery); + + checkResultPerKey( + allCommittedRecords, + expectedResult, + "The all committed records after recovery do not match what expected"); + + checkResultPerKey( + committedRecordsAfterFailure, + expectedResult.subList(committedDataBeforeFailure.size(), expectedResult.size()), + "The committed records after recovery do not match what expected"); + + verifyStateStore( + streams, + getMaxPerKey(expectedResult), + "The state store content after recovery do not match what expected"); + + assertThat("Should only get one uncaught exception from Streams.", hasUnexpectedError, is(false)); + } + } + + @Test + public void shouldNotViolateEosIfOneTaskGetsFencedUsingIsolatedAppInstances() throws Exception { + if (eosConfig.equals(StreamsConfig.AT_LEAST_ONCE)) return; + + // this test writes 10 + 5 + 5 + 10 records per partition (running with 2 partitions) + // the app is supposed to copy all 60 records into the output topic + // + // the app first commits after each 10 records per partition, and thus will have 2 * 5 uncommitted writes + // + // Then, a stall gets injected after 20 committed and 30 uncommitted records got received + // -> the stall only affects one thread and should trigger a rebalance + // after rebalancing, we should read 40 committed records (even if 50 record got written) + // + // afterwards, the "stalling" thread resumes, and another rebalance should get triggered + // we write the remaining 20 records and verify to read 60 result records + + try ( + final KafkaStreams streams1 = getKafkaStreams("streams1", false, "appDir1", 1, eosConfig, MAX_POLL_INTERVAL_MS); + final KafkaStreams streams2 = getKafkaStreams("streams2", false, "appDir2", 1, eosConfig, MAX_POLL_INTERVAL_MS) + ) { + startKafkaStreamsAndWaitForRunningState(streams1, MAX_WAIT_TIME_MS); + startKafkaStreamsAndWaitForRunningState(streams2, MAX_WAIT_TIME_MS); + + final List> committedDataBeforeStall = prepareData(0L, 10L, 0L, 1L); + final List> uncommittedDataBeforeStall = prepareData(10L, 15L, 0L, 1L); + + final List> dataBeforeStall = new ArrayList<>( + committedDataBeforeStall.size() + uncommittedDataBeforeStall.size()); + dataBeforeStall.addAll(committedDataBeforeStall); + dataBeforeStall.addAll(uncommittedDataBeforeStall); + + final List> dataToTriggerFirstRebalance = prepareData(15L, 20L, 0L, 1L); + + final List> dataAfterSecondRebalance = prepareData(20L, 30L, 0L, 1L); + + writeInputData(committedDataBeforeStall); + + waitForCondition( + () -> commitRequested.get() == 2, MAX_WAIT_TIME_MS, + "SteamsTasks did not request commit."); + + // expected end state per output partition (C == COMMIT; A == ABORT; ---> indicate the changes): + // + // p-0: ---> 10 rec + C + // p-1: ---> 10 rec + C + + final List> committedRecords = readResult(committedDataBeforeStall.size(), CONSUMER_GROUP_ID); + checkResultPerKey( + committedRecords, + committedDataBeforeStall, + "The committed records before stall do not match what expected"); + + writeInputData(uncommittedDataBeforeStall); + + // expected end state per output partition (C == COMMIT; A == ABORT; ---> indicate the changes): + // + // p-0: ---> 10 rec + C + 5 rec (pending) + // p-1: ---> 10 rec + C + 5 rec (pending) + + final List> uncommittedRecords = readResult(dataBeforeStall.size(), null); + checkResultPerKey( + uncommittedRecords, + dataBeforeStall, + "The uncommitted records before stall do not match what expected"); + + LOG.info("Injecting Stall"); + stallInjected.set(true); + writeInputData(dataToTriggerFirstRebalance); + LOG.info("Input Data Written"); + waitForCondition( + () -> stallingHost.get() != null, + MAX_WAIT_TIME_MS, + "Expected a host to start stalling" + ); + final String observedStallingHost = stallingHost.get(); + final KafkaStreams stallingInstance; + final KafkaStreams remainingInstance; + if ("streams1".equals(observedStallingHost)) { + stallingInstance = streams1; + remainingInstance = streams2; + } else if ("streams2".equals(observedStallingHost)) { + stallingInstance = streams2; + remainingInstance = streams1; + } else { + throw new IllegalArgumentException("unexpected host name: " + observedStallingHost); + } + + // the stalling instance won't have an updated view, and it doesn't matter what it thinks + // the assignment is. We only really care that the remaining instance only sees one host + // that owns both partitions. + waitForCondition( + () -> stallingInstance.metadataForAllStreamsClients().size() == 2 + && remainingInstance.metadataForAllStreamsClients().size() == 1 + && remainingInstance.metadataForAllStreamsClients().iterator().next().topicPartitions().size() == 2, + MAX_WAIT_TIME_MS, + () -> "Should have rebalanced.\n" + + "Streams1[" + streams1.metadataForAllStreamsClients() + "]\n" + + "Streams2[" + streams2.metadataForAllStreamsClients() + "]"); + + // expected end state per output partition (C == COMMIT; A == ABORT; ---> indicate the changes): + // + // p-0: ---> 10 rec + C + 5 rec + C + 5 rec + C + // p-1: ---> 10 rec + C + 5 rec + C + 5 rec + C + + final List> committedRecordsAfterRebalance = readResult( + uncommittedDataBeforeStall.size() + dataToTriggerFirstRebalance.size(), + CONSUMER_GROUP_ID); + + final List> expectedCommittedRecordsAfterRebalance = new ArrayList<>( + uncommittedDataBeforeStall.size() + dataToTriggerFirstRebalance.size()); + expectedCommittedRecordsAfterRebalance.addAll(uncommittedDataBeforeStall); + expectedCommittedRecordsAfterRebalance.addAll(dataToTriggerFirstRebalance); + + checkResultPerKey( + committedRecordsAfterRebalance, + expectedCommittedRecordsAfterRebalance, + "The all committed records after rebalance do not match what expected"); + + LOG.info("Releasing Stall"); + doStall = false; + // Once the stalling host rejoins the group, we expect both instances to see both instances. + // It doesn't really matter what the assignment is, but we might as well also assert that they + // both see both partitions assigned exactly once + waitForCondition( + () -> streams1.metadataForAllStreamsClients().size() == 2 + && streams2.metadataForAllStreamsClients().size() == 2 + && streams1.metadataForAllStreamsClients().stream().mapToLong(meta -> meta.topicPartitions().size()).sum() == 2 + && streams2.metadataForAllStreamsClients().stream().mapToLong(meta -> meta.topicPartitions().size()).sum() == 2, + MAX_WAIT_TIME_MS, + () -> "Should have rebalanced.\n" + + "Streams1[" + streams1.metadataForAllStreamsClients() + "]\n" + + "Streams2[" + streams2.metadataForAllStreamsClients() + "]"); + + writeInputData(dataAfterSecondRebalance); + + // expected end state per output partition (C == COMMIT; A == ABORT; ---> indicate the changes): + // + // p-0: ---> 10 rec + C + 5 rec + C + 5 rec + C + 10 rec + C + // p-1: ---> 10 rec + C + 5 rec + C + 5 rec + C + 10 rec + C + + final List> allCommittedRecords = readResult( + committedDataBeforeStall.size() + uncommittedDataBeforeStall.size() + + dataToTriggerFirstRebalance.size() + dataAfterSecondRebalance.size(), + CONSUMER_GROUP_ID + "_ALL"); + + final int allCommittedRecordsAfterRecoverySize = committedDataBeforeStall.size() + + uncommittedDataBeforeStall.size() + dataToTriggerFirstRebalance.size() + dataAfterSecondRebalance.size(); + final List> allExpectedCommittedRecordsAfterRecovery = new ArrayList<>(allCommittedRecordsAfterRecoverySize); + allExpectedCommittedRecordsAfterRecovery.addAll(committedDataBeforeStall); + allExpectedCommittedRecordsAfterRecovery.addAll(uncommittedDataBeforeStall); + allExpectedCommittedRecordsAfterRecovery.addAll(dataToTriggerFirstRebalance); + allExpectedCommittedRecordsAfterRecovery.addAll(dataAfterSecondRebalance); + + checkResultPerKey( + allCommittedRecords, + allExpectedCommittedRecordsAfterRecovery, + "The all committed records after recovery do not match what expected"); + } + } + + @Test + public void shouldWriteLatestOffsetsToCheckpointOnShutdown() throws Exception { + final List> writtenData = prepareData(0L, 10, 0L, 1L); + final List> expectedResult = computeExpectedResult(writtenData); + + try (final KafkaStreams streams = getKafkaStreams("streams", true, "appDir", 1, eosConfig, MAX_POLL_INTERVAL_MS)) { + + startKafkaStreamsAndWaitForRunningState(streams, MAX_WAIT_TIME_MS); + + writeInputData(writtenData); + + waitForCondition( + () -> commitRequested.get() == 2, MAX_WAIT_TIME_MS, + "SteamsTasks did not request commit."); + + final List> committedRecords = readResult(writtenData.size(), CONSUMER_GROUP_ID); + + checkResultPerKey( + committedRecords, + expectedResult, + "The committed records do not match what expected"); + + verifyStateStore( + streams, + getMaxPerKey(expectedResult), + "The state store content do not match what expected"); + } + + final Set> expectedState = getMaxPerKey(expectedResult); + verifyStateIsInStoreAndOffsetsAreInCheckpoint(0, expectedState); + verifyStateIsInStoreAndOffsetsAreInCheckpoint(1, expectedState); + + assertThat("Not all expected state values were found in the state stores", expectedState.isEmpty()); + } + + private void verifyStateIsInStoreAndOffsetsAreInCheckpoint(final int partition, final Set> expectedState) throws IOException { + final String stateStoreDir = stateTmpDir + File.separator + "appDir" + File.separator + applicationId + File.separator + "0_" + partition + File.separator; + + // Verify that the data in the state store on disk is fully up-to-date + final StateStoreContext context = new MockInternalProcessorContext(new Properties(), new TaskId(0, 0), new File(stateStoreDir)); + final MockKeyValueStore stateStore = new MockKeyValueStore("store", false); + final RocksDBStore store = (RocksDBStore) new RocksDbKeyValueBytesStoreSupplier(storeName, false).get(); + store.init(context, stateStore); + + store.all().forEachRemaining(kv -> { + final KeyValue kv2 = new KeyValue<>(new BigInteger(kv.key.get()).longValue(), new BigInteger(kv.value).longValue()); + expectedState.remove(kv2); + }); + + // Verify that the checkpointed offsets match exactly with max offset of the records in the changelog + final OffsetCheckpoint checkpoint = new OffsetCheckpoint(new File(stateStoreDir + ".checkpoint")); + final Map checkpointedOffsets = checkpoint.read(); + checkpointedOffsets.forEach(this::verifyChangelogMaxRecordOffsetMatchesCheckpointedOffset); + } + + private void verifyChangelogMaxRecordOffsetMatchesCheckpointedOffset(final TopicPartition tp, final long checkpointedOffset) { + final KafkaConsumer consumer = new KafkaConsumer<>(consumerConfig(CLUSTER.bootstrapServers(), Serdes.ByteArray().deserializer().getClass(), Serdes.ByteArray().deserializer().getClass())); + final List partitions = Collections.singletonList(tp); + consumer.assign(partitions); + consumer.seekToEnd(partitions); + final long topicEndOffset = consumer.position(tp); + + assertTrue("changelog topic end " + topicEndOffset + " is less than checkpointed offset " + checkpointedOffset, + topicEndOffset >= checkpointedOffset); + + consumer.seekToBeginning(partitions); + + Long maxRecordOffset = null; + while (consumer.position(tp) != topicEndOffset) { + final List> records = consumer.poll(Duration.ofMillis(0)).records(tp); + if (!records.isEmpty()) { + maxRecordOffset = records.get(records.size() - 1).offset(); + } + } + + assertEquals("Checkpointed offset does not match end of changelog", maxRecordOffset, (Long) checkpointedOffset); + } + + private List> prepareData(final long fromInclusive, + final long toExclusive, + final Long... keys) { + final Long dataSize = keys.length * (toExclusive - fromInclusive); + final List> data = new ArrayList<>(dataSize.intValue()); + + for (final Long k : keys) { + for (long v = fromInclusive; v < toExclusive; ++v) { + data.add(new KeyValue<>(k, v)); + } + } + + return data; + } + + @SuppressWarnings("deprecation") //the threads should no longer fail one thread one at a time + private KafkaStreams getKafkaStreams(final String dummyHostName, + final boolean withState, + final String appDir, + final int numberOfStreamsThreads, + final String eosConfig, + final int maxPollIntervalMs) { + commitRequested = new AtomicInteger(0); + errorInjected = new AtomicBoolean(false); + stallInjected = new AtomicBoolean(false); + stallingHost = new AtomicReference<>(); + final StreamsBuilder builder = new StreamsBuilder(); + + String[] storeNames = new String[0]; + if (withState) { + storeNames = new String[] {storeName}; + final StoreBuilder> storeBuilder = Stores + .keyValueStoreBuilder(Stores.persistentKeyValueStore(storeName), Serdes.Long(), Serdes.Long()) + .withCachingEnabled(); + + builder.addStateStore(storeBuilder); + } + + final KStream input = builder.stream(MULTI_PARTITION_INPUT_TOPIC); + input.transform(new TransformerSupplier>() { + @SuppressWarnings("unchecked") + @Override + public Transformer> get() { + return new Transformer>() { + ProcessorContext context; + KeyValueStore state = null; + + @Override + public void init(final ProcessorContext context) { + this.context = context; + + if (withState) { + state = (KeyValueStore) context.getStateStore(storeName); + } + } + + @Override + public KeyValue transform(final Long key, final Long value) { + if (stallInjected.compareAndSet(true, false)) { + LOG.info(dummyHostName + " is executing the injected stall"); + stallingHost.set(dummyHostName); + while (doStall) { + final StreamThread thread = (StreamThread) Thread.currentThread(); + if (thread.isInterrupted() || !thread.isRunning()) { + throw new RuntimeException("Detected we've been interrupted."); + } + try { + Thread.sleep(100); + } catch (final InterruptedException e) { + throw new RuntimeException(e); + } + } + } + + if ((value + 1) % 10 == 0) { + context.commit(); + commitRequested.incrementAndGet(); + } + + if (state != null) { + Long sum = state.get(key); + + if (sum == null) { + sum = value; + } else { + sum += value; + } + state.put(key, sum); + state.flush(); + } + + + if (errorInjected.compareAndSet(true, false)) { + // only tries to fail once on one of the task + throw new RuntimeException("Injected test exception."); + } + + if (state != null) { + return new KeyValue<>(key, state.get(key)); + } else { + return new KeyValue<>(key, value); + } + } + + @Override + public void close() { } + }; + } }, storeNames) + .to(SINGLE_PARTITION_OUTPUT_TOPIC); + + stateTmpDir = TestUtils.tempDirectory().getPath() + File.separator; + + final Properties properties = new Properties(); + // Set commit interval to a larger value to avoid affection of controlled stream commit, + // but not too large as we need to have a relatively low transaction timeout such + // that it should help trigger the timed out transaction in time. + final long commitIntervalMs = 20_000L; + properties.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, eosConfig); + properties.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, numberOfStreamsThreads); + properties.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, commitIntervalMs); + properties.put(StreamsConfig.producerPrefix(ProducerConfig.TRANSACTION_TIMEOUT_CONFIG), (int) commitIntervalMs); + properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.METADATA_MAX_AGE_CONFIG), "1000"); + properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG), "earliest"); + properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG), maxPollIntervalMs); + properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG), maxPollIntervalMs - 1); + properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG), maxPollIntervalMs); + properties.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + properties.put(StreamsConfig.STATE_DIR_CONFIG, stateTmpDir + appDir); + properties.put(StreamsConfig.APPLICATION_SERVER_CONFIG, dummyHostName + ":2142"); + + final Properties config = StreamsTestUtils.getStreamsConfig( + applicationId, + CLUSTER.bootstrapServers(), + Serdes.LongSerde.class.getName(), + Serdes.LongSerde.class.getName(), + properties); + + final KafkaStreams streams = new KafkaStreams(builder.build(), config); + + streams.setUncaughtExceptionHandler((t, e) -> { + if (uncaughtException != null || !e.getMessage().contains("Injected test exception")) { + e.printStackTrace(System.err); + hasUnexpectedError = true; + } + uncaughtException = e; + }); + + return streams; + } + + private void writeInputData(final List> records) { + IntegrationTestUtils.produceKeyValuesSynchronously( + MULTI_PARTITION_INPUT_TOPIC, + records, + TestUtils.producerConfig(CLUSTER.bootstrapServers(), LongSerializer.class, LongSerializer.class), + CLUSTER.time + ); + } + + private List> readResult(final int numberOfRecords, + final String groupId) throws Exception { + if (groupId != null) { + return IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived( + TestUtils.consumerConfig( + CLUSTER.bootstrapServers(), + groupId, + LongDeserializer.class, + LongDeserializer.class, + Utils.mkProperties(Collections.singletonMap( + ConsumerConfig.ISOLATION_LEVEL_CONFIG, + IsolationLevel.READ_COMMITTED.name().toLowerCase(Locale.ROOT)))), + SINGLE_PARTITION_OUTPUT_TOPIC, + numberOfRecords + ); + } + + // read uncommitted + return IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived( + TestUtils.consumerConfig(CLUSTER.bootstrapServers(), LongDeserializer.class, LongDeserializer.class), + SINGLE_PARTITION_OUTPUT_TOPIC, + numberOfRecords + ); + } + + private List> computeExpectedResult(final List> input) { + final List> expectedResult = new ArrayList<>(input.size()); + + final HashMap sums = new HashMap<>(); + + for (final KeyValue record : input) { + Long sum = sums.get(record.key); + if (sum == null) { + sum = record.value; + } else { + sum += record.value; + } + sums.put(record.key, sum); + expectedResult.add(new KeyValue<>(record.key, sum)); + } + + return expectedResult; + } + + private Set> getMaxPerKey(final List> input) { + final Set> expectedResult = new HashSet<>(input.size()); + + final HashMap maxPerKey = new HashMap<>(); + + for (final KeyValue record : input) { + final Long max = maxPerKey.get(record.key); + if (max == null || record.value > max) { + maxPerKey.put(record.key, record.value); + } + + } + + for (final Map.Entry max : maxPerKey.entrySet()) { + expectedResult.add(new KeyValue<>(max.getKey(), max.getValue())); + } + + return expectedResult; + } + + private void verifyStateStore(final KafkaStreams streams, + final Set> expectedStoreContent, + final String reason) throws Exception { + final ReadOnlyKeyValueStore store = IntegrationTestUtils + .getStore(300_000L, storeName, streams, QueryableStoreTypes.keyValueStore()); + assertNotNull(store); + + try (final KeyValueIterator it = store.all()) { + while (it.hasNext()) { + assertTrue(reason, expectedStoreContent.remove(it.next())); + } + + assertTrue(reason, expectedStoreContent.isEmpty()); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/EosV2UpgradeIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/EosV2UpgradeIntegrationTest.java new file mode 100644 index 0000000..b6aab86 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/EosV2UpgradeIntegrationTest.java @@ -0,0 +1,1208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.Partitioner; +import org.apache.kafka.clients.producer.Producer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.IsolationLevel; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.serialization.LongDeserializer; +import org.apache.kafka.common.serialization.LongSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KafkaStreams.State; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StoreQueryParameters; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.StreamsConfig.InternalConfig; +import org.apache.kafka.streams.errors.StreamsUncaughtExceptionHandler; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils.StableAssignmentListener; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Transformer; +import org.apache.kafka.streams.kstream.TransformerSupplier; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.internals.DefaultKafkaClientSupplier; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.QueryableStoreTypes; +import org.apache.kafka.streams.state.ReadOnlyKeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.StreamsTestUtils; +import org.apache.kafka.test.TestUtils; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.File; +import java.io.IOException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.test.TestUtils.waitForCondition; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertFalse; + +@RunWith(Parameterized.class) +@Category({IntegrationTest.class}) +public class EosV2UpgradeIntegrationTest { + + @Parameterized.Parameters(name = "{0}") + public static Collection data() { + return Arrays.asList(new Boolean[][] { + {false}, + {true} + }); + } + + @Parameterized.Parameter + public boolean injectError; + + private static final int NUM_BROKERS = 3; + private static final int MAX_POLL_INTERVAL_MS = (int) Duration.ofSeconds(100L).toMillis(); + private static final long MAX_WAIT_TIME_MS = Duration.ofMinutes(1L).toMillis(); + + private static final List> CLOSE = + Collections.unmodifiableList( + Arrays.asList( + KeyValue.pair(KafkaStreams.State.RUNNING, KafkaStreams.State.PENDING_SHUTDOWN), + KeyValue.pair(KafkaStreams.State.PENDING_SHUTDOWN, KafkaStreams.State.NOT_RUNNING) + ) + ); + private static final List> CRASH = + Collections.unmodifiableList( + Collections.singletonList( + KeyValue.pair(State.PENDING_ERROR, State.ERROR) + ) + ); + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster( + NUM_BROKERS, + Utils.mkProperties(Collections.singletonMap("auto.create.topics.enable", "false")) + ); + + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + private static String applicationId; + private final static int NUM_TOPIC_PARTITIONS = 4; + private final static String CONSUMER_GROUP_ID = "readCommitted"; + private final static String MULTI_PARTITION_INPUT_TOPIC = "multiPartitionInputTopic"; + private final static String MULTI_PARTITION_OUTPUT_TOPIC = "multiPartitionOutputTopic"; + private final static String APP_DIR_1 = "appDir1"; + private final static String APP_DIR_2 = "appDir2"; + private final static String UNEXPECTED_EXCEPTION_MSG = "Fail the test since we got an unexpected exception, or " + + "there are too many exceptions thrown, please check standard error log for more info."; + private final String storeName = "store"; + + private final StableAssignmentListener assignmentListener = new StableAssignmentListener(); + + private final AtomicBoolean errorInjectedClient1 = new AtomicBoolean(false); + private final AtomicBoolean errorInjectedClient2 = new AtomicBoolean(false); + private final AtomicBoolean commitErrorInjectedClient1 = new AtomicBoolean(false); + private final AtomicBoolean commitErrorInjectedClient2 = new AtomicBoolean(false); + private final AtomicInteger commitCounterClient1 = new AtomicInteger(-1); + private final AtomicInteger commitCounterClient2 = new AtomicInteger(-1); + private final AtomicInteger commitRequested = new AtomicInteger(0); + + private int testNumber = 0; + private Map exceptionCounts = new HashMap() { + { + put(APP_DIR_1, 0); + put(APP_DIR_2, 0); + } + }; + + private volatile boolean hasUnexpectedError = false; + + @Before + public void createTopics() throws Exception { + applicationId = "appId-" + ++testNumber; + CLUSTER.deleteTopicsAndWait( + MULTI_PARTITION_INPUT_TOPIC, + MULTI_PARTITION_OUTPUT_TOPIC, + applicationId + "-" + storeName + "-changelog" + ); + + CLUSTER.createTopic(MULTI_PARTITION_INPUT_TOPIC, NUM_TOPIC_PARTITIONS, 1); + CLUSTER.createTopic(MULTI_PARTITION_OUTPUT_TOPIC, NUM_TOPIC_PARTITIONS, 1); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldUpgradeFromEosAlphaToEosV2() throws Exception { + // We use two KafkaStreams clients that we upgrade from eos-alpha to eos-V2. During the upgrade, + // we ensure that there are pending transaction and verify that data is processed correctly. + // + // We either close clients cleanly (`injectError = false`) or let them crash (`injectError = true`) during + // the upgrade. For both cases, EOS should not be violated. + // + // Additionally, we inject errors while one client is on eos-alpha while the other client is on eos-V2: + // For this case, we inject the error during task commit phase, i.e., after offsets are appended to a TX, + // and before the TX is committed. The goal is to verify that the written but uncommitted offsets are not + // picked up, i.e., GroupCoordinator fencing works correctly. + // + // The commit interval is set to MAX_VALUE and the used `Processor` request commits manually so we have full + // control when a commit actually happens. We use an input topic with 4 partitions and each task will request + // a commit after processing 10 records. + // + // 1. start both clients and wait until rebalance stabilizes + // 2. write 10 records per input topic partition and verify that the result was committed + // 3. write 5 records per input topic partition to get pending transactions (verified via "read_uncommitted" mode) + // - all 4 pending transactions are based on task producers + // - we will get only 4 pending writes for one partition for the crash case as we crash processing the 5th record + // 4. stop/crash the first client, wait until rebalance stabilizes: + // - stop case: + // * verify that the stopped client did commit its pending transaction during shutdown + // * the second client will still have two pending transaction + // - crash case: + // * the pending transactions of the crashed client got aborted + // * the second client will have four pending transactions + // 5. restart the first client with eos-V2 enabled and wait until rebalance stabilizes + // - the rebalance should result in a commit of all tasks + // 6. write 5 record per input topic partition + // - stop case: + // * verify that the result was committed + // - crash case: + // * fail the second (i.e., eos-alpha) client during commit + // * the eos-V2 client should not pickup the pending offsets + // * verify uncommitted and committed result + // 7. only for crash case: + // 7a. restart the second client in eos-alpha mode and wait until rebalance stabilizes + // 7b. write 10 records per input topic partition + // * fail the first (i.e., eos-V2) client during commit + // * the eos-alpha client should not pickup the pending offsets + // * verify uncommitted and committed result + // 7c. restart the first client in eos-V2 mode and wait until rebalance stabilizes + // 8. write 5 records per input topic partition to get pending transactions (verified via "read_uncommitted" mode) + // - 2 transaction are base on a task producer; one transaction is based on a thread producer + // - we will get 4 pending writes for the crash case as we crash processing the 5th record + // 9. stop/crash the second client and wait until rebalance stabilizes: + // - stop only: + // * verify that the stopped client did commit its pending transaction during shutdown + // * the first client will still have one pending transaction + // - crash case: + // * the pending transactions of the crashed client got aborted + // * the first client will have one pending transactions + // 10. restart the second client with eos-V2 enabled and wait until rebalance stabilizes + // - the rebalance should result in a commit of all tasks + // 11. write 5 record per input topic partition and verify that the result was committed + + final List> stateTransitions1 = new LinkedList<>(); + KafkaStreams streams1Alpha = null; + KafkaStreams streams1V2 = null; + KafkaStreams streams1V2Two = null; + + final List> stateTransitions2 = new LinkedList<>(); + KafkaStreams streams2Alpha = null; + KafkaStreams streams2AlphaTwo = null; + KafkaStreams streams2V2 = null; + + try { + // phase 1: start both clients + streams1Alpha = getKafkaStreams(APP_DIR_1, StreamsConfig.EXACTLY_ONCE); + streams1Alpha.setStateListener( + (newState, oldState) -> stateTransitions1.add(KeyValue.pair(oldState, newState)) + ); + + assignmentListener.prepareForRebalance(); + streams1Alpha.cleanUp(); + streams1Alpha.start(); + assignmentListener.waitForNextStableAssignment(MAX_WAIT_TIME_MS); + waitForRunning(stateTransitions1); + + streams2Alpha = getKafkaStreams(APP_DIR_2, StreamsConfig.EXACTLY_ONCE); + streams2Alpha.setStateListener( + (newState, oldState) -> stateTransitions2.add(KeyValue.pair(oldState, newState)) + ); + stateTransitions1.clear(); + + assignmentListener.prepareForRebalance(); + streams2Alpha.cleanUp(); + streams2Alpha.start(); + assignmentListener.waitForNextStableAssignment(MAX_WAIT_TIME_MS); + waitForRunning(stateTransitions1); + waitForRunning(stateTransitions2); + + // in all phases, we write comments that assume that p-0/p-1 are assigned to the first client + // and p-2/p-3 are assigned to the second client (in reality the assignment might be different though) + + // phase 2: (write first batch of data) + // expected end state per output partition (C == COMMIT; A == ABORT; ---> indicate the changes): + // + // p-0: ---> 10 rec + C + // p-1: ---> 10 rec + C + // p-2: ---> 10 rec + C + // p-3: ---> 10 rec + C + final List> committedInputDataBeforeUpgrade = + prepareData(0L, 10L, 0L, 1L, 2L, 3L); + writeInputData(committedInputDataBeforeUpgrade); + + waitForCondition( + () -> commitRequested.get() == 4, + MAX_WAIT_TIME_MS, + "SteamsTasks did not request commit." + ); + + final Map committedState = new HashMap<>(); + final List> expectedUncommittedResult = + computeExpectedResult(committedInputDataBeforeUpgrade, committedState); + verifyCommitted(expectedUncommittedResult); + + // phase 3: (write partial second batch of data) + // expected end state per output partition (C == COMMIT; A == ABORT; ---> indicate the changes): + // + // stop case: + // p-0: 10 rec + C ---> 5 rec (pending) + // p-1: 10 rec + C ---> 5 rec (pending) + // p-2: 10 rec + C ---> 5 rec (pending) + // p-3: 10 rec + C ---> 5 rec (pending) + // crash case: (we just assumes that we inject the error for p-0; in reality it might be a different partition) + // (we don't crash right away and write one record less) + // p-0: 10 rec + C ---> 4 rec (pending) + // p-1: 10 rec + C ---> 5 rec (pending) + // p-2: 10 rec + C ---> 5 rec (pending) + // p-3: 10 rec + C ---> 5 rec (pending) + final Set cleanKeys = mkSet(0L, 1L, 2L, 3L); + final Set keysFirstClientAlpha = keysFromInstance(streams1Alpha); + final long firstFailingKeyForCrashCase = keysFirstClientAlpha.iterator().next(); + cleanKeys.remove(firstFailingKeyForCrashCase); + + final List> uncommittedInputDataBeforeFirstUpgrade = new LinkedList<>(); + final HashMap uncommittedState = new HashMap<>(committedState); + if (!injectError) { + uncommittedInputDataBeforeFirstUpgrade.addAll( + prepareData(10L, 15L, 0L, 1L, 2L, 3L) + ); + writeInputData(uncommittedInputDataBeforeFirstUpgrade); + + expectedUncommittedResult.addAll( + computeExpectedResult(uncommittedInputDataBeforeFirstUpgrade, uncommittedState) + ); + verifyUncommitted(expectedUncommittedResult); + } else { + final List> uncommittedInputDataWithoutFailingKey = new LinkedList<>(); + for (final long key : cleanKeys) { + uncommittedInputDataWithoutFailingKey.addAll(prepareData(10L, 15L, key)); + } + uncommittedInputDataWithoutFailingKey.addAll( + prepareData(10L, 14L, firstFailingKeyForCrashCase) + ); + uncommittedInputDataBeforeFirstUpgrade.addAll(uncommittedInputDataWithoutFailingKey); + writeInputData(uncommittedInputDataWithoutFailingKey); + + expectedUncommittedResult.addAll( + computeExpectedResult(uncommittedInputDataWithoutFailingKey, new HashMap<>(committedState)) + ); + verifyUncommitted(expectedUncommittedResult); + } + + // phase 4: (stop first client) + // expected end state per output partition (C == COMMIT; A == ABORT; ---> indicate the changes): + // + // stop case: (client 1 will commit its two tasks on close()) + // p-0: 10 rec + C + 5 rec ---> C + // p-1: 10 rec + C + 5 rec ---> C + // p-2: 10 rec + C + 5 rec (pending) + // p-3: 10 rec + C + 5 rec (pending) + // crash case: (we write the last record that will trigger the crash; both TX from client 1 will be aborted + // during fail over by client 2 and retried) + // p-0: 10 rec + C + 4 rec + A + 5 rec (pending) + // p-1: 10 rec + C + 5 rec + A + 5 rec (pending) + // p-2: 10 rec + C + 5 rec (pending) + // p-3: 10 rec + C + 5 rec (pending) + stateTransitions2.clear(); + assignmentListener.prepareForRebalance(); + + if (!injectError) { + stateTransitions1.clear(); + streams1Alpha.close(); + waitForStateTransition(stateTransitions1, CLOSE); + } else { + errorInjectedClient1.set(true); + + final List> dataPotentiallyFirstFailingKey = + prepareData(14L, 15L, firstFailingKeyForCrashCase); + uncommittedInputDataBeforeFirstUpgrade.addAll(dataPotentiallyFirstFailingKey); + writeInputData(dataPotentiallyFirstFailingKey); + } + assignmentListener.waitForNextStableAssignment(MAX_WAIT_TIME_MS); + waitForRunning(stateTransitions2); + + if (!injectError) { + final List> committedInputDataDuringFirstUpgrade = + uncommittedInputDataBeforeFirstUpgrade + .stream() + .filter(pair -> keysFirstClientAlpha.contains(pair.key)) + .collect(Collectors.toList()); + + final List> expectedCommittedResult = + computeExpectedResult(committedInputDataDuringFirstUpgrade, committedState); + verifyCommitted(expectedCommittedResult); + } else { + // retrying TX + expectedUncommittedResult.addAll(computeExpectedResult( + uncommittedInputDataBeforeFirstUpgrade + .stream() + .filter(pair -> keysFirstClientAlpha.contains(pair.key)) + .collect(Collectors.toList()), + new HashMap<>(committedState) + )); + verifyUncommitted(expectedUncommittedResult); + waitForStateTransitionContains(stateTransitions1, CRASH); + + errorInjectedClient1.set(false); + stateTransitions1.clear(); + streams1Alpha.close(); + assertFalse(UNEXPECTED_EXCEPTION_MSG, hasUnexpectedError); + } + + // phase 5: (restart first client) + // expected end state per output partition (C == COMMIT; A == ABORT; ---> indicate the changes): + // + // stop case: (client 2 (alpha) will commit the two revoked task that migrate back to client 1) + // (note: we may or may not get newly committed data, depending if the already committed tasks + // migrate back to client 1, or different tasks) + // (below we show the case for which we don't get newly committed data) + // p-0: 10 rec + C + 5 rec ---> C + // p-1: 10 rec + C + 5 rec ---> C + // p-2: 10 rec + C + 5 rec (pending) + // p-3: 10 rec + C + 5 rec (pending) + // crash case: (client 2 (alpha) will commit all tasks even only two tasks are revoked and migrate back to client 1) + // (note: because nothing was committed originally, we always get newly committed data) + // p-0: 10 rec + C + 4 rec + A + 5 rec ---> C + // p-1: 10 rec + C + 5 rec + A + 5 rec ---> C + // p-2: 10 rec + C + 5 rec ---> C + // p-3: 10 rec + C + 5 rec ---> C + commitRequested.set(0); + stateTransitions1.clear(); + stateTransitions2.clear(); + streams1V2 = getKafkaStreams(APP_DIR_1, StreamsConfig.EXACTLY_ONCE_V2); + streams1V2.setStateListener((newState, oldState) -> stateTransitions1.add(KeyValue.pair(oldState, newState))); + assignmentListener.prepareForRebalance(); + streams1V2.start(); + assignmentListener.waitForNextStableAssignment(MAX_WAIT_TIME_MS); + waitForRunning(stateTransitions1); + waitForRunning(stateTransitions2); + + final Set newlyCommittedKeys; + if (!injectError) { + newlyCommittedKeys = keysFromInstance(streams1V2); + newlyCommittedKeys.removeAll(keysFirstClientAlpha); + } else { + newlyCommittedKeys = mkSet(0L, 1L, 2L, 3L); + } + + final List> expectedCommittedResultAfterRestartFirstClient = computeExpectedResult( + uncommittedInputDataBeforeFirstUpgrade + .stream() + .filter(pair -> newlyCommittedKeys.contains(pair.key)) + .collect(Collectors.toList()), + committedState + ); + verifyCommitted(expectedCommittedResultAfterRestartFirstClient); + + // phase 6: (complete second batch of data; crash: let second client fail on commit) + // expected end state per output partition (C == COMMIT; A == ABORT; ---> indicate the changes): + // + // stop case: (both client commit regularly) + // (depending on the task movement in phase 5, we may or may not get newly committed data; + // we show the case for which p-2 and p-3 are newly committed below) + // p-0: 10 rec + C + 5 rec + C ---> 5 rec + C + // p-1: 10 rec + C + 5 rec + C ---> 5 rec + C + // p-2: 10 rec + C + 5 rec ---> 5 rec + C + // p-3: 10 rec + C + 5 rec ---> 5 rec + C + // crash case: (second/alpha client fails and both TX are aborted) + // (first/V2 client reprocessed the 10 records and commits TX) + // p-0: 10 rec + C + 4 rec + A + 5 rec + C ---> 5 rec + C + // p-1: 10 rec + C + 5 rec + A + 5 rec + C ---> 5 rec + C + // p-2: 10 rec + C + 5 rec + C ---> 5 rec + A + 5 rec + C + // p-3: 10 rec + C + 5 rec + C ---> 5 rec + A + 5 rec + C + commitCounterClient1.set(0); + + if (!injectError) { + final List> finishSecondBatch = prepareData(15L, 20L, 0L, 1L, 2L, 3L); + writeInputData(finishSecondBatch); + + final List> committedInputDataDuringUpgrade = uncommittedInputDataBeforeFirstUpgrade + .stream() + .filter(pair -> !keysFirstClientAlpha.contains(pair.key)) + .filter(pair -> !newlyCommittedKeys.contains(pair.key)) + .collect(Collectors.toList()); + committedInputDataDuringUpgrade.addAll( + finishSecondBatch + ); + + expectedUncommittedResult.addAll( + computeExpectedResult(finishSecondBatch, uncommittedState) + ); + final List> expectedCommittedResult = + computeExpectedResult(committedInputDataDuringUpgrade, committedState); + verifyCommitted(expectedCommittedResult); + } else { + final Set keysFirstClientV2 = keysFromInstance(streams1V2); + final Set keysSecondClientAlpha = keysFromInstance(streams2Alpha); + + final List> committedInputDataAfterFirstUpgrade = + prepareData(15L, 20L, keysFirstClientV2.toArray(new Long[0])); + writeInputData(committedInputDataAfterFirstUpgrade); + + final List> expectedCommittedResultBeforeFailure = + computeExpectedResult(committedInputDataAfterFirstUpgrade, committedState); + verifyCommitted(expectedCommittedResultBeforeFailure); + expectedUncommittedResult.addAll(expectedCommittedResultBeforeFailure); + + commitCounterClient2.set(0); + + final Iterator it = keysSecondClientAlpha.iterator(); + final Long otherKey = it.next(); + final Long failingKey = it.next(); + + final List> uncommittedInputDataAfterFirstUpgrade = + prepareData(15L, 19L, keysSecondClientAlpha.toArray(new Long[0])); + uncommittedInputDataAfterFirstUpgrade.addAll(prepareData(19L, 20L, otherKey)); + writeInputData(uncommittedInputDataAfterFirstUpgrade); + + uncommittedState.putAll(committedState); + expectedUncommittedResult.addAll( + computeExpectedResult(uncommittedInputDataAfterFirstUpgrade, uncommittedState) + ); + verifyUncommitted(expectedUncommittedResult); + + stateTransitions1.clear(); + stateTransitions2.clear(); + assignmentListener.prepareForRebalance(); + + commitCounterClient1.set(0); + commitErrorInjectedClient2.set(true); + + final List> dataFailingKey = prepareData(19L, 20L, failingKey); + uncommittedInputDataAfterFirstUpgrade.addAll(dataFailingKey); + writeInputData(dataFailingKey); + + expectedUncommittedResult.addAll( + computeExpectedResult(dataFailingKey, uncommittedState) + ); + verifyUncommitted(expectedUncommittedResult); + + assignmentListener.waitForNextStableAssignment(MAX_WAIT_TIME_MS); + + waitForStateTransitionContains(stateTransitions2, CRASH); + + commitErrorInjectedClient2.set(false); + stateTransitions2.clear(); + streams2Alpha.close(); + assertFalse(UNEXPECTED_EXCEPTION_MSG, hasUnexpectedError); + + final List> expectedCommittedResultAfterFailure = + computeExpectedResult(uncommittedInputDataAfterFirstUpgrade, committedState); + verifyCommitted(expectedCommittedResultAfterFailure); + expectedUncommittedResult.addAll(expectedCommittedResultAfterFailure); + } + + // 7. only for crash case: + // 7a. restart the failed second client in eos-alpha mode and wait until rebalance stabilizes + // 7b. write third batch of input data + // * fail the first (i.e., eos-V2) client during commit + // * the eos-alpha client should not pickup the pending offsets + // * verify uncommitted and committed result + // 7c. restart the first client in eos-V2 mode and wait until rebalance stabilizes + // + // crash case: + // p-0: 10 rec + C + 4 rec + A + 5 rec + C + 5 rec + C ---> 10 rec + A + 10 rec + C + // p-1: 10 rec + C + 5 rec + A + 5 rec + C + 5 rec + C ---> 10 rec + A + 10 rec + C + // p-2: 10 rec + C + 5 rec + C + 5 rec + A + 5 rec + C ---> 10 rec + C + // p-3: 10 rec + C + 5 rec + C + 5 rec + A + 5 rec + C ---> 10 rec + C + if (!injectError) { + streams2AlphaTwo = streams2Alpha; + } else { + // 7a restart the second client in eos-alpha mode and wait until rebalance stabilizes + commitCounterClient1.set(0); + commitCounterClient2.set(-1); + stateTransitions1.clear(); + stateTransitions2.clear(); + streams2AlphaTwo = getKafkaStreams(APP_DIR_2, StreamsConfig.EXACTLY_ONCE); + streams2AlphaTwo.setStateListener( + (newState, oldState) -> stateTransitions2.add(KeyValue.pair(oldState, newState)) + ); + assignmentListener.prepareForRebalance(); + streams2AlphaTwo.start(); + assignmentListener.waitForNextStableAssignment(MAX_WAIT_TIME_MS); + waitForRunning(stateTransitions1); + waitForRunning(stateTransitions2); + + // 7b. write third batch of input data + final Set keysFirstClientV2 = keysFromInstance(streams1V2); + final Set keysSecondClientAlphaTwo = keysFromInstance(streams2AlphaTwo); + + final List> committedInputDataBetweenUpgrades = + prepareData(20L, 30L, keysSecondClientAlphaTwo.toArray(new Long[0])); + writeInputData(committedInputDataBetweenUpgrades); + + final List> expectedCommittedResultBeforeFailure = + computeExpectedResult(committedInputDataBetweenUpgrades, committedState); + verifyCommitted(expectedCommittedResultBeforeFailure); + expectedUncommittedResult.addAll(expectedCommittedResultBeforeFailure); + + commitCounterClient2.set(0); + + final Iterator it = keysFirstClientV2.iterator(); + final Long otherKey = it.next(); + final Long failingKey = it.next(); + + final List> uncommittedInputDataBetweenUpgrade = + prepareData(20L, 29L, keysFirstClientV2.toArray(new Long[0])); + uncommittedInputDataBetweenUpgrade.addAll(prepareData(29L, 30L, otherKey)); + writeInputData(uncommittedInputDataBetweenUpgrade); + + uncommittedState.putAll(committedState); + expectedUncommittedResult.addAll( + computeExpectedResult(uncommittedInputDataBetweenUpgrade, uncommittedState) + ); + verifyUncommitted(expectedUncommittedResult); + + stateTransitions1.clear(); + stateTransitions2.clear(); + assignmentListener.prepareForRebalance(); + commitCounterClient2.set(0); + commitErrorInjectedClient1.set(true); + + final List> dataFailingKey = prepareData(29L, 30L, failingKey); + uncommittedInputDataBetweenUpgrade.addAll(dataFailingKey); + writeInputData(dataFailingKey); + + expectedUncommittedResult.addAll( + computeExpectedResult(dataFailingKey, uncommittedState) + ); + verifyUncommitted(expectedUncommittedResult); + + assignmentListener.waitForNextStableAssignment(MAX_WAIT_TIME_MS); + + waitForStateTransitionContains(stateTransitions1, CRASH); + + commitErrorInjectedClient1.set(false); + stateTransitions1.clear(); + streams1V2.close(); + assertFalse(UNEXPECTED_EXCEPTION_MSG, hasUnexpectedError); + + final List> expectedCommittedResultAfterFailure = + computeExpectedResult(uncommittedInputDataBetweenUpgrade, committedState); + verifyCommitted(expectedCommittedResultAfterFailure); + expectedUncommittedResult.addAll(expectedCommittedResultAfterFailure); + + // 7c. restart the first client in eos-V2 mode and wait until rebalance stabilizes + stateTransitions1.clear(); + stateTransitions2.clear(); + streams1V2Two = getKafkaStreams(APP_DIR_1, StreamsConfig.EXACTLY_ONCE_V2); + streams1V2Two.setStateListener((newState, oldState) -> stateTransitions1.add(KeyValue.pair(oldState, newState))); + assignmentListener.prepareForRebalance(); + streams1V2Two.start(); + assignmentListener.waitForNextStableAssignment(MAX_WAIT_TIME_MS); + waitForRunning(stateTransitions1); + waitForRunning(stateTransitions2); + } + + // phase 8: (write partial last batch of data) + // expected end state per output partition (C == COMMIT; A == ABORT; ---> indicate the changes): + // + // stop case: + // p-0: 10 rec + C + 5 rec + C + 5 rec + C ---> 5 rec (pending) + // p-1: 10 rec + C + 5 rec + C + 5 rec + C ---> 5 rec (pending) + // p-2: 10 rec + C + 5 rec + C + 5 rec + C ---> 5 rec (pending) + // p-3: 10 rec + C + 5 rec + C + 5 rec + C ---> 5 rec (pending) + // crash case: (we just assumes that we inject the error for p-2; in reality it might be a different partition) + // (we don't crash right away and write one record less) + // p-0: 10 rec + C + 4 rec + A + 5 rec + C + 5 rec + C + 10 rec + A + 10 rec + C ---> 5 rec (pending) + // p-1: 10 rec + C + 5 rec + A + 5 rec + C + 5 rec + C + 10 rec + A + 10 rec + C ---> 5 rec (pending) + // p-2: 10 rec + C + 5 rec + C + 5 rec + A + 5 rec + C + 10 rec + C ---> 4 rec (pending) + // p-3: 10 rec + C + 5 rec + C + 5 rec + A + 5 rec + C + 10 rec + C ---> 5 rec (pending) + cleanKeys.addAll(mkSet(0L, 1L, 2L, 3L)); + final Set keysSecondClientAlphaTwo = keysFromInstance(streams2AlphaTwo); + final long secondFailingKeyForCrashCase = keysSecondClientAlphaTwo.iterator().next(); + cleanKeys.remove(secondFailingKeyForCrashCase); + + final List> uncommittedInputDataBeforeSecondUpgrade = new LinkedList<>(); + if (!injectError) { + uncommittedInputDataBeforeSecondUpgrade.addAll( + prepareData(30L, 35L, 0L, 1L, 2L, 3L) + ); + writeInputData(uncommittedInputDataBeforeSecondUpgrade); + + expectedUncommittedResult.addAll( + computeExpectedResult(uncommittedInputDataBeforeSecondUpgrade, new HashMap<>(committedState)) + ); + verifyUncommitted(expectedUncommittedResult); + } else { + final List> uncommittedInputDataWithoutFailingKey = new LinkedList<>(); + for (final long key : cleanKeys) { + uncommittedInputDataWithoutFailingKey.addAll(prepareData(30L, 35L, key)); + } + uncommittedInputDataWithoutFailingKey.addAll( + prepareData(30L, 34L, secondFailingKeyForCrashCase) + ); + uncommittedInputDataBeforeSecondUpgrade.addAll(uncommittedInputDataWithoutFailingKey); + writeInputData(uncommittedInputDataWithoutFailingKey); + + expectedUncommittedResult.addAll( + computeExpectedResult(uncommittedInputDataWithoutFailingKey, new HashMap<>(committedState)) + ); + verifyUncommitted(expectedUncommittedResult); + } + + // phase 9: (stop/crash second client) + // expected end state per output partition (C == COMMIT; A == ABORT; ---> indicate the changes): + // + // stop case: (client 2 (alpha) will commit its two tasks on close()) + // p-0: 10 rec + C + 5 rec + C + 5 rec + C + 5 rec (pending) + // p-1: 10 rec + C + 5 rec + C + 5 rec + C + 5 rec (pending) + // p-2: 10 rec + C + 5 rec + C + 5 rec + C + 5 rec ---> C + // p-3: 10 rec + C + 5 rec + C + 5 rec + C + 5 rec ---> C + // crash case: (we write the last record that will trigger the crash; both TX from client 2 will be aborted + // during fail over by client 1 and retried) + // p-0: 10 rec + C + 4 rec + A + 5 rec + C + 5 rec + C + 10 rec + A + 10 rec + C + 5 rec (pending) + // p-1: 10 rec + C + 5 rec + A + 5 rec + C + 5 rec + C + 10 rec + A + 10 rec + C + 5 rec (pending) + // p-2: 10 rec + C + 5 rec + C + 5 rec + A + 5 rec + C + 10 rec + C + 4 rec ---> A + 5 rec (pending) + // p-3: 10 rec + C + 5 rec + C + 5 rec + A + 5 rec + C + 10 rec + C + 5 rec ---> A + 5 rec (pending) + stateTransitions1.clear(); + assignmentListener.prepareForRebalance(); + if (!injectError) { + stateTransitions2.clear(); + streams2AlphaTwo.close(); + waitForStateTransition(stateTransitions2, CLOSE); + } else { + errorInjectedClient2.set(true); + + final List> dataPotentiallySecondFailingKey = + prepareData(34L, 35L, secondFailingKeyForCrashCase); + uncommittedInputDataBeforeSecondUpgrade.addAll(dataPotentiallySecondFailingKey); + writeInputData(dataPotentiallySecondFailingKey); + } + assignmentListener.waitForNextStableAssignment(MAX_WAIT_TIME_MS); + waitForRunning(stateTransitions1); + + if (!injectError) { + final List> committedInputDataDuringSecondUpgrade = + uncommittedInputDataBeforeSecondUpgrade + .stream() + .filter(pair -> keysSecondClientAlphaTwo.contains(pair.key)) + .collect(Collectors.toList()); + + final List> expectedCommittedResult = + computeExpectedResult(committedInputDataDuringSecondUpgrade, committedState); + verifyCommitted(expectedCommittedResult); + } else { + // retrying TX + expectedUncommittedResult.addAll(computeExpectedResult( + uncommittedInputDataBeforeSecondUpgrade + .stream() + .filter(pair -> keysSecondClientAlphaTwo.contains(pair.key)) + .collect(Collectors.toList()), + new HashMap<>(committedState) + )); + verifyUncommitted(expectedUncommittedResult); + waitForStateTransitionContains(stateTransitions2, CRASH); + + errorInjectedClient2.set(false); + stateTransitions2.clear(); + streams2AlphaTwo.close(); + assertFalse(UNEXPECTED_EXCEPTION_MSG, hasUnexpectedError); + } + + // phase 10: (restart second client) + // expected end state per output partition (C == COMMIT; A == ABORT; ---> indicate the changes): + // + // the state below indicate the case for which the "original" tasks of client2 are migrated back to client2 + // if a task "switch" happens, we might get additional commits (omitted in the comment for brevity) + // + // stop case: (client 1 (V2) will commit all four tasks if at least one revoked and migrate task needs committing back to client 2) + // p-0: 10 rec + C + 5 rec + C + 5 rec + C + 5 rec ---> C + // p-1: 10 rec + C + 5 rec + C + 5 rec + C + 5 rec ---> C + // p-2: 10 rec + C + 5 rec + C + 5 rec + C + 5 rec + C + // p-3: 10 rec + C + 5 rec + C + 5 rec + C + 5 rec + C + // crash case: (client 1 (V2) will commit all four tasks even only two are migrate back to client 2) + // p-0: 10 rec + C + 4 rec + A + 5 rec + C + 5 rec + C + 10 rec + A + 10 rec + C + 5 rec ---> C + // p-1: 10 rec + C + 5 rec + A + 5 rec + C + 5 rec + C + 10 rec + A + 10 rec + C + 5 rec ---> C + // p-2: 10 rec + C + 5 rec + C + 5 rec + A + 5 rec + C + 10 rec + C + 4 rec + A + 5 rec ---> C + // p-3: 10 rec + C + 5 rec + C + 5 rec + A + 5 rec + C + 10 rec + C + 5 rec + A + 5 rec ---> C + commitRequested.set(0); + stateTransitions1.clear(); + stateTransitions2.clear(); + streams2V2 = getKafkaStreams(APP_DIR_1, StreamsConfig.EXACTLY_ONCE_V2); + streams2V2.setStateListener( + (newState, oldState) -> stateTransitions2.add(KeyValue.pair(oldState, newState)) + ); + assignmentListener.prepareForRebalance(); + streams2V2.start(); + assignmentListener.waitForNextStableAssignment(MAX_WAIT_TIME_MS); + waitForRunning(stateTransitions1); + waitForRunning(stateTransitions2); + + newlyCommittedKeys.clear(); + if (!injectError) { + newlyCommittedKeys.addAll(keysFromInstance(streams2V2)); + newlyCommittedKeys.removeAll(keysSecondClientAlphaTwo); + } else { + newlyCommittedKeys.addAll(mkSet(0L, 1L, 2L, 3L)); + } + + final List> expectedCommittedResultAfterRestartSecondClient = computeExpectedResult( + uncommittedInputDataBeforeSecondUpgrade + .stream() + .filter(pair -> newlyCommittedKeys.contains(pair.key)) + .collect(Collectors.toList()), + committedState + ); + verifyCommitted(expectedCommittedResultAfterRestartSecondClient); + + // phase 11: (complete fourth batch of data) + // expected end state per output partition (C == COMMIT; A == ABORT; ---> indicate the changes): + // + // stop case: + // p-0: 10 rec + C + 5 rec + C + 5 rec + C + 5 rec + C ---> 5 rec + C + // p-1: 10 rec + C + 5 rec + C + 5 rec + C + 5 rec + C ---> 5 rec + C + // p-2: 10 rec + C + 5 rec + C + 5 rec + C + 5 rec + C ---> 5 rec + C + // p-3: 10 rec + C + 5 rec + C + 5 rec + C + 5 rec + C ---> 5 rec + C + // crash case: (we just assumes that we inject the error for p-2; in reality it might be a different partition) + // p-0: 10 rec + C + 4 rec + A + 5 rec + C + 5 rec + C + 10 rec + A + 10 rec + C + 5 rec + C ---> 5 rec + C + // p-1: 10 rec + C + 5 rec + A + 5 rec + C + 5 rec + C + 10 rec + A + 10 rec + C + 5 rec + C ---> 5 rec + C + // p-2: 10 rec + C + 5 rec + C + 5 rec + A + 5 rec + C + 10 rec + C + 4 rec + A + 5 rec + C ---> 5 rec + C + // p-3: 10 rec + C + 5 rec + C + 5 rec + A + 5 rec + C + 10 rec + C + 5 rec + A + 5 rec + C ---> 5 rec + C + commitCounterClient1.set(-1); + commitCounterClient2.set(-1); + + final List> finishLastBatch = + prepareData(35L, 40L, 0L, 1L, 2L, 3L); + writeInputData(finishLastBatch); + + final Set uncommittedKeys = mkSet(0L, 1L, 2L, 3L); + uncommittedKeys.removeAll(keysSecondClientAlphaTwo); + uncommittedKeys.removeAll(newlyCommittedKeys); + final List> committedInputDataDuringUpgrade = uncommittedInputDataBeforeSecondUpgrade + .stream() + .filter(pair -> uncommittedKeys.contains(pair.key)) + .collect(Collectors.toList()); + committedInputDataDuringUpgrade.addAll( + finishLastBatch + ); + + final List> expectedCommittedResult = + computeExpectedResult(committedInputDataDuringUpgrade, committedState); + verifyCommitted(expectedCommittedResult); + } finally { + if (streams1Alpha != null) { + streams1Alpha.close(); + } + if (streams1V2 != null) { + streams1V2.close(); + } + if (streams1V2Two != null) { + streams1V2Two.close(); + } + if (streams2Alpha != null) { + streams2Alpha.close(); + } + if (streams2AlphaTwo != null) { + streams2AlphaTwo.close(); + } + if (streams2V2 != null) { + streams2V2.close(); + } + } + } + + private KafkaStreams getKafkaStreams(final String appDir, + final String processingGuarantee) { + final StreamsBuilder builder = new StreamsBuilder(); + + final String[] storeNames = new String[] {storeName}; + final StoreBuilder> storeBuilder = Stores + .keyValueStoreBuilder(Stores.persistentKeyValueStore(storeName), Serdes.Long(), Serdes.Long()) + .withCachingEnabled(); + + builder.addStateStore(storeBuilder); + + final KStream input = builder.stream(MULTI_PARTITION_INPUT_TOPIC); + input.transform(new TransformerSupplier>() { + @Override + public Transformer> get() { + return new Transformer>() { + ProcessorContext context; + KeyValueStore state = null; + AtomicBoolean crash; + AtomicInteger sharedCommit; + + @Override + public void init(final ProcessorContext context) { + this.context = context; + state = context.getStateStore(storeName); + final String clientId = context.appConfigs().get(StreamsConfig.CLIENT_ID_CONFIG).toString(); + if (APP_DIR_1.equals(clientId)) { + crash = errorInjectedClient1; + sharedCommit = commitCounterClient1; + } else { + crash = errorInjectedClient2; + sharedCommit = commitCounterClient2; + } + } + + @Override + public KeyValue transform(final Long key, final Long value) { + if ((value + 1) % 10 == 0) { + if (sharedCommit.get() < 0 || + sharedCommit.incrementAndGet() == 2) { + + context.commit(); + } + commitRequested.incrementAndGet(); + } + + Long sum = state.get(key); + if (sum == null) { + sum = value; + } else { + sum += value; + } + state.put(key, sum); + state.flush(); + + if (value % 10 == 4 && // potentially crash when processing 5th, 15th, or 25th record (etc.) + crash != null && crash.compareAndSet(true, false)) { + // only crash a single task + throw new RuntimeException("Injected test exception."); + } + + return new KeyValue<>(key, state.get(key)); + } + + @Override + public void close() {} + }; + } }, storeNames) + .to(MULTI_PARTITION_OUTPUT_TOPIC); + + final Properties properties = new Properties(); + properties.put(StreamsConfig.CLIENT_ID_CONFIG, appDir); + properties.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, processingGuarantee); + final long commitInterval = Duration.ofMinutes(1L).toMillis(); + properties.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, commitInterval); + properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.METADATA_MAX_AGE_CONFIG), Duration.ofSeconds(1L).toMillis()); + properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG), "earliest"); + properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG), (int) Duration.ofSeconds(5L).toMillis()); + properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG), (int) Duration.ofSeconds(5L).minusMillis(1L).toMillis()); + properties.put(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG), MAX_POLL_INTERVAL_MS); + properties.put(StreamsConfig.producerPrefix(ProducerConfig.TRANSACTION_TIMEOUT_CONFIG), (int) commitInterval); + properties.put(StreamsConfig.producerPrefix(ProducerConfig.PARTITIONER_CLASS_CONFIG), KeyPartitioner.class); + properties.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + properties.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath() + File.separator + appDir); + properties.put(InternalConfig.ASSIGNMENT_LISTENER, assignmentListener); + + final Properties config = StreamsTestUtils.getStreamsConfig( + applicationId, + CLUSTER.bootstrapServers(), + Serdes.LongSerde.class.getName(), + Serdes.LongSerde.class.getName(), + properties + ); + + final KafkaStreams streams = new KafkaStreams(builder.build(), config, new TestKafkaClientSupplier()); + streams.setUncaughtExceptionHandler(e -> { + if (!injectError) { + // we don't expect any exception thrown in stop case + e.printStackTrace(System.err); + hasUnexpectedError = true; + } else { + int exceptionCount = (int) exceptionCounts.get(appDir); + // should only have our injected exception or commit exception, and 2 exceptions for each stream + if (++exceptionCount > 2 || !(e instanceof RuntimeException) || + !(e.getMessage().contains("test exception"))) { + // The exception won't cause the test fail since we actually "expected" exception thrown and failed the stream. + // So, log to stderr for debugging when the exception is not what we expected, and fail in the main thread + e.printStackTrace(System.err); + hasUnexpectedError = true; + } + exceptionCounts.put(appDir, exceptionCount); + } + return StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse.SHUTDOWN_CLIENT; + }); + + return streams; + } + + private void waitForRunning(final List> observed) throws Exception { + waitForCondition( + () -> !observed.isEmpty() && observed.get(observed.size() - 1).value.equals(State.RUNNING), + MAX_WAIT_TIME_MS, + () -> "Client did not startup on time. Observers transitions: " + observed + ); + } + + private void waitForStateTransition(final List> observed, + final List> expected) + throws Exception { + + waitForCondition( + () -> observed.equals(expected), + MAX_WAIT_TIME_MS, + () -> "Client did not have the expected state transition on time. Observers transitions: " + observed + + "Expected transitions: " + expected + ); + } + + private void waitForStateTransitionContains(final List> observed, + final List> expected) + throws Exception { + + waitForCondition( + () -> observed.containsAll(expected), + MAX_WAIT_TIME_MS, + () -> "Client did not have the expected state transition on time. Observers transitions: " + observed + + "Expected transitions: " + expected + ); + } + + private List> prepareData(final long fromInclusive, + final long toExclusive, + final Long... keys) { + final List> data = new ArrayList<>(); + + for (final Long k : keys) { + for (long v = fromInclusive; v < toExclusive; ++v) { + data.add(new KeyValue<>(k, v)); + } + } + + return data; + } + + private void writeInputData(final List> records) { + final Properties config = TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + LongSerializer.class, + LongSerializer.class + ); + config.setProperty(ProducerConfig.PARTITIONER_CLASS_CONFIG, KeyPartitioner.class.getName()); + IntegrationTestUtils.produceKeyValuesSynchronously( + MULTI_PARTITION_INPUT_TOPIC, + records, + config, + CLUSTER.time + ); + } + + private void verifyCommitted(final List> expectedResult) throws Exception { + final List> committedOutput = readResult(expectedResult.size(), true); + checkResultPerKey(committedOutput, expectedResult); + } + + private void verifyUncommitted(final List> expectedResult) throws Exception { + final List> uncommittedOutput = readResult(expectedResult.size(), false); + checkResultPerKey(uncommittedOutput, expectedResult); + } + + private List> readResult(final int numberOfRecords, + final boolean readCommitted) throws Exception { + if (readCommitted) { + return IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived( + TestUtils.consumerConfig( + CLUSTER.bootstrapServers(), + CONSUMER_GROUP_ID, + LongDeserializer.class, + LongDeserializer.class, + Utils.mkProperties(Collections.singletonMap( + ConsumerConfig.ISOLATION_LEVEL_CONFIG, + IsolationLevel.READ_COMMITTED.name().toLowerCase(Locale.ROOT)) + ) + ), + MULTI_PARTITION_OUTPUT_TOPIC, + numberOfRecords, + MAX_WAIT_TIME_MS + ); + } + + // read uncommitted + return IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived( + TestUtils.consumerConfig(CLUSTER.bootstrapServers(), LongDeserializer.class, LongDeserializer.class), + MULTI_PARTITION_OUTPUT_TOPIC, + numberOfRecords + ); + } + + private void checkResultPerKey(final List> result, + final List> expectedResult) { + final Set allKeys = new HashSet<>(); + addAllKeys(allKeys, result); + addAllKeys(allKeys, expectedResult); + + for (final Long key : allKeys) { + try { + assertThat(getAllRecordPerKey(key, result), equalTo(getAllRecordPerKey(key, expectedResult))); + } catch (final AssertionError error) { + throw new AssertionError( + "expected result: " + expectedResult.stream().map(KeyValue::toString).collect(Collectors.joining(", ")) + + "\nreceived records: " + result.stream().map(KeyValue::toString).collect(Collectors.joining(", ")), + error + ); + } + } + } + + private void addAllKeys(final Set allKeys, final List> records) { + for (final KeyValue record : records) { + allKeys.add(record.key); + } + } + + private List> getAllRecordPerKey(final Long key, final List> records) { + final List> recordsPerKey = new ArrayList<>(records.size()); + + for (final KeyValue record : records) { + if (record.key.equals(key)) { + recordsPerKey.add(record); + } + } + + return recordsPerKey; + } + + private List> computeExpectedResult(final List> input, + final Map currentState) { + final List> expectedResult = new ArrayList<>(input.size()); + + for (final KeyValue record : input) { + final long sum = currentState.getOrDefault(record.key, 0L); + currentState.put(record.key, sum + record.value); + expectedResult.add(new KeyValue<>(record.key, sum + record.value)); + } + + return expectedResult; + } + + private Set keysFromInstance(final KafkaStreams streams) throws Exception { + final Set keys = new HashSet<>(); + waitForCondition( + () -> { + final ReadOnlyKeyValueStore store = streams.store( + StoreQueryParameters.fromNameAndType(storeName, QueryableStoreTypes.keyValueStore()) + ); + + keys.clear(); + try (final KeyValueIterator it = store.all()) { + while (it.hasNext()) { + final KeyValue row = it.next(); + keys.add(row.key); + } + } + + return true; + }, + MAX_WAIT_TIME_MS, + "Could not get keys from store: " + storeName + ); + + return keys; + } + + // must be public to allow KafkaProducer to instantiate it + public static class KeyPartitioner implements Partitioner { + private final static LongDeserializer LONG_DESERIALIZER = new LongDeserializer(); + + @Override + public int partition(final String topic, + final Object key, + final byte[] keyBytes, + final Object value, + final byte[] valueBytes, + final Cluster cluster) { + return LONG_DESERIALIZER.deserialize(topic, keyBytes).intValue() % NUM_TOPIC_PARTITIONS; + } + + @Override + public void close() {} + + @Override + public void configure(final Map configs) {} + } + + private class TestKafkaClientSupplier extends DefaultKafkaClientSupplier { + @Override + public Producer getProducer(final Map config) { + return new ErrorInjector(config); + } + } + + private class ErrorInjector extends KafkaProducer { + private final AtomicBoolean crash; + + public ErrorInjector(final Map configs) { + super(configs, new ByteArraySerializer(), new ByteArraySerializer()); + final String clientId = configs.get(ProducerConfig.CLIENT_ID_CONFIG).toString(); + if (clientId.contains(APP_DIR_1)) { + crash = commitErrorInjectedClient1; + } else { + crash = commitErrorInjectedClient2; + } + } + + @Override + public void commitTransaction() { + super.flush(); // we flush to ensure that the offsets are written + if (!crash.compareAndSet(true, false)) { + super.commitTransaction(); + } else { + throw new RuntimeException("Injected producer commit test exception."); + } + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/FineGrainedAutoResetIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/FineGrainedAutoResetIntegrationTest.java new file mode 100644 index 0000000..baaf06c --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/FineGrainedAutoResetIntegrationTest.java @@ -0,0 +1,326 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + + +import kafka.utils.MockTime; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.errors.StreamsUncaughtExceptionHandler; +import org.apache.kafka.streams.errors.TopologyException; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.StreamsTestUtils; +import org.apache.kafka.test.TestUtils; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.regex.Pattern; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.fail; + +@Category({IntegrationTest.class}) +public class FineGrainedAutoResetIntegrationTest { + + private static final int NUM_BROKERS = 1; + private static final String DEFAULT_OUTPUT_TOPIC = "outputTopic"; + private static final String OUTPUT_TOPIC_0 = "outputTopic_0"; + private static final String OUTPUT_TOPIC_1 = "outputTopic_1"; + private static final String OUTPUT_TOPIC_2 = "outputTopic_2"; + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(NUM_BROKERS); + + @BeforeClass + public static void startCluster() throws IOException, InterruptedException { + CLUSTER.start(); + CLUSTER.createTopics( + TOPIC_1_0, + TOPIC_2_0, + TOPIC_A_0, + TOPIC_C_0, + TOPIC_Y_0, + TOPIC_Z_0, + TOPIC_1_1, + TOPIC_2_1, + TOPIC_A_1, + TOPIC_C_1, + TOPIC_Y_1, + TOPIC_Z_1, + TOPIC_1_2, + TOPIC_2_2, + TOPIC_A_2, + TOPIC_C_2, + TOPIC_Y_2, + TOPIC_Z_2, + NOOP, + DEFAULT_OUTPUT_TOPIC, + OUTPUT_TOPIC_0, + OUTPUT_TOPIC_1, + OUTPUT_TOPIC_2); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + private final MockTime mockTime = CLUSTER.time; + + private static final String TOPIC_1_0 = "topic-1_0"; + private static final String TOPIC_2_0 = "topic-2_0"; + private static final String TOPIC_A_0 = "topic-A_0"; + private static final String TOPIC_C_0 = "topic-C_0"; + private static final String TOPIC_Y_0 = "topic-Y_0"; + private static final String TOPIC_Z_0 = "topic-Z_0"; + private static final String TOPIC_1_1 = "topic-1_1"; + private static final String TOPIC_2_1 = "topic-2_1"; + private static final String TOPIC_A_1 = "topic-A_1"; + private static final String TOPIC_C_1 = "topic-C_1"; + private static final String TOPIC_Y_1 = "topic-Y_1"; + private static final String TOPIC_Z_1 = "topic-Z_1"; + private static final String TOPIC_1_2 = "topic-1_2"; + private static final String TOPIC_2_2 = "topic-2_2"; + private static final String TOPIC_A_2 = "topic-A_2"; + private static final String TOPIC_C_2 = "topic-C_2"; + private static final String TOPIC_Y_2 = "topic-Y_2"; + private static final String TOPIC_Z_2 = "topic-Z_2"; + private static final String NOOP = "noop"; + private final Serde stringSerde = Serdes.String(); + + private static final String STRING_SERDE_CLASSNAME = Serdes.String().getClass().getName(); + private Properties streamsConfiguration; + + private final String topic1TestMessage = "topic-1 test"; + private final String topic2TestMessage = "topic-2 test"; + private final String topicATestMessage = "topic-A test"; + private final String topicCTestMessage = "topic-C test"; + private final String topicYTestMessage = "topic-Y test"; + private final String topicZTestMessage = "topic-Z test"; + + @Before + public void setUp() throws IOException { + + final Properties props = new Properties(); + props.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + props.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L); + props.put(ConsumerConfig.METADATA_MAX_AGE_CONFIG, "1000"); + props.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + + streamsConfiguration = StreamsTestUtils.getStreamsConfig( + "testAutoOffsetId", + CLUSTER.bootstrapServers(), + STRING_SERDE_CLASSNAME, + STRING_SERDE_CLASSNAME, + props); + + // Remove any state from previous test runs + IntegrationTestUtils.purgeLocalStreamsState(streamsConfiguration); + } + + @Test + public void shouldOnlyReadRecordsWhereEarliestSpecifiedWithNoCommittedOffsetsWithGlobalAutoOffsetResetLatest() throws Exception { + streamsConfiguration.put(StreamsConfig.consumerPrefix(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG), "latest"); + + final List expectedReceivedValues = Arrays.asList(topic1TestMessage, topic2TestMessage); + shouldOnlyReadForEarliest("_0", TOPIC_1_0, TOPIC_2_0, TOPIC_A_0, TOPIC_C_0, TOPIC_Y_0, TOPIC_Z_0, OUTPUT_TOPIC_0, expectedReceivedValues); + } + + @Test + public void shouldOnlyReadRecordsWhereEarliestSpecifiedWithNoCommittedOffsetsWithDefaultGlobalAutoOffsetResetEarliest() throws Exception { + final List expectedReceivedValues = Arrays.asList(topic1TestMessage, topic2TestMessage, topicYTestMessage, topicZTestMessage); + shouldOnlyReadForEarliest("_1", TOPIC_1_1, TOPIC_2_1, TOPIC_A_1, TOPIC_C_1, TOPIC_Y_1, TOPIC_Z_1, OUTPUT_TOPIC_1, expectedReceivedValues); + } + + @Test + public void shouldOnlyReadRecordsWhereEarliestSpecifiedWithInvalidCommittedOffsets() throws Exception { + commitInvalidOffsets(); + + final List expectedReceivedValues = Arrays.asList(topic1TestMessage, topic2TestMessage, topicYTestMessage, topicZTestMessage); + shouldOnlyReadForEarliest("_2", TOPIC_1_2, TOPIC_2_2, TOPIC_A_2, TOPIC_C_2, TOPIC_Y_2, TOPIC_Z_2, OUTPUT_TOPIC_2, expectedReceivedValues); + } + + private void shouldOnlyReadForEarliest( + final String topicSuffix, + final String topic1, + final String topic2, + final String topicA, + final String topicC, + final String topicY, + final String topicZ, + final String outputTopic, + final List expectedReceivedValues) throws Exception { + + final StreamsBuilder builder = new StreamsBuilder(); + + + final KStream pattern1Stream = builder.stream(Pattern.compile("topic-\\d" + topicSuffix), Consumed.with(Topology.AutoOffsetReset.EARLIEST)); + final KStream pattern2Stream = builder.stream(Pattern.compile("topic-[A-D]" + topicSuffix), Consumed.with(Topology.AutoOffsetReset.LATEST)); + final KStream namedTopicsStream = builder.stream(Arrays.asList(topicY, topicZ)); + + pattern1Stream.to(outputTopic, Produced.with(stringSerde, stringSerde)); + pattern2Stream.to(outputTopic, Produced.with(stringSerde, stringSerde)); + namedTopicsStream.to(outputTopic, Produced.with(stringSerde, stringSerde)); + + final Properties producerConfig = TestUtils.producerConfig(CLUSTER.bootstrapServers(), StringSerializer.class, StringSerializer.class); + + IntegrationTestUtils.produceValuesSynchronously(topic1, Collections.singletonList(topic1TestMessage), producerConfig, mockTime); + IntegrationTestUtils.produceValuesSynchronously(topic2, Collections.singletonList(topic2TestMessage), producerConfig, mockTime); + IntegrationTestUtils.produceValuesSynchronously(topicA, Collections.singletonList(topicATestMessage), producerConfig, mockTime); + IntegrationTestUtils.produceValuesSynchronously(topicC, Collections.singletonList(topicCTestMessage), producerConfig, mockTime); + IntegrationTestUtils.produceValuesSynchronously(topicY, Collections.singletonList(topicYTestMessage), producerConfig, mockTime); + IntegrationTestUtils.produceValuesSynchronously(topicZ, Collections.singletonList(topicZTestMessage), producerConfig, mockTime); + + final Properties consumerConfig = TestUtils.consumerConfig(CLUSTER.bootstrapServers(), StringDeserializer.class, StringDeserializer.class); + + final KafkaStreams streams = new KafkaStreams(builder.build(), streamsConfiguration); + streams.start(); + + final List> receivedKeyValues = IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(consumerConfig, outputTopic, expectedReceivedValues.size()); + final List actualValues = new ArrayList<>(expectedReceivedValues.size()); + + for (final KeyValue receivedKeyValue : receivedKeyValues) { + actualValues.add(receivedKeyValue.value); + } + + streams.close(); + Collections.sort(actualValues); + Collections.sort(expectedReceivedValues); + assertThat(actualValues, equalTo(expectedReceivedValues)); + } + + private void commitInvalidOffsets() { + final KafkaConsumer consumer = new KafkaConsumer<>(TestUtils.consumerConfig( + CLUSTER.bootstrapServers(), + "commit_invalid_offset_app", // Having a separate application id to avoid waiting for last test poll interval timeout. + StringDeserializer.class, + StringDeserializer.class)); + + final Map invalidOffsets = new HashMap<>(); + invalidOffsets.put(new TopicPartition(TOPIC_1_2, 0), new OffsetAndMetadata(5, null)); + invalidOffsets.put(new TopicPartition(TOPIC_2_2, 0), new OffsetAndMetadata(5, null)); + invalidOffsets.put(new TopicPartition(TOPIC_A_2, 0), new OffsetAndMetadata(5, null)); + invalidOffsets.put(new TopicPartition(TOPIC_C_2, 0), new OffsetAndMetadata(5, null)); + invalidOffsets.put(new TopicPartition(TOPIC_Y_2, 0), new OffsetAndMetadata(5, null)); + invalidOffsets.put(new TopicPartition(TOPIC_Z_2, 0), new OffsetAndMetadata(5, null)); + + consumer.commitSync(invalidOffsets); + + consumer.close(); + } + + @Test + public void shouldThrowExceptionOverlappingPattern() { + final StreamsBuilder builder = new StreamsBuilder(); + //NOTE this would realistically get caught when building topology, the test is for completeness + builder.stream(Pattern.compile("topic-[A-D]_1"), Consumed.with(Topology.AutoOffsetReset.EARLIEST)); + + try { + builder.stream(Pattern.compile("topic-[A-D]_1"), Consumed.with(Topology.AutoOffsetReset.LATEST)); + builder.build(); + fail("Should have thrown TopologyException"); + } catch (final TopologyException expected) { + // do nothing + } + } + + @Test + public void shouldThrowExceptionOverlappingTopic() { + final StreamsBuilder builder = new StreamsBuilder(); + //NOTE this would realistically get caught when building topology, the test is for completeness + builder.stream(Pattern.compile("topic-[A-D]_1"), Consumed.with(Topology.AutoOffsetReset.EARLIEST)); + try { + builder.stream(Arrays.asList(TOPIC_A_1, TOPIC_Z_1), Consumed.with(Topology.AutoOffsetReset.LATEST)); + builder.build(); + fail("Should have thrown TopologyException"); + } catch (final TopologyException expected) { + // do nothing + } + } + + @Test + public void shouldThrowStreamsExceptionNoResetSpecified() throws InterruptedException { + final Properties props = new Properties(); + props.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + props.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L); + props.put(ConsumerConfig.METADATA_MAX_AGE_CONFIG, "1000"); + props.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "none"); + + final Properties localConfig = StreamsTestUtils.getStreamsConfig( + "testAutoOffsetWithNone", + CLUSTER.bootstrapServers(), + STRING_SERDE_CLASSNAME, + STRING_SERDE_CLASSNAME, + props); + + final StreamsBuilder builder = new StreamsBuilder(); + final KStream exceptionStream = builder.stream(NOOP); + + exceptionStream.to(DEFAULT_OUTPUT_TOPIC, Produced.with(stringSerde, stringSerde)); + + final KafkaStreams streams = new KafkaStreams(builder.build(), localConfig); + + final TestingUncaughtExceptionHandler uncaughtExceptionHandler = new TestingUncaughtExceptionHandler(); + + streams.setUncaughtExceptionHandler(uncaughtExceptionHandler); + streams.start(); + TestUtils.waitForCondition(() -> uncaughtExceptionHandler.correctExceptionThrown, + "The expected NoOffsetForPartitionException was never thrown"); + streams.close(); + } + + + private static final class TestingUncaughtExceptionHandler implements StreamsUncaughtExceptionHandler { + boolean correctExceptionThrown = false; + @Override + public StreamThreadExceptionResponse handle(final Throwable throwable) { + assertThat(throwable.getClass().getSimpleName(), is("StreamsException")); + assertThat(throwable.getCause().getClass().getSimpleName(), is("NoOffsetForPartitionException")); + correctExceptionThrown = true; + return StreamThreadExceptionResponse.SHUTDOWN_CLIENT; + } + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/ForeignKeyJoinSuite.java b/streams/src/test/java/org/apache/kafka/streams/integration/ForeignKeyJoinSuite.java new file mode 100644 index 0000000..5dd8c05 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/ForeignKeyJoinSuite.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.common.utils.BytesTest; +import org.apache.kafka.streams.kstream.internals.KTableKTableForeignKeyJoinScenarioTest; +import org.apache.kafka.streams.kstream.internals.foreignkeyjoin.CombinedKeySchemaTest; +import org.apache.kafka.streams.kstream.internals.foreignkeyjoin.SubscriptionResolverJoinProcessorSupplierTest; +import org.apache.kafka.streams.kstream.internals.foreignkeyjoin.SubscriptionResponseWrapperSerdeTest; +import org.apache.kafka.streams.kstream.internals.foreignkeyjoin.SubscriptionWrapperSerdeTest; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; + +/** + * This suite runs all the tests related to the KTable-KTable foreign key join feature. + * + * It can be used from an IDE to selectively just run these tests when developing code related to KTable-KTable + * foreign key join. + * + * If desired, it can also be added to a Gradle build task, although this isn't strictly necessary, since all + * these tests are already included in the `:streams:test` task. + */ +@RunWith(Suite.class) +@Suite.SuiteClasses({ + BytesTest.class, + KTableKTableForeignKeyInnerJoinMultiIntegrationTest.class, + KTableKTableForeignKeyJoinIntegrationTest.class, + KTableKTableForeignKeyJoinMaterializationIntegrationTest.class, + KTableKTableForeignKeyJoinScenarioTest.class, + CombinedKeySchemaTest.class, + SubscriptionWrapperSerdeTest.class, + SubscriptionResponseWrapperSerdeTest.class, + SubscriptionResolverJoinProcessorSupplierTest.class +}) +public class ForeignKeyJoinSuite { +} + + diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/GlobalKTableEOSIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/GlobalKTableEOSIntegrationTest.java new file mode 100644 index 0000000..097a79f --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/GlobalKTableEOSIntegrationTest.java @@ -0,0 +1,570 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import java.io.File; +import java.util.Collections; +import java.util.concurrent.atomic.AtomicReference; +import kafka.utils.MockTime; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.LongSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.ForeachAction; +import org.apache.kafka.streams.kstream.GlobalKTable; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.ValueJoiner; +import org.apache.kafka.streams.processor.StateRestoreListener; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.QueryableStoreTypes; +import org.apache.kafka.streams.state.ReadOnlyKeyValueStore; +import org.apache.kafka.streams.state.internals.OffsetCheckpoint; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.Properties; + +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +@RunWith(Parameterized.class) +@Category({IntegrationTest.class}) +public class GlobalKTableEOSIntegrationTest { + private static final int NUM_BROKERS = 1; + private static final Properties BROKER_CONFIG; + static { + BROKER_CONFIG = new Properties(); + BROKER_CONFIG.put("transaction.state.log.replication.factor", (short) 1); + BROKER_CONFIG.put("transaction.state.log.min.isr", 1); + } + + public static final EmbeddedKafkaCluster CLUSTER = + new EmbeddedKafkaCluster(NUM_BROKERS, BROKER_CONFIG); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + @SuppressWarnings("deprecation") + @Parameterized.Parameters(name = "{0}") + public static Collection data() { + return Arrays.asList(new String[][] { + {StreamsConfig.EXACTLY_ONCE}, + {StreamsConfig.EXACTLY_ONCE_V2} + }); + } + + @Parameterized.Parameter + public String eosConfig; + + private final MockTime mockTime = CLUSTER.time; + private final KeyValueMapper keyMapper = (key, value) -> value; + private final ValueJoiner joiner = (value1, value2) -> value1 + "+" + value2; + private final String globalStore = "globalStore"; + private final Map results = new HashMap<>(); + private StreamsBuilder builder; + private Properties streamsConfiguration; + private KafkaStreams kafkaStreams; + private String globalTableTopic; + private String streamTopic; + private GlobalKTable globalTable; + private KStream stream; + private ForeachAction foreachAction; + + @Rule + public TestName testName = new TestName(); + + @Before + public void before() throws Exception { + builder = new StreamsBuilder(); + createTopics(); + streamsConfiguration = new Properties(); + final String safeTestName = safeUniqueTestName(getClass(), testName); + streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, "app-" + safeTestName); + streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()); + streamsConfiguration.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0L); + streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L); + streamsConfiguration.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, eosConfig); + streamsConfiguration.put(StreamsConfig.TASK_TIMEOUT_MS_CONFIG, 1L); + streamsConfiguration.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + streamsConfiguration.put(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, 1000); + streamsConfiguration.put(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, 300); + streamsConfiguration.put(ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG, 5000); + globalTable = builder.globalTable( + globalTableTopic, + Consumed.with(Serdes.Long(), Serdes.String()), + Materialized.>as(globalStore) + .withKeySerde(Serdes.Long()) + .withValueSerde(Serdes.String())); + final Consumed stringLongConsumed = Consumed.with(Serdes.String(), Serdes.Long()); + stream = builder.stream(streamTopic, stringLongConsumed); + foreachAction = results::put; + } + + @After + public void after() throws Exception { + if (kafkaStreams != null) { + kafkaStreams.close(); + } + IntegrationTestUtils.purgeLocalStreamsState(streamsConfiguration); + } + + @Test + public void shouldKStreamGlobalKTableLeftJoin() throws Exception { + final KStream streamTableJoin = stream.leftJoin(globalTable, keyMapper, joiner); + streamTableJoin.foreach(foreachAction); + produceInitialGlobalTableValues(); + startStreams(); + produceTopicValues(streamTopic); + + final Map expected = new HashMap<>(); + expected.put("a", "1+A"); + expected.put("b", "2+B"); + expected.put("c", "3+C"); + expected.put("d", "4+D"); + expected.put("e", "5+null"); + + TestUtils.waitForCondition( + () -> results.equals(expected), + 30_000L, + () -> "waiting for initial values;" + + "\n expected: " + expected + + "\n received: " + results + ); + + + produceGlobalTableValues(); + + final ReadOnlyKeyValueStore replicatedStore = IntegrationTestUtils + .getStore(globalStore, kafkaStreams, QueryableStoreTypes.keyValueStore()); + assertNotNull(replicatedStore); + + + final Map expectedState = new HashMap<>(); + expectedState.put(1L, "F"); + expectedState.put(2L, "G"); + expectedState.put(3L, "H"); + expectedState.put(4L, "I"); + expectedState.put(5L, "J"); + + final Map globalState = new HashMap<>(); + TestUtils.waitForCondition( + () -> { + globalState.clear(); + replicatedStore.all().forEachRemaining(pair -> globalState.put(pair.key, pair.value)); + return globalState.equals(expectedState); + }, + 30_000L, + () -> "waiting for data in replicated store" + + "\n expected: " + expectedState + + "\n received: " + globalState + ); + + + produceTopicValues(streamTopic); + + expected.put("a", "1+F"); + expected.put("b", "2+G"); + expected.put("c", "3+H"); + expected.put("d", "4+I"); + expected.put("e", "5+J"); + + TestUtils.waitForCondition( + () -> results.equals(expected), + 30_000L, + () -> "waiting for final values" + + "\n expected: " + expected + + "\n received: " + results + ); + } + + @Test + public void shouldKStreamGlobalKTableJoin() throws Exception { + final KStream streamTableJoin = stream.join(globalTable, keyMapper, joiner); + streamTableJoin.foreach(foreachAction); + produceInitialGlobalTableValues(); + startStreams(); + produceTopicValues(streamTopic); + + final Map expected = new HashMap<>(); + expected.put("a", "1+A"); + expected.put("b", "2+B"); + expected.put("c", "3+C"); + expected.put("d", "4+D"); + + TestUtils.waitForCondition( + () -> results.equals(expected), + 30_000L, + () -> "waiting for initial values" + + "\n expected: " + expected + + "\n received: " + results + ); + + + produceGlobalTableValues(); + + final ReadOnlyKeyValueStore replicatedStore = IntegrationTestUtils + .getStore(globalStore, kafkaStreams, QueryableStoreTypes.keyValueStore()); + assertNotNull(replicatedStore); + + + final Map expectedState = new HashMap<>(); + expectedState.put(1L, "F"); + expectedState.put(2L, "G"); + expectedState.put(3L, "H"); + expectedState.put(4L, "I"); + expectedState.put(5L, "J"); + + final Map globalState = new HashMap<>(); + TestUtils.waitForCondition( + () -> { + globalState.clear(); + replicatedStore.all().forEachRemaining(pair -> globalState.put(pair.key, pair.value)); + return globalState.equals(expectedState); + }, + 30_000L, + () -> "waiting for data in replicated store" + + "\n expected: " + expectedState + + "\n received: " + globalState + ); + + + produceTopicValues(streamTopic); + + expected.put("a", "1+F"); + expected.put("b", "2+G"); + expected.put("c", "3+H"); + expected.put("d", "4+I"); + expected.put("e", "5+J"); + + TestUtils.waitForCondition( + () -> results.equals(expected), + 30_000L, + () -> "waiting for final values" + + "\n expected: " + expected + + "\n received: " + results + ); + } + + @Test + public void shouldRestoreTransactionalMessages() throws Exception { + produceInitialGlobalTableValues(); + + startStreams(); + + final Map expected = new HashMap<>(); + expected.put(1L, "A"); + expected.put(2L, "B"); + expected.put(3L, "C"); + expected.put(4L, "D"); + + final ReadOnlyKeyValueStore store = IntegrationTestUtils + .getStore(globalStore, kafkaStreams, QueryableStoreTypes.keyValueStore()); + assertNotNull(store); + + final Map result = new HashMap<>(); + TestUtils.waitForCondition( + () -> { + result.clear(); + final Iterator> it = store.all(); + while (it.hasNext()) { + final KeyValue kv = it.next(); + result.put(kv.key, kv.value); + } + return result.equals(expected); + }, + 30_000L, + () -> "waiting for initial values" + + "\n expected: " + expected + + "\n received: " + result + ); + } + + @Test + public void shouldSkipOverTxMarkersOnRestore() throws Exception { + shouldSkipOverTxMarkersAndAbortedMessagesOnRestore(false); + } + + @Test + public void shouldSkipOverAbortedMessagesOnRestore() throws Exception { + shouldSkipOverTxMarkersAndAbortedMessagesOnRestore(true); + } + + private void shouldSkipOverTxMarkersAndAbortedMessagesOnRestore(final boolean appendAbortedMessages) throws Exception { + // records with key 1L, 2L, and 4L are written into partition-0 + // record with key 3L is written into partition-1 + produceInitialGlobalTableValues(); + + final String stateDir = streamsConfiguration.getProperty(StreamsConfig.STATE_DIR_CONFIG); + final File globalStateDir = new File( + stateDir + + File.separator + + streamsConfiguration.getProperty(StreamsConfig.APPLICATION_ID_CONFIG) + + File.separator + + "global"); + assertTrue(globalStateDir.mkdirs()); + final OffsetCheckpoint checkpoint = new OffsetCheckpoint(new File(globalStateDir, ".checkpoint")); + + // set the checkpointed offset to the commit marker of partition-1 + // even if `poll()` won't return any data for partition-1, we should still finish the restore + checkpoint.write(Collections.singletonMap(new TopicPartition(globalTableTopic, 1), 1L)); + + if (appendAbortedMessages) { + final AtomicReference error = new AtomicReference<>(); + startStreams(new StateRestoreListener() { + @Override + public void onRestoreStart(final TopicPartition topicPartition, + final String storeName, + final long startingOffset, + final long endingOffset) { + // we need to write aborted messages only after we init the `highWatermark` + // to move the `endOffset` beyond the `highWatermark + // + // we cannot write committed messages because we want to test the case that + // poll() returns no records + // + // cf. GlobalStateManagerImpl#restoreState() + try { + produceAbortedMessages(); + } catch (final Exception fatal) { + error.set(fatal); + } + } + + @Override + public void onBatchRestored(final TopicPartition topicPartition, + final String storeName, + final long batchEndOffset, + final long numRestored) { } + + @Override + public void onRestoreEnd(final TopicPartition topicPartition, + final String storeName, + final long totalRestored) { } + }); + final Exception fatal = error.get(); + if (fatal != null) { + throw fatal; + } + } else { + startStreams(); + } + + final Map expected = new HashMap<>(); + expected.put(1L, "A"); + expected.put(2L, "B"); + // skip record <3L, "C"> because we won't read it (cf checkpoint file above) + expected.put(4L, "D"); + + final ReadOnlyKeyValueStore store = IntegrationTestUtils + .getStore(globalStore, kafkaStreams, QueryableStoreTypes.keyValueStore()); + assertNotNull(store); + + final Map storeContent = new HashMap<>(); + TestUtils.waitForCondition( + () -> { + storeContent.clear(); + final Iterator> it = store.all(); + while (it.hasNext()) { + final KeyValue kv = it.next(); + storeContent.put(kv.key, kv.value); + } + return storeContent.equals(expected); + }, + 30_000L, + () -> "waiting for initial values" + + "\n expected: " + expected + + "\n received: " + storeContent + ); + } + + @Test + public void shouldNotRestoreAbortedMessages() throws Exception { + produceAbortedMessages(); + produceInitialGlobalTableValues(); + produceAbortedMessages(); + + startStreams(); + + final Map expected = new HashMap<>(); + expected.put(1L, "A"); + expected.put(2L, "B"); + expected.put(3L, "C"); + expected.put(4L, "D"); + + final ReadOnlyKeyValueStore store = IntegrationTestUtils + .getStore(globalStore, kafkaStreams, QueryableStoreTypes.keyValueStore()); + assertNotNull(store); + + final Map storeContent = new HashMap<>(); + TestUtils.waitForCondition( + () -> { + storeContent.clear(); + store.all().forEachRemaining(pair -> storeContent.put(pair.key, pair.value)); + return storeContent.equals(expected); + }, + 30_000L, + () -> "waiting for initial values" + + "\n expected: " + expected + + "\n received: " + storeContent + ); + } + + private void createTopics() throws Exception { + final String safeTestName = safeUniqueTestName(getClass(), testName); + streamTopic = "stream-" + safeTestName; + globalTableTopic = "globalTable-" + safeTestName; + CLUSTER.createTopics(streamTopic); + CLUSTER.createTopic(globalTableTopic, 2, 1); + } + + private void startStreams() { + startStreams(null); + } + + private void startStreams(final StateRestoreListener stateRestoreListener) { + kafkaStreams = new KafkaStreams(builder.build(), streamsConfiguration); + kafkaStreams.setGlobalStateRestoreListener(stateRestoreListener); + kafkaStreams.start(); + } + + private void produceTopicValues(final String topic) { + final Properties config = new Properties(); + config.setProperty(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG, "true"); + + IntegrationTestUtils.produceKeyValuesSynchronously( + topic, + Arrays.asList( + new KeyValue<>("a", 1L), + new KeyValue<>("b", 2L), + new KeyValue<>("c", 3L), + new KeyValue<>("d", 4L), + new KeyValue<>("e", 5L) + ), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + StringSerializer.class, + LongSerializer.class, + config + ), + mockTime + ); + } + + private void produceAbortedMessages() throws Exception { + final Properties properties = new Properties(); + properties.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "someid"); + + IntegrationTestUtils.produceAbortedKeyValuesSynchronouslyWithTimestamp( + globalTableTopic, Arrays.asList( + new KeyValue<>(1L, "A"), + new KeyValue<>(2L, "B"), + new KeyValue<>(3L, "C"), + new KeyValue<>(4L, "D") + ), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + LongSerializer.class, + StringSerializer.class, + properties + ), + mockTime.milliseconds() + ); + } + + private void produceInitialGlobalTableValues() { + final Properties properties = new Properties(); + properties.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "someid"); + + IntegrationTestUtils.produceKeyValuesSynchronously( + globalTableTopic, + Arrays.asList( + new KeyValue<>(1L, "A"), + new KeyValue<>(2L, "B"), + new KeyValue<>(3L, "C"), + new KeyValue<>(4L, "D") + ), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + LongSerializer.class, + StringSerializer.class, + properties + ), + mockTime, + true + ); + } + + private void produceGlobalTableValues() { + final Properties config = new Properties(); + config.setProperty(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG, "true"); + + IntegrationTestUtils.produceKeyValuesSynchronously( + globalTableTopic, + Arrays.asList( + new KeyValue<>(1L, "F"), + new KeyValue<>(2L, "G"), + new KeyValue<>(3L, "H"), + new KeyValue<>(4L, "I"), + new KeyValue<>(5L, "J") + ), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + LongSerializer.class, + StringSerializer.class, + config + ), + mockTime + ); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/GlobalKTableIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/GlobalKTableIntegrationTest.java new file mode 100644 index 0000000..4bc69e6 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/GlobalKTableIntegrationTest.java @@ -0,0 +1,406 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import kafka.utils.MockTime; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.serialization.LongSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KafkaStreams.State; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.GlobalKTable; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.ValueJoiner; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.QueryableStoreTypes; +import org.apache.kafka.streams.state.ReadOnlyKeyValueStore; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; + +import java.io.IOException; +import java.time.Duration; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; + +import static java.util.Collections.singletonList; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.waitForApplicationState; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.IsEqual.equalTo; +import static org.junit.Assert.assertNotNull; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +@Category({IntegrationTest.class}) +public class GlobalKTableIntegrationTest { + private static final int NUM_BROKERS = 1; + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(NUM_BROKERS); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + + private final MockTime mockTime = CLUSTER.time; + private final KeyValueMapper keyMapper = (key, value) -> value; + private final ValueJoiner joiner = (value1, value2) -> value1 + "+" + value2; + private final String globalStore = "globalStore"; + private StreamsBuilder builder; + private Properties streamsConfiguration; + private KafkaStreams kafkaStreams; + private String globalTableTopic; + private String streamTopic; + private GlobalKTable globalTable; + private KStream stream; + private MockProcessorSupplier supplier; + + @Rule + public TestName testName = new TestName(); + + @Before + public void before() throws Exception { + builder = new StreamsBuilder(); + createTopics(); + streamsConfiguration = new Properties(); + final String safeTestName = safeUniqueTestName(getClass(), testName); + streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, "app-" + safeTestName); + streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + streamsConfiguration.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()); + streamsConfiguration.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L); + globalTable = builder.globalTable(globalTableTopic, Consumed.with(Serdes.Long(), Serdes.String()), + Materialized.>as(globalStore) + .withKeySerde(Serdes.Long()) + .withValueSerde(Serdes.String())); + final Consumed stringLongConsumed = Consumed.with(Serdes.String(), Serdes.Long()); + stream = builder.stream(streamTopic, stringLongConsumed); + supplier = new MockProcessorSupplier<>(); + } + + @After + public void whenShuttingDown() throws Exception { + if (kafkaStreams != null) { + kafkaStreams.close(); + } + IntegrationTestUtils.purgeLocalStreamsState(streamsConfiguration); + } + + @Test + public void shouldKStreamGlobalKTableLeftJoin() throws Exception { + final KStream streamTableJoin = stream.leftJoin(globalTable, keyMapper, joiner); + streamTableJoin.process(supplier); + produceInitialGlobalTableValues(); + startStreams(); + long firstTimestamp = mockTime.milliseconds(); + produceTopicValues(streamTopic); + + final Map> expected = new HashMap<>(); + expected.put("a", ValueAndTimestamp.make("1+A", firstTimestamp)); + expected.put("b", ValueAndTimestamp.make("2+B", firstTimestamp + 1L)); + expected.put("c", ValueAndTimestamp.make("3+C", firstTimestamp + 2L)); + expected.put("d", ValueAndTimestamp.make("4+D", firstTimestamp + 3L)); + expected.put("e", ValueAndTimestamp.make("5+null", firstTimestamp + 4L)); + + TestUtils.waitForCondition( + () -> { + if (supplier.capturedProcessorsCount() < 2) { + return false; + } + final Map> result = new HashMap<>(); + result.putAll(supplier.capturedProcessors(2).get(0).lastValueAndTimestampPerKey()); + result.putAll(supplier.capturedProcessors(2).get(1).lastValueAndTimestampPerKey()); + return result.equals(expected); + }, + 30000L, + "waiting for initial values"); + + firstTimestamp = mockTime.milliseconds(); + produceGlobalTableValues(); + + final ReadOnlyKeyValueStore replicatedStore = IntegrationTestUtils + .getStore(globalStore, kafkaStreams, QueryableStoreTypes.keyValueStore()); + assertNotNull(replicatedStore); + + final Map expectedState = new HashMap<>(); + expectedState.put(1L, "F"); + expectedState.put(2L, "G"); + expectedState.put(3L, "H"); + expectedState.put(4L, "I"); + expectedState.put(5L, "J"); + + final Map globalState = new HashMap<>(); + TestUtils.waitForCondition( + () -> { + globalState.clear(); + replicatedStore.all().forEachRemaining(pair -> globalState.put(pair.key, pair.value)); + return globalState.equals(expectedState); + }, + 30000, + () -> "waiting for data in replicated store" + + "\n expected: " + expectedState + + "\n received: " + globalState); + + final ReadOnlyKeyValueStore> replicatedStoreWithTimestamp = IntegrationTestUtils + .getStore(globalStore, kafkaStreams, QueryableStoreTypes.timestampedKeyValueStore()); + assertNotNull(replicatedStoreWithTimestamp); + assertThat(replicatedStoreWithTimestamp.get(5L), equalTo(ValueAndTimestamp.make("J", firstTimestamp + 4L))); + + firstTimestamp = mockTime.milliseconds(); + produceTopicValues(streamTopic); + + expected.put("a", ValueAndTimestamp.make("1+F", firstTimestamp)); + expected.put("b", ValueAndTimestamp.make("2+G", firstTimestamp + 1L)); + expected.put("c", ValueAndTimestamp.make("3+H", firstTimestamp + 2L)); + expected.put("d", ValueAndTimestamp.make("4+I", firstTimestamp + 3L)); + expected.put("e", ValueAndTimestamp.make("5+J", firstTimestamp + 4L)); + + TestUtils.waitForCondition( + () -> { + if (supplier.capturedProcessorsCount() < 2) { + return false; + } + final Map> result = new HashMap<>(); + result.putAll(supplier.capturedProcessors(2).get(0).lastValueAndTimestampPerKey()); + result.putAll(supplier.capturedProcessors(2).get(1).lastValueAndTimestampPerKey()); + return result.equals(expected); + }, + 30000L, + "waiting for final values"); + } + + @Test + public void shouldKStreamGlobalKTableJoin() throws Exception { + final KStream streamTableJoin = stream.join(globalTable, keyMapper, joiner); + streamTableJoin.process(supplier); + produceInitialGlobalTableValues(); + startStreams(); + long firstTimestamp = mockTime.milliseconds(); + produceTopicValues(streamTopic); + + final Map> expected = new HashMap<>(); + expected.put("a", ValueAndTimestamp.make("1+A", firstTimestamp)); + expected.put("b", ValueAndTimestamp.make("2+B", firstTimestamp + 1L)); + expected.put("c", ValueAndTimestamp.make("3+C", firstTimestamp + 2L)); + expected.put("d", ValueAndTimestamp.make("4+D", firstTimestamp + 3L)); + + TestUtils.waitForCondition( + () -> { + if (supplier.capturedProcessorsCount() < 2) { + return false; + } + final Map> result = new HashMap<>(); + result.putAll(supplier.capturedProcessors(2).get(0).lastValueAndTimestampPerKey()); + result.putAll(supplier.capturedProcessors(2).get(1).lastValueAndTimestampPerKey()); + return result.equals(expected); + }, + 30000L, + "waiting for initial values"); + + + firstTimestamp = mockTime.milliseconds(); + produceGlobalTableValues(); + + final ReadOnlyKeyValueStore replicatedStore = IntegrationTestUtils + .getStore(globalStore, kafkaStreams, QueryableStoreTypes.keyValueStore()); + assertNotNull(replicatedStore); + + final Map expectedState = new HashMap<>(); + expectedState.put(1L, "F"); + expectedState.put(2L, "G"); + expectedState.put(3L, "H"); + expectedState.put(4L, "I"); + expectedState.put(5L, "J"); + + final Map globalState = new HashMap<>(); + TestUtils.waitForCondition( + () -> { + globalState.clear(); + replicatedStore.all().forEachRemaining(pair -> globalState.put(pair.key, pair.value)); + return globalState.equals(expectedState); + }, + 30000, + () -> "waiting for data in replicated store" + + "\n expected: " + expectedState + + "\n received: " + globalState); + + final ReadOnlyKeyValueStore> replicatedStoreWithTimestamp = IntegrationTestUtils + .getStore(globalStore, kafkaStreams, QueryableStoreTypes.timestampedKeyValueStore()); + assertNotNull(replicatedStoreWithTimestamp); + assertThat(replicatedStoreWithTimestamp.get(5L), equalTo(ValueAndTimestamp.make("J", firstTimestamp + 4L))); + + firstTimestamp = mockTime.milliseconds(); + produceTopicValues(streamTopic); + + expected.put("a", ValueAndTimestamp.make("1+F", firstTimestamp)); + expected.put("b", ValueAndTimestamp.make("2+G", firstTimestamp + 1L)); + expected.put("c", ValueAndTimestamp.make("3+H", firstTimestamp + 2L)); + expected.put("d", ValueAndTimestamp.make("4+I", firstTimestamp + 3L)); + expected.put("e", ValueAndTimestamp.make("5+J", firstTimestamp + 4L)); + + TestUtils.waitForCondition( + () -> { + if (supplier.capturedProcessorsCount() < 2) { + return false; + } + final Map> result = new HashMap<>(); + result.putAll(supplier.capturedProcessors(2).get(0).lastValueAndTimestampPerKey()); + result.putAll(supplier.capturedProcessors(2).get(1).lastValueAndTimestampPerKey()); + return result.equals(expected); + }, + 30000L, + "waiting for final values"); + } + + @Test + public void shouldRestoreGlobalInMemoryKTableOnRestart() throws Exception { + builder = new StreamsBuilder(); + globalTable = builder.globalTable( + globalTableTopic, + Consumed.with(Serdes.Long(), Serdes.String()), + Materialized.as(Stores.inMemoryKeyValueStore(globalStore))); + + produceInitialGlobalTableValues(); + + startStreams(); + ReadOnlyKeyValueStore store = IntegrationTestUtils + .getStore(globalStore, kafkaStreams, QueryableStoreTypes.keyValueStore()); + assertNotNull(store); + + assertThat(store.approximateNumEntries(), equalTo(4L)); + + ReadOnlyKeyValueStore> timestampedStore = IntegrationTestUtils + .getStore(globalStore, kafkaStreams, QueryableStoreTypes.timestampedKeyValueStore()); + assertNotNull(timestampedStore); + + assertThat(timestampedStore.approximateNumEntries(), equalTo(4L)); + kafkaStreams.close(); + + startStreams(); + store = IntegrationTestUtils.getStore(globalStore, kafkaStreams, QueryableStoreTypes.keyValueStore()); + assertThat(store.approximateNumEntries(), equalTo(4L)); + timestampedStore = IntegrationTestUtils.getStore(globalStore, kafkaStreams, QueryableStoreTypes.timestampedKeyValueStore()); + assertThat(timestampedStore.approximateNumEntries(), equalTo(4L)); + } + + @Test + public void shouldGetToRunningWithOnlyGlobalTopology() throws Exception { + builder = new StreamsBuilder(); + globalTable = builder.globalTable( + globalTableTopic, + Consumed.with(Serdes.Long(), Serdes.String()), + Materialized.as(Stores.inMemoryKeyValueStore(globalStore))); + + startStreams(); + waitForApplicationState(singletonList(kafkaStreams), State.RUNNING, Duration.ofSeconds(30)); + + kafkaStreams.close(); + } + + private void createTopics() throws Exception { + final String safeTestName = safeUniqueTestName(getClass(), testName); + streamTopic = "stream-" + safeTestName; + globalTableTopic = "globalTable-" + safeTestName; + CLUSTER.createTopics(streamTopic); + CLUSTER.createTopic(globalTableTopic, 2, 1); + } + + private void startStreams() { + kafkaStreams = new KafkaStreams(builder.build(), streamsConfiguration); + kafkaStreams.start(); + } + + private void produceTopicValues(final String topic) { + IntegrationTestUtils.produceKeyValuesSynchronously( + topic, + Arrays.asList( + new KeyValue<>("a", 1L), + new KeyValue<>("b", 2L), + new KeyValue<>("c", 3L), + new KeyValue<>("d", 4L), + new KeyValue<>("e", 5L)), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + StringSerializer.class, + LongSerializer.class, + new Properties()), + mockTime); + } + + private void produceInitialGlobalTableValues() { + IntegrationTestUtils.produceKeyValuesSynchronously( + globalTableTopic, + Arrays.asList( + new KeyValue<>(1L, "A"), + new KeyValue<>(2L, "B"), + new KeyValue<>(3L, "C"), + new KeyValue<>(4L, "D") + ), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + LongSerializer.class, + StringSerializer.class + ), + mockTime); + } + + private void produceGlobalTableValues() { + IntegrationTestUtils.produceKeyValuesSynchronously( + globalTableTopic, + Arrays.asList( + new KeyValue<>(1L, "F"), + new KeyValue<>(2L, "G"), + new KeyValue<>(3L, "H"), + new KeyValue<>(4L, "I"), + new KeyValue<>(5L, "J")), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + LongSerializer.class, + StringSerializer.class, + new Properties()), + mockTime); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/GlobalThreadShutDownOrderTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/GlobalThreadShutDownOrderTest.java new file mode 100644 index 0000000..6d89325 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/GlobalThreadShutDownOrderTest.java @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.integration; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.serialization.LongSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.internals.KeyValueStoreBuilder; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.MockApiProcessorSupplier; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; + +import java.io.IOException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Properties; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; +import static org.junit.Assert.assertEquals; + + +/** + * This test asserts that when Kafka Streams is closing and shuts + * down a StreamThread the closing of the GlobalStreamThread happens + * after all the StreamThreads are completely stopped. + * + * The test validates the Processor still has access to the GlobalStateStore while closing. + * Otherwise if the GlobalStreamThread were to close underneath the StreamThread + * an exception would be thrown as the GlobalStreamThread closes all global stores on closing. + */ +@Category({IntegrationTest.class}) +public class GlobalThreadShutDownOrderTest { + + private static final int NUM_BROKERS = 1; + private static final Properties BROKER_CONFIG; + + static { + BROKER_CONFIG = new Properties(); + BROKER_CONFIG.put("transaction.state.log.replication.factor", (short) 1); + BROKER_CONFIG.put("transaction.state.log.min.isr", 1); + } + + private final AtomicInteger closeCounter = new AtomicInteger(0); + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(NUM_BROKERS, BROKER_CONFIG); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + + private final MockTime mockTime = CLUSTER.time; + private final String globalStore = "globalStore"; + private StreamsBuilder builder; + private Properties streamsConfiguration; + private KafkaStreams kafkaStreams; + private String globalStoreTopic; + private String streamTopic; + private final List retrievedValuesList = new ArrayList<>(); + private boolean firstRecordProcessed; + + @Rule + public TestName testName = new TestName(); + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Before + public void before() throws Exception { + builder = new StreamsBuilder(); + createTopics(); + streamsConfiguration = new Properties(); + final String safeTestName = safeUniqueTestName(getClass(), testName); + streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, "app-" + safeTestName); + streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + streamsConfiguration.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()); + streamsConfiguration.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L); + + final Consumed stringLongConsumed = Consumed.with(Serdes.String(), Serdes.Long()); + + final KeyValueStoreBuilder storeBuilder = new KeyValueStoreBuilder<>( + Stores.persistentKeyValueStore(globalStore), + Serdes.String(), + Serdes.Long(), + mockTime); + + builder.addGlobalStore( + storeBuilder, + globalStoreTopic, + Consumed.with(Serdes.String(), Serdes.Long()), + new MockApiProcessorSupplier<>() + ); + + builder + .stream(streamTopic, stringLongConsumed) + .process(() -> new GlobalStoreProcessor(globalStore)); + + } + + @After + public void after() throws Exception { + if (kafkaStreams != null) { + kafkaStreams.close(); + } + IntegrationTestUtils.purgeLocalStreamsState(streamsConfiguration); + } + + @Test + public void shouldFinishGlobalStoreOperationOnShutDown() throws Exception { + kafkaStreams = new KafkaStreams(builder.build(), streamsConfiguration); + populateTopics(globalStoreTopic); + populateTopics(streamTopic); + + kafkaStreams.start(); + + TestUtils.waitForCondition( + () -> firstRecordProcessed, + 30000, + "Has not processed record within 30 seconds"); + + kafkaStreams.close(Duration.ofSeconds(30)); + + final List expectedRetrievedValues = Arrays.asList(1L, 2L, 3L, 4L); + assertEquals(expectedRetrievedValues, retrievedValuesList); + assertEquals(1, closeCounter.get()); + } + + + private void createTopics() throws Exception { + streamTopic = "stream-topic"; + globalStoreTopic = "global-store-topic"; + CLUSTER.createTopics(streamTopic); + CLUSTER.createTopic(globalStoreTopic); + } + + + private void populateTopics(final String topicName) throws Exception { + IntegrationTestUtils.produceKeyValuesSynchronously( + topicName, + Arrays.asList( + new KeyValue<>("A", 1L), + new KeyValue<>("B", 2L), + new KeyValue<>("C", 3L), + new KeyValue<>("D", 4L)), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + StringSerializer.class, + LongSerializer.class, + new Properties()), + mockTime); + } + + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + private class GlobalStoreProcessor extends org.apache.kafka.streams.processor.AbstractProcessor { + + private KeyValueStore store; + private final String storeName; + + GlobalStoreProcessor(final String storeName) { + this.storeName = storeName; + } + + @Override + @SuppressWarnings("unchecked") + public void init(final ProcessorContext context) { + super.init(context); + store = (KeyValueStore) context.getStateStore(storeName); + } + + @Override + public void process(final String key, final Long value) { + firstRecordProcessed = true; + } + + + @Override + public void close() { + closeCounter.getAndIncrement(); + final List keys = Arrays.asList("A", "B", "C", "D"); + for (final String key : keys) { + // need to simulate thread slow in closing + Utils.sleep(1000); + retrievedValuesList.add(store.get(key)); + } + } + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/HighAvailabilityTaskAssignorIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/HighAvailabilityTaskAssignorIntegrationTest.java new file mode 100644 index 0000000..2b67a0f --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/HighAvailabilityTaskAssignorIntegrationTest.java @@ -0,0 +1,320 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.Producer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.processor.StateRestoreListener; +import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentListener; +import org.apache.kafka.streams.processor.internals.assignment.HighAvailabilityTaskAssignor; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.NoRetryException; +import org.apache.kafka.test.TestUtils; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; + +import java.io.IOException; +import java.util.Collection; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Function; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkObjectProperties; +import static org.apache.kafka.common.utils.Utils.mkProperties; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; + +@Category(IntegrationTest.class) +public class HighAvailabilityTaskAssignorIntegrationTest { + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(1); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + + @Rule + public TestName testName = new TestName(); + + @Test + public void shouldScaleOutWithWarmupTasksAndInMemoryStores() throws InterruptedException { + // NB: this test takes at least a minute to run, because it needs a probing rebalance, and the minimum + // value is one minute + shouldScaleOutWithWarmupTasks(storeName -> Materialized.as(Stores.inMemoryKeyValueStore(storeName))); + } + + @Test + public void shouldScaleOutWithWarmupTasksAndPersistentStores() throws InterruptedException { + // NB: this test takes at least a minute to run, because it needs a probing rebalance, and the minimum + // value is one minute + shouldScaleOutWithWarmupTasks(storeName -> Materialized.as(Stores.persistentKeyValueStore(storeName))); + } + + private void shouldScaleOutWithWarmupTasks(final Function>> materializedFunction) throws InterruptedException { + final String testId = safeUniqueTestName(getClass(), testName); + final String appId = "appId_" + System.currentTimeMillis() + "_" + testId; + final String inputTopic = "input" + testId; + final Set inputTopicPartitions = mkSet( + new TopicPartition(inputTopic, 0), + new TopicPartition(inputTopic, 1) + ); + + final String storeName = "store" + testId; + final String storeChangelog = appId + "-store" + testId + "-changelog"; + final Set changelogTopicPartitions = mkSet( + new TopicPartition(storeChangelog, 0), + new TopicPartition(storeChangelog, 1) + ); + + IntegrationTestUtils.cleanStateBeforeTest(CLUSTER, 2, inputTopic, storeChangelog); + + final ReentrantLock assignmentLock = new ReentrantLock(); + final AtomicInteger assignmentsCompleted = new AtomicInteger(0); + final Map assignmentsStable = new ConcurrentHashMap<>(); + final AtomicBoolean assignmentStable = new AtomicBoolean(false); + final AssignmentListener assignmentListener = + stable -> { + assignmentLock.lock(); + try { + final int thisAssignmentIndex = assignmentsCompleted.incrementAndGet(); + assignmentsStable.put(thisAssignmentIndex, stable); + assignmentStable.set(stable); + } finally { + assignmentLock.unlock(); + } + }; + + final StreamsBuilder builder = new StreamsBuilder(); + builder.table(inputTopic, materializedFunction.apply(storeName)); + final Topology topology = builder.build(); + + final int numberOfRecords = 500; + + produceTestData(inputTopic, numberOfRecords); + + try (final KafkaStreams kafkaStreams0 = new KafkaStreams(topology, streamsProperties(appId, assignmentListener)); + final KafkaStreams kafkaStreams1 = new KafkaStreams(topology, streamsProperties(appId, assignmentListener)); + final Consumer consumer = new KafkaConsumer<>(getConsumerProperties())) { + kafkaStreams0.start(); + + // sanity check: just make sure we actually wrote all the input records + TestUtils.waitForCondition( + () -> getEndOffsetSum(inputTopicPartitions, consumer) == numberOfRecords, + 120_000L, + () -> "Input records haven't all been written to the input topic: " + getEndOffsetSum(inputTopicPartitions, consumer) + ); + + // wait until all the input records are in the changelog + TestUtils.waitForCondition( + () -> getEndOffsetSum(changelogTopicPartitions, consumer) == numberOfRecords, + 120_000L, + () -> "Input records haven't all been written to the changelog: " + getEndOffsetSum(changelogTopicPartitions, consumer) + ); + + final AtomicLong instance1TotalRestored = new AtomicLong(-1); + final AtomicLong instance1NumRestored = new AtomicLong(-1); + final CountDownLatch restoreCompleteLatch = new CountDownLatch(1); + kafkaStreams1.setGlobalStateRestoreListener(new StateRestoreListener() { + @Override + public void onRestoreStart(final TopicPartition topicPartition, + final String storeName, + final long startingOffset, + final long endingOffset) { + } + + @Override + public void onBatchRestored(final TopicPartition topicPartition, + final String storeName, + final long batchEndOffset, + final long numRestored) { + instance1NumRestored.accumulateAndGet( + numRestored, + (prev, restored) -> prev == -1 ? restored : prev + restored + ); + } + + @Override + public void onRestoreEnd(final TopicPartition topicPartition, + final String storeName, + final long totalRestored) { + instance1TotalRestored.accumulateAndGet( + totalRestored, + (prev, restored) -> prev == -1 ? restored : prev + restored + ); + restoreCompleteLatch.countDown(); + } + }); + final int assignmentsBeforeScaleOut = assignmentsCompleted.get(); + kafkaStreams1.start(); + TestUtils.waitForCondition( + () -> { + assignmentLock.lock(); + try { + if (assignmentsCompleted.get() > assignmentsBeforeScaleOut) { + assertFalseNoRetry( + assignmentsStable.get(assignmentsBeforeScaleOut + 1), + "the first assignment after adding a node should be unstable while we warm up the state." + ); + return true; + } else { + return false; + } + } finally { + assignmentLock.unlock(); + } + }, + 120_000L, + "Never saw a first assignment after scale out: " + assignmentsCompleted.get() + ); + + TestUtils.waitForCondition( + assignmentStable::get, + 120_000L, + "Assignment hasn't become stable: " + assignmentsCompleted.get() + + " Note, if this does fail, check and see if the new instance just failed to catch up within" + + " the probing rebalance interval. A full minute should be long enough to read ~500 records" + + " in any test environment, but you never know..." + ); + + restoreCompleteLatch.await(); + // We should finalize the restoration without having restored any records (because they're already in + // the store. Otherwise, we failed to properly re-use the state from the standby. + assertThat(instance1TotalRestored.get(), is(0L)); + // Belt-and-suspenders check that we never even attempt to restore any records. + assertThat(instance1NumRestored.get(), is(-1L)); + } + } + + private void produceTestData(final String inputTopic, final int numberOfRecords) { + final String kilo = getKiloByteValue(); + + final Properties producerProperties = mkProperties( + mkMap( + mkEntry(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()), + mkEntry(ProducerConfig.ACKS_CONFIG, "all"), + mkEntry(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class.getName()), + mkEntry(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, StringSerializer.class.getName()) + ) + ); + + try (final Producer producer = new KafkaProducer<>(producerProperties)) { + for (int i = 0; i < numberOfRecords; i++) { + producer.send(new ProducerRecord<>(inputTopic, String.valueOf(i), kilo)); + } + } + } + + private static Properties getConsumerProperties() { + return mkProperties( + mkMap( + mkEntry(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()), + mkEntry(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class.getName()), + mkEntry(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class.getName()) + ) + ); + } + + private static String getKiloByteValue() { + final StringBuilder kiloBuilder = new StringBuilder(1000); + for (int i = 0; i < 1000; i++) { + kiloBuilder.append('0'); + } + return kiloBuilder.toString(); + } + + private static void assertFalseNoRetry(final boolean assertion, final String message) { + if (assertion) { + throw new NoRetryException( + new AssertionError( + message + ) + ); + } + } + + private static Properties streamsProperties(final String appId, + final AssignmentListener configuredAssignmentListener) { + return mkObjectProperties( + mkMap( + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()), + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, appId), + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()), + mkEntry(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, "0"), + mkEntry(StreamsConfig.ACCEPTABLE_RECOVERY_LAG_CONFIG, "0"), // make the warmup catch up completely + mkEntry(StreamsConfig.MAX_WARMUP_REPLICAS_CONFIG, "2"), + mkEntry(StreamsConfig.PROBING_REBALANCE_INTERVAL_MS_CONFIG, "60000"), + mkEntry(StreamsConfig.InternalConfig.ASSIGNMENT_LISTENER, configuredAssignmentListener), + mkEntry(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L), + mkEntry(StreamsConfig.InternalConfig.INTERNAL_TASK_ASSIGNOR_CLASS, HighAvailabilityTaskAssignor.class.getName()), + // Increasing the number of threads to ensure that a rebalance happens each time a consumer sends a rejoin (KAFKA-10455) + mkEntry(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 40), + mkEntry(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.StringSerde.class.getName()), + mkEntry(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.StringSerde.class.getName()) + ) + ); + } + + private static long getEndOffsetSum(final Set changelogTopicPartitions, + final Consumer consumer) { + long sum = 0; + final Collection values = consumer.endOffsets(changelogTopicPartitions).values(); + for (final Long value : values) { + sum += value; + } + return sum; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/InternalTopicIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/InternalTopicIntegrationTest.java new file mode 100644 index 0000000..29c61ec --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/InternalTopicIntegrationTest.java @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import java.time.Duration; +import kafka.log.LogConfig; +import kafka.utils.MockTime; +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.AdminClientConfig; +import org.apache.kafka.clients.admin.Config; +import org.apache.kafka.clients.admin.ConfigEntry; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.streams.processor.internals.ProcessorStateManager; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.MockMapper; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Properties; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; + +import static java.time.Duration.ofMillis; +import static java.time.Duration.ofSeconds; +import static java.util.Collections.singletonList; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.startApplicationAndWaitUntilRunning; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.waitForCompletion; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * Tests related to internal topics in streams + */ +@SuppressWarnings("deprecation") +@Category({IntegrationTest.class}) +public class InternalTopicIntegrationTest { + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(1); + + @BeforeClass + public static void startCluster() throws IOException, InterruptedException { + CLUSTER.start(); + CLUSTER.createTopics(DEFAULT_INPUT_TOPIC, DEFAULT_INPUT_TABLE_TOPIC); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + + private static final String APP_ID = "internal-topics-integration-test"; + private static final String DEFAULT_INPUT_TOPIC = "inputTopic"; + private static final String DEFAULT_INPUT_TABLE_TOPIC = "inputTable"; + + private final MockTime mockTime = CLUSTER.time; + + private Properties streamsProp; + + @Before + public void before() { + streamsProp = new Properties(); + streamsProp.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + streamsProp.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass().getName()); + streamsProp.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass().getName()); + streamsProp.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()); + streamsProp.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L); + streamsProp.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + streamsProp.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + } + + @After + public void after() throws IOException { + // Remove any state from previous test runs + IntegrationTestUtils.purgeLocalStreamsState(streamsProp); + } + + private void produceData(final List inputValues) { + final Properties producerProp = new Properties(); + producerProp.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + producerProp.put(ProducerConfig.ACKS_CONFIG, "all"); + producerProp.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class); + producerProp.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, StringSerializer.class); + + IntegrationTestUtils.produceValuesSynchronously(DEFAULT_INPUT_TOPIC, inputValues, producerProp, mockTime); + } + + private Properties getTopicProperties(final String changelog) { + try (final Admin adminClient = createAdminClient()) { + final ConfigResource configResource = new ConfigResource(ConfigResource.Type.TOPIC, changelog); + try { + final Config config = adminClient.describeConfigs(Collections.singletonList(configResource)).values().get(configResource).get(); + final Properties properties = new Properties(); + for (final ConfigEntry configEntry : config.entries()) { + if (configEntry.source() == ConfigEntry.ConfigSource.DYNAMIC_TOPIC_CONFIG) { + properties.put(configEntry.name(), configEntry.value()); + } + } + return properties; + } catch (final InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + } + } + + private Admin createAdminClient() { + final Properties adminClientConfig = new Properties(); + adminClientConfig.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + return Admin.create(adminClientConfig); + } + + /* + * This test just ensures that that the assignor does not get stuck during partition number resolution + * for internal repartition topics. See KAFKA-10689 + */ + @Test + public void shouldGetToRunningWithWindowedTableInFKJ() throws Exception { + final String appID = APP_ID + "-windowed-FKJ"; + streamsProp.put(StreamsConfig.APPLICATION_ID_CONFIG, appID); + + final StreamsBuilder streamsBuilder = new StreamsBuilder(); + final KStream inputTopic = streamsBuilder.stream(DEFAULT_INPUT_TOPIC); + final KTable inputTable = streamsBuilder.table(DEFAULT_INPUT_TABLE_TOPIC); + inputTopic + .groupBy( + (k, v) -> k, + Grouped.with("GroupName", Serdes.String(), Serdes.String()) + ) + .windowedBy(TimeWindows.of(Duration.ofMinutes(10))) + .aggregate( + () -> "", + (k, v, a) -> a + k) + .leftJoin( + inputTable, + v -> v, + (x, y) -> x + y + ); + + final KafkaStreams streams = new KafkaStreams(streamsBuilder.build(), streamsProp); + startApplicationAndWaitUntilRunning(singletonList(streams), Duration.ofSeconds(60)); + } + + @Test + public void shouldCompactTopicsForKeyValueStoreChangelogs() { + final String appID = APP_ID + "-compact"; + streamsProp.put(StreamsConfig.APPLICATION_ID_CONFIG, appID); + + // + // Step 1: Configure and start a simple word count topology + // + final StreamsBuilder builder = new StreamsBuilder(); + final KStream textLines = builder.stream(DEFAULT_INPUT_TOPIC); + + textLines.flatMapValues(value -> Arrays.asList(value.toLowerCase(Locale.getDefault()).split("\\W+"))) + .groupBy(MockMapper.selectValueMapper()) + .count(Materialized.as("Counts")); + + final KafkaStreams streams = new KafkaStreams(builder.build(), streamsProp); + streams.start(); + + // + // Step 2: Produce some input data to the input topic. + // + produceData(Arrays.asList("hello", "world", "world", "hello world")); + + // + // Step 3: Verify the state changelog topics are compact + // + waitForCompletion(streams, 2, 30000L); + streams.close(); + + final Properties changelogProps = getTopicProperties(ProcessorStateManager.storeChangelogTopic(appID, "Counts", null)); + assertEquals(LogConfig.Compact(), changelogProps.getProperty(LogConfig.CleanupPolicyProp())); + + final Properties repartitionProps = getTopicProperties(appID + "-Counts-repartition"); + assertEquals(LogConfig.Delete(), repartitionProps.getProperty(LogConfig.CleanupPolicyProp())); + assertEquals(4, repartitionProps.size()); + } + + @Test + public void shouldCompactAndDeleteTopicsForWindowStoreChangelogs() { + final String appID = APP_ID + "-compact-delete"; + streamsProp.put(StreamsConfig.APPLICATION_ID_CONFIG, appID); + + // + // Step 1: Configure and start a simple word count topology + // + final StreamsBuilder builder = new StreamsBuilder(); + final KStream textLines = builder.stream(DEFAULT_INPUT_TOPIC); + + final int durationMs = 2000; + + textLines.flatMapValues(value -> Arrays.asList(value.toLowerCase(Locale.getDefault()).split("\\W+"))) + .groupBy(MockMapper.selectValueMapper()) + .windowedBy(TimeWindows.of(ofSeconds(1L)).grace(ofMillis(0L))) + .count(Materialized.>as("CountWindows").withRetention(ofSeconds(2L))); + + final KafkaStreams streams = new KafkaStreams(builder.build(), streamsProp); + streams.start(); + + // + // Step 2: Produce some input data to the input topic. + // + produceData(Arrays.asList("hello", "world", "world", "hello world")); + + // + // Step 3: Verify the state changelog topics are compact + // + waitForCompletion(streams, 2, 30000L); + streams.close(); + final Properties properties = getTopicProperties(ProcessorStateManager.storeChangelogTopic(appID, "CountWindows", null)); + final List policies = Arrays.asList(properties.getProperty(LogConfig.CleanupPolicyProp()).split(",")); + assertEquals(2, policies.size()); + assertTrue(policies.contains(LogConfig.Compact())); + assertTrue(policies.contains(LogConfig.Delete())); + // retention should be 1 day + the window duration + final long retention = TimeUnit.MILLISECONDS.convert(1, TimeUnit.DAYS) + durationMs; + assertEquals(retention, Long.parseLong(properties.getProperty(LogConfig.RetentionMsProp()))); + + final Properties repartitionProps = getTopicProperties(appID + "-CountWindows-repartition"); + assertEquals(LogConfig.Delete(), repartitionProps.getProperty(LogConfig.CleanupPolicyProp())); + assertEquals(4, repartitionProps.size()); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/JoinStoreIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/JoinStoreIntegrationTest.java new file mode 100644 index 0000000..04d3f7d --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/JoinStoreIntegrationTest.java @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import java.util.Collection; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.AdminClientConfig; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.config.ConfigResource.Type; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.UnknownStateStoreException; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.JoinWindows; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.StreamJoined; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TemporaryFolder; + +import java.io.IOException; +import java.util.Properties; +import java.util.concurrent.CountDownLatch; + +import static java.time.Duration.ofMillis; +import static org.apache.kafka.streams.StoreQueryParameters.fromNameAndType; +import static org.apache.kafka.streams.state.QueryableStoreTypes.keyValueStore; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThrows; + +@SuppressWarnings("deprecation") +@Category({IntegrationTest.class}) +public class JoinStoreIntegrationTest { + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(1); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + STREAMS_CONFIG.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + STREAMS_CONFIG.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + STREAMS_CONFIG.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.Long().getClass()); + STREAMS_CONFIG.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + STREAMS_CONFIG.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, COMMIT_INTERVAL); + + ADMIN_CONFIG.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + + @Rule + public final TemporaryFolder testFolder = new TemporaryFolder(TestUtils.tempDirectory()); + + private static final String APP_ID = "join-store-integration-test"; + private static final Long COMMIT_INTERVAL = 100L; + static final Properties STREAMS_CONFIG = new Properties(); + static final String INPUT_TOPIC_RIGHT = "inputTopicRight"; + static final String INPUT_TOPIC_LEFT = "inputTopicLeft"; + static final String OUTPUT_TOPIC = "outputTopic"; + static final Properties ADMIN_CONFIG = new Properties(); + + + StreamsBuilder builder; + + @Before + public void prepareTopology() throws InterruptedException { + CLUSTER.createTopics(INPUT_TOPIC_LEFT, INPUT_TOPIC_RIGHT, OUTPUT_TOPIC); + STREAMS_CONFIG.put(StreamsConfig.STATE_DIR_CONFIG, testFolder.getRoot().getPath()); + + builder = new StreamsBuilder(); + } + + @After + public void cleanup() throws InterruptedException { + CLUSTER.deleteAllTopicsAndWait(120000); + } + + @Test + public void providingAJoinStoreNameShouldNotMakeTheJoinResultQueriable() throws InterruptedException { + STREAMS_CONFIG.put(StreamsConfig.APPLICATION_ID_CONFIG, APP_ID + "-no-store-access"); + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream left = builder.stream(INPUT_TOPIC_LEFT, Consumed.with(Serdes.String(), Serdes.Integer())); + final KStream right = builder.stream(INPUT_TOPIC_RIGHT, Consumed.with(Serdes.String(), Serdes.Integer())); + final CountDownLatch latch = new CountDownLatch(1); + + left.join( + right, + Integer::sum, + JoinWindows.of(ofMillis(100)), + StreamJoined.with(Serdes.String(), Serdes.Integer(), Serdes.Integer()).withStoreName("join-store")); + + try (final KafkaStreams kafkaStreams = new KafkaStreams(builder.build(), STREAMS_CONFIG)) { + kafkaStreams.setStateListener((newState, oldState) -> { + if (newState == KafkaStreams.State.RUNNING) { + latch.countDown(); + } + }); + + kafkaStreams.start(); + latch.await(); + final UnknownStateStoreException exception = + assertThrows( + UnknownStateStoreException.class, + () -> kafkaStreams.store(fromNameAndType("join-store", keyValueStore())) + ); + assertThat( + exception.getMessage(), + is("Cannot get state store join-store because no such store is registered in the topology.") + ); + } + } + + @Test + public void streamJoinChangelogTopicShouldBeConfiguredWithDeleteOnlyCleanupPolicy() throws Exception { + STREAMS_CONFIG.put(StreamsConfig.APPLICATION_ID_CONFIG, APP_ID + "-changelog-cleanup-policy"); + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream left = builder.stream(INPUT_TOPIC_LEFT, Consumed.with(Serdes.String(), Serdes.Integer())); + final KStream right = builder.stream(INPUT_TOPIC_RIGHT, Consumed.with(Serdes.String(), Serdes.Integer())); + final CountDownLatch latch = new CountDownLatch(1); + + left.join( + right, + Integer::sum, + JoinWindows.of(ofMillis(100)), + StreamJoined.with(Serdes.String(), Serdes.Integer(), Serdes.Integer()).withStoreName("join-store")); + + try (final KafkaStreams kafkaStreams = new KafkaStreams(builder.build(), STREAMS_CONFIG); + final Admin admin = Admin.create(ADMIN_CONFIG)) { + kafkaStreams.setStateListener((newState, oldState) -> { + if (newState == KafkaStreams.State.RUNNING) { + latch.countDown(); + } + }); + + kafkaStreams.start(); + latch.await(); + + final Collection changelogTopics = Stream.of( + "join-store-integration-test-changelog-cleanup-policy-join-store-this-join-store-changelog", + "join-store-integration-test-changelog-cleanup-policy-join-store-other-join-store-changelog" + ) + .map(name -> new ConfigResource(Type.TOPIC, name)) + .collect(Collectors.toList()); + + final Map topicConfig + = admin.describeConfigs(changelogTopics).all().get(); + topicConfig.values().forEach( + tc -> assertThat( + tc.get("cleanup.policy").value(), + is("delete") + ) + ); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/JoinWithIncompleteMetadataIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/JoinWithIncompleteMetadataIntegrationTest.java new file mode 100644 index 0000000..1f6152d --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/JoinWithIncompleteMetadataIntegrationTest.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import java.io.IOException; +import java.util.Properties; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KafkaStreamsWrapper; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.ValueJoiner; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TemporaryFolder; + +import static org.junit.Assert.assertTrue; + +@Category({IntegrationTest.class}) +public class JoinWithIncompleteMetadataIntegrationTest { + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(1); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + STREAMS_CONFIG.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + STREAMS_CONFIG.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + STREAMS_CONFIG.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.Long().getClass()); + STREAMS_CONFIG.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + STREAMS_CONFIG.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, COMMIT_INTERVAL); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + + @Rule + public final TemporaryFolder testFolder = new TemporaryFolder(TestUtils.tempDirectory()); + + private static final String APP_ID = "join-incomplete-metadata-integration-test"; + private static final Long COMMIT_INTERVAL = 100L; + static final Properties STREAMS_CONFIG = new Properties(); + static final String INPUT_TOPIC_RIGHT = "inputTopicRight"; + static final String NON_EXISTENT_INPUT_TOPIC_LEFT = "inputTopicLeft-not-exist"; + static final String OUTPUT_TOPIC = "outputTopic"; + + StreamsBuilder builder; + final ValueJoiner valueJoiner = (value1, value2) -> value1 + "-" + value2; + private KTable rightTable; + + @Before + public void prepareTopology() throws InterruptedException { + CLUSTER.createTopics(INPUT_TOPIC_RIGHT, OUTPUT_TOPIC); + STREAMS_CONFIG.put(StreamsConfig.STATE_DIR_CONFIG, testFolder.getRoot().getPath()); + + builder = new StreamsBuilder(); + rightTable = builder.table(INPUT_TOPIC_RIGHT); + } + + @After + public void cleanup() throws InterruptedException { + CLUSTER.deleteAllTopicsAndWait(120000); + } + + @Test + public void testShouldAutoShutdownOnJoinWithIncompleteMetadata() throws InterruptedException { + STREAMS_CONFIG.put(StreamsConfig.APPLICATION_ID_CONFIG, APP_ID); + STREAMS_CONFIG.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + + final KStream notExistStream = builder.stream(NON_EXISTENT_INPUT_TOPIC_LEFT); + + final KTable aggregatedTable = notExistStream.leftJoin(rightTable, valueJoiner) + .groupBy((key, value) -> key) + .reduce((value1, value2) -> value1 + value2); + + // Write the (continuously updating) results to the output topic. + aggregatedTable.toStream().to(OUTPUT_TOPIC); + + final KafkaStreamsWrapper streams = new KafkaStreamsWrapper(builder.build(), STREAMS_CONFIG); + final IntegrationTestUtils.StateListenerStub listener = new IntegrationTestUtils.StateListenerStub(); + streams.setStreamThreadStateListener(listener); + streams.start(); + + TestUtils.waitForCondition(listener::transitToPendingShutdownSeen, "Did not seen thread state transited to PENDING_SHUTDOWN"); + + streams.close(); + assertTrue(listener.transitToPendingShutdownSeen()); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/KStreamAggregationDedupIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/KStreamAggregationDedupIntegrationTest.java new file mode 100644 index 0000000..4fe35a6 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/KStreamAggregationDedupIntegrationTest.java @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import kafka.utils.MockTime; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.LongDeserializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.Reducer; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.MockMapper; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.Properties; + +import static java.time.Duration.ofMillis; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; + +/** + * Similar to KStreamAggregationIntegrationTest but with dedupping enabled + * by virtue of having a large commit interval + */ +@Category({IntegrationTest.class}) +@SuppressWarnings("deprecation") +public class KStreamAggregationDedupIntegrationTest { + private static final int NUM_BROKERS = 1; + private static final long COMMIT_INTERVAL_MS = 300L; + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(NUM_BROKERS); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + + private final MockTime mockTime = CLUSTER.time; + private StreamsBuilder builder; + private Properties streamsConfiguration; + private KafkaStreams kafkaStreams; + private String streamOneInput; + private String outputTopic; + private KGroupedStream groupedStream; + private Reducer reducer; + private KStream stream; + + @Rule + public TestName testName = new TestName(); + + @Before + public void before() throws InterruptedException { + builder = new StreamsBuilder(); + createTopics(); + streamsConfiguration = new Properties(); + final String safeTestName = safeUniqueTestName(getClass(), testName); + streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, "app-" + safeTestName); + streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + streamsConfiguration.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()); + streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, COMMIT_INTERVAL_MS); + streamsConfiguration.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 10 * 1024 * 1024L); + + final KeyValueMapper mapper = MockMapper.selectValueMapper(); + stream = builder.stream(streamOneInput, Consumed.with(Serdes.Integer(), Serdes.String())); + groupedStream = stream.groupBy(mapper, Grouped.with(Serdes.String(), Serdes.String())); + + reducer = (value1, value2) -> value1 + ":" + value2; + } + + @After + public void whenShuttingDown() throws IOException { + if (kafkaStreams != null) { + kafkaStreams.close(); + } + IntegrationTestUtils.purgeLocalStreamsState(streamsConfiguration); + } + + + @Test + public void shouldReduce() throws Exception { + produceMessages(System.currentTimeMillis()); + groupedStream + .reduce(reducer, Materialized.as("reduce-by-key")) + .toStream() + .to(outputTopic, Produced.with(Serdes.String(), Serdes.String())); + + startStreams(); + + final long timestamp = System.currentTimeMillis(); + produceMessages(timestamp); + + validateReceivedMessages( + new StringDeserializer(), + new StringDeserializer(), + Arrays.asList( + new KeyValueTimestamp<>("A", "A:A", timestamp), + new KeyValueTimestamp<>("B", "B:B", timestamp), + new KeyValueTimestamp<>("C", "C:C", timestamp), + new KeyValueTimestamp<>("D", "D:D", timestamp), + new KeyValueTimestamp<>("E", "E:E", timestamp))); + } + + @Test + public void shouldReduceWindowed() throws Exception { + final long firstBatchTimestamp = System.currentTimeMillis() - 1000; + produceMessages(firstBatchTimestamp); + final long secondBatchTimestamp = System.currentTimeMillis(); + produceMessages(secondBatchTimestamp); + produceMessages(secondBatchTimestamp); + + groupedStream + .windowedBy(TimeWindows.of(ofMillis(500L))) + .reduce(reducer, Materialized.as("reduce-time-windows")) + .toStream((windowedKey, value) -> windowedKey.key() + "@" + windowedKey.window().start()) + .to(outputTopic, Produced.with(Serdes.String(), Serdes.String())); + + startStreams(); + + final long firstBatchWindow = firstBatchTimestamp / 500 * 500; + final long secondBatchWindow = secondBatchTimestamp / 500 * 500; + + validateReceivedMessages( + new StringDeserializer(), + new StringDeserializer(), + Arrays.asList( + new KeyValueTimestamp<>("A@" + firstBatchWindow, "A", firstBatchTimestamp), + new KeyValueTimestamp<>("A@" + secondBatchWindow, "A:A", secondBatchTimestamp), + new KeyValueTimestamp<>("B@" + firstBatchWindow, "B", firstBatchTimestamp), + new KeyValueTimestamp<>("B@" + secondBatchWindow, "B:B", secondBatchTimestamp), + new KeyValueTimestamp<>("C@" + firstBatchWindow, "C", firstBatchTimestamp), + new KeyValueTimestamp<>("C@" + secondBatchWindow, "C:C", secondBatchTimestamp), + new KeyValueTimestamp<>("D@" + firstBatchWindow, "D", firstBatchTimestamp), + new KeyValueTimestamp<>("D@" + secondBatchWindow, "D:D", secondBatchTimestamp), + new KeyValueTimestamp<>("E@" + firstBatchWindow, "E", firstBatchTimestamp), + new KeyValueTimestamp<>("E@" + secondBatchWindow, "E:E", secondBatchTimestamp) + ) + ); + } + + @Test + public void shouldGroupByKey() throws Exception { + final long timestamp = mockTime.milliseconds(); + produceMessages(timestamp); + produceMessages(timestamp); + + stream.groupByKey(Grouped.with(Serdes.Integer(), Serdes.String())) + .windowedBy(TimeWindows.of(ofMillis(500L))) + .count(Materialized.as("count-windows")) + .toStream((windowedKey, value) -> windowedKey.key() + "@" + windowedKey.window().start()) + .to(outputTopic, Produced.with(Serdes.String(), Serdes.Long())); + + startStreams(); + + final long window = timestamp / 500 * 500; + + validateReceivedMessages( + new StringDeserializer(), + new LongDeserializer(), + Arrays.asList( + new KeyValueTimestamp<>("1@" + window, 2L, timestamp), + new KeyValueTimestamp<>("2@" + window, 2L, timestamp), + new KeyValueTimestamp<>("3@" + window, 2L, timestamp), + new KeyValueTimestamp<>("4@" + window, 2L, timestamp), + new KeyValueTimestamp<>("5@" + window, 2L, timestamp) + ) + ); + } + + + private void produceMessages(final long timestamp) throws Exception { + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + streamOneInput, + Arrays.asList( + new KeyValue<>(1, "A"), + new KeyValue<>(2, "B"), + new KeyValue<>(3, "C"), + new KeyValue<>(4, "D"), + new KeyValue<>(5, "E")), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + IntegerSerializer.class, + StringSerializer.class, + new Properties()), + timestamp); + } + + + private void createTopics() throws InterruptedException { + final String safeTestName = safeUniqueTestName(getClass(), testName); + streamOneInput = "stream-one-" + safeTestName; + outputTopic = "output-" + safeTestName; + CLUSTER.createTopic(streamOneInput, 3, 1); + CLUSTER.createTopic(outputTopic); + } + + private void startStreams() { + kafkaStreams = new KafkaStreams(builder.build(), streamsConfiguration); + kafkaStreams.start(); + } + + + private void validateReceivedMessages(final Deserializer keyDeserializer, + final Deserializer valueDeserializer, + final List> expectedRecords) + throws Exception { + + final String safeTestName = safeUniqueTestName(getClass(), testName); + final Properties consumerProperties = new Properties(); + consumerProperties.setProperty(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + consumerProperties.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "group-" + safeTestName); + consumerProperties.setProperty(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + consumerProperties.setProperty(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, keyDeserializer.getClass().getName()); + consumerProperties.setProperty(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, valueDeserializer.getClass().getName()); + + IntegrationTestUtils.waitUntilFinalKeyValueTimestampRecordsReceived( + consumerProperties, + outputTopic, + expectedRecords); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/KStreamAggregationIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/KStreamAggregationIntegrationTest.java new file mode 100644 index 0000000..e581903 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/KStreamAggregationIntegrationTest.java @@ -0,0 +1,1142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import kafka.tools.ConsoleConsumer; +import kafka.utils.MockTime; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.IntegerDeserializer; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.LongDeserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.Reducer; +import org.apache.kafka.streams.kstream.SessionWindowedDeserializer; +import org.apache.kafka.streams.kstream.SessionWindows; +import org.apache.kafka.streams.kstream.SlidingWindows; +import org.apache.kafka.streams.kstream.TimeWindowedDeserializer; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.streams.kstream.Transformer; +import org.apache.kafka.streams.kstream.UnlimitedWindows; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.WindowedSerdes; +import org.apache.kafka.streams.kstream.internals.SessionWindow; +import org.apache.kafka.streams.kstream.internals.TimeWindow; +import org.apache.kafka.streams.kstream.internals.UnlimitedWindow; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.QueryableStoreTypes; +import org.apache.kafka.streams.state.ReadOnlySessionStore; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.MockMapper; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static java.time.Duration.ofMillis; +import static java.time.Duration.ofMinutes; +import static java.time.Instant.ofEpochMilli; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.Is.is; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +@SuppressWarnings({"unchecked", "deprecation"}) +@Category({IntegrationTest.class}) +public class KStreamAggregationIntegrationTest { + private static final int NUM_BROKERS = 1; + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(NUM_BROKERS); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + + private final MockTime mockTime = CLUSTER.time; + private StreamsBuilder builder; + private Properties streamsConfiguration; + private KafkaStreams kafkaStreams; + private String streamOneInput; + private String userSessionsStream; + private String outputTopic; + private KGroupedStream groupedStream; + private Reducer reducer; + private Initializer initializer; + private Aggregator aggregator; + private KStream stream; + + @Rule + public TestName testName = new TestName(); + + @Before + public void before() throws InterruptedException { + builder = new StreamsBuilder(); + createTopics(); + streamsConfiguration = new Properties(); + final String safeTestName = safeUniqueTestName(getClass(), testName); + streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, "app-" + safeTestName); + streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + streamsConfiguration.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()); + streamsConfiguration.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L); + streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.Integer().getClass()); + + final KeyValueMapper mapper = MockMapper.selectValueMapper(); + stream = builder.stream(streamOneInput, Consumed.with(Serdes.Integer(), Serdes.String())); + groupedStream = stream.groupBy(mapper, Grouped.with(Serdes.String(), Serdes.String())); + + reducer = (value1, value2) -> value1 + ":" + value2; + initializer = () -> 0; + aggregator = (aggKey, value, aggregate) -> aggregate + value.length(); + } + + @After + public void whenShuttingDown() throws IOException { + if (kafkaStreams != null) { + kafkaStreams.close(); + } + IntegrationTestUtils.purgeLocalStreamsState(streamsConfiguration); + } + + @Test + public void shouldReduce() throws Exception { + produceMessages(mockTime.milliseconds()); + groupedStream + .reduce(reducer, Materialized.as("reduce-by-key")) + .toStream() + .to(outputTopic, Produced.with(Serdes.String(), Serdes.String())); + + startStreams(); + + produceMessages(mockTime.milliseconds()); + + final List> results = receiveMessages( + new StringDeserializer(), + new StringDeserializer(), + 10); + + results.sort(KStreamAggregationIntegrationTest::compare); + + assertThat(results, is(Arrays.asList( + new KeyValueTimestamp("A", "A", mockTime.milliseconds()), + new KeyValueTimestamp("A", "A:A", mockTime.milliseconds()), + new KeyValueTimestamp("B", "B", mockTime.milliseconds()), + new KeyValueTimestamp("B", "B:B", mockTime.milliseconds()), + new KeyValueTimestamp("C", "C", mockTime.milliseconds()), + new KeyValueTimestamp("C", "C:C", mockTime.milliseconds()), + new KeyValueTimestamp("D", "D", mockTime.milliseconds()), + new KeyValueTimestamp("D", "D:D", mockTime.milliseconds()), + new KeyValueTimestamp("E", "E", mockTime.milliseconds()), + new KeyValueTimestamp("E", "E:E", mockTime.milliseconds())))); + } + + private static int compare(final KeyValueTimestamp o1, + final KeyValueTimestamp o2) { + final int keyComparison = o1.key().compareTo(o2.key()); + if (keyComparison == 0) { + final int valueComparison = o1.value().compareTo(o2.value()); + if (valueComparison == 0) { + return Long.compare(o1.timestamp(), o2.timestamp()); + } + return valueComparison; + } + return keyComparison; + } + + @SuppressWarnings("deprecation") + @Test + public void shouldReduceWindowed() throws Exception { + final long firstBatchTimestamp = mockTime.milliseconds(); + mockTime.sleep(1000); + produceMessages(firstBatchTimestamp); + final long secondBatchTimestamp = mockTime.milliseconds(); + produceMessages(secondBatchTimestamp); + produceMessages(secondBatchTimestamp); + + final Serde> windowedSerde = WindowedSerdes.timeWindowedSerdeFrom(String.class, 500L); + //noinspection deprecation + groupedStream + .windowedBy(TimeWindows.of(ofMillis(500L))) + .reduce(reducer) + .toStream() + .to(outputTopic, Produced.with(windowedSerde, Serdes.String())); + + startStreams(); + + final List, String>> windowedOutput = receiveMessages( + new TimeWindowedDeserializer<>(), + new StringDeserializer(), + String.class, + 15); + + // read from ConsoleConsumer + final String resultFromConsoleConsumer = readWindowedKeyedMessagesViaConsoleConsumer( + new TimeWindowedDeserializer(), + new StringDeserializer(), + String.class, + 15, + true); + + final Comparator, String>> comparator = + Comparator.comparing((KeyValueTimestamp, String> o) -> o.key().key()) + .thenComparing(KeyValueTimestamp::value); + + windowedOutput.sort(comparator); + final long firstBatchWindowStart = firstBatchTimestamp / 500 * 500; + final long firstBatchWindowEnd = firstBatchWindowStart + 500; + final long secondBatchWindowStart = secondBatchTimestamp / 500 * 500; + final long secondBatchWindowEnd = secondBatchWindowStart + 500; + + final List, String>> expectResult = Arrays.asList( + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(firstBatchWindowStart, firstBatchWindowEnd)), "A", firstBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(secondBatchWindowStart, secondBatchWindowEnd)), "A", secondBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(secondBatchWindowStart, secondBatchWindowEnd)), "A:A", secondBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(firstBatchWindowStart, firstBatchWindowEnd)), "B", firstBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(secondBatchWindowStart, secondBatchWindowEnd)), "B", secondBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(secondBatchWindowStart, secondBatchWindowEnd)), "B:B", secondBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(firstBatchWindowStart, firstBatchWindowEnd)), "C", firstBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(secondBatchWindowStart, secondBatchWindowEnd)), "C", secondBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(secondBatchWindowStart, secondBatchWindowEnd)), "C:C", secondBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(firstBatchWindowStart, firstBatchWindowEnd)), "D", firstBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(secondBatchWindowStart, secondBatchWindowEnd)), "D", secondBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(secondBatchWindowStart, secondBatchWindowEnd)), "D:D", secondBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(firstBatchWindowStart, firstBatchWindowEnd)), "E", firstBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(secondBatchWindowStart, secondBatchWindowEnd)), "E", secondBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(secondBatchWindowStart, secondBatchWindowEnd)), "E:E", secondBatchTimestamp) + ); + assertThat(windowedOutput, is(expectResult)); + + final Set expectResultString = new HashSet<>(expectResult.size()); + for (final KeyValueTimestamp, String> eachRecord: expectResult) { + expectResultString.add("CreateTime:" + eachRecord.timestamp() + ", " + + eachRecord.key() + ", " + eachRecord.value()); + } + + // check every message is contained in the expect result + final String[] allRecords = resultFromConsoleConsumer.split("\n"); + for (final String record: allRecords) { + assertTrue(expectResultString.contains(record)); + } + } + + @Test + public void shouldAggregate() throws Exception { + produceMessages(mockTime.milliseconds()); + groupedStream.aggregate( + initializer, + aggregator, + Materialized.as("aggregate-by-selected-key")) + .toStream() + .to(outputTopic, Produced.with(Serdes.String(), Serdes.Integer())); + + startStreams(); + + produceMessages(mockTime.milliseconds()); + + final List> results = receiveMessages( + new StringDeserializer(), + new IntegerDeserializer(), + 10); + + results.sort(KStreamAggregationIntegrationTest::compare); + + assertThat(results, is(Arrays.asList( + new KeyValueTimestamp("A", 1, mockTime.milliseconds()), + new KeyValueTimestamp("A", 2, mockTime.milliseconds()), + new KeyValueTimestamp("B", 1, mockTime.milliseconds()), + new KeyValueTimestamp("B", 2, mockTime.milliseconds()), + new KeyValueTimestamp("C", 1, mockTime.milliseconds()), + new KeyValueTimestamp("C", 2, mockTime.milliseconds()), + new KeyValueTimestamp("D", 1, mockTime.milliseconds()), + new KeyValueTimestamp("D", 2, mockTime.milliseconds()), + new KeyValueTimestamp("E", 1, mockTime.milliseconds()), + new KeyValueTimestamp("E", 2, mockTime.milliseconds()) + ))); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldAggregateWindowed() throws Exception { + final long firstTimestamp = mockTime.milliseconds(); + mockTime.sleep(1000); + produceMessages(firstTimestamp); + final long secondTimestamp = mockTime.milliseconds(); + produceMessages(secondTimestamp); + produceMessages(secondTimestamp); + + final Serde> windowedSerde = WindowedSerdes.timeWindowedSerdeFrom(String.class, 500L); + //noinspection deprecation + groupedStream.windowedBy(TimeWindows.of(ofMillis(500L))) + .aggregate( + initializer, + aggregator, + Materialized.with(null, Serdes.Integer()) + ) + .toStream() + .to(outputTopic, Produced.with(windowedSerde, Serdes.Integer())); + + startStreams(); + + final List, Integer>> windowedMessages = receiveMessagesWithTimestamp( + new TimeWindowedDeserializer<>(new StringDeserializer(), 500L), + new IntegerDeserializer(), + String.class, + 15); + + // read from ConsoleConsumer + final String resultFromConsoleConsumer = readWindowedKeyedMessagesViaConsoleConsumer( + new TimeWindowedDeserializer(), + new IntegerDeserializer(), + String.class, + 15, + true); + + final Comparator, Integer>> comparator = + Comparator.comparing((KeyValueTimestamp, Integer> o) -> o.key().key()) + .thenComparingInt(KeyValueTimestamp::value); + windowedMessages.sort(comparator); + + final long firstWindowStart = firstTimestamp / 500 * 500; + final long firstWindowEnd = firstWindowStart + 500; + final long secondWindowStart = secondTimestamp / 500 * 500; + final long secondWindowEnd = secondWindowStart + 500; + + final List, Integer>> expectResult = Arrays.asList( + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(firstWindowStart, firstWindowEnd)), 1, firstTimestamp), + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(secondWindowStart, secondWindowEnd)), 1, secondTimestamp), + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(secondWindowStart, secondWindowEnd)), 2, secondTimestamp), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(firstWindowStart, firstWindowEnd)), 1, firstTimestamp), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(secondWindowStart, secondWindowEnd)), 1, secondTimestamp), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(secondWindowStart, secondWindowEnd)), 2, secondTimestamp), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(firstWindowStart, firstWindowEnd)), 1, firstTimestamp), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(secondWindowStart, secondWindowEnd)), 1, secondTimestamp), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(secondWindowStart, secondWindowEnd)), 2, secondTimestamp), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(firstWindowStart, firstWindowEnd)), 1, firstTimestamp), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(secondWindowStart, secondWindowEnd)), 1, secondTimestamp), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(secondWindowStart, secondWindowEnd)), 2, secondTimestamp), + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(firstWindowStart, firstWindowEnd)), 1, firstTimestamp), + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(secondWindowStart, secondWindowEnd)), 1, secondTimestamp), + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(secondWindowStart, secondWindowEnd)), 2, secondTimestamp)); + + assertThat(windowedMessages, is(expectResult)); + + final Set expectResultString = new HashSet<>(expectResult.size()); + for (final KeyValueTimestamp, Integer> eachRecord: expectResult) { + expectResultString.add("CreateTime:" + eachRecord.timestamp() + ", " + eachRecord.key() + ", " + eachRecord.value()); + } + + // check every message is contained in the expect result + final String[] allRecords = resultFromConsoleConsumer.split("\n"); + for (final String record: allRecords) { + assertTrue(expectResultString.contains(record)); + } + + } + + private void shouldCountHelper() throws Exception { + startStreams(); + + produceMessages(mockTime.milliseconds()); + + final List> results = receiveMessages( + new StringDeserializer(), + new LongDeserializer(), + 10); + results.sort(KStreamAggregationIntegrationTest::compare); + + assertThat(results, is(Arrays.asList( + new KeyValueTimestamp("A", 1L, mockTime.milliseconds()), + new KeyValueTimestamp("A", 2L, mockTime.milliseconds()), + new KeyValueTimestamp("B", 1L, mockTime.milliseconds()), + new KeyValueTimestamp("B", 2L, mockTime.milliseconds()), + new KeyValueTimestamp("C", 1L, mockTime.milliseconds()), + new KeyValueTimestamp("C", 2L, mockTime.milliseconds()), + new KeyValueTimestamp("D", 1L, mockTime.milliseconds()), + new KeyValueTimestamp("D", 2L, mockTime.milliseconds()), + new KeyValueTimestamp("E", 1L, mockTime.milliseconds()), + new KeyValueTimestamp("E", 2L, mockTime.milliseconds()) + ))); + } + + @Test + public void shouldCount() throws Exception { + produceMessages(mockTime.milliseconds()); + + groupedStream.count(Materialized.as("count-by-key")) + .toStream() + .to(outputTopic, Produced.with(Serdes.String(), Serdes.Long())); + + shouldCountHelper(); + } + + @Test + public void shouldCountWithInternalStore() throws Exception { + produceMessages(mockTime.milliseconds()); + + groupedStream.count() + .toStream() + .to(outputTopic, Produced.with(Serdes.String(), Serdes.Long())); + + shouldCountHelper(); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldGroupByKey() throws Exception { + final long timestamp = mockTime.milliseconds(); + produceMessages(timestamp); + produceMessages(timestamp); + + //noinspection deprecation + stream.groupByKey(Grouped.with(Serdes.Integer(), Serdes.String())) + .windowedBy(TimeWindows.of(ofMillis(500L))) + .count() + .toStream((windowedKey, value) -> windowedKey.key() + "@" + windowedKey.window().start()).to(outputTopic, Produced.with(Serdes.String(), Serdes.Long())); + + startStreams(); + + final List> results = receiveMessages( + new StringDeserializer(), + new LongDeserializer(), + 10); + results.sort(KStreamAggregationIntegrationTest::compare); + + final long window = timestamp / 500 * 500; + assertThat(results, is(Arrays.asList( + new KeyValueTimestamp("1@" + window, 1L, timestamp), + new KeyValueTimestamp("1@" + window, 2L, timestamp), + new KeyValueTimestamp("2@" + window, 1L, timestamp), + new KeyValueTimestamp("2@" + window, 2L, timestamp), + new KeyValueTimestamp("3@" + window, 1L, timestamp), + new KeyValueTimestamp("3@" + window, 2L, timestamp), + new KeyValueTimestamp("4@" + window, 1L, timestamp), + new KeyValueTimestamp("4@" + window, 2L, timestamp), + new KeyValueTimestamp("5@" + window, 1L, timestamp), + new KeyValueTimestamp("5@" + window, 2L, timestamp) + ))); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldReduceSlidingWindows() throws Exception { + final long firstBatchTimestamp = mockTime.milliseconds(); + final long timeDifference = 500L; + produceMessages(firstBatchTimestamp); + final long secondBatchTimestamp = firstBatchTimestamp + timeDifference / 2; + produceMessages(secondBatchTimestamp); + final long thirdBatchTimestamp = firstBatchTimestamp + timeDifference - 100L; + produceMessages(thirdBatchTimestamp); + + final Serde> windowedSerde = WindowedSerdes.timeWindowedSerdeFrom(String.class, timeDifference); + //noinspection deprecation + groupedStream + .windowedBy(SlidingWindows.withTimeDifferenceAndGrace(ofMillis(timeDifference), ofMillis(2000L))) + .reduce(reducer) + .toStream() + .to(outputTopic, Produced.with(windowedSerde, Serdes.String())); + + startStreams(); + + final List, String>> windowedOutput = receiveMessages( + new TimeWindowedDeserializer<>(new StringDeserializer(), 500L), + new StringDeserializer(), + String.class, + 30); + + final String resultFromConsoleConsumer = readWindowedKeyedMessagesViaConsoleConsumer( + new TimeWindowedDeserializer(), + new StringDeserializer(), + String.class, + 30, + true); + + final Comparator, String>> comparator = + Comparator.comparing((KeyValueTimestamp, String> o) -> o.key().key()) + .thenComparing(KeyValueTimestamp::value); + + windowedOutput.sort(comparator); + final long firstBatchLeftWindowStart = firstBatchTimestamp - timeDifference; + final long firstBatchLeftWindowEnd = firstBatchLeftWindowStart + timeDifference; + final long firstBatchRightWindowStart = firstBatchTimestamp + 1; + final long firstBatchRightWindowEnd = firstBatchRightWindowStart + timeDifference; + + final long secondBatchLeftWindowStart = secondBatchTimestamp - timeDifference; + final long secondBatchLeftWindowEnd = secondBatchLeftWindowStart + timeDifference; + final long secondBatchRightWindowStart = secondBatchTimestamp + 1; + final long secondBatchRightWindowEnd = secondBatchRightWindowStart + timeDifference; + + final long thirdBatchLeftWindowStart = thirdBatchTimestamp - timeDifference; + final long thirdBatchLeftWindowEnd = thirdBatchLeftWindowStart + timeDifference; + + final List, String>> expectResult = Arrays.asList( + // A @ firstBatchTimestamp left window created when A @ firstBatchTimestamp processed + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(firstBatchLeftWindowStart, firstBatchLeftWindowEnd)), "A", firstBatchTimestamp), + // A @ firstBatchTimestamp right window created when A @ secondBatchTimestamp processed + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(firstBatchRightWindowStart, firstBatchRightWindowEnd)), "A", secondBatchTimestamp), + // A @ secondBatchTimestamp right window created when A @ thirdBatchTimestamp processed + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(secondBatchRightWindowStart, secondBatchRightWindowEnd)), "A", thirdBatchTimestamp), + // A @ secondBatchTimestamp left window created when A @ secondBatchTimestamp processed + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(secondBatchLeftWindowStart, secondBatchLeftWindowEnd)), "A:A", secondBatchTimestamp), + // A @ firstBatchTimestamp right window updated when A @ thirdBatchTimestamp processed + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(firstBatchRightWindowStart, firstBatchRightWindowEnd)), "A:A", thirdBatchTimestamp), + // A @ thirdBatchTimestamp left window created when A @ thirdBatchTimestamp processed + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(thirdBatchLeftWindowStart, thirdBatchLeftWindowEnd)), "A:A:A", thirdBatchTimestamp), + + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(firstBatchLeftWindowStart, firstBatchLeftWindowEnd)), "B", firstBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(firstBatchRightWindowStart, firstBatchRightWindowEnd)), "B", secondBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(secondBatchRightWindowStart, secondBatchRightWindowEnd)), "B", thirdBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(secondBatchLeftWindowStart, secondBatchLeftWindowEnd)), "B:B", secondBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(firstBatchRightWindowStart, firstBatchRightWindowEnd)), "B:B", thirdBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(thirdBatchLeftWindowStart, thirdBatchLeftWindowEnd)), "B:B:B", thirdBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(firstBatchLeftWindowStart, firstBatchLeftWindowEnd)), "C", firstBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(firstBatchRightWindowStart, firstBatchRightWindowEnd)), "C", secondBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(secondBatchRightWindowStart, secondBatchRightWindowEnd)), "C", thirdBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(secondBatchLeftWindowStart, secondBatchLeftWindowEnd)), "C:C", secondBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(firstBatchRightWindowStart, firstBatchRightWindowEnd)), "C:C", thirdBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(thirdBatchLeftWindowStart, thirdBatchLeftWindowEnd)), "C:C:C", thirdBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(firstBatchLeftWindowStart, firstBatchLeftWindowEnd)), "D", firstBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(firstBatchRightWindowStart, firstBatchRightWindowEnd)), "D", secondBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(secondBatchRightWindowStart, secondBatchRightWindowEnd)), "D", thirdBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(secondBatchLeftWindowStart, secondBatchLeftWindowEnd)), "D:D", secondBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(firstBatchRightWindowStart, firstBatchRightWindowEnd)), "D:D", thirdBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(thirdBatchLeftWindowStart, thirdBatchLeftWindowEnd)), "D:D:D", thirdBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(firstBatchLeftWindowStart, firstBatchLeftWindowEnd)), "E", firstBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(firstBatchRightWindowStart, firstBatchRightWindowEnd)), "E", secondBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(secondBatchRightWindowStart, secondBatchRightWindowEnd)), "E", thirdBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(secondBatchLeftWindowStart, secondBatchLeftWindowEnd)), "E:E", secondBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(firstBatchRightWindowStart, firstBatchRightWindowEnd)), "E:E", thirdBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(thirdBatchLeftWindowStart, thirdBatchLeftWindowEnd)), "E:E:E", thirdBatchTimestamp) + ); + assertThat(windowedOutput, is(expectResult)); + + final Set expectResultString = new HashSet<>(expectResult.size()); + for (final KeyValueTimestamp, String> eachRecord: expectResult) { + expectResultString.add("CreateTime:" + eachRecord.timestamp() + ", " + + eachRecord.key() + ", " + eachRecord.value()); + } + + // check every message is contained in the expect result + final String[] allRecords = resultFromConsoleConsumer.split("\n"); + for (final String record: allRecords) { + assertTrue(expectResultString.contains(record)); + } + } + + @SuppressWarnings("deprecation") + @Test + public void shouldAggregateSlidingWindows() throws Exception { + final long firstBatchTimestamp = mockTime.milliseconds(); + final long timeDifference = 500L; + produceMessages(firstBatchTimestamp); + final long secondBatchTimestamp = firstBatchTimestamp + timeDifference / 2; + produceMessages(secondBatchTimestamp); + final long thirdBatchTimestamp = firstBatchTimestamp + timeDifference - 100L; + produceMessages(thirdBatchTimestamp); + + final Serde> windowedSerde = WindowedSerdes.timeWindowedSerdeFrom(String.class, timeDifference); + //noinspection deprecation + groupedStream.windowedBy(SlidingWindows.withTimeDifferenceAndGrace(ofMillis(500L), ofMinutes(5))) + .aggregate( + initializer, + aggregator, + Materialized.with(null, Serdes.Integer()) + ) + .toStream() + .to(outputTopic, Produced.with(windowedSerde, Serdes.Integer())); + + startStreams(); + + final List, Integer>> windowedMessages = receiveMessagesWithTimestamp( + new TimeWindowedDeserializer<>(), + new IntegerDeserializer(), + String.class, + 30); + + // read from ConsoleConsumer + final String resultFromConsoleConsumer = readWindowedKeyedMessagesViaConsoleConsumer( + new TimeWindowedDeserializer(), + new IntegerDeserializer(), + String.class, + 30, + true); + + final Comparator, Integer>> comparator = + Comparator.comparing((KeyValueTimestamp, Integer> o) -> o.key().key()) + .thenComparingInt(KeyValueTimestamp::value); + windowedMessages.sort(comparator); + + final long firstBatchLeftWindowStart = firstBatchTimestamp - timeDifference; + final long firstBatchLeftWindowEnd = firstBatchLeftWindowStart + timeDifference; + final long firstBatchRightWindowStart = firstBatchTimestamp + 1; + final long firstBatchRightWindowEnd = firstBatchRightWindowStart + timeDifference; + + final long secondBatchLeftWindowStart = secondBatchTimestamp - timeDifference; + final long secondBatchLeftWindowEnd = secondBatchLeftWindowStart + timeDifference; + final long secondBatchRightWindowStart = secondBatchTimestamp + 1; + final long secondBatchRightWindowEnd = secondBatchRightWindowStart + timeDifference; + + final long thirdBatchLeftWindowStart = thirdBatchTimestamp - timeDifference; + final long thirdBatchLeftWindowEnd = thirdBatchLeftWindowStart + timeDifference; + + final List, Integer>> expectResult = Arrays.asList( + // A @ firstBatchTimestamp left window created when A @ firstBatchTimestamp processed + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(firstBatchLeftWindowStart, firstBatchLeftWindowEnd)), 1, firstBatchTimestamp), + // A @ firstBatchTimestamp right window created when A @ secondBatchTimestamp processed + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(firstBatchRightWindowStart, firstBatchRightWindowEnd)), 1, secondBatchTimestamp), + // A @ secondBatchTimestamp right window created when A @ thirdBatchTimestamp processed + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(secondBatchRightWindowStart, secondBatchRightWindowEnd)), 1, thirdBatchTimestamp), + // A @ secondBatchTimestamp left window created when A @ secondBatchTimestamp processed + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(secondBatchLeftWindowStart, secondBatchLeftWindowEnd)), 2, secondBatchTimestamp), + // A @ firstBatchTimestamp right window updated when A @ thirdBatchTimestamp processed + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(firstBatchRightWindowStart, firstBatchRightWindowEnd)), 2, thirdBatchTimestamp), + // A @ thirdBatchTimestamp left window created when A @ thirdBatchTimestamp processed + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(thirdBatchLeftWindowStart, thirdBatchLeftWindowEnd)), 3, thirdBatchTimestamp), + + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(firstBatchLeftWindowStart, firstBatchLeftWindowEnd)), 1, firstBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(firstBatchRightWindowStart, firstBatchRightWindowEnd)), 1, secondBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(secondBatchRightWindowStart, secondBatchRightWindowEnd)), 1, thirdBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(secondBatchLeftWindowStart, secondBatchLeftWindowEnd)), 2, secondBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(firstBatchRightWindowStart, firstBatchRightWindowEnd)), 2, thirdBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(thirdBatchLeftWindowStart, thirdBatchLeftWindowEnd)), 3, thirdBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(firstBatchLeftWindowStart, firstBatchLeftWindowEnd)), 1, firstBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(firstBatchRightWindowStart, firstBatchRightWindowEnd)), 1, secondBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(secondBatchRightWindowStart, secondBatchRightWindowEnd)), 1, thirdBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(secondBatchLeftWindowStart, secondBatchLeftWindowEnd)), 2, secondBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(firstBatchRightWindowStart, firstBatchRightWindowEnd)), 2, thirdBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(thirdBatchLeftWindowStart, thirdBatchLeftWindowEnd)), 3, thirdBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(firstBatchLeftWindowStart, firstBatchLeftWindowEnd)), 1, firstBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(firstBatchRightWindowStart, firstBatchRightWindowEnd)), 1, secondBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(secondBatchRightWindowStart, secondBatchRightWindowEnd)), 1, thirdBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(secondBatchLeftWindowStart, secondBatchLeftWindowEnd)), 2, secondBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(firstBatchRightWindowStart, firstBatchRightWindowEnd)), 2, thirdBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(thirdBatchLeftWindowStart, thirdBatchLeftWindowEnd)), 3, thirdBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(firstBatchLeftWindowStart, firstBatchLeftWindowEnd)), 1, firstBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(firstBatchRightWindowStart, firstBatchRightWindowEnd)), 1, secondBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(secondBatchRightWindowStart, secondBatchRightWindowEnd)), 1, thirdBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(secondBatchLeftWindowStart, secondBatchLeftWindowEnd)), 2, secondBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(firstBatchRightWindowStart, firstBatchRightWindowEnd)), 2, thirdBatchTimestamp), + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(thirdBatchLeftWindowStart, thirdBatchLeftWindowEnd)), 3, thirdBatchTimestamp) + ); + + assertThat(windowedMessages, is(expectResult)); + + final Set expectResultString = new HashSet<>(expectResult.size()); + for (final KeyValueTimestamp, Integer> eachRecord: expectResult) { + expectResultString.add("CreateTime:" + eachRecord.timestamp() + ", " + eachRecord.key() + ", " + eachRecord.value()); + } + + // check every message is contained in the expect result + final String[] allRecords = resultFromConsoleConsumer.split("\n"); + for (final String record: allRecords) { + assertTrue(expectResultString.contains(record)); + } + + } + + @SuppressWarnings("deprecation") + @Test + public void shouldCountSessionWindows() throws Exception { + final long sessionGap = 5 * 60 * 1000L; + final List> t1Messages = Arrays.asList(new KeyValue<>("bob", "start"), + new KeyValue<>("penny", "start"), + new KeyValue<>("jo", "pause"), + new KeyValue<>("emily", "pause")); + + final long t1 = mockTime.milliseconds() - TimeUnit.MILLISECONDS.convert(1, TimeUnit.HOURS); + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + userSessionsStream, + t1Messages, + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + StringSerializer.class, + StringSerializer.class, + new Properties()), + t1); + final long t2 = t1 + (sessionGap / 2); + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + userSessionsStream, + Collections.singletonList( + new KeyValue<>("emily", "resume") + ), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + StringSerializer.class, + StringSerializer.class, + new Properties()), + t2); + final long t3 = t1 + sessionGap + 1; + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + userSessionsStream, + Arrays.asList( + new KeyValue<>("bob", "pause"), + new KeyValue<>("penny", "stop") + ), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + StringSerializer.class, + StringSerializer.class, + new Properties()), + t3); + final long t4 = t3 + (sessionGap / 2); + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + userSessionsStream, + Arrays.asList( + new KeyValue<>("bob", "resume"), // bobs session continues + new KeyValue<>("jo", "resume") // jo's starts new session + ), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + StringSerializer.class, + StringSerializer.class, + new Properties()), + t4); + final long t5 = t4 - 1; + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + userSessionsStream, + Collections.singletonList( + new KeyValue<>("jo", "late") // jo has late arrival + ), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + StringSerializer.class, + StringSerializer.class, + new Properties()), + t5); + + final Map, KeyValue> results = new HashMap<>(); + final CountDownLatch latch = new CountDownLatch(13); + + //noinspection deprecation + builder.stream(userSessionsStream, Consumed.with(Serdes.String(), Serdes.String())) + .groupByKey(Grouped.with(Serdes.String(), Serdes.String())) + .windowedBy(SessionWindows.with(ofMillis(sessionGap))) + .count() + .toStream() + .transform(() -> new Transformer, Long, KeyValue>() { + private ProcessorContext context; + + @Override + public void init(final ProcessorContext context) { + this.context = context; + } + + @Override + public KeyValue transform(final Windowed key, final Long value) { + results.put(key, KeyValue.pair(value, context.timestamp())); + latch.countDown(); + return null; + } + + @Override + public void close() {} + }); + + startStreams(); + latch.await(30, TimeUnit.SECONDS); + + assertThat(results.get(new Windowed<>("bob", new SessionWindow(t1, t1))), equalTo(KeyValue.pair(1L, t1))); + assertThat(results.get(new Windowed<>("penny", new SessionWindow(t1, t1))), equalTo(KeyValue.pair(1L, t1))); + assertThat(results.get(new Windowed<>("jo", new SessionWindow(t1, t1))), equalTo(KeyValue.pair(1L, t1))); + assertThat(results.get(new Windowed<>("jo", new SessionWindow(t5, t4))), equalTo(KeyValue.pair(2L, t4))); + assertThat(results.get(new Windowed<>("emily", new SessionWindow(t1, t2))), equalTo(KeyValue.pair(2L, t2))); + assertThat(results.get(new Windowed<>("bob", new SessionWindow(t3, t4))), equalTo(KeyValue.pair(2L, t4))); + assertThat(results.get(new Windowed<>("penny", new SessionWindow(t3, t3))), equalTo(KeyValue.pair(1L, t3))); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldReduceSessionWindows() throws Exception { + final long sessionGap = 1000L; // something to do with time + final List> t1Messages = Arrays.asList(new KeyValue<>("bob", "start"), + new KeyValue<>("penny", "start"), + new KeyValue<>("jo", "pause"), + new KeyValue<>("emily", "pause")); + + final long t1 = mockTime.milliseconds(); + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + userSessionsStream, + t1Messages, + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + StringSerializer.class, + StringSerializer.class, + new Properties()), + t1); + final long t2 = t1 + (sessionGap / 2); + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + userSessionsStream, + Collections.singletonList( + new KeyValue<>("emily", "resume") + ), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + StringSerializer.class, + StringSerializer.class, + new Properties()), + t2); + final long t3 = t1 + sessionGap + 1; + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + userSessionsStream, + Arrays.asList( + new KeyValue<>("bob", "pause"), + new KeyValue<>("penny", "stop") + ), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + StringSerializer.class, + StringSerializer.class, + new Properties()), + t3); + final long t4 = t3 + (sessionGap / 2); + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + userSessionsStream, + Arrays.asList( + new KeyValue<>("bob", "resume"), // bobs session continues + new KeyValue<>("jo", "resume") // jo's starts new session + ), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + StringSerializer.class, + StringSerializer.class, + new Properties()), + t4); + final long t5 = t4 - 1; + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + userSessionsStream, + Collections.singletonList( + new KeyValue<>("jo", "late") // jo has late arrival + ), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + StringSerializer.class, + StringSerializer.class, + new Properties()), + t5); + + final Map, KeyValue> results = new HashMap<>(); + final CountDownLatch latch = new CountDownLatch(13); + final String userSessionsStore = "UserSessionsStore"; + //noinspection deprecation + builder.stream(userSessionsStream, Consumed.with(Serdes.String(), Serdes.String())) + .groupByKey(Grouped.with(Serdes.String(), Serdes.String())) + .windowedBy(SessionWindows.with(ofMillis(sessionGap))) + .reduce((value1, value2) -> value1 + ":" + value2, Materialized.as(userSessionsStore)) + .toStream() + .transform(() -> new Transformer, String, KeyValue>() { + private ProcessorContext context; + + @Override + public void init(final ProcessorContext context) { + this.context = context; + } + + @Override + public KeyValue transform(final Windowed key, final String value) { + results.put(key, KeyValue.pair(value, context.timestamp())); + latch.countDown(); + return null; + } + + @Override + public void close() {} + }); + + startStreams(); + latch.await(30, TimeUnit.SECONDS); + + // verify correct data received + assertThat(results.get(new Windowed<>("bob", new SessionWindow(t1, t1))), equalTo(KeyValue.pair("start", t1))); + assertThat(results.get(new Windowed<>("penny", new SessionWindow(t1, t1))), equalTo(KeyValue.pair("start", t1))); + assertThat(results.get(new Windowed<>("jo", new SessionWindow(t1, t1))), equalTo(KeyValue.pair("pause", t1))); + assertThat(results.get(new Windowed<>("jo", new SessionWindow(t5, t4))), equalTo(KeyValue.pair("resume:late", t4))); + assertThat(results.get(new Windowed<>("emily", new SessionWindow(t1, t2))), equalTo(KeyValue.pair("pause:resume", t2))); + assertThat(results.get(new Windowed<>("bob", new SessionWindow(t3, t4))), equalTo(KeyValue.pair("pause:resume", t4))); + assertThat(results.get(new Windowed<>("penny", new SessionWindow(t3, t3))), equalTo(KeyValue.pair("stop", t3))); + + // verify can query data via IQ + final ReadOnlySessionStore sessionStore = + IntegrationTestUtils.getStore(userSessionsStore, kafkaStreams, QueryableStoreTypes.sessionStore()); + + try (final KeyValueIterator, String> bob = sessionStore.fetch("bob")) { + assertThat(bob.next(), equalTo(KeyValue.pair(new Windowed<>("bob", new SessionWindow(t1, t1)), "start"))); + assertThat(bob.next(), equalTo(KeyValue.pair(new Windowed<>("bob", new SessionWindow(t3, t4)), "pause:resume"))); + assertFalse(bob.hasNext()); + } + } + + @Test + public void shouldCountUnlimitedWindows() throws Exception { + final long startTime = mockTime.milliseconds() - TimeUnit.MILLISECONDS.convert(1, TimeUnit.HOURS) + 1; + final long incrementTime = Duration.ofDays(1).toMillis(); + + final long t1 = mockTime.milliseconds() - TimeUnit.MILLISECONDS.convert(1, TimeUnit.HOURS); + final List> t1Messages = Arrays.asList(new KeyValue<>("bob", "start"), + new KeyValue<>("penny", "start"), + new KeyValue<>("jo", "pause"), + new KeyValue<>("emily", "pause")); + + final Properties producerConfig = TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + StringSerializer.class, + StringSerializer.class, + new Properties() + ); + + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + userSessionsStream, + t1Messages, + producerConfig, + t1); + + final long t2 = t1 + incrementTime; + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + userSessionsStream, + Collections.singletonList( + new KeyValue<>("emily", "resume") + ), + producerConfig, + t2); + final long t3 = t2 + incrementTime; + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + userSessionsStream, + Arrays.asList( + new KeyValue<>("bob", "pause"), + new KeyValue<>("penny", "stop") + ), + producerConfig, + t3); + + final long t4 = t3 + incrementTime; + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + userSessionsStream, + Arrays.asList( + new KeyValue<>("bob", "resume"), // bobs session continues + new KeyValue<>("jo", "resume") // jo's starts new session + ), + producerConfig, + t4); + + final Map, KeyValue> results = new HashMap<>(); + final CountDownLatch latch = new CountDownLatch(5); + + builder.stream(userSessionsStream, Consumed.with(Serdes.String(), Serdes.String())) + .groupByKey(Grouped.with(Serdes.String(), Serdes.String())) + .windowedBy(UnlimitedWindows.of().startOn(ofEpochMilli(startTime))) + .count() + .toStream() + .transform(() -> new Transformer, Long, KeyValue>() { + private ProcessorContext context; + + @Override + public void init(final ProcessorContext context) { + this.context = context; + } + + @Override + public KeyValue transform(final Windowed key, final Long value) { + results.put(key, KeyValue.pair(value, context.timestamp())); + latch.countDown(); + return null; + } + + @Override + public void close() {} + }); + startStreams(); + assertTrue(latch.await(30, TimeUnit.SECONDS)); + + assertThat(results.get(new Windowed<>("bob", new UnlimitedWindow(startTime))), equalTo(KeyValue.pair(2L, t4))); + assertThat(results.get(new Windowed<>("penny", new UnlimitedWindow(startTime))), equalTo(KeyValue.pair(1L, t3))); + assertThat(results.get(new Windowed<>("jo", new UnlimitedWindow(startTime))), equalTo(KeyValue.pair(1L, t4))); + assertThat(results.get(new Windowed<>("emily", new UnlimitedWindow(startTime))), equalTo(KeyValue.pair(1L, t2))); + } + + + private void produceMessages(final long timestamp) throws Exception { + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + streamOneInput, + Arrays.asList( + new KeyValue<>(1, "A"), + new KeyValue<>(2, "B"), + new KeyValue<>(3, "C"), + new KeyValue<>(4, "D"), + new KeyValue<>(5, "E")), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + IntegerSerializer.class, + StringSerializer.class, + new Properties()), + timestamp); + } + + + private void createTopics() throws InterruptedException { + final String safeTestName = safeUniqueTestName(getClass(), testName); + streamOneInput = "stream-one-" + safeTestName; + outputTopic = "output-" + safeTestName; + userSessionsStream = "user-sessions-" + safeTestName; + CLUSTER.createTopic(streamOneInput, 3, 1); + CLUSTER.createTopics(userSessionsStream, outputTopic); + } + + private void startStreams() { + kafkaStreams = new KafkaStreams(builder.build(), streamsConfiguration); + kafkaStreams.start(); + } + + private List> receiveMessages(final Deserializer keyDeserializer, + final Deserializer valueDeserializer, + final int numMessages) + throws Exception { + + return receiveMessages(keyDeserializer, valueDeserializer, null, numMessages); + } + + private List> receiveMessages(final Deserializer keyDeserializer, + final Deserializer valueDeserializer, + final Class innerClass, + final int numMessages) + throws Exception { + + final String safeTestName = safeUniqueTestName(getClass(), testName); + final Properties consumerProperties = new Properties(); + consumerProperties.setProperty(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + consumerProperties.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "group-" + safeTestName); + consumerProperties.setProperty(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + consumerProperties.setProperty(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, keyDeserializer.getClass().getName()); + consumerProperties.setProperty(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, valueDeserializer.getClass().getName()); + consumerProperties.put(StreamsConfig.WINDOW_SIZE_MS_CONFIG, 500L); + if (keyDeserializer instanceof TimeWindowedDeserializer || keyDeserializer instanceof SessionWindowedDeserializer) { + consumerProperties.setProperty(StreamsConfig.WINDOWED_INNER_CLASS_SERDE, + Serdes.serdeFrom(innerClass).getClass().getName()); + } + return IntegrationTestUtils.waitUntilMinKeyValueWithTimestampRecordsReceived( + consumerProperties, + outputTopic, + numMessages, + 60 * 1000); + } + + private List> receiveMessagesWithTimestamp(final Deserializer keyDeserializer, + final Deserializer valueDeserializer, + final Class innerClass, + final int numMessages) throws Exception { + final String safeTestName = safeUniqueTestName(getClass(), testName); + final Properties consumerProperties = new Properties(); + consumerProperties.setProperty(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + consumerProperties.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "group-" + safeTestName); + consumerProperties.setProperty(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + consumerProperties.setProperty(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, keyDeserializer.getClass().getName()); + consumerProperties.setProperty(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, valueDeserializer.getClass().getName()); + consumerProperties.put(StreamsConfig.WINDOW_SIZE_MS_CONFIG, 500L); + if (keyDeserializer instanceof TimeWindowedDeserializer || keyDeserializer instanceof SessionWindowedDeserializer) { + consumerProperties.setProperty(StreamsConfig.WINDOWED_INNER_CLASS_SERDE, + Serdes.serdeFrom(innerClass).getClass().getName()); + } + return IntegrationTestUtils.waitUntilMinKeyValueWithTimestampRecordsReceived( + consumerProperties, + outputTopic, + numMessages, + 60 * 1000); + } + + private String readWindowedKeyedMessagesViaConsoleConsumer(final Deserializer keyDeserializer, + final Deserializer valueDeserializer, + final Class innerClass, + final int numMessages, + final boolean printTimestamp) { + final ByteArrayOutputStream newConsole = new ByteArrayOutputStream(); + final PrintStream originalStream = System.out; + try (final PrintStream newStream = new PrintStream(newConsole)) { + System.setOut(newStream); + + final String keySeparator = ", "; + // manually construct the console consumer argument array + final String[] args = new String[] { + "--bootstrap-server", CLUSTER.bootstrapServers(), + "--from-beginning", + "--property", "print.key=true", + "--property", "print.timestamp=" + printTimestamp, + "--topic", outputTopic, + "--max-messages", String.valueOf(numMessages), + "--property", "key.deserializer=" + keyDeserializer.getClass().getName(), + "--property", "value.deserializer=" + valueDeserializer.getClass().getName(), + "--property", "key.separator=" + keySeparator, + "--property", "key.deserializer." + StreamsConfig.WINDOWED_INNER_CLASS_SERDE + "=" + Serdes.serdeFrom(innerClass).getClass().getName(), + "--property", "key.deserializer.window.size.ms=500", + }; + + ConsoleConsumer.messageCount_$eq(0); //reset the message count + ConsoleConsumer.run(new ConsoleConsumer.ConsumerConfig(args)); + newStream.flush(); + System.setOut(originalStream); + return newConsole.toString(); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/KStreamRepartitionIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/KStreamRepartitionIntegrationTest.java new file mode 100644 index 0000000..1e7f685 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/KStreamRepartitionIntegrationTest.java @@ -0,0 +1,818 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.clients.admin.AdminClient; +import org.apache.kafka.clients.admin.AdminClientConfig; +import org.apache.kafka.clients.admin.TopicDescription; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.IntegerDeserializer; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.LongDeserializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KafkaStreams.State; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.JoinWindows; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Named; +import org.apache.kafka.streams.kstream.Repartitioned; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +import java.io.IOException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static org.apache.kafka.streams.KafkaStreams.State.ERROR; +import static org.apache.kafka.streams.KafkaStreams.State.REBALANCING; +import static org.apache.kafka.streams.KafkaStreams.State.RUNNING; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +@RunWith(value = Parameterized.class) +@Category({IntegrationTest.class}) +@SuppressWarnings("deprecation") +public class KStreamRepartitionIntegrationTest { + private static final int NUM_BROKERS = 1; + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(NUM_BROKERS); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + + private String topicB; + private String inputTopic; + private String outputTopic; + private String applicationId; + + private Properties streamsConfiguration; + private List kafkaStreamsInstances; + + @Parameter + public String topologyOptimization; + + @Parameters(name = "Optimization = {0}") + public static Collection topologyOptimization() { + return Arrays.asList(new String[][]{ + {StreamsConfig.OPTIMIZE}, + {StreamsConfig.NO_OPTIMIZATION} + }); + } + + @Rule + public TestName testName = new TestName(); + + @Before + public void before() throws InterruptedException { + streamsConfiguration = new Properties(); + kafkaStreamsInstances = new ArrayList<>(); + + final String safeTestName = safeUniqueTestName(getClass(), testName); + + topicB = "topic-b-" + safeTestName; + inputTopic = "input-topic-" + safeTestName; + outputTopic = "output-topic-" + safeTestName; + applicationId = "app-" + safeTestName; + + CLUSTER.createTopic(inputTopic, 4, 1); + CLUSTER.createTopic(outputTopic, 1, 1); + + streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, applicationId); + streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()); + streamsConfiguration.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L); + streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.Integer().getClass()); + streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + streamsConfiguration.put(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, topologyOptimization); + } + + @After + public void whenShuttingDown() throws IOException { + kafkaStreamsInstances.stream() + .filter(Objects::nonNull) + .forEach(KafkaStreams::close); + + IntegrationTestUtils.purgeLocalStreamsState(streamsConfiguration); + } + + @Test + public void shouldThrowAnExceptionWhenNumberOfPartitionsOfRepartitionOperationDoNotMatchSourceTopicWhenJoining() throws InterruptedException { + final int topicBNumberOfPartitions = 6; + final String inputTopicRepartitionName = "join-repartition-test"; + final AtomicReference expectedThrowable = new AtomicReference<>(); + final int inputTopicRepartitionedNumOfPartitions = 2; + + CLUSTER.createTopic(topicB, topicBNumberOfPartitions, 1); + + final StreamsBuilder builder = new StreamsBuilder(); + + final Repartitioned inputTopicRepartitioned = Repartitioned + .as(inputTopicRepartitionName) + .withNumberOfPartitions(inputTopicRepartitionedNumOfPartitions); + + final KStream topicBStream = builder + .stream(topicB, Consumed.with(Serdes.Integer(), Serdes.String())); + + builder.stream(inputTopic, Consumed.with(Serdes.Integer(), Serdes.String())) + .repartition(inputTopicRepartitioned) + .join(topicBStream, (value1, value2) -> value2, JoinWindows.of(Duration.ofSeconds(10))) + .to(outputTopic); + + builder.build(streamsConfiguration); + + startStreams(builder, REBALANCING, ERROR, (t, e) -> expectedThrowable.set(e)); + + final String expectedMsg = String.format("Number of partitions [%s] of repartition topic [%s] " + + "doesn't match number of partitions [%s] of the source topic.", + inputTopicRepartitionedNumOfPartitions, + toRepartitionTopicName(inputTopicRepartitionName), + topicBNumberOfPartitions); + assertNotNull(expectedThrowable.get()); + assertTrue(expectedThrowable.get().getMessage().contains(expectedMsg)); + } + + @Test + public void shouldDeductNumberOfPartitionsFromRepartitionOperation() throws Exception { + final String topicBMapperName = "topic-b-mapper"; + final int topicBNumberOfPartitions = 6; + final String inputTopicRepartitionName = "join-repartition-test"; + final int inputTopicRepartitionedNumOfPartitions = 3; + + final long timestamp = System.currentTimeMillis(); + + CLUSTER.createTopic(topicB, topicBNumberOfPartitions, 1); + + final List> expectedRecords = Arrays.asList( + new KeyValue<>(1, "A"), + new KeyValue<>(2, "B") + ); + + sendEvents(timestamp, expectedRecords); + sendEvents(topicB, timestamp, expectedRecords); + + final StreamsBuilder builder = new StreamsBuilder(); + + final Repartitioned inputTopicRepartitioned = Repartitioned + .as(inputTopicRepartitionName) + .withNumberOfPartitions(inputTopicRepartitionedNumOfPartitions); + + final KStream topicBStream = builder + .stream(topicB, Consumed.with(Serdes.Integer(), Serdes.String())) + .map(KeyValue::new, Named.as(topicBMapperName)); + + builder.stream(inputTopic, Consumed.with(Serdes.Integer(), Serdes.String())) + .repartition(inputTopicRepartitioned) + .join(topicBStream, (value1, value2) -> value2, JoinWindows.of(Duration.ofSeconds(10))) + .to(outputTopic); + + builder.build(streamsConfiguration); + + startStreams(builder); + + assertEquals(inputTopicRepartitionedNumOfPartitions, + getNumberOfPartitionsForTopic(toRepartitionTopicName(inputTopicRepartitionName))); + + assertEquals(inputTopicRepartitionedNumOfPartitions, + getNumberOfPartitionsForTopic(toRepartitionTopicName(topicBMapperName))); + + validateReceivedMessages( + new IntegerDeserializer(), + new StringDeserializer(), + expectedRecords + ); + } + + @Test + public void shouldDoProperJoiningWhenNumberOfPartitionsAreValidWhenUsingRepartitionOperation() throws Exception { + final String topicBRepartitionedName = "topic-b-scale-up"; + final String inputTopicRepartitionedName = "input-topic-scale-up"; + + final long timestamp = System.currentTimeMillis(); + + CLUSTER.createTopic(topicB, 1, 1); + + final List> expectedRecords = Arrays.asList( + new KeyValue<>(1, "A"), + new KeyValue<>(2, "B") + ); + + sendEvents(timestamp, expectedRecords); + sendEvents(topicB, timestamp, expectedRecords); + + final StreamsBuilder builder = new StreamsBuilder(); + + final Repartitioned inputTopicRepartitioned = Repartitioned + .as(inputTopicRepartitionedName) + .withNumberOfPartitions(4); + + final Repartitioned topicBRepartitioned = Repartitioned + .as(topicBRepartitionedName) + .withNumberOfPartitions(4); + + final KStream topicBStream = builder + .stream(topicB, Consumed.with(Serdes.Integer(), Serdes.String())) + .repartition(topicBRepartitioned); + + builder.stream(inputTopic, Consumed.with(Serdes.Integer(), Serdes.String())) + .repartition(inputTopicRepartitioned) + .join(topicBStream, (value1, value2) -> value2, JoinWindows.of(Duration.ofSeconds(10))) + .to(outputTopic); + + startStreams(builder); + + assertEquals(4, getNumberOfPartitionsForTopic(toRepartitionTopicName(topicBRepartitionedName))); + assertEquals(4, getNumberOfPartitionsForTopic(toRepartitionTopicName(inputTopicRepartitionedName))); + + validateReceivedMessages( + new IntegerDeserializer(), + new StringDeserializer(), + expectedRecords + ); + } + + @Test + public void shouldUseStreamPartitionerForRepartitionOperation() throws Exception { + final int partition = 1; + final String repartitionName = "partitioner-test"; + final long timestamp = System.currentTimeMillis(); + final AtomicInteger partitionerInvocation = new AtomicInteger(0); + + final List> expectedRecords = Arrays.asList( + new KeyValue<>(1, "A"), + new KeyValue<>(2, "B") + ); + + sendEvents(timestamp, expectedRecords); + + final StreamsBuilder builder = new StreamsBuilder(); + + final Repartitioned repartitioned = Repartitioned + .as(repartitionName) + .withStreamPartitioner((topic, key, value, numPartitions) -> { + partitionerInvocation.incrementAndGet(); + return partition; + }); + + builder.stream(inputTopic, Consumed.with(Serdes.Integer(), Serdes.String())) + .repartition(repartitioned) + .to(outputTopic); + + startStreams(builder); + + final String topic = toRepartitionTopicName(repartitionName); + + validateReceivedMessages( + new IntegerDeserializer(), + new StringDeserializer(), + expectedRecords + ); + + assertTrue(topicExists(topic)); + assertEquals(expectedRecords.size(), partitionerInvocation.get()); + } + + @Test + public void shouldPerformSelectKeyWithRepartitionOperation() throws Exception { + final long timestamp = System.currentTimeMillis(); + + sendEvents( + timestamp, + Arrays.asList( + new KeyValue<>(1, "10"), + new KeyValue<>(2, "20") + ) + ); + + final StreamsBuilder builder = new StreamsBuilder(); + + builder.stream(inputTopic, Consumed.with(Serdes.Integer(), Serdes.String())) + .selectKey((key, value) -> Integer.valueOf(value)) + .repartition() + .to(outputTopic); + + startStreams(builder); + + validateReceivedMessages( + new IntegerDeserializer(), + new StringDeserializer(), + Arrays.asList( + new KeyValue<>(10, "10"), + new KeyValue<>(20, "20") + ) + ); + + final String topology = builder.build().describe().toString(); + + assertEquals(1, countOccurrencesInTopology(topology, "Sink: .*-repartition.*")); + } + + @Test + public void shouldCreateRepartitionTopicIfKeyChangingOperationWasNotPerformed() throws Exception { + final String repartitionName = "dummy"; + final long timestamp = System.currentTimeMillis(); + + sendEvents( + timestamp, + Arrays.asList( + new KeyValue<>(1, "A"), + new KeyValue<>(2, "B") + ) + ); + + final StreamsBuilder builder = new StreamsBuilder(); + + builder.stream(inputTopic, Consumed.with(Serdes.Integer(), Serdes.String())) + .repartition(Repartitioned.as(repartitionName)) + .to(outputTopic); + + startStreams(builder); + + validateReceivedMessages( + new IntegerDeserializer(), + new StringDeserializer(), + Arrays.asList( + new KeyValue<>(1, "A"), + new KeyValue<>(2, "B") + ) + ); + + final String topology = builder.build().describe().toString(); + + assertTrue(topicExists(toRepartitionTopicName(repartitionName))); + assertEquals(1, countOccurrencesInTopology(topology, "Sink: .*dummy-repartition.*")); + } + + @Test + public void shouldPerformKeySelectOperationWhenRepartitionOperationIsUsedWithKeySelector() throws Exception { + final String repartitionedName = "new-key"; + final long timestamp = System.currentTimeMillis(); + + sendEvents( + timestamp, + Arrays.asList( + new KeyValue<>(1, "A"), + new KeyValue<>(2, "B") + ) + ); + + final StreamsBuilder builder = new StreamsBuilder(); + + final Repartitioned repartitioned = Repartitioned.as(repartitionedName) + .withKeySerde(Serdes.String()); + + builder.stream(inputTopic, Consumed.with(Serdes.Integer(), Serdes.String())) + .selectKey((key, value) -> key.toString(), Named.as(repartitionedName)) + .repartition(repartitioned) + .groupByKey() + .count() + .toStream() + .to(outputTopic); + + startStreams(builder); + + validateReceivedMessages( + new StringDeserializer(), + new LongDeserializer(), + Arrays.asList( + new KeyValue<>("1", 1L), + new KeyValue<>("2", 1L) + ) + ); + + final String topology = builder.build().describe().toString(); + final String repartitionTopicName = toRepartitionTopicName(repartitionedName); + + assertTrue(topicExists(repartitionTopicName)); + assertEquals(1, countOccurrencesInTopology(topology, "Sink: .*" + repartitionedName + "-repartition.*")); + assertEquals(1, countOccurrencesInTopology(topology, "<-- " + repartitionedName + "\n")); + } + + @Test + public void shouldCreateRepartitionTopicWithSpecifiedNumberOfPartitions() throws Exception { + final String repartitionName = "new-partitions"; + final long timestamp = System.currentTimeMillis(); + + sendEvents( + timestamp, + Arrays.asList( + new KeyValue<>(1, "A"), + new KeyValue<>(2, "B") + ) + ); + + final StreamsBuilder builder = new StreamsBuilder(); + + builder.stream(inputTopic, Consumed.with(Serdes.Integer(), Serdes.String())) + .repartition(Repartitioned.as(repartitionName).withNumberOfPartitions(2)) + .groupByKey() + .count() + .toStream() + .to(outputTopic); + + startStreams(builder); + + validateReceivedMessages( + new IntegerDeserializer(), + new LongDeserializer(), + Arrays.asList( + new KeyValue<>(1, 1L), + new KeyValue<>(2, 1L) + ) + ); + + final String repartitionTopicName = toRepartitionTopicName(repartitionName); + + assertTrue(topicExists(repartitionTopicName)); + assertEquals(2, getNumberOfPartitionsForTopic(repartitionTopicName)); + } + + @Test + public void shouldInheritRepartitionTopicPartitionNumberFromUpstreamTopicWhenNumberOfPartitionsIsNotSpecified() throws Exception { + final String repartitionName = "new-topic"; + final long timestamp = System.currentTimeMillis(); + + sendEvents( + timestamp, + Arrays.asList( + new KeyValue<>(1, "A"), + new KeyValue<>(2, "B") + ) + ); + + final StreamsBuilder builder = new StreamsBuilder(); + + builder.stream(inputTopic, Consumed.with(Serdes.Integer(), Serdes.String())) + .repartition(Repartitioned.as(repartitionName)) + .groupByKey() + .count() + .toStream() + .to(outputTopic); + + startStreams(builder); + + validateReceivedMessages( + new IntegerDeserializer(), + new LongDeserializer(), + Arrays.asList( + new KeyValue<>(1, 1L), + new KeyValue<>(2, 1L) + ) + ); + + final String repartitionTopicName = toRepartitionTopicName(repartitionName); + + assertTrue(topicExists(repartitionTopicName)); + assertEquals(4, getNumberOfPartitionsForTopic(repartitionTopicName)); + } + + @Test + public void shouldCreateOnlyOneRepartitionTopicWhenRepartitionIsFollowedByGroupByKey() throws Exception { + final String repartitionName = "new-partitions"; + final long timestamp = System.currentTimeMillis(); + + sendEvents( + timestamp, + Arrays.asList( + new KeyValue<>(1, "A"), + new KeyValue<>(2, "B") + ) + ); + + final StreamsBuilder builder = new StreamsBuilder(); + + final Repartitioned repartitioned = Repartitioned.as(repartitionName) + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.String()) + .withNumberOfPartitions(1); + + builder.stream(inputTopic, Consumed.with(Serdes.Integer(), Serdes.String())) + .selectKey((key, value) -> key.toString()) + .repartition(repartitioned) + .groupByKey() + .count() + .toStream() + .to(outputTopic); + + startStreams(builder); + + final String topology = builder.build().describe().toString(); + + validateReceivedMessages( + new StringDeserializer(), + new LongDeserializer(), + Arrays.asList( + new KeyValue<>("1", 1L), + new KeyValue<>("2", 1L) + ) + ); + + assertTrue(topicExists(toRepartitionTopicName(repartitionName))); + assertEquals(1, countOccurrencesInTopology(topology, "Sink: .*-repartition")); + } + + @Test + public void shouldGenerateRepartitionTopicWhenNameIsNotSpecified() throws Exception { + final long timestamp = System.currentTimeMillis(); + + sendEvents( + timestamp, + Arrays.asList( + new KeyValue<>(1, "A"), + new KeyValue<>(2, "B") + ) + ); + + final StreamsBuilder builder = new StreamsBuilder(); + + builder.stream(inputTopic, Consumed.with(Serdes.Integer(), Serdes.String())) + .selectKey((key, value) -> key.toString()) + .repartition(Repartitioned.with(Serdes.String(), Serdes.String())) + .to(outputTopic); + + startStreams(builder); + + validateReceivedMessages( + new StringDeserializer(), + new StringDeserializer(), + Arrays.asList( + new KeyValue<>("1", "A"), + new KeyValue<>("2", "B") + ) + ); + + final String topology = builder.build().describe().toString(); + + assertEquals(1, countOccurrencesInTopology(topology, "Sink: .*-repartition")); + } + + @Test + public void shouldGoThroughRebalancingCorrectly() throws Exception { + final String repartitionName = "rebalancing-test"; + final long timestamp = System.currentTimeMillis(); + + sendEvents( + timestamp, + Arrays.asList( + new KeyValue<>(1, "A"), + new KeyValue<>(2, "B") + ) + ); + + final StreamsBuilder builder = new StreamsBuilder(); + + final Repartitioned repartitioned = Repartitioned.as(repartitionName) + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.String()) + .withNumberOfPartitions(2); + + builder.stream(inputTopic, Consumed.with(Serdes.Integer(), Serdes.String())) + .selectKey((key, value) -> key.toString()) + .repartition(repartitioned) + .groupByKey() + .count() + .toStream() + .to(outputTopic); + + startStreams(builder); + final Properties streamsToCloseConfigs = new Properties(); + streamsToCloseConfigs.putAll(streamsConfiguration); + streamsToCloseConfigs.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath() + "-2"); + final KafkaStreams kafkaStreamsToClose = startStreams(builder, streamsToCloseConfigs); + + validateReceivedMessages( + new StringDeserializer(), + new LongDeserializer(), + Arrays.asList( + new KeyValue<>("1", 1L), + new KeyValue<>("2", 1L) + ) + ); + + kafkaStreamsToClose.close(); + + sendEvents( + timestamp, + Arrays.asList( + new KeyValue<>(1, "C"), + new KeyValue<>(2, "D") + ) + ); + + validateReceivedMessages( + new StringDeserializer(), + new LongDeserializer(), + Arrays.asList( + new KeyValue<>("1", 2L), + new KeyValue<>("2", 2L) + ) + ); + + final String repartitionTopicName = toRepartitionTopicName(repartitionName); + + assertTrue(topicExists(repartitionTopicName)); + assertEquals(2, getNumberOfPartitionsForTopic(repartitionTopicName)); + } + + private int getNumberOfPartitionsForTopic(final String topic) throws Exception { + try (final AdminClient adminClient = createAdminClient()) { + final TopicDescription topicDescription = adminClient.describeTopics(Collections.singleton(topic)) + .topicNameValues() + .get(topic) + .get(); + + return topicDescription.partitions().size(); + } + } + + private boolean topicExists(final String topic) throws Exception { + try (final AdminClient adminClient = createAdminClient()) { + final Set topics = adminClient.listTopics() + .names() + .get(); + + return topics.contains(topic); + } + } + + private String toRepartitionTopicName(final String input) { + return applicationId + "-" + input + "-repartition"; + } + + private static AdminClient createAdminClient() { + final Properties properties = new Properties(); + properties.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + + return AdminClient.create(properties); + } + + private static int countOccurrencesInTopology(final String topologyString, + final String searchPattern) { + final Matcher matcher = Pattern.compile(searchPattern).matcher(topologyString); + final List repartitionTopicsFound = new ArrayList<>(); + + while (matcher.find()) { + repartitionTopicsFound.add(matcher.group()); + } + + return repartitionTopicsFound.size(); + } + + private void sendEvents(final long timestamp, + final List> events) throws Exception { + sendEvents(inputTopic, timestamp, events); + } + + private void sendEvents(final String topic, + final long timestamp, + final List> events) throws Exception { + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + topic, + events, + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + IntegerSerializer.class, + StringSerializer.class, + new Properties() + ), + timestamp + ); + } + + private KafkaStreams startStreams(final StreamsBuilder builder) throws InterruptedException { + return startStreams(builder, REBALANCING, RUNNING, streamsConfiguration, null); + } + + private KafkaStreams startStreams(final StreamsBuilder builder, final Properties streamsConfiguration) throws InterruptedException { + return startStreams(builder, REBALANCING, RUNNING, streamsConfiguration, null); + } + + private KafkaStreams startStreams(final StreamsBuilder builder, + final State expectedOldState, + final State expectedNewState, + final Thread.UncaughtExceptionHandler uncaughtExceptionHandler) throws InterruptedException { + return startStreams(builder, expectedOldState, expectedNewState, streamsConfiguration, uncaughtExceptionHandler); + } + + private KafkaStreams startStreams(final StreamsBuilder builder, + final State expectedOldState, + final State expectedNewState, + final Properties streamsConfiguration, + final Thread.UncaughtExceptionHandler uncaughtExceptionHandler) throws InterruptedException { + final CountDownLatch latch; + final KafkaStreams kafkaStreams = new KafkaStreams(builder.build(streamsConfiguration), streamsConfiguration); + + if (uncaughtExceptionHandler == null) { + latch = new CountDownLatch(1); + } else { + latch = new CountDownLatch(2); + kafkaStreams.setUncaughtExceptionHandler(e -> { + uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), e); + latch.countDown(); + if (e instanceof RuntimeException) { + throw (RuntimeException) e; + } else if (e instanceof Error) { + throw (Error) e; + } else { + throw new RuntimeException("Unexpected checked exception caught in the uncaught exception handler", e); + } + }); + } + + kafkaStreams.setStateListener((newState, oldState) -> { + if (expectedOldState == oldState && expectedNewState == newState) { + latch.countDown(); + } + }); + + kafkaStreams.start(); + + latch.await(IntegrationTestUtils.DEFAULT_TIMEOUT, TimeUnit.MILLISECONDS); + kafkaStreamsInstances.add(kafkaStreams); + + return kafkaStreams; + } + + private void validateReceivedMessages(final Deserializer keySerializer, + final Deserializer valueSerializer, + final List> expectedRecords) throws Exception { + + final String safeTestName = safeUniqueTestName(getClass(), testName); + final Properties consumerProperties = new Properties(); + consumerProperties.setProperty(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + consumerProperties.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "group-" + safeTestName); + consumerProperties.setProperty(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + consumerProperties.setProperty( + ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, + keySerializer.getClass().getName() + ); + consumerProperties.setProperty( + ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, + valueSerializer.getClass().getName() + ); + + IntegrationTestUtils.waitUntilFinalKeyValueRecordsReceived( + consumerProperties, + outputTopic, + expectedRecords + ); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/KStreamTransformIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/KStreamTransformIntegrationTest.java new file mode 100644 index 0000000..7b13d9f --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/KStreamTransformIntegrationTest.java @@ -0,0 +1,562 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.ForeachAction; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Transformer; +import org.apache.kafka.streams.kstream.TransformerSupplier; +import org.apache.kafka.streams.kstream.ValueTransformer; +import org.apache.kafka.streams.kstream.ValueTransformerSupplier; +import org.apache.kafka.streams.kstream.ValueTransformerWithKey; +import org.apache.kafka.streams.kstream.ValueTransformerWithKeySupplier; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.StreamsTestUtils; + +import org.junit.Before; +import org.junit.Test; +import org.junit.experimental.categories.Category; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Set; +import java.util.Properties; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.IsEqual.equalTo; + +@SuppressWarnings("unchecked") +@Category({IntegrationTest.class}) +public class KStreamTransformIntegrationTest { + + private StreamsBuilder builder; + private final String topic = "stream"; + private final String stateStoreName = "myTransformState"; + private final List> results = new ArrayList<>(); + private final ForeachAction accumulateExpected = (key, value) -> results.add(KeyValue.pair(key, value)); + private KStream stream; + + @Before + public void before() { + builder = new StreamsBuilder(); + stream = builder.stream(topic, Consumed.with(Serdes.Integer(), Serdes.Integer())); + } + + private StoreBuilder> storeBuilder() { + return Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore(stateStoreName), + Serdes.Integer(), + Serdes.Integer()); + } + + private void verifyResult(final List> expected) { + final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.Integer(), Serdes.Integer()); + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(topic, new IntegerSerializer(), new IntegerSerializer()); + inputTopic.pipeKeyValueList(Arrays.asList( + new KeyValue<>(1, 1), + new KeyValue<>(2, 2), + new KeyValue<>(3, 3), + new KeyValue<>(2, 1), + new KeyValue<>(2, 3), + new KeyValue<>(1, 3))); + } + assertThat(results, equalTo(expected)); + } + + private class TestTransformer implements Transformer> { + private KeyValueStore state; + + @SuppressWarnings("unchecked") + @Override + public void init(final ProcessorContext context) { + state = (KeyValueStore) context.getStateStore(stateStoreName); + } + + @Override + public KeyValue transform(final Integer key, final Integer value) { + state.putIfAbsent(key, 0); + Integer storedValue = state.get(key); + final KeyValue result = new KeyValue<>(key + 1, value + storedValue++); + state.put(key, storedValue); + return result; + } + + @Override + public void close() { + } + } + + @Test + public void shouldTransform() { + builder.addStateStore(storeBuilder()); + + stream + .transform(TestTransformer::new, stateStoreName) + .foreach(accumulateExpected); + + final List> expected = Arrays.asList( + KeyValue.pair(2, 1), + KeyValue.pair(3, 2), + KeyValue.pair(4, 3), + KeyValue.pair(3, 2), + KeyValue.pair(3, 5), + KeyValue.pair(2, 4)); + verifyResult(expected); + } + + @Test + public void shouldTransformWithConnectedStoreProvider() { + stream + .transform(new TransformerSupplier>() { + @Override + public Transformer> get() { + return new TestTransformer(); + } + + @Override + public Set> stores() { + return Collections.singleton(storeBuilder()); + } + }) + .foreach(accumulateExpected); + + final List> expected = Arrays.asList( + KeyValue.pair(2, 1), + KeyValue.pair(3, 2), + KeyValue.pair(4, 3), + KeyValue.pair(3, 2), + KeyValue.pair(3, 5), + KeyValue.pair(2, 4)); + verifyResult(expected); + } + + private class TestFlatTransformer implements Transformer>> { + private KeyValueStore state; + + @SuppressWarnings("unchecked") + @Override + public void init(final ProcessorContext context) { + state = (KeyValueStore) context.getStateStore(stateStoreName); + } + + @Override + public Iterable> transform(final Integer key, final Integer value) { + final List> result = new ArrayList<>(); + state.putIfAbsent(key, 0); + Integer storedValue = state.get(key); + for (int i = 0; i < 3; i++) { + result.add(new KeyValue<>(key + i, value + storedValue++)); + } + state.put(key, storedValue); + return result; + } + + @Override + public void close() { + } + } + + @Test + public void shouldFlatTransform() { + builder.addStateStore(storeBuilder()); + + stream + .flatTransform(TestFlatTransformer::new, stateStoreName) + .foreach(accumulateExpected); + + final List> expected = Arrays.asList( + KeyValue.pair(1, 1), + KeyValue.pair(2, 2), + KeyValue.pair(3, 3), + KeyValue.pair(2, 2), + KeyValue.pair(3, 3), + KeyValue.pair(4, 4), + KeyValue.pair(3, 3), + KeyValue.pair(4, 4), + KeyValue.pair(5, 5), + KeyValue.pair(2, 4), + KeyValue.pair(3, 5), + KeyValue.pair(4, 6), + KeyValue.pair(2, 9), + KeyValue.pair(3, 10), + KeyValue.pair(4, 11), + KeyValue.pair(1, 6), + KeyValue.pair(2, 7), + KeyValue.pair(3, 8)); + verifyResult(expected); + } + + @Test + public void shouldFlatTransformWithConnectedStoreProvider() { + stream + .flatTransform(new TransformerSupplier>>() { + @Override + public Transformer>> get() { + return new TestFlatTransformer(); + } + + @Override + public Set> stores() { + return Collections.singleton(storeBuilder()); + } + }) + .foreach(accumulateExpected); + + final List> expected = Arrays.asList( + KeyValue.pair(1, 1), + KeyValue.pair(2, 2), + KeyValue.pair(3, 3), + KeyValue.pair(2, 2), + KeyValue.pair(3, 3), + KeyValue.pair(4, 4), + KeyValue.pair(3, 3), + KeyValue.pair(4, 4), + KeyValue.pair(5, 5), + KeyValue.pair(2, 4), + KeyValue.pair(3, 5), + KeyValue.pair(4, 6), + KeyValue.pair(2, 9), + KeyValue.pair(3, 10), + KeyValue.pair(4, 11), + KeyValue.pair(1, 6), + KeyValue.pair(2, 7), + KeyValue.pair(3, 8)); + verifyResult(expected); + } + + private class TestValueTransformerWithKey implements ValueTransformerWithKey { + private KeyValueStore state; + + @Override + public void init(final ProcessorContext context) { + state = (KeyValueStore) context.getStateStore(stateStoreName); + } + + @Override + public Integer transform(final Integer key, final Integer value) { + state.putIfAbsent(key, 0); + Integer storedValue = state.get(key); + final Integer result = value + storedValue++; + state.put(key, storedValue); + return result; + } + + @Override + public void close() { + } + } + + @Test + public void shouldTransformValuesWithValueTransformerWithKey() { + builder.addStateStore(storeBuilder()); + + stream + .transformValues(TestValueTransformerWithKey::new, stateStoreName) + .foreach(accumulateExpected); + + final List> expected = Arrays.asList( + KeyValue.pair(1, 1), + KeyValue.pair(2, 2), + KeyValue.pair(3, 3), + KeyValue.pair(2, 2), + KeyValue.pair(2, 5), + KeyValue.pair(1, 4)); + verifyResult(expected); + } + + @Test + public void shouldTransformValuesWithValueTransformerWithKeyWithConnectedStoreProvider() { + stream + .transformValues(new ValueTransformerWithKeySupplier() { + @Override + public ValueTransformerWithKey get() { + return new TestValueTransformerWithKey(); + } + + @Override + public Set> stores() { + return Collections.singleton(storeBuilder()); + } + }) + .foreach(accumulateExpected); + } + + private class TestValueTransformer implements ValueTransformer { + private KeyValueStore state; + + @Override + public void init(final ProcessorContext context) { + state = (KeyValueStore) context.getStateStore(stateStoreName); + } + + @Override + public Integer transform(final Integer value) { + state.putIfAbsent(value, 0); + Integer counter = state.get(value); + state.put(value, ++counter); + return counter; + } + + @Override + public void close() { + } + } + + @Test + public void shouldTransformValuesWithValueTransformerWithoutKey() { + builder.addStateStore(storeBuilder()); + + stream + .transformValues(TestValueTransformer::new, stateStoreName) + .foreach(accumulateExpected); + + final List> expected = Arrays.asList( + KeyValue.pair(1, 1), + KeyValue.pair(2, 1), + KeyValue.pair(3, 1), + KeyValue.pair(2, 2), + KeyValue.pair(2, 2), + KeyValue.pair(1, 3)); + verifyResult(expected); + } + + @Test + public void shouldTransformValuesWithValueTransformerWithoutKeyWithConnectedStoreProvider() { + stream + .transformValues(new ValueTransformerSupplier() { + @Override + public ValueTransformer get() { + return new TestValueTransformer(); + } + + @Override + public Set> stores() { + return Collections.singleton(storeBuilder()); + } + }) + .foreach(accumulateExpected); + + final List> expected = Arrays.asList( + KeyValue.pair(1, 1), + KeyValue.pair(2, 1), + KeyValue.pair(3, 1), + KeyValue.pair(2, 2), + KeyValue.pair(2, 2), + KeyValue.pair(1, 3)); + verifyResult(expected); + } + + private class TestValueTransformerWithoutKey implements ValueTransformerWithKey> { + private KeyValueStore state; + + @Override + public void init(final ProcessorContext context) { + state = (KeyValueStore) context.getStateStore(stateStoreName); + } + + @Override + public Iterable transform(final Integer key, final Integer value) { + final List result = new ArrayList<>(); + state.putIfAbsent(key, 0); + Integer storedValue = state.get(key); + for (int i = 0; i < 3; i++) { + result.add(value + storedValue++); + } + state.put(key, storedValue); + return result; + } + + @Override + public void close() { + } + } + + @Test + public void shouldFlatTransformValuesWithKey() { + builder.addStateStore(storeBuilder()); + + stream + .flatTransformValues(TestValueTransformerWithoutKey::new, stateStoreName) + .foreach(accumulateExpected); + + final List> expected = Arrays.asList( + KeyValue.pair(1, 1), + KeyValue.pair(1, 2), + KeyValue.pair(1, 3), + KeyValue.pair(2, 2), + KeyValue.pair(2, 3), + KeyValue.pair(2, 4), + KeyValue.pair(3, 3), + KeyValue.pair(3, 4), + KeyValue.pair(3, 5), + KeyValue.pair(2, 4), + KeyValue.pair(2, 5), + KeyValue.pair(2, 6), + KeyValue.pair(2, 9), + KeyValue.pair(2, 10), + KeyValue.pair(2, 11), + KeyValue.pair(1, 6), + KeyValue.pair(1, 7), + KeyValue.pair(1, 8)); + verifyResult(expected); + } + + @Test + public void shouldFlatTransformValuesWithKeyWithConnectedStoreProvider() { + stream + .flatTransformValues(new ValueTransformerWithKeySupplier>() { + @Override + public ValueTransformerWithKey> get() { + return new TestValueTransformerWithoutKey(); + } + + @Override + public Set> stores() { + return Collections.singleton(storeBuilder()); + } + }) + .foreach(accumulateExpected); + + final List> expected = Arrays.asList( + KeyValue.pair(1, 1), + KeyValue.pair(1, 2), + KeyValue.pair(1, 3), + KeyValue.pair(2, 2), + KeyValue.pair(2, 3), + KeyValue.pair(2, 4), + KeyValue.pair(3, 3), + KeyValue.pair(3, 4), + KeyValue.pair(3, 5), + KeyValue.pair(2, 4), + KeyValue.pair(2, 5), + KeyValue.pair(2, 6), + KeyValue.pair(2, 9), + KeyValue.pair(2, 10), + KeyValue.pair(2, 11), + KeyValue.pair(1, 6), + KeyValue.pair(1, 7), + KeyValue.pair(1, 8)); + verifyResult(expected); + } + + private class TestFlatValueTransformer implements ValueTransformer> { + private KeyValueStore state; + + @Override + public void init(final ProcessorContext context) { + state = (KeyValueStore) context.getStateStore(stateStoreName); + } + + @Override + public Iterable transform(final Integer value) { + final List result = new ArrayList<>(); + state.putIfAbsent(value, 0); + Integer counter = state.get(value); + for (int i = 0; i < 3; i++) { + result.add(++counter); + } + state.put(value, counter); + return result; + } + + @Override + public void close() { + } + } + + @Test + public void shouldFlatTransformValuesWithValueTransformerWithoutKey() { + builder.addStateStore(storeBuilder()); + + stream + .flatTransformValues(TestFlatValueTransformer::new, stateStoreName) + .foreach(accumulateExpected); + + final List> expected = Arrays.asList( + KeyValue.pair(1, 1), + KeyValue.pair(1, 2), + KeyValue.pair(1, 3), + KeyValue.pair(2, 1), + KeyValue.pair(2, 2), + KeyValue.pair(2, 3), + KeyValue.pair(3, 1), + KeyValue.pair(3, 2), + KeyValue.pair(3, 3), + KeyValue.pair(2, 4), + KeyValue.pair(2, 5), + KeyValue.pair(2, 6), + KeyValue.pair(2, 4), + KeyValue.pair(2, 5), + KeyValue.pair(2, 6), + KeyValue.pair(1, 7), + KeyValue.pair(1, 8), + KeyValue.pair(1, 9)); + verifyResult(expected); + } + + @Test + public void shouldFlatTransformValuesWithValueTransformerWithoutKeyWithConnectedStoreProvider() { + stream + .flatTransformValues(new ValueTransformerSupplier>() { + @Override + public ValueTransformer> get() { + return new TestFlatValueTransformer(); + } + + @Override + public Set> stores() { + return Collections.singleton(storeBuilder()); + } + }) + .foreach(accumulateExpected); + + final List> expected = Arrays.asList( + KeyValue.pair(1, 1), + KeyValue.pair(1, 2), + KeyValue.pair(1, 3), + KeyValue.pair(2, 1), + KeyValue.pair(2, 2), + KeyValue.pair(2, 3), + KeyValue.pair(3, 1), + KeyValue.pair(3, 2), + KeyValue.pair(3, 3), + KeyValue.pair(2, 4), + KeyValue.pair(2, 5), + KeyValue.pair(2, 6), + KeyValue.pair(2, 4), + KeyValue.pair(2, 5), + KeyValue.pair(2, 6), + KeyValue.pair(1, 7), + KeyValue.pair(1, 8), + KeyValue.pair(1, 9)); + verifyResult(expected); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/KTableEfficientRangeQueryTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/KTableEfficientRangeQueryTest.java new file mode 100644 index 0000000..b0564ba --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/KTableEfficientRangeQueryTest.java @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.integration; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.state.KeyValueBytesStoreSupplier; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.ReadOnlyKeyValueStore; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.test.TestUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestName; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Properties; +import java.util.function.Predicate; +import java.util.function.Supplier; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkProperties; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +@RunWith(Parameterized.class) +public class KTableEfficientRangeQueryTest { + private enum StoreType { InMemory, RocksDB, Timed }; + private static final String TABLE_NAME = "mytable"; + private static final int DATA_SIZE = 5; + + private StoreType storeType; + private boolean enableLogging; + private boolean enableCaching; + private boolean forward; + + private LinkedList> records; + private String low; + private String high; + private String middle; + private String innerLow; + private String innerHigh; + private String innerLowBetween; + private String innerHighBetween; + + private Properties streamsConfig; + + public KTableEfficientRangeQueryTest(final StoreType storeType, final boolean enableLogging, final boolean enableCaching, final boolean forward) { + this.storeType = storeType; + this.enableLogging = enableLogging; + this.enableCaching = enableCaching; + this.forward = forward; + + this.records = new LinkedList<>(); + final int m = DATA_SIZE / 2; + for (int i = 0; i < DATA_SIZE; i++) { + final String key = "key-" + i * 2; + final String value = "val-" + i * 2; + records.add(new KeyValue<>(key, value)); + high = key; + if (low == null) { + low = key; + } + if (i == m) { + middle = key; + } + if (i == 1) { + innerLow = key; + final int index = i * 2 - 1; + innerLowBetween = "key-" + index; + } + if (i == DATA_SIZE - 2) { + innerHigh = key; + final int index = i * 2 + 1; + innerHighBetween = "key-" + index; + } + } + Assert.assertNotNull(low); + Assert.assertNotNull(high); + Assert.assertNotNull(middle); + Assert.assertNotNull(innerLow); + Assert.assertNotNull(innerHigh); + Assert.assertNotNull(innerLowBetween); + Assert.assertNotNull(innerHighBetween); + } + + @Rule + public TestName testName = new TestName(); + + @Parameterized.Parameters(name = "storeType={0}, enableLogging={1}, enableCaching={2}, forward={3}") + public static Collection data() { + final List types = Arrays.asList(StoreType.InMemory, StoreType.RocksDB, StoreType.Timed); + final List logging = Arrays.asList(true, false); + final List caching = Arrays.asList(true, false); + final List forward = Arrays.asList(true, false); + return buildParameters(types, logging, caching, forward); + } + + @Before + public void setup() { + streamsConfig = mkProperties(mkMap( + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()) + )); + } + + @Test + public void testStoreConfig() { + final Materialized> stateStoreConfig = getStoreConfig(storeType, TABLE_NAME, enableLogging, enableCaching); + //Create topology: table from input topic + final StreamsBuilder builder = new StreamsBuilder(); + final KTable table = + builder.table("input", stateStoreConfig); + final Topology topology = builder.build(); + + try (final TopologyTestDriver driver = new TopologyTestDriver(topology)) { + //get input topic and stateStore + final TestInputTopic input = driver + .createInputTopic("input", new StringSerializer(), new StringSerializer()); + final ReadOnlyKeyValueStore stateStore = driver.getKeyValueStore(TABLE_NAME); + + //write some data + for (final KeyValue kv : records) { + input.pipeInput(kv.key, kv.value); + } + + //query the state store + try (final KeyValueIterator scanIterator = forward ? stateStore.range(null, null) : stateStore.reverseRange(null, null)) { + final Iterator> dataIterator = forward ? records.iterator() : records.descendingIterator(); + TestUtils.checkEquals(scanIterator, dataIterator); + } + + try (final KeyValueIterator allIterator = forward ? stateStore.all() : stateStore.reverseAll()) { + final Iterator> dataIterator = forward ? records.iterator() : records.descendingIterator(); + TestUtils.checkEquals(allIterator, dataIterator); + } + + testRange("range", stateStore, innerLow, innerHigh, forward); + testRange("until", stateStore, null, middle, forward); + testRange("from", stateStore, middle, null, forward); + + testRange("untilBetween", stateStore, null, innerHighBetween, forward); + testRange("fromBetween", stateStore, innerLowBetween, null, forward); + } + } + + private List> filterList(final KeyValueIterator iterator, final String from, final String to) { + final Predicate> pred = new Predicate>() { + @Override + public boolean test(final KeyValue elem) { + if (from != null && elem.key.compareTo(from) < 0) { + return false; + } + if (to != null && elem.key.compareTo(to) > 0) { + return false; + } + return elem != null; + } + }; + + return Utils.toList(iterator, pred); + } + + private void testRange(final String name, final ReadOnlyKeyValueStore store, final String from, final String to, final boolean forward) { + try (final KeyValueIterator resultIterator = forward ? store.range(from, to) : store.reverseRange(from, to); + final KeyValueIterator expectedIterator = forward ? store.all() : store.reverseAll()) { + final List> result = Utils.toList(resultIterator); + final List> expected = filterList(expectedIterator, from, to); + assertThat(result, is(expected)); + } + } + + private static Collection buildParameters(final List... argOptions) { + List result = new LinkedList<>(); + result.add(new Object[0]); + + for (final List argOption : argOptions) { + result = times(result, argOption); + } + + return result; + } + + private static List times(final List left, final List right) { + final List result = new LinkedList<>(); + for (final Object[] args : left) { + for (final Object rightElem : right) { + final Object[] resArgs = new Object[args.length + 1]; + System.arraycopy(args, 0, resArgs, 0, args.length); + resArgs[args.length] = rightElem; + result.add(resArgs); + } + } + return result; + } + + private Materialized> getStoreConfig(final StoreType type, final String name, final boolean cachingEnabled, final boolean loggingEnabled) { + final Supplier createStore = () -> { + if (type == StoreType.InMemory) { + return Stores.inMemoryKeyValueStore(TABLE_NAME); + } else if (type == StoreType.RocksDB) { + return Stores.persistentKeyValueStore(TABLE_NAME); + } else if (type == StoreType.Timed) { + return Stores.persistentTimestampedKeyValueStore(TABLE_NAME); + } else { + return Stores.inMemoryKeyValueStore(TABLE_NAME); + } + }; + + final KeyValueBytesStoreSupplier stateStoreSupplier = createStore.get(); + final Materialized> stateStoreConfig = Materialized + .as(stateStoreSupplier) + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.String()); + if (cachingEnabled) { + stateStoreConfig.withCachingEnabled(); + } else { + stateStoreConfig.withCachingDisabled(); + } + if (loggingEnabled) { + stateStoreConfig.withLoggingEnabled(new HashMap()); + } else { + stateStoreConfig.withLoggingDisabled(); + } + return stateStoreConfig; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/KTableKTableForeignKeyInnerJoinCustomPartitionerIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/KTableKTableForeignKeyInnerJoinCustomPartitionerIntegrationTest.java new file mode 100644 index 0000000..c83bbae --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/KTableKTableForeignKeyInnerJoinCustomPartitionerIntegrationTest.java @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.integration; + +import static java.time.Duration.ofSeconds; +import static java.util.Arrays.asList; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.startApplicationAndWaitUntilRunning; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived; +import static org.junit.Assert.assertEquals; + +import java.io.IOException; +import java.util.HashSet; +import java.util.List; +import java.util.Properties; +import java.util.Set; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Named; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.Repartitioned; +import org.apache.kafka.streams.kstream.TableJoined; +import org.apache.kafka.streams.kstream.ValueJoiner; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.utils.UniqueTopicSerdeScope; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; + +import kafka.utils.MockTime; + +@Category({IntegrationTest.class}) +public class KTableKTableForeignKeyInnerJoinCustomPartitionerIntegrationTest { + private final static int NUM_BROKERS = 1; + + public final static EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(NUM_BROKERS); + private final static MockTime MOCK_TIME = CLUSTER.time; + private final static String TABLE_1 = "table1"; + private final static String TABLE_2 = "table2"; + private final static String OUTPUT = "output-"; + private final Properties streamsConfig = getStreamsConfig(); + private final Properties streamsConfigTwo = getStreamsConfig(); + private final Properties streamsConfigThree = getStreamsConfig(); + private KafkaStreams streams; + private KafkaStreams streamsTwo; + private KafkaStreams streamsThree; + private final static Properties CONSUMER_CONFIG = new Properties(); + + private final static Properties PRODUCER_CONFIG_1 = new Properties(); + private final static Properties PRODUCER_CONFIG_2 = new Properties(); + + @BeforeClass + public static void startCluster() throws IOException, InterruptedException { + CLUSTER.start(); + //Use multiple partitions to ensure distribution of keys. + + CLUSTER.createTopic(TABLE_1, 4, 1); + CLUSTER.createTopic(TABLE_2, 4, 1); + CLUSTER.createTopic(OUTPUT, 4, 1); + + PRODUCER_CONFIG_1.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + PRODUCER_CONFIG_1.put(ProducerConfig.ACKS_CONFIG, "all"); + PRODUCER_CONFIG_1.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class); + PRODUCER_CONFIG_1.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, StringSerializer.class); + + PRODUCER_CONFIG_2.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + PRODUCER_CONFIG_2.put(ProducerConfig.ACKS_CONFIG, "all"); + PRODUCER_CONFIG_2.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class); + PRODUCER_CONFIG_2.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, StringSerializer.class); + + final List> table1 = asList( + new KeyValue<>("ID123-1", "ID123-A1"), + new KeyValue<>("ID123-2", "ID123-A2"), + new KeyValue<>("ID123-3", "ID123-A3"), + new KeyValue<>("ID123-4", "ID123-A4") + ); + + final List> table2 = asList( + new KeyValue<>("ID123", "BBB") + ); + + IntegrationTestUtils.produceKeyValuesSynchronously(TABLE_1, table1, PRODUCER_CONFIG_1, MOCK_TIME); + IntegrationTestUtils.produceKeyValuesSynchronously(TABLE_2, table2, PRODUCER_CONFIG_2, MOCK_TIME); + + CONSUMER_CONFIG.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + CONSUMER_CONFIG.put(ConsumerConfig.GROUP_ID_CONFIG, "ktable-ktable-consumer"); + CONSUMER_CONFIG.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); + CONSUMER_CONFIG.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + @Before + public void before() throws IOException { + final String stateDirBasePath = TestUtils.tempDirectory().getPath(); + streamsConfig.put(StreamsConfig.STATE_DIR_CONFIG, stateDirBasePath + "-1"); + streamsConfigTwo.put(StreamsConfig.STATE_DIR_CONFIG, stateDirBasePath + "-2"); + streamsConfigThree.put(StreamsConfig.STATE_DIR_CONFIG, stateDirBasePath + "-3"); + } + + @After + public void after() throws IOException { + if (streams != null) { + streams.close(); + streams = null; + } + if (streamsTwo != null) { + streamsTwo.close(); + streamsTwo = null; + } + if (streamsThree != null) { + streamsThree.close(); + streamsThree = null; + } + IntegrationTestUtils.purgeLocalStreamsState(asList(streamsConfig, streamsConfigTwo, streamsConfigThree)); + } + + @Test + public void shouldInnerJoinMultiPartitionQueryable() throws Exception { + final Set> expectedOne = new HashSet<>(); + expectedOne.add(new KeyValue<>("ID123-1", "value1=ID123-A1,value2=BBB")); + expectedOne.add(new KeyValue<>("ID123-2", "value1=ID123-A2,value2=BBB")); + expectedOne.add(new KeyValue<>("ID123-3", "value1=ID123-A3,value2=BBB")); + expectedOne.add(new KeyValue<>("ID123-4", "value1=ID123-A4,value2=BBB")); + + verifyKTableKTableJoin(expectedOne); + } + + private void verifyKTableKTableJoin(final Set> expectedResult) throws Exception { + final String innerJoinType = "INNER"; + final String queryableName = innerJoinType + "-store1"; + + streams = prepareTopology(queryableName, streamsConfig); + streamsTwo = prepareTopology(queryableName, streamsConfigTwo); + streamsThree = prepareTopology(queryableName, streamsConfigThree); + + final List kafkaStreamsList = asList(streams, streamsTwo, streamsThree); + startApplicationAndWaitUntilRunning(kafkaStreamsList, ofSeconds(120)); + + final Set> result = new HashSet<>(waitUntilMinKeyValueRecordsReceived( + CONSUMER_CONFIG, + OUTPUT, + expectedResult.size())); + + assertEquals(expectedResult, result); + } + + private static Properties getStreamsConfig() { + final Properties streamsConfig = new Properties(); + streamsConfig.put(StreamsConfig.APPLICATION_ID_CONFIG, "KTable-FKJ-Partitioner"); + streamsConfig.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + streamsConfig.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + streamsConfig.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + streamsConfig.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L); + + return streamsConfig; + } + + private static KafkaStreams prepareTopology(final String queryableName, final Properties streamsConfig) { + + final UniqueTopicSerdeScope serdeScope = new UniqueTopicSerdeScope(); + final StreamsBuilder builder = new StreamsBuilder(); + + final KTable table1 = builder.stream(TABLE_1, + Consumed.with(serdeScope.decorateSerde(Serdes.String(), streamsConfig, true), serdeScope.decorateSerde(Serdes.String(), streamsConfig, false))) + .repartition(repartitionA()) + .toTable(Named.as("table.a")); + + final KTable table2 = builder + .stream(TABLE_2, + Consumed.with(serdeScope.decorateSerde(Serdes.String(), streamsConfig, true), serdeScope.decorateSerde(Serdes.String(), streamsConfig, false))) + .repartition(repartitionB()) + .toTable(Named.as("table.b")); + + final Materialized> materialized; + if (queryableName != null) { + materialized = Materialized.>as(queryableName) + .withKeySerde(serdeScope.decorateSerde(Serdes.String(), streamsConfig, true)) + .withValueSerde(serdeScope.decorateSerde(Serdes.String(), streamsConfig, false)) + .withCachingDisabled(); + } else { + throw new RuntimeException("Current implementation of joinOnForeignKey requires a materialized store"); + } + + final ValueJoiner joiner = (value1, value2) -> "value1=" + value1 + ",value2=" + value2; + + final TableJoined tableJoined = TableJoined.with( + (topic, key, value, numPartitions) -> Math.abs(getKeyB(key).hashCode()) % numPartitions, + (topic, key, value, numPartitions) -> Math.abs(key.hashCode()) % numPartitions + ); + + table1.join(table2, KTableKTableForeignKeyInnerJoinCustomPartitionerIntegrationTest::getKeyB, joiner, tableJoined, materialized) + .toStream() + .to(OUTPUT, + Produced.with(serdeScope.decorateSerde(Serdes.String(), streamsConfig, true), + serdeScope.decorateSerde(Serdes.String(), streamsConfig, false))); + + return new KafkaStreams(builder.build(streamsConfig), streamsConfig); + } + + private static Repartitioned repartitionA() { + final Repartitioned repartitioned = Repartitioned.as("a"); + return repartitioned.withKeySerde(Serdes.String()).withValueSerde(Serdes.String()) + .withStreamPartitioner((topic, key, value, numPartitions) -> Math.abs(getKeyB(key).hashCode()) % numPartitions) + .withNumberOfPartitions(4); + } + + private static Repartitioned repartitionB() { + final Repartitioned repartitioned = Repartitioned.as("b"); + return repartitioned.withKeySerde(Serdes.String()).withValueSerde(Serdes.String()) + .withStreamPartitioner((topic, key, value, numPartitions) -> Math.abs(key.hashCode()) % numPartitions) + .withNumberOfPartitions(4); + } + + private static String getKeyB(final String value) { + return value.substring(0, value.indexOf("-")); + } + +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/KTableKTableForeignKeyInnerJoinMultiIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/KTableKTableForeignKeyInnerJoinMultiIntegrationTest.java new file mode 100644 index 0000000..0788b52 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/KTableKTableForeignKeyInnerJoinMultiIntegrationTest.java @@ -0,0 +1,281 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import kafka.utils.MockTime; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.serialization.FloatSerializer; +import org.apache.kafka.common.serialization.IntegerDeserializer; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.LongSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.ValueJoiner; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.utils.UniqueTopicSerdeScope; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Properties; +import java.util.Set; +import java.util.function.Function; + +import static java.time.Duration.ofSeconds; +import static java.util.Arrays.asList; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.startApplicationAndWaitUntilRunning; +import static org.junit.Assert.assertEquals; + +@Category({IntegrationTest.class}) +public class KTableKTableForeignKeyInnerJoinMultiIntegrationTest { + private final static int NUM_BROKERS = 1; + + public final static EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(NUM_BROKERS); + private final static MockTime MOCK_TIME = CLUSTER.time; + private final static String TABLE_1 = "table1"; + private final static String TABLE_2 = "table2"; + private final static String TABLE_3 = "table3"; + private final static String OUTPUT = "output-"; + private final Properties streamsConfig = getStreamsConfig(); + private final Properties streamsConfigTwo = getStreamsConfig(); + private final Properties streamsConfigThree = getStreamsConfig(); + private KafkaStreams streams; + private KafkaStreams streamsTwo; + private KafkaStreams streamsThree; + private final static Properties CONSUMER_CONFIG = new Properties(); + + private final static Properties PRODUCER_CONFIG_1 = new Properties(); + private final static Properties PRODUCER_CONFIG_2 = new Properties(); + private final static Properties PRODUCER_CONFIG_3 = new Properties(); + + @BeforeClass + public static void startCluster() throws IOException, InterruptedException { + CLUSTER.start(); + //Use multiple partitions to ensure distribution of keys. + + CLUSTER.createTopic(TABLE_1, 3, 1); + CLUSTER.createTopic(TABLE_2, 5, 1); + CLUSTER.createTopic(TABLE_3, 7, 1); + CLUSTER.createTopic(OUTPUT, 11, 1); + + PRODUCER_CONFIG_1.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + PRODUCER_CONFIG_1.put(ProducerConfig.ACKS_CONFIG, "all"); + PRODUCER_CONFIG_1.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, IntegerSerializer.class); + PRODUCER_CONFIG_1.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, FloatSerializer.class); + + PRODUCER_CONFIG_2.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + PRODUCER_CONFIG_2.put(ProducerConfig.ACKS_CONFIG, "all"); + PRODUCER_CONFIG_2.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class); + PRODUCER_CONFIG_2.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, LongSerializer.class); + + PRODUCER_CONFIG_3.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + PRODUCER_CONFIG_3.put(ProducerConfig.ACKS_CONFIG, "all"); + PRODUCER_CONFIG_3.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, IntegerSerializer.class); + PRODUCER_CONFIG_3.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, StringSerializer.class); + + final List> table1 = asList( + new KeyValue<>(1, 1.33f), + new KeyValue<>(2, 2.22f), + new KeyValue<>(3, -1.22f), //Won't be joined in yet. + new KeyValue<>(4, -2.22f) //Won't be joined in at all. + ); + + //Partitions pre-computed using the default Murmur2 hash, just to ensure that all 3 partitions will be exercised. + final List> table2 = asList( + new KeyValue<>("0", 0L), //partition 2 + new KeyValue<>("1", 10L), //partition 0 + new KeyValue<>("2", 20L), //partition 2 + new KeyValue<>("3", 30L), //partition 2 + new KeyValue<>("4", 40L), //partition 1 + new KeyValue<>("5", 50L), //partition 0 + new KeyValue<>("6", 60L), //partition 1 + new KeyValue<>("7", 70L), //partition 0 + new KeyValue<>("8", 80L), //partition 0 + new KeyValue<>("9", 90L) //partition 2 + ); + + //Partitions pre-computed using the default Murmur2 hash, just to ensure that all 3 partitions will be exercised. + final List> table3 = Collections.singletonList( + new KeyValue<>(10, "waffle") + ); + + IntegrationTestUtils.produceKeyValuesSynchronously(TABLE_1, table1, PRODUCER_CONFIG_1, MOCK_TIME); + IntegrationTestUtils.produceKeyValuesSynchronously(TABLE_2, table2, PRODUCER_CONFIG_2, MOCK_TIME); + IntegrationTestUtils.produceKeyValuesSynchronously(TABLE_3, table3, PRODUCER_CONFIG_3, MOCK_TIME); + + CONSUMER_CONFIG.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + CONSUMER_CONFIG.put(ConsumerConfig.GROUP_ID_CONFIG, "ktable-ktable-consumer"); + CONSUMER_CONFIG.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, IntegerDeserializer.class); + CONSUMER_CONFIG.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + @Before + public void before() throws IOException { + final String stateDirBasePath = TestUtils.tempDirectory().getPath(); + streamsConfig.put(StreamsConfig.STATE_DIR_CONFIG, stateDirBasePath + "-1"); + streamsConfigTwo.put(StreamsConfig.STATE_DIR_CONFIG, stateDirBasePath + "-2"); + streamsConfigThree.put(StreamsConfig.STATE_DIR_CONFIG, stateDirBasePath + "-3"); + } + + @After + public void after() throws IOException { + if (streams != null) { + streams.close(); + streams = null; + } + if (streamsTwo != null) { + streamsTwo.close(); + streamsTwo = null; + } + if (streamsThree != null) { + streamsThree.close(); + streamsThree = null; + } + IntegrationTestUtils.purgeLocalStreamsState(asList(streamsConfig, streamsConfigTwo, streamsConfigThree)); + } + + @Test + public void shouldInnerJoinMultiPartitionQueryable() throws Exception { + final Set> expectedOne = new HashSet<>(); + expectedOne.add(new KeyValue<>(1, "value1=1.33,value2=10,value3=waffle")); + + verifyKTableKTableJoin(expectedOne); + } + + private void verifyKTableKTableJoin(final Set> expectedResult) throws Exception { + final String innerJoinType = "INNER"; + final String queryableName = innerJoinType + "-store1"; + final String queryableNameTwo = innerJoinType + "-store2"; + + streams = prepareTopology(queryableName, queryableNameTwo, streamsConfig); + streamsTwo = prepareTopology(queryableName, queryableNameTwo, streamsConfigTwo); + streamsThree = prepareTopology(queryableName, queryableNameTwo, streamsConfigThree); + + final List kafkaStreamsList = asList(streams, streamsTwo, streamsThree); + startApplicationAndWaitUntilRunning(kafkaStreamsList, ofSeconds(120)); + + final Set> result = new HashSet<>(IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived( + CONSUMER_CONFIG, + OUTPUT, + expectedResult.size())); + + assertEquals(expectedResult, result); + } + + private static Properties getStreamsConfig() { + final Properties streamsConfig = new Properties(); + streamsConfig.put(StreamsConfig.APPLICATION_ID_CONFIG, "KTable-FKJ-Multi"); + streamsConfig.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + streamsConfig.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + streamsConfig.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + streamsConfig.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L); + + return streamsConfig; + } + + private static KafkaStreams prepareTopology(final String queryableName, + final String queryableNameTwo, + final Properties streamsConfig) { + + final UniqueTopicSerdeScope serdeScope = new UniqueTopicSerdeScope(); + final StreamsBuilder builder = new StreamsBuilder(); + + final KTable table1 = builder.table( + TABLE_1, + Consumed.with(serdeScope.decorateSerde(Serdes.Integer(), streamsConfig, true), + serdeScope.decorateSerde(Serdes.Float(), streamsConfig, false)) + ); + final KTable table2 = builder.table( + TABLE_2, + Consumed.with(serdeScope.decorateSerde(Serdes.String(), streamsConfig, true), + serdeScope.decorateSerde(Serdes.Long(), streamsConfig, false)) + ); + final KTable table3 = builder.table( + TABLE_3, + Consumed.with(serdeScope.decorateSerde(Serdes.Integer(), streamsConfig, true), + serdeScope.decorateSerde(Serdes.String(), streamsConfig, false)) + ); + + final Materialized> materialized; + if (queryableName != null) { + materialized = Materialized.>as(queryableName) + .withKeySerde(serdeScope.decorateSerde(Serdes.Integer(), streamsConfig, true)) + .withValueSerde(serdeScope.decorateSerde(Serdes.String(), streamsConfig, false)) + .withCachingDisabled(); + } else { + throw new RuntimeException("Current implementation of joinOnForeignKey requires a materialized store"); + } + + final Materialized> materializedTwo; + if (queryableNameTwo != null) { + materializedTwo = Materialized.>as(queryableNameTwo) + .withKeySerde(serdeScope.decorateSerde(Serdes.Integer(), streamsConfig, true)) + .withValueSerde(serdeScope.decorateSerde(Serdes.String(), streamsConfig, false)) + .withCachingDisabled(); + } else { + throw new RuntimeException("Current implementation of joinOnForeignKey requires a materialized store"); + } + + final Function tableOneKeyExtractor = value -> Integer.toString((int) value.floatValue()); + final Function joinedTableKeyExtractor = value -> { + //Hardwired to return the desired foreign key as a test shortcut + if (value.contains("value2=10")) + return 10; + else + return 0; + }; + + final ValueJoiner joiner = (value1, value2) -> "value1=" + value1 + ",value2=" + value2; + final ValueJoiner joinerTwo = (value1, value2) -> value1 + ",value3=" + value2; + + table1.join(table2, tableOneKeyExtractor, joiner, materialized) + .join(table3, joinedTableKeyExtractor, joinerTwo, materializedTwo) + .toStream() + .to(OUTPUT, + Produced.with(serdeScope.decorateSerde(Serdes.Integer(), streamsConfig, true), + serdeScope.decorateSerde(Serdes.String(), streamsConfig, false))); + + return new KafkaStreams(builder.build(streamsConfig), streamsConfig); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/KTableKTableForeignKeyJoinDistributedTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/KTableKTableForeignKeyJoinDistributedTest.java new file mode 100644 index 0000000..af952bd --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/KTableKTableForeignKeyJoinDistributedTest.java @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.integration; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.ThreadMetadata; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.ValueJoiner; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Properties; +import java.util.Set; +import java.util.function.Function; +import java.util.function.Predicate; + +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.quietlyCleanStateAfterTest; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; +import static org.junit.Assert.assertEquals; + +@Category({IntegrationTest.class}) +public class KTableKTableForeignKeyJoinDistributedTest { + private static final int NUM_BROKERS = 1; + private static final String LEFT_TABLE = "left_table"; + private static final String RIGHT_TABLE = "right_table"; + private static final String OUTPUT = "output-topic"; + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(NUM_BROKERS); + + @BeforeClass + public static void startCluster() throws IOException, InterruptedException { + CLUSTER.start(); + CLUSTER.createTopic(INPUT_TOPIC, 2, 1); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + private static final Properties CONSUMER_CONFIG = new Properties(); + + @Rule + public TestName testName = new TestName(); + + + private static final String INPUT_TOPIC = "input-topic"; + + private KafkaStreams client1; + private KafkaStreams client2; + + private volatile boolean client1IsOk = false; + private volatile boolean client2IsOk = false; + + @Before + public void setupTopics() throws InterruptedException { + CLUSTER.createTopic(LEFT_TABLE, 1, 1); + CLUSTER.createTopic(RIGHT_TABLE, 1, 1); + CLUSTER.createTopic(OUTPUT, 11, 1); + + //Fill test tables + final Properties producerConfig = new Properties(); + producerConfig.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + producerConfig.put(ProducerConfig.ACKS_CONFIG, "all"); + producerConfig.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class); + producerConfig.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, StringSerializer.class); + final List> leftTable = Arrays.asList( + new KeyValue<>("lhsValue1", "lhsValue1|rhs1"), + new KeyValue<>("lhsValue2", "lhsValue2|rhs2"), + new KeyValue<>("lhsValue3", "lhsValue3|rhs3"), + new KeyValue<>("lhsValue4", "lhsValue4|rhs4") + ); + final List> rightTable = Arrays.asList( + new KeyValue<>("rhs1", "rhsValue1"), + new KeyValue<>("rhs2", "rhsValue2"), + new KeyValue<>("rhs3", "rhsValue3") + ); + + IntegrationTestUtils.produceKeyValuesSynchronously(LEFT_TABLE, leftTable, producerConfig, CLUSTER.time); + IntegrationTestUtils.produceKeyValuesSynchronously(RIGHT_TABLE, rightTable, producerConfig, CLUSTER.time); + + CONSUMER_CONFIG.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + CONSUMER_CONFIG.put(ConsumerConfig.GROUP_ID_CONFIG, "ktable-ktable-distributed-consumer"); + CONSUMER_CONFIG.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); + CONSUMER_CONFIG.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); + } + + @After + public void after() { + client1.close(); + client2.close(); + quietlyCleanStateAfterTest(CLUSTER, client1); + quietlyCleanStateAfterTest(CLUSTER, client2); + } + + public Properties getStreamsConfiguration() { + final String safeTestName = safeUniqueTestName(getClass(), testName); + final Properties streamsConfiguration = new Properties(); + streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, "app-" + safeTestName); + streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()); + streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + streamsConfiguration.put(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, StreamsConfig.OPTIMIZE); + return streamsConfiguration; + } + + + private void configureBuilder(final StreamsBuilder builder) { + final KTable left = builder.table( + LEFT_TABLE + ); + final KTable right = builder.table( + RIGHT_TABLE + ); + + final Function extractor = value -> value.split("\\|")[1]; + final ValueJoiner joiner = (value1, value2) -> "(" + value1 + "," + value2 + ")"; + + final KTable fkJoin = left.join(right, extractor, joiner); + fkJoin + .toStream() + .to(OUTPUT); + } + + @Test + public void shouldBeInitializedWithDefaultSerde() throws Exception { + final Properties streamsConfiguration1 = getStreamsConfiguration(); + final Properties streamsConfiguration2 = getStreamsConfiguration(); + + //Each streams client needs to have it's own StreamsBuilder in order to simulate + //a truly distributed run + final StreamsBuilder builder1 = new StreamsBuilder(); + configureBuilder(builder1); + final StreamsBuilder builder2 = new StreamsBuilder(); + configureBuilder(builder2); + + + createClients( + builder1.build(streamsConfiguration1), + streamsConfiguration1, + builder2.build(streamsConfiguration2), + streamsConfiguration2 + ); + + setStateListenersForVerification(thread -> !thread.activeTasks().isEmpty()); + + startClients(); + + waitUntilBothClientAreOK( + "At least one client did not reach state RUNNING with active tasks" + ); + final Set> expectedResult = new HashSet<>(); + expectedResult.add(new KeyValue<>("lhsValue1", "(lhsValue1|rhs1,rhsValue1)")); + expectedResult.add(new KeyValue<>("lhsValue2", "(lhsValue2|rhs2,rhsValue2)")); + expectedResult.add(new KeyValue<>("lhsValue3", "(lhsValue3|rhs3,rhsValue3)")); + final Set> result = new HashSet<>(IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived( + CONSUMER_CONFIG, + OUTPUT, + expectedResult.size())); + + assertEquals(expectedResult, result); + //Check that both clients are still running + assertEquals(KafkaStreams.State.RUNNING, client1.state()); + assertEquals(KafkaStreams.State.RUNNING, client2.state()); + } + + private void createClients(final Topology topology1, + final Properties streamsConfiguration1, + final Topology topology2, + final Properties streamsConfiguration2) { + + client1 = new KafkaStreams(topology1, streamsConfiguration1); + client2 = new KafkaStreams(topology2, streamsConfiguration2); + } + + private void setStateListenersForVerification(final Predicate taskCondition) { + client1.setStateListener((newState, oldState) -> { + if (newState == KafkaStreams.State.RUNNING && + client1.metadataForLocalThreads().stream().allMatch(taskCondition)) { + client1IsOk = true; + } + }); + client2.setStateListener((newState, oldState) -> { + if (newState == KafkaStreams.State.RUNNING && + client2.metadataForLocalThreads().stream().allMatch(taskCondition)) { + client2IsOk = true; + } + }); + } + + private void startClients() { + client1.start(); + client2.start(); + } + + private void waitUntilBothClientAreOK(final String message) throws Exception { + TestUtils.waitForCondition(() -> client1IsOk && client2IsOk, + 30 * 1000, + message + ": " + + "Client 1 is " + (!client1IsOk ? "NOT " : "") + "OK, " + + "client 2 is " + (!client2IsOk ? "NOT " : "") + "OK." + ); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/KTableKTableForeignKeyJoinIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/KTableKTableForeignKeyJoinIntegrationTest.java new file mode 100644 index 0000000..60104c4 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/KTableKTableForeignKeyJoinIntegrationTest.java @@ -0,0 +1,702 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TestOutputTopic; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.ValueJoiner; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.utils.UniqueTopicSerdeScope; +import org.apache.kafka.test.TestUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestName; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.function.Function; + +import static java.util.Collections.emptyMap; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkProperties; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + + +@RunWith(Parameterized.class) +public class KTableKTableForeignKeyJoinIntegrationTest { + + private static final String LEFT_TABLE = "left_table"; + private static final String RIGHT_TABLE = "right_table"; + private static final String OUTPUT = "output-topic"; + private static final String REJOIN_OUTPUT = "rejoin-output-topic"; + private final boolean leftJoin; + private final boolean materialized; + private final String optimization; + private final boolean rejoin; + + private Properties streamsConfig; + + public KTableKTableForeignKeyJoinIntegrationTest(final boolean leftJoin, + final String optimization, + final boolean materialized, + final boolean rejoin) { + this.rejoin = rejoin; + this.leftJoin = leftJoin; + this.materialized = materialized; + this.optimization = optimization; + } + + @Rule + public TestName testName = new TestName(); + + @Before + public void before() { + streamsConfig = mkProperties(mkMap( + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()), + mkEntry(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, optimization) + )); + } + + @Parameterized.Parameters(name = "leftJoin={0}, optimization={1}, materialized={2}, rejoin={3}") + public static Collection data() { + final List booleans = Arrays.asList(true, false); + final List optimizations = Arrays.asList(StreamsConfig.OPTIMIZE, StreamsConfig.NO_OPTIMIZATION); + return buildParameters(booleans, optimizations, booleans, booleans); + } + + private static Collection buildParameters(final List... argOptions) { + List result = new LinkedList<>(); + result.add(new Object[0]); + + for (final List argOption : argOptions) { + result = times(result, argOption); + } + + return result; + } + + private static List times(final List left, final List right) { + final List result = new LinkedList<>(); + for (final Object[] args : left) { + for (final Object rightElem : right) { + final Object[] resArgs = new Object[args.length + 1]; + System.arraycopy(args, 0, resArgs, 0, args.length); + resArgs[args.length] = rightElem; + result.add(resArgs); + } + } + return result; + } + + @Test + public void doJoinFromLeftThenDeleteLeftEntity() { + final Topology topology = getTopology(streamsConfig, materialized ? "store" : null, leftJoin, rejoin); + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, streamsConfig)) { + final TestInputTopic right = driver.createInputTopic(RIGHT_TABLE, new StringSerializer(), new StringSerializer()); + final TestInputTopic left = driver.createInputTopic(LEFT_TABLE, new StringSerializer(), new StringSerializer()); + final TestOutputTopic outputTopic = driver.createOutputTopic(OUTPUT, new StringDeserializer(), new StringDeserializer()); + final TestOutputTopic rejoinOutputTopic = rejoin ? driver.createOutputTopic(REJOIN_OUTPUT, new StringDeserializer(), new StringDeserializer()) : null; + final KeyValueStore store = driver.getKeyValueStore("store"); + + // Pre-populate the RHS records. This test is all about what happens when we add/remove LHS records + right.pipeInput("rhs1", "rhsValue1"); + right.pipeInput("rhs2", "rhsValue2"); + right.pipeInput("rhs3", "rhsValue3"); // this unreferenced FK won't show up in any results + + assertThat( + outputTopic.readKeyValuesToMap(), + is(emptyMap()) + ); + if (rejoin) { + assertThat( + rejoinOutputTopic.readKeyValuesToMap(), + is(emptyMap()) + ); + } + if (materialized) { + assertThat( + asMap(store), + is(emptyMap()) + ); + } + + left.pipeInput("lhs1", "lhsValue1|rhs1"); + left.pipeInput("lhs2", "lhsValue2|rhs2"); + + { + final Map expected = mkMap( + mkEntry("lhs1", "(lhsValue1|rhs1,rhsValue1)"), + mkEntry("lhs2", "(lhsValue2|rhs2,rhsValue2)") + ); + assertThat( + outputTopic.readKeyValuesToMap(), + is(expected) + ); + if (rejoin) { + assertThat( + rejoinOutputTopic.readKeyValuesToMap(), + is(mkMap( + mkEntry("lhs1", "rejoin((lhsValue1|rhs1,rhsValue1),lhsValue1|rhs1)"), + mkEntry("lhs2", "rejoin((lhsValue2|rhs2,rhsValue2),lhsValue2|rhs2)") + )) + ); + } + if (materialized) { + assertThat( + asMap(store), + is(expected) + ); + } + } + + // Add another reference to an existing FK + left.pipeInput("lhs3", "lhsValue3|rhs1"); + { + assertThat( + outputTopic.readKeyValuesToMap(), + is(mkMap( + mkEntry("lhs3", "(lhsValue3|rhs1,rhsValue1)") + )) + ); + if (rejoin) { + assertThat( + rejoinOutputTopic.readKeyValuesToMap(), + is(mkMap( + mkEntry("lhs3", "rejoin((lhsValue3|rhs1,rhsValue1),lhsValue3|rhs1)") + )) + ); + } + if (materialized) { + assertThat( + asMap(store), + is(mkMap( + mkEntry("lhs1", "(lhsValue1|rhs1,rhsValue1)"), + mkEntry("lhs2", "(lhsValue2|rhs2,rhsValue2)"), + mkEntry("lhs3", "(lhsValue3|rhs1,rhsValue1)") + )) + ); + } + } + // Now delete one LHS entity such that one delete is propagated down to the output. + + left.pipeInput("lhs1", (String) null); + assertThat( + outputTopic.readKeyValuesToMap(), + is(mkMap( + mkEntry("lhs1", null) + )) + ); + if (rejoin) { + assertThat( + rejoinOutputTopic.readKeyValuesToMap(), + is(mkMap( + mkEntry("lhs1", null) + )) + ); + } + if (materialized) { + assertThat( + asMap(store), + is(mkMap( + mkEntry("lhs2", "(lhsValue2|rhs2,rhsValue2)"), + mkEntry("lhs3", "(lhsValue3|rhs1,rhsValue1)") + )) + ); + } + } + } + + @Test + public void doJoinFromRightThenDeleteRightEntity() { + final Topology topology = getTopology(streamsConfig, materialized ? "store" : null, leftJoin, rejoin); + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, streamsConfig)) { + final TestInputTopic right = driver.createInputTopic(RIGHT_TABLE, new StringSerializer(), new StringSerializer()); + final TestInputTopic left = driver.createInputTopic(LEFT_TABLE, new StringSerializer(), new StringSerializer()); + final TestOutputTopic outputTopic = driver.createOutputTopic(OUTPUT, new StringDeserializer(), new StringDeserializer()); + final KeyValueStore store = driver.getKeyValueStore("store"); + + // Pre-populate the LHS records. This test is all about what happens when we add/remove RHS records + left.pipeInput("lhs1", "lhsValue1|rhs1"); + left.pipeInput("lhs2", "lhsValue2|rhs2"); + left.pipeInput("lhs3", "lhsValue3|rhs1"); + + assertThat( + outputTopic.readKeyValuesToMap(), + is(leftJoin + ? mkMap(mkEntry("lhs1", "(lhsValue1|rhs1,null)"), + mkEntry("lhs2", "(lhsValue2|rhs2,null)"), + mkEntry("lhs3", "(lhsValue3|rhs1,null)")) + : emptyMap() + ) + ); + if (materialized) { + assertThat( + asMap(store), + is(leftJoin + ? mkMap(mkEntry("lhs1", "(lhsValue1|rhs1,null)"), + mkEntry("lhs2", "(lhsValue2|rhs2,null)"), + mkEntry("lhs3", "(lhsValue3|rhs1,null)")) + : emptyMap() + ) + ); + } + + right.pipeInput("rhs1", "rhsValue1"); + + assertThat( + outputTopic.readKeyValuesToMap(), + is(mkMap(mkEntry("lhs1", "(lhsValue1|rhs1,rhsValue1)"), + mkEntry("lhs3", "(lhsValue3|rhs1,rhsValue1)")) + ) + ); + if (materialized) { + assertThat( + asMap(store), + is(leftJoin + ? mkMap(mkEntry("lhs1", "(lhsValue1|rhs1,rhsValue1)"), + mkEntry("lhs2", "(lhsValue2|rhs2,null)"), + mkEntry("lhs3", "(lhsValue3|rhs1,rhsValue1)")) + + : mkMap(mkEntry("lhs1", "(lhsValue1|rhs1,rhsValue1)"), + mkEntry("lhs3", "(lhsValue3|rhs1,rhsValue1)")) + ) + ); + } + + right.pipeInput("rhs2", "rhsValue2"); + + assertThat( + outputTopic.readKeyValuesToMap(), + is(mkMap(mkEntry("lhs2", "(lhsValue2|rhs2,rhsValue2)"))) + ); + if (materialized) { + assertThat( + asMap(store), + is(mkMap(mkEntry("lhs1", "(lhsValue1|rhs1,rhsValue1)"), + mkEntry("lhs2", "(lhsValue2|rhs2,rhsValue2)"), + mkEntry("lhs3", "(lhsValue3|rhs1,rhsValue1)")) + ) + ); + } + + right.pipeInput("rhs3", "rhsValue3"); // this unreferenced FK won't show up in any results + + assertThat( + outputTopic.readKeyValuesToMap(), + is(emptyMap()) + ); + if (materialized) { + assertThat( + asMap(store), + is(mkMap(mkEntry("lhs1", "(lhsValue1|rhs1,rhsValue1)"), + mkEntry("lhs2", "(lhsValue2|rhs2,rhsValue2)"), + mkEntry("lhs3", "(lhsValue3|rhs1,rhsValue1)")) + ) + ); + } + + // Now delete the RHS entity such that all matching keys have deletes propagated. + right.pipeInput("rhs1", (String) null); + + assertThat( + outputTopic.readKeyValuesToMap(), + is(mkMap(mkEntry("lhs1", leftJoin ? "(lhsValue1|rhs1,null)" : null), + mkEntry("lhs3", leftJoin ? "(lhsValue3|rhs1,null)" : null)) + ) + ); + if (materialized) { + assertThat( + asMap(store), + is(leftJoin + ? mkMap(mkEntry("lhs1", "(lhsValue1|rhs1,null)"), + mkEntry("lhs2", "(lhsValue2|rhs2,rhsValue2)"), + mkEntry("lhs3", "(lhsValue3|rhs1,null)")) + + : mkMap(mkEntry("lhs2", "(lhsValue2|rhs2,rhsValue2)")) + ) + ); + } + } + } + + @Test + public void shouldEmitTombstoneWhenDeletingNonJoiningRecords() { + final Topology topology = getTopology(streamsConfig, materialized ? "store" : null, leftJoin, rejoin); + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, streamsConfig)) { + final TestInputTopic left = driver.createInputTopic(LEFT_TABLE, new StringSerializer(), new StringSerializer()); + final TestOutputTopic outputTopic = driver.createOutputTopic(OUTPUT, new StringDeserializer(), new StringDeserializer()); + final KeyValueStore store = driver.getKeyValueStore("store"); + + left.pipeInput("lhs1", "lhsValue1|rhs1"); + + { + final Map expected = + leftJoin ? mkMap(mkEntry("lhs1", "(lhsValue1|rhs1,null)")) : emptyMap(); + assertThat( + outputTopic.readKeyValuesToMap(), + is(expected) + ); + if (materialized) { + assertThat( + asMap(store), + is(expected) + ); + } + } + + // Deleting a non-joining record produces an unnecessary tombstone for inner joins, because + // it's not possible to know whether a result was previously emitted. + // For the left join, the tombstone is necessary. + left.pipeInput("lhs1", (String) null); + { + assertThat( + outputTopic.readKeyValuesToMap(), + is(mkMap(mkEntry("lhs1", null))) + ); + if (materialized) { + assertThat( + asMap(store), + is(emptyMap()) + ); + } + } + + // Deleting a non-existing record is idempotent + left.pipeInput("lhs1", (String) null); + { + assertThat( + outputTopic.readKeyValuesToMap(), + is(emptyMap()) + ); + if (materialized) { + assertThat( + asMap(store), + is(emptyMap()) + ); + } + } + } + } + + @Test + public void shouldNotEmitTombstonesWhenDeletingNonExistingRecords() { + final Topology topology = getTopology(streamsConfig, materialized ? "store" : null, leftJoin, rejoin); + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, streamsConfig)) { + final TestInputTopic left = driver.createInputTopic(LEFT_TABLE, new StringSerializer(), new StringSerializer()); + final TestOutputTopic outputTopic = driver.createOutputTopic(OUTPUT, new StringDeserializer(), new StringDeserializer()); + final KeyValueStore store = driver.getKeyValueStore("store"); + + // Deleting a record that never existed doesn't need to emit tombstones. + left.pipeInput("lhs1", (String) null); + { + assertThat( + outputTopic.readKeyValuesToMap(), + is(emptyMap()) + ); + if (materialized) { + assertThat( + asMap(store), + is(emptyMap()) + ); + } + } + } + } + + @Test + public void joinShouldProduceNullsWhenValueHasNonMatchingForeignKey() { + final Topology topology = getTopology(streamsConfig, materialized ? "store" : null, leftJoin, rejoin); + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, streamsConfig)) { + final TestInputTopic right = driver.createInputTopic(RIGHT_TABLE, new StringSerializer(), new StringSerializer()); + final TestInputTopic left = driver.createInputTopic(LEFT_TABLE, new StringSerializer(), new StringSerializer()); + final TestOutputTopic outputTopic = driver.createOutputTopic(OUTPUT, new StringDeserializer(), new StringDeserializer()); + final KeyValueStore store = driver.getKeyValueStore("store"); + + left.pipeInput("lhs1", "lhsValue1|rhs1"); + // no output for a new inner join on a non-existent FK + // the left join of course emits the half-joined output + assertThat( + outputTopic.readKeyValuesToMap(), + is(leftJoin ? mkMap(mkEntry("lhs1", "(lhsValue1|rhs1,null)")) : emptyMap()) + ); + if (materialized) { + assertThat( + asMap(store), + is(leftJoin ? mkMap(mkEntry("lhs1", "(lhsValue1|rhs1,null)")) : emptyMap()) + ); + } + // "moving" our subscription to another non-existent FK results in an unnecessary tombstone for inner join, + // since it impossible to know whether the prior FK existed or not (and thus whether any results have + // previously been emitted) + // The left join emits a _necessary_ update (since the lhs record has actually changed) + left.pipeInput("lhs1", "lhsValue1|rhs2"); + assertThat( + outputTopic.readKeyValuesToMap(), + is(mkMap(mkEntry("lhs1", leftJoin ? "(lhsValue1|rhs2,null)" : null))) + ); + if (materialized) { + assertThat( + asMap(store), + is(leftJoin ? mkMap(mkEntry("lhs1", "(lhsValue1|rhs2,null)")) : emptyMap()) + ); + } + // of course, moving it again to yet another non-existent FK has the same effect + left.pipeInput("lhs1", "lhsValue1|rhs3"); + assertThat( + outputTopic.readKeyValuesToMap(), + is(mkMap(mkEntry("lhs1", leftJoin ? "(lhsValue1|rhs3,null)" : null))) + ); + if (materialized) { + assertThat( + asMap(store), + is(leftJoin ? mkMap(mkEntry("lhs1", "(lhsValue1|rhs3,null)")) : emptyMap()) + ); + } + + // Adding an RHS record now, so that we can demonstrate "moving" from a non-existent FK to an existent one + // This RHS key was previously referenced, but it's not referenced now, so adding this record should + // result in no changes whatsoever. + right.pipeInput("rhs1", "rhsValue1"); + assertThat( + outputTopic.readKeyValuesToMap(), + is(emptyMap()) + ); + if (materialized) { + assertThat( + asMap(store), + is(leftJoin ? mkMap(mkEntry("lhs1", "(lhsValue1|rhs3,null)")) : emptyMap()) + ); + } + + // now, we change to a FK that exists, and see the join completes + left.pipeInput("lhs1", "lhsValue1|rhs1"); + assertThat( + outputTopic.readKeyValuesToMap(), + is(mkMap( + mkEntry("lhs1", "(lhsValue1|rhs1,rhsValue1)") + )) + ); + if (materialized) { + assertThat( + asMap(store), + is(mkMap( + mkEntry("lhs1", "(lhsValue1|rhs1,rhsValue1)") + )) + ); + } + + // but if we update it again to a non-existent one, we'll get a tombstone for the inner join, and the + // left join updates appropriately. + left.pipeInput("lhs1", "lhsValue1|rhs2"); + assertThat( + outputTopic.readKeyValuesToMap(), + is(mkMap( + mkEntry("lhs1", leftJoin ? "(lhsValue1|rhs2,null)" : null) + )) + ); + if (materialized) { + assertThat( + asMap(store), + is(leftJoin ? mkMap(mkEntry("lhs1", "(lhsValue1|rhs2,null)")) : emptyMap()) + ); + } + } + } + + @Test + public void shouldUnsubscribeOldForeignKeyIfLeftSideIsUpdated() { + final Topology topology = getTopology(streamsConfig, materialized ? "store" : null, leftJoin, rejoin); + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, streamsConfig)) { + final TestInputTopic right = driver.createInputTopic(RIGHT_TABLE, new StringSerializer(), new StringSerializer()); + final TestInputTopic left = driver.createInputTopic(LEFT_TABLE, new StringSerializer(), new StringSerializer()); + final TestOutputTopic outputTopic = driver.createOutputTopic(OUTPUT, new StringDeserializer(), new StringDeserializer()); + final KeyValueStore store = driver.getKeyValueStore("store"); + + // Pre-populate the RHS records. This test is all about what happens when we change LHS records foreign key reference + // then populate update on RHS + right.pipeInput("rhs1", "rhsValue1"); + right.pipeInput("rhs2", "rhsValue2"); + + assertThat( + outputTopic.readKeyValuesToMap(), + is(emptyMap()) + ); + if (materialized) { + assertThat( + asMap(store), + is(emptyMap()) + ); + } + + left.pipeInput("lhs1", "lhsValue1|rhs1"); + { + final Map expected = mkMap( + mkEntry("lhs1", "(lhsValue1|rhs1,rhsValue1)") + ); + assertThat( + outputTopic.readKeyValuesToMap(), + is(expected) + ); + if (materialized) { + assertThat( + asMap(store), + is(expected) + ); + } + } + + // Change LHS foreign key reference + left.pipeInput("lhs1", "lhsValue1|rhs2"); + { + final Map expected = mkMap( + mkEntry("lhs1", "(lhsValue1|rhs2,rhsValue2)") + ); + assertThat( + outputTopic.readKeyValuesToMap(), + is(expected) + ); + if (materialized) { + assertThat( + asMap(store), + is(expected) + ); + } + } + + // Populate RHS update on old LHS foreign key ref + right.pipeInput("rhs1", "rhsValue1Delta"); + { + assertThat( + outputTopic.readKeyValuesToMap(), + is(emptyMap()) + ); + if (materialized) { + assertThat( + asMap(store), + is(mkMap( + mkEntry("lhs1", "(lhsValue1|rhs2,rhsValue2)") + )) + ); + } + } + } + } + + private static Map asMap(final KeyValueStore store) { + final HashMap result = new HashMap<>(); + store.all().forEachRemaining(kv -> result.put(kv.key, kv.value)); + return result; + } + + private static Topology getTopology(final Properties streamsConfig, + final String queryableStoreName, + final boolean leftJoin, + final boolean rejoin) { + final UniqueTopicSerdeScope serdeScope = new UniqueTopicSerdeScope(); + final StreamsBuilder builder = new StreamsBuilder(); + + final KTable left = builder.table( + LEFT_TABLE, + Consumed.with(serdeScope.decorateSerde(Serdes.String(), streamsConfig, true), + serdeScope.decorateSerde(Serdes.String(), streamsConfig, false)) + ); + final KTable right = builder.table( + RIGHT_TABLE, + Consumed.with(serdeScope.decorateSerde(Serdes.String(), streamsConfig, true), + serdeScope.decorateSerde(Serdes.String(), streamsConfig, false)) + ); + + final Function extractor = value -> value.split("\\|")[1]; + final ValueJoiner joiner = (value1, value2) -> "(" + value1 + "," + value2 + ")"; + final ValueJoiner rejoiner = rejoin ? (value1, value2) -> "rejoin(" + value1 + "," + value2 + ")" : null; + + // the cache suppresses some of the unnecessary tombstones we want to make assertions about + final Materialized> mainMaterialized = + queryableStoreName == null ? + Materialized.>with( + null, + serdeScope.decorateSerde(Serdes.String(), streamsConfig, false) + ).withCachingDisabled() : + Materialized.as(Stores.inMemoryKeyValueStore(queryableStoreName)) + .withValueSerde(serdeScope.decorateSerde(Serdes.String(), streamsConfig, false)) + .withCachingDisabled(); + + final Materialized> rejoinMaterialized = + !rejoin ? null : + queryableStoreName == null ? + Materialized.with(null, serdeScope.decorateSerde(Serdes.String(), streamsConfig, false)) : + // not actually going to query this store, but we need to force materialization here + // to really test this confuguration + Materialized.as(Stores.inMemoryKeyValueStore(queryableStoreName + "-rejoin")) + .withValueSerde(serdeScope.decorateSerde(Serdes.String(), streamsConfig, false)) + // the cache suppresses some of the unnecessary tombstones we want to make assertions about + .withCachingDisabled(); + + if (leftJoin) { + final KTable fkJoin = + left.leftJoin(right, extractor, joiner, mainMaterialized); + + fkJoin.toStream() + .to(OUTPUT); + + // also make sure the FK join is set up right for downstream operations that require materialization + if (rejoin) { + fkJoin.leftJoin(left, rejoiner, rejoinMaterialized) + .toStream() + .to(REJOIN_OUTPUT); + } + } else { + final KTable fkJoin = left.join(right, extractor, joiner, mainMaterialized); + + fkJoin + .toStream() + .to(OUTPUT); + + // also make sure the FK join is set up right for downstream operations that require materialization + if (rejoin) { + fkJoin.join(left, rejoiner, rejoinMaterialized) + .toStream() + .to(REJOIN_OUTPUT); + } + } + + + return builder.build(streamsConfig); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/KTableKTableForeignKeyJoinMaterializationIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/KTableKTableForeignKeyJoinMaterializationIntegrationTest.java new file mode 100644 index 0000000..778f507 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/KTableKTableForeignKeyJoinMaterializationIntegrationTest.java @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TestOutputTopic; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.ValueJoiner; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.test.TestUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestName; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; +import java.util.function.Function; + +import static java.util.Collections.emptyMap; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkProperties; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + + +@RunWith(Parameterized.class) +public class KTableKTableForeignKeyJoinMaterializationIntegrationTest { + + private static final String LEFT_TABLE = "left_table"; + private static final String RIGHT_TABLE = "right_table"; + private static final String OUTPUT = "output-topic"; + private final boolean materialized; + private final boolean queryable; + + private Properties streamsConfig; + + public KTableKTableForeignKeyJoinMaterializationIntegrationTest(final boolean materialized, final boolean queryable) { + this.materialized = materialized; + this.queryable = queryable; + } + + @Rule + public TestName testName = new TestName(); + + @Before + public void before() { + final String safeTestName = safeUniqueTestName(getClass(), testName); + streamsConfig = mkProperties(mkMap( + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()) + )); + } + + + @Parameterized.Parameters(name = "materialized={0}, queryable={1}") + public static Collection data() { + return Arrays.asList( + new Object[] {false, false}, + new Object[] {true, false}, + new Object[] {true, true} + ); + } + + @Test + public void shouldEmitTombstoneWhenDeletingNonJoiningRecords() { + final Topology topology = getTopology(streamsConfig, "store"); + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, streamsConfig)) { + final TestInputTopic left = driver.createInputTopic(LEFT_TABLE, new StringSerializer(), new StringSerializer()); + final TestOutputTopic outputTopic = driver.createOutputTopic(OUTPUT, new StringDeserializer(), new StringDeserializer()); + final KeyValueStore store = driver.getKeyValueStore("store"); + + left.pipeInput("lhs1", "lhsValue1|rhs1"); + + assertThat( + outputTopic.readKeyValuesToMap(), + is(emptyMap()) + ); + if (materialized && queryable) { + assertThat( + asMap(store), + is(emptyMap()) + ); + } + + // Deleting a non-joining record produces an unnecessary tombstone for inner joins, because + // it's not possible to know whether a result was previously emitted. + left.pipeInput("lhs1", (String) null); + { + if (materialized && queryable) { + // in only this specific case, the record cache will actually be activated and + // suppress the unnecessary tombstone. This is because the cache is able to determine + // for sure that there has never been a previous result. (Because the "old" and "new" values + // are both null, and the underlying store is also missing the record in question). + assertThat( + outputTopic.readKeyValuesToMap(), + is(emptyMap()) + ); + + assertThat( + asMap(store), + is(emptyMap()) + ); + } else { + assertThat( + outputTopic.readKeyValuesToMap(), + is(mkMap(mkEntry("lhs1", null))) + ); + } + } + + // Deleting a non-existing record is idempotent + left.pipeInput("lhs1", (String) null); + { + assertThat( + outputTopic.readKeyValuesToMap(), + is(emptyMap()) + ); + if (materialized && queryable) { + assertThat( + asMap(store), + is(emptyMap()) + ); + } + } + } + } + + private static Map asMap(final KeyValueStore store) { + final HashMap result = new HashMap<>(); + store.all().forEachRemaining(kv -> result.put(kv.key, kv.value)); + return result; + } + + private Topology getTopology(final Properties streamsConfig, + final String queryableStoreName) { + final StreamsBuilder builder = new StreamsBuilder(); + + final KTable left = builder.table(LEFT_TABLE, Consumed.with(Serdes.String(), Serdes.String())); + final KTable right = builder.table(RIGHT_TABLE, Consumed.with(Serdes.String(), Serdes.String())); + + final Function extractor = value -> value.split("\\|")[1]; + final ValueJoiner joiner = (value1, value2) -> "(" + value1 + "," + value2 + ")"; + + final Materialized> materialized; + if (queryable) { + materialized = Materialized.>as(queryableStoreName).withValueSerde(Serdes.String()); + } else { + materialized = Materialized.with(null, Serdes.String()); + } + + final KTable joinResult; + if (this.materialized) { + joinResult = left.join( + right, + extractor, + joiner, + materialized + ); + } else { + joinResult = left.join( + right, + extractor, + joiner + ); + } + + joinResult + .toStream() + .to(OUTPUT, Produced.with(null, Serdes.String())); + + return builder.build(streamsConfig); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/KTableSourceTopicRestartIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/KTableSourceTopicRestartIntegrationTest.java new file mode 100644 index 0000000..6d50ea9 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/KTableSourceTopicRestartIntegrationTest.java @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.integration; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.processor.StateRestoreListener; +import org.apache.kafka.streams.processor.WallclockTimestampExtractor; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; + +import java.io.IOException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.ConcurrentHashMap; + +@Category({IntegrationTest.class}) +public class KTableSourceTopicRestartIntegrationTest { + private static final int NUM_BROKERS = 3; + private static final String SOURCE_TOPIC = "source-topic"; + private static final Properties PRODUCER_CONFIG = new Properties(); + private static final Properties STREAMS_CONFIG = new Properties(); + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(NUM_BROKERS); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + STREAMS_CONFIG.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + STREAMS_CONFIG.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass().getName()); + STREAMS_CONFIG.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass().getName()); + STREAMS_CONFIG.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()); + STREAMS_CONFIG.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + STREAMS_CONFIG.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 5L); + STREAMS_CONFIG.put(StreamsConfig.DEFAULT_TIMESTAMP_EXTRACTOR_CLASS_CONFIG, WallclockTimestampExtractor.class); + STREAMS_CONFIG.put(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, 1000); + STREAMS_CONFIG.put(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, 300); + + PRODUCER_CONFIG.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + PRODUCER_CONFIG.put(ProducerConfig.ACKS_CONFIG, "all"); + PRODUCER_CONFIG.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class); + PRODUCER_CONFIG.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, StringSerializer.class); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + + private final Time time = CLUSTER.time; + private final StreamsBuilder streamsBuilder = new StreamsBuilder(); + private final Map readKeyValues = new ConcurrentHashMap<>(); + + private String sourceTopic; + private KafkaStreams streams; + private Map expectedInitialResultsMap; + private Map expectedResultsWithDataWrittenDuringRestoreMap; + + @Rule + public TestName testName = new TestName(); + + @Before + public void before() throws Exception { + sourceTopic = SOURCE_TOPIC + "-" + testName.getMethodName(); + CLUSTER.createTopic(sourceTopic); + + STREAMS_CONFIG.put(StreamsConfig.APPLICATION_ID_CONFIG, IntegrationTestUtils.safeUniqueTestName(getClass(), testName)); + + final KTable kTable = streamsBuilder.table(sourceTopic, Materialized.as("store")); + kTable.toStream().foreach(readKeyValues::put); + + expectedInitialResultsMap = createExpectedResultsMap("a", "b", "c"); + expectedResultsWithDataWrittenDuringRestoreMap = createExpectedResultsMap("a", "b", "c", "d", "f", "g", "h"); + } + + @After + public void after() throws Exception { + IntegrationTestUtils.purgeLocalStreamsState(STREAMS_CONFIG); + } + + @Test + public void shouldRestoreAndProgressWhenTopicWrittenToDuringRestorationWithEosDisabled() throws Exception { + try { + streams = new KafkaStreams(streamsBuilder.build(), STREAMS_CONFIG); + streams.start(); + + produceKeyValues("a", "b", "c"); + + assertNumberValuesRead(readKeyValues, expectedInitialResultsMap, "Table did not read all values"); + + streams.close(); + streams = new KafkaStreams(streamsBuilder.build(), STREAMS_CONFIG); + // the state restore listener will append one record to the log + streams.setGlobalStateRestoreListener(new UpdatingSourceTopicOnRestoreStartStateRestoreListener()); + streams.start(); + + produceKeyValues("f", "g", "h"); + + assertNumberValuesRead( + readKeyValues, + expectedResultsWithDataWrittenDuringRestoreMap, + "Table did not get all values after restart"); + } finally { + streams.close(Duration.ofSeconds(5)); + } + } + + @SuppressWarnings("deprecation") + @Test + public void shouldRestoreAndProgressWhenTopicWrittenToDuringRestorationWithEosAlphaEnabled() throws Exception { + STREAMS_CONFIG.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE); + shouldRestoreAndProgressWhenTopicWrittenToDuringRestorationWithEosEnabled(); + } + + @Test + public void shouldRestoreAndProgressWhenTopicWrittenToDuringRestorationWithEosV2Enabled() throws Exception { + STREAMS_CONFIG.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE_V2); + shouldRestoreAndProgressWhenTopicWrittenToDuringRestorationWithEosEnabled(); + } + + private void shouldRestoreAndProgressWhenTopicWrittenToDuringRestorationWithEosEnabled() throws Exception { + try { + streams = new KafkaStreams(streamsBuilder.build(), STREAMS_CONFIG); + streams.start(); + + produceKeyValues("a", "b", "c"); + + assertNumberValuesRead(readKeyValues, expectedInitialResultsMap, "Table did not read all values"); + + streams.close(); + streams = new KafkaStreams(streamsBuilder.build(), STREAMS_CONFIG); + // the state restore listener will append one record to the log + streams.setGlobalStateRestoreListener(new UpdatingSourceTopicOnRestoreStartStateRestoreListener()); + streams.start(); + + produceKeyValues("f", "g", "h"); + + assertNumberValuesRead( + readKeyValues, + expectedResultsWithDataWrittenDuringRestoreMap, + "Table did not get all values after restart"); + } finally { + streams.close(Duration.ofSeconds(5)); + } + } + + @Test + public void shouldRestoreAndProgressWhenTopicNotWrittenToDuringRestoration() throws Exception { + try { + streams = new KafkaStreams(streamsBuilder.build(), STREAMS_CONFIG); + streams.start(); + + produceKeyValues("a", "b", "c"); + + assertNumberValuesRead(readKeyValues, expectedInitialResultsMap, "Table did not read all values"); + + streams.close(); + streams = new KafkaStreams(streamsBuilder.build(), STREAMS_CONFIG); + streams.start(); + + produceKeyValues("f", "g", "h"); + + final Map expectedValues = createExpectedResultsMap("a", "b", "c", "f", "g", "h"); + + assertNumberValuesRead(readKeyValues, expectedValues, "Table did not get all values after restart"); + } finally { + streams.close(Duration.ofSeconds(5)); + } + } + + private void assertNumberValuesRead(final Map valueMap, + final Map expectedMap, + final String errorMessage) throws InterruptedException { + TestUtils.waitForCondition( + () -> valueMap.equals(expectedMap), + 30 * 1000L, + errorMessage); + } + + private void produceKeyValues(final String... keys) { + final List> keyValueList = new ArrayList<>(); + + for (final String key : keys) { + keyValueList.add(new KeyValue<>(key, key + "1")); + } + + IntegrationTestUtils.produceKeyValuesSynchronously(sourceTopic, + keyValueList, + PRODUCER_CONFIG, + time); + } + + private Map createExpectedResultsMap(final String... keys) { + final Map expectedMap = new HashMap<>(); + for (final String key : keys) { + expectedMap.put(key, key + "1"); + } + return expectedMap; + } + + private class UpdatingSourceTopicOnRestoreStartStateRestoreListener implements StateRestoreListener { + + @Override + public void onRestoreStart(final TopicPartition topicPartition, + final String storeName, + final long startingOffset, + final long endingOffset) { + produceKeyValues("d"); + } + + @Override + public void onBatchRestored(final TopicPartition topicPartition, + final String storeName, + final long batchEndOffset, + final long numRestored) { + } + + @Override + public void onRestoreEnd(final TopicPartition topicPartition, + final String storeName, + final long totalRestored) { + } + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/LagFetchIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/LagFetchIntegrationTest.java new file mode 100644 index 0000000..6a02496 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/LagFetchIntegrationTest.java @@ -0,0 +1,368 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import kafka.utils.MockTime; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.LongDeserializer; +import org.apache.kafka.common.serialization.LongSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KafkaStreamsWrapper; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.LagInfo; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.processor.StateRestoreListener; +import org.apache.kafka.streams.processor.internals.StreamThread; +import org.apache.kafka.streams.processor.internals.assignment.FallbackPriorTaskAssignor; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.startApplicationAndWaitUntilRunning; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.IsEqual.equalTo; +import static org.junit.Assert.assertTrue; + +@Category({IntegrationTest.class}) +public class LagFetchIntegrationTest { + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(1); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + + private static final long WAIT_TIMEOUT_MS = 120000; + private static final Logger LOG = LoggerFactory.getLogger(LagFetchIntegrationTest.class); + + private final MockTime mockTime = CLUSTER.time; + private Properties streamsConfiguration; + private Properties consumerConfiguration; + private String inputTopicName; + private String outputTopicName; + private String stateStoreName; + + @Rule + public TestName testName = new TestName(); + + @Before + public void before() { + final String safeTestName = safeUniqueTestName(getClass(), testName); + inputTopicName = "input-topic-" + safeTestName; + outputTopicName = "output-topic-" + safeTestName; + stateStoreName = "lagfetch-test-store" + safeTestName; + + streamsConfiguration = new Properties(); + streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, "app-" + safeTestName); + streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + streamsConfiguration.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L); + + consumerConfiguration = new Properties(); + consumerConfiguration.setProperty(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + consumerConfiguration.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "group-" + safeTestName); + consumerConfiguration.setProperty(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + consumerConfiguration.setProperty(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class.getName()); + consumerConfiguration.setProperty(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, LongDeserializer.class.getName()); + } + + @After + public void shutdown() throws Exception { + IntegrationTestUtils.purgeLocalStreamsState(streamsConfiguration); + } + + private Map> getFirstNonEmptyLagMap(final KafkaStreams streams) throws InterruptedException { + final Map> offsetLagInfoMap = new HashMap<>(); + TestUtils.waitForCondition(() -> { + final Map> lagMap = streams.allLocalStorePartitionLags(); + if (lagMap.size() > 0) { + offsetLagInfoMap.putAll(lagMap); + } + return lagMap.size() > 0; + }, WAIT_TIMEOUT_MS, "Should obtain non-empty lag information eventually"); + return offsetLagInfoMap; + } + + private void shouldFetchLagsDuringRebalancing(final String optimization) throws Exception { + final CountDownLatch latchTillActiveIsRunning = new CountDownLatch(1); + final CountDownLatch latchTillStandbyIsRunning = new CountDownLatch(1); + final CountDownLatch latchTillStandbyHasPartitionsAssigned = new CountDownLatch(1); + final CyclicBarrier lagCheckBarrier = new CyclicBarrier(2); + final List streamsList = new ArrayList<>(); + + IntegrationTestUtils.produceKeyValuesSynchronously( + inputTopicName, + mkSet(new KeyValue<>("k1", 1L), new KeyValue<>("k2", 2L), new KeyValue<>("k3", 3L), new KeyValue<>("k4", 4L), new KeyValue<>("k5", 5L)), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + StringSerializer.class, + LongSerializer.class, + new Properties()), + mockTime); + + // create stream threads + for (int i = 0; i < 2; i++) { + final Properties props = (Properties) streamsConfiguration.clone(); + // this test relies on the second instance getting the standby, so we specify + // an assignor with this contract. + props.put(StreamsConfig.InternalConfig.INTERNAL_TASK_ASSIGNOR_CLASS, FallbackPriorTaskAssignor.class.getName()); + props.put(StreamsConfig.APPLICATION_SERVER_CONFIG, "localhost:" + i); + props.put(StreamsConfig.CLIENT_ID_CONFIG, "instance-" + i); + props.put(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, optimization); + props.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1); + props.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory(stateStoreName + i).getAbsolutePath()); + + final StreamsBuilder builder = new StreamsBuilder(); + final KTable t1 = builder.table(inputTopicName, Materialized.as(stateStoreName)); + t1.toStream().to(outputTopicName); + final KafkaStreamsWrapper streams = new KafkaStreamsWrapper(builder.build(props), props); + streamsList.add(streams); + } + + final KafkaStreamsWrapper activeStreams = streamsList.get(0); + final KafkaStreamsWrapper standbyStreams = streamsList.get(1); + activeStreams.setStreamThreadStateListener((thread, newState, oldState) -> { + if (newState == StreamThread.State.RUNNING) { + latchTillActiveIsRunning.countDown(); + } + }); + standbyStreams.setStreamThreadStateListener((thread, newState, oldState) -> { + if (oldState == StreamThread.State.PARTITIONS_ASSIGNED && newState == StreamThread.State.RUNNING) { + latchTillStandbyHasPartitionsAssigned.countDown(); + try { + lagCheckBarrier.await(60, TimeUnit.SECONDS); + } catch (final Exception e) { + throw new RuntimeException(e); + } + } else if (newState == StreamThread.State.RUNNING) { + latchTillStandbyIsRunning.countDown(); + } + }); + + try { + // First start up the active. + TestUtils.waitForCondition(() -> activeStreams.allLocalStorePartitionLags().size() == 0, + WAIT_TIMEOUT_MS, + "Should see empty lag map before streams is started."); + activeStreams.start(); + latchTillActiveIsRunning.await(60, TimeUnit.SECONDS); + + IntegrationTestUtils.waitUntilMinValuesRecordsReceived( + consumerConfiguration, + outputTopicName, + 5, + WAIT_TIMEOUT_MS); + // Check the active reports proper lag values. + Map> offsetLagInfoMap = getFirstNonEmptyLagMap(activeStreams); + assertThat(offsetLagInfoMap.size(), equalTo(1)); + assertThat(offsetLagInfoMap.keySet(), equalTo(mkSet(stateStoreName))); + assertThat(offsetLagInfoMap.get(stateStoreName).size(), equalTo(1)); + LagInfo lagInfo = offsetLagInfoMap.get(stateStoreName).get(0); + assertThat(lagInfo.currentOffsetPosition(), equalTo(5L)); + assertThat(lagInfo.endOffsetPosition(), equalTo(5L)); + assertThat(lagInfo.offsetLag(), equalTo(0L)); + + // start up the standby & make it pause right after it has partition assigned + standbyStreams.start(); + latchTillStandbyHasPartitionsAssigned.await(60, TimeUnit.SECONDS); + offsetLagInfoMap = getFirstNonEmptyLagMap(standbyStreams); + assertThat(offsetLagInfoMap.size(), equalTo(1)); + assertThat(offsetLagInfoMap.keySet(), equalTo(mkSet(stateStoreName))); + assertThat(offsetLagInfoMap.get(stateStoreName).size(), equalTo(1)); + lagInfo = offsetLagInfoMap.get(stateStoreName).get(0); + assertThat(lagInfo.currentOffsetPosition(), equalTo(0L)); + assertThat(lagInfo.endOffsetPosition(), equalTo(5L)); + assertThat(lagInfo.offsetLag(), equalTo(5L)); + // standby thread wont proceed to RUNNING before this barrier is crossed + lagCheckBarrier.await(60, TimeUnit.SECONDS); + + // wait till the lag goes down to 0, on the standby + TestUtils.waitForCondition(() -> standbyStreams.allLocalStorePartitionLags().get(stateStoreName).get(0).offsetLag() == 0, + WAIT_TIMEOUT_MS, + "Standby should eventually catchup and have zero lag."); + } finally { + for (final KafkaStreams streams : streamsList) { + streams.close(); + } + } + } + + @Test + public void shouldFetchLagsDuringRebalancingWithOptimization() throws Exception { + shouldFetchLagsDuringRebalancing(StreamsConfig.OPTIMIZE); + } + + @Test + public void shouldFetchLagsDuringRebalancingWithNoOptimization() throws Exception { + shouldFetchLagsDuringRebalancing(StreamsConfig.NO_OPTIMIZATION); + } + + @Test + public void shouldFetchLagsDuringRestoration() throws Exception { + IntegrationTestUtils.produceKeyValuesSynchronously( + inputTopicName, + mkSet(new KeyValue<>("k1", 1L), new KeyValue<>("k2", 2L), new KeyValue<>("k3", 3L), new KeyValue<>("k4", 4L), new KeyValue<>("k5", 5L)), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + StringSerializer.class, + LongSerializer.class, + new Properties()), + mockTime); + + // create stream threads + final Properties props = (Properties) streamsConfiguration.clone(); + final File stateDir = TestUtils.tempDirectory(stateStoreName + "0"); + props.put(StreamsConfig.APPLICATION_SERVER_CONFIG, "localhost:0"); + props.put(StreamsConfig.CLIENT_ID_CONFIG, "instance-0"); + props.put(StreamsConfig.STATE_DIR_CONFIG, stateDir.getAbsolutePath()); + + final StreamsBuilder builder = new StreamsBuilder(); + final KTable t1 = builder.table(inputTopicName, Materialized.as(stateStoreName)); + t1.toStream().to(outputTopicName); + final KafkaStreams streams = new KafkaStreams(builder.build(), props); + + try { + // First start up the active. + TestUtils.waitForCondition(() -> streams.allLocalStorePartitionLags().size() == 0, + WAIT_TIMEOUT_MS, + "Should see empty lag map before streams is started."); + + // Get the instance to fully catch up and reach RUNNING state + startApplicationAndWaitUntilRunning(Collections.singletonList(streams), Duration.ofSeconds(60)); + IntegrationTestUtils.waitUntilMinValuesRecordsReceived( + consumerConfiguration, + outputTopicName, + 5, + WAIT_TIMEOUT_MS); + + // check for proper lag values. + final AtomicReference zeroLagRef = new AtomicReference<>(); + TestUtils.waitForCondition(() -> { + final Map> offsetLagInfoMap = streams.allLocalStorePartitionLags(); + assertThat(offsetLagInfoMap.size(), equalTo(1)); + assertThat(offsetLagInfoMap.keySet(), equalTo(mkSet(stateStoreName))); + assertThat(offsetLagInfoMap.get(stateStoreName).size(), equalTo(1)); + + final LagInfo zeroLagInfo = offsetLagInfoMap.get(stateStoreName).get(0); + assertThat(zeroLagInfo.currentOffsetPosition(), equalTo(5L)); + assertThat(zeroLagInfo.endOffsetPosition(), equalTo(5L)); + assertThat(zeroLagInfo.offsetLag(), equalTo(0L)); + zeroLagRef.set(zeroLagInfo); + return true; + }, WAIT_TIMEOUT_MS, "Eventually should reach zero lag."); + + // Kill instance, delete state to force restoration. + assertThat("Streams instance did not close within timeout", streams.close(Duration.ofSeconds(60))); + IntegrationTestUtils.purgeLocalStreamsState(streamsConfiguration); + Files.walk(stateDir.toPath()).sorted(Comparator.reverseOrder()) + .map(Path::toFile) + .forEach(f -> assertTrue("Some state " + f + " could not be deleted", f.delete())); + + // wait till the lag goes down to 0 + final KafkaStreams restartedStreams = new KafkaStreams(builder.build(), props); + // set a state restoration listener to track progress of restoration + final CountDownLatch restorationEndLatch = new CountDownLatch(1); + final Map> restoreStartLagInfo = new HashMap<>(); + final Map> restoreEndLagInfo = new HashMap<>(); + restartedStreams.setGlobalStateRestoreListener(new StateRestoreListener() { + @Override + public void onRestoreStart(final TopicPartition topicPartition, final String storeName, final long startingOffset, final long endingOffset) { + try { + restoreStartLagInfo.putAll(getFirstNonEmptyLagMap(restartedStreams)); + } catch (final Exception e) { + LOG.error("Exception while trying to obtain lag map", e); + } + } + + @Override + public void onBatchRestored(final TopicPartition topicPartition, final String storeName, final long batchEndOffset, final long numRestored) { + } + + @Override + public void onRestoreEnd(final TopicPartition topicPartition, final String storeName, final long totalRestored) { + try { + restoreEndLagInfo.putAll(getFirstNonEmptyLagMap(restartedStreams)); + } catch (final Exception e) { + LOG.error("Exception while trying to obtain lag map", e); + } + restorationEndLatch.countDown(); + } + }); + + restartedStreams.start(); + restorationEndLatch.await(WAIT_TIMEOUT_MS, TimeUnit.MILLISECONDS); + TestUtils.waitForCondition(() -> restartedStreams.allLocalStorePartitionLags().get(stateStoreName).get(0).offsetLag() == 0, + WAIT_TIMEOUT_MS, + "Standby should eventually catchup and have zero lag."); + final LagInfo fullLagInfo = restoreStartLagInfo.get(stateStoreName).get(0); + assertThat(fullLagInfo.currentOffsetPosition(), equalTo(0L)); + assertThat(fullLagInfo.endOffsetPosition(), equalTo(5L)); + assertThat(fullLagInfo.offsetLag(), equalTo(5L)); + + assertThat(restoreEndLagInfo.get(stateStoreName).get(0), equalTo(zeroLagRef.get())); + } finally { + streams.close(); + streams.cleanUp(); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/MetricsIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/MetricsIntegrationTest.java new file mode 100644 index 0000000..9ada60f --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/MetricsIntegrationTest.java @@ -0,0 +1,740 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KafkaStreams.State; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.SessionWindows; +import org.apache.kafka.streams.kstream.Suppressed; +import org.apache.kafka.streams.kstream.Suppressed.BufferConfig; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.streams.state.SessionStore; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; + +import java.io.IOException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Properties; +import java.util.stream.Collectors; + +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +@Category({IntegrationTest.class}) +@SuppressWarnings("deprecation") +public class MetricsIntegrationTest { + + private static final int NUM_BROKERS = 1; + private static final int NUM_THREADS = 2; + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(NUM_BROKERS); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + private final long timeout = 60000; + + // Metric group + private static final String STREAM_CLIENT_NODE_METRICS = "stream-metrics"; + private static final String STREAM_THREAD_NODE_METRICS = "stream-thread-metrics"; + private static final String STREAM_TASK_NODE_METRICS = "stream-task-metrics"; + private static final String STREAM_PROCESSOR_NODE_METRICS = "stream-processor-node-metrics"; + private static final String STREAM_CACHE_NODE_METRICS = "stream-record-cache-metrics"; + + private static final String IN_MEMORY_KVSTORE_TAG_KEY = "in-memory-state-id"; + private static final String IN_MEMORY_LRUCACHE_TAG_KEY = "in-memory-lru-state-id"; + private static final String ROCKSDB_KVSTORE_TAG_KEY = "rocksdb-state-id"; + private static final String STATE_STORE_LEVEL_GROUP = "stream-state-metrics"; + + // Metrics name + private static final String VERSION = "version"; + private static final String COMMIT_ID = "commit-id"; + private static final String APPLICATION_ID = "application-id"; + private static final String TOPOLOGY_DESCRIPTION = "topology-description"; + private static final String STATE = "state"; + private static final String ALIVE_STREAM_THREADS = "alive-stream-threads"; + private static final String FAILED_STREAM_THREADS = "failed-stream-threads"; + private static final String PUT_LATENCY_AVG = "put-latency-avg"; + private static final String PUT_LATENCY_MAX = "put-latency-max"; + private static final String PUT_IF_ABSENT_LATENCY_AVG = "put-if-absent-latency-avg"; + private static final String PUT_IF_ABSENT_LATENCY_MAX = "put-if-absent-latency-max"; + private static final String GET_LATENCY_AVG = "get-latency-avg"; + private static final String GET_LATENCY_MAX = "get-latency-max"; + private static final String DELETE_LATENCY_AVG = "delete-latency-avg"; + private static final String DELETE_LATENCY_MAX = "delete-latency-max"; + private static final String REMOVE_LATENCY_AVG = "remove-latency-avg"; + private static final String REMOVE_LATENCY_MAX = "remove-latency-max"; + private static final String PUT_ALL_LATENCY_AVG = "put-all-latency-avg"; + private static final String PUT_ALL_LATENCY_MAX = "put-all-latency-max"; + private static final String ALL_LATENCY_AVG = "all-latency-avg"; + private static final String ALL_LATENCY_MAX = "all-latency-max"; + private static final String RANGE_LATENCY_AVG = "range-latency-avg"; + private static final String RANGE_LATENCY_MAX = "range-latency-max"; + private static final String FLUSH_LATENCY_AVG = "flush-latency-avg"; + private static final String FLUSH_LATENCY_MAX = "flush-latency-max"; + private static final String RESTORE_LATENCY_AVG = "restore-latency-avg"; + private static final String RESTORE_LATENCY_MAX = "restore-latency-max"; + private static final String PUT_RATE = "put-rate"; + private static final String PUT_TOTAL = "put-total"; + private static final String PUT_IF_ABSENT_RATE = "put-if-absent-rate"; + private static final String PUT_IF_ABSENT_TOTAL = "put-if-absent-total"; + private static final String GET_RATE = "get-rate"; + private static final String GET_TOTAL = "get-total"; + private static final String FETCH_RATE = "fetch-rate"; + private static final String FETCH_TOTAL = "fetch-total"; + private static final String FETCH_LATENCY_AVG = "fetch-latency-avg"; + private static final String FETCH_LATENCY_MAX = "fetch-latency-max"; + private static final String DELETE_RATE = "delete-rate"; + private static final String DELETE_TOTAL = "delete-total"; + private static final String REMOVE_RATE = "remove-rate"; + private static final String REMOVE_TOTAL = "remove-total"; + private static final String PUT_ALL_RATE = "put-all-rate"; + private static final String PUT_ALL_TOTAL = "put-all-total"; + private static final String ALL_RATE = "all-rate"; + private static final String ALL_TOTAL = "all-total"; + private static final String RANGE_RATE = "range-rate"; + private static final String RANGE_TOTAL = "range-total"; + private static final String FLUSH_RATE = "flush-rate"; + private static final String FLUSH_TOTAL = "flush-total"; + private static final String RESTORE_RATE = "restore-rate"; + private static final String RESTORE_TOTAL = "restore-total"; + private static final String PROCESS_LATENCY_AVG = "process-latency-avg"; + private static final String PROCESS_LATENCY_MAX = "process-latency-max"; + private static final String PUNCTUATE_LATENCY_AVG = "punctuate-latency-avg"; + private static final String PUNCTUATE_LATENCY_MAX = "punctuate-latency-max"; + private static final String CREATE_LATENCY_AVG = "create-latency-avg"; + private static final String CREATE_LATENCY_MAX = "create-latency-max"; + private static final String DESTROY_LATENCY_AVG = "destroy-latency-avg"; + private static final String DESTROY_LATENCY_MAX = "destroy-latency-max"; + private static final String PROCESS_RATE = "process-rate"; + private static final String PROCESS_TOTAL = "process-total"; + private static final String PROCESS_RATIO = "process-ratio"; + private static final String PROCESS_RECORDS_AVG = "process-records-avg"; + private static final String PROCESS_RECORDS_MAX = "process-records-max"; + private static final String PUNCTUATE_RATE = "punctuate-rate"; + private static final String PUNCTUATE_TOTAL = "punctuate-total"; + private static final String PUNCTUATE_RATIO = "punctuate-ratio"; + private static final String CREATE_RATE = "create-rate"; + private static final String CREATE_TOTAL = "create-total"; + private static final String DESTROY_RATE = "destroy-rate"; + private static final String DESTROY_TOTAL = "destroy-total"; + private static final String FORWARD_TOTAL = "forward-total"; + private static final String FORWARD_RATE = "forward-rate"; + private static final String STREAM_STRING = "stream"; + private static final String COMMIT_LATENCY_AVG = "commit-latency-avg"; + private static final String COMMIT_LATENCY_MAX = "commit-latency-max"; + private static final String POLL_LATENCY_AVG = "poll-latency-avg"; + private static final String POLL_LATENCY_MAX = "poll-latency-max"; + private static final String COMMIT_RATE = "commit-rate"; + private static final String COMMIT_TOTAL = "commit-total"; + private static final String COMMIT_RATIO = "commit-ratio"; + private static final String ENFORCED_PROCESSING_RATE = "enforced-processing-rate"; + private static final String ENFORCED_PROCESSING_TOTAL = "enforced-processing-total"; + private static final String POLL_RATE = "poll-rate"; + private static final String POLL_TOTAL = "poll-total"; + private static final String POLL_RATIO = "poll-ratio"; + private static final String POLL_RECORDS_AVG = "poll-records-avg"; + private static final String POLL_RECORDS_MAX = "poll-records-max"; + private static final String TASK_CREATED_RATE = "task-created-rate"; + private static final String TASK_CREATED_TOTAL = "task-created-total"; + private static final String TASK_CLOSED_RATE = "task-closed-rate"; + private static final String TASK_CLOSED_TOTAL = "task-closed-total"; + private static final String BLOCKED_TIME_TOTAL = "blocked-time-ns-total"; + private static final String THREAD_START_TIME = "thread-start-time"; + private static final String ACTIVE_PROCESS_RATIO = "active-process-ratio"; + private static final String ACTIVE_BUFFER_COUNT = "active-buffer-count"; + private static final String SKIPPED_RECORDS_RATE = "skipped-records-rate"; + private static final String SKIPPED_RECORDS_TOTAL = "skipped-records-total"; + private static final String RECORD_LATENESS_AVG = "record-lateness-avg"; + private static final String RECORD_LATENESS_MAX = "record-lateness-max"; + private static final String HIT_RATIO_AVG = "hit-ratio-avg"; + private static final String HIT_RATIO_MIN = "hit-ratio-min"; + private static final String HIT_RATIO_MAX = "hit-ratio-max"; + private static final String SUPPRESSION_BUFFER_SIZE_CURRENT = "suppression-buffer-size-current"; + private static final String SUPPRESSION_BUFFER_SIZE_AVG = "suppression-buffer-size-avg"; + private static final String SUPPRESSION_BUFFER_SIZE_MAX = "suppression-buffer-size-max"; + private static final String SUPPRESSION_BUFFER_COUNT_CURRENT = "suppression-buffer-count-current"; + private static final String SUPPRESSION_BUFFER_COUNT_AVG = "suppression-buffer-count-avg"; + private static final String SUPPRESSION_BUFFER_COUNT_MAX = "suppression-buffer-count-max"; + private static final String EXPIRED_WINDOW_RECORD_DROP_RATE = "expired-window-record-drop-rate"; + private static final String EXPIRED_WINDOW_RECORD_DROP_TOTAL = "expired-window-record-drop-total"; + private static final String RECORD_E2E_LATENCY_AVG = "record-e2e-latency-avg"; + private static final String RECORD_E2E_LATENCY_MIN = "record-e2e-latency-min"; + private static final String RECORD_E2E_LATENCY_MAX = "record-e2e-latency-max"; + + // stores name + private static final String TIME_WINDOWED_AGGREGATED_STREAM_STORE = "time-windowed-aggregated-stream-store"; + private static final String SESSION_AGGREGATED_STREAM_STORE = "session-aggregated-stream-store"; + private static final String MY_STORE_IN_MEMORY = "myStoreInMemory"; + private static final String MY_STORE_PERSISTENT_KEY_VALUE = "myStorePersistentKeyValue"; + private static final String MY_STORE_LRU_MAP = "myStoreLruMap"; + + // topic names + private static final String STREAM_INPUT = "STREAM_INPUT"; + private static final String STREAM_OUTPUT_1 = "STREAM_OUTPUT_1"; + private static final String STREAM_OUTPUT_2 = "STREAM_OUTPUT_2"; + private static final String STREAM_OUTPUT_3 = "STREAM_OUTPUT_3"; + private static final String STREAM_OUTPUT_4 = "STREAM_OUTPUT_4"; + + private StreamsBuilder builder; + private Properties streamsConfiguration; + private KafkaStreams kafkaStreams; + + private String appId; + + @Rule + public TestName testName = new TestName(); + + @Before + public void before() throws InterruptedException { + builder = new StreamsBuilder(); + CLUSTER.createTopics(STREAM_INPUT, STREAM_OUTPUT_1, STREAM_OUTPUT_2, STREAM_OUTPUT_3, STREAM_OUTPUT_4); + + final String safeTestName = safeUniqueTestName(getClass(), testName); + appId = "app-" + safeTestName; + + streamsConfiguration = new Properties(); + streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, appId); + streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.Integer().getClass()); + streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + streamsConfiguration.put(StreamsConfig.METRICS_RECORDING_LEVEL_CONFIG, Sensor.RecordingLevel.DEBUG.name); + streamsConfiguration.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 10 * 1024 * 1024L); + streamsConfiguration.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, NUM_THREADS); + streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()); + } + + @After + public void after() throws InterruptedException { + CLUSTER.deleteTopics(STREAM_INPUT, STREAM_OUTPUT_1, STREAM_OUTPUT_2, STREAM_OUTPUT_3, STREAM_OUTPUT_4); + } + + private void startApplication() throws InterruptedException { + final Topology topology = builder.build(); + kafkaStreams = new KafkaStreams(topology, streamsConfiguration); + + verifyAliveStreamThreadsMetric(); + verifyStateMetric(State.CREATED); + verifyTopologyDescriptionMetric(topology.describe().toString()); + verifyApplicationIdMetric(); + + kafkaStreams.start(); + TestUtils.waitForCondition( + () -> kafkaStreams.state() == State.RUNNING, + timeout, + () -> "Kafka Streams application did not reach state RUNNING in " + timeout + " ms"); + + verifyAliveStreamThreadsMetric(); + verifyStateMetric(State.RUNNING); + } + + private void produceRecordsForTwoSegments(final Duration segmentInterval) { + final MockTime mockTime = new MockTime(Math.max(segmentInterval.toMillis(), 60_000L)); + final Properties props = TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + IntegerSerializer.class, + StringSerializer.class, + new Properties()); + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + STREAM_INPUT, + Collections.singletonList(new KeyValue<>(1, "A")), + props, + mockTime.milliseconds() + ); + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + STREAM_INPUT, + Collections.singletonList(new KeyValue<>(1, "B")), + props, + mockTime.milliseconds() + ); + } + + private void produceRecordsForClosingWindow(final Duration windowSize) { + final MockTime mockTime = new MockTime(windowSize.toMillis() + 1); + final Properties props = TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + IntegerSerializer.class, + StringSerializer.class, + new Properties()); + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + STREAM_INPUT, + Collections.singletonList(new KeyValue<>(1, "A")), + props, + mockTime.milliseconds() + ); + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + STREAM_INPUT, + Collections.singletonList(new KeyValue<>(1, "B")), + props, + mockTime.milliseconds() + ); + } + + private void closeApplication() throws Exception { + kafkaStreams.close(); + kafkaStreams.cleanUp(); + IntegrationTestUtils.purgeLocalStreamsState(streamsConfiguration); + final long timeout = 60000; + TestUtils.waitForCondition( + () -> kafkaStreams.state() == State.NOT_RUNNING, + timeout, + () -> "Kafka Streams application did not reach state NOT_RUNNING in " + timeout + " ms"); + } + + @Test + public void shouldAddMetricsOnAllLevels() throws Exception { + builder.stream(STREAM_INPUT, Consumed.with(Serdes.Integer(), Serdes.String())) + .to(STREAM_OUTPUT_1, Produced.with(Serdes.Integer(), Serdes.String())); + builder.table(STREAM_OUTPUT_1, + Materialized.as(Stores.inMemoryKeyValueStore(MY_STORE_IN_MEMORY)).withCachingEnabled()) + .toStream() + .to(STREAM_OUTPUT_2); + builder.table(STREAM_OUTPUT_2, + Materialized.as(Stores.persistentKeyValueStore(MY_STORE_PERSISTENT_KEY_VALUE)).withCachingEnabled()) + .toStream() + .to(STREAM_OUTPUT_3); + builder.table(STREAM_OUTPUT_3, + Materialized.as(Stores.lruMap(MY_STORE_LRU_MAP, 10000)).withCachingEnabled()) + .toStream() + .to(STREAM_OUTPUT_4); + startApplication(); + + verifyStateMetric(State.RUNNING); + checkClientLevelMetrics(); + checkThreadLevelMetrics(); + checkTaskLevelMetrics(); + checkProcessorNodeLevelMetrics(); + checkKeyValueStoreMetrics(IN_MEMORY_KVSTORE_TAG_KEY); + checkKeyValueStoreMetrics(ROCKSDB_KVSTORE_TAG_KEY); + checkKeyValueStoreMetrics(IN_MEMORY_LRUCACHE_TAG_KEY); + checkCacheMetrics(); + + closeApplication(); + + checkMetricsDeregistration(); + } + + @Test + public void shouldAddMetricsForWindowStoreAndSuppressionBuffer() throws Exception { + final Duration windowSize = Duration.ofMillis(50); + builder.stream(STREAM_INPUT, Consumed.with(Serdes.Integer(), Serdes.String())) + .groupByKey() + .windowedBy(TimeWindows.of(windowSize).grace(Duration.ZERO)) + .aggregate(() -> 0L, + (aggKey, newValue, aggValue) -> aggValue, + Materialized.>as(TIME_WINDOWED_AGGREGATED_STREAM_STORE) + .withValueSerde(Serdes.Long()) + .withRetention(windowSize)) + .suppress(Suppressed.untilWindowCloses(BufferConfig.unbounded())) + .toStream() + .map((key, value) -> KeyValue.pair(value, value)) + .to(STREAM_OUTPUT_1, Produced.with(Serdes.Long(), Serdes.Long())); + + produceRecordsForClosingWindow(windowSize); + startApplication(); + + verifyStateMetric(State.RUNNING); + + checkWindowStoreAndSuppressionBufferMetrics(); + + closeApplication(); + + checkMetricsDeregistration(); + } + + @Test + public void shouldAddMetricsForSessionStore() throws Exception { + final Duration inactivityGap = Duration.ofMillis(50); + builder.stream(STREAM_INPUT, Consumed.with(Serdes.Integer(), Serdes.String())) + .groupByKey() + .windowedBy(SessionWindows.with(inactivityGap).grace(Duration.ZERO)) + .aggregate(() -> 0L, + (aggKey, newValue, aggValue) -> aggValue, + (aggKey, leftAggValue, rightAggValue) -> leftAggValue, + Materialized.>as(SESSION_AGGREGATED_STREAM_STORE) + .withValueSerde(Serdes.Long()) + .withRetention(inactivityGap)) + .toStream() + .map((key, value) -> KeyValue.pair(value, value)) + .to(STREAM_OUTPUT_1, Produced.with(Serdes.Long(), Serdes.Long())); + + produceRecordsForTwoSegments(inactivityGap); + + startApplication(); + + verifyStateMetric(State.RUNNING); + + checkSessionStoreMetrics(); + + closeApplication(); + + checkMetricsDeregistration(); + } + + private void verifyAliveStreamThreadsMetric() { + final List metricsList = new ArrayList(kafkaStreams.metrics().values()).stream() + .filter(m -> m.metricName().name().equals(ALIVE_STREAM_THREADS) && + m.metricName().group().equals(STREAM_CLIENT_NODE_METRICS)) + .collect(Collectors.toList()); + assertThat(metricsList.size(), is(1)); + assertThat(metricsList.get(0).metricValue(), is(NUM_THREADS)); + } + + private void verifyStateMetric(final State state) { + final List metricsList = new ArrayList(kafkaStreams.metrics().values()).stream() + .filter(m -> m.metricName().name().equals(STATE) && + m.metricName().group().equals(STREAM_CLIENT_NODE_METRICS)) + .collect(Collectors.toList()); + assertThat(metricsList.size(), is(1)); + assertThat(metricsList.get(0).metricValue(), is(state)); + assertThat(metricsList.get(0).metricValue().toString(), is(state.toString())); + } + + private void verifyTopologyDescriptionMetric(final String topologyDescription) { + final List metricsList = new ArrayList(kafkaStreams.metrics().values()).stream() + .filter(m -> m.metricName().name().equals(TOPOLOGY_DESCRIPTION) && + m.metricName().group().equals(STREAM_CLIENT_NODE_METRICS)) + .collect(Collectors.toList()); + assertThat(metricsList.size(), is(1)); + assertThat(metricsList.get(0).metricValue(), is(topologyDescription)); + } + + private void verifyApplicationIdMetric() { + final List metricsList = new ArrayList(kafkaStreams.metrics().values()).stream() + .filter(m -> m.metricName().name().equals(APPLICATION_ID) && + m.metricName().group().equals(STREAM_CLIENT_NODE_METRICS)) + .collect(Collectors.toList()); + assertThat(metricsList.size(), is(1)); + assertThat(metricsList.get(0).metricValue(), is(appId)); + } + + private void checkClientLevelMetrics() { + final List listMetricThread = new ArrayList(kafkaStreams.metrics().values()).stream() + .filter(m -> m.metricName().group().equals(STREAM_CLIENT_NODE_METRICS)) + .collect(Collectors.toList()); + checkMetricByName(listMetricThread, VERSION, 1); + checkMetricByName(listMetricThread, COMMIT_ID, 1); + checkMetricByName(listMetricThread, APPLICATION_ID, 1); + checkMetricByName(listMetricThread, TOPOLOGY_DESCRIPTION, 1); + checkMetricByName(listMetricThread, STATE, 1); + checkMetricByName(listMetricThread, ALIVE_STREAM_THREADS, 1); + checkMetricByName(listMetricThread, FAILED_STREAM_THREADS, 1); + } + + private void checkThreadLevelMetrics() { + final List listMetricThread = new ArrayList(kafkaStreams.metrics().values()).stream() + .filter(m -> m.metricName().group().equals(STREAM_THREAD_NODE_METRICS)) + .collect(Collectors.toList()); + checkMetricByName(listMetricThread, COMMIT_LATENCY_AVG, NUM_THREADS); + checkMetricByName(listMetricThread, COMMIT_LATENCY_MAX, NUM_THREADS); + checkMetricByName(listMetricThread, POLL_LATENCY_AVG, NUM_THREADS); + checkMetricByName(listMetricThread, POLL_LATENCY_MAX, NUM_THREADS); + checkMetricByName(listMetricThread, PROCESS_LATENCY_AVG, NUM_THREADS); + checkMetricByName(listMetricThread, PROCESS_LATENCY_MAX, NUM_THREADS); + checkMetricByName(listMetricThread, PUNCTUATE_LATENCY_AVG, NUM_THREADS); + checkMetricByName(listMetricThread, PUNCTUATE_LATENCY_MAX, NUM_THREADS); + checkMetricByName(listMetricThread, COMMIT_RATE, NUM_THREADS); + checkMetricByName(listMetricThread, COMMIT_TOTAL, NUM_THREADS); + checkMetricByName(listMetricThread, COMMIT_RATIO, NUM_THREADS); + checkMetricByName(listMetricThread, POLL_RATE, NUM_THREADS); + checkMetricByName(listMetricThread, POLL_TOTAL, NUM_THREADS); + checkMetricByName(listMetricThread, POLL_RATIO, NUM_THREADS); + checkMetricByName(listMetricThread, POLL_RECORDS_AVG, NUM_THREADS); + checkMetricByName(listMetricThread, POLL_RECORDS_MAX, NUM_THREADS); + checkMetricByName(listMetricThread, PROCESS_RATE, NUM_THREADS); + checkMetricByName(listMetricThread, PROCESS_TOTAL, NUM_THREADS); + checkMetricByName(listMetricThread, PROCESS_RATIO, NUM_THREADS); + checkMetricByName(listMetricThread, PROCESS_RECORDS_AVG, NUM_THREADS); + checkMetricByName(listMetricThread, PROCESS_RECORDS_MAX, NUM_THREADS); + checkMetricByName(listMetricThread, PUNCTUATE_RATE, NUM_THREADS); + checkMetricByName(listMetricThread, PUNCTUATE_TOTAL, NUM_THREADS); + checkMetricByName(listMetricThread, PUNCTUATE_RATIO, NUM_THREADS); + checkMetricByName(listMetricThread, TASK_CREATED_RATE, NUM_THREADS); + checkMetricByName(listMetricThread, TASK_CREATED_TOTAL, NUM_THREADS); + checkMetricByName(listMetricThread, TASK_CLOSED_RATE, NUM_THREADS); + checkMetricByName(listMetricThread, TASK_CLOSED_TOTAL, NUM_THREADS); + checkMetricByName(listMetricThread, BLOCKED_TIME_TOTAL, NUM_THREADS); + checkMetricByName(listMetricThread, THREAD_START_TIME, NUM_THREADS); + } + + private void checkTaskLevelMetrics() { + final List listMetricTask = new ArrayList(kafkaStreams.metrics().values()).stream() + .filter(m -> m.metricName().group().equals(STREAM_TASK_NODE_METRICS)) + .collect(Collectors.toList()); + checkMetricByName(listMetricTask, ENFORCED_PROCESSING_RATE, 4); + checkMetricByName(listMetricTask, ENFORCED_PROCESSING_TOTAL, 4); + checkMetricByName(listMetricTask, RECORD_LATENESS_AVG, 4); + checkMetricByName(listMetricTask, RECORD_LATENESS_MAX, 4); + checkMetricByName(listMetricTask, ACTIVE_PROCESS_RATIO, 4); + checkMetricByName(listMetricTask, ACTIVE_BUFFER_COUNT, 4); + checkMetricByName(listMetricTask, PROCESS_LATENCY_AVG, 4); + checkMetricByName(listMetricTask, PROCESS_LATENCY_MAX, 4); + checkMetricByName(listMetricTask, PUNCTUATE_LATENCY_AVG, 4); + checkMetricByName(listMetricTask, PUNCTUATE_LATENCY_MAX, 4); + checkMetricByName(listMetricTask, PUNCTUATE_RATE, 4); + checkMetricByName(listMetricTask, PUNCTUATE_TOTAL, 4); + checkMetricByName(listMetricTask, PROCESS_RATE, 4); + checkMetricByName(listMetricTask, PROCESS_TOTAL, 4); + } + + private void checkProcessorNodeLevelMetrics() { + final List listMetricProcessor = new ArrayList(kafkaStreams.metrics().values()).stream() + .filter(m -> m.metricName().group().equals(STREAM_PROCESSOR_NODE_METRICS)) + .collect(Collectors.toList()); + final int numberOfSourceNodes = 4; + final int numberOfTerminalNodes = 4; + checkMetricByName(listMetricProcessor, PROCESS_RATE, 4); + checkMetricByName(listMetricProcessor, PROCESS_TOTAL, 4); + checkMetricByName(listMetricProcessor, RECORD_E2E_LATENCY_AVG, numberOfSourceNodes + numberOfTerminalNodes); + checkMetricByName(listMetricProcessor, RECORD_E2E_LATENCY_MIN, numberOfSourceNodes + numberOfTerminalNodes); + checkMetricByName(listMetricProcessor, RECORD_E2E_LATENCY_MAX, numberOfSourceNodes + numberOfTerminalNodes); + } + + private void checkKeyValueStoreMetrics(final String tagKey) { + final List listMetricStore = new ArrayList(kafkaStreams.metrics().values()).stream() + .filter(m -> m.metricName().tags().containsKey(tagKey) && m.metricName().group().equals(STATE_STORE_LEVEL_GROUP)) + .collect(Collectors.toList()); + + final int expectedNumberOfLatencyMetrics = 1; + final int expectedNumberOfRateMetrics = 1; + final int expectedNumberOfTotalMetrics = 0; + checkMetricByName(listMetricStore, PUT_LATENCY_AVG, expectedNumberOfLatencyMetrics); + checkMetricByName(listMetricStore, PUT_LATENCY_MAX, expectedNumberOfLatencyMetrics); + checkMetricByName(listMetricStore, PUT_IF_ABSENT_LATENCY_AVG, expectedNumberOfLatencyMetrics); + checkMetricByName(listMetricStore, PUT_IF_ABSENT_LATENCY_MAX, expectedNumberOfLatencyMetrics); + checkMetricByName(listMetricStore, GET_LATENCY_AVG, expectedNumberOfLatencyMetrics); + checkMetricByName(listMetricStore, GET_LATENCY_MAX, expectedNumberOfLatencyMetrics); + checkMetricByName(listMetricStore, DELETE_LATENCY_AVG, expectedNumberOfLatencyMetrics); + checkMetricByName(listMetricStore, DELETE_LATENCY_MAX, expectedNumberOfLatencyMetrics); + checkMetricByName(listMetricStore, REMOVE_LATENCY_AVG, 0); + checkMetricByName(listMetricStore, REMOVE_LATENCY_MAX, 0); + checkMetricByName(listMetricStore, PUT_ALL_LATENCY_AVG, expectedNumberOfLatencyMetrics); + checkMetricByName(listMetricStore, PUT_ALL_LATENCY_MAX, expectedNumberOfLatencyMetrics); + checkMetricByName(listMetricStore, ALL_LATENCY_AVG, expectedNumberOfLatencyMetrics); + checkMetricByName(listMetricStore, ALL_LATENCY_MAX, expectedNumberOfLatencyMetrics); + checkMetricByName(listMetricStore, RANGE_LATENCY_AVG, expectedNumberOfLatencyMetrics); + checkMetricByName(listMetricStore, RANGE_LATENCY_MAX, expectedNumberOfLatencyMetrics); + checkMetricByName(listMetricStore, FLUSH_LATENCY_AVG, expectedNumberOfLatencyMetrics); + checkMetricByName(listMetricStore, FLUSH_LATENCY_MAX, expectedNumberOfLatencyMetrics); + checkMetricByName(listMetricStore, RESTORE_LATENCY_AVG, expectedNumberOfLatencyMetrics); + checkMetricByName(listMetricStore, RESTORE_LATENCY_MAX, expectedNumberOfLatencyMetrics); + checkMetricByName(listMetricStore, FETCH_LATENCY_AVG, 0); + checkMetricByName(listMetricStore, FETCH_LATENCY_MAX, 0); + checkMetricByName(listMetricStore, PUT_RATE, expectedNumberOfRateMetrics); + checkMetricByName(listMetricStore, PUT_TOTAL, expectedNumberOfTotalMetrics); + checkMetricByName(listMetricStore, PUT_IF_ABSENT_RATE, expectedNumberOfRateMetrics); + checkMetricByName(listMetricStore, PUT_IF_ABSENT_TOTAL, expectedNumberOfTotalMetrics); + checkMetricByName(listMetricStore, GET_RATE, expectedNumberOfRateMetrics); + checkMetricByName(listMetricStore, GET_TOTAL, expectedNumberOfTotalMetrics); + checkMetricByName(listMetricStore, DELETE_RATE, expectedNumberOfRateMetrics); + checkMetricByName(listMetricStore, DELETE_TOTAL, expectedNumberOfTotalMetrics); + checkMetricByName(listMetricStore, REMOVE_RATE, 0); + checkMetricByName(listMetricStore, REMOVE_TOTAL, 0); + checkMetricByName(listMetricStore, PUT_ALL_RATE, expectedNumberOfRateMetrics); + checkMetricByName(listMetricStore, PUT_ALL_TOTAL, expectedNumberOfTotalMetrics); + checkMetricByName(listMetricStore, ALL_RATE, expectedNumberOfRateMetrics); + checkMetricByName(listMetricStore, ALL_TOTAL, expectedNumberOfTotalMetrics); + checkMetricByName(listMetricStore, RANGE_RATE, expectedNumberOfRateMetrics); + checkMetricByName(listMetricStore, RANGE_TOTAL, expectedNumberOfTotalMetrics); + checkMetricByName(listMetricStore, FLUSH_RATE, expectedNumberOfRateMetrics); + checkMetricByName(listMetricStore, FLUSH_TOTAL, expectedNumberOfTotalMetrics); + checkMetricByName(listMetricStore, RESTORE_RATE, expectedNumberOfRateMetrics); + checkMetricByName(listMetricStore, RESTORE_TOTAL, expectedNumberOfTotalMetrics); + checkMetricByName(listMetricStore, FETCH_RATE, 0); + checkMetricByName(listMetricStore, FETCH_TOTAL, 0); + checkMetricByName(listMetricStore, SUPPRESSION_BUFFER_COUNT_CURRENT, 0); + checkMetricByName(listMetricStore, SUPPRESSION_BUFFER_COUNT_AVG, 0); + checkMetricByName(listMetricStore, SUPPRESSION_BUFFER_COUNT_MAX, 0); + checkMetricByName(listMetricStore, SUPPRESSION_BUFFER_SIZE_CURRENT, 0); + checkMetricByName(listMetricStore, SUPPRESSION_BUFFER_SIZE_AVG, 0); + checkMetricByName(listMetricStore, SUPPRESSION_BUFFER_SIZE_MAX, 0); + checkMetricByName(listMetricStore, RECORD_E2E_LATENCY_AVG, 1); + checkMetricByName(listMetricStore, RECORD_E2E_LATENCY_MIN, 1); + checkMetricByName(listMetricStore, RECORD_E2E_LATENCY_MAX, 1); + } + + private void checkMetricsDeregistration() { + final List listMetricAfterClosingApp = new ArrayList(kafkaStreams.metrics().values()).stream() + .filter(m -> m.metricName().group().contains(STREAM_STRING)) + .collect(Collectors.toList()); + assertThat(listMetricAfterClosingApp.size(), is(0)); + } + + private void checkCacheMetrics() { + final List listMetricCache = new ArrayList(kafkaStreams.metrics().values()).stream() + .filter(m -> m.metricName().group().equals(STREAM_CACHE_NODE_METRICS)) + .collect(Collectors.toList()); + checkMetricByName(listMetricCache, HIT_RATIO_AVG, 3); + checkMetricByName(listMetricCache, HIT_RATIO_MIN, 3); + checkMetricByName(listMetricCache, HIT_RATIO_MAX, 3); + } + + private void checkWindowStoreAndSuppressionBufferMetrics() { + final List listMetricStore = new ArrayList(kafkaStreams.metrics().values()).stream() + .filter(m -> m.metricName().group().equals(STATE_STORE_LEVEL_GROUP)) + .collect(Collectors.toList()); + checkMetricByName(listMetricStore, PUT_LATENCY_AVG, 1); + checkMetricByName(listMetricStore, PUT_LATENCY_MAX, 1); + checkMetricByName(listMetricStore, PUT_IF_ABSENT_LATENCY_AVG, 0); + checkMetricByName(listMetricStore, PUT_IF_ABSENT_LATENCY_MAX, 0); + checkMetricByName(listMetricStore, GET_LATENCY_AVG, 0); + checkMetricByName(listMetricStore, GET_LATENCY_MAX, 0); + checkMetricByName(listMetricStore, DELETE_LATENCY_AVG, 0); + checkMetricByName(listMetricStore, DELETE_LATENCY_MAX, 0); + checkMetricByName(listMetricStore, REMOVE_LATENCY_AVG, 0); + checkMetricByName(listMetricStore, REMOVE_LATENCY_MAX, 0); + checkMetricByName(listMetricStore, PUT_ALL_LATENCY_AVG, 0); + checkMetricByName(listMetricStore, PUT_ALL_LATENCY_MAX, 0); + checkMetricByName(listMetricStore, ALL_LATENCY_AVG, 0); + checkMetricByName(listMetricStore, ALL_LATENCY_MAX, 0); + checkMetricByName(listMetricStore, RANGE_LATENCY_AVG, 0); + checkMetricByName(listMetricStore, RANGE_LATENCY_MAX, 0); + checkMetricByName(listMetricStore, FLUSH_LATENCY_AVG, 1); + checkMetricByName(listMetricStore, FLUSH_LATENCY_MAX, 1); + checkMetricByName(listMetricStore, RESTORE_LATENCY_AVG, 1); + checkMetricByName(listMetricStore, RESTORE_LATENCY_MAX, 1); + checkMetricByName(listMetricStore, FETCH_LATENCY_AVG, 1); + checkMetricByName(listMetricStore, FETCH_LATENCY_MAX, 1); + checkMetricByName(listMetricStore, PUT_RATE, 1); + checkMetricByName(listMetricStore, PUT_IF_ABSENT_RATE, 0); + checkMetricByName(listMetricStore, PUT_IF_ABSENT_TOTAL, 0); + checkMetricByName(listMetricStore, GET_RATE, 0); + checkMetricByName(listMetricStore, GET_TOTAL, 0); + checkMetricByName(listMetricStore, DELETE_RATE, 0); + checkMetricByName(listMetricStore, DELETE_TOTAL, 0); + checkMetricByName(listMetricStore, REMOVE_RATE, 0); + checkMetricByName(listMetricStore, REMOVE_TOTAL, 0); + checkMetricByName(listMetricStore, PUT_ALL_RATE, 0); + checkMetricByName(listMetricStore, PUT_ALL_TOTAL, 0); + checkMetricByName(listMetricStore, ALL_RATE, 0); + checkMetricByName(listMetricStore, ALL_TOTAL, 0); + checkMetricByName(listMetricStore, RANGE_RATE, 0); + checkMetricByName(listMetricStore, RANGE_TOTAL, 0); + checkMetricByName(listMetricStore, FLUSH_RATE, 1); + checkMetricByName(listMetricStore, RESTORE_RATE, 1); + checkMetricByName(listMetricStore, FETCH_RATE, 1); + checkMetricByName(listMetricStore, SUPPRESSION_BUFFER_COUNT_AVG, 1); + checkMetricByName(listMetricStore, SUPPRESSION_BUFFER_COUNT_MAX, 1); + checkMetricByName(listMetricStore, SUPPRESSION_BUFFER_SIZE_AVG, 1); + checkMetricByName(listMetricStore, SUPPRESSION_BUFFER_SIZE_MAX, 1); + checkMetricByName(listMetricStore, RECORD_E2E_LATENCY_AVG, 1); + checkMetricByName(listMetricStore, RECORD_E2E_LATENCY_MIN, 1); + checkMetricByName(listMetricStore, RECORD_E2E_LATENCY_MAX, 1); + } + + private void checkSessionStoreMetrics() { + final List listMetricStore = new ArrayList(kafkaStreams.metrics().values()).stream() + .filter(m -> m.metricName().group().equals(STATE_STORE_LEVEL_GROUP)) + .collect(Collectors.toList()); + checkMetricByName(listMetricStore, PUT_LATENCY_AVG, 1); + checkMetricByName(listMetricStore, PUT_LATENCY_MAX, 1); + checkMetricByName(listMetricStore, PUT_IF_ABSENT_LATENCY_AVG, 0); + checkMetricByName(listMetricStore, PUT_IF_ABSENT_LATENCY_MAX, 0); + checkMetricByName(listMetricStore, GET_LATENCY_AVG, 0); + checkMetricByName(listMetricStore, GET_LATENCY_MAX, 0); + checkMetricByName(listMetricStore, DELETE_LATENCY_AVG, 0); + checkMetricByName(listMetricStore, DELETE_LATENCY_MAX, 0); + checkMetricByName(listMetricStore, REMOVE_LATENCY_AVG, 1); + checkMetricByName(listMetricStore, REMOVE_LATENCY_MAX, 1); + checkMetricByName(listMetricStore, PUT_ALL_LATENCY_AVG, 0); + checkMetricByName(listMetricStore, PUT_ALL_LATENCY_MAX, 0); + checkMetricByName(listMetricStore, ALL_LATENCY_AVG, 0); + checkMetricByName(listMetricStore, ALL_LATENCY_MAX, 0); + checkMetricByName(listMetricStore, RANGE_LATENCY_AVG, 0); + checkMetricByName(listMetricStore, RANGE_LATENCY_MAX, 0); + checkMetricByName(listMetricStore, FLUSH_LATENCY_AVG, 1); + checkMetricByName(listMetricStore, FLUSH_LATENCY_MAX, 1); + checkMetricByName(listMetricStore, RESTORE_LATENCY_AVG, 1); + checkMetricByName(listMetricStore, RESTORE_LATENCY_MAX, 1); + checkMetricByName(listMetricStore, FETCH_LATENCY_AVG, 1); + checkMetricByName(listMetricStore, FETCH_LATENCY_MAX, 1); + checkMetricByName(listMetricStore, PUT_RATE, 1); + checkMetricByName(listMetricStore, PUT_IF_ABSENT_RATE, 0); + checkMetricByName(listMetricStore, PUT_IF_ABSENT_TOTAL, 0); + checkMetricByName(listMetricStore, GET_RATE, 0); + checkMetricByName(listMetricStore, GET_TOTAL, 0); + checkMetricByName(listMetricStore, DELETE_RATE, 0); + checkMetricByName(listMetricStore, DELETE_TOTAL, 0); + checkMetricByName(listMetricStore, REMOVE_RATE, 1); + checkMetricByName(listMetricStore, PUT_ALL_RATE, 0); + checkMetricByName(listMetricStore, PUT_ALL_TOTAL, 0); + checkMetricByName(listMetricStore, ALL_RATE, 0); + checkMetricByName(listMetricStore, ALL_TOTAL, 0); + checkMetricByName(listMetricStore, RANGE_RATE, 0); + checkMetricByName(listMetricStore, RANGE_TOTAL, 0); + checkMetricByName(listMetricStore, FLUSH_RATE, 1); + checkMetricByName(listMetricStore, RESTORE_RATE, 1); + checkMetricByName(listMetricStore, FETCH_RATE, 1); + checkMetricByName(listMetricStore, SUPPRESSION_BUFFER_COUNT_CURRENT, 0); + checkMetricByName(listMetricStore, SUPPRESSION_BUFFER_COUNT_AVG, 0); + checkMetricByName(listMetricStore, SUPPRESSION_BUFFER_COUNT_MAX, 0); + checkMetricByName(listMetricStore, SUPPRESSION_BUFFER_SIZE_CURRENT, 0); + checkMetricByName(listMetricStore, SUPPRESSION_BUFFER_SIZE_AVG, 0); + checkMetricByName(listMetricStore, SUPPRESSION_BUFFER_SIZE_MAX, 0); + checkMetricByName(listMetricStore, RECORD_E2E_LATENCY_AVG, 1); + checkMetricByName(listMetricStore, RECORD_E2E_LATENCY_MIN, 1); + checkMetricByName(listMetricStore, RECORD_E2E_LATENCY_MAX, 1); + } + + private void checkMetricByName(final List listMetric, final String metricName, final int numMetric) { + final List metrics = listMetric.stream() + .filter(m -> m.metricName().name().equals(metricName)) + .collect(Collectors.toList()); + Assert.assertEquals("Size of metrics of type:'" + metricName + "' must be equal to " + numMetric + " but it's equal to " + metrics.size(), numMetric, metrics.size()); + for (final Metric m : metrics) { + Assert.assertNotNull("Metric:'" + m.metricName() + "' must be not null", m.metricValue()); + } + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/MetricsReporterIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/MetricsReporterIntegrationTest.java new file mode 100644 index 0000000..a7c925a --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/MetricsReporterIntegrationTest.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.common.metrics.KafkaMetric; +import org.apache.kafka.common.metrics.MetricsReporter; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.test.IntegrationTest; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; + +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.MatcherAssert.assertThat; + +@Category({IntegrationTest.class}) +public class MetricsReporterIntegrationTest { + + private static final int NUM_BROKERS = 1; + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(NUM_BROKERS); + + // topic names + private static final String STREAM_INPUT = "STREAM_INPUT"; + private static final String STREAM_OUTPUT = "STREAM_OUTPUT"; + + private StreamsBuilder builder; + private Properties streamsConfiguration; + + @Rule + public TestName testName = new TestName(); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + @Before + public void before() throws InterruptedException { + builder = new StreamsBuilder(); + + final String safeTestName = safeUniqueTestName(getClass(), testName); + final String appId = "app-" + safeTestName; + + streamsConfiguration = new Properties(); + streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, appId); + streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.Integer().getClass()); + streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + streamsConfiguration.put(CommonClientConfigs.METRIC_REPORTER_CLASSES_CONFIG, MetricReporterImpl.class.getName()); + } + + final static Map METRIC_NAME_TO_INITIAL_VALUE = new HashMap<>(); + + public static class MetricReporterImpl implements MetricsReporter { + + + @Override + public void configure(final Map configs) { + } + + @Override + public void init(final List metrics) { + } + + @Override + public void metricChange(final KafkaMetric metric) { + // get value of metric, e.g. if you wanted checking the type of the value + METRIC_NAME_TO_INITIAL_VALUE.put(metric.metricName().name(), metric.metricValue()); + } + + @Override + public void metricRemoval(final KafkaMetric metric) { + } + + @Override + public void close() { + } + + } + + @Test + public void shouldBeAbleToProvideInitialMetricValueToMetricsReporter() { + // no need to create the topics, because we don't start the stream - just need to create the KafkaStreams object + // to check all initial values from the metrics are not null + builder.stream(STREAM_INPUT, Consumed.with(Serdes.Integer(), Serdes.String())) + .to(STREAM_OUTPUT, Produced.with(Serdes.Integer(), Serdes.String())); + final Topology topology = builder.build(); + final KafkaStreams kafkaStreams = new KafkaStreams(topology, streamsConfiguration); + + kafkaStreams.metrics().keySet().forEach(metricName -> { + final Object initialMetricValue = METRIC_NAME_TO_INITIAL_VALUE.get(metricName.name()); + assertThat(initialMetricValue, notNullValue()); + }); + } + +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/NamedTopologyIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/NamedTopologyIntegrationTest.java new file mode 100644 index 0000000..2b01fee --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/NamedTopologyIntegrationTest.java @@ -0,0 +1,444 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.serialization.LongDeserializer; +import org.apache.kafka.common.serialization.LongSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KafkaClientSupplier; +import org.apache.kafka.streams.KafkaStreams.State; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.processor.internals.DefaultKafkaClientSupplier; +import org.apache.kafka.streams.processor.internals.namedtopology.KafkaStreamsNamedTopologyWrapper; +import org.apache.kafka.streams.processor.internals.namedtopology.NamedTopology; +import org.apache.kafka.streams.processor.internals.namedtopology.NamedTopologyStreamsBuilder; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.utils.UniqueTopicSerdeScope; +import org.apache.kafka.test.TestUtils; + +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestName; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Properties; +import java.util.Set; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.streams.KeyValue.pair; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; + +public class NamedTopologyIntegrationTest { + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(1); + + // TODO KAFKA-12648: + // 1) full test coverage for add/removeNamedTopology, covering: + // - the "last topology removed" case + // - test using multiple clients, with standbys + + // "standard" input topics which are pre-filled with the STANDARD_INPUT_DATA + private final static String INPUT_STREAM_1 = "input-stream-1"; + private final static String INPUT_STREAM_2 = "input-stream-2"; + private final static String INPUT_STREAM_3 = "input-stream-3"; + + private final static String OUTPUT_STREAM_1 = "output-stream-1"; + private final static String OUTPUT_STREAM_2 = "output-stream-2"; + private final static String OUTPUT_STREAM_3 = "output-stream-3"; + + private final static String SUM_OUTPUT = "sum"; + private final static String COUNT_OUTPUT = "count"; + + + // "delayed" input topics which are empty at start to allow control over when input data appears + private final static String DELAYED_INPUT_STREAM_1 = "delayed-input-stream-1"; + private final static String DELAYED_INPUT_STREAM_2 = "delayed-input-stream-2"; + + private final static Materialized> IN_MEMORY_STORE = Materialized.as(Stores.inMemoryKeyValueStore("store")); + private final static Materialized> ROCKSDB_STORE = Materialized.as(Stores.persistentKeyValueStore("store")); + + private static Properties producerConfig; + private static Properties consumerConfig; + + @BeforeClass + public static void initializeClusterAndStandardTopics() throws Exception { + CLUSTER.start(); + + CLUSTER.createTopic(INPUT_STREAM_1, 2, 1); + CLUSTER.createTopic(INPUT_STREAM_2, 2, 1); + CLUSTER.createTopic(INPUT_STREAM_3, 2, 1); + + CLUSTER.createTopic(DELAYED_INPUT_STREAM_1, 2, 1); + CLUSTER.createTopic(DELAYED_INPUT_STREAM_2, 2, 1); + + producerConfig = TestUtils.producerConfig(CLUSTER.bootstrapServers(), StringSerializer.class, LongSerializer.class); + consumerConfig = TestUtils.consumerConfig(CLUSTER.bootstrapServers(), StringDeserializer.class, LongDeserializer.class); + + produceToInputTopics(INPUT_STREAM_1, STANDARD_INPUT_DATA); + produceToInputTopics(INPUT_STREAM_2, STANDARD_INPUT_DATA); + produceToInputTopics(INPUT_STREAM_3, STANDARD_INPUT_DATA); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + @Rule + public final TestName testName = new TestName(); + private String appId; + private String changelog1; + private String changelog2; + private String changelog3; + + private final static List> STANDARD_INPUT_DATA = + asList(pair("A", 100L), pair("B", 200L), pair("A", 300L), pair("C", 400L), pair("C", -50L)); + private final static List> COUNT_OUTPUT_DATA = + asList(pair("B", 1L), pair("A", 2L), pair("C", 2L)); // output of count operation with caching + private final static List> SUM_OUTPUT_DATA = + asList(pair("B", 200L), pair("A", 400L), pair("C", 350L)); // output of summation with caching + + private final KafkaClientSupplier clientSupplier = new DefaultKafkaClientSupplier(); + + // builders for the 1st Streams instance (default) + private final NamedTopologyStreamsBuilder topology1Builder = new NamedTopologyStreamsBuilder("topology-1"); + private final NamedTopologyStreamsBuilder topology2Builder = new NamedTopologyStreamsBuilder("topology-2"); + private final NamedTopologyStreamsBuilder topology3Builder = new NamedTopologyStreamsBuilder("topology-3"); + + // builders for the 2nd Streams instance + private final NamedTopologyStreamsBuilder topology1Builder2 = new NamedTopologyStreamsBuilder("topology-1"); + private final NamedTopologyStreamsBuilder topology2Builder2 = new NamedTopologyStreamsBuilder("topology-2"); + private final NamedTopologyStreamsBuilder topology3Builder2 = new NamedTopologyStreamsBuilder("topology-3"); + + private Properties props; + private Properties props2; + + private KafkaStreamsNamedTopologyWrapper streams; + private KafkaStreamsNamedTopologyWrapper streams2; + + private Properties configProps() { + final Properties streamsConfiguration = new Properties(); + streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, appId); + streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory(appId).getPath()); + streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.Long().getClass()); + streamsConfiguration.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 2); + streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + streamsConfiguration.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + streamsConfiguration.put(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, 10 * 1000); + return streamsConfiguration; + } + + @Before + public void setup() throws Exception { + appId = safeUniqueTestName(NamedTopologyIntegrationTest.class, testName); + changelog1 = appId + "-topology-1-store-changelog"; + changelog2 = appId + "-topology-2-store-changelog"; + changelog3 = appId + "-topology-3-store-changelog"; + props = configProps(); + props2 = configProps(); + + // TODO KAFKA-12648: refactor to avoid deleting & (re)creating outputs topics for each test + CLUSTER.createTopic(OUTPUT_STREAM_1, 2, 1); + CLUSTER.createTopic(OUTPUT_STREAM_2, 2, 1); + CLUSTER.createTopic(OUTPUT_STREAM_3, 2, 1); + } + + @After + public void shutdown() throws Exception { + if (streams != null) { + streams.close(Duration.ofSeconds(30)); + } + if (streams2 != null) { + streams2.close(Duration.ofSeconds(30)); + } + + CLUSTER.deleteTopics(OUTPUT_STREAM_1, OUTPUT_STREAM_2, OUTPUT_STREAM_3); + } + + @Test + public void shouldPrefixAllInternalTopicNamesWithNamedTopology() throws Exception { + final String countTopologyName = "count-topology"; + final String fkjTopologyName = "FKJ-topology"; + + final NamedTopologyStreamsBuilder countBuilder = new NamedTopologyStreamsBuilder(countTopologyName); + countBuilder.stream(INPUT_STREAM_1).groupBy((k, v) -> k).count(); + + final NamedTopologyStreamsBuilder fkjBuilder = new NamedTopologyStreamsBuilder(fkjTopologyName); + + final UniqueTopicSerdeScope serdeScope = new UniqueTopicSerdeScope(); + final KTable left = fkjBuilder.table( + INPUT_STREAM_2, + Consumed.with(serdeScope.decorateSerde(Serdes.String(), props, true), + serdeScope.decorateSerde(Serdes.Long(), props, false)) + ); + final KTable right = fkjBuilder.table( + INPUT_STREAM_3, + Consumed.with(serdeScope.decorateSerde(Serdes.String(), props, true), + serdeScope.decorateSerde(Serdes.Long(), props, false)) + ); + left.join( + right, + Object::toString, + (value1, value2) -> String.valueOf(value1 + value2), + Materialized.with(null, serdeScope.decorateSerde(Serdes.String(), props, false))); + + streams = new KafkaStreamsNamedTopologyWrapper(buildNamedTopologies(fkjBuilder, countBuilder), props, clientSupplier); + IntegrationTestUtils.startApplicationAndWaitUntilRunning(singletonList(streams), Duration.ofSeconds(15)); + + final String countTopicPrefix = appId + "-" + countTopologyName; + final String fkjTopicPrefix = appId + "-" + fkjTopologyName; + final Set internalTopics = CLUSTER + .getAllTopicsInCluster().stream() + .filter(t -> t.contains(appId)) + .filter(t -> t.endsWith("-repartition") || t.endsWith("-changelog") || t.endsWith("-topic")) + .collect(Collectors.toSet()); + assertThat(internalTopics, is(mkSet( + countTopicPrefix + "-KSTREAM-AGGREGATE-STATE-STORE-0000000002-repartition", + countTopicPrefix + "-KSTREAM-AGGREGATE-STATE-STORE-0000000002-changelog", + fkjTopicPrefix + "-KTABLE-FK-JOIN-SUBSCRIPTION-REGISTRATION-0000000006-topic", + fkjTopicPrefix + "-KTABLE-FK-JOIN-SUBSCRIPTION-RESPONSE-0000000014-topic", + fkjTopicPrefix + "-KTABLE-FK-JOIN-SUBSCRIPTION-STATE-STORE-0000000010-changelog", + fkjTopicPrefix + "-" + INPUT_STREAM_2 + "-STATE-STORE-0000000000-changelog", + fkjTopicPrefix + "-" + INPUT_STREAM_3 + "-STATE-STORE-0000000003-changelog")) + ); + } + + @Test + public void shouldProcessSingleNamedTopologyAndPrefixInternalTopics() throws Exception { + topology1Builder.stream(INPUT_STREAM_1) + .selectKey((k, v) -> k) + .groupByKey() + .count(ROCKSDB_STORE) + .toStream().to(OUTPUT_STREAM_1); + streams = new KafkaStreamsNamedTopologyWrapper(topology1Builder.buildNamedTopology(props), props, clientSupplier); + IntegrationTestUtils.startApplicationAndWaitUntilRunning(singletonList(streams), Duration.ofSeconds(15)); + final List> results = waitUntilMinKeyValueRecordsReceived(consumerConfig, OUTPUT_STREAM_1, 3); + assertThat(results, equalTo(COUNT_OUTPUT_DATA)); + + final Set allTopics = CLUSTER.getAllTopicsInCluster(); + assertThat(allTopics.contains(appId + "-" + "topology-1" + "-store-changelog"), is(true)); + assertThat(allTopics.contains(appId + "-" + "topology-1" + "-store-repartition"), is(true)); + } + + @Test + public void shouldProcessMultipleIdenticalNamedTopologiesWithInMemoryAndPersistentStateStores() throws Exception { + topology1Builder.stream(INPUT_STREAM_1).groupBy((k, v) -> k).count(ROCKSDB_STORE).toStream().to(OUTPUT_STREAM_1); + topology2Builder.stream(INPUT_STREAM_2).groupBy((k, v) -> k).count(IN_MEMORY_STORE).toStream().to(OUTPUT_STREAM_2); + topology3Builder.stream(INPUT_STREAM_3).groupBy((k, v) -> k).count(ROCKSDB_STORE).toStream().to(OUTPUT_STREAM_3); + streams = new KafkaStreamsNamedTopologyWrapper(buildNamedTopologies(topology1Builder, topology2Builder, topology3Builder), props, clientSupplier); + IntegrationTestUtils.startApplicationAndWaitUntilRunning(singletonList(streams), Duration.ofSeconds(15)); + + assertThat(waitUntilMinKeyValueRecordsReceived(consumerConfig, OUTPUT_STREAM_1, 3), equalTo(COUNT_OUTPUT_DATA)); + assertThat(waitUntilMinKeyValueRecordsReceived(consumerConfig, OUTPUT_STREAM_2, 3), equalTo(COUNT_OUTPUT_DATA)); + assertThat(waitUntilMinKeyValueRecordsReceived(consumerConfig, OUTPUT_STREAM_3, 3), equalTo(COUNT_OUTPUT_DATA)); + + assertThat(CLUSTER.getAllTopicsInCluster().containsAll(asList(changelog1, changelog2, changelog3)), is(true)); + } + + @Test + public void shouldAddNamedTopologyToUnstartedApplicationWithEmptyInitialTopology() throws Exception { + topology1Builder.stream(INPUT_STREAM_1).groupBy((k, v) -> k).count(IN_MEMORY_STORE).toStream().to(OUTPUT_STREAM_1); + streams = new KafkaStreamsNamedTopologyWrapper(props, clientSupplier); + streams.addNamedTopology(topology1Builder.buildNamedTopology(props)); + IntegrationTestUtils.startApplicationAndWaitUntilRunning(singletonList(streams), Duration.ofSeconds(15)); + + assertThat(waitUntilMinKeyValueRecordsReceived(consumerConfig, OUTPUT_STREAM_1, 3), equalTo(COUNT_OUTPUT_DATA)); + } + + @Test + public void shouldAddNamedTopologyToRunningApplicationWithEmptyInitialTopology() throws Exception { + topology1Builder.stream(INPUT_STREAM_1).groupBy((k, v) -> k).count(IN_MEMORY_STORE).toStream().to(OUTPUT_STREAM_1); + streams = new KafkaStreamsNamedTopologyWrapper(props, clientSupplier); + streams.start(); + + streams.addNamedTopology(topology1Builder.buildNamedTopology(props)); + IntegrationTestUtils.waitForApplicationState(singletonList(streams), State.RUNNING, Duration.ofSeconds(15)); + + assertThat(waitUntilMinKeyValueRecordsReceived(consumerConfig, OUTPUT_STREAM_1, 3), equalTo(COUNT_OUTPUT_DATA)); + } + + @Test + public void shouldAddNamedTopologyToRunningApplicationWithSingleInitialNamedTopology() throws Exception { + topology1Builder.stream(INPUT_STREAM_1).groupBy((k, v) -> k).count(IN_MEMORY_STORE).toStream().to(OUTPUT_STREAM_1); + topology2Builder.stream(INPUT_STREAM_2).groupBy((k, v) -> k).count(IN_MEMORY_STORE).toStream().to(OUTPUT_STREAM_2); + streams = new KafkaStreamsNamedTopologyWrapper(topology1Builder.buildNamedTopology(props), props, clientSupplier); + IntegrationTestUtils.startApplicationAndWaitUntilRunning(singletonList(streams), Duration.ofSeconds(15)); + + streams.addNamedTopology(topology2Builder.buildNamedTopology(props)); + + assertThat(waitUntilMinKeyValueRecordsReceived(consumerConfig, OUTPUT_STREAM_1, 3), equalTo(COUNT_OUTPUT_DATA)); + assertThat(waitUntilMinKeyValueRecordsReceived(consumerConfig, OUTPUT_STREAM_2, 3), equalTo(COUNT_OUTPUT_DATA)); + } + + @Test + public void shouldAddNamedTopologyToRunningApplicationWithMultipleInitialNamedTopologies() throws Exception { + topology1Builder.stream(INPUT_STREAM_1).groupBy((k, v) -> k).count(ROCKSDB_STORE).toStream().to(OUTPUT_STREAM_1); + topology2Builder.stream(INPUT_STREAM_2).groupBy((k, v) -> k).count(ROCKSDB_STORE).toStream().to(OUTPUT_STREAM_2); + topology3Builder.stream(INPUT_STREAM_3).groupBy((k, v) -> k).count(ROCKSDB_STORE).toStream().to(OUTPUT_STREAM_3); + streams = new KafkaStreamsNamedTopologyWrapper(buildNamedTopologies(topology1Builder, topology2Builder), props, clientSupplier); + IntegrationTestUtils.startApplicationAndWaitUntilRunning(singletonList(streams), Duration.ofSeconds(15)); + + streams.addNamedTopology(topology3Builder.buildNamedTopology(props)); + + assertThat(waitUntilMinKeyValueRecordsReceived(consumerConfig, OUTPUT_STREAM_1, 3), equalTo(COUNT_OUTPUT_DATA)); + assertThat(waitUntilMinKeyValueRecordsReceived(consumerConfig, OUTPUT_STREAM_2, 3), equalTo(COUNT_OUTPUT_DATA)); + assertThat(waitUntilMinKeyValueRecordsReceived(consumerConfig, OUTPUT_STREAM_3, 3), equalTo(COUNT_OUTPUT_DATA)); + } + + @Test + public void shouldAddNamedTopologyToRunningApplicationWithMultipleNodes() throws Exception { + topology1Builder.stream(INPUT_STREAM_1).groupBy((k, v) -> k).count(IN_MEMORY_STORE).toStream().to(OUTPUT_STREAM_1); + topology1Builder2.stream(INPUT_STREAM_1).groupBy((k, v) -> k).count(IN_MEMORY_STORE).toStream().to(OUTPUT_STREAM_1); + + topology2Builder.stream(INPUT_STREAM_2).groupBy((k, v) -> k).count(IN_MEMORY_STORE).toStream().to(OUTPUT_STREAM_2); + topology2Builder2.stream(INPUT_STREAM_2).groupBy((k, v) -> k).count(IN_MEMORY_STORE).toStream().to(OUTPUT_STREAM_2); + + streams = new KafkaStreamsNamedTopologyWrapper(topology1Builder.buildNamedTopology(props), props, clientSupplier); + streams2 = new KafkaStreamsNamedTopologyWrapper(topology1Builder2.buildNamedTopology(props2), props2, clientSupplier); + IntegrationTestUtils.startApplicationAndWaitUntilRunning(asList(streams, streams2), Duration.ofSeconds(15)); + + streams.addNamedTopology(topology2Builder.buildNamedTopology(props)); + streams2.addNamedTopology(topology2Builder2.buildNamedTopology(props2)); + + assertThat(waitUntilMinKeyValueRecordsReceived(consumerConfig, OUTPUT_STREAM_1, 3), equalTo(COUNT_OUTPUT_DATA)); + assertThat(waitUntilMinKeyValueRecordsReceived(consumerConfig, OUTPUT_STREAM_2, 3), equalTo(COUNT_OUTPUT_DATA)); + + // TODO KAFKA-12648: need to make sure that both instances actually did some of this processing of topology-2, + // ie that both joined the group after the new topology was added and then successfully processed records from it + // Also: test where we wait for a rebalance between streams.addNamedTopology and streams2.addNamedTopology, + // and vice versa, to make sure we hit case where not all new tasks are initially assigned, and when not all yet known + } + + @Ignore // TODO KAFKA-12648: re-enable once we have the ability to block on the removed topology + @Test + public void shouldRemoveOneNamedTopologyWhileAnotherContinuesProcessing() throws Exception { + topology1Builder.stream(DELAYED_INPUT_STREAM_1).groupBy((k, v) -> k).count(IN_MEMORY_STORE).toStream().to(OUTPUT_STREAM_1); + topology2Builder.stream(DELAYED_INPUT_STREAM_2).map((k, v) -> { + throw new IllegalStateException("Should not process any records for removed topology-2"); + }); + streams = new KafkaStreamsNamedTopologyWrapper(buildNamedTopologies(topology1Builder, topology2Builder), props, clientSupplier); + IntegrationTestUtils.startApplicationAndWaitUntilRunning(singletonList(streams), Duration.ofSeconds(15)); + + streams.removeNamedTopology("topology-2"); + + produceToInputTopics(DELAYED_INPUT_STREAM_1, STANDARD_INPUT_DATA); + produceToInputTopics(DELAYED_INPUT_STREAM_2, STANDARD_INPUT_DATA); + + assertThat(waitUntilMinKeyValueRecordsReceived(consumerConfig, OUTPUT_STREAM_1, 3), equalTo(COUNT_OUTPUT_DATA)); + } + + @Ignore // TODO KAFKA-12648: re-enable once we have the ability to block on the removed topology + @Test + public void shouldRemoveAndReplaceTopologicallyIncompatibleNamedTopology() throws Exception { + CLUSTER.createTopics(SUM_OUTPUT, COUNT_OUTPUT); + // Build up named topology with two stateful subtopologies + final KStream inputStream1 = topology1Builder.stream(INPUT_STREAM_1); + inputStream1.groupByKey().count().toStream().to(COUNT_OUTPUT); + inputStream1.groupByKey().reduce(Long::sum).toStream().to(SUM_OUTPUT); + streams = new KafkaStreamsNamedTopologyWrapper(buildNamedTopologies(topology1Builder), props, clientSupplier); + IntegrationTestUtils.startApplicationAndWaitUntilRunning(singletonList(streams), Duration.ofSeconds(15)); + + assertThat(waitUntilMinKeyValueRecordsReceived(consumerConfig, COUNT_OUTPUT, 3), equalTo(COUNT_OUTPUT_DATA)); + assertThat(waitUntilMinKeyValueRecordsReceived(consumerConfig, SUM_OUTPUT, 3), equalTo(SUM_OUTPUT_DATA)); + streams.removeNamedTopology("topology-1"); + streams.cleanUpNamedTopology("topology-1"); + + // Prepare a new named topology with the same name but an incompatible topology (stateful subtopologies swap order) + final KStream inputStream2 = topology1Builder2.stream(DELAYED_INPUT_STREAM_1); + inputStream2.groupByKey().reduce(Long::sum).toStream().to(SUM_OUTPUT); + inputStream2.groupByKey().count().toStream().to(COUNT_OUTPUT); + + produceToInputTopics(DELAYED_INPUT_STREAM_1, STANDARD_INPUT_DATA); + streams.addNamedTopology(topology1Builder2.buildNamedTopology(props)); + + assertThat(waitUntilMinKeyValueRecordsReceived(consumerConfig, COUNT_OUTPUT, 3), equalTo(COUNT_OUTPUT_DATA)); + assertThat(waitUntilMinKeyValueRecordsReceived(consumerConfig, SUM_OUTPUT, 3), equalTo(SUM_OUTPUT_DATA)); + CLUSTER.deleteTopics(SUM_OUTPUT, COUNT_OUTPUT); + } + + @Test + public void shouldAllowPatternSubscriptionWithMultipleNamedTopologies() throws Exception { + topology1Builder.stream(Pattern.compile(INPUT_STREAM_1)).groupBy((k, v) -> k).count().toStream().to(OUTPUT_STREAM_1); + topology2Builder.stream(Pattern.compile(INPUT_STREAM_2)).groupBy((k, v) -> k).count().toStream().to(OUTPUT_STREAM_2); + topology3Builder.stream(Pattern.compile(INPUT_STREAM_3)).groupBy((k, v) -> k).count().toStream().to(OUTPUT_STREAM_3); + streams = new KafkaStreamsNamedTopologyWrapper(buildNamedTopologies(topology1Builder, topology2Builder, topology3Builder), props, clientSupplier); + IntegrationTestUtils.startApplicationAndWaitUntilRunning(singletonList(streams), Duration.ofSeconds(15)); + + assertThat(waitUntilMinKeyValueRecordsReceived(consumerConfig, OUTPUT_STREAM_1, 3), equalTo(COUNT_OUTPUT_DATA)); + assertThat(waitUntilMinKeyValueRecordsReceived(consumerConfig, OUTPUT_STREAM_2, 3), equalTo(COUNT_OUTPUT_DATA)); + assertThat(waitUntilMinKeyValueRecordsReceived(consumerConfig, OUTPUT_STREAM_3, 3), equalTo(COUNT_OUTPUT_DATA)); + } + + @Test + public void shouldAllowMixedCollectionAndPatternSubscriptionWithMultipleNamedTopologies() throws Exception { + topology1Builder.stream(INPUT_STREAM_1).groupBy((k, v) -> k).count().toStream().to(OUTPUT_STREAM_1); + topology2Builder.stream(Pattern.compile(INPUT_STREAM_2)).groupBy((k, v) -> k).count().toStream().to(OUTPUT_STREAM_2); + topology3Builder.stream(Pattern.compile(INPUT_STREAM_3)).groupBy((k, v) -> k).count().toStream().to(OUTPUT_STREAM_3); + streams = new KafkaStreamsNamedTopologyWrapper(buildNamedTopologies(topology1Builder, topology2Builder, topology3Builder), props, clientSupplier); + IntegrationTestUtils.startApplicationAndWaitUntilRunning(singletonList(streams), Duration.ofSeconds(15)); + + assertThat(waitUntilMinKeyValueRecordsReceived(consumerConfig, OUTPUT_STREAM_1, 3), equalTo(COUNT_OUTPUT_DATA)); + assertThat(waitUntilMinKeyValueRecordsReceived(consumerConfig, OUTPUT_STREAM_2, 3), equalTo(COUNT_OUTPUT_DATA)); + assertThat(waitUntilMinKeyValueRecordsReceived(consumerConfig, OUTPUT_STREAM_3, 3), equalTo(COUNT_OUTPUT_DATA)); + } + + private static void produceToInputTopics(final String topic, final Collection> records) { + IntegrationTestUtils.produceKeyValuesSynchronously( + topic, + records, + producerConfig, + CLUSTER.time + ); + } + + private List buildNamedTopologies(final NamedTopologyStreamsBuilder... builders) { + final List topologies = new ArrayList<>(); + for (final NamedTopologyStreamsBuilder builder : builders) { + topologies.add(builder.buildNamedTopology(props)); + } + return topologies; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/OptimizedKTableIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/OptimizedKTableIntegrationTest.java new file mode 100644 index 0000000..44744cd --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/OptimizedKTableIntegrationTest.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.startApplicationAndWaitUntilRunning; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; + +import java.io.IOException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Properties; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.KeyQueryMetadata; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.QueryableStoreTypes; +import org.apache.kafka.streams.state.ReadOnlyKeyValueStore; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; + +@Category(IntegrationTest.class) +public class OptimizedKTableIntegrationTest { + private static final int NUM_BROKERS = 1; + private static int port = 0; + private static final String INPUT_TOPIC_NAME = "input-topic"; + private static final String TABLE_NAME = "source-table"; + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(NUM_BROKERS); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + + @Rule + public final TestName testName = new TestName(); + + private final List streamsToCleanup = new ArrayList<>(); + private final MockTime mockTime = CLUSTER.time; + + @Before + public void before() throws InterruptedException { + CLUSTER.createTopic(INPUT_TOPIC_NAME, 2, 1); + } + + @After + public void after() { + for (final KafkaStreams kafkaStreams : streamsToCleanup) { + kafkaStreams.close(); + } + } + + @Test + public void shouldApplyUpdatesToStandbyStore() throws Exception { + final int batch1NumMessages = 100; + final int batch2NumMessages = 100; + final int key = 1; + final Semaphore semaphore = new Semaphore(0); + + final StreamsBuilder builder = new StreamsBuilder(); + builder + .table(INPUT_TOPIC_NAME, Consumed.with(Serdes.Integer(), Serdes.Integer()), + Materialized.>as(TABLE_NAME) + .withCachingDisabled()) + .toStream() + .peek((k, v) -> semaphore.release()); + + final KafkaStreams kafkaStreams1 = createKafkaStreams(builder, streamsConfiguration()); + final KafkaStreams kafkaStreams2 = createKafkaStreams(builder, streamsConfiguration()); + final List kafkaStreamsList = Arrays.asList(kafkaStreams1, kafkaStreams2); + + startApplicationAndWaitUntilRunning(kafkaStreamsList, Duration.ofSeconds(60)); + + produceValueRange(key, 0, batch1NumMessages); + + // Assert that all messages in the first batch were processed in a timely manner + assertThat(semaphore.tryAcquire(batch1NumMessages, 60, TimeUnit.SECONDS), is(equalTo(true))); + + final ReadOnlyKeyValueStore store1 = IntegrationTestUtils.getStore(TABLE_NAME, kafkaStreams1, QueryableStoreTypes.keyValueStore()); + final ReadOnlyKeyValueStore store2 = IntegrationTestUtils.getStore(TABLE_NAME, kafkaStreams2, QueryableStoreTypes.keyValueStore()); + + final boolean kafkaStreams1WasFirstActive; + final KeyQueryMetadata keyQueryMetadata = kafkaStreams1.queryMetadataForKey(TABLE_NAME, key, (topic, somekey, value, numPartitions) -> 0); + + // Assert that the current value in store reflects all messages being processed + if ((keyQueryMetadata.activeHost().port() % 2) == 1) { + assertThat(store1.get(key), is(equalTo(batch1NumMessages - 1))); + kafkaStreams1WasFirstActive = true; + } else { + assertThat(store2.get(key), is(equalTo(batch1NumMessages - 1))); + kafkaStreams1WasFirstActive = false; + } + + if (kafkaStreams1WasFirstActive) { + kafkaStreams1.close(); + } else { + kafkaStreams2.close(); + } + + final ReadOnlyKeyValueStore newActiveStore = kafkaStreams1WasFirstActive ? store2 : store1; + TestUtils.retryOnExceptionWithTimeout(60 * 1000, 100, () -> { + // Assert that after failover we have recovered to the last store write + assertThat(newActiveStore.get(key), is(equalTo(batch1NumMessages - 1))); + }); + + final int totalNumMessages = batch1NumMessages + batch2NumMessages; + + produceValueRange(key, batch1NumMessages, totalNumMessages); + + // Assert that all messages in the second batch were processed in a timely manner + assertThat(semaphore.tryAcquire(batch2NumMessages, 60, TimeUnit.SECONDS), is(equalTo(true))); + + TestUtils.retryOnExceptionWithTimeout(60 * 1000, 100, () -> { + // Assert that the current value in store reflects all messages being processed + assertThat(newActiveStore.get(key), is(equalTo(totalNumMessages - 1))); + }); + } + + private void produceValueRange(final int key, final int start, final int endExclusive) { + final Properties producerProps = new Properties(); + producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + producerProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, IntegerSerializer.class); + producerProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, IntegerSerializer.class); + + IntegrationTestUtils.produceKeyValuesSynchronously( + INPUT_TOPIC_NAME, + IntStream.range(start, endExclusive) + .mapToObj(i -> KeyValue.pair(key, i)) + .collect(Collectors.toList()), + producerProps, + mockTime); + } + + private KafkaStreams createKafkaStreams(final StreamsBuilder builder, final Properties config) { + final KafkaStreams streams = new KafkaStreams(builder.build(config), config); + streamsToCleanup.add(streams); + return streams; + } + + private Properties streamsConfiguration() { + final String safeTestName = safeUniqueTestName(getClass(), testName); + final Properties config = new Properties(); + config.put(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, StreamsConfig.OPTIMIZE); + config.put(StreamsConfig.APPLICATION_ID_CONFIG, "app-" + safeTestName); + config.put(StreamsConfig.APPLICATION_SERVER_CONFIG, "localhost:" + (++port)); + config.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + config.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()); + config.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.Integer().getClass()); + config.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.Integer().getClass()); + config.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L); + config.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + config.put(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, 100); + config.put(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, 200); + config.put(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, 1000); + return config; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/PurgeRepartitionTopicIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/PurgeRepartitionTopicIntegrationTest.java new file mode 100644 index 0000000..ffb3531 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/PurgeRepartitionTopicIntegrationTest.java @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.Config; +import org.apache.kafka.clients.admin.LogDirDescription; +import org.apache.kafka.clients.admin.ReplicaInfo; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.config.TopicConfig; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.MockMapper; +import org.apache.kafka.test.TestCondition; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; + +import java.io.IOException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Properties; +import java.util.Set; + +@Category({IntegrationTest.class}) +public class PurgeRepartitionTopicIntegrationTest { + + private static final int NUM_BROKERS = 1; + + private static final String INPUT_TOPIC = "input-stream"; + private static final String APPLICATION_ID = "restore-test"; + private static final String REPARTITION_TOPIC = APPLICATION_ID + "-KSTREAM-AGGREGATE-STATE-STORE-0000000002-repartition"; + + private static Admin adminClient; + private static KafkaStreams kafkaStreams; + private static final Integer PURGE_INTERVAL_MS = 10; + private static final Integer PURGE_SEGMENT_BYTES = 2000; + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(NUM_BROKERS, new Properties() { + { + put("log.retention.check.interval.ms", PURGE_INTERVAL_MS); + put(TopicConfig.FILE_DELETE_DELAY_MS_CONFIG, 0); + } + }); + + @BeforeClass + public static void startCluster() throws IOException, InterruptedException { + CLUSTER.start(); + CLUSTER.createTopic(INPUT_TOPIC, 1, 1); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + + private final Time time = CLUSTER.time; + + private class RepartitionTopicCreatedWithExpectedConfigs implements TestCondition { + @Override + final public boolean conditionMet() { + try { + final Set topics = adminClient.listTopics().names().get(); + + if (!topics.contains(REPARTITION_TOPIC)) { + return false; + } + } catch (final Exception e) { + return false; + } + + try { + final ConfigResource resource = new ConfigResource(ConfigResource.Type.TOPIC, REPARTITION_TOPIC); + final Config config = adminClient + .describeConfigs(Collections.singleton(resource)) + .values() + .get(resource) + .get(); + return config.get(TopicConfig.CLEANUP_POLICY_CONFIG).value().equals(TopicConfig.CLEANUP_POLICY_DELETE) + && config.get(TopicConfig.SEGMENT_MS_CONFIG).value().equals(PURGE_INTERVAL_MS.toString()) + && config.get(TopicConfig.SEGMENT_BYTES_CONFIG).value().equals(PURGE_SEGMENT_BYTES.toString()); + } catch (final Exception e) { + return false; + } + } + } + + private interface TopicSizeVerifier { + boolean verify(long currentSize); + } + + private class RepartitionTopicVerified implements TestCondition { + private final TopicSizeVerifier verifier; + + RepartitionTopicVerified(final TopicSizeVerifier verifier) { + this.verifier = verifier; + } + + @Override + public final boolean conditionMet() { + time.sleep(PURGE_INTERVAL_MS); + + try { + final Collection logDirInfo = + adminClient.describeLogDirs(Collections.singleton(0)).descriptions().get(0).get().values(); + + for (final LogDirDescription partitionInfo : logDirInfo) { + final ReplicaInfo replicaInfo = + partitionInfo.replicaInfos().get(new TopicPartition(REPARTITION_TOPIC, 0)); + if (replicaInfo != null && verifier.verify(replicaInfo.size())) { + return true; + } + } + } catch (final Exception e) { + // swallow + } + + return false; + } + } + + @Before + public void setup() { + // create admin client for verification + final Properties adminConfig = new Properties(); + adminConfig.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + adminClient = Admin.create(adminConfig); + + final Properties streamsConfiguration = new Properties(); + streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, APPLICATION_ID); + streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, PURGE_INTERVAL_MS); + streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.Integer().getClass()); + streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.Integer().getClass()); + streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory(APPLICATION_ID).getPath()); + streamsConfiguration.put(StreamsConfig.topicPrefix(TopicConfig.SEGMENT_MS_CONFIG), PURGE_INTERVAL_MS); + streamsConfiguration.put(StreamsConfig.topicPrefix(TopicConfig.SEGMENT_BYTES_CONFIG), PURGE_SEGMENT_BYTES); + streamsConfiguration.put(StreamsConfig.producerPrefix(ProducerConfig.BATCH_SIZE_CONFIG), PURGE_SEGMENT_BYTES / 2); // we cannot allow batch size larger than segment size + + final StreamsBuilder builder = new StreamsBuilder(); + builder.stream(INPUT_TOPIC) + .groupBy(MockMapper.selectKeyKeyValueMapper()) + .count(); + + kafkaStreams = new KafkaStreams(builder.build(), streamsConfiguration, time); + } + + @After + public void shutdown() { + if (kafkaStreams != null) { + kafkaStreams.close(Duration.ofSeconds(30)); + } + } + + @Test + public void shouldRestoreState() throws Exception { + // produce some data to input topic + final List> messages = new ArrayList<>(); + for (int i = 0; i < 1000; i++) { + messages.add(new KeyValue<>(i, i)); + } + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp(INPUT_TOPIC, + messages, + TestUtils.producerConfig(CLUSTER.bootstrapServers(), + IntegerSerializer.class, + IntegerSerializer.class), + time.milliseconds()); + + kafkaStreams.start(); + + TestUtils.waitForCondition(new RepartitionTopicCreatedWithExpectedConfigs(), 60000, + "Repartition topic " + REPARTITION_TOPIC + " not created with the expected configs after 60000 ms."); + + TestUtils.waitForCondition( + new RepartitionTopicVerified(currentSize -> currentSize > 0), + 60000, + "Repartition topic " + REPARTITION_TOPIC + " not received data after 60000 ms." + ); + + // we need long enough timeout to by-pass the log manager's InitialTaskDelayMs, which is hard-coded on server side + TestUtils.waitForCondition( + new RepartitionTopicVerified(currentSize -> currentSize <= PURGE_SEGMENT_BYTES), + 60000, + "Repartition topic " + REPARTITION_TOPIC + " not purged data after 60000 ms." + ); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/QueryableStateIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/QueryableStateIntegrationTest.java new file mode 100644 index 0000000..15b9ea6 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/QueryableStateIntegrationTest.java @@ -0,0 +1,1296 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import kafka.utils.MockTime; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.serialization.LongDeserializer; +import org.apache.kafka.common.serialization.LongSerializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KafkaStreams.State; +import org.apache.kafka.streams.KafkaStreamsTest; +import org.apache.kafka.streams.KeyQueryMetadata; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.LagInfo; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.errors.UnknownStateStoreException; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Predicate; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.streams.kstream.ValueMapper; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.QueryableStoreTypes; +import org.apache.kafka.streams.state.ReadOnlyKeyValueStore; +import org.apache.kafka.streams.state.ReadOnlySessionStore; +import org.apache.kafka.streams.state.ReadOnlyWindowStore; +import org.apache.kafka.streams.state.WindowStoreIterator; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.MockMapper; +import org.apache.kafka.test.NoRetryException; +import org.apache.kafka.test.TestUtils; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.BufferedReader; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileReader; +import java.io.IOException; +import java.io.PrintStream; +import java.io.StringReader; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.Properties; +import java.util.Set; +import java.util.TreeMap; +import java.util.TreeSet; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import static java.time.Duration.ofMillis; +import static java.time.Duration.ofSeconds; +import static java.time.Instant.ofEpochMilli; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkProperties; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.streams.StoreQueryParameters.fromNameAndType; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.getRunningStreams; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.startApplicationAndWaitUntilRunning; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.waitForApplicationState; +import static org.apache.kafka.streams.state.QueryableStoreTypes.keyValueStore; +import static org.apache.kafka.streams.state.QueryableStoreTypes.sessionStore; +import static org.apache.kafka.test.StreamsTestUtils.startKafkaStreamsAndWaitForRunningState; +import static org.apache.kafka.test.TestUtils.retryOnExceptionWithTimeout; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.core.IsEqual.equalTo; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; + +@Category({IntegrationTest.class}) +@SuppressWarnings("deprecation") +public class QueryableStateIntegrationTest { + private static final Logger log = LoggerFactory.getLogger(QueryableStateIntegrationTest.class); + + private static final long DEFAULT_TIMEOUT_MS = 120 * 1000; + + private static final int NUM_BROKERS = 1; + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(NUM_BROKERS); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + private static final int STREAM_THREE_PARTITIONS = 4; + private final MockTime mockTime = CLUSTER.time; + private String streamOne = "stream-one"; + private String streamTwo = "stream-two"; + private String streamThree = "stream-three"; + private String streamConcurrent = "stream-concurrent"; + private String outputTopic = "output"; + private String outputTopicConcurrent = "output-concurrent"; + private String outputTopicConcurrentWindowed = "output-concurrent-windowed"; + private String outputTopicThree = "output-three"; + // sufficiently large window size such that everything falls into 1 window + private static final long WINDOW_SIZE = TimeUnit.MILLISECONDS.convert(2, TimeUnit.DAYS); + private static final int STREAM_TWO_PARTITIONS = 2; + private static final int NUM_REPLICAS = NUM_BROKERS; + private Properties streamsConfiguration; + private List inputValues; + private Set inputValuesKeys; + private KafkaStreams kafkaStreams; + private Comparator> stringComparator; + private Comparator> stringLongComparator; + + private void createTopics() throws Exception { + final String safeTestName = safeUniqueTestName(getClass(), testName); + streamOne = streamOne + "-" + safeTestName; + streamConcurrent = streamConcurrent + "-" + safeTestName; + streamThree = streamThree + "-" + safeTestName; + outputTopic = outputTopic + "-" + safeTestName; + outputTopicConcurrent = outputTopicConcurrent + "-" + safeTestName; + outputTopicConcurrentWindowed = outputTopicConcurrentWindowed + "-" + safeTestName; + outputTopicThree = outputTopicThree + "-" + safeTestName; + streamTwo = streamTwo + "-" + safeTestName; + CLUSTER.createTopics(streamOne, streamConcurrent); + CLUSTER.createTopic(streamTwo, STREAM_TWO_PARTITIONS, NUM_REPLICAS); + CLUSTER.createTopic(streamThree, STREAM_THREE_PARTITIONS, 1); + CLUSTER.createTopics(outputTopic, outputTopicConcurrent, outputTopicConcurrentWindowed, outputTopicThree); + } + + /** + * Try to read inputValues from {@code resources/QueryableStateIntegrationTest/inputValues.txt}, which might be useful + * for larger scale testing. In case of exception, for instance if no such file can be read, return a small list + * which satisfies all the prerequisites of the tests. + */ + private List getInputValues() { + List input = new ArrayList<>(); + final ClassLoader classLoader = getClass().getClassLoader(); + final String fileName = "QueryableStateIntegrationTest" + File.separator + "inputValues.txt"; + try (final BufferedReader reader = new BufferedReader( + new FileReader(Objects.requireNonNull(classLoader.getResource(fileName)).getFile()))) { + + for (String line = reader.readLine(); line != null; line = reader.readLine()) { + input.add(line); + } + } catch (final Exception e) { + log.warn("Unable to read '{}{}{}'. Using default inputValues list", "resources", File.separator, fileName); + input = Arrays.asList( + "hello world", + "all streams lead to kafka", + "streams", + "kafka streams", + "the cat in the hat", + "green eggs and ham", + "that Sam i am", + "up the creek without a paddle", + "run forest run", + "a tank full of gas", + "eat sleep rave repeat", + "one jolly sailor", + "king of the world"); + + } + return input; + } + + @Rule + public TestName testName = new TestName(); + + @Before + public void before() throws Exception { + createTopics(); + streamsConfiguration = new Properties(); + final String safeTestName = safeUniqueTestName(getClass(), testName); + + streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, "app-" + safeTestName); + streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()); + streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L); + streamsConfiguration.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + streamsConfiguration.put(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, 10000); + + stringComparator = Comparator.comparing((KeyValue o) -> o.key).thenComparing(o -> o.value); + stringLongComparator = Comparator.comparing((KeyValue o) -> o.key).thenComparingLong(o -> o.value); + inputValues = getInputValues(); + inputValuesKeys = new HashSet<>(); + for (final String sentence : inputValues) { + final String[] words = sentence.split("\\W+"); + Collections.addAll(inputValuesKeys, words); + } + } + + @After + public void shutdown() throws Exception { + if (kafkaStreams != null) { + kafkaStreams.close(ofSeconds(30)); + } + IntegrationTestUtils.purgeLocalStreamsState(streamsConfiguration); + CLUSTER.deleteAllTopicsAndWait(0L); + } + + /** + * Creates a typical word count topology + */ + private KafkaStreams createCountStream(final String inputTopic, + final String outputTopic, + final String windowOutputTopic, + final String storeName, + final String windowStoreName, + final Properties streamsConfiguration) { + final StreamsBuilder builder = new StreamsBuilder(); + final Serde stringSerde = Serdes.String(); + final KStream textLines = builder.stream(inputTopic, Consumed.with(stringSerde, stringSerde)); + + final KGroupedStream groupedByWord = textLines + .flatMapValues((ValueMapper>) value -> Arrays.asList(value.split("\\W+"))) + .groupBy(MockMapper.selectValueMapper()); + + // Create a State Store for the all time word count + groupedByWord + .count(Materialized.as(storeName + "-" + inputTopic)) + .toStream() + .to(outputTopic, Produced.with(Serdes.String(), Serdes.Long())); + + // Create a Windowed State Store that contains the word count for every 1 minute + groupedByWord + .windowedBy(TimeWindows.of(ofMillis(WINDOW_SIZE))) + .count(Materialized.as(windowStoreName + "-" + inputTopic)) + .toStream((key, value) -> key.key()) + .to(windowOutputTopic, Produced.with(Serdes.String(), Serdes.Long())); + + return new KafkaStreams(builder.build(), streamsConfiguration); + } + + + private void verifyOffsetLagFetch(final List streamsList, + final Set stores, + final List partitionsPerStreamsInstance) { + for (int i = 0; i < streamsList.size(); i++) { + final Map> localLags = streamsList.get(i).allLocalStorePartitionLags(); + final int expectedPartitions = partitionsPerStreamsInstance.get(i); + assertThat(localLags.values().stream().mapToInt(Map::size).sum(), equalTo(expectedPartitions)); + if (expectedPartitions > 0) { + assertThat(localLags.keySet(), equalTo(stores)); + } + } + } + + private void verifyAllKVKeys(final List streamsList, + final KafkaStreams streams, + final KafkaStreamsTest.StateListenerStub stateListener, + final Set keys, + final String storeName, + final long timeout, + final boolean pickInstanceByPort) throws Exception { + retryOnExceptionWithTimeout(timeout, () -> { + final List noMetadataKeys = new ArrayList<>(); + final List nullStoreKeys = new ArrayList<>(); + final List nullValueKeys = new ArrayList<>(); + final Map exceptionalKeys = new TreeMap<>(); + final StringSerializer serializer = new StringSerializer(); + + for (final String key: keys) { + try { + final KeyQueryMetadata queryMetadata = streams.queryMetadataForKey(storeName, key, serializer); + if (queryMetadata == null || queryMetadata.equals(KeyQueryMetadata.NOT_AVAILABLE)) { + noMetadataKeys.add(key); + continue; + } + if (!pickInstanceByPort) { + assertThat("Should have standbys to query from", !queryMetadata.standbyHosts().isEmpty()); + } + + final int index = queryMetadata.activeHost().port(); + final KafkaStreams streamsWithKey = pickInstanceByPort ? streamsList.get(index) : streams; + final ReadOnlyKeyValueStore store = + IntegrationTestUtils.getStore(storeName, streamsWithKey, true, keyValueStore()); + if (store == null) { + nullStoreKeys.add(key); + continue; + } + + if (store.get(key) == null) { + nullValueKeys.add(key); + } + } catch (final InvalidStateStoreException e) { + if (stateListener.mapStates.get(KafkaStreams.State.REBALANCING) < 1) { + throw new NoRetryException(new AssertionError( + String.format("Received %s for key %s and expected at least one rebalancing state, but had none", + e.getClass().getName(), key))); + } + } catch (final Exception e) { + exceptionalKeys.put(key, e); + } + } + + assertNoKVKeyFailures(storeName, timeout, noMetadataKeys, nullStoreKeys, nullValueKeys, exceptionalKeys); + }); + } + + private void verifyAllWindowedKeys(final List streamsList, + final KafkaStreams streams, + final KafkaStreamsTest.StateListenerStub stateListenerStub, + final Set keys, + final String storeName, + final Long from, + final Long to, + final long timeout, + final boolean pickInstanceByPort) throws Exception { + retryOnExceptionWithTimeout(timeout, () -> { + final List noMetadataKeys = new ArrayList<>(); + final List nullStoreKeys = new ArrayList<>(); + final List nullValueKeys = new ArrayList<>(); + final Map exceptionalKeys = new TreeMap<>(); + final StringSerializer serializer = new StringSerializer(); + + for (final String key: keys) { + try { + final KeyQueryMetadata queryMetadata = streams.queryMetadataForKey(storeName, key, serializer); + if (queryMetadata == null || queryMetadata.equals(KeyQueryMetadata.NOT_AVAILABLE)) { + noMetadataKeys.add(key); + continue; + } + if (pickInstanceByPort) { + assertThat(queryMetadata.standbyHosts().size(), equalTo(0)); + } else { + assertThat("Should have standbys to query from", !queryMetadata.standbyHosts().isEmpty()); + } + + final int index = queryMetadata.activeHost().port(); + final KafkaStreams streamsWithKey = pickInstanceByPort ? streamsList.get(index) : streams; + final ReadOnlyWindowStore store = + IntegrationTestUtils.getStore(storeName, streamsWithKey, true, QueryableStoreTypes.windowStore()); + if (store == null) { + nullStoreKeys.add(key); + continue; + } + + if (store.fetch(key, ofEpochMilli(from), ofEpochMilli(to)) == null) { + nullValueKeys.add(key); + } + } catch (final InvalidStateStoreException e) { + // there must have been at least one rebalance state + if (stateListenerStub.mapStates.get(KafkaStreams.State.REBALANCING) < 1) { + throw new NoRetryException(new AssertionError( + String.format("Received %s for key %s and expected at least one rebalancing state, but had none", + e.getClass().getName(), key))); + } + } catch (final Exception e) { + exceptionalKeys.put(key, e); + } + } + + assertNoKVKeyFailures(storeName, timeout, noMetadataKeys, nullStoreKeys, nullValueKeys, exceptionalKeys); + }); + } + + private void assertNoKVKeyFailures(final String storeName, + final long timeout, + final List noMetadataKeys, + final List nullStoreKeys, + final List nullValueKeys, + final Map exceptionalKeys) throws IOException { + final StringBuilder reason = new StringBuilder(); + reason.append(String.format("Not all keys are available for store %s in %d ms", storeName, timeout)); + if (!noMetadataKeys.isEmpty()) { + reason.append("\n * No metadata is available for these keys: ").append(noMetadataKeys); + } + if (!nullStoreKeys.isEmpty()) { + reason.append("\n * No store is available for these keys: ").append(nullStoreKeys); + } + if (!nullValueKeys.isEmpty()) { + reason.append("\n * No value is available for these keys: ").append(nullValueKeys); + } + if (!exceptionalKeys.isEmpty()) { + reason.append("\n * Exceptions were raised for the following keys: "); + for (final Entry entry : exceptionalKeys.entrySet()) { + reason.append(String.format("\n %s:", entry.getKey())); + + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final Exception exception = entry.getValue(); + + exception.printStackTrace(new PrintStream(baos)); + try (final BufferedReader reader = new BufferedReader(new StringReader(baos.toString()))) { + String line = reader.readLine(); + while (line != null) { + reason.append("\n ").append(line); + line = reader.readLine(); + } + } + } + } + + assertThat(reason.toString(), + noMetadataKeys.isEmpty() && nullStoreKeys.isEmpty() && nullValueKeys.isEmpty() && exceptionalKeys.isEmpty()); + } + + @Test + public void shouldRejectNonExistentStoreName() throws InterruptedException { + final String uniqueTestName = safeUniqueTestName(getClass(), testName); + final String input = uniqueTestName + "-input"; + final String storeName = uniqueTestName + "-input-table"; + + final StreamsBuilder builder = new StreamsBuilder(); + builder.table( + input, + Materialized + .>as(storeName) + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.String()) + ); + + final Properties properties = mkProperties(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, safeUniqueTestName(getClass(), testName)), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()) + )); + + CLUSTER.createTopic(input); + + try (final KafkaStreams streams = getRunningStreams(properties, builder, true)) { + final ReadOnlyKeyValueStore store = + streams.store(fromNameAndType(storeName, keyValueStore())); + assertThat(store, Matchers.notNullValue()); + + final UnknownStateStoreException exception = assertThrows( + UnknownStateStoreException.class, + () -> streams.store(fromNameAndType("no-table", keyValueStore())) + ); + assertThat( + exception.getMessage(), + is("Cannot get state store no-table because no such store is registered in the topology.") + ); + } + } + + @Test + public void shouldRejectWronglyTypedStore() throws InterruptedException { + final String uniqueTestName = safeUniqueTestName(getClass(), testName); + final String input = uniqueTestName + "-input"; + final String storeName = uniqueTestName + "-input-table"; + + final StreamsBuilder builder = new StreamsBuilder(); + builder.table( + input, + Materialized + .>as(storeName) + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.String()) + ); + + CLUSTER.createTopic(input); + + final Properties properties = mkProperties(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, uniqueTestName + "-app"), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()) + )); + + try (final KafkaStreams streams = getRunningStreams(properties, builder, true)) { + final ReadOnlyKeyValueStore store = + streams.store(fromNameAndType(storeName, keyValueStore())); + assertThat(store, Matchers.notNullValue()); + + // Note that to check the type we actually need a store reference, + // so we can't check when you get the IQ store, only when you + // try to use it. Presumably, this could be improved. + final ReadOnlySessionStore sessionStore = + streams.store(fromNameAndType(storeName, sessionStore())); + final InvalidStateStoreException exception = assertThrows( + InvalidStateStoreException.class, + () -> sessionStore.fetch("a") + ); + assertThat( + exception.getMessage(), + is( + "Cannot get state store " + storeName + " because the queryable store type" + + " [class org.apache.kafka.streams.state.QueryableStoreTypes$SessionStoreType]" + + " does not accept the actual store type" + + " [class org.apache.kafka.streams.state.internals.MeteredTimestampedKeyValueStore]." + ) + ); + } + } + + @Test + public void shouldBeAbleToQueryDuringRebalance() throws Exception { + final int numThreads = STREAM_TWO_PARTITIONS; + final List streamsList = new ArrayList<>(numThreads); + final List listeners = new ArrayList<>(numThreads); + + final ProducerRunnable producerRunnable = new ProducerRunnable(streamThree, inputValues, 1); + producerRunnable.run(); + + // create stream threads + final String storeName = "word-count-store"; + final String windowStoreName = "windowed-word-count-store"; + for (int i = 0; i < numThreads; i++) { + final Properties props = (Properties) streamsConfiguration.clone(); + props.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory("shouldBeAbleToQueryDuringRebalance-" + i).getPath()); + props.put(StreamsConfig.APPLICATION_SERVER_CONFIG, "localhost:" + i); + props.put(StreamsConfig.CLIENT_ID_CONFIG, "instance-" + i); + final KafkaStreams streams = + createCountStream(streamThree, outputTopicThree, outputTopicConcurrentWindowed, storeName, windowStoreName, props); + final KafkaStreamsTest.StateListenerStub listener = new KafkaStreamsTest.StateListenerStub(); + streams.setStateListener(listener); + listeners.add(listener); + streamsList.add(streams); + } + startApplicationAndWaitUntilRunning(streamsList, Duration.ofSeconds(60)); + + final Set stores = mkSet(storeName + "-" + streamThree, windowStoreName + "-" + streamThree); + verifyOffsetLagFetch(streamsList, stores, Arrays.asList(4, 4)); + + try { + waitUntilAtLeastNumRecordProcessed(outputTopicThree, 1); + + for (int i = 0; i < streamsList.size(); i++) { + verifyAllKVKeys( + streamsList, + streamsList.get(i), + listeners.get(i), + inputValuesKeys, + storeName + "-" + streamThree, + DEFAULT_TIMEOUT_MS, + true); + verifyAllWindowedKeys( + streamsList, + streamsList.get(i), + listeners.get(i), + inputValuesKeys, + windowStoreName + "-" + streamThree, + 0L, + WINDOW_SIZE, + DEFAULT_TIMEOUT_MS, + true); + } + verifyOffsetLagFetch(streamsList, stores, Arrays.asList(4, 4)); + + // kill N-1 threads + for (int i = 1; i < streamsList.size(); i++) { + final Duration closeTimeout = Duration.ofSeconds(60); + assertThat(String.format("Streams instance %s did not close in %d ms", i, closeTimeout.toMillis()), + streamsList.get(i).close(closeTimeout)); + } + + waitForApplicationState(streamsList.subList(1, numThreads), State.NOT_RUNNING, Duration.ofSeconds(60)); + verifyOffsetLagFetch(streamsList, stores, Arrays.asList(4, 0)); + + // It's not enough to assert that the first instance is RUNNING because it is possible + // for the above checks to succeed while the instance is in a REBALANCING state. + waitForApplicationState(streamsList.subList(0, 1), State.RUNNING, Duration.ofSeconds(60)); + verifyOffsetLagFetch(streamsList, stores, Arrays.asList(4, 0)); + + // Even though the closed instance(s) are now in NOT_RUNNING there is no guarantee that + // the running instance is aware of this, so we must run our follow up queries with + // enough time for the shutdown to be detected. + + // query from the remaining thread + verifyAllKVKeys( + streamsList, + streamsList.get(0), + listeners.get(0), + inputValuesKeys, + storeName + "-" + streamThree, + DEFAULT_TIMEOUT_MS, + true); + verifyAllWindowedKeys( + streamsList, + streamsList.get(0), + listeners.get(0), + inputValuesKeys, + windowStoreName + "-" + streamThree, + 0L, + WINDOW_SIZE, + DEFAULT_TIMEOUT_MS, + true); + retryOnExceptionWithTimeout(DEFAULT_TIMEOUT_MS, () -> verifyOffsetLagFetch(streamsList, stores, Arrays.asList(8, 0))); + } finally { + for (final KafkaStreams streams : streamsList) { + streams.close(); + } + } + } + + @Test + public void shouldBeAbleQueryStandbyStateDuringRebalance() throws Exception { + final int numThreads = STREAM_TWO_PARTITIONS; + final List streamsList = new ArrayList<>(numThreads); + final List listeners = new ArrayList<>(numThreads); + + final ProducerRunnable producerRunnable = new ProducerRunnable(streamThree, inputValues, 1); + producerRunnable.run(); + + // create stream threads + final String storeName = "word-count-store"; + final String windowStoreName = "windowed-word-count-store"; + final Set stores = mkSet(storeName + "-" + streamThree, windowStoreName + "-" + streamThree); + for (int i = 0; i < numThreads; i++) { + final Properties props = (Properties) streamsConfiguration.clone(); + props.put(StreamsConfig.APPLICATION_SERVER_CONFIG, "localhost:" + i); + props.put(StreamsConfig.CLIENT_ID_CONFIG, "instance-" + i); + props.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1); + props.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory("shouldBeAbleQueryStandbyStateDuringRebalance-" + i).getPath()); + final KafkaStreams streams = + createCountStream(streamThree, outputTopicThree, outputTopicConcurrentWindowed, storeName, windowStoreName, props); + final KafkaStreamsTest.StateListenerStub listener = new KafkaStreamsTest.StateListenerStub(); + streams.setStateListener(listener); + listeners.add(listener); + streamsList.add(streams); + } + startApplicationAndWaitUntilRunning(streamsList, ofSeconds(60)); + verifyOffsetLagFetch(streamsList, stores, Arrays.asList(8, 8)); + + try { + waitUntilAtLeastNumRecordProcessed(outputTopicThree, 1); + + // Ensure each thread can serve all keys by itself; i.e standby replication works. + for (int i = 0; i < streamsList.size(); i++) { + verifyAllKVKeys( + streamsList, + streamsList.get(i), + listeners.get(i), + inputValuesKeys, + storeName + "-" + streamThree, + DEFAULT_TIMEOUT_MS, + false); + verifyAllWindowedKeys( + streamsList, + streamsList.get(i), + listeners.get(i), + inputValuesKeys, + windowStoreName + "-" + streamThree, + 0L, + WINDOW_SIZE, + DEFAULT_TIMEOUT_MS, + false); + } + verifyOffsetLagFetch(streamsList, stores, Arrays.asList(8, 8)); + + // kill N-1 threads + for (int i = 1; i < streamsList.size(); i++) { + final Duration closeTimeout = Duration.ofSeconds(60); + assertThat(String.format("Streams instance %s did not close in %d ms", i, closeTimeout.toMillis()), + streamsList.get(i).close(closeTimeout)); + } + + waitForApplicationState(streamsList.subList(1, numThreads), State.NOT_RUNNING, Duration.ofSeconds(60)); + verifyOffsetLagFetch(streamsList, stores, Arrays.asList(8, 0)); + + // Now, confirm that all the keys are still queryable on the remaining thread, regardless of the state + verifyAllKVKeys( + streamsList, + streamsList.get(0), + listeners.get(0), + inputValuesKeys, + storeName + "-" + streamThree, + DEFAULT_TIMEOUT_MS, + false); + verifyAllWindowedKeys( + streamsList, + streamsList.get(0), + listeners.get(0), + inputValuesKeys, + windowStoreName + "-" + streamThree, + 0L, + WINDOW_SIZE, + DEFAULT_TIMEOUT_MS, + false); + + retryOnExceptionWithTimeout(DEFAULT_TIMEOUT_MS, () -> verifyOffsetLagFetch(streamsList, stores, Arrays.asList(8, 0))); + } finally { + for (final KafkaStreams streams : streamsList) { + streams.close(); + } + } + } + + @Test + public void shouldBeAbleToQueryStateWithZeroSizedCache() throws Exception { + verifyCanQueryState(0); + } + + @Test + public void shouldBeAbleToQueryStateWithNonZeroSizedCache() throws Exception { + verifyCanQueryState(10 * 1024 * 1024); + } + + @Test + public void shouldBeAbleToQueryFilterState() throws Exception { + streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.Long().getClass()); + final StreamsBuilder builder = new StreamsBuilder(); + final String[] keys = {"hello", "goodbye", "welcome", "go", "kafka"}; + final Set> batch1 = new HashSet<>( + Arrays.asList( + new KeyValue<>(keys[0], 1L), + new KeyValue<>(keys[1], 1L), + new KeyValue<>(keys[2], 3L), + new KeyValue<>(keys[3], 5L), + new KeyValue<>(keys[4], 2L)) + ); + final Set> expectedBatch1 = + new HashSet<>(Collections.singleton(new KeyValue<>(keys[4], 2L))); + + IntegrationTestUtils.produceKeyValuesSynchronously( + streamOne, + batch1, + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + StringSerializer.class, + LongSerializer.class, + new Properties()), + mockTime); + final Predicate filterPredicate = (key, value) -> key.contains("kafka"); + final KTable t1 = builder.table(streamOne); + final KTable t2 = t1.filter(filterPredicate, Materialized.as("queryFilter")); + t1.filterNot(filterPredicate, Materialized.as("queryFilterNot")); + t2.toStream().to(outputTopic); + + kafkaStreams = new KafkaStreams(builder.build(), streamsConfiguration); + startKafkaStreamsAndWaitForRunningState(kafkaStreams); + + waitUntilAtLeastNumRecordProcessed(outputTopic, 1); + + final ReadOnlyKeyValueStore myFilterStore = + IntegrationTestUtils.getStore("queryFilter", kafkaStreams, keyValueStore()); + + final ReadOnlyKeyValueStore myFilterNotStore = + IntegrationTestUtils.getStore("queryFilterNot", kafkaStreams, keyValueStore()); + + for (final KeyValue expectedEntry : expectedBatch1) { + TestUtils.waitForCondition(() -> expectedEntry.value.equals(myFilterStore.get(expectedEntry.key)), + "Cannot get expected result"); + } + for (final KeyValue batchEntry : batch1) { + if (!expectedBatch1.contains(batchEntry)) { + TestUtils.waitForCondition(() -> myFilterStore.get(batchEntry.key) == null, + "Cannot get null result"); + } + } + + for (final KeyValue expectedEntry : expectedBatch1) { + TestUtils.waitForCondition(() -> myFilterNotStore.get(expectedEntry.key) == null, + "Cannot get null result"); + } + for (final KeyValue batchEntry : batch1) { + if (!expectedBatch1.contains(batchEntry)) { + TestUtils.waitForCondition(() -> batchEntry.value.equals(myFilterNotStore.get(batchEntry.key)), + "Cannot get expected result"); + } + } + } + + @Test + public void shouldBeAbleToQueryMapValuesState() throws Exception { + streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + final StreamsBuilder builder = new StreamsBuilder(); + final String[] keys = {"hello", "goodbye", "welcome", "go", "kafka"}; + final Set> batch1 = new HashSet<>( + Arrays.asList( + new KeyValue<>(keys[0], "1"), + new KeyValue<>(keys[1], "1"), + new KeyValue<>(keys[2], "3"), + new KeyValue<>(keys[3], "5"), + new KeyValue<>(keys[4], "2")) + ); + + IntegrationTestUtils.produceKeyValuesSynchronously( + streamOne, + batch1, + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + StringSerializer.class, + StringSerializer.class, + new Properties()), + mockTime); + + final KTable t1 = builder.table(streamOne); + t1 + .mapValues( + (ValueMapper) Long::valueOf, + Materialized.>as("queryMapValues").withValueSerde(Serdes.Long())) + .toStream() + .to(outputTopic, Produced.with(Serdes.String(), Serdes.Long())); + + kafkaStreams = new KafkaStreams(builder.build(), streamsConfiguration); + startKafkaStreamsAndWaitForRunningState(kafkaStreams); + + waitUntilAtLeastNumRecordProcessed(outputTopic, 5); + + final ReadOnlyKeyValueStore myMapStore = + IntegrationTestUtils.getStore("queryMapValues", kafkaStreams, keyValueStore()); + + for (final KeyValue batchEntry : batch1) { + assertEquals(Long.valueOf(batchEntry.value), myMapStore.get(batchEntry.key)); + } + + try (final KeyValueIterator range = myMapStore.range("hello", "kafka")) { + while (range.hasNext()) { + System.out.println(range.next()); + } + } + } + + @Test + public void shouldBeAbleToQueryKeysWithGivenPrefix() throws Exception { + streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + final StreamsBuilder builder = new StreamsBuilder(); + final String[] keys = {"hello", "goodbye", "welcome", "go", "kafka"}; + final Set> batch1 = new HashSet<>( + Arrays.asList( + new KeyValue<>(keys[0], "1"), + new KeyValue<>(keys[1], "1"), + new KeyValue<>(keys[2], "3"), + new KeyValue<>(keys[3], "5"), + new KeyValue<>(keys[4], "2")) + ); + + final List> expectedPrefixScanResult = Arrays.asList( + new KeyValue<>(keys[3], 5L), + new KeyValue<>(keys[1], 1L) + ); + + IntegrationTestUtils.produceKeyValuesSynchronously( + streamOne, + batch1, + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + StringSerializer.class, + StringSerializer.class, + new Properties()), + mockTime); + + final KTable t1 = builder.table(streamOne); + t1 + .mapValues( + (ValueMapper) Long::valueOf, + Materialized.>as("queryMapValues").withValueSerde(Serdes.Long())) + .toStream() + .to(outputTopic, Produced.with(Serdes.String(), Serdes.Long())); + + kafkaStreams = new KafkaStreams(builder.build(), streamsConfiguration); + startKafkaStreamsAndWaitForRunningState(kafkaStreams); + + waitUntilAtLeastNumRecordProcessed(outputTopic, 5); + + final ReadOnlyKeyValueStore myMapStore = + IntegrationTestUtils.getStore("queryMapValues", kafkaStreams, keyValueStore()); + + int index = 0; + try (final KeyValueIterator range = myMapStore.prefixScan("go", Serdes.String().serializer())) { + while (range.hasNext()) { + assertEquals(expectedPrefixScanResult.get(index++), range.next()); + } + } + } + + @Test + public void shouldBeAbleToQueryMapValuesAfterFilterState() throws Exception { + streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + final StreamsBuilder builder = new StreamsBuilder(); + final String[] keys = {"hello", "goodbye", "welcome", "go", "kafka"}; + final Set> batch1 = new HashSet<>( + Arrays.asList( + new KeyValue<>(keys[0], "1"), + new KeyValue<>(keys[1], "1"), + new KeyValue<>(keys[2], "3"), + new KeyValue<>(keys[3], "5"), + new KeyValue<>(keys[4], "2")) + ); + final Set> expectedBatch1 = + new HashSet<>(Collections.singleton(new KeyValue<>(keys[4], 2L))); + + IntegrationTestUtils.produceKeyValuesSynchronously( + streamOne, + batch1, + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + StringSerializer.class, + StringSerializer.class, + new Properties()), + mockTime); + + final Predicate filterPredicate = (key, value) -> key.contains("kafka"); + final KTable t1 = builder.table(streamOne); + final KTable t2 = t1.filter(filterPredicate, Materialized.as("queryFilter")); + final KTable t3 = t2 + .mapValues( + (ValueMapper) Long::valueOf, + Materialized.>as("queryMapValues").withValueSerde(Serdes.Long())); + t3.toStream().to(outputTopic, Produced.with(Serdes.String(), Serdes.Long())); + + kafkaStreams = new KafkaStreams(builder.build(), streamsConfiguration); + startKafkaStreamsAndWaitForRunningState(kafkaStreams); + + waitUntilAtLeastNumRecordProcessed(outputTopic, 1); + + final ReadOnlyKeyValueStore myMapStore = + IntegrationTestUtils.getStore("queryMapValues", kafkaStreams, keyValueStore()); + + for (final KeyValue expectedEntry : expectedBatch1) { + assertEquals(expectedEntry.value, myMapStore.get(expectedEntry.key)); + } + for (final KeyValue batchEntry : batch1) { + final KeyValue batchEntryMapValue = + new KeyValue<>(batchEntry.key, Long.valueOf(batchEntry.value)); + if (!expectedBatch1.contains(batchEntryMapValue)) { + assertNull(myMapStore.get(batchEntry.key)); + } + } + } + + private void verifyCanQueryState(final int cacheSizeBytes) throws Exception { + streamsConfiguration.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, cacheSizeBytes); + final StreamsBuilder builder = new StreamsBuilder(); + final String[] keys = {"hello", "goodbye", "welcome", "go", "kafka"}; + + final Set> batch1 = new TreeSet<>(stringComparator); + batch1.addAll(Arrays.asList( + new KeyValue<>(keys[0], "hello"), + new KeyValue<>(keys[1], "goodbye"), + new KeyValue<>(keys[2], "welcome"), + new KeyValue<>(keys[3], "go"), + new KeyValue<>(keys[4], "kafka"))); + + final Set> expectedCount = new TreeSet<>(stringLongComparator); + for (final String key : keys) { + expectedCount.add(new KeyValue<>(key, 1L)); + } + + IntegrationTestUtils.produceKeyValuesSynchronously( + streamOne, + batch1, + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + StringSerializer.class, + StringSerializer.class, + new Properties()), + mockTime); + + final KStream s1 = builder.stream(streamOne); + + // Non Windowed + final String storeName = "my-count"; + s1.groupByKey() + .count(Materialized.as(storeName)) + .toStream() + .to(outputTopic, Produced.with(Serdes.String(), Serdes.Long())); + + final String windowStoreName = "windowed-count"; + s1.groupByKey() + .windowedBy(TimeWindows.of(ofMillis(WINDOW_SIZE))) + .count(Materialized.as(windowStoreName)); + kafkaStreams = new KafkaStreams(builder.build(), streamsConfiguration); + startKafkaStreamsAndWaitForRunningState(kafkaStreams); + + waitUntilAtLeastNumRecordProcessed(outputTopic, 1); + + final ReadOnlyKeyValueStore myCount = + IntegrationTestUtils.getStore(storeName, kafkaStreams, keyValueStore()); + + final ReadOnlyWindowStore windowStore = + IntegrationTestUtils.getStore(windowStoreName, kafkaStreams, QueryableStoreTypes.windowStore()); + + verifyCanGetByKey(keys, + expectedCount, + expectedCount, + windowStore, + myCount); + + verifyRangeAndAll(expectedCount, myCount); + } + + @Test + public void shouldNotMakeStoreAvailableUntilAllStoresAvailable() throws Exception { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream = builder.stream(streamThree); + + final String storeName = "count-by-key"; + stream.groupByKey().count(Materialized.as(storeName)); + kafkaStreams = new KafkaStreams(builder.build(), streamsConfiguration); + startKafkaStreamsAndWaitForRunningState(kafkaStreams); + + final KeyValue hello = KeyValue.pair("hello", "hello"); + IntegrationTestUtils.produceKeyValuesSynchronously( + streamThree, + Arrays.asList(hello, hello, hello, hello, hello, hello, hello, hello), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + StringSerializer.class, + StringSerializer.class, + new Properties()), + mockTime); + + final int maxWaitMs = 30000; + + final ReadOnlyKeyValueStore store = + IntegrationTestUtils.getStore(storeName, kafkaStreams, keyValueStore()); + + TestUtils.waitForCondition( + () -> Long.valueOf(8).equals(store.get("hello")), + maxWaitMs, + "wait for count to be 8"); + + // close stream + kafkaStreams.close(); + + // start again, and since it may take time to restore we wait for it to transit to RUNNING a bit longer + kafkaStreams = new KafkaStreams(builder.build(), streamsConfiguration); + startKafkaStreamsAndWaitForRunningState(kafkaStreams, maxWaitMs); + + // make sure we never get any value other than 8 for hello + TestUtils.waitForCondition( + () -> { + try { + assertEquals(8L, IntegrationTestUtils.getStore(storeName, kafkaStreams, keyValueStore()).get("hello")); + return true; + } catch (final InvalidStateStoreException ise) { + return false; + } + }, + maxWaitMs, + "waiting for store " + storeName); + + } + + @Test + @Deprecated //A single thread should no longer die + public void shouldAllowToQueryAfterThreadDied() throws Exception { + final AtomicBoolean beforeFailure = new AtomicBoolean(true); + final AtomicBoolean failed = new AtomicBoolean(false); + final String storeName = "store"; + + final StreamsBuilder builder = new StreamsBuilder(); + final KStream input = builder.stream(streamOne); + input + .groupByKey() + .reduce((value1, value2) -> { + if (value1.length() > 1) { + if (beforeFailure.compareAndSet(true, false)) { + throw new RuntimeException("Injected test exception"); + } + } + return value1 + value2; + }, Materialized.as(storeName)) + .toStream() + .to(outputTopic); + + streamsConfiguration.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 2); + kafkaStreams = new KafkaStreams(builder.build(), streamsConfiguration); + kafkaStreams.setUncaughtExceptionHandler((t, e) -> failed.set(true)); + + // since we start with two threads, wait for a bit longer for both of them to transit to running + startKafkaStreamsAndWaitForRunningState(kafkaStreams, 30000); + + IntegrationTestUtils.produceKeyValuesSynchronously( + streamOne, + Arrays.asList( + KeyValue.pair("a", "1"), + KeyValue.pair("a", "2"), + KeyValue.pair("b", "3"), + KeyValue.pair("b", "4")), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + StringSerializer.class, + StringSerializer.class, + new Properties()), + mockTime); + + final int maxWaitMs = 30000; + + final ReadOnlyKeyValueStore store = + IntegrationTestUtils.getStore(storeName, kafkaStreams, keyValueStore()); + + TestUtils.waitForCondition( + () -> "12".equals(store.get("a")) && "34".equals(store.get("b")), + maxWaitMs, + "wait for agg to be and "); + + IntegrationTestUtils.produceKeyValuesSynchronously( + streamOne, + Collections.singleton(KeyValue.pair("a", "5")), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + StringSerializer.class, + StringSerializer.class, + new Properties()), + mockTime); + + TestUtils.waitForCondition( + failed::get, + maxWaitMs, + "wait for thread to fail"); + + final ReadOnlyKeyValueStore store2 = + IntegrationTestUtils.getStore(storeName, kafkaStreams, keyValueStore()); + + try { + TestUtils.waitForCondition( + () -> ("125".equals(store2.get("a")) + || "1225".equals(store2.get("a")) + || "12125".equals(store2.get("a"))) + && + ("34".equals(store2.get("b")) + || "344".equals(store2.get("b")) + || "3434".equals(store2.get("b"))), + maxWaitMs, + "wait for agg to be |||| and ||||"); + } catch (final Throwable t) { + throw new RuntimeException("Store content is a: " + store2.get("a") + "; b: " + store2.get("b"), t); + } + } + + private void verifyRangeAndAll(final Set> expectedCount, + final ReadOnlyKeyValueStore myCount) { + final Set> countRangeResults = new TreeSet<>(stringLongComparator); + final Set> countAllResults = new TreeSet<>(stringLongComparator); + final Set> expectedRangeResults = new TreeSet<>(stringLongComparator); + + expectedRangeResults.addAll(Arrays.asList( + new KeyValue<>("hello", 1L), + new KeyValue<>("go", 1L), + new KeyValue<>("goodbye", 1L), + new KeyValue<>("kafka", 1L) + )); + + try (final KeyValueIterator range = myCount.range("go", "kafka")) { + while (range.hasNext()) { + countRangeResults.add(range.next()); + } + } + + try (final KeyValueIterator all = myCount.all()) { + while (all.hasNext()) { + countAllResults.add(all.next()); + } + } + + assertThat(countRangeResults, equalTo(expectedRangeResults)); + assertThat(countAllResults, equalTo(expectedCount)); + } + + private void verifyCanGetByKey(final String[] keys, + final Set> expectedWindowState, + final Set> expectedCount, + final ReadOnlyWindowStore windowStore, + final ReadOnlyKeyValueStore myCount) throws Exception { + final Set> windowState = new TreeSet<>(stringLongComparator); + final Set> countState = new TreeSet<>(stringLongComparator); + + final long timeout = System.currentTimeMillis() + 30000; + while ((windowState.size() < keys.length || + countState.size() < keys.length) && + System.currentTimeMillis() < timeout) { + Thread.sleep(10); + for (final String key : keys) { + windowState.addAll(fetch(windowStore, key)); + final Long value = myCount.get(key); + if (value != null) { + countState.add(new KeyValue<>(key, value)); + } + } + } + assertThat(windowState, equalTo(expectedWindowState)); + assertThat(countState, equalTo(expectedCount)); + } + + private void waitUntilAtLeastNumRecordProcessed(final String topic, + final int numRecs) throws Exception { + final long timeout = DEFAULT_TIMEOUT_MS; + final Properties config = new Properties(); + config.setProperty(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + config.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "queryable-state-consumer"); + config.setProperty(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + config.setProperty(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class.getName()); + config.setProperty(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, LongDeserializer.class.getName()); + IntegrationTestUtils.waitUntilMinValuesRecordsReceived( + config, + topic, + numRecs, + timeout); + } + + private Set> fetch(final ReadOnlyWindowStore store, + final String key) { + try (final WindowStoreIterator fetch = + store.fetch(key, ofEpochMilli(0), ofEpochMilli(System.currentTimeMillis()))) { + if (fetch.hasNext()) { + final KeyValue next = fetch.next(); + return Collections.singleton(KeyValue.pair(key, next.value)); + } + } + + return Collections.emptySet(); + } + + /** + * A class that periodically produces records in a separate thread + */ + private class ProducerRunnable implements Runnable { + private final String topic; + private final List inputValues; + private final int numIterations; + private int currIteration = 0; + + ProducerRunnable(final String topic, + final List inputValues, + final int numIterations) { + this.topic = topic; + this.inputValues = inputValues; + this.numIterations = numIterations; + } + + private synchronized void incrementIteration() { + currIteration++; + } + + synchronized int getCurrIteration() { + return currIteration; + } + + @Override + public void run() { + final Properties producerConfig = new Properties(); + producerConfig.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + producerConfig.put(ProducerConfig.ACKS_CONFIG, "all"); + producerConfig.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class); + producerConfig.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, StringSerializer.class); + + try (final KafkaProducer producer = + new KafkaProducer<>(producerConfig, new StringSerializer(), new StringSerializer())) { + + while (getCurrIteration() < numIterations) { + for (final String value : inputValues) { + producer.send(new ProducerRecord<>(topic, value)); + } + incrementIteration(); + } + } + } + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/RangeQueryIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/RangeQueryIntegrationTest.java new file mode 100644 index 0000000..aabf6e2 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/RangeQueryIntegrationTest.java @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.integration; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.state.KeyValueBytesStoreSupplier; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.QueryableStoreTypes; +import org.apache.kafka.streams.state.ReadOnlyKeyValueStore; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.IOException; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Properties; +import java.util.function.Predicate; +import java.util.function.Supplier; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +@RunWith(Parameterized.class) +@Category({IntegrationTest.class}) +public class RangeQueryIntegrationTest { + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(1); + private static final Properties STREAMS_CONFIG = new Properties(); + private static final String APP_ID = "range-query-integration-test"; + private static final Long COMMIT_INTERVAL = 100L; + private static String inputStream; + private static final String TABLE_NAME = "mytable"; + private static final int DATA_SIZE = 5; + + private enum StoreType { InMemory, RocksDB, Timed }; + private StoreType storeType; + private boolean enableLogging; + private boolean enableCaching; + private boolean forward; + private KafkaStreams kafkaStreams; + + private LinkedList> records; + private String low; + private String high; + private String middle; + private String innerLow; + private String innerHigh; + private String innerLowBetween; + private String innerHighBetween; + + public RangeQueryIntegrationTest(final StoreType storeType, final boolean enableLogging, final boolean enableCaching, final boolean forward) { + this.storeType = storeType; + this.enableLogging = enableLogging; + this.enableCaching = enableCaching; + this.forward = forward; + + records = new LinkedList<>(); + final int m = DATA_SIZE / 2; + for (int i = 0; i < DATA_SIZE; i++) { + final String key = "key-" + i * 2; + final String value = "val-" + i * 2; + records.add(new KeyValue<>(key, value)); + high = key; + if (low == null) { + low = key; + } + if (i == m) { + middle = key; + } + if (i == 1) { + innerLow = key; + final int index = i * 2 - 1; + innerLowBetween = "key-" + index; + } + if (i == DATA_SIZE - 2) { + innerHigh = key; + final int index = i * 2 + 1; + innerHighBetween = "key-" + index; + } + } + Assert.assertNotNull(low); + Assert.assertNotNull(high); + Assert.assertNotNull(middle); + Assert.assertNotNull(innerLow); + Assert.assertNotNull(innerHigh); + Assert.assertNotNull(innerLowBetween); + Assert.assertNotNull(innerHighBetween); + } + + @Rule + public TestName testName = new TestName(); + + @Parameterized.Parameters(name = "storeType={0}, enableLogging={1}, enableCaching={2}, forward={3}") + public static Collection data() { + final List types = Arrays.asList(StoreType.InMemory, StoreType.RocksDB, StoreType.Timed); + final List logging = Arrays.asList(true, false); + final List caching = Arrays.asList(true, false); + final List forward = Arrays.asList(true, false); + return buildParameters(types, logging, caching, forward); + } + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + STREAMS_CONFIG.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + STREAMS_CONFIG.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + STREAMS_CONFIG.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + STREAMS_CONFIG.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + STREAMS_CONFIG.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, COMMIT_INTERVAL); + STREAMS_CONFIG.put(StreamsConfig.APPLICATION_ID_CONFIG, APP_ID); + STREAMS_CONFIG.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + @Before + public void setupTopics() throws Exception { + inputStream = "input-topic"; + CLUSTER.createTopic(inputStream); + } + + @After + public void cleanup() throws InterruptedException { + CLUSTER.deleteAllTopicsAndWait(120000); + } + + @Test + public void testStoreConfig() throws Exception { + final StreamsBuilder builder = new StreamsBuilder(); + final Materialized> stateStoreConfig = getStoreConfig(storeType, TABLE_NAME, enableLogging, enableCaching); + final KTable table = builder.table(inputStream, stateStoreConfig); + + try (final KafkaStreams kafkaStreams = new KafkaStreams(builder.build(), STREAMS_CONFIG)) { + final List kafkaStreamsList = Arrays.asList(kafkaStreams); + + IntegrationTestUtils.startApplicationAndWaitUntilRunning(kafkaStreamsList, Duration.ofSeconds(60)); + + writeInputData(); + + final ReadOnlyKeyValueStore stateStore = IntegrationTestUtils.getStore(1000_000L, TABLE_NAME, kafkaStreams, QueryableStoreTypes.keyValueStore()); + + // wait for the store to populate + TestUtils.waitForCondition(() -> stateStore.get(high) != null, "The store never finished populating"); + + //query the state store + try (final KeyValueIterator scanIterator = forward ? stateStore.range(null, null) : stateStore.reverseRange(null, null)) { + final Iterator> dataIterator = forward ? records.iterator() : records.descendingIterator(); + TestUtils.checkEquals(scanIterator, dataIterator); + } + + try (final KeyValueIterator allIterator = forward ? stateStore.all() : stateStore.reverseAll()) { + final Iterator> dataIterator = forward ? records.iterator() : records.descendingIterator(); + TestUtils.checkEquals(allIterator, dataIterator); + } + + testRange("range", stateStore, innerLow, innerHigh, forward); + testRange("until", stateStore, null, middle, forward); + testRange("from", stateStore, middle, null, forward); + + testRange("untilBetween", stateStore, null, innerHighBetween, forward); + testRange("fromBetween", stateStore, innerLowBetween, null, forward); + } + } + + private void writeInputData() { + IntegrationTestUtils.produceKeyValuesSynchronously( + inputStream, + records, + TestUtils.producerConfig(CLUSTER.bootstrapServers(), StringSerializer.class, StringSerializer.class), + CLUSTER.time + ); + } + + private List> filterList(final KeyValueIterator iterator, final String from, final String to) { + final Predicate> pred = new Predicate>() { + @Override + public boolean test(final KeyValue elem) { + if (from != null && elem.key.compareTo(from) < 0) { + return false; + } + if (to != null && elem.key.compareTo(to) > 0) { + return false; + } + return elem != null; + } + }; + + return Utils.toList(iterator, pred); + } + + private void testRange(final String name, final ReadOnlyKeyValueStore store, final String from, final String to, final boolean forward) { + try (final KeyValueIterator resultIterator = forward ? store.range(from, to) : store.reverseRange(from, to); + final KeyValueIterator expectedIterator = forward ? store.all() : store.reverseAll();) { + final List> result = Utils.toList(resultIterator); + final List> expected = filterList(expectedIterator, from, to); + assertThat(result, is(expected)); + } + } + + private Materialized> getStoreConfig(final StoreType type, final String name, final boolean cachingEnabled, final boolean loggingEnabled) { + final Supplier createStore = () -> { + if (type == StoreType.InMemory) { + return Stores.inMemoryKeyValueStore(TABLE_NAME); + } else if (type == StoreType.RocksDB) { + return Stores.persistentKeyValueStore(TABLE_NAME); + } else if (type == StoreType.Timed) { + return Stores.persistentTimestampedKeyValueStore(TABLE_NAME); + } else { + return Stores.inMemoryKeyValueStore(TABLE_NAME); + } + }; + + final KeyValueBytesStoreSupplier stateStoreSupplier = createStore.get(); + final Materialized> stateStoreConfig = Materialized + .as(stateStoreSupplier) + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.String()); + if (cachingEnabled) { + stateStoreConfig.withCachingEnabled(); + } else { + stateStoreConfig.withCachingDisabled(); + } + if (loggingEnabled) { + stateStoreConfig.withLoggingEnabled(new HashMap()); + } else { + stateStoreConfig.withLoggingDisabled(); + } + return stateStoreConfig; + } + + private static Collection buildParameters(final List... argOptions) { + List result = new LinkedList<>(); + result.add(new Object[0]); + + for (final List argOption : argOptions) { + result = times(result, argOption); + } + + return result; + } + + private static List times(final List left, final List right) { + final List result = new LinkedList<>(); + for (final Object[] args : left) { + for (final Object rightElem : right) { + final Object[] resArgs = new Object[args.length + 1]; + System.arraycopy(args, 0, resArgs, 0, args.length); + resArgs[args.length] = rightElem; + result.add(resArgs); + } + } + return result; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/RegexSourceIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/RegexSourceIntegrationTest.java new file mode 100644 index 0000000..f060d41 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/RegexSourceIntegrationTest.java @@ -0,0 +1,499 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import kafka.utils.MockTime; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRebalanceListener; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.TopologyWrapper; +import org.apache.kafka.streams.errors.StreamsUncaughtExceptionHandler; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.processor.internals.DefaultKafkaClientSupplier; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.MockApiProcessorSupplier; +import org.apache.kafka.test.MockKeyValueStoreBuilder; +import org.apache.kafka.test.StreamsTestUtils; +import org.apache.kafka.test.TestCondition; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; + +import java.io.IOException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.regex.Pattern; + +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.startApplicationAndWaitUntilRunning; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThan; + +/** + * End-to-end integration test based on using regex and named topics for creating sources, using + * an embedded Kafka cluster. + */ +@Category({IntegrationTest.class}) +public class RegexSourceIntegrationTest { + private static final int NUM_BROKERS = 1; + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(NUM_BROKERS); + + @BeforeClass + public static void startCluster() throws IOException, InterruptedException { + CLUSTER.start(); + CLUSTER.createTopics( + TOPIC_1, + TOPIC_2, + TOPIC_A, + TOPIC_C, + TOPIC_Y, + TOPIC_Z, + FA_TOPIC, + FOO_TOPIC); + CLUSTER.createTopic(PARTITIONED_TOPIC_1, 2, 1); + CLUSTER.createTopic(PARTITIONED_TOPIC_2, 2, 1); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + private final MockTime mockTime = CLUSTER.time; + + private static final String TOPIC_1 = "topic-1"; + private static final String TOPIC_2 = "topic-2"; + private static final String TOPIC_A = "topic-A"; + private static final String TOPIC_C = "topic-C"; + private static final String TOPIC_Y = "topic-Y"; + private static final String TOPIC_Z = "topic-Z"; + private static final String FA_TOPIC = "fa"; + private static final String FOO_TOPIC = "foo"; + private static final String PARTITIONED_TOPIC_1 = "partitioned-1"; + private static final String PARTITIONED_TOPIC_2 = "partitioned-2"; + + private static final String STRING_SERDE_CLASSNAME = Serdes.String().getClass().getName(); + private Properties streamsConfiguration; + private static final String STREAM_TASKS_NOT_UPDATED = "Stream tasks not updated"; + private KafkaStreams streams; + private static volatile AtomicInteger topicSuffixGenerator = new AtomicInteger(0); + private String outputTopic; + + @Before + public void setUp() throws InterruptedException { + outputTopic = createTopic(topicSuffixGenerator.incrementAndGet()); + final Properties properties = new Properties(); + properties.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + properties.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L); + properties.put(ConsumerConfig.METADATA_MAX_AGE_CONFIG, "1000"); + properties.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + properties.put(StreamsConfig.MAX_TASK_IDLE_MS_CONFIG, 0L); + properties.put(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, 10000); + + streamsConfiguration = StreamsTestUtils.getStreamsConfig( + IntegrationTestUtils.safeUniqueTestName(RegexSourceIntegrationTest.class, new TestName()), + CLUSTER.bootstrapServers(), + STRING_SERDE_CLASSNAME, + STRING_SERDE_CLASSNAME, + properties + ); + } + + @After + public void tearDown() throws IOException { + if (streams != null) { + streams.close(); + } + // Remove any state from previous test runs + IntegrationTestUtils.purgeLocalStreamsState(streamsConfiguration); + } + + @Test + public void testRegexMatchesTopicsAWhenCreated() throws Exception { + try { + final Serde stringSerde = Serdes.String(); + + final List expectedFirstAssignment = Collections.singletonList("TEST-TOPIC-1"); + // we compare lists of subscribed topics and hence requiring the order as well; this is guaranteed + // with KIP-429 since we would NOT revoke TEST-TOPIC-1 but only add TEST-TOPIC-2 so the list is always + // in the order of "TEST-TOPIC-1, TEST-TOPIC-2". Note if KIP-429 behavior ever changed it may become a flaky test + final List expectedSecondAssignment = Arrays.asList("TEST-TOPIC-1", "TEST-TOPIC-2"); + + CLUSTER.createTopic("TEST-TOPIC-1"); + + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream pattern1Stream = builder.stream(Pattern.compile("TEST-TOPIC-\\d")); + + pattern1Stream.to(outputTopic, Produced.with(stringSerde, stringSerde)); + final List assignedTopics = new CopyOnWriteArrayList<>(); + streams = new KafkaStreams(builder.build(), streamsConfiguration, new DefaultKafkaClientSupplier() { + @Override + public Consumer getConsumer(final Map config) { + return new KafkaConsumer(config, new ByteArrayDeserializer(), new ByteArrayDeserializer()) { + @Override + public void subscribe(final Pattern topics, final ConsumerRebalanceListener listener) { + super.subscribe(topics, new TheConsumerRebalanceListener(assignedTopics, listener)); + } + }; + + } + }); + + streams.start(); + TestUtils.waitForCondition(() -> assignedTopics.equals(expectedFirstAssignment), STREAM_TASKS_NOT_UPDATED); + + CLUSTER.createTopic("TEST-TOPIC-2"); + + TestUtils.waitForCondition(() -> assignedTopics.equals(expectedSecondAssignment), STREAM_TASKS_NOT_UPDATED); + + streams.close(); + } finally { + CLUSTER.deleteTopicsAndWait("TEST-TOPIC-1", "TEST-TOPIC-2"); + } + } + + @Test + public void testRegexRecordsAreProcessedAfterNewTopicCreatedWithMultipleSubtopologies() throws Exception { + final String topic1 = "TEST-TOPIC-1"; + final String topic2 = "TEST-TOPIC-2"; + + try { + CLUSTER.createTopic(topic1); + + final StreamsBuilder builder = new StreamsBuilder(); + final KStream pattern1Stream = builder.stream(Pattern.compile("TEST-TOPIC-\\d")); + final KStream otherStream = builder.stream(Pattern.compile("not-a-match")); + + pattern1Stream + .selectKey((k, v) -> k) + .groupByKey() + .aggregate(() -> "", (k, v, a) -> v) + .toStream().to(outputTopic, Produced.with(Serdes.String(), Serdes.String())); + + final Topology topology = builder.build(); + assertThat(topology.describe().subtopologies().size(), greaterThan(1)); + streams = new KafkaStreams(topology, streamsConfiguration); + + startApplicationAndWaitUntilRunning(Collections.singletonList(streams), Duration.ofSeconds(30)); + + CLUSTER.createTopic(topic2); + + final KeyValue record1 = new KeyValue<>("1", "1"); + final KeyValue record2 = new KeyValue<>("2", "2"); + IntegrationTestUtils.produceKeyValuesSynchronously( + topic1, + Collections.singletonList(record1), + TestUtils.producerConfig(CLUSTER.bootstrapServers(), StringSerializer.class, StringSerializer.class), + CLUSTER.time + ); + IntegrationTestUtils.produceKeyValuesSynchronously( + topic2, + Collections.singletonList(record2), + TestUtils.producerConfig(CLUSTER.bootstrapServers(), StringSerializer.class, StringSerializer.class), + CLUSTER.time + ); + IntegrationTestUtils.waitUntilFinalKeyValueRecordsReceived( + TestUtils.consumerConfig(CLUSTER.bootstrapServers(), StringDeserializer.class, StringDeserializer.class), + outputTopic, + Arrays.asList(record1, record2) + ); + + streams.close(); + } finally { + CLUSTER.deleteTopicsAndWait(topic1, topic2); + } + } + + private String createTopic(final int suffix) throws InterruptedException { + final String outputTopic = "outputTopic_" + suffix; + CLUSTER.createTopic(outputTopic); + return outputTopic; + } + + @Test + public void testRegexMatchesTopicsAWhenDeleted() throws Exception { + final Serde stringSerde = Serdes.String(); + final List expectedFirstAssignment = Arrays.asList("TEST-TOPIC-A", "TEST-TOPIC-B"); + final List expectedSecondAssignment = Collections.singletonList("TEST-TOPIC-B"); + final List assignedTopics = new CopyOnWriteArrayList<>(); + + try { + CLUSTER.createTopics("TEST-TOPIC-A", "TEST-TOPIC-B"); + + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream pattern1Stream = builder.stream(Pattern.compile("TEST-TOPIC-[A-Z]")); + + pattern1Stream.to(outputTopic, Produced.with(stringSerde, stringSerde)); + + streams = new KafkaStreams(builder.build(), streamsConfiguration, new DefaultKafkaClientSupplier() { + @Override + public Consumer getConsumer(final Map config) { + return new KafkaConsumer(config, new ByteArrayDeserializer(), new ByteArrayDeserializer()) { + @Override + public void subscribe(final Pattern topics, final ConsumerRebalanceListener listener) { + super.subscribe(topics, new TheConsumerRebalanceListener(assignedTopics, listener)); + } + }; + } + }); + + + streams.start(); + TestUtils.waitForCondition(() -> assignedTopics.equals(expectedFirstAssignment), STREAM_TASKS_NOT_UPDATED); + } finally { + CLUSTER.deleteTopic("TEST-TOPIC-A"); + } + + TestUtils.waitForCondition(() -> assignedTopics.equals(expectedSecondAssignment), STREAM_TASKS_NOT_UPDATED); + } + + @Test + public void shouldAddStateStoreToRegexDefinedSource() throws Exception { + final StoreBuilder> storeBuilder = new MockKeyValueStoreBuilder("testStateStore", false); + final long thirtySecondTimeout = 30 * 1000; + + final TopologyWrapper topology = new TopologyWrapper(); + topology.addSource("ingest", Pattern.compile("topic-\\d+")); + topology.addProcessor("my-processor", new MockApiProcessorSupplier<>(), "ingest"); + topology.addStateStore(storeBuilder, "my-processor"); + + streams = new KafkaStreams(topology, streamsConfiguration); + streams.start(); + + final TestCondition stateStoreNameBoundToSourceTopic = () -> { + final Map> stateStoreToSourceTopic = topology.getInternalBuilder().stateStoreNameToSourceTopics(); + final List topicNamesList = stateStoreToSourceTopic.get("testStateStore"); + return topicNamesList != null && !topicNamesList.isEmpty() && topicNamesList.get(0).equals("topic-1"); + }; + + TestUtils.waitForCondition(stateStoreNameBoundToSourceTopic, thirtySecondTimeout, "Did not find topic: [topic-1] connected to state store: [testStateStore]"); + } + + @Test + public void testShouldReadFromRegexAndNamedTopics() throws Exception { + final String topic1TestMessage = "topic-1 test"; + final String topic2TestMessage = "topic-2 test"; + final String topicATestMessage = "topic-A test"; + final String topicCTestMessage = "topic-C test"; + final String topicYTestMessage = "topic-Y test"; + final String topicZTestMessage = "topic-Z test"; + + + final Serde stringSerde = Serdes.String(); + + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream pattern1Stream = builder.stream(Pattern.compile("topic-\\d")); + final KStream pattern2Stream = builder.stream(Pattern.compile("topic-[A-D]")); + final KStream namedTopicsStream = builder.stream(Arrays.asList(TOPIC_Y, TOPIC_Z)); + + pattern1Stream.to(outputTopic, Produced.with(stringSerde, stringSerde)); + pattern2Stream.to(outputTopic, Produced.with(stringSerde, stringSerde)); + namedTopicsStream.to(outputTopic, Produced.with(stringSerde, stringSerde)); + + streams = new KafkaStreams(builder.build(), streamsConfiguration); + streams.start(); + + final Properties producerConfig = TestUtils.producerConfig(CLUSTER.bootstrapServers(), StringSerializer.class, StringSerializer.class); + + IntegrationTestUtils.produceValuesSynchronously(TOPIC_1, Collections.singleton(topic1TestMessage), producerConfig, mockTime); + IntegrationTestUtils.produceValuesSynchronously(TOPIC_2, Collections.singleton(topic2TestMessage), producerConfig, mockTime); + IntegrationTestUtils.produceValuesSynchronously(TOPIC_A, Collections.singleton(topicATestMessage), producerConfig, mockTime); + IntegrationTestUtils.produceValuesSynchronously(TOPIC_C, Collections.singleton(topicCTestMessage), producerConfig, mockTime); + IntegrationTestUtils.produceValuesSynchronously(TOPIC_Y, Collections.singleton(topicYTestMessage), producerConfig, mockTime); + IntegrationTestUtils.produceValuesSynchronously(TOPIC_Z, Collections.singleton(topicZTestMessage), producerConfig, mockTime); + + final Properties consumerConfig = TestUtils.consumerConfig(CLUSTER.bootstrapServers(), StringDeserializer.class, StringDeserializer.class); + + final List expectedReceivedValues = Arrays.asList(topicATestMessage, topic1TestMessage, topic2TestMessage, topicCTestMessage, topicYTestMessage, topicZTestMessage); + final List> receivedKeyValues = IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(consumerConfig, outputTopic, 6); + final List actualValues = new ArrayList<>(6); + + for (final KeyValue receivedKeyValue : receivedKeyValues) { + actualValues.add(receivedKeyValue.value); + } + + Collections.sort(actualValues); + Collections.sort(expectedReceivedValues); + assertThat(actualValues, equalTo(expectedReceivedValues)); + } + + @Test + public void testMultipleConsumersCanReadFromPartitionedTopic() throws Exception { + KafkaStreams partitionedStreamsLeader = null; + KafkaStreams partitionedStreamsFollower = null; + try { + final Serde stringSerde = Serdes.String(); + final StreamsBuilder builderLeader = new StreamsBuilder(); + final StreamsBuilder builderFollower = new StreamsBuilder(); + final List expectedAssignment = Arrays.asList(PARTITIONED_TOPIC_1, PARTITIONED_TOPIC_2); + + final KStream partitionedStreamLeader = builderLeader.stream(Pattern.compile("partitioned-\\d")); + final KStream partitionedStreamFollower = builderFollower.stream(Pattern.compile("partitioned-\\d")); + + + partitionedStreamLeader.to(outputTopic, Produced.with(stringSerde, stringSerde)); + partitionedStreamFollower.to(outputTopic, Produced.with(stringSerde, stringSerde)); + + final List leaderAssignment = new CopyOnWriteArrayList<>(); + final List followerAssignment = new CopyOnWriteArrayList<>(); + + partitionedStreamsLeader = new KafkaStreams(builderLeader.build(), streamsConfiguration, new DefaultKafkaClientSupplier() { + @Override + public Consumer getConsumer(final Map config) { + return new KafkaConsumer(config, new ByteArrayDeserializer(), new ByteArrayDeserializer()) { + @Override + public void subscribe(final Pattern topics, final ConsumerRebalanceListener listener) { + super.subscribe(topics, new TheConsumerRebalanceListener(leaderAssignment, listener)); + } + }; + + } + }); + partitionedStreamsFollower = new KafkaStreams(builderFollower.build(), streamsConfiguration, new DefaultKafkaClientSupplier() { + @Override + public Consumer getConsumer(final Map config) { + return new KafkaConsumer(config, new ByteArrayDeserializer(), new ByteArrayDeserializer()) { + @Override + public void subscribe(final Pattern topics, final ConsumerRebalanceListener listener) { + super.subscribe(topics, new TheConsumerRebalanceListener(followerAssignment, listener)); + } + }; + + } + }); + + partitionedStreamsLeader.start(); + partitionedStreamsFollower.start(); + TestUtils.waitForCondition(() -> followerAssignment.equals(expectedAssignment) && leaderAssignment.equals(expectedAssignment), "topic assignment not completed"); + } finally { + if (partitionedStreamsLeader != null) { + partitionedStreamsLeader.close(); + } + if (partitionedStreamsFollower != null) { + partitionedStreamsFollower.close(); + } + } + } + + @Test + public void testNoMessagesSentExceptionFromOverlappingPatterns() throws Exception { + final String fMessage = "fMessage"; + final String fooMessage = "fooMessage"; + final Serde stringSerde = Serdes.String(); + final StreamsBuilder builder = new StreamsBuilder(); + + // overlapping patterns here, no messages should be sent as TopologyException + // will be thrown when the processor topology is built. + final KStream pattern1Stream = builder.stream(Pattern.compile("foo.*")); + final KStream pattern2Stream = builder.stream(Pattern.compile("f.*")); + + pattern1Stream.to(outputTopic, Produced.with(stringSerde, stringSerde)); + pattern2Stream.to(outputTopic, Produced.with(stringSerde, stringSerde)); + + final AtomicBoolean expectError = new AtomicBoolean(false); + + streams = new KafkaStreams(builder.build(), streamsConfiguration); + streams.setStateListener((newState, oldState) -> { + if (newState == KafkaStreams.State.ERROR) { + expectError.set(true); + } + }); + streams.setUncaughtExceptionHandler(e -> { + expectError.set(true); + return StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse.SHUTDOWN_CLIENT; + }); + streams.start(); + + final Properties producerConfig = TestUtils.producerConfig(CLUSTER.bootstrapServers(), StringSerializer.class, StringSerializer.class); + + IntegrationTestUtils.produceValuesSynchronously(FA_TOPIC, Collections.singleton(fMessage), producerConfig, mockTime); + IntegrationTestUtils.produceValuesSynchronously(FOO_TOPIC, Collections.singleton(fooMessage), producerConfig, mockTime); + + final Properties consumerConfig = TestUtils.consumerConfig(CLUSTER.bootstrapServers(), StringDeserializer.class, StringDeserializer.class); + try { + IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(consumerConfig, outputTopic, 2, 5000); + throw new IllegalStateException("This should not happen: an assertion error should have been thrown before this."); + } catch (final AssertionError e) { + // this is fine + } + + assertThat(expectError.get(), is(true)); + } + + private static class TheConsumerRebalanceListener implements ConsumerRebalanceListener { + private final List assignedTopics; + private final ConsumerRebalanceListener listener; + + TheConsumerRebalanceListener(final List assignedTopics, final ConsumerRebalanceListener listener) { + this.assignedTopics = assignedTopics; + this.listener = listener; + } + + @Override + public void onPartitionsRevoked(final Collection partitions) { + for (final TopicPartition partition : partitions) { + assignedTopics.remove(partition.topic()); + } + listener.onPartitionsRevoked(partitions); + } + + @Override + public void onPartitionsAssigned(final Collection partitions) { + for (final TopicPartition partition : partitions) { + assignedTopics.add(partition.topic()); + } + Collections.sort(assignedTopics); + listener.onPartitionsAssigned(partitions); + } + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/ResetIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/ResetIntegrationTest.java new file mode 100644 index 0000000..5c236e6 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/ResetIntegrationTest.java @@ -0,0 +1,348 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import kafka.server.KafkaConfig$; +import kafka.tools.StreamsResetter; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.test.IntegrationTest; + +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; + +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.text.SimpleDateFormat; +import java.util.Calendar; +import java.util.List; +import java.util.Map; +import java.util.Properties; + +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.isEmptyConsumerGroup; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.waitForEmptyConsumerGroup; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; + +/** + * Tests local state store and global application cleanup. + */ +@Category({IntegrationTest.class}) +public class ResetIntegrationTest extends AbstractResetIntegrationTest { + + private static final String NON_EXISTING_TOPIC = "nonExistingTopic"; + + public static final EmbeddedKafkaCluster CLUSTER; + + static { + final Properties brokerProps = new Properties(); + // we double the value passed to `time.sleep` in each iteration in one of the map functions, so we disable + // expiration of connections by the brokers to avoid errors when `AdminClient` sends requests after potentially + // very long sleep times + brokerProps.put(KafkaConfig$.MODULE$.ConnectionsMaxIdleMsProp(), -1L); + CLUSTER = new EmbeddedKafkaCluster(1, brokerProps); + } + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + @Override + Map getClientSslConfig() { + return null; + } + + @Before + public void before() throws Exception { + cluster = CLUSTER; + prepareTest(); + } + + @After + public void after() throws Exception { + cleanupTest(); + } + + @Test + public void shouldNotAllowToResetWhileStreamsIsRunning() { + final String appID = IntegrationTestUtils.safeUniqueTestName(getClass(), testName); + final String[] parameters = new String[] { + "--application-id", appID, + "--bootstrap-servers", cluster.bootstrapServers(), + "--input-topics", NON_EXISTING_TOPIC + }; + final Properties cleanUpConfig = new Properties(); + cleanUpConfig.put(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, 100); + cleanUpConfig.put(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, Integer.toString(CLEANUP_CONSUMER_TIMEOUT)); + + streamsConfig.put(StreamsConfig.APPLICATION_ID_CONFIG, appID); + + // RUN + streams = new KafkaStreams(setupTopologyWithoutIntermediateUserTopic(), streamsConfig); + streams.start(); + + final int exitCode = new StreamsResetter().run(parameters, cleanUpConfig); + Assert.assertEquals(1, exitCode); + + streams.close(); + } + + @Test + public void shouldNotAllowToResetWhenInputTopicAbsent() { + final String appID = IntegrationTestUtils.safeUniqueTestName(getClass(), testName); + final String[] parameters = new String[] { + "--application-id", appID, + "--bootstrap-servers", cluster.bootstrapServers(), + "--input-topics", NON_EXISTING_TOPIC + }; + final Properties cleanUpConfig = new Properties(); + cleanUpConfig.put(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, 100); + cleanUpConfig.put(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, Integer.toString(CLEANUP_CONSUMER_TIMEOUT)); + + final int exitCode = new StreamsResetter().run(parameters, cleanUpConfig); + Assert.assertEquals(1, exitCode); + } + + @Test + public void shouldNotAllowToResetWhenIntermediateTopicAbsent() { + final String appID = IntegrationTestUtils.safeUniqueTestName(getClass(), testName); + final String[] parameters = new String[] { + "--application-id", appID, + "--bootstrap-servers", cluster.bootstrapServers(), + "--intermediate-topics", NON_EXISTING_TOPIC + }; + final Properties cleanUpConfig = new Properties(); + cleanUpConfig.put(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, 100); + cleanUpConfig.put(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, Integer.toString(CLEANUP_CONSUMER_TIMEOUT)); + + final int exitCode = new StreamsResetter().run(parameters, cleanUpConfig); + Assert.assertEquals(1, exitCode); + } + + @Test + public void shouldNotAllowToResetWhenSpecifiedInternalTopicDoesNotExist() { + final String appID = IntegrationTestUtils.safeUniqueTestName(getClass(), testName); + final String[] parameters = new String[] { + "--application-id", appID, + "--bootstrap-servers", cluster.bootstrapServers(), + "--internal-topics", NON_EXISTING_TOPIC + }; + final Properties cleanUpConfig = new Properties(); + cleanUpConfig.put(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, 100); + cleanUpConfig.put(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, Integer.toString(CLEANUP_CONSUMER_TIMEOUT)); + + final int exitCode = new StreamsResetter().run(parameters, cleanUpConfig); + Assert.assertEquals(1, exitCode); + } + + @Test + public void shouldNotAllowToResetWhenSpecifiedInternalTopicIsNotInternal() { + final String appID = IntegrationTestUtils.safeUniqueTestName(getClass(), testName); + final String[] parameters = new String[] { + "--application-id", appID, + "--bootstrap-servers", cluster.bootstrapServers(), + "--internal-topics", INPUT_TOPIC + }; + final Properties cleanUpConfig = new Properties(); + cleanUpConfig.put(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, 100); + cleanUpConfig.put(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, Integer.toString(CLEANUP_CONSUMER_TIMEOUT)); + + final int exitCode = new StreamsResetter().run(parameters, cleanUpConfig); + Assert.assertEquals(1, exitCode); + } + + @Test + public void testResetWhenLongSessionTimeoutConfiguredWithForceOption() throws Exception { + final String appID = IntegrationTestUtils.safeUniqueTestName(getClass(), testName); + streamsConfig.put(StreamsConfig.APPLICATION_ID_CONFIG, appID); + streamsConfig.put(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, Integer.toString(STREAMS_CONSUMER_TIMEOUT * 100)); + + // Run + streams = new KafkaStreams(setupTopologyWithoutIntermediateUserTopic(), streamsConfig); + streams.start(); + final List> result = IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(resultConsumerConfig, OUTPUT_TOPIC, 10); + + streams.close(); + + // RESET + streams = new KafkaStreams(setupTopologyWithoutIntermediateUserTopic(), streamsConfig); + streams.cleanUp(); + + // Reset would fail since long session timeout has been configured + final boolean cleanResult = tryCleanGlobal(false, null, null, appID); + Assert.assertFalse(cleanResult); + + // Reset will success with --force, it will force delete active members on broker side + cleanGlobal(false, "--force", null, appID); + assertThat("Group is not empty after cleanGlobal", isEmptyConsumerGroup(adminClient, appID)); + + assertInternalTopicsGotDeleted(null); + + // RE-RUN + streams.start(); + final List> resultRerun = IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(resultConsumerConfig, OUTPUT_TOPIC, 10); + streams.close(); + + assertThat(resultRerun, equalTo(result)); + cleanGlobal(false, "--force", null, appID); + } + + @Test + public void testReprocessingFromFileAfterResetWithoutIntermediateUserTopic() throws Exception { + final String appID = IntegrationTestUtils.safeUniqueTestName(getClass(), testName); + streamsConfig.put(StreamsConfig.APPLICATION_ID_CONFIG, appID); + + // RUN + streams = new KafkaStreams(setupTopologyWithoutIntermediateUserTopic(), streamsConfig); + streams.start(); + final List> result = IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(resultConsumerConfig, OUTPUT_TOPIC, 10); + + streams.close(); + waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT); + + // RESET + final File resetFile = File.createTempFile("reset", ".csv"); + try (final BufferedWriter writer = new BufferedWriter(new FileWriter(resetFile))) { + writer.write(INPUT_TOPIC + ",0,1"); + } + + streams = new KafkaStreams(setupTopologyWithoutIntermediateUserTopic(), streamsConfig); + streams.cleanUp(); + + cleanGlobal(false, "--from-file", resetFile.getAbsolutePath(), appID); + waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT); + + assertInternalTopicsGotDeleted(null); + + resetFile.deleteOnExit(); + + // RE-RUN + streams.start(); + final List> resultRerun = IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(resultConsumerConfig, OUTPUT_TOPIC, 5); + streams.close(); + + result.remove(0); + assertThat(resultRerun, equalTo(result)); + + waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT); + cleanGlobal(false, null, null, appID); + } + + @Test + public void testReprocessingFromDateTimeAfterResetWithoutIntermediateUserTopic() throws Exception { + final String appID = IntegrationTestUtils.safeUniqueTestName(getClass(), testName); + streamsConfig.put(StreamsConfig.APPLICATION_ID_CONFIG, appID); + + // RUN + streams = new KafkaStreams(setupTopologyWithoutIntermediateUserTopic(), streamsConfig); + streams.start(); + final List> result = IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(resultConsumerConfig, OUTPUT_TOPIC, 10); + + streams.close(); + waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT); + + // RESET + final File resetFile = File.createTempFile("reset", ".csv"); + try (final BufferedWriter writer = new BufferedWriter(new FileWriter(resetFile))) { + writer.write(INPUT_TOPIC + ",0,1"); + } + + streams = new KafkaStreams(setupTopologyWithoutIntermediateUserTopic(), streamsConfig); + streams.cleanUp(); + + + final SimpleDateFormat format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSS"); + final Calendar calendar = Calendar.getInstance(); + calendar.add(Calendar.DATE, -1); + + cleanGlobal(false, "--to-datetime", format.format(calendar.getTime()), appID); + waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT); + + assertInternalTopicsGotDeleted(null); + + resetFile.deleteOnExit(); + + // RE-RUN + streams.start(); + final List> resultRerun = IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(resultConsumerConfig, OUTPUT_TOPIC, 10); + streams.close(); + + assertThat(resultRerun, equalTo(result)); + + waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT); + cleanGlobal(false, null, null, appID); + } + + @Test + public void testReprocessingByDurationAfterResetWithoutIntermediateUserTopic() throws Exception { + final String appID = IntegrationTestUtils.safeUniqueTestName(getClass(), testName); + streamsConfig.put(StreamsConfig.APPLICATION_ID_CONFIG, appID); + + // RUN + streams = new KafkaStreams(setupTopologyWithoutIntermediateUserTopic(), streamsConfig); + streams.start(); + final List> result = IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(resultConsumerConfig, OUTPUT_TOPIC, 10); + + streams.close(); + waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT); + + // RESET + final File resetFile = File.createTempFile("reset", ".csv"); + try (final BufferedWriter writer = new BufferedWriter(new FileWriter(resetFile))) { + writer.write(INPUT_TOPIC + ",0,1"); + } + + streams = new KafkaStreams(setupTopologyWithoutIntermediateUserTopic(), streamsConfig); + streams.cleanUp(); + cleanGlobal(false, "--by-duration", "PT1M", appID); + + waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT); + + assertInternalTopicsGotDeleted(null); + + resetFile.deleteOnExit(); + + // RE-RUN + streams.start(); + final List> resultRerun = IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(resultConsumerConfig, OUTPUT_TOPIC, 10); + streams.close(); + + assertThat(resultRerun, equalTo(result)); + + waitForEmptyConsumerGroup(adminClient, appID, TIMEOUT_MULTIPLIER * STREAMS_CONSUMER_TIMEOUT); + cleanGlobal(false, null, null, appID); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/ResetIntegrationWithSslTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/ResetIntegrationWithSslTest.java new file mode 100644 index 0000000..34e9894 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/ResetIntegrationWithSslTest.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import kafka.server.KafkaConfig$; +import org.apache.kafka.common.network.Mode; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestSslUtils; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.experimental.categories.Category; + +import java.io.IOException; +import java.util.Map; +import java.util.Properties; + +/** + * Tests command line SSL setup for reset tool. + */ +@Category({IntegrationTest.class}) +public class ResetIntegrationWithSslTest extends AbstractResetIntegrationTest { + + public static final EmbeddedKafkaCluster CLUSTER; + + private static final Map SSL_CONFIG; + + static { + final Properties brokerProps = new Properties(); + // we double the value passed to `time.sleep` in each iteration in one of the map functions, so we disable + // expiration of connections by the brokers to avoid errors when `AdminClient` sends requests after potentially + // very long sleep times + brokerProps.put(KafkaConfig$.MODULE$.ConnectionsMaxIdleMsProp(), -1L); + + try { + SSL_CONFIG = TestSslUtils.createSslConfig(false, true, Mode.SERVER, TestUtils.tempFile(), "testCert"); + + brokerProps.put(KafkaConfig$.MODULE$.ListenersProp(), "SSL://localhost:0"); + brokerProps.put(KafkaConfig$.MODULE$.InterBrokerListenerNameProp(), "SSL"); + brokerProps.putAll(SSL_CONFIG); + } catch (final Exception e) { + throw new RuntimeException(e); + } + + CLUSTER = new EmbeddedKafkaCluster(1, brokerProps); + } + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + @Override + Map getClientSslConfig() { + return SSL_CONFIG; + } + + @Before + public void before() throws Exception { + cluster = CLUSTER; + prepareTest(); + } + + @After + public void after() throws Exception { + cleanupTest(); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/ResetPartitionTimeIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/ResetPartitionTimeIntegrationTest.java new file mode 100644 index 0000000..c3e7c28 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/ResetPartitionTimeIntegrationTest.java @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.processor.TimestampExtractor; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestUtils; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.Properties; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkProperties; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.cleanStateBeforeTest; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.getStartedStreams; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.quietlyCleanStateAfterTest; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +@RunWith(Parameterized.class) +@Category({IntegrationTest.class}) +public class ResetPartitionTimeIntegrationTest { + private static final int NUM_BROKERS = 1; + private static final Properties BROKER_CONFIG; + static { + BROKER_CONFIG = new Properties(); + BROKER_CONFIG.put("transaction.state.log.replication.factor", (short) 1); + BROKER_CONFIG.put("transaction.state.log.min.isr", 1); + } + public static final EmbeddedKafkaCluster CLUSTER = + new EmbeddedKafkaCluster(NUM_BROKERS, BROKER_CONFIG, 0L); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + private static final StringDeserializer STRING_DESERIALIZER = new StringDeserializer(); + private static final StringSerializer STRING_SERIALIZER = new StringSerializer(); + private static final Serde STRING_SERDE = Serdes.String(); + private static final int DEFAULT_TIMEOUT = 100; + private static long lastRecordedTimestamp = -2L; + + @SuppressWarnings("deprecation") + @Parameterized.Parameters(name = "{0}") + public static Collection data() { + return Arrays.asList(new String[][] { + {StreamsConfig.AT_LEAST_ONCE}, + {StreamsConfig.EXACTLY_ONCE}, + {StreamsConfig.EXACTLY_ONCE_V2} + }); + } + + @Parameterized.Parameter + public String processingGuarantee; + + @Rule + public TestName testName = new TestName(); + + @Test + public void shouldPreservePartitionTimeOnKafkaStreamRestart() { + final String appId = "app-" + safeUniqueTestName(getClass(), testName); + final String input = "input"; + final String outputRaw = "output-raw"; + + cleanStateBeforeTest(CLUSTER, 2, input, outputRaw); + + final StreamsBuilder builder = new StreamsBuilder(); + builder + .stream(input, Consumed.with(STRING_SERDE, STRING_SERDE)) + .to(outputRaw); + + final Properties streamsConfig = new Properties(); + streamsConfig.put(StreamsConfig.DEFAULT_TIMESTAMP_EXTRACTOR_CLASS_CONFIG, MaxTimestampExtractor.class); + streamsConfig.put(StreamsConfig.APPLICATION_ID_CONFIG, appId); + streamsConfig.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + streamsConfig.put(StreamsConfig.POLL_MS_CONFIG, Integer.toString(DEFAULT_TIMEOUT)); + streamsConfig.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, (long) DEFAULT_TIMEOUT); + streamsConfig.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, processingGuarantee); + streamsConfig.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()); + + KafkaStreams kafkaStreams = getStartedStreams(streamsConfig, builder, true); + try { + // start sending some records to have partition time committed + produceSynchronouslyToPartitionZero( + input, + Collections.singletonList( + new KeyValueTimestamp<>("k3", "v3", 5000) + ) + ); + verifyOutput( + outputRaw, + Collections.singletonList( + new KeyValueTimestamp<>("k3", "v3", 5000) + ) + ); + assertThat(lastRecordedTimestamp, is(-1L)); + lastRecordedTimestamp = -2L; + + kafkaStreams.close(); + assertThat(kafkaStreams.state(), is(KafkaStreams.State.NOT_RUNNING)); + + kafkaStreams = getStartedStreams(streamsConfig, builder, true); + + // resend some records and retrieve the last committed timestamp + produceSynchronouslyToPartitionZero( + input, + Collections.singletonList( + new KeyValueTimestamp<>("k5", "v5", 4999) + ) + ); + verifyOutput( + outputRaw, + Collections.singletonList( + new KeyValueTimestamp<>("k5", "v5", 4999) + ) + ); + assertThat(lastRecordedTimestamp, is(5000L)); + } finally { + kafkaStreams.close(); + quietlyCleanStateAfterTest(CLUSTER, kafkaStreams); + } + } + + public static final class MaxTimestampExtractor implements TimestampExtractor { + @Override + public long extract(final ConsumerRecord record, final long partitionTime) { + lastRecordedTimestamp = partitionTime; + return record.timestamp(); + } + } + + private void verifyOutput(final String topic, final List> keyValueTimestamps) { + final Properties properties = mkProperties( + mkMap( + mkEntry(ConsumerConfig.GROUP_ID_CONFIG, "test-group"), + mkEntry(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()), + mkEntry(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, ((Deserializer) STRING_DESERIALIZER).getClass().getName()), + mkEntry(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, ((Deserializer) STRING_DESERIALIZER).getClass().getName()) + ) + ); + IntegrationTestUtils.verifyKeyValueTimestamps(properties, topic, keyValueTimestamps); + } + + private static void produceSynchronouslyToPartitionZero(final String topic, final List> toProduce) { + final Properties producerConfig = mkProperties(mkMap( + mkEntry(ProducerConfig.CLIENT_ID_CONFIG, "anything"), + mkEntry(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, ((Serializer) STRING_SERIALIZER).getClass().getName()), + mkEntry(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, ((Serializer) STRING_SERIALIZER).getClass().getName()), + mkEntry(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()) + )); + IntegrationTestUtils.produceSynchronously(producerConfig, false, topic, Optional.of(0), toProduce); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/RestoreIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/RestoreIntegrationTest.java new file mode 100644 index 0000000..1bff14d --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/RestoreIntegrationTest.java @@ -0,0 +1,503 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.IntegerDeserializer; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KafkaStreams.State; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils.TrackingStateRestoreListener; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateRestoreListener; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.StateDirectory; +import org.apache.kafka.streams.state.KeyValueBytesStoreSupplier; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.internals.InMemoryKeyValueStore; +import org.apache.kafka.streams.state.internals.KeyValueStoreBuilder; +import org.apache.kafka.streams.state.internals.OffsetCheckpoint; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestUtils; +import org.hamcrest.CoreMatchers; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; + +import java.io.File; +import java.io.IOException; +import java.time.Duration; +import java.util.Collections; +import java.util.List; +import java.util.Properties; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import org.junit.rules.TestName; + +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.purgeLocalStreamsState; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.startApplicationAndWaitUntilRunning; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.waitForApplicationState; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.waitForCompletion; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.waitForStandbyCompletion; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.IsEqual.equalTo; +import static org.junit.Assert.assertTrue; + +@Category({IntegrationTest.class}) +public class RestoreIntegrationTest { + private static final int NUM_BROKERS = 1; + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(NUM_BROKERS); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + @Rule + public final TestName testName = new TestName(); + private String appId; + private String inputStream; + + private final int numberOfKeys = 10000; + private KafkaStreams kafkaStreams; + + @Before + public void createTopics() throws InterruptedException { + appId = safeUniqueTestName(RestoreIntegrationTest.class, testName); + inputStream = appId + "-input-stream"; + CLUSTER.createTopic(inputStream, 2, 1); + } + + private Properties props() { + final Properties streamsConfiguration = new Properties(); + streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, appId); + streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + streamsConfiguration.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory(appId).getPath()); + streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.Integer().getClass()); + streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.Integer().getClass()); + streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + streamsConfiguration.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + return streamsConfiguration; + } + + @After + public void shutdown() { + if (kafkaStreams != null) { + kafkaStreams.close(Duration.ofSeconds(30)); + } + } + + @Test + public void shouldRestoreStateFromSourceTopic() throws Exception { + final AtomicInteger numReceived = new AtomicInteger(0); + final StreamsBuilder builder = new StreamsBuilder(); + + final Properties props = props(); + props.put(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, StreamsConfig.OPTIMIZE); + + // restoring from 1000 to 4000 (committed), and then process from 4000 to 5000 on each of the two partitions + final int offsetLimitDelta = 1000; + final int offsetCheckpointed = 1000; + createStateForRestoration(inputStream, 0); + setCommittedOffset(inputStream, offsetLimitDelta); + + final StateDirectory stateDirectory = new StateDirectory(new StreamsConfig(props), new MockTime(), true, false); + // note here the checkpointed offset is the last processed record's offset, so without control message we should write this offset - 1 + new OffsetCheckpoint(new File(stateDirectory.getOrCreateDirectoryForTask(new TaskId(0, 0)), ".checkpoint")) + .write(Collections.singletonMap(new TopicPartition(inputStream, 0), (long) offsetCheckpointed - 1)); + new OffsetCheckpoint(new File(stateDirectory.getOrCreateDirectoryForTask(new TaskId(0, 1)), ".checkpoint")) + .write(Collections.singletonMap(new TopicPartition(inputStream, 1), (long) offsetCheckpointed - 1)); + + final CountDownLatch startupLatch = new CountDownLatch(1); + final CountDownLatch shutdownLatch = new CountDownLatch(1); + + builder.table(inputStream, Materialized.>as("store").withKeySerde(Serdes.Integer()).withValueSerde(Serdes.Integer())) + .toStream() + .foreach((key, value) -> { + if (numReceived.incrementAndGet() == offsetLimitDelta * 2) { + shutdownLatch.countDown(); + } + }); + + kafkaStreams = new KafkaStreams(builder.build(props), props); + kafkaStreams.setStateListener((newState, oldState) -> { + if (newState == KafkaStreams.State.RUNNING && oldState == KafkaStreams.State.REBALANCING) { + startupLatch.countDown(); + } + }); + + final AtomicLong restored = new AtomicLong(0); + kafkaStreams.setGlobalStateRestoreListener(new StateRestoreListener() { + @Override + public void onRestoreStart(final TopicPartition topicPartition, final String storeName, final long startingOffset, final long endingOffset) { + + } + + @Override + public void onBatchRestored(final TopicPartition topicPartition, final String storeName, final long batchEndOffset, final long numRestored) { + + } + + @Override + public void onRestoreEnd(final TopicPartition topicPartition, final String storeName, final long totalRestored) { + restored.addAndGet(totalRestored); + } + }); + kafkaStreams.start(); + + assertTrue(startupLatch.await(30, TimeUnit.SECONDS)); + assertThat(restored.get(), equalTo((long) numberOfKeys - offsetLimitDelta * 2 - offsetCheckpointed * 2)); + + assertTrue(shutdownLatch.await(30, TimeUnit.SECONDS)); + assertThat(numReceived.get(), equalTo(offsetLimitDelta * 2)); + } + + @Test + public void shouldRestoreStateFromChangelogTopic() throws Exception { + final String changelog = appId + "-store-changelog"; + CLUSTER.createTopic(changelog, 2, 1); + + final AtomicInteger numReceived = new AtomicInteger(0); + final StreamsBuilder builder = new StreamsBuilder(); + + final Properties props = props(); + + // restoring from 1000 to 5000, and then process from 5000 to 10000 on each of the two partitions + final int offsetCheckpointed = 1000; + createStateForRestoration(changelog, 0); + createStateForRestoration(inputStream, 10000); + + final StateDirectory stateDirectory = new StateDirectory(new StreamsConfig(props), new MockTime(), true, false); + // note here the checkpointed offset is the last processed record's offset, so without control message we should write this offset - 1 + new OffsetCheckpoint(new File(stateDirectory.getOrCreateDirectoryForTask(new TaskId(0, 0)), ".checkpoint")) + .write(Collections.singletonMap(new TopicPartition(changelog, 0), (long) offsetCheckpointed - 1)); + new OffsetCheckpoint(new File(stateDirectory.getOrCreateDirectoryForTask(new TaskId(0, 1)), ".checkpoint")) + .write(Collections.singletonMap(new TopicPartition(changelog, 1), (long) offsetCheckpointed - 1)); + + final CountDownLatch startupLatch = new CountDownLatch(1); + final CountDownLatch shutdownLatch = new CountDownLatch(1); + + builder.table(inputStream, Consumed.with(Serdes.Integer(), Serdes.Integer()), Materialized.as("store")) + .toStream() + .foreach((key, value) -> { + if (numReceived.incrementAndGet() == numberOfKeys) { + shutdownLatch.countDown(); + } + }); + + kafkaStreams = new KafkaStreams(builder.build(), props); + kafkaStreams.setStateListener((newState, oldState) -> { + if (newState == KafkaStreams.State.RUNNING && oldState == KafkaStreams.State.REBALANCING) { + startupLatch.countDown(); + } + }); + + final AtomicLong restored = new AtomicLong(0); + kafkaStreams.setGlobalStateRestoreListener(new StateRestoreListener() { + @Override + public void onRestoreStart(final TopicPartition topicPartition, final String storeName, final long startingOffset, final long endingOffset) { + + } + + @Override + public void onBatchRestored(final TopicPartition topicPartition, final String storeName, final long batchEndOffset, final long numRestored) { + + } + + @Override + public void onRestoreEnd(final TopicPartition topicPartition, final String storeName, final long totalRestored) { + restored.addAndGet(totalRestored); + } + }); + kafkaStreams.start(); + + assertTrue(startupLatch.await(30, TimeUnit.SECONDS)); + assertThat(restored.get(), equalTo((long) numberOfKeys - 2 * offsetCheckpointed)); + + assertTrue(shutdownLatch.await(30, TimeUnit.SECONDS)); + assertThat(numReceived.get(), equalTo(numberOfKeys)); + } + + @Test + public void shouldSuccessfullyStartWhenLoggingDisabled() throws InterruptedException { + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream stream = builder.stream(inputStream); + stream.groupByKey() + .reduce( + (value1, value2) -> value1 + value2, + Materialized.>as("reduce-store").withLoggingDisabled()); + + final CountDownLatch startupLatch = new CountDownLatch(1); + kafkaStreams = new KafkaStreams(builder.build(), props()); + kafkaStreams.setStateListener((newState, oldState) -> { + if (newState == KafkaStreams.State.RUNNING && oldState == KafkaStreams.State.REBALANCING) { + startupLatch.countDown(); + } + }); + + kafkaStreams.start(); + + assertTrue(startupLatch.await(30, TimeUnit.SECONDS)); + } + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Test + public void shouldProcessDataFromStoresWithLoggingDisabled() throws InterruptedException { + + IntegrationTestUtils.produceKeyValuesSynchronously(inputStream, + asList(KeyValue.pair(1, 1), + KeyValue.pair(2, 2), + KeyValue.pair(3, 3)), + TestUtils.producerConfig(CLUSTER.bootstrapServers(), + IntegerSerializer.class, + IntegerSerializer.class), + CLUSTER.time); + + final KeyValueBytesStoreSupplier lruMapSupplier = Stores.lruMap(inputStream, 10); + + final StoreBuilder> storeBuilder = new KeyValueStoreBuilder<>(lruMapSupplier, + Serdes.Integer(), + Serdes.Integer(), + CLUSTER.time) + .withLoggingDisabled(); + + final StreamsBuilder streamsBuilder = new StreamsBuilder(); + + streamsBuilder.addStateStore(storeBuilder); + + final KStream stream = streamsBuilder.stream(inputStream); + final CountDownLatch processorLatch = new CountDownLatch(3); + stream.process(() -> new KeyValueStoreProcessor(inputStream, processorLatch), inputStream); + + final Topology topology = streamsBuilder.build(); + + kafkaStreams = new KafkaStreams(topology, props()); + + final CountDownLatch latch = new CountDownLatch(1); + kafkaStreams.setStateListener((newState, oldState) -> { + if (newState == KafkaStreams.State.RUNNING && oldState == KafkaStreams.State.REBALANCING) { + latch.countDown(); + } + }); + kafkaStreams.start(); + + latch.await(30, TimeUnit.SECONDS); + + assertTrue(processorLatch.await(30, TimeUnit.SECONDS)); + } + + @Test + public void shouldRecycleStateFromStandbyTaskPromotedToActiveTaskAndNotRestore() throws Exception { + final StreamsBuilder builder = new StreamsBuilder(); + builder.table( + inputStream, + Consumed.with(Serdes.Integer(), Serdes.Integer()), Materialized.as(getCloseCountingStore("store")) + ); + createStateForRestoration(inputStream, 0); + + final Properties props1 = props(); + props1.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1); + props1.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory(appId + "-1").getPath()); + purgeLocalStreamsState(props1); + final KafkaStreams client1 = new KafkaStreams(builder.build(), props1); + + final Properties props2 = props(); + props2.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1); + props2.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory(appId + "-2").getPath()); + purgeLocalStreamsState(props2); + final KafkaStreams client2 = new KafkaStreams(builder.build(), props2); + + final TrackingStateRestoreListener restoreListener = new TrackingStateRestoreListener(); + client1.setGlobalStateRestoreListener(restoreListener); + + startApplicationAndWaitUntilRunning(asList(client1, client2), Duration.ofSeconds(60)); + + waitForCompletion(client1, 1, 30 * 1000L); + waitForCompletion(client2, 1, 30 * 1000L); + waitForStandbyCompletion(client1, 1, 30 * 1000L); + waitForStandbyCompletion(client2, 1, 30 * 1000L); + + // Sometimes the store happens to have already been closed sometime during startup, so just keep track + // of where it started and make sure it doesn't happen more times from there + final int initialStoreCloseCount = CloseCountingInMemoryStore.numStoresClosed(); + final long initialNunRestoredCount = restoreListener.totalNumRestored(); + + client2.close(); + waitForApplicationState(singletonList(client2), State.NOT_RUNNING, Duration.ofSeconds(60)); + waitForApplicationState(singletonList(client1), State.REBALANCING, Duration.ofSeconds(60)); + waitForApplicationState(singletonList(client1), State.RUNNING, Duration.ofSeconds(60)); + + waitForCompletion(client1, 1, 30 * 1000L); + waitForStandbyCompletion(client1, 1, 30 * 1000L); + + assertThat(restoreListener.totalNumRestored(), CoreMatchers.equalTo(initialNunRestoredCount)); + + // After stopping instance 2 and letting instance 1 take over its tasks, we should have closed just two stores + // total: the active and standby tasks on instance 2 + assertThat(CloseCountingInMemoryStore.numStoresClosed(), equalTo(initialStoreCloseCount + 2)); + + client1.close(); + waitForApplicationState(singletonList(client2), State.NOT_RUNNING, Duration.ofSeconds(60)); + + assertThat(CloseCountingInMemoryStore.numStoresClosed(), CoreMatchers.equalTo(initialStoreCloseCount + 4)); + } + + private static KeyValueBytesStoreSupplier getCloseCountingStore(final String name) { + return new KeyValueBytesStoreSupplier() { + @Override + public String name() { + return name; + } + + @Override + public KeyValueStore get() { + return new CloseCountingInMemoryStore(name); + } + + @Override + public String metricsScope() { + return "close-counting"; + } + }; + } + + static class CloseCountingInMemoryStore extends InMemoryKeyValueStore { + static AtomicInteger numStoresClosed = new AtomicInteger(0); + + CloseCountingInMemoryStore(final String name) { + super(name); + } + + @Override + public void close() { + numStoresClosed.incrementAndGet(); + super.close(); + } + + static int numStoresClosed() { + return numStoresClosed.get(); + } + } + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + public static class KeyValueStoreProcessor implements org.apache.kafka.streams.processor.Processor { + + private final String topic; + private final CountDownLatch processorLatch; + + private KeyValueStore store; + + KeyValueStoreProcessor(final String topic, final CountDownLatch processorLatch) { + this.topic = topic; + this.processorLatch = processorLatch; + } + + @SuppressWarnings("unchecked") + @Override + public void init(final ProcessorContext context) { + this.store = (KeyValueStore) context.getStateStore(topic); + } + + @Override + public void process(final Integer key, final Integer value) { + if (key != null) { + store.put(key, value); + processorLatch.countDown(); + } + } + + @Override + public void close() { } + } + + private void createStateForRestoration(final String changelogTopic, final int startingOffset) { + final Properties producerConfig = new Properties(); + producerConfig.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + + try (final KafkaProducer producer = + new KafkaProducer<>(producerConfig, new IntegerSerializer(), new IntegerSerializer())) { + + for (int i = 0; i < numberOfKeys; i++) { + final int offset = startingOffset + i; + producer.send(new ProducerRecord<>(changelogTopic, offset, offset)); + } + } + } + + private void setCommittedOffset(final String topic, final int limitDelta) { + final Properties consumerConfig = new Properties(); + consumerConfig.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + consumerConfig.put(ConsumerConfig.GROUP_ID_CONFIG, appId); + consumerConfig.put(ConsumerConfig.CLIENT_ID_CONFIG, "commit-consumer"); + consumerConfig.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, IntegerDeserializer.class); + consumerConfig.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, IntegerDeserializer.class); + + final Consumer consumer = new KafkaConsumer<>(consumerConfig); + final List partitions = asList( + new TopicPartition(topic, 0), + new TopicPartition(topic, 1)); + + consumer.assign(partitions); + consumer.seekToEnd(partitions); + + for (final TopicPartition partition : partitions) { + final long position = consumer.position(partition); + consumer.seek(partition, position - limitDelta); + } + + consumer.commitSync(); + consumer.close(); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/RocksDBMetricsIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/RocksDBMetricsIntegrationTest.java new file mode 100644 index 0000000..c698d06 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/RocksDBMetricsIntegrationTest.java @@ -0,0 +1,340 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.StreamsTestUtils; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +import java.io.IOException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Properties; +import java.util.stream.Collectors; + +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; + +@Category({IntegrationTest.class}) +@RunWith(Parameterized.class) +@SuppressWarnings("deprecation") +public class RocksDBMetricsIntegrationTest { + + private static final int NUM_BROKERS = 3; + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(NUM_BROKERS); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + private static final String STREAM_INPUT_ONE = "STREAM_INPUT_ONE"; + private static final String STREAM_OUTPUT_ONE = "STREAM_OUTPUT_ONE"; + private static final String STREAM_INPUT_TWO = "STREAM_INPUT_TWO"; + private static final String STREAM_OUTPUT_TWO = "STREAM_OUTPUT_TWO"; + private static final String MY_STORE_PERSISTENT_KEY_VALUE = "myStorePersistentKeyValue"; + private static final Duration WINDOW_SIZE = Duration.ofMillis(50); + private static final long TIMEOUT = 60000; + + // RocksDB metrics + private static final String METRICS_GROUP = "stream-state-metrics"; + private static final String BYTES_WRITTEN_RATE = "bytes-written-rate"; + private static final String BYTES_WRITTEN_TOTAL = "bytes-written-total"; + private static final String BYTES_READ_RATE = "bytes-read-rate"; + private static final String BYTES_READ_TOTAL = "bytes-read-total"; + private static final String MEMTABLE_BYTES_FLUSHED_RATE = "memtable-bytes-flushed-rate"; + private static final String MEMTABLE_BYTES_FLUSHED_TOTAL = "memtable-bytes-flushed-total"; + private static final String MEMTABLE_HIT_RATIO = "memtable-hit-ratio"; + private static final String WRITE_STALL_DURATION_AVG = "write-stall-duration-avg"; + private static final String WRITE_STALL_DURATION_TOTAL = "write-stall-duration-total"; + private static final String BLOCK_CACHE_DATA_HIT_RATIO = "block-cache-data-hit-ratio"; + private static final String BLOCK_CACHE_INDEX_HIT_RATIO = "block-cache-index-hit-ratio"; + private static final String BLOCK_CACHE_FILTER_HIT_RATIO = "block-cache-filter-hit-ratio"; + private static final String BYTES_READ_DURING_COMPACTION_RATE = "bytes-read-compaction-rate"; + private static final String BYTES_WRITTEN_DURING_COMPACTION_RATE = "bytes-written-compaction-rate"; + private static final String NUMBER_OF_OPEN_FILES = "number-open-files"; + private static final String NUMBER_OF_FILE_ERRORS = "number-file-errors-total"; + private static final String NUMBER_OF_ENTRIES_ACTIVE_MEMTABLE = "num-entries-active-mem-table"; + private static final String NUMBER_OF_DELETES_ACTIVE_MEMTABLE = "num-deletes-active-mem-table"; + private static final String NUMBER_OF_ENTRIES_IMMUTABLE_MEMTABLES = "num-entries-imm-mem-tables"; + private static final String NUMBER_OF_DELETES_IMMUTABLE_MEMTABLES = "num-deletes-imm-mem-tables"; + private static final String NUMBER_OF_IMMUTABLE_MEMTABLES = "num-immutable-mem-table"; + private static final String CURRENT_SIZE_OF_ACTIVE_MEMTABLE = "cur-size-active-mem-table"; + private static final String CURRENT_SIZE_OF_ALL_MEMTABLES = "cur-size-all-mem-tables"; + private static final String SIZE_OF_ALL_MEMTABLES = "size-all-mem-tables"; + private static final String MEMTABLE_FLUSH_PENDING = "mem-table-flush-pending"; + private static final String NUMBER_OF_RUNNING_FLUSHES = "num-running-flushes"; + private static final String COMPACTION_PENDING = "compaction-pending"; + private static final String NUMBER_OF_RUNNING_COMPACTIONS = "num-running-compactions"; + private static final String ESTIMATED_BYTES_OF_PENDING_COMPACTION = "estimate-pending-compaction-bytes"; + private static final String TOTAL_SST_FILES_SIZE = "total-sst-files-size"; + private static final String LIVE_SST_FILES_SIZE = "live-sst-files-size"; + private static final String NUMBER_OF_LIVE_VERSIONS = "num-live-versions"; + private static final String CAPACITY_OF_BLOCK_CACHE = "block-cache-capacity"; + private static final String USAGE_OF_BLOCK_CACHE = "block-cache-usage"; + private static final String PINNED_USAGE_OF_BLOCK_CACHE = "block-cache-pinned-usage"; + private static final String ESTIMATED_NUMBER_OF_KEYS = "estimate-num-keys"; + private static final String ESTIMATED_MEMORY_OF_TABLE_READERS = "estimate-table-readers-mem"; + private static final String NUMBER_OF_BACKGROUND_ERRORS = "background-errors"; + + @SuppressWarnings("deprecation") + @Parameters(name = "{0}") + public static Collection data() { + return Arrays.asList(new Object[][] { + {StreamsConfig.AT_LEAST_ONCE}, + {StreamsConfig.EXACTLY_ONCE}, + {StreamsConfig.EXACTLY_ONCE_V2} + }); + } + + @Parameter + public String processingGuarantee; + + @Rule + public TestName testName = new TestName(); + + @Before + public void before() throws Exception { + CLUSTER.createTopic(STREAM_INPUT_ONE, 1, 3); + CLUSTER.createTopic(STREAM_INPUT_TWO, 1, 3); + } + + @After + public void after() throws Exception { + CLUSTER.deleteTopicsAndWait(STREAM_INPUT_ONE, STREAM_INPUT_TWO, STREAM_OUTPUT_ONE, STREAM_OUTPUT_TWO); + } + + @FunctionalInterface + private interface MetricsVerifier { + void verify(final KafkaStreams kafkaStreams, final String metricScope) throws Exception; + } + + @Test + public void shouldExposeRocksDBMetricsBeforeAndAfterFailureWithEmptyStateDir() throws Exception { + final Properties streamsConfiguration = streamsConfig(); + IntegrationTestUtils.purgeLocalStreamsState(streamsConfiguration); + final StreamsBuilder builder = builderForStateStores(); + + cleanUpStateRunVerifyAndClose( + builder, + streamsConfiguration, + this::verifyThatRocksDBMetricsAreExposed + ); + + // simulated failure + + cleanUpStateRunVerifyAndClose( + builder, + streamsConfiguration, + this::verifyThatRocksDBMetricsAreExposed + ); + } + + private Properties streamsConfig() { + final Properties streamsConfiguration = new Properties(); + final String safeTestName = safeUniqueTestName(getClass(), testName); + streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, "test-application-" + safeTestName); + streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.Integer().getClass()); + streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + streamsConfiguration.put(StreamsConfig.METRICS_RECORDING_LEVEL_CONFIG, Sensor.RecordingLevel.DEBUG.name); + streamsConfiguration.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, processingGuarantee); + streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()); + streamsConfiguration.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + return streamsConfiguration; + } + + private StreamsBuilder builderForStateStores() { + final StreamsBuilder builder = new StreamsBuilder(); + // create two state stores, one non-segmented and one segmented + builder.table( + STREAM_INPUT_ONE, + Materialized.as(Stores.persistentKeyValueStore(MY_STORE_PERSISTENT_KEY_VALUE)).withCachingEnabled() + ).toStream().to(STREAM_OUTPUT_ONE); + builder.stream(STREAM_INPUT_TWO, Consumed.with(Serdes.Integer(), Serdes.String())) + .groupByKey() + .windowedBy(TimeWindows.of(WINDOW_SIZE).grace(Duration.ZERO)) + .aggregate(() -> 0L, + (aggKey, newValue, aggValue) -> aggValue, + Materialized.>as("time-windowed-aggregated-stream-store") + .withValueSerde(Serdes.Long()) + .withRetention(WINDOW_SIZE)) + .toStream() + .map((key, value) -> KeyValue.pair(value, value)) + .to(STREAM_OUTPUT_TWO, Produced.with(Serdes.Long(), Serdes.Long())); + return builder; + } + + private void cleanUpStateRunVerifyAndClose(final StreamsBuilder builder, + final Properties streamsConfiguration, + final MetricsVerifier metricsVerifier) throws Exception { + final KafkaStreams kafkaStreams = new KafkaStreams(builder.build(), streamsConfiguration); + kafkaStreams.cleanUp(); + produceRecords(); + + StreamsTestUtils.startKafkaStreamsAndWaitForRunningState(kafkaStreams, TIMEOUT); + + metricsVerifier.verify(kafkaStreams, "rocksdb-state-id"); + metricsVerifier.verify(kafkaStreams, "rocksdb-window-state-id"); + kafkaStreams.close(); + } + + private void produceRecords() { + final MockTime mockTime = new MockTime(WINDOW_SIZE.toMillis()); + final Properties prop = TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + IntegerSerializer.class, + StringSerializer.class, + new Properties() + ); + // non-segmented store do not need records with different timestamps + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + STREAM_INPUT_ONE, + Utils.mkSet(new KeyValue<>(1, "A"), new KeyValue<>(1, "B"), new KeyValue<>(1, "C")), + prop, + mockTime.milliseconds() + ); + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + STREAM_INPUT_TWO, + Collections.singleton(new KeyValue<>(1, "A")), + prop, + mockTime.milliseconds() + ); + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + STREAM_INPUT_TWO, + Collections.singleton(new KeyValue<>(1, "B")), + prop, + mockTime.milliseconds() + ); + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + STREAM_INPUT_TWO, + Collections.singleton(new KeyValue<>(1, "C")), + prop, + mockTime.milliseconds() + ); + } + + private void verifyThatRocksDBMetricsAreExposed(final KafkaStreams kafkaStreams, + final String metricsScope) { + final List listMetricStore = getRocksDBMetrics(kafkaStreams, metricsScope); + checkMetricByName(listMetricStore, BYTES_WRITTEN_RATE, 1); + checkMetricByName(listMetricStore, BYTES_WRITTEN_TOTAL, 1); + checkMetricByName(listMetricStore, BYTES_READ_RATE, 1); + checkMetricByName(listMetricStore, BYTES_READ_TOTAL, 1); + checkMetricByName(listMetricStore, MEMTABLE_BYTES_FLUSHED_RATE, 1); + checkMetricByName(listMetricStore, MEMTABLE_BYTES_FLUSHED_TOTAL, 1); + checkMetricByName(listMetricStore, MEMTABLE_HIT_RATIO, 1); + checkMetricByName(listMetricStore, WRITE_STALL_DURATION_AVG, 1); + checkMetricByName(listMetricStore, WRITE_STALL_DURATION_TOTAL, 1); + checkMetricByName(listMetricStore, BLOCK_CACHE_DATA_HIT_RATIO, 1); + checkMetricByName(listMetricStore, BLOCK_CACHE_INDEX_HIT_RATIO, 1); + checkMetricByName(listMetricStore, BLOCK_CACHE_FILTER_HIT_RATIO, 1); + checkMetricByName(listMetricStore, BYTES_READ_DURING_COMPACTION_RATE, 1); + checkMetricByName(listMetricStore, BYTES_WRITTEN_DURING_COMPACTION_RATE, 1); + checkMetricByName(listMetricStore, NUMBER_OF_OPEN_FILES, 1); + checkMetricByName(listMetricStore, NUMBER_OF_FILE_ERRORS, 1); + checkMetricByName(listMetricStore, NUMBER_OF_ENTRIES_ACTIVE_MEMTABLE, 1); + checkMetricByName(listMetricStore, NUMBER_OF_DELETES_ACTIVE_MEMTABLE, 1); + checkMetricByName(listMetricStore, NUMBER_OF_ENTRIES_IMMUTABLE_MEMTABLES, 1); + checkMetricByName(listMetricStore, NUMBER_OF_DELETES_IMMUTABLE_MEMTABLES, 1); + checkMetricByName(listMetricStore, NUMBER_OF_IMMUTABLE_MEMTABLES, 1); + checkMetricByName(listMetricStore, CURRENT_SIZE_OF_ACTIVE_MEMTABLE, 1); + checkMetricByName(listMetricStore, CURRENT_SIZE_OF_ALL_MEMTABLES, 1); + checkMetricByName(listMetricStore, SIZE_OF_ALL_MEMTABLES, 1); + checkMetricByName(listMetricStore, MEMTABLE_FLUSH_PENDING, 1); + checkMetricByName(listMetricStore, NUMBER_OF_RUNNING_FLUSHES, 1); + checkMetricByName(listMetricStore, COMPACTION_PENDING, 1); + checkMetricByName(listMetricStore, NUMBER_OF_RUNNING_COMPACTIONS, 1); + checkMetricByName(listMetricStore, ESTIMATED_BYTES_OF_PENDING_COMPACTION, 1); + checkMetricByName(listMetricStore, TOTAL_SST_FILES_SIZE, 1); + checkMetricByName(listMetricStore, LIVE_SST_FILES_SIZE, 1); + checkMetricByName(listMetricStore, NUMBER_OF_LIVE_VERSIONS, 1); + checkMetricByName(listMetricStore, CAPACITY_OF_BLOCK_CACHE, 1); + checkMetricByName(listMetricStore, USAGE_OF_BLOCK_CACHE, 1); + checkMetricByName(listMetricStore, PINNED_USAGE_OF_BLOCK_CACHE, 1); + checkMetricByName(listMetricStore, ESTIMATED_NUMBER_OF_KEYS, 1); + checkMetricByName(listMetricStore, ESTIMATED_MEMORY_OF_TABLE_READERS, 1); + checkMetricByName(listMetricStore, NUMBER_OF_BACKGROUND_ERRORS, 1); + } + + private void checkMetricByName(final List listMetric, + final String metricName, + final int numMetric) { + final List metrics = listMetric.stream() + .filter(m -> m.metricName().name().equals(metricName)) + .collect(Collectors.toList()); + assertThat( + "Size of metrics of type:'" + metricName + "' must be equal to " + numMetric + " but it's equal to " + metrics.size(), + metrics.size(), + is(numMetric) + ); + for (final Metric metric : metrics) { + assertThat("Metric:'" + metric.metricName() + "' must be not null", metric.metricValue(), is(notNullValue())); + } + } + + private List getRocksDBMetrics(final KafkaStreams kafkaStreams, + final String metricsScope) { + return new ArrayList(kafkaStreams.metrics().values()).stream() + .filter(m -> m.metricName().group().equals(METRICS_GROUP) && m.metricName().tags().containsKey(metricsScope)) + .collect(Collectors.toList()); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/SmokeTestDriverIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/SmokeTestDriverIntegrationTest.java new file mode 100644 index 0000000..22d7735 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/SmokeTestDriverIntegrationTest.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.tests.SmokeTestClient; +import org.apache.kafka.streams.tests.SmokeTestDriver; +import org.apache.kafka.test.IntegrationTest; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; + +import java.io.IOException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Map; +import java.util.Properties; +import java.util.Set; + +import static org.apache.kafka.streams.tests.SmokeTestDriver.generate; +import static org.apache.kafka.streams.tests.SmokeTestDriver.verify; + +@Category(IntegrationTest.class) +public class SmokeTestDriverIntegrationTest { + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(3); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + + private static class Driver extends Thread { + private final String bootstrapServers; + private final int numKeys; + private final int maxRecordsPerKey; + private Exception exception = null; + private SmokeTestDriver.VerificationResult result; + + private Driver(final String bootstrapServers, final int numKeys, final int maxRecordsPerKey) { + this.bootstrapServers = bootstrapServers; + this.numKeys = numKeys; + this.maxRecordsPerKey = maxRecordsPerKey; + } + + @Override + public void run() { + try { + final Map> allData = + generate(bootstrapServers, numKeys, maxRecordsPerKey, Duration.ofSeconds(20)); + result = verify(bootstrapServers, allData, maxRecordsPerKey); + + } catch (final Exception ex) { + this.exception = ex; + } + } + + public Exception exception() { + return exception; + } + + SmokeTestDriver.VerificationResult result() { + return result; + } + + } + + // In this test, we try to keep creating new stream, and closing the old one, to maintain only 3 streams alive. + // During the new stream added and old stream left, the stream process should still complete without issue. + // We set 2 timeout condition to fail the test before passing the verification: + // (1) 6 min timeout, (2) 30 tries of polling without getting any data + @Test + public void shouldWorkWithRebalance() throws InterruptedException { + Exit.setExitProcedure((statusCode, message) -> { + throw new AssertionError("Test called exit(). code:" + statusCode + " message:" + message); + }); + Exit.setHaltProcedure((statusCode, message) -> { + throw new AssertionError("Test called halt(). code:" + statusCode + " message:" + message); + }); + int numClientsCreated = 0; + final ArrayList clients = new ArrayList<>(); + + IntegrationTestUtils.cleanStateBeforeTest(CLUSTER, SmokeTestDriver.topics()); + + final String bootstrapServers = CLUSTER.bootstrapServers(); + final Driver driver = new Driver(bootstrapServers, 10, 1000); + driver.start(); + System.out.println("started driver"); + + + final Properties props = new Properties(); + props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers); + // decrease the session timeout so that we can trigger the rebalance soon after old client left closed + props.put(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, 10000); + + // cycle out Streams instances as long as the test is running. + while (driver.isAlive()) { + // take a nap + Thread.sleep(1000); + + // add a new client + final SmokeTestClient smokeTestClient = new SmokeTestClient("streams-" + numClientsCreated++); + clients.add(smokeTestClient); + smokeTestClient.start(props); + + // let the oldest client die of "natural causes" + if (clients.size() >= 3) { + final SmokeTestClient client = clients.remove(0); + + client.closeAsync(); + while (!client.closed()) { + Thread.sleep(100); + } + } + } + + try { + // wait for verification to finish + driver.join(); + } finally { + // whether or not the assertions failed, tell all the streams instances to stop + for (final SmokeTestClient client : clients) { + client.closeAsync(); + } + + // then, wait for them to stop + for (final SmokeTestClient client : clients) { + while (!client.closed()) { + Thread.sleep(100); + } + } + } + + // check to make sure that it actually succeeded + if (driver.exception() != null) { + driver.exception().printStackTrace(); + throw new AssertionError(driver.exception()); + } + Assert.assertTrue(driver.result().result(), driver.result().passed()); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/StandbyTaskCreationIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/StandbyTaskCreationIntegrationTest.java new file mode 100644 index 0000000..28eeeef --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/StandbyTaskCreationIntegrationTest.java @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KafkaStreams.State; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.ThreadMetadata; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Transformer; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; + +import java.io.IOException; +import java.util.Properties; +import java.util.function.Predicate; + +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; + +@Category({IntegrationTest.class}) +public class StandbyTaskCreationIntegrationTest { + + private static final int NUM_BROKERS = 1; + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(NUM_BROKERS); + + @BeforeClass + public static void startCluster() throws IOException, InterruptedException { + CLUSTER.start(); + CLUSTER.createTopic(INPUT_TOPIC, 2, 1); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + @Rule + public TestName testName = new TestName(); + + private static final String INPUT_TOPIC = "input-topic"; + + private KafkaStreams client1; + private KafkaStreams client2; + private volatile boolean client1IsOk = false; + private volatile boolean client2IsOk = false; + + @After + public void after() { + client1.close(); + client2.close(); + } + + private Properties streamsConfiguration() { + final String safeTestName = safeUniqueTestName(getClass(), testName); + final Properties streamsConfiguration = new Properties(); + streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, "app-" + safeTestName); + streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()); + streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.Integer().getClass()); + streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.Integer().getClass()); + streamsConfiguration.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1); + return streamsConfiguration; + } + + @Test + public void shouldNotCreateAnyStandByTasksForStateStoreWithLoggingDisabled() throws Exception { + final StreamsBuilder builder = new StreamsBuilder(); + final String stateStoreName = "myTransformState"; + final StoreBuilder> keyValueStoreBuilder = + Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore(stateStoreName), + Serdes.Integer(), + Serdes.Integer()).withLoggingDisabled(); + builder.addStateStore(keyValueStoreBuilder); + builder.stream(INPUT_TOPIC, Consumed.with(Serdes.Integer(), Serdes.Integer())) + .transform(() -> new Transformer>() { + @Override + public void init(final ProcessorContext context) {} + + @Override + public KeyValue transform(final Integer key, final Integer value) { + return null; + } + + @Override + public void close() {} + }, stateStoreName); + + final Topology topology = builder.build(); + createClients(topology, streamsConfiguration(), topology, streamsConfiguration()); + + setStateListenersForVerification(thread -> thread.standbyTasks().isEmpty() && !thread.activeTasks().isEmpty()); + + startClients(); + + waitUntilBothClientAreOK( + "At least one client did not reach state RUNNING with active tasks but no stand-by tasks" + ); + } + + @Test + public void shouldCreateStandByTasksForMaterializedAndOptimizedSourceTables() throws Exception { + final Properties streamsConfiguration1 = streamsConfiguration(); + streamsConfiguration1.put(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, StreamsConfig.OPTIMIZE); + final Properties streamsConfiguration2 = streamsConfiguration(); + streamsConfiguration2.put(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, StreamsConfig.OPTIMIZE); + + final StreamsBuilder builder = new StreamsBuilder(); + builder.table(INPUT_TOPIC, Consumed.with(Serdes.Integer(), Serdes.Integer()), Materialized.as("source-table")); + + createClients( + builder.build(streamsConfiguration1), + streamsConfiguration1, + builder.build(streamsConfiguration2), + streamsConfiguration2 + ); + + setStateListenersForVerification(thread -> !thread.standbyTasks().isEmpty() && !thread.activeTasks().isEmpty()); + + startClients(); + + waitUntilBothClientAreOK( + "At least one client did not reach state RUNNING with active tasks and stand-by tasks" + ); + } + + private void createClients(final Topology topology1, + final Properties streamsConfiguration1, + final Topology topology2, + final Properties streamsConfiguration2) { + + client1 = new KafkaStreams(topology1, streamsConfiguration1); + client2 = new KafkaStreams(topology2, streamsConfiguration2); + } + + private void setStateListenersForVerification(final Predicate taskCondition) { + client1.setStateListener((newState, oldState) -> { + if (newState == State.RUNNING && + client1.metadataForLocalThreads().stream().allMatch(taskCondition)) { + + client1IsOk = true; + } + }); + client2.setStateListener((newState, oldState) -> { + if (newState == State.RUNNING && + client2.metadataForLocalThreads().stream().allMatch(taskCondition)) { + + client2IsOk = true; + } + }); + } + + private void startClients() { + client1.start(); + client2.start(); + } + + private void waitUntilBothClientAreOK(final String message) throws Exception { + TestUtils.waitForCondition( + () -> client1IsOk && client2IsOk, + 30 * 1000, + message + ": " + + "Client 1 is " + (!client1IsOk ? "NOT " : "") + "OK, " + + "client 2 is " + (!client2IsOk ? "NOT " : "") + "OK." + ); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/StandbyTaskEOSIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/StandbyTaskEOSIntegrationTest.java new file mode 100644 index 0000000..4fbe734 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/StandbyTaskEOSIntegrationTest.java @@ -0,0 +1,416 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.IntegerDeserializer; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StoreQueryParameters; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Transformer; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.StateDirectory; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.QueryableStoreTypes; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.internals.OffsetCheckpoint; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.File; +import java.io.IOException; +import java.time.Duration; +import java.util.Collection; +import java.util.Collections; +import java.util.Properties; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import static java.util.Arrays.asList; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.startApplicationAndWaitUntilRunning; +import static org.apache.kafka.test.TestUtils.waitForCondition; +import static org.junit.Assert.assertTrue; + +/** + * An integration test to verify the conversion of a dirty-closed EOS + * task towards a standby task is safe across restarts of the application. + */ +@RunWith(Parameterized.class) +@Category(IntegrationTest.class) +public class StandbyTaskEOSIntegrationTest { + + private final static long REBALANCE_TIMEOUT = Duration.ofMinutes(2L).toMillis(); + private final static int KEY_0 = 0; + private final static int KEY_1 = 1; + + @SuppressWarnings("deprecation") + @Parameterized.Parameters(name = "{0}") + public static Collection data() { + return asList(new String[][] { + {StreamsConfig.EXACTLY_ONCE}, + {StreamsConfig.EXACTLY_ONCE_V2} + }); + } + + @Parameterized.Parameter + public String eosConfig; + + private final AtomicBoolean skipRecord = new AtomicBoolean(false); + + private String appId; + private String inputTopic; + private String storeName; + private String outputTopic; + + private KafkaStreams streamInstanceOne; + private KafkaStreams streamInstanceTwo; + private KafkaStreams streamInstanceOneRecovery; + + private static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(3); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + @Rule + public TestName testName = new TestName(); + + @Before + public void createTopics() throws Exception { + final String safeTestName = safeUniqueTestName(getClass(), testName); + appId = "app-" + safeTestName; + inputTopic = "input-" + safeTestName; + outputTopic = "output-" + safeTestName; + storeName = "store-" + safeTestName; + CLUSTER.deleteTopicsAndWait(inputTopic, outputTopic, appId + "-KSTREAM-AGGREGATE-STATE-STORE-0000000001-changelog"); + CLUSTER.createTopic(inputTopic, 1, 3); + CLUSTER.createTopic(outputTopic, 1, 3); + } + + @After + public void cleanUp() { + if (streamInstanceOne != null) { + streamInstanceOne.close(); + } + if (streamInstanceTwo != null) { + streamInstanceTwo.close(); + } + if (streamInstanceOneRecovery != null) { + streamInstanceOneRecovery.close(); + } + } + + @Test + public void shouldSurviveWithOneTaskAsStandby() throws Exception { + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + inputTopic, + Collections.singletonList( + new KeyValue<>(0, 0) + ), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + IntegerSerializer.class, + IntegerSerializer.class, + new Properties() + ), + 10L + ); + + final String stateDirPath = TestUtils.tempDirectory(appId).getPath(); + + final CountDownLatch instanceLatch = new CountDownLatch(1); + + streamInstanceOne = buildStreamWithDirtyStateDir(stateDirPath + "/" + appId + "-1/", instanceLatch); + streamInstanceTwo = buildStreamWithDirtyStateDir(stateDirPath + "/" + appId + "-2/", instanceLatch); + + startApplicationAndWaitUntilRunning(asList(streamInstanceOne, streamInstanceTwo), Duration.ofSeconds(60)); + + // Wait for the record to be processed + assertTrue(instanceLatch.await(15, TimeUnit.SECONDS)); + + streamInstanceOne.close(); + streamInstanceTwo.close(); + + streamInstanceOne.cleanUp(); + streamInstanceTwo.cleanUp(); + } + + private KafkaStreams buildStreamWithDirtyStateDir(final String stateDirPath, + final CountDownLatch recordProcessLatch) throws Exception { + + final StreamsBuilder builder = new StreamsBuilder(); + final TaskId taskId = new TaskId(0, 0); + + final Properties props = props(stateDirPath); + + final StateDirectory stateDirectory = new StateDirectory( + new StreamsConfig(props), new MockTime(), true, false); + + new OffsetCheckpoint(new File(stateDirectory.getOrCreateDirectoryForTask(taskId), ".checkpoint")) + .write(Collections.singletonMap(new TopicPartition("unknown-topic", 0), 5L)); + + assertTrue(new File(stateDirectory.getOrCreateDirectoryForTask(taskId), + "rocksdb/KSTREAM-AGGREGATE-STATE-STORE-0000000001").mkdirs()); + + builder.stream(inputTopic, + Consumed.with(Serdes.Integer(), Serdes.Integer())) + .groupByKey() + .count() + .toStream() + .peek((key, value) -> recordProcessLatch.countDown()); + + return new KafkaStreams(builder.build(), props); + } + + @Test + public void shouldWipeOutStandbyStateDirectoryIfCheckpointIsMissing() throws Exception { + final long time = System.currentTimeMillis(); + final String base = TestUtils.tempDirectory(appId).getPath(); + + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + inputTopic, + Collections.singletonList( + new KeyValue<>(KEY_0, 0) + ), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + IntegerSerializer.class, + IntegerSerializer.class, + new Properties() + ), + 10L + time + ); + + streamInstanceOne = buildWithDeduplicationTopology(base + "-1"); + streamInstanceTwo = buildWithDeduplicationTopology(base + "-2"); + + // start first instance and wait for processing + startApplicationAndWaitUntilRunning(Collections.singletonList(streamInstanceOne), Duration.ofSeconds(30)); + IntegrationTestUtils.waitUntilMinRecordsReceived( + TestUtils.consumerConfig( + CLUSTER.bootstrapServers(), + IntegerDeserializer.class, + IntegerDeserializer.class + ), + outputTopic, + 1 + ); + + // start second instance and wait for standby replication + startApplicationAndWaitUntilRunning(Collections.singletonList(streamInstanceTwo), Duration.ofSeconds(30)); + waitForCondition( + () -> streamInstanceTwo.store( + StoreQueryParameters.fromNameAndType( + storeName, + QueryableStoreTypes.keyValueStore() + ).enableStaleStores() + ).get(KEY_0) != null, + REBALANCE_TIMEOUT, + "Could not get key from standby store" + ); + // sanity check that first instance is still active + waitForCondition( + () -> streamInstanceOne.store( + StoreQueryParameters.fromNameAndType( + storeName, + QueryableStoreTypes.keyValueStore() + ) + ).get(KEY_0) != null, + "Could not get key from main store" + ); + + // inject poison pill and wait for crash of first instance and recovery on second instance + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + inputTopic, + Collections.singletonList( + new KeyValue<>(KEY_1, 0) + ), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + IntegerSerializer.class, + IntegerSerializer.class, + new Properties() + ), + 10L + time + ); + waitForCondition( + () -> streamInstanceOne.state() == KafkaStreams.State.ERROR, + "Stream instance 1 did not go into error state" + ); + streamInstanceOne.close(); + + IntegrationTestUtils.waitUntilMinRecordsReceived( + TestUtils.consumerConfig( + CLUSTER.bootstrapServers(), + IntegerDeserializer.class, + IntegerDeserializer.class + ), + outputTopic, + 2 + ); + + streamInstanceOneRecovery = buildWithDeduplicationTopology(base + "-1"); + + // "restart" first client and wait for standby recovery + // (could actually also be active, but it does not matter as long as we enable "state stores" + startApplicationAndWaitUntilRunning( + Collections.singletonList(streamInstanceOneRecovery), + Duration.ofSeconds(30) + ); + waitForCondition( + () -> streamInstanceOneRecovery.store( + StoreQueryParameters.fromNameAndType( + storeName, + QueryableStoreTypes.keyValueStore() + ).enableStaleStores() + ).get(KEY_0) != null, + "Could not get key from recovered standby store" + ); + + streamInstanceTwo.close(); + waitForCondition( + () -> streamInstanceOneRecovery.store( + StoreQueryParameters.fromNameAndType( + storeName, + QueryableStoreTypes.keyValueStore() + ) + ).get(KEY_0) != null, + REBALANCE_TIMEOUT, + "Could not get key from recovered main store" + ); + + // re-inject poison pill and wait for crash of first instance + skipRecord.set(false); + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + inputTopic, + Collections.singletonList( + new KeyValue<>(KEY_1, 0) + ), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + IntegerSerializer.class, + IntegerSerializer.class, + new Properties() + ), + 10L + time + ); + waitForCondition( + () -> streamInstanceOneRecovery.state() == KafkaStreams.State.ERROR, + "Stream instance 1 did not go into error state. Is in " + streamInstanceOneRecovery.state() + " state." + ); + } + + private KafkaStreams buildWithDeduplicationTopology(final String stateDirPath) { + final StreamsBuilder builder = new StreamsBuilder(); + + builder.addStateStore(Stores.keyValueStoreBuilder( + Stores.persistentKeyValueStore(storeName), + Serdes.Integer(), + Serdes.Integer()) + ); + builder.stream(inputTopic) + .transform( + () -> new Transformer>() { + private KeyValueStore store; + + @SuppressWarnings("unchecked") + @Override + public void init(final ProcessorContext context) { + store = (KeyValueStore) context.getStateStore(storeName); + } + + @Override + public KeyValue transform(final Integer key, final Integer value) { + if (skipRecord.get()) { + // we only forward so we can verify the skipping by reading the output topic + // the goal is skipping is to not modify the state store + return KeyValue.pair(key, value); + } + + if (store.get(key) != null) { + return null; + } + + store.put(key, value); + store.flush(); + + if (key == KEY_1) { + // after error injection, we need to avoid a consecutive error after rebalancing + skipRecord.set(true); + throw new RuntimeException("Injected test error"); + } + + return KeyValue.pair(key, value); + } + + @Override + public void close() { } + }, + storeName + ) + .to(outputTopic); + + return new KafkaStreams(builder.build(), props(stateDirPath)); + } + + + private Properties props(final String stateDirPath) { + final Properties streamsConfiguration = new Properties(); + streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, appId); + streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + streamsConfiguration.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, stateDirPath); + streamsConfiguration.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1); + streamsConfiguration.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, eosConfig); + streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.Integer().getClass()); + streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.Integer().getClass()); + streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + // need to set to zero to get predictable active/standby task assignments + streamsConfiguration.put(StreamsConfig.ACCEPTABLE_RECOVERY_LAG_CONFIG, 0); + streamsConfiguration.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + + return streamsConfiguration; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/StateDirectoryIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/StateDirectoryIntegrationTest.java new file mode 100644 index 0000000..9c34fba --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/StateDirectoryIntegrationTest.java @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import java.io.File; +import java.io.IOException; +import java.util.Arrays; +import java.util.Properties; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestUtils; + + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkProperties; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; +import static org.junit.Assert.assertTrue; + +@Category(IntegrationTest.class) +public class StateDirectoryIntegrationTest { + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(3); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + @Rule + public TestName testName = new TestName(); + + @Test + public void testCleanUpStateDirIfEmpty() throws InterruptedException { + final String uniqueTestName = safeUniqueTestName(getClass(), testName); + + // Create Topic + final String input = uniqueTestName + "-input"; + CLUSTER.createTopic(input); + + final Properties producerConfig = mkProperties(mkMap( + mkEntry(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()), + mkEntry(ProducerConfig.ACKS_CONFIG, "all"), + mkEntry(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class.getCanonicalName()), + mkEntry(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, StringSerializer.class.getCanonicalName()) + )); + + try (final KafkaProducer producer = + new KafkaProducer<>(producerConfig, Serdes.String().serializer(), Serdes.String().serializer())) { + // Create Test Records + producer.send(new ProducerRecord<>(input, "a")); + producer.send(new ProducerRecord<>(input, "b")); + producer.send(new ProducerRecord<>(input, "c")); + + // Create Topology + final String storeName = uniqueTestName + "-input-table"; + + final StreamsBuilder builder = new StreamsBuilder(); + builder.table( + input, + Materialized + .>as(storeName) + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.String()) + ); + final Topology topology = builder.build(); + + // State Store Directory + final String stateDir = TestUtils.tempDirectory(uniqueTestName).getPath(); + + // Create KafkaStreams instance + final String applicationId = uniqueTestName + "-app"; + final Properties streamsConfig = mkProperties(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, applicationId), + mkEntry(StreamsConfig.STATE_DIR_CONFIG, stateDir), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()) + )); + + final KafkaStreams streams = new KafkaStreams(topology, streamsConfig); + + // Create StateListener + final CountDownLatch runningLatch = new CountDownLatch(1); + final CountDownLatch notRunningLatch = new CountDownLatch(1); + + final KafkaStreams.StateListener stateListener = (newState, oldState) -> { + if (newState == KafkaStreams.State.RUNNING) { + runningLatch.countDown(); + } + if (newState == KafkaStreams.State.NOT_RUNNING) { + notRunningLatch.countDown(); + } + }; + streams.setStateListener(stateListener); + + // Application state directory + final File appDir = new File(stateDir, applicationId); + + // Validate application state directory is created. + streams.start(); + try { + runningLatch.await(IntegrationTestUtils.DEFAULT_TIMEOUT, TimeUnit.MILLISECONDS); + } catch (final InterruptedException e) { + throw new RuntimeException("Streams didn't start in time.", e); + } + + assertTrue((new File(stateDir)).exists()); // State directory exists + assertTrue(appDir.exists()); // Application state directory Exists + + // Validate StateStore directory is deleted. + streams.close(); + try { + notRunningLatch.await(IntegrationTestUtils.DEFAULT_TIMEOUT, TimeUnit.MILLISECONDS); + } catch (final InterruptedException e) { + throw new RuntimeException("Streams didn't cleaned up in time.", e); + } + + streams.cleanUp(); + + assertTrue((new File(stateDir)).exists()); // Root state store exists + + // case 1: the state directory is cleaned up without any problems. + // case 2: The state directory is not cleaned up, for it does not include any checkpoint file. + // case 3: The state directory is not cleaned up, for it includes a checkpoint file but it is empty. + assertTrue(appDir.exists() + || Arrays.stream(appDir.listFiles()) + .filter( + (File f) -> f.isDirectory() && f.listFiles().length > 0 && !(new File(f, ".checkpoint")).exists() + ).findFirst().isPresent() + || Arrays.stream(appDir.listFiles()) + .filter( + (File f) -> f.isDirectory() && (new File(f, ".checkpoint")).length() == 0L + ).findFirst().isPresent() + ); + } finally { + CLUSTER.deleteAllTopicsAndWait(0L); + } + } + + @Test + public void testNotCleanUpStateDirIfNotEmpty() throws InterruptedException { + final String uniqueTestName = safeUniqueTestName(getClass(), testName); + + // Create Topic + final String input = uniqueTestName + "-input"; + CLUSTER.createTopic(input); + + final Properties producerConfig = mkProperties(mkMap( + mkEntry(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()), + mkEntry(ProducerConfig.ACKS_CONFIG, "all"), + mkEntry(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class.getCanonicalName()), + mkEntry(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, StringSerializer.class.getCanonicalName()) + )); + + try (final KafkaProducer producer = + new KafkaProducer<>(producerConfig, Serdes.String().serializer(), Serdes.String().serializer())) { + // Create Test Records + producer.send(new ProducerRecord<>(input, "a")); + producer.send(new ProducerRecord<>(input, "b")); + producer.send(new ProducerRecord<>(input, "c")); + + // Create Topology + final String storeName = uniqueTestName + "-input-table"; + + final StreamsBuilder builder = new StreamsBuilder(); + builder.table( + input, + Materialized + .>as(storeName) + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.String()) + ); + final Topology topology = builder.build(); + + // State Store Directory + final String stateDir = TestUtils.tempDirectory(uniqueTestName).getPath(); + + // Create KafkaStreams instance + final String applicationId = uniqueTestName + "-app"; + final Properties streamsConfig = mkProperties(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, applicationId), + mkEntry(StreamsConfig.STATE_DIR_CONFIG, stateDir), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()) + )); + + final KafkaStreams streams = new KafkaStreams(topology, streamsConfig); + + // Create StateListener + final CountDownLatch runningLatch = new CountDownLatch(1); + final CountDownLatch notRunningLatch = new CountDownLatch(1); + + final KafkaStreams.StateListener stateListener = (newState, oldState) -> { + if (newState == KafkaStreams.State.RUNNING) { + runningLatch.countDown(); + } + if (newState == KafkaStreams.State.NOT_RUNNING) { + notRunningLatch.countDown(); + } + }; + streams.setStateListener(stateListener); + + // Application state directory + final File appDir = new File(stateDir, applicationId); + + // Validate application state directory is created. + streams.start(); + try { + runningLatch.await(IntegrationTestUtils.DEFAULT_TIMEOUT, TimeUnit.MILLISECONDS); + } catch (final InterruptedException e) { + throw new RuntimeException("Streams didn't start in time.", e); + } + + assertTrue((new File(stateDir)).exists()); // State directory exists + assertTrue(appDir.exists()); // Application state directory Exists + + try { + assertTrue((new File(appDir, "dummy")).createNewFile()); + } catch (final IOException e) { + throw new RuntimeException("Failed to create dummy file.", e); + } + + // Validate StateStore directory is deleted. + streams.close(); + try { + notRunningLatch.await(IntegrationTestUtils.DEFAULT_TIMEOUT, TimeUnit.MILLISECONDS); + } catch (final InterruptedException e) { + throw new RuntimeException("Streams didn't cleaned up in time.", e); + } + + streams.cleanUp(); + + assertTrue((new File(stateDir)).exists()); // Root state store exists + assertTrue(appDir.exists()); // Application state store exists + } finally { + CLUSTER.deleteAllTopicsAndWait(0L); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/StateRestorationIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/StateRestorationIntegrationTest.java new file mode 100644 index 0000000..d890a30 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/StateRestorationIntegrationTest.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import kafka.utils.MockTime; +import org.apache.kafka.common.serialization.BytesDeserializer; +import org.apache.kafka.common.serialization.BytesSerializer; +import org.apache.kafka.common.serialization.IntegerDeserializer; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.StreamsTestUtils; +import org.apache.kafka.test.TestUtils; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Properties; + +@Category({IntegrationTest.class}) +public class StateRestorationIntegrationTest { + private final StreamsBuilder builder = new StreamsBuilder(); + + private static final String APPLICATION_ID = "restoration-test-app"; + private static final String STATE_STORE_NAME = "stateStore"; + private static final String INPUT_TOPIC = "input"; + private static final String OUTPUT_TOPIC = "output"; + + private Properties streamsConfiguration; + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(1); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + private final MockTime mockTime = CLUSTER.time; + + @Before + public void setUp() throws Exception { + final Properties props = new Properties(); + + streamsConfiguration = StreamsTestUtils.getStreamsConfig( + APPLICATION_ID, + CLUSTER.bootstrapServers(), + Serdes.Integer().getClass().getName(), + Serdes.ByteArray().getClass().getName(), + props); + + CLUSTER.createTopics(INPUT_TOPIC); + CLUSTER.createTopics(OUTPUT_TOPIC); + + IntegrationTestUtils.purgeLocalStreamsState(streamsConfiguration); + } + + @Test + public void shouldRestoreNullRecord() throws Exception { + builder.table(INPUT_TOPIC, Materialized.as( + Stores.persistentTimestampedKeyValueStore(STATE_STORE_NAME)) + .withKeySerde(Serdes.Integer()) + .withValueSerde(Serdes.Bytes()) + .withCachingDisabled()).toStream().to(OUTPUT_TOPIC); + + final Properties producerConfig = TestUtils.producerConfig( + CLUSTER.bootstrapServers(), IntegerSerializer.class, BytesSerializer.class); + + final List> initialKeyValues = Arrays.asList( + KeyValue.pair(3, new Bytes(new byte[]{3})), + KeyValue.pair(3, null), + KeyValue.pair(1, new Bytes(new byte[]{1}))); + + IntegrationTestUtils.produceKeyValuesSynchronously( + INPUT_TOPIC, initialKeyValues, producerConfig, mockTime); + + KafkaStreams streams = new KafkaStreams(builder.build(streamsConfiguration), streamsConfiguration); + streams.start(); + + final Properties consumerConfig = TestUtils.consumerConfig( + CLUSTER.bootstrapServers(), IntegerDeserializer.class, BytesDeserializer.class); + + IntegrationTestUtils.waitUntilFinalKeyValueRecordsReceived( + consumerConfig, OUTPUT_TOPIC, initialKeyValues); + + // wipe out state store to trigger restore process on restart + streams.close(); + streams.cleanUp(); + + // Restart the stream instance. There should not be exception handling the null value within changelog topic. + final List> newKeyValues = + Collections.singletonList(KeyValue.pair(2, new Bytes(new byte[3]))); + IntegrationTestUtils.produceKeyValuesSynchronously( + INPUT_TOPIC, newKeyValues, producerConfig, mockTime); + streams = new KafkaStreams(builder.build(streamsConfiguration), streamsConfiguration); + streams.start(); + IntegrationTestUtils.waitUntilFinalKeyValueRecordsReceived( + consumerConfig, OUTPUT_TOPIC, newKeyValues); + streams.close(); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/StoreQueryIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/StoreQueryIntegrationTest.java new file mode 100644 index 0000000..fb2d0a1 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/StoreQueryIntegrationTest.java @@ -0,0 +1,608 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyQueryMetadata; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StoreQueryParameters; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.processor.internals.namedtopology.KafkaStreamsNamedTopologyWrapper; +import org.apache.kafka.streams.processor.internals.namedtopology.NamedTopologyStreamsBuilder; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.QueryableStoreType; +import org.apache.kafka.streams.state.ReadOnlyKeyValueStore; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestCondition; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.Properties; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static java.util.Collections.singletonList; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.getStore; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.startApplicationAndWaitUntilRunning; +import static org.apache.kafka.streams.state.QueryableStoreTypes.keyValueStore; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.Matchers.anyOf; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +@Category({IntegrationTest.class}) +public class StoreQueryIntegrationTest { + + private static final Logger LOG = LoggerFactory.getLogger(StoreQueryIntegrationTest.class); + + private static final int NUM_BROKERS = 1; + private static int port = 0; + private static final String INPUT_TOPIC_NAME = "input-topic"; + private static final String TABLE_NAME = "source-table"; + + public final EmbeddedKafkaCluster cluster = new EmbeddedKafkaCluster(NUM_BROKERS); + + @Rule + public TestName testName = new TestName(); + + private final List streamsToCleanup = new ArrayList<>(); + private final MockTime mockTime = cluster.time; + + @Before + public void before() throws InterruptedException, IOException { + cluster.start(); + cluster.createTopic(INPUT_TOPIC_NAME, 2, 1); + } + + @After + public void after() { + for (final KafkaStreams kafkaStreams : streamsToCleanup) { + kafkaStreams.close(); + } + cluster.stop(); + } + + @Test + public void shouldQueryOnlyActivePartitionStoresByDefault() throws Exception { + final int batch1NumMessages = 100; + final int key = 1; + final Semaphore semaphore = new Semaphore(0); + + final StreamsBuilder builder = new StreamsBuilder(); + getStreamsBuilderWithTopology(builder, semaphore); + + final KafkaStreams kafkaStreams1 = createKafkaStreams(builder, streamsConfiguration()); + final KafkaStreams kafkaStreams2 = createKafkaStreams(builder, streamsConfiguration()); + final List kafkaStreamsList = Arrays.asList(kafkaStreams1, kafkaStreams2); + + startApplicationAndWaitUntilRunning(kafkaStreamsList, Duration.ofSeconds(60)); + + produceValueRange(key, 0, batch1NumMessages); + + // Assert that all messages in the first batch were processed in a timely manner + assertThat(semaphore.tryAcquire(batch1NumMessages, 60, TimeUnit.SECONDS), is(equalTo(true))); + until(() -> { + + final KeyQueryMetadata keyQueryMetadata = kafkaStreams1.queryMetadataForKey(TABLE_NAME, key, (topic, somekey, value, numPartitions) -> 0); + + final QueryableStoreType> queryableStoreType = keyValueStore(); + final ReadOnlyKeyValueStore store1 = getStore(TABLE_NAME, kafkaStreams1, queryableStoreType); + final ReadOnlyKeyValueStore store2 = getStore(TABLE_NAME, kafkaStreams2, queryableStoreType); + + final boolean kafkaStreams1IsActive = (keyQueryMetadata.activeHost().port() % 2) == 1; + + try { + if (kafkaStreams1IsActive) { + assertThat(store1.get(key), is(notNullValue())); + assertThat(store2.get(key), is(nullValue())); + } else { + assertThat(store1.get(key), is(nullValue())); + assertThat(store2.get(key), is(notNullValue())); + } + return true; + } catch (final InvalidStateStoreException exception) { + verifyRetrievableException(exception); + LOG.info("Either streams wasn't running or a re-balancing took place. Will try again."); + return false; + } + }); + } + + @Test + public void shouldQuerySpecificActivePartitionStores() throws Exception { + final int batch1NumMessages = 100; + final int key = 1; + final Semaphore semaphore = new Semaphore(0); + + final StreamsBuilder builder = new StreamsBuilder(); + getStreamsBuilderWithTopology(builder, semaphore); + + final KafkaStreams kafkaStreams1 = createKafkaStreams(builder, streamsConfiguration()); + final KafkaStreams kafkaStreams2 = createKafkaStreams(builder, streamsConfiguration()); + final List kafkaStreamsList = Arrays.asList(kafkaStreams1, kafkaStreams2); + + startApplicationAndWaitUntilRunning(kafkaStreamsList, Duration.ofSeconds(60)); + + produceValueRange(key, 0, batch1NumMessages); + + // Assert that all messages in the first batch were processed in a timely manner + assertThat(semaphore.tryAcquire(batch1NumMessages, 60, TimeUnit.SECONDS), is(equalTo(true))); + until(() -> { + final KeyQueryMetadata keyQueryMetadata = kafkaStreams1.queryMetadataForKey(TABLE_NAME, key, (topic, somekey, value, numPartitions) -> 0); + + //key belongs to this partition + final int keyPartition = keyQueryMetadata.partition(); + + //key doesn't belongs to this partition + final int keyDontBelongPartition = (keyPartition == 0) ? 1 : 0; + final boolean kafkaStreams1IsActive = (keyQueryMetadata.activeHost().port() % 2) == 1; + + final StoreQueryParameters> storeQueryParam = + StoreQueryParameters.>fromNameAndType(TABLE_NAME, keyValueStore()) + .withPartition(keyPartition); + ReadOnlyKeyValueStore store1 = null; + ReadOnlyKeyValueStore store2 = null; + if (kafkaStreams1IsActive) { + store1 = getStore(kafkaStreams1, storeQueryParam); + } else { + store2 = getStore(kafkaStreams2, storeQueryParam); + } + + if (kafkaStreams1IsActive) { + assertThat(store1, is(notNullValue())); + assertThat(store2, is(nullValue())); + } else { + assertThat(store2, is(notNullValue())); + assertThat(store1, is(nullValue())); + } + + final StoreQueryParameters> storeQueryParam2 = + StoreQueryParameters.>fromNameAndType(TABLE_NAME, keyValueStore()) + .withPartition(keyDontBelongPartition); + + try { + // Assert that key is not served when wrong specific partition is requested + // If kafkaStreams1 is active for keyPartition, kafkaStreams2 would be active for keyDontBelongPartition + // So, in that case, store3 would be null and the store4 would not return the value for key as wrong partition was requested + if (kafkaStreams1IsActive) { + assertThat(store1.get(key), is(notNullValue())); + assertThat(getStore(kafkaStreams2, storeQueryParam2).get(key), is(nullValue())); + final InvalidStateStoreException exception = + assertThrows(InvalidStateStoreException.class, () -> getStore(kafkaStreams1, storeQueryParam2).get(key)); + assertThat( + exception.getMessage(), + containsString("The specified partition 1 for store source-table does not exist.") + ); + } else { + assertThat(store2.get(key), is(notNullValue())); + assertThat(getStore(kafkaStreams1, storeQueryParam2).get(key), is(nullValue())); + final InvalidStateStoreException exception = + assertThrows(InvalidStateStoreException.class, () -> getStore(kafkaStreams2, storeQueryParam2).get(key)); + assertThat( + exception.getMessage(), + containsString("The specified partition 1 for store source-table does not exist.") + ); + } + return true; + } catch (final InvalidStateStoreException exception) { + verifyRetrievableException(exception); + LOG.info("Either streams wasn't running or a re-balancing took place. Will try again."); + return false; + } + }); + } + + @Test + public void shouldQueryAllStalePartitionStores() throws Exception { + final int batch1NumMessages = 100; + final int key = 1; + final Semaphore semaphore = new Semaphore(0); + + final StreamsBuilder builder = new StreamsBuilder(); + getStreamsBuilderWithTopology(builder, semaphore); + + final KafkaStreams kafkaStreams1 = createKafkaStreams(builder, streamsConfiguration()); + final KafkaStreams kafkaStreams2 = createKafkaStreams(builder, streamsConfiguration()); + final List kafkaStreamsList = Arrays.asList(kafkaStreams1, kafkaStreams2); + + startApplicationAndWaitUntilRunning(kafkaStreamsList, Duration.ofSeconds(60)); + + produceValueRange(key, 0, batch1NumMessages); + + // Assert that all messages in the first batch were processed in a timely manner + assertThat(semaphore.tryAcquire(batch1NumMessages, 60, TimeUnit.SECONDS), is(equalTo(true))); + + final QueryableStoreType> queryableStoreType = keyValueStore(); + + // Assert that both active and standby are able to query for a key + TestUtils.waitForCondition(() -> { + final ReadOnlyKeyValueStore store1 = getStore(TABLE_NAME, kafkaStreams1, true, queryableStoreType); + return store1.get(key) != null; + }, "store1 cannot find results for key"); + TestUtils.waitForCondition(() -> { + final ReadOnlyKeyValueStore store2 = getStore(TABLE_NAME, kafkaStreams2, true, queryableStoreType); + return store2.get(key) != null; + }, "store2 cannot find results for key"); + } + + @Test + public void shouldQuerySpecificStalePartitionStores() throws Exception { + final int batch1NumMessages = 100; + final int key = 1; + final Semaphore semaphore = new Semaphore(0); + + final StreamsBuilder builder = new StreamsBuilder(); + getStreamsBuilderWithTopology(builder, semaphore); + + final KafkaStreams kafkaStreams1 = createKafkaStreams(builder, streamsConfiguration()); + final KafkaStreams kafkaStreams2 = createKafkaStreams(builder, streamsConfiguration()); + final List kafkaStreamsList = Arrays.asList(kafkaStreams1, kafkaStreams2); + + startApplicationAndWaitUntilRunning(kafkaStreamsList, Duration.ofSeconds(60)); + + produceValueRange(key, 0, batch1NumMessages); + + // Assert that all messages in the first batch were processed in a timely manner + assertThat(semaphore.tryAcquire(batch1NumMessages, 60, TimeUnit.SECONDS), is(equalTo(true))); + final KeyQueryMetadata keyQueryMetadata = kafkaStreams1.queryMetadataForKey(TABLE_NAME, key, (topic, somekey, value, numPartitions) -> 0); + + //key belongs to this partition + final int keyPartition = keyQueryMetadata.partition(); + + //key doesn't belongs to this partition + final int keyDontBelongPartition = (keyPartition == 0) ? 1 : 0; + final QueryableStoreType> queryableStoreType = keyValueStore(); + + // Assert that both active and standby are able to query for a key + final StoreQueryParameters> param = StoreQueryParameters + .fromNameAndType(TABLE_NAME, queryableStoreType) + .enableStaleStores() + .withPartition(keyPartition); + TestUtils.waitForCondition(() -> { + final ReadOnlyKeyValueStore store1 = getStore(kafkaStreams1, param); + return store1.get(key) != null; + }, "store1 cannot find results for key"); + TestUtils.waitForCondition(() -> { + final ReadOnlyKeyValueStore store2 = getStore(kafkaStreams2, param); + return store2.get(key) != null; + }, "store2 cannot find results for key"); + + final StoreQueryParameters> otherParam = StoreQueryParameters + .fromNameAndType(TABLE_NAME, queryableStoreType) + .enableStaleStores() + .withPartition(keyDontBelongPartition); + final ReadOnlyKeyValueStore store3 = getStore(kafkaStreams1, otherParam); + final ReadOnlyKeyValueStore store4 = getStore(kafkaStreams2, otherParam); + + // Assert that + assertThat(store3.get(key), is(nullValue())); + assertThat(store4.get(key), is(nullValue())); + } + + @Test + public void shouldQuerySpecificStalePartitionStoresMultiStreamThreads() throws Exception { + final int batch1NumMessages = 100; + final int key = 1; + final Semaphore semaphore = new Semaphore(0); + final int numStreamThreads = 2; + + final StreamsBuilder builder = new StreamsBuilder(); + getStreamsBuilderWithTopology(builder, semaphore); + + final Properties streamsConfiguration1 = streamsConfiguration(); + streamsConfiguration1.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, numStreamThreads); + + final Properties streamsConfiguration2 = streamsConfiguration(); + streamsConfiguration2.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, numStreamThreads); + + final KafkaStreams kafkaStreams1 = createKafkaStreams(builder, streamsConfiguration1); + final KafkaStreams kafkaStreams2 = createKafkaStreams(builder, streamsConfiguration2); + final List kafkaStreamsList = Arrays.asList(kafkaStreams1, kafkaStreams2); + + startApplicationAndWaitUntilRunning(kafkaStreamsList, Duration.ofSeconds(60)); + + assertTrue(kafkaStreams1.metadataForLocalThreads().size() > 1); + assertTrue(kafkaStreams2.metadataForLocalThreads().size() > 1); + + produceValueRange(key, 0, batch1NumMessages); + + // Assert that all messages in the first batch were processed in a timely manner + assertThat(semaphore.tryAcquire(batch1NumMessages, 60, TimeUnit.SECONDS), is(equalTo(true))); + final KeyQueryMetadata keyQueryMetadata = kafkaStreams1.queryMetadataForKey(TABLE_NAME, key, new IntegerSerializer()); + + //key belongs to this partition + final int keyPartition = keyQueryMetadata.partition(); + + //key doesn't belongs to this partition + final int keyDontBelongPartition = (keyPartition == 0) ? 1 : 0; + final QueryableStoreType> queryableStoreType = keyValueStore(); + + // Assert that both active and standby are able to query for a key + final StoreQueryParameters> param = StoreQueryParameters + .fromNameAndType(TABLE_NAME, queryableStoreType) + .enableStaleStores() + .withPartition(keyPartition); + TestUtils.waitForCondition(() -> { + final ReadOnlyKeyValueStore store1 = getStore(kafkaStreams1, param); + return store1.get(key) != null; + }, "store1 cannot find results for key"); + TestUtils.waitForCondition(() -> { + final ReadOnlyKeyValueStore store2 = getStore(kafkaStreams2, param); + return store2.get(key) != null; + }, "store2 cannot find results for key"); + + final StoreQueryParameters> otherParam = StoreQueryParameters + .fromNameAndType(TABLE_NAME, queryableStoreType) + .enableStaleStores() + .withPartition(keyDontBelongPartition); + final ReadOnlyKeyValueStore store3 = getStore(kafkaStreams1, otherParam); + final ReadOnlyKeyValueStore store4 = getStore(kafkaStreams2, otherParam); + + // Assert that + assertThat(store3.get(key), is(nullValue())); + assertThat(store4.get(key), is(nullValue())); + } + + @Test + public void shouldQuerySpecificStalePartitionStoresMultiStreamThreadsNamedTopology() throws Exception { + final int batch1NumMessages = 100; + final int key = 1; + final Semaphore semaphore = new Semaphore(0); + final int numStreamThreads = 2; + + final NamedTopologyStreamsBuilder builder1A = new NamedTopologyStreamsBuilder("topology-A"); + getStreamsBuilderWithTopology(builder1A, semaphore); + + final NamedTopologyStreamsBuilder builder2A = new NamedTopologyStreamsBuilder("topology-A"); + getStreamsBuilderWithTopology(builder2A, semaphore); + + final Properties streamsConfiguration1 = streamsConfiguration(); + streamsConfiguration1.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, numStreamThreads); + + final Properties streamsConfiguration2 = streamsConfiguration(); + streamsConfiguration2.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, numStreamThreads); + + final KafkaStreamsNamedTopologyWrapper kafkaStreams1 = createNamedTopologyKafkaStreams(builder1A, streamsConfiguration1); + final KafkaStreamsNamedTopologyWrapper kafkaStreams2 = createNamedTopologyKafkaStreams(builder2A, streamsConfiguration2); + final List kafkaStreamsList = Arrays.asList(kafkaStreams1, kafkaStreams2); + + startApplicationAndWaitUntilRunning(kafkaStreamsList, Duration.ofSeconds(60)); + + assertTrue(kafkaStreams1.metadataForLocalThreads().size() > 1); + assertTrue(kafkaStreams2.metadataForLocalThreads().size() > 1); + + produceValueRange(key, 0, batch1NumMessages); + + // Assert that all messages in the first batch were processed in a timely manner + assertThat(semaphore.tryAcquire(batch1NumMessages, 60, TimeUnit.SECONDS), is(equalTo(true))); + final KeyQueryMetadata keyQueryMetadata = kafkaStreams1.queryMetadataForKey(TABLE_NAME, key, new IntegerSerializer()); + + //key belongs to this partition + final int keyPartition = keyQueryMetadata.partition(); + + //key doesn't belongs to this partition + final int keyDontBelongPartition = (keyPartition == 0) ? 1 : 0; + final QueryableStoreType> queryableStoreType = keyValueStore(); + + // Assert that both active and standby are able to query for a key + final StoreQueryParameters> param = StoreQueryParameters + .fromNameAndType(TABLE_NAME, queryableStoreType) + .enableStaleStores() + .withPartition(keyPartition); + TestUtils.waitForCondition(() -> { + final ReadOnlyKeyValueStore store1 = getStore(kafkaStreams1, param); + return store1.get(key) != null; + }, "store1 cannot find results for key"); + TestUtils.waitForCondition(() -> { + final ReadOnlyKeyValueStore store2 = getStore(kafkaStreams2, param); + return store2.get(key) != null; + }, "store2 cannot find results for key"); + + final StoreQueryParameters> otherParam = StoreQueryParameters + .fromNameAndType(TABLE_NAME, queryableStoreType) + .enableStaleStores() + .withPartition(keyDontBelongPartition); + final ReadOnlyKeyValueStore store3 = getStore(kafkaStreams1, otherParam); + final ReadOnlyKeyValueStore store4 = getStore(kafkaStreams2, otherParam); + + // Assert that + assertThat(store3.get(key), is(nullValue())); + assertThat(store4.get(key), is(nullValue())); + } + + @Test + public void shouldQueryStoresAfterAddingAndRemovingStreamThread() throws Exception { + final int batch1NumMessages = 100; + final int key = 1; + final int key2 = 2; + final int key3 = 3; + final Semaphore semaphore = new Semaphore(0); + + final StreamsBuilder builder = new StreamsBuilder(); + getStreamsBuilderWithTopology(builder, semaphore); + + final Properties streamsConfiguration1 = streamsConfiguration(); + streamsConfiguration1.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 1); + + final KafkaStreams kafkaStreams1 = createKafkaStreams(builder, streamsConfiguration1); + + startApplicationAndWaitUntilRunning(singletonList(kafkaStreams1), Duration.ofSeconds(60)); + //Add thread + final Optional streamThread = kafkaStreams1.addStreamThread(); + assertThat(streamThread.isPresent(), is(true)); + until(() -> kafkaStreams1.state().isRunningOrRebalancing()); + + produceValueRange(key, 0, batch1NumMessages); + produceValueRange(key2, 0, batch1NumMessages); + produceValueRange(key3, 0, batch1NumMessages); + + // Assert that all messages in the batches were processed in a timely manner + assertThat(semaphore.tryAcquire(3 * batch1NumMessages, 60, TimeUnit.SECONDS), is(equalTo(true))); + + until(() -> KafkaStreams.State.RUNNING.equals(kafkaStreams1.state())); + until(() -> { + final QueryableStoreType> queryableStoreType = keyValueStore(); + final ReadOnlyKeyValueStore store1 = getStore(TABLE_NAME, kafkaStreams1, queryableStoreType); + + try { + assertThat(store1.get(key), is(notNullValue())); + assertThat(store1.get(key2), is(notNullValue())); + assertThat(store1.get(key3), is(notNullValue())); + return true; + } catch (final InvalidStateStoreException exception) { + verifyRetrievableException(exception); + LOG.info("Either streams wasn't running or a re-balancing took place. Will try again."); + return false; + } + }); + + final Optional removedThreadName = kafkaStreams1.removeStreamThread(); + assertThat(removedThreadName.isPresent(), is(true)); + until(() -> kafkaStreams1.state().isRunningOrRebalancing()); + + until(() -> KafkaStreams.State.RUNNING.equals(kafkaStreams1.state())); + until(() -> { + final QueryableStoreType> queryableStoreType = keyValueStore(); + final ReadOnlyKeyValueStore store1 = getStore(TABLE_NAME, kafkaStreams1, queryableStoreType); + + try { + assertThat(store1.get(key), is(notNullValue())); + assertThat(store1.get(key2), is(notNullValue())); + assertThat(store1.get(key3), is(notNullValue())); + return true; + } catch (final InvalidStateStoreException exception) { + verifyRetrievableException(exception); + LOG.info("Either streams wasn't running or a re-balancing took place. Will try again."); + return false; + } + }); + } + + private void verifyRetrievableException(final Exception exception) { + assertThat( + "Unexpected exception thrown while getting the value from store.", + exception.getMessage(), + is( + anyOf( + containsString("Cannot get state store source-table because the stream thread is PARTITIONS_ASSIGNED, not RUNNING"), + containsString("The state store, source-table, may have migrated to another instance"), + containsString("Cannot get state store source-table because the stream thread is STARTING, not RUNNING") + ) + ) + ); + } + + private static void until(final TestCondition condition) { + boolean success = false; + final long deadline = System.currentTimeMillis() + IntegrationTestUtils.DEFAULT_TIMEOUT; + while (!success && System.currentTimeMillis() < deadline) { + try { + success = condition.conditionMet(); + Thread.sleep(500L); + } catch (final RuntimeException e) { + throw e; + } catch (final Exception e) { + throw new RuntimeException(e); + } + } + } + + private void getStreamsBuilderWithTopology(final StreamsBuilder builder, final Semaphore semaphore) { + builder.table(INPUT_TOPIC_NAME, Consumed.with(Serdes.Integer(), Serdes.Integer()), + Materialized.>as(TABLE_NAME).withCachingDisabled()) + .toStream() + .peek((k, v) -> semaphore.release()); + } + + private KafkaStreams createKafkaStreams(final StreamsBuilder builder, final Properties config) { + final KafkaStreams streams = new KafkaStreams(builder.build(config), config); + streamsToCleanup.add(streams); + return streams; + } + + private KafkaStreamsNamedTopologyWrapper createNamedTopologyKafkaStreams(final NamedTopologyStreamsBuilder builder, final Properties config) { + final KafkaStreamsNamedTopologyWrapper streams = new KafkaStreamsNamedTopologyWrapper(builder.buildNamedTopology(config), config); + streamsToCleanup.add(streams); + return streams; + } + + private void produceValueRange(final int key, final int start, final int endExclusive) { + final Properties producerProps = new Properties(); + producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, cluster.bootstrapServers()); + producerProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, IntegerSerializer.class); + producerProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, IntegerSerializer.class); + + IntegrationTestUtils.produceKeyValuesSynchronously( + INPUT_TOPIC_NAME, + IntStream.range(start, endExclusive) + .mapToObj(i -> KeyValue.pair(key, i)) + .collect(Collectors.toList()), + producerProps, + mockTime); + } + + private Properties streamsConfiguration() { + final String safeTestName = safeUniqueTestName(getClass(), testName); + final Properties config = new Properties(); + config.put(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, StreamsConfig.OPTIMIZE); + config.put(StreamsConfig.APPLICATION_ID_CONFIG, "app-" + safeTestName); + config.put(StreamsConfig.APPLICATION_SERVER_CONFIG, "localhost:" + (++port)); + config.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, cluster.bootstrapServers()); + config.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()); + config.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.Integer().getClass()); + config.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.Integer().getClass()); + config.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1); + config.put(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, 100); + config.put(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, 200); + config.put(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, 1000); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L); + return config; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/StoreQuerySuite.java b/streams/src/test/java/org/apache/kafka/streams/integration/StoreQuerySuite.java new file mode 100644 index 0000000..31b8554 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/StoreQuerySuite.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.streams.state.internals.CompositeReadOnlyKeyValueStoreTest; +import org.apache.kafka.streams.state.internals.CompositeReadOnlySessionStoreTest; +import org.apache.kafka.streams.state.internals.CompositeReadOnlyWindowStoreTest; +import org.apache.kafka.streams.state.internals.GlobalStateStoreProviderTest; +import org.apache.kafka.streams.state.internals.StreamThreadStateStoreProviderTest; +import org.apache.kafka.streams.state.internals.WrappingStoreProviderTest; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; + +/** + * This suite runs all the tests related to querying StateStores (IQ). + * + * It can be used from an IDE to selectively just run these tests. + * + * Tests ending in the word "Suite" are excluded from the gradle build because it + * already runs the component tests individually. + */ +@RunWith(Suite.class) +@Suite.SuiteClasses({ + CompositeReadOnlyKeyValueStoreTest.class, + CompositeReadOnlyWindowStoreTest.class, + CompositeReadOnlySessionStoreTest.class, + GlobalStateStoreProviderTest.class, + StreamThreadStateStoreProviderTest.class, + WrappingStoreProviderTest.class, + QueryableStateIntegrationTest.class, + }) +public class StoreQuerySuite { +} + + diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/StoreUpgradeIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/StoreUpgradeIntegrationTest.java new file mode 100644 index 0000000..9c6085f --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/StoreUpgradeIntegrationTest.java @@ -0,0 +1,1079 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.TimeWindow; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.QueryableStoreTypes; +import org.apache.kafka.streams.state.ReadOnlyKeyValueStore; +import org.apache.kafka.streams.state.ReadOnlyWindowStore; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.TimestampedWindowStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; + +import java.io.IOException; +import java.time.Duration; +import java.util.LinkedList; +import java.util.List; +import java.util.Properties; + +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +@Category({IntegrationTest.class}) +public class StoreUpgradeIntegrationTest { + private static final String STORE_NAME = "store"; + private String inputStream; + + private KafkaStreams kafkaStreams; + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(1); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + @Rule + public TestName testName = new TestName(); + + @Before + public void createTopics() throws Exception { + inputStream = "input-stream-" + safeUniqueTestName(getClass(), testName); + CLUSTER.createTopic(inputStream); + } + + private Properties props() { + final Properties streamsConfiguration = new Properties(); + final String safeTestName = safeUniqueTestName(getClass(), testName); + streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, "app-" + safeTestName); + streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + streamsConfiguration.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()); + streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.Integer().getClass()); + streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.Integer().getClass()); + streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + streamsConfiguration.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + return streamsConfiguration; + } + + @After + public void shutdown() { + if (kafkaStreams != null) { + kafkaStreams.close(Duration.ofSeconds(30L)); + kafkaStreams.cleanUp(); + } + } + + @Test + public void shouldMigrateInMemoryKeyValueStoreToTimestampedKeyValueStoreUsingPapi() throws Exception { + shouldMigrateKeyValueStoreToTimestampedKeyValueStoreUsingPapi(false); + } + + @Test + public void shouldMigratePersistentKeyValueStoreToTimestampedKeyValueStoreUsingPapi() throws Exception { + shouldMigrateKeyValueStoreToTimestampedKeyValueStoreUsingPapi(true); + } + + private void shouldMigrateKeyValueStoreToTimestampedKeyValueStoreUsingPapi(final boolean persistentStore) throws Exception { + final StreamsBuilder streamsBuilderForOldStore = new StreamsBuilder(); + + streamsBuilderForOldStore.addStateStore( + Stores.keyValueStoreBuilder( + persistentStore ? Stores.persistentKeyValueStore(STORE_NAME) : Stores.inMemoryKeyValueStore(STORE_NAME), + Serdes.Integer(), + Serdes.Long())) + .stream(inputStream) + .process(KeyValueProcessor::new, STORE_NAME); + + final Properties props = props(); + kafkaStreams = new KafkaStreams(streamsBuilderForOldStore.build(), props); + kafkaStreams.start(); + + processKeyValueAndVerifyPlainCount(1, singletonList(KeyValue.pair(1, 1L))); + + processKeyValueAndVerifyPlainCount(1, singletonList(KeyValue.pair(1, 2L))); + final long lastUpdateKeyOne = persistentStore ? -1L : CLUSTER.time.milliseconds() - 1L; + + processKeyValueAndVerifyPlainCount(2, asList( + KeyValue.pair(1, 2L), + KeyValue.pair(2, 1L))); + final long lastUpdateKeyTwo = persistentStore ? -1L : CLUSTER.time.milliseconds() - 1L; + + processKeyValueAndVerifyPlainCount(3, asList( + KeyValue.pair(1, 2L), + KeyValue.pair(2, 1L), + KeyValue.pair(3, 1L))); + final long lastUpdateKeyThree = persistentStore ? -1L : CLUSTER.time.milliseconds() - 1L; + + processKeyValueAndVerifyPlainCount(4, asList( + KeyValue.pair(1, 2L), + KeyValue.pair(2, 1L), + KeyValue.pair(3, 1L), + KeyValue.pair(4, 1L))); + + processKeyValueAndVerifyPlainCount(4, asList( + KeyValue.pair(1, 2L), + KeyValue.pair(2, 1L), + KeyValue.pair(3, 1L), + KeyValue.pair(4, 2L))); + + processKeyValueAndVerifyPlainCount(4, asList( + KeyValue.pair(1, 2L), + KeyValue.pair(2, 1L), + KeyValue.pair(3, 1L), + KeyValue.pair(4, 3L))); + final long lastUpdateKeyFour = persistentStore ? -1L : CLUSTER.time.milliseconds() - 1L; + + kafkaStreams.close(); + kafkaStreams = null; + + + + final StreamsBuilder streamsBuilderForNewStore = new StreamsBuilder(); + + streamsBuilderForNewStore.addStateStore( + Stores.timestampedKeyValueStoreBuilder( + persistentStore ? Stores.persistentTimestampedKeyValueStore(STORE_NAME) : Stores.inMemoryKeyValueStore(STORE_NAME), + Serdes.Integer(), + Serdes.Long())) + .stream(inputStream) + .process(TimestampedKeyValueProcessor::new, STORE_NAME); + + kafkaStreams = new KafkaStreams(streamsBuilderForNewStore.build(), props); + kafkaStreams.start(); + + verifyCountWithTimestamp(1, 2L, lastUpdateKeyOne); + verifyCountWithTimestamp(2, 1L, lastUpdateKeyTwo); + verifyCountWithTimestamp(3, 1L, lastUpdateKeyThree); + verifyCountWithTimestamp(4, 3L, lastUpdateKeyFour); + + final long currentTime = CLUSTER.time.milliseconds(); + processKeyValueAndVerifyCountWithTimestamp(1, currentTime + 42L, asList( + KeyValue.pair(1, ValueAndTimestamp.make(3L, currentTime + 42L)), + KeyValue.pair(2, ValueAndTimestamp.make(1L, lastUpdateKeyTwo)), + KeyValue.pair(3, ValueAndTimestamp.make(1L, lastUpdateKeyThree)), + KeyValue.pair(4, ValueAndTimestamp.make(3L, lastUpdateKeyFour)))); + + processKeyValueAndVerifyCountWithTimestamp(2, currentTime + 45L, asList( + KeyValue.pair(1, ValueAndTimestamp.make(3L, currentTime + 42L)), + KeyValue.pair(2, ValueAndTimestamp.make(2L, currentTime + 45L)), + KeyValue.pair(3, ValueAndTimestamp.make(1L, lastUpdateKeyThree)), + KeyValue.pair(4, ValueAndTimestamp.make(3L, lastUpdateKeyFour)))); + + // can process "out of order" record for different key + processKeyValueAndVerifyCountWithTimestamp(4, currentTime + 21L, asList( + KeyValue.pair(1, ValueAndTimestamp.make(3L, currentTime + 42L)), + KeyValue.pair(2, ValueAndTimestamp.make(2L, currentTime + 45L)), + KeyValue.pair(3, ValueAndTimestamp.make(1L, lastUpdateKeyThree)), + KeyValue.pair(4, ValueAndTimestamp.make(4L, currentTime + 21L)))); + + processKeyValueAndVerifyCountWithTimestamp(4, currentTime + 42L, asList( + KeyValue.pair(1, ValueAndTimestamp.make(3L, currentTime + 42L)), + KeyValue.pair(2, ValueAndTimestamp.make(2L, currentTime + 45L)), + KeyValue.pair(3, ValueAndTimestamp.make(1L, lastUpdateKeyThree)), + KeyValue.pair(4, ValueAndTimestamp.make(5L, currentTime + 42L)))); + + // out of order (same key) record should not reduce result timestamp + processKeyValueAndVerifyCountWithTimestamp(4, currentTime + 10L, asList( + KeyValue.pair(1, ValueAndTimestamp.make(3L, currentTime + 42L)), + KeyValue.pair(2, ValueAndTimestamp.make(2L, currentTime + 45L)), + KeyValue.pair(3, ValueAndTimestamp.make(1L, lastUpdateKeyThree)), + KeyValue.pair(4, ValueAndTimestamp.make(6L, currentTime + 42L)))); + + kafkaStreams.close(); + } + + @Test + public void shouldProxyKeyValueStoreToTimestampedKeyValueStoreUsingPapi() throws Exception { + final StreamsBuilder streamsBuilderForOldStore = new StreamsBuilder(); + + streamsBuilderForOldStore.addStateStore( + Stores.keyValueStoreBuilder( + Stores.persistentKeyValueStore(STORE_NAME), + Serdes.Integer(), + Serdes.Long())) + .stream(inputStream) + .process(KeyValueProcessor::new, STORE_NAME); + + final Properties props = props(); + kafkaStreams = new KafkaStreams(streamsBuilderForOldStore.build(), props); + kafkaStreams.start(); + + processKeyValueAndVerifyPlainCount(1, singletonList(KeyValue.pair(1, 1L))); + + processKeyValueAndVerifyPlainCount(1, singletonList(KeyValue.pair(1, 2L))); + + processKeyValueAndVerifyPlainCount(2, asList( + KeyValue.pair(1, 2L), + KeyValue.pair(2, 1L))); + + processKeyValueAndVerifyPlainCount(3, asList( + KeyValue.pair(1, 2L), + KeyValue.pair(2, 1L), + KeyValue.pair(3, 1L))); + + processKeyValueAndVerifyPlainCount(4, asList( + KeyValue.pair(1, 2L), + KeyValue.pair(2, 1L), + KeyValue.pair(3, 1L), + KeyValue.pair(4, 1L))); + + processKeyValueAndVerifyPlainCount(4, asList( + KeyValue.pair(1, 2L), + KeyValue.pair(2, 1L), + KeyValue.pair(3, 1L), + KeyValue.pair(4, 2L))); + + processKeyValueAndVerifyPlainCount(4, asList( + KeyValue.pair(1, 2L), + KeyValue.pair(2, 1L), + KeyValue.pair(3, 1L), + KeyValue.pair(4, 3L))); + + kafkaStreams.close(); + kafkaStreams = null; + + + + final StreamsBuilder streamsBuilderForNewStore = new StreamsBuilder(); + + streamsBuilderForNewStore.addStateStore( + Stores.timestampedKeyValueStoreBuilder( + Stores.persistentKeyValueStore(STORE_NAME), + Serdes.Integer(), + Serdes.Long())) + .stream(inputStream) + .process(TimestampedKeyValueProcessor::new, STORE_NAME); + + kafkaStreams = new KafkaStreams(streamsBuilderForNewStore.build(), props); + kafkaStreams.start(); + + verifyCountWithSurrogateTimestamp(1, 2L); + verifyCountWithSurrogateTimestamp(2, 1L); + verifyCountWithSurrogateTimestamp(3, 1L); + verifyCountWithSurrogateTimestamp(4, 3L); + + processKeyValueAndVerifyCount(1, 42L, asList( + KeyValue.pair(1, ValueAndTimestamp.make(3L, -1L)), + KeyValue.pair(2, ValueAndTimestamp.make(1L, -1L)), + KeyValue.pair(3, ValueAndTimestamp.make(1L, -1L)), + KeyValue.pair(4, ValueAndTimestamp.make(3L, -1L)))); + + processKeyValueAndVerifyCount(2, 45L, asList( + KeyValue.pair(1, ValueAndTimestamp.make(3L, -1L)), + KeyValue.pair(2, ValueAndTimestamp.make(2L, -1L)), + KeyValue.pair(3, ValueAndTimestamp.make(1L, -1L)), + KeyValue.pair(4, ValueAndTimestamp.make(3L, -1L)))); + + // can process "out of order" record for different key + processKeyValueAndVerifyCount(4, 21L, asList( + KeyValue.pair(1, ValueAndTimestamp.make(3L, -1L)), + KeyValue.pair(2, ValueAndTimestamp.make(2L, -1L)), + KeyValue.pair(3, ValueAndTimestamp.make(1L, -1L)), + KeyValue.pair(4, ValueAndTimestamp.make(4L, -1L)))); + + processKeyValueAndVerifyCount(4, 42L, asList( + KeyValue.pair(1, ValueAndTimestamp.make(3L, -1L)), + KeyValue.pair(2, ValueAndTimestamp.make(2L, -1L)), + KeyValue.pair(3, ValueAndTimestamp.make(1L, -1L)), + KeyValue.pair(4, ValueAndTimestamp.make(5L, -1L)))); + + // out of order (same key) record should not reduce result timestamp + processKeyValueAndVerifyCount(4, 10L, asList( + KeyValue.pair(1, ValueAndTimestamp.make(3L, -1L)), + KeyValue.pair(2, ValueAndTimestamp.make(2L, -1L)), + KeyValue.pair(3, ValueAndTimestamp.make(1L, -1L)), + KeyValue.pair(4, ValueAndTimestamp.make(6L, -1L)))); + + kafkaStreams.close(); + } + + private void processKeyValueAndVerifyPlainCount(final K key, + final List> expectedStoreContent) + throws Exception { + + IntegrationTestUtils.produceKeyValuesSynchronously( + inputStream, + singletonList(KeyValue.pair(key, 0)), + TestUtils.producerConfig(CLUSTER.bootstrapServers(), + IntegerSerializer.class, + IntegerSerializer.class), + CLUSTER.time); + + TestUtils.waitForCondition( + () -> { + try { + final ReadOnlyKeyValueStore store = IntegrationTestUtils.getStore(STORE_NAME, kafkaStreams, QueryableStoreTypes.keyValueStore()); + + if (store == null) { + return false; + } + + try (final KeyValueIterator all = store.all()) { + final List> storeContent = new LinkedList<>(); + while (all.hasNext()) { + storeContent.add(all.next()); + } + return storeContent.equals(expectedStoreContent); + } + } catch (final Exception swallow) { + swallow.printStackTrace(); + System.err.println(swallow.getMessage()); + return false; + } + }, + 60_000L, + "Could not get expected result in time."); + } + + private void verifyCountWithTimestamp(final K key, + final long value, + final long timestamp) throws Exception { + TestUtils.waitForCondition( + () -> { + try { + final ReadOnlyKeyValueStore> store = IntegrationTestUtils + .getStore(STORE_NAME, kafkaStreams, QueryableStoreTypes.timestampedKeyValueStore()); + + if (store == null) + return false; + + final ValueAndTimestamp count = store.get(key); + return count.value() == value && count.timestamp() == timestamp; + } catch (final Exception swallow) { + swallow.printStackTrace(); + System.err.println(swallow.getMessage()); + return false; + } + }, + 60_000L, + "Could not get expected result in time."); + } + + private void verifyCountWithSurrogateTimestamp(final K key, + final long value) throws Exception { + TestUtils.waitForCondition( + () -> { + try { + final ReadOnlyKeyValueStore> store = IntegrationTestUtils + .getStore(STORE_NAME, kafkaStreams, QueryableStoreTypes.timestampedKeyValueStore()); + + if (store == null) + return false; + + final ValueAndTimestamp count = store.get(key); + return count.value() == value && count.timestamp() == -1L; + } catch (final Exception swallow) { + swallow.printStackTrace(); + System.err.println(swallow.getMessage()); + return false; + } + }, + 60_000L, + "Could not get expected result in time."); + } + + private void processKeyValueAndVerifyCount(final K key, + final long timestamp, + final List> expectedStoreContent) + throws Exception { + + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + inputStream, + singletonList(KeyValue.pair(key, 0)), + TestUtils.producerConfig(CLUSTER.bootstrapServers(), + IntegerSerializer.class, + IntegerSerializer.class), + timestamp); + + TestUtils.waitForCondition( + () -> { + try { + final ReadOnlyKeyValueStore> store = IntegrationTestUtils + .getStore(STORE_NAME, kafkaStreams, QueryableStoreTypes.timestampedKeyValueStore()); + + if (store == null) + return false; + + try (final KeyValueIterator> all = store.all()) { + final List>> storeContent = new LinkedList<>(); + while (all.hasNext()) { + storeContent.add(all.next()); + } + return storeContent.equals(expectedStoreContent); + } + } catch (final Exception swallow) { + swallow.printStackTrace(); + System.err.println(swallow.getMessage()); + return false; + } + }, + 60_000L, + "Could not get expected result in time."); + } + + private void processKeyValueAndVerifyCountWithTimestamp(final K key, + final long timestamp, + final List> expectedStoreContent) + throws Exception { + + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + inputStream, + singletonList(KeyValue.pair(key, 0)), + TestUtils.producerConfig(CLUSTER.bootstrapServers(), + IntegerSerializer.class, + IntegerSerializer.class), + timestamp); + + TestUtils.waitForCondition( + () -> { + try { + final ReadOnlyKeyValueStore> store = IntegrationTestUtils + .getStore(STORE_NAME, kafkaStreams, QueryableStoreTypes.timestampedKeyValueStore()); + + if (store == null) + return false; + + try (final KeyValueIterator> all = store.all()) { + final List>> storeContent = new LinkedList<>(); + while (all.hasNext()) { + storeContent.add(all.next()); + } + return storeContent.equals(expectedStoreContent); + } + } catch (final Exception swallow) { + swallow.printStackTrace(); + System.err.println(swallow.getMessage()); + return false; + } + }, + 60_000L, + "Could not get expected result in time."); + } + + @Test + public void shouldMigrateInMemoryWindowStoreToTimestampedWindowStoreUsingPapi() throws Exception { + final StreamsBuilder streamsBuilderForOldStore = new StreamsBuilder(); + streamsBuilderForOldStore + .addStateStore( + Stores.windowStoreBuilder( + Stores.inMemoryWindowStore( + STORE_NAME, + Duration.ofMillis(1000L), + Duration.ofMillis(1000L), + false), + Serdes.Integer(), + Serdes.Long())) + .stream(inputStream) + .process(WindowedProcessor::new, STORE_NAME); + + final StreamsBuilder streamsBuilderForNewStore = new StreamsBuilder(); + streamsBuilderForNewStore + .addStateStore( + Stores.timestampedWindowStoreBuilder( + Stores.inMemoryWindowStore( + STORE_NAME, + Duration.ofMillis(1000L), + Duration.ofMillis(1000L), + false), + Serdes.Integer(), + Serdes.Long())) + .stream(inputStream) + .process(TimestampedWindowedProcessor::new, STORE_NAME); + + + shouldMigrateWindowStoreToTimestampedWindowStoreUsingPapi( + streamsBuilderForOldStore, + streamsBuilderForNewStore, + false); + } + + @Test + public void shouldMigratePersistentWindowStoreToTimestampedWindowStoreUsingPapi() throws Exception { + final StreamsBuilder streamsBuilderForOldStore = new StreamsBuilder(); + + streamsBuilderForOldStore + .addStateStore( + Stores.windowStoreBuilder( + Stores.persistentWindowStore( + STORE_NAME, + Duration.ofMillis(1000L), + Duration.ofMillis(1000L), + false), + Serdes.Integer(), + Serdes.Long())) + .stream(inputStream) + .process(WindowedProcessor::new, STORE_NAME); + + final StreamsBuilder streamsBuilderForNewStore = new StreamsBuilder(); + streamsBuilderForNewStore + .addStateStore( + Stores.timestampedWindowStoreBuilder( + Stores.persistentTimestampedWindowStore( + STORE_NAME, + Duration.ofMillis(1000L), + Duration.ofMillis(1000L), + false), + Serdes.Integer(), + Serdes.Long())) + .stream(inputStream) + .process(TimestampedWindowedProcessor::new, STORE_NAME); + + shouldMigrateWindowStoreToTimestampedWindowStoreUsingPapi( + streamsBuilderForOldStore, + streamsBuilderForNewStore, + true); + } + + private void shouldMigrateWindowStoreToTimestampedWindowStoreUsingPapi(final StreamsBuilder streamsBuilderForOldStore, + final StreamsBuilder streamsBuilderForNewStore, + final boolean persistentStore) throws Exception { + final Properties props = props(); + kafkaStreams = new KafkaStreams(streamsBuilderForOldStore.build(), props); + kafkaStreams.start(); + + processWindowedKeyValueAndVerifyPlainCount(1, singletonList( + KeyValue.pair(new Windowed<>(1, new TimeWindow(0L, 1000L)), 1L))); + + processWindowedKeyValueAndVerifyPlainCount(1, singletonList( + KeyValue.pair(new Windowed<>(1, new TimeWindow(0L, 1000L)), 2L))); + final long lastUpdateKeyOne = persistentStore ? -1L : CLUSTER.time.milliseconds() - 1L; + + processWindowedKeyValueAndVerifyPlainCount(2, asList( + KeyValue.pair(new Windowed<>(1, new TimeWindow(0L, 1000L)), 2L), + KeyValue.pair(new Windowed<>(2, new TimeWindow(0L, 1000L)), 1L))); + final long lastUpdateKeyTwo = persistentStore ? -1L : CLUSTER.time.milliseconds() - 1L; + + processWindowedKeyValueAndVerifyPlainCount(3, asList( + KeyValue.pair(new Windowed<>(1, new TimeWindow(0L, 1000L)), 2L), + KeyValue.pair(new Windowed<>(2, new TimeWindow(0L, 1000L)), 1L), + KeyValue.pair(new Windowed<>(3, new TimeWindow(0L, 1000L)), 1L))); + final long lastUpdateKeyThree = persistentStore ? -1L : CLUSTER.time.milliseconds() - 1L; + + processWindowedKeyValueAndVerifyPlainCount(4, asList( + KeyValue.pair(new Windowed<>(1, new TimeWindow(0L, 1000L)), 2L), + KeyValue.pair(new Windowed<>(2, new TimeWindow(0L, 1000L)), 1L), + KeyValue.pair(new Windowed<>(3, new TimeWindow(0L, 1000L)), 1L), + KeyValue.pair(new Windowed<>(4, new TimeWindow(0L, 1000L)), 1L))); + + processWindowedKeyValueAndVerifyPlainCount(4, asList( + KeyValue.pair(new Windowed<>(1, new TimeWindow(0L, 1000L)), 2L), + KeyValue.pair(new Windowed<>(2, new TimeWindow(0L, 1000L)), 1L), + KeyValue.pair(new Windowed<>(3, new TimeWindow(0L, 1000L)), 1L), + KeyValue.pair(new Windowed<>(4, new TimeWindow(0L, 1000L)), 2L))); + + processWindowedKeyValueAndVerifyPlainCount(4, asList( + KeyValue.pair(new Windowed<>(1, new TimeWindow(0L, 1000L)), 2L), + KeyValue.pair(new Windowed<>(2, new TimeWindow(0L, 1000L)), 1L), + KeyValue.pair(new Windowed<>(3, new TimeWindow(0L, 1000L)), 1L), + KeyValue.pair(new Windowed<>(4, new TimeWindow(0L, 1000L)), 3L))); + final long lastUpdateKeyFour = persistentStore ? -1L : CLUSTER.time.milliseconds() - 1L; + + kafkaStreams.close(); + kafkaStreams = null; + + + kafkaStreams = new KafkaStreams(streamsBuilderForNewStore.build(), props); + kafkaStreams.start(); + + verifyWindowedCountWithTimestamp(new Windowed<>(1, new TimeWindow(0L, 1000L)), 2L, lastUpdateKeyOne); + verifyWindowedCountWithTimestamp(new Windowed<>(2, new TimeWindow(0L, 1000L)), 1L, lastUpdateKeyTwo); + verifyWindowedCountWithTimestamp(new Windowed<>(3, new TimeWindow(0L, 1000L)), 1L, lastUpdateKeyThree); + verifyWindowedCountWithTimestamp(new Windowed<>(4, new TimeWindow(0L, 1000L)), 3L, lastUpdateKeyFour); + + final long currentTime = CLUSTER.time.milliseconds(); + processKeyValueAndVerifyWindowedCountWithTimestamp(1, currentTime + 42L, asList( + KeyValue.pair( + new Windowed<>(1, new TimeWindow(0L, 1000L)), + ValueAndTimestamp.make(3L, currentTime + 42L)), + KeyValue.pair( + new Windowed<>(2, new TimeWindow(0L, 1000L)), + ValueAndTimestamp.make(1L, lastUpdateKeyTwo)), + KeyValue.pair( + new Windowed<>(3, new TimeWindow(0L, 1000L)), + ValueAndTimestamp.make(1L, lastUpdateKeyThree)), + KeyValue.pair( + new Windowed<>(4, new TimeWindow(0L, 1000L)), + ValueAndTimestamp.make(3L, lastUpdateKeyFour)))); + + processKeyValueAndVerifyWindowedCountWithTimestamp(2, currentTime + 45L, asList( + KeyValue.pair( + new Windowed<>(1, new TimeWindow(0L, 1000L)), + ValueAndTimestamp.make(3L, currentTime + 42L)), + KeyValue.pair( + new Windowed<>(2, new TimeWindow(0L, 1000L)), + ValueAndTimestamp.make(2L, currentTime + 45L)), + KeyValue.pair( + new Windowed<>(3, new TimeWindow(0L, 1000L)), + ValueAndTimestamp.make(1L, lastUpdateKeyThree)), + KeyValue.pair( + new Windowed<>(4, new TimeWindow(0L, 1000L)), + ValueAndTimestamp.make(3L, lastUpdateKeyFour)))); + + // can process "out of order" record for different key + processKeyValueAndVerifyWindowedCountWithTimestamp(4, currentTime + 21L, asList( + KeyValue.pair( + new Windowed<>(1, new TimeWindow(0L, 1000L)), + ValueAndTimestamp.make(3L, currentTime + 42L)), + KeyValue.pair( + new Windowed<>(2, new TimeWindow(0L, 1000L)), + ValueAndTimestamp.make(2L, currentTime + 45L)), + KeyValue.pair( + new Windowed<>(3, new TimeWindow(0L, 1000L)), + ValueAndTimestamp.make(1L, lastUpdateKeyThree)), + KeyValue.pair( + new Windowed<>(4, new TimeWindow(0L, 1000L)), + ValueAndTimestamp.make(4L, currentTime + 21L)))); + + processKeyValueAndVerifyWindowedCountWithTimestamp(4, currentTime + 42L, asList( + KeyValue.pair( + new Windowed<>(1, new TimeWindow(0L, 1000L)), + ValueAndTimestamp.make(3L, currentTime + 42L)), + KeyValue.pair( + new Windowed<>(2, new TimeWindow(0L, 1000L)), + ValueAndTimestamp.make(2L, currentTime + 45L)), + KeyValue.pair( + new Windowed<>(3, new TimeWindow(0L, 1000L)), + ValueAndTimestamp.make(1L, lastUpdateKeyThree)), + KeyValue.pair( + new Windowed<>(4, new TimeWindow(0L, 1000L)), + ValueAndTimestamp.make(5L, currentTime + 42L)))); + + // out of order (same key) record should not reduce result timestamp + processKeyValueAndVerifyWindowedCountWithTimestamp(4, currentTime + 10L, asList( + KeyValue.pair( + new Windowed<>(1, new TimeWindow(0L, 1000L)), + ValueAndTimestamp.make(3L, currentTime + 42L)), + KeyValue.pair( + new Windowed<>(2, new TimeWindow(0L, 1000L)), + ValueAndTimestamp.make(2L, currentTime + 45L)), + KeyValue.pair( + new Windowed<>(3, new TimeWindow(0L, 1000L)), + ValueAndTimestamp.make(1L, lastUpdateKeyThree)), + KeyValue.pair( + new Windowed<>(4, new TimeWindow(0L, 1000L)), + ValueAndTimestamp.make(6L, currentTime + 42L)))); + + // test new segment + processKeyValueAndVerifyWindowedCountWithTimestamp(10, currentTime + 100001L, singletonList( + KeyValue.pair( + new Windowed<>(10, new TimeWindow(100000L, 101000L)), ValueAndTimestamp.make(1L, currentTime + 100001L)))); + + + kafkaStreams.close(); + } + + @Test + public void shouldProxyWindowStoreToTimestampedWindowStoreUsingPapi() throws Exception { + final StreamsBuilder streamsBuilderForOldStore = new StreamsBuilder(); + + streamsBuilderForOldStore.addStateStore( + Stores.windowStoreBuilder( + Stores.persistentWindowStore( + STORE_NAME, + Duration.ofMillis(1000L), + Duration.ofMillis(1000L), + false), + Serdes.Integer(), + Serdes.Long())) + .stream(inputStream) + .process(WindowedProcessor::new, STORE_NAME); + + final Properties props = props(); + kafkaStreams = new KafkaStreams(streamsBuilderForOldStore.build(), props); + kafkaStreams.start(); + + processWindowedKeyValueAndVerifyPlainCount(1, singletonList( + KeyValue.pair(new Windowed<>(1, new TimeWindow(0L, 1000L)), 1L))); + + processWindowedKeyValueAndVerifyPlainCount(1, singletonList( + KeyValue.pair(new Windowed<>(1, new TimeWindow(0L, 1000L)), 2L))); + + processWindowedKeyValueAndVerifyPlainCount(2, asList( + KeyValue.pair(new Windowed<>(1, new TimeWindow(0L, 1000L)), 2L), + KeyValue.pair(new Windowed<>(2, new TimeWindow(0L, 1000L)), 1L))); + + processWindowedKeyValueAndVerifyPlainCount(3, asList( + KeyValue.pair(new Windowed<>(1, new TimeWindow(0L, 1000L)), 2L), + KeyValue.pair(new Windowed<>(2, new TimeWindow(0L, 1000L)), 1L), + KeyValue.pair(new Windowed<>(3, new TimeWindow(0L, 1000L)), 1L))); + + processWindowedKeyValueAndVerifyPlainCount(4, asList( + KeyValue.pair(new Windowed<>(1, new TimeWindow(0L, 1000L)), 2L), + KeyValue.pair(new Windowed<>(2, new TimeWindow(0L, 1000L)), 1L), + KeyValue.pair(new Windowed<>(3, new TimeWindow(0L, 1000L)), 1L), + KeyValue.pair(new Windowed<>(4, new TimeWindow(0L, 1000L)), 1L))); + + processWindowedKeyValueAndVerifyPlainCount(4, asList( + KeyValue.pair(new Windowed<>(1, new TimeWindow(0L, 1000L)), 2L), + KeyValue.pair(new Windowed<>(2, new TimeWindow(0L, 1000L)), 1L), + KeyValue.pair(new Windowed<>(3, new TimeWindow(0L, 1000L)), 1L), + KeyValue.pair(new Windowed<>(4, new TimeWindow(0L, 1000L)), 2L))); + + processWindowedKeyValueAndVerifyPlainCount(4, asList( + KeyValue.pair(new Windowed<>(1, new TimeWindow(0L, 1000L)), 2L), + KeyValue.pair(new Windowed<>(2, new TimeWindow(0L, 1000L)), 1L), + KeyValue.pair(new Windowed<>(3, new TimeWindow(0L, 1000L)), 1L), + KeyValue.pair(new Windowed<>(4, new TimeWindow(0L, 1000L)), 3L))); + + kafkaStreams.close(); + kafkaStreams = null; + + + + final StreamsBuilder streamsBuilderForNewStore = new StreamsBuilder(); + + streamsBuilderForNewStore.addStateStore( + Stores.timestampedWindowStoreBuilder( + Stores.persistentWindowStore( + STORE_NAME, + Duration.ofMillis(1000L), + Duration.ofMillis(1000L), + false), + Serdes.Integer(), + Serdes.Long())) + .stream(inputStream) + .process(TimestampedWindowedProcessor::new, STORE_NAME); + + kafkaStreams = new KafkaStreams(streamsBuilderForNewStore.build(), props); + kafkaStreams.start(); + + verifyWindowedCountWithSurrogateTimestamp(new Windowed<>(1, new TimeWindow(0L, 1000L)), 2L); + verifyWindowedCountWithSurrogateTimestamp(new Windowed<>(2, new TimeWindow(0L, 1000L)), 1L); + verifyWindowedCountWithSurrogateTimestamp(new Windowed<>(3, new TimeWindow(0L, 1000L)), 1L); + verifyWindowedCountWithSurrogateTimestamp(new Windowed<>(4, new TimeWindow(0L, 1000L)), 3L); + + processKeyValueAndVerifyWindowedCountWithTimestamp(1, 42L, asList( + KeyValue.pair(new Windowed<>(1, new TimeWindow(0L, 1000L)), ValueAndTimestamp.make(3L, -1L)), + KeyValue.pair(new Windowed<>(2, new TimeWindow(0L, 1000L)), ValueAndTimestamp.make(1L, -1L)), + KeyValue.pair(new Windowed<>(3, new TimeWindow(0L, 1000L)), ValueAndTimestamp.make(1L, -1L)), + KeyValue.pair(new Windowed<>(4, new TimeWindow(0L, 1000L)), ValueAndTimestamp.make(3L, -1L)))); + + processKeyValueAndVerifyWindowedCountWithTimestamp(2, 45L, asList( + KeyValue.pair(new Windowed<>(1, new TimeWindow(0L, 1000L)), ValueAndTimestamp.make(3L, -1L)), + KeyValue.pair(new Windowed<>(2, new TimeWindow(0L, 1000L)), ValueAndTimestamp.make(2L, -1L)), + KeyValue.pair(new Windowed<>(3, new TimeWindow(0L, 1000L)), ValueAndTimestamp.make(1L, -1L)), + KeyValue.pair(new Windowed<>(4, new TimeWindow(0L, 1000L)), ValueAndTimestamp.make(3L, -1L)))); + + // can process "out of order" record for different key + processKeyValueAndVerifyWindowedCountWithTimestamp(4, 21L, asList( + KeyValue.pair(new Windowed<>(1, new TimeWindow(0L, 1000L)), ValueAndTimestamp.make(3L, -1L)), + KeyValue.pair(new Windowed<>(2, new TimeWindow(0L, 1000L)), ValueAndTimestamp.make(2L, -1L)), + KeyValue.pair(new Windowed<>(3, new TimeWindow(0L, 1000L)), ValueAndTimestamp.make(1L, -1L)), + KeyValue.pair(new Windowed<>(4, new TimeWindow(0L, 1000L)), ValueAndTimestamp.make(4L, -1L)))); + + processKeyValueAndVerifyWindowedCountWithTimestamp(4, 42L, asList( + KeyValue.pair(new Windowed<>(1, new TimeWindow(0L, 1000L)), ValueAndTimestamp.make(3L, -1L)), + KeyValue.pair(new Windowed<>(2, new TimeWindow(0L, 1000L)), ValueAndTimestamp.make(2L, -1L)), + KeyValue.pair(new Windowed<>(3, new TimeWindow(0L, 1000L)), ValueAndTimestamp.make(1L, -1L)), + KeyValue.pair(new Windowed<>(4, new TimeWindow(0L, 1000L)), ValueAndTimestamp.make(5L, -1L)))); + + // out of order (same key) record should not reduce result timestamp + processKeyValueAndVerifyWindowedCountWithTimestamp(4, 10L, asList( + KeyValue.pair(new Windowed<>(1, new TimeWindow(0L, 1000L)), ValueAndTimestamp.make(3L, -1L)), + KeyValue.pair(new Windowed<>(2, new TimeWindow(0L, 1000L)), ValueAndTimestamp.make(2L, -1L)), + KeyValue.pair(new Windowed<>(3, new TimeWindow(0L, 1000L)), ValueAndTimestamp.make(1L, -1L)), + KeyValue.pair(new Windowed<>(4, new TimeWindow(0L, 1000L)), ValueAndTimestamp.make(6L, -1L)))); + + // test new segment + processKeyValueAndVerifyWindowedCountWithTimestamp(10, 100001L, singletonList( + KeyValue.pair(new Windowed<>(10, new TimeWindow(100000L, 101000L)), ValueAndTimestamp.make(1L, -1L)))); + + + kafkaStreams.close(); + } + + private void processWindowedKeyValueAndVerifyPlainCount(final K key, + final List, Object>> expectedStoreContent) + throws Exception { + + IntegrationTestUtils.produceKeyValuesSynchronously( + inputStream, + singletonList(KeyValue.pair(key, 0)), + TestUtils.producerConfig(CLUSTER.bootstrapServers(), + IntegerSerializer.class, + IntegerSerializer.class), + CLUSTER.time); + + TestUtils.waitForCondition( + () -> { + try { + final ReadOnlyWindowStore store = IntegrationTestUtils + .getStore(STORE_NAME, kafkaStreams, QueryableStoreTypes.windowStore()); + + if (store == null) + return false; + + try (final KeyValueIterator, V> all = store.all()) { + final List, V>> storeContent = new LinkedList<>(); + while (all.hasNext()) { + storeContent.add(all.next()); + } + return storeContent.equals(expectedStoreContent); + } + } catch (final Exception swallow) { + swallow.printStackTrace(); + System.err.println(swallow.getMessage()); + return false; + } + }, + 60_000L, + "Could not get expected result in time."); + } + + private void verifyWindowedCountWithSurrogateTimestamp(final Windowed key, + final long value) throws Exception { + TestUtils.waitForCondition( + () -> { + try { + final ReadOnlyWindowStore> store = IntegrationTestUtils + .getStore(STORE_NAME, kafkaStreams, QueryableStoreTypes.timestampedWindowStore()); + + if (store == null) + return false; + + final ValueAndTimestamp count = store.fetch(key.key(), key.window().start()); + return count.value() == value && count.timestamp() == -1L; + } catch (final Exception swallow) { + swallow.printStackTrace(); + System.err.println(swallow.getMessage()); + return false; + } + }, + 60_000L, + "Could not get expected result in time."); + } + + private void verifyWindowedCountWithTimestamp(final Windowed key, + final long value, + final long timestamp) throws Exception { + TestUtils.waitForCondition( + () -> { + try { + final ReadOnlyWindowStore> store = IntegrationTestUtils + .getStore(STORE_NAME, kafkaStreams, QueryableStoreTypes.timestampedWindowStore()); + + if (store == null) + return false; + + final ValueAndTimestamp count = store.fetch(key.key(), key.window().start()); + return count.value() == value && count.timestamp() == timestamp; + } catch (final Exception swallow) { + swallow.printStackTrace(); + System.err.println(swallow.getMessage()); + return false; + } + }, + 60_000L, + "Could not get expected result in time."); + } + + private void processKeyValueAndVerifyWindowedCountWithTimestamp(final K key, + final long timestamp, + final List, Object>> expectedStoreContent) + throws Exception { + + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + inputStream, + singletonList(KeyValue.pair(key, 0)), + TestUtils.producerConfig(CLUSTER.bootstrapServers(), + IntegerSerializer.class, + IntegerSerializer.class), + timestamp); + + TestUtils.waitForCondition( + () -> { + try { + final ReadOnlyWindowStore> store = IntegrationTestUtils + .getStore(STORE_NAME, kafkaStreams, QueryableStoreTypes.timestampedWindowStore()); + + if (store == null) + return false; + + try (final KeyValueIterator, ValueAndTimestamp> all = store.all()) { + final List, ValueAndTimestamp>> storeContent = new LinkedList<>(); + while (all.hasNext()) { + storeContent.add(all.next()); + } + return storeContent.equals(expectedStoreContent); + } + } catch (final Exception swallow) { + swallow.printStackTrace(); + System.err.println(swallow.getMessage()); + return false; + } + }, + 60_000L, + "Could not get expected result in time."); + } + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + private static class KeyValueProcessor implements org.apache.kafka.streams.processor.Processor { + private KeyValueStore store; + + @SuppressWarnings("unchecked") + @Override + public void init(final ProcessorContext context) { + store = (KeyValueStore) context.getStateStore(STORE_NAME); + } + + @Override + public void process(final Integer key, final Integer value) { + final long newCount; + + final Long oldCount = store.get(key); + if (oldCount != null) { + newCount = oldCount + 1L; + } else { + newCount = 1L; + } + + store.put(key, newCount); + } + + @Override + public void close() {} + } + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + private static class TimestampedKeyValueProcessor implements org.apache.kafka.streams.processor.Processor { + private ProcessorContext context; + private TimestampedKeyValueStore store; + + @SuppressWarnings("unchecked") + @Override + public void init(final ProcessorContext context) { + this.context = context; + store = (TimestampedKeyValueStore) context.getStateStore(STORE_NAME); + } + + @Override + public void process(final Integer key, final Integer value) { + final long newCount; + + final ValueAndTimestamp oldCountWithTimestamp = store.get(key); + final long newTimestamp; + + if (oldCountWithTimestamp == null) { + newCount = 1L; + newTimestamp = context.timestamp(); + } else { + newCount = oldCountWithTimestamp.value() + 1L; + newTimestamp = Math.max(oldCountWithTimestamp.timestamp(), context.timestamp()); + } + + store.put(key, ValueAndTimestamp.make(newCount, newTimestamp)); + } + + @Override + public void close() {} + } + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + private static class WindowedProcessor implements org.apache.kafka.streams.processor.Processor { + private WindowStore store; + + @SuppressWarnings("unchecked") + @Override + public void init(final ProcessorContext context) { + store = (WindowStore) context.getStateStore(STORE_NAME); + } + + @Override + public void process(final Integer key, final Integer value) { + final long newCount; + + final Long oldCount = store.fetch(key, key < 10 ? 0L : 100000L); + if (oldCount != null) { + newCount = oldCount + 1L; + } else { + newCount = 1L; + } + + store.put(key, newCount, key < 10 ? 0L : 100000L); + } + + @Override + public void close() {} + } + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + private static class TimestampedWindowedProcessor implements org.apache.kafka.streams.processor.Processor { + private ProcessorContext context; + private TimestampedWindowStore store; + + @SuppressWarnings("unchecked") + @Override + public void init(final ProcessorContext context) { + this.context = context; + store = (TimestampedWindowStore) context.getStateStore(STORE_NAME); + } + + @Override + public void process(final Integer key, final Integer value) { + final long newCount; + + final ValueAndTimestamp oldCountWithTimestamp = store.fetch(key, key < 10 ? 0L : 100000L); + final long newTimestamp; + + if (oldCountWithTimestamp == null) { + newCount = 1L; + newTimestamp = context.timestamp(); + } else { + newCount = oldCountWithTimestamp.value() + 1L; + newTimestamp = Math.max(oldCountWithTimestamp.timestamp(), context.timestamp()); + } + + store.put(key, ValueAndTimestamp.make(newCount, newTimestamp), key < 10 ? 0L : 100000L); + } + + @Override + public void close() {} + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/StreamStreamJoinIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/StreamStreamJoinIntegrationTest.java new file mode 100644 index 0000000..9d2bd1e --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/StreamStreamJoinIntegrationTest.java @@ -0,0 +1,436 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.JoinWindows; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.test.TestRecord; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.MockMapper; + +import org.junit.Before; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static java.time.Duration.ofHours; +import static java.time.Duration.ofSeconds; + +/** + * Tests all available joins of Kafka Streams DSL. + */ +@Category({IntegrationTest.class}) +@RunWith(value = Parameterized.class) +public class StreamStreamJoinIntegrationTest extends AbstractJoinIntegrationTest { + private KStream leftStream; + private KStream rightStream; + + public StreamStreamJoinIntegrationTest(final boolean cacheEnabled) { + super(cacheEnabled); + } + + @Before + public void prepareTopology() throws InterruptedException { + super.prepareEnvironment(); + + appID = "stream-stream-join-integration-test"; + + builder = new StreamsBuilder(); + leftStream = builder.stream(INPUT_TOPIC_LEFT); + rightStream = builder.stream(INPUT_TOPIC_RIGHT); + } + + @Test + public void testInner() { + STREAMS_CONFIG.put(StreamsConfig.APPLICATION_ID_CONFIG, appID + "-inner"); + + final List>> expectedResult = Arrays.asList( + null, + null, + null, + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "A-a", null, 4L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "B-a", null, 5L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-b", null, 6L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-b", null, 6L)), + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "C-a", null, 9L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-b", null, 9L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-c", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-c", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-c", null, 10L)), + null, + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-d", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-d", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-d", null, 14L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "D-a", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-b", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-c", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-d", null, 15L)) + ); + + leftStream.join( + rightStream, + valueJoiner, + JoinWindows.ofTimeDifferenceAndGrace(ofSeconds(10), ofHours(24)) + ).to(OUTPUT_TOPIC); + + runTestWithDriver(expectedResult); + } + + @Test + public void testInnerRepartitioned() { + STREAMS_CONFIG.put(StreamsConfig.APPLICATION_ID_CONFIG, appID + "-inner-repartitioned"); + + final List>> expectedResult = Arrays.asList( + null, + null, + null, + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "A-a", null, 4L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "B-a", null, 5L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-b", null, 6L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-b", null, 6L)), + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "C-a", null, 9L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-b", null, 9L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-c", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-c", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-c", null, 10L)), + null, + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-d", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-d", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-d", null, 14L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "D-a", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-b", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-c", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-d", null, 15L)) + ); + + leftStream.map(MockMapper.noOpKeyValueMapper()) + .join( + rightStream.flatMap(MockMapper.noOpFlatKeyValueMapper()) + .selectKey(MockMapper.selectKeyKeyValueMapper()), + valueJoiner, + JoinWindows.ofTimeDifferenceAndGrace(ofSeconds(10), ofHours(24)) + ).to(OUTPUT_TOPIC); + + runTestWithDriver(expectedResult); + } + + @Test + public void testLeft() { + STREAMS_CONFIG.put(StreamsConfig.APPLICATION_ID_CONFIG, appID + "-left"); + + final List>> expectedResult = Arrays.asList( + null, + null, + null, + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "A-a", null, 4L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "B-a", null, 5L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-b", null, 6L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-b", null, 6L)), + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "C-a", null, 9L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-b", null, 9L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-c", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-c", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-c", null, 10L)), + null, + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-d", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-d", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-d", null, 14L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "D-a", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-b", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-c", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-d", null, 15L)) + ); + + leftStream.leftJoin( + rightStream, + valueJoiner, + JoinWindows.ofTimeDifferenceAndGrace(ofSeconds(10), ofHours(24)) + ).to(OUTPUT_TOPIC); + + runTestWithDriver(expectedResult); + } + + @Test + public void testLeftRepartitioned() { + STREAMS_CONFIG.put(StreamsConfig.APPLICATION_ID_CONFIG, appID + "-left-repartitioned"); + + final List>> expectedResult = Arrays.asList( + null, + null, + null, + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "A-a", null, 4L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "B-a", null, 5L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-b", null, 6L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-b", null, 6L)), + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "C-a", null, 9L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-b", null, 9L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-c", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-c", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-c", null, 10L)), + null, + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-d", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-d", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-d", null, 14L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "D-a", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-b", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-c", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-d", null, 15L)) + ); + + leftStream.map(MockMapper.noOpKeyValueMapper()) + .leftJoin( + rightStream.flatMap(MockMapper.noOpFlatKeyValueMapper()) + .selectKey(MockMapper.selectKeyKeyValueMapper()), + valueJoiner, + JoinWindows.ofTimeDifferenceAndGrace(ofSeconds(10), ofHours(24)) + ).to(OUTPUT_TOPIC); + + runTestWithDriver(expectedResult); + } + + @Test + public void testOuter() { + STREAMS_CONFIG.put(StreamsConfig.APPLICATION_ID_CONFIG, appID + "-outer"); + + final List>> expectedResult = Arrays.asList( + null, + null, + null, + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "A-a", null, 4L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "B-a", null, 5L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-b", null, 6L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-b", null, 6L)), + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "C-a", null, 9L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-b", null, 9L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-c", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-c", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-c", null, 10L)), + null, + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-d", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-d", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-d", null, 14L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "D-a", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-b", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-c", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-d", null, 15L)) + ); + + leftStream.outerJoin( + rightStream, + valueJoiner, + JoinWindows.ofTimeDifferenceAndGrace(ofSeconds(10), ofHours(24)) + ).to(OUTPUT_TOPIC); + + runTestWithDriver(expectedResult); + } + + @Test + public void testOuterRepartitioned() { + STREAMS_CONFIG.put(StreamsConfig.APPLICATION_ID_CONFIG, appID + "-outer"); + + final List>> expectedResult = Arrays.asList( + null, + null, + null, + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "A-a", null, 4L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "B-a", null, 5L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-b", null, 6L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-b", null, 6L)), + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "C-a", null, 9L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-b", null, 9L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-c", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-c", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-c", null, 10L)), + null, + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-d", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-d", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-d", null, 14L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "D-a", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-b", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-c", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-d", null, 15L)) + ); + + leftStream.map(MockMapper.noOpKeyValueMapper()) + .outerJoin( + rightStream.flatMap(MockMapper.noOpFlatKeyValueMapper()) + .selectKey(MockMapper.selectKeyKeyValueMapper()), + valueJoiner, + JoinWindows.ofTimeDifferenceAndGrace(ofSeconds(10), ofHours(24)) + ).to(OUTPUT_TOPIC); + + runTestWithDriver(expectedResult); + } + + @Test + public void testMultiInner() { + STREAMS_CONFIG.put(StreamsConfig.APPLICATION_ID_CONFIG, appID + "-multi-inner"); + + final List>> expectedResult = Arrays.asList( + null, + null, + null, + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "A-a-a", null, 4L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "B-a-a", null, 5L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-b-a", null, 6L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-b-a", null, 6L), + new TestRecord<>(ANY_UNIQUE_KEY, "A-a-b", null, 6L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-a-b", null, 6L), + new TestRecord<>(ANY_UNIQUE_KEY, "A-b-b", null, 6L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-b-b", null, 6L)), + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "C-a-a", null, 9L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-a-b", null, 9L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-b-a", null, 9L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-b-b", null, 9L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-c-a", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "A-c-b", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-c-a", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-c-b", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-c-a", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-c-b", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "A-a-c", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-a-c", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "A-b-c", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-b-c", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-a-c", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-b-c", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "A-c-c", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-c-c", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-c-c", null, 10L)), + null, + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-d-a", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "A-d-b", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "A-d-c", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-d-a", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-d-b", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-d-c", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-d-a", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-d-b", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-d-c", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "A-a-d", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-a-d", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "A-b-d", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-b-d", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-a-d", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-b-d", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "A-c-d", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-c-d", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-c-d", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "A-d-d", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-d-d", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-d-d", null, 14L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "D-a-a", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-a-b", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-a-c", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-a-d", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-b-a", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-b-b", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-b-c", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-b-d", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-c-a", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-c-b", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-c-c", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-c-d", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-d-a", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-d-b", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-d-c", null, 15L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-d-d", null, 15L)) + ); + + leftStream.join( + rightStream, + valueJoiner, + JoinWindows.ofTimeDifferenceAndGrace(ofSeconds(10), ofHours(24)) + ).join( + rightStream, + valueJoiner, + JoinWindows.ofTimeDifferenceAndGrace(ofSeconds(10), ofHours(24)) + ).to(OUTPUT_TOPIC); + + runTestWithDriver(expectedResult); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/StreamTableJoinIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/StreamTableJoinIntegrationTest.java new file mode 100644 index 0000000..0f7e8aa --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/StreamTableJoinIntegrationTest.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.test.TestRecord; +import org.apache.kafka.test.IntegrationTest; +import org.junit.Before; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * Tests all available joins of Kafka Streams DSL. + */ +@Category({IntegrationTest.class}) +@RunWith(value = Parameterized.class) +public class StreamTableJoinIntegrationTest extends AbstractJoinIntegrationTest { + private KStream leftStream; + private KTable rightTable; + + public StreamTableJoinIntegrationTest(final boolean cacheEnabled) { + super(cacheEnabled); + } + + @Before + public void prepareTopology() throws InterruptedException { + super.prepareEnvironment(); + + appID = "stream-table-join-integration-test"; + + builder = new StreamsBuilder(); + rightTable = builder.table(INPUT_TOPIC_RIGHT); + leftStream = builder.stream(INPUT_TOPIC_LEFT); + } + + @Test + public void testInner() { + STREAMS_CONFIG.put(StreamsConfig.APPLICATION_ID_CONFIG, appID + "-inner"); + + final List>> expectedResult = Arrays.asList( + null, + null, + null, + null, + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "B-a", null, 5L)), + null, + null, + null, + null, + null, + null, + null, + null, + null, + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "D-d", null, 15L)) + ); + + leftStream.join(rightTable, valueJoiner).to(OUTPUT_TOPIC); + runTestWithDriver(expectedResult); + } + + @Test + public void testLeft() { + STREAMS_CONFIG.put(StreamsConfig.APPLICATION_ID_CONFIG, appID + "-left"); + + final List>> expectedResult = Arrays.asList( + null, + null, + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "A-null", null, 3L)), + null, + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "B-a", null, 5L)), + null, + null, + null, + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "C-null", null, 9L)), + null, + null, + null, + null, + null, + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "D-d", null, 15L)) + ); + + leftStream.leftJoin(rightTable, valueJoiner).to(OUTPUT_TOPIC); + + runTestWithDriver(expectedResult); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/StreamTableJoinTopologyOptimizationIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/StreamTableJoinTopologyOptimizationIntegrationTest.java new file mode 100644 index 0000000..512d1c1 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/StreamTableJoinTopologyOptimizationIntegrationTest.java @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.clients.admin.AdminClient; +import org.apache.kafka.clients.admin.AdminClientConfig; +import org.apache.kafka.clients.admin.TopicDescription; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.IntegerDeserializer; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Named; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +@RunWith(value = Parameterized.class) +@Category({IntegrationTest.class}) +public class StreamTableJoinTopologyOptimizationIntegrationTest { + private static final int NUM_BROKERS = 1; + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(NUM_BROKERS); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + private String tableTopic; + private String inputTopic; + private String outputTopic; + private String applicationId; + private KafkaStreams kafkaStreams; + + private Properties streamsConfiguration; + + @Rule + public TestName testName = new TestName(); + + @Parameterized.Parameter + public String topologyOptimization; + + @Parameterized.Parameters(name = "Optimization = {0}") + public static Collection topologyOptimization() { + return Arrays.asList(new String[][]{ + {StreamsConfig.OPTIMIZE}, + {StreamsConfig.NO_OPTIMIZATION} + }); + } + + @Before + public void before() throws InterruptedException { + streamsConfiguration = new Properties(); + + final String safeTestName = safeUniqueTestName(getClass(), testName); + + tableTopic = "table-topic" + safeTestName; + inputTopic = "stream-topic-" + safeTestName; + outputTopic = "output-topic-" + safeTestName; + applicationId = "app-" + safeTestName; + + CLUSTER.createTopic(inputTopic, 4, 1); + CLUSTER.createTopic(tableTopic, 2, 1); + CLUSTER.createTopic(outputTopic, 4, 1); + + streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, applicationId); + streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()); + streamsConfiguration.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L); + streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.Integer().getClass()); + streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + streamsConfiguration.put(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, topologyOptimization); + } + + @After + public void whenShuttingDown() throws IOException { + if (kafkaStreams != null) { + kafkaStreams.close(); + } + IntegrationTestUtils.purgeLocalStreamsState(streamsConfiguration); + } + + @Test + public void shouldDoStreamTableJoinWithDifferentNumberOfPartitions() throws Exception { + final String storeName = "store"; + final String selectKeyName = "selectKey"; + + final StreamsBuilder streamsBuilder = new StreamsBuilder(); + + final KStream stream = streamsBuilder.stream(inputTopic); + final KTable table = streamsBuilder.table(tableTopic, Materialized.as(storeName)); + + stream + .selectKey((key, value) -> key, Named.as(selectKeyName)) + .join(table, (value1, value2) -> value2) + .to(outputTopic); + + kafkaStreams = startStreams(streamsBuilder); + + final long timestamp = System.currentTimeMillis(); + + final List> expectedRecords = Arrays.asList( + new KeyValue<>(1, "A"), + new KeyValue<>(2, "B") + ); + + sendEvents(inputTopic, timestamp, expectedRecords); + sendEvents(outputTopic, timestamp, expectedRecords); + + validateReceivedMessages( + outputTopic, + new IntegerDeserializer(), + new StringDeserializer(), + expectedRecords + ); + + final Set allTopicsInCluster = CLUSTER.getAllTopicsInCluster(); + + final String repartitionTopicName = applicationId + "-" + selectKeyName + "-repartition"; + final String tableChangelogStoreName = applicationId + "-" + storeName + "-changelog"; + + assertTrue(topicExists(repartitionTopicName)); + assertEquals(2, getNumberOfPartitionsForTopic(repartitionTopicName)); + + if (StreamsConfig.OPTIMIZE.equals(topologyOptimization)) { + assertFalse(allTopicsInCluster.contains(tableChangelogStoreName)); + } else if (StreamsConfig.NO_OPTIMIZATION.equals(topologyOptimization)) { + assertTrue(allTopicsInCluster.contains(tableChangelogStoreName)); + } + } + + private KafkaStreams startStreams(final StreamsBuilder builder) throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(1); + final KafkaStreams kafkaStreams = new KafkaStreams(builder.build(streamsConfiguration), streamsConfiguration); + + kafkaStreams.setStateListener((newState, oldState) -> { + if (KafkaStreams.State.REBALANCING == oldState && KafkaStreams.State.RUNNING == newState) { + latch.countDown(); + } + }); + + kafkaStreams.start(); + + latch.await(IntegrationTestUtils.DEFAULT_TIMEOUT, TimeUnit.MILLISECONDS); + + return kafkaStreams; + } + + private int getNumberOfPartitionsForTopic(final String topic) throws Exception { + try (final AdminClient adminClient = createAdminClient()) { + final TopicDescription topicDescription = adminClient.describeTopics(Collections.singleton(topic)) + .topicNameValues() + .get(topic) + .get(IntegrationTestUtils.DEFAULT_TIMEOUT, TimeUnit.MILLISECONDS); + + return topicDescription.partitions().size(); + } + } + + private boolean topicExists(final String topic) { + return CLUSTER.getAllTopicsInCluster().contains(topic); + } + + private void sendEvents(final String topic, + final long timestamp, + final List> events) throws Exception { + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + topic, + events, + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + IntegerSerializer.class, + StringSerializer.class, + new Properties() + ), + timestamp + ); + } + + private void validateReceivedMessages(final String topic, + final Deserializer keySerializer, + final Deserializer valueSerializer, + final List> expectedRecords) throws Exception { + + final String safeTestName = safeUniqueTestName(getClass(), testName); + final Properties consumerProperties = new Properties(); + consumerProperties.setProperty(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + consumerProperties.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "group-" + safeTestName); + consumerProperties.setProperty(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + consumerProperties.setProperty( + ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, + keySerializer.getClass().getName() + ); + consumerProperties.setProperty( + ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, + valueSerializer.getClass().getName() + ); + + IntegrationTestUtils.waitUntilFinalKeyValueRecordsReceived( + consumerProperties, + topic, + expectedRecords + ); + } + + private static AdminClient createAdminClient() { + final Properties properties = new Properties(); + properties.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + + return AdminClient.create(properties); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/StreamsUncaughtExceptionHandlerIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/StreamsUncaughtExceptionHandlerIntegrationTest.java new file mode 100644 index 0000000..74d6ba9 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/StreamsUncaughtExceptionHandlerIntegrationTest.java @@ -0,0 +1,340 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.integration; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Named; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.internals.KeyValueStoreBuilder; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.StreamsTestUtils; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; + +import java.io.IOException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Properties; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkObjectProperties; +import static org.apache.kafka.streams.errors.StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse.REPLACE_THREAD; +import static org.apache.kafka.streams.errors.StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse.SHUTDOWN_APPLICATION; +import static org.apache.kafka.streams.errors.StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse.SHUTDOWN_CLIENT; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.purgeLocalStreamsState; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.waitForApplicationState; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.fail; + +@Category(IntegrationTest.class) +@SuppressWarnings("deprecation") //Need to call the old handler, will remove those calls when the old handler is removed +public class StreamsUncaughtExceptionHandlerIntegrationTest { + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(1, new Properties(), 0L, 0L); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + public static final Duration DEFAULT_DURATION = Duration.ofSeconds(30); + + @Rule + public TestName testName = new TestName(); + + private static String inputTopic; + private static StreamsBuilder builder; + private static Properties properties; + private static List processorValueCollector; + private static String appId = ""; + private static final AtomicBoolean THROW_ERROR = new AtomicBoolean(true); + private static final AtomicBoolean THROW_ILLEGAL_STATE_EXCEPTION = new AtomicBoolean(false); + private static final AtomicBoolean THROW_ILLEGAL_ARGUMENT_EXCEPTION = new AtomicBoolean(false); + + @Before + public void setup() { + final String testId = safeUniqueTestName(getClass(), testName); + appId = "appId_" + testId; + inputTopic = "input" + testId; + IntegrationTestUtils.cleanStateBeforeTest(CLUSTER, inputTopic); + + builder = new StreamsBuilder(); + + processorValueCollector = new ArrayList<>(); + + final KStream stream = builder.stream(inputTopic); + stream.process(() -> new ShutdownProcessor(processorValueCollector), Named.as("process")); + + properties = mkObjectProperties( + mkMap( + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()), + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, appId), + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()), + mkEntry(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 2), + mkEntry(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.StringSerde.class), + mkEntry(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.StringSerde.class), + mkEntry(StreamsConfig.consumerPrefix(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG), 10000) + ) + ); + } + + @After + public void teardown() throws IOException { + purgeLocalStreamsState(properties); + } + + @Test + public void shouldShutdownThreadUsingOldHandler() throws InterruptedException { + try (final KafkaStreams kafkaStreams = new KafkaStreams(builder.build(), properties)) { + final AtomicInteger counter = new AtomicInteger(0); + kafkaStreams.setUncaughtExceptionHandler((t, e) -> counter.incrementAndGet()); + + StreamsTestUtils.startKafkaStreamsAndWaitForRunningState(kafkaStreams); + produceMessages(0L, inputTopic, "A"); + + // should call the UncaughtExceptionHandler in current thread + TestUtils.waitForCondition(() -> counter.get() == 1, "Handler was called 1st time"); + // should call the UncaughtExceptionHandler after rebalancing to another thread + TestUtils.waitForCondition(() -> counter.get() == 2, DEFAULT_DURATION.toMillis(), "Handler was called 2nd time"); + // there is no threads running but the client is still in running + waitForApplicationState(Collections.singletonList(kafkaStreams), KafkaStreams.State.RUNNING, DEFAULT_DURATION); + + assertThat(processorValueCollector.size(), equalTo(2)); + } + } + + @Test + public void shouldShutdownClient() throws InterruptedException { + try (final KafkaStreams kafkaStreams = new KafkaStreams(builder.build(), properties)) { + kafkaStreams.setUncaughtExceptionHandler((t, e) -> fail("should not hit old handler")); + + kafkaStreams.setUncaughtExceptionHandler(exception -> SHUTDOWN_CLIENT); + + StreamsTestUtils.startKafkaStreamsAndWaitForRunningState(kafkaStreams); + + produceMessages(0L, inputTopic, "A"); + waitForApplicationState(Collections.singletonList(kafkaStreams), KafkaStreams.State.ERROR, DEFAULT_DURATION); + + assertThat(processorValueCollector.size(), equalTo(1)); + } + } + + + @Test + public void shouldShutdownClientWhenIllegalStateException() throws InterruptedException { + THROW_ILLEGAL_STATE_EXCEPTION.compareAndSet(false, true); + try (final KafkaStreams kafkaStreams = new KafkaStreams(builder.build(), properties)) { + kafkaStreams.setUncaughtExceptionHandler((t, e) -> fail("should not hit old handler")); + + kafkaStreams.setUncaughtExceptionHandler(exception -> REPLACE_THREAD); // if the user defined uncaught exception handler would be hit we would be replacing the thread + + StreamsTestUtils.startKafkaStreamsAndWaitForRunningState(kafkaStreams); + + produceMessages(0L, inputTopic, "A"); + waitForApplicationState(Collections.singletonList(kafkaStreams), KafkaStreams.State.ERROR, DEFAULT_DURATION); + + assertThat(processorValueCollector.size(), equalTo(1)); + } finally { + THROW_ILLEGAL_STATE_EXCEPTION.compareAndSet(true, false); + } + + } + + @Test + public void shouldShutdownClientWhenIllegalArgumentException() throws InterruptedException { + THROW_ILLEGAL_ARGUMENT_EXCEPTION.compareAndSet(false, true); + try (final KafkaStreams kafkaStreams = new KafkaStreams(builder.build(), properties)) { + kafkaStreams.setUncaughtExceptionHandler((t, e) -> fail("should not hit old handler")); + + kafkaStreams.setUncaughtExceptionHandler(exception -> REPLACE_THREAD); // if the user defined uncaught exception handler would be hit we would be replacing the thread + + StreamsTestUtils.startKafkaStreamsAndWaitForRunningState(kafkaStreams); + + produceMessages(0L, inputTopic, "A"); + waitForApplicationState(Collections.singletonList(kafkaStreams), KafkaStreams.State.ERROR, DEFAULT_DURATION); + + assertThat(processorValueCollector.size(), equalTo(1)); + } finally { + THROW_ILLEGAL_ARGUMENT_EXCEPTION.compareAndSet(true, false); + } + + } + + @Test + public void shouldReplaceThreads() throws InterruptedException { + testReplaceThreads(2); + } + + @Test + public void shouldReplaceSingleThread() throws InterruptedException { + testReplaceThreads(1); + } + + @Test + public void shouldShutdownMultipleThreadApplication() throws InterruptedException { + testShutdownApplication(2); + } + + @Test + public void shouldShutdownSingleThreadApplication() throws InterruptedException { + testShutdownApplication(1); + } + + @Test + public void shouldShutDownClientIfGlobalStreamThreadWantsToReplaceThread() throws InterruptedException { + builder = new StreamsBuilder(); + builder.addGlobalStore( + new KeyValueStoreBuilder<>( + Stores.persistentKeyValueStore("globalStore"), + Serdes.String(), + Serdes.String(), + CLUSTER.time + ), + inputTopic, + Consumed.with(Serdes.String(), Serdes.String()), + () -> new ShutdownProcessor(processorValueCollector) + ); + properties.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 0); + + try (final KafkaStreams kafkaStreams = new KafkaStreams(builder.build(), properties)) { + kafkaStreams.setUncaughtExceptionHandler((t, e) -> fail("should not hit old handler")); + kafkaStreams.setUncaughtExceptionHandler(exception -> REPLACE_THREAD); + + StreamsTestUtils.startKafkaStreamsAndWaitForRunningState(kafkaStreams); + + produceMessages(0L, inputTopic, "A"); + waitForApplicationState(Collections.singletonList(kafkaStreams), KafkaStreams.State.ERROR, DEFAULT_DURATION); + + assertThat(processorValueCollector.size(), equalTo(1)); + } + + } + + private void produceMessages(final long timestamp, final String streamOneInput, final String msg) { + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + streamOneInput, + Collections.singletonList(new KeyValue<>("1", msg)), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + StringSerializer.class, + StringSerializer.class, + new Properties()), + timestamp); + } + + private static class ShutdownProcessor extends org.apache.kafka.streams.processor.AbstractProcessor { + final List valueList; + + ShutdownProcessor(final List valueList) { + this.valueList = valueList; + } + + @Override + public void process(final String key, final String value) { + valueList.add(value + " " + context.taskId()); + if (THROW_ERROR.get()) { + if (THROW_ILLEGAL_STATE_EXCEPTION.get()) { + throw new IllegalStateException("Something unexpected happened in " + Thread.currentThread().getName()); + } else if (THROW_ILLEGAL_ARGUMENT_EXCEPTION.get()) { + throw new IllegalArgumentException("Something unexpected happened in " + Thread.currentThread().getName()); + } else { + throw new StreamsException(Thread.currentThread().getName()); + } + } + THROW_ERROR.set(true); + } + } + + private void testShutdownApplication(final int numThreads) throws InterruptedException { + properties.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, numThreads); + + final Topology topology = builder.build(); + + try (final KafkaStreams kafkaStreams1 = new KafkaStreams(topology, properties); + final KafkaStreams kafkaStreams2 = new KafkaStreams(topology, properties)) { + kafkaStreams1.setUncaughtExceptionHandler((t, e) -> fail("should not hit old handler")); + kafkaStreams2.setUncaughtExceptionHandler((t, e) -> fail("should not hit old handler")); + kafkaStreams1.setUncaughtExceptionHandler(exception -> SHUTDOWN_APPLICATION); + kafkaStreams2.setUncaughtExceptionHandler(exception -> SHUTDOWN_APPLICATION); + + StreamsTestUtils.startKafkaStreamsAndWaitForRunningState(kafkaStreams1); + StreamsTestUtils.startKafkaStreamsAndWaitForRunningState(kafkaStreams2); + + produceMessages(0L, inputTopic, "A"); + waitForApplicationState(Arrays.asList(kafkaStreams1, kafkaStreams2), KafkaStreams.State.ERROR, DEFAULT_DURATION); + + assertThat(processorValueCollector.size(), equalTo(1)); + } + } + + private void testReplaceThreads(final int numThreads) throws InterruptedException { + properties.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, numThreads); + try (final KafkaStreams kafkaStreams = new KafkaStreams(builder.build(), properties)) { + kafkaStreams.setUncaughtExceptionHandler((t, e) -> fail("should not hit old handler")); + + final AtomicInteger count = new AtomicInteger(); + kafkaStreams.setUncaughtExceptionHandler(exception -> { + if (count.incrementAndGet() == numThreads) { + THROW_ERROR.set(false); + } + return REPLACE_THREAD; + }); + StreamsTestUtils.startKafkaStreamsAndWaitForRunningState(kafkaStreams); + + produceMessages(0L, inputTopic, "A"); + TestUtils.waitForCondition(() -> count.get() == numThreads, "finished replacing threads"); + TestUtils.waitForCondition(() -> THROW_ERROR.get(), "finished replacing threads"); + kafkaStreams.close(); + waitForApplicationState(Collections.singletonList(kafkaStreams), KafkaStreams.State.NOT_RUNNING, DEFAULT_DURATION); + + assertThat("All initial threads have failed and the replacement thread had processed on record", + processorValueCollector.size(), equalTo(numThreads + 1)); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/StreamsUpgradeTestIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/StreamsUpgradeTestIntegrationTest.java new file mode 100644 index 0000000..4285530 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/StreamsUpgradeTestIntegrationTest.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.tests.StreamsUpgradeTest; +import org.apache.kafka.test.IntegrationTest; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; + +import java.io.IOException; +import java.time.Duration; +import java.util.Properties; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkProperties; +import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.LATEST_SUPPORTED_VERSION; + +import static org.apache.kafka.test.TestUtils.retryOnExceptionWithTimeout; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +@Category(IntegrationTest.class) +public class StreamsUpgradeTestIntegrationTest { + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(3); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + IntegrationTestUtils.cleanStateBeforeTest(CLUSTER, 1, "data"); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + @Test + public void testVersionProbingUpgrade() throws InterruptedException { + final KafkaStreams kafkaStreams1 = StreamsUpgradeTest.buildStreams(mkProperties( + mkMap( + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()) + ) + )); + final KafkaStreams kafkaStreams2 = StreamsUpgradeTest.buildStreams(mkProperties( + mkMap( + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()) + ) + )); + final KafkaStreams kafkaStreams3 = StreamsUpgradeTest.buildStreams(mkProperties( + mkMap( + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()) + ) + )); + startSync(kafkaStreams1, kafkaStreams2, kafkaStreams3); + + // first roll + kafkaStreams1.close(); + final AtomicInteger usedVersion4 = new AtomicInteger(); + final KafkaStreams kafkaStreams4 = buildFutureStreams(usedVersion4); + startSync(kafkaStreams4); + assertThat(usedVersion4.get(), is(LATEST_SUPPORTED_VERSION)); + + // second roll + kafkaStreams2.close(); + final AtomicInteger usedVersion5 = new AtomicInteger(); + final KafkaStreams kafkaStreams5 = buildFutureStreams(usedVersion5); + startSync(kafkaStreams5); + assertThat(usedVersion5.get(), is(LATEST_SUPPORTED_VERSION)); + + // third roll, upgrade complete + kafkaStreams3.close(); + final AtomicInteger usedVersion6 = new AtomicInteger(); + final KafkaStreams kafkaStreams6 = buildFutureStreams(usedVersion6); + startSync(kafkaStreams6); + retryOnExceptionWithTimeout(() -> assertThat(usedVersion6.get(), is(LATEST_SUPPORTED_VERSION + 1))); + retryOnExceptionWithTimeout(() -> assertThat(usedVersion5.get(), is(LATEST_SUPPORTED_VERSION + 1))); + retryOnExceptionWithTimeout(() -> assertThat(usedVersion4.get(), is(LATEST_SUPPORTED_VERSION + 1))); + + kafkaStreams4.close(Duration.ZERO); + kafkaStreams5.close(Duration.ZERO); + kafkaStreams6.close(Duration.ZERO); + kafkaStreams4.close(); + kafkaStreams5.close(); + kafkaStreams6.close(); + } + + private static KafkaStreams buildFutureStreams(final AtomicInteger usedVersion4) { + final Properties properties = new Properties(); + properties.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + properties.put("test.future.metadata", usedVersion4); + return StreamsUpgradeTest.buildStreams(properties); + } + + private static void startSync(final KafkaStreams... kafkaStreams) throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(kafkaStreams.length); + for (final KafkaStreams streams : kafkaStreams) { + streams.setStateListener((newState, oldState) -> { + if (newState == KafkaStreams.State.RUNNING) { + latch.countDown(); + } + }); + } + for (final KafkaStreams streams : kafkaStreams) { + streams.start(); + } + latch.await(); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/SuppressionDurabilityIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/SuppressionDurabilityIntegrationTest.java new file mode 100644 index 0000000..6498592 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/SuppressionDurabilityIntegrationTest.java @@ -0,0 +1,356 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.LongDeserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.Transformer; +import org.apache.kafka.streams.kstream.TransformerSupplier; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestUtils; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static java.lang.Long.MAX_VALUE; +import static java.time.Duration.ofMillis; +import static java.util.Arrays.asList; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkProperties; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.cleanStateBeforeTest; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.getStartedStreams; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.quietlyCleanStateAfterTest; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; +import static org.apache.kafka.streams.kstream.Suppressed.BufferConfig.maxRecords; +import static org.apache.kafka.streams.kstream.Suppressed.untilTimeLimit; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; + +@RunWith(Parameterized.class) +@Category({IntegrationTest.class}) +public class SuppressionDurabilityIntegrationTest { + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster( + 3, + mkProperties(mkMap()), + 0L + ); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + @Rule + public TestName testName = new TestName(); + + private static final StringDeserializer STRING_DESERIALIZER = new StringDeserializer(); + private static final StringSerializer STRING_SERIALIZER = new StringSerializer(); + private static final Serde STRING_SERDE = Serdes.String(); + private static final LongDeserializer LONG_DESERIALIZER = new LongDeserializer(); + private static final long COMMIT_INTERVAL = 100L; + + @SuppressWarnings("deprecation") + @Parameterized.Parameters(name = "{0}") + public static Collection data() { + return Arrays.asList(new String[][] { + {StreamsConfig.AT_LEAST_ONCE}, + {StreamsConfig.EXACTLY_ONCE}, + {StreamsConfig.EXACTLY_ONCE_V2} + }); + } + + @Parameterized.Parameter + public String processingGuaranteee; + + @Test + public void shouldRecoverBufferAfterShutdown() { + final String testId = safeUniqueTestName(getClass(), testName); + final String appId = "appId_" + testId; + final String input = "input" + testId; + final String storeName = "counts"; + final String outputSuppressed = "output-suppressed" + testId; + final String outputRaw = "output-raw" + testId; + + // create multiple partitions as a trap, in case the buffer doesn't properly set the + // partition on the records, but instead relies on the default key partitioner + cleanStateBeforeTest(CLUSTER, 2, input, outputRaw, outputSuppressed); + + final StreamsBuilder builder = new StreamsBuilder(); + final KTable valueCounts = builder + .stream( + input, + Consumed.with(STRING_SERDE, STRING_SERDE)) + .groupByKey() + .count(Materialized.>as(storeName).withCachingDisabled()); + + final KStream suppressedCounts = valueCounts + .suppress(untilTimeLimit(ofMillis(MAX_VALUE), maxRecords(3L).emitEarlyWhenFull())) + .toStream(); + + final AtomicInteger eventCount = new AtomicInteger(0); + suppressedCounts.foreach((key, value) -> eventCount.incrementAndGet()); + + // expect all post-suppress records to keep the right input topic + final MetadataValidator metadataValidator = new MetadataValidator(input); + + suppressedCounts + .transform(metadataValidator) + .to(outputSuppressed, Produced.with(STRING_SERDE, Serdes.Long())); + + valueCounts + .toStream() + .transform(metadataValidator) + .to(outputRaw, Produced.with(STRING_SERDE, Serdes.Long())); + + final Properties streamsConfig = mkProperties(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, appId), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()), + mkEntry(StreamsConfig.POLL_MS_CONFIG, Long.toString(COMMIT_INTERVAL)), + mkEntry(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, processingGuaranteee), + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()) + )); + + streamsConfig.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, COMMIT_INTERVAL); + + KafkaStreams driver = getStartedStreams(streamsConfig, builder, true); + try { + // start by putting some stuff in the buffer + // note, we send all input records to partition 0 + // to make sure that supppress doesn't erroneously send records to other partitions. + produceSynchronouslyToPartitionZero( + input, + asList( + new KeyValueTimestamp<>("k1", "v1", scaledTime(1L)), + new KeyValueTimestamp<>("k2", "v2", scaledTime(2L)), + new KeyValueTimestamp<>("k3", "v3", scaledTime(3L)) + ) + ); + verifyOutput( + outputRaw, + new HashSet<>(asList( + new KeyValueTimestamp<>("k1", 1L, scaledTime(1L)), + new KeyValueTimestamp<>("k2", 1L, scaledTime(2L)), + new KeyValueTimestamp<>("k3", 1L, scaledTime(3L)) + )) + ); + assertThat(eventCount.get(), is(0)); + + // flush two of the first three events out. + produceSynchronouslyToPartitionZero( + input, + asList( + new KeyValueTimestamp<>("k4", "v4", scaledTime(4L)), + new KeyValueTimestamp<>("k5", "v5", scaledTime(5L)) + ) + ); + verifyOutput( + outputRaw, + new HashSet<>(asList( + new KeyValueTimestamp<>("k4", 1L, scaledTime(4L)), + new KeyValueTimestamp<>("k5", 1L, scaledTime(5L)) + )) + ); + assertThat(eventCount.get(), is(2)); + verifyOutput( + outputSuppressed, + asList( + new KeyValueTimestamp<>("k1", 1L, scaledTime(1L)), + new KeyValueTimestamp<>("k2", 1L, scaledTime(2L)) + ) + ); + + // bounce to ensure that the history, including retractions, + // get restored properly. (i.e., we shouldn't see those first events again) + + // restart the driver + driver.close(); + assertThat(driver.state(), is(KafkaStreams.State.NOT_RUNNING)); + driver = getStartedStreams(streamsConfig, builder, false); + + + // flush those recovered buffered events out. + produceSynchronouslyToPartitionZero( + input, + asList( + new KeyValueTimestamp<>("k6", "v6", scaledTime(6L)), + new KeyValueTimestamp<>("k7", "v7", scaledTime(7L)), + new KeyValueTimestamp<>("k8", "v8", scaledTime(8L)) + ) + ); + verifyOutput( + outputRaw, + new HashSet<>(asList( + new KeyValueTimestamp<>("k6", 1L, scaledTime(6L)), + new KeyValueTimestamp<>("k7", 1L, scaledTime(7L)), + new KeyValueTimestamp<>("k8", 1L, scaledTime(8L)) + )) + ); + assertThat("suppress has apparently produced some duplicates. There should only be 5 output events.", + eventCount.get(), is(5)); + + verifyOutput( + outputSuppressed, + asList( + new KeyValueTimestamp<>("k3", 1L, scaledTime(3L)), + new KeyValueTimestamp<>("k4", 1L, scaledTime(4L)), + new KeyValueTimestamp<>("k5", 1L, scaledTime(5L)) + ) + ); + + metadataValidator.raiseExceptionIfAny(); + + } finally { + driver.close(); + quietlyCleanStateAfterTest(CLUSTER, driver); + } + } + + private static final class MetadataValidator implements TransformerSupplier> { + private static final Logger LOG = LoggerFactory.getLogger(MetadataValidator.class); + private final AtomicReference firstException = new AtomicReference<>(); + private final String topic; + + public MetadataValidator(final String topic) { + this.topic = topic; + } + + @Override + public Transformer> get() { + return new Transformer>() { + private ProcessorContext context; + + @Override + public void init(final ProcessorContext context) { + this.context = context; + } + + @Override + public KeyValue transform(final String key, final Long value) { + try { + assertThat(context.topic(), equalTo(topic)); + } catch (final Throwable e) { + firstException.compareAndSet(null, e); + LOG.error("Validation Failed", e); + } + return new KeyValue<>(key, value); + } + + @Override + public void close() { + + } + }; + } + + void raiseExceptionIfAny() { + final Throwable exception = firstException.get(); + if (exception != null) { + throw new AssertionError("Got an exception during run", exception); + } + } + } + + private void verifyOutput(final String topic, final List> keyValueTimestamps) { + final Properties properties = mkProperties( + mkMap( + mkEntry(ConsumerConfig.GROUP_ID_CONFIG, "test-group"), + mkEntry(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()), + mkEntry(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, ((Deserializer) STRING_DESERIALIZER).getClass().getName()), + mkEntry(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, ((Deserializer) LONG_DESERIALIZER).getClass().getName()) + ) + ); + IntegrationTestUtils.verifyKeyValueTimestamps(properties, topic, keyValueTimestamps); + } + + private void verifyOutput(final String topic, final Set> keyValueTimestamps) { + final Properties properties = mkProperties( + mkMap( + mkEntry(ConsumerConfig.GROUP_ID_CONFIG, "test-group"), + mkEntry(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()), + mkEntry(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, ((Deserializer) STRING_DESERIALIZER).getClass().getName()), + mkEntry(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, ((Deserializer) LONG_DESERIALIZER).getClass().getName()) + ) + ); + IntegrationTestUtils.verifyKeyValueTimestamps(properties, topic, keyValueTimestamps); + } + + /** + * scaling to ensure that there are commits in between the various test events, + * just to exercise that everything works properly in the presence of commits. + */ + private long scaledTime(final long unscaledTime) { + return COMMIT_INTERVAL * 2 * unscaledTime; + } + + private static void produceSynchronouslyToPartitionZero(final String topic, final List> toProduce) { + final Properties producerConfig = mkProperties(mkMap( + mkEntry(ProducerConfig.CLIENT_ID_CONFIG, "anything"), + mkEntry(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, ((Serializer) STRING_SERIALIZER).getClass().getName()), + mkEntry(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, ((Serializer) STRING_SERIALIZER).getClass().getName()), + mkEntry(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()) + )); + IntegrationTestUtils.produceSynchronously(producerConfig, false, topic, Optional.of(0), toProduce); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/SuppressionIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/SuppressionIntegrationTest.java new file mode 100644 index 0000000..71ef0e3 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/SuppressionIntegrationTest.java @@ -0,0 +1,543 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestUtils; +import org.hamcrest.Matchers; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.Properties; +import java.util.Set; +import java.util.stream.Collectors; + +import static java.lang.Long.MAX_VALUE; +import static java.time.Duration.ofMillis; +import static java.util.Arrays.asList; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkProperties; +import static org.apache.kafka.streams.StreamsConfig.AT_LEAST_ONCE; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.DEFAULT_TIMEOUT; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.cleanStateBeforeTest; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.quietlyCleanStateAfterTest; +import static org.apache.kafka.streams.kstream.Suppressed.BufferConfig.maxBytes; +import static org.apache.kafka.streams.kstream.Suppressed.BufferConfig.maxRecords; +import static org.apache.kafka.streams.kstream.Suppressed.untilTimeLimit; +import static org.apache.kafka.test.TestUtils.waitForCondition; +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.empty; + +@Category(IntegrationTest.class) +public class SuppressionIntegrationTest { + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster( + 1, + mkProperties(mkMap()), + 0L + ); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + private static final StringSerializer STRING_SERIALIZER = new StringSerializer(); + private static final Serde STRING_SERDE = Serdes.String(); + private static final int COMMIT_INTERVAL = 100; + + private static KTable buildCountsTable(final String input, final StreamsBuilder builder) { + return builder + .table( + input, + Consumed.with(STRING_SERDE, STRING_SERDE), + Materialized.>with(STRING_SERDE, STRING_SERDE) + .withCachingDisabled() + .withLoggingDisabled() + ) + .groupBy((k, v) -> new KeyValue<>(v, k), Grouped.with(STRING_SERDE, STRING_SERDE)) + .count(Materialized.>as("counts").withCachingDisabled()); + } + + @Test + public void shouldUseDefaultSerdes() { + final String testId = "-shouldInheritSerdes"; + final String appId = getClass().getSimpleName().toLowerCase(Locale.getDefault()) + testId; + final String input = "input" + testId; + final String outputSuppressed = "output-suppressed" + testId; + final String outputRaw = "output-raw" + testId; + + cleanStateBeforeTest(CLUSTER, input, outputRaw, outputSuppressed); + + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream inputStream = builder.stream(input); + + final KTable valueCounts = inputStream + .groupByKey() + .aggregate(() -> "()", (key, value, aggregate) -> aggregate + ",(" + key + ": " + value + ")"); + + valueCounts + .suppress(untilTimeLimit(ofMillis(MAX_VALUE), maxRecords(1L).emitEarlyWhenFull())) + .toStream() + .to(outputSuppressed); + + valueCounts + .toStream() + .to(outputRaw); + + final Properties streamsConfig = getStreamsConfig(appId); + streamsConfig.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.StringSerde.class); + streamsConfig.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.StringSerde.class); + + final KafkaStreams driver = IntegrationTestUtils.getStartedStreams(streamsConfig, builder, true); + try { + produceSynchronously( + input, + asList( + new KeyValueTimestamp<>("k1", "v1", scaledTime(0L)), + new KeyValueTimestamp<>("k1", "v2", scaledTime(1L)), + new KeyValueTimestamp<>("k2", "v1", scaledTime(2L)), + new KeyValueTimestamp<>("x", "x", scaledTime(3L)) + ) + ); + final boolean rawRecords = waitForAnyRecord(outputRaw); + final boolean suppressedRecords = waitForAnyRecord(outputSuppressed); + assertThat(rawRecords, Matchers.is(true)); + assertThat(suppressedRecords, is(true)); + } finally { + driver.close(); + quietlyCleanStateAfterTest(CLUSTER, driver); + } + } + + @Test + public void shouldInheritSerdes() { + final String testId = "-shouldInheritSerdes"; + final String appId = getClass().getSimpleName().toLowerCase(Locale.getDefault()) + testId; + final String input = "input" + testId; + final String outputSuppressed = "output-suppressed" + testId; + final String outputRaw = "output-raw" + testId; + + cleanStateBeforeTest(CLUSTER, input, outputRaw, outputSuppressed); + + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream inputStream = builder.stream(input); + + // count sets the serde to Long + final KTable valueCounts = inputStream + .groupByKey() + .count(); + + valueCounts + .suppress(untilTimeLimit(ofMillis(MAX_VALUE), maxRecords(1L).emitEarlyWhenFull())) + .toStream() + .to(outputSuppressed); + + valueCounts + .toStream() + .to(outputRaw); + + final Properties streamsConfig = getStreamsConfig(appId); + streamsConfig.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.StringSerde.class); + streamsConfig.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.StringSerde.class); + + final KafkaStreams driver = IntegrationTestUtils.getStartedStreams(streamsConfig, builder, true); + try { + produceSynchronously( + input, + asList( + new KeyValueTimestamp<>("k1", "v1", scaledTime(0L)), + new KeyValueTimestamp<>("k1", "v2", scaledTime(1L)), + new KeyValueTimestamp<>("k2", "v1", scaledTime(2L)), + new KeyValueTimestamp<>("x", "x", scaledTime(3L)) + ) + ); + final boolean rawRecords = waitForAnyRecord(outputRaw); + final boolean suppressedRecords = waitForAnyRecord(outputSuppressed); + assertThat(rawRecords, Matchers.is(true)); + assertThat(suppressedRecords, is(true)); + } finally { + driver.close(); + quietlyCleanStateAfterTest(CLUSTER, driver); + } + } + + private static boolean waitForAnyRecord(final String topic) { + final Properties properties = new Properties(); + properties.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + properties.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); + properties.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); + properties.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, false); + + try (final Consumer consumer = new KafkaConsumer<>(properties)) { + final List partitions = + consumer.partitionsFor(topic) + .stream() + .map(pi -> new TopicPartition(pi.topic(), pi.partition())) + .collect(Collectors.toList()); + consumer.assign(partitions); + consumer.seekToBeginning(partitions); + final long start = System.currentTimeMillis(); + while ((System.currentTimeMillis() - start) < DEFAULT_TIMEOUT) { + final ConsumerRecords records = consumer.poll(ofMillis(500)); + + if (!records.isEmpty()) { + return true; + } + } + + return false; + } + } + + @Test + public void shouldShutdownWhenRecordConstraintIsViolated() throws InterruptedException { + final String testId = "-shouldShutdownWhenRecordConstraintIsViolated"; + final String appId = getClass().getSimpleName().toLowerCase(Locale.getDefault()) + testId; + final String input = "input" + testId; + final String outputSuppressed = "output-suppressed" + testId; + final String outputRaw = "output-raw" + testId; + + cleanStateBeforeTest(CLUSTER, input, outputRaw, outputSuppressed); + + final StreamsBuilder builder = new StreamsBuilder(); + final KTable valueCounts = buildCountsTable(input, builder); + + valueCounts + .suppress(untilTimeLimit(ofMillis(MAX_VALUE), maxRecords(1L).shutDownWhenFull())) + .toStream() + .to(outputSuppressed, Produced.with(STRING_SERDE, Serdes.Long())); + + valueCounts + .toStream() + .to(outputRaw, Produced.with(STRING_SERDE, Serdes.Long())); + + final Properties streamsConfig = getStreamsConfig(appId); + final KafkaStreams driver = IntegrationTestUtils.getStartedStreams(streamsConfig, builder, true); + try { + produceSynchronously( + input, + asList( + new KeyValueTimestamp<>("k1", "v1", scaledTime(0L)), + new KeyValueTimestamp<>("k1", "v2", scaledTime(1L)), + new KeyValueTimestamp<>("k2", "v1", scaledTime(2L)), + new KeyValueTimestamp<>("x", "x", scaledTime(3L)) + ) + ); + verifyErrorShutdown(driver); + } finally { + driver.close(); + quietlyCleanStateAfterTest(CLUSTER, driver); + } + } + + @Test + public void shouldShutdownWhenBytesConstraintIsViolated() throws InterruptedException { + final String testId = "-shouldShutdownWhenBytesConstraintIsViolated"; + final String appId = getClass().getSimpleName().toLowerCase(Locale.getDefault()) + testId; + final String input = "input" + testId; + final String outputSuppressed = "output-suppressed" + testId; + final String outputRaw = "output-raw" + testId; + + cleanStateBeforeTest(CLUSTER, input, outputRaw, outputSuppressed); + + final StreamsBuilder builder = new StreamsBuilder(); + final KTable valueCounts = buildCountsTable(input, builder); + + valueCounts + // this is a bit brittle, but I happen to know that the entries are a little over 100 bytes in size. + .suppress(untilTimeLimit(ofMillis(MAX_VALUE), maxBytes(200L).shutDownWhenFull())) + .toStream() + .to(outputSuppressed, Produced.with(STRING_SERDE, Serdes.Long())); + + valueCounts + .toStream() + .to(outputRaw, Produced.with(STRING_SERDE, Serdes.Long())); + + final Properties streamsConfig = getStreamsConfig(appId); + final KafkaStreams driver = IntegrationTestUtils.getStartedStreams(streamsConfig, builder, true); + try { + produceSynchronously( + input, + asList( + new KeyValueTimestamp<>("k1", "v1", scaledTime(0L)), + new KeyValueTimestamp<>("k1", "v2", scaledTime(1L)), + new KeyValueTimestamp<>("k2", "v1", scaledTime(2L)), + new KeyValueTimestamp<>("x", "x", scaledTime(3L)) + ) + ); + verifyErrorShutdown(driver); + } finally { + driver.close(); + quietlyCleanStateAfterTest(CLUSTER, driver); + } + } + + @Test + public void shouldAllowOverridingChangelogConfig() { + final String testId = "-shouldAllowOverridingChangelogConfig"; + final String appId = getClass().getSimpleName().toLowerCase(Locale.getDefault()) + testId; + final String input = "input" + testId; + final String outputSuppressed = "output-suppressed" + testId; + final String outputRaw = "output-raw" + testId; + final Map logConfig = Collections.singletonMap("retention.ms", "1000"); + final String changeLog = "suppressionintegrationtest-shouldAllowOverridingChangelogConfig-KTABLE-SUPPRESS-STATE-STORE-0000000004-changelog"; + + cleanStateBeforeTest(CLUSTER, input, outputRaw, outputSuppressed); + + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream inputStream = builder.stream(input); + + final KTable valueCounts = inputStream + .groupByKey() + .aggregate(() -> "()", (key, value, aggregate) -> aggregate + ",(" + key + ": " + value + ")"); + + valueCounts + .suppress(untilTimeLimit(ofMillis(MAX_VALUE), maxRecords(1L) + .emitEarlyWhenFull() + .withLoggingEnabled(logConfig))) + .toStream() + .to(outputSuppressed); + + valueCounts + .toStream() + .to(outputRaw); + + final Properties streamsConfig = getStreamsConfig(appId); + streamsConfig.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.StringSerde.class); + streamsConfig.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.StringSerde.class); + + final KafkaStreams driver = IntegrationTestUtils.getStartedStreams(streamsConfig, builder, true); + try { + produceSynchronously( + input, + asList( + new KeyValueTimestamp<>("k1", "v1", scaledTime(0L)), + new KeyValueTimestamp<>("k1", "v2", scaledTime(1L)), + new KeyValueTimestamp<>("k2", "v1", scaledTime(2L)), + new KeyValueTimestamp<>("x", "x", scaledTime(3L)) + ) + ); + final boolean rawRecords = waitForAnyRecord(outputRaw); + final boolean suppressedRecords = waitForAnyRecord(outputSuppressed); + final Properties config = CLUSTER.getLogConfig(changeLog); + + assertThat(config.getProperty("retention.ms"), is(logConfig.get("retention.ms"))); + assertThat(CLUSTER.getAllTopicsInCluster(), hasItem(changeLog)); + assertThat(rawRecords, Matchers.is(true)); + assertThat(suppressedRecords, is(true)); + } finally { + driver.close(); + quietlyCleanStateAfterTest(CLUSTER, driver); + } + } + + @Test + public void shouldCreateChangelogByDefault() { + final String testId = "-shouldCreateChangelogByDefault"; + final String appId = getClass().getSimpleName().toLowerCase(Locale.getDefault()) + testId; + final String input = "input" + testId; + final String outputSuppressed = "output-suppressed" + testId; + final String outputRaw = "output-raw" + testId; + final String changeLog = "suppressionintegrationtest-shouldCreateChangelogByDefault-KTABLE-SUPPRESS-STATE-STORE-0000000004-changelog"; + + cleanStateBeforeTest(CLUSTER, input, outputRaw, outputSuppressed); + + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream inputStream = builder.stream(input); + + final KTable valueCounts = inputStream + .groupByKey() + .aggregate(() -> "()", (key, value, aggregate) -> aggregate + ",(" + key + ": " + value + ")"); + + valueCounts + .suppress(untilTimeLimit(ofMillis(MAX_VALUE), maxRecords(1L) + .emitEarlyWhenFull())) + .toStream() + .to(outputSuppressed); + + valueCounts + .toStream() + .to(outputRaw); + + final Properties streamsConfig = getStreamsConfig(appId); + streamsConfig.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.StringSerde.class); + streamsConfig.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.StringSerde.class); + + final KafkaStreams driver = IntegrationTestUtils.getStartedStreams(streamsConfig, builder, true); + try { + produceSynchronously( + input, + asList( + new KeyValueTimestamp<>("k1", "v1", scaledTime(0L)), + new KeyValueTimestamp<>("k1", "v2", scaledTime(1L)), + new KeyValueTimestamp<>("k2", "v1", scaledTime(2L)), + new KeyValueTimestamp<>("x", "x", scaledTime(3L)) + ) + ); + final boolean rawRecords = waitForAnyRecord(outputRaw); + final boolean suppressedRecords = waitForAnyRecord(outputSuppressed); + + assertThat(CLUSTER.getAllTopicsInCluster(), hasItem(changeLog)); + assertThat(rawRecords, Matchers.is(true)); + assertThat(suppressedRecords, is(true)); + } finally { + driver.close(); + quietlyCleanStateAfterTest(CLUSTER, driver); + } + } + + @Test + public void shouldAllowDisablingChangelog() { + final String testId = "-shouldAllowDisablingChangelog"; + final String appId = getClass().getSimpleName().toLowerCase(Locale.getDefault()) + testId; + final String input = "input" + testId; + final String outputSuppressed = "output-suppressed" + testId; + final String outputRaw = "output-raw" + testId; + + cleanStateBeforeTest(CLUSTER, input, outputRaw, outputSuppressed); + + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream inputStream = builder.stream(input); + + final KTable valueCounts = inputStream + .groupByKey() + .aggregate(() -> "()", (key, value, aggregate) -> aggregate + ",(" + key + ": " + value + ")"); + + valueCounts + .suppress(untilTimeLimit(ofMillis(MAX_VALUE), maxRecords(1L) + .emitEarlyWhenFull() + .withLoggingDisabled())) + .toStream() + .to(outputSuppressed); + + valueCounts + .toStream() + .to(outputRaw); + + final Properties streamsConfig = getStreamsConfig(appId); + streamsConfig.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.StringSerde.class); + streamsConfig.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.StringSerde.class); + + final KafkaStreams driver = IntegrationTestUtils.getStartedStreams(streamsConfig, builder, true); + try { + produceSynchronously( + input, + asList( + new KeyValueTimestamp<>("k1", "v1", scaledTime(0L)), + new KeyValueTimestamp<>("k1", "v2", scaledTime(1L)), + new KeyValueTimestamp<>("k2", "v1", scaledTime(2L)), + new KeyValueTimestamp<>("x", "x", scaledTime(3L)) + ) + ); + final boolean rawRecords = waitForAnyRecord(outputRaw); + final boolean suppressedRecords = waitForAnyRecord(outputSuppressed); + final Set suppressChangeLog = CLUSTER.getAllTopicsInCluster() + .stream() + .filter(s -> s.contains("-changelog")) + .filter(s -> s.contains("KTABLE-SUPPRESS")) + .collect(Collectors.toSet()); + + assertThat(suppressChangeLog, is(empty())); + assertThat(rawRecords, Matchers.is(true)); + assertThat(suppressedRecords, is(true)); + } finally { + driver.close(); + quietlyCleanStateAfterTest(CLUSTER, driver); + } + } + + private static Properties getStreamsConfig(final String appId) { + return mkProperties(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, appId), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()), + mkEntry(StreamsConfig.POLL_MS_CONFIG, Integer.toString(COMMIT_INTERVAL)), + mkEntry(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, Long.toString(COMMIT_INTERVAL)), + mkEntry(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, AT_LEAST_ONCE), + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()) + )); + } + + /** + * scaling to ensure that there are commits in between the various test events, + * just to exercise that everything works properly in the presence of commits. + */ + private static long scaledTime(final long unscaledTime) { + return COMMIT_INTERVAL * 2 * unscaledTime; + } + + private static void produceSynchronously(final String topic, final List> toProduce) { + final Properties producerConfig = mkProperties(mkMap( + mkEntry(ProducerConfig.CLIENT_ID_CONFIG, "anything"), + mkEntry(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, ((Serializer) STRING_SERIALIZER).getClass().getName()), + mkEntry(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, ((Serializer) STRING_SERIALIZER).getClass().getName()), + mkEntry(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()) + )); + IntegrationTestUtils.produceSynchronously(producerConfig, false, topic, Optional.empty(), toProduce); + } + + private static void verifyErrorShutdown(final KafkaStreams driver) throws InterruptedException { + waitForCondition(() -> !driver.state().isRunningOrRebalancing(), DEFAULT_TIMEOUT, "Streams didn't shut down."); + waitForCondition(() -> driver.state() == KafkaStreams.State.ERROR, "Streams didn't transit to ERROR state"); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/TableTableJoinIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/TableTableJoinIntegrationTest.java new file mode 100644 index 0000000..579ed19 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/TableTableJoinIntegrationTest.java @@ -0,0 +1,557 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.test.TestRecord; +import org.apache.kafka.test.IntegrationTest; +import org.junit.Before; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * Tests all available joins of Kafka Streams DSL. + */ +@Category({IntegrationTest.class}) +@RunWith(value = Parameterized.class) +public class TableTableJoinIntegrationTest extends AbstractJoinIntegrationTest { + private KTable leftTable; + private KTable rightTable; + + public TableTableJoinIntegrationTest(final boolean cacheEnabled) { + super(cacheEnabled); + } + + @Before + public void prepareTopology() throws InterruptedException { + super.prepareEnvironment(); + + appID = "table-table-join-integration-test"; + + builder = new StreamsBuilder(); + leftTable = builder.table(INPUT_TOPIC_LEFT, Materialized.>as("left").withLoggingDisabled()); + rightTable = builder.table(INPUT_TOPIC_RIGHT, Materialized.>as("right").withLoggingDisabled()); + } + + private final TestRecord expectedFinalJoinResult = new TestRecord<>(ANY_UNIQUE_KEY, "D-d", null, 15L); + private final TestRecord expectedFinalMultiJoinResult = new TestRecord<>(ANY_UNIQUE_KEY, "D-d-d", null, 15L); + private final String storeName = appID + "-store"; + + private final Materialized> materialized = Materialized.>as(storeName) + .withKeySerde(Serdes.Long()) + .withValueSerde(Serdes.String()) + .withCachingDisabled() + .withLoggingDisabled(); + + @Test + public void testInner() throws Exception { + STREAMS_CONFIG.put(StreamsConfig.APPLICATION_ID_CONFIG, appID + "-inner"); + + leftTable.join(rightTable, valueJoiner, materialized).toStream().to(OUTPUT_TOPIC); + + if (cacheEnabled) { + runTestWithDriver(expectedFinalJoinResult, storeName); + } else { + final List>> expectedResult = Arrays.asList( + null, + null, + null, + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "A-a", null, 4L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "B-a", null, 5L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "B-b", null, 6L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, null, null, 7L)), + null, + null, + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "C-c", null, 10L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, null, null, 11L)), + null, + null, + null, + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "D-d", null, 15L)) + ); + + runTestWithDriver(expectedResult, storeName); + } + } + + @Test + public void testLeft() throws Exception { + STREAMS_CONFIG.put(StreamsConfig.APPLICATION_ID_CONFIG, appID + "-left"); + + leftTable.leftJoin(rightTable, valueJoiner, materialized).toStream().to(OUTPUT_TOPIC); + + if (cacheEnabled) { + runTestWithDriver(expectedFinalJoinResult, storeName); + } else { + final List>> expectedResult = Arrays.asList( + null, + null, + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "A-null", null, 3L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "A-a", null, 4L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "B-a", null, 5L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "B-b", null, 6L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, null, null, 7L)), + null, + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "C-null", null, 9L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "C-c", null, 10L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "C-null", null, 11L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, null, null, 12L)), + null, + null, + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "D-d", null, 15L)) + ); + + runTestWithDriver(expectedResult, storeName); + } + } + + @Test + public void testOuter() throws Exception { + STREAMS_CONFIG.put(StreamsConfig.APPLICATION_ID_CONFIG, appID + "-outer"); + + leftTable.outerJoin(rightTable, valueJoiner, materialized).toStream().to(OUTPUT_TOPIC); + + if (cacheEnabled) { + runTestWithDriver(expectedFinalJoinResult, storeName); + } else { + final List>> expectedResult = Arrays.asList( + null, + null, + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "A-null", null, 3L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "A-a", null, 4L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "B-a", null, 5L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "B-b", null, 6L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "null-b", null, 7L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, null, null, 8L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "C-null", null, 9L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "C-c", null, 10L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "C-null", null, 11L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, null, null, 12L)), + null, + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "null-d", null, 14L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "D-d", null, 15L)) + ); + + runTestWithDriver(expectedResult, storeName); + } + } + + @Test + public void testInnerInner() throws Exception { + STREAMS_CONFIG.put(StreamsConfig.APPLICATION_ID_CONFIG, appID + "-inner-inner"); + + leftTable.join(rightTable, valueJoiner) + .join(rightTable, valueJoiner, materialized) + .toStream() + .to(OUTPUT_TOPIC); + + if (cacheEnabled) { + runTestWithDriver(expectedFinalMultiJoinResult, storeName); + } else { + // TODO K6443: the duplicate below for all the multi-joins are due to + // KAFKA-6443, should be updated once it is fixed. + final List>> expectedResult = Arrays.asList( + null, + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-a-a", null, 4L), + new TestRecord<>(ANY_UNIQUE_KEY, "A-a-a", null, 4L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "B-a-a", null, 5L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "B-b-b", null, 6L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-b-b", null, 6L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, null, null, 7L)), + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "C-c-c", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-c-c", null, 10L)), + null, // correct would be -> new TestRecord<>(ANY_UNIQUE_KEY, null, null, 11L) + // we don't get correct value, because of self-join of `rightTable` + null, + null, + null, + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "D-d-d", null, 15L)) + ); + + runTestWithDriver(expectedResult, storeName); + } + } + + @Test + public void testInnerLeft() throws Exception { + STREAMS_CONFIG.put(StreamsConfig.APPLICATION_ID_CONFIG, appID + "-inner-left"); + + leftTable.join(rightTable, valueJoiner) + .leftJoin(rightTable, valueJoiner, materialized) + .toStream() + .to(OUTPUT_TOPIC); + + if (cacheEnabled) { + runTestWithDriver(expectedFinalMultiJoinResult, storeName); + } else { + final List>> expectedResult = Arrays.asList( + null, + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-a-a", null, 4L), + new TestRecord<>(ANY_UNIQUE_KEY, "A-a-a", null, 4L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "B-a-a", null, 5L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "B-b-b", null, 6L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-b-b", null, 6L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, null, null, 7L)), + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "C-c-c", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-c-c", null, 10L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, null, null, 11L)), + null, + null, + null, + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "D-d-d", null, 15L)) + ); + + runTestWithDriver(expectedResult, storeName); + } + } + + @Test + public void testInnerOuter() throws Exception { + STREAMS_CONFIG.put(StreamsConfig.APPLICATION_ID_CONFIG, appID + "-inner-outer"); + + leftTable.join(rightTable, valueJoiner) + .outerJoin(rightTable, valueJoiner, materialized) + .toStream() + .to(OUTPUT_TOPIC); + + if (cacheEnabled) { + runTestWithDriver(expectedFinalMultiJoinResult, storeName); + } else { + final List>> expectedResult = Arrays.asList( + null, + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-a-a", null, 4L), + new TestRecord<>(ANY_UNIQUE_KEY, "A-a-a", null, 4L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "B-a-a", null, 5L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "B-b-b", null, 6L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-b-b", null, 6L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "null-b", null, 7L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, null, null, 8L)), + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "C-c-c", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-c-c", null, 10L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, null, null, 11L), + new TestRecord<>(ANY_UNIQUE_KEY, null, null, 11L)), + null, + null, + null, + Arrays.asList( + // incorrect result `null-d` is caused by self-join of `rightTable` + new TestRecord<>(ANY_UNIQUE_KEY, "null-d", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-d-d", null, 15L)) + ); + + runTestWithDriver(expectedResult, storeName); + } + } + + @Test + public void testLeftInner() throws Exception { + STREAMS_CONFIG.put(StreamsConfig.APPLICATION_ID_CONFIG, appID + "-inner-inner"); + + leftTable.leftJoin(rightTable, valueJoiner) + .join(rightTable, valueJoiner, materialized) + .toStream() + .to(OUTPUT_TOPIC); + + if (cacheEnabled) { + runTestWithDriver(expectedFinalMultiJoinResult, storeName); + } else { + final List>> expectedResult = Arrays.asList( + null, + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-a-a", null, 4L), + new TestRecord<>(ANY_UNIQUE_KEY, "A-a-a", null, 4L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "B-a-a", null, 5L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "B-b-b", null, 6L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-b-b", null, 6L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, null, null, 7L)), + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "C-c-c", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-c-c", null, 10L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, null, null, 11L)), + null, + null, + null, + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "D-d-d", null, 15L)) + ); + + runTestWithDriver(expectedResult, storeName); + } + } + + @Test + public void testLeftLeft() throws Exception { + STREAMS_CONFIG.put(StreamsConfig.APPLICATION_ID_CONFIG, appID + "-inner-left"); + + leftTable.leftJoin(rightTable, valueJoiner) + .leftJoin(rightTable, valueJoiner, materialized) + .toStream() + .to(OUTPUT_TOPIC); + + if (cacheEnabled) { + runTestWithDriver(expectedFinalMultiJoinResult, storeName); + } else { + final List>> expectedResult = Arrays.asList( + null, + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-null-null", null, 3L), + new TestRecord<>(ANY_UNIQUE_KEY, "A-a-a", null, 4L), + new TestRecord<>(ANY_UNIQUE_KEY, "A-a-a", null, 4L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "B-a-a", null, 5L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "B-b-b", null, 6L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-b-b", null, 6L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, null, null, 7L)), + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "C-null-null", null, 9L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-c-c", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-c-c", null, 10L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "C-null-null", null, 11L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-null-null", null, 11L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, null, null, 12L)), + null, + null, + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "D-d-d", null, 15L)) + ); + + runTestWithDriver(expectedResult, storeName); + } + } + + @Test + public void testLeftOuter() throws Exception { + STREAMS_CONFIG.put(StreamsConfig.APPLICATION_ID_CONFIG, appID + "-inner-outer"); + + leftTable.leftJoin(rightTable, valueJoiner) + .outerJoin(rightTable, valueJoiner, materialized) + .toStream() + .to(OUTPUT_TOPIC); + + if (cacheEnabled) { + runTestWithDriver(expectedFinalMultiJoinResult, storeName); + } else { + final List>> expectedResult = Arrays.asList( + null, + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-null-null", null, 3L), + new TestRecord<>(ANY_UNIQUE_KEY, "A-a-a", null, 4L), + new TestRecord<>(ANY_UNIQUE_KEY, "A-a-a", null, 4L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "B-a-a", null, 5L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "B-b-b", null, 6L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-b-b", null, 6L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "null-b", null, 7L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, null, null, 8L)), + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "C-null-null", null, 9L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-c-c", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-c-c", null, 10L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "C-null-null", null, 11L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-null-null", null, 11L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, null, null, 12L)), + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "null-d", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-d-d", null, 15L)) + ); + + runTestWithDriver(expectedResult, storeName); + } + } + + @Test + public void testOuterInner() throws Exception { + STREAMS_CONFIG.put(StreamsConfig.APPLICATION_ID_CONFIG, appID + "-inner-inner"); + + leftTable.outerJoin(rightTable, valueJoiner) + .join(rightTable, valueJoiner, materialized) + .toStream() + .to(OUTPUT_TOPIC); + + if (cacheEnabled) { + runTestWithDriver(expectedFinalMultiJoinResult, storeName); + } else { + final List>> expectedResult = Arrays.asList( + null, + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-a-a", null, 4L), + new TestRecord<>(ANY_UNIQUE_KEY, "A-a-a", null, 4L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "B-a-a", null, 5L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "B-b-b", null, 6L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-b-b", null, 6L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "null-b-b", null, 7L)), + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "C-c-c", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-c-c", null, 10L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, null, null, 11L)), + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "null-d-d", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "null-d-d", null, 14L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "D-d-d", null, 15L)) + ); + + runTestWithDriver(expectedResult, storeName); + } + } + + @Test + public void testOuterLeft() throws Exception { + STREAMS_CONFIG.put(StreamsConfig.APPLICATION_ID_CONFIG, appID + "-inner-left"); + + leftTable.outerJoin(rightTable, valueJoiner) + .leftJoin(rightTable, valueJoiner, materialized) + .toStream() + .to(OUTPUT_TOPIC); + + if (cacheEnabled) { + runTestWithDriver(expectedFinalMultiJoinResult, storeName); + } else { + final List>> expectedResult = Arrays.asList( + null, + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-null-null", null, 3L), + new TestRecord<>(ANY_UNIQUE_KEY, "A-a-a", null, 4L), + new TestRecord<>(ANY_UNIQUE_KEY, "A-a-a", null, 4L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "B-a-a", null, 5L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "B-b-b", null, 6L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-b-b", null, 6L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "null-b-b", null, 7L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, null, null, 8L)), + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "C-null-null", null, 9L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-c-c", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-c-c", null, 10L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "C-null-null", null, 11L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-null-null", null, 11L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, null, null, 12L)), + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "null-d-d", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "null-d-d", null, 14L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "D-d-d", null, 15L)) + ); + + runTestWithDriver(expectedResult, storeName); + } + } + + @Test + public void testOuterOuter() throws Exception { + STREAMS_CONFIG.put(StreamsConfig.APPLICATION_ID_CONFIG, appID + "-inner-outer"); + + leftTable.outerJoin(rightTable, valueJoiner) + .outerJoin(rightTable, valueJoiner, materialized) + .toStream() + .to(OUTPUT_TOPIC); + + if (cacheEnabled) { + runTestWithDriver(expectedFinalMultiJoinResult, storeName); + } else { + final List>> expectedResult = Arrays.asList( + null, + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "A-null-null", null, 3L), + new TestRecord<>(ANY_UNIQUE_KEY, "A-a-a", null, 4L), + new TestRecord<>(ANY_UNIQUE_KEY, "A-a-a", null, 4L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "B-a-a", null, 5L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "B-b-b", null, 6L), + new TestRecord<>(ANY_UNIQUE_KEY, "B-b-b", null, 6L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "null-b-b", null, 7L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, null, null, 8L), + new TestRecord<>(ANY_UNIQUE_KEY, null, null, 8L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, "C-null-null", null, 9L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "C-c-c", null, 10L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-c-c", null, 10L)), + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "C-null-null", null, 11L), + new TestRecord<>(ANY_UNIQUE_KEY, "C-null-null", null, 11L)), + Collections.singletonList(new TestRecord<>(ANY_UNIQUE_KEY, null, null, 12L)), + null, + null, + Arrays.asList( + new TestRecord<>(ANY_UNIQUE_KEY, "null-d-d", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "null-d-d", null, 14L), + new TestRecord<>(ANY_UNIQUE_KEY, "D-d-d", null, 15L)) + ); + runTestWithDriver(expectedResult, storeName); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/TaskAssignorIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/TaskAssignorIntegrationTest.java new file mode 100644 index 0000000..5ff6cb6 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/TaskAssignorIntegrationTest.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.processor.internals.StreamThread; +import org.apache.kafka.streams.processor.internals.StreamsPartitionAssignor; +import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration; +import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentListener; +import org.apache.kafka.streams.processor.internals.assignment.HighAvailabilityTaskAssignor; +import org.apache.kafka.streams.processor.internals.assignment.TaskAssignor; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestUtils; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; + +import java.io.IOException; +import java.lang.reflect.Field; +import java.util.List; +import java.util.Properties; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkObjectProperties; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; + +@Category(IntegrationTest.class) +public class TaskAssignorIntegrationTest { + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(1); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + @Rule + public TestName testName = new TestName(); + + // Just a dummy implementation so we can check the config + public static final class MyTaskAssignor extends HighAvailabilityTaskAssignor implements TaskAssignor { } + + @SuppressWarnings("unchecked") + @Test + public void shouldProperlyConfigureTheAssignor() throws NoSuchFieldException, IllegalAccessException { + // This test uses reflection to check and make sure that all the expected configurations really + // make it all the way to configure the task assignor. There's no other use case for being able + // to extract all these fields, so reflection is a good choice until we find that the maintenance + // burden is too high. + // + // Also note that this is an integration test because so many components have to come together to + // ensure these configurations wind up where they belong, and any number of future code changes + // could break this change. + + final String testId = safeUniqueTestName(getClass(), testName); + final String appId = "appId_" + testId; + final String inputTopic = "input" + testId; + + IntegrationTestUtils.cleanStateBeforeTest(CLUSTER, inputTopic); + + // Maybe I'm paranoid, but I don't want the compiler deciding that my lambdas are equal to the identity + // function and defeating my identity check + final AtomicInteger compilerDefeatingReference = new AtomicInteger(0); + + // the implementation doesn't matter, we're just going to verify the reference. + final AssignmentListener configuredAssignmentListener = + stable -> compilerDefeatingReference.incrementAndGet(); + + final Properties properties = mkObjectProperties( + mkMap( + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()), + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, appId), + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()), + mkEntry(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, "5"), + mkEntry(StreamsConfig.ACCEPTABLE_RECOVERY_LAG_CONFIG, "6"), + mkEntry(StreamsConfig.MAX_WARMUP_REPLICAS_CONFIG, "7"), + mkEntry(StreamsConfig.PROBING_REBALANCE_INTERVAL_MS_CONFIG, "480000"), + mkEntry(StreamsConfig.InternalConfig.ASSIGNMENT_LISTENER, configuredAssignmentListener), + mkEntry(StreamsConfig.InternalConfig.INTERNAL_TASK_ASSIGNOR_CLASS, MyTaskAssignor.class.getName()) + ) + ); + + final StreamsBuilder builder = new StreamsBuilder(); + builder.stream(inputTopic); + + try (final KafkaStreams kafkaStreams = new KafkaStreams(builder.build(), properties)) { + kafkaStreams.start(); + + final Field threads = KafkaStreams.class.getDeclaredField("threads"); + threads.setAccessible(true); + final List streamThreads = (List) threads.get(kafkaStreams); + final StreamThread streamThread = streamThreads.get(0); + + final Field mainConsumer = StreamThread.class.getDeclaredField("mainConsumer"); + mainConsumer.setAccessible(true); + final KafkaConsumer consumer = (KafkaConsumer) mainConsumer.get(streamThread); + + final Field assignors = KafkaConsumer.class.getDeclaredField("assignors"); + assignors.setAccessible(true); + final List consumerPartitionAssignors = (List) assignors.get(consumer); + final StreamsPartitionAssignor streamsPartitionAssignor = (StreamsPartitionAssignor) consumerPartitionAssignors.get(0); + + final Field assignmentConfigs = StreamsPartitionAssignor.class.getDeclaredField("assignmentConfigs"); + assignmentConfigs.setAccessible(true); + final AssignorConfiguration.AssignmentConfigs configs = (AssignorConfiguration.AssignmentConfigs) assignmentConfigs.get(streamsPartitionAssignor); + + final Field assignmentListenerField = StreamsPartitionAssignor.class.getDeclaredField("assignmentListener"); + assignmentListenerField.setAccessible(true); + final AssignmentListener actualAssignmentListener = (AssignmentListener) assignmentListenerField.get(streamsPartitionAssignor); + + final Field taskAssignorSupplierField = StreamsPartitionAssignor.class.getDeclaredField("taskAssignorSupplier"); + taskAssignorSupplierField.setAccessible(true); + final Supplier taskAssignorSupplier = + (Supplier) taskAssignorSupplierField.get(streamsPartitionAssignor); + final TaskAssignor taskAssignor = taskAssignorSupplier.get(); + + assertThat(configs.numStandbyReplicas, is(5)); + assertThat(configs.acceptableRecoveryLag, is(6L)); + assertThat(configs.maxWarmupReplicas, is(7)); + assertThat(configs.probingRebalanceIntervalMs, is(480000L)); + assertThat(actualAssignmentListener, sameInstance(configuredAssignmentListener)); + assertThat(taskAssignor, instanceOf(MyTaskAssignor.class)); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/TaskMetadataIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/integration/TaskMetadataIntegrationTest.java new file mode 100644 index 0000000..6f35d12 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/TaskMetadataIntegrationTest.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.integration; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.TaskMetadata; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.integration.utils.IntegrationTestUtils; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; + +import java.io.IOException; +import java.time.Duration; +import java.util.Collections; +import java.util.List; +import java.util.Properties; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkObjectProperties; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.purgeLocalStreamsState; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; + +@Category(IntegrationTest.class) +public class TaskMetadataIntegrationTest { + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(1, new Properties(), 0L, 0L); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + public static final Duration DEFAULT_DURATION = Duration.ofSeconds(30); + + @Rule + public TestName testName = new TestName(); + + private String inputTopic; + private static StreamsBuilder builder; + private static Properties properties; + private static String appIdPrefix = "TaskMetadataTest_"; + private static String appId; + private AtomicBoolean process; + private AtomicBoolean commit; + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Before + public void setup() { + final String testId = safeUniqueTestName(getClass(), testName); + appId = appIdPrefix + testId; + inputTopic = "input" + testId; + IntegrationTestUtils.cleanStateBeforeTest(CLUSTER, inputTopic); + + builder = new StreamsBuilder(); + + process = new AtomicBoolean(true); + commit = new AtomicBoolean(true); + + final KStream stream = builder.stream(inputTopic); + stream.process(PauseProcessor::new); + + properties = mkObjectProperties( + mkMap( + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()), + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, appId), + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()), + mkEntry(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 2), + mkEntry(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.StringSerde.class), + mkEntry(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.StringSerde.class), + mkEntry(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1L) + ) + ); + } + + @Test + public void shouldReportCorrectCommittedOffsetInformation() { + try (final KafkaStreams kafkaStreams = new KafkaStreams(builder.build(), properties)) { + IntegrationTestUtils.startApplicationAndWaitUntilRunning(Collections.singletonList(kafkaStreams), DEFAULT_DURATION); + final TaskMetadata taskMetadata = getTaskMetadata(kafkaStreams); + assertThat(taskMetadata.committedOffsets().size(), equalTo(1)); + final TopicPartition topicPartition = new TopicPartition(inputTopic, 0); + + produceMessages(0L, inputTopic, "test"); + TestUtils.waitForCondition(() -> !process.get(), "The record was not processed"); + TestUtils.waitForCondition(() -> taskMetadata.committedOffsets().get(topicPartition) == 1L, "the record was processed"); + process.set(true); + + produceMessages(0L, inputTopic, "test1"); + TestUtils.waitForCondition(() -> !process.get(), "The record was not processed"); + TestUtils.waitForCondition(() -> taskMetadata.committedOffsets().get(topicPartition) == 2L, "the record was processed"); + process.set(true); + + produceMessages(0L, inputTopic, "test1"); + TestUtils.waitForCondition(() -> !process.get(), "The record was not processed"); + TestUtils.waitForCondition(() -> taskMetadata.committedOffsets().get(topicPartition) == 3L, "the record was processed"); + } catch (final Exception e) { + e.printStackTrace(); + } + } + + @Test + public void shouldReportCorrectEndOffsetInformation() { + try (final KafkaStreams kafkaStreams = new KafkaStreams(builder.build(), properties)) { + IntegrationTestUtils.startApplicationAndWaitUntilRunning(Collections.singletonList(kafkaStreams), DEFAULT_DURATION); + final TaskMetadata taskMetadata = getTaskMetadata(kafkaStreams); + assertThat(taskMetadata.endOffsets().size(), equalTo(1)); + final TopicPartition topicPartition = new TopicPartition(inputTopic, 0); + commit.set(false); + + for (int i = 0; i < 10; i++) { + produceMessages(0L, inputTopic, "test"); + TestUtils.waitForCondition(() -> !process.get(), "The record was not processed"); + process.set(true); + } + assertThat(taskMetadata.endOffsets().get(topicPartition), equalTo(9L)); + + } catch (final Exception e) { + e.printStackTrace(); + } + } + + private TaskMetadata getTaskMetadata(final KafkaStreams kafkaStreams) throws InterruptedException { + final AtomicReference> taskMetadataList = new AtomicReference<>(); + TestUtils.waitForCondition(() -> { + taskMetadataList.set(kafkaStreams.metadataForLocalThreads().stream().flatMap(t -> t.activeTasks().stream()).collect(Collectors.toList())); + return taskMetadataList.get().size() == 1; + }, "The number of active tasks returned in the allotted time was not one."); + return taskMetadataList.get().get(0); + } + + @After + public void teardown() throws IOException { + purgeLocalStreamsState(properties); + } + + private void produceMessages(final long timestamp, final String streamOneInput, final String msg) { + IntegrationTestUtils.produceKeyValuesSynchronouslyWithTimestamp( + streamOneInput, + Collections.singletonList(new KeyValue<>("1", msg)), + TestUtils.producerConfig( + CLUSTER.bootstrapServers(), + StringSerializer.class, + StringSerializer.class, + new Properties()), + timestamp); + } + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + private class PauseProcessor extends org.apache.kafka.streams.processor.AbstractProcessor { + @Override + public void process(final String key, final String value) { + while (!process.get()) { + try { + wait(100); + } catch (final InterruptedException e) { + + } + } + context().forward(key, value); + if (commit.get()) { + context().commit(); + } + process.set(false); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/utils/CompositeStateListener.java b/streams/src/test/java/org/apache/kafka/streams/integration/utils/CompositeStateListener.java new file mode 100644 index 0000000..8258942 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/utils/CompositeStateListener.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration.utils; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import org.apache.kafka.streams.KafkaStreams.State; +import org.apache.kafka.streams.KafkaStreams.StateListener; + +/** + * A {@link StateListener} that holds zero or more listeners internally and invokes all of them + * when a state transition occurs (i.e. {@link #onChange(State, State)} is called). If any listener + * throws {@link RuntimeException} or {@link Error} this immediately stops execution of listeners + * and causes the thrown exception to be raised. + */ +public class CompositeStateListener implements StateListener { + private final List listeners; + + public CompositeStateListener(final StateListener... listeners) { + this(Arrays.asList(listeners)); + } + + public CompositeStateListener(final Collection stateListeners) { + this.listeners = Collections.unmodifiableList(new ArrayList<>(stateListeners)); + } + + @Override + public void onChange(final State newState, final State oldState) { + for (final StateListener listener : listeners) { + listener.onChange(newState, oldState); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/utils/EmbeddedKafkaCluster.java b/streams/src/test/java/org/apache/kafka/streams/integration/utils/EmbeddedKafkaCluster.java new file mode 100644 index 0000000..d990530 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/utils/EmbeddedKafkaCluster.java @@ -0,0 +1,343 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration.utils; + +import kafka.server.ConfigType; +import kafka.server.KafkaConfig; +import kafka.server.KafkaServer; +import kafka.utils.MockTime; +import kafka.zk.EmbeddedZookeeper; +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.UnknownTopicOrPartitionException; +import org.apache.kafka.test.TestCondition; +import org.apache.kafka.test.TestUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.ExecutionException; + +/** + * Runs an in-memory, "embedded" Kafka cluster with 1 ZooKeeper instance and supplied number of Kafka brokers. + */ +public class EmbeddedKafkaCluster { + + private static final Logger log = LoggerFactory.getLogger(EmbeddedKafkaCluster.class); + private static final int DEFAULT_BROKER_PORT = 0; // 0 results in a random port being selected + private static final int TOPIC_CREATION_TIMEOUT = 30000; + private static final int TOPIC_DELETION_TIMEOUT = 30000; + private EmbeddedZookeeper zookeeper = null; + private final KafkaEmbedded[] brokers; + + private final Properties brokerConfig; + public final MockTime time; + + public EmbeddedKafkaCluster(final int numBrokers) { + this(numBrokers, new Properties()); + } + + public EmbeddedKafkaCluster(final int numBrokers, + final Properties brokerConfig) { + this(numBrokers, brokerConfig, System.currentTimeMillis()); + } + + public EmbeddedKafkaCluster(final int numBrokers, + final Properties brokerConfig, + final long mockTimeMillisStart) { + this(numBrokers, brokerConfig, mockTimeMillisStart, System.nanoTime()); + } + + public EmbeddedKafkaCluster(final int numBrokers, + final Properties brokerConfig, + final long mockTimeMillisStart, + final long mockTimeNanoStart) { + brokers = new KafkaEmbedded[numBrokers]; + this.brokerConfig = brokerConfig; + time = new MockTime(mockTimeMillisStart, mockTimeNanoStart); + } + + /** + * Creates and starts a Kafka cluster. + */ + public void start() throws IOException { + log.debug("Initiating embedded Kafka cluster startup"); + log.debug("Starting a ZooKeeper instance"); + zookeeper = new EmbeddedZookeeper(); + log.debug("ZooKeeper instance is running at {}", zKConnectString()); + + brokerConfig.put(KafkaConfig.ZkConnectProp(), zKConnectString()); + putIfAbsent(brokerConfig, KafkaConfig.ListenersProp(), "PLAINTEXT://localhost:" + DEFAULT_BROKER_PORT); + putIfAbsent(brokerConfig, KafkaConfig.DeleteTopicEnableProp(), true); + putIfAbsent(brokerConfig, KafkaConfig.LogCleanerDedupeBufferSizeProp(), 2 * 1024 * 1024L); + putIfAbsent(brokerConfig, KafkaConfig.GroupMinSessionTimeoutMsProp(), 0); + putIfAbsent(brokerConfig, KafkaConfig.GroupInitialRebalanceDelayMsProp(), 0); + putIfAbsent(brokerConfig, KafkaConfig.OffsetsTopicReplicationFactorProp(), (short) 1); + putIfAbsent(brokerConfig, KafkaConfig.OffsetsTopicPartitionsProp(), 5); + putIfAbsent(brokerConfig, KafkaConfig.TransactionsTopicPartitionsProp(), 5); + putIfAbsent(brokerConfig, KafkaConfig.AutoCreateTopicsEnableProp(), true); + + for (int i = 0; i < brokers.length; i++) { + brokerConfig.put(KafkaConfig.BrokerIdProp(), i); + log.debug("Starting a Kafka instance on {} ...", brokerConfig.get(KafkaConfig.ListenersProp())); + brokers[i] = new KafkaEmbedded(brokerConfig, time); + + log.debug("Kafka instance is running at {}, connected to ZooKeeper at {}", + brokers[i].brokerList(), brokers[i].zookeeperConnect()); + } + } + + private void putIfAbsent(final Properties props, final String propertyKey, final Object propertyValue) { + if (!props.containsKey(propertyKey)) { + brokerConfig.put(propertyKey, propertyValue); + } + } + + /** + * Stop the Kafka cluster. + */ + public void stop() { + if (brokers.length > 1) { + // delete the topics first to avoid cascading leader elections while shutting down the brokers + final Set topics = getAllTopicsInCluster(); + if (!topics.isEmpty()) { + try (final Admin adminClient = brokers[0].createAdminClient()) { + adminClient.deleteTopics(topics).all().get(); + } catch (final InterruptedException e) { + log.warn("Got interrupted while deleting topics in preparation for stopping embedded brokers", e); + throw new RuntimeException(e); + } catch (final ExecutionException | RuntimeException e) { + log.warn("Couldn't delete all topics before stopping brokers", e); + } + } + } + for (final KafkaEmbedded broker : brokers) { + broker.stopAsync(); + } + for (final KafkaEmbedded broker : brokers) { + broker.awaitStoppedAndPurge(); + } + zookeeper.shutdown(); + } + + /** + * The ZooKeeper connection string aka `zookeeper.connect` in `hostnameOrIp:port` format. + * Example: `127.0.0.1:2181`. + *

                + * You can use this to e.g. tell Kafka brokers how to connect to this instance. + */ + public String zKConnectString() { + return "127.0.0.1:" + zookeeper.port(); + } + + /** + * This cluster's `bootstrap.servers` value. Example: `127.0.0.1:9092`. + *

                + * You can use this to tell Kafka producers how to connect to this cluster. + */ + public String bootstrapServers() { + return brokers[0].brokerList(); + } + + /** + * Create multiple Kafka topics each with 1 partition and a replication factor of 1. + * + * @param topics The name of the topics. + */ + public void createTopics(final String... topics) throws InterruptedException { + for (final String topic : topics) { + createTopic(topic, 1, 1, Collections.emptyMap()); + } + } + + /** + * Create a Kafka topic with 1 partition and a replication factor of 1. + * + * @param topic The name of the topic. + */ + public void createTopic(final String topic) throws InterruptedException { + createTopic(topic, 1, 1, Collections.emptyMap()); + } + + /** + * Create a Kafka topic with the given parameters. + * + * @param topic The name of the topic. + * @param partitions The number of partitions for this topic. + * @param replication The replication factor for (the partitions of) this topic. + */ + public void createTopic(final String topic, final int partitions, final int replication) throws InterruptedException { + createTopic(topic, partitions, replication, Collections.emptyMap()); + } + + /** + * Create a Kafka topic with the given parameters. + * + * @param topic The name of the topic. + * @param partitions The number of partitions for this topic. + * @param replication The replication factor for (partitions of) this topic. + * @param topicConfig Additional topic-level configuration settings. + */ + public void createTopic(final String topic, + final int partitions, + final int replication, + final Map topicConfig) throws InterruptedException { + brokers[0].createTopic(topic, partitions, replication, topicConfig); + final List topicPartitions = new ArrayList<>(); + for (int partition = 0; partition < partitions; partition++) { + topicPartitions.add(new TopicPartition(topic, partition)); + } + IntegrationTestUtils.waitForTopicPartitions(brokers(), topicPartitions, TOPIC_CREATION_TIMEOUT); + } + + /** + * Deletes a topic returns immediately. + * + * @param topic the name of the topic + */ + public void deleteTopic(final String topic) throws InterruptedException { + deleteTopicsAndWait(-1L, topic); + } + + /** + * Deletes a topic and blocks for max 30 sec until the topic got deleted. + * + * @param topic the name of the topic + */ + public void deleteTopicAndWait(final String topic) throws InterruptedException { + deleteTopicsAndWait(TOPIC_DELETION_TIMEOUT, topic); + } + + /** + * Deletes multiple topics returns immediately. + * + * @param topics the name of the topics + */ + public void deleteTopics(final String... topics) throws InterruptedException { + deleteTopicsAndWait(-1, topics); + } + + /** + * Deletes multiple topics and blocks for max 30 sec until all topics got deleted. + * + * @param topics the name of the topics + */ + public void deleteTopicsAndWait(final String... topics) throws InterruptedException { + deleteTopicsAndWait(TOPIC_DELETION_TIMEOUT, topics); + } + + /** + * Deletes multiple topics and blocks until all topics got deleted. + * + * @param timeoutMs the max time to wait for the topics to be deleted (does not block if {@code <= 0}) + * @param topics the name of the topics + */ + public void deleteTopicsAndWait(final long timeoutMs, final String... topics) throws InterruptedException { + for (final String topic : topics) { + try { + brokers[0].deleteTopic(topic); + } catch (final UnknownTopicOrPartitionException ignored) { } + } + + if (timeoutMs > 0) { + TestUtils.waitForCondition(new TopicsDeletedCondition(topics), timeoutMs, "Topics not deleted after " + timeoutMs + " milli seconds."); + } + } + + /** + * Deletes all topics and blocks until all topics got deleted. + * + * @param timeoutMs the max time to wait for the topics to be deleted (does not block if {@code <= 0}) + */ + public void deleteAllTopicsAndWait(final long timeoutMs) throws InterruptedException { + final Set topics = getAllTopicsInCluster(); + for (final String topic : topics) { + try { + brokers[0].deleteTopic(topic); + } catch (final UnknownTopicOrPartitionException ignored) { } + } + + if (timeoutMs > 0) { + TestUtils.waitForCondition(new TopicsDeletedCondition(topics), timeoutMs, "Topics not deleted after " + timeoutMs + " milli seconds."); + } + } + + public void waitForRemainingTopics(final long timeoutMs, final String... topics) throws InterruptedException { + TestUtils.waitForCondition(new TopicsRemainingCondition(topics), timeoutMs, "Topics are not expected after " + timeoutMs + " milli seconds."); + } + + private final class TopicsDeletedCondition implements TestCondition { + final Set deletedTopics = new HashSet<>(); + + private TopicsDeletedCondition(final String... topics) { + Collections.addAll(deletedTopics, topics); + } + + private TopicsDeletedCondition(final Collection topics) { + deletedTopics.addAll(topics); + } + + @Override + public boolean conditionMet() { + final Set allTopics = getAllTopicsInCluster(); + return !allTopics.removeAll(deletedTopics); + } + } + + private final class TopicsRemainingCondition implements TestCondition { + final Set remainingTopics = new HashSet<>(); + + private TopicsRemainingCondition(final String... topics) { + Collections.addAll(remainingTopics, topics); + } + + @Override + public boolean conditionMet() { + final Set allTopics = getAllTopicsInCluster(); + return allTopics.equals(remainingTopics); + } + } + + private List brokers() { + final List servers = new ArrayList<>(); + for (final KafkaEmbedded broker : brokers) { + servers.add(broker.kafkaServer()); + } + return servers; + } + + public Properties getLogConfig(final String topic) { + return brokers[0].kafkaServer().zkClient().getEntityConfigs(ConfigType.Topic(), topic); + } + + public Set getAllTopicsInCluster() { + final scala.collection.Iterator topicsIterator = brokers[0].kafkaServer().zkClient().getAllTopicsInCluster(false).iterator(); + final Set topics = new HashSet<>(); + while (topicsIterator.hasNext()) { + topics.add(topicsIterator.next()); + } + return topics; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/utils/IntegrationTestUtils.java b/streams/src/test/java/org/apache/kafka/streams/integration/utils/IntegrationTestUtils.java new file mode 100644 index 0000000..da457cb --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/utils/IntegrationTestUtils.java @@ -0,0 +1,1371 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration.utils; + +import kafka.api.Request; +import kafka.server.KafkaServer; +import kafka.server.MetadataCache; +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.ConsumerGroupDescription; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.Producer; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.message.UpdateMetadataRequestData.UpdateMetadataPartitionState; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KafkaStreams.State; +import org.apache.kafka.streams.KafkaStreams.StateListener; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.StoreQueryParameters; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.processor.StateRestoreListener; +import org.apache.kafka.streams.processor.internals.StreamThread; +import org.apache.kafka.streams.processor.internals.ThreadStateTransitionValidator; +import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentListener; +import org.apache.kafka.streams.processor.internals.namedtopology.KafkaStreamsNamedTopologyWrapper; +import org.apache.kafka.streams.state.QueryableStoreType; +import org.apache.kafka.test.TestCondition; +import org.apache.kafka.test.TestUtils; +import org.junit.rules.TestName; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.Option; + +import java.io.File; +import java.io.IOException; +import java.lang.reflect.Field; +import java.nio.file.Paths; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.Optional; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; +import java.util.stream.Collectors; + +import static org.apache.kafka.test.TestUtils.retryOnExceptionWithTimeout; +import static org.apache.kafka.test.TestUtils.waitForCondition; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.fail; + +/** + * Utility functions to make integration testing more convenient. + */ +public class IntegrationTestUtils { + + public static final long DEFAULT_TIMEOUT = 60 * 1000L; + private static final Logger LOG = LoggerFactory.getLogger(IntegrationTestUtils.class); + + /* + * Records state transition for StreamThread + */ + public static class StateListenerStub implements StreamThread.StateListener { + boolean toPendingShutdownSeen = false; + @Override + public void onChange(final Thread thread, + final ThreadStateTransitionValidator newState, + final ThreadStateTransitionValidator oldState) { + if (newState == StreamThread.State.PENDING_SHUTDOWN) { + toPendingShutdownSeen = true; + } + } + + public boolean transitToPendingShutdownSeen() { + return toPendingShutdownSeen; + } + } + + /** + * Gives a test name that is safe to be used in application ids, topic names, etc. + * The name is safe even for parameterized methods. + */ + public static String safeUniqueTestName(final Class testClass, final TestName testName) { + return (testClass.getSimpleName() + testName.getMethodName()) + .replace('.', '_') + .replace('[', '_') + .replace(']', '_') + .replace(' ', '_') + .replace('=', '_'); + } + + /** + * Removes local state stores. Useful to reset state in-between integration test runs. + * + * @param streamsConfiguration Streams configuration settings + */ + public static void purgeLocalStreamsState(final Properties streamsConfiguration) throws IOException { + final String tmpDir = TestUtils.IO_TMP_DIR.getPath(); + final String path = streamsConfiguration.getProperty(StreamsConfig.STATE_DIR_CONFIG); + if (path != null) { + final File node = Paths.get(path).normalize().toFile(); + // Only purge state when it's under java.io.tmpdir. This is a safety net to prevent accidentally + // deleting important local directory trees. + if (node.getAbsolutePath().startsWith(tmpDir)) { + Utils.delete(new File(node.getAbsolutePath())); + } + } + } + + /** + * Removes local state stores. Useful to reset state in-between integration test runs. + * + * @param streamsConfigurations Streams configuration settings + */ + public static void purgeLocalStreamsState(final Collection streamsConfigurations) throws IOException { + for (final Properties streamsConfig : streamsConfigurations) { + purgeLocalStreamsState(streamsConfig); + } + } + + public static void cleanStateBeforeTest(final EmbeddedKafkaCluster cluster, final String... topics) { + cleanStateBeforeTest(cluster, 1, topics); + } + + public static void cleanStateBeforeTest(final EmbeddedKafkaCluster cluster, + final int partitionCount, + final String... topics) { + try { + cluster.deleteAllTopicsAndWait(DEFAULT_TIMEOUT); + for (final String topic : topics) { + cluster.createTopic(topic, partitionCount, 1); + } + } catch (final InterruptedException e) { + throw new RuntimeException(e); + } + } + + public static void quietlyCleanStateAfterTest(final EmbeddedKafkaCluster cluster, final KafkaStreams driver) { + try { + driver.cleanUp(); + cluster.deleteAllTopicsAndWait(DEFAULT_TIMEOUT); + } catch (final RuntimeException | InterruptedException e) { + LOG.warn("Ignoring failure to clean test state", e); + } + } + + /** + * @param topic Kafka topic to write the data records to + * @param records Data records to write to Kafka + * @param producerConfig Kafka producer configuration + * @param time Timestamp provider + * @param Key type of the data records + * @param Value type of the data records + */ + public static void produceKeyValuesSynchronously(final String topic, + final Collection> records, + final Properties producerConfig, + final Time time) { + produceKeyValuesSynchronously(topic, records, producerConfig, time, false); + } + + /** + * @param topic Kafka topic to write the data records to + * @param records Data records to write to Kafka + * @param producerConfig Kafka producer configuration + * @param headers {@link Headers} of the data records + * @param time Timestamp provider + * @param Key type of the data records + * @param Value type of the data records + */ + public static void produceKeyValuesSynchronously(final String topic, + final Collection> records, + final Properties producerConfig, + final Headers headers, + final Time time) { + produceKeyValuesSynchronously(topic, records, producerConfig, headers, time, false); + } + + /** + * @param topic Kafka topic to write the data records to + * @param records Data records to write to Kafka + * @param producerConfig Kafka producer configuration + * @param time Timestamp provider + * @param enableTransactions Send messages in a transaction + * @param Key type of the data records + * @param Value type of the data records + */ + public static void produceKeyValuesSynchronously(final String topic, + final Collection> records, + final Properties producerConfig, + final Time time, + final boolean enableTransactions) { + produceKeyValuesSynchronously(topic, records, producerConfig, null, time, enableTransactions); + } + + /** + * @param topic Kafka topic to write the data records to + * @param records Data records to write to Kafka + * @param producerConfig Kafka producer configuration + * @param headers {@link Headers} of the data records + * @param time Timestamp provider + * @param enableTransactions Send messages in a transaction + * @param Key type of the data records + * @param Value type of the data records + */ + public static void produceKeyValuesSynchronously(final String topic, + final Collection> records, + final Properties producerConfig, + final Headers headers, + final Time time, + final boolean enableTransactions) { + try (final Producer producer = new KafkaProducer<>(producerConfig)) { + if (enableTransactions) { + producer.initTransactions(); + producer.beginTransaction(); + } + for (final KeyValue record : records) { + producer.send(new ProducerRecord<>(topic, null, time.milliseconds(), record.key, record.value, headers)); + time.sleep(1L); + } + if (enableTransactions) { + producer.commitTransaction(); + } else { + producer.flush(); + } + } + } + + /** + * @param topic Kafka topic to write the data records to + * @param records Data records to write to Kafka + * @param producerConfig Kafka producer configuration + * @param timestamp Timestamp of the record + * @param Key type of the data records + * @param Value type of the data records + */ + public static void produceKeyValuesSynchronouslyWithTimestamp(final String topic, + final Collection> records, + final Properties producerConfig, + final Long timestamp) { + produceKeyValuesSynchronouslyWithTimestamp(topic, records, producerConfig, timestamp, false); + } + + /** + * @param topic Kafka topic to write the data records to + * @param records Data records to write to Kafka + * @param producerConfig Kafka producer configuration + * @param timestamp Timestamp of the record + * @param enableTransactions Send messages in a transaction + * @param Key type of the data records + * @param Value type of the data records + */ + public static void produceKeyValuesSynchronouslyWithTimestamp(final String topic, + final Collection> records, + final Properties producerConfig, + final Long timestamp, + final boolean enableTransactions) { + produceKeyValuesSynchronouslyWithTimestamp(topic, records, producerConfig, null, timestamp, enableTransactions); + } + + /** + * @param topic Kafka topic to write the data records to + * @param records Data records to write to Kafka + * @param producerConfig Kafka producer configuration + * @param headers {@link Headers} of the data records + * @param timestamp Timestamp of the record + * @param enableTransactions Send messages in a transaction + * @param Key type of the data records + * @param Value type of the data records + */ + public static void produceKeyValuesSynchronouslyWithTimestamp(final String topic, + final Collection> records, + final Properties producerConfig, + final Headers headers, + final Long timestamp, + final boolean enableTransactions) { + try (final Producer producer = new KafkaProducer<>(producerConfig)) { + if (enableTransactions) { + producer.initTransactions(); + producer.beginTransaction(); + } + for (final KeyValue record : records) { + producer.send(new ProducerRecord<>(topic, null, timestamp, record.key, record.value, headers)); + } + if (enableTransactions) { + producer.commitTransaction(); + } + } + } + + public static void produceSynchronously(final Properties producerConfig, + final boolean eos, + final String topic, + final Optional partition, + final List> toProduce) { + try (final Producer producer = new KafkaProducer<>(producerConfig)) { + if (eos) { + producer.initTransactions(); + producer.beginTransaction(); + } + final LinkedList> futures = new LinkedList<>(); + for (final KeyValueTimestamp record : toProduce) { + final Future f = producer.send( + new ProducerRecord<>( + topic, + partition.orElse(null), + record.timestamp(), + record.key(), + record.value(), + null + ) + ); + futures.add(f); + } + + if (eos) { + producer.commitTransaction(); + } else { + producer.flush(); + } + + for (final Future future : futures) { + try { + future.get(); + } catch (final InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + } + } + } + + /** + * Produce data records and send them synchronously in an aborted transaction; that is, a transaction is started for + * each data record but not committed. + * + * @param topic Kafka topic to write the data records to + * @param records Data records to write to Kafka + * @param producerConfig Kafka producer configuration + * @param timestamp Timestamp of the record + * @param Key type of the data records + * @param Value type of the data records + */ + public static void produceAbortedKeyValuesSynchronouslyWithTimestamp(final String topic, + final Collection> records, + final Properties producerConfig, + final Long timestamp) throws Exception { + try (final Producer producer = new KafkaProducer<>(producerConfig)) { + producer.initTransactions(); + for (final KeyValue record : records) { + producer.beginTransaction(); + final Future f = producer + .send(new ProducerRecord<>(topic, null, timestamp, record.key, record.value)); + f.get(); + producer.abortTransaction(); + } + } + } + + /** + * @param topic Kafka topic to write the data records to + * @param records Data records to write to Kafka + * @param producerConfig Kafka producer configuration + * @param time Timestamp provider + * @param Value type of the data records + */ + public static void produceValuesSynchronously(final String topic, + final Collection records, + final Properties producerConfig, + final Time time) { + produceValuesSynchronously(topic, records, producerConfig, time, false); + } + + /** + * @param topic Kafka topic to write the data records to + * @param records Data records to write to Kafka + * @param producerConfig Kafka producer configuration + * @param time Timestamp provider + * @param enableTransactions Send messages in a transaction + * @param Value type of the data records + */ + @SuppressWarnings("WeakerAccess") + public static void produceValuesSynchronously(final String topic, + final Collection records, + final Properties producerConfig, + final Time time, + final boolean enableTransactions) { + final Collection> keyedRecords = new ArrayList<>(); + for (final V value : records) { + final KeyValue kv = new KeyValue<>(null, value); + keyedRecords.add(kv); + } + produceKeyValuesSynchronously(topic, keyedRecords, producerConfig, time, enableTransactions); + } + + /** + * Wait for streams to "finish", based on the consumer lag metric. Includes only the main consumer, for + * completion of standbys as well see {@link #waitForStandbyCompletion} + * + * Caveats: + * - Inputs must be finite, fully loaded, and flushed before this method is called + * - expectedPartitions is the total number of partitions to watch the lag on, including both input and internal. + * It's somewhat ok to get this wrong, as the main failure case would be an immediate return due to the clients + * not being initialized, which you can avoid with any non-zero value. But it's probably better to get it right ;) + */ + public static void waitForCompletion(final KafkaStreams streams, + final int expectedPartitions, + final long timeoutMilliseconds) { + final long start = System.currentTimeMillis(); + while (true) { + int lagMetrics = 0; + double totalLag = 0.0; + for (final Metric metric : streams.metrics().values()) { + if (metric.metricName().name().equals("records-lag")) { + if (!metric.metricName().tags().get("client-id").endsWith("restore-consumer")) { + lagMetrics++; + totalLag += ((Number) metric.metricValue()).doubleValue(); + } + } + } + if (lagMetrics >= expectedPartitions && totalLag == 0.0) { + return; + } + if (System.currentTimeMillis() - start >= timeoutMilliseconds) { + throw new RuntimeException(String.format( + "Timed out waiting for completion. lagMetrics=[%s/%s] totalLag=[%s]", + lagMetrics, expectedPartitions, totalLag + )); + } + } + } + + /** + * Wait for streams to "finish" processing standbys, based on the (restore) consumer lag metric. Includes only the + * restore consumer, for completion of active tasks see {@link #waitForCompletion} + * + * Caveats: + * - Inputs must be finite, fully loaded, and flushed before this method is called + * - expectedPartitions is the total number of partitions to watch the lag on, including both input and internal. + * It's somewhat ok to get this wrong, as the main failure case would be an immediate return due to the clients + * not being initialized, which you can avoid with any non-zero value. But it's probably better to get it right ;) + */ + public static void waitForStandbyCompletion(final KafkaStreams streams, + final int expectedPartitions, + final long timeoutMilliseconds) { + final long start = System.currentTimeMillis(); + while (true) { + int lagMetrics = 0; + double totalLag = 0.0; + for (final Metric metric : streams.metrics().values()) { + if (metric.metricName().name().equals("records-lag")) { + if (metric.metricName().tags().get("client-id").endsWith("restore-consumer")) { + lagMetrics++; + totalLag += ((Number) metric.metricValue()).doubleValue(); + } + } + } + if (lagMetrics >= expectedPartitions && totalLag == 0.0) { + return; + } + if (System.currentTimeMillis() - start >= timeoutMilliseconds) { + throw new RuntimeException(String.format( + "Timed out waiting for completion. lagMetrics=[%s/%s] totalLag=[%s]", + lagMetrics, expectedPartitions, totalLag + )); + } + } + } + + /** + * Wait until enough data (consumer records) has been consumed. + * + * @param consumerConfig Kafka Consumer configuration + * @param topic Kafka topic to consume from + * @param expectedNumRecords Minimum number of expected records + * @param Key type of the data records + * @param Value type of the data records + * @return All the records consumed, or null if no records are consumed + */ + @SuppressWarnings("WeakerAccess") + public static List> waitUntilMinRecordsReceived(final Properties consumerConfig, + final String topic, + final int expectedNumRecords) throws Exception { + return waitUntilMinRecordsReceived(consumerConfig, topic, expectedNumRecords, DEFAULT_TIMEOUT); + } + + /** + * Wait until enough data (consumer records) has been consumed. + * + * @param consumerConfig Kafka Consumer configuration + * @param topic Kafka topic to consume from + * @param expectedNumRecords Minimum number of expected records + * @param waitTime Upper bound of waiting time in milliseconds + * @param Key type of the data records + * @param Value type of the data records + * @return All the records consumed, or null if no records are consumed + */ + @SuppressWarnings("WeakerAccess") + public static List> waitUntilMinRecordsReceived(final Properties consumerConfig, + final String topic, + final int expectedNumRecords, + final long waitTime) throws Exception { + final List> accumData = new ArrayList<>(); + final String reason = String.format( + "Did not receive all %d records from topic %s within %d ms", + expectedNumRecords, + topic, + waitTime + ); + try (final Consumer consumer = createConsumer(consumerConfig)) { + retryOnExceptionWithTimeout(waitTime, () -> { + final List> readData = + readRecords(topic, consumer, waitTime, expectedNumRecords); + accumData.addAll(readData); + assertThat(reason, accumData.size(), is(greaterThanOrEqualTo(expectedNumRecords))); + }); + } + return accumData; + } + + /** + * Wait until enough data (key-value records) has been consumed. + * + * @param consumerConfig Kafka Consumer configuration + * @param topic Kafka topic to consume from + * @param expectedNumRecords Minimum number of expected records + * @param Key type of the data records + * @param Value type of the data records + * @return All the records consumed, or null if no records are consumed + */ + public static List> waitUntilMinKeyValueRecordsReceived(final Properties consumerConfig, + final String topic, + final int expectedNumRecords) throws Exception { + return waitUntilMinKeyValueRecordsReceived(consumerConfig, topic, expectedNumRecords, DEFAULT_TIMEOUT); + } + + /** + * Wait until enough data (key-value records) has been consumed. + * + * @param consumerConfig Kafka Consumer configuration + * @param topic Kafka topic to consume from + * @param expectedNumRecords Minimum number of expected records + * @param waitTime Upper bound of waiting time in milliseconds + * @param Key type of the data records + * @param Value type of the data records + * @return All the records consumed, or null if no records are consumed + * @throws AssertionError if the given wait time elapses + */ + public static List> waitUntilMinKeyValueRecordsReceived(final Properties consumerConfig, + final String topic, + final int expectedNumRecords, + final long waitTime) throws Exception { + final List> accumData = new ArrayList<>(); + final String reason = String.format( + "Did not receive all %d records from topic %s within %d ms", + expectedNumRecords, + topic, + waitTime + ); + try (final Consumer consumer = createConsumer(consumerConfig)) { + retryOnExceptionWithTimeout(waitTime, () -> { + final List> readData = + readKeyValues(topic, consumer, waitTime, expectedNumRecords); + accumData.addAll(readData); + assertThat(reason + ", currently accumulated data is " + accumData, accumData.size(), is(greaterThanOrEqualTo(expectedNumRecords))); + }); + } + return accumData; + } + + /** + * Wait until enough data (timestamped key-value records) has been consumed. + * + * @param consumerConfig Kafka Consumer configuration + * @param topic Kafka topic to consume from + * @param expectedNumRecords Minimum number of expected records + * @param waitTime Upper bound of waiting time in milliseconds + * @return All the records consumed, or null if no records are consumed + * @param Key type of the data records + * @param Value type of the data records + */ + public static List> waitUntilMinKeyValueWithTimestampRecordsReceived(final Properties consumerConfig, + final String topic, + final int expectedNumRecords, + final long waitTime) throws Exception { + final List> accumData = new ArrayList<>(); + final String reason = String.format( + "Did not receive all %d records from topic %s within %d ms", + expectedNumRecords, + topic, + waitTime + ); + try (final Consumer consumer = createConsumer(consumerConfig)) { + retryOnExceptionWithTimeout(waitTime, () -> { + final List> readData = + readKeyValuesWithTimestamp(topic, consumer, waitTime, expectedNumRecords); + accumData.addAll(readData); + assertThat(reason, accumData.size(), is(greaterThanOrEqualTo(expectedNumRecords))); + }); + } + return accumData; + } + + /** + * Wait until final key-value mappings have been consumed. + * + * @param consumerConfig Kafka Consumer configuration + * @param topic Kafka topic to consume from + * @param expectedRecords Expected key-value mappings + * @param Key type of the data records + * @param Value type of the data records + * @return All the mappings consumed, or null if no records are consumed + */ + public static List> waitUntilFinalKeyValueRecordsReceived(final Properties consumerConfig, + final String topic, + final List> expectedRecords) throws Exception { + return waitUntilFinalKeyValueRecordsReceived(consumerConfig, topic, expectedRecords, DEFAULT_TIMEOUT); + } + + /** + * Wait until final key-value mappings have been consumed. + * + * @param consumerConfig Kafka Consumer configuration + * @param topic Kafka topic to consume from + * @param expectedRecords Expected key-value mappings + * @param Key type of the data records + * @param Value type of the data records + * @return All the mappings consumed, or null if no records are consumed + */ + public static List> waitUntilFinalKeyValueTimestampRecordsReceived(final Properties consumerConfig, + final String topic, + final List> expectedRecords) throws Exception { + return waitUntilFinalKeyValueRecordsReceived(consumerConfig, topic, expectedRecords, DEFAULT_TIMEOUT, true); + } + + /** + * Wait until final key-value mappings have been consumed. + * + * @param consumerConfig Kafka Consumer configuration + * @param topic Kafka topic to consume from + * @param expectedRecords Expected key-value mappings + * @param waitTime Upper bound of waiting time in milliseconds + * @param Key type of the data records + * @param Value type of the data records + * @return All the mappings consumed, or null if no records are consumed + */ + @SuppressWarnings("WeakerAccess") + public static List> waitUntilFinalKeyValueRecordsReceived(final Properties consumerConfig, + final String topic, + final List> expectedRecords, + final long waitTime) throws Exception { + return waitUntilFinalKeyValueRecordsReceived(consumerConfig, topic, expectedRecords, waitTime, false); + } + + @SuppressWarnings("unchecked") + private static List waitUntilFinalKeyValueRecordsReceived(final Properties consumerConfig, + final String topic, + final List expectedRecords, + final long waitTime, + final boolean withTimestamp) throws Exception { + final List accumData = new ArrayList<>(); + try (final Consumer consumer = createConsumer(consumerConfig)) { + final TestCondition valuesRead = () -> { + final List readData; + if (withTimestamp) { + readData = (List) readKeyValuesWithTimestamp(topic, consumer, waitTime, expectedRecords.size()); + } else { + readData = (List) readKeyValues(topic, consumer, waitTime, expectedRecords.size()); + } + accumData.addAll(readData); + + // filter out all intermediate records we don't want + final List accumulatedActual = accumData + .stream() + .filter(expectedRecords::contains) + .collect(Collectors.toList()); + + // still need to check that for each key, the ordering is expected + final Map> finalAccumData = new HashMap<>(); + for (final T kv : accumulatedActual) { + finalAccumData.computeIfAbsent( + withTimestamp ? ((KeyValueTimestamp) kv).key() : ((KeyValue) kv).key, + key -> new ArrayList<>()).add(kv); + } + final Map> finalExpected = new HashMap<>(); + for (final T kv : expectedRecords) { + finalExpected.computeIfAbsent( + withTimestamp ? ((KeyValueTimestamp) kv).key() : ((KeyValue) kv).key, + key -> new ArrayList<>()).add(kv); + } + + // returns true only if the remaining records in both lists are the same and in the same order + // and the last record received matches the last expected record + return finalAccumData.equals(finalExpected); + + }; + final String conditionDetails = "Did not receive all " + expectedRecords + " records from topic " + + topic + " (got " + accumData + ")"; + TestUtils.waitForCondition(valuesRead, waitTime, conditionDetails); + } + return accumData; + } + + /** + * Wait until enough data (value records) has been consumed. + * + * @param consumerConfig Kafka Consumer configuration + * @param topic Topic to consume from + * @param expectedNumRecords Minimum number of expected records + * @return All the records consumed, or null if no records are consumed + * @throws AssertionError if the given wait time elapses + */ + public static List waitUntilMinValuesRecordsReceived(final Properties consumerConfig, + final String topic, + final int expectedNumRecords) throws Exception { + return waitUntilMinValuesRecordsReceived(consumerConfig, topic, expectedNumRecords, DEFAULT_TIMEOUT); + } + + /** + * Wait until enough data (value records) has been consumed. + * + * @param consumerConfig Kafka Consumer configuration + * @param topic Topic to consume from + * @param expectedNumRecords Minimum number of expected records + * @param waitTime Upper bound of waiting time in milliseconds + * @return All the records consumed, or null if no records are consumed + * @throws AssertionError if the given wait time elapses + */ + public static List waitUntilMinValuesRecordsReceived(final Properties consumerConfig, + final String topic, + final int expectedNumRecords, + final long waitTime) throws Exception { + final List accumData = new ArrayList<>(); + final String reason = String.format( + "Did not receive all %d records from topic %s within %d ms", + expectedNumRecords, + topic, + waitTime + ); + try (final Consumer consumer = createConsumer(consumerConfig)) { + retryOnExceptionWithTimeout(waitTime, () -> { + final List readData = + readValues(topic, consumer, waitTime, expectedNumRecords); + accumData.addAll(readData); + assertThat(reason, accumData.size(), is(greaterThanOrEqualTo(expectedNumRecords))); + }); + } + return accumData; + } + + @SuppressWarnings("WeakerAccess") + public static void waitForTopicPartitions(final List servers, + final List partitions, + final long timeout) throws InterruptedException { + final long end = System.currentTimeMillis() + timeout; + for (final TopicPartition partition : partitions) { + final long remaining = end - System.currentTimeMillis(); + if (remaining <= 0) { + throw new AssertionError("timed out while waiting for partitions to become available. Timeout=" + timeout); + } + waitUntilMetadataIsPropagated(servers, partition.topic(), partition.partition(), remaining); + } + } + + private static void waitUntilMetadataIsPropagated(final List servers, + final String topic, + final int partition, + final long timeout) throws InterruptedException { + final String baseReason = String.format("Metadata for topic=%s partition=%d was not propagated to all brokers within %d ms. ", + topic, partition, timeout); + + retryOnExceptionWithTimeout(timeout, () -> { + final List emptyPartitionInfos = new ArrayList<>(); + final List invalidBrokerIds = new ArrayList<>(); + + for (final KafkaServer server : servers) { + final MetadataCache metadataCache = server.dataPlaneRequestProcessor().metadataCache(); + final Option partitionInfo = + metadataCache.getPartitionInfo(topic, partition); + + if (partitionInfo.isEmpty()) { + emptyPartitionInfos.add(server); + continue; + } + + final UpdateMetadataPartitionState metadataPartitionState = partitionInfo.get(); + if (!Request.isValidBrokerId(metadataPartitionState.leader())) { + invalidBrokerIds.add(server); + } + } + + final String reason = baseReason + ". Brokers without partition info: " + emptyPartitionInfos + + ". Brokers with invalid broker id for partition leader: " + invalidBrokerIds; + assertThat(reason, emptyPartitionInfos.isEmpty() && invalidBrokerIds.isEmpty()); + }); + } + + /** + * Starts the given {@link KafkaStreams} instances and waits for all of them to reach the + * {@link State#RUNNING} state at the same time. Note that states may change between the time + * that this method returns and the calling function executes its next statement.

                + * + * If the application is already started, use {@link #waitForApplicationState(List, State, Duration)} + * to wait for instances to reach {@link State#RUNNING} state. + * + * @param streamsList the list of streams instances to run. + * @param timeout the time to wait for the streams to all be in {@link State#RUNNING} state. + */ + public static void startApplicationAndWaitUntilRunning(final List streamsList, + final Duration timeout) throws Exception { + final Lock stateLock = new ReentrantLock(); + final Condition stateUpdate = stateLock.newCondition(); + final Map stateMap = new HashMap<>(); + for (final KafkaStreams streams : streamsList) { + stateMap.put(streams, streams.state()); + final StateListener prevStateListener = getStateListener(streams); + final StateListener newStateListener = (newState, oldState) -> { + stateLock.lock(); + try { + stateMap.put(streams, newState); + if (newState == State.RUNNING) { + if (stateMap.values().stream().allMatch(state -> state == State.RUNNING)) { + stateUpdate.signalAll(); + } + } + } finally { + stateLock.unlock(); + } + }; + + streams.setStateListener(prevStateListener != null + ? new CompositeStateListener(prevStateListener, newStateListener) + : newStateListener); + } + + for (final KafkaStreams streams : streamsList) { + streams.start(); + } + + final long expectedEnd = System.currentTimeMillis() + timeout.toMillis(); + stateLock.lock(); + try { + // We use while true here because we want to run this test at least once, even if the + // timeout has expired + while (true) { + final Map nonRunningStreams = new HashMap<>(); + for (final Entry entry : stateMap.entrySet()) { + if (entry.getValue() != State.RUNNING) { + nonRunningStreams.put(entry.getKey(), entry.getValue()); + } + } + + if (nonRunningStreams.isEmpty()) { + return; + } + + final long millisRemaining = expectedEnd - System.currentTimeMillis(); + if (millisRemaining <= 0) { + fail( + "Application did not reach a RUNNING state for all streams instances. " + + "Non-running instances: " + nonRunningStreams + ); + } + + stateUpdate.await(millisRemaining, TimeUnit.MILLISECONDS); + } + } finally { + stateLock.unlock(); + } + } + + /** + * Waits for the given {@link KafkaStreams} instances to all be in a specific {@link State}. + * Prefer {@link #startApplicationAndWaitUntilRunning(List, Duration)} when possible + * because this method uses polling, which can be more error prone and slightly slower. + * + * @param streamsList the list of streams instances to run. + * @param state the expected state that all the streams to be in within timeout + * @param timeout the time to wait for the streams to all be in the specific state. + * + * @throws InterruptedException if the streams doesn't change to the expected state in time. + */ + public static void waitForApplicationState(final List streamsList, + final State state, + final Duration timeout) throws InterruptedException { + retryOnExceptionWithTimeout(timeout.toMillis(), () -> { + final Map streamsToStates = streamsList + .stream() + .collect(Collectors.toMap(stream -> stream, KafkaStreams::state)); + + final Map wrongStateMap = streamsToStates.entrySet() + .stream() + .filter(entry -> entry.getValue() != state) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + final String reason = String.format( + "Expected all streams instances in %s to be %s within %d ms, but the following were not: %s", + streamsList, + state, + timeout.toMillis(), + wrongStateMap + ); + assertThat(reason, wrongStateMap.isEmpty()); + }); + } + + private static class ConsumerGroupInactiveCondition implements TestCondition { + private final Admin adminClient; + private final String applicationId; + + private ConsumerGroupInactiveCondition(final Admin adminClient, + final String applicationId) { + this.adminClient = adminClient; + this.applicationId = applicationId; + } + + @Override + public boolean conditionMet() { + return isEmptyConsumerGroup(adminClient, applicationId); + } + } + + public static void waitForEmptyConsumerGroup(final Admin adminClient, + final String applicationId, + final long timeoutMs) throws Exception { + TestUtils.waitForCondition( + new IntegrationTestUtils.ConsumerGroupInactiveCondition(adminClient, applicationId), + timeoutMs, + "Test consumer group " + applicationId + " still active even after waiting " + timeoutMs + " ms." + ); + } + + public static boolean isEmptyConsumerGroup(final Admin adminClient, + final String applicationId) { + try { + final ConsumerGroupDescription groupDescription = + adminClient.describeConsumerGroups(Collections.singletonList(applicationId)) + .describedGroups() + .get(applicationId) + .get(); + return groupDescription.members().isEmpty(); + } catch (final ExecutionException | InterruptedException e) { + return false; + } + } + + private static StateListener getStateListener(final KafkaStreams streams) { + try { + if (streams instanceof KafkaStreamsNamedTopologyWrapper) { + final Field field = streams.getClass().getSuperclass().getDeclaredField("stateListener"); + field.setAccessible(true); + return (StateListener) field.get(streams); + } else { + final Field field = streams.getClass().getDeclaredField("stateListener"); + field.setAccessible(true); + return (StateListener) field.get(streams); + } + } catch (final IllegalAccessException | NoSuchFieldException e) { + throw new RuntimeException("Failed to get StateListener through reflection", e); + } + } + + public static void verifyKeyValueTimestamps(final Properties consumerConfig, + final String topic, + final List> expected) { + final List> results; + try { + results = waitUntilMinRecordsReceived(consumerConfig, topic, expected.size()); + } catch (final Exception e) { + throw new RuntimeException(e); + } + + if (results.size() != expected.size()) { + throw new AssertionError(printRecords(results) + " != " + expected); + } + final Iterator> expectedIterator = expected.iterator(); + for (final ConsumerRecord result : results) { + final KeyValueTimestamp expected1 = expectedIterator.next(); + try { + compareKeyValueTimestamp(result, expected1.key(), expected1.value(), expected1.timestamp()); + } catch (final AssertionError e) { + throw new AssertionError(printRecords(results) + " != " + expected, e); + } + } + } + + public static void verifyKeyValueTimestamps(final Properties consumerConfig, + final String topic, + final Set> expected) { + final List> results; + try { + results = waitUntilMinRecordsReceived(consumerConfig, topic, expected.size()); + } catch (final Exception e) { + throw new RuntimeException(e); + } + + if (results.size() != expected.size()) { + throw new AssertionError(printRecords(results) + " != " + expected); + } + + final Set> actual = + results.stream() + .map(result -> new KeyValueTimestamp<>(result.key(), result.value(), result.timestamp())) + .collect(Collectors.toSet()); + + assertThat(actual, equalTo(expected)); + } + + private static void compareKeyValueTimestamp(final ConsumerRecord record, + final K expectedKey, + final V expectedValue, + final long expectedTimestamp) { + Objects.requireNonNull(record); + final K recordKey = record.key(); + final V recordValue = record.value(); + final long recordTimestamp = record.timestamp(); + final AssertionError error = new AssertionError( + "Expected <" + expectedKey + ", " + expectedValue + "> with timestamp=" + expectedTimestamp + + " but was <" + recordKey + ", " + recordValue + "> with timestamp=" + recordTimestamp + ); + if (recordKey != null) { + if (!recordKey.equals(expectedKey)) { + throw error; + } + } else if (expectedKey != null) { + throw error; + } + if (recordValue != null) { + if (!recordValue.equals(expectedValue)) { + throw error; + } + } else if (expectedValue != null) { + throw error; + } + if (recordTimestamp != expectedTimestamp) { + throw error; + } + } + + private static String printRecords(final List> result) { + final StringBuilder resultStr = new StringBuilder(); + resultStr.append("[\n"); + for (final ConsumerRecord record : result) { + resultStr.append(" ").append(record.toString()).append("\n"); + } + resultStr.append("]"); + return resultStr.toString(); + } + + /** + * Returns up to `maxMessages` message-values from the topic. + * + * @param topic Kafka topic to read messages from + * @param consumer Kafka consumer + * @param waitTime Maximum wait time in milliseconds + * @param maxMessages Maximum number of messages to read via the consumer. + * @return The values retrieved via the consumer. + */ + private static List readValues(final String topic, + final Consumer consumer, + final long waitTime, + final int maxMessages) { + final List returnList = new ArrayList<>(); + final List> kvs = readKeyValues(topic, consumer, waitTime, maxMessages); + for (final KeyValue kv : kvs) { + returnList.add(kv.value); + } + return returnList; + } + + /** + * Returns up to `maxMessages` by reading via the provided consumer (the topic(s) to read from + * are already configured in the consumer). + * + * @param topic Kafka topic to read messages from + * @param consumer Kafka consumer + * @param waitTime Maximum wait time in milliseconds + * @param maxMessages Maximum number of messages to read via the consumer + * @return The KeyValue elements retrieved via the consumer + */ + private static List> readKeyValues(final String topic, + final Consumer consumer, + final long waitTime, + final int maxMessages) { + final List> consumedValues = new ArrayList<>(); + final List> records = readRecords(topic, consumer, waitTime, maxMessages); + for (final ConsumerRecord record : records) { + consumedValues.add(new KeyValue<>(record.key(), record.value())); + } + return consumedValues; + } + + /** + * Returns up to `maxMessages` by reading via the provided consumer (the topic(s) to read from + * are already configured in the consumer). + * + * @param topic Kafka topic to read messages from + * @param consumer Kafka consumer + * @param waitTime Maximum wait time in milliseconds + * @param maxMessages Maximum number of messages to read via the consumer + * @return The KeyValue elements retrieved via the consumer + */ + private static List> readKeyValuesWithTimestamp(final String topic, + final Consumer consumer, + final long waitTime, + final int maxMessages) { + final List> consumedValues = new ArrayList<>(); + final List> records = readRecords(topic, consumer, waitTime, maxMessages); + for (final ConsumerRecord record : records) { + consumedValues.add(new KeyValueTimestamp<>(record.key(), record.value(), record.timestamp())); + } + return consumedValues; + } + + private static List> readRecords(final String topic, + final Consumer consumer, + final long waitTime, + final int maxMessages) { + final List> consumerRecords; + consumer.subscribe(Collections.singletonList(topic)); + final int pollIntervalMs = 100; + consumerRecords = new ArrayList<>(); + int totalPollTimeMs = 0; + while (totalPollTimeMs < waitTime && + continueConsuming(consumerRecords.size(), maxMessages)) { + totalPollTimeMs += pollIntervalMs; + final ConsumerRecords records = consumer.poll(Duration.ofMillis(pollIntervalMs)); + + for (final ConsumerRecord record : records) { + consumerRecords.add(record); + } + } + return consumerRecords; + } + + private static boolean continueConsuming(final int messagesConsumed, final int maxMessages) { + return maxMessages > 0 && messagesConsumed < maxMessages; + } + + /** + * Sets up a {@link KafkaConsumer} from a copy of the given configuration that has + * {@link ConsumerConfig#AUTO_OFFSET_RESET_CONFIG} set to "earliest" and {@link ConsumerConfig#ENABLE_AUTO_COMMIT_CONFIG} + * set to "true" to prevent missing events as well as repeat consumption. + * @param consumerConfig Consumer configuration + * @return Consumer + */ + private static KafkaConsumer createConsumer(final Properties consumerConfig) { + final Properties filtered = new Properties(); + filtered.putAll(consumerConfig); + filtered.setProperty(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + filtered.setProperty(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "true"); + return new KafkaConsumer<>(filtered); + } + + public static KafkaStreams getStartedStreams(final Properties streamsConfig, + final StreamsBuilder builder, + final boolean clean) { + final KafkaStreams driver = new KafkaStreams(builder.build(), streamsConfig); + if (clean) { + driver.cleanUp(); + } + driver.start(); + return driver; + } + + public static KafkaStreams getRunningStreams(final Properties streamsConfig, + final StreamsBuilder builder, + final boolean clean) { + final KafkaStreams driver = new KafkaStreams(builder.build(), streamsConfig); + if (clean) { + driver.cleanUp(); + } + final CountDownLatch latch = new CountDownLatch(1); + driver.setStateListener((newState, oldState) -> { + if (newState == State.RUNNING) { + latch.countDown(); + } + }); + driver.start(); + try { + latch.await(DEFAULT_TIMEOUT, TimeUnit.MILLISECONDS); + } catch (final InterruptedException e) { + throw new RuntimeException("Streams didn't start in time.", e); + } + return driver; + } + + public static S getStore(final String storeName, + final KafkaStreams streams, + final QueryableStoreType storeType) throws Exception { + return getStore(DEFAULT_TIMEOUT, storeName, streams, storeType); + } + + public static S getStore(final String storeName, + final KafkaStreams streams, + final boolean enableStaleQuery, + final QueryableStoreType storeType) throws Exception { + return getStore(DEFAULT_TIMEOUT, storeName, streams, enableStaleQuery, storeType); + } + + public static S getStore(final long waitTime, + final String storeName, + final KafkaStreams streams, + final QueryableStoreType storeType) throws Exception { + return getStore(waitTime, storeName, streams, false, storeType); + } + + public static S getStore(final long waitTime, + final String storeName, + final KafkaStreams streams, + final boolean enableStaleQuery, + final QueryableStoreType storeType) throws Exception { + final StoreQueryParameters param = enableStaleQuery ? + StoreQueryParameters.fromNameAndType(storeName, storeType).enableStaleStores() : + StoreQueryParameters.fromNameAndType(storeName, storeType); + return getStore(waitTime, streams, param); + } + + public static S getStore(final KafkaStreams streams, + final StoreQueryParameters param) throws Exception { + return getStore(DEFAULT_TIMEOUT, streams, param); + } + + public static S getStore(final long waitTime, + final KafkaStreams streams, + final StoreQueryParameters param) throws Exception { + final long expectedEnd = System.currentTimeMillis() + waitTime; + while (true) { + try { + return streams.store(param); + } catch (final InvalidStateStoreException e) { + if (System.currentTimeMillis() > expectedEnd) { + throw e; + } + } catch (final Exception e) { + if (System.currentTimeMillis() > expectedEnd) { + throw new AssertionError(e); + } + } + Thread.sleep(Math.min(100L, waitTime)); + } + } + + public static class StableAssignmentListener implements AssignmentListener { + final AtomicInteger numStableAssignments = new AtomicInteger(0); + int nextExpectedNumStableAssignments; + + @Override + public void onAssignmentComplete(final boolean stable) { + if (stable) { + numStableAssignments.incrementAndGet(); + } + } + + public int numStableAssignments() { + return numStableAssignments.get(); + } + + /** + * Saves the current number of stable rebalances so that we can tell when the next stable assignment has been + * reached. This should be called once for every invocation of {@link #waitForNextStableAssignment(long)}, + * before the rebalance-triggering event. + */ + public void prepareForRebalance() { + nextExpectedNumStableAssignments = numStableAssignments.get() + 1; + } + + /** + * Waits for the assignment to stabilize after the group rebalances. You must call {@link #prepareForRebalance()} + * prior to the rebalance-triggering event before using this method to wait. + */ + public void waitForNextStableAssignment(final long maxWaitMs) throws InterruptedException { + waitForCondition( + () -> numStableAssignments() >= nextExpectedNumStableAssignments, + maxWaitMs, + () -> "Client did not reach " + nextExpectedNumStableAssignments + " stable assignments on time, " + + "numStableAssignments was " + numStableAssignments() + ); + } + } + + /** + * Tracks the offsets and number of restored records on a per-partition basis. + * Currently assumes only one store in the topology; you will need to update this + * if it's important to track across multiple stores in a topology + */ + public static class TrackingStateRestoreListener implements StateRestoreListener { + public final Map changelogToStartOffset = new ConcurrentHashMap<>(); + public final Map changelogToEndOffset = new ConcurrentHashMap<>(); + public final Map changelogToTotalNumRestored = new ConcurrentHashMap<>(); + + @Override + public void onRestoreStart(final TopicPartition topicPartition, + final String storeName, + final long startingOffset, + final long endingOffset) { + changelogToStartOffset.put(topicPartition, new AtomicLong(startingOffset)); + changelogToEndOffset.put(topicPartition, new AtomicLong(endingOffset)); + changelogToTotalNumRestored.put(topicPartition, new AtomicLong(0L)); + } + + @Override + public void onBatchRestored(final TopicPartition topicPartition, + final String storeName, + final long batchEndOffset, + final long numRestored) { + changelogToTotalNumRestored.get(topicPartition).addAndGet(numRestored); + } + + @Override + public void onRestoreEnd(final TopicPartition topicPartition, + final String storeName, + final long totalRestored) { + } + + public long totalNumRestored() { + long totalNumRestored = 0L; + for (final AtomicLong numRestored : changelogToTotalNumRestored.values()) { + totalNumRestored += numRestored.get(); + } + return totalNumRestored; + } + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/integration/utils/KafkaEmbedded.java b/streams/src/test/java/org/apache/kafka/streams/integration/utils/KafkaEmbedded.java new file mode 100644 index 0000000..165aaed --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/integration/utils/KafkaEmbedded.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration.utils; + +import kafka.cluster.EndPoint; +import kafka.server.KafkaConfig; +import kafka.server.KafkaServer; +import kafka.utils.MockTime; +import kafka.utils.TestUtils; +import org.apache.kafka.clients.CommonClientConfigs; +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.AdminClientConfig; +import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.common.config.SslConfigs; +import org.apache.kafka.common.config.types.Password; +import org.apache.kafka.common.errors.UnknownTopicOrPartitionException; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.ExecutionException; + +/** + * Runs an in-memory, "embedded" instance of a Kafka broker, which listens at `127.0.0.1:9092` by + * default. + *

                + * Requires a running ZooKeeper instance to connect to. + */ +public class KafkaEmbedded { + + private static final Logger log = LoggerFactory.getLogger(KafkaEmbedded.class); + + private static final String DEFAULT_ZK_CONNECT = "127.0.0.1:2181"; + + private final Properties effectiveConfig; + private final File logDir; + private final File tmpFolder; + private final KafkaServer kafka; + + /** + * Creates and starts an embedded Kafka broker. + * + * @param config Broker configuration settings. Used to modify, for example, on which port the + * broker should listen to. Note that you cannot change the `log.dirs` setting + * currently. + */ + @SuppressWarnings("WeakerAccess") + public KafkaEmbedded(final Properties config, final MockTime time) throws IOException { + tmpFolder = org.apache.kafka.test.TestUtils.tempDirectory(); + logDir = org.apache.kafka.test.TestUtils.tempDirectory(tmpFolder.toPath(), "log"); + effectiveConfig = effectiveConfigFrom(config); + final boolean loggingEnabled = true; + final KafkaConfig kafkaConfig = new KafkaConfig(effectiveConfig, loggingEnabled); + log.debug("Starting embedded Kafka broker (with log.dirs={} and ZK ensemble at {}) ...", + logDir, zookeeperConnect()); + kafka = TestUtils.createServer(kafkaConfig, time); + log.debug("Startup of embedded Kafka broker at {} completed (with ZK ensemble at {}) ...", + brokerList(), zookeeperConnect()); + } + + /** + * Creates the configuration for starting the Kafka broker by merging default values with + * overwrites. + * + * @param initialConfig Broker configuration settings that override the default config. + */ + private Properties effectiveConfigFrom(final Properties initialConfig) { + final Properties effectiveConfig = new Properties(); + effectiveConfig.put(KafkaConfig.BrokerIdProp(), 0); + effectiveConfig.put(KafkaConfig.NumPartitionsProp(), 1); + effectiveConfig.put(KafkaConfig.AutoCreateTopicsEnableProp(), true); + effectiveConfig.put(KafkaConfig.MessageMaxBytesProp(), 1000000); + effectiveConfig.put(KafkaConfig.ControlledShutdownEnableProp(), true); + effectiveConfig.put(KafkaConfig.ZkSessionTimeoutMsProp(), 10000); + + effectiveConfig.putAll(initialConfig); + effectiveConfig.setProperty(KafkaConfig.LogDirProp(), logDir.getAbsolutePath()); + return effectiveConfig; + } + + /** + * This broker's `metadata.broker.list` value. Example: `localhost:9092`. + *

                + * You can use this to tell Kafka producers and consumers how to connect to this instance. + */ + @SuppressWarnings("WeakerAccess") + public String brokerList() { + final EndPoint endPoint = kafka.advertisedListeners().head(); + return endPoint.host() + ":" + endPoint.port(); + } + + + /** + * The ZooKeeper connection string aka `zookeeper.connect`. + */ + @SuppressWarnings("WeakerAccess") + public String zookeeperConnect() { + return effectiveConfig.getProperty("zookeeper.connect", DEFAULT_ZK_CONNECT); + } + + @SuppressWarnings("WeakerAccess") + public void stopAsync() { + log.debug("Shutting down embedded Kafka broker at {} (with ZK ensemble at {}) ...", + brokerList(), zookeeperConnect()); + kafka.shutdown(); + } + + @SuppressWarnings("WeakerAccess") + public void awaitStoppedAndPurge() { + kafka.awaitShutdown(); + log.debug("Removing log dir at {} ...", logDir); + try { + Utils.delete(tmpFolder); + } catch (final IOException e) { + throw new RuntimeException(e); + } + log.debug("Shutdown of embedded Kafka broker at {} completed (with ZK ensemble at {}) ...", + brokerList(), zookeeperConnect()); + } + + /** + * Create a Kafka topic with 1 partition and a replication factor of 1. + * + * @param topic The name of the topic. + */ + public void createTopic(final String topic) { + createTopic(topic, 1, 1, Collections.emptyMap()); + } + + /** + * Create a Kafka topic with the given parameters. + * + * @param topic The name of the topic. + * @param partitions The number of partitions for this topic. + * @param replication The replication factor for (the partitions of) this topic. + */ + public void createTopic(final String topic, final int partitions, final int replication) { + createTopic(topic, partitions, replication, Collections.emptyMap()); + } + + /** + * Create a Kafka topic with the given parameters. + * + * @param topic The name of the topic. + * @param partitions The number of partitions for this topic. + * @param replication The replication factor for (partitions of) this topic. + * @param topicConfig Additional topic-level configuration settings. + */ + public void createTopic(final String topic, + final int partitions, + final int replication, + final Map topicConfig) { + log.debug("Creating topic { name: {}, partitions: {}, replication: {}, config: {} }", + topic, partitions, replication, topicConfig); + final NewTopic newTopic = new NewTopic(topic, partitions, (short) replication); + newTopic.configs(topicConfig); + + try (final Admin adminClient = createAdminClient()) { + adminClient.createTopics(Collections.singletonList(newTopic)).all().get(); + } catch (final InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + } + + @SuppressWarnings("WeakerAccess") + public Admin createAdminClient() { + final Properties adminClientConfig = new Properties(); + adminClientConfig.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList()); + final Object listeners = effectiveConfig.get(KafkaConfig.ListenersProp()); + if (listeners != null && listeners.toString().contains("SSL")) { + adminClientConfig.put(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG, effectiveConfig.get(SslConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG)); + adminClientConfig.put(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG, ((Password) effectiveConfig.get(SslConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG)).value()); + adminClientConfig.put(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG, "SSL"); + } + return Admin.create(adminClientConfig); + } + + @SuppressWarnings("WeakerAccess") + public void deleteTopic(final String topic) { + log.debug("Deleting topic { name: {} }", topic); + try (final Admin adminClient = createAdminClient()) { + adminClient.deleteTopics(Collections.singletonList(topic)).all().get(); + } catch (final InterruptedException | ExecutionException e) { + if (!(e.getCause() instanceof UnknownTopicOrPartitionException)) { + throw new RuntimeException(e); + } + } + } + + @SuppressWarnings("WeakerAccess") + public KafkaServer kafkaServer() { + return kafka; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/internals/ApiUtilsTest.java b/streams/src/test/java/org/apache/kafka/streams/internals/ApiUtilsTest.java new file mode 100644 index 0000000..6e4cdef --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/internals/ApiUtilsTest.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.internals; + +import org.junit.Test; + +import java.time.Duration; +import java.time.Instant; + +import static org.apache.kafka.streams.internals.ApiUtils.prepareMillisCheckFailMsgPrefix; +import static org.apache.kafka.streams.internals.ApiUtils.validateMillisecondDuration; +import static org.apache.kafka.streams.internals.ApiUtils.validateMillisecondInstant; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + + +public class ApiUtilsTest { + + // This is the maximum limit that Duration accepts but fails when it converts to milliseconds. + private static final long MAX_ACCEPTABLE_DAYS_FOR_DURATION = 106751991167300L; + // This is the maximum limit that Duration accepts and converts to milliseconds with out fail. + private static final long MAX_ACCEPTABLE_DAYS_FOR_DURATION_TO_MILLIS = 106751991167L; + + @Test + public void shouldThrowNullPointerExceptionForNullDuration() { + final String nullDurationPrefix = prepareMillisCheckFailMsgPrefix(null, "nullDuration"); + + try { + validateMillisecondDuration(null, nullDurationPrefix); + fail("Expected exception when null passed to duration."); + } catch (final IllegalArgumentException e) { + assertThat(e.getMessage(), containsString(nullDurationPrefix)); + } + } + + @Test + public void shouldThrowArithmeticExceptionForMaxDuration() { + final Duration maxDurationInDays = Duration.ofDays(MAX_ACCEPTABLE_DAYS_FOR_DURATION); + final String maxDurationPrefix = prepareMillisCheckFailMsgPrefix(maxDurationInDays, "maxDuration"); + + try { + validateMillisecondDuration(maxDurationInDays, maxDurationPrefix); + fail("Expected exception when maximum days passed for duration, because of long overflow"); + } catch (final IllegalArgumentException e) { + assertThat(e.getMessage(), containsString(maxDurationPrefix)); + } + } + + @Test + public void shouldThrowNullPointerExceptionForNullInstant() { + final String nullInstantPrefix = prepareMillisCheckFailMsgPrefix(null, "nullInstant"); + + try { + validateMillisecondInstant(null, nullInstantPrefix); + fail("Expected exception when null value passed for instant."); + } catch (final IllegalArgumentException e) { + assertThat(e.getMessage(), containsString(nullInstantPrefix)); + } + } + + @Test + public void shouldThrowArithmeticExceptionForMaxInstant() { + final String maxInstantPrefix = prepareMillisCheckFailMsgPrefix(Instant.MAX, "maxInstant"); + + try { + validateMillisecondInstant(Instant.MAX, maxInstantPrefix); + fail("Expected exception when maximum value passed for instant, because of long overflow."); + } catch (final IllegalArgumentException e) { + assertThat(e.getMessage(), containsString(maxInstantPrefix)); + } + } + + @Test + public void shouldReturnMillisecondsOnValidDuration() { + final Duration sampleDuration = Duration.ofDays(MAX_ACCEPTABLE_DAYS_FOR_DURATION_TO_MILLIS); + + assertEquals(sampleDuration.toMillis(), validateMillisecondDuration(sampleDuration, "sampleDuration")); + } + + @Test + public void shouldReturnMillisecondsOnValidInstant() { + final Instant sampleInstant = Instant.now(); + + assertEquals(sampleInstant.toEpochMilli(), validateMillisecondInstant(sampleInstant, "sampleInstant")); + } + + @Test + public void shouldContainsNameAndValueInFailMsgPrefix() { + final String failMsgPrefix = prepareMillisCheckFailMsgPrefix("someValue", "variableName"); + + assertThat(failMsgPrefix, containsString("variableName")); + assertThat(failMsgPrefix, containsString("someValue")); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/internals/metrics/ClientMetricsTest.java b/streams/src/test/java/org/apache/kafka/streams/internals/metrics/ClientMetricsTest.java new file mode 100644 index 0000000..daaa737 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/internals/metrics/ClientMetricsTest.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.internals.metrics; + +import org.apache.kafka.common.metrics.Gauge; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.Sensor.RecordingLevel; +import org.apache.kafka.streams.KafkaStreams.State; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.junit.Test; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.eq; + +import java.util.Collections; +import java.util.Map; + +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.CLIENT_LEVEL_GROUP; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +public class ClientMetricsTest { + private static final String COMMIT_ID = "test-commit-ID"; + private static final String VERSION = "test-version"; + + private final StreamsMetricsImpl streamsMetrics = mock(StreamsMetricsImpl.class); + private final Sensor expectedSensor = mock(Sensor.class); + private final Map tagMap = Collections.singletonMap("hello", "world"); + + + @Test + public void shouldAddVersionMetric() { + final String name = "version"; + final String description = "The version of the Kafka Streams client"; + setUpAndVerifyImmutableMetric(name, description, VERSION, () -> ClientMetrics.addVersionMetric(streamsMetrics)); + } + + @Test + public void shouldAddCommitIdMetric() { + final String name = "commit-id"; + final String description = "The version control commit ID of the Kafka Streams client"; + setUpAndVerifyImmutableMetric(name, description, COMMIT_ID, () -> ClientMetrics.addCommitIdMetric(streamsMetrics)); + } + + @Test + public void shouldAddApplicationIdMetric() { + final String name = "application-id"; + final String description = "The application ID of the Kafka Streams client"; + final String applicationId = "thisIsAnID"; + setUpAndVerifyImmutableMetric( + name, + description, + applicationId, + () -> ClientMetrics.addApplicationIdMetric(streamsMetrics, applicationId) + ); + } + + @Test + public void shouldAddTopologyDescriptionMetric() { + final String name = "topology-description"; + final String description = "The description of the topology executed in the Kafka Streams client"; + final String topologyDescription = "thisIsATopologyDescription"; + final Gauge topologyDescriptionProvider = (c, n) -> topologyDescription; + setUpAndVerifyMutableMetric( + name, + description, + topologyDescriptionProvider, + () -> ClientMetrics.addTopologyDescriptionMetric(streamsMetrics, topologyDescriptionProvider) + ); + } + + @Test + public void shouldAddStateMetric() { + final String name = "state"; + final String description = "The state of the Kafka Streams client"; + final Gauge stateProvider = (config, now) -> State.RUNNING; + setUpAndVerifyMutableMetric( + name, + description, + stateProvider, + () -> ClientMetrics.addStateMetric(streamsMetrics, stateProvider) + ); + } + + @Test + public void shouldAddAliveStreamThreadsMetric() { + final String name = "alive-stream-threads"; + final String description = "The current number of alive stream threads that are running or participating in rebalance"; + final Gauge valueProvider = (config, now) -> 1; + setUpAndVerifyMutableMetric( + name, + description, + valueProvider, + () -> ClientMetrics.addNumAliveStreamThreadMetric(streamsMetrics, valueProvider) + ); + } + + @Test + public void shouldGetFailedStreamThreadsSensor() { + final String name = "failed-stream-threads"; + final String description = "The number of failed stream threads since the start of the Kafka Streams client"; + when(streamsMetrics.clientLevelSensor(name, RecordingLevel.INFO)).thenReturn(expectedSensor); + when(streamsMetrics.clientLevelTagMap()).thenReturn(tagMap); + StreamsMetricsImpl.addSumMetricToSensor( + expectedSensor, + CLIENT_LEVEL_GROUP, + tagMap, + name, + false, + description + ); + + final Sensor sensor = ClientMetrics.failedStreamThreadSensor(streamsMetrics); + assertThat(sensor, is(expectedSensor)); + } + + private void setUpAndVerifyMutableMetric(final String name, + final String description, + final Gauge valueProvider, + final Runnable metricAdder) { + + metricAdder.run(); + + verify(streamsMetrics).addClientLevelMutableMetric( + eq(name), + eq(description), + eq(RecordingLevel.INFO), + eq(valueProvider) + ); + } + + private void setUpAndVerifyImmutableMetric(final String name, + final String description, + final String value, + final Runnable metricAdder) { + + metricAdder.run(); + + verify(streamsMetrics).addClientLevelImmutableMetric( + eq(name), + eq(description), + eq(RecordingLevel.INFO), + eq(value) + ); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/JoinWindowsTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/JoinWindowsTest.java new file mode 100644 index 0000000..bd604ad --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/JoinWindowsTest.java @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.junit.Test; + +import java.time.Duration; + +import static java.time.Duration.ofMillis; +import static java.time.Duration.ofSeconds; +import static org.apache.kafka.streams.EqualityCheck.verifyEquality; +import static org.apache.kafka.streams.EqualityCheck.verifyInEquality; +import static org.apache.kafka.streams.kstream.Windows.DEPRECATED_DEFAULT_24_HR_GRACE_PERIOD; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.fail; + +public class JoinWindowsTest { + + private static final long ANY_SIZE = 123L; + private static final long ANY_OTHER_SIZE = 456L; // should be larger than anySize + private static final long ANY_GRACE = 1024L; + + @Test + public void validWindows() { + JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(ANY_OTHER_SIZE)) // [ -anyOtherSize ; anyOtherSize ] + .before(ofMillis(ANY_SIZE)) // [ -anySize ; anyOtherSize ] + .before(ofMillis(0)) // [ 0 ; anyOtherSize ] + .before(ofMillis(-ANY_SIZE)) // [ anySize ; anyOtherSize ] + .before(ofMillis(-ANY_OTHER_SIZE)); // [ anyOtherSize ; anyOtherSize ] + + JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(ANY_OTHER_SIZE)) // [ -anyOtherSize ; anyOtherSize ] + .after(ofMillis(ANY_SIZE)) // [ -anyOtherSize ; anySize ] + .after(ofMillis(0)) // [ -anyOtherSize ; 0 ] + .after(ofMillis(-ANY_SIZE)) // [ -anyOtherSize ; -anySize ] + .after(ofMillis(-ANY_OTHER_SIZE)); // [ -anyOtherSize ; -anyOtherSize ] + } + + @Test + public void beforeShouldNotModifyGrace() { + final JoinWindows joinWindows = JoinWindows.ofTimeDifferenceAndGrace(ofMillis(ANY_SIZE), ofMillis(ANY_OTHER_SIZE)) + .before(ofSeconds(ANY_SIZE)); + + assertThat(joinWindows.gracePeriodMs(), equalTo(ANY_OTHER_SIZE)); + } + + @Test + public void afterShouldNotModifyGrace() { + final JoinWindows joinWindows = JoinWindows.ofTimeDifferenceAndGrace(ofMillis(ANY_SIZE), ofMillis(ANY_OTHER_SIZE)) + .after(ofSeconds(ANY_SIZE)); + + assertThat(joinWindows.gracePeriodMs(), equalTo(ANY_OTHER_SIZE)); + } + + @Test + public void timeDifferenceMustNotBeNegative() { + assertThrows(IllegalArgumentException.class, () -> JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(-1))); + assertThrows(IllegalArgumentException.class, () -> JoinWindows.ofTimeDifferenceAndGrace(ofMillis(-1), ofMillis(ANY_GRACE))); + } + + @SuppressWarnings("deprecation") + @Test + public void graceShouldNotCalledAfterGraceSet() { + assertThrows(IllegalStateException.class, () -> JoinWindows.ofTimeDifferenceAndGrace(ofMillis(10), ofMillis(10)).grace(ofMillis(10))); + assertThrows(IllegalStateException.class, () -> JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(10)).grace(ofMillis(10))); + } + + @Test + public void endTimeShouldNotBeBeforeStart() { + final JoinWindows windowSpec = JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(ANY_SIZE)); + try { + windowSpec.after(ofMillis(-ANY_SIZE - 1)); + fail("window end time should not be before window start time"); + } catch (final IllegalArgumentException e) { + // expected + } + } + + @Test + public void startTimeShouldNotBeAfterEnd() { + final JoinWindows windowSpec = JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(ANY_SIZE)); + try { + windowSpec.before(ofMillis(-ANY_SIZE - 1)); + fail("window start time should not be after window end time"); + } catch (final IllegalArgumentException e) { + // expected + } + } + + @SuppressWarnings("deprecation") + @Test + public void untilShouldSetGraceDuration() { + final JoinWindows windowSpec = JoinWindows.of(ofMillis(ANY_SIZE)); + final long windowSize = windowSpec.size(); + assertEquals(windowSize, windowSpec.grace(ofMillis(windowSize)).gracePeriodMs()); + } + + @Test + public void gracePeriodShouldEnforceBoundaries() { + JoinWindows.ofTimeDifferenceAndGrace(ofMillis(3L), ofMillis(0L)); + + try { + JoinWindows.ofTimeDifferenceAndGrace(ofMillis(3L), ofMillis(-1L)); + fail("should not accept negatives"); + } catch (final IllegalArgumentException e) { + //expected + } + } + + @SuppressWarnings("deprecation") + @Test + public void oldAPIShouldSetDefaultGracePeriod() { + assertEquals(Duration.ofDays(1).toMillis(), DEPRECATED_DEFAULT_24_HR_GRACE_PERIOD); + assertEquals(DEPRECATED_DEFAULT_24_HR_GRACE_PERIOD - 6L, JoinWindows.of(ofMillis(3L)).gracePeriodMs()); + assertEquals(0L, JoinWindows.of(ofMillis(DEPRECATED_DEFAULT_24_HR_GRACE_PERIOD)).gracePeriodMs()); + assertEquals(0L, JoinWindows.of(ofMillis(DEPRECATED_DEFAULT_24_HR_GRACE_PERIOD + 1L)).gracePeriodMs()); + } + + @Test + public void noGraceAPIShouldNotSetGracePeriod() { + assertEquals(0L, JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(3L)).gracePeriodMs()); + assertEquals(0L, JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(ANY_SIZE)).gracePeriodMs()); + assertEquals(0L, JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(ANY_OTHER_SIZE)).gracePeriodMs()); + } + + @Test + public void withGraceAPIShouldSetGracePeriod() { + assertEquals(ANY_GRACE, JoinWindows.ofTimeDifferenceAndGrace(ofMillis(3L), ofMillis(ANY_GRACE)).gracePeriodMs()); + assertEquals(ANY_GRACE, JoinWindows.ofTimeDifferenceAndGrace(ofMillis(ANY_SIZE), ofMillis(ANY_GRACE)).gracePeriodMs()); + assertEquals(ANY_GRACE, JoinWindows.ofTimeDifferenceAndGrace(ofMillis(ANY_OTHER_SIZE), ofMillis(ANY_GRACE)).gracePeriodMs()); + } + + @Test + public void equalsAndHashcodeShouldBeValidForPositiveCases() { + verifyEquality( + JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(3)), + JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(3)) + ); + + verifyEquality( + JoinWindows.ofTimeDifferenceAndGrace(ofMillis(3), ofMillis(2)), + JoinWindows.ofTimeDifferenceAndGrace(ofMillis(3), ofMillis(2)) + ); + + verifyEquality( + JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(3)).after(ofMillis(2)), + JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(3)).after(ofMillis(2)) + ); + + verifyEquality( + JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(3)).before(ofMillis(2)), + JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(3)).before(ofMillis(2)) + ); + + verifyEquality( + JoinWindows.ofTimeDifferenceAndGrace(ofMillis(3), ofMillis(2)).after(ofMillis(4)), + JoinWindows.ofTimeDifferenceAndGrace(ofMillis(3), ofMillis(2)).after(ofMillis(4)) + ); + + verifyEquality( + JoinWindows.ofTimeDifferenceAndGrace(ofMillis(3), ofMillis(2)).before(ofMillis(4)), + JoinWindows.ofTimeDifferenceAndGrace(ofMillis(3), ofMillis(2)).before(ofMillis(4)) + ); + } + + @Test + public void equalsAndHashcodeShouldBeValidForNegativeCases() { + verifyInEquality( + JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(9)), + JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(3)) + ); + + verifyInEquality( + JoinWindows.ofTimeDifferenceAndGrace(ofMillis(3), ofMillis(9)), + JoinWindows.ofTimeDifferenceAndGrace(ofMillis(3), ofMillis(2)) + ); + + verifyInEquality( + JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(3)).after(ofMillis(9)), + JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(3)).after(ofMillis(2)) + ); + + verifyInEquality( + JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(3)).before(ofMillis(9)), + JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(3)).before(ofMillis(2)) + ); + + verifyInEquality( + JoinWindows.ofTimeDifferenceAndGrace(ofMillis(3), ofMillis(3)).before(ofMillis(9)).after(ofMillis(2)), + JoinWindows.ofTimeDifferenceAndGrace(ofMillis(3), ofMillis(3)).before(ofMillis(1)).after(ofMillis(2)) + ); + + verifyInEquality( + JoinWindows.ofTimeDifferenceAndGrace(ofMillis(3), ofMillis(3)).before(ofMillis(1)).after(ofMillis(9)), + JoinWindows.ofTimeDifferenceAndGrace(ofMillis(3), ofMillis(3)).before(ofMillis(1)).after(ofMillis(2)) + ); + + verifyInEquality( + JoinWindows.ofTimeDifferenceAndGrace(ofMillis(3), ofMillis(9)).before(ofMillis(1)).after(ofMillis(2)), + JoinWindows.ofTimeDifferenceAndGrace(ofMillis(3), ofMillis(3)).before(ofMillis(1)).after(ofMillis(2)) + ); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/MaterializedTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/MaterializedTest.java new file mode 100644 index 0000000..2630f35 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/MaterializedTest.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.streams.errors.TopologyException; +import org.apache.kafka.streams.state.KeyValueBytesStoreSupplier; +import org.apache.kafka.streams.state.SessionBytesStoreSupplier; +import org.apache.kafka.streams.state.WindowBytesStoreSupplier; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import java.time.Duration; +import java.time.temporal.ChronoUnit; + +public class MaterializedTest { + + @Test + public void shouldAllowValidTopicNamesAsStoreName() { + Materialized.as("valid-name"); + Materialized.as("valid.name"); + Materialized.as("valid_name"); + } + + @Test + public void shouldNotAllowInvalidTopicNames() { + final String invalidName = "not:valid"; + final TopologyException e = assertThrows(TopologyException.class, + () -> Materialized.as(invalidName)); + + assertEquals(e.getMessage(), "Invalid topology: Name \"" + invalidName + + "\" is illegal, it contains a character other than " + "ASCII alphanumerics, '.', '_' and '-'"); + } + + @Test + public void shouldThrowNullPointerIfWindowBytesStoreSupplierIsNull() { + final NullPointerException e = assertThrows(NullPointerException.class, + () -> Materialized.as((WindowBytesStoreSupplier) null)); + + assertEquals(e.getMessage(), "supplier can't be null"); + } + + @Test + public void shouldThrowNullPointerIfKeyValueBytesStoreSupplierIsNull() { + final NullPointerException e = assertThrows(NullPointerException.class, + () -> Materialized.as((KeyValueBytesStoreSupplier) null)); + + assertEquals(e.getMessage(), "supplier can't be null"); + } + + @Test + public void shouldThrowNullPointerIfSessionBytesStoreSupplierIsNull() { + final NullPointerException e = assertThrows(NullPointerException.class, + () -> Materialized.as((SessionBytesStoreSupplier) null)); + + assertEquals(e.getMessage(), "supplier can't be null"); + } + + @Test + public void shouldThrowIllegalArgumentExceptionIfRetentionIsNegative() { + final IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> Materialized.as("valid-name").withRetention(Duration.of(-1, ChronoUnit.DAYS))); + + assertEquals(e.getMessage(), "Retention must not be negative."); + } + + @Test + public void shouldThrowTopologyExceptionIfStoreNameExceedsMaxAllowedLength() { + final StringBuffer invalidStoreNameBuffer = new StringBuffer(); + final int maxNameLength = 249; + + for (int i = 0; i < maxNameLength + 1; i++) { + invalidStoreNameBuffer.append('a'); + } + + final String invalidStoreName = invalidStoreNameBuffer.toString(); + + final TopologyException e = assertThrows(TopologyException.class, + () -> Materialized.as(invalidStoreName)); + assertEquals(e.getMessage(), "Invalid topology: Name is illegal, it can't be longer than " + maxNameLength + + " characters, name: " + invalidStoreName); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/NamedTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/NamedTest.java new file mode 100644 index 0000000..6bac246 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/NamedTest.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.streams.errors.TopologyException; +import org.junit.Test; + +import java.util.Arrays; + +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.fail; + +public class NamedTest { + + @Test + public void shouldThrowExceptionGivenNullName() { + assertThrows(NullPointerException.class, () -> Named.as(null)); + } + + @Test + public void shouldThrowExceptionOnInvalidTopicNames() { + final char[] longString = new char[250]; + Arrays.fill(longString, 'a'); + final String[] invalidNames = {"", "foo bar", "..", "foo:bar", "foo=bar", ".", new String(longString)}; + + for (final String name : invalidNames) { + try { + Named.validate(name); + fail("No exception was thrown for named with invalid name: " + name); + } catch (final TopologyException e) { + // success + } + } + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/PrintedTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/PrintedTest.java new file mode 100644 index 0000000..212074f --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/PrintedTest.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.streams.errors.TopologyException; +import org.apache.kafka.streams.kstream.internals.PrintedInternal; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.io.PrintStream; +import java.io.UnsupportedEncodingException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertThrows; + +public class PrintedTest { + + private final PrintStream originalSysOut = System.out; + private final ByteArrayOutputStream sysOut = new ByteArrayOutputStream(); + private Printed sysOutPrinter; + + @Before + public void before() { + System.setOut(new PrintStream(sysOut)); + sysOutPrinter = Printed.toSysOut(); + } + + @After + public void after() { + System.setOut(originalSysOut); + } + + @Test + public void shouldCreateProcessorThatPrintsToFile() throws IOException { + final File file = TestUtils.tempFile(); + final ProcessorSupplier processorSupplier = new PrintedInternal<>( + Printed.toFile(file.getPath())) + .build("processor"); + final Processor processor = processorSupplier.get(); + processor.process(new Record<>("hi", 1, 0L)); + processor.close(); + try (final InputStream stream = Files.newInputStream(file.toPath())) { + final byte[] data = new byte[stream.available()]; + stream.read(data); + assertThat(new String(data, StandardCharsets.UTF_8.name()), equalTo("[processor]: hi, 1\n")); + } + } + + @Test + public void shouldCreateProcessorThatPrintsToStdOut() throws UnsupportedEncodingException { + final ProcessorSupplier supplier = new PrintedInternal<>(sysOutPrinter).build("processor"); + final Processor processor = supplier.get(); + + processor.process(new Record<>("good", 2, 0L)); + processor.close(); + assertThat(sysOut.toString(StandardCharsets.UTF_8.name()), equalTo("[processor]: good, 2\n")); + } + + @Test + public void shouldPrintWithLabel() throws UnsupportedEncodingException { + final Processor processor = new PrintedInternal<>(sysOutPrinter.withLabel("label")) + .build("processor") + .get(); + + processor.process(new Record<>("hello", 3, 0L)); + processor.close(); + assertThat(sysOut.toString(StandardCharsets.UTF_8.name()), equalTo("[label]: hello, 3\n")); + } + + @Test + public void shouldPrintWithKeyValueMapper() throws UnsupportedEncodingException { + final Processor processor = new PrintedInternal<>( + sysOutPrinter.withKeyValueMapper((key, value) -> String.format("%s -> %d", key, value)) + ).build("processor").get(); + processor.process(new Record<>("hello", 1, 0L)); + processor.close(); + assertThat(sysOut.toString(StandardCharsets.UTF_8.name()), equalTo("[processor]: hello -> 1\n")); + } + + @Test + public void shouldThrowNullPointerExceptionIfFilePathIsNull() { + assertThrows(NullPointerException.class, () -> Printed.toFile(null)); + } + + @Test + public void shouldThrowNullPointerExceptionIfMapperIsNull() { + assertThrows(NullPointerException.class, () -> sysOutPrinter.withKeyValueMapper(null)); + } + + @Test + public void shouldThrowNullPointerExceptionIfLabelIsNull() { + assertThrows(NullPointerException.class, () -> sysOutPrinter.withLabel(null)); + } + + @Test + public void shouldThrowTopologyExceptionIfFilePathIsEmpty() { + assertThrows(TopologyException.class, () -> Printed.toFile("")); + } + + @Test + public void shouldThrowTopologyExceptionIfFilePathDoesntExist() { + assertThrows(TopologyException.class, () -> Printed.toFile("/this/should/not/exist")); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/RepartitionTopicNamingTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/RepartitionTopicNamingTest.java new file mode 100644 index 0000000..cad978c --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/RepartitionTopicNamingTest.java @@ -0,0 +1,703 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.errors.TopologyException; +import org.junit.Test; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Properties; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +@SuppressWarnings("deprecation") +public class RepartitionTopicNamingTest { + + private final KeyValueMapper kvMapper = (k, v) -> k + v; + private static final String INPUT_TOPIC = "input"; + private static final String COUNT_TOPIC = "outputTopic_0"; + private static final String AGGREGATION_TOPIC = "outputTopic_1"; + private static final String REDUCE_TOPIC = "outputTopic_2"; + private static final String JOINED_TOPIC = "outputTopicForJoin"; + + private final String firstRepartitionTopicName = "count-stream"; + private final String secondRepartitionTopicName = "aggregate-stream"; + private final String thirdRepartitionTopicName = "reduced-stream"; + private final String fourthRepartitionTopicName = "joined-stream"; + private final Pattern repartitionTopicPattern = Pattern.compile("Sink: .*-repartition"); + + + @Test + public void shouldReuseFirstRepartitionTopicNameWhenOptimizing() { + + final String optimizedTopology = buildTopology(StreamsConfig.OPTIMIZE).describe().toString(); + final String unOptimizedTopology = buildTopology(StreamsConfig.NO_OPTIMIZATION).describe().toString(); + + assertThat(optimizedTopology, is(EXPECTED_OPTIMIZED_TOPOLOGY)); + // only one repartition topic + assertThat(1, is(getCountOfRepartitionTopicsFound(optimizedTopology, repartitionTopicPattern))); + // the first named repartition topic + assertTrue(optimizedTopology.contains(firstRepartitionTopicName + "-repartition")); + + assertThat(unOptimizedTopology, is(EXPECTED_UNOPTIMIZED_TOPOLOGY)); + // now 4 repartition topic + assertThat(4, is(getCountOfRepartitionTopicsFound(unOptimizedTopology, repartitionTopicPattern))); + // all 4 named repartition topics present + assertTrue(unOptimizedTopology.contains(firstRepartitionTopicName + "-repartition")); + assertTrue(unOptimizedTopology.contains(secondRepartitionTopicName + "-repartition")); + assertTrue(unOptimizedTopology.contains(thirdRepartitionTopicName + "-repartition")); + assertTrue(unOptimizedTopology.contains(fourthRepartitionTopicName + "-left-repartition")); + + } + + // can't use same repartition topic name + @Test + public void shouldFailWithSameRepartitionTopicName() { + try { + final StreamsBuilder builder = new StreamsBuilder(); + builder.stream("topic").selectKey((k, v) -> k) + .groupByKey(Grouped.as("grouping")) + .count().toStream(); + + builder.stream("topicII").selectKey((k, v) -> k) + .groupByKey(Grouped.as("grouping")) + .count().toStream(); + builder.build(); + fail("Should not build re-using repartition topic name"); + } catch (final TopologyException te) { + // ok + } + } + + @Test + public void shouldNotFailWithSameRepartitionTopicNameUsingSameKGroupedStream() { + final StreamsBuilder builder = new StreamsBuilder(); + final KGroupedStream kGroupedStream = builder.stream("topic") + .selectKey((k, v) -> k) + .groupByKey(Grouped.as("grouping")); + + kGroupedStream.windowedBy(TimeWindows.of(Duration.ofMillis(10L))).count().toStream().to("output-one"); + kGroupedStream.windowedBy(TimeWindows.of(Duration.ofMillis(30L))).count().toStream().to("output-two"); + + final String topologyString = builder.build().describe().toString(); + assertThat(1, is(getCountOfRepartitionTopicsFound(topologyString, repartitionTopicPattern))); + assertTrue(topologyString.contains("grouping-repartition")); + } + + @Test + public void shouldNotFailWithSameRepartitionTopicNameUsingSameTimeWindowStream() { + final StreamsBuilder builder = new StreamsBuilder(); + final KGroupedStream kGroupedStream = builder.stream("topic") + .selectKey((k, v) -> k) + .groupByKey(Grouped.as("grouping")); + + final TimeWindowedKStream timeWindowedKStream = kGroupedStream.windowedBy(TimeWindows.of(Duration.ofMillis(10L))); + + timeWindowedKStream.count().toStream().to("output-one"); + timeWindowedKStream.reduce((v, v2) -> v + v2).toStream().to("output-two"); + kGroupedStream.windowedBy(TimeWindows.of(Duration.ofMillis(30L))).count().toStream().to("output-two"); + + final String topologyString = builder.build().describe().toString(); + assertThat(1, is(getCountOfRepartitionTopicsFound(topologyString, repartitionTopicPattern))); + assertTrue(topologyString.contains("grouping-repartition")); + } + + @Test + public void shouldNotFailWithSameRepartitionTopicNameUsingSameSessionWindowStream() { + final StreamsBuilder builder = new StreamsBuilder(); + final KGroupedStream kGroupedStream = builder.stream("topic") + .selectKey((k, v) -> k) + .groupByKey(Grouped.as("grouping")); + + final SessionWindowedKStream sessionWindowedKStream = kGroupedStream.windowedBy(SessionWindows.with(Duration.ofMillis(10L))); + + sessionWindowedKStream.count().toStream().to("output-one"); + sessionWindowedKStream.reduce((v, v2) -> v + v2).toStream().to("output-two"); + kGroupedStream.windowedBy(TimeWindows.of(Duration.ofMillis(30L))).count().toStream().to("output-two"); + + final String topologyString = builder.build().describe().toString(); + assertThat(1, is(getCountOfRepartitionTopicsFound(topologyString, repartitionTopicPattern))); + assertTrue(topologyString.contains("grouping-repartition")); + } + + @Test + public void shouldNotFailWithSameRepartitionTopicNameUsingSameKGroupedTable() { + final StreamsBuilder builder = new StreamsBuilder(); + final KGroupedTable kGroupedTable = builder.table("topic") + .groupBy(KeyValue::pair, Grouped.as("grouping")); + kGroupedTable.count().toStream().to("output-count"); + kGroupedTable.reduce((v, v2) -> v2, (v, v2) -> v2).toStream().to("output-reduce"); + final String topologyString = builder.build().describe().toString(); + assertThat(1, is(getCountOfRepartitionTopicsFound(topologyString, repartitionTopicPattern))); + assertTrue(topologyString.contains("grouping-repartition")); + } + + @Test + public void shouldNotReuseRepartitionNodeWithUnamedRepartitionTopics() { + final StreamsBuilder builder = new StreamsBuilder(); + final KGroupedStream kGroupedStream = builder.stream("topic") + .selectKey((k, v) -> k) + .groupByKey(); + kGroupedStream.windowedBy(TimeWindows.of(Duration.ofMillis(10L))).count().toStream().to("output-one"); + kGroupedStream.windowedBy(TimeWindows.of(Duration.ofMillis(30L))).count().toStream().to("output-two"); + final String topologyString = builder.build().describe().toString(); + assertThat(2, is(getCountOfRepartitionTopicsFound(topologyString, repartitionTopicPattern))); + } + + @Test + public void shouldNotReuseRepartitionNodeWithUnamedRepartitionTopicsKGroupedTable() { + final StreamsBuilder builder = new StreamsBuilder(); + final KGroupedTable kGroupedTable = builder.table("topic").groupBy(KeyValue::pair); + kGroupedTable.count().toStream().to("output-count"); + kGroupedTable.reduce((v, v2) -> v2, (v, v2) -> v2).toStream().to("output-reduce"); + final String topologyString = builder.build().describe().toString(); + assertThat(2, is(getCountOfRepartitionTopicsFound(topologyString, repartitionTopicPattern))); + } + + @Test + public void shouldNotFailWithSameRepartitionTopicNameUsingSameKGroupedStreamOptimizationsOn() { + final StreamsBuilder builder = new StreamsBuilder(); + final KGroupedStream kGroupedStream = builder.stream("topic") + .selectKey((k, v) -> k) + .groupByKey(Grouped.as("grouping")); + kGroupedStream.windowedBy(TimeWindows.of(Duration.ofMillis(10L))).count(); + kGroupedStream.windowedBy(TimeWindows.of(Duration.ofMillis(30L))).count(); + final Properties properties = new Properties(); + properties.put(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, StreamsConfig.OPTIMIZE); + final Topology topology = builder.build(properties); + assertThat(getCountOfRepartitionTopicsFound(topology.describe().toString(), repartitionTopicPattern), is(1)); + } + + + // can't use same repartition topic name in joins + @Test + public void shouldFailWithSameRepartitionTopicNameInJoin() { + try { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream1 = builder.stream("topic").selectKey((k, v) -> k); + final KStream stream2 = builder.stream("topic2").selectKey((k, v) -> k); + final KStream stream3 = builder.stream("topic3").selectKey((k, v) -> k); + + final KStream joined = stream1.join(stream2, (v1, v2) -> v1 + v2, + JoinWindows.of(Duration.ofMillis(30L)), + StreamJoined.as("join-store").withName("join-repartition")); + + joined.join(stream3, (v1, v2) -> v1 + v2, JoinWindows.of(Duration.ofMillis(30L)), + StreamJoined.as("join-store").withName("join-repartition")); + + builder.build(); + fail("Should not build re-using repartition topic name"); + } catch (final TopologyException te) { + // ok + } + } + + @Test + public void shouldPassWithSameRepartitionTopicNameUsingSameKGroupedStreamOptimized() { + final StreamsBuilder builder = new StreamsBuilder(); + final Properties properties = new Properties(); + properties.put(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, StreamsConfig.OPTIMIZE); + final KGroupedStream kGroupedStream = builder.stream("topic") + .selectKey((k, v) -> k) + .groupByKey(Grouped.as("grouping")); + kGroupedStream.windowedBy(TimeWindows.of(Duration.ofMillis(10L))).count(); + kGroupedStream.windowedBy(TimeWindows.of(Duration.ofMillis(30L))).count(); + builder.build(properties); + } + + + @Test + public void shouldKeepRepartitionTopicNameForJoins() { + + final String expectedLeftRepartitionTopic = "(topic: my-join-left-repartition)"; + final String expectedRightRepartitionTopic = "(topic: my-join-right-repartition)"; + + + final String joinTopologyFirst = buildStreamJoin(false); + + assertTrue(joinTopologyFirst.contains(expectedLeftRepartitionTopic)); + assertTrue(joinTopologyFirst.contains(expectedRightRepartitionTopic)); + + final String joinTopologyUpdated = buildStreamJoin(true); + + assertTrue(joinTopologyUpdated.contains(expectedLeftRepartitionTopic)); + assertTrue(joinTopologyUpdated.contains(expectedRightRepartitionTopic)); + } + + @Test + public void shouldKeepRepartitionTopicNameForGroupByKeyTimeWindows() { + + final String expectedTimeWindowRepartitionTopic = "(topic: time-window-grouping-repartition)"; + + final String timeWindowGroupingRepartitionTopology = buildStreamGroupByKeyTimeWindows(false, true); + assertTrue(timeWindowGroupingRepartitionTopology.contains(expectedTimeWindowRepartitionTopic)); + + final String timeWindowGroupingUpdatedTopology = buildStreamGroupByKeyTimeWindows(true, true); + assertTrue(timeWindowGroupingUpdatedTopology.contains(expectedTimeWindowRepartitionTopic)); + } + + @Test + public void shouldKeepRepartitionTopicNameForGroupByTimeWindows() { + + final String expectedTimeWindowRepartitionTopic = "(topic: time-window-grouping-repartition)"; + + final String timeWindowGroupingRepartitionTopology = buildStreamGroupByKeyTimeWindows(false, false); + assertTrue(timeWindowGroupingRepartitionTopology.contains(expectedTimeWindowRepartitionTopic)); + + final String timeWindowGroupingUpdatedTopology = buildStreamGroupByKeyTimeWindows(true, false); + assertTrue(timeWindowGroupingUpdatedTopology.contains(expectedTimeWindowRepartitionTopic)); + } + + + @Test + public void shouldKeepRepartitionTopicNameForGroupByKeyNoWindows() { + + final String expectedNoWindowRepartitionTopic = "(topic: kstream-grouping-repartition)"; + + final String noWindowGroupingRepartitionTopology = buildStreamGroupByKeyNoWindows(false, true); + assertTrue(noWindowGroupingRepartitionTopology.contains(expectedNoWindowRepartitionTopic)); + + final String noWindowGroupingUpdatedTopology = buildStreamGroupByKeyNoWindows(true, true); + assertTrue(noWindowGroupingUpdatedTopology.contains(expectedNoWindowRepartitionTopic)); + } + + @Test + public void shouldKeepRepartitionTopicNameForGroupByNoWindows() { + + final String expectedNoWindowRepartitionTopic = "(topic: kstream-grouping-repartition)"; + + final String noWindowGroupingRepartitionTopology = buildStreamGroupByKeyNoWindows(false, false); + assertTrue(noWindowGroupingRepartitionTopology.contains(expectedNoWindowRepartitionTopic)); + + final String noWindowGroupingUpdatedTopology = buildStreamGroupByKeyNoWindows(true, false); + assertTrue(noWindowGroupingUpdatedTopology.contains(expectedNoWindowRepartitionTopic)); + } + + + @Test + public void shouldKeepRepartitionTopicNameForGroupByKeySessionWindows() { + + final String expectedSessionWindowRepartitionTopic = "(topic: session-window-grouping-repartition)"; + + final String sessionWindowGroupingRepartitionTopology = buildStreamGroupByKeySessionWindows(false, true); + assertTrue(sessionWindowGroupingRepartitionTopology.contains(expectedSessionWindowRepartitionTopic)); + + final String sessionWindowGroupingUpdatedTopology = buildStreamGroupByKeySessionWindows(true, true); + assertTrue(sessionWindowGroupingUpdatedTopology.contains(expectedSessionWindowRepartitionTopic)); + } + + @Test + public void shouldKeepRepartitionTopicNameForGroupBySessionWindows() { + + final String expectedSessionWindowRepartitionTopic = "(topic: session-window-grouping-repartition)"; + + final String sessionWindowGroupingRepartitionTopology = buildStreamGroupByKeySessionWindows(false, false); + assertTrue(sessionWindowGroupingRepartitionTopology.contains(expectedSessionWindowRepartitionTopic)); + + final String sessionWindowGroupingUpdatedTopology = buildStreamGroupByKeySessionWindows(true, false); + assertTrue(sessionWindowGroupingUpdatedTopology.contains(expectedSessionWindowRepartitionTopic)); + } + + @Test + public void shouldKeepRepartitionNameForGroupByKTable() { + final String expectedKTableGroupByRepartitionTopic = "(topic: ktable-group-by-repartition)"; + + final String ktableGroupByTopology = buildKTableGroupBy(false); + assertTrue(ktableGroupByTopology.contains(expectedKTableGroupByRepartitionTopic)); + + final String ktableUpdatedGroupByTopology = buildKTableGroupBy(true); + assertTrue(ktableUpdatedGroupByTopology.contains(expectedKTableGroupByRepartitionTopic)); + } + + + private String buildKTableGroupBy(final boolean otherOperations) { + final String ktableGroupByTopicName = "ktable-group-by"; + final StreamsBuilder builder = new StreamsBuilder(); + + final KTable ktable = builder.table("topic"); + + if (otherOperations) { + ktable.filter((k, v) -> true).groupBy(KeyValue::pair, Grouped.as(ktableGroupByTopicName)).count(); + } else { + ktable.groupBy(KeyValue::pair, Grouped.as(ktableGroupByTopicName)).count(); + } + + return builder.build().describe().toString(); + } + + private String buildStreamGroupByKeyTimeWindows(final boolean otherOperations, final boolean isGroupByKey) { + + final String groupedTimeWindowRepartitionTopicName = "time-window-grouping"; + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream selectKeyStream = builder.stream("topic").selectKey((k, v) -> k + v); + + + if (isGroupByKey) { + if (otherOperations) { + selectKeyStream.filter((k, v) -> true).mapValues(v -> v).groupByKey(Grouped.as(groupedTimeWindowRepartitionTopicName)).windowedBy(TimeWindows.of(Duration.ofMillis(10L))).count(); + } else { + selectKeyStream.groupByKey(Grouped.as(groupedTimeWindowRepartitionTopicName)).windowedBy(TimeWindows.of(Duration.ofMillis(10L))).count(); + } + } else { + if (otherOperations) { + selectKeyStream.filter((k, v) -> true).mapValues(v -> v).groupBy(kvMapper, Grouped.as(groupedTimeWindowRepartitionTopicName)).count(); + } else { + selectKeyStream.groupBy(kvMapper, Grouped.as(groupedTimeWindowRepartitionTopicName)).count(); + } + } + + return builder.build().describe().toString(); + } + + + private String buildStreamGroupByKeySessionWindows(final boolean otherOperations, final boolean isGroupByKey) { + + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream selectKeyStream = builder.stream("topic").selectKey((k, v) -> k + v); + + final String groupedSessionWindowRepartitionTopicName = "session-window-grouping"; + if (isGroupByKey) { + if (otherOperations) { + selectKeyStream.filter((k, v) -> true).mapValues(v -> v).groupByKey(Grouped.as(groupedSessionWindowRepartitionTopicName)).windowedBy(SessionWindows.with(Duration.ofMillis(10L))).count(); + } else { + selectKeyStream.groupByKey(Grouped.as(groupedSessionWindowRepartitionTopicName)).windowedBy(SessionWindows.with(Duration.ofMillis(10L))).count(); + } + } else { + if (otherOperations) { + selectKeyStream.filter((k, v) -> true).mapValues(v -> v).groupBy(kvMapper, Grouped.as(groupedSessionWindowRepartitionTopicName)).windowedBy(SessionWindows.with(Duration.ofMillis(10L))).count(); + } else { + selectKeyStream.groupBy(kvMapper, Grouped.as(groupedSessionWindowRepartitionTopicName)).windowedBy(SessionWindows.with(Duration.ofMillis(10L))).count(); + } + } + + return builder.build().describe().toString(); + } + + + private String buildStreamGroupByKeyNoWindows(final boolean otherOperations, final boolean isGroupByKey) { + + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream selectKeyStream = builder.stream("topic").selectKey((k, v) -> k + v); + + final String groupByAndCountRepartitionTopicName = "kstream-grouping"; + if (isGroupByKey) { + if (otherOperations) { + selectKeyStream.filter((k, v) -> true).mapValues(v -> v).groupByKey(Grouped.as(groupByAndCountRepartitionTopicName)).count(); + } else { + selectKeyStream.groupByKey(Grouped.as(groupByAndCountRepartitionTopicName)).count(); + } + } else { + if (otherOperations) { + selectKeyStream.filter((k, v) -> true).mapValues(v -> v).groupBy(kvMapper, Grouped.as(groupByAndCountRepartitionTopicName)).count(); + } else { + selectKeyStream.groupBy(kvMapper, Grouped.as(groupByAndCountRepartitionTopicName)).count(); + } + } + + return builder.build().describe().toString(); + } + + private String buildStreamJoin(final boolean includeOtherOperations) { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream initialStreamOne = builder.stream("topic-one"); + final KStream initialStreamTwo = builder.stream("topic-two"); + + final KStream updatedStreamOne; + final KStream updatedStreamTwo; + + if (includeOtherOperations) { + // without naming the join, the repartition topic name would change due to operator changing before join performed + updatedStreamOne = initialStreamOne.selectKey((k, v) -> k + v).filter((k, v) -> true).peek((k, v) -> System.out.println(k + v)); + updatedStreamTwo = initialStreamTwo.selectKey((k, v) -> k + v).filter((k, v) -> true).peek((k, v) -> System.out.println(k + v)); + } else { + updatedStreamOne = initialStreamOne.selectKey((k, v) -> k + v); + updatedStreamTwo = initialStreamTwo.selectKey((k, v) -> k + v); + } + + final String joinRepartitionTopicName = "my-join"; + updatedStreamOne.join(updatedStreamTwo, (v1, v2) -> v1 + v2, JoinWindows.of(Duration.ofMillis(1000L)), + StreamJoined.with(Serdes.String(), Serdes.String(), Serdes.String()).withName(joinRepartitionTopicName)); + + return builder.build().describe().toString(); + } + + + private int getCountOfRepartitionTopicsFound(final String topologyString, final Pattern repartitionTopicPattern) { + final Matcher matcher = repartitionTopicPattern.matcher(topologyString); + final List repartitionTopicsFound = new ArrayList<>(); + while (matcher.find()) { + repartitionTopicsFound.add(matcher.group()); + } + return repartitionTopicsFound.size(); + } + + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + private Topology buildTopology(final String optimizationConfig) { + final Initializer initializer = () -> 0; + final Aggregator aggregator = (k, v, agg) -> agg + v.length(); + final Reducer reducer = (v1, v2) -> v1 + ":" + v2; + final List processorValueCollector = new ArrayList<>(); + + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream sourceStream = builder.stream(INPUT_TOPIC, Consumed.with(Serdes.String(), Serdes.String())); + + final KStream mappedStream = sourceStream.map((k, v) -> KeyValue.pair(k.toUpperCase(Locale.getDefault()), v)); + + mappedStream.filter((k, v) -> k.equals("B")).mapValues(v -> v.toUpperCase(Locale.getDefault())) + .process(() -> new SimpleProcessor(processorValueCollector)); + + final KStream countStream = mappedStream.groupByKey(Grouped.as(firstRepartitionTopicName)).count(Materialized.with(Serdes.String(), Serdes.Long())).toStream(); + + countStream.to(COUNT_TOPIC, Produced.with(Serdes.String(), Serdes.Long())); + + mappedStream.groupByKey(Grouped.as(secondRepartitionTopicName)).aggregate(initializer, + aggregator, + Materialized.with(Serdes.String(), Serdes.Integer())) + .toStream().to(AGGREGATION_TOPIC, Produced.with(Serdes.String(), Serdes.Integer())); + + // adding operators for case where the repartition node is further downstream + mappedStream.filter((k, v) -> true).peek((k, v) -> System.out.println(k + ":" + v)).groupByKey(Grouped.as(thirdRepartitionTopicName)) + .reduce(reducer, Materialized.with(Serdes.String(), Serdes.String())) + .toStream().to(REDUCE_TOPIC, Produced.with(Serdes.String(), Serdes.String())); + + mappedStream.filter((k, v) -> k.equals("A")) + .join(countStream, (v1, v2) -> v1 + ":" + v2.toString(), + JoinWindows.of(Duration.ofMillis(5000L)), + StreamJoined.with(Serdes.String(), Serdes.String(), Serdes.Long()).withStoreName(fourthRepartitionTopicName).withName(fourthRepartitionTopicName)) + .to(JOINED_TOPIC); + + final Properties properties = new Properties(); + + properties.put(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, optimizationConfig); + return builder.build(properties); + } + + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + private static class SimpleProcessor extends org.apache.kafka.streams.processor.AbstractProcessor { + + final List valueList; + + SimpleProcessor(final List valueList) { + this.valueList = valueList; + } + + @Override + public void process(final String key, final String value) { + valueList.add(value); + } + } + + + private static final String EXPECTED_OPTIMIZED_TOPOLOGY = "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input])\n" + + " --> KSTREAM-MAP-0000000001\n" + + " Processor: KSTREAM-MAP-0000000001 (stores: [])\n" + + " --> KSTREAM-FILTER-0000000002, count-stream-repartition-filter\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: KSTREAM-FILTER-0000000002 (stores: [])\n" + + " --> KSTREAM-MAPVALUES-0000000003\n" + + " <-- KSTREAM-MAP-0000000001\n" + + " Processor: KSTREAM-MAPVALUES-0000000003 (stores: [])\n" + + " --> KSTREAM-PROCESSOR-0000000004\n" + + " <-- KSTREAM-FILTER-0000000002\n" + + " Processor: count-stream-repartition-filter (stores: [])\n" + + " --> count-stream-repartition-sink\n" + + " <-- KSTREAM-MAP-0000000001\n" + + " Processor: KSTREAM-PROCESSOR-0000000004 (stores: [])\n" + + " --> none\n" + + " <-- KSTREAM-MAPVALUES-0000000003\n" + + " Sink: count-stream-repartition-sink (topic: count-stream-repartition)\n" + + " <-- count-stream-repartition-filter\n" + + "\n" + + " Sub-topology: 1\n" + + " Source: count-stream-repartition-source (topics: [count-stream-repartition])\n" + + " --> KSTREAM-FILTER-0000000020, KSTREAM-AGGREGATE-0000000007, KSTREAM-AGGREGATE-0000000014, KSTREAM-FILTER-0000000029\n" + + " Processor: KSTREAM-AGGREGATE-0000000007 (stores: [KSTREAM-AGGREGATE-STATE-STORE-0000000006])\n" + + " --> KTABLE-TOSTREAM-0000000011\n" + + " <-- count-stream-repartition-source\n" + + " Processor: KTABLE-TOSTREAM-0000000011 (stores: [])\n" + + " --> joined-stream-other-windowed, KSTREAM-SINK-0000000012\n" + + " <-- KSTREAM-AGGREGATE-0000000007\n" + + " Processor: KSTREAM-FILTER-0000000020 (stores: [])\n" + + " --> KSTREAM-PEEK-0000000021\n" + + " <-- count-stream-repartition-source\n" + + " Processor: KSTREAM-FILTER-0000000029 (stores: [])\n" + + " --> joined-stream-this-windowed\n" + + " <-- count-stream-repartition-source\n" + + " Processor: KSTREAM-PEEK-0000000021 (stores: [])\n" + + " --> KSTREAM-REDUCE-0000000023\n" + + " <-- KSTREAM-FILTER-0000000020\n" + + " Processor: joined-stream-other-windowed (stores: [joined-stream-other-join-store])\n" + + " --> joined-stream-other-join\n" + + " <-- KTABLE-TOSTREAM-0000000011\n" + + " Processor: joined-stream-this-windowed (stores: [joined-stream-this-join-store])\n" + + " --> joined-stream-this-join\n" + + " <-- KSTREAM-FILTER-0000000029\n" + + " Processor: KSTREAM-AGGREGATE-0000000014 (stores: [KSTREAM-AGGREGATE-STATE-STORE-0000000013])\n" + + " --> KTABLE-TOSTREAM-0000000018\n" + + " <-- count-stream-repartition-source\n" + + " Processor: KSTREAM-REDUCE-0000000023 (stores: [KSTREAM-REDUCE-STATE-STORE-0000000022])\n" + + " --> KTABLE-TOSTREAM-0000000027\n" + + " <-- KSTREAM-PEEK-0000000021\n" + + " Processor: joined-stream-other-join (stores: [joined-stream-this-join-store])\n" + + " --> joined-stream-merge\n" + + " <-- joined-stream-other-windowed\n" + + " Processor: joined-stream-this-join (stores: [joined-stream-other-join-store])\n" + + " --> joined-stream-merge\n" + + " <-- joined-stream-this-windowed\n" + + " Processor: KTABLE-TOSTREAM-0000000018 (stores: [])\n" + + " --> KSTREAM-SINK-0000000019\n" + + " <-- KSTREAM-AGGREGATE-0000000014\n" + + " Processor: KTABLE-TOSTREAM-0000000027 (stores: [])\n" + + " --> KSTREAM-SINK-0000000028\n" + + " <-- KSTREAM-REDUCE-0000000023\n" + + " Processor: joined-stream-merge (stores: [])\n" + + " --> KSTREAM-SINK-0000000038\n" + + " <-- joined-stream-this-join, joined-stream-other-join\n" + + " Sink: KSTREAM-SINK-0000000012 (topic: outputTopic_0)\n" + + " <-- KTABLE-TOSTREAM-0000000011\n" + + " Sink: KSTREAM-SINK-0000000019 (topic: outputTopic_1)\n" + + " <-- KTABLE-TOSTREAM-0000000018\n" + + " Sink: KSTREAM-SINK-0000000028 (topic: outputTopic_2)\n" + + " <-- KTABLE-TOSTREAM-0000000027\n" + + " Sink: KSTREAM-SINK-0000000038 (topic: outputTopicForJoin)\n" + + " <-- joined-stream-merge\n\n"; + + + private static final String EXPECTED_UNOPTIMIZED_TOPOLOGY = "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input])\n" + + " --> KSTREAM-MAP-0000000001\n" + + " Processor: KSTREAM-MAP-0000000001 (stores: [])\n" + + " --> KSTREAM-FILTER-0000000029, KSTREAM-FILTER-0000000002, KSTREAM-FILTER-0000000020, aggregate-stream-repartition-filter, count-stream-repartition-filter\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: KSTREAM-FILTER-0000000020 (stores: [])\n" + + " --> KSTREAM-PEEK-0000000021\n" + + " <-- KSTREAM-MAP-0000000001\n" + + " Processor: KSTREAM-FILTER-0000000002 (stores: [])\n" + + " --> KSTREAM-MAPVALUES-0000000003\n" + + " <-- KSTREAM-MAP-0000000001\n" + + " Processor: KSTREAM-FILTER-0000000029 (stores: [])\n" + + " --> joined-stream-left-repartition-filter\n" + + " <-- KSTREAM-MAP-0000000001\n" + + " Processor: KSTREAM-PEEK-0000000021 (stores: [])\n" + + " --> reduced-stream-repartition-filter\n" + + " <-- KSTREAM-FILTER-0000000020\n" + + " Processor: KSTREAM-MAPVALUES-0000000003 (stores: [])\n" + + " --> KSTREAM-PROCESSOR-0000000004\n" + + " <-- KSTREAM-FILTER-0000000002\n" + + " Processor: aggregate-stream-repartition-filter (stores: [])\n" + + " --> aggregate-stream-repartition-sink\n" + + " <-- KSTREAM-MAP-0000000001\n" + + " Processor: count-stream-repartition-filter (stores: [])\n" + + " --> count-stream-repartition-sink\n" + + " <-- KSTREAM-MAP-0000000001\n" + + " Processor: joined-stream-left-repartition-filter (stores: [])\n" + + " --> joined-stream-left-repartition-sink\n" + + " <-- KSTREAM-FILTER-0000000029\n" + + " Processor: reduced-stream-repartition-filter (stores: [])\n" + + " --> reduced-stream-repartition-sink\n" + + " <-- KSTREAM-PEEK-0000000021\n" + + " Processor: KSTREAM-PROCESSOR-0000000004 (stores: [])\n" + + " --> none\n" + + " <-- KSTREAM-MAPVALUES-0000000003\n" + + " Sink: aggregate-stream-repartition-sink (topic: aggregate-stream-repartition)\n" + + " <-- aggregate-stream-repartition-filter\n" + + " Sink: count-stream-repartition-sink (topic: count-stream-repartition)\n" + + " <-- count-stream-repartition-filter\n" + + " Sink: joined-stream-left-repartition-sink (topic: joined-stream-left-repartition)\n" + + " <-- joined-stream-left-repartition-filter\n" + + " Sink: reduced-stream-repartition-sink (topic: reduced-stream-repartition)\n" + + " <-- reduced-stream-repartition-filter\n" + + "\n" + + " Sub-topology: 1\n" + + " Source: count-stream-repartition-source (topics: [count-stream-repartition])\n" + + " --> KSTREAM-AGGREGATE-0000000007\n" + + " Processor: KSTREAM-AGGREGATE-0000000007 (stores: [KSTREAM-AGGREGATE-STATE-STORE-0000000006])\n" + + " --> KTABLE-TOSTREAM-0000000011\n" + + " <-- count-stream-repartition-source\n" + + " Processor: KTABLE-TOSTREAM-0000000011 (stores: [])\n" + + " --> KSTREAM-SINK-0000000012, joined-stream-other-windowed\n" + + " <-- KSTREAM-AGGREGATE-0000000007\n" + + " Source: joined-stream-left-repartition-source (topics: [joined-stream-left-repartition])\n" + + " --> joined-stream-this-windowed\n" + + " Processor: joined-stream-other-windowed (stores: [joined-stream-other-join-store])\n" + + " --> joined-stream-other-join\n" + + " <-- KTABLE-TOSTREAM-0000000011\n" + + " Processor: joined-stream-this-windowed (stores: [joined-stream-this-join-store])\n" + + " --> joined-stream-this-join\n" + + " <-- joined-stream-left-repartition-source\n" + + " Processor: joined-stream-other-join (stores: [joined-stream-this-join-store])\n" + + " --> joined-stream-merge\n" + + " <-- joined-stream-other-windowed\n" + + " Processor: joined-stream-this-join (stores: [joined-stream-other-join-store])\n" + + " --> joined-stream-merge\n" + + " <-- joined-stream-this-windowed\n" + + " Processor: joined-stream-merge (stores: [])\n" + + " --> KSTREAM-SINK-0000000038\n" + + " <-- joined-stream-this-join, joined-stream-other-join\n" + + " Sink: KSTREAM-SINK-0000000012 (topic: outputTopic_0)\n" + + " <-- KTABLE-TOSTREAM-0000000011\n" + + " Sink: KSTREAM-SINK-0000000038 (topic: outputTopicForJoin)\n" + + " <-- joined-stream-merge\n" + + "\n" + + " Sub-topology: 2\n" + + " Source: aggregate-stream-repartition-source (topics: [aggregate-stream-repartition])\n" + + " --> KSTREAM-AGGREGATE-0000000014\n" + + " Processor: KSTREAM-AGGREGATE-0000000014 (stores: [KSTREAM-AGGREGATE-STATE-STORE-0000000013])\n" + + " --> KTABLE-TOSTREAM-0000000018\n" + + " <-- aggregate-stream-repartition-source\n" + + " Processor: KTABLE-TOSTREAM-0000000018 (stores: [])\n" + + " --> KSTREAM-SINK-0000000019\n" + + " <-- KSTREAM-AGGREGATE-0000000014\n" + + " Sink: KSTREAM-SINK-0000000019 (topic: outputTopic_1)\n" + + " <-- KTABLE-TOSTREAM-0000000018\n" + + "\n" + + " Sub-topology: 3\n" + + " Source: reduced-stream-repartition-source (topics: [reduced-stream-repartition])\n" + + " --> KSTREAM-REDUCE-0000000023\n" + + " Processor: KSTREAM-REDUCE-0000000023 (stores: [KSTREAM-REDUCE-STATE-STORE-0000000022])\n" + + " --> KTABLE-TOSTREAM-0000000027\n" + + " <-- reduced-stream-repartition-source\n" + + " Processor: KTABLE-TOSTREAM-0000000027 (stores: [])\n" + + " --> KSTREAM-SINK-0000000028\n" + + " <-- KSTREAM-REDUCE-0000000023\n" + + " Sink: KSTREAM-SINK-0000000028 (topic: outputTopic_2)\n" + + " <-- KTABLE-TOSTREAM-0000000027\n\n"; + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/SessionWindowedDeserializerTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/SessionWindowedDeserializerTest.java new file mode 100644 index 0000000..cc8f964 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/SessionWindowedDeserializerTest.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.streams.StreamsConfig; +import org.junit.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertThrows; + + +public class SessionWindowedDeserializerTest { + private final SessionWindowedDeserializer sessionWindowedDeserializer = new SessionWindowedDeserializer<>(new StringDeserializer()); + private final Map props = new HashMap<>(); + + @Test + public void testSessionWindowedDeserializerConstructor() { + sessionWindowedDeserializer.configure(props, true); + final Deserializer inner = sessionWindowedDeserializer.innerDeserializer(); + assertNotNull("Inner deserializer should be not null", inner); + assertTrue("Inner deserializer type should be StringDeserializer", inner instanceof StringDeserializer); + } + + @Test + public void shouldSetWindowedInnerClassDeserialiserThroughConfig() { + props.put(StreamsConfig.WINDOWED_INNER_CLASS_SERDE, Serdes.ByteArraySerde.class.getName()); + final SessionWindowedDeserializer deserializer = new SessionWindowedDeserializer<>(); + deserializer.configure(props, false); + assertTrue(deserializer.innerDeserializer() instanceof ByteArrayDeserializer); + } + + @Test + public void shouldThrowErrorIfWindowInnerClassDeserialiserIsNotSet() { + final SessionWindowedDeserializer deserializer = new SessionWindowedDeserializer<>(); + assertThrows(IllegalArgumentException.class, () -> deserializer.configure(props, false)); + } + + @Test + public void shouldThrowErrorIfDeserialisersConflictInConstructorAndConfig() { + props.put(StreamsConfig.WINDOWED_INNER_CLASS_SERDE, Serdes.ByteArraySerde.class.getName()); + assertThrows(IllegalArgumentException.class, () -> sessionWindowedDeserializer.configure(props, false)); + } + + @Test + public void shouldThrowConfigExceptionWhenInvalidWindowInnerClassDeserialiserSupplied() { + props.put(StreamsConfig.WINDOWED_INNER_CLASS_SERDE, "some.non.existent.class"); + assertThrows(ConfigException.class, () -> sessionWindowedDeserializer.configure(props, false)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/SessionWindowedSerializerTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/SessionWindowedSerializerTest.java new file mode 100644 index 0000000..2a560ed --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/SessionWindowedSerializerTest.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.StreamsConfig; +import org.junit.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertThrows; + +public class SessionWindowedSerializerTest { + private final SessionWindowedSerializer sessionWindowedSerializer = new SessionWindowedSerializer<>(Serdes.String().serializer()); + private final Map props = new HashMap<>(); + + @Test + public void testSessionWindowedSerializerConstructor() { + sessionWindowedSerializer.configure(props, true); + final Serializer inner = sessionWindowedSerializer.innerSerializer(); + assertNotNull("Inner serializer should be not null", inner); + assertTrue("Inner serializer type should be StringSerializer", inner instanceof StringSerializer); + } + + @Test + public void shouldSetWindowedInnerClassSerialiserThroughConfig() { + props.put(StreamsConfig.WINDOWED_INNER_CLASS_SERDE, Serdes.ByteArraySerde.class.getName()); + final SessionWindowedSerializer serializer = new SessionWindowedSerializer<>(); + serializer.configure(props, false); + assertTrue(serializer.innerSerializer() instanceof ByteArraySerializer); + } + + @Test + public void shouldThrowErrorIfWindowInnerClassSerialiserIsNotSet() { + final SessionWindowedSerializer serializer = new SessionWindowedSerializer<>(); + assertThrows(IllegalArgumentException.class, () -> serializer.configure(props, false)); + } + + @Test + public void shouldThrowErrorIfSerialisersConflictInConstructorAndConfig() { + props.put(StreamsConfig.WINDOWED_INNER_CLASS_SERDE, Serdes.ByteArraySerde.class.getName()); + assertThrows(IllegalArgumentException.class, () -> sessionWindowedSerializer.configure(props, false)); + } + + @Test + public void shouldThrowConfigExceptionWhenInvalidWindowInnerClassSerialiserSupplied() { + props.put(StreamsConfig.WINDOWED_INNER_CLASS_SERDE, "some.non.existent.class"); + assertThrows(ConfigException.class, () -> sessionWindowedSerializer.configure(props, false)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/SessionWindowsTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/SessionWindowsTest.java new file mode 100644 index 0000000..7a8fafd --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/SessionWindowsTest.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.junit.Test; + +import java.time.Duration; + +import static java.time.Duration.ofMillis; +import static org.apache.kafka.streams.EqualityCheck.verifyEquality; +import static org.apache.kafka.streams.EqualityCheck.verifyInEquality; +import static org.apache.kafka.streams.kstream.Windows.DEPRECATED_DEFAULT_24_HR_GRACE_PERIOD; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.fail; + +public class SessionWindowsTest { + + private static final long ANY_SIZE = 123L; + private static final long ANY_OTHER_SIZE = 456L; // should be larger than anySize + private static final long ANY_GRACE = 1024L; + + @Test + public void shouldSetWindowGap() { + final long anyGap = 42L; + + assertEquals(anyGap, SessionWindows.ofInactivityGapWithNoGrace(ofMillis(anyGap)).inactivityGap()); + assertEquals(anyGap, SessionWindows.ofInactivityGapAndGrace(ofMillis(anyGap), ofMillis(ANY_GRACE)).inactivityGap()); + } + + @Test + public void gracePeriodShouldEnforceBoundaries() { + SessionWindows.ofInactivityGapAndGrace(ofMillis(3L), ofMillis(0)); + + try { + SessionWindows.ofInactivityGapAndGrace(ofMillis(3L), ofMillis(-1L)); + fail("should not accept negatives"); + } catch (final IllegalArgumentException e) { + //expected + } + } + + @Test + public void noGraceAPIShouldNotSetGracePeriod() { + assertEquals(0L, SessionWindows.ofInactivityGapWithNoGrace(ofMillis(3L)).gracePeriodMs()); + assertEquals(0L, SessionWindows.ofInactivityGapWithNoGrace(ofMillis(ANY_SIZE)).gracePeriodMs()); + assertEquals(0L, SessionWindows.ofInactivityGapWithNoGrace(ofMillis(ANY_OTHER_SIZE)).gracePeriodMs()); + } + + @Test + public void withGraceAPIShouldSetGracePeriod() { + assertEquals(ANY_GRACE, SessionWindows.ofInactivityGapAndGrace(ofMillis(3L), ofMillis(ANY_GRACE)).gracePeriodMs()); + assertEquals(ANY_GRACE, SessionWindows.ofInactivityGapAndGrace(ofMillis(ANY_SIZE), ofMillis(ANY_GRACE)).gracePeriodMs()); + assertEquals(ANY_GRACE, SessionWindows.ofInactivityGapAndGrace(ofMillis(ANY_OTHER_SIZE), ofMillis(ANY_GRACE)).gracePeriodMs()); + } + + @SuppressWarnings("deprecation") + @Test + public void oldAPIShouldSetDefaultGracePeriod() { + assertEquals(Duration.ofDays(1).toMillis(), DEPRECATED_DEFAULT_24_HR_GRACE_PERIOD); + assertEquals(DEPRECATED_DEFAULT_24_HR_GRACE_PERIOD - 3L, SessionWindows.with(ofMillis(3L)).gracePeriodMs()); + assertEquals(0L, SessionWindows.with(ofMillis(DEPRECATED_DEFAULT_24_HR_GRACE_PERIOD)).gracePeriodMs()); + assertEquals(0L, SessionWindows.with(ofMillis(DEPRECATED_DEFAULT_24_HR_GRACE_PERIOD + 1L)).gracePeriodMs()); + } + + @Test + public void windowSizeMustNotBeNegative() { + assertThrows(IllegalArgumentException.class, () -> SessionWindows.ofInactivityGapWithNoGrace(ofMillis(-1))); + } + + @Test + public void windowSizeMustNotBeZero() { + assertThrows(IllegalArgumentException.class, () -> SessionWindows.ofInactivityGapWithNoGrace(ofMillis(0))); + } + + @SuppressWarnings("deprecation") + @Test + public void graceShouldNotCalledAfterGraceSet() { + assertThrows(IllegalStateException.class, () -> SessionWindows.ofInactivityGapAndGrace(ofMillis(10), ofMillis(10)).grace(ofMillis(10))); + assertThrows(IllegalStateException.class, () -> SessionWindows.ofInactivityGapWithNoGrace(ofMillis(10)).grace(ofMillis(10))); + } + + @Test + public void equalsAndHashcodeShouldBeValidForPositiveCases() { + verifyEquality( + SessionWindows.ofInactivityGapWithNoGrace(ofMillis(1)), + SessionWindows.ofInactivityGapWithNoGrace(ofMillis(1)) + ); + + verifyEquality( + SessionWindows.ofInactivityGapAndGrace(ofMillis(1), ofMillis(11)), + SessionWindows.ofInactivityGapAndGrace(ofMillis(1), ofMillis(11)) + ); + } + + @Test + public void equalsAndHashcodeShouldBeValidForNegativeCases() { + verifyInEquality( + SessionWindows.ofInactivityGapWithNoGrace(ofMillis(9)), + SessionWindows.ofInactivityGapWithNoGrace(ofMillis(1)) + ); + + verifyInEquality( + SessionWindows.ofInactivityGapAndGrace(ofMillis(9), ofMillis(9)), + SessionWindows.ofInactivityGapAndGrace(ofMillis(1), ofMillis(9)) + ); + + verifyInEquality( + SessionWindows.ofInactivityGapAndGrace(ofMillis(1), ofMillis(9)), + SessionWindows.ofInactivityGapAndGrace(ofMillis(1), ofMillis(6)) + ); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/SlidingWindowsTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/SlidingWindowsTest.java new file mode 100644 index 0000000..1c8656c --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/SlidingWindowsTest.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.junit.Test; + +import static java.time.Duration.ofMillis; +import static org.apache.kafka.streams.EqualityCheck.verifyEquality; +import static org.apache.kafka.streams.EqualityCheck.verifyInEquality; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +public class SlidingWindowsTest { + + private static final long ANY_SIZE = 123L; + private static final long ANY_GRACE = 1024L; + + @SuppressWarnings("deprecation") + @Test + public void shouldSetTimeDifference() { + assertEquals(ANY_SIZE, SlidingWindows.withTimeDifferenceAndGrace(ofMillis(ANY_SIZE), ofMillis(3)).timeDifferenceMs()); + assertEquals(ANY_SIZE, SlidingWindows.ofTimeDifferenceAndGrace(ofMillis(ANY_SIZE), ofMillis(ANY_GRACE)).timeDifferenceMs()); + assertEquals(ANY_SIZE, SlidingWindows.ofTimeDifferenceWithNoGrace(ofMillis(ANY_SIZE)).timeDifferenceMs()); + } + + @Test + public void timeDifferenceMustNotBeNegative() { + assertThrows(IllegalArgumentException.class, () -> SlidingWindows.ofTimeDifferenceAndGrace(ofMillis(-1), ofMillis(5))); + } + + @Test + public void shouldSetGracePeriod() { + assertEquals(ANY_SIZE, SlidingWindows.ofTimeDifferenceAndGrace(ofMillis(10), ofMillis(ANY_SIZE)).gracePeriodMs()); + } + + @Test + public void gracePeriodMustNotBeNegative() { + assertThrows(IllegalArgumentException.class, () -> SlidingWindows.ofTimeDifferenceAndGrace(ofMillis(10), ofMillis(-1))); + } + + @Test + public void equalsAndHashcodeShouldBeValidForPositiveCases() { + final long grace = 1L + (long) (Math.random() * (20L - 1L)); + final long timeDifference = 1L + (long) (Math.random() * (20L - 1L)); + verifyEquality( + SlidingWindows.ofTimeDifferenceAndGrace(ofMillis(timeDifference), ofMillis(grace)), + SlidingWindows.ofTimeDifferenceAndGrace(ofMillis(timeDifference), ofMillis(grace)) + ); + + verifyEquality( + SlidingWindows.ofTimeDifferenceAndGrace(ofMillis(timeDifference), ofMillis(grace)), + SlidingWindows.ofTimeDifferenceAndGrace(ofMillis(timeDifference), ofMillis(grace)) + ); + + verifyEquality( + SlidingWindows.ofTimeDifferenceWithNoGrace(ofMillis(timeDifference)), + SlidingWindows.ofTimeDifferenceWithNoGrace(ofMillis(timeDifference)) + ); + } + + @Test + public void equalsAndHashcodeShouldNotBeEqualForDifferentTimeDifference() { + final long grace = 1L + (long) (Math.random() * (10L - 1L)); + final long timeDifferenceOne = 1L + (long) (Math.random() * (10L - 1L)); + final long timeDifferenceTwo = 21L + (long) (Math.random() * (41L - 21L)); + verifyInEquality( + SlidingWindows.ofTimeDifferenceAndGrace(ofMillis(timeDifferenceOne), ofMillis(grace)), + SlidingWindows.ofTimeDifferenceAndGrace(ofMillis(timeDifferenceTwo), ofMillis(grace)) + ); + } + + @Test + public void equalsAndHashcodeShouldNotBeEqualForDifferentGracePeriod() { + final long timeDifference = 1L + (long) (Math.random() * (10L - 1L)); + final long graceOne = 1L + (long) (Math.random() * (10L - 1L)); + final long graceTwo = 21L + (long) (Math.random() * (41L - 21L)); + verifyInEquality( + SlidingWindows.ofTimeDifferenceAndGrace(ofMillis(timeDifference), ofMillis(graceOne)), + SlidingWindows.ofTimeDifferenceAndGrace(ofMillis(timeDifference), ofMillis(graceTwo)) + ); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/SuppressedTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/SuppressedTest.java new file mode 100644 index 0000000..b799884 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/SuppressedTest.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.streams.kstream.internals.suppress.EagerBufferConfigImpl; +import org.apache.kafka.streams.kstream.internals.suppress.FinalResultsSuppressionBuilder; +import org.apache.kafka.streams.kstream.internals.suppress.StrictBufferConfigImpl; +import org.apache.kafka.streams.kstream.internals.suppress.SuppressedInternal; +import org.junit.Test; + +import java.util.Collections; + +import static java.lang.Long.MAX_VALUE; +import static java.time.Duration.ofMillis; +import static org.apache.kafka.streams.kstream.Suppressed.BufferConfig.maxBytes; +import static org.apache.kafka.streams.kstream.Suppressed.BufferConfig.maxRecords; +import static org.apache.kafka.streams.kstream.Suppressed.BufferConfig.unbounded; +import static org.apache.kafka.streams.kstream.Suppressed.untilTimeLimit; +import static org.apache.kafka.streams.kstream.Suppressed.untilWindowCloses; +import static org.apache.kafka.streams.kstream.internals.suppress.BufferFullStrategy.SHUT_DOWN; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +public class SuppressedTest { + + @Test + public void bufferBuilderShouldBeConsistent() { + assertThat( + "noBound should remove bounds", + maxBytes(2L).withMaxRecords(4L).withNoBound(), + is(unbounded()) + ); + + assertThat( + "keys alone should be set", + maxRecords(2L), + is(new EagerBufferConfigImpl(2L, MAX_VALUE, Collections.emptyMap())) + ); + + assertThat( + "size alone should be set", + maxBytes(2L), + is(new EagerBufferConfigImpl(MAX_VALUE, 2L, Collections.emptyMap())) + ); + + assertThat( + "config should be set even after max records", + maxRecords(2L).withMaxBytes(4L).withLoggingEnabled(Collections.singletonMap("myConfigKey", "myConfigValue")), + is(new EagerBufferConfigImpl(2L, 4L, Collections.singletonMap("myConfigKey", "myConfigValue"))) + ); + } + + @Test + public void intermediateEventsShouldAcceptAnyBufferAndSetBounds() { + assertThat( + "name should be set", + untilTimeLimit(ofMillis(2), unbounded()).withName("myname"), + is(new SuppressedInternal<>("myname", ofMillis(2), unbounded(), null, false)) + ); + + assertThat( + "time alone should be set", + untilTimeLimit(ofMillis(2), unbounded()), + is(new SuppressedInternal<>(null, ofMillis(2), unbounded(), null, false)) + ); + + assertThat( + "time and unbounded buffer should be set", + untilTimeLimit(ofMillis(2), unbounded()), + is(new SuppressedInternal<>(null, ofMillis(2), unbounded(), null, false)) + ); + + assertThat( + "time and keys buffer should be set", + untilTimeLimit(ofMillis(2), maxRecords(2)), + is(new SuppressedInternal<>(null, ofMillis(2), maxRecords(2), null, false)) + ); + + assertThat( + "time and size buffer should be set", + untilTimeLimit(ofMillis(2), maxBytes(2)), + is(new SuppressedInternal<>(null, ofMillis(2), maxBytes(2), null, false)) + ); + + assertThat( + "all constraints should be set", + untilTimeLimit(ofMillis(2L), maxRecords(3L).withMaxBytes(2L)), + is(new SuppressedInternal<>(null, ofMillis(2), new EagerBufferConfigImpl(3L, 2L, Collections.emptyMap()), null, false)) + ); + + assertThat( + "config is not lost early emit is set", + untilTimeLimit(ofMillis(2), maxRecords(2L).withLoggingEnabled(Collections.singletonMap("myConfigKey", "myConfigValue")).emitEarlyWhenFull()), + is(new SuppressedInternal<>(null, ofMillis(2), new EagerBufferConfigImpl(2L, MAX_VALUE, Collections.singletonMap("myConfigKey", "myConfigValue")), null, false)) + ); + } + + @Test + public void finalEventsShouldAcceptStrictBuffersAndSetBounds() { + + assertThat( + untilWindowCloses(unbounded()), + is(new FinalResultsSuppressionBuilder<>(null, unbounded())) + ); + + assertThat( + untilWindowCloses(maxRecords(2L).shutDownWhenFull()), + is(new FinalResultsSuppressionBuilder<>(null, new StrictBufferConfigImpl(2L, MAX_VALUE, SHUT_DOWN, Collections.emptyMap())) + ) + ); + + assertThat( + untilWindowCloses(maxBytes(2L).shutDownWhenFull()), + is(new FinalResultsSuppressionBuilder<>(null, new StrictBufferConfigImpl(MAX_VALUE, 2L, SHUT_DOWN, Collections.emptyMap())) + ) + ); + + assertThat( + untilWindowCloses(unbounded()).withName("name"), + is(new FinalResultsSuppressionBuilder<>("name", unbounded())) + ); + + assertThat( + untilWindowCloses(maxRecords(2L).shutDownWhenFull()).withName("name"), + is(new FinalResultsSuppressionBuilder<>("name", new StrictBufferConfigImpl(2L, MAX_VALUE, SHUT_DOWN, Collections.emptyMap())) + ) + ); + + assertThat( + untilWindowCloses(maxBytes(2L).shutDownWhenFull()).withName("name"), + is(new FinalResultsSuppressionBuilder<>("name", new StrictBufferConfigImpl(MAX_VALUE, 2L, SHUT_DOWN, Collections.emptyMap())) + ) + ); + + assertThat( + "config is not lost when shutdown when full is set", + untilWindowCloses(maxBytes(2L).withLoggingEnabled(Collections.singletonMap("myConfigKey", "myConfigValue")).shutDownWhenFull()), + is(new FinalResultsSuppressionBuilder<>(null, new StrictBufferConfigImpl(MAX_VALUE, 2L, SHUT_DOWN, Collections.singletonMap("myConfigKey", "myConfigValue")))) + ); + } + + @Test + public void supportLongChainOfMethods() { + final Suppressed.BufferConfig bufferConfig = unbounded() + .emitEarlyWhenFull() + .withMaxRecords(3L) + .withMaxBytes(4L) + .withMaxRecords(5L) + .withMaxBytes(6L); + + assertThat( + "long chain of eager buffer config sets attributes properly", + bufferConfig, + is(new EagerBufferConfigImpl(5L, 6L, Collections.emptyMap())) + ); + assertThat( + "long chain of strict buffer config sets attributes properly", + bufferConfig.shutDownWhenFull(), + is(new StrictBufferConfigImpl(5L, 6L, SHUT_DOWN, Collections.emptyMap())) + ); + + final Suppressed.BufferConfig bufferConfigWithLogging = unbounded() + .withLoggingEnabled(Collections.singletonMap("myConfigKey", "myConfigValue")) + .emitEarlyWhenFull() + .withMaxRecords(3L) + .withMaxBytes(4L) + .withMaxRecords(5L) + .withMaxBytes(6L); + + assertThat( + "long chain of eager buffer config sets attributes properly with logging enabled", + bufferConfigWithLogging, + is(new EagerBufferConfigImpl(5L, 6L, Collections.singletonMap("myConfigKey", "myConfigValue"))) + ); + assertThat( + "long chain of strict buffer config sets attributes properly with logging enabled", + bufferConfigWithLogging.shutDownWhenFull(), + is(new StrictBufferConfigImpl(5L, 6L, SHUT_DOWN, Collections.singletonMap("myConfigKey", "myConfigValue"))) + ); + + final Suppressed.BufferConfig bufferConfigWithLoggingCalledAtTheEnd = unbounded() + .emitEarlyWhenFull() + .withMaxRecords(3L) + .withMaxBytes(4L) + .withMaxRecords(5L) + .withMaxBytes(6L) + .withLoggingEnabled(Collections.singletonMap("myConfigKey", "myConfigValue")); + + assertThat( + "long chain of eager buffer config sets logging even after other setters", + bufferConfigWithLoggingCalledAtTheEnd, + is(new EagerBufferConfigImpl(5L, 6L, Collections.singletonMap("myConfigKey", "myConfigValue"))) + ); + assertThat( + "long chain of strict buffer config sets logging even after other setters", + bufferConfigWithLoggingCalledAtTheEnd.shutDownWhenFull(), + is(new StrictBufferConfigImpl(5L, 6L, SHUT_DOWN, Collections.singletonMap("myConfigKey", "myConfigValue"))) + ); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/TimeWindowedDeserializerTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/TimeWindowedDeserializerTest.java new file mode 100644 index 0000000..a035763 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/TimeWindowedDeserializerTest.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.streams.StreamsConfig; +import org.junit.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.Is.is; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class TimeWindowedDeserializerTest { + private final long windowSize = 5000000; + private final TimeWindowedDeserializer timeWindowedDeserializer = new TimeWindowedDeserializer<>(new StringDeserializer(), windowSize); + private final Map props = new HashMap<>(); + + @Test + public void testTimeWindowedDeserializerConstructor() { + timeWindowedDeserializer.configure(props, true); + final Deserializer inner = timeWindowedDeserializer.innerDeserializer(); + assertNotNull("Inner deserializer should be not null", inner); + assertTrue("Inner deserializer type should be StringDeserializer", inner instanceof StringDeserializer); + assertThat(timeWindowedDeserializer.getWindowSize(), is(5000000L)); + } + + @Test + public void shouldSetWindowSizeAndWindowedInnerDeserialiserThroughConfigs() { + props.put(StreamsConfig.WINDOW_SIZE_MS_CONFIG, "500"); + props.put(StreamsConfig.WINDOWED_INNER_CLASS_SERDE, Serdes.ByteArraySerde.class.getName()); + final TimeWindowedDeserializer deserializer = new TimeWindowedDeserializer<>(); + deserializer.configure(props, false); + assertThat(deserializer.getWindowSize(), is(500L)); + assertTrue(deserializer.innerDeserializer() instanceof ByteArrayDeserializer); + } + + @Test + public void shouldThrowErrorIfWindowSizeSetInConfigsAndConstructor() { + props.put(StreamsConfig.WINDOW_SIZE_MS_CONFIG, "500"); + assertThrows(IllegalArgumentException.class, () -> timeWindowedDeserializer.configure(props, false)); + } + + @Test + public void shouldThrowErrorIfWindowSizeIsNotSet() { + props.put(StreamsConfig.WINDOWED_INNER_CLASS_SERDE, Serdes.ByteArraySerde.class.getName()); + final TimeWindowedDeserializer deserializer = new TimeWindowedDeserializer<>(); + assertThrows(IllegalArgumentException.class, () -> deserializer.configure(props, false)); + } + + @Test + public void shouldThrowErrorIfWindowedInnerClassDeserialiserIsNotSet() { + props.put(StreamsConfig.WINDOW_SIZE_MS_CONFIG, "500"); + final TimeWindowedDeserializer deserializer = new TimeWindowedDeserializer<>(); + assertThrows(IllegalArgumentException.class, () -> deserializer.configure(props, false)); + } + + @Test + public void shouldThrowErrorIfWindowedInnerClassDeserialisersConflictInConstructorAndConfig() { + props.put(StreamsConfig.WINDOWED_INNER_CLASS_SERDE, Serdes.ByteArraySerde.class.getName()); + assertThrows(IllegalArgumentException.class, () -> timeWindowedDeserializer.configure(props, false)); + } + + @Test + public void shouldThrowConfigExceptionWhenInvalidWindowedInnerClassDeserialiserSupplied() { + props.put(StreamsConfig.WINDOWED_INNER_CLASS_SERDE, "some.non.existent.class"); + assertThrows(ConfigException.class, () -> timeWindowedDeserializer.configure(props, false)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/TimeWindowedSerializerTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/TimeWindowedSerializerTest.java new file mode 100644 index 0000000..b5e9754 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/TimeWindowedSerializerTest.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.StreamsConfig; +import org.junit.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class TimeWindowedSerializerTest { + private final TimeWindowedSerializer timeWindowedSerializer = new TimeWindowedSerializer<>(Serdes.String().serializer()); + private final Map props = new HashMap<>(); + + @Test + public void testTimeWindowedSerializerConstructor() { + timeWindowedSerializer.configure(props, true); + final Serializer inner = timeWindowedSerializer.innerSerializer(); + assertNotNull("Inner serializer should be not null", inner); + assertTrue("Inner serializer type should be StringSerializer", inner instanceof StringSerializer); + } + + @Test + public void shouldSetWindowedInnerClassSerialiserThroughConfig() { + props.put(StreamsConfig.WINDOWED_INNER_CLASS_SERDE, Serdes.ByteArraySerde.class.getName()); + final TimeWindowedSerializer serializer = new TimeWindowedSerializer<>(); + serializer.configure(props, false); + assertTrue(serializer.innerSerializer() instanceof ByteArraySerializer); + } + + @Test + public void shouldThrowErrorIfWindowedInnerClassSerialiserIsNotSet() { + final TimeWindowedSerializer serializer = new TimeWindowedSerializer<>(); + assertThrows(IllegalArgumentException.class, () -> serializer.configure(props, false)); + } + + @Test + public void shouldThrowErrorIfWindowedInnerClassSerialisersConflictInConstructorAndConfig() { + props.put(StreamsConfig.WINDOWED_INNER_CLASS_SERDE, Serdes.ByteArraySerde.class.getName()); + assertThrows(IllegalArgumentException.class, () -> timeWindowedSerializer.configure(props, false)); + } + + @Test + public void shouldThrowConfigExceptionWhenInvalidWindowedInnerClassSerialiserSupplied() { + props.put(StreamsConfig.WINDOWED_INNER_CLASS_SERDE, "some.non.existent.class"); + assertThrows(ConfigException.class, () -> timeWindowedSerializer.configure(props, false)); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/TimeWindowsTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/TimeWindowsTest.java new file mode 100644 index 0000000..2e8f2e7 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/TimeWindowsTest.java @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.streams.kstream.internals.TimeWindow; +import org.junit.Test; + +import java.time.Duration; +import java.util.Map; + +import static java.time.Duration.ofMillis; +import static org.apache.kafka.streams.EqualityCheck.verifyEquality; +import static org.apache.kafka.streams.EqualityCheck.verifyInEquality; +import static org.apache.kafka.streams.kstream.Windows.DEPRECATED_DEFAULT_24_HR_GRACE_PERIOD; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.fail; + +public class TimeWindowsTest { + + private static final long ANY_SIZE = 123L; + private static final long ANY_GRACE = 1024L; + + @SuppressWarnings("deprecation") + @Test + public void shouldSetWindowSize() { + assertEquals(ANY_SIZE, TimeWindows.of(ofMillis(ANY_SIZE)).sizeMs); + assertEquals(ANY_SIZE, TimeWindows.ofSizeWithNoGrace(ofMillis(ANY_SIZE)).sizeMs); + assertEquals(ANY_SIZE, TimeWindows.ofSizeAndGrace(ofMillis(ANY_SIZE), ofMillis(ANY_GRACE)).sizeMs); + } + + @Test + public void shouldSetWindowAdvance() { + final long anyAdvance = 4; + assertEquals(anyAdvance, TimeWindows.ofSizeWithNoGrace(ofMillis(ANY_SIZE)).advanceBy(ofMillis(anyAdvance)).advanceMs); + } + + @Test + public void windowSizeMustNotBeZero() { + assertThrows(IllegalArgumentException.class, () -> TimeWindows.ofSizeWithNoGrace(ofMillis(0))); + } + + @Test + public void windowSizeMustNotBeNegative() { + assertThrows(IllegalArgumentException.class, () -> TimeWindows.ofSizeWithNoGrace(ofMillis(-1))); + } + + @SuppressWarnings("deprecation") + @Test + public void graceShouldNotCalledAfterGraceSet() { + assertThrows(IllegalStateException.class, () -> TimeWindows.ofSizeAndGrace(ofMillis(10), ofMillis(10)).grace(ofMillis(10))); + assertThrows(IllegalStateException.class, () -> TimeWindows.ofSizeWithNoGrace(ofMillis(10)).grace(ofMillis(10))); + } + + @Test + public void advanceIntervalMustNotBeZero() { + final TimeWindows windowSpec = TimeWindows.ofSizeWithNoGrace(ofMillis(ANY_SIZE)); + try { + windowSpec.advanceBy(ofMillis(0)); + fail("should not accept zero advance parameter"); + } catch (final IllegalArgumentException e) { + // expected + } + } + + @Test + public void advanceIntervalMustNotBeNegative() { + final TimeWindows windowSpec = TimeWindows.ofSizeWithNoGrace(ofMillis(ANY_SIZE)); + try { + windowSpec.advanceBy(ofMillis(-1)); + fail("should not accept negative advance parameter"); + } catch (final IllegalArgumentException e) { + // expected + } + } + + @Test + public void advanceIntervalMustNotBeLargerThanWindowSize() { + final TimeWindows windowSpec = TimeWindows.ofSizeWithNoGrace(ofMillis(ANY_SIZE)); + try { + windowSpec.advanceBy(ofMillis(ANY_SIZE + 1)); + fail("should not accept advance greater than window size"); + } catch (final IllegalArgumentException e) { + // expected + } + } + + @Test + public void gracePeriodShouldEnforceBoundaries() { + TimeWindows.ofSizeAndGrace(ofMillis(3L), ofMillis(0L)); + + try { + TimeWindows.ofSizeAndGrace(ofMillis(3L), ofMillis(-1L)); + fail("should not accept negatives"); + } catch (final IllegalArgumentException e) { + //expected + } + } + + @SuppressWarnings("deprecation") + @Test + public void oldAPIShouldSetDefaultGracePeriod() { + assertEquals(Duration.ofDays(1).toMillis(), DEPRECATED_DEFAULT_24_HR_GRACE_PERIOD); + assertEquals(DEPRECATED_DEFAULT_24_HR_GRACE_PERIOD - 3L, TimeWindows.of(ofMillis(3L)).gracePeriodMs()); + assertEquals(0L, TimeWindows.of(ofMillis(DEPRECATED_DEFAULT_24_HR_GRACE_PERIOD)).gracePeriodMs()); + assertEquals(0L, TimeWindows.of(ofMillis(DEPRECATED_DEFAULT_24_HR_GRACE_PERIOD + 1L)).gracePeriodMs()); + } + + @Test + public void shouldComputeWindowsForHoppingWindows() { + final TimeWindows windows = TimeWindows.ofSizeWithNoGrace(ofMillis(12L)).advanceBy(ofMillis(5L)); + final Map matched = windows.windowsFor(21L); + assertEquals(12L / 5L + 1, matched.size()); + assertEquals(new TimeWindow(10L, 22L), matched.get(10L)); + assertEquals(new TimeWindow(15L, 27L), matched.get(15L)); + assertEquals(new TimeWindow(20L, 32L), matched.get(20L)); + } + + @Test + public void shouldComputeWindowsForBarelyOverlappingHoppingWindows() { + final TimeWindows windows = TimeWindows.ofSizeWithNoGrace(ofMillis(6L)).advanceBy(ofMillis(5L)); + final Map matched = windows.windowsFor(7L); + assertEquals(1, matched.size()); + assertEquals(new TimeWindow(5L, 11L), matched.get(5L)); + } + + @Test + public void shouldComputeWindowsForTumblingWindows() { + final TimeWindows windows = TimeWindows.ofSizeWithNoGrace(ofMillis(12L)); + final Map matched = windows.windowsFor(21L); + assertEquals(1, matched.size()); + assertEquals(new TimeWindow(12L, 24L), matched.get(12L)); + } + + + @Test + public void equalsAndHashcodeShouldBeValidForPositiveCases() { + verifyEquality(TimeWindows.ofSizeWithNoGrace(ofMillis(3)), TimeWindows.ofSizeWithNoGrace(ofMillis(3))); + + verifyEquality(TimeWindows.ofSizeWithNoGrace(ofMillis(3)).advanceBy(ofMillis(1)), TimeWindows.ofSizeWithNoGrace(ofMillis(3)).advanceBy(ofMillis(1))); + + verifyEquality(TimeWindows.ofSizeAndGrace(ofMillis(3), ofMillis(4)), TimeWindows.ofSizeAndGrace(ofMillis(3), ofMillis(4))); + + verifyEquality(TimeWindows.ofSizeAndGrace(ofMillis(3), ofMillis(33)), + TimeWindows.ofSizeAndGrace(ofMillis(3), ofMillis(33)) + ); + } + + @Test + public void equalsAndHashcodeShouldBeValidForNegativeCases() { + + verifyInEquality( + TimeWindows.ofSizeWithNoGrace(ofMillis(9)), + TimeWindows.ofSizeWithNoGrace(ofMillis(3)) + ); + + verifyInEquality( + TimeWindows.ofSizeAndGrace(ofMillis(9), ofMillis(9)), + TimeWindows.ofSizeAndGrace(ofMillis(3), ofMillis(9)) + ); + + verifyInEquality(TimeWindows.ofSizeWithNoGrace(ofMillis(3)).advanceBy(ofMillis(2)), TimeWindows.ofSizeWithNoGrace(ofMillis(3)).advanceBy(ofMillis(1))); + + verifyInEquality(TimeWindows.ofSizeAndGrace(ofMillis(3), ofMillis(2)), TimeWindows.ofSizeAndGrace(ofMillis(3), ofMillis(1))); + + verifyInEquality(TimeWindows.ofSizeAndGrace(ofMillis(3), ofMillis(9)), TimeWindows.ofSizeAndGrace(ofMillis(3), ofMillis(4))); + + verifyInEquality( + TimeWindows.ofSizeAndGrace(ofMillis(4), ofMillis(2)).advanceBy(ofMillis(2)), + TimeWindows.ofSizeAndGrace(ofMillis(3), ofMillis(2)).advanceBy(ofMillis(2)) + ); + + verifyInEquality( + TimeWindows.ofSizeAndGrace(ofMillis(3), ofMillis(2)).advanceBy(ofMillis(1)), + TimeWindows.ofSizeAndGrace(ofMillis(3), ofMillis(2)).advanceBy(ofMillis(2)) + ); + + assertNotEquals( + TimeWindows.ofSizeAndGrace(ofMillis(3), ofMillis(1)).advanceBy(ofMillis(2)), + TimeWindows.ofSizeAndGrace(ofMillis(3), ofMillis(2)).advanceBy(ofMillis(2)) + ); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/UnlimitedWindowsTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/UnlimitedWindowsTest.java new file mode 100644 index 0000000..e04da45 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/UnlimitedWindowsTest.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.streams.kstream.internals.UnlimitedWindow; +import org.junit.Test; + +import java.util.Map; + +import static java.time.Instant.ofEpochMilli; +import static org.apache.kafka.streams.EqualityCheck.verifyEquality; +import static org.apache.kafka.streams.EqualityCheck.verifyInEquality; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class UnlimitedWindowsTest { + + private static final long ANY_START_TIME = 10L; + + @Test + public void shouldSetWindowStartTime() { + assertEquals(ANY_START_TIME, UnlimitedWindows.of().startOn(ofEpochMilli(ANY_START_TIME)).startMs); + } + + @Test + public void startTimeMustNotBeNegative() { + assertThrows(IllegalArgumentException.class, () -> UnlimitedWindows.of().startOn(ofEpochMilli(-1))); + } + + @Test + public void shouldIncludeRecordsThatHappenedOnWindowStart() { + final UnlimitedWindows w = UnlimitedWindows.of().startOn(ofEpochMilli(ANY_START_TIME)); + final Map matchedWindows = w.windowsFor(w.startMs); + assertEquals(1, matchedWindows.size()); + assertEquals(new UnlimitedWindow(ANY_START_TIME), matchedWindows.get(ANY_START_TIME)); + } + + @Test + public void shouldIncludeRecordsThatHappenedAfterWindowStart() { + final UnlimitedWindows w = UnlimitedWindows.of().startOn(ofEpochMilli(ANY_START_TIME)); + final long timestamp = w.startMs + 1; + final Map matchedWindows = w.windowsFor(timestamp); + assertEquals(1, matchedWindows.size()); + assertEquals(new UnlimitedWindow(ANY_START_TIME), matchedWindows.get(ANY_START_TIME)); + } + + @Test + public void shouldExcludeRecordsThatHappenedBeforeWindowStart() { + final UnlimitedWindows w = UnlimitedWindows.of().startOn(ofEpochMilli(ANY_START_TIME)); + final long timestamp = w.startMs - 1; + final Map matchedWindows = w.windowsFor(timestamp); + assertTrue(matchedWindows.isEmpty()); + } + + @Test + public void equalsAndHashcodeShouldBeValidForPositiveCases() { + verifyEquality(UnlimitedWindows.of(), UnlimitedWindows.of()); + + verifyEquality(UnlimitedWindows.of().startOn(ofEpochMilli(1)), UnlimitedWindows.of().startOn(ofEpochMilli(1))); + } + + @Test + public void equalsAndHashcodeShouldBeValidForNegativeCases() { + verifyInEquality(UnlimitedWindows.of().startOn(ofEpochMilli(9)), UnlimitedWindows.of().startOn(ofEpochMilli(1))); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/WindowTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/WindowTest.java new file mode 100644 index 0000000..adb4b80 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/WindowTest.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertThrows; + +public class WindowTest { + + static class TestWindow extends Window { + TestWindow(final long startMs, final long endMs) { + super(startMs, endMs); + } + + @Override + public boolean overlap(final Window other) { + return false; + } + } + + static class TestWindow2 extends Window { + TestWindow2(final long startMs, final long endMs) { + super(startMs, endMs); + } + + @Override + public boolean overlap(final Window other) { + return false; + } + } + + private final TestWindow window = new TestWindow(5, 10); + + @Test + public void shouldThrowIfStartIsNegative() { + assertThrows(IllegalArgumentException.class, () -> new TestWindow(-1, 0)); + } + + @Test + public void shouldThrowIfEndIsSmallerThanStart() { + assertThrows(IllegalArgumentException.class, () -> new TestWindow(1, 0)); + } + + @Test + public void shouldBeEqualIfStartAndEndSame() { + final TestWindow window2 = new TestWindow(window.startMs, window.endMs); + + assertEquals(window, window); + assertEquals(window, window2); + assertEquals(window2, window); + } + + @Test + public void shouldNotBeEqualIfNull() { + assertNotEquals(window, null); + } + + @Test + public void shouldNotBeEqualIfStartOrEndIsDifferent() { + assertNotEquals(window, new TestWindow(0, window.endMs)); + assertNotEquals(window, new TestWindow(7, window.endMs)); + assertNotEquals(window, new TestWindow(window.startMs, 7)); + assertNotEquals(window, new TestWindow(window.startMs, 15)); + assertNotEquals(window, new TestWindow(7, 8)); + assertNotEquals(window, new TestWindow(0, 15)); + } + + @Test + public void shouldNotBeEqualIfDifferentWindowType() { + assertNotEquals(window, new TestWindow2(window.startMs, window.endMs)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/WindowedSerdesTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/WindowedSerdesTest.java new file mode 100644 index 0000000..f12d0db --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/WindowedSerdesTest.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.kstream.internals.SessionWindow; +import org.apache.kafka.streams.kstream.internals.TimeWindow; +import org.junit.Assert; +import org.junit.Test; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class WindowedSerdesTest { + + private final String topic = "sample"; + + @Test + public void shouldWrapForTimeWindowedSerde() { + final Serde> serde = WindowedSerdes.timeWindowedSerdeFrom(String.class, Long.MAX_VALUE); + assertTrue(serde.serializer() instanceof TimeWindowedSerializer); + assertTrue(serde.deserializer() instanceof TimeWindowedDeserializer); + assertTrue(((TimeWindowedSerializer) serde.serializer()).innerSerializer() instanceof StringSerializer); + assertTrue(((TimeWindowedDeserializer) serde.deserializer()).innerDeserializer() instanceof StringDeserializer); + } + + @Test + public void shouldWrapForSessionWindowedSerde() { + final Serde> serde = WindowedSerdes.sessionWindowedSerdeFrom(String.class); + assertTrue(serde.serializer() instanceof SessionWindowedSerializer); + assertTrue(serde.deserializer() instanceof SessionWindowedDeserializer); + assertTrue(((SessionWindowedSerializer) serde.serializer()).innerSerializer() instanceof StringSerializer); + assertTrue(((SessionWindowedDeserializer) serde.deserializer()).innerDeserializer() instanceof StringDeserializer); + } + + @Test + public void testTimeWindowSerdeFrom() { + final Windowed timeWindowed = new Windowed<>(10, new TimeWindow(0, Long.MAX_VALUE)); + final Serde> timeWindowedSerde = WindowedSerdes.timeWindowedSerdeFrom(Integer.class, Long.MAX_VALUE); + final byte[] bytes = timeWindowedSerde.serializer().serialize(topic, timeWindowed); + final Windowed windowed = timeWindowedSerde.deserializer().deserialize(topic, bytes); + Assert.assertEquals(timeWindowed, windowed); + } + + @Test + public void testSessionWindowedSerdeFrom() { + final Windowed sessionWindowed = new Windowed<>(10, new SessionWindow(0, 1)); + final Serde> sessionWindowedSerde = WindowedSerdes.sessionWindowedSerdeFrom(Integer.class); + final byte[] bytes = sessionWindowedSerde.serializer().serialize(topic, sessionWindowed); + final Windowed windowed = sessionWindowedSerde.deserializer().deserialize(topic, bytes); + Assert.assertEquals(sessionWindowed, windowed); + } + + @Test + public void timeWindowedSerializerShouldThrowNpeIfNotInitializedProperly() { + final TimeWindowedSerializer serializer = new TimeWindowedSerializer<>(); + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> serializer.serialize("topic", new Windowed<>(new byte[0], new TimeWindow(0, 1)))); + assertThat( + exception.getMessage(), + equalTo("Inner serializer is `null`. User code must use constructor " + + "`TimeWindowedSerializer(final Serializer inner)` instead of the no-arg constructor.")); + } + + @Test + public void timeWindowedSerializerShouldThrowNpeOnSerializingBaseKeyIfNotInitializedProperly() { + final TimeWindowedSerializer serializer = new TimeWindowedSerializer<>(); + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> serializer.serializeBaseKey("topic", new Windowed<>(new byte[0], new TimeWindow(0, 1)))); + assertThat( + exception.getMessage(), + equalTo("Inner serializer is `null`. User code must use constructor " + + "`TimeWindowedSerializer(final Serializer inner)` instead of the no-arg constructor.")); + } + + @Test + public void timeWindowedDeserializerShouldThrowNpeIfNotInitializedProperly() { + final TimeWindowedDeserializer deserializer = new TimeWindowedDeserializer<>(); + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> deserializer.deserialize("topic", new byte[0])); + assertThat( + exception.getMessage(), + equalTo("Inner deserializer is `null`. User code must use constructor " + + "`TimeWindowedDeserializer(final Deserializer inner)` instead of the no-arg constructor.")); + } + + @Test + public void sessionWindowedSerializerShouldThrowNpeIfNotInitializedProperly() { + final SessionWindowedSerializer serializer = new SessionWindowedSerializer<>(); + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> serializer.serialize("topic", new Windowed<>(new byte[0], new SessionWindow(0, 0)))); + assertThat( + exception.getMessage(), + equalTo("Inner serializer is `null`. User code must use constructor " + + "`SessionWindowedSerializer(final Serializer inner)` instead of the no-arg constructor.")); + } + + @Test + public void sessionWindowedSerializerShouldThrowNpeOnSerializingBaseKeyIfNotInitializedProperly() { + final SessionWindowedSerializer serializer = new SessionWindowedSerializer<>(); + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> serializer.serializeBaseKey("topic", new Windowed<>(new byte[0], new SessionWindow(0, 0)))); + assertThat( + exception.getMessage(), + equalTo("Inner serializer is `null`. User code must use constructor " + + "`SessionWindowedSerializer(final Serializer inner)` instead of the no-arg constructor.")); + } + + @Test + public void sessionWindowedDeserializerShouldThrowNpeIfNotInitializedProperly() { + final SessionWindowedDeserializer deserializer = new SessionWindowedDeserializer<>(); + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> deserializer.deserialize("topic", new byte[0])); + assertThat( + exception.getMessage(), + equalTo("Inner deserializer is `null`. User code must use constructor " + + "`SessionWindowedDeserializer(final Deserializer inner)` instead of the no-arg constructor.")); + } + + @Test + public void timeWindowedSerializerShouldNotThrowOnCloseIfNotInitializedProperly() { + new TimeWindowedSerializer<>().close(); + } + + @Test + public void timeWindowedDeserializerShouldNotThrowOnCloseIfNotInitializedProperly() { + new TimeWindowedDeserializer<>().close(); + } + + @Test + public void sessionWindowedSerializerShouldNotThrowOnCloseIfNotInitializedProperly() { + new SessionWindowedSerializer<>().close(); + } + + @Test + public void sessionWindowedDeserializerShouldNotThrowOnCloseIfNotInitializedProperly() { + new SessionWindowedDeserializer<>().close(); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/AbstractStreamTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/AbstractStreamTest.java new file mode 100644 index 0000000..ae7ef8f --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/AbstractStreamTest.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.test.NoopValueTransformer; +import org.apache.kafka.test.NoopValueTransformerWithKey; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.ValueTransformerSupplier; +import org.apache.kafka.streams.kstream.ValueTransformerWithKeySupplier; +import org.apache.kafka.streams.kstream.internals.graph.ProcessorGraphNode; +import org.apache.kafka.streams.kstream.internals.graph.ProcessorParameters; +import org.apache.kafka.test.MockProcessorSupplier; +import org.junit.Test; + +import java.util.Random; + +import static org.easymock.EasyMock.createMock; +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.verify; +import static org.junit.Assert.assertTrue; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class AbstractStreamTest { + + @Test + public void testToInternalValueTransformerSupplierSuppliesNewTransformers() { + final ValueTransformerSupplier valueTransformerSupplier = createMock(ValueTransformerSupplier.class); + expect(valueTransformerSupplier.get()).andAnswer(NoopValueTransformer::new).atLeastOnce(); + replay(valueTransformerSupplier); + final ValueTransformerWithKeySupplier valueTransformerWithKeySupplier = + AbstractStream.toValueTransformerWithKeySupplier(valueTransformerSupplier); + valueTransformerWithKeySupplier.get(); + valueTransformerWithKeySupplier.get(); + valueTransformerWithKeySupplier.get(); + verify(valueTransformerSupplier); + } + + @Test + public void testToInternalValueTransformerWithKeySupplierSuppliesNewTransformers() { + final ValueTransformerWithKeySupplier valueTransformerWithKeySupplier = + createMock(ValueTransformerWithKeySupplier.class); + expect(valueTransformerWithKeySupplier.get()).andAnswer(NoopValueTransformerWithKey::new).atLeastOnce(); + replay(valueTransformerWithKeySupplier); + valueTransformerWithKeySupplier.get(); + valueTransformerWithKeySupplier.get(); + valueTransformerWithKeySupplier.get(); + verify(valueTransformerWithKeySupplier); + } + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Test + public void testShouldBeExtensible() { + final StreamsBuilder builder = new StreamsBuilder(); + final int[] expectedKeys = new int[]{1, 2, 3, 4, 5, 6, 7}; + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + final String topicName = "topic"; + + final ExtendedKStream stream = new ExtendedKStream<>(builder.stream(topicName, Consumed.with(Serdes.Integer(), Serdes.String()))); + + stream.randomFilter().process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build())) { + + final TestInputTopic inputTopic = driver.createInputTopic(topicName, new IntegerSerializer(), new StringSerializer()); + for (final int expectedKey : expectedKeys) { + inputTopic.pipeInput(expectedKey, "V" + expectedKey); + } + + assertTrue(supplier.theCapturedProcessor().processed().size() <= expectedKeys.length); + } + } + + private static class ExtendedKStream extends AbstractStream { + + ExtendedKStream(final KStream stream) { + super((KStreamImpl) stream); + } + + KStream randomFilter() { + final String name = builder.newProcessorName("RANDOM-FILTER-"); + final ProcessorGraphNode processorNode = new ProcessorGraphNode<>( + name, + new ProcessorParameters<>(new ExtendedKStreamDummy<>(), name)); + builder.addGraphNode(this.graphNode, processorNode); + return new KStreamImpl<>(name, null, null, subTopologySourceNodes, false, processorNode, builder); + } + } + + private static class ExtendedKStreamDummy implements org.apache.kafka.streams.processor.ProcessorSupplier { + + private final Random rand; + + ExtendedKStreamDummy() { + rand = new Random(); + } + + @Override + public org.apache.kafka.streams.processor.Processor get() { + return new ExtendedKStreamDummyProcessor(); + } + + private class ExtendedKStreamDummyProcessor extends org.apache.kafka.streams.processor.AbstractProcessor { + @Override + public void process(final K key, final V value) { + // flip a coin and filter + if (rand.nextBoolean()) { + context().forward(key, value); + } + } + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/CogroupedKStreamImplTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/CogroupedKStreamImplTest.java new file mode 100644 index 0000000..59a922c --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/CogroupedKStreamImplTest.java @@ -0,0 +1,1264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import java.util.Properties; +import org.apache.kafka.common.serialization.IntegerDeserializer; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TestOutputTopic; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.CogroupedKStream; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Named; +import org.apache.kafka.streams.kstream.SessionWindows; +import org.apache.kafka.streams.kstream.SlidingWindows; +import org.apache.kafka.streams.kstream.Window; +import org.apache.kafka.streams.kstream.Windows; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.test.TestRecord; +import org.apache.kafka.test.MockAggregator; +import org.apache.kafka.test.MockInitializer; +import org.apache.kafka.test.MockValueJoiner; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Before; +import org.junit.Test; + +public class CogroupedKStreamImplTest { + private final Consumed stringConsumed = Consumed.with(Serdes.String(), Serdes.String()); + private static final String TOPIC = "topic"; + private static final String OUTPUT = "output"; + private KGroupedStream groupedStream; + private CogroupedKStream cogroupedStream; + + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.String(), Serdes.String()); + + private static final Aggregator STRING_AGGREGATOR = + (key, value, aggregate) -> aggregate + value; + + private static final Initializer STRING_INITIALIZER = () -> ""; + + private static final Aggregator STRING_SUM_AGGREGATOR = + (key, value, aggregate) -> aggregate + Integer.parseInt(value); + + private static final Aggregator SUM_AGGREGATOR = + (key, value, aggregate) -> aggregate + value; + + private static final Initializer SUM_INITIALIZER = () -> 0; + + + @Before + public void setup() { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream = builder.stream(TOPIC, Consumed.with(Serdes.String(), Serdes.String())); + + groupedStream = stream.groupByKey(Grouped.with(Serdes.String(), Serdes.String())); + cogroupedStream = groupedStream.cogroup(MockAggregator.TOSTRING_ADDER); + } + + @Test + public void shouldThrowNPEInCogroupIfKGroupedStreamIsNull() { + assertThrows(NullPointerException.class, () -> cogroupedStream.cogroup(null, MockAggregator.TOSTRING_ADDER)); + } + + @Test + public void shouldNotHaveNullAggregatorOnCogroup() { + assertThrows(NullPointerException.class, () -> cogroupedStream.cogroup(groupedStream, null)); + } + + @Test + public void shouldNotHaveNullInitializerOnAggregate() { + assertThrows(NullPointerException.class, () -> cogroupedStream.aggregate(null)); + } + + @Test + public void shouldNotHaveNullInitializerOnAggregateWitNamed() { + assertThrows(NullPointerException.class, () -> cogroupedStream.aggregate(null, Named.as("name"))); + } + + @Test + public void shouldNotHaveNullInitializerOnAggregateWitMaterialized() { + assertThrows(NullPointerException.class, () -> cogroupedStream.aggregate(null, Materialized.as("store"))); + } + + @Test + public void shouldNotHaveNullInitializerOnAggregateWitNamedAndMaterialized() { + assertThrows(NullPointerException.class, () -> cogroupedStream.aggregate(null, Named.as("name"), Materialized.as("store"))); + } + + @Test + public void shouldNotHaveNullNamedOnAggregate() { + assertThrows(NullPointerException.class, () -> cogroupedStream.aggregate(STRING_INITIALIZER, (Named) null)); + } + + @Test + public void shouldNotHaveNullMaterializedOnAggregate() { + assertThrows(NullPointerException.class, () -> cogroupedStream.aggregate(STRING_INITIALIZER, (Materialized>) null)); + } + + @Test + public void shouldNotHaveNullNamedOnAggregateWithMateriazlied() { + assertThrows(NullPointerException.class, () -> cogroupedStream.aggregate(STRING_INITIALIZER, null, Materialized.as("store"))); + } + + @Test + public void shouldNotHaveNullMaterializedOnAggregateWithNames() { + assertThrows(NullPointerException.class, () -> cogroupedStream.aggregate(STRING_INITIALIZER, Named.as("name"), null)); + } + + @Test + public void shouldNotHaveNullWindowOnWindowedByTime() { + assertThrows(NullPointerException.class, () -> cogroupedStream.windowedBy((Windows) null)); + } + + @Test + public void shouldNotHaveNullWindowOnWindowedBySession() { + assertThrows(NullPointerException.class, () -> cogroupedStream.windowedBy((SessionWindows) null)); + } + + @Test + public void shouldNotHaveNullWindowOnWindowedBySliding() { + assertThrows(NullPointerException.class, () -> cogroupedStream.windowedBy((SlidingWindows) null)); + } + + @Test + public void shouldNameProcessorsAndStoreBasedOnNamedParameter() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream stream1 = builder.stream("one", stringConsumed); + final KStream test2 = builder.stream("two", stringConsumed); + + final KGroupedStream groupedOne = stream1.groupByKey(); + final KGroupedStream groupedTwo = test2.groupByKey(); + + final KTable customers = groupedOne + .cogroup(STRING_AGGREGATOR) + .cogroup(groupedTwo, STRING_AGGREGATOR) + .aggregate(STRING_INITIALIZER, Named.as("test"), Materialized.as("store")); + + customers.toStream().to(OUTPUT); + + final String topologyDescription = builder.build().describe().toString(); + + assertThat( + topologyDescription, + equalTo("Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [one])\n" + + " --> test-cogroup-agg-0\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [two])\n" + + " --> test-cogroup-agg-1\n" + + " Processor: test-cogroup-agg-0 (stores: [store])\n" + + " --> test-cogroup-merge\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: test-cogroup-agg-1 (stores: [store])\n" + + " --> test-cogroup-merge\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: test-cogroup-merge (stores: [])\n" + + " --> KTABLE-TOSTREAM-0000000005\n" + + " <-- test-cogroup-agg-0, test-cogroup-agg-1\n" + + " Processor: KTABLE-TOSTREAM-0000000005 (stores: [])\n" + + " --> KSTREAM-SINK-0000000006\n" + + " <-- test-cogroup-merge\n" + + " Sink: KSTREAM-SINK-0000000006 (topic: output)\n" + + " <-- KTABLE-TOSTREAM-0000000005\n\n")); + } + + @Test + public void shouldNameRepartitionTopic() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream stream1 = builder.stream("one", stringConsumed); + final KStream test2 = builder.stream("two", stringConsumed); + + final KGroupedStream groupedOne = stream1.map((k, v) -> new KeyValue<>(v, k)).groupByKey(Grouped.as("repartition-test")); + final KGroupedStream groupedTwo = test2.groupByKey(); + + final KTable customers = groupedOne + .cogroup(STRING_AGGREGATOR) + .cogroup(groupedTwo, STRING_AGGREGATOR) + .aggregate(STRING_INITIALIZER); + + customers.toStream().to(OUTPUT); + + final String topologyDescription = builder.build().describe().toString(); + + assertThat( + topologyDescription, + equalTo("Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [one])\n" + + " --> KSTREAM-MAP-0000000002\n" + + " Processor: KSTREAM-MAP-0000000002 (stores: [])\n" + + " --> repartition-test-repartition-filter\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: repartition-test-repartition-filter (stores: [])\n" + + " --> repartition-test-repartition-sink\n" + + " <-- KSTREAM-MAP-0000000002\n" + + " Sink: repartition-test-repartition-sink (topic: repartition-test-repartition)\n" + + " <-- repartition-test-repartition-filter\n\n" + + " Sub-topology: 1\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [two])\n" + + " --> COGROUPKSTREAM-AGGREGATE-0000000008\n" + + " Source: repartition-test-repartition-source (topics: [repartition-test-repartition])\n" + + " --> COGROUPKSTREAM-AGGREGATE-0000000007\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-0000000007 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003])\n" + + " --> COGROUPKSTREAM-MERGE-0000000009\n" + + " <-- repartition-test-repartition-source\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-0000000008 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003])\n" + + " --> COGROUPKSTREAM-MERGE-0000000009\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: COGROUPKSTREAM-MERGE-0000000009 (stores: [])\n" + + " --> KTABLE-TOSTREAM-0000000010\n" + + " <-- COGROUPKSTREAM-AGGREGATE-0000000007, COGROUPKSTREAM-AGGREGATE-0000000008\n" + + " Processor: KTABLE-TOSTREAM-0000000010 (stores: [])\n" + + " --> KSTREAM-SINK-0000000011\n" + + " <-- COGROUPKSTREAM-MERGE-0000000009\n" + + " Sink: KSTREAM-SINK-0000000011 (topic: output)\n" + + " <-- KTABLE-TOSTREAM-0000000010\n\n")); + } + + @Test + public void shouldInsertRepartitionsTopicForUpstreamKeyModification() { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream1 = builder.stream("one", stringConsumed); + final KStream test2 = builder.stream("two", stringConsumed); + + final KGroupedStream groupedOne = stream1.map((k, v) -> new KeyValue<>(v, k)).groupByKey(); + final KGroupedStream groupedTwo = test2.groupByKey(); + + final KTable customers = groupedOne + .cogroup(STRING_AGGREGATOR) + .cogroup(groupedTwo, STRING_AGGREGATOR) + .aggregate(STRING_INITIALIZER, Named.as("test"), Materialized.as("store")); + + customers.toStream().to(OUTPUT); + + final String topologyDescription = builder.build().describe().toString(); + + assertThat( + topologyDescription, + equalTo("Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [one])\n" + + " --> KSTREAM-MAP-0000000002\n" + + " Processor: KSTREAM-MAP-0000000002 (stores: [])\n" + + " --> store-repartition-filter\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: store-repartition-filter (stores: [])\n" + + " --> store-repartition-sink\n" + + " <-- KSTREAM-MAP-0000000002\n" + + " Sink: store-repartition-sink (topic: store-repartition)\n" + + " <-- store-repartition-filter\n\n" + + " Sub-topology: 1\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [two])\n" + + " --> test-cogroup-agg-1\n" + + " Source: store-repartition-source (topics: [store-repartition])\n" + + " --> test-cogroup-agg-0\n" + + " Processor: test-cogroup-agg-0 (stores: [store])\n" + + " --> test-cogroup-merge\n" + + " <-- store-repartition-source\n" + + " Processor: test-cogroup-agg-1 (stores: [store])\n" + + " --> test-cogroup-merge\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: test-cogroup-merge (stores: [])\n" + + " --> KTABLE-TOSTREAM-0000000009\n" + + " <-- test-cogroup-agg-0, test-cogroup-agg-1\n" + + " Processor: KTABLE-TOSTREAM-0000000009 (stores: [])\n" + + " --> KSTREAM-SINK-0000000010\n" + + " <-- test-cogroup-merge\n" + + " Sink: KSTREAM-SINK-0000000010 (topic: output)\n" + + " <-- KTABLE-TOSTREAM-0000000009\n\n")); + } + + @Test + public void shouldInsertRepartitionsTopicForUpstreamKeyModificationWithGroupedReusedInSameCogroups() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream stream1 = builder.stream("one", stringConsumed); + final KStream stream2 = builder.stream("two", stringConsumed); + + final KGroupedStream groupedOne = stream1.map((k, v) -> new KeyValue<>(v, k)).groupByKey(); + final KGroupedStream groupedTwo = stream2.groupByKey(); + + final KTable cogroupedTwo = groupedOne + .cogroup(STRING_AGGREGATOR) + .cogroup(groupedTwo, STRING_AGGREGATOR) + .aggregate(STRING_INITIALIZER); + + final KTable cogroupedOne = groupedOne + .cogroup(STRING_AGGREGATOR) + .cogroup(groupedTwo, STRING_AGGREGATOR) + .aggregate(STRING_INITIALIZER); + + cogroupedOne.toStream().to(OUTPUT); + cogroupedTwo.toStream().to("OUTPUT2"); + + final String topologyDescription = builder.build().describe().toString(); + + assertThat( + topologyDescription, + equalTo("Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [one])\n" + + " --> KSTREAM-MAP-0000000002\n" + + " Processor: KSTREAM-MAP-0000000002 (stores: [])\n" + + " --> COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000010-repartition-filter, COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-filter\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-filter (stores: [])\n" + + " --> COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-sink\n" + + " <-- KSTREAM-MAP-0000000002\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000010-repartition-filter (stores: [])\n" + + " --> COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000010-repartition-sink\n" + + " <-- KSTREAM-MAP-0000000002\n" + + " Sink: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-sink (topic: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition)\n" + + " <-- COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-filter\n" + + " Sink: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000010-repartition-sink (topic: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000010-repartition)\n" + + " <-- COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000010-repartition-filter\n\n" + + " Sub-topology: 1\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [two])\n" + + " --> COGROUPKSTREAM-AGGREGATE-0000000008, COGROUPKSTREAM-AGGREGATE-0000000015\n" + + " Source: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-source (topics: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition])\n" + + " --> COGROUPKSTREAM-AGGREGATE-0000000007\n" + + " Source: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000010-repartition-source (topics: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000010-repartition])\n" + + " --> COGROUPKSTREAM-AGGREGATE-0000000014\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-0000000007 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003])\n" + + " --> COGROUPKSTREAM-MERGE-0000000009\n" + + " <-- COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-source\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-0000000008 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003])\n" + + " --> COGROUPKSTREAM-MERGE-0000000009\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-0000000014 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000010])\n" + + " --> COGROUPKSTREAM-MERGE-0000000016\n" + + " <-- COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000010-repartition-source\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-0000000015 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000010])\n" + + " --> COGROUPKSTREAM-MERGE-0000000016\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: COGROUPKSTREAM-MERGE-0000000009 (stores: [])\n" + + " --> KTABLE-TOSTREAM-0000000019\n" + + " <-- COGROUPKSTREAM-AGGREGATE-0000000007, COGROUPKSTREAM-AGGREGATE-0000000008\n" + + " Processor: COGROUPKSTREAM-MERGE-0000000016 (stores: [])\n" + + " --> KTABLE-TOSTREAM-0000000017\n" + + " <-- COGROUPKSTREAM-AGGREGATE-0000000014, COGROUPKSTREAM-AGGREGATE-0000000015\n" + + " Processor: KTABLE-TOSTREAM-0000000017 (stores: [])\n" + + " --> KSTREAM-SINK-0000000018\n" + + " <-- COGROUPKSTREAM-MERGE-0000000016\n" + + " Processor: KTABLE-TOSTREAM-0000000019 (stores: [])\n" + + " --> KSTREAM-SINK-0000000020\n" + + " <-- COGROUPKSTREAM-MERGE-0000000009\n" + + " Sink: KSTREAM-SINK-0000000018 (topic: output)\n" + + " <-- KTABLE-TOSTREAM-0000000017\n" + + " Sink: KSTREAM-SINK-0000000020 (topic: OUTPUT2)\n" + + " <-- KTABLE-TOSTREAM-0000000019\n\n")); + } + + @Test + public void shouldInsertRepartitionsTopicForUpstreamKeyModificationWithGroupedReusedInSameCogroupsWithOptimization() { + final Properties properties = new Properties(); + properties.setProperty(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, StreamsConfig.OPTIMIZE); + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream stream1 = builder.stream("one", stringConsumed); + final KStream stream2 = builder.stream("two", stringConsumed); + + final KGroupedStream groupedOne = stream1.map((k, v) -> new KeyValue<>(v, k)).groupByKey(); + final KGroupedStream groupedTwo = stream2.groupByKey(); + + final KTable cogroupedTwo = groupedOne + .cogroup(STRING_AGGREGATOR) + .cogroup(groupedTwo, STRING_AGGREGATOR) + .aggregate(STRING_INITIALIZER); + + final KTable cogroupedOne = groupedOne + .cogroup(STRING_AGGREGATOR) + .cogroup(groupedTwo, STRING_AGGREGATOR) + .aggregate(STRING_INITIALIZER); + + cogroupedOne.toStream().to(OUTPUT); + cogroupedTwo.toStream().to("OUTPUT2"); + + final String topologyDescription = builder.build(properties).describe().toString(); + + assertThat( + topologyDescription, + equalTo("Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [one])\n" + + " --> KSTREAM-MAP-0000000002\n" + + " Processor: KSTREAM-MAP-0000000002 (stores: [])\n" + + " --> COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-filter\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-filter (stores: [])\n" + + " --> COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-sink\n" + + " <-- KSTREAM-MAP-0000000002\n" + + " Sink: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-sink (topic: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition)\n" + + " <-- COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-filter\n\n" + + " Sub-topology: 1\n" + + " Source: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-source (topics: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition])\n" + + " --> COGROUPKSTREAM-AGGREGATE-0000000014, COGROUPKSTREAM-AGGREGATE-0000000007\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [two])\n" + + " --> COGROUPKSTREAM-AGGREGATE-0000000015, COGROUPKSTREAM-AGGREGATE-0000000008\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-0000000007 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003])\n" + + " --> COGROUPKSTREAM-MERGE-0000000009\n" + + " <-- COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-source\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-0000000008 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003])\n" + + " --> COGROUPKSTREAM-MERGE-0000000009\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-0000000014 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000010])\n" + + " --> COGROUPKSTREAM-MERGE-0000000016\n" + + " <-- COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-source\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-0000000015 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000010])\n" + + " --> COGROUPKSTREAM-MERGE-0000000016\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: COGROUPKSTREAM-MERGE-0000000009 (stores: [])\n" + + " --> KTABLE-TOSTREAM-0000000019\n" + + " <-- COGROUPKSTREAM-AGGREGATE-0000000007, COGROUPKSTREAM-AGGREGATE-0000000008\n" + + " Processor: COGROUPKSTREAM-MERGE-0000000016 (stores: [])\n" + + " --> KTABLE-TOSTREAM-0000000017\n" + + " <-- COGROUPKSTREAM-AGGREGATE-0000000014, COGROUPKSTREAM-AGGREGATE-0000000015\n" + + " Processor: KTABLE-TOSTREAM-0000000017 (stores: [])\n" + + " --> KSTREAM-SINK-0000000018\n" + + " <-- COGROUPKSTREAM-MERGE-0000000016\n" + + " Processor: KTABLE-TOSTREAM-0000000019 (stores: [])\n" + + " --> KSTREAM-SINK-0000000020\n" + + " <-- COGROUPKSTREAM-MERGE-0000000009\n" + + " Sink: KSTREAM-SINK-0000000018 (topic: output)\n" + + " <-- KTABLE-TOSTREAM-0000000017\n" + + " Sink: KSTREAM-SINK-0000000020 (topic: OUTPUT2)\n" + + " <-- KTABLE-TOSTREAM-0000000019\n\n")); + } + + @Test + public void shouldInsertRepartitionsTopicForUpstreamKeyModificationWithGroupedReusedInDifferentCogroups() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream stream1 = builder.stream("one", stringConsumed); + final KStream stream2 = builder.stream("two", stringConsumed); + final KStream stream3 = builder.stream("three", stringConsumed); + + final KGroupedStream groupedOne = stream1.map((k, v) -> new KeyValue<>(v, k)).groupByKey(); + final KGroupedStream groupedTwo = stream2.groupByKey(); + final KGroupedStream groupedThree = stream3.groupByKey(); + + groupedOne.cogroup(STRING_AGGREGATOR) + .cogroup(groupedThree, STRING_AGGREGATOR) + .aggregate(STRING_INITIALIZER); + + groupedOne.cogroup(STRING_AGGREGATOR) + .cogroup(groupedTwo, STRING_AGGREGATOR) + .aggregate(STRING_INITIALIZER); + + final String topologyDescription = builder.build().describe().toString(); + + assertThat( + topologyDescription, + equalTo("Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [one])\n" + + " --> KSTREAM-MAP-0000000003\n" + + " Processor: KSTREAM-MAP-0000000003 (stores: [])\n" + + " --> COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000004-repartition-filter, COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000011-repartition-filter\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000004-repartition-filter (stores: [])\n" + + " --> COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000004-repartition-sink\n" + + " <-- KSTREAM-MAP-0000000003\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000011-repartition-filter (stores: [])\n" + + " --> COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000011-repartition-sink\n" + + " <-- KSTREAM-MAP-0000000003\n" + + " Sink: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000004-repartition-sink (topic: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000004-repartition)\n" + + " <-- COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000004-repartition-filter\n" + + " Sink: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000011-repartition-sink (topic: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000011-repartition)\n" + + " <-- COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000011-repartition-filter\n\n" + + " Sub-topology: 1\n" + + " Source: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000011-repartition-source (topics: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000011-repartition])\n" + + " --> COGROUPKSTREAM-AGGREGATE-0000000015\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [two])\n" + + " --> COGROUPKSTREAM-AGGREGATE-0000000016\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-0000000015 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000011])\n" + + " --> COGROUPKSTREAM-MERGE-0000000017\n" + + " <-- COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000011-repartition-source\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-0000000016 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000011])\n" + + " --> COGROUPKSTREAM-MERGE-0000000017\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: COGROUPKSTREAM-MERGE-0000000017 (stores: [])\n" + + " --> none\n" + + " <-- COGROUPKSTREAM-AGGREGATE-0000000015, COGROUPKSTREAM-AGGREGATE-0000000016\n\n" + + " Sub-topology: 2\n" + + " Source: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000004-repartition-source (topics: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000004-repartition])\n" + + " --> COGROUPKSTREAM-AGGREGATE-0000000008\n" + + " Source: KSTREAM-SOURCE-0000000002 (topics: [three])\n" + + " --> COGROUPKSTREAM-AGGREGATE-0000000009\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-0000000008 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000004])\n" + + " --> COGROUPKSTREAM-MERGE-0000000010\n" + + " <-- COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000004-repartition-source\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-0000000009 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000004])\n" + + " --> COGROUPKSTREAM-MERGE-0000000010\n" + + " <-- KSTREAM-SOURCE-0000000002\n" + + " Processor: COGROUPKSTREAM-MERGE-0000000010 (stores: [])\n" + + " --> none\n" + + " <-- COGROUPKSTREAM-AGGREGATE-0000000008, COGROUPKSTREAM-AGGREGATE-0000000009\n\n")); + } + + @Test + public void shouldInsertRepartitionsTopicForUpstreamKeyModificationWithGroupedReusedInDifferentCogroupsWithOptimization() { + final StreamsBuilder builder = new StreamsBuilder(); + + final Properties properties = new Properties(); + properties.setProperty(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, StreamsConfig.OPTIMIZE); + + final KStream stream1 = builder.stream("one", stringConsumed); + final KStream stream2 = builder.stream("two", stringConsumed); + final KStream stream3 = builder.stream("three", stringConsumed); + + final KGroupedStream groupedOne = stream1.map((k, v) -> new KeyValue<>(v, k)).groupByKey(); + final KGroupedStream groupedTwo = stream2.groupByKey(); + final KGroupedStream groupedThree = stream3.groupByKey(); + + groupedOne.cogroup(STRING_AGGREGATOR) + .cogroup(groupedThree, STRING_AGGREGATOR) + .aggregate(STRING_INITIALIZER); + + groupedOne.cogroup(STRING_AGGREGATOR) + .cogroup(groupedTwo, STRING_AGGREGATOR) + .aggregate(STRING_INITIALIZER); + + final String topologyDescription = builder.build(properties).describe().toString(); + + assertThat( + topologyDescription, + equalTo("Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [one])\n" + + " --> KSTREAM-MAP-0000000003\n" + + " Processor: KSTREAM-MAP-0000000003 (stores: [])\n" + + " --> COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000004-repartition-filter\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000004-repartition-filter (stores: [])\n" + + " --> COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000004-repartition-sink\n" + + " <-- KSTREAM-MAP-0000000003\n" + + " Sink: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000004-repartition-sink (topic: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000004-repartition)\n" + + " <-- COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000004-repartition-filter\n\n" + + " Sub-topology: 1\n" + + " Source: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000004-repartition-source (topics: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000004-repartition])\n" + + " --> COGROUPKSTREAM-AGGREGATE-0000000008, COGROUPKSTREAM-AGGREGATE-0000000015\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [two])\n" + + " --> COGROUPKSTREAM-AGGREGATE-0000000016\n" + + " Source: KSTREAM-SOURCE-0000000002 (topics: [three])\n" + + " --> COGROUPKSTREAM-AGGREGATE-0000000009\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-0000000008 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000004])\n" + + " --> COGROUPKSTREAM-MERGE-0000000010\n" + + " <-- COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000004-repartition-source\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-0000000009 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000004])\n" + + " --> COGROUPKSTREAM-MERGE-0000000010\n" + + " <-- KSTREAM-SOURCE-0000000002\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-0000000015 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000011])\n" + + " --> COGROUPKSTREAM-MERGE-0000000017\n" + + " <-- COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000004-repartition-source\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-0000000016 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000011])\n" + + " --> COGROUPKSTREAM-MERGE-0000000017\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: COGROUPKSTREAM-MERGE-0000000010 (stores: [])\n" + + " --> none\n" + + " <-- COGROUPKSTREAM-AGGREGATE-0000000008, COGROUPKSTREAM-AGGREGATE-0000000009\n" + + " Processor: COGROUPKSTREAM-MERGE-0000000017 (stores: [])\n" + + " --> none\n" + + " <-- COGROUPKSTREAM-AGGREGATE-0000000015, COGROUPKSTREAM-AGGREGATE-0000000016\n\n")); + } + + @Test + public void shouldInsertRepartitionsTopicForUpstreamKeyModificationWithGroupedReused() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream stream1 = builder.stream("one", stringConsumed); + final KStream stream2 = builder.stream("two", stringConsumed); + + final KGroupedStream groupedOne = stream1.map((k, v) -> new KeyValue<>(v, k)).groupByKey(); + final KGroupedStream groupedTwo = stream2.groupByKey(); + + groupedOne.cogroup(STRING_AGGREGATOR) + .cogroup(groupedTwo, STRING_AGGREGATOR) + .aggregate(STRING_INITIALIZER); + + groupedOne.aggregate(STRING_INITIALIZER, STRING_AGGREGATOR); + + final String topologyDescription = builder.build().describe().toString(); + + assertThat( + topologyDescription, + equalTo("Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [one])\n" + + " --> KSTREAM-MAP-0000000002\n" + + " Processor: KSTREAM-MAP-0000000002 (stores: [])\n" + + " --> COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-filter, KSTREAM-FILTER-0000000013\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-filter (stores: [])\n" + + " --> COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-sink\n" + + " <-- KSTREAM-MAP-0000000002\n" + + " Processor: KSTREAM-FILTER-0000000013 (stores: [])\n" + + " --> KSTREAM-SINK-0000000012\n" + + " <-- KSTREAM-MAP-0000000002\n" + + " Sink: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-sink (topic: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition)\n" + + " <-- COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-filter\n" + + " Sink: KSTREAM-SINK-0000000012 (topic: KSTREAM-AGGREGATE-STATE-STORE-0000000010-repartition)\n" + + " <-- KSTREAM-FILTER-0000000013\n\n" + + " Sub-topology: 1\n" + + " Source: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-source (topics: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition])\n" + + " --> COGROUPKSTREAM-AGGREGATE-0000000007\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [two])\n" + + " --> COGROUPKSTREAM-AGGREGATE-0000000008\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-0000000007 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003])\n" + + " --> COGROUPKSTREAM-MERGE-0000000009\n" + + " <-- COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-source\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-0000000008 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003])\n" + + " --> COGROUPKSTREAM-MERGE-0000000009\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: COGROUPKSTREAM-MERGE-0000000009 (stores: [])\n" + + " --> none\n" + + " <-- COGROUPKSTREAM-AGGREGATE-0000000007, COGROUPKSTREAM-AGGREGATE-0000000008\n\n" + + " Sub-topology: 2\n" + + " Source: KSTREAM-SOURCE-0000000014 (topics: [KSTREAM-AGGREGATE-STATE-STORE-0000000010-repartition])\n" + + " --> KSTREAM-AGGREGATE-0000000011\n" + + " Processor: KSTREAM-AGGREGATE-0000000011 (stores: [KSTREAM-AGGREGATE-STATE-STORE-0000000010])\n" + + " --> none\n" + + " <-- KSTREAM-SOURCE-0000000014\n\n")); + } + + @Test + public void shouldInsertRepartitionsTopicForUpstreamKeyModificationWithGroupedReusedWithOptimization() { + final StreamsBuilder builder = new StreamsBuilder(); + + final Properties properties = new Properties(); + properties.setProperty(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, StreamsConfig.OPTIMIZE); + + final KStream stream1 = builder.stream("one", stringConsumed); + final KStream stream2 = builder.stream("two", stringConsumed); + + final KGroupedStream groupedOne = stream1.map((k, v) -> new KeyValue<>(v, k)).groupByKey(); + final KGroupedStream groupedTwo = stream2.groupByKey(); + + groupedOne.cogroup(STRING_AGGREGATOR) + .cogroup(groupedTwo, STRING_AGGREGATOR) + .aggregate(STRING_INITIALIZER); + + groupedOne.aggregate(STRING_INITIALIZER, STRING_AGGREGATOR); + + final String topologyDescription = builder.build(properties).describe().toString(); + + assertThat( + topologyDescription, + equalTo("Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [one])\n" + + " --> KSTREAM-MAP-0000000002\n" + + " Processor: KSTREAM-MAP-0000000002 (stores: [])\n" + + " --> COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-filter\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-filter (stores: [])\n" + + " --> COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-sink\n" + + " <-- KSTREAM-MAP-0000000002\n" + + " Sink: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-sink (topic: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition)\n" + + " <-- COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-filter\n\n" + + " Sub-topology: 1\n" + + " Source: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-source (topics: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition])\n" + + " --> COGROUPKSTREAM-AGGREGATE-0000000007, KSTREAM-AGGREGATE-0000000011\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [two])\n" + + " --> COGROUPKSTREAM-AGGREGATE-0000000008\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-0000000007 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003])\n" + + " --> COGROUPKSTREAM-MERGE-0000000009\n" + + " <-- COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-source\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-0000000008 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003])\n" + + " --> COGROUPKSTREAM-MERGE-0000000009\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: COGROUPKSTREAM-MERGE-0000000009 (stores: [])\n" + + " --> none\n" + + " <-- COGROUPKSTREAM-AGGREGATE-0000000007, COGROUPKSTREAM-AGGREGATE-0000000008\n" + + " Processor: KSTREAM-AGGREGATE-0000000011 (stores: [KSTREAM-AGGREGATE-STATE-STORE-0000000010])\n" + + " --> none\n" + + " <-- COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000003-repartition-source\n\n")); + } + + @Test + public void shouldInsertRepartitionsTopicForUpstreamKeyModificationWithGroupedRemadeWithOptimization() { + final StreamsBuilder builder = new StreamsBuilder(); + + final Properties properties = new Properties(); + properties.setProperty(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, StreamsConfig.OPTIMIZE); + + final KStream stream1 = builder.stream("one", stringConsumed); + final KStream stream2 = builder.stream("two", stringConsumed); + final KStream stream3 = builder.stream("three", stringConsumed); + + final KGroupedStream groupedOne = stream1.map((k, v) -> new KeyValue<>(v, k)).groupByKey(); + final KGroupedStream groupedTwo = stream2.groupByKey(); + final KGroupedStream groupedThree = stream3.groupByKey(); + final KGroupedStream groupedFour = stream1.map((k, v) -> new KeyValue<>(v, k)).groupByKey(); + + + groupedOne.cogroup(STRING_AGGREGATOR) + .cogroup(groupedTwo, STRING_AGGREGATOR) + .aggregate(STRING_INITIALIZER); + + groupedThree.cogroup(STRING_AGGREGATOR) + .cogroup(groupedFour, STRING_AGGREGATOR) + .aggregate(STRING_INITIALIZER); + + + final String topologyDescription = builder.build(properties).describe().toString(); + + assertThat( + topologyDescription, + equalTo("Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [one])\n" + + " --> KSTREAM-MAP-0000000003, KSTREAM-MAP-0000000004\n" + + " Processor: KSTREAM-MAP-0000000003 (stores: [])\n" + + " --> COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000005-repartition-filter\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: KSTREAM-MAP-0000000004 (stores: [])\n" + + " --> COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000012-repartition-filter\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000005-repartition-filter (stores: [])\n" + + " --> COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000005-repartition-sink\n" + + " <-- KSTREAM-MAP-0000000003\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000012-repartition-filter (stores: [])\n" + + " --> COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000012-repartition-sink\n" + + " <-- KSTREAM-MAP-0000000004\n" + + " Sink: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000005-repartition-sink (topic: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000005-repartition)\n" + + " <-- COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000005-repartition-filter\n" + + " Sink: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000012-repartition-sink (topic: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000012-repartition)\n" + + " <-- COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000012-repartition-filter\n\n" + + " Sub-topology: 1\n" + + " Source: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000005-repartition-source (topics: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000005-repartition])\n" + + " --> COGROUPKSTREAM-AGGREGATE-0000000009\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [two])\n" + + " --> COGROUPKSTREAM-AGGREGATE-0000000010\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-0000000009 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000005])\n" + + " --> COGROUPKSTREAM-MERGE-0000000011\n" + + " <-- COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000005-repartition-source\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-0000000010 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000005])\n" + + " --> COGROUPKSTREAM-MERGE-0000000011\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: COGROUPKSTREAM-MERGE-0000000011 (stores: [])\n" + + " --> none\n" + + " <-- COGROUPKSTREAM-AGGREGATE-0000000009, COGROUPKSTREAM-AGGREGATE-0000000010\n\n" + + " Sub-topology: 2\n" + + " Source: COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000012-repartition-source (topics: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000012-repartition])\n" + + " --> COGROUPKSTREAM-AGGREGATE-0000000017\n" + + " Source: KSTREAM-SOURCE-0000000002 (topics: [three])\n" + + " --> COGROUPKSTREAM-AGGREGATE-0000000016\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-0000000016 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000012])\n" + + " --> COGROUPKSTREAM-MERGE-0000000018\n" + + " <-- KSTREAM-SOURCE-0000000002\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-0000000017 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000012])\n" + + " --> COGROUPKSTREAM-MERGE-0000000018\n" + + " <-- COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000012-repartition-source\n" + + " Processor: COGROUPKSTREAM-MERGE-0000000018 (stores: [])\n" + + " --> none\n" + + " <-- COGROUPKSTREAM-AGGREGATE-0000000016, COGROUPKSTREAM-AGGREGATE-0000000017\n\n")); + } + + @Test + public void shouldInsertRepartitionsTopicForCogroupsUsedTwice() { + final StreamsBuilder builder = new StreamsBuilder(); + + final Properties properties = new Properties(); + + final KStream stream1 = builder.stream("one", stringConsumed); + + final KGroupedStream groupedOne = stream1.map((k, v) -> new KeyValue<>(v, k)).groupByKey(Grouped.as("foo")); + + final CogroupedKStream one = groupedOne.cogroup(STRING_AGGREGATOR); + one.aggregate(STRING_INITIALIZER); + one.aggregate(STRING_INITIALIZER); + + final String topologyDescription = builder.build(properties).describe().toString(); + + assertThat( + topologyDescription, + equalTo("Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [one])\n" + + " --> KSTREAM-MAP-0000000001\n" + + " Processor: KSTREAM-MAP-0000000001 (stores: [])\n" + + " --> foo-repartition-filter\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: foo-repartition-filter (stores: [])\n" + + " --> foo-repartition-sink\n" + + " <-- KSTREAM-MAP-0000000001\n" + + " Sink: foo-repartition-sink (topic: foo-repartition)\n" + + " <-- foo-repartition-filter\n\n" + + " Sub-topology: 1\n" + + " Source: foo-repartition-source (topics: [foo-repartition])\n" + + " --> COGROUPKSTREAM-AGGREGATE-0000000006, COGROUPKSTREAM-AGGREGATE-0000000012\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-0000000006 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000002])\n" + + " --> COGROUPKSTREAM-MERGE-0000000007\n" + + " <-- foo-repartition-source\n" + + " Processor: COGROUPKSTREAM-AGGREGATE-0000000012 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000008])\n" + + " --> COGROUPKSTREAM-MERGE-0000000013\n" + + " <-- foo-repartition-source\n" + + " Processor: COGROUPKSTREAM-MERGE-0000000007 (stores: [])\n" + + " --> none\n" + + " <-- COGROUPKSTREAM-AGGREGATE-0000000006\n" + + " Processor: COGROUPKSTREAM-MERGE-0000000013 (stores: [])\n" + + " --> none\n" + + " <-- COGROUPKSTREAM-AGGREGATE-0000000012\n\n")); + } + + @Test + public void shouldCogroupAndAggregateSingleKStreams() { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream1 = builder.stream("one", stringConsumed); + + final KGroupedStream grouped1 = stream1.groupByKey(); + + final KTable customers = grouped1 + .cogroup(STRING_AGGREGATOR) + .aggregate(STRING_INITIALIZER); + + customers.toStream().to(OUTPUT); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic testInputTopic = + driver.createInputTopic("one", new StringSerializer(), new StringSerializer()); + final TestOutputTopic testOutputTopic = + driver.createOutputTopic(OUTPUT, new StringDeserializer(), new StringDeserializer()); + testInputTopic.pipeInput("k1", "A", 0); + testInputTopic.pipeInput("k2", "B", 0); + testInputTopic.pipeInput("k2", "B", 0); + testInputTopic.pipeInput("k1", "A", 0); + + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "A", 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "B", 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "BB", 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "AA", 0); + } + } + + @Test + public void testCogroupHandleNullValues() { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream1 = builder.stream("one", stringConsumed); + + final KGroupedStream grouped1 = stream1.groupByKey(); + + final KTable customers = grouped1 + .cogroup(STRING_AGGREGATOR) + .aggregate(STRING_INITIALIZER); + + customers.toStream().to(OUTPUT); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic testInputTopic = driver.createInputTopic("one", new StringSerializer(), new StringSerializer()); + final TestOutputTopic testOutputTopic = driver.createOutputTopic(OUTPUT, new StringDeserializer(), new StringDeserializer()); + testInputTopic.pipeInput("k1", "A", 0); + testInputTopic.pipeInput("k2", "B", 0); + testInputTopic.pipeInput("k2", null, 0); + testInputTopic.pipeInput("k2", "B", 0); + testInputTopic.pipeInput("k1", "A", 0); + + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "A", 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "B", 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "BB", 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "AA", 0); + } + } + + @Test + public void shouldCogroupAndAggregateTwoKStreamsWithDistinctKeys() { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream1 = builder.stream("one", stringConsumed); + final KStream stream2 = builder.stream("two", stringConsumed); + + final KGroupedStream grouped1 = stream1.groupByKey(); + final KGroupedStream grouped2 = stream2.groupByKey(); + + final KTable customers = grouped1 + .cogroup(STRING_AGGREGATOR) + .cogroup(grouped2, STRING_AGGREGATOR) + .aggregate(STRING_INITIALIZER); + + customers.toStream().to(OUTPUT); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic testInputTopic = + driver.createInputTopic("one", new StringSerializer(), new StringSerializer()); + final TestInputTopic testInputTopic2 = + driver.createInputTopic("two", new StringSerializer(), new StringSerializer()); + final TestOutputTopic testOutputTopic = + driver.createOutputTopic(OUTPUT, new StringDeserializer(), new StringDeserializer()); + + testInputTopic.pipeInput("k1", "A", 0); + testInputTopic.pipeInput("k1", "A", 1); + testInputTopic.pipeInput("k1", "A", 10); + testInputTopic.pipeInput("k1", "A", 100); + testInputTopic2.pipeInput("k2", "B", 100L); + testInputTopic2.pipeInput("k2", "B", 200L); + testInputTopic2.pipeInput("k2", "B", 1L); + testInputTopic2.pipeInput("k2", "B", 500L); + testInputTopic2.pipeInput("k2", "B", 500L); + testInputTopic2.pipeInput("k2", "B", 100L); + + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "A", 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "AA", 1); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "AAA", 10); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "AAAA", 100); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "B", 100); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "BB", 200); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "BBB", 200); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "BBBB", 500); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "BBBBB", 500); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "BBBBBB", 500); + } + } + + @Test + public void shouldCogroupAndAggregateTwoKStreamsWithSharedKeys() { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream1 = builder.stream("one", stringConsumed); + final KStream stream2 = builder.stream("two", stringConsumed); + + final KGroupedStream grouped1 = stream1.groupByKey(); + final KGroupedStream grouped2 = stream2.groupByKey(); + + final KTable customers = grouped1 + .cogroup(STRING_AGGREGATOR) + .cogroup(grouped2, STRING_AGGREGATOR) + .aggregate(STRING_INITIALIZER); + + customers.toStream().to(OUTPUT); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic testInputTopic = + driver.createInputTopic("one", new StringSerializer(), new StringSerializer()); + final TestInputTopic testInputTopic2 = + driver.createInputTopic("two", new StringSerializer(), new StringSerializer()); + final TestOutputTopic testOutputTopic = + driver.createOutputTopic(OUTPUT, new StringDeserializer(), new StringDeserializer()); + + testInputTopic.pipeInput("k1", "A", 0L); + testInputTopic.pipeInput("k2", "A", 1L); + testInputTopic.pipeInput("k1", "A", 10L); + testInputTopic.pipeInput("k2", "A", 100L); + testInputTopic2.pipeInput("k2", "B", 100L); + testInputTopic2.pipeInput("k2", "B", 200L); + testInputTopic2.pipeInput("k1", "B", 1L); + testInputTopic2.pipeInput("k2", "B", 500L); + testInputTopic2.pipeInput("k1", "B", 500L); + testInputTopic2.pipeInput("k2", "B", 500L); + testInputTopic2.pipeInput("k3", "B", 500L); + testInputTopic2.pipeInput("k2", "B", 100L); + + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "A", 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "A", 1); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "AA", 10); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "AA", 100); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "AAB", 100); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "AABB", 200); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "AAB", 10); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "AABBB", 500); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "AABB", 500); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "AABBBB", 500); + } + } + + @Test + public void shouldAllowDifferentOutputTypeInCoGroup() { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream1 = builder.stream("one", stringConsumed); + final KStream stream2 = builder.stream("two", stringConsumed); + + final KGroupedStream grouped1 = stream1.groupByKey(); + final KGroupedStream grouped2 = stream2.groupByKey(); + + final KTable customers = grouped1 + .cogroup(STRING_SUM_AGGREGATOR) + .cogroup(grouped2, STRING_SUM_AGGREGATOR) + .aggregate( + SUM_INITIALIZER, + Materialized.>as("store1") + .withValueSerde(Serdes.Integer())); + + customers.toStream().to(OUTPUT); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic testInputTopic = + driver.createInputTopic("one", new StringSerializer(), new StringSerializer()); + final TestInputTopic testInputTopic2 = + driver.createInputTopic("two", new StringSerializer(), new StringSerializer()); + final TestOutputTopic testOutputTopic = + driver.createOutputTopic(OUTPUT, new StringDeserializer(), new IntegerDeserializer()); + + testInputTopic.pipeInput("k1", "1", 0L); + testInputTopic.pipeInput("k2", "1", 1L); + testInputTopic.pipeInput("k1", "1", 10L); + testInputTopic.pipeInput("k2", "1", 100L); + testInputTopic2.pipeInput("k2", "2", 100L); + testInputTopic2.pipeInput("k2", "2", 200L); + testInputTopic2.pipeInput("k1", "2", 1L); + testInputTopic2.pipeInput("k2", "2", 500L); + testInputTopic2.pipeInput("k1", "2", 500L); + testInputTopic2.pipeInput("k2", "3", 500L); + testInputTopic2.pipeInput("k3", "2", 500L); + testInputTopic2.pipeInput("k2", "2", 100L); + + assertOutputKeyValueTimestamp(testOutputTopic, "k1", 1, 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", 1, 1); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", 2, 10); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", 2, 100); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", 4, 100); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", 6, 200); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", 4, 10); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", 8, 500); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", 6, 500); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", 11, 500); + } + } + + @Test + public void shouldCoGroupStreamsWithDifferentInputTypes() { + final StreamsBuilder builder = new StreamsBuilder(); + final Consumed integerConsumed = Consumed.with(Serdes.String(), Serdes.Integer()); + final KStream stream1 = builder.stream("one", stringConsumed); + final KStream stream2 = builder.stream("two", integerConsumed); + + final KGroupedStream grouped1 = stream1.groupByKey(); + final KGroupedStream grouped2 = stream2.groupByKey(); + + final KTable customers = grouped1 + .cogroup(STRING_SUM_AGGREGATOR) + .cogroup(grouped2, SUM_AGGREGATOR) + .aggregate( + SUM_INITIALIZER, + Materialized.>as("store1") + .withValueSerde(Serdes.Integer())); + + customers.toStream().to(OUTPUT); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic testInputTopic = driver.createInputTopic("one", new StringSerializer(), new StringSerializer()); + final TestInputTopic testInputTopic2 = driver.createInputTopic("two", new StringSerializer(), new IntegerSerializer()); + final TestOutputTopic testOutputTopic = driver.createOutputTopic(OUTPUT, new StringDeserializer(), new IntegerDeserializer()); + testInputTopic.pipeInput("k1", "1", 0L); + testInputTopic.pipeInput("k2", "1", 1L); + testInputTopic.pipeInput("k1", "1", 10L); + testInputTopic.pipeInput("k2", "1", 100L); + + testInputTopic2.pipeInput("k2", 2, 100L); + testInputTopic2.pipeInput("k2", 2, 200L); + testInputTopic2.pipeInput("k1", 2, 1L); + testInputTopic2.pipeInput("k2", 2, 500L); + testInputTopic2.pipeInput("k1", 2, 500L); + testInputTopic2.pipeInput("k2", 3, 500L); + testInputTopic2.pipeInput("k3", 2, 500L); + testInputTopic2.pipeInput("k2", 2, 100L); + + assertOutputKeyValueTimestamp(testOutputTopic, "k1", 1, 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", 1, 1); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", 2, 10); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", 2, 100); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", 4, 100); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", 6, 200); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", 4, 10); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", 8, 500); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", 6, 500); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", 11, 500); + assertOutputKeyValueTimestamp(testOutputTopic, "k3", 2, 500); + } + } + + @Test + public void testCogroupKeyMixedAggregators() { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream1 = builder.stream("one", stringConsumed); + final KStream stream2 = builder.stream("two", stringConsumed); + + final KGroupedStream grouped1 = stream1.groupByKey(); + final KGroupedStream grouped2 = stream2.groupByKey(); + + final KTable customers = grouped1 + .cogroup(MockAggregator.TOSTRING_REMOVER) + .cogroup(grouped2, MockAggregator.TOSTRING_ADDER) + .aggregate( + MockInitializer.STRING_INIT, + Materialized.>as("store1") + .withValueSerde(Serdes.String())); + + customers.toStream().to(OUTPUT); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic testInputTopic = + driver.createInputTopic("one", new StringSerializer(), new StringSerializer()); + final TestInputTopic testInputTopic2 = + driver.createInputTopic("two", new StringSerializer(), new StringSerializer()); + final TestOutputTopic testOutputTopic = + driver.createOutputTopic(OUTPUT, new StringDeserializer(), new StringDeserializer()); + + testInputTopic.pipeInput("k1", "1", 0L); + testInputTopic.pipeInput("k2", "1", 1L); + testInputTopic.pipeInput("k1", "1", 10L); + testInputTopic.pipeInput("k2", "1", 100L); + testInputTopic2.pipeInput("k1", "2", 500L); + testInputTopic2.pipeInput("k2", "2", 500L); + testInputTopic2.pipeInput("k1", "2", 500L); + testInputTopic2.pipeInput("k2", "2", 100L); + + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0-1", 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0-1", 1); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0-1-1", 10); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0-1-1", 100); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0-1-1+2", 500L); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0-1-1+2", 500L); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0-1-1+2+2", 500L); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0-1-1+2+2", 500L); + } + } + + @Test + public void testCogroupWithThreeGroupedStreams() { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream1 = builder.stream("one", stringConsumed); + final KStream stream2 = builder.stream("two", stringConsumed); + final KStream stream3 = builder.stream("three", stringConsumed); + + final KGroupedStream grouped1 = stream1.groupByKey(); + final KGroupedStream grouped2 = stream2.groupByKey(); + final KGroupedStream grouped3 = stream3.groupByKey(); + + final KTable customers = grouped1 + .cogroup(STRING_AGGREGATOR) + .cogroup(grouped2, STRING_AGGREGATOR) + .cogroup(grouped3, STRING_AGGREGATOR) + .aggregate(STRING_INITIALIZER); + + customers.toStream().to(OUTPUT); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic testInputTopic = + driver.createInputTopic("one", new StringSerializer(), new StringSerializer()); + final TestInputTopic testInputTopic2 = + driver.createInputTopic("two", new StringSerializer(), new StringSerializer()); + final TestInputTopic testInputTopic3 = + driver.createInputTopic("three", new StringSerializer(), new StringSerializer()); + + final TestOutputTopic testOutputTopic = + driver.createOutputTopic(OUTPUT, new StringDeserializer(), new StringDeserializer()); + + testInputTopic.pipeInput("k1", "A", 0L); + testInputTopic.pipeInput("k2", "A", 1L); + testInputTopic.pipeInput("k1", "A", 10L); + testInputTopic.pipeInput("k2", "A", 100L); + testInputTopic2.pipeInput("k2", "B", 100L); + testInputTopic2.pipeInput("k2", "B", 200L); + testInputTopic2.pipeInput("k1", "B", 1L); + testInputTopic2.pipeInput("k2", "B", 500L); + testInputTopic3.pipeInput("k1", "B", 500L); + testInputTopic3.pipeInput("k2", "B", 500L); + testInputTopic3.pipeInput("k3", "B", 500L); + testInputTopic3.pipeInput("k2", "B", 100L); + + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "A", 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "A", 1); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "AA", 10); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "AA", 100); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "AAB", 100); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "AABB", 200); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "AAB", 10); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "AABBB", 500); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "AABB", 500); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "AABBBB", 500); + assertOutputKeyValueTimestamp(testOutputTopic, "k3", "B", 500); + } + } + + @Test + public void testCogroupWithKTableKTableInnerJoin() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KGroupedStream grouped1 = builder.stream("one", stringConsumed).groupByKey(); + final KGroupedStream grouped2 = builder.stream("two", stringConsumed).groupByKey(); + + final KTable table1 = grouped1 + .cogroup(STRING_AGGREGATOR) + .cogroup(grouped2, STRING_AGGREGATOR) + .aggregate(STRING_INITIALIZER, Named.as("name"), Materialized.as("store")); + + final KTable table2 = builder.table("three", stringConsumed); + final KTable joined = table1.join(table2, MockValueJoiner.TOSTRING_JOINER, Materialized.with(Serdes.String(), Serdes.String())); + joined.toStream().to(OUTPUT); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic testInputTopic = + driver.createInputTopic("one", new StringSerializer(), new StringSerializer()); + final TestInputTopic testInputTopic2 = + driver.createInputTopic("two", new StringSerializer(), new StringSerializer()); + final TestInputTopic testInputTopic3 = + driver.createInputTopic("three", new StringSerializer(), new StringSerializer()); + final TestOutputTopic testOutputTopic = + driver.createOutputTopic(OUTPUT, new StringDeserializer(), new StringDeserializer()); + + testInputTopic.pipeInput("k1", "A", 5L); + testInputTopic2.pipeInput("k2", "B", 6L); + + assertTrue(testOutputTopic.isEmpty()); + + testInputTopic3.pipeInput("k1", "C", 0L); + testInputTopic3.pipeInput("k2", "D", 10L); + + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "A+C", 5L); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "B+D", 10L); + assertTrue(testOutputTopic.isEmpty()); + } + } + + private void assertOutputKeyValueTimestamp(final TestOutputTopic outputTopic, + final String expectedKey, + final String expectedValue, + final long expectedTimestamp) { + assertThat( + outputTopic.readRecord(), + equalTo(new TestRecord<>(expectedKey, expectedValue, null, expectedTimestamp))); + } + + private void assertOutputKeyValueTimestamp(final TestOutputTopic outputTopic, + final String expectedKey, + final Integer expectedValue, + final long expectedTimestamp) { + assertThat( + outputTopic.readRecord(), + equalTo(new TestRecord<>(expectedKey, expectedValue, null, expectedTimestamp))); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/FullChangeSerdeTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/FullChangeSerdeTest.java new file mode 100644 index 0000000..e7e0c88 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/FullChangeSerdeTest.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Serdes; +import org.junit.Test; + +import java.nio.ByteBuffer; + +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.Is.is; + +public class FullChangeSerdeTest { + private final FullChangeSerde serde = FullChangeSerde.wrap(Serdes.String()); + + /** + * We used to serialize a Change into a single byte[]. Now, we don't anymore, but we still keep this logic here + * so that we can produce the legacy format to test that we can still deserialize it. + */ + private static byte[] mergeChangeArraysIntoSingleLegacyFormattedArray(final Change serialChange) { + if (serialChange == null) { + return null; + } + + final int oldSize = serialChange.oldValue == null ? -1 : serialChange.oldValue.length; + final int newSize = serialChange.newValue == null ? -1 : serialChange.newValue.length; + + final ByteBuffer buffer = ByteBuffer.allocate(Integer.BYTES * 2 + Math.max(0, oldSize) + Math.max(0, newSize)); + + + buffer.putInt(oldSize); + if (serialChange.oldValue != null) { + buffer.put(serialChange.oldValue); + } + + buffer.putInt(newSize); + if (serialChange.newValue != null) { + buffer.put(serialChange.newValue); + } + return buffer.array(); + } + + @Test + public void shouldRoundTripNull() { + assertThat(serde.serializeParts(null, null), nullValue()); + assertThat(mergeChangeArraysIntoSingleLegacyFormattedArray(null), nullValue()); + assertThat(FullChangeSerde.decomposeLegacyFormattedArrayIntoChangeArrays(null), nullValue()); + assertThat(serde.deserializeParts(null, null), nullValue()); + } + + + @Test + public void shouldRoundTripNullChange() { + assertThat( + serde.serializeParts(null, new Change<>(null, null)), + is(new Change(null, null)) + ); + + assertThat( + serde.deserializeParts(null, new Change<>(null, null)), + is(new Change(null, null)) + ); + + final byte[] legacyFormat = mergeChangeArraysIntoSingleLegacyFormattedArray(new Change<>(null, null)); + assertThat( + FullChangeSerde.decomposeLegacyFormattedArrayIntoChangeArrays(legacyFormat), + is(new Change(null, null)) + ); + } + + @Test + public void shouldRoundTripOldNull() { + final Change serialized = serde.serializeParts(null, new Change<>("new", null)); + final byte[] legacyFormat = mergeChangeArraysIntoSingleLegacyFormattedArray(serialized); + final Change decomposedLegacyFormat = FullChangeSerde.decomposeLegacyFormattedArrayIntoChangeArrays(legacyFormat); + assertThat( + serde.deserializeParts(null, decomposedLegacyFormat), + is(new Change<>("new", null)) + ); + } + + @Test + public void shouldRoundTripNewNull() { + final Change serialized = serde.serializeParts(null, new Change<>(null, "old")); + final byte[] legacyFormat = mergeChangeArraysIntoSingleLegacyFormattedArray(serialized); + final Change decomposedLegacyFormat = FullChangeSerde.decomposeLegacyFormattedArrayIntoChangeArrays(legacyFormat); + assertThat( + serde.deserializeParts(null, decomposedLegacyFormat), + is(new Change<>(null, "old")) + ); + } + + @Test + public void shouldRoundTripChange() { + final Change serialized = serde.serializeParts(null, new Change<>("new", "old")); + final byte[] legacyFormat = mergeChangeArraysIntoSingleLegacyFormattedArray(serialized); + final Change decomposedLegacyFormat = FullChangeSerde.decomposeLegacyFormattedArrayIntoChangeArrays(legacyFormat); + assertThat( + serde.deserializeParts(null, decomposedLegacyFormat), + is(new Change<>("new", "old")) + ); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/GlobalKTableJoinsTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/GlobalKTableJoinsTest.java new file mode 100644 index 0000000..e4d95e0 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/GlobalKTableJoinsTest.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.GlobalKTable; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.MockValueJoiner; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Before; +import org.junit.Test; + +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; + +import static org.junit.Assert.assertEquals; + + +public class GlobalKTableJoinsTest { + + private final StreamsBuilder builder = new StreamsBuilder(); + private final String streamTopic = "stream"; + private final String globalTopic = "global"; + private GlobalKTable global; + private KStream stream; + private KeyValueMapper keyValueMapper; + + @Before + public void setUp() { + final Consumed consumed = Consumed.with(Serdes.String(), Serdes.String()); + global = builder.globalTable(globalTopic, consumed); + stream = builder.stream(streamTopic, consumed); + keyValueMapper = (key, value) -> value; + } + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Test + public void shouldLeftJoinWithStream() { + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + stream + .leftJoin(global, keyValueMapper, MockValueJoiner.TOSTRING_JOINER) + .process(supplier); + + final Map> expected = new HashMap<>(); + expected.put("1", ValueAndTimestamp.make("a+A", 2L)); + expected.put("2", ValueAndTimestamp.make("b+B", 10L)); + expected.put("3", ValueAndTimestamp.make("c+null", 3L)); + + verifyJoin(expected, supplier); + } + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Test + public void shouldInnerJoinWithStream() { + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + stream + .join(global, keyValueMapper, MockValueJoiner.TOSTRING_JOINER) + .process(supplier); + + final Map> expected = new HashMap<>(); + expected.put("1", ValueAndTimestamp.make("a+A", 2L)); + expected.put("2", ValueAndTimestamp.make("b+B", 10L)); + + verifyJoin(expected, supplier); + } + + private void verifyJoin(final Map> expected, + final MockProcessorSupplier supplier) { + final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.String(), Serdes.String()); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic globalInputTopic = driver.createInputTopic(globalTopic, new StringSerializer(), new StringSerializer()); + // write some data to the global table + globalInputTopic.pipeInput("a", "A", 1L); + globalInputTopic.pipeInput("b", "B", 5L); + final TestInputTopic streamInputTopic = driver.createInputTopic(streamTopic, new StringSerializer(), new StringSerializer()); + //write some data to the stream + streamInputTopic.pipeInput("1", "a", 2L); + streamInputTopic.pipeInput("2", "b", 10L); + streamInputTopic.pipeInput("3", "c", 3L); + } + + assertEquals(expected, supplier.theCapturedProcessor().lastValueAndTimestampPerKey()); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/InternalStreamsBuilderTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/InternalStreamsBuilderTest.java new file mode 100644 index 0000000..73155da --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/InternalStreamsBuilderTest.java @@ -0,0 +1,365 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.clients.consumer.OffsetResetStrategy; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.GlobalKTable; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; +import org.apache.kafka.streams.processor.internals.ProcessorTopology; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.test.MockMapper; +import org.apache.kafka.test.MockTimestampExtractor; +import org.apache.kafka.test.MockValueJoiner; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Test; + +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.regex.Pattern; + +import static java.util.Arrays.asList; +import static org.apache.kafka.streams.Topology.AutoOffsetReset; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.core.IsInstanceOf.instanceOf; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +public class InternalStreamsBuilderTest { + + private static final String APP_ID = "app-id"; + + private final InternalStreamsBuilder builder = new InternalStreamsBuilder(new InternalTopologyBuilder()); + private final ConsumedInternal consumed = new ConsumedInternal<>(); + private final String storePrefix = "prefix-"; + private final MaterializedInternal> materialized = new MaterializedInternal<>(Materialized.as("test-store"), builder, storePrefix); + + @Test + public void testNewName() { + assertEquals("X-0000000000", builder.newProcessorName("X-")); + assertEquals("Y-0000000001", builder.newProcessorName("Y-")); + assertEquals("Z-0000000002", builder.newProcessorName("Z-")); + + final InternalStreamsBuilder newBuilder = new InternalStreamsBuilder(new InternalTopologyBuilder()); + + assertEquals("X-0000000000", newBuilder.newProcessorName("X-")); + assertEquals("Y-0000000001", newBuilder.newProcessorName("Y-")); + assertEquals("Z-0000000002", newBuilder.newProcessorName("Z-")); + } + + @Test + public void testNewStoreName() { + assertEquals("X-STATE-STORE-0000000000", builder.newStoreName("X-")); + assertEquals("Y-STATE-STORE-0000000001", builder.newStoreName("Y-")); + assertEquals("Z-STATE-STORE-0000000002", builder.newStoreName("Z-")); + + final InternalStreamsBuilder newBuilder = new InternalStreamsBuilder(new InternalTopologyBuilder()); + + assertEquals("X-STATE-STORE-0000000000", newBuilder.newStoreName("X-")); + assertEquals("Y-STATE-STORE-0000000001", newBuilder.newStoreName("Y-")); + assertEquals("Z-STATE-STORE-0000000002", newBuilder.newStoreName("Z-")); + } + + @Test + public void shouldHaveCorrectSourceTopicsForTableFromMergedStream() { + final String topic1 = "topic-1"; + final String topic2 = "topic-2"; + final String topic3 = "topic-3"; + final KStream source1 = builder.stream(Collections.singleton(topic1), consumed); + final KStream source2 = builder.stream(Collections.singleton(topic2), consumed); + final KStream source3 = builder.stream(Collections.singleton(topic3), consumed); + final KStream processedSource1 = + source1.mapValues(v -> v) + .filter((k, v) -> true); + final KStream processedSource2 = source2.filter((k, v) -> true); + + final KStream merged = processedSource1.merge(processedSource2).merge(source3); + merged.groupByKey().count(Materialized.as("my-table")); + builder.buildAndOptimizeTopology(); + final Map> actual = builder.internalTopologyBuilder.stateStoreNameToSourceTopics(); + assertEquals(asList("topic-1", "topic-2", "topic-3"), actual.get("my-table")); + } + + @Test + public void shouldNotMaterializeSourceKTableIfNotRequired() { + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(Materialized.with(null, null), builder, storePrefix); + final KTable table1 = builder.table("topic2", consumed, materializedInternal); + + builder.buildAndOptimizeTopology(); + final ProcessorTopology topology = builder.internalTopologyBuilder + .rewriteTopology(new StreamsConfig(StreamsTestUtils.getStreamsConfig(APP_ID))) + .buildTopology(); + + assertEquals(0, topology.stateStores().size()); + assertEquals(0, topology.storeToChangelogTopic().size()); + assertNull(table1.queryableStoreName()); + } + + @Test + public void shouldBuildGlobalTableWithNonQueryableStoreName() { + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(Materialized.with(null, null), builder, storePrefix); + + final GlobalKTable table1 = builder.globalTable("topic2", consumed, materializedInternal); + + assertNull(table1.queryableStoreName()); + } + + @Test + public void shouldBuildGlobalTableWithQueryaIbleStoreName() { + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(Materialized.as("globalTable"), builder, storePrefix); + final GlobalKTable table1 = builder.globalTable("topic2", consumed, materializedInternal); + + assertEquals("globalTable", table1.queryableStoreName()); + } + + @Test + public void shouldBuildSimpleGlobalTableTopology() { + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(Materialized.as("globalTable"), builder, storePrefix); + builder.globalTable("table", + consumed, + materializedInternal); + + builder.buildAndOptimizeTopology(); + final ProcessorTopology topology = builder.internalTopologyBuilder + .rewriteTopology(new StreamsConfig(StreamsTestUtils.getStreamsConfig(APP_ID))) + .buildGlobalStateTopology(); + final List stateStores = topology.globalStateStores(); + + assertEquals(1, stateStores.size()); + assertEquals("globalTable", stateStores.get(0).name()); + } + + private void doBuildGlobalTopologyWithAllGlobalTables() { + final ProcessorTopology topology = builder.internalTopologyBuilder + .rewriteTopology(new StreamsConfig(StreamsTestUtils.getStreamsConfig(APP_ID))) + .buildGlobalStateTopology(); + + final List stateStores = topology.globalStateStores(); + final Set sourceTopics = topology.sourceTopics(); + + assertEquals(Utils.mkSet("table", "table2"), sourceTopics); + assertEquals(2, stateStores.size()); + } + + @Test + public void shouldBuildGlobalTopologyWithAllGlobalTables() { + { + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(Materialized.as("global1"), builder, storePrefix); + builder.globalTable("table", consumed, materializedInternal); + } + { + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(Materialized.as("global2"), builder, storePrefix); + builder.globalTable("table2", consumed, materializedInternal); + } + + builder.buildAndOptimizeTopology(); + doBuildGlobalTopologyWithAllGlobalTables(); + } + + @Test + public void shouldAddGlobalTablesToEachGroup() { + final String one = "globalTable"; + final String two = "globalTable2"; + + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(Materialized.as(one), builder, storePrefix); + final GlobalKTable globalTable = builder.globalTable("table", consumed, materializedInternal); + + final MaterializedInternal> materializedInternal2 = + new MaterializedInternal<>(Materialized.as(two), builder, storePrefix); + final GlobalKTable globalTable2 = builder.globalTable("table2", consumed, materializedInternal2); + + final MaterializedInternal> materializedInternalNotGlobal = + new MaterializedInternal<>(Materialized.as("not-global"), builder, storePrefix); + builder.table("not-global", consumed, materializedInternalNotGlobal); + + final KeyValueMapper kvMapper = (key, value) -> value; + + final KStream stream = builder.stream(Collections.singleton("t1"), consumed); + stream.leftJoin(globalTable, kvMapper, MockValueJoiner.TOSTRING_JOINER); + final KStream stream2 = builder.stream(Collections.singleton("t2"), consumed); + stream2.leftJoin(globalTable2, kvMapper, MockValueJoiner.TOSTRING_JOINER); + + final Map> nodeGroups = builder.internalTopologyBuilder.nodeGroups(); + for (final Integer groupId : nodeGroups.keySet()) { + final ProcessorTopology topology = builder.internalTopologyBuilder.buildSubtopology(groupId); + final List stateStores = topology.globalStateStores(); + final Set names = new HashSet<>(); + for (final StateStore stateStore : stateStores) { + names.add(stateStore.name()); + } + + assertEquals(2, stateStores.size()); + assertTrue(names.contains(one)); + assertTrue(names.contains(two)); + } + } + + @Test + public void shouldMapStateStoresToCorrectSourceTopics() { + final KStream playEvents = builder.stream(Collections.singleton("events"), consumed); + + final MaterializedInternal> materializedInternal = + new MaterializedInternal<>(Materialized.as("table-store"), builder, storePrefix); + final KTable table = builder.table("table-topic", consumed, materializedInternal); + + final KStream mapped = playEvents.map(MockMapper.selectValueKeyValueMapper()); + mapped.leftJoin(table, MockValueJoiner.TOSTRING_JOINER).groupByKey().count(Materialized.as("count")); + builder.buildAndOptimizeTopology(); + builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(StreamsTestUtils.getStreamsConfig(APP_ID))); + assertEquals(Collections.singletonList("table-topic"), builder.internalTopologyBuilder.sourceTopicsForStore("table-store")); + assertEquals(Collections.singletonList(APP_ID + "-KSTREAM-MAP-0000000003-repartition"), builder.internalTopologyBuilder.sourceTopicsForStore("count")); + } + + @Test + public void shouldAddTopicToEarliestAutoOffsetResetList() { + final String topicName = "topic-1"; + final ConsumedInternal consumed = new ConsumedInternal<>(Consumed.with(AutoOffsetReset.EARLIEST)); + builder.stream(Collections.singleton(topicName), consumed); + builder.buildAndOptimizeTopology(); + + assertThat(builder.internalTopologyBuilder.offsetResetStrategy(topicName), equalTo(OffsetResetStrategy.EARLIEST)); + } + + @Test + public void shouldAddTopicToLatestAutoOffsetResetList() { + final String topicName = "topic-1"; + + final ConsumedInternal consumed = new ConsumedInternal<>(Consumed.with(AutoOffsetReset.LATEST)); + builder.stream(Collections.singleton(topicName), consumed); + builder.buildAndOptimizeTopology(); + assertThat(builder.internalTopologyBuilder.offsetResetStrategy(topicName), equalTo(OffsetResetStrategy.LATEST)); + } + + @Test + public void shouldAddTableToEarliestAutoOffsetResetList() { + final String topicName = "topic-1"; + builder.table(topicName, new ConsumedInternal<>(Consumed.with(AutoOffsetReset.EARLIEST)), materialized); + builder.buildAndOptimizeTopology(); + assertThat(builder.internalTopologyBuilder.offsetResetStrategy(topicName), equalTo(OffsetResetStrategy.EARLIEST)); + } + + @Test + public void shouldAddTableToLatestAutoOffsetResetList() { + final String topicName = "topic-1"; + builder.table(topicName, new ConsumedInternal<>(Consumed.with(AutoOffsetReset.LATEST)), materialized); + builder.buildAndOptimizeTopology(); + assertThat(builder.internalTopologyBuilder.offsetResetStrategy(topicName), equalTo(OffsetResetStrategy.LATEST)); + } + + @Test + public void shouldNotAddTableToOffsetResetLists() { + final String topicName = "topic-1"; + + builder.table(topicName, consumed, materialized); + + assertThat(builder.internalTopologyBuilder.offsetResetStrategy(topicName), equalTo(OffsetResetStrategy.NONE)); + } + + @Test + public void shouldNotAddRegexTopicsToOffsetResetLists() { + final Pattern topicPattern = Pattern.compile("topic-\\d"); + final String topic = "topic-5"; + + builder.stream(topicPattern, consumed); + + assertThat(builder.internalTopologyBuilder.offsetResetStrategy(topic), equalTo(OffsetResetStrategy.NONE)); + } + + @Test + public void shouldAddRegexTopicToEarliestAutoOffsetResetList() { + final Pattern topicPattern = Pattern.compile("topic-\\d+"); + final String topicTwo = "topic-500000"; + + builder.stream(topicPattern, new ConsumedInternal<>(Consumed.with(AutoOffsetReset.EARLIEST))); + builder.buildAndOptimizeTopology(); + + assertThat(builder.internalTopologyBuilder.offsetResetStrategy(topicTwo), equalTo(OffsetResetStrategy.EARLIEST)); + } + + @Test + public void shouldAddRegexTopicToLatestAutoOffsetResetList() { + final Pattern topicPattern = Pattern.compile("topic-\\d+"); + final String topicTwo = "topic-1000000"; + + builder.stream(topicPattern, new ConsumedInternal<>(Consumed.with(AutoOffsetReset.LATEST))); + builder.buildAndOptimizeTopology(); + + assertThat(builder.internalTopologyBuilder.offsetResetStrategy(topicTwo), equalTo(OffsetResetStrategy.LATEST)); + } + + @Test + public void shouldHaveNullTimestampExtractorWhenNoneSupplied() { + builder.stream(Collections.singleton("topic"), consumed); + builder.buildAndOptimizeTopology(); + builder.internalTopologyBuilder.rewriteTopology(new StreamsConfig(StreamsTestUtils.getStreamsConfig(APP_ID))); + final ProcessorTopology processorTopology = builder.internalTopologyBuilder.buildTopology(); + assertNull(processorTopology.source("topic").getTimestampExtractor()); + } + + @Test + public void shouldUseProvidedTimestampExtractor() { + final ConsumedInternal consumed = new ConsumedInternal<>(Consumed.with(new MockTimestampExtractor())); + builder.stream(Collections.singleton("topic"), consumed); + builder.buildAndOptimizeTopology(); + final ProcessorTopology processorTopology = builder.internalTopologyBuilder + .rewriteTopology(new StreamsConfig(StreamsTestUtils.getStreamsConfig(APP_ID))) + .buildTopology(); + assertThat(processorTopology.source("topic").getTimestampExtractor(), instanceOf(MockTimestampExtractor.class)); + } + + @Test + public void ktableShouldHaveNullTimestampExtractorWhenNoneSupplied() { + builder.table("topic", consumed, materialized); + builder.buildAndOptimizeTopology(); + final ProcessorTopology processorTopology = builder.internalTopologyBuilder + .rewriteTopology(new StreamsConfig(StreamsTestUtils.getStreamsConfig(APP_ID))) + .buildTopology(); + assertNull(processorTopology.source("topic").getTimestampExtractor()); + } + + @Test + public void ktableShouldUseProvidedTimestampExtractor() { + final ConsumedInternal consumed = new ConsumedInternal<>(Consumed.with(new MockTimestampExtractor())); + builder.table("topic", consumed, materialized); + builder.buildAndOptimizeTopology(); + final ProcessorTopology processorTopology = builder.internalTopologyBuilder + .rewriteTopology(new StreamsConfig(StreamsTestUtils.getStreamsConfig(APP_ID))) + .buildTopology(); + assertThat(processorTopology.source("topic").getTimestampExtractor(), instanceOf(MockTimestampExtractor.class)); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KGroupedStreamImplTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KGroupedStreamImplTest.java new file mode 100644 index 0000000..333884e --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KGroupedStreamImplTest.java @@ -0,0 +1,769 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.errors.TopologyException; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.SessionWindows; +import org.apache.kafka.streams.kstream.SlidingWindows; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.Windows; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.SessionStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.test.MockAggregator; +import org.apache.kafka.test.MockInitializer; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.MockReducer; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.Map; +import java.util.Properties; + +import static java.time.Duration.ofMillis; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class KGroupedStreamImplTest { + + private static final String TOPIC = "topic"; + private static final String INVALID_STORE_NAME = "~foo bar~"; + private final StreamsBuilder builder = new StreamsBuilder(); + private KGroupedStream groupedStream; + + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.String(), Serdes.String()); + + @Before + public void before() { + final KStream stream = builder.stream(TOPIC, Consumed.with(Serdes.String(), Serdes.String())); + groupedStream = stream.groupByKey(Grouped.with(Serdes.String(), Serdes.String())); + } + + @Test + public void shouldNotHaveNullAggregatorOnCogroup() { + assertThrows(NullPointerException.class, () -> groupedStream.cogroup(null)); + } + + @Test + public void shouldNotHaveNullReducerOnReduce() { + assertThrows(NullPointerException.class, () -> groupedStream.reduce(null)); + } + + @Test + public void shouldNotHaveInvalidStoreNameOnReduce() { + assertThrows(TopologyException.class, () -> groupedStream.reduce(MockReducer.STRING_ADDER, Materialized.as(INVALID_STORE_NAME))); + } + + @Test + public void shouldNotHaveNullReducerWithWindowedReduce() { + assertThrows(NullPointerException.class, () -> groupedStream + .windowedBy(TimeWindows.of(ofMillis(10))) + .reduce(null, Materialized.as("store"))); + } + + @Test + public void shouldNotHaveNullWindowsWithWindowedReduce() { + assertThrows(NullPointerException.class, () -> groupedStream.windowedBy((Windows) null)); + } + + @Test + public void shouldNotHaveInvalidStoreNameWithWindowedReduce() { + assertThrows(TopologyException.class, () -> groupedStream + .windowedBy(TimeWindows.of(ofMillis(10))) + .reduce(MockReducer.STRING_ADDER, Materialized.as(INVALID_STORE_NAME))); + } + + @Test + public void shouldNotHaveNullInitializerOnAggregate() { + assertThrows(NullPointerException.class, () -> groupedStream.aggregate(null, MockAggregator.TOSTRING_ADDER, Materialized.as("store"))); + } + + @Test + public void shouldNotHaveNullAdderOnAggregate() { + assertThrows(NullPointerException.class, () -> groupedStream.aggregate(MockInitializer.STRING_INIT, null, Materialized.as("store"))); + } + + @Test + public void shouldNotHaveInvalidStoreNameOnAggregate() { + assertThrows(TopologyException.class, () -> groupedStream.aggregate( + MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + Materialized.as(INVALID_STORE_NAME))); + } + + @Test + public void shouldNotHaveNullInitializerOnWindowedAggregate() { + assertThrows(NullPointerException.class, () -> groupedStream + .windowedBy(TimeWindows.of(ofMillis(10))) + .aggregate(null, MockAggregator.TOSTRING_ADDER, Materialized.as("store"))); + } + + @Test + public void shouldNotHaveNullAdderOnWindowedAggregate() { + assertThrows(NullPointerException.class, () -> groupedStream + .windowedBy(TimeWindows.of(ofMillis(10))) + .aggregate(MockInitializer.STRING_INIT, null, Materialized.as("store"))); + } + + @Test + public void shouldNotHaveNullWindowsOnWindowedAggregate() { + assertThrows(NullPointerException.class, () -> groupedStream.windowedBy((Windows) null)); + } + + @Test + public void shouldNotHaveInvalidStoreNameOnWindowedAggregate() { + assertThrows(TopologyException.class, () -> groupedStream + .windowedBy(TimeWindows.of(ofMillis(10))) + .aggregate(MockInitializer.STRING_INIT, MockAggregator.TOSTRING_ADDER, Materialized.as(INVALID_STORE_NAME))); + } + + @Test + public void shouldNotHaveNullReducerWithSlidingWindowedReduce() { + assertThrows(NullPointerException.class, () -> groupedStream + .windowedBy(SlidingWindows.withTimeDifferenceAndGrace(ofMillis(10), ofMillis(100))) + .reduce(null, Materialized.as("store"))); + } + + @Test + public void shouldNotHaveNullWindowsWithSlidingWindowedReduce() { + assertThrows(NullPointerException.class, () -> groupedStream.windowedBy((SlidingWindows) null)); + } + + @Test + public void shouldNotHaveInvalidStoreNameWithSlidingWindowedReduce() { + assertThrows(TopologyException.class, () -> groupedStream + .windowedBy(SlidingWindows.withTimeDifferenceAndGrace(ofMillis(10), ofMillis(100))) + .reduce(MockReducer.STRING_ADDER, Materialized.as(INVALID_STORE_NAME))); + } + + @Test + public void shouldNotHaveNullInitializerOnSlidingWindowedAggregate() { + assertThrows(NullPointerException.class, () -> groupedStream + .windowedBy(SlidingWindows.withTimeDifferenceAndGrace(ofMillis(10), ofMillis(100))) + .aggregate(null, MockAggregator.TOSTRING_ADDER, Materialized.as("store"))); + } + + @Test + public void shouldNotHaveNullAdderOnSlidingWindowedAggregate() { + assertThrows(NullPointerException.class, () -> groupedStream + .windowedBy(SlidingWindows.withTimeDifferenceAndGrace(ofMillis(10), ofMillis(100))) + .aggregate(MockInitializer.STRING_INIT, null, Materialized.as("store"))); + } + + @Test + public void shouldNotHaveInvalidStoreNameOnSlidingWindowedAggregate() { + assertThrows(TopologyException.class, () -> groupedStream + .windowedBy(SlidingWindows.withTimeDifferenceAndGrace(ofMillis(10), ofMillis(100))) + .aggregate(MockInitializer.STRING_INIT, MockAggregator.TOSTRING_ADDER, Materialized.as(INVALID_STORE_NAME))); + } + + @Test + public void shouldCountSlidingWindows() { + final MockProcessorSupplier, Long> supplier = new MockProcessorSupplier<>(); + groupedStream + .windowedBy(SlidingWindows.withTimeDifferenceAndGrace(ofMillis(500L), ofMillis(2000L))) + .count(Materialized.as("aggregate-by-key-windowed")) + .toStream() + .process(supplier); + + doCountSlidingWindows(supplier); + } + + @Test + public void shouldCountSlidingWindowsWithInternalStoreName() { + final MockProcessorSupplier, Long> supplier = new MockProcessorSupplier<>(); + groupedStream + .windowedBy(SlidingWindows.withTimeDifferenceAndGrace(ofMillis(500L), ofMillis(2000L))) + .count() + .toStream() + .process(supplier); + + doCountSlidingWindows(supplier); + } + + private void doCountSlidingWindows(final MockProcessorSupplier, Long> supplier) { + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(TOPIC, new StringSerializer(), new StringSerializer()); + inputTopic.pipeInput("1", "A", 500L); + inputTopic.pipeInput("1", "A", 999L); + inputTopic.pipeInput("1", "A", 600L); + inputTopic.pipeInput("2", "B", 500L); + inputTopic.pipeInput("2", "B", 600L); + inputTopic.pipeInput("2", "B", 700L); + inputTopic.pipeInput("3", "C", 501L); + inputTopic.pipeInput("1", "A", 1000L); + inputTopic.pipeInput("1", "A", 1000L); + inputTopic.pipeInput("2", "B", 1000L); + inputTopic.pipeInput("2", "B", 1000L); + inputTopic.pipeInput("3", "C", 600L); + } + + final Comparator, Long>> comparator = + Comparator.comparing((KeyValueTimestamp, Long> o) -> o.key().key()) + .thenComparing((KeyValueTimestamp, Long> o) -> o.key().window().start()); + + final ArrayList, Long>> actual = supplier.theCapturedProcessor().processed(); + actual.sort(comparator); + + assertThat(actual, equalTo(Arrays.asList( + // processing A@500 + new KeyValueTimestamp<>(new Windowed<>("1", new TimeWindow(0L, 500L)), 1L, 500L), + // processing A@600 + new KeyValueTimestamp<>(new Windowed<>("1", new TimeWindow(100L, 600L)), 2L, 600L), + // processing A@999 + new KeyValueTimestamp<>(new Windowed<>("1", new TimeWindow(499L, 999L)), 2L, 999L), + // processing A@600 + new KeyValueTimestamp<>(new Windowed<>("1", new TimeWindow(499L, 999L)), 3L, 999L), + // processing first A@1000 + new KeyValueTimestamp<>(new Windowed<>("1", new TimeWindow(500L, 1000L)), 4L, 1000L), + // processing second A@1000 + new KeyValueTimestamp<>(new Windowed<>("1", new TimeWindow(500L, 1000L)), 5L, 1000L), + // processing A@999 + new KeyValueTimestamp<>(new Windowed<>("1", new TimeWindow(501L, 1001L)), 1L, 999L), + // processing A@600 + new KeyValueTimestamp<>(new Windowed<>("1", new TimeWindow(501L, 1001L)), 2L, 999L), + // processing first A@1000 + new KeyValueTimestamp<>(new Windowed<>("1", new TimeWindow(501L, 1001L)), 3L, 1000L), + // processing second A@1000 + new KeyValueTimestamp<>(new Windowed<>("1", new TimeWindow(501L, 1001L)), 4L, 1000L), + // processing A@600 + new KeyValueTimestamp<>(new Windowed<>("1", new TimeWindow(601L, 1101L)), 1L, 999L), + // processing first A@1000 + new KeyValueTimestamp<>(new Windowed<>("1", new TimeWindow(601L, 1101L)), 2L, 1000L), + // processing second A@1000 + new KeyValueTimestamp<>(new Windowed<>("1", new TimeWindow(601L, 1101L)), 3L, 1000L), + // processing first A@1000 + new KeyValueTimestamp<>(new Windowed<>("1", new TimeWindow(1000L, 1500L)), 1L, 1000L), + // processing second A@1000 + new KeyValueTimestamp<>(new Windowed<>("1", new TimeWindow(1000L, 1500L)), 2L, 1000L), + + // processing B@500 + new KeyValueTimestamp<>(new Windowed<>("2", new TimeWindow(0L, 500L)), 1L, 500L), + // processing B@600 + new KeyValueTimestamp<>(new Windowed<>("2", new TimeWindow(100L, 600L)), 2L, 600L), + // processing B@700 + new KeyValueTimestamp<>(new Windowed<>("2", new TimeWindow(200L, 700L)), 3L, 700L), + // processing first B@1000 + new KeyValueTimestamp<>(new Windowed<>("2", new TimeWindow(500L, 1000L)), 4L, 1000L), + // processing second B@1000 + new KeyValueTimestamp<>(new Windowed<>("2", new TimeWindow(500L, 1000L)), 5L, 1000L), + // processing B@600 + new KeyValueTimestamp<>(new Windowed<>("2", new TimeWindow(501L, 1001L)), 1L, 600L), + // processing B@700 + new KeyValueTimestamp<>(new Windowed<>("2", new TimeWindow(501L, 1001L)), 2L, 700L), + // processing first B@1000 + new KeyValueTimestamp<>(new Windowed<>("2", new TimeWindow(501L, 1001L)), 3L, 1000L), + // processing second B@1000 + new KeyValueTimestamp<>(new Windowed<>("2", new TimeWindow(501L, 1001L)), 4L, 1000L), + // processing B@700 + new KeyValueTimestamp<>(new Windowed<>("2", new TimeWindow(601L, 1101L)), 1L, 700L), + // processing first B@1000 + new KeyValueTimestamp<>(new Windowed<>("2", new TimeWindow(601L, 1101)), 2L, 1000L), + // processing second B@1000 + new KeyValueTimestamp<>(new Windowed<>("2", new TimeWindow(601L, 1101)), 3L, 1000L), + // processing first B@1000 + new KeyValueTimestamp<>(new Windowed<>("2", new TimeWindow(701L, 1201L)), 1L, 1000L), + // processing second B@1000 + new KeyValueTimestamp<>(new Windowed<>("2", new TimeWindow(701L, 1201L)), 2L, 1000L), + + // processing C@501 + new KeyValueTimestamp<>(new Windowed<>("3", new TimeWindow(1L, 501L)), 1L, 501L), + // processing C@600 + new KeyValueTimestamp<>(new Windowed<>("3", new TimeWindow(100L, 600L)), 2L, 600L), + // processing C@600 + new KeyValueTimestamp<>(new Windowed<>("3", new TimeWindow(502L, 1002L)), 1L, 600L) + ))); + } + + private void doAggregateSessionWindows(final MockProcessorSupplier, Integer> supplier) { + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(TOPIC, new StringSerializer(), new StringSerializer()); + inputTopic.pipeInput("1", "1", 10); + inputTopic.pipeInput("2", "2", 15); + inputTopic.pipeInput("1", "1", 30); + inputTopic.pipeInput("1", "1", 70); + inputTopic.pipeInput("1", "1", 100); + inputTopic.pipeInput("1", "1", 90); + } + final Map, ValueAndTimestamp> result + = supplier.theCapturedProcessor().lastValueAndTimestampPerKey(); + assertEquals( + ValueAndTimestamp.make(2, 30L), + result.get(new Windowed<>("1", new SessionWindow(10L, 30L)))); + assertEquals( + ValueAndTimestamp.make(1, 15L), + result.get(new Windowed<>("2", new SessionWindow(15L, 15L)))); + assertEquals( + ValueAndTimestamp.make(3, 100L), + result.get(new Windowed<>("1", new SessionWindow(70L, 100L)))); + } + + @Test + public void shouldAggregateSessionWindows() { + final MockProcessorSupplier, Integer> supplier = new MockProcessorSupplier<>(); + final KTable, Integer> table = groupedStream + .windowedBy(SessionWindows.with(ofMillis(30))) + .aggregate( + () -> 0, + (aggKey, value, aggregate) -> aggregate + 1, + (aggKey, aggOne, aggTwo) -> aggOne + aggTwo, + Materialized + .>as("session-store"). + withValueSerde(Serdes.Integer())); + table.toStream().process(supplier); + + doAggregateSessionWindows(supplier); + assertEquals(table.queryableStoreName(), "session-store"); + } + + @Test + public void shouldAggregateSessionWindowsWithInternalStoreName() { + final MockProcessorSupplier, Integer> supplier = new MockProcessorSupplier<>(); + final KTable, Integer> table = groupedStream + .windowedBy(SessionWindows.with(ofMillis(30))) + .aggregate( + () -> 0, + (aggKey, value, aggregate) -> aggregate + 1, + (aggKey, aggOne, aggTwo) -> aggOne + aggTwo, + Materialized.with(null, Serdes.Integer())); + table.toStream().process(supplier); + + doAggregateSessionWindows(supplier); + } + + private void doCountSessionWindows(final MockProcessorSupplier, Long> supplier) { + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(TOPIC, new StringSerializer(), new StringSerializer()); + inputTopic.pipeInput("1", "1", 10); + inputTopic.pipeInput("2", "2", 15); + inputTopic.pipeInput("1", "1", 30); + inputTopic.pipeInput("1", "1", 70); + inputTopic.pipeInput("1", "1", 100); + inputTopic.pipeInput("1", "1", 90); + } + final Map, ValueAndTimestamp> result = + supplier.theCapturedProcessor().lastValueAndTimestampPerKey(); + assertEquals( + ValueAndTimestamp.make(2L, 30L), + result.get(new Windowed<>("1", new SessionWindow(10L, 30L)))); + assertEquals( + ValueAndTimestamp.make(1L, 15L), + result.get(new Windowed<>("2", new SessionWindow(15L, 15L)))); + assertEquals( + ValueAndTimestamp.make(3L, 100L), + result.get(new Windowed<>("1", new SessionWindow(70L, 100L)))); + } + + @Test + public void shouldCountSessionWindows() { + final MockProcessorSupplier, Long> supplier = new MockProcessorSupplier<>(); + final KTable, Long> table = groupedStream + .windowedBy(SessionWindows.with(ofMillis(30))) + .count(Materialized.as("session-store")); + table.toStream().process(supplier); + doCountSessionWindows(supplier); + assertEquals(table.queryableStoreName(), "session-store"); + } + + @Test + public void shouldCountSessionWindowsWithInternalStoreName() { + final MockProcessorSupplier, Long> supplier = new MockProcessorSupplier<>(); + final KTable, Long> table = groupedStream + .windowedBy(SessionWindows.with(ofMillis(30))) + .count(); + table.toStream().process(supplier); + doCountSessionWindows(supplier); + assertNull(table.queryableStoreName()); + } + + private void doReduceSessionWindows(final MockProcessorSupplier, String> supplier) { + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(TOPIC, new StringSerializer(), new StringSerializer()); + inputTopic.pipeInput("1", "A", 10); + inputTopic.pipeInput("2", "Z", 15); + inputTopic.pipeInput("1", "B", 30); + inputTopic.pipeInput("1", "A", 70); + inputTopic.pipeInput("1", "B", 100); + inputTopic.pipeInput("1", "C", 90); + } + final Map, ValueAndTimestamp> result = + supplier.theCapturedProcessor().lastValueAndTimestampPerKey(); + assertEquals( + ValueAndTimestamp.make("A:B", 30L), + result.get(new Windowed<>("1", new SessionWindow(10L, 30L)))); + assertEquals( + ValueAndTimestamp.make("Z", 15L), + result.get(new Windowed<>("2", new SessionWindow(15L, 15L)))); + assertEquals( + ValueAndTimestamp.make("A:B:C", 100L), + result.get(new Windowed<>("1", new SessionWindow(70L, 100L)))); + } + + @Test + public void shouldReduceSessionWindows() { + final MockProcessorSupplier, String> supplier = new MockProcessorSupplier<>(); + final KTable, String> table = groupedStream + .windowedBy(SessionWindows.with(ofMillis(30))) + .reduce((value1, value2) -> value1 + ":" + value2, Materialized.as("session-store")); + table.toStream().process(supplier); + doReduceSessionWindows(supplier); + assertEquals(table.queryableStoreName(), "session-store"); + } + + @Test + public void shouldReduceSessionWindowsWithInternalStoreName() { + final MockProcessorSupplier, String> supplier = new MockProcessorSupplier<>(); + final KTable, String> table = groupedStream + .windowedBy(SessionWindows.with(ofMillis(30))) + .reduce((value1, value2) -> value1 + ":" + value2); + table.toStream().process(supplier); + doReduceSessionWindows(supplier); + assertNull(table.queryableStoreName()); + } + + @Test + public void shouldNotAcceptNullReducerWhenReducingSessionWindows() { + assertThrows(NullPointerException.class, () -> groupedStream + .windowedBy(SessionWindows.with(ofMillis(30))) + .reduce(null, Materialized.as("store"))); + } + + @Test + public void shouldNotAcceptNullSessionWindowsReducingSessionWindows() { + assertThrows(NullPointerException.class, () -> groupedStream.windowedBy((SessionWindows) null)); + } + + @Test + public void shouldNotAcceptInvalidStoreNameWhenReducingSessionWindows() { + assertThrows(TopologyException.class, () -> groupedStream + .windowedBy(SessionWindows.with(ofMillis(30))) + .reduce(MockReducer.STRING_ADDER, Materialized.as(INVALID_STORE_NAME)) + ); + } + + @Test + public void shouldNotAcceptNullStateStoreSupplierWhenReducingSessionWindows() { + assertThrows(NullPointerException.class, () -> groupedStream + .windowedBy(SessionWindows.with(ofMillis(30))) + .reduce(null, Materialized.>as(null)) + ); + } + + @Test + public void shouldNotAcceptNullInitializerWhenAggregatingSessionWindows() { + assertThrows(NullPointerException.class, () -> groupedStream + .windowedBy(SessionWindows.with(ofMillis(30))) + .aggregate(null, MockAggregator.TOSTRING_ADDER, (aggKey, aggOne, aggTwo) -> null, Materialized.as("storeName")) + ); + } + + @Test + public void shouldNotAcceptNullAggregatorWhenAggregatingSessionWindows() { + assertThrows(NullPointerException.class, () -> groupedStream. + windowedBy(SessionWindows.with(ofMillis(30))) + .aggregate(MockInitializer.STRING_INIT, null, (aggKey, aggOne, aggTwo) -> null, Materialized.as("storeName")) + ); + } + + @Test + public void shouldNotAcceptNullSessionMergerWhenAggregatingSessionWindows() { + assertThrows(NullPointerException.class, () -> groupedStream + .windowedBy(SessionWindows.with(ofMillis(30))) + .aggregate(MockInitializer.STRING_INIT, MockAggregator.TOSTRING_ADDER, null, Materialized.as("storeName")) + ); + } + + @Test + public void shouldNotAcceptNullSessionWindowsWhenAggregatingSessionWindows() { + assertThrows(NullPointerException.class, () -> groupedStream.windowedBy((SessionWindows) null)); + } + + @Test + public void shouldAcceptNullStoreNameWhenAggregatingSessionWindows() { + groupedStream + .windowedBy(SessionWindows.with(ofMillis(10))) + .aggregate( + MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + (aggKey, aggOne, aggTwo) -> null, Materialized.with(Serdes.String(), Serdes.String()) + ); + } + + @Test + public void shouldNotAcceptInvalidStoreNameWhenAggregatingSessionWindows() { + assertThrows(TopologyException.class, () -> groupedStream + .windowedBy(SessionWindows.with(ofMillis(10))) + .aggregate(MockInitializer.STRING_INIT, MockAggregator.TOSTRING_ADDER, (aggKey, aggOne, aggTwo) -> null, Materialized.as(INVALID_STORE_NAME)) + ); + } + + @Test + public void shouldThrowNullPointerOnReduceWhenMaterializedIsNull() { + assertThrows(NullPointerException.class, () -> groupedStream.reduce(MockReducer.STRING_ADDER, null)); + } + + @Test + public void shouldThrowNullPointerOnAggregateWhenMaterializedIsNull() { + assertThrows(NullPointerException.class, () -> groupedStream.aggregate(MockInitializer.STRING_INIT, MockAggregator.TOSTRING_ADDER, null)); + } + + @Test + public void shouldThrowNullPointerOnCountWhenMaterializedIsNull() { + assertThrows(NullPointerException.class, () -> groupedStream.count((Materialized>) null)); + } + + @Test + public void shouldCountAndMaterializeResults() { + groupedStream.count(Materialized.>as("count").withKeySerde(Serdes.String())); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + processData(driver); + + { + final KeyValueStore count = driver.getKeyValueStore("count"); + + assertThat(count.get("1"), equalTo(3L)); + assertThat(count.get("2"), equalTo(1L)); + assertThat(count.get("3"), equalTo(2L)); + } + { + final KeyValueStore> count = driver.getTimestampedKeyValueStore("count"); + + assertThat(count.get("1"), equalTo(ValueAndTimestamp.make(3L, 10L))); + assertThat(count.get("2"), equalTo(ValueAndTimestamp.make(1L, 1L))); + assertThat(count.get("3"), equalTo(ValueAndTimestamp.make(2L, 9L))); + } + } + } + + @Test + public void shouldLogAndMeasureSkipsInAggregate() { + groupedStream.count(Materialized.>as("count").withKeySerde(Serdes.String())); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(KStreamAggregate.class); + final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + + processData(driver); + + assertThat( + appender.getMessages(), + hasItem("Skipping record due to null key or value. topic=[topic] partition=[0] " + + "offset=[6]") + ); + } + } + + @Test + public void shouldReduceAndMaterializeResults() { + groupedStream.reduce( + MockReducer.STRING_ADDER, + Materialized.>as("reduce") + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.String())); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + processData(driver); + + { + final KeyValueStore reduced = driver.getKeyValueStore("reduce"); + + assertThat(reduced.get("1"), equalTo("A+C+D")); + assertThat(reduced.get("2"), equalTo("B")); + assertThat(reduced.get("3"), equalTo("E+F")); + } + { + final KeyValueStore> reduced = driver.getTimestampedKeyValueStore("reduce"); + + assertThat(reduced.get("1"), equalTo(ValueAndTimestamp.make("A+C+D", 10L))); + assertThat(reduced.get("2"), equalTo(ValueAndTimestamp.make("B", 1L))); + assertThat(reduced.get("3"), equalTo(ValueAndTimestamp.make("E+F", 9L))); + } + } + } + + @Test + public void shouldLogAndMeasureSkipsInReduce() { + groupedStream.reduce( + MockReducer.STRING_ADDER, + Materialized.>as("reduce") + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.String()) + ); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(KStreamReduce.class); + final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + + processData(driver); + + assertThat( + appender.getMessages(), + hasItem("Skipping record due to null key or value. topic=[topic] partition=[0] " + + "offset=[6]") + ); + } + } + + @Test + public void shouldAggregateAndMaterializeResults() { + groupedStream.aggregate( + MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + Materialized.>as("aggregate") + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.String())); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + processData(driver); + + { + final KeyValueStore aggregate = driver.getKeyValueStore("aggregate"); + + assertThat(aggregate.get("1"), equalTo("0+A+C+D")); + assertThat(aggregate.get("2"), equalTo("0+B")); + assertThat(aggregate.get("3"), equalTo("0+E+F")); + } + { + final KeyValueStore> aggregate = driver.getTimestampedKeyValueStore("aggregate"); + + assertThat(aggregate.get("1"), equalTo(ValueAndTimestamp.make("0+A+C+D", 10L))); + assertThat(aggregate.get("2"), equalTo(ValueAndTimestamp.make("0+B", 1L))); + assertThat(aggregate.get("3"), equalTo(ValueAndTimestamp.make("0+E+F", 9L))); + } + } + } + + @Test + public void shouldAggregateWithDefaultSerdes() { + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + groupedStream + .aggregate(MockInitializer.STRING_INIT, MockAggregator.TOSTRING_ADDER) + .toStream() + .process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + processData(driver); + + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey().get("1"), + equalTo(ValueAndTimestamp.make("0+A+C+D", 10L))); + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey().get("2"), + equalTo(ValueAndTimestamp.make("0+B", 1L))); + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey().get("3"), + equalTo(ValueAndTimestamp.make("0+E+F", 9L))); + } + } + + private void processData(final TopologyTestDriver driver) { + final TestInputTopic inputTopic = + driver.createInputTopic(TOPIC, new StringSerializer(), new StringSerializer()); + inputTopic.pipeInput("1", "A", 5L); + inputTopic.pipeInput("2", "B", 1L); + inputTopic.pipeInput("1", "C", 3L); + inputTopic.pipeInput("1", "D", 10L); + inputTopic.pipeInput("3", "E", 8L); + inputTopic.pipeInput("3", "F", 9L); + inputTopic.pipeInput("3", (String) null); + } + + private void doCountWindowed(final MockProcessorSupplier, Long> supplier) { + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(TOPIC, new StringSerializer(), new StringSerializer()); + inputTopic.pipeInput("1", "A", 0L); + inputTopic.pipeInput("1", "A", 499L); + inputTopic.pipeInput("1", "A", 100L); + inputTopic.pipeInput("2", "B", 0L); + inputTopic.pipeInput("2", "B", 100L); + inputTopic.pipeInput("2", "B", 200L); + inputTopic.pipeInput("3", "C", 1L); + inputTopic.pipeInput("1", "A", 500L); + inputTopic.pipeInput("1", "A", 500L); + inputTopic.pipeInput("2", "B", 500L); + inputTopic.pipeInput("2", "B", 500L); + inputTopic.pipeInput("3", "B", 100L); + } + assertThat(supplier.theCapturedProcessor().processed(), equalTo(Arrays.asList( + new KeyValueTimestamp<>(new Windowed<>("1", new TimeWindow(0L, 500L)), 1L, 0L), + new KeyValueTimestamp<>(new Windowed<>("1", new TimeWindow(0L, 500L)), 2L, 499L), + new KeyValueTimestamp<>(new Windowed<>("1", new TimeWindow(0L, 500L)), 3L, 499L), + new KeyValueTimestamp<>(new Windowed<>("2", new TimeWindow(0L, 500L)), 1L, 0L), + new KeyValueTimestamp<>(new Windowed<>("2", new TimeWindow(0L, 500L)), 2L, 100L), + new KeyValueTimestamp<>(new Windowed<>("2", new TimeWindow(0L, 500L)), 3L, 200L), + new KeyValueTimestamp<>(new Windowed<>("3", new TimeWindow(0L, 500L)), 1L, 1L), + new KeyValueTimestamp<>(new Windowed<>("1", new TimeWindow(500L, 1000L)), 1L, 500L), + new KeyValueTimestamp<>(new Windowed<>("1", new TimeWindow(500L, 1000L)), 2L, 500L), + new KeyValueTimestamp<>(new Windowed<>("2", new TimeWindow(500L, 1000L)), 1L, 500L), + new KeyValueTimestamp<>(new Windowed<>("2", new TimeWindow(500L, 1000L)), 2L, 500L), + new KeyValueTimestamp<>(new Windowed<>("3", new TimeWindow(0L, 500L)), 2L, 100L) + ))); + } + + @Test + public void shouldCountWindowed() { + final MockProcessorSupplier, Long> supplier = new MockProcessorSupplier<>(); + groupedStream + .windowedBy(TimeWindows.of(ofMillis(500L))) + .count(Materialized.as("aggregate-by-key-windowed")) + .toStream() + .process(supplier); + + doCountWindowed(supplier); + } + + @Test + public void shouldCountWindowedWithInternalStoreName() { + final MockProcessorSupplier, Long> supplier = new MockProcessorSupplier<>(); + groupedStream + .windowedBy(TimeWindows.of(ofMillis(500L))) + .count() + .toStream() + .process(supplier); + + doCountWindowed(supplier); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KGroupedTableImplTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KGroupedTableImplTest.java new file mode 100644 index 0000000..3b5295c --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KGroupedTableImplTest.java @@ -0,0 +1,377 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.DoubleSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.errors.TopologyException; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KGroupedTable; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.test.MockAggregator; +import org.apache.kafka.test.MockApiProcessorSupplier; +import org.apache.kafka.test.MockInitializer; +import org.apache.kafka.test.MockMapper; +import org.apache.kafka.test.MockReducer; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Before; +import org.junit.Test; + +import java.util.Map; +import java.util.Properties; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; + +public class KGroupedTableImplTest { + + private final StreamsBuilder builder = new StreamsBuilder(); + private static final String INVALID_STORE_NAME = "~foo bar~"; + private KGroupedTable groupedTable; + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.String(), Serdes.Integer()); + private final String topic = "input"; + + @Before + public void before() { + groupedTable = builder + .table("blah", Consumed.with(Serdes.String(), Serdes.String())) + .groupBy(MockMapper.selectValueKeyValueMapper()); + } + + @Test + public void shouldNotAllowInvalidStoreNameOnAggregate() { + assertThrows(TopologyException.class, () -> groupedTable.aggregate( + MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + MockAggregator.TOSTRING_REMOVER, + Materialized.as(INVALID_STORE_NAME))); + } + + @Test + public void shouldNotAllowNullInitializerOnAggregate() { + assertThrows(NullPointerException.class, () -> groupedTable.aggregate( + null, + MockAggregator.TOSTRING_ADDER, + MockAggregator.TOSTRING_REMOVER, + Materialized.as("store"))); + } + + @Test + public void shouldNotAllowNullAdderOnAggregate() { + assertThrows(NullPointerException.class, () -> groupedTable.aggregate( + MockInitializer.STRING_INIT, + null, + MockAggregator.TOSTRING_REMOVER, + Materialized.as("store"))); + } + + @Test + public void shouldNotAllowNullSubtractorOnAggregate() { + assertThrows(NullPointerException.class, () -> groupedTable.aggregate( + MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + null, + Materialized.as("store"))); + } + + @Test + public void shouldNotAllowNullAdderOnReduce() { + assertThrows(NullPointerException.class, () -> groupedTable.reduce( + null, + MockReducer.STRING_REMOVER, + Materialized.as("store"))); + } + + @Test + public void shouldNotAllowNullSubtractorOnReduce() { + assertThrows(NullPointerException.class, () -> groupedTable.reduce( + MockReducer.STRING_ADDER, + null, + Materialized.as("store"))); + } + + @Test + public void shouldNotAllowInvalidStoreNameOnReduce() { + assertThrows(TopologyException.class, () -> groupedTable.reduce( + MockReducer.STRING_ADDER, + MockReducer.STRING_REMOVER, + Materialized.as(INVALID_STORE_NAME))); + } + + private MockApiProcessorSupplier getReducedResults(final KTable inputKTable) { + final MockApiProcessorSupplier supplier = new MockApiProcessorSupplier<>(); + inputKTable + .toStream() + .process(supplier); + return supplier; + } + + private void assertReduced(final Map> reducedResults, + final String topic, + final TopologyTestDriver driver) { + final TestInputTopic inputTopic = + driver.createInputTopic(topic, new StringSerializer(), new DoubleSerializer()); + inputTopic.pipeInput("A", 1.1, 10); + inputTopic.pipeInput("B", 2.2, 11); + + assertEquals(ValueAndTimestamp.make(1, 10L), reducedResults.get("A")); + assertEquals(ValueAndTimestamp.make(2, 11L), reducedResults.get("B")); + + inputTopic.pipeInput("A", 2.6, 30); + inputTopic.pipeInput("B", 1.3, 30); + inputTopic.pipeInput("A", 5.7, 50); + inputTopic.pipeInput("B", 6.2, 20); + + assertEquals(ValueAndTimestamp.make(5, 50L), reducedResults.get("A")); + assertEquals(ValueAndTimestamp.make(6, 30L), reducedResults.get("B")); + } + + @Test + public void shouldReduce() { + final KeyValueMapper> intProjection = + (key, value) -> KeyValue.pair(key, value.intValue()); + + final KTable reduced = builder + .table( + topic, + Consumed.with(Serdes.String(), Serdes.Double()), + Materialized.>as("store") + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.Double())) + .groupBy(intProjection) + .reduce( + MockReducer.INTEGER_ADDER, + MockReducer.INTEGER_SUBTRACTOR, + Materialized.as("reduced")); + + final MockApiProcessorSupplier supplier = getReducedResults(reduced); + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + assertReduced(supplier.theCapturedProcessor().lastValueAndTimestampPerKey(), topic, driver); + assertEquals(reduced.queryableStoreName(), "reduced"); + } + } + + @Test + public void shouldReduceWithInternalStoreName() { + final KeyValueMapper> intProjection = + (key, value) -> KeyValue.pair(key, value.intValue()); + + final KTable reduced = builder + .table( + topic, + Consumed.with(Serdes.String(), Serdes.Double()), + Materialized.>as("store") + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.Double())) + .groupBy(intProjection) + .reduce(MockReducer.INTEGER_ADDER, MockReducer.INTEGER_SUBTRACTOR); + + final MockApiProcessorSupplier supplier = getReducedResults(reduced); + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + assertReduced(supplier.theCapturedProcessor().lastValueAndTimestampPerKey(), topic, driver); + assertNull(reduced.queryableStoreName()); + } + } + + @Test + public void shouldReduceAndMaterializeResults() { + final KeyValueMapper> intProjection = + (key, value) -> KeyValue.pair(key, value.intValue()); + + final KTable reduced = builder + .table( + topic, + Consumed.with(Serdes.String(), Serdes.Double())) + .groupBy(intProjection) + .reduce( + MockReducer.INTEGER_ADDER, + MockReducer.INTEGER_SUBTRACTOR, + Materialized.>as("reduce") + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.Integer())); + + final MockApiProcessorSupplier supplier = getReducedResults(reduced); + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + assertReduced(supplier.theCapturedProcessor().lastValueAndTimestampPerKey(), topic, driver); + { + final KeyValueStore reduce = driver.getKeyValueStore("reduce"); + assertThat(reduce.get("A"), equalTo(5)); + assertThat(reduce.get("B"), equalTo(6)); + } + { + final KeyValueStore> reduce = driver.getTimestampedKeyValueStore("reduce"); + assertThat(reduce.get("A"), equalTo(ValueAndTimestamp.make(5, 50L))); + assertThat(reduce.get("B"), equalTo(ValueAndTimestamp.make(6, 30L))); + } + } + } + + @Test + public void shouldCountAndMaterializeResults() { + builder + .table( + topic, + Consumed.with(Serdes.String(), Serdes.String())) + .groupBy( + MockMapper.selectValueKeyValueMapper(), + Grouped.with(Serdes.String(), Serdes.String())) + .count( + Materialized.>as("count") + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.Long())); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + processData(topic, driver); + { + final KeyValueStore counts = driver.getKeyValueStore("count"); + assertThat(counts.get("1"), equalTo(3L)); + assertThat(counts.get("2"), equalTo(2L)); + } + { + final KeyValueStore> counts = driver.getTimestampedKeyValueStore("count"); + assertThat(counts.get("1"), equalTo(ValueAndTimestamp.make(3L, 50L))); + assertThat(counts.get("2"), equalTo(ValueAndTimestamp.make(2L, 60L))); + } + } + } + + @Test + public void shouldAggregateAndMaterializeResults() { + builder + .table( + topic, + Consumed.with(Serdes.String(), Serdes.String())) + .groupBy( + MockMapper.selectValueKeyValueMapper(), + Grouped.with(Serdes.String(), Serdes.String())) + .aggregate( + MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + MockAggregator.TOSTRING_REMOVER, + Materialized.>as("aggregate") + .withValueSerde(Serdes.String()) + .withKeySerde(Serdes.String())); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + processData(topic, driver); + { + { + final KeyValueStore aggregate = driver.getKeyValueStore("aggregate"); + assertThat(aggregate.get("1"), equalTo("0+1+1+1")); + assertThat(aggregate.get("2"), equalTo("0+2+2")); + } + { + final KeyValueStore> aggregate = driver.getTimestampedKeyValueStore("aggregate"); + assertThat(aggregate.get("1"), equalTo(ValueAndTimestamp.make("0+1+1+1", 50L))); + assertThat(aggregate.get("2"), equalTo(ValueAndTimestamp.make("0+2+2", 60L))); + } + } + } + } + + @SuppressWarnings("unchecked") + @Test + public void shouldThrowNullPointOnCountWhenMaterializedIsNull() { + assertThrows(NullPointerException.class, () -> groupedTable.count((Materialized) null)); + } + + @Test + public void shouldThrowNullPointerOnReduceWhenMaterializedIsNull() { + assertThrows(NullPointerException.class, () -> groupedTable.reduce( + MockReducer.STRING_ADDER, + MockReducer.STRING_REMOVER, + null)); + } + + @Test + public void shouldThrowNullPointerOnReduceWhenAdderIsNull() { + assertThrows(NullPointerException.class, () -> groupedTable.reduce( + null, + MockReducer.STRING_REMOVER, + Materialized.as("store"))); + } + + @Test + public void shouldThrowNullPointerOnReduceWhenSubtractorIsNull() { + assertThrows(NullPointerException.class, () -> groupedTable.reduce( + MockReducer.STRING_ADDER, + null, + Materialized.as("store"))); + } + + @Test + public void shouldThrowNullPointerOnAggregateWhenInitializerIsNull() { + assertThrows(NullPointerException.class, () -> groupedTable.aggregate( + null, + MockAggregator.TOSTRING_ADDER, + MockAggregator.TOSTRING_REMOVER, + Materialized.as("store"))); + } + + @Test + public void shouldThrowNullPointerOnAggregateWhenAdderIsNull() { + assertThrows(NullPointerException.class, () -> groupedTable.aggregate( + MockInitializer.STRING_INIT, + null, + MockAggregator.TOSTRING_REMOVER, + Materialized.as("store"))); + } + + @Test + public void shouldThrowNullPointerOnAggregateWhenSubtractorIsNull() { + assertThrows(NullPointerException.class, () -> groupedTable.aggregate( + MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + null, + Materialized.as("store"))); + } + + @SuppressWarnings("unchecked") + @Test + public void shouldThrowNullPointerOnAggregateWhenMaterializedIsNull() { + assertThrows(NullPointerException.class, () -> groupedTable.aggregate( + MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + MockAggregator.TOSTRING_REMOVER, + (Materialized) null)); + } + + private void processData(final String topic, + final TopologyTestDriver driver) { + final TestInputTopic inputTopic = + driver.createInputTopic(topic, new StringSerializer(), new StringSerializer()); + inputTopic.pipeInput("A", "1", 10L); + inputTopic.pipeInput("B", "1", 50L); + inputTopic.pipeInput("C", "1", 30L); + inputTopic.pipeInput("D", "2", 40L); + inputTopic.pipeInput("E", "2", 60L); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamBranchTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamBranchTest.java new file mode 100644 index 0000000..8a731f9 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamBranchTest.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Predicate; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.test.MockProcessor; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Test; + +import java.util.List; +import java.util.Properties; + +import static org.junit.Assert.assertEquals; + +public class KStreamBranchTest { + + private final String topicName = "topic"; + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.String(), Serdes.String()); + + @SuppressWarnings({"unchecked", "deprecation"}) // Old PAPI. Needs to be migrated. + @Test + public void testKStreamBranch() { + final StreamsBuilder builder = new StreamsBuilder(); + + final Predicate isEven = (key, value) -> (key % 2) == 0; + final Predicate isMultipleOfThree = (key, value) -> (key % 3) == 0; + final Predicate isOdd = (key, value) -> (key % 2) != 0; + + final int[] expectedKeys = new int[]{1, 2, 3, 4, 5, 6}; + + final KStream stream; + final KStream[] branches; + + stream = builder.stream(topicName, Consumed.with(Serdes.Integer(), Serdes.String())); + branches = stream.branch(isEven, isMultipleOfThree, isOdd); + + assertEquals(3, branches.length); + + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + for (final KStream branch : branches) { + branch.process(supplier); + } + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = driver.createInputTopic(topicName, new IntegerSerializer(), new StringSerializer()); + for (final int expectedKey : expectedKeys) { + inputTopic.pipeInput(expectedKey, "V" + expectedKey); + } + } + + final List> processors = supplier.capturedProcessors(3); + assertEquals(3, processors.get(0).processed().size()); + assertEquals(1, processors.get(1).processed().size()); + assertEquals(2, processors.get(2).processed().size()); + } + + @SuppressWarnings({"unchecked", "deprecation"}) + @Test + public void testTypeVariance() { + final Predicate positive = (key, value) -> key.doubleValue() > 0; + + final Predicate negative = (key, value) -> key.doubleValue() < 0; + + new StreamsBuilder() + .stream("empty") + .branch(positive, negative); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamFilterTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamFilterTest.java new file mode 100644 index 0000000..bc3f461 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamFilterTest.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Predicate; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Test; + +import java.util.Properties; + +import static org.junit.Assert.assertEquals; + +public class KStreamFilterTest { + + private final String topicName = "topic"; + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.Integer(), Serdes.String()); + + private final Predicate isMultipleOfThree = (key, value) -> (key % 3) == 0; + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Test + public void testFilter() { + final StreamsBuilder builder = new StreamsBuilder(); + final int[] expectedKeys = new int[]{1, 2, 3, 4, 5, 6, 7}; + + final KStream stream; + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + + stream = builder.stream(topicName, Consumed.with(Serdes.Integer(), Serdes.String())); + stream.filter(isMultipleOfThree).process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = driver.createInputTopic(topicName, new IntegerSerializer(), new StringSerializer()); + for (final int expectedKey : expectedKeys) { + inputTopic.pipeInput(expectedKey, "V" + expectedKey); + } + } + + assertEquals(2, supplier.theCapturedProcessor().processed().size()); + } + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Test + public void testFilterNot() { + final StreamsBuilder builder = new StreamsBuilder(); + final int[] expectedKeys = new int[]{1, 2, 3, 4, 5, 6, 7}; + + final KStream stream; + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + + stream = builder.stream(topicName, Consumed.with(Serdes.Integer(), Serdes.String())); + stream.filterNot(isMultipleOfThree).process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + for (final int expectedKey : expectedKeys) { + final TestInputTopic inputTopic = driver.createInputTopic(topicName, new IntegerSerializer(), new StringSerializer()); + inputTopic.pipeInput(expectedKey, "V" + expectedKey); + } + } + + assertEquals(5, supplier.theCapturedProcessor().processed().size()); + } + + @Test + public void testTypeVariance() { + final Predicate numberKeyPredicate = (key, value) -> false; + + new StreamsBuilder() + .stream("empty") + .filter(numberKeyPredicate) + .filterNot(numberKeyPredicate) + .to("nirvana"); + + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamFlatMapTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamFlatMapTest.java new file mode 100644 index 0000000..1f763a0 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamFlatMapTest.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Test; + +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Properties; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.Is.is; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +public class KStreamFlatMapTest { + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.Integer(), Serdes.String()); + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Test + public void testFlatMap() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topicName = "topic"; + + final KeyValueMapper>> mapper = + (key, value) -> { + final ArrayList> result = new ArrayList<>(); + for (int i = 0; i < key.intValue(); i++) { + result.add(KeyValue.pair(Integer.toString(key.intValue() * 10 + i), value.toString())); + } + return result; + }; + + final int[] expectedKeys = {0, 1, 2, 3}; + + final KStream stream; + final MockProcessorSupplier supplier; + + supplier = new MockProcessorSupplier<>(); + stream = builder.stream(topicName, Consumed.with(Serdes.Integer(), Serdes.String())); + stream.flatMap(mapper).process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(topicName, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0), Duration.ZERO); + for (final int expectedKey : expectedKeys) { + inputTopic.pipeInput(expectedKey, "V" + expectedKey); + } + } + + assertEquals(6, supplier.theCapturedProcessor().processed().size()); + + final KeyValueTimestamp[] expected = {new KeyValueTimestamp<>("10", "V1", 0), + new KeyValueTimestamp<>("20", "V2", 0), + new KeyValueTimestamp<>("21", "V2", 0), + new KeyValueTimestamp<>("30", "V3", 0), + new KeyValueTimestamp<>("31", "V3", 0), + new KeyValueTimestamp<>("32", "V3", 0)}; + + for (int i = 0; i < expected.length; i++) { + assertEquals(expected[i], supplier.theCapturedProcessor().processed().get(i)); + } + } + + @Test + public void testKeyValueMapperResultNotNull() { + final KStreamFlatMap supplier = new KStreamFlatMap<>((key, value) -> null); + final Throwable throwable = assertThrows(NullPointerException.class, + () -> supplier.get().process(new Record<>("K", 0, 0L))); + assertThat(throwable.getMessage(), is("The provided KeyValueMapper returned null which is not allowed.")); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamFlatMapValuesTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamFlatMapValuesTest.java new file mode 100644 index 0000000..c1930e5 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamFlatMapValuesTest.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.ValueMapper; +import org.apache.kafka.streams.kstream.ValueMapperWithKey; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Test; + +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Properties; + +import static org.junit.Assert.assertArrayEquals; + +public class KStreamFlatMapValuesTest { + private final String topicName = "topic"; + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.Integer(), Serdes.String()); + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Test + public void testFlatMapValues() { + final StreamsBuilder builder = new StreamsBuilder(); + + final ValueMapper> mapper = + value -> { + final ArrayList result = new ArrayList<>(); + result.add("v" + value); + result.add("V" + value); + return result; + }; + + final int[] expectedKeys = {0, 1, 2, 3}; + + final KStream stream = builder.stream(topicName, Consumed.with(Serdes.Integer(), Serdes.Integer())); + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + stream.flatMapValues(mapper).process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(topicName, new IntegerSerializer(), new IntegerSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + for (final int expectedKey : expectedKeys) { + // passing the timestamp to inputTopic.create to disambiguate the call + inputTopic.pipeInput(expectedKey, expectedKey, 0L); + } + } + + final KeyValueTimestamp[] expected = {new KeyValueTimestamp<>(0, "v0", 0), new KeyValueTimestamp<>(0, "V0", 0), + new KeyValueTimestamp<>(1, "v1", 0), new KeyValueTimestamp<>(1, "V1", 0), + new KeyValueTimestamp<>(2, "v2", 0), new KeyValueTimestamp<>(2, "V2", 0), + new KeyValueTimestamp<>(3, "v3", 0), new KeyValueTimestamp<>(3, "V3", 0)}; + + assertArrayEquals(expected, supplier.theCapturedProcessor().processed().toArray()); + } + + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Test + public void testFlatMapValuesWithKeys() { + final StreamsBuilder builder = new StreamsBuilder(); + + final ValueMapperWithKey> mapper = + (readOnlyKey, value) -> { + final ArrayList result = new ArrayList<>(); + result.add("v" + value); + result.add("k" + readOnlyKey); + return result; + }; + + final int[] expectedKeys = {0, 1, 2, 3}; + + final KStream stream = builder.stream(topicName, Consumed.with(Serdes.Integer(), Serdes.Integer())); + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + + stream.flatMapValues(mapper).process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(topicName, new IntegerSerializer(), new IntegerSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + for (final int expectedKey : expectedKeys) { + // passing the timestamp to inputTopic.create to disambiguate the call + inputTopic.pipeInput(expectedKey, expectedKey, 0L); + } + } + + final KeyValueTimestamp[] expected = {new KeyValueTimestamp<>(0, "v0", 0), + new KeyValueTimestamp<>(0, "k0", 0), + new KeyValueTimestamp<>(1, "v1", 0), + new KeyValueTimestamp<>(1, "k1", 0), + new KeyValueTimestamp<>(2, "v2", 0), + new KeyValueTimestamp<>(2, "k2", 0), + new KeyValueTimestamp<>(3, "v3", 0), + new KeyValueTimestamp<>(3, "k3", 0)}; + + assertArrayEquals(expected, supplier.theCapturedProcessor().processed().toArray()); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamFlatTransformTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamFlatTransformTest.java new file mode 100644 index 0000000..8082255 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamFlatTransformTest.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Transformer; +import org.apache.kafka.streams.kstream.TransformerSupplier; +import org.apache.kafka.streams.kstream.internals.KStreamFlatTransform.KStreamFlatTransformProcessor; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.easymock.EasyMock; +import org.easymock.EasyMockSupport; +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collections; + +import static org.junit.Assert.assertTrue; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class KStreamFlatTransformTest extends EasyMockSupport { + + private Number inputKey; + private Number inputValue; + + private Transformer>> transformer; + private ProcessorContext context; + + private KStreamFlatTransformProcessor processor; + + @Before + public void setUp() { + inputKey = 1; + inputValue = 10; + transformer = mock(Transformer.class); + context = strictMock(ProcessorContext.class); + processor = new KStreamFlatTransformProcessor<>(transformer); + } + + @Test + public void shouldInitialiseFlatTransformProcessor() { + transformer.init(context); + replayAll(); + + processor.init(context); + + verifyAll(); + } + + @Test + public void shouldTransformInputRecordToMultipleOutputRecords() { + final Iterable> outputRecords = Arrays.asList( + KeyValue.pair(2, 20), + KeyValue.pair(3, 30), + KeyValue.pair(4, 40)); + processor.init(context); + EasyMock.reset(transformer); + + EasyMock.expect(transformer.transform(inputKey, inputValue)).andReturn(outputRecords); + for (final KeyValue outputRecord : outputRecords) { + context.forward(outputRecord.key, outputRecord.value); + } + replayAll(); + + processor.process(inputKey, inputValue); + + verifyAll(); + } + + @Test + public void shouldAllowEmptyListAsResultOfTransform() { + processor.init(context); + EasyMock.reset(transformer); + + EasyMock.expect(transformer.transform(inputKey, inputValue)) + .andReturn(Collections.>emptyList()); + replayAll(); + + processor.process(inputKey, inputValue); + + verifyAll(); + } + + @Test + public void shouldAllowNullAsResultOfTransform() { + processor.init(context); + EasyMock.reset(transformer); + + EasyMock.expect(transformer.transform(inputKey, inputValue)) + .andReturn(null); + replayAll(); + + processor.process(inputKey, inputValue); + + verifyAll(); + } + + @Test + public void shouldCloseFlatTransformProcessor() { + transformer.close(); + replayAll(); + + processor.close(); + + verifyAll(); + } + + @Test + public void shouldGetFlatTransformProcessor() { + final TransformerSupplier>> transformerSupplier = + mock(TransformerSupplier.class); + final KStreamFlatTransform processorSupplier = + new KStreamFlatTransform<>(transformerSupplier); + + EasyMock.expect(transformerSupplier.get()).andReturn(transformer); + replayAll(); + + final org.apache.kafka.streams.processor.Processor processor = processorSupplier.get(); + + verifyAll(); + assertTrue(processor instanceof KStreamFlatTransformProcessor); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamFlatTransformValuesTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamFlatTransformValuesTest.java new file mode 100644 index 0000000..fd64604 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamFlatTransformValuesTest.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import static org.junit.Assert.assertTrue; + +import java.util.Arrays; +import java.util.Collections; + +import org.apache.kafka.streams.kstream.ValueTransformerWithKey; +import org.apache.kafka.streams.kstream.ValueTransformerWithKeySupplier; +import org.apache.kafka.streams.kstream.internals.KStreamFlatTransformValues.KStreamFlatTransformValuesProcessor; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.internals.ForwardingDisabledProcessorContext; +import org.easymock.EasyMock; +import org.easymock.EasyMockSupport; +import org.junit.Before; +import org.junit.Test; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class KStreamFlatTransformValuesTest extends EasyMockSupport { + + private Integer inputKey; + private Integer inputValue; + + private ValueTransformerWithKey> valueTransformer; + private ProcessorContext context; + + private KStreamFlatTransformValuesProcessor processor; + + @Before + public void setUp() { + inputKey = 1; + inputValue = 10; + valueTransformer = mock(ValueTransformerWithKey.class); + context = strictMock(ProcessorContext.class); + processor = new KStreamFlatTransformValuesProcessor<>(valueTransformer); + } + + @Test + public void shouldInitializeFlatTransformValuesProcessor() { + valueTransformer.init(EasyMock.isA(ForwardingDisabledProcessorContext.class)); + replayAll(); + + processor.init(context); + + verifyAll(); + } + + @Test + public void shouldTransformInputRecordToMultipleOutputValues() { + final Iterable outputValues = Arrays.asList( + "Hello", + "Blue", + "Planet"); + processor.init(context); + EasyMock.reset(valueTransformer); + + EasyMock.expect(valueTransformer.transform(inputKey, inputValue)).andReturn(outputValues); + for (final String outputValue : outputValues) { + context.forward(inputKey, outputValue); + } + replayAll(); + + processor.process(inputKey, inputValue); + + verifyAll(); + } + + @Test + public void shouldEmitNoRecordIfTransformReturnsEmptyList() { + processor.init(context); + EasyMock.reset(valueTransformer); + + EasyMock.expect(valueTransformer.transform(inputKey, inputValue)).andReturn(Collections.emptyList()); + replayAll(); + + processor.process(inputKey, inputValue); + + verifyAll(); + } + + @Test + public void shouldEmitNoRecordIfTransformReturnsNull() { + processor.init(context); + EasyMock.reset(valueTransformer); + + EasyMock.expect(valueTransformer.transform(inputKey, inputValue)).andReturn(null); + replayAll(); + + processor.process(inputKey, inputValue); + + verifyAll(); + } + + @Test + public void shouldCloseFlatTransformValuesProcessor() { + valueTransformer.close(); + replayAll(); + + processor.close(); + + verifyAll(); + } + + @Test + public void shouldGetFlatTransformValuesProcessor() { + final ValueTransformerWithKeySupplier> valueTransformerSupplier = + mock(ValueTransformerWithKeySupplier.class); + final KStreamFlatTransformValues processorSupplier = + new KStreamFlatTransformValues<>(valueTransformerSupplier); + + EasyMock.expect(valueTransformerSupplier.get()).andReturn(valueTransformer); + replayAll(); + + final org.apache.kafka.streams.processor.Processor processor = processorSupplier.get(); + + verifyAll(); + assertTrue(processor instanceof KStreamFlatTransformValuesProcessor); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamForeachTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamForeachTest.java new file mode 100644 index 0000000..35db245 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamForeachTest.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.ForeachAction; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.Properties; + +import static org.junit.Assert.assertEquals; + +public class KStreamForeachTest { + + private final String topicName = "topic"; + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.Integer(), Serdes.String()); + + @Test + public void testForeach() { + // Given + final List> inputRecords = Arrays.asList( + new KeyValue<>(0, "zero"), + new KeyValue<>(1, "one"), + new KeyValue<>(2, "two"), + new KeyValue<>(3, "three") + ); + + final List> expectedRecords = Arrays.asList( + new KeyValue<>(0, "ZERO"), + new KeyValue<>(2, "ONE"), + new KeyValue<>(4, "TWO"), + new KeyValue<>(6, "THREE") + ); + + final List> actualRecords = new ArrayList<>(); + final ForeachAction action = + (key, value) -> actualRecords.add(new KeyValue<>(key * 2, value.toUpperCase(Locale.ROOT))); + + // When + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream = builder.stream(topicName, Consumed.with(Serdes.Integer(), Serdes.String())); + stream.foreach(action); + + // Then + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = driver.createInputTopic(topicName, new IntegerSerializer(), new StringSerializer()); + for (final KeyValue record : inputRecords) { + inputTopic.pipeInput(record.key, record.value); + } + } + + assertEquals(expectedRecords.size(), actualRecords.size()); + for (int i = 0; i < expectedRecords.size(); i++) { + final KeyValue expectedRecord = expectedRecords.get(i); + final KeyValue actualRecord = actualRecords.get(i); + assertEquals(expectedRecord, actualRecord); + } + } + + @Test + public void testTypeVariance() { + final ForeachAction consume = (key, value) -> { }; + + new StreamsBuilder() + .stream("emptyTopic") + .foreach(consume); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamGlobalKTableJoinTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamGlobalKTableJoinTest.java new file mode 100644 index 0000000..fda87dc --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamGlobalKTableJoinTest.java @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.TopologyWrapper; +import org.apache.kafka.streams.kstream.GlobalKTable; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.test.MockProcessor; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.MockValueJoiner; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.time.Duration; +import java.time.Instant; +import java.util.Collection; +import java.util.Properties; +import java.util.Set; + +import static org.junit.Assert.assertEquals; + +public class KStreamGlobalKTableJoinTest { + private final static KeyValueTimestamp[] EMPTY = new KeyValueTimestamp[0]; + + private final String streamTopic = "streamTopic"; + private final String globalTableTopic = "globalTableTopic"; + private final int[] expectedKeys = {0, 1, 2, 3}; + + private TopologyTestDriver driver; + private MockProcessor processor; + private StreamsBuilder builder; + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Before + public void setUp() { + + builder = new StreamsBuilder(); + final KStream stream; + final GlobalKTable table; // value of stream optionally contains key of table + final KeyValueMapper keyMapper; + + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + final Consumed streamConsumed = Consumed.with(Serdes.Integer(), Serdes.String()); + final Consumed tableConsumed = Consumed.with(Serdes.String(), Serdes.String()); + stream = builder.stream(streamTopic, streamConsumed); + table = builder.globalTable(globalTableTopic, tableConsumed); + keyMapper = (key, value) -> { + final String[] tokens = value.split(","); + // Value is comma delimited. If second token is present, it's the key to the global ktable. + // If not present, use null to indicate no match + return tokens.length > 1 ? tokens[1] : null; + }; + stream.join(table, keyMapper, MockValueJoiner.TOSTRING_JOINER).process(supplier); + + final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.Integer(), Serdes.String()); + driver = new TopologyTestDriver(builder.build(), props); + + processor = supplier.theCapturedProcessor(); + } + + @After + public void cleanup() { + driver.close(); + } + + private void pushToStream(final int messageCount, final String valuePrefix, final boolean includeForeignKey, final boolean includeNullKey) { + final TestInputTopic inputTopic = + driver.createInputTopic(streamTopic, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ofMillis(1L)); + for (int i = 0; i < messageCount; i++) { + String value = valuePrefix + expectedKeys[i]; + if (includeForeignKey) { + value = value + ",FKey" + expectedKeys[i]; + } + Integer key = expectedKeys[i]; + if (includeNullKey && i == 0) { + key = null; + } + inputTopic.pipeInput(key, value); + } + } + + private void pushToGlobalTable(final int messageCount, final String valuePrefix) { + final TestInputTopic inputTopic = + driver.createInputTopic(globalTableTopic, new StringSerializer(), new StringSerializer()); + for (int i = 0; i < messageCount; i++) { + inputTopic.pipeInput("FKey" + expectedKeys[i], valuePrefix + expectedKeys[i]); + } + } + + private void pushNullValueToGlobalTable(final int messageCount) { + final TestInputTopic inputTopic = + driver.createInputTopic(globalTableTopic, new StringSerializer(), new StringSerializer()); + for (int i = 0; i < messageCount; i++) { + inputTopic.pipeInput("FKey" + expectedKeys[i], (String) null); + } + } + + @Test + public void shouldNotRequireCopartitioning() { + final Collection> copartitionGroups = + TopologyWrapper.getInternalTopologyBuilder(builder.build()).copartitionGroups(); + + assertEquals("KStream-GlobalKTable joins do not need to be co-partitioned", 0, copartitionGroups.size()); + } + + @Test + public void shouldNotJoinWithEmptyGlobalTableOnStreamUpdates() { + + // push two items to the primary stream. the globalTable is empty + + pushToStream(2, "X", true, false); + processor.checkAndClearProcessResult(EMPTY); + } + + @Test + public void shouldNotJoinOnGlobalTableUpdates() { + + // push two items to the primary stream. the globalTable is empty + + pushToStream(2, "X", true, false); + processor.checkAndClearProcessResult(EMPTY); + + // push two items to the globalTable. this should not produce any item. + + pushToGlobalTable(2, "Y"); + processor.checkAndClearProcessResult(EMPTY); + + // push all four items to the primary stream. this should produce two items. + + pushToStream(4, "X", true, false); + processor.checkAndClearProcessResult(new KeyValueTimestamp<>(0, "X0,FKey0+Y0", 0), + new KeyValueTimestamp<>(1, "X1,FKey1+Y1", 1)); + + // push all items to the globalTable. this should not produce any item + + pushToGlobalTable(4, "YY"); + processor.checkAndClearProcessResult(EMPTY); + + // push all four items to the primary stream. this should produce four items. + + pushToStream(4, "X", true, false); + processor.checkAndClearProcessResult(new KeyValueTimestamp<>(0, "X0,FKey0+YY0", 0), + new KeyValueTimestamp<>(1, "X1,FKey1+YY1", 1), + new KeyValueTimestamp<>(2, "X2,FKey2+YY2", 2), + new KeyValueTimestamp<>(3, "X3,FKey3+YY3", 3)); + + // push all items to the globalTable. this should not produce any item + + pushToGlobalTable(4, "YYY"); + processor.checkAndClearProcessResult(EMPTY); + } + + @Test + public void shouldJoinOnlyIfMatchFoundOnStreamUpdates() { + + // push two items to the globalTable. this should not produce any item. + + pushToGlobalTable(2, "Y"); + processor.checkAndClearProcessResult(EMPTY); + + // push all four items to the primary stream. this should produce two items. + + pushToStream(4, "X", true, false); + processor.checkAndClearProcessResult(new KeyValueTimestamp<>(0, "X0,FKey0+Y0", 0), + new KeyValueTimestamp<>(1, "X1,FKey1+Y1", 1)); + + } + + @Test + public void shouldClearGlobalTableEntryOnNullValueUpdates() { + + // push all four items to the globalTable. this should not produce any item. + + pushToGlobalTable(4, "Y"); + processor.checkAndClearProcessResult(EMPTY); + + // push all four items to the primary stream. this should produce four items. + + pushToStream(4, "X", true, false); + processor.checkAndClearProcessResult(new KeyValueTimestamp<>(0, "X0,FKey0+Y0", 0), + new KeyValueTimestamp<>(1, "X1,FKey1+Y1", 1), + new KeyValueTimestamp<>(2, "X2,FKey2+Y2", 2), + new KeyValueTimestamp<>(3, "X3,FKey3+Y3", 3)); + + // push two items with null to the globalTable as deletes. this should not produce any item. + + pushNullValueToGlobalTable(2); + processor.checkAndClearProcessResult(EMPTY); + + // push all four items to the primary stream. this should produce two items. + + pushToStream(4, "XX", true, false); + processor.checkAndClearProcessResult(new KeyValueTimestamp<>(2, "XX2,FKey2+Y2", 2), + new KeyValueTimestamp<>(3, "XX3,FKey3+Y3", 3)); + } + + @Test + public void shouldNotJoinOnNullKeyMapperValues() { + + // push all items to the globalTable. this should not produce any item + + pushToGlobalTable(4, "Y"); + processor.checkAndClearProcessResult(EMPTY); + + // push all four items to the primary stream with no foreign key, resulting in null keyMapper values. + // this should not produce any item. + + pushToStream(4, "XXX", false, false); + processor.checkAndClearProcessResult(EMPTY); + } + + @Test + public void shouldJoinOnNullKeyWithNonNullKeyMapperValues() { + // push two items to the globalTable. this should not produce any item. + + pushToGlobalTable(2, "Y"); + processor.checkAndClearProcessResult(EMPTY); + + // push all four items to the primary stream. this should produce two items. + + pushToStream(4, "X", true, true); + processor.checkAndClearProcessResult(new KeyValueTimestamp<>(null, "X0,FKey0+Y0", 0), + new KeyValueTimestamp<>(1, "X1,FKey1+Y1", 1)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamGlobalKTableLeftJoinTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamGlobalKTableLeftJoinTest.java new file mode 100644 index 0000000..9268997 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamGlobalKTableLeftJoinTest.java @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.TopologyWrapper; +import org.apache.kafka.streams.kstream.GlobalKTable; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.test.MockProcessor; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.MockValueJoiner; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.time.Duration; +import java.time.Instant; +import java.util.Collection; +import java.util.Properties; +import java.util.Set; + +import static org.junit.Assert.assertEquals; + +public class KStreamGlobalKTableLeftJoinTest { + private final static KeyValueTimestamp[] EMPTY = new KeyValueTimestamp[0]; + + private final String streamTopic = "streamTopic"; + private final String globalTableTopic = "globalTableTopic"; + private final int[] expectedKeys = {0, 1, 2, 3}; + + private MockProcessor processor; + private TopologyTestDriver driver; + private StreamsBuilder builder; + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Before + public void setUp() { + + builder = new StreamsBuilder(); + final KStream stream; + final GlobalKTable table; // value of stream optionally contains key of table + final KeyValueMapper keyMapper; + + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + final Consumed streamConsumed = Consumed.with(Serdes.Integer(), Serdes.String()); + final Consumed tableConsumed = Consumed.with(Serdes.String(), Serdes.String()); + stream = builder.stream(streamTopic, streamConsumed); + table = builder.globalTable(globalTableTopic, tableConsumed); + keyMapper = (key, value) -> { + final String[] tokens = value.split(","); + // Value is comma delimited. If second token is present, it's the key to the global ktable. + // If not present, use null to indicate no match + return tokens.length > 1 ? tokens[1] : null; + }; + stream.leftJoin(table, keyMapper, MockValueJoiner.TOSTRING_JOINER).process(supplier); + + final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.Integer(), Serdes.String()); + driver = new TopologyTestDriver(builder.build(), props); + + processor = supplier.theCapturedProcessor(); + } + + @After + public void cleanup() { + driver.close(); + } + + private void pushToStream(final int messageCount, final String valuePrefix, final boolean includeForeignKey, final boolean includeNullKey) { + final TestInputTopic inputTopic = + driver.createInputTopic(streamTopic, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ofMillis(1L)); + for (int i = 0; i < messageCount; i++) { + String value = valuePrefix + expectedKeys[i]; + if (includeForeignKey) { + value = value + ",FKey" + expectedKeys[i]; + } + Integer key = expectedKeys[i]; + if (includeNullKey && i == 0) { + key = null; + } + inputTopic.pipeInput(key, value); + } + } + + private void pushToGlobalTable(final int messageCount, final String valuePrefix) { + final TestInputTopic inputTopic = + driver.createInputTopic(globalTableTopic, new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ofMillis(1L)); + for (int i = 0; i < messageCount; i++) { + inputTopic.pipeInput("FKey" + expectedKeys[i], valuePrefix + expectedKeys[i]); + } + } + + private void pushNullValueToGlobalTable(final int messageCount) { + final TestInputTopic inputTopic = + driver.createInputTopic(globalTableTopic, new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ofMillis(1L)); + for (int i = 0; i < messageCount; i++) { + inputTopic.pipeInput("FKey" + expectedKeys[i], (String) null); + } + } + + @Test + public void shouldNotRequireCopartitioning() { + final Collection> copartitionGroups = + TopologyWrapper.getInternalTopologyBuilder(builder.build()).copartitionGroups(); + + assertEquals("KStream-GlobalKTable joins do not need to be co-partitioned", 0, copartitionGroups.size()); + } + + @Test + public void shouldNotJoinWithEmptyGlobalTableOnStreamUpdates() { + + // push two items to the primary stream. the globalTable is empty + + pushToStream(2, "X", true, false); + processor.checkAndClearProcessResult(new KeyValueTimestamp<>(0, "X0,FKey0+null", 0), + new KeyValueTimestamp<>(1, "X1,FKey1+null", 1)); + } + + @Test + public void shouldNotJoinOnGlobalTableUpdates() { + + // push two items to the primary stream. the globalTable is empty + + pushToStream(2, "X", true, false); + processor.checkAndClearProcessResult(new KeyValueTimestamp<>(0, "X0,FKey0+null", 0), + new KeyValueTimestamp<>(1, "X1,FKey1+null", 1)); + + // push two items to the globalTable. this should not produce any item. + + pushToGlobalTable(2, "Y"); + processor.checkAndClearProcessResult(EMPTY); + + // push all four items to the primary stream. this should produce four items. + + pushToStream(4, "X", true, false); + processor.checkAndClearProcessResult(new KeyValueTimestamp<>(0, "X0,FKey0+Y0", 0), + new KeyValueTimestamp<>(1, "X1,FKey1+Y1", 1), + new KeyValueTimestamp<>(2, "X2,FKey2+null", 2), + new KeyValueTimestamp<>(3, "X3,FKey3+null", 3)); + + // push all items to the globalTable. this should not produce any item + + pushToGlobalTable(4, "YY"); + processor.checkAndClearProcessResult(EMPTY); + + // push all four items to the primary stream. this should produce four items. + + pushToStream(4, "X", true, false); + processor.checkAndClearProcessResult(new KeyValueTimestamp<>(0, "X0,FKey0+YY0", 0), + new KeyValueTimestamp<>(1, "X1,FKey1+YY1", 1), + new KeyValueTimestamp<>(2, "X2,FKey2+YY2", 2), + new KeyValueTimestamp<>(3, "X3,FKey3+YY3", 3)); + + // push all items to the globalTable. this should not produce any item + + pushToGlobalTable(4, "YYY"); + processor.checkAndClearProcessResult(EMPTY); + } + + @Test + public void shouldJoinRegardlessIfMatchFoundOnStreamUpdates() { + + // push two items to the globalTable. this should not produce any item. + + pushToGlobalTable(2, "Y"); + processor.checkAndClearProcessResult(EMPTY); + + // push all four items to the primary stream. this should produce four items. + + pushToStream(4, "X", true, false); + processor.checkAndClearProcessResult(new KeyValueTimestamp<>(0, "X0,FKey0+Y0", 0), + new KeyValueTimestamp<>(1, "X1,FKey1+Y1", 1), + new KeyValueTimestamp<>(2, "X2,FKey2+null", 2), + new KeyValueTimestamp<>(3, "X3,FKey3+null", 3)); + + } + + @Test + public void shouldClearGlobalTableEntryOnNullValueUpdates() { + + // push all four items to the globalTable. this should not produce any item. + + pushToGlobalTable(4, "Y"); + processor.checkAndClearProcessResult(EMPTY); + + // push all four items to the primary stream. this should produce four items. + + pushToStream(4, "X", true, false); + processor.checkAndClearProcessResult(new KeyValueTimestamp<>(0, "X0,FKey0+Y0", 0), + new KeyValueTimestamp<>(1, "X1,FKey1+Y1", 1), + new KeyValueTimestamp<>(2, "X2,FKey2+Y2", 2), + new KeyValueTimestamp<>(3, "X3,FKey3+Y3", 3)); + + // push two items with null to the globalTable as deletes. this should not produce any item. + + pushNullValueToGlobalTable(2); + processor.checkAndClearProcessResult(EMPTY); + + // push all four items to the primary stream. this should produce four items. + + pushToStream(4, "XX", true, false); + processor.checkAndClearProcessResult(new KeyValueTimestamp<>(0, "XX0,FKey0+null", 0), + new KeyValueTimestamp<>(1, "XX1,FKey1+null", 1), + new KeyValueTimestamp<>(2, "XX2,FKey2+Y2", 2), + new KeyValueTimestamp<>(3, "XX3,FKey3+Y3", 3)); + } + + @Test + public void shouldNotJoinOnNullKeyMapperValues() { + + // push all items to the globalTable. this should not produce any item + + pushToGlobalTable(4, "Y"); + processor.checkAndClearProcessResult(EMPTY); + + // push all four items to the primary stream with no foreign key, resulting in null keyMapper values. + // this should not produce any item. + + pushToStream(4, "XXX", false, false); + processor.checkAndClearProcessResult(EMPTY); + } + + @Test + public void shouldJoinOnNullKeyWithNonNullKeyMapperValues() { + // push four items to the globalTable. this should not produce any item. + + pushToGlobalTable(4, "Y"); + processor.checkAndClearProcessResult(EMPTY); + + // push all four items to the primary stream. this should produce four items. + + pushToStream(4, "X", true, true); + processor.checkAndClearProcessResult(new KeyValueTimestamp<>(null, "X0,FKey0+Y0", 0), + new KeyValueTimestamp<>(1, "X1,FKey1+Y1", 1), + new KeyValueTimestamp<>(2, "X2,FKey2+Y2", 2), + new KeyValueTimestamp<>(3, "X3,FKey3+Y3", 3)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamImplTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamImplTest.java new file mode 100644 index 0000000..9753453 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamImplTest.java @@ -0,0 +1,3079 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TestOutputTopic; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.TopologyWrapper; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.GlobalKTable; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.JoinWindows; +import org.apache.kafka.streams.kstream.Joined; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Named; +import org.apache.kafka.streams.kstream.Predicate; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.Repartitioned; +import org.apache.kafka.streams.kstream.StreamJoined; +import org.apache.kafka.streams.kstream.Transformer; +import org.apache.kafka.streams.kstream.TransformerSupplier; +import org.apache.kafka.streams.kstream.ValueJoiner; +import org.apache.kafka.streams.kstream.ValueJoinerWithKey; +import org.apache.kafka.streams.kstream.ValueMapper; +import org.apache.kafka.streams.kstream.ValueMapperWithKey; +import org.apache.kafka.streams.kstream.ValueTransformer; +import org.apache.kafka.streams.kstream.ValueTransformerSupplier; +import org.apache.kafka.streams.kstream.ValueTransformerWithKey; +import org.apache.kafka.streams.kstream.ValueTransformerWithKeySupplier; +import org.apache.kafka.streams.processor.FailOnInvalidTimestamp; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.TopicNameExtractor; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.internals.ProcessorTopology; +import org.apache.kafka.streams.processor.internals.SourceNode; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.test.TestRecord; +import org.apache.kafka.test.MockMapper; +import org.apache.kafka.test.MockProcessor; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.MockValueJoiner; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Before; +import org.junit.Test; + +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static java.time.Duration.ofMillis; +import static java.util.Arrays.asList; +import static java.util.Collections.emptyMap; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.IsInstanceOf.instanceOf; +import static org.hamcrest.core.IsNull.notNullValue; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class KStreamImplTest { + + private final Consumed stringConsumed = Consumed.with(Serdes.String(), Serdes.String()); + private final MockProcessorSupplier processorSupplier = new MockProcessorSupplier<>(); + private final TransformerSupplier> transformerSupplier = + () -> new Transformer>() { + @Override + public void init(final ProcessorContext context) {} + + @Override + public KeyValue transform(final String key, final String value) { + return new KeyValue<>(key, value); + } + + @Override + public void close() {} + }; + private final TransformerSupplier>> flatTransformerSupplier = + () -> new Transformer>>() { + @Override + public void init(final ProcessorContext context) {} + + @Override + public Iterable> transform(final String key, final String value) { + return Collections.singleton(new KeyValue<>(key, value)); + } + + @Override + public void close() {} + }; + private final ValueTransformerSupplier valueTransformerSupplier = + () -> new ValueTransformer() { + @Override + public void init(final ProcessorContext context) {} + + @Override + public String transform(final String value) { + return value; + } + + @Override + public void close() {} + }; + private final ValueTransformerWithKeySupplier valueTransformerWithKeySupplier = + () -> new ValueTransformerWithKey() { + @Override + public void init(final ProcessorContext context) {} + + @Override + public String transform(final String key, final String value) { + return value; + } + + @Override + public void close() {} + }; + private final ValueTransformerSupplier> flatValueTransformerSupplier = + () -> new ValueTransformer>() { + @Override + public void init(final ProcessorContext context) {} + + @Override + public Iterable transform(final String value) { + return Collections.singleton(value); + } + + @Override + public void close() {} + }; + private final ValueTransformerWithKeySupplier> flatValueTransformerWithKeySupplier = + () -> new ValueTransformerWithKey>() { + @Override + public void init(final ProcessorContext context) {} + + @Override + public Iterable transform(final String key, final String value) { + return Collections.singleton(value); + } + + @Override + public void close() {} + }; + + private StreamsBuilder builder; + private KStream testStream; + private KTable testTable; + private GlobalKTable testGlobalTable; + + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.String(), Serdes.String()); + + private final Serde mySerde = new Serdes.StringSerde(); + + @Before + public void before() { + builder = new StreamsBuilder(); + testStream = builder.stream("source"); + testTable = builder.table("topic"); + testGlobalTable = builder.globalTable("global"); + } + + @Test + public void shouldNotAllowNullPredicateOnFilter() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.filter(null)); + assertThat(exception.getMessage(), equalTo("predicate can't be null")); + } + + @Test + public void shouldNotAllowNullPredicateOnFilterWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.filter(null, Named.as("filter"))); + assertThat(exception.getMessage(), equalTo("predicate can't be null")); + } + + @Test + public void shouldNotAllowNullNamedOnFilter() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.filter((k, v) -> true, null)); + assertThat(exception.getMessage(), equalTo("named can't be null")); + } + + @Test + public void shouldNotAllowNullPredicateOnFilterNot() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.filterNot(null)); + assertThat(exception.getMessage(), equalTo("predicate can't be null")); + } + + @Test + public void shouldNotAllowNullPredicateOnFilterNotWithName() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.filterNot(null, Named.as("filter"))); + assertThat(exception.getMessage(), equalTo("predicate can't be null")); + } + + @Test + public void shouldNotAllowNullNamedOnFilterNot() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.filterNot((k, v) -> true, null)); + assertThat(exception.getMessage(), equalTo("named can't be null")); + } + + @Test + public void shouldNotAllowNullMapperOnSelectKey() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.selectKey(null)); + assertThat(exception.getMessage(), equalTo("mapper can't be null")); + } + + @Test + public void shouldNotAllowNullMapperOnSelectKeyWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.selectKey(null, Named.as("keySelector"))); + assertThat(exception.getMessage(), equalTo("mapper can't be null")); + } + + @Test + public void shouldNotAllowNullNamedOnSelectKey() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.selectKey((k, v) -> k, null)); + assertThat(exception.getMessage(), equalTo("named can't be null")); + } + + @Test + public void shouldNotAllowNullMapperOnMap() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.map(null)); + assertThat(exception.getMessage(), equalTo("mapper can't be null")); + } + + @Test + public void shouldNotAllowNullMapperOnMapWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.map(null, Named.as("map"))); + assertThat(exception.getMessage(), equalTo("mapper can't be null")); + } + + @Test + public void shouldNotAllowNullNamedOnMap() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.map(KeyValue::pair, null)); + assertThat(exception.getMessage(), equalTo("named can't be null")); + } + + @Test + public void shouldNotAllowNullMapperOnMapValues() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.mapValues((ValueMapper) null)); + assertThat(exception.getMessage(), equalTo("valueMapper can't be null")); + } + + @Test + public void shouldNotAllowNullMapperOnMapValuesWithKey() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.mapValues((ValueMapperWithKey) null)); + assertThat(exception.getMessage(), equalTo("valueMapperWithKey can't be null")); + } + + @Test + public void shouldNotAllowNullMapperOnMapValuesWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.mapValues((ValueMapper) null, Named.as("valueMapper"))); + assertThat(exception.getMessage(), equalTo("valueMapper can't be null")); + } + + @Test + public void shouldNotAllowNullMapperOnMapValuesWithKeyWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.mapValues( + (ValueMapperWithKey) null, + Named.as("valueMapperWithKey"))); + assertThat(exception.getMessage(), equalTo("valueMapperWithKey can't be null")); + } + + @Test + public void shouldNotAllowNullNamedOnMapValues() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.mapValues(v -> v, null)); + assertThat(exception.getMessage(), equalTo("named can't be null")); + } + + @Test + public void shouldNotAllowNullNamedOnMapValuesWithKey() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.mapValues((k, v) -> v, null)); + assertThat(exception.getMessage(), equalTo("named can't be null")); + } + + @Test + public void shouldNotAllowNullMapperOnFlatMap() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatMap(null)); + assertThat(exception.getMessage(), equalTo("mapper can't be null")); + } + + @Test + public void shouldNotAllowNullMapperOnFlatMapWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatMap(null, Named.as("flatMapper"))); + assertThat(exception.getMessage(), equalTo("mapper can't be null")); + } + + @Test + public void shouldNotAllowNullNamedOnFlatMap() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatMap((k, v) -> Collections.singleton(new KeyValue<>(k, v)), null)); + assertThat(exception.getMessage(), equalTo("named can't be null")); + } + + @Test + public void shouldNotAllowNullMapperOnFlatMapValues() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatMapValues((ValueMapper>) null)); + assertThat(exception.getMessage(), equalTo("valueMapper can't be null")); + } + + @Test + public void shouldNotAllowNullMapperOnFlatMapValuesWithKey() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatMapValues((ValueMapperWithKey>) null)); + assertThat(exception.getMessage(), equalTo("valueMapper can't be null")); + } + + @Test + public void shouldNotAllowNullMapperOnFlatMapValuesWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatMapValues( + (ValueMapper>) null, + Named.as("flatValueMapper"))); + assertThat(exception.getMessage(), equalTo("valueMapper can't be null")); + } + + @Test + public void shouldNotAllowNullMapperOnFlatMapValuesWithKeyWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatMapValues( + (ValueMapperWithKey>) null, + Named.as("flatValueMapperWithKey"))); + assertThat(exception.getMessage(), equalTo("valueMapper can't be null")); + } + + @Test + public void shouldNotAllowNullNameOnFlatMapValues() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatMapValues(v -> Collections.emptyList(), null)); + assertThat(exception.getMessage(), equalTo("named can't be null")); + } + + @Test + public void shouldNotAllowNullNameOnFlatMapValuesWithKey() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatMapValues((k, v) -> Collections.emptyList(), null)); + assertThat(exception.getMessage(), equalTo("named can't be null")); + } + + @Test + public void shouldNotAllowNullPrintedOnPrint() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.print(null)); + assertThat(exception.getMessage(), equalTo("printed can't be null")); + } + + @Test + public void shouldNotAllowNullActionOnForEach() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.foreach(null)); + assertThat(exception.getMessage(), equalTo("action can't be null")); + } + + @Test + public void shouldNotAllowNullActionOnForEachWithName() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.foreach(null, Named.as("foreach"))); + assertThat(exception.getMessage(), equalTo("action can't be null")); + } + + @Test + public void shouldNotAllowNullNamedOnForEach() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.foreach((k, v) -> { }, null)); + assertThat(exception.getMessage(), equalTo("named can't be null")); + } + + @Test + public void shouldNotAllowNullActionOnPeek() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.peek(null)); + assertThat(exception.getMessage(), equalTo("action can't be null")); + } + + @Test + public void shouldNotAllowNullActionOnPeekWithName() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.peek(null, Named.as("peek"))); + assertThat(exception.getMessage(), equalTo("action can't be null")); + } + + @Test + public void shouldNotAllowNullNamedOnPeek() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.peek((k, v) -> { }, null)); + assertThat(exception.getMessage(), equalTo("named can't be null")); + } + + @Test + @SuppressWarnings({"rawtypes", "unchecked", "deprecation"}) + public void shouldNotAllowNullPredicatedOnBranch() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.branch((Predicate[]) null)); + assertThat(exception.getMessage(), equalTo("predicates can't be a null array")); + } + + @Test + @SuppressWarnings({"unchecked", "deprecation"}) + public void shouldHaveAtLeastOnPredicateWhenBranching() { + final IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> testStream.branch()); + assertThat(exception.getMessage(), equalTo("branch() requires at least one predicate")); + } + + @SuppressWarnings({"unchecked", "deprecation"}) + @Test + public void shouldHaveAtLeastOnPredicateWhenBranchingWithNamed() { + final IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> testStream.branch(Named.as("branch"))); + assertThat(exception.getMessage(), equalTo("branch() requires at least one predicate")); + } + + @SuppressWarnings({"unchecked", "deprecation"}) + @Test + public void shouldNotAllowNullNamedOnBranch() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.branch((Named) null, (k, v) -> true)); + assertThat(exception.getMessage(), equalTo("named can't be null")); + } + + @SuppressWarnings({"unchecked", "deprecation"}) + @Test + public void shouldCantHaveNullPredicate() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.branch((Predicate) null)); + assertThat(exception.getMessage(), equalTo("predicates can't be null")); + } + + @SuppressWarnings({"unchecked", "deprecation"}) + @Test + public void shouldCantHaveNullPredicateWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.branch(Named.as("branch"), (Predicate) null)); + assertThat(exception.getMessage(), equalTo("predicates can't be null")); + } + + @Test + public void shouldNotAllowNullKStreamOnMerge() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.merge(null)); + assertThat(exception.getMessage(), equalTo("stream can't be null")); + } + + @Test + public void shouldNotAllowNullKStreamOnMergeWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.merge(null, Named.as("merge"))); + assertThat(exception.getMessage(), equalTo("stream can't be null")); + } + + @Test + public void shouldNotAllowNullNamedOnMerge() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.merge(testStream, null)); + assertThat(exception.getMessage(), equalTo("named can't be null")); + } + + @Deprecated // specifically testing the deprecated variant + @Test + public void shouldNotAllowNullTopicOnThrough() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.through(null)); + assertThat(exception.getMessage(), equalTo("topic can't be null")); + } + + @Deprecated // specifically testing the deprecated variant + @Test + public void shouldNotAllowNullTopicOnThroughWithProduced() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.through(null, Produced.as("through"))); + assertThat(exception.getMessage(), equalTo("topic can't be null")); + } + + @Deprecated // specifically testing the deprecated variant + @Test + public void shouldNotAllowNullProducedOnThrough() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.through("topic", null)); + assertThat(exception.getMessage(), equalTo("produced can't be null")); + } + + @Test + public void shouldNotAllowNullTopicOnTo() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.to((String) null)); + assertThat(exception.getMessage(), equalTo("topic can't be null")); + } + + @Test + public void shouldNotAllowNullRepartitionedOnRepartition() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.repartition(null)); + assertThat(exception.getMessage(), equalTo("repartitioned can't be null")); + } + + @Test + public void shouldNotAllowNullTopicChooserOnTo() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.to((TopicNameExtractor) null)); + assertThat(exception.getMessage(), equalTo("topicExtractor can't be null")); + } + + @Test + public void shouldNotAllowNullTopicOnToWithProduced() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.to((String) null, Produced.as("to"))); + assertThat(exception.getMessage(), equalTo("topic can't be null")); + } + + @Test + public void shouldNotAllowNullTopicChooserOnToWithProduced() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.to((TopicNameExtractor) null, Produced.as("to"))); + assertThat(exception.getMessage(), equalTo("topicExtractor can't be null")); + } + + @Test + public void shouldNotAllowNullProducedOnToWithTopicName() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.to("topic", null)); + assertThat(exception.getMessage(), equalTo("produced can't be null")); + } + + @Test + public void shouldNotAllowNullProducedOnToWithTopicChooser() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.to((k, v, ctx) -> "topic", null)); + assertThat(exception.getMessage(), equalTo("produced can't be null")); + } + + @Test + public void shouldNotAllowNullSelectorOnGroupBy() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.groupBy(null)); + assertThat(exception.getMessage(), equalTo("keySelector can't be null")); + } + + @Test + public void shouldNotAllowNullSelectorOnGroupByWithGrouped() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.groupBy(null, Grouped.as("name"))); + assertThat(exception.getMessage(), equalTo("keySelector can't be null")); + } + + @Test + public void shouldNotAllowNullGroupedOnGroupBy() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.groupBy((k, v) -> k, (Grouped) null)); + assertThat(exception.getMessage(), equalTo("grouped can't be null")); + } + + @Test + public void shouldNotAllowNullGroupedOnGroupByKey() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.groupByKey((Grouped) null)); + assertThat(exception.getMessage(), equalTo("grouped can't be null")); + } + + @Test + public void shouldNotAllowNullNamedOnToTable() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.toTable((Named) null)); + assertThat(exception.getMessage(), equalTo("named can't be null")); + } + + @Test + public void shouldNotAllowNullMaterializedOnToTable() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.toTable((Materialized>) null)); + assertThat(exception.getMessage(), equalTo("materialized can't be null")); + } + + @Test + public void shouldNotAllowNullNamedOnToTableWithMaterialized() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.toTable(null, Materialized.with(null, null))); + assertThat(exception.getMessage(), equalTo("named can't be null")); + } + + @Test + public void shouldNotAllowNullMaterializedOnToTableWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.toTable(Named.as("name"), null)); + assertThat(exception.getMessage(), equalTo("materialized can't be null")); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNotAllowNullOtherStreamOnJoin() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.join(null, MockValueJoiner.TOSTRING_JOINER, JoinWindows.of(ofMillis(10)))); + assertThat(exception.getMessage(), equalTo("otherStream can't be null")); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNotAllowNullOtherStreamOnJoinWithStreamJoined() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.join( + null, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.of(ofMillis(10)), + StreamJoined.as("name"))); + assertThat(exception.getMessage(), equalTo("otherStream can't be null")); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNotAllowNullValueJoinerOnJoin() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.join(testStream, (ValueJoiner) null, JoinWindows.of(ofMillis(10)))); + assertThat(exception.getMessage(), equalTo("joiner can't be null")); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNotAllowNullValueJoinerWithKeyOnJoin() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.join(testStream, (ValueJoinerWithKey) null, JoinWindows.of(ofMillis(10)))); + assertThat(exception.getMessage(), equalTo("joiner can't be null")); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNotAllowNullValueJoinerOnJoinWithStreamJoined() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.join( + testStream, + (ValueJoiner) null, + JoinWindows.of(ofMillis(10)), + StreamJoined.as("name"))); + assertThat(exception.getMessage(), equalTo("joiner can't be null")); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNotAllowNullValueJoinerWithKeyOnJoinWithStreamJoined() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.join( + testStream, + (ValueJoinerWithKey) null, + JoinWindows.of(ofMillis(10)), + StreamJoined.as("name"))); + assertThat(exception.getMessage(), equalTo("joiner can't be null")); + } + + @Test + public void shouldNotAllowNullJoinWindowsOnJoin() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.join(testStream, MockValueJoiner.TOSTRING_JOINER, null)); + assertThat(exception.getMessage(), equalTo("windows can't be null")); + } + + @Test + public void shouldNotAllowNullJoinWindowsOnJoinWithStreamJoined() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.join( + testStream, + MockValueJoiner.TOSTRING_JOINER, + null, + StreamJoined.as("name"))); + assertThat(exception.getMessage(), equalTo("windows can't be null")); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNotAllowNullStreamJoinedOnJoin() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.join( + testStream, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.of(ofMillis(10)), + (StreamJoined) null)); + assertThat(exception.getMessage(), equalTo("streamJoined can't be null")); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNotAllowNullOtherStreamOnLeftJoin() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.leftJoin(null, MockValueJoiner.TOSTRING_JOINER, JoinWindows.of(ofMillis(10)))); + assertThat(exception.getMessage(), equalTo("otherStream can't be null")); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNotAllowNullOtherStreamOnLeftJoinWithStreamJoined() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.leftJoin( + null, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.of(ofMillis(10)), + StreamJoined.as("name"))); + assertThat(exception.getMessage(), equalTo("otherStream can't be null")); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNotAllowNullValueJoinerOnLeftJoin() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.leftJoin(testStream, (ValueJoiner) null, JoinWindows.of(ofMillis(10)))); + assertThat(exception.getMessage(), equalTo("joiner can't be null")); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNotAllowNullValueJoinerWithKeyOnLeftJoin() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.leftJoin(testStream, (ValueJoinerWithKey) null, JoinWindows.of(ofMillis(10)))); + assertThat(exception.getMessage(), equalTo("joiner can't be null")); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNotAllowNullValueJoinerOnLeftJoinWithStreamJoined() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.leftJoin( + testStream, + (ValueJoiner) null, + JoinWindows.of(ofMillis(10)), + StreamJoined.as("name"))); + assertThat(exception.getMessage(), equalTo("joiner can't be null")); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNotAllowNullValueJoinerWithKeyOnLeftJoinWithStreamJoined() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.leftJoin( + testStream, + (ValueJoinerWithKey) null, + JoinWindows.of(ofMillis(10)), + StreamJoined.as("name"))); + assertThat(exception.getMessage(), equalTo("joiner can't be null")); + } + + + @Test + public void shouldNotAllowNullJoinWindowsOnLeftJoin() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.leftJoin(testStream, MockValueJoiner.TOSTRING_JOINER, null)); + assertThat(exception.getMessage(), equalTo("windows can't be null")); + } + + @Test + public void shouldNotAllowNullJoinWindowsOnLeftJoinWithStreamJoined() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.leftJoin( + testStream, + MockValueJoiner.TOSTRING_JOINER, + null, + StreamJoined.as("name"))); + assertThat(exception.getMessage(), equalTo("windows can't be null")); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNotAllowNullStreamJoinedOnLeftJoin() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.leftJoin( + testStream, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.of(ofMillis(10)), + (StreamJoined) null)); + assertThat(exception.getMessage(), equalTo("streamJoined can't be null")); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNotAllowNullOtherStreamOnOuterJoin() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.outerJoin(null, MockValueJoiner.TOSTRING_JOINER, JoinWindows.of(ofMillis(10)))); + assertThat(exception.getMessage(), equalTo("otherStream can't be null")); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNotAllowNullOtherStreamOnOuterJoinWithStreamJoined() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.outerJoin( + null, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.of(ofMillis(10)), + StreamJoined.as("name"))); + assertThat(exception.getMessage(), equalTo("otherStream can't be null")); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNotAllowNullValueJoinerOnOuterJoin() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.outerJoin(testStream, (ValueJoiner) null, JoinWindows.of(ofMillis(10)))); + assertThat(exception.getMessage(), equalTo("joiner can't be null")); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNotAllowNullValueJoinerWithKeyOnOuterJoin() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.outerJoin(testStream, (ValueJoinerWithKey) null, JoinWindows.of(ofMillis(10)))); + assertThat(exception.getMessage(), equalTo("joiner can't be null")); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNotAllowNullValueJoinerOnOuterJoinWithStreamJoined() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.outerJoin( + testStream, + (ValueJoiner) null, + JoinWindows.of(ofMillis(10)), + StreamJoined.as("name"))); + assertThat(exception.getMessage(), equalTo("joiner can't be null")); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNotAllowNullValueJoinerWithKeyOnOuterJoinWithStreamJoined() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.outerJoin( + testStream, + (ValueJoinerWithKey) null, + JoinWindows.of(ofMillis(10)), + StreamJoined.as("name"))); + assertThat(exception.getMessage(), equalTo("joiner can't be null")); + } + + @Test + public void shouldNotAllowNullJoinWindowsOnOuterJoin() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.outerJoin(testStream, MockValueJoiner.TOSTRING_JOINER, null)); + assertThat(exception.getMessage(), equalTo("windows can't be null")); + } + + @Test + public void shouldNotAllowNullJoinWindowsOnOuterJoinWithStreamJoined() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.outerJoin( + testStream, + MockValueJoiner.TOSTRING_JOINER, + null, + StreamJoined.as("name"))); + assertThat(exception.getMessage(), equalTo("windows can't be null")); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNotAllowNullStreamJoinedOnOuterJoin() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.outerJoin( + testStream, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.of(ofMillis(10)), + (StreamJoined) null)); + assertThat(exception.getMessage(), equalTo("streamJoined can't be null")); + } + + @Test + public void shouldNotAllowNullTableOnTableJoin() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.join(null, MockValueJoiner.TOSTRING_JOINER)); + assertThat(exception.getMessage(), equalTo("table can't be null")); + } + + @Test + public void shouldNotAllowNullTableOnTableJoinWithJoiner() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.join(null, MockValueJoiner.TOSTRING_JOINER, Joined.as("name"))); + assertThat(exception.getMessage(), equalTo("table can't be null")); + } + + @Test + public void shouldNotAllowNullValueJoinerOnTableJoin() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.join(testTable, (ValueJoiner) null)); + assertThat(exception.getMessage(), equalTo("joiner can't be null")); + } + + @Test + public void shouldNotAllowNullValueJoinerWithKeyOnTableJoin() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.join(testTable, (ValueJoinerWithKey) null)); + assertThat(exception.getMessage(), equalTo("joiner can't be null")); + } + + @Test + public void shouldNotAllowNullValueJoinerOnTableJoinWithJoiner() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.join(testTable, (ValueJoiner) null, Joined.as("name"))); + assertThat(exception.getMessage(), equalTo("joiner can't be null")); + } + + @Test + public void shouldNotAllowNullValueJoinerWithKeyOnTableJoinWithJoiner() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.join(testTable, (ValueJoinerWithKey) null, Joined.as("name"))); + assertThat(exception.getMessage(), equalTo("joiner can't be null")); + } + + @Test + public void shouldNotAllowNullJoinedOnTableJoin() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.join(testTable, MockValueJoiner.TOSTRING_JOINER, null)); + assertThat(exception.getMessage(), equalTo("joined can't be null")); + } + + @Test + public void shouldNotAllowNullTableOnTableLeftJoin() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.leftJoin(null, MockValueJoiner.TOSTRING_JOINER)); + assertThat(exception.getMessage(), equalTo("table can't be null")); + } + + @Test + public void shouldNotAllowNullTableOnTableLeftJoinWithJoined() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.leftJoin(null, MockValueJoiner.TOSTRING_JOINER, Joined.as("name"))); + assertThat(exception.getMessage(), equalTo("table can't be null")); + } + + @Test + public void shouldNotAllowNullValueJoinerOnTableLeftJoin() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.leftJoin(testTable, (ValueJoiner) null)); + assertThat(exception.getMessage(), equalTo("joiner can't be null")); + } + + @Test + public void shouldNotAllowNullValueJoinerWithKeyOnTableLeftJoin() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.leftJoin(testTable, (ValueJoinerWithKey) null)); + assertThat(exception.getMessage(), equalTo("joiner can't be null")); + } + + @Test + public void shouldNotAllowNullValueJoinerOnTableLeftJoinWithJoined() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.leftJoin(testTable, (ValueJoiner) null, Joined.as("name"))); + assertThat(exception.getMessage(), equalTo("joiner can't be null")); + } + + @Test + public void shouldNotAllowNullValueJoinerWithKeyOnTableLeftJoinWithJoined() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.leftJoin(testTable, (ValueJoinerWithKey) null, Joined.as("name"))); + assertThat(exception.getMessage(), equalTo("joiner can't be null")); + } + + @Test + public void shouldNotAllowNullJoinedOnTableLeftJoin() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.leftJoin(testTable, MockValueJoiner.TOSTRING_JOINER, null)); + assertThat(exception.getMessage(), equalTo("joined can't be null")); + } + + @Test + public void shouldNotAllowNullTableOnJoinWithGlobalTable() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.join(null, MockMapper.selectValueMapper(), MockValueJoiner.TOSTRING_JOINER)); + assertThat(exception.getMessage(), equalTo("globalTable can't be null")); + } + + @Test + public void shouldNotAllowNullTableOnJoinWithGlobalTableWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.join( + null, + MockMapper.selectValueMapper(), + MockValueJoiner.TOSTRING_JOINER, + Named.as("name"))); + assertThat(exception.getMessage(), equalTo("globalTable can't be null")); + } + + @Test + public void shouldNotAllowNullMapperOnJoinWithGlobalTable() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.join(testGlobalTable, null, MockValueJoiner.TOSTRING_JOINER)); + assertThat(exception.getMessage(), equalTo("keySelector can't be null")); + } + + @Test + public void shouldNotAllowNullMapperOnJoinWithGlobalTableWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.join( + testGlobalTable, + null, + MockValueJoiner.TOSTRING_JOINER, + Named.as("name"))); + assertThat(exception.getMessage(), equalTo("keySelector can't be null")); + } + + @Test + public void shouldNotAllowNullValueJoinerOnJoinWithGlobalTable() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.join(testGlobalTable, MockMapper.selectValueMapper(), (ValueJoiner) null)); + assertThat(exception.getMessage(), equalTo("joiner can't be null")); + } + + @Test + public void shouldNotAllowNullValueJoinerWithKeyOnJoinWithGlobalTable() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.join(testGlobalTable, MockMapper.selectValueMapper(), (ValueJoinerWithKey) null)); + assertThat(exception.getMessage(), equalTo("joiner can't be null")); + } + + @Test + public void shouldNotAllowNullValueJoinerOnJoinWithGlobalTableWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.join( + testGlobalTable, + MockMapper.selectValueMapper(), + (ValueJoiner) null, + Named.as("name"))); + assertThat(exception.getMessage(), equalTo("joiner can't be null")); + } + + @Test + public void shouldNotAllowNullValueJoinerWithKeyOnJoinWithGlobalTableWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.join( + testGlobalTable, + MockMapper.selectValueMapper(), + (ValueJoiner) null, + Named.as("name"))); + assertThat(exception.getMessage(), equalTo("joiner can't be null")); + } + + @Test + public void shouldNotAllowNullTableOnLeftJoinWithGlobalTable() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.leftJoin(null, MockMapper.selectValueMapper(), MockValueJoiner.TOSTRING_JOINER)); + assertThat(exception.getMessage(), equalTo("globalTable can't be null")); + } + + @Test + public void shouldNotAllowNullTableOnLeftJoinWithGlobalTableWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.leftJoin( + null, + MockMapper.selectValueMapper(), + MockValueJoiner.TOSTRING_JOINER, + Named.as("name"))); + assertThat(exception.getMessage(), equalTo("globalTable can't be null")); + } + + @Test + public void shouldNotAllowNullMapperOnLeftJoinWithGlobalTable() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.leftJoin(testGlobalTable, null, MockValueJoiner.TOSTRING_JOINER)); + assertThat(exception.getMessage(), equalTo("keySelector can't be null")); + } + + @Test + public void shouldNotAllowNullMapperOnLeftJoinWithGlobalTableWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.leftJoin( + testGlobalTable, + null, + MockValueJoiner.TOSTRING_JOINER, + Named.as("name"))); + assertThat(exception.getMessage(), equalTo("keySelector can't be null")); + } + + @Test + public void shouldNotAllowNullValueJoinerOnLeftJoinWithGlobalTable() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.leftJoin(testGlobalTable, MockMapper.selectValueMapper(), (ValueJoiner) null)); + assertThat(exception.getMessage(), equalTo("joiner can't be null")); + } + + @Test + public void shouldNotAllowNullValueJoinerWithKeyOnLeftJoinWithGlobalTable() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.leftJoin(testGlobalTable, MockMapper.selectValueMapper(), (ValueJoinerWithKey) null)); + assertThat(exception.getMessage(), equalTo("joiner can't be null")); + } + + @Test + public void shouldNotAllowNullValueJoinerOnLeftJoinWithGlobalTableWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.leftJoin( + testGlobalTable, + MockMapper.selectValueMapper(), + (ValueJoiner) null, + Named.as("name"))); + assertThat(exception.getMessage(), equalTo("joiner can't be null")); + } + + @Test + public void shouldNotAllowNullValueJoinerWithKeyOnLeftJoinWithGlobalTableWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.leftJoin( + testGlobalTable, + MockMapper.selectValueMapper(), + (ValueJoinerWithKey) null, + Named.as("name"))); + assertThat(exception.getMessage(), equalTo("joiner can't be null")); + } + + @SuppressWarnings({"unchecked", "deprecation"}) // specifically testing the deprecated variant + @Test + public void testNumProcesses() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream source1 = builder.stream(Arrays.asList("topic-1", "topic-2"), stringConsumed); + + final KStream source2 = builder.stream(Arrays.asList("topic-3", "topic-4"), stringConsumed); + + final KStream stream1 = source1.filter((key, value) -> true) + .filterNot((key, value) -> false); + + final KStream stream2 = stream1.mapValues((ValueMapper) Integer::valueOf); + + final KStream stream3 = source2.flatMapValues((ValueMapper>) + value -> Collections.singletonList(Integer.valueOf(value))); + + final KStream[] streams2 = stream2.branch( + (key, value) -> (value % 2) == 0, + (key, value) -> true + ); + + final KStream[] streams3 = stream3.branch( + (key, value) -> (value % 2) == 0, + (key, value) -> true + ); + + final int anyWindowSize = 1; + final StreamJoined joined = StreamJoined.with(Serdes.String(), Serdes.Integer(), Serdes.Integer()); + final KStream stream4 = streams2[0].join(streams3[0], + Integer::sum, JoinWindows.of(ofMillis(anyWindowSize)), joined); + + streams2[1].join(streams3[1], Integer::sum, + JoinWindows.of(ofMillis(anyWindowSize)), joined); + + stream4.to("topic-5"); + + streams2[1].through("topic-6").process(new MockProcessorSupplier<>()); + + streams2[1].repartition().process(new MockProcessorSupplier<>()); + + assertEquals(2 + // sources + 2 + // stream1 + 1 + // stream2 + 1 + // stream3 + 1 + 2 + // streams2 + 1 + 2 + // streams3 + 5 * 2 + // stream2-stream3 joins + 1 + // to + 2 + // through + 1 + // process + 3 + // repartition + 1, // process + TopologyWrapper.getInternalTopologyBuilder(builder.build()).setApplicationId("X").buildTopology().processors().size()); + } + + @SuppressWarnings({"rawtypes", "deprecation"}) // specifically testing the deprecated variant + @Test + public void shouldPreserveSerdesForOperators() { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream1 = builder.stream(Collections.singleton("topic-1"), stringConsumed); + final KTable table1 = builder.table("topic-2", stringConsumed); + final GlobalKTable table2 = builder.globalTable("topic-2", stringConsumed); + final ConsumedInternal consumedInternal = new ConsumedInternal<>(stringConsumed); + + final KeyValueMapper selector = (key, value) -> key; + final KeyValueMapper>> flatSelector = (key, value) -> Collections.singleton(new KeyValue<>(key, value)); + final ValueMapper mapper = value -> value; + final ValueMapper> flatMapper = Collections::singleton; + final ValueJoiner joiner = (value1, value2) -> value1; + + assertEquals(((AbstractStream) stream1.filter((key, value) -> false)).keySerde(), consumedInternal.keySerde()); + assertEquals(((AbstractStream) stream1.filter((key, value) -> false)).valueSerde(), consumedInternal.valueSerde()); + + assertEquals(((AbstractStream) stream1.filterNot((key, value) -> false)).keySerde(), consumedInternal.keySerde()); + assertEquals(((AbstractStream) stream1.filterNot((key, value) -> false)).valueSerde(), consumedInternal.valueSerde()); + + assertNull(((AbstractStream) stream1.selectKey(selector)).keySerde()); + assertEquals(((AbstractStream) stream1.selectKey(selector)).valueSerde(), consumedInternal.valueSerde()); + + assertNull(((AbstractStream) stream1.map(KeyValue::new)).keySerde()); + assertNull(((AbstractStream) stream1.map(KeyValue::new)).valueSerde()); + + assertEquals(((AbstractStream) stream1.mapValues(mapper)).keySerde(), consumedInternal.keySerde()); + assertNull(((AbstractStream) stream1.mapValues(mapper)).valueSerde()); + + assertNull(((AbstractStream) stream1.flatMap(flatSelector)).keySerde()); + assertNull(((AbstractStream) stream1.flatMap(flatSelector)).valueSerde()); + + assertEquals(((AbstractStream) stream1.flatMapValues(flatMapper)).keySerde(), consumedInternal.keySerde()); + assertNull(((AbstractStream) stream1.flatMapValues(flatMapper)).valueSerde()); + + assertNull(((AbstractStream) stream1.transform(transformerSupplier)).keySerde()); + assertNull(((AbstractStream) stream1.transform(transformerSupplier)).valueSerde()); + + assertEquals(((AbstractStream) stream1.transformValues(valueTransformerSupplier)).keySerde(), consumedInternal.keySerde()); + assertNull(((AbstractStream) stream1.transformValues(valueTransformerSupplier)).valueSerde()); + + assertNull(((AbstractStream) stream1.merge(stream1)).keySerde()); + assertNull(((AbstractStream) stream1.merge(stream1)).valueSerde()); + + assertEquals(((AbstractStream) stream1.through("topic-3")).keySerde(), consumedInternal.keySerde()); + assertEquals(((AbstractStream) stream1.through("topic-3")).valueSerde(), consumedInternal.valueSerde()); + assertEquals(((AbstractStream) stream1.through("topic-3", Produced.with(mySerde, mySerde))).keySerde(), mySerde); + assertEquals(((AbstractStream) stream1.through("topic-3", Produced.with(mySerde, mySerde))).valueSerde(), mySerde); + + assertEquals(((AbstractStream) stream1.repartition()).keySerde(), consumedInternal.keySerde()); + assertEquals(((AbstractStream) stream1.repartition()).valueSerde(), consumedInternal.valueSerde()); + assertEquals(((AbstractStream) stream1.repartition(Repartitioned.with(mySerde, mySerde))).keySerde(), mySerde); + assertEquals(((AbstractStream) stream1.repartition(Repartitioned.with(mySerde, mySerde))).valueSerde(), mySerde); + + assertEquals(((AbstractStream) stream1.groupByKey()).keySerde(), consumedInternal.keySerde()); + assertEquals(((AbstractStream) stream1.groupByKey()).valueSerde(), consumedInternal.valueSerde()); + assertEquals(((AbstractStream) stream1.groupByKey(Grouped.with(mySerde, mySerde))).keySerde(), mySerde); + assertEquals(((AbstractStream) stream1.groupByKey(Grouped.with(mySerde, mySerde))).valueSerde(), mySerde); + + assertNull(((AbstractStream) stream1.groupBy(selector)).keySerde()); + assertEquals(((AbstractStream) stream1.groupBy(selector)).valueSerde(), consumedInternal.valueSerde()); + assertEquals(((AbstractStream) stream1.groupBy(selector, Grouped.with(mySerde, mySerde))).keySerde(), mySerde); + assertEquals(((AbstractStream) stream1.groupBy(selector, Grouped.with(mySerde, mySerde))).valueSerde(), mySerde); + + assertNull(((AbstractStream) stream1.join(stream1, joiner, JoinWindows.of(Duration.ofMillis(100L)))).keySerde()); + assertNull(((AbstractStream) stream1.join(stream1, joiner, JoinWindows.of(Duration.ofMillis(100L)))).valueSerde()); + assertEquals(((AbstractStream) stream1.join(stream1, joiner, JoinWindows.of(Duration.ofMillis(100L)), StreamJoined.with(mySerde, mySerde, mySerde))).keySerde(), mySerde); + assertNull(((AbstractStream) stream1.join(stream1, joiner, JoinWindows.of(Duration.ofMillis(100L)), StreamJoined.with(mySerde, mySerde, mySerde))).valueSerde()); + + assertNull(((AbstractStream) stream1.leftJoin(stream1, joiner, JoinWindows.of(Duration.ofMillis(100L)))).keySerde()); + assertNull(((AbstractStream) stream1.leftJoin(stream1, joiner, JoinWindows.of(Duration.ofMillis(100L)))).valueSerde()); + assertEquals(((AbstractStream) stream1.leftJoin(stream1, joiner, JoinWindows.of(Duration.ofMillis(100L)), StreamJoined.with(mySerde, mySerde, mySerde))).keySerde(), mySerde); + assertNull(((AbstractStream) stream1.leftJoin(stream1, joiner, JoinWindows.of(Duration.ofMillis(100L)), StreamJoined.with(mySerde, mySerde, mySerde))).valueSerde()); + + assertNull(((AbstractStream) stream1.outerJoin(stream1, joiner, JoinWindows.of(Duration.ofMillis(100L)))).keySerde()); + assertNull(((AbstractStream) stream1.outerJoin(stream1, joiner, JoinWindows.of(Duration.ofMillis(100L)))).valueSerde()); + assertEquals(((AbstractStream) stream1.outerJoin(stream1, joiner, JoinWindows.of(Duration.ofMillis(100L)), StreamJoined.with(mySerde, mySerde, mySerde))).keySerde(), mySerde); + assertNull(((AbstractStream) stream1.outerJoin(stream1, joiner, JoinWindows.of(Duration.ofMillis(100L)), StreamJoined.with(mySerde, mySerde, mySerde))).valueSerde()); + + assertEquals(((AbstractStream) stream1.join(table1, joiner)).keySerde(), consumedInternal.keySerde()); + assertNull(((AbstractStream) stream1.join(table1, joiner)).valueSerde()); + assertEquals(((AbstractStream) stream1.join(table1, joiner, Joined.with(mySerde, mySerde, mySerde))).keySerde(), mySerde); + assertNull(((AbstractStream) stream1.join(table1, joiner, Joined.with(mySerde, mySerde, mySerde))).valueSerde()); + + assertEquals(((AbstractStream) stream1.leftJoin(table1, joiner)).keySerde(), consumedInternal.keySerde()); + assertNull(((AbstractStream) stream1.leftJoin(table1, joiner)).valueSerde()); + assertEquals(((AbstractStream) stream1.leftJoin(table1, joiner, Joined.with(mySerde, mySerde, mySerde))).keySerde(), mySerde); + assertNull(((AbstractStream) stream1.leftJoin(table1, joiner, Joined.with(mySerde, mySerde, mySerde))).valueSerde()); + + assertEquals(((AbstractStream) stream1.join(table2, selector, joiner)).keySerde(), consumedInternal.keySerde()); + assertNull(((AbstractStream) stream1.join(table2, selector, joiner)).valueSerde()); + + assertEquals(((AbstractStream) stream1.leftJoin(table2, selector, joiner)).keySerde(), consumedInternal.keySerde()); + assertNull(((AbstractStream) stream1.leftJoin(table2, selector, joiner)).valueSerde()); + } + + @Deprecated + @Test + public void shouldUseRecordMetadataTimestampExtractorWithThrough() { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream1 = builder.stream(Arrays.asList("topic-1", "topic-2"), stringConsumed); + final KStream stream2 = builder.stream(Arrays.asList("topic-3", "topic-4"), stringConsumed); + + stream1.to("topic-5"); + stream2.through("topic-6"); + + final ProcessorTopology processorTopology = TopologyWrapper.getInternalTopologyBuilder(builder.build()).setApplicationId("X").buildTopology(); + assertThat(processorTopology.source("topic-6").getTimestampExtractor(), instanceOf(FailOnInvalidTimestamp.class)); + assertNull(processorTopology.source("topic-4").getTimestampExtractor()); + assertNull(processorTopology.source("topic-3").getTimestampExtractor()); + assertNull(processorTopology.source("topic-2").getTimestampExtractor()); + assertNull(processorTopology.source("topic-1").getTimestampExtractor()); + } + + @Test + public void shouldUseRecordMetadataTimestampExtractorWithRepartition() { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream1 = builder.stream(Arrays.asList("topic-1", "topic-2"), stringConsumed); + final KStream stream2 = builder.stream(Arrays.asList("topic-3", "topic-4"), stringConsumed); + + stream1.to("topic-5"); + stream2.repartition(Repartitioned.as("topic-6")); + + final ProcessorTopology processorTopology = TopologyWrapper.getInternalTopologyBuilder(builder.build()).setApplicationId("X").buildTopology(); + assertThat(processorTopology.source("X-topic-6-repartition").getTimestampExtractor(), instanceOf(FailOnInvalidTimestamp.class)); + assertNull(processorTopology.source("topic-4").getTimestampExtractor()); + assertNull(processorTopology.source("topic-3").getTimestampExtractor()); + assertNull(processorTopology.source("topic-2").getTimestampExtractor()); + assertNull(processorTopology.source("topic-1").getTimestampExtractor()); + } + + @Deprecated + @Test + public void shouldSendDataThroughTopicUsingProduced() { + final StreamsBuilder builder = new StreamsBuilder(); + final String input = "topic"; + final KStream stream = builder.stream(input, stringConsumed); + stream.through("through-topic", Produced.with(Serdes.String(), Serdes.String())).process(processorSupplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(input, new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + inputTopic.pipeInput("a", "b"); + } + assertThat(processorSupplier.theCapturedProcessor().processed(), equalTo(Collections.singletonList(new KeyValueTimestamp<>("a", "b", 0)))); + } + + @Test + public void shouldSendDataThroughRepartitionTopicUsingRepartitioned() { + final StreamsBuilder builder = new StreamsBuilder(); + final String input = "topic"; + final KStream stream = builder.stream(input, stringConsumed); + stream.repartition(Repartitioned.with(Serdes.String(), Serdes.String())).process(processorSupplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(input, new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + inputTopic.pipeInput("a", "b"); + } + assertThat(processorSupplier.theCapturedProcessor().processed(), equalTo(Collections.singletonList(new KeyValueTimestamp<>("a", "b", 0)))); + } + + @Test + public void shouldSendDataToTopicUsingProduced() { + final StreamsBuilder builder = new StreamsBuilder(); + final String input = "topic"; + final KStream stream = builder.stream(input, stringConsumed); + stream.to("to-topic", Produced.with(Serdes.String(), Serdes.String())); + builder.stream("to-topic", stringConsumed).process(processorSupplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(input, new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + inputTopic.pipeInput("e", "f"); + } + assertThat(processorSupplier.theCapturedProcessor().processed(), equalTo(Collections.singletonList(new KeyValueTimestamp<>("e", "f", 0)))); + } + + @Test + public void shouldSendDataToDynamicTopics() { + final StreamsBuilder builder = new StreamsBuilder(); + final String input = "topic"; + final KStream stream = builder.stream(input, stringConsumed); + stream.to((key, value, context) -> context.topic() + "-" + key + "-" + value.substring(0, 1), + Produced.with(Serdes.String(), Serdes.String())); + builder.stream(input + "-a-v", stringConsumed).process(processorSupplier); + builder.stream(input + "-b-v", stringConsumed).process(processorSupplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(input, new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + inputTopic.pipeInput("a", "v1"); + inputTopic.pipeInput("a", "v2"); + inputTopic.pipeInput("b", "v1"); + } + final List> mockProcessors = processorSupplier.capturedProcessors(2); + assertThat(mockProcessors.get(0).processed(), equalTo(asList(new KeyValueTimestamp<>("a", "v1", 0), + new KeyValueTimestamp<>("a", "v2", 0)))); + assertThat(mockProcessors.get(1).processed(), equalTo(Collections.singletonList(new KeyValueTimestamp<>("b", "v1", 0)))); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldUseRecordMetadataTimestampExtractorWhenInternalRepartitioningTopicCreatedWithRetention() { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream kStream = builder.stream("topic-1", stringConsumed); + final ValueJoiner valueJoiner = MockValueJoiner.instance(":"); + final long windowSize = TimeUnit.MILLISECONDS.convert(1, TimeUnit.DAYS); + final KStream stream = kStream + .map((key, value) -> KeyValue.pair(value, value)); + stream.join(kStream, + valueJoiner, + JoinWindows.of(ofMillis(windowSize)).grace(ofMillis(3 * windowSize)), + StreamJoined.with(Serdes.String(), Serdes.String(), Serdes.String())) + .to("output-topic", Produced.with(Serdes.String(), Serdes.String())); + + final ProcessorTopology topology = TopologyWrapper.getInternalTopologyBuilder(builder.build()).setApplicationId("X").buildTopology(); + + final SourceNode originalSourceNode = topology.source("topic-1"); + + for (final SourceNode sourceNode : topology.sources()) { + if (sourceNode.name().equals(originalSourceNode.name())) { + assertNull(sourceNode.getTimestampExtractor()); + } else { + assertThat(sourceNode.getTimestampExtractor(), instanceOf(FailOnInvalidTimestamp.class)); + } + } + } + + @SuppressWarnings("deprecation") + @Test + public void shouldUseRecordMetadataTimestampExtractorWhenInternalRepartitioningTopicCreated() { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream kStream = builder.stream("topic-1", stringConsumed); + final ValueJoiner valueJoiner = MockValueJoiner.instance(":"); + final long windowSize = TimeUnit.MILLISECONDS.convert(1, TimeUnit.DAYS); + final KStream stream = kStream + .map((key, value) -> KeyValue.pair(value, value)); + stream.join( + kStream, + valueJoiner, + JoinWindows.of(ofMillis(windowSize)).grace(ofMillis(3L * windowSize)), + StreamJoined.with(Serdes.String(), Serdes.String(), Serdes.String()) + ) + .to("output-topic", Produced.with(Serdes.String(), Serdes.String())); + + final ProcessorTopology topology = TopologyWrapper.getInternalTopologyBuilder(builder.build()).setApplicationId("X").buildTopology(); + + final SourceNode originalSourceNode = topology.source("topic-1"); + + for (final SourceNode sourceNode : topology.sources()) { + if (sourceNode.name().equals(originalSourceNode.name())) { + assertNull(sourceNode.getTimestampExtractor()); + } else { + assertThat(sourceNode.getTimestampExtractor(), instanceOf(FailOnInvalidTimestamp.class)); + } + } + } + + @Test + public void shouldPropagateRepartitionFlagAfterGlobalKTableJoin() { + final StreamsBuilder builder = new StreamsBuilder(); + final GlobalKTable globalKTable = builder.globalTable("globalTopic"); + final KeyValueMapper kvMappper = (k, v) -> k + v; + final ValueJoiner valueJoiner = (v1, v2) -> v1 + v2; + builder.stream("topic").selectKey((k, v) -> v) + .join(globalKTable, kvMappper, valueJoiner) + .groupByKey() + .count(); + + final Pattern repartitionTopicPattern = Pattern.compile("Sink: .*-repartition"); + final String topology = builder.build().describe().toString(); + final Matcher matcher = repartitionTopicPattern.matcher(topology); + assertTrue(matcher.find()); + final String match = matcher.group(); + assertThat(match, notNullValue()); + assertTrue(match.endsWith("repartition")); + } + + @Test + public void shouldMergeTwoStreams() { + final String topic1 = "topic-1"; + final String topic2 = "topic-2"; + + final KStream source1 = builder.stream(topic1); + final KStream source2 = builder.stream(topic2); + final KStream merged = source1.merge(source2); + + merged.process(processorSupplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + inputTopic1.pipeInput("A", "aa"); + inputTopic2.pipeInput("B", "bb"); + inputTopic2.pipeInput("C", "cc"); + inputTopic1.pipeInput("D", "dd"); + } + + assertEquals(asList(new KeyValueTimestamp<>("A", "aa", 0), + new KeyValueTimestamp<>("B", "bb", 0), + new KeyValueTimestamp<>("C", "cc", 0), + new KeyValueTimestamp<>("D", "dd", 0)), processorSupplier.theCapturedProcessor().processed()); + } + + @Test + public void shouldMergeMultipleStreams() { + final String topic1 = "topic-1"; + final String topic2 = "topic-2"; + final String topic3 = "topic-3"; + final String topic4 = "topic-4"; + + final KStream source1 = builder.stream(topic1); + final KStream source2 = builder.stream(topic2); + final KStream source3 = builder.stream(topic3); + final KStream source4 = builder.stream(topic4); + final KStream merged = source1.merge(source2).merge(source3).merge(source4); + + merged.process(processorSupplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic3 = + driver.createInputTopic(topic3, new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic4 = + driver.createInputTopic(topic4, new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + + inputTopic1.pipeInput("A", "aa", 1L); + inputTopic2.pipeInput("B", "bb", 9L); + inputTopic3.pipeInput("C", "cc", 2L); + inputTopic4.pipeInput("D", "dd", 8L); + inputTopic4.pipeInput("E", "ee", 3L); + inputTopic3.pipeInput("F", "ff", 7L); + inputTopic2.pipeInput("G", "gg", 4L); + inputTopic1.pipeInput("H", "hh", 6L); + } + + assertEquals(asList(new KeyValueTimestamp<>("A", "aa", 1), + new KeyValueTimestamp<>("B", "bb", 9), + new KeyValueTimestamp<>("C", "cc", 2), + new KeyValueTimestamp<>("D", "dd", 8), + new KeyValueTimestamp<>("E", "ee", 3), + new KeyValueTimestamp<>("F", "ff", 7), + new KeyValueTimestamp<>("G", "gg", 4), + new KeyValueTimestamp<>("H", "hh", 6)), + processorSupplier.theCapturedProcessor().processed()); + } + + @Test + public void shouldProcessFromSourceThatMatchPattern() { + final KStream pattern2Source = builder.stream(Pattern.compile("topic-\\d")); + + pattern2Source.process(processorSupplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic3 = + driver.createInputTopic("topic-3", new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic4 = + driver.createInputTopic("topic-4", new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic5 = + driver.createInputTopic("topic-5", new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic6 = + driver.createInputTopic("topic-6", new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic7 = + driver.createInputTopic("topic-7", new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + + inputTopic3.pipeInput("A", "aa", 1L); + inputTopic4.pipeInput("B", "bb", 5L); + inputTopic5.pipeInput("C", "cc", 10L); + inputTopic6.pipeInput("D", "dd", 8L); + inputTopic7.pipeInput("E", "ee", 3L); + } + + assertEquals(asList(new KeyValueTimestamp<>("A", "aa", 1), + new KeyValueTimestamp<>("B", "bb", 5), + new KeyValueTimestamp<>("C", "cc", 10), + new KeyValueTimestamp<>("D", "dd", 8), + new KeyValueTimestamp<>("E", "ee", 3)), + processorSupplier.theCapturedProcessor().processed()); + } + + @Test + public void shouldProcessFromSourcesThatMatchMultiplePattern() { + final String topic3 = "topic-without-pattern"; + + final KStream pattern2Source1 = builder.stream(Pattern.compile("topic-\\d")); + final KStream pattern2Source2 = builder.stream(Pattern.compile("topic-[A-Z]")); + final KStream source3 = builder.stream(topic3); + final KStream merged = pattern2Source1.merge(pattern2Source2).merge(source3); + + merged.process(processorSupplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic3 = + driver.createInputTopic("topic-3", new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic4 = + driver.createInputTopic("topic-4", new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopicA = + driver.createInputTopic("topic-A", new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopicZ = + driver.createInputTopic("topic-Z", new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic = + driver.createInputTopic(topic3, new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + + inputTopic3.pipeInput("A", "aa", 1L); + inputTopic4.pipeInput("B", "bb", 5L); + inputTopicA.pipeInput("C", "cc", 10L); + inputTopicZ.pipeInput("D", "dd", 8L); + inputTopic.pipeInput("E", "ee", 3L); + } + + assertEquals(asList(new KeyValueTimestamp<>("A", "aa", 1), + new KeyValueTimestamp<>("B", "bb", 5), + new KeyValueTimestamp<>("C", "cc", 10), + new KeyValueTimestamp<>("D", "dd", 8), + new KeyValueTimestamp<>("E", "ee", 3)), + processorSupplier.theCapturedProcessor().processed()); + } + + @Test + public void shouldNotAllowNullTransformerSupplierOnTransform() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transform(null)); + assertThat(exception.getMessage(), equalTo("transformerSupplier can't be null")); + } + + @Test + public void shouldNotAllowNullTransformerSupplierOnTransformWithStores() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transform(null, "storeName")); + assertThat(exception.getMessage(), equalTo("transformerSupplier can't be null")); + } + + @Test + public void shouldNotAllowNullTransformerSupplierOnTransformWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transform(null, Named.as("transformer"))); + assertThat(exception.getMessage(), equalTo("transformerSupplier can't be null")); + } + + @Test + public void shouldNotAllowNullTransformerSupplierOnTransformWithNamedAndStores() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transform(null, Named.as("transformer"), "storeName")); + assertThat(exception.getMessage(), equalTo("transformerSupplier can't be null")); + } + + @Test + public void shouldNotAllowNullStoreNamesOnTransform() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transform(transformerSupplier, (String[]) null)); + assertThat(exception.getMessage(), equalTo("stateStoreNames can't be a null array")); + } + + @Test + public void shouldNotAllowNullStoreNameOnTransform() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transform(transformerSupplier, (String) null)); + assertThat(exception.getMessage(), equalTo("stateStoreNames can't contain `null` as store name")); + } + + @Test + public void shouldNotAllowNullStoreNamesOnTransformWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transform(transformerSupplier, Named.as("transform"), (String[]) null)); + assertThat(exception.getMessage(), equalTo("stateStoreNames can't be a null array")); + } + + @Test + public void shouldNotAllowNullStoreNameOnTransformWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transform(transformerSupplier, Named.as("transform"), (String) null)); + assertThat(exception.getMessage(), equalTo("stateStoreNames can't contain `null` as store name")); + } + + @Test + public void shouldNotAllowNullNamedOnTransform() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transform(transformerSupplier, (Named) null)); + assertThat(exception.getMessage(), equalTo("named can't be null")); + } + + @Test + public void shouldNotAllowNullNamedOnTransformWithStoreName() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transform(transformerSupplier, (Named) null, "storeName")); + assertThat(exception.getMessage(), equalTo("named can't be null")); + } + + @Test + public void shouldNotAllowBadTransformerSupplierOnFlatTransform() { + final Transformer>> transformer = flatTransformerSupplier.get(); + final IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> testStream.flatTransform(() -> transformer) + ); + assertThat(exception.getMessage(), containsString("#get() must return a new object each time it is called.")); + } + + @Test + public void shouldNotAllowBadTransformerSupplierOnFlatTransformWithStores() { + final Transformer>> transformer = flatTransformerSupplier.get(); + final IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> testStream.flatTransform(() -> transformer, "storeName") + ); + assertThat(exception.getMessage(), containsString("#get() must return a new object each time it is called.")); + } + + @Test + public void shouldNotAllowBadTransformerSupplierOnFlatTransformWithNamed() { + final Transformer>> transformer = flatTransformerSupplier.get(); + final IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> testStream.flatTransform(() -> transformer, Named.as("flatTransformer")) + ); + assertThat(exception.getMessage(), containsString("#get() must return a new object each time it is called.")); + } + + @Test + public void shouldNotAllowBadTransformerSupplierOnFlatTransformWithNamedAndStores() { + final Transformer>> transformer = flatTransformerSupplier.get(); + final IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> testStream.flatTransform(() -> transformer, Named.as("flatTransformer"), "storeName") + ); + assertThat(exception.getMessage(), containsString("#get() must return a new object each time it is called.")); + } + + @Test + public void shouldNotAllowNullTransformerSupplierOnFlatTransform() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransform(null)); + assertThat(exception.getMessage(), equalTo("transformerSupplier can't be null")); + } + + @Test + public void shouldNotAllowNullTransformerSupplierOnFlatTransformWithStores() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransform(null, "storeName")); + assertThat(exception.getMessage(), equalTo("transformerSupplier can't be null")); + } + + @Test + public void shouldNotAllowNullTransformerSupplierOnFlatTransformWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransform(null, Named.as("flatTransformer"))); + assertThat(exception.getMessage(), equalTo("transformerSupplier can't be null")); + } + + @Test + public void shouldNotAllowNullTransformerSupplierOnFlatTransformWithNamedAndStores() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransform(null, Named.as("flatTransformer"), "storeName")); + assertThat(exception.getMessage(), equalTo("transformerSupplier can't be null")); + } + + @Test + public void shouldNotAllowNullStoreNamesOnFlatTransform() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransform(flatTransformerSupplier, (String[]) null)); + assertThat(exception.getMessage(), equalTo("stateStoreNames can't be a null array")); + } + + @Test + public void shouldNotAllowNullStoreNameOnFlatTransform() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransform(flatTransformerSupplier, (String) null)); + assertThat(exception.getMessage(), equalTo("stateStoreNames can't contain `null` as store name")); + } + + @Test + public void shouldNotAllowNullStoreNamesOnFlatTransformWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransform(flatTransformerSupplier, Named.as("flatTransform"), (String[]) null)); + assertThat(exception.getMessage(), equalTo("stateStoreNames can't be a null array")); + } + + @Test + public void shouldNotAllowNullStoreNameOnFlatTransformWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransform(flatTransformerSupplier, Named.as("flatTransform"), (String) null)); + assertThat(exception.getMessage(), equalTo("stateStoreNames can't contain `null` as store name")); + } + + @Test + public void shouldNotAllowNullNamedOnFlatTransform() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransform(flatTransformerSupplier, (Named) null)); + assertThat(exception.getMessage(), equalTo("named can't be null")); + } + + @Test + public void shouldNotAllowNullNamedOnFlatTransformWithStoreName() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransform(flatTransformerSupplier, (Named) null, "storeName")); + assertThat(exception.getMessage(), equalTo("named can't be null")); + } + + @Test + public void shouldNotAllowBadTransformerSupplierOnTransformValues() { + final ValueTransformer transformer = valueTransformerSupplier.get(); + final IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> testStream.transformValues(() -> transformer) + ); + assertThat(exception.getMessage(), containsString("#get() must return a new object each time it is called.")); + } + + @Test + public void shouldNotAllowBadTransformerSupplierOnTransformValuesWithNamed() { + final ValueTransformer transformer = valueTransformerSupplier.get(); + final IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> testStream.transformValues(() -> transformer, Named.as("transformer")) + ); + assertThat(exception.getMessage(), containsString("#get() must return a new object each time it is called.")); + } + + @Test + public void shouldNotAllowNullValueTransformerSupplierOnTransformValues() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transformValues((ValueTransformerSupplier) null)); + assertThat(exception.getMessage(), equalTo("valueTransformerSupplier can't be null")); + } + + @Test + public void shouldNotAllowBadValueTransformerWithKeySupplierOnTransformValues() { + final ValueTransformerWithKey transformer = valueTransformerWithKeySupplier.get(); + final IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> testStream.transformValues(() -> transformer) + ); + assertThat(exception.getMessage(), containsString("#get() must return a new object each time it is called.")); + } + + @Test + public void shouldNotAllowBadValueTransformerWithKeySupplierOnTransformValuesWithNamed() { + final ValueTransformerWithKey transformer = valueTransformerWithKeySupplier.get(); + final IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> testStream.transformValues(() -> transformer, Named.as("transformer")) + ); + assertThat(exception.getMessage(), containsString("#get() must return a new object each time it is called.")); + } + + @Test + public void shouldNotAllowNullValueTransformerWithKeySupplierOnTransformValues() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transformValues((ValueTransformerWithKeySupplier) null)); + assertThat(exception.getMessage(), equalTo("valueTransformerSupplier can't be null")); + } + + @Test + public void shouldNotAllowNullValueTransformerSupplierOnTransformValuesWithStores() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transformValues( + (ValueTransformerSupplier) null, + "storeName")); + assertThat(exception.getMessage(), equalTo("valueTransformerSupplier can't be null")); + } + + @Test + public void shouldNotAllowNullValueTransformerWithKeySupplierOnTransformValuesWithStores() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transformValues( + (ValueTransformerWithKeySupplier) null, + "storeName")); + assertThat(exception.getMessage(), equalTo("valueTransformerSupplier can't be null")); + } + + @Test + public void shouldNotAllowNullValueTransformerSupplierOnTransformValuesWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transformValues( + (ValueTransformerSupplier) null, + Named.as("valueTransformer"))); + assertThat(exception.getMessage(), equalTo("valueTransformerSupplier can't be null")); + } + + @Test + public void shouldNotAllowNullValueTransformerWithKeySupplierOnTransformValuesWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transformValues( + (ValueTransformerWithKeySupplier) null, + Named.as("valueTransformerWithKey"))); + assertThat(exception.getMessage(), equalTo("valueTransformerSupplier can't be null")); + } + + @Test + public void shouldNotAllowNullValueTransformerSupplierOnTransformValuesWithNamedAndStores() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transformValues( + (ValueTransformerSupplier) null, + Named.as("valueTransformer"), + "storeName")); + assertThat(exception.getMessage(), equalTo("valueTransformerSupplier can't be null")); + } + + @Test + public void shouldNotAllowNullValueTransformerWithKeySupplierOnTransformValuesWithNamedAndStores() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transformValues( + (ValueTransformerWithKeySupplier) null, + Named.as("valueTransformerWithKey"), + "storeName")); + assertThat(exception.getMessage(), equalTo("valueTransformerSupplier can't be null")); + } + + @Test + public void shouldNotAllowNullStoreNamesOnTransformValuesWithValueTransformerSupplier() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transformValues( + valueTransformerSupplier, + (String[]) null)); + assertThat(exception.getMessage(), equalTo("stateStoreNames can't be a null array")); + } + + @Test + public void shouldNotAllowNullStoreNamesOnTransformValuesWithValueTransformerWithKeySupplier() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transformValues( + valueTransformerWithKeySupplier, + (String[]) null)); + assertThat(exception.getMessage(), equalTo("stateStoreNames can't be a null array")); + } + + @Test + public void shouldNotAllowNullStoreNameOnTransformValuesWithValueTransformerSupplier() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transformValues( + valueTransformerSupplier, (String) null)); + assertThat(exception.getMessage(), equalTo("stateStoreNames can't contain `null` as store name")); + } + + @Test + public void shouldNotAllowNullStoreNameOnTransformValuesWithValueTransformerWithKeySupplier() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transformValues( + valueTransformerWithKeySupplier, + (String) null)); + assertThat(exception.getMessage(), equalTo("stateStoreNames can't contain `null` as store name")); + } + + @Test + public void shouldNotAllowNullStoreNamesOnTransformValuesWithValueTransformerSupplierWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transformValues( + valueTransformerSupplier, + Named.as("valueTransformer"), + (String[]) null)); + assertThat(exception.getMessage(), equalTo("stateStoreNames can't be a null array")); + } + + @Test + public void shouldNotAllowNullStoreNamesOnTransformValuesWithValueTransformerWithKeySupplierWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transformValues( + valueTransformerWithKeySupplier, + Named.as("valueTransformer"), + (String[]) null)); + assertThat(exception.getMessage(), equalTo("stateStoreNames can't be a null array")); + } + + @Test + public void shouldNotAllowNullStoreNameOnTransformValuesWithValueTransformerSupplierWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transformValues( + valueTransformerSupplier, + Named.as("valueTransformer"), + (String) null)); + assertThat(exception.getMessage(), equalTo("stateStoreNames can't contain `null` as store name")); + } + + @Test + public void shouldNotAllowNullStoreNameOnTransformValuesWithValueTransformerWithKeySupplierWithName() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transformValues( + valueTransformerWithKeySupplier, + Named.as("valueTransformerWithKey"), + (String) null)); + assertThat(exception.getMessage(), equalTo("stateStoreNames can't contain `null` as store name")); + } + + @Test + public void shouldNotAllowNullNamedOnTransformValuesWithValueTransformerSupplier() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transformValues( + valueTransformerSupplier, + (Named) null)); + assertThat(exception.getMessage(), equalTo("named can't be null")); + } + + @Test + public void shouldNotAllowNullNamedOnTransformValuesWithValueTransformerWithKeySupplier() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transformValues( + valueTransformerWithKeySupplier, + (Named) null)); + assertThat(exception.getMessage(), equalTo("named can't be null")); + } + + @Test + public void shouldNotAllowNullNamedOnTransformValuesWithValueTransformerSupplierAndStores() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transformValues( + valueTransformerSupplier, + (Named) null, + "storeName")); + assertThat(exception.getMessage(), equalTo("named can't be null")); + } + + @Test + public void shouldNotAllowNullNamedOnTransformValuesWithValueTransformerWithKeySupplierAndStores() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.transformValues( + valueTransformerWithKeySupplier, + (Named) null, + "storeName")); + assertThat(exception.getMessage(), equalTo("named can't be null")); + } + + @Test + public void shouldNotAllowNullValueTransformerSupplierOnFlatTransformValues() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransformValues((ValueTransformerSupplier>) null)); + assertThat(exception.getMessage(), equalTo("valueTransformerSupplier can't be null")); + } + + @Test + public void shouldNotAllowNullValueTransformerWithKeySupplierOnFlatTransformValues() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransformValues((ValueTransformerWithKeySupplier>) null)); + assertThat(exception.getMessage(), equalTo("valueTransformerSupplier can't be null")); + } + + @Test + public void shouldNotAllowNullValueTransformerSupplierOnFlatTransformValuesWithStores() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransformValues( + (ValueTransformerSupplier>) null, + "stateStore")); + assertThat(exception.getMessage(), equalTo("valueTransformerSupplier can't be null")); + } + + @Test + public void shouldNotAllowNullValueTransformerWithKeySupplierOnFlatTransformValuesWithStores() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransformValues( + (ValueTransformerWithKeySupplier>) null, + "stateStore")); + assertThat(exception.getMessage(), equalTo("valueTransformerSupplier can't be null")); + } + + @Test + public void shouldNotAllowNullValueTransformerSupplierOnFlatTransformValuesWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransformValues( + (ValueTransformerSupplier>) null, + Named.as("flatValueTransformer"))); + assertThat(exception.getMessage(), equalTo("valueTransformerSupplier can't be null")); + } + + @Test + public void shouldNotAllowNullValueTransformerWithKeySupplierOnFlatTransformValuesWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransformValues( + (ValueTransformerWithKeySupplier>) null, + Named.as("flatValueWithKeyTransformer"))); + assertThat(exception.getMessage(), equalTo("valueTransformerSupplier can't be null")); + } + + @Test + public void shouldNotAllowNullValueTransformerSupplierOnFlatTransformValuesWithNamedAndStores() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransformValues( + (ValueTransformerSupplier>) null, + Named.as("flatValueTransformer"), + "stateStore")); + assertThat(exception.getMessage(), equalTo("valueTransformerSupplier can't be null")); + } + + @Test + public void shouldNotAllowNullValueTransformerWithKeySupplierOnFlatTransformValuesWithNamedAndStores() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransformValues( + (ValueTransformerWithKeySupplier>) null, + Named.as("flatValueWitKeyTransformer"), + "stateStore")); + assertThat(exception.getMessage(), equalTo("valueTransformerSupplier can't be null")); + } + + @Test + public void shouldNotAllowNullStoreNamesOnFlatTransformValuesWithFlatValueSupplier() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransformValues( + flatValueTransformerSupplier, + (String[]) null)); + assertThat(exception.getMessage(), equalTo("stateStoreNames can't be a null array")); + } + + @Test + public void shouldNotAllowNullStoreNamesOnFlatTransformValuesWithFlatValueWithKeySupplier() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransformValues( + flatValueTransformerWithKeySupplier, + (String[]) null)); + assertThat(exception.getMessage(), equalTo("stateStoreNames can't be a null array")); + } + + @Test + public void shouldNotAllowNullStoreNameOnFlatTransformValuesWithFlatValueSupplier() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransformValues( + flatValueTransformerSupplier, + (String) null)); + assertThat(exception.getMessage(), equalTo("stateStoreNames can't contain `null` as store name")); + } + + @Test + public void shouldNotAllowNullStoreNameOnFlatTransformValuesWithFlatValueWithKeySupplier() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransformValues( + flatValueTransformerWithKeySupplier, + (String) null)); + assertThat(exception.getMessage(), equalTo("stateStoreNames can't contain `null` as store name")); + } + + @Test + public void shouldNotAllowNullStoreNamesOnFlatTransformValuesWithFlatValueSupplierAndNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransformValues( + flatValueTransformerSupplier, + Named.as("flatValueTransformer"), + (String[]) null)); + assertThat(exception.getMessage(), equalTo("stateStoreNames can't be a null array")); + } + + @Test + public void shouldNotAllowNullStoreNamesOnFlatTransformValuesWithFlatValueWithKeySupplierAndNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransformValues( + flatValueTransformerWithKeySupplier, + Named.as("flatValueWitKeyTransformer"), + (String[]) null)); + assertThat(exception.getMessage(), equalTo("stateStoreNames can't be a null array")); + } + + @Test + public void shouldNotAllowNullStoreNameOnFlatTransformValuesWithFlatValueSupplierAndNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransformValues( + flatValueTransformerSupplier, + Named.as("flatValueTransformer"), + (String) null)); + assertThat(exception.getMessage(), equalTo("stateStoreNames can't contain `null` as store name")); + } + + @Test + public void shouldNotAllowNullStoreNameOnFlatTransformValuesWithFlatValueWithKeySupplierAndNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransformValues( + flatValueTransformerWithKeySupplier, + Named.as("flatValueWitKeyTransformer"), + (String) null)); + assertThat(exception.getMessage(), equalTo("stateStoreNames can't contain `null` as store name")); + } + + @Test + public void shouldNotAllowNullNamedOnFlatTransformValuesWithFlatValueSupplier() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransformValues( + flatValueTransformerSupplier, + (Named) null)); + assertThat(exception.getMessage(), equalTo("named can't be null")); + } + + @Test + public void shouldNotAllowNullNamedOnFlatTransformValuesWithFlatValueWithKeySupplier() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransformValues( + flatValueTransformerWithKeySupplier, + (Named) null)); + assertThat(exception.getMessage(), equalTo("named can't be null")); + } + + @Test + public void shouldNotAllowNullNamedOnFlatTransformValuesWithFlatValueSupplierAndStores() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransformValues( + flatValueTransformerSupplier, + (Named) null, + "storeName")); + assertThat(exception.getMessage(), equalTo("named can't be null")); + } + + @Test + public void shouldNotAllowNullNamedOnFlatTransformValuesWithFlatValueWithKeySupplierAndStore() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.flatTransformValues( + flatValueTransformerWithKeySupplier, + (Named) null, + "storeName")); + assertThat(exception.getMessage(), equalTo("named can't be null")); + } + + @Test + public void shouldNotAllowNullProcessSupplierOnProcess() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.process((ProcessorSupplier) null)); + assertThat(exception.getMessage(), equalTo("processorSupplier can't be null")); + } + + @Test + public void shouldNotAllowNullProcessSupplierOnProcessWithStores() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.process((ProcessorSupplier) null, + "storeName")); + assertThat(exception.getMessage(), equalTo("processorSupplier can't be null")); + } + + @Test + public void shouldNotAllowNullProcessSupplierOnProcessWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.process((ProcessorSupplier) null, + Named.as("processor"))); + assertThat(exception.getMessage(), equalTo("processorSupplier can't be null")); + } + + @Test + public void shouldNotAllowNullProcessSupplierOnProcessWithNamedAndStores() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.process((ProcessorSupplier) null, + Named.as("processor"), "stateStore")); + assertThat(exception.getMessage(), equalTo("processorSupplier can't be null")); + } + + @Test + public void shouldNotAllowNullStoreNamesOnProcess() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.process(processorSupplier, (String[]) null)); + assertThat(exception.getMessage(), equalTo("stateStoreNames can't be a null array")); + } + + @Test + public void shouldNotAllowNullStoreNameOnProcess() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.process(processorSupplier, (String) null)); + assertThat(exception.getMessage(), equalTo("stateStoreNames can't be null")); + } + + @Test + public void shouldNotAllowNullStoreNamesOnProcessWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.process(processorSupplier, Named.as("processor"), (String[]) null)); + assertThat(exception.getMessage(), equalTo("stateStoreNames can't be a null array")); + } + + @Test + public void shouldNotAllowNullStoreNameOnProcessWithNamed() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.process(processorSupplier, Named.as("processor"), (String) null)); + assertThat(exception.getMessage(), equalTo("stateStoreNames can't be null")); + } + + @Test + public void shouldNotAllowNullNamedOnProcess() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.process(processorSupplier, (Named) null)); + assertThat(exception.getMessage(), equalTo("named can't be null")); + } + + @Test + public void shouldNotAllowNullNamedOnProcessWithStores() { + final NullPointerException exception = assertThrows( + NullPointerException.class, + () -> testStream.process(processorSupplier, (Named) null, "storeName")); + assertThat(exception.getMessage(), equalTo("named can't be null")); + } + + + @Test + public void shouldNotMaterializedKTableFromKStream() { + final Consumed consumed = Consumed.with(Serdes.String(), Serdes.String()); + + final StreamsBuilder builder = new StreamsBuilder(); + + final String input = "input"; + final String output = "output"; + + builder.stream(input, consumed).toTable().toStream().to(output); + + final String topologyDescription = builder.build().describe().toString(); + + assertThat( + topologyDescription, + equalTo("Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input])\n" + + " --> KSTREAM-TOTABLE-0000000001\n" + + " Processor: KSTREAM-TOTABLE-0000000001 (stores: [])\n" + + " --> KTABLE-TOSTREAM-0000000003\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: KTABLE-TOSTREAM-0000000003 (stores: [])\n" + + " --> KSTREAM-SINK-0000000004\n" + + " <-- KSTREAM-TOTABLE-0000000001\n" + + " Sink: KSTREAM-SINK-0000000004 (topic: output)\n" + + " <-- KTABLE-TOSTREAM-0000000003\n\n") + ); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(input, Serdes.String().serializer(), Serdes.String().serializer()); + final TestOutputTopic outputTopic = + driver.createOutputTopic(output, Serdes.String().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("A", "01", 5L); + inputTopic.pipeInput("B", "02", 100L); + inputTopic.pipeInput("C", "03", 0L); + inputTopic.pipeInput("D", "04", 0L); + inputTopic.pipeInput("A", "05", 10L); + inputTopic.pipeInput("A", "06", 8L); + + final List> outputExpectRecords = new ArrayList<>(); + outputExpectRecords.add(new TestRecord<>("A", "01", Instant.ofEpochMilli(5L))); + outputExpectRecords.add(new TestRecord<>("B", "02", Instant.ofEpochMilli(100L))); + outputExpectRecords.add(new TestRecord<>("C", "03", Instant.ofEpochMilli(0L))); + outputExpectRecords.add(new TestRecord<>("D", "04", Instant.ofEpochMilli(0L))); + outputExpectRecords.add(new TestRecord<>("A", "05", Instant.ofEpochMilli(10L))); + outputExpectRecords.add(new TestRecord<>("A", "06", Instant.ofEpochMilli(8L))); + + assertEquals(outputTopic.readRecordsToList(), outputExpectRecords); + } + } + + @Test + public void shouldMaterializeKTableFromKStream() { + final Consumed consumed = Consumed.with(Serdes.String(), Serdes.String()); + + final StreamsBuilder builder = new StreamsBuilder(); + final String storeName = "store"; + + final String input = "input"; + builder.stream(input, consumed) + .toTable(Materialized.as(Stores.inMemoryKeyValueStore(storeName))); + + final Topology topology = builder.build(); + + final String topologyDescription = topology.describe().toString(); + assertThat( + topologyDescription, + equalTo("Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input])\n" + + " --> KSTREAM-TOTABLE-0000000001\n" + + " Processor: KSTREAM-TOTABLE-0000000001 (stores: [store])\n" + + " --> none\n" + + " <-- KSTREAM-SOURCE-0000000000\n\n") + ); + + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(input, Serdes.String().serializer(), Serdes.String().serializer()); + final KeyValueStore store = driver.getKeyValueStore(storeName); + + inputTopic.pipeInput("A", "01"); + inputTopic.pipeInput("B", "02"); + inputTopic.pipeInput("A", "03"); + final Map expectedStore = mkMap(mkEntry("A", "03"), mkEntry("B", "02")); + + assertThat(asMap(store), is(expectedStore)); + } + } + + @Test + public void shouldSupportKeyChangeKTableFromKStream() { + final Consumed consumed = Consumed.with(Serdes.String(), Serdes.String()); + + final StreamsBuilder builder = new StreamsBuilder(); + + final String input = "input"; + final String output = "output"; + + builder.stream(input, consumed) + .map((key, value) -> new KeyValue<>(key.charAt(0) - 'A', value)) + .toTable(Materialized.with(Serdes.Integer(), null)) + .toStream().to(output); + + final Topology topology = builder.build(); + + final String topologyDescription = topology.describe().toString(); + assertThat( + topologyDescription, + equalTo("Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input])\n" + + " --> KSTREAM-MAP-0000000001\n" + + " Processor: KSTREAM-MAP-0000000001 (stores: [])\n" + + " --> KSTREAM-FILTER-0000000005\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: KSTREAM-FILTER-0000000005 (stores: [])\n" + + " --> KSTREAM-SINK-0000000004\n" + + " <-- KSTREAM-MAP-0000000001\n" + + " Sink: KSTREAM-SINK-0000000004 (topic: KSTREAM-TOTABLE-0000000002-repartition)\n" + + " <-- KSTREAM-FILTER-0000000005\n" + + "\n" + + " Sub-topology: 1\n" + + " Source: KSTREAM-SOURCE-0000000006 (topics: [KSTREAM-TOTABLE-0000000002-repartition])\n" + + " --> KSTREAM-TOTABLE-0000000002\n" + + " Processor: KSTREAM-TOTABLE-0000000002 (stores: [])\n" + + " --> KTABLE-TOSTREAM-0000000007\n" + + " <-- KSTREAM-SOURCE-0000000006\n" + + " Processor: KTABLE-TOSTREAM-0000000007 (stores: [])\n" + + " --> KSTREAM-SINK-0000000008\n" + + " <-- KSTREAM-TOTABLE-0000000002\n" + + " Sink: KSTREAM-SINK-0000000008 (topic: output)\n" + + " <-- KTABLE-TOSTREAM-0000000007\n\n") + ); + + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(input, Serdes.String().serializer(), Serdes.String().serializer()); + final TestOutputTopic outputTopic = + driver.createOutputTopic(output, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("A", "01", 5L); + inputTopic.pipeInput("B", "02", 100L); + inputTopic.pipeInput("C", "03", 0L); + inputTopic.pipeInput("D", "04", 0L); + inputTopic.pipeInput("A", "05", 10L); + inputTopic.pipeInput("A", "06", 8L); + + final List> outputExpectRecords = new ArrayList<>(); + outputExpectRecords.add(new TestRecord<>(0, "01", Instant.ofEpochMilli(5L))); + outputExpectRecords.add(new TestRecord<>(1, "02", Instant.ofEpochMilli(100L))); + outputExpectRecords.add(new TestRecord<>(2, "03", Instant.ofEpochMilli(0L))); + outputExpectRecords.add(new TestRecord<>(3, "04", Instant.ofEpochMilli(0L))); + outputExpectRecords.add(new TestRecord<>(0, "05", Instant.ofEpochMilli(10L))); + outputExpectRecords.add(new TestRecord<>(0, "06", Instant.ofEpochMilli(8L))); + + assertEquals(outputTopic.readRecordsToList(), outputExpectRecords); + } + } + + @Test + public void shouldSupportForeignKeyTableTableJoinWithKTableFromKStream() { + final Consumed consumed = Consumed.with(Serdes.String(), Serdes.String()); + final StreamsBuilder builder = new StreamsBuilder(); + + final String input1 = "input1"; + final String input2 = "input2"; + final String output = "output"; + + final KTable leftTable = builder.stream(input1, consumed).toTable(); + final KTable rightTable = builder.stream(input2, consumed).toTable(); + + final Function extractor = value -> value.split("\\|")[1]; + final ValueJoiner joiner = (value1, value2) -> "(" + value1 + "," + value2 + ")"; + + leftTable.join(rightTable, extractor, joiner).toStream().to(output); + + final Topology topology = builder.build(props); + + final String topologyDescription = topology.describe().toString(); + + assertThat( + topologyDescription, + equalTo("Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KTABLE-SOURCE-0000000016 (topics: [KTABLE-FK-JOIN-SUBSCRIPTION-RESPONSE-0000000014-topic])\n" + + " --> KTABLE-FK-JOIN-SUBSCRIPTION-RESPONSE-RESOLVER-PROCESSOR-0000000017\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input1])\n" + + " --> KSTREAM-TOTABLE-0000000001\n" + + " Processor: KTABLE-FK-JOIN-SUBSCRIPTION-RESPONSE-RESOLVER-PROCESSOR-0000000017 (stores: [KSTREAM-TOTABLE-STATE-STORE-0000000002])\n" + + " --> KTABLE-FK-JOIN-OUTPUT-0000000018\n" + + " <-- KTABLE-SOURCE-0000000016\n" + + " Processor: KSTREAM-TOTABLE-0000000001 (stores: [KSTREAM-TOTABLE-STATE-STORE-0000000002])\n" + + " --> KTABLE-FK-JOIN-SUBSCRIPTION-REGISTRATION-0000000007\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: KTABLE-FK-JOIN-OUTPUT-0000000018 (stores: [])\n" + + " --> KTABLE-TOSTREAM-0000000020\n" + + " <-- KTABLE-FK-JOIN-SUBSCRIPTION-RESPONSE-RESOLVER-PROCESSOR-0000000017\n" + + " Processor: KTABLE-FK-JOIN-SUBSCRIPTION-REGISTRATION-0000000007 (stores: [])\n" + + " --> KTABLE-SINK-0000000008\n" + + " <-- KSTREAM-TOTABLE-0000000001\n" + + " Processor: KTABLE-TOSTREAM-0000000020 (stores: [])\n" + + " --> KSTREAM-SINK-0000000021\n" + + " <-- KTABLE-FK-JOIN-OUTPUT-0000000018\n" + + " Sink: KSTREAM-SINK-0000000021 (topic: output)\n" + + " <-- KTABLE-TOSTREAM-0000000020\n" + + " Sink: KTABLE-SINK-0000000008 (topic: KTABLE-FK-JOIN-SUBSCRIPTION-REGISTRATION-0000000006-topic)\n" + + " <-- KTABLE-FK-JOIN-SUBSCRIPTION-REGISTRATION-0000000007\n" + + "\n" + + " Sub-topology: 1\n" + + " Source: KSTREAM-SOURCE-0000000003 (topics: [input2])\n" + + " --> KSTREAM-TOTABLE-0000000004\n" + + " Source: KTABLE-SOURCE-0000000009 (topics: [KTABLE-FK-JOIN-SUBSCRIPTION-REGISTRATION-0000000006-topic])\n" + + " --> KTABLE-FK-JOIN-SUBSCRIPTION-PROCESSOR-0000000011\n" + + " Processor: KSTREAM-TOTABLE-0000000004 (stores: [KSTREAM-TOTABLE-STATE-STORE-0000000005])\n" + + " --> KTABLE-FK-JOIN-SUBSCRIPTION-PROCESSOR-0000000013\n" + + " <-- KSTREAM-SOURCE-0000000003\n" + + " Processor: KTABLE-FK-JOIN-SUBSCRIPTION-PROCESSOR-0000000011 (stores: [KTABLE-FK-JOIN-SUBSCRIPTION-STATE-STORE-0000000010])\n" + + " --> KTABLE-FK-JOIN-SUBSCRIPTION-PROCESSOR-0000000012\n" + + " <-- KTABLE-SOURCE-0000000009\n" + + " Processor: KTABLE-FK-JOIN-SUBSCRIPTION-PROCESSOR-0000000012 (stores: [KSTREAM-TOTABLE-STATE-STORE-0000000005])\n" + + " --> KTABLE-SINK-0000000015\n" + + " <-- KTABLE-FK-JOIN-SUBSCRIPTION-PROCESSOR-0000000011\n" + + " Processor: KTABLE-FK-JOIN-SUBSCRIPTION-PROCESSOR-0000000013 (stores: [KTABLE-FK-JOIN-SUBSCRIPTION-STATE-STORE-0000000010])\n" + + " --> KTABLE-SINK-0000000015\n" + + " <-- KSTREAM-TOTABLE-0000000004\n" + + " Sink: KTABLE-SINK-0000000015 (topic: KTABLE-FK-JOIN-SUBSCRIPTION-RESPONSE-0000000014-topic)\n" + + " <-- KTABLE-FK-JOIN-SUBSCRIPTION-PROCESSOR-0000000012, KTABLE-FK-JOIN-SUBSCRIPTION-PROCESSOR-0000000013\n\n") + ); + + + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, props)) { + final TestInputTopic left = driver.createInputTopic(input1, new StringSerializer(), new StringSerializer()); + final TestInputTopic right = driver.createInputTopic(input2, new StringSerializer(), new StringSerializer()); + final TestOutputTopic outputTopic = driver.createOutputTopic(output, new StringDeserializer(), new StringDeserializer()); + + // Pre-populate the RHS records. This test is all about what happens when we add/remove LHS records + right.pipeInput("rhs1", "rhsValue1"); + right.pipeInput("rhs2", "rhsValue2"); + right.pipeInput("rhs3", "rhsValue3"); // this unreferenced FK won't show up in any results + + assertThat(outputTopic.readKeyValuesToMap(), is(emptyMap())); + + left.pipeInput("lhs1", "lhsValue1|rhs1"); + left.pipeInput("lhs2", "lhsValue2|rhs2"); + + final Map expected = mkMap( + mkEntry("lhs1", "(lhsValue1|rhs1,rhsValue1)"), + mkEntry("lhs2", "(lhsValue2|rhs2,rhsValue2)") + ); + assertThat(outputTopic.readKeyValuesToMap(), is(expected)); + + // Add another reference to an existing FK + left.pipeInput("lhs3", "lhsValue3|rhs1"); + + assertThat( + outputTopic.readKeyValuesToMap(), + is(mkMap( + mkEntry("lhs3", "(lhsValue3|rhs1,rhsValue1)") + )) + ); + + left.pipeInput("lhs1", (String) null); + assertThat( + outputTopic.readKeyValuesToMap(), + is(mkMap( + mkEntry("lhs1", null) + )) + ); + } + } + + @Test + public void shouldSupportTableTableJoinWithKStreamToKTable() { + final Consumed consumed = Consumed.with(Serdes.String(), Serdes.String()); + final StreamsBuilder builder = new StreamsBuilder(); + + final String leftTopic = "left"; + final String rightTopic = "right"; + final String outputTopic = "output"; + + final KTable table1 = builder.stream(leftTopic, consumed).toTable(); + final KTable table2 = builder.stream(rightTopic, consumed).toTable(); + + table1.join(table2, MockValueJoiner.TOSTRING_JOINER).toStream().to(outputTopic); + + final Topology topology = builder.build(props); + + final String topologyDescription = topology.describe().toString(); + assertThat( + topologyDescription, + equalTo("Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [left])\n" + + " --> KSTREAM-TOTABLE-0000000001\n" + + " Source: KSTREAM-SOURCE-0000000003 (topics: [right])\n" + + " --> KSTREAM-TOTABLE-0000000004\n" + + " Processor: KSTREAM-TOTABLE-0000000001 (stores: [KSTREAM-TOTABLE-STATE-STORE-0000000002])\n" + + " --> KTABLE-JOINTHIS-0000000007\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: KSTREAM-TOTABLE-0000000004 (stores: [KSTREAM-TOTABLE-STATE-STORE-0000000005])\n" + + " --> KTABLE-JOINOTHER-0000000008\n" + + " <-- KSTREAM-SOURCE-0000000003\n" + + " Processor: KTABLE-JOINOTHER-0000000008 (stores: [KSTREAM-TOTABLE-STATE-STORE-0000000002])\n" + + " --> KTABLE-MERGE-0000000006\n" + + " <-- KSTREAM-TOTABLE-0000000004\n" + + " Processor: KTABLE-JOINTHIS-0000000007 (stores: [KSTREAM-TOTABLE-STATE-STORE-0000000005])\n" + + " --> KTABLE-MERGE-0000000006\n" + + " <-- KSTREAM-TOTABLE-0000000001\n" + + " Processor: KTABLE-MERGE-0000000006 (stores: [])\n" + + " --> KTABLE-TOSTREAM-0000000009\n" + + " <-- KTABLE-JOINTHIS-0000000007, KTABLE-JOINOTHER-0000000008\n" + + " Processor: KTABLE-TOSTREAM-0000000009 (stores: [])\n" + + " --> KSTREAM-SINK-0000000010\n" + + " <-- KTABLE-MERGE-0000000006\n" + + " Sink: KSTREAM-SINK-0000000010 (topic: output)\n" + + " <-- KTABLE-TOSTREAM-0000000009\n\n")); + + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, props)) { + final TestInputTopic left = driver.createInputTopic(leftTopic, new StringSerializer(), new StringSerializer()); + final TestInputTopic right = driver.createInputTopic(rightTopic, new StringSerializer(), new StringSerializer()); + final TestOutputTopic output = driver.createOutputTopic(outputTopic, new StringDeserializer(), new StringDeserializer()); + + right.pipeInput("lhs1", "rhsValue1"); + right.pipeInput("rhs2", "rhsValue2"); + right.pipeInput("lhs3", "rhsValue3"); + + assertThat(output.readKeyValuesToMap(), is(emptyMap())); + + left.pipeInput("lhs1", "lhsValue1"); + left.pipeInput("lhs2", "lhsValue2"); + + final Map expected = mkMap( + mkEntry("lhs1", "lhsValue1+rhsValue1") + ); + + assertThat( + output.readKeyValuesToMap(), + is(expected) + ); + + left.pipeInput("lhs3", "lhsValue3"); + + assertThat( + output.readKeyValuesToMap(), + is(mkMap( + mkEntry("lhs3", "lhsValue3+rhsValue3") + )) + ); + + left.pipeInput("lhs1", "lhsValue4"); + assertThat( + output.readKeyValuesToMap(), + is(mkMap( + mkEntry("lhs1", "lhsValue4+rhsValue1") + )) + ); + } + } + + @Test + public void shouldSupportStreamTableJoinWithKStreamToKTable() { + final StreamsBuilder builder = new StreamsBuilder(); + final Consumed consumed = Consumed.with(Serdes.String(), Serdes.String()); + + final String streamTopic = "streamTopic"; + final String tableTopic = "tableTopic"; + final String outputTopic = "output"; + + final KStream stream = builder.stream(streamTopic, consumed); + final KTable table = builder.stream(tableTopic, consumed).toTable(); + + stream.join(table, MockValueJoiner.TOSTRING_JOINER).to(outputTopic); + + final Topology topology = builder.build(props); + + final String topologyDescription = topology.describe().toString(); + assertThat( + topologyDescription, + equalTo("Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [streamTopic])\n" + + " --> KSTREAM-JOIN-0000000004\n" + + " Processor: KSTREAM-JOIN-0000000004 (stores: [KSTREAM-TOTABLE-STATE-STORE-0000000003])\n" + + " --> KSTREAM-SINK-0000000005\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [tableTopic])\n" + + " --> KSTREAM-TOTABLE-0000000002\n" + + " Sink: KSTREAM-SINK-0000000005 (topic: output)\n" + + " <-- KSTREAM-JOIN-0000000004\n" + + " Processor: KSTREAM-TOTABLE-0000000002 (stores: [KSTREAM-TOTABLE-STATE-STORE-0000000003])\n" + + " --> none\n" + + " <-- KSTREAM-SOURCE-0000000001\n\n")); + + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, props)) { + final TestInputTopic left = driver.createInputTopic(streamTopic, new StringSerializer(), new StringSerializer()); + final TestInputTopic right = driver.createInputTopic(tableTopic, new StringSerializer(), new StringSerializer()); + final TestOutputTopic output = driver.createOutputTopic(outputTopic, new StringDeserializer(), new StringDeserializer()); + + right.pipeInput("lhs1", "rhsValue1"); + right.pipeInput("rhs2", "rhsValue2"); + right.pipeInput("lhs3", "rhsValue3"); + + assertThat(output.readKeyValuesToMap(), is(emptyMap())); + + left.pipeInput("lhs1", "lhsValue1"); + left.pipeInput("lhs2", "lhsValue2"); + + final Map expected = mkMap( + mkEntry("lhs1", "lhsValue1+rhsValue1") + ); + + assertThat( + output.readKeyValuesToMap(), + is(expected) + ); + + left.pipeInput("lhs3", "lhsValue3"); + + assertThat( + output.readKeyValuesToMap(), + is(mkMap( + mkEntry("lhs3", "lhsValue3+rhsValue3") + )) + ); + + left.pipeInput("lhs1", "lhsValue4"); + assertThat( + output.readKeyValuesToMap(), + is(mkMap( + mkEntry("lhs1", "lhsValue4+rhsValue1") + )) + ); + } + } + + @Test + public void shouldSupportGroupByCountWithKStreamToKTable() { + final Consumed consumed = Consumed.with(Serdes.String(), Serdes.String()); + final StreamsBuilder builder = new StreamsBuilder(); + + final String input = "input"; + final String output = "output"; + + builder + .stream(input, consumed) + .toTable() + .groupBy(MockMapper.selectValueKeyValueMapper(), Grouped.with(Serdes.String(), Serdes.String())) + .count() + .toStream() + .to(output); + + final Topology topology = builder.build(props); + + final String topologyDescription = topology.describe().toString(); + assertThat( + topologyDescription, + equalTo("Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input])\n" + + " --> KSTREAM-TOTABLE-0000000001\n" + + " Processor: KSTREAM-TOTABLE-0000000001 (stores: [KSTREAM-TOTABLE-STATE-STORE-0000000002])\n" + + " --> KTABLE-SELECT-0000000003\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: KTABLE-SELECT-0000000003 (stores: [])\n" + + " --> KSTREAM-SINK-0000000005\n" + + " <-- KSTREAM-TOTABLE-0000000001\n" + + " Sink: KSTREAM-SINK-0000000005 (topic: KTABLE-AGGREGATE-STATE-STORE-0000000004-repartition)\n" + + " <-- KTABLE-SELECT-0000000003\n" + + "\n" + + " Sub-topology: 1\n" + + " Source: KSTREAM-SOURCE-0000000006 (topics: [KTABLE-AGGREGATE-STATE-STORE-0000000004-repartition])\n" + + " --> KTABLE-AGGREGATE-0000000007\n" + + " Processor: KTABLE-AGGREGATE-0000000007 (stores: [KTABLE-AGGREGATE-STATE-STORE-0000000004])\n" + + " --> KTABLE-TOSTREAM-0000000008\n" + + " <-- KSTREAM-SOURCE-0000000006\n" + + " Processor: KTABLE-TOSTREAM-0000000008 (stores: [])\n" + + " --> KSTREAM-SINK-0000000009\n" + + " <-- KTABLE-AGGREGATE-0000000007\n" + + " Sink: KSTREAM-SINK-0000000009 (topic: output)\n" + + " <-- KTABLE-TOSTREAM-0000000008\n\n")); + + try ( + final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(input, new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestOutputTopic outputTopic = + driver.createOutputTopic(output, Serdes.String().deserializer(), Serdes.Long().deserializer()); + + inputTopic.pipeInput("A", "green", 10L); + inputTopic.pipeInput("B", "green", 9L); + inputTopic.pipeInput("A", "blue", 12L); + inputTopic.pipeInput("C", "yellow", 15L); + inputTopic.pipeInput("D", "green", 11L); + + assertEquals( + asList( + new TestRecord<>("green", 1L, Instant.ofEpochMilli(10)), + new TestRecord<>("green", 2L, Instant.ofEpochMilli(10)), + new TestRecord<>("green", 1L, Instant.ofEpochMilli(12)), + new TestRecord<>("blue", 1L, Instant.ofEpochMilli(12)), + new TestRecord<>("yellow", 1L, Instant.ofEpochMilli(15)), + new TestRecord<>("green", 2L, Instant.ofEpochMilli(12))), + outputTopic.readRecordsToList()); + } + } + + @Test + public void shouldSupportTriggerMaterializedWithKTableFromKStream() { + final Consumed consumed = Consumed.with(Serdes.String(), Serdes.String()); + final StreamsBuilder builder = new StreamsBuilder(); + + final String input = "input"; + final String output = "output"; + final String storeName = "store"; + + builder.stream(input, consumed) + .toTable() + .mapValues( + value -> value.charAt(0) - (int) 'a', + Materialized.>as(storeName) + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.Integer())) + .toStream() + .to(output); + + final Topology topology = builder.build(props); + + final String topologyDescription = topology.describe().toString(); + assertThat( + topologyDescription, + equalTo("Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input])\n" + + " --> KSTREAM-TOTABLE-0000000001\n" + + " Processor: KSTREAM-TOTABLE-0000000001 (stores: [])\n" + + " --> KTABLE-MAPVALUES-0000000003\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: KTABLE-MAPVALUES-0000000003 (stores: [store])\n" + + " --> KTABLE-TOSTREAM-0000000004\n" + + " <-- KSTREAM-TOTABLE-0000000001\n" + + " Processor: KTABLE-TOSTREAM-0000000004 (stores: [])\n" + + " --> KSTREAM-SINK-0000000005\n" + + " <-- KTABLE-MAPVALUES-0000000003\n" + + " Sink: KSTREAM-SINK-0000000005 (topic: output)\n" + + " <-- KTABLE-TOSTREAM-0000000004\n\n")); + + try ( + final TopologyTestDriver driver = new TopologyTestDriver(topology, props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(input, new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestOutputTopic outputTopic = + driver.createOutputTopic(output, Serdes.String().deserializer(), Serdes.Integer().deserializer()); + final KeyValueStore store = driver.getKeyValueStore(storeName); + + inputTopic.pipeInput("A", "green", 10L); + inputTopic.pipeInput("B", "green", 9L); + inputTopic.pipeInput("A", "blue", 12L); + inputTopic.pipeInput("C", "yellow", 15L); + inputTopic.pipeInput("D", "green", 11L); + + final Map expectedStore = new HashMap<>(); + expectedStore.putIfAbsent("A", 1); + expectedStore.putIfAbsent("B", 6); + expectedStore.putIfAbsent("C", 24); + expectedStore.putIfAbsent("D", 6); + + assertEquals(expectedStore, asMap(store)); + + assertEquals( + asList( + new TestRecord<>("A", 6, Instant.ofEpochMilli(10)), + new TestRecord<>("B", 6, Instant.ofEpochMilli(9)), + new TestRecord<>("A", 1, Instant.ofEpochMilli(12)), + new TestRecord<>("C", 24, Instant.ofEpochMilli(15)), + new TestRecord<>("D", 6, Instant.ofEpochMilli(11))), + outputTopic.readRecordsToList()); + + } + } + + private static Map asMap(final KeyValueStore store) { + final HashMap result = new HashMap<>(); + store.all().forEachRemaining(kv -> result.put(kv.key, kv.value)); + return result; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamImplValueJoinerWithKeyTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamImplValueJoinerWithKeyTest.java new file mode 100644 index 0000000..b35a7b2 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamImplValueJoinerWithKeyTest.java @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TestOutputTopic; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.GlobalKTable; +import org.apache.kafka.streams.kstream.JoinWindows; +import org.apache.kafka.streams.kstream.Joined; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.StreamJoined; +import org.apache.kafka.streams.kstream.ValueJoinerWithKey; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Before; +import org.junit.Test; + +import java.util.Collections; +import java.util.List; +import java.util.Properties; + +import static java.time.Duration.ofHours; +import static java.time.Duration.ofMillis; +import static org.junit.Assert.assertEquals; + +public class KStreamImplValueJoinerWithKeyTest { + + private KStream leftStream; + private KStream rightStream; + private KTable ktable; + private GlobalKTable globalKTable; + private StreamsBuilder builder; + + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.String(), Serdes.String()); + private final String leftTopic = "left"; + private final String rightTopic = "right"; + private final String ktableTopic = "ktableTopic"; + private final String globalTopic = "globalTopic"; + private final String outputTopic = "joined-result"; + + private final ValueJoinerWithKey valueJoinerWithKey = + (key, lv, rv) -> key + ":" + (lv + (rv == null ? 0 : rv)); + private final JoinWindows joinWindows = JoinWindows.ofTimeDifferenceAndGrace(ofMillis(100), ofHours(24L)); + private final StreamJoined streamJoined = + StreamJoined.with(Serdes.String(), Serdes.Integer(), Serdes.Integer()); + private final Joined joined = + Joined.with(Serdes.String(), Serdes.Integer(), Serdes.Integer()); + private final KeyValueMapper keyValueMapper = + (k, v) -> k; + + @Before + public void setup() { + builder = new StreamsBuilder(); + leftStream = builder.stream(leftTopic, Consumed.with(Serdes.String(), Serdes.Integer())); + rightStream = builder.stream(rightTopic, Consumed.with(Serdes.String(), Serdes.Integer())); + ktable = builder.table(ktableTopic, Consumed.with(Serdes.String(), Serdes.Integer())); + globalKTable = builder.globalTable(globalTopic, Consumed.with(Serdes.String(), Serdes.Integer())); + } + + @Test + public void shouldIncludeKeyInStreamSteamJoinResults() { + leftStream.join( + rightStream, + valueJoinerWithKey, + joinWindows, + streamJoined + ).to(outputTopic, Produced.with(Serdes.String(), Serdes.String())); + // Left KV A, 3, Right KV A, 5 + runJoinTopology( + builder, + Collections.singletonList(KeyValue.pair("A", "A:5")), + false, + rightTopic + ); + } + + @Test + public void shouldIncludeKeyInStreamLeftJoinResults() { + leftStream.leftJoin( + rightStream, + valueJoinerWithKey, + joinWindows, + streamJoined + ).to(outputTopic, Produced.with(Serdes.String(), Serdes.String())); + // Left KV A, 3, Right KV A, 5 + // TTD pipes records to left stream first, then right + final List> expectedResults = Collections.singletonList(KeyValue.pair("A", "A:5")); + runJoinTopology( + builder, + expectedResults, + false, + rightTopic + ); + } + + @Test + public void shouldIncludeKeyInStreamOuterJoinResults() { + leftStream.outerJoin( + rightStream, + valueJoinerWithKey, + joinWindows, + streamJoined + ).to(outputTopic, Produced.with(Serdes.String(), Serdes.String())); + + // Left KV A, 3, Right KV A, 5 + // TTD pipes records to left stream first, then right + final List> expectedResults = Collections.singletonList(KeyValue.pair("A", "A:5")); + runJoinTopology( + builder, + expectedResults, + false, + rightTopic + ); + } + + @Test + public void shouldIncludeKeyInStreamTableJoinResults() { + leftStream.join( + ktable, + valueJoinerWithKey, + joined + ).to(outputTopic, Produced.with(Serdes.String(), Serdes.String())); + // Left KV A, 3, Table KV A, 5 + runJoinTopology( + builder, + Collections.singletonList(KeyValue.pair("A", "A:5")), + true, + ktableTopic + ); + } + + @Test + public void shouldIncludeKeyInStreamTableLeftJoinResults() { + leftStream.leftJoin( + ktable, + valueJoinerWithKey, + joined + ).to(outputTopic, Produced.with(Serdes.String(), Serdes.String())); + // Left KV A, 3, Table KV A, 5 + runJoinTopology( + builder, + Collections.singletonList(KeyValue.pair("A", "A:5")), + true, + ktableTopic + ); + } + + @Test + public void shouldIncludeKeyInStreamGlobalTableJoinResults() { + leftStream.join( + globalKTable, + keyValueMapper, + valueJoinerWithKey + ).to(outputTopic, Produced.with(Serdes.String(), Serdes.String())); + // Left KV A, 3, GlobalTable KV A, 5 + runJoinTopology( + builder, + Collections.singletonList(KeyValue.pair("A", "A:5")), + true, + globalTopic + ); + } + + @Test + public void shouldIncludeKeyInStreamGlobalTableLeftJoinResults() { + leftStream.leftJoin( + globalKTable, + keyValueMapper, + valueJoinerWithKey + ).to(outputTopic, Produced.with(Serdes.String(), Serdes.String())); + // Left KV A, 3, GlobalTable KV A, 5 + runJoinTopology( + builder, + Collections.singletonList(KeyValue.pair("A", "A:5")), + true, + globalTopic + ); + } + + + private void runJoinTopology(final StreamsBuilder builder, + final List> expectedResults, + final boolean isTableJoin, + final String rightTopic) { + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + + final TestInputTopic rightInputTopic = + driver.createInputTopic(rightTopic, new StringSerializer(), new IntegerSerializer()); + + final TestInputTopic leftInputTopic = + driver.createInputTopic(leftTopic, new StringSerializer(), new IntegerSerializer()); + + final TestOutputTopic joinResultTopic = + driver.createOutputTopic(outputTopic, new StringDeserializer(), new StringDeserializer()); + + // with stream table joins need to make sure records hit + // the table first, joins only triggered from streams side + if (isTableJoin) { + rightInputTopic.pipeInput("A", 2); + leftInputTopic.pipeInput("A", 3); + } else { + leftInputTopic.pipeInput("A", 3); + rightInputTopic.pipeInput("A", 2); + } + + final List> actualResult = joinResultTopic.readKeyValuesToList(); + assertEquals(expectedResults, actualResult); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamKStreamJoinTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamKStreamJoinTest.java new file mode 100644 index 0000000..08b0e93 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamKStreamJoinTest.java @@ -0,0 +1,1871 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.TopologyWrapper; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.JoinWindows; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.StreamJoined; +import org.apache.kafka.streams.processor.internals.InternalTopicConfig; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.WindowBytesStoreSupplier; +import org.apache.kafka.test.MockProcessor; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.MockValueJoiner; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Test; + +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.Properties; +import java.util.Set; + +import static java.time.Duration.ofHours; +import static java.time.Duration.ofMillis; + +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.SUBTOPOLOGY_0; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class KStreamKStreamJoinTest { + private final String topic1 = "topic1"; + private final String topic2 = "topic2"; + private final Consumed consumed = Consumed.with(Serdes.Integer(), Serdes.String()); + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.String(), Serdes.String()); + + private final JoinWindows joinWindows = JoinWindows.of(ofMillis(50)).grace(Duration.ofMillis(50)); + private final StreamJoined streamJoined = StreamJoined.with(Serdes.String(), Serdes.Integer(), Serdes.Integer()); + private final String errorMessagePrefix = "Window settings mismatch. WindowBytesStoreSupplier settings"; + + @Test + public void shouldLogAndMeterOnSkippedRecordsWithNullValueWithBuiltInMetricsVersionLatest() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream left = builder.stream("left", Consumed.with(Serdes.String(), Serdes.Integer())); + final KStream right = builder.stream("right", Consumed.with(Serdes.String(), Serdes.Integer())); + + left.join( + right, + Integer::sum, + JoinWindows.of(ofMillis(100)), + StreamJoined.with(Serdes.String(), Serdes.Integer(), Serdes.Integer()) + ); + + props.setProperty(StreamsConfig.BUILT_IN_METRICS_VERSION_CONFIG, StreamsConfig.METRICS_LATEST); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(KStreamKStreamJoin.class); + final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + + final TestInputTopic inputTopic = + driver.createInputTopic("left", new StringSerializer(), new IntegerSerializer()); + inputTopic.pipeInput("A", null); + + assertThat( + appender.getMessages(), + hasItem("Skipping record due to null key or value. topic=[left] partition=[0] offset=[0]") + ); + } + } + + + @Test + public void shouldReuseRepartitionTopicWithGeneratedName() { + final StreamsBuilder builder = new StreamsBuilder(); + final Properties props = new Properties(); + props.put(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, StreamsConfig.NO_OPTIMIZATION); + final KStream stream1 = builder.stream("topic", Consumed.with(Serdes.String(), Serdes.String())); + final KStream stream2 = builder.stream("topic2", Consumed.with(Serdes.String(), Serdes.String())); + final KStream stream3 = builder.stream("topic3", Consumed.with(Serdes.String(), Serdes.String())); + final KStream newStream = stream1.map((k, v) -> new KeyValue<>(v, k)); + newStream.join(stream2, (value1, value2) -> value1 + value2, JoinWindows.of(ofMillis(100))).to("out-one"); + newStream.join(stream3, (value1, value2) -> value1 + value2, JoinWindows.of(ofMillis(100))).to("out-to"); + assertEquals(expectedTopologyWithGeneratedRepartitionTopic, builder.build(props).describe().toString()); + } + + @Test + public void shouldCreateRepartitionTopicsWithUserProvidedName() { + final StreamsBuilder builder = new StreamsBuilder(); + final Properties props = new Properties(); + props.put(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, StreamsConfig.NO_OPTIMIZATION); + final KStream stream1 = builder.stream("topic", Consumed.with(Serdes.String(), Serdes.String())); + final KStream stream2 = builder.stream("topic2", Consumed.with(Serdes.String(), Serdes.String())); + final KStream stream3 = builder.stream("topic3", Consumed.with(Serdes.String(), Serdes.String())); + final KStream newStream = stream1.map((k, v) -> new KeyValue<>(v, k)); + final StreamJoined streamJoined = StreamJoined.with(Serdes.String(), Serdes.String(), Serdes.String()); + newStream.join(stream2, (value1, value2) -> value1 + value2, JoinWindows.of(ofMillis(100)), streamJoined.withName("first-join")).to("out-one"); + newStream.join(stream3, (value1, value2) -> value1 + value2, JoinWindows.of(ofMillis(100)), streamJoined.withName("second-join")).to("out-two"); + final Topology topology = builder.build(props); + System.out.println(topology.describe().toString()); + assertEquals(expectedTopologyWithUserNamedRepartitionTopics, topology.describe().toString()); + } + + @Test + public void shouldDisableLoggingOnStreamJoined() { + + final JoinWindows joinWindows = JoinWindows.of(ofMillis(100)).grace(Duration.ofMillis(50)); + final StreamJoined streamJoined = StreamJoined + .with(Serdes.String(), Serdes.Integer(), Serdes.Integer()) + .withStoreName("store") + .withLoggingDisabled(); + + final StreamsBuilder builder = new StreamsBuilder(); + final KStream left = builder.stream("left", Consumed.with(Serdes.String(), Serdes.Integer())); + final KStream right = builder.stream("right", Consumed.with(Serdes.String(), Serdes.Integer())); + + left.join( + right, + Integer::sum, + joinWindows, + streamJoined + ); + + final Topology topology = builder.build(); + final InternalTopologyBuilder internalTopologyBuilder = TopologyWrapper.getInternalTopologyBuilder(topology); + + assertThat(internalTopologyBuilder.stateStores().get("store-this-join-store").loggingEnabled(), equalTo(false)); + assertThat(internalTopologyBuilder.stateStores().get("store-other-join-store").loggingEnabled(), equalTo(false)); + } + + @Test + public void shouldEnableLoggingWithCustomConfigOnStreamJoined() { + + final JoinWindows joinWindows = JoinWindows.of(ofMillis(100)).grace(Duration.ofMillis(50)); + final StreamJoined streamJoined = StreamJoined + .with(Serdes.String(), Serdes.Integer(), Serdes.Integer()) + .withStoreName("store") + .withLoggingEnabled(Collections.singletonMap("test", "property")); + + final StreamsBuilder builder = new StreamsBuilder(); + final KStream left = builder.stream("left", Consumed.with(Serdes.String(), Serdes.Integer())); + final KStream right = builder.stream("right", Consumed.with(Serdes.String(), Serdes.Integer())); + + left.join( + right, + Integer::sum, + joinWindows, + streamJoined + ); + + final Topology topology = builder.build(); + final InternalTopologyBuilder internalTopologyBuilder = TopologyWrapper.getInternalTopologyBuilder(topology); + + internalTopologyBuilder.buildSubtopology(0); + + assertThat(internalTopologyBuilder.stateStores().get("store-this-join-store").loggingEnabled(), equalTo(true)); + assertThat(internalTopologyBuilder.stateStores().get("store-other-join-store").loggingEnabled(), equalTo(true)); + assertThat(internalTopologyBuilder.topicGroups().get(SUBTOPOLOGY_0).stateChangelogTopics.size(), equalTo(2)); + for (final InternalTopicConfig config : internalTopologyBuilder.topicGroups().get(SUBTOPOLOGY_0).stateChangelogTopics.values()) { + assertThat( + config.getProperties(Collections.emptyMap(), 0).get("test"), + equalTo("property") + ); + } + } + + @Test + public void shouldThrowExceptionThisStoreSupplierRetentionDoNotMatchWindowsSizeAndGrace() { + // Case where retention of thisJoinStore doesn't match JoinWindows + final WindowBytesStoreSupplier thisStoreSupplier = buildWindowBytesStoreSupplier("in-memory-join-store", 500L, 100L, true); + final WindowBytesStoreSupplier otherStoreSupplier = buildWindowBytesStoreSupplier("in-memory-join-store-other", 150L, 100L, true); + + buildStreamsJoinThatShouldThrow( + streamJoined.withThisStoreSupplier(thisStoreSupplier).withOtherStoreSupplier(otherStoreSupplier), + joinWindows, + errorMessagePrefix + ); + } + + @Test + public void shouldThrowExceptionThisStoreSupplierWindowSizeDoesNotMatchJoinWindowsWindowSize() { + //Case where window size of thisJoinStore doesn't match JoinWindows + final WindowBytesStoreSupplier thisStoreSupplier = buildWindowBytesStoreSupplier("in-memory-join-store", 150L, 150L, true); + final WindowBytesStoreSupplier otherStoreSupplier = buildWindowBytesStoreSupplier("in-memory-join-store-other", 150L, 100L, true); + + buildStreamsJoinThatShouldThrow( + streamJoined.withThisStoreSupplier(thisStoreSupplier).withOtherStoreSupplier(otherStoreSupplier), + joinWindows, + errorMessagePrefix + ); + } + + @Test + public void shouldThrowExceptionWhenThisJoinStoreSetsRetainDuplicatesFalse() { + //Case where thisJoinStore retain duplicates false + final WindowBytesStoreSupplier thisStoreSupplier = buildWindowBytesStoreSupplier("in-memory-join-store", 150L, 100L, false); + final WindowBytesStoreSupplier otherStoreSupplier = buildWindowBytesStoreSupplier("in-memory-join-store-other", 150L, 100L, true); + + buildStreamsJoinThatShouldThrow( + streamJoined.withThisStoreSupplier(thisStoreSupplier).withOtherStoreSupplier(otherStoreSupplier), + joinWindows, + "The StoreSupplier must set retainDuplicates=true, found retainDuplicates=false" + ); + } + + @Test + public void shouldThrowExceptionOtherStoreSupplierRetentionDoNotMatchWindowsSizeAndGrace() { + //Case where retention size of otherJoinStore doesn't match JoinWindows + final WindowBytesStoreSupplier thisStoreSupplier = buildWindowBytesStoreSupplier("in-memory-join-store", 150L, 100L, true); + final WindowBytesStoreSupplier otherStoreSupplier = buildWindowBytesStoreSupplier("in-memory-join-store-other", 500L, 100L, true); + + buildStreamsJoinThatShouldThrow( + streamJoined.withThisStoreSupplier(thisStoreSupplier).withOtherStoreSupplier(otherStoreSupplier), + joinWindows, + errorMessagePrefix + ); + } + + @Test + public void shouldThrowExceptionOtherStoreSupplierWindowSizeDoesNotMatchJoinWindowsWindowSize() { + //Case where window size of otherJoinStore doesn't match JoinWindows + final WindowBytesStoreSupplier thisStoreSupplier = buildWindowBytesStoreSupplier("in-memory-join-store", 150L, 100L, true); + final WindowBytesStoreSupplier otherStoreSupplier = buildWindowBytesStoreSupplier("in-memory-join-store-other", 150L, 150L, true); + + buildStreamsJoinThatShouldThrow( + streamJoined.withThisStoreSupplier(thisStoreSupplier).withOtherStoreSupplier(otherStoreSupplier), + joinWindows, + errorMessagePrefix + ); + } + + @Test + public void shouldThrowExceptionWhenOtherJoinStoreSetsRetainDuplicatesFalse() { + //Case where otherJoinStore retain duplicates false + final WindowBytesStoreSupplier thisStoreSupplier = buildWindowBytesStoreSupplier("in-memory-join-store", 150L, 100L, true); + final WindowBytesStoreSupplier otherStoreSupplier = buildWindowBytesStoreSupplier("in-memory-join-store-other", 150L, 100L, false); + + buildStreamsJoinThatShouldThrow( + streamJoined.withThisStoreSupplier(thisStoreSupplier).withOtherStoreSupplier(otherStoreSupplier), + joinWindows, + "The StoreSupplier must set retainDuplicates=true, found retainDuplicates=false" + ); + } + + @Test + public void shouldBuildJoinWithCustomStoresAndCorrectWindowSettings() { + //Case where everything matches up + final StreamsBuilder builder = new StreamsBuilder(); + final KStream left = builder.stream("left", Consumed.with(Serdes.String(), Serdes.Integer())); + final KStream right = builder.stream("right", Consumed.with(Serdes.String(), Serdes.Integer())); + + left.join(right, + Integer::sum, + joinWindows, + streamJoined + ); + + builder.build(); + } + + @Test + public void shouldExceptionWhenJoinStoresDontHaveUniqueNames() { + final JoinWindows joinWindows = JoinWindows.of(ofMillis(100L)).grace(Duration.ofMillis(50L)); + final StreamJoined streamJoined = StreamJoined.with(Serdes.String(), Serdes.Integer(), Serdes.Integer()); + final WindowBytesStoreSupplier thisStoreSupplier = buildWindowBytesStoreSupplier("in-memory-join-store", 150L, 100L, true); + final WindowBytesStoreSupplier otherStoreSupplier = buildWindowBytesStoreSupplier("in-memory-join-store", 150L, 100L, true); + + buildStreamsJoinThatShouldThrow( + streamJoined.withThisStoreSupplier(thisStoreSupplier).withOtherStoreSupplier(otherStoreSupplier), + joinWindows, + "Both StoreSuppliers have the same name. StoreSuppliers must provide unique names" + ); + } + + @Test + public void shouldJoinWithCustomStoreSuppliers() { + final JoinWindows joinWindows = JoinWindows.of(ofMillis(100L)); + + final WindowBytesStoreSupplier thisStoreSupplier = Stores.inMemoryWindowStore( + "in-memory-join-store", + Duration.ofMillis(joinWindows.size() + joinWindows.gracePeriodMs()), + Duration.ofMillis(joinWindows.size()), + true + ); + + final WindowBytesStoreSupplier otherStoreSupplier = Stores.inMemoryWindowStore( + "in-memory-join-store-other", + Duration.ofMillis(joinWindows.size() + joinWindows.gracePeriodMs()), + Duration.ofMillis(joinWindows.size()), + true + ); + + final StreamJoined streamJoined = StreamJoined.with(Serdes.String(), Serdes.Integer(), Serdes.Integer()); + + //Case with 2 custom store suppliers + runJoin(streamJoined.withThisStoreSupplier(thisStoreSupplier).withOtherStoreSupplier(otherStoreSupplier), joinWindows); + + //Case with this stream store supplier + runJoin(streamJoined.withThisStoreSupplier(thisStoreSupplier), joinWindows); + + //Case with other stream store supplier + runJoin(streamJoined.withOtherStoreSupplier(otherStoreSupplier), joinWindows); + } + + private void runJoin(final StreamJoined streamJoined, + final JoinWindows joinWindows) { + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream left = builder.stream("left", Consumed.with(Serdes.String(), Serdes.Integer())); + final KStream right = builder.stream("right", Consumed.with(Serdes.String(), Serdes.Integer())); + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + final KStream joinedStream; + + joinedStream = left.join( + right, + Integer::sum, + joinWindows, + streamJoined + ); + + joinedStream.process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopicLeft = + driver.createInputTopic("left", new StringSerializer(), new IntegerSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopicRight = + driver.createInputTopic("right", new StringSerializer(), new IntegerSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockProcessor processor = supplier.theCapturedProcessor(); + + inputTopicLeft.pipeInput("A", 1, 1L); + inputTopicLeft.pipeInput("B", 1, 2L); + + inputTopicRight.pipeInput("A", 1, 1L); + inputTopicRight.pipeInput("B", 2, 2L); + + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>("A", 2, 1L), + new KeyValueTimestamp<>("B", 3, 2L) + ); + } + } + + @Test + public void testJoin() { + final StreamsBuilder builder = new StreamsBuilder(); + + final int[] expectedKeys = new int[] {0, 1, 2, 3}; + + final KStream stream1; + final KStream stream2; + final KStream joined; + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + stream1 = builder.stream(topic1, consumed); + stream2 = builder.stream(topic2, consumed); + joined = stream1.join( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.of(ofMillis(100L)), + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()) + ); + joined.process(supplier); + + final Collection> copartitionGroups = + TopologyWrapper.getInternalTopologyBuilder(builder.build()).copartitionGroups(); + + assertEquals(1, copartitionGroups.size()); + assertEquals(new HashSet<>(Arrays.asList(topic1, topic2)), copartitionGroups.iterator().next()); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockProcessor processor = supplier.theCapturedProcessor(); + + // push two items to the primary stream; the other window is empty + // w1 = {} + // w2 = {} + // --> w1 = { 0:A0, 1:A1 } + // w2 = {} + for (int i = 0; i < 2; i++) { + inputTopic1.pipeInput(expectedKeys[i], "A" + expectedKeys[i]); + } + processor.checkAndClearProcessResult(); + + // push two items to the other stream; this should produce two items + // w1 = { 0:A0, 1:A1 } + // w2 = {} + // --> w1 = { 0:A0, 1:A1 } + // w2 = { 0:a0, 1:a1 } + for (int i = 0; i < 2; i++) { + inputTopic2.pipeInput(expectedKeys[i], "a" + expectedKeys[i]); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+a0", 0L), + new KeyValueTimestamp<>(1, "A1+a1", 0L) + ); + + // push all four items to the primary stream; this should produce two items + // w1 = { 0:A0, 1:A1 } + // w2 = { 0:a0, 1:a1 } + // --> w1 = { 0:A0, 1:A1, 0:B0, 1:B1, 2:B2, 3:B3 } + // w2 = { 0:a0, 1:a1 } + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "B" + expectedKey); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "B0+a0", 0L), + new KeyValueTimestamp<>(1, "B1+a1", 0L) + ); + + // push all items to the other stream; this should produce six items + // w1 = { 0:A0, 1:A1, 0:B0, 1:B1, 2:B2, 3:B3 } + // w2 = { 0:a0, 1:a1 } + // --> w1 = { 0:A0, 1:A1, 0:B0, 1:B1, 2:B2, 3:B3 } + // w2 = { 0:a0, 1:a1, 0:b0, 1:b1, 2:b2, 3:b3 } + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "b" + expectedKey); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+b0", 0L), + new KeyValueTimestamp<>(0, "B0+b0", 0L), + new KeyValueTimestamp<>(1, "A1+b1", 0L), + new KeyValueTimestamp<>(1, "B1+b1", 0L), + new KeyValueTimestamp<>(2, "B2+b2", 0L), + new KeyValueTimestamp<>(3, "B3+b3", 0L) + ); + + // push all four items to the primary stream; this should produce six items + // w1 = { 0:A0, 1:A1, 0:B0, 1:B1, 2:B2, 3:B3 } + // w2 = { 0:a0, 1:a1, 0:b0, 1:b1, 2:b2, 3:b3 } + // --> w1 = { 0:A0, 1:A1, 0:B0, 1:B1, 2:B2, 3:B3, 0:C0, 1:C1, 2:C2, 3:C3 } + // w2 = { 0:a0, 1:a1, 0:b0, 1:b1, 2:b2, 3:b3 } + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "C" + expectedKey); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "C0+a0", 0L), + new KeyValueTimestamp<>(0, "C0+b0", 0L), + new KeyValueTimestamp<>(1, "C1+a1", 0L), + new KeyValueTimestamp<>(1, "C1+b1", 0L), + new KeyValueTimestamp<>(2, "C2+b2", 0L), + new KeyValueTimestamp<>(3, "C3+b3", 0L) + ); + + // push two items to the other stream; this should produce six items + // w1 = { 0:A0, 1:A1, 0:B0, 1:B1, 2:B2, 3:B3, 0:C0, 1:C1, 2:C2, 3:C3 } + // w2 = { 0:a0, 1:a1, 0:b0, 1:b1, 2:b2, 3:b3 } + // --> w1 = { 0:A0, 1:A1, 0:B0, 1:B1, 2:B2, 3:B3, 0:C0, 1:C1, 2:C2, 3:C3 } + // w2 = { 0:a0, 1:a1, 0:b0, 1:b1, 2:b2, 3:b3, 0:c0, 1:c1 } + for (int i = 0; i < 2; i++) { + inputTopic2.pipeInput(expectedKeys[i], "c" + expectedKeys[i]); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+c0", 0L), + new KeyValueTimestamp<>(0, "B0+c0", 0L), + new KeyValueTimestamp<>(0, "C0+c0", 0L), + new KeyValueTimestamp<>(1, "A1+c1", 0L), + new KeyValueTimestamp<>(1, "B1+c1", 0L), + new KeyValueTimestamp<>(1, "C1+c1", 0L) + ); + } + } + + @Test + public void testOuterJoin() { + final StreamsBuilder builder = new StreamsBuilder(); + + final int[] expectedKeys = new int[] {0, 1, 2, 3}; + + final KStream stream1; + final KStream stream2; + final KStream joined; + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + + stream1 = builder.stream(topic1, consumed); + stream2 = builder.stream(topic2, consumed); + joined = stream1.outerJoin( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.ofTimeDifferenceAndGrace(ofMillis(100L), ofHours(24L)), + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()) + ); + joined.process(supplier); + final Collection> copartitionGroups = + TopologyWrapper.getInternalTopologyBuilder(builder.build()).copartitionGroups(); + + assertEquals(1, copartitionGroups.size()); + assertEquals(new HashSet<>(Arrays.asList(topic1, topic2)), copartitionGroups.iterator().next()); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockProcessor processor = supplier.theCapturedProcessor(); + + // push two items to the primary stream; the other window is empty; this should not produce items yet + // w1 = {} + // w2 = {} + // --> w1 = { 0:A0, 1:A1 } + // w2 = {} + for (int i = 0; i < 2; i++) { + inputTopic1.pipeInput(expectedKeys[i], "A" + expectedKeys[i]); + } + processor.checkAndClearProcessResult(); + + // push two items to the other stream; this should produce two items + // w1 = { 0:A0, 1:A1 } + // w2 = {} + // --> w1 = { 0:A0, 1:A1 } + // w2 = { 0:a0, 1:a1 } + for (int i = 0; i < 2; i++) { + inputTopic2.pipeInput(expectedKeys[i], "a" + expectedKeys[i]); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+a0", 0L), + new KeyValueTimestamp<>(1, "A1+a1", 0L) + ); + + // push all four items to the primary stream; this should produce two items + // w1 = { 0:A0, 1:A1 } + // w2 = { 0:a0, 1:a1 } + // --> w1 = { 0:A0, 1:A1, 0:B0, 1:B1, 2:B2, 3:B3 } + // w2 = { 0:a0, 1:a1 } + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "B" + expectedKey); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "B0+a0", 0L), + new KeyValueTimestamp<>(1, "B1+a1", 0L) + ); + + // push all items to the other stream; this should produce six items + // w1 = { 0:A0, 1:A1, 0:B0, 1:B1, 2:B2, 3:B3 } + // w2 = { 0:a0, 1:a1 } + // --> w1 = { 0:A0, 1:A1, 0:B0, 1:B1, 2:B2, 3:B3 } + // w2 = { 0:a0, 1:a1, 0:b0, 0:b0, 1:b1, 2:b2, 3:b3 } + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "b" + expectedKey); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+b0", 0L), + new KeyValueTimestamp<>(0, "B0+b0", 0L), + new KeyValueTimestamp<>(1, "A1+b1", 0L), + new KeyValueTimestamp<>(1, "B1+b1", 0L), + new KeyValueTimestamp<>(2, "B2+b2", 0L), + new KeyValueTimestamp<>(3, "B3+b3", 0L) + ); + + // push all four items to the primary stream; this should produce six items + // w1 = { 0:A0, 1:A1, 0:B0, 1:B1, 2:B2, 3:B3 } + // w2 = { 0:a0, 1:a1, 0:b0, 0:b0, 1:b1, 2:b2, 3:b3 } + // --> w1 = { 0:A0, 1:A1, 0:B0, 1:B1, 2:B2, 3:B3, 0:C0, 1:C1, 2:C2, 3:C3 } + // w2 = { 0:a0, 1:a1, 0:b0, 0:b0, 1:b1, 2:b2, 3:b3 } + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "C" + expectedKey); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "C0+a0", 0L), + new KeyValueTimestamp<>(0, "C0+b0", 0L), + new KeyValueTimestamp<>(1, "C1+a1", 0L), + new KeyValueTimestamp<>(1, "C1+b1", 0L), + new KeyValueTimestamp<>(2, "C2+b2", 0L), + new KeyValueTimestamp<>(3, "C3+b3", 0L) + ); + + // push two items to the other stream; this should produce six items + // w1 = { 0:A0, 1:A1, 0:B0, 1:B1, 2:B2, 3:B3, 0:C0, 1:C1, 2:C2, 3:C3 } + // w2 = { 0:a0, 1:a1, 0:b0, 0:b0, 1:b1, 2:b2, 3:b3 } + // --> w1 = { 0:A0, 1:A1, 0:B0, 1:B1, 2:B2, 3:B3, 0:C0, 1:C1, 2:C2, 3:C3 } + // w2 = { 0:a0, 1:a1, 0:b0, 0:b0, 1:b1, 2:b2, 3:b3, 0:c0, 1:c1 } + for (int i = 0; i < 2; i++) { + inputTopic2.pipeInput(expectedKeys[i], "c" + expectedKeys[i]); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+c0", 0L), + new KeyValueTimestamp<>(0, "B0+c0", 0L), + new KeyValueTimestamp<>(0, "C0+c0", 0L), + new KeyValueTimestamp<>(1, "A1+c1", 0L), + new KeyValueTimestamp<>(1, "B1+c1", 0L), + new KeyValueTimestamp<>(1, "C1+c1", 0L) + ); + } + } + + @Test + public void testWindowing() { + final StreamsBuilder builder = new StreamsBuilder(); + + final int[] expectedKeys = new int[] {0, 1, 2, 3}; + + final KStream stream1; + final KStream stream2; + final KStream joined; + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + stream1 = builder.stream(topic1, consumed); + stream2 = builder.stream(topic2, consumed); + + joined = stream1.join( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.of(ofMillis(100L)), + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()) + ); + joined.process(supplier); + + final Collection> copartitionGroups = + TopologyWrapper.getInternalTopologyBuilder(builder.build()).copartitionGroups(); + + assertEquals(1, copartitionGroups.size()); + assertEquals(new HashSet<>(Arrays.asList(topic1, topic2)), copartitionGroups.iterator().next()); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockProcessor processor = supplier.theCapturedProcessor(); + long time = 0L; + + // push two items to the primary stream; the other window is empty; this should produce no items + // w1 = {} + // w2 = {} + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0) } + // w2 = {} + for (int i = 0; i < 2; i++) { + inputTopic1.pipeInput(expectedKeys[i], "A" + expectedKeys[i], time); + } + processor.checkAndClearProcessResult(); + + // push two items to the other stream; this should produce two items + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0) } + // w2 = {} + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0) } + for (int i = 0; i < 2; i++) { + inputTopic2.pipeInput(expectedKeys[i], "a" + expectedKeys[i], time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+a0", 0L), + new KeyValueTimestamp<>(1, "A1+a1", 0L) + ); + + // push four items to the primary stream with larger and increasing timestamp; this should produce no items + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1000), 1:B1 (ts: 1001), 2:B2 (ts: 1002), 3:B3 (ts: 1003) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0) } + time = 1000L; + for (int i = 0; i < expectedKeys.length; i++) { + inputTopic1.pipeInput(expectedKeys[i], "B" + expectedKeys[i], time + i); + } + processor.checkAndClearProcessResult(); + + // push four items to the other stream with fixed larger timestamp; this should produce four items + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1000), 1:B1 (ts: 1001), 2:B2 (ts: 1002), 3:B3 (ts: 1003) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1000), 1:B1 (ts: 1001), 2:B2 (ts: 1002), 3:B3 (ts: 1003) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), + // 0:b0 (ts: 1100), 1:b1 (ts: 1100), 2:b2 (ts: 1100), 3:b3 (ts: 1100) } + time += 100L; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "b" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "B0+b0", 1100L), + new KeyValueTimestamp<>(1, "B1+b1", 1100L), + new KeyValueTimestamp<>(2, "B2+b2", 1100L), + new KeyValueTimestamp<>(3, "B3+b3", 1100L) + ); + + // push four items to the other stream with incremented timestamp; this should produce three items + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1000), 1:B1 (ts: 1001), 2:B2 (ts: 1002), 3:B3 (ts: 1003) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), + // 0:b0 (ts: 1100), 1:b1 (ts: 1100), 2:b2 (ts: 1100), 3:b3 (ts: 1100) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1000), 1:B1 (ts: 1001), 2:B2 (ts: 1002), 3:B3 (ts: 1003) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), + // 0:b0 (ts: 1100), 1:b1 (ts: 1100), 2:b2 (ts: 1100), 3:b3 (ts: 1100), + // 0:c0 (ts: 1101), 1:c1 (ts: 1101), 2:c2 (ts: 1101), 3:c3 (ts: 1101) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "c" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(1, "B1+c1", 1101L), + new KeyValueTimestamp<>(2, "B2+c2", 1101L), + new KeyValueTimestamp<>(3, "B3+c3", 1101L) + ); + + // push four items to the other stream with incremented timestamp; this should produce two items + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1000), 1:B1 (ts: 1001), 2:B2 (ts: 1002), 3:B3 (ts: 1003) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), + // 0:b0 (ts: 1100), 1:b1 (ts: 1100), 2:b2 (ts: 1100), 3:b3 (ts: 1100), + // 0:c0 (ts: 1101), 1:c1 (ts: 1101), 2:c2 (ts: 1101), 3:c3 (ts: 1101) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1000), 1:B1 (ts: 1001), 2:B2 (ts: 1002), 3:B3 (ts: 1003) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), + // 0:b0 (ts: 1100), 1:b1 (ts: 1100), 2:b2 (ts: 1100), 3:b3 (ts: 1100), + // 0:c0 (ts: 1101), 1:c1 (ts: 1101), 2:c2 (ts: 1101), 3:c3 (ts: 1101), + // 0:d0 (ts: 1102), 1:d1 (ts: 1102), 2:d2 (ts: 1102), 3:d3 (ts: 1102) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "d" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(2, "B2+d2", 1102L), + new KeyValueTimestamp<>(3, "B3+d3", 1102L) + ); + + // push four items to the other stream with incremented timestamp; this should produce one item + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1000), 1:B1 (ts: 1001), 2:B2 (ts: 1002), 3:B3 (ts: 1003) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), + // 0:b0 (ts: 1100), 1:b1 (ts: 1100), 2:b2 (ts: 1100), 3:b3 (ts: 1100), + // 0:c0 (ts: 1101), 1:c1 (ts: 1101), 2:c2 (ts: 1101), 3:c3 (ts: 1101), + // 0:d0 (ts: 1102), 1:d1 (ts: 1102), 2:d2 (ts: 1102), 3:d3 (ts: 1102) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1000), 1:B1 (ts: 1001), 2:B2 (ts: 1002), 3:B3 (ts: 1003) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), + // 0:b0 (ts: 1100), 1:b1 (ts: 1100), 2:b2 (ts: 1100), 3:b3 (ts: 1100), + // 0:c0 (ts: 1101), 1:c1 (ts: 1101), 2:c2 (ts: 1101), 3:c3 (ts: 1101), + // 0:d0 (ts: 1102), 1:d1 (ts: 1102), 2:d2 (ts: 1102), 3:d3 (ts: 1102), + // 0:e0 (ts: 1103), 1:e1 (ts: 1103), 2:e2 (ts: 1103), 3:e3 (ts: 1103) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "e" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(3, "B3+e3", 1103L) + ); + + // push four items to the other stream with incremented timestamp; this should produce no items + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1000), 1:B1 (ts: 1001), 2:B2 (ts: 1002), 3:B3 (ts: 1003) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), + // 0:b0 (ts: 1100), 1:b1 (ts: 1100), 2:b2 (ts: 1100), 3:b3 (ts: 1100), + // 0:c0 (ts: 1101), 1:c1 (ts: 1101), 2:c2 (ts: 1101), 3:c3 (ts: 1101), + // 0:d0 (ts: 1102), 1:d1 (ts: 1102), 2:d2 (ts: 1102), 3:d3 (ts: 1102), + // 0:e0 (ts: 1103), 1:e1 (ts: 1103), 2:e2 (ts: 1103), 3:e3 (ts: 1103) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1000), 1:B1 (ts: 1001), 2:B2 (ts: 1002), 3:B3 (ts: 1003) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), + // 0:b0 (ts: 1100), 1:b1 (ts: 1100), 2:b2 (ts: 1100), 3:b3 (ts: 1100), + // 0:c0 (ts: 1101), 1:c1 (ts: 1101), 2:c2 (ts: 1101), 3:c3 (ts: 1101), + // 0:d0 (ts: 1102), 1:d1 (ts: 1102), 2:d2 (ts: 1102), 3:d3 (ts: 1102), + // 0:e0 (ts: 1103), 1:e1 (ts: 1103), 2:e2 (ts: 1103), 3:e3 (ts: 1103), + // 0:f0 (ts: 1104), 1:f1 (ts: 1104), 2:f2 (ts: 1104), 3:f3 (ts: 1104) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "f" + expectedKey, time); + } + processor.checkAndClearProcessResult(); + + // push four items to the other stream with timestamp before the window bound; this should produce no items + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1000), 1:B1 (ts: 1001), 2:B2 (ts: 1002), 3:B3 (ts: 1003) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), + // 0:b0 (ts: 1100), 1:b1 (ts: 1100), 2:b2 (ts: 1100), 3:b3 (ts: 1100), + // 0:c0 (ts: 1101), 1:c1 (ts: 1101), 2:c2 (ts: 1101), 3:c3 (ts: 1101), + // 0:d0 (ts: 1102), 1:d1 (ts: 1102), 2:d2 (ts: 1102), 3:d3 (ts: 1102), + // 0:e0 (ts: 1103), 1:e1 (ts: 1103), 2:e2 (ts: 1103), 3:e3 (ts: 1103), + // 0:f0 (ts: 1104), 1:f1 (ts: 1104), 2:f2 (ts: 1104), 3:f3 (ts: 1104) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1000), 1:B1 (ts: 1001), 2:B2 (ts: 1002), 3:B3 (ts: 1003) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), + // 0:b0 (ts: 1100), 1:b1 (ts: 1100), 2:b2 (ts: 1100), 3:b3 (ts: 1100), + // 0:c0 (ts: 1101), 1:c1 (ts: 1101), 2:c2 (ts: 1101), 3:c3 (ts: 1101), + // 0:d0 (ts: 1102), 1:d1 (ts: 1102), 2:d2 (ts: 1102), 3:d3 (ts: 1102), + // 0:e0 (ts: 1103), 1:e1 (ts: 1103), 2:e2 (ts: 1103), 3:e3 (ts: 1103), + // 0:f0 (ts: 1104), 1:f1 (ts: 1104), 2:f2 (ts: 1104), 3:f3 (ts: 1104), + // 0:g0 (ts: 899), 1:g1 (ts: 899), 2:g2 (ts: 899), 3:g3 (ts: 899) } + time = 1000L - 100L - 1L; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "g" + expectedKey, time); + } + processor.checkAndClearProcessResult(); + + // push four items to the other stream with with incremented timestamp; this should produce one item + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1000), 1:B1 (ts: 1001), 2:B2 (ts: 1002), 3:B3 (ts: 1003) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), + // 0:b0 (ts: 1100), 1:b1 (ts: 1100), 2:b2 (ts: 1100), 3:b3 (ts: 1100), + // 0:c0 (ts: 1101), 1:c1 (ts: 1101), 2:c2 (ts: 1101), 3:c3 (ts: 1101), + // 0:d0 (ts: 1102), 1:d1 (ts: 1102), 2:d2 (ts: 1102), 3:d3 (ts: 1102), + // 0:e0 (ts: 1103), 1:e1 (ts: 1103), 2:e2 (ts: 1103), 3:e3 (ts: 1103), + // 0:f0 (ts: 1104), 1:f1 (ts: 1104), 2:f2 (ts: 1104), 3:f3 (ts: 1104), + // 0:g0 (ts: 899), 1:g1 (ts: 899), 2:g2 (ts: 899), 3:g3 (ts: 899) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1000), 1:B1 (ts: 1001), 2:B2 (ts: 1002), 3:B3 (ts: 1003) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), + // 0:b0 (ts: 1100), 1:b1 (ts: 1100), 2:b2 (ts: 1100), 3:b3 (ts: 1100), + // 0:c0 (ts: 1101), 1:c1 (ts: 1101), 2:c2 (ts: 1101), 3:c3 (ts: 1101), + // 0:d0 (ts: 1102), 1:d1 (ts: 1102), 2:d2 (ts: 1102), 3:d3 (ts: 1102), + // 0:e0 (ts: 1103), 1:e1 (ts: 1103), 2:e2 (ts: 1103), 3:e3 (ts: 1103), + // 0:f0 (ts: 1104), 1:f1 (ts: 1104), 2:f2 (ts: 1104), 3:f3 (ts: 1104), + // 0:g0 (ts: 899), 1:g1 (ts: 899), 2:g2 (ts: 899), 3:g3 (ts: 899), + // 0:h0 (ts: 900), 1:h1 (ts: 900), 2:h2 (ts: 900), 3:h3 (ts: 900) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "h" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "B0+h0", 1000L) + ); + + // push four items to the other stream with with incremented timestamp; this should produce two items + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1000), 1:B1 (ts: 1001), 2:B2 (ts: 1002), 3:B3 (ts: 1003) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), + // 0:b0 (ts: 1100), 1:b1 (ts: 1100), 2:b2 (ts: 1100), 3:b3 (ts: 1100), + // 0:c0 (ts: 1101), 1:c1 (ts: 1101), 2:c2 (ts: 1101), 3:c3 (ts: 1101), + // 0:d0 (ts: 1102), 1:d1 (ts: 1102), 2:d2 (ts: 1102), 3:d3 (ts: 1102), + // 0:e0 (ts: 1103), 1:e1 (ts: 1103), 2:e2 (ts: 1103), 3:e3 (ts: 1103), + // 0:f0 (ts: 1104), 1:f1 (ts: 1104), 2:f2 (ts: 1104), 3:f3 (ts: 1104), + // 0:g0 (ts: 899), 1:g1 (ts: 899), 2:g2 (ts: 899), 3:g3 (ts: 899), + // 0:h0 (ts: 900), 1:h1 (ts: 900), 2:h2 (ts: 900), 3:h3 (ts: 900) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1000), 1:B1 (ts: 1001), 2:B2 (ts: 1002), 3:B3 (ts: 1003) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), + // 0:b0 (ts: 1100), 1:b1 (ts: 1100), 2:b2 (ts: 1100), 3:b3 (ts: 1100), + // 0:c0 (ts: 1101), 1:c1 (ts: 1101), 2:c2 (ts: 1101), 3:c3 (ts: 1101), + // 0:d0 (ts: 1102), 1:d1 (ts: 1102), 2:d2 (ts: 1102), 3:d3 (ts: 1102), + // 0:e0 (ts: 1103), 1:e1 (ts: 1103), 2:e2 (ts: 1103), 3:e3 (ts: 1103), + // 0:f0 (ts: 1104), 1:f1 (ts: 1104), 2:f2 (ts: 1104), 3:f3 (ts: 1104), + // 0:g0 (ts: 899), 1:g1 (ts: 899), 2:g2 (ts: 899), 3:g3 (ts: 899), + // 0:h0 (ts: 900), 1:h1 (ts: 900), 2:h2 (ts: 900), 3:h3 (ts: 900), + // 0:i0 (ts: 901), 1:i1 (ts: 901), 2:i2 (ts: 901), 3:i3 (ts: 901) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "i" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "B0+i0", 1000L), + new KeyValueTimestamp<>(1, "B1+i1", 1001L) + ); + + // push four items to the other stream with with incremented timestamp; this should produce three items + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1000), 1:B1 (ts: 1001), 2:B2 (ts: 1002), 3:B3 (ts: 1003) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), + // 0:b0 (ts: 1100), 1:b1 (ts: 1100), 2:b2 (ts: 1100), 3:b3 (ts: 1100), + // 0:c0 (ts: 1101), 1:c1 (ts: 1101), 2:c2 (ts: 1101), 3:c3 (ts: 1101), + // 0:d0 (ts: 1102), 1:d1 (ts: 1102), 2:d2 (ts: 1102), 3:d3 (ts: 1102), + // 0:e0 (ts: 1103), 1:e1 (ts: 1103), 2:e2 (ts: 1103), 3:e3 (ts: 1103), + // 0:f0 (ts: 1104), 1:f1 (ts: 1104), 2:f2 (ts: 1104), 3:f3 (ts: 1104), + // 0:g0 (ts: 899), 1:g1 (ts: 899), 2:g2 (ts: 899), 3:g3 (ts: 899), + // 0:h0 (ts: 900), 1:h1 (ts: 900), 2:h2 (ts: 900), 3:h3 (ts: 900), + // 0:i0 (ts: 901), 1:i1 (ts: 901), 2:i2 (ts: 901), 3:i3 (ts: 901) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1000), 1:B1 (ts: 1001), 2:B2 (ts: 1002), 3:B3 (ts: 1003) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), + // 0:b0 (ts: 1100), 1:b1 (ts: 1100), 2:b2 (ts: 1100), 3:b3 (ts: 1100), + // 0:c0 (ts: 1101), 1:c1 (ts: 1101), 2:c2 (ts: 1101), 3:c3 (ts: 1101), + // 0:d0 (ts: 1102), 1:d1 (ts: 1102), 2:d2 (ts: 1102), 3:d3 (ts: 1102), + // 0:e0 (ts: 1103), 1:e1 (ts: 1103), 2:e2 (ts: 1103), 3:e3 (ts: 1103), + // 0:f0 (ts: 1104), 1:f1 (ts: 1104), 2:f2 (ts: 1104), 3:f3 (ts: 1104), + // 0:g0 (ts: 899), 1:g1 (ts: 899), 2:g2 (ts: 899), 3:g3 (ts: 899), + // 0:h0 (ts: 900), 1:h1 (ts: 900), 2:h2 (ts: 900), 3:h3 (ts: 900), + // 0:i0 (ts: 901), 1:i1 (ts: 901), 2:i2 (ts: 901), 3:i3 (ts: 901), + // 0:j0 (ts: 902), 1:j1 (ts: 902), 2:j2 (ts: 902), 3:j3 (ts: 902) } + time += 1; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "j" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "B0+j0", 1000L), + new KeyValueTimestamp<>(1, "B1+j1", 1001L), + new KeyValueTimestamp<>(2, "B2+j2", 1002L) + ); + + // push four items to the other stream with with incremented timestamp; this should produce four items + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1000), 1:B1 (ts: 1001), 2:B2 (ts: 1002), 3:B3 (ts: 1003) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), + // 0:b0 (ts: 1100), 1:b1 (ts: 1100), 2:b2 (ts: 1100), 3:b3 (ts: 1100), + // 0:c0 (ts: 1101), 1:c1 (ts: 1101), 2:c2 (ts: 1101), 3:c3 (ts: 1101), + // 0:d0 (ts: 1102), 1:d1 (ts: 1102), 2:d2 (ts: 1102), 3:d3 (ts: 1102), + // 0:e0 (ts: 1103), 1:e1 (ts: 1103), 2:e2 (ts: 1103), 3:e3 (ts: 1103), + // 0:f0 (ts: 1104), 1:f1 (ts: 1104), 2:f2 (ts: 1104), 3:f3 (ts: 1104), + // 0:g0 (ts: 899), 1:g1 (ts: 899), 2:g2 (ts: 899), 3:g3 (ts: 899), + // 0:h0 (ts: 900), 1:h1 (ts: 900), 2:h2 (ts: 900), 3:h3 (ts: 900), + // 0:i0 (ts: 901), 1:i1 (ts: 901), 2:i2 (ts: 901), 3:i3 (ts: 901), + // 0:j0 (ts: 902), 1:j1 (ts: 902), 2:j2 (ts: 902), 3:j3 (ts: 902) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1000), 1:B1 (ts: 1001), 2:B2 (ts: 1002), 3:B3 (ts: 1003) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), + // 0:b0 (ts: 1100), 1:b1 (ts: 1100), 2:b2 (ts: 1100), 3:b3 (ts: 1100), + // 0:c0 (ts: 1101), 1:c1 (ts: 1101), 2:c2 (ts: 1101), 3:c3 (ts: 1101), + // 0:d0 (ts: 1102), 1:d1 (ts: 1102), 2:d2 (ts: 1102), 3:d3 (ts: 1102), + // 0:e0 (ts: 1103), 1:e1 (ts: 1103), 2:e2 (ts: 1103), 3:e3 (ts: 1103), + // 0:f0 (ts: 1104), 1:f1 (ts: 1104), 2:f2 (ts: 1104), 3:f3 (ts: 1104), + // 0:g0 (ts: 899), 1:g1 (ts: 899), 2:g2 (ts: 899), 3:g3 (ts: 899), + // 0:h0 (ts: 900), 1:h1 (ts: 900), 2:h2 (ts: 900), 3:h3 (ts: 900), + // 0:i0 (ts: 901), 1:i1 (ts: 901), 2:i2 (ts: 901), 3:i3 (ts: 901), + // 0:j0 (ts: 902), 1:j1 (ts: 902), 2:j2 (ts: 902), 3:j3 (ts: 902) } + // 0:k0 (ts: 903), 1:k1 (ts: 903), 2:k2 (ts: 903), 3:k3 (ts: 903) } + time += 1; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "k" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "B0+k0", 1000L), + new KeyValueTimestamp<>(1, "B1+k1", 1001L), + new KeyValueTimestamp<>(2, "B2+k2", 1002L), + new KeyValueTimestamp<>(3, "B3+k3", 1003L) + ); + + // advance time to not join with existing data + // we omit above exiting data, even if it's still in the window + // + // push four items with increasing timestamps to the other stream. the primary window is empty; this should produce no items + // w1 = {} + // w2 = {} + // --> w1 = {} + // w2 = { 0:l0 (ts: 2000), 1:l1 (ts: 2001), 2:l2 (ts: 2002), 3:l3 (ts: 2003) } + time = 2000L; + for (int i = 0; i < expectedKeys.length; i++) { + inputTopic2.pipeInput(expectedKeys[i], "l" + expectedKeys[i], time + i); + } + processor.checkAndClearProcessResult(); + + // push four items with larger timestamps to the primary stream; this should produce four items + // w1 = {} + // w2 = { 0:l0 (ts: 2000), 1:l1 (ts: 2001), 2:l2 (ts: 2002), 3:l3 (ts: 2003) } + // --> w1 = { 0:C0 (ts: 2100), 1:C1 (ts: 2100), 2:C2 (ts: 2100), 3:C3 (ts: 2100) } + // w2 = { 0:l0 (ts: 2000), 1:l1 (ts: 2001), 2:l2 (ts: 2002), 3:l3 (ts: 2003) } + time = 2000L + 100L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "C" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "C0+l0", 2100L), + new KeyValueTimestamp<>(1, "C1+l1", 2100L), + new KeyValueTimestamp<>(2, "C2+l2", 2100L), + new KeyValueTimestamp<>(3, "C3+l3", 2100L) + ); + + // push four items with increase timestamps to the primary stream; this should produce three items + // w1 = { 0:C0 (ts: 2100), 1:C1 (ts: 2100), 2:C2 (ts: 2100), 3:C3 (ts: 2100) } + // w2 = { 0:l0 (ts: 2000), 1:l1 (ts: 2001), 2:l2 (ts: 2002), 3:l3 (ts: 2003) } + // --> w1 = { 0:C0 (ts: 2100), 1:C1 (ts: 2100), 2:C2 (ts: 2100), 3:C3 (ts: 2100), + // 0:D0 (ts: 2101), 1:D1 (ts: 2101), 2:D2 (ts: 2101), 3:D3 (ts: 2101) } + // w2 = { 0:l0 (ts: 2000), 1:l1 (ts: 2001), 2:l2 (ts: 2002), 3:l3 (ts: 2003) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "D" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(1, "D1+l1", 2101L), + new KeyValueTimestamp<>(2, "D2+l2", 2101L), + new KeyValueTimestamp<>(3, "D3+l3", 2101L) + ); + + // push four items with increase timestamps to the primary stream; this should produce two items + // w1 = { 0:C0 (ts: 2100), 1:C1 (ts: 2100), 2:C2 (ts: 2100), 3:C3 (ts: 2100), + // 0:D0 (ts: 2101), 1:D1 (ts: 2101), 2:D2 (ts: 2101), 3:D3 (ts: 2101) } + // w2 = { 0:l0 (ts: 2000), 1:l1 (ts: 2001), 2:l2 (ts: 2002), 3:l3 (ts: 2003) } + // --> w1 = { 0:C0 (ts: 2100), 1:C1 (ts: 2100), 2:C2 (ts: 2100), 3:C3 (ts: 2100), + // 0:D0 (ts: 2101), 1:D1 (ts: 2101), 2:D2 (ts: 2101), 3:D3 (ts: 2101), + // 0:E0 (ts: 2102), 1:E1 (ts: 2102), 2:E2 (ts: 2102), 3:E3 (ts: 2102) } + // w2 = { 0:l0 (ts: 2000), 1:l1 (ts: 2001), 2:l2 (ts: 2002), 3:l3 (ts: 2003) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "E" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(2, "E2+l2", 2102L), + new KeyValueTimestamp<>(3, "E3+l3", 2102L) + ); + + // push four items with increase timestamps to the primary stream; this should produce one item + // w1 = { 0:C0 (ts: 2100), 1:C1 (ts: 2100), 2:C2 (ts: 2100), 3:C3 (ts: 2100), + // 0:D0 (ts: 2101), 1:D1 (ts: 2101), 2:D2 (ts: 2101), 3:D3 (ts: 2101), + // 0:E0 (ts: 2102), 1:E1 (ts: 2102), 2:E2 (ts: 2102), 3:E3 (ts: 2102) } + // w2 = { 0:l0 (ts: 2000), 1:l1 (ts: 2001), 2:l2 (ts: 2002), 3:l3 (ts: 2003) } + // --> w1 = { 0:C0 (ts: 2100), 1:C1 (ts: 2100), 2:C2 (ts: 2100), 3:C3 (ts: 2100), + // 0:D0 (ts: 2101), 1:D1 (ts: 2101), 2:D2 (ts: 2101), 3:D3 (ts: 2101), + // 0:E0 (ts: 2102), 1:E1 (ts: 2102), 2:E2 (ts: 2102), 3:E3 (ts: 2102), + // 0:F0 (ts: 2103), 1:F1 (ts: 2103), 2:F2 (ts: 2103), 3:F3 (ts: 2103) } + // w2 = { 0:l0 (ts: 2000), 1:l1 (ts: 2001), 2:l2 (ts: 2002), 3:l3 (ts: 2003) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "F" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(3, "F3+l3", 2103L) + ); + + // push four items with increase timestamps (now out of window) to the primary stream; this should produce no items + // w1 = { 0:C0 (ts: 2100), 1:C1 (ts: 2100), 2:C2 (ts: 2100), 3:C3 (ts: 2100), + // 0:D0 (ts: 2101), 1:D1 (ts: 2101), 2:D2 (ts: 2101), 3:D3 (ts: 2101), + // 0:E0 (ts: 2102), 1:E1 (ts: 2102), 2:E2 (ts: 2102), 3:E3 (ts: 2102), + // 0:F0 (ts: 2103), 1:F1 (ts: 2103), 2:F2 (ts: 2103), 3:F3 (ts: 2103) } + // w2 = { 0:l0 (ts: 2000), 1:l1 (ts: 2001), 2:l2 (ts: 2002), 3:l3 (ts: 2003) } + // --> w1 = { 0:C0 (ts: 2100), 1:C1 (ts: 2100), 2:C2 (ts: 2100), 3:C3 (ts: 2100), + // 0:D0 (ts: 2101), 1:D1 (ts: 2101), 2:D2 (ts: 2101), 3:D3 (ts: 2101), + // 0:E0 (ts: 2102), 1:E1 (ts: 2102), 2:E2 (ts: 2102), 3:E3 (ts: 2102), + // 0:F0 (ts: 2103), 1:F1 (ts: 2103), 2:F2 (ts: 2103), 3:F3 (ts: 2103), + // 0:G0 (ts: 2104), 1:G1 (ts: 2104), 2:G2 (ts: 2104), 3:G3 (ts: 2104) } + // w2 = { 0:l0 (ts: 2000), 1:l1 (ts: 2001), 2:l2 (ts: 2002), 3:l3 (ts: 2003) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "G" + expectedKey, time); + } + processor.checkAndClearProcessResult(); + + // push four items with smaller timestamps (before window) to the primary stream; this should produce no items + // w1 = { 0:C0 (ts: 2100), 1:C1 (ts: 2100), 2:C2 (ts: 2100), 3:C3 (ts: 2100), + // 0:D0 (ts: 2101), 1:D1 (ts: 2101), 2:D2 (ts: 2101), 3:D3 (ts: 2101), + // 0:E0 (ts: 2102), 1:E1 (ts: 2102), 2:E2 (ts: 2102), 3:E3 (ts: 2102), + // 0:F0 (ts: 2103), 1:F1 (ts: 2103), 2:F2 (ts: 2103), 3:F3 (ts: 2103), + // 0:G0 (ts: 2104), 1:G1 (ts: 2104), 2:G2 (ts: 2104), 3:G3 (ts: 2104) } + // w2 = { 0:l0 (ts: 2000), 1:l1 (ts: 2001), 2:l2 (ts: 2002), 3:l3 (ts: 2003) } + // --> w1 = { 0:C0 (ts: 2100), 1:C1 (ts: 2100), 2:C2 (ts: 2100), 3:C3 (ts: 2100), + // 0:D0 (ts: 2101), 1:D1 (ts: 2101), 2:D2 (ts: 2101), 3:D3 (ts: 2101), + // 0:E0 (ts: 2102), 1:E1 (ts: 2102), 2:E2 (ts: 2102), 3:E3 (ts: 2102), + // 0:F0 (ts: 2103), 1:F1 (ts: 2103), 2:F2 (ts: 2103), 3:F3 (ts: 2103), + // 0:G0 (ts: 2104), 1:G1 (ts: 2104), 2:G2 (ts: 2104), 3:G3 (ts: 2104), + // 0:H0 (ts: 1899), 1:H1 (ts: 1899), 2:H2 (ts: 1899), 3:H3 (ts: 1899) } + // w2 = { 0:l0 (ts: 2000), 1:l1 (ts: 2001), 2:l2 (ts: 2002), 3:l3 (ts: 2003) } + time = 2000L - 100L - 1L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "H" + expectedKey, time); + } + processor.checkAndClearProcessResult(); + + // push four items with increased timestamps to the primary stream; this should produce one item + // w1 = { 0:C0 (ts: 2100), 1:C1 (ts: 2100), 2:C2 (ts: 2100), 3:C3 (ts: 2100), + // 0:D0 (ts: 2101), 1:D1 (ts: 2101), 2:D2 (ts: 2101), 3:D3 (ts: 2101), + // 0:E0 (ts: 2102), 1:E1 (ts: 2102), 2:E2 (ts: 2102), 3:E3 (ts: 2102), + // 0:F0 (ts: 2103), 1:F1 (ts: 2103), 2:F2 (ts: 2103), 3:F3 (ts: 2103), + // 0:G0 (ts: 2104), 1:G1 (ts: 2104), 2:G2 (ts: 2104), 3:G3 (ts: 2104), + // 0:H0 (ts: 1899), 1:H1 (ts: 1899), 2:H2 (ts: 1899), 3:H3 (ts: 1899) } + // w2 = { 0:l0 (ts: 2000), 1:l1 (ts: 2001), 2:l2 (ts: 2002), 3:l3 (ts: 2003) } + // --> w1 = { 0:C0 (ts: 2100), 1:C1 (ts: 2100), 2:C2 (ts: 2100), 3:C3 (ts: 2100), + // 0:D0 (ts: 2101), 1:D1 (ts: 2101), 2:D2 (ts: 2101), 3:D3 (ts: 2101), + // 0:E0 (ts: 2102), 1:E1 (ts: 2102), 2:E2 (ts: 2102), 3:E3 (ts: 2102), + // 0:F0 (ts: 2103), 1:F1 (ts: 2103), 2:F2 (ts: 2103), 3:F3 (ts: 2103), + // 0:G0 (ts: 2104), 1:G1 (ts: 2104), 2:G2 (ts: 2104), 3:G3 (ts: 2104), + // 0:H0 (ts: 1899), 1:H1 (ts: 1899), 2:H2 (ts: 1899), 3:H3 (ts: 1899), + // 0:I0 (ts: 1900), 1:I1 (ts: 1900), 2:I2 (ts: 1900), 3:I3 (ts: 1900) } + // w2 = { 0:l0 (ts: 2000), 1:l1 (ts: 2001), 2:l2 (ts: 2002), 3:l3 (ts: 2003) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "I" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "I0+l0", 2000L) + ); + + // push four items with increased timestamps to the primary stream; this should produce two items + // w1 = { 0:C0 (ts: 2100), 1:C1 (ts: 2100), 2:C2 (ts: 2100), 3:C3 (ts: 2100), + // 0:D0 (ts: 2101), 1:D1 (ts: 2101), 2:D2 (ts: 2101), 3:D3 (ts: 2101), + // 0:E0 (ts: 2102), 1:E1 (ts: 2102), 2:E2 (ts: 2102), 3:E3 (ts: 2102), + // 0:F0 (ts: 2103), 1:F1 (ts: 2103), 2:F2 (ts: 2103), 3:F3 (ts: 2103), + // 0:G0 (ts: 2104), 1:G1 (ts: 2104), 2:G2 (ts: 2104), 3:G3 (ts: 2104), + // 0:H0 (ts: 1899), 1:H1 (ts: 1899), 2:H2 (ts: 1899), 3:H3 (ts: 1899), + // 0:I0 (ts: 1900), 1:I1 (ts: 1900), 2:I2 (ts: 1900), 3:I3 (ts: 1900) } + // w2 = { 0:l0 (ts: 2000), 1:l1 (ts: 2001), 2:l2 (ts: 2002), 3:l3 (ts: 2003) } + // --> w1 = { 0:C0 (ts: 2100), 1:C1 (ts: 2100), 2:C2 (ts: 2100), 3:C3 (ts: 2100), + // 0:D0 (ts: 2101), 1:D1 (ts: 2101), 2:D2 (ts: 2101), 3:D3 (ts: 2101), + // 0:E0 (ts: 2102), 1:E1 (ts: 2102), 2:E2 (ts: 2102), 3:E3 (ts: 2102), + // 0:F0 (ts: 2103), 1:F1 (ts: 2103), 2:F2 (ts: 2103), 3:F3 (ts: 2103), + // 0:G0 (ts: 2104), 1:G1 (ts: 2104), 2:G2 (ts: 2104), 3:G3 (ts: 2104), + // 0:H0 (ts: 1899), 1:H1 (ts: 1899), 2:H2 (ts: 1899), 3:H3 (ts: 1899), + // 0:I0 (ts: 1900), 1:I1 (ts: 1900), 2:I2 (ts: 1900), 3:I3 (ts: 1900), + // 0:J0 (ts: 1901), 1:J1 (ts: 1901), 2:J2 (ts: 1901), 3:J3 (ts: 1901) } + // w2 = { 0:l0 (ts: 2000), 1:l1 (ts: 2001), 2:l2 (ts: 2002), 3:l3 (ts: 2003) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "J" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "J0+l0", 2000L), + new KeyValueTimestamp<>(1, "J1+l1", 2001L) + ); + + // push four items with increased timestamps to the primary stream; this should produce three items + // w1 = { 0:C0 (ts: 2100), 1:C1 (ts: 2100), 2:C2 (ts: 2100), 3:C3 (ts: 2100), + // 0:D0 (ts: 2101), 1:D1 (ts: 2101), 2:D2 (ts: 2101), 3:D3 (ts: 2101), + // 0:E0 (ts: 2102), 1:E1 (ts: 2102), 2:E2 (ts: 2102), 3:E3 (ts: 2102), + // 0:F0 (ts: 2103), 1:F1 (ts: 2103), 2:F2 (ts: 2103), 3:F3 (ts: 2103), + // 0:G0 (ts: 2104), 1:G1 (ts: 2104), 2:G2 (ts: 2104), 3:G3 (ts: 2104), + // 0:H0 (ts: 1899), 1:H1 (ts: 1899), 2:H2 (ts: 1899), 3:H3 (ts: 1899), + // 0:I0 (ts: 1900), 1:I1 (ts: 1900), 2:I2 (ts: 1900), 3:I3 (ts: 1900), + // 0:J0 (ts: 1901), 1:J1 (ts: 1901), 2:J2 (ts: 1901), 3:J3 (ts: 1901) } + // w2 = { 0:l0 (ts: 2000), 1:l1 (ts: 2001), 2:l2 (ts: 2002), 3:l3 (ts: 2003) } + // --> w1 = { 0:C0 (ts: 2100), 1:C1 (ts: 2100), 2:C2 (ts: 2100), 3:C3 (ts: 2100), + // 0:D0 (ts: 2101), 1:D1 (ts: 2101), 2:D2 (ts: 2101), 3:D3 (ts: 2101), + // 0:E0 (ts: 2102), 1:E1 (ts: 2102), 2:E2 (ts: 2102), 3:E3 (ts: 2102), + // 0:F0 (ts: 2103), 1:F1 (ts: 2103), 2:F2 (ts: 2103), 3:F3 (ts: 2103), + // 0:G0 (ts: 2104), 1:G1 (ts: 2104), 2:G2 (ts: 2104), 3:G3 (ts: 2104), + // 0:H0 (ts: 1899), 1:H1 (ts: 1899), 2:H2 (ts: 1899), 3:H3 (ts: 1899), + // 0:I0 (ts: 1900), 1:I1 (ts: 1900), 2:I2 (ts: 1900), 3:I3 (ts: 1900), + // 0:J0 (ts: 1901), 1:J1 (ts: 1901), 2:J2 (ts: 1901), 3:J3 (ts: 1901), + // 0:K0 (ts: 1902), 1:K1 (ts: 1902), 2:K2 (ts: 1902), 3:K3 (ts: 1902) } + // w2 = { 0:l0 (ts: 2000), 1:l1 (ts: 2001), 2:l2 (ts: 2002), 3:l3 (ts: 2003) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "K" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "K0+l0", 2000L), + new KeyValueTimestamp<>(1, "K1+l1", 2001L), + new KeyValueTimestamp<>(2, "K2+l2", 2002L) + ); + + // push four items with increased timestamps to the primary stream; this should produce four items + // w1 = { 0:C0 (ts: 2100), 1:C1 (ts: 2100), 2:C2 (ts: 2100), 3:C3 (ts: 2100), + // 0:D0 (ts: 2101), 1:D1 (ts: 2101), 2:D2 (ts: 2101), 3:D3 (ts: 2101), + // 0:E0 (ts: 2102), 1:E1 (ts: 2102), 2:E2 (ts: 2102), 3:E3 (ts: 2102), + // 0:F0 (ts: 2103), 1:F1 (ts: 2103), 2:F2 (ts: 2103), 3:F3 (ts: 2103), + // 0:G0 (ts: 2104), 1:G1 (ts: 2104), 2:G2 (ts: 2104), 3:G3 (ts: 2104), + // 0:H0 (ts: 1899), 1:H1 (ts: 1899), 2:H2 (ts: 1899), 3:H3 (ts: 1899), + // 0:I0 (ts: 1900), 1:I1 (ts: 1900), 2:I2 (ts: 1900), 3:I3 (ts: 1900), + // 0:J0 (ts: 1901), 1:J1 (ts: 1901), 2:J2 (ts: 1901), 3:J3 (ts: 1901) } + // 0:K0 (ts: 1902), 1:K1 (ts: 1902), 2:K2 (ts: 1902), 3:K3 (ts: 1902) } + // w2 = { 0:l0 (ts: 2000), 1:l1 (ts: 2001), 2:l2 (ts: 2002), 3:l3 (ts: 2003) } + // --> w1 = { 0:C0 (ts: 2100), 1:C1 (ts: 2100), 2:C2 (ts: 2100), 3:C3 (ts: 2100), + // 0:D0 (ts: 2101), 1:D1 (ts: 2101), 2:D2 (ts: 2101), 3:D3 (ts: 2101), + // 0:E0 (ts: 2102), 1:E1 (ts: 2102), 2:E2 (ts: 2102), 3:E3 (ts: 2102), + // 0:F0 (ts: 2103), 1:F1 (ts: 2103), 2:F2 (ts: 2103), 3:F3 (ts: 2103), + // 0:G0 (ts: 2104), 1:G1 (ts: 2104), 2:G2 (ts: 2104), 3:G3 (ts: 2104), + // 0:H0 (ts: 1899), 1:H1 (ts: 1899), 2:H2 (ts: 1899), 3:H3 (ts: 1899), + // 0:I0 (ts: 1900), 1:I1 (ts: 1900), 2:I2 (ts: 1900), 3:I3 (ts: 1900), + // 0:J0 (ts: 1901), 1:J1 (ts: 1901), 2:J2 (ts: 1901), 3:J3 (ts: 1901), + // 0:K0 (ts: 1902), 1:K1 (ts: 1902), 2:K2 (ts: 1902), 3:K3 (ts: 1902), + // 0:L0 (ts: 1903), 1:L1 (ts: 1903), 2:L2 (ts: 1903), 3:L3 (ts: 1903) } + // w2 = { 0:l0 (ts: 2000), 1:l1 (ts: 2001), 2:l2 (ts: 2002), 3:l3 (ts: 2003) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "L" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "L0+l0", 2000L), + new KeyValueTimestamp<>(1, "L1+l1", 2001L), + new KeyValueTimestamp<>(2, "L2+l2", 2002L), + new KeyValueTimestamp<>(3, "L3+l3", 2003L) + ); + } + } + + @Test + public void testAsymmetricWindowingAfter() { + final StreamsBuilder builder = new StreamsBuilder(); + + final int[] expectedKeys = new int[] {0, 1, 2, 3}; + + final KStream stream1; + final KStream stream2; + final KStream joined; + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + stream1 = builder.stream(topic1, consumed); + stream2 = builder.stream(topic2, consumed); + + joined = stream1.join( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.of(ofMillis(0)).after(ofMillis(100)).grace(ofMillis(0)), + StreamJoined.with(Serdes.Integer(), + Serdes.String(), + Serdes.String()) + ); + joined.process(supplier); + + final Collection> copartitionGroups = + TopologyWrapper.getInternalTopologyBuilder(builder.build()).copartitionGroups(); + + assertEquals(1, copartitionGroups.size()); + assertEquals(new HashSet<>(Arrays.asList(topic1, topic2)), copartitionGroups.iterator().next()); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockProcessor processor = supplier.theCapturedProcessor(); + long time = 1000L; + + // push four items with increasing timestamps to the primary stream; the other window is empty; this should produce no items + // w1 = {} + // w2 = {} + // --> w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = {} + for (int i = 0; i < expectedKeys.length; i++) { + inputTopic1.pipeInput(expectedKeys[i], "A" + expectedKeys[i], time + i); + } + processor.checkAndClearProcessResult(); + + // push four items smaller timestamps (out of window) to the secondary stream; this should produce no items + // w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = {} + // --> w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 999), 1:a1 (ts: 999), 2:a2 (ts: 999), 3:a3 (ts: 999) } + time = 1000L - 1L; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "a" + expectedKey, time); + } + processor.checkAndClearProcessResult(); + + // push four items with increased timestamps to the secondary stream; this should produce one item + // w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 999), 1:a1 (ts: 999), 2:a2 (ts: 999), 3:a3 (ts: 999) } + // --> w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 999), 1:a1 (ts: 999), 2:a2 (ts: 999), 3:a3 (ts: 999), + // 0:b0 (ts: 1000), 1:b1 (ts: 1000), 2:b2 (ts: 1000), 3:b3 (ts: 1000) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "b" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+b0", 1000L) + ); + + // push four items with increased timestamps to the secondary stream; this should produce two items + // w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 999), 1:a1 (ts: 999), 2:a2 (ts: 999), 3:a3 (ts: 999), + // 0:b0 (ts: 1000), 1:b1 (ts: 1000), 2:b2 (ts: 1000), 3:b3 (ts: 1000) } + // --> w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 999), 1:a1 (ts: 999), 2:a2 (ts: 999), 3:a3 (ts: 999), + // 0:b0 (ts: 1000), 1:b1 (ts: 1000), 2:b2 (ts: 1000), 3:b3 (ts: 1000), + // 0:c0 (ts: 1001), 1:c1 (ts: 1001), 2:c2 (ts: 1001), 3:c3 (ts: 1001) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "c" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+c0", 1001L), + new KeyValueTimestamp<>(1, "A1+c1", 1001L) + ); + + // push four items with increased timestamps to the secondary stream; this should produce three items + // w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 999), 1:a1 (ts: 999), 2:a2 (ts: 999), 3:a3 (ts: 999), + // 0:b0 (ts: 1000), 1:b1 (ts: 1000), 2:b2 (ts: 1000), 3:b3 (ts: 1000), + // 0:c0 (ts: 1001), 1:c1 (ts: 1001), 2:c2 (ts: 1001), 3:c3 (ts: 1001) } + // --> w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 999), 1:a1 (ts: 999), 2:a2 (ts: 999), 3:a3 (ts: 999), + // 0:b0 (ts: 1000), 1:b1 (ts: 1000), 2:b2 (ts: 1000), 3:b3 (ts: 1000), + // 0:c0 (ts: 1001), 1:c1 (ts: 1001), 2:c2 (ts: 1001), 3:c3 (ts: 1001), + // 0:d0 (ts: 1002), 1:d1 (ts: 1002), 2:d2 (ts: 1002), 3:d3 (ts: 1002) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "d" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+d0", 1002L), + new KeyValueTimestamp<>(1, "A1+d1", 1002L), + new KeyValueTimestamp<>(2, "A2+d2", 1002L) + ); + + // push four items with increased timestamps to the secondary stream; this should produce four items + // w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 999), 1:a1 (ts: 999), 2:a2 (ts: 999), 3:a3 (ts: 999), + // 0:b0 (ts: 1000), 1:b1 (ts: 1000), 2:b2 (ts: 1000), 3:b3 (ts: 1000), + // 0:c0 (ts: 1001), 1:c1 (ts: 1001), 2:c2 (ts: 1001), 3:c3 (ts: 1001), + // 0:d0 (ts: 1002), 1:d1 (ts: 1002), 2:d2 (ts: 1002), 3:d3 (ts: 1002) } + // --> w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 999), 1:a1 (ts: 999), 2:a2 (ts: 999), 3:a3 (ts: 999), + // 0:b0 (ts: 1000), 1:b1 (ts: 1000), 2:b2 (ts: 1000), 3:b3 (ts: 1000), + // 0:c0 (ts: 1001), 1:c1 (ts: 1001), 2:c2 (ts: 1001), 3:c3 (ts: 1001), + // 0:d0 (ts: 1002), 1:d1 (ts: 1002), 2:d2 (ts: 1002), 3:d3 (ts: 1002), + // 0:e0 (ts: 1003), 1:e1 (ts: 1003), 2:e2 (ts: 1003), 3:e3 (ts: 1003) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "e" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+e0", 1003L), + new KeyValueTimestamp<>(1, "A1+e1", 1003L), + new KeyValueTimestamp<>(2, "A2+e2", 1003L), + new KeyValueTimestamp<>(3, "A3+e3", 1003L) + ); + + // push four items with larger timestamps to the secondary stream; this should produce four items + // w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 999), 1:a1 (ts: 999), 2:a2 (ts: 999), 3:a3 (ts: 999), + // 0:b0 (ts: 1000), 1:b1 (ts: 1000), 2:b2 (ts: 1000), 3:b3 (ts: 1000), + // 0:c0 (ts: 1001), 1:c1 (ts: 1001), 2:c2 (ts: 1001), 3:c3 (ts: 1001), + // 0:d0 (ts: 1002), 1:d1 (ts: 1002), 2:d2 (ts: 1002), 3:d3 (ts: 1002), + // 0:e0 (ts: 1003), 1:e1 (ts: 1003), 2:e2 (ts: 1003), 3:e3 (ts: 1003) } + // --> w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 999), 1:a1 (ts: 999), 2:a2 (ts: 999), 3:a3 (ts: 999), + // 0:b0 (ts: 1000), 1:b1 (ts: 1000), 2:b2 (ts: 1000), 3:b3 (ts: 1000), + // 0:c0 (ts: 1001), 1:c1 (ts: 1001), 2:c2 (ts: 1001), 3:c3 (ts: 1001), + // 0:d0 (ts: 1002), 1:d1 (ts: 1002), 2:d2 (ts: 1002), 3:d3 (ts: 1002), + // 0:e0 (ts: 1003), 1:e1 (ts: 1003), 2:e2 (ts: 1003), 3:e3 (ts: 1003), + // 0:f0 (ts: 1100), 1:f1 (ts: 1100), 2:f2 (ts: 1100), 3:f3 (ts: 1100) } + time = 1000 + 100L; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "f" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+f0", 1100L), + new KeyValueTimestamp<>(1, "A1+f1", 1100L), + new KeyValueTimestamp<>(2, "A2+f2", 1100L), + new KeyValueTimestamp<>(3, "A3+f3", 1100L) + ); + + // push four items with increased timestamps to the secondary stream; this should produce three items + // w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 999), 1:a1 (ts: 999), 2:a2 (ts: 999), 3:a3 (ts: 999), + // 0:b0 (ts: 1000), 1:b1 (ts: 1000), 2:b2 (ts: 1000), 3:b3 (ts: 1000), + // 0:c0 (ts: 1001), 1:c1 (ts: 1001), 2:c2 (ts: 1001), 3:c3 (ts: 1001), + // 0:d0 (ts: 1002), 1:d1 (ts: 1002), 2:d2 (ts: 1002), 3:d3 (ts: 1002), + // 0:e0 (ts: 1003), 1:e1 (ts: 1003), 2:e2 (ts: 1003), 3:e3 (ts: 1003), + // 0:f0 (ts: 1100), 1:f1 (ts: 1100), 2:f2 (ts: 1100), 3:f3 (ts: 1100) } + // --> w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 999), 1:a1 (ts: 999), 2:a2 (ts: 999), 3:a3 (ts: 999), + // 0:b0 (ts: 1000), 1:b1 (ts: 1000), 2:b2 (ts: 1000), 3:b3 (ts: 1000), + // 0:c0 (ts: 1001), 1:c1 (ts: 1001), 2:c2 (ts: 1001), 3:c3 (ts: 1001), + // 0:d0 (ts: 1002), 1:d1 (ts: 1002), 2:d2 (ts: 1002), 3:d3 (ts: 1002), + // 0:e0 (ts: 1003), 1:e1 (ts: 1003), 2:e2 (ts: 1003), 3:e3 (ts: 1003), + // 0:f0 (ts: 1100), 1:f1 (ts: 1100), 2:f2 (ts: 1100), 3:f3 (ts: 1100), + // 0:g0 (ts: 1101), 1:g1 (ts: 1101), 2:g2 (ts: 1101), 3:g3 (ts: 1101) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "g" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(1, "A1+g1", 1101L), + new KeyValueTimestamp<>(2, "A2+g2", 1101L), + new KeyValueTimestamp<>(3, "A3+g3", 1101L) + ); + + // push four items with increased timestamps to the secondary stream; this should produce two items + // w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 999), 1:a1 (ts: 999), 2:a2 (ts: 999), 3:a3 (ts: 999), + // 0:b0 (ts: 1000), 1:b1 (ts: 1000), 2:b2 (ts: 1000), 3:b3 (ts: 1000), + // 0:c0 (ts: 1001), 1:c1 (ts: 1001), 2:c2 (ts: 1001), 3:c3 (ts: 1001), + // 0:d0 (ts: 1002), 1:d1 (ts: 1002), 2:d2 (ts: 1002), 3:d3 (ts: 1002), + // 0:e0 (ts: 1003), 1:e1 (ts: 1003), 2:e2 (ts: 1003), 3:e3 (ts: 1003), + // 0:f0 (ts: 1100), 1:f1 (ts: 1100), 2:f2 (ts: 1100), 3:f3 (ts: 1100), + // 0:g0 (ts: 1101), 1:g1 (ts: 1101), 2:g2 (ts: 1101), 3:g3 (ts: 1101) } + // --> w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 999), 1:a1 (ts: 999), 2:a2 (ts: 999), 3:a3 (ts: 999), + // 0:b0 (ts: 1000), 1:b1 (ts: 1000), 2:b2 (ts: 1000), 3:b3 (ts: 1000), + // 0:c0 (ts: 1001), 1:c1 (ts: 1001), 2:c2 (ts: 1001), 3:c3 (ts: 1001), + // 0:d0 (ts: 1002), 1:d1 (ts: 1002), 2:d2 (ts: 1002), 3:d3 (ts: 1002), + // 0:e0 (ts: 1003), 1:e1 (ts: 1003), 2:e2 (ts: 1003), 3:e3 (ts: 1003), + // 0:f0 (ts: 1100), 1:f1 (ts: 1100), 2:f2 (ts: 1100), 3:f3 (ts: 1100), + // 0:g0 (ts: 1101), 1:g1 (ts: 1101), 2:g2 (ts: 1101), 3:g3 (ts: 1101), + // 0:h0 (ts: 1102), 1:h1 (ts: 1102), 2:h2 (ts: 1102), 3:h3 (ts: 1102) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "h" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(2, "A2+h2", 1102L), + new KeyValueTimestamp<>(3, "A3+h3", 1102L) + ); + + // push four items with increased timestamps to the secondary stream; this should produce one item + // w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 999), 1:a1 (ts: 999), 2:a2 (ts: 999), 3:a3 (ts: 999), + // 0:b0 (ts: 1000), 1:b1 (ts: 1000), 2:b2 (ts: 1000), 3:b3 (ts: 1000), + // 0:c0 (ts: 1001), 1:c1 (ts: 1001), 2:c2 (ts: 1001), 3:c3 (ts: 1001), + // 0:d0 (ts: 1002), 1:d1 (ts: 1002), 2:d2 (ts: 1002), 3:d3 (ts: 1002), + // 0:e0 (ts: 1003), 1:e1 (ts: 1003), 2:e2 (ts: 1003), 3:e3 (ts: 1003), + // 0:f0 (ts: 1100), 1:f1 (ts: 1100), 2:f2 (ts: 1100), 3:f3 (ts: 1100), + // 0:g0 (ts: 1101), 1:g1 (ts: 1101), 2:g2 (ts: 1101), 3:g3 (ts: 1101), + // 0:h0 (ts: 1102), 1:h1 (ts: 1102), 2:h2 (ts: 1102), 3:h3 (ts: 1102) } + // --> w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 999), 1:a1 (ts: 999), 2:a2 (ts: 999), 3:a3 (ts: 999), + // 0:b0 (ts: 1000), 1:b1 (ts: 1000), 2:b2 (ts: 1000), 3:b3 (ts: 1000), + // 0:c0 (ts: 1001), 1:c1 (ts: 1001), 2:c2 (ts: 1001), 3:c3 (ts: 1001), + // 0:d0 (ts: 1002), 1:d1 (ts: 1002), 2:d2 (ts: 1002), 3:d3 (ts: 1002), + // 0:e0 (ts: 1003), 1:e1 (ts: 1003), 2:e2 (ts: 1003), 3:e3 (ts: 1003), + // 0:f0 (ts: 1100), 1:f1 (ts: 1100), 2:f2 (ts: 1100), 3:f3 (ts: 1100), + // 0:g0 (ts: 1101), 1:g1 (ts: 1101), 2:g2 (ts: 1101), 3:g3 (ts: 1101), + // 0:h0 (ts: 1102), 1:h1 (ts: 1102), 2:h2 (ts: 1102), 3:h3 (ts: 1102), + // 0:i0 (ts: 1103), 1:i1 (ts: 1103), 2:i2 (ts: 1103), 3:i3 (ts: 1103) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "i" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(3, "A3+i3", 1103L) + ); + + // push four items with increased timestamps (no out of window) to the secondary stream; this should produce no items + // w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 999), 1:a1 (ts: 999), 2:a2 (ts: 999), 3:a3 (ts: 999), + // 0:b0 (ts: 1000), 1:b1 (ts: 1000), 2:b2 (ts: 1000), 3:b3 (ts: 1000), + // 0:c0 (ts: 1001), 1:c1 (ts: 1001), 2:c2 (ts: 1001), 3:c3 (ts: 1001), + // 0:d0 (ts: 1002), 1:d1 (ts: 1002), 2:d2 (ts: 1002), 3:d3 (ts: 1002), + // 0:e0 (ts: 1003), 1:e1 (ts: 1003), 2:e2 (ts: 1003), 3:e3 (ts: 1003), + // 0:f0 (ts: 1100), 1:f1 (ts: 1100), 2:f2 (ts: 1100), 3:f3 (ts: 1100), + // 0:g0 (ts: 1101), 1:g1 (ts: 1101), 2:g2 (ts: 1101), 3:g3 (ts: 1101), + // 0:h0 (ts: 1102), 1:h1 (ts: 1102), 2:h2 (ts: 1102), 3:h3 (ts: 1102), + // 0:i0 (ts: 1103), 1:i1 (ts: 1103), 2:i2 (ts: 1103), 3:i3 (ts: 1103) } + // --> w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 999), 1:a1 (ts: 999), 2:a2 (ts: 999), 3:a3 (ts: 999), + // 0:b0 (ts: 1000), 1:b1 (ts: 1000), 2:b2 (ts: 1000), 3:b3 (ts: 1000), + // 0:c0 (ts: 1001), 1:c1 (ts: 1001), 2:c2 (ts: 1001), 3:c3 (ts: 1001), + // 0:d0 (ts: 1002), 1:d1 (ts: 1002), 2:d2 (ts: 1002), 3:d3 (ts: 1002), + // 0:e0 (ts: 1003), 1:e1 (ts: 1003), 2:e2 (ts: 1003), 3:e3 (ts: 1003), + // 0:f0 (ts: 1100), 1:f1 (ts: 1100), 2:f2 (ts: 1100), 3:f3 (ts: 1100), + // 0:g0 (ts: 1101), 1:g1 (ts: 1101), 2:g2 (ts: 1101), 3:g3 (ts: 1101), + // 0:h0 (ts: 1102), 1:h1 (ts: 1102), 2:h2 (ts: 1102), 3:h3 (ts: 1102), + // 0:i0 (ts: 1103), 1:i1 (ts: 1103), 2:i2 (ts: 1103), 3:i3 (ts: 1103), + // 0:j0 (ts: 1104), 1:j1 (ts: 1104), 2:j2 (ts: 1104), 3:j3 (ts: 1104) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "j" + expectedKey, time); + } + processor.checkAndClearProcessResult(); + } + } + + @Test + public void testAsymmetricWindowingBefore() { + final StreamsBuilder builder = new StreamsBuilder(); + + final int[] expectedKeys = new int[] {0, 1, 2, 3}; + + final KStream stream1; + final KStream stream2; + final KStream joined; + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + + stream1 = builder.stream(topic1, consumed); + stream2 = builder.stream(topic2, consumed); + + joined = stream1.join( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.of(ofMillis(0)).before(ofMillis(100)), + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()) + ); + joined.process(supplier); + + final Collection> copartitionGroups = + TopologyWrapper.getInternalTopologyBuilder(builder.build()).copartitionGroups(); + + assertEquals(1, copartitionGroups.size()); + assertEquals(new HashSet<>(Arrays.asList(topic1, topic2)), copartitionGroups.iterator().next()); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockProcessor processor = supplier.theCapturedProcessor(); + long time = 1000L; + + // push four items with increasing timestamps to the primary stream; the other window is empty; this should produce no items + // w1 = {} + // w2 = {} + // --> w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = {} + for (int i = 0; i < expectedKeys.length; i++) { + inputTopic1.pipeInput(expectedKeys[i], "A" + expectedKeys[i], time + i); + } + processor.checkAndClearProcessResult(); + + // push four items with smaller timestamps (before the window) to the other stream; this should produce no items + // w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = {} + // --> w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 899), 1:a1 (ts: 899), 2:a2 (ts: 899), 3:a3 (ts: 899) } + time = 1000L - 100L - 1L; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "a" + expectedKey, time); + } + processor.checkAndClearProcessResult(); + + // push four items with increased timestamp to the other stream; this should produce one item + // w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 899), 1:a1 (ts: 899), 2:a2 (ts: 899), 3:a3 (ts: 899) } + // --> w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 899), 1:a1 (ts: 899), 2:a2 (ts: 899), 3:a3 (ts: 899), + // 0:b0 (ts: 900), 1:b1 (ts: 900), 2:b2 (ts: 900), 3:b3 (ts: 900) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "b" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+b0", 1000L) + ); + + // push four items with increased timestamp to the other stream; this should produce two items + // w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 899), 1:a1 (ts: 899), 2:a2 (ts: 899), 3:a3 (ts: 899), + // 0:b0 (ts: 900), 1:b1 (ts: 900), 2:b2 (ts: 900), 3:b3 (ts: 900) } + // --> w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 899), 1:a1 (ts: 899), 2:a2 (ts: 899), 3:a3 (ts: 899), + // 0:b0 (ts: 900), 1:b1 (ts: 900), 2:b2 (ts: 900), 3:b3 (ts: 900), + // 0:c0 (ts: 901), 1:c1 (ts: 901), 2:c2 (ts: 901), 3:c3 (ts: 901) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "c" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+c0", 1000L), + new KeyValueTimestamp<>(1, "A1+c1", 1001L) + ); + + // push four items with increased timestamp to the other stream; this should produce three items + // w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 899), 1:a1 (ts: 899), 2:a2 (ts: 899), 3:a3 (ts: 899), + // 0:b0 (ts: 900), 1:b1 (ts: 900), 2:b2 (ts: 900), 3:b3 (ts: 900), + // 0:c0 (ts: 901), 1:c1 (ts: 901), 2:c2 (ts: 901), 3:c3 (ts: 901) } + // --> w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 899), 1:a1 (ts: 899), 2:a2 (ts: 899), 3:a3 (ts: 899), + // 0:b0 (ts: 900), 1:b1 (ts: 900), 2:b2 (ts: 900), 3:b3 (ts: 900), + // 0:c0 (ts: 901), 1:c1 (ts: 901), 2:c2 (ts: 901), 3:c3 (ts: 901), + // 0:d0 (ts: 902), 1:d1 (ts: 902), 2:d2 (ts: 902), 3:d3 (ts: 902) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "d" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+d0", 1000L), + new KeyValueTimestamp<>(1, "A1+d1", 1001L), + new KeyValueTimestamp<>(2, "A2+d2", 1002L) + ); + + // push four items with increased timestamp to the other stream; this should produce four items + // w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 899), 1:a1 (ts: 899), 2:a2 (ts: 899), 3:a3 (ts: 899), + // 0:b0 (ts: 900), 1:b1 (ts: 900), 2:b2 (ts: 900), 3:b3 (ts: 900), + // 0:c0 (ts: 901), 1:c1 (ts: 901), 2:c2 (ts: 901), 3:c3 (ts: 901), + // 0:d0 (ts: 902), 1:d1 (ts: 902), 2:d2 (ts: 902), 3:d3 (ts: 902) } + // --> w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 899), 1:a1 (ts: 899), 2:a2 (ts: 899), 3:a3 (ts: 899), + // 0:b0 (ts: 900), 1:b1 (ts: 900), 2:b2 (ts: 900), 3:b3 (ts: 900), + // 0:c0 (ts: 901), 1:c1 (ts: 901), 2:c2 (ts: 901), 3:c3 (ts: 901), + // 0:d0 (ts: 902), 1:d1 (ts: 902), 2:d2 (ts: 902), 3:d3 (ts: 902), + // 0:e0 (ts: 903), 1:e1 (ts: 903), 2:e2 (ts: 903), 3:e3 (ts: 903) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "e" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+e0", 1000L), + new KeyValueTimestamp<>(1, "A1+e1", 1001L), + new KeyValueTimestamp<>(2, "A2+e2", 1002L), + new KeyValueTimestamp<>(3, "A3+e3", 1003L) + ); + + // push four items with larger timestamp to the other stream; this should produce four items + // w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 899), 1:a1 (ts: 899), 2:a2 (ts: 899), 3:a3 (ts: 899), + // 0:b0 (ts: 900), 1:b1 (ts: 900), 2:b2 (ts: 900), 3:b3 (ts: 900), + // 0:c0 (ts: 901), 1:c1 (ts: 901), 2:c2 (ts: 901), 3:c3 (ts: 901), + // 0:d0 (ts: 902), 1:d1 (ts: 902), 2:d2 (ts: 902), 3:d3 (ts: 902), + // 0:e0 (ts: 903), 1:e1 (ts: 903), 2:e2 (ts: 903), 3:e3 (ts: 903) } + // --> w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 899), 1:a1 (ts: 899), 2:a2 (ts: 899), 3:a3 (ts: 899), + // 0:b0 (ts: 900), 1:b1 (ts: 900), 2:b2 (ts: 900), 3:b3 (ts: 900), + // 0:c0 (ts: 901), 1:c1 (ts: 901), 2:c2 (ts: 901), 3:c3 (ts: 901), + // 0:d0 (ts: 902), 1:d1 (ts: 902), 2:d2 (ts: 902), 3:d3 (ts: 902), + // 0:e0 (ts: 903), 1:e1 (ts: 903), 2:e2 (ts: 903), 3:e3 (ts: 903), + // 0:f0 (ts: 1000), 1:f1 (ts: 1000), 2:f2 (ts: 1000), 3:f3 (ts: 1000) } + time = 1000L; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "f" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+f0", 1000L), + new KeyValueTimestamp<>(1, "A1+f1", 1001L), + new KeyValueTimestamp<>(2, "A2+f2", 1002L), + new KeyValueTimestamp<>(3, "A3+f3", 1003L) + ); + + // push four items with increase timestamp to the other stream; this should produce three items + // w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 899), 1:a1 (ts: 899), 2:a2 (ts: 899), 3:a3 (ts: 899), + // 0:b0 (ts: 900), 1:b1 (ts: 900), 2:b2 (ts: 900), 3:b3 (ts: 900), + // 0:c0 (ts: 901), 1:c1 (ts: 901), 2:c2 (ts: 901), 3:c3 (ts: 901), + // 0:d0 (ts: 902), 1:d1 (ts: 902), 2:d2 (ts: 902), 3:d3 (ts: 902), + // 0:e0 (ts: 903), 1:e1 (ts: 903), 2:e2 (ts: 903), 3:e3 (ts: 903), + // 0:f0 (ts: 1000), 1:f1 (ts: 1000), 2:f2 (ts: 1000), 3:f3 (ts: 1000) } + // --> w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 899), 1:a1 (ts: 899), 2:a2 (ts: 899), 3:a3 (ts: 899), + // 0:b0 (ts: 900), 1:b1 (ts: 900), 2:b2 (ts: 900), 3:b3 (ts: 900), + // 0:c0 (ts: 901), 1:c1 (ts: 901), 2:c2 (ts: 901), 3:c3 (ts: 901), + // 0:d0 (ts: 902), 1:d1 (ts: 902), 2:d2 (ts: 902), 3:d3 (ts: 902), + // 0:e0 (ts: 903), 1:e1 (ts: 903), 2:e2 (ts: 903), 3:e3 (ts: 903), + // 0:f0 (ts: 1000), 1:f1 (ts: 1000), 2:f2 (ts: 1000), 3:f3 (ts: 1000), + // 0:g0 (ts: 1001), 1:g1 (ts: 1001), 2:g2 (ts: 1001), 3:g3 (ts: 1001) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "g" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(1, "A1+g1", 1001L), + new KeyValueTimestamp<>(2, "A2+g2", 1002L), + new KeyValueTimestamp<>(3, "A3+g3", 1003L) + ); + + // push four items with increase timestamp to the other stream; this should produce two items + // w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 899), 1:a1 (ts: 899), 2:a2 (ts: 899), 3:a3 (ts: 899), + // 0:b0 (ts: 900), 1:b1 (ts: 900), 2:b2 (ts: 900), 3:b3 (ts: 900), + // 0:c0 (ts: 901), 1:c1 (ts: 901), 2:c2 (ts: 901), 3:c3 (ts: 901), + // 0:d0 (ts: 902), 1:d1 (ts: 902), 2:d2 (ts: 902), 3:d3 (ts: 902), + // 0:e0 (ts: 903), 1:e1 (ts: 903), 2:e2 (ts: 903), 3:e3 (ts: 903), + // 0:f0 (ts: 1000), 1:f1 (ts: 1000), 2:f2 (ts: 1000), 3:f3 (ts: 1000), + // 0:g0 (ts: 1001), 1:g1 (ts: 1001), 2:g2 (ts: 1001), 3:g3 (ts: 1001) } + // --> w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 899), 1:a1 (ts: 899), 2:a2 (ts: 899), 3:a3 (ts: 899), + // 0:b0 (ts: 900), 1:b1 (ts: 900), 2:b2 (ts: 900), 3:b3 (ts: 900), + // 0:c0 (ts: 901), 1:c1 (ts: 901), 2:c2 (ts: 901), 3:c3 (ts: 901), + // 0:d0 (ts: 902), 1:d1 (ts: 902), 2:d2 (ts: 902), 3:d3 (ts: 902), + // 0:e0 (ts: 903), 1:e1 (ts: 903), 2:e2 (ts: 903), 3:e3 (ts: 903), + // 0:f0 (ts: 1000), 1:f1 (ts: 1000), 2:f2 (ts: 1000), 3:f3 (ts: 1000), + // 0:g0 (ts: 1001), 1:g1 (ts: 1001), 2:g2 (ts: 1001), 3:g3 (ts: 1001), + // 0:h0 (ts: 1002), 1:h1 (ts: 1002), 2:h2 (ts: 1002), 3:h3 (ts: 1002) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "h" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(2, "A2+h2", 1002L), + new KeyValueTimestamp<>(3, "A3+h3", 1003L) + ); + + // push four items with increase timestamp to the other stream; this should produce one item + // w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 899), 1:a1 (ts: 899), 2:a2 (ts: 899), 3:a3 (ts: 899), + // 0:b0 (ts: 900), 1:b1 (ts: 900), 2:b2 (ts: 900), 3:b3 (ts: 900), + // 0:c0 (ts: 901), 1:c1 (ts: 901), 2:c2 (ts: 901), 3:c3 (ts: 901), + // 0:d0 (ts: 902), 1:d1 (ts: 902), 2:d2 (ts: 902), 3:d3 (ts: 902), + // 0:e0 (ts: 903), 1:e1 (ts: 903), 2:e2 (ts: 903), 3:e3 (ts: 903), + // 0:f0 (ts: 1000), 1:f1 (ts: 1000), 2:f2 (ts: 1000), 3:f3 (ts: 1000), + // 0:g0 (ts: 1001), 1:g1 (ts: 1001), 2:g2 (ts: 1001), 3:g3 (ts: 1001), + // 0:h0 (ts: 1002), 1:h1 (ts: 1002), 2:h2 (ts: 1002), 3:h3 (ts: 1002) } + // --> w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 899), 1:a1 (ts: 899), 2:a2 (ts: 899), 3:a3 (ts: 899), + // 0:b0 (ts: 900), 1:b1 (ts: 900), 2:b2 (ts: 900), 3:b3 (ts: 900), + // 0:c0 (ts: 901), 1:c1 (ts: 901), 2:c2 (ts: 901), 3:c3 (ts: 901), + // 0:d0 (ts: 902), 1:d1 (ts: 902), 2:d2 (ts: 902), 3:d3 (ts: 902), + // 0:e0 (ts: 903), 1:e1 (ts: 903), 2:e2 (ts: 903), 3:e3 (ts: 903), + // 0:f0 (ts: 1000), 1:f1 (ts: 1000), 2:f2 (ts: 1000), 3:f3 (ts: 1000), + // 0:g0 (ts: 1001), 1:g1 (ts: 1001), 2:g2 (ts: 1001), 3:g3 (ts: 1001), + // 0:h0 (ts: 1002), 1:h1 (ts: 1002), 2:h2 (ts: 1002), 3:h3 (ts: 1002), + // 0:i0 (ts: 1003), 1:i1 (ts: 1003), 2:i2 (ts: 1003), 3:i3 (ts: 1003) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "i" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(3, "A3+i3", 1003L) + ); + + // push four items with increase timestamp (no out of window) to the other stream; this should produce no items + // w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 899), 1:a1 (ts: 899), 2:a2 (ts: 899), 3:a3 (ts: 899), + // 0:b0 (ts: 900), 1:b1 (ts: 900), 2:b2 (ts: 900), 3:b3 (ts: 900), + // 0:c0 (ts: 901), 1:c1 (ts: 901), 2:c2 (ts: 901), 3:c3 (ts: 901), + // 0:d0 (ts: 902), 1:d1 (ts: 902), 2:d2 (ts: 902), 3:d3 (ts: 902), + // 0:e0 (ts: 903), 1:e1 (ts: 903), 2:e2 (ts: 903), 3:e3 (ts: 903), + // 0:f0 (ts: 1000), 1:f1 (ts: 1000), 2:f2 (ts: 1000), 3:f3 (ts: 1000), + // 0:g0 (ts: 1001), 1:g1 (ts: 1001), 2:g2 (ts: 1001), 3:g3 (ts: 1001), + // 0:h0 (ts: 1002), 1:h1 (ts: 1002), 2:h2 (ts: 1002), 3:h3 (ts: 1002), + // 0:i0 (ts: 1003), 1:i1 (ts: 1003), 2:i2 (ts: 1003), 3:i3 (ts: 1003) } + // --> w1 = { 0:A0 (ts: 1000), 1:A1 (ts: 1001), 2:A2 (ts: 1002), 3:A3 (ts: 1003) } + // w2 = { 0:a0 (ts: 899), 1:a1 (ts: 899), 2:a2 (ts: 899), 3:a3 (ts: 899), + // 0:b0 (ts: 900), 1:b1 (ts: 900), 2:b2 (ts: 900), 3:b3 (ts: 900), + // 0:c0 (ts: 901), 1:c1 (ts: 901), 2:c2 (ts: 901), 3:c3 (ts: 901), + // 0:d0 (ts: 902), 1:d1 (ts: 902), 2:d2 (ts: 902), 3:d3 (ts: 902), + // 0:e0 (ts: 903), 1:e1 (ts: 903), 2:e2 (ts: 903), 3:e3 (ts: 903), + // 0:f0 (ts: 1000), 1:f1 (ts: 1000), 2:f2 (ts: 1000), 3:f3 (ts: 1000), + // 0:g0 (ts: 1001), 1:g1 (ts: 1001), 2:g2 (ts: 1001), 3:g3 (ts: 1001), + // 0:h0 (ts: 1002), 1:h1 (ts: 1002), 2:h2 (ts: 1002), 3:h3 (ts: 1002), + // 0:i0 (ts: 1003), 1:i1 (ts: 1003), 2:i2 (ts: 1003), 3:i3 (ts: 1003), + // 0:j0 (ts: 1004), 1:j1 (ts: 1004), 2:j2 (ts: 1004), 3:j3 (ts: 1004) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "j" + expectedKey, time); + } + processor.checkAndClearProcessResult(); + } + } + + private void buildStreamsJoinThatShouldThrow(final StreamJoined streamJoined, + final JoinWindows joinWindows, + final String expectedExceptionMessagePrefix) { + + final StreamsBuilder builder = new StreamsBuilder(); + final KStream left = builder.stream("left", Consumed.with(Serdes.String(), Serdes.Integer())); + final KStream right = builder.stream("right", Consumed.with(Serdes.String(), Serdes.Integer())); + + final StreamsException streamsException = assertThrows( + StreamsException.class, + () -> left.join( + right, + Integer::sum, + joinWindows, + streamJoined + ) + ); + + assertTrue(streamsException.getMessage().startsWith(expectedExceptionMessagePrefix)); + } + + private WindowBytesStoreSupplier buildWindowBytesStoreSupplier(final String name, + final long retentionPeriod, + final long windowSize, + final boolean retainDuplicates) { + return Stores.inMemoryWindowStore(name, + Duration.ofMillis(retentionPeriod), + Duration.ofMillis(windowSize), + retainDuplicates); + } + + + private final String expectedTopologyWithUserNamedRepartitionTopics = "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [topic])\n" + + " --> KSTREAM-MAP-0000000003\n" + + " Processor: KSTREAM-MAP-0000000003 (stores: [])\n" + + " --> second-join-left-repartition-filter, first-join-left-repartition-filter\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: first-join-left-repartition-filter (stores: [])\n" + + " --> first-join-left-repartition-sink\n" + + " <-- KSTREAM-MAP-0000000003\n" + + " Processor: second-join-left-repartition-filter (stores: [])\n" + + " --> second-join-left-repartition-sink\n" + + " <-- KSTREAM-MAP-0000000003\n" + + " Sink: first-join-left-repartition-sink (topic: first-join-left-repartition)\n" + + " <-- first-join-left-repartition-filter\n" + + " Sink: second-join-left-repartition-sink (topic: second-join-left-repartition)\n" + + " <-- second-join-left-repartition-filter\n" + + "\n" + + " Sub-topology: 1\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [topic2])\n" + + " --> first-join-other-windowed\n" + + " Source: first-join-left-repartition-source (topics: [first-join-left-repartition])\n" + + " --> first-join-this-windowed\n" + + " Processor: first-join-other-windowed (stores: [KSTREAM-JOINOTHER-0000000010-store])\n" + + " --> first-join-other-join\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: first-join-this-windowed (stores: [KSTREAM-JOINTHIS-0000000009-store])\n" + + " --> first-join-this-join\n" + + " <-- first-join-left-repartition-source\n" + + " Processor: first-join-other-join (stores: [KSTREAM-JOINTHIS-0000000009-store])\n" + + " --> first-join-merge\n" + + " <-- first-join-other-windowed\n" + + " Processor: first-join-this-join (stores: [KSTREAM-JOINOTHER-0000000010-store])\n" + + " --> first-join-merge\n" + + " <-- first-join-this-windowed\n" + + " Processor: first-join-merge (stores: [])\n" + + " --> KSTREAM-SINK-0000000012\n" + + " <-- first-join-this-join, first-join-other-join\n" + + " Sink: KSTREAM-SINK-0000000012 (topic: out-one)\n" + + " <-- first-join-merge\n" + + "\n" + + " Sub-topology: 2\n" + + " Source: KSTREAM-SOURCE-0000000002 (topics: [topic3])\n" + + " --> second-join-other-windowed\n" + + " Source: second-join-left-repartition-source (topics: [second-join-left-repartition])\n" + + " --> second-join-this-windowed\n" + + " Processor: second-join-other-windowed (stores: [KSTREAM-JOINOTHER-0000000019-store])\n" + + " --> second-join-other-join\n" + + " <-- KSTREAM-SOURCE-0000000002\n" + + " Processor: second-join-this-windowed (stores: [KSTREAM-JOINTHIS-0000000018-store])\n" + + " --> second-join-this-join\n" + + " <-- second-join-left-repartition-source\n" + + " Processor: second-join-other-join (stores: [KSTREAM-JOINTHIS-0000000018-store])\n" + + " --> second-join-merge\n" + + " <-- second-join-other-windowed\n" + + " Processor: second-join-this-join (stores: [KSTREAM-JOINOTHER-0000000019-store])\n" + + " --> second-join-merge\n" + + " <-- second-join-this-windowed\n" + + " Processor: second-join-merge (stores: [])\n" + + " --> KSTREAM-SINK-0000000021\n" + + " <-- second-join-this-join, second-join-other-join\n" + + " Sink: KSTREAM-SINK-0000000021 (topic: out-two)\n" + + " <-- second-join-merge\n\n"; + + private final String expectedTopologyWithGeneratedRepartitionTopic = "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [topic])\n" + + " --> KSTREAM-MAP-0000000003\n" + + " Processor: KSTREAM-MAP-0000000003 (stores: [])\n" + + " --> KSTREAM-FILTER-0000000005\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: KSTREAM-FILTER-0000000005 (stores: [])\n" + + " --> KSTREAM-SINK-0000000004\n" + + " <-- KSTREAM-MAP-0000000003\n" + + " Sink: KSTREAM-SINK-0000000004 (topic: KSTREAM-MAP-0000000003-repartition)\n" + + " <-- KSTREAM-FILTER-0000000005\n" + + "\n" + + " Sub-topology: 1\n" + + " Source: KSTREAM-SOURCE-0000000006 (topics: [KSTREAM-MAP-0000000003-repartition])\n" + + " --> KSTREAM-WINDOWED-0000000007, KSTREAM-WINDOWED-0000000016\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [topic2])\n" + + " --> KSTREAM-WINDOWED-0000000008\n" + + " Source: KSTREAM-SOURCE-0000000002 (topics: [topic3])\n" + + " --> KSTREAM-WINDOWED-0000000017\n" + + " Processor: KSTREAM-WINDOWED-0000000007 (stores: [KSTREAM-JOINTHIS-0000000009-store])\n" + + " --> KSTREAM-JOINTHIS-0000000009\n" + + " <-- KSTREAM-SOURCE-0000000006\n" + + " Processor: KSTREAM-WINDOWED-0000000008 (stores: [KSTREAM-JOINOTHER-0000000010-store])\n" + + " --> KSTREAM-JOINOTHER-0000000010\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: KSTREAM-WINDOWED-0000000016 (stores: [KSTREAM-JOINTHIS-0000000018-store])\n" + + " --> KSTREAM-JOINTHIS-0000000018\n" + + " <-- KSTREAM-SOURCE-0000000006\n" + + " Processor: KSTREAM-WINDOWED-0000000017 (stores: [KSTREAM-JOINOTHER-0000000019-store])\n" + + " --> KSTREAM-JOINOTHER-0000000019\n" + + " <-- KSTREAM-SOURCE-0000000002\n" + + " Processor: KSTREAM-JOINOTHER-0000000010 (stores: [KSTREAM-JOINTHIS-0000000009-store])\n" + + " --> KSTREAM-MERGE-0000000011\n" + + " <-- KSTREAM-WINDOWED-0000000008\n" + + " Processor: KSTREAM-JOINOTHER-0000000019 (stores: [KSTREAM-JOINTHIS-0000000018-store])\n" + + " --> KSTREAM-MERGE-0000000020\n" + + " <-- KSTREAM-WINDOWED-0000000017\n" + + " Processor: KSTREAM-JOINTHIS-0000000009 (stores: [KSTREAM-JOINOTHER-0000000010-store])\n" + + " --> KSTREAM-MERGE-0000000011\n" + + " <-- KSTREAM-WINDOWED-0000000007\n" + + " Processor: KSTREAM-JOINTHIS-0000000018 (stores: [KSTREAM-JOINOTHER-0000000019-store])\n" + + " --> KSTREAM-MERGE-0000000020\n" + + " <-- KSTREAM-WINDOWED-0000000016\n" + + " Processor: KSTREAM-MERGE-0000000011 (stores: [])\n" + + " --> KSTREAM-SINK-0000000012\n" + + " <-- KSTREAM-JOINTHIS-0000000009, KSTREAM-JOINOTHER-0000000010\n" + + " Processor: KSTREAM-MERGE-0000000020 (stores: [])\n" + + " --> KSTREAM-SINK-0000000021\n" + + " <-- KSTREAM-JOINTHIS-0000000018, KSTREAM-JOINOTHER-0000000019\n" + + " Sink: KSTREAM-SINK-0000000012 (topic: out-one)\n" + + " <-- KSTREAM-MERGE-0000000011\n" + + " Sink: KSTREAM-SINK-0000000021 (topic: out-to)\n" + + " <-- KSTREAM-MERGE-0000000020\n\n"; +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamKStreamLeftJoinTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamKStreamLeftJoinTest.java new file mode 100644 index 0000000..4e2b6d8 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamKStreamLeftJoinTest.java @@ -0,0 +1,1072 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.TopologyWrapper; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.JoinWindows; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.kstream.StreamJoined; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.WindowBytesStoreSupplier; +import org.apache.kafka.test.MockProcessor; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.MockValueJoiner; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Test; + +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.Properties; +import java.util.Set; + +import static java.time.Duration.ofMillis; +import static org.junit.Assert.assertEquals; + + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class KStreamKStreamLeftJoinTest { + private final static KeyValueTimestamp[] EMPTY = new KeyValueTimestamp[0]; + + private final String topic1 = "topic1"; + private final String topic2 = "topic2"; + private final Consumed consumed = Consumed.with(Serdes.Integer(), Serdes.String()); + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.String(), Serdes.String()); + + @Test + public void testLeftJoinWithSpuriousResultFixDisabledOldApi() { + final StreamsBuilder builder = new StreamsBuilder(); + + final int[] expectedKeys = new int[] {0, 1, 2, 3}; + + final KStream stream1; + final KStream stream2; + final KStream joined; + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + stream1 = builder.stream(topic1, consumed); + stream2 = builder.stream(topic2, consumed); + + joined = stream1.leftJoin( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.of(ofMillis(100L)), + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()) + ); + joined.process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(props), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockProcessor processor = supplier.theCapturedProcessor(); + + // Only 2 window stores should be available + assertEquals(2, driver.getAllStateStores().size()); + + // push two items to the primary stream; the other window is empty + // w1 {} + // w2 {} + // --> w1 = { 0:A0, 1:A1 } + // --> w2 = {} + for (int i = 0; i < 2; i++) { + inputTopic1.pipeInput(expectedKeys[i], "A" + expectedKeys[i]); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+null", 0L), + new KeyValueTimestamp<>(1, "A1+null", 0L) + ); + + // push two items to the other stream; this should produce two items + // w1 = { 0:A0, 1:A1 } + // w2 {} + // --> w1 = { 0:A0, 1:A1 } + // --> w2 = { 0:a0, 1:a1 } + for (int i = 0; i < 2; i++) { + inputTopic2.pipeInput(expectedKeys[i], "a" + expectedKeys[i]); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+a0", 0L), + new KeyValueTimestamp<>(1, "A1+a1", 0L) + ); + } + } + + @Test + public void testLeftJoinDuplicatesWithSpuriousResultFixDisabledOldApi() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream stream1; + final KStream stream2; + final KStream joined; + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + stream1 = builder.stream(topic1, consumed); + stream2 = builder.stream(topic2, consumed); + + joined = stream1.leftJoin( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.of(ofMillis(100L)).grace(ofMillis(10L)), + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()) + ); + joined.process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(props), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockProcessor processor = supplier.theCapturedProcessor(); + + // Only 2 window stores should be available + assertEquals(2, driver.getAllStateStores().size()); + + inputTopic1.pipeInput(0, "A0", 0L); + inputTopic1.pipeInput(0, "A0-0", 0L); + inputTopic2.pipeInput(0, "a0", 0L); + + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+null", 0L), + new KeyValueTimestamp<>(0, "A0-0+null", 0L), + new KeyValueTimestamp<>(0, "A0+a0", 0L), + new KeyValueTimestamp<>(0, "A0-0+a0", 0L) + ); + } + } + + @Test + public void testLeftJoinDuplicates() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream stream1; + final KStream stream2; + final KStream joined; + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + stream1 = builder.stream(topic1, consumed); + stream2 = builder.stream(topic2, consumed); + + joined = stream1.leftJoin( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.ofTimeDifferenceAndGrace(ofMillis(100L), ofMillis(10L)), + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()) + ); + joined.process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockProcessor processor = supplier.theCapturedProcessor(); + + // verifies non-joined duplicates are emitted when window has closed + inputTopic1.pipeInput(0, "A0", 0L); + inputTopic1.pipeInput(0, "A0-0", 0L); + inputTopic2.pipeInput(1, "a0", 111L); + // bump stream-time to trigger left-join results + inputTopic2.pipeInput(2, "dummy", 500L); + + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+null", 0L), + new KeyValueTimestamp<>(0, "A0-0+null", 0L) + ); + + // verifies joined duplicates are emitted + inputTopic1.pipeInput(2, "A2", 1000L); + inputTopic1.pipeInput(2, "A2-0", 1000L); + inputTopic2.pipeInput(2, "a2", 1001L); + + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(2, "A2+a2", 1001L), + new KeyValueTimestamp<>(2, "A2-0+a2", 1001L) + ); + + // this record should expired non-joined records, but because A2 and A2-0 are joined and + // emitted already, then they won't be emitted again + inputTopic2.pipeInput(3, "a3", 315L); + + processor.checkAndClearProcessResult(); + } + } + + @Test + public void testLeftExpiredNonJoinedRecordsAreEmittedByTheLeftProcessor() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream stream1; + final KStream stream2; + final KStream joined; + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + + stream1 = builder.stream(topic1, consumed); + stream2 = builder.stream(topic2, consumed); + joined = stream1.leftJoin( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.ofTimeDifferenceAndGrace(ofMillis(100L), ofMillis(0L)), + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()) + ); + joined.process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockProcessor processor = supplier.theCapturedProcessor(); + + final long windowStart = 0L; + + // No joins detected; No null-joins emitted + inputTopic1.pipeInput(0, "A0", windowStart + 1L); + inputTopic1.pipeInput(1, "A1", windowStart + 2L); + inputTopic1.pipeInput(0, "A0-0", windowStart + 3L); + processor.checkAndClearProcessResult(); + + // Join detected; No null-joins emitted + inputTopic2.pipeInput(1, "a1", windowStart + 3L); + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(1, "A1+a1", windowStart + 3L) + ); + + // Dummy record in left topic will emit expired non-joined records from the left topic + inputTopic1.pipeInput(2, "dummy", windowStart + 401L); + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+null", windowStart + 1L), + new KeyValueTimestamp<>(0, "A0-0+null", windowStart + 3L) + ); + + // Flush internal non-joined state store by joining the dummy record + inputTopic2.pipeInput(2, "dummy", windowStart + 401L); + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(2, "dummy+dummy", windowStart + 401L) + ); + } + } + + @Test + public void testLeftExpiredNonJoinedRecordsAreEmittedByTheRightProcessor() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream stream1; + final KStream stream2; + final KStream joined; + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + + stream1 = builder.stream(topic1, consumed); + stream2 = builder.stream(topic2, consumed); + joined = stream1.leftJoin( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.ofTimeDifferenceAndGrace(ofMillis(100L), ofMillis(0L)), + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()) + ); + joined.process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockProcessor processor = supplier.theCapturedProcessor(); + + final long windowStart = 0L; + + // No joins detected; No null-joins emitted + inputTopic1.pipeInput(0, "A0", windowStart + 1L); + inputTopic1.pipeInput(1, "A1", windowStart + 2L); + inputTopic1.pipeInput(0, "A0-0", windowStart + 3L); + processor.checkAndClearProcessResult(); + + // Join detected; No null-joins emitted + inputTopic2.pipeInput(1, "a1", windowStart + 3L); + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(1, "A1+a1", windowStart + 3L) + ); + + // Dummy record in right topic will emit expired non-joined records from the left topic + inputTopic2.pipeInput(2, "dummy", windowStart + 401L); + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+null", windowStart + 1L), + new KeyValueTimestamp<>(0, "A0-0+null", windowStart + 3L) + ); + + // Flush internal non-joined state store by joining the dummy record + inputTopic1.pipeInput(2, "dummy", windowStart + 402L); + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(2, "dummy+dummy", windowStart + 402L) + ); + } + } + + @Test + public void testRightNonJoinedRecordsAreNeverEmittedByTheLeftProcessor() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream stream1; + final KStream stream2; + final KStream joined; + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + + stream1 = builder.stream(topic1, consumed); + stream2 = builder.stream(topic2, consumed); + joined = stream1.leftJoin( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.ofTimeDifferenceAndGrace(ofMillis(100L), ofMillis(0L)), + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()) + ); + joined.process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockProcessor processor = supplier.theCapturedProcessor(); + + final long windowStart = 0L; + + // No joins detected; No null-joins emitted + inputTopic2.pipeInput(0, "A0", windowStart + 1L); + inputTopic2.pipeInput(1, "A1", windowStart + 2L); + inputTopic2.pipeInput(0, "A0-0", windowStart + 3L); + processor.checkAndClearProcessResult(); + + // Join detected; No null-joins emitted + inputTopic1.pipeInput(1, "a1", windowStart + 3L); + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(1, "a1+A1", windowStart + 3L) + ); + + // Dummy record in left topic will not emit records + inputTopic1.pipeInput(2, "dummy", windowStart + 401L); + processor.checkAndClearProcessResult(); + + // Process the dummy joined record + inputTopic2.pipeInput(2, "dummy", windowStart + 402L); + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(2, "dummy+dummy", windowStart + 402L) + ); + } + } + + @Test + public void testRightNonJoinedRecordsAreNeverEmittedByTheRightProcessor() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream stream1; + final KStream stream2; + final KStream joined; + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + + stream1 = builder.stream(topic1, consumed); + stream2 = builder.stream(topic2, consumed); + joined = stream1.leftJoin( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(100L)), + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()) + ); + joined.process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockProcessor processor = supplier.theCapturedProcessor(); + + final long windowStart = 0L; + + // No joins detected; No null-joins emitted + inputTopic2.pipeInput(0, "A0", windowStart + 1L); + inputTopic2.pipeInput(1, "A1", windowStart + 2L); + inputTopic2.pipeInput(0, "A0-0", windowStart + 3L); + processor.checkAndClearProcessResult(); + + // Join detected; No null-joins emitted + inputTopic1.pipeInput(1, "a1", windowStart + 3L); + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(1, "a1+A1", windowStart + 3L) + ); + + // Dummy record in right topic will not emit records + inputTopic2.pipeInput(2, "dummy", windowStart + 401L); + processor.checkAndClearProcessResult(); + + // Process the dummy joined record + inputTopic1.pipeInput(2, "dummy", windowStart + 402L); + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(2, "dummy+dummy", windowStart + 402L) + ); + } + } + + @Test + public void testLeftJoinWithInMemoryCustomSuppliers() { + final JoinWindows joinWindows = JoinWindows.ofTimeDifferenceAndGrace(ofMillis(100L), ofMillis(0L)); + + final WindowBytesStoreSupplier thisStoreSupplier = Stores.inMemoryWindowStore("in-memory-join-store", + Duration.ofMillis(joinWindows.size() + joinWindows.gracePeriodMs()), + Duration.ofMillis(joinWindows.size()), true); + + final WindowBytesStoreSupplier otherStoreSupplier = Stores.inMemoryWindowStore("in-memory-join-store-other", + Duration.ofMillis(joinWindows.size() + joinWindows.gracePeriodMs()), + Duration.ofMillis(joinWindows.size()), true); + + final StreamJoined streamJoined = StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()); + + runLeftJoin(streamJoined.withThisStoreSupplier(thisStoreSupplier).withOtherStoreSupplier(otherStoreSupplier), joinWindows); + } + + @Test + public void testLeftJoinWithDefaultSuppliers() { + final JoinWindows joinWindows = JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(100L)); + final StreamJoined streamJoined = StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()); + + runLeftJoin(streamJoined, joinWindows); + } + + public void runLeftJoin(final StreamJoined streamJoined, + final JoinWindows joinWindows) { + final StreamsBuilder builder = new StreamsBuilder(); + + final int[] expectedKeys = new int[] {0, 1, 2, 3}; + + final KStream stream1; + final KStream stream2; + final KStream joined; + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + stream1 = builder.stream(topic1, consumed); + stream2 = builder.stream(topic2, consumed); + + joined = stream1.leftJoin( + stream2, + MockValueJoiner.TOSTRING_JOINER, + joinWindows, + streamJoined + ); + joined.process(supplier); + + final Collection> copartitionGroups = + TopologyWrapper.getInternalTopologyBuilder(builder.build()).copartitionGroups(); + + assertEquals(1, copartitionGroups.size()); + assertEquals(new HashSet<>(Arrays.asList(topic1, topic2)), copartitionGroups.iterator().next()); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockProcessor processor = supplier.theCapturedProcessor(); + + // 2 window stores + 1 shared window store should be available + assertEquals(3, driver.getAllStateStores().size()); + + // push two items to the primary stream; the other window is empty + // w1 {} + // w2 {} + // --> w1 = { 0:A0, 1:A1 } + // --> w2 = {} + for (int i = 0; i < 2; i++) { + inputTopic1.pipeInput(expectedKeys[i], "A" + expectedKeys[i]); + } + processor.checkAndClearProcessResult(); + + // push two items to the other stream; this should produce two items + // w1 = { 0:A0, 1:A1 } + // w2 {} + // --> w1 = { 0:A0, 1:A1 } + // --> w2 = { 0:a0, 1:a1 } + for (int i = 0; i < 2; i++) { + inputTopic2.pipeInput(expectedKeys[i], "a" + expectedKeys[i]); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+a0", 0L), + new KeyValueTimestamp<>(1, "A1+a1", 0L) + ); + + // push three items to the primary stream; this should produce two joined items + // w1 = { 0:A0, 1:A1 } + // w2 = { 0:a0, 1:a1 } + // --> w1 = { 0:A0, 1:A1, 0:B0, 1:B1, 2:B2 } + // --> w2 = { 0:a0, 1:a1 } + for (int i = 0; i < 3; i++) { + inputTopic1.pipeInput(expectedKeys[i], "B" + expectedKeys[i]); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "B0+a0", 0L), + new KeyValueTimestamp<>(1, "B1+a1", 0L) + ); + + // push all items to the other stream; this should produce five items + // w1 = { 0:A0, 1:A1, 0:B0, 1:B1, 2:B2 } + // w2 = { 0:a0, 1:a1 } + // --> w1 = { 0:A0, 1:A1, 0:B0, 1:B1, 2:B2 } + // --> w2 = { 0:a0, 1:a1, 0:b0, 1:b1, 2:b2, 3:b3 } + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "b" + expectedKey); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+b0", 0L), + new KeyValueTimestamp<>(0, "B0+b0", 0L), + new KeyValueTimestamp<>(1, "A1+b1", 0L), + new KeyValueTimestamp<>(1, "B1+b1", 0L), + new KeyValueTimestamp<>(2, "B2+b2", 0L) + ); + + // push all four items to the primary stream; this should produce six joined items + // w1 = { 0:A0, 1:A1, 0:B0, 1:B1, 2:B2 } + // w2 = { 0:a0, 1:a1, 0:b0, 1:b1, 2:b2, 3:b3 } + // --> w1 = { 0:A0, 1:A1, 0:B0, 1:B1, 2:B2, 0:C0, 1:C1, 2:C2, 3:C3 } + // --> w2 = { 0:a0, 1:a1, 0:b0, 1:b1, 2:b2, 3:b3 } + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "C" + expectedKey); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "C0+a0", 0L), + new KeyValueTimestamp<>(0, "C0+b0", 0L), + new KeyValueTimestamp<>(1, "C1+a1", 0L), + new KeyValueTimestamp<>(1, "C1+b1", 0L), + new KeyValueTimestamp<>(2, "C2+b2", 0L), + new KeyValueTimestamp<>(3, "C3+b3", 0L) + ); + + // push a dummy record that should expire non-joined items; it should not produce any items because + // all of them are joined + inputTopic1.pipeInput(0, "dummy", 1000L); + processor.checkAndClearProcessResult(); + } + } + + @Test + public void testOrdering() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream stream1; + final KStream stream2; + final KStream joined; + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + stream1 = builder.stream(topic1, consumed); + stream2 = builder.stream(topic2, consumed); + + joined = stream1.leftJoin( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(100L)), + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()) + ); + joined.process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockProcessor processor = supplier.theCapturedProcessor(); + + // push two items to the primary stream; the other window is empty; this should not produce any item yet + // w1 = {} + // w2 = {} + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 100) } + // --> w2 = {} + inputTopic1.pipeInput(0, "A0", 0L); + inputTopic1.pipeInput(1, "A1", 100L); + processor.checkAndClearProcessResult(); + + // push one item to the other window that has a join; this should produce non-joined records with a closed window first, then + // the joined records + // by the time they were produced before + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 100) } + // w2 = { } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 100) } + // --> w2 = { 1:a1 (ts: 110) } + inputTopic2.pipeInput(1, "a1", 110L); + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+null", 0L), + new KeyValueTimestamp<>(1, "A1+a1", 110L) + ); + } + } + + @Test + public void testGracePeriod() { + final StreamsBuilder builder = new StreamsBuilder(); + final int[] expectedKeys = new int[] {0, 1, 2, 3}; + + final KStream stream1; + final KStream stream2; + final KStream joined; + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + stream1 = builder.stream(topic1, consumed); + stream2 = builder.stream(topic2, consumed); + + joined = stream1.leftJoin( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.ofTimeDifferenceAndGrace(ofMillis(100L), ofMillis(10L)), + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()) + ); + joined.process(supplier); + + final Collection> copartitionGroups = + TopologyWrapper.getInternalTopologyBuilder(builder.build()).copartitionGroups(); + + assertEquals(1, copartitionGroups.size()); + assertEquals(new HashSet<>(Arrays.asList(topic1, topic2)), copartitionGroups.iterator().next()); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockProcessor processor = supplier.theCapturedProcessor(); + + // push two items to the primary stream; the other window is empty; this should not produce items because window has not closed + // w1 = {} + // w2 = {} + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0) } + // --> w2 = {} + long time = 0L; + for (int i = 0; i < 2; i++) { + inputTopic1.pipeInput(expectedKeys[i], "A" + expectedKeys[i], time); + } + processor.checkAndClearProcessResult(); + + // push two items to the other stream with a window time after the previous window ended (not closed); this should not produce + // joined records because the window has ended, but not closed. + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0) } + // w2 = { } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0) } + // --> w2 = { 0:a0 (ts: 101), 1:a1 (ts: 101) } + time += 101L; + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "a" + expectedKey, time); + } + processor.checkAndClearProcessResult(); + + // push a dummy item to the other stream after the window is closed; this should only produced the expired non-joined records, but + // not the joined records because the window has closed + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0) } + // w2 = { 0:a0 (ts: 101), 1:a1 (ts: 101) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0) } + // --> w2 = { 0:a0 (ts: 101), 1:a1 (ts: 101), + // 0:dummy (ts: 211)} + time += 1100L; + inputTopic2.pipeInput(0, "dummy", time); + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+null", 0L), + new KeyValueTimestamp<>(1, "A1+null", 0L) + ); + } + } + + @Test + public void testWindowing() { + final StreamsBuilder builder = new StreamsBuilder(); + final int[] expectedKeys = new int[] {0, 1, 2, 3}; + + final KStream stream1; + final KStream stream2; + final KStream joined; + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + stream1 = builder.stream(topic1, consumed); + stream2 = builder.stream(topic2, consumed); + + joined = stream1.leftJoin( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(100)), + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()) + ); + joined.process(supplier); + + final Collection> copartitionGroups = + TopologyWrapper.getInternalTopologyBuilder(builder.build()).copartitionGroups(); + + assertEquals(1, copartitionGroups.size()); + assertEquals(new HashSet<>(Arrays.asList(topic1, topic2)), copartitionGroups.iterator().next()); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockProcessor processor = supplier.theCapturedProcessor(); + final long time = 0L; + + // push two items to the primary stream; the other window is empty; this should not produce any items + // w1 = {} + // w2 = {} + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0) } + // --> w2 = {} + for (int i = 0; i < 2; i++) { + inputTopic1.pipeInput(expectedKeys[i], "A" + expectedKeys[i], time); + } + processor.checkAndClearProcessResult(); + + // push four items to the other stream; this should produce two full-join items + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0) } + // w2 = {} + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0) } + // --> w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0) } + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "a" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+a0", 0L), + new KeyValueTimestamp<>(1, "A1+a1", 0L) + ); + testUpperWindowBound(expectedKeys, driver, processor); + testLowerWindowBound(expectedKeys, driver, processor); + } + } + + private void testUpperWindowBound(final int[] expectedKeys, + final TopologyTestDriver driver, + final MockProcessor processor) { + long time; + + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + + // push four items with larger and increasing timestamp (out of window) to the other stream; this should produce no items + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0) } + // --> w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + time = 1000L; + for (int i = 0; i < expectedKeys.length; i++) { + inputTopic2.pipeInput(expectedKeys[i], "b" + expectedKeys[i], time + i); + } + processor.checkAndClearProcessResult(EMPTY); + + // push four items with larger timestamp to the primary stream; this should produce four full-join items + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100) } + // --> w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + time = 1000L + 100L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "B" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "B0+b0", 1100L), + new KeyValueTimestamp<>(1, "B1+b1", 1100L), + new KeyValueTimestamp<>(2, "B2+b2", 1100L), + new KeyValueTimestamp<>(3, "B3+b3", 1100L) + ); + + // push four items with increased timestamp to the primary stream; this should produce one left-join and three full-join items (non-joined item is not produced yet) + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101) } + // --> w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "C" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(1, "C1+b1", 1101L), + new KeyValueTimestamp<>(2, "C2+b2", 1101L), + new KeyValueTimestamp<>(3, "C3+b3", 1101L) + ); + + // push four items with increased timestamp to the primary stream; this should produce two left-join and two full-join items (non-joined items are not produced yet) + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102) } + // --> w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "D" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(2, "D2+b2", 1102L), + new KeyValueTimestamp<>(3, "D3+b3", 1102L) + ); + + // push four items with increased timestamp to the primary stream; this should produce one full-join items (three non-joined left-join are not produced yet) + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102), + // 0:E0 (ts: 1103), 1:E1 (ts: 1103), 2:E2 (ts: 1103), 3:E3 (ts: 1103) } + // --> w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "E" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(3, "E3+b3", 1103L) + ); + + // push four items with increased timestamp to the primary stream; this should produce no full-join items (four non-joined left-join are not produced yet) + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102), + // 0:E0 (ts: 1103), 1:E1 (ts: 1103), 2:E2 (ts: 1103), 3:E3 (ts: 1103) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102), + // 0:E0 (ts: 1103), 1:E1 (ts: 1103), 2:E2 (ts: 1103), 3:E3 (ts: 1103), + // 0:F0 (ts: 1104), 1:F1 (ts: 1104), 2:F2 (ts: 1104), 3:F3 (ts: 1104) } + // --> w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "F" + expectedKey, time); + } + processor.checkAndClearProcessResult(); + + // push a dummy record to produce all left-join non-joined items + time += 301L; + driver.advanceWallClockTime(Duration.ofMillis(1000L)); + inputTopic1.pipeInput(0, "dummy", time); + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "C0+null", 1101L), + new KeyValueTimestamp<>(0, "D0+null", 1102L), + new KeyValueTimestamp<>(1, "D1+null", 1102L), + new KeyValueTimestamp<>(0, "E0+null", 1103L), + new KeyValueTimestamp<>(1, "E1+null", 1103L), + new KeyValueTimestamp<>(2, "E2+null", 1103L), + new KeyValueTimestamp<>(0, "F0+null", 1104L), + new KeyValueTimestamp<>(1, "F1+null", 1104L), + new KeyValueTimestamp<>(2, "F2+null", 1104L), + new KeyValueTimestamp<>(3, "F3+null", 1104L) + ); + } + + private void testLowerWindowBound(final int[] expectedKeys, + final TopologyTestDriver driver, + final MockProcessor processor) { + long time; + final TestInputTopic inputTopic1 = driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer()); + + // push four items with smaller timestamp (before the window) to the primary stream; this should produce four left-join and no full-join items + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102), + // 0:E0 (ts: 1103), 1:E1 (ts: 1103), 2:E2 (ts: 1103), 3:E3 (ts: 1103), + // 0:F0 (ts: 1104), 1:F1 (ts: 1104), 2:F2 (ts: 1104), 3:F3 (ts: 1104) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102), + // 0:E0 (ts: 1103), 1:E1 (ts: 1103), 2:E2 (ts: 1103), 3:E3 (ts: 1103), + // 0:F0 (ts: 1104), 1:F1 (ts: 1104), 2:F2 (ts: 1104), 3:F3 (ts: 1104), + // 0:G0 (ts: 899), 1:G1 (ts: 899), 2:G2 (ts: 899), 3:G3 (ts: 899) } + // --> w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + time = 1000L - 100L - 1L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "G" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "G0+null", 899L), + new KeyValueTimestamp<>(1, "G1+null", 899L), + new KeyValueTimestamp<>(2, "G2+null", 899L), + new KeyValueTimestamp<>(3, "G3+null", 899L) + ); + + // push four items with increase timestamp to the primary stream; this should produce three left-join and one full-join items + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102), + // 0:E0 (ts: 1103), 1:E1 (ts: 1103), 2:E2 (ts: 1103), 3:E3 (ts: 1103), + // 0:F0 (ts: 1104), 1:F1 (ts: 1104), 2:F2 (ts: 1104), 3:F3 (ts: 1104), + // 0:G0 (ts: 899), 1:G1 (ts: 899), 2:G2 (ts: 899), 3:G3 (ts: 899) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102), + // 0:E0 (ts: 1103), 1:E1 (ts: 1103), 2:E2 (ts: 1103), 3:E3 (ts: 1103), + // 0:F0 (ts: 1104), 1:F1 (ts: 1104), 2:F2 (ts: 1104), 3:F3 (ts: 1104), + // 0:G0 (ts: 899), 1:G1 (ts: 899), 2:G2 (ts: 899), 3:G3 (ts: 899), + // 0:H0 (ts: 900), 1:H1 (ts: 900), 2:H2 (ts: 900), 3:H3 (ts: 900) } + // --> w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "H" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "H0+b0", 1000L), + new KeyValueTimestamp<>(1, "H1+null", 900L), + new KeyValueTimestamp<>(2, "H2+null", 900L), + new KeyValueTimestamp<>(3, "H3+null", 900L) + ); + + // push four items with increase timestamp to the primary stream; this should produce two left-join and two full-join items + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102), + // 0:E0 (ts: 1103), 1:E1 (ts: 1103), 2:E2 (ts: 1103), 3:E3 (ts: 1103), + // 0:F0 (ts: 1104), 1:F1 (ts: 1104), 2:F2 (ts: 1104), 3:F3 (ts: 1104), + // 0:G0 (ts: 899), 1:G1 (ts: 899), 2:G2 (ts: 899), 3:G3 (ts: 899), + // 0:H0 (ts: 900), 1:H1 (ts: 900), 2:H2 (ts: 900), 3:H3 (ts: 900) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102), + // 0:E0 (ts: 1103), 1:E1 (ts: 1103), 2:E2 (ts: 1103), 3:E3 (ts: 1103), + // 0:F0 (ts: 1104), 1:F1 (ts: 1104), 2:F2 (ts: 1104), 3:F3 (ts: 1104), + // 0:G0 (ts: 899), 1:G1 (ts: 899), 2:G2 (ts: 899), 3:G3 (ts: 899), + // 0:H0 (ts: 900), 1:H1 (ts: 900), 2:H2 (ts: 900), 3:H3 (ts: 900), + // 0:I0 (ts: 901), 1:I1 (ts: 901), 2:I2 (ts: 901), 3:I3 (ts: 901) } + // --> w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "I" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "I0+b0", 1000L), + new KeyValueTimestamp<>(1, "I1+b1", 1001L), + new KeyValueTimestamp<>(2, "I2+null", 901L), + new KeyValueTimestamp<>(3, "I3+null", 901L) + ); + + // push four items with increase timestamp to the primary stream; this should produce one left-join and three full-join items + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102), + // 0:E0 (ts: 1103), 1:E1 (ts: 1103), 2:E2 (ts: 1103), 3:E3 (ts: 1103), + // 0:F0 (ts: 1104), 1:F1 (ts: 1104), 2:F2 (ts: 1104), 3:F3 (ts: 1104), + // 0:G0 (ts: 899), 1:G1 (ts: 899), 2:G2 (ts: 899), 3:G3 (ts: 899), + // 0:H0 (ts: 900), 1:H1 (ts: 900), 2:H2 (ts: 900), 3:H3 (ts: 900), + // 0:I0 (ts: 901), 1:I1 (ts: 901), 2:I2 (ts: 901), 3:I3 (ts: 901) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102), + // 0:E0 (ts: 1103), 1:E1 (ts: 1103), 2:E2 (ts: 1103), 3:E3 (ts: 1103), + // 0:F0 (ts: 1104), 1:F1 (ts: 1104), 2:F2 (ts: 1104), 3:F3 (ts: 1104), + // 0:G0 (ts: 899), 1:G1 (ts: 899), 2:G2 (ts: 899), 3:G3 (ts: 899), + // 0:H0 (ts: 900), 1:H1 (ts: 900), 2:H2 (ts: 900), 3:H3 (ts: 900), + // 0:I0 (ts: 901), 1:I1 (ts: 901), 2:I2 (ts: 901), 3:I3 (ts: 901), + // 0:J0 (ts: 902), 1:J1 (ts: 902), 2:J2 (ts: 902), 3:J3 (ts: 902) } + // --> w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "J" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "J0+b0", 1000L), + new KeyValueTimestamp<>(1, "J1+b1", 1001L), + new KeyValueTimestamp<>(2, "J2+b2", 1002L), + new KeyValueTimestamp<>(3, "J3+null", 902L) + ); + + // push four items with increase timestamp to the primary stream; this should produce one left-join and three full-join items + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102), + // 0:E0 (ts: 1103), 1:E1 (ts: 1103), 2:E2 (ts: 1103), 3:E3 (ts: 1103), + // 0:F0 (ts: 1104), 1:F1 (ts: 1104), 2:F2 (ts: 1104), 3:F3 (ts: 1104), + // 0:G0 (ts: 899), 1:G1 (ts: 899), 2:G2 (ts: 899), 3:G3 (ts: 899), + // 0:H0 (ts: 900), 1:H1 (ts: 900), 2:H2 (ts: 900), 3:H3 (ts: 900), + // 0:I0 (ts: 901), 1:I1 (ts: 901), 2:I2 (ts: 901), 3:I3 (ts: 901), + // 0:J0 (ts: 902), 1:J1 (ts: 902), 2:J2 (ts: 902), 3:J3 (ts: 902) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102), + // 0:E0 (ts: 1103), 1:E1 (ts: 1103), 2:E2 (ts: 1103), 3:E3 (ts: 1103), + // 0:F0 (ts: 1104), 1:F1 (ts: 1104), 2:F2 (ts: 1104), 3:F3 (ts: 1104), + // 0:G0 (ts: 899), 1:G1 (ts: 899), 2:G2 (ts: 899), 3:G3 (ts: 899), + // 0:H0 (ts: 900), 1:H1 (ts: 900), 2:H2 (ts: 900), 3:H3 (ts: 900), + // 0:I0 (ts: 901), 1:I1 (ts: 901), 2:I2 (ts: 901), 3:I3 (ts: 901), + // 0:J0 (ts: 902), 1:J1 (ts: 902), 2:J2 (ts: 902), 3:J3 (ts: 902), + // 0:K0 (ts: 903), 1:K1 (ts: 903), 2:K2 (ts: 903), 3:K3 (ts: 903) } + // --> w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "K" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "K0+b0", 1000L), + new KeyValueTimestamp<>(1, "K1+b1", 1001L), + new KeyValueTimestamp<>(2, "K2+b2", 1002L), + new KeyValueTimestamp<>(3, "K3+b3", 1003L) + ); + + // push a dummy record that should expire non-joined items; it should produce only the dummy+null record because + // all previous late records were emitted immediately + inputTopic1.pipeInput(0, "dummy", time + 300L); + processor.checkAndClearProcessResult(new KeyValueTimestamp<>(0, "dummy+null", 1203L)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamKStreamOuterJoinTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamKStreamOuterJoinTest.java new file mode 100644 index 0000000..2d9e320 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamKStreamOuterJoinTest.java @@ -0,0 +1,1037 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.TopologyWrapper; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.JoinWindows; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.StreamJoined; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.WindowBytesStoreSupplier; +import org.apache.kafka.test.MockProcessor; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.MockValueJoiner; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Test; + +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.Properties; +import java.util.Set; + +import static java.time.Duration.ofMillis; +import static org.junit.Assert.assertEquals; + + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class KStreamKStreamOuterJoinTest { + private final String topic1 = "topic1"; + private final String topic2 = "topic2"; + private final Consumed consumed = Consumed.with(Serdes.Integer(), Serdes.String()); + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.String(), Serdes.String()); + + @Test + public void testOuterJoinDuplicatesWithFixDisabledOldApi() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream stream1; + final KStream stream2; + final KStream joined; + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + stream1 = builder.stream(topic1, consumed); + stream2 = builder.stream(topic2, consumed); + + joined = stream1.outerJoin( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.of(ofMillis(100L)).grace(ofMillis(10L)), + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()) + ); + joined.process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(props), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockProcessor processor = supplier.theCapturedProcessor(); + + // Only 2 window stores should be available + assertEquals(2, driver.getAllStateStores().size()); + + inputTopic1.pipeInput(0, "A0", 0L); + inputTopic1.pipeInput(0, "A0-0", 0L); + inputTopic2.pipeInput(0, "a0", 0L); + inputTopic2.pipeInput(1, "b1", 0L); + + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+null", 0L), + new KeyValueTimestamp<>(0, "A0-0+null", 0L), + new KeyValueTimestamp<>(0, "A0+a0", 0L), + new KeyValueTimestamp<>(0, "A0-0+a0", 0L), + new KeyValueTimestamp<>(1, "null+b1", 0L) + ); + } + } + + @Test + public void testOuterJoinDuplicates() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream stream1; + final KStream stream2; + final KStream joined; + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + stream1 = builder.stream(topic1, consumed); + stream2 = builder.stream(topic2, consumed); + + joined = stream1.outerJoin( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.ofTimeDifferenceAndGrace(ofMillis(100L), ofMillis(10L)), + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()) + ); + joined.process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockProcessor processor = supplier.theCapturedProcessor(); + + // verifies non-joined duplicates are emitted when window has closed + inputTopic1.pipeInput(0, "A0", 0L); + inputTopic1.pipeInput(0, "A0-0", 0L); + inputTopic2.pipeInput(1, "a1", 0L); + inputTopic2.pipeInput(1, "a1-0", 0L); + inputTopic2.pipeInput(1, "a0", 111L); + // bump stream-time to trigger outer-join results + inputTopic2.pipeInput(3, "dummy", 211); + + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(1, "null+a1", 0L), + new KeyValueTimestamp<>(1, "null+a1-0", 0L), + new KeyValueTimestamp<>(0, "A0+null", 0L), + new KeyValueTimestamp<>(0, "A0-0+null", 0L) + ); + + // verifies joined duplicates are emitted + inputTopic1.pipeInput(2, "A2", 200L); + inputTopic1.pipeInput(2, "A2-0", 200L); + inputTopic2.pipeInput(2, "a2", 201L); + inputTopic2.pipeInput(2, "a2-0", 201L); + + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(2, "A2+a2", 201L), + new KeyValueTimestamp<>(2, "A2-0+a2", 201L), + new KeyValueTimestamp<>(2, "A2+a2-0", 201L), + new KeyValueTimestamp<>(2, "A2-0+a2-0", 201L) + ); + + // this record should expired non-joined records; only null+a0 will be emitted because + // it did not have a join + driver.advanceWallClockTime(Duration.ofMillis(1000L)); + inputTopic2.pipeInput(3, "dummy", 1500L); + + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(1, "null+a0", 111L), + new KeyValueTimestamp<>(3, "null+dummy", 211) + ); + } + } + + @Test + public void testLeftExpiredNonJoinedRecordsAreEmittedByTheLeftProcessor() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream stream1; + final KStream stream2; + final KStream joined; + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + + stream1 = builder.stream(topic1, consumed); + stream2 = builder.stream(topic2, consumed); + joined = stream1.outerJoin( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(100L)), + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()) + ); + joined.process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockProcessor processor = supplier.theCapturedProcessor(); + + final long windowStart = 0L; + + // No joins detected; No null-joins emitted + inputTopic1.pipeInput(0, "A0", windowStart + 1L); + inputTopic1.pipeInput(1, "A1", windowStart + 2L); + inputTopic1.pipeInput(0, "A0-0", windowStart + 3L); + processor.checkAndClearProcessResult(); + + // Join detected; No null-joins emitted + inputTopic2.pipeInput(1, "a1", windowStart + 3L); + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(1, "A1+a1", windowStart + 3L) + ); + + // Dummy record in left topic will emit expired non-joined records from the left topic + inputTopic1.pipeInput(2, "dummy", windowStart + 401L); + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+null", windowStart + 1L), + new KeyValueTimestamp<>(0, "A0-0+null", windowStart + 3L) + ); + + // Flush internal non-joined state store by joining the dummy record + inputTopic2.pipeInput(2, "dummy", windowStart + 401L); + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(2, "dummy+dummy", windowStart + 401L) + ); + } + } + + @Test + public void testLeftExpiredNonJoinedRecordsAreEmittedByTheRightProcessor() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream stream1; + final KStream stream2; + final KStream joined; + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + + stream1 = builder.stream(topic1, consumed); + stream2 = builder.stream(topic2, consumed); + joined = stream1.outerJoin( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.ofTimeDifferenceAndGrace(ofMillis(100L), ofMillis(0L)), + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()) + ); + joined.process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockProcessor processor = supplier.theCapturedProcessor(); + + final long windowStart = 0L; + + // No joins detected; No null-joins emitted + inputTopic1.pipeInput(0, "A0", windowStart + 1L); + inputTopic1.pipeInput(1, "A1", windowStart + 2L); + inputTopic1.pipeInput(0, "A0-0", windowStart + 3L); + processor.checkAndClearProcessResult(); + + // Join detected; No null-joins emitted + inputTopic2.pipeInput(1, "a1", windowStart + 3L); + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(1, "A1+a1", windowStart + 3L) + ); + + // Dummy record in right topic will emit expired non-joined records from the left topic + inputTopic2.pipeInput(2, "dummy", windowStart + 401L); + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+null", windowStart + 1L), + new KeyValueTimestamp<>(0, "A0-0+null", windowStart + 3L) + ); + + // Flush internal non-joined state store by joining the dummy record + inputTopic1.pipeInput(2, "dummy", windowStart + 402L); + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(2, "dummy+dummy", windowStart + 402L) + ); + } + } + + @Test + public void testRightExpiredNonJoinedRecordsAreEmittedByTheLeftProcessor() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream stream1; + final KStream stream2; + final KStream joined; + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + + stream1 = builder.stream(topic1, consumed); + stream2 = builder.stream(topic2, consumed); + joined = stream1.outerJoin( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(100L)), + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()) + ); + joined.process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockProcessor processor = supplier.theCapturedProcessor(); + + final long windowStart = 0L; + + // No joins detected; No null-joins emitted + inputTopic2.pipeInput(0, "A0", windowStart + 1L); + inputTopic2.pipeInput(1, "A1", windowStart + 2L); + inputTopic2.pipeInput(0, "A0-0", windowStart + 3L); + processor.checkAndClearProcessResult(); + + // Join detected; No null-joins emitted + inputTopic1.pipeInput(1, "a1", windowStart + 3L); + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(1, "a1+A1", windowStart + 3L) + ); + + // Dummy record in left topic will emit expired non-joined records from the right topic + inputTopic1.pipeInput(2, "dummy", windowStart + 401L); + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "null+A0", windowStart + 1L), + new KeyValueTimestamp<>(0, "null+A0-0", windowStart + 3L) + ); + + // Process the dummy joined record + inputTopic2.pipeInput(2, "dummy", windowStart + 402L); + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(2, "dummy+dummy", windowStart + 402L) + ); + } + } + + @Test + public void testRightExpiredNonJoinedRecordsAreEmittedByTheRightProcessor() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream stream1; + final KStream stream2; + final KStream joined; + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + + stream1 = builder.stream(topic1, consumed); + stream2 = builder.stream(topic2, consumed); + joined = stream1.outerJoin( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.ofTimeDifferenceAndGrace(ofMillis(100L), ofMillis(0L)), + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()) + ); + joined.process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockProcessor processor = supplier.theCapturedProcessor(); + + final long windowStart = 0L; + + // No joins detected; No null-joins emitted + inputTopic2.pipeInput(0, "A0", windowStart + 1L); + inputTopic2.pipeInput(1, "A1", windowStart + 2L); + inputTopic2.pipeInput(0, "A0-0", windowStart + 3L); + processor.checkAndClearProcessResult(); + + // Join detected; No null-joins emitted + inputTopic1.pipeInput(1, "a1", windowStart + 3L); + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(1, "a1+A1", windowStart + 3L) + ); + + // Dummy record in right topic will emit expired non-joined records from the right topic + inputTopic2.pipeInput(2, "dummy", windowStart + 401L); + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "null+A0", windowStart + 1L), + new KeyValueTimestamp<>(0, "null+A0-0", windowStart + 3L) + ); + + // Process the dummy joined record + inputTopic1.pipeInput(2, "dummy", windowStart + 402L); + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(2, "dummy+dummy", windowStart + 402L) + ); + } + } + + @Test + public void testOrdering() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream stream1; + final KStream stream2; + final KStream joined; + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + stream1 = builder.stream(topic1, consumed); + stream2 = builder.stream(topic2, consumed); + + joined = stream1.outerJoin( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(100)), + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()) + ); + joined.process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockProcessor processor = supplier.theCapturedProcessor(); + + // push two items to the primary stream; the other window is empty; this should not produce any item yet + // w1 = {} + // w2 = {} + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 100) } + // --> w2 = {} + inputTopic1.pipeInput(0, "A0", 0L); + inputTopic1.pipeInput(1, "A1", 100L); + processor.checkAndClearProcessResult(); + + // push one item to the other window that has a join; this should produce non-joined records with a closed window first, then + // the joined records + // by the time they were produced before + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 100) } + // w2 = { } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0) } + // --> w2 = { 0:a0 (ts: 110) } + inputTopic2.pipeInput(1, "a1", 110L); + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+null", 0L), + new KeyValueTimestamp<>(1, "A1+a1", 110L) + ); + } + } + + @Test + public void testGracePeriod() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream stream1; + final KStream stream2; + final KStream joined; + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + stream1 = builder.stream(topic1, consumed); + stream2 = builder.stream(topic2, consumed); + + joined = stream1.outerJoin( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.ofTimeDifferenceAndGrace(ofMillis(100), ofMillis(10)), + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()) + ); + joined.process(supplier); + + final Collection> copartitionGroups = + TopologyWrapper.getInternalTopologyBuilder(builder.build()).copartitionGroups(); + + assertEquals(1, copartitionGroups.size()); + assertEquals(new HashSet<>(Arrays.asList(topic1, topic2)), copartitionGroups.iterator().next()); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockProcessor processor = supplier.theCapturedProcessor(); + + // push one item to the primary stream; and one item in other stream; this should not produce items because there are no joins + // and window has not ended + // w1 = {} + // w2 = {} + // --> w1 = { 0:A0 (ts: 0) } + // --> w2 = { 1:a1 (ts: 0) } + inputTopic1.pipeInput(0, "A0", 0L); + inputTopic2.pipeInput(1, "a1", 0L); + processor.checkAndClearProcessResult(); + + // push one item on each stream with a window time after the previous window ended (not closed); this should not produce + // joined records because the window has ended, but will not produce non-joined records because the window has not closed. + // w1 = { 0:A0 (ts: 0) } + // w2 = { 1:a1 (ts: 0) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0) } + // --> w2 = { 0:a0 (ts: 101), 1:a1 (ts: 101) } + inputTopic2.pipeInput(0, "a0", 101L); + inputTopic1.pipeInput(1, "A1", 101L); + processor.checkAndClearProcessResult(); + + // push a dummy item to the any stream after the window is closed; this should produced all expired non-joined records because + // the window has closed + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0) } + // w2 = { 0:a0 (ts: 101), 1:a1 (ts: 101) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0) } + // --> w2 = { 0:a0 (ts: 101), 1:a1 (ts: 101), 0:dummy (ts: 112) } + inputTopic2.pipeInput(0, "dummy", 211); + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(1, "null+a1", 0L), + new KeyValueTimestamp<>(0, "A0+null", 0L) + ); + } + } + + @Test + public void testOuterJoinWithInMemoryCustomSuppliers() { + final JoinWindows joinWindows = JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(100L)); + + final WindowBytesStoreSupplier thisStoreSupplier = Stores.inMemoryWindowStore( + "in-memory-join-store", + Duration.ofMillis(joinWindows.size() + joinWindows.gracePeriodMs()), + Duration.ofMillis(joinWindows.size()), + true + ); + + final WindowBytesStoreSupplier otherStoreSupplier = Stores.inMemoryWindowStore( + "in-memory-join-store-other", + Duration.ofMillis(joinWindows.size() + joinWindows.gracePeriodMs()), + Duration.ofMillis(joinWindows.size()), + true + ); + + final StreamJoined streamJoined = StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()); + + runOuterJoin(streamJoined.withThisStoreSupplier(thisStoreSupplier).withOtherStoreSupplier(otherStoreSupplier), joinWindows); + } + + @Test + public void testOuterJoinWithDefaultSuppliers() { + final JoinWindows joinWindows = JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(100L)); + final StreamJoined streamJoined = StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()); + + runOuterJoin(streamJoined, joinWindows); + } + + public void runOuterJoin(final StreamJoined streamJoined, + final JoinWindows joinWindows) { + final StreamsBuilder builder = new StreamsBuilder(); + + final int[] expectedKeys = new int[] {0, 1, 2, 3}; + + final KStream stream1; + final KStream stream2; + final KStream joined; + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + stream1 = builder.stream(topic1, consumed); + stream2 = builder.stream(topic2, consumed); + + joined = stream1.outerJoin( + stream2, + MockValueJoiner.TOSTRING_JOINER, + joinWindows, + streamJoined + ); + joined.process(supplier); + + final Collection> copartitionGroups = + TopologyWrapper.getInternalTopologyBuilder(builder.build()).copartitionGroups(); + + assertEquals(1, copartitionGroups.size()); + assertEquals(new HashSet<>(Arrays.asList(topic1, topic2)), copartitionGroups.iterator().next()); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockProcessor processor = supplier.theCapturedProcessor(); + + // 2 window stores + 1 shared window store should be available + assertEquals(3, driver.getAllStateStores().size()); + + // push two items to the primary stream; the other window is empty; this should not + // produce any items because window has not expired + // w1 {} + // w2 {} + // --> w1 = { 0:A0, 1:A1 } + // --> w2 = {} + for (int i = 0; i < 2; i++) { + inputTopic1.pipeInput(expectedKeys[i], "A" + expectedKeys[i]); + } + processor.checkAndClearProcessResult(); + + // push two items to the other stream; this should produce two full-joined items + // w1 = { 0:A0, 1:A1 } + // w2 {} + // --> w1 = { 0:A0, 1:A1 } + // --> w2 = { 0:a0, 1:a1 } + for (int i = 0; i < 2; i++) { + inputTopic2.pipeInput(expectedKeys[i], "a" + expectedKeys[i]); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+a0", 0L), + new KeyValueTimestamp<>(1, "A1+a1", 0L) + ); + + // push three items to the primary stream; this should produce two full-joined items + // w1 = { 0:A0, 1:A1 } + // w2 = { 0:a0, 1:a1 } + // --> w1 = { 0:A0, 1:A1, 0:B0, 1:B1, 2:B2 } + // --> w2 = { 0:a0, 1:a1 } + for (int i = 0; i < 3; i++) { + inputTopic1.pipeInput(expectedKeys[i], "B" + expectedKeys[i]); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "B0+a0", 0L), + new KeyValueTimestamp<>(1, "B1+a1", 0L) + ); + + // push all items to the other stream; this should produce five full-joined items + // w1 = { 0:A0, 1:A1, 0:B0, 1:B1, 2:B2 } + // w2 = { 0:a0, 1:a1 } + // --> w1 = { 0:A0, 1:A1, 0:B0, 1:B1, 2:B2 } + // --> w2 = { 0:a0, 1:a1, 0:b0, 1:b1, 2:b2, 3:b3 } + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "b" + expectedKey); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+b0", 0L), + new KeyValueTimestamp<>(0, "B0+b0", 0L), + new KeyValueTimestamp<>(1, "A1+b1", 0L), + new KeyValueTimestamp<>(1, "B1+b1", 0L), + new KeyValueTimestamp<>(2, "B2+b2", 0L) + ); + + // push all four items to the primary stream; this should produce six full-joined items + // w1 = { 0:A0, 1:A1, 0:B0, 1:B1, 2:B2 } + // w2 = { 0:a0, 1:a1, 0:b0, 1:b1, 2:b2, 3:b3 } + // --> w1 = { 0:A0, 1:A1, 0:B0, 1:B1, 2:B2, 0:C0, 1:C1, 2:C2, 3:C3 } + // --> w2 = { 0:a0, 1:a1, 0:b0, 1:b1, 2:b2, 3:b3 } + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "C" + expectedKey); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "C0+a0", 0L), + new KeyValueTimestamp<>(0, "C0+b0", 0L), + new KeyValueTimestamp<>(1, "C1+a1", 0L), + new KeyValueTimestamp<>(1, "C1+b1", 0L), + new KeyValueTimestamp<>(2, "C2+b2", 0L), + new KeyValueTimestamp<>(3, "C3+b3", 0L) + ); + + // push a dummy record that should expire non-joined items; it should not produce any items because + // all of them are joined + inputTopic1.pipeInput(0, "dummy", 400L); + processor.checkAndClearProcessResult(); + } + } + + @Test + public void testWindowing() { + final StreamsBuilder builder = new StreamsBuilder(); + final int[] expectedKeys = new int[]{0, 1, 2, 3}; + + final KStream stream1; + final KStream stream2; + final KStream joined; + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + stream1 = builder.stream(topic1, consumed); + stream2 = builder.stream(topic2, consumed); + + joined = stream1.outerJoin( + stream2, + MockValueJoiner.TOSTRING_JOINER, + JoinWindows.ofTimeDifferenceWithNoGrace(ofMillis(100)), + StreamJoined.with(Serdes.Integer(), Serdes.String(), Serdes.String()) + ); + joined.process(supplier); + + final Collection> copartitionGroups = + TopologyWrapper.getInternalTopologyBuilder(builder.build()).copartitionGroups(); + + assertEquals(1, copartitionGroups.size()); + assertEquals(new HashSet<>(Arrays.asList(topic1, topic2)), copartitionGroups.iterator().next()); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockProcessor processor = supplier.theCapturedProcessor(); + final long time = 0L; + + // push two items to the primary stream; the other window is empty; this should not produce items because window has not closed + // w1 = {} + // w2 = {} + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0) } + // --> w2 = {} + for (int i = 0; i < 2; i++) { + inputTopic1.pipeInput(expectedKeys[i], "A" + expectedKeys[i], time); + } + processor.checkAndClearProcessResult(); + + // push four items to the other stream; this should produce two full-join items + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0) } + // w2 = {} + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0) } + // --> w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0) } + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "a" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "A0+a0", 0L), + new KeyValueTimestamp<>(1, "A1+a1", 0L) + ); + + testUpperWindowBound(expectedKeys, driver, processor); + testLowerWindowBound(expectedKeys, driver, processor); + } + } + + private void testUpperWindowBound(final int[] expectedKeys, + final TopologyTestDriver driver, + final MockProcessor processor) { + long time; + + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + + // push four items with larger and increasing timestamp (out of window) to the other stream; this should produced 2 expired non-joined records + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0) } + // --> w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + time = 1000L; + for (int i = 0; i < expectedKeys.length; i++) { + inputTopic2.pipeInput(expectedKeys[i], "b" + expectedKeys[i], time + i); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(2, "null+a2", 0L), + new KeyValueTimestamp<>(3, "null+a3", 0L) + ); + + // push four items with larger timestamp to the primary stream; this should produce four full-join items + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100) } + // --> w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + time = 1000L + 100L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "B" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "B0+b0", 1100L), + new KeyValueTimestamp<>(1, "B1+b1", 1100L), + new KeyValueTimestamp<>(2, "B2+b2", 1100L), + new KeyValueTimestamp<>(3, "B3+b3", 1100L) + ); + + // push four items with increased timestamp to the primary stream; this should produce three full-join items (non-joined item is not produced yet) + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101) } + // --> w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "C" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(1, "C1+b1", 1101L), + new KeyValueTimestamp<>(2, "C2+b2", 1101L), + new KeyValueTimestamp<>(3, "C3+b3", 1101L) + ); + + // push four items with increased timestamp to the primary stream; this should produce two full-join items (non-joined items are not produced yet) + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102) } + // --> w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "D" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(2, "D2+b2", 1102L), + new KeyValueTimestamp<>(3, "D3+b3", 1102L) + ); + + // push four items with increased timestamp to the primary stream; this should produce one full-join items (three non-joined left-join are not produced yet) + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102), + // 0:E0 (ts: 1103), 1:E1 (ts: 1103), 2:E2 (ts: 1103), 3:E3 (ts: 1103) } + // --> w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "E" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(3, "E3+b3", 1103L) + ); + + // push four items with increased timestamp to the primary stream; this should produce no full-join items (four non-joined left-join are not produced yet) + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102), + // 0:E0 (ts: 1103), 1:E1 (ts: 1103), 2:E2 (ts: 1103), 3:E3 (ts: 1103) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102), + // 0:E0 (ts: 1103), 1:E1 (ts: 1103), 2:E2 (ts: 1103), 3:E3 (ts: 1103), + // 0:F0 (ts: 1104), 1:F1 (ts: 1104), 2:F2 (ts: 1104), 3:F3 (ts: 1104) } + // --> w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "F" + expectedKey, time); + } + processor.checkAndClearProcessResult(); + + // push a dummy record to produce all left-join non-joined items + time += 301L; + driver.advanceWallClockTime(Duration.ofMillis(1000L)); + inputTopic1.pipeInput(0, "dummy", time); + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "C0+null", 1101L), + new KeyValueTimestamp<>(0, "D0+null", 1102L), + new KeyValueTimestamp<>(1, "D1+null", 1102L), + new KeyValueTimestamp<>(0, "E0+null", 1103L), + new KeyValueTimestamp<>(1, "E1+null", 1103L), + new KeyValueTimestamp<>(2, "E2+null", 1103L), + new KeyValueTimestamp<>(0, "F0+null", 1104L), + new KeyValueTimestamp<>(1, "F1+null", 1104L), + new KeyValueTimestamp<>(2, "F2+null", 1104L), + new KeyValueTimestamp<>(3, "F3+null", 1104L) + ); + } + + private void testLowerWindowBound(final int[] expectedKeys, + final TopologyTestDriver driver, + final MockProcessor processor) { + long time; + final TestInputTopic inputTopic1 = driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer()); + + // push four items with smaller timestamp (before the window) to the primary stream; this should produce four left-join and no full-join items + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102), + // 0:E0 (ts: 1103), 1:E1 (ts: 1103), 2:E2 (ts: 1103), 3:E3 (ts: 1103), + // 0:F0 (ts: 1104), 1:F1 (ts: 1104), 2:F2 (ts: 1104), 3:F3 (ts: 1104) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102), + // 0:E0 (ts: 1103), 1:E1 (ts: 1103), 2:E2 (ts: 1103), 3:E3 (ts: 1103), + // 0:F0 (ts: 1104), 1:F1 (ts: 1104), 2:F2 (ts: 1104), 3:F3 (ts: 1104), + // 0:G0 (ts: 899), 1:G1 (ts: 899), 2:G2 (ts: 899), 3:G3 (ts: 899) } + // --> w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + time = 1000L - 100L - 1L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "G" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "G0+null", 899L), + new KeyValueTimestamp<>(1, "G1+null", 899L), + new KeyValueTimestamp<>(2, "G2+null", 899L), + new KeyValueTimestamp<>(3, "G3+null", 899L) + ); + + // push four items with increase timestamp to the primary stream; this should produce three left-join and one full-join items + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102), + // 0:E0 (ts: 1103), 1:E1 (ts: 1103), 2:E2 (ts: 1103), 3:E3 (ts: 1103), + // 0:F0 (ts: 1104), 1:F1 (ts: 1104), 2:F2 (ts: 1104), 3:F3 (ts: 1104), + // 0:G0 (ts: 899), 1:G1 (ts: 899), 2:G2 (ts: 899), 3:G3 (ts: 899) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102), + // 0:E0 (ts: 1103), 1:E1 (ts: 1103), 2:E2 (ts: 1103), 3:E3 (ts: 1103), + // 0:F0 (ts: 1104), 1:F1 (ts: 1104), 2:F2 (ts: 1104), 3:F3 (ts: 1104), + // 0:G0 (ts: 899), 1:G1 (ts: 899), 2:G2 (ts: 899), 3:G3 (ts: 899), + // 0:H0 (ts: 900), 1:H1 (ts: 900), 2:H2 (ts: 900), 3:H3 (ts: 900) } + // --> w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "H" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "H0+b0", 1000L), + new KeyValueTimestamp<>(1, "H1+null", 900L), + new KeyValueTimestamp<>(2, "H2+null", 900L), + new KeyValueTimestamp<>(3, "H3+null", 900L) + ); + + // push four items with increase timestamp to the primary stream; this should produce two left-join and two full-join items + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102), + // 0:E0 (ts: 1103), 1:E1 (ts: 1103), 2:E2 (ts: 1103), 3:E3 (ts: 1103), + // 0:F0 (ts: 1104), 1:F1 (ts: 1104), 2:F2 (ts: 1104), 3:F3 (ts: 1104), + // 0:G0 (ts: 899), 1:G1 (ts: 899), 2:G2 (ts: 899), 3:G3 (ts: 899), + // 0:H0 (ts: 900), 1:H1 (ts: 900), 2:H2 (ts: 900), 3:H3 (ts: 900) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102), + // 0:E0 (ts: 1103), 1:E1 (ts: 1103), 2:E2 (ts: 1103), 3:E3 (ts: 1103), + // 0:F0 (ts: 1104), 1:F1 (ts: 1104), 2:F2 (ts: 1104), 3:F3 (ts: 1104), + // 0:G0 (ts: 899), 1:G1 (ts: 899), 2:G2 (ts: 899), 3:G3 (ts: 899), + // 0:H0 (ts: 900), 1:H1 (ts: 900), 2:H2 (ts: 900), 3:H3 (ts: 900), + // 0:I0 (ts: 901), 1:I1 (ts: 901), 2:I2 (ts: 901), 3:I3 (ts: 901) } + // --> w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "I" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "I0+b0", 1000L), + new KeyValueTimestamp<>(1, "I1+b1", 1001L), + new KeyValueTimestamp<>(2, "I2+null", 901L), + new KeyValueTimestamp<>(3, "I3+null", 901L) + ); + + // push four items with increase timestamp to the primary stream; this should produce one left-join and three full-join items + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102), + // 0:E0 (ts: 1103), 1:E1 (ts: 1103), 2:E2 (ts: 1103), 3:E3 (ts: 1103), + // 0:F0 (ts: 1104), 1:F1 (ts: 1104), 2:F2 (ts: 1104), 3:F3 (ts: 1104), + // 0:G0 (ts: 899), 1:G1 (ts: 899), 2:G2 (ts: 899), 3:G3 (ts: 899), + // 0:H0 (ts: 900), 1:H1 (ts: 900), 2:H2 (ts: 900), 3:H3 (ts: 900), + // 0:I0 (ts: 901), 1:I1 (ts: 901), 2:I2 (ts: 901), 3:I3 (ts: 901) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102), + // 0:E0 (ts: 1103), 1:E1 (ts: 1103), 2:E2 (ts: 1103), 3:E3 (ts: 1103), + // 0:F0 (ts: 1104), 1:F1 (ts: 1104), 2:F2 (ts: 1104), 3:F3 (ts: 1104), + // 0:G0 (ts: 899), 1:G1 (ts: 899), 2:G2 (ts: 899), 3:G3 (ts: 899), + // 0:H0 (ts: 900), 1:H1 (ts: 900), 2:H2 (ts: 900), 3:H3 (ts: 900), + // 0:I0 (ts: 901), 1:I1 (ts: 901), 2:I2 (ts: 901), 3:I3 (ts: 901), + // 0:J0 (ts: 902), 1:J1 (ts: 902), 2:J2 (ts: 902), 3:J3 (ts: 902) } + // --> w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "J" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "J0+b0", 1000L), + new KeyValueTimestamp<>(1, "J1+b1", 1001L), + new KeyValueTimestamp<>(2, "J2+b2", 1002L), + new KeyValueTimestamp<>(3, "J3+null", 902L) + ); + + // push four items with increase timestamp to the primary stream; this should produce one left-join and three full-join items + // w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102), + // 0:E0 (ts: 1103), 1:E1 (ts: 1103), 2:E2 (ts: 1103), 3:E3 (ts: 1103), + // 0:F0 (ts: 1104), 1:F1 (ts: 1104), 2:F2 (ts: 1104), 3:F3 (ts: 1104), + // 0:G0 (ts: 899), 1:G1 (ts: 899), 2:G2 (ts: 899), 3:G3 (ts: 899), + // 0:H0 (ts: 900), 1:H1 (ts: 900), 2:H2 (ts: 900), 3:H3 (ts: 900), + // 0:I0 (ts: 901), 1:I1 (ts: 901), 2:I2 (ts: 901), 3:I3 (ts: 901), + // 0:J0 (ts: 902), 1:J1 (ts: 902), 2:J2 (ts: 902), 3:J3 (ts: 902) } + // w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + // --> w1 = { 0:A0 (ts: 0), 1:A1 (ts: 0), + // 0:B0 (ts: 1100), 1:B1 (ts: 1100), 2:B2 (ts: 1100), 3:B3 (ts: 1100), + // 0:C0 (ts: 1101), 1:C1 (ts: 1101), 2:C2 (ts: 1101), 3:C3 (ts: 1101), + // 0:D0 (ts: 1102), 1:D1 (ts: 1102), 2:D2 (ts: 1102), 3:D3 (ts: 1102), + // 0:E0 (ts: 1103), 1:E1 (ts: 1103), 2:E2 (ts: 1103), 3:E3 (ts: 1103), + // 0:F0 (ts: 1104), 1:F1 (ts: 1104), 2:F2 (ts: 1104), 3:F3 (ts: 1104), + // 0:G0 (ts: 899), 1:G1 (ts: 899), 2:G2 (ts: 899), 3:G3 (ts: 899), + // 0:H0 (ts: 900), 1:H1 (ts: 900), 2:H2 (ts: 900), 3:H3 (ts: 900), + // 0:I0 (ts: 901), 1:I1 (ts: 901), 2:I2 (ts: 901), 3:I3 (ts: 901), + // 0:J0 (ts: 902), 1:J1 (ts: 902), 2:J2 (ts: 902), 3:J3 (ts: 902), + // 0:K0 (ts: 903), 1:K1 (ts: 903), 2:K2 (ts: 903), 3:K3 (ts: 903) } + // --> w2 = { 0:a0 (ts: 0), 1:a1 (ts: 0), 2:a2 (ts: 0), 3:a3 (ts: 0), + // 0:b0 (ts: 1000), 1:b1 (ts: 1001), 2:b2 (ts: 1002), 3:b3 (ts: 1003) } + time += 1L; + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "K" + expectedKey, time); + } + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "K0+b0", 1000L), + new KeyValueTimestamp<>(1, "K1+b1", 1001L), + new KeyValueTimestamp<>(2, "K2+b2", 1002L), + new KeyValueTimestamp<>(3, "K3+b3", 1003L) + ); + + // push a dummy record to verify there are no expired records to produce + // dummy window is behind the max. stream time seen (1205 used in testUpperWindowBound) + inputTopic1.pipeInput(0, "dummy", time + 200L); + processor.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, "dummy+null", 1103L) + ); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamKTableJoinTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamKTableJoinTest.java new file mode 100644 index 0000000..3d33827 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamKTableJoinTest.java @@ -0,0 +1,345 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; + +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.Properties; +import java.util.Random; +import java.util.Set; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.TopologyWrapper; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Joined; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.apache.kafka.test.MockProcessor; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.MockValueJoiner; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class KStreamKTableJoinTest { + private final static KeyValueTimestamp[] EMPTY = new KeyValueTimestamp[0]; + + private final String streamTopic = "streamTopic"; + private final String tableTopic = "tableTopic"; + private TestInputTopic inputStreamTopic; + private TestInputTopic inputTableTopic; + private final int[] expectedKeys = {0, 1, 2, 3}; + + private MockProcessor processor; + private TopologyTestDriver driver; + private StreamsBuilder builder; + private final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Before + public void setUp() { + builder = new StreamsBuilder(); + + final KStream stream; + final KTable table; + + final Consumed consumed = Consumed.with(Serdes.Integer(), Serdes.String()); + stream = builder.stream(streamTopic, consumed); + table = builder.table(tableTopic, consumed); + stream.join(table, MockValueJoiner.TOSTRING_JOINER).process(supplier); + final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.Integer(), Serdes.String()); + driver = new TopologyTestDriver(builder.build(), props); + inputStreamTopic = driver.createInputTopic(streamTopic, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + inputTableTopic = driver.createInputTopic(tableTopic, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + + processor = supplier.theCapturedProcessor(); + } + + @After + public void cleanup() { + driver.close(); + } + + private void pushToStream(final int messageCount, final String valuePrefix) { + for (int i = 0; i < messageCount; i++) { + inputStreamTopic.pipeInput(expectedKeys[i], valuePrefix + expectedKeys[i], i); + } + } + + private void pushToTable(final int messageCount, final String valuePrefix) { + final Random r = new Random(System.currentTimeMillis()); + for (int i = 0; i < messageCount; i++) { + inputTableTopic.pipeInput( + expectedKeys[i], + valuePrefix + expectedKeys[i], + r.nextInt(Integer.MAX_VALUE)); + } + } + + private void pushNullValueToTable() { + for (int i = 0; i < 2; i++) { + inputTableTopic.pipeInput(expectedKeys[i], null); + } + } + + @Test + public void shouldReuseRepartitionTopicWithGeneratedName() { + final StreamsBuilder builder = new StreamsBuilder(); + final Properties props = new Properties(); + props.put(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, StreamsConfig.NO_OPTIMIZATION); + final KStream streamA = builder.stream("topic", Consumed.with(Serdes.String(), Serdes.String())); + final KTable tableB = builder.table("topic2", Consumed.with(Serdes.String(), Serdes.String())); + final KTable tableC = builder.table("topic3", Consumed.with(Serdes.String(), Serdes.String())); + final KStream rekeyedStream = streamA.map((k, v) -> new KeyValue<>(v, k)); + rekeyedStream.join(tableB, (value1, value2) -> value1 + value2).to("out-one"); + rekeyedStream.join(tableC, (value1, value2) -> value1 + value2).to("out-two"); + final Topology topology = builder.build(props); + assertEquals(expectedTopologyWithGeneratedRepartitionTopicNames, topology.describe().toString()); + } + + @Test + public void shouldCreateRepartitionTopicsWithUserProvidedName() { + final StreamsBuilder builder = new StreamsBuilder(); + final Properties props = new Properties(); + props.put(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, StreamsConfig.NO_OPTIMIZATION); + final KStream streamA = builder.stream("topic", Consumed.with(Serdes.String(), Serdes.String())); + final KTable tableB = builder.table("topic2", Consumed.with(Serdes.String(), Serdes.String())); + final KTable tableC = builder.table("topic3", Consumed.with(Serdes.String(), Serdes.String())); + final KStream rekeyedStream = streamA.map((k, v) -> new KeyValue<>(v, k)); + + rekeyedStream.join(tableB, (value1, value2) -> value1 + value2, Joined.with(Serdes.String(), Serdes.String(), Serdes.String(), "first-join")).to("out-one"); + rekeyedStream.join(tableC, (value1, value2) -> value1 + value2, Joined.with(Serdes.String(), Serdes.String(), Serdes.String(), "second-join")).to("out-two"); + final Topology topology = builder.build(props); + System.out.println(topology.describe().toString()); + assertEquals(expectedTopologyWithUserProvidedRepartitionTopicNames, topology.describe().toString()); + } + + @Test + public void shouldRequireCopartitionedStreams() { + final Collection> copartitionGroups = + TopologyWrapper.getInternalTopologyBuilder(builder.build()).copartitionGroups(); + + assertEquals(1, copartitionGroups.size()); + assertEquals(new HashSet<>(Arrays.asList(streamTopic, tableTopic)), copartitionGroups.iterator().next()); + } + + @Test + public void shouldNotJoinWithEmptyTableOnStreamUpdates() { + // push two items to the primary stream. the table is empty + pushToStream(2, "X"); + processor.checkAndClearProcessResult(EMPTY); + } + + @Test + public void shouldNotJoinOnTableUpdates() { + // push two items to the primary stream. the table is empty + pushToStream(2, "X"); + processor.checkAndClearProcessResult(EMPTY); + + // push two items to the table. this should not produce any item. + pushToTable(2, "Y"); + processor.checkAndClearProcessResult(EMPTY); + + // push all four items to the primary stream. this should produce two items. + pushToStream(4, "X"); + processor.checkAndClearProcessResult(new KeyValueTimestamp<>(0, "X0+Y0", 0), + new KeyValueTimestamp<>(1, "X1+Y1", 1)); + + // push all items to the table. this should not produce any item + pushToTable(4, "YY"); + processor.checkAndClearProcessResult(EMPTY); + + // push all four items to the primary stream. this should produce four items. + pushToStream(4, "X"); + processor.checkAndClearProcessResult(new KeyValueTimestamp<>(0, "X0+YY0", 0), + new KeyValueTimestamp<>(1, "X1+YY1", 1), + new KeyValueTimestamp<>(2, "X2+YY2", 2), + new KeyValueTimestamp<>(3, "X3+YY3", 3)); + + // push all items to the table. this should not produce any item + pushToTable(4, "YYY"); + processor.checkAndClearProcessResult(EMPTY); + } + + @Test + public void shouldJoinOnlyIfMatchFoundOnStreamUpdates() { + // push two items to the table. this should not produce any item. + pushToTable(2, "Y"); + processor.checkAndClearProcessResult(EMPTY); + + // push all four items to the primary stream. this should produce two items. + pushToStream(4, "X"); + processor.checkAndClearProcessResult(new KeyValueTimestamp<>(0, "X0+Y0", 0), + new KeyValueTimestamp<>(1, "X1+Y1", 1)); + } + + @Test + public void shouldClearTableEntryOnNullValueUpdates() { + // push all four items to the table. this should not produce any item. + pushToTable(4, "Y"); + processor.checkAndClearProcessResult(EMPTY); + + // push all four items to the primary stream. this should produce four items. + pushToStream(4, "X"); + processor.checkAndClearProcessResult(new KeyValueTimestamp<>(0, "X0+Y0", 0), + new KeyValueTimestamp<>(1, "X1+Y1", 1), + new KeyValueTimestamp<>(2, "X2+Y2", 2), + new KeyValueTimestamp<>(3, "X3+Y3", 3)); + + // push two items with null to the table as deletes. this should not produce any item. + pushNullValueToTable(); + processor.checkAndClearProcessResult(EMPTY); + + // push all four items to the primary stream. this should produce two items. + pushToStream(4, "XX"); + processor.checkAndClearProcessResult(new KeyValueTimestamp<>(2, "XX2+Y2", 2), + new KeyValueTimestamp<>(3, "XX3+Y3", 3)); + } + + @Test + public void shouldLogAndMeterWhenSkippingNullLeftKey() { + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(KStreamKTableJoin.class)) { + final TestInputTopic inputTopic = + driver.createInputTopic(streamTopic, new IntegerSerializer(), new StringSerializer()); + inputTopic.pipeInput(null, "A"); + + assertThat( + appender.getMessages(), + hasItem("Skipping record due to null join key or value. key=[null] value=[A] topic=[streamTopic] partition=[0] " + + "offset=[0]")); + } + } + + @Test + public void shouldLogAndMeterWhenSkippingNullLeftValue() { + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(KStreamKTableJoin.class)) { + final TestInputTopic inputTopic = + driver.createInputTopic(streamTopic, new IntegerSerializer(), new StringSerializer()); + inputTopic.pipeInput(1, null); + + assertThat( + appender.getMessages(), + hasItem("Skipping record due to null join key or value. key=[1] value=[null] topic=[streamTopic] partition=[0] " + + "offset=[0]") + ); + } + } + + + private final String expectedTopologyWithGeneratedRepartitionTopicNames = + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [topic])\n" + + " --> KSTREAM-MAP-0000000007\n" + + " Processor: KSTREAM-MAP-0000000007 (stores: [])\n" + + " --> KSTREAM-FILTER-0000000009\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: KSTREAM-FILTER-0000000009 (stores: [])\n" + + " --> KSTREAM-SINK-0000000008\n" + + " <-- KSTREAM-MAP-0000000007\n" + + " Sink: KSTREAM-SINK-0000000008 (topic: KSTREAM-MAP-0000000007-repartition)\n" + + " <-- KSTREAM-FILTER-0000000009\n" + + "\n" + + " Sub-topology: 1\n" + + " Source: KSTREAM-SOURCE-0000000010 (topics: [KSTREAM-MAP-0000000007-repartition])\n" + + " --> KSTREAM-JOIN-0000000011, KSTREAM-JOIN-0000000016\n" + + " Processor: KSTREAM-JOIN-0000000011 (stores: [topic2-STATE-STORE-0000000001])\n" + + " --> KSTREAM-SINK-0000000012\n" + + " <-- KSTREAM-SOURCE-0000000010\n" + + " Processor: KSTREAM-JOIN-0000000016 (stores: [topic3-STATE-STORE-0000000004])\n" + + " --> KSTREAM-SINK-0000000017\n" + + " <-- KSTREAM-SOURCE-0000000010\n" + + " Source: KSTREAM-SOURCE-0000000002 (topics: [topic2])\n" + + " --> KTABLE-SOURCE-0000000003\n" + + " Source: KSTREAM-SOURCE-0000000005 (topics: [topic3])\n" + + " --> KTABLE-SOURCE-0000000006\n" + + " Sink: KSTREAM-SINK-0000000012 (topic: out-one)\n" + + " <-- KSTREAM-JOIN-0000000011\n" + + " Sink: KSTREAM-SINK-0000000017 (topic: out-two)\n" + + " <-- KSTREAM-JOIN-0000000016\n" + + " Processor: KTABLE-SOURCE-0000000003 (stores: [topic2-STATE-STORE-0000000001])\n" + + " --> none\n" + + " <-- KSTREAM-SOURCE-0000000002\n" + + " Processor: KTABLE-SOURCE-0000000006 (stores: [topic3-STATE-STORE-0000000004])\n" + + " --> none\n" + + " <-- KSTREAM-SOURCE-0000000005\n\n"; + + + private final String expectedTopologyWithUserProvidedRepartitionTopicNames = + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [topic])\n" + + " --> KSTREAM-MAP-0000000007\n" + + " Processor: KSTREAM-MAP-0000000007 (stores: [])\n" + + " --> first-join-repartition-filter, second-join-repartition-filter\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: first-join-repartition-filter (stores: [])\n" + + " --> first-join-repartition-sink\n" + + " <-- KSTREAM-MAP-0000000007\n" + + " Processor: second-join-repartition-filter (stores: [])\n" + + " --> second-join-repartition-sink\n" + + " <-- KSTREAM-MAP-0000000007\n" + + " Sink: first-join-repartition-sink (topic: first-join-repartition)\n" + + " <-- first-join-repartition-filter\n" + + " Sink: second-join-repartition-sink (topic: second-join-repartition)\n" + + " <-- second-join-repartition-filter\n" + + "\n" + + " Sub-topology: 1\n" + + " Source: first-join-repartition-source (topics: [first-join-repartition])\n" + + " --> first-join\n" + + " Source: KSTREAM-SOURCE-0000000002 (topics: [topic2])\n" + + " --> KTABLE-SOURCE-0000000003\n" + + " Processor: first-join (stores: [topic2-STATE-STORE-0000000001])\n" + + " --> KSTREAM-SINK-0000000012\n" + + " <-- first-join-repartition-source\n" + + " Sink: KSTREAM-SINK-0000000012 (topic: out-one)\n" + + " <-- first-join\n" + + " Processor: KTABLE-SOURCE-0000000003 (stores: [topic2-STATE-STORE-0000000001])\n" + + " --> none\n" + + " <-- KSTREAM-SOURCE-0000000002\n" + + "\n" + + " Sub-topology: 2\n" + + " Source: second-join-repartition-source (topics: [second-join-repartition])\n" + + " --> second-join\n" + + " Source: KSTREAM-SOURCE-0000000005 (topics: [topic3])\n" + + " --> KTABLE-SOURCE-0000000006\n" + + " Processor: second-join (stores: [topic3-STATE-STORE-0000000004])\n" + + " --> KSTREAM-SINK-0000000017\n" + + " <-- second-join-repartition-source\n" + + " Sink: KSTREAM-SINK-0000000017 (topic: out-two)\n" + + " <-- second-join\n" + + " Processor: KTABLE-SOURCE-0000000006 (stores: [topic3-STATE-STORE-0000000004])\n" + + " --> none\n" + + " <-- KSTREAM-SOURCE-0000000005\n\n"; +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamKTableLeftJoinTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamKTableLeftJoinTest.java new file mode 100644 index 0000000..d9f227c --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamKTableLeftJoinTest.java @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.TopologyWrapper; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.test.MockProcessor; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.MockValueJoiner; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.Properties; +import java.util.Random; +import java.util.Set; + +import static org.junit.Assert.assertEquals; + +public class KStreamKTableLeftJoinTest { + private final static KeyValueTimestamp[] EMPTY = new KeyValueTimestamp[0]; + + private final String streamTopic = "streamTopic"; + private final String tableTopic = "tableTopic"; + private TestInputTopic inputStreamTopic; + private TestInputTopic inputTableTopic; + private final int[] expectedKeys = {0, 1, 2, 3}; + + private TopologyTestDriver driver; + private MockProcessor processor; + private StreamsBuilder builder; + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Before + public void setUp() { + builder = new StreamsBuilder(); + + final KStream stream; + final KTable table; + + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + final Consumed consumed = Consumed.with(Serdes.Integer(), Serdes.String()); + stream = builder.stream(streamTopic, consumed); + table = builder.table(tableTopic, consumed); + stream.leftJoin(table, MockValueJoiner.TOSTRING_JOINER).process(supplier); + + final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.Integer(), Serdes.String()); + driver = new TopologyTestDriver(builder.build(), props); + inputStreamTopic = driver.createInputTopic(streamTopic, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + inputTableTopic = driver.createInputTopic(tableTopic, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + + processor = supplier.theCapturedProcessor(); + } + + @After + public void cleanup() { + driver.close(); + } + + private void pushToStream(final int messageCount, final String valuePrefix) { + for (int i = 0; i < messageCount; i++) { + inputStreamTopic.pipeInput(expectedKeys[i], valuePrefix + expectedKeys[i], i); + } + } + + private void pushToTable(final int messageCount, final String valuePrefix) { + final Random r = new Random(System.currentTimeMillis()); + for (int i = 0; i < messageCount; i++) { + inputTableTopic.pipeInput( + expectedKeys[i], + valuePrefix + expectedKeys[i], + r.nextInt(Integer.MAX_VALUE)); + } + } + + private void pushNullValueToTable(final int messageCount) { + for (int i = 0; i < messageCount; i++) { + inputTableTopic.pipeInput(expectedKeys[i], null); + } + } + + @Test + public void shouldRequireCopartitionedStreams() { + final Collection> copartitionGroups = + TopologyWrapper.getInternalTopologyBuilder(builder.build()).copartitionGroups(); + + assertEquals(1, copartitionGroups.size()); + assertEquals(new HashSet<>(Arrays.asList(streamTopic, tableTopic)), copartitionGroups.iterator().next()); + } + + @Test + public void shouldJoinWithEmptyTableOnStreamUpdates() { + // push two items to the primary stream. the table is empty + pushToStream(2, "X"); + processor.checkAndClearProcessResult(new KeyValueTimestamp<>(0, "X0+null", 0), + new KeyValueTimestamp<>(1, "X1+null", 1)); + } + + @Test + public void shouldNotJoinOnTableUpdates() { + // push two items to the primary stream. the table is empty + pushToStream(2, "X"); + processor.checkAndClearProcessResult(new KeyValueTimestamp<>(0, "X0+null", 0), + new KeyValueTimestamp<>(1, "X1+null", 1)); + + // push two items to the table. this should not produce any item. + pushToTable(2, "Y"); + processor.checkAndClearProcessResult(EMPTY); + + // push all four items to the primary stream. this should produce four items. + pushToStream(4, "X"); + processor.checkAndClearProcessResult(new KeyValueTimestamp<>(0, "X0+Y0", 0), + new KeyValueTimestamp<>(1, "X1+Y1", 1), + new KeyValueTimestamp<>(2, "X2+null", 2), + new KeyValueTimestamp<>(3, "X3+null", 3)); + + // push all items to the table. this should not produce any item + pushToTable(4, "YY"); + processor.checkAndClearProcessResult(EMPTY); + + // push all four items to the primary stream. this should produce four items. + pushToStream(4, "X"); + processor.checkAndClearProcessResult(new KeyValueTimestamp<>(0, "X0+YY0", 0), + new KeyValueTimestamp<>(1, "X1+YY1", 1), + new KeyValueTimestamp<>(2, "X2+YY2", 2), + new KeyValueTimestamp<>(3, "X3+YY3", 3)); + + // push all items to the table. this should not produce any item + pushToTable(4, "YYY"); + processor.checkAndClearProcessResult(EMPTY); + } + + @Test + public void shouldJoinRegardlessIfMatchFoundOnStreamUpdates() { + // push two items to the table. this should not produce any item. + pushToTable(2, "Y"); + processor.checkAndClearProcessResult(EMPTY); + + // push all four items to the primary stream. this should produce four items. + pushToStream(4, "X"); + processor.checkAndClearProcessResult(new KeyValueTimestamp<>(0, "X0+Y0", 0), + new KeyValueTimestamp<>(1, "X1+Y1", 1), + new KeyValueTimestamp<>(2, "X2+null", 2), + new KeyValueTimestamp<>(3, "X3+null", 3)); + + } + + @Test + public void shouldClearTableEntryOnNullValueUpdates() { + // push all four items to the table. this should not produce any item. + pushToTable(4, "Y"); + processor.checkAndClearProcessResult(EMPTY); + + // push all four items to the primary stream. this should produce four items. + pushToStream(4, "X"); + processor.checkAndClearProcessResult(new KeyValueTimestamp<>(0, "X0+Y0", 0), + new KeyValueTimestamp<>(1, "X1+Y1", 1), + new KeyValueTimestamp<>(2, "X2+Y2", 2), + new KeyValueTimestamp<>(3, "X3+Y3", 3)); + + // push two items with null to the table as deletes. this should not produce any item. + pushNullValueToTable(2); + processor.checkAndClearProcessResult(EMPTY); + + // push all four items to the primary stream. this should produce four items. + pushToStream(4, "XX"); + processor.checkAndClearProcessResult(new KeyValueTimestamp<>(0, "XX0+null", 0), + new KeyValueTimestamp<>(1, "XX1+null", 1), + new KeyValueTimestamp<>(2, "XX2+Y2", 2), + new KeyValueTimestamp<>(3, "XX3+Y3", 3)); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamMapTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamMapTest.java new file mode 100644 index 0000000..bf5ff26 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamMapTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Test; + +import java.time.Duration; +import java.time.Instant; +import java.util.Properties; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.Is.is; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +public class KStreamMapTest { + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.Integer(), Serdes.String()); + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Test + public void testMap() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topicName = "topic"; + final int[] expectedKeys = new int[] {0, 1, 2, 3}; + + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + final KStream stream = builder.stream(topicName, Consumed.with(Serdes.Integer(), Serdes.String())); + stream.map((key, value) -> KeyValue.pair(value, key)).process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + for (final int expectedKey : expectedKeys) { + final TestInputTopic inputTopic = + driver.createInputTopic(topicName, new IntegerSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + inputTopic.pipeInput(expectedKey, "V" + expectedKey, 10L - expectedKey); + } + } + + final KeyValueTimestamp[] expected = new KeyValueTimestamp[] {new KeyValueTimestamp<>("V0", 0, 10), + new KeyValueTimestamp<>("V1", 1, 9), + new KeyValueTimestamp<>("V2", 2, 8), + new KeyValueTimestamp<>("V3", 3, 7)}; + assertEquals(4, supplier.theCapturedProcessor().processed().size()); + for (int i = 0; i < expected.length; i++) { + assertEquals(expected[i], supplier.theCapturedProcessor().processed().get(i)); + } + } + + @Test + public void testKeyValueMapperResultNotNull() { + final KStreamMap supplier = new KStreamMap<>((key, value) -> null); + final Throwable throwable = assertThrows(NullPointerException.class, + () -> supplier.get().process(new Record<>("K", 0, 0L))); + assertThat(throwable.getMessage(), is("The provided KeyValueMapper returned null which is not allowed.")); + } + + @Test + public void testTypeVariance() { + new StreamsBuilder() + .stream("numbers") + .map((key, value) -> KeyValue.pair(key, key + ":" + value)) + .to("strings"); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamMapValuesTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamMapValuesTest.java new file mode 100644 index 0000000..fe1e4a3 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamMapValuesTest.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.ValueMapperWithKey; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Test; + +import java.util.Properties; + +import static org.junit.Assert.assertArrayEquals; + +public class KStreamMapValuesTest { + private final String topicName = "topic"; + private final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.Integer(), Serdes.String()); + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Test + public void testFlatMapValues() { + final StreamsBuilder builder = new StreamsBuilder(); + + final int[] expectedKeys = {1, 10, 100, 1000}; + + final KStream stream = builder.stream(topicName, Consumed.with(Serdes.Integer(), Serdes.String())); + stream.mapValues(CharSequence::length).process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + for (final int expectedKey : expectedKeys) { + final TestInputTopic inputTopic = + driver.createInputTopic(topicName, new IntegerSerializer(), new StringSerializer()); + inputTopic.pipeInput(expectedKey, Integer.toString(expectedKey), expectedKey / 2L); + } + } + final KeyValueTimestamp[] expected = {new KeyValueTimestamp<>(1, 1, 0), + new KeyValueTimestamp<>(10, 2, 5), + new KeyValueTimestamp<>(100, 3, 50), + new KeyValueTimestamp<>(1000, 4, 500)}; + + assertArrayEquals(expected, supplier.theCapturedProcessor().processed().toArray()); + } + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Test + public void testMapValuesWithKeys() { + final StreamsBuilder builder = new StreamsBuilder(); + + final ValueMapperWithKey mapper = (readOnlyKey, value) -> value.length() + readOnlyKey; + + final int[] expectedKeys = {1, 10, 100, 1000}; + + final KStream stream = builder.stream(topicName, Consumed.with(Serdes.Integer(), Serdes.String())); + stream.mapValues(mapper).process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(topicName, new IntegerSerializer(), new StringSerializer()); + for (final int expectedKey : expectedKeys) { + inputTopic.pipeInput(expectedKey, Integer.toString(expectedKey), expectedKey / 2L); + } + } + final KeyValueTimestamp[] expected = {new KeyValueTimestamp<>(1, 2, 0), + new KeyValueTimestamp<>(10, 12, 5), + new KeyValueTimestamp<>(100, 103, 50), + new KeyValueTimestamp<>(1000, 1004, 500)}; + + assertArrayEquals(expected, supplier.theCapturedProcessor().processed().toArray()); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamPeekTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamPeekTest.java new file mode 100644 index 0000000..f04278b --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamPeekTest.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.ForeachAction; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.Properties; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +public class KStreamPeekTest { + + private final String topicName = "topic"; + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.Integer(), Serdes.String()); + + @Test + public void shouldObserveStreamElements() { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream = builder.stream(topicName, Consumed.with(Serdes.Integer(), Serdes.String())); + final List> peekObserved = new ArrayList<>(), streamObserved = new ArrayList<>(); + stream.peek(collect(peekObserved)).foreach(collect(streamObserved)); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = driver.createInputTopic(topicName, new IntegerSerializer(), new StringSerializer()); + final List> expected = new ArrayList<>(); + for (int key = 0; key < 32; key++) { + final String value = "V" + key; + inputTopic.pipeInput(key, value); + expected.add(new KeyValue<>(key, value)); + } + + assertEquals(expected, peekObserved); + assertEquals(expected, streamObserved); + } + } + + @Test + public void shouldNotAllowNullAction() { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream = builder.stream(topicName, Consumed.with(Serdes.Integer(), Serdes.String())); + try { + stream.peek(null); + fail("expected null action to throw NPE"); + } catch (final NullPointerException expected) { + // do nothing + } + } + + private static ForeachAction collect(final List> into) { + return (key, value) -> into.add(new KeyValue<>(key, value)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamPrintTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamPrintTest.java new file mode 100644 index 0000000..2915a11 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamPrintTest.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.Record; +import org.easymock.EasyMock; +import org.junit.Before; +import org.junit.Test; + +import java.io.ByteArrayOutputStream; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +public class KStreamPrintTest { + + private ByteArrayOutputStream byteOutStream; + private Processor printProcessor; + + @Before + public void setUp() { + byteOutStream = new ByteArrayOutputStream(); + + final KStreamPrint kStreamPrint = new KStreamPrint<>(new PrintForeachAction<>( + byteOutStream, + (key, value) -> String.format("%d, %s", key, value), + "test-stream")); + + printProcessor = kStreamPrint.get(); + final ProcessorContext processorContext = EasyMock.createNiceMock(ProcessorContext.class); + EasyMock.replay(processorContext); + + printProcessor.init(processorContext); + } + + @Test + public void testPrintStreamWithProvidedKeyValueMapper() { + final List> inputRecords = Arrays.asList( + new KeyValue<>(0, "zero"), + new KeyValue<>(1, "one"), + new KeyValue<>(2, "two"), + new KeyValue<>(3, "three")); + + final String[] expectedResult = { + "[test-stream]: 0, zero", + "[test-stream]: 1, one", + "[test-stream]: 2, two", + "[test-stream]: 3, three"}; + + for (final KeyValue record: inputRecords) { + final Record r = new Record<>(record.key, record.value, 0L); + printProcessor.process(r); + } + printProcessor.close(); + + final String[] flushOutDatas = new String(byteOutStream.toByteArray(), StandardCharsets.UTF_8).split("\\r*\\n"); + for (int i = 0; i < flushOutDatas.length; i++) { + assertEquals(expectedResult[i], flushOutDatas[i]); + } + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamRepartitionTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamRepartitionTest.java new file mode 100644 index 0000000..0344f46 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamRepartitionTest.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.IntegerDeserializer; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TestOutputTopic; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.errors.TopologyException; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.JoinWindows; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Repartitioned; +import org.apache.kafka.streams.processor.StreamPartitioner; +import org.apache.kafka.streams.test.TestRecord; +import org.apache.kafka.test.StreamsTestUtils; +import org.easymock.EasyMock; +import org.easymock.EasyMockRunner; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import java.time.Duration; +import java.time.Instant; +import java.util.Map; +import java.util.Properties; +import java.util.TreeMap; + +import static org.easymock.EasyMock.anyInt; +import static org.easymock.EasyMock.anyString; +import static org.easymock.EasyMock.eq; +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.verify; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertThrows; + +@SuppressWarnings("deprecation") +@RunWith(EasyMockRunner.class) +public class KStreamRepartitionTest { + private final String inputTopic = "input-topic"; + + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.Integer(), Serdes.String()); + + private StreamsBuilder builder; + + @Before + public void setUp() { + builder = new StreamsBuilder(); + } + + @Test + public void shouldInvokePartitionerWhenSet() { + final int[] expectedKeys = new int[]{0, 1}; + final StreamPartitioner streamPartitionerMock = EasyMock.mock(StreamPartitioner.class); + + expect(streamPartitionerMock.partition(anyString(), eq(0), eq("X0"), anyInt())).andReturn(1).times(1); + expect(streamPartitionerMock.partition(anyString(), eq(1), eq("X1"), anyInt())).andReturn(1).times(1); + replay(streamPartitionerMock); + + final String repartitionOperationName = "test"; + final Repartitioned repartitioned = Repartitioned + .streamPartitioner(streamPartitionerMock) + .withName(repartitionOperationName); + + builder.stream(inputTopic) + .repartition(repartitioned); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic testInputTopic = driver.createInputTopic(inputTopic, + new IntegerSerializer(), + new StringSerializer()); + + final String topicName = repartitionOutputTopic(props, repartitionOperationName); + + final TestOutputTopic testOutputTopic = driver.createOutputTopic( + topicName, + new IntegerDeserializer(), + new StringDeserializer() + ); + + for (int i = 0; i < 2; i++) { + testInputTopic.pipeInput(expectedKeys[i], "X" + expectedKeys[i], i + 10); + } + + assertThat(testOutputTopic.readRecord(), equalTo(new TestRecord<>(0, "X0", Instant.ofEpochMilli(10)))); + assertThat(testOutputTopic.readRecord(), equalTo(new TestRecord<>(1, "X1", Instant.ofEpochMilli(11)))); + assertTrue(testOutputTopic.readRecordsToList().isEmpty()); + } + + verify(streamPartitionerMock); + } + + @Test + public void shouldThrowAnExceptionWhenNumberOfPartitionsOfRepartitionOperationsDoNotMatchWhenJoining() { + final String topicB = "topic-b"; + final String outputTopic = "topic-output"; + final String topicBRepartitionedName = "topic-b-scale-up"; + final String inputTopicRepartitionedName = "input-topic-scale-up"; + final int topicBNumberOfPartitions = 2; + final int inputTopicNumberOfPartitions = 4; + final StreamsBuilder builder = new StreamsBuilder(); + + final Repartitioned inputTopicRepartitioned = Repartitioned + .as(inputTopicRepartitionedName) + .withNumberOfPartitions(inputTopicNumberOfPartitions); + + final Repartitioned topicBRepartitioned = Repartitioned + .as(topicBRepartitionedName) + .withNumberOfPartitions(topicBNumberOfPartitions); + + final KStream topicBStream = builder + .stream(topicB, Consumed.with(Serdes.Integer(), Serdes.String())) + .repartition(topicBRepartitioned); + + builder.stream(inputTopic, Consumed.with(Serdes.Integer(), Serdes.String())) + .repartition(inputTopicRepartitioned) + .join(topicBStream, (value1, value2) -> value2, JoinWindows.of(Duration.ofSeconds(10))) + .to(outputTopic); + + final Map repartitionTopicsWithNumOfPartitions = Utils.mkMap( + Utils.mkEntry(toRepartitionTopicName(topicBRepartitionedName), topicBNumberOfPartitions), + Utils.mkEntry(toRepartitionTopicName(inputTopicRepartitionedName), inputTopicNumberOfPartitions) + ); + + final TopologyException expected = assertThrows( + TopologyException.class, () -> builder.build(props) + ); + final String expectedErrorMessage = String.format("Following topics do not have the same " + + "number of partitions: [%s]", + new TreeMap<>(repartitionTopicsWithNumOfPartitions)); + assertNotNull(expected); + assertTrue(expected.getMessage().contains(expectedErrorMessage)); + } + + private String toRepartitionTopicName(final String input) { + return input + "-repartition"; + } + + private String repartitionOutputTopic(final Properties props, final String repartitionOperationName) { + return props.getProperty(StreamsConfig.APPLICATION_ID_CONFIG) + "-" + repartitionOperationName + "-repartition"; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamSelectKeyTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamSelectKeyTest.java new file mode 100644 index 0000000..d8e70f9 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamSelectKeyTest.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Test; + +import java.time.Duration; +import java.time.Instant; +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; + +import static org.junit.Assert.assertEquals; + +public class KStreamSelectKeyTest { + private final String topicName = "topic_key_select"; + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.String(), Serdes.Integer()); + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Test + public void testSelectKey() { + final StreamsBuilder builder = new StreamsBuilder(); + + final Map keyMap = new HashMap<>(); + keyMap.put(1, "ONE"); + keyMap.put(2, "TWO"); + keyMap.put(3, "THREE"); + + final KeyValueTimestamp[] expected = new KeyValueTimestamp[]{new KeyValueTimestamp<>("ONE", 1, 0), + new KeyValueTimestamp<>("TWO", 2, 0), + new KeyValueTimestamp<>("THREE", 3, 0)}; + final int[] expectedValues = new int[]{1, 2, 3}; + + final KStream stream = + builder.stream(topicName, Consumed.with(Serdes.String(), Serdes.Integer())); + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + stream.selectKey((key, value) -> keyMap.get(value)).process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(topicName, new StringSerializer(), new IntegerSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + for (final int expectedValue : expectedValues) { + inputTopic.pipeInput(expectedValue); + } + } + + assertEquals(3, supplier.theCapturedProcessor().processed().size()); + for (int i = 0; i < expected.length; i++) { + assertEquals(expected[i], supplier.theCapturedProcessor().processed().get(i)); + } + + } + + @Test + public void testTypeVariance() { + new StreamsBuilder() + .stream("empty") + .foreach((key, value) -> { }); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamSessionWindowAggregateProcessorTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamSessionWindowAggregateProcessorTest.java new file mode 100644 index 0000000..ab92b70 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamSessionWindowAggregateProcessorTest.java @@ -0,0 +1,531 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.kstream.Merger; +import org.apache.kafka.streams.kstream.SessionWindows; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.metrics.TaskMetrics; +import org.apache.kafka.streams.processor.internals.ProcessorRecordContext; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender.Event; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.SessionStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.internals.ThreadCache; +import org.apache.kafka.test.InternalMockProcessorContext; +import org.apache.kafka.test.MockRecordCollector; +import org.apache.kafka.test.StreamsTestUtils; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import static java.time.Duration.ofMillis; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.test.StreamsTestUtils.getMetricByName; +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThan; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + + +public class KStreamSessionWindowAggregateProcessorTest { + + private static final long GAP_MS = 5 * 60 * 1000L; + private static final String STORE_NAME = "session-store"; + + private final String threadId = Thread.currentThread().getName(); + private final Initializer initializer = () -> 0L; + private final Aggregator aggregator = (aggKey, value, aggregate) -> aggregate + 1; + private final Merger sessionMerger = (aggKey, aggOne, aggTwo) -> aggOne + aggTwo; + private final KStreamSessionWindowAggregate sessionAggregator = + new KStreamSessionWindowAggregate<>( + SessionWindows.ofInactivityGapWithNoGrace(ofMillis(GAP_MS)), + STORE_NAME, + initializer, + aggregator, + sessionMerger); + + private final List, Change>> results = new ArrayList<>(); + private final Processor, Change> processor = sessionAggregator.get(); + private SessionStore sessionStore; + private InternalMockProcessorContext, Change> context; + private final Metrics metrics = new Metrics(); + + @Before + public void setup() { + setup(true); + } + + private void setup(final boolean enableCache) { + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, "test", StreamsConfig.METRICS_LATEST, new MockTime()); + context = new InternalMockProcessorContext, Change>( + TestUtils.tempDirectory(), + Serdes.String(), + Serdes.String(), + streamsMetrics, + new StreamsConfig(StreamsTestUtils.getStreamsConfig()), + MockRecordCollector::new, + new ThreadCache(new LogContext("testCache "), 100000, streamsMetrics), + Time.SYSTEM + ) { + @Override + public , V extends Change> void forward(final Record record) { + results.add(new KeyValueTimestamp<>(record.key(), record.value(), record.timestamp())); + } + }; + TaskMetrics.droppedRecordsSensor(threadId, context.taskId().toString(), streamsMetrics); + + initStore(enableCache); + processor.init(context); + } + + private void initStore(final boolean enableCaching) { + final StoreBuilder> storeBuilder = + Stores.sessionStoreBuilder( + Stores.persistentSessionStore(STORE_NAME, ofMillis(GAP_MS * 3)), + Serdes.String(), + Serdes.Long()) + .withLoggingDisabled(); + + if (enableCaching) { + storeBuilder.withCachingEnabled(); + } + + if (sessionStore != null) { + sessionStore.close(); + } + sessionStore = storeBuilder.build(); + sessionStore.init((StateStoreContext) context, sessionStore); + } + + @After + public void closeStore() { + sessionStore.close(); + } + + @Test + public void shouldCreateSingleSessionWhenWithinGap() { + processor.process(new Record<>("john", "first", 0L)); + processor.process(new Record<>("john", "second", 500L)); + + try (final KeyValueIterator, Long> values = + sessionStore.findSessions("john", 0, 2000)) { + assertTrue(values.hasNext()); + assertEquals(Long.valueOf(2), values.next().value); + } + } + + @Test + public void shouldMergeSessions() { + final String sessionId = "mel"; + processor.process(new Record<>(sessionId, "first", 0L)); + assertTrue(sessionStore.findSessions(sessionId, 0, 0).hasNext()); + + // move time beyond gap + processor.process(new Record<>(sessionId, "second", GAP_MS + 1)); + assertTrue(sessionStore.findSessions(sessionId, GAP_MS + 1, GAP_MS + 1).hasNext()); + // should still exist as not within gap + assertTrue(sessionStore.findSessions(sessionId, 0, 0).hasNext()); + // move time back + processor.process(new Record<>(sessionId, "third", GAP_MS / 2)); + + try (final KeyValueIterator, Long> iterator = + sessionStore.findSessions(sessionId, 0, GAP_MS + 1)) { + final KeyValue, Long> kv = iterator.next(); + + assertEquals(Long.valueOf(3), kv.value); + assertFalse(iterator.hasNext()); + } + } + + @Test + public void shouldUpdateSessionIfTheSameTime() { + processor.process(new Record<>("mel", "first", 0L)); + processor.process(new Record<>("mel", "second", 0L)); + try (final KeyValueIterator, Long> iterator = + sessionStore.findSessions("mel", 0, 0)) { + assertEquals(Long.valueOf(2L), iterator.next().value); + assertFalse(iterator.hasNext()); + } + } + + @Test + public void shouldHaveMultipleSessionsForSameIdWhenTimestampApartBySessionGap() { + final String sessionId = "mel"; + long time = 0; + processor.process(new Record<>(sessionId, "first", time)); + final long time1 = time += GAP_MS + 1; + processor.process(new Record<>(sessionId, "second", time1)); + processor.process(new Record<>(sessionId, "second", time1)); + final long time2 = time += GAP_MS + 1; + processor.process(new Record<>(sessionId, "third", time2)); + processor.process(new Record<>(sessionId, "third", time2)); + processor.process(new Record<>(sessionId, "third", time2)); + + sessionStore.flush(); + assertEquals( + Arrays.asList( + new KeyValueTimestamp<>( + new Windowed<>(sessionId, new SessionWindow(0, 0)), + new Change<>(1L, null), + 0L), + new KeyValueTimestamp<>( + new Windowed<>(sessionId, new SessionWindow(GAP_MS + 1, GAP_MS + 1)), + new Change<>(2L, null), + GAP_MS + 1), + new KeyValueTimestamp<>( + new Windowed<>(sessionId, new SessionWindow(time, time)), + new Change<>(3L, null), + time) + ), + results + ); + + } + + @Test + public void shouldRemoveMergedSessionsFromStateStore() { + processor.process(new Record<>("a", "1", 0L)); + + // first ensure it is in the store + try (final KeyValueIterator, Long> a1 = + sessionStore.findSessions("a", 0, 0)) { + assertEquals(KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 1L), a1.next()); + } + + + processor.process(new Record<>("a", "2", 100L)); + // a1 from above should have been removed + // should have merged session in store + try (final KeyValueIterator, Long> a2 = + sessionStore.findSessions("a", 0, 100)) { + assertEquals(KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 100)), 2L), a2.next()); + assertFalse(a2.hasNext()); + } + } + + @Test + public void shouldHandleMultipleSessionsAndMerging() { + processor.process(new Record<>("a", "1", 0L)); + processor.process(new Record<>("b", "1", 0L)); + processor.process(new Record<>("c", "1", 0L)); + processor.process(new Record<>("d", "1", 0L)); + processor.process(new Record<>("d", "2", GAP_MS / 2)); + processor.process(new Record<>("a", "2", GAP_MS + 1)); + processor.process(new Record<>("b", "2", GAP_MS + 1)); + processor.process(new Record<>("a", "3", GAP_MS + 1 + GAP_MS / 2)); + processor.process(new Record<>("c", "3", GAP_MS + 1 + GAP_MS / 2)); + + sessionStore.flush(); + + assertEquals( + Arrays.asList( + new KeyValueTimestamp<>( + new Windowed<>("a", new SessionWindow(0, 0)), + new Change<>(1L, null), + 0L), + new KeyValueTimestamp<>( + new Windowed<>("b", new SessionWindow(0, 0)), + new Change<>(1L, null), + 0L), + new KeyValueTimestamp<>( + new Windowed<>("c", new SessionWindow(0, 0)), + new Change<>(1L, null), + 0L), + new KeyValueTimestamp<>( + new Windowed<>("d", new SessionWindow(0, GAP_MS / 2)), + new Change<>(2L, null), + GAP_MS / 2), + new KeyValueTimestamp<>( + new Windowed<>("b", new SessionWindow(GAP_MS + 1, GAP_MS + 1)), + new Change<>(1L, null), + GAP_MS + 1), + new KeyValueTimestamp<>( + new Windowed<>("a", new SessionWindow(GAP_MS + 1, GAP_MS + 1 + GAP_MS / 2)), + new Change<>(2L, null), + GAP_MS + 1 + GAP_MS / 2), + new KeyValueTimestamp<>(new Windowed<>( + "c", + new SessionWindow(GAP_MS + 1 + GAP_MS / 2, GAP_MS + 1 + GAP_MS / 2)), new Change<>(1L, null), + GAP_MS + 1 + GAP_MS / 2) + ), + results + ); + } + + @Test + public void shouldGetAggregatedValuesFromValueGetter() { + final KTableValueGetter, Long> getter = sessionAggregator.view().get(); + getter.init(context); + processor.process(new Record<>("a", "1", 0L)); + processor.process(new Record<>("a", "1", GAP_MS + 1)); + processor.process(new Record<>("a", "2", GAP_MS + 1)); + final long t0 = getter.get(new Windowed<>("a", new SessionWindow(0, 0))).value(); + final long t1 = getter.get(new Windowed<>("a", new SessionWindow(GAP_MS + 1, GAP_MS + 1))).value(); + assertEquals(1L, t0); + assertEquals(2L, t1); + } + + @Test + public void shouldImmediatelyForwardNewSessionWhenNonCachedStore() { + initStore(false); + processor.init(context); + + processor.process(new Record<>("a", "1", 0L)); + processor.process(new Record<>("b", "1", 0L)); + processor.process(new Record<>("c", "1", 0L)); + + assertEquals( + Arrays.asList( + new KeyValueTimestamp<>( + new Windowed<>("a", new SessionWindow(0, 0)), + new Change<>(1L, null), + 0L), + new KeyValueTimestamp<>( + new Windowed<>("b", new SessionWindow(0, 0)), + new Change<>(1L, null), + 0L), + new KeyValueTimestamp<>( + new Windowed<>("c", new SessionWindow(0, 0)), + new Change<>(1L, null), + 0L) + ), + results + ); + } + + @Test + public void shouldImmediatelyForwardRemovedSessionsWhenMerging() { + initStore(false); + processor.init(context); + + processor.process(new Record<>("a", "1", 0L)); + processor.process(new Record<>("a", "1", 5L)); + assertEquals( + Arrays.asList( + new KeyValueTimestamp<>( + new Windowed<>("a", new SessionWindow(0, 0)), + new Change<>(1L, null), + 0L), + new KeyValueTimestamp<>( + new Windowed<>("a", new SessionWindow(0, 0)), + new Change<>(null, null), + 0L), + new KeyValueTimestamp<>( + new Windowed<>("a", new SessionWindow(0, 5)), + new Change<>(2L, null), + 5L) + ), + results + ); + + } + + @Test + public void shouldLogAndMeterWhenSkippingNullKeyWithBuiltInMetrics() { + setup(false); + context.setRecordContext( + new ProcessorRecordContext(-1, -2, -3, "topic", new RecordHeaders()) + ); + + try (final LogCaptureAppender appender = + LogCaptureAppender.createAndRegister(KStreamSessionWindowAggregate.class)) { + + processor.process(new Record<>(null, "1", 0L)); + + assertThat( + appender.getEvents().stream() + .filter(e -> e.getLevel().equals("WARN")) + .map(Event::getMessage) + .collect(Collectors.toList()), + hasItem("Skipping record due to null key. topic=[topic] partition=[-3] offset=[-2]") + ); + } + + assertEquals( + 1.0, + getMetricByName(context.metrics().metrics(), "dropped-records-total", "stream-task-metrics").metricValue() + ); + } + + @Test + public void shouldLogAndMeterWhenSkippingLateRecordWithZeroGrace() { + setup(false); + final Processor, Change> processor = new KStreamSessionWindowAggregate<>( + SessionWindows.ofInactivityGapAndGrace(ofMillis(10L), ofMillis(0L)), + STORE_NAME, + initializer, + aggregator, + sessionMerger + ).get(); + processor.init(context); + + // dummy record to establish stream time = 0 + context.setRecordContext(new ProcessorRecordContext(0, -2, -3, "topic", new RecordHeaders())); + processor.process(new Record<>("dummy", "dummy", 0L)); + + // record arrives on time, should not be skipped + context.setRecordContext(new ProcessorRecordContext(0, -2, -3, "topic", new RecordHeaders())); + processor.process(new Record<>("OnTime1", "1", 0L)); + + // dummy record to advance stream time = 11, 10 for gap time plus 1 to place outside window + context.setRecordContext(new ProcessorRecordContext(11, -2, -3, "topic", new RecordHeaders())); + processor.process(new Record<>("dummy", "dummy", 11L)); + + try (final LogCaptureAppender appender = + LogCaptureAppender.createAndRegister(KStreamSessionWindowAggregate.class)) { + + // record is late + context.setRecordContext(new ProcessorRecordContext(0, -2, -3, "topic", new RecordHeaders())); + processor.process(new Record<>("Late1", "1", 0L)); + + assertThat( + appender.getMessages(), + hasItem("Skipping record for expired window." + + " topic=[topic] partition=[-3] offset=[-2] timestamp=[0] window=[0,0] expiration=[1] streamTime=[11]") + ); + } + + final MetricName dropTotal; + final MetricName dropRate; + dropTotal = new MetricName( + "dropped-records-total", + "stream-task-metrics", + "The total number of dropped records", + mkMap( + mkEntry("thread-id", threadId), + mkEntry("task-id", "0_0") + ) + ); + dropRate = new MetricName( + "dropped-records-rate", + "stream-task-metrics", + "The average number of dropped records per second", + mkMap( + mkEntry("thread-id", threadId), + mkEntry("task-id", "0_0") + ) + ); + assertThat(metrics.metrics().get(dropTotal).metricValue(), is(1.0)); + assertThat( + (Double) metrics.metrics().get(dropRate).metricValue(), + greaterThan(0.0) + ); + } + + @Test + public void shouldLogAndMeterWhenSkippingLateRecordWithNonzeroGrace() { + setup(false); + final Processor, Change> processor = new KStreamSessionWindowAggregate<>( + SessionWindows.ofInactivityGapAndGrace(ofMillis(10L), ofMillis(1L)), + STORE_NAME, + initializer, + aggregator, + sessionMerger + ).get(); + processor.init(context); + + try (final LogCaptureAppender appender = + LogCaptureAppender.createAndRegister(KStreamSessionWindowAggregate.class)) { + + // dummy record to establish stream time = 0 + context.setRecordContext(new ProcessorRecordContext(0, -2, -3, "topic", new RecordHeaders())); + processor.process(new Record<>("dummy", "dummy", 0L)); + + // record arrives on time, should not be skipped + context.setRecordContext(new ProcessorRecordContext(0, -2, -3, "topic", new RecordHeaders())); + processor.process(new Record<>("OnTime1", "1", 0L)); + + // dummy record to advance stream time = 11, 10 for gap time plus 1 to place at edge of window + context.setRecordContext(new ProcessorRecordContext(11, -2, -3, "topic", new RecordHeaders())); + processor.process(new Record<>("dummy", "dummy", 11L)); + + // delayed record arrives on time, should not be skipped + context.setRecordContext(new ProcessorRecordContext(0, -2, -3, "topic", new RecordHeaders())); + processor.process(new Record<>("OnTime2", "1", 0L)); + + // dummy record to advance stream time = 12, 10 for gap time plus 2 to place outside window + context.setRecordContext(new ProcessorRecordContext(12, -2, -3, "topic", new RecordHeaders())); + processor.process(new Record<>("dummy", "dummy", 12L)); + + // delayed record arrives late + context.setRecordContext(new ProcessorRecordContext(0, -2, -3, "topic", new RecordHeaders())); + processor.process(new Record<>("Late1", "1", 0L)); + + assertThat( + appender.getMessages(), + hasItem("Skipping record for expired window." + + " topic=[topic] partition=[-3] offset=[-2] timestamp=[0] window=[0,0] expiration=[1] streamTime=[12]") + ); + } + + final MetricName dropTotal; + final MetricName dropRate; + dropTotal = new MetricName( + "dropped-records-total", + "stream-task-metrics", + "The total number of dropped records", + mkMap( + mkEntry("thread-id", threadId), + mkEntry("task-id", "0_0") + ) + ); + dropRate = new MetricName( + "dropped-records-rate", + "stream-task-metrics", + "The average number of dropped records per second", + mkMap( + mkEntry("thread-id", threadId), + mkEntry("task-id", "0_0") + ) + ); + + assertThat(metrics.metrics().get(dropTotal).metricValue(), is(1.0)); + assertThat( + (Double) metrics.metrics().get(dropRate).metricValue(), + greaterThan(0.0)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamSlidingWindowAggregateTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamSlidingWindowAggregateTest.java new file mode 100644 index 0000000..38c3fa7 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamSlidingWindowAggregateTest.java @@ -0,0 +1,1027 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.TestOutputTopic; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.SlidingWindows; +import org.apache.kafka.streams.kstream.TimeWindowedDeserializer; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender.Event; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.streams.state.WindowBytesStoreSupplier; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.state.WindowStoreIterator; +import org.apache.kafka.streams.state.internals.InMemoryWindowBytesStoreSupplier; +import org.apache.kafka.streams.state.internals.InMemoryWindowStore; +import org.apache.kafka.streams.test.TestRecord; +import org.apache.kafka.test.MockAggregator; +import org.apache.kafka.test.MockInitializer; +import org.apache.kafka.test.MockProcessor; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.MockReducer; +import org.apache.kafka.test.StreamsTestUtils; +import org.hamcrest.Matcher; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Random; +import java.util.stream.Collectors; + +import static java.time.Duration.ofMillis; +import static java.util.Arrays.asList; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.CoreMatchers.hasItems; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +@RunWith(Parameterized.class) +public class KStreamSlidingWindowAggregateTest { + + @Parameterized.Parameters(name = "{0}") + public static Collection data() { + return Arrays.asList(new Boolean[][] { + {false}, + {true} + }); + } + + @Parameterized.Parameter + public boolean inOrderIterator; + + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.String(), Serdes.String()); + private final String threadId = Thread.currentThread().getName(); + + @SuppressWarnings("unchecked") + @Test + public void testAggregateSmallInput() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic = "topic"; + + final WindowBytesStoreSupplier storeSupplier = + inOrderIterator + ? new InOrderMemoryWindowStoreSupplier("InOrder", 50000L, 10L, false) + : Stores.inMemoryWindowStore("Reverse", Duration.ofMillis(50000), Duration.ofMillis(10), false); + final KTable, String> table = builder + .stream(topic, Consumed.with(Serdes.String(), Serdes.String())) + .groupByKey(Grouped.with(Serdes.String(), Serdes.String())) + .windowedBy(SlidingWindows.withTimeDifferenceAndGrace(ofMillis(10), ofMillis(50))) + .aggregate( + MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + Materialized.as(storeSupplier) + ); + final MockProcessorSupplier, String> supplier = new MockProcessorSupplier<>(); + table.toStream().process(supplier); + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(topic, new StringSerializer(), new StringSerializer()); + inputTopic.pipeInput("A", "1", 10L); + inputTopic.pipeInput("A", "2", 15L); + inputTopic.pipeInput("A", "3", 20L); + inputTopic.pipeInput("A", "4", 22L); + inputTopic.pipeInput("A", "5", 30L); + } + + final Map> actual = new HashMap<>(); + + for (final KeyValueTimestamp, String> entry : supplier.theCapturedProcessor().processed()) { + final Windowed window = entry.key(); + final Long start = window.window().start(); + final ValueAndTimestamp valueAndTimestamp = ValueAndTimestamp.make(entry.value(), entry.timestamp()); + + if (actual.putIfAbsent(start, valueAndTimestamp) != null) { + actual.replace(start, valueAndTimestamp); + } + } + + final Map> expected = new HashMap<>(); + expected.put(0L, ValueAndTimestamp.make("0+1", 10L)); + expected.put(5L, ValueAndTimestamp.make("0+1+2", 15L)); + expected.put(10L, ValueAndTimestamp.make("0+1+2+3", 20L)); + expected.put(11L, ValueAndTimestamp.make("0+2+3", 20L)); + expected.put(12L, ValueAndTimestamp.make("0+2+3+4", 22L)); + expected.put(16L, ValueAndTimestamp.make("0+3+4", 22L)); + expected.put(20L, ValueAndTimestamp.make("0+3+4+5", 30L)); + expected.put(21L, ValueAndTimestamp.make("0+4+5", 30L)); + expected.put(23L, ValueAndTimestamp.make("0+5", 30L)); + + assertEquals(expected, actual); + } + + @SuppressWarnings("unchecked") + @Test + public void testReduceSmallInput() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic = "topic"; + final WindowBytesStoreSupplier storeSupplier = + inOrderIterator + ? new InOrderMemoryWindowStoreSupplier("InOrder", 50000L, 10L, false) + : Stores.inMemoryWindowStore("Reverse", Duration.ofMillis(50000), Duration.ofMillis(10), false); + + final KTable, String> table = builder + .stream(topic, Consumed.with(Serdes.String(), Serdes.String())) + .groupByKey(Grouped.with(Serdes.String(), Serdes.String())) + .windowedBy(SlidingWindows.withTimeDifferenceAndGrace(ofMillis(10), ofMillis(50))) + .reduce( + MockReducer.STRING_ADDER, + Materialized.as(storeSupplier) + ); + final MockProcessorSupplier, String> supplier = new MockProcessorSupplier<>(); + table.toStream().process(supplier); + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(topic, new StringSerializer(), new StringSerializer()); + inputTopic.pipeInput("A", "1", 10L); + inputTopic.pipeInput("A", "2", 14L); + inputTopic.pipeInput("A", "3", 15L); + inputTopic.pipeInput("A", "4", 22L); + inputTopic.pipeInput("A", "5", 26L); + inputTopic.pipeInput("A", "6", 30L); + } + + final Map> actual = new HashMap<>(); + + for (final KeyValueTimestamp, String> entry : supplier.theCapturedProcessor().processed()) { + final Windowed window = entry.key(); + final Long start = window.window().start(); + final ValueAndTimestamp valueAndTimestamp = ValueAndTimestamp.make(entry.value(), entry.timestamp()); + if (actual.putIfAbsent(start, valueAndTimestamp) != null) { + actual.replace(start, valueAndTimestamp); + } + } + + final Map> expected = new HashMap<>(); + expected.put(0L, ValueAndTimestamp.make("1", 10L)); + expected.put(4L, ValueAndTimestamp.make("1+2", 14L)); + expected.put(5L, ValueAndTimestamp.make("1+2+3", 15L)); + expected.put(11L, ValueAndTimestamp.make("2+3", 15L)); + expected.put(12L, ValueAndTimestamp.make("2+3+4", 22L)); + expected.put(15L, ValueAndTimestamp.make("3+4", 22L)); + expected.put(16L, ValueAndTimestamp.make("4+5", 26L)); + expected.put(20L, ValueAndTimestamp.make("4+5+6", 30L)); + expected.put(23L, ValueAndTimestamp.make("5+6", 30L)); + expected.put(27L, ValueAndTimestamp.make("6", 30L)); + + assertEquals(expected, actual); + } + + @Test + public void testAggregateLargeInput() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + + final WindowBytesStoreSupplier storeSupplier = + inOrderIterator + ? new InOrderMemoryWindowStoreSupplier("InOrder", 50000L, 10L, false) + : Stores.inMemoryWindowStore("Reverse", Duration.ofMillis(50000), Duration.ofMillis(10), false); + final KTable, String> table2 = builder + .stream(topic1, Consumed.with(Serdes.String(), Serdes.String())) + .groupByKey(Grouped.with(Serdes.String(), Serdes.String())) + .windowedBy(SlidingWindows.withTimeDifferenceAndGrace(ofMillis(10), ofMillis(50))) + .aggregate( + MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + Materialized.as(storeSupplier) + ); + + final MockProcessorSupplier, String> supplier = new MockProcessorSupplier<>(); + table2.toStream().process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new StringSerializer(), new StringSerializer()); + inputTopic1.pipeInput("A", "1", 10L); + inputTopic1.pipeInput("A", "2", 20L); + inputTopic1.pipeInput("A", "3", 22L); + inputTopic1.pipeInput("A", "4", 15L); + + inputTopic1.pipeInput("B", "1", 12L); + inputTopic1.pipeInput("B", "2", 13L); + inputTopic1.pipeInput("B", "3", 18L); + inputTopic1.pipeInput("B", "4", 19L); + inputTopic1.pipeInput("B", "5", 25L); + inputTopic1.pipeInput("B", "6", 14L); + + inputTopic1.pipeInput("C", "1", 11L); + inputTopic1.pipeInput("C", "2", 15L); + inputTopic1.pipeInput("C", "3", 16L); + inputTopic1.pipeInput("C", "4", 21); + inputTopic1.pipeInput("C", "5", 23L); + + inputTopic1.pipeInput("D", "4", 11L); + inputTopic1.pipeInput("D", "2", 12L); + inputTopic1.pipeInput("D", "3", 29L); + inputTopic1.pipeInput("D", "5", 16L); + } + final Comparator, String>> comparator = + Comparator.comparing((KeyValueTimestamp, String> o) -> o.key().key()) + .thenComparing((KeyValueTimestamp, String> o) -> o.key().window().start()); + + final ArrayList, String>> actual = supplier.theCapturedProcessor().processed(); + actual.sort(comparator); + assertEquals( + asList( + // FINAL WINDOW: A@10 left window created when A@10 processed + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(0, 10)), "0+1", 10), + // FINAL WINDOW: A@15 left window created when A@15 processed + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(5, 15)), "0+1+4", 15), + // A@20 left window created when A@20 processed + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(10, 20)), "0+1+2", 20), + // FINAL WINDOW: A@20 left window updated when A@15 processed + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(10, 20)), "0+1+2+4", 20), + // A@10 right window created when A@20 processed + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(11, 21)), "0+2", 20), + // FINAL WINDOW: A@10 right window updated when A@15 processed + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(11, 21)), "0+2+4", 20), + // A@22 left window created when A@22 processed + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(12, 22)), "0+2+3", 22), + // FINAL WINDOW: A@22 left window updated when A@15 processed + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(12, 22)), "0+2+3+4", 22), + // FINAL WINDOW: A@15 right window created when A@15 processed + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(16, 26)), "0+2+3", 22), + // FINAL WINDOW: A@20 right window created when A@22 processed + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(21, 31)), "0+3", 22), + // FINAL WINDOW: B@12 left window created when B@12 processed + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(2, 12)), "0+1", 12), + // FINAL WINDOW: B@13 left window created when B@13 processed + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(3, 13)), "0+1+2", 13), + // FINAL WINDOW: B@14 left window created when B@14 processed + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(4, 14)), "0+1+2+6", 14), + // B@18 left window created when B@18 processed + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(8, 18)), "0+1+2+3", 18), + // FINAL WINDOW: B@18 left window updated when B@14 processed + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(8, 18)), "0+1+2+3+6", 18), + // B@19 left window created when B@19 processed + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(9, 19)), "0+1+2+3+4", 19), + // FINAL WINDOW: B@19 left window updated when B@14 processed + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(9, 19)), "0+1+2+3+4+6", 19), + // B@12 right window created when B@13 processed + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(13, 23)), "0+2", 13), + // B@12 right window updated when B@18 processed + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(13, 23)), "0+2+3", 18), + // B@12 right window updated when B@19 processed + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(13, 23)), "0+2+3+4", 19), + // FINAL WINDOW: B@12 right window updated when B@14 processed + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(13, 23)), "0+2+3+4+6", 19), + // B@13 right window created when B@18 processed + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(14, 24)), "0+3", 18), + // B@13 right window updated when B@19 processed + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(14, 24)), "0+3+4", 19), + // FINAL WINDOW: B@13 right window updated when B@14 processed + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(14, 24)), "0+3+4+6", 19), + // FINAL WINDOW: B@25 left window created when B@25 processed + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(15, 25)), "0+3+4+5", 25), + // B@18 right window created when B@19 processed + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(19, 29)), "0+4", 19), + // FINAL WINDOW: B@18 right window updated when B@25 processed + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(19, 29)), "0+4+5", 25), + // FINAL WINDOW: B@19 right window updated when B@25 processed + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(20, 30)), "0+5", 25), + // FINAL WINDOW: C@11 left window created when C@11 processed + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(1, 11)), "0+1", 11), + // FINAL WINDOW: C@15 left window created when C@15 processed + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(5, 15)), "0+1+2", 15), + // FINAL WINDOW: C@16 left window created when C@16 processed + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(6, 16)), "0+1+2+3", 16), + // FINAL WINDOW: C@21 left window created when C@21 processed + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(11, 21)), "0+1+2+3+4", 21), + // C@11 right window created when C@15 processed + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(12, 22)), "0+2", 15), + // C@11 right window updated when C@16 processed + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(12, 22)), "0+2+3", 16), + // FINAL WINDOW: C@11 right window updated when C@21 processed + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(12, 22)), "0+2+3+4", 21), + // FINAL WINDOW: C@23 left window created when C@23 processed + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(13, 23)), "0+2+3+4+5", 23), + // C@15 right window created when C@16 processed + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(16, 26)), "0+3", 16), + // C@15 right window updated when C@21 processed + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(16, 26)), "0+3+4", 21), + // FINAL WINDOW: C@15 right window updated when C@23 processed + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(16, 26)), "0+3+4+5", 23), + // C@16 right window created when C@21 processed + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(17, 27)), "0+4", 21), + // FINAL WINDOW: C@16 right window updated when C@23 processed + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(17, 27)), "0+4+5", 23), + // FINAL WINDOW: C@21 right window created when C@23 processed + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(22, 32)), "0+5", 23), + // FINAL WINDOW: D@11 left window created when D@11 processed + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(1, 11)), "0+4", 11), + // FINAL WINDOW: D@12 left window created when D@12 processed + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(2, 12)), "0+4+2", 12), + // FINAL WINDOW: D@16 left window created when D@16 processed + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(6, 16)), "0+4+2+5", 16), + // D@11 right window created when D@12 processed + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(12, 22)), "0+2", 12), + // FINAL WINDOW: D@11 right window updated when D@16 processed + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(12, 22)), "0+2+5", 16), + // FINAL WINDOW: D@12 right window created when D@16 processed + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(13, 23)), "0+5", 16), + // FINAL WINDOW: D@29 left window created when D@29 processed + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(19, 29)), "0+3", 29)), + actual + ); + } + + @Test + public void testJoin() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + final String topic2 = "topic2"; + final WindowBytesStoreSupplier storeSupplier1 = + inOrderIterator + ? new InOrderMemoryWindowStoreSupplier("InOrder1", 50000L, 10L, false) + : Stores.inMemoryWindowStore("Reverse1", Duration.ofMillis(50000), Duration.ofMillis(10), false); + final WindowBytesStoreSupplier storeSupplier2 = + inOrderIterator + ? new InOrderMemoryWindowStoreSupplier("InOrder2", 50000L, 10L, false) + : Stores.inMemoryWindowStore("Reverse2", Duration.ofMillis(50000), Duration.ofMillis(10), false); + + final KTable, String> table1 = builder + .stream(topic1, Consumed.with(Serdes.String(), Serdes.String())) + .groupByKey(Grouped.with(Serdes.String(), Serdes.String())) + .windowedBy(SlidingWindows.withTimeDifferenceAndGrace(ofMillis(10), ofMillis(100))) + .aggregate( + MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + Materialized.as(storeSupplier1) + ); + final KTable, String> table2 = builder + .stream(topic2, Consumed.with(Serdes.String(), Serdes.String())) + .groupByKey(Grouped.with(Serdes.String(), Serdes.String())) + .windowedBy(SlidingWindows.withTimeDifferenceAndGrace(ofMillis(10), ofMillis(100))) + .aggregate( + MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + Materialized.as(storeSupplier2) + ); + + final MockProcessorSupplier, String> supplier = new MockProcessorSupplier<>(); + table1.toStream().process(supplier); + + table2.toStream().process(supplier); + + table1.join(table2, (p1, p2) -> p1 + "%" + p2).toStream().process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new StringSerializer(), new StringSerializer()); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new StringSerializer(), new StringSerializer()); + inputTopic1.pipeInput("A", "1", 10L); + inputTopic1.pipeInput("B", "2", 11L); + inputTopic1.pipeInput("C", "3", 12L); + + final List, String>> processors = supplier.capturedProcessors(3); + + processors.get(0).checkAndClearProcessResult( + // left windows created by the first set of records to table 1 + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(0, 10)), "0+1", 10), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(1, 11)), "0+2", 11), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(2, 12)), "0+3", 12) + ); + processors.get(1).checkAndClearProcessResult(); + processors.get(2).checkAndClearProcessResult(); + + inputTopic1.pipeInput("A", "1", 15L); + inputTopic1.pipeInput("B", "2", 16L); + inputTopic1.pipeInput("C", "3", 19L); + + processors.get(0).checkAndClearProcessResult( + // right windows from previous records are created, and left windows from new records to table 1 + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(11, 21)), "0+1", 15), + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(5, 15)), "0+1+1", 15), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(12, 22)), "0+2", 16), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(6, 16)), "0+2+2", 16), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(13, 23)), "0+3", 19), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(9, 19)), "0+3+3", 19) + ); + processors.get(1).checkAndClearProcessResult(); + processors.get(2).checkAndClearProcessResult(); + + inputTopic2.pipeInput("A", "a", 10L); + inputTopic2.pipeInput("B", "b", 30L); + inputTopic2.pipeInput("C", "c", 12L); + inputTopic2.pipeInput("C", "c", 35L); + + + processors.get(0).checkAndClearProcessResult(); + processors.get(1).checkAndClearProcessResult( + // left windows from first set of records sent to table 2 + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(0, 10)), "0+a", 10), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(20, 30)), "0+b", 30), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(2, 12)), "0+c", 12), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(25, 35)), "0+c", 35) + ); + processors.get(2).checkAndClearProcessResult( + // set of join windows from windows created by table 1 and table 2 + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(0, 10)), "0+1%0+a", 10), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(2, 12)), "0+3%0+c", 12) + ); + + inputTopic2.pipeInput("A", "a", 15L); + inputTopic2.pipeInput("B", "b", 16L); + inputTopic2.pipeInput("C", "c", 17L); + + processors.get(0).checkAndClearProcessResult(); + processors.get(1).checkAndClearProcessResult( + // right windows from previous records are created (where applicable), and left windows from new records to table 2 + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(11, 21)), "0+a", 15), + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(5, 15)), "0+a+a", 15), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(6, 16)), "0+b", 16), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(13, 23)), "0+c", 17), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(7, 17)), "0+c+c", 17) + ); + processors.get(2).checkAndClearProcessResult( + // set of join windows from windows created by table 1 and table 2 + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(11, 21)), "0+1%0+a", 15), + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(5, 15)), "0+1+1%0+a+a", 15), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(6, 16)), "0+2+2%0+b", 16), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(13, 23)), "0+3%0+c", 19) + ); + } + } + + @SuppressWarnings("unchecked") + @Test + public void testEarlyRecordsSmallInput() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic = "topic"; + + final KTable, String> table2 = builder + .stream(topic, Consumed.with(Serdes.String(), Serdes.String())) + .groupByKey(Grouped.with(Serdes.String(), Serdes.String())) + .windowedBy(SlidingWindows.withTimeDifferenceAndGrace(ofMillis(50), ofMillis(200))) + .aggregate( + MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + Materialized.>as("topic-Canonized").withValueSerde(Serdes.String()) + ); + final MockProcessorSupplier, String> supplier = new MockProcessorSupplier<>(); + table2.toStream().process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(topic, new StringSerializer(), new StringSerializer()); + + inputTopic.pipeInput("A", "1", 0L); + inputTopic.pipeInput("A", "2", 5L); + inputTopic.pipeInput("A", "3", 6L); + inputTopic.pipeInput("A", "4", 3L); + inputTopic.pipeInput("A", "5", 13L); + inputTopic.pipeInput("A", "6", 10L); + } + + final Map> actual = new HashMap<>(); + for (final KeyValueTimestamp, String> entry : supplier.theCapturedProcessor().processed()) { + final Windowed window = entry.key(); + final Long start = window.window().start(); + final ValueAndTimestamp valueAndTimestamp = ValueAndTimestamp.make(entry.value(), entry.timestamp()); + if (actual.putIfAbsent(start, valueAndTimestamp) != null) { + actual.replace(start, valueAndTimestamp); + } + } + + final Map> expected = new HashMap<>(); + expected.put(0L, ValueAndTimestamp.make("0+1+2+3+4+5+6", 13L)); + expected.put(1L, ValueAndTimestamp.make("0+2+3+4+5+6", 13L)); + expected.put(4L, ValueAndTimestamp.make("0+2+3+5+6", 13L)); + expected.put(6L, ValueAndTimestamp.make("0+3+5+6", 13L)); + expected.put(7L, ValueAndTimestamp.make("0+5+6", 13L)); + expected.put(11L, ValueAndTimestamp.make("0+5", 13L)); + + assertEquals(expected, actual); + } + + @SuppressWarnings("unchecked") + @Test + public void testEarlyRecordsRepeatedInput() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic = "topic"; + + final KTable, String> table2 = builder + .stream(topic, Consumed.with(Serdes.String(), Serdes.String())) + .groupByKey(Grouped.with(Serdes.String(), Serdes.String())) + .windowedBy(SlidingWindows.withTimeDifferenceAndGrace(ofMillis(5), ofMillis(20))) + .aggregate( + MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + Materialized.>as("topic-Canonized").withValueSerde(Serdes.String()) + ); + final MockProcessorSupplier, String> supplier = new MockProcessorSupplier<>(); + table2.toStream().process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(topic, new StringSerializer(), new StringSerializer()); + + inputTopic.pipeInput("A", "1", 0L); + inputTopic.pipeInput("A", "2", 2L); + inputTopic.pipeInput("A", "3", 4L); + inputTopic.pipeInput("A", "4", 0L); + inputTopic.pipeInput("A", "5", 2L); + inputTopic.pipeInput("A", "6", 2L); + inputTopic.pipeInput("A", "7", 0L); + } + + final Map> actual = new HashMap<>(); + for (final KeyValueTimestamp, String> entry : supplier.theCapturedProcessor().processed()) { + final Windowed window = entry.key(); + final Long start = window.window().start(); + final ValueAndTimestamp valueAndTimestamp = ValueAndTimestamp.make(entry.value(), entry.timestamp()); + if (actual.putIfAbsent(start, valueAndTimestamp) != null) { + actual.replace(start, valueAndTimestamp); + } + } + + final Map> expected = new HashMap<>(); + expected.put(0L, ValueAndTimestamp.make("0+1+2+3+4+5+6+7", 4L)); + expected.put(1L, ValueAndTimestamp.make("0+2+3+5+6", 4L)); + expected.put(3L, ValueAndTimestamp.make("0+3", 4L)); + assertEquals(expected, actual); + } + + @Test + public void testEarlyRecordsLargeInput() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic = "topic"; + final WindowBytesStoreSupplier storeSupplier = + inOrderIterator + ? new InOrderMemoryWindowStoreSupplier("InOrder", 50000L, 10L, false) + : Stores.inMemoryWindowStore("Reverse", Duration.ofMillis(50000), Duration.ofMillis(10), false); + + final KTable, String> table2 = builder + .stream(topic, Consumed.with(Serdes.String(), Serdes.String())) + .groupByKey(Grouped.with(Serdes.String(), Serdes.String())) + .windowedBy(SlidingWindows.withTimeDifferenceAndGrace(ofMillis(10), ofMillis(50))) + .aggregate( + MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + Materialized.as(storeSupplier) + ); + + final MockProcessorSupplier, String> supplier = new MockProcessorSupplier<>(); + table2.toStream().process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic, new StringSerializer(), new StringSerializer()); + + inputTopic1.pipeInput("E", "1", 0L); + inputTopic1.pipeInput("E", "3", 5L); + inputTopic1.pipeInput("E", "4", 6L); + inputTopic1.pipeInput("E", "2", 3L); + inputTopic1.pipeInput("E", "6", 13L); + inputTopic1.pipeInput("E", "5", 10L); + inputTopic1.pipeInput("E", "7", 4L); + inputTopic1.pipeInput("E", "8", 2L); + inputTopic1.pipeInput("E", "9", 15L); + } + final Comparator, String>> comparator = + Comparator.comparing((KeyValueTimestamp, String> o) -> o.key().key()) + .thenComparing((KeyValueTimestamp, String> o) -> o.key().window().start()); + + final ArrayList, String>> actual = supplier.theCapturedProcessor().processed(); + actual.sort(comparator); + assertEquals( + asList( + // E@0 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(0, 10)), "0+1", 0), + // E@5 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(0, 10)), "0+1+3", 5), + // E@6 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(0, 10)), "0+1+3+4", 6), + // E@3 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(0, 10)), "0+1+3+4+2", 6), + //E@10 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(0, 10)), "0+1+3+4+2+5", 10), + //E@4 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(0, 10)), "0+1+3+4+2+5+7", 10), + //E@2 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(0, 10)), "0+1+3+4+2+5+7+8", 10), + // E@5 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(1, 11)), "0+3", 5), + // E@6 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(1, 11)), "0+3+4", 6), + // E@3 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(1, 11)), "0+3+4+2", 6), + //E@10 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(1, 11)), "0+3+4+2+5", 10), + //E@4 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(1, 11)), "0+3+4+2+5+7", 10), + //E@2 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(1, 11)), "0+3+4+2+5+7+8", 10), + //E@13 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(3, 13)), "0+3+4+2+6", 13), + //E@10 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(3, 13)), "0+3+4+2+6+5", 13), + //E@4 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(3, 13)), "0+3+4+2+6+5+7", 13), + // E@3 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(4, 14)), "0+3+4", 6), + //E@13 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(4, 14)), "0+3+4+6", 13), + //E@10 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(4, 14)), "0+3+4+6+5", 13), + //E@4 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(4, 14)), "0+3+4+6+5+7", 13), + //E@4 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(5, 15)), "0+3+4+6+5", 13), + //E@15 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(5, 15)), "0+3+4+6+5+9", 15), + // E@6 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(6, 16)), "0+4", 6), + //E@13 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(6, 16)), "0+4+6", 13), + //E@10 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(6, 16)), "0+4+6+5", 13), + //E@15 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(6, 16)), "0+4+6+5+9", 15), + //E@13 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(7, 17)), "0+6", 13), + //E@10 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(7, 17)), "0+6+5", 13), + //E@15 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(7, 17)), "0+6+5+9", 15), + //E@10 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(11, 21)), "0+6", 13), + //E@15 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(11, 21)), "0+6+9", 15), + //E@15 + new KeyValueTimestamp<>(new Windowed<>("E", new TimeWindow(14, 24)), "0+9", 15)), + actual + ); + } + + @Test + public void shouldLogAndMeterWhenSkippingNullKey() { + final String builtInMetricsVersion = StreamsConfig.METRICS_LATEST; + final StreamsBuilder builder = new StreamsBuilder(); + final String topic = "topic"; + builder + .stream(topic, Consumed.with(Serdes.String(), Serdes.String())) + .groupByKey(Grouped.with(Serdes.String(), Serdes.String())) + .windowedBy(SlidingWindows.ofTimeDifferenceAndGrace(ofMillis(10), ofMillis(100))) + .aggregate(MockInitializer.STRING_INIT, MockAggregator.toStringInstance("+"), Materialized.>as("topic1-Canonicalized").withValueSerde(Serdes.String())); + + props.setProperty(StreamsConfig.BUILT_IN_METRICS_VERSION_CONFIG, builtInMetricsVersion); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(KStreamSlidingWindowAggregate.class); + final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(topic, new StringSerializer(), new StringSerializer()); + inputTopic.pipeInput(null, "1"); + assertThat( + appender.getEvents().stream() + .filter(e -> e.getLevel().equals("WARN")) + .map(Event::getMessage) + .collect(Collectors.toList()), + hasItem("Skipping record due to null key or value. topic=[topic] partition=[0] offset=[0]") + ); + } + } + + @Test + public void shouldLogAndMeterWhenSkippingExpiredWindowByGrace() { + final String builtInMetricsVersion = StreamsConfig.METRICS_LATEST; + final StreamsBuilder builder = new StreamsBuilder(); + final String topic = "topic"; + final WindowBytesStoreSupplier storeSupplier = + inOrderIterator + ? new InOrderMemoryWindowStoreSupplier("InOrder", 50000L, 10L, false) + : Stores.inMemoryWindowStore("Reverse", Duration.ofMillis(50000), Duration.ofMillis(10), false); + + final KStream stream1 = builder.stream(topic, Consumed.with(Serdes.String(), Serdes.String())); + stream1.groupByKey(Grouped.with(Serdes.String(), Serdes.String())) + .windowedBy(SlidingWindows.ofTimeDifferenceAndGrace(ofMillis(10), ofMillis(90))) + .aggregate( + MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + Materialized.as(storeSupplier) + ) + .toStream() + .to("output"); + + props.setProperty(StreamsConfig.BUILT_IN_METRICS_VERSION_CONFIG, builtInMetricsVersion); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(KStreamSlidingWindowAggregate.class); + final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + + final TestInputTopic inputTopic = + driver.createInputTopic(topic, new StringSerializer(), new StringSerializer()); + inputTopic.pipeInput("k", "100", 200L); + inputTopic.pipeInput("k", "0", 100L); + inputTopic.pipeInput("k", "1", 101L); + inputTopic.pipeInput("k", "2", 102L); + inputTopic.pipeInput("k", "3", 103L); + inputTopic.pipeInput("k", "4", 104L); + inputTopic.pipeInput("k", "5", 105L); + inputTopic.pipeInput("k", "6", 15L); + + assertLatenessMetrics(driver, is(7.0), is(185.0), is(96.25)); + + assertThat(appender.getMessages(), hasItems( + // left window for k@100 + "Skipping record for expired window. topic=[topic] partition=[0] offset=[1] timestamp=[100] window=[90,100] expiration=[110] streamTime=[200]", + // left window for k@101 + "Skipping record for expired window. topic=[topic] partition=[0] offset=[2] timestamp=[101] window=[91,101] expiration=[110] streamTime=[200]", + // left window for k@102 + "Skipping record for expired window. topic=[topic] partition=[0] offset=[3] timestamp=[102] window=[92,102] expiration=[110] streamTime=[200]", + // left window for k@103 + "Skipping record for expired window. topic=[topic] partition=[0] offset=[4] timestamp=[103] window=[93,103] expiration=[110] streamTime=[200]", + // left window for k@104 + "Skipping record for expired window. topic=[topic] partition=[0] offset=[5] timestamp=[104] window=[94,104] expiration=[110] streamTime=[200]", + // left window for k@105 + "Skipping record for expired window. topic=[topic] partition=[0] offset=[6] timestamp=[105] window=[95,105] expiration=[110] streamTime=[200]", + // left window for k@15 + "Skipping record for expired window. topic=[topic] partition=[0] offset=[7] timestamp=[15] window=[5,15] expiration=[110] streamTime=[200]" + )); + final TestOutputTopic, String> outputTopic = + driver.createOutputTopic("output", new TimeWindowedDeserializer<>(new StringDeserializer(), 10L), new StringDeserializer()); + assertThat(outputTopic.readRecord(), equalTo(new TestRecord<>(new Windowed<>("k", new TimeWindow(190, 200)), "0+100", null, 200L))); + assertTrue(outputTopic.isEmpty()); + } + } + + @SuppressWarnings("unchecked") + @Test + public void testAggregateRandomInput() { + + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + final WindowBytesStoreSupplier storeSupplier = + inOrderIterator + ? new InOrderMemoryWindowStoreSupplier("InOrder", 50000L, 10L, false) + : Stores.inMemoryWindowStore("Reverse", Duration.ofMillis(50000), Duration.ofMillis(10), false); + + final KTable, String> table = builder + .stream(topic1, Consumed.with(Serdes.String(), Serdes.String())) + .groupByKey(Grouped.with(Serdes.String(), Serdes.String())) + .windowedBy(SlidingWindows.ofTimeDifferenceAndGrace(ofMillis(10), ofMillis(10000))) + // The aggregator needs to sort the strings so the window value is the same for the final windows even when + // records are processed in a different order. Here, we sort alphabetically. + .aggregate( + () -> "", + (key, value, aggregate) -> { + aggregate += value; + final char[] ch = aggregate.toCharArray(); + Arrays.sort(ch); + aggregate = String.valueOf(ch); + return aggregate; + }, + Materialized.as(storeSupplier) + ); + final MockProcessorSupplier, String> supplier = new MockProcessorSupplier<>(); + table.toStream().process(supplier); + final long seed = new Random().nextLong(); + final Random shuffle = new Random(seed); + + try { + + final List> input = Arrays.asList( + ValueAndTimestamp.make("A", 10L), + ValueAndTimestamp.make("B", 15L), + ValueAndTimestamp.make("C", 16L), + ValueAndTimestamp.make("D", 18L), + ValueAndTimestamp.make("E", 30L), + ValueAndTimestamp.make("F", 40L), + ValueAndTimestamp.make("G", 55L), + ValueAndTimestamp.make("H", 56L), + ValueAndTimestamp.make("I", 58L), + ValueAndTimestamp.make("J", 58L), + ValueAndTimestamp.make("K", 62L), + ValueAndTimestamp.make("L", 63L), + ValueAndTimestamp.make("M", 63L), + ValueAndTimestamp.make("N", 63L), + ValueAndTimestamp.make("O", 76L), + ValueAndTimestamp.make("P", 77L), + ValueAndTimestamp.make("Q", 80L), + ValueAndTimestamp.make("R", 2L), + ValueAndTimestamp.make("S", 3L), + ValueAndTimestamp.make("T", 5L), + ValueAndTimestamp.make("U", 8L) + ); + + Collections.shuffle(input, shuffle); + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new StringSerializer(), new StringSerializer()); + for (int i = 0; i < input.size(); i++) { + inputTopic1.pipeInput("A", input.get(i).value(), input.get(i).timestamp()); + } + } + + final Map> results = new HashMap<>(); + + for (final KeyValueTimestamp, String> entry : supplier.theCapturedProcessor().processed()) { + final Windowed window = entry.key(); + final Long start = window.window().start(); + final ValueAndTimestamp valueAndTimestamp = ValueAndTimestamp.make(entry.value(), entry.timestamp()); + if (results.putIfAbsent(start, valueAndTimestamp) != null) { + results.replace(start, valueAndTimestamp); + } + } + verifyRandomTestResults(results); + } catch (final AssertionError t) { + throw new AssertionError( + "Assertion failed in randomized test. Reproduce with seed: " + seed + ".", + t + ); + } catch (final Throwable t) { + final StringBuilder sb = + new StringBuilder() + .append("Exception in randomized scenario. Reproduce with seed: ") + .append(seed) + .append("."); + throw new AssertionError(sb.toString(), t); + } + } + + private void verifyRandomTestResults(final Map> actual) { + final Map> expected = new HashMap<>(); + expected.put(0L, ValueAndTimestamp.make("ARSTU", 10L)); + expected.put(3L, ValueAndTimestamp.make("ASTU", 10L)); + expected.put(4L, ValueAndTimestamp.make("ATU", 10L)); + expected.put(5L, ValueAndTimestamp.make("ABTU", 15L)); + expected.put(6L, ValueAndTimestamp.make("ABCU", 16L)); + expected.put(8L, ValueAndTimestamp.make("ABCDU", 18L)); + expected.put(9L, ValueAndTimestamp.make("ABCD", 18L)); + expected.put(11L, ValueAndTimestamp.make("BCD", 18L)); + expected.put(16L, ValueAndTimestamp.make("CD", 18L)); + expected.put(17L, ValueAndTimestamp.make("D", 18L)); + expected.put(20L, ValueAndTimestamp.make("E", 30L)); + expected.put(30L, ValueAndTimestamp.make("EF", 40L)); + expected.put(31L, ValueAndTimestamp.make("F", 40L)); + expected.put(45L, ValueAndTimestamp.make("G", 55L)); + expected.put(46L, ValueAndTimestamp.make("GH", 56L)); + expected.put(48L, ValueAndTimestamp.make("GHIJ", 58L)); + expected.put(52L, ValueAndTimestamp.make("GHIJK", 62L)); + expected.put(53L, ValueAndTimestamp.make("GHIJKLMN", 63L)); + expected.put(56L, ValueAndTimestamp.make("HIJKLMN", 63L)); + expected.put(57L, ValueAndTimestamp.make("IJKLMN", 63L)); + expected.put(59L, ValueAndTimestamp.make("KLMN", 63L)); + expected.put(63L, ValueAndTimestamp.make("LMN", 63L)); + expected.put(66L, ValueAndTimestamp.make("O", 76L)); + expected.put(67L, ValueAndTimestamp.make("OP", 77L)); + expected.put(70L, ValueAndTimestamp.make("OPQ", 80L)); + expected.put(77L, ValueAndTimestamp.make("PQ", 80L)); + expected.put(78L, ValueAndTimestamp.make("Q", 80L)); + + assertEquals(expected, actual); + } + + private void assertLatenessMetrics(final TopologyTestDriver driver, + final Matcher dropTotal, + final Matcher maxLateness, + final Matcher avgLateness) { + + final MetricName dropTotalMetric; + final MetricName dropRateMetric; + final MetricName latenessMaxMetric; + final MetricName latenessAvgMetric; + dropTotalMetric = new MetricName( + "dropped-records-total", + "stream-task-metrics", + "The total number of dropped records", + mkMap( + mkEntry("thread-id", threadId), + mkEntry("task-id", "0_0") + ) + ); + dropRateMetric = new MetricName( + "dropped-records-rate", + "stream-task-metrics", + "The average number of dropped records per second", + mkMap( + mkEntry("thread-id", threadId), + mkEntry("task-id", "0_0") + ) + ); + latenessMaxMetric = new MetricName( + "record-lateness-max", + "stream-task-metrics", + "The observed maximum lateness of records in milliseconds, measured by comparing the record " + + "timestamp with the current stream time", + mkMap( + mkEntry("thread-id", threadId), + mkEntry("task-id", "0_0") + ) + ); + latenessAvgMetric = new MetricName( + "record-lateness-avg", + "stream-task-metrics", + "The observed average lateness of records in milliseconds, measured by comparing the record " + + "timestamp with the current stream time", + mkMap( + mkEntry("thread-id", threadId), + mkEntry("task-id", "0_0") + ) + ); + assertThat(driver.metrics().get(dropTotalMetric).metricValue(), dropTotal); + assertThat(driver.metrics().get(dropRateMetric).metricValue(), not(0.0)); + assertThat(driver.metrics().get(latenessMaxMetric).metricValue(), maxLateness); + assertThat(driver.metrics().get(latenessAvgMetric).metricValue(), avgLateness); + } + + private static class InOrderMemoryWindowStore extends InMemoryWindowStore { + InOrderMemoryWindowStore(final String name, + final long retentionPeriod, + final long windowSize, + final boolean retainDuplicates, + final String metricScope) { + super(name, retentionPeriod, windowSize, retainDuplicates, metricScope); + } + + @Override + public WindowStoreIterator backwardFetch(final Bytes key, final long timeFrom, final long timeTo) { + throw new UnsupportedOperationException("Backward fetch not supported here"); + } + + @Override + public KeyValueIterator, byte[]> backwardFetch(final Bytes keyFrom, + final Bytes keyTo, + final long timeFrom, + final long timeTo) { + throw new UnsupportedOperationException("Backward fetch not supported here"); + } + + @Override + public KeyValueIterator, byte[]> backwardFetchAll(final long timeFrom, final long timeTo) { + throw new UnsupportedOperationException("Backward fetch not supported here"); + } + + @Override + public KeyValueIterator, byte[]> backwardAll() { + throw new UnsupportedOperationException("Backward fetch not supported here"); + } + } + + private static class InOrderMemoryWindowStoreSupplier extends InMemoryWindowBytesStoreSupplier { + + InOrderMemoryWindowStoreSupplier(final String name, + final long retentionPeriod, + final long windowSize, + final boolean retainDuplicates) { + super(name, retentionPeriod, windowSize, retainDuplicates); + } + + @Override + public WindowStore get() { + return new InOrderMemoryWindowStore(name(), + retentionPeriod(), + windowSize(), + retainDuplicates(), + metricsScope()); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamSplitTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamSplitTest.java new file mode 100644 index 0000000..29eaf1a --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamSplitTest.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.IntegerDeserializer; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TestOutputTopic; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.Branched; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Named; +import org.apache.kafka.streams.kstream.Predicate; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Map; +import java.util.Properties; +import java.util.function.Consumer; + +import static org.junit.Assert.assertEquals; + +public class KStreamSplitTest { + + private final String topicName = "topic"; + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.String(), Serdes.String()); + private final StreamsBuilder builder = new StreamsBuilder(); + private final Predicate isEven = (key, value) -> (key % 2) == 0; + private final Predicate isMultipleOfThree = (key, value) -> (key % 3) == 0; + private final Predicate isMultipleOfFive = (key, value) -> (key % 5) == 0; + private final Predicate isMultipleOfSeven = (key, value) -> (key % 7) == 0; + private final Predicate isNegative = (key, value) -> key < 0; + private final KStream source = builder.stream(topicName, Consumed.with(Serdes.Integer(), Serdes.String())); + + @Test + public void testKStreamSplit() { + final Map> branches = + source.split() + .branch(isEven, Branched.withConsumer(ks -> ks.to("x2"))) + .branch(isMultipleOfThree, Branched.withConsumer(ks -> ks.to("x3"))) + .branch(isMultipleOfFive, Branched.withConsumer(ks -> ks.to("x5"))).noDefaultBranch(); + + assertEquals(0, branches.size()); + + builder.build(); + + withDriver(driver -> { + final TestOutputTopic x2 = driver.createOutputTopic("x2", new IntegerDeserializer(), new StringDeserializer()); + final TestOutputTopic x3 = driver.createOutputTopic("x3", new IntegerDeserializer(), new StringDeserializer()); + final TestOutputTopic x5 = driver.createOutputTopic("x5", new IntegerDeserializer(), new StringDeserializer()); + assertEquals(Arrays.asList("V0", "V2", "V4", "V6"), x2.readValuesToList()); + assertEquals(Arrays.asList("V3"), x3.readValuesToList()); + assertEquals(Arrays.asList("V5"), x5.readValuesToList()); + }); + } + + private void withDriver(final Consumer test) { + final int[] expectedKeys = new int[]{-1, 0, 1, 2, 3, 4, 5, 6, 7}; + final Topology topology = builder.build(); + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, props)) { + final TestInputTopic inputTopic = driver.createInputTopic(topicName, new IntegerSerializer(), new StringSerializer()); + for (final int expectedKey : expectedKeys) { + inputTopic.pipeInput(expectedKey, "V" + expectedKey); + } + test.accept(driver); + } + } + + @Test + public void testTypeVariance() { + final Predicate positive = (key, value) -> key.doubleValue() > 0; + final Predicate negative = (key, value) -> key.doubleValue() < 0; + new StreamsBuilder() + .stream("empty") + .split() + .branch(positive) + .branch(negative); + } + + @Test + public void testResultingMap() { + final Map> branches = + source.split(Named.as("foo-")) + // "foo-bar" + .branch(isEven, Branched.as("bar")) + // no entry: a Consumer is provided + .branch(isMultipleOfThree, Branched.withConsumer(ks -> { })) + // no entry: chain function returns null + .branch(isMultipleOfFive, Branched.withFunction(ks -> null)) + // "foo-4": chain function returns non-null value + .branch(isNegative, Branched.withFunction(ks -> ks)) + // "foo-5": name defaults to the branch position + .branch(isMultipleOfSeven) + // "foo-0": "0" is the default name for the default branch + .defaultBranch(); + assertEquals(4, branches.size()); + // direct the branched streams into different topics named with branch name + for (final Map.Entry> branch: branches.entrySet()) { + branch.getValue().to(branch.getKey()); + } + builder.build(); + + withDriver(driver -> { + final TestOutputTopic even = driver.createOutputTopic("foo-bar", new IntegerDeserializer(), new StringDeserializer()); + final TestOutputTopic negative = driver.createOutputTopic("foo-4", new IntegerDeserializer(), new StringDeserializer()); + final TestOutputTopic x7 = driver.createOutputTopic("foo-5", new IntegerDeserializer(), new StringDeserializer()); + final TestOutputTopic defaultBranch = driver.createOutputTopic("foo-0", new IntegerDeserializer(), new StringDeserializer()); + assertEquals(Arrays.asList("V0", "V2", "V4", "V6"), even.readValuesToList()); + assertEquals(Arrays.asList("V-1"), negative.readValuesToList()); + assertEquals(Arrays.asList("V7"), x7.readValuesToList()); + assertEquals(Arrays.asList("V1"), defaultBranch.readValuesToList()); + }); + } + + @Test + public void testBranchingWithNoTerminalOperation() { + final String outputTopicName = "output"; + source.split() + .branch(isEven, Branched.withConsumer(ks -> ks.to(outputTopicName))) + .branch(isMultipleOfFive, Branched.withConsumer(ks -> ks.to(outputTopicName))); + builder.build(); + withDriver(driver -> { + final TestOutputTopic outputTopic = + driver.createOutputTopic(outputTopicName, new IntegerDeserializer(), new StringDeserializer()); + assertEquals(Arrays.asList("V0", "V2", "V4", "V5", "V6"), outputTopic.readValuesToList()); + }); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamTransformTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamTransformTest.java new file mode 100644 index 0000000..5aad9f0 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamTransformTest.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Transformer; +import org.apache.kafka.streams.kstream.TransformerSupplier; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.To; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Test; + +import java.time.Duration; +import java.time.Instant; +import java.util.Properties; + +import static org.junit.Assert.assertEquals; + +public class KStreamTransformTest { + private static final String TOPIC_NAME = "topic"; + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.Integer(), Serdes.Integer()); + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Test + public void testTransform() { + final StreamsBuilder builder = new StreamsBuilder(); + + final TransformerSupplier> transformerSupplier = + () -> new Transformer>() { + private int total = 0; + + @Override + public void init(final ProcessorContext context) { + context.schedule( + Duration.ofMillis(1), + PunctuationType.WALL_CLOCK_TIME, + timestamp -> context.forward(-1, (int) timestamp, To.all().withTimestamp(timestamp)) + ); + } + + @Override + public KeyValue transform(final Number key, final Number value) { + total += value.intValue(); + return KeyValue.pair(key.intValue() * 2, total); + } + + @Override + public void close() { } + }; + + final int[] expectedKeys = {1, 10, 100, 1000}; + + final MockProcessorSupplier processor = new MockProcessorSupplier<>(); + final KStream stream = builder.stream(TOPIC_NAME, Consumed.with(Serdes.Integer(), Serdes.Integer())); + stream.transform(transformerSupplier).process(processor); + + try (final TopologyTestDriver driver = new TopologyTestDriver( + builder.build(), + Instant.ofEpochMilli(0L))) { + final TestInputTopic inputTopic = + driver.createInputTopic(TOPIC_NAME, new IntegerSerializer(), new IntegerSerializer()); + + for (final int expectedKey : expectedKeys) { + inputTopic.pipeInput(expectedKey, expectedKey * 10, expectedKey / 2L); + } + + driver.advanceWallClockTime(Duration.ofMillis(2)); + driver.advanceWallClockTime(Duration.ofMillis(1)); + + final KeyValueTimestamp[] expected = { + new KeyValueTimestamp<>(2, 10, 0), + new KeyValueTimestamp<>(20, 110, 5), + new KeyValueTimestamp<>(200, 1110, 50), + new KeyValueTimestamp<>(2000, 11110, 500), + new KeyValueTimestamp<>(-1, 2, 2), + new KeyValueTimestamp<>(-1, 3, 3) + }; + + assertEquals(expected.length, processor.theCapturedProcessor().processed().size()); + for (int i = 0; i < expected.length; i++) { + assertEquals(expected[i], processor.theCapturedProcessor().processed().get(i)); + } + } + } + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Test + public void testTransformWithNewDriverAndPunctuator() { + final StreamsBuilder builder = new StreamsBuilder(); + + final TransformerSupplier> transformerSupplier = + () -> new Transformer>() { + private int total = 0; + + @Override + public void init(final ProcessorContext context) { + context.schedule( + Duration.ofMillis(1), + PunctuationType.WALL_CLOCK_TIME, + timestamp -> context.forward(-1, (int) timestamp, To.all().withTimestamp(timestamp))); + } + + @Override + public KeyValue transform(final Number key, final Number value) { + total += value.intValue(); + return KeyValue.pair(key.intValue() * 2, total); + } + + @Override + public void close() { } + }; + + final int[] expectedKeys = {1, 10, 100, 1000}; + + final MockProcessorSupplier processor = new MockProcessorSupplier<>(); + final KStream stream = builder.stream(TOPIC_NAME, Consumed.with(Serdes.Integer(), Serdes.Integer())); + stream.transform(transformerSupplier).process(processor); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props, Instant.ofEpochMilli(0L))) { + final TestInputTopic inputTopic = + driver.createInputTopic(TOPIC_NAME, new IntegerSerializer(), new IntegerSerializer()); + for (final int expectedKey : expectedKeys) { + inputTopic.pipeInput(expectedKey, expectedKey * 10, 0L); + } + + // This tick yields the "-1:2" result + driver.advanceWallClockTime(Duration.ofMillis(2)); + // This tick further advances the clock to 3, which leads to the "-1:3" result + driver.advanceWallClockTime(Duration.ofMillis(1)); + } + + assertEquals(6, processor.theCapturedProcessor().processed().size()); + + final KeyValueTimestamp[] expected = {new KeyValueTimestamp<>(2, 10, 0), + new KeyValueTimestamp<>(20, 110, 0), + new KeyValueTimestamp<>(200, 1110, 0), + new KeyValueTimestamp<>(2000, 11110, 0), + new KeyValueTimestamp<>(-1, 2, 2), + new KeyValueTimestamp<>(-1, 3, 3)}; + + for (int i = 0; i < expected.length; i++) { + assertEquals(expected[i], processor.theCapturedProcessor().processed().get(i)); + } + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamTransformValuesTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamTransformValuesTest.java new file mode 100644 index 0000000..983e52d --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamTransformValuesTest.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.ValueTransformer; +import org.apache.kafka.streams.kstream.ValueTransformerSupplier; +import org.apache.kafka.streams.kstream.ValueTransformerWithKey; +import org.apache.kafka.streams.kstream.ValueTransformerWithKeySupplier; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.internals.ForwardingDisabledProcessorContext; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.NoOpValueTransformerWithKeySupplier; +import org.apache.kafka.test.StreamsTestUtils; +import org.easymock.EasyMockRunner; +import org.easymock.Mock; +import org.easymock.MockType; +import org.junit.Test; +import org.junit.runner.RunWith; + +import java.util.Properties; + +import static org.hamcrest.CoreMatchers.isA; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertArrayEquals; + +@RunWith(EasyMockRunner.class) +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class KStreamTransformValuesTest { + private final String topicName = "topic"; + private final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.Integer(), Serdes.Integer()); + @Mock(MockType.NICE) + private ProcessorContext context; + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Test + public void testTransform() { + final StreamsBuilder builder = new StreamsBuilder(); + + final ValueTransformerSupplier valueTransformerSupplier = + () -> new ValueTransformer() { + private int total = 0; + + @Override + public void init(final ProcessorContext context) { } + + @Override + public Integer transform(final Number value) { + total += value.intValue(); + return total; + } + + @Override + public void close() { } + }; + + final int[] expectedKeys = {1, 10, 100, 1000}; + + final KStream stream; + stream = builder.stream(topicName, Consumed.with(Serdes.Integer(), Serdes.Integer())); + stream.transformValues(valueTransformerSupplier).process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + for (final int expectedKey : expectedKeys) { + final TestInputTopic inputTopic = + driver.createInputTopic(topicName, new IntegerSerializer(), new IntegerSerializer()); + inputTopic.pipeInput(expectedKey, expectedKey * 10, expectedKey / 2L); + } + } + final KeyValueTimestamp[] expected = {new KeyValueTimestamp<>(1, 10, 0), + new KeyValueTimestamp<>(10, 110, 5), + new KeyValueTimestamp<>(100, 1110, 50), + new KeyValueTimestamp<>(1000, 11110, 500)}; + + assertArrayEquals(expected, supplier.theCapturedProcessor().processed().toArray()); + } + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Test + public void testTransformWithKey() { + final StreamsBuilder builder = new StreamsBuilder(); + + final ValueTransformerWithKeySupplier valueTransformerSupplier = + () -> new ValueTransformerWithKey() { + private int total = 0; + + @Override + public void init(final ProcessorContext context) { } + + @Override + public Integer transform(final Integer readOnlyKey, final Number value) { + total += value.intValue() + readOnlyKey; + return total; + } + + @Override + public void close() { } + }; + + final int[] expectedKeys = {1, 10, 100, 1000}; + + final KStream stream; + stream = builder.stream(topicName, Consumed.with(Serdes.Integer(), Serdes.Integer())); + stream.transformValues(valueTransformerSupplier).process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(topicName, new IntegerSerializer(), new IntegerSerializer()); + for (final int expectedKey : expectedKeys) { + inputTopic.pipeInput(expectedKey, expectedKey * 10, expectedKey / 2L); + } + } + final KeyValueTimestamp[] expected = {new KeyValueTimestamp<>(1, 11, 0), + new KeyValueTimestamp<>(10, 121, 5), + new KeyValueTimestamp<>(100, 1221, 50), + new KeyValueTimestamp<>(1000, 12221, 500)}; + + assertArrayEquals(expected, supplier.theCapturedProcessor().processed().toArray()); + } + + @SuppressWarnings("unchecked") + @Test + public void shouldInitializeTransformerWithForwardDisabledProcessorContext() { + final NoOpValueTransformerWithKeySupplier transformer = new NoOpValueTransformerWithKeySupplier<>(); + final KStreamTransformValues transformValues = new KStreamTransformValues<>(transformer); + final org.apache.kafka.streams.processor.Processor processor = transformValues.get(); + + processor.init(context); + + assertThat(transformer.context, isA((Class) ForwardingDisabledProcessorContext.class)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamWindowAggregateTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamWindowAggregateTest.java new file mode 100644 index 0000000..df40c94 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KStreamWindowAggregateTest.java @@ -0,0 +1,462 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TestOutputTopic; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.test.TestRecord; +import org.apache.kafka.test.MockAggregator; +import org.apache.kafka.test.MockInitializer; +import org.apache.kafka.test.MockProcessor; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.StreamsTestUtils; +import org.hamcrest.Matcher; +import org.junit.Test; + +import java.time.Duration; +import java.util.List; +import java.util.Properties; + +import static java.time.Duration.ofMillis; +import static java.util.Arrays.asList; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.CoreMatchers.hasItems; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +@SuppressWarnings("deprecation") +public class KStreamWindowAggregateTest { + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.String(), Serdes.String()); + private final String threadId = Thread.currentThread().getName(); + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Test + public void testAggBasic() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + + final KTable, String> table2 = builder + .stream(topic1, Consumed.with(Serdes.String(), Serdes.String())) + .groupByKey(Grouped.with(Serdes.String(), Serdes.String())) + .windowedBy(TimeWindows.of(ofMillis(10)).advanceBy(ofMillis(5))) + .aggregate(MockInitializer.STRING_INIT, MockAggregator.TOSTRING_ADDER, Materialized.>as("topic1-Canonized").withValueSerde(Serdes.String())); + + final MockProcessorSupplier, String> supplier = new MockProcessorSupplier<>(); + table2.toStream().process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new StringSerializer(), new StringSerializer()); + inputTopic1.pipeInput("A", "1", 0L); + inputTopic1.pipeInput("B", "2", 1L); + inputTopic1.pipeInput("C", "3", 2L); + inputTopic1.pipeInput("D", "4", 3L); + inputTopic1.pipeInput("A", "1", 4L); + + inputTopic1.pipeInput("A", "1", 5L); + inputTopic1.pipeInput("B", "2", 6L); + inputTopic1.pipeInput("D", "4", 7L); + inputTopic1.pipeInput("B", "2", 8L); + inputTopic1.pipeInput("C", "3", 9L); + + inputTopic1.pipeInput("A", "1", 10L); + inputTopic1.pipeInput("B", "2", 11L); + inputTopic1.pipeInput("D", "4", 12L); + inputTopic1.pipeInput("B", "2", 13L); + inputTopic1.pipeInput("C", "3", 14L); + + inputTopic1.pipeInput("B", "1", 3L); + inputTopic1.pipeInput("B", "2", 2L); + inputTopic1.pipeInput("B", "3", 9L); + } + + assertEquals( + asList( + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(0, 10)), "0+1", 0), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(0, 10)), "0+2", 1), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(0, 10)), "0+3", 2), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(0, 10)), "0+4", 3), + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(0, 10)), "0+1+1", 4), + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(0, 10)), "0+1+1+1", 5), + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(5, 15)), "0+1", 5), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(0, 10)), "0+2+2", 6), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(5, 15)), "0+2", 6), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(0, 10)), "0+4+4", 7), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(5, 15)), "0+4", 7), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(0, 10)), "0+2+2+2", 8), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(5, 15)), "0+2+2", 8), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(0, 10)), "0+3+3", 9), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(5, 15)), "0+3", 9), + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(5, 15)), "0+1+1", 10), + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(10, 20)), "0+1", 10), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(5, 15)), "0+2+2+2", 11), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(10, 20)), "0+2", 11), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(5, 15)), "0+4+4", 12), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(10, 20)), "0+4", 12), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(5, 15)), "0+2+2+2+2", 13), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(10, 20)), "0+2+2", 13), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(5, 15)), "0+3+3", 14), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(10, 20)), "0+3", 14), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(0, 10)), "0+2+2+2+1", 8), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(0, 10)), "0+2+2+2+1+2", 8), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(0, 10)), "0+2+2+2+1+2+3", 9), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(5, 15)), "0+2+2+2+2+3", 13) + + ), + supplier.theCapturedProcessor().processed() + ); + } + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Test + public void testJoin() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + final String topic2 = "topic2"; + + final KTable, String> table1 = builder + .stream(topic1, Consumed.with(Serdes.String(), Serdes.String())) + .groupByKey(Grouped.with(Serdes.String(), Serdes.String())) + .windowedBy(TimeWindows.of(ofMillis(10)).advanceBy(ofMillis(5))) + .aggregate(MockInitializer.STRING_INIT, MockAggregator.TOSTRING_ADDER, Materialized.>as("topic1-Canonized").withValueSerde(Serdes.String())); + + final MockProcessorSupplier, String> supplier = new MockProcessorSupplier<>(); + table1.toStream().process(supplier); + + final KTable, String> table2 = builder + .stream(topic2, Consumed.with(Serdes.String(), Serdes.String())) + .groupByKey(Grouped.with(Serdes.String(), Serdes.String())) + .windowedBy(TimeWindows.of(ofMillis(10)).advanceBy(ofMillis(5))) + .aggregate(MockInitializer.STRING_INIT, MockAggregator.TOSTRING_ADDER, Materialized.>as("topic2-Canonized").withValueSerde(Serdes.String())); + table2.toStream().process(supplier); + + table1.join(table2, (p1, p2) -> p1 + "%" + p2).toStream().process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new StringSerializer(), new StringSerializer()); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, new StringSerializer(), new StringSerializer()); + inputTopic1.pipeInput("A", "1", 0L); + inputTopic1.pipeInput("B", "2", 1L); + inputTopic1.pipeInput("C", "3", 2L); + inputTopic1.pipeInput("D", "4", 3L); + inputTopic1.pipeInput("A", "1", 9L); + + final List, String>> processors = supplier.capturedProcessors(3); + + processors.get(0).checkAndClearProcessResult( + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(0, 10)), "0+1", 0), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(0, 10)), "0+2", 1), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(0, 10)), "0+3", 2), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(0, 10)), "0+4", 3), + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(0, 10)), "0+1+1", 9), + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(5, 15)), "0+1", 9) + ); + processors.get(1).checkAndClearProcessResult(); + processors.get(2).checkAndClearProcessResult(); + + inputTopic1.pipeInput("A", "1", 5L); + inputTopic1.pipeInput("B", "2", 6L); + inputTopic1.pipeInput("D", "4", 7L); + inputTopic1.pipeInput("B", "2", 8L); + inputTopic1.pipeInput("C", "3", 9L); + + processors.get(0).checkAndClearProcessResult( + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(0, 10)), "0+1+1+1", 9), + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(5, 15)), "0+1+1", 9), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(0, 10)), "0+2+2", 6), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(5, 15)), "0+2", 6), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(0, 10)), "0+4+4", 7), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(5, 15)), "0+4", 7), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(0, 10)), "0+2+2+2", 8), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(5, 15)), "0+2+2", 8), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(0, 10)), "0+3+3", 9), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(5, 15)), "0+3", 9) + ); + processors.get(1).checkAndClearProcessResult(); + processors.get(2).checkAndClearProcessResult(); + + inputTopic2.pipeInput("A", "a", 0L); + inputTopic2.pipeInput("B", "b", 1L); + inputTopic2.pipeInput("C", "c", 2L); + inputTopic2.pipeInput("D", "d", 20L); + inputTopic2.pipeInput("A", "a", 20L); + + processors.get(0).checkAndClearProcessResult(); + processors.get(1).checkAndClearProcessResult( + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(0, 10)), "0+a", 0), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(0, 10)), "0+b", 1), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(0, 10)), "0+c", 2), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(15, 25)), "0+d", 20), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(20, 30)), "0+d", 20), + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(15, 25)), "0+a", 20), + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(20, 30)), "0+a", 20) + ); + processors.get(2).checkAndClearProcessResult( + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(0, 10)), "0+1+1+1%0+a", 9), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(0, 10)), "0+2+2+2%0+b", 8), + new KeyValueTimestamp<>(new Windowed<>("C", new TimeWindow(0, 10)), "0+3+3%0+c", 9)); + + inputTopic2.pipeInput("A", "a", 5L); + inputTopic2.pipeInput("B", "b", 6L); + inputTopic2.pipeInput("D", "d", 7L); + inputTopic2.pipeInput("D", "d", 18L); + inputTopic2.pipeInput("A", "a", 21L); + + processors.get(0).checkAndClearProcessResult(); + processors.get(1).checkAndClearProcessResult( + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(0, 10)), "0+a+a", 5), + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(5, 15)), "0+a", 5), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(0, 10)), "0+b+b", 6), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(5, 15)), "0+b", 6), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(0, 10)), "0+d", 7), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(5, 15)), "0+d", 7), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(10, 20)), "0+d", 18), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(15, 25)), "0+d+d", 20), + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(15, 25)), "0+a+a", 21), + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(20, 30)), "0+a+a", 21) + ); + processors.get(2).checkAndClearProcessResult( + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(0, 10)), "0+1+1+1%0+a+a", 9), + new KeyValueTimestamp<>(new Windowed<>("A", new TimeWindow(5, 15)), "0+1+1%0+a", 9), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(0, 10)), "0+2+2+2%0+b+b", 8), + new KeyValueTimestamp<>(new Windowed<>("B", new TimeWindow(5, 15)), "0+2+2%0+b", 8), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(0, 10)), "0+4+4%0+d", 7), + new KeyValueTimestamp<>(new Windowed<>("D", new TimeWindow(5, 15)), "0+4%0+d", 7) + ); + } + } + + @Test + public void shouldLogAndMeterWhenSkippingNullKey() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic = "topic"; + + builder + .stream(topic, Consumed.with(Serdes.String(), Serdes.String())) + .groupByKey(Grouped.with(Serdes.String(), Serdes.String())) + .windowedBy(TimeWindows.of(ofMillis(10)).advanceBy(ofMillis(5))) + .aggregate( + MockInitializer.STRING_INIT, + MockAggregator.toStringInstance("+"), + Materialized.>as("topic1-Canonicalized").withValueSerde(Serdes.String()) + ); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(KStreamWindowAggregate.class); + final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + + final TestInputTopic inputTopic = + driver.createInputTopic(topic, new StringSerializer(), new StringSerializer()); + inputTopic.pipeInput(null, "1"); + + assertThat(appender.getMessages(), hasItem("Skipping record due to null key. topic=[topic] partition=[0] offset=[0]")); + } + } + + @Test + public void shouldLogAndMeterWhenSkippingExpiredWindow() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic = "topic"; + + final KStream stream1 = builder.stream(topic, Consumed.with(Serdes.String(), Serdes.String())); + stream1.groupByKey(Grouped.with(Serdes.String(), Serdes.String())) + .windowedBy(TimeWindows.of(ofMillis(10)).advanceBy(ofMillis(5)).grace(ofMillis(90))) + .aggregate( + () -> "", + MockAggregator.toStringInstance("+"), + Materialized.>as("topic1-Canonicalized") + .withValueSerde(Serdes.String()) + .withCachingDisabled() + .withLoggingDisabled() + .withRetention(Duration.ofMillis(100)) + ) + .toStream() + .map((key, value) -> new KeyValue<>(key.toString(), value)) + .to("output"); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(KStreamWindowAggregate.class); + final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + + final TestInputTopic inputTopic = + driver.createInputTopic(topic, new StringSerializer(), new StringSerializer()); + inputTopic.pipeInput("k", "100", 100L); + inputTopic.pipeInput("k", "0", 0L); + inputTopic.pipeInput("k", "1", 1L); + inputTopic.pipeInput("k", "2", 2L); + inputTopic.pipeInput("k", "3", 3L); + inputTopic.pipeInput("k", "4", 4L); + inputTopic.pipeInput("k", "5", 5L); + inputTopic.pipeInput("k", "6", 6L); + + assertLatenessMetrics( + driver, + is(7.0), // how many events get dropped + is(100.0), // k:0 is 100ms late, since its time is 0, but it arrives at stream time 100. + is(84.875) // (0 + 100 + 99 + 98 + 97 + 96 + 95 + 94) / 8 + ); + + assertThat(appender.getMessages(), hasItems( + "Skipping record for expired window. topic=[topic] partition=[0] offset=[1] timestamp=[0] window=[0,10) expiration=[10] streamTime=[100]", + "Skipping record for expired window. topic=[topic] partition=[0] offset=[2] timestamp=[1] window=[0,10) expiration=[10] streamTime=[100]", + "Skipping record for expired window. topic=[topic] partition=[0] offset=[3] timestamp=[2] window=[0,10) expiration=[10] streamTime=[100]", + "Skipping record for expired window. topic=[topic] partition=[0] offset=[4] timestamp=[3] window=[0,10) expiration=[10] streamTime=[100]", + "Skipping record for expired window. topic=[topic] partition=[0] offset=[5] timestamp=[4] window=[0,10) expiration=[10] streamTime=[100]", + "Skipping record for expired window. topic=[topic] partition=[0] offset=[6] timestamp=[5] window=[0,10) expiration=[10] streamTime=[100]", + "Skipping record for expired window. topic=[topic] partition=[0] offset=[7] timestamp=[6] window=[0,10) expiration=[10] streamTime=[100]" + )); + + final TestOutputTopic outputTopic = + driver.createOutputTopic("output", new StringDeserializer(), new StringDeserializer()); + + assertThat(outputTopic.readRecord(), equalTo(new TestRecord<>("[k@95/105]", "+100", null, 100L))); + assertThat(outputTopic.readRecord(), equalTo(new TestRecord<>("[k@100/110]", "+100", null, 100L))); + assertThat(outputTopic.readRecord(), equalTo(new TestRecord<>("[k@5/15]", "+5", null, 5L))); + assertThat(outputTopic.readRecord(), equalTo(new TestRecord<>("[k@5/15]", "+5+6", null, 6L))); + assertTrue(outputTopic.isEmpty()); + } + } + + @Test + public void shouldLogAndMeterWhenSkippingExpiredWindowByGrace() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic = "topic"; + + final KStream stream1 = builder.stream(topic, Consumed.with(Serdes.String(), Serdes.String())); + stream1.groupByKey(Grouped.with(Serdes.String(), Serdes.String())) + .windowedBy(TimeWindows.of(ofMillis(10)).advanceBy(ofMillis(10)).grace(ofMillis(90L))) + .aggregate( + () -> "", + MockAggregator.toStringInstance("+"), + Materialized.>as("topic1-Canonicalized").withValueSerde(Serdes.String()).withCachingDisabled().withLoggingDisabled() + ) + .toStream() + .map((key, value) -> new KeyValue<>(key.toString(), value)) + .to("output"); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(KStreamWindowAggregate.class); + final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + + final TestInputTopic inputTopic = + driver.createInputTopic(topic, new StringSerializer(), new StringSerializer()); + inputTopic.pipeInput("k", "100", 200L); + inputTopic.pipeInput("k", "0", 100L); + inputTopic.pipeInput("k", "1", 101L); + inputTopic.pipeInput("k", "2", 102L); + inputTopic.pipeInput("k", "3", 103L); + inputTopic.pipeInput("k", "4", 104L); + inputTopic.pipeInput("k", "5", 105L); + inputTopic.pipeInput("k", "6", 6L); + + assertLatenessMetrics(driver, is(7.0), is(194.0), is(97.375)); + + assertThat(appender.getMessages(), hasItems( + "Skipping record for expired window. topic=[topic] partition=[0] offset=[1] timestamp=[100] window=[100,110) expiration=[110] streamTime=[200]", + "Skipping record for expired window. topic=[topic] partition=[0] offset=[2] timestamp=[101] window=[100,110) expiration=[110] streamTime=[200]", + "Skipping record for expired window. topic=[topic] partition=[0] offset=[3] timestamp=[102] window=[100,110) expiration=[110] streamTime=[200]", + "Skipping record for expired window. topic=[topic] partition=[0] offset=[4] timestamp=[103] window=[100,110) expiration=[110] streamTime=[200]", + "Skipping record for expired window. topic=[topic] partition=[0] offset=[5] timestamp=[104] window=[100,110) expiration=[110] streamTime=[200]", + "Skipping record for expired window. topic=[topic] partition=[0] offset=[6] timestamp=[105] window=[100,110) expiration=[110] streamTime=[200]", + "Skipping record for expired window. topic=[topic] partition=[0] offset=[7] timestamp=[6] window=[0,10) expiration=[110] streamTime=[200]" + )); + + final TestOutputTopic outputTopic = + driver.createOutputTopic("output", new StringDeserializer(), new StringDeserializer()); + assertThat(outputTopic.readRecord(), equalTo(new TestRecord<>("[k@200/210]", "+100", null, 200L))); + assertTrue(outputTopic.isEmpty()); + } + } + + private void assertLatenessMetrics(final TopologyTestDriver driver, + final Matcher dropTotal, + final Matcher maxLateness, + final Matcher avgLateness) { + + final MetricName dropTotalMetric; + final MetricName dropRateMetric; + final MetricName latenessMaxMetric; + final MetricName latenessAvgMetric; + dropTotalMetric = new MetricName( + "dropped-records-total", + "stream-task-metrics", + "The total number of dropped records", + mkMap( + mkEntry("thread-id", threadId), + mkEntry("task-id", "0_0") + ) + ); + dropRateMetric = new MetricName( + "dropped-records-rate", + "stream-task-metrics", + "The average number of dropped records per second", + mkMap( + mkEntry("thread-id", threadId), + mkEntry("task-id", "0_0") + ) + ); + latenessMaxMetric = new MetricName( + "record-lateness-max", + "stream-task-metrics", + "The observed maximum lateness of records in milliseconds, measured by comparing the record " + + "timestamp with the current stream time", + mkMap( + mkEntry("thread-id", threadId), + mkEntry("task-id", "0_0") + ) + ); + latenessAvgMetric = new MetricName( + "record-lateness-avg", + "stream-task-metrics", + "The observed average lateness of records in milliseconds, measured by comparing the record " + + "timestamp with the current stream time", + mkMap( + mkEntry("thread-id", threadId), + mkEntry("task-id", "0_0") + ) + ); + + assertThat(driver.metrics().get(dropTotalMetric).metricValue(), dropTotal); + assertThat(driver.metrics().get(dropRateMetric).metricValue(), not(0.0)); + assertThat(driver.metrics().get(latenessMaxMetric).metricValue(), maxLateness); + assertThat(driver.metrics().get(latenessAvgMetric).metricValue(), avgLateness); + } + +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableAggregateTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableAggregateTest.java new file mode 100644 index 0000000..220734f --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableAggregateTest.java @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.test.MockAggregator; +import org.apache.kafka.test.MockApiProcessor; +import org.apache.kafka.test.MockApiProcessorSupplier; +import org.apache.kafka.test.MockInitializer; +import org.apache.kafka.test.MockMapper; +import org.apache.kafka.test.TestUtils; +import org.junit.Test; +import java.util.Properties; + +import java.time.Duration; +import java.time.Instant; + +import static java.util.Arrays.asList; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkProperties; +import static org.junit.Assert.assertEquals; + +public class KTableAggregateTest { + private final Serde stringSerde = Serdes.String(); + private final Consumed consumed = Consumed.with(stringSerde, stringSerde); + private final Grouped stringSerialized = Grouped.with(stringSerde, stringSerde); + private final MockApiProcessorSupplier supplier = new MockApiProcessorSupplier<>(); + private final static Properties CONFIG = mkProperties(mkMap( + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory("kafka-test").getAbsolutePath()))); + + @Test + public void testAggBasic() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + + final KTable table1 = builder.table(topic1, consumed); + final KTable table2 = table1 + .groupBy( + MockMapper.noOpKeyValueMapper(), + stringSerialized) + .aggregate( + MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + MockAggregator.TOSTRING_REMOVER, + Materialized.>as("topic1-Canonized") + .withValueSerde(stringSerde)); + + table2.toStream().process(supplier); + + try ( + final TopologyTestDriver driver = new TopologyTestDriver( + builder.build(), CONFIG, Instant.ofEpochMilli(0L))) { + final TestInputTopic inputTopic = + driver.createInputTopic(topic1, new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + + inputTopic.pipeInput("A", "1", 10L); + inputTopic.pipeInput("B", "2", 15L); + inputTopic.pipeInput("A", "3", 20L); + inputTopic.pipeInput("B", "4", 18L); + inputTopic.pipeInput("C", "5", 5L); + inputTopic.pipeInput("D", "6", 25L); + inputTopic.pipeInput("B", "7", 15L); + inputTopic.pipeInput("C", "8", 10L); + + assertEquals( + asList( + new KeyValueTimestamp<>("A", "0+1", 10L), + new KeyValueTimestamp<>("B", "0+2", 15L), + new KeyValueTimestamp<>("A", "0+1-1", 20L), + new KeyValueTimestamp<>("A", "0+1-1+3", 20L), + new KeyValueTimestamp<>("B", "0+2-2", 18L), + new KeyValueTimestamp<>("B", "0+2-2+4", 18L), + new KeyValueTimestamp<>("C", "0+5", 5L), + new KeyValueTimestamp<>("D", "0+6", 25L), + new KeyValueTimestamp<>("B", "0+2-2+4-4", 18L), + new KeyValueTimestamp<>("B", "0+2-2+4-4+7", 18L), + new KeyValueTimestamp<>("C", "0+5-5", 10L), + new KeyValueTimestamp<>("C", "0+5-5+8", 10L)), + supplier.theCapturedProcessor().processed()); + } + } + + @Test + public void testAggRepartition() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + + final KTable table1 = builder.table(topic1, consumed); + final KTable table2 = table1 + .groupBy( + (key, value) -> { + switch (key) { + case "null": + return KeyValue.pair(null, value); + case "NULL": + return null; + default: + return KeyValue.pair(value, value); + } + }, + stringSerialized) + .aggregate( + MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + MockAggregator.TOSTRING_REMOVER, + Materialized.>as("topic1-Canonized") + .withValueSerde(stringSerde)); + + table2.toStream().process(supplier); + + try ( + final TopologyTestDriver driver = new TopologyTestDriver( + builder.build(), CONFIG, Instant.ofEpochMilli(0L))) { + final TestInputTopic inputTopic = + driver.createInputTopic(topic1, new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + + inputTopic.pipeInput("A", "1", 10L); + inputTopic.pipeInput("A", (String) null, 15L); + inputTopic.pipeInput("A", "1", 12L); + inputTopic.pipeInput("B", "2", 20L); + inputTopic.pipeInput("null", "3", 25L); + inputTopic.pipeInput("B", "4", 23L); + inputTopic.pipeInput("NULL", "5", 24L); + inputTopic.pipeInput("B", "7", 22L); + + assertEquals( + asList( + new KeyValueTimestamp<>("1", "0+1", 10), + new KeyValueTimestamp<>("1", "0+1-1", 15), + new KeyValueTimestamp<>("1", "0+1-1+1", 15), + new KeyValueTimestamp<>("2", "0+2", 20), + new KeyValueTimestamp<>("2", "0+2-2", 23), + new KeyValueTimestamp<>("4", "0+4", 23), + new KeyValueTimestamp<>("4", "0+4-4", 23), + new KeyValueTimestamp<>("7", "0+7", 22)), + supplier.theCapturedProcessor().processed()); + } + } + + private static void testCountHelper(final StreamsBuilder builder, + final String input, + final MockApiProcessorSupplier supplier) { + try ( + final TopologyTestDriver driver = new TopologyTestDriver( + builder.build(), CONFIG, Instant.ofEpochMilli(0L))) { + final TestInputTopic inputTopic = + driver.createInputTopic(input, new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + + inputTopic.pipeInput("A", "green", 10L); + inputTopic.pipeInput("B", "green", 9L); + inputTopic.pipeInput("A", "blue", 12L); + inputTopic.pipeInput("C", "yellow", 15L); + inputTopic.pipeInput("D", "green", 11L); + + assertEquals( + asList( + new KeyValueTimestamp<>("green", 1L, 10), + new KeyValueTimestamp<>("green", 2L, 10), + new KeyValueTimestamp<>("green", 1L, 12), + new KeyValueTimestamp<>("blue", 1L, 12), + new KeyValueTimestamp<>("yellow", 1L, 15), + new KeyValueTimestamp<>("green", 2L, 12)), + supplier.theCapturedProcessor().processed()); + } + } + + + @Test + public void testCount() { + final StreamsBuilder builder = new StreamsBuilder(); + final String input = "count-test-input"; + + builder + .table(input, consumed) + .groupBy(MockMapper.selectValueKeyValueMapper(), stringSerialized) + .count(Materialized.as("count")) + .toStream() + .process(supplier); + + testCountHelper(builder, input, supplier); + } + + @Test + public void testCountWithInternalStore() { + final StreamsBuilder builder = new StreamsBuilder(); + final String input = "count-test-input"; + + builder + .table(input, consumed) + .groupBy(MockMapper.selectValueKeyValueMapper(), stringSerialized) + .count() + .toStream() + .process(supplier); + + testCountHelper(builder, input, supplier); + } + + @Test + public void testRemoveOldBeforeAddNew() { + final StreamsBuilder builder = new StreamsBuilder(); + final String input = "count-test-input"; + final MockApiProcessorSupplier supplier = new MockApiProcessorSupplier<>(); + + builder + .table(input, consumed) + .groupBy( + (key, value) -> KeyValue.pair( + String.valueOf(key.charAt(0)), + String.valueOf(key.charAt(1))), + stringSerialized) + .aggregate( + () -> "", + (aggKey, value, aggregate) -> aggregate + value, + (key, value, aggregate) -> aggregate.replaceAll(value, ""), + Materialized.>as("someStore") + .withValueSerde(Serdes.String())) + .toStream() + .process(supplier); + + try ( + final TopologyTestDriver driver = new TopologyTestDriver( + builder.build(), CONFIG, Instant.ofEpochMilli(0L))) { + final TestInputTopic inputTopic = + driver.createInputTopic(input, new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + + final MockApiProcessor proc = supplier.theCapturedProcessor(); + + inputTopic.pipeInput("11", "A", 10L); + inputTopic.pipeInput("12", "B", 8L); + inputTopic.pipeInput("11", (String) null, 12L); + inputTopic.pipeInput("12", "C", 6L); + + assertEquals( + asList( + new KeyValueTimestamp<>("1", "1", 10), + new KeyValueTimestamp<>("1", "12", 10), + new KeyValueTimestamp<>("1", "2", 12), + new KeyValueTimestamp<>("1", "", 12), + new KeyValueTimestamp<>("1", "2", 12L) + ), + proc.processed() + ); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableFilterTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableFilterTest.java new file mode 100644 index 0000000..d3ed6b5 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableFilterTest.java @@ -0,0 +1,509 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.TopologyTestDriverWrapper; +import org.apache.kafka.streams.TopologyWrapper; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Predicate; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.test.MockApiProcessor; +import org.apache.kafka.test.MockApiProcessorSupplier; +import org.apache.kafka.test.MockMapper; +import org.apache.kafka.test.MockProcessor; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.MockReducer; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Before; +import org.junit.Test; + +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.Properties; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +@SuppressWarnings("unchecked") +public class KTableFilterTest { + private final Consumed consumed = Consumed.with(Serdes.String(), Serdes.Integer()); + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.String(), Serdes.Integer()); + + @Before + public void setUp() { + // disable caching at the config level + props.setProperty(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, "0"); + } + + private final Predicate predicate = (key, value) -> (value % 2) == 0; + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + private void doTestKTable(final StreamsBuilder builder, + final KTable table2, + final KTable table3, + final String topic) { + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + table2.toStream().process(supplier); + table3.toStream().process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(topic, new StringSerializer(), new IntegerSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + inputTopic.pipeInput("A", 1, 10L); + inputTopic.pipeInput("B", 2, 5L); + inputTopic.pipeInput("C", 3, 8L); + inputTopic.pipeInput("D", 4, 14L); + inputTopic.pipeInput("A", null, 18L); + inputTopic.pipeInput("B", null, 15L); + } + + final List> processors = supplier.capturedProcessors(2); + + processors.get(0).checkAndClearProcessResult(new KeyValueTimestamp<>("A", null, 10), + new KeyValueTimestamp<>("B", 2, 5), + new KeyValueTimestamp<>("C", null, 8), + new KeyValueTimestamp<>("D", 4, 14), + new KeyValueTimestamp<>("A", null, 18), + new KeyValueTimestamp<>("B", null, 15)); + processors.get(1).checkAndClearProcessResult(new KeyValueTimestamp<>("A", 1, 10), + new KeyValueTimestamp<>("B", null, 5), + new KeyValueTimestamp<>("C", 3, 8), + new KeyValueTimestamp<>("D", null, 14), + new KeyValueTimestamp<>("A", null, 18), + new KeyValueTimestamp<>("B", null, 15)); + } + + @Test + public void shouldPassThroughWithoutMaterialization() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + + final KTable table1 = builder.table(topic1, consumed); + final KTable table2 = table1.filter(predicate); + final KTable table3 = table1.filterNot(predicate); + + assertNull(table1.queryableStoreName()); + assertNull(table2.queryableStoreName()); + assertNull(table3.queryableStoreName()); + + doTestKTable(builder, table2, table3, topic1); + } + + @Test + public void shouldPassThroughOnMaterialization() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + + final KTable table1 = builder.table(topic1, consumed); + final KTable table2 = table1.filter(predicate, Materialized.as("store2")); + final KTable table3 = table1.filterNot(predicate); + + assertNull(table1.queryableStoreName()); + assertEquals("store2", table2.queryableStoreName()); + assertNull(table3.queryableStoreName()); + + doTestKTable(builder, table2, table3, topic1); + } + + private void doTestValueGetter(final StreamsBuilder builder, + final KTableImpl table2, + final KTableImpl table3, + final String topic1) { + + final Topology topology = builder.build(); + + final KTableValueGetterSupplier getterSupplier2 = table2.valueGetterSupplier(); + final KTableValueGetterSupplier getterSupplier3 = table3.valueGetterSupplier(); + + final InternalTopologyBuilder topologyBuilder = TopologyWrapper.getInternalTopologyBuilder(topology); + topologyBuilder.connectProcessorAndStateStores(table2.name, getterSupplier2.storeNames()); + topologyBuilder.connectProcessorAndStateStores(table3.name, getterSupplier3.storeNames()); + + try (final TopologyTestDriverWrapper driver = new TopologyTestDriverWrapper(topology, props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(topic1, new StringSerializer(), new IntegerSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + + final KTableValueGetter getter2 = getterSupplier2.get(); + final KTableValueGetter getter3 = getterSupplier3.get(); + + getter2.init(driver.setCurrentNodeForProcessorContext(table2.name)); + getter3.init(driver.setCurrentNodeForProcessorContext(table3.name)); + + inputTopic.pipeInput("A", 1, 5L); + inputTopic.pipeInput("B", 1, 10L); + inputTopic.pipeInput("C", 1, 15L); + + assertNull(getter2.get("A")); + assertNull(getter2.get("B")); + assertNull(getter2.get("C")); + + assertEquals(ValueAndTimestamp.make(1, 5L), getter3.get("A")); + assertEquals(ValueAndTimestamp.make(1, 10L), getter3.get("B")); + assertEquals(ValueAndTimestamp.make(1, 15L), getter3.get("C")); + + inputTopic.pipeInput("A", 2, 10L); + inputTopic.pipeInput("B", 2, 5L); + + assertEquals(ValueAndTimestamp.make(2, 10L), getter2.get("A")); + assertEquals(ValueAndTimestamp.make(2, 5L), getter2.get("B")); + assertNull(getter2.get("C")); + + assertNull(getter3.get("A")); + assertNull(getter3.get("B")); + assertEquals(ValueAndTimestamp.make(1, 15L), getter3.get("C")); + + inputTopic.pipeInput("A", 3, 15L); + + assertNull(getter2.get("A")); + assertEquals(ValueAndTimestamp.make(2, 5L), getter2.get("B")); + assertNull(getter2.get("C")); + + assertEquals(ValueAndTimestamp.make(3, 15L), getter3.get("A")); + assertNull(getter3.get("B")); + assertEquals(ValueAndTimestamp.make(1, 15L), getter3.get("C")); + + inputTopic.pipeInput("A", null, 10L); + inputTopic.pipeInput("B", null, 20L); + + assertNull(getter2.get("A")); + assertNull(getter2.get("B")); + assertNull(getter2.get("C")); + + assertNull(getter3.get("A")); + assertNull(getter3.get("B")); + assertEquals(ValueAndTimestamp.make(1, 15L), getter3.get("C")); + } + } + + @Test + public void shouldGetValuesOnMaterialization() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + + final KTableImpl table1 = + (KTableImpl) builder.table(topic1, consumed); + final KTableImpl table2 = + (KTableImpl) table1.filter(predicate, Materialized.as("store2")); + final KTableImpl table3 = + (KTableImpl) table1.filterNot(predicate, Materialized.as("store3")); + final KTableImpl table4 = + (KTableImpl) table1.filterNot(predicate); + + assertNull(table1.queryableStoreName()); + assertEquals("store2", table2.queryableStoreName()); + assertEquals("store3", table3.queryableStoreName()); + assertNull(table4.queryableStoreName()); + + doTestValueGetter(builder, table2, table3, topic1); + } + + private void doTestNotSendingOldValue(final StreamsBuilder builder, + final KTableImpl table1, + final KTableImpl table2, + final String topic1) { + final MockApiProcessorSupplier supplier = new MockApiProcessorSupplier<>(); + + builder.build().addProcessor("proc1", supplier, table1.name); + builder.build().addProcessor("proc2", supplier, table2.name); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(topic1, new StringSerializer(), new IntegerSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + + inputTopic.pipeInput("A", 1, 5L); + inputTopic.pipeInput("B", 1, 10L); + inputTopic.pipeInput("C", 1, 15L); + + final List> processors = supplier.capturedProcessors(2); + + processors.get(0).checkAndClearProcessResult(new KeyValueTimestamp<>("A", new Change<>(1, null), 5), + new KeyValueTimestamp<>("B", new Change<>(1, null), 10), + new KeyValueTimestamp<>("C", new Change<>(1, null), 15)); + processors.get(1).checkAndClearProcessResult(new KeyValueTimestamp<>("A", new Change<>(null, null), 5), + new KeyValueTimestamp<>("B", new Change<>(null, null), 10), + new KeyValueTimestamp<>("C", new Change<>(null, null), 15)); + + inputTopic.pipeInput("A", 2, 15L); + inputTopic.pipeInput("B", 2, 8L); + + processors.get(0).checkAndClearProcessResult(new KeyValueTimestamp<>("A", new Change<>(2, null), 15), + new KeyValueTimestamp<>("B", new Change<>(2, null), 8)); + processors.get(1).checkAndClearProcessResult(new KeyValueTimestamp<>("A", new Change<>(2, null), 15), + new KeyValueTimestamp<>("B", new Change<>(2, null), 8)); + + inputTopic.pipeInput("A", 3, 20L); + + processors.get(0).checkAndClearProcessResult(new KeyValueTimestamp<>("A", new Change<>(3, null), 20)); + processors.get(1).checkAndClearProcessResult(new KeyValueTimestamp<>("A", new Change<>(null, null), 20)); + inputTopic.pipeInput("A", null, 10L); + inputTopic.pipeInput("B", null, 20L); + + processors.get(0).checkAndClearProcessResult(new KeyValueTimestamp<>("A", new Change<>(null, null), 10), + new KeyValueTimestamp<>("B", new Change<>(null, null), 20)); + processors.get(1).checkAndClearProcessResult(new KeyValueTimestamp<>("A", new Change<>(null, null), 10), + new KeyValueTimestamp<>("B", new Change<>(null, null), 20)); + } + } + + + @Test + public void shouldNotSendOldValuesWithoutMaterialization() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + + final KTableImpl table1 = + (KTableImpl) builder.table(topic1, consumed); + final KTableImpl table2 = (KTableImpl) table1.filter(predicate); + + doTestNotSendingOldValue(builder, table1, table2, topic1); + } + + @Test + public void shouldNotSendOldValuesOnMaterialization() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + + final KTableImpl table1 = + (KTableImpl) builder.table(topic1, consumed); + final KTableImpl table2 = + (KTableImpl) table1.filter(predicate, Materialized.as("store2")); + + doTestNotSendingOldValue(builder, table1, table2, topic1); + } + + @Test + public void shouldNotEnableSendingOldValuesIfNotAlreadyMaterializedAndNotForcedToMaterialize() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + + final KTableImpl table1 = + (KTableImpl) builder.table(topic1, consumed); + final KTableImpl table2 = (KTableImpl) table1.filter(predicate); + + table2.enableSendingOldValues(false); + + doTestNotSendingOldValue(builder, table1, table2, topic1); + } + + private void doTestSendingOldValue(final StreamsBuilder builder, + final KTableImpl table1, + final KTableImpl table2, + final String topic1) { + final MockApiProcessorSupplier supplier = new MockApiProcessorSupplier<>(); + final Topology topology = builder.build(); + + topology.addProcessor("proc1", supplier, table1.name); + topology.addProcessor("proc2", supplier, table2.name); + + final boolean parentSendOldVals = table1.sendingOldValueEnabled(); + + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(topic1, new StringSerializer(), new IntegerSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + + inputTopic.pipeInput("A", 1, 5L); + inputTopic.pipeInput("B", 1, 10L); + inputTopic.pipeInput("C", 1, 15L); + + final List> processors = supplier.capturedProcessors(2); + final MockApiProcessor table1Output = processors.get(0); + final MockApiProcessor table2Output = processors.get(1); + + table1Output.checkAndClearProcessResult( + new KeyValueTimestamp<>("A", new Change<>(1, null), 5), + new KeyValueTimestamp<>("B", new Change<>(1, null), 10), + new KeyValueTimestamp<>("C", new Change<>(1, null), 15) + ); + table2Output.checkEmptyAndClearProcessResult(); + + inputTopic.pipeInput("A", 2, 15L); + inputTopic.pipeInput("B", 2, 8L); + + table1Output.checkAndClearProcessResult( + new KeyValueTimestamp<>("A", new Change<>(2, parentSendOldVals ? 1 : null), 15), + new KeyValueTimestamp<>("B", new Change<>(2, parentSendOldVals ? 1 : null), 8) + ); + table2Output.checkAndClearProcessResult( + new KeyValueTimestamp<>("A", new Change<>(2, null), 15), + new KeyValueTimestamp<>("B", new Change<>(2, null), 8) + ); + + inputTopic.pipeInput("A", 3, 20L); + + table1Output.checkAndClearProcessResult( + new KeyValueTimestamp<>("A", new Change<>(3, parentSendOldVals ? 2 : null), 20) + ); + table2Output.checkAndClearProcessResult( + new KeyValueTimestamp<>("A", new Change<>(null, 2), 20) + ); + + inputTopic.pipeInput("A", null, 10L); + inputTopic.pipeInput("B", null, 20L); + + table1Output.checkAndClearProcessResult( + new KeyValueTimestamp<>("A", new Change<>(null, parentSendOldVals ? 3 : null), 10), + new KeyValueTimestamp<>("B", new Change<>(null, parentSendOldVals ? 2 : null), 20) + ); + table2Output.checkAndClearProcessResult( + new KeyValueTimestamp<>("B", new Change<>(null, 2), 20) + ); + } + } + + @Test + public void shouldEnableSendOldValuesWhenNotMaterializedAlreadyButForcedToMaterialize() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + + final KTableImpl table1 = + (KTableImpl) builder.table(topic1, consumed); + final KTableImpl table2 = + (KTableImpl) table1.filter(predicate); + + table2.enableSendingOldValues(true); + + assertThat(table1.sendingOldValueEnabled(), is(true)); + assertThat(table2.sendingOldValueEnabled(), is(true)); + + doTestSendingOldValue(builder, table1, table2, topic1); + } + + @Test + public void shouldEnableSendOldValuesWhenMaterializedAlreadyAndForcedToMaterialize() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + + final KTableImpl table1 = + (KTableImpl) builder.table(topic1, consumed); + final KTableImpl table2 = + (KTableImpl) table1.filter(predicate, Materialized.as("store2")); + + table2.enableSendingOldValues(true); + + assertThat(table1.sendingOldValueEnabled(), is(false)); + assertThat(table2.sendingOldValueEnabled(), is(true)); + + doTestSendingOldValue(builder, table1, table2, topic1); + } + + @Test + public void shouldSendOldValuesWhenEnabledOnUpStreamMaterialization() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + + final KTableImpl table1 = + (KTableImpl) builder.table(topic1, consumed, Materialized.as("store2")); + final KTableImpl table2 = + (KTableImpl) table1.filter(predicate); + + table2.enableSendingOldValues(false); + + assertThat(table1.sendingOldValueEnabled(), is(true)); + assertThat(table2.sendingOldValueEnabled(), is(true)); + + doTestSendingOldValue(builder, table1, table2, topic1); + } + + private void doTestSkipNullOnMaterialization(final StreamsBuilder builder, + final KTableImpl table1, + final KTableImpl table2, + final String topic1) { + final MockApiProcessorSupplier supplier = new MockApiProcessorSupplier<>(); + final Topology topology = builder.build(); + + topology.addProcessor("proc1", supplier, table1.name); + topology.addProcessor("proc2", supplier, table2.name); + + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, props)) { + final TestInputTopic stringinputTopic = + driver.createInputTopic(topic1, new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + + stringinputTopic.pipeInput("A", "reject", 5L); + stringinputTopic.pipeInput("B", "reject", 10L); + stringinputTopic.pipeInput("C", "reject", 20L); + } + + final List> processors = supplier.capturedProcessors(2); + processors.get(0).checkAndClearProcessResult(new KeyValueTimestamp<>("A", new Change<>("reject", null), 5), + new KeyValueTimestamp<>("B", new Change<>("reject", null), 10), + new KeyValueTimestamp<>("C", new Change<>("reject", null), 20)); + processors.get(1).checkEmptyAndClearProcessResult(); + } + + @Test + public void shouldSkipNullToRepartitionWithoutMaterialization() { + // Do not explicitly set enableSendingOldValues. Let a further downstream stateful operator trigger it instead. + final StreamsBuilder builder = new StreamsBuilder(); + + final String topic1 = "topic1"; + + final Consumed consumed = Consumed.with(Serdes.String(), Serdes.String()); + final KTableImpl table1 = + (KTableImpl) builder.table(topic1, consumed); + final KTableImpl table2 = + (KTableImpl) table1.filter((key, value) -> value.equalsIgnoreCase("accept")) + .groupBy(MockMapper.noOpKeyValueMapper()) + .reduce(MockReducer.STRING_ADDER, MockReducer.STRING_REMOVER); + + doTestSkipNullOnMaterialization(builder, table1, table2, topic1); + } + + @Test + public void shouldSkipNullToRepartitionOnMaterialization() { + // Do not explicitly set enableSendingOldValues. Let a further downstream stateful operator trigger it instead. + final StreamsBuilder builder = new StreamsBuilder(); + + final String topic1 = "topic1"; + + final Consumed consumed = Consumed.with(Serdes.String(), Serdes.String()); + final KTableImpl table1 = + (KTableImpl) builder.table(topic1, consumed); + final KTableImpl table2 = + (KTableImpl) table1.filter((key, value) -> value.equalsIgnoreCase("accept"), Materialized.as("store2")) + .groupBy(MockMapper.noOpKeyValueMapper()) + .reduce(MockReducer.STRING_ADDER, MockReducer.STRING_REMOVER, Materialized.as("mock-result")); + + doTestSkipNullOnMaterialization(builder, table1, table2, topic1); + } + + @Test + public void testTypeVariance() { + final Predicate numberKeyPredicate = (key, value) -> false; + + new StreamsBuilder() + .table("empty") + .filter(numberKeyPredicate) + .filterNot(numberKeyPredicate) + .toStream() + .to("nirvana"); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableImplTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableImplTest.java new file mode 100644 index 0000000..462423a --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableImplTest.java @@ -0,0 +1,585 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.TopologyDescription; +import org.apache.kafka.streams.TopologyDescription.Subtopology; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.TopologyTestDriverWrapper; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.ValueJoiner; +import org.apache.kafka.streams.kstream.ValueMapper; +import org.apache.kafka.streams.kstream.ValueMapperWithKey; +import org.apache.kafka.streams.kstream.ValueTransformerWithKey; +import org.apache.kafka.streams.kstream.ValueTransformerWithKeySupplier; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.internals.SinkNode; +import org.apache.kafka.streams.processor.internals.SourceNode; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.test.MockAggregator; +import org.apache.kafka.test.MockInitializer; +import org.apache.kafka.test.MockMapper; +import org.apache.kafka.test.MockProcessor; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.MockReducer; +import org.apache.kafka.test.MockValueJoiner; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Before; +import org.junit.Test; + +import java.lang.reflect.Field; +import java.util.List; +import java.util.Properties; + +import static java.util.Arrays.asList; +import static org.easymock.EasyMock.mock; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; + +@SuppressWarnings("unchecked") +public class KTableImplTest { + private final Consumed stringConsumed = Consumed.with(Serdes.String(), Serdes.String()); + private final Consumed consumed = Consumed.with(Serdes.String(), Serdes.String()); + private final Produced produced = Produced.with(Serdes.String(), Serdes.String()); + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.String(), Serdes.String()); + private final Serde mySerde = new Serdes.StringSerde(); + + private KTable table; + + @Before + public void setUp() { + table = new StreamsBuilder().table("test"); + } + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Test + public void testKTable() { + final StreamsBuilder builder = new StreamsBuilder(); + + final String topic1 = "topic1"; + final String topic2 = "topic2"; + + final KTable table1 = builder.table(topic1, consumed); + + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + table1.toStream().process(supplier); + + final KTable table2 = table1.mapValues(s -> Integer.valueOf(s)); + table2.toStream().process(supplier); + + final KTable table3 = table2.filter((key, value) -> (value % 2) == 0); + table3.toStream().process(supplier); + table1.toStream().to(topic2, produced); + + final KTable table4 = builder.table(topic2, consumed); + table4.toStream().process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(topic1, new StringSerializer(), new StringSerializer()); + inputTopic.pipeInput("A", "01", 5L); + inputTopic.pipeInput("B", "02", 100L); + inputTopic.pipeInput("C", "03", 0L); + inputTopic.pipeInput("D", "04", 0L); + inputTopic.pipeInput("A", "05", 10L); + inputTopic.pipeInput("A", "06", 8L); + } + + final List> processors = supplier.capturedProcessors(4); + assertEquals(asList( + new KeyValueTimestamp<>("A", "01", 5), + new KeyValueTimestamp<>("B", "02", 100), + new KeyValueTimestamp<>("C", "03", 0), + new KeyValueTimestamp<>("D", "04", 0), + new KeyValueTimestamp<>("A", "05", 10), + new KeyValueTimestamp<>("A", "06", 8)), + processors.get(0).processed()); + assertEquals(asList( + new KeyValueTimestamp<>("A", 1, 5), + new KeyValueTimestamp<>("B", 2, 100), + new KeyValueTimestamp<>("C", 3, 0), + new KeyValueTimestamp<>("D", 4, 0), + new KeyValueTimestamp<>("A", 5, 10), + new KeyValueTimestamp<>("A", 6, 8)), + processors.get(1).processed()); + assertEquals(asList( + new KeyValueTimestamp<>("A", null, 5), + new KeyValueTimestamp<>("B", 2, 100), + new KeyValueTimestamp<>("C", null, 0), + new KeyValueTimestamp<>("D", 4, 0), + new KeyValueTimestamp<>("A", null, 10), + new KeyValueTimestamp<>("A", 6, 8)), + processors.get(2).processed()); + assertEquals(asList( + new KeyValueTimestamp<>("A", "01", 5), + new KeyValueTimestamp<>("B", "02", 100), + new KeyValueTimestamp<>("C", "03", 0), + new KeyValueTimestamp<>("D", "04", 0), + new KeyValueTimestamp<>("A", "05", 10), + new KeyValueTimestamp<>("A", "06", 8)), + processors.get(3).processed()); + } + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Test + public void testMaterializedKTable() { + final StreamsBuilder builder = new StreamsBuilder(); + + final String topic1 = "topic1"; + final String topic2 = "topic2"; + + final KTable table1 = builder.table(topic1, consumed, Materialized.as("fred")); + + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + table1.toStream().process(supplier); + + final KTable table2 = table1.mapValues(s -> Integer.valueOf(s)); + table2.toStream().process(supplier); + + final KTable table3 = table2.filter((key, value) -> (value % 2) == 0); + table3.toStream().process(supplier); + table1.toStream().to(topic2, produced); + + final KTable table4 = builder.table(topic2, consumed); + table4.toStream().process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(topic1, new StringSerializer(), new StringSerializer()); + inputTopic.pipeInput("A", "01", 5L); + inputTopic.pipeInput("B", "02", 100L); + inputTopic.pipeInput("C", "03", 0L); + inputTopic.pipeInput("D", "04", 0L); + inputTopic.pipeInput("A", "05", 10L); + inputTopic.pipeInput("A", "06", 8L); + } + + final List> processors = supplier.capturedProcessors(4); + assertEquals(asList( + new KeyValueTimestamp<>("A", "01", 5), + new KeyValueTimestamp<>("B", "02", 100), + new KeyValueTimestamp<>("C", "03", 0), + new KeyValueTimestamp<>("D", "04", 0), + new KeyValueTimestamp<>("A", "05", 10), + new KeyValueTimestamp<>("A", "06", 8)), + processors.get(0).processed()); + assertEquals(asList( + new KeyValueTimestamp<>("A", 1, 5), + new KeyValueTimestamp<>("B", 2, 100), + new KeyValueTimestamp<>("C", 3, 0), + new KeyValueTimestamp<>("D", 4, 0), + new KeyValueTimestamp<>("A", 5, 10), + new KeyValueTimestamp<>("A", 6, 8)), + processors.get(1).processed()); + assertEquals(asList( + new KeyValueTimestamp<>("B", 2, 100), + new KeyValueTimestamp<>("D", 4, 0), + new KeyValueTimestamp<>("A", 6, 8)), + processors.get(2).processed()); + assertEquals(asList( + new KeyValueTimestamp<>("A", "01", 5), + new KeyValueTimestamp<>("B", "02", 100), + new KeyValueTimestamp<>("C", "03", 0), + new KeyValueTimestamp<>("D", "04", 0), + new KeyValueTimestamp<>("A", "05", 10), + new KeyValueTimestamp<>("A", "06", 8)), + processors.get(3).processed()); + } + + @Test + public void shouldPreserveSerdesForOperators() { + final StreamsBuilder builder = new StreamsBuilder(); + final KTable table1 = builder.table("topic-2", stringConsumed); + final ConsumedInternal consumedInternal = new ConsumedInternal<>(stringConsumed); + + final KeyValueMapper selector = (key, value) -> key; + final ValueMapper mapper = value -> value; + final ValueJoiner joiner = (value1, value2) -> value1; + final ValueTransformerWithKeySupplier valueTransformerWithKeySupplier = + () -> new ValueTransformerWithKey() { + @Override + public void init(final ProcessorContext context) {} + + @Override + public String transform(final String key, final String value) { + return value; + } + + @Override + public void close() {} + }; + + assertEquals( + ((AbstractStream) table1.filter((key, value) -> false)).keySerde(), + consumedInternal.keySerde()); + assertEquals( + ((AbstractStream) table1.filter((key, value) -> false)).valueSerde(), + consumedInternal.valueSerde()); + assertEquals( + ((AbstractStream) table1.filter((key, value) -> false, Materialized.with(mySerde, mySerde))).keySerde(), + mySerde); + assertEquals( + ((AbstractStream) table1.filter((key, value) -> false, Materialized.with(mySerde, mySerde))).valueSerde(), + mySerde); + + assertEquals( + ((AbstractStream) table1.filterNot((key, value) -> false)).keySerde(), + consumedInternal.keySerde()); + assertEquals( + ((AbstractStream) table1.filterNot((key, value) -> false)).valueSerde(), + consumedInternal.valueSerde()); + assertEquals( + ((AbstractStream) table1.filterNot((key, value) -> false, Materialized.with(mySerde, mySerde))).keySerde(), + mySerde); + assertEquals( + ((AbstractStream) table1.filterNot((key, value) -> false, Materialized.with(mySerde, mySerde))).valueSerde(), + mySerde); + + assertEquals( + ((AbstractStream) table1.mapValues(mapper)).keySerde(), + consumedInternal.keySerde()); + assertNull(((AbstractStream) table1.mapValues(mapper)).valueSerde()); + assertEquals( + ((AbstractStream) table1.mapValues(mapper, Materialized.with(mySerde, mySerde))).keySerde(), + mySerde); + assertEquals( + ((AbstractStream) table1.mapValues(mapper, Materialized.with(mySerde, mySerde))).valueSerde(), + mySerde); + + assertEquals( + ((AbstractStream) table1.toStream()).keySerde(), + consumedInternal.keySerde()); + assertEquals( + ((AbstractStream) table1.toStream()).valueSerde(), + consumedInternal.valueSerde()); + assertNull(((AbstractStream) table1.toStream(selector)).keySerde()); + assertEquals( + ((AbstractStream) table1.toStream(selector)).valueSerde(), + consumedInternal.valueSerde()); + + assertEquals( + ((AbstractStream) table1.transformValues(valueTransformerWithKeySupplier)).keySerde(), + consumedInternal.keySerde()); + assertNull(((AbstractStream) table1.transformValues(valueTransformerWithKeySupplier)).valueSerde()); + assertEquals( + ((AbstractStream) table1.transformValues(valueTransformerWithKeySupplier, Materialized.with(mySerde, mySerde))).keySerde(), + mySerde); + assertEquals(((AbstractStream) table1.transformValues(valueTransformerWithKeySupplier, Materialized.with(mySerde, mySerde))).valueSerde(), + mySerde); + + assertNull(((AbstractStream) table1.groupBy(KeyValue::new)).keySerde()); + assertNull(((AbstractStream) table1.groupBy(KeyValue::new)).valueSerde()); + assertEquals( + ((AbstractStream) table1.groupBy(KeyValue::new, Grouped.with(mySerde, mySerde))).keySerde(), + mySerde); + assertEquals( + ((AbstractStream) table1.groupBy(KeyValue::new, Grouped.with(mySerde, mySerde))).valueSerde(), + mySerde); + + assertEquals( + ((AbstractStream) table1.join(table1, joiner)).keySerde(), + consumedInternal.keySerde()); + assertNull(((AbstractStream) table1.join(table1, joiner)).valueSerde()); + assertEquals( + ((AbstractStream) table1.join(table1, joiner, Materialized.with(mySerde, mySerde))).keySerde(), + mySerde); + assertEquals( + ((AbstractStream) table1.join(table1, joiner, Materialized.with(mySerde, mySerde))).valueSerde(), + mySerde); + + assertEquals( + ((AbstractStream) table1.leftJoin(table1, joiner)).keySerde(), + consumedInternal.keySerde()); + assertNull(((AbstractStream) table1.leftJoin(table1, joiner)).valueSerde()); + assertEquals( + ((AbstractStream) table1.leftJoin(table1, joiner, Materialized.with(mySerde, mySerde))).keySerde(), + mySerde); + assertEquals( + ((AbstractStream) table1.leftJoin(table1, joiner, Materialized.with(mySerde, mySerde))).valueSerde(), + mySerde); + + assertEquals( + ((AbstractStream) table1.outerJoin(table1, joiner)).keySerde(), + consumedInternal.keySerde()); + assertNull(((AbstractStream) table1.outerJoin(table1, joiner)).valueSerde()); + assertEquals( + ((AbstractStream) table1.outerJoin(table1, joiner, Materialized.with(mySerde, mySerde))).keySerde(), + mySerde); + assertEquals( + ((AbstractStream) table1.outerJoin(table1, joiner, Materialized.with(mySerde, mySerde))).valueSerde(), + mySerde); + } + + @Test + public void testStateStoreLazyEval() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + final String topic2 = "topic2"; + + final KTableImpl table1 = + (KTableImpl) builder.table(topic1, consumed); + builder.table(topic2, consumed); + + final KTableImpl table1Mapped = + (KTableImpl) table1.mapValues(s -> Integer.valueOf(s)); + table1Mapped.filter((key, value) -> (value % 2) == 0); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + assertEquals(0, driver.getAllStateStores().size()); + } + } + + @Test + public void testStateStore() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + final String topic2 = "topic2"; + + final KTableImpl table1 = + (KTableImpl) builder.table(topic1, consumed); + final KTableImpl table2 = + (KTableImpl) builder.table(topic2, consumed); + + final KTableImpl table1Mapped = + (KTableImpl) table1.mapValues(s -> Integer.valueOf(s)); + final KTableImpl table1MappedFiltered = + (KTableImpl) table1Mapped.filter((key, value) -> (value % 2) == 0); + table2.join(table1MappedFiltered, (v1, v2) -> v1 + v2); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + assertEquals(2, driver.getAllStateStores().size()); + } + } + + @Test + public void shouldNotEnableSendingOldValuesIfNotMaterializedAlreadyAndNotForcedToMaterialize() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KTableImpl table = + (KTableImpl) builder.table("topic1", consumed); + + table.enableSendingOldValues(false); + + assertThat(table.sendingOldValueEnabled(), is(false)); + } + + @Test + public void shouldEnableSendingOldValuesIfNotMaterializedAlreadyButForcedToMaterialize() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KTableImpl table = + (KTableImpl) builder.table("topic1", consumed); + + table.enableSendingOldValues(true); + + assertThat(table.sendingOldValueEnabled(), is(true)); + } + + private void assertTopologyContainsProcessor(final Topology topology, final String processorName) { + for (final Subtopology subtopology: topology.describe().subtopologies()) { + for (final TopologyDescription.Node node: subtopology.nodes()) { + if (node.name().equals(processorName)) { + return; + } + } + } + throw new AssertionError("No processor named '" + processorName + "'" + + "found in the provided Topology:\n" + topology.describe()); + } + + @Test + public void shouldCreateSourceAndSinkNodesForRepartitioningTopic() throws Exception { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + final String storeName1 = "storeName1"; + + final KTableImpl table1 = + (KTableImpl) builder.table( + topic1, + consumed, + Materialized.>as(storeName1) + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.String()) + ); + + table1.groupBy(MockMapper.noOpKeyValueMapper()) + .aggregate( + MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + MockAggregator.TOSTRING_REMOVER, + Materialized.as("mock-result1")); + + table1.groupBy(MockMapper.noOpKeyValueMapper()) + .reduce( + MockReducer.STRING_ADDER, + MockReducer.STRING_REMOVER, + Materialized.as("mock-result2")); + + final Topology topology = builder.build(); + try (final TopologyTestDriverWrapper driver = new TopologyTestDriverWrapper(topology, props)) { + + assertEquals(3, driver.getAllStateStores().size()); + + assertTopologyContainsProcessor(topology, "KSTREAM-SINK-0000000003"); + assertTopologyContainsProcessor(topology, "KSTREAM-SOURCE-0000000004"); + assertTopologyContainsProcessor(topology, "KSTREAM-SINK-0000000007"); + assertTopologyContainsProcessor(topology, "KSTREAM-SOURCE-0000000008"); + + final Field valSerializerField = ((SinkNode) driver.getProcessor("KSTREAM-SINK-0000000003")) + .getClass() + .getDeclaredField("valSerializer"); + final Field valDeserializerField = ((SourceNode) driver.getProcessor("KSTREAM-SOURCE-0000000004")) + .getClass() + .getDeclaredField("valDeserializer"); + valSerializerField.setAccessible(true); + valDeserializerField.setAccessible(true); + + assertNotNull(((ChangedSerializer) valSerializerField.get(driver.getProcessor("KSTREAM-SINK-0000000003"))).inner()); + assertNotNull(((ChangedDeserializer) valDeserializerField.get(driver.getProcessor("KSTREAM-SOURCE-0000000004"))).inner()); + assertNotNull(((ChangedSerializer) valSerializerField.get(driver.getProcessor("KSTREAM-SINK-0000000007"))).inner()); + assertNotNull(((ChangedDeserializer) valDeserializerField.get(driver.getProcessor("KSTREAM-SOURCE-0000000008"))).inner()); + } + } + + @Test + public void shouldNotAllowNullSelectorOnToStream() { + assertThrows(NullPointerException.class, () -> table.toStream((KeyValueMapper) null)); + } + + @Test + public void shouldNotAllowNullPredicateOnFilter() { + assertThrows(NullPointerException.class, () -> table.filter(null)); + } + + @Test + public void shouldNotAllowNullPredicateOnFilterNot() { + assertThrows(NullPointerException.class, () -> table.filterNot(null)); + } + + @Test + public void shouldNotAllowNullMapperOnMapValues() { + assertThrows(NullPointerException.class, () -> table.mapValues((ValueMapper) null)); + } + + @Test + public void shouldNotAllowNullMapperOnMapValueWithKey() { + assertThrows(NullPointerException.class, () -> table.mapValues((ValueMapperWithKey) null)); + } + + @Test + public void shouldNotAllowNullSelectorOnGroupBy() { + assertThrows(NullPointerException.class, () -> table.groupBy(null)); + } + + @Test + public void shouldNotAllowNullOtherTableOnJoin() { + assertThrows(NullPointerException.class, () -> table.join(null, MockValueJoiner.TOSTRING_JOINER)); + } + + @Test + public void shouldAllowNullStoreInJoin() { + table.join(table, MockValueJoiner.TOSTRING_JOINER); + } + + @Test + public void shouldNotAllowNullJoinerJoin() { + assertThrows(NullPointerException.class, () -> table.join(table, null)); + } + + @Test + public void shouldNotAllowNullOtherTableOnOuterJoin() { + assertThrows(NullPointerException.class, () -> table.outerJoin(null, MockValueJoiner.TOSTRING_JOINER)); + } + + @Test + public void shouldNotAllowNullJoinerOnOuterJoin() { + assertThrows(NullPointerException.class, () -> table.outerJoin(table, null)); + } + + @Test + public void shouldNotAllowNullJoinerOnLeftJoin() { + assertThrows(NullPointerException.class, () -> table.leftJoin(table, null)); + } + + @Test + public void shouldNotAllowNullOtherTableOnLeftJoin() { + assertThrows(NullPointerException.class, () -> table.leftJoin(null, MockValueJoiner.TOSTRING_JOINER)); + } + + @Test + public void shouldThrowNullPointerOnFilterWhenMaterializedIsNull() { + assertThrows(NullPointerException.class, () -> table.filter((key, value) -> false, (Materialized) null)); + } + + @Test + public void shouldThrowNullPointerOnFilterNotWhenMaterializedIsNull() { + assertThrows(NullPointerException.class, () -> table.filterNot((key, value) -> false, (Materialized) null)); + } + + @Test + public void shouldThrowNullPointerOnJoinWhenMaterializedIsNull() { + assertThrows(NullPointerException.class, () -> table.join(table, MockValueJoiner.TOSTRING_JOINER, (Materialized) null)); + } + + @Test + public void shouldThrowNullPointerOnLeftJoinWhenMaterializedIsNull() { + assertThrows(NullPointerException.class, () -> table.leftJoin(table, MockValueJoiner.TOSTRING_JOINER, (Materialized) null)); + } + + @Test + public void shouldThrowNullPointerOnOuterJoinWhenMaterializedIsNull() { + assertThrows(NullPointerException.class, () -> table.outerJoin(table, MockValueJoiner.TOSTRING_JOINER, (Materialized) null)); + } + + @Test + public void shouldThrowNullPointerOnTransformValuesWithKeyWhenTransformerSupplierIsNull() { + assertThrows(NullPointerException.class, () -> table.transformValues(null)); + } + + @SuppressWarnings("unchecked") + @Test + public void shouldThrowNullPointerOnTransformValuesWithKeyWhenMaterializedIsNull() { + final ValueTransformerWithKeySupplier valueTransformerSupplier = + mock(ValueTransformerWithKeySupplier.class); + assertThrows(NullPointerException.class, () -> table.transformValues(valueTransformerSupplier, (Materialized) null)); + } + + @Test + public void shouldThrowNullPointerOnTransformValuesWithKeyWhenStoreNamesNull() { + final ValueTransformerWithKeySupplier valueTransformerSupplier = + mock(ValueTransformerWithKeySupplier.class); + assertThrows(NullPointerException.class, () -> table.transformValues(valueTransformerSupplier, (String[]) null)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableKTableForeignKeyJoinScenarioTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableKTableForeignKeyJoinScenarioTest.java new file mode 100644 index 0000000..a111e82 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableKTableForeignKeyJoinScenarioTest.java @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.IntegerDeserializer; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TestOutputTopic; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.utils.UniqueTopicSerdeScope; +import org.apache.kafka.test.TestUtils; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestName; + +import java.util.Collections; +import java.util.Map; +import java.util.Properties; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkProperties; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; + +public class KTableKTableForeignKeyJoinScenarioTest { + + private static final String LEFT_TABLE = "left_table"; + private static final String RIGHT_TABLE = "right_table"; + private static final String OUTPUT = "output-topic"; + + @Rule + public TestName testName = new TestName(); + + @Test + public void shouldWorkWithDefaultSerdes() { + final StreamsBuilder builder = new StreamsBuilder(); + final KTable aTable = builder.table("A"); + final KTable bTable = builder.table("B"); + + final KTable fkJoinResult = aTable.join( + bTable, + value -> Integer.parseInt(value.split("-")[0]), + (aVal, bVal) -> "(" + aVal + "," + bVal + ")", + Materialized.as("asdf") + ); + + final KTable finalJoinResult = aTable.join( + fkJoinResult, + (aVal, fkJoinVal) -> "(" + aVal + "," + fkJoinVal + ")" + ); + + finalJoinResult.toStream().to("output"); + + validateTopologyCanProcessData(builder); + } + + @Test + public void shouldWorkWithDefaultAndConsumedSerdes() { + final StreamsBuilder builder = new StreamsBuilder(); + final KTable aTable = builder.table("A", Consumed.with(Serdes.Integer(), Serdes.String())); + final KTable bTable = builder.table("B"); + + final KTable fkJoinResult = aTable.join( + bTable, + value -> Integer.parseInt(value.split("-")[0]), + (aVal, bVal) -> "(" + aVal + "," + bVal + ")", + Materialized.as("asdf") + ); + + final KTable finalJoinResult = aTable.join( + fkJoinResult, + (aVal, fkJoinVal) -> "(" + aVal + "," + fkJoinVal + ")" + ); + + finalJoinResult.toStream().to("output"); + + validateTopologyCanProcessData(builder); + } + + @Test + public void shouldWorkWithDefaultAndJoinResultSerdes() { + final StreamsBuilder builder = new StreamsBuilder(); + final KTable aTable = builder.table("A"); + final KTable bTable = builder.table("B"); + + final KTable fkJoinResult = aTable.join( + bTable, + value -> Integer.parseInt(value.split("-")[0]), + (aVal, bVal) -> "(" + aVal + "," + bVal + ")", + Materialized.>as("asdf") + .withKeySerde(Serdes.Integer()) + .withValueSerde(Serdes.String()) + ); + + final KTable finalJoinResult = aTable.join( + fkJoinResult, + (aVal, fkJoinVal) -> "(" + aVal + "," + fkJoinVal + ")" + ); + + finalJoinResult.toStream().to("output"); + + validateTopologyCanProcessData(builder); + } + + @Test + public void shouldWorkWithDefaultAndEquiJoinResultSerdes() { + final StreamsBuilder builder = new StreamsBuilder(); + final KTable aTable = builder.table("A"); + final KTable bTable = builder.table("B"); + + final KTable fkJoinResult = aTable.join( + bTable, + value -> Integer.parseInt(value.split("-")[0]), + (aVal, bVal) -> "(" + aVal + "," + bVal + ")", + Materialized.as("asdf") + ); + + final KTable finalJoinResult = aTable.join( + fkJoinResult, + (aVal, fkJoinVal) -> "(" + aVal + "," + fkJoinVal + ")", + Materialized.with(Serdes.Integer(), Serdes.String()) + ); + + finalJoinResult.toStream().to("output"); + + validateTopologyCanProcessData(builder); + } + + @Test + public void shouldWorkWithDefaultAndProducedSerdes() { + final StreamsBuilder builder = new StreamsBuilder(); + final KTable aTable = builder.table("A"); + final KTable bTable = builder.table("B"); + + final KTable fkJoinResult = aTable.join( + bTable, + value -> Integer.parseInt(value.split("-")[0]), + (aVal, bVal) -> "(" + aVal + "," + bVal + ")", + Materialized.as("asdf") + ); + + final KTable finalJoinResult = aTable.join( + fkJoinResult, + (aVal, fkJoinVal) -> "(" + aVal + "," + fkJoinVal + ")" + ); + + finalJoinResult.toStream().to("output", Produced.with(Serdes.Integer(), Serdes.String())); + + validateTopologyCanProcessData(builder); + } + + @Test + public void shouldUseExpectedTopicsWithSerde() { + final String applicationId = "ktable-ktable-joinOnForeignKey"; + final Properties streamsConfig = mkProperties(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, applicationId), + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()) + )); + + final UniqueTopicSerdeScope serdeScope = new UniqueTopicSerdeScope(); + final StreamsBuilder builder = new StreamsBuilder(); + + final KTable left = builder.table( + LEFT_TABLE, + Consumed.with(serdeScope.decorateSerde(Serdes.Integer(), streamsConfig, true), + serdeScope.decorateSerde(Serdes.String(), streamsConfig, false)) + ); + final KTable right = builder.table( + RIGHT_TABLE, + Consumed.with(serdeScope.decorateSerde(Serdes.Integer(), streamsConfig, true), + serdeScope.decorateSerde(Serdes.String(), streamsConfig, false)) + ); + + left.join( + right, + value -> Integer.parseInt(value.split("\\|")[1]), + (value1, value2) -> "(" + value1 + "," + value2 + ")", + Materialized.with(null, serdeScope.decorateSerde(Serdes.String(), streamsConfig, false) + )) + .toStream() + .to(OUTPUT); + + + final Topology topology = builder.build(streamsConfig); + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, streamsConfig)) { + final TestInputTopic leftInput = driver.createInputTopic(LEFT_TABLE, new IntegerSerializer(), new StringSerializer()); + final TestInputTopic rightInput = driver.createInputTopic(RIGHT_TABLE, new IntegerSerializer(), new StringSerializer()); + leftInput.pipeInput(2, "lhsValue1|1"); + rightInput.pipeInput(1, "rhsValue1"); + } + // verifying primarily that no extra pseudo-topics were used, but it's nice to also verify the rest of the + // topics our serdes serialize data for + assertThat(serdeScope.registeredTopics(), is(mkSet( + // expected pseudo-topics + applicationId + "-KTABLE-FK-JOIN-SUBSCRIPTION-REGISTRATION-0000000006-topic-fk--key", + applicationId + "-KTABLE-FK-JOIN-SUBSCRIPTION-REGISTRATION-0000000006-topic-pk--key", + applicationId + "-KTABLE-FK-JOIN-SUBSCRIPTION-REGISTRATION-0000000006-topic-vh--value", + // internal topics + applicationId + "-KTABLE-FK-JOIN-SUBSCRIPTION-REGISTRATION-0000000006-topic--key", + applicationId + "-KTABLE-FK-JOIN-SUBSCRIPTION-RESPONSE-0000000014-topic--key", + applicationId + "-KTABLE-FK-JOIN-SUBSCRIPTION-RESPONSE-0000000014-topic--value", + applicationId + "-left_table-STATE-STORE-0000000000-changelog--key", + applicationId + "-left_table-STATE-STORE-0000000000-changelog--value", + applicationId + "-right_table-STATE-STORE-0000000003-changelog--key", + applicationId + "-right_table-STATE-STORE-0000000003-changelog--value", + // output topics + "output-topic--key", + "output-topic--value" + ))); + } + + private void validateTopologyCanProcessData(final StreamsBuilder builder) { + final Properties config = new Properties(); + final String safeTestName = safeUniqueTestName(getClass(), testName); + config.setProperty(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.IntegerSerde.class.getName()); + config.setProperty(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.StringSerde.class.getName()); + config.setProperty(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getAbsolutePath()); + try (final TopologyTestDriver topologyTestDriver = new TopologyTestDriver(builder.build(), config)) { + final TestInputTopic aTopic = topologyTestDriver.createInputTopic("A", new IntegerSerializer(), new StringSerializer()); + final TestInputTopic bTopic = topologyTestDriver.createInputTopic("B", new IntegerSerializer(), new StringSerializer()); + final TestOutputTopic output = topologyTestDriver.createOutputTopic("output", new IntegerDeserializer(), new StringDeserializer()); + aTopic.pipeInput(1, "999-alpha"); + bTopic.pipeInput(999, "beta"); + final Map x = output.readKeyValuesToMap(); + assertThat(x, is(Collections.singletonMap(1, "(999-alpha,(999-alpha,beta))"))); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableKTableInnerJoinTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableKTableInnerJoinTest.java new file mode 100644 index 0000000..13caa9f --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableKTableInnerJoinTest.java @@ -0,0 +1,481 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TestOutputTopic; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.TopologyWrapper; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.processor.MockProcessorContext; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.test.TestRecord; +import org.apache.kafka.test.MockApiProcessor; +import org.apache.kafka.test.MockApiProcessorSupplier; +import org.apache.kafka.test.MockValueJoiner; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Test; + +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.Properties; +import java.util.Set; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class KTableKTableInnerJoinTest { + private final static KeyValueTimestamp[] EMPTY = new KeyValueTimestamp[0]; + + private final String topic1 = "topic1"; + private final String topic2 = "topic2"; + private final String output = "output"; + private final Consumed consumed = Consumed.with(Serdes.Integer(), Serdes.String()); + private final Materialized> materialized = + Materialized.with(Serdes.Integer(), Serdes.String()); + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.Integer(), Serdes.String()); + + @Test + public void testJoin() { + final StreamsBuilder builder = new StreamsBuilder(); + + final int[] expectedKeys = new int[] {0, 1, 2, 3}; + + final KTable table1; + final KTable table2; + final KTable joined; + table1 = builder.table(topic1, consumed); + table2 = builder.table(topic2, consumed); + joined = table1.join(table2, MockValueJoiner.TOSTRING_JOINER); + joined.toStream().to(output); + + doTestJoin(builder, expectedKeys); + } + + @Test + public void testQueryableJoin() { + final StreamsBuilder builder = new StreamsBuilder(); + + final int[] expectedKeys = new int[] {0, 1, 2, 3}; + + final KTable table1; + final KTable table2; + final KTable table3; + table1 = builder.table(topic1, consumed); + table2 = builder.table(topic2, consumed); + table3 = table1.join(table2, MockValueJoiner.TOSTRING_JOINER, materialized); + table3.toStream().to(output); + + doTestJoin(builder, expectedKeys); + } + + @Test + public void testQueryableNotSendingOldValues() { + final StreamsBuilder builder = new StreamsBuilder(); + + final int[] expectedKeys = new int[] {0, 1, 2, 3}; + + final KTable table1; + final KTable table2; + final KTable joined; + final MockApiProcessorSupplier supplier = new MockApiProcessorSupplier<>(); + + table1 = builder.table(topic1, consumed); + table2 = builder.table(topic2, consumed); + joined = table1.join(table2, MockValueJoiner.TOSTRING_JOINER, materialized); + builder.build().addProcessor("proc", supplier, ((KTableImpl) joined).name); + + doTestNotSendingOldValues(builder, expectedKeys, table1, table2, supplier, joined); + } + + @Test + public void testNotSendingOldValues() { + final StreamsBuilder builder = new StreamsBuilder(); + + final int[] expectedKeys = new int[] {0, 1, 2, 3}; + + final KTable table1; + final KTable table2; + final KTable joined; + final MockApiProcessorSupplier supplier = new MockApiProcessorSupplier<>(); + + table1 = builder.table(topic1, consumed); + table2 = builder.table(topic2, consumed); + joined = table1.join(table2, MockValueJoiner.TOSTRING_JOINER); + builder.build().addProcessor("proc", supplier, ((KTableImpl) joined).name); + + doTestNotSendingOldValues(builder, expectedKeys, table1, table2, supplier, joined); + } + + @Test + public void testSendingOldValues() { + final StreamsBuilder builder = new StreamsBuilder(); + + final int[] expectedKeys = new int[] {0, 1, 2, 3}; + + final KTable table1; + final KTable table2; + final KTable joined; + final MockApiProcessorSupplier supplier = new MockApiProcessorSupplier<>(); + + table1 = builder.table(topic1, consumed); + table2 = builder.table(topic2, consumed); + joined = table1.join(table2, MockValueJoiner.TOSTRING_JOINER); + + ((KTableImpl) joined).enableSendingOldValues(true); + + builder.build().addProcessor("proc", supplier, ((KTableImpl) joined).name); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, Serdes.Integer().serializer(), Serdes.String().serializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, Serdes.Integer().serializer(), Serdes.String().serializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockApiProcessor proc = supplier.theCapturedProcessor(); + + assertTrue(((KTableImpl) table1).sendingOldValueEnabled()); + assertTrue(((KTableImpl) table2).sendingOldValueEnabled()); + assertTrue(((KTableImpl) joined).sendingOldValueEnabled()); + + // push two items to the primary stream. the other table is empty + for (int i = 0; i < 2; i++) { + inputTopic1.pipeInput(expectedKeys[i], "X" + expectedKeys[i], 5L + i); + } + // pass tuple with null key, it will be discarded in join process + inputTopic1.pipeInput(null, "SomeVal", 42L); + // left: X0:0 (ts: 5), X1:1 (ts: 6) + // right: + proc.checkAndClearProcessResult(EMPTY); + + // push two items to the other stream. this should produce two items. + for (int i = 0; i < 2; i++) { + inputTopic2.pipeInput(expectedKeys[i], "Y" + expectedKeys[i], 10L * i); + } + // pass tuple with null key, it will be discarded in join process + inputTopic2.pipeInput(null, "AnotherVal", 73L); + // left: X0:0 (ts: 5), X1:1 (ts: 6) + // right: Y0:0 (ts: 0), Y1:1 (ts: 10) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("X0+Y0", null), 5), + new KeyValueTimestamp<>(1, new Change<>("X1+Y1", null), 10)); + // push all four items to the primary stream. this should produce two items. + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "XX" + expectedKey, 7L); + } + // left: XX0:0 (ts: 7), XX1:1 (ts: 7), XX2:2 (ts: 7), XX3:3 (ts: 7) + // right: Y0:0 (ts: 0), Y1:1 (ts: 10) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("XX0+Y0", "X0+Y0"), 7), + new KeyValueTimestamp<>(1, new Change<>("XX1+Y1", "X1+Y1"), 10)); + // push all items to the other stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "YY" + expectedKey, expectedKey * 5L); + } + // left: XX0:0 (ts: 7), XX1:1 (ts: 7), XX2:2 (ts: 7), XX3:3 (ts: 7) + // right: YY0:0 (ts: 0), YY1:1 (ts: 5), YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, new Change<>("XX0+YY0", "XX0+Y0"), 7), + new KeyValueTimestamp<>(1, new Change<>("XX1+YY1", "XX1+Y1"), 7), + new KeyValueTimestamp<>(2, new Change<>("XX2+YY2", null), 10), + new KeyValueTimestamp<>(3, new Change<>("XX3+YY3", null), 15)); + + // push all four items to the primary stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "XXX" + expectedKey, 6L); + } + // left: XXX0:0 (ts: 6), XXX1:1 (ts: 6), XXX2:2 (ts: 6), XXX3:3 (ts: 6) + // right: YY0:0 (ts: 0), YY1:1 (ts: 5), YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, new Change<>("XXX0+YY0", "XX0+YY0"), 6), + new KeyValueTimestamp<>(1, new Change<>("XXX1+YY1", "XX1+YY1"), 6), + new KeyValueTimestamp<>(2, new Change<>("XXX2+YY2", "XX2+YY2"), 10), + new KeyValueTimestamp<>(3, new Change<>("XXX3+YY3", "XX3+YY3"), 15)); + + // push two items with null to the other stream as deletes. this should produce two item. + inputTopic2.pipeInput(expectedKeys[0], null, 5L); + inputTopic2.pipeInput(expectedKeys[1], null, 7L); + // left: XXX0:0 (ts: 6), XXX1:1 (ts: 6), XXX2:2 (ts: 6), XXX3:3 (ts: 6) + // right: YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>(null, "XXX0+YY0"), 6), + new KeyValueTimestamp<>(1, new Change<>(null, "XXX1+YY1"), 7)); + // push all four items to the primary stream. this should produce two items. + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "XXXX" + expectedKey, 13L); + } + // left: XXXX0:0 (ts: 13), XXXX1:1 (ts: 13), XXXX2:2 (ts: 13), XXXX3:3 (ts: 13) + // right: YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(2, new Change<>("XXXX2+YY2", "XXX2+YY2"), 13), + new KeyValueTimestamp<>(3, new Change<>("XXXX3+YY3", "XXX3+YY3"), 15)); + // push four items to the primary stream with null. this should produce two items. + inputTopic1.pipeInput(expectedKeys[0], null, 0L); + inputTopic1.pipeInput(expectedKeys[1], null, 42L); + inputTopic1.pipeInput(expectedKeys[2], null, 5L); + inputTopic1.pipeInput(expectedKeys[3], null, 20L); + // left: + // right: YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(2, new Change<>(null, "XXXX2+YY2"), 10), + new KeyValueTimestamp<>(3, new Change<>(null, "XXXX3+YY3"), 20)); + } + } + + @Test + public void shouldLogAndMeterSkippedRecordsDueToNullLeftKey() { + final StreamsBuilder builder = new StreamsBuilder(); + + @SuppressWarnings("unchecked") + final org.apache.kafka.streams.processor.Processor> join = new KTableKTableInnerJoin<>( + (KTableImpl) builder.table("left", Consumed.with(Serdes.String(), Serdes.String())), + (KTableImpl) builder.table("right", Consumed.with(Serdes.String(), Serdes.String())), + null + ).get(); + + final MockProcessorContext context = new MockProcessorContext(props); + context.setRecordMetadata("left", -1, -2, new RecordHeaders(), -3); + join.init(context); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(KTableKTableInnerJoin.class)) { + join.process(null, new Change<>("new", "old")); + + assertThat( + appender.getMessages(), + hasItem("Skipping record due to null key. change=[(new<-old)] topic=[left] partition=[-1] offset=[-2]") + ); + } + } + + private void doTestNotSendingOldValues(final StreamsBuilder builder, + final int[] expectedKeys, + final KTable table1, + final KTable table2, + final MockApiProcessorSupplier supplier, + final KTable joined) { + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, Serdes.Integer().serializer(), Serdes.String().serializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, Serdes.Integer().serializer(), Serdes.String().serializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockApiProcessor proc = supplier.theCapturedProcessor(); + + assertFalse(((KTableImpl) table1).sendingOldValueEnabled()); + assertFalse(((KTableImpl) table2).sendingOldValueEnabled()); + assertFalse(((KTableImpl) joined).sendingOldValueEnabled()); + + // push two items to the primary stream. the other table is empty + for (int i = 0; i < 2; i++) { + inputTopic1.pipeInput(expectedKeys[i], "X" + expectedKeys[i], 5L + i); + } + // pass tuple with null key, it will be discarded in join process + inputTopic1.pipeInput(null, "SomeVal", 42L); + // left: X0:0 (ts: 5), X1:1 (ts: 6) + // right: + proc.checkAndClearProcessResult(EMPTY); + + // push two items to the other stream. this should produce two items. + for (int i = 0; i < 2; i++) { + inputTopic2.pipeInput(expectedKeys[i], "Y" + expectedKeys[i], 10L * i); + } + // pass tuple with null key, it will be discarded in join process + inputTopic2.pipeInput(null, "AnotherVal", 73L); + // left: X0:0 (ts: 5), X1:1 (ts: 6) + // right: Y0:0 (ts: 0), Y1:1 (ts: 10) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("X0+Y0", null), 5), + new KeyValueTimestamp<>(1, new Change<>("X1+Y1", null), 10)); + // push all four items to the primary stream. this should produce two items. + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "XX" + expectedKey, 7L); + } + // left: XX0:0 (ts: 7), XX1:1 (ts: 7), XX2:2 (ts: 7), XX3:3 (ts: 7) + // right: Y0:0 (ts: 0), Y1:1 (ts: 10) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("XX0+Y0", null), 7), + new KeyValueTimestamp<>(1, new Change<>("XX1+Y1", null), 10)); + // push all items to the other stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "YY" + expectedKey, expectedKey * 5L); + } + // left: XX0:0 (ts: 7), XX1:1 (ts: 7), XX2:2 (ts: 7), XX3:3 (ts: 7) + // right: YY0:0 (ts: 0), YY1:1 (ts: 5), YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, new Change<>("XX0+YY0", null), 7), + new KeyValueTimestamp<>(1, new Change<>("XX1+YY1", null), 7), + new KeyValueTimestamp<>(2, new Change<>("XX2+YY2", null), 10), + new KeyValueTimestamp<>(3, new Change<>("XX3+YY3", null), 15)); + + // push all four items to the primary stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "XXX" + expectedKey, 6L); + } + // left: XXX0:0 (ts: 6), XXX1:1 (ts: 6), XXX2:2 (ts: 6), XXX3:3 (ts: 6) + // right: YY0:0 (ts: 0), YY1:1 (ts: 5), YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult( + new KeyValueTimestamp<>(0, new Change<>("XXX0+YY0", null), 6), + new KeyValueTimestamp<>(1, new Change<>("XXX1+YY1", null), 6), + new KeyValueTimestamp<>(2, new Change<>("XXX2+YY2", null), 10), + new KeyValueTimestamp<>(3, new Change<>("XXX3+YY3", null), 15)); + + // push two items with null to the other stream as deletes. this should produce two item. + inputTopic2.pipeInput(expectedKeys[0], null, 5L); + inputTopic2.pipeInput(expectedKeys[1], null, 7L); + // left: XXX0:0 (ts: 6), XXX1:1 (ts: 6), XXX2:2 (ts: 6), XXX3:3 (ts: 6) + // right: YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>(null, null), 6), + new KeyValueTimestamp<>(1, new Change<>(null, null), 7)); + // push all four items to the primary stream. this should produce two items. + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "XXXX" + expectedKey, 13L); + } + // left: XXXX0:0 (ts: 13), XXXX1:1 (ts: 13), XXXX2:2 (ts: 13), XXXX3:3 (ts: 13) + // right: YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(2, new Change<>("XXXX2+YY2", null), 13), + new KeyValueTimestamp<>(3, new Change<>("XXXX3+YY3", null), 15)); + // push four items to the primary stream with null. this should produce two items. + inputTopic1.pipeInput(expectedKeys[0], null, 0L); + inputTopic1.pipeInput(expectedKeys[1], null, 42L); + inputTopic1.pipeInput(expectedKeys[2], null, 5L); + inputTopic1.pipeInput(expectedKeys[3], null, 20L); + // left: + // right: YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(2, new Change<>(null, null), 10), + new KeyValueTimestamp<>(3, new Change<>(null, null), 20)); + } + } + + private void doTestJoin(final StreamsBuilder builder, final int[] expectedKeys) { + final Collection> copartitionGroups = + TopologyWrapper.getInternalTopologyBuilder(builder.build()).copartitionGroups(); + + assertEquals(1, copartitionGroups.size()); + assertEquals(new HashSet<>(Arrays.asList(topic1, topic2)), copartitionGroups.iterator().next()); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, Serdes.Integer().serializer(), Serdes.String().serializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, Serdes.Integer().serializer(), Serdes.String().serializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestOutputTopic outputTopic = + driver.createOutputTopic(output, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + // push two items to the primary stream. the other table is empty + for (int i = 0; i < 2; i++) { + inputTopic1.pipeInput(expectedKeys[i], "X" + expectedKeys[i], 5L + i); + } + // pass tuple with null key, it will be discarded in join process + inputTopic1.pipeInput(null, "SomeVal", 42L); + // left: X0:0 (ts: 5), X1:1 (ts: 6) + // right: + assertTrue(outputTopic.isEmpty()); + + // push two items to the other stream. this should produce two items. + for (int i = 0; i < 2; i++) { + inputTopic2.pipeInput(expectedKeys[i], "Y" + expectedKeys[i], 10L * i); + } + // pass tuple with null key, it will be discarded in join process + inputTopic2.pipeInput(null, "AnotherVal", 73L); + // left: X0:0 (ts: 5), X1:1 (ts: 6) + // right: Y0:0 (ts: 0), Y1:1 (ts: 10) + assertOutputKeyValueTimestamp(outputTopic, 0, "X0+Y0", 5L); + assertOutputKeyValueTimestamp(outputTopic, 1, "X1+Y1", 10L); + assertTrue(outputTopic.isEmpty()); + + // push all four items to the primary stream. this should produce two items. + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "XX" + expectedKey, 7L); + } + // left: XX0:0 (ts: 7), XX1:1 (ts: 7), XX2:2 (ts: 7), XX3:3 (ts: 7) + // right: Y0:0 (ts: 0), Y1:1 (ts: 10) + assertOutputKeyValueTimestamp(outputTopic, 0, "XX0+Y0", 7L); + assertOutputKeyValueTimestamp(outputTopic, 1, "XX1+Y1", 10L); + assertTrue(outputTopic.isEmpty()); + + // push all items to the other stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "YY" + expectedKey, expectedKey * 5L); + } + // left: XX0:0 (ts: 7), XX1:1 (ts: 7), XX2:2 (ts: 7), XX3:3 (ts: 7) + // right: YY0:0 (ts: 0), YY1:1 (ts: 5), YY2:2 (ts: 10), YY3:3 (ts: 15) + assertOutputKeyValueTimestamp(outputTopic, 0, "XX0+YY0", 7L); + assertOutputKeyValueTimestamp(outputTopic, 1, "XX1+YY1", 7L); + assertOutputKeyValueTimestamp(outputTopic, 2, "XX2+YY2", 10L); + assertOutputKeyValueTimestamp(outputTopic, 3, "XX3+YY3", 15L); + assertTrue(outputTopic.isEmpty()); + + // push all four items to the primary stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "XXX" + expectedKey, 6L); + } + // left: XXX0:0 (ts: 6), XXX1:1 (ts: 6), XXX2:2 (ts: 6), XXX3:3 (ts: 6) + // right: YY0:0 (ts: 0), YY1:1 (ts: 5), YY2:2 (ts: 10), YY3:3 (ts: 15) + assertOutputKeyValueTimestamp(outputTopic, 0, "XXX0+YY0", 6L); + assertOutputKeyValueTimestamp(outputTopic, 1, "XXX1+YY1", 6L); + assertOutputKeyValueTimestamp(outputTopic, 2, "XXX2+YY2", 10L); + assertOutputKeyValueTimestamp(outputTopic, 3, "XXX3+YY3", 15L); + assertTrue(outputTopic.isEmpty()); + + // push two items with null to the other stream as deletes. this should produce two item. + inputTopic2.pipeInput(expectedKeys[0], null, 5L); + inputTopic2.pipeInput(expectedKeys[1], null, 7L); + // left: XXX0:0 (ts: 6), XXX1:1 (ts: 6), XXX2:2 (ts: 6), XXX3:3 (ts: 6) + // right: YY2:2 (ts: 10), YY3:3 (ts: 15) + assertOutputKeyValueTimestamp(outputTopic, 0, null, 6L); + assertOutputKeyValueTimestamp(outputTopic, 1, null, 7L); + assertTrue(outputTopic.isEmpty()); + + // push all four items to the primary stream. this should produce two items. + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "XXXX" + expectedKey, 13L); + } + // left: XXXX0:0 (ts: 13), XXXX1:1 (ts: 13), XXXX2:2 (ts: 13), XXXX3:3 (ts: 13) + // right: YY2:2 (ts: 10), YY3:3 (ts: 15) + assertOutputKeyValueTimestamp(outputTopic, 2, "XXXX2+YY2", 13L); + assertOutputKeyValueTimestamp(outputTopic, 3, "XXXX3+YY3", 15L); + assertTrue(outputTopic.isEmpty()); + + // push fourt items to the primary stream with null. this should produce two items. + inputTopic1.pipeInput(expectedKeys[0], null, 0L); + inputTopic1.pipeInput(expectedKeys[1], null, 42L); + inputTopic1.pipeInput(expectedKeys[2], null, 5L); + inputTopic1.pipeInput(expectedKeys[3], null, 20L); + // left: + // right: YY2:2 (ts: 10), YY3:3 (ts: 15) + assertOutputKeyValueTimestamp(outputTopic, 2, null, 10L); + assertOutputKeyValueTimestamp(outputTopic, 3, null, 20L); + assertTrue(outputTopic.isEmpty()); + } + } + + private void assertOutputKeyValueTimestamp(final TestOutputTopic outputTopic, + final Integer expectedKey, + final String expectedValue, + final long expectedTimestamp) { + assertThat(outputTopic.readRecord(), equalTo(new TestRecord<>(expectedKey, expectedValue, null, expectedTimestamp))); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableKTableLeftJoinTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableKTableLeftJoinTest.java new file mode 100644 index 0000000..0b14d8b --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableKTableLeftJoinTest.java @@ -0,0 +1,546 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TestOutputTopic; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.TopologyTestDriverWrapper; +import org.apache.kafka.streams.TopologyWrapper; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.ValueMapper; +import org.apache.kafka.streams.processor.MockProcessorContext; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.test.TestRecord; +import org.apache.kafka.test.MockApiProcessor; +import org.apache.kafka.test.MockApiProcessorSupplier; +import org.apache.kafka.test.MockReducer; +import org.apache.kafka.test.MockValueJoiner; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Test; + +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.Locale; +import java.util.Properties; +import java.util.Random; +import java.util.Set; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class KTableKTableLeftJoinTest { + private final String topic1 = "topic1"; + private final String topic2 = "topic2"; + private final String output = "output"; + private final Consumed consumed = Consumed.with(Serdes.Integer(), Serdes.String()); + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.Integer(), Serdes.String()); + + @Test + public void testJoin() { + final StreamsBuilder builder = new StreamsBuilder(); + + final int[] expectedKeys = new int[] {0, 1, 2, 3}; + + final KTable table1 = builder.table(topic1, consumed); + final KTable table2 = builder.table(topic2, consumed); + final KTable joined = table1.leftJoin(table2, MockValueJoiner.TOSTRING_JOINER); + joined.toStream().to(output); + + final Collection> copartitionGroups = + TopologyWrapper.getInternalTopologyBuilder(builder.build()).copartitionGroups(); + + assertEquals(1, copartitionGroups.size()); + assertEquals(new HashSet<>(Arrays.asList(topic1, topic2)), copartitionGroups.iterator().next()); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, Serdes.Integer().serializer(), Serdes.String().serializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, Serdes.Integer().serializer(), Serdes.String().serializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestOutputTopic outputTopic = + driver.createOutputTopic(output, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + // push two items to the primary stream. the other table is empty + for (int i = 0; i < 2; i++) { + inputTopic1.pipeInput(expectedKeys[i], "X" + expectedKeys[i], 5L + i); + } + // pass tuple with null key, it will be discarded in join process + inputTopic1.pipeInput(null, "SomeVal", 42L); + // left: X0:0 (ts: 5), X1:1 (ts: 6) + // right: + assertOutputKeyValueTimestamp(outputTopic, 0, "X0+null", 5L); + assertOutputKeyValueTimestamp(outputTopic, 1, "X1+null", 6L); + assertTrue(outputTopic.isEmpty()); + + // push two items to the other stream. this should produce two items. + for (int i = 0; i < 2; i++) { + inputTopic2.pipeInput(expectedKeys[i], "Y" + expectedKeys[i], 10L * i); + } + // pass tuple with null key, it will be discarded in join process + inputTopic2.pipeInput(null, "AnotherVal", 73L); + // left: X0:0 (ts: 5), X1:1 (ts: 6) + // right: Y0:0 (ts: 0), Y1:1 (ts: 10) + assertOutputKeyValueTimestamp(outputTopic, 0, "X0+Y0", 5L); + assertOutputKeyValueTimestamp(outputTopic, 1, "X1+Y1", 10L); + assertTrue(outputTopic.isEmpty()); + + // push all four items to the primary stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "XX" + expectedKey, 7L); + } + // left: XX0:0 (ts: 7), XX1:1 (ts: 7), XX2:2 (ts: 7), XX3:3 (ts: 7) + // right: Y0:0 (ts: 0), Y1:1 (ts: 10) + assertOutputKeyValueTimestamp(outputTopic, 0, "XX0+Y0", 7L); + assertOutputKeyValueTimestamp(outputTopic, 1, "XX1+Y1", 10L); + assertOutputKeyValueTimestamp(outputTopic, 2, "XX2+null", 7L); + assertOutputKeyValueTimestamp(outputTopic, 3, "XX3+null", 7L); + assertTrue(outputTopic.isEmpty()); + + // push all items to the other stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "YY" + expectedKey, expectedKey * 5L); + } + // left: XX0:0 (ts: 7), XX1:1 (ts: 7), XX2:2 (ts: 7), XX3:3 (ts: 7) + // right: YY0:0 (ts: 0), YY1:1 (ts: 5), YY2:2 (ts: 10), YY3:3 (ts: 15) + assertOutputKeyValueTimestamp(outputTopic, 0, "XX0+YY0", 7L); + assertOutputKeyValueTimestamp(outputTopic, 1, "XX1+YY1", 7L); + assertOutputKeyValueTimestamp(outputTopic, 2, "XX2+YY2", 10L); + assertOutputKeyValueTimestamp(outputTopic, 3, "XX3+YY3", 15L); + assertTrue(outputTopic.isEmpty()); + + // push all four items to the primary stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "XXX" + expectedKey, 6L); + } + // left: XXX0:0 (ts: 6), XXX1:1 (ts: 6), XXX2:2 (ts: 6), XXX3:3 (ts: 6) + // right: YY0:0 (ts: 0), YY1:1 (ts: 5), YY2:2 (ts: 10), YY3:3 (ts: 15) + assertOutputKeyValueTimestamp(outputTopic, 0, "XXX0+YY0", 6L); + assertOutputKeyValueTimestamp(outputTopic, 1, "XXX1+YY1", 6L); + assertOutputKeyValueTimestamp(outputTopic, 2, "XXX2+YY2", 10L); + assertOutputKeyValueTimestamp(outputTopic, 3, "XXX3+YY3", 15L); + assertTrue(outputTopic.isEmpty()); + + // push two items with null to the other stream as deletes. this should produce two item. + inputTopic2.pipeInput(expectedKeys[0], null, 5L); + inputTopic2.pipeInput(expectedKeys[1], null, 7L); + // left: XXX0:0 (ts: 6), XXX1:1 (ts: 6), XXX2:2 (ts: 6), XXX3:3 (ts: 6) + // right: YY2:2 (ts: 10), YY3:3 (ts: 15) + assertOutputKeyValueTimestamp(outputTopic, 0, "XXX0+null", 6L); + assertOutputKeyValueTimestamp(outputTopic, 1, "XXX1+null", 7L); + assertTrue(outputTopic.isEmpty()); + + // push all four items to the primary stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "XXXX" + expectedKey, 13L); + } + // left: XXXX0:0 (ts: 13), XXXX1:1 (ts: 13), XXXX2:2 (ts: 13), XXXX3:3 (ts: 13) + // right: YY2:2 (ts: 10), YY3:3 (ts: 15) + assertOutputKeyValueTimestamp(outputTopic, 0, "XXXX0+null", 13L); + assertOutputKeyValueTimestamp(outputTopic, 1, "XXXX1+null", 13L); + assertOutputKeyValueTimestamp(outputTopic, 2, "XXXX2+YY2", 13L); + assertOutputKeyValueTimestamp(outputTopic, 3, "XXXX3+YY3", 15L); + assertTrue(outputTopic.isEmpty()); + + // push three items to the primary stream with null. this should produce four items. + inputTopic1.pipeInput(expectedKeys[0], null, 0L); + inputTopic1.pipeInput(expectedKeys[1], null, 42L); + inputTopic1.pipeInput(expectedKeys[2], null, 5L); + inputTopic1.pipeInput(expectedKeys[3], null, 20L); + // left: + // right: YY2:2 (ts: 10), YY3:3 (ts: 15) + assertOutputKeyValueTimestamp(outputTopic, 0, null, 0L); + assertOutputKeyValueTimestamp(outputTopic, 1, null, 42L); + assertOutputKeyValueTimestamp(outputTopic, 2, null, 10L); + assertOutputKeyValueTimestamp(outputTopic, 3, null, 20L); + assertTrue(outputTopic.isEmpty()); + } + } + + @Test + public void testNotSendingOldValue() { + final StreamsBuilder builder = new StreamsBuilder(); + + final int[] expectedKeys = new int[] {0, 1, 2, 3}; + + final KTable table1; + final KTable table2; + final KTable joined; + final MockApiProcessorSupplier supplier = new MockApiProcessorSupplier<>(); + + table1 = builder.table(topic1, consumed); + table2 = builder.table(topic2, consumed); + joined = table1.leftJoin(table2, MockValueJoiner.TOSTRING_JOINER); + + final Topology topology = builder.build().addProcessor("proc", supplier, ((KTableImpl) joined).name); + + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, Serdes.Integer().serializer(), Serdes.String().serializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, Serdes.Integer().serializer(), Serdes.String().serializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockApiProcessor proc = supplier.theCapturedProcessor(); + + assertTrue(((KTableImpl) table1).sendingOldValueEnabled()); + assertFalse(((KTableImpl) table2).sendingOldValueEnabled()); + assertFalse(((KTableImpl) joined).sendingOldValueEnabled()); + + // push two items to the primary stream. the other table is empty + for (int i = 0; i < 2; i++) { + inputTopic1.pipeInput(expectedKeys[i], "X" + expectedKeys[i], 5L + i); + } + // pass tuple with null key, it will be discarded in join process + inputTopic1.pipeInput(null, "SomeVal", 42L); + // left: X0:0 (ts: 5), X1:1 (ts: 6) + // right: + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("X0+null", null), 5), + new KeyValueTimestamp<>(1, new Change<>("X1+null", null), 6)); + + // push two items to the other stream. this should produce two items. + for (int i = 0; i < 2; i++) { + inputTopic2.pipeInput(expectedKeys[i], "Y" + expectedKeys[i], 10L * i); + } + // pass tuple with null key, it will be discarded in join process + inputTopic2.pipeInput(null, "AnotherVal", 73L); + // left: X0:0 (ts: 5), X1:1 (ts: 6) + // right: Y0:0 (ts: 0), Y1:1 (ts: 10) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("X0+Y0", null), 5), + new KeyValueTimestamp<>(1, new Change<>("X1+Y1", null), 10)); + + // push all four items to the primary stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "XX" + expectedKey, 7L); + } + // left: XX0:0 (ts: 7), XX1:1 (ts: 7), XX2:2 (ts: 7), XX3:3 (ts: 7) + // right: Y0:0 (ts: 0), Y1:1 (ts: 10) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("XX0+Y0", null), 7), + new KeyValueTimestamp<>(1, new Change<>("XX1+Y1", null), 10), + new KeyValueTimestamp<>(2, new Change<>("XX2+null", null), 7), + new KeyValueTimestamp<>(3, new Change<>("XX3+null", null), 7)); + + // push all items to the other stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "YY" + expectedKey, expectedKey * 5L); + } + // left: XX0:0 (ts: 7), XX1:1 (ts: 7), XX2:2 (ts: 7), XX3:3 (ts: 7) + // right: YY0:0 (ts: 0), YY1:1 (ts: 5), YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("XX0+YY0", null), 7), + new KeyValueTimestamp<>(1, new Change<>("XX1+YY1", null), 7), + new KeyValueTimestamp<>(2, new Change<>("XX2+YY2", null), 10), + new KeyValueTimestamp<>(3, new Change<>("XX3+YY3", null), 15)); + + // push all four items to the primary stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "XXX" + expectedKey, 6L); + } + // left: XXX0:0 (ts: 6), XXX1:1 (ts: 6), XXX2:2 (ts: 6), XXX3:3 (ts: 6) + // right: YY0:0 (ts: 0), YY1:1 (ts: 5), YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("XXX0+YY0", null), 6), + new KeyValueTimestamp<>(1, new Change<>("XXX1+YY1", null), 6), + new KeyValueTimestamp<>(2, new Change<>("XXX2+YY2", null), 10), + new KeyValueTimestamp<>(3, new Change<>("XXX3+YY3", null), 15)); + + // push two items with null to the other stream as deletes. this should produce two item. + inputTopic2.pipeInput(expectedKeys[0], null, 5L); + inputTopic2.pipeInput(expectedKeys[1], null, 7L); + // left: XXX0:0 (ts: 6), XXX1:1 (ts: 6), XXX2:2 (ts: 6), XXX3:3 (ts: 6) + // right: YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("XXX0+null", null), 6), + new KeyValueTimestamp<>(1, new Change<>("XXX1+null", null), 7)); + + // push all four items to the primary stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "XXXX" + expectedKey, 13L); + } + // left: XXXX0:0 (ts: 13), XXXX1:1 (ts: 13), XXXX2:2 (ts: 13), XXXX3:3 (ts: 13) + // right: YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("XXXX0+null", null), 13), + new KeyValueTimestamp<>(1, new Change<>("XXXX1+null", null), 13), + new KeyValueTimestamp<>(2, new Change<>("XXXX2+YY2", null), 13), + new KeyValueTimestamp<>(3, new Change<>("XXXX3+YY3", null), 15)); + + // push four items to the primary stream with null. this should produce four items. + inputTopic1.pipeInput(expectedKeys[0], null, 0L); + inputTopic1.pipeInput(expectedKeys[1], null, 42L); + inputTopic1.pipeInput(expectedKeys[2], null, 5L); + inputTopic1.pipeInput(expectedKeys[3], null, 20L); + // left: + // right: YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>(null, null), 0), + new KeyValueTimestamp<>(1, new Change<>(null, null), 42), + new KeyValueTimestamp<>(2, new Change<>(null, null), 10), + new KeyValueTimestamp<>(3, new Change<>(null, null), 20)); + } + } + + @Test + public void testSendingOldValue() { + final StreamsBuilder builder = new StreamsBuilder(); + + final int[] expectedKeys = new int[] {0, 1, 2, 3}; + + final KTable table1; + final KTable table2; + final KTable joined; + final MockApiProcessorSupplier supplier = new MockApiProcessorSupplier<>(); + + table1 = builder.table(topic1, consumed); + table2 = builder.table(topic2, consumed); + joined = table1.leftJoin(table2, MockValueJoiner.TOSTRING_JOINER); + + ((KTableImpl) joined).enableSendingOldValues(true); + + assertThat(((KTableImpl) table1).sendingOldValueEnabled(), is(true)); + assertThat(((KTableImpl) table2).sendingOldValueEnabled(), is(true)); + assertThat(((KTableImpl) joined).sendingOldValueEnabled(), is(true)); + + final Topology topology = builder.build().addProcessor("proc", supplier, ((KTableImpl) joined).name); + + try (final TopologyTestDriver driver = new TopologyTestDriverWrapper(topology, props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, Serdes.Integer().serializer(), Serdes.String().serializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, Serdes.Integer().serializer(), Serdes.String().serializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockApiProcessor proc = supplier.theCapturedProcessor(); + + assertTrue(((KTableImpl) table1).sendingOldValueEnabled()); + assertTrue(((KTableImpl) table2).sendingOldValueEnabled()); + assertTrue(((KTableImpl) joined).sendingOldValueEnabled()); + + // push two items to the primary stream. the other table is empty + for (int i = 0; i < 2; i++) { + inputTopic1.pipeInput(expectedKeys[i], "X" + expectedKeys[i], 5L + i); + } + // pass tuple with null key, it will be discarded in join process + inputTopic1.pipeInput(null, "SomeVal", 42L); + // left: X0:0 (ts: 5), X1:1 (ts: 6) + // right: + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("X0+null", null), 5), + new KeyValueTimestamp<>(1, new Change<>("X1+null", null), 6)); + + // push two items to the other stream. this should produce two items. + for (int i = 0; i < 2; i++) { + inputTopic2.pipeInput(expectedKeys[i], "Y" + expectedKeys[i], 10L * i); + } + // pass tuple with null key, it will be discarded in join process + inputTopic2.pipeInput(null, "AnotherVal", 73L); + // left: X0:0 (ts: 5), X1:1 (ts: 6) + // right: Y0:0 (ts: 0), Y1:1 (ts: 10) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("X0+Y0", "X0+null"), 5), + new KeyValueTimestamp<>(1, new Change<>("X1+Y1", "X1+null"), 10)); + + // push all four items to the primary stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "XX" + expectedKey, 7L); + } + // left: XX0:0 (ts: 7), XX1:1 (ts: 7), XX2:2 (ts: 7), XX3:3 (ts: 7) + // right: Y0:0 (ts: 0), Y1:1 (ts: 10) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("XX0+Y0", "X0+Y0"), 7), + new KeyValueTimestamp<>(1, new Change<>("XX1+Y1", "X1+Y1"), 10), + new KeyValueTimestamp<>(2, new Change<>("XX2+null", null), 7), + new KeyValueTimestamp<>(3, new Change<>("XX3+null", null), 7)); + + // push all items to the other stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "YY" + expectedKey, expectedKey * 5L); + } + // left: XX0:0 (ts: 7), XX1:1 (ts: 7), XX2:2 (ts: 7), XX3:3 (ts: 7) + // right: YY0:0 (ts: 0), YY1:1 (ts: 5), YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("XX0+YY0", "XX0+Y0"), 7), + new KeyValueTimestamp<>(1, new Change<>("XX1+YY1", "XX1+Y1"), 7), + new KeyValueTimestamp<>(2, new Change<>("XX2+YY2", "XX2+null"), 10), + new KeyValueTimestamp<>(3, new Change<>("XX3+YY3", "XX3+null"), 15)); + // push all four items to the primary stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "XXX" + expectedKey, 6L); + } + // left: XXX0:0 (ts: 6), XXX1:1 (ts: 6), XXX2:2 (ts: 6), XXX3:3 (ts: 6) + // right: YY0:0 (ts: 0), YY1:1 (ts: 5), YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("XXX0+YY0", "XX0+YY0"), 6), + new KeyValueTimestamp<>(1, new Change<>("XXX1+YY1", "XX1+YY1"), 6), + new KeyValueTimestamp<>(2, new Change<>("XXX2+YY2", "XX2+YY2"), 10), + new KeyValueTimestamp<>(3, new Change<>("XXX3+YY3", "XX3+YY3"), 15)); + + // push two items with null to the other stream as deletes. this should produce two item. + inputTopic2.pipeInput(expectedKeys[0], null, 5L); + inputTopic2.pipeInput(expectedKeys[1], null, 7L); + // left: XXX0:0 (ts: 6), XXX1:1 (ts: 6), XXX2:2 (ts: 6), XXX3:3 (ts: 6) + // right: YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("XXX0+null", "XXX0+YY0"), 6), + new KeyValueTimestamp<>(1, new Change<>("XXX1+null", "XXX1+YY1"), 7)); + + // push all four items to the primary stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "XXXX" + expectedKey, 13L); + } + // left: XXXX0:0 (ts: 13), XXXX1:1 (ts: 13), XXXX2:2 (ts: 13), XXXX3:3 (ts: 13) + // right: YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("XXXX0+null", "XXX0+null"), 13), + new KeyValueTimestamp<>(1, new Change<>("XXXX1+null", "XXX1+null"), 13), + new KeyValueTimestamp<>(2, new Change<>("XXXX2+YY2", "XXX2+YY2"), 13), + new KeyValueTimestamp<>(3, new Change<>("XXXX3+YY3", "XXX3+YY3"), 15)); + // push four items to the primary stream with null. this should produce four items. + inputTopic1.pipeInput(expectedKeys[0], null, 0L); + inputTopic1.pipeInput(expectedKeys[1], null, 42L); + inputTopic1.pipeInput(expectedKeys[2], null, 5L); + inputTopic1.pipeInput(expectedKeys[3], null, 20L); + // left: + // right: YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>(null, "XXXX0+null"), 0), + new KeyValueTimestamp<>(1, new Change<>(null, "XXXX1+null"), 42), + new KeyValueTimestamp<>(2, new Change<>(null, "XXXX2+YY2"), 10), + new KeyValueTimestamp<>(3, new Change<>(null, "XXXX3+YY3"), 20)); + } + } + + /** + * This test was written to reproduce https://issues.apache.org/jira/browse/KAFKA-4492 + * It is based on a fairly complicated join used by the developer that reported the bug. + * Before the fix this would trigger an IllegalStateException. + */ + @Test + public void shouldNotThrowIllegalStateExceptionWhenMultiCacheEvictions() { + final String agg = "agg"; + final String tableOne = "tableOne"; + final String tableTwo = "tableTwo"; + final String tableThree = "tableThree"; + final String tableFour = "tableFour"; + final String tableFive = "tableFive"; + final String tableSix = "tableSix"; + final String[] inputs = {agg, tableOne, tableTwo, tableThree, tableFour, tableFive, tableSix}; + + final StreamsBuilder builder = new StreamsBuilder(); + final Consumed consumed = Consumed.with(Serdes.Long(), Serdes.String()); + final KTable aggTable = builder + .table(agg, consumed, Materialized.as(Stores.inMemoryKeyValueStore("agg-base-store"))) + .groupBy(KeyValue::new, Grouped.with(Serdes.Long(), Serdes.String())) + .reduce( + MockReducer.STRING_ADDER, + MockReducer.STRING_ADDER, + Materialized.as(Stores.inMemoryKeyValueStore("agg-store"))); + + final KTable one = builder.table( + tableOne, + consumed, + Materialized.as(Stores.inMemoryKeyValueStore("tableOne-base-store"))); + final KTable two = builder.table( + tableTwo, + consumed, + Materialized.as(Stores.inMemoryKeyValueStore("tableTwo-base-store"))); + final KTable three = builder.table( + tableThree, + consumed, + Materialized.as(Stores.inMemoryKeyValueStore("tableThree-base-store"))); + final KTable four = builder.table( + tableFour, + consumed, + Materialized.as(Stores.inMemoryKeyValueStore("tableFour-base-store"))); + final KTable five = builder.table( + tableFive, + consumed, + Materialized.as(Stores.inMemoryKeyValueStore("tableFive-base-store"))); + final KTable six = builder.table( + tableSix, + consumed, + Materialized.as(Stores.inMemoryKeyValueStore("tableSix-base-store"))); + + final ValueMapper mapper = value -> value.toUpperCase(Locale.ROOT); + + final KTable seven = one.mapValues(mapper); + + final KTable eight = six.leftJoin(seven, MockValueJoiner.TOSTRING_JOINER); + + aggTable + .leftJoin(one, MockValueJoiner.TOSTRING_JOINER) + .leftJoin(two, MockValueJoiner.TOSTRING_JOINER) + .leftJoin(three, MockValueJoiner.TOSTRING_JOINER) + .leftJoin(four, MockValueJoiner.TOSTRING_JOINER) + .leftJoin(five, MockValueJoiner.TOSTRING_JOINER) + .leftJoin(eight, MockValueJoiner.TOSTRING_JOINER) + .mapValues(mapper); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final String[] values = { + "a", "AA", "BBB", "CCCC", "DD", "EEEEEEEE", "F", "GGGGGGGGGGGGGGG", "HHH", "IIIIIIIIII", + "J", "KK", "LLLL", "MMMMMMMMMMMMMMMMMMMMMM", "NNNNN", "O", "P", "QQQQQ", "R", "SSSS", + "T", "UU", "VVVVVVVVVVVVVVVVVVV" + }; + + TestInputTopic inputTopic; + final Random random = new Random(); + for (int i = 0; i < 1000; i++) { + for (final String input : inputs) { + final Long key = (long) random.nextInt(1000); + final String value = values[random.nextInt(values.length)]; + inputTopic = driver.createInputTopic(input, Serdes.Long().serializer(), Serdes.String().serializer()); + inputTopic.pipeInput(key, value); + } + } + } + } + + @Test + public void shouldLogAndMeterSkippedRecordsDueToNullLeftKey() { + final StreamsBuilder builder = new StreamsBuilder(); + + @SuppressWarnings("unchecked") + final org.apache.kafka.streams.processor.Processor> join = new KTableKTableLeftJoin<>( + (KTableImpl) builder.table("left", Consumed.with(Serdes.String(), Serdes.String())), + (KTableImpl) builder.table("right", Consumed.with(Serdes.String(), Serdes.String())), + null + ).get(); + + final MockProcessorContext context = new MockProcessorContext(props); + context.setRecordMetadata("left", -1, -2, new RecordHeaders(), -3); + join.init(context); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(KTableKTableLeftJoin.class)) { + join.process(null, new Change<>("new", "old")); + + assertThat( + appender.getMessages(), + hasItem("Skipping record due to null key. change=[(new<-old)] topic=[left] partition=[-1] offset=[-2]") + ); + } + } + + private void assertOutputKeyValueTimestamp(final TestOutputTopic outputTopic, + final Integer expectedKey, + final String expectedValue, + final long expectedTimestamp) { + assertThat(outputTopic.readRecord(), equalTo(new TestRecord<>(expectedKey, expectedValue, null, expectedTimestamp))); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableKTableOuterJoinTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableKTableOuterJoinTest.java new file mode 100644 index 0000000..6b3cf3b --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableKTableOuterJoinTest.java @@ -0,0 +1,438 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TestOutputTopic; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.TopologyWrapper; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.processor.MockProcessorContext; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.apache.kafka.streams.test.TestRecord; +import org.apache.kafka.test.MockApiProcessor; +import org.apache.kafka.test.MockApiProcessorSupplier; +import org.apache.kafka.test.MockValueJoiner; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Test; + +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.Properties; +import java.util.Set; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class KTableKTableOuterJoinTest { + private final String topic1 = "topic1"; + private final String topic2 = "topic2"; + private final String output = "output"; + private final Consumed consumed = Consumed.with(Serdes.Integer(), Serdes.String()); + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.Integer(), Serdes.String()); + + @Test + public void testJoin() { + final StreamsBuilder builder = new StreamsBuilder(); + + final int[] expectedKeys = new int[]{0, 1, 2, 3}; + + final KTable table1; + final KTable table2; + final KTable joined; + + table1 = builder.table(topic1, consumed); + table2 = builder.table(topic2, consumed); + joined = table1.outerJoin(table2, MockValueJoiner.TOSTRING_JOINER); + joined.toStream().to(output); + + final Collection> copartitionGroups = + TopologyWrapper.getInternalTopologyBuilder(builder.build()).copartitionGroups(); + + assertEquals(1, copartitionGroups.size()); + assertEquals(new HashSet<>(Arrays.asList(topic1, topic2)), copartitionGroups.iterator().next()); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, Serdes.Integer().serializer(), Serdes.String().serializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, Serdes.Integer().serializer(), Serdes.String().serializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestOutputTopic outputTopic = + driver.createOutputTopic(output, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + // push two items to the primary stream. the other table is empty + for (int i = 0; i < 2; i++) { + inputTopic1.pipeInput(expectedKeys[i], "X" + expectedKeys[i], 5L + i); + } + // pass tuple with null key, it will be discarded in join process + inputTopic1.pipeInput(null, "SomeVal", 42L); + // left: X0:0 (ts: 5), X1:1 (ts: 6) + // right: + assertOutputKeyValueTimestamp(outputTopic, 0, "X0+null", 5L); + assertOutputKeyValueTimestamp(outputTopic, 1, "X1+null", 6L); + assertTrue(outputTopic.isEmpty()); + + // push two items to the other stream. this should produce two items. + for (int i = 0; i < 2; i++) { + inputTopic2.pipeInput(expectedKeys[i], "Y" + expectedKeys[i], 10L * i); + } + // pass tuple with null key, it will be discarded in join process + inputTopic2.pipeInput(null, "AnotherVal", 73L); + // left: X0:0 (ts: 5), X1:1 (ts: 6) + // right: Y0:0 (ts: 0), Y1:1 (ts: 10) + assertOutputKeyValueTimestamp(outputTopic, 0, "X0+Y0", 5L); + assertOutputKeyValueTimestamp(outputTopic, 1, "X1+Y1", 10L); + assertTrue(outputTopic.isEmpty()); + + // push all four items to the primary stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "XX" + expectedKey, 7L); + } + // left: XX0:0 (ts: 7), XX1:1 (ts: 7), XX2:2 (ts: 7), XX3:3 (ts: 7) + // right: Y0:0 (ts: 0), Y1:1 (ts: 10) + assertOutputKeyValueTimestamp(outputTopic, 0, "XX0+Y0", 7L); + assertOutputKeyValueTimestamp(outputTopic, 1, "XX1+Y1", 10L); + assertOutputKeyValueTimestamp(outputTopic, 2, "XX2+null", 7L); + assertOutputKeyValueTimestamp(outputTopic, 3, "XX3+null", 7L); + assertTrue(outputTopic.isEmpty()); + + // push all items to the other stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "YY" + expectedKey, expectedKey * 5L); + } + // left: XX0:0 (ts: 7), XX1:1 (ts: 7), XX2:2 (ts: 7), XX3:3 (ts: 7) + // right: YY0:0 (ts: 0), YY1:1 (ts: 5), YY2:2 (ts: 10), YY3:3 (ts: 15) + assertOutputKeyValueTimestamp(outputTopic, 0, "XX0+YY0", 7L); + assertOutputKeyValueTimestamp(outputTopic, 1, "XX1+YY1", 7L); + assertOutputKeyValueTimestamp(outputTopic, 2, "XX2+YY2", 10L); + assertOutputKeyValueTimestamp(outputTopic, 3, "XX3+YY3", 15L); + assertTrue(outputTopic.isEmpty()); + + // push all four items to the primary stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "XXX" + expectedKey, 6L); + } + // left: XXX0:0 (ts: 6), XXX1:1 (ts: 6), XXX2:2 (ts: 6), XXX3:3 (ts: 6) + // right: YY0:0 (ts: 0), YY1:1 (ts: 5), YY2:2 (ts: 10), YY3:3 (ts: 15) + assertOutputKeyValueTimestamp(outputTopic, 0, "XXX0+YY0", 6L); + assertOutputKeyValueTimestamp(outputTopic, 1, "XXX1+YY1", 6L); + assertOutputKeyValueTimestamp(outputTopic, 2, "XXX2+YY2", 10L); + assertOutputKeyValueTimestamp(outputTopic, 3, "XXX3+YY3", 15L); + assertTrue(outputTopic.isEmpty()); + + // push two items with null to the other stream as deletes. this should produce two item. + inputTopic2.pipeInput(expectedKeys[0], null, 5L); + inputTopic2.pipeInput(expectedKeys[1], null, 7L); + // left: XXX0:0 (ts: 6), XXX1:1 (ts: 6), XXX2:2 (ts: 6), XXX3:3 (ts: 6) + // right: YY2:2 (ts: 10), YY3:3 (ts: 15) + assertOutputKeyValueTimestamp(outputTopic, 0, "XXX0+null", 6L); + assertOutputKeyValueTimestamp(outputTopic, 1, "XXX1+null", 7L); + assertTrue(outputTopic.isEmpty()); + + // push all four items to the primary stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "XXXX" + expectedKey, 13L); + } + // left: XXXX0:0 (ts: 13), XXXX1:1 (ts: 13), XXXX2:2 (ts: 13), XXXX3:3 (ts: 13) + // right: YY2:2 (ts: 10), YY3:3 (ts: 15) + assertOutputKeyValueTimestamp(outputTopic, 0, "XXXX0+null", 13L); + assertOutputKeyValueTimestamp(outputTopic, 1, "XXXX1+null", 13L); + assertOutputKeyValueTimestamp(outputTopic, 2, "XXXX2+YY2", 13L); + assertOutputKeyValueTimestamp(outputTopic, 3, "XXXX3+YY3", 15L); + assertTrue(outputTopic.isEmpty()); + + // push four items to the primary stream with null. this should produce four items. + inputTopic1.pipeInput(expectedKeys[0], null, 0L); + inputTopic1.pipeInput(expectedKeys[1], null, 42L); + inputTopic1.pipeInput(expectedKeys[2], null, 5L); + inputTopic1.pipeInput(expectedKeys[3], null, 20L); + // left: + // right: YY2:2 (ts: 10), YY3:3 (ts: 15) + assertOutputKeyValueTimestamp(outputTopic, 0, null, 0L); + assertOutputKeyValueTimestamp(outputTopic, 1, null, 42L); + assertOutputKeyValueTimestamp(outputTopic, 2, "null+YY2", 10L); + assertOutputKeyValueTimestamp(outputTopic, 3, "null+YY3", 20L); + assertTrue(outputTopic.isEmpty()); + } + } + + @Test + public void testNotSendingOldValue() { + final StreamsBuilder builder = new StreamsBuilder(); + + final int[] expectedKeys = new int[]{0, 1, 2, 3}; + + final KTable table1; + final KTable table2; + final KTable joined; + final MockApiProcessorSupplier supplier = new MockApiProcessorSupplier<>(); + + table1 = builder.table(topic1, consumed); + table2 = builder.table(topic2, consumed); + joined = table1.outerJoin(table2, MockValueJoiner.TOSTRING_JOINER); + + final Topology topology = builder.build().addProcessor("proc", supplier, ((KTableImpl) joined).name); + + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, Serdes.Integer().serializer(), Serdes.String().serializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, Serdes.Integer().serializer(), Serdes.String().serializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockApiProcessor proc = supplier.theCapturedProcessor(); + + assertTrue(((KTableImpl) table1).sendingOldValueEnabled()); + assertTrue(((KTableImpl) table2).sendingOldValueEnabled()); + assertFalse(((KTableImpl) joined).sendingOldValueEnabled()); + + // push two items to the primary stream. the other table is empty + for (int i = 0; i < 2; i++) { + inputTopic1.pipeInput(expectedKeys[i], "X" + expectedKeys[i], 5L + i); + } + // pass tuple with null key, it will be discarded in join process + inputTopic1.pipeInput(null, "SomeVal", 42L); + // left: X0:0 (ts: 5), X1:1 (ts: 6) + // right: + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("X0+null", null), 5), + new KeyValueTimestamp<>(1, new Change<>("X1+null", null), 6)); + // push two items to the other stream. this should produce two items. + for (int i = 0; i < 2; i++) { + inputTopic2.pipeInput(expectedKeys[i], "Y" + expectedKeys[i], 10L * i); + } + // pass tuple with null key, it will be discarded in join process + inputTopic2.pipeInput(null, "AnotherVal", 73L); + // left: X0:0 (ts: 5), X1:1 (ts: 6) + // right: Y0:0 (ts: 0), Y1:1 (ts: 10) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("X0+Y0", null), 5), + new KeyValueTimestamp<>(1, new Change<>("X1+Y1", null), 10)); + // push all four items to the primary stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "XX" + expectedKey, 7L); + } + // left: XX0:0 (ts: 7), XX1:1 (ts: 7), XX2:2 (ts: 7), XX3:3 (ts: 7) + // right: Y0:0 (ts: 0), Y1:1 (ts: 10) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("XX0+Y0", null), 7), + new KeyValueTimestamp<>(1, new Change<>("XX1+Y1", null), 10), + new KeyValueTimestamp<>(2, new Change<>("XX2+null", null), 7), + new KeyValueTimestamp<>(3, new Change<>("XX3+null", null), 7)); + // push all items to the other stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "YY" + expectedKey, expectedKey * 5L); + } + // left: XX0:0 (ts: 7), XX1:1 (ts: 7), XX2:2 (ts: 7), XX3:3 (ts: 7) + // right: YY0:0 (ts: 0), YY1:1 (ts: 5), YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("XX0+YY0", null), 7), + new KeyValueTimestamp<>(1, new Change<>("XX1+YY1", null), 7), + new KeyValueTimestamp<>(2, new Change<>("XX2+YY2", null), 10), + new KeyValueTimestamp<>(3, new Change<>("XX3+YY3", null), 15)); + // push all four items to the primary stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "XXX" + expectedKey, 6L); + } + // left: XXX0:0 (ts: 6), XXX1:1 (ts: 6), XXX2:2 (ts: 6), XXX3:3 (ts: 6) + // right: YY0:0 (ts: 0), YY1:1 (ts: 5), YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("XXX0+YY0", null), 6), + new KeyValueTimestamp<>(1, new Change<>("XXX1+YY1", null), 6), + new KeyValueTimestamp<>(2, new Change<>("XXX2+YY2", null), 10), + new KeyValueTimestamp<>(3, new Change<>("XXX3+YY3", null), 15)); + // push two items with null to the other stream as deletes. this should produce two item. + inputTopic2.pipeInput(expectedKeys[0], null, 5L); + inputTopic2.pipeInput(expectedKeys[1], null, 7L); + // left: XXX0:0 (ts: 6), XXX1:1 (ts: 6), XXX2:2 (ts: 6), XXX3:3 (ts: 6) + // right: YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("XXX0+null", null), 6), + new KeyValueTimestamp<>(1, new Change<>("XXX1+null", null), 7)); + // push all four items to the primary stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "XXXX" + expectedKey, 13L); + } + // left: XXXX0:0 (ts: 13), XXXX1:1 (ts: 13), XXXX2:2 (ts: 13), XXXX3:3 (ts: 13) + // right: YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("XXXX0+null", null), 13), + new KeyValueTimestamp<>(1, new Change<>("XXXX1+null", null), 13), + new KeyValueTimestamp<>(2, new Change<>("XXXX2+YY2", null), 13), + new KeyValueTimestamp<>(3, new Change<>("XXXX3+YY3", null), 15)); + // push four items to the primary stream with null. this should produce four items. + inputTopic1.pipeInput(expectedKeys[0], null, 0L); + inputTopic1.pipeInput(expectedKeys[1], null, 42L); + inputTopic1.pipeInput(expectedKeys[2], null, 5L); + inputTopic1.pipeInput(expectedKeys[3], null, 20L); + // left: + // right: YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>(null, null), 0), + new KeyValueTimestamp<>(1, new Change<>(null, null), 42), + new KeyValueTimestamp<>(2, new Change<>("null+YY2", null), 10), + new KeyValueTimestamp<>(3, new Change<>("null+YY3", null), 20)); + } + } + + @Test + public void testSendingOldValue() { + final StreamsBuilder builder = new StreamsBuilder(); + + final int[] expectedKeys = new int[]{0, 1, 2, 3}; + + final KTable table1; + final KTable table2; + final KTable joined; + final MockApiProcessorSupplier supplier = new MockApiProcessorSupplier<>(); + + table1 = builder.table(topic1, consumed); + table2 = builder.table(topic2, consumed); + joined = table1.outerJoin(table2, MockValueJoiner.TOSTRING_JOINER); + + ((KTableImpl) joined).enableSendingOldValues(true); + + final Topology topology = builder.build().addProcessor("proc", supplier, ((KTableImpl) joined).name); + + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, Serdes.Integer().serializer(), Serdes.String().serializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final TestInputTopic inputTopic2 = + driver.createInputTopic(topic2, Serdes.Integer().serializer(), Serdes.String().serializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockApiProcessor proc = supplier.theCapturedProcessor(); + + assertTrue(((KTableImpl) table1).sendingOldValueEnabled()); + assertTrue(((KTableImpl) table2).sendingOldValueEnabled()); + assertTrue(((KTableImpl) joined).sendingOldValueEnabled()); + + // push two items to the primary stream. the other table is empty + for (int i = 0; i < 2; i++) { + inputTopic1.pipeInput(expectedKeys[i], "X" + expectedKeys[i], 5L + i); + } + // pass tuple with null key, it will be discarded in join process + inputTopic1.pipeInput(null, "SomeVal", 42L); + // left: X0:0 (ts: 5), X1:1 (ts: 6) + // right: + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("X0+null", null), 5), + new KeyValueTimestamp<>(1, new Change<>("X1+null", null), 6)); + // push two items to the other stream. this should produce two items. + for (int i = 0; i < 2; i++) { + inputTopic2.pipeInput(expectedKeys[i], "Y" + expectedKeys[i], 10L * i); + } + // pass tuple with null key, it will be discarded in join process + inputTopic2.pipeInput(null, "AnotherVal", 73L); + // left: X0:0 (ts: 5), X1:1 (ts: 6) + // right: Y0:0 (ts: 0), Y1:1 (ts: 10) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("X0+Y0", "X0+null"), 5), + new KeyValueTimestamp<>(1, new Change<>("X1+Y1", "X1+null"), 10)); + // push all four items to the primary stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "XX" + expectedKey, 7L); + } + // left: XX0:0 (ts: 7), XX1:1 (ts: 7), XX2:2 (ts: 7), XX3:3 (ts: 7) + // right: Y0:0 (ts: 0), Y1:1 (ts: 10) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("XX0+Y0", "X0+Y0"), 7), + new KeyValueTimestamp<>(1, new Change<>("XX1+Y1", "X1+Y1"), 10), + new KeyValueTimestamp<>(2, new Change<>("XX2+null", null), 7), + new KeyValueTimestamp<>(3, new Change<>("XX3+null", null), 7)); + // push all items to the other stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic2.pipeInput(expectedKey, "YY" + expectedKey, expectedKey * 5L); + } + // left: XX0:0 (ts: 7), XX1:1 (ts: 7), XX2:2 (ts: 7), XX3:3 (ts: 7) + // right: YY0:0 (ts: 0), YY1:1 (ts: 5), YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("XX0+YY0", "XX0+Y0"), 7), + new KeyValueTimestamp<>(1, new Change<>("XX1+YY1", "XX1+Y1"), 7), + new KeyValueTimestamp<>(2, new Change<>("XX2+YY2", "XX2+null"), 10), + new KeyValueTimestamp<>(3, new Change<>("XX3+YY3", "XX3+null"), 15)); + // push all four items to the primary stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "XXX" + expectedKey, 6L); + } + // left: XXX0:0 (ts: 6), XXX1:1 (ts: 6), XXX2:2 (ts: 6), XXX3:3 (ts: 6) + // right: YY0:0 (ts: 0), YY1:1 (ts: 5), YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("XXX0+YY0", "XX0+YY0"), 6), + new KeyValueTimestamp<>(1, new Change<>("XXX1+YY1", "XX1+YY1"), 6), + new KeyValueTimestamp<>(2, new Change<>("XXX2+YY2", "XX2+YY2"), 10), + new KeyValueTimestamp<>(3, new Change<>("XXX3+YY3", "XX3+YY3"), 15)); + // push two items with null to the other stream as deletes. this should produce two item. + inputTopic2.pipeInput(expectedKeys[0], null, 5L); + inputTopic2.pipeInput(expectedKeys[1], null, 7L); + // left: XXX0:0 (ts: 6), XXX1:1 (ts: 6), XXX2:2 (ts: 6), XXX3:3 (ts: 6) + // right: YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("XXX0+null", "XXX0+YY0"), 6), + new KeyValueTimestamp<>(1, new Change<>("XXX1+null", "XXX1+YY1"), 7)); + // push all four items to the primary stream. this should produce four items. + for (final int expectedKey : expectedKeys) { + inputTopic1.pipeInput(expectedKey, "XXXX" + expectedKey, 13L); + } + // left: XXXX0:0 (ts: 13), XXXX1:1 (ts: 13), XXXX2:2 (ts: 13), XXXX3:3 (ts: 13) + // right: YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>("XXXX0+null", "XXX0+null"), 13), + new KeyValueTimestamp<>(1, new Change<>("XXXX1+null", "XXX1+null"), 13), + new KeyValueTimestamp<>(2, new Change<>("XXXX2+YY2", "XXX2+YY2"), 13), + new KeyValueTimestamp<>(3, new Change<>("XXXX3+YY3", "XXX3+YY3"), 15)); + // push four items to the primary stream with null. this should produce four items. + inputTopic1.pipeInput(expectedKeys[0], null, 0L); + inputTopic1.pipeInput(expectedKeys[1], null, 42L); + inputTopic1.pipeInput(expectedKeys[2], null, 5L); + inputTopic1.pipeInput(expectedKeys[3], null, 20L); + // left: + // right: YY2:2 (ts: 10), YY3:3 (ts: 15) + proc.checkAndClearProcessResult(new KeyValueTimestamp<>(0, new Change<>(null, "XXXX0+null"), 0), + new KeyValueTimestamp<>(1, new Change<>(null, "XXXX1+null"), 42), + new KeyValueTimestamp<>(2, new Change<>("null+YY2", "XXXX2+YY2"), 10), + new KeyValueTimestamp<>(3, new Change<>("null+YY3", "XXXX3+YY3"), 20)); + } + } + + @Test + public void shouldLogAndMeterSkippedRecordsDueToNullLeftKey() { + final StreamsBuilder builder = new StreamsBuilder(); + + @SuppressWarnings("unchecked") + final org.apache.kafka.streams.processor.Processor> join = new KTableKTableOuterJoin<>( + (KTableImpl) builder.table("left", Consumed.with(Serdes.String(), Serdes.String())), + (KTableImpl) builder.table("right", Consumed.with(Serdes.String(), Serdes.String())), + null + ).get(); + + final MockProcessorContext context = new MockProcessorContext(props); + context.setRecordMetadata("left", -1, -2, new RecordHeaders(), -3); + join.init(context); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(KTableKTableOuterJoin.class)) { + join.process(null, new Change<>("new", "old")); + + assertThat( + appender.getMessages(), + hasItem("Skipping record due to null key. change=[(new<-old)] topic=[left] partition=[-1] offset=[-2]") + ); + } + } + + private void assertOutputKeyValueTimestamp(final TestOutputTopic outputTopic, + final Integer expectedKey, + final String expectedValue, + final long expectedTimestamp) { + assertThat(outputTopic.readRecord(), equalTo(new TestRecord<>(expectedKey, expectedValue, null, expectedTimestamp))); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableKTableRightJoinTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableKTableRightJoinTest.java new file mode 100644 index 0000000..56d9c99 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableKTableRightJoinTest.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.processor.MockProcessorContext; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender.Event; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Test; + +import java.util.Properties; +import java.util.stream.Collectors; + +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.MatcherAssert.assertThat; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class KTableKTableRightJoinTest { + + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.String(), Serdes.String()); + + @Test + public void shouldLogAndMeterSkippedRecordsDueToNullLeftKeyWithBuiltInMetricsVersionLatest() { + final StreamsBuilder builder = new StreamsBuilder(); + + @SuppressWarnings("unchecked") + final org.apache.kafka.streams.processor.Processor> join = new KTableKTableRightJoin<>( + (KTableImpl) builder.table("left", Consumed.with(Serdes.String(), Serdes.String())), + (KTableImpl) builder.table("right", Consumed.with(Serdes.String(), Serdes.String())), + null + ).get(); + + props.setProperty(StreamsConfig.BUILT_IN_METRICS_VERSION_CONFIG, StreamsConfig.METRICS_LATEST); + final MockProcessorContext context = new MockProcessorContext(props); + context.setRecordMetadata("left", -1, -2, new RecordHeaders(), -3); + join.init(context); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(KTableKTableRightJoin.class)) { + join.process(null, new Change<>("new", "old")); + + assertThat( + appender.getEvents().stream() + .filter(e -> e.getLevel().equals("WARN")) + .map(Event::getMessage) + .collect(Collectors.toList()), + hasItem("Skipping record due to null key. change=[(new<-old)] topic=[left] partition=[-1] offset=[-2]") + ); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableMapKeysTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableMapKeysTest.java new file mode 100644 index 0000000..e761968 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableMapKeysTest.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Test; + +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; + +import static org.junit.Assert.assertEquals; + +public class KTableMapKeysTest { + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.Integer(), Serdes.String()); + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Test + public void testMapKeysConvertingToStream() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic_map_keys"; + + final KTable table1 = builder.table(topic1, Consumed.with(Serdes.Integer(), Serdes.String())); + + final Map keyMap = new HashMap<>(); + keyMap.put(1, "ONE"); + keyMap.put(2, "TWO"); + keyMap.put(3, "THREE"); + + final KStream convertedStream = table1.toStream((key, value) -> keyMap.get(key)); + + final KeyValueTimestamp[] expected = new KeyValueTimestamp[] {new KeyValueTimestamp<>("ONE", "V_ONE", 5), + new KeyValueTimestamp<>("TWO", "V_TWO", 10), + new KeyValueTimestamp<>("THREE", "V_THREE", 15)}; + final int[] originalKeys = new int[] {1, 2, 3}; + final String[] values = new String[] {"V_ONE", "V_TWO", "V_THREE"}; + + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + convertedStream.process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + for (int i = 0; i < originalKeys.length; i++) { + final TestInputTopic inputTopic = + driver.createInputTopic(topic1, new IntegerSerializer(), new StringSerializer()); + inputTopic.pipeInput(originalKeys[i], values[i], 5 + i * 5); + } + } + + assertEquals(3, supplier.theCapturedProcessor().processed().size()); + for (int i = 0; i < expected.length; i++) { + assertEquals(expected[i], supplier.theCapturedProcessor().processed().get(i)); + } + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableMapValuesTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableMapValuesTest.java new file mode 100644 index 0000000..ff6e1b9 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableMapValuesTest.java @@ -0,0 +1,325 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.TopologyTestDriverWrapper; +import org.apache.kafka.streams.TopologyWrapper; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.test.MockApiProcessor; +import org.apache.kafka.test.MockApiProcessorSupplier; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Test; + +import java.time.Duration; +import java.time.Instant; +import java.util.Properties; + +import static java.util.Arrays.asList; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; + +@SuppressWarnings("unchecked") +public class KTableMapValuesTest { + private final Consumed consumed = Consumed.with(Serdes.String(), Serdes.String()); + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.String(), Serdes.String()); + + private void doTestKTable(final StreamsBuilder builder, + final String topic1, + final MockProcessorSupplier supplier) { + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + inputTopic1.pipeInput("A", "1", 5L); + inputTopic1.pipeInput("B", "2", 25L); + inputTopic1.pipeInput("C", "3", 20L); + inputTopic1.pipeInput("D", "4", 10L); + assertEquals(asList(new KeyValueTimestamp<>("A", 1, 5), + new KeyValueTimestamp<>("B", 2, 25), + new KeyValueTimestamp<>("C", 3, 20), + new KeyValueTimestamp<>("D", 4, 10)), supplier.theCapturedProcessor().processed()); + } + } + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Test + public void testKTable() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + + final KTable table1 = builder.table(topic1, consumed); + final KTable table2 = table1.mapValues(value -> value.charAt(0) - 48); + + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + table2.toStream().process(supplier); + + doTestKTable(builder, topic1, supplier); + } + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Test + public void testQueryableKTable() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + + final KTable table1 = builder.table(topic1, consumed); + final KTable table2 = table1 + .mapValues( + value -> value.charAt(0) - 48, + Materialized.>as("anyName") + .withValueSerde(Serdes.Integer())); + + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + table2.toStream().process(supplier); + + doTestKTable(builder, topic1, supplier); + } + + private void doTestValueGetter(final StreamsBuilder builder, + final String topic1, + final KTableImpl table2, + final KTableImpl table3) { + + final Topology topology = builder.build(); + + final KTableValueGetterSupplier getterSupplier2 = table2.valueGetterSupplier(); + final KTableValueGetterSupplier getterSupplier3 = table3.valueGetterSupplier(); + + final InternalTopologyBuilder topologyBuilder = TopologyWrapper.getInternalTopologyBuilder(topology); + topologyBuilder.connectProcessorAndStateStores(table2.name, getterSupplier2.storeNames()); + topologyBuilder.connectProcessorAndStateStores(table3.name, getterSupplier3.storeNames()); + + try (final TopologyTestDriverWrapper driver = new TopologyTestDriverWrapper(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final KTableValueGetter getter2 = getterSupplier2.get(); + final KTableValueGetter getter3 = getterSupplier3.get(); + + getter2.init(driver.setCurrentNodeForProcessorContext(table2.name)); + getter3.init(driver.setCurrentNodeForProcessorContext(table3.name)); + + inputTopic1.pipeInput("A", "01", 50L); + inputTopic1.pipeInput("B", "01", 10L); + inputTopic1.pipeInput("C", "01", 30L); + + assertEquals(ValueAndTimestamp.make(1, 50L), getter2.get("A")); + assertEquals(ValueAndTimestamp.make(1, 10L), getter2.get("B")); + assertEquals(ValueAndTimestamp.make(1, 30L), getter2.get("C")); + + assertEquals(ValueAndTimestamp.make(-1, 50L), getter3.get("A")); + assertEquals(ValueAndTimestamp.make(-1, 10L), getter3.get("B")); + assertEquals(ValueAndTimestamp.make(-1, 30L), getter3.get("C")); + + inputTopic1.pipeInput("A", "02", 25L); + inputTopic1.pipeInput("B", "02", 20L); + + assertEquals(ValueAndTimestamp.make(2, 25L), getter2.get("A")); + assertEquals(ValueAndTimestamp.make(2, 20L), getter2.get("B")); + assertEquals(ValueAndTimestamp.make(1, 30L), getter2.get("C")); + + assertEquals(ValueAndTimestamp.make(-2, 25L), getter3.get("A")); + assertEquals(ValueAndTimestamp.make(-2, 20L), getter3.get("B")); + assertEquals(ValueAndTimestamp.make(-1, 30L), getter3.get("C")); + + inputTopic1.pipeInput("A", "03", 35L); + + assertEquals(ValueAndTimestamp.make(3, 35L), getter2.get("A")); + assertEquals(ValueAndTimestamp.make(2, 20L), getter2.get("B")); + assertEquals(ValueAndTimestamp.make(1, 30L), getter2.get("C")); + + assertEquals(ValueAndTimestamp.make(-3, 35L), getter3.get("A")); + assertEquals(ValueAndTimestamp.make(-2, 20L), getter3.get("B")); + assertEquals(ValueAndTimestamp.make(-1, 30L), getter3.get("C")); + + inputTopic1.pipeInput("A", (String) null, 1L); + + assertNull(getter2.get("A")); + assertEquals(ValueAndTimestamp.make(2, 20L), getter2.get("B")); + assertEquals(ValueAndTimestamp.make(1, 30L), getter2.get("C")); + + assertNull(getter3.get("A")); + assertEquals(ValueAndTimestamp.make(-2, 20L), getter3.get("B")); + assertEquals(ValueAndTimestamp.make(-1, 30L), getter3.get("C")); + } + } + + @Test + public void testQueryableValueGetter() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + final String storeName2 = "store2"; + final String storeName3 = "store3"; + + final KTableImpl table1 = + (KTableImpl) builder.table(topic1, consumed); + final KTableImpl table2 = + (KTableImpl) table1.mapValues( + s -> Integer.valueOf(s), + Materialized.>as(storeName2) + .withValueSerde(Serdes.Integer())); + final KTableImpl table3 = + (KTableImpl) table1.mapValues( + value -> Integer.valueOf(value) * (-1), + Materialized.>as(storeName3) + .withValueSerde(Serdes.Integer())); + final KTableImpl table4 = + (KTableImpl) table1.mapValues(s -> Integer.valueOf(s)); + + assertEquals(storeName2, table2.queryableStoreName()); + assertEquals(storeName3, table3.queryableStoreName()); + assertNull(table4.queryableStoreName()); + + doTestValueGetter(builder, topic1, table2, table3); + } + + @Test + public void testNotSendingOldValue() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + + final KTableImpl table1 = + (KTableImpl) builder.table(topic1, consumed); + final KTableImpl table2 = + (KTableImpl) table1.mapValues(s -> Integer.valueOf(s)); + + final MockApiProcessorSupplier supplier = new MockApiProcessorSupplier<>(); + final Topology topology = builder.build().addProcessor("proc", supplier, table2.name); + + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockApiProcessor proc = supplier.theCapturedProcessor(); + + assertFalse(table1.sendingOldValueEnabled()); + assertFalse(table2.sendingOldValueEnabled()); + + inputTopic1.pipeInput("A", "01", 5L); + inputTopic1.pipeInput("B", "01", 10L); + inputTopic1.pipeInput("C", "01", 15L); + proc.checkAndClearProcessResult(new KeyValueTimestamp<>("A", new Change<>(1, null), 5), + new KeyValueTimestamp<>("B", new Change<>(1, null), 10), + new KeyValueTimestamp<>("C", new Change<>(1, null), 15)); + + inputTopic1.pipeInput("A", "02", 10L); + inputTopic1.pipeInput("B", "02", 8L); + proc.checkAndClearProcessResult(new KeyValueTimestamp<>("A", new Change<>(2, null), 10), + new KeyValueTimestamp<>("B", new Change<>(2, null), 8)); + + inputTopic1.pipeInput("A", "03", 20L); + proc.checkAndClearProcessResult(new KeyValueTimestamp<>("A", new Change<>(3, null), 20)); + + inputTopic1.pipeInput("A", (String) null, 30L); + proc.checkAndClearProcessResult(new KeyValueTimestamp<>("A", new Change<>(null, null), 30)); + } + } + + @Test + public void shouldEnableSendingOldValuesOnParentIfMapValuesNotMaterialized() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + + final KTableImpl table1 = + (KTableImpl) builder.table(topic1, consumed); + final KTableImpl table2 = + (KTableImpl) table1.mapValues(s -> Integer.valueOf(s)); + + table2.enableSendingOldValues(true); + + assertThat(table1.sendingOldValueEnabled(), is(true)); + assertThat(table2.sendingOldValueEnabled(), is(true)); + + testSendingOldValues(builder, topic1, table2); + } + + @Test + public void shouldNotEnableSendingOldValuesOnParentIfMapValuesMaterialized() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + + final KTableImpl table1 = + (KTableImpl) builder.table(topic1, consumed); + final KTableImpl table2 = + (KTableImpl) table1.mapValues( + s -> Integer.valueOf(s), + Materialized.>as("bob").withValueSerde(Serdes.Integer()) + ); + + table2.enableSendingOldValues(true); + + assertThat(table1.sendingOldValueEnabled(), is(false)); + assertThat(table2.sendingOldValueEnabled(), is(true)); + + testSendingOldValues(builder, topic1, table2); + } + + private void testSendingOldValues( + final StreamsBuilder builder, + final String topic1, + final KTableImpl table2 + ) { + final MockApiProcessorSupplier supplier = new MockApiProcessorSupplier<>(); + builder.build().addProcessor("proc", supplier, table2.name); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic(topic1, new StringSerializer(), new StringSerializer(), Instant.ofEpochMilli(0L), Duration.ZERO); + final MockApiProcessor proc = supplier.theCapturedProcessor(); + + inputTopic1.pipeInput("A", "01", 5L); + inputTopic1.pipeInput("B", "01", 10L); + inputTopic1.pipeInput("C", "01", 15L); + proc.checkAndClearProcessResult( + new KeyValueTimestamp<>("A", new Change<>(1, null), 5), + new KeyValueTimestamp<>("B", new Change<>(1, null), 10), + new KeyValueTimestamp<>("C", new Change<>(1, null), 15) + ); + + inputTopic1.pipeInput("A", "02", 10L); + inputTopic1.pipeInput("B", "02", 8L); + proc.checkAndClearProcessResult( + new KeyValueTimestamp<>("A", new Change<>(2, 1), 10), + new KeyValueTimestamp<>("B", new Change<>(2, 1), 8) + ); + + inputTopic1.pipeInput("A", "03", 20L); + proc.checkAndClearProcessResult( + new KeyValueTimestamp<>("A", new Change<>(3, 2), 20) + ); + + inputTopic1.pipeInput("A", (String) null, 30L); + proc.checkAndClearProcessResult( + new KeyValueTimestamp<>("A", new Change<>(null, 3), 30) + ); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableReduceTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableReduceTest.java new file mode 100644 index 0000000..89aa17a --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableReduceTest.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.ProcessorNode; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.test.GenericInMemoryTimestampedKeyValueStore; +import org.apache.kafka.test.InternalMockProcessorContext; +import org.junit.Test; + +import java.util.HashSet; +import java.util.Set; + +import static java.util.Collections.emptySet; +import static java.util.Collections.singleton; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +public class KTableReduceTest { + + @Test + public void shouldAddAndSubtract() { + final InternalMockProcessorContext>> context = new InternalMockProcessorContext<>(); + + final Processor>, String, Change>> reduceProcessor = + new KTableReduce>( + "myStore", + this::unionNotNullArgs, + this::differenceNotNullArgs + ).get(); + + final TimestampedKeyValueStore> myStore = + new GenericInMemoryTimestampedKeyValueStore<>("myStore"); + + context.register(myStore, null); + reduceProcessor.init(context); + context.setCurrentNode(new ProcessorNode<>("reduce", reduceProcessor, singleton("myStore"))); + + reduceProcessor.process(new Record<>("A", new Change<>(singleton("a"), null), 10L)); + assertEquals(ValueAndTimestamp.make(singleton("a"), 10L), myStore.get("A")); + reduceProcessor.process(new Record<>("A", new Change<>(singleton("b"), singleton("a")), 15L)); + assertEquals(ValueAndTimestamp.make(singleton("b"), 15L), myStore.get("A")); + reduceProcessor.process(new Record<>("A", new Change<>(null, singleton("b")), 12L)); + assertEquals(ValueAndTimestamp.make(emptySet(), 15L), myStore.get("A")); + } + + private Set differenceNotNullArgs(final Set left, final Set right) { + assertNotNull(left); + assertNotNull(right); + + final HashSet strings = new HashSet<>(left); + strings.removeAll(right); + return strings; + } + + private Set unionNotNullArgs(final Set left, final Set right) { + assertNotNull(left); + assertNotNull(right); + + final HashSet strings = new HashSet<>(); + strings.addAll(left); + strings.addAll(right); + return strings; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableSourceTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableSourceTest.java new file mode 100644 index 0000000..83b8ac8 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableSourceTest.java @@ -0,0 +1,358 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.IntegerDeserializer; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TestOutputTopic; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.TopologyTestDriverWrapper; +import org.apache.kafka.streams.TopologyWrapper; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender.Event; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.streams.test.TestRecord; +import org.apache.kafka.test.MockApiProcessor; +import org.apache.kafka.test.MockApiProcessorSupplier; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Ignore; +import org.junit.Test; + +import java.time.Duration; +import java.time.Instant; +import java.util.Properties; +import java.util.stream.Collectors; + +import static java.util.Arrays.asList; +import static org.apache.kafka.test.StreamsTestUtils.getMetricByName; +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +public class KTableSourceTest { + private final Consumed stringConsumed = Consumed.with(Serdes.String(), Serdes.String()); + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.String(), Serdes.String()); + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + @Test + public void testKTable() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + + final KTable table1 = builder.table(topic1, Consumed.with(Serdes.String(), Serdes.Integer())); + + final MockProcessorSupplier supplier = new MockProcessorSupplier<>(); + table1.toStream().process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(topic1, new StringSerializer(), new IntegerSerializer()); + inputTopic.pipeInput("A", 1, 10L); + inputTopic.pipeInput("B", 2, 11L); + inputTopic.pipeInput("C", 3, 12L); + inputTopic.pipeInput("D", 4, 13L); + inputTopic.pipeInput("A", null, 14L); + inputTopic.pipeInput("B", null, 15L); + } + + assertEquals( + asList(new KeyValueTimestamp<>("A", 1, 10L), + new KeyValueTimestamp<>("B", 2, 11L), + new KeyValueTimestamp<>("C", 3, 12L), + new KeyValueTimestamp<>("D", 4, 13L), + new KeyValueTimestamp<>("A", null, 14L), + new KeyValueTimestamp<>("B", null, 15L)), + supplier.theCapturedProcessor().processed()); + } + + @Ignore // we have disabled KIP-557 until KAFKA-12508 can be properly addressed + @Test + public void testKTableSourceEmitOnChange() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + + builder.table(topic1, Consumed.with(Serdes.String(), Serdes.Integer()), Materialized.as("store")) + .toStream() + .to("output"); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(topic1, new StringSerializer(), new IntegerSerializer()); + final TestOutputTopic outputTopic = + driver.createOutputTopic("output", new StringDeserializer(), new IntegerDeserializer()); + + inputTopic.pipeInput("A", 1, 10L); + inputTopic.pipeInput("B", 2, 11L); + inputTopic.pipeInput("A", 1, 12L); + inputTopic.pipeInput("B", 3, 13L); + // this record should be kept since this is out of order, so the timestamp + // should be updated in this scenario + inputTopic.pipeInput("A", 1, 9L); + + assertEquals( + 1.0, + getMetricByName(driver.metrics(), "idempotent-update-skip-total", "stream-processor-node-metrics").metricValue() + ); + + assertEquals( + asList(new TestRecord<>("A", 1, Instant.ofEpochMilli(10L)), + new TestRecord<>("B", 2, Instant.ofEpochMilli(11L)), + new TestRecord<>("B", 3, Instant.ofEpochMilli(13L)), + new TestRecord<>("A", 1, Instant.ofEpochMilli(9L))), + outputTopic.readRecordsToList() + ); + } + } + + @Test + public void kTableShouldLogAndMeterOnSkippedRecords() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic = "topic"; + builder.table(topic, stringConsumed); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(KTableSource.class); + final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + + final TestInputTopic inputTopic = + driver.createInputTopic( + topic, + new StringSerializer(), + new StringSerializer(), + Instant.ofEpochMilli(0L), + Duration.ZERO + ); + inputTopic.pipeInput(null, "value"); + + assertThat( + appender.getEvents().stream() + .filter(e -> e.getLevel().equals("WARN")) + .map(Event::getMessage) + .collect(Collectors.toList()), + hasItem("Skipping record due to null key. topic=[topic] partition=[0] offset=[0]") + ); + } + } + + @Test + public void kTableShouldLogOnOutOfOrder() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic = "topic"; + builder.table(topic, stringConsumed, Materialized.as("store")); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(KTableSource.class); + final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + + final TestInputTopic inputTopic = + driver.createInputTopic( + topic, + new StringSerializer(), + new StringSerializer(), + Instant.ofEpochMilli(0L), + Duration.ZERO + ); + inputTopic.pipeInput("key", "value", 10L); + inputTopic.pipeInput("key", "value", 5L); + + assertThat( + appender.getEvents().stream() + .filter(e -> e.getLevel().equals("WARN")) + .map(Event::getMessage) + .collect(Collectors.toList()), + hasItem("Detected out-of-order KTable update for store, old timestamp=[10] new timestamp=[5]. topic=[topic] partition=[1] offset=[0].") + ); + } + } + + @Test + public void testValueGetter() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + + @SuppressWarnings("unchecked") + final KTableImpl table1 = + (KTableImpl) builder.table(topic1, stringConsumed, Materialized.as("store")); + + final Topology topology = builder.build(); + final KTableValueGetterSupplier getterSupplier1 = table1.valueGetterSupplier(); + + final InternalTopologyBuilder topologyBuilder = TopologyWrapper.getInternalTopologyBuilder(topology); + topologyBuilder.connectProcessorAndStateStores(table1.name, getterSupplier1.storeNames()); + + try (final TopologyTestDriverWrapper driver = new TopologyTestDriverWrapper(builder.build(), props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic( + topic1, + new StringSerializer(), + new StringSerializer(), + Instant.ofEpochMilli(0L), + Duration.ZERO + ); + final KTableValueGetter getter1 = getterSupplier1.get(); + getter1.init(driver.setCurrentNodeForProcessorContext(table1.name)); + + inputTopic1.pipeInput("A", "01", 10L); + inputTopic1.pipeInput("B", "01", 20L); + inputTopic1.pipeInput("C", "01", 15L); + + assertEquals(ValueAndTimestamp.make("01", 10L), getter1.get("A")); + assertEquals(ValueAndTimestamp.make("01", 20L), getter1.get("B")); + assertEquals(ValueAndTimestamp.make("01", 15L), getter1.get("C")); + + inputTopic1.pipeInput("A", "02", 30L); + inputTopic1.pipeInput("B", "02", 5L); + + assertEquals(ValueAndTimestamp.make("02", 30L), getter1.get("A")); + assertEquals(ValueAndTimestamp.make("02", 5L), getter1.get("B")); + assertEquals(ValueAndTimestamp.make("01", 15L), getter1.get("C")); + + inputTopic1.pipeInput("A", "03", 29L); + + assertEquals(ValueAndTimestamp.make("03", 29L), getter1.get("A")); + assertEquals(ValueAndTimestamp.make("02", 5L), getter1.get("B")); + assertEquals(ValueAndTimestamp.make("01", 15L), getter1.get("C")); + + inputTopic1.pipeInput("A", null, 50L); + inputTopic1.pipeInput("B", null, 3L); + + assertNull(getter1.get("A")); + assertNull(getter1.get("B")); + assertEquals(ValueAndTimestamp.make("01", 15L), getter1.get("C")); + } + } + + @Test + public void testNotSendingOldValue() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + + @SuppressWarnings("unchecked") + final KTableImpl table1 = + (KTableImpl) builder.table(topic1, stringConsumed); + + final MockApiProcessorSupplier supplier = new MockApiProcessorSupplier<>(); + final Topology topology = builder.build().addProcessor("proc1", supplier, table1.name); + + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic( + topic1, + new StringSerializer(), + new StringSerializer(), + Instant.ofEpochMilli(0L), + Duration.ZERO + ); + final MockApiProcessor proc1 = supplier.theCapturedProcessor(); + + inputTopic1.pipeInput("A", "01", 10L); + inputTopic1.pipeInput("B", "01", 20L); + inputTopic1.pipeInput("C", "01", 15L); + proc1.checkAndClearProcessResult( + new KeyValueTimestamp<>("A", new Change<>("01", null), 10), + new KeyValueTimestamp<>("B", new Change<>("01", null), 20), + new KeyValueTimestamp<>("C", new Change<>("01", null), 15) + ); + + inputTopic1.pipeInput("A", "02", 8L); + inputTopic1.pipeInput("B", "02", 22L); + proc1.checkAndClearProcessResult( + new KeyValueTimestamp<>("A", new Change<>("02", null), 8), + new KeyValueTimestamp<>("B", new Change<>("02", null), 22) + ); + + inputTopic1.pipeInput("A", "03", 12L); + proc1.checkAndClearProcessResult( + new KeyValueTimestamp<>("A", new Change<>("03", null), 12) + ); + + inputTopic1.pipeInput("A", null, 15L); + inputTopic1.pipeInput("B", null, 20L); + proc1.checkAndClearProcessResult( + new KeyValueTimestamp<>("A", new Change<>(null, null), 15), + new KeyValueTimestamp<>("B", new Change<>(null, null), 20) + ); + } + } + + @Test + public void testSendingOldValue() { + final StreamsBuilder builder = new StreamsBuilder(); + final String topic1 = "topic1"; + + @SuppressWarnings("unchecked") + final KTableImpl table1 = + (KTableImpl) builder.table(topic1, stringConsumed); + table1.enableSendingOldValues(true); + assertTrue(table1.sendingOldValueEnabled()); + + final MockApiProcessorSupplier supplier = new MockApiProcessorSupplier<>(); + final Topology topology = builder.build().addProcessor("proc1", supplier, table1.name); + + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, props)) { + final TestInputTopic inputTopic1 = + driver.createInputTopic( + topic1, + new StringSerializer(), + new StringSerializer(), + Instant.ofEpochMilli(0L), + Duration.ZERO + ); + final MockApiProcessor proc1 = supplier.theCapturedProcessor(); + + inputTopic1.pipeInput("A", "01", 10L); + inputTopic1.pipeInput("B", "01", 20L); + inputTopic1.pipeInput("C", "01", 15L); + proc1.checkAndClearProcessResult( + new KeyValueTimestamp<>("A", new Change<>("01", null), 10), + new KeyValueTimestamp<>("B", new Change<>("01", null), 20), + new KeyValueTimestamp<>("C", new Change<>("01", null), 15) + ); + + inputTopic1.pipeInput("A", "02", 8L); + inputTopic1.pipeInput("B", "02", 22L); + proc1.checkAndClearProcessResult( + new KeyValueTimestamp<>("A", new Change<>("02", "01"), 8), + new KeyValueTimestamp<>("B", new Change<>("02", "01"), 22) + ); + + inputTopic1.pipeInput("A", "03", 12L); + proc1.checkAndClearProcessResult( + new KeyValueTimestamp<>("A", new Change<>("03", "02"), 12) + ); + + inputTopic1.pipeInput("A", null, 15L); + inputTopic1.pipeInput("B", null, 20L); + proc1.checkAndClearProcessResult( + new KeyValueTimestamp<>("A", new Change<>(null, "03"), 15), + new KeyValueTimestamp<>("B", new Change<>(null, "02"), 20) + ); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableTransformValuesTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableTransformValuesTest.java new file mode 100644 index 0000000..1690dc9 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/KTableTransformValuesTest.java @@ -0,0 +1,597 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.ValueMapper; +import org.apache.kafka.streams.kstream.ValueTransformerWithKey; +import org.apache.kafka.streams.kstream.ValueTransformerWithKeySupplier; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.internals.ForwardingDisabledProcessorContext; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.processor.internals.ProcessorRecordContext; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.MockReducer; +import org.apache.kafka.test.NoOpValueTransformerWithKeySupplier; +import org.apache.kafka.test.TestUtils; +import org.easymock.EasyMockRunner; +import org.easymock.Mock; +import org.easymock.MockType; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Properties; + +import static org.easymock.EasyMock.anyBoolean; +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.expectLastCall; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.verify; +import static org.hamcrest.CoreMatchers.hasItems; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.isA; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.fail; + +@RunWith(EasyMockRunner.class) +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class KTableTransformValuesTest { + private static final String QUERYABLE_NAME = "queryable-store"; + private static final String INPUT_TOPIC = "inputTopic"; + private static final String STORE_NAME = "someStore"; + private static final String OTHER_STORE_NAME = "otherStore"; + + private static final Consumed CONSUMED = Consumed.with(Serdes.String(), Serdes.String()); + + private TopologyTestDriver driver; + private MockProcessorSupplier capture; + private StreamsBuilder builder; + @Mock(MockType.NICE) + private KTableImpl parent; + @Mock(MockType.NICE) + private InternalProcessorContext context; + @Mock(MockType.NICE) + private KTableValueGetterSupplier parentGetterSupplier; + @Mock(MockType.NICE) + private KTableValueGetter parentGetter; + @Mock(MockType.NICE) + private TimestampedKeyValueStore stateStore; + @Mock(MockType.NICE) + private ValueTransformerWithKeySupplier mockSupplier; + @Mock(MockType.NICE) + private ValueTransformerWithKey transformer; + + @After + public void cleanup() { + if (driver != null) { + driver.close(); + driver = null; + } + } + + @Before + public void setUp() { + capture = new MockProcessorSupplier<>(); + builder = new StreamsBuilder(); + } + + @Test + public void shouldThrowOnGetIfSupplierReturnsNull() { + final KTableTransformValues transformer = + new KTableTransformValues<>(parent, new NullSupplier(), QUERYABLE_NAME); + + try { + transformer.get(); + fail("NPE expected"); + } catch (final NullPointerException expected) { + // expected + } + } + + @Test + public void shouldThrowOnViewGetIfSupplierReturnsNull() { + final KTableValueGetterSupplier view = + new KTableTransformValues<>(parent, new NullSupplier(), null).view(); + + try { + view.get(); + fail("NPE expected"); + } catch (final NullPointerException expected) { + // expected + } + } + + @SuppressWarnings("unchecked") + @Test + public void shouldInitializeTransformerWithForwardDisabledProcessorContext() { + final NoOpValueTransformerWithKeySupplier transformer = new NoOpValueTransformerWithKeySupplier<>(); + final KTableTransformValues transformValues = + new KTableTransformValues<>(parent, transformer, null); + final org.apache.kafka.streams.processor.Processor> processor = transformValues.get(); + + processor.init(context); + + assertThat(transformer.context, isA((Class) ForwardingDisabledProcessorContext.class)); + } + + @Test + public void shouldNotSendOldValuesByDefault() { + final KTableTransformValues transformValues = + new KTableTransformValues<>(parent, new ExclamationValueTransformerSupplier(), null); + + final org.apache.kafka.streams.processor.Processor> processor = transformValues.get(); + processor.init(context); + + context.forward("Key", new Change<>("Key->newValue!", null)); + expectLastCall(); + replay(context); + + processor.process("Key", new Change<>("newValue", "oldValue")); + + verify(context); + } + + @Test + public void shouldSendOldValuesIfConfigured() { + final KTableTransformValues transformValues = + new KTableTransformValues<>(parent, new ExclamationValueTransformerSupplier(), null); + + expect(parent.enableSendingOldValues(true)).andReturn(true); + replay(parent); + + transformValues.enableSendingOldValues(true); + final org.apache.kafka.streams.processor.Processor> processor = transformValues.get(); + processor.init(context); + + context.forward("Key", new Change<>("Key->newValue!", "Key->oldValue!")); + expectLastCall(); + replay(context); + + processor.process("Key", new Change<>("newValue", "oldValue")); + + verify(context); + } + + @Test + public void shouldNotSetSendOldValuesOnParentIfMaterialized() { + expect(parent.enableSendingOldValues(anyBoolean())) + .andThrow(new AssertionError("Should not call enableSendingOldValues")) + .anyTimes(); + + replay(parent); + + new KTableTransformValues<>(parent, new NoOpValueTransformerWithKeySupplier<>(), QUERYABLE_NAME).enableSendingOldValues(true); + + verify(parent); + } + + @Test + public void shouldSetSendOldValuesOnParentIfNotMaterialized() { + expect(parent.enableSendingOldValues(true)).andReturn(true); + replay(parent); + + new KTableTransformValues<>(parent, new NoOpValueTransformerWithKeySupplier<>(), null).enableSendingOldValues(true); + + verify(parent); + } + + @Test + public void shouldTransformOnGetIfNotMaterialized() { + final KTableTransformValues transformValues = + new KTableTransformValues<>(parent, new ExclamationValueTransformerSupplier(), null); + + expect(parent.valueGetterSupplier()).andReturn(parentGetterSupplier); + expect(parentGetterSupplier.get()).andReturn(parentGetter); + expect(parentGetter.get("Key")).andReturn(ValueAndTimestamp.make("Value", 73L)); + final ProcessorRecordContext recordContext = new ProcessorRecordContext( + 42L, + 23L, + -1, + "foo", + new RecordHeaders() + ); + expect(context.recordContext()).andReturn(recordContext); + context.setRecordContext(new ProcessorRecordContext( + 73L, + -1L, + -1, + null, + new RecordHeaders() + )); + expectLastCall(); + context.setRecordContext(recordContext); + expectLastCall(); + replay(parent, parentGetterSupplier, parentGetter, context); + + final KTableValueGetter getter = transformValues.view().get(); + getter.init(context); + + final String result = getter.get("Key").value(); + + assertThat(result, is("Key->Value!")); + verify(context); + } + + @Test + public void shouldGetFromStateStoreIfMaterialized() { + final KTableTransformValues transformValues = + new KTableTransformValues<>(parent, new ExclamationValueTransformerSupplier(), QUERYABLE_NAME); + + expect(context.getStateStore(QUERYABLE_NAME)).andReturn(stateStore); + expect(stateStore.get("Key")).andReturn(ValueAndTimestamp.make("something", 0L)); + replay(context, stateStore); + + final KTableValueGetter getter = transformValues.view().get(); + getter.init(context); + + final String result = getter.get("Key").value(); + + assertThat(result, is("something")); + } + + @Test + public void shouldGetStoreNamesFromParentIfNotMaterialized() { + final KTableTransformValues transformValues = + new KTableTransformValues<>(parent, new ExclamationValueTransformerSupplier(), null); + + expect(parent.valueGetterSupplier()).andReturn(parentGetterSupplier); + expect(parentGetterSupplier.storeNames()).andReturn(new String[]{"store1", "store2"}); + replay(parent, parentGetterSupplier); + + final String[] storeNames = transformValues.view().storeNames(); + + assertThat(storeNames, is(new String[]{"store1", "store2"})); + } + + @Test + public void shouldGetQueryableStoreNameIfMaterialized() { + final KTableTransformValues transformValues = + new KTableTransformValues<>(parent, new ExclamationValueTransformerSupplier(), QUERYABLE_NAME); + + final String[] storeNames = transformValues.view().storeNames(); + + assertThat(storeNames, is(new String[]{QUERYABLE_NAME})); + } + + @Test + public void shouldCloseTransformerOnProcessorClose() { + final KTableTransformValues transformValues = + new KTableTransformValues<>(parent, mockSupplier, null); + + expect(mockSupplier.get()).andReturn(transformer); + transformer.close(); + expectLastCall(); + replay(mockSupplier, transformer); + + final org.apache.kafka.streams.processor.Processor> processor = transformValues.get(); + processor.close(); + + verify(transformer); + } + + @Test + public void shouldCloseTransformerOnGetterClose() { + final KTableTransformValues transformValues = + new KTableTransformValues<>(parent, mockSupplier, null); + + expect(mockSupplier.get()).andReturn(transformer); + expect(parentGetterSupplier.get()).andReturn(parentGetter); + expect(parent.valueGetterSupplier()).andReturn(parentGetterSupplier); + + transformer.close(); + expectLastCall(); + + replay(mockSupplier, transformer, parent, parentGetterSupplier); + + final KTableValueGetter getter = transformValues.view().get(); + getter.close(); + + verify(transformer); + } + + @Test + public void shouldCloseParentGetterClose() { + final KTableTransformValues transformValues = + new KTableTransformValues<>(parent, mockSupplier, null); + + expect(parent.valueGetterSupplier()).andReturn(parentGetterSupplier); + expect(mockSupplier.get()).andReturn(transformer); + expect(parentGetterSupplier.get()).andReturn(parentGetter); + + parentGetter.close(); + expectLastCall(); + + replay(mockSupplier, parent, parentGetterSupplier, parentGetter); + + final KTableValueGetter getter = transformValues.view().get(); + getter.close(); + + verify(parentGetter); + } + + @Test + public void shouldTransformValuesWithKey() { + builder + .addStateStore(storeBuilder(STORE_NAME)) + .addStateStore(storeBuilder(OTHER_STORE_NAME)) + .table(INPUT_TOPIC, CONSUMED) + .transformValues( + new ExclamationValueTransformerSupplier(STORE_NAME, OTHER_STORE_NAME), + STORE_NAME, OTHER_STORE_NAME) + .toStream() + .process(capture); + + driver = new TopologyTestDriver(builder.build(), props()); + final TestInputTopic inputTopic = + driver.createInputTopic(INPUT_TOPIC, new StringSerializer(), new StringSerializer()); + + inputTopic.pipeInput("A", "a", 5L); + inputTopic.pipeInput("B", "b", 10L); + inputTopic.pipeInput("D", null, 15L); + + + assertThat(output(), hasItems(new KeyValueTimestamp<>("A", "A->a!", 5), + new KeyValueTimestamp<>("B", "B->b!", 10), + new KeyValueTimestamp<>("D", "D->null!", 15) + )); + assertNull("Store should not be materialized", driver.getKeyValueStore(QUERYABLE_NAME)); + } + + @Test + public void shouldTransformValuesWithKeyAndMaterialize() { + builder + .addStateStore(storeBuilder(STORE_NAME)) + .table(INPUT_TOPIC, CONSUMED) + .transformValues( + new ExclamationValueTransformerSupplier(STORE_NAME, QUERYABLE_NAME), + Materialized.>as(QUERYABLE_NAME) + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.String()), + STORE_NAME) + .toStream() + .process(capture); + + driver = new TopologyTestDriver(builder.build(), props()); + final TestInputTopic inputTopic = + driver.createInputTopic(INPUT_TOPIC, new StringSerializer(), new StringSerializer()); + inputTopic.pipeInput("A", "a", 5L); + inputTopic.pipeInput("B", "b", 10L); + inputTopic.pipeInput("C", null, 15L); + + assertThat(output(), hasItems(new KeyValueTimestamp<>("A", "A->a!", 5), + new KeyValueTimestamp<>("B", "B->b!", 10), + new KeyValueTimestamp<>("C", "C->null!", 15))); + + { + final KeyValueStore keyValueStore = driver.getKeyValueStore(QUERYABLE_NAME); + assertThat(keyValueStore.get("A"), is("A->a!")); + assertThat(keyValueStore.get("B"), is("B->b!")); + assertThat(keyValueStore.get("C"), is("C->null!")); + } + { + final KeyValueStore> keyValueStore = driver.getTimestampedKeyValueStore(QUERYABLE_NAME); + assertThat(keyValueStore.get("A"), is(ValueAndTimestamp.make("A->a!", 5L))); + assertThat(keyValueStore.get("B"), is(ValueAndTimestamp.make("B->b!", 10L))); + assertThat(keyValueStore.get("C"), is(ValueAndTimestamp.make("C->null!", 15L))); + } + } + + @Test + public void shouldCalculateCorrectOldValuesIfMaterializedEvenIfStateful() { + builder + .table(INPUT_TOPIC, CONSUMED) + .transformValues( + new StatefulTransformerSupplier(), + Materialized.>as(QUERYABLE_NAME) + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.Integer())) + .groupBy(toForceSendingOfOldValues(), Grouped.with(Serdes.String(), Serdes.Integer())) + .reduce(MockReducer.INTEGER_ADDER, MockReducer.INTEGER_SUBTRACTOR) + .mapValues(mapBackToStrings()) + .toStream() + .process(capture); + + driver = new TopologyTestDriver(builder.build(), props()); + final TestInputTopic inputTopic = + driver.createInputTopic(INPUT_TOPIC, new StringSerializer(), new StringSerializer()); + + inputTopic.pipeInput("A", "ignored", 5L); + inputTopic.pipeInput("A", "ignored1", 15L); + inputTopic.pipeInput("A", "ignored2", 10L); + + assertThat(output(), hasItems(new KeyValueTimestamp<>("A", "1", 5), + new KeyValueTimestamp<>("A", "0", 15), + new KeyValueTimestamp<>("A", "2", 15), + new KeyValueTimestamp<>("A", "0", 15), + new KeyValueTimestamp<>("A", "3", 15))); + + final KeyValueStore keyValueStore = driver.getKeyValueStore(QUERYABLE_NAME); + assertThat(keyValueStore.get("A"), is(3)); + } + + @Test + public void shouldCalculateCorrectOldValuesIfNotStatefulEvenIfNotMaterialized() { + builder + .table(INPUT_TOPIC, CONSUMED) + .transformValues(new StatelessTransformerSupplier()) + .groupBy(toForceSendingOfOldValues(), Grouped.with(Serdes.String(), Serdes.Integer())) + .reduce(MockReducer.INTEGER_ADDER, MockReducer.INTEGER_SUBTRACTOR) + .mapValues(mapBackToStrings()) + .toStream() + .process(capture); + + driver = new TopologyTestDriver(builder.build(), props()); + final TestInputTopic inputTopic = + driver.createInputTopic(INPUT_TOPIC, new StringSerializer(), new StringSerializer()); + + inputTopic.pipeInput("A", "a", 5L); + inputTopic.pipeInput("A", "aa", 15L); + inputTopic.pipeInput("A", "aaa", 10); + + assertThat(output(), hasItems(new KeyValueTimestamp<>("A", "1", 5), + new KeyValueTimestamp<>("A", "0", 15), + new KeyValueTimestamp<>("A", "2", 15), + new KeyValueTimestamp<>("A", "0", 15), + new KeyValueTimestamp<>("A", "3", 15))); + } + + private ArrayList> output() { + return capture.capturedProcessors(1).get(0).processed(); + } + + private static KeyValueMapper> toForceSendingOfOldValues() { + return KeyValue::new; + } + + private static ValueMapper mapBackToStrings() { + return Object::toString; + } + + private static StoreBuilder> storeBuilder(final String storeName) { + return Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore(storeName), Serdes.Long(), Serdes.Long()); + } + + public static Properties props() { + final Properties props = new Properties(); + props.setProperty(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getAbsolutePath()); + props.setProperty(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.Integer().getClass().getName()); + props.setProperty(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.Integer().getClass().getName()); + return props; + } + + private static void throwIfStoresNotAvailable(final ProcessorContext context, + final List expectedStoredNames) { + final List missing = new ArrayList<>(); + + for (final String storedName : expectedStoredNames) { + if (context.getStateStore(storedName) == null) { + missing.add(storedName); + } + } + + if (!missing.isEmpty()) { + throw new AssertionError("State stores are not accessible: " + missing); + } + } + + public static class ExclamationValueTransformerSupplier implements ValueTransformerWithKeySupplier { + private final List expectedStoredNames; + + ExclamationValueTransformerSupplier(final String... expectedStoreNames) { + this.expectedStoredNames = Arrays.asList(expectedStoreNames); + } + + @Override + public ExclamationValueTransformer get() { + return new ExclamationValueTransformer(expectedStoredNames); + } + } + + public static class ExclamationValueTransformer implements ValueTransformerWithKey { + private final List expectedStoredNames; + + ExclamationValueTransformer(final List expectedStoredNames) { + this.expectedStoredNames = expectedStoredNames; + } + + @Override + public void init(final ProcessorContext context) { + throwIfStoresNotAvailable(context, expectedStoredNames); + } + + @Override + public String transform(final Object readOnlyKey, final String value) { + return readOnlyKey.toString() + "->" + value + "!"; + } + + @Override + public void close() {} + } + + private static class NullSupplier implements ValueTransformerWithKeySupplier { + @Override + public ValueTransformerWithKey get() { + return null; + } + } + + private static class StatefulTransformerSupplier implements ValueTransformerWithKeySupplier { + @Override + public ValueTransformerWithKey get() { + return new StatefulTransformer(); + } + } + + private static class StatefulTransformer implements ValueTransformerWithKey { + private int counter; + + @Override + public void init(final ProcessorContext context) {} + + @Override + public Integer transform(final String readOnlyKey, final String value) { + return ++counter; + } + + @Override + public void close() {} + } + + private static class StatelessTransformerSupplier implements ValueTransformerWithKeySupplier { + @Override + public ValueTransformerWithKey get() { + return new StatelessTransformer(); + } + } + + private static class StatelessTransformer implements ValueTransformerWithKey { + @Override + public void init(final ProcessorContext context) {} + + @Override + public Integer transform(final String readOnlyKey, final String value) { + return value == null ? null : value.length(); + } + + @Override + public void close() {} + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/MaterializedInternalTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/MaterializedInternalTest.java new file mode 100644 index 0000000..5d5e888 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/MaterializedInternalTest.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.state.KeyValueBytesStoreSupplier; +import org.apache.kafka.streams.state.KeyValueStore; +import org.easymock.EasyMock; +import org.easymock.EasyMockRunner; +import org.easymock.Mock; +import org.easymock.MockType; +import org.junit.Test; +import org.junit.runner.RunWith; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; + +@RunWith(EasyMockRunner.class) +public class MaterializedInternalTest { + + @Mock(type = MockType.NICE) + private InternalNameProvider nameProvider; + + @Mock(type = MockType.NICE) + private KeyValueBytesStoreSupplier supplier; + private final String prefix = "prefix"; + + @Test + public void shouldGenerateStoreNameWithPrefixIfProvidedNameIsNull() { + final String generatedName = prefix + "-store"; + EasyMock.expect(nameProvider.newStoreName(prefix)).andReturn(generatedName); + + EasyMock.replay(nameProvider); + + final MaterializedInternal materialized = + new MaterializedInternal<>(Materialized.with(null, null), nameProvider, prefix); + + assertThat(materialized.storeName(), equalTo(generatedName)); + EasyMock.verify(nameProvider); + } + + @Test + public void shouldUseProvidedStoreNameWhenSet() { + final String storeName = "store-name"; + final MaterializedInternal materialized = + new MaterializedInternal<>(Materialized.as(storeName), nameProvider, prefix); + assertThat(materialized.storeName(), equalTo(storeName)); + } + + @Test + public void shouldUseStoreNameOfSupplierWhenProvided() { + final String storeName = "other-store-name"; + EasyMock.expect(supplier.name()).andReturn(storeName).anyTimes(); + EasyMock.replay(supplier); + final MaterializedInternal> materialized = + new MaterializedInternal<>(Materialized.as(supplier), nameProvider, prefix); + assertThat(materialized.storeName(), equalTo(storeName)); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/NamedInternalTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/NamedInternalTest.java new file mode 100644 index 0000000..1c4c700 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/NamedInternalTest.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +public class NamedInternalTest { + + private static final String TEST_PREFIX = "prefix-"; + private static final String TEST_VALUE = "default-value"; + private static final String TEST_SUFFIX = "-suffix"; + + private static class TestNameProvider implements InternalNameProvider { + int index = 0; + + @Override + public String newProcessorName(final String prefix) { + return prefix + "PROCESSOR-" + index++; + } + + @Override + public String newStoreName(final String prefix) { + return prefix + "STORE-" + index++; + } + + } + + @Test + public void shouldSuffixNameOrReturnProviderValue() { + final String name = "foo"; + final TestNameProvider provider = new TestNameProvider(); + + assertEquals( + name + TEST_SUFFIX, + NamedInternal.with(name).suffixWithOrElseGet(TEST_SUFFIX, provider, TEST_PREFIX) + ); + + // 1, not 0, indicates that the named call still burned an index number. + assertEquals( + "prefix-PROCESSOR-1", + NamedInternal.with(null).suffixWithOrElseGet(TEST_SUFFIX, provider, TEST_PREFIX) + ); + } + + @Test + public void shouldGenerateWithPrefixGivenEmptyName() { + final String prefix = "KSTREAM-MAP-"; + assertEquals(prefix + "PROCESSOR-0", NamedInternal.with(null).orElseGenerateWithPrefix( + new TestNameProvider(), + prefix) + ); + } + + @Test + public void shouldNotGenerateWithPrefixGivenValidName() { + final String validName = "validName"; + assertEquals(validName, NamedInternal.with(validName).orElseGenerateWithPrefix(new TestNameProvider(), "KSTREAM-MAP-") + ); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/SessionCacheFlushListenerTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/SessionCacheFlushListenerTest.java new file mode 100644 index 0000000..da71149 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/SessionCacheFlushListenerTest.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.junit.Test; + +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.expectLastCall; +import static org.easymock.EasyMock.mock; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.verify; + +public class SessionCacheFlushListenerTest { + @Test + public void shouldForwardKeyNewValueOldValueAndTimestamp() { + final InternalProcessorContext, Change> context = mock(InternalProcessorContext.class); + expect(context.currentNode()).andReturn(null).anyTimes(); + context.setCurrentNode(null); + context.setCurrentNode(null); + context.forward( + new Record<>( + new Windowed<>("key", new SessionWindow(21L, 73L)), + new Change<>("newValue", "oldValue"), + 73L)); + expectLastCall(); + replay(context); + + new SessionCacheFlushListener<>(context).apply( + new Windowed<>("key", new SessionWindow(21L, 73L)), + "newValue", + "oldValue", + 42L); + + verify(context); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/SessionTupleForwarderTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/SessionTupleForwarderTest.java new file mode 100644 index 0000000..f0a963d --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/SessionTupleForwarderTest.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.state.internals.WrappedStateStore; +import org.junit.Test; + +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.expectLastCall; +import static org.easymock.EasyMock.mock; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.verify; + +public class SessionTupleForwarderTest { + + @Test + public void shouldSetFlushListenerOnWrappedStateStore() { + setFlushListener(true); + setFlushListener(false); + } + + private void setFlushListener(final boolean sendOldValues) { + final WrappedStateStore, Object> store = mock(WrappedStateStore.class); + final SessionCacheFlushListener flushListener = mock(SessionCacheFlushListener.class); + + expect(store.setFlushListener(flushListener, sendOldValues)).andReturn(false); + replay(store); + + new SessionTupleForwarder<>(store, null, flushListener, sendOldValues); + + verify(store); + } + + @Test + public void shouldForwardRecordsIfWrappedStateStoreDoesNotCache() { + shouldForwardRecordsIfWrappedStateStoreDoesNotCache(false); + shouldForwardRecordsIfWrappedStateStoreDoesNotCache(true); + } + + private void shouldForwardRecordsIfWrappedStateStoreDoesNotCache(final boolean sendOldValued) { + final WrappedStateStore store = mock(WrappedStateStore.class); + final ProcessorContext, Change> context = mock( + ProcessorContext.class); + + expect(store.setFlushListener(null, sendOldValued)).andReturn(false); + if (sendOldValued) { + context.forward( + new Record<>( + new Windowed<>("key", new SessionWindow(21L, 42L)), + new Change<>("value", "oldValue"), + 42L)); + } else { + context.forward( + new Record<>( + new Windowed<>("key", new SessionWindow(21L, 42L)), + new Change<>("value", null), + 42L)); + } + expectLastCall(); + replay(store, context); + + new SessionTupleForwarder<>(store, context, null, + sendOldValued) + .maybeForward( + new Windowed<>("key", new SessionWindow(21L, 42L)), + "value", + "oldValue"); + + verify(store, context); + } + + @Test + public void shouldNotForwardRecordsIfWrappedStateStoreDoesCache() { + final WrappedStateStore store = mock(WrappedStateStore.class); + final ProcessorContext, Change> context = mock(ProcessorContext.class); + + expect(store.setFlushListener(null, false)).andReturn(true); + replay(store, context); + + new SessionTupleForwarder<>(store, context, null, false) + .maybeForward(new Windowed<>("key", new SessionWindow(21L, 42L)), "value", "oldValue"); + + verify(store, context); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/SessionWindowTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/SessionWindowTest.java new file mode 100644 index 0000000..df23ddd --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/SessionWindowTest.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.junit.Test; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class SessionWindowTest { + + private final long start = 50; + private final long end = 100; + private final SessionWindow window = new SessionWindow(start, end); + private final TimeWindow timeWindow = new TimeWindow(start, end); + + @Test + public void shouldNotOverlapIfOtherWindowIsBeforeThisWindow() { + /* + * This: [-------] + * Other: [---] + */ + assertFalse(window.overlap(new SessionWindow(0, 25))); + assertFalse(window.overlap(new SessionWindow(0, start - 1))); + assertFalse(window.overlap(new SessionWindow(start - 1, start - 1))); + } + + @Test + public void shouldOverlapIfOtherWindowEndIsWithinThisWindow() { + /* + * This: [-------] + * Other: [---------] + */ + assertTrue(window.overlap(new SessionWindow(0, start))); + assertTrue(window.overlap(new SessionWindow(0, start + 1))); + assertTrue(window.overlap(new SessionWindow(0, 75))); + assertTrue(window.overlap(new SessionWindow(0, end - 1))); + assertTrue(window.overlap(new SessionWindow(0, end))); + + assertTrue(window.overlap(new SessionWindow(start - 1, start))); + assertTrue(window.overlap(new SessionWindow(start - 1, start + 1))); + assertTrue(window.overlap(new SessionWindow(start - 1, 75))); + assertTrue(window.overlap(new SessionWindow(start - 1, end - 1))); + assertTrue(window.overlap(new SessionWindow(start - 1, end))); + } + + @Test + public void shouldOverlapIfOtherWindowContainsThisWindow() { + /* + * This: [-------] + * Other: [------------------] + */ + assertTrue(window.overlap(new SessionWindow(0, end))); + assertTrue(window.overlap(new SessionWindow(0, end + 1))); + assertTrue(window.overlap(new SessionWindow(0, 150))); + + assertTrue(window.overlap(new SessionWindow(start - 1, end))); + assertTrue(window.overlap(new SessionWindow(start - 1, end + 1))); + assertTrue(window.overlap(new SessionWindow(start - 1, 150))); + + assertTrue(window.overlap(new SessionWindow(start, end))); + assertTrue(window.overlap(new SessionWindow(start, end + 1))); + assertTrue(window.overlap(new SessionWindow(start, 150))); + } + + @Test + public void shouldOverlapIfOtherWindowIsWithinThisWindow() { + /* + * This: [-------] + * Other: [---] + */ + assertTrue(window.overlap(new SessionWindow(start, start))); + assertTrue(window.overlap(new SessionWindow(start, 75))); + assertTrue(window.overlap(new SessionWindow(start, end))); + assertTrue(window.overlap(new SessionWindow(75, end))); + assertTrue(window.overlap(new SessionWindow(end, end))); + } + + @Test + public void shouldOverlapIfOtherWindowStartIsWithinThisWindow() { + /* + * This: [-------] + * Other: [-------] + */ + assertTrue(window.overlap(new SessionWindow(start, end + 1))); + assertTrue(window.overlap(new SessionWindow(start, 150))); + assertTrue(window.overlap(new SessionWindow(75, end + 1))); + assertTrue(window.overlap(new SessionWindow(75, 150))); + assertTrue(window.overlap(new SessionWindow(end, end + 1))); + assertTrue(window.overlap(new SessionWindow(end, 150))); + } + + @Test + public void shouldNotOverlapIsOtherWindowIsAfterThisWindow() { + /* + * This: [-------] + * Other: [---] + */ + assertFalse(window.overlap(new SessionWindow(end + 1, end + 1))); + assertFalse(window.overlap(new SessionWindow(end + 1, 150))); + assertFalse(window.overlap(new SessionWindow(125, 150))); + } + + @Test + public void cannotCompareSessionWindowWithDifferentWindowType() { + assertThrows(IllegalArgumentException.class, () -> window.overlap(timeWindow)); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/SessionWindowedCogroupedKStreamImplTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/SessionWindowedCogroupedKStreamImplTest.java new file mode 100644 index 0000000..eee7cc5 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/SessionWindowedCogroupedKStreamImplTest.java @@ -0,0 +1,337 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals; + +import static java.time.Duration.ofMillis; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertThrows; + +import java.util.Properties; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TestOutputTopic; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.CogroupedKStream; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Merger; +import org.apache.kafka.streams.kstream.Named; +import org.apache.kafka.streams.kstream.SessionWindowedCogroupedKStream; +import org.apache.kafka.streams.kstream.SessionWindowedDeserializer; +import org.apache.kafka.streams.kstream.SessionWindows; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.test.TestRecord; +import org.apache.kafka.test.MockAggregator; +import org.apache.kafka.test.MockInitializer; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Before; +import org.junit.Test; + +@SuppressWarnings("deprecation") +public class SessionWindowedCogroupedKStreamImplTest { + + private final StreamsBuilder builder = new StreamsBuilder(); + private static final String TOPIC = "topic"; + private static final String TOPIC2 = "topic2"; + private static final String OUTPUT = "output"; + + private final Merger sessionMerger = (aggKey, aggOne, aggTwo) -> aggOne + "+" + aggTwo; + + private KGroupedStream groupedStream; + private KGroupedStream groupedStream2; + private CogroupedKStream cogroupedStream; + private SessionWindowedCogroupedKStream windowedCogroupedStream; + + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.String(), Serdes.String()); + + @Before + public void setup() { + final KStream stream = builder.stream(TOPIC, Consumed + .with(Serdes.String(), Serdes.String())); + final KStream stream2 = builder.stream(TOPIC2, Consumed + .with(Serdes.String(), Serdes.String())); + + groupedStream = stream.groupByKey(Grouped.with(Serdes.String(), Serdes.String())); + groupedStream2 = stream2.groupByKey(Grouped.with(Serdes.String(), Serdes.String())); + cogroupedStream = groupedStream.cogroup(MockAggregator.TOSTRING_ADDER) + .cogroup(groupedStream2, MockAggregator.TOSTRING_REMOVER); + windowedCogroupedStream = cogroupedStream.windowedBy(SessionWindows.with(ofMillis(100))); + } + + @Test + public void shouldNotHaveNullInitializerOnAggregate() { + assertThrows(NullPointerException.class, () -> windowedCogroupedStream.aggregate(null, sessionMerger)); + } + + @Test + public void shouldNotHaveNullSessionMergerOnAggregate() { + assertThrows(NullPointerException.class, () -> windowedCogroupedStream.aggregate(MockInitializer.STRING_INIT, null)); + } + + @Test + public void shouldNotHaveNullMaterializedOnAggregate() { + assertThrows(NullPointerException.class, () -> windowedCogroupedStream.aggregate(MockInitializer.STRING_INIT, + sessionMerger, (Named) null)); + } + + @Test + public void shouldNotHaveNullSessionMerger2OnAggregate() { + assertThrows(NullPointerException.class, () -> windowedCogroupedStream.aggregate(MockInitializer.STRING_INIT, + null, Materialized.as("test"))); + } + + @Test + public void shouldNotHaveNullInitializer2OnAggregate() { + assertThrows(NullPointerException.class, () -> windowedCogroupedStream.aggregate(null, sessionMerger, + Materialized.as("test"))); + } + + @Test + public void shouldNotHaveNullMaterialized2OnAggregate() { + assertThrows(NullPointerException.class, () -> windowedCogroupedStream.aggregate(MockInitializer.STRING_INIT, + sessionMerger, Named.as("name"), null)); + } + + @Test + public void shouldNotHaveNullSessionMerger3OnAggregate() { + assertThrows(NullPointerException.class, () -> windowedCogroupedStream.aggregate(MockInitializer.STRING_INIT, + null, Named.as("name"), Materialized.as("test"))); + } + + @Test + public void shouldNotHaveNullNamedOnAggregate() { + assertThrows(NullPointerException.class, () -> windowedCogroupedStream.aggregate(MockInitializer.STRING_INIT, + sessionMerger, null, Materialized.as("test"))); + } + + @Test + public void shouldNotHaveNullInitializer3OnAggregate() { + assertThrows(NullPointerException.class, () -> windowedCogroupedStream.aggregate(null, sessionMerger, + Named.as("name"), Materialized.as("test"))); + } + + @Test + public void shouldNotHaveNullNamed2OnAggregate() { + assertThrows(NullPointerException.class, () -> windowedCogroupedStream.aggregate(MockInitializer.STRING_INIT, sessionMerger, (Named) null)); + } + + @Test + public void namedParamShouldSetName() { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream = builder.stream(TOPIC, Consumed + .with(Serdes.String(), Serdes.String())); + groupedStream = stream.groupByKey(Grouped.with(Serdes.String(), Serdes.String())); + groupedStream.cogroup(MockAggregator.TOSTRING_ADDER) + .windowedBy(SessionWindows.with(ofMillis(1))) + .aggregate(MockInitializer.STRING_INIT, sessionMerger, Named.as("foo")); + + assertThat(builder.build().describe().toString(), equalTo( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [topic])\n" + + " --> foo-cogroup-agg-0\n" + + " Processor: foo-cogroup-agg-0 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000001])\n" + + " --> foo-cogroup-merge\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: foo-cogroup-merge (stores: [])\n" + + " --> none\n" + + " <-- foo-cogroup-agg-0\n\n")); + } + + @Test + public void sessionWindowAggregateTest() { + final KTable, String> customers = groupedStream.cogroup(MockAggregator.TOSTRING_ADDER) + .windowedBy(SessionWindows.with(ofMillis(500))) + .aggregate(MockInitializer.STRING_INIT, sessionMerger, Materialized.with(Serdes.String(), Serdes.String())); + customers.toStream().to(OUTPUT); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic testInputTopic = driver.createInputTopic( + TOPIC, new StringSerializer(), new StringSerializer()); + final TestOutputTopic, String> testOutputTopic = driver.createOutputTopic( + OUTPUT, new SessionWindowedDeserializer<>(new StringDeserializer()), new StringDeserializer()); + testInputTopic.pipeInput("k1", "A", 0); + testInputTopic.pipeInput("k2", "A", 0); + testInputTopic.pipeInput("k1", "B", 599); + testInputTopic.pipeInput("k2", "B", 607); + + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+A", 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+A", 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+B", 599); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+B", 607); + } + } + + @Test + public void sessionWindowAggregate2Test() { + final KTable, String> customers = groupedStream.cogroup(MockAggregator.TOSTRING_ADDER) + .windowedBy(SessionWindows.with(ofMillis(500))) + .aggregate(MockInitializer.STRING_INIT, sessionMerger, Materialized.with(Serdes.String(), Serdes.String())); + customers.toStream().to(OUTPUT); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic testInputTopic = driver.createInputTopic( + TOPIC, new StringSerializer(), new StringSerializer()); + final TestOutputTopic, String> testOutputTopic = driver.createOutputTopic( + OUTPUT, new SessionWindowedDeserializer<>(new StringDeserializer()), new StringDeserializer()); + + testInputTopic.pipeInput("k1", "A", 0); + testInputTopic.pipeInput("k1", "A", 0); + testInputTopic.pipeInput("k2", "B", 599); + testInputTopic.pipeInput("k1", "B", 607); + + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+A", 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+0+A+A", 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+B", 599); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+B", 607); + } + + + } + + @Test + public void sessionWindowAggregateTest2StreamsTest() { + final KTable, String> customers = windowedCogroupedStream.aggregate( + MockInitializer.STRING_INIT, sessionMerger, Materialized.with(Serdes.String(), Serdes.String())); + customers.toStream().to(OUTPUT); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic testInputTopic = driver.createInputTopic( + TOPIC, new StringSerializer(), new StringSerializer()); + final TestOutputTopic, String> testOutputTopic = driver.createOutputTopic( + OUTPUT, new SessionWindowedDeserializer<>(new StringDeserializer()), new StringDeserializer()); + + testInputTopic.pipeInput("k1", "A", 0); + testInputTopic.pipeInput("k1", "A", 84); + testInputTopic.pipeInput("k1", "A", 113); + testInputTopic.pipeInput("k1", "A", 199); + testInputTopic.pipeInput("k1", "B", 300); + testInputTopic.pipeInput("k2", "B", 301); + testInputTopic.pipeInput("k2", "B", 400); + testInputTopic.pipeInput("k1", "B", 400); + + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+A", 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", null, 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+0+A+A", 84); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", null, 84); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+0+0+A+A+A", 113); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", null, 113); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+0+0+0+A+A+A+A", 199); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+B", 300); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+B", 301); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", null, 301); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+0+B+B", 400); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", null, 300); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+0+B+B", 400); + + } + } + + @Test + public void sessionWindowMixAggregatorsTest() { + final KTable, String> customers = windowedCogroupedStream.aggregate( + MockInitializer.STRING_INIT, sessionMerger, Materialized.with(Serdes.String(), Serdes.String())); + customers.toStream().to(OUTPUT); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic testInputTopic = driver.createInputTopic(TOPIC, new StringSerializer(), new StringSerializer()); + final TestInputTopic testInputTopic2 = driver.createInputTopic(TOPIC2, new StringSerializer(), new StringSerializer()); + + final TestOutputTopic, String> testOutputTopic = driver.createOutputTopic( + OUTPUT, new SessionWindowedDeserializer<>(new StringDeserializer()), new StringDeserializer()); + testInputTopic.pipeInput("k1", "A", 0); + testInputTopic.pipeInput("k2", "A", 0); + testInputTopic.pipeInput("k2", "A", 1); + testInputTopic.pipeInput("k1", "A", 2); + testInputTopic2.pipeInput("k1", "B", 3); + testInputTopic2.pipeInput("k2", "B", 3); + testInputTopic2.pipeInput("k2", "B", 444); + testInputTopic2.pipeInput("k1", "B", 444); + + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+A", 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+A", 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", null, 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+0+A+A", 1); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", null, 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+0+A+A", 2); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", null, 2); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+0+0+A+A-B", 3); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", null, 1); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+0+0+A+A-B", 3); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0-B", 444); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0-B", 444); + } + + } + + @Test + public void sessionWindowMixAggregatorsManyWindowsTest() { + final KTable, String> customers = windowedCogroupedStream.aggregate( + MockInitializer.STRING_INIT, sessionMerger, Materialized.with(Serdes.String(), Serdes.String())); + customers.toStream().to(OUTPUT); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic testInputTopic = driver.createInputTopic(TOPIC, new StringSerializer(), new StringSerializer()); + final TestInputTopic testInputTopic2 = driver.createInputTopic(TOPIC2, new StringSerializer(), new StringSerializer()); + final TestOutputTopic, String> testOutputTopic = driver.createOutputTopic( + OUTPUT, new SessionWindowedDeserializer<>(new StringDeserializer()), new StringDeserializer()); + testInputTopic.pipeInput("k1", "A", 0); + testInputTopic.pipeInput("k2", "A", 0); + testInputTopic.pipeInput("k2", "A", 1); + testInputTopic.pipeInput("k1", "A", 2); + testInputTopic2.pipeInput("k1", "B", 3); + testInputTopic2.pipeInput("k2", "B", 500); + testInputTopic2.pipeInput("k2", "B", 501); + testInputTopic2.pipeInput("k1", "B", 501); + + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+A", 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+A", 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", null, 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+0+A+A", 1); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", null, 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+0+A+A", 2); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", null, 2); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+0+0+A+A-B", 3); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0-B", 500); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", null, 500); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+0-B-B", 501); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0-B", 501); + } + + } + + private void assertOutputKeyValueTimestamp(final TestOutputTopic, String> outputTopic, + final String expectedKey, + final String expectedValue, + final long expectedTimestamp) { + final TestRecord, String> realRecord = outputTopic.readRecord(); + final TestRecord nonWindowedRecord = new TestRecord<>( + realRecord.getKey().key(), realRecord.getValue(), null, realRecord.timestamp()); + final TestRecord testRecord = new TestRecord<>(expectedKey, expectedValue, null, expectedTimestamp); + assertThat(nonWindowedRecord, equalTo(testRecord)); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/SessionWindowedKStreamImplTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/SessionWindowedKStreamImplTest.java new file mode 100644 index 0000000..f26fbc6 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/SessionWindowedKStreamImplTest.java @@ -0,0 +1,305 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Merger; +import org.apache.kafka.streams.kstream.Named; +import org.apache.kafka.streams.kstream.SessionWindowedKStream; +import org.apache.kafka.streams.kstream.SessionWindows; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.SessionStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.test.MockAggregator; +import org.apache.kafka.test.MockInitializer; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.MockReducer; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Properties; + +import static java.time.Duration.ofMillis; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertThrows; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class SessionWindowedKStreamImplTest { + private static final String TOPIC = "input"; + private final StreamsBuilder builder = new StreamsBuilder(); + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.String(), Serdes.String()); + private final Merger sessionMerger = (aggKey, aggOne, aggTwo) -> aggOne + "+" + aggTwo; + private SessionWindowedKStream stream; + + @Before + public void before() { + final KStream stream = builder.stream(TOPIC, Consumed.with(Serdes.String(), Serdes.String())); + this.stream = stream.groupByKey(Grouped.with(Serdes.String(), Serdes.String())) + .windowedBy(SessionWindows.with(ofMillis(500))); + } + + @Test + public void shouldCountSessionWindowedWithCachingDisabled() { + props.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + shouldCountSessionWindowed(); + } + + @Test + public void shouldCountSessionWindowedWithCachingEnabled() { + shouldCountSessionWindowed(); + } + + private void shouldCountSessionWindowed() { + final MockProcessorSupplier, Long> supplier = new MockProcessorSupplier<>(); + stream.count() + .toStream() + .process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + processData(driver); + } + + final Map, ValueAndTimestamp> result = + supplier.theCapturedProcessor().lastValueAndTimestampPerKey(); + + assertThat(result.size(), equalTo(3)); + assertThat( + result.get(new Windowed<>("1", new SessionWindow(10L, 15L))), + equalTo(ValueAndTimestamp.make(2L, 15L))); + assertThat( + result.get(new Windowed<>("2", new SessionWindow(599L, 600L))), + equalTo(ValueAndTimestamp.make(2L, 600L))); + assertThat( + result.get(new Windowed<>("1", new SessionWindow(600L, 600L))), + equalTo(ValueAndTimestamp.make(1L, 600L))); + } + + @Test + public void shouldReduceWindowed() { + final MockProcessorSupplier, String> supplier = new MockProcessorSupplier<>(); + stream.reduce(MockReducer.STRING_ADDER) + .toStream() + .process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + processData(driver); + } + + final Map, ValueAndTimestamp> result = + supplier.theCapturedProcessor().lastValueAndTimestampPerKey(); + + assertThat(result.size(), equalTo(3)); + assertThat( + result.get(new Windowed<>("1", new SessionWindow(10, 15))), + equalTo(ValueAndTimestamp.make("1+2", 15L))); + assertThat( + result.get(new Windowed<>("2", new SessionWindow(599L, 600))), + equalTo(ValueAndTimestamp.make("1+2", 600L))); + assertThat( + result.get(new Windowed<>("1", new SessionWindow(600, 600))), + equalTo(ValueAndTimestamp.make("3", 600L))); + } + + @Test + public void shouldAggregateSessionWindowed() { + final MockProcessorSupplier, String> supplier = new MockProcessorSupplier<>(); + stream.aggregate(MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + sessionMerger, + Materialized.with(Serdes.String(), Serdes.String())) + .toStream() + .process(supplier); + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + processData(driver); + } + + final Map, ValueAndTimestamp> result = + supplier.theCapturedProcessor().lastValueAndTimestampPerKey(); + + assertThat(result.size(), equalTo(3)); + assertThat( + result.get(new Windowed<>("1", new SessionWindow(10, 15))), + equalTo(ValueAndTimestamp.make("0+0+1+2", 15L))); + assertThat( + result.get(new Windowed<>("2", new SessionWindow(599, 600))), + equalTo(ValueAndTimestamp.make("0+0+1+2", 600L))); + assertThat( + result.get(new Windowed<>("1", new SessionWindow(600, 600))), + equalTo(ValueAndTimestamp.make("0+3", 600L))); + } + + @Test + public void shouldMaterializeCount() { + stream.count(Materialized.as("count-store")); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + processData(driver); + final SessionStore store = driver.getSessionStore("count-store"); + final List, Long>> data = StreamsTestUtils.toList(store.fetch("1", "2")); + assertThat( + data, + equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("1", new SessionWindow(10, 15)), 2L), + KeyValue.pair(new Windowed<>("1", new SessionWindow(600, 600)), 1L), + KeyValue.pair(new Windowed<>("2", new SessionWindow(599, 600)), 2L)))); + } + } + + @Test + public void shouldMaterializeReduced() { + stream.reduce(MockReducer.STRING_ADDER, Materialized.as("reduced")); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + processData(driver); + final SessionStore sessionStore = driver.getSessionStore("reduced"); + final List, String>> data = StreamsTestUtils.toList(sessionStore.fetch("1", "2")); + + assertThat( + data, + equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("1", new SessionWindow(10, 15)), "1+2"), + KeyValue.pair(new Windowed<>("1", new SessionWindow(600, 600)), "3"), + KeyValue.pair(new Windowed<>("2", new SessionWindow(599, 600)), "1+2")))); + } + } + + @Test + public void shouldMaterializeAggregated() { + stream.aggregate( + MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + sessionMerger, + Materialized.>as("aggregated").withValueSerde(Serdes.String())); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + processData(driver); + final SessionStore sessionStore = driver.getSessionStore("aggregated"); + final List, String>> data = StreamsTestUtils.toList(sessionStore.fetch("1", "2")); + assertThat( + data, + equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("1", new SessionWindow(10, 15)), "0+0+1+2"), + KeyValue.pair(new Windowed<>("1", new SessionWindow(600, 600)), "0+3"), + KeyValue.pair(new Windowed<>("2", new SessionWindow(599, 600)), "0+0+1+2")))); + } + } + + @Test + public void shouldThrowNullPointerOnAggregateIfInitializerIsNull() { + assertThrows(NullPointerException.class, () -> stream.aggregate(null, MockAggregator.TOSTRING_ADDER, sessionMerger)); + } + + @Test + public void shouldThrowNullPointerOnAggregateIfAggregatorIsNull() { + assertThrows(NullPointerException.class, () -> stream.aggregate(MockInitializer.STRING_INIT, null, sessionMerger)); + } + + @Test + public void shouldThrowNullPointerOnAggregateIfMergerIsNull() { + assertThrows(NullPointerException.class, () -> stream.aggregate(MockInitializer.STRING_INIT, MockAggregator.TOSTRING_ADDER, null)); + } + + @Test + public void shouldThrowNullPointerOnReduceIfReducerIsNull() { + assertThrows(NullPointerException.class, () -> stream.reduce(null)); + } + + @Test + public void shouldThrowNullPointerOnMaterializedAggregateIfInitializerIsNull() { + assertThrows(NullPointerException.class, () -> stream.aggregate( + null, + MockAggregator.TOSTRING_ADDER, + sessionMerger, + Materialized.as("store"))); + } + + @Test + public void shouldThrowNullPointerOnMaterializedAggregateIfAggregatorIsNull() { + assertThrows(NullPointerException.class, () -> stream.aggregate( + MockInitializer.STRING_INIT, + null, + sessionMerger, + Materialized.as("store"))); + } + + @Test + public void shouldThrowNullPointerOnMaterializedAggregateIfMergerIsNull() { + assertThrows(NullPointerException.class, () -> stream.aggregate( + MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + null, + Materialized.as("store"))); + } + + @SuppressWarnings("unchecked") + @Test + public void shouldThrowNullPointerOnMaterializedAggregateIfMaterializedIsNull() { + assertThrows(NullPointerException.class, () -> stream.aggregate( + MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + sessionMerger, + (Materialized) null)); + } + + @Test + public void shouldThrowNullPointerOnMaterializedReduceIfReducerIsNull() { + assertThrows(NullPointerException.class, () -> stream.reduce(null, Materialized.as("store"))); + } + + @Test + @SuppressWarnings("unchecked") + public void shouldThrowNullPointerOnMaterializedReduceIfMaterializedIsNull() { + assertThrows(NullPointerException.class, () -> stream.reduce(MockReducer.STRING_ADDER, (Materialized) null)); + } + + @Test + public void shouldThrowNullPointerOnMaterializedReduceIfNamedIsNull() { + assertThrows(NullPointerException.class, () -> stream.reduce(MockReducer.STRING_ADDER, (Named) null)); + } + + @Test + public void shouldThrowNullPointerOnCountIfMaterializedIsNull() { + assertThrows(NullPointerException.class, () -> stream.count((Materialized>) null)); + } + + private void processData(final TopologyTestDriver driver) { + final TestInputTopic inputTopic = + driver.createInputTopic(TOPIC, new StringSerializer(), new StringSerializer()); + inputTopic.pipeInput("1", "1", 10); + inputTopic.pipeInput("1", "2", 15); + inputTopic.pipeInput("1", "3", 600); + inputTopic.pipeInput("2", "1", 600); + inputTopic.pipeInput("2", "2", 599); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/SlidingWindowedCogroupedKStreamImplTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/SlidingWindowedCogroupedKStreamImplTest.java new file mode 100644 index 0000000..52ff858 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/SlidingWindowedCogroupedKStreamImplTest.java @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals; + +import static java.time.Duration.ofMillis; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import java.util.LinkedList; +import java.util.List; +import java.util.Properties; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TestOutputTopic; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.CogroupedKStream; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Named; +import org.apache.kafka.streams.kstream.SlidingWindows; +import org.apache.kafka.streams.kstream.TimeWindowedCogroupedKStream; +import org.apache.kafka.streams.kstream.TimeWindowedDeserializer; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.streams.test.TestRecord; +import org.apache.kafka.test.MockAggregator; +import org.apache.kafka.test.MockInitializer; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Before; +import org.junit.Test; + +@SuppressWarnings("deprecation") +public class SlidingWindowedCogroupedKStreamImplTest { + + private static final String TOPIC = "topic"; + private static final String TOPIC2 = "topic2"; + private static final String OUTPUT = "output"; + private static final long WINDOW_SIZE_MS = 500L; + private final StreamsBuilder builder = new StreamsBuilder(); + + private KGroupedStream groupedStream; + + private TimeWindowedCogroupedKStream windowedCogroupedStream; + + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.String(), Serdes.String()); + + @Before + public void setup() { + final KStream stream = builder.stream(TOPIC, Consumed + .with(Serdes.String(), Serdes.String())); + final KStream stream2 = builder.stream(TOPIC2, Consumed + .with(Serdes.String(), Serdes.String())); + + groupedStream = stream.groupByKey(Grouped.with(Serdes.String(), Serdes.String())); + final KGroupedStream groupedStream2 = stream2.groupByKey(Grouped.with(Serdes.String(), Serdes.String())); + final CogroupedKStream cogroupedStream = groupedStream.cogroup(MockAggregator.TOSTRING_ADDER) + .cogroup(groupedStream2, MockAggregator.TOSTRING_REMOVER); + windowedCogroupedStream = cogroupedStream.windowedBy(SlidingWindows.withTimeDifferenceAndGrace(ofMillis( + WINDOW_SIZE_MS), ofMillis(2000L))); + } + + @Test + public void shouldNotHaveNullInitializerOnAggregate() { + assertThrows(NullPointerException.class, () -> windowedCogroupedStream.aggregate(null)); + } + + @Test + public void shouldNotHaveNullMaterializedOnTwoOptionAggregate() { + assertThrows(NullPointerException.class, () -> windowedCogroupedStream.aggregate(MockInitializer.STRING_INIT, (Materialized>) null)); + } + + @Test + public void shouldNotHaveNullNamedTwoOptionOnAggregate() { + assertThrows(NullPointerException.class, () -> windowedCogroupedStream.aggregate(MockInitializer.STRING_INIT, (Named) null)); + } + + @Test + public void shouldNotHaveNullInitializerTwoOptionNamedOnAggregate() { + assertThrows(NullPointerException.class, () -> windowedCogroupedStream.aggregate(null, Named.as("test"))); + } + + @Test + public void shouldNotHaveNullInitializerTwoOptionMaterializedOnAggregate() { + assertThrows(NullPointerException.class, () -> windowedCogroupedStream.aggregate(null, Materialized.as("test"))); + } + + @Test + public void shouldNotHaveNullInitializerThreeOptionOnAggregate() { + assertThrows(NullPointerException.class, () -> windowedCogroupedStream.aggregate(null, Named.as("test"), Materialized.as("test"))); + } + + @Test + public void shouldNotHaveNullMaterializedOnAggregate() { + assertThrows(NullPointerException.class, () -> windowedCogroupedStream.aggregate(MockInitializer.STRING_INIT, Named.as("Test"), null)); + } + + @Test + public void shouldNotHaveNullNamedOnAggregate() { + assertThrows(NullPointerException.class, () -> windowedCogroupedStream.aggregate(MockInitializer.STRING_INIT, null, Materialized.as("test"))); + } + + @Test + public void namedParamShouldSetName() { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream = builder.stream(TOPIC, Consumed + .with(Serdes.String(), Serdes.String())); + groupedStream = stream.groupByKey(Grouped.with(Serdes.String(), Serdes.String())); + groupedStream.cogroup(MockAggregator.TOSTRING_ADDER) + .windowedBy(SlidingWindows.withTimeDifferenceAndGrace(ofMillis(WINDOW_SIZE_MS), ofMillis(2000L))) + .aggregate(MockInitializer.STRING_INIT, Named.as("foo")); + + assertThat(builder.build().describe().toString(), equalTo( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [topic])\n" + + " --> foo-cogroup-agg-0\n" + + " Processor: foo-cogroup-agg-0 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000001])\n" + + " --> foo-cogroup-merge\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: foo-cogroup-merge (stores: [])\n" + + " --> none\n" + + " <-- foo-cogroup-agg-0\n\n")); + } + + @Test + public void slidingWindowAggregateStreamsTest() { + final KTable, String> customers = windowedCogroupedStream.aggregate( + MockInitializer.STRING_INIT, Materialized.with(Serdes.String(), Serdes.String())); + customers.toStream().to(OUTPUT); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic testInputTopic = driver.createInputTopic( + TOPIC, new StringSerializer(), new StringSerializer()); + final TestOutputTopic, String> testOutputTopic = driver.createOutputTopic( + OUTPUT, new TimeWindowedDeserializer<>(new StringDeserializer(), WINDOW_SIZE_MS), new StringDeserializer()); + + testInputTopic.pipeInput("k1", "A", 500); + testInputTopic.pipeInput("k2", "A", 500); + testInputTopic.pipeInput("k2", "A", 501); + testInputTopic.pipeInput("k1", "A", 502); + testInputTopic.pipeInput("k1", "B", 503); + testInputTopic.pipeInput("k2", "B", 503); + testInputTopic.pipeInput("k2", "B", 504); + testInputTopic.pipeInput("k1", "B", 504); + + final List, String>> results = testOutputTopic.readRecordsToList(); + + final List, String>> expected = new LinkedList<>(); + // k1-A-500 + expected.add(new TestRecord<>(new Windowed<>("k1", new TimeWindow(0L, 500L)), "0+A", null, 500L)); + // k2-A-500 + expected.add(new TestRecord<>(new Windowed<>("k2", new TimeWindow(0L, 500L)), "0+A", null, 500L)); + // k2-A-501 + expected.add(new TestRecord<>(new Windowed<>("k2", new TimeWindow(501L, 1001L)), "0+A", null, 501L)); + expected.add(new TestRecord<>(new Windowed<>("k2", new TimeWindow(1L, 501L)), "0+A+A", null, 501L)); + // k1-A-502 + expected.add(new TestRecord<>(new Windowed<>("k1", new TimeWindow(501L, 1001L)), "0+A", null, 502L)); + expected.add(new TestRecord<>(new Windowed<>("k1", new TimeWindow(2L, 502L)), "0+A+A", null, 502L)); + // k1-B-503 + expected.add(new TestRecord<>(new Windowed<>("k1", new TimeWindow(501L, 1001L)), "0+A+B", null, 503L)); + expected.add(new TestRecord<>(new Windowed<>("k1", new TimeWindow(503L, 1003L)), "0+B", null, 503L)); + expected.add(new TestRecord<>(new Windowed<>("k1", new TimeWindow(3L, 503L)), "0+A+A+B", null, 503L)); + // k2-B-503 + expected.add(new TestRecord<>(new Windowed<>("k2", new TimeWindow(501L, 1001L)), "0+A+B", null, 503L)); + expected.add(new TestRecord<>(new Windowed<>("k2", new TimeWindow(502L, 1002)), "0+B", null, 503L)); + expected.add(new TestRecord<>(new Windowed<>("k2", new TimeWindow(3L, 503L)), "0+A+A+B", null, 503L)); + // k2-B-504 + expected.add(new TestRecord<>(new Windowed<>("k2", new TimeWindow(502L, 1002L)), "0+B+B", null, 504L)); + expected.add(new TestRecord<>(new Windowed<>("k2", new TimeWindow(501L, 1001L)), "0+A+B+B", null, 504L)); + expected.add(new TestRecord<>(new Windowed<>("k2", new TimeWindow(504L, 1004L)), "0+B", null, 504L)); + expected.add(new TestRecord<>(new Windowed<>("k2", new TimeWindow(4L, 504L)), "0+A+A+B+B", null, 504L)); + // k1-B-504 + expected.add(new TestRecord<>(new Windowed<>("k1", new TimeWindow(503L, 1003L)), "0+B+B", null, 504L)); + expected.add(new TestRecord<>(new Windowed<>("k1", new TimeWindow(501L, 1001L)), "0+A+B+B", null, 504L)); + expected.add(new TestRecord<>(new Windowed<>("k1", new TimeWindow(504L, 1004L)), "0+B", null, 504L)); + expected.add(new TestRecord<>(new Windowed<>("k1", new TimeWindow(4L, 504L)), "0+A+A+B+B", null, 504L)); + + assertEquals(expected, results); + } + } + + @Test + public void slidingWindowAggregateOverlappingWindowsTest() { + + final KTable, String> customers = groupedStream.cogroup(MockAggregator.TOSTRING_ADDER) + .windowedBy(SlidingWindows.withTimeDifferenceAndGrace(ofMillis(WINDOW_SIZE_MS), ofMillis(2000L))).aggregate( + MockInitializer.STRING_INIT, Materialized.with(Serdes.String(), Serdes.String())); + customers.toStream().to(OUTPUT); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic testInputTopic = driver.createInputTopic( + TOPIC, new StringSerializer(), new StringSerializer()); + final TestOutputTopic, String> testOutputTopic = driver.createOutputTopic( + OUTPUT, new TimeWindowedDeserializer<>(new StringDeserializer(), WINDOW_SIZE_MS), new StringDeserializer()); + + testInputTopic.pipeInput("k1", "A", 500); + testInputTopic.pipeInput("k2", "A", 500); + testInputTopic.pipeInput("k1", "B", 750); + testInputTopic.pipeInput("k2", "B", 750); + testInputTopic.pipeInput("k2", "A", 1000L); + testInputTopic.pipeInput("k1", "A", 1000L); + + // left window k1@500 + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+A", 500); + // left window k2@500 + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+A", 500); + // right window k1@500 + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+B", 750); + // left window k1@750 + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+A+B", 750); + // right window k2@500 + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+B", 750); + // left window k2@750 + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+A+B", 750); + // right window k2@500 update + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+B+A", 1000); + // right window k2@750 + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+A", 1000); + // left window k2@1000 + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+A+B+A", 1000); + // right window k1@500 update + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+B+A", 1000); + // right window k1@750 + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+A", 1000); + // left window k1@1000 + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+A+B+A", 1000); + } + } + + private void assertOutputKeyValueTimestamp(final TestOutputTopic, String> outputTopic, + final String expectedKey, + final String expectedValue, + final long expectedTimestamp) { + final TestRecord, String> realRecord = outputTopic.readRecord(); + final TestRecord nonWindowedRecord = new TestRecord<>( + realRecord.getKey().key(), realRecord.getValue(), null, realRecord.timestamp()); + final TestRecord testRecord = new TestRecord<>(expectedKey, expectedValue, null, expectedTimestamp); + assertThat(nonWindowedRecord, equalTo(testRecord)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/SlidingWindowedKStreamImplTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/SlidingWindowedKStreamImplTest.java new file mode 100644 index 0000000..e4a965e --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/SlidingWindowedKStreamImplTest.java @@ -0,0 +1,439 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Named; +import org.apache.kafka.streams.kstream.SlidingWindows; +import org.apache.kafka.streams.kstream.TimeWindowedKStream; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.streams.state.WindowBytesStoreSupplier; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.test.MockAggregator; +import org.apache.kafka.test.MockInitializer; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.MockReducer; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.List; +import java.util.Properties; + +import static java.time.Duration.ofMillis; +import static java.time.Instant.ofEpochMilli; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertThrows; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class SlidingWindowedKStreamImplTest { + + private static final String TOPIC = "input"; + private final StreamsBuilder builder = new StreamsBuilder(); + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.String(), Serdes.String()); + private TimeWindowedKStream windowedStream; + + @Before + public void before() { + final KStream stream = builder.stream(TOPIC, Consumed.with(Serdes.String(), Serdes.String())); + windowedStream = stream. + groupByKey(Grouped.with(Serdes.String(), Serdes.String())) + .windowedBy(SlidingWindows.withTimeDifferenceAndGrace(ofMillis(100L), ofMillis(1000L))); + } + + @Test + public void shouldCountSlidingWindows() { + final MockProcessorSupplier, Long> supplier = new MockProcessorSupplier<>(); + windowedStream + .count() + .toStream() + .process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + processData(driver); + } + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("1", new TimeWindow(0L, 100L))), + equalTo(ValueAndTimestamp.make(1L, 100L))); + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("1", new TimeWindow(101L, 201L))), + equalTo(ValueAndTimestamp.make(1L, 150L))); + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("1", new TimeWindow(50L, 150L))), + equalTo(ValueAndTimestamp.make(2L, 150L))); + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("1", new TimeWindow(400L, 500L))), + equalTo(ValueAndTimestamp.make(1L, 500L))); + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("2", new TimeWindow(100L, 200L))), + equalTo(ValueAndTimestamp.make(2L, 200L))); + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("2", new TimeWindow(50L, 150L))), + equalTo(ValueAndTimestamp.make(1L, 150L))); + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("2", new TimeWindow(151L, 251L))), + equalTo(ValueAndTimestamp.make(1L, 200L))); + } + + @Test + public void shouldReduceSlidingWindows() { + final MockProcessorSupplier, String> supplier = new MockProcessorSupplier<>(); + windowedStream + .reduce(MockReducer.STRING_ADDER) + .toStream() + .process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + processData(driver); + } + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("1", new TimeWindow(0L, 100L))), + equalTo(ValueAndTimestamp.make("1", 100L))); + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("1", new TimeWindow(101L, 201L))), + equalTo(ValueAndTimestamp.make("2", 150L))); + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("1", new TimeWindow(50L, 150L))), + equalTo(ValueAndTimestamp.make("1+2", 150L))); + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("1", new TimeWindow(400L, 500L))), + equalTo(ValueAndTimestamp.make("3", 500L))); + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("2", new TimeWindow(100L, 200L))), + equalTo(ValueAndTimestamp.make("10+20", 200L))); + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("2", new TimeWindow(50L, 150L))), + equalTo(ValueAndTimestamp.make("20", 150L))); + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("2", new TimeWindow(151L, 251L))), + equalTo(ValueAndTimestamp.make("10", 200L))); + } + + @Test + public void shouldAggregateSlidingWindows() { + final MockProcessorSupplier, String> supplier = new MockProcessorSupplier<>(); + windowedStream + .aggregate( + MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + Materialized.with(Serdes.String(), Serdes.String())) + .toStream() + .process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + processData(driver); + } + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("1", new TimeWindow(0L, 100L))), + equalTo(ValueAndTimestamp.make("0+1", 100L))); + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("1", new TimeWindow(101L, 201L))), + equalTo(ValueAndTimestamp.make("0+2", 150L))); + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("1", new TimeWindow(50L, 150L))), + equalTo(ValueAndTimestamp.make("0+1+2", 150L))); + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("1", new TimeWindow(400L, 500L))), + equalTo(ValueAndTimestamp.make("0+3", 500L))); + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("2", new TimeWindow(100L, 200L))), + equalTo(ValueAndTimestamp.make("0+10+20", 200L))); + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("2", new TimeWindow(50L, 150L))), + equalTo(ValueAndTimestamp.make("0+20", 150L))); + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("2", new TimeWindow(151L, 251L))), + equalTo(ValueAndTimestamp.make("0+10", 200L))); + } + + @Test + public void shouldMaterializeCount() { + windowedStream.count( + Materialized.>as("count-store") + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.Long())); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + processData(driver); + { + final WindowStore windowStore = driver.getWindowStore("count-store"); + final List, Long>> data = + StreamsTestUtils.toList(windowStore.fetch("1", "2", ofEpochMilli(0), ofEpochMilli(1000L))); + + assertThat(data, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("1", new TimeWindow(0, 100)), 1L), + KeyValue.pair(new Windowed<>("1", new TimeWindow(50, 150)), 2L), + KeyValue.pair(new Windowed<>("1", new TimeWindow(101, 201)), 1L), + KeyValue.pair(new Windowed<>("1", new TimeWindow(400, 500)), 1L), + KeyValue.pair(new Windowed<>("2", new TimeWindow(50, 150)), 1L), + KeyValue.pair(new Windowed<>("2", new TimeWindow(100, 200)), 2L), + KeyValue.pair(new Windowed<>("2", new TimeWindow(151, 251)), 1L)))); + } + { + final WindowStore> windowStore = + driver.getTimestampedWindowStore("count-store"); + final List, ValueAndTimestamp>> data = + StreamsTestUtils.toList(windowStore.fetch("1", "2", ofEpochMilli(0), ofEpochMilli(1000L))); + assertThat(data, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("1", new TimeWindow(0, 100)), ValueAndTimestamp.make(1L, 100L)), + KeyValue.pair(new Windowed<>("1", new TimeWindow(50, 150)), ValueAndTimestamp.make(2L, 150L)), + KeyValue.pair(new Windowed<>("1", new TimeWindow(101, 201)), ValueAndTimestamp.make(1L, 150L)), + KeyValue.pair(new Windowed<>("1", new TimeWindow(400, 500)), ValueAndTimestamp.make(1L, 500L)), + KeyValue.pair(new Windowed<>("2", new TimeWindow(50, 150)), ValueAndTimestamp.make(1L, 150L)), + KeyValue.pair(new Windowed<>("2", new TimeWindow(100, 200)), ValueAndTimestamp.make(2L, 200L)), + KeyValue.pair(new Windowed<>("2", new TimeWindow(151, 251)), ValueAndTimestamp.make(1L, 200L))))); } + } + } + + @Test + public void shouldMaterializeReduced() { + windowedStream.reduce( + MockReducer.STRING_ADDER, + Materialized.>as("reduced") + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.String())); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + processData(driver); + { + final WindowStore windowStore = driver.getWindowStore("reduced"); + final List, String>> data = + StreamsTestUtils.toList(windowStore.fetch("1", "2", ofEpochMilli(0), ofEpochMilli(1000L))); + assertThat(data, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("1", new TimeWindow(0, 100)), "1"), + KeyValue.pair(new Windowed<>("1", new TimeWindow(50, 150)), "1+2"), + KeyValue.pair(new Windowed<>("1", new TimeWindow(101, 201)), "2"), + KeyValue.pair(new Windowed<>("1", new TimeWindow(400, 500)), "3"), + KeyValue.pair(new Windowed<>("2", new TimeWindow(50, 150)), "20"), + KeyValue.pair(new Windowed<>("2", new TimeWindow(100, 200)), "10+20"), + KeyValue.pair(new Windowed<>("2", new TimeWindow(151, 251)), "10")))); + } + { + final WindowStore> windowStore = + driver.getTimestampedWindowStore("reduced"); + final List, ValueAndTimestamp>> data = + StreamsTestUtils.toList(windowStore.fetch("1", "2", ofEpochMilli(0), ofEpochMilli(1000L))); + assertThat(data, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("1", new TimeWindow(0, 100)), ValueAndTimestamp.make("1", 100L)), + KeyValue.pair(new Windowed<>("1", new TimeWindow(50, 150)), ValueAndTimestamp.make("1+2", 150L)), + KeyValue.pair(new Windowed<>("1", new TimeWindow(101, 201)), ValueAndTimestamp.make("2", 150L)), + KeyValue.pair(new Windowed<>("1", new TimeWindow(400, 500)), ValueAndTimestamp.make("3", 500L)), + KeyValue.pair(new Windowed<>("2", new TimeWindow(50, 150)), ValueAndTimestamp.make("20", 150L)), + KeyValue.pair(new Windowed<>("2", new TimeWindow(100, 200)), ValueAndTimestamp.make("10+20", 200L)), + KeyValue.pair(new Windowed<>("2", new TimeWindow(151, 251)), ValueAndTimestamp.make("10", 200L))))); + } + } + } + + @Test + public void shouldMaterializeAggregated() { + windowedStream.aggregate( + MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + Materialized.>as("aggregated") + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.String())); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + processData(driver); + { + final WindowStore windowStore = driver.getWindowStore("aggregated"); + final List, String>> data = + StreamsTestUtils.toList(windowStore.fetch("1", "2", ofEpochMilli(0), ofEpochMilli(1000L))); + assertThat(data, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("1", new TimeWindow(0, 100)), "0+1"), + KeyValue.pair(new Windowed<>("1", new TimeWindow(50, 150)), "0+1+2"), + KeyValue.pair(new Windowed<>("1", new TimeWindow(101, 201)), "0+2"), + KeyValue.pair(new Windowed<>("1", new TimeWindow(400, 500)), "0+3"), + KeyValue.pair(new Windowed<>("2", new TimeWindow(50, 150)), "0+20"), + KeyValue.pair(new Windowed<>("2", new TimeWindow(100, 200)), "0+10+20"), + KeyValue.pair(new Windowed<>("2", new TimeWindow(151, 251)), "0+10")))); + } + { + final WindowStore> windowStore = + driver.getTimestampedWindowStore("aggregated"); + final List, ValueAndTimestamp>> data = + StreamsTestUtils.toList(windowStore.fetch("1", "2", ofEpochMilli(0), ofEpochMilli(1000L))); + assertThat(data, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("1", new TimeWindow(0, 100)), ValueAndTimestamp.make("0+1", 100L)), + KeyValue.pair(new Windowed<>("1", new TimeWindow(50, 150)), ValueAndTimestamp.make("0+1+2", 150L)), + KeyValue.pair(new Windowed<>("1", new TimeWindow(101, 201)), ValueAndTimestamp.make("0+2", 150L)), + KeyValue.pair(new Windowed<>("1", new TimeWindow(400, 500)), ValueAndTimestamp.make("0+3", 500L)), + KeyValue.pair(new Windowed<>("2", new TimeWindow(50, 150)), ValueAndTimestamp.make("0+20", 150L)), + KeyValue.pair(new Windowed<>("2", new TimeWindow(100, 200)), ValueAndTimestamp.make("0+10+20", 200L)), + KeyValue.pair(new Windowed<>("2", new TimeWindow(151, 251)), ValueAndTimestamp.make("0+10", 200L))))); + } + } + } + + @Test + public void shouldThrowNullPointerOnAggregateIfInitializerIsNull() { + assertThrows(NullPointerException.class, () -> windowedStream.aggregate(null, MockAggregator.TOSTRING_ADDER)); + } + + @Test + public void shouldThrowNullPointerOnAggregateIfAggregatorIsNull() { + assertThrows(NullPointerException.class, () -> windowedStream.aggregate(MockInitializer.STRING_INIT, null)); + } + + @Test + public void shouldThrowNullPointerOnReduceIfReducerIsNull() { + assertThrows(NullPointerException.class, () -> windowedStream.reduce(null)); + } + + @Test + public void shouldThrowNullPointerOnMaterializedAggregateIfInitializerIsNull() { + assertThrows(NullPointerException.class, () -> windowedStream.aggregate(null, MockAggregator.TOSTRING_ADDER, Materialized.as("store"))); + } + + @Test + public void shouldThrowNullPointerOnMaterializedAggregateIfAggregatorIsNull() { + assertThrows(NullPointerException.class, () -> windowedStream.aggregate( + MockInitializer.STRING_INIT, + null, + Materialized.as("store"))); + } + + @SuppressWarnings("unchecked") + @Test + public void shouldThrowNullPointerOnMaterializedAggregateIfMaterializedIsNull() { + assertThrows(NullPointerException.class, () -> windowedStream.aggregate(MockInitializer.STRING_INIT, MockAggregator.TOSTRING_ADDER, (Materialized) null)); + } + + @Test + public void shouldThrowNullPointerOnMaterializedReduceIfReducerIsNull() { + assertThrows(NullPointerException.class, () -> windowedStream.reduce(null, Materialized.as("store"))); + } + + @Test + @SuppressWarnings("unchecked") + public void shouldThrowNullPointerOnMaterializedReduceIfMaterializedIsNull() { + assertThrows(NullPointerException.class, () -> windowedStream.reduce(MockReducer.STRING_ADDER, (Materialized) null)); + } + + @Test + public void shouldThrowNullPointerOnMaterializedReduceIfNamedIsNull() { + assertThrows(NullPointerException.class, () -> windowedStream.reduce(MockReducer.STRING_ADDER, (Named) null)); + } + + @Test + public void shouldThrowNullPointerOnCountIfMaterializedIsNull() { + assertThrows(NullPointerException.class, () -> windowedStream.count((Materialized>) null)); + } + + @Test + public void shouldThrowIllegalArgumentWhenRetentionIsTooSmall() { + assertThrows(IllegalArgumentException.class, () -> windowedStream + .aggregate( + MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + Materialized + .>as("aggregated") + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.String()) + .withRetention(ofMillis(1L)) + ) + ); + } + + @Test + public void shouldDropWindowsOutsideOfRetention() { + final WindowBytesStoreSupplier storeSupplier = Stores.inMemoryWindowStore("aggregated", ofMillis(1200L), ofMillis(100L), false); + windowedStream.aggregate( + MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + Materialized.as(storeSupplier) + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.String()) + .withCachingDisabled()); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic inputTopic = + driver.createInputTopic(TOPIC, new StringSerializer(), new StringSerializer()); + + inputTopic.pipeInput("1", "2", 100L); + inputTopic.pipeInput("1", "3", 500L); + inputTopic.pipeInput("1", "4", 799L); + inputTopic.pipeInput("1", "4", 1000L); + inputTopic.pipeInput("1", "5", 2000L); + + { + final WindowStore windowStore = driver.getWindowStore("aggregated"); + final List, String>> data = + StreamsTestUtils.toList(windowStore.fetch("1", "1", ofEpochMilli(0), ofEpochMilli(10000L))); + assertThat(data, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("1", new TimeWindow(900, 1000)), "0+4"), + KeyValue.pair(new Windowed<>("1", new TimeWindow(1900, 2000)), "0+5")))); + } + { + final WindowStore> windowStore = + driver.getTimestampedWindowStore("aggregated"); + final List, ValueAndTimestamp>> data = + StreamsTestUtils.toList(windowStore.fetch("1", "1", ofEpochMilli(0), ofEpochMilli(2000L))); + assertThat(data, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("1", new TimeWindow(900, 1000)), ValueAndTimestamp.make("0+4", 1000L)), + KeyValue.pair(new Windowed<>("1", new TimeWindow(1900, 2000)), ValueAndTimestamp.make("0+5", 2000L))))); + } + } + } + + private void processData(final TopologyTestDriver driver) { + final TestInputTopic inputTopic = + driver.createInputTopic(TOPIC, new StringSerializer(), new StringSerializer()); + inputTopic.pipeInput("1", "1", 100L); + inputTopic.pipeInput("1", "2", 150L); + inputTopic.pipeInput("1", "3", 500L); + inputTopic.pipeInput("2", "10", 200L); + inputTopic.pipeInput("2", "20", 150L); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/SuppressScenarioTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/SuppressScenarioTest.java new file mode 100644 index 0000000..7b521ab --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/SuppressScenarioTest.java @@ -0,0 +1,857 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.LongDeserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Named; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.SessionWindows; +import org.apache.kafka.streams.kstream.SlidingWindows; +import org.apache.kafka.streams.kstream.Suppressed; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.SessionStore; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.test.TestRecord; +import org.apache.kafka.test.TestUtils; +import org.junit.Test; + +import java.time.Duration; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Properties; + +import static java.time.Duration.ZERO; +import static java.time.Duration.ofMillis; +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; +import static java.util.Collections.singletonList; +import static org.apache.kafka.streams.kstream.Suppressed.BufferConfig.maxBytes; +import static org.apache.kafka.streams.kstream.Suppressed.BufferConfig.maxRecords; +import static org.apache.kafka.streams.kstream.Suppressed.BufferConfig.unbounded; +import static org.apache.kafka.streams.kstream.Suppressed.untilTimeLimit; +import static org.apache.kafka.streams.kstream.Suppressed.untilWindowCloses; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; + +@SuppressWarnings("deprecation") +public class SuppressScenarioTest { + private static final StringDeserializer STRING_DESERIALIZER = new StringDeserializer(); + private static final StringSerializer STRING_SERIALIZER = new StringSerializer(); + private static final Serde STRING_SERDE = Serdes.String(); + private static final LongDeserializer LONG_DESERIALIZER = new LongDeserializer(); + private final Properties config = Utils.mkProperties(Utils.mkMap( + Utils.mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()) + )); + + @Test + public void shouldImmediatelyEmitEventsWithZeroEmitAfter() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KTable valueCounts = builder + .table( + "input", + Consumed.with(STRING_SERDE, STRING_SERDE), + Materialized.>with(STRING_SERDE, STRING_SERDE) + .withCachingDisabled() + .withLoggingDisabled() + ) + .groupBy((k, v) -> new KeyValue<>(v, k), Grouped.with(STRING_SERDE, STRING_SERDE)) + .count(); + + valueCounts + .suppress(untilTimeLimit(ZERO, unbounded())) + .toStream() + .to("output-suppressed", Produced.with(STRING_SERDE, Serdes.Long())); + + valueCounts + .toStream() + .to("output-raw", Produced.with(STRING_SERDE, Serdes.Long())); + + final Topology topology = builder.build(); + + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, config)) { + final TestInputTopic inputTopic = + driver.createInputTopic("input", STRING_SERIALIZER, STRING_SERIALIZER); + inputTopic.pipeInput("k1", "v1", 0L); + inputTopic.pipeInput("k1", "v2", 1L); + inputTopic.pipeInput("k2", "v1", 2L); + verify( + drainProducerRecords(driver, "output-raw", STRING_DESERIALIZER, LONG_DESERIALIZER), + asList( + new KeyValueTimestamp<>("v1", 1L, 0L), + new KeyValueTimestamp<>("v1", 0L, 1L), + new KeyValueTimestamp<>("v2", 1L, 1L), + new KeyValueTimestamp<>("v1", 1L, 2L) + ) + ); + verify( + drainProducerRecords(driver, "output-suppressed", STRING_DESERIALIZER, LONG_DESERIALIZER), + asList( + new KeyValueTimestamp<>("v1", 1L, 0L), + new KeyValueTimestamp<>("v1", 0L, 1L), + new KeyValueTimestamp<>("v2", 1L, 1L), + new KeyValueTimestamp<>("v1", 1L, 2L) + ) + ); + inputTopic.pipeInput("x", "x", 3L); + verify( + drainProducerRecords(driver, "output-raw", STRING_DESERIALIZER, LONG_DESERIALIZER), + singletonList( + new KeyValueTimestamp<>("x", 1L, 3L) + ) + ); + verify( + drainProducerRecords(driver, "output-suppressed", STRING_DESERIALIZER, LONG_DESERIALIZER), + singletonList( + new KeyValueTimestamp<>("x", 1L, 3L) + ) + ); + inputTopic.pipeInput("x", "y", 4L); + verify( + drainProducerRecords(driver, "output-raw", STRING_DESERIALIZER, LONG_DESERIALIZER), + asList( + new KeyValueTimestamp<>("x", 0L, 4L), + new KeyValueTimestamp<>("y", 1L, 4L) + ) + ); + verify( + drainProducerRecords(driver, "output-suppressed", STRING_DESERIALIZER, LONG_DESERIALIZER), + asList( + new KeyValueTimestamp<>("x", 0L, 4L), + new KeyValueTimestamp<>("y", 1L, 4L) + ) + ); + } + } + + @Test + public void shouldSuppressIntermediateEventsWithTimeLimit() { + final StreamsBuilder builder = new StreamsBuilder(); + final KTable valueCounts = builder + .table( + "input", + Consumed.with(STRING_SERDE, STRING_SERDE), + Materialized.>with(STRING_SERDE, STRING_SERDE) + .withCachingDisabled() + .withLoggingDisabled() + ) + .groupBy((k, v) -> new KeyValue<>(v, k), Grouped.with(STRING_SERDE, STRING_SERDE)) + .count(); + valueCounts + .suppress(untilTimeLimit(ofMillis(2L), unbounded())) + .toStream() + .to("output-suppressed", Produced.with(STRING_SERDE, Serdes.Long())); + valueCounts + .toStream() + .to("output-raw", Produced.with(STRING_SERDE, Serdes.Long())); + final Topology topology = builder.build(); + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, config)) { + final TestInputTopic inputTopic = + driver.createInputTopic("input", STRING_SERIALIZER, STRING_SERIALIZER); + inputTopic.pipeInput("k1", "v1", 0L); + inputTopic.pipeInput("k1", "v2", 1L); + inputTopic.pipeInput("k2", "v1", 2L); + verify( + drainProducerRecords(driver, "output-raw", STRING_DESERIALIZER, LONG_DESERIALIZER), + asList( + new KeyValueTimestamp<>("v1", 1L, 0L), + new KeyValueTimestamp<>("v1", 0L, 1L), + new KeyValueTimestamp<>("v2", 1L, 1L), + new KeyValueTimestamp<>("v1", 1L, 2L) + ) + ); + verify( + drainProducerRecords(driver, "output-suppressed", STRING_DESERIALIZER, LONG_DESERIALIZER), + singletonList(new KeyValueTimestamp<>("v1", 1L, 2L)) + ); + // inserting a dummy "tick" record just to advance stream time + inputTopic.pipeInput("tick", "tick", 3L); + verify( + drainProducerRecords(driver, "output-raw", STRING_DESERIALIZER, LONG_DESERIALIZER), + singletonList(new KeyValueTimestamp<>("tick", 1L, 3L)) + ); + // the stream time is now 3, so it's time to emit this record + verify( + drainProducerRecords(driver, "output-suppressed", STRING_DESERIALIZER, LONG_DESERIALIZER), + singletonList(new KeyValueTimestamp<>("v2", 1L, 1L)) + ); + + + inputTopic.pipeInput("tick", "tock", 4L); + verify( + drainProducerRecords(driver, "output-raw", STRING_DESERIALIZER, LONG_DESERIALIZER), + asList( + new KeyValueTimestamp<>("tick", 0L, 4L), + new KeyValueTimestamp<>("tock", 1L, 4L) + ) + ); + // tick is still buffered, since it was first inserted at time 3, and it is only time 4 right now. + verify( + drainProducerRecords(driver, "output-suppressed", STRING_DESERIALIZER, LONG_DESERIALIZER), + emptyList() + ); + } + } + + @Test + public void shouldSuppressIntermediateEventsWithRecordLimit() { + final StreamsBuilder builder = new StreamsBuilder(); + final KTable valueCounts = builder + .table( + "input", + Consumed.with(STRING_SERDE, STRING_SERDE), + Materialized.>with(STRING_SERDE, STRING_SERDE) + .withCachingDisabled() + .withLoggingDisabled() + ) + .groupBy((k, v) -> new KeyValue<>(v, k), Grouped.with(STRING_SERDE, STRING_SERDE)) + .count(Materialized.with(STRING_SERDE, Serdes.Long())); + valueCounts + .suppress(untilTimeLimit(ofMillis(Long.MAX_VALUE), maxRecords(1L).emitEarlyWhenFull())) + .toStream() + .to("output-suppressed", Produced.with(STRING_SERDE, Serdes.Long())); + valueCounts + .toStream() + .to("output-raw", Produced.with(STRING_SERDE, Serdes.Long())); + final Topology topology = builder.build(); + System.out.println(topology.describe()); + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, config)) { + final TestInputTopic inputTopic = + driver.createInputTopic("input", STRING_SERIALIZER, STRING_SERIALIZER); + inputTopic.pipeInput("k1", "v1", 0L); + inputTopic.pipeInput("k1", "v2", 1L); + inputTopic.pipeInput("k2", "v1", 2L); + verify( + drainProducerRecords(driver, "output-raw", STRING_DESERIALIZER, LONG_DESERIALIZER), + asList( + new KeyValueTimestamp<>("v1", 1L, 0L), + new KeyValueTimestamp<>("v1", 0L, 1L), + new KeyValueTimestamp<>("v2", 1L, 1L), + new KeyValueTimestamp<>("v1", 1L, 2L) + ) + ); + verify( + drainProducerRecords(driver, "output-suppressed", STRING_DESERIALIZER, LONG_DESERIALIZER), + asList( + // consecutive updates to v1 get suppressed into only the latter. + new KeyValueTimestamp<>("v1", 0L, 1L), + new KeyValueTimestamp<>("v2", 1L, 1L) + // the last update won't be evicted until another key comes along. + ) + ); + inputTopic.pipeInput("x", "x", 3L); + verify( + drainProducerRecords(driver, "output-raw", STRING_DESERIALIZER, LONG_DESERIALIZER), + singletonList( + new KeyValueTimestamp<>("x", 1L, 3L) + ) + ); + verify( + drainProducerRecords(driver, "output-suppressed", STRING_DESERIALIZER, LONG_DESERIALIZER), + singletonList( + // now we see that last update to v1, but we won't see the update to x until it gets evicted + new KeyValueTimestamp<>("v1", 1L, 2L) + ) + ); + } + } + + @Test + public void shouldSuppressIntermediateEventsWithBytesLimit() { + final StreamsBuilder builder = new StreamsBuilder(); + final KTable valueCounts = builder + .table( + "input", + Consumed.with(STRING_SERDE, STRING_SERDE), + Materialized.>with(STRING_SERDE, STRING_SERDE) + .withCachingDisabled() + .withLoggingDisabled() + ) + .groupBy((k, v) -> new KeyValue<>(v, k), Grouped.with(STRING_SERDE, STRING_SERDE)) + .count(); + valueCounts + // this is a bit brittle, but I happen to know that the entries are a little over 100 bytes in size. + .suppress(untilTimeLimit(ofMillis(Long.MAX_VALUE), maxBytes(200L).emitEarlyWhenFull())) + .toStream() + .to("output-suppressed", Produced.with(STRING_SERDE, Serdes.Long())); + valueCounts + .toStream() + .to("output-raw", Produced.with(STRING_SERDE, Serdes.Long())); + final Topology topology = builder.build(); + System.out.println(topology.describe()); + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, config)) { + final TestInputTopic inputTopic = + driver.createInputTopic("input", STRING_SERIALIZER, STRING_SERIALIZER); + inputTopic.pipeInput("k1", "v1", 0L); + inputTopic.pipeInput("k1", "v2", 1L); + inputTopic.pipeInput("k2", "v1", 2L); + verify( + drainProducerRecords(driver, "output-raw", STRING_DESERIALIZER, LONG_DESERIALIZER), + asList( + new KeyValueTimestamp<>("v1", 1L, 0L), + new KeyValueTimestamp<>("v1", 0L, 1L), + new KeyValueTimestamp<>("v2", 1L, 1L), + new KeyValueTimestamp<>("v1", 1L, 2L) + ) + ); + verify( + drainProducerRecords(driver, "output-suppressed", STRING_DESERIALIZER, LONG_DESERIALIZER), + asList( + // consecutive updates to v1 get suppressed into only the latter. + new KeyValueTimestamp<>("v1", 0L, 1L), + new KeyValueTimestamp<>("v2", 1L, 1L) + // the last update won't be evicted until another key comes along. + ) + ); + inputTopic.pipeInput("x", "x", 3L); + verify( + drainProducerRecords(driver, "output-raw", STRING_DESERIALIZER, LONG_DESERIALIZER), + singletonList( + new KeyValueTimestamp<>("x", 1L, 3L) + ) + ); + verify( + drainProducerRecords(driver, "output-suppressed", STRING_DESERIALIZER, LONG_DESERIALIZER), + singletonList( + // now we see that last update to v1, but we won't see the update to x until it gets evicted + new KeyValueTimestamp<>("v1", 1L, 2L) + ) + ); + } + } + + @Test + public void shouldSupportFinalResultsForTimeWindows() { + final StreamsBuilder builder = new StreamsBuilder(); + final KTable, Long> valueCounts = builder + .stream("input", Consumed.with(STRING_SERDE, STRING_SERDE)) + .groupBy((String k, String v) -> k, Grouped.with(STRING_SERDE, STRING_SERDE)) + .windowedBy(TimeWindows.of(ofMillis(2L)).grace(ofMillis(1L))) + .count(Materialized.>as("counts").withCachingDisabled()); + valueCounts + .suppress(untilWindowCloses(unbounded())) + .toStream() + .map((final Windowed k, final Long v) -> new KeyValue<>(k.toString(), v)) + .to("output-suppressed", Produced.with(STRING_SERDE, Serdes.Long())); + valueCounts + .toStream() + .map((final Windowed k, final Long v) -> new KeyValue<>(k.toString(), v)) + .to("output-raw", Produced.with(STRING_SERDE, Serdes.Long())); + final Topology topology = builder.build(); + System.out.println(topology.describe()); + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, config)) { + final TestInputTopic inputTopic = + driver.createInputTopic("input", STRING_SERIALIZER, STRING_SERIALIZER); + inputTopic.pipeInput("k1", "v1", 0L); + inputTopic.pipeInput("k1", "v1", 1L); + inputTopic.pipeInput("k1", "v1", 2L); + inputTopic.pipeInput("k1", "v1", 1L); + inputTopic.pipeInput("k1", "v1", 0L); + inputTopic.pipeInput("k1", "v1", 5L); + // note this last record gets dropped because it is out of the grace period + inputTopic.pipeInput("k1", "v1", 0L); + verify( + drainProducerRecords(driver, "output-raw", STRING_DESERIALIZER, LONG_DESERIALIZER), + asList( + new KeyValueTimestamp<>("[k1@0/2]", 1L, 0L), + new KeyValueTimestamp<>("[k1@0/2]", 2L, 1L), + new KeyValueTimestamp<>("[k1@2/4]", 1L, 2L), + new KeyValueTimestamp<>("[k1@0/2]", 3L, 1L), + new KeyValueTimestamp<>("[k1@0/2]", 4L, 1L), + new KeyValueTimestamp<>("[k1@4/6]", 1L, 5L) + ) + ); + verify( + drainProducerRecords(driver, "output-suppressed", STRING_DESERIALIZER, LONG_DESERIALIZER), + asList( + new KeyValueTimestamp<>("[k1@0/2]", 4L, 1L), + new KeyValueTimestamp<>("[k1@2/4]", 1L, 2L) + ) + ); + } + } + + @Test + public void shouldSupportFinalResultsForTimeWindowsWithLargeJump() { + final StreamsBuilder builder = new StreamsBuilder(); + final KTable, Long> valueCounts = builder + .stream("input", Consumed.with(STRING_SERDE, STRING_SERDE)) + .groupBy((String k, String v) -> k, Grouped.with(STRING_SERDE, STRING_SERDE)) + .windowedBy(TimeWindows.of(ofMillis(2L)).grace(ofMillis(2L))) + .count(Materialized.>as("counts").withCachingDisabled().withKeySerde(STRING_SERDE)); + valueCounts + .suppress(untilWindowCloses(unbounded())) + .toStream() + .map((final Windowed k, final Long v) -> new KeyValue<>(k.toString(), v)) + .to("output-suppressed", Produced.with(STRING_SERDE, Serdes.Long())); + valueCounts + .toStream() + .map((final Windowed k, final Long v) -> new KeyValue<>(k.toString(), v)) + .to("output-raw", Produced.with(STRING_SERDE, Serdes.Long())); + final Topology topology = builder.build(); + System.out.println(topology.describe()); + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, config)) { + final TestInputTopic inputTopic = + driver.createInputTopic("input", STRING_SERIALIZER, STRING_SERIALIZER); + inputTopic.pipeInput("k1", "v1", 0L); + inputTopic.pipeInput("k1", "v1", 1L); + inputTopic.pipeInput("k1", "v1", 2L); + inputTopic.pipeInput("k1", "v1", 0L); + inputTopic.pipeInput("k1", "v1", 3L); + inputTopic.pipeInput("k1", "v1", 0L); + inputTopic.pipeInput("k1", "v1", 4L); + // this update should get dropped, since the previous event advanced the stream time and closed the window. + inputTopic.pipeInput("k1", "v1", 0L); + inputTopic.pipeInput("k1", "v1", 30L); + verify( + drainProducerRecords(driver, "output-raw", STRING_DESERIALIZER, LONG_DESERIALIZER), + asList( + new KeyValueTimestamp<>("[k1@0/2]", 1L, 0L), + new KeyValueTimestamp<>("[k1@0/2]", 2L, 1L), + new KeyValueTimestamp<>("[k1@2/4]", 1L, 2L), + new KeyValueTimestamp<>("[k1@0/2]", 3L, 1L), + new KeyValueTimestamp<>("[k1@2/4]", 2L, 3L), + new KeyValueTimestamp<>("[k1@0/2]", 4L, 1L), + new KeyValueTimestamp<>("[k1@4/6]", 1L, 4L), + new KeyValueTimestamp<>("[k1@30/32]", 1L, 30L) + ) + ); + verify( + drainProducerRecords(driver, "output-suppressed", STRING_DESERIALIZER, LONG_DESERIALIZER), + asList( + new KeyValueTimestamp<>("[k1@0/2]", 4L, 1L), + new KeyValueTimestamp<>("[k1@2/4]", 2L, 3L), + new KeyValueTimestamp<>("[k1@4/6]", 1L, 4L) + ) + ); + } + } + + @Test + public void shouldSupportFinalResultsForSlidingWindows() { + final StreamsBuilder builder = new StreamsBuilder(); + final KTable, Long> valueCounts = builder + .stream("input", Consumed.with(STRING_SERDE, STRING_SERDE)) + .groupBy((String k, String v) -> k, Grouped.with(STRING_SERDE, STRING_SERDE)) + .windowedBy(SlidingWindows.withTimeDifferenceAndGrace(ofMillis(5L), ofMillis(15L))) + .count(Materialized.>as("counts").withCachingDisabled().withKeySerde(STRING_SERDE)); + valueCounts + .suppress(untilWindowCloses(unbounded())) + .toStream() + .map((final Windowed k, final Long v) -> new KeyValue<>(k.toString(), v)) + .to("output-suppressed", Produced.with(STRING_SERDE, Serdes.Long())); + valueCounts + .toStream() + .map((final Windowed k, final Long v) -> new KeyValue<>(k.toString(), v)) + .to("output-raw", Produced.with(STRING_SERDE, Serdes.Long())); + final Topology topology = builder.build(); + System.out.println(topology.describe()); + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, config)) { + final TestInputTopic inputTopic = + driver.createInputTopic("input", STRING_SERIALIZER, STRING_SERIALIZER); + inputTopic.pipeInput("k1", "v1", 10L); + inputTopic.pipeInput("k1", "v1", 11L); + inputTopic.pipeInput("k1", "v1", 10L); + inputTopic.pipeInput("k1", "v1", 13L); + inputTopic.pipeInput("k1", "v1", 10L); + inputTopic.pipeInput("k1", "v1", 24L); + // this update should get dropped, since the previous event advanced the stream time and closed the window. + inputTopic.pipeInput("k1", "v1", 5L); + inputTopic.pipeInput("k1", "v1", 7L); + // final record to advance stream time and flush windows + inputTopic.pipeInput("k1", "v1", 90L); + final Comparator> comparator = + Comparator.comparing((TestRecord o) -> o.getKey()) + .thenComparing((TestRecord o) -> o.timestamp()); + + final List> actual = drainProducerRecords(driver, "output-raw", STRING_DESERIALIZER, LONG_DESERIALIZER); + actual.sort(comparator); + verify( + actual, + asList( + // right window for k1@10 created when k1@11 is processed + new KeyValueTimestamp<>("[k1@11/16]", 1L, 11L), + // right window for k1@10 updated when k1@13 is processed + new KeyValueTimestamp<>("[k1@11/16]", 2L, 13L), + // right window for k1@11 created when k1@13 is processed + new KeyValueTimestamp<>("[k1@12/17]", 1L, 13L), + // left window for k1@24 created when k1@24 is processed + new KeyValueTimestamp<>("[k1@19/24]", 1L, 24L), + // left window for k1@10 created when k1@10 is processed + new KeyValueTimestamp<>("[k1@5/10]", 1L, 10L), + // left window for k1@10 updated when k1@10 is processed + new KeyValueTimestamp<>("[k1@5/10]", 2L, 10L), + // left window for k1@10 updated when k1@10 is processed + new KeyValueTimestamp<>("[k1@5/10]", 3L, 10L), + // left window for k1@10 updated when k1@5 is processed + new KeyValueTimestamp<>("[k1@5/10]", 4L, 10L), + // left window for k1@10 updated when k1@7 is processed + new KeyValueTimestamp<>("[k1@5/10]", 5L, 10L), + // left window for k1@11 created when k1@11 is processed + new KeyValueTimestamp<>("[k1@6/11]", 2L, 11L), + // left window for k1@11 updated when k1@10 is processed + new KeyValueTimestamp<>("[k1@6/11]", 3L, 11L), + // left window for k1@11 updated when k1@10 is processed + new KeyValueTimestamp<>("[k1@6/11]", 4L, 11L), + // left window for k1@11 updated when k1@7 is processed + new KeyValueTimestamp<>("[k1@6/11]", 5L, 11L), + // left window for k1@13 created when k1@13 is processed + new KeyValueTimestamp<>("[k1@8/13]", 4L, 13L), + // left window for k1@13 updated when k1@10 is processed + new KeyValueTimestamp<>("[k1@8/13]", 5L, 13L), + // right window for k1@90 created when k1@90 is processed + new KeyValueTimestamp<>("[k1@85/90]", 1L, 90L) + ) + ); + verify( + drainProducerRecords(driver, "output-suppressed", STRING_DESERIALIZER, LONG_DESERIALIZER), + asList( + new KeyValueTimestamp<>("[k1@5/10]", 5L, 10L), + new KeyValueTimestamp<>("[k1@6/11]", 5L, 11L), + new KeyValueTimestamp<>("[k1@8/13]", 5L, 13L), + new KeyValueTimestamp<>("[k1@11/16]", 2L, 13L), + new KeyValueTimestamp<>("[k1@12/17]", 1L, 13L), + new KeyValueTimestamp<>("[k1@19/24]", 1L, 24L) + ) + ); + } + } + + @Test + public void shouldSupportFinalResultsForSessionWindows() { + final StreamsBuilder builder = new StreamsBuilder(); + final KTable, Long> valueCounts = builder + .stream("input", Consumed.with(STRING_SERDE, STRING_SERDE)) + .groupBy((String k, String v) -> k, Grouped.with(STRING_SERDE, STRING_SERDE)) + .windowedBy(SessionWindows.with(ofMillis(5L)).grace(ofMillis(0L))) + .count(Materialized.>as("counts").withCachingDisabled()); + valueCounts + .suppress(untilWindowCloses(unbounded())) + .toStream() + .map((final Windowed k, final Long v) -> new KeyValue<>(k.toString(), v)) + .to("output-suppressed", Produced.with(STRING_SERDE, Serdes.Long())); + valueCounts + .toStream() + .map((final Windowed k, final Long v) -> new KeyValue<>(k.toString(), v)) + .to("output-raw", Produced.with(STRING_SERDE, Serdes.Long())); + final Topology topology = builder.build(); + System.out.println(topology.describe()); + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, config)) { + final TestInputTopic inputTopic = + driver.createInputTopic("input", STRING_SERIALIZER, STRING_SERIALIZER); + // first window + inputTopic.pipeInput("k1", "v1", 0L); + inputTopic.pipeInput("k1", "v1", 5L); + // arbitrarily disordered records are admitted, because the *window* is not closed until stream-time > window-end + grace + inputTopic.pipeInput("k1", "v1", 1L); + // any record in the same partition advances stream time (note the key is different) + inputTopic.pipeInput("k2", "v1", 11L); + // late event for first window - this should get dropped from all streams, since the first window is now closed. + inputTopic.pipeInput("k1", "v1", 5L); + // just pushing stream time forward to flush the other events through. + inputTopic.pipeInput("k1", "v1", 30L); + verify( + drainProducerRecords(driver, "output-raw", STRING_DESERIALIZER, LONG_DESERIALIZER), + asList( + new KeyValueTimestamp<>("[k1@0/0]", 1L, 0L), + new KeyValueTimestamp<>("[k1@0/0]", null, 0L), + new KeyValueTimestamp<>("[k1@0/5]", 2L, 5L), + new KeyValueTimestamp<>("[k1@0/5]", null, 5L), + new KeyValueTimestamp<>("[k1@0/5]", 3L, 5L), + new KeyValueTimestamp<>("[k2@11/11]", 1L, 11L), + new KeyValueTimestamp<>("[k1@30/30]", 1L, 30L) + ) + ); + verify( + drainProducerRecords(driver, "output-suppressed", STRING_DESERIALIZER, LONG_DESERIALIZER), + asList( + new KeyValueTimestamp<>("[k1@0/5]", 3L, 5L), + new KeyValueTimestamp<>("[k2@11/11]", 1L, 11L) + ) + ); + } + } + + @Test + public void shouldWorkBeforeGroupBy() { + final StreamsBuilder builder = new StreamsBuilder(); + + builder + .table("topic", Consumed.with(Serdes.String(), Serdes.String())) + .suppress(untilTimeLimit(ofMillis(10), unbounded())) + .groupBy(KeyValue::pair, Grouped.with(Serdes.String(), Serdes.String())) + .count() + .toStream() + .to("output", Produced.with(Serdes.String(), Serdes.Long())); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), config)) { + final TestInputTopic inputTopic = + driver.createInputTopic("topic", STRING_SERIALIZER, STRING_SERIALIZER); + + inputTopic.pipeInput("A", "a", 0L); + inputTopic.pipeInput("tick", "tick", 10L); + + verify( + drainProducerRecords(driver, "output", STRING_DESERIALIZER, LONG_DESERIALIZER), + singletonList(new KeyValueTimestamp<>("A", 1L, 0L)) + ); + } + } + + @Test + public void shouldWorkBeforeJoinRight() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KTable left = builder + .table("left", Consumed.with(Serdes.String(), Serdes.String())); + + final KTable right = builder + .table("right", Consumed.with(Serdes.String(), Serdes.String())) + .suppress(untilTimeLimit(ofMillis(10), unbounded())); + + left + .outerJoin(right, (l, r) -> String.format("(%s,%s)", l, r)) + .toStream() + .to("output", Produced.with(Serdes.String(), Serdes.String())); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), config)) { + final TestInputTopic inputTopicRight = + driver.createInputTopic("right", STRING_SERIALIZER, STRING_SERIALIZER); + final TestInputTopic inputTopicLeft = + driver.createInputTopic("left", STRING_SERIALIZER, STRING_SERIALIZER); + + inputTopicRight.pipeInput("B", "1", 0L); + inputTopicRight.pipeInput("A", "1", 0L); + // buffered, no output + verify( + drainProducerRecords(driver, "output", STRING_DESERIALIZER, STRING_DESERIALIZER), + emptyList() + ); + + + inputTopicRight.pipeInput("tick", "tick", 10L); + // flush buffer + verify( + drainProducerRecords(driver, "output", STRING_DESERIALIZER, STRING_DESERIALIZER), + asList( + new KeyValueTimestamp<>("A", "(null,1)", 0L), + new KeyValueTimestamp<>("B", "(null,1)", 0L) + ) + ); + + + inputTopicRight.pipeInput("A", "2", 11L); + // buffered, no output + verify( + drainProducerRecords(driver, "output", STRING_DESERIALIZER, STRING_DESERIALIZER), + emptyList() + ); + + + inputTopicLeft.pipeInput("A", "a", 12L); + // should join with previously emitted right side + verify( + drainProducerRecords(driver, "output", STRING_DESERIALIZER, STRING_DESERIALIZER), + singletonList(new KeyValueTimestamp<>("A", "(a,1)", 12L)) + ); + + + inputTopicLeft.pipeInput("B", "b", 12L); + // should view through to the parent KTable, since B is no longer buffered + verify( + drainProducerRecords(driver, "output", STRING_DESERIALIZER, STRING_DESERIALIZER), + singletonList(new KeyValueTimestamp<>("B", "(b,1)", 12L)) + ); + + + inputTopicLeft.pipeInput("A", "b", 13L); + // should join with previously emitted right side + verify( + drainProducerRecords(driver, "output", STRING_DESERIALIZER, STRING_DESERIALIZER), + singletonList(new KeyValueTimestamp<>("A", "(b,1)", 13L)) + ); + + + inputTopicRight.pipeInput("tick", "tick1", 21L); + verify( + drainProducerRecords(driver, "output", STRING_DESERIALIZER, STRING_DESERIALIZER), + asList( + new KeyValueTimestamp<>("tick", "(null,tick1)", 21), // just a testing artifact + new KeyValueTimestamp<>("A", "(b,2)", 13L) + ) + ); + } + + } + + + @Test + public void shouldWorkBeforeJoinLeft() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KTable left = builder + .table("left", Consumed.with(Serdes.String(), Serdes.String())) + .suppress(untilTimeLimit(ofMillis(10), unbounded())); + + final KTable right = builder + .table("right", Consumed.with(Serdes.String(), Serdes.String())); + + left + .outerJoin(right, (l, r) -> String.format("(%s,%s)", l, r)) + .toStream() + .to("output", Produced.with(Serdes.String(), Serdes.String())); + + final Topology topology = builder.build(); + try (final TopologyTestDriver driver = new TopologyTestDriver(topology, config)) { + final TestInputTopic inputTopicRight = + driver.createInputTopic("right", STRING_SERIALIZER, STRING_SERIALIZER); + final TestInputTopic inputTopicLeft = + driver.createInputTopic("left", STRING_SERIALIZER, STRING_SERIALIZER); + + inputTopicLeft.pipeInput("B", "1", 0L); + inputTopicLeft.pipeInput("A", "1", 0L); + // buffered, no output + verify( + drainProducerRecords(driver, "output", STRING_DESERIALIZER, STRING_DESERIALIZER), + emptyList() + ); + + + inputTopicLeft.pipeInput("tick", "tick", 10L); + // flush buffer + verify( + drainProducerRecords(driver, "output", STRING_DESERIALIZER, STRING_DESERIALIZER), + asList( + new KeyValueTimestamp<>("A", "(1,null)", 0L), + new KeyValueTimestamp<>("B", "(1,null)", 0L) + ) + ); + + + inputTopicLeft.pipeInput("A", "2", 11L); + // buffered, no output + verify( + drainProducerRecords(driver, "output", STRING_DESERIALIZER, STRING_DESERIALIZER), + emptyList() + ); + + + inputTopicRight.pipeInput("A", "a", 12L); + // should join with previously emitted left side + verify( + drainProducerRecords(driver, "output", STRING_DESERIALIZER, STRING_DESERIALIZER), + singletonList(new KeyValueTimestamp<>("A", "(1,a)", 12L)) + ); + + + inputTopicRight.pipeInput("B", "b", 12L); + // should view through to the parent KTable, since B is no longer buffered + verify( + drainProducerRecords(driver, "output", STRING_DESERIALIZER, STRING_DESERIALIZER), + singletonList(new KeyValueTimestamp<>("B", "(1,b)", 12L)) + ); + + + inputTopicRight.pipeInput("A", "b", 13L); + // should join with previously emitted left side + verify( + drainProducerRecords(driver, "output", STRING_DESERIALIZER, STRING_DESERIALIZER), + singletonList(new KeyValueTimestamp<>("A", "(1,b)", 13L)) + ); + + + inputTopicLeft.pipeInput("tick", "tick1", 21L); + verify( + drainProducerRecords(driver, "output", STRING_DESERIALIZER, STRING_DESERIALIZER), + asList( + new KeyValueTimestamp<>("tick", "(tick1,null)", 21), // just a testing artifact + new KeyValueTimestamp<>("A", "(2,b)", 13L) + ) + ); + } + + } + + @Test + public void shouldWorkWithCogrouped() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KGroupedStream stream1 = builder.stream("one", Consumed.with(Serdes.String(), Serdes.String())).groupByKey(Grouped.with(Serdes.String(), Serdes.String())); + final KGroupedStream stream2 = builder.stream("two", Consumed.with(Serdes.String(), Serdes.String())).groupByKey(Grouped.with(Serdes.String(), Serdes.String())); + final KStream, Object> cogrouped = stream1.cogroup((key, value, aggregate) -> aggregate + value).cogroup(stream2, (key, value, aggregate) -> aggregate + value) + .windowedBy(TimeWindows.of(Duration.ofMinutes(15))) + .aggregate(() -> "", Named.as("test"), Materialized.as("store")) + .suppress(Suppressed.untilWindowCloses(unbounded())) + .toStream(); + } + + private static void verify(final List> results, + final List> expectedResults) { + if (results.size() != expectedResults.size()) { + throw new AssertionError(printRecords(results) + " != " + expectedResults); + } + final Iterator> expectedIterator = expectedResults.iterator(); + for (final TestRecord result : results) { + final KeyValueTimestamp expected = expectedIterator.next(); + try { + assertThat(result, equalTo(new TestRecord<>(expected.key(), expected.value(), null, expected.timestamp()))); + } catch (final AssertionError e) { + throw new AssertionError(printRecords(results) + " != " + expectedResults, e); + } + } + } + + private static List> drainProducerRecords(final TopologyTestDriver driver, + final String topic, + final Deserializer keyDeserializer, + final Deserializer valueDeserializer) { + return driver.createOutputTopic(topic, keyDeserializer, valueDeserializer).readRecordsToList(); + } + + private static String printRecords(final List> result) { + final StringBuilder resultStr = new StringBuilder(); + resultStr.append("[\n"); + for (final TestRecord record : result) { + resultStr.append(" ").append(record).append("\n"); + } + resultStr.append("]"); + return resultStr.toString(); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/SuppressTopologyTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/SuppressTopologyTest.java new file mode 100644 index 0000000..d775796 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/SuppressTopologyTest.java @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.SessionWindows; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.SessionStore; +import org.junit.Test; + +import java.time.Duration; + +import static java.time.Duration.ofMillis; +import static org.apache.kafka.streams.kstream.Suppressed.BufferConfig.unbounded; +import static org.apache.kafka.streams.kstream.Suppressed.untilTimeLimit; +import static org.apache.kafka.streams.kstream.Suppressed.untilWindowCloses; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.Is.is; + +@SuppressWarnings("deprecation") +public class SuppressTopologyTest { + private static final Serde STRING_SERDE = Serdes.String(); + + private static final String NAMED_FINAL_TOPOLOGY = "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input])\n" + + " --> KSTREAM-KEY-SELECT-0000000001\n" + + " Processor: KSTREAM-KEY-SELECT-0000000001 (stores: [])\n" + + " --> counts-repartition-filter\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: counts-repartition-filter (stores: [])\n" + + " --> counts-repartition-sink\n" + + " <-- KSTREAM-KEY-SELECT-0000000001\n" + + " Sink: counts-repartition-sink (topic: counts-repartition)\n" + + " <-- counts-repartition-filter\n" + + "\n" + + " Sub-topology: 1\n" + + " Source: counts-repartition-source (topics: [counts-repartition])\n" + + " --> KSTREAM-AGGREGATE-0000000002\n" + + " Processor: KSTREAM-AGGREGATE-0000000002 (stores: [counts])\n" + + " --> myname\n" + + " <-- counts-repartition-source\n" + + " Processor: myname (stores: [myname-store])\n" + + " --> KTABLE-TOSTREAM-0000000006\n" + + " <-- KSTREAM-AGGREGATE-0000000002\n" + + " Processor: KTABLE-TOSTREAM-0000000006 (stores: [])\n" + + " --> KSTREAM-MAP-0000000007\n" + + " <-- myname\n" + + " Processor: KSTREAM-MAP-0000000007 (stores: [])\n" + + " --> KSTREAM-SINK-0000000008\n" + + " <-- KTABLE-TOSTREAM-0000000006\n" + + " Sink: KSTREAM-SINK-0000000008 (topic: output-suppressed)\n" + + " <-- KSTREAM-MAP-0000000007\n" + + "\n"; + + private static final String ANONYMOUS_FINAL_TOPOLOGY = "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input])\n" + + " --> KSTREAM-KEY-SELECT-0000000001\n" + + " Processor: KSTREAM-KEY-SELECT-0000000001 (stores: [])\n" + + " --> counts-repartition-filter\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: counts-repartition-filter (stores: [])\n" + + " --> counts-repartition-sink\n" + + " <-- KSTREAM-KEY-SELECT-0000000001\n" + + " Sink: counts-repartition-sink (topic: counts-repartition)\n" + + " <-- counts-repartition-filter\n" + + "\n" + + " Sub-topology: 1\n" + + " Source: counts-repartition-source (topics: [counts-repartition])\n" + + " --> KSTREAM-AGGREGATE-0000000002\n" + + " Processor: KSTREAM-AGGREGATE-0000000002 (stores: [counts])\n" + + " --> KTABLE-SUPPRESS-0000000006\n" + + " <-- counts-repartition-source\n" + + " Processor: KTABLE-SUPPRESS-0000000006 (stores: [KTABLE-SUPPRESS-STATE-STORE-0000000007])\n" + + " --> KTABLE-TOSTREAM-0000000008\n" + + " <-- KSTREAM-AGGREGATE-0000000002\n" + + " Processor: KTABLE-TOSTREAM-0000000008 (stores: [])\n" + + " --> KSTREAM-MAP-0000000009\n" + + " <-- KTABLE-SUPPRESS-0000000006\n" + + " Processor: KSTREAM-MAP-0000000009 (stores: [])\n" + + " --> KSTREAM-SINK-0000000010\n" + + " <-- KTABLE-TOSTREAM-0000000008\n" + + " Sink: KSTREAM-SINK-0000000010 (topic: output-suppressed)\n" + + " <-- KSTREAM-MAP-0000000009\n" + + "\n"; + + private static final String NAMED_INTERMEDIATE_TOPOLOGY = "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input])\n" + + " --> KSTREAM-AGGREGATE-0000000002\n" + + " Processor: KSTREAM-AGGREGATE-0000000002 (stores: [KSTREAM-AGGREGATE-STATE-STORE-0000000001])\n" + + " --> asdf\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: asdf (stores: [asdf-store])\n" + + " --> KTABLE-TOSTREAM-0000000003\n" + + " <-- KSTREAM-AGGREGATE-0000000002\n" + + " Processor: KTABLE-TOSTREAM-0000000003 (stores: [])\n" + + " --> KSTREAM-SINK-0000000004\n" + + " <-- asdf\n" + + " Sink: KSTREAM-SINK-0000000004 (topic: output)\n" + + " <-- KTABLE-TOSTREAM-0000000003\n" + + "\n"; + + private static final String ANONYMOUS_INTERMEDIATE_TOPOLOGY = "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input])\n" + + " --> KSTREAM-AGGREGATE-0000000002\n" + + " Processor: KSTREAM-AGGREGATE-0000000002 (stores: [KSTREAM-AGGREGATE-STATE-STORE-0000000001])\n" + + " --> KTABLE-SUPPRESS-0000000003\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: KTABLE-SUPPRESS-0000000003 (stores: [KTABLE-SUPPRESS-STATE-STORE-0000000004])\n" + + " --> KTABLE-TOSTREAM-0000000005\n" + + " <-- KSTREAM-AGGREGATE-0000000002\n" + + " Processor: KTABLE-TOSTREAM-0000000005 (stores: [])\n" + + " --> KSTREAM-SINK-0000000006\n" + + " <-- KTABLE-SUPPRESS-0000000003\n" + + " Sink: KSTREAM-SINK-0000000006 (topic: output)\n" + + " <-- KTABLE-TOSTREAM-0000000005\n" + + "\n"; + + + @Test + public void shouldUseNumberingForAnonymousFinalSuppressionNode() { + final StreamsBuilder anonymousNodeBuilder = new StreamsBuilder(); + anonymousNodeBuilder + .stream("input", Consumed.with(STRING_SERDE, STRING_SERDE)) + .groupBy((String k, String v) -> k, Grouped.with(STRING_SERDE, STRING_SERDE)) + .windowedBy(SessionWindows.with(ofMillis(5L)).grace(ofMillis(5L))) + .count(Materialized.>as("counts").withCachingDisabled()) + .suppress(untilWindowCloses(unbounded())) + .toStream() + .map((final Windowed k, final Long v) -> new KeyValue<>(k.toString(), v)) + .to("output-suppressed", Produced.with(STRING_SERDE, Serdes.Long())); + final String anonymousNodeTopology = anonymousNodeBuilder.build().describe().toString(); + + // without the name, the suppression node increments the topology index + assertThat(anonymousNodeTopology, is(ANONYMOUS_FINAL_TOPOLOGY)); + } + + @Test + public void shouldApplyNameToFinalSuppressionNode() { + final StreamsBuilder namedNodeBuilder = new StreamsBuilder(); + namedNodeBuilder + .stream("input", Consumed.with(STRING_SERDE, STRING_SERDE)) + .groupBy((String k, String v) -> k, Grouped.with(STRING_SERDE, STRING_SERDE)) + .windowedBy(SessionWindows.with(ofMillis(5L)).grace(ofMillis(5L))) + .count(Materialized.>as("counts").withCachingDisabled()) + .suppress(untilWindowCloses(unbounded()).withName("myname")) + .toStream() + .map((final Windowed k, final Long v) -> new KeyValue<>(k.toString(), v)) + .to("output-suppressed", Produced.with(STRING_SERDE, Serdes.Long())); + final String namedNodeTopology = namedNodeBuilder.build().describe().toString(); + + // without the name, the suppression node does not increment the topology index + assertThat(namedNodeTopology, is(NAMED_FINAL_TOPOLOGY)); + } + + @Test + public void shouldUseNumberingForAnonymousSuppressionNode() { + final StreamsBuilder anonymousNodeBuilder = new StreamsBuilder(); + anonymousNodeBuilder + .stream("input", Consumed.with(STRING_SERDE, STRING_SERDE)) + .groupByKey() + .count() + .suppress(untilTimeLimit(Duration.ofSeconds(1), unbounded())) + .toStream() + .to("output", Produced.with(STRING_SERDE, Serdes.Long())); + final String anonymousNodeTopology = anonymousNodeBuilder.build().describe().toString(); + + // without the name, the suppression node increments the topology index + assertThat(anonymousNodeTopology, is(ANONYMOUS_INTERMEDIATE_TOPOLOGY)); + } + + @Test + public void shouldApplyNameToSuppressionNode() { + final StreamsBuilder namedNodeBuilder = new StreamsBuilder(); + namedNodeBuilder + .stream("input", Consumed.with(STRING_SERDE, STRING_SERDE)) + .groupByKey() + .count() + .suppress(untilTimeLimit(Duration.ofSeconds(1), unbounded()).withName("asdf")) + .toStream() + .to("output", Produced.with(STRING_SERDE, Serdes.Long())); + final String namedNodeTopology = namedNodeBuilder.build().describe().toString(); + + // without the name, the suppression node does not increment the topology index + assertThat(namedNodeTopology, is(NAMED_INTERMEDIATE_TOPOLOGY)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/TimeWindowTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/TimeWindowTest.java new file mode 100644 index 0000000..bdfdb16 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/TimeWindowTest.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.kstream.TimeWindows; +import org.junit.Test; + +import java.util.Map; + +import static java.time.Duration.ofMillis; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class TimeWindowTest { + + private final long start = 50; + private final long end = 100; + private final TimeWindow window = new TimeWindow(start, end); + private final SessionWindow sessionWindow = new SessionWindow(start, end); + + @Test + public void endMustBeLargerThanStart() { + assertThrows(IllegalArgumentException.class, () -> new TimeWindow(start, start)); + } + + @Test + public void shouldNotOverlapIfOtherWindowIsBeforeThisWindow() { + /* + * This: [-------) + * Other: [-----) + */ + assertFalse(window.overlap(new TimeWindow(0, 25))); + assertFalse(window.overlap(new TimeWindow(0, start - 1))); + assertFalse(window.overlap(new TimeWindow(0, start))); + } + + @Test + public void shouldOverlapIfOtherWindowEndIsWithinThisWindow() { + /* + * This: [-------) + * Other: [---------) + */ + assertTrue(window.overlap(new TimeWindow(0, start + 1))); + assertTrue(window.overlap(new TimeWindow(0, 75))); + assertTrue(window.overlap(new TimeWindow(0, end - 1))); + + assertTrue(window.overlap(new TimeWindow(start - 1, start + 1))); + assertTrue(window.overlap(new TimeWindow(start - 1, 75))); + assertTrue(window.overlap(new TimeWindow(start - 1, end - 1))); + } + + @Test + public void shouldOverlapIfOtherWindowContainsThisWindow() { + /* + * This: [-------) + * Other: [------------------) + */ + assertTrue(window.overlap(new TimeWindow(0, end))); + assertTrue(window.overlap(new TimeWindow(0, end + 1))); + assertTrue(window.overlap(new TimeWindow(0, 150))); + + assertTrue(window.overlap(new TimeWindow(start - 1, end))); + assertTrue(window.overlap(new TimeWindow(start - 1, end + 1))); + assertTrue(window.overlap(new TimeWindow(start - 1, 150))); + + assertTrue(window.overlap(new TimeWindow(start, end))); + assertTrue(window.overlap(new TimeWindow(start, end + 1))); + assertTrue(window.overlap(new TimeWindow(start, 150))); + } + + @Test + public void shouldOverlapIfOtherWindowIsWithinThisWindow() { + /* + * This: [-------) + * Other: [---) + */ + assertTrue(window.overlap(new TimeWindow(start, 75))); + assertTrue(window.overlap(new TimeWindow(start, end))); + assertTrue(window.overlap(new TimeWindow(75, end))); + } + + @Test + public void shouldOverlapIfOtherWindowStartIsWithinThisWindow() { + /* + * This: [-------) + * Other: [-------) + */ + assertTrue(window.overlap(new TimeWindow(start, end + 1))); + assertTrue(window.overlap(new TimeWindow(start, 150))); + assertTrue(window.overlap(new TimeWindow(75, end + 1))); + assertTrue(window.overlap(new TimeWindow(75, 150))); + } + + @Test + public void shouldNotOverlapIsOtherWindowIsAfterThisWindow() { + /* + * This: [-------) + * Other: [------) + */ + assertFalse(window.overlap(new TimeWindow(end, end + 1))); + assertFalse(window.overlap(new TimeWindow(end, 150))); + assertFalse(window.overlap(new TimeWindow(end + 1, 150))); + assertFalse(window.overlap(new TimeWindow(125, 150))); + } + + @Test + public void cannotCompareTimeWindowWithDifferentWindowType() { + assertThrows(IllegalArgumentException.class, () -> window.overlap(sessionWindow)); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldReturnMatchedWindowsOrderedByTimestamp() { + final TimeWindows windows = TimeWindows.of(ofMillis(12L)).advanceBy(ofMillis(5L)); + final Map matched = windows.windowsFor(21L); + + final Long[] expected = matched.keySet().toArray(new Long[0]); + assertEquals(expected[0].longValue(), 10L); + assertEquals(expected[1].longValue(), 15L); + assertEquals(expected[2].longValue(), 20L); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/TimeWindowedCogroupedKStreamImplTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/TimeWindowedCogroupedKStreamImplTest.java new file mode 100644 index 0000000..cd9ca19 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/TimeWindowedCogroupedKStreamImplTest.java @@ -0,0 +1,324 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals; + +import static java.time.Duration.ofMillis; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertThrows; + +import java.util.Properties; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TestOutputTopic; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.CogroupedKStream; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Named; +import org.apache.kafka.streams.kstream.TimeWindowedCogroupedKStream; +import org.apache.kafka.streams.kstream.TimeWindowedDeserializer; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.streams.test.TestRecord; +import org.apache.kafka.test.MockAggregator; +import org.apache.kafka.test.MockInitializer; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Before; +import org.junit.Test; + +@SuppressWarnings("deprecation") +public class TimeWindowedCogroupedKStreamImplTest { + + private static final Long WINDOW_SIZE = 500L; + private static final String TOPIC = "topic"; + private static final String TOPIC2 = "topic2"; + private static final String OUTPUT = "output"; + private final StreamsBuilder builder = new StreamsBuilder(); + + private KGroupedStream groupedStream; + + private KGroupedStream groupedStream2; + private CogroupedKStream cogroupedStream; + private TimeWindowedCogroupedKStream windowedCogroupedStream; + + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.String(), Serdes.String()); + + @Before + public void setup() { + final KStream stream = builder.stream(TOPIC, Consumed + .with(Serdes.String(), Serdes.String())); + final KStream stream2 = builder.stream(TOPIC2, Consumed + .with(Serdes.String(), Serdes.String())); + + groupedStream = stream.groupByKey(Grouped.with(Serdes.String(), Serdes.String())); + groupedStream2 = stream2.groupByKey(Grouped.with(Serdes.String(), Serdes.String())); + cogroupedStream = groupedStream.cogroup(MockAggregator.TOSTRING_ADDER) + .cogroup(groupedStream2, MockAggregator.TOSTRING_REMOVER); + windowedCogroupedStream = cogroupedStream.windowedBy(TimeWindows.of(ofMillis(WINDOW_SIZE))); + } + + @Test + public void shouldNotHaveNullInitializerOnAggregate() { + assertThrows(NullPointerException.class, () -> windowedCogroupedStream.aggregate(null)); + } + + @Test + public void shouldNotHaveNullMaterializedOnTwoOptionAggregate() { + assertThrows(NullPointerException.class, () -> windowedCogroupedStream.aggregate(MockInitializer.STRING_INIT, + (Materialized>) null)); + } + + @Test + public void shouldNotHaveNullNamedTwoOptionOnAggregate() { + assertThrows(NullPointerException.class, () -> windowedCogroupedStream.aggregate(MockInitializer.STRING_INIT, (Named) null)); + } + + @Test + public void shouldNotHaveNullInitializerTwoOptionNamedOnAggregate() { + assertThrows(NullPointerException.class, () -> windowedCogroupedStream.aggregate(null, Named.as("test"))); + } + + @Test + public void shouldNotHaveNullInitializerTwoOptionMaterializedOnAggregate() { + assertThrows(NullPointerException.class, () -> windowedCogroupedStream.aggregate(null, Materialized.as("test"))); + } + + @Test + public void shouldNotHaveNullInitializerThreeOptionOnAggregate() { + assertThrows(NullPointerException.class, () -> windowedCogroupedStream.aggregate(null, Named.as("test"), Materialized.as("test"))); + } + + @Test + public void shouldNotHaveNullMaterializedOnAggregate() { + assertThrows(NullPointerException.class, () -> windowedCogroupedStream.aggregate(MockInitializer.STRING_INIT, Named.as("Test"), null)); + } + + @Test + public void shouldNotHaveNullNamedOnAggregate() { + assertThrows(NullPointerException.class, () -> windowedCogroupedStream.aggregate(MockInitializer.STRING_INIT, null, Materialized.as("test"))); + } + + @Test + public void namedParamShouldSetName() { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream stream = builder.stream(TOPIC, Consumed + .with(Serdes.String(), Serdes.String())); + groupedStream = stream.groupByKey(Grouped.with(Serdes.String(), Serdes.String())); + groupedStream.cogroup(MockAggregator.TOSTRING_ADDER) + .windowedBy(TimeWindows.of(ofMillis(WINDOW_SIZE))) + .aggregate(MockInitializer.STRING_INIT, Named.as("foo")); + + assertThat(builder.build().describe().toString(), equalTo( + "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [topic])\n" + + " --> foo-cogroup-agg-0\n" + + " Processor: foo-cogroup-agg-0 (stores: [COGROUPKSTREAM-AGGREGATE-STATE-STORE-0000000001])\n" + + " --> foo-cogroup-merge\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: foo-cogroup-merge (stores: [])\n" + + " --> none\n" + + " <-- foo-cogroup-agg-0\n\n")); + } + + @Test + public void timeWindowAggregateTestStreamsTest() { + + final KTable, String> customers = windowedCogroupedStream.aggregate( + MockInitializer.STRING_INIT, Materialized.with(Serdes.String(), Serdes.String())); + customers.toStream().to(OUTPUT); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic testInputTopic = driver.createInputTopic( + TOPIC, new StringSerializer(), new StringSerializer()); + final TestOutputTopic, String> testOutputTopic = driver.createOutputTopic( + OUTPUT, new TimeWindowedDeserializer<>(new StringDeserializer(), WINDOW_SIZE), new StringDeserializer()); + + testInputTopic.pipeInput("k1", "A", 0); + testInputTopic.pipeInput("k2", "A", 0); + testInputTopic.pipeInput("k2", "A", 1); + testInputTopic.pipeInput("k1", "A", 2); + testInputTopic.pipeInput("k1", "B", 3); + testInputTopic.pipeInput("k2", "B", 3); + testInputTopic.pipeInput("k2", "B", 4); + testInputTopic.pipeInput("k1", "B", 4); + + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+A", 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+A", 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+A+A", 1); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+A+A", 2); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+A+A+B", 3); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+A+A+B", 3); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+A+A+B+B", 4); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+A+A+B+B", 4); + } + + } + + @Test + public void timeWindowMixAggregatorsTest() { + + final KTable, String> customers = windowedCogroupedStream.aggregate( + MockInitializer.STRING_INIT, Materialized.with(Serdes.String(), Serdes.String())); + customers.toStream().to(OUTPUT); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic testInputTopic = driver.createInputTopic( + TOPIC, new StringSerializer(), new StringSerializer()); + final TestInputTopic testInputTopic2 = driver.createInputTopic( + TOPIC2, new StringSerializer(), new StringSerializer()); + final TestOutputTopic, String> testOutputTopic = driver.createOutputTopic( + OUTPUT, new TimeWindowedDeserializer<>(new StringDeserializer(), WINDOW_SIZE), new StringDeserializer()); + + testInputTopic.pipeInput("k1", "A", 0); + testInputTopic.pipeInput("k2", "A", 0); + testInputTopic.pipeInput("k2", "A", 1); + testInputTopic.pipeInput("k1", "A", 2); + testInputTopic2.pipeInput("k1", "B", 3); + testInputTopic2.pipeInput("k2", "B", 3); + testInputTopic2.pipeInput("k2", "B", 4); + testInputTopic2.pipeInput("k1", "B", 4); + + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+A", 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+A", 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+A+A", 1); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+A+A", 2); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+A+A-B", 3); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+A+A-B", 3); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+A+A-B-B", 4); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+A+A-B-B", 4); + } + + } + + @Test + public void timeWindowAggregateManyWindowsTest() { + + final KTable, String> customers = groupedStream.cogroup(MockAggregator.TOSTRING_ADDER) + .windowedBy(TimeWindows.of(ofMillis(500L))).aggregate( + MockInitializer.STRING_INIT, Materialized.with(Serdes.String(), Serdes.String())); + customers.toStream().to(OUTPUT); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic testInputTopic = driver.createInputTopic( + TOPIC, new StringSerializer(), new StringSerializer()); + final TestOutputTopic, String> testOutputTopic = driver.createOutputTopic( + OUTPUT, new TimeWindowedDeserializer<>(new StringDeserializer(), WINDOW_SIZE), new StringDeserializer()); + + testInputTopic.pipeInput("k1", "A", 0); + testInputTopic.pipeInput("k2", "A", 499); + testInputTopic.pipeInput("k2", "A", 500L); + testInputTopic.pipeInput("k1", "A", 500L); + + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+A", 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+A", 499); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+A", 500); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+A", 500); + } + } + + @Test + public void timeWindowAggregateOverlappingWindowsTest() { + + final KTable, String> customers = groupedStream.cogroup(MockAggregator.TOSTRING_ADDER) + .windowedBy(TimeWindows.of(ofMillis(500L)).advanceBy(ofMillis(200L))).aggregate( + MockInitializer.STRING_INIT, Materialized.with(Serdes.String(), Serdes.String())); + customers.toStream().to(OUTPUT); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic testInputTopic = driver.createInputTopic( + TOPIC, new StringSerializer(), new StringSerializer()); + final TestOutputTopic, String> testOutputTopic = driver.createOutputTopic( + OUTPUT, new TimeWindowedDeserializer<>(new StringDeserializer(), WINDOW_SIZE), new StringDeserializer()); + + testInputTopic.pipeInput("k1", "A", 0); + testInputTopic.pipeInput("k2", "A", 0); + testInputTopic.pipeInput("k1", "B", 250); + testInputTopic.pipeInput("k2", "B", 250); + testInputTopic.pipeInput("k2", "A", 500L); + testInputTopic.pipeInput("k1", "A", 500L); + + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+A", 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+A", 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+A+B", 250); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+B", 250); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+A+B", 250); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+B", 250); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+B+A", 500); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+A", 500); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+B+A", 500); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+A", 500); + } + } + + @Test + public void timeWindowMixAggregatorsManyWindowsTest() { + + final KTable, String> customers = windowedCogroupedStream.aggregate( + MockInitializer.STRING_INIT, Materialized.with(Serdes.String(), Serdes.String())); + customers.toStream().to(OUTPUT); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + final TestInputTopic testInputTopic = driver.createInputTopic( + TOPIC, new StringSerializer(), new StringSerializer()); + final TestInputTopic testInputTopic2 = driver.createInputTopic( + TOPIC2, new StringSerializer(), new StringSerializer()); + final TestOutputTopic, String> testOutputTopic = driver.createOutputTopic( + OUTPUT, new TimeWindowedDeserializer<>(new StringDeserializer(), WINDOW_SIZE), new StringDeserializer()); + + testInputTopic.pipeInput("k1", "A", 0); + testInputTopic.pipeInput("k2", "A", 0); + testInputTopic.pipeInput("k2", "A", 1); + testInputTopic.pipeInput("k1", "A", 2); + testInputTopic2.pipeInput("k1", "B", 3); + testInputTopic2.pipeInput("k2", "B", 3); + testInputTopic2.pipeInput("k2", "B", 501); + testInputTopic2.pipeInput("k1", "B", 501); + + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+A", 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+A", 0); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+A+A", 1); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+A+A", 2); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0+A+A-B", 3); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0+A+A-B", 3); + assertOutputKeyValueTimestamp(testOutputTopic, "k2", "0-B", 501); + assertOutputKeyValueTimestamp(testOutputTopic, "k1", "0-B", 501); + } + } + + private void assertOutputKeyValueTimestamp(final TestOutputTopic, String> outputTopic, + final String expectedKey, + final String expectedValue, + final long expectedTimestamp) { + final TestRecord, String> realRecord = outputTopic.readRecord(); + final TestRecord nonWindowedRecord = new TestRecord<>( + realRecord.getKey().key(), realRecord.getValue(), null, realRecord.timestamp()); + final TestRecord testRecord = new TestRecord<>(expectedKey, expectedValue, null, expectedTimestamp); + assertThat(nonWindowedRecord, equalTo(testRecord)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/TimeWindowedKStreamImplTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/TimeWindowedKStreamImplTest.java new file mode 100644 index 0000000..aef79c8 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/TimeWindowedKStreamImplTest.java @@ -0,0 +1,325 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Named; +import org.apache.kafka.streams.kstream.TimeWindowedKStream; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.test.MockAggregator; +import org.apache.kafka.test.MockInitializer; +import org.apache.kafka.test.MockProcessorSupplier; +import org.apache.kafka.test.MockReducer; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.List; +import java.util.Properties; + +import static java.time.Duration.ofMillis; +import static java.time.Instant.ofEpochMilli; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertThrows; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class TimeWindowedKStreamImplTest { + private static final String TOPIC = "input"; + private final StreamsBuilder builder = new StreamsBuilder(); + private final Properties props = StreamsTestUtils.getStreamsConfig(Serdes.String(), Serdes.String()); + private TimeWindowedKStream windowedStream; + + @Before + public void before() { + final KStream stream = builder.stream(TOPIC, Consumed.with(Serdes.String(), Serdes.String())); + windowedStream = stream. + groupByKey(Grouped.with(Serdes.String(), Serdes.String())) + .windowedBy(TimeWindows.of(ofMillis(500L))); + } + + @Test + public void shouldCountWindowed() { + final MockProcessorSupplier, Long> supplier = new MockProcessorSupplier<>(); + windowedStream + .count() + .toStream() + .process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + processData(driver); + } + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("1", new TimeWindow(0L, 500L))), + equalTo(ValueAndTimestamp.make(2L, 15L))); + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("2", new TimeWindow(500L, 1000L))), + equalTo(ValueAndTimestamp.make(2L, 550L))); + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("1", new TimeWindow(500L, 1000L))), + equalTo(ValueAndTimestamp.make(1L, 500L))); + } + + @Test + public void shouldReduceWindowed() { + final MockProcessorSupplier, String> supplier = new MockProcessorSupplier<>(); + windowedStream + .reduce(MockReducer.STRING_ADDER) + .toStream() + .process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + processData(driver); + } + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("1", new TimeWindow(0L, 500L))), + equalTo(ValueAndTimestamp.make("1+2", 15L))); + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("2", new TimeWindow(500L, 1000L))), + equalTo(ValueAndTimestamp.make("10+20", 550L))); + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("1", new TimeWindow(500L, 1000L))), + equalTo(ValueAndTimestamp.make("3", 500L))); + } + + @Test + public void shouldAggregateWindowed() { + final MockProcessorSupplier, String> supplier = new MockProcessorSupplier<>(); + windowedStream + .aggregate( + MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + Materialized.with(Serdes.String(), Serdes.String())) + .toStream() + .process(supplier); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + processData(driver); + } + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("1", new TimeWindow(0L, 500L))), + equalTo(ValueAndTimestamp.make("0+1+2", 15L))); + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("2", new TimeWindow(500L, 1000L))), + equalTo(ValueAndTimestamp.make("0+10+20", 550L))); + assertThat( + supplier.theCapturedProcessor().lastValueAndTimestampPerKey() + .get(new Windowed<>("1", new TimeWindow(500L, 1000L))), + equalTo(ValueAndTimestamp.make("0+3", 500L))); + } + + @Test + public void shouldMaterializeCount() { + windowedStream.count( + Materialized.>as("count-store") + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.Long())); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + processData(driver); + { + final WindowStore windowStore = driver.getWindowStore("count-store"); + final List, Long>> data = + StreamsTestUtils.toList(windowStore.fetch("1", "2", ofEpochMilli(0), ofEpochMilli(1000L))); + + assertThat(data, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("1", new TimeWindow(0, 500)), 2L), + KeyValue.pair(new Windowed<>("1", new TimeWindow(500, 1000)), 1L), + KeyValue.pair(new Windowed<>("2", new TimeWindow(500, 1000)), 2L)))); + } + { + final WindowStore> windowStore = + driver.getTimestampedWindowStore("count-store"); + final List, ValueAndTimestamp>> data = + StreamsTestUtils.toList(windowStore.fetch("1", "2", ofEpochMilli(0), ofEpochMilli(1000L))); + + assertThat(data, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("1", new TimeWindow(0, 500)), ValueAndTimestamp.make(2L, 15L)), + KeyValue.pair(new Windowed<>("1", new TimeWindow(500, 1000)), ValueAndTimestamp.make(1L, 500L)), + KeyValue.pair(new Windowed<>("2", new TimeWindow(500, 1000)), ValueAndTimestamp.make(2L, 550L))))); + } + } + } + + @Test + public void shouldMaterializeReduced() { + windowedStream.reduce( + MockReducer.STRING_ADDER, + Materialized.>as("reduced") + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.String())); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + processData(driver); + { + final WindowStore windowStore = driver.getWindowStore("reduced"); + final List, String>> data = + StreamsTestUtils.toList(windowStore.fetch("1", "2", ofEpochMilli(0), ofEpochMilli(1000L))); + + assertThat(data, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("1", new TimeWindow(0, 500)), "1+2"), + KeyValue.pair(new Windowed<>("1", new TimeWindow(500, 1000)), "3"), + KeyValue.pair(new Windowed<>("2", new TimeWindow(500, 1000)), "10+20")))); + } + { + final WindowStore> windowStore = driver.getTimestampedWindowStore("reduced"); + final List, ValueAndTimestamp>> data = + StreamsTestUtils.toList(windowStore.fetch("1", "2", ofEpochMilli(0), ofEpochMilli(1000L))); + + assertThat(data, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("1", new TimeWindow(0, 500)), ValueAndTimestamp.make("1+2", 15L)), + KeyValue.pair(new Windowed<>("1", new TimeWindow(500, 1000)), ValueAndTimestamp.make("3", 500L)), + KeyValue.pair(new Windowed<>("2", new TimeWindow(500, 1000)), ValueAndTimestamp.make("10+20", 550L))))); + } + } + } + + @Test + public void shouldMaterializeAggregated() { + windowedStream.aggregate( + MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + Materialized.>as("aggregated") + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.String())); + + try (final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), props)) { + processData(driver); + { + final WindowStore windowStore = driver.getWindowStore("aggregated"); + final List, String>> data = + StreamsTestUtils.toList(windowStore.fetch("1", "2", ofEpochMilli(0), ofEpochMilli(1000L))); + + assertThat(data, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("1", new TimeWindow(0, 500)), "0+1+2"), + KeyValue.pair(new Windowed<>("1", new TimeWindow(500, 1000)), "0+3"), + KeyValue.pair(new Windowed<>("2", new TimeWindow(500, 1000)), "0+10+20")))); + } + { + final WindowStore> windowStore = driver.getTimestampedWindowStore("aggregated"); + final List, ValueAndTimestamp>> data = + StreamsTestUtils.toList(windowStore.fetch("1", "2", ofEpochMilli(0), ofEpochMilli(1000L))); + + assertThat(data, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("1", new TimeWindow(0, 500)), ValueAndTimestamp.make("0+1+2", 15L)), + KeyValue.pair(new Windowed<>("1", new TimeWindow(500, 1000)), ValueAndTimestamp.make("0+3", 500L)), + KeyValue.pair(new Windowed<>("2", new TimeWindow(500, 1000)), ValueAndTimestamp.make("0+10+20", 550L))))); + } + } + } + + @Test + public void shouldThrowNullPointerOnAggregateIfInitializerIsNull() { + assertThrows(NullPointerException.class, () -> windowedStream.aggregate(null, MockAggregator.TOSTRING_ADDER)); + } + + @Test + public void shouldThrowNullPointerOnAggregateIfAggregatorIsNull() { + assertThrows(NullPointerException.class, () -> windowedStream.aggregate(MockInitializer.STRING_INIT, null)); + } + + @Test + public void shouldThrowNullPointerOnReduceIfReducerIsNull() { + assertThrows(NullPointerException.class, () -> windowedStream.reduce(null)); + } + + @Test + public void shouldThrowNullPointerOnMaterializedAggregateIfInitializerIsNull() { + assertThrows(NullPointerException.class, () -> windowedStream.aggregate( + null, + MockAggregator.TOSTRING_ADDER, + Materialized.as("store"))); + } + + @Test + public void shouldThrowNullPointerOnMaterializedAggregateIfAggregatorIsNull() { + assertThrows(NullPointerException.class, () -> windowedStream.aggregate( + MockInitializer.STRING_INIT, + null, + Materialized.as("store"))); + } + + @SuppressWarnings("unchecked") + @Test + public void shouldThrowNullPointerOnMaterializedAggregateIfMaterializedIsNull() { + assertThrows(NullPointerException.class, () -> windowedStream.aggregate( + MockInitializer.STRING_INIT, + MockAggregator.TOSTRING_ADDER, + (Materialized) null)); + } + + @Test + public void shouldThrowNullPointerOnMaterializedReduceIfReducerIsNull() { + assertThrows(NullPointerException.class, () -> windowedStream.reduce( + null, + Materialized.as("store"))); + } + + @Test + @SuppressWarnings("unchecked") + public void shouldThrowNullPointerOnMaterializedReduceIfMaterializedIsNull() { + assertThrows(NullPointerException.class, () -> windowedStream.reduce( + MockReducer.STRING_ADDER, + (Materialized) null)); + } + + @Test + public void shouldThrowNullPointerOnMaterializedReduceIfNamedIsNull() { + assertThrows(NullPointerException.class, () -> windowedStream.reduce( + MockReducer.STRING_ADDER, + (Named) null)); + } + + @Test + public void shouldThrowNullPointerOnCountIfMaterializedIsNull() { + assertThrows(NullPointerException.class, () -> windowedStream.count((Materialized>) null)); + } + + private void processData(final TopologyTestDriver driver) { + final TestInputTopic inputTopic = + driver.createInputTopic(TOPIC, new StringSerializer(), new StringSerializer()); + inputTopic.pipeInput("1", "1", 10L); + inputTopic.pipeInput("1", "2", 15L); + inputTopic.pipeInput("1", "3", 500L); + inputTopic.pipeInput("2", "10", 550L); + inputTopic.pipeInput("2", "20", 500L); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/TimestampedCacheFlushListenerTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/TimestampedCacheFlushListenerTest.java new file mode 100644 index 0000000..7c25b2e --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/TimestampedCacheFlushListenerTest.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.processor.To; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.junit.Test; + +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.expectLastCall; +import static org.easymock.EasyMock.mock; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.verify; + +public class TimestampedCacheFlushListenerTest { + + @Test + public void shouldForwardValueTimestampIfNewValueExists() { + final InternalProcessorContext> context = mock(InternalProcessorContext.class); + expect(context.currentNode()).andReturn(null).anyTimes(); + context.setCurrentNode(null); + context.setCurrentNode(null); + context.forward( + "key", + new Change<>("newValue", "oldValue"), + To.all().withTimestamp(42L)); + expectLastCall(); + replay(context); + + new TimestampedCacheFlushListener<>((ProcessorContext>) context).apply( + "key", + ValueAndTimestamp.make("newValue", 42L), + ValueAndTimestamp.make("oldValue", 21L), + 73L); + + verify(context); + } + + @Test + public void shouldForwardParameterTimestampIfNewValueIsNull() { + final InternalProcessorContext> context = mock(InternalProcessorContext.class); + expect(context.currentNode()).andReturn(null).anyTimes(); + context.setCurrentNode(null); + context.setCurrentNode(null); + context.forward( + "key", + new Change<>(null, "oldValue"), + To.all().withTimestamp(73L)); + expectLastCall(); + replay(context); + + new TimestampedCacheFlushListener<>((ProcessorContext>) context).apply( + "key", + null, + ValueAndTimestamp.make("oldValue", 21L), + 73L); + + verify(context); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/TimestampedTupleForwarderTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/TimestampedTupleForwarderTest.java new file mode 100644 index 0000000..89b732e --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/TimestampedTupleForwarderTest.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.To; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.streams.state.internals.WrappedStateStore; +import org.junit.Test; + +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.expectLastCall; +import static org.easymock.EasyMock.mock; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.verify; + +public class TimestampedTupleForwarderTest { + + @Test + public void shouldSetFlushListenerOnWrappedStateStore() { + setFlushListener(true); + setFlushListener(false); + } + + private void setFlushListener(final boolean sendOldValues) { + final WrappedStateStore> store = mock(WrappedStateStore.class); + final TimestampedCacheFlushListener flushListener = mock(TimestampedCacheFlushListener.class); + + expect(store.setFlushListener(flushListener, sendOldValues)).andReturn(false); + replay(store); + + new TimestampedTupleForwarder<>( + store, + (org.apache.kafka.streams.processor.api.ProcessorContext>) null, + flushListener, + sendOldValues + ); + + verify(store); + } + + @Test + public void shouldForwardRecordsIfWrappedStateStoreDoesNotCache() { + shouldForwardRecordsIfWrappedStateStoreDoesNotCache(false); + shouldForwardRecordsIfWrappedStateStoreDoesNotCache(true); + } + + private void shouldForwardRecordsIfWrappedStateStoreDoesNotCache(final boolean sendOldValues) { + final WrappedStateStore store = mock(WrappedStateStore.class); + final InternalProcessorContext> context = mock(InternalProcessorContext.class); + + expect(store.setFlushListener(null, sendOldValues)).andReturn(false); + if (sendOldValues) { + context.forward("key1", new Change<>("newValue1", "oldValue1")); + context.forward("key2", new Change<>("newValue2", "oldValue2"), To.all().withTimestamp(42L)); + } else { + context.forward("key1", new Change<>("newValue1", null)); + context.forward("key2", new Change<>("newValue2", null), To.all().withTimestamp(42L)); + } + expectLastCall(); + replay(store, context); + + final TimestampedTupleForwarder forwarder = + new TimestampedTupleForwarder<>( + store, + (org.apache.kafka.streams.processor.api.ProcessorContext>) context, + null, + sendOldValues + ); + forwarder.maybeForward("key1", "newValue1", "oldValue1"); + forwarder.maybeForward("key2", "newValue2", "oldValue2", 42L); + + verify(store, context); + } + + @Test + public void shouldNotForwardRecordsIfWrappedStateStoreDoesCache() { + final WrappedStateStore store = mock(WrappedStateStore.class); + final InternalProcessorContext> context = mock(InternalProcessorContext.class); + + expect(store.setFlushListener(null, false)).andReturn(true); + replay(store, context); + + final TimestampedTupleForwarder forwarder = + new TimestampedTupleForwarder<>( + store, + (org.apache.kafka.streams.processor.api.ProcessorContext>) context, + null, + false + ); + forwarder.maybeForward("key", "newValue", "oldValue"); + forwarder.maybeForward("key", "newValue", "oldValue", 42L); + + verify(store, context); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/TransformerSupplierAdapterTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/TransformerSupplierAdapterTest.java new file mode 100644 index 0000000..1eb55d0 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/TransformerSupplierAdapterTest.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import java.util.Iterator; +import java.util.Set; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Transformer; +import org.apache.kafka.streams.kstream.TransformerSupplier; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.state.StoreBuilder; +import org.easymock.EasyMock; +import org.easymock.EasyMockSupport; +import org.junit.Before; +import org.junit.Test; + +import static org.hamcrest.core.IsEqual.equalTo; +import static org.hamcrest.core.IsSame.sameInstance; +import static org.hamcrest.core.IsNot.not; +import static org.hamcrest.MatcherAssert.assertThat; + +public class TransformerSupplierAdapterTest extends EasyMockSupport { + + private ProcessorContext context; + private Transformer> transformer; + private TransformerSupplier> transformerSupplier; + private Set> stores; + + final String key = "Hello"; + final String value = "World"; + + @Before + public void before() { + context = mock(ProcessorContext.class); + transformer = mock(Transformer.class); + transformerSupplier = mock(TransformerSupplier.class); + stores = mock(Set.class); + } + + @Test + public void shouldCallInitOfAdapteeTransformer() { + EasyMock.expect(transformerSupplier.get()).andReturn(transformer); + transformer.init(context); + replayAll(); + + final TransformerSupplierAdapter adapter = + new TransformerSupplierAdapter<>(transformerSupplier); + final Transformer>> adaptedTransformer = adapter.get(); + adaptedTransformer.init(context); + + verifyAll(); + } + + @Test + public void shouldCallCloseOfAdapteeTransformer() { + EasyMock.expect(transformerSupplier.get()).andReturn(transformer); + transformer.close(); + replayAll(); + + final TransformerSupplierAdapter adapter = + new TransformerSupplierAdapter<>(transformerSupplier); + final Transformer>> adaptedTransformer = adapter.get(); + adaptedTransformer.close(); + + verifyAll(); + } + + @Test + public void shouldCallStoresOfAdapteeTransformerSupplier() { + EasyMock.expect(transformerSupplier.stores()).andReturn(stores); + replayAll(); + + final TransformerSupplierAdapter adapter = + new TransformerSupplierAdapter<>(transformerSupplier); + adapter.stores(); + verifyAll(); + } + + @Test + public void shouldCallTransformOfAdapteeTransformerAndReturnSingletonIterable() { + EasyMock.expect(transformerSupplier.get()).andReturn(transformer); + EasyMock.expect(transformer.transform(key, value)).andReturn(KeyValue.pair(0, 1)); + replayAll(); + + final TransformerSupplierAdapter adapter = + new TransformerSupplierAdapter<>(transformerSupplier); + final Transformer>> adaptedTransformer = adapter.get(); + final Iterator> iterator = adaptedTransformer.transform(key, value).iterator(); + + verifyAll(); + assertThat(iterator.hasNext(), equalTo(true)); + iterator.next(); + assertThat(iterator.hasNext(), equalTo(false)); + } + + @Test + public void shouldCallTransformOfAdapteeTransformerAndReturnEmptyIterable() { + EasyMock.expect(transformerSupplier.get()).andReturn(transformer); + EasyMock.expect(transformer.transform(key, value)).andReturn(null); + replayAll(); + + final TransformerSupplierAdapter adapter = + new TransformerSupplierAdapter<>(transformerSupplier); + final Transformer>> adaptedTransformer = adapter.get(); + final Iterator> iterator = adaptedTransformer.transform(key, value).iterator(); + + verifyAll(); + assertThat(iterator.hasNext(), equalTo(false)); + } + + @Test + public void shouldAlwaysGetNewAdapterTransformer() { + final Transformer> transformer1 = mock(Transformer.class); + final Transformer> transformer2 = mock(Transformer.class); + final Transformer> transformer3 = mock(Transformer.class); + EasyMock.expect(transformerSupplier.get()).andReturn(transformer1); + transformer1.init(context); + EasyMock.expect(transformerSupplier.get()).andReturn(transformer2); + transformer2.init(context); + EasyMock.expect(transformerSupplier.get()).andReturn(transformer3); + transformer3.init(context); + replayAll(); + + final TransformerSupplierAdapter adapter = + new TransformerSupplierAdapter<>(transformerSupplier); + final Transformer>> adapterTransformer1 = adapter.get(); + adapterTransformer1.init(context); + final Transformer>> adapterTransformer2 = adapter.get(); + adapterTransformer2.init(context); + final Transformer>> adapterTransformer3 = adapter.get(); + adapterTransformer3.init(context); + + verifyAll(); + assertThat(adapterTransformer1, not(sameInstance(adapterTransformer2))); + assertThat(adapterTransformer2, not(sameInstance(adapterTransformer3))); + assertThat(adapterTransformer3, not(sameInstance(adapterTransformer1))); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/UnlimitedWindowTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/UnlimitedWindowTest.java new file mode 100644 index 0000000..f8e5731 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/UnlimitedWindowTest.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.junit.Test; + +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class UnlimitedWindowTest { + + private long start = 50; + private final UnlimitedWindow window = new UnlimitedWindow(start); + private final SessionWindow sessionWindow = new SessionWindow(start, start); + + @Test + public void shouldAlwaysOverlap() { + assertTrue(window.overlap(new UnlimitedWindow(start - 1))); + assertTrue(window.overlap(new UnlimitedWindow(start))); + assertTrue(window.overlap(new UnlimitedWindow(start + 1))); + } + + @Test + public void cannotCompareUnlimitedWindowWithDifferentWindowType() { + assertThrows(IllegalArgumentException.class, () -> window.overlap(sessionWindow)); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/WindowedStreamPartitionerTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/WindowedStreamPartitionerTest.java new file mode 100644 index 0000000..adebf16 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/WindowedStreamPartitionerTest.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals; + +import org.apache.kafka.clients.producer.internals.DefaultPartitioner; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.kstream.TimeWindowedSerializer; +import org.apache.kafka.streams.kstream.Windowed; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Random; + +import static org.junit.Assert.assertEquals; + +public class WindowedStreamPartitionerTest { + + private String topicName = "topic"; + + private IntegerSerializer intSerializer = new IntegerSerializer(); + private StringSerializer stringSerializer = new StringSerializer(); + + private List infos = Arrays.asList( + new PartitionInfo(topicName, 0, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo(topicName, 1, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo(topicName, 2, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo(topicName, 3, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo(topicName, 4, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo(topicName, 5, Node.noNode(), new Node[0], new Node[0]) + ); + + private Cluster cluster = new Cluster("cluster", Collections.singletonList(Node.noNode()), infos, + Collections.emptySet(), Collections.emptySet()); + + @Test + public void testCopartitioning() { + final Random rand = new Random(); + final DefaultPartitioner defaultPartitioner = new DefaultPartitioner(); + final WindowedSerializer timeWindowedSerializer = new TimeWindowedSerializer<>(intSerializer); + final WindowedStreamPartitioner streamPartitioner = new WindowedStreamPartitioner<>(timeWindowedSerializer); + + for (int k = 0; k < 10; k++) { + final Integer key = rand.nextInt(); + final byte[] keyBytes = intSerializer.serialize(topicName, key); + + final String value = key.toString(); + final byte[] valueBytes = stringSerializer.serialize(topicName, value); + + final Integer expected = defaultPartitioner.partition("topic", key, keyBytes, value, valueBytes, cluster); + + for (int w = 1; w < 10; w++) { + final TimeWindow window = new TimeWindow(10 * w, 20 * w); + + final Windowed windowedKey = new Windowed<>(key, window); + final Integer actual = streamPartitioner.partition(topicName, windowedKey, value, infos.size()); + + assertEquals(expected, actual); + } + } + + defaultPartitioner.close(); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/CombinedKeySchemaTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/CombinedKeySchemaTest.java new file mode 100644 index 0000000..eb2c57e --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/CombinedKeySchemaTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals.foreignkeyjoin; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.junit.Test; + +import java.nio.ByteBuffer; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +public class CombinedKeySchemaTest { + + @Test + public void nonNullPrimaryKeySerdeTest() { + final CombinedKeySchema cks = new CombinedKeySchema<>( + () -> "fkTopic", Serdes.String(), + () -> "pkTopic", Serdes.Integer() + ); + final Integer primary = -999; + final Bytes result = cks.toBytes("foreignKey", primary); + + final CombinedKey deserializedKey = cks.fromBytes(result); + assertEquals("foreignKey", deserializedKey.getForeignKey()); + assertEquals(primary, deserializedKey.getPrimaryKey()); + } + + @Test + public void nullPrimaryKeySerdeTest() { + final CombinedKeySchema cks = new CombinedKeySchema<>( + () -> "fkTopic", Serdes.String(), + () -> "pkTopic", Serdes.Integer() + ); + assertThrows(NullPointerException.class, () -> cks.toBytes("foreignKey", null)); + } + + @Test + public void nullForeignKeySerdeTest() { + final CombinedKeySchema cks = new CombinedKeySchema<>( + () -> "fkTopic", Serdes.String(), + () -> "pkTopic", Serdes.Integer() + ); + assertThrows(NullPointerException.class, () -> cks.toBytes(null, 10)); + } + + @Test + public void prefixKeySerdeTest() { + final CombinedKeySchema cks = new CombinedKeySchema<>( + () -> "fkTopic", Serdes.String(), + () -> "pkTopic", Serdes.Integer() + ); + final String foreignKey = "someForeignKey"; + final byte[] foreignKeySerializedData = + Serdes.String().serializer().serialize("fkTopic", foreignKey); + final Bytes prefix = cks.prefixBytes(foreignKey); + + final ByteBuffer buf = ByteBuffer.allocate(Integer.BYTES + foreignKeySerializedData.length); + buf.putInt(foreignKeySerializedData.length); + buf.put(foreignKeySerializedData); + final Bytes expectedPrefixBytes = Bytes.wrap(buf.array()); + + assertEquals(expectedPrefixBytes, prefix); + } + + @Test + public void nullPrefixKeySerdeTest() { + final CombinedKeySchema cks = new CombinedKeySchema<>( + () -> "fkTopic", Serdes.String(), + () -> "pkTopic", Serdes.Integer() + ); + final String foreignKey = null; + assertThrows(NullPointerException.class, () -> cks.prefixBytes(foreignKey)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionResolverJoinProcessorSupplierTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionResolverJoinProcessorSupplierTest.java new file mode 100644 index 0000000..4a379d6 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionResolverJoinProcessorSupplierTest.java @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals.foreignkeyjoin; + +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.ValueJoiner; +import org.apache.kafka.streams.kstream.internals.KTableValueGetter; +import org.apache.kafka.streams.kstream.internals.KTableValueGetterSupplier; +import org.apache.kafka.streams.processor.MockProcessorContext; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.streams.state.internals.Murmur3; +import org.junit.Test; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.collection.IsEmptyCollection.empty; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class SubscriptionResolverJoinProcessorSupplierTest { + private static final StringSerializer STRING_SERIALIZER = new StringSerializer(); + private static final ValueJoiner JOINER = + (value1, value2) -> "(" + value1 + "," + value2 + ")"; + + private static class TestKTableValueGetterSupplier implements KTableValueGetterSupplier { + private final Map map = new HashMap<>(); + + @Override + public KTableValueGetter get() { + return new KTableValueGetter() { + @Override + public void init(final ProcessorContext context) { + } + + @Override + public ValueAndTimestamp get(final K key) { + return ValueAndTimestamp.make(map.get(key), -1); + } + }; + } + + @Override + public String[] storeNames() { + return new String[0]; + } + + void put(final K key, final V value) { + map.put(key, value); + } + } + + @Test + public void shouldNotForwardWhenHashDoesNotMatch() { + final TestKTableValueGetterSupplier valueGetterSupplier = + new TestKTableValueGetterSupplier<>(); + final boolean leftJoin = false; + final SubscriptionResolverJoinProcessorSupplier processorSupplier = + new SubscriptionResolverJoinProcessorSupplier<>( + valueGetterSupplier, + STRING_SERIALIZER, + () -> "value-hash-dummy-topic", + JOINER, + leftJoin + ); + final org.apache.kafka.streams.processor.Processor> processor = processorSupplier.get(); + final MockProcessorContext context = new MockProcessorContext(); + processor.init(context); + context.setRecordMetadata("topic", 0, 0, new RecordHeaders(), 0); + + valueGetterSupplier.put("lhs1", "lhsValue"); + final long[] oldHash = Murmur3.hash128(STRING_SERIALIZER.serialize("topic-join-resolver", "oldLhsValue")); + processor.process("lhs1", new SubscriptionResponseWrapper<>(oldHash, "rhsValue")); + final List forwarded = context.forwarded(); + assertThat(forwarded, empty()); + } + + @Test + public void shouldIgnoreUpdateWhenLeftHasBecomeNull() { + final TestKTableValueGetterSupplier valueGetterSupplier = + new TestKTableValueGetterSupplier<>(); + final boolean leftJoin = false; + final SubscriptionResolverJoinProcessorSupplier processorSupplier = + new SubscriptionResolverJoinProcessorSupplier<>( + valueGetterSupplier, + STRING_SERIALIZER, + () -> "value-hash-dummy-topic", + JOINER, + leftJoin + ); + final org.apache.kafka.streams.processor.Processor> processor = processorSupplier.get(); + final MockProcessorContext context = new MockProcessorContext(); + processor.init(context); + context.setRecordMetadata("topic", 0, 0, new RecordHeaders(), 0); + + valueGetterSupplier.put("lhs1", null); + final long[] hash = Murmur3.hash128(STRING_SERIALIZER.serialize("topic-join-resolver", "lhsValue")); + processor.process("lhs1", new SubscriptionResponseWrapper<>(hash, "rhsValue")); + final List forwarded = context.forwarded(); + assertThat(forwarded, empty()); + } + + @Test + public void shouldForwardWhenHashMatches() { + final TestKTableValueGetterSupplier valueGetterSupplier = + new TestKTableValueGetterSupplier<>(); + final boolean leftJoin = false; + final SubscriptionResolverJoinProcessorSupplier processorSupplier = + new SubscriptionResolverJoinProcessorSupplier<>( + valueGetterSupplier, + STRING_SERIALIZER, + () -> "value-hash-dummy-topic", + JOINER, + leftJoin + ); + final org.apache.kafka.streams.processor.Processor> processor = processorSupplier.get(); + final MockProcessorContext context = new MockProcessorContext(); + processor.init(context); + context.setRecordMetadata("topic", 0, 0, new RecordHeaders(), 0); + + valueGetterSupplier.put("lhs1", "lhsValue"); + final long[] hash = Murmur3.hash128(STRING_SERIALIZER.serialize("topic-join-resolver", "lhsValue")); + processor.process("lhs1", new SubscriptionResponseWrapper<>(hash, "rhsValue")); + final List forwarded = context.forwarded(); + assertThat(forwarded.size(), is(1)); + assertThat(forwarded.get(0).keyValue(), is(new KeyValue<>("lhs1", "(lhsValue,rhsValue)"))); + } + + @Test + public void shouldEmitTombstoneForInnerJoinWhenRightIsNull() { + final TestKTableValueGetterSupplier valueGetterSupplier = + new TestKTableValueGetterSupplier<>(); + final boolean leftJoin = false; + final SubscriptionResolverJoinProcessorSupplier processorSupplier = + new SubscriptionResolverJoinProcessorSupplier<>( + valueGetterSupplier, + STRING_SERIALIZER, + () -> "value-hash-dummy-topic", + JOINER, + leftJoin + ); + final org.apache.kafka.streams.processor.Processor> processor = processorSupplier.get(); + final MockProcessorContext context = new MockProcessorContext(); + processor.init(context); + context.setRecordMetadata("topic", 0, 0, new RecordHeaders(), 0); + + valueGetterSupplier.put("lhs1", "lhsValue"); + final long[] hash = Murmur3.hash128(STRING_SERIALIZER.serialize("topic-join-resolver", "lhsValue")); + processor.process("lhs1", new SubscriptionResponseWrapper<>(hash, null)); + final List forwarded = context.forwarded(); + assertThat(forwarded.size(), is(1)); + assertThat(forwarded.get(0).keyValue(), is(new KeyValue<>("lhs1", null))); + } + + @Test + public void shouldEmitResultForLeftJoinWhenRightIsNull() { + final TestKTableValueGetterSupplier valueGetterSupplier = + new TestKTableValueGetterSupplier<>(); + final boolean leftJoin = true; + final SubscriptionResolverJoinProcessorSupplier processorSupplier = + new SubscriptionResolverJoinProcessorSupplier<>( + valueGetterSupplier, + STRING_SERIALIZER, + () -> "value-hash-dummy-topic", + JOINER, + leftJoin + ); + final org.apache.kafka.streams.processor.Processor> processor = processorSupplier.get(); + final MockProcessorContext context = new MockProcessorContext(); + processor.init(context); + context.setRecordMetadata("topic", 0, 0, new RecordHeaders(), 0); + + valueGetterSupplier.put("lhs1", "lhsValue"); + final long[] hash = Murmur3.hash128(STRING_SERIALIZER.serialize("topic-join-resolver", "lhsValue")); + processor.process("lhs1", new SubscriptionResponseWrapper<>(hash, null)); + final List forwarded = context.forwarded(); + assertThat(forwarded.size(), is(1)); + assertThat(forwarded.get(0).keyValue(), is(new KeyValue<>("lhs1", "(lhsValue,null)"))); + } + + @Test + public void shouldEmitTombstoneForLeftJoinWhenRightIsNullAndLeftIsNull() { + final TestKTableValueGetterSupplier valueGetterSupplier = + new TestKTableValueGetterSupplier<>(); + final boolean leftJoin = true; + final SubscriptionResolverJoinProcessorSupplier processorSupplier = + new SubscriptionResolverJoinProcessorSupplier<>( + valueGetterSupplier, + STRING_SERIALIZER, + () -> "value-hash-dummy-topic", + JOINER, + leftJoin + ); + final org.apache.kafka.streams.processor.Processor> processor = processorSupplier.get(); + final MockProcessorContext context = new MockProcessorContext(); + processor.init(context); + context.setRecordMetadata("topic", 0, 0, new RecordHeaders(), 0); + + valueGetterSupplier.put("lhs1", null); + final long[] hash = null; + processor.process("lhs1", new SubscriptionResponseWrapper<>(hash, null)); + final List forwarded = context.forwarded(); + assertThat(forwarded.size(), is(1)); + assertThat(forwarded.get(0).keyValue(), is(new KeyValue<>("lhs1", null))); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionResponseWrapperSerdeTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionResponseWrapperSerdeTest.java new file mode 100644 index 0000000..30fc0c3 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionResponseWrapperSerdeTest.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals.foreignkeyjoin; + +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.state.internals.Murmur3; +import org.junit.Test; + +import java.util.Map; + +import static java.util.Objects.requireNonNull; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; + +public class SubscriptionResponseWrapperSerdeTest { + private static final class NonNullableSerde implements Serde, Serializer, Deserializer { + private final Serde delegate; + + NonNullableSerde(final Serde delegate) { + this.delegate = delegate; + } + + @Override + public void configure(final Map configs, final boolean isKey) { + + } + + @Override + public void close() { + + } + + @Override + public Serializer serializer() { + return this; + } + + @Override + public Deserializer deserializer() { + return this; + } + + @Override + public byte[] serialize(final String topic, final T data) { + return delegate.serializer().serialize(topic, requireNonNull(data)); + } + + @Override + public T deserialize(final String topic, final byte[] data) { + return delegate.deserializer().deserialize(topic, requireNonNull(data)); + } + } + + @Test + @SuppressWarnings("unchecked") + public void ShouldSerdeWithNonNullsTest() { + final long[] hashedValue = Murmur3.hash128(new byte[] {(byte) 0x01, (byte) 0x9A, (byte) 0xFF, (byte) 0x00}); + final String foreignValue = "foreignValue"; + final SubscriptionResponseWrapper srw = new SubscriptionResponseWrapper<>(hashedValue, foreignValue); + final SubscriptionResponseWrapperSerde srwSerde = new SubscriptionResponseWrapperSerde(new NonNullableSerde(Serdes.String())); + final byte[] serResponse = srwSerde.serializer().serialize(null, srw); + final SubscriptionResponseWrapper result = srwSerde.deserializer().deserialize(null, serResponse); + + assertArrayEquals(hashedValue, result.getOriginalValueHash()); + assertEquals(foreignValue, result.getForeignValue()); + } + + @Test + @SuppressWarnings("unchecked") + public void shouldSerdeWithNullForeignValueTest() { + final long[] hashedValue = Murmur3.hash128(new byte[] {(byte) 0x01, (byte) 0x9A, (byte) 0xFF, (byte) 0x00}); + final SubscriptionResponseWrapper srw = new SubscriptionResponseWrapper<>(hashedValue, null); + final SubscriptionResponseWrapperSerde srwSerde = new SubscriptionResponseWrapperSerde(new NonNullableSerde(Serdes.String())); + final byte[] serResponse = srwSerde.serializer().serialize(null, srw); + final SubscriptionResponseWrapper result = srwSerde.deserializer().deserialize(null, serResponse); + + assertArrayEquals(hashedValue, result.getOriginalValueHash()); + assertNull(result.getForeignValue()); + } + + @Test + @SuppressWarnings("unchecked") + public void shouldSerdeWithNullHashTest() { + final long[] hashedValue = null; + final String foreignValue = "foreignValue"; + final SubscriptionResponseWrapper srw = new SubscriptionResponseWrapper<>(hashedValue, foreignValue); + final SubscriptionResponseWrapperSerde srwSerde = new SubscriptionResponseWrapperSerde(new NonNullableSerde(Serdes.String())); + final byte[] serResponse = srwSerde.serializer().serialize(null, srw); + final SubscriptionResponseWrapper result = srwSerde.deserializer().deserialize(null, serResponse); + + assertArrayEquals(hashedValue, result.getOriginalValueHash()); + assertEquals(foreignValue, result.getForeignValue()); + } + + @Test + @SuppressWarnings("unchecked") + public void shouldSerdeWithNullsTest() { + final long[] hashedValue = null; + final String foreignValue = null; + final SubscriptionResponseWrapper srw = new SubscriptionResponseWrapper<>(hashedValue, foreignValue); + final SubscriptionResponseWrapperSerde srwSerde = new SubscriptionResponseWrapperSerde(new NonNullableSerde(Serdes.String())); + final byte[] serResponse = srwSerde.serializer().serialize(null, srw); + final SubscriptionResponseWrapper result = srwSerde.deserializer().deserialize(null, serResponse); + + assertArrayEquals(hashedValue, result.getOriginalValueHash()); + assertEquals(foreignValue, result.getForeignValue()); + } + + @Test + public void shouldThrowExceptionWithBadVersionTest() { + final long[] hashedValue = null; + assertThrows(UnsupportedVersionException.class, + () -> new SubscriptionResponseWrapper<>(hashedValue, "foreignValue", (byte) 0xFF)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionWrapperSerdeTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionWrapperSerdeTest.java new file mode 100644 index 0000000..e937efe --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/foreignkeyjoin/SubscriptionWrapperSerdeTest.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals.foreignkeyjoin; + +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.state.internals.Murmur3; +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +@SuppressWarnings({"unchecked", "rawtypes"}) +public class SubscriptionWrapperSerdeTest { + + @Test + @SuppressWarnings("unchecked") + public void shouldSerdeTest() { + final String originalKey = "originalKey"; + final SubscriptionWrapperSerde swSerde = new SubscriptionWrapperSerde<>(() -> "pkTopic", Serdes.String()); + final long[] hashedValue = Murmur3.hash128(new byte[] {(byte) 0xFF, (byte) 0xAA, (byte) 0x00, (byte) 0x19}); + final SubscriptionWrapper wrapper = new SubscriptionWrapper<>(hashedValue, SubscriptionWrapper.Instruction.DELETE_KEY_AND_PROPAGATE, originalKey); + final byte[] serialized = swSerde.serializer().serialize(null, wrapper); + final SubscriptionWrapper deserialized = (SubscriptionWrapper) swSerde.deserializer().deserialize(null, serialized); + + assertEquals(SubscriptionWrapper.Instruction.DELETE_KEY_AND_PROPAGATE, deserialized.getInstruction()); + assertArrayEquals(hashedValue, deserialized.getHash()); + assertEquals(originalKey, deserialized.getPrimaryKey()); + } + + @Test + @SuppressWarnings("unchecked") + public void shouldSerdeNullHashTest() { + final String originalKey = "originalKey"; + final SubscriptionWrapperSerde swSerde = new SubscriptionWrapperSerde<>(() -> "pkTopic", Serdes.String()); + final long[] hashedValue = null; + final SubscriptionWrapper wrapper = new SubscriptionWrapper<>(hashedValue, SubscriptionWrapper.Instruction.PROPAGATE_ONLY_IF_FK_VAL_AVAILABLE, originalKey); + final byte[] serialized = swSerde.serializer().serialize(null, wrapper); + final SubscriptionWrapper deserialized = (SubscriptionWrapper) swSerde.deserializer().deserialize(null, serialized); + + assertEquals(SubscriptionWrapper.Instruction.PROPAGATE_ONLY_IF_FK_VAL_AVAILABLE, deserialized.getInstruction()); + assertArrayEquals(hashedValue, deserialized.getHash()); + assertEquals(originalKey, deserialized.getPrimaryKey()); + } + + @Test + public void shouldThrowExceptionOnNullKeyTest() { + final String originalKey = null; + final long[] hashedValue = Murmur3.hash128(new byte[] {(byte) 0xFF, (byte) 0xAA, (byte) 0x00, (byte) 0x19}); + assertThrows(NullPointerException.class, () -> new SubscriptionWrapper<>(hashedValue, + SubscriptionWrapper.Instruction.PROPAGATE_ONLY_IF_FK_VAL_AVAILABLE, originalKey)); + } + + @Test + public void shouldThrowExceptionOnNullInstructionTest() { + final String originalKey = "originalKey"; + final long[] hashedValue = Murmur3.hash128(new byte[] {(byte) 0xFF, (byte) 0xAA, (byte) 0x00, (byte) 0x19}); + assertThrows(NullPointerException.class, () -> new SubscriptionWrapper<>(hashedValue, null, originalKey)); + } + + @Test (expected = UnsupportedVersionException.class) + public void shouldThrowExceptionOnUnsupportedVersionTest() { + final String originalKey = "originalKey"; + final long[] hashedValue = null; + new SubscriptionWrapper<>(hashedValue, SubscriptionWrapper.Instruction.PROPAGATE_ONLY_IF_FK_VAL_AVAILABLE, originalKey, (byte) 0x80); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/graph/GraphGraceSearchUtilTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/graph/GraphGraceSearchUtilTest.java new file mode 100644 index 0000000..a898338 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/graph/GraphGraceSearchUtilTest.java @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals.graph; + +import org.apache.kafka.streams.errors.TopologyException; +import org.apache.kafka.streams.kstream.SessionWindows; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.streams.kstream.internals.KStreamSessionWindowAggregate; +import org.apache.kafka.streams.kstream.internals.KStreamWindowAggregate; +import org.apache.kafka.streams.kstream.internals.TimeWindow; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.state.StoreBuilder; +import org.junit.Test; + +import static java.time.Duration.ofMillis; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.fail; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class GraphGraceSearchUtilTest { + @Test + public void shouldThrowOnNull() { + try { + GraphGraceSearchUtil.findAndVerifyWindowGrace(null); + fail("Should have thrown."); + } catch (final TopologyException e) { + assertThat(e.getMessage(), is("Invalid topology: Window close time is only defined for windowed computations. Got [].")); + } + } + + @Test + public void shouldFailIfThereIsNoGraceAncestor() { + // doesn't matter if this ancestor is stateless or stateful. The important thing it that there is + // no grace period defined on any ancestor of the node + final StatefulProcessorNode gracelessAncestor = new StatefulProcessorNode<>( + "stateful", + new ProcessorParameters<>( + () -> new org.apache.kafka.streams.processor.Processor() { + @Override + public void init(final ProcessorContext context) {} + + @Override + public void process(final String key, final Long value) {} + + @Override + public void close() {} + }, + "dummy" + ), + (StoreBuilder) null + ); + + final ProcessorGraphNode node = new ProcessorGraphNode<>("stateless", null); + gracelessAncestor.addChild(node); + + try { + GraphGraceSearchUtil.findAndVerifyWindowGrace(node); + fail("should have thrown."); + } catch (final TopologyException e) { + assertThat(e.getMessage(), is("Invalid topology: Window close time is only defined for windowed computations. Got [stateful->stateless].")); + } + } + + @Test + public void shouldExtractGraceFromKStreamWindowAggregateNode() { + final TimeWindows windows = TimeWindows.of(ofMillis(10L)).grace(ofMillis(1234L)); + final StatefulProcessorNode node = new StatefulProcessorNode<>( + "asdf", + new ProcessorParameters<>( + new KStreamWindowAggregate( + windows, + "asdf", + null, + null + ), + "asdf" + ), + (StoreBuilder) null + ); + + final long extracted = GraphGraceSearchUtil.findAndVerifyWindowGrace(node); + assertThat(extracted, is(windows.gracePeriodMs())); + } + + @Test + public void shouldExtractGraceFromKStreamSessionWindowAggregateNode() { + final SessionWindows windows = SessionWindows.with(ofMillis(10L)).grace(ofMillis(1234L)); + + final StatefulProcessorNode node = new StatefulProcessorNode<>( + "asdf", + new ProcessorParameters<>( + new KStreamSessionWindowAggregate( + windows, + "asdf", + null, + null, + null + ), + "asdf" + ), + (StoreBuilder) null + ); + + final long extracted = GraphGraceSearchUtil.findAndVerifyWindowGrace(node); + assertThat(extracted, is(windows.gracePeriodMs() + windows.inactivityGap())); + } + + @Test + public void shouldExtractGraceFromSessionAncestorThroughStatefulParent() { + final SessionWindows windows = SessionWindows.with(ofMillis(10L)).grace(ofMillis(1234L)); + final StatefulProcessorNode graceGrandparent = new StatefulProcessorNode<>( + "asdf", + new ProcessorParameters<>(new KStreamSessionWindowAggregate( + windows, "asdf", null, null, null + ), "asdf"), + (StoreBuilder) null + ); + + final StatefulProcessorNode statefulParent = new StatefulProcessorNode<>( + "stateful", + new ProcessorParameters<>( + () -> new org.apache.kafka.streams.processor.Processor() { + @Override + public void init(final ProcessorContext context) {} + + @Override + public void process(final String key, final Long value) {} + + @Override + public void close() {} + }, + "dummy" + ), + (StoreBuilder) null + ); + graceGrandparent.addChild(statefulParent); + + final ProcessorGraphNode node = new ProcessorGraphNode<>("stateless", null); + statefulParent.addChild(node); + + final long extracted = GraphGraceSearchUtil.findAndVerifyWindowGrace(node); + assertThat(extracted, is(windows.gracePeriodMs() + windows.inactivityGap())); + } + + @Test + public void shouldExtractGraceFromSessionAncestorThroughStatelessParent() { + final SessionWindows windows = SessionWindows.with(ofMillis(10L)).grace(ofMillis(1234L)); + final StatefulProcessorNode graceGrandparent = new StatefulProcessorNode<>( + "asdf", + new ProcessorParameters<>( + new KStreamSessionWindowAggregate( + windows, + "asdf", + null, + null, + null + ), + "asdf" + ), + (StoreBuilder) null + ); + + final ProcessorGraphNode statelessParent = new ProcessorGraphNode<>("stateless", null); + graceGrandparent.addChild(statelessParent); + + final ProcessorGraphNode node = new ProcessorGraphNode<>("stateless", null); + statelessParent.addChild(node); + + final long extracted = GraphGraceSearchUtil.findAndVerifyWindowGrace(node); + assertThat(extracted, is(windows.gracePeriodMs() + windows.inactivityGap())); + } + + @Test + public void shouldUseMaxIfMultiParentsDoNotAgreeOnGrace() { + final StatefulProcessorNode leftParent = new StatefulProcessorNode<>( + "asdf", + new ProcessorParameters<>( + new KStreamSessionWindowAggregate( + SessionWindows.with(ofMillis(10L)).grace(ofMillis(1234L)), + "asdf", + null, + null, + null + ), + "asdf" + ), + (StoreBuilder) null + ); + + final StatefulProcessorNode rightParent = new StatefulProcessorNode<>( + "asdf", + new ProcessorParameters<>( + new KStreamWindowAggregate( + TimeWindows.of(ofMillis(10L)).grace(ofMillis(4321L)), + "asdf", + null, + null + ), + "asdf" + ), + (StoreBuilder) null + ); + + final ProcessorGraphNode node = new ProcessorGraphNode<>("stateless", null); + leftParent.addChild(node); + rightParent.addChild(node); + + final long extracted = GraphGraceSearchUtil.findAndVerifyWindowGrace(node); + assertThat(extracted, is(4321L)); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/graph/StreamsGraphTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/graph/StreamsGraphTest.java new file mode 100644 index 0000000..b912a96 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/graph/StreamsGraphTest.java @@ -0,0 +1,518 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals.graph; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.Branched; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.kstream.JoinWindows; +import org.apache.kafka.streams.kstream.Joined; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.Suppressed; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.streams.kstream.Transformer; +import org.apache.kafka.streams.kstream.TransformerSupplier; +import org.apache.kafka.streams.kstream.ValueJoiner; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.junit.Test; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.Properties; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static java.time.Duration.ofMillis; +import static org.junit.Assert.assertEquals; + +@SuppressWarnings("deprecation") +public class StreamsGraphTest { + + private final Pattern repartitionTopicPattern = Pattern.compile("Sink: .*-repartition"); + private Initializer initializer; + private Aggregator aggregator; + + // Test builds topology in succesive manner but only graph node not yet processed written to topology + + @Test + public void shouldBeAbleToBuildTopologyIncrementally() { + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream stream = builder.stream("topic"); + final KStream streamII = builder.stream("other-topic"); + final ValueJoiner valueJoiner = (v, v2) -> v + v2; + + + final KStream joinedStream = stream.join(streamII, valueJoiner, JoinWindows.of(ofMillis(5000))); + + // build step one + assertEquals(expectedJoinedTopology, builder.build().describe().toString()); + + final KStream filteredJoinStream = joinedStream.filter((k, v) -> v.equals("foo")); + // build step two + assertEquals(expectedJoinedFilteredTopology, builder.build().describe().toString()); + + filteredJoinStream.mapValues(v -> v + "some value").to("output-topic"); + // build step three + assertEquals(expectedFullTopology, builder.build().describe().toString()); + + } + + @Test + public void shouldBeAbleToProcessNestedMultipleKeyChangingNodes() { + final Properties properties = new Properties(); + properties.setProperty(StreamsConfig.APPLICATION_ID_CONFIG, "test-application"); + properties.setProperty(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + properties.setProperty(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, StreamsConfig.OPTIMIZE); + + final StreamsBuilder builder = new StreamsBuilder(); + final KStream inputStream = builder.stream("inputTopic"); + + final KStream changedKeyStream = inputStream.selectKey((k, v) -> v.substring(0, 5)); + + // first repartition + changedKeyStream.groupByKey(Grouped.as("count-repartition")) + .count(Materialized.as("count-store")) + .toStream().to("count-topic", Produced.with(Serdes.String(), Serdes.Long())); + + // second repartition + changedKeyStream.groupByKey(Grouped.as("windowed-repartition")) + .windowedBy(TimeWindows.of(Duration.ofSeconds(5))) + .count(Materialized.as("windowed-count-store")) + .toStream() + .map((k, v) -> KeyValue.pair(k.key(), v)).to("windowed-count", Produced.with(Serdes.String(), Serdes.Long())); + + builder.build(properties); + } + + @Test + // Topology in this test from https://issues.apache.org/jira/browse/KAFKA-9739 + public void shouldNotThrowNPEWithMergeNodes() { + final Properties properties = new Properties(); + properties.setProperty(StreamsConfig.APPLICATION_ID_CONFIG, "test-application"); + properties.setProperty(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + properties.setProperty(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, StreamsConfig.OPTIMIZE); + + final StreamsBuilder builder = new StreamsBuilder(); + initializer = () -> ""; + aggregator = (aggKey, value, aggregate) -> aggregate + value.length(); + final TransformerSupplier> transformSupplier = () -> new Transformer>() { + @Override + public void init(final ProcessorContext context) { + + } + + @Override + public KeyValue transform(final String key, final String value) { + return KeyValue.pair(key, value); + } + + @Override + public void close() { + + } + }; + + final KStream retryStream = builder.stream("retryTopic", Consumed.with(Serdes.String(), Serdes.String())) + .transform(transformSupplier) + .groupByKey(Grouped.with(Serdes.String(), Serdes.String())) + .aggregate(initializer, + aggregator, + Materialized.with(Serdes.String(), Serdes.String())) + .suppress(Suppressed.untilTimeLimit(Duration.ofSeconds(500), Suppressed.BufferConfig.maxBytes(64_000_000))) + .toStream() + .flatMap((k, v) -> new ArrayList<>()); + + final KTable idTable = builder.stream("id-table-topic", Consumed.with(Serdes.String(), Serdes.String())) + .flatMap((k, v) -> new ArrayList>()) + .peek((subscriptionId, recipientId) -> System.out.println("data " + subscriptionId + " " + recipientId)) + .groupByKey(Grouped.with(Serdes.String(), Serdes.String())) + .aggregate(initializer, + aggregator, + Materialized.with(Serdes.String(), Serdes.String())); + + final KStream joinStream = builder.stream("internal-topic-command", Consumed.with(Serdes.String(), Serdes.String())) + .peek((subscriptionId, command) -> System.out.println("stdoutput")) + .mapValues((k, v) -> v) + .merge(retryStream) + .leftJoin(idTable, (v1, v2) -> v1 + v2, + Joined.with(Serdes.String(), Serdes.String(), Serdes.String())); + + + joinStream.split() + .branch((k, v) -> v.equals("some-value"), Branched.withConsumer(ks -> ks.map(KeyValue::pair) + .peek((recipientId, command) -> System.out.println("printing out")) + .to("external-command", Produced.with(Serdes.String(), Serdes.String())) + )) + .defaultBranch(Branched.withConsumer(ks -> { + ks.filter((k, v) -> v != null) + .peek((subscriptionId, wrapper) -> System.out.println("Printing output")) + .mapValues((k, v) -> v) + .to("dlq-topic", Produced.with(Serdes.String(), Serdes.String())); + ks.map(KeyValue::pair).to("retryTopic", Produced.with(Serdes.String(), Serdes.String())); + })); + + final Topology topology = builder.build(properties); + assertEquals(expectedComplexMergeOptimizeTopology, topology.describe().toString()); + } + + @Test + public void shouldNotOptimizeWithValueOrKeyChangingOperatorsAfterInitialKeyChange() { + + final Topology attemptedOptimize = getTopologyWithChangingValuesAfterChangingKey(StreamsConfig.OPTIMIZE); + final Topology noOptimization = getTopologyWithChangingValuesAfterChangingKey(StreamsConfig.NO_OPTIMIZATION); + + assertEquals(attemptedOptimize.describe().toString(), noOptimization.describe().toString()); + assertEquals(2, getCountOfRepartitionTopicsFound(attemptedOptimize.describe().toString())); + assertEquals(2, getCountOfRepartitionTopicsFound(noOptimization.describe().toString())); + } + + // no need to optimize as user has already performed the repartitioning manually + @Deprecated + @Test + public void shouldNotOptimizeWhenAThroughOperationIsDone() { + final Topology attemptedOptimize = getTopologyWithThroughOperation(StreamsConfig.OPTIMIZE); + final Topology noOptimziation = getTopologyWithThroughOperation(StreamsConfig.NO_OPTIMIZATION); + + assertEquals(attemptedOptimize.describe().toString(), noOptimziation.describe().toString()); + assertEquals(0, getCountOfRepartitionTopicsFound(attemptedOptimize.describe().toString())); + assertEquals(0, getCountOfRepartitionTopicsFound(noOptimziation.describe().toString())); + + } + + @Test + public void shouldOptimizeSeveralMergeNodesWithCommonKeyChangingParent() { + final StreamsBuilder streamsBuilder = new StreamsBuilder(); + final KStream parentStream = streamsBuilder.stream("input_topic", Consumed.with(Serdes.Integer(), Serdes.Integer())) + .selectKey(Integer::sum); + + final KStream childStream1 = parentStream.mapValues(v -> v + 1); + final KStream childStream2 = parentStream.mapValues(v -> v + 2); + final KStream childStream3 = parentStream.mapValues(v -> v + 3); + + childStream1 + .merge(childStream2) + .merge(childStream3) + .to("output_topic"); + + final Properties properties = new Properties(); + properties.setProperty(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, StreamsConfig.OPTIMIZE); + final Topology topology = streamsBuilder.build(properties); + + assertEquals(expectedMergeOptimizedTopology, topology.describe().toString()); + } + + @Test + public void shouldNotOptimizeWhenRepartitionOperationIsDone() { + final Topology attemptedOptimize = getTopologyWithRepartitionOperation(StreamsConfig.OPTIMIZE); + final Topology noOptimziation = getTopologyWithRepartitionOperation(StreamsConfig.NO_OPTIMIZATION); + + assertEquals(attemptedOptimize.describe().toString(), noOptimziation.describe().toString()); + assertEquals(2, getCountOfRepartitionTopicsFound(attemptedOptimize.describe().toString())); + assertEquals(2, getCountOfRepartitionTopicsFound(noOptimziation.describe().toString())); + } + + private Topology getTopologyWithChangingValuesAfterChangingKey(final String optimizeConfig) { + + final StreamsBuilder builder = new StreamsBuilder(); + final Properties properties = new Properties(); + properties.put(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, optimizeConfig); + + final KStream inputStream = builder.stream("input"); + final KStream mappedKeyStream = inputStream.selectKey((k, v) -> k + v); + + mappedKeyStream.mapValues(v -> v.toUpperCase(Locale.getDefault())).groupByKey().count().toStream().to("output"); + mappedKeyStream.flatMapValues(v -> Arrays.asList(v.split("\\s"))).groupByKey().windowedBy(TimeWindows.of(ofMillis(5000))).count().toStream().to("windowed-output"); + + return builder.build(properties); + + } + + @Deprecated // specifically testing the deprecated variant + private Topology getTopologyWithThroughOperation(final String optimizeConfig) { + + final StreamsBuilder builder = new StreamsBuilder(); + final Properties properties = new Properties(); + properties.put(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, optimizeConfig); + + final KStream inputStream = builder.stream("input"); + final KStream mappedKeyStream = inputStream.selectKey((k, v) -> k + v).through("through-topic"); + + mappedKeyStream.groupByKey().count().toStream().to("output"); + mappedKeyStream.groupByKey().windowedBy(TimeWindows.of(ofMillis(5000))).count().toStream().to("windowed-output"); + + return builder.build(properties); + + } + + private Topology getTopologyWithRepartitionOperation(final String optimizeConfig) { + final StreamsBuilder builder = new StreamsBuilder(); + final Properties properties = new Properties(); + properties.put(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, optimizeConfig); + + final KStream inputStream = builder.stream("input").selectKey((k, v) -> k + v); + + inputStream + .repartition() + .groupByKey() + .count() + .toStream() + .to("output"); + + inputStream + .repartition() + .groupByKey() + .windowedBy(TimeWindows.of(ofMillis(5000))) + .count() + .toStream() + .to("windowed-output"); + + return builder.build(properties); + } + + private int getCountOfRepartitionTopicsFound(final String topologyString) { + final Matcher matcher = repartitionTopicPattern.matcher(topologyString); + final List repartitionTopicsFound = new ArrayList<>(); + while (matcher.find()) { + repartitionTopicsFound.add(matcher.group()); + } + return repartitionTopicsFound.size(); + } + + private final String expectedJoinedTopology = "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [topic])\n" + + " --> KSTREAM-WINDOWED-0000000002\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [other-topic])\n" + + " --> KSTREAM-WINDOWED-0000000003\n" + + " Processor: KSTREAM-WINDOWED-0000000002 (stores: [KSTREAM-JOINTHIS-0000000004-store])\n" + + " --> KSTREAM-JOINTHIS-0000000004\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: KSTREAM-WINDOWED-0000000003 (stores: [KSTREAM-JOINOTHER-0000000005-store])\n" + + " --> KSTREAM-JOINOTHER-0000000005\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: KSTREAM-JOINOTHER-0000000005 (stores: [KSTREAM-JOINTHIS-0000000004-store])\n" + + " --> KSTREAM-MERGE-0000000006\n" + + " <-- KSTREAM-WINDOWED-0000000003\n" + + " Processor: KSTREAM-JOINTHIS-0000000004 (stores: [KSTREAM-JOINOTHER-0000000005-store])\n" + + " --> KSTREAM-MERGE-0000000006\n" + + " <-- KSTREAM-WINDOWED-0000000002\n" + + " Processor: KSTREAM-MERGE-0000000006 (stores: [])\n" + + " --> none\n" + + " <-- KSTREAM-JOINTHIS-0000000004, KSTREAM-JOINOTHER-0000000005\n\n"; + + private final String expectedJoinedFilteredTopology = "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [topic])\n" + + " --> KSTREAM-WINDOWED-0000000002\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [other-topic])\n" + + " --> KSTREAM-WINDOWED-0000000003\n" + + " Processor: KSTREAM-WINDOWED-0000000002 (stores: [KSTREAM-JOINTHIS-0000000004-store])\n" + + " --> KSTREAM-JOINTHIS-0000000004\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: KSTREAM-WINDOWED-0000000003 (stores: [KSTREAM-JOINOTHER-0000000005-store])\n" + + " --> KSTREAM-JOINOTHER-0000000005\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: KSTREAM-JOINOTHER-0000000005 (stores: [KSTREAM-JOINTHIS-0000000004-store])\n" + + " --> KSTREAM-MERGE-0000000006\n" + + " <-- KSTREAM-WINDOWED-0000000003\n" + + " Processor: KSTREAM-JOINTHIS-0000000004 (stores: [KSTREAM-JOINOTHER-0000000005-store])\n" + + " --> KSTREAM-MERGE-0000000006\n" + + " <-- KSTREAM-WINDOWED-0000000002\n" + + " Processor: KSTREAM-MERGE-0000000006 (stores: [])\n" + + " --> KSTREAM-FILTER-0000000007\n" + + " <-- KSTREAM-JOINTHIS-0000000004, KSTREAM-JOINOTHER-0000000005\n" + + " Processor: KSTREAM-FILTER-0000000007 (stores: [])\n" + + " --> none\n" + + " <-- KSTREAM-MERGE-0000000006\n\n"; + + private final String expectedFullTopology = "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [topic])\n" + + " --> KSTREAM-WINDOWED-0000000002\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [other-topic])\n" + + " --> KSTREAM-WINDOWED-0000000003\n" + + " Processor: KSTREAM-WINDOWED-0000000002 (stores: [KSTREAM-JOINTHIS-0000000004-store])\n" + + " --> KSTREAM-JOINTHIS-0000000004\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: KSTREAM-WINDOWED-0000000003 (stores: [KSTREAM-JOINOTHER-0000000005-store])\n" + + " --> KSTREAM-JOINOTHER-0000000005\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Processor: KSTREAM-JOINOTHER-0000000005 (stores: [KSTREAM-JOINTHIS-0000000004-store])\n" + + " --> KSTREAM-MERGE-0000000006\n" + + " <-- KSTREAM-WINDOWED-0000000003\n" + + " Processor: KSTREAM-JOINTHIS-0000000004 (stores: [KSTREAM-JOINOTHER-0000000005-store])\n" + + " --> KSTREAM-MERGE-0000000006\n" + + " <-- KSTREAM-WINDOWED-0000000002\n" + + " Processor: KSTREAM-MERGE-0000000006 (stores: [])\n" + + " --> KSTREAM-FILTER-0000000007\n" + + " <-- KSTREAM-JOINTHIS-0000000004, KSTREAM-JOINOTHER-0000000005\n" + + " Processor: KSTREAM-FILTER-0000000007 (stores: [])\n" + + " --> KSTREAM-MAPVALUES-0000000008\n" + + " <-- KSTREAM-MERGE-0000000006\n" + + " Processor: KSTREAM-MAPVALUES-0000000008 (stores: [])\n" + + " --> KSTREAM-SINK-0000000009\n" + + " <-- KSTREAM-FILTER-0000000007\n" + + " Sink: KSTREAM-SINK-0000000009 (topic: output-topic)\n" + + " <-- KSTREAM-MAPVALUES-0000000008\n\n"; + + + private final String expectedMergeOptimizedTopology = "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input_topic])\n" + + " --> KSTREAM-KEY-SELECT-0000000001\n" + + " Processor: KSTREAM-KEY-SELECT-0000000001 (stores: [])\n" + + " --> KSTREAM-MAPVALUES-0000000002, KSTREAM-MAPVALUES-0000000003, KSTREAM-MAPVALUES-0000000004\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: KSTREAM-MAPVALUES-0000000002 (stores: [])\n" + + " --> KSTREAM-MERGE-0000000005\n" + + " <-- KSTREAM-KEY-SELECT-0000000001\n" + + " Processor: KSTREAM-MAPVALUES-0000000003 (stores: [])\n" + + " --> KSTREAM-MERGE-0000000005\n" + + " <-- KSTREAM-KEY-SELECT-0000000001\n" + + " Processor: KSTREAM-MAPVALUES-0000000004 (stores: [])\n" + + " --> KSTREAM-MERGE-0000000006\n" + + " <-- KSTREAM-KEY-SELECT-0000000001\n" + + " Processor: KSTREAM-MERGE-0000000005 (stores: [])\n" + + " --> KSTREAM-MERGE-0000000006\n" + + " <-- KSTREAM-MAPVALUES-0000000002, KSTREAM-MAPVALUES-0000000003\n" + + " Processor: KSTREAM-MERGE-0000000006 (stores: [])\n" + + " --> KSTREAM-SINK-0000000007\n" + + " <-- KSTREAM-MERGE-0000000005, KSTREAM-MAPVALUES-0000000004\n" + + " Sink: KSTREAM-SINK-0000000007 (topic: output_topic)\n" + + " <-- KSTREAM-MERGE-0000000006\n\n"; + + + private final String expectedComplexMergeOptimizeTopology = "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [retryTopic])\n" + + " --> KSTREAM-TRANSFORM-0000000001\n" + + " Processor: KSTREAM-TRANSFORM-0000000001 (stores: [])\n" + + " --> KSTREAM-AGGREGATE-STATE-STORE-0000000002-repartition-filter\n" + + " <-- KSTREAM-SOURCE-0000000000\n" + + " Processor: KSTREAM-AGGREGATE-STATE-STORE-0000000002-repartition-filter (stores: [])\n" + + " --> KSTREAM-AGGREGATE-STATE-STORE-0000000002-repartition-sink\n" + + " <-- KSTREAM-TRANSFORM-0000000001\n" + + " Sink: KSTREAM-AGGREGATE-STATE-STORE-0000000002-repartition-sink (topic: KSTREAM-AGGREGATE-STATE-STORE-0000000002-repartition)\n" + + " <-- KSTREAM-AGGREGATE-STATE-STORE-0000000002-repartition-filter\n" + + "\n" + + " Sub-topology: 1\n" + + " Source: KSTREAM-AGGREGATE-STATE-STORE-0000000002-repartition-source (topics: [KSTREAM-AGGREGATE-STATE-STORE-0000000002-repartition])\n" + + " --> KSTREAM-AGGREGATE-0000000003\n" + + " Processor: KSTREAM-AGGREGATE-0000000003 (stores: [KSTREAM-AGGREGATE-STATE-STORE-0000000002])\n" + + " --> KTABLE-SUPPRESS-0000000007\n" + + " <-- KSTREAM-AGGREGATE-STATE-STORE-0000000002-repartition-source\n" + + " Source: KSTREAM-SOURCE-0000000019 (topics: [internal-topic-command])\n" + + " --> KSTREAM-PEEK-0000000020\n" + + " Processor: KTABLE-SUPPRESS-0000000007 (stores: [KTABLE-SUPPRESS-STATE-STORE-0000000008])\n" + + " --> KTABLE-TOSTREAM-0000000009\n" + + " <-- KSTREAM-AGGREGATE-0000000003\n" + + " Processor: KSTREAM-PEEK-0000000020 (stores: [])\n" + + " --> KSTREAM-MAPVALUES-0000000021\n" + + " <-- KSTREAM-SOURCE-0000000019\n" + + " Processor: KTABLE-TOSTREAM-0000000009 (stores: [])\n" + + " --> KSTREAM-FLATMAP-0000000010\n" + + " <-- KTABLE-SUPPRESS-0000000007\n" + + " Processor: KSTREAM-FLATMAP-0000000010 (stores: [])\n" + + " --> KSTREAM-MERGE-0000000022\n" + + " <-- KTABLE-TOSTREAM-0000000009\n" + + " Processor: KSTREAM-MAPVALUES-0000000021 (stores: [])\n" + + " --> KSTREAM-MERGE-0000000022\n" + + " <-- KSTREAM-PEEK-0000000020\n" + + " Processor: KSTREAM-MERGE-0000000022 (stores: [])\n" + + " --> KSTREAM-FILTER-0000000024\n" + + " <-- KSTREAM-MAPVALUES-0000000021, KSTREAM-FLATMAP-0000000010\n" + + " Processor: KSTREAM-FILTER-0000000024 (stores: [])\n" + + " --> KSTREAM-SINK-0000000023\n" + + " <-- KSTREAM-MERGE-0000000022\n" + + " Sink: KSTREAM-SINK-0000000023 (topic: KSTREAM-MERGE-0000000022-repartition)\n" + + " <-- KSTREAM-FILTER-0000000024\n" + + "\n" + + " Sub-topology: 2\n" + + " Source: KSTREAM-SOURCE-0000000011 (topics: [id-table-topic])\n" + + " --> KSTREAM-FLATMAP-0000000012\n" + + " Processor: KSTREAM-FLATMAP-0000000012 (stores: [])\n" + + " --> KSTREAM-AGGREGATE-STATE-STORE-0000000014-repartition-filter\n" + + " <-- KSTREAM-SOURCE-0000000011\n" + + " Processor: KSTREAM-AGGREGATE-STATE-STORE-0000000014-repartition-filter (stores: [])\n" + + " --> KSTREAM-AGGREGATE-STATE-STORE-0000000014-repartition-sink\n" + + " <-- KSTREAM-FLATMAP-0000000012\n" + + " Sink: KSTREAM-AGGREGATE-STATE-STORE-0000000014-repartition-sink (topic: KSTREAM-AGGREGATE-STATE-STORE-0000000014-repartition)\n" + + " <-- KSTREAM-AGGREGATE-STATE-STORE-0000000014-repartition-filter\n" + + "\n" + + " Sub-topology: 3\n" + + " Source: KSTREAM-SOURCE-0000000025 (topics: [KSTREAM-MERGE-0000000022-repartition])\n" + + " --> KSTREAM-LEFTJOIN-0000000026\n" + + " Processor: KSTREAM-LEFTJOIN-0000000026 (stores: [KSTREAM-AGGREGATE-STATE-STORE-0000000014])\n" + + " --> KSTREAM-BRANCH-0000000027\n" + + " <-- KSTREAM-SOURCE-0000000025\n" + + " Processor: KSTREAM-BRANCH-0000000027 (stores: [])\n" + + " --> KSTREAM-BRANCH-00000000270, KSTREAM-BRANCH-00000000271\n" + + " <-- KSTREAM-LEFTJOIN-0000000026\n" + + " Processor: KSTREAM-BRANCH-00000000270 (stores: [])\n" + + " --> KSTREAM-FILTER-0000000033, KSTREAM-MAP-0000000037\n" + + " <-- KSTREAM-BRANCH-0000000027\n" + + " Processor: KSTREAM-BRANCH-00000000271 (stores: [])\n" + + " --> KSTREAM-MAP-0000000029\n" + + " <-- KSTREAM-BRANCH-0000000027\n" + + " Processor: KSTREAM-FILTER-0000000033 (stores: [])\n" + + " --> KSTREAM-PEEK-0000000034\n" + + " <-- KSTREAM-BRANCH-00000000270\n" + + " Source: KSTREAM-AGGREGATE-STATE-STORE-0000000014-repartition-source (topics: [KSTREAM-AGGREGATE-STATE-STORE-0000000014-repartition])\n" + + " --> KSTREAM-PEEK-0000000013\n" + + " Processor: KSTREAM-MAP-0000000029 (stores: [])\n" + + " --> KSTREAM-PEEK-0000000030\n" + + " <-- KSTREAM-BRANCH-00000000271\n" + + " Processor: KSTREAM-PEEK-0000000034 (stores: [])\n" + + " --> KSTREAM-MAPVALUES-0000000035\n" + + " <-- KSTREAM-FILTER-0000000033\n" + + " Processor: KSTREAM-MAP-0000000037 (stores: [])\n" + + " --> KSTREAM-SINK-0000000038\n" + + " <-- KSTREAM-BRANCH-00000000270\n" + + " Processor: KSTREAM-MAPVALUES-0000000035 (stores: [])\n" + + " --> KSTREAM-SINK-0000000036\n" + + " <-- KSTREAM-PEEK-0000000034\n" + + " Processor: KSTREAM-PEEK-0000000013 (stores: [])\n" + + " --> KSTREAM-AGGREGATE-0000000015\n" + + " <-- KSTREAM-AGGREGATE-STATE-STORE-0000000014-repartition-source\n" + + " Processor: KSTREAM-PEEK-0000000030 (stores: [])\n" + + " --> KSTREAM-SINK-0000000031\n" + + " <-- KSTREAM-MAP-0000000029\n" + + " Processor: KSTREAM-AGGREGATE-0000000015 (stores: [KSTREAM-AGGREGATE-STATE-STORE-0000000014])\n" + + " --> none\n" + + " <-- KSTREAM-PEEK-0000000013\n" + + " Sink: KSTREAM-SINK-0000000031 (topic: external-command)\n" + + " <-- KSTREAM-PEEK-0000000030\n" + + " Sink: KSTREAM-SINK-0000000036 (topic: dlq-topic)\n" + + " <-- KSTREAM-MAPVALUES-0000000035\n" + + " Sink: KSTREAM-SINK-0000000038 (topic: retryTopic)\n" + + " <-- KSTREAM-MAP-0000000037\n\n"; +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/graph/TableProcessorNodeTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/graph/TableProcessorNodeTest.java new file mode 100644 index 0000000..99be7f8 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/graph/TableProcessorNodeTest.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.kstream.internals.graph; + +import org.junit.Test; + +import static org.junit.Assert.assertTrue; + +public class TableProcessorNodeTest { + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + private static class TestProcessor extends org.apache.kafka.streams.processor.AbstractProcessor { + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + } + + @Override + public void process(final String key, final String value) { + } + + @Override + public void close() { + } + } + + @Test + public void shouldConvertToStringWithNullStoreBuilder() { + final TableProcessorNode node = new TableProcessorNode<>( + "name", + new ProcessorParameters<>(TestProcessor::new, "processor"), + null, + new String[]{"store1", "store2"} + ); + + final String asString = node.toString(); + final String expected = "storeBuilder=null"; + assertTrue( + String.format( + "Expected toString to return string with \"%s\", received: %s", + expected, + asString), + asString.contains(expected) + ); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/graph/TableSourceNodeTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/graph/TableSourceNodeTest.java new file mode 100644 index 0000000..66c55c0 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/graph/TableSourceNodeTest.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals.graph; + +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.internals.ConsumedInternal; +import org.apache.kafka.streams.kstream.internals.KTableSource; +import org.apache.kafka.streams.kstream.internals.MaterializedInternal; +import org.apache.kafka.streams.kstream.internals.graph.TableSourceNode.TableSourceNodeBuilder; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; +import org.easymock.EasyMock; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.api.easymock.PowerMock; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; + +@RunWith(PowerMockRunner.class) +@PrepareForTest({InternalTopologyBuilder.class}) +public class TableSourceNodeTest { + + private static final String STORE_NAME = "store-name"; + private static final String TOPIC = "input-topic"; + + private final InternalTopologyBuilder topologyBuilder = PowerMock.createNiceMock(InternalTopologyBuilder.class); + + @Test + public void shouldConnectStateStoreToInputTopicIfInputTopicIsUsedAsChangelog() { + final boolean shouldReuseSourceTopicForChangelog = true; + topologyBuilder.connectSourceStoreAndTopic(STORE_NAME, TOPIC); + EasyMock.replay(topologyBuilder); + + buildTableSourceNode(shouldReuseSourceTopicForChangelog); + + EasyMock.verify(topologyBuilder); + } + + @Test + public void shouldConnectStateStoreToChangelogTopic() { + final boolean shouldReuseSourceTopicForChangelog = false; + EasyMock.replay(topologyBuilder); + + buildTableSourceNode(shouldReuseSourceTopicForChangelog); + + EasyMock.verify(topologyBuilder); + } + + private void buildTableSourceNode(final boolean shouldReuseSourceTopicForChangelog) { + final TableSourceNodeBuilder tableSourceNodeBuilder = TableSourceNode.tableSourceNodeBuilder(); + final TableSourceNode tableSourceNode = tableSourceNodeBuilder + .withTopic(TOPIC) + .withMaterializedInternal(new MaterializedInternal<>(Materialized.as(STORE_NAME))) + .withConsumedInternal(new ConsumedInternal<>(Consumed.as("node-name"))) + .withProcessorParameters( + new ProcessorParameters<>(new KTableSource<>(STORE_NAME, STORE_NAME), null)) + .build(); + tableSourceNode.reuseSourceTopicForChangeLog(shouldReuseSourceTopicForChangelog); + + tableSourceNode.writeToTopology(topologyBuilder); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/suppress/KTableSuppressProcessorMetricsTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/suppress/KTableSuppressProcessorMetricsTest.java new file mode 100644 index 0000000..009a70d --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/suppress/KTableSuppressProcessorMetricsTest.java @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals.suppress; + +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.SystemTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.Suppressed; +import org.apache.kafka.streams.kstream.internals.Change; +import org.apache.kafka.streams.kstream.internals.KTableImpl; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.ProcessorNode; +import org.apache.kafka.streams.state.internals.InMemoryTimeOrderedKeyValueBuffer; +import org.apache.kafka.test.MockInternalNewProcessorContext; +import org.apache.kafka.test.StreamsTestUtils; +import org.apache.kafka.test.TestUtils; +import org.easymock.EasyMock; +import org.hamcrest.Matcher; +import org.junit.Test; + +import java.time.Duration; +import java.util.Map; +import java.util.Properties; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.streams.kstream.Suppressed.BufferConfig.maxRecords; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.core.Is.is; + +public class KTableSuppressProcessorMetricsTest { + private static final long ARBITRARY_LONG = 5L; + private static final TaskId TASK_ID = new TaskId(0, 0); + private final Properties streamsConfig = StreamsTestUtils.getStreamsConfig(); + private final String threadId = Thread.currentThread().getName(); + + private final MetricName evictionTotalMetricLatest = new MetricName( + "suppression-emit-total", + "stream-processor-node-metrics", + "The total number of emitted records from the suppression buffer", + mkMap( + mkEntry("thread-id", threadId), + mkEntry("task-id", TASK_ID.toString()), + mkEntry("processor-node-id", "testNode") + ) + ); + + private final MetricName evictionRateMetricLatest = new MetricName( + "suppression-emit-rate", + "stream-processor-node-metrics", + "The average number of emitted records from the suppression buffer per second", + mkMap( + mkEntry("thread-id", threadId), + mkEntry("task-id", TASK_ID.toString()), + mkEntry("processor-node-id", "testNode") + ) + ); + + private final MetricName bufferSizeAvgMetricLatest = new MetricName( + "suppression-buffer-size-avg", + "stream-state-metrics", + "The average size of buffered records", + mkMap( + mkEntry("thread-id", threadId), + mkEntry("task-id", TASK_ID.toString()), + mkEntry("in-memory-suppression-state-id", "test-store") + ) + ); + + private final MetricName bufferSizeMaxMetricLatest = new MetricName( + "suppression-buffer-size-max", + "stream-state-metrics", + "The maximum size of buffered records", + mkMap( + mkEntry("thread-id", threadId), + mkEntry("task-id", TASK_ID.toString()), + mkEntry("in-memory-suppression-state-id", "test-store") + ) + ); + + private final MetricName bufferCountAvgMetricLatest = new MetricName( + "suppression-buffer-count-avg", + "stream-state-metrics", + "The average count of buffered records", + mkMap( + mkEntry("thread-id", threadId), + mkEntry("task-id", TASK_ID.toString()), + mkEntry("in-memory-suppression-state-id", "test-store") + ) + ); + + private final MetricName bufferCountMaxMetricLatest = new MetricName( + "suppression-buffer-count-max", + "stream-state-metrics", + "The maximum count of buffered records", + mkMap( + mkEntry("thread-id", threadId), + mkEntry("task-id", TASK_ID.toString()), + mkEntry("in-memory-suppression-state-id", "test-store") + ) + ); + + @Test + public void shouldRecordMetricsWithBuiltInMetricsVersionLatest() { + final String storeName = "test-store"; + + final StateStore buffer = new InMemoryTimeOrderedKeyValueBuffer.Builder<>( + storeName, Serdes.String(), + Serdes.Long() + ) + .withLoggingDisabled() + .build(); + + final KTableImpl mock = EasyMock.mock(KTableImpl.class); + final Processor, String, Change> processor = + new KTableSuppressProcessorSupplier<>( + (SuppressedInternal) Suppressed.untilTimeLimit(Duration.ofDays(100), maxRecords(1)), + storeName, + mock + ).get(); + + streamsConfig.setProperty(StreamsConfig.BUILT_IN_METRICS_VERSION_CONFIG, StreamsConfig.METRICS_LATEST); + final MockInternalNewProcessorContext> context = + new MockInternalNewProcessorContext<>(streamsConfig, TASK_ID, TestUtils.tempDirectory()); + final Time time = new SystemTime(); + context.setCurrentNode(new ProcessorNode("testNode")); + context.setSystemTimeMs(time.milliseconds()); + + buffer.init((StateStoreContext) context, buffer); + processor.init(context); + + final long timestamp = 100L; + context.setRecordMetadata("", 0, 0L); + context.setTimestamp(timestamp); + final String key = "longKey"; + final Change value = new Change<>(null, ARBITRARY_LONG); + processor.process(new Record<>(key, value, timestamp)); + + final MetricName evictionRateMetric = evictionRateMetricLatest; + final MetricName evictionTotalMetric = evictionTotalMetricLatest; + final MetricName bufferSizeAvgMetric = bufferSizeAvgMetricLatest; + final MetricName bufferSizeMaxMetric = bufferSizeMaxMetricLatest; + final MetricName bufferCountAvgMetric = bufferCountAvgMetricLatest; + final MetricName bufferCountMaxMetric = bufferCountMaxMetricLatest; + + { + final Map metrics = context.metrics().metrics(); + + verifyMetric(metrics, evictionRateMetric, is(0.0)); + verifyMetric(metrics, evictionTotalMetric, is(0.0)); + verifyMetric(metrics, bufferSizeAvgMetric, is(21.5)); + verifyMetric(metrics, bufferSizeMaxMetric, is(43.0)); + verifyMetric(metrics, bufferCountAvgMetric, is(0.5)); + verifyMetric(metrics, bufferCountMaxMetric, is(1.0)); + } + + context.setRecordMetadata("", 0, 1L); + context.setTimestamp(timestamp + 1); + processor.process(new Record<>("key", value, timestamp + 1)); + + { + final Map metrics = context.metrics().metrics(); + + verifyMetric(metrics, evictionRateMetric, greaterThan(0.0)); + verifyMetric(metrics, evictionTotalMetric, is(1.0)); + verifyMetric(metrics, bufferSizeAvgMetric, is(41.0)); + verifyMetric(metrics, bufferSizeMaxMetric, is(82.0)); + verifyMetric(metrics, bufferCountAvgMetric, is(1.0)); + verifyMetric(metrics, bufferCountMaxMetric, is(2.0)); + } + } + + @SuppressWarnings("unchecked") + private static void verifyMetric(final Map metrics, + final MetricName metricName, + final Matcher matcher) { + assertThat(metrics.get(metricName).metricName().description(), is(metricName.description())); + assertThat((T) metrics.get(metricName).metricValue(), matcher); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/suppress/KTableSuppressProcessorTest.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/suppress/KTableSuppressProcessorTest.java new file mode 100644 index 0000000..1505d0d --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/suppress/KTableSuppressProcessorTest.java @@ -0,0 +1,482 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals.suppress; + +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.kstream.Suppressed; +import org.apache.kafka.streams.kstream.TimeWindowedDeserializer; +import org.apache.kafka.streams.kstream.TimeWindowedSerializer; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.Change; +import org.apache.kafka.streams.kstream.internals.KTableImpl; +import org.apache.kafka.streams.kstream.internals.SessionWindow; +import org.apache.kafka.streams.kstream.internals.TimeWindow; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.api.MockProcessorContext; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.ProcessorNode; +import org.apache.kafka.streams.state.internals.InMemoryTimeOrderedKeyValueBuffer; +import org.apache.kafka.test.MockInternalNewProcessorContext; +import org.easymock.EasyMock; +import org.hamcrest.BaseMatcher; +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.junit.Test; + +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.Collection; + +import static java.time.Duration.ZERO; +import static java.time.Duration.ofMillis; +import static org.apache.kafka.common.serialization.Serdes.Long; +import static org.apache.kafka.common.serialization.Serdes.String; +import static org.apache.kafka.streams.kstream.Suppressed.BufferConfig.maxBytes; +import static org.apache.kafka.streams.kstream.Suppressed.BufferConfig.maxRecords; +import static org.apache.kafka.streams.kstream.Suppressed.BufferConfig.unbounded; +import static org.apache.kafka.streams.kstream.Suppressed.untilTimeLimit; +import static org.apache.kafka.streams.kstream.Suppressed.untilWindowCloses; +import static org.apache.kafka.streams.kstream.WindowedSerdes.sessionWindowedSerdeFrom; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.fail; + +public class KTableSuppressProcessorTest { + private static final long ARBITRARY_LONG = 5L; + + private static final Change ARBITRARY_CHANGE = new Change<>(7L, 14L); + + private static class Harness { + private final Processor, K, Change> processor; + private final MockInternalNewProcessorContext> context; + + + Harness(final Suppressed suppressed, + final Serde keySerde, + final Serde valueSerde) { + + final String storeName = "test-store"; + + final StateStore buffer = new InMemoryTimeOrderedKeyValueBuffer.Builder<>(storeName, keySerde, valueSerde) + .withLoggingDisabled() + .build(); + + final KTableImpl parent = EasyMock.mock(KTableImpl.class); + final Processor, K, Change> processor = + new KTableSuppressProcessorSupplier<>((SuppressedInternal) suppressed, storeName, parent).get(); + + final MockInternalNewProcessorContext> context = new MockInternalNewProcessorContext<>(); + context.setCurrentNode(new ProcessorNode("testNode")); + + buffer.init((StateStoreContext) context, buffer); + processor.init(context); + + this.processor = processor; + this.context = context; + } + } + + @Test + public void zeroTimeLimitShouldImmediatelyEmit() { + final Harness harness = + new Harness<>(untilTimeLimit(ZERO, unbounded()), String(), Long()); + final MockInternalNewProcessorContext> context = harness.context; + + final long timestamp = ARBITRARY_LONG; + context.setRecordMetadata("", 0, 0L); + context.setTimestamp(timestamp); + final String key = "hey"; + final Change value = ARBITRARY_CHANGE; + harness.processor.process(new Record<>(key, value, timestamp)); + + assertThat(context.forwarded(), hasSize(1)); + final MockProcessorContext.CapturedForward capturedForward = context.forwarded().get(0); + assertThat(capturedForward.record(), is(new Record<>(key, value, timestamp))); + } + + @Test + public void windowedZeroTimeLimitShouldImmediatelyEmit() { + final Harness, Long> harness = + new Harness<>(untilTimeLimit(ZERO, unbounded()), timeWindowedSerdeFrom(String.class, 100L), Long()); + final MockInternalNewProcessorContext, Change> context = harness.context; + + final long timestamp = ARBITRARY_LONG; + context.setRecordMetadata("", 0, 0L); + context.setTimestamp(timestamp); + final Windowed key = new Windowed<>("hey", new TimeWindow(0L, 100L)); + final Change value = ARBITRARY_CHANGE; + harness.processor.process(new Record<>(key, value, timestamp)); + + assertThat(context.forwarded(), hasSize(1)); + final MockProcessorContext.CapturedForward capturedForward = context.forwarded().get(0); + assertThat(capturedForward.record(), is(new Record<>(key, value, timestamp))); + } + + @Test + public void intermediateSuppressionShouldBufferAndEmitLater() { + final Harness harness = + new Harness<>(untilTimeLimit(ofMillis(1), unbounded()), String(), Long()); + final MockInternalNewProcessorContext> context = harness.context; + + final long timestamp = 0L; + context.setRecordMetadata("topic", 0, 0); + context.setTimestamp(timestamp); + final String key = "hey"; + final Change value = new Change<>(null, 1L); + harness.processor.process(new Record<>(key, value, timestamp)); + assertThat(context.forwarded(), hasSize(0)); + + context.setRecordMetadata("topic", 0, 1); + context.setTimestamp(1L); + harness.processor.process(new Record<>("tick", new Change<>(null, null), 1L)); + + assertThat(context.forwarded(), hasSize(1)); + final MockProcessorContext.CapturedForward capturedForward = context.forwarded().get(0); + assertThat(capturedForward.record(), is(new Record<>(key, value, timestamp))); + } + + @Test + public void finalResultsSuppressionShouldBufferAndEmitAtGraceExpiration() { + final Harness, Long> harness = + new Harness<>(finalResults(ofMillis(1L)), timeWindowedSerdeFrom(String.class, 1L), Long()); + final MockInternalNewProcessorContext, Change> context = harness.context; + + final long windowStart = 99L; + final long recordTime = 99L; + final long windowEnd = 100L; + context.setRecordMetadata("topic", 0, 0); + context.setTimestamp(recordTime); + final Windowed key = new Windowed<>("hey", new TimeWindow(windowStart, windowEnd)); + final Change value = ARBITRARY_CHANGE; + harness.processor.process(new Record<>(key, value, recordTime)); + assertThat(context.forwarded(), hasSize(0)); + + // although the stream time is now 100, we have to wait 1 ms after the window *end* before we + // emit "hey", so we don't emit yet. + final long windowStart2 = 100L; + final long recordTime2 = 100L; + final long windowEnd2 = 101L; + context.setRecordMetadata("topic", 0, 1); + context.setTimestamp(recordTime2); + harness.processor.process(new Record<>(new Windowed<>("dummyKey1", new TimeWindow(windowStart2, windowEnd2)), ARBITRARY_CHANGE, recordTime2)); + assertThat(context.forwarded(), hasSize(0)); + + // ok, now it's time to emit "hey" + final long windowStart3 = 101L; + final long recordTime3 = 101L; + final long windowEnd3 = 102L; + context.setRecordMetadata("topic", 0, 1); + context.setTimestamp(recordTime3); + harness.processor.process(new Record<>(new Windowed<>("dummyKey2", new TimeWindow(windowStart3, windowEnd3)), ARBITRARY_CHANGE, recordTime3)); + + assertThat(context.forwarded(), hasSize(1)); + final MockProcessorContext.CapturedForward capturedForward = context.forwarded().get(0); + assertThat(capturedForward.record(), is(new Record<>(key, value, recordTime))); + } + + /** + * Testing a special case of final results: that even with a grace period of 0, + * it will still buffer events and emit only after the end of the window. + * As opposed to emitting immediately the way regular suppression would with a time limit of 0. + */ + @Test + public void finalResultsWithZeroGraceShouldStillBufferUntilTheWindowEnd() { + final Harness, Long> harness = + new Harness<>(finalResults(ofMillis(0L)), timeWindowedSerdeFrom(String.class, 100L), Long()); + final MockInternalNewProcessorContext, Change> context = harness.context; + + // note the record is in the past, but the window end is in the future, so we still have to buffer, + // even though the grace period is 0. + final long timestamp = 5L; + final long windowEnd = 100L; + context.setRecordMetadata("", 0, 0L); + context.setTimestamp(timestamp); + final Windowed key = new Windowed<>("hey", new TimeWindow(0, windowEnd)); + final Change value = ARBITRARY_CHANGE; + harness.processor.process(new Record<>(key, value, timestamp)); + assertThat(context.forwarded(), hasSize(0)); + + context.setRecordMetadata("", 0, 1L); + context.setTimestamp(windowEnd); + harness.processor.process(new Record<>(new Windowed<>("dummyKey", new TimeWindow(windowEnd, windowEnd + 100L)), ARBITRARY_CHANGE, windowEnd)); + + assertThat(context.forwarded(), hasSize(1)); + final MockProcessorContext.CapturedForward capturedForward = context.forwarded().get(0); + assertThat(capturedForward.record(), is(new Record<>(key, value, timestamp))); + } + + @Test + public void finalResultsWithZeroGraceAtWindowEndShouldImmediatelyEmit() { + final Harness, Long> harness = + new Harness<>(finalResults(ofMillis(0L)), timeWindowedSerdeFrom(String.class, 100L), Long()); + final MockInternalNewProcessorContext, Change> context = harness.context; + + final long timestamp = 100L; + context.setRecordMetadata("", 0, 0L); + context.setTimestamp(timestamp); + final Windowed key = new Windowed<>("hey", new TimeWindow(0, 100L)); + final Change value = ARBITRARY_CHANGE; + harness.processor.process(new Record<>(key, value, timestamp)); + + assertThat(context.forwarded(), hasSize(1)); + final MockProcessorContext.CapturedForward capturedForward = context.forwarded().get(0); + assertThat(capturedForward.record(), is(new Record<>(key, value, timestamp))); + } + + /** + * It's desirable to drop tombstones for final-results windowed streams, since (as described in the + * {@link SuppressedInternal} javadoc), they are unnecessary to emit. + */ + @Test + public void finalResultsShouldDropTombstonesForTimeWindows() { + final Harness, Long> harness = + new Harness<>(finalResults(ofMillis(0L)), timeWindowedSerdeFrom(String.class, 100L), Long()); + final MockInternalNewProcessorContext, Change> context = harness.context; + + final long timestamp = 100L; + context.setRecordMetadata("", 0, 0L); + context.setTimestamp(timestamp); + final Windowed key = new Windowed<>("hey", new TimeWindow(0, 100L)); + final Change value = new Change<>(null, ARBITRARY_LONG); + harness.processor.process(new Record<>(key, value, timestamp)); + + assertThat(context.forwarded(), hasSize(0)); + } + + + /** + * It's desirable to drop tombstones for final-results windowed streams, since (as described in the + * {@link SuppressedInternal} javadoc), they are unnecessary to emit. + */ + @Test + public void finalResultsShouldDropTombstonesForSessionWindows() { + final Harness, Long> harness = + new Harness<>(finalResults(ofMillis(0L)), sessionWindowedSerdeFrom(String.class), Long()); + final MockInternalNewProcessorContext, Change> context = harness.context; + + final long timestamp = 100L; + context.setRecordMetadata("", 0, 0L); + context.setTimestamp(timestamp); + final Windowed key = new Windowed<>("hey", new SessionWindow(0L, 0L)); + final Change value = new Change<>(null, ARBITRARY_LONG); + harness.processor.process(new Record<>(key, value, timestamp)); + + assertThat(context.forwarded(), hasSize(0)); + } + + /** + * It's NOT OK to drop tombstones for non-final-results windowed streams, since we may have emitted some results for + * the window before getting the tombstone (see the {@link SuppressedInternal} javadoc). + */ + @Test + public void suppressShouldNotDropTombstonesForTimeWindows() { + final Harness, Long> harness = + new Harness<>(untilTimeLimit(ofMillis(0), maxRecords(0)), timeWindowedSerdeFrom(String.class, 100L), Long()); + final MockInternalNewProcessorContext, Change> context = harness.context; + + final long timestamp = 100L; + final Headers headers = new RecordHeaders().add("k", "v".getBytes(StandardCharsets.UTF_8)); + context.setRecordMetadata("", 0, 0L); + context.setTimestamp(timestamp); + context.setHeaders(headers); + final Windowed key = new Windowed<>("hey", new TimeWindow(0L, 100L)); + final Change value = new Change<>(null, ARBITRARY_LONG); + harness.processor.process(new Record<>(key, value, timestamp)); + + assertThat(context.forwarded(), hasSize(1)); + final MockProcessorContext.CapturedForward capturedForward = context.forwarded().get(0); + assertThat(capturedForward.record(), is(new Record<>(key, value, timestamp, headers))); + } + + + /** + * It's NOT OK to drop tombstones for non-final-results windowed streams, since we may have emitted some results for + * the window before getting the tombstone (see the {@link SuppressedInternal} javadoc). + */ + @Test + public void suppressShouldNotDropTombstonesForSessionWindows() { + final Harness, Long> harness = + new Harness<>(untilTimeLimit(ofMillis(0), maxRecords(0)), sessionWindowedSerdeFrom(String.class), Long()); + final MockInternalNewProcessorContext, Change> context = harness.context; + + final long timestamp = 100L; + context.setRecordMetadata("", 0, 0L); + context.setTimestamp(timestamp); + final Windowed key = new Windowed<>("hey", new SessionWindow(0L, 0L)); + final Change value = new Change<>(null, ARBITRARY_LONG); + harness.processor.process(new Record<>(key, value, timestamp)); + + assertThat(context.forwarded(), hasSize(1)); + final MockProcessorContext.CapturedForward capturedForward = context.forwarded().get(0); + assertThat(capturedForward.record(), is(new Record<>(key, value, timestamp))); + } + + + /** + * It's SUPER NOT OK to drop tombstones for non-windowed streams, since we may have emitted some results for + * the key before getting the tombstone (see the {@link SuppressedInternal} javadoc). + */ + @Test + public void suppressShouldNotDropTombstonesForKTable() { + final Harness harness = + new Harness<>(untilTimeLimit(ofMillis(0), maxRecords(0)), String(), Long()); + final MockInternalNewProcessorContext> context = harness.context; + + final long timestamp = 100L; + context.setRecordMetadata("", 0, 0L); + context.setTimestamp(timestamp); + final String key = "hey"; + final Change value = new Change<>(null, ARBITRARY_LONG); + harness.processor.process(new Record<>(key, value, timestamp)); + + assertThat(context.forwarded(), hasSize(1)); + final MockProcessorContext.CapturedForward capturedForward = context.forwarded().get(0); + assertThat(capturedForward.record(), is(new Record<>(key, value, timestamp))); + } + + @Test + public void suppressShouldEmitWhenOverRecordCapacity() { + final Harness harness = + new Harness<>(untilTimeLimit(Duration.ofDays(100), maxRecords(1)), String(), Long()); + final MockInternalNewProcessorContext> context = harness.context; + + final long timestamp = 100L; + context.setRecordMetadata("", 0, 0L); + context.setTimestamp(timestamp); + final String key = "hey"; + final Change value = new Change<>(null, ARBITRARY_LONG); + harness.processor.process(new Record<>(key, value, timestamp)); + + context.setRecordMetadata("", 0, 1L); + context.setTimestamp(timestamp + 1); + harness.processor.process(new Record<>("dummyKey", value, timestamp + 1)); + + assertThat(context.forwarded(), hasSize(1)); + final MockProcessorContext.CapturedForward capturedForward = context.forwarded().get(0); + assertThat(capturedForward.record(), is(new Record<>(key, value, timestamp))); + } + + @Test + public void suppressShouldEmitWhenOverByteCapacity() { + final Harness harness = + new Harness<>(untilTimeLimit(Duration.ofDays(100), maxBytes(60L)), String(), Long()); + final MockInternalNewProcessorContext> context = harness.context; + + final long timestamp = 100L; + context.setRecordMetadata("", 0, 0L); + context.setTimestamp(timestamp); + final String key = "hey"; + final Change value = new Change<>(null, ARBITRARY_LONG); + harness.processor.process(new Record<>(key, value, timestamp)); + + context.setRecordMetadata("", 0, 1L); + context.setTimestamp(timestamp + 1); + harness.processor.process(new Record<>("dummyKey", value, timestamp + 1)); + + assertThat(context.forwarded(), hasSize(1)); + final MockProcessorContext.CapturedForward capturedForward = context.forwarded().get(0); + assertThat(capturedForward.record(), is(new Record<>(key, value, timestamp))); + } + + @Test + public void suppressShouldShutDownWhenOverRecordCapacity() { + final Harness harness = + new Harness<>(untilTimeLimit(Duration.ofDays(100), maxRecords(1).shutDownWhenFull()), String(), Long()); + final MockInternalNewProcessorContext> context = harness.context; + + final long timestamp = 100L; + context.setRecordMetadata("", 0, 0L); + context.setTimestamp(timestamp); + context.setCurrentNode(new ProcessorNode("testNode")); + final String key = "hey"; + final Change value = new Change<>(null, ARBITRARY_LONG); + harness.processor.process(new Record<>(key, value, timestamp)); + + context.setRecordMetadata("", 0, 1L); + context.setTimestamp(timestamp); + try { + harness.processor.process(new Record<>("dummyKey", value, timestamp)); + fail("expected an exception"); + } catch (final StreamsException e) { + assertThat(e.getMessage(), containsString("buffer exceeded its max capacity")); + } + } + + @Test + public void suppressShouldShutDownWhenOverByteCapacity() { + final Harness harness = + new Harness<>(untilTimeLimit(Duration.ofDays(100), maxBytes(60L).shutDownWhenFull()), String(), Long()); + final MockInternalNewProcessorContext> context = harness.context; + + final long timestamp = 100L; + context.setRecordMetadata("", 0, 0L); + context.setTimestamp(timestamp); + context.setCurrentNode(new ProcessorNode("testNode")); + final String key = "hey"; + final Change value = new Change<>(null, ARBITRARY_LONG); + harness.processor.process(new Record<>(key, value, timestamp)); + + context.setRecordMetadata("", 0, 1L); + context.setTimestamp(1L); + try { + harness.processor.process(new Record<>("dummyKey", value, timestamp)); + fail("expected an exception"); + } catch (final StreamsException e) { + assertThat(e.getMessage(), containsString("buffer exceeded its max capacity")); + } + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private static SuppressedInternal finalResults(final Duration grace) { + return ((FinalResultsSuppressionBuilder) untilWindowCloses(unbounded())).buildFinalResultsSuppression(grace); + } + + private static Matcher> hasSize(final int i) { + return new BaseMatcher>() { + @Override + public void describeTo(final Description description) { + description.appendText("a collection of size " + i); + } + + @SuppressWarnings("unchecked") + @Override + public boolean matches(final Object item) { + if (item == null) { + return false; + } else { + return ((Collection) item).size() == i; + } + } + + }; + } + + private static Serde> timeWindowedSerdeFrom(final Class rawType, final long windowSize) { + final Serde kSerde = Serdes.serdeFrom(rawType); + return new Serdes.WrapperSerde<>( + new TimeWindowedSerializer<>(kSerde.serializer()), + new TimeWindowedDeserializer<>(kSerde.deserializer(), windowSize) + ); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/kstream/internals/suppress/SuppressSuite.java b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/suppress/SuppressSuite.java new file mode 100644 index 0000000..a323b9b --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/kstream/internals/suppress/SuppressSuite.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.kstream.internals.suppress; + +import org.apache.kafka.streams.integration.SuppressionDurabilityIntegrationTest; +import org.apache.kafka.streams.integration.SuppressionIntegrationTest; +import org.apache.kafka.streams.kstream.SuppressedTest; +import org.apache.kafka.streams.kstream.internals.FullChangeSerdeTest; +import org.apache.kafka.streams.kstream.internals.SuppressScenarioTest; +import org.apache.kafka.streams.kstream.internals.SuppressTopologyTest; +import org.apache.kafka.streams.state.internals.BufferValueTest; +import org.apache.kafka.streams.state.internals.InMemoryTimeOrderedKeyValueBufferTest; +import org.apache.kafka.streams.state.internals.TimeOrderedKeyValueBufferTest; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; + +/** + * This suite runs all the tests related to the Suppression feature. + * + * It can be used from an IDE to selectively just run these tests when developing code related to Suppress. + * + * If desired, it can also be added to a Gradle build task, although this isn't strictly necessary, since all + * these tests are already included in the `:streams:test` task. + */ +@RunWith(Suite.class) +@Suite.SuiteClasses({ + BufferValueTest.class, + KTableSuppressProcessorMetricsTest.class, + KTableSuppressProcessorTest.class, + SuppressScenarioTest.class, + SuppressTopologyTest.class, + SuppressedTest.class, + InMemoryTimeOrderedKeyValueBufferTest.class, + TimeOrderedKeyValueBufferTest.class, + FullChangeSerdeTest.class, + SuppressionIntegrationTest.class, + SuppressionDurabilityIntegrationTest.class +}) +public class SuppressSuite { +} + + diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/FailOnInvalidTimestampTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/FailOnInvalidTimestampTest.java new file mode 100644 index 0000000..78ede86 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/FailOnInvalidTimestampTest.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.streams.errors.StreamsException; +import static org.junit.Assert.assertThrows; +import org.junit.Test; + +public class FailOnInvalidTimestampTest extends TimestampExtractorTest { + + @Test + public void extractMetadataTimestamp() { + testExtractMetadataTimestamp(new FailOnInvalidTimestamp()); + } + + @Test + public void failOnInvalidTimestamp() { + final TimestampExtractor extractor = new FailOnInvalidTimestamp(); + assertThrows(StreamsException.class, () -> extractor.extract(new ConsumerRecord<>("anyTopic", + 0, 0, null, null), 42)); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/LogAndSkipOnInvalidTimestampTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/LogAndSkipOnInvalidTimestampTest.java new file mode 100644 index 0000000..5474f9f --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/LogAndSkipOnInvalidTimestampTest.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.record.TimestampType; +import org.junit.Test; + +import java.util.Optional; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +public class LogAndSkipOnInvalidTimestampTest extends TimestampExtractorTest { + + @Test + public void extractMetadataTimestamp() { + testExtractMetadataTimestamp(new LogAndSkipOnInvalidTimestamp()); + } + + @Test + public void logAndSkipOnInvalidTimestamp() { + final long invalidMetadataTimestamp = -42; + + final TimestampExtractor extractor = new LogAndSkipOnInvalidTimestamp(); + final long timestamp = extractor.extract( + new ConsumerRecord<>( + "anyTopic", + 0, + 0, + invalidMetadataTimestamp, + TimestampType.NO_TIMESTAMP_TYPE, + 0, + 0, + null, + null, + new RecordHeaders(), + Optional.empty()), + 0 + ); + + assertThat(timestamp, is(invalidMetadataTimestamp)); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/TimestampExtractorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/TimestampExtractorTest.java new file mode 100644 index 0000000..dca9a70 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/TimestampExtractorTest.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.record.TimestampType; + +import java.util.Optional; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +class TimestampExtractorTest { + + void testExtractMetadataTimestamp(final TimestampExtractor extractor) { + final long metadataTimestamp = 42; + + final long timestamp = extractor.extract( + new ConsumerRecord<>( + "anyTopic", + 0, + 0, + metadataTimestamp, + TimestampType.NO_TIMESTAMP_TYPE, + 0, + 0, + null, + null, + new RecordHeaders(), + Optional.empty()), + 0 + ); + + assertThat(timestamp, is(metadataTimestamp)); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/UsePartitionTimeOnInvalidTimestampTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/UsePartitionTimeOnInvalidTimestampTest.java new file mode 100644 index 0000000..fb93032 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/UsePartitionTimeOnInvalidTimestampTest.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.streams.errors.StreamsException; +import org.junit.Test; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.fail; + +public class UsePartitionTimeOnInvalidTimestampTest extends TimestampExtractorTest { + + @Test + public void extractMetadataTimestamp() { + testExtractMetadataTimestamp(new UsePartitionTimeOnInvalidTimestamp()); + } + + @Test + public void usePartitionTimeOnInvalidTimestamp() { + final long partitionTime = 42; + + final TimestampExtractor extractor = new UsePartitionTimeOnInvalidTimestamp(); + final long timestamp = extractor.extract( + new ConsumerRecord<>("anyTopic", 0, 0, null, null), + partitionTime + ); + + assertThat(timestamp, is(partitionTime)); + } + + @Test + public void shouldThrowStreamsException() { + final TimestampExtractor extractor = new UsePartitionTimeOnInvalidTimestamp(); + final ConsumerRecord record = new ConsumerRecord<>("anyTopic", 0, 0, null, null); + try { + extractor.extract(record, -1); + fail("should have thrown StreamsException"); + } catch (final StreamsException expected) { } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/WallclockTimestampExtractorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/WallclockTimestampExtractorTest.java new file mode 100644 index 0000000..ac4f4e7 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/WallclockTimestampExtractorTest.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.hamcrest.BaseMatcher; +import org.hamcrest.Description; +import org.junit.Test; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +public class WallclockTimestampExtractorTest { + + @Test + public void extractSystemTimestamp() { + final TimestampExtractor extractor = new WallclockTimestampExtractor(); + + final long before = System.currentTimeMillis(); + final long timestamp = extractor.extract(new ConsumerRecord<>("anyTopic", 0, 0, null, null), 42); + final long after = System.currentTimeMillis(); + + assertThat(timestamp, is(new InBetween(before, after))); + } + + private static class InBetween extends BaseMatcher { + private final long before; + private final long after; + + public InBetween(final long before, final long after) { + this.before = before; + this.after = after; + } + + @Override + public boolean matches(final Object item) { + final long timestamp = (Long) item; + return before <= timestamp && timestamp <= after; + } + + @Override + public void describeMismatch(final Object item, final Description mismatchDescription) {} + + @Override + public void describeTo(final Description description) {} + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractProcessorContextTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractProcessorContextTest.java new file mode 100644 index 0000000..7a78af5 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/AbstractProcessorContextTest.java @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.processor.Cancellable; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.Punctuator; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.To; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.state.RocksDBConfigSetter; +import org.apache.kafka.streams.state.internals.ThreadCache; +import org.apache.kafka.streams.state.internals.ThreadCache.DirtyEntryFlushListener; +import org.apache.kafka.test.MockKeyValueStore; +import org.junit.Before; +import org.junit.Test; + +import java.time.Duration; +import java.util.Properties; + +import static org.apache.kafka.test.StreamsTestUtils.getStreamsConfig; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.emptyIterable; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.fail; + +public class AbstractProcessorContextTest { + + private final MockStreamsMetrics metrics = new MockStreamsMetrics(new Metrics()); + private final AbstractProcessorContext context = new TestProcessorContext(metrics); + private final MockKeyValueStore stateStore = new MockKeyValueStore("store", false); + private final Headers headers = new RecordHeaders(new Header[]{new RecordHeader("key", "value".getBytes())}); + private final ProcessorRecordContext recordContext = new ProcessorRecordContext(10, System.currentTimeMillis(), 1, "foo", headers); + + @Before + public void before() { + context.setRecordContext(recordContext); + } + + @Test + public void shouldThrowIllegalStateExceptionOnRegisterWhenContextIsInitialized() { + context.initialize(); + try { + context.register(stateStore, null); + fail("should throw illegal state exception when context already initialized"); + } catch (final IllegalStateException e) { + // pass + } + } + + @Test + public void shouldNotThrowIllegalStateExceptionOnRegisterWhenContextIsNotInitialized() { + context.register(stateStore, null); + } + + @Test + public void shouldThrowNullPointerOnRegisterIfStateStoreIsNull() { + assertThrows(NullPointerException.class, () -> context.register(null, null)); + } + + @Test + public void shouldReturnNullTopicIfNoRecordContext() { + context.setRecordContext(null); + assertThat(context.topic(), is(nullValue())); + } + + @Test + public void shouldNotThrowNullPointerExceptionOnTopicIfRecordContextTopicIsNull() { + context.setRecordContext(new ProcessorRecordContext(0, 0, 0, null, new RecordHeaders())); + assertThat(context.topic(), nullValue()); + } + + @Test + public void shouldReturnTopicFromRecordContext() { + assertThat(context.topic(), equalTo(recordContext.topic())); + } + + @Test + public void shouldReturnNullIfTopicEqualsNonExistTopic() { + context.setRecordContext(null); + assertThat(context.topic(), nullValue()); + } + + @Test + public void shouldReturnDummyPartitionIfNoRecordContext() { + context.setRecordContext(null); + assertThat(context.partition(), is(-1)); + } + + @Test + public void shouldReturnPartitionFromRecordContext() { + assertThat(context.partition(), equalTo(recordContext.partition())); + } + + @Test + public void shouldThrowIllegalStateExceptionOnOffsetIfNoRecordContext() { + context.setRecordContext(null); + try { + context.offset(); + } catch (final IllegalStateException e) { + // pass + } + } + + @Test + public void shouldReturnOffsetFromRecordContext() { + assertThat(context.offset(), equalTo(recordContext.offset())); + } + + @Test + public void shouldReturnDummyTimestampIfNoRecordContext() { + context.setRecordContext(null); + assertThat(context.timestamp(), is(0L)); + } + + @Test + public void shouldReturnTimestampFromRecordContext() { + assertThat(context.timestamp(), equalTo(recordContext.timestamp())); + } + + @Test + public void shouldReturnHeadersFromRecordContext() { + assertThat(context.headers(), equalTo(recordContext.headers())); + } + + @Test + public void shouldReturnEmptyHeadersIfHeadersAreNotSet() { + context.setRecordContext(null); + assertThat(context.headers(), is(emptyIterable())); + } + + @Test + public void appConfigsShouldReturnParsedValues() { + assertThat( + context.appConfigs().get(StreamsConfig.ROCKSDB_CONFIG_SETTER_CLASS_CONFIG), + equalTo(RocksDBConfigSetter.class) + ); + } + + @Test + public void appConfigsShouldReturnUnrecognizedValues() { + assertThat( + context.appConfigs().get("user.supplied.config"), + equalTo("user-supplied-value") + ); + } + @Test + public void shouldThrowErrorIfSerdeDefaultNotSet() { + final Properties config = getStreamsConfig(); + config.put(StreamsConfig.ROCKSDB_CONFIG_SETTER_CLASS_CONFIG, RocksDBConfigSetter.class.getName()); + config.put("user.supplied.config", "user-supplied-value"); + final TestProcessorContext pc = new TestProcessorContext(metrics, config); + assertThrows(ConfigException.class, pc::keySerde); + assertThrows(ConfigException.class, pc::valueSerde); + } + + private static class TestProcessorContext extends AbstractProcessorContext { + static Properties config; + static { + config = getStreamsConfig(); + // Value must be a string to test className -> class conversion + config.put(StreamsConfig.ROCKSDB_CONFIG_SETTER_CLASS_CONFIG, RocksDBConfigSetter.class.getName()); + config.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.ByteArraySerde.class); + config.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.ByteArraySerde.class); + config.put("user.supplied.config", "user-supplied-value"); + } + + TestProcessorContext(final MockStreamsMetrics metrics) { + super(new TaskId(0, 0), new StreamsConfig(config), metrics, new ThreadCache(new LogContext("name "), 0, metrics)); + } + + TestProcessorContext(final MockStreamsMetrics metrics, final Properties config) { + super(new TaskId(0, 0), new StreamsConfig(config), metrics, new ThreadCache(new LogContext("name "), 0, metrics)); + } + + @Override + protected StateManager stateManager() { + return new StateManagerStub(); + } + + @Override + public S getStateStore(final String name) { + return null; + } + + @Override + public Cancellable schedule(final Duration interval, + final PunctuationType type, + final Punctuator callback) throws IllegalArgumentException { + return null; + } + + @Override + public void forward(final Record record) {} + + @Override + public void forward(final Record record, final String childName) {} + + @Override + public void forward(final K key, final V value) {} + + @Override + public void forward(final K key, final V value, final To to) {} + + @Override + public void commit() {} + + @Override + public long currentStreamTimeMs() { + throw new UnsupportedOperationException("this method is not supported in TestProcessorContext"); + } + + @Override + public void logChange(final String storeName, + final Bytes key, + final byte[] value, + final long timestamp) { + } + + @Override + public void transitionToActive(final StreamTask streamTask, final RecordCollector recordCollector, final ThreadCache newCache) { + } + + @Override + public void transitionToStandby(final ThreadCache newCache) { + } + + @Override + public void registerCacheFlushListener(final String namespace, final DirtyEntryFlushListener listener) { + } + + @Override + public String changelogFor(final String storeName) { + return ProcessorStateManager.storeChangelogTopic(applicationId(), storeName, taskId().topologyName()); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreatorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreatorTest.java new file mode 100644 index 0000000..74d81bd --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ActiveTaskCreatorTest.java @@ -0,0 +1,541 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.producer.MockProducer; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.metrics.KafkaMetric; +import org.apache.kafka.common.metrics.Measurable; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.TimestampExtractor; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.internals.ThreadCache; +import org.apache.kafka.test.MockClientSupplier; +import org.easymock.EasyMockRunner; +import org.easymock.Mock; +import org.easymock.MockType; +import org.junit.Test; +import org.junit.runner.RunWith; + +import java.io.File; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.mock; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.reset; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.core.IsNot.not; +import static org.junit.Assert.assertThrows; + +@RunWith(EasyMockRunner.class) +public class ActiveTaskCreatorTest { + + @Mock(type = MockType.NICE) + private InternalTopologyBuilder builder; + @Mock(type = MockType.NICE) + private StateDirectory stateDirectory; + @Mock(type = MockType.NICE) + private ChangelogReader changeLogReader; + + private final MockClientSupplier mockClientSupplier = new MockClientSupplier(); + private final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(new Metrics(), "clientId", StreamsConfig.METRICS_LATEST, new MockTime()); + private final Map properties = mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234") + ); + final UUID uuid = UUID.randomUUID(); + + private ActiveTaskCreator activeTaskCreator; + + + + // non-EOS test + + // functional test + + @Test + public void shouldConstructProducerMetricsWithEosDisabled() { + shouldConstructThreadProducerMetric(); + } + + @Test + public void shouldConstructClientIdWithEosDisabled() { + createTasks(); + + final Set clientIds = activeTaskCreator.producerClientIds(); + + assertThat(clientIds, is(Collections.singleton("clientId-StreamThread-0-producer"))); + } + + @Test + public void shouldCloseThreadProducerIfEosDisabled() { + createTasks(); + + activeTaskCreator.closeThreadProducerIfNeeded(); + + assertThat(mockClientSupplier.producers.get(0).closed(), is(true)); + } + + @Test + public void shouldNoOpCloseTaskProducerIfEosDisabled() { + createTasks(); + + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(new TaskId(0, 0)); + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(new TaskId(0, 1)); + + assertThat(mockClientSupplier.producers.get(0).closed(), is(false)); + } + + @Test + public void shouldReturnBlockedTimeWhenThreadProducer() { + final double blockedTime = 123.0; + createTasks(); + final MockProducer producer = mockClientSupplier.producers.get(0); + addMetric(producer, "flush-time-ns-total", blockedTime); + + assertThat(activeTaskCreator.totalProducerBlockedTime(), closeTo(blockedTime, 0.01)); + } + + // error handling + + @Test + public void shouldFailOnStreamsProducerPerTaskIfEosDisabled() { + createTasks(); + + final IllegalStateException thrown = assertThrows( + IllegalStateException.class, + () -> activeTaskCreator.streamsProducerForTask(null) + ); + + assertThat(thrown.getMessage(), is("Expected EXACTLY_ONCE to be enabled, but the processing mode was AT_LEAST_ONCE")); + } + + @Test + public void shouldFailOnGetThreadProducerIfEosDisabled() { + createTasks(); + + final IllegalStateException thrown = assertThrows( + IllegalStateException.class, + activeTaskCreator::threadProducer + ); + + assertThat(thrown.getMessage(), is("Expected EXACTLY_ONCE_V2 to be enabled, but the processing mode was AT_LEAST_ONCE")); + } + + @Test + public void shouldThrowStreamsExceptionOnErrorCloseThreadProducerIfEosDisabled() { + createTasks(); + mockClientSupplier.producers.get(0).closeException = new RuntimeException("KABOOM!"); + + final StreamsException thrown = assertThrows( + StreamsException.class, + activeTaskCreator::closeThreadProducerIfNeeded + ); + + assertThat(thrown.getMessage(), is("Thread producer encounter error trying to close.")); + assertThat(thrown.getCause().getMessage(), is("KABOOM!")); + } + + + + // eos-alpha test + + // functional test + + @SuppressWarnings("deprecation") + @Test + public void shouldReturnStreamsProducerPerTaskIfEosAlphaEnabled() { + properties.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE); + + shouldReturnStreamsProducerPerTask(); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldConstructProducerMetricsWithEosAlphaEnabled() { + properties.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE); + + shouldConstructProducerMetricsPerTask(); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldConstructClientIdWithEosAlphaEnabled() { + properties.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE); + mockClientSupplier.setApplicationIdForProducer("appId"); + createTasks(); + + final Set clientIds = activeTaskCreator.producerClientIds(); + + assertThat(clientIds, is(mkSet("clientId-StreamThread-0-0_0-producer", "clientId-StreamThread-0-0_1-producer"))); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNoOpCloseThreadProducerIfEosAlphaEnabled() { + properties.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE); + mockClientSupplier.setApplicationIdForProducer("appId"); + createTasks(); + + activeTaskCreator.closeThreadProducerIfNeeded(); + + assertThat(mockClientSupplier.producers.get(0).closed(), is(false)); + assertThat(mockClientSupplier.producers.get(1).closed(), is(false)); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldCloseTaskProducersIfEosAlphaEnabled() { + properties.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE); + mockClientSupplier.setApplicationIdForProducer("appId"); + createTasks(); + + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(new TaskId(0, 0)); + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(new TaskId(0, 1)); + // should no-op unknown task + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(new TaskId(0, 2)); + + assertThat(mockClientSupplier.producers.get(0).closed(), is(true)); + assertThat(mockClientSupplier.producers.get(1).closed(), is(true)); + + // should not throw because producer should be removed + mockClientSupplier.producers.get(0).closeException = new RuntimeException("KABOOM!"); + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(new TaskId(0, 0)); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldReturnBlockedTimeWhenTaskProducers() { + properties.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE); + mockClientSupplier.setApplicationIdForProducer("appId"); + createTasks(); + double total = 0.0; + double blocked = 1.0; + for (final MockProducer producer : mockClientSupplier.producers) { + addMetric(producer, "flush-time-ns-total", blocked); + total += blocked; + blocked += 1.0; + } + + assertThat(activeTaskCreator.totalProducerBlockedTime(), closeTo(total, 0.01)); + } + + // error handling + + @SuppressWarnings("deprecation") + @Test + public void shouldFailForUnknownTaskOnStreamsProducerPerTaskIfEosAlphaEnabled() { + properties.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE); + mockClientSupplier.setApplicationIdForProducer("appId"); + + createTasks(); + + { + final IllegalStateException thrown = assertThrows( + IllegalStateException.class, + () -> activeTaskCreator.streamsProducerForTask(null) + ); + + assertThat(thrown.getMessage(), is("Unknown TaskId: null")); + } + { + final IllegalStateException thrown = assertThrows( + IllegalStateException.class, + () -> activeTaskCreator.streamsProducerForTask(new TaskId(0, 2)) + ); + + assertThat(thrown.getMessage(), is("Unknown TaskId: 0_2")); + } + } + + @SuppressWarnings("deprecation") + @Test + public void shouldFailOnGetThreadProducerIfEosAlphaEnabled() { + properties.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE); + mockClientSupplier.setApplicationIdForProducer("appId"); + + createTasks(); + + final IllegalStateException thrown = assertThrows( + IllegalStateException.class, + activeTaskCreator::threadProducer + ); + + assertThat(thrown.getMessage(), is("Expected EXACTLY_ONCE_V2 to be enabled, but the processing mode was EXACTLY_ONCE_ALPHA")); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldThrowStreamsExceptionOnErrorCloseTaskProducerIfEosAlphaEnabled() { + properties.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE); + mockClientSupplier.setApplicationIdForProducer("appId"); + createTasks(); + mockClientSupplier.producers.get(0).closeException = new RuntimeException("KABOOM!"); + + final StreamsException thrown = assertThrows( + StreamsException.class, + () -> activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(new TaskId(0, 0)) + ); + + assertThat(thrown.getMessage(), is("[0_0] task producer encounter error trying to close.")); + assertThat(thrown.getCause().getMessage(), is("KABOOM!")); + + // should not throw again because producer should be removed + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(new TaskId(0, 0)); + } + + + // eos-v2 test + + // functional test + + @Test + public void shouldReturnThreadProducerIfEosV2Enabled() { + properties.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE_V2); + mockClientSupplier.setApplicationIdForProducer("appId"); + + createTasks(); + + final StreamsProducer threadProducer = activeTaskCreator.threadProducer(); + + assertThat(mockClientSupplier.producers.size(), is(1)); + assertThat(threadProducer.kafkaProducer(), is(mockClientSupplier.producers.get(0))); + } + + @Test + public void shouldConstructProducerMetricsWithEosV2Enabled() { + properties.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE_V2); + mockClientSupplier.setApplicationIdForProducer("appId"); + + shouldConstructThreadProducerMetric(); + } + + @Test + public void shouldConstructClientIdWithEosV2Enabled() { + properties.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE_V2); + mockClientSupplier.setApplicationIdForProducer("appId"); + createTasks(); + + final Set clientIds = activeTaskCreator.producerClientIds(); + + assertThat(clientIds, is(Collections.singleton("clientId-StreamThread-0-producer"))); + } + + @Test + public void shouldCloseThreadProducerIfEosV2Enabled() { + properties.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE_V2); + mockClientSupplier.setApplicationIdForProducer("appId"); + createTasks(); + + activeTaskCreator.closeThreadProducerIfNeeded(); + + assertThat(mockClientSupplier.producers.get(0).closed(), is(true)); + } + + @Test + public void shouldNoOpCloseTaskProducerIfEosV2Enabled() { + properties.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE_V2); + mockClientSupplier.setApplicationIdForProducer("appId"); + + createTasks(); + + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(new TaskId(0, 0)); + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(new TaskId(0, 1)); + + assertThat(mockClientSupplier.producers.get(0).closed(), is(false)); + } + + // error handling + + @Test + public void shouldFailOnStreamsProducerPerTaskIfEosV2Enabled() { + properties.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE_V2); + mockClientSupplier.setApplicationIdForProducer("appId"); + + createTasks(); + + final IllegalStateException thrown = assertThrows( + IllegalStateException.class, + () -> activeTaskCreator.streamsProducerForTask(null) + ); + + assertThat(thrown.getMessage(), is("Expected EXACTLY_ONCE to be enabled, but the processing mode was EXACTLY_ONCE_V2")); + } + + @Test + public void shouldThrowStreamsExceptionOnErrorCloseThreadProducerIfEosV2Enabled() { + properties.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE_V2); + mockClientSupplier.setApplicationIdForProducer("appId"); + createTasks(); + mockClientSupplier.producers.get(0).closeException = new RuntimeException("KABOOM!"); + + final StreamsException thrown = assertThrows( + StreamsException.class, + activeTaskCreator::closeThreadProducerIfNeeded + ); + + assertThat(thrown.getMessage(), is("Thread producer encounter error trying to close.")); + assertThat(thrown.getCause().getMessage(), is("KABOOM!")); + } + + private void shouldReturnStreamsProducerPerTask() { + mockClientSupplier.setApplicationIdForProducer("appId"); + + createTasks(); + + final StreamsProducer streamsProducer1 = activeTaskCreator.streamsProducerForTask(new TaskId(0, 0)); + final StreamsProducer streamsProducer2 = activeTaskCreator.streamsProducerForTask(new TaskId(0, 1)); + + assertThat(streamsProducer1, not(is(streamsProducer2))); + } + + private void shouldConstructProducerMetricsPerTask() { + mockClientSupplier.setApplicationIdForProducer("appId"); + + createTasks(); + + final MetricName testMetricName1 = new MetricName("test_metric_1", "", "", new HashMap<>()); + final Metric testMetric1 = new KafkaMetric( + new Object(), + testMetricName1, + (Measurable) (config, now) -> 0, + null, + new MockTime()); + mockClientSupplier.producers.get(0).setMockMetrics(testMetricName1, testMetric1); + final MetricName testMetricName2 = new MetricName("test_metric_2", "", "", new HashMap<>()); + final Metric testMetric2 = new KafkaMetric( + new Object(), + testMetricName2, + (Measurable) (config, now) -> 0, + null, + new MockTime()); + mockClientSupplier.producers.get(0).setMockMetrics(testMetricName2, testMetric2); + + final Map producerMetrics = activeTaskCreator.producerMetrics(); + + assertThat(producerMetrics, is(mkMap(mkEntry(testMetricName1, testMetric1), mkEntry(testMetricName2, testMetric2)))); + } + + private void shouldConstructThreadProducerMetric() { + createTasks(); + + final MetricName testMetricName = new MetricName("test_metric", "", "", new HashMap<>()); + final Metric testMetric = new KafkaMetric( + new Object(), + testMetricName, + (Measurable) (config, now) -> 0, + null, + new MockTime()); + mockClientSupplier.producers.get(0).setMockMetrics(testMetricName, testMetric); + assertThat(mockClientSupplier.producers.size(), is(1)); + + final Map producerMetrics = activeTaskCreator.producerMetrics(); + + assertThat(producerMetrics.size(), is(1)); + assertThat(producerMetrics.get(testMetricName), is(testMetric)); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private void createTasks() { + final TaskId task00 = new TaskId(0, 0); + final TaskId task01 = new TaskId(0, 1); + + final ProcessorTopology topology = mock(ProcessorTopology.class); + final SourceNode sourceNode = mock(SourceNode.class); + + reset(builder, stateDirectory); + expect(builder.buildSubtopology(0)).andReturn(topology).anyTimes(); + expect(stateDirectory.getOrCreateDirectoryForTask(task00)).andReturn(mock(File.class)); + expect(stateDirectory.checkpointFileFor(task00)).andReturn(mock(File.class)); + expect(stateDirectory.getOrCreateDirectoryForTask(task01)).andReturn(mock(File.class)); + expect(stateDirectory.checkpointFileFor(task01)).andReturn(mock(File.class)); + expect(topology.storeToChangelogTopic()).andReturn(Collections.emptyMap()).anyTimes(); + expect(topology.source("topic")).andReturn(sourceNode).anyTimes(); + expect(sourceNode.getTimestampExtractor()).andReturn(mock(TimestampExtractor.class)).anyTimes(); + expect(topology.globalStateStores()).andReturn(Collections.emptyList()).anyTimes(); + expect(topology.terminalNodes()).andStubReturn(Collections.singleton(sourceNode.name())); + expect(topology.sources()).andStubReturn(Collections.singleton(sourceNode)); + replay(builder, stateDirectory, topology, sourceNode); + + final StreamsConfig config = new StreamsConfig(properties); + activeTaskCreator = new ActiveTaskCreator( + new TopologyMetadata(builder, config), + config, + streamsMetrics, + stateDirectory, + changeLogReader, + new ThreadCache(new LogContext(), 0L, streamsMetrics), + new MockTime(), + mockClientSupplier, + "clientId-StreamThread-0", + uuid, + new LogContext().logger(ActiveTaskCreator.class) + ); + + assertThat( + activeTaskCreator.createTasks( + mockClientSupplier.consumer, + mkMap( + mkEntry(task00, Collections.singleton(new TopicPartition("topic", 0))), + mkEntry(task01, Collections.singleton(new TopicPartition("topic", 1))) + ) + ).stream().map(Task::id).collect(Collectors.toSet()), + equalTo(mkSet(task00, task01)) + ); + } + + private void addMetric( + final MockProducer producer, + final String name, + final double value) { + final MetricName metricName = metricName(name); + producer.setMockMetrics(metricName, new Metric() { + @Override + public MetricName metricName() { + return metricName; + } + + @Override + public Object metricValue() { + return value; + } + }); + } + + private MetricName metricName(final String name) { + return new MetricName(name, "", "", Collections.emptyMap()); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ChangelogTopicsTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ChangelogTopicsTest.java new file mode 100644 index 0000000..17db61d --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ChangelogTopicsTest.java @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder.TopicsInfo; +import org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology; + +import org.junit.Test; + +import java.util.Collections; +import java.util.Map; +import java.util.Set; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.SUBTOPOLOGY_0; + +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.mock; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.verify; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +public class ChangelogTopicsTest { + + private static final String SOURCE_TOPIC_NAME = "source"; + private static final String SINK_TOPIC_NAME = "sink"; + private static final String REPARTITION_TOPIC_NAME = "repartition"; + private static final String CHANGELOG_TOPIC_NAME1 = "changelog1"; + private static final Map TOPIC_CONFIG = Collections.singletonMap("config1", "val1"); + private static final RepartitionTopicConfig REPARTITION_TOPIC_CONFIG = + new RepartitionTopicConfig(REPARTITION_TOPIC_NAME, TOPIC_CONFIG); + private static final UnwindowedChangelogTopicConfig CHANGELOG_TOPIC_CONFIG = + new UnwindowedChangelogTopicConfig(CHANGELOG_TOPIC_NAME1, TOPIC_CONFIG); + + private static final TopicsInfo TOPICS_INFO1 = new TopicsInfo( + mkSet(SINK_TOPIC_NAME), + mkSet(SOURCE_TOPIC_NAME), + mkMap(mkEntry(REPARTITION_TOPIC_NAME, REPARTITION_TOPIC_CONFIG)), + mkMap(mkEntry(CHANGELOG_TOPIC_NAME1, CHANGELOG_TOPIC_CONFIG)) + ); + private static final TopicsInfo TOPICS_INFO2 = new TopicsInfo( + mkSet(SINK_TOPIC_NAME), + mkSet(SOURCE_TOPIC_NAME), + mkMap(mkEntry(REPARTITION_TOPIC_NAME, REPARTITION_TOPIC_CONFIG)), + mkMap() + ); + private static final TopicsInfo TOPICS_INFO3 = new TopicsInfo( + mkSet(SINK_TOPIC_NAME), + mkSet(SOURCE_TOPIC_NAME), + mkMap(mkEntry(REPARTITION_TOPIC_NAME, REPARTITION_TOPIC_CONFIG)), + mkMap(mkEntry(SOURCE_TOPIC_NAME, CHANGELOG_TOPIC_CONFIG)) + ); + private static final TopicsInfo TOPICS_INFO4 = new TopicsInfo( + mkSet(SINK_TOPIC_NAME), + mkSet(SOURCE_TOPIC_NAME), + mkMap(mkEntry(REPARTITION_TOPIC_NAME, REPARTITION_TOPIC_CONFIG)), + mkMap(mkEntry(SOURCE_TOPIC_NAME, null), mkEntry(CHANGELOG_TOPIC_NAME1, CHANGELOG_TOPIC_CONFIG)) + ); + private static final TaskId TASK_0_0 = new TaskId(0, 0); + private static final TaskId TASK_0_1 = new TaskId(0, 1); + private static final TaskId TASK_0_2 = new TaskId(0, 2); + + final InternalTopicManager internalTopicManager = mock(InternalTopicManager.class); + + @Test + public void shouldNotContainChangelogsForStatelessTasks() { + expect(internalTopicManager.makeReady(Collections.emptyMap())).andStubReturn(Collections.emptySet()); + final Map topicGroups = mkMap(mkEntry(SUBTOPOLOGY_0, TOPICS_INFO2)); + final Map> tasksForTopicGroup = mkMap(mkEntry(SUBTOPOLOGY_0, mkSet(TASK_0_0, TASK_0_1, TASK_0_2))); + replay(internalTopicManager); + + final ChangelogTopics changelogTopics = + new ChangelogTopics(internalTopicManager, topicGroups, tasksForTopicGroup, "[test] "); + changelogTopics.setup(); + + verify(internalTopicManager); + assertThat(changelogTopics.preExistingPartitionsFor(TASK_0_0), is(Collections.emptySet())); + assertThat(changelogTopics.preExistingPartitionsFor(TASK_0_1), is(Collections.emptySet())); + assertThat(changelogTopics.preExistingPartitionsFor(TASK_0_2), is(Collections.emptySet())); + assertThat(changelogTopics.preExistingSourceTopicBasedPartitions(), is(Collections.emptySet())); + assertThat(changelogTopics.preExistingNonSourceTopicBasedPartitions(), is(Collections.emptySet())); + } + + @Test + public void shouldNotContainAnyPreExistingChangelogsIfChangelogIsNewlyCreated() { + expect(internalTopicManager.makeReady(mkMap(mkEntry(CHANGELOG_TOPIC_NAME1, CHANGELOG_TOPIC_CONFIG)))) + .andStubReturn(mkSet(CHANGELOG_TOPIC_NAME1)); + final Map topicGroups = mkMap(mkEntry(SUBTOPOLOGY_0, TOPICS_INFO1)); + final Set tasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2); + final Map> tasksForTopicGroup = mkMap(mkEntry(SUBTOPOLOGY_0, tasks)); + replay(internalTopicManager); + + final ChangelogTopics changelogTopics = + new ChangelogTopics(internalTopicManager, topicGroups, tasksForTopicGroup, "[test] "); + changelogTopics.setup(); + + verify(internalTopicManager); + assertThat(CHANGELOG_TOPIC_CONFIG.numberOfPartitions().orElse(Integer.MIN_VALUE), is(3)); + assertThat(changelogTopics.preExistingPartitionsFor(TASK_0_0), is(Collections.emptySet())); + assertThat(changelogTopics.preExistingPartitionsFor(TASK_0_1), is(Collections.emptySet())); + assertThat(changelogTopics.preExistingPartitionsFor(TASK_0_2), is(Collections.emptySet())); + assertThat(changelogTopics.preExistingSourceTopicBasedPartitions(), is(Collections.emptySet())); + assertThat(changelogTopics.preExistingNonSourceTopicBasedPartitions(), is(Collections.emptySet())); + } + + @Test + public void shouldOnlyContainPreExistingNonSourceBasedChangelogs() { + expect(internalTopicManager.makeReady(mkMap(mkEntry(CHANGELOG_TOPIC_NAME1, CHANGELOG_TOPIC_CONFIG)))) + .andStubReturn(Collections.emptySet()); + final Map topicGroups = mkMap(mkEntry(SUBTOPOLOGY_0, TOPICS_INFO1)); + final Set tasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2); + final Map> tasksForTopicGroup = mkMap(mkEntry(SUBTOPOLOGY_0, tasks)); + replay(internalTopicManager); + + final ChangelogTopics changelogTopics = + new ChangelogTopics(internalTopicManager, topicGroups, tasksForTopicGroup, "[test] "); + changelogTopics.setup(); + + verify(internalTopicManager); + assertThat(CHANGELOG_TOPIC_CONFIG.numberOfPartitions().orElse(Integer.MIN_VALUE), is(3)); + final TopicPartition changelogPartition0 = new TopicPartition(CHANGELOG_TOPIC_NAME1, 0); + final TopicPartition changelogPartition1 = new TopicPartition(CHANGELOG_TOPIC_NAME1, 1); + final TopicPartition changelogPartition2 = new TopicPartition(CHANGELOG_TOPIC_NAME1, 2); + assertThat(changelogTopics.preExistingPartitionsFor(TASK_0_0), is(mkSet(changelogPartition0))); + assertThat(changelogTopics.preExistingPartitionsFor(TASK_0_1), is(mkSet(changelogPartition1))); + assertThat(changelogTopics.preExistingPartitionsFor(TASK_0_2), is(mkSet(changelogPartition2))); + assertThat(changelogTopics.preExistingSourceTopicBasedPartitions(), is(Collections.emptySet())); + assertThat( + changelogTopics.preExistingNonSourceTopicBasedPartitions(), + is(mkSet(changelogPartition0, changelogPartition1, changelogPartition2)) + ); + } + + @Test + public void shouldOnlyContainPreExistingSourceBasedChangelogs() { + expect(internalTopicManager.makeReady(Collections.emptyMap())).andStubReturn(Collections.emptySet()); + final Map topicGroups = mkMap(mkEntry(SUBTOPOLOGY_0, TOPICS_INFO3)); + final Set tasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2); + final Map> tasksForTopicGroup = mkMap(mkEntry(SUBTOPOLOGY_0, tasks)); + replay(internalTopicManager); + + final ChangelogTopics changelogTopics = + new ChangelogTopics(internalTopicManager, topicGroups, tasksForTopicGroup, "[test] "); + changelogTopics.setup(); + + verify(internalTopicManager); + final TopicPartition changelogPartition0 = new TopicPartition(SOURCE_TOPIC_NAME, 0); + final TopicPartition changelogPartition1 = new TopicPartition(SOURCE_TOPIC_NAME, 1); + final TopicPartition changelogPartition2 = new TopicPartition(SOURCE_TOPIC_NAME, 2); + assertThat(changelogTopics.preExistingPartitionsFor(TASK_0_0), is(mkSet(changelogPartition0))); + assertThat(changelogTopics.preExistingPartitionsFor(TASK_0_1), is(mkSet(changelogPartition1))); + assertThat(changelogTopics.preExistingPartitionsFor(TASK_0_2), is(mkSet(changelogPartition2))); + assertThat( + changelogTopics.preExistingSourceTopicBasedPartitions(), + is(mkSet(changelogPartition0, changelogPartition1, changelogPartition2)) + ); + assertThat(changelogTopics.preExistingNonSourceTopicBasedPartitions(), is(Collections.emptySet())); + } + + @Test + public void shouldContainBothTypesOfPreExistingChangelogs() { + expect(internalTopicManager.makeReady(mkMap(mkEntry(CHANGELOG_TOPIC_NAME1, CHANGELOG_TOPIC_CONFIG)))) + .andStubReturn(Collections.emptySet()); + final Map topicGroups = mkMap(mkEntry(SUBTOPOLOGY_0, TOPICS_INFO4)); + final Set tasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2); + final Map> tasksForTopicGroup = mkMap(mkEntry(SUBTOPOLOGY_0, tasks)); + replay(internalTopicManager); + + final ChangelogTopics changelogTopics = + new ChangelogTopics(internalTopicManager, topicGroups, tasksForTopicGroup, "[test] "); + changelogTopics.setup(); + + verify(internalTopicManager); + assertThat(CHANGELOG_TOPIC_CONFIG.numberOfPartitions().orElse(Integer.MIN_VALUE), is(3)); + final TopicPartition changelogPartition0 = new TopicPartition(CHANGELOG_TOPIC_NAME1, 0); + final TopicPartition changelogPartition1 = new TopicPartition(CHANGELOG_TOPIC_NAME1, 1); + final TopicPartition changelogPartition2 = new TopicPartition(CHANGELOG_TOPIC_NAME1, 2); + final TopicPartition sourcePartition0 = new TopicPartition(SOURCE_TOPIC_NAME, 0); + final TopicPartition sourcePartition1 = new TopicPartition(SOURCE_TOPIC_NAME, 1); + final TopicPartition sourcePartition2 = new TopicPartition(SOURCE_TOPIC_NAME, 2); + assertThat(changelogTopics.preExistingPartitionsFor(TASK_0_0), is(mkSet(sourcePartition0, changelogPartition0))); + assertThat(changelogTopics.preExistingPartitionsFor(TASK_0_1), is(mkSet(sourcePartition1, changelogPartition1))); + assertThat(changelogTopics.preExistingPartitionsFor(TASK_0_2), is(mkSet(sourcePartition2, changelogPartition2))); + assertThat( + changelogTopics.preExistingSourceTopicBasedPartitions(), + is(mkSet(sourcePartition0, sourcePartition1, sourcePartition2)) + ); + assertThat( + changelogTopics.preExistingNonSourceTopicBasedPartitions(), + is(mkSet(changelogPartition0, changelogPartition1, changelogPartition2)) + ); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ClientUtilsTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ClientUtilsTest.java new file mode 100644 index 0000000..a6c5e3d --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ClientUtilsTest.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.AdminClient; +import org.apache.kafka.clients.admin.ListOffsetsResult; +import org.apache.kafka.clients.admin.ListOffsetsResult.ListOffsetsResultInfo; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.streams.errors.StreamsException; +import org.easymock.EasyMock; +import org.junit.Test; + +import static java.util.Collections.emptySet; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.streams.processor.internals.ClientUtils.fetchCommittedOffsets; +import static org.apache.kafka.streams.processor.internals.ClientUtils.fetchEndOffsets; +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.verify; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class ClientUtilsTest { + + private static final Set PARTITIONS = mkSet( + new TopicPartition("topic", 1), + new TopicPartition("topic", 2) + ); + + @Test + public void fetchCommittedOffsetsShouldRethrowKafkaExceptionAsStreamsException() { + final Consumer consumer = EasyMock.createMock(Consumer.class); + expect(consumer.committed(PARTITIONS)).andThrow(new KafkaException()); + replay(consumer); + assertThrows(StreamsException.class, () -> fetchCommittedOffsets(PARTITIONS, consumer)); + } + + @Test + public void fetchCommittedOffsetsShouldRethrowTimeoutException() { + final Consumer consumer = EasyMock.createMock(Consumer.class); + expect(consumer.committed(PARTITIONS)).andThrow(new TimeoutException()); + replay(consumer); + assertThrows(TimeoutException.class, () -> fetchCommittedOffsets(PARTITIONS, consumer)); + } + + @Test + public void fetchCommittedOffsetsShouldReturnEmptyMapIfPartitionsAreEmpty() { + final Consumer consumer = EasyMock.createMock(Consumer.class); + assertTrue(fetchCommittedOffsets(emptySet(), consumer).isEmpty()); + } + + @Test + public void fetchEndOffsetsShouldReturnEmptyMapIfPartitionsAreEmpty() { + final Admin adminClient = EasyMock.createMock(AdminClient.class); + assertTrue(fetchEndOffsets(emptySet(), adminClient).isEmpty()); + } + + @Test + public void fetchEndOffsetsShouldRethrowRuntimeExceptionAsStreamsException() throws Exception { + final Admin adminClient = EasyMock.createMock(AdminClient.class); + final ListOffsetsResult result = EasyMock.createNiceMock(ListOffsetsResult.class); + final KafkaFuture> allFuture = EasyMock.createMock(KafkaFuture.class); + + EasyMock.expect(adminClient.listOffsets(EasyMock.anyObject())).andStubReturn(result); + EasyMock.expect(result.all()).andStubReturn(allFuture); + EasyMock.expect(allFuture.get()).andThrow(new RuntimeException()); + replay(adminClient, result, allFuture); + + assertThrows(StreamsException.class, () -> fetchEndOffsets(PARTITIONS, adminClient)); + verify(adminClient); + } + + @Test + public void fetchEndOffsetsShouldRethrowInterruptedExceptionAsStreamsException() throws Exception { + final Admin adminClient = EasyMock.createMock(AdminClient.class); + final ListOffsetsResult result = EasyMock.createNiceMock(ListOffsetsResult.class); + final KafkaFuture> allFuture = EasyMock.createMock(KafkaFuture.class); + + EasyMock.expect(adminClient.listOffsets(EasyMock.anyObject())).andStubReturn(result); + EasyMock.expect(result.all()).andStubReturn(allFuture); + EasyMock.expect(allFuture.get()).andThrow(new InterruptedException()); + replay(adminClient, result, allFuture); + + assertThrows(StreamsException.class, () -> fetchEndOffsets(PARTITIONS, adminClient)); + verify(adminClient); + } + + @Test + public void fetchEndOffsetsShouldRethrowExecutionExceptionAsStreamsException() throws Exception { + final Admin adminClient = EasyMock.createMock(AdminClient.class); + final ListOffsetsResult result = EasyMock.createNiceMock(ListOffsetsResult.class); + final KafkaFuture> allFuture = EasyMock.createMock(KafkaFuture.class); + + EasyMock.expect(adminClient.listOffsets(EasyMock.anyObject())).andStubReturn(result); + EasyMock.expect(result.all()).andStubReturn(allFuture); + EasyMock.expect(allFuture.get()).andThrow(new ExecutionException(new RuntimeException())); + replay(adminClient, result, allFuture); + + assertThrows(StreamsException.class, () -> fetchEndOffsets(PARTITIONS, adminClient)); + verify(adminClient); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/CopartitionedTopicsEnforcerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/CopartitionedTopicsEnforcerTest.java new file mode 100644 index 0000000..09e1269 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/CopartitionedTopicsEnforcerTest.java @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.errors.TopologyException; +import org.apache.kafka.streams.processor.internals.assignment.CopartitionedTopicsEnforcer; +import org.junit.Before; +import org.junit.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.TreeMap; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +public class CopartitionedTopicsEnforcerTest { + + private final CopartitionedTopicsEnforcer validator = new CopartitionedTopicsEnforcer("thread "); + private final Map partitions = new HashMap<>(); + private final Cluster cluster = Cluster.empty(); + + @Before + public void before() { + partitions.put( + new TopicPartition("first", 0), + new PartitionInfo("first", 0, null, null, null)); + partitions.put( + new TopicPartition("first", 1), + new PartitionInfo("first", 1, null, null, null)); + partitions.put( + new TopicPartition("second", 0), + new PartitionInfo("second", 0, null, null, null)); + partitions.put( + new TopicPartition("second", 1), + new PartitionInfo("second", 1, null, null, null)); + } + + @Test + public void shouldThrowTopologyBuilderExceptionIfNoPartitionsFoundForCoPartitionedTopic() { + assertThrows(IllegalStateException.class, () -> validator.enforce(Collections.singleton("topic"), + Collections.emptyMap(), cluster)); + } + + @Test + public void shouldThrowTopologyBuilderExceptionIfPartitionCountsForCoPartitionedTopicsDontMatch() { + partitions.remove(new TopicPartition("second", 0)); + assertThrows(TopologyException.class, () -> validator.enforce(Utils.mkSet("first", "second"), + Collections.emptyMap(), + cluster.withPartitions(partitions))); + } + + + @Test + public void shouldEnforceCopartitioningOnRepartitionTopics() { + final InternalTopicConfig config = createTopicConfig("repartitioned", 10); + + validator.enforce(Utils.mkSet("first", "second", config.name()), + Collections.singletonMap(config.name(), config), + cluster.withPartitions(partitions)); + + assertThat(config.numberOfPartitions(), equalTo(Optional.of(2))); + } + + + @Test + public void shouldSetNumPartitionsToMaximumPartitionsWhenAllTopicsAreRepartitionTopics() { + final InternalTopicConfig one = createTopicConfig("one", 1); + final InternalTopicConfig two = createTopicConfig("two", 15); + final InternalTopicConfig three = createTopicConfig("three", 5); + final Map repartitionTopicConfig = new HashMap<>(); + + repartitionTopicConfig.put(one.name(), one); + repartitionTopicConfig.put(two.name(), two); + repartitionTopicConfig.put(three.name(), three); + + validator.enforce(Utils.mkSet(one.name(), + two.name(), + three.name()), + repartitionTopicConfig, + cluster + ); + + assertThat(one.numberOfPartitions(), equalTo(Optional.of(15))); + assertThat(two.numberOfPartitions(), equalTo(Optional.of(15))); + assertThat(three.numberOfPartitions(), equalTo(Optional.of(15))); + } + + @Test + public void shouldThrowAnExceptionIfRepartitionTopicConfigsWithEnforcedNumOfPartitionsHaveDifferentNumOfPartitiones() { + final InternalTopicConfig topic1 = createRepartitionTopicConfigWithEnforcedNumberOfPartitions("repartitioned-1", 10); + final InternalTopicConfig topic2 = createRepartitionTopicConfigWithEnforcedNumberOfPartitions("repartitioned-2", 5); + + final TopologyException ex = assertThrows( + TopologyException.class, + () -> validator.enforce(Utils.mkSet(topic1.name(), topic2.name()), + Utils.mkMap(Utils.mkEntry(topic1.name(), topic1), + Utils.mkEntry(topic2.name(), topic2)), + cluster.withPartitions(partitions)) + ); + + final TreeMap sorted = new TreeMap<>( + Utils.mkMap(Utils.mkEntry(topic1.name(), topic1.numberOfPartitions().get()), + Utils.mkEntry(topic2.name(), topic2.numberOfPartitions().get())) + ); + + assertEquals(String.format("Invalid topology: thread " + + "Following topics do not have the same number of partitions: " + + "[%s]", sorted), ex.getMessage()); + } + + @Test + public void shouldNotThrowAnExceptionWhenRepartitionTopicConfigsWithEnforcedNumOfPartitionsAreValid() { + final InternalTopicConfig topic1 = createRepartitionTopicConfigWithEnforcedNumberOfPartitions("repartitioned-1", 10); + final InternalTopicConfig topic2 = createRepartitionTopicConfigWithEnforcedNumberOfPartitions("repartitioned-2", 10); + + validator.enforce(Utils.mkSet(topic1.name(), topic2.name()), + Utils.mkMap(Utils.mkEntry(topic1.name(), topic1), + Utils.mkEntry(topic2.name(), topic2)), + cluster.withPartitions(partitions)); + + assertThat(topic1.numberOfPartitions(), equalTo(Optional.of(10))); + assertThat(topic2.numberOfPartitions(), equalTo(Optional.of(10))); + } + + @Test + public void shouldThrowAnExceptionWhenNumberOfPartitionsOfNonRepartitionTopicAndRepartitionTopicWithEnforcedNumOfPartitionsDoNotMatch() { + final InternalTopicConfig topic1 = createRepartitionTopicConfigWithEnforcedNumberOfPartitions("repartitioned-1", 10); + + final TopologyException ex = assertThrows( + TopologyException.class, + () -> validator.enforce(Utils.mkSet(topic1.name(), "second"), + Utils.mkMap(Utils.mkEntry(topic1.name(), topic1)), + cluster.withPartitions(partitions)) + ); + + assertEquals(String.format("Invalid topology: thread Number of partitions [%s] " + + "of repartition topic [%s] " + + "doesn't match number of partitions [%s] of the source topic.", + topic1.numberOfPartitions().get(), topic1.name(), 2), ex.getMessage()); + } + + @Test + public void shouldNotThrowAnExceptionWhenNumberOfPartitionsOfNonRepartitionTopicAndRepartitionTopicWithEnforcedNumOfPartitionsMatch() { + final InternalTopicConfig topic1 = createRepartitionTopicConfigWithEnforcedNumberOfPartitions("repartitioned-1", 2); + + validator.enforce(Utils.mkSet(topic1.name(), "second"), + Utils.mkMap(Utils.mkEntry(topic1.name(), topic1)), + cluster.withPartitions(partitions)); + + assertThat(topic1.numberOfPartitions(), equalTo(Optional.of(2))); + } + + @Test + public void shouldDeductNumberOfPartitionsFromRepartitionTopicWithEnforcedNumberOfPartitions() { + final InternalTopicConfig topic1 = createRepartitionTopicConfigWithEnforcedNumberOfPartitions("repartitioned-1", 2); + final InternalTopicConfig topic2 = createTopicConfig("repartitioned-2", 5); + final InternalTopicConfig topic3 = createRepartitionTopicConfigWithEnforcedNumberOfPartitions("repartitioned-3", 2); + + validator.enforce(Utils.mkSet(topic1.name(), topic2.name()), + Utils.mkMap(Utils.mkEntry(topic1.name(), topic1), + Utils.mkEntry(topic2.name(), topic2), + Utils.mkEntry(topic3.name(), topic3)), + cluster.withPartitions(partitions)); + + assertEquals(topic1.numberOfPartitions(), topic2.numberOfPartitions()); + assertEquals(topic2.numberOfPartitions(), topic3.numberOfPartitions()); + } + + private InternalTopicConfig createTopicConfig(final String repartitionTopic, + final int partitions) { + final InternalTopicConfig repartitionTopicConfig = + new RepartitionTopicConfig(repartitionTopic, Collections.emptyMap()); + + repartitionTopicConfig.setNumberOfPartitions(partitions); + return repartitionTopicConfig; + } + + private InternalTopicConfig createRepartitionTopicConfigWithEnforcedNumberOfPartitions(final String repartitionTopic, + final int partitions) { + return new RepartitionTopicConfig(repartitionTopic, + Collections.emptyMap(), + partitions, + true); + } + +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ForwardingDisabledProcessorContextTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ForwardingDisabledProcessorContextTest.java new file mode 100644 index 0000000..9e78c6d --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ForwardingDisabledProcessorContextTest.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.To; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; + +public class ForwardingDisabledProcessorContextTest { + + private ForwardingDisabledProcessorContext context; + + @Before + public void setUp() { + context = new ForwardingDisabledProcessorContext(mock(ProcessorContext.class)); + } + + @Test + public void shouldThrowOnForward() { + assertThrows(StreamsException.class, () -> context.forward("key", "value")); + } + + @Test + public void shouldThrowOnForwardWithTo() { + assertThrows(StreamsException.class, () -> context.forward("key", "value", To.all())); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalProcessorContextImplTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalProcessorContextImplTest.java new file mode 100644 index 0000000..1dffd7a --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalProcessorContextImplTest.java @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.To; +import org.apache.kafka.streams.processor.internals.Task.TaskType; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.SessionStore; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.TimestampedWindowStore; +import org.apache.kafka.streams.state.WindowStore; +import org.hamcrest.core.IsInstanceOf; +import org.junit.Before; +import org.junit.Test; + +import static org.easymock.EasyMock.anyObject; +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.expectLastCall; +import static org.easymock.EasyMock.mock; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.verify; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.fail; + +public class GlobalProcessorContextImplTest { + private static final String GLOBAL_STORE_NAME = "global-store"; + private static final String GLOBAL_KEY_VALUE_STORE_NAME = "global-key-value-store"; + private static final String GLOBAL_TIMESTAMPED_KEY_VALUE_STORE_NAME = "global-timestamped-key-value-store"; + private static final String GLOBAL_WINDOW_STORE_NAME = "global-window-store"; + private static final String GLOBAL_TIMESTAMPED_WINDOW_STORE_NAME = "global-timestamped-window-store"; + private static final String GLOBAL_SESSION_STORE_NAME = "global-session-store"; + private static final String UNKNOWN_STORE = "unknown-store"; + + private GlobalProcessorContextImpl globalContext; + + private ProcessorNode child; + private ProcessorRecordContext recordContext; + + @Before + public void setup() { + final StreamsConfig streamsConfig = mock(StreamsConfig.class); + expect(streamsConfig.getString(StreamsConfig.APPLICATION_ID_CONFIG)).andReturn("dummy-id"); + expect(streamsConfig.defaultValueSerde()).andReturn(Serdes.ByteArray()); + expect(streamsConfig.defaultKeySerde()).andReturn(Serdes.ByteArray()); + replay(streamsConfig); + + final GlobalStateManager stateManager = mock(GlobalStateManager.class); + expect(stateManager.getGlobalStore(GLOBAL_STORE_NAME)).andReturn(mock(StateStore.class)); + expect(stateManager.getGlobalStore(GLOBAL_KEY_VALUE_STORE_NAME)).andReturn(mock(KeyValueStore.class)); + expect(stateManager.getGlobalStore(GLOBAL_TIMESTAMPED_KEY_VALUE_STORE_NAME)).andReturn(mock(TimestampedKeyValueStore.class)); + expect(stateManager.getGlobalStore(GLOBAL_WINDOW_STORE_NAME)).andReturn(mock(WindowStore.class)); + expect(stateManager.getGlobalStore(GLOBAL_TIMESTAMPED_WINDOW_STORE_NAME)).andReturn(mock(TimestampedWindowStore.class)); + expect(stateManager.getGlobalStore(GLOBAL_SESSION_STORE_NAME)).andReturn(mock(SessionStore.class)); + expect(stateManager.getGlobalStore(UNKNOWN_STORE)).andReturn(null); + expect(stateManager.taskType()).andStubReturn(TaskType.GLOBAL); + replay(stateManager); + + globalContext = new GlobalProcessorContextImpl( + streamsConfig, + stateManager, + null, + null, + Time.SYSTEM); + + final ProcessorNode processorNode = new ProcessorNode<>("testNode"); + + child = mock(ProcessorNode.class); + processorNode.addChild(child); + + globalContext.setCurrentNode(processorNode); + recordContext = mock(ProcessorRecordContext.class); + globalContext.setRecordContext(recordContext); + } + + @Test + public void shouldReturnGlobalOrNullStore() { + assertThat(globalContext.getStateStore(GLOBAL_STORE_NAME), new IsInstanceOf(StateStore.class)); + assertNull(globalContext.getStateStore(UNKNOWN_STORE)); + } + + @Test + public void shouldForwardToSingleChild() { + child.process(anyObject()); + expectLastCall(); + + expect(recordContext.timestamp()).andStubReturn(0L); + expect(recordContext.headers()).andStubReturn(new RecordHeaders()); + replay(child, recordContext); + globalContext.forward((Object /*forcing a call to the K/V forward*/) null, null); + verify(child, recordContext); + } + + @Test + public void shouldFailToForwardUsingToParameter() { + assertThrows(IllegalStateException.class, () -> globalContext.forward(null, null, To.all())); + } + + @Test + public void shouldNotFailOnNoOpCommit() { + globalContext.commit(); + } + + @Test + public void shouldNotAllowToSchedulePunctuations() { + assertThrows(UnsupportedOperationException.class, () -> globalContext.schedule(null, null, null)); + } + + @Test + public void shouldNotAllowInitForKeyValueStore() { + final StateStore store = globalContext.getStateStore(GLOBAL_KEY_VALUE_STORE_NAME); + try { + store.init((StateStoreContext) null, null); + fail("Should have thrown UnsupportedOperationException."); + } catch (final UnsupportedOperationException expected) { } + } + + @Test + public void shouldNotAllowInitForTimestampedKeyValueStore() { + final StateStore store = globalContext.getStateStore(GLOBAL_TIMESTAMPED_KEY_VALUE_STORE_NAME); + try { + store.init((StateStoreContext) null, null); + fail("Should have thrown UnsupportedOperationException."); + } catch (final UnsupportedOperationException expected) { } + } + + @Test + public void shouldNotAllowInitForWindowStore() { + final StateStore store = globalContext.getStateStore(GLOBAL_WINDOW_STORE_NAME); + try { + store.init((StateStoreContext) null, null); + fail("Should have thrown UnsupportedOperationException."); + } catch (final UnsupportedOperationException expected) { } + } + + @Test + public void shouldNotAllowInitForTimestampedWindowStore() { + final StateStore store = globalContext.getStateStore(GLOBAL_TIMESTAMPED_WINDOW_STORE_NAME); + try { + store.init((StateStoreContext) null, null); + fail("Should have thrown UnsupportedOperationException."); + } catch (final UnsupportedOperationException expected) { } + } + + @Test + public void shouldNotAllowInitForSessionStore() { + final StateStore store = globalContext.getStateStore(GLOBAL_SESSION_STORE_NAME); + try { + store.init((StateStoreContext) null, null); + fail("Should have thrown UnsupportedOperationException."); + } catch (final UnsupportedOperationException expected) { } + } + + @Test + public void shouldNotAllowCloseForKeyValueStore() { + final StateStore store = globalContext.getStateStore(GLOBAL_KEY_VALUE_STORE_NAME); + try { + store.close(); + fail("Should have thrown UnsupportedOperationException."); + } catch (final UnsupportedOperationException expected) { } + } + + @Test + public void shouldNotAllowCloseForTimestampedKeyValueStore() { + final StateStore store = globalContext.getStateStore(GLOBAL_TIMESTAMPED_KEY_VALUE_STORE_NAME); + try { + store.close(); + fail("Should have thrown UnsupportedOperationException."); + } catch (final UnsupportedOperationException expected) { } + } + + @Test + public void shouldNotAllowCloseForWindowStore() { + final StateStore store = globalContext.getStateStore(GLOBAL_WINDOW_STORE_NAME); + try { + store.close(); + fail("Should have thrown UnsupportedOperationException."); + } catch (final UnsupportedOperationException expected) { } + } + + @Test + public void shouldNotAllowCloseForTimestampedWindowStore() { + final StateStore store = globalContext.getStateStore(GLOBAL_TIMESTAMPED_WINDOW_STORE_NAME); + try { + store.close(); + fail("Should have thrown UnsupportedOperationException."); + } catch (final UnsupportedOperationException expected) { } + } + + @Test + public void shouldNotAllowCloseForSessionStore() { + final StateStore store = globalContext.getStateStore(GLOBAL_SESSION_STORE_NAME); + try { + store.close(); + fail("Should have thrown UnsupportedOperationException."); + } catch (final UnsupportedOperationException expected) { } + } + + @Test(expected = UnsupportedOperationException.class) + public void shouldThrowOnCurrentStreamTime() { + globalContext.currentStreamTimeMs(); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java new file mode 100644 index 0000000..9b50ead --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java @@ -0,0 +1,1142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.MockConsumer; +import org.apache.kafka.clients.consumer.OffsetResetStrategy; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.ProcessorStateException; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.StateRestoreCallback; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.state.TimestampedBytesStore; +import org.apache.kafka.streams.state.internals.OffsetCheckpoint; +import org.apache.kafka.streams.state.internals.WrappedStateStore; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.apache.kafka.test.InternalMockProcessorContext; +import org.apache.kafka.test.MockStateRestoreListener; +import org.apache.kafka.test.NoOpReadOnlyStore; +import org.apache.kafka.test.TestUtils; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.file.Files; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; + +import static java.util.Arrays.asList; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.test.MockStateRestoreListener.RESTORE_BATCH; +import static org.apache.kafka.test.MockStateRestoreListener.RESTORE_END; +import static org.apache.kafka.test.MockStateRestoreListener.RESTORE_START; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class GlobalStateManagerImplTest { + + + private final MockTime time = new MockTime(); + private final TheStateRestoreCallback stateRestoreCallback = new TheStateRestoreCallback(); + private final MockStateRestoreListener stateRestoreListener = new MockStateRestoreListener(); + private final String storeName1 = "t1-store"; + private final String storeName2 = "t2-store"; + private final String storeName3 = "t3-store"; + private final String storeName4 = "t4-store"; + private final TopicPartition t1 = new TopicPartition("t1", 1); + private final TopicPartition t2 = new TopicPartition("t2", 1); + private final TopicPartition t3 = new TopicPartition("t3", 1); + private final TopicPartition t4 = new TopicPartition("t4", 1); + private GlobalStateManagerImpl stateManager; + private StateDirectory stateDirectory; + private StreamsConfig streamsConfig; + private NoOpReadOnlyStore store1, store2, store3, store4; + private MockConsumer consumer; + private File checkpointFile; + private ProcessorTopology topology; + private InternalMockProcessorContext processorContext; + + static ProcessorTopology withGlobalStores(final List stateStores, + final Map storeToChangelogTopic) { + return new ProcessorTopology(Collections.emptyList(), + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyList(), + stateStores, + storeToChangelogTopic, + Collections.emptySet()); + } + + @Before + public void before() { + final Map storeToTopic = new HashMap<>(); + + storeToTopic.put(storeName1, t1.topic()); + storeToTopic.put(storeName2, t2.topic()); + storeToTopic.put(storeName3, t3.topic()); + storeToTopic.put(storeName4, t4.topic()); + + store1 = new NoOpReadOnlyStore<>(storeName1, true); + store2 = new ConverterStore<>(storeName2, true); + store3 = new NoOpReadOnlyStore<>(storeName3); + store4 = new NoOpReadOnlyStore<>(storeName4); + + topology = withGlobalStores(asList(store1, store2, store3, store4), storeToTopic); + + streamsConfig = new StreamsConfig(new Properties() { + { + put(StreamsConfig.APPLICATION_ID_CONFIG, "appId"); + put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"); + put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()); + } + }); + stateDirectory = new StateDirectory(streamsConfig, time, true, false); + consumer = new MockConsumer<>(OffsetResetStrategy.NONE); + stateManager = new GlobalStateManagerImpl( + new LogContext("test"), + time, + topology, + consumer, + stateDirectory, + stateRestoreListener, + streamsConfig + ); + processorContext = new InternalMockProcessorContext(stateDirectory.globalStateDir(), streamsConfig); + stateManager.setGlobalProcessorContext(processorContext); + checkpointFile = new File(stateManager.baseDir(), StateManagerUtil.CHECKPOINT_FILE_NAME); + } + + @Test + public void shouldReadCheckpointOffsets() throws IOException { + final Map expected = writeCheckpoint(); + + stateManager.initialize(); + final Map offsets = stateManager.changelogOffsets(); + assertEquals(expected, offsets); + } + + @Test + public void shouldLogWarningMessageWhenIOExceptionInCheckPoint() throws IOException { + final Map offsets = Collections.singletonMap(t1, 25L); + stateManager.initialize(); + stateManager.updateChangelogOffsets(offsets); + + // set readonly to the CHECKPOINT_FILE_NAME.tmp file because we will write data to the .tmp file first + // and then swap to CHECKPOINT_FILE_NAME by replacing it + final File file = new File(stateDirectory.globalStateDir(), StateManagerUtil.CHECKPOINT_FILE_NAME + ".tmp"); + file.createNewFile(); + file.setWritable(false); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(GlobalStateManagerImpl.class)) { + stateManager.checkpoint(); + assertThat(appender.getMessages(), hasItem(containsString( + "Failed to write offset checkpoint file to " + checkpointFile.getPath() + " for global stores"))); + } + } + + @Test + public void shouldThrowStreamsExceptionForOldTopicPartitions() throws IOException { + final HashMap expectedOffsets = new HashMap<>(); + expectedOffsets.put(t1, 1L); + expectedOffsets.put(t2, 1L); + expectedOffsets.put(t3, 1L); + expectedOffsets.put(t4, 1L); + + // add an old topic (a topic not associated with any global state store) + final HashMap startOffsets = new HashMap<>(expectedOffsets); + final TopicPartition tOld = new TopicPartition("oldTopic", 1); + startOffsets.put(tOld, 1L); + + // start with a checkpoint file will all topic-partitions: expected and old (not + // associated with any global state store). + final OffsetCheckpoint checkpoint = new OffsetCheckpoint(checkpointFile); + checkpoint.write(startOffsets); + + // initialize will throw exception + final StreamsException e = assertThrows(StreamsException.class, () -> stateManager.initialize()); + assertThat(e.getMessage(), equalTo("Encountered a topic-partition not associated with any global state store")); + } + + @Test + public void shouldNotDeleteCheckpointFileAfterLoaded() throws IOException { + writeCheckpoint(); + stateManager.initialize(); + assertTrue(checkpointFile.exists()); + } + + @Test + public void shouldThrowStreamsExceptionIfFailedToReadCheckpointedOffsets() throws IOException { + writeCorruptCheckpoint(); + assertThrows(StreamsException.class, stateManager::initialize); + } + + @Test + public void shouldInitializeStateStores() { + stateManager.initialize(); + assertTrue(store1.initialized); + assertTrue(store2.initialized); + } + + @Test + public void shouldReturnInitializedStoreNames() { + final Set storeNames = stateManager.initialize(); + assertEquals(Utils.mkSet(storeName1, storeName2, storeName3, storeName4), storeNames); + } + + @Test + public void shouldThrowIllegalArgumentIfTryingToRegisterStoreThatIsNotGlobal() { + stateManager.initialize(); + + try { + stateManager.registerStore(new NoOpReadOnlyStore<>("not-in-topology"), stateRestoreCallback); + fail("should have raised an illegal argument exception as store is not in the topology"); + } catch (final IllegalArgumentException e) { + // pass + } + } + + @Test + public void shouldThrowIllegalArgumentExceptionIfAttemptingToRegisterStoreTwice() { + stateManager.initialize(); + initializeConsumer(2, 0, t1); + stateManager.registerStore(store1, stateRestoreCallback); + try { + stateManager.registerStore(store1, stateRestoreCallback); + fail("should have raised an illegal argument exception as store has already been registered"); + } catch (final IllegalArgumentException e) { + // pass + } + } + + @Test + public void shouldThrowStreamsExceptionIfNoPartitionsFoundForStore() { + stateManager.initialize(); + try { + stateManager.registerStore(store1, stateRestoreCallback); + fail("Should have raised a StreamsException as there are no partition for the store"); + } catch (final StreamsException e) { + // pass + } + } + + @Test + public void shouldNotConvertValuesIfStoreDoesNotImplementTimestampedBytesStore() { + initializeConsumer(1, 0, t1); + + stateManager.initialize(); + stateManager.registerStore(store1, stateRestoreCallback); + + final KeyValue restoredRecord = stateRestoreCallback.restored.get(0); + assertEquals(3, restoredRecord.key.length); + assertEquals(5, restoredRecord.value.length); + } + + @Test + public void shouldNotConvertValuesIfInnerStoreDoesNotImplementTimestampedBytesStore() { + initializeConsumer(1, 0, t1); + + stateManager.initialize(); + stateManager.registerStore( + new WrappedStateStore, Object, Object>(store1) { + }, + stateRestoreCallback + ); + + final KeyValue restoredRecord = stateRestoreCallback.restored.get(0); + assertEquals(3, restoredRecord.key.length); + assertEquals(5, restoredRecord.value.length); + } + + @Test + public void shouldConvertValuesIfStoreImplementsTimestampedBytesStore() { + initializeConsumer(1, 0, t2); + + stateManager.initialize(); + stateManager.registerStore(store2, stateRestoreCallback); + + final KeyValue restoredRecord = stateRestoreCallback.restored.get(0); + assertEquals(3, restoredRecord.key.length); + assertEquals(13, restoredRecord.value.length); + } + + @Test + public void shouldConvertValuesIfInnerStoreImplementsTimestampedBytesStore() { + initializeConsumer(1, 0, t2); + + stateManager.initialize(); + stateManager.registerStore( + new WrappedStateStore, Object, Object>(store2) { + }, + stateRestoreCallback + ); + + final KeyValue restoredRecord = stateRestoreCallback.restored.get(0); + assertEquals(3, restoredRecord.key.length); + assertEquals(13, restoredRecord.value.length); + } + + @Test + public void shouldRestoreRecordsUpToHighwatermark() { + initializeConsumer(2, 0, t1); + + stateManager.initialize(); + + stateManager.registerStore(store1, stateRestoreCallback); + assertEquals(2, stateRestoreCallback.restored.size()); + } + + @Test + public void shouldListenForRestoreEvents() { + initializeConsumer(5, 1, t1); + stateManager.initialize(); + + stateManager.registerStore(store1, stateRestoreCallback); + + assertThat(stateRestoreListener.restoreStartOffset, equalTo(1L)); + assertThat(stateRestoreListener.restoreEndOffset, equalTo(6L)); + assertThat(stateRestoreListener.totalNumRestored, equalTo(5L)); + + + assertThat(stateRestoreListener.storeNameCalledStates.get(RESTORE_START), equalTo(store1.name())); + assertThat(stateRestoreListener.storeNameCalledStates.get(RESTORE_BATCH), equalTo(store1.name())); + assertThat(stateRestoreListener.storeNameCalledStates.get(RESTORE_END), equalTo(store1.name())); + } + + @Test + public void shouldRestoreRecordsFromCheckpointToHighWatermark() throws IOException { + initializeConsumer(5, 5, t1); + + final OffsetCheckpoint offsetCheckpoint = new OffsetCheckpoint(new File(stateManager.baseDir(), + StateManagerUtil.CHECKPOINT_FILE_NAME)); + offsetCheckpoint.write(Collections.singletonMap(t1, 5L)); + + stateManager.initialize(); + stateManager.registerStore(store1, stateRestoreCallback); + assertEquals(5, stateRestoreCallback.restored.size()); + } + + + @Test + public void shouldFlushStateStores() { + stateManager.initialize(); + // register the stores + initializeConsumer(1, 0, t1); + stateManager.registerStore(store1, stateRestoreCallback); + initializeConsumer(1, 0, t2); + stateManager.registerStore(store2, stateRestoreCallback); + + stateManager.flush(); + assertTrue(store1.flushed); + assertTrue(store2.flushed); + } + + @Test + public void shouldThrowProcessorStateStoreExceptionIfStoreFlushFailed() { + stateManager.initialize(); + // register the stores + initializeConsumer(1, 0, t1); + stateManager.registerStore(new NoOpReadOnlyStore(store1.name()) { + @Override + public void flush() { + throw new RuntimeException("KABOOM!"); + } + }, stateRestoreCallback); + assertThrows(StreamsException.class, stateManager::flush); + } + + @Test + public void shouldCloseStateStores() throws IOException { + stateManager.initialize(); + // register the stores + initializeConsumer(1, 0, t1); + stateManager.registerStore(store1, stateRestoreCallback); + initializeConsumer(1, 0, t2); + stateManager.registerStore(store2, stateRestoreCallback); + + stateManager.close(); + assertFalse(store1.isOpen()); + assertFalse(store2.isOpen()); + } + + @Test + public void shouldThrowProcessorStateStoreExceptionIfStoreCloseFailed() { + stateManager.initialize(); + initializeConsumer(1, 0, t1); + stateManager.registerStore(new NoOpReadOnlyStore(store1.name()) { + @Override + public void close() { + throw new RuntimeException("KABOOM!"); + } + }, stateRestoreCallback); + + assertThrows(ProcessorStateException.class, stateManager::close); + } + + @Test + public void shouldThrowIllegalArgumentExceptionIfCallbackIsNull() { + stateManager.initialize(); + try { + stateManager.registerStore(store1, null); + fail("should have thrown due to null callback"); + } catch (final IllegalArgumentException e) { + //pass + } + } + + @Test + public void shouldNotCloseStoresIfCloseAlreadyCalled() { + stateManager.initialize(); + initializeConsumer(1, 0, t1); + stateManager.registerStore(new NoOpReadOnlyStore("t1-store") { + @Override + public void close() { + if (!isOpen()) { + throw new RuntimeException("store already closed"); + } + super.close(); + } + }, stateRestoreCallback); + stateManager.close(); + + stateManager.close(); + } + + @Test + public void shouldAttemptToCloseAllStoresEvenWhenSomeException() { + stateManager.initialize(); + initializeConsumer(1, 0, t1); + final NoOpReadOnlyStore store = new NoOpReadOnlyStore("t1-store") { + @Override + public void close() { + super.close(); + throw new RuntimeException("KABOOM!"); + } + }; + stateManager.registerStore(store, stateRestoreCallback); + + initializeConsumer(1, 0, t2); + stateManager.registerStore(store2, stateRestoreCallback); + + try { + stateManager.close(); + } catch (final ProcessorStateException e) { + // expected + } + assertFalse(store.isOpen()); + assertFalse(store2.isOpen()); + } + + @Test + public void shouldCheckpointOffsets() throws IOException { + final Map offsets = Collections.singletonMap(t1, 25L); + stateManager.initialize(); + + stateManager.updateChangelogOffsets(offsets); + stateManager.checkpoint(); + + final Map result = readOffsetsCheckpoint(); + assertThat(result, equalTo(offsets)); + assertThat(stateManager.changelogOffsets(), equalTo(offsets)); + } + + @Test + public void shouldNotRemoveOffsetsOfUnUpdatedTablesDuringCheckpoint() { + stateManager.initialize(); + initializeConsumer(10, 0, t1); + stateManager.registerStore(store1, stateRestoreCallback); + initializeConsumer(20, 0, t2); + stateManager.registerStore(store2, stateRestoreCallback); + + final Map initialCheckpoint = stateManager.changelogOffsets(); + stateManager.updateChangelogOffsets(Collections.singletonMap(t1, 101L)); + stateManager.checkpoint(); + + final Map updatedCheckpoint = stateManager.changelogOffsets(); + assertThat(updatedCheckpoint.get(t2), equalTo(initialCheckpoint.get(t2))); + assertThat(updatedCheckpoint.get(t1), equalTo(101L)); + } + + @Test + public void shouldSkipNullKeysWhenRestoring() { + final HashMap startOffsets = new HashMap<>(); + startOffsets.put(t1, 1L); + final HashMap endOffsets = new HashMap<>(); + endOffsets.put(t1, 3L); + consumer.updatePartitions(t1.topic(), Collections.singletonList(new PartitionInfo(t1.topic(), t1.partition(), null, null, null))); + consumer.assign(Collections.singletonList(t1)); + consumer.updateEndOffsets(endOffsets); + consumer.updateBeginningOffsets(startOffsets); + consumer.addRecord(new ConsumerRecord<>(t1.topic(), t1.partition(), 1, null, "null".getBytes())); + final byte[] expectedKey = "key".getBytes(); + final byte[] expectedValue = "value".getBytes(); + consumer.addRecord(new ConsumerRecord<>(t1.topic(), t1.partition(), 2, expectedKey, expectedValue)); + + stateManager.initialize(); + stateManager.registerStore(store1, stateRestoreCallback); + final KeyValue restoredKv = stateRestoreCallback.restored.get(0); + assertThat(stateRestoreCallback.restored, equalTo(Collections.singletonList(KeyValue.pair(restoredKv.key, restoredKv.value)))); + } + + @Test + public void shouldCheckpointRestoredOffsetsToFile() throws IOException { + stateManager.initialize(); + initializeConsumer(10, 0, t1); + stateManager.registerStore(store1, stateRestoreCallback); + stateManager.checkpoint(); + stateManager.close(); + + final Map checkpointMap = stateManager.changelogOffsets(); + assertThat(checkpointMap, equalTo(Collections.singletonMap(t1, 10L))); + assertThat(readOffsetsCheckpoint(), equalTo(checkpointMap)); + } + + @Test + public void shouldSkipGlobalInMemoryStoreOffsetsToFile() throws IOException { + stateManager.initialize(); + initializeConsumer(10, 0, t3); + stateManager.registerStore(store3, stateRestoreCallback); + stateManager.close(); + + assertThat(readOffsetsCheckpoint(), equalTo(Collections.emptyMap())); + } + + private Map readOffsetsCheckpoint() throws IOException { + final OffsetCheckpoint offsetCheckpoint = new OffsetCheckpoint(new File(stateManager.baseDir(), + StateManagerUtil.CHECKPOINT_FILE_NAME)); + return offsetCheckpoint.read(); + } + + @Test + public void shouldNotRetryWhenEndOffsetsThrowsTimeoutExceptionAndTaskTimeoutIsZero() { + final AtomicInteger numberOfCalls = new AtomicInteger(0); + consumer = new MockConsumer(OffsetResetStrategy.EARLIEST) { + @Override + public synchronized Map endOffsets(final Collection partitions) { + numberOfCalls.incrementAndGet(); + throw new TimeoutException("KABOOM!"); + } + }; + initializeConsumer(0, 0, t1, t2, t3, t4); + + streamsConfig = new StreamsConfig(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"), + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()), + mkEntry(StreamsConfig.TASK_TIMEOUT_MS_CONFIG, 0L) + )); + + stateManager = new GlobalStateManagerImpl( + new LogContext("mock"), + time, + topology, + consumer, + stateDirectory, + stateRestoreListener, + streamsConfig + ); + processorContext.setStateManger(stateManager); + stateManager.setGlobalProcessorContext(processorContext); + + final StreamsException expected = assertThrows( + StreamsException.class, + () -> stateManager.initialize() + ); + final Throwable cause = expected.getCause(); + assertThat(cause, instanceOf(TimeoutException.class)); + assertThat(cause.getMessage(), equalTo("KABOOM!")); + + assertEquals(numberOfCalls.get(), 1); + } + + @Test + public void shouldRetryAtLeastOnceWhenEndOffsetsThrowsTimeoutException() { + final AtomicInteger numberOfCalls = new AtomicInteger(0); + consumer = new MockConsumer(OffsetResetStrategy.EARLIEST) { + @Override + public synchronized Map endOffsets(final Collection partitions) { + time.sleep(100L); + numberOfCalls.incrementAndGet(); + throw new TimeoutException("KABOOM!"); + } + }; + initializeConsumer(0, 0, t1, t2, t3, t4); + + streamsConfig = new StreamsConfig(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"), + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()), + mkEntry(StreamsConfig.TASK_TIMEOUT_MS_CONFIG, 1L) + )); + + stateManager = new GlobalStateManagerImpl( + new LogContext("mock"), + time, + topology, + consumer, + stateDirectory, + stateRestoreListener, + streamsConfig + ); + processorContext.setStateManger(stateManager); + stateManager.setGlobalProcessorContext(processorContext); + + final TimeoutException expected = assertThrows( + TimeoutException.class, + () -> stateManager.initialize() + ); + assertThat(expected.getMessage(), equalTo("Global task did not make progress to restore state within 100 ms. Adjust `task.timeout.ms` if needed.")); + + assertEquals(numberOfCalls.get(), 2); + } + + @Test + public void shouldRetryWhenEndOffsetsThrowsTimeoutExceptionUntilTaskTimeoutExpired() { + final AtomicInteger numberOfCalls = new AtomicInteger(0); + consumer = new MockConsumer(OffsetResetStrategy.EARLIEST) { + @Override + public synchronized Map endOffsets(final Collection partitions) { + time.sleep(100L); + numberOfCalls.incrementAndGet(); + throw new TimeoutException("KABOOM!"); + } + }; + initializeConsumer(0, 0, t1, t2, t3, t4); + + streamsConfig = new StreamsConfig(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"), + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()), + mkEntry(StreamsConfig.TASK_TIMEOUT_MS_CONFIG, 1000L) + )); + + stateManager = new GlobalStateManagerImpl( + new LogContext("mock"), + time, + topology, + consumer, + stateDirectory, + stateRestoreListener, + streamsConfig + ); + processorContext.setStateManger(stateManager); + stateManager.setGlobalProcessorContext(processorContext); + + final TimeoutException expected = assertThrows( + TimeoutException.class, + () -> stateManager.initialize() + ); + assertThat(expected.getMessage(), equalTo("Global task did not make progress to restore state within 1000 ms. Adjust `task.timeout.ms` if needed.")); + + assertEquals(numberOfCalls.get(), 11); + } + + @Test + public void shouldNotFailOnSlowProgressWhenEndOffsetsThrowsTimeoutException() { + final AtomicInteger numberOfCalls = new AtomicInteger(0); + consumer = new MockConsumer(OffsetResetStrategy.EARLIEST) { + @Override + public synchronized Map endOffsets(final Collection partitions) { + time.sleep(1L); + if (numberOfCalls.incrementAndGet() % 3 == 0) { + return super.endOffsets(partitions); + } + throw new TimeoutException("KABOOM!"); + } + + @Override + public synchronized long position(final TopicPartition partition) { + return numberOfCalls.incrementAndGet(); + } + }; + initializeConsumer(0, 0, t1, t2, t3, t4); + + streamsConfig = new StreamsConfig(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"), + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()), + mkEntry(StreamsConfig.TASK_TIMEOUT_MS_CONFIG, 10L) + )); + + stateManager = new GlobalStateManagerImpl( + new LogContext("mock"), + time, + topology, + consumer, + stateDirectory, + stateRestoreListener, + streamsConfig + ); + processorContext.setStateManger(stateManager); + stateManager.setGlobalProcessorContext(processorContext); + + stateManager.initialize(); + } + + @Test + public void shouldNotRetryWhenPartitionsForThrowsTimeoutExceptionAndTaskTimeoutIsZero() { + final AtomicInteger numberOfCalls = new AtomicInteger(0); + consumer = new MockConsumer(OffsetResetStrategy.EARLIEST) { + @Override + public List partitionsFor(final String topic) { + numberOfCalls.incrementAndGet(); + throw new TimeoutException("KABOOM!"); + } + }; + initializeConsumer(0, 0, t1, t2, t3, t4); + + streamsConfig = new StreamsConfig(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"), + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()), + mkEntry(StreamsConfig.TASK_TIMEOUT_MS_CONFIG, 0L) + )); + + stateManager = new GlobalStateManagerImpl( + new LogContext("mock"), + time, + topology, + consumer, + stateDirectory, + stateRestoreListener, + streamsConfig + ); + processorContext.setStateManger(stateManager); + stateManager.setGlobalProcessorContext(processorContext); + + final StreamsException expected = assertThrows( + StreamsException.class, + () -> stateManager.initialize() + ); + final Throwable cause = expected.getCause(); + assertThat(cause, instanceOf(TimeoutException.class)); + assertThat(cause.getMessage(), equalTo("KABOOM!")); + + assertEquals(numberOfCalls.get(), 1); + } + + @Test + public void shouldRetryAtLeastOnceWhenPartitionsForThrowsTimeoutException() { + final AtomicInteger numberOfCalls = new AtomicInteger(0); + consumer = new MockConsumer(OffsetResetStrategy.EARLIEST) { + @Override + public List partitionsFor(final String topic) { + time.sleep(100L); + numberOfCalls.incrementAndGet(); + throw new TimeoutException("KABOOM!"); + } + }; + initializeConsumer(0, 0, t1, t2, t3, t4); + + streamsConfig = new StreamsConfig(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"), + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()), + mkEntry(StreamsConfig.TASK_TIMEOUT_MS_CONFIG, 1L) + )); + + stateManager = new GlobalStateManagerImpl( + new LogContext("mock"), + time, + topology, + consumer, + stateDirectory, + stateRestoreListener, + streamsConfig + ); + processorContext.setStateManger(stateManager); + stateManager.setGlobalProcessorContext(processorContext); + + final TimeoutException expected = assertThrows( + TimeoutException.class, + () -> stateManager.initialize() + ); + assertThat(expected.getMessage(), equalTo("Global task did not make progress to restore state within 100 ms. Adjust `task.timeout.ms` if needed.")); + + assertEquals(numberOfCalls.get(), 2); + } + + @Test + public void shouldRetryWhenPartitionsForThrowsTimeoutExceptionUntilTaskTimeoutExpires() { + final AtomicInteger numberOfCalls = new AtomicInteger(0); + consumer = new MockConsumer(OffsetResetStrategy.EARLIEST) { + @Override + public List partitionsFor(final String topic) { + time.sleep(100L); + numberOfCalls.incrementAndGet(); + throw new TimeoutException("KABOOM!"); + } + }; + initializeConsumer(0, 0, t1, t2, t3, t4); + + streamsConfig = new StreamsConfig(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"), + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()), + mkEntry(StreamsConfig.TASK_TIMEOUT_MS_CONFIG, 1000L) + )); + + stateManager = new GlobalStateManagerImpl( + new LogContext("mock"), + time, + topology, + consumer, + stateDirectory, + stateRestoreListener, + streamsConfig + ); + processorContext.setStateManger(stateManager); + stateManager.setGlobalProcessorContext(processorContext); + + final TimeoutException expected = assertThrows( + TimeoutException.class, + () -> stateManager.initialize() + ); + assertThat(expected.getMessage(), equalTo("Global task did not make progress to restore state within 1000 ms. Adjust `task.timeout.ms` if needed.")); + + assertEquals(numberOfCalls.get(), 11); + } + + @Test + public void shouldNotFailOnSlowProgressWhenPartitionForThrowsTimeoutException() { + final AtomicInteger numberOfCalls = new AtomicInteger(0); + consumer = new MockConsumer(OffsetResetStrategy.EARLIEST) { + @Override + public List partitionsFor(final String topic) { + time.sleep(1L); + if (numberOfCalls.incrementAndGet() % 3 == 0) { + return super.partitionsFor(topic); + } + throw new TimeoutException("KABOOM!"); + } + + @Override + public synchronized long position(final TopicPartition partition) { + return numberOfCalls.incrementAndGet(); + } + }; + initializeConsumer(0, 0, t1, t2, t3, t4); + + streamsConfig = new StreamsConfig(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"), + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()), + mkEntry(StreamsConfig.TASK_TIMEOUT_MS_CONFIG, 10L) + )); + + stateManager = new GlobalStateManagerImpl( + new LogContext("mock"), + time, + topology, + consumer, + stateDirectory, + stateRestoreListener, + streamsConfig + ); + processorContext.setStateManger(stateManager); + stateManager.setGlobalProcessorContext(processorContext); + + stateManager.initialize(); + } + + @Test + public void shouldNotRetryWhenPositionThrowsTimeoutExceptionAndTaskTimeoutIsZero() { + final AtomicInteger numberOfCalls = new AtomicInteger(0); + consumer = new MockConsumer(OffsetResetStrategy.EARLIEST) { + @Override + public synchronized long position(final TopicPartition partition) { + numberOfCalls.incrementAndGet(); + throw new TimeoutException("KABOOM!"); + } + }; + initializeConsumer(0, 0, t1, t2, t3, t4); + + streamsConfig = new StreamsConfig(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"), + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()), + mkEntry(StreamsConfig.TASK_TIMEOUT_MS_CONFIG, 0L) + )); + + stateManager = new GlobalStateManagerImpl( + new LogContext("mock"), + time, + topology, + consumer, + stateDirectory, + stateRestoreListener, + streamsConfig + ); + processorContext.setStateManger(stateManager); + stateManager.setGlobalProcessorContext(processorContext); + + final StreamsException expected = assertThrows( + StreamsException.class, + () -> stateManager.initialize() + ); + final Throwable cause = expected.getCause(); + assertThat(cause, instanceOf(TimeoutException.class)); + assertThat(cause.getMessage(), equalTo("KABOOM!")); + + assertEquals(numberOfCalls.get(), 1); + } + + @Test + public void shouldRetryAtLeastOnceWhenPositionThrowsTimeoutException() { + final AtomicInteger numberOfCalls = new AtomicInteger(0); + consumer = new MockConsumer(OffsetResetStrategy.EARLIEST) { + @Override + public synchronized long position(final TopicPartition partition) { + time.sleep(100L); + numberOfCalls.incrementAndGet(); + throw new TimeoutException("KABOOM!"); + } + }; + initializeConsumer(0, 0, t1, t2, t3, t4); + + streamsConfig = new StreamsConfig(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"), + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()), + mkEntry(StreamsConfig.TASK_TIMEOUT_MS_CONFIG, 1L) + )); + + stateManager = new GlobalStateManagerImpl( + new LogContext("mock"), + time, + topology, + consumer, + stateDirectory, + stateRestoreListener, + streamsConfig + ); + processorContext.setStateManger(stateManager); + stateManager.setGlobalProcessorContext(processorContext); + + final TimeoutException expected = assertThrows( + TimeoutException.class, + () -> stateManager.initialize() + ); + assertThat(expected.getMessage(), equalTo("Global task did not make progress to restore state within 100 ms. Adjust `task.timeout.ms` if needed.")); + + assertEquals(numberOfCalls.get(), 2); + } + + @Test + public void shouldRetryWhenPositionThrowsTimeoutExceptionUntilTaskTimeoutExpired() { + final AtomicInteger numberOfCalls = new AtomicInteger(0); + consumer = new MockConsumer(OffsetResetStrategy.EARLIEST) { + @Override + public synchronized long position(final TopicPartition partition) { + time.sleep(100L); + numberOfCalls.incrementAndGet(); + throw new TimeoutException("KABOOM!"); + } + }; + initializeConsumer(0, 0, t1, t2, t3, t4); + + streamsConfig = new StreamsConfig(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"), + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()), + mkEntry(StreamsConfig.TASK_TIMEOUT_MS_CONFIG, 1000L) + )); + + stateManager = new GlobalStateManagerImpl( + new LogContext("mock"), + time, + topology, + consumer, + stateDirectory, + stateRestoreListener, + streamsConfig + ); + processorContext.setStateManger(stateManager); + stateManager.setGlobalProcessorContext(processorContext); + + final TimeoutException expected = assertThrows( + TimeoutException.class, + () -> stateManager.initialize() + ); + assertThat(expected.getMessage(), equalTo("Global task did not make progress to restore state within 1000 ms. Adjust `task.timeout.ms` if needed.")); + + assertEquals(numberOfCalls.get(), 11); + } + + @Test + public void shouldNotFailOnSlowProgressWhenPositionThrowsTimeoutException() { + final AtomicInteger numberOfCalls = new AtomicInteger(0); + consumer = new MockConsumer(OffsetResetStrategy.EARLIEST) { + @Override + public synchronized long position(final TopicPartition partition) { + time.sleep(1L); + if (numberOfCalls.incrementAndGet() % 3 == 0) { + return numberOfCalls.incrementAndGet(); + } + throw new TimeoutException("KABOOM!"); + } + }; + initializeConsumer(0, 0, t1, t2, t3, t4); + + streamsConfig = new StreamsConfig(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"), + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()), + mkEntry(StreamsConfig.TASK_TIMEOUT_MS_CONFIG, 10L) + )); + + stateManager = new GlobalStateManagerImpl( + new LogContext("mock"), + time, + topology, + consumer, + stateDirectory, + stateRestoreListener, + streamsConfig + ); + processorContext.setStateManger(stateManager); + stateManager.setGlobalProcessorContext(processorContext); + + stateManager.initialize(); + } + + @Test + public void shouldUsePollMsPlusRequestTimeoutInPollDuringRestoreAndTimeoutWhenNoProgressDuringRestore() { + consumer = new MockConsumer(OffsetResetStrategy.EARLIEST) { + @Override + public synchronized ConsumerRecords poll(final Duration timeout) { + time.sleep(timeout.toMillis()); + return super.poll(timeout); + } + }; + + final HashMap startOffsets = new HashMap<>(); + startOffsets.put(t1, 1L); + final HashMap endOffsets = new HashMap<>(); + endOffsets.put(t1, 3L); + consumer.updatePartitions(t1.topic(), Collections.singletonList(new PartitionInfo(t1.topic(), t1.partition(), null, null, null))); + consumer.assign(Collections.singletonList(t1)); + consumer.updateBeginningOffsets(startOffsets); + consumer.updateEndOffsets(endOffsets); + + streamsConfig = new StreamsConfig(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"), + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()) + )); + + stateManager = new GlobalStateManagerImpl( + new LogContext("mock"), + time, + topology, + consumer, + stateDirectory, + stateRestoreListener, + streamsConfig + ); + processorContext.setStateManger(stateManager); + stateManager.setGlobalProcessorContext(processorContext); + + final long startTime = time.milliseconds(); + + final TimeoutException exception = assertThrows( + TimeoutException.class, + () -> stateManager.initialize() + ); + assertThat( + exception.getMessage(), + equalTo("Global task did not make progress to restore state within 301000 ms. Adjust `task.timeout.ms` if needed.") + ); + assertThat(time.milliseconds() - startTime, equalTo(331_100L)); + } + + private void writeCorruptCheckpoint() throws IOException { + final File checkpointFile = new File(stateManager.baseDir(), StateManagerUtil.CHECKPOINT_FILE_NAME); + try (final OutputStream stream = Files.newOutputStream(checkpointFile.toPath())) { + stream.write("0\n1\nfoo".getBytes()); + } + } + + private void initializeConsumer(final long numRecords, final long startOffset, final TopicPartition... topicPartitions) { + consumer.assign(Arrays.asList(topicPartitions)); + + final Map startOffsets = new HashMap<>(); + final Map endOffsets = new HashMap<>(); + for (final TopicPartition topicPartition : topicPartitions) { + startOffsets.put(topicPartition, startOffset); + endOffsets.put(topicPartition, startOffset + numRecords); + consumer.updatePartitions(topicPartition.topic(), Collections.singletonList(new PartitionInfo(topicPartition.topic(), topicPartition.partition(), null, null, null))); + for (int i = 0; i < numRecords; i++) { + consumer.addRecord(new ConsumerRecord<>(topicPartition.topic(), topicPartition.partition(), startOffset + i, "key".getBytes(), "value".getBytes())); + } + } + consumer.updateEndOffsets(endOffsets); + consumer.updateBeginningOffsets(startOffsets); + } + + private Map writeCheckpoint() throws IOException { + final OffsetCheckpoint checkpoint = new OffsetCheckpoint(checkpointFile); + final Map expected = Collections.singletonMap(t1, 1L); + checkpoint.write(expected); + return expected; + } + + private static class TheStateRestoreCallback implements StateRestoreCallback { + private final List> restored = new ArrayList<>(); + + @Override + public void restore(final byte[] key, final byte[] value) { + restored.add(KeyValue.pair(key, value)); + } + } + + private static class ConverterStore extends NoOpReadOnlyStore implements TimestampedBytesStore { + ConverterStore(final String name, + final boolean rocksdbStore) { + super(name, rocksdbStore); + } + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateTaskTest.java new file mode 100644 index 0000000..31be9dc --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateTaskTest.java @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.common.serialization.IntegerDeserializer; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.LongSerializer; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.errors.LogAndContinueExceptionHandler; +import org.apache.kafka.streams.errors.LogAndFailExceptionHandler; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.test.GlobalStateManagerStub; +import org.apache.kafka.test.MockProcessorNode; +import org.apache.kafka.test.MockSourceNode; +import org.apache.kafka.test.NoOpProcessorContext; +import org.apache.kafka.test.TestUtils; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static java.util.Arrays.asList; +import static org.apache.kafka.streams.processor.internals.testutil.ConsumerRecordUtil.record; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class GlobalStateTaskTest { + + private final LogContext logContext = new LogContext(); + + private final String topic1 = "t1"; + private final String topic2 = "t2"; + private final TopicPartition t1 = new TopicPartition(topic1, 1); + private final TopicPartition t2 = new TopicPartition(topic2, 1); + private final MockSourceNode sourceOne = new MockSourceNode<>( + new StringDeserializer(), + new StringDeserializer()); + private final MockSourceNode sourceTwo = new MockSourceNode<>( + new IntegerDeserializer(), + new IntegerDeserializer()); + private final MockProcessorNode processorOne = new MockProcessorNode<>(); + private final MockProcessorNode processorTwo = new MockProcessorNode<>(); + + private final Map offsets = new HashMap<>(); + private File testDirectory = TestUtils.tempDirectory("global-store"); + private final NoOpProcessorContext context = new NoOpProcessorContext(); + + private ProcessorTopology topology; + private GlobalStateManagerStub stateMgr; + private GlobalStateUpdateTask globalStateTask; + + @Before + public void before() { + final Set storeNames = Utils.mkSet("t1-store", "t2-store"); + final Map> sourceByTopics = new HashMap<>(); + sourceByTopics.put(topic1, sourceOne); + sourceByTopics.put(topic2, sourceTwo); + final Map storeToTopic = new HashMap<>(); + storeToTopic.put("t1-store", topic1); + storeToTopic.put("t2-store", topic2); + topology = ProcessorTopologyFactories.with( + asList(sourceOne, sourceTwo, processorOne, processorTwo), + sourceByTopics, + Collections.emptyList(), + storeToTopic); + + offsets.put(t1, 50L); + offsets.put(t2, 100L); + stateMgr = new GlobalStateManagerStub(storeNames, offsets, testDirectory); + globalStateTask = new GlobalStateUpdateTask( + logContext, + topology, + context, + stateMgr, + new LogAndFailExceptionHandler() + ); + } + + @Test + public void shouldInitializeStateManager() { + final Map startingOffsets = globalStateTask.initialize(); + assertTrue(stateMgr.initialized); + assertEquals(offsets, startingOffsets); + } + + @Test + public void shouldInitializeContext() { + globalStateTask.initialize(); + assertTrue(context.initialized); + } + + @Test + public void shouldInitializeProcessorTopology() { + globalStateTask.initialize(); + assertTrue(sourceOne.initialized); + assertTrue(sourceTwo.initialized); + assertTrue(processorOne.initialized); + assertTrue(processorTwo.initialized); + } + + @Test + public void shouldProcessRecordsForTopic() { + globalStateTask.initialize(); + globalStateTask.update(record(topic1, 1, 1, "foo".getBytes(), "bar".getBytes())); + assertEquals(1, sourceOne.numReceived); + assertEquals(0, sourceTwo.numReceived); + } + + @Test + public void shouldProcessRecordsForOtherTopic() { + final byte[] integerBytes = new IntegerSerializer().serialize("foo", 1); + globalStateTask.initialize(); + globalStateTask.update(record(topic2, 1, 1, integerBytes, integerBytes)); + assertEquals(1, sourceTwo.numReceived); + assertEquals(0, sourceOne.numReceived); + } + + private void maybeDeserialize(final GlobalStateUpdateTask globalStateTask, + final byte[] key, + final byte[] recordValue, + final boolean failExpected) { + final ConsumerRecord record = new ConsumerRecord<>( + topic2, 1, 1, 0L, TimestampType.CREATE_TIME, + 0, 0, key, recordValue, new RecordHeaders(), Optional.empty() + ); + globalStateTask.initialize(); + try { + globalStateTask.update(record); + if (failExpected) { + fail("Should have failed to deserialize."); + } + } catch (final StreamsException e) { + if (!failExpected) { + fail("Shouldn't have failed to deserialize."); + } + } + } + + + @Test + public void shouldThrowStreamsExceptionWhenKeyDeserializationFails() { + final byte[] key = new LongSerializer().serialize(topic2, 1L); + final byte[] recordValue = new IntegerSerializer().serialize(topic2, 10); + maybeDeserialize(globalStateTask, key, recordValue, true); + } + + + @Test + public void shouldThrowStreamsExceptionWhenValueDeserializationFails() { + final byte[] key = new IntegerSerializer().serialize(topic2, 1); + final byte[] recordValue = new LongSerializer().serialize(topic2, 10L); + maybeDeserialize(globalStateTask, key, recordValue, true); + } + + @Test + public void shouldNotThrowStreamsExceptionWhenKeyDeserializationFailsWithSkipHandler() { + final GlobalStateUpdateTask globalStateTask2 = new GlobalStateUpdateTask( + logContext, + topology, + context, + stateMgr, + new LogAndContinueExceptionHandler() + ); + final byte[] key = new LongSerializer().serialize(topic2, 1L); + final byte[] recordValue = new IntegerSerializer().serialize(topic2, 10); + + maybeDeserialize(globalStateTask2, key, recordValue, false); + } + + @Test + public void shouldNotThrowStreamsExceptionWhenValueDeserializationFails() { + final GlobalStateUpdateTask globalStateTask2 = new GlobalStateUpdateTask( + logContext, + topology, + context, + stateMgr, + new LogAndContinueExceptionHandler() + ); + final byte[] key = new IntegerSerializer().serialize(topic2, 1); + final byte[] recordValue = new LongSerializer().serialize(topic2, 10L); + + maybeDeserialize(globalStateTask2, key, recordValue, false); + } + + + @Test + public void shouldFlushStateManagerWithOffsets() { + final Map expectedOffsets = new HashMap<>(); + expectedOffsets.put(t1, 52L); + expectedOffsets.put(t2, 100L); + globalStateTask.initialize(); + globalStateTask.update(record(topic1, 1, 51, "foo".getBytes(), "foo".getBytes())); + globalStateTask.flushState(); + assertEquals(expectedOffsets, stateMgr.changelogOffsets()); + } + + @Test + public void shouldCheckpointOffsetsWhenStateIsFlushed() { + final Map expectedOffsets = new HashMap<>(); + expectedOffsets.put(t1, 102L); + expectedOffsets.put(t2, 100L); + globalStateTask.initialize(); + globalStateTask.update(record(topic1, 1, 101, "foo".getBytes(), "foo".getBytes())); + globalStateTask.flushState(); + assertThat(stateMgr.changelogOffsets(), equalTo(expectedOffsets)); + } + + @Test + public void shouldWipeGlobalStateDirectory() throws Exception { + assertTrue(stateMgr.baseDir().exists()); + globalStateTask.close(true); + assertFalse(stateMgr.baseDir().exists()); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java new file mode 100644 index 0000000..391fa09 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStreamThreadTest.java @@ -0,0 +1,319 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.InvalidOffsetException; +import org.apache.kafka.clients.consumer.MockConsumer; +import org.apache.kafka.clients.consumer.OffsetResetStrategy; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.internals.InternalNameProvider; +import org.apache.kafka.streams.kstream.internals.MaterializedInternal; +import org.apache.kafka.streams.kstream.internals.TimestampedKeyValueStoreMaterializer; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.api.ContextualProcessor; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.test.MockStateRestoreListener; +import org.apache.kafka.test.TestUtils; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Set; + +import static org.apache.kafka.streams.processor.internals.GlobalStreamThread.State.DEAD; +import static org.apache.kafka.streams.processor.internals.GlobalStreamThread.State.RUNNING; +import static org.apache.kafka.streams.processor.internals.testutil.ConsumerRecordUtil.record; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.IsInstanceOf.instanceOf; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class GlobalStreamThreadTest { + private final InternalTopologyBuilder builder = new InternalTopologyBuilder(); + private final MockConsumer mockConsumer = new MockConsumer<>(OffsetResetStrategy.NONE); + private final MockTime time = new MockTime(); + private final MockStateRestoreListener stateRestoreListener = new MockStateRestoreListener(); + private GlobalStreamThread globalStreamThread; + private StreamsConfig config; + private String baseDirectoryName; + + private final static String GLOBAL_STORE_TOPIC_NAME = "foo"; + private final static String GLOBAL_STORE_NAME = "bar"; + private final TopicPartition topicPartition = new TopicPartition(GLOBAL_STORE_TOPIC_NAME, 0); + + @Before + public void before() { + final MaterializedInternal> materialized = + new MaterializedInternal<>(Materialized.with(null, null), + new InternalNameProvider() { + @Override + public String newProcessorName(final String prefix) { + return "processorName"; + } + + @Override + public String newStoreName(final String prefix) { + return GLOBAL_STORE_NAME; + } + }, + "store-" + ); + + final ProcessorSupplier processorSupplier = () -> + new ContextualProcessor() { + @Override + public void process(final Record record) { + } + }; + + builder.addGlobalStore( + new TimestampedKeyValueStoreMaterializer<>(materialized).materialize().withLoggingDisabled(), + "sourceName", + null, + null, + null, + GLOBAL_STORE_TOPIC_NAME, + "processorName", + processorSupplier); + + baseDirectoryName = TestUtils.tempDirectory().getAbsolutePath(); + final HashMap properties = new HashMap<>(); + properties.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "blah"); + properties.put(StreamsConfig.APPLICATION_ID_CONFIG, "testAppId"); + properties.put(StreamsConfig.STATE_DIR_CONFIG, baseDirectoryName); + properties.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.ByteArraySerde.class.getName()); + properties.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.ByteArraySerde.class.getName()); + config = new StreamsConfig(properties); + globalStreamThread = new GlobalStreamThread( + builder.rewriteTopology(config).buildGlobalStateTopology(), + config, + mockConsumer, + new StateDirectory(config, time, true, false), + 0, + new StreamsMetricsImpl(new Metrics(), "test-client", StreamsConfig.METRICS_LATEST, time), + time, + "clientId", + stateRestoreListener, + e -> { } + ); + } + + @Test + public void shouldThrowStreamsExceptionOnStartupIfThereIsAStreamsException() throws Exception { + // should throw as the MockConsumer hasn't been configured and there are no + // partitions available + final StateStore globalStore = builder.globalStateStores().get(GLOBAL_STORE_NAME); + try { + globalStreamThread.start(); + fail("Should have thrown StreamsException if start up failed"); + } catch (final StreamsException e) { + // ok + } + globalStreamThread.join(); + assertThat(globalStore.isOpen(), is(false)); + assertFalse(globalStreamThread.stillRunning()); + } + + @Test + public void shouldThrowStreamsExceptionOnStartupIfExceptionOccurred() throws Exception { + final MockConsumer mockConsumer = new MockConsumer(OffsetResetStrategy.EARLIEST) { + @Override + public List partitionsFor(final String topic) { + throw new RuntimeException("KABOOM!"); + } + }; + final StateStore globalStore = builder.globalStateStores().get(GLOBAL_STORE_NAME); + globalStreamThread = new GlobalStreamThread( + builder.buildGlobalStateTopology(), + config, + mockConsumer, + new StateDirectory(config, time, true, false), + 0, + new StreamsMetricsImpl(new Metrics(), "test-client", StreamsConfig.METRICS_LATEST, time), + time, + "clientId", + stateRestoreListener, + e -> { } + ); + + try { + globalStreamThread.start(); + fail("Should have thrown StreamsException if start up failed"); + } catch (final StreamsException e) { + assertThat(e.getCause(), instanceOf(RuntimeException.class)); + assertThat(e.getCause().getMessage(), equalTo("KABOOM!")); + } + globalStreamThread.join(); + assertThat(globalStore.isOpen(), is(false)); + assertFalse(globalStreamThread.stillRunning()); + } + + @Test + public void shouldBeRunningAfterSuccessfulStart() throws Exception { + initializeConsumer(); + startAndSwallowError(); + assertTrue(globalStreamThread.stillRunning()); + + globalStreamThread.shutdown(); + globalStreamThread.join(); + } + + @Test(timeout = 30000) + public void shouldStopRunningWhenClosedByUser() throws Exception { + initializeConsumer(); + startAndSwallowError(); + globalStreamThread.shutdown(); + globalStreamThread.join(); + assertEquals(GlobalStreamThread.State.DEAD, globalStreamThread.state()); + } + + @Test + public void shouldCloseStateStoresOnClose() throws Exception { + initializeConsumer(); + startAndSwallowError(); + final StateStore globalStore = builder.globalStateStores().get(GLOBAL_STORE_NAME); + assertTrue(globalStore.isOpen()); + globalStreamThread.shutdown(); + globalStreamThread.join(); + assertFalse(globalStore.isOpen()); + } + + @Test + public void shouldStayDeadAfterTwoCloses() throws Exception { + initializeConsumer(); + startAndSwallowError(); + globalStreamThread.shutdown(); + globalStreamThread.join(); + globalStreamThread.shutdown(); + + assertEquals(GlobalStreamThread.State.DEAD, globalStreamThread.state()); + } + + @Test + public void shouldTransitionToRunningOnStart() throws Exception { + initializeConsumer(); + startAndSwallowError(); + + TestUtils.waitForCondition( + () -> globalStreamThread.state() == RUNNING, + 10 * 1000, + "Thread never started."); + + globalStreamThread.shutdown(); + } + + @Test + public void shouldDieOnInvalidOffsetExceptionDuringStartup() throws Exception { + final StateStore globalStore = builder.globalStateStores().get(GLOBAL_STORE_NAME); + initializeConsumer(); + mockConsumer.setPollException(new InvalidOffsetException("Try Again!") { + @Override + public Set partitions() { + return Collections.singleton(topicPartition); + } + }); + + startAndSwallowError(); + + TestUtils.waitForCondition( + () -> globalStreamThread.state() == DEAD, + 10 * 1000, + "GlobalStreamThread should have died." + ); + globalStreamThread.join(); + + assertThat(globalStore.isOpen(), is(false)); + assertFalse(new File(baseDirectoryName + File.separator + "testAppId" + File.separator + "global").exists()); + } + + @Test + public void shouldDieOnInvalidOffsetExceptionWhileRunning() throws Exception { + final StateStore globalStore = builder.globalStateStores().get(GLOBAL_STORE_NAME); + initializeConsumer(); + startAndSwallowError(); + + TestUtils.waitForCondition( + () -> globalStreamThread.state() == RUNNING, + 10 * 1000, + "Thread never started."); + + mockConsumer.updateEndOffsets(Collections.singletonMap(topicPartition, 1L)); + mockConsumer.addRecord(record(GLOBAL_STORE_TOPIC_NAME, 0, 0L, "K1".getBytes(), "V1".getBytes())); + + TestUtils.waitForCondition( + () -> mockConsumer.position(topicPartition) == 1L, + 10 * 1000, + "Input record never consumed"); + + mockConsumer.setPollException(new InvalidOffsetException("Try Again!") { + @Override + public Set partitions() { + return Collections.singleton(topicPartition); + } + }); + + TestUtils.waitForCondition( + () -> globalStreamThread.state() == DEAD, + 10 * 1000, + "GlobalStreamThread should have died." + ); + globalStreamThread.join(); + + assertThat(globalStore.isOpen(), is(false)); + assertFalse(new File(baseDirectoryName + File.separator + "testAppId" + File.separator + "global").exists()); + } + + private void initializeConsumer() { + mockConsumer.updatePartitions( + GLOBAL_STORE_TOPIC_NAME, + Collections.singletonList(new PartitionInfo( + GLOBAL_STORE_TOPIC_NAME, + 0, + null, + new Node[0], + new Node[0]))); + mockConsumer.updateBeginningOffsets(Collections.singletonMap(topicPartition, 0L)); + mockConsumer.updateEndOffsets(Collections.singletonMap(topicPartition, 0L)); + mockConsumer.assign(Collections.singleton(topicPartition)); + } + + private void startAndSwallowError() { + try { + globalStreamThread.start(); + } catch (final IllegalStateException ignored) { + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/HandlingSourceTopicDeletionIntegrationTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/HandlingSourceTopicDeletionIntegrationTest.java new file mode 100644 index 0000000..017cccf --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/HandlingSourceTopicDeletionIntegrationTest.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KafkaStreams.State; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.errors.StreamsUncaughtExceptionHandler; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.TestName; + +import java.io.IOException; +import java.util.Properties; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.apache.kafka.streams.integration.utils.IntegrationTestUtils.safeUniqueTestName; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +@Category({IntegrationTest.class}) +public class HandlingSourceTopicDeletionIntegrationTest { + + private static final int NUM_BROKERS = 1; + private static final int NUM_THREADS = 2; + private static final long TIMEOUT = 60000; + private static final String INPUT_TOPIC = "inputTopic"; + private static final String OUTPUT_TOPIC = "outputTopic"; + + public static final EmbeddedKafkaCluster CLUSTER = new EmbeddedKafkaCluster(NUM_BROKERS); + + @BeforeClass + public static void startCluster() throws IOException { + CLUSTER.start(); + } + + @AfterClass + public static void closeCluster() { + CLUSTER.stop(); + } + + @Rule + public TestName testName = new TestName(); + + @Before + public void before() throws InterruptedException { + CLUSTER.createTopics(INPUT_TOPIC, OUTPUT_TOPIC); + } + + @After + public void after() throws InterruptedException { + CLUSTER.deleteTopics(INPUT_TOPIC, OUTPUT_TOPIC); + } + + @Test + public void shouldThrowErrorAfterSourceTopicDeleted() throws InterruptedException { + final StreamsBuilder builder = new StreamsBuilder(); + builder.stream(INPUT_TOPIC, Consumed.with(Serdes.Integer(), Serdes.String())) + .to(OUTPUT_TOPIC, Produced.with(Serdes.Integer(), Serdes.String())); + + final String safeTestName = safeUniqueTestName(getClass(), testName); + final String appId = "app-" + safeTestName; + + final Properties streamsConfiguration = new Properties(); + streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, appId); + streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, CLUSTER.bootstrapServers()); + streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.Integer().getClass()); + streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + streamsConfiguration.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, NUM_THREADS); + streamsConfiguration.put(StreamsConfig.METADATA_MAX_AGE_CONFIG, 2000); + + final Topology topology = builder.build(); + final KafkaStreams kafkaStreams1 = new KafkaStreams(topology, streamsConfiguration); + final AtomicBoolean calledUncaughtExceptionHandler1 = new AtomicBoolean(false); + kafkaStreams1.setUncaughtExceptionHandler(exception -> { + calledUncaughtExceptionHandler1.set(true); + return StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse.SHUTDOWN_CLIENT; + }); + kafkaStreams1.start(); + final KafkaStreams kafkaStreams2 = new KafkaStreams(topology, streamsConfiguration); + final AtomicBoolean calledUncaughtExceptionHandler2 = new AtomicBoolean(false); + kafkaStreams2.setUncaughtExceptionHandler(exception -> { + calledUncaughtExceptionHandler2.set(true); + return StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse.SHUTDOWN_CLIENT; + }); + kafkaStreams2.start(); + + TestUtils.waitForCondition( + () -> kafkaStreams1.state() == State.RUNNING && kafkaStreams2.state() == State.RUNNING, + TIMEOUT, + () -> "Kafka Streams clients did not reach state RUNNING" + ); + + CLUSTER.deleteTopicAndWait(INPUT_TOPIC); + + TestUtils.waitForCondition( + () -> kafkaStreams1.state() == State.ERROR && kafkaStreams2.state() == State.ERROR, + TIMEOUT, + () -> "Kafka Streams clients did not reach state ERROR" + ); + + assertThat(calledUncaughtExceptionHandler1.get(), is(true)); + assertThat(calledUncaughtExceptionHandler2.get(), is(true)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/HighAvailabilityStreamsPartitionAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/HighAvailabilityStreamsPartitionAssignorTest.java new file mode 100644 index 0000000..9c2d119 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/HighAvailabilityStreamsPartitionAssignorTest.java @@ -0,0 +1,341 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.AdminClient; +import org.apache.kafka.clients.admin.ListOffsetsResult; +import org.apache.kafka.clients.admin.ListOffsetsResult.ListOffsetsResultInfo; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Assignment; +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.GroupSubscription; +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Subscription; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.StreamsConfig.InternalConfig; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.assignment.AssignmentInfo; +import org.apache.kafka.streams.processor.internals.assignment.AssignorError; +import org.apache.kafka.streams.processor.internals.assignment.ReferenceContainer; +import org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo; +import org.apache.kafka.test.MockApiProcessorSupplier; +import org.apache.kafka.test.MockClientSupplier; +import org.apache.kafka.test.MockInternalTopicManager; +import org.apache.kafka.test.MockKeyValueStoreBuilder; +import org.easymock.EasyMock; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; + +import static java.util.Arrays.asList; +import static java.util.Collections.emptySet; +import static java.util.Collections.singletonList; +import static java.util.Collections.singletonMap; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.EMPTY_CHANGELOG_END_OFFSETS; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.EMPTY_TASKS; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_2; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_2; +import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.LATEST_SUPPORTED_VERSION; +import static org.easymock.EasyMock.anyObject; +import static org.easymock.EasyMock.expect; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.is; + +public class HighAvailabilityStreamsPartitionAssignorTest { + + private final List infos = asList( + new PartitionInfo("topic1", 0, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic1", 1, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic1", 2, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic2", 0, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic2", 1, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic2", 2, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic3", 0, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic3", 1, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic3", 2, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic3", 3, Node.noNode(), new Node[0], new Node[0]) + ); + + private final Cluster metadata = new Cluster( + "cluster", + singletonList(Node.noNode()), + infos, + emptySet(), + emptySet()); + + private final StreamsPartitionAssignor partitionAssignor = new StreamsPartitionAssignor(); + private final MockClientSupplier mockClientSupplier = new MockClientSupplier(); + private static final String USER_END_POINT = "localhost:8080"; + private static final String APPLICATION_ID = "stream-partition-assignor-test"; + + private TaskManager taskManager; + private Admin adminClient; + private StreamsConfig streamsConfig = new StreamsConfig(configProps()); + private final InternalTopologyBuilder builder = new InternalTopologyBuilder(); + private TopologyMetadata topologyMetadata = new TopologyMetadata(builder, streamsConfig); + private final StreamsMetadataState streamsMetadataState = EasyMock.createNiceMock(StreamsMetadataState.class); + private final Map subscriptions = new HashMap<>(); + + private ReferenceContainer referenceContainer; + private final MockTime time = new MockTime(); + + private Map configProps() { + final Map configurationMap = new HashMap<>(); + configurationMap.put(StreamsConfig.APPLICATION_ID_CONFIG, APPLICATION_ID); + configurationMap.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, USER_END_POINT); + referenceContainer = new ReferenceContainer(); + referenceContainer.mainConsumer = EasyMock.mock(Consumer.class); + referenceContainer.adminClient = adminClient; + referenceContainer.taskManager = taskManager; + referenceContainer.streamsMetadataState = streamsMetadataState; + referenceContainer.time = time; + configurationMap.put(InternalConfig.REFERENCE_CONTAINER_PARTITION_ASSIGNOR, referenceContainer); + return configurationMap; + } + + // Make sure to complete setting up any mocks (such as TaskManager or AdminClient) before configuring the assignor + private void configurePartitionAssignorWith(final Map props) { + final Map configMap = configProps(); + configMap.putAll(props); + + streamsConfig = new StreamsConfig(configMap); + topologyMetadata = new TopologyMetadata(builder, streamsConfig); + partitionAssignor.configure(configMap); + EasyMock.replay(taskManager, adminClient); + + overwriteInternalTopicManagerWithMock(); + } + + // Useful for tests that don't care about the task offset sums + private void createMockTaskManager(final Set activeTasks) { + createMockTaskManager(getTaskOffsetSums(activeTasks)); + } + + private void createMockTaskManager(final Map taskOffsetSums) { + taskManager = EasyMock.createNiceMock(TaskManager.class); + expect(taskManager.topologyMetadata()).andStubReturn(topologyMetadata); + expect(taskManager.getTaskOffsetSums()).andReturn(taskOffsetSums).anyTimes(); + expect(taskManager.processId()).andReturn(UUID_1).anyTimes(); + topologyMetadata.buildAndRewriteTopology(); + } + + // If you don't care about setting the end offsets for each specific topic partition, the helper method + // getTopicPartitionOffsetMap is useful for building this input map for all partitions + private void createMockAdminClient(final Map changelogEndOffsets) { + adminClient = EasyMock.createMock(AdminClient.class); + + final ListOffsetsResult result = EasyMock.createNiceMock(ListOffsetsResult.class); + final KafkaFutureImpl> allFuture = new KafkaFutureImpl<>(); + allFuture.complete(changelogEndOffsets.entrySet().stream().collect(Collectors.toMap( + Entry::getKey, + t -> { + final ListOffsetsResultInfo info = EasyMock.createNiceMock(ListOffsetsResultInfo.class); + expect(info.offset()).andStubReturn(t.getValue()); + EasyMock.replay(info); + return info; + })) + ); + + expect(adminClient.listOffsets(anyObject())).andStubReturn(result); + expect(result.all()).andReturn(allFuture); + + EasyMock.replay(result); + } + + private void overwriteInternalTopicManagerWithMock() { + final MockInternalTopicManager mockInternalTopicManager = new MockInternalTopicManager( + time, + streamsConfig, + mockClientSupplier.restoreConsumer, + false + ); + partitionAssignor.setInternalTopicManager(mockInternalTopicManager); + } + + @Before + public void setUp() { + createMockAdminClient(EMPTY_CHANGELOG_END_OFFSETS); + } + + + @Test + public void shouldReturnAllActiveTasksToPreviousOwnerRegardlessOfBalanceAndTriggerRebalanceIfEndOffsetFetchFailsAndHighAvailabilityEnabled() { + final long rebalanceInterval = 5 * 60 * 1000L; + + builder.addSource(null, "source1", null, null, null, "topic1"); + builder.addProcessor("processor1", new MockApiProcessorSupplier<>(), "source1"); + builder.addStateStore(new MockKeyValueStoreBuilder("store1", false), "processor1"); + final Set allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2); + + createMockTaskManager(allTasks); + adminClient = EasyMock.createMock(AdminClient.class); + expect(adminClient.listOffsets(anyObject())).andThrow(new StreamsException("Should be handled")); + configurePartitionAssignorWith(singletonMap(StreamsConfig.PROBING_REBALANCE_INTERVAL_MS_CONFIG, rebalanceInterval)); + + final String firstConsumer = "consumer1"; + final String newConsumer = "consumer2"; + + subscriptions.put(firstConsumer, + new Subscription( + singletonList("source1"), + getInfo(UUID_1, allTasks).encode() + )); + subscriptions.put(newConsumer, + new Subscription( + singletonList("source1"), + getInfo(UUID_2, EMPTY_TASKS).encode() + )); + + final Map assignments = partitionAssignor + .assign(metadata, new GroupSubscription(subscriptions)) + .groupAssignment(); + + final AssignmentInfo firstConsumerUserData = AssignmentInfo.decode(assignments.get(firstConsumer).userData()); + final List firstConsumerActiveTasks = firstConsumerUserData.activeTasks(); + final AssignmentInfo newConsumerUserData = AssignmentInfo.decode(assignments.get(newConsumer).userData()); + final List newConsumerActiveTasks = newConsumerUserData.activeTasks(); + + // The tasks were returned to their prior owner + final ArrayList sortedExpectedTasks = new ArrayList<>(allTasks); + Collections.sort(sortedExpectedTasks); + assertThat(firstConsumerActiveTasks, equalTo(sortedExpectedTasks)); + assertThat(newConsumerActiveTasks, empty()); + + // There is a rebalance scheduled + assertThat( + time.milliseconds() + rebalanceInterval, + anyOf( + is(firstConsumerUserData.nextRebalanceMs()), + is(newConsumerUserData.nextRebalanceMs()) + ) + ); + } + + @Test + public void shouldScheduleProbingRebalanceOnThisClientIfWarmupTasksRequired() { + final long rebalanceInterval = 5 * 60 * 1000L; + + builder.addSource(null, "source1", null, null, null, "topic1"); + builder.addProcessor("processor1", new MockApiProcessorSupplier<>(), "source1"); + builder.addStateStore(new MockKeyValueStoreBuilder("store1", false), "processor1"); + final Set allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2); + + createMockTaskManager(allTasks); + createMockAdminClient(getTopicPartitionOffsetsMap( + singletonList(APPLICATION_ID + "-store1-changelog"), + singletonList(3))); + configurePartitionAssignorWith(singletonMap(StreamsConfig.PROBING_REBALANCE_INTERVAL_MS_CONFIG, rebalanceInterval)); + + final String firstConsumer = "consumer1"; + final String newConsumer = "consumer2"; + + subscriptions.put(firstConsumer, + new Subscription( + singletonList("source1"), + getInfo(UUID_1, allTasks).encode() + )); + subscriptions.put(newConsumer, + new Subscription( + singletonList("source1"), + getInfo(UUID_2, EMPTY_TASKS).encode() + )); + + final Map assignments = partitionAssignor + .assign(metadata, new GroupSubscription(subscriptions)) + .groupAssignment(); + + final List firstConsumerActiveTasks = + AssignmentInfo.decode(assignments.get(firstConsumer).userData()).activeTasks(); + final List newConsumerActiveTasks = + AssignmentInfo.decode(assignments.get(newConsumer).userData()).activeTasks(); + + final ArrayList sortedExpectedTasks = new ArrayList<>(allTasks); + Collections.sort(sortedExpectedTasks); + assertThat(firstConsumerActiveTasks, equalTo(sortedExpectedTasks)); + assertThat(newConsumerActiveTasks, empty()); + + assertThat(referenceContainer.assignmentErrorCode.get(), equalTo(AssignorError.NONE.code())); + + final long nextScheduledRebalanceOnThisClient = + AssignmentInfo.decode(assignments.get(firstConsumer).userData()).nextRebalanceMs(); + final long nextScheduledRebalanceOnOtherClient = + AssignmentInfo.decode(assignments.get(newConsumer).userData()).nextRebalanceMs(); + + assertThat(nextScheduledRebalanceOnThisClient, equalTo(time.milliseconds() + rebalanceInterval)); + assertThat(nextScheduledRebalanceOnOtherClient, equalTo(Long.MAX_VALUE)); + } + + + /** + * Helper for building the input to createMockAdminClient in cases where we don't care about the actual offsets + * @param changelogTopics The names of all changelog topics in the topology + * @param topicsNumPartitions The number of partitions for the corresponding changelog topic, such that the number + * of partitions of the ith topic in changelogTopics is given by the ith element of topicsNumPartitions + */ + private static Map getTopicPartitionOffsetsMap(final List changelogTopics, + final List topicsNumPartitions) { + if (changelogTopics.size() != topicsNumPartitions.size()) { + throw new IllegalStateException("Passed in " + changelogTopics.size() + " changelog topic names, but " + + topicsNumPartitions.size() + " different numPartitions for the topics"); + } + final Map changelogEndOffsets = new HashMap<>(); + for (int i = 0; i < changelogTopics.size(); ++i) { + final String topic = changelogTopics.get(i); + final int numPartitions = topicsNumPartitions.get(i); + for (int partition = 0; partition < numPartitions; ++partition) { + changelogEndOffsets.put(new TopicPartition(topic, partition), Long.MAX_VALUE); + } + } + return changelogEndOffsets; + } + + private static SubscriptionInfo getInfo(final UUID processId, + final Set prevTasks) { + return new SubscriptionInfo( + LATEST_SUPPORTED_VERSION, LATEST_SUPPORTED_VERSION, processId, null, getTaskOffsetSums(prevTasks), (byte) 0, 0); + } + + // Stub offset sums for when we only care about the prev/standby task sets, not the actual offsets + private static Map getTaskOffsetSums(final Set activeTasks) { + final Map taskOffsetSums = activeTasks.stream().collect(Collectors.toMap(t -> t, t -> Task.LATEST_OFFSET)); + taskOffsetSums.putAll(EMPTY_TASKS.stream().collect(Collectors.toMap(t -> t, t -> 0L))); + return taskOffsetSums; + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/InternalTopicConfigTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/InternalTopicConfigTest.java new file mode 100644 index 0000000..f528457 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/InternalTopicConfigTest.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.config.TopicConfig; +import org.apache.kafka.common.errors.InvalidTopicException; +import org.junit.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +public class InternalTopicConfigTest { + + @Test + public void shouldThrowNpeIfTopicConfigIsNull() { + assertThrows(NullPointerException.class, () -> new RepartitionTopicConfig("topic", null)); + } + + @Test + public void shouldThrowIfNameIsNull() { + assertThrows(NullPointerException.class, () -> new RepartitionTopicConfig(null, Collections.emptyMap())); + } + + @Test + public void shouldThrowIfNameIsInvalid() { + assertThrows(InvalidTopicException.class, () -> new RepartitionTopicConfig("foo bar baz", Collections.emptyMap())); + } + + @Test + public void shouldSetCreateTimeByDefaultForWindowedChangelog() { + final WindowedChangelogTopicConfig topicConfig = new WindowedChangelogTopicConfig("name", Collections.emptyMap()); + + final Map properties = topicConfig.getProperties(Collections.emptyMap(), 0); + assertEquals("CreateTime", properties.get(TopicConfig.MESSAGE_TIMESTAMP_TYPE_CONFIG)); + } + + @Test + public void shouldSetCreateTimeByDefaultForUnwindowedChangelog() { + final UnwindowedChangelogTopicConfig topicConfig = new UnwindowedChangelogTopicConfig("name", Collections.emptyMap()); + + final Map properties = topicConfig.getProperties(Collections.emptyMap(), 0); + assertEquals("CreateTime", properties.get(TopicConfig.MESSAGE_TIMESTAMP_TYPE_CONFIG)); + } + + @Test + public void shouldSetCreateTimeByDefaultForRepartitionTopic() { + final RepartitionTopicConfig topicConfig = new RepartitionTopicConfig("name", Collections.emptyMap()); + + final Map properties = topicConfig.getProperties(Collections.emptyMap(), 0); + assertEquals("CreateTime", properties.get(TopicConfig.MESSAGE_TIMESTAMP_TYPE_CONFIG)); + } + + @Test + public void shouldAugmentRetentionMsWithWindowedChangelog() { + final WindowedChangelogTopicConfig topicConfig = new WindowedChangelogTopicConfig("name", Collections.emptyMap()); + topicConfig.setRetentionMs(10); + assertEquals("30", topicConfig.getProperties(Collections.emptyMap(), 20).get(TopicConfig.RETENTION_MS_CONFIG)); + } + + @Test + public void shouldUseSuppliedConfigsForWindowedChangelogConfig() { + final Map configs = new HashMap<>(); + configs.put("message.timestamp.type", "LogAppendTime"); + + final WindowedChangelogTopicConfig topicConfig = new WindowedChangelogTopicConfig("name", configs); + + final Map properties = topicConfig.getProperties(Collections.emptyMap(), 0); + assertEquals("LogAppendTime", properties.get(TopicConfig.MESSAGE_TIMESTAMP_TYPE_CONFIG)); + } + + @Test + public void shouldUseSuppliedConfigsForUnwindowedChangelogConfig() { + final Map configs = new HashMap<>(); + configs.put("retention.ms", "1000"); + configs.put("retention.bytes", "10000"); + configs.put("message.timestamp.type", "LogAppendTime"); + + final UnwindowedChangelogTopicConfig topicConfig = new UnwindowedChangelogTopicConfig("name", configs); + + final Map properties = topicConfig.getProperties(Collections.emptyMap(), 0); + assertEquals("1000", properties.get(TopicConfig.RETENTION_MS_CONFIG)); + assertEquals("10000", properties.get(TopicConfig.RETENTION_BYTES_CONFIG)); + assertEquals("LogAppendTime", properties.get(TopicConfig.MESSAGE_TIMESTAMP_TYPE_CONFIG)); + } + + @Test + public void shouldUseSuppliedConfigsForRepartitionConfig() { + final Map configs = new HashMap<>(); + configs.put("retention.ms", "1000"); + configs.put("message.timestamp.type", "LogAppendTime"); + + final RepartitionTopicConfig topicConfig = new RepartitionTopicConfig("name", configs); + + final Map properties = topicConfig.getProperties(Collections.emptyMap(), 0); + assertEquals("1000", properties.get(TopicConfig.RETENTION_MS_CONFIG)); + assertEquals("LogAppendTime", properties.get(TopicConfig.MESSAGE_TIMESTAMP_TYPE_CONFIG)); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/InternalTopicManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/InternalTopicManagerTest.java new file mode 100644 index 0000000..853d8d7 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/InternalTopicManagerTest.java @@ -0,0 +1,1744 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.admin.AdminClient; +import org.apache.kafka.clients.admin.Config; +import org.apache.kafka.clients.admin.ConfigEntry; +import org.apache.kafka.clients.admin.CreateTopicsOptions; +import org.apache.kafka.clients.admin.CreateTopicsResult; +import org.apache.kafka.clients.admin.CreateTopicsResult.TopicMetadataAndConfig; +import org.apache.kafka.clients.admin.DeleteTopicsResult; +import org.apache.kafka.clients.admin.DescribeConfigsResult; +import org.apache.kafka.clients.admin.DescribeTopicsResult; +import org.apache.kafka.clients.admin.MockAdminClient; +import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.clients.admin.TopicDescription; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartitionInfo; +import org.apache.kafka.common.Uuid; +import org.apache.kafka.common.config.ConfigResource; +import org.apache.kafka.common.config.ConfigResource.Type; +import org.apache.kafka.common.config.TopicConfig; +import org.apache.kafka.common.errors.LeaderNotAvailableException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.TopicExistsException; +import org.apache.kafka.common.errors.UnknownTopicOrPartitionException; +import org.apache.kafka.common.errors.UnsupportedVersionException; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.message.CreateTopicsRequestData; +import org.apache.kafka.common.message.CreateTopicsRequestData.CreatableReplicaAssignmentCollection; +import org.apache.kafka.common.message.CreateTopicsRequestData.CreatableTopic; +import org.apache.kafka.common.message.CreateTopicsRequestData.CreatableTopicCollection; +import org.apache.kafka.common.requests.CreateTopicsRequest; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.internals.InternalTopicManager.ValidationResult; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.easymock.EasyMock; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.anEmptyMap; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.hasKey; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.fail; + +public class InternalTopicManagerTest { + private final Node broker1 = new Node(0, "dummyHost-1", 1234); + private final Node broker2 = new Node(1, "dummyHost-2", 1234); + private final List cluster = new ArrayList(2) { + { + add(broker1); + add(broker2); + } + }; + private final String topic1 = "test_topic"; + private final String topic2 = "test_topic_2"; + private final String topic3 = "test_topic_3"; + private final String topic4 = "test_topic_4"; + private final String topic5 = "test_topic_5"; + private final List singleReplica = Collections.singletonList(broker1); + + private String threadName; + + private MockAdminClient mockAdminClient; + private InternalTopicManager internalTopicManager; + + private final Map config = new HashMap() { + { + put(StreamsConfig.APPLICATION_ID_CONFIG, "app-id"); + put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, broker1.host() + ":" + broker1.port()); + put(StreamsConfig.REPLICATION_FACTOR_CONFIG, 1); + put(StreamsConfig.producerPrefix(ProducerConfig.BATCH_SIZE_CONFIG), 16384); + put(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG), 100); + put(StreamsConfig.RETRY_BACKOFF_MS_CONFIG, 10); + } + }; + + @Before + public void init() { + threadName = Thread.currentThread().getName(); + + mockAdminClient = new MockAdminClient(cluster, broker1); + internalTopicManager = new InternalTopicManager( + Time.SYSTEM, + mockAdminClient, + new StreamsConfig(config) + ); + } + + @After + public void shutdown() { + mockAdminClient.close(); + } + + @Test + public void shouldCreateTopics() throws Exception { + final InternalTopicConfig internalTopicConfig1 = setupRepartitionTopicConfig(topic1, 1); + final InternalTopicConfig internalTopicConfig2 = setupRepartitionTopicConfig(topic2, 1); + + internalTopicManager.setup(mkMap( + mkEntry(topic1, internalTopicConfig1), + mkEntry(topic2, internalTopicConfig2) + )); + + final Set newlyCreatedTopics = mockAdminClient.listTopics().names().get(); + assertThat(newlyCreatedTopics.size(), is(2)); + assertThat(newlyCreatedTopics, hasItem(topic1)); + assertThat(newlyCreatedTopics, hasItem(topic2)); + } + + @Test + public void shouldNotCreateTopicsWithEmptyInput() throws Exception { + + internalTopicManager.setup(Collections.emptyMap()); + + final Set newlyCreatedTopics = mockAdminClient.listTopics().names().get(); + assertThat(newlyCreatedTopics, empty()); + } + + @Test + public void shouldOnlyRetryNotSuccessfulFuturesDuringSetup() { + final AdminClient admin = EasyMock.createMock(AdminClient.class); + final StreamsConfig streamsConfig = new StreamsConfig(config); + final InternalTopicManager topicManager = new InternalTopicManager(new MockTime(1L), admin, streamsConfig); + final KafkaFutureImpl createTopicFailFuture = new KafkaFutureImpl<>(); + createTopicFailFuture.completeExceptionally(new TopicExistsException("exists")); + final KafkaFutureImpl createTopicSuccessfulFuture = new KafkaFutureImpl<>(); + createTopicSuccessfulFuture.complete( + new TopicMetadataAndConfig(Uuid.randomUuid(), 1, 1, new Config(Collections.emptyList())) + ); + final InternalTopicConfig internalTopicConfig1 = setupRepartitionTopicConfig(topic1, 1); + final InternalTopicConfig internalTopicConfig2 = setupRepartitionTopicConfig(topic2, 1); + final NewTopic newTopic1 = newTopic(topic1, internalTopicConfig1, streamsConfig); + final NewTopic newTopic2 = newTopic(topic2, internalTopicConfig2, streamsConfig); + EasyMock.expect(admin.createTopics(mkSet(newTopic1, newTopic2))) + .andAnswer(() -> new MockCreateTopicsResult(mkMap( + mkEntry(topic1, createTopicSuccessfulFuture), + mkEntry(topic2, createTopicFailFuture) + ))); + EasyMock.expect(admin.createTopics(mkSet(newTopic2))) + .andAnswer(() -> new MockCreateTopicsResult(mkMap( + mkEntry(topic2, createTopicSuccessfulFuture) + ))); + EasyMock.replay(admin); + + topicManager.setup(mkMap( + mkEntry(topic1, internalTopicConfig1), + mkEntry(topic2, internalTopicConfig2) + )); + + EasyMock.verify(admin); + } + + @Test + public void shouldRetryCreateTopicWhenCreationTimesOut() { + shouldRetryCreateTopicWhenRetriableExceptionIsThrown(new TimeoutException("timed out")); + } + + @Test + public void shouldRetryCreateTopicWhenTopicNotYetDeleted() { + shouldRetryCreateTopicWhenRetriableExceptionIsThrown(new TopicExistsException("exists")); + } + + private void shouldRetryCreateTopicWhenRetriableExceptionIsThrown(final Exception retriableException) { + final AdminClient admin = EasyMock.createNiceMock(AdminClient.class); + final StreamsConfig streamsConfig = new StreamsConfig(config); + final InternalTopicManager topicManager = new InternalTopicManager(Time.SYSTEM, admin, streamsConfig); + final KafkaFutureImpl createTopicFailFuture = new KafkaFutureImpl<>(); + createTopicFailFuture.completeExceptionally(retriableException); + final KafkaFutureImpl createTopicSuccessfulFuture = new KafkaFutureImpl<>(); + createTopicSuccessfulFuture.complete( + new TopicMetadataAndConfig(Uuid.randomUuid(), 1, 1, new Config(Collections.emptyList())) + ); + final InternalTopicConfig internalTopicConfig = setupRepartitionTopicConfig(topic1, 1); + final NewTopic newTopic = newTopic(topic1, internalTopicConfig, streamsConfig); + EasyMock.expect(admin.createTopics(mkSet(newTopic))) + .andAnswer(() -> new MockCreateTopicsResult(mkMap( + mkEntry(topic1, createTopicSuccessfulFuture) + ))); + EasyMock.expect(admin.createTopics(mkSet(newTopic))) + .andAnswer(() -> new MockCreateTopicsResult(mkMap( + mkEntry(topic2, createTopicSuccessfulFuture) + ))); + EasyMock.replay(admin); + + topicManager.setup(mkMap( + mkEntry(topic1, internalTopicConfig) + )); + } + + @Test + public void shouldThrowInformativeExceptionForOlderBrokers() { + final AdminClient admin = new MockAdminClient() { + @Override + public CreateTopicsResult createTopics(final Collection newTopics, + final CreateTopicsOptions options) { + final CreatableTopic topicToBeCreated = new CreatableTopic(); + topicToBeCreated.setAssignments(new CreatableReplicaAssignmentCollection()); + topicToBeCreated.setNumPartitions((short) 1); + // set unsupported replication factor for older brokers + topicToBeCreated.setReplicationFactor((short) -1); + + final CreatableTopicCollection topicsToBeCreated = new CreatableTopicCollection(); + topicsToBeCreated.add(topicToBeCreated); + + try { + new CreateTopicsRequest.Builder( + new CreateTopicsRequestData() + .setTopics(topicsToBeCreated) + .setTimeoutMs(0) + .setValidateOnly(options.shouldValidateOnly())) + .build((short) 3); // pass in old unsupported request version for old brokers + + throw new IllegalStateException("Building CreateTopicRequest should have thrown."); + } catch (final UnsupportedVersionException expected) { + final KafkaFutureImpl future = new KafkaFutureImpl<>(); + future.completeExceptionally(expected); + + return new CreateTopicsResult(Collections.singletonMap(topic1, future)) { }; + } + } + }; + + final StreamsConfig streamsConfig = new StreamsConfig(config); + final InternalTopicManager topicManager = new InternalTopicManager(Time.SYSTEM, admin, streamsConfig); + + final InternalTopicConfig topicConfig = new RepartitionTopicConfig(topic1, Collections.emptyMap()); + topicConfig.setNumberOfPartitions(1); + + final StreamsException exception = assertThrows( + StreamsException.class, + () -> topicManager.makeReady(Collections.singletonMap(topic1, topicConfig)) + ); + assertThat( + exception.getMessage(), + equalTo("Could not create topic " + topic1 + ", because brokers don't support configuration replication.factor=-1." + + " You can change the replication.factor config or upgrade your brokers to version 2.4 or newer to avoid this error.")); + } + + @Test + public void shouldThrowTimeoutExceptionIfTopicExistsDuringSetup() { + setupTopicInMockAdminClient(topic1, Collections.emptyMap()); + final InternalTopicConfig internalTopicConfig = setupRepartitionTopicConfig(topic1, 1); + + final TimeoutException exception = assertThrows( + TimeoutException.class, + () -> internalTopicManager.setup(Collections.singletonMap(topic1, internalTopicConfig)) + ); + + assertThat( + exception.getMessage(), + is("Could not create internal topics within " + + (Integer) config.get(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG)) / 2 + + " milliseconds. This can happen if the Kafka cluster is temporarily not available or a topic is marked" + + " for deletion and the broker did not complete its deletion within the timeout." + + " The last errors seen per topic are:" + + " {" + topic1 + "=org.apache.kafka.common.errors.TopicExistsException: Topic test_topic exists already.}") + ); + } + + @Test + public void shouldThrowWhenCreateTopicsThrowsUnexpectedException() { + final AdminClient admin = EasyMock.createNiceMock(AdminClient.class); + final StreamsConfig streamsConfig = new StreamsConfig(config); + final InternalTopicManager topicManager = new InternalTopicManager(Time.SYSTEM, admin, streamsConfig); + final InternalTopicConfig internalTopicConfig = setupRepartitionTopicConfig(topic1, 1); + final KafkaFutureImpl createTopicFailFuture = new KafkaFutureImpl<>(); + createTopicFailFuture.completeExceptionally(new IllegalStateException("Nobody expects the Spanish inquisition")); + final NewTopic newTopic = newTopic(topic1, internalTopicConfig, streamsConfig); + EasyMock.expect(admin.createTopics(mkSet(newTopic))) + .andStubAnswer(() -> new MockCreateTopicsResult(mkMap( + mkEntry(topic1, createTopicFailFuture) + ))); + EasyMock.replay(admin); + + assertThrows(StreamsException.class, () -> topicManager.setup(mkMap( + mkEntry(topic1, internalTopicConfig) + ))); + } + + @Test + public void shouldThrowWhenCreateTopicsResultsDoNotContainTopic() { + final AdminClient admin = EasyMock.createNiceMock(AdminClient.class); + final StreamsConfig streamsConfig = new StreamsConfig(config); + final InternalTopicManager topicManager = new InternalTopicManager(Time.SYSTEM, admin, streamsConfig); + final InternalTopicConfig internalTopicConfig = setupRepartitionTopicConfig(topic1, 1); + final NewTopic newTopic = newTopic(topic1, internalTopicConfig, streamsConfig); + EasyMock.expect(admin.createTopics(mkSet(newTopic))) + .andStubAnswer(() -> new MockCreateTopicsResult(Collections.singletonMap(topic2, new KafkaFutureImpl<>()))); + EasyMock.replay(admin); + + assertThrows( + IllegalStateException.class, + () -> topicManager.setup(Collections.singletonMap(topic1, internalTopicConfig)) + ); + } + + @Test + public void shouldThrowTimeoutExceptionWhenCreateTopicExceedsTimeout() { + final AdminClient admin = EasyMock.createNiceMock(AdminClient.class); + final MockTime time = new MockTime( + (Integer) config.get(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG)) / 3 + ); + final StreamsConfig streamsConfig = new StreamsConfig(config); + final InternalTopicManager topicManager = new InternalTopicManager(time, admin, streamsConfig); + final KafkaFutureImpl createTopicFailFuture = new KafkaFutureImpl<>(); + createTopicFailFuture.completeExceptionally(new TimeoutException()); + final InternalTopicConfig internalTopicConfig = setupRepartitionTopicConfig(topic1, 1); + final NewTopic newTopic = newTopic(topic1, internalTopicConfig, streamsConfig); + EasyMock.expect(admin.createTopics(mkSet(newTopic))) + .andStubAnswer(() -> new MockCreateTopicsResult(mkMap(mkEntry(topic1, createTopicFailFuture)))); + EasyMock.replay(admin); + + assertThrows( + TimeoutException.class, + () -> topicManager.setup(Collections.singletonMap(topic1, internalTopicConfig)) + ); + } + + @Test + public void shouldThrowTimeoutExceptionWhenFuturesNeverCompleteDuringSetup() { + final AdminClient admin = EasyMock.createNiceMock(AdminClient.class); + final MockTime time = new MockTime( + (Integer) config.get(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG)) / 3 + ); + final StreamsConfig streamsConfig = new StreamsConfig(config); + final InternalTopicManager topicManager = new InternalTopicManager(time, admin, streamsConfig); + final KafkaFutureImpl createTopicFutureThatNeverCompletes = new KafkaFutureImpl<>(); + final InternalTopicConfig internalTopicConfig = setupRepartitionTopicConfig(topic1, 1); + final NewTopic newTopic = newTopic(topic1, internalTopicConfig, streamsConfig); + EasyMock.expect(admin.createTopics(mkSet(newTopic))) + .andStubAnswer(() -> new MockCreateTopicsResult(mkMap(mkEntry(topic1, createTopicFutureThatNeverCompletes)))); + EasyMock.replay(admin); + + assertThrows( + TimeoutException.class, + () -> topicManager.setup(Collections.singletonMap(topic1, internalTopicConfig)) + ); + } + + @Test + public void shouldCleanUpWhenUnexpectedExceptionIsThrownDuringSetup() { + final AdminClient admin = EasyMock.createNiceMock(AdminClient.class); + final StreamsConfig streamsConfig = new StreamsConfig(config); + final MockTime time = new MockTime( + (Integer) config.get(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG)) / 3 + ); + final InternalTopicManager topicManager = new InternalTopicManager(time, admin, streamsConfig); + final InternalTopicConfig internalTopicConfig1 = setupRepartitionTopicConfig(topic1, 1); + final InternalTopicConfig internalTopicConfig2 = setupRepartitionTopicConfig(topic2, 1); + setupCleanUpScenario(admin, streamsConfig, internalTopicConfig1, internalTopicConfig2); + final KafkaFutureImpl deleteTopicSuccessfulFuture = new KafkaFutureImpl<>(); + deleteTopicSuccessfulFuture.complete(null); + EasyMock.expect(admin.deleteTopics(mkSet(topic1))) + .andAnswer(() -> new MockDeleteTopicsResult(mkMap(mkEntry(topic1, deleteTopicSuccessfulFuture)))); + EasyMock.replay(admin); + + assertThrows( + StreamsException.class, + () -> topicManager.setup(mkMap( + mkEntry(topic1, internalTopicConfig1), + mkEntry(topic2, internalTopicConfig2) + )) + ); + + EasyMock.verify(admin); + } + + @Test + public void shouldCleanUpWhenCreateTopicsResultsDoNotContainTopic() { + final AdminClient admin = EasyMock.createNiceMock(AdminClient.class); + final StreamsConfig streamsConfig = new StreamsConfig(config); + final InternalTopicManager topicManager = new InternalTopicManager(Time.SYSTEM, admin, streamsConfig); + final InternalTopicConfig internalTopicConfig1 = setupRepartitionTopicConfig(topic1, 1); + final InternalTopicConfig internalTopicConfig2 = setupRepartitionTopicConfig(topic2, 1); + final KafkaFutureImpl createTopicFailFuture1 = new KafkaFutureImpl<>(); + createTopicFailFuture1.completeExceptionally(new TopicExistsException("exists")); + final KafkaFutureImpl createTopicSuccessfulFuture = new KafkaFutureImpl<>(); + createTopicSuccessfulFuture.complete( + new TopicMetadataAndConfig(Uuid.randomUuid(), 1, 1, new Config(Collections.emptyList())) + ); + final NewTopic newTopic1 = newTopic(topic1, internalTopicConfig1, streamsConfig); + final NewTopic newTopic2 = newTopic(topic2, internalTopicConfig2, streamsConfig); + EasyMock.expect(admin.createTopics(mkSet(newTopic1, newTopic2))) + .andAnswer(() -> new MockCreateTopicsResult(mkMap( + mkEntry(topic1, createTopicSuccessfulFuture), + mkEntry(topic2, createTopicFailFuture1) + ))); + EasyMock.expect(admin.createTopics(mkSet(newTopic2))) + .andAnswer(() -> new MockCreateTopicsResult(mkMap( + mkEntry(topic3, createTopicSuccessfulFuture) + ))); + final KafkaFutureImpl deleteTopicSuccessfulFuture = new KafkaFutureImpl<>(); + deleteTopicSuccessfulFuture.complete(null); + EasyMock.expect(admin.deleteTopics(mkSet(topic1))) + .andAnswer(() -> new MockDeleteTopicsResult(mkMap(mkEntry(topic1, deleteTopicSuccessfulFuture)))); + EasyMock.replay(admin); + + assertThrows( + IllegalStateException.class, + () -> topicManager.setup(mkMap( + mkEntry(topic1, internalTopicConfig1), + mkEntry(topic2, internalTopicConfig2) + )) + ); + + EasyMock.verify(admin); + } + + @Test + public void shouldCleanUpWhenCreateTopicsTimesOut() { + final AdminClient admin = EasyMock.createNiceMock(AdminClient.class); + final StreamsConfig streamsConfig = new StreamsConfig(config); + final MockTime time = new MockTime( + (Integer) config.get(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG)) / 3 + ); + final InternalTopicManager topicManager = new InternalTopicManager(time, admin, streamsConfig); + final InternalTopicConfig internalTopicConfig1 = setupRepartitionTopicConfig(topic1, 1); + final InternalTopicConfig internalTopicConfig2 = setupRepartitionTopicConfig(topic2, 1); + final KafkaFutureImpl createTopicFailFuture1 = new KafkaFutureImpl<>(); + createTopicFailFuture1.completeExceptionally(new TopicExistsException("exists")); + final KafkaFutureImpl createTopicSuccessfulFuture = new KafkaFutureImpl<>(); + createTopicSuccessfulFuture.complete( + new TopicMetadataAndConfig(Uuid.randomUuid(), 1, 1, new Config(Collections.emptyList())) + ); + final NewTopic newTopic1 = newTopic(topic1, internalTopicConfig1, streamsConfig); + final NewTopic newTopic2 = newTopic(topic2, internalTopicConfig2, streamsConfig); + EasyMock.expect(admin.createTopics(mkSet(newTopic1, newTopic2))) + .andAnswer(() -> new MockCreateTopicsResult(mkMap( + mkEntry(topic1, createTopicSuccessfulFuture), + mkEntry(topic2, createTopicFailFuture1) + ))); + final KafkaFutureImpl createTopicFutureThatNeverCompletes = new KafkaFutureImpl<>(); + EasyMock.expect(admin.createTopics(mkSet(newTopic2))) + .andStubAnswer(() -> new MockCreateTopicsResult(mkMap(mkEntry(topic2, createTopicFutureThatNeverCompletes)))); + final KafkaFutureImpl deleteTopicSuccessfulFuture = new KafkaFutureImpl<>(); + deleteTopicSuccessfulFuture.complete(null); + EasyMock.expect(admin.deleteTopics(mkSet(topic1))) + .andAnswer(() -> new MockDeleteTopicsResult(mkMap(mkEntry(topic1, deleteTopicSuccessfulFuture)))); + EasyMock.replay(admin); + + assertThrows( + TimeoutException.class, + () -> topicManager.setup(mkMap( + mkEntry(topic1, internalTopicConfig1), + mkEntry(topic2, internalTopicConfig2) + )) + ); + + EasyMock.verify(admin); + } + + @Test + public void shouldRetryDeleteTopicWhenTopicUnknown() { + shouldRetryDeleteTopicWhenRetriableException(new UnknownTopicOrPartitionException()); + } + + @Test + public void shouldRetryDeleteTopicWhenLeaderNotAvailable() { + shouldRetryDeleteTopicWhenRetriableException(new LeaderNotAvailableException("leader not available")); + } + + @Test + public void shouldRetryDeleteTopicWhenFutureTimesOut() { + shouldRetryDeleteTopicWhenRetriableException(new TimeoutException("timed out")); + } + + private void shouldRetryDeleteTopicWhenRetriableException(final Exception retriableException) { + final AdminClient admin = EasyMock.createNiceMock(AdminClient.class); + final StreamsConfig streamsConfig = new StreamsConfig(config); + final InternalTopicManager topicManager = new InternalTopicManager(Time.SYSTEM, admin, streamsConfig); + final InternalTopicConfig internalTopicConfig1 = setupRepartitionTopicConfig(topic1, 1); + final InternalTopicConfig internalTopicConfig2 = setupRepartitionTopicConfig(topic2, 1); + setupCleanUpScenario(admin, streamsConfig, internalTopicConfig1, internalTopicConfig2); + final KafkaFutureImpl deleteTopicFailFuture = new KafkaFutureImpl<>(); + deleteTopicFailFuture.completeExceptionally(retriableException); + final KafkaFutureImpl deleteTopicSuccessfulFuture = new KafkaFutureImpl<>(); + deleteTopicSuccessfulFuture.complete(null); + EasyMock.expect(admin.deleteTopics(mkSet(topic1))) + .andAnswer(() -> new MockDeleteTopicsResult(mkMap(mkEntry(topic1, deleteTopicFailFuture)))) + .andAnswer(() -> new MockDeleteTopicsResult(mkMap(mkEntry(topic1, deleteTopicSuccessfulFuture)))); + EasyMock.replay(admin); + + assertThrows( + StreamsException.class, + () -> topicManager.setup(mkMap( + mkEntry(topic1, internalTopicConfig1), + mkEntry(topic2, internalTopicConfig2) + )) + ); + EasyMock.verify(); + } + + @Test + public void shouldThrowTimeoutExceptionWhenFuturesNeverCompleteDuringCleanUp() { + final AdminClient admin = EasyMock.createNiceMock(AdminClient.class); + final StreamsConfig streamsConfig = new StreamsConfig(config); + final MockTime time = new MockTime( + (Integer) config.get(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG)) / 3 + ); + final InternalTopicManager topicManager = new InternalTopicManager(time, admin, streamsConfig); + final InternalTopicConfig internalTopicConfig1 = setupRepartitionTopicConfig(topic1, 1); + final InternalTopicConfig internalTopicConfig2 = setupRepartitionTopicConfig(topic2, 1); + setupCleanUpScenario(admin, streamsConfig, internalTopicConfig1, internalTopicConfig2); + final KafkaFutureImpl deleteTopicFutureThatNeverCompletes = new KafkaFutureImpl<>(); + EasyMock.expect(admin.deleteTopics(mkSet(topic1))) + .andStubAnswer(() -> new MockDeleteTopicsResult(mkMap(mkEntry(topic1, deleteTopicFutureThatNeverCompletes)))); + EasyMock.replay(admin); + + assertThrows( + TimeoutException.class, + () -> topicManager.setup(mkMap( + mkEntry(topic1, internalTopicConfig1), + mkEntry(topic2, internalTopicConfig2) + )) + ); + } + + @Test + public void shouldThrowWhenDeleteTopicsThrowsUnexpectedException() { + final AdminClient admin = EasyMock.createNiceMock(AdminClient.class); + final StreamsConfig streamsConfig = new StreamsConfig(config); + final InternalTopicManager topicManager = new InternalTopicManager(Time.SYSTEM, admin, streamsConfig); + final InternalTopicConfig internalTopicConfig1 = setupRepartitionTopicConfig(topic1, 1); + final InternalTopicConfig internalTopicConfig2 = setupRepartitionTopicConfig(topic2, 1); + setupCleanUpScenario(admin, streamsConfig, internalTopicConfig1, internalTopicConfig2); + final KafkaFutureImpl deleteTopicFailFuture = new KafkaFutureImpl<>(); + deleteTopicFailFuture.completeExceptionally(new IllegalStateException("Nobody expects the Spanish inquisition")); + EasyMock.expect(admin.deleteTopics(mkSet(topic1))) + .andStubAnswer(() -> new MockDeleteTopicsResult(mkMap(mkEntry(topic1, deleteTopicFailFuture)))); + EasyMock.replay(admin); + + assertThrows( + StreamsException.class, + () -> topicManager.setup(mkMap( + mkEntry(topic1, internalTopicConfig1), + mkEntry(topic2, internalTopicConfig2) + )) + ); + } + + private void setupCleanUpScenario(final AdminClient admin, final StreamsConfig streamsConfig, final InternalTopicConfig internalTopicConfig1, final InternalTopicConfig internalTopicConfig2) { + final KafkaFutureImpl createTopicFailFuture1 = new KafkaFutureImpl<>(); + createTopicFailFuture1.completeExceptionally(new TopicExistsException("exists")); + final KafkaFutureImpl createTopicFailFuture2 = new KafkaFutureImpl<>(); + createTopicFailFuture2.completeExceptionally(new IllegalStateException("Nobody expects the Spanish inquisition")); + final KafkaFutureImpl createTopicSuccessfulFuture = new KafkaFutureImpl<>(); + createTopicSuccessfulFuture.complete( + new TopicMetadataAndConfig(Uuid.randomUuid(), 1, 1, new Config(Collections.emptyList())) + ); + final NewTopic newTopic1 = newTopic(topic1, internalTopicConfig1, streamsConfig); + final NewTopic newTopic2 = newTopic(topic2, internalTopicConfig2, streamsConfig); + EasyMock.expect(admin.createTopics(mkSet(newTopic1, newTopic2))) + .andAnswer(() -> new MockCreateTopicsResult(mkMap( + mkEntry(topic1, createTopicSuccessfulFuture), + mkEntry(topic2, createTopicFailFuture1) + ))); + EasyMock.expect(admin.createTopics(mkSet(newTopic2))) + .andAnswer(() -> new MockCreateTopicsResult(mkMap( + mkEntry(topic2, createTopicFailFuture2) + ))); + } + + @Test + public void shouldReturnCorrectPartitionCounts() { + mockAdminClient.addTopic( + false, + topic1, + Collections.singletonList(new TopicPartitionInfo(0, broker1, singleReplica, Collections.emptyList())), + null); + assertEquals(Collections.singletonMap(topic1, 1), + internalTopicManager.getNumPartitions(Collections.singleton(topic1), Collections.emptySet())); + } + + @Test + public void shouldCreateRequiredTopics() throws Exception { + final InternalTopicConfig topicConfig = new RepartitionTopicConfig(topic1, Collections.emptyMap()); + topicConfig.setNumberOfPartitions(1); + final InternalTopicConfig topicConfig2 = new UnwindowedChangelogTopicConfig(topic2, Collections.emptyMap()); + topicConfig2.setNumberOfPartitions(1); + final InternalTopicConfig topicConfig3 = new WindowedChangelogTopicConfig(topic3, Collections.emptyMap()); + topicConfig3.setNumberOfPartitions(1); + + internalTopicManager.makeReady(Collections.singletonMap(topic1, topicConfig)); + internalTopicManager.makeReady(Collections.singletonMap(topic2, topicConfig2)); + internalTopicManager.makeReady(Collections.singletonMap(topic3, topicConfig3)); + + assertEquals(mkSet(topic1, topic2, topic3), mockAdminClient.listTopics().names().get()); + assertEquals(new TopicDescription(topic1, false, new ArrayList() { + { + add(new TopicPartitionInfo(0, broker1, singleReplica, Collections.emptyList())); + } + }), mockAdminClient.describeTopics(Collections.singleton(topic1)).topicNameValues().get(topic1).get()); + assertEquals(new TopicDescription(topic2, false, new ArrayList() { + { + add(new TopicPartitionInfo(0, broker1, singleReplica, Collections.emptyList())); + } + }), mockAdminClient.describeTopics(Collections.singleton(topic2)).topicNameValues().get(topic2).get()); + assertEquals(new TopicDescription(topic3, false, new ArrayList() { + { + add(new TopicPartitionInfo(0, broker1, singleReplica, Collections.emptyList())); + } + }), mockAdminClient.describeTopics(Collections.singleton(topic3)).topicNameValues().get(topic3).get()); + + final ConfigResource resource = new ConfigResource(ConfigResource.Type.TOPIC, topic1); + final ConfigResource resource2 = new ConfigResource(ConfigResource.Type.TOPIC, topic2); + final ConfigResource resource3 = new ConfigResource(ConfigResource.Type.TOPIC, topic3); + + assertEquals( + new ConfigEntry(TopicConfig.CLEANUP_POLICY_CONFIG, TopicConfig.CLEANUP_POLICY_DELETE), + mockAdminClient.describeConfigs(Collections.singleton(resource)).values().get(resource).get().get(TopicConfig.CLEANUP_POLICY_CONFIG) + ); + assertEquals( + new ConfigEntry(TopicConfig.CLEANUP_POLICY_CONFIG, TopicConfig.CLEANUP_POLICY_COMPACT), + mockAdminClient.describeConfigs(Collections.singleton(resource2)).values().get(resource2).get().get(TopicConfig.CLEANUP_POLICY_CONFIG) + ); + assertEquals( + new ConfigEntry(TopicConfig.CLEANUP_POLICY_CONFIG, TopicConfig.CLEANUP_POLICY_COMPACT + "," + TopicConfig.CLEANUP_POLICY_DELETE), + mockAdminClient.describeConfigs(Collections.singleton(resource3)).values().get(resource3).get().get(TopicConfig.CLEANUP_POLICY_CONFIG) + ); + } + + @Test + public void shouldCompleteTopicValidationOnRetry() { + final AdminClient admin = EasyMock.createNiceMock(AdminClient.class); + final InternalTopicManager topicManager = new InternalTopicManager( + Time.SYSTEM, + admin, + new StreamsConfig(config) + ); + final TopicPartitionInfo partitionInfo = new TopicPartitionInfo(0, broker1, + Collections.singletonList(broker1), Collections.singletonList(broker1)); + + final KafkaFutureImpl topicDescriptionSuccessFuture = new KafkaFutureImpl<>(); + final KafkaFutureImpl topicDescriptionFailFuture = new KafkaFutureImpl<>(); + topicDescriptionSuccessFuture.complete( + new TopicDescription(topic1, false, Collections.singletonList(partitionInfo), Collections.emptySet()) + ); + topicDescriptionFailFuture.completeExceptionally(new UnknownTopicOrPartitionException("KABOOM!")); + + final KafkaFutureImpl topicCreationFuture = new KafkaFutureImpl<>(); + topicCreationFuture.completeExceptionally(new TopicExistsException("KABOOM!")); + + // let the first describe succeed on topic, and fail on topic2, and then let creation throws topics-existed; + // it should retry with just topic2 and then let it succeed + EasyMock.expect(admin.describeTopics(mkSet(topic1, topic2))) + .andReturn(new MockDescribeTopicsResult(mkMap( + mkEntry(topic1, topicDescriptionSuccessFuture), + mkEntry(topic2, topicDescriptionFailFuture) + ))).once(); + EasyMock.expect(admin.createTopics(Collections.singleton(new NewTopic(topic2, Optional.of(1), Optional.of((short) 1)) + .configs(mkMap(mkEntry(TopicConfig.CLEANUP_POLICY_CONFIG, TopicConfig.CLEANUP_POLICY_COMPACT), + mkEntry(TopicConfig.MESSAGE_TIMESTAMP_TYPE_CONFIG, "CreateTime")))))) + .andReturn(new MockCreateTopicsResult(Collections.singletonMap(topic2, topicCreationFuture))).once(); + EasyMock.expect(admin.describeTopics(Collections.singleton(topic2))) + .andReturn(new MockDescribeTopicsResult(Collections.singletonMap(topic2, topicDescriptionSuccessFuture))); + + EasyMock.replay(admin); + + final InternalTopicConfig topicConfig = new UnwindowedChangelogTopicConfig(topic1, Collections.emptyMap()); + topicConfig.setNumberOfPartitions(1); + final InternalTopicConfig topic2Config = new UnwindowedChangelogTopicConfig(topic2, Collections.emptyMap()); + topic2Config.setNumberOfPartitions(1); + topicManager.makeReady(mkMap( + mkEntry(topic1, topicConfig), + mkEntry(topic2, topic2Config) + )); + + EasyMock.verify(admin); + } + + @Test + public void shouldNotCreateTopicIfExistsWithDifferentPartitions() { + mockAdminClient.addTopic( + false, + topic1, + new ArrayList() { + { + add(new TopicPartitionInfo(0, broker1, singleReplica, Collections.emptyList())); + add(new TopicPartitionInfo(1, broker1, singleReplica, Collections.emptyList())); + } + }, + null); + + try { + final InternalTopicConfig internalTopicConfig = new RepartitionTopicConfig(topic1, Collections.emptyMap()); + internalTopicConfig.setNumberOfPartitions(1); + internalTopicManager.makeReady(Collections.singletonMap(topic1, internalTopicConfig)); + fail("Should have thrown StreamsException"); + } catch (final StreamsException expected) { /* pass */ } + } + + @Test + public void shouldNotThrowExceptionIfExistsWithDifferentReplication() { + mockAdminClient.addTopic( + false, + topic1, + Collections.singletonList(new TopicPartitionInfo(0, broker1, cluster, Collections.emptyList())), + null); + + // attempt to create it again with replication 1 + final InternalTopicManager internalTopicManager2 = new InternalTopicManager( + Time.SYSTEM, + mockAdminClient, + new StreamsConfig(config) + ); + + final InternalTopicConfig internalTopicConfig = new RepartitionTopicConfig(topic1, Collections.emptyMap()); + internalTopicConfig.setNumberOfPartitions(1); + internalTopicManager2.makeReady(Collections.singletonMap(topic1, internalTopicConfig)); + } + + @Test + public void shouldNotThrowExceptionForEmptyTopicMap() { + internalTopicManager.makeReady(Collections.emptyMap()); + } + + @Test + public void shouldExhaustRetriesOnTimeoutExceptionForMakeReady() { + mockAdminClient.timeoutNextRequest(1); + + final InternalTopicConfig internalTopicConfig = new RepartitionTopicConfig(topic1, Collections.emptyMap()); + internalTopicConfig.setNumberOfPartitions(1); + try { + internalTopicManager.makeReady(Collections.singletonMap(topic1, internalTopicConfig)); + fail("Should have thrown StreamsException."); + } catch (final StreamsException expected) { + assertEquals(TimeoutException.class, expected.getCause().getClass()); + } + } + + @Test + public void shouldLogWhenTopicNotFoundAndNotThrowException() { + mockAdminClient.addTopic( + false, + topic1, + Collections.singletonList(new TopicPartitionInfo(0, broker1, cluster, Collections.emptyList())), + null); + + final InternalTopicConfig internalTopicConfig = new RepartitionTopicConfig(topic1, Collections.emptyMap()); + internalTopicConfig.setNumberOfPartitions(1); + + final InternalTopicConfig internalTopicConfigII = + new RepartitionTopicConfig("internal-topic", Collections.emptyMap()); + internalTopicConfigII.setNumberOfPartitions(1); + + final Map topicConfigMap = new HashMap<>(); + topicConfigMap.put(topic1, internalTopicConfig); + topicConfigMap.put("internal-topic", internalTopicConfigII); + + LogCaptureAppender.setClassLoggerToDebug(InternalTopicManager.class); + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(InternalTopicManager.class)) { + internalTopicManager.makeReady(topicConfigMap); + + assertThat( + appender.getMessages(), + hasItem("stream-thread [" + threadName + "] Topic internal-topic is unknown or not found, hence not existed yet.\n" + + "Error message was: org.apache.kafka.common.errors.UnknownTopicOrPartitionException: Topic internal-topic not found.") + ); + } + } + + @Test + public void shouldCreateTopicWhenTopicLeaderNotAvailableAndThenTopicNotFound() { + final AdminClient admin = EasyMock.createNiceMock(AdminClient.class); + final InternalTopicManager topicManager = new InternalTopicManager( + Time.SYSTEM, + admin, + new StreamsConfig(config) + ); + + final KafkaFutureImpl topicDescriptionLeaderNotAvailableFuture = new KafkaFutureImpl<>(); + topicDescriptionLeaderNotAvailableFuture.completeExceptionally(new LeaderNotAvailableException("Leader Not Available!")); + final KafkaFutureImpl topicDescriptionUnknownTopicFuture = new KafkaFutureImpl<>(); + topicDescriptionUnknownTopicFuture.completeExceptionally(new UnknownTopicOrPartitionException("Unknown Topic!")); + final KafkaFutureImpl topicCreationFuture = new KafkaFutureImpl<>(); + topicCreationFuture.complete(EasyMock.createNiceMock(CreateTopicsResult.TopicMetadataAndConfig.class)); + + EasyMock.expect(admin.describeTopics(Collections.singleton(topic1))) + .andReturn(new MockDescribeTopicsResult( + Collections.singletonMap(topic1, topicDescriptionLeaderNotAvailableFuture))) + .once(); + // we would not need to call create-topics for the first time + EasyMock.expect(admin.describeTopics(Collections.singleton(topic1))) + .andReturn(new MockDescribeTopicsResult( + Collections.singletonMap(topic1, topicDescriptionUnknownTopicFuture))) + .once(); + EasyMock.expect(admin.createTopics(Collections.singleton( + new NewTopic(topic1, Optional.of(1), Optional.of((short) 1)) + .configs(mkMap(mkEntry(TopicConfig.CLEANUP_POLICY_CONFIG, TopicConfig.CLEANUP_POLICY_DELETE), + mkEntry(TopicConfig.MESSAGE_TIMESTAMP_TYPE_CONFIG, "CreateTime"), + mkEntry(TopicConfig.SEGMENT_BYTES_CONFIG, "52428800"), + mkEntry(TopicConfig.RETENTION_MS_CONFIG, "-1")))))) + .andReturn(new MockCreateTopicsResult(Collections.singletonMap(topic1, topicCreationFuture))).once(); + + EasyMock.replay(admin); + + final InternalTopicConfig internalTopicConfig = new RepartitionTopicConfig(topic1, Collections.emptyMap()); + internalTopicConfig.setNumberOfPartitions(1); + topicManager.makeReady(Collections.singletonMap(topic1, internalTopicConfig)); + + EasyMock.verify(admin); + } + + @Test + public void shouldCompleteValidateWhenTopicLeaderNotAvailableAndThenDescribeSuccess() { + final AdminClient admin = EasyMock.createNiceMock(AdminClient.class); + final InternalTopicManager topicManager = new InternalTopicManager( + Time.SYSTEM, + admin, + new StreamsConfig(config) + ); + final TopicPartitionInfo partitionInfo = new TopicPartitionInfo(0, broker1, + Collections.singletonList(broker1), Collections.singletonList(broker1)); + + final KafkaFutureImpl topicDescriptionFailFuture = new KafkaFutureImpl<>(); + topicDescriptionFailFuture.completeExceptionally(new LeaderNotAvailableException("Leader Not Available!")); + final KafkaFutureImpl topicDescriptionSuccessFuture = new KafkaFutureImpl<>(); + topicDescriptionSuccessFuture.complete( + new TopicDescription(topic1, false, Collections.singletonList(partitionInfo), Collections.emptySet()) + ); + + EasyMock.expect(admin.describeTopics(Collections.singleton(topic1))) + .andReturn(new MockDescribeTopicsResult( + Collections.singletonMap(topic1, topicDescriptionFailFuture))) + .once(); + EasyMock.expect(admin.describeTopics(Collections.singleton(topic1))) + .andReturn(new MockDescribeTopicsResult( + Collections.singletonMap(topic1, topicDescriptionSuccessFuture))) + .once(); + + EasyMock.replay(admin); + + final InternalTopicConfig internalTopicConfig = new RepartitionTopicConfig(topic1, Collections.emptyMap()); + internalTopicConfig.setNumberOfPartitions(1); + topicManager.makeReady(Collections.singletonMap(topic1, internalTopicConfig)); + + EasyMock.verify(admin); + } + + @Test + public void shouldThrowExceptionWhenKeepsTopicLeaderNotAvailable() { + final AdminClient admin = EasyMock.createNiceMock(AdminClient.class); + final InternalTopicManager topicManager = new InternalTopicManager( + Time.SYSTEM, + admin, + new StreamsConfig(config) + ); + + final KafkaFutureImpl topicDescriptionFailFuture = new KafkaFutureImpl<>(); + topicDescriptionFailFuture.completeExceptionally(new LeaderNotAvailableException("Leader Not Available!")); + + // simulate describeTopics got LeaderNotAvailableException + EasyMock.expect(admin.describeTopics(Collections.singleton(topic1))) + .andReturn(new MockDescribeTopicsResult( + Collections.singletonMap(topic1, topicDescriptionFailFuture))) + .anyTimes(); + + EasyMock.replay(admin); + + final InternalTopicConfig internalTopicConfig = new RepartitionTopicConfig(topic1, Collections.emptyMap()); + internalTopicConfig.setNumberOfPartitions(1); + + final TimeoutException exception = assertThrows( + TimeoutException.class, + () -> topicManager.makeReady(Collections.singletonMap(topic1, internalTopicConfig)) + ); + assertNull(exception.getCause()); + assertThat( + exception.getMessage(), + equalTo("Could not create topics within 50 milliseconds." + + " This can happen if the Kafka cluster is temporarily not available.") + ); + + EasyMock.verify(admin); + } + + @Test + public void shouldExhaustRetriesOnMarkedForDeletionTopic() { + mockAdminClient.addTopic( + false, + topic1, + Collections.singletonList(new TopicPartitionInfo(0, broker1, cluster, Collections.emptyList())), + null); + mockAdminClient.markTopicForDeletion(topic1); + + final InternalTopicConfig internalTopicConfig = new RepartitionTopicConfig(topic1, Collections.emptyMap()); + internalTopicConfig.setNumberOfPartitions(1); + + final TimeoutException exception = assertThrows( + TimeoutException.class, + () -> internalTopicManager.makeReady(Collections.singletonMap(topic1, internalTopicConfig)) + ); + assertNull(exception.getCause()); + assertThat( + exception.getMessage(), + equalTo("Could not create topics within 50 milliseconds." + + " This can happen if the Kafka cluster is temporarily not available.") + ); + } + + @Test + public void shouldValidateSuccessfully() { + setupTopicInMockAdminClient(topic1, repartitionTopicConfig()); + setupTopicInMockAdminClient(topic2, repartitionTopicConfig()); + final InternalTopicConfig internalTopicConfig1 = setupRepartitionTopicConfig(topic1, 1); + final InternalTopicConfig internalTopicConfig2 = setupRepartitionTopicConfig(topic2, 1); + + final ValidationResult validationResult = internalTopicManager.validate(mkMap( + mkEntry(topic1, internalTopicConfig1), + mkEntry(topic2, internalTopicConfig2) + )); + + assertThat(validationResult.missingTopics(), empty()); + assertThat(validationResult.misconfigurationsForTopics(), anEmptyMap()); + } + + @Test + public void shouldValidateSuccessfullyWithEmptyInternalTopics() { + setupTopicInMockAdminClient(topic1, repartitionTopicConfig()); + + final ValidationResult validationResult = internalTopicManager.validate(Collections.emptyMap()); + + assertThat(validationResult.missingTopics(), empty()); + assertThat(validationResult.misconfigurationsForTopics(), anEmptyMap()); + } + + @Test + public void shouldReportMissingTopics() { + final String missingTopic1 = "missingTopic1"; + final String missingTopic2 = "missingTopic2"; + setupTopicInMockAdminClient(topic1, repartitionTopicConfig()); + final InternalTopicConfig internalTopicConfig1 = setupRepartitionTopicConfig(topic1, 1); + final InternalTopicConfig internalTopicConfig2 = setupRepartitionTopicConfig(missingTopic1, 1); + final InternalTopicConfig internalTopicConfig3 = setupRepartitionTopicConfig(missingTopic2, 1); + + final ValidationResult validationResult = internalTopicManager.validate(mkMap( + mkEntry(topic1, internalTopicConfig1), + mkEntry(missingTopic1, internalTopicConfig2), + mkEntry(missingTopic2, internalTopicConfig3) + )); + + final Set missingTopics = validationResult.missingTopics(); + assertThat(missingTopics.size(), is(2)); + assertThat(missingTopics, hasItem(missingTopic1)); + assertThat(missingTopics, hasItem(missingTopic2)); + assertThat(validationResult.misconfigurationsForTopics(), anEmptyMap()); + } + + @Test + public void shouldReportMisconfigurationsOfPartitionCount() { + setupTopicInMockAdminClient(topic1, repartitionTopicConfig()); + setupTopicInMockAdminClient(topic2, repartitionTopicConfig()); + setupTopicInMockAdminClient(topic3, repartitionTopicConfig()); + final InternalTopicConfig internalTopicConfig1 = setupRepartitionTopicConfig(topic1, 2); + final InternalTopicConfig internalTopicConfig2 = setupRepartitionTopicConfig(topic2, 3); + final InternalTopicConfig internalTopicConfig3 = setupRepartitionTopicConfig(topic3, 1); + + final ValidationResult validationResult = internalTopicManager.validate(mkMap( + mkEntry(topic1, internalTopicConfig1), + mkEntry(topic2, internalTopicConfig2), + mkEntry(topic3, internalTopicConfig3) + )); + + final Map> misconfigurationsForTopics = validationResult.misconfigurationsForTopics(); + assertThat(validationResult.missingTopics(), empty()); + assertThat(misconfigurationsForTopics.size(), is(2)); + assertThat(misconfigurationsForTopics, hasKey(topic1)); + assertThat(misconfigurationsForTopics.get(topic1).size(), is(1)); + assertThat( + misconfigurationsForTopics.get(topic1).get(0), + is("Internal topic " + topic1 + " requires 2 partitions, but the existing topic on the broker has 1 partitions.") + ); + assertThat(misconfigurationsForTopics, hasKey(topic2)); + assertThat(misconfigurationsForTopics.get(topic2).size(), is(1)); + assertThat( + misconfigurationsForTopics.get(topic2).get(0), + is("Internal topic " + topic2 + " requires 3 partitions, but the existing topic on the broker has 1 partitions.") + ); + assertThat(misconfigurationsForTopics, not(hasKey(topic3))); + } + + @Test + public void shouldReportMisconfigurationsOfCleanupPolicyForUnwindowedChangelogTopics() { + final Map unwindowedChangelogConfigWithDeleteCleanupPolicy = unwindowedChangelogConfig(); + unwindowedChangelogConfigWithDeleteCleanupPolicy.put( + TopicConfig.CLEANUP_POLICY_CONFIG, + TopicConfig.CLEANUP_POLICY_DELETE + ); + setupTopicInMockAdminClient(topic1, unwindowedChangelogConfigWithDeleteCleanupPolicy); + final Map unwindowedChangelogConfigWithDeleteCompactCleanupPolicy = unwindowedChangelogConfig(); + unwindowedChangelogConfigWithDeleteCompactCleanupPolicy.put( + TopicConfig.CLEANUP_POLICY_CONFIG, + TopicConfig.CLEANUP_POLICY_COMPACT + "," + TopicConfig.CLEANUP_POLICY_DELETE + ); + setupTopicInMockAdminClient(topic2, unwindowedChangelogConfigWithDeleteCompactCleanupPolicy); + setupTopicInMockAdminClient(topic3, unwindowedChangelogConfig()); + final InternalTopicConfig internalTopicConfig1 = setupUnwindowedChangelogTopicConfig(topic1, 1); + final InternalTopicConfig internalTopicConfig2 = setupUnwindowedChangelogTopicConfig(topic2, 1); + final InternalTopicConfig internalTopicConfig3 = setupUnwindowedChangelogTopicConfig(topic3, 1); + + final ValidationResult validationResult = internalTopicManager.validate(mkMap( + mkEntry(topic1, internalTopicConfig1), + mkEntry(topic2, internalTopicConfig2), + mkEntry(topic3, internalTopicConfig3) + )); + + final Map> misconfigurationsForTopics = validationResult.misconfigurationsForTopics(); + assertThat(validationResult.missingTopics(), empty()); + assertThat(misconfigurationsForTopics.size(), is(2)); + assertThat(misconfigurationsForTopics, hasKey(topic1)); + assertThat(misconfigurationsForTopics.get(topic1).size(), is(1)); + assertThat( + misconfigurationsForTopics.get(topic1).get(0), + is("Cleanup policy (" + TopicConfig.CLEANUP_POLICY_CONFIG + ") of existing internal topic " + topic1 + " should not contain \"" + + TopicConfig.CLEANUP_POLICY_DELETE + "\".") + ); + assertThat(misconfigurationsForTopics, hasKey(topic2)); + assertThat(misconfigurationsForTopics.get(topic2).size(), is(1)); + assertThat( + misconfigurationsForTopics.get(topic2).get(0), + is("Cleanup policy (" + TopicConfig.CLEANUP_POLICY_CONFIG + ") of existing internal topic " + topic2 + " should not contain \"" + + TopicConfig.CLEANUP_POLICY_DELETE + "\".") + ); + assertThat(misconfigurationsForTopics, not(hasKey(topic3))); + } + + @Test + public void shouldReportMisconfigurationsOfCleanupPolicyForWindowedChangelogTopics() { + final long retentionMs = 1000; + final long shorterRetentionMs = 900; + setupTopicInMockAdminClient(topic1, windowedChangelogConfig(retentionMs)); + setupTopicInMockAdminClient(topic2, windowedChangelogConfig(shorterRetentionMs)); + final Map windowedChangelogConfigOnlyCleanupPolicyCompact = windowedChangelogConfig(retentionMs); + windowedChangelogConfigOnlyCleanupPolicyCompact.put(TopicConfig.CLEANUP_POLICY_CONFIG, TopicConfig.CLEANUP_POLICY_COMPACT); + setupTopicInMockAdminClient(topic3, windowedChangelogConfigOnlyCleanupPolicyCompact); + final Map windowedChangelogConfigOnlyCleanupPolicyDelete = windowedChangelogConfig(shorterRetentionMs); + windowedChangelogConfigOnlyCleanupPolicyDelete.put(TopicConfig.CLEANUP_POLICY_CONFIG, TopicConfig.CLEANUP_POLICY_DELETE); + setupTopicInMockAdminClient(topic4, windowedChangelogConfigOnlyCleanupPolicyDelete); + final Map windowedChangelogConfigWithRetentionBytes = windowedChangelogConfig(retentionMs); + windowedChangelogConfigWithRetentionBytes.put(TopicConfig.RETENTION_BYTES_CONFIG, "1024"); + setupTopicInMockAdminClient(topic5, windowedChangelogConfigWithRetentionBytes); + final InternalTopicConfig internalTopicConfig1 = setupWindowedChangelogTopicConfig(topic1, 1, retentionMs); + final InternalTopicConfig internalTopicConfig2 = setupWindowedChangelogTopicConfig(topic2, 1, retentionMs); + final InternalTopicConfig internalTopicConfig3 = setupWindowedChangelogTopicConfig(topic3, 1, retentionMs); + final InternalTopicConfig internalTopicConfig4 = setupWindowedChangelogTopicConfig(topic4, 1, retentionMs); + final InternalTopicConfig internalTopicConfig5 = setupWindowedChangelogTopicConfig(topic5, 1, retentionMs); + + final ValidationResult validationResult = internalTopicManager.validate(mkMap( + mkEntry(topic1, internalTopicConfig1), + mkEntry(topic2, internalTopicConfig2), + mkEntry(topic3, internalTopicConfig3), + mkEntry(topic4, internalTopicConfig4), + mkEntry(topic5, internalTopicConfig5) + )); + + final Map> misconfigurationsForTopics = validationResult.misconfigurationsForTopics(); + assertThat(validationResult.missingTopics(), empty()); + assertThat(misconfigurationsForTopics.size(), is(3)); + assertThat(misconfigurationsForTopics, hasKey(topic2)); + assertThat(misconfigurationsForTopics.get(topic2).size(), is(1)); + assertThat( + misconfigurationsForTopics.get(topic2).get(0), + is("Retention time (" + TopicConfig.RETENTION_MS_CONFIG + ") of existing internal topic " + + topic2 + " is " + shorterRetentionMs + " but should be " + retentionMs + " or larger.") + ); + assertThat(misconfigurationsForTopics, hasKey(topic4)); + assertThat(misconfigurationsForTopics.get(topic4).size(), is(1)); + assertThat( + misconfigurationsForTopics.get(topic4).get(0), + is("Retention time (" + TopicConfig.RETENTION_MS_CONFIG + ") of existing internal topic " + + topic4 + " is " + shorterRetentionMs + " but should be " + retentionMs + " or larger.") + ); + assertThat(misconfigurationsForTopics, hasKey(topic5)); + assertThat(misconfigurationsForTopics.get(topic5).size(), is(1)); + assertThat( + misconfigurationsForTopics.get(topic5).get(0), + is("Retention byte (" + TopicConfig.RETENTION_BYTES_CONFIG + ") of existing internal topic " + + topic5 + " is set but it should be unset.") + ); + assertThat(misconfigurationsForTopics, not(hasKey(topic1))); + assertThat(misconfigurationsForTopics, not(hasKey(topic3))); + } + + @Test + public void shouldReportMisconfigurationsOfCleanupPolicyForRepartitionTopics() { + final long retentionMs = 1000; + setupTopicInMockAdminClient(topic1, repartitionTopicConfig()); + final Map repartitionTopicConfigCleanupPolicyCompact = repartitionTopicConfig(); + repartitionTopicConfigCleanupPolicyCompact.put(TopicConfig.CLEANUP_POLICY_CONFIG, TopicConfig.CLEANUP_POLICY_COMPACT); + setupTopicInMockAdminClient(topic2, repartitionTopicConfigCleanupPolicyCompact); + final Map repartitionTopicConfigCleanupPolicyCompactAndDelete = repartitionTopicConfig(); + repartitionTopicConfigCleanupPolicyCompactAndDelete.put( + TopicConfig.CLEANUP_POLICY_CONFIG, + TopicConfig.CLEANUP_POLICY_COMPACT + "," + TopicConfig.CLEANUP_POLICY_DELETE + ); + setupTopicInMockAdminClient(topic3, repartitionTopicConfigCleanupPolicyCompactAndDelete); + final Map repartitionTopicConfigWithFiniteRetentionMs = repartitionTopicConfig(); + repartitionTopicConfigWithFiniteRetentionMs.put(TopicConfig.RETENTION_MS_CONFIG, String.valueOf(retentionMs)); + setupTopicInMockAdminClient(topic4, repartitionTopicConfigWithFiniteRetentionMs); + final Map repartitionTopicConfigWithRetentionBytesSet = repartitionTopicConfig(); + repartitionTopicConfigWithRetentionBytesSet.put(TopicConfig.RETENTION_BYTES_CONFIG, "1024"); + setupTopicInMockAdminClient(topic5, repartitionTopicConfigWithRetentionBytesSet); + final InternalTopicConfig internalTopicConfig1 = setupRepartitionTopicConfig(topic1, 1); + final InternalTopicConfig internalTopicConfig2 = setupRepartitionTopicConfig(topic2, 1); + final InternalTopicConfig internalTopicConfig3 = setupRepartitionTopicConfig(topic3, 1); + final InternalTopicConfig internalTopicConfig4 = setupRepartitionTopicConfig(topic4, 1); + final InternalTopicConfig internalTopicConfig5 = setupRepartitionTopicConfig(topic5, 1); + + final ValidationResult validationResult = internalTopicManager.validate(mkMap( + mkEntry(topic1, internalTopicConfig1), + mkEntry(topic2, internalTopicConfig2), + mkEntry(topic3, internalTopicConfig3), + mkEntry(topic4, internalTopicConfig4), + mkEntry(topic5, internalTopicConfig5) + )); + + final Map> misconfigurationsForTopics = validationResult.misconfigurationsForTopics(); + assertThat(validationResult.missingTopics(), empty()); + assertThat(misconfigurationsForTopics.size(), is(4)); + assertThat(misconfigurationsForTopics, hasKey(topic2)); + assertThat(misconfigurationsForTopics.get(topic2).size(), is(1)); + assertThat( + misconfigurationsForTopics.get(topic2).get(0), + is("Cleanup policy (" + TopicConfig.CLEANUP_POLICY_CONFIG + ") of existing internal topic " + + topic2 + " should not contain \"" + TopicConfig.CLEANUP_POLICY_COMPACT + "\".") + ); + assertThat(misconfigurationsForTopics, hasKey(topic3)); + assertThat(misconfigurationsForTopics.get(topic3).size(), is(1)); + assertThat( + misconfigurationsForTopics.get(topic3).get(0), + is("Cleanup policy (" + TopicConfig.CLEANUP_POLICY_CONFIG + ") of existing internal topic " + + topic3 + " should not contain \"" + TopicConfig.CLEANUP_POLICY_COMPACT + "\".") + ); + assertThat(misconfigurationsForTopics, hasKey(topic4)); + assertThat(misconfigurationsForTopics.get(topic4).size(), is(1)); + assertThat( + misconfigurationsForTopics.get(topic4).get(0), + is("Retention time (" + TopicConfig.RETENTION_MS_CONFIG + ") of existing internal topic " + + topic4 + " is " + retentionMs + " but should be -1.") + ); + assertThat(misconfigurationsForTopics, hasKey(topic5)); + assertThat(misconfigurationsForTopics.get(topic5).size(), is(1)); + assertThat( + misconfigurationsForTopics.get(topic5).get(0), + is("Retention byte (" + TopicConfig.RETENTION_BYTES_CONFIG + ") of existing internal topic " + + topic5 + " is set but it should be unset.") + ); + } + + @Test + public void shouldReportMultipleMisconfigurationsForSameTopic() { + final long retentionMs = 1000; + final long shorterRetentionMs = 900; + final Map windowedChangelogConfig = windowedChangelogConfig(shorterRetentionMs); + windowedChangelogConfig.put(TopicConfig.RETENTION_BYTES_CONFIG, "1024"); + setupTopicInMockAdminClient(topic1, windowedChangelogConfig); + final InternalTopicConfig internalTopicConfig1 = setupWindowedChangelogTopicConfig(topic1, 1, retentionMs); + + final ValidationResult validationResult = internalTopicManager.validate(mkMap( + mkEntry(topic1, internalTopicConfig1) + )); + + final Map> misconfigurationsForTopics = validationResult.misconfigurationsForTopics(); + assertThat(validationResult.missingTopics(), empty()); + assertThat(misconfigurationsForTopics.size(), is(1)); + assertThat(misconfigurationsForTopics, hasKey(topic1)); + assertThat(misconfigurationsForTopics.get(topic1).size(), is(2)); + assertThat( + misconfigurationsForTopics.get(topic1).get(0), + is("Retention time (" + TopicConfig.RETENTION_MS_CONFIG + ") of existing internal topic " + + topic1 + " is " + shorterRetentionMs + " but should be " + retentionMs + " or larger.") + ); + assertThat( + misconfigurationsForTopics.get(topic1).get(1), + is("Retention byte (" + TopicConfig.RETENTION_BYTES_CONFIG + ") of existing internal topic " + + topic1 + " is set but it should be unset.") + ); + } + + @Test + public void shouldThrowWhenPartitionCountUnknown() { + setupTopicInMockAdminClient(topic1, repartitionTopicConfig()); + final InternalTopicConfig internalTopicConfig = new RepartitionTopicConfig(topic1, Collections.emptyMap()); + + assertThrows( + IllegalStateException.class, + () -> internalTopicManager.validate(Collections.singletonMap(topic1, internalTopicConfig)) + ); + } + + @Test + public void shouldNotThrowExceptionIfTopicExistsWithDifferentReplication() { + setupTopicInMockAdminClient(topic1, repartitionTopicConfig()); + // attempt to create it again with replication 1 + final InternalTopicManager internalTopicManager2 = new InternalTopicManager( + Time.SYSTEM, + mockAdminClient, + new StreamsConfig(config) + ); + + final InternalTopicConfig internalTopicConfig = setupRepartitionTopicConfig(topic1, 1); + final ValidationResult validationResult = + internalTopicManager2.validate(Collections.singletonMap(topic1, internalTopicConfig)); + + assertThat(validationResult.missingTopics(), empty()); + assertThat(validationResult.misconfigurationsForTopics(), anEmptyMap()); + } + + @Test + public void shouldRetryWhenCallsThrowTimeoutExceptionDuringValidation() { + setupTopicInMockAdminClient(topic1, repartitionTopicConfig()); + mockAdminClient.timeoutNextRequest(2); + final InternalTopicConfig internalTopicConfig = setupRepartitionTopicConfig(topic1, 1); + + final ValidationResult validationResult = internalTopicManager.validate(Collections.singletonMap(topic1, internalTopicConfig)); + + assertThat(validationResult.missingTopics(), empty()); + assertThat(validationResult.misconfigurationsForTopics(), anEmptyMap()); + } + + @Test + public void shouldOnlyRetryDescribeTopicsWhenDescribeTopicsThrowsLeaderNotAvailableExceptionDuringValidation() { + final AdminClient admin = EasyMock.createNiceMock(AdminClient.class); + final InternalTopicManager topicManager = new InternalTopicManager( + Time.SYSTEM, + admin, + new StreamsConfig(config) + ); + final KafkaFutureImpl topicDescriptionFailFuture = new KafkaFutureImpl<>(); + topicDescriptionFailFuture.completeExceptionally(new LeaderNotAvailableException("Leader Not Available!")); + final KafkaFutureImpl topicDescriptionSuccessfulFuture = new KafkaFutureImpl<>(); + topicDescriptionSuccessfulFuture.complete(new TopicDescription( + topic1, + false, + Collections.singletonList(new TopicPartitionInfo(0, broker1, cluster, Collections.emptyList())) + )); + EasyMock.expect(admin.describeTopics(Collections.singleton(topic1))) + .andReturn(new MockDescribeTopicsResult(mkMap(mkEntry(topic1, topicDescriptionFailFuture)))) + .andReturn(new MockDescribeTopicsResult(mkMap(mkEntry(topic1, topicDescriptionSuccessfulFuture)))); + final KafkaFutureImpl topicConfigSuccessfulFuture = new KafkaFutureImpl<>(); + topicConfigSuccessfulFuture.complete( + new Config(repartitionTopicConfig().entrySet().stream() + .map(entry -> new ConfigEntry(entry.getKey(), entry.getValue())).collect(Collectors.toSet())) + ); + final ConfigResource topicResource = new ConfigResource(Type.TOPIC, topic1); + EasyMock.expect(admin.describeConfigs(Collections.singleton(topicResource))) + .andReturn(new MockDescribeConfigsResult(mkMap(mkEntry(topicResource, topicConfigSuccessfulFuture)))); + EasyMock.replay(admin); + final InternalTopicConfig internalTopicConfig = setupRepartitionTopicConfig(topic1, 1); + + final ValidationResult validationResult = topicManager.validate(Collections.singletonMap(topic1, internalTopicConfig)); + + assertThat(validationResult.missingTopics(), empty()); + assertThat(validationResult.misconfigurationsForTopics(), anEmptyMap()); + EasyMock.verify(admin); + } + + @Test + public void shouldOnlyRetryDescribeConfigsWhenDescribeConfigsThrowsLeaderNotAvailableExceptionDuringValidation() { + final AdminClient admin = EasyMock.createNiceMock(AdminClient.class); + final InternalTopicManager topicManager = new InternalTopicManager( + Time.SYSTEM, + admin, + new StreamsConfig(config) + ); + final KafkaFutureImpl topicDescriptionSuccessfulFuture = new KafkaFutureImpl<>(); + topicDescriptionSuccessfulFuture.complete(new TopicDescription( + topic1, + false, + Collections.singletonList(new TopicPartitionInfo(0, broker1, cluster, Collections.emptyList())) + )); + EasyMock.expect(admin.describeTopics(Collections.singleton(topic1))) + .andReturn(new MockDescribeTopicsResult(mkMap(mkEntry(topic1, topicDescriptionSuccessfulFuture)))); + final KafkaFutureImpl topicConfigsFailFuture = new KafkaFutureImpl<>(); + topicConfigsFailFuture.completeExceptionally(new LeaderNotAvailableException("Leader Not Available!")); + final KafkaFutureImpl topicConfigSuccessfulFuture = new KafkaFutureImpl<>(); + topicConfigSuccessfulFuture.complete( + new Config(repartitionTopicConfig().entrySet().stream() + .map(entry -> new ConfigEntry(entry.getKey(), entry.getValue())).collect(Collectors.toSet())) + ); + final ConfigResource topicResource = new ConfigResource(Type.TOPIC, topic1); + EasyMock.expect(admin.describeConfigs(Collections.singleton(topicResource))) + .andReturn(new MockDescribeConfigsResult(mkMap(mkEntry(topicResource, topicConfigsFailFuture)))) + .andReturn(new MockDescribeConfigsResult(mkMap(mkEntry(topicResource, topicConfigSuccessfulFuture)))); + EasyMock.replay(admin); + final InternalTopicConfig internalTopicConfig = setupRepartitionTopicConfig(topic1, 1); + + final ValidationResult validationResult = topicManager.validate(Collections.singletonMap(topic1, internalTopicConfig)); + + assertThat(validationResult.missingTopics(), empty()); + assertThat(validationResult.misconfigurationsForTopics(), anEmptyMap()); + EasyMock.verify(admin); + } + + @Test + public void shouldOnlyRetryNotSuccessfulFuturesDuringValidation() { + final AdminClient admin = EasyMock.createNiceMock(AdminClient.class); + final InternalTopicManager topicManager = new InternalTopicManager( + Time.SYSTEM, + admin, + new StreamsConfig(config) + ); + final KafkaFutureImpl topicDescriptionFailFuture = new KafkaFutureImpl<>(); + topicDescriptionFailFuture.completeExceptionally(new LeaderNotAvailableException("Leader Not Available!")); + final KafkaFutureImpl topicDescriptionSuccessfulFuture1 = new KafkaFutureImpl<>(); + topicDescriptionSuccessfulFuture1.complete(new TopicDescription( + topic1, + false, + Collections.singletonList(new TopicPartitionInfo(0, broker1, cluster, Collections.emptyList())) + )); + final KafkaFutureImpl topicDescriptionSuccessfulFuture2 = new KafkaFutureImpl<>(); + topicDescriptionSuccessfulFuture2.complete(new TopicDescription( + topic2, + false, + Collections.singletonList(new TopicPartitionInfo(0, broker1, cluster, Collections.emptyList())) + )); + EasyMock.expect(admin.describeTopics(mkSet(topic1, topic2))) + .andAnswer(() -> new MockDescribeTopicsResult(mkMap( + mkEntry(topic1, topicDescriptionSuccessfulFuture1), + mkEntry(topic2, topicDescriptionFailFuture) + ))); + EasyMock.expect(admin.describeTopics(mkSet(topic2))) + .andAnswer(() -> new MockDescribeTopicsResult(mkMap( + mkEntry(topic2, topicDescriptionSuccessfulFuture2) + ))); + final KafkaFutureImpl topicConfigSuccessfulFuture = new KafkaFutureImpl<>(); + topicConfigSuccessfulFuture.complete( + new Config(repartitionTopicConfig().entrySet().stream() + .map(entry -> new ConfigEntry(entry.getKey(), entry.getValue())).collect(Collectors.toSet())) + ); + final ConfigResource topicResource1 = new ConfigResource(Type.TOPIC, topic1); + final ConfigResource topicResource2 = new ConfigResource(Type.TOPIC, topic2); + EasyMock.expect(admin.describeConfigs(mkSet(topicResource1, topicResource2))) + .andAnswer(() -> new MockDescribeConfigsResult(mkMap( + mkEntry(topicResource1, topicConfigSuccessfulFuture), + mkEntry(topicResource2, topicConfigSuccessfulFuture) + ))); + EasyMock.replay(admin); + final InternalTopicConfig internalTopicConfig1 = setupRepartitionTopicConfig(topic1, 1); + final InternalTopicConfig internalTopicConfig2 = setupRepartitionTopicConfig(topic2, 1); + + final ValidationResult validationResult = topicManager.validate(mkMap( + mkEntry(topic1, internalTopicConfig1), + mkEntry(topic2, internalTopicConfig2) + )); + + assertThat(validationResult.missingTopics(), empty()); + assertThat(validationResult.misconfigurationsForTopics(), anEmptyMap()); + EasyMock.verify(admin); + } + + @Test + public void shouldThrowWhenDescribeTopicsThrowsUnexpectedExceptionDuringValidation() { + final AdminClient admin = EasyMock.createNiceMock(AdminClient.class); + final InternalTopicManager topicManager = new InternalTopicManager( + Time.SYSTEM, + admin, + new StreamsConfig(config) + ); + final KafkaFutureImpl topicDescriptionFailFuture = new KafkaFutureImpl<>(); + topicDescriptionFailFuture.completeExceptionally(new IllegalStateException("Nobody expects the Spanish inquisition")); + EasyMock.expect(admin.describeTopics(Collections.singleton(topic1))) + .andStubAnswer(() -> new MockDescribeTopicsResult(mkMap(mkEntry(topic1, topicDescriptionFailFuture)))); + EasyMock.replay(admin); + final InternalTopicConfig internalTopicConfig = setupRepartitionTopicConfig(topic1, 1); + + assertThrows(Throwable.class, () -> topicManager.validate(Collections.singletonMap(topic1, internalTopicConfig))); + } + + @Test + public void shouldThrowWhenDescribeConfigsThrowsUnexpectedExceptionDuringValidation() { + final AdminClient admin = EasyMock.createNiceMock(AdminClient.class); + final InternalTopicManager topicManager = new InternalTopicManager( + Time.SYSTEM, + admin, + new StreamsConfig(config) + ); + final KafkaFutureImpl configDescriptionFailFuture = new KafkaFutureImpl<>(); + configDescriptionFailFuture.completeExceptionally(new IllegalStateException("Nobody expects the Spanish inquisition")); + final ConfigResource topicResource = new ConfigResource(Type.TOPIC, topic1); + EasyMock.expect(admin.describeConfigs(Collections.singleton(topicResource))) + .andStubAnswer(() -> new MockDescribeConfigsResult(mkMap(mkEntry(topicResource, configDescriptionFailFuture)))); + EasyMock.replay(admin); + final InternalTopicConfig internalTopicConfig = setupRepartitionTopicConfig(topic1, 1); + + assertThrows(Throwable.class, () -> topicManager.validate(Collections.singletonMap(topic1, internalTopicConfig))); + } + + @Test + public void shouldThrowWhenTopicDescriptionsDoNotContainTopicDuringValidation() { + final AdminClient admin = EasyMock.createNiceMock(AdminClient.class); + final InternalTopicManager topicManager = new InternalTopicManager( + Time.SYSTEM, + admin, + new StreamsConfig(config) + ); + final KafkaFutureImpl topicDescriptionSuccessfulFuture = new KafkaFutureImpl<>(); + topicDescriptionSuccessfulFuture.complete(new TopicDescription( + topic1, + false, + Collections.singletonList(new TopicPartitionInfo(0, broker1, cluster, Collections.emptyList())) + )); + EasyMock.expect(admin.describeTopics(Collections.singleton(topic1))) + .andStubAnswer(() -> new MockDescribeTopicsResult(mkMap(mkEntry(topic2, topicDescriptionSuccessfulFuture)))); + final KafkaFutureImpl topicConfigSuccessfulFuture = new KafkaFutureImpl<>(); + topicConfigSuccessfulFuture.complete(new Config(Collections.emptySet())); + final ConfigResource topicResource = new ConfigResource(Type.TOPIC, topic1); + EasyMock.expect(admin.describeConfigs(Collections.singleton(topicResource))) + .andStubAnswer(() -> new MockDescribeConfigsResult(mkMap(mkEntry(topicResource, topicConfigSuccessfulFuture)))); + EasyMock.replay(admin); + final InternalTopicConfig internalTopicConfig = setupRepartitionTopicConfig(topic1, 1); + + assertThrows( + IllegalStateException.class, + () -> topicManager.validate(Collections.singletonMap(topic1, internalTopicConfig)) + ); + } + + @Test + public void shouldThrowWhenConfigDescriptionsDoNotContainTopicDuringValidation() { + final AdminClient admin = EasyMock.createNiceMock(AdminClient.class); + final InternalTopicManager topicManager = new InternalTopicManager( + Time.SYSTEM, + admin, + new StreamsConfig(config) + ); + final KafkaFutureImpl topicDescriptionSuccessfulFuture = new KafkaFutureImpl<>(); + topicDescriptionSuccessfulFuture.complete(new TopicDescription( + topic1, + false, + Collections.singletonList(new TopicPartitionInfo(0, broker1, cluster, Collections.emptyList())) + )); + EasyMock.expect(admin.describeTopics(Collections.singleton(topic1))) + .andStubAnswer(() -> new MockDescribeTopicsResult(mkMap(mkEntry(topic1, topicDescriptionSuccessfulFuture)))); + final KafkaFutureImpl topicConfigSuccessfulFuture = new KafkaFutureImpl<>(); + topicConfigSuccessfulFuture.complete(new Config(Collections.emptySet())); + final ConfigResource topicResource1 = new ConfigResource(Type.TOPIC, topic1); + final ConfigResource topicResource2 = new ConfigResource(Type.TOPIC, topic2); + EasyMock.expect(admin.describeConfigs(Collections.singleton(topicResource1))) + .andStubAnswer(() -> new MockDescribeConfigsResult(mkMap(mkEntry(topicResource2, topicConfigSuccessfulFuture)))); + EasyMock.replay(admin); + final InternalTopicConfig internalTopicConfig = setupRepartitionTopicConfig(topic1, 1); + + assertThrows( + IllegalStateException.class, + () -> topicManager.validate(Collections.singletonMap(topic1, internalTopicConfig)) + ); + } + + @Test + public void shouldThrowWhenConfigDescriptionsDoNotCleanupPolicyForUnwindowedConfigDuringValidation() { + shouldThrowWhenConfigDescriptionsDoNotContainConfigDuringValidation( + setupUnwindowedChangelogTopicConfig(topic1, 1), + configWithoutKey(unwindowedChangelogConfig(), TopicConfig.CLEANUP_POLICY_CONFIG) + ); + } + + @Test + public void shouldThrowWhenConfigDescriptionsDoNotContainCleanupPolicyForWindowedConfigDuringValidation() { + final long retentionMs = 1000; + shouldThrowWhenConfigDescriptionsDoNotContainConfigDuringValidation( + setupWindowedChangelogTopicConfig(topic1, 1, retentionMs), + configWithoutKey(windowedChangelogConfig(retentionMs), TopicConfig.CLEANUP_POLICY_CONFIG) + ); + } + + @Test + public void shouldThrowWhenConfigDescriptionsDoNotContainRetentionMsForWindowedConfigDuringValidation() { + final long retentionMs = 1000; + shouldThrowWhenConfigDescriptionsDoNotContainConfigDuringValidation( + setupWindowedChangelogTopicConfig(topic1, 1, retentionMs), + configWithoutKey(windowedChangelogConfig(retentionMs), TopicConfig.RETENTION_MS_CONFIG) + ); + } + + @Test + public void shouldThrowWhenConfigDescriptionsDoNotContainRetentionBytesForWindowedConfigDuringValidation() { + final long retentionMs = 1000; + shouldThrowWhenConfigDescriptionsDoNotContainConfigDuringValidation( + setupWindowedChangelogTopicConfig(topic1, 1, retentionMs), + configWithoutKey(windowedChangelogConfig(retentionMs), TopicConfig.RETENTION_BYTES_CONFIG) + ); + } + + @Test + public void shouldThrowWhenConfigDescriptionsDoNotContainCleanupPolicyForRepartitionConfigDuringValidation() { + shouldThrowWhenConfigDescriptionsDoNotContainConfigDuringValidation( + setupRepartitionTopicConfig(topic1, 1), + configWithoutKey(repartitionTopicConfig(), TopicConfig.CLEANUP_POLICY_CONFIG) + ); + } + + @Test + public void shouldThrowWhenConfigDescriptionsDoNotContainRetentionMsForRepartitionConfigDuringValidation() { + shouldThrowWhenConfigDescriptionsDoNotContainConfigDuringValidation( + setupRepartitionTopicConfig(topic1, 1), + configWithoutKey(repartitionTopicConfig(), TopicConfig.RETENTION_MS_CONFIG) + ); + } + + @Test + public void shouldThrowWhenConfigDescriptionsDoNotContainRetentionBytesForRepartitionConfigDuringValidation() { + shouldThrowWhenConfigDescriptionsDoNotContainConfigDuringValidation( + setupRepartitionTopicConfig(topic1, 1), + configWithoutKey(repartitionTopicConfig(), TopicConfig.RETENTION_BYTES_CONFIG) + ); + } + + private Config configWithoutKey(final Map config, final String key) { + return new Config(config.entrySet().stream() + .filter(entry -> !entry.getKey().equals(key)) + .map(entry -> new ConfigEntry(entry.getKey(), entry.getValue())).collect(Collectors.toSet()) + ); + } + + private void shouldThrowWhenConfigDescriptionsDoNotContainConfigDuringValidation(final InternalTopicConfig streamsSideTopicConfig, + final Config brokerSideTopicConfig) { + final AdminClient admin = EasyMock.createNiceMock(AdminClient.class); + final InternalTopicManager topicManager = new InternalTopicManager( + Time.SYSTEM, + admin, + new StreamsConfig(config) + ); + final KafkaFutureImpl topicDescriptionSuccessfulFuture = new KafkaFutureImpl<>(); + topicDescriptionSuccessfulFuture.complete(new TopicDescription( + topic1, + false, + Collections.singletonList(new TopicPartitionInfo(0, broker1, cluster, Collections.emptyList())) + )); + EasyMock.expect(admin.describeTopics(Collections.singleton(topic1))) + .andStubAnswer(() -> new MockDescribeTopicsResult(mkMap(mkEntry(topic1, topicDescriptionSuccessfulFuture)))); + final KafkaFutureImpl topicConfigSuccessfulFuture = new KafkaFutureImpl<>(); + topicConfigSuccessfulFuture.complete(brokerSideTopicConfig); + final ConfigResource topicResource1 = new ConfigResource(Type.TOPIC, topic1); + EasyMock.expect(admin.describeConfigs(Collections.singleton(topicResource1))) + .andStubAnswer(() -> new MockDescribeConfigsResult(mkMap(mkEntry(topicResource1, topicConfigSuccessfulFuture)))); + EasyMock.replay(admin); + + assertThrows( + IllegalStateException.class, + () -> topicManager.validate(Collections.singletonMap(topic1, streamsSideTopicConfig)) + ); + } + + @Test + public void shouldThrowTimeoutExceptionWhenTimeoutIsExceededDuringValidation() { + final AdminClient admin = EasyMock.createNiceMock(AdminClient.class); + final MockTime time = new MockTime( + (Integer) config.get(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG)) / 3 + ); + final InternalTopicManager topicManager = new InternalTopicManager( + time, + admin, + new StreamsConfig(config) + ); + final KafkaFutureImpl topicDescriptionFailFuture = new KafkaFutureImpl<>(); + topicDescriptionFailFuture.completeExceptionally(new TimeoutException()); + EasyMock.expect(admin.describeTopics(Collections.singleton(topic1))) + .andStubAnswer(() -> new MockDescribeTopicsResult(mkMap(mkEntry(topic1, topicDescriptionFailFuture)))); + final KafkaFutureImpl topicConfigSuccessfulFuture = new KafkaFutureImpl<>(); + topicConfigSuccessfulFuture.complete( + new Config(repartitionTopicConfig().entrySet().stream() + .map(entry -> new ConfigEntry(entry.getKey(), entry.getValue())).collect(Collectors.toSet())) + ); + final ConfigResource topicResource = new ConfigResource(Type.TOPIC, topic1); + EasyMock.expect(admin.describeConfigs(Collections.singleton(topicResource))) + .andStubAnswer(() -> new MockDescribeConfigsResult(mkMap(mkEntry(topicResource, topicConfigSuccessfulFuture)))); + EasyMock.replay(admin); + final InternalTopicConfig internalTopicConfig = setupRepartitionTopicConfig(topic1, 1); + + assertThrows( + TimeoutException.class, + () -> topicManager.validate(Collections.singletonMap(topic1, internalTopicConfig)) + ); + } + + @Test + public void shouldThrowTimeoutExceptionWhenFuturesNeverCompleteDuringValidation() { + final AdminClient admin = EasyMock.createNiceMock(AdminClient.class); + final MockTime time = new MockTime( + (Integer) config.get(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG)) / 3 + ); + final InternalTopicManager topicManager = new InternalTopicManager( + time, + admin, + new StreamsConfig(config) + ); + final KafkaFutureImpl topicDescriptionFutureThatNeverCompletes = new KafkaFutureImpl<>(); + EasyMock.expect(admin.describeTopics(Collections.singleton(topic1))) + .andStubAnswer(() -> new MockDescribeTopicsResult(mkMap(mkEntry(topic1, topicDescriptionFutureThatNeverCompletes)))); + final KafkaFutureImpl topicConfigSuccessfulFuture = new KafkaFutureImpl<>(); + topicConfigSuccessfulFuture.complete( + new Config(repartitionTopicConfig().entrySet().stream() + .map(entry -> new ConfigEntry(entry.getKey(), entry.getValue())).collect(Collectors.toSet())) + ); + final ConfigResource topicResource = new ConfigResource(Type.TOPIC, topic1); + EasyMock.expect(admin.describeConfigs(Collections.singleton(topicResource))) + .andStubAnswer(() -> new MockDescribeConfigsResult(mkMap(mkEntry(topicResource, topicConfigSuccessfulFuture)))); + EasyMock.replay(admin); + final InternalTopicConfig internalTopicConfig = setupRepartitionTopicConfig(topic1, 1); + + assertThrows( + TimeoutException.class, + () -> topicManager.validate(Collections.singletonMap(topic1, internalTopicConfig)) + ); + } + + private NewTopic newTopic(final String topicName, + final InternalTopicConfig topicConfig, + final StreamsConfig streamsConfig) { + return new NewTopic( + topicName, + topicConfig.numberOfPartitions(), + Optional.of(streamsConfig.getInt(StreamsConfig.REPLICATION_FACTOR_CONFIG).shortValue()) + ).configs(topicConfig.getProperties( + Collections.emptyMap(), + streamsConfig.getLong(StreamsConfig.WINDOW_STORE_CHANGE_LOG_ADDITIONAL_RETENTION_MS_CONFIG)) + ); + } + + private Map repartitionTopicConfig() { + return mkMap( + mkEntry(TopicConfig.CLEANUP_POLICY_CONFIG, TopicConfig.CLEANUP_POLICY_DELETE), + mkEntry(TopicConfig.RETENTION_MS_CONFIG, "-1"), + mkEntry(TopicConfig.RETENTION_BYTES_CONFIG, null) + ); + } + + private Map unwindowedChangelogConfig() { + return mkMap( + mkEntry(TopicConfig.CLEANUP_POLICY_CONFIG, TopicConfig.CLEANUP_POLICY_COMPACT) + ); + } + + private Map windowedChangelogConfig(final long retentionMs) { + return mkMap( + mkEntry(TopicConfig.CLEANUP_POLICY_CONFIG, TopicConfig.CLEANUP_POLICY_COMPACT + "," + TopicConfig.CLEANUP_POLICY_DELETE), + mkEntry(TopicConfig.RETENTION_MS_CONFIG, String.valueOf(retentionMs)), + mkEntry(TopicConfig.RETENTION_BYTES_CONFIG, null) + ); + } + + private void setupTopicInMockAdminClient(final String topic, final Map topicConfig) { + mockAdminClient.addTopic( + false, + topic, + Collections.singletonList(new TopicPartitionInfo(0, broker1, cluster, Collections.emptyList())), + topicConfig + ); + } + + private InternalTopicConfig setupUnwindowedChangelogTopicConfig(final String topicName, + final int partitionCount) { + final InternalTopicConfig internalTopicConfig = + new UnwindowedChangelogTopicConfig(topicName, Collections.emptyMap()); + internalTopicConfig.setNumberOfPartitions(partitionCount); + return internalTopicConfig; + } + + private InternalTopicConfig setupWindowedChangelogTopicConfig(final String topicName, + final int partitionCount, + final long retentionMs) { + final InternalTopicConfig internalTopicConfig = new WindowedChangelogTopicConfig( + topicName, + mkMap(mkEntry(TopicConfig.RETENTION_MS_CONFIG, String.valueOf(retentionMs))) + ); + internalTopicConfig.setNumberOfPartitions(partitionCount); + return internalTopicConfig; + } + + private InternalTopicConfig setupRepartitionTopicConfig(final String topicName, + final int partitionCount) { + final InternalTopicConfig internalTopicConfig = new RepartitionTopicConfig(topicName, Collections.emptyMap()); + internalTopicConfig.setNumberOfPartitions(partitionCount); + return internalTopicConfig; + } + + private static class MockCreateTopicsResult extends CreateTopicsResult { + MockCreateTopicsResult(final Map> futures) { + super(futures); + } + } + + private static class MockDeleteTopicsResult extends DeleteTopicsResult { + MockDeleteTopicsResult(final Map> futures) { + super(null, futures); + } + } + + private static class MockDescribeTopicsResult extends DescribeTopicsResult { + MockDescribeTopicsResult(final Map> futures) { + super(null, futures); + } + } + + private static class MockDescribeConfigsResult extends DescribeConfigsResult { + MockDescribeConfigsResult(final Map> futures) { + super(futures); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/InternalTopologyBuilderTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/InternalTopologyBuilderTest.java new file mode 100644 index 0000000..74f4059 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/InternalTopologyBuilderTest.java @@ -0,0 +1,1186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.OffsetResetStrategy; +import org.apache.kafka.common.config.TopicConfig; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.TopologyDescription; +import org.apache.kafka.streams.errors.TopologyException; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.TopicNameExtractor; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder.SubtopologyDescription; +import org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.test.MockApiProcessor; +import org.apache.kafka.test.MockApiProcessorSupplier; +import org.apache.kafka.test.MockKeyValueStoreBuilder; +import org.apache.kafka.test.MockTimestampExtractor; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.regex.Pattern; + +import static java.time.Duration.ofSeconds; +import static java.util.Arrays.asList; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkProperties; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.SUBTOPOLOGY_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.SUBTOPOLOGY_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.SUBTOPOLOGY_2; + +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.core.IsInstanceOf.instanceOf; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class InternalTopologyBuilderTest { + + private final Serde stringSerde = Serdes.String(); + private final InternalTopologyBuilder builder = new InternalTopologyBuilder(); + private final StoreBuilder storeBuilder = new MockKeyValueStoreBuilder("testStore", false); + + @Test + public void shouldAddSourceWithOffsetReset() { + final String earliestTopic = "earliestTopic"; + final String latestTopic = "latestTopic"; + + builder.addSource(Topology.AutoOffsetReset.EARLIEST, "source", null, null, null, earliestTopic); + builder.addSource(Topology.AutoOffsetReset.LATEST, "source2", null, null, null, latestTopic); + builder.initializeSubscription(); + + assertThat(builder.offsetResetStrategy(earliestTopic), equalTo(OffsetResetStrategy.EARLIEST)); + assertThat(builder.offsetResetStrategy(latestTopic), equalTo(OffsetResetStrategy.LATEST)); + } + + @Test + public void shouldAddSourcePatternWithOffsetReset() { + final String earliestTopicPattern = "earliest.*Topic"; + final String latestTopicPattern = "latest.*Topic"; + + builder.addSource(Topology.AutoOffsetReset.EARLIEST, "source", null, null, null, Pattern.compile(earliestTopicPattern)); + builder.addSource(Topology.AutoOffsetReset.LATEST, "source2", null, null, null, Pattern.compile(latestTopicPattern)); + builder.initializeSubscription(); + + assertThat(builder.offsetResetStrategy("earliestTestTopic"), equalTo(OffsetResetStrategy.EARLIEST)); + assertThat(builder.offsetResetStrategy("latestTestTopic"), equalTo(OffsetResetStrategy.LATEST)); + } + + @Test + public void shouldAddSourceWithoutOffsetReset() { + builder.addSource(null, "source", null, stringSerde.deserializer(), stringSerde.deserializer(), "test-topic"); + builder.initializeSubscription(); + + assertEquals(Collections.singletonList("test-topic"), builder.sourceTopicCollection()); + + assertThat(builder.offsetResetStrategy("test-topic"), equalTo(OffsetResetStrategy.NONE)); + } + + @Test + public void shouldAddPatternSourceWithoutOffsetReset() { + final Pattern expectedPattern = Pattern.compile("test-.*"); + + builder.addSource(null, "source", null, stringSerde.deserializer(), stringSerde.deserializer(), Pattern.compile("test-.*")); + builder.initializeSubscription(); + + assertThat(expectedPattern.pattern(), builder.sourceTopicsPatternString(), equalTo("test-.*")); + + assertThat(builder.offsetResetStrategy("test-topic"), equalTo(OffsetResetStrategy.NONE)); + } + + @Test + public void shouldNotAllowOffsetResetSourceWithoutTopics() { + assertThrows(TopologyException.class, () -> builder.addSource(Topology.AutoOffsetReset.EARLIEST, "source", + null, stringSerde.deserializer(), stringSerde.deserializer())); + } + + @Test + public void shouldNotAllowOffsetResetSourceWithDuplicateSourceName() { + builder.addSource(Topology.AutoOffsetReset.EARLIEST, "source", null, stringSerde.deserializer(), stringSerde.deserializer(), "topic-1"); + try { + builder.addSource(Topology.AutoOffsetReset.LATEST, "source", null, stringSerde.deserializer(), stringSerde.deserializer(), "topic-2"); + fail("Should throw TopologyException for duplicate source name"); + } catch (final TopologyException expected) { /* ok */ } + } + + @Test + public void testAddSourceWithSameName() { + builder.addSource(null, "source", null, null, null, "topic-1"); + try { + builder.addSource(null, "source", null, null, null, "topic-2"); + fail("Should throw TopologyException with source name conflict"); + } catch (final TopologyException expected) { /* ok */ } + } + + @Test + public void testAddSourceWithSameTopic() { + builder.addSource(null, "source", null, null, null, "topic-1"); + try { + builder.addSource(null, "source-2", null, null, null, "topic-1"); + fail("Should throw TopologyException with topic conflict"); + } catch (final TopologyException expected) { /* ok */ } + } + + @Test + public void testAddProcessorWithSameName() { + builder.addSource(null, "source", null, null, null, "topic-1"); + builder.addProcessor("processor", new MockApiProcessorSupplier<>(), "source"); + try { + builder.addProcessor("processor", new MockApiProcessorSupplier<>(), "source"); + fail("Should throw TopologyException with processor name conflict"); + } catch (final TopologyException expected) { /* ok */ } + } + + @Test + public void testAddProcessorWithWrongParent() { + assertThrows(TopologyException.class, () -> builder.addProcessor("processor", new MockApiProcessorSupplier<>(), "source")); + } + + @Test + public void testAddProcessorWithSelfParent() { + assertThrows(TopologyException.class, () -> builder.addProcessor("processor", new MockApiProcessorSupplier<>(), "processor")); + } + + @Test + public void testAddProcessorWithEmptyParents() { + assertThrows(TopologyException.class, () -> builder.addProcessor("processor", new MockApiProcessorSupplier<>())); + } + + @Test + public void testAddProcessorWithNullParents() { + assertThrows(NullPointerException.class, () -> builder.addProcessor("processor", + new MockApiProcessorSupplier<>(), (String) null)); + } + + @Test + public void testAddProcessorWithBadSupplier() { + final Processor processor = new MockApiProcessor<>(); + final IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> builder.addProcessor("processor", () -> processor, (String) null) + ); + assertThat(exception.getMessage(), containsString("#get() must return a new object each time it is called.")); + } + + @Test + public void testAddGlobalStoreWithBadSupplier() { + final org.apache.kafka.streams.processor.api.Processor processor = new MockApiProcessorSupplier().get(); + final IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> builder.addGlobalStore( + new MockKeyValueStoreBuilder("global-store", false).withLoggingDisabled(), + "globalSource", + null, + null, + null, + "globalTopic", + "global-processor", + () -> processor) + ); + assertThat(exception.getMessage(), containsString("#get() must return a new object each time it is called.")); + } + + @Test + public void testAddSinkWithSameName() { + builder.addSource(null, "source", null, null, null, "topic-1"); + builder.addSink("sink", "topic-2", null, null, null, "source"); + try { + builder.addSink("sink", "topic-3", null, null, null, "source"); + fail("Should throw TopologyException with sink name conflict"); + } catch (final TopologyException expected) { /* ok */ } + } + + @Test + public void testAddSinkWithWrongParent() { + assertThrows(TopologyException.class, () -> builder.addSink("sink", "topic-2", null, null, null, "source")); + } + + @Test + public void testAddSinkWithSelfParent() { + assertThrows(TopologyException.class, () -> builder.addSink("sink", "topic-2", null, null, null, "sink")); + } + + + @Test + public void testAddSinkWithEmptyParents() { + assertThrows(TopologyException.class, () -> builder.addSink("sink", "topic", null, null, null)); + } + + @Test + public void testAddSinkWithNullParents() { + assertThrows(NullPointerException.class, () -> builder.addSink("sink", "topic", null, + null, null, (String) null)); + } + + @Test + public void testAddSinkConnectedWithParent() { + builder.addSource(null, "source", null, null, null, "source-topic"); + builder.addSink("sink", "dest-topic", null, null, null, "source"); + + final Map> nodeGroups = builder.nodeGroups(); + final Set nodeGroup = nodeGroups.get(0); + + assertTrue(nodeGroup.contains("sink")); + assertTrue(nodeGroup.contains("source")); + } + + @Test + public void testAddSinkConnectedWithMultipleParent() { + builder.addSource(null, "source", null, null, null, "source-topic"); + builder.addSource(null, "sourceII", null, null, null, "source-topicII"); + builder.addSink("sink", "dest-topic", null, null, null, "source", "sourceII"); + + final Map> nodeGroups = builder.nodeGroups(); + final Set nodeGroup = nodeGroups.get(0); + + assertTrue(nodeGroup.contains("sink")); + assertTrue(nodeGroup.contains("source")); + assertTrue(nodeGroup.contains("sourceII")); + } + + @Test + public void testOnlyTopicNameSourceTopics() { + builder.setApplicationId("X"); + builder.addSource(null, "source-1", null, null, null, "topic-1"); + builder.addSource(null, "source-2", null, null, null, "topic-2"); + builder.addSource(null, "source-3", null, null, null, "topic-3"); + builder.addInternalTopic("topic-3", InternalTopicProperties.empty()); + builder.initializeSubscription(); + + assertFalse(builder.usesPatternSubscription()); + assertEquals(Arrays.asList("X-topic-3", "topic-1", "topic-2"), builder.sourceTopicCollection()); + } + + @Test + public void testPatternAndNameSourceTopics() { + final Pattern sourcePattern = Pattern.compile("topic-4|topic-5"); + + builder.setApplicationId("X"); + builder.addSource(null, "source-1", null, null, null, "topic-1"); + builder.addSource(null, "source-2", null, null, null, "topic-2"); + builder.addSource(null, "source-3", null, null, null, "topic-3"); + builder.addSource(null, "source-4", null, null, null, sourcePattern); + + builder.addInternalTopic("topic-3", InternalTopicProperties.empty()); + builder.initializeSubscription(); + + final Pattern expectedPattern = Pattern.compile("X-topic-3|topic-1|topic-2|topic-4|topic-5"); + final String patternString = builder.sourceTopicsPatternString(); + + assertEquals(expectedPattern.pattern(), Pattern.compile(patternString).pattern()); + } + + @Test + public void testPatternSourceTopicsWithGlobalTopics() { + builder.setApplicationId("X"); + builder.addSource(null, "source-1", null, null, null, Pattern.compile("topic-1")); + builder.addSource(null, "source-2", null, null, null, Pattern.compile("topic-2")); + builder.addGlobalStore( + new MockKeyValueStoreBuilder("global-store", false).withLoggingDisabled(), + "globalSource", + null, + null, + null, + "globalTopic", + "global-processor", + new MockApiProcessorSupplier<>() + ); + builder.initializeSubscription(); + + final Pattern expectedPattern = Pattern.compile("topic-1|topic-2"); + + final String patternString = builder.sourceTopicsPatternString(); + + assertEquals(expectedPattern.pattern(), Pattern.compile(patternString).pattern()); + } + + @Test + public void testNameSourceTopicsWithGlobalTopics() { + builder.setApplicationId("X"); + builder.addSource(null, "source-1", null, null, null, "topic-1"); + builder.addSource(null, "source-2", null, null, null, "topic-2"); + builder.addGlobalStore( + new MockKeyValueStoreBuilder("global-store", false).withLoggingDisabled(), + "globalSource", + null, + null, + null, + "globalTopic", + "global-processor", + new MockApiProcessorSupplier<>() + ); + builder.initializeSubscription(); + + assertThat(builder.sourceTopicCollection(), equalTo(asList("topic-1", "topic-2"))); + } + + @Test + public void testPatternSourceTopic() { + final Pattern expectedPattern = Pattern.compile("topic-\\d"); + builder.addSource(null, "source-1", null, null, null, expectedPattern); + builder.initializeSubscription(); + final String patternString = builder.sourceTopicsPatternString(); + + assertEquals(expectedPattern.pattern(), Pattern.compile(patternString).pattern()); + } + + @Test + public void testAddMoreThanOnePatternSourceNode() { + final Pattern expectedPattern = Pattern.compile("topics[A-Z]|.*-\\d"); + builder.addSource(null, "source-1", null, null, null, Pattern.compile("topics[A-Z]")); + builder.addSource(null, "source-2", null, null, null, Pattern.compile(".*-\\d")); + builder.initializeSubscription(); + final String patternString = builder.sourceTopicsPatternString(); + + assertEquals(expectedPattern.pattern(), Pattern.compile(patternString).pattern()); + } + + @Test + public void testSubscribeTopicNameAndPattern() { + final Pattern expectedPattern = Pattern.compile("topic-bar|topic-foo|.*-\\d"); + builder.addSource(null, "source-1", null, null, null, "topic-foo", "topic-bar"); + builder.addSource(null, "source-2", null, null, null, Pattern.compile(".*-\\d")); + builder.initializeSubscription(); + final String patternString = builder.sourceTopicsPatternString(); + + assertEquals(expectedPattern.pattern(), Pattern.compile(patternString).pattern()); + } + + @Test + public void testPatternMatchesAlreadyProvidedTopicSource() { + builder.addSource(null, "source-1", null, null, null, "foo"); + try { + builder.addSource(null, "source-2", null, null, null, Pattern.compile("f.*")); + fail("Should throw TopologyException with topic name/pattern conflict"); + } catch (final TopologyException expected) { /* ok */ } + } + + @Test + public void testNamedTopicMatchesAlreadyProvidedPattern() { + builder.addSource(null, "source-1", null, null, null, Pattern.compile("f.*")); + try { + builder.addSource(null, "source-2", null, null, null, "foo"); + fail("Should throw TopologyException with topic name/pattern conflict"); + } catch (final TopologyException expected) { /* ok */ } + } + + @Test + public void testAddStateStoreWithNonExistingProcessor() { + assertThrows(TopologyException.class, () -> builder.addStateStore(storeBuilder, "no-such-processor")); + } + + @Test + public void testAddStateStoreWithSource() { + builder.addSource(null, "source-1", null, null, null, "topic-1"); + try { + builder.addStateStore(storeBuilder, "source-1"); + fail("Should throw TopologyException with store cannot be added to source"); + } catch (final TopologyException expected) { /* ok */ } + } + + @Test + public void testAddStateStoreWithSink() { + builder.addSource(null, "source-1", null, null, null, "topic-1"); + builder.addSink("sink-1", "topic-1", null, null, null, "source-1"); + try { + builder.addStateStore(storeBuilder, "sink-1"); + fail("Should throw TopologyException with store cannot be added to sink"); + } catch (final TopologyException expected) { /* ok */ } + } + + @Test + public void shouldNotAllowToAddStoresWithSameName() { + final StoreBuilder> otherBuilder = + new MockKeyValueStoreBuilder("testStore", false); + + builder.addStateStore(storeBuilder); + + final TopologyException exception = assertThrows( + TopologyException.class, + () -> builder.addStateStore(otherBuilder) + ); + + assertThat( + exception.getMessage(), + equalTo("Invalid topology: A different StateStore has already been added with the name testStore") + ); + } + + @Test + public void shouldNotAllowToAddStoresWithSameNameWhenFirstStoreIsGlobal() { + final StoreBuilder> globalBuilder = + new MockKeyValueStoreBuilder("testStore", false).withLoggingDisabled(); + + builder.addGlobalStore( + globalBuilder, + "global-store", + null, + null, + null, + "global-topic", + "global-processor", + new MockApiProcessorSupplier<>() + ); + + final TopologyException exception = assertThrows( + TopologyException.class, + () -> builder.addStateStore(storeBuilder) + ); + + assertThat( + exception.getMessage(), + equalTo("Invalid topology: A different GlobalStateStore has already been added with the name testStore") + ); + } + + @Test + public void shouldNotAllowToAddStoresWithSameNameWhenSecondStoreIsGlobal() { + final StoreBuilder> globalBuilder = + new MockKeyValueStoreBuilder("testStore", false).withLoggingDisabled(); + + builder.addStateStore(storeBuilder); + + final TopologyException exception = assertThrows( + TopologyException.class, + () -> builder.addGlobalStore( + globalBuilder, + "global-store", + null, + null, + null, + "global-topic", + "global-processor", + new MockApiProcessorSupplier<>() + ) + ); + + assertThat( + exception.getMessage(), + equalTo("Invalid topology: A different StateStore has already been added with the name testStore") + ); + } + + @Test + public void shouldNotAllowToAddGlobalStoresWithSameName() { + final StoreBuilder> firstGlobalBuilder = + new MockKeyValueStoreBuilder("testStore", false).withLoggingDisabled(); + final StoreBuilder> secondGlobalBuilder = + new MockKeyValueStoreBuilder("testStore", false).withLoggingDisabled(); + + builder.addGlobalStore( + firstGlobalBuilder, + "global-store", + null, + null, + null, + "global-topic", + "global-processor", + new MockApiProcessorSupplier<>() + ); + + final TopologyException exception = assertThrows( + TopologyException.class, + () -> builder.addGlobalStore( + secondGlobalBuilder, + "global-store-2", + null, + null, + null, + "global-topic", + "global-processor-2", + new MockApiProcessorSupplier<>() + ) + ); + + assertThat( + exception.getMessage(), + equalTo("Invalid topology: A different GlobalStateStore has already been added with the name testStore") + ); + } + + @Test + public void testAddStateStore() { + builder.addStateStore(storeBuilder); + builder.setApplicationId("X"); + builder.addSource(null, "source-1", null, null, null, "topic-1"); + builder.addProcessor("processor-1", new MockApiProcessorSupplier<>(), "source-1"); + + assertEquals(0, builder.buildTopology().stateStores().size()); + + builder.connectProcessorAndStateStores("processor-1", storeBuilder.name()); + + final List suppliers = builder.buildTopology().stateStores(); + assertEquals(1, suppliers.size()); + assertEquals(storeBuilder.name(), suppliers.get(0).name()); + } + + @Test + public void shouldAllowAddingSameStoreBuilderMultipleTimes() { + builder.setApplicationId("X"); + builder.addSource(null, "source-1", null, null, null, "topic-1"); + + builder.addStateStore(storeBuilder); + builder.addProcessor("processor-1", new MockApiProcessorSupplier<>(), "source-1"); + builder.connectProcessorAndStateStores("processor-1", storeBuilder.name()); + + builder.addStateStore(storeBuilder); + builder.addProcessor("processor-2", new MockApiProcessorSupplier<>(), "source-1"); + builder.connectProcessorAndStateStores("processor-2", storeBuilder.name()); + + assertEquals(1, builder.buildTopology().stateStores().size()); + } + + @Test + public void testTopicGroups() { + builder.setApplicationId("X"); + builder.addInternalTopic("topic-1x", InternalTopicProperties.empty()); + builder.addSource(null, "source-1", null, null, null, "topic-1", "topic-1x"); + builder.addSource(null, "source-2", null, null, null, "topic-2"); + builder.addSource(null, "source-3", null, null, null, "topic-3"); + builder.addSource(null, "source-4", null, null, null, "topic-4"); + builder.addSource(null, "source-5", null, null, null, "topic-5"); + + builder.addProcessor("processor-1", new MockApiProcessorSupplier<>(), "source-1"); + + builder.addProcessor("processor-2", new MockApiProcessorSupplier<>(), "source-2", "processor-1"); + builder.copartitionSources(asList("source-1", "source-2")); + + builder.addProcessor("processor-3", new MockApiProcessorSupplier<>(), "source-3", "source-4"); + + final Map topicGroups = builder.topicGroups(); + + final Map expectedTopicGroups = new HashMap<>(); + expectedTopicGroups.put(SUBTOPOLOGY_0, new InternalTopologyBuilder.TopicsInfo(Collections.emptySet(), mkSet("topic-1", "X-topic-1x", "topic-2"), Collections.emptyMap(), Collections.emptyMap())); + expectedTopicGroups.put(SUBTOPOLOGY_1, new InternalTopologyBuilder.TopicsInfo(Collections.emptySet(), mkSet("topic-3", "topic-4"), Collections.emptyMap(), Collections.emptyMap())); + expectedTopicGroups.put(SUBTOPOLOGY_2, new InternalTopologyBuilder.TopicsInfo(Collections.emptySet(), mkSet("topic-5"), Collections.emptyMap(), Collections.emptyMap())); + + assertEquals(3, topicGroups.size()); + assertEquals(expectedTopicGroups, topicGroups); + + final Collection> copartitionGroups = builder.copartitionGroups(); + + assertEquals(mkSet(mkSet("topic-1", "X-topic-1x", "topic-2")), new HashSet<>(copartitionGroups)); + } + + @Test + public void testTopicGroupsByStateStore() { + builder.setApplicationId("X"); + builder.addSource(null, "source-1", null, null, null, "topic-1", "topic-1x"); + builder.addSource(null, "source-2", null, null, null, "topic-2"); + builder.addSource(null, "source-3", null, null, null, "topic-3"); + builder.addSource(null, "source-4", null, null, null, "topic-4"); + builder.addSource(null, "source-5", null, null, null, "topic-5"); + + builder.addProcessor("processor-1", new MockApiProcessorSupplier<>(), "source-1"); + builder.addProcessor("processor-2", new MockApiProcessorSupplier<>(), "source-2"); + builder.addStateStore(new MockKeyValueStoreBuilder("store-1", false), "processor-1", "processor-2"); + + builder.addProcessor("processor-3", new MockApiProcessorSupplier<>(), "source-3"); + builder.addProcessor("processor-4", new MockApiProcessorSupplier<>(), "source-4"); + builder.addStateStore(new MockKeyValueStoreBuilder("store-2", false), "processor-3", "processor-4"); + + builder.addProcessor("processor-5", new MockApiProcessorSupplier<>(), "source-5"); + builder.addStateStore(new MockKeyValueStoreBuilder("store-3", false)); + builder.connectProcessorAndStateStores("processor-5", "store-3"); + builder.buildTopology(); + + final Map topicGroups = builder.topicGroups(); + + final Map expectedTopicGroups = new HashMap<>(); + final String store1 = ProcessorStateManager.storeChangelogTopic("X", "store-1", builder.topologyName()); + final String store2 = ProcessorStateManager.storeChangelogTopic("X", "store-2", builder.topologyName()); + final String store3 = ProcessorStateManager.storeChangelogTopic("X", "store-3", builder.topologyName()); + expectedTopicGroups.put(SUBTOPOLOGY_0, new InternalTopologyBuilder.TopicsInfo( + Collections.emptySet(), mkSet("topic-1", "topic-1x", "topic-2"), + Collections.emptyMap(), + Collections.singletonMap(store1, new UnwindowedChangelogTopicConfig(store1, Collections.emptyMap())))); + expectedTopicGroups.put(SUBTOPOLOGY_1, new InternalTopologyBuilder.TopicsInfo( + Collections.emptySet(), mkSet("topic-3", "topic-4"), + Collections.emptyMap(), + Collections.singletonMap(store2, new UnwindowedChangelogTopicConfig(store2, Collections.emptyMap())))); + expectedTopicGroups.put(SUBTOPOLOGY_2, new InternalTopologyBuilder.TopicsInfo( + Collections.emptySet(), mkSet("topic-5"), + Collections.emptyMap(), + Collections.singletonMap(store3, new UnwindowedChangelogTopicConfig(store3, Collections.emptyMap())))); + + assertEquals(3, topicGroups.size()); + assertEquals(expectedTopicGroups, topicGroups); + } + + @Test + public void testBuild() { + builder.addSource(null, "source-1", null, null, null, "topic-1", "topic-1x"); + builder.addSource(null, "source-2", null, null, null, "topic-2"); + builder.addSource(null, "source-3", null, null, null, "topic-3"); + builder.addSource(null, "source-4", null, null, null, "topic-4"); + builder.addSource(null, "source-5", null, null, null, "topic-5"); + + builder.addProcessor("processor-1", new MockApiProcessorSupplier<>(), "source-1"); + builder.addProcessor("processor-2", new MockApiProcessorSupplier<>(), "source-2", "processor-1"); + builder.addProcessor("processor-3", new MockApiProcessorSupplier<>(), "source-3", "source-4"); + + builder.setApplicationId("X"); + final ProcessorTopology topology0 = builder.buildSubtopology(0); + final ProcessorTopology topology1 = builder.buildSubtopology(1); + final ProcessorTopology topology2 = builder.buildSubtopology(2); + + assertEquals(mkSet("source-1", "source-2", "processor-1", "processor-2"), nodeNames(topology0.processors())); + assertEquals(mkSet("source-3", "source-4", "processor-3"), nodeNames(topology1.processors())); + assertEquals(mkSet("source-5"), nodeNames(topology2.processors())); + } + + @Test + public void shouldAllowIncrementalBuilds() { + Map> oldNodeGroups, newNodeGroups; + + oldNodeGroups = builder.nodeGroups(); + builder.addSource(null, "source-1", null, null, null, "topic-1"); + builder.addSource(null, "source-2", null, null, null, "topic-2"); + newNodeGroups = builder.nodeGroups(); + assertNotEquals(oldNodeGroups, newNodeGroups); + + oldNodeGroups = newNodeGroups; + builder.addSource(null, "source-3", null, null, null, Pattern.compile("")); + builder.addSource(null, "source-4", null, null, null, Pattern.compile("")); + newNodeGroups = builder.nodeGroups(); + assertNotEquals(oldNodeGroups, newNodeGroups); + + oldNodeGroups = newNodeGroups; + builder.addProcessor("processor-1", new MockApiProcessorSupplier<>(), "source-1"); + builder.addProcessor("processor-2", new MockApiProcessorSupplier<>(), "source-2"); + builder.addProcessor("processor-3", new MockApiProcessorSupplier<>(), "source-3"); + newNodeGroups = builder.nodeGroups(); + assertNotEquals(oldNodeGroups, newNodeGroups); + + oldNodeGroups = newNodeGroups; + builder.addSink("sink-1", "sink-topic", null, null, null, "processor-1"); + newNodeGroups = builder.nodeGroups(); + assertNotEquals(oldNodeGroups, newNodeGroups); + + oldNodeGroups = newNodeGroups; + builder.addSink("sink-2", (k, v, ctx) -> "sink-topic", null, null, null, "processor-2"); + newNodeGroups = builder.nodeGroups(); + assertNotEquals(oldNodeGroups, newNodeGroups); + + oldNodeGroups = newNodeGroups; + builder.addStateStore(new MockKeyValueStoreBuilder("store-1", false), "processor-1", "processor-2"); + newNodeGroups = builder.nodeGroups(); + assertNotEquals(oldNodeGroups, newNodeGroups); + + oldNodeGroups = newNodeGroups; + builder.addStateStore(new MockKeyValueStoreBuilder("store-2", false)); + builder.connectProcessorAndStateStores("processor-2", "store-2"); + builder.connectProcessorAndStateStores("processor-3", "store-2"); + newNodeGroups = builder.nodeGroups(); + assertNotEquals(oldNodeGroups, newNodeGroups); + + oldNodeGroups = newNodeGroups; + builder.addGlobalStore( + new MockKeyValueStoreBuilder("global-store", false).withLoggingDisabled(), + "globalSource", + null, + null, + null, + "globalTopic", + "global-processor", + new MockApiProcessorSupplier<>() + ); + newNodeGroups = builder.nodeGroups(); + assertNotEquals(oldNodeGroups, newNodeGroups); + } + + @Test + public void shouldNotAllowNullNameWhenAddingSink() { + assertThrows(NullPointerException.class, () -> builder.addSink(null, "topic", null, null, null)); + } + + @Test + public void shouldNotAllowNullTopicWhenAddingSink() { + assertThrows(NullPointerException.class, () -> builder.addSink("name", (String) null, null, null, null)); + } + + @Test + public void shouldNotAllowNullTopicChooserWhenAddingSink() { + assertThrows(NullPointerException.class, () -> builder.addSink("name", (TopicNameExtractor) null, null, null, null)); + } + + @Test + public void shouldNotAllowNullNameWhenAddingProcessor() { + assertThrows(NullPointerException.class, () -> builder.addProcessor(null, () -> null)); + } + + @Test + public void shouldNotAllowNullProcessorSupplier() { + assertThrows(NullPointerException.class, () -> builder.addProcessor("name", null)); + } + + @Test + public void shouldNotAllowNullNameWhenAddingSource() { + assertThrows(NullPointerException.class, () -> builder.addSource(null, null, null, null, null, Pattern.compile(".*"))); + } + + @Test + public void shouldNotAllowNullProcessorNameWhenConnectingProcessorAndStateStores() { + assertThrows(NullPointerException.class, () -> builder.connectProcessorAndStateStores(null, "store")); + } + + @Test + public void shouldNotAllowNullStateStoreNameWhenConnectingProcessorAndStateStores() { + assertThrows(NullPointerException.class, () -> builder.connectProcessorAndStateStores("processor", new String[]{null})); + } + + @Test + public void shouldNotAddNullInternalTopic() { + assertThrows(NullPointerException.class, () -> builder.addInternalTopic(null, InternalTopicProperties.empty())); + } + + @Test + public void shouldNotAddNullInternalTopicProperties() { + assertThrows(NullPointerException.class, () -> builder.addInternalTopic("topic", null)); + } + + @Test + public void shouldNotSetApplicationIdToNull() { + assertThrows(NullPointerException.class, () -> builder.setApplicationId(null)); + } + + @Test + public void shouldNotSetStreamsConfigToNull() { + assertThrows(NullPointerException.class, () -> builder.setStreamsConfig(null)); + } + + @Test + public void shouldNotAddNullStateStoreSupplier() { + assertThrows(NullPointerException.class, () -> builder.addStateStore(null)); + } + + private Set nodeNames(final Collection> nodes) { + final Set nodeNames = new HashSet<>(); + for (final ProcessorNode node : nodes) { + nodeNames.add(node.name()); + } + return nodeNames; + } + + @Test + public void shouldAssociateStateStoreNameWhenStateStoreSupplierIsInternal() { + builder.addSource(null, "source", null, null, null, "topic"); + builder.addProcessor("processor", new MockApiProcessorSupplier<>(), "source"); + builder.addStateStore(storeBuilder, "processor"); + final Map> stateStoreNameToSourceTopic = builder.stateStoreNameToSourceTopics(); + assertEquals(1, stateStoreNameToSourceTopic.size()); + assertEquals(Collections.singletonList("topic"), stateStoreNameToSourceTopic.get("testStore")); + } + + @Test + public void shouldAssociateStateStoreNameWhenStateStoreSupplierIsExternal() { + builder.addSource(null, "source", null, null, null, "topic"); + builder.addProcessor("processor", new MockApiProcessorSupplier<>(), "source"); + builder.addStateStore(storeBuilder, "processor"); + final Map> stateStoreNameToSourceTopic = builder.stateStoreNameToSourceTopics(); + assertEquals(1, stateStoreNameToSourceTopic.size()); + assertEquals(Collections.singletonList("topic"), stateStoreNameToSourceTopic.get("testStore")); + } + + @Test + public void shouldCorrectlyMapStateStoreToInternalTopics() { + builder.setApplicationId("appId"); + builder.addInternalTopic("internal-topic", InternalTopicProperties.empty()); + builder.addSource(null, "source", null, null, null, "internal-topic"); + builder.addProcessor("processor", new MockApiProcessorSupplier<>(), "source"); + builder.addStateStore(storeBuilder, "processor"); + final Map> stateStoreNameToSourceTopic = builder.stateStoreNameToSourceTopics(); + assertEquals(1, stateStoreNameToSourceTopic.size()); + assertEquals(Collections.singletonList("appId-internal-topic"), stateStoreNameToSourceTopic.get("testStore")); + } + + @Test + public void shouldAddInternalTopicConfigForWindowStores() { + builder.setApplicationId("appId"); + builder.addSource(null, "source", null, null, null, "topic"); + builder.addProcessor("processor", new MockApiProcessorSupplier<>(), "source"); + builder.addStateStore( + Stores.windowStoreBuilder( + Stores.persistentWindowStore("store1", ofSeconds(30L), ofSeconds(10L), false), + Serdes.String(), + Serdes.String() + ), + "processor" + ); + builder.addStateStore( + Stores.sessionStoreBuilder( + Stores.persistentSessionStore("store2", ofSeconds(30)), Serdes.String(), Serdes.String() + ), + "processor" + ); + builder.buildTopology(); + final Map topicGroups = builder.topicGroups(); + final InternalTopologyBuilder.TopicsInfo topicsInfo = topicGroups.values().iterator().next(); + final InternalTopicConfig topicConfig1 = topicsInfo.stateChangelogTopics.get("appId-store1-changelog"); + final Map properties1 = topicConfig1.getProperties(Collections.emptyMap(), 10000); + assertEquals(3, properties1.size()); + assertEquals(TopicConfig.CLEANUP_POLICY_COMPACT + "," + TopicConfig.CLEANUP_POLICY_DELETE, properties1.get(TopicConfig.CLEANUP_POLICY_CONFIG)); + assertEquals("40000", properties1.get(TopicConfig.RETENTION_MS_CONFIG)); + assertEquals("appId-store1-changelog", topicConfig1.name()); + assertTrue(topicConfig1 instanceof WindowedChangelogTopicConfig); + final InternalTopicConfig topicConfig2 = topicsInfo.stateChangelogTopics.get("appId-store2-changelog"); + final Map properties2 = topicConfig2.getProperties(Collections.emptyMap(), 10000); + assertEquals(3, properties2.size()); + assertEquals(TopicConfig.CLEANUP_POLICY_COMPACT + "," + TopicConfig.CLEANUP_POLICY_DELETE, properties2.get(TopicConfig.CLEANUP_POLICY_CONFIG)); + assertEquals("40000", properties2.get(TopicConfig.RETENTION_MS_CONFIG)); + assertEquals("appId-store2-changelog", topicConfig2.name()); + assertTrue(topicConfig2 instanceof WindowedChangelogTopicConfig); + } + + @Test + public void shouldAddInternalTopicConfigForNonWindowStores() { + builder.setApplicationId("appId"); + builder.addSource(null, "source", null, null, null, "topic"); + builder.addProcessor("processor", new MockApiProcessorSupplier<>(), "source"); + builder.addStateStore(storeBuilder, "processor"); + builder.buildTopology(); + final Map topicGroups = builder.topicGroups(); + final InternalTopologyBuilder.TopicsInfo topicsInfo = topicGroups.values().iterator().next(); + final InternalTopicConfig topicConfig = topicsInfo.stateChangelogTopics.get("appId-testStore-changelog"); + final Map properties = topicConfig.getProperties(Collections.emptyMap(), 10000); + assertEquals(2, properties.size()); + assertEquals(TopicConfig.CLEANUP_POLICY_COMPACT, properties.get(TopicConfig.CLEANUP_POLICY_CONFIG)); + assertEquals("appId-testStore-changelog", topicConfig.name()); + assertTrue(topicConfig instanceof UnwindowedChangelogTopicConfig); + } + + @Test + public void shouldAddInternalTopicConfigForRepartitionTopics() { + builder.setApplicationId("appId"); + builder.addInternalTopic("foo", InternalTopicProperties.empty()); + builder.addSource(null, "source", null, null, null, "foo"); + builder.buildTopology(); + final InternalTopologyBuilder.TopicsInfo topicsInfo = builder.topicGroups().values().iterator().next(); + final InternalTopicConfig topicConfig = topicsInfo.repartitionSourceTopics.get("appId-foo"); + final Map properties = topicConfig.getProperties(Collections.emptyMap(), 10000); + assertEquals(4, properties.size()); + assertEquals(String.valueOf(-1), properties.get(TopicConfig.RETENTION_MS_CONFIG)); + assertEquals(TopicConfig.CLEANUP_POLICY_DELETE, properties.get(TopicConfig.CLEANUP_POLICY_CONFIG)); + assertEquals("appId-foo", topicConfig.name()); + assertTrue(topicConfig instanceof RepartitionTopicConfig); + } + + @Test + public void shouldSetCorrectSourceNodesWithRegexUpdatedTopics() { + builder.addSource(null, "source-1", null, null, null, "topic-foo"); + builder.addSource(null, "source-2", null, null, null, Pattern.compile("topic-[A-C]")); + builder.addSource(null, "source-3", null, null, null, Pattern.compile("topic-\\d")); + + final Set updatedTopics = new HashSet<>(); + + updatedTopics.add("topic-B"); + updatedTopics.add("topic-3"); + updatedTopics.add("topic-A"); + + builder.addSubscribedTopicsFromMetadata(updatedTopics, null); + builder.setApplicationId("test-id"); + + final Map topicGroups = builder.topicGroups(); + assertTrue(topicGroups.get(SUBTOPOLOGY_0).sourceTopics.contains("topic-foo")); + assertTrue(topicGroups.get(SUBTOPOLOGY_1).sourceTopics.contains("topic-A")); + assertTrue(topicGroups.get(SUBTOPOLOGY_1).sourceTopics.contains("topic-B")); + assertTrue(topicGroups.get(SUBTOPOLOGY_2).sourceTopics.contains("topic-3")); + } + + @Test + public void shouldSetStreamsConfigOnRewriteTopology() { + final StreamsConfig config = new StreamsConfig(StreamsTestUtils.getStreamsConfig()); + final InternalTopologyBuilder topologyBuilder = builder.rewriteTopology(config); + assertThat(topologyBuilder.getStreamsConfig(), equalTo(config)); + } + + @Test + public void shouldAddTimestampExtractorPerSource() { + builder.addSource(null, "source", new MockTimestampExtractor(), null, null, "topic"); + final ProcessorTopology processorTopology = builder.rewriteTopology(new StreamsConfig(StreamsTestUtils.getStreamsConfig())).buildTopology(); + assertThat(processorTopology.source("topic").getTimestampExtractor(), instanceOf(MockTimestampExtractor.class)); + } + + @Test + public void shouldAddTimestampExtractorWithPatternPerSource() { + final Pattern pattern = Pattern.compile("t.*"); + builder.addSource(null, "source", new MockTimestampExtractor(), null, null, pattern); + final ProcessorTopology processorTopology = builder.rewriteTopology(new StreamsConfig(StreamsTestUtils.getStreamsConfig())).buildTopology(); + assertThat(processorTopology.source(pattern.pattern()).getTimestampExtractor(), instanceOf(MockTimestampExtractor.class)); + } + + @Test + public void shouldSortProcessorNodesCorrectly() { + builder.addSource(null, "source1", null, null, null, "topic1"); + builder.addSource(null, "source2", null, null, null, "topic2"); + builder.addProcessor("processor1", new MockApiProcessorSupplier<>(), "source1"); + builder.addProcessor("processor2", new MockApiProcessorSupplier<>(), "source1", "source2"); + builder.addProcessor("processor3", new MockApiProcessorSupplier<>(), "processor2"); + builder.addSink("sink1", "topic2", null, null, null, "processor1", "processor3"); + + assertEquals(1, builder.describe().subtopologies().size()); + + final Iterator iterator = ((SubtopologyDescription) builder.describe().subtopologies().iterator().next()).nodesInOrder(); + + assertTrue(iterator.hasNext()); + InternalTopologyBuilder.AbstractNode node = (InternalTopologyBuilder.AbstractNode) iterator.next(); + assertEquals("source1", node.name); + assertEquals(6, node.size); + + assertTrue(iterator.hasNext()); + node = (InternalTopologyBuilder.AbstractNode) iterator.next(); + assertEquals("source2", node.name); + assertEquals(4, node.size); + + assertTrue(iterator.hasNext()); + node = (InternalTopologyBuilder.AbstractNode) iterator.next(); + assertEquals("processor2", node.name); + assertEquals(3, node.size); + + assertTrue(iterator.hasNext()); + node = (InternalTopologyBuilder.AbstractNode) iterator.next(); + assertEquals("processor1", node.name); + assertEquals(2, node.size); + + assertTrue(iterator.hasNext()); + node = (InternalTopologyBuilder.AbstractNode) iterator.next(); + assertEquals("processor3", node.name); + assertEquals(2, node.size); + + assertTrue(iterator.hasNext()); + node = (InternalTopologyBuilder.AbstractNode) iterator.next(); + assertEquals("sink1", node.name); + assertEquals(1, node.size); + } + + @Test + public void shouldConnectRegexMatchedTopicsToStateStore() { + builder.addSource(null, "ingest", null, null, null, Pattern.compile("topic-\\d+")); + builder.addProcessor("my-processor", new MockApiProcessorSupplier<>(), "ingest"); + builder.addStateStore(storeBuilder, "my-processor"); + + final Set updatedTopics = new HashSet<>(); + + updatedTopics.add("topic-2"); + updatedTopics.add("topic-3"); + updatedTopics.add("topic-A"); + + builder.addSubscribedTopicsFromMetadata(updatedTopics, "test-thread"); + builder.setApplicationId("test-app"); + + final Map> stateStoreAndTopics = builder.stateStoreNameToSourceTopics(); + final List topics = stateStoreAndTopics.get(storeBuilder.name()); + + assertEquals("Expected to contain two topics", 2, topics.size()); + + assertTrue(topics.contains("topic-2")); + assertTrue(topics.contains("topic-3")); + assertFalse(topics.contains("topic-A")); + } + + @Test + public void shouldNotAllowToAddGlobalStoreWithSourceNameEqualsProcessorName() { + final String sameNameForSourceAndProcessor = "sameName"; + assertThrows(TopologyException.class, () -> builder.addGlobalStore( + storeBuilder, + sameNameForSourceAndProcessor, + null, + null, + null, + "anyTopicName", + sameNameForSourceAndProcessor, + new MockApiProcessorSupplier<>() + )); + } + + @Test + public void shouldThrowIfNameIsNull() { + final Exception e = assertThrows(NullPointerException.class, () -> new InternalTopologyBuilder.Source(null, Collections.emptySet(), null)); + assertEquals("name cannot be null", e.getMessage()); + } + + @Test + public void shouldThrowIfTopicAndPatternAreNull() { + final Exception e = assertThrows(IllegalArgumentException.class, () -> new InternalTopologyBuilder.Source("name", null, null)); + assertEquals("Either topics or pattern must be not-null, but both are null.", e.getMessage()); + } + + @Test + public void shouldThrowIfBothTopicAndPatternAreNotNull() { + final Exception e = assertThrows(IllegalArgumentException.class, () -> new InternalTopologyBuilder.Source("name", Collections.emptySet(), Pattern.compile(""))); + assertEquals("Either topics or pattern must be null, but both are not null.", e.getMessage()); + } + + @Test + public void sourceShouldBeEqualIfNameAndTopicListAreTheSame() { + final InternalTopologyBuilder.Source base = new InternalTopologyBuilder.Source("name", Collections.singleton("topic"), null); + final InternalTopologyBuilder.Source sameAsBase = new InternalTopologyBuilder.Source("name", Collections.singleton("topic"), null); + + assertThat(base, equalTo(sameAsBase)); + } + + @Test + public void sourceShouldBeEqualIfNameAndPatternAreTheSame() { + final InternalTopologyBuilder.Source base = new InternalTopologyBuilder.Source("name", null, Pattern.compile("topic")); + final InternalTopologyBuilder.Source sameAsBase = new InternalTopologyBuilder.Source("name", null, Pattern.compile("topic")); + + assertThat(base, equalTo(sameAsBase)); + } + + @Test + public void sourceShouldNotBeEqualForDifferentNamesWithSameTopicList() { + final InternalTopologyBuilder.Source base = new InternalTopologyBuilder.Source("name", Collections.singleton("topic"), null); + final InternalTopologyBuilder.Source differentName = new InternalTopologyBuilder.Source("name2", Collections.singleton("topic"), null); + + assertThat(base, not(equalTo(differentName))); + } + + @Test + public void sourceShouldNotBeEqualForDifferentNamesWithSamePattern() { + final InternalTopologyBuilder.Source base = new InternalTopologyBuilder.Source("name", null, Pattern.compile("topic")); + final InternalTopologyBuilder.Source differentName = new InternalTopologyBuilder.Source("name2", null, Pattern.compile("topic")); + + assertThat(base, not(equalTo(differentName))); + } + + @Test + public void sourceShouldNotBeEqualForDifferentTopicList() { + final InternalTopologyBuilder.Source base = new InternalTopologyBuilder.Source("name", Collections.singleton("topic"), null); + final InternalTopologyBuilder.Source differentTopicList = new InternalTopologyBuilder.Source("name", Collections.emptySet(), null); + final InternalTopologyBuilder.Source differentTopic = new InternalTopologyBuilder.Source("name", Collections.singleton("topic2"), null); + + assertThat(base, not(equalTo(differentTopicList))); + assertThat(base, not(equalTo(differentTopic))); + } + + @Test + public void sourceShouldNotBeEqualForDifferentPattern() { + final InternalTopologyBuilder.Source base = new InternalTopologyBuilder.Source("name", null, Pattern.compile("topic")); + final InternalTopologyBuilder.Source differentPattern = new InternalTopologyBuilder.Source("name", null, Pattern.compile("topic2")); + final InternalTopologyBuilder.Source overlappingPattern = new InternalTopologyBuilder.Source("name", null, Pattern.compile("top*")); + + assertThat(base, not(equalTo(differentPattern))); + assertThat(base, not(equalTo(overlappingPattern))); + } + + @Test + public void shouldHaveCorrectInternalTopicConfigWhenInternalTopicPropertiesArePresent() { + final int numberOfPartitions = 10; + builder.setApplicationId("Z"); + builder.addInternalTopic("topic-1z", new InternalTopicProperties(numberOfPartitions)); + builder.addSource(null, "source-1", null, null, null, "topic-1z"); + + final Map topicGroups = builder.topicGroups(); + + final Map repartitionSourceTopics = topicGroups.get(SUBTOPOLOGY_0).repartitionSourceTopics; + + assertEquals( + repartitionSourceTopics.get("Z-topic-1z"), + new RepartitionTopicConfig( + "Z-topic-1z", + Collections.emptyMap(), + numberOfPartitions, + true + ) + ); + } + + @Test + public void shouldHandleWhenTopicPropertiesNumberOfPartitionsIsNull() { + builder.setApplicationId("T"); + builder.addInternalTopic("topic-1t", InternalTopicProperties.empty()); + builder.addSource(null, "source-1", null, null, null, "topic-1t"); + + final Map topicGroups = builder.topicGroups(); + + final Map repartitionSourceTopics = topicGroups.get(SUBTOPOLOGY_0).repartitionSourceTopics; + + assertEquals( + repartitionSourceTopics.get("T-topic-1t"), + new RepartitionTopicConfig( + "T-topic-1t", + Collections.emptyMap() + ) + ); + } + + @Test + public void shouldHaveCorrectInternalTopicConfigWhenInternalTopicPropertiesAreNotPresent() { + builder.setApplicationId("Y"); + builder.addInternalTopic("topic-1y", InternalTopicProperties.empty()); + builder.addSource(null, "source-1", null, null, null, "topic-1y"); + + final Map topicGroups = builder.topicGroups(); + + final Map repartitionSourceTopics = topicGroups.get(SUBTOPOLOGY_0).repartitionSourceTopics; + + assertEquals( + repartitionSourceTopics.get("Y-topic-1y"), + new RepartitionTopicConfig("Y-topic-1y", Collections.emptyMap()) + ); + } + + @Test + public void shouldConnectGlobalStateStoreToInputTopic() { + final String globalStoreName = "global-store"; + final String globalTopic = "global-topic"; + builder.setApplicationId("X"); + builder.addGlobalStore( + new MockKeyValueStoreBuilder(globalStoreName, false).withLoggingDisabled(), + "globalSource", + null, + null, + null, + globalTopic, + "global-processor", + new MockApiProcessorSupplier<>() + ); + builder.initializeSubscription(); + + builder.rewriteTopology(new StreamsConfig(mkProperties(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "asdf"), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "asdf") + )))); + + assertThat(builder.buildGlobalStateTopology().storeToChangelogTopic().get(globalStoreName), is(globalTopic)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockChangelogReader.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockChangelogReader.java new file mode 100644 index 0000000..6ea7fc3 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockChangelogReader.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.streams.processor.TaskId; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +public class MockChangelogReader implements ChangelogReader { + private final Set restoringPartitions = new HashSet<>(); + private Map restoredOffsets = Collections.emptyMap(); + + public boolean isPartitionRegistered(final TopicPartition partition) { + return restoringPartitions.contains(partition); + } + + @Override + public void register(final TopicPartition partition, final ProcessorStateManager stateManager) { + restoringPartitions.add(partition); + } + + @Override + public void restore(final Map tasks) { + // do nothing + } + + @Override + public void enforceRestoreActive() { + // do nothing + } + + @Override + public void transitToUpdateStandby() { + // do nothing + } + + @Override + public Set completedChangelogs() { + // assuming all restoring partitions are completed + return restoringPartitions; + } + + @Override + public void clear() { + restoringPartitions.clear(); + } + + @Override + public void unregister(final Collection partitions) { + restoringPartitions.removeAll(partitions); + + for (final TopicPartition partition : partitions) { + restoredOffsets.remove(partition); + } + } + + @Override + public boolean isEmpty() { + return restoredOffsets.isEmpty() && restoringPartitions.isEmpty(); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockStreamsMetrics.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockStreamsMetrics.java new file mode 100644 index 0000000..bb0303c --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/MockStreamsMetrics.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; + +public class MockStreamsMetrics extends StreamsMetricsImpl { + + public MockStreamsMetrics(final Metrics metrics) { + super(metrics, "test", StreamsConfig.METRICS_LATEST, new MockTime()); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/NamedTopologyTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/NamedTopologyTest.java new file mode 100644 index 0000000..59b3eda --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/NamedTopologyTest.java @@ -0,0 +1,302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.streams.KafkaClientSupplier; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.TopologyException; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.processor.internals.namedtopology.KafkaStreamsNamedTopologyWrapper; +import org.apache.kafka.streams.processor.internals.namedtopology.NamedTopology; +import org.apache.kafka.streams.processor.internals.namedtopology.NamedTopologyStreamsBuilder; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.test.TestUtils; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import java.util.Properties; +import java.util.regex.Pattern; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertThrows; +import static java.util.Arrays.asList; + +public class NamedTopologyTest { + final KafkaClientSupplier clientSupplier = new DefaultKafkaClientSupplier(); + final Properties props = configProps(); + + final NamedTopologyStreamsBuilder builder1 = new NamedTopologyStreamsBuilder("topology-1"); + final NamedTopologyStreamsBuilder builder2 = new NamedTopologyStreamsBuilder("topology-2"); + final NamedTopologyStreamsBuilder builder3 = new NamedTopologyStreamsBuilder("topology-3"); + + KafkaStreamsNamedTopologyWrapper streams; + + @Before + public void setup() { + builder1.stream("input-1"); + builder2.stream("input-2"); + builder3.stream("input-3"); + } + + @After + public void cleanup() { + if (streams != null) { + streams.close(); + } + } + + private static Properties configProps() { + final Properties props = new Properties(); + props.put(StreamsConfig.APPLICATION_ID_CONFIG, "Named-Topology-App"); + props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:2018"); + props.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()); + return props; + } + + @Test + public void shouldThrowIllegalArgumentOnIllegalName() { + assertThrows(IllegalArgumentException.class, () -> new NamedTopologyStreamsBuilder("__not-allowed__")); + } + + @Test + public void shouldBuildSingleNamedTopology() { + builder1.stream("stream-1").filter((k, v) -> !k.equals(v)).to("output-1"); + + streams = new KafkaStreamsNamedTopologyWrapper(builder1.buildNamedTopology(props), props, clientSupplier); + } + + @Test + public void shouldBuildMultipleIdenticalNamedTopologyWithRepartition() { + builder1.stream("stream-1").selectKey((k, v) -> v).groupByKey().count().toStream().to("output-1"); + builder2.stream("stream-2").selectKey((k, v) -> v).groupByKey().count().toStream().to("output-2"); + builder3.stream("stream-3").selectKey((k, v) -> v).groupByKey().count().toStream().to("output-3"); + + streams = new KafkaStreamsNamedTopologyWrapper( + asList( + builder1.buildNamedTopology(props), + builder2.buildNamedTopology(props), + builder3.buildNamedTopology(props)), + props, + clientSupplier + ); + } + + @Test + public void shouldReturnTopologyByName() { + final NamedTopology topology1 = builder1.buildNamedTopology(props); + final NamedTopology topology2 = builder2.buildNamedTopology(props); + final NamedTopology topology3 = builder3.buildNamedTopology(props); + streams = new KafkaStreamsNamedTopologyWrapper(asList(topology1, topology2, topology3), props, clientSupplier); + assertThat(streams.getTopologyByName("topology-1").get(), equalTo(topology1)); + assertThat(streams.getTopologyByName("topology-2").get(), equalTo(topology2)); + assertThat(streams.getTopologyByName("topology-3").get(), equalTo(topology3)); + } + + @Test + public void shouldReturnEmptyWhenLookingUpNonExistentTopologyByName() { + streams = new KafkaStreamsNamedTopologyWrapper(builder1.buildNamedTopology(props), props, clientSupplier); + assertThat(streams.getTopologyByName("non-existent-topology").isPresent(), equalTo(false)); + } + + @Test + public void shouldAllowSameStoreNameToBeUsedByMultipleNamedTopologies() { + builder1.stream("stream-1").selectKey((k, v) -> v).groupByKey().count(Materialized.as(Stores.inMemoryKeyValueStore("store"))); + builder2.stream("stream-2").selectKey((k, v) -> v).groupByKey().count(Materialized.as(Stores.inMemoryKeyValueStore("store"))); + + streams = new KafkaStreamsNamedTopologyWrapper(asList( + builder1.buildNamedTopology(props), + builder2.buildNamedTopology(props)), + props, + clientSupplier + ); + } + + @Test + public void shouldThrowTopologyExceptionWhenMultipleNamedTopologiesCreateStreamFromSameInputTopic() { + builder1.stream("stream"); + builder2.stream("stream"); + + assertThrows( + TopologyException.class, + () -> streams = new KafkaStreamsNamedTopologyWrapper( + asList( + builder1.buildNamedTopology(props), + builder2.buildNamedTopology(props)), + props, + clientSupplier) + ); + } + + @Test + public void shouldThrowTopologyExceptionWhenMultipleNamedTopologiesCreateTableFromSameInputTopic() { + builder1.table("table"); + builder2.table("table"); + + assertThrows( + TopologyException.class, + () -> streams = new KafkaStreamsNamedTopologyWrapper( + asList( + builder1.buildNamedTopology(props), + builder2.buildNamedTopology(props)), + props, + clientSupplier) + ); + } + + @Test + public void shouldThrowTopologyExceptionWhenMultipleNamedTopologiesCreateStreamAndTableFromSameInputTopic() { + builder1.stream("input"); + builder2.table("input"); + + assertThrows( + TopologyException.class, + () -> streams = new KafkaStreamsNamedTopologyWrapper( + asList( + builder1.buildNamedTopology(props), + builder2.buildNamedTopology(props)), + props, + clientSupplier) + ); + } + + @Test + public void shouldThrowTopologyExceptionWhenMultipleNamedTopologiesCreateStreamFromOverlappingInputTopicCollection() { + builder1.stream("stream"); + builder2.stream(asList("unique-input", "stream")); + + assertThrows( + TopologyException.class, + () -> streams = new KafkaStreamsNamedTopologyWrapper( + asList( + builder1.buildNamedTopology(props), + builder2.buildNamedTopology(props)), + props, + clientSupplier) + ); + } + + @Test + public void shouldThrowTopologyExceptionWhenMultipleNamedTopologiesCreateStreamFromSamePattern() { + builder1.stream(Pattern.compile("some-regex")); + builder2.stream(Pattern.compile("some-regex")); + + assertThrows( + TopologyException.class, + () -> streams = new KafkaStreamsNamedTopologyWrapper( + asList( + builder1.buildNamedTopology(props), + builder2.buildNamedTopology(props)), + props, + clientSupplier) + ); + } + + @Test + public void shouldDescribeWithSingleNamedTopology() { + builder1.stream("input").filter((k, v) -> !k.equals(v)).to("output"); + streams = new KafkaStreamsNamedTopologyWrapper(builder1.buildNamedTopology(props), props, clientSupplier); + + assertThat( + streams.getFullTopologyDescription(), + equalTo( + "Topology - topology-1:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input-1])\n" + + " --> none\n" + + "\n" + + " Sub-topology: 1\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [input])\n" + + " --> KSTREAM-FILTER-0000000002\n" + + " Processor: KSTREAM-FILTER-0000000002 (stores: [])\n" + + " --> KSTREAM-SINK-0000000003\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Sink: KSTREAM-SINK-0000000003 (topic: output)\n" + + " <-- KSTREAM-FILTER-0000000002\n\n") + ); + } + + @Test + public void shouldDescribeWithMultipleNamedTopologies() { + builder1.stream("stream-1").filter((k, v) -> !k.equals(v)).to("output-1"); + builder2.stream("stream-2").filter((k, v) -> !k.equals(v)).to("output-2"); + builder3.stream("stream-3").filter((k, v) -> !k.equals(v)).to("output-3"); + + streams = new KafkaStreamsNamedTopologyWrapper( + asList( + builder1.buildNamedTopology(props), + builder2.buildNamedTopology(props), + builder3.buildNamedTopology(props)), + props, + clientSupplier + ); + + assertThat( + streams.getFullTopologyDescription(), + equalTo( + "Topology - topology-1:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input-1])\n" + + " --> none\n" + + "\n" + + " Sub-topology: 1\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [stream-1])\n" + + " --> KSTREAM-FILTER-0000000002\n" + + " Processor: KSTREAM-FILTER-0000000002 (stores: [])\n" + + " --> KSTREAM-SINK-0000000003\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Sink: KSTREAM-SINK-0000000003 (topic: output-1)\n" + + " <-- KSTREAM-FILTER-0000000002\n" + + "\n" + + "Topology - topology-2:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input-2])\n" + + " --> none\n" + + "\n" + + " Sub-topology: 1\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [stream-2])\n" + + " --> KSTREAM-FILTER-0000000002\n" + + " Processor: KSTREAM-FILTER-0000000002 (stores: [])\n" + + " --> KSTREAM-SINK-0000000003\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Sink: KSTREAM-SINK-0000000003 (topic: output-2)\n" + + " <-- KSTREAM-FILTER-0000000002\n" + + "\n" + + "Topology - topology-3:\n" + + " Sub-topology: 0\n" + + " Source: KSTREAM-SOURCE-0000000000 (topics: [input-3])\n" + + " --> none\n" + + "\n" + + " Sub-topology: 1\n" + + " Source: KSTREAM-SOURCE-0000000001 (topics: [stream-3])\n" + + " --> KSTREAM-FILTER-0000000002\n" + + " Processor: KSTREAM-FILTER-0000000002 (stores: [])\n" + + " --> KSTREAM-SINK-0000000003\n" + + " <-- KSTREAM-SOURCE-0000000001\n" + + " Sink: KSTREAM-SINK-0000000003 (topic: output-3)\n" + + " <-- KSTREAM-FILTER-0000000002\n\n") + ); + } + + @Test + public void shouldDescribeWithEmptyNamedTopology() { + streams = new KafkaStreamsNamedTopologyWrapper(props, clientSupplier); + + assertThat(streams.getFullTopologyDescription(), equalTo("")); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/PartitionGroupTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/PartitionGroupTest.java new file mode 100644 index 0000000..40602b5 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/PartitionGroupTest.java @@ -0,0 +1,779 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.stats.Value; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.IntegerDeserializer; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.LogAndContinueExceptionHandler; +import org.apache.kafka.streams.processor.TimestampExtractor; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.apache.kafka.test.InternalMockProcessorContext; +import org.apache.kafka.test.MockSourceNode; +import org.apache.kafka.test.MockTimestampExtractor; +import org.hamcrest.Matchers; +import org.junit.Test; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.OptionalLong; +import java.util.UUID; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class PartitionGroupTest { + + private final long maxTaskIdleMs = StreamsConfig.MAX_TASK_IDLE_MS_DISABLED; + private final LogContext logContext = new LogContext("[test] "); + private final Time time = new MockTime(); + private final Serializer intSerializer = new IntegerSerializer(); + private final Deserializer intDeserializer = new IntegerDeserializer(); + private final TimestampExtractor timestampExtractor = new MockTimestampExtractor(); + private final TopicPartition unknownPartition = new TopicPartition("unknown-partition", 0); + private final String errMessage = "Partition " + unknownPartition + " not found."; + private final String[] topics = {"topic"}; + private final TopicPartition partition1 = createPartition1(); + private final TopicPartition partition2 = createPartition2(); + private final RecordQueue queue1 = createQueue1(); + private final RecordQueue queue2 = createQueue2(); + + private final byte[] recordValue = intSerializer.serialize(null, 10); + private final byte[] recordKey = intSerializer.serialize(null, 1); + + private final Metrics metrics = new Metrics(); + private final Sensor enforcedProcessingSensor = metrics.sensor(UUID.randomUUID().toString()); + private final MetricName lastLatenessValue = new MetricName("record-lateness-last-value", "", "", mkMap()); + + + private static Sensor getValueSensor(final Metrics metrics, final MetricName metricName) { + final Sensor lastRecordedValue = metrics.sensor(metricName.name()); + lastRecordedValue.add(metricName, new Value()); + return lastRecordedValue; + } + + @Test + public void testTimeTracking() { + final PartitionGroup group = getBasicGroup(); + + testFirstBatch(group); + testSecondBatch(group); + } + + private RecordQueue createQueue1() { + return new RecordQueue( + partition1, + new MockSourceNode<>(intDeserializer, intDeserializer), + timestampExtractor, + new LogAndContinueExceptionHandler(), + new InternalMockProcessorContext(), + logContext + ); + } + + private RecordQueue createQueue2() { + return new RecordQueue( + partition2, + new MockSourceNode<>(intDeserializer, intDeserializer), + timestampExtractor, + new LogAndContinueExceptionHandler(), + new InternalMockProcessorContext(), + logContext + ); + } + + private TopicPartition createPartition1() { + return new TopicPartition(topics[0], 1); + } + + private TopicPartition createPartition2() { + return new TopicPartition(topics[0], 2); + } + + private void testFirstBatch(final PartitionGroup group) { + StampedRecord record; + final PartitionGroup.RecordInfo info = new PartitionGroup.RecordInfo(); + assertThat(group.numBuffered(), is(0)); + + // add three 3 records with timestamp 1, 3, 5 to partition-1 + final List> list1 = Arrays.asList( + new ConsumerRecord<>("topic", 1, 1L, recordKey, recordValue), + new ConsumerRecord<>("topic", 1, 3L, recordKey, recordValue), + new ConsumerRecord<>("topic", 1, 5L, recordKey, recordValue)); + + group.addRawRecords(partition1, list1); + + // add three 3 records with timestamp 2, 4, 6 to partition-2 + final List> list2 = Arrays.asList( + new ConsumerRecord<>("topic", 2, 2L, recordKey, recordValue), + new ConsumerRecord<>("topic", 2, 4L, recordKey, recordValue), + new ConsumerRecord<>("topic", 2, 6L, recordKey, recordValue)); + + group.addRawRecords(partition2, list2); + // 1:[1, 3, 5] + // 2:[2, 4, 6] + // st: -1 since no records was being processed yet + + verifyBuffered(6, 3, 3, group); + assertThat(group.partitionTimestamp(partition1), is(RecordQueue.UNKNOWN)); + assertThat(group.partitionTimestamp(partition2), is(RecordQueue.UNKNOWN)); + assertThat(group.headRecordOffset(partition1), is(1L)); + assertThat(group.headRecordOffset(partition2), is(2L)); + assertThat(group.streamTime(), is(RecordQueue.UNKNOWN)); + assertThat(metrics.metric(lastLatenessValue).metricValue(), is(0.0)); + + // get one record, now the time should be advanced + record = group.nextRecord(info, time.milliseconds()); + // 1:[3, 5] + // 2:[2, 4, 6] + // st: 1 + assertThat(info.partition(), equalTo(partition1)); + assertThat(group.partitionTimestamp(partition1), is(1L)); + assertThat(group.partitionTimestamp(partition2), is(RecordQueue.UNKNOWN)); + assertThat(group.headRecordOffset(partition1), is(3L)); + assertThat(group.headRecordOffset(partition2), is(2L)); + verifyTimes(record, 1L, 1L, group); + assertThat(metrics.metric(lastLatenessValue).metricValue(), is(0.0)); + + // get one record, now the time should be advanced + record = group.nextRecord(info, time.milliseconds()); + // 1:[3, 5] + // 2:[4, 6] + // st: 2 + assertThat(info.partition(), equalTo(partition2)); + assertThat(group.partitionTimestamp(partition1), is(1L)); + assertThat(group.partitionTimestamp(partition2), is(2L)); + assertThat(group.headRecordOffset(partition1), is(3L)); + assertThat(group.headRecordOffset(partition2), is(4L)); + verifyTimes(record, 2L, 2L, group); + verifyBuffered(4, 2, 2, group); + assertEquals(0.0, metrics.metric(lastLatenessValue).metricValue()); + } + + private void testSecondBatch(final PartitionGroup group) { + StampedRecord record; + final PartitionGroup.RecordInfo info = new PartitionGroup.RecordInfo(); + + // add 2 more records with timestamp 2, 4 to partition-1 + final List> list3 = Arrays.asList( + new ConsumerRecord<>("topic", 1, 2L, recordKey, recordValue), + new ConsumerRecord<>("topic", 1, 4L, recordKey, recordValue)); + + group.addRawRecords(partition1, list3); + // 1:[3, 5, 2, 4] + // 2:[4, 6] + // st: 2 (just adding records shouldn't change it) + verifyBuffered(6, 4, 2, group); + assertThat(group.partitionTimestamp(partition1), is(1L)); + assertThat(group.partitionTimestamp(partition2), is(2L)); + assertThat(group.headRecordOffset(partition1), is(3L)); + assertThat(group.headRecordOffset(partition2), is(4L)); + assertThat(group.streamTime(), is(2L)); + assertThat(metrics.metric(lastLatenessValue).metricValue(), is(0.0)); + + // get one record, time should be advanced + record = group.nextRecord(info, time.milliseconds()); + // 1:[5, 2, 4] + // 2:[4, 6] + // st: 3 + assertThat(info.partition(), equalTo(partition1)); + assertThat(group.partitionTimestamp(partition1), is(3L)); + assertThat(group.partitionTimestamp(partition2), is(2L)); + assertThat(group.headRecordOffset(partition1), is(5L)); + assertThat(group.headRecordOffset(partition2), is(4L)); + verifyTimes(record, 3L, 3L, group); + verifyBuffered(5, 3, 2, group); + assertThat(metrics.metric(lastLatenessValue).metricValue(), is(0.0)); + + // get one record, time should be advanced + record = group.nextRecord(info, time.milliseconds()); + // 1:[5, 2, 4] + // 2:[6] + // st: 4 + assertThat(info.partition(), equalTo(partition2)); + assertThat(group.partitionTimestamp(partition1), is(3L)); + assertThat(group.partitionTimestamp(partition2), is(4L)); + assertThat(group.headRecordOffset(partition1), is(5L)); + assertThat(group.headRecordOffset(partition2), is(6L)); + verifyTimes(record, 4L, 4L, group); + verifyBuffered(4, 3, 1, group); + assertThat(metrics.metric(lastLatenessValue).metricValue(), is(0.0)); + + // get one more record, time should be advanced + record = group.nextRecord(info, time.milliseconds()); + // 1:[2, 4] + // 2:[6] + // st: 5 + assertThat(info.partition(), equalTo(partition1)); + assertThat(group.partitionTimestamp(partition1), is(5L)); + assertThat(group.partitionTimestamp(partition2), is(4L)); + assertThat(group.headRecordOffset(partition1), is(2L)); + assertThat(group.headRecordOffset(partition2), is(6L)); + verifyTimes(record, 5L, 5L, group); + verifyBuffered(3, 2, 1, group); + assertThat(metrics.metric(lastLatenessValue).metricValue(), is(0.0)); + + // get one more record, time should not be advanced + record = group.nextRecord(info, time.milliseconds()); + // 1:[4] + // 2:[6] + // st: 5 + assertThat(info.partition(), equalTo(partition1)); + assertThat(group.partitionTimestamp(partition1), is(5L)); + assertThat(group.partitionTimestamp(partition2), is(4L)); + assertThat(group.headRecordOffset(partition1), is(4L)); + assertThat(group.headRecordOffset(partition2), is(6L)); + verifyTimes(record, 2L, 5L, group); + verifyBuffered(2, 1, 1, group); + assertThat(metrics.metric(lastLatenessValue).metricValue(), is(3.0)); + + // get one more record, time should not be advanced + record = group.nextRecord(info, time.milliseconds()); + // 1:[] + // 2:[6] + // st: 5 + assertThat(info.partition(), equalTo(partition1)); + assertThat(group.partitionTimestamp(partition1), is(5L)); + assertThat(group.partitionTimestamp(partition2), is(4L)); + assertNull(group.headRecordOffset(partition1)); + assertThat(group.headRecordOffset(partition2), is(6L)); + verifyTimes(record, 4L, 5L, group); + verifyBuffered(1, 0, 1, group); + assertThat(metrics.metric(lastLatenessValue).metricValue(), is(1.0)); + + // get one more record, time should be advanced + record = group.nextRecord(info, time.milliseconds()); + // 1:[] + // 2:[] + // st: 6 + assertThat(info.partition(), equalTo(partition2)); + assertThat(group.partitionTimestamp(partition1), is(5L)); + assertThat(group.partitionTimestamp(partition2), is(6L)); + assertNull(group.headRecordOffset(partition1)); + assertNull(group.headRecordOffset(partition2)); + verifyTimes(record, 6L, 6L, group); + verifyBuffered(0, 0, 0, group); + assertThat(metrics.metric(lastLatenessValue).metricValue(), is(0.0)); + } + + @Test + public void shouldChooseNextRecordBasedOnHeadTimestamp() { + final PartitionGroup group = getBasicGroup(); + + assertEquals(0, group.numBuffered()); + + // add three 3 records with timestamp 1, 5, 3 to partition-1 + final List> list1 = Arrays.asList( + new ConsumerRecord<>("topic", 1, 1L, recordKey, recordValue), + new ConsumerRecord<>("topic", 1, 5L, recordKey, recordValue), + new ConsumerRecord<>("topic", 1, 3L, recordKey, recordValue)); + + group.addRawRecords(partition1, list1); + + verifyBuffered(3, 3, 0, group); + assertEquals(-1L, group.streamTime()); + assertEquals(0.0, metrics.metric(lastLatenessValue).metricValue()); + + StampedRecord record; + final PartitionGroup.RecordInfo info = new PartitionGroup.RecordInfo(); + + // get first two records from partition 1 + record = group.nextRecord(info, time.milliseconds()); + assertEquals(record.timestamp, 1L); + record = group.nextRecord(info, time.milliseconds()); + assertEquals(record.timestamp, 5L); + + // add three 3 records with timestamp 2, 4, 6 to partition-2 + final List> list2 = Arrays.asList( + new ConsumerRecord<>("topic", 2, 2L, recordKey, recordValue), + new ConsumerRecord<>("topic", 2, 4L, recordKey, recordValue), + new ConsumerRecord<>("topic", 2, 6L, recordKey, recordValue)); + + group.addRawRecords(partition2, list2); + // 1:[3] + // 2:[2, 4, 6] + + // get one record, next record should be ts=2 from partition 2 + record = group.nextRecord(info, time.milliseconds()); + // 1:[3] + // 2:[4, 6] + assertEquals(record.timestamp, 2L); + + // get one record, next up should have ts=3 from partition 1 (even though it has seen a larger max timestamp =5) + record = group.nextRecord(info, time.milliseconds()); + // 1:[] + // 2:[4, 6] + assertEquals(record.timestamp, 3L); + } + + private void verifyTimes(final StampedRecord record, + final long recordTime, + final long streamTime, + final PartitionGroup group) { + assertThat(record.timestamp, is(recordTime)); + assertThat(group.streamTime(), is(streamTime)); + } + + private void verifyBuffered(final int totalBuffered, + final int partitionOneBuffered, + final int partitionTwoBuffered, + final PartitionGroup group) { + assertEquals(totalBuffered, group.numBuffered()); + assertEquals(partitionOneBuffered, group.numBuffered(partition1)); + assertEquals(partitionTwoBuffered, group.numBuffered(partition2)); + } + + @Test + public void shouldSetPartitionTimestampAndStreamTime() { + final PartitionGroup group = getBasicGroup(); + + group.setPartitionTime(partition1, 100L); + assertEquals(100L, group.partitionTimestamp(partition1)); + assertEquals(100L, group.streamTime()); + group.setPartitionTime(partition2, 50L); + assertEquals(50L, group.partitionTimestamp(partition2)); + assertEquals(100L, group.streamTime()); + } + + @Test + public void shouldThrowIllegalStateExceptionUponAddRecordsIfPartitionUnknown() { + final PartitionGroup group = getBasicGroup(); + + final IllegalStateException exception = assertThrows( + IllegalStateException.class, + () -> group.addRawRecords(unknownPartition, null)); + assertThat(errMessage, equalTo(exception.getMessage())); + } + + @Test + public void shouldThrowIllegalStateExceptionUponNumBufferedIfPartitionUnknown() { + final PartitionGroup group = getBasicGroup(); + + final IllegalStateException exception = assertThrows( + IllegalStateException.class, + () -> group.numBuffered(unknownPartition)); + assertThat(errMessage, equalTo(exception.getMessage())); + } + + @Test + public void shouldThrowIllegalStateExceptionUponSetPartitionTimestampIfPartitionUnknown() { + final PartitionGroup group = getBasicGroup(); + + final IllegalStateException exception = assertThrows( + IllegalStateException.class, + () -> group.setPartitionTime(unknownPartition, 0L)); + assertThat(errMessage, equalTo(exception.getMessage())); + } + + @Test + public void shouldThrowIllegalStateExceptionUponGetPartitionTimestampIfPartitionUnknown() { + final PartitionGroup group = getBasicGroup(); + + final IllegalStateException exception = assertThrows( + IllegalStateException.class, + () -> group.partitionTimestamp(unknownPartition)); + assertThat(errMessage, equalTo(exception.getMessage())); + } + + @Test + public void shouldThrowIllegalStateExceptionUponGetHeadRecordOffsetIfPartitionUnknown() { + final PartitionGroup group = getBasicGroup(); + + final IllegalStateException exception = assertThrows( + IllegalStateException.class, + () -> group.headRecordOffset(unknownPartition)); + assertThat(errMessage, equalTo(exception.getMessage())); + } + + @Test + public void shouldEmptyPartitionsOnClear() { + final PartitionGroup group = getBasicGroup(); + + final List> list = Arrays.asList( + new ConsumerRecord<>("topic", 1, 1L, recordKey, recordValue), + new ConsumerRecord<>("topic", 1, 3L, recordKey, recordValue), + new ConsumerRecord<>("topic", 1, 5L, recordKey, recordValue)); + group.addRawRecords(partition1, list); + group.nextRecord(new PartitionGroup.RecordInfo(), time.milliseconds()); + group.nextRecord(new PartitionGroup.RecordInfo(), time.milliseconds()); + + group.clear(); + + assertThat(group.numBuffered(), equalTo(0)); + assertThat(group.streamTime(), equalTo(RecordQueue.UNKNOWN)); + assertThat(group.nextRecord(new PartitionGroup.RecordInfo(), time.milliseconds()), equalTo(null)); + assertThat(group.partitionTimestamp(partition1), equalTo(RecordQueue.UNKNOWN)); + + group.addRawRecords(partition1, list); + } + + @Test + public void shouldUpdatePartitionQueuesShrink() { + final PartitionGroup group = getBasicGroup(); + + final List> list1 = Arrays.asList( + new ConsumerRecord<>("topic", 1, 1L, recordKey, recordValue), + new ConsumerRecord<>("topic", 1, 5L, recordKey, recordValue)); + group.addRawRecords(partition1, list1); + final List> list2 = Arrays.asList( + new ConsumerRecord<>("topic", 2, 2L, recordKey, recordValue), + new ConsumerRecord<>("topic", 2, 4L, recordKey, recordValue), + new ConsumerRecord<>("topic", 2, 6L, recordKey, recordValue)); + group.addRawRecords(partition2, list2); + assertEquals(list1.size() + list2.size(), group.numBuffered()); + assertTrue(group.allPartitionsBufferedLocally()); + group.nextRecord(new PartitionGroup.RecordInfo(), time.milliseconds()); + + // shrink list of queues + group.updatePartitions(mkSet(createPartition2()), p -> { + fail("should not create any queues"); + return null; + }); + + assertTrue(group.allPartitionsBufferedLocally()); // because didn't add any new partitions + assertEquals(list2.size(), group.numBuffered()); + assertEquals(1, group.streamTime()); + assertThrows(IllegalStateException.class, () -> group.partitionTimestamp(partition1)); + assertThat(group.nextRecord(new PartitionGroup.RecordInfo(), time.milliseconds()), notNullValue()); // can access buffered records + assertThat(group.partitionTimestamp(partition2), equalTo(2L)); + } + + @Test + public void shouldUpdatePartitionQueuesExpand() { + final PartitionGroup group = new PartitionGroup( + logContext, + mkMap(mkEntry(partition1, queue1)), + tp -> OptionalLong.of(0L), + getValueSensor(metrics, lastLatenessValue), + enforcedProcessingSensor, + maxTaskIdleMs + ); + final List> list1 = Arrays.asList( + new ConsumerRecord<>("topic", 1, 1L, recordKey, recordValue), + new ConsumerRecord<>("topic", 1, 5L, recordKey, recordValue)); + group.addRawRecords(partition1, list1); + + assertEquals(list1.size(), group.numBuffered()); + assertTrue(group.allPartitionsBufferedLocally()); + group.nextRecord(new PartitionGroup.RecordInfo(), time.milliseconds()); + + // expand list of queues + group.updatePartitions(mkSet(createPartition1(), createPartition2()), p -> { + assertEquals(createPartition2(), p); + return createQueue2(); + }); + + assertFalse(group.allPartitionsBufferedLocally()); // because added new partition + assertEquals(1, group.numBuffered()); + assertEquals(1, group.streamTime()); + assertThat(group.partitionTimestamp(partition1), equalTo(1L)); + assertThat(group.partitionTimestamp(partition2), equalTo(RecordQueue.UNKNOWN)); + assertThat(group.nextRecord(new PartitionGroup.RecordInfo(), time.milliseconds()), notNullValue()); // can access buffered records + } + + @Test + public void shouldUpdatePartitionQueuesShrinkAndExpand() { + final PartitionGroup group = new PartitionGroup( + logContext, + mkMap(mkEntry(partition1, queue1)), + tp -> OptionalLong.of(0L), + getValueSensor(metrics, lastLatenessValue), + enforcedProcessingSensor, + maxTaskIdleMs + ); + final List> list1 = Arrays.asList( + new ConsumerRecord<>("topic", 1, 1L, recordKey, recordValue), + new ConsumerRecord<>("topic", 1, 5L, recordKey, recordValue)); + group.addRawRecords(partition1, list1); + assertEquals(list1.size(), group.numBuffered()); + assertTrue(group.allPartitionsBufferedLocally()); + group.nextRecord(new PartitionGroup.RecordInfo(), time.milliseconds()); + + // expand and shrink list of queues + group.updatePartitions(mkSet(createPartition2()), p -> { + assertEquals(createPartition2(), p); + return createQueue2(); + }); + + assertFalse(group.allPartitionsBufferedLocally()); // because added new partition + assertEquals(0, group.numBuffered()); + assertEquals(1, group.streamTime()); + assertThrows(IllegalStateException.class, () -> group.partitionTimestamp(partition1)); + assertThat(group.partitionTimestamp(partition2), equalTo(RecordQueue.UNKNOWN)); + assertThat(group.nextRecord(new PartitionGroup.RecordInfo(), time.milliseconds()), nullValue()); // all available records removed + } + + @Test + public void shouldNeverWaitIfIdlingIsDisabled() { + final PartitionGroup group = new PartitionGroup( + logContext, + mkMap( + mkEntry(partition1, queue1), + mkEntry(partition2, queue2) + ), + tp -> OptionalLong.of(0L), + getValueSensor(metrics, lastLatenessValue), + enforcedProcessingSensor, + StreamsConfig.MAX_TASK_IDLE_MS_DISABLED + ); + + final List> list1 = Arrays.asList( + new ConsumerRecord<>("topic", 1, 1L, recordKey, recordValue), + new ConsumerRecord<>("topic", 1, 5L, recordKey, recordValue)); + group.addRawRecords(partition1, list1); + + assertThat(group.allPartitionsBufferedLocally(), is(false)); + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(PartitionGroup.class)) { + LogCaptureAppender.setClassLoggerToTrace(PartitionGroup.class); + assertThat(group.readyToProcess(0L), is(true)); + assertThat( + appender.getEvents(), + hasItem(Matchers.allOf( + Matchers.hasProperty("level", equalTo("TRACE")), + Matchers.hasProperty("message", equalTo( + "[test] Ready for processing because max.task.idle.ms is disabled.\n" + + "\tThere may be out-of-order processing for this task as a result.\n" + + "\tBuffered partitions: [topic-1]\n" + + "\tNon-buffered partitions: [topic-2]" + )) + )) + ); + } + } + + @Test + public void shouldBeReadyIfAllPartitionsAreBuffered() { + final PartitionGroup group = new PartitionGroup( + logContext, + mkMap( + mkEntry(partition1, queue1), + mkEntry(partition2, queue2) + ), + tp -> OptionalLong.of(0L), + getValueSensor(metrics, lastLatenessValue), + enforcedProcessingSensor, + 0L + ); + + final List> list1 = Arrays.asList( + new ConsumerRecord<>("topic", 1, 1L, recordKey, recordValue), + new ConsumerRecord<>("topic", 1, 5L, recordKey, recordValue)); + group.addRawRecords(partition1, list1); + + final List> list2 = Arrays.asList( + new ConsumerRecord<>("topic", 2, 1L, recordKey, recordValue), + new ConsumerRecord<>("topic", 2, 5L, recordKey, recordValue)); + group.addRawRecords(partition2, list2); + + assertThat(group.allPartitionsBufferedLocally(), is(true)); + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(PartitionGroup.class)) { + LogCaptureAppender.setClassLoggerToTrace(PartitionGroup.class); + assertThat(group.readyToProcess(0L), is(true)); + assertThat( + appender.getEvents(), + hasItem(Matchers.allOf( + Matchers.hasProperty("level", equalTo("TRACE")), + Matchers.hasProperty("message", equalTo("[test] All partitions were buffered locally, so this task is ready for processing.")) + )) + ); + } + } + + @Test + public void shouldWaitForFetchesWhenMetadataIsIncomplete() { + final HashMap lags = new HashMap<>(); + final PartitionGroup group = new PartitionGroup( + logContext, + mkMap( + mkEntry(partition1, queue1), + mkEntry(partition2, queue2) + ), + tp -> lags.getOrDefault(tp, OptionalLong.empty()), + getValueSensor(metrics, lastLatenessValue), + enforcedProcessingSensor, + 0L + ); + + final List> list1 = Arrays.asList( + new ConsumerRecord<>("topic", 1, 1L, recordKey, recordValue), + new ConsumerRecord<>("topic", 1, 5L, recordKey, recordValue)); + group.addRawRecords(partition1, list1); + + assertThat(group.allPartitionsBufferedLocally(), is(false)); + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(PartitionGroup.class)) { + LogCaptureAppender.setClassLoggerToTrace(PartitionGroup.class); + assertThat(group.readyToProcess(0L), is(false)); + assertThat( + appender.getEvents(), + hasItem(Matchers.allOf( + Matchers.hasProperty("level", equalTo("TRACE")), + Matchers.hasProperty("message", equalTo("[test] Waiting to fetch data for topic-2")) + )) + ); + } + lags.put(partition2, OptionalLong.of(0L)); + assertThat(group.readyToProcess(0L), is(true)); + } + + @Test + public void shouldWaitForPollWhenLagIsNonzero() { + final HashMap lags = new HashMap<>(); + final PartitionGroup group = new PartitionGroup( + logContext, + mkMap( + mkEntry(partition1, queue1), + mkEntry(partition2, queue2) + ), + tp -> lags.getOrDefault(tp, OptionalLong.empty()), + getValueSensor(metrics, lastLatenessValue), + enforcedProcessingSensor, + 0L + ); + + final List> list1 = Arrays.asList( + new ConsumerRecord<>("topic", 1, 1L, recordKey, recordValue), + new ConsumerRecord<>("topic", 1, 5L, recordKey, recordValue)); + group.addRawRecords(partition1, list1); + + lags.put(partition2, OptionalLong.of(1L)); + + assertThat(group.allPartitionsBufferedLocally(), is(false)); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(PartitionGroup.class)) { + LogCaptureAppender.setClassLoggerToTrace(PartitionGroup.class); + assertThat(group.readyToProcess(0L), is(false)); + assertThat( + appender.getEvents(), + hasItem(Matchers.allOf( + Matchers.hasProperty("level", equalTo("TRACE")), + Matchers.hasProperty("message", equalTo("[test] Lag for topic-2 is currently 1, but no data is buffered locally. Waiting to buffer some records.")) + )) + ); + } + } + + @Test + public void shouldIdleAsSpecifiedWhenLagIsZero() { + final PartitionGroup group = new PartitionGroup( + logContext, + mkMap( + mkEntry(partition1, queue1), + mkEntry(partition2, queue2) + ), + tp -> OptionalLong.of(0L), + getValueSensor(metrics, lastLatenessValue), + enforcedProcessingSensor, + 1L + ); + + final List> list1 = Arrays.asList( + new ConsumerRecord<>("topic", 1, 1L, recordKey, recordValue), + new ConsumerRecord<>("topic", 1, 5L, recordKey, recordValue)); + group.addRawRecords(partition1, list1); + + assertThat(group.allPartitionsBufferedLocally(), is(false)); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(PartitionGroup.class)) { + LogCaptureAppender.setClassLoggerToTrace(PartitionGroup.class); + assertThat(group.readyToProcess(0L), is(false)); + assertThat( + appender.getEvents(), + hasItem(Matchers.allOf( + Matchers.hasProperty("level", equalTo("TRACE")), + Matchers.hasProperty("message", equalTo("[test] Lag for topic-2 is currently 0 and current time is 0. Waiting for new data to be produced for configured idle time 1 (deadline is 1).")) + )) + ); + } + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(PartitionGroup.class)) { + LogCaptureAppender.setClassLoggerToTrace(PartitionGroup.class); + assertThat(group.readyToProcess(1L), is(true)); + assertThat( + appender.getEvents(), + hasItem(Matchers.allOf( + Matchers.hasProperty("level", equalTo("TRACE")), + Matchers.hasProperty("message", equalTo( + "[test] Continuing to process although some partitions are empty on the broker.\n" + + "\tThere may be out-of-order processing for this task as a result.\n" + + "\tPartitions with local data: [topic-1].\n" + + "\tPartitions we gave up waiting for, with their corresponding deadlines: {topic-2=1}.\n" + + "\tConfigured max.task.idle.ms: 1.\n" + + "\tCurrent wall-clock time: 1." + )) + )) + ); + } + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(PartitionGroup.class)) { + LogCaptureAppender.setClassLoggerToTrace(PartitionGroup.class); + assertThat(group.readyToProcess(2L), is(true)); + assertThat( + appender.getEvents(), + hasItem(Matchers.allOf( + Matchers.hasProperty("level", equalTo("TRACE")), + Matchers.hasProperty("message", equalTo( + "[test] Continuing to process although some partitions are empty on the broker.\n" + + "\tThere may be out-of-order processing for this task as a result.\n" + + "\tPartitions with local data: [topic-1].\n" + + "\tPartitions we gave up waiting for, with their corresponding deadlines: {topic-2=1}.\n" + + "\tConfigured max.task.idle.ms: 1.\n" + + "\tCurrent wall-clock time: 2." + )) + )) + ); + } + } + + private PartitionGroup getBasicGroup() { + return new PartitionGroup( + logContext, + mkMap( + mkEntry(partition1, queue1), + mkEntry(partition2, queue2) + ), + tp -> OptionalLong.of(0L), + getValueSensor(metrics, lastLatenessValue), + enforcedProcessingSensor, + maxTaskIdleMs + ); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/PartitionGrouperTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/PartitionGrouperTest.java new file mode 100644 index 0000000..93d2dce --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/PartitionGrouperTest.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology; + +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.SUBTOPOLOGY_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.SUBTOPOLOGY_1; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +public class PartitionGrouperTest { + + private final List infos = Arrays.asList( + new PartitionInfo("topic1", 0, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic1", 1, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic1", 2, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic2", 0, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic2", 1, Node.noNode(), new Node[0], new Node[0]) + ); + + private final Cluster metadata = new Cluster( + "cluster", + Collections.singletonList(Node.noNode()), + infos, + Collections.emptySet(), + Collections.emptySet()); + + @Test + public void shouldComputeGroupingForTwoGroups() { + final PartitionGrouper grouper = new PartitionGrouper(); + final Map> expectedPartitionsForTask = new HashMap<>(); + final Map> topicGroups = new HashMap<>(); + + topicGroups.put(SUBTOPOLOGY_0, mkSet("topic1")); + expectedPartitionsForTask.put(new TaskId(SUBTOPOLOGY_0.nodeGroupId, 0, SUBTOPOLOGY_0.namedTopology), mkSet(new TopicPartition("topic1", 0))); + expectedPartitionsForTask.put(new TaskId(SUBTOPOLOGY_0.nodeGroupId, 1, SUBTOPOLOGY_0.namedTopology), mkSet(new TopicPartition("topic1", 1))); + expectedPartitionsForTask.put(new TaskId(SUBTOPOLOGY_0.nodeGroupId, 2, SUBTOPOLOGY_0.namedTopology), mkSet(new TopicPartition("topic1", 2))); + + topicGroups.put(SUBTOPOLOGY_1, mkSet("topic2")); + expectedPartitionsForTask.put(new TaskId(SUBTOPOLOGY_1.nodeGroupId, 0, SUBTOPOLOGY_1.namedTopology), mkSet(new TopicPartition("topic2", 0))); + expectedPartitionsForTask.put(new TaskId(SUBTOPOLOGY_1.nodeGroupId, 1, SUBTOPOLOGY_1.namedTopology), mkSet(new TopicPartition("topic2", 1))); + + assertEquals(expectedPartitionsForTask, grouper.partitionGroups(topicGroups, metadata)); + } + + @Test + public void shouldComputeGroupingForSingleGroupWithMultipleTopics() { + final PartitionGrouper grouper = new PartitionGrouper(); + final Map> expectedPartitionsForTask = new HashMap<>(); + final Map> topicGroups = new HashMap<>(); + + topicGroups.put(SUBTOPOLOGY_0, mkSet("topic1", "topic2")); + expectedPartitionsForTask.put( + new TaskId(SUBTOPOLOGY_0.nodeGroupId, 0, SUBTOPOLOGY_0.namedTopology), + mkSet(new TopicPartition("topic1", 0), new TopicPartition("topic2", 0))); + expectedPartitionsForTask.put( + new TaskId(SUBTOPOLOGY_0.nodeGroupId, 1, SUBTOPOLOGY_0.namedTopology), + mkSet(new TopicPartition("topic1", 1), new TopicPartition("topic2", 1))); + expectedPartitionsForTask.put( + new TaskId(SUBTOPOLOGY_0.nodeGroupId, 2, SUBTOPOLOGY_0.namedTopology), + mkSet(new TopicPartition("topic1", 2))); + + assertEquals(expectedPartitionsForTask, grouper.partitionGroups(topicGroups, metadata)); + } + + @Test + public void shouldNotCreateAnyTasksBecauseOneTopicHasUnknownPartitions() { + final PartitionGrouper grouper = new PartitionGrouper(); + final Map> topicGroups = new HashMap<>(); + + topicGroups.put(SUBTOPOLOGY_0, mkSet("topic1", "unknownTopic", "topic2")); + assertThrows(RuntimeException.class, () -> grouper.partitionGroups(topicGroups, metadata)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorContextImplTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorContextImplTest.java new file mode 100644 index 0000000..cb7949a --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorContextImplTest.java @@ -0,0 +1,753 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.To; +import org.apache.kafka.streams.processor.internals.Task.TaskType; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.SessionStore; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.TimestampedWindowStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.streams.state.WindowStoreIterator; +import org.apache.kafka.streams.state.internals.ThreadCache; +import org.easymock.EasyMock; +import org.junit.Before; +import org.junit.Test; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.function.Consumer; + +import static java.util.Arrays.asList; +import static org.apache.kafka.streams.processor.internals.ProcessorContextImpl.BYTEARRAY_VALUE_SERIALIZER; +import static org.apache.kafka.streams.processor.internals.ProcessorContextImpl.BYTES_KEY_SERIALIZER; +import static org.easymock.EasyMock.anyLong; +import static org.easymock.EasyMock.anyObject; +import static org.easymock.EasyMock.anyString; +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.expectLastCall; +import static org.easymock.EasyMock.mock; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.verify; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class ProcessorContextImplTest { + private ProcessorContextImpl context; + + private final StreamsConfig streamsConfig = streamsConfigMock(); + + private final RecordCollector recordCollector = mock(RecordCollector.class); + + private static final String KEY = "key"; + private static final Bytes KEY_BYTES = Bytes.wrap(KEY.getBytes()); + private static final long VALUE = 42L; + private static final byte[] VALUE_BYTES = String.valueOf(VALUE).getBytes(); + private static final long TIMESTAMP = 21L; + private static final long STREAM_TIME = 50L; + private static final ValueAndTimestamp VALUE_AND_TIMESTAMP = ValueAndTimestamp.make(42L, 21L); + private static final String STORE_NAME = "underlying-store"; + private static final String REGISTERED_STORE_NAME = "registered-store"; + private static final TopicPartition CHANGELOG_PARTITION = new TopicPartition("store-changelog", 1); + + private boolean flushExecuted; + private boolean putExecuted; + private boolean putWithTimestampExecuted; + private boolean putIfAbsentExecuted; + private boolean putAllExecuted; + private boolean deleteExecuted; + private boolean removeExecuted; + + private KeyValueIterator rangeIter; + private KeyValueIterator> timestampedRangeIter; + private KeyValueIterator allIter; + private KeyValueIterator> timestampedAllIter; + + private final List, Long>> iters = new ArrayList<>(7); + private final List, ValueAndTimestamp>> timestampedIters = new ArrayList<>(7); + private WindowStoreIterator windowStoreIter; + + @Before + public void setup() { + flushExecuted = false; + putExecuted = false; + putIfAbsentExecuted = false; + putAllExecuted = false; + deleteExecuted = false; + removeExecuted = false; + + rangeIter = mock(KeyValueIterator.class); + timestampedRangeIter = mock(KeyValueIterator.class); + allIter = mock(KeyValueIterator.class); + timestampedAllIter = mock(KeyValueIterator.class); + windowStoreIter = mock(WindowStoreIterator.class); + + for (int i = 0; i < 7; i++) { + iters.add(i, mock(KeyValueIterator.class)); + timestampedIters.add(i, mock(KeyValueIterator.class)); + } + + final ProcessorStateManager stateManager = mock(ProcessorStateManager.class); + expect(stateManager.taskType()).andStubReturn(TaskType.ACTIVE); + + expect(stateManager.getGlobalStore("GlobalKeyValueStore")).andReturn(keyValueStoreMock()); + expect(stateManager.getGlobalStore("GlobalTimestampedKeyValueStore")).andReturn(timestampedKeyValueStoreMock()); + expect(stateManager.getGlobalStore("GlobalWindowStore")).andReturn(windowStoreMock()); + expect(stateManager.getGlobalStore("GlobalTimestampedWindowStore")).andReturn(timestampedWindowStoreMock()); + expect(stateManager.getGlobalStore("GlobalSessionStore")).andReturn(sessionStoreMock()); + expect(stateManager.getGlobalStore(anyString())).andReturn(null); + expect(stateManager.getStore("LocalKeyValueStore")).andReturn(keyValueStoreMock()); + expect(stateManager.getStore("LocalTimestampedKeyValueStore")).andReturn(timestampedKeyValueStoreMock()); + expect(stateManager.getStore("LocalWindowStore")).andReturn(windowStoreMock()); + expect(stateManager.getStore("LocalTimestampedWindowStore")).andReturn(timestampedWindowStoreMock()); + expect(stateManager.getStore("LocalSessionStore")).andReturn(sessionStoreMock()); + expect(stateManager.registeredChangelogPartitionFor(REGISTERED_STORE_NAME)).andStubReturn(CHANGELOG_PARTITION); + + replay(stateManager); + + context = new ProcessorContextImpl( + mock(TaskId.class), + streamsConfig, + stateManager, + mock(StreamsMetricsImpl.class), + mock(ThreadCache.class) + ); + + final StreamTask task = mock(StreamTask.class); + expect(task.streamTime()).andReturn(STREAM_TIME); + EasyMock.expect(task.recordCollector()).andStubReturn(recordCollector); + replay(task); + ((InternalProcessorContext) context).transitionToActive(task, null, null); + + context.setCurrentNode( + new ProcessorNode<>( + "fake", + (org.apache.kafka.streams.processor.api.Processor) null, + new HashSet<>( + asList( + "LocalKeyValueStore", + "LocalTimestampedKeyValueStore", + "LocalWindowStore", + "LocalTimestampedWindowStore", + "LocalSessionStore" + ) + ) + ) + ); + } + + private ProcessorContextImpl getStandbyContext() { + final ProcessorStateManager stateManager = EasyMock.createNiceMock(ProcessorStateManager.class); + expect(stateManager.taskType()).andStubReturn(TaskType.STANDBY); + replay(stateManager); + return new ProcessorContextImpl( + mock(TaskId.class), + streamsConfig, + stateManager, + mock(StreamsMetricsImpl.class), + mock(ThreadCache.class) + ); + } + + @Test + public void globalKeyValueStoreShouldBeReadOnly() { + doTest("GlobalKeyValueStore", (Consumer>) store -> { + verifyStoreCannotBeInitializedOrClosed(store); + + checkThrowsUnsupportedOperation(store::flush, "flush()"); + checkThrowsUnsupportedOperation(() -> store.put("1", 1L), "put()"); + checkThrowsUnsupportedOperation(() -> store.putIfAbsent("1", 1L), "putIfAbsent()"); + checkThrowsUnsupportedOperation(() -> store.putAll(Collections.emptyList()), "putAll()"); + checkThrowsUnsupportedOperation(() -> store.delete("1"), "delete()"); + + assertEquals((Long) VALUE, store.get(KEY)); + assertEquals(rangeIter, store.range("one", "two")); + assertEquals(allIter, store.all()); + assertEquals(VALUE, store.approximateNumEntries()); + }); + } + + @Test + public void globalTimestampedKeyValueStoreShouldBeReadOnly() { + doTest("GlobalTimestampedKeyValueStore", (Consumer>) store -> { + verifyStoreCannotBeInitializedOrClosed(store); + + checkThrowsUnsupportedOperation(store::flush, "flush()"); + checkThrowsUnsupportedOperation(() -> store.put("1", ValueAndTimestamp.make(1L, 2L)), "put()"); + checkThrowsUnsupportedOperation(() -> store.putIfAbsent("1", ValueAndTimestamp.make(1L, 2L)), "putIfAbsent()"); + checkThrowsUnsupportedOperation(() -> store.putAll(Collections.emptyList()), "putAll()"); + checkThrowsUnsupportedOperation(() -> store.delete("1"), "delete()"); + + assertEquals(VALUE_AND_TIMESTAMP, store.get(KEY)); + assertEquals(timestampedRangeIter, store.range("one", "two")); + assertEquals(timestampedAllIter, store.all()); + assertEquals(VALUE, store.approximateNumEntries()); + }); + } + + @Test + public void globalWindowStoreShouldBeReadOnly() { + doTest("GlobalWindowStore", (Consumer>) store -> { + verifyStoreCannotBeInitializedOrClosed(store); + + checkThrowsUnsupportedOperation(store::flush, "flush()"); + checkThrowsUnsupportedOperation(() -> store.put("1", 1L, 1L), "put()"); + + assertEquals(iters.get(0), store.fetchAll(0L, 0L)); + assertEquals(windowStoreIter, store.fetch(KEY, 0L, 1L)); + assertEquals(iters.get(1), store.fetch(KEY, KEY, 0L, 1L)); + assertEquals((Long) VALUE, store.fetch(KEY, 1L)); + assertEquals(iters.get(2), store.all()); + }); + } + + + @Test + public void globalTimestampedWindowStoreShouldBeReadOnly() { + doTest("GlobalTimestampedWindowStore", (Consumer>) store -> { + verifyStoreCannotBeInitializedOrClosed(store); + + checkThrowsUnsupportedOperation(store::flush, "flush()"); + checkThrowsUnsupportedOperation(() -> store.put("1", ValueAndTimestamp.make(1L, 1L), 1L), "put() [with timestamp]"); + + assertEquals(timestampedIters.get(0), store.fetchAll(0L, 0L)); + assertEquals(windowStoreIter, store.fetch(KEY, 0L, 1L)); + assertEquals(timestampedIters.get(1), store.fetch(KEY, KEY, 0L, 1L)); + assertEquals(VALUE_AND_TIMESTAMP, store.fetch(KEY, 1L)); + assertEquals(timestampedIters.get(2), store.all()); + }); + } + + @Test + public void globalSessionStoreShouldBeReadOnly() { + doTest("GlobalSessionStore", (Consumer>) store -> { + verifyStoreCannotBeInitializedOrClosed(store); + + checkThrowsUnsupportedOperation(store::flush, "flush()"); + checkThrowsUnsupportedOperation(() -> store.remove(null), "remove()"); + checkThrowsUnsupportedOperation(() -> store.put(null, null), "put()"); + + assertEquals(iters.get(3), store.findSessions(KEY, 1L, 2L)); + assertEquals(iters.get(4), store.findSessions(KEY, KEY, 1L, 2L)); + assertEquals(iters.get(5), store.fetch(KEY)); + assertEquals(iters.get(6), store.fetch(KEY, KEY)); + }); + } + + @Test + public void localKeyValueStoreShouldNotAllowInitOrClose() { + doTest("LocalKeyValueStore", (Consumer>) store -> { + verifyStoreCannotBeInitializedOrClosed(store); + + store.flush(); + assertTrue(flushExecuted); + + store.put("1", 1L); + assertTrue(putExecuted); + + store.putIfAbsent("1", 1L); + assertTrue(putIfAbsentExecuted); + + store.putAll(Collections.emptyList()); + assertTrue(putAllExecuted); + + store.delete("1"); + assertTrue(deleteExecuted); + + assertEquals((Long) VALUE, store.get(KEY)); + assertEquals(rangeIter, store.range("one", "two")); + assertEquals(allIter, store.all()); + assertEquals(VALUE, store.approximateNumEntries()); + }); + } + + @Test + public void localTimestampedKeyValueStoreShouldNotAllowInitOrClose() { + doTest("LocalTimestampedKeyValueStore", (Consumer>) store -> { + verifyStoreCannotBeInitializedOrClosed(store); + + store.flush(); + assertTrue(flushExecuted); + + store.put("1", ValueAndTimestamp.make(1L, 2L)); + assertTrue(putExecuted); + + store.putIfAbsent("1", ValueAndTimestamp.make(1L, 2L)); + assertTrue(putIfAbsentExecuted); + + store.putAll(Collections.emptyList()); + assertTrue(putAllExecuted); + + store.delete("1"); + assertTrue(deleteExecuted); + + assertEquals(VALUE_AND_TIMESTAMP, store.get(KEY)); + assertEquals(timestampedRangeIter, store.range("one", "two")); + assertEquals(timestampedAllIter, store.all()); + assertEquals(VALUE, store.approximateNumEntries()); + }); + } + + @Test + public void localWindowStoreShouldNotAllowInitOrClose() { + doTest("LocalWindowStore", (Consumer>) store -> { + verifyStoreCannotBeInitializedOrClosed(store); + + store.flush(); + assertTrue(flushExecuted); + + store.put("1", 1L, 1L); + assertTrue(putExecuted); + + assertEquals(iters.get(0), store.fetchAll(0L, 0L)); + assertEquals(windowStoreIter, store.fetch(KEY, 0L, 1L)); + assertEquals(iters.get(1), store.fetch(KEY, KEY, 0L, 1L)); + assertEquals((Long) VALUE, store.fetch(KEY, 1L)); + assertEquals(iters.get(2), store.all()); + }); + } + + @Test + public void localTimestampedWindowStoreShouldNotAllowInitOrClose() { + doTest("LocalTimestampedWindowStore", (Consumer>) store -> { + verifyStoreCannotBeInitializedOrClosed(store); + + store.flush(); + assertTrue(flushExecuted); + + store.put("1", ValueAndTimestamp.make(1L, 1L), 1L); + assertTrue(putExecuted); + + store.put("1", ValueAndTimestamp.make(1L, 1L), 1L); + assertTrue(putWithTimestampExecuted); + + assertEquals(timestampedIters.get(0), store.fetchAll(0L, 0L)); + assertEquals(windowStoreIter, store.fetch(KEY, 0L, 1L)); + assertEquals(timestampedIters.get(1), store.fetch(KEY, KEY, 0L, 1L)); + assertEquals(VALUE_AND_TIMESTAMP, store.fetch(KEY, 1L)); + assertEquals(timestampedIters.get(2), store.all()); + }); + } + + @Test + public void localSessionStoreShouldNotAllowInitOrClose() { + doTest("LocalSessionStore", (Consumer>) store -> { + verifyStoreCannotBeInitializedOrClosed(store); + + store.flush(); + assertTrue(flushExecuted); + + store.remove(null); + assertTrue(removeExecuted); + + store.put(null, null); + assertTrue(putExecuted); + + assertEquals(iters.get(3), store.findSessions(KEY, 1L, 2L)); + assertEquals(iters.get(4), store.findSessions(KEY, KEY, 1L, 2L)); + assertEquals(iters.get(5), store.fetch(KEY)); + assertEquals(iters.get(6), store.fetch(KEY, KEY)); + }); + } + + @Test + public void shouldNotSendRecordHeadersToChangelogTopic() { + recordCollector.send( + CHANGELOG_PARTITION.topic(), + KEY_BYTES, + VALUE_BYTES, + null, + CHANGELOG_PARTITION.partition(), + TIMESTAMP, + BYTES_KEY_SERIALIZER, + BYTEARRAY_VALUE_SERIALIZER + ); + + final StreamTask task = EasyMock.createNiceMock(StreamTask.class); + + replay(recordCollector, task); + context.transitionToActive(task, recordCollector, null); + context.logChange(REGISTERED_STORE_NAME, KEY_BYTES, VALUE_BYTES, TIMESTAMP); + + verify(recordCollector); + } + + @Test + public void shouldThrowUnsupportedOperationExceptionOnLogChange() { + context = getStandbyContext(); + assertThrows( + UnsupportedOperationException.class, + () -> context.logChange("Store", Bytes.wrap("k".getBytes()), null, 0L) + ); + } + + @Test + public void shouldThrowUnsupportedOperationExceptionOnGetStateStore() { + context = getStandbyContext(); + assertThrows( + UnsupportedOperationException.class, + () -> context.getStateStore("store") + ); + } + + @Test + public void shouldThrowUnsupportedOperationExceptionOnForward() { + context = getStandbyContext(); + assertThrows( + UnsupportedOperationException.class, + () -> context.forward("key", "value") + ); + } + + @Test + public void shouldThrowUnsupportedOperationExceptionOnForwardWithTo() { + context = getStandbyContext(); + assertThrows( + UnsupportedOperationException.class, + () -> context.forward("key", "value", To.child("child-name")) + ); + } + + @Test + public void shouldThrowUnsupportedOperationExceptionOnCommit() { + context = getStandbyContext(); + assertThrows( + UnsupportedOperationException.class, + () -> context.commit() + ); + } + + @Test + public void shouldThrowUnsupportedOperationExceptionOnSchedule() { + context = getStandbyContext(); + assertThrows( + UnsupportedOperationException.class, + () -> context.schedule(Duration.ofMillis(100L), PunctuationType.STREAM_TIME, t -> { }) + ); + } + + @Test + public void shouldThrowUnsupportedOperationExceptionOnTopic() { + context = getStandbyContext(); + assertThrows( + UnsupportedOperationException.class, + () -> context.topic() + ); + } + @Test + public void shouldThrowUnsupportedOperationExceptionOnPartition() { + context = getStandbyContext(); + assertThrows( + UnsupportedOperationException.class, + () -> context.partition() + ); + } + + @Test + public void shouldThrowUnsupportedOperationExceptionOnOffset() { + context = getStandbyContext(); + assertThrows( + UnsupportedOperationException.class, + () -> context.offset() + ); + } + + @Test + public void shouldThrowUnsupportedOperationExceptionOnTimestamp() { + context = getStandbyContext(); + assertThrows( + UnsupportedOperationException.class, + () -> context.timestamp() + ); + } + + @Test + public void shouldThrowUnsupportedOperationExceptionOnCurrentNode() { + context = getStandbyContext(); + assertThrows( + UnsupportedOperationException.class, + () -> context.currentNode() + ); + } + + @Test + public void shouldThrowUnsupportedOperationExceptionOnSetRecordContext() { + context = getStandbyContext(); + assertThrows( + UnsupportedOperationException.class, + () -> context.setRecordContext(mock(ProcessorRecordContext.class)) + ); + } + + @Test + public void shouldThrowUnsupportedOperationExceptionOnRecordContext() { + context = getStandbyContext(); + assertThrows( + UnsupportedOperationException.class, + () -> context.recordContext() + ); + } + + @Test + public void shouldMatchStreamTime() { + assertEquals(STREAM_TIME, context.currentStreamTimeMs()); + } + + @SuppressWarnings("unchecked") + private KeyValueStore keyValueStoreMock() { + final KeyValueStore keyValueStoreMock = mock(KeyValueStore.class); + + initStateStoreMock(keyValueStoreMock); + + expect(keyValueStoreMock.get(KEY)).andReturn(VALUE); + expect(keyValueStoreMock.approximateNumEntries()).andReturn(VALUE); + + expect(keyValueStoreMock.range("one", "two")).andReturn(rangeIter); + expect(keyValueStoreMock.all()).andReturn(allIter); + + + keyValueStoreMock.put(anyString(), anyLong()); + expectLastCall().andAnswer(() -> { + putExecuted = true; + return null; + }); + + keyValueStoreMock.putIfAbsent(anyString(), anyLong()); + expectLastCall().andAnswer(() -> { + putIfAbsentExecuted = true; + return null; + }); + + keyValueStoreMock.putAll(anyObject(List.class)); + expectLastCall().andAnswer(() -> { + putAllExecuted = true; + return null; + }); + + keyValueStoreMock.delete(anyString()); + expectLastCall().andAnswer(() -> { + deleteExecuted = true; + return null; + }); + + replay(keyValueStoreMock); + + return keyValueStoreMock; + } + + @SuppressWarnings("unchecked") + private TimestampedKeyValueStore timestampedKeyValueStoreMock() { + final TimestampedKeyValueStore timestampedKeyValueStoreMock = mock(TimestampedKeyValueStore.class); + + initStateStoreMock(timestampedKeyValueStoreMock); + + expect(timestampedKeyValueStoreMock.get(KEY)).andReturn(VALUE_AND_TIMESTAMP); + expect(timestampedKeyValueStoreMock.approximateNumEntries()).andReturn(VALUE); + + expect(timestampedKeyValueStoreMock.range("one", "two")).andReturn(timestampedRangeIter); + expect(timestampedKeyValueStoreMock.all()).andReturn(timestampedAllIter); + + + timestampedKeyValueStoreMock.put(anyString(), anyObject(ValueAndTimestamp.class)); + expectLastCall().andAnswer(() -> { + putExecuted = true; + return null; + }); + + timestampedKeyValueStoreMock.putIfAbsent(anyString(), anyObject(ValueAndTimestamp.class)); + expectLastCall().andAnswer(() -> { + putIfAbsentExecuted = true; + return null; + }); + + timestampedKeyValueStoreMock.putAll(anyObject(List.class)); + expectLastCall().andAnswer(() -> { + putAllExecuted = true; + return null; + }); + + timestampedKeyValueStoreMock.delete(anyString()); + expectLastCall().andAnswer(() -> { + deleteExecuted = true; + return null; + }); + + replay(timestampedKeyValueStoreMock); + + return timestampedKeyValueStoreMock; + } + + @SuppressWarnings("unchecked") + private WindowStore windowStoreMock() { + final WindowStore windowStore = mock(WindowStore.class); + + initStateStoreMock(windowStore); + + expect(windowStore.fetchAll(anyLong(), anyLong())).andReturn(iters.get(0)); + expect(windowStore.fetch(anyString(), anyString(), anyLong(), anyLong())).andReturn(iters.get(1)); + expect(windowStore.fetch(anyString(), anyLong(), anyLong())).andReturn(windowStoreIter); + expect(windowStore.fetch(anyString(), anyLong())).andReturn(VALUE); + expect(windowStore.all()).andReturn(iters.get(2)); + + windowStore.put(anyString(), anyLong(), anyLong()); + expectLastCall().andAnswer(() -> { + putExecuted = true; + return null; + }); + + replay(windowStore); + + return windowStore; + } + + @SuppressWarnings("unchecked") + private TimestampedWindowStore timestampedWindowStoreMock() { + final TimestampedWindowStore windowStore = mock(TimestampedWindowStore.class); + + initStateStoreMock(windowStore); + + expect(windowStore.fetchAll(anyLong(), anyLong())).andReturn(timestampedIters.get(0)); + expect(windowStore.fetch(anyString(), anyString(), anyLong(), anyLong())).andReturn(timestampedIters.get(1)); + expect(windowStore.fetch(anyString(), anyLong(), anyLong())).andReturn(windowStoreIter); + expect(windowStore.fetch(anyString(), anyLong())).andReturn(VALUE_AND_TIMESTAMP); + expect(windowStore.all()).andReturn(timestampedIters.get(2)); + + windowStore.put(anyString(), anyObject(ValueAndTimestamp.class), anyLong()); + expectLastCall().andAnswer(() -> { + putExecuted = true; + return null; + }); + + windowStore.put(anyString(), anyObject(ValueAndTimestamp.class), anyLong()); + expectLastCall().andAnswer(() -> { + putWithTimestampExecuted = true; + return null; + }); + + replay(windowStore); + + return windowStore; + } + + @SuppressWarnings("unchecked") + private SessionStore sessionStoreMock() { + final SessionStore sessionStore = mock(SessionStore.class); + + initStateStoreMock(sessionStore); + + expect(sessionStore.findSessions(anyString(), anyLong(), anyLong())).andReturn(iters.get(3)); + expect(sessionStore.findSessions(anyString(), anyString(), anyLong(), anyLong())).andReturn(iters.get(4)); + expect(sessionStore.fetch(anyString())).andReturn(iters.get(5)); + expect(sessionStore.fetch(anyString(), anyString())).andReturn(iters.get(6)); + + sessionStore.put(anyObject(Windowed.class), anyLong()); + expectLastCall().andAnswer(() -> { + putExecuted = true; + return null; + }); + + sessionStore.remove(anyObject(Windowed.class)); + expectLastCall().andAnswer(() -> { + removeExecuted = true; + return null; + }); + + replay(sessionStore); + + return sessionStore; + } + + private StreamsConfig streamsConfigMock() { + final StreamsConfig streamsConfig = mock(StreamsConfig.class); + expect(streamsConfig.getString(StreamsConfig.APPLICATION_ID_CONFIG)).andStubReturn("add-id"); + expect(streamsConfig.defaultValueSerde()).andStubReturn(Serdes.ByteArray()); + expect(streamsConfig.defaultKeySerde()).andStubReturn(Serdes.ByteArray()); + replay(streamsConfig); + return streamsConfig; + } + + private void initStateStoreMock(final StateStore stateStore) { + expect(stateStore.name()).andReturn(STORE_NAME); + expect(stateStore.persistent()).andReturn(true); + expect(stateStore.isOpen()).andReturn(true); + + stateStore.flush(); + expectLastCall().andAnswer(() -> { + flushExecuted = true; + return null; + }); + } + + private void doTest(final String name, final Consumer checker) { + @SuppressWarnings("deprecation") final org.apache.kafka.streams.processor.Processor processor = new org.apache.kafka.streams.processor.Processor() { + @Override + public void init(final ProcessorContext context) { + final T store = context.getStateStore(name); + checker.accept(store); + } + + @Override + public void process(final String k, final Long v) { + //No-op. + } + + @Override + public void close() { + //No-op. + } + }; + + processor.init(context); + } + + private void verifyStoreCannotBeInitializedOrClosed(final StateStore store) { + assertEquals(STORE_NAME, store.name()); + assertTrue(store.persistent()); + assertTrue(store.isOpen()); + + checkThrowsUnsupportedOperation(() -> store.init((StateStoreContext) null, null), "init()"); + checkThrowsUnsupportedOperation(store::close, "close()"); + } + + private void checkThrowsUnsupportedOperation(final Runnable check, final String name) { + try { + check.run(); + fail(name + " should throw exception"); + } catch (final UnsupportedOperationException e) { + //ignore. + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorContextTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorContextTest.java new file mode 100644 index 0000000..81a98c8 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorContextTest.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.Task.TaskType; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.internals.ThreadCache; +import org.junit.Before; +import org.junit.Test; + +import java.time.Duration; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; + +public class ProcessorContextTest { + private ProcessorContext context; + + @Before + public void prepare() { + final StreamsConfig streamsConfig = mock(StreamsConfig.class); + doReturn("add-id").when(streamsConfig).getString(StreamsConfig.APPLICATION_ID_CONFIG); + doReturn(Serdes.ByteArray()).when(streamsConfig).defaultValueSerde(); + doReturn(Serdes.ByteArray()).when(streamsConfig).defaultKeySerde(); + + final ProcessorStateManager stateManager = mock(ProcessorStateManager.class); + doReturn(TaskType.ACTIVE).when(stateManager).taskType(); + + context = new ProcessorContextImpl( + mock(TaskId.class), + streamsConfig, + stateManager, + mock(StreamsMetricsImpl.class), + mock(ThreadCache.class) + ); + ((InternalProcessorContext) context).transitionToActive(mock(StreamTask.class), null, null); + } + + @Test + public void shouldNotAllowToScheduleZeroMillisecondPunctuation() { + try { + context.schedule(Duration.ofMillis(0L), null, null); + fail("Should have thrown IllegalArgumentException"); + } catch (final IllegalArgumentException expected) { + assertThat(expected.getMessage(), equalTo("The minimum supported scheduling interval is 1 millisecond.")); + } + } + + @Test + public void shouldNotAllowToScheduleSubMillisecondPunctuation() { + try { + context.schedule(Duration.ofNanos(999_999L), null, null); + fail("Should have thrown IllegalArgumentException"); + } catch (final IllegalArgumentException expected) { + assertThat(expected.getMessage(), equalTo("The minimum supported scheduling interval is 1 millisecond.")); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorNodeTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorNodeTest.java new file mode 100644 index 0000000..87a4c68 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorNodeTest.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.test.InternalMockProcessorContext; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Test; + +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Properties; + +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.ROLLUP_VALUE; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class ProcessorNodeTest { + + @SuppressWarnings("unchecked") + @Test + public void shouldThrowStreamsExceptionIfExceptionCaughtDuringInit() { + final ProcessorNode node = new ProcessorNode("name", new ExceptionalProcessor(), Collections.emptySet()); + assertThrows(StreamsException.class, () -> node.init(null)); + } + + @SuppressWarnings("unchecked") + @Test + public void shouldThrowStreamsExceptionIfExceptionCaughtDuringClose() { + final ProcessorNode node = new ProcessorNode("name", new ExceptionalProcessor(), Collections.emptySet()); + assertThrows(StreamsException.class, () -> node.init(null)); + } + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + private static class ExceptionalProcessor implements org.apache.kafka.streams.processor.Processor { + @Override + public void init(final ProcessorContext context) { + throw new RuntimeException(); + } + + @Override + public void process(final Object key, final Object value) { + throw new RuntimeException(); + } + + @Override + public void close() { + throw new RuntimeException(); + } + } + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + private static class NoOpProcessor implements org.apache.kafka.streams.processor.Processor { + @Override + public void init(final ProcessorContext context) { + + } + + @Override + public void process(final Object key, final Object value) { + + } + + @Override + public void close() { + + } + } + + @Test + public void testMetricsWithBuiltInMetricsVersionLatest() { + final Metrics metrics = new Metrics(); + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, "test-client", StreamsConfig.METRICS_LATEST, new MockTime()); + final InternalMockProcessorContext context = new InternalMockProcessorContext<>(streamsMetrics); + final ProcessorNode node = new ProcessorNode<>("name", new NoOpProcessor(), Collections.emptySet()); + node.init(context); + + final String threadId = Thread.currentThread().getName(); + final String[] latencyOperations = {"process", "punctuate", "create", "destroy"}; + final String groupName = "stream-processor-node-metrics"; + final Map metricTags = new LinkedHashMap<>(); + final String threadIdTagKey = "client-id"; + metricTags.put("processor-node-id", node.name()); + metricTags.put("task-id", context.taskId().toString()); + metricTags.put(threadIdTagKey, threadId); + + for (final String opName : latencyOperations) { + assertFalse(StreamsTestUtils.containsMetric(metrics, opName + "-latency-avg", groupName, metricTags)); + assertFalse(StreamsTestUtils.containsMetric(metrics, opName + "-latency-max", groupName, metricTags)); + assertFalse(StreamsTestUtils.containsMetric(metrics, opName + "-rate", groupName, metricTags)); + assertFalse(StreamsTestUtils.containsMetric(metrics, opName + "-total", groupName, metricTags)); + } + + // test parent sensors + metricTags.put("processor-node-id", ROLLUP_VALUE); + for (final String opName : latencyOperations) { + assertFalse(StreamsTestUtils.containsMetric(metrics, opName + "-latency-avg", groupName, metricTags)); + assertFalse(StreamsTestUtils.containsMetric(metrics, opName + "-latency-max", groupName, metricTags)); + assertFalse(StreamsTestUtils.containsMetric(metrics, opName + "-rate", groupName, metricTags)); + assertFalse(StreamsTestUtils.containsMetric(metrics, opName + "-total", groupName, metricTags)); + } + } + + @Test + public void testTopologyLevelClassCastException() { + // Serdes configuration is missing and no default is set which will trigger an exception + final StreamsBuilder builder = new StreamsBuilder(); + + builder.stream("streams-plaintext-input") + .flatMapValues(value -> { + return Collections.singletonList(""); + }); + final Topology topology = builder.build(); + final Properties config = new Properties(); + config.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.ByteArraySerde.class); + config.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.ByteArraySerde.class); + + try (final TopologyTestDriver testDriver = new TopologyTestDriver(topology, config)) { + final TestInputTopic topic = testDriver.createInputTopic("streams-plaintext-input", new StringSerializer(), new StringSerializer()); + + final StreamsException se = assertThrows(StreamsException.class, () -> topic.pipeInput("a-key", "a value")); + final String msg = se.getMessage(); + assertTrue("Error about class cast with serdes", msg.contains("ClassCastException")); + assertTrue("Error about class cast with serdes", msg.contains("Serdes")); + } + } + + @Test + public void testTopologyLevelConfigException() { + // Serdes configuration is missing and no default is set which will trigger an exception + final StreamsBuilder builder = new StreamsBuilder(); + + builder.stream("streams-plaintext-input") + .flatMapValues(value -> { + return Collections.singletonList(""); + }); + final Topology topology = builder.build(); + + final ConfigException se = assertThrows(ConfigException.class, () -> new TopologyTestDriver(topology)); + final String msg = se.getMessage(); + assertTrue("Error about class cast with serdes", msg.contains("StreamsConfig#DEFAULT_KEY_SERDE_CLASS_CONFIG")); + assertTrue("Error about class cast with serdes", msg.contains("specify a key serde")); + } + + private static class ClassCastProcessor extends ExceptionalProcessor { + + @Override + public void init(final ProcessorContext context) { + } + + @Override + public void process(final Object key, final Object value) { + throw new ClassCastException("Incompatible types simulation exception."); + } + } + + @Test + public void testTopologyLevelClassCastExceptionDirect() { + final Metrics metrics = new Metrics(); + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, "test-client", StreamsConfig.METRICS_LATEST, new MockTime()); + final InternalMockProcessorContext context = new InternalMockProcessorContext<>(streamsMetrics); + final ProcessorNode node = new ProcessorNode<>("pname", new ClassCastProcessor(), Collections.emptySet()); + node.init(context); + final StreamsException se = assertThrows( + StreamsException.class, + () -> node.process(new Record<>("aKey", "aValue", 0)) + ); + assertThat(se.getCause(), instanceOf(ClassCastException.class)); + assertThat(se.getMessage(), containsString("default Serdes")); + assertThat(se.getMessage(), containsString("input types")); + assertThat(se.getMessage(), containsString("pname")); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorRecordContextTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorRecordContextTest.java new file mode 100644 index 0000000..68eb21e --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorRecordContextTest.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +public class ProcessorRecordContextTest { + // timestamp + offset + partition: 8 + 8 + 4 + private final static long MIN_SIZE = 20L; + + @Test + public void shouldNotAllowNullHeaders() { + assertThrows( + NullPointerException.class, + () -> new ProcessorRecordContext( + 42L, + 73L, + 0, + "topic", + null + ) + ); + } + + @Test + public void shouldEstimateNullTopicAndEmptyHeadersAsZeroLength() { + final Headers headers = new RecordHeaders(); + final ProcessorRecordContext context = new ProcessorRecordContext( + 42L, + 73L, + 0, + null, + new RecordHeaders() + ); + + assertEquals(MIN_SIZE, context.residentMemorySizeEstimate()); + } + + @Test + public void shouldEstimateEmptyHeaderAsZeroLength() { + final ProcessorRecordContext context = new ProcessorRecordContext( + 42L, + 73L, + 0, + null, + new RecordHeaders() + ); + + assertEquals(MIN_SIZE, context.residentMemorySizeEstimate()); + } + + @Test + public void shouldEstimateTopicLength() { + final ProcessorRecordContext context = new ProcessorRecordContext( + 42L, + 73L, + 0, + "topic", + new RecordHeaders() + ); + + assertEquals(MIN_SIZE + 5L, context.residentMemorySizeEstimate()); + } + + @Test + public void shouldEstimateHeadersLength() { + final Headers headers = new RecordHeaders(); + headers.add("header-key", "header-value".getBytes()); + final ProcessorRecordContext context = new ProcessorRecordContext( + 42L, + 73L, + 0, + null, + headers + ); + + assertEquals(MIN_SIZE + 10L + 12L, context.residentMemorySizeEstimate()); + } + + @Test + public void shouldEstimateNullValueInHeaderAsZero() { + final Headers headers = new RecordHeaders(); + headers.add("header-key", null); + final ProcessorRecordContext context = new ProcessorRecordContext( + 42L, + 73L, + 0, + null, + headers + ); + + assertEquals(MIN_SIZE + 10L, context.residentMemorySizeEstimate()); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java new file mode 100644 index 0000000..3e88ced --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorStateManagerTest.java @@ -0,0 +1,1065 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.ProcessorStateException; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.errors.TaskCorruptedException; +import org.apache.kafka.streams.processor.StateRestoreCallback; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.apache.kafka.streams.processor.internals.ProcessorStateManager.StateStoreMetadata; +import org.apache.kafka.streams.state.TimestampedBytesStore; +import org.apache.kafka.streams.state.internals.OffsetCheckpoint; +import org.apache.kafka.test.MockKeyValueStore; +import org.apache.kafka.test.MockRestoreCallback; +import org.apache.kafka.test.TestUtils; +import org.easymock.EasyMock; +import org.easymock.EasyMockRunner; +import org.easymock.Mock; +import org.easymock.MockType; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.atomic.AtomicBoolean; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.emptySet; +import static java.util.Collections.singletonList; +import static java.util.Collections.singletonMap; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.streams.processor.internals.StateManagerUtil.CHECKPOINT_FILE_NAME; +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.reset; +import static org.easymock.EasyMock.verify; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.Is.is; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +@RunWith(EasyMockRunner.class) +public class ProcessorStateManagerTest { + + private final String applicationId = "test-application"; + private final TaskId taskId = new TaskId(0, 1, "My-Topology"); + private final String persistentStoreName = "persistentStore"; + private final String persistentStoreTwoName = "persistentStore2"; + private final String nonPersistentStoreName = "nonPersistentStore"; + private final String persistentStoreTopicName = + ProcessorStateManager.storeChangelogTopic(applicationId, persistentStoreName, taskId.topologyName()); + private final String persistentStoreTwoTopicName = + ProcessorStateManager.storeChangelogTopic(applicationId, persistentStoreTwoName, taskId.topologyName()); + private final String nonPersistentStoreTopicName = + ProcessorStateManager.storeChangelogTopic(applicationId, nonPersistentStoreName, taskId.topologyName()); + private final MockKeyValueStore persistentStore = new MockKeyValueStore(persistentStoreName, true); + private final MockKeyValueStore persistentStoreTwo = new MockKeyValueStore(persistentStoreTwoName, true); + private final MockKeyValueStore nonPersistentStore = new MockKeyValueStore(nonPersistentStoreName, false); + private final TopicPartition persistentStorePartition = new TopicPartition(persistentStoreTopicName, 1); + private final TopicPartition persistentStoreTwoPartition = new TopicPartition(persistentStoreTwoTopicName, 1); + private final TopicPartition nonPersistentStorePartition = new TopicPartition(nonPersistentStoreTopicName, 1); + private final TopicPartition irrelevantPartition = new TopicPartition("other-topic", 1); + private final Integer key = 1; + private final String value = "the-value"; + private final byte[] keyBytes = new byte[] {0x0, 0x0, 0x0, 0x1}; + private final byte[] valueBytes = value.getBytes(StandardCharsets.UTF_8); + private final ConsumerRecord consumerRecord = + new ConsumerRecord<>(persistentStoreTopicName, 1, 100L, keyBytes, valueBytes); + private final MockChangelogReader changelogReader = new MockChangelogReader(); + private final LogContext logContext = new LogContext("process-state-manager-test "); + private final StateRestoreCallback noopStateRestoreCallback = (k, v) -> { }; + + private File baseDir; + private File checkpointFile; + private OffsetCheckpoint checkpoint; + private StateDirectory stateDirectory; + + @Mock(type = MockType.NICE) + private StateStore store; + @Mock(type = MockType.NICE) + private StateStoreMetadata storeMetadata; + @Mock(type = MockType.NICE) + private InternalProcessorContext context; + + @Before + public void setup() { + baseDir = TestUtils.tempDirectory(); + + stateDirectory = new StateDirectory(new StreamsConfig(new Properties() { + { + put(StreamsConfig.APPLICATION_ID_CONFIG, applicationId); + put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"); + put(StreamsConfig.STATE_DIR_CONFIG, baseDir.getPath()); + } + }), new MockTime(), true, true); + checkpointFile = new File(stateDirectory.getOrCreateDirectoryForTask(taskId), CHECKPOINT_FILE_NAME); + checkpoint = new OffsetCheckpoint(checkpointFile); + + expect(storeMetadata.changelogPartition()).andReturn(persistentStorePartition).anyTimes(); + expect(storeMetadata.store()).andReturn(store).anyTimes(); + expect(store.name()).andReturn(persistentStoreName).anyTimes(); + replay(storeMetadata, store); + } + + @After + public void cleanup() throws IOException { + Utils.delete(baseDir); + } + + @Test + public void shouldReturnDefaultChangelogTopicName() { + final String applicationId = "appId"; + final String storeName = "store"; + + assertThat( + ProcessorStateManager.storeChangelogTopic(applicationId, storeName, null), + is(applicationId + "-" + storeName + "-changelog") + ); + } + + @Test + public void shouldReturnDefaultChangelogTopicNameWithNamedTopology() { + final String applicationId = "appId"; + final String namedTopology = "namedTopology"; + final String storeName = "store"; + + assertThat( + ProcessorStateManager.storeChangelogTopic(applicationId, storeName, namedTopology), + is(applicationId + "-" + namedTopology + "-" + storeName + "-changelog") + ); + } + + @Test + public void shouldReturnBaseDir() { + final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE); + assertEquals(stateDirectory.getOrCreateDirectoryForTask(taskId), stateMgr.baseDir()); + } + + // except this test for all other tests active / standby state managers acts the same, so + // for all others we always use ACTIVE unless explained specifically. + @Test + public void shouldReportTaskType() { + ProcessorStateManager stateMgr = getStateManager(Task.TaskType.STANDBY); + assertEquals(Task.TaskType.STANDBY, stateMgr.taskType()); + + stateMgr = getStateManager(Task.TaskType.ACTIVE); + assertEquals(Task.TaskType.ACTIVE, stateMgr.taskType()); + } + + @Test + public void shouldReportChangelogAsSource() { + final ProcessorStateManager stateMgr = new ProcessorStateManager( + taskId, + Task.TaskType.STANDBY, + false, + logContext, + stateDirectory, + changelogReader, + mkMap( + mkEntry(persistentStoreName, persistentStoreTopicName), + mkEntry(persistentStoreTwoName, persistentStoreTwoTopicName), + mkEntry(nonPersistentStoreName, nonPersistentStoreTopicName) + ), + mkSet(persistentStorePartition, nonPersistentStorePartition)); + + assertTrue(stateMgr.changelogAsSource(persistentStorePartition)); + assertTrue(stateMgr.changelogAsSource(nonPersistentStorePartition)); + assertFalse(stateMgr.changelogAsSource(persistentStoreTwoPartition)); + } + + @Test + public void shouldFindSingleStoreForChangelog() { + final ProcessorStateManager stateMgr = new ProcessorStateManager( + taskId, + Task.TaskType.STANDBY, + false, + logContext, + stateDirectory, + changelogReader, mkMap( + mkEntry(persistentStoreName, persistentStoreTopicName), + mkEntry(persistentStoreTwoName, persistentStoreTopicName) + ), + Collections.emptySet()); + + stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback); + stateMgr.registerStore(persistentStoreTwo, persistentStore.stateRestoreCallback); + + assertThrows( + IllegalStateException.class, + () -> stateMgr.updateChangelogOffsets(Collections.singletonMap(persistentStorePartition, 0L)) + ); + } + + @Test + public void shouldRestoreStoreWithRestoreCallback() { + final MockRestoreCallback restoreCallback = new MockRestoreCallback(); + + final KeyValue expectedKeyValue = KeyValue.pair(keyBytes, valueBytes); + + final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE); + + try { + stateMgr.registerStore(persistentStore, restoreCallback); + final StateStoreMetadata storeMetadata = stateMgr.storeMetadata(persistentStorePartition); + assertThat(storeMetadata, notNullValue()); + + stateMgr.restore(storeMetadata, singletonList(consumerRecord)); + + assertThat(restoreCallback.restored.size(), is(1)); + assertTrue(restoreCallback.restored.contains(expectedKeyValue)); + + assertEquals(Collections.singletonMap(persistentStorePartition, 101L), stateMgr.changelogOffsets()); + } finally { + stateMgr.close(); + } + } + + @Test + public void shouldRestoreNonTimestampedStoreWithNoConverter() { + final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE); + + try { + stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback); + final StateStoreMetadata storeMetadata = stateMgr.storeMetadata(persistentStorePartition); + assertThat(storeMetadata, notNullValue()); + + stateMgr.restore(storeMetadata, singletonList(consumerRecord)); + + assertThat(persistentStore.keys.size(), is(1)); + assertTrue(persistentStore.keys.contains(key)); + // we just check non timestamped value length + assertEquals(9, persistentStore.values.get(0).length); + } finally { + stateMgr.close(); + } + } + + @Test + public void shouldRestoreTimestampedStoreWithConverter() { + final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE); + final MockKeyValueStore store = getConverterStore(); + + try { + stateMgr.registerStore(store, store.stateRestoreCallback); + final StateStoreMetadata storeMetadata = stateMgr.storeMetadata(persistentStorePartition); + assertThat(storeMetadata, notNullValue()); + + stateMgr.restore(storeMetadata, singletonList(consumerRecord)); + + assertThat(store.keys.size(), is(1)); + assertTrue(store.keys.contains(key)); + // we just check timestamped value length + assertEquals(17, store.values.get(0).length); + } finally { + stateMgr.close(); + } + } + + @Test + public void shouldUnregisterChangelogsDuringClose() { + final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE); + reset(storeMetadata); + final StateStore store = EasyMock.createMock(StateStore.class); + expect(storeMetadata.changelogPartition()).andStubReturn(persistentStorePartition); + expect(storeMetadata.store()).andStubReturn(store); + expect(store.name()).andStubReturn(persistentStoreName); + + context.uninitialize(); + store.init((StateStoreContext) context, store); + replay(storeMetadata, context, store); + + stateMgr.registerStateStores(singletonList(store), context); + verify(context, store); + + stateMgr.registerStore(store, noopStateRestoreCallback); + assertTrue(changelogReader.isPartitionRegistered(persistentStorePartition)); + + reset(store); + expect(store.name()).andStubReturn(persistentStoreName); + store.close(); + replay(store); + + stateMgr.close(); + verify(store); + + assertFalse(changelogReader.isPartitionRegistered(persistentStorePartition)); + } + + @Test + public void shouldRecycleStoreAndReregisterChangelog() { + final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE); + reset(storeMetadata); + final StateStore store = EasyMock.createMock(StateStore.class); + expect(storeMetadata.changelogPartition()).andStubReturn(persistentStorePartition); + expect(storeMetadata.store()).andStubReturn(store); + expect(store.name()).andStubReturn(persistentStoreName); + + context.uninitialize(); + store.init((StateStoreContext) context, store); + replay(storeMetadata, context, store); + + stateMgr.registerStateStores(singletonList(store), context); + verify(context, store); + + stateMgr.registerStore(store, noopStateRestoreCallback); + assertTrue(changelogReader.isPartitionRegistered(persistentStorePartition)); + + stateMgr.recycle(); + assertFalse(changelogReader.isPartitionRegistered(persistentStorePartition)); + assertThat(stateMgr.getStore(persistentStoreName), equalTo(store)); + + reset(context, store); + context.uninitialize(); + expect(store.name()).andStubReturn(persistentStoreName); + replay(context, store); + + stateMgr.registerStateStores(singletonList(store), context); + + verify(context, store); + assertTrue(changelogReader.isPartitionRegistered(persistentStorePartition)); + } + + @Test + public void shouldRegisterPersistentStores() { + final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE); + + try { + stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback); + assertTrue(changelogReader.isPartitionRegistered(persistentStorePartition)); + } finally { + stateMgr.close(); + } + } + + @Test + public void shouldRegisterNonPersistentStore() { + final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE); + + try { + stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback); + assertTrue(changelogReader.isPartitionRegistered(nonPersistentStorePartition)); + } finally { + stateMgr.close(); + } + } + + @Test + public void shouldNotRegisterNonLoggedStore() { + final ProcessorStateManager stateMgr = new ProcessorStateManager( + taskId, + Task.TaskType.STANDBY, + false, + logContext, + stateDirectory, + changelogReader, + emptyMap(), + emptySet()); + + try { + stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback); + assertFalse(changelogReader.isPartitionRegistered(persistentStorePartition)); + } finally { + stateMgr.close(); + } + } + + @Test + public void shouldInitializeOffsetsFromCheckpointFile() throws IOException { + final long checkpointOffset = 10L; + + final Map offsets = mkMap( + mkEntry(persistentStorePartition, checkpointOffset), + mkEntry(nonPersistentStorePartition, checkpointOffset), + mkEntry(irrelevantPartition, 999L) + ); + checkpoint.write(offsets); + + final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE); + + try { + stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback); + stateMgr.registerStore(persistentStoreTwo, persistentStoreTwo.stateRestoreCallback); + stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback); + stateMgr.initializeStoreOffsetsFromCheckpoint(true); + + assertTrue(checkpointFile.exists()); + assertEquals(mkSet( + persistentStorePartition, + persistentStoreTwoPartition, + nonPersistentStorePartition), + stateMgr.changelogPartitions()); + assertEquals(mkMap( + mkEntry(persistentStorePartition, checkpointOffset + 1L), + mkEntry(persistentStoreTwoPartition, 0L), + mkEntry(nonPersistentStorePartition, 0L)), + stateMgr.changelogOffsets() + ); + + assertNull(stateMgr.storeMetadata(irrelevantPartition)); + assertNull(stateMgr.storeMetadata(persistentStoreTwoPartition).offset()); + assertThat(stateMgr.storeMetadata(persistentStorePartition).offset(), equalTo(checkpointOffset)); + assertNull(stateMgr.storeMetadata(nonPersistentStorePartition).offset()); + } finally { + stateMgr.close(); + } + } + + @Test + public void shouldInitializeOffsetsFromCheckpointFileAndDeleteIfEOSEnabled() throws IOException { + final long checkpointOffset = 10L; + + final Map offsets = mkMap( + mkEntry(persistentStorePartition, checkpointOffset), + mkEntry(nonPersistentStorePartition, checkpointOffset), + mkEntry(irrelevantPartition, 999L) + ); + checkpoint.write(offsets); + + final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE, true); + + try { + stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback); + stateMgr.registerStore(persistentStoreTwo, persistentStoreTwo.stateRestoreCallback); + stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback); + stateMgr.initializeStoreOffsetsFromCheckpoint(true); + + assertFalse(checkpointFile.exists()); + assertEquals(mkSet( + persistentStorePartition, + persistentStoreTwoPartition, + nonPersistentStorePartition), + stateMgr.changelogPartitions()); + assertEquals(mkMap( + mkEntry(persistentStorePartition, checkpointOffset + 1L), + mkEntry(persistentStoreTwoPartition, 0L), + mkEntry(nonPersistentStorePartition, 0L)), + stateMgr.changelogOffsets() + ); + + assertNull(stateMgr.storeMetadata(irrelevantPartition)); + assertNull(stateMgr.storeMetadata(persistentStoreTwoPartition).offset()); + assertThat(stateMgr.storeMetadata(persistentStorePartition).offset(), equalTo(checkpointOffset)); + assertNull(stateMgr.storeMetadata(nonPersistentStorePartition).offset()); + } finally { + stateMgr.close(); + } + } + + @Test + public void shouldGetRegisteredStore() { + final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE); + try { + stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback); + stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback); + + assertNull(stateMgr.getStore("noSuchStore")); + assertEquals(persistentStore, stateMgr.getStore(persistentStoreName)); + assertEquals(nonPersistentStore, stateMgr.getStore(nonPersistentStoreName)); + } finally { + stateMgr.close(); + } + } + + @Test + public void shouldGetChangelogPartitionForRegisteredStore() { + final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE); + stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback); + + final TopicPartition changelogPartition = stateMgr.registeredChangelogPartitionFor(persistentStoreName); + + assertThat(changelogPartition.topic(), is(persistentStoreTopicName)); + assertThat(changelogPartition.partition(), is(taskId.partition())); + } + + @Test + public void shouldThrowIfStateStoreIsNotRegistered() { + final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE); + + assertThrows("State store " + persistentStoreName + + " for which the registered changelog partition should be" + + " retrieved has not been registered", + IllegalStateException.class, + () -> stateMgr.registeredChangelogPartitionFor(persistentStoreName) + ); + } + + @Test + public void shouldThrowIfStateStoreHasLoggingDisabled() { + final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE); + final String storeName = "store-with-logging-disabled"; + final MockKeyValueStore storeWithLoggingDisabled = new MockKeyValueStore(storeName, true); + stateMgr.registerStore(storeWithLoggingDisabled, null); + + assertThrows("Registered state store " + storeName + + " does not have a registered changelog partition." + + " This may happen if logging is disabled for the state store.", + IllegalStateException.class, + () -> stateMgr.registeredChangelogPartitionFor(storeName) + ); + } + + @Test + public void shouldFlushCheckpointAndClose() throws IOException { + checkpoint.write(emptyMap()); + + // set up ack'ed offsets + final HashMap ackedOffsets = new HashMap<>(); + ackedOffsets.put(persistentStorePartition, 123L); + ackedOffsets.put(nonPersistentStorePartition, 456L); + ackedOffsets.put(new TopicPartition("nonRegisteredTopic", 1), 789L); + + final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE); + try { + // make sure the checkpoint file is not written yet + assertFalse(checkpointFile.exists()); + + stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback); + stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback); + } finally { + stateMgr.flush(); + + assertTrue(persistentStore.flushed); + assertTrue(nonPersistentStore.flushed); + + // make sure that flush is called in the proper order + assertThat(persistentStore.getLastFlushCount(), Matchers.lessThan(nonPersistentStore.getLastFlushCount())); + + stateMgr.updateChangelogOffsets(ackedOffsets); + stateMgr.checkpoint(); + + assertTrue(checkpointFile.exists()); + + // the checkpoint file should contain an offset from the persistent store only. + final Map checkpointedOffsets = checkpoint.read(); + assertThat(checkpointedOffsets, is(singletonMap(new TopicPartition(persistentStoreTopicName, 1), 123L))); + + stateMgr.close(); + + assertTrue(persistentStore.closed); + assertTrue(nonPersistentStore.closed); + } + } + + @Test + public void shouldOverrideOffsetsWhenRestoreAndProcess() throws IOException { + final Map offsets = singletonMap(persistentStorePartition, 99L); + checkpoint.write(offsets); + + final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE); + try { + stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback); + stateMgr.initializeStoreOffsetsFromCheckpoint(true); + + final StateStoreMetadata storeMetadata = stateMgr.storeMetadata(persistentStorePartition); + assertThat(storeMetadata, notNullValue()); + assertThat(storeMetadata.offset(), equalTo(99L)); + + stateMgr.restore(storeMetadata, singletonList(consumerRecord)); + + assertThat(storeMetadata.offset(), equalTo(100L)); + + // should ignore irrelevant topic partitions + stateMgr.updateChangelogOffsets(mkMap( + mkEntry(persistentStorePartition, 220L), + mkEntry(irrelevantPartition, 9000L) + )); + stateMgr.checkpoint(); + + assertThat(stateMgr.storeMetadata(irrelevantPartition), equalTo(null)); + assertThat(storeMetadata.offset(), equalTo(220L)); + } finally { + stateMgr.close(); + } + } + + @Test + public void shouldWriteCheckpointForPersistentStore() throws IOException { + final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE); + + try { + stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback); + stateMgr.initializeStoreOffsetsFromCheckpoint(true); + + final StateStoreMetadata storeMetadata = stateMgr.storeMetadata(persistentStorePartition); + assertThat(storeMetadata, notNullValue()); + + stateMgr.restore(storeMetadata, singletonList(consumerRecord)); + + stateMgr.checkpoint(); + + final Map read = checkpoint.read(); + assertThat(read, equalTo(singletonMap(persistentStorePartition, 100L))); + } finally { + stateMgr.close(); + } + } + + @Test + public void shouldNotWriteCheckpointForNonPersistentStore() throws IOException { + final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE); + + try { + stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback); + stateMgr.initializeStoreOffsetsFromCheckpoint(true); + + final StateStoreMetadata storeMetadata = stateMgr.storeMetadata(nonPersistentStorePartition); + assertThat(storeMetadata, notNullValue()); + + stateMgr.updateChangelogOffsets(singletonMap(nonPersistentStorePartition, 876L)); + stateMgr.checkpoint(); + + final Map read = checkpoint.read(); + assertThat(read, equalTo(emptyMap())); + } finally { + stateMgr.close(); + } + } + + @Test + public void shouldNotWriteCheckpointForStoresWithoutChangelogTopic() throws IOException { + final ProcessorStateManager stateMgr = new ProcessorStateManager( + taskId, + Task.TaskType.STANDBY, + false, + logContext, + stateDirectory, + changelogReader, + emptyMap(), + emptySet()); + + try { + stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback); + + stateMgr.updateChangelogOffsets(singletonMap(persistentStorePartition, 987L)); + stateMgr.checkpoint(); + + final Map read = checkpoint.read(); + assertThat(read, equalTo(emptyMap())); + } finally { + stateMgr.close(); + } + } + + @Test + public void shouldThrowIllegalArgumentExceptionIfStoreNameIsSameAsCheckpointFileName() { + final ProcessorStateManager stateManager = getStateManager(Task.TaskType.ACTIVE); + + assertThrows(IllegalArgumentException.class, () -> + stateManager.registerStore(new MockKeyValueStore(CHECKPOINT_FILE_NAME, true), null)); + } + + @Test + public void shouldThrowIllegalArgumentExceptionOnRegisterWhenStoreHasAlreadyBeenRegistered() { + final ProcessorStateManager stateManager = getStateManager(Task.TaskType.ACTIVE); + + stateManager.registerStore(persistentStore, persistentStore.stateRestoreCallback); + + assertThrows(IllegalArgumentException.class, () -> + stateManager.registerStore(persistentStore, persistentStore.stateRestoreCallback)); + } + + @Test + public void shouldThrowProcessorStateExceptionOnFlushIfStoreThrowsAnException() { + final RuntimeException exception = new RuntimeException("KABOOM!"); + final ProcessorStateManager stateManager = getStateManager(Task.TaskType.ACTIVE); + final MockKeyValueStore stateStore = new MockKeyValueStore(persistentStoreName, true) { + @Override + public void flush() { + throw exception; + } + }; + stateManager.registerStore(stateStore, stateStore.stateRestoreCallback); + + final ProcessorStateException thrown = assertThrows(ProcessorStateException.class, stateManager::flush); + assertEquals(exception, thrown.getCause()); + } + + @Test + public void shouldPreserveStreamsExceptionOnFlushIfStoreThrows() { + final StreamsException exception = new StreamsException("KABOOM!"); + final ProcessorStateManager stateManager = getStateManager(Task.TaskType.ACTIVE); + final MockKeyValueStore stateStore = new MockKeyValueStore(persistentStoreName, true) { + @Override + public void flush() { + throw exception; + } + }; + stateManager.registerStore(stateStore, stateStore.stateRestoreCallback); + + final StreamsException thrown = assertThrows(StreamsException.class, stateManager::flush); + assertEquals(exception, thrown); + } + + @Test + public void shouldThrowProcessorStateExceptionOnCloseIfStoreThrowsAnException() { + final RuntimeException exception = new RuntimeException("KABOOM!"); + final ProcessorStateManager stateManager = getStateManager(Task.TaskType.ACTIVE); + final MockKeyValueStore stateStore = new MockKeyValueStore(persistentStoreName, true) { + @Override + public void close() { + throw exception; + } + }; + stateManager.registerStore(stateStore, stateStore.stateRestoreCallback); + + final ProcessorStateException thrown = assertThrows(ProcessorStateException.class, stateManager::close); + assertEquals(exception, thrown.getCause()); + } + + @Test + public void shouldPreserveStreamsExceptionOnCloseIfStoreThrows() { + final StreamsException exception = new StreamsException("KABOOM!"); + final ProcessorStateManager stateManager = getStateManager(Task.TaskType.ACTIVE); + final MockKeyValueStore stateStore = new MockKeyValueStore(persistentStoreName, true) { + @Override + public void close() { + throw exception; + } + }; + stateManager.registerStore(stateStore, stateStore.stateRestoreCallback); + + final StreamsException thrown = assertThrows(StreamsException.class, stateManager::close); + assertEquals(exception, thrown); + } + + @Test + public void shouldThrowIfRestoringUnregisteredStore() { + final ProcessorStateManager stateManager = getStateManager(Task.TaskType.ACTIVE); + + assertThrows(IllegalStateException.class, () -> stateManager.restore(storeMetadata, Collections.emptyList())); + } + + @SuppressWarnings("OptionalGetWithoutIsPresent") + @Test + public void shouldLogAWarningIfCheckpointThrowsAnIOException() { + final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE); + stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback); + stateDirectory.clean(); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(ProcessorStateManager.class)) { + stateMgr.updateChangelogOffsets(singletonMap(persistentStorePartition, 10L)); + stateMgr.checkpoint(); + + boolean foundExpectedLogMessage = false; + for (final LogCaptureAppender.Event event : appender.getEvents()) { + if ("WARN".equals(event.getLevel()) + && event.getMessage().startsWith("process-state-manager-test Failed to write offset checkpoint file to [") + && event.getMessage().endsWith(".checkpoint]." + + " This may occur if OS cleaned the state.dir in case when it located in ${java.io.tmpdir} directory." + + " This may also occur due to running multiple instances on the same machine using the same state dir." + + " Changing the location of state.dir may resolve the problem.") + && event.getThrowableInfo().get().startsWith("java.io.FileNotFoundException: ")) { + + foundExpectedLogMessage = true; + break; + } + } + assertTrue(foundExpectedLogMessage); + } + } + + @Test + public void shouldThrowIfLoadCheckpointThrows() throws Exception { + final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE); + + stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback); + final File file = new File(stateMgr.baseDir(), CHECKPOINT_FILE_NAME); + file.createNewFile(); + final FileWriter writer = new FileWriter(file); + writer.write("abcdefg"); + writer.close(); + + try { + stateMgr.initializeStoreOffsetsFromCheckpoint(true); + fail("should have thrown processor state exception when IO exception happens"); + } catch (final ProcessorStateException e) { + // pass + } + } + + @Test + public void shouldThrowIfRestoreCallbackThrows() { + final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE); + + stateMgr.registerStore(persistentStore, (key, value) -> { + throw new RuntimeException("KABOOM!"); + }); + + final StateStoreMetadata storeMetadata = stateMgr.storeMetadata(persistentStorePartition); + + try { + stateMgr.restore(storeMetadata, singletonList(consumerRecord)); + fail("should have thrown processor state exception when IO exception happens"); + } catch (final ProcessorStateException e) { + // pass + } + } + + @Test + public void shouldFlushGoodStoresEvenSomeThrowsException() { + final AtomicBoolean flushedStore = new AtomicBoolean(false); + + final MockKeyValueStore stateStore1 = new MockKeyValueStore(persistentStoreName, true) { + @Override + public void flush() { + throw new RuntimeException("KABOOM!"); + } + }; + final MockKeyValueStore stateStore2 = new MockKeyValueStore(persistentStoreTwoName, true) { + @Override + public void flush() { + flushedStore.set(true); + } + }; + final ProcessorStateManager stateManager = getStateManager(Task.TaskType.ACTIVE); + + stateManager.registerStore(stateStore1, stateStore1.stateRestoreCallback); + stateManager.registerStore(stateStore2, stateStore2.stateRestoreCallback); + + try { + stateManager.flush(); + } catch (final ProcessorStateException expected) { /* ignore */ } + + Assert.assertTrue(flushedStore.get()); + } + + @Test + public void shouldCloseAllStoresEvenIfStoreThrowsException() { + final AtomicBoolean closedStore = new AtomicBoolean(false); + + final MockKeyValueStore stateStore1 = new MockKeyValueStore(persistentStoreName, true) { + @Override + public void close() { + throw new RuntimeException("KABOOM!"); + } + }; + final MockKeyValueStore stateStore2 = new MockKeyValueStore(persistentStoreTwoName, true) { + @Override + public void close() { + closedStore.set(true); + } + }; + final ProcessorStateManager stateManager = getStateManager(Task.TaskType.ACTIVE); + + stateManager.registerStore(stateStore1, stateStore1.stateRestoreCallback); + stateManager.registerStore(stateStore2, stateStore2.stateRestoreCallback); + + try { + stateManager.close(); + } catch (final ProcessorStateException expected) { /* ignore */ } + + Assert.assertTrue(closedStore.get()); + } + + @Test + public void shouldThrowTaskCorruptedWithoutPersistentStoreCheckpointAndNonEmptyDir() throws IOException { + final long checkpointOffset = 10L; + + final Map offsets = mkMap( + mkEntry(persistentStorePartition, checkpointOffset), + mkEntry(nonPersistentStorePartition, checkpointOffset), + mkEntry(irrelevantPartition, 999L) + ); + checkpoint.write(offsets); + + final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE, true); + + try { + stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback); + stateMgr.registerStore(persistentStoreTwo, persistentStoreTwo.stateRestoreCallback); + stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback); + + final TaskCorruptedException exception = assertThrows(TaskCorruptedException.class, + () -> stateMgr.initializeStoreOffsetsFromCheckpoint(false)); + + assertEquals( + Collections.singleton(taskId), + exception.corruptedTasks() + ); + } finally { + stateMgr.close(); + } + } + + @Test + public void shouldNotThrowTaskCorruptedWithoutInMemoryStoreCheckpointAndNonEmptyDir() throws IOException { + final long checkpointOffset = 10L; + + final Map offsets = mkMap( + mkEntry(persistentStorePartition, checkpointOffset), + mkEntry(irrelevantPartition, 999L) + ); + checkpoint.write(offsets); + + final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE, true); + + try { + stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback); + stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback); + + stateMgr.initializeStoreOffsetsFromCheckpoint(false); + } finally { + stateMgr.close(); + } + } + + @Test + public void shouldNotThrowTaskCorruptedExceptionAfterCheckpointing() { + final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE, true); + + try { + stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback); + stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback); + stateMgr.initializeStoreOffsetsFromCheckpoint(true); + + assertThat(stateMgr.storeMetadata(nonPersistentStorePartition), notNullValue()); + assertThat(stateMgr.storeMetadata(persistentStorePartition), notNullValue()); + + stateMgr.updateChangelogOffsets(mkMap( + mkEntry(nonPersistentStorePartition, 876L), + mkEntry(persistentStorePartition, 666L)) + ); + stateMgr.checkpoint(); + + // reset the state and offsets, for example as in a corrupted task + stateMgr.close(); + assertNull(stateMgr.storeMetadata(nonPersistentStorePartition)); + assertNull(stateMgr.storeMetadata(persistentStorePartition)); + + stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback); + stateMgr.registerStore(nonPersistentStore, nonPersistentStore.stateRestoreCallback); + + // This should not throw a TaskCorruptedException! + stateMgr.initializeStoreOffsetsFromCheckpoint(false); + assertThat(stateMgr.storeMetadata(nonPersistentStorePartition), notNullValue()); + assertThat(stateMgr.storeMetadata(persistentStorePartition), notNullValue()); + } finally { + stateMgr.close(); + } + } + + @Test + public void shouldThrowIllegalStateIfInitializingOffsetsForCorruptedTasks() { + final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE, true); + + try { + stateMgr.registerStore(persistentStore, persistentStore.stateRestoreCallback); + stateMgr.markChangelogAsCorrupted(mkSet(persistentStorePartition)); + + final ProcessorStateException thrown = assertThrows(ProcessorStateException.class, () -> stateMgr.initializeStoreOffsetsFromCheckpoint(true)); + assertTrue(thrown.getCause() instanceof IllegalStateException); + } finally { + stateMgr.close(); + } + } + + @Test + public void shouldBeAbleToCloseWithoutRegisteringAnyStores() { + final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE, true); + + stateMgr.close(); + } + + @Test + public void shouldDeleteCheckPointFileIfEosEnabled() throws IOException { + final long checkpointOffset = 10L; + final Map offsets = mkMap( + mkEntry(persistentStorePartition, checkpointOffset), + mkEntry(nonPersistentStorePartition, checkpointOffset), + mkEntry(irrelevantPartition, 999L) + ); + checkpoint.write(offsets); + final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE, true); + stateMgr.deleteCheckPointFileIfEOSEnabled(); + stateMgr.close(); + assertFalse(checkpointFile.exists()); + } + + @Test + public void shouldNotDeleteCheckPointFileIfEosNotEnabled() throws IOException { + final long checkpointOffset = 10L; + final Map offsets = mkMap( + mkEntry(persistentStorePartition, checkpointOffset), + mkEntry(nonPersistentStorePartition, checkpointOffset), + mkEntry(irrelevantPartition, 999L) + ); + checkpoint.write(offsets); + final ProcessorStateManager stateMgr = getStateManager(Task.TaskType.ACTIVE, false); + stateMgr.deleteCheckPointFileIfEOSEnabled(); + stateMgr.close(); + assertTrue(checkpointFile.exists()); + } + + private ProcessorStateManager getStateManager(final Task.TaskType taskType, final boolean eosEnabled) { + return new ProcessorStateManager( + taskId, + taskType, + eosEnabled, + logContext, + stateDirectory, + changelogReader, + mkMap( + mkEntry(persistentStoreName, persistentStoreTopicName), + mkEntry(persistentStoreTwoName, persistentStoreTwoTopicName), + mkEntry(nonPersistentStoreName, nonPersistentStoreTopicName) + ), + emptySet()); + } + + private ProcessorStateManager getStateManager(final Task.TaskType taskType) { + return getStateManager(taskType, false); + } + + private MockKeyValueStore getConverterStore() { + return new ConverterStore(persistentStoreName, true); + } + + private static class ConverterStore extends MockKeyValueStore implements TimestampedBytesStore { + ConverterStore(final String name, final boolean persistent) { + super(name, persistent); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyFactories.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyFactories.java new file mode 100644 index 0000000..57e4490 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyFactories.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.streams.processor.StateStore; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +public final class ProcessorTopologyFactories { + private ProcessorTopologyFactories() {} + + + public static ProcessorTopology with(final List> processorNodes, + final Map> sourcesByTopic, + final List stateStoresByName, + final Map storeToChangelogTopic) { + return new ProcessorTopology(processorNodes, + sourcesByTopic, + Collections.emptyMap(), + stateStoresByName, + Collections.emptyList(), + storeToChangelogTopic, + Collections.emptySet()); + } + + static ProcessorTopology withLocalStores(final List stateStores, + final Map storeToChangelogTopic) { + return new ProcessorTopology(Collections.emptyList(), + Collections.emptyMap(), + Collections.emptyMap(), + stateStores, + Collections.emptyList(), + storeToChangelogTopic, + Collections.emptySet()); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyTest.java new file mode 100644 index 0000000..de26c7b --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ProcessorTopologyTest.java @@ -0,0 +1,1849 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TestOutputTopic; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.TopologyWrapper; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.StreamPartitioner; +import org.apache.kafka.streams.processor.TimestampExtractor; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.state.KeyValueBytesStoreSupplier; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.test.TestRecord; +import org.apache.kafka.test.MockApiProcessorSupplier; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collections; +import java.util.Properties; +import java.util.Set; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Supplier; + +import static java.util.Arrays.asList; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.CoreMatchers.startsWith; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class ProcessorTopologyTest { + + private static final Serializer STRING_SERIALIZER = new StringSerializer(); + private static final Deserializer STRING_DESERIALIZER = new StringDeserializer(); + + private static final String DEFAULT_STORE_NAME = "prefixScanStore"; + private static final String DEFAULT_PREFIX = "key"; + + private static final String INPUT_TOPIC_1 = "input-topic-1"; + private static final String INPUT_TOPIC_2 = "input-topic-2"; + private static final String OUTPUT_TOPIC_1 = "output-topic-1"; + private static final String OUTPUT_TOPIC_2 = "output-topic-2"; + private static final String THROUGH_TOPIC_1 = "through-topic-1"; + + private static final Header HEADER = new RecordHeader("key", "value".getBytes()); + private static final Headers HEADERS = new RecordHeaders(new Header[]{HEADER}); + + private final TopologyWrapper topology = new TopologyWrapper(); + private final MockApiProcessorSupplier mockProcessorSupplier = new MockApiProcessorSupplier<>(); + + private TopologyTestDriver driver; + private final Properties props = new Properties(); + + @Before + public void setup() { + // Create a new directory in which we'll put all of the state for this test, enabling running tests in parallel ... + final File localState = TestUtils.tempDirectory(); + props.setProperty(StreamsConfig.STATE_DIR_CONFIG, localState.getAbsolutePath()); + props.setProperty(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass().getName()); + props.setProperty(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass().getName()); + props.setProperty(StreamsConfig.DEFAULT_TIMESTAMP_EXTRACTOR_CLASS_CONFIG, CustomTimestampExtractor.class.getName()); + } + + @After + public void cleanup() { + props.clear(); + if (driver != null) { + driver.close(); + } + driver = null; + } + + private List> prefixScanResults(final KeyValueStore store, final String prefix) { + final List> results = new ArrayList<>(); + try (final KeyValueIterator prefixScan = store.prefixScan(prefix, Serdes.String().serializer())) { + while (prefixScan.hasNext()) { + final KeyValue next = prefixScan.next(); + results.add(next); + } + } + + return results; + } + + @Test + public void testTopologyMetadata() { + topology.addSource("source-1", "topic-1"); + topology.addSource("source-2", "topic-2", "topic-3"); + topology.addProcessor("processor-1", new MockApiProcessorSupplier<>(), "source-1"); + topology.addProcessor("processor-2", new MockApiProcessorSupplier<>(), "source-1", "source-2"); + topology.addSink("sink-1", "topic-3", "processor-1"); + topology.addSink("sink-2", "topic-4", "processor-1", "processor-2"); + + final ProcessorTopology processorTopology = topology.getInternalBuilder("X").buildTopology(); + + assertEquals(6, processorTopology.processors().size()); + + assertEquals(2, processorTopology.sources().size()); + + assertEquals(3, processorTopology.sourceTopics().size()); + + assertNotNull(processorTopology.source("topic-1")); + + assertNotNull(processorTopology.source("topic-2")); + + assertNotNull(processorTopology.source("topic-3")); + + assertEquals(processorTopology.source("topic-2"), processorTopology.source("topic-3")); + } + + @Test + public void shouldGetTerminalNodes() { + topology.addSource("source-1", "topic-1"); + topology.addSource("source-2", "topic-2", "topic-3"); + topology.addProcessor("processor-1", new MockApiProcessorSupplier<>(), "source-1"); + topology.addProcessor("processor-2", new MockApiProcessorSupplier<>(), "source-1", "source-2"); + topology.addSink("sink-1", "topic-3", "processor-1"); + + final ProcessorTopology processorTopology = topology.getInternalBuilder("X").buildTopology(); + + assertThat(processorTopology.terminalNodes(), equalTo(mkSet("processor-2", "sink-1"))); + } + + @Test + public void shouldUpdateSourceTopicsWithNewMatchingTopic() { + final String sourceNode = "source-1"; + final String topic = "topic-1"; + final String newTopic = "topic-2"; + topology.addSource(sourceNode, topic); + final ProcessorTopology processorTopology = topology.getInternalBuilder("X").buildTopology(); + assertThat(processorTopology.source(newTopic), is(nullValue())); + + processorTopology.updateSourceTopics(Collections.singletonMap(sourceNode, asList(topic, newTopic))); + + assertThat(processorTopology.source(newTopic).name(), equalTo(sourceNode)); + } + + @Test + public void shouldUpdateSourceTopicsWithRemovedTopic() { + final String sourceNode = "source-1"; + final String topic = "topic-1"; + final String topicToRemove = "topic-2"; + topology.addSource(sourceNode, topic, topicToRemove); + final ProcessorTopology processorTopology = topology.getInternalBuilder("X").buildTopology(); + assertThat(processorTopology.source(topicToRemove).name(), equalTo(sourceNode)); + + processorTopology.updateSourceTopics(Collections.singletonMap(sourceNode, Collections.singletonList(topic))); + + assertThat(processorTopology.source(topicToRemove), is(nullValue())); + } + + @Test + public void shouldUpdateSourceTopicsWithAllTopicsRemoved() { + final String sourceNode = "source-1"; + final String topic = "topic-1"; + topology.addSource(sourceNode, topic); + final ProcessorTopology processorTopology = topology.getInternalBuilder("X").buildTopology(); + assertThat(processorTopology.source(topic).name(), equalTo(sourceNode)); + + processorTopology.updateSourceTopics(Collections.singletonMap(sourceNode, Collections.emptyList())); + + assertThat(processorTopology.source(topic), is(nullValue())); + } + + @Test + public void shouldUpdateSourceTopicsOnlyForSourceNodesWithinTheSubtopology() { + final String sourceNodeWithinSubtopology = "source-1"; + final String sourceNodeOutsideSubtopology = "source-2"; + final String topicWithinSubtopology = "topic-1"; + final String topicOutsideSubtopology = "topic-2"; + topology.addSource(sourceNodeWithinSubtopology, topicWithinSubtopology); + final ProcessorTopology processorTopology = topology.getInternalBuilder("X").buildTopology(); + + processorTopology.updateSourceTopics(mkMap( + mkEntry(sourceNodeWithinSubtopology, Collections.singletonList(topicWithinSubtopology)), + mkEntry(sourceNodeOutsideSubtopology, Collections.singletonList(topicOutsideSubtopology)) + ) + ); + + assertThat(processorTopology.source(topicOutsideSubtopology), is(nullValue())); + assertThat(processorTopology.sources().size(), equalTo(1)); + } + + @Test + public void shouldThrowIfSourceNodeToUpdateDoesNotExist() { + final String existingSourceNode = "source-1"; + final String nonExistingSourceNode = "source-2"; + final String topicOfExistingSourceNode = "topic-1"; + final String topicOfNonExistingSourceNode = "topic-2"; + topology.addSource(nonExistingSourceNode, topicOfNonExistingSourceNode); + final ProcessorTopology processorTopology = topology.getInternalBuilder("X").buildTopology(); + + final Throwable exception = assertThrows( + IllegalStateException.class, + () -> processorTopology.updateSourceTopics(Collections.singletonMap( + existingSourceNode, Collections.singletonList(topicOfExistingSourceNode) + )) + ); + assertThat(exception.getMessage(), is("Node " + nonExistingSourceNode + " not found in full topology")); + } + + @Test + public void shouldThrowIfMultipleSourceNodeOfSameSubtopologySubscribedToSameTopic() { + final String sourceNode = "source-1"; + final String updatedSourceNode = "source-2"; + final String doublySubscribedTopic = "topic-1"; + final String topic = "topic-2"; + topology.addSource(sourceNode, doublySubscribedTopic); + topology.addSource(updatedSourceNode, topic); + final ProcessorTopology processorTopology = topology.getInternalBuilder("X").buildTopology(); + + final Throwable exception = assertThrows( + IllegalStateException.class, + () -> processorTopology.updateSourceTopics(mkMap( + mkEntry(sourceNode, Collections.singletonList(doublySubscribedTopic)), + mkEntry(updatedSourceNode, Arrays.asList(topic, doublySubscribedTopic)) + )) + ); + assertThat( + exception.getMessage(), + startsWith("Topic " + doublySubscribedTopic + " was already registered to source node") + ); + } + + @Test + public void testDrivingSimpleTopology() { + final int partition = 10; + driver = new TopologyTestDriver(createSimpleTopology(partition), props); + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER, Instant.ofEpochMilli(0L), Duration.ZERO); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.String().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + assertNextOutputRecord(outputTopic1.readRecord(), "key1", "value1"); + assertTrue(outputTopic1.isEmpty()); + + inputTopic.pipeInput("key2", "value2"); + assertNextOutputRecord(outputTopic1.readRecord(), "key2", "value2"); + assertTrue(outputTopic1.isEmpty()); + + inputTopic.pipeInput("key3", "value3"); + inputTopic.pipeInput("key4", "value4"); + inputTopic.pipeInput("key5", "value5"); + assertNextOutputRecord(outputTopic1.readRecord(), "key3", "value3"); + assertNextOutputRecord(outputTopic1.readRecord(), "key4", "value4"); + assertNextOutputRecord(outputTopic1.readRecord(), "key5", "value5"); + assertTrue(outputTopic1.isEmpty()); + } + + @Test + public void testDrivingStatefulTopology() { + final String storeName = "entries"; + driver = new TopologyTestDriver(createStatefulTopology(storeName), props); + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + inputTopic.pipeInput("key1", "value4"); + assertTrue(outputTopic1.isEmpty()); + + final KeyValueStore store = driver.getKeyValueStore(storeName); + assertEquals("value4", store.get("key1")); + assertEquals("value2", store.get("key2")); + assertEquals("value3", store.get("key3")); + assertNull(store.get("key4")); + } + + @Test + public void testDrivingConnectedStateStoreTopology() { + driver = new TopologyTestDriver(createConnectedStateStoreTopology("connectedStore"), props); + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + inputTopic.pipeInput("key1", "value4"); + assertTrue(outputTopic1.isEmpty()); + + final KeyValueStore store = driver.getKeyValueStore("connectedStore"); + assertEquals("value4", store.get("key1")); + assertEquals("value2", store.get("key2")); + assertEquals("value3", store.get("key3")); + assertNull(store.get("key4")); + } + + @Deprecated // testing old PAPI + @Test + public void testDrivingConnectedStateStoreInDifferentProcessorsTopologyWithOldAPI() { + final String storeName = "connectedStore"; + final StoreBuilder> storeBuilder = + Stores.keyValueStoreBuilder(Stores.inMemoryKeyValueStore(storeName), Serdes.String(), Serdes.String()); + topology + .addSource("source1", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addSource("source2", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_2) + .addProcessor("processor1", defineWithStoresOldAPI(() -> new OldAPIStatefulProcessor(storeName), Collections.singleton(storeBuilder)), "source1") + .addProcessor("processor2", defineWithStoresOldAPI(() -> new OldAPIStatefulProcessor(storeName), Collections.singleton(storeBuilder)), "source2") + .addSink("counts", OUTPUT_TOPIC_1, "processor1", "processor2"); + + driver = new TopologyTestDriver(topology, props); + + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + inputTopic.pipeInput("key1", "value4"); + assertTrue(outputTopic1.isEmpty()); + + final KeyValueStore store = driver.getKeyValueStore("connectedStore"); + assertEquals("value4", store.get("key1")); + assertEquals("value2", store.get("key2")); + assertEquals("value3", store.get("key3")); + assertNull(store.get("key4")); + } + + @Test + public void testDrivingConnectedStateStoreInDifferentProcessorsTopology() { + final String storeName = "connectedStore"; + final StoreBuilder> storeBuilder = + Stores.keyValueStoreBuilder(Stores.inMemoryKeyValueStore(storeName), Serdes.String(), Serdes.String()); + topology + .addSource("source1", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addSource("source2", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_2) + .addProcessor("processor1", defineWithStores(() -> new StatefulProcessor(storeName), Collections.singleton(storeBuilder)), "source1") + .addProcessor("processor2", defineWithStores(() -> new StatefulProcessor(storeName), Collections.singleton(storeBuilder)), "source2") + .addSink("counts", OUTPUT_TOPIC_1, "processor1", "processor2"); + + driver = new TopologyTestDriver(topology, props); + + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + inputTopic.pipeInput("key1", "value4"); + assertTrue(outputTopic1.isEmpty()); + + final KeyValueStore store = driver.getKeyValueStore("connectedStore"); + assertEquals("value4", store.get("key1")); + assertEquals("value2", store.get("key2")); + assertEquals("value3", store.get("key3")); + assertNull(store.get("key4")); + } + + @Test + public void testPrefixScanInMemoryStoreNoCachingNoLogging() { + final StoreBuilder> storeBuilder = + Stores.keyValueStoreBuilder(Stores.inMemoryKeyValueStore(DEFAULT_STORE_NAME), Serdes.String(), Serdes.String()) + .withCachingDisabled() + .withLoggingDisabled(); + topology + .addSource("source1", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor1", defineWithStores(() -> new StatefulProcessor(DEFAULT_STORE_NAME), Collections.singleton(storeBuilder)), "source1") + .addSink("counts", OUTPUT_TOPIC_1, "processor1"); + + driver = new TopologyTestDriver(topology, props); + + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + inputTopic.pipeInput("key1", "value4"); + assertTrue(outputTopic1.isEmpty()); + + final KeyValueStore store = driver.getKeyValueStore(DEFAULT_STORE_NAME); + final List> results = prefixScanResults(store, DEFAULT_PREFIX); + + assertEquals("key1", results.get(0).key); + assertEquals("value4", results.get(0).value); + assertEquals("key2", results.get(1).key); + assertEquals("value2", results.get(1).value); + assertEquals("key3", results.get(2).key); + assertEquals("value3", results.get(2).value); + + } + + @Test + public void testPrefixScanInMemoryStoreWithCachingNoLogging() { + final StoreBuilder> storeBuilder = + Stores.keyValueStoreBuilder(Stores.inMemoryKeyValueStore(DEFAULT_STORE_NAME), Serdes.String(), Serdes.String()) + .withCachingEnabled() + .withLoggingDisabled(); + topology + .addSource("source1", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor1", defineWithStores(() -> new StatefulProcessor(DEFAULT_STORE_NAME), Collections.singleton(storeBuilder)), "source1") + .addSink("counts", OUTPUT_TOPIC_1, "processor1"); + + driver = new TopologyTestDriver(topology, props); + + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + inputTopic.pipeInput("key1", "value4"); + assertTrue(outputTopic1.isEmpty()); + + final KeyValueStore store = driver.getKeyValueStore(DEFAULT_STORE_NAME); + final List> results = prefixScanResults(store, DEFAULT_PREFIX); + + assertEquals("key1", results.get(0).key); + assertEquals("value4", results.get(0).value); + assertEquals("key2", results.get(1).key); + assertEquals("value2", results.get(1).value); + assertEquals("key3", results.get(2).key); + assertEquals("value3", results.get(2).value); + + } + + @Test + public void testPrefixScanInMemoryStoreWithCachingWithLogging() { + final StoreBuilder> storeBuilder = + Stores.keyValueStoreBuilder(Stores.inMemoryKeyValueStore(DEFAULT_STORE_NAME), Serdes.String(), Serdes.String()) + .withCachingEnabled() + .withLoggingEnabled(Collections.emptyMap()); + topology + .addSource("source1", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor1", defineWithStores(() -> new StatefulProcessor(DEFAULT_STORE_NAME), Collections.singleton(storeBuilder)), "source1") + .addSink("counts", OUTPUT_TOPIC_1, "processor1"); + + driver = new TopologyTestDriver(topology, props); + + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + inputTopic.pipeInput("key1", "value4"); + assertTrue(outputTopic1.isEmpty()); + + final KeyValueStore store = driver.getKeyValueStore(DEFAULT_STORE_NAME); + final List> results = prefixScanResults(store, DEFAULT_PREFIX); + + assertEquals("key1", results.get(0).key); + assertEquals("value4", results.get(0).value); + assertEquals("key2", results.get(1).key); + assertEquals("value2", results.get(1).value); + assertEquals("key3", results.get(2).key); + assertEquals("value3", results.get(2).value); + + } + + @Test + public void testPrefixScanPersistentStoreNoCachingNoLogging() { + final StoreBuilder> storeBuilder = + Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore(DEFAULT_STORE_NAME), Serdes.String(), Serdes.String()) + .withCachingDisabled() + .withLoggingDisabled(); + topology + .addSource("source1", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor1", defineWithStores(() -> new StatefulProcessor(DEFAULT_STORE_NAME), Collections.singleton(storeBuilder)), "source1") + .addSink("counts", OUTPUT_TOPIC_1, "processor1"); + + driver = new TopologyTestDriver(topology, props); + + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + inputTopic.pipeInput("key1", "value4"); + assertTrue(outputTopic1.isEmpty()); + + final KeyValueStore store = driver.getKeyValueStore(DEFAULT_STORE_NAME); + final List> results = prefixScanResults(store, DEFAULT_PREFIX); + + assertEquals("key1", results.get(0).key); + assertEquals("value4", results.get(0).value); + assertEquals("key2", results.get(1).key); + assertEquals("value2", results.get(1).value); + assertEquals("key3", results.get(2).key); + assertEquals("value3", results.get(2).value); + + } + + @Test + public void testPrefixScanPersistentStoreWithCachingNoLogging() { + final StoreBuilder> storeBuilder = + Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore(DEFAULT_STORE_NAME), Serdes.String(), Serdes.String()) + .withCachingEnabled() + .withLoggingDisabled(); + topology + .addSource("source1", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor1", defineWithStores(() -> new StatefulProcessor(DEFAULT_STORE_NAME), Collections.singleton(storeBuilder)), "source1") + .addSink("counts", OUTPUT_TOPIC_1, "processor1"); + + driver = new TopologyTestDriver(topology, props); + + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + inputTopic.pipeInput("key1", "value4"); + assertTrue(outputTopic1.isEmpty()); + + final KeyValueStore store = driver.getKeyValueStore(DEFAULT_STORE_NAME); + final List> results = prefixScanResults(store, DEFAULT_PREFIX); + + assertEquals("key1", results.get(0).key); + assertEquals("value4", results.get(0).value); + assertEquals("key2", results.get(1).key); + assertEquals("value2", results.get(1).value); + assertEquals("key3", results.get(2).key); + assertEquals("value3", results.get(2).value); + + } + + @Test + public void testPrefixScanPersistentStoreWithCachingWithLogging() { + final StoreBuilder> storeBuilder = + Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore(DEFAULT_STORE_NAME), Serdes.String(), Serdes.String()) + .withCachingEnabled() + .withLoggingEnabled(Collections.emptyMap()); + topology + .addSource("source1", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor1", defineWithStores(() -> new StatefulProcessor(DEFAULT_STORE_NAME), Collections.singleton(storeBuilder)), "source1") + .addSink("counts", OUTPUT_TOPIC_1, "processor1"); + + driver = new TopologyTestDriver(topology, props); + + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + inputTopic.pipeInput("key1", "value4"); + assertTrue(outputTopic1.isEmpty()); + + final KeyValueStore store = driver.getKeyValueStore(DEFAULT_STORE_NAME); + final List> results = prefixScanResults(store, DEFAULT_PREFIX); + + assertEquals("key1", results.get(0).key); + assertEquals("value4", results.get(0).value); + assertEquals("key2", results.get(1).key); + assertEquals("value2", results.get(1).value); + assertEquals("key3", results.get(2).key); + assertEquals("value3", results.get(2).value); + + } + + @Test + public void testPrefixScanPersistentTimestampedStoreNoCachingNoLogging() { + final StoreBuilder> storeBuilder = + Stores.keyValueStoreBuilder(Stores.persistentTimestampedKeyValueStore(DEFAULT_STORE_NAME), Serdes.String(), Serdes.String()) + .withCachingDisabled() + .withLoggingDisabled(); + topology + .addSource("source1", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor1", defineWithStores(() -> new StatefulProcessor(DEFAULT_STORE_NAME), Collections.singleton(storeBuilder)), "source1") + .addSink("counts", OUTPUT_TOPIC_1, "processor1"); + + driver = new TopologyTestDriver(topology, props); + + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + inputTopic.pipeInput("key1", "value4"); + assertTrue(outputTopic1.isEmpty()); + + final KeyValueStore store = driver.getKeyValueStore(DEFAULT_STORE_NAME); + final List> results = prefixScanResults(store, DEFAULT_PREFIX); + + assertEquals("key1", results.get(0).key); + assertEquals("value4", results.get(0).value); + assertEquals("key2", results.get(1).key); + assertEquals("value2", results.get(1).value); + assertEquals("key3", results.get(2).key); + assertEquals("value3", results.get(2).value); + + } + + @Test + public void testPrefixScanPersistentTimestampedStoreWithCachingNoLogging() { + final StoreBuilder> storeBuilder = + Stores.keyValueStoreBuilder(Stores.persistentTimestampedKeyValueStore(DEFAULT_STORE_NAME), Serdes.String(), Serdes.String()) + .withCachingEnabled() + .withLoggingDisabled(); + topology + .addSource("source1", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor1", defineWithStores(() -> new StatefulProcessor(DEFAULT_STORE_NAME), Collections.singleton(storeBuilder)), "source1") + .addSink("counts", OUTPUT_TOPIC_1, "processor1"); + + driver = new TopologyTestDriver(topology, props); + + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + inputTopic.pipeInput("key1", "value4"); + assertTrue(outputTopic1.isEmpty()); + + final KeyValueStore store = driver.getKeyValueStore(DEFAULT_STORE_NAME); + final List> results = prefixScanResults(store, DEFAULT_PREFIX); + + assertEquals("key1", results.get(0).key); + assertEquals("value4", results.get(0).value); + assertEquals("key2", results.get(1).key); + assertEquals("value2", results.get(1).value); + assertEquals("key3", results.get(2).key); + assertEquals("value3", results.get(2).value); + + } + + @Test + public void testPrefixScanPersistentTimestampedStoreWithCachingWithLogging() { + final StoreBuilder> storeBuilder = + Stores.keyValueStoreBuilder(Stores.persistentTimestampedKeyValueStore(DEFAULT_STORE_NAME), Serdes.String(), Serdes.String()) + .withCachingEnabled() + .withLoggingEnabled(Collections.emptyMap()); + topology + .addSource("source1", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor1", defineWithStores(() -> new StatefulProcessor(DEFAULT_STORE_NAME), Collections.singleton(storeBuilder)), "source1") + .addSink("counts", OUTPUT_TOPIC_1, "processor1"); + + driver = new TopologyTestDriver(topology, props); + + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + inputTopic.pipeInput("key1", "value4"); + assertTrue(outputTopic1.isEmpty()); + + final KeyValueStore store = driver.getKeyValueStore(DEFAULT_STORE_NAME); + final List> results = prefixScanResults(store, DEFAULT_PREFIX); + + assertEquals("key1", results.get(0).key); + assertEquals("value4", results.get(0).value); + assertEquals("key2", results.get(1).key); + assertEquals("value2", results.get(1).value); + assertEquals("key3", results.get(2).key); + assertEquals("value3", results.get(2).value); + + } + + @Test + public void testPrefixScanLruMapNoCachingNoLogging() { + final StoreBuilder> storeBuilder = + Stores.keyValueStoreBuilder(Stores.lruMap(DEFAULT_STORE_NAME, 100), Serdes.String(), Serdes.String()) + .withCachingDisabled() + .withLoggingDisabled(); + topology + .addSource("source1", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor1", defineWithStores(() -> new StatefulProcessor(DEFAULT_STORE_NAME), Collections.singleton(storeBuilder)), "source1") + .addSink("counts", OUTPUT_TOPIC_1, "processor1"); + + driver = new TopologyTestDriver(topology, props); + + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + inputTopic.pipeInput("key1", "value4"); + assertTrue(outputTopic1.isEmpty()); + + final KeyValueStore store = driver.getKeyValueStore(DEFAULT_STORE_NAME); + final List> results = prefixScanResults(store, DEFAULT_PREFIX); + + assertEquals("key1", results.get(0).key); + assertEquals("value4", results.get(0).value); + assertEquals("key2", results.get(1).key); + assertEquals("value2", results.get(1).value); + assertEquals("key3", results.get(2).key); + assertEquals("value3", results.get(2).value); + + } + + @Test + public void testPrefixScanLruMapWithCachingNoLogging() { + final StoreBuilder> storeBuilder = + Stores.keyValueStoreBuilder(Stores.lruMap(DEFAULT_STORE_NAME, 100), Serdes.String(), Serdes.String()) + .withCachingEnabled() + .withLoggingDisabled(); + topology + .addSource("source1", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor1", defineWithStores(() -> new StatefulProcessor(DEFAULT_STORE_NAME), Collections.singleton(storeBuilder)), "source1") + .addSink("counts", OUTPUT_TOPIC_1, "processor1"); + + driver = new TopologyTestDriver(topology, props); + + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + inputTopic.pipeInput("key1", "value4"); + assertTrue(outputTopic1.isEmpty()); + + final KeyValueStore store = driver.getKeyValueStore(DEFAULT_STORE_NAME); + final List> results = prefixScanResults(store, DEFAULT_PREFIX); + + assertEquals("key1", results.get(0).key); + assertEquals("value4", results.get(0).value); + assertEquals("key2", results.get(1).key); + assertEquals("value2", results.get(1).value); + assertEquals("key3", results.get(2).key); + assertEquals("value3", results.get(2).value); + + } + + @Test + public void testPrefixScanLruMapWithCachingWithLogging() { + final StoreBuilder> storeBuilder = + Stores.keyValueStoreBuilder(Stores.lruMap(DEFAULT_STORE_NAME, 100), Serdes.String(), Serdes.String()) + .withCachingEnabled() + .withLoggingEnabled(Collections.emptyMap()); + topology + .addSource("source1", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor1", defineWithStores(() -> new StatefulProcessor(DEFAULT_STORE_NAME), Collections.singleton(storeBuilder)), "source1") + .addSink("counts", OUTPUT_TOPIC_1, "processor1"); + + driver = new TopologyTestDriver(topology, props); + + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + inputTopic.pipeInput("key1", "value4"); + assertTrue(outputTopic1.isEmpty()); + + final KeyValueStore store = driver.getKeyValueStore(DEFAULT_STORE_NAME); + final List> results = prefixScanResults(store, DEFAULT_PREFIX); + + assertEquals("key1", results.get(0).key); + assertEquals("value4", results.get(0).value); + assertEquals("key2", results.get(1).key); + assertEquals("value2", results.get(1).value); + assertEquals("key3", results.get(2).key); + assertEquals("value3", results.get(2).value); + + } + + @Deprecated // testing old PAPI + @Test + public void testPrefixScanInMemoryStoreNoCachingNoLoggingOldProcessor() { + final StoreBuilder> storeBuilder = + Stores.keyValueStoreBuilder(Stores.inMemoryKeyValueStore(DEFAULT_STORE_NAME), Serdes.String(), Serdes.String()) + .withCachingDisabled() + .withLoggingDisabled(); + topology + .addSource("source1", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor1", defineWithStoresOldAPI(() -> new OldAPIStatefulProcessor(DEFAULT_STORE_NAME), Collections.singleton(storeBuilder)), "source1") + .addSink("counts", OUTPUT_TOPIC_1, "processor1"); + + driver = new TopologyTestDriver(topology, props); + + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + inputTopic.pipeInput("key1", "value4"); + assertTrue(outputTopic1.isEmpty()); + + final KeyValueStore store = driver.getKeyValueStore(DEFAULT_STORE_NAME); + final List> results = prefixScanResults(store, DEFAULT_PREFIX); + + assertEquals("key1", results.get(0).key); + assertEquals("value4", results.get(0).value); + assertEquals("key2", results.get(1).key); + assertEquals("value2", results.get(1).value); + assertEquals("key3", results.get(2).key); + assertEquals("value3", results.get(2).value); + + } + + @Deprecated // testing old PAPI + @Test + public void testPrefixScanInMemoryStoreWithCachingNoLoggingOldProcessor() { + final StoreBuilder> storeBuilder = + Stores.keyValueStoreBuilder(Stores.inMemoryKeyValueStore(DEFAULT_STORE_NAME), Serdes.String(), Serdes.String()) + .withCachingEnabled() + .withLoggingDisabled(); + topology + .addSource("source1", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor1", defineWithStoresOldAPI(() -> new OldAPIStatefulProcessor(DEFAULT_STORE_NAME), Collections.singleton(storeBuilder)), "source1") + .addSink("counts", OUTPUT_TOPIC_1, "processor1"); + + driver = new TopologyTestDriver(topology, props); + + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + inputTopic.pipeInput("key1", "value4"); + assertTrue(outputTopic1.isEmpty()); + + final KeyValueStore store = driver.getKeyValueStore(DEFAULT_STORE_NAME); + final List> results = prefixScanResults(store, DEFAULT_PREFIX); + + assertEquals("key1", results.get(0).key); + assertEquals("value4", results.get(0).value); + assertEquals("key2", results.get(1).key); + assertEquals("value2", results.get(1).value); + assertEquals("key3", results.get(2).key); + assertEquals("value3", results.get(2).value); + + } + + @Deprecated // testing old PAPI + @Test + public void testPrefixScanInMemoryStoreWithCachingWithLoggingOldProcessor() { + final StoreBuilder> storeBuilder = + Stores.keyValueStoreBuilder(Stores.inMemoryKeyValueStore(DEFAULT_STORE_NAME), Serdes.String(), Serdes.String()) + .withCachingEnabled() + .withLoggingEnabled(Collections.emptyMap()); + topology + .addSource("source1", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor1", defineWithStoresOldAPI(() -> new OldAPIStatefulProcessor(DEFAULT_STORE_NAME), Collections.singleton(storeBuilder)), "source1") + .addSink("counts", OUTPUT_TOPIC_1, "processor1"); + + driver = new TopologyTestDriver(topology, props); + + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + inputTopic.pipeInput("key1", "value4"); + assertTrue(outputTopic1.isEmpty()); + + final KeyValueStore store = driver.getKeyValueStore(DEFAULT_STORE_NAME); + final List> results = prefixScanResults(store, DEFAULT_PREFIX); + + assertEquals("key1", results.get(0).key); + assertEquals("value4", results.get(0).value); + assertEquals("key2", results.get(1).key); + assertEquals("value2", results.get(1).value); + assertEquals("key3", results.get(2).key); + assertEquals("value3", results.get(2).value); + + } + + @Deprecated // testing old PAPI + @Test + public void testPrefixScanPersistentStoreNoCachingNoLoggingOldProcessor() { + final StoreBuilder> storeBuilder = + Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore(DEFAULT_STORE_NAME), Serdes.String(), Serdes.String()) + .withCachingDisabled() + .withLoggingDisabled(); + topology + .addSource("source1", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor1", defineWithStoresOldAPI(() -> new OldAPIStatefulProcessor(DEFAULT_STORE_NAME), Collections.singleton(storeBuilder)), "source1") + .addSink("counts", OUTPUT_TOPIC_1, "processor1"); + + driver = new TopologyTestDriver(topology, props); + + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + inputTopic.pipeInput("key1", "value4"); + assertTrue(outputTopic1.isEmpty()); + + final KeyValueStore store = driver.getKeyValueStore(DEFAULT_STORE_NAME); + final List> results = prefixScanResults(store, DEFAULT_PREFIX); + + assertEquals("key1", results.get(0).key); + assertEquals("value4", results.get(0).value); + assertEquals("key2", results.get(1).key); + assertEquals("value2", results.get(1).value); + assertEquals("key3", results.get(2).key); + assertEquals("value3", results.get(2).value); + + } + + @Deprecated // testing old PAPI + @Test + public void testPrefixScanPersistentStoreWithCachingNoLoggingOldProcessor() { + final StoreBuilder> storeBuilder = + Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore(DEFAULT_STORE_NAME), Serdes.String(), Serdes.String()) + .withCachingEnabled() + .withLoggingDisabled(); + topology + .addSource("source1", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor1", defineWithStoresOldAPI(() -> new OldAPIStatefulProcessor(DEFAULT_STORE_NAME), Collections.singleton(storeBuilder)), "source1") + .addSink("counts", OUTPUT_TOPIC_1, "processor1"); + + driver = new TopologyTestDriver(topology, props); + + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + inputTopic.pipeInput("key1", "value4"); + assertTrue(outputTopic1.isEmpty()); + + final KeyValueStore store = driver.getKeyValueStore(DEFAULT_STORE_NAME); + final List> results = prefixScanResults(store, DEFAULT_PREFIX); + + assertEquals("key1", results.get(0).key); + assertEquals("value4", results.get(0).value); + assertEquals("key2", results.get(1).key); + assertEquals("value2", results.get(1).value); + assertEquals("key3", results.get(2).key); + assertEquals("value3", results.get(2).value); + + } + + @Deprecated // testing old PAPI + @Test + public void testPrefixScanPersistentStoreWithCachingWithLoggingOldProcessor() { + final StoreBuilder> storeBuilder = + Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore(DEFAULT_STORE_NAME), Serdes.String(), Serdes.String()) + .withCachingEnabled() + .withLoggingEnabled(Collections.emptyMap()); + topology + .addSource("source1", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor1", defineWithStoresOldAPI(() -> new OldAPIStatefulProcessor(DEFAULT_STORE_NAME), Collections.singleton(storeBuilder)), "source1") + .addSink("counts", OUTPUT_TOPIC_1, "processor1"); + + driver = new TopologyTestDriver(topology, props); + + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + inputTopic.pipeInput("key1", "value4"); + assertTrue(outputTopic1.isEmpty()); + + final KeyValueStore store = driver.getKeyValueStore(DEFAULT_STORE_NAME); + final List> results = prefixScanResults(store, DEFAULT_PREFIX); + + assertEquals("key1", results.get(0).key); + assertEquals("value4", results.get(0).value); + assertEquals("key2", results.get(1).key); + assertEquals("value2", results.get(1).value); + assertEquals("key3", results.get(2).key); + assertEquals("value3", results.get(2).value); + + } + + @Deprecated // testing old PAPI + @Test + public void testPrefixScanPersistentTimestampedStoreNoCachingNoLoggingOldProcessor() { + final StoreBuilder> storeBuilder = + Stores.keyValueStoreBuilder(Stores.persistentTimestampedKeyValueStore(DEFAULT_STORE_NAME), Serdes.String(), Serdes.String()) + .withCachingDisabled() + .withLoggingDisabled(); + topology + .addSource("source1", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor1", defineWithStoresOldAPI(() -> new OldAPIStatefulProcessor(DEFAULT_STORE_NAME), Collections.singleton(storeBuilder)), "source1") + .addSink("counts", OUTPUT_TOPIC_1, "processor1"); + + driver = new TopologyTestDriver(topology, props); + + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + inputTopic.pipeInput("key1", "value4"); + assertTrue(outputTopic1.isEmpty()); + + final KeyValueStore store = driver.getKeyValueStore(DEFAULT_STORE_NAME); + final List> results = prefixScanResults(store, DEFAULT_PREFIX); + + assertEquals("key1", results.get(0).key); + assertEquals("value4", results.get(0).value); + assertEquals("key2", results.get(1).key); + assertEquals("value2", results.get(1).value); + assertEquals("key3", results.get(2).key); + assertEquals("value3", results.get(2).value); + + } + + @Deprecated // testing old PAPI + @Test + public void testPrefixScanPersistentTimestampedStoreWithCachingNoLoggingOldProcessor() { + final StoreBuilder> storeBuilder = + Stores.keyValueStoreBuilder(Stores.persistentTimestampedKeyValueStore(DEFAULT_STORE_NAME), Serdes.String(), Serdes.String()) + .withCachingEnabled() + .withLoggingDisabled(); + topology + .addSource("source1", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor1", defineWithStoresOldAPI(() -> new OldAPIStatefulProcessor(DEFAULT_STORE_NAME), Collections.singleton(storeBuilder)), "source1") + .addSink("counts", OUTPUT_TOPIC_1, "processor1"); + + driver = new TopologyTestDriver(topology, props); + + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + inputTopic.pipeInput("key1", "value4"); + assertTrue(outputTopic1.isEmpty()); + + final KeyValueStore store = driver.getKeyValueStore(DEFAULT_STORE_NAME); + final List> results = prefixScanResults(store, DEFAULT_PREFIX); + + assertEquals("key1", results.get(0).key); + assertEquals("value4", results.get(0).value); + assertEquals("key2", results.get(1).key); + assertEquals("value2", results.get(1).value); + assertEquals("key3", results.get(2).key); + assertEquals("value3", results.get(2).value); + + } + + @Deprecated // testing old PAPI + @Test + public void testPrefixScanPersistentTimestampedStoreWithCachingWithLoggingOldProcessor() { + final StoreBuilder> storeBuilder = + Stores.keyValueStoreBuilder(Stores.persistentTimestampedKeyValueStore(DEFAULT_STORE_NAME), Serdes.String(), Serdes.String()) + .withCachingEnabled() + .withLoggingEnabled(Collections.emptyMap()); + topology + .addSource("source1", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor1", defineWithStoresOldAPI(() -> new OldAPIStatefulProcessor(DEFAULT_STORE_NAME), Collections.singleton(storeBuilder)), "source1") + .addSink("counts", OUTPUT_TOPIC_1, "processor1"); + + driver = new TopologyTestDriver(topology, props); + + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + inputTopic.pipeInput("key1", "value4"); + assertTrue(outputTopic1.isEmpty()); + + final KeyValueStore store = driver.getKeyValueStore(DEFAULT_STORE_NAME); + final List> results = prefixScanResults(store, DEFAULT_PREFIX); + + assertEquals("key1", results.get(0).key); + assertEquals("value4", results.get(0).value); + assertEquals("key2", results.get(1).key); + assertEquals("value2", results.get(1).value); + assertEquals("key3", results.get(2).key); + assertEquals("value3", results.get(2).value); + + } + + @Deprecated // testing old PAPI + @Test + public void testPrefixScanLruMapNoCachingNoLoggingOldProcessor() { + final StoreBuilder> storeBuilder = + Stores.keyValueStoreBuilder(Stores.lruMap(DEFAULT_STORE_NAME, 100), Serdes.String(), Serdes.String()) + .withCachingDisabled() + .withLoggingDisabled(); + topology + .addSource("source1", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor1", defineWithStoresOldAPI(() -> new OldAPIStatefulProcessor(DEFAULT_STORE_NAME), Collections.singleton(storeBuilder)), "source1") + .addSink("counts", OUTPUT_TOPIC_1, "processor1"); + + driver = new TopologyTestDriver(topology, props); + + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + inputTopic.pipeInput("key1", "value4"); + assertTrue(outputTopic1.isEmpty()); + + final KeyValueStore store = driver.getKeyValueStore(DEFAULT_STORE_NAME); + final List> results = prefixScanResults(store, DEFAULT_PREFIX); + + assertEquals("key1", results.get(0).key); + assertEquals("value4", results.get(0).value); + assertEquals("key2", results.get(1).key); + assertEquals("value2", results.get(1).value); + assertEquals("key3", results.get(2).key); + assertEquals("value3", results.get(2).value); + + } + + @Deprecated // testing old PAPI + @Test + public void testPrefixScanLruMapWithCachingNoLoggingOldProcessor() { + final StoreBuilder> storeBuilder = + Stores.keyValueStoreBuilder(Stores.lruMap(DEFAULT_STORE_NAME, 100), Serdes.String(), Serdes.String()) + .withCachingEnabled() + .withLoggingDisabled(); + topology + .addSource("source1", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor1", defineWithStoresOldAPI(() -> new OldAPIStatefulProcessor(DEFAULT_STORE_NAME), Collections.singleton(storeBuilder)), "source1") + .addSink("counts", OUTPUT_TOPIC_1, "processor1"); + + driver = new TopologyTestDriver(topology, props); + + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + inputTopic.pipeInput("key1", "value4"); + assertTrue(outputTopic1.isEmpty()); + + final KeyValueStore store = driver.getKeyValueStore(DEFAULT_STORE_NAME); + final List> results = prefixScanResults(store, DEFAULT_PREFIX); + + assertEquals("key1", results.get(0).key); + assertEquals("value4", results.get(0).value); + assertEquals("key2", results.get(1).key); + assertEquals("value2", results.get(1).value); + assertEquals("key3", results.get(2).key); + assertEquals("value3", results.get(2).value); + + } + + @Deprecated // testing old PAPI + @Test + public void testPrefixScanLruMapWithCachingWithLoggingOldProcessor() { + final StoreBuilder> storeBuilder = + Stores.keyValueStoreBuilder(Stores.lruMap(DEFAULT_STORE_NAME, 100), Serdes.String(), Serdes.String()) + .withCachingEnabled() + .withLoggingEnabled(Collections.emptyMap()); + topology + .addSource("source1", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor1", defineWithStoresOldAPI(() -> new OldAPIStatefulProcessor(DEFAULT_STORE_NAME), Collections.singleton(storeBuilder)), "source1") + .addSink("counts", OUTPUT_TOPIC_1, "processor1"); + + driver = new TopologyTestDriver(topology, props); + + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.Integer().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + inputTopic.pipeInput("key1", "value4"); + assertTrue(outputTopic1.isEmpty()); + + final KeyValueStore store = driver.getKeyValueStore(DEFAULT_STORE_NAME); + final List> results = prefixScanResults(store, DEFAULT_PREFIX); + + assertEquals("key1", results.get(0).key); + assertEquals("value4", results.get(0).value); + assertEquals("key2", results.get(1).key); + assertEquals("value2", results.get(1).value); + assertEquals("key3", results.get(2).key); + assertEquals("value3", results.get(2).value); + + } + + @Deprecated // testing old PAPI + @Test + public void shouldDriveGlobalStore() { + final String storeName = "my-store"; + final String global = "global"; + final String topic = "topic"; + + topology.addGlobalStore( + Stores.keyValueStoreBuilder( + Stores.inMemoryKeyValueStore(storeName), + Serdes.String(), + Serdes.String() + ).withLoggingDisabled(), + global, + STRING_DESERIALIZER, + STRING_DESERIALIZER, + topic, + "processor", + define(new OldAPIStatefulProcessor(storeName))); + + driver = new TopologyTestDriver(topology, props); + final TestInputTopic inputTopic = driver.createInputTopic(topic, STRING_SERIALIZER, STRING_SERIALIZER); + final KeyValueStore globalStore = driver.getKeyValueStore(storeName); + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + assertEquals("value1", globalStore.get("key1")); + assertEquals("value2", globalStore.get("key2")); + } + + @Test + public void testDrivingSimpleMultiSourceTopology() { + final int partition = 10; + driver = new TopologyTestDriver(createSimpleMultiSourceTopology(partition), props); + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER, Instant.ofEpochMilli(0L), Duration.ZERO); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.String().deserializer(), Serdes.String().deserializer()); + final TestOutputTopic outputTopic2 = + driver.createOutputTopic(OUTPUT_TOPIC_2, Serdes.String().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1"); + assertNextOutputRecord(outputTopic1.readRecord(), "key1", "value1"); + assertTrue(outputTopic2.isEmpty()); + + final TestInputTopic inputTopic2 = driver.createInputTopic(INPUT_TOPIC_2, STRING_SERIALIZER, STRING_SERIALIZER, Instant.ofEpochMilli(0L), Duration.ZERO); + inputTopic2.pipeInput("key2", "value2"); + assertNextOutputRecord(outputTopic2.readRecord(), "key2", "value2"); + assertTrue(outputTopic2.isEmpty()); + } + + @Test + public void testDrivingForwardToSourceTopology() { + driver = new TopologyTestDriver(createForwardToSourceTopology(), props); + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER, Instant.ofEpochMilli(0L), Duration.ZERO); + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + final TestOutputTopic outputTopic2 = + driver.createOutputTopic(OUTPUT_TOPIC_2, Serdes.String().deserializer(), Serdes.String().deserializer()); + assertNextOutputRecord(outputTopic2.readRecord(), "key1", "value1"); + assertNextOutputRecord(outputTopic2.readRecord(), "key2", "value2"); + assertNextOutputRecord(outputTopic2.readRecord(), "key3", "value3"); + } + + @Test + public void testDrivingInternalRepartitioningTopology() { + driver = new TopologyTestDriver(createInternalRepartitioningTopology(), props); + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER, Instant.ofEpochMilli(0L), Duration.ZERO); + inputTopic.pipeInput("key1", "value1"); + inputTopic.pipeInput("key2", "value2"); + inputTopic.pipeInput("key3", "value3"); + final TestOutputTopic outputTopic1 = driver.createOutputTopic(OUTPUT_TOPIC_1, STRING_DESERIALIZER, STRING_DESERIALIZER); + assertNextOutputRecord(outputTopic1.readRecord(), "key1", "value1"); + assertNextOutputRecord(outputTopic1.readRecord(), "key2", "value2"); + assertNextOutputRecord(outputTopic1.readRecord(), "key3", "value3"); + } + + @Test + public void testDrivingInternalRepartitioningForwardingTimestampTopology() { + driver = new TopologyTestDriver(createInternalRepartitioningWithValueTimestampTopology(), props); + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + inputTopic.pipeInput("key1", "value1@1000"); + inputTopic.pipeInput("key2", "value2@2000"); + inputTopic.pipeInput("key3", "value3@3000"); + final TestOutputTopic outputTopic = driver.createOutputTopic(OUTPUT_TOPIC_1, STRING_DESERIALIZER, STRING_DESERIALIZER); + assertThat(outputTopic.readRecord(), + equalTo(new TestRecord<>("key1", "value1", null, 1000L))); + assertThat(outputTopic.readRecord(), + equalTo(new TestRecord<>("key2", "value2", null, 2000L))); + assertThat(outputTopic.readRecord(), + equalTo(new TestRecord<>("key3", "value3", null, 3000L))); + } + + @Test + public void shouldCreateStringWithSourceAndTopics() { + topology.addSource("source", "topic1", "topic2"); + final ProcessorTopology processorTopology = topology.getInternalBuilder().buildTopology(); + final String result = processorTopology.toString(); + assertThat(result, containsString("source:\n\t\ttopics:\t\t[topic1, topic2]\n")); + } + + @Test + public void shouldCreateStringWithMultipleSourcesAndTopics() { + topology.addSource("source", "topic1", "topic2"); + topology.addSource("source2", "t", "t1", "t2"); + final ProcessorTopology processorTopology = topology.getInternalBuilder().buildTopology(); + final String result = processorTopology.toString(); + assertThat(result, containsString("source:\n\t\ttopics:\t\t[topic1, topic2]\n")); + assertThat(result, containsString("source2:\n\t\ttopics:\t\t[t, t1, t2]\n")); + } + + @Test + public void shouldCreateStringWithProcessors() { + topology.addSource("source", "t") + .addProcessor("processor", mockProcessorSupplier, "source") + .addProcessor("other", mockProcessorSupplier, "source"); + final ProcessorTopology processorTopology = topology.getInternalBuilder().buildTopology(); + final String result = processorTopology.toString(); + assertThat(result, containsString("\t\tchildren:\t[processor, other]")); + assertThat(result, containsString("processor:\n")); + assertThat(result, containsString("other:\n")); + } + + @Test + public void shouldRecursivelyPrintChildren() { + topology.addSource("source", "t") + .addProcessor("processor", mockProcessorSupplier, "source") + .addProcessor("child-one", mockProcessorSupplier, "processor") + .addProcessor("child-one-one", mockProcessorSupplier, "child-one") + .addProcessor("child-two", mockProcessorSupplier, "processor") + .addProcessor("child-two-one", mockProcessorSupplier, "child-two"); + + final String result = topology.getInternalBuilder().buildTopology().toString(); + assertThat(result, containsString("child-one:\n\t\tchildren:\t[child-one-one]")); + assertThat(result, containsString("child-two:\n\t\tchildren:\t[child-two-one]")); + } + + @Test + public void shouldConsiderTimeStamps() { + final int partition = 10; + driver = new TopologyTestDriver(createSimpleTopology(partition), props); + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + inputTopic.pipeInput("key1", "value1", 10L); + inputTopic.pipeInput("key2", "value2", 20L); + inputTopic.pipeInput("key3", "value3", 30L); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.String().deserializer(), Serdes.String().deserializer()); + assertNextOutputRecord(outputTopic1.readRecord(), "key1", "value1", 10L); + assertNextOutputRecord(outputTopic1.readRecord(), "key2", "value2", 20L); + assertNextOutputRecord(outputTopic1.readRecord(), "key3", "value3", 30L); + } + + @Test + public void shouldConsiderModifiedTimeStamps() { + final int partition = 10; + driver = new TopologyTestDriver(createTimestampTopology(partition), props); + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + inputTopic.pipeInput("key1", "value1", 10L); + inputTopic.pipeInput("key2", "value2", 20L); + inputTopic.pipeInput("key3", "value3", 30L); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.String().deserializer(), Serdes.String().deserializer()); + assertNextOutputRecord(outputTopic1.readRecord(), "key1", "value1", 20L); + assertNextOutputRecord(outputTopic1.readRecord(), "key2", "value2", 30L); + assertNextOutputRecord(outputTopic1.readRecord(), "key3", "value3", 40L); + } + + @Test + public void shouldConsiderModifiedTimeStampsForMultipleProcessors() { + final int partition = 10; + driver = new TopologyTestDriver(createMultiProcessorTimestampTopology(partition), props); + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + final TestOutputTopic outputTopic1 = + driver.createOutputTopic(OUTPUT_TOPIC_1, Serdes.String().deserializer(), Serdes.String().deserializer()); + final TestOutputTopic outputTopic2 = + driver.createOutputTopic(OUTPUT_TOPIC_2, Serdes.String().deserializer(), Serdes.String().deserializer()); + + inputTopic.pipeInput("key1", "value1", 10L); + assertNextOutputRecord(outputTopic1.readRecord(), "key1", "value1", 10L); + assertNextOutputRecord(outputTopic2.readRecord(), "key1", "value1", 20L); + assertNextOutputRecord(outputTopic1.readRecord(), "key1", "value1", 15L); + assertNextOutputRecord(outputTopic2.readRecord(), "key1", "value1", 20L); + assertNextOutputRecord(outputTopic1.readRecord(), "key1", "value1", 12L); + assertNextOutputRecord(outputTopic2.readRecord(), "key1", "value1", 22L); + assertTrue(outputTopic1.isEmpty()); + assertTrue(outputTopic2.isEmpty()); + + inputTopic.pipeInput("key2", "value2", 20L); + assertNextOutputRecord(outputTopic1.readRecord(), "key2", "value2", 20L); + assertNextOutputRecord(outputTopic2.readRecord(), "key2", "value2", 30L); + assertNextOutputRecord(outputTopic1.readRecord(), "key2", "value2", 25L); + assertNextOutputRecord(outputTopic2.readRecord(), "key2", "value2", 30L); + assertNextOutputRecord(outputTopic1.readRecord(), "key2", "value2", 22L); + assertNextOutputRecord(outputTopic2.readRecord(), "key2", "value2", 32L); + assertTrue(outputTopic1.isEmpty()); + assertTrue(outputTopic2.isEmpty()); + } + + @Test + public void shouldConsiderHeaders() { + final int partition = 10; + driver = new TopologyTestDriver(createSimpleTopology(partition), props); + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + inputTopic.pipeInput(new TestRecord<>("key1", "value1", HEADERS, 10L)); + inputTopic.pipeInput(new TestRecord<>("key2", "value2", HEADERS, 20L)); + inputTopic.pipeInput(new TestRecord<>("key3", "value3", HEADERS, 30L)); + final TestOutputTopic outputTopic1 = driver.createOutputTopic(OUTPUT_TOPIC_1, STRING_DESERIALIZER, STRING_DESERIALIZER); + assertNextOutputRecord(outputTopic1.readRecord(), "key1", "value1", HEADERS, 10L); + assertNextOutputRecord(outputTopic1.readRecord(), "key2", "value2", HEADERS, 20L); + assertNextOutputRecord(outputTopic1.readRecord(), "key3", "value3", HEADERS, 30L); + } + + @Test + public void shouldAddHeaders() { + driver = new TopologyTestDriver(createAddHeaderTopology(), props); + final TestInputTopic inputTopic = driver.createInputTopic(INPUT_TOPIC_1, STRING_SERIALIZER, STRING_SERIALIZER); + inputTopic.pipeInput("key1", "value1", 10L); + inputTopic.pipeInput("key2", "value2", 20L); + inputTopic.pipeInput("key3", "value3", 30L); + final TestOutputTopic outputTopic1 = driver.createOutputTopic(OUTPUT_TOPIC_1, STRING_DESERIALIZER, STRING_DESERIALIZER); + assertNextOutputRecord(outputTopic1.readRecord(), "key1", "value1", HEADERS, 10L); + assertNextOutputRecord(outputTopic1.readRecord(), "key2", "value2", HEADERS, 20L); + assertNextOutputRecord(outputTopic1.readRecord(), "key3", "value3", HEADERS, 30L); + } + + @Test + public void statelessTopologyShouldNotHavePersistentStore() { + final TopologyWrapper topology = new TopologyWrapper(); + final ProcessorTopology processorTopology = topology.getInternalBuilder("anyAppId").buildTopology(); + assertFalse(processorTopology.hasPersistentLocalStore()); + assertFalse(processorTopology.hasPersistentGlobalStore()); + } + + @Test + public void inMemoryStoreShouldNotResultInPersistentLocalStore() { + final ProcessorTopology processorTopology = createLocalStoreTopology(Stores.inMemoryKeyValueStore("my-store")); + assertFalse(processorTopology.hasPersistentLocalStore()); + } + + @Test + public void persistentLocalStoreShouldBeDetected() { + final ProcessorTopology processorTopology = createLocalStoreTopology(Stores.persistentKeyValueStore("my-store")); + assertTrue(processorTopology.hasPersistentLocalStore()); + } + + @Test + public void inMemoryStoreShouldNotResultInPersistentGlobalStore() { + final ProcessorTopology processorTopology = createGlobalStoreTopology(Stores.inMemoryKeyValueStore("my-store")); + assertFalse(processorTopology.hasPersistentGlobalStore()); + } + + @Test + public void persistentGlobalStoreShouldBeDetected() { + final ProcessorTopology processorTopology = createGlobalStoreTopology(Stores.persistentKeyValueStore("my-store")); + assertTrue(processorTopology.hasPersistentGlobalStore()); + } + + private ProcessorTopology createLocalStoreTopology(final KeyValueBytesStoreSupplier storeSupplier) { + final TopologyWrapper topology = new TopologyWrapper(); + final String processor = "processor"; + final StoreBuilder> storeBuilder = + Stores.keyValueStoreBuilder(storeSupplier, Serdes.String(), Serdes.String()); + topology.addSource("source", STRING_DESERIALIZER, STRING_DESERIALIZER, "topic") + .addProcessor(processor, () -> new StatefulProcessor(storeSupplier.name()), "source") + .addStateStore(storeBuilder, processor); + return topology.getInternalBuilder("anyAppId").buildTopology(); + } + + @Deprecated // testing old PAPI + private ProcessorTopology createGlobalStoreTopology(final KeyValueBytesStoreSupplier storeSupplier) { + final TopologyWrapper topology = new TopologyWrapper(); + final StoreBuilder> storeBuilder = + Stores.keyValueStoreBuilder(storeSupplier, Serdes.String(), Serdes.String()).withLoggingDisabled(); + topology.addGlobalStore(storeBuilder, "global", STRING_DESERIALIZER, STRING_DESERIALIZER, "topic", "processor", + define(new OldAPIStatefulProcessor(storeSupplier.name()))); + return topology.getInternalBuilder("anyAppId").buildTopology(); + } + + private void assertNextOutputRecord(final TestRecord record, + final String key, + final String value) { + assertNextOutputRecord(record, key, value, 0L); + } + + private void assertNextOutputRecord(final TestRecord record, + final String key, + final String value, + final Long timestamp) { + assertNextOutputRecord(record, key, value, new RecordHeaders(), timestamp); + } + + private void assertNextOutputRecord(final TestRecord record, + final String key, + final String value, + final Headers headers, + final Long timestamp) { + assertEquals(key, record.key()); + assertEquals(value, record.value()); + assertEquals(timestamp, record.timestamp()); + assertEquals(headers, record.headers()); + } + + private StreamPartitioner constantPartitioner(final Integer partition) { + return (topic, key, value, numPartitions) -> partition; + } + + private Topology createSimpleTopology(final int partition) { + return topology + .addSource("source", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor", ForwardingProcessor::new, "source") + .addSink("sink", OUTPUT_TOPIC_1, constantPartitioner(partition), "processor"); + } + + private Topology createTimestampTopology(final int partition) { + return topology + .addSource("source", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor", TimestampProcessor::new, "source") + .addSink("sink", OUTPUT_TOPIC_1, constantPartitioner(partition), "processor"); + } + + private Topology createMultiProcessorTimestampTopology(final int partition) { + return topology + .addSource("source", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor", () -> new FanOutTimestampProcessor("child1", "child2"), "source") + .addProcessor("child1", ForwardingProcessor::new, "processor") + .addProcessor("child2", TimestampProcessor::new, "processor") + .addSink("sink1", OUTPUT_TOPIC_1, constantPartitioner(partition), "child1") + .addSink("sink2", OUTPUT_TOPIC_2, constantPartitioner(partition), "child2"); + } + + @Deprecated // testing old PAPI + private Topology createStatefulTopology(final String storeName) { + return topology + .addSource("source", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor", define(new OldAPIStatefulProcessor(storeName)), "source") + .addStateStore(Stores.keyValueStoreBuilder(Stores.inMemoryKeyValueStore(storeName), Serdes.String(), Serdes.String()), "processor") + .addSink("counts", OUTPUT_TOPIC_1, "processor"); + } + + @Deprecated // testing old PAPI + private Topology createConnectedStateStoreTopology(final String storeName) { + final StoreBuilder> storeBuilder = Stores.keyValueStoreBuilder(Stores.inMemoryKeyValueStore(storeName), Serdes.String(), Serdes.String()); + return topology + .addSource("source", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor", defineWithStoresOldAPI(() -> new OldAPIStatefulProcessor(storeName), Collections.singleton(storeBuilder)), "source") + .addSink("counts", OUTPUT_TOPIC_1, "processor"); + } + + private Topology createInternalRepartitioningTopology() { + topology.addSource("source", INPUT_TOPIC_1) + .addSink("sink0", THROUGH_TOPIC_1, "source") + .addSource("source1", THROUGH_TOPIC_1) + .addSink("sink1", OUTPUT_TOPIC_1, "source1"); + + // use wrapper to get the internal topology builder to add internal topic + final InternalTopologyBuilder internalTopologyBuilder = TopologyWrapper.getInternalTopologyBuilder(topology); + internalTopologyBuilder.addInternalTopic(THROUGH_TOPIC_1, InternalTopicProperties.empty()); + + return topology; + } + + private Topology createInternalRepartitioningWithValueTimestampTopology() { + topology.addSource("source", INPUT_TOPIC_1) + .addProcessor("processor", ValueTimestampProcessor::new, "source") + .addSink("sink0", THROUGH_TOPIC_1, "processor") + .addSource("source1", THROUGH_TOPIC_1) + .addSink("sink1", OUTPUT_TOPIC_1, "source1"); + + // use wrapper to get the internal topology builder to add internal topic + final InternalTopologyBuilder internalTopologyBuilder = TopologyWrapper.getInternalTopologyBuilder(topology); + internalTopologyBuilder.addInternalTopic(THROUGH_TOPIC_1, InternalTopicProperties.empty()); + + return topology; + } + + private Topology createForwardToSourceTopology() { + return topology.addSource("source-1", INPUT_TOPIC_1) + .addSink("sink-1", OUTPUT_TOPIC_1, "source-1") + .addSource("source-2", OUTPUT_TOPIC_1) + .addSink("sink-2", OUTPUT_TOPIC_2, "source-2"); + } + + private Topology createSimpleMultiSourceTopology(final int partition) { + return topology.addSource("source-1", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor-1", ForwardingProcessor::new, "source-1") + .addSink("sink-1", OUTPUT_TOPIC_1, constantPartitioner(partition), "processor-1") + .addSource("source-2", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_2) + .addProcessor("processor-2", ForwardingProcessor::new, "source-2") + .addSink("sink-2", OUTPUT_TOPIC_2, constantPartitioner(partition), "processor-2"); + } + + private Topology createAddHeaderTopology() { + return topology.addSource("source-1", STRING_DESERIALIZER, STRING_DESERIALIZER, INPUT_TOPIC_1) + .addProcessor("processor-1", AddHeaderProcessor::new, "source-1") + .addSink("sink-1", OUTPUT_TOPIC_1, "processor-1"); + } + + /** + * A processor that simply forwards all messages to all children. + */ + protected static class ForwardingProcessor implements Processor { + private ProcessorContext context; + + @Override + public void init(final ProcessorContext context) { + this.context = context; + } + + @Override + public void process(final Record record) { + context.forward(record); + } + } + + /** + * A processor that simply forwards all messages to all children with advanced timestamps. + */ + protected static class TimestampProcessor implements Processor { + private ProcessorContext context; + + @Override + public void init(final ProcessorContext context) { + this.context = context; + } + + @Override + public void process(final Record record) { + context.forward(record.withTimestamp(record.timestamp() + 10)); + } + } + + protected static class FanOutTimestampProcessor implements Processor { + private final String firstChild; + private final String secondChild; + private ProcessorContext context; + + FanOutTimestampProcessor(final String firstChild, + final String secondChild) { + this.firstChild = firstChild; + this.secondChild = secondChild; + } + + @Override + public void init(final ProcessorContext context) { + this.context = context; + } + + @Override + public void process(final Record record) { + context.forward(record); + context.forward(record.withTimestamp(record.timestamp() + 5), firstChild); + context.forward(record, secondChild); + context.forward(record.withTimestamp(record.timestamp() + 2)); + } + } + + protected static class AddHeaderProcessor implements Processor { + private ProcessorContext context; + + @Override + public void init(final ProcessorContext context) { + this.context = context; + } + + @Override + public void process(final Record record) { + // making a copy of headers for safety. + final Record toForward = record.withHeaders(record.headers()); + toForward.headers().add(HEADER); + context.forward(toForward); + } + } + + /** + * A processor that removes custom timestamp information from messages and forwards modified messages to each child. + * A message contains custom timestamp information if the value is in ".*@[0-9]+" format. + */ + protected static class ValueTimestampProcessor implements Processor { + private ProcessorContext context; + + @Override + public void init(final ProcessorContext context) { + this.context = context; + } + + @Override + public void process(final Record record) { + context.forward(record.withValue(record.value().split("@")[0])); + } + } + + /** + * A processor that stores each key-value pair in an in-memory key-value store registered with the context. + */ + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + protected static class OldAPIStatefulProcessor extends org.apache.kafka.streams.processor.AbstractProcessor { + private KeyValueStore store; + private final String storeName; + + OldAPIStatefulProcessor(final String storeName) { + this.storeName = storeName; + } + + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + super.init(context); + store = context.getStateStore(storeName); + } + + @Override + public void process(final String key, final String value) { + store.put(key, value); + } + } + + /** + * A processor that stores each key-value pair in an in-memory key-value store registered with the context. + */ + protected static class StatefulProcessor implements Processor { + private KeyValueStore store; + private final String storeName; + + StatefulProcessor(final String storeName) { + this.storeName = storeName; + } + + @Override + public void init(final ProcessorContext context) { + store = context.getStateStore(storeName); + } + + @Override + public void process(final Record record) { + store.put(record.key(), record.value()); + } + } + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + private org.apache.kafka.streams.processor.ProcessorSupplier define(final org.apache.kafka.streams.processor.Processor processor) { + return () -> processor; + } + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + private org.apache.kafka.streams.processor.ProcessorSupplier defineWithStoresOldAPI(final Supplier> supplier, + final Set> stores) { + return new org.apache.kafka.streams.processor.ProcessorSupplier() { + @Override + public org.apache.kafka.streams.processor.Processor get() { + return supplier.get(); + } + + @Override + public Set> stores() { + return stores; + } + }; + } + + private ProcessorSupplier defineWithStores(final Supplier> supplier, + final Set> stores) { + return new ProcessorSupplier() { + @Override + public Processor get() { + return supplier.get(); + } + + @Override + public Set> stores() { + return stores; + } + }; + } + + /** + * A custom timestamp extractor that extracts the timestamp from the record's value if the value is in ".*@[0-9]+" + * format. Otherwise, it returns the record's timestamp or the default timestamp if the record's timestamp is negative. + */ + public static class CustomTimestampExtractor implements TimestampExtractor { + private static final long DEFAULT_TIMESTAMP = 1000L; + + @Override + public long extract(final ConsumerRecord record, final long partitionTime) { + if (record.value().toString().matches(".*@[0-9]+")) { + return Long.parseLong(record.value().toString().split("@")[1]); + } + + if (record.timestamp() >= 0L) { + return record.timestamp(); + } + + return DEFAULT_TIMESTAMP; + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/PunctuationQueueTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/PunctuationQueueTest.java new file mode 100644 index 0000000..cde573a --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/PunctuationQueueTest.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.streams.processor.Cancellable; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.Punctuator; +import org.apache.kafka.test.MockProcessorNode; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +public class PunctuationQueueTest { + + private final MockProcessorNode node = new MockProcessorNode<>(); + private final PunctuationQueue queue = new PunctuationQueue(); + private final Punctuator punctuator = new Punctuator() { + @Override + public void punctuate(final long timestamp) { + node.mockProcessor.punctuatedStreamTime().add(timestamp); + } + }; + + @Test + public void testPunctuationInterval() { + final PunctuationSchedule sched = new PunctuationSchedule(node, 0L, 100L, punctuator); + final long now = sched.timestamp - 100L; + + queue.schedule(sched); + + final ProcessorNodePunctuator processorNodePunctuator = new ProcessorNodePunctuator() { + @Override + public void punctuate(final ProcessorNode node, final long timestamp, final PunctuationType type, final Punctuator punctuator) { + punctuator.punctuate(timestamp); + } + }; + + queue.mayPunctuate(now, PunctuationType.STREAM_TIME, processorNodePunctuator); + assertEquals(0, node.mockProcessor.punctuatedStreamTime().size()); + + queue.mayPunctuate(now + 99L, PunctuationType.STREAM_TIME, processorNodePunctuator); + assertEquals(0, node.mockProcessor.punctuatedStreamTime().size()); + + queue.mayPunctuate(now + 100L, PunctuationType.STREAM_TIME, processorNodePunctuator); + assertEquals(1, node.mockProcessor.punctuatedStreamTime().size()); + + queue.mayPunctuate(now + 199L, PunctuationType.STREAM_TIME, processorNodePunctuator); + assertEquals(1, node.mockProcessor.punctuatedStreamTime().size()); + + queue.mayPunctuate(now + 200L, PunctuationType.STREAM_TIME, processorNodePunctuator); + assertEquals(2, node.mockProcessor.punctuatedStreamTime().size()); + + queue.mayPunctuate(now + 1001L, PunctuationType.STREAM_TIME, processorNodePunctuator); + assertEquals(3, node.mockProcessor.punctuatedStreamTime().size()); + + queue.mayPunctuate(now + 1002L, PunctuationType.STREAM_TIME, processorNodePunctuator); + assertEquals(3, node.mockProcessor.punctuatedStreamTime().size()); + + queue.mayPunctuate(now + 1100L, PunctuationType.STREAM_TIME, processorNodePunctuator); + assertEquals(4, node.mockProcessor.punctuatedStreamTime().size()); + } + + @Test + public void testPunctuationIntervalCustomAlignment() { + final PunctuationSchedule sched = new PunctuationSchedule(node, 50L, 100L, punctuator); + final long now = sched.timestamp - 50L; + + queue.schedule(sched); + + final ProcessorNodePunctuator processorNodePunctuator = new ProcessorNodePunctuator() { + @Override + public void punctuate(final ProcessorNode node, final long timestamp, final PunctuationType type, final Punctuator punctuator) { + punctuator.punctuate(timestamp); + } + }; + + queue.mayPunctuate(now, PunctuationType.STREAM_TIME, processorNodePunctuator); + assertEquals(0, node.mockProcessor.punctuatedStreamTime().size()); + + queue.mayPunctuate(now + 49L, PunctuationType.STREAM_TIME, processorNodePunctuator); + assertEquals(0, node.mockProcessor.punctuatedStreamTime().size()); + + queue.mayPunctuate(now + 50L, PunctuationType.STREAM_TIME, processorNodePunctuator); + assertEquals(1, node.mockProcessor.punctuatedStreamTime().size()); + + queue.mayPunctuate(now + 149L, PunctuationType.STREAM_TIME, processorNodePunctuator); + assertEquals(1, node.mockProcessor.punctuatedStreamTime().size()); + + queue.mayPunctuate(now + 150L, PunctuationType.STREAM_TIME, processorNodePunctuator); + assertEquals(2, node.mockProcessor.punctuatedStreamTime().size()); + + queue.mayPunctuate(now + 1051L, PunctuationType.STREAM_TIME, processorNodePunctuator); + assertEquals(3, node.mockProcessor.punctuatedStreamTime().size()); + + queue.mayPunctuate(now + 1052L, PunctuationType.STREAM_TIME, processorNodePunctuator); + assertEquals(3, node.mockProcessor.punctuatedStreamTime().size()); + + queue.mayPunctuate(now + 1150L, PunctuationType.STREAM_TIME, processorNodePunctuator); + assertEquals(4, node.mockProcessor.punctuatedStreamTime().size()); + } + + @Test + public void testPunctuationIntervalCancelFromPunctuator() { + final PunctuationSchedule sched = new PunctuationSchedule(node, 0L, 100L, punctuator); + final long now = sched.timestamp - 100L; + + final Cancellable cancellable = queue.schedule(sched); + + final ProcessorNodePunctuator processorNodePunctuator = new ProcessorNodePunctuator() { + @Override + public void punctuate(final ProcessorNode node, final long timestamp, final PunctuationType type, final Punctuator punctuator) { + punctuator.punctuate(timestamp); + // simulate scheduler cancelled from within punctuator + cancellable.cancel(); + } + }; + + queue.mayPunctuate(now, PunctuationType.STREAM_TIME, processorNodePunctuator); + assertEquals(0, node.mockProcessor.punctuatedStreamTime().size()); + + queue.mayPunctuate(now + 100L, PunctuationType.STREAM_TIME, processorNodePunctuator); + assertEquals(1, node.mockProcessor.punctuatedStreamTime().size()); + + queue.mayPunctuate(now + 200L, PunctuationType.STREAM_TIME, processorNodePunctuator); + assertEquals(1, node.mockProcessor.punctuatedStreamTime().size()); + } + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + private static class TestProcessor extends org.apache.kafka.streams.processor.AbstractProcessor { + + @Override + public void init(final ProcessorContext context) {} + + @Override + public void process(final String key, final String value) {} + + @Override + public void close() {} + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/QuickUnionTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/QuickUnionTest.java new file mode 100644 index 0000000..f5c15ea --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/QuickUnionTest.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.junit.Test; + +import java.util.HashSet; +import java.util.Set; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; + +public class QuickUnionTest { + + @Test + public void testUnite() { + final QuickUnion qu = new QuickUnion<>(); + + final long[] ids = { + 1L, 2L, 3L, 4L, 5L + }; + + for (final long id : ids) { + qu.add(id); + } + + assertEquals(5, roots(qu, ids).size()); + + qu.unite(1L, 2L); + assertEquals(4, roots(qu, ids).size()); + assertEquals(qu.root(1L), qu.root(2L)); + + qu.unite(3L, 4L); + assertEquals(3, roots(qu, ids).size()); + assertEquals(qu.root(1L), qu.root(2L)); + assertEquals(qu.root(3L), qu.root(4L)); + + qu.unite(1L, 5L); + assertEquals(2, roots(qu, ids).size()); + assertEquals(qu.root(1L), qu.root(2L)); + assertEquals(qu.root(2L), qu.root(5L)); + assertEquals(qu.root(3L), qu.root(4L)); + + qu.unite(3L, 5L); + assertEquals(1, roots(qu, ids).size()); + assertEquals(qu.root(1L), qu.root(2L)); + assertEquals(qu.root(2L), qu.root(3L)); + assertEquals(qu.root(3L), qu.root(4L)); + assertEquals(qu.root(4L), qu.root(5L)); + } + + @Test + public void testUniteMany() { + final QuickUnion qu = new QuickUnion<>(); + + final long[] ids = { + 1L, 2L, 3L, 4L, 5L + }; + + for (final long id : ids) { + qu.add(id); + } + + assertEquals(5, roots(qu, ids).size()); + + qu.unite(1L, 2L, 3L, 4L); + assertEquals(2, roots(qu, ids).size()); + assertEquals(qu.root(1L), qu.root(2L)); + assertEquals(qu.root(2L), qu.root(3L)); + assertEquals(qu.root(3L), qu.root(4L)); + assertNotEquals(qu.root(1L), qu.root(5L)); + } + + private Set roots(final QuickUnion qu, final long... ids) { + final HashSet roots = new HashSet<>(); + for (final long id : ids) { + roots.add(qu.root(id)); + } + return roots; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java new file mode 100644 index 0000000..48364f2 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordCollectorTest.java @@ -0,0 +1,949 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.producer.Callback; +import org.apache.kafka.clients.producer.MockProducer; +import org.apache.kafka.clients.producer.Producer; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.clients.producer.internals.DefaultPartitioner; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.AuthenticationException; +import org.apache.kafka.common.errors.InvalidProducerEpochException; +import org.apache.kafka.common.errors.ProducerFencedException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.serialization.LongSerializer; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.AlwaysContinueProductionExceptionHandler; +import org.apache.kafka.streams.errors.DefaultProductionExceptionHandler; +import org.apache.kafka.streams.errors.ProductionExceptionHandler; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.errors.TaskMigratedException; +import org.apache.kafka.streams.processor.StreamPartitioner; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.apache.kafka.test.MockClientSupplier; + +import java.util.UUID; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.expectLastCall; +import static org.easymock.EasyMock.mock; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.verify; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.IsEqual.equalTo; +import static org.hamcrest.core.IsInstanceOf.instanceOf; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class RecordCollectorTest { + + private final LogContext logContext = new LogContext("test "); + private final TaskId taskId = new TaskId(0, 0); + private final ProductionExceptionHandler productionExceptionHandler = new DefaultProductionExceptionHandler(); + private final StreamsMetricsImpl streamsMetrics = new MockStreamsMetrics(new Metrics()); + private final StreamsConfig config = new StreamsConfig(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234") + )); + private final StreamsConfig eosConfig = new StreamsConfig(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"), + mkEntry(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE_V2) + )); + + private final String topic = "topic"; + private final Cluster cluster = new Cluster( + "cluster", + Collections.singletonList(Node.noNode()), + Arrays.asList( + new PartitionInfo(topic, 0, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo(topic, 1, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo(topic, 2, Node.noNode(), new Node[0], new Node[0]) + ), + Collections.emptySet(), + Collections.emptySet() + ); + + private final StringSerializer stringSerializer = new StringSerializer(); + private final ByteArraySerializer byteArraySerializer = new ByteArraySerializer(); + private final UUID processId = UUID.randomUUID(); + + private final StreamPartitioner streamPartitioner = + (topic, key, value, numPartitions) -> Integer.parseInt(key) % numPartitions; + + private MockProducer mockProducer; + private StreamsProducer streamsProducer; + + private RecordCollectorImpl collector; + + @Before + public void setup() { + final MockClientSupplier clientSupplier = new MockClientSupplier(); + clientSupplier.setCluster(cluster); + streamsProducer = new StreamsProducer( + config, + processId + "-StreamThread-1", + clientSupplier, + null, + processId, + logContext, + Time.SYSTEM + ); + mockProducer = clientSupplier.producers.get(0); + collector = new RecordCollectorImpl( + logContext, + taskId, + streamsProducer, + productionExceptionHandler, + streamsMetrics); + } + + @After + public void cleanup() { + collector.closeClean(); + } + + @Test + public void shouldSendToSpecificPartition() { + final Headers headers = new RecordHeaders(new Header[] {new RecordHeader("key", "value".getBytes())}); + + collector.send(topic, "999", "0", null, 0, null, stringSerializer, stringSerializer); + collector.send(topic, "999", "0", null, 0, null, stringSerializer, stringSerializer); + collector.send(topic, "999", "0", null, 0, null, stringSerializer, stringSerializer); + collector.send(topic, "999", "0", headers, 1, null, stringSerializer, stringSerializer); + collector.send(topic, "999", "0", headers, 1, null, stringSerializer, stringSerializer); + collector.send(topic, "999", "0", headers, 2, null, stringSerializer, stringSerializer); + + Map offsets = collector.offsets(); + + assertEquals(2L, (long) offsets.get(new TopicPartition(topic, 0))); + assertEquals(1L, (long) offsets.get(new TopicPartition(topic, 1))); + assertEquals(0L, (long) offsets.get(new TopicPartition(topic, 2))); + assertEquals(6, mockProducer.history().size()); + + collector.send(topic, "999", "0", null, 0, null, stringSerializer, stringSerializer); + collector.send(topic, "999", "0", null, 1, null, stringSerializer, stringSerializer); + collector.send(topic, "999", "0", headers, 2, null, stringSerializer, stringSerializer); + + offsets = collector.offsets(); + + assertEquals(3L, (long) offsets.get(new TopicPartition(topic, 0))); + assertEquals(2L, (long) offsets.get(new TopicPartition(topic, 1))); + assertEquals(1L, (long) offsets.get(new TopicPartition(topic, 2))); + assertEquals(9, mockProducer.history().size()); + } + + @Test + public void shouldSendWithPartitioner() { + final Headers headers = new RecordHeaders(new Header[] {new RecordHeader("key", "value".getBytes())}); + + collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner); + collector.send(topic, "9", "0", null, null, stringSerializer, stringSerializer, streamPartitioner); + collector.send(topic, "27", "0", null, null, stringSerializer, stringSerializer, streamPartitioner); + collector.send(topic, "81", "0", null, null, stringSerializer, stringSerializer, streamPartitioner); + collector.send(topic, "243", "0", null, null, stringSerializer, stringSerializer, streamPartitioner); + collector.send(topic, "28", "0", headers, null, stringSerializer, stringSerializer, streamPartitioner); + collector.send(topic, "82", "0", headers, null, stringSerializer, stringSerializer, streamPartitioner); + collector.send(topic, "244", "0", headers, null, stringSerializer, stringSerializer, streamPartitioner); + collector.send(topic, "245", "0", null, null, stringSerializer, stringSerializer, streamPartitioner); + + final Map offsets = collector.offsets(); + + assertEquals(4L, (long) offsets.get(new TopicPartition(topic, 0))); + assertEquals(2L, (long) offsets.get(new TopicPartition(topic, 1))); + assertEquals(0L, (long) offsets.get(new TopicPartition(topic, 2))); + assertEquals(9, mockProducer.history().size()); + + // returned offsets should not be modified + final TopicPartition topicPartition = new TopicPartition(topic, 0); + assertThrows(UnsupportedOperationException.class, () -> offsets.put(topicPartition, 50L)); + } + + @Test + public void shouldSendWithNoPartition() { + final Headers headers = new RecordHeaders(new Header[] {new RecordHeader("key", "value".getBytes())}); + + collector.send(topic, "3", "0", headers, null, null, stringSerializer, stringSerializer); + collector.send(topic, "9", "0", headers, null, null, stringSerializer, stringSerializer); + collector.send(topic, "27", "0", headers, null, null, stringSerializer, stringSerializer); + collector.send(topic, "81", "0", headers, null, null, stringSerializer, stringSerializer); + collector.send(topic, "243", "0", headers, null, null, stringSerializer, stringSerializer); + collector.send(topic, "28", "0", headers, null, null, stringSerializer, stringSerializer); + collector.send(topic, "82", "0", headers, null, null, stringSerializer, stringSerializer); + collector.send(topic, "244", "0", headers, null, null, stringSerializer, stringSerializer); + collector.send(topic, "245", "0", headers, null, null, stringSerializer, stringSerializer); + + final Map offsets = collector.offsets(); + + // with mock producer without specific partition, we would use default producer partitioner with murmur hash + assertEquals(3L, (long) offsets.get(new TopicPartition(topic, 0))); + assertEquals(2L, (long) offsets.get(new TopicPartition(topic, 1))); + assertEquals(1L, (long) offsets.get(new TopicPartition(topic, 2))); + assertEquals(9, mockProducer.history().size()); + } + + @Test + public void shouldUpdateOffsetsUponCompletion() { + Map offsets = collector.offsets(); + + collector.send(topic, "999", "0", null, 0, null, stringSerializer, stringSerializer); + collector.send(topic, "999", "0", null, 1, null, stringSerializer, stringSerializer); + collector.send(topic, "999", "0", null, 2, null, stringSerializer, stringSerializer); + + assertEquals(Collections.emptyMap(), offsets); + + collector.flush(); + + offsets = collector.offsets(); + assertEquals((Long) 0L, offsets.get(new TopicPartition(topic, 0))); + assertEquals((Long) 0L, offsets.get(new TopicPartition(topic, 1))); + assertEquals((Long) 0L, offsets.get(new TopicPartition(topic, 2))); + } + + @Test + public void shouldPassThroughRecordHeaderToSerializer() { + final CustomStringSerializer keySerializer = new CustomStringSerializer(); + final CustomStringSerializer valueSerializer = new CustomStringSerializer(); + keySerializer.configure(Collections.emptyMap(), true); + + collector.send(topic, "3", "0", new RecordHeaders(), null, keySerializer, valueSerializer, streamPartitioner); + + final List> recordHistory = mockProducer.history(); + for (final ProducerRecord sentRecord : recordHistory) { + final Headers headers = sentRecord.headers(); + assertEquals(2, headers.toArray().length); + assertEquals(new RecordHeader("key", "key".getBytes()), headers.lastHeader("key")); + assertEquals(new RecordHeader("value", "value".getBytes()), headers.lastHeader("value")); + } + } + + @Test + public void shouldForwardFlushToStreamsProducer() { + final StreamsProducer streamsProducer = mock(StreamsProducer.class); + expect(streamsProducer.eosEnabled()).andReturn(false); + streamsProducer.flush(); + expectLastCall(); + replay(streamsProducer); + + final RecordCollector collector = new RecordCollectorImpl( + logContext, + taskId, + streamsProducer, + productionExceptionHandler, + streamsMetrics); + + collector.flush(); + + verify(streamsProducer); + } + + @Test + public void shouldForwardFlushToStreamsProducerEosEnabled() { + final StreamsProducer streamsProducer = mock(StreamsProducer.class); + expect(streamsProducer.eosEnabled()).andReturn(true); + streamsProducer.flush(); + expectLastCall(); + replay(streamsProducer); + + final RecordCollector collector = new RecordCollectorImpl( + logContext, + taskId, + streamsProducer, + productionExceptionHandler, + streamsMetrics); + + collector.flush(); + + verify(streamsProducer); + } + + @Test + public void shouldNotAbortTxOnCloseCleanIfEosEnabled() { + final StreamsProducer streamsProducer = mock(StreamsProducer.class); + expect(streamsProducer.eosEnabled()).andReturn(true); + replay(streamsProducer); + + final RecordCollector collector = new RecordCollectorImpl( + logContext, + taskId, + streamsProducer, + productionExceptionHandler, + streamsMetrics); + + collector.closeClean(); + + verify(streamsProducer); + } + + @Test + public void shouldAbortTxOnCloseDirtyIfEosEnabled() { + final StreamsProducer streamsProducer = mock(StreamsProducer.class); + expect(streamsProducer.eosEnabled()).andReturn(true); + streamsProducer.abortTransaction(); + replay(streamsProducer); + + final RecordCollector collector = new RecordCollectorImpl( + logContext, + taskId, + streamsProducer, + productionExceptionHandler, + streamsMetrics); + + collector.closeDirty(); + + verify(streamsProducer); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + @Test + public void shouldThrowInformativeStreamsExceptionOnKeyClassCastException() { + final StreamsException expected = assertThrows( + StreamsException.class, + () -> this.collector.send( + "topic", + "key", + "value", + new RecordHeaders(), + 0, + 0L, + (Serializer) new LongSerializer(), // need to add cast to trigger `ClassCastException` + new StringSerializer()) + ); + + assertThat(expected.getCause(), instanceOf(ClassCastException.class)); + assertThat( + expected.getMessage(), + equalTo( + "ClassCastException while producing data to topic topic. " + + "A serializer (key: org.apache.kafka.common.serialization.LongSerializer / value: org.apache.kafka.common.serialization.StringSerializer) " + + "is not compatible to the actual key or value type (key type: java.lang.String / value type: java.lang.String). " + + "Change the default Serdes in StreamConfig or provide correct Serdes via method parameters " + + "(for example if using the DSL, `#to(String topic, Produced produced)` with `Produced.keySerde(WindowedSerdes.timeWindowedSerdeFrom(String.class))`).") + ); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + @Test + public void shouldThrowInformativeStreamsExceptionOnKeyAndNullValueClassCastException() { + final StreamsException expected = assertThrows( + StreamsException.class, + () -> this.collector.send( + "topic", + "key", + null, + new RecordHeaders(), + 0, + 0L, + (Serializer) new LongSerializer(), // need to add cast to trigger `ClassCastException` + new StringSerializer()) + ); + + assertThat(expected.getCause(), instanceOf(ClassCastException.class)); + assertThat( + expected.getMessage(), + equalTo( + "ClassCastException while producing data to topic topic. " + + "A serializer (key: org.apache.kafka.common.serialization.LongSerializer / value: org.apache.kafka.common.serialization.StringSerializer) " + + "is not compatible to the actual key or value type (key type: java.lang.String / value type: unknown because value is null). " + + "Change the default Serdes in StreamConfig or provide correct Serdes via method parameters " + + "(for example if using the DSL, `#to(String topic, Produced produced)` with `Produced.keySerde(WindowedSerdes.timeWindowedSerdeFrom(String.class))`).") + ); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + @Test + public void shouldThrowInformativeStreamsExceptionOnValueClassCastException() { + final StreamsException expected = assertThrows( + StreamsException.class, + () -> this.collector.send( + "topic", + "key", + "value", + new RecordHeaders(), + 0, + 0L, + new StringSerializer(), + (Serializer) new LongSerializer()) // need to add cast to trigger `ClassCastException` + ); + + assertThat(expected.getCause(), instanceOf(ClassCastException.class)); + assertThat( + expected.getMessage(), + equalTo( + "ClassCastException while producing data to topic topic. " + + "A serializer (key: org.apache.kafka.common.serialization.StringSerializer / value: org.apache.kafka.common.serialization.LongSerializer) " + + "is not compatible to the actual key or value type (key type: java.lang.String / value type: java.lang.String). " + + "Change the default Serdes in StreamConfig or provide correct Serdes via method parameters " + + "(for example if using the DSL, `#to(String topic, Produced produced)` with `Produced.keySerde(WindowedSerdes.timeWindowedSerdeFrom(String.class))`).") + ); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + @Test + public void shouldThrowInformativeStreamsExceptionOnValueAndNullKeyClassCastException() { + final StreamsException expected = assertThrows( + StreamsException.class, + () -> this.collector.send( + "topic", + null, + "value", + new RecordHeaders(), + 0, + 0L, + new StringSerializer(), + (Serializer) new LongSerializer()) // need to add cast to trigger `ClassCastException` + ); + + assertThat(expected.getCause(), instanceOf(ClassCastException.class)); + assertThat( + expected.getMessage(), + equalTo( + "ClassCastException while producing data to topic topic. " + + "A serializer (key: org.apache.kafka.common.serialization.StringSerializer / value: org.apache.kafka.common.serialization.LongSerializer) " + + "is not compatible to the actual key or value type (key type: unknown because key is null / value type: java.lang.String). " + + "Change the default Serdes in StreamConfig or provide correct Serdes via method parameters " + + "(for example if using the DSL, `#to(String topic, Produced produced)` with `Produced.keySerde(WindowedSerdes.timeWindowedSerdeFrom(String.class))`).") + ); + } + + @Test + public void shouldThrowInformativeStreamsExceptionOnKafkaExceptionFromStreamPartitioner() { + final RecordCollector collector = new RecordCollectorImpl( + logContext, + taskId, + getExceptionalStreamProducerOnPartitionsFor(new KafkaException("Kaboom!")), + productionExceptionHandler, + streamsMetrics + ); + collector.initialize(); + + final StreamsException exception = assertThrows( + StreamsException.class, + () -> collector.send(topic, "0", "0", null, null, stringSerializer, stringSerializer, streamPartitioner) + ); + assertThat( + exception.getMessage(), + equalTo("Could not determine the number of partitions for topic '" + topic + "' for task " + + taskId + " due to org.apache.kafka.common.KafkaException: Kaboom!") + ); + } + + @Test + public void shouldForwardTimeoutExceptionFromStreamPartitionerWithoutWrappingIt() { + shouldForwardExceptionWithoutWrappingIt(new TimeoutException("Kaboom!")); + } + + @Test + public void shouldForwardRuntimeExceptionFromStreamPartitionerWithoutWrappingIt() { + shouldForwardExceptionWithoutWrappingIt(new RuntimeException("Kaboom!")); + } + + private void shouldForwardExceptionWithoutWrappingIt(final E runtimeException) { + final RecordCollector collector = new RecordCollectorImpl( + logContext, + taskId, + getExceptionalStreamProducerOnPartitionsFor(runtimeException), + productionExceptionHandler, + streamsMetrics + ); + collector.initialize(); + + final RuntimeException exception = assertThrows( + runtimeException.getClass(), + () -> collector.send(topic, "0", "0", null, null, stringSerializer, stringSerializer, streamPartitioner) + ); + assertThat(exception.getMessage(), equalTo("Kaboom!")); + } + + @Test + public void shouldThrowTaskMigratedExceptionOnSubsequentSendWhenProducerFencedInCallback() { + testThrowTaskMigratedExceptionOnSubsequentSend(new ProducerFencedException("KABOOM!")); + } + + @Test + public void shouldThrowTaskMigratedExceptionOnSubsequentSendWhenInvalidEpochInCallback() { + testThrowTaskMigratedExceptionOnSubsequentSend(new InvalidProducerEpochException("KABOOM!")); + } + + private void testThrowTaskMigratedExceptionOnSubsequentSend(final RuntimeException exception) { + final RecordCollector collector = new RecordCollectorImpl( + logContext, + taskId, + getExceptionalStreamsProducerOnSend(exception), + productionExceptionHandler, + streamsMetrics + ); + collector.initialize(); + + collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner); + + final TaskMigratedException thrown = assertThrows( + TaskMigratedException.class, + () -> collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner) + ); + assertEquals(exception, thrown.getCause()); + } + + @Test + public void shouldThrowTaskMigratedExceptionOnSubsequentFlushWhenProducerFencedInCallback() { + testThrowTaskMigratedExceptionOnSubsequentFlush(new ProducerFencedException("KABOOM!")); + } + + @Test + public void shouldThrowTaskMigratedExceptionOnSubsequentFlushWhenInvalidEpochInCallback() { + testThrowTaskMigratedExceptionOnSubsequentFlush(new InvalidProducerEpochException("KABOOM!")); + } + + private void testThrowTaskMigratedExceptionOnSubsequentFlush(final RuntimeException exception) { + final RecordCollector collector = new RecordCollectorImpl( + logContext, + taskId, + getExceptionalStreamsProducerOnSend(exception), + productionExceptionHandler, + streamsMetrics + ); + collector.initialize(); + + collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner); + + final TaskMigratedException thrown = assertThrows(TaskMigratedException.class, collector::flush); + assertEquals(exception, thrown.getCause()); + } + + @Test + public void shouldThrowTaskMigratedExceptionOnSubsequentCloseWhenProducerFencedInCallback() { + testThrowTaskMigratedExceptionOnSubsequentClose(new ProducerFencedException("KABOOM!")); + } + + @Test + public void shouldThrowTaskMigratedExceptionOnSubsequentCloseWhenInvalidEpochInCallback() { + testThrowTaskMigratedExceptionOnSubsequentClose(new InvalidProducerEpochException("KABOOM!")); + } + + private void testThrowTaskMigratedExceptionOnSubsequentClose(final RuntimeException exception) { + final RecordCollector collector = new RecordCollectorImpl( + logContext, + taskId, + getExceptionalStreamsProducerOnSend(exception), + productionExceptionHandler, + streamsMetrics + ); + collector.initialize(); + + collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner); + + final TaskMigratedException thrown = assertThrows(TaskMigratedException.class, collector::closeClean); + assertEquals(exception, thrown.getCause()); + } + + @Test + public void shouldThrowStreamsExceptionOnSubsequentSendIfASendFailsWithDefaultExceptionHandler() { + final KafkaException exception = new KafkaException("KABOOM!"); + final RecordCollector collector = new RecordCollectorImpl( + logContext, + taskId, + getExceptionalStreamsProducerOnSend(exception), + productionExceptionHandler, + streamsMetrics + ); + + collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner); + + final StreamsException thrown = assertThrows( + StreamsException.class, + () -> collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner) + ); + assertEquals(exception, thrown.getCause()); + assertThat( + thrown.getMessage(), + equalTo("Error encountered sending record to topic topic for task 0_0 due to:" + + "\norg.apache.kafka.common.KafkaException: KABOOM!" + + "\nException handler choose to FAIL the processing, no more records would be sent.") + ); + } + + @Test + public void shouldThrowStreamsExceptionOnSubsequentFlushIfASendFailsWithDefaultExceptionHandler() { + final KafkaException exception = new KafkaException("KABOOM!"); + final RecordCollector collector = new RecordCollectorImpl( + logContext, + taskId, + getExceptionalStreamsProducerOnSend(exception), + productionExceptionHandler, + streamsMetrics + ); + + collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner); + + final StreamsException thrown = assertThrows(StreamsException.class, collector::flush); + assertEquals(exception, thrown.getCause()); + assertThat( + thrown.getMessage(), + equalTo("Error encountered sending record to topic topic for task 0_0 due to:" + + "\norg.apache.kafka.common.KafkaException: KABOOM!" + + "\nException handler choose to FAIL the processing, no more records would be sent.") + ); + } + + @Test + public void shouldThrowStreamsExceptionOnSubsequentCloseIfASendFailsWithDefaultExceptionHandler() { + final KafkaException exception = new KafkaException("KABOOM!"); + final RecordCollector collector = new RecordCollectorImpl( + logContext, + taskId, + getExceptionalStreamsProducerOnSend(exception), + productionExceptionHandler, + streamsMetrics + ); + + collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner); + + final StreamsException thrown = assertThrows(StreamsException.class, collector::closeClean); + assertEquals(exception, thrown.getCause()); + assertThat( + thrown.getMessage(), + equalTo("Error encountered sending record to topic topic for task 0_0 due to:" + + "\norg.apache.kafka.common.KafkaException: KABOOM!" + + "\nException handler choose to FAIL the processing, no more records would be sent.") + ); + } + + @Test + public void shouldThrowStreamsExceptionOnSubsequentSendIfFatalEvenWithContinueExceptionHandler() { + final KafkaException exception = new AuthenticationException("KABOOM!"); + final RecordCollector collector = new RecordCollectorImpl( + logContext, + taskId, + getExceptionalStreamsProducerOnSend(exception), + new AlwaysContinueProductionExceptionHandler(), + streamsMetrics + ); + + collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner); + + final StreamsException thrown = assertThrows( + StreamsException.class, + () -> collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner) + ); + assertEquals(exception, thrown.getCause()); + assertThat( + thrown.getMessage(), + equalTo("Error encountered sending record to topic topic for task 0_0 due to:" + + "\norg.apache.kafka.common.errors.AuthenticationException: KABOOM!" + + "\nWritten offsets would not be recorded and no more records would be sent since this is a fatal error.") + ); + } + + @Test + public void shouldThrowStreamsExceptionOnSubsequentFlushIfFatalEvenWithContinueExceptionHandler() { + final KafkaException exception = new AuthenticationException("KABOOM!"); + final RecordCollector collector = new RecordCollectorImpl( + logContext, + taskId, + getExceptionalStreamsProducerOnSend(exception), + new AlwaysContinueProductionExceptionHandler(), + streamsMetrics + ); + + collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner); + + final StreamsException thrown = assertThrows(StreamsException.class, collector::flush); + assertEquals(exception, thrown.getCause()); + assertThat( + thrown.getMessage(), + equalTo("Error encountered sending record to topic topic for task 0_0 due to:" + + "\norg.apache.kafka.common.errors.AuthenticationException: KABOOM!" + + "\nWritten offsets would not be recorded and no more records would be sent since this is a fatal error.") + ); + } + + @Test + public void shouldThrowStreamsExceptionOnSubsequentCloseIfFatalEvenWithContinueExceptionHandler() { + final KafkaException exception = new AuthenticationException("KABOOM!"); + final RecordCollector collector = new RecordCollectorImpl( + logContext, + taskId, + getExceptionalStreamsProducerOnSend(exception), + new AlwaysContinueProductionExceptionHandler(), + streamsMetrics + ); + + collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner); + + final StreamsException thrown = assertThrows(StreamsException.class, collector::closeClean); + assertEquals(exception, thrown.getCause()); + assertThat( + thrown.getMessage(), + equalTo("Error encountered sending record to topic topic for task 0_0 due to:" + + "\norg.apache.kafka.common.errors.AuthenticationException: KABOOM!" + + "\nWritten offsets would not be recorded and no more records would be sent since this is a fatal error.") + ); + } + + @Test + public void shouldNotThrowStreamsExceptionOnSubsequentCallIfASendFailsWithContinueExceptionHandler() { + final RecordCollector collector = new RecordCollectorImpl( + logContext, + taskId, + getExceptionalStreamsProducerOnSend(new Exception()), + new AlwaysContinueProductionExceptionHandler(), + streamsMetrics + ); + + try (final LogCaptureAppender logCaptureAppender = + LogCaptureAppender.createAndRegister(RecordCollectorImpl.class)) { + + collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner); + collector.flush(); + + final List messages = logCaptureAppender.getMessages(); + final StringBuilder errorMessage = new StringBuilder("Messages received:"); + for (final String error : messages) { + errorMessage.append("\n - ").append(error); + } + assertTrue( + errorMessage.toString(), + messages.get(messages.size() - 1) + .endsWith("Exception handler choose to CONTINUE processing in spite of this error but written offsets would not be recorded.") + ); + } + + final Metric metric = streamsMetrics.metrics().get(new MetricName( + "dropped-records-total", + "stream-task-metrics", + "The total number of dropped records", + mkMap( + mkEntry("thread-id", Thread.currentThread().getName()), + mkEntry("task-id", taskId.toString()) + ) + )); + assertEquals(1.0, metric.metricValue()); + + collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner); + collector.flush(); + collector.closeClean(); + } + + @Test + public void shouldNotAbortTxnOnEOSCloseDirtyIfNothingSent() { + final AtomicBoolean functionCalled = new AtomicBoolean(false); + final RecordCollector collector = new RecordCollectorImpl( + logContext, + taskId, + new StreamsProducer( + eosConfig, + "-StreamThread-1", + new MockClientSupplier() { + @Override + public Producer getProducer(final Map config) { + return new MockProducer(cluster, true, new DefaultPartitioner(), byteArraySerializer, byteArraySerializer) { + @Override + public void abortTransaction() { + functionCalled.set(true); + } + }; + } + }, + taskId, + processId, + logContext, + Time.SYSTEM + ), + productionExceptionHandler, + streamsMetrics + ); + + collector.closeDirty(); + assertFalse(functionCalled.get()); + } + + @Test + public void shouldThrowIfTopicIsUnknownOnSendWithPartitioner() { + final RecordCollector collector = new RecordCollectorImpl( + logContext, + taskId, + new StreamsProducer( + config, + processId + "-StreamThread-1", + new MockClientSupplier() { + @Override + public Producer getProducer(final Map config) { + return new MockProducer(cluster, true, new DefaultPartitioner(), byteArraySerializer, byteArraySerializer) { + @Override + public List partitionsFor(final String topic) { + return Collections.emptyList(); + } + }; + } + }, + null, + null, + logContext, + Time.SYSTEM + ), + productionExceptionHandler, + streamsMetrics + ); + collector.initialize(); + + final StreamsException thrown = assertThrows( + StreamsException.class, + () -> collector.send(topic, "3", "0", null, null, stringSerializer, stringSerializer, streamPartitioner) + ); + assertThat( + thrown.getMessage(), + equalTo("Could not get partition information for topic topic for task 0_0." + + " This can happen if the topic does not exist.") + ); + } + + @Test + public void shouldNotCloseInternalProducerForEOS() { + final RecordCollector collector = new RecordCollectorImpl( + logContext, + taskId, + new StreamsProducer( + eosConfig, + processId + "-StreamThread-1", + new MockClientSupplier() { + @Override + public Producer getProducer(final Map config) { + return mockProducer; + } + }, + taskId, + processId, + logContext, + Time.SYSTEM + ), + productionExceptionHandler, + streamsMetrics + ); + + collector.closeClean(); + + // Flush should not throw as producer is still alive. + streamsProducer.flush(); + } + + @Test + public void shouldNotCloseInternalProducerForNonEOS() { + collector.closeClean(); + + // Flush should not throw as producer is still alive. + streamsProducer.flush(); + } + + private StreamsProducer getExceptionalStreamsProducerOnSend(final Exception exception) { + return new StreamsProducer( + config, + processId + "-StreamThread-1", + new MockClientSupplier() { + @Override + public Producer getProducer(final Map config) { + return new MockProducer(cluster, true, new DefaultPartitioner(), byteArraySerializer, byteArraySerializer) { + @Override + public synchronized Future send(final ProducerRecord record, final Callback callback) { + callback.onCompletion(null, exception); + return null; + } + }; + } + }, + null, + null, + logContext, + Time.SYSTEM + ); + } + + private StreamsProducer getExceptionalStreamProducerOnPartitionsFor(final RuntimeException exception) { + return new StreamsProducer( + config, + processId + "-StreamThread-1", + new MockClientSupplier() { + @Override + public Producer getProducer(final Map config) { + return new MockProducer(cluster, true, new DefaultPartitioner(), byteArraySerializer, byteArraySerializer) { + @Override + public synchronized List partitionsFor(final String topic) { + throw exception; + } + }; + } + }, + null, + null, + logContext, + Time.SYSTEM + ); + } + + private static class CustomStringSerializer extends StringSerializer { + private boolean isKey; + + @Override + public void configure(final Map configs, final boolean isKey) { + this.isKey = isKey; + super.configure(configs, isKey); + } + + @Override + public byte[] serialize(final String topic, final Headers headers, final String data) { + if (isKey) { + headers.add(new RecordHeader("key", "key".getBytes())); + } else { + headers.add(new RecordHeader("value", "value".getBytes())); + } + return serialize(topic, data); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordDeserializerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordDeserializerTest.java new file mode 100644 index 0000000..448ceaf --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordDeserializerTest.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.common.utils.LogContext; +import org.junit.Test; + +import java.util.Optional; + +import static org.junit.Assert.assertEquals; + +public class RecordDeserializerTest { + + private final RecordHeaders headers = new RecordHeaders(new Header[] {new RecordHeader("key", "value".getBytes())}); + private final ConsumerRecord rawRecord = new ConsumerRecord<>("topic", + 1, + 1, + 10, + TimestampType.LOG_APPEND_TIME, + 3, + 5, + new byte[0], + new byte[0], + headers, + Optional.empty()); + + @Test + public void shouldReturnConsumerRecordWithDeserializedValueWhenNoExceptions() { + final RecordDeserializer recordDeserializer = new RecordDeserializer( + new TheSourceNode( + false, + false, + "key", "value" + ), + null, + new LogContext(), + new Metrics().sensor("dropped-records") + ); + final ConsumerRecord record = recordDeserializer.deserialize(null, rawRecord); + assertEquals(rawRecord.topic(), record.topic()); + assertEquals(rawRecord.partition(), record.partition()); + assertEquals(rawRecord.offset(), record.offset()); + assertEquals("key", record.key()); + assertEquals("value", record.value()); + assertEquals(rawRecord.timestamp(), record.timestamp()); + assertEquals(TimestampType.CREATE_TIME, record.timestampType()); + assertEquals(rawRecord.headers(), record.headers()); + } + + static class TheSourceNode extends SourceNode { + private final boolean keyThrowsException; + private final boolean valueThrowsException; + private final Object key; + private final Object value; + + TheSourceNode(final boolean keyThrowsException, + final boolean valueThrowsException, + final Object key, + final Object value) { + super("", null, null); + this.keyThrowsException = keyThrowsException; + this.valueThrowsException = valueThrowsException; + this.key = key; + this.value = value; + } + + @Override + public Object deserializeKey(final String topic, final Headers headers, final byte[] data) { + if (keyThrowsException) { + throw new RuntimeException(); + } + return key; + } + + @Override + public Object deserializeValue(final String topic, final Headers headers, final byte[] data) { + if (valueThrowsException) { + throw new RuntimeException(); + } + return value; + } + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordQueueTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordQueueTest.java new file mode 100644 index 0000000..d23311b --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RecordQueueTest.java @@ -0,0 +1,414 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.SerializationException; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.IntegerDeserializer; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.streams.errors.LogAndContinueExceptionHandler; +import org.apache.kafka.streams.errors.LogAndFailExceptionHandler; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.FailOnInvalidTimestamp; +import org.apache.kafka.streams.processor.LogAndSkipOnInvalidTimestamp; +import org.apache.kafka.streams.processor.TimestampExtractor; +import org.apache.kafka.streams.state.StateSerdes; +import org.apache.kafka.test.InternalMockProcessorContext; +import org.apache.kafka.test.MockRecordCollector; +import org.apache.kafka.test.MockSourceNode; +import org.apache.kafka.test.MockTimestampExtractor; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.core.IsInstanceOf.instanceOf; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class RecordQueueTest { + private final Serializer intSerializer = new IntegerSerializer(); + private final Deserializer intDeserializer = new IntegerDeserializer(); + private final TimestampExtractor timestampExtractor = new MockTimestampExtractor(); + + @SuppressWarnings("rawtypes") + final InternalMockProcessorContext context = new InternalMockProcessorContext<>( + StateSerdes.withBuiltinTypes("anyName", Bytes.class, Bytes.class), + new MockRecordCollector() + ); + private final MockSourceNode mockSourceNodeWithMetrics + = new MockSourceNode<>(intDeserializer, intDeserializer); + private final RecordQueue queue = new RecordQueue( + new TopicPartition("topic", 1), + mockSourceNodeWithMetrics, + timestampExtractor, + new LogAndFailExceptionHandler(), + context, + new LogContext()); + private final RecordQueue queueThatSkipsDeserializeErrors = new RecordQueue( + new TopicPartition("topic", 1), + mockSourceNodeWithMetrics, + timestampExtractor, + new LogAndContinueExceptionHandler(), + context, + new LogContext()); + + private final byte[] recordValue = intSerializer.serialize(null, 10); + private final byte[] recordKey = intSerializer.serialize(null, 1); + + @SuppressWarnings("unchecked") + @Before + public void before() { + mockSourceNodeWithMetrics.init(context); + } + + @After + public void after() { + mockSourceNodeWithMetrics.close(); + } + + @Test + public void testTimeTracking() { + assertTrue(queue.isEmpty()); + assertEquals(0, queue.size()); + assertEquals(RecordQueue.UNKNOWN, queue.headRecordTimestamp()); + assertNull(queue.headRecordOffset()); + + // add three 3 out-of-order records with timestamp 2, 1, 3 + final List> list1 = Arrays.asList( + new ConsumerRecord<>("topic", 1, 2, 0L, TimestampType.CREATE_TIME, 0, 0, recordKey, recordValue, new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>("topic", 1, 1, 0L, TimestampType.CREATE_TIME, 0, 0, recordKey, recordValue, new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>("topic", 1, 3, 0L, TimestampType.CREATE_TIME, 0, 0, recordKey, recordValue, new RecordHeaders(), Optional.empty())); + + queue.addRawRecords(list1); + + assertEquals(3, queue.size()); + assertEquals(2L, queue.headRecordTimestamp()); + assertEquals(2L, queue.headRecordOffset().longValue()); + + // poll the first record, now with 1, 3 + assertEquals(2L, queue.poll().timestamp); + assertEquals(2, queue.size()); + assertEquals(1L, queue.headRecordTimestamp()); + assertEquals(1L, queue.headRecordOffset().longValue()); + + // poll the second record, now with 3 + assertEquals(1L, queue.poll().timestamp); + assertEquals(1, queue.size()); + assertEquals(3L, queue.headRecordTimestamp()); + assertEquals(3L, queue.headRecordOffset().longValue()); + + // add three 3 out-of-order records with timestamp 4, 1, 2 + // now with 3, 4, 1, 2 + final List> list2 = Arrays.asList( + new ConsumerRecord<>("topic", 1, 4, 0L, TimestampType.CREATE_TIME, 0, 0, recordKey, recordValue, new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>("topic", 1, 1, 0L, TimestampType.CREATE_TIME, 0, 0, recordKey, recordValue, new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>("topic", 1, 2, 0L, TimestampType.CREATE_TIME, 0, 0, recordKey, recordValue, new RecordHeaders(), Optional.empty())); + + queue.addRawRecords(list2); + + assertEquals(4, queue.size()); + assertEquals(3L, queue.headRecordTimestamp()); + assertEquals(3L, queue.headRecordOffset().longValue()); + + // poll the third record, now with 4, 1, 2 + assertEquals(3L, queue.poll().timestamp); + assertEquals(3, queue.size()); + assertEquals(4L, queue.headRecordTimestamp()); + assertEquals(4L, queue.headRecordOffset().longValue()); + + // poll the rest records + assertEquals(4L, queue.poll().timestamp); + assertEquals(1L, queue.headRecordTimestamp()); + assertEquals(1L, queue.headRecordOffset().longValue()); + + assertEquals(1L, queue.poll().timestamp); + assertEquals(2L, queue.headRecordTimestamp()); + assertEquals(2L, queue.headRecordOffset().longValue()); + + assertEquals(2L, queue.poll().timestamp); + assertTrue(queue.isEmpty()); + assertEquals(0, queue.size()); + assertEquals(RecordQueue.UNKNOWN, queue.headRecordTimestamp()); + assertNull(queue.headRecordOffset()); + + // add three more records with 4, 5, 6 + final List> list3 = Arrays.asList( + new ConsumerRecord<>("topic", 1, 4, 0L, TimestampType.CREATE_TIME, 0, 0, recordKey, recordValue, new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>("topic", 1, 5, 0L, TimestampType.CREATE_TIME, 0, 0, recordKey, recordValue, new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>("topic", 1, 6, 0L, TimestampType.CREATE_TIME, 0, 0, recordKey, recordValue, new RecordHeaders(), Optional.empty())); + + queue.addRawRecords(list3); + + assertEquals(3, queue.size()); + assertEquals(4L, queue.headRecordTimestamp()); + assertEquals(4L, queue.headRecordOffset().longValue()); + + // poll one record again, the timestamp should advance now + assertEquals(4L, queue.poll().timestamp); + assertEquals(2, queue.size()); + assertEquals(5L, queue.headRecordTimestamp()); + assertEquals(5L, queue.headRecordOffset().longValue()); + + // clear the queue + queue.clear(); + assertTrue(queue.isEmpty()); + assertEquals(0, queue.size()); + assertEquals(RecordQueue.UNKNOWN, queue.headRecordTimestamp()); + assertEquals(RecordQueue.UNKNOWN, queue.partitionTime()); + assertNull(queue.headRecordOffset()); + + // re-insert the three records with 4, 5, 6 + queue.addRawRecords(list3); + + assertEquals(3, queue.size()); + assertEquals(4L, queue.headRecordTimestamp()); + assertEquals(4L, queue.headRecordOffset().longValue()); + } + + @Test + public void shouldTrackPartitionTimeAsMaxProcessedTimestamp() { + assertTrue(queue.isEmpty()); + assertThat(queue.size(), is(0)); + assertThat(queue.headRecordTimestamp(), is(RecordQueue.UNKNOWN)); + assertThat(queue.partitionTime(), is(RecordQueue.UNKNOWN)); + + // add three 3 out-of-order records with timestamp 2, 1, 3, 4 + final List> list1 = Arrays.asList( + new ConsumerRecord<>("topic", 1, 2, 0L, TimestampType.CREATE_TIME, 0, 0, recordKey, recordValue, + new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>("topic", 1, 1, 0L, TimestampType.CREATE_TIME, 0, 0, recordKey, recordValue, + new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>("topic", 1, 3, 0L, TimestampType.CREATE_TIME, 0, 0, recordKey, recordValue, + new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>("topic", 1, 4, 0L, TimestampType.CREATE_TIME, 0, 0, recordKey, recordValue, + new RecordHeaders(), Optional.empty())); + + queue.addRawRecords(list1); + assertThat(queue.partitionTime(), is(RecordQueue.UNKNOWN)); + + queue.poll(); + assertThat(queue.partitionTime(), is(2L)); + + queue.poll(); + assertThat(queue.partitionTime(), is(2L)); + + queue.poll(); + assertThat(queue.partitionTime(), is(3L)); + } + + @Test + public void shouldSetTimestampAndRespectMaxTimestampPolicy() { + assertTrue(queue.isEmpty()); + assertThat(queue.size(), is(0)); + assertThat(queue.headRecordTimestamp(), is(RecordQueue.UNKNOWN)); + assertThat(queue.partitionTime(), is(RecordQueue.UNKNOWN)); + + queue.setPartitionTime(150L); + assertThat(queue.partitionTime(), is(150L)); + + final List> list1 = Arrays.asList( + new ConsumerRecord<>("topic", 1, 200, 0L, TimestampType.CREATE_TIME, 0, 0, recordKey, recordValue, + new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>("topic", 1, 100, 0L, TimestampType.CREATE_TIME, 0, 0, recordKey, recordValue, + new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>("topic", 1, 300, 0L, TimestampType.CREATE_TIME, 0, 0, recordKey, recordValue, + new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>("topic", 1, 400, 0L, TimestampType.CREATE_TIME, 0, 0, recordKey, recordValue, + new RecordHeaders(), Optional.empty())); + + queue.addRawRecords(list1); + assertThat(queue.partitionTime(), is(150L)); + + queue.poll(); + assertThat(queue.partitionTime(), is(200L)); + + queue.setPartitionTime(500L); + assertThat(queue.partitionTime(), is(500L)); + + queue.poll(); + assertThat(queue.partitionTime(), is(500L)); + } + + @Test + public void shouldThrowStreamsExceptionWhenKeyDeserializationFails() { + final byte[] key = Serdes.Long().serializer().serialize("foo", 1L); + final List> records = Collections.singletonList( + new ConsumerRecord<>("topic", 1, 1, 0L, TimestampType.CREATE_TIME, 0, 0, key, recordValue, + new RecordHeaders(), Optional.empty())); + + final StreamsException exception = assertThrows( + StreamsException.class, + () -> queue.addRawRecords(records) + ); + assertThat(exception.getCause(), instanceOf(SerializationException.class)); + } + + @Test + public void shouldThrowStreamsExceptionWhenValueDeserializationFails() { + final byte[] value = Serdes.Long().serializer().serialize("foo", 1L); + final List> records = Collections.singletonList( + new ConsumerRecord<>("topic", 1, 1, 0L, TimestampType.CREATE_TIME, 0, 0, recordKey, value, + new RecordHeaders(), Optional.empty())); + + final StreamsException exception = assertThrows( + StreamsException.class, + () -> queue.addRawRecords(records) + ); + assertThat(exception.getCause(), instanceOf(SerializationException.class)); + } + + @Test + public void shouldNotThrowStreamsExceptionWhenKeyDeserializationFailsWithSkipHandler() { + final byte[] key = Serdes.Long().serializer().serialize("foo", 1L); + final List> records = Collections.singletonList( + new ConsumerRecord<>("topic", 1, 1, 0L, TimestampType.CREATE_TIME, 0, 0, key, recordValue, + new RecordHeaders(), Optional.empty())); + + queueThatSkipsDeserializeErrors.addRawRecords(records); + assertEquals(0, queueThatSkipsDeserializeErrors.size()); + } + + @Test + public void shouldNotThrowStreamsExceptionWhenValueDeserializationFailsWithSkipHandler() { + final byte[] value = Serdes.Long().serializer().serialize("foo", 1L); + final List> records = Collections.singletonList( + new ConsumerRecord<>("topic", 1, 1, 0L, TimestampType.CREATE_TIME, 0, 0, recordKey, value, + new RecordHeaders(), Optional.empty())); + + queueThatSkipsDeserializeErrors.addRawRecords(records); + assertEquals(0, queueThatSkipsDeserializeErrors.size()); + } + + @Test + public void shouldThrowOnNegativeTimestamp() { + final List> records = Collections.singletonList( + new ConsumerRecord<>("topic", 1, 1, -1L, TimestampType.CREATE_TIME, 0, 0, recordKey, recordValue, + new RecordHeaders(), Optional.empty())); + + final RecordQueue queue = new RecordQueue( + new TopicPartition("topic", 1), + mockSourceNodeWithMetrics, + new FailOnInvalidTimestamp(), + new LogAndContinueExceptionHandler(), + new InternalMockProcessorContext(), + new LogContext()); + + final StreamsException exception = assertThrows( + StreamsException.class, + () -> queue.addRawRecords(records) + ); + assertThat(exception.getMessage(), equalTo("Input record ConsumerRecord(topic = topic, partition = 1, " + + "leaderEpoch = null, offset = 1, CreateTime = -1, serialized key size = 0, serialized value size = 0, " + + "headers = RecordHeaders(headers = [], isReadOnly = false), key = 1, value = 10) has invalid (negative) " + + "timestamp. Possibly because a pre-0.10 producer client was used to write this record to Kafka without " + + "embedding a timestamp, or because the input topic was created before upgrading the Kafka cluster to 0.10+. " + + "Use a different TimestampExtractor to process this data.")); + } + + @Test + public void shouldDropOnNegativeTimestamp() { + final List> records = Collections.singletonList( + new ConsumerRecord<>("topic", 1, 1, -1L, TimestampType.CREATE_TIME, 0, 0, recordKey, recordValue, + new RecordHeaders(), Optional.empty())); + + final RecordQueue queue = new RecordQueue( + new TopicPartition("topic", 1), + mockSourceNodeWithMetrics, + new LogAndSkipOnInvalidTimestamp(), + new LogAndContinueExceptionHandler(), + new InternalMockProcessorContext(), + new LogContext()); + queue.addRawRecords(records); + + assertEquals(0, queue.size()); + } + + @Test + public void shouldPassPartitionTimeToTimestampExtractor() { + + final PartitionTimeTrackingTimestampExtractor timestampExtractor = new PartitionTimeTrackingTimestampExtractor(); + final RecordQueue queue = new RecordQueue( + new TopicPartition("topic", 1), + mockSourceNodeWithMetrics, + timestampExtractor, + new LogAndFailExceptionHandler(), + context, + new LogContext()); + + assertTrue(queue.isEmpty()); + assertEquals(0, queue.size()); + assertEquals(RecordQueue.UNKNOWN, queue.headRecordTimestamp()); + + // add three 3 out-of-order records with timestamp 2, 1, 3, 4 + final List> list1 = Arrays.asList( + new ConsumerRecord<>("topic", 1, 2, 0L, TimestampType.CREATE_TIME, 0, 0, recordKey, recordValue, + new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>("topic", 1, 1, 0L, TimestampType.CREATE_TIME, 0, 0, recordKey, recordValue, + new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>("topic", 1, 3, 0L, TimestampType.CREATE_TIME, 0, 0, recordKey, recordValue, + new RecordHeaders(), Optional.empty()), + new ConsumerRecord<>("topic", 1, 4, 0L, TimestampType.CREATE_TIME, 0, 0, recordKey, recordValue, + new RecordHeaders(), Optional.empty())); + + assertEquals(RecordQueue.UNKNOWN, timestampExtractor.partitionTime); + + queue.addRawRecords(list1); + + // no (known) timestamp has yet been passed to the timestamp extractor + assertEquals(RecordQueue.UNKNOWN, timestampExtractor.partitionTime); + + queue.poll(); + assertEquals(2L, timestampExtractor.partitionTime); + + queue.poll(); + assertEquals(2L, timestampExtractor.partitionTime); + + queue.poll(); + assertEquals(3L, timestampExtractor.partitionTime); + + } + + private static class PartitionTimeTrackingTimestampExtractor implements TimestampExtractor { + private long partitionTime = RecordQueue.UNKNOWN; + + public long extract(final ConsumerRecord record, final long partitionTime) { + if (partitionTime < this.partitionTime) { + throw new IllegalStateException("Partition time should not decrease"); + } + this.partitionTime = partitionTime; + return record.offset(); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RepartitionOptimizingTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RepartitionOptimizingTest.java new file mode 100644 index 0000000..eb817cc --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RepartitionOptimizingTest.java @@ -0,0 +1,456 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.processor.internals; + +import java.util.HashMap; +import java.util.Map; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.IntegerDeserializer; +import org.apache.kafka.common.serialization.LongDeserializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TestOutputTopic; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.kstream.JoinWindows; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Named; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.Reducer; +import org.apache.kafka.streams.kstream.StreamJoined; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.Properties; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static java.time.Duration.ofDays; +import static java.time.Duration.ofMillis; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; + +@SuppressWarnings("deprecation") +public class RepartitionOptimizingTest { + + private final Logger log = LoggerFactory.getLogger(RepartitionOptimizingTest.class); + + private static final String INPUT_TOPIC = "input"; + private static final String COUNT_TOPIC = "outputTopic_0"; + private static final String AGGREGATION_TOPIC = "outputTopic_1"; + private static final String REDUCE_TOPIC = "outputTopic_2"; + private static final String JOINED_TOPIC = "joinedOutputTopic"; + + private static final int ONE_REPARTITION_TOPIC = 1; + private static final int FOUR_REPARTITION_TOPICS = 4; + + private final Serializer stringSerializer = new StringSerializer(); + private final Deserializer stringDeserializer = new StringDeserializer(); + + private final Pattern repartitionTopicPattern = Pattern.compile("Sink: .*-repartition"); + + private Properties streamsConfiguration; + private TopologyTestDriver topologyTestDriver; + + private final Initializer initializer = () -> 0; + private final Aggregator aggregator = (k, v, agg) -> agg + v.length(); + private final Reducer reducer = (v1, v2) -> v1 + ":" + v2; + + private final List processorValueCollector = new ArrayList<>(); + + private final List> expectedCountKeyValues = + Arrays.asList(KeyValue.pair("A", 3L), KeyValue.pair("B", 3L), KeyValue.pair("C", 3L)); + private final List> expectedAggKeyValues = + Arrays.asList(KeyValue.pair("A", 9), KeyValue.pair("B", 9), KeyValue.pair("C", 9)); + private final List> expectedReduceKeyValues = + Arrays.asList(KeyValue.pair("A", "foo:bar:baz"), KeyValue.pair("B", "foo:bar:baz"), KeyValue.pair("C", "foo:bar:baz")); + private final List> expectedJoinKeyValues = + Arrays.asList(KeyValue.pair("A", "foo:3"), KeyValue.pair("A", "bar:3"), KeyValue.pair("A", "baz:3")); + private final List expectedCollectedProcessorValues = + Arrays.asList("FOO", "BAR", "BAZ"); + + @Before + public void setUp() { + streamsConfiguration = StreamsTestUtils.getStreamsConfig(Serdes.String(), Serdes.String()); + streamsConfiguration.setProperty(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, Integer.toString(1024 * 10)); + streamsConfiguration.setProperty(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, Long.toString(5000)); + + processorValueCollector.clear(); + } + + @After + public void tearDown() { + try { + topologyTestDriver.close(); + } catch (final RuntimeException e) { + log.warn("The following exception was thrown while trying to close the TopologyTestDriver (note that " + + "KAFKA-6647 causes this when running on Windows):", e); + } + } + + @Test + public void shouldSendCorrectRecords_OPTIMIZED() { + runTest(StreamsConfig.OPTIMIZE, ONE_REPARTITION_TOPIC); + } + + @Test + public void shouldSendCorrectResults_NO_OPTIMIZATION() { + runTest(StreamsConfig.NO_OPTIMIZATION, FOUR_REPARTITION_TOPICS); + } + + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + private void runTest(final String optimizationConfig, final int expectedNumberRepartitionTopics) { + + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream sourceStream = + builder.stream(INPUT_TOPIC, Consumed.with(Serdes.String(), Serdes.String()).withName("sourceStream")); + + final KStream mappedStream = sourceStream + .map((k, v) -> KeyValue.pair(k.toUpperCase(Locale.getDefault()), v), Named.as("source-map")); + + mappedStream + .filter((k, v) -> k.equals("B"), Named.as("process-filter")) + .mapValues(v -> v.toUpperCase(Locale.getDefault()), Named.as("process-mapValues")) + .process(() -> new SimpleProcessor(processorValueCollector), Named.as("process")); + + final KStream countStream = mappedStream + .groupByKey(Grouped.as("count-groupByKey")) + .count(Named.as("count"), Materialized.as(Stores.inMemoryKeyValueStore("count-store")) + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.Long())) + .toStream(Named.as("count-toStream")); + + countStream.to(COUNT_TOPIC, Produced.with(Serdes.String(), Serdes.Long()).withName("count-to")); + + mappedStream + .groupByKey(Grouped.as("aggregate-groupByKey")) + .aggregate(initializer, + aggregator, + Named.as("aggregate"), + Materialized.as(Stores.inMemoryKeyValueStore("aggregate-store")) + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.Integer())) + .toStream(Named.as("aggregate-toStream")) + .to(AGGREGATION_TOPIC, Produced.with(Serdes.String(), Serdes.Integer()).withName("reduce-to")); + + // adding operators for case where the repartition node is further downstream + mappedStream + .filter((k, v) -> true, Named.as("reduce-filter")) + .peek((k, v) -> System.out.println(k + ":" + v), Named.as("reduce-peek")) + .groupByKey(Grouped.as("reduce-groupByKey")) + .reduce(reducer, + Named.as("reducer"), + Materialized.as(Stores.inMemoryKeyValueStore("reduce-store"))) + .toStream(Named.as("reduce-toStream")) + .to(REDUCE_TOPIC, Produced.with(Serdes.String(), Serdes.String())); + + mappedStream + .filter((k, v) -> k.equals("A"), Named.as("join-filter")) + .join(countStream, (v1, v2) -> v1 + ":" + v2.toString(), + JoinWindows.of(ofMillis(5000)), + StreamJoined.with(Stores.inMemoryWindowStore("join-store", ofDays(1), ofMillis(10000), true), + Stores.inMemoryWindowStore("other-join-store", ofDays(1), ofMillis(10000), true)) + .withName("join") + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.String()) + .withOtherValueSerde(Serdes.Long())) + .to(JOINED_TOPIC, Produced.as("join-to")); + + streamsConfiguration.setProperty(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, optimizationConfig); + final Topology topology = builder.build(streamsConfiguration); + + topologyTestDriver = new TopologyTestDriver(topology, streamsConfiguration); + + final TestInputTopic inputTopicA = topologyTestDriver.createInputTopic(INPUT_TOPIC, stringSerializer, stringSerializer); + final TestOutputTopic countOutputTopic = topologyTestDriver.createOutputTopic(COUNT_TOPIC, stringDeserializer, new LongDeserializer()); + final TestOutputTopic aggregationOutputTopic = topologyTestDriver.createOutputTopic(AGGREGATION_TOPIC, stringDeserializer, new IntegerDeserializer()); + final TestOutputTopic reduceOutputTopic = topologyTestDriver.createOutputTopic(REDUCE_TOPIC, stringDeserializer, stringDeserializer); + final TestOutputTopic joinedOutputTopic = topologyTestDriver.createOutputTopic(JOINED_TOPIC, stringDeserializer, stringDeserializer); + + inputTopicA.pipeKeyValueList(getKeyValues()); + + // Verify the topology + final String topologyString = topology.describe().toString(); + if (optimizationConfig.equals(StreamsConfig.OPTIMIZE)) { + assertEquals(EXPECTED_OPTIMIZED_TOPOLOGY, topologyString); + } else { + assertEquals(EXPECTED_UNOPTIMIZED_TOPOLOGY, topologyString); + } + + // Verify the number of repartition topics + assertEquals(expectedNumberRepartitionTopics, getCountOfRepartitionTopicsFound(topologyString)); + + // Verify the values collected by the processor + assertThat(3, equalTo(processorValueCollector.size())); + assertThat(processorValueCollector, equalTo(expectedCollectedProcessorValues)); + + // Verify the expected output + assertThat(countOutputTopic.readKeyValuesToMap(), equalTo(keyValueListToMap(expectedCountKeyValues))); + assertThat(aggregationOutputTopic.readKeyValuesToMap(), equalTo(keyValueListToMap(expectedAggKeyValues))); + assertThat(reduceOutputTopic.readKeyValuesToMap(), equalTo(keyValueListToMap(expectedReduceKeyValues))); + assertThat(joinedOutputTopic.readKeyValuesToMap(), equalTo(keyValueListToMap(expectedJoinKeyValues))); + } + + private Map keyValueListToMap(final List> keyValuePairs) { + final Map map = new HashMap<>(); + for (final KeyValue pair : keyValuePairs) { + map.put(pair.key, pair.value); + } + return map; + } + + private int getCountOfRepartitionTopicsFound(final String topologyString) { + final Matcher matcher = repartitionTopicPattern.matcher(topologyString); + final List repartitionTopicsFound = new ArrayList<>(); + while (matcher.find()) { + repartitionTopicsFound.add(matcher.group()); + } + return repartitionTopicsFound.size(); + } + + private List> getKeyValues() { + final List> keyValueList = new ArrayList<>(); + final String[] keys = new String[]{"a", "b", "c"}; + final String[] values = new String[]{"foo", "bar", "baz"}; + for (final String key : keys) { + for (final String value : values) { + keyValueList.add(KeyValue.pair(key, value)); + } + } + return keyValueList; + } + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + private static class SimpleProcessor extends org.apache.kafka.streams.processor.AbstractProcessor { + + final List valueList; + + SimpleProcessor(final List valueList) { + this.valueList = valueList; + } + + @Override + public void process(final String key, final String value) { + valueList.add(value); + } + } + + private static final String EXPECTED_OPTIMIZED_TOPOLOGY = "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: sourceStream (topics: [input])\n" + + " --> source-map\n" + + " Processor: source-map (stores: [])\n" + + " --> process-filter, count-groupByKey-repartition-filter\n" + + " <-- sourceStream\n" + + " Processor: process-filter (stores: [])\n" + + " --> process-mapValues\n" + + " <-- source-map\n" + + " Processor: count-groupByKey-repartition-filter (stores: [])\n" + + " --> count-groupByKey-repartition-sink\n" + + " <-- source-map\n" + + " Processor: process-mapValues (stores: [])\n" + + " --> process\n" + + " <-- process-filter\n" + + " Sink: count-groupByKey-repartition-sink (topic: count-groupByKey-repartition)\n" + + " <-- count-groupByKey-repartition-filter\n" + + " Processor: process (stores: [])\n" + + " --> none\n" + + " <-- process-mapValues\n" + + "\n" + + " Sub-topology: 1\n" + + " Source: count-groupByKey-repartition-source (topics: [count-groupByKey-repartition])\n" + + " --> aggregate, count, join-filter, reduce-filter\n" + + " Processor: count (stores: [count-store])\n" + + " --> count-toStream\n" + + " <-- count-groupByKey-repartition-source\n" + + " Processor: count-toStream (stores: [])\n" + + " --> join-other-windowed, count-to\n" + + " <-- count\n" + + " Processor: join-filter (stores: [])\n" + + " --> join-this-windowed\n" + + " <-- count-groupByKey-repartition-source\n" + + " Processor: reduce-filter (stores: [])\n" + + " --> reduce-peek\n" + + " <-- count-groupByKey-repartition-source\n" + + " Processor: join-other-windowed (stores: [other-join-store])\n" + + " --> join-other-join\n" + + " <-- count-toStream\n" + + " Processor: join-this-windowed (stores: [join-store])\n" + + " --> join-this-join\n" + + " <-- join-filter\n" + + " Processor: reduce-peek (stores: [])\n" + + " --> reducer\n" + + " <-- reduce-filter\n" + + " Processor: aggregate (stores: [aggregate-store])\n" + + " --> aggregate-toStream\n" + + " <-- count-groupByKey-repartition-source\n" + + " Processor: join-other-join (stores: [join-store])\n" + + " --> join-merge\n" + + " <-- join-other-windowed\n" + + " Processor: join-this-join (stores: [other-join-store])\n" + + " --> join-merge\n" + + " <-- join-this-windowed\n" + + " Processor: reducer (stores: [reduce-store])\n" + + " --> reduce-toStream\n" + + " <-- reduce-peek\n" + + " Processor: aggregate-toStream (stores: [])\n" + + " --> reduce-to\n" + + " <-- aggregate\n" + + " Processor: join-merge (stores: [])\n" + + " --> join-to\n" + + " <-- join-this-join, join-other-join\n" + + " Processor: reduce-toStream (stores: [])\n" + + " --> KSTREAM-SINK-0000000023\n" + + " <-- reducer\n" + + " Sink: KSTREAM-SINK-0000000023 (topic: outputTopic_2)\n" + + " <-- reduce-toStream\n" + + " Sink: count-to (topic: outputTopic_0)\n" + + " <-- count-toStream\n" + + " Sink: join-to (topic: joinedOutputTopic)\n" + + " <-- join-merge\n" + + " Sink: reduce-to (topic: outputTopic_1)\n" + + " <-- aggregate-toStream\n\n"; + + + + + private static final String EXPECTED_UNOPTIMIZED_TOPOLOGY = "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: sourceStream (topics: [input])\n" + + " --> source-map\n" + + " Processor: source-map (stores: [])\n" + + " --> reduce-filter, process-filter, aggregate-groupByKey-repartition-filter, count-groupByKey-repartition-filter, join-filter\n" + + " <-- sourceStream\n" + + " Processor: reduce-filter (stores: [])\n" + + " --> reduce-peek\n" + + " <-- source-map\n" + + " Processor: join-filter (stores: [])\n" + + " --> join-left-repartition-filter\n" + + " <-- source-map\n" + + " Processor: process-filter (stores: [])\n" + + " --> process-mapValues\n" + + " <-- source-map\n" + + " Processor: reduce-peek (stores: [])\n" + + " --> reduce-groupByKey-repartition-filter\n" + + " <-- reduce-filter\n" + + " Processor: aggregate-groupByKey-repartition-filter (stores: [])\n" + + " --> aggregate-groupByKey-repartition-sink\n" + + " <-- source-map\n" + + " Processor: count-groupByKey-repartition-filter (stores: [])\n" + + " --> count-groupByKey-repartition-sink\n" + + " <-- source-map\n" + + " Processor: join-left-repartition-filter (stores: [])\n" + + " --> join-left-repartition-sink\n" + + " <-- join-filter\n" + + " Processor: process-mapValues (stores: [])\n" + + " --> process\n" + + " <-- process-filter\n" + + " Processor: reduce-groupByKey-repartition-filter (stores: [])\n" + + " --> reduce-groupByKey-repartition-sink\n" + + " <-- reduce-peek\n" + + " Sink: aggregate-groupByKey-repartition-sink (topic: aggregate-groupByKey-repartition)\n" + + " <-- aggregate-groupByKey-repartition-filter\n" + + " Sink: count-groupByKey-repartition-sink (topic: count-groupByKey-repartition)\n" + + " <-- count-groupByKey-repartition-filter\n" + + " Sink: join-left-repartition-sink (topic: join-left-repartition)\n" + + " <-- join-left-repartition-filter\n" + + " Processor: process (stores: [])\n" + + " --> none\n" + + " <-- process-mapValues\n" + + " Sink: reduce-groupByKey-repartition-sink (topic: reduce-groupByKey-repartition)\n" + + " <-- reduce-groupByKey-repartition-filter\n" + + "\n" + + " Sub-topology: 1\n" + + " Source: count-groupByKey-repartition-source (topics: [count-groupByKey-repartition])\n" + + " --> count\n" + + " Processor: count (stores: [count-store])\n" + + " --> count-toStream\n" + + " <-- count-groupByKey-repartition-source\n" + + " Processor: count-toStream (stores: [])\n" + + " --> join-other-windowed, count-to\n" + + " <-- count\n" + + " Source: join-left-repartition-source (topics: [join-left-repartition])\n" + + " --> join-this-windowed\n" + + " Processor: join-other-windowed (stores: [other-join-store])\n" + + " --> join-other-join\n" + + " <-- count-toStream\n" + + " Processor: join-this-windowed (stores: [join-store])\n" + + " --> join-this-join\n" + + " <-- join-left-repartition-source\n" + + " Processor: join-other-join (stores: [join-store])\n" + + " --> join-merge\n" + + " <-- join-other-windowed\n" + + " Processor: join-this-join (stores: [other-join-store])\n" + + " --> join-merge\n" + + " <-- join-this-windowed\n" + + " Processor: join-merge (stores: [])\n" + + " --> join-to\n" + + " <-- join-this-join, join-other-join\n" + + " Sink: count-to (topic: outputTopic_0)\n" + + " <-- count-toStream\n" + + " Sink: join-to (topic: joinedOutputTopic)\n" + + " <-- join-merge\n" + + "\n" + + " Sub-topology: 2\n" + + " Source: aggregate-groupByKey-repartition-source (topics: [aggregate-groupByKey-repartition])\n" + + " --> aggregate\n" + + " Processor: aggregate (stores: [aggregate-store])\n" + + " --> aggregate-toStream\n" + + " <-- aggregate-groupByKey-repartition-source\n" + + " Processor: aggregate-toStream (stores: [])\n" + + " --> reduce-to\n" + + " <-- aggregate\n" + + " Sink: reduce-to (topic: outputTopic_1)\n" + + " <-- aggregate-toStream\n" + + "\n" + + " Sub-topology: 3\n" + + " Source: reduce-groupByKey-repartition-source (topics: [reduce-groupByKey-repartition])\n" + + " --> reducer\n" + + " Processor: reducer (stores: [reduce-store])\n" + + " --> reduce-toStream\n" + + " <-- reduce-groupByKey-repartition-source\n" + + " Processor: reduce-toStream (stores: [])\n" + + " --> KSTREAM-SINK-0000000023\n" + + " <-- reducer\n" + + " Sink: KSTREAM-SINK-0000000023 (topic: outputTopic_2)\n" + + " <-- reduce-toStream\n\n"; + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RepartitionTopicConfigTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RepartitionTopicConfigTest.java new file mode 100644 index 0000000..01fff7d --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RepartitionTopicConfigTest.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.junit.Test; + +import java.util.Collections; +import java.util.Optional; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +public class RepartitionTopicConfigTest { + + @Test + public void shouldThrowAnExceptionWhenSettingNumberOfPartitionsIfTheyAreEnforced() { + final String name = "my-topic"; + final RepartitionTopicConfig repartitionTopicConfig = new RepartitionTopicConfig(name, + Collections.emptyMap(), + 10, + true); + + final UnsupportedOperationException ex = assertThrows( + UnsupportedOperationException.class, + () -> repartitionTopicConfig.setNumberOfPartitions(2) + ); + + assertEquals(String.format("number of partitions are enforced on topic " + + "%s and can't be altered.", name), ex.getMessage()); + } + + @Test + public void shouldNotThrowAnExceptionWhenSettingNumberOfPartitionsIfTheyAreNotEnforced() { + final String name = "my-topic"; + final RepartitionTopicConfig repartitionTopicConfig = new RepartitionTopicConfig(name, + Collections.emptyMap(), + 10, + false); + + repartitionTopicConfig.setNumberOfPartitions(4); + + assertEquals(repartitionTopicConfig.numberOfPartitions(), Optional.of(4)); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RepartitionTopicsTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RepartitionTopicsTest.java new file mode 100644 index 0000000..ce94294 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RepartitionTopicsTest.java @@ -0,0 +1,443 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.MissingSourceTopicException; +import org.apache.kafka.streams.errors.TaskAssignmentException; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder.TopicsInfo; +import org.apache.kafka.streams.processor.internals.assignment.CopartitionedTopicsEnforcer; +import org.apache.kafka.streams.processor.internals.testutil.DummyStreamsConfig; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.SUBTOPOLOGY_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.SUBTOPOLOGY_1; + +import static org.easymock.EasyMock.anyObject; +import static org.easymock.EasyMock.eq; +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.mock; +import static org.easymock.EasyMock.niceMock; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.verify; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertThrows; + +@RunWith(PowerMockRunner.class) +@PrepareForTest({Cluster.class}) +public class RepartitionTopicsTest { + + private static final String SOURCE_TOPIC_NAME1 = "source1"; + private static final String SOURCE_TOPIC_NAME2 = "source2"; + private static final String SOURCE_TOPIC_NAME3 = "source3"; + private static final String SINK_TOPIC_NAME1 = "sink1"; + private static final String SINK_TOPIC_NAME2 = "sink2"; + private static final String REPARTITION_TOPIC_NAME1 = "repartition1"; + private static final String REPARTITION_TOPIC_NAME2 = "repartition2"; + private static final String REPARTITION_TOPIC_NAME3 = "repartition3"; + private static final String REPARTITION_TOPIC_NAME4 = "repartition4"; + private static final String REPARTITION_WITHOUT_PARTITION_COUNT = "repartitionWithoutPartitionCount"; + private static final String SOME_OTHER_TOPIC = "someOtherTopic"; + private static final Map TOPIC_CONFIG1 = Collections.singletonMap("config1", "val1"); + private static final Map TOPIC_CONFIG2 = Collections.singletonMap("config2", "val2"); + private static final Map TOPIC_CONFIG5 = Collections.singletonMap("config5", "val5"); + private static final RepartitionTopicConfig REPARTITION_TOPIC_CONFIG1 = + new RepartitionTopicConfig(REPARTITION_TOPIC_NAME1, TOPIC_CONFIG1, 4, true); + private static final RepartitionTopicConfig REPARTITION_TOPIC_CONFIG2 = + new RepartitionTopicConfig(REPARTITION_TOPIC_NAME2, TOPIC_CONFIG2, 2, true); + private static final TopicsInfo TOPICS_INFO1 = new TopicsInfo( + mkSet(REPARTITION_TOPIC_NAME1), + mkSet(SOURCE_TOPIC_NAME1, SOURCE_TOPIC_NAME2), + mkMap( + mkEntry(REPARTITION_TOPIC_NAME1, REPARTITION_TOPIC_CONFIG1), + mkEntry(REPARTITION_TOPIC_NAME2, REPARTITION_TOPIC_CONFIG2) + ), + Collections.emptyMap() + ); + private static final TopicsInfo TOPICS_INFO2 = new TopicsInfo( + mkSet(SINK_TOPIC_NAME1), + mkSet(REPARTITION_TOPIC_NAME1), + mkMap(mkEntry(REPARTITION_TOPIC_NAME1, REPARTITION_TOPIC_CONFIG1)), + Collections.emptyMap() + ); + final StreamsConfig config = new DummyStreamsConfig(); + + final InternalTopologyBuilder internalTopologyBuilder = mock(InternalTopologyBuilder.class); + final InternalTopicManager internalTopicManager = mock(InternalTopicManager.class); + final CopartitionedTopicsEnforcer copartitionedTopicsEnforcer = mock(CopartitionedTopicsEnforcer.class); + final Cluster clusterMetadata = niceMock(Cluster.class); + + @Test + public void shouldSetupRepartitionTopics() { + expect(internalTopologyBuilder.hasNamedTopology()).andStubReturn(false); + expect(internalTopologyBuilder.topicGroups()) + .andReturn(mkMap(mkEntry(SUBTOPOLOGY_0, TOPICS_INFO1), mkEntry(SUBTOPOLOGY_1, TOPICS_INFO2))); + final Set coPartitionGroup1 = mkSet(SOURCE_TOPIC_NAME1, SOURCE_TOPIC_NAME2); + final Set coPartitionGroup2 = mkSet(REPARTITION_TOPIC_NAME1, REPARTITION_TOPIC_NAME2); + final List> coPartitionGroups = Arrays.asList(coPartitionGroup1, coPartitionGroup2); + expect(internalTopologyBuilder.copartitionGroups()).andReturn(coPartitionGroups); + copartitionedTopicsEnforcer.enforce(eq(coPartitionGroup1), anyObject(), eq(clusterMetadata)); + copartitionedTopicsEnforcer.enforce(eq(coPartitionGroup2), anyObject(), eq(clusterMetadata)); + expect(internalTopicManager.makeReady( + mkMap( + mkEntry(REPARTITION_TOPIC_NAME1, REPARTITION_TOPIC_CONFIG1), + mkEntry(REPARTITION_TOPIC_NAME2, REPARTITION_TOPIC_CONFIG2) + )) + ).andReturn(Collections.emptySet()); + setupCluster(); + replay(internalTopicManager, internalTopologyBuilder, clusterMetadata); + final RepartitionTopics repartitionTopics = new RepartitionTopics( + new TopologyMetadata(internalTopologyBuilder, config), + internalTopicManager, + copartitionedTopicsEnforcer, + clusterMetadata, + "[test] " + ); + + repartitionTopics.setup(); + + verify(internalTopicManager, internalTopologyBuilder); + final Map topicPartitionsInfo = repartitionTopics.topicPartitionsInfo(); + assertThat(topicPartitionsInfo.size(), is(6)); + verifyRepartitionTopicPartitionInfo(topicPartitionsInfo, REPARTITION_TOPIC_NAME1, 0); + verifyRepartitionTopicPartitionInfo(topicPartitionsInfo, REPARTITION_TOPIC_NAME1, 1); + verifyRepartitionTopicPartitionInfo(topicPartitionsInfo, REPARTITION_TOPIC_NAME1, 2); + verifyRepartitionTopicPartitionInfo(topicPartitionsInfo, REPARTITION_TOPIC_NAME1, 3); + verifyRepartitionTopicPartitionInfo(topicPartitionsInfo, REPARTITION_TOPIC_NAME2, 0); + verifyRepartitionTopicPartitionInfo(topicPartitionsInfo, REPARTITION_TOPIC_NAME2, 1); + } + + @Test + public void shouldThrowMissingSourceTopicException() { + expect(internalTopologyBuilder.hasNamedTopology()).andStubReturn(false); + expect(internalTopologyBuilder.topicGroups()) + .andReturn(mkMap(mkEntry(SUBTOPOLOGY_0, TOPICS_INFO1), mkEntry(SUBTOPOLOGY_1, TOPICS_INFO2))); + expect(internalTopologyBuilder.copartitionGroups()).andReturn(Collections.emptyList()); + copartitionedTopicsEnforcer.enforce(eq(Collections.emptySet()), anyObject(), eq(clusterMetadata)); + expect(internalTopicManager.makeReady( + mkMap( + mkEntry(REPARTITION_TOPIC_NAME1, REPARTITION_TOPIC_CONFIG1) + )) + ).andReturn(Collections.emptySet()); + setupClusterWithMissingTopics(mkSet(SOURCE_TOPIC_NAME1)); + replay(internalTopicManager, internalTopologyBuilder, clusterMetadata); + final RepartitionTopics repartitionTopics = new RepartitionTopics( + new TopologyMetadata(internalTopologyBuilder, config), + internalTopicManager, + copartitionedTopicsEnforcer, + clusterMetadata, + "[test] " + ); + + assertThrows(MissingSourceTopicException.class, repartitionTopics::setup); + } + + @Test + public void shouldThrowTaskAssignmentExceptionIfPartitionCountCannotBeComputedForAllRepartitionTopics() { + final RepartitionTopicConfig repartitionTopicConfigWithoutPartitionCount = + new RepartitionTopicConfig(REPARTITION_WITHOUT_PARTITION_COUNT, TOPIC_CONFIG5); + expect(internalTopologyBuilder.hasNamedTopology()).andStubReturn(false); + expect(internalTopologyBuilder.topicGroups()) + .andReturn(mkMap( + mkEntry(SUBTOPOLOGY_0, TOPICS_INFO1), + mkEntry(SUBTOPOLOGY_1, setupTopicInfoWithRepartitionTopicWithoutPartitionCount(repartitionTopicConfigWithoutPartitionCount)) + )); + expect(internalTopologyBuilder.copartitionGroups()).andReturn(Collections.emptyList()); + copartitionedTopicsEnforcer.enforce(eq(Collections.emptySet()), anyObject(), eq(clusterMetadata)); + expect(internalTopicManager.makeReady( + mkMap( + mkEntry(REPARTITION_TOPIC_NAME1, REPARTITION_TOPIC_CONFIG1) + )) + ).andReturn(Collections.emptySet()); + setupCluster(); + replay(internalTopicManager, internalTopologyBuilder, clusterMetadata); + final RepartitionTopics repartitionTopics = new RepartitionTopics( + new TopologyMetadata(internalTopologyBuilder, config), + internalTopicManager, + copartitionedTopicsEnforcer, + clusterMetadata, + "[test] " + ); + + final TaskAssignmentException exception = assertThrows(TaskAssignmentException.class, repartitionTopics::setup); + assertThat(exception.getMessage(), is("Failed to compute number of partitions for all repartition topics, make sure all user input topics are created and all Pattern subscriptions match at least one topic in the cluster")); + } + + @Test + public void shouldThrowTaskAssignmentExceptionIfSourceTopicHasNoPartitionCount() { + final RepartitionTopicConfig repartitionTopicConfigWithoutPartitionCount = + new RepartitionTopicConfig(REPARTITION_WITHOUT_PARTITION_COUNT, TOPIC_CONFIG5); + final TopicsInfo topicsInfo = new TopicsInfo( + mkSet(REPARTITION_WITHOUT_PARTITION_COUNT), + mkSet(SOURCE_TOPIC_NAME1), + mkMap( + mkEntry(REPARTITION_WITHOUT_PARTITION_COUNT, repartitionTopicConfigWithoutPartitionCount) + ), + Collections.emptyMap() + ); + expect(internalTopologyBuilder.hasNamedTopology()).andStubReturn(false); + expect(internalTopologyBuilder.topicGroups()) + .andReturn(mkMap( + mkEntry(SUBTOPOLOGY_0, topicsInfo), + mkEntry(SUBTOPOLOGY_1, setupTopicInfoWithRepartitionTopicWithoutPartitionCount(repartitionTopicConfigWithoutPartitionCount)) + )); + expect(internalTopologyBuilder.copartitionGroups()).andReturn(Collections.emptyList()); + copartitionedTopicsEnforcer.enforce(eq(Collections.emptySet()), anyObject(), eq(clusterMetadata)); + expect(internalTopicManager.makeReady( + mkMap( + mkEntry(REPARTITION_WITHOUT_PARTITION_COUNT, repartitionTopicConfigWithoutPartitionCount) + )) + ).andReturn(Collections.emptySet()); + setupClusterWithMissingPartitionCounts(mkSet(SOURCE_TOPIC_NAME1)); + replay(internalTopicManager, internalTopologyBuilder, clusterMetadata); + final RepartitionTopics repartitionTopics = new RepartitionTopics( + new TopologyMetadata(internalTopologyBuilder, config), + internalTopicManager, + copartitionedTopicsEnforcer, + clusterMetadata, + "[test] " + ); + + final TaskAssignmentException exception = assertThrows(TaskAssignmentException.class, repartitionTopics::setup); + assertThat( + exception.getMessage(), + is("No partition count found for source topic " + SOURCE_TOPIC_NAME1 + ", but it should have been.") + ); + } + + @Test + public void shouldSetRepartitionTopicPartitionCountFromUpstreamExternalSourceTopic() { + final RepartitionTopicConfig repartitionTopicConfigWithoutPartitionCount = + new RepartitionTopicConfig(REPARTITION_WITHOUT_PARTITION_COUNT, TOPIC_CONFIG5); + final TopicsInfo topicsInfo = new TopicsInfo( + mkSet(REPARTITION_TOPIC_NAME1, REPARTITION_WITHOUT_PARTITION_COUNT), + mkSet(SOURCE_TOPIC_NAME1, REPARTITION_TOPIC_NAME2), + mkMap( + mkEntry(REPARTITION_TOPIC_NAME1, REPARTITION_TOPIC_CONFIG1), + mkEntry(REPARTITION_TOPIC_NAME2, REPARTITION_TOPIC_CONFIG2), + mkEntry(REPARTITION_WITHOUT_PARTITION_COUNT, repartitionTopicConfigWithoutPartitionCount) + ), + Collections.emptyMap() + ); + expect(internalTopologyBuilder.hasNamedTopology()).andStubReturn(false); + expect(internalTopologyBuilder.topicGroups()) + .andReturn(mkMap( + mkEntry(SUBTOPOLOGY_0, topicsInfo), + mkEntry(SUBTOPOLOGY_1, setupTopicInfoWithRepartitionTopicWithoutPartitionCount(repartitionTopicConfigWithoutPartitionCount)) + )); + expect(internalTopologyBuilder.copartitionGroups()).andReturn(Collections.emptyList()); + copartitionedTopicsEnforcer.enforce(eq(Collections.emptySet()), anyObject(), eq(clusterMetadata)); + expect(internalTopicManager.makeReady( + mkMap( + mkEntry(REPARTITION_TOPIC_NAME1, REPARTITION_TOPIC_CONFIG1), + mkEntry(REPARTITION_TOPIC_NAME2, REPARTITION_TOPIC_CONFIG2), + mkEntry(REPARTITION_WITHOUT_PARTITION_COUNT, repartitionTopicConfigWithoutPartitionCount) + )) + ).andReturn(Collections.emptySet()); + setupCluster(); + replay(internalTopicManager, internalTopologyBuilder, clusterMetadata); + final RepartitionTopics repartitionTopics = new RepartitionTopics( + new TopologyMetadata(internalTopologyBuilder, config), + internalTopicManager, + copartitionedTopicsEnforcer, + clusterMetadata, + "[test] " + ); + + repartitionTopics.setup(); + + verify(internalTopicManager, internalTopologyBuilder); + final Map topicPartitionsInfo = repartitionTopics.topicPartitionsInfo(); + assertThat(topicPartitionsInfo.size(), is(9)); + verifyRepartitionTopicPartitionInfo(topicPartitionsInfo, REPARTITION_TOPIC_NAME1, 0); + verifyRepartitionTopicPartitionInfo(topicPartitionsInfo, REPARTITION_TOPIC_NAME1, 1); + verifyRepartitionTopicPartitionInfo(topicPartitionsInfo, REPARTITION_TOPIC_NAME1, 2); + verifyRepartitionTopicPartitionInfo(topicPartitionsInfo, REPARTITION_TOPIC_NAME1, 3); + verifyRepartitionTopicPartitionInfo(topicPartitionsInfo, REPARTITION_TOPIC_NAME2, 0); + verifyRepartitionTopicPartitionInfo(topicPartitionsInfo, REPARTITION_TOPIC_NAME2, 1); + verifyRepartitionTopicPartitionInfo(topicPartitionsInfo, REPARTITION_WITHOUT_PARTITION_COUNT, 0); + verifyRepartitionTopicPartitionInfo(topicPartitionsInfo, REPARTITION_WITHOUT_PARTITION_COUNT, 1); + verifyRepartitionTopicPartitionInfo(topicPartitionsInfo, REPARTITION_WITHOUT_PARTITION_COUNT, 2); + } + + @Test + public void shouldSetRepartitionTopicPartitionCountFromUpstreamInternalRepartitionSourceTopic() { + final RepartitionTopicConfig repartitionTopicConfigWithoutPartitionCount = + new RepartitionTopicConfig(REPARTITION_WITHOUT_PARTITION_COUNT, TOPIC_CONFIG5); + final TopicsInfo topicsInfo = new TopicsInfo( + mkSet(REPARTITION_TOPIC_NAME2, REPARTITION_WITHOUT_PARTITION_COUNT), + mkSet(SOURCE_TOPIC_NAME1, REPARTITION_TOPIC_NAME1), + mkMap( + mkEntry(REPARTITION_TOPIC_NAME1, REPARTITION_TOPIC_CONFIG1), + mkEntry(REPARTITION_TOPIC_NAME2, REPARTITION_TOPIC_CONFIG2), + mkEntry(REPARTITION_WITHOUT_PARTITION_COUNT, repartitionTopicConfigWithoutPartitionCount) + ), + Collections.emptyMap() + ); + expect(internalTopologyBuilder.hasNamedTopology()).andStubReturn(false); + expect(internalTopologyBuilder.topicGroups()) + .andReturn(mkMap( + mkEntry(SUBTOPOLOGY_0, topicsInfo), + mkEntry(SUBTOPOLOGY_1, setupTopicInfoWithRepartitionTopicWithoutPartitionCount(repartitionTopicConfigWithoutPartitionCount)) + )); + expect(internalTopologyBuilder.copartitionGroups()).andReturn(Collections.emptyList()); + copartitionedTopicsEnforcer.enforce(eq(Collections.emptySet()), anyObject(), eq(clusterMetadata)); + expect(internalTopicManager.makeReady( + mkMap( + mkEntry(REPARTITION_TOPIC_NAME1, REPARTITION_TOPIC_CONFIG1), + mkEntry(REPARTITION_TOPIC_NAME2, REPARTITION_TOPIC_CONFIG2), + mkEntry(REPARTITION_WITHOUT_PARTITION_COUNT, repartitionTopicConfigWithoutPartitionCount) + )) + ).andReturn(Collections.emptySet()); + setupCluster(); + replay(internalTopicManager, internalTopologyBuilder, clusterMetadata); + final RepartitionTopics repartitionTopics = new RepartitionTopics( + new TopologyMetadata(internalTopologyBuilder, config), + internalTopicManager, + copartitionedTopicsEnforcer, + clusterMetadata, + "[test] " + ); + + repartitionTopics.setup(); + + verify(internalTopicManager, internalTopologyBuilder); + final Map topicPartitionsInfo = repartitionTopics.topicPartitionsInfo(); + assertThat(topicPartitionsInfo.size(), is(10)); + verifyRepartitionTopicPartitionInfo(topicPartitionsInfo, REPARTITION_TOPIC_NAME1, 0); + verifyRepartitionTopicPartitionInfo(topicPartitionsInfo, REPARTITION_TOPIC_NAME1, 1); + verifyRepartitionTopicPartitionInfo(topicPartitionsInfo, REPARTITION_TOPIC_NAME1, 2); + verifyRepartitionTopicPartitionInfo(topicPartitionsInfo, REPARTITION_TOPIC_NAME1, 3); + verifyRepartitionTopicPartitionInfo(topicPartitionsInfo, REPARTITION_TOPIC_NAME2, 0); + verifyRepartitionTopicPartitionInfo(topicPartitionsInfo, REPARTITION_TOPIC_NAME2, 1); + verifyRepartitionTopicPartitionInfo(topicPartitionsInfo, REPARTITION_WITHOUT_PARTITION_COUNT, 0); + verifyRepartitionTopicPartitionInfo(topicPartitionsInfo, REPARTITION_WITHOUT_PARTITION_COUNT, 1); + verifyRepartitionTopicPartitionInfo(topicPartitionsInfo, REPARTITION_WITHOUT_PARTITION_COUNT, 2); + verifyRepartitionTopicPartitionInfo(topicPartitionsInfo, REPARTITION_WITHOUT_PARTITION_COUNT, 3); + } + + @Test + public void shouldNotSetupRepartitionTopicsWhenTopologyDoesNotContainAnyRepartitionTopics() { + final TopicsInfo topicsInfo = new TopicsInfo( + mkSet(SINK_TOPIC_NAME1), + mkSet(SOURCE_TOPIC_NAME1), + Collections.emptyMap(), + Collections.emptyMap() + ); + expect(internalTopologyBuilder.hasNamedTopology()).andStubReturn(false); + expect(internalTopologyBuilder.topicGroups()) + .andReturn(mkMap(mkEntry(SUBTOPOLOGY_0, topicsInfo))); + expect(internalTopologyBuilder.copartitionGroups()).andReturn(Collections.emptySet()); + expect(internalTopicManager.makeReady(Collections.emptyMap())).andReturn(Collections.emptySet()); + setupCluster(); + replay(internalTopicManager, internalTopologyBuilder, clusterMetadata); + final RepartitionTopics repartitionTopics = new RepartitionTopics( + new TopologyMetadata(internalTopologyBuilder, config), + internalTopicManager, + copartitionedTopicsEnforcer, + clusterMetadata, + "[test] " + ); + + repartitionTopics.setup(); + + verify(internalTopicManager, internalTopologyBuilder); + final Map topicPartitionsInfo = repartitionTopics.topicPartitionsInfo(); + assertThat(topicPartitionsInfo, is(Collections.emptyMap())); + } + + private void verifyRepartitionTopicPartitionInfo(final Map topicPartitionsInfo, + final String topic, + final int partition) { + final TopicPartition repartitionTopicPartition = new TopicPartition(topic, partition); + assertThat(topicPartitionsInfo.containsKey(repartitionTopicPartition), is(true)); + final PartitionInfo repartitionTopicInfo = topicPartitionsInfo.get(repartitionTopicPartition); + assertThat(repartitionTopicInfo.topic(), is(topic)); + assertThat(repartitionTopicInfo.partition(), is(partition)); + assertThat(repartitionTopicInfo.inSyncReplicas(), is(new Node[0])); + assertThat(repartitionTopicInfo.leader(), nullValue()); + assertThat(repartitionTopicInfo.offlineReplicas(), is(new Node[0])); + assertThat(repartitionTopicInfo.replicas(), is(new Node[0])); + } + + private void setupCluster() { + setupClusterWithMissingTopicsAndMissingPartitionCounts(Collections.emptySet(), Collections.emptySet()); + } + + private void setupClusterWithMissingTopics(final Set missingTopics) { + setupClusterWithMissingTopicsAndMissingPartitionCounts(missingTopics, Collections.emptySet()); + } + + private void setupClusterWithMissingPartitionCounts(final Set topicsWithMissingPartitionCounts) { + setupClusterWithMissingTopicsAndMissingPartitionCounts(Collections.emptySet(), topicsWithMissingPartitionCounts); + } + + private void setupClusterWithMissingTopicsAndMissingPartitionCounts(final Set missingTopics, + final Set topicsWithMissingPartitionCounts) { + final Set topics = mkSet( + SOURCE_TOPIC_NAME1, + SOURCE_TOPIC_NAME2, + SOURCE_TOPIC_NAME3, + SINK_TOPIC_NAME1, + SINK_TOPIC_NAME2, + REPARTITION_TOPIC_NAME1, + REPARTITION_TOPIC_NAME2, + REPARTITION_TOPIC_NAME3, + REPARTITION_TOPIC_NAME4, + SOME_OTHER_TOPIC + ); + topics.removeAll(missingTopics); + expect(clusterMetadata.topics()).andStubReturn(topics); + expect(clusterMetadata.partitionCountForTopic(SOURCE_TOPIC_NAME1)) + .andStubReturn(topicsWithMissingPartitionCounts.contains(SOURCE_TOPIC_NAME1) ? null : 3); + expect(clusterMetadata.partitionCountForTopic(SOURCE_TOPIC_NAME2)) + .andStubReturn(topicsWithMissingPartitionCounts.contains(SOURCE_TOPIC_NAME2) ? null : 1); + expect(clusterMetadata.partitionCountForTopic(SOURCE_TOPIC_NAME3)) + .andStubReturn(topicsWithMissingPartitionCounts.contains(SOURCE_TOPIC_NAME3) ? null : 2); + } + + private TopicsInfo setupTopicInfoWithRepartitionTopicWithoutPartitionCount(final RepartitionTopicConfig repartitionTopicConfigWithoutPartitionCount) { + return new TopicsInfo( + mkSet(SINK_TOPIC_NAME2), + mkSet(REPARTITION_TOPIC_NAME1, REPARTITION_WITHOUT_PARTITION_COUNT), + mkMap( + mkEntry(REPARTITION_TOPIC_NAME1, REPARTITION_TOPIC_CONFIG1), + mkEntry(REPARTITION_WITHOUT_PARTITION_COUNT, repartitionTopicConfigWithoutPartitionCount) + ), + Collections.emptyMap() + ); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/RepartitionWithMergeOptimizingTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RepartitionWithMergeOptimizingTest.java new file mode 100644 index 0000000..dbc855d --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/RepartitionWithMergeOptimizingTest.java @@ -0,0 +1,303 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.processor.internals; + +import java.util.HashMap; +import java.util.Map; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.LongDeserializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TestOutputTopic; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Named; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.test.StreamsTestUtils; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Properties; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; + +public class RepartitionWithMergeOptimizingTest { + + private final Logger log = LoggerFactory.getLogger(RepartitionWithMergeOptimizingTest.class); + + private static final String INPUT_A_TOPIC = "inputA"; + private static final String INPUT_B_TOPIC = "inputB"; + private static final String COUNT_TOPIC = "outputTopic_0"; + private static final String STRING_COUNT_TOPIC = "outputTopic_1"; + + private static final int ONE_REPARTITION_TOPIC = 1; + private static final int TWO_REPARTITION_TOPICS = 2; + + private final Serializer stringSerializer = new StringSerializer(); + private final Deserializer stringDeserializer = new StringDeserializer(); + + private final Pattern repartitionTopicPattern = Pattern.compile("Sink: .*-repartition"); + + private Properties streamsConfiguration; + private TopologyTestDriver topologyTestDriver; + + private final List> expectedCountKeyValues = + Arrays.asList(KeyValue.pair("A", 6L), KeyValue.pair("B", 6L), KeyValue.pair("C", 6L)); + private final List> expectedStringCountKeyValues = + Arrays.asList(KeyValue.pair("A", "6"), KeyValue.pair("B", "6"), KeyValue.pair("C", "6")); + + @Before + public void setUp() { + streamsConfiguration = StreamsTestUtils.getStreamsConfig(Serdes.String(), Serdes.String()); + streamsConfiguration.setProperty(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, Integer.toString(1024 * 10)); + streamsConfiguration.setProperty(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, Long.toString(5000)); + } + + @After + public void tearDown() { + try { + topologyTestDriver.close(); + } catch (final RuntimeException e) { + log.warn("The following exception was thrown while trying to close the TopologyTestDriver (note that " + + "KAFKA-6647 causes this when running on Windows):", e); + } + } + + @Test + public void shouldSendCorrectRecords_OPTIMIZED() { + runTest(StreamsConfig.OPTIMIZE, ONE_REPARTITION_TOPIC); + } + + @Test + public void shouldSendCorrectResults_NO_OPTIMIZATION() { + runTest(StreamsConfig.NO_OPTIMIZATION, TWO_REPARTITION_TOPICS); + } + + + private void runTest(final String optimizationConfig, final int expectedNumberRepartitionTopics) { + + streamsConfiguration.setProperty(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, optimizationConfig); + + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream sourceAStream = + builder.stream(INPUT_A_TOPIC, Consumed.with(Serdes.String(), Serdes.String()).withName("sourceAStream")); + + final KStream sourceBStream = + builder.stream(INPUT_B_TOPIC, Consumed.with(Serdes.String(), Serdes.String()).withName("sourceBStream")); + + final KStream mappedAStream = + sourceAStream.map((k, v) -> KeyValue.pair(v.split(":")[0], v), Named.as("mappedAStream")); + final KStream mappedBStream = + sourceBStream.map((k, v) -> KeyValue.pair(v.split(":")[0], v), Named.as("mappedBStream")); + + final KStream mergedStream = mappedAStream.merge(mappedBStream, Named.as("mergedStream")); + + mergedStream + .groupByKey(Grouped.as("long-groupByKey")) + .count(Named.as("long-count"), Materialized.as(Stores.inMemoryKeyValueStore("long-store"))) + .toStream(Named.as("long-toStream")) + .to(COUNT_TOPIC, Produced.with(Serdes.String(), Serdes.Long()).withName("long-to")); + + mergedStream + .groupByKey(Grouped.as("string-groupByKey")) + .count(Named.as("string-count"), Materialized.as(Stores.inMemoryKeyValueStore("string-store"))) + .toStream(Named.as("string-toStream")) + .mapValues(v -> v.toString(), Named.as("string-mapValues")) + .to(STRING_COUNT_TOPIC, Produced.with(Serdes.String(), Serdes.String()).withName("string-to")); + + final Topology topology = builder.build(streamsConfiguration); + + topologyTestDriver = new TopologyTestDriver(topology, streamsConfiguration); + + final TestInputTopic inputTopicA = topologyTestDriver.createInputTopic(INPUT_A_TOPIC, stringSerializer, stringSerializer); + final TestInputTopic inputTopicB = topologyTestDriver.createInputTopic(INPUT_B_TOPIC, stringSerializer, stringSerializer); + + final TestOutputTopic countOutputTopic = topologyTestDriver.createOutputTopic(COUNT_TOPIC, stringDeserializer, new LongDeserializer()); + final TestOutputTopic stringCountOutputTopic = topologyTestDriver.createOutputTopic(STRING_COUNT_TOPIC, stringDeserializer, stringDeserializer); + + inputTopicA.pipeKeyValueList(getKeyValues()); + inputTopicB.pipeKeyValueList(getKeyValues()); + + final String topologyString = topology.describe().toString(); + + // Verify the topology + if (optimizationConfig.equals(StreamsConfig.OPTIMIZE)) { + assertEquals(EXPECTED_OPTIMIZED_TOPOLOGY, topologyString); + } else { + assertEquals(EXPECTED_UNOPTIMIZED_TOPOLOGY, topologyString); + } + + // Verify the number of repartition topics + assertEquals(expectedNumberRepartitionTopics, getCountOfRepartitionTopicsFound(topologyString)); + + // Verify the expected output + assertThat(countOutputTopic.readKeyValuesToMap(), equalTo(keyValueListToMap(expectedCountKeyValues))); + assertThat(stringCountOutputTopic.readKeyValuesToMap(), equalTo(keyValueListToMap(expectedStringCountKeyValues))); + } + + private Map keyValueListToMap(final List> keyValuePairs) { + final Map map = new HashMap<>(); + for (final KeyValue pair : keyValuePairs) { + map.put(pair.key, pair.value); + } + return map; + } + + private int getCountOfRepartitionTopicsFound(final String topologyString) { + final Matcher matcher = repartitionTopicPattern.matcher(topologyString); + final List repartitionTopicsFound = new ArrayList<>(); + while (matcher.find()) { + repartitionTopicsFound.add(matcher.group()); + } + return repartitionTopicsFound.size(); + } + + private List> getKeyValues() { + final List> keyValueList = new ArrayList<>(); + final String[] keys = new String[]{"X", "Y", "Z"}; + final String[] values = new String[]{"A:foo", "B:foo", "C:foo"}; + for (final String key : keys) { + for (final String value : values) { + keyValueList.add(KeyValue.pair(key, value)); + } + } + return keyValueList; + } + + private static final String EXPECTED_OPTIMIZED_TOPOLOGY = "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: sourceAStream (topics: [inputA])\n" + + " --> mappedAStream\n" + + " Source: sourceBStream (topics: [inputB])\n" + + " --> mappedBStream\n" + + " Processor: mappedAStream (stores: [])\n" + + " --> mergedStream\n" + + " <-- sourceAStream\n" + + " Processor: mappedBStream (stores: [])\n" + + " --> mergedStream\n" + + " <-- sourceBStream\n" + + " Processor: mergedStream (stores: [])\n" + + " --> long-groupByKey-repartition-filter\n" + + " <-- mappedAStream, mappedBStream\n" + + " Processor: long-groupByKey-repartition-filter (stores: [])\n" + + " --> long-groupByKey-repartition-sink\n" + + " <-- mergedStream\n" + + " Sink: long-groupByKey-repartition-sink (topic: long-groupByKey-repartition)\n" + + " <-- long-groupByKey-repartition-filter\n" + + "\n" + + " Sub-topology: 1\n" + + " Source: long-groupByKey-repartition-source (topics: [long-groupByKey-repartition])\n" + + " --> long-count, string-count\n" + + " Processor: string-count (stores: [string-store])\n" + + " --> string-toStream\n" + + " <-- long-groupByKey-repartition-source\n" + + " Processor: long-count (stores: [long-store])\n" + + " --> long-toStream\n" + + " <-- long-groupByKey-repartition-source\n" + + " Processor: string-toStream (stores: [])\n" + + " --> string-mapValues\n" + + " <-- string-count\n" + + " Processor: long-toStream (stores: [])\n" + + " --> long-to\n" + + " <-- long-count\n" + + " Processor: string-mapValues (stores: [])\n" + + " --> string-to\n" + + " <-- string-toStream\n" + + " Sink: long-to (topic: outputTopic_0)\n" + + " <-- long-toStream\n" + + " Sink: string-to (topic: outputTopic_1)\n" + + " <-- string-mapValues\n\n"; + + + private static final String EXPECTED_UNOPTIMIZED_TOPOLOGY = "Topologies:\n" + + " Sub-topology: 0\n" + + " Source: sourceAStream (topics: [inputA])\n" + + " --> mappedAStream\n" + + " Source: sourceBStream (topics: [inputB])\n" + + " --> mappedBStream\n" + + " Processor: mappedAStream (stores: [])\n" + + " --> mergedStream\n" + + " <-- sourceAStream\n" + + " Processor: mappedBStream (stores: [])\n" + + " --> mergedStream\n" + + " <-- sourceBStream\n" + + " Processor: mergedStream (stores: [])\n" + + " --> long-groupByKey-repartition-filter, string-groupByKey-repartition-filter\n" + + " <-- mappedAStream, mappedBStream\n" + + " Processor: long-groupByKey-repartition-filter (stores: [])\n" + + " --> long-groupByKey-repartition-sink\n" + + " <-- mergedStream\n" + + " Processor: string-groupByKey-repartition-filter (stores: [])\n" + + " --> string-groupByKey-repartition-sink\n" + + " <-- mergedStream\n" + + " Sink: long-groupByKey-repartition-sink (topic: long-groupByKey-repartition)\n" + + " <-- long-groupByKey-repartition-filter\n" + + " Sink: string-groupByKey-repartition-sink (topic: string-groupByKey-repartition)\n" + + " <-- string-groupByKey-repartition-filter\n" + + "\n" + + " Sub-topology: 1\n" + + " Source: long-groupByKey-repartition-source (topics: [long-groupByKey-repartition])\n" + + " --> long-count\n" + + " Processor: long-count (stores: [long-store])\n" + + " --> long-toStream\n" + + " <-- long-groupByKey-repartition-source\n" + + " Processor: long-toStream (stores: [])\n" + + " --> long-to\n" + + " <-- long-count\n" + + " Sink: long-to (topic: outputTopic_0)\n" + + " <-- long-toStream\n" + + "\n" + + " Sub-topology: 2\n" + + " Source: string-groupByKey-repartition-source (topics: [string-groupByKey-repartition])\n" + + " --> string-count\n" + + " Processor: string-count (stores: [string-store])\n" + + " --> string-toStream\n" + + " <-- string-groupByKey-repartition-source\n" + + " Processor: string-toStream (stores: [])\n" + + " --> string-mapValues\n" + + " <-- string-count\n" + + " Processor: string-mapValues (stores: [])\n" + + " --> string-to\n" + + " <-- string-toStream\n" + + " Sink: string-to (topic: outputTopic_1)\n" + + " <-- string-mapValues\n\n"; + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/SinkNodeTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/SinkNodeTest.java new file mode 100644 index 0000000..7e7f7b8 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/SinkNodeTest.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.state.StateSerdes; +import org.apache.kafka.test.InternalMockProcessorContext; +import org.apache.kafka.test.MockRecordCollector; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.fail; + +public class SinkNodeTest { + private final StateSerdes anyStateSerde = StateSerdes.withBuiltinTypes("anyName", Bytes.class, Bytes.class); + private final Serializer anySerializer = Serdes.ByteArray().serializer(); + private final RecordCollector recordCollector = new MockRecordCollector(); + private final InternalMockProcessorContext context = new InternalMockProcessorContext<>(anyStateSerde, recordCollector); + private final SinkNode sink = new SinkNode<>("anyNodeName", + new StaticTopicNameExtractor<>("any-output-topic"), anySerializer, anySerializer, null); + + // Used to verify that the correct exceptions are thrown if the compiler checks are bypassed + @SuppressWarnings({"unchecked", "rawtypes"}) + private final SinkNode illTypedSink = (SinkNode) sink; + + @Before + public void before() { + sink.init(context); + } + + @Test + public void shouldThrowStreamsExceptionOnInputRecordWithInvalidTimestamp() { + // When/Then + context.setTime(-1); // ensures a negative timestamp is set for the record we send next + try { + illTypedSink.process(new Record<>("any key".getBytes(), "any value".getBytes(), -1)); + fail("Should have thrown StreamsException"); + } catch (final StreamsException ignored) { + // expected + } + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/SourceNodeTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/SourceNodeTest.java new file mode 100644 index 0000000..03f22a3 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/SourceNodeTest.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.SensorAccessor; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.test.InternalMockProcessorContext; +import org.apache.kafka.test.MockSourceNode; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Test; + +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; +import static org.junit.Assert.assertTrue; + +public class SourceNodeTest { + @Test + public void shouldProvideTopicHeadersAndDataToKeyDeserializer() { + final SourceNode sourceNode = new MockSourceNode<>(new TheDeserializer(), new TheDeserializer()); + final RecordHeaders headers = new RecordHeaders(); + final String deserializeKey = sourceNode.deserializeKey("topic", headers, "data".getBytes(StandardCharsets.UTF_8)); + assertThat(deserializeKey, is("topic" + headers + "data")); + } + + @Test + public void shouldProvideTopicHeadersAndDataToValueDeserializer() { + final SourceNode sourceNode = new MockSourceNode<>(new TheDeserializer(), new TheDeserializer()); + final RecordHeaders headers = new RecordHeaders(); + final String deserializedValue = sourceNode.deserializeValue("topic", headers, "data".getBytes(StandardCharsets.UTF_8)); + assertThat(deserializedValue, is("topic" + headers + "data")); + } + + public static class TheDeserializer implements Deserializer { + @Override + public String deserialize(final String topic, final Headers headers, final byte[] data) { + return topic + headers + new String(data, StandardCharsets.UTF_8); + } + + @Override + public String deserialize(final String topic, final byte[] data) { + return deserialize(topic, null, data); + } + } + + @Test + public void shouldExposeProcessMetrics() { + final Metrics metrics = new Metrics(); + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, "test-client", StreamsConfig.METRICS_LATEST, new MockTime()); + final InternalMockProcessorContext context = new InternalMockProcessorContext<>(streamsMetrics); + final SourceNode node = + new SourceNode<>(context.currentNode().name(), new TheDeserializer(), new TheDeserializer()); + node.init(context); + + final String threadId = Thread.currentThread().getName(); + final String groupName = "stream-processor-node-metrics"; + final Map metricTags = mkMap( + mkEntry("thread-id", threadId), + mkEntry("task-id", context.taskId().toString()), + mkEntry("processor-node-id", node.name()) + ); + + assertTrue(StreamsTestUtils.containsMetric(metrics, "process-rate", groupName, metricTags)); + assertTrue(StreamsTestUtils.containsMetric(metrics, "process-total", groupName, metricTags)); + + // test parent sensors + final String parentGroupName = "stream-task-metrics"; + metricTags.remove("processor-node-id"); + assertTrue(StreamsTestUtils.containsMetric(metrics, "process-rate", parentGroupName, metricTags)); + assertTrue(StreamsTestUtils.containsMetric(metrics, "process-total", parentGroupName, metricTags)); + + final String sensorNamePrefix = "internal." + threadId + ".task." + context.taskId().toString(); + final Sensor processSensor = + metrics.getSensor(sensorNamePrefix + ".node." + context.currentNode().name() + ".s.process"); + final SensorAccessor sensorAccessor = new SensorAccessor(processSensor); + assertThat( + sensorAccessor.parents().stream().map(Sensor::name).collect(Collectors.toList()), + contains(sensorNamePrefix + ".s.process") + ); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java new file mode 100644 index 0000000..50f4c33 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StandbyTaskTest.java @@ -0,0 +1,638 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.metrics.KafkaMetric; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.stats.CumulativeSum; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.LockException; +import org.apache.kafka.streams.errors.ProcessorStateException; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.Task.TaskType; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.internals.ThreadCache; +import org.apache.kafka.test.MockKeyValueStore; +import org.apache.kafka.test.MockKeyValueStoreBuilder; +import org.apache.kafka.test.MockRestoreConsumer; +import org.apache.kafka.test.MockTimestampExtractor; +import org.apache.kafka.test.TestUtils; +import org.easymock.EasyMock; +import org.easymock.EasyMockRunner; +import org.easymock.Mock; +import org.easymock.MockType; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import java.io.File; +import java.io.IOException; +import java.time.Duration; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import static java.util.Arrays.asList; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkProperties; +import static org.apache.kafka.streams.processor.internals.Task.State.CREATED; +import static org.apache.kafka.streams.processor.internals.Task.State.RUNNING; +import static org.apache.kafka.streams.processor.internals.Task.State.SUSPENDED; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.isA; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +@RunWith(EasyMockRunner.class) +public class StandbyTaskTest { + + private final String threadName = "threadName"; + private final String threadId = Thread.currentThread().getName(); + private final TaskId taskId = new TaskId(0, 0, "My-Topology"); + + private final String storeName1 = "store1"; + private final String storeName2 = "store2"; + private final String applicationId = "test-application"; + private final String storeChangelogTopicName1 = ProcessorStateManager.storeChangelogTopic(applicationId, storeName1, taskId.topologyName()); + private final String storeChangelogTopicName2 = ProcessorStateManager.storeChangelogTopic(applicationId, storeName2, taskId.topologyName()); + + private final TopicPartition partition = new TopicPartition(storeChangelogTopicName1, 0); + private final MockKeyValueStore store1 = (MockKeyValueStore) new MockKeyValueStoreBuilder(storeName1, false).build(); + private final MockKeyValueStore store2 = (MockKeyValueStore) new MockKeyValueStoreBuilder(storeName2, true).build(); + + private final ProcessorTopology topology = ProcessorTopologyFactories.withLocalStores( + asList(store1, store2), + mkMap(mkEntry(storeName1, storeChangelogTopicName1), mkEntry(storeName2, storeChangelogTopicName2)) + ); + private final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(new Metrics(), threadName, StreamsConfig.METRICS_LATEST, new MockTime()); + + private File baseDir; + private StreamsConfig config; + private StateDirectory stateDirectory; + private StandbyTask task; + + private StreamsConfig createConfig(final File baseDir) throws IOException { + return new StreamsConfig(mkProperties(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, applicationId), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:2171"), + mkEntry(StreamsConfig.BUFFERED_RECORDS_PER_PARTITION_CONFIG, "3"), + mkEntry(StreamsConfig.STATE_DIR_CONFIG, baseDir.getCanonicalPath()), + mkEntry(StreamsConfig.DEFAULT_TIMESTAMP_EXTRACTOR_CLASS_CONFIG, MockTimestampExtractor.class.getName()) + ))); + } + + private final MockRestoreConsumer restoreStateConsumer = new MockRestoreConsumer<>( + new IntegerSerializer(), + new IntegerSerializer() + ); + + @Mock(type = MockType.NICE) + private ProcessorStateManager stateManager; + + @Before + public void setup() throws Exception { + EasyMock.expect(stateManager.taskId()).andStubReturn(taskId); + EasyMock.expect(stateManager.taskType()).andStubReturn(TaskType.STANDBY); + + restoreStateConsumer.reset(); + restoreStateConsumer.updatePartitions(storeChangelogTopicName1, asList( + new PartitionInfo(storeChangelogTopicName1, 0, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo(storeChangelogTopicName1, 1, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo(storeChangelogTopicName1, 2, Node.noNode(), new Node[0], new Node[0]) + )); + + restoreStateConsumer.updatePartitions(storeChangelogTopicName2, asList( + new PartitionInfo(storeChangelogTopicName2, 0, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo(storeChangelogTopicName2, 1, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo(storeChangelogTopicName2, 2, Node.noNode(), new Node[0], new Node[0]) + )); + baseDir = TestUtils.tempDirectory(); + config = createConfig(baseDir); + stateDirectory = new StateDirectory(config, new MockTime(), true, true); + } + + @After + public void cleanup() throws IOException { + if (task != null) { + try { + task.suspend(); + } catch (final IllegalStateException maybeSwallow) { + if (!maybeSwallow.getMessage().startsWith("Illegal state CLOSED while suspending standby task")) { + throw maybeSwallow; + } + } + task.closeDirty(); + task = null; + } + Utils.delete(baseDir); + } + + @Test + public void shouldThrowLockExceptionIfFailedToLockStateDirectory() throws IOException { + stateDirectory = EasyMock.createNiceMock(StateDirectory.class); + EasyMock.expect(stateDirectory.lock(taskId)).andReturn(false); + EasyMock.expect(stateManager.taskType()).andStubReturn(TaskType.STANDBY); + + EasyMock.replay(stateDirectory, stateManager); + + task = createStandbyTask(); + + assertThrows(LockException.class, () -> task.initializeIfNeeded()); + task = null; + } + + @Test + public void shouldTransitToRunningAfterInitialization() { + EasyMock.expect(stateManager.changelogOffsets()).andStubReturn(Collections.emptyMap()); + stateManager.registerStateStores(EasyMock.anyObject(), EasyMock.anyObject()); + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()).anyTimes(); + EasyMock.replay(stateManager); + + task = createStandbyTask(); + + assertEquals(CREATED, task.state()); + + task.initializeIfNeeded(); + + assertEquals(RUNNING, task.state()); + + // initialize should be idempotent + task.initializeIfNeeded(); + + assertEquals(RUNNING, task.state()); + + EasyMock.verify(stateManager); + } + + @Test + public void shouldThrowIfCommittingOnIllegalState() { + EasyMock.replay(stateManager); + task = createStandbyTask(); + task.suspend(); + task.closeClean(); + + assertThrows(IllegalStateException.class, task::prepareCommit); + } + + @Test + public void shouldFlushAndCheckpointStateManagerOnCommit() { + EasyMock.expect(stateManager.changelogOffsets()).andStubReturn(Collections.emptyMap()); + stateManager.flush(); + EasyMock.expectLastCall(); + stateManager.checkpoint(); + EasyMock.expectLastCall().once(); + EasyMock.expect(stateManager.changelogOffsets()) + .andReturn(Collections.singletonMap(partition, 50L)) + .andReturn(Collections.singletonMap(partition, 11000L)) + .andReturn(Collections.singletonMap(partition, 11000L)); + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.singleton(partition)).anyTimes(); + EasyMock.replay(stateManager); + + task = createStandbyTask(); + task.initializeIfNeeded(); + task.prepareCommit(); + task.postCommit(false); // this should not checkpoint + + task.prepareCommit(); + task.postCommit(false); // this should checkpoint + + task.prepareCommit(); + task.postCommit(false); // this should not checkpoint + + EasyMock.verify(stateManager); + } + + @Test + public void shouldReturnStateManagerChangelogOffsets() { + EasyMock.expect(stateManager.changelogOffsets()).andReturn(Collections.singletonMap(partition, 50L)); + EasyMock.replay(stateManager); + + task = createStandbyTask(); + + assertEquals(Collections.singletonMap(partition, 50L), task.changelogOffsets()); + + EasyMock.verify(stateManager); + } + + @Test + public void shouldNotFlushAndThrowOnCloseDirty() { + EasyMock.expect(stateManager.changelogOffsets()).andStubReturn(Collections.emptyMap()); + stateManager.close(); + EasyMock.expectLastCall().andThrow(new ProcessorStateException("KABOOM!")).anyTimes(); + stateManager.flush(); + EasyMock.expectLastCall().andThrow(new AssertionError("Flush should not be called")).anyTimes(); + stateManager.checkpoint(); + EasyMock.expectLastCall().andThrow(new AssertionError("Checkpoint should not be called")).anyTimes(); + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()).anyTimes(); + EasyMock.replay(stateManager); + final MetricName metricName = setupCloseTaskMetric(); + + task = createStandbyTask(); + task.initializeIfNeeded(); + task.suspend(); + task.closeDirty(); + + assertEquals(Task.State.CLOSED, task.state()); + + final double expectedCloseTaskMetric = 1.0; + verifyCloseTaskMetric(expectedCloseTaskMetric, streamsMetrics, metricName); + + EasyMock.verify(stateManager); + } + + @Test + public void shouldNotThrowFromStateManagerCloseInCloseDirty() { + EasyMock.expect(stateManager.changelogOffsets()).andStubReturn(Collections.emptyMap()); + stateManager.close(); + EasyMock.expectLastCall().andThrow(new RuntimeException("KABOOM!")).anyTimes(); + EasyMock.replay(stateManager); + + task = createStandbyTask(); + task.initializeIfNeeded(); + + task.suspend(); + task.closeDirty(); + + EasyMock.verify(stateManager); + } + + @Test + public void shouldSuspendAndCommitBeforeCloseClean() { + stateManager.close(); + EasyMock.expectLastCall(); + stateManager.checkpoint(); + EasyMock.expectLastCall().once(); + EasyMock.expect(stateManager.changelogOffsets()) + .andReturn(Collections.singletonMap(partition, 60L)); + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.singleton(partition)).anyTimes(); + EasyMock.replay(stateManager); + final MetricName metricName = setupCloseTaskMetric(); + + task = createStandbyTask(); + task.initializeIfNeeded(); + task.suspend(); + task.prepareCommit(); + task.postCommit(true); + task.closeClean(); + + assertEquals(Task.State.CLOSED, task.state()); + + final double expectedCloseTaskMetric = 1.0; + verifyCloseTaskMetric(expectedCloseTaskMetric, streamsMetrics, metricName); + + EasyMock.verify(stateManager); + } + + @Test + public void shouldRequireSuspendingCreatedTasksBeforeClose() { + EasyMock.replay(stateManager); + task = createStandbyTask(); + assertThat(task.state(), equalTo(CREATED)); + assertThrows(IllegalStateException.class, () -> task.closeClean()); + + task.suspend(); + task.closeClean(); + } + + @Test + public void shouldOnlyNeedCommitWhenChangelogOffsetChanged() { + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.singleton(partition)).anyTimes(); + EasyMock.expect(stateManager.changelogOffsets()) + .andReturn(Collections.singletonMap(partition, 50L)) + .andReturn(Collections.singletonMap(partition, 10100L)).anyTimes(); + stateManager.flush(); + EasyMock.expectLastCall(); + stateManager.checkpoint(); + EasyMock.expectLastCall(); + EasyMock.replay(stateManager); + + task = createStandbyTask(); + task.initializeIfNeeded(); + + // no need to commit if we've just initialized and offset not advanced much + assertFalse(task.commitNeeded()); + + // could commit if the offset advanced beyond threshold + assertTrue(task.commitNeeded()); + + task.prepareCommit(); + task.postCommit(true); + + EasyMock.verify(stateManager); + } + + @Test + public void shouldThrowOnCloseCleanError() { + EasyMock.expect(stateManager.changelogOffsets()).andStubReturn(Collections.emptyMap()); + stateManager.close(); + EasyMock.expectLastCall().andThrow(new RuntimeException("KABOOM!")).anyTimes(); + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.singleton(partition)).anyTimes(); + EasyMock.replay(stateManager); + final MetricName metricName = setupCloseTaskMetric(); + + task = createStandbyTask(); + task.initializeIfNeeded(); + + task.suspend(); + assertThrows(RuntimeException.class, () -> task.closeClean()); + + final double expectedCloseTaskMetric = 0.0; + verifyCloseTaskMetric(expectedCloseTaskMetric, streamsMetrics, metricName); + + EasyMock.verify(stateManager); + EasyMock.reset(stateManager); + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.singleton(partition)).anyTimes(); + EasyMock.replay(stateManager); + } + + @Test + public void shouldThrowOnCloseCleanCheckpointError() { + EasyMock.expect(stateManager.changelogOffsets()) + .andReturn(Collections.singletonMap(partition, 50L)); + stateManager.checkpoint(); + EasyMock.expectLastCall().andThrow(new RuntimeException("KABOOM!")).anyTimes(); + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()).anyTimes(); + EasyMock.replay(stateManager); + final MetricName metricName = setupCloseTaskMetric(); + + task = createStandbyTask(); + task.initializeIfNeeded(); + + task.prepareCommit(); + assertThrows(RuntimeException.class, () -> task.postCommit(true)); + + assertEquals(RUNNING, task.state()); + + final double expectedCloseTaskMetric = 0.0; + verifyCloseTaskMetric(expectedCloseTaskMetric, streamsMetrics, metricName); + + EasyMock.verify(stateManager); + EasyMock.reset(stateManager); + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()).anyTimes(); + EasyMock.replay(stateManager); + } + + @Test + public void shouldUnregisterMetricsInCloseClean() { + EasyMock.expect(stateManager.changelogOffsets()).andStubReturn(Collections.emptyMap()); + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()).anyTimes(); + EasyMock.replay(stateManager); + + task = createStandbyTask(); + task.initializeIfNeeded(); + + task.suspend(); + task.closeClean(); + // Currently, there are no metrics registered for standby tasks. + // This is a regression test so that, if we add some, we will be sure to deregister them. + assertThat(getTaskMetrics(), empty()); + } + + @Test + public void shouldUnregisterMetricsInCloseDirty() { + EasyMock.expect(stateManager.changelogOffsets()).andStubReturn(Collections.emptyMap()); + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()).anyTimes(); + EasyMock.replay(stateManager); + + task = createStandbyTask(); + task.initializeIfNeeded(); + + task.suspend(); + task.closeDirty(); + + // Currently, there are no metrics registered for standby tasks. + // This is a regression test so that, if we add some, we will be sure to deregister them. + assertThat(getTaskMetrics(), empty()); + } + + @Test + public void shouldCloseStateManagerOnTaskCreated() { + stateManager.close(); + EasyMock.expectLastCall(); + + EasyMock.replay(stateManager); + + final MetricName metricName = setupCloseTaskMetric(); + + task = createStandbyTask(); + task.suspend(); + + task.closeDirty(); + + final double expectedCloseTaskMetric = 1.0; + verifyCloseTaskMetric(expectedCloseTaskMetric, streamsMetrics, metricName); + + EasyMock.verify(stateManager); + + assertEquals(Task.State.CLOSED, task.state()); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldDeleteStateDirOnTaskCreatedAndEosAlphaUncleanClose() { + stateManager.close(); + EasyMock.expectLastCall(); + + EasyMock.expect(stateManager.baseDir()).andReturn(baseDir); + + EasyMock.replay(stateManager); + + final MetricName metricName = setupCloseTaskMetric(); + + config = new StreamsConfig(mkProperties(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, applicationId), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:2171"), + mkEntry(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE) + ))); + + task = createStandbyTask(); + task.suspend(); + + task.closeDirty(); + + final double expectedCloseTaskMetric = 1.0; + verifyCloseTaskMetric(expectedCloseTaskMetric, streamsMetrics, metricName); + + EasyMock.verify(stateManager); + + assertEquals(Task.State.CLOSED, task.state()); + } + + @Test + public void shouldDeleteStateDirOnTaskCreatedAndEosV2UncleanClose() { + stateManager.close(); + EasyMock.expectLastCall(); + + EasyMock.expect(stateManager.baseDir()).andReturn(baseDir); + + EasyMock.replay(stateManager); + + final MetricName metricName = setupCloseTaskMetric(); + + config = new StreamsConfig(mkProperties(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, applicationId), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:2171"), + mkEntry(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE_V2) + ))); + + task = createStandbyTask(); + + task.suspend(); + task.closeDirty(); + + final double expectedCloseTaskMetric = 1.0; + verifyCloseTaskMetric(expectedCloseTaskMetric, streamsMetrics, metricName); + + EasyMock.verify(stateManager); + + assertEquals(Task.State.CLOSED, task.state()); + } + + @Test + public void shouldRecycleTask() { + EasyMock.expect(stateManager.changelogOffsets()).andStubReturn(Collections.emptyMap()); + stateManager.recycle(); + EasyMock.replay(stateManager); + + task = createStandbyTask(); + assertThrows(IllegalStateException.class, () -> task.closeCleanAndRecycleState()); // CREATED + + task.initializeIfNeeded(); + assertThrows(IllegalStateException.class, () -> task.closeCleanAndRecycleState()); // RUNNING + + task.suspend(); + task.closeCleanAndRecycleState(); // SUSPENDED + + // Currently, there are no metrics registered for standby tasks. + // This is a regression test so that, if we add some, we will be sure to deregister them. + assertThat(getTaskMetrics(), empty()); + + EasyMock.verify(stateManager); + } + + @Test + public void shouldAlwaysSuspendCreatedTasks() { + EasyMock.replay(stateManager); + task = createStandbyTask(); + assertThat(task.state(), equalTo(CREATED)); + task.suspend(); + assertThat(task.state(), equalTo(SUSPENDED)); + } + + @Test + public void shouldAlwaysSuspendRunningTasks() { + EasyMock.expect(stateManager.changelogOffsets()).andStubReturn(Collections.emptyMap()); + EasyMock.replay(stateManager); + task = createStandbyTask(); + task.initializeIfNeeded(); + assertThat(task.state(), equalTo(RUNNING)); + task.suspend(); + assertThat(task.state(), equalTo(SUSPENDED)); + } + + @Test + public void shouldInitTaskTimeoutAndEventuallyThrow() { + EasyMock.replay(stateManager); + + task = createStandbyTask(); + + task.maybeInitTaskTimeoutOrThrow(0L, null); + task.maybeInitTaskTimeoutOrThrow(Duration.ofMinutes(5).toMillis(), null); + + final StreamsException thrown = assertThrows( + StreamsException.class, + () -> task.maybeInitTaskTimeoutOrThrow(Duration.ofMinutes(5).plus(Duration.ofMillis(1L)).toMillis(), null) + ); + + assertThat(thrown.getCause(), isA(TimeoutException.class)); + + } + + @Test + public void shouldCLearTaskTimeout() { + EasyMock.replay(stateManager); + + task = createStandbyTask(); + + task.maybeInitTaskTimeoutOrThrow(0L, null); + task.clearTaskTimeout(); + task.maybeInitTaskTimeoutOrThrow(Duration.ofMinutes(5).plus(Duration.ofMillis(1L)).toMillis(), null); + } + + private StandbyTask createStandbyTask() { + + final ThreadCache cache = new ThreadCache( + new LogContext(String.format("stream-thread [%s] ", Thread.currentThread().getName())), + 0, + streamsMetrics + ); + + final InternalProcessorContext context = new ProcessorContextImpl( + taskId, + config, + stateManager, + streamsMetrics, + cache + ); + + return new StandbyTask( + taskId, + Collections.singleton(partition), + topology, + config, + streamsMetrics, + stateManager, + stateDirectory, + cache, + context); + } + + private MetricName setupCloseTaskMetric() { + final MetricName metricName = new MetricName("name", "group", "description", Collections.emptyMap()); + final Sensor sensor = streamsMetrics.threadLevelSensor(threadId, "task-closed", Sensor.RecordingLevel.INFO); + sensor.add(metricName, new CumulativeSum()); + return metricName; + } + + private void verifyCloseTaskMetric(final double expected, final StreamsMetricsImpl streamsMetrics, final MetricName metricName) { + final KafkaMetric metric = (KafkaMetric) streamsMetrics.metrics().get(metricName); + final double totalCloses = metric.measurable().measure(metric.config(), System.currentTimeMillis()); + assertThat(totalCloses, equalTo(expected)); + } + + private List getTaskMetrics() { + return streamsMetrics.metrics().keySet().stream().filter(m -> m.tags().containsKey("task-id")).collect(Collectors.toList()); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateConsumerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateConsumerTest.java new file mode 100644 index 0000000..1f98eb4 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateConsumerTest.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.MockConsumer; +import org.apache.kafka.clients.consumer.OffsetResetStrategy; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Utils; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.time.Duration; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + + +public class StateConsumerTest { + + private static final long FLUSH_INTERVAL = 1000L; + private final TopicPartition topicOne = new TopicPartition("topic-one", 1); + private final TopicPartition topicTwo = new TopicPartition("topic-two", 1); + private final MockTime time = new MockTime(); + private final MockConsumer consumer = new MockConsumer<>(OffsetResetStrategy.EARLIEST); + private final Map partitionOffsets = new HashMap<>(); + private final LogContext logContext = new LogContext("test "); + private GlobalStreamThread.StateConsumer stateConsumer; + private TaskStub stateMaintainer; + + @Before + public void setUp() { + partitionOffsets.put(topicOne, 20L); + partitionOffsets.put(topicTwo, 30L); + stateMaintainer = new TaskStub(partitionOffsets); + stateConsumer = new GlobalStreamThread.StateConsumer(logContext, consumer, stateMaintainer, time, Duration.ofMillis(10L), FLUSH_INTERVAL); + } + + @Test + public void shouldAssignPartitionsToConsumer() { + stateConsumer.initialize(); + assertEquals(Utils.mkSet(topicOne, topicTwo), consumer.assignment()); + } + + @Test + public void shouldSeekToInitialOffsets() { + stateConsumer.initialize(); + assertEquals(20L, consumer.position(topicOne)); + assertEquals(30L, consumer.position(topicTwo)); + } + + @Test + public void shouldUpdateStateWithReceivedRecordsForPartition() { + stateConsumer.initialize(); + consumer.addRecord(new ConsumerRecord<>("topic-one", 1, 20L, new byte[0], new byte[0])); + consumer.addRecord(new ConsumerRecord<>("topic-one", 1, 21L, new byte[0], new byte[0])); + stateConsumer.pollAndUpdate(); + assertEquals(2, stateMaintainer.updatedPartitions.get(topicOne).intValue()); + } + + @Test + public void shouldUpdateStateWithReceivedRecordsForAllTopicPartition() { + stateConsumer.initialize(); + consumer.addRecord(new ConsumerRecord<>("topic-one", 1, 20L, new byte[0], new byte[0])); + consumer.addRecord(new ConsumerRecord<>("topic-two", 1, 31L, new byte[0], new byte[0])); + consumer.addRecord(new ConsumerRecord<>("topic-two", 1, 32L, new byte[0], new byte[0])); + stateConsumer.pollAndUpdate(); + assertEquals(1, stateMaintainer.updatedPartitions.get(topicOne).intValue()); + assertEquals(2, stateMaintainer.updatedPartitions.get(topicTwo).intValue()); + } + + @Test + public void shouldFlushStoreWhenFlushIntervalHasLapsed() { + stateConsumer.initialize(); + consumer.addRecord(new ConsumerRecord<>("topic-one", 1, 20L, new byte[0], new byte[0])); + time.sleep(FLUSH_INTERVAL); + + stateConsumer.pollAndUpdate(); + assertTrue(stateMaintainer.flushed); + } + + @Test + public void shouldNotFlushOffsetsWhenFlushIntervalHasNotLapsed() { + stateConsumer.initialize(); + consumer.addRecord(new ConsumerRecord<>("topic-one", 1, 20L, new byte[0], new byte[0])); + time.sleep(FLUSH_INTERVAL / 2); + stateConsumer.pollAndUpdate(); + assertFalse(stateMaintainer.flushed); + } + + @Test + public void shouldCloseConsumer() throws IOException { + stateConsumer.close(false); + assertTrue(consumer.closed()); + } + + @Test + public void shouldCloseStateMaintainer() throws IOException { + stateConsumer.close(false); + assertTrue(stateMaintainer.closed); + } + + @Test + public void shouldWipeStoreOnClose() throws IOException { + stateConsumer.close(true); + assertTrue(stateMaintainer.wipeStore); + } + + private static class TaskStub implements GlobalStateMaintainer { + private final Map partitionOffsets; + private final Map updatedPartitions = new HashMap<>(); + private boolean flushed; + private boolean wipeStore; + private boolean closed; + + TaskStub(final Map partitionOffsets) { + this.partitionOffsets = partitionOffsets; + } + + @Override + public Map initialize() { + return partitionOffsets; + } + + public void flushState() { + flushed = true; + } + + @Override + public void close(final boolean wipeStateStore) { + closed = true; + wipeStore = wipeStateStore; + } + + @Override + public void update(final ConsumerRecord record) { + final TopicPartition tp = new TopicPartition(record.topic(), record.partition()); + if (!updatedPartitions.containsKey(tp)) { + updatedPartitions.put(tp, 0); + } + updatedPartitions.put(tp, updatedPartitions.get(tp) + 1); + } + + } + +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateDirectoryTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateDirectoryTest.java new file mode 100644 index 0000000..81bc7d7 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateDirectoryTest.java @@ -0,0 +1,863 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.BufferedWriter; +import java.io.FileOutputStream; +import java.io.OutputStreamWriter; +import java.nio.charset.StandardCharsets; +import java.util.HashSet; +import java.util.List; +import java.util.UUID; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.ProcessorStateException; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.StateDirectory.TaskDirectory; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.apache.kafka.streams.state.internals.OffsetCheckpoint; +import org.apache.kafka.test.TestUtils; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.attribute.PosixFilePermission; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.EnumSet; +import java.util.Objects; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.streams.processor.internals.StateDirectory.PROCESS_FILE_NAME; +import static org.apache.kafka.streams.processor.internals.StateManagerUtil.CHECKPOINT_FILE_NAME; +import static org.apache.kafka.streams.processor.internals.StateManagerUtil.toTaskDirString; + +import static java.util.Collections.emptyList; +import static java.util.Collections.singleton; +import static java.util.Collections.singletonList; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.endsWith; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class StateDirectoryTest { + + private final MockTime time = new MockTime(); + private File stateDir; + private final String applicationId = "applicationId"; + private StateDirectory directory; + private File appDir; + + private void initializeStateDirectory(final boolean createStateDirectory, final boolean hasNamedTopology) throws IOException { + stateDir = new File(TestUtils.IO_TMP_DIR, "kafka-" + TestUtils.randomString(5)); + if (!createStateDirectory) { + cleanup(); + } + directory = new StateDirectory( + new StreamsConfig(new Properties() { + { + put(StreamsConfig.APPLICATION_ID_CONFIG, applicationId); + put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"); + put(StreamsConfig.STATE_DIR_CONFIG, stateDir.getPath()); + } + }), + time, createStateDirectory, hasNamedTopology); + appDir = new File(stateDir, applicationId); + } + + @Before + public void before() throws IOException { + initializeStateDirectory(true, false); + } + + @After + public void cleanup() throws IOException { + Utils.delete(stateDir); + } + + @Test + public void shouldCreateBaseDirectory() { + assertTrue(stateDir.exists()); + assertTrue(stateDir.isDirectory()); + assertTrue(appDir.exists()); + assertTrue(appDir.isDirectory()); + } + + @Test + public void shouldHaveSecurePermissions() { + assertPermissions(stateDir); + assertPermissions(appDir); + } + + private void assertPermissions(final File file) { + final Path path = file.toPath(); + if (path.getFileSystem().supportedFileAttributeViews().contains("posix")) { + final Set expectedPermissions = EnumSet.of( + PosixFilePermission.OWNER_EXECUTE, + PosixFilePermission.GROUP_READ, + PosixFilePermission.OWNER_WRITE, + PosixFilePermission.GROUP_EXECUTE, + PosixFilePermission.OWNER_READ); + try { + final Set filePermissions = Files.getPosixFilePermissions(path); + assertThat(expectedPermissions, equalTo(filePermissions)); + } catch (final IOException e) { + fail("Should create correct files and set correct permissions"); + } + } else { + assertThat(file.canRead(), is(true)); + assertThat(file.canWrite(), is(true)); + assertThat(file.canExecute(), is(true)); + } + } + + @Test + public void shouldParseUnnamedTaskId() { + final TaskId task = new TaskId(1, 0); + assertThat(TaskId.parse(task.toString()), equalTo(task)); + } + + @Test + public void shouldParseNamedTaskId() { + final TaskId task = new TaskId(1, 0, "namedTopology"); + assertThat(TaskId.parse(task.toString()), equalTo(task)); + } + + @Test + public void shouldCreateTaskStateDirectory() { + final TaskId taskId = new TaskId(0, 0); + final File taskDirectory = directory.getOrCreateDirectoryForTask(taskId); + assertTrue(taskDirectory.exists()); + assertTrue(taskDirectory.isDirectory()); + } + + @Test + public void shouldBeTrueIfAlreadyHoldsLock() { + final TaskId taskId = new TaskId(0, 0); + directory.getOrCreateDirectoryForTask(taskId); + directory.lock(taskId); + try { + assertTrue(directory.lock(taskId)); + } finally { + directory.unlock(taskId); + } + } + + @Test + public void shouldBeAbleToUnlockEvenWithoutLocking() { + final TaskId taskId = new TaskId(0, 0); + directory.unlock(taskId); + } + + @Test + public void shouldReportDirectoryEmpty() throws IOException { + final TaskId taskId = new TaskId(0, 0); + + // when task dir first created, it should be empty + assertTrue(directory.directoryForTaskIsEmpty(taskId)); + + // after locking, it should still be empty + directory.lock(taskId); + assertTrue(directory.directoryForTaskIsEmpty(taskId)); + + // after writing checkpoint, it should still be empty + final OffsetCheckpoint checkpointFile = new OffsetCheckpoint(new File(directory.getOrCreateDirectoryForTask(taskId), CHECKPOINT_FILE_NAME)); + assertTrue(directory.directoryForTaskIsEmpty(taskId)); + + checkpointFile.write(Collections.singletonMap(new TopicPartition("topic", 0), 0L)); + assertTrue(directory.directoryForTaskIsEmpty(taskId)); + + // if some store dir is created, it should not be empty + final File dbDir = new File(new File(directory.getOrCreateDirectoryForTask(taskId), "db"), "store1"); + + Files.createDirectories(dbDir.getParentFile().toPath()); + Files.createDirectories(dbDir.getAbsoluteFile().toPath()); + + assertFalse(directory.directoryForTaskIsEmpty(taskId)); + + // after wiping out the state dir, the dir should show as empty again + Utils.delete(dbDir.getParentFile()); + assertTrue(directory.directoryForTaskIsEmpty(taskId)); + + directory.unlock(taskId); + assertTrue(directory.directoryForTaskIsEmpty(taskId)); + } + + @Test + public void shouldThrowProcessorStateException() throws IOException { + final TaskId taskId = new TaskId(0, 0); + + Utils.delete(stateDir); + + assertThrows(ProcessorStateException.class, () -> directory.getOrCreateDirectoryForTask(taskId)); + } + + @Test + public void shouldThrowProcessorStateExceptionIfStateDirOccupied() throws IOException { + final TaskId taskId = new TaskId(0, 0); + + // Replace application's stateDir to regular file + Utils.delete(appDir); + appDir.createNewFile(); + + assertThrows(ProcessorStateException.class, () -> directory.getOrCreateDirectoryForTask(taskId)); + } + + @Test + public void shouldThrowProcessorStateExceptionIfTestDirOccupied() throws IOException { + final TaskId taskId = new TaskId(0, 0); + + // Replace taskDir to a regular file + final File taskDir = new File(appDir, toTaskDirString(taskId)); + Utils.delete(taskDir); + taskDir.createNewFile(); + + // Error: ProcessorStateException should be thrown. + assertThrows(ProcessorStateException.class, () -> directory.getOrCreateDirectoryForTask(taskId)); + } + + @Test + public void shouldNotThrowIfStateDirectoryHasBeenDeleted() throws IOException { + final TaskId taskId = new TaskId(0, 0); + + Utils.delete(stateDir); + assertThrows(IllegalStateException.class, () -> directory.lock(taskId)); + } + + @Test + public void shouldLockMultipleTaskDirectories() { + final TaskId taskId = new TaskId(0, 0); + final TaskId taskId2 = new TaskId(1, 0); + + assertThat(directory.lock(taskId), is(true)); + assertThat(directory.lock(taskId2), is(true)); + directory.unlock(taskId); + directory.unlock(taskId2); + } + + @Test + public void shouldCleanUpTaskStateDirectoriesThatAreNotCurrentlyLocked() { + final TaskId task0 = new TaskId(0, 0); + final TaskId task1 = new TaskId(1, 0); + final TaskId task2 = new TaskId(2, 0); + try { + assertTrue(new File(directory.getOrCreateDirectoryForTask(task0), "store").mkdir()); + assertTrue(new File(directory.getOrCreateDirectoryForTask(task1), "store").mkdir()); + assertTrue(new File(directory.getOrCreateDirectoryForTask(task2), "store").mkdir()); + + directory.lock(task0); + directory.lock(task1); + + final TaskDirectory dir0 = new TaskDirectory(new File(appDir, toTaskDirString(task0)), null); + final TaskDirectory dir1 = new TaskDirectory(new File(appDir, toTaskDirString(task1)), null); + final TaskDirectory dir2 = new TaskDirectory(new File(appDir, toTaskDirString(task2)), null); + + List files = directory.listAllTaskDirectories(); + assertEquals(mkSet(dir0, dir1, dir2), new HashSet<>(files)); + + files = directory.listNonEmptyTaskDirectories(); + assertEquals(mkSet(dir0, dir1, dir2), new HashSet<>(files)); + + time.sleep(5000); + directory.cleanRemovedTasks(0); + + files = directory.listAllTaskDirectories(); + assertEquals(mkSet(dir0, dir1), new HashSet<>(files)); + + files = directory.listNonEmptyTaskDirectories(); + assertEquals(mkSet(dir0, dir1), new HashSet<>(files)); + } finally { + directory.unlock(task0); + directory.unlock(task1); + } + } + + @Test + public void shouldCleanupStateDirectoriesWhenLastModifiedIsLessThanNowMinusCleanupDelay() { + final File dir = directory.getOrCreateDirectoryForTask(new TaskId(2, 0)); + assertTrue(new File(dir, "store").mkdir()); + + final int cleanupDelayMs = 60000; + directory.cleanRemovedTasks(cleanupDelayMs); + assertTrue(dir.exists()); + assertEquals(1, directory.listAllTaskDirectories().size()); + assertEquals(1, directory.listNonEmptyTaskDirectories().size()); + + time.sleep(cleanupDelayMs + 1000); + directory.cleanRemovedTasks(cleanupDelayMs); + assertFalse(dir.exists()); + assertEquals(0, directory.listAllTaskDirectories().size()); + assertEquals(0, directory.listNonEmptyTaskDirectories().size()); + } + + @Test + public void shouldCleanupObsoleteTaskDirectoriesAndDeleteTheDirectoryItself() { + final File dir = directory.getOrCreateDirectoryForTask(new TaskId(2, 0)); + assertTrue(new File(dir, "store").mkdir()); + assertEquals(1, directory.listAllTaskDirectories().size()); + assertEquals(1, directory.listNonEmptyTaskDirectories().size()); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(StateDirectory.class)) { + time.sleep(5000); + directory.cleanRemovedTasks(0); + assertFalse(dir.exists()); + assertEquals(0, directory.listAllTaskDirectories().size()); + assertEquals(0, directory.listNonEmptyTaskDirectories().size()); + assertThat( + appender.getMessages(), + hasItem(containsString("Deleting obsolete state directory")) + ); + } + } + + @Test + public void shouldNotRemoveNonTaskDirectoriesAndFiles() { + final File otherDir = TestUtils.tempDirectory(stateDir.toPath(), "foo"); + directory.cleanRemovedTasks(0); + assertTrue(otherDir.exists()); + } + + @Test + public void shouldReturnEmptyArrayForNonPersistentApp() throws IOException { + initializeStateDirectory(false, false); + assertTrue(directory.listAllTaskDirectories().isEmpty()); + } + + @Test + public void shouldReturnEmptyArrayIfStateDirDoesntExist() throws IOException { + cleanup(); + assertFalse(stateDir.exists()); + assertTrue(directory.listAllTaskDirectories().isEmpty()); + } + + @Test + public void shouldReturnEmptyArrayIfListFilesReturnsNull() throws IOException { + stateDir = new File(TestUtils.IO_TMP_DIR, "kafka-" + TestUtils.randomString(5)); + directory = new StateDirectory( + new StreamsConfig(new Properties() { + { + put(StreamsConfig.APPLICATION_ID_CONFIG, applicationId); + put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"); + put(StreamsConfig.STATE_DIR_CONFIG, stateDir.getPath()); + } + }), + time, + true, + false); + appDir = new File(stateDir, applicationId); + + // make sure the File#listFiles returns null and StateDirectory#listAllTaskDirectories is able to handle null + Utils.delete(appDir); + assertTrue(appDir.createNewFile()); + assertTrue(appDir.exists()); + assertNull(appDir.listFiles()); + assertEquals(0, directory.listAllTaskDirectories().size()); + } + + @Test + public void shouldOnlyListNonEmptyTaskDirectories() throws IOException { + TestUtils.tempDirectory(stateDir.toPath(), "foo"); + final TaskDirectory taskDir1 = new TaskDirectory(directory.getOrCreateDirectoryForTask(new TaskId(0, 0)), null); + final TaskDirectory taskDir2 = new TaskDirectory(directory.getOrCreateDirectoryForTask(new TaskId(0, 1)), null); + + final File storeDir = new File(taskDir1.file(), "store"); + assertTrue(storeDir.mkdir()); + + assertThat(mkSet(taskDir1, taskDir2), equalTo(new HashSet<>(directory.listAllTaskDirectories()))); + assertThat(singletonList(taskDir1), equalTo(directory.listNonEmptyTaskDirectories())); + + Utils.delete(taskDir1.file()); + + assertThat(singleton(taskDir2), equalTo(new HashSet<>(directory.listAllTaskDirectories()))); + assertThat(emptyList(), equalTo(directory.listNonEmptyTaskDirectories())); + } + + @Test + public void shouldCreateDirectoriesIfParentDoesntExist() { + final File tempDir = TestUtils.tempDirectory(); + final File stateDir = new File(new File(tempDir, "foo"), "state-dir"); + final StateDirectory stateDirectory = new StateDirectory( + new StreamsConfig(new Properties() { + { + put(StreamsConfig.APPLICATION_ID_CONFIG, applicationId); + put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"); + put(StreamsConfig.STATE_DIR_CONFIG, stateDir.getPath()); + } + }), + time, + true, + false); + final File taskDir = stateDirectory.getOrCreateDirectoryForTask(new TaskId(0, 0)); + assertTrue(stateDir.exists()); + assertTrue(taskDir.exists()); + } + + @Test + public void shouldNotLockStateDirLockedByAnotherThread() throws Exception { + final TaskId taskId = new TaskId(0, 0); + final Thread thread = new Thread(() -> directory.lock(taskId)); + thread.start(); + thread.join(30000); + assertFalse(directory.lock(taskId)); + } + + @Test + public void shouldNotUnLockStateDirLockedByAnotherThread() throws Exception { + final TaskId taskId = new TaskId(0, 0); + final CountDownLatch lockLatch = new CountDownLatch(1); + final CountDownLatch unlockLatch = new CountDownLatch(1); + final AtomicReference exceptionOnThread = new AtomicReference<>(); + final Thread thread = new Thread(() -> { + try { + directory.lock(taskId); + lockLatch.countDown(); + unlockLatch.await(); + directory.unlock(taskId); + } catch (final Exception e) { + exceptionOnThread.set(e); + } + }); + thread.start(); + lockLatch.await(5, TimeUnit.SECONDS); + + assertNull("should not have had an exception on other thread", exceptionOnThread.get()); + directory.unlock(taskId); + assertFalse(directory.lock(taskId)); + + unlockLatch.countDown(); + thread.join(30000); + + assertNull("should not have had an exception on other thread", exceptionOnThread.get()); + assertTrue(directory.lock(taskId)); + } + + @Test + public void shouldCleanupAllTaskDirectoriesIncludingGlobalOne() { + final TaskId id = new TaskId(1, 0); + directory.getOrCreateDirectoryForTask(id); + directory.globalStateDir(); + + final File dir0 = new File(appDir, id.toString()); + final File globalDir = new File(appDir, "global"); + assertEquals(mkSet(dir0, globalDir), Arrays.stream( + Objects.requireNonNull(appDir.listFiles())).collect(Collectors.toSet())); + + directory.clean(); + + // if appDir is empty, it is deleted in StateDirectory#clean process. + assertFalse(appDir.exists()); + } + + @Test + public void shouldNotCreateBaseDirectory() throws IOException { + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(StateDirectory.class)) { + initializeStateDirectory(false, false); + assertThat(stateDir.exists(), is(false)); + assertThat(appDir.exists(), is(false)); + assertThat(appender.getMessages(), + not(hasItem(containsString("Error changing permissions for the state or base directory")))); + } + } + + @Test + public void shouldNotCreateTaskStateDirectory() throws IOException { + initializeStateDirectory(false, false); + final TaskId taskId = new TaskId(0, 0); + final File taskDirectory = directory.getOrCreateDirectoryForTask(taskId); + assertFalse(taskDirectory.exists()); + } + + @Test + public void shouldNotCreateGlobalStateDirectory() throws IOException { + initializeStateDirectory(false, false); + final File globalStateDir = directory.globalStateDir(); + assertFalse(globalStateDir.exists()); + } + + @Test + public void shouldLockTaskStateDirectoryWhenDirectoryCreationDisabled() throws IOException { + initializeStateDirectory(false, false); + final TaskId taskId = new TaskId(0, 0); + assertTrue(directory.lock(taskId)); + } + + @Test + public void shouldNotFailWhenCreatingTaskDirectoryInParallel() throws Exception { + final TaskId taskId = new TaskId(0, 0); + final AtomicBoolean passed = new AtomicBoolean(true); + + final CreateTaskDirRunner runner = new CreateTaskDirRunner(directory, taskId, passed); + + final Thread t1 = new Thread(runner); + final Thread t2 = new Thread(runner); + + t1.start(); + t2.start(); + + t1.join(Duration.ofMillis(500L).toMillis()); + t2.join(Duration.ofMillis(500L).toMillis()); + + assertNotNull(runner.taskDirectory); + assertTrue(passed.get()); + assertTrue(runner.taskDirectory.exists()); + assertTrue(runner.taskDirectory.isDirectory()); + } + + @Test + public void shouldDeleteAppDirWhenCleanUpIfEmpty() { + final TaskId taskId = new TaskId(0, 0); + final File taskDirectory = directory.getOrCreateDirectoryForTask(taskId); + final File testFile = new File(taskDirectory, "testFile"); + assertThat(testFile.mkdir(), is(true)); + assertThat(directory.directoryForTaskIsEmpty(taskId), is(false)); + + // call StateDirectory#clean + directory.clean(); + + // if appDir is empty, it is deleted in StateDirectory#clean process. + assertFalse(appDir.exists()); + } + + @Test + public void shouldNotDeleteAppDirWhenCleanUpIfNotEmpty() throws IOException { + final TaskId taskId = new TaskId(0, 0); + final File taskDirectory = directory.getOrCreateDirectoryForTask(taskId); + final File testFile = new File(taskDirectory, "testFile"); + assertThat(testFile.mkdir(), is(true)); + assertThat(directory.directoryForTaskIsEmpty(taskId), is(false)); + + // Create a dummy file in appDir; for this, appDir will not be empty after cleanup. + final File dummyFile = new File(appDir, "dummy"); + assertTrue(dummyFile.createNewFile()); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(StateDirectory.class)) { + // call StateDirectory#clean + directory.clean(); + assertThat( + appender.getMessages(), + hasItem(endsWith(String.format("Failed to delete state store directory of %s for it is not empty", appDir.getAbsolutePath()))) + ); + } + } + + @Test + public void shouldLogManualUserCallMessage() { + final TaskId taskId = new TaskId(0, 0); + final File taskDirectory = directory.getOrCreateDirectoryForTask(taskId); + final File testFile = new File(taskDirectory, "testFile"); + assertThat(testFile.mkdir(), is(true)); + assertThat(directory.directoryForTaskIsEmpty(taskId), is(false)); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(StateDirectory.class)) { + directory.clean(); + assertThat( + appender.getMessages(), + hasItem(endsWith("as user calling cleanup.")) + ); + } + } + + @Test + public void shouldLogStateDirCleanerMessage() { + final TaskId taskId = new TaskId(0, 0); + final File taskDirectory = directory.getOrCreateDirectoryForTask(taskId); + final File testFile = new File(taskDirectory, "testFile"); + assertThat(testFile.mkdir(), is(true)); + assertThat(directory.directoryForTaskIsEmpty(taskId), is(false)); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(StateDirectory.class)) { + final long cleanupDelayMs = 0; + time.sleep(5000); + directory.cleanRemovedTasks(cleanupDelayMs); + assertThat(appender.getMessages(), hasItem(endsWith("ms has elapsed (cleanup delay is " + cleanupDelayMs + "ms)."))); + } + } + + @Test + public void shouldLogTempDirMessage() { + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(StateDirectory.class)) { + new StateDirectory( + new StreamsConfig( + mkMap( + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, ""), + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "") + ) + ), + new MockTime(), + true, + false + ); + assertThat( + appender.getMessages(), + hasItem("Using an OS temp directory in the state.dir property can cause failures with writing the" + + " checkpoint file due to the fact that this directory can be cleared by the OS." + + " Resolved state.dir: [" + System.getProperty("java.io.tmpdir") + "/kafka-streams]") + ); + } + } + + /************* Named Topology Tests *************/ + + @Test + public void shouldCreateTaskDirectoriesUnderNamedTopologyDirs() throws IOException { + initializeStateDirectory(true, true); + + directory.getOrCreateDirectoryForTask(new TaskId(0, 0, "topology1")); + directory.getOrCreateDirectoryForTask(new TaskId(0, 1, "topology1")); + directory.getOrCreateDirectoryForTask(new TaskId(0, 0, "topology2")); + + assertThat(new File(appDir, "__topology1__").exists(), is(true)); + assertThat(new File(appDir, "__topology1__").isDirectory(), is(true)); + assertThat(new File(appDir, "__topology2__").exists(), is(true)); + assertThat(new File(appDir, "__topology2__").isDirectory(), is(true)); + + assertThat(new File(new File(appDir, "__topology1__"), "0_0").exists(), is(true)); + assertThat(new File(new File(appDir, "__topology1__"), "0_0").isDirectory(), is(true)); + assertThat(new File(new File(appDir, "__topology1__"), "0_1").exists(), is(true)); + assertThat(new File(new File(appDir, "__topology1__"), "0_1").isDirectory(), is(true)); + assertThat(new File(new File(appDir, "__topology2__"), "0_0").exists(), is(true)); + assertThat(new File(new File(appDir, "__topology2__"), "0_0").isDirectory(), is(true)); + } + + @Test + public void shouldOnlyListNonEmptyTaskDirectoriesInNamedTopologies() throws IOException { + initializeStateDirectory(true, true); + + TestUtils.tempDirectory(appDir.toPath(), "foo"); + final TaskDirectory taskDir1 = new TaskDirectory(directory.getOrCreateDirectoryForTask(new TaskId(0, 0, "topology1")), "topology1"); + final TaskDirectory taskDir2 = new TaskDirectory(directory.getOrCreateDirectoryForTask(new TaskId(0, 1, "topology1")), "topology1"); + final TaskDirectory taskDir3 = new TaskDirectory(directory.getOrCreateDirectoryForTask(new TaskId(0, 0, "topology2")), "topology2"); + + final File storeDir = new File(taskDir1.file(), "store"); + assertTrue(storeDir.mkdir()); + + assertThat(new HashSet<>(directory.listAllTaskDirectories()), equalTo(mkSet(taskDir1, taskDir2, taskDir3))); + assertThat(directory.listNonEmptyTaskDirectories(), equalTo(singletonList(taskDir1))); + + Utils.delete(taskDir1.file()); + + assertThat(new HashSet<>(directory.listAllTaskDirectories()), equalTo(mkSet(taskDir2, taskDir3))); + assertThat(directory.listNonEmptyTaskDirectories(), equalTo(emptyList())); + } + + @Test + public void shouldRemoveNonEmptyNamedTopologyDirsWhenCallingClean() throws Exception { + initializeStateDirectory(true, true); + final File taskDir = directory.getOrCreateDirectoryForTask(new TaskId(2, 0, "topology1")); + final File namedTopologyDir = new File(appDir, "__topology1__"); + + assertThat(taskDir.exists(), is(true)); + assertThat(namedTopologyDir.exists(), is(true)); + directory.clean(); + assertThat(taskDir.exists(), is(false)); + assertThat(namedTopologyDir.exists(), is(false)); + } + + @Test + public void shouldRemoveEmptyNamedTopologyDirsWhenCallingClean() throws IOException { + initializeStateDirectory(true, true); + final File namedTopologyDir = new File(appDir, "__topology1__"); + assertThat(namedTopologyDir.mkdir(), is(true)); + assertThat(namedTopologyDir.exists(), is(true)); + directory.clean(); + assertThat(namedTopologyDir.exists(), is(false)); + } + + @Test + public void shouldRemoveNonEmptyNamedTopologyDirsWhenCallingClearLocalStateForNamedTopology() throws Exception { + initializeStateDirectory(true, true); + final String topologyName = "topology1"; + final File taskDir = directory.getOrCreateDirectoryForTask(new TaskId(2, 0, topologyName)); + final File namedTopologyDir = new File(appDir, "__" + topologyName + "__"); + + assertThat(taskDir.exists(), is(true)); + assertThat(namedTopologyDir.exists(), is(true)); + directory.clearLocalStateForNamedTopology(topologyName); + assertThat(taskDir.exists(), is(false)); + assertThat(namedTopologyDir.exists(), is(false)); + } + + @Test + public void shouldRemoveEmptyNamedTopologyDirsWhenCallingClearLocalStateForNamedTopology() throws IOException { + initializeStateDirectory(true, true); + final String topologyName = "topology1"; + final File namedTopologyDir = new File(appDir, "__" + topologyName + "__"); + assertThat(namedTopologyDir.mkdir(), is(true)); + assertThat(namedTopologyDir.exists(), is(true)); + directory.clearLocalStateForNamedTopology(topologyName); + assertThat(namedTopologyDir.exists(), is(false)); + } + + @Test + public void shouldNotRemoveDirsThatDoNotMatchNamedTopologyDirsWhenCallingClean() throws IOException { + initializeStateDirectory(true, true); + final File someDir = new File(appDir, "_not-a-valid-named-topology_dir_name_"); + assertThat(someDir.mkdir(), is(true)); + assertThat(someDir.exists(), is(true)); + directory.clean(); + assertThat(someDir.exists(), is(true)); + } + + @Test + public void shouldCleanupObsoleteTaskDirectoriesInNamedTopologiesAndDeleteTheParentDirectories() throws IOException { + initializeStateDirectory(true, true); + + final File taskDir = directory.getOrCreateDirectoryForTask(new TaskId(2, 0, "topology1")); + final File namedTopologyDir = new File(appDir, "__topology1__"); + assertThat(namedTopologyDir.exists(), is(true)); + assertThat(taskDir.exists(), is(true)); + assertTrue(new File(taskDir, "store").mkdir()); + assertThat(directory.listAllTaskDirectories().size(), is(1)); + assertThat(directory.listNonEmptyTaskDirectories().size(), is(1)); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(StateDirectory.class)) { + time.sleep(5000); + directory.cleanRemovedTasks(0); + assertThat(taskDir.exists(), is(false)); + assertThat(namedTopologyDir.exists(), is(false)); + assertThat(directory.listAllTaskDirectories().size(), is(0)); + assertThat(directory.listNonEmptyTaskDirectories().size(), is(0)); + assertThat( + appender.getMessages(), + hasItem(containsString("Deleting obsolete state directory")) + ); + } + } + + /************************************************/ + + @Test + public void shouldPersistProcessIdAcrossRestart() { + final UUID processId = directory.initializeProcessId(); + directory.close(); + assertThat(directory.initializeProcessId(), equalTo(processId)); + } + + @Test + public void shouldGetFreshProcessIdIfProcessFileDeleted() { + final UUID processId = directory.initializeProcessId(); + directory.close(); + + final File processFile = new File(appDir, PROCESS_FILE_NAME); + assertThat(processFile.exists(), is(true)); + assertThat(processFile.delete(), is(true)); + + assertThat(directory.initializeProcessId(), not(processId)); + } + + @Test + public void shouldGetFreshProcessIdIfJsonUnreadable() throws Exception { + final File processFile = new File(appDir, PROCESS_FILE_NAME); + assertThat(processFile.createNewFile(), is(true)); + final UUID processId = UUID.randomUUID(); + + final FileOutputStream fileOutputStream = new FileOutputStream(processFile); + try (final BufferedWriter writer = new BufferedWriter( + new OutputStreamWriter(fileOutputStream, StandardCharsets.UTF_8))) { + writer.write(processId.toString()); + writer.flush(); + fileOutputStream.getFD().sync(); + } + + assertThat(directory.initializeProcessId(), not(processId)); + } + + @Test + public void shouldReadFutureProcessFileFormat() throws Exception { + final File processFile = new File(appDir, PROCESS_FILE_NAME); + final ObjectMapper mapper = new ObjectMapper(); + final UUID processId = UUID.randomUUID(); + mapper.writeValue(processFile, new FutureStateDirectoryProcessFile(processId, "some random junk")); + + assertThat(directory.initializeProcessId(), equalTo(processId)); + } + + private static class FutureStateDirectoryProcessFile { + + @JsonProperty + private final UUID processId; + + @JsonProperty + private final String newField; + + // required by jackson -- do not remove, your IDE may be warning that this is unused but it's lying to you + public FutureStateDirectoryProcessFile() { + this.processId = null; + this.newField = null; + } + + FutureStateDirectoryProcessFile(final UUID processId, final String newField) { + this.processId = processId; + this.newField = newField; + + } + } + + private static class CreateTaskDirRunner implements Runnable { + private final StateDirectory directory; + private final TaskId taskId; + private final AtomicBoolean passed; + + private File taskDirectory; + + private CreateTaskDirRunner(final StateDirectory directory, + final TaskId taskId, + final AtomicBoolean passed) { + this.directory = directory; + this.taskId = taskId; + this.passed = passed; + } + + @Override + public void run() { + try { + taskDirectory = directory.getOrCreateDirectoryForTask(taskId); + } catch (final ProcessorStateException error) { + passed.set(false); + } + } + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateManagerStub.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateManagerStub.java new file mode 100644 index 0000000..122c992 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateManagerStub.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.streams.processor.StateRestoreCallback; +import org.apache.kafka.streams.processor.StateStore; + +import java.io.File; +import java.util.Map; +import org.apache.kafka.streams.processor.internals.Task.TaskType; + +public class StateManagerStub implements StateManager { + + @Override + public File baseDir() { + return null; + } + + @Override + public void registerStore(final StateStore store, + final StateRestoreCallback stateRestoreCallback) {} + + @Override + public void flush() {} + + @Override + public void close() {} + + @Override + public StateStore getStore(final String name) { + return null; + } + + @Override + public StateStore getGlobalStore(final String name) { + return null; + } + + @Override + public Map changelogOffsets() { + return null; + } + + @Override + public void updateChangelogOffsets(final Map writtenOffsets) {} + + @Override + public void checkpoint() {} + + @Override + public TaskType taskType() { + return null; + } + + @Override + public String changelogFor(final String storeName) { + return null; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateManagerUtilTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateManagerUtilTest.java new file mode 100644 index 0000000..bc7fb14 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateManagerUtilTest.java @@ -0,0 +1,337 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import java.util.List; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.errors.LockException; +import org.apache.kafka.streams.errors.ProcessorStateException; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.Task.TaskType; +import org.apache.kafka.test.MockKeyValueStore; +import org.apache.kafka.test.TestUtils; +import org.easymock.IMocksControl; +import org.easymock.Mock; +import org.easymock.MockType; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; +import org.slf4j.Logger; + +import java.io.File; +import java.io.IOException; +import java.util.Arrays; + +import static java.util.Collections.emptyList; +import static java.util.Collections.singletonList; +import static org.easymock.EasyMock.createStrictControl; +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.expectLastCall; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.powermock.api.easymock.PowerMock.mockStatic; +import static org.powermock.api.easymock.PowerMock.replayAll; + +@RunWith(PowerMockRunner.class) +@PrepareForTest(Utils.class) +public class StateManagerUtilTest { + + @Mock(type = MockType.NICE) + private ProcessorStateManager stateManager; + + @Mock(type = MockType.NICE) + private StateDirectory stateDirectory; + + @Mock(type = MockType.NICE) + private ProcessorTopology topology; + + @Mock(type = MockType.NICE) + private InternalProcessorContext processorContext; + + private IMocksControl ctrl; + + private Logger logger = new LogContext("test").logger(AbstractTask.class); + + private final TaskId taskId = new TaskId(0, 0); + + @Before + public void setup() { + ctrl = createStrictControl(); + topology = ctrl.createMock(ProcessorTopology.class); + processorContext = ctrl.createMock(InternalProcessorContext.class); + + stateManager = ctrl.createMock(ProcessorStateManager.class); + stateDirectory = ctrl.createMock(StateDirectory.class); + } + + @Test + public void testRegisterStateStoreWhenTopologyEmpty() { + expect(topology.stateStores()).andReturn(emptyList()); + + ctrl.checkOrder(true); + ctrl.replay(); + + StateManagerUtil.registerStateStores(logger, + "logPrefix:", topology, stateManager, stateDirectory, processorContext); + + ctrl.verify(); + } + + @Test + public void testRegisterStateStoreFailToLockStateDirectory() { + expect(topology.stateStores()).andReturn(singletonList(new MockKeyValueStore("store", false))); + + expect(stateManager.taskId()).andReturn(taskId); + + expect(stateDirectory.lock(taskId)).andReturn(false); + + ctrl.checkOrder(true); + ctrl.replay(); + + final LockException thrown = assertThrows(LockException.class, + () -> StateManagerUtil.registerStateStores(logger, "logPrefix:", + topology, stateManager, stateDirectory, processorContext)); + + assertEquals("logPrefix:Failed to lock the state directory for task 0_0", thrown.getMessage()); + + ctrl.verify(); + } + + @Test + public void testRegisterStateStores() { + final MockKeyValueStore store1 = new MockKeyValueStore("store1", false); + final MockKeyValueStore store2 = new MockKeyValueStore("store2", false); + final List stateStores = Arrays.asList(store1, store2); + + expect(topology.stateStores()).andReturn(stateStores); + + expect(stateManager.taskId()).andReturn(taskId); + + expect(stateDirectory.lock(taskId)).andReturn(true); + expect(stateDirectory.directoryForTaskIsEmpty(taskId)).andReturn(true); + + expect(topology.stateStores()).andReturn(stateStores); + + stateManager.registerStateStores(stateStores, processorContext); + + stateManager.initializeStoreOffsetsFromCheckpoint(true); + expectLastCall(); + + ctrl.checkOrder(true); + ctrl.replay(); + + StateManagerUtil.registerStateStores(logger, "logPrefix:", + topology, stateManager, stateDirectory, processorContext); + + ctrl.verify(); + } + + @Test + public void testCloseStateManagerClean() { + expect(stateManager.taskId()).andReturn(taskId); + + expect(stateDirectory.lock(taskId)).andReturn(true); + + stateManager.close(); + expectLastCall(); + + stateDirectory.unlock(taskId); + expectLastCall(); + + ctrl.checkOrder(true); + ctrl.replay(); + + StateManagerUtil.closeStateManager(logger, + "logPrefix:", true, false, stateManager, stateDirectory, TaskType.ACTIVE); + + ctrl.verify(); + } + + @Test + public void testCloseStateManagerThrowsExceptionWhenClean() { + expect(stateManager.taskId()).andReturn(taskId); + + expect(stateDirectory.lock(taskId)).andReturn(true); + + stateManager.close(); + expectLastCall().andThrow(new ProcessorStateException("state manager failed to close")); + + // The unlock logic should still be executed. + stateDirectory.unlock(taskId); + + ctrl.checkOrder(true); + ctrl.replay(); + + final ProcessorStateException thrown = assertThrows( + ProcessorStateException.class, () -> StateManagerUtil.closeStateManager(logger, + "logPrefix:", true, false, stateManager, stateDirectory, TaskType.ACTIVE)); + + // Thrown stateMgr exception will not be wrapped. + assertEquals("state manager failed to close", thrown.getMessage()); + + ctrl.verify(); + } + + @Test + public void testCloseStateManagerThrowsExceptionWhenDirty() { + expect(stateManager.taskId()).andReturn(taskId); + + expect(stateDirectory.lock(taskId)).andReturn(true); + + stateManager.close(); + expectLastCall().andThrow(new ProcessorStateException("state manager failed to close")); + + stateDirectory.unlock(taskId); + + ctrl.checkOrder(true); + ctrl.replay(); + + assertThrows( + ProcessorStateException.class, + () -> StateManagerUtil.closeStateManager( + logger, "logPrefix:", false, false, stateManager, stateDirectory, TaskType.ACTIVE)); + + ctrl.verify(); + } + + @Test + public void testCloseStateManagerWithStateStoreWipeOut() { + expect(stateManager.taskId()).andReturn(taskId); + expect(stateDirectory.lock(taskId)).andReturn(true); + + stateManager.close(); + expectLastCall(); + + // The `baseDir` will be accessed when attempting to delete the state store. + expect(stateManager.baseDir()).andReturn(TestUtils.tempDirectory("state_store")); + + stateDirectory.unlock(taskId); + expectLastCall(); + + ctrl.checkOrder(true); + ctrl.replay(); + + StateManagerUtil.closeStateManager(logger, + "logPrefix:", false, true, stateManager, stateDirectory, TaskType.ACTIVE); + + ctrl.verify(); + } + + @Test + public void shouldStillWipeStateStoresIfCloseThrowsException() throws IOException { + final File randomFile = new File("/random/path"); + mockStatic(Utils.class); + + expect(stateManager.taskId()).andReturn(taskId); + expect(stateDirectory.lock(taskId)).andReturn(true); + + stateManager.close(); + expectLastCall().andThrow(new ProcessorStateException("Close failed")); + + expect(stateManager.baseDir()).andReturn(randomFile); + + Utils.delete(randomFile); + + stateDirectory.unlock(taskId); + expectLastCall(); + + ctrl.checkOrder(true); + ctrl.replay(); + + replayAll(); + + assertThrows(ProcessorStateException.class, () -> + StateManagerUtil.closeStateManager(logger, "logPrefix:", false, true, stateManager, stateDirectory, TaskType.ACTIVE)); + + ctrl.verify(); + } + + @Test + public void testCloseStateManagerWithStateStoreWipeOutRethrowWrappedIOException() throws IOException { + final File unknownFile = new File("/unknown/path"); + mockStatic(Utils.class); + + expect(stateManager.taskId()).andReturn(taskId); + expect(stateDirectory.lock(taskId)).andReturn(true); + + stateManager.close(); + expectLastCall(); + + expect(stateManager.baseDir()).andReturn(unknownFile); + + Utils.delete(unknownFile); + expectLastCall().andThrow(new IOException("Deletion failed")); + + stateDirectory.unlock(taskId); + expectLastCall(); + + ctrl.checkOrder(true); + ctrl.replay(); + + replayAll(); + + final ProcessorStateException thrown = assertThrows( + ProcessorStateException.class, () -> StateManagerUtil.closeStateManager(logger, + "logPrefix:", false, true, stateManager, stateDirectory, TaskType.ACTIVE)); + + assertEquals(IOException.class, thrown.getCause().getClass()); + + ctrl.verify(); + } + + @Test + public void shouldNotCloseStateManagerIfUnableToLockTaskDirectory() { + expect(stateManager.taskId()).andReturn(taskId); + + expect(stateDirectory.lock(taskId)).andReturn(false); + + stateManager.close(); + expectLastCall().andThrow(new AssertionError("Should not be trying to close state you don't own!")); + + ctrl.checkOrder(true); + ctrl.replay(); + + replayAll(); + + StateManagerUtil.closeStateManager( + logger, "logPrefix:", true, false, stateManager, stateDirectory, TaskType.ACTIVE); + } + + @Test + public void shouldNotWipeStateStoresIfUnableToLockTaskDirectory() throws IOException { + final File unknownFile = new File("/unknown/path"); + expect(stateManager.taskId()).andReturn(taskId); + expect(stateDirectory.lock(taskId)).andReturn(false); + + expect(stateManager.baseDir()).andReturn(unknownFile); + + Utils.delete(unknownFile); + expectLastCall().andThrow(new AssertionError("Should not be trying to wipe state you don't own!")); + + ctrl.checkOrder(true); + ctrl.replay(); + + replayAll(); + + StateManagerUtil.closeStateManager( + logger, "logPrefix:", false, true, stateManager, stateDirectory, TaskType.ACTIVE); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateRestoreCallbackAdapterTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateRestoreCallbackAdapterTest.java new file mode 100644 index 0000000..258bf1d --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StateRestoreCallbackAdapterTest.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.processor.internals; + + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.BatchingStateRestoreCallback; +import org.apache.kafka.streams.processor.StateRestoreCallback; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import static java.util.Arrays.asList; +import static org.apache.kafka.streams.processor.internals.StateRestoreCallbackAdapter.adapt; +import static org.easymock.EasyMock.mock; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.Is.is; +import static org.junit.Assert.assertThrows; + +public class StateRestoreCallbackAdapterTest { + @Test + public void shouldThrowOnRestoreAll() { + assertThrows(UnsupportedOperationException.class, () -> adapt(mock(StateRestoreCallback.class)).restoreAll(null)); + } + + @Test + public void shouldThrowOnRestore() { + assertThrows(UnsupportedOperationException.class, () -> adapt(mock(StateRestoreCallback.class)).restore(null, null)); + } + + @Test + public void shouldPassRecordsThrough() { + final ArrayList> actual = new ArrayList<>(); + final RecordBatchingStateRestoreCallback callback = actual::addAll; + + final RecordBatchingStateRestoreCallback adapted = adapt(callback); + + final byte[] key1 = {1}; + final byte[] value1 = {2}; + final byte[] key2 = {3}; + final byte[] value2 = {4}; + + final List> recordList = asList( + new ConsumerRecord<>("topic1", 0, 0L, key1, value1), + new ConsumerRecord<>("topic2", 1, 1L, key2, value2) + ); + + adapted.restoreBatch(recordList); + + validate(actual, recordList); + } + + @Test + public void shouldConvertToKeyValueBatches() { + final ArrayList> actual = new ArrayList<>(); + final BatchingStateRestoreCallback callback = new BatchingStateRestoreCallback() { + @Override + public void restoreAll(final Collection> records) { + actual.addAll(records); + } + + @Override + public void restore(final byte[] key, final byte[] value) { + // unreachable + } + }; + + final RecordBatchingStateRestoreCallback adapted = adapt(callback); + + final byte[] key1 = {1}; + final byte[] value1 = {2}; + final byte[] key2 = {3}; + final byte[] value2 = {4}; + adapted.restoreBatch(asList( + new ConsumerRecord<>("topic1", 0, 0L, key1, value1), + new ConsumerRecord<>("topic2", 1, 1L, key2, value2) + )); + + assertThat( + actual, + is(asList( + new KeyValue<>(key1, value1), + new KeyValue<>(key2, value2) + )) + ); + } + + @Test + public void shouldConvertToKeyValue() { + final ArrayList> actual = new ArrayList<>(); + final StateRestoreCallback callback = (key, value) -> actual.add(new KeyValue<>(key, value)); + + final RecordBatchingStateRestoreCallback adapted = adapt(callback); + + final byte[] key1 = {1}; + final byte[] value1 = {2}; + final byte[] key2 = {3}; + final byte[] value2 = {4}; + adapted.restoreBatch(asList( + new ConsumerRecord<>("topic1", 0, 0L, key1, value1), + new ConsumerRecord<>("topic2", 1, 1L, key2, value2) + )); + + assertThat( + actual, + is(asList( + new KeyValue<>(key1, value1), + new KeyValue<>(key2, value2) + )) + ); + } + + private void validate(final List> actual, + final List> expected) { + assertThat(actual.size(), is(expected.size())); + for (int i = 0; i < actual.size(); i++) { + final ConsumerRecord actual1 = actual.get(i); + final ConsumerRecord expected1 = expected.get(i); + assertThat(actual1.topic(), is(expected1.topic())); + assertThat(actual1.partition(), is(expected1.partition())); + assertThat(actual1.offset(), is(expected1.offset())); + assertThat(actual1.key(), is(expected1.key())); + assertThat(actual1.value(), is(expected1.value())); + assertThat(actual1.timestamp(), is(expected1.timestamp())); + assertThat(actual1.headers(), is(expected1.headers())); + } + } + + +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StoreChangelogReaderTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StoreChangelogReaderTest.java new file mode 100644 index 0000000..594fc7e --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StoreChangelogReaderTest.java @@ -0,0 +1,1175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.admin.ListOffsetsOptions; +import org.apache.kafka.clients.admin.ListOffsetsResult; +import org.apache.kafka.clients.admin.MockAdminClient; +import org.apache.kafka.clients.admin.OffsetSpec; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.MockConsumer; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.consumer.OffsetResetStrategy; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.ProcessorStateManager.StateStoreMetadata; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.apache.kafka.test.MockStateRestoreListener; +import org.apache.kafka.test.StreamsTestUtils; +import org.easymock.EasyMock; +import org.easymock.EasyMockRule; +import org.easymock.EasyMockSupport; +import org.easymock.Mock; +import org.easymock.MockType; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.time.Duration; +import java.util.Collections; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static java.util.Collections.singletonMap; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.streams.processor.internals.StoreChangelogReader.ChangelogReaderState.ACTIVE_RESTORING; +import static org.apache.kafka.streams.processor.internals.StoreChangelogReader.ChangelogReaderState.STANDBY_UPDATING; +import static org.apache.kafka.streams.processor.internals.Task.TaskType.ACTIVE; +import static org.apache.kafka.streams.processor.internals.Task.TaskType.STANDBY; +import static org.apache.kafka.test.MockStateRestoreListener.RESTORE_BATCH; +import static org.apache.kafka.test.MockStateRestoreListener.RESTORE_END; +import static org.apache.kafka.test.MockStateRestoreListener.RESTORE_START; +import static org.easymock.EasyMock.anyLong; +import static org.easymock.EasyMock.anyObject; +import static org.easymock.EasyMock.expectLastCall; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.resetToDefault; +import static org.easymock.EasyMock.verify; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasItem; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +@RunWith(Parameterized.class) +public class StoreChangelogReaderTest extends EasyMockSupport { + + @Rule + public EasyMockRule rule = new EasyMockRule(this); + + @Mock(type = MockType.NICE) + private ProcessorStateManager stateManager; + @Mock(type = MockType.NICE) + private ProcessorStateManager activeStateManager; + @Mock(type = MockType.NICE) + private ProcessorStateManager standbyStateManager; + @Mock(type = MockType.NICE) + private StateStoreMetadata storeMetadata; + @Mock(type = MockType.NICE) + private StateStoreMetadata storeMetadataOne; + @Mock(type = MockType.NICE) + private StateStoreMetadata storeMetadataTwo; + @Mock(type = MockType.NICE) + private StateStore store; + + @Parameterized.Parameters + public static Object[] data() { + return new Object[] {STANDBY, ACTIVE}; + } + + @Parameterized.Parameter + public Task.TaskType type; + + private final String storeName = "store"; + private final String topicName = "topic"; + private final LogContext logContext = new LogContext("test-reader "); + private final TopicPartition tp = new TopicPartition(topicName, 0); + private final TopicPartition tp1 = new TopicPartition("one", 0); + private final TopicPartition tp2 = new TopicPartition("two", 0); + private final StreamsConfig config = new StreamsConfig(StreamsTestUtils.getStreamsConfig("test-reader")); + private final MockTime time = new MockTime(); + private final MockStateRestoreListener callback = new MockStateRestoreListener(); + private final KafkaException kaboom = new KafkaException("KABOOM!"); + private final MockStateRestoreListener exceptionCallback = new MockStateRestoreListener() { + @Override + public void onRestoreStart(final TopicPartition tp, final String store, final long stOffset, final long edOffset) { + throw kaboom; + } + + @Override + public void onBatchRestored(final TopicPartition tp, final String store, final long bedOffset, final long numRestored) { + throw kaboom; + } + + @Override + public void onRestoreEnd(final TopicPartition tp, final String store, final long totalRestored) { + throw kaboom; + } + }; + + private final MockConsumer consumer = new MockConsumer<>(OffsetResetStrategy.EARLIEST); + private final MockAdminClient adminClient = new MockAdminClient(); + private final StoreChangelogReader changelogReader = + new StoreChangelogReader(time, config, logContext, adminClient, consumer, callback); + + @Before + public void setUp() { + EasyMock.expect(stateManager.storeMetadata(tp)).andReturn(storeMetadata).anyTimes(); + EasyMock.expect(stateManager.taskType()).andReturn(type).anyTimes(); + EasyMock.expect(activeStateManager.storeMetadata(tp)).andReturn(storeMetadata).anyTimes(); + EasyMock.expect(activeStateManager.taskType()).andReturn(ACTIVE).anyTimes(); + EasyMock.expect(standbyStateManager.storeMetadata(tp)).andReturn(storeMetadata).anyTimes(); + EasyMock.expect(standbyStateManager.taskType()).andReturn(STANDBY).anyTimes(); + + EasyMock.expect(storeMetadata.changelogPartition()).andReturn(tp).anyTimes(); + EasyMock.expect(storeMetadata.store()).andReturn(store).anyTimes(); + EasyMock.expect(store.name()).andReturn(storeName).anyTimes(); + } + + @After + public void tearDown() { + EasyMock.reset( + stateManager, + activeStateManager, + standbyStateManager, + storeMetadata, + storeMetadataOne, + storeMetadataTwo, + store + ); + } + + @Test + public void shouldNotRegisterSameStoreMultipleTimes() { + EasyMock.replay(stateManager, storeMetadata); + + changelogReader.register(tp, stateManager); + + assertEquals(StoreChangelogReader.ChangelogState.REGISTERED, changelogReader.changelogMetadata(tp).state()); + assertNull(changelogReader.changelogMetadata(tp).endOffset()); + assertEquals(0L, changelogReader.changelogMetadata(tp).totalRestored()); + + assertThrows(IllegalStateException.class, () -> changelogReader.register(tp, stateManager)); + } + + @Test + public void shouldNotRegisterStoreWithoutMetadata() { + EasyMock.replay(stateManager, storeMetadata); + + assertThrows(IllegalStateException.class, + () -> changelogReader.register(new TopicPartition("ChangelogWithoutStoreMetadata", 0), stateManager)); + } + + @Test + public void shouldInitializeChangelogAndCheckForCompletion() { + final Map mockTasks = mock(Map.class); + EasyMock.expect(mockTasks.get(null)).andReturn(mock(Task.class)).anyTimes(); + EasyMock.expect(storeMetadata.offset()).andReturn(9L).anyTimes(); + EasyMock.replay(mockTasks, stateManager, storeMetadata, store); + + adminClient.updateEndOffsets(Collections.singletonMap(tp, 10L)); + + final StoreChangelogReader changelogReader = + new StoreChangelogReader(time, config, logContext, adminClient, consumer, callback); + + changelogReader.register(tp, stateManager); + changelogReader.restore(mockTasks); + + assertEquals( + type == ACTIVE ? + StoreChangelogReader.ChangelogState.COMPLETED : + StoreChangelogReader.ChangelogState.RESTORING, + changelogReader.changelogMetadata(tp).state() + ); + assertEquals(type == ACTIVE ? 10L : null, changelogReader.changelogMetadata(tp).endOffset()); + assertEquals(0L, changelogReader.changelogMetadata(tp).totalRestored()); + assertEquals( + type == ACTIVE ? Collections.singleton(tp) : Collections.emptySet(), + changelogReader.completedChangelogs() + ); + assertEquals(10L, consumer.position(tp)); + assertEquals(Collections.singleton(tp), consumer.paused()); + + if (type == ACTIVE) { + assertEquals(tp, callback.restoreTopicPartition); + assertEquals(storeName, callback.storeNameCalledStates.get(RESTORE_START)); + assertEquals(storeName, callback.storeNameCalledStates.get(RESTORE_END)); + assertNull(callback.storeNameCalledStates.get(RESTORE_BATCH)); + } + } + + @Test + public void shouldTriggerRestoreListenerWithOffsetZeroIfPositionThrowsTimeoutException() { + // restore listener is only triggered for active tasks + if (type == ACTIVE) { + final Map mockTasks = mock(Map.class); + EasyMock.expect(mockTasks.get(null)).andReturn(mock(Task.class)).anyTimes(); + EasyMock.expect(stateManager.changelogOffsets()).andReturn(singletonMap(tp, 5L)); + EasyMock.replay(mockTasks, stateManager, storeMetadata, store); + + adminClient.updateEndOffsets(Collections.singletonMap(tp, 10L)); + + final MockConsumer consumer = new MockConsumer(OffsetResetStrategy.EARLIEST) { + @Override + public long position(final TopicPartition partition) { + throw new TimeoutException("KABOOM!"); + } + }; + consumer.updateBeginningOffsets(Collections.singletonMap(tp, 5L)); + + final StoreChangelogReader changelogReader = + new StoreChangelogReader(time, config, logContext, adminClient, consumer, callback); + + changelogReader.register(tp, stateManager); + changelogReader.restore(mockTasks); + + assertThat(callback.restoreStartOffset, equalTo(0L)); + } + } + + @Test + public void shouldPollWithRightTimeout() { + final TaskId taskId = new TaskId(0, 0); + + EasyMock.expect(storeMetadata.offset()).andReturn(null).andReturn(9L).anyTimes(); + EasyMock.expect(stateManager.changelogOffsets()).andReturn(singletonMap(tp, 5L)); + EasyMock.expect(stateManager.taskId()).andReturn(taskId); + EasyMock.replay(stateManager, storeMetadata, store); + + consumer.updateBeginningOffsets(Collections.singletonMap(tp, 5L)); + adminClient.updateEndOffsets(Collections.singletonMap(tp, 11L)); + + final StoreChangelogReader changelogReader = + new StoreChangelogReader(time, config, logContext, adminClient, consumer, callback); + + changelogReader.register(tp, stateManager); + + if (type == STANDBY) { + changelogReader.transitToUpdateStandby(); + } + + changelogReader.restore(Collections.singletonMap(taskId, mock(Task.class))); + + if (type == ACTIVE) { + assertEquals(Duration.ofMillis(config.getLong(StreamsConfig.POLL_MS_CONFIG)), consumer.lastPollTimeout()); + } else { + assertEquals(Duration.ZERO, consumer.lastPollTimeout()); + } + } + + @Test + public void shouldRestoreFromPositionAndCheckForCompletion() { + final TaskId taskId = new TaskId(0, 0); + + EasyMock.expect(storeMetadata.offset()).andReturn(5L).anyTimes(); + EasyMock.expect(stateManager.changelogOffsets()).andReturn(singletonMap(tp, 5L)); + EasyMock.expect(stateManager.taskId()).andReturn(taskId).anyTimes(); + EasyMock.replay(stateManager, storeMetadata, store); + + adminClient.updateEndOffsets(Collections.singletonMap(tp, 10L)); + + final StoreChangelogReader changelogReader = + new StoreChangelogReader(time, config, logContext, adminClient, consumer, callback); + + changelogReader.register(tp, stateManager); + + if (type == STANDBY) { + changelogReader.transitToUpdateStandby(); + } + + changelogReader.restore(Collections.singletonMap(taskId, mock(Task.class))); + + assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp).state()); + assertEquals(0L, changelogReader.changelogMetadata(tp).totalRestored()); + assertTrue(changelogReader.completedChangelogs().isEmpty()); + assertEquals(6L, consumer.position(tp)); + assertEquals(Collections.emptySet(), consumer.paused()); + + if (type == ACTIVE) { + assertEquals(10L, (long) changelogReader.changelogMetadata(tp).endOffset()); + + assertEquals(tp, callback.restoreTopicPartition); + assertEquals(storeName, callback.storeNameCalledStates.get(RESTORE_START)); + assertNull(callback.storeNameCalledStates.get(RESTORE_END)); + assertNull(callback.storeNameCalledStates.get(RESTORE_BATCH)); + } else { + assertNull(changelogReader.changelogMetadata(tp).endOffset()); + } + + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 6L, "key".getBytes(), "value".getBytes())); + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 7L, "key".getBytes(), "value".getBytes())); + // null key should be ignored + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 8L, null, "value".getBytes())); + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 9L, "key".getBytes(), "value".getBytes())); + // beyond end records should be skipped even when there's gap at the end offset + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 11L, "key".getBytes(), "value".getBytes())); + + changelogReader.restore(Collections.singletonMap(taskId, mock(Task.class))); + + assertEquals(12L, consumer.position(tp)); + + if (type == ACTIVE) { + assertEquals(StoreChangelogReader.ChangelogState.COMPLETED, changelogReader.changelogMetadata(tp).state()); + assertEquals(3L, changelogReader.changelogMetadata(tp).totalRestored()); + assertEquals(1, changelogReader.changelogMetadata(tp).bufferedRecords().size()); + assertEquals(Collections.singleton(tp), changelogReader.completedChangelogs()); + assertEquals(Collections.singleton(tp), consumer.paused()); + + assertEquals(storeName, callback.storeNameCalledStates.get(RESTORE_BATCH)); + assertEquals(storeName, callback.storeNameCalledStates.get(RESTORE_END)); + } else { + assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp).state()); + assertEquals(4L, changelogReader.changelogMetadata(tp).totalRestored()); + assertEquals(0, changelogReader.changelogMetadata(tp).bufferedRecords().size()); + assertEquals(Collections.emptySet(), changelogReader.completedChangelogs()); + assertEquals(Collections.emptySet(), consumer.paused()); + } + } + + @Test + public void shouldRestoreFromBeginningAndCheckCompletion() { + final TaskId taskId = new TaskId(0, 0); + + EasyMock.expect(storeMetadata.offset()).andReturn(null).andReturn(9L).anyTimes(); + EasyMock.expect(stateManager.changelogOffsets()).andReturn(singletonMap(tp, 5L)); + EasyMock.expect(stateManager.taskId()).andReturn(taskId).anyTimes(); + EasyMock.replay(stateManager, storeMetadata, store); + + consumer.updateBeginningOffsets(Collections.singletonMap(tp, 5L)); + adminClient.updateEndOffsets(Collections.singletonMap(tp, 11L)); + + final StoreChangelogReader changelogReader = + new StoreChangelogReader(time, config, logContext, adminClient, consumer, callback); + + changelogReader.register(tp, stateManager); + + if (type == STANDBY) { + changelogReader.transitToUpdateStandby(); + } + + changelogReader.restore(Collections.singletonMap(taskId, mock(Task.class))); + + assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp).state()); + assertEquals(0L, changelogReader.changelogMetadata(tp).totalRestored()); + assertEquals(5L, consumer.position(tp)); + assertEquals(Collections.emptySet(), consumer.paused()); + + if (type == ACTIVE) { + assertEquals(11L, (long) changelogReader.changelogMetadata(tp).endOffset()); + + assertEquals(tp, callback.restoreTopicPartition); + assertEquals(storeName, callback.storeNameCalledStates.get(RESTORE_START)); + assertNull(callback.storeNameCalledStates.get(RESTORE_END)); + assertNull(callback.storeNameCalledStates.get(RESTORE_BATCH)); + } else { + assertNull(changelogReader.changelogMetadata(tp).endOffset()); + } + + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 6L, "key".getBytes(), "value".getBytes())); + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 7L, "key".getBytes(), "value".getBytes())); + // null key should be ignored + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 8L, null, "value".getBytes())); + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 9L, "key".getBytes(), "value".getBytes())); + + changelogReader.restore(Collections.singletonMap(taskId, mock(Task.class))); + + assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp).state()); + assertEquals(3L, changelogReader.changelogMetadata(tp).totalRestored()); + assertEquals(0, changelogReader.changelogMetadata(tp).bufferedRecords().size()); + assertEquals(0, changelogReader.changelogMetadata(tp).bufferedLimitIndex()); + + // consumer position bypassing the gap in the next poll + consumer.seek(tp, 11L); + + changelogReader.restore(Collections.singletonMap(taskId, mock(Task.class))); + + assertEquals(11L, consumer.position(tp)); + assertEquals(3L, changelogReader.changelogMetadata(tp).totalRestored()); + + if (type == ACTIVE) { + assertEquals(StoreChangelogReader.ChangelogState.COMPLETED, changelogReader.changelogMetadata(tp).state()); + assertEquals(3L, changelogReader.changelogMetadata(tp).totalRestored()); + assertEquals(Collections.singleton(tp), changelogReader.completedChangelogs()); + assertEquals(Collections.singleton(tp), consumer.paused()); + + assertEquals(storeName, callback.storeNameCalledStates.get(RESTORE_BATCH)); + assertEquals(storeName, callback.storeNameCalledStates.get(RESTORE_END)); + } else { + assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp).state()); + assertEquals(Collections.emptySet(), changelogReader.completedChangelogs()); + assertEquals(Collections.emptySet(), consumer.paused()); + } + } + + @Test + public void shouldCheckCompletionIfPositionLargerThanEndOffset() { + final Map mockTasks = mock(Map.class); + EasyMock.expect(mockTasks.get(null)).andReturn(mock(Task.class)).anyTimes(); + EasyMock.expect(storeMetadata.offset()).andReturn(5L).anyTimes(); + EasyMock.replay(mockTasks, activeStateManager, storeMetadata, store); + + adminClient.updateEndOffsets(Collections.singletonMap(tp, 0L)); + + final StoreChangelogReader changelogReader = + new StoreChangelogReader(time, config, logContext, adminClient, consumer, callback); + + changelogReader.register(tp, activeStateManager); + changelogReader.restore(mockTasks); + + assertEquals(StoreChangelogReader.ChangelogState.COMPLETED, changelogReader.changelogMetadata(tp).state()); + assertEquals(0L, (long) changelogReader.changelogMetadata(tp).endOffset()); + assertEquals(0L, changelogReader.changelogMetadata(tp).totalRestored()); + assertEquals(Collections.singleton(tp), changelogReader.completedChangelogs()); + assertEquals(6L, consumer.position(tp)); + assertEquals(Collections.singleton(tp), consumer.paused()); + assertEquals(tp, callback.restoreTopicPartition); + assertEquals(storeName, callback.storeNameCalledStates.get(RESTORE_START)); + assertEquals(storeName, callback.storeNameCalledStates.get(RESTORE_END)); + assertNull(callback.storeNameCalledStates.get(RESTORE_BATCH)); + } + + @Test + public void shouldRequestPositionAndHandleTimeoutException() { + final TaskId taskId = new TaskId(0, 0); + + final Task mockTask = mock(Task.class); + mockTask.clearTaskTimeout(); + mockTask.maybeInitTaskTimeoutOrThrow(anyLong(), anyObject()); + EasyMock.expectLastCall(); + EasyMock.expect(storeMetadata.offset()).andReturn(10L).anyTimes(); + EasyMock.expect(activeStateManager.changelogOffsets()).andReturn(singletonMap(tp, 10L)); + EasyMock.expect(activeStateManager.taskId()).andReturn(taskId).anyTimes(); + EasyMock.replay(mockTask, activeStateManager, storeMetadata, store); + + final AtomicBoolean clearException = new AtomicBoolean(false); + final MockConsumer consumer = new MockConsumer(OffsetResetStrategy.EARLIEST) { + @Override + public long position(final TopicPartition partition) { + if (clearException.get()) { + return 10L; + } else { + throw new TimeoutException("KABOOM!"); + } + } + }; + + adminClient.updateEndOffsets(Collections.singletonMap(tp, 10L)); + + final StoreChangelogReader changelogReader = + new StoreChangelogReader(time, config, logContext, adminClient, consumer, callback); + + changelogReader.register(tp, activeStateManager); + changelogReader.restore(Collections.singletonMap(taskId, mockTask)); + + assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp).state()); + assertTrue(changelogReader.completedChangelogs().isEmpty()); + assertEquals(10L, (long) changelogReader.changelogMetadata(tp).endOffset()); + verify(mockTask); + + clearException.set(true); + resetToDefault(mockTask); + mockTask.clearTaskTimeout(); + EasyMock.expectLastCall(); + EasyMock.replay(mockTask); + changelogReader.restore(Collections.singletonMap(taskId, mockTask)); + + assertEquals(StoreChangelogReader.ChangelogState.COMPLETED, changelogReader.changelogMetadata(tp).state()); + assertEquals(10L, (long) changelogReader.changelogMetadata(tp).endOffset()); + assertEquals(Collections.singleton(tp), changelogReader.completedChangelogs()); + assertEquals(10L, consumer.position(tp)); + verify(mockTask); + } + + @Test + public void shouldThrowIfPositionFail() { + final TaskId taskId = new TaskId(0, 0); + EasyMock.expect(activeStateManager.taskId()).andReturn(taskId); + EasyMock.expect(storeMetadata.offset()).andReturn(10L).anyTimes(); + EasyMock.replay(activeStateManager, storeMetadata, store); + + final MockConsumer consumer = new MockConsumer(OffsetResetStrategy.EARLIEST) { + @Override + public long position(final TopicPartition partition) { + throw kaboom; + } + }; + + adminClient.updateEndOffsets(Collections.singletonMap(tp, 10L)); + + final StoreChangelogReader changelogReader = + new StoreChangelogReader(time, config, logContext, adminClient, consumer, callback); + + changelogReader.register(tp, activeStateManager); + + final StreamsException thrown = assertThrows( + StreamsException.class, + () -> changelogReader.restore(Collections.singletonMap(taskId, mock(Task.class))) + ); + assertEquals(kaboom, thrown.getCause()); + } + + @Test + public void shouldRequestEndOffsetsAndHandleTimeoutException() { + final TaskId taskId = new TaskId(0, 0); + + final Task mockTask = mock(Task.class); + mockTask.maybeInitTaskTimeoutOrThrow(anyLong(), anyObject()); + EasyMock.expectLastCall(); + + EasyMock.expect(storeMetadata.offset()).andReturn(5L).anyTimes(); + EasyMock.expect(activeStateManager.changelogOffsets()).andReturn(singletonMap(tp, 5L)); + EasyMock.expect(activeStateManager.taskId()).andReturn(taskId).anyTimes(); + EasyMock.replay(mockTask, activeStateManager, storeMetadata, store); + + final AtomicBoolean functionCalled = new AtomicBoolean(false); + + final MockAdminClient adminClient = new MockAdminClient() { + @Override + public ListOffsetsResult listOffsets(final Map topicPartitionOffsets, + final ListOffsetsOptions options) { + if (functionCalled.get()) { + return super.listOffsets(topicPartitionOffsets, options); + } else { + functionCalled.set(true); + throw new TimeoutException("KABOOM!"); + } + } + }; + adminClient.updateEndOffsets(Collections.singletonMap(tp, 10L)); + + final MockConsumer consumer = new MockConsumer(OffsetResetStrategy.EARLIEST) { + @Override + public Map committed(final Set partitions) { + throw new AssertionError("Should not trigger this function"); + } + }; + + final StoreChangelogReader changelogReader = + new StoreChangelogReader(time, config, logContext, adminClient, consumer, callback); + + changelogReader.register(tp, activeStateManager); + changelogReader.restore(Collections.singletonMap(taskId, mockTask)); + + assertEquals(StoreChangelogReader.ChangelogState.REGISTERED, changelogReader.changelogMetadata(tp).state()); + assertNull(changelogReader.changelogMetadata(tp).endOffset()); + assertTrue(functionCalled.get()); + verify(mockTask); + + EasyMock.resetToDefault(mockTask); + mockTask.clearTaskTimeout(); + EasyMock.replay(mockTask); + + changelogReader.restore(Collections.singletonMap(taskId, mockTask)); + + assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp).state()); + assertEquals(10L, (long) changelogReader.changelogMetadata(tp).endOffset()); + assertEquals(6L, consumer.position(tp)); + verify(mockTask); + } + + @Test + public void shouldThrowIfEndOffsetsFail() { + EasyMock.expect(storeMetadata.offset()).andReturn(10L).anyTimes(); + EasyMock.replay(activeStateManager, storeMetadata, store); + + final MockAdminClient adminClient = new MockAdminClient() { + @Override + public ListOffsetsResult listOffsets(final Map topicPartitionOffsets, + final ListOffsetsOptions options) { + throw kaboom; + } + }; + adminClient.updateEndOffsets(Collections.singletonMap(tp, 0L)); + + final StoreChangelogReader changelogReader = + new StoreChangelogReader(time, config, logContext, adminClient, consumer, callback); + + changelogReader.register(tp, activeStateManager); + + final StreamsException thrown = assertThrows( + StreamsException.class, + () -> changelogReader.restore(Collections.emptyMap()) + ); + assertEquals(kaboom, thrown.getCause()); + } + + @Test + public void shouldRequestCommittedOffsetsAndHandleTimeoutException() { + final TaskId taskId = new TaskId(0, 0); + + final Task mockTask = mock(Task.class); + if (type == ACTIVE) { + mockTask.clearTaskTimeout(); + } + mockTask.maybeInitTaskTimeoutOrThrow(anyLong(), anyObject()); + EasyMock.expectLastCall(); + + EasyMock.expect(stateManager.changelogAsSource(tp)).andReturn(true).anyTimes(); + EasyMock.expect(storeMetadata.offset()).andReturn(5L).anyTimes(); + EasyMock.expect(stateManager.changelogOffsets()).andReturn(singletonMap(tp, 5L)); + EasyMock.expect(stateManager.taskId()).andReturn(taskId).anyTimes(); + EasyMock.replay(mockTask, stateManager, storeMetadata, store); + + final AtomicBoolean functionCalled = new AtomicBoolean(false); + final MockConsumer consumer = new MockConsumer(OffsetResetStrategy.EARLIEST) { + @Override + public Map committed(final Set partitions) { + if (functionCalled.get()) { + return partitions + .stream() + .collect(Collectors.toMap(Function.identity(), partition -> new OffsetAndMetadata(10L))); + } else { + functionCalled.set(true); + throw new TimeoutException("KABOOM!"); + } + } + }; + + adminClient.updateEndOffsets(Collections.singletonMap(tp, 20L)); + + final StoreChangelogReader changelogReader = + new StoreChangelogReader(time, config, logContext, adminClient, consumer, callback); + changelogReader.setMainConsumer(consumer); + + changelogReader.register(tp, stateManager); + changelogReader.restore(Collections.singletonMap(taskId, mockTask)); + + assertEquals( + type == ACTIVE ? + StoreChangelogReader.ChangelogState.REGISTERED : + StoreChangelogReader.ChangelogState.RESTORING, + changelogReader.changelogMetadata(tp).state() + ); + if (type == ACTIVE) { + assertNull(changelogReader.changelogMetadata(tp).endOffset()); + } else { + assertEquals(0L, (long) changelogReader.changelogMetadata(tp).endOffset()); + } + assertTrue(functionCalled.get()); + verify(mockTask); + + resetToDefault(mockTask); + if (type == ACTIVE) { + mockTask.clearTaskTimeout(); + mockTask.clearTaskTimeout(); + expectLastCall(); + } + replay(mockTask); + + changelogReader.restore(Collections.singletonMap(taskId, mockTask)); + + assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp).state()); + assertEquals(type == ACTIVE ? 10L : 0L, (long) changelogReader.changelogMetadata(tp).endOffset()); + assertEquals(6L, consumer.position(tp)); + verify(mockTask); + } + + @Test + public void shouldThrowIfCommittedOffsetsFail() { + final TaskId taskId = new TaskId(0, 0); + + EasyMock.expect(stateManager.taskId()).andReturn(taskId); + EasyMock.expect(stateManager.changelogAsSource(tp)).andReturn(true).anyTimes(); + EasyMock.expect(storeMetadata.offset()).andReturn(10L).anyTimes(); + EasyMock.replay(stateManager, storeMetadata, store); + + final MockConsumer consumer = new MockConsumer(OffsetResetStrategy.EARLIEST) { + @Override + public Map committed(final Set partitions) { + throw kaboom; + } + }; + + adminClient.updateEndOffsets(Collections.singletonMap(tp, 10L)); + + final StoreChangelogReader changelogReader = + new StoreChangelogReader(time, config, logContext, adminClient, consumer, callback); + changelogReader.setMainConsumer(consumer); + + changelogReader.register(tp, stateManager); + + final StreamsException thrown = assertThrows( + StreamsException.class, + () -> changelogReader.restore(Collections.singletonMap(taskId, mock(Task.class))) + ); + assertEquals(kaboom, thrown.getCause()); + } + + @Test + public void shouldThrowIfUnsubscribeFail() { + EasyMock.replay(stateManager, storeMetadata, store); + + final MockConsumer consumer = new MockConsumer(OffsetResetStrategy.EARLIEST) { + @Override + public void unsubscribe() { + throw kaboom; + } + }; + final StoreChangelogReader changelogReader = + new StoreChangelogReader(time, config, logContext, adminClient, consumer, callback); + + final StreamsException thrown = assertThrows(StreamsException.class, changelogReader::clear); + assertEquals(kaboom, thrown.getCause()); + } + + @Test + public void shouldOnlyRestoreStandbyChangelogInUpdateStandbyState() { + final Map mockTasks = mock(Map.class); + EasyMock.expect(mockTasks.get(null)).andReturn(mock(Task.class)).anyTimes(); + EasyMock.replay(mockTasks, standbyStateManager, storeMetadata, store); + + consumer.updateBeginningOffsets(Collections.singletonMap(tp, 5L)); + changelogReader.register(tp, standbyStateManager); + changelogReader.restore(mockTasks); + + assertNull(callback.restoreTopicPartition); + assertNull(callback.storeNameCalledStates.get(RESTORE_START)); + assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp).state()); + assertNull(changelogReader.changelogMetadata(tp).endOffset()); + assertEquals(0L, changelogReader.changelogMetadata(tp).totalRestored()); + + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 6L, "key".getBytes(), "value".getBytes())); + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 7L, "key".getBytes(), "value".getBytes())); + // null key should be ignored + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 8L, null, "value".getBytes())); + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 9L, "key".getBytes(), "value".getBytes())); + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 10L, "key".getBytes(), "value".getBytes())); + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 11L, "key".getBytes(), "value".getBytes())); + + changelogReader.restore(mockTasks); + assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp).state()); + assertEquals(0L, changelogReader.changelogMetadata(tp).totalRestored()); + assertTrue(changelogReader.changelogMetadata(tp).bufferedRecords().isEmpty()); + + assertEquals(Collections.singleton(tp), consumer.paused()); + + changelogReader.transitToUpdateStandby(); + changelogReader.restore(mockTasks); + assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp).state()); + assertEquals(5L, changelogReader.changelogMetadata(tp).totalRestored()); + assertTrue(changelogReader.changelogMetadata(tp).bufferedRecords().isEmpty()); + } + + @Test + public void shouldNotUpdateLimitForNonSourceStandbyChangelog() { + final Map mockTasks = mock(Map.class); + EasyMock.expect(mockTasks.get(null)).andReturn(mock(Task.class)).anyTimes(); + EasyMock.expect(standbyStateManager.changelogAsSource(tp)).andReturn(false).anyTimes(); + EasyMock.replay(mockTasks, standbyStateManager, storeMetadata, store); + + final MockConsumer consumer = new MockConsumer(OffsetResetStrategy.EARLIEST) { + @Override + public Map committed(final Set partitions) { + throw new AssertionError("Should not try to fetch committed offsets"); + } + }; + + final Properties properties = new Properties(); + properties.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L); + final StreamsConfig config = new StreamsConfig(StreamsTestUtils.getStreamsConfig("test-reader", properties)); + final StoreChangelogReader changelogReader = new StoreChangelogReader(time, config, logContext, adminClient, consumer, callback); + changelogReader.setMainConsumer(consumer); + changelogReader.transitToUpdateStandby(); + + consumer.updateBeginningOffsets(Collections.singletonMap(tp, 5L)); + changelogReader.register(tp, standbyStateManager); + assertNull(changelogReader.changelogMetadata(tp).endOffset()); + assertEquals(0L, changelogReader.changelogMetadata(tp).totalRestored()); + + // if there's no records fetchable, nothings gets restored + changelogReader.restore(mockTasks); + assertNull(callback.restoreTopicPartition); + assertNull(callback.storeNameCalledStates.get(RESTORE_START)); + assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp).state()); + assertNull(changelogReader.changelogMetadata(tp).endOffset()); + assertEquals(0L, changelogReader.changelogMetadata(tp).totalRestored()); + + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 5L, "key".getBytes(), "value".getBytes())); + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 6L, "key".getBytes(), "value".getBytes())); + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 7L, "key".getBytes(), "value".getBytes())); + // null key should be ignored + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 8L, null, "value".getBytes())); + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 9L, "key".getBytes(), "value".getBytes())); + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 10L, "key".getBytes(), "value".getBytes())); + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 11L, "key".getBytes(), "value".getBytes())); + + // we should be able to restore to the log end offsets since there's no limit + changelogReader.restore(mockTasks); + assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp).state()); + assertNull(changelogReader.changelogMetadata(tp).endOffset()); + assertEquals(6L, changelogReader.changelogMetadata(tp).totalRestored()); + assertEquals(0, changelogReader.changelogMetadata(tp).bufferedRecords().size()); + assertEquals(0, changelogReader.changelogMetadata(tp).bufferedLimitIndex()); + assertNull(callback.storeNameCalledStates.get(RESTORE_END)); + assertNull(callback.storeNameCalledStates.get(RESTORE_BATCH)); + } + + @Test + public void shouldRestoreToLimitInStandbyState() { + final Map mockTasks = mock(Map.class); + EasyMock.expect(mockTasks.get(null)).andReturn(mock(Task.class)).anyTimes(); + EasyMock.expect(standbyStateManager.changelogAsSource(tp)).andReturn(true).anyTimes(); + EasyMock.replay(mockTasks, standbyStateManager, storeMetadata, store); + + final AtomicLong offset = new AtomicLong(7L); + final MockConsumer consumer = new MockConsumer(OffsetResetStrategy.EARLIEST) { + @Override + public Map committed(final Set partitions) { + return partitions + .stream() + .collect(Collectors.toMap(Function.identity(), partition -> new OffsetAndMetadata(offset.get()))); + } + }; + + final long now = time.milliseconds(); + final Properties properties = new Properties(); + properties.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L); + final StreamsConfig config = new StreamsConfig(StreamsTestUtils.getStreamsConfig("test-reader", properties)); + final StoreChangelogReader changelogReader = new StoreChangelogReader(time, config, logContext, adminClient, consumer, callback); + changelogReader.setMainConsumer(consumer); + changelogReader.transitToUpdateStandby(); + + consumer.updateBeginningOffsets(Collections.singletonMap(tp, 5L)); + changelogReader.register(tp, standbyStateManager); + assertEquals(0L, (long) changelogReader.changelogMetadata(tp).endOffset()); + assertEquals(0L, changelogReader.changelogMetadata(tp).totalRestored()); + + changelogReader.restore(mockTasks); + + assertNull(callback.restoreTopicPartition); + assertNull(callback.storeNameCalledStates.get(RESTORE_START)); + assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp).state()); + assertEquals(7L, (long) changelogReader.changelogMetadata(tp).endOffset()); + assertEquals(0L, changelogReader.changelogMetadata(tp).totalRestored()); + + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 5L, "key".getBytes(), "value".getBytes())); + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 6L, "key".getBytes(), "value".getBytes())); + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 7L, "key".getBytes(), "value".getBytes())); + // null key should be ignored + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 8L, null, "value".getBytes())); + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 9L, "key".getBytes(), "value".getBytes())); + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 10L, "key".getBytes(), "value".getBytes())); + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 11L, "key".getBytes(), "value".getBytes())); + + changelogReader.restore(mockTasks); + assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp).state()); + assertEquals(7L, (long) changelogReader.changelogMetadata(tp).endOffset()); + assertEquals(2L, changelogReader.changelogMetadata(tp).totalRestored()); + assertEquals(4, changelogReader.changelogMetadata(tp).bufferedRecords().size()); + assertEquals(0, changelogReader.changelogMetadata(tp).bufferedLimitIndex()); + assertNull(callback.storeNameCalledStates.get(RESTORE_END)); + assertNull(callback.storeNameCalledStates.get(RESTORE_BATCH)); + + offset.set(10L); + time.setCurrentTimeMs(now + 100L); + // should not try to read committed offsets if interval has not reached + changelogReader.restore(mockTasks); + assertEquals(7L, (long) changelogReader.changelogMetadata(tp).endOffset()); + assertEquals(2L, changelogReader.changelogMetadata(tp).totalRestored()); + assertEquals(4, changelogReader.changelogMetadata(tp).bufferedRecords().size()); + assertEquals(0, changelogReader.changelogMetadata(tp).bufferedLimitIndex()); + + time.setCurrentTimeMs(now + 101L); + // the first restore would only update the limit, same below + changelogReader.restore(mockTasks); + assertEquals(10L, (long) changelogReader.changelogMetadata(tp).endOffset()); + assertEquals(2L, changelogReader.changelogMetadata(tp).totalRestored()); + assertEquals(4, changelogReader.changelogMetadata(tp).bufferedRecords().size()); + assertEquals(2, changelogReader.changelogMetadata(tp).bufferedLimitIndex()); + + changelogReader.restore(mockTasks); + assertEquals(10L, (long) changelogReader.changelogMetadata(tp).endOffset()); + assertEquals(4L, changelogReader.changelogMetadata(tp).totalRestored()); + assertEquals(2, changelogReader.changelogMetadata(tp).bufferedRecords().size()); + assertEquals(0, changelogReader.changelogMetadata(tp).bufferedLimitIndex()); + + offset.set(15L); + + // after we've updated once, the timer should be reset and we should not try again until next interval elapsed + time.setCurrentTimeMs(now + 201L); + changelogReader.restore(mockTasks); + assertEquals(10L, (long) changelogReader.changelogMetadata(tp).endOffset()); + assertEquals(4L, changelogReader.changelogMetadata(tp).totalRestored()); + assertEquals(2, changelogReader.changelogMetadata(tp).bufferedRecords().size()); + assertEquals(0, changelogReader.changelogMetadata(tp).bufferedLimitIndex()); + + // once we are in update active mode, we should not try to update limit offset + time.setCurrentTimeMs(now + 202L); + changelogReader.enforceRestoreActive(); + changelogReader.restore(mockTasks); + assertEquals(10L, (long) changelogReader.changelogMetadata(tp).endOffset()); + assertEquals(4L, changelogReader.changelogMetadata(tp).totalRestored()); + assertEquals(2, changelogReader.changelogMetadata(tp).bufferedRecords().size()); + assertEquals(0, changelogReader.changelogMetadata(tp).bufferedLimitIndex()); + + changelogReader.transitToUpdateStandby(); + changelogReader.restore(mockTasks); + assertEquals(15L, (long) changelogReader.changelogMetadata(tp).endOffset()); + assertEquals(4L, changelogReader.changelogMetadata(tp).totalRestored()); + assertEquals(2, changelogReader.changelogMetadata(tp).bufferedRecords().size()); + assertEquals(2, changelogReader.changelogMetadata(tp).bufferedLimitIndex()); + + changelogReader.restore(mockTasks); + assertEquals(15L, (long) changelogReader.changelogMetadata(tp).endOffset()); + assertEquals(6L, changelogReader.changelogMetadata(tp).totalRestored()); + assertEquals(0, changelogReader.changelogMetadata(tp).bufferedRecords().size()); + assertEquals(0, changelogReader.changelogMetadata(tp).bufferedLimitIndex()); + + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 12L, "key".getBytes(), "value".getBytes())); + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 13L, "key".getBytes(), "value".getBytes())); + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 14L, "key".getBytes(), "value".getBytes())); + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 15L, "key".getBytes(), "value".getBytes())); + + changelogReader.restore(mockTasks); + assertEquals(15L, (long) changelogReader.changelogMetadata(tp).endOffset()); + assertEquals(9L, changelogReader.changelogMetadata(tp).totalRestored()); + assertEquals(1, changelogReader.changelogMetadata(tp).bufferedRecords().size()); + assertEquals(0, changelogReader.changelogMetadata(tp).bufferedLimitIndex()); + } + + @Test + public void shouldRestoreMultipleChangelogs() { + final Map mockTasks = mock(Map.class); + EasyMock.expect(mockTasks.get(null)).andReturn(mock(Task.class)).anyTimes(); + EasyMock.expect(storeMetadataOne.changelogPartition()).andReturn(tp1).anyTimes(); + EasyMock.expect(storeMetadataOne.store()).andReturn(store).anyTimes(); + EasyMock.expect(storeMetadataTwo.changelogPartition()).andReturn(tp2).anyTimes(); + EasyMock.expect(storeMetadataTwo.store()).andReturn(store).anyTimes(); + EasyMock.expect(storeMetadata.offset()).andReturn(0L).anyTimes(); + EasyMock.expect(storeMetadataOne.offset()).andReturn(0L).anyTimes(); + EasyMock.expect(storeMetadataTwo.offset()).andReturn(0L).anyTimes(); + EasyMock.expect(activeStateManager.storeMetadata(tp1)).andReturn(storeMetadataOne).anyTimes(); + EasyMock.expect(activeStateManager.storeMetadata(tp2)).andReturn(storeMetadataTwo).anyTimes(); + EasyMock.expect(activeStateManager.changelogOffsets()).andReturn(mkMap( + mkEntry(tp, 5L), + mkEntry(tp1, 5L), + mkEntry(tp2, 5L) + )).anyTimes(); + EasyMock.replay(mockTasks, activeStateManager, storeMetadata, store, storeMetadataOne, storeMetadataTwo); + + setupConsumer(10, tp); + setupConsumer(5, tp1); + setupConsumer(3, tp2); + + changelogReader.register(tp, activeStateManager); + changelogReader.register(tp1, activeStateManager); + changelogReader.register(tp2, activeStateManager); + + changelogReader.restore(mockTasks); + + assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp).state()); + assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp1).state()); + assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp2).state()); + + // should support removing and clearing changelogs + changelogReader.unregister(Collections.singletonList(tp)); + assertNull(changelogReader.changelogMetadata(tp)); + assertFalse(changelogReader.isEmpty()); + assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp1).state()); + assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp2).state()); + + changelogReader.clear(); + assertTrue(changelogReader.isEmpty()); + assertNull(changelogReader.changelogMetadata(tp1)); + assertNull(changelogReader.changelogMetadata(tp2)); + } + + @Test + public void shouldTransitState() { + final TaskId taskId = new TaskId(0, 0); + EasyMock.expect(storeMetadataOne.changelogPartition()).andReturn(tp1).anyTimes(); + EasyMock.expect(storeMetadataOne.store()).andReturn(store).anyTimes(); + EasyMock.expect(storeMetadataTwo.changelogPartition()).andReturn(tp2).anyTimes(); + EasyMock.expect(storeMetadataTwo.store()).andReturn(store).anyTimes(); + EasyMock.expect(storeMetadata.offset()).andReturn(5L).anyTimes(); + EasyMock.expect(storeMetadataOne.offset()).andReturn(5L).anyTimes(); + EasyMock.expect(storeMetadataTwo.offset()).andReturn(5L).anyTimes(); + EasyMock.expect(standbyStateManager.storeMetadata(tp1)).andReturn(storeMetadataOne).anyTimes(); + EasyMock.expect(standbyStateManager.storeMetadata(tp2)).andReturn(storeMetadataTwo).anyTimes(); + EasyMock.expect(activeStateManager.changelogOffsets()).andReturn(singletonMap(tp, 5L)); + EasyMock.expect(activeStateManager.taskId()).andReturn(taskId).anyTimes(); + EasyMock.replay(activeStateManager, standbyStateManager, storeMetadata, store, storeMetadataOne, storeMetadataTwo); + + adminClient.updateEndOffsets(Collections.singletonMap(tp, 10L)); + adminClient.updateEndOffsets(Collections.singletonMap(tp1, 10L)); + adminClient.updateEndOffsets(Collections.singletonMap(tp2, 10L)); + final StoreChangelogReader changelogReader = new StoreChangelogReader(time, config, logContext, adminClient, consumer, callback); + assertEquals(ACTIVE_RESTORING, changelogReader.state()); + + changelogReader.register(tp, activeStateManager); + changelogReader.register(tp1, standbyStateManager); + changelogReader.register(tp2, standbyStateManager); + assertEquals(StoreChangelogReader.ChangelogState.REGISTERED, changelogReader.changelogMetadata(tp).state()); + assertEquals(StoreChangelogReader.ChangelogState.REGISTERED, changelogReader.changelogMetadata(tp1).state()); + assertEquals(StoreChangelogReader.ChangelogState.REGISTERED, changelogReader.changelogMetadata(tp2).state()); + + assertEquals(Collections.emptySet(), consumer.assignment()); + + changelogReader.restore(Collections.singletonMap(taskId, mock(Task.class))); + + assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp).state()); + assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp1).state()); + assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp2).state()); + assertEquals(mkSet(tp, tp1, tp2), consumer.assignment()); + assertEquals(mkSet(tp1, tp2), consumer.paused()); + assertEquals(ACTIVE_RESTORING, changelogReader.state()); + + // transition to restore active is idempotent + changelogReader.enforceRestoreActive(); + assertEquals(ACTIVE_RESTORING, changelogReader.state()); + + changelogReader.transitToUpdateStandby(); + assertEquals(STANDBY_UPDATING, changelogReader.state()); + + assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp).state()); + assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp1).state()); + assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp2).state()); + assertEquals(mkSet(tp, tp1, tp2), consumer.assignment()); + assertEquals(Collections.emptySet(), consumer.paused()); + + // transition to update standby is NOT idempotent + assertThrows(IllegalStateException.class, changelogReader::transitToUpdateStandby); + + changelogReader.unregister(Collections.singletonList(tp)); + changelogReader.register(tp, activeStateManager); + + // if a new active is registered, we should immediately transit to standby updating + assertThrows( + IllegalStateException.class, + () -> changelogReader.restore(Collections.singletonMap(taskId, mock(Task.class))) + ); + + assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp).state()); + assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp1).state()); + assertEquals(StoreChangelogReader.ChangelogState.RESTORING, changelogReader.changelogMetadata(tp2).state()); + assertEquals(mkSet(tp, tp1, tp2), consumer.assignment()); + assertEquals(Collections.emptySet(), consumer.paused()); + assertEquals(STANDBY_UPDATING, changelogReader.state()); + + changelogReader.enforceRestoreActive(); + assertEquals(ACTIVE_RESTORING, changelogReader.state()); + assertEquals(mkSet(tp, tp1, tp2), consumer.assignment()); + assertEquals(mkSet(tp1, tp2), consumer.paused()); + } + + @Test + public void shouldThrowIfRestoreCallbackThrows() { + final TaskId taskId = new TaskId(0, 0); + + EasyMock.expect(storeMetadata.offset()).andReturn(5L).anyTimes(); + EasyMock.expect(activeStateManager.taskId()).andReturn(taskId); + EasyMock.replay(activeStateManager, storeMetadata, store); + + adminClient.updateEndOffsets(Collections.singletonMap(tp, 10L)); + + final StoreChangelogReader changelogReader = + new StoreChangelogReader(time, config, logContext, adminClient, consumer, exceptionCallback); + + changelogReader.register(tp, activeStateManager); + + StreamsException thrown = assertThrows( + StreamsException.class, + () -> changelogReader.restore(Collections.singletonMap(taskId, mock(Task.class))) + ); + assertEquals(kaboom, thrown.getCause()); + + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 6L, "key".getBytes(), "value".getBytes())); + consumer.addRecord(new ConsumerRecord<>(topicName, 0, 7L, "key".getBytes(), "value".getBytes())); + + thrown = assertThrows( + StreamsException.class, + () -> changelogReader.restore(Collections.singletonMap(taskId, mock(Task.class))) + ); + assertEquals(kaboom, thrown.getCause()); + + consumer.seek(tp, 10L); + + thrown = assertThrows( + StreamsException.class, + () -> changelogReader.restore(Collections.singletonMap(taskId, mock(Task.class))) + ); + assertEquals(kaboom, thrown.getCause()); + } + + @Test + public void shouldNotThrowOnUnknownRevokedPartition() { + LogCaptureAppender.setClassLoggerToDebug(StoreChangelogReader.class); + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(StoreChangelogReader.class)) { + changelogReader.unregister(Collections.singletonList(new TopicPartition("unknown", 0))); + + assertThat( + appender.getMessages(), + hasItem("test-reader Changelog partition unknown-0 could not be found," + + " it could be already cleaned up during the handling of task corruption and never restore again") + ); + } + } + + private void setupConsumer(final long messages, final TopicPartition topicPartition) { + assignPartition(messages, topicPartition); + addRecords(messages, topicPartition); + consumer.assign(Collections.emptyList()); + } + + private void addRecords(final long messages, final TopicPartition topicPartition) { + for (int i = 0; i < messages; i++) { + consumer.addRecord(new ConsumerRecord<>( + topicPartition.topic(), + topicPartition.partition(), + i, + new byte[0], + new byte[0])); + } + } + + private void assignPartition(final long messages, + final TopicPartition topicPartition) { + consumer.updatePartitions( + topicPartition.topic(), + Collections.singletonList(new PartitionInfo( + topicPartition.topic(), + topicPartition.partition(), + null, + null, + null))); + consumer.updateBeginningOffsets(Collections.singletonMap(topicPartition, 0L)); + consumer.updateEndOffsets(Collections.singletonMap(topicPartition, Math.max(0, messages) + 1)); + adminClient.updateEndOffsets(Collections.singletonMap(topicPartition, Math.max(0, messages) + 1)); + consumer.assign(Collections.singletonList(topicPartition)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StoreToProcessorContextAdapterTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StoreToProcessorContextAdapterTest.java new file mode 100644 index 0000000..90db566 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StoreToProcessorContextAdapterTest.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.Punctuator; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.To; +import org.easymock.EasyMockRunner; +import org.easymock.Mock; +import org.easymock.MockType; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import java.time.Duration; + +@RunWith(EasyMockRunner.class) +public class StoreToProcessorContextAdapterTest { + @Mock(MockType.NICE) + private StateStoreContext delegate; + private ProcessorContext context; + @Mock(MockType.NICE) + private Punctuator punctuator; + + @Before + public void setUp() { + context = StoreToProcessorContextAdapter.adapt(delegate); + } + + @Test(expected = UnsupportedOperationException.class) + public void shouldThrowOnCurrentSystemTime() { + context.currentSystemTimeMs(); + } + + @Test(expected = UnsupportedOperationException.class) + public void shouldThrowOnCurrentStreamTime() { + context.currentStreamTimeMs(); + } + + @Test(expected = UnsupportedOperationException.class) + public void shouldThrowOnGetStateStore() { + context.getStateStore("store"); + } + + @Test(expected = UnsupportedOperationException.class) + public void shouldThrowOnScheduleWithDuration() { + context.schedule(Duration.ZERO, PunctuationType.WALL_CLOCK_TIME, punctuator); + } + + @Test(expected = UnsupportedOperationException.class) + public void shouldThrowOnForward() { + context.forward("key", "value"); + } + + @Test(expected = UnsupportedOperationException.class) + public void shouldThrowOnForwardWithTo() { + context.forward("key", "value", To.all()); + } + + @Test(expected = UnsupportedOperationException.class) + public void shouldThrowOnCommit() { + context.commit(); + } + + @Test(expected = UnsupportedOperationException.class) + public void shouldThrowOnTopic() { + context.topic(); + } + + @Test(expected = UnsupportedOperationException.class) + public void shouldThrowOnPartition() { + context.partition(); + } + + @Test(expected = UnsupportedOperationException.class) + public void shouldThrowOnOffset() { + context.offset(); + } + + @Test(expected = UnsupportedOperationException.class) + public void shouldThrowOnHeaders() { + context.headers(); + } + + @Test(expected = UnsupportedOperationException.class) + public void shouldThrowOnTimestamp() { + context.timestamp(); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java new file mode 100644 index 0000000..c372234 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamTaskTest.java @@ -0,0 +1,2590 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.MockConsumer; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.consumer.OffsetResetStrategy; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.metrics.JmxReporter; +import org.apache.kafka.common.metrics.KafkaMetric; +import org.apache.kafka.common.metrics.KafkaMetricsContext; +import org.apache.kafka.common.metrics.MetricConfig; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.MetricsContext; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.stats.CumulativeSum; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.LockException; +import org.apache.kafka.streams.errors.ProcessorStateException; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.errors.TaskCorruptedException; +import org.apache.kafka.streams.errors.TaskMigratedException; +import org.apache.kafka.streams.errors.TopologyException; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.Punctuator; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.Task.TaskType; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.internals.ThreadCache; +import org.apache.kafka.test.MockKeyValueStore; +import org.apache.kafka.test.MockProcessorNode; +import org.apache.kafka.test.MockSourceNode; +import org.apache.kafka.test.MockTimestampExtractor; +import org.apache.kafka.test.TestUtils; +import org.easymock.EasyMock; +import org.easymock.EasyMockRunner; +import org.easymock.IMocksControl; +import org.easymock.Mock; +import org.easymock.MockType; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.time.Duration; +import java.util.Base64; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; +import static java.util.Collections.emptyMap; +import static java.util.Collections.singleton; +import static java.util.Collections.singletonList; +import static java.util.Collections.singletonMap; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkProperties; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.streams.StreamsConfig.AT_LEAST_ONCE; +import static org.apache.kafka.streams.processor.internals.StreamTask.encodeTimestamp; +import static org.apache.kafka.streams.processor.internals.Task.State.CREATED; +import static org.apache.kafka.streams.processor.internals.Task.State.RESTORING; +import static org.apache.kafka.streams.processor.internals.Task.State.RUNNING; +import static org.apache.kafka.streams.processor.internals.Task.State.SUSPENDED; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.THREAD_ID_TAG; +import static org.apache.kafka.test.StreamsTestUtils.getMetricByNameFilterByTags; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +@RunWith(EasyMockRunner.class) +public class StreamTaskTest { + + private static final String APPLICATION_ID = "stream-task-test"; + private static final File BASE_DIR = TestUtils.tempDirectory(); + private static final long DEFAULT_TIMESTAMP = 1000; + + private final LogContext logContext = new LogContext("[test] "); + private final String topic1 = "topic1"; + private final String topic2 = "topic2"; + private final TopicPartition partition1 = new TopicPartition(topic1, 1); + private final TopicPartition partition2 = new TopicPartition(topic2, 1); + private final Set partitions = mkSet(partition1, partition2); + private final Serializer intSerializer = Serdes.Integer().serializer(); + private final Deserializer intDeserializer = Serdes.Integer().deserializer(); + + private final MockSourceNode source1 = new MockSourceNode<>(intDeserializer, intDeserializer); + private final MockSourceNode source2 = new MockSourceNode<>(intDeserializer, intDeserializer); + private final MockSourceNode source3 = new MockSourceNode(intDeserializer, intDeserializer) { + @Override + public void process(final Record record) { + throw new RuntimeException("KABOOM!"); + } + + @Override + public void close() { + throw new RuntimeException("KABOOM!"); + } + }; + private final MockSourceNode timeoutSource = new MockSourceNode(intDeserializer, intDeserializer) { + @Override + public void process(final Record record) { + throw new TimeoutException("Kaboom!"); + } + }; + private final MockProcessorNode processorStreamTime = new MockProcessorNode<>(10L); + private final MockProcessorNode processorSystemTime = new MockProcessorNode<>(10L, PunctuationType.WALL_CLOCK_TIME); + + private final String storeName = "store"; + private final MockKeyValueStore stateStore = new MockKeyValueStore(storeName, false); + private final TopicPartition changelogPartition = new TopicPartition("store-changelog", 1); + + private final MockConsumer consumer = new MockConsumer<>(OffsetResetStrategy.EARLIEST); + private final byte[] recordValue = intSerializer.serialize(null, 10); + private final byte[] recordKey = intSerializer.serialize(null, 1); + private final String threadId = Thread.currentThread().getName(); + private final TaskId taskId = new TaskId(0, 0); + + private MockTime time = new MockTime(); + private Metrics metrics = new Metrics(new MetricConfig().recordLevel(Sensor.RecordingLevel.DEBUG), time); + private final StreamsMetricsImpl streamsMetrics = new MockStreamsMetrics(metrics); + + private StateDirectory stateDirectory; + private StreamTask task; + private long punctuatedAt; + + @Mock(type = MockType.NICE) + private ProcessorStateManager stateManager; + @Mock(type = MockType.NICE) + private RecordCollector recordCollector; + @Mock(type = MockType.NICE) + private ThreadCache cache; + + private final Punctuator punctuator = new Punctuator() { + @Override + public void punctuate(final long timestamp) { + punctuatedAt = timestamp; + } + }; + + private static ProcessorTopology withRepartitionTopics(final List> processorNodes, + final Map> sourcesByTopic, + final Set repartitionTopics) { + return new ProcessorTopology(processorNodes, + sourcesByTopic, + emptyMap(), + emptyList(), + emptyList(), + emptyMap(), + repartitionTopics); + } + + private static ProcessorTopology withSources(final List> processorNodes, + final Map> sourcesByTopic) { + return new ProcessorTopology(processorNodes, + sourcesByTopic, + emptyMap(), + emptyList(), + emptyList(), + emptyMap(), + Collections.emptySet()); + } + + private static StreamsConfig createConfig() { + return createConfig("0"); + } + + private static StreamsConfig createConfig(final String enforcedProcessingValue) { + return createConfig(AT_LEAST_ONCE, enforcedProcessingValue); + } + + private static StreamsConfig createConfig(final String eosConfig, final String enforcedProcessingValue) { + final String canonicalPath; + try { + canonicalPath = BASE_DIR.getCanonicalPath(); + } catch (final IOException e) { + throw new RuntimeException(e); + } + return new StreamsConfig(mkProperties(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, APPLICATION_ID), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:2171"), + mkEntry(StreamsConfig.BUFFERED_RECORDS_PER_PARTITION_CONFIG, "3"), + mkEntry(StreamsConfig.STATE_DIR_CONFIG, canonicalPath), + mkEntry(StreamsConfig.DEFAULT_TIMESTAMP_EXTRACTOR_CLASS_CONFIG, MockTimestampExtractor.class.getName()), + mkEntry(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, eosConfig), + mkEntry(StreamsConfig.MAX_TASK_IDLE_MS_CONFIG, enforcedProcessingValue) + ))); + } + + @Before + public void setup() { + EasyMock.expect(stateManager.taskId()).andStubReturn(taskId); + EasyMock.expect(stateManager.taskType()).andStubReturn(TaskType.ACTIVE); + + consumer.assign(asList(partition1, partition2)); + consumer.updateBeginningOffsets(mkMap(mkEntry(partition1, 0L), mkEntry(partition2, 0L))); + stateDirectory = new StateDirectory(createConfig("100"), new MockTime(), true, false); + } + + @After + public void cleanup() throws IOException { + if (task != null) { + try { + task.suspend(); + } catch (final IllegalStateException maybeSwallow) { + if (!maybeSwallow.getMessage().startsWith("Illegal state CLOSED")) { + throw maybeSwallow; + } + } catch (final RuntimeException swallow) { + // suspend dirty case + } + task.closeDirty(); + task = null; + } + Utils.delete(BASE_DIR); + } + + @Test + public void shouldThrowLockExceptionIfFailedToLockStateDirectory() throws IOException { + stateDirectory = EasyMock.createNiceMock(StateDirectory.class); + EasyMock.expect(stateDirectory.lock(taskId)).andReturn(false); + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()); + stateManager.registerStore(stateStore, stateStore.stateRestoreCallback); + EasyMock.expectLastCall(); + EasyMock.replay(stateDirectory, stateManager); + + task = createStatefulTask(createConfig("100"), false); + + assertThrows(LockException.class, () -> task.initializeIfNeeded()); + } + + @Test + public void shouldNotAttemptToLockIfNoStores() { + stateDirectory = EasyMock.createNiceMock(StateDirectory.class); + EasyMock.replay(stateDirectory); + + task = createStatelessTask(createConfig("100")); + + task.initializeIfNeeded(); + + // should fail if lock is called + EasyMock.verify(stateDirectory); + } + + @Test + public void shouldAttemptToDeleteStateDirectoryWhenCloseDirtyAndEosEnabled() throws IOException { + final IMocksControl ctrl = EasyMock.createStrictControl(); + final ProcessorStateManager stateManager = ctrl.createMock(ProcessorStateManager.class); + EasyMock.expect(stateManager.taskType()).andStubReturn(TaskType.ACTIVE); + stateDirectory = ctrl.createMock(StateDirectory.class); + + stateManager.registerGlobalStateStores(emptyList()); + EasyMock.expectLastCall(); + + EasyMock.expect(stateManager.taskId()).andReturn(taskId); + + EasyMock.expect(stateDirectory.lock(taskId)).andReturn(true); + + stateManager.close(); + EasyMock.expectLastCall(); + + // The `baseDir` will be accessed when attempting to delete the state store. + EasyMock.expect(stateManager.baseDir()).andReturn(TestUtils.tempDirectory("state_store")); + + stateDirectory.unlock(taskId); + EasyMock.expectLastCall(); + + ctrl.checkOrder(true); + ctrl.replay(); + + task = createStatefulTask(createConfig(StreamsConfig.EXACTLY_ONCE_V2, "100"), true, stateManager); + task.suspend(); + task.closeDirty(); + task = null; + + ctrl.verify(); + } + + @Test + public void shouldResetOffsetsToLastCommittedForSpecifiedPartitions() { + task = createStatelessTask(createConfig("100")); + task.addPartitionsForOffsetReset(Collections.singleton(partition1)); + + consumer.seek(partition1, 5L); + consumer.commitSync(); + + consumer.seek(partition1, 10L); + consumer.seek(partition2, 15L); + + final java.util.function.Consumer> resetter = + EasyMock.mock(java.util.function.Consumer.class); + resetter.accept(Collections.emptySet()); + EasyMock.expectLastCall(); + EasyMock.replay(resetter); + + task.initializeIfNeeded(); + task.completeRestoration(resetter); + + assertThat(consumer.position(partition1), equalTo(5L)); + assertThat(consumer.position(partition2), equalTo(15L)); + } + + @Test + public void shouldAutoOffsetResetIfNoCommittedOffsetFound() { + task = createStatelessTask(createConfig("100")); + task.addPartitionsForOffsetReset(Collections.singleton(partition1)); + + final AtomicReference shouldNotSeek = new AtomicReference<>(); + final MockConsumer consumer = new MockConsumer(OffsetResetStrategy.EARLIEST) { + @Override + public void seek(final TopicPartition partition, final long offset) { + final AssertionError error = shouldNotSeek.get(); + if (error != null) { + throw error; + } + super.seek(partition, offset); + } + }; + consumer.assign(asList(partition1, partition2)); + consumer.updateBeginningOffsets(mkMap(mkEntry(partition1, 0L), mkEntry(partition2, 0L))); + + consumer.seek(partition1, 5L); + consumer.seek(partition2, 15L); + + shouldNotSeek.set(new AssertionError("Should not seek")); + + final java.util.function.Consumer> resetter = + EasyMock.mock(java.util.function.Consumer.class); + resetter.accept(Collections.singleton(partition1)); + EasyMock.expectLastCall(); + EasyMock.replay(resetter); + + task.initializeIfNeeded(); + task.completeRestoration(resetter); + + // because we mocked the `resetter` positions don't change + assertThat(consumer.position(partition1), equalTo(5L)); + assertThat(consumer.position(partition2), equalTo(15L)); + EasyMock.verify(resetter); + } + + @Test + public void shouldReadCommittedStreamTimeOnInitialize() { + stateDirectory = EasyMock.createNiceMock(StateDirectory.class); + EasyMock.replay(stateDirectory); + + consumer.commitSync(partitions.stream() + .collect(Collectors.toMap(Function.identity(), tp -> new OffsetAndMetadata(0L, encodeTimestamp(10L))))); + + task = createStatelessTask(createConfig("100")); + + assertEquals(RecordQueue.UNKNOWN, task.streamTime()); + + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + + assertEquals(10L, task.streamTime()); + } + + @Test + public void shouldTransitToRestoringThenRunningAfterCreation() throws IOException { + stateDirectory = EasyMock.createNiceMock(StateDirectory.class); + EasyMock.expect(stateDirectory.lock(taskId)).andReturn(true); + EasyMock.expect(stateManager.changelogPartitions()).andReturn(singleton(changelogPartition)); + EasyMock.expect(stateManager.changelogOffsets()).andReturn(singletonMap(changelogPartition, 10L)); + stateManager.registerStore(stateStore, stateStore.stateRestoreCallback); + EasyMock.expectLastCall(); + EasyMock.expect(recordCollector.offsets()).andReturn(emptyMap()).anyTimes(); + EasyMock.replay(stateDirectory, stateManager, recordCollector); + + task = createStatefulTask(createConfig("100"), true); + + assertEquals(CREATED, task.state()); + + task.initializeIfNeeded(); + + assertEquals(RESTORING, task.state()); + assertFalse(source1.initialized); + assertFalse(source2.initialized); + + // initialize should be idempotent + task.initializeIfNeeded(); + + assertEquals(RESTORING, task.state()); + + task.completeRestoration(noOpResetter -> { }); + + assertEquals(RUNNING, task.state()); + assertTrue(source1.initialized); + assertTrue(source2.initialized); + + EasyMock.verify(stateDirectory); + } + + @Test + public void shouldProcessInOrder() { + task = createStatelessTask(createConfig()); + + task.addRecords(partition1, asList( + getConsumerRecordWithOffsetAsTimestamp(partition1, 10, 101), + getConsumerRecordWithOffsetAsTimestamp(partition1, 20, 102), + getConsumerRecordWithOffsetAsTimestamp(partition1, 30, 103) + )); + + task.addRecords(partition2, asList( + getConsumerRecordWithOffsetAsTimestamp(partition2, 25, 201), + getConsumerRecordWithOffsetAsTimestamp(partition2, 35, 202), + getConsumerRecordWithOffsetAsTimestamp(partition2, 45, 203) + )); + + assertTrue(task.process(0L)); + assertEquals(5, task.numBuffered()); + assertEquals(1, source1.numReceived); + assertEquals(0, source2.numReceived); + assertEquals(singletonList(101), source1.values); + assertEquals(emptyList(), source2.values); + + assertTrue(task.process(0L)); + assertEquals(4, task.numBuffered()); + assertEquals(2, source1.numReceived); + assertEquals(0, source2.numReceived); + assertEquals(asList(101, 102), source1.values); + assertEquals(emptyList(), source2.values); + + assertTrue(task.process(0L)); + assertEquals(3, task.numBuffered()); + assertEquals(2, source1.numReceived); + assertEquals(1, source2.numReceived); + assertEquals(asList(101, 102), source1.values); + assertEquals(singletonList(201), source2.values); + + assertTrue(task.process(0L)); + assertEquals(2, task.numBuffered()); + assertEquals(3, source1.numReceived); + assertEquals(1, source2.numReceived); + assertEquals(asList(101, 102, 103), source1.values); + assertEquals(singletonList(201), source2.values); + + assertTrue(task.process(0L)); + assertEquals(1, task.numBuffered()); + assertEquals(3, source1.numReceived); + assertEquals(2, source2.numReceived); + assertEquals(asList(101, 102, 103), source1.values); + assertEquals(asList(201, 202), source2.values); + + assertTrue(task.process(0L)); + assertEquals(0, task.numBuffered()); + assertEquals(3, source1.numReceived); + assertEquals(3, source2.numReceived); + assertEquals(asList(101, 102, 103), source1.values); + assertEquals(asList(201, 202, 203), source2.values); + } + + @Test + public void shouldProcessRecordsAfterPrepareCommitWhenEosDisabled() { + task = createSingleSourceStateless(createConfig(), StreamsConfig.METRICS_LATEST); + + assertFalse(task.process(time.milliseconds())); + + task.addRecords(partition1, asList( + getConsumerRecordWithOffsetAsTimestamp(partition1, 10), + getConsumerRecordWithOffsetAsTimestamp(partition1, 20), + getConsumerRecordWithOffsetAsTimestamp(partition1, 30) + )); + + assertTrue(task.process(time.milliseconds())); + task.prepareCommit(); + assertTrue(task.process(time.milliseconds())); + task.postCommit(false); + assertTrue(task.process(time.milliseconds())); + + assertFalse(task.process(time.milliseconds())); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldNotProcessRecordsAfterPrepareCommitWhenEosAlphaEnabled() { + task = createSingleSourceStateless(createConfig(StreamsConfig.EXACTLY_ONCE, "0"), StreamsConfig.METRICS_LATEST); + + assertFalse(task.process(time.milliseconds())); + + task.addRecords(partition1, asList( + getConsumerRecordWithOffsetAsTimestamp(partition1, 10), + getConsumerRecordWithOffsetAsTimestamp(partition1, 20), + getConsumerRecordWithOffsetAsTimestamp(partition1, 30) + )); + + assertTrue(task.process(time.milliseconds())); + task.prepareCommit(); + assertFalse(task.process(time.milliseconds())); + task.postCommit(false); + assertTrue(task.process(time.milliseconds())); + assertTrue(task.process(time.milliseconds())); + + assertFalse(task.process(time.milliseconds())); + } + + @Test + public void shouldNotProcessRecordsAfterPrepareCommitWhenEosV2Enabled() { + task = createSingleSourceStateless(createConfig(StreamsConfig.EXACTLY_ONCE_V2, "0"), StreamsConfig.METRICS_LATEST); + + assertFalse(task.process(time.milliseconds())); + + task.addRecords(partition1, asList( + getConsumerRecordWithOffsetAsTimestamp(partition1, 10), + getConsumerRecordWithOffsetAsTimestamp(partition1, 20), + getConsumerRecordWithOffsetAsTimestamp(partition1, 30) + )); + + assertTrue(task.process(time.milliseconds())); + task.prepareCommit(); + assertFalse(task.process(time.milliseconds())); + task.postCommit(false); + assertTrue(task.process(time.milliseconds())); + assertTrue(task.process(time.milliseconds())); + + assertFalse(task.process(time.milliseconds())); + } + + @Test + public void shouldRecordBufferedRecords() { + task = createSingleSourceStateless(createConfig(AT_LEAST_ONCE, "0"), StreamsConfig.METRICS_LATEST); + + final KafkaMetric metric = getMetric("active-buffer", "%s-count", task.id().toString()); + + assertThat(metric.metricValue(), equalTo(0.0)); + + task.addRecords(partition1, asList( + getConsumerRecordWithOffsetAsTimestamp(partition1, 10), + getConsumerRecordWithOffsetAsTimestamp(partition1, 20) + )); + task.recordProcessTimeRatioAndBufferSize(100L, time.milliseconds()); + + assertThat(metric.metricValue(), equalTo(2.0)); + + assertTrue(task.process(0L)); + task.recordProcessTimeRatioAndBufferSize(100L, time.milliseconds()); + + assertThat(metric.metricValue(), equalTo(1.0)); + } + + @Test + public void shouldRecordProcessRatio() { + task = createStatelessTask(createConfig()); + + final KafkaMetric metric = getMetric("active-process", "%s-ratio", task.id().toString()); + + assertThat(metric.metricValue(), equalTo(0.0)); + + task.recordProcessBatchTime(10L); + task.recordProcessBatchTime(15L); + task.recordProcessTimeRatioAndBufferSize(100L, time.milliseconds()); + + assertThat(metric.metricValue(), equalTo(0.25)); + + task.recordProcessBatchTime(10L); + + assertThat(metric.metricValue(), equalTo(0.25)); + + task.recordProcessBatchTime(10L); + task.recordProcessTimeRatioAndBufferSize(20L, time.milliseconds()); + + assertThat(metric.metricValue(), equalTo(1.0)); + } + + @Test + public void shouldRecordE2ELatencyOnSourceNodeAndTerminalNodes() { + time = new MockTime(0L, 0L, 0L); + metrics = new Metrics(new MetricConfig().recordLevel(Sensor.RecordingLevel.INFO), time); + + // Create a processor that only forwards even keys to test the metrics at the source and terminal nodes + final MockSourceNode evenKeyForwardingSourceNode = new MockSourceNode(intDeserializer, intDeserializer) { + InternalProcessorContext context; + + @Override + public void init(final InternalProcessorContext context) { + this.context = context; + super.init(context); + } + + @Override + public void process(final Record record) { + if (record.key() % 2 == 0) { + context.forward(record); + } + } + }; + + task = createStatelessTaskWithForwardingTopology(evenKeyForwardingSourceNode); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + + final String sourceNodeName = evenKeyForwardingSourceNode.name(); + final String terminalNodeName = processorStreamTime.name(); + + final Metric sourceAvg = getProcessorMetric("record-e2e-latency", "%s-avg", task.id().toString(), sourceNodeName, StreamsConfig.METRICS_LATEST); + final Metric sourceMin = getProcessorMetric("record-e2e-latency", "%s-min", task.id().toString(), sourceNodeName, StreamsConfig.METRICS_LATEST); + final Metric sourceMax = getProcessorMetric("record-e2e-latency", "%s-max", task.id().toString(), sourceNodeName, StreamsConfig.METRICS_LATEST); + + final Metric terminalAvg = getProcessorMetric("record-e2e-latency", "%s-avg", task.id().toString(), terminalNodeName, StreamsConfig.METRICS_LATEST); + final Metric terminalMin = getProcessorMetric("record-e2e-latency", "%s-min", task.id().toString(), terminalNodeName, StreamsConfig.METRICS_LATEST); + final Metric terminalMax = getProcessorMetric("record-e2e-latency", "%s-max", task.id().toString(), terminalNodeName, StreamsConfig.METRICS_LATEST); + + // e2e latency = 10 + task.addRecords(partition1, singletonList(getConsumerRecordWithOffsetAsTimestamp(0, 0L))); + task.process(10L); + + assertThat(sourceAvg.metricValue(), equalTo(10.0)); + assertThat(sourceMin.metricValue(), equalTo(10.0)); + assertThat(sourceMax.metricValue(), equalTo(10.0)); + + // key 0: reaches terminal node + assertThat(terminalAvg.metricValue(), equalTo(10.0)); + assertThat(terminalMin.metricValue(), equalTo(10.0)); + assertThat(terminalMax.metricValue(), equalTo(10.0)); + + + // e2e latency = 15 + task.addRecords(partition1, singletonList(getConsumerRecordWithOffsetAsTimestamp(1, 0L))); + task.process(15L); + + assertThat(sourceAvg.metricValue(), equalTo(12.5)); + assertThat(sourceMin.metricValue(), equalTo(10.0)); + assertThat(sourceMax.metricValue(), equalTo(15.0)); + + // key 1: stops at source, doesn't affect terminal node metrics + assertThat(terminalAvg.metricValue(), equalTo(10.0)); + assertThat(terminalMin.metricValue(), equalTo(10.0)); + assertThat(terminalMax.metricValue(), equalTo(10.0)); + + + // e2e latency = 23 + task.addRecords(partition1, singletonList(getConsumerRecordWithOffsetAsTimestamp(2, 0L))); + task.process(23L); + + assertThat(sourceAvg.metricValue(), equalTo(16.0)); + assertThat(sourceMin.metricValue(), equalTo(10.0)); + assertThat(sourceMax.metricValue(), equalTo(23.0)); + + // key 2: reaches terminal node + assertThat(terminalAvg.metricValue(), equalTo(16.5)); + assertThat(terminalMin.metricValue(), equalTo(10.0)); + assertThat(terminalMax.metricValue(), equalTo(23.0)); + + + // e2e latency = 5 + task.addRecords(partition1, singletonList(getConsumerRecordWithOffsetAsTimestamp(3, 0L))); + task.process(5L); + + assertThat(sourceAvg.metricValue(), equalTo(13.25)); + assertThat(sourceMin.metricValue(), equalTo(5.0)); + assertThat(sourceMax.metricValue(), equalTo(23.0)); + + // key 3: stops at source, doesn't affect terminal node metrics + assertThat(terminalAvg.metricValue(), equalTo(16.5)); + assertThat(terminalMin.metricValue(), equalTo(10.0)); + assertThat(terminalMax.metricValue(), equalTo(23.0)); + } + + @Test + public void shouldThrowOnTimeoutExceptionAndBufferRecordForRetryIfEosDisabled() { + createTimeoutTask(AT_LEAST_ONCE); + + task.addRecords(partition1, singletonList(getConsumerRecordWithOffsetAsTimestamp(0, 0L))); + + final TimeoutException exception = assertThrows( + TimeoutException.class, + () -> task.process(0) + ); + assertThat(exception.getMessage(), equalTo("Kaboom!")); + + // we have only a single record that was not successfully processed + // however, the record should not be in the record buffer any longer, but should be cached within the task itself + assertThat(task.commitNeeded(), equalTo(false)); + assertThat(task.hasRecordsQueued(), equalTo(false)); + + // -> thus the task should try process the cached record now (that thus throw again) + final TimeoutException nextException = assertThrows( + TimeoutException.class, + () -> task.process(0) + ); + assertThat(nextException.getMessage(), equalTo("Kaboom!")); + } + + @Test + public void shouldThrowTaskCorruptedExceptionOnTimeoutExceptionIfEosEnabled() { + createTimeoutTask(StreamsConfig.EXACTLY_ONCE_V2); + + task.addRecords(partition1, singletonList(getConsumerRecordWithOffsetAsTimestamp(0, 0L))); + + assertThrows( + TaskCorruptedException.class, + () -> task.process(0) + ); + } + + @Test + public void testMetrics() { + task = createStatelessTask(createConfig("100")); + + assertNotNull(getMetric( + "enforced-processing", + "%s-rate", + task.id().toString() + )); + assertNotNull(getMetric( + "enforced-processing", + "%s-total", + task.id().toString() + )); + + assertNotNull(getMetric( + "record-lateness", + "%s-avg", + task.id().toString() + )); + assertNotNull(getMetric( + "record-lateness", + "%s-max", + task.id().toString() + )); + + assertNotNull(getMetric( + "active-process", + "%s-ratio", + task.id().toString() + )); + + assertNotNull(getMetric( + "active-buffer", + "%s-count", + task.id().toString() + )); + + testMetricsForBuiltInMetricsVersionLatest(); + + final JmxReporter reporter = new JmxReporter(); + final MetricsContext metricsContext = new KafkaMetricsContext("kafka.streams"); + reporter.contextChange(metricsContext); + + metrics.addReporter(reporter); + final String threadIdTag = THREAD_ID_TAG; + assertTrue(reporter.containsMbean(String.format( + "kafka.streams:type=stream-task-metrics,%s=%s,task-id=%s", + threadIdTag, + threadId, + task.id() + ))); + } + + private void testMetricsForBuiltInMetricsVersionLatest() { + final String builtInMetricsVersion = StreamsConfig.METRICS_LATEST; + assertNull(getMetric("commit", "%s-latency-avg", "all")); + assertNull(getMetric("commit", "%s-latency-max", "all")); + assertNull(getMetric("commit", "%s-rate", "all")); + assertNull(getMetric("commit", "%s-total", "all")); + + assertNotNull(getMetric("process", "%s-latency-max", task.id().toString())); + assertNotNull(getMetric("process", "%s-latency-avg", task.id().toString())); + + assertNotNull(getMetric("punctuate", "%s-latency-avg", task.id().toString())); + assertNotNull(getMetric("punctuate", "%s-latency-max", task.id().toString())); + assertNotNull(getMetric("punctuate", "%s-rate", task.id().toString())); + assertNotNull(getMetric("punctuate", "%s-total", task.id().toString())); + } + + private KafkaMetric getMetric(final String operation, + final String nameFormat, + final String taskId) { + final String descriptionIsNotVerified = ""; + return metrics.metrics().get(metrics.metricName( + String.format(nameFormat, operation), + "stream-task-metrics", + descriptionIsNotVerified, + mkMap( + mkEntry("task-id", taskId), + mkEntry(THREAD_ID_TAG, Thread.currentThread().getName()) + ) + )); + } + + private Metric getProcessorMetric(final String operation, + final String nameFormat, + final String taskId, + final String processorNodeId, + final String builtInMetricsVersion) { + + return getMetricByNameFilterByTags( + metrics.metrics(), + String.format(nameFormat, operation), + "stream-processor-node-metrics", + mkMap( + mkEntry("task-id", taskId), + mkEntry("processor-node-id", processorNodeId), + mkEntry(THREAD_ID_TAG, Thread.currentThread().getName() + ) + ) + ); + } + + @Test + public void shouldPauseAndResumeBasedOnBufferedRecords() { + task = createStatelessTask(createConfig("100")); + + task.addRecords(partition1, asList( + getConsumerRecordWithOffsetAsTimestamp(partition1, 10), + getConsumerRecordWithOffsetAsTimestamp(partition1, 20) + )); + + task.addRecords(partition2, asList( + getConsumerRecordWithOffsetAsTimestamp(partition2, 35), + getConsumerRecordWithOffsetAsTimestamp(partition2, 45), + getConsumerRecordWithOffsetAsTimestamp(partition2, 55), + getConsumerRecordWithOffsetAsTimestamp(partition2, 65) + )); + + assertTrue(task.process(0L)); + assertEquals(1, source1.numReceived); + assertEquals(0, source2.numReceived); + + assertEquals(1, consumer.paused().size()); + assertTrue(consumer.paused().contains(partition2)); + + task.addRecords(partition1, asList( + getConsumerRecordWithOffsetAsTimestamp(partition1, 30), + getConsumerRecordWithOffsetAsTimestamp(partition1, 40), + getConsumerRecordWithOffsetAsTimestamp(partition1, 50) + )); + + assertEquals(2, consumer.paused().size()); + assertTrue(consumer.paused().contains(partition1)); + assertTrue(consumer.paused().contains(partition2)); + + assertTrue(task.process(0L)); + assertEquals(2, source1.numReceived); + assertEquals(0, source2.numReceived); + + assertEquals(1, consumer.paused().size()); + assertTrue(consumer.paused().contains(partition2)); + + assertTrue(task.process(0L)); + assertEquals(3, source1.numReceived); + assertEquals(0, source2.numReceived); + + assertEquals(1, consumer.paused().size()); + assertTrue(consumer.paused().contains(partition2)); + + assertTrue(task.process(0L)); + assertEquals(3, source1.numReceived); + assertEquals(1, source2.numReceived); + + assertEquals(0, consumer.paused().size()); + } + + @Test + public void shouldPunctuateOnceStreamTimeAfterGap() { + task = createStatelessTask(createConfig()); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + + task.addRecords(partition1, asList( + getConsumerRecordWithOffsetAsTimestamp(partition1, 20), + getConsumerRecordWithOffsetAsTimestamp(partition1, 142), + getConsumerRecordWithOffsetAsTimestamp(partition1, 155), + getConsumerRecordWithOffsetAsTimestamp(partition1, 160) + )); + + task.addRecords(partition2, asList( + getConsumerRecordWithOffsetAsTimestamp(partition2, 25), + getConsumerRecordWithOffsetAsTimestamp(partition2, 145), + getConsumerRecordWithOffsetAsTimestamp(partition2, 159), + getConsumerRecordWithOffsetAsTimestamp(partition2, 161) + )); + + // st: -1 + assertFalse(task.maybePunctuateStreamTime()); // punctuate at 20 + + // st: 20 + assertTrue(task.process(0L)); + assertEquals(7, task.numBuffered()); + assertEquals(1, source1.numReceived); + assertEquals(0, source2.numReceived); + assertTrue(task.maybePunctuateStreamTime()); + + // st: 25 + assertTrue(task.process(0L)); + assertEquals(6, task.numBuffered()); + assertEquals(1, source1.numReceived); + assertEquals(1, source2.numReceived); + assertFalse(task.maybePunctuateStreamTime()); + + // st: 142 + // punctuate at 142 + assertTrue(task.process(0L)); + assertEquals(5, task.numBuffered()); + assertEquals(2, source1.numReceived); + assertEquals(1, source2.numReceived); + assertTrue(task.maybePunctuateStreamTime()); + + // st: 145 + // only one punctuation after 100ms gap + assertTrue(task.process(0L)); + assertEquals(4, task.numBuffered()); + assertEquals(2, source1.numReceived); + assertEquals(2, source2.numReceived); + assertFalse(task.maybePunctuateStreamTime()); + + // st: 155 + // punctuate at 155 + assertTrue(task.process(0L)); + assertEquals(3, task.numBuffered()); + assertEquals(3, source1.numReceived); + assertEquals(2, source2.numReceived); + assertTrue(task.maybePunctuateStreamTime()); + + // st: 159 + assertTrue(task.process(0L)); + assertEquals(2, task.numBuffered()); + assertEquals(3, source1.numReceived); + assertEquals(3, source2.numReceived); + assertFalse(task.maybePunctuateStreamTime()); + + // st: 160, aligned at 0 + assertTrue(task.process(0L)); + assertEquals(1, task.numBuffered()); + assertEquals(4, source1.numReceived); + assertEquals(3, source2.numReceived); + assertTrue(task.maybePunctuateStreamTime()); + + // st: 161 + assertTrue(task.process(0L)); + assertEquals(0, task.numBuffered()); + assertEquals(4, source1.numReceived); + assertEquals(4, source2.numReceived); + assertFalse(task.maybePunctuateStreamTime()); + + processorStreamTime.mockProcessor.checkAndClearPunctuateResult(PunctuationType.STREAM_TIME, 20L, 142L, 155L, 160L); + } + + @Test + public void shouldRespectPunctuateCancellationStreamTime() { + task = createStatelessTask(createConfig("100")); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + + task.addRecords(partition1, asList( + getConsumerRecordWithOffsetAsTimestamp(partition1, 20), + getConsumerRecordWithOffsetAsTimestamp(partition1, 30), + getConsumerRecordWithOffsetAsTimestamp(partition1, 40) + )); + + task.addRecords(partition2, asList( + getConsumerRecordWithOffsetAsTimestamp(partition2, 25), + getConsumerRecordWithOffsetAsTimestamp(partition2, 35), + getConsumerRecordWithOffsetAsTimestamp(partition2, 45) + )); + + assertFalse(task.maybePunctuateStreamTime()); + + // st is now 20 + assertTrue(task.process(0L)); + + assertTrue(task.maybePunctuateStreamTime()); + + // st is now 25 + assertTrue(task.process(0L)); + + assertFalse(task.maybePunctuateStreamTime()); + + // st is now 30 + assertTrue(task.process(0L)); + + processorStreamTime.mockProcessor.scheduleCancellable().cancel(); + + assertFalse(task.maybePunctuateStreamTime()); + + processorStreamTime.mockProcessor.checkAndClearPunctuateResult(PunctuationType.STREAM_TIME, 20L); + } + + @Test + public void shouldRespectPunctuateCancellationSystemTime() { + task = createStatelessTask(createConfig("100")); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + final long now = time.milliseconds(); + time.sleep(10); + assertTrue(task.maybePunctuateSystemTime()); + processorSystemTime.mockProcessor.scheduleCancellable().cancel(); + time.sleep(10); + assertFalse(task.maybePunctuateSystemTime()); + processorSystemTime.mockProcessor.checkAndClearPunctuateResult(PunctuationType.WALL_CLOCK_TIME, now + 10); + } + + @Test + public void shouldRespectCommitNeeded() { + task = createSingleSourceStateless(createConfig(AT_LEAST_ONCE, "0"), StreamsConfig.METRICS_LATEST); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + + assertFalse(task.commitNeeded()); + + task.addRecords(partition1, singletonList(getConsumerRecordWithOffsetAsTimestamp(partition1, 0))); + assertTrue(task.process(0L)); + assertTrue(task.commitNeeded()); + + task.prepareCommit(); + assertTrue(task.commitNeeded()); + + task.postCommit(true); + assertFalse(task.commitNeeded()); + + assertTrue(task.maybePunctuateStreamTime()); + assertTrue(task.commitNeeded()); + + task.prepareCommit(); + assertTrue(task.commitNeeded()); + + task.postCommit(true); + assertFalse(task.commitNeeded()); + + time.sleep(10); + assertTrue(task.maybePunctuateSystemTime()); + assertTrue(task.commitNeeded()); + + task.prepareCommit(); + assertTrue(task.commitNeeded()); + + task.postCommit(true); + assertFalse(task.commitNeeded()); + } + + @Test + public void shouldCommitNextOffsetFromQueueIfAvailable() { + task = createSingleSourceStateless(createConfig(AT_LEAST_ONCE, "0"), StreamsConfig.METRICS_LATEST); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + + task.addRecords(partition1, asList( + getConsumerRecordWithOffsetAsTimestamp(partition1, 0L), + getConsumerRecordWithOffsetAsTimestamp(partition1, 3L), + getConsumerRecordWithOffsetAsTimestamp(partition1, 5L))); + + task.process(0L); + task.process(0L); + + final Map offsetsAndMetadata = task.prepareCommit(); + + assertThat(offsetsAndMetadata, equalTo(mkMap(mkEntry(partition1, new OffsetAndMetadata(5L, encodeTimestamp(3L)))))); + } + + @Test + public void shouldCommitConsumerPositionIfRecordQueueIsEmpty() { + task = createStatelessTask(createConfig()); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + + consumer.addRecord(getConsumerRecordWithOffsetAsTimestamp(partition1, 0L)); + consumer.addRecord(getConsumerRecordWithOffsetAsTimestamp(partition1, 1L)); + consumer.addRecord(getConsumerRecordWithOffsetAsTimestamp(partition1, 2L)); + consumer.addRecord(getConsumerRecordWithOffsetAsTimestamp(partition2, 0L)); + consumer.poll(Duration.ZERO); + + task.addRecords(partition1, singletonList(getConsumerRecordWithOffsetAsTimestamp(partition1, 0L))); + task.addRecords(partition2, singletonList(getConsumerRecordWithOffsetAsTimestamp(partition2, 0L))); + task.process(0L); + + assertTrue(task.commitNeeded()); + assertThat(task.prepareCommit(), equalTo(mkMap(mkEntry(partition1, new OffsetAndMetadata(3L, encodeTimestamp(0L)))))); + task.postCommit(false); + + // the task should still be committed since the processed records have not reached the consumer position + assertTrue(task.commitNeeded()); + + consumer.poll(Duration.ZERO); + task.process(0L); + + assertTrue(task.commitNeeded()); + assertThat(task.prepareCommit(), equalTo(mkMap(mkEntry(partition1, new OffsetAndMetadata(3L, encodeTimestamp(0L))), + mkEntry(partition2, new OffsetAndMetadata(1L, encodeTimestamp(0L)))))); + task.postCommit(false); + + assertFalse(task.commitNeeded()); + } + + @Test + public void shouldFailOnCommitIfTaskIsClosed() { + task = createStatelessTask(createConfig()); + task.suspend(); + task.transitionTo(Task.State.CLOSED); + + final IllegalStateException thrown = assertThrows( + IllegalStateException.class, + task::prepareCommit + ); + + assertThat(thrown.getMessage(), is("Illegal state CLOSED while preparing active task 0_0 for committing")); + } + + @Test + public void shouldRespectCommitRequested() { + task = createStatelessTask(createConfig("100")); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + + task.requestCommit(); + assertTrue(task.commitRequested()); + } + + @Test + public void shouldEncodeAndDecodeMetadata() { + task = createStatelessTask(createConfig("100")); + assertEquals(DEFAULT_TIMESTAMP, task.decodeTimestamp(encodeTimestamp(DEFAULT_TIMESTAMP))); + } + + @Test + public void shouldReturnUnknownTimestampIfUnknownVersion() { + task = createStatelessTask(createConfig("100")); + + final byte[] emptyMessage = {StreamTask.LATEST_MAGIC_BYTE + 1}; + final String encodedString = Base64.getEncoder().encodeToString(emptyMessage); + assertEquals(RecordQueue.UNKNOWN, task.decodeTimestamp(encodedString)); + } + + @Test + public void shouldReturnUnknownTimestampIfEmptyMessage() { + task = createStatelessTask(createConfig("100")); + + assertEquals(RecordQueue.UNKNOWN, task.decodeTimestamp("")); + } + + @Test + public void shouldBeProcessableIfAllPartitionsBuffered() { + task = createStatelessTask(createConfig("100")); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + + assertThat("task is not idling", !task.timeCurrentIdlingStarted().isPresent()); + + assertFalse(task.process(0L)); + + final byte[] bytes = ByteBuffer.allocate(4).putInt(1).array(); + + task.addRecords(partition1, singleton(new ConsumerRecord<>(topic1, 1, 0, bytes, bytes))); + + assertFalse(task.process(0L)); + assertThat("task is idling", task.timeCurrentIdlingStarted().isPresent()); + + task.addRecords(partition2, singleton(new ConsumerRecord<>(topic2, 1, 0, bytes, bytes))); + + assertTrue(task.process(0L)); + assertThat("task is not idling", !task.timeCurrentIdlingStarted().isPresent()); + + } + + @Test + public void shouldBeRecordIdlingTimeIfSuspended() { + task = createStatelessTask(createConfig("100")); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + task.suspend(); + + assertThat("task is idling", task.timeCurrentIdlingStarted().isPresent()); + + task.resume(); + + assertThat("task is not idling", !task.timeCurrentIdlingStarted().isPresent()); + } + + public void shouldPunctuateSystemTimeWhenIntervalElapsed() { + task = createStatelessTask(createConfig("100")); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + final long now = time.milliseconds(); + time.sleep(10); + assertTrue(task.maybePunctuateSystemTime()); + time.sleep(10); + assertTrue(task.maybePunctuateSystemTime()); + time.sleep(9); + assertFalse(task.maybePunctuateSystemTime()); + time.sleep(1); + assertTrue(task.maybePunctuateSystemTime()); + time.sleep(20); + assertTrue(task.maybePunctuateSystemTime()); + assertFalse(task.maybePunctuateSystemTime()); + processorSystemTime.mockProcessor.checkAndClearPunctuateResult(PunctuationType.WALL_CLOCK_TIME, now + 10, now + 20, now + 30, now + 50); + } + + @Test + public void shouldNotPunctuateSystemTimeWhenIntervalNotElapsed() { + task = createStatelessTask(createConfig("100")); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + assertFalse(task.maybePunctuateSystemTime()); + time.sleep(9); + assertFalse(task.maybePunctuateSystemTime()); + processorSystemTime.mockProcessor.checkAndClearPunctuateResult(PunctuationType.WALL_CLOCK_TIME); + } + + @Test + public void shouldPunctuateOnceSystemTimeAfterGap() { + task = createStatelessTask(createConfig("100")); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + final long now = time.milliseconds(); + time.sleep(100); + assertTrue(task.maybePunctuateSystemTime()); + assertFalse(task.maybePunctuateSystemTime()); + time.sleep(10); + assertTrue(task.maybePunctuateSystemTime()); + time.sleep(12); + assertTrue(task.maybePunctuateSystemTime()); + time.sleep(7); + assertFalse(task.maybePunctuateSystemTime()); + time.sleep(1); // punctuate at now + 130 + assertTrue(task.maybePunctuateSystemTime()); + time.sleep(105); // punctuate at now + 235 + assertTrue(task.maybePunctuateSystemTime()); + assertFalse(task.maybePunctuateSystemTime()); + time.sleep(5); // punctuate at now + 240, still aligned on the initial punctuation + assertTrue(task.maybePunctuateSystemTime()); + assertFalse(task.maybePunctuateSystemTime()); + processorSystemTime.mockProcessor.checkAndClearPunctuateResult(PunctuationType.WALL_CLOCK_TIME, now + 100, now + 110, now + 122, now + 130, now + 235, now + 240); + } + + @Test + public void shouldWrapKafkaExceptionsWithStreamsExceptionAndAddContextWhenPunctuatingStreamTime() { + task = createStatelessTask(createConfig("100")); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + + try { + task.punctuate(processorStreamTime, 1, PunctuationType.STREAM_TIME, timestamp -> { + throw new KafkaException("KABOOM!"); + }); + fail("Should've thrown StreamsException"); + } catch (final StreamsException e) { + final String message = e.getMessage(); + assertTrue("message=" + message + " should contain processor", message.contains("processor '" + processorStreamTime.name() + "'")); + assertThat(task.processorContext().currentNode(), nullValue()); + } + } + + @Test + public void shouldWrapKafkaExceptionsWithStreamsExceptionAndAddContextWhenPunctuatingWallClockTimeTime() { + task = createStatelessTask(createConfig("100")); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + + try { + task.punctuate(processorSystemTime, 1, PunctuationType.WALL_CLOCK_TIME, timestamp -> { + throw new KafkaException("KABOOM!"); + }); + fail("Should've thrown StreamsException"); + } catch (final StreamsException e) { + final String message = e.getMessage(); + assertTrue("message=" + message + " should contain processor", message.contains("processor '" + processorSystemTime.name() + "'")); + assertThat(task.processorContext().currentNode(), nullValue()); + } + } + + @Test + public void shouldNotShareHeadersBetweenPunctuateIterations() { + task = createStatelessTask(createConfig("100")); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + + task.punctuate( + processorSystemTime, + 1L, + PunctuationType.WALL_CLOCK_TIME, + timestamp -> task.processorContext().headers().add("dummy", null) + ); + task.punctuate( + processorSystemTime, + 1L, + PunctuationType.WALL_CLOCK_TIME, + timestamp -> assertFalse(task.processorContext().headers().iterator().hasNext()) + ); + } + + @Test + public void shouldWrapKafkaExceptionWithStreamsExceptionWhenProcess() { + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()).anyTimes(); + EasyMock.expect(stateManager.changelogOffsets()).andReturn(emptyMap()).anyTimes(); + EasyMock.expect(recordCollector.offsets()).andReturn(emptyMap()).anyTimes(); + EasyMock.replay(stateManager, recordCollector); + + task = createFaultyStatefulTask(createConfig("100")); + + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + + task.addRecords(partition1, asList( + getConsumerRecordWithOffsetAsTimestamp(partition1, 10), + getConsumerRecordWithOffsetAsTimestamp(partition1, 20), + getConsumerRecordWithOffsetAsTimestamp(partition1, 30) + )); + task.addRecords(partition2, asList( + getConsumerRecordWithOffsetAsTimestamp(partition2, 5), // this is the first record to process + getConsumerRecordWithOffsetAsTimestamp(partition2, 35), + getConsumerRecordWithOffsetAsTimestamp(partition2, 45) + )); + + assertThat("Map did not contain the partitions", task.highWaterMark().containsKey(partition1) + && task.highWaterMark().containsKey(partition2)); + assertThrows(StreamsException.class, () -> task.process(0L)); + } + + @Test + public void shouldReadCommittedOffsetAndRethrowTimeoutWhenCompleteRestoration() throws IOException { + stateDirectory = EasyMock.createNiceMock(StateDirectory.class); + EasyMock.expect(stateDirectory.lock(taskId)).andReturn(true); + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()).anyTimes(); + EasyMock.expect(stateManager.changelogOffsets()).andReturn(emptyMap()).anyTimes(); + + EasyMock.replay(recordCollector, stateDirectory, stateManager); + + task = createDisconnectedTask(createConfig("100")); + + task.transitionTo(RESTORING); + + assertThrows(TimeoutException.class, () -> task.completeRestoration(noOpResetter -> { })); + } + + @Test + public void shouldReInitializeTopologyWhenResuming() throws IOException { + stateDirectory = EasyMock.createNiceMock(StateDirectory.class); + EasyMock.expect(stateDirectory.lock(taskId)).andReturn(true); + EasyMock.expect(recordCollector.offsets()).andThrow(new AssertionError("Should not try to read offsets")).anyTimes(); + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()).anyTimes(); + + EasyMock.replay(recordCollector, stateDirectory, stateManager); + + task = createStatefulTask(createConfig("100"), true); + + task.initializeIfNeeded(); + + task.suspend(); + + assertEquals(SUSPENDED, task.state()); + assertFalse(source1.initialized); + assertFalse(source2.initialized); + + task.resume(); + + assertEquals(RESTORING, task.state()); + assertFalse(source1.initialized); + assertFalse(source2.initialized); + + task.completeRestoration(noOpResetter -> { }); + + assertEquals(RUNNING, task.state()); + assertTrue(source1.initialized); + assertTrue(source2.initialized); + + EasyMock.verify(stateManager, recordCollector); + + EasyMock.reset(recordCollector); + EasyMock.expect(recordCollector.offsets()).andReturn(emptyMap()); + EasyMock.replay(recordCollector); + assertThat("Map did not contain the partition", task.highWaterMark().containsKey(partition1)); + } + + @Test + public void shouldNotCheckpointOffsetsAgainOnCommitIfSnapshotNotChangedMuch() { + final Long offset = 543L; + + EasyMock.expect(recordCollector.offsets()).andReturn(singletonMap(changelogPartition, offset)).anyTimes(); + stateManager.checkpoint(); + EasyMock.expectLastCall().once(); + EasyMock.expect(stateManager.changelogOffsets()) + .andReturn(singletonMap(changelogPartition, 10L)) + .andReturn(singletonMap(changelogPartition, 20L)); + EasyMock.expectLastCall(); + EasyMock.replay(stateManager, recordCollector); + + task = createStatefulTask(createConfig("100"), true); + + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + + task.prepareCommit(); + task.postCommit(true); // should checkpoint + + task.prepareCommit(); + task.postCommit(false); // should not checkpoint + + EasyMock.verify(stateManager, recordCollector); + assertThat("Map was empty", task.highWaterMark().size() == 2); + } + + @Test + public void shouldCheckpointOffsetsOnCommitIfSnapshotMuchChanged() { + final Long offset = 543L; + + EasyMock.expect(recordCollector.offsets()).andReturn(singletonMap(changelogPartition, offset)).anyTimes(); + stateManager.checkpoint(); + EasyMock.expectLastCall().times(2); + EasyMock.expect(stateManager.changelogPartitions()).andReturn(singleton(changelogPartition)); + EasyMock.expect(stateManager.changelogOffsets()) + .andReturn(singletonMap(changelogPartition, 0L)) + .andReturn(singletonMap(changelogPartition, 10L)) + .andReturn(singletonMap(changelogPartition, 12000L)); + stateManager.registerStore(stateStore, stateStore.stateRestoreCallback); + EasyMock.expectLastCall(); + EasyMock.replay(stateManager, recordCollector); + + task = createStatefulTask(createConfig("100"), true); + + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + task.prepareCommit(); + task.postCommit(true); + + task.prepareCommit(); + task.postCommit(false); + + EasyMock.verify(recordCollector); + assertThat("Map was empty", task.highWaterMark().size() == 2); + } + + @Test + public void shouldNotCheckpointOffsetsOnCommitIfEosIsEnabled() { + EasyMock.expect(stateManager.changelogPartitions()).andReturn(singleton(changelogPartition)); + stateManager.registerStore(stateStore, stateStore.stateRestoreCallback); + EasyMock.expectLastCall(); + EasyMock.expect(recordCollector.offsets()).andReturn(emptyMap()).anyTimes(); + EasyMock.replay(stateManager, recordCollector); + + task = createStatefulTask(createConfig(StreamsConfig.EXACTLY_ONCE_V2, "100"), true); + + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + task.prepareCommit(); + task.postCommit(false); + final File checkpointFile = new File( + stateDirectory.getOrCreateDirectoryForTask(taskId), + StateManagerUtil.CHECKPOINT_FILE_NAME + ); + + assertFalse(checkpointFile.exists()); + } + + @SuppressWarnings("unchecked") + @Test + public void shouldThrowIllegalStateExceptionIfCurrentNodeIsNotNullWhenPunctuateCalled() { + task = createStatelessTask(createConfig("100")); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + task.processorContext().setCurrentNode(processorStreamTime); + try { + task.punctuate(processorStreamTime, 10, PunctuationType.STREAM_TIME, punctuator); + fail("Should throw illegal state exception as current node is not null"); + } catch (final IllegalStateException e) { + // pass + } + } + + @Test + public void shouldCallPunctuateOnPassedInProcessorNode() { + task = createStatelessTask(createConfig("100")); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + task.punctuate(processorStreamTime, 5, PunctuationType.STREAM_TIME, punctuator); + assertThat(punctuatedAt, equalTo(5L)); + task.punctuate(processorStreamTime, 10, PunctuationType.STREAM_TIME, punctuator); + assertThat(punctuatedAt, equalTo(10L)); + } + + @Test + public void shouldSetProcessorNodeOnContextBackToNullAfterSuccessfulPunctuate() { + task = createStatelessTask(createConfig("100")); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + task.punctuate(processorStreamTime, 5, PunctuationType.STREAM_TIME, punctuator); + assertThat(task.processorContext().currentNode(), nullValue()); + } + + @Test + public void shouldThrowIllegalStateExceptionOnScheduleIfCurrentNodeIsNull() { + task = createStatelessTask(createConfig("100")); + assertThrows(IllegalStateException.class, () -> task.schedule(1, PunctuationType.STREAM_TIME, timestamp -> { })); + } + + @SuppressWarnings("unchecked") + @Test + public void shouldNotThrowExceptionOnScheduleIfCurrentNodeIsNotNull() { + task = createStatelessTask(createConfig("100")); + task.processorContext().setCurrentNode(processorStreamTime); + task.schedule(1, PunctuationType.STREAM_TIME, timestamp -> { }); + } + + @Test + public void shouldCloseStateManagerEvenDuringFailureOnUncleanTaskClose() { + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()).anyTimes(); + EasyMock.expectLastCall(); + stateManager.close(); + EasyMock.expectLastCall(); + EasyMock.expect(recordCollector.offsets()).andReturn(emptyMap()).anyTimes(); + EasyMock.replay(stateManager, recordCollector); + + task = createFaultyStatefulTask(createConfig("100")); + + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + + assertThrows(RuntimeException.class, () -> task.suspend()); + task.closeDirty(); + + EasyMock.verify(stateManager); + } + + @Test + public void shouldReturnOffsetsForRepartitionTopicsForPurging() { + final TopicPartition repartition = new TopicPartition("repartition", 1); + + final ProcessorTopology topology = withRepartitionTopics( + asList(source1, source2), + mkMap(mkEntry(topic1, source1), mkEntry(repartition.topic(), source2)), + singleton(repartition.topic()) + ); + consumer.assign(asList(partition1, repartition)); + consumer.updateBeginningOffsets(mkMap(mkEntry(repartition, 0L))); + + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()); + EasyMock.expect(recordCollector.offsets()).andReturn(emptyMap()).anyTimes(); + EasyMock.replay(stateManager, recordCollector); + + final StreamsConfig config = createConfig(); + final InternalProcessorContext context = new ProcessorContextImpl( + taskId, + config, + stateManager, + streamsMetrics, + null + ); + + task = new StreamTask( + taskId, + mkSet(partition1, repartition), + topology, + consumer, + config, + streamsMetrics, + stateDirectory, + cache, + time, + stateManager, + recordCollector, + context, + logContext); + + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + + task.addRecords(partition1, singletonList(getConsumerRecordWithOffsetAsTimestamp(partition1, 5L))); + task.addRecords(repartition, singletonList(getConsumerRecordWithOffsetAsTimestamp(repartition, 10L))); + + assertTrue(task.process(0L)); + assertTrue(task.process(0L)); + + task.prepareCommit(); + + final Map map = task.purgeableOffsets(); + + assertThat(map, equalTo(singletonMap(repartition, 11L))); + } + + @Test + public void shouldThrowStreamsExceptionWhenFetchCommittedFailed() { + EasyMock.expect(stateManager.changelogPartitions()).andReturn(singleton(partition1)); + EasyMock.replay(stateManager); + + final Consumer consumer = new MockConsumer(OffsetResetStrategy.EARLIEST) { + @Override + public Map committed(final Set partitions) { + throw new KafkaException("KABOOM!"); + } + }; + + task = createOptimizedStatefulTask(createConfig("100"), consumer); + + task.transitionTo(RESTORING); + + assertThrows(StreamsException.class, () -> task.completeRestoration(noOpResetter -> { })); + } + + @Test + public void shouldThrowIfCommittingOnIllegalState() { + task = createStatelessTask(createConfig("100")); + + task.transitionTo(SUSPENDED); + task.transitionTo(Task.State.CLOSED); + assertThrows(IllegalStateException.class, task::prepareCommit); + } + + @Test + public void shouldThrowIfPostCommittingOnIllegalState() { + task = createStatelessTask(createConfig("100")); + + task.transitionTo(SUSPENDED); + task.transitionTo(Task.State.CLOSED); + assertThrows(IllegalStateException.class, () -> task.postCommit(true)); + } + + @Test + public void shouldSkipCheckpointingSuspendedCreatedTask() { + stateManager.checkpoint(); + EasyMock.expectLastCall().andThrow(new AssertionError("Should not have tried to checkpoint")); + EasyMock.expect(recordCollector.offsets()).andReturn(emptyMap()).anyTimes(); + EasyMock.replay(stateManager, recordCollector); + + task = createStatefulTask(createConfig("100"), true); + task.suspend(); + task.postCommit(true); + } + + @Test + public void shouldCheckpointForSuspendedTask() { + stateManager.checkpoint(); + EasyMock.expectLastCall().once(); + EasyMock.expect(stateManager.changelogOffsets()) + .andReturn(singletonMap(partition1, 1L)); + EasyMock.expect(recordCollector.offsets()).andReturn(emptyMap()).anyTimes(); + EasyMock.replay(stateManager, recordCollector); + + task = createStatefulTask(createConfig("100"), true); + task.initializeIfNeeded(); + task.suspend(); + task.postCommit(true); + EasyMock.verify(stateManager); + } + + @Test + public void shouldNotCheckpointForSuspendedRunningTaskWithSmallProgress() { + EasyMock.expect(stateManager.changelogOffsets()) + .andReturn(singletonMap(partition1, 1L)) + .andReturn(singletonMap(partition1, 2L)); + stateManager.checkpoint(); + EasyMock.expectLastCall().andThrow(new AssertionError("Checkpoint should not be called")).anyTimes(); + EasyMock.replay(stateManager); + + task = createStatefulTask(createConfig("100"), true); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + + task.prepareCommit(); + task.postCommit(false); + + task.suspend(); + task.postCommit(false); + EasyMock.verify(stateManager); + } + + @Test + public void shouldCheckpointForSuspendedRunningTaskWithLargeProgress() { + EasyMock.expect(stateManager.changelogOffsets()) + .andReturn(singletonMap(partition1, 12000L)) + .andReturn(singletonMap(partition1, 24000L)); + stateManager.checkpoint(); + EasyMock.expectLastCall().times(2); + EasyMock.replay(stateManager); + + task = createStatefulTask(createConfig("100"), true); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + + task.prepareCommit(); + task.postCommit(false); + + task.suspend(); + task.postCommit(false); + EasyMock.verify(stateManager); + } + + @Test + public void shouldCheckpointWhileUpdateSnapshotWithTheConsumedOffsetsForSuspendedRunningTask() { + final Map checkpointableOffsets = singletonMap(partition1, 1L); + stateManager.checkpoint(); + EasyMock.expectLastCall().once(); + stateManager.updateChangelogOffsets(EasyMock.eq(checkpointableOffsets)); + EasyMock.expectLastCall().once(); + EasyMock.expect(stateManager.changelogOffsets()) + .andReturn(checkpointableOffsets); + EasyMock.expect(recordCollector.offsets()).andReturn(checkpointableOffsets).once(); + EasyMock.replay(stateManager, recordCollector); + + task = createStatefulTask(createConfig(), true); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + task.addRecords(partition1, singleton(getConsumerRecordWithOffsetAsTimestamp(partition1, 10))); + task.addRecords(partition2, singleton(getConsumerRecordWithOffsetAsTimestamp(partition2, 10))); + task.process(100L); + assertTrue(task.commitNeeded()); + + task.suspend(); + task.postCommit(true); + EasyMock.verify(stateManager, recordCollector); + } + + @Test + public void shouldReturnStateManagerChangelogOffsets() { + EasyMock.expect(stateManager.changelogOffsets()).andReturn(singletonMap(partition1, 50L)).anyTimes(); + EasyMock.expect(stateManager.changelogPartitions()).andReturn(singleton(partition1)).anyTimes(); + EasyMock.expect(recordCollector.offsets()).andReturn(emptyMap()).anyTimes(); + EasyMock.replay(stateManager, recordCollector); + + task = createOptimizedStatefulTask(createConfig("100"), consumer); + + task.initializeIfNeeded(); + + assertEquals(singletonMap(partition1, 50L), task.changelogOffsets()); + + task.completeRestoration(noOpResetter -> { }); + + assertEquals(singletonMap(partition1, Task.LATEST_OFFSET), task.changelogOffsets()); + } + + @Test + public void shouldNotCheckpointOnCloseCreated() { + stateManager.flush(); + EasyMock.expectLastCall().andThrow(new AssertionError("Flush should not be called")).anyTimes(); + stateManager.checkpoint(); + EasyMock.expectLastCall().andThrow(new AssertionError("Checkpoint should not be called")).anyTimes(); + EasyMock.expect(recordCollector.offsets()).andReturn(emptyMap()).anyTimes(); + EasyMock.replay(stateManager, recordCollector); + + final MetricName metricName = setupCloseTaskMetric(); + + task = createOptimizedStatefulTask(createConfig("100"), consumer); + + task.suspend(); + task.closeClean(); + + assertEquals(Task.State.CLOSED, task.state()); + assertFalse(source1.initialized); + assertFalse(source1.closed); + + EasyMock.verify(stateManager, recordCollector); + + final double expectedCloseTaskMetric = 1.0; + verifyCloseTaskMetric(expectedCloseTaskMetric, streamsMetrics, metricName); + } + + @Test + public void shouldCheckpointOnCloseRestoringIfNoProgress() { + stateManager.flush(); + EasyMock.expectLastCall().once(); + stateManager.checkpoint(); + EasyMock.expectLastCall().once(); + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()).anyTimes(); + EasyMock.expect(stateManager.changelogOffsets()).andReturn(emptyMap()).anyTimes(); + EasyMock.expect(recordCollector.offsets()).andReturn(emptyMap()).anyTimes(); + EasyMock.replay(stateManager, recordCollector); + + task = createOptimizedStatefulTask(createConfig("100"), consumer); + + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + task.suspend(); + task.prepareCommit(); + task.postCommit(true); + task.closeClean(); + + assertEquals(Task.State.CLOSED, task.state()); + + EasyMock.verify(stateManager); + } + + @Test + public void shouldCheckpointOffsetsOnPostCommit() { + final long offset = 543L; + final long consumedOffset = 345L; + + EasyMock.expect(recordCollector.offsets()).andReturn(singletonMap(changelogPartition, offset)).anyTimes(); + EasyMock.expectLastCall(); + stateManager.checkpoint(); + EasyMock.expectLastCall().once(); + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()).anyTimes(); + EasyMock.expect(stateManager.changelogOffsets()) + .andReturn(singletonMap(partition1, offset + 12000L)); + EasyMock.replay(recordCollector, stateManager); + + task = createOptimizedStatefulTask(createConfig(), consumer); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + + task.addRecords(partition1, singletonList(getConsumerRecordWithOffsetAsTimestamp(partition1, consumedOffset))); + task.process(100L); + assertTrue(task.commitNeeded()); + + task.suspend(); + task.prepareCommit(); + task.postCommit(false); + + assertEquals(SUSPENDED, task.state()); + + EasyMock.verify(stateManager); + } + + @Test + public void shouldThrowExceptionOnCloseCleanError() { + final long offset = 543L; + + EasyMock.expect(recordCollector.offsets()).andReturn(emptyMap()).anyTimes(); + stateManager.checkpoint(); + EasyMock.expectLastCall().once(); + EasyMock.expect(stateManager.changelogPartitions()).andReturn(singleton(changelogPartition)).anyTimes(); + EasyMock.expect(stateManager.changelogOffsets()).andReturn(singletonMap(changelogPartition, offset)).anyTimes(); + stateManager.close(); + EasyMock.expectLastCall().andThrow(new ProcessorStateException("KABOOM!")).anyTimes(); + EasyMock.replay(recordCollector, stateManager); + final MetricName metricName = setupCloseTaskMetric(); + + task = createOptimizedStatefulTask(createConfig("100"), consumer); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + + task.addRecords(partition1, singletonList(getConsumerRecordWithOffsetAsTimestamp(partition1, offset))); + task.process(100L); + assertTrue(task.commitNeeded()); + + task.suspend(); + task.prepareCommit(); + task.postCommit(true); + assertThrows(ProcessorStateException.class, () -> task.closeClean()); + + final double expectedCloseTaskMetric = 0.0; + verifyCloseTaskMetric(expectedCloseTaskMetric, streamsMetrics, metricName); + + EasyMock.verify(stateManager); + EasyMock.reset(stateManager); + EasyMock.expect(stateManager.changelogPartitions()).andStubReturn(singleton(changelogPartition)); + stateManager.close(); + EasyMock.expectLastCall(); + EasyMock.replay(stateManager); + } + + @Test + public void shouldThrowOnCloseCleanFlushError() { + final long offset = 543L; + + EasyMock.expect(recordCollector.offsets()).andReturn(singletonMap(changelogPartition, offset)); + stateManager.flushCache(); + EasyMock.expectLastCall().andThrow(new ProcessorStateException("KABOOM!")).anyTimes(); + stateManager.flush(); + EasyMock.expectLastCall().andThrow(new AssertionError("Flush should not be called")).anyTimes(); + stateManager.checkpoint(); + EasyMock.expectLastCall().andThrow(new AssertionError("Checkpoint should not be called")).anyTimes(); + stateManager.close(); + EasyMock.expectLastCall().andThrow(new AssertionError("Close should not be called!")).anyTimes(); + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()).anyTimes(); + EasyMock.expect(stateManager.changelogOffsets()).andReturn(emptyMap()).anyTimes(); + EasyMock.replay(recordCollector, stateManager); + final MetricName metricName = setupCloseTaskMetric(); + + task = createOptimizedStatefulTask(createConfig("100"), consumer); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + + // process one record to make commit needed + task.addRecords(partition1, singletonList(getConsumerRecordWithOffsetAsTimestamp(partition1, offset))); + task.process(100L); + + assertThrows(ProcessorStateException.class, task::prepareCommit); + + assertEquals(RUNNING, task.state()); + + final double expectedCloseTaskMetric = 0.0; + verifyCloseTaskMetric(expectedCloseTaskMetric, streamsMetrics, metricName); + + EasyMock.verify(stateManager); + EasyMock.reset(stateManager); + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()).anyTimes(); + EasyMock.replay(stateManager); + } + + @Test + public void shouldThrowOnCloseCleanCheckpointError() { + final long offset = 54300L; + EasyMock.expect(recordCollector.offsets()).andReturn(emptyMap()); + stateManager.checkpoint(); + EasyMock.expectLastCall().andThrow(new ProcessorStateException("KABOOM!")).anyTimes(); + stateManager.close(); + EasyMock.expectLastCall().andThrow(new AssertionError("Close should not be called!")).anyTimes(); + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()).anyTimes(); + EasyMock.expect(stateManager.changelogOffsets()) + .andReturn(singletonMap(partition1, offset)); + EasyMock.replay(recordCollector, stateManager); + final MetricName metricName = setupCloseTaskMetric(); + + task = createOptimizedStatefulTask(createConfig("100"), consumer); + task.initializeIfNeeded(); + + task.addRecords(partition1, singletonList(getConsumerRecordWithOffsetAsTimestamp(partition1, offset))); + task.process(100L); + assertTrue(task.commitNeeded()); + + task.suspend(); + task.prepareCommit(); + assertThrows(ProcessorStateException.class, () -> task.postCommit(true)); + + assertEquals(Task.State.SUSPENDED, task.state()); + + final double expectedCloseTaskMetric = 0.0; + verifyCloseTaskMetric(expectedCloseTaskMetric, streamsMetrics, metricName); + + EasyMock.verify(stateManager); + EasyMock.reset(stateManager); + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()).anyTimes(); + stateManager.close(); + EasyMock.expectLastCall(); + EasyMock.replay(stateManager); + } + + @Test + public void shouldNotThrowFromStateManagerCloseInCloseDirty() { + stateManager.close(); + EasyMock.expectLastCall().andThrow(new RuntimeException("KABOOM!")).anyTimes(); + EasyMock.expect(stateManager.changelogOffsets()).andReturn(Collections.emptyMap()).anyTimes(); + EasyMock.replay(stateManager); + + task = createOptimizedStatefulTask(createConfig("100"), consumer); + task.initializeIfNeeded(); + + task.suspend(); + task.closeDirty(); + + EasyMock.verify(stateManager); + } + + @Test + public void shouldUnregisterMetricsInCloseClean() { + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()).anyTimes(); + EasyMock.expect(recordCollector.offsets()).andReturn(Collections.emptyMap()).anyTimes(); + EasyMock.replay(stateManager, recordCollector); + + task = createOptimizedStatefulTask(createConfig("100"), consumer); + + task.suspend(); + assertThat(getTaskMetrics(), not(empty())); + task.closeClean(); + assertThat(getTaskMetrics(), empty()); + } + + @Test + public void shouldUnregisterMetricsInCloseDirty() { + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()).anyTimes(); + EasyMock.expect(recordCollector.offsets()).andReturn(Collections.emptyMap()).anyTimes(); + EasyMock.replay(stateManager, recordCollector); + + task = createOptimizedStatefulTask(createConfig("100"), consumer); + + task.suspend(); + assertThat(getTaskMetrics(), not(empty())); + task.closeDirty(); + assertThat(getTaskMetrics(), empty()); + } + + @Test + public void shouldUnregisterMetricsInCloseCleanAndRecycleState() { + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()).anyTimes(); + EasyMock.expect(recordCollector.offsets()).andReturn(Collections.emptyMap()).anyTimes(); + EasyMock.replay(stateManager, recordCollector); + + task = createOptimizedStatefulTask(createConfig("100"), consumer); + + task.suspend(); + assertThat(getTaskMetrics(), not(empty())); + task.closeCleanAndRecycleState(); + assertThat(getTaskMetrics(), empty()); + } + + @Test + public void shouldClearCommitStatusesInCloseDirty() { + task = createSingleSourceStateless(createConfig(AT_LEAST_ONCE, "0"), StreamsConfig.METRICS_LATEST); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + + task.addRecords(partition1, singletonList(getConsumerRecordWithOffsetAsTimestamp(partition1, 0))); + assertTrue(task.process(0L)); + task.requestCommit(); + + task.suspend(); + assertThat(task.commitNeeded(), is(true)); + assertThat(task.commitRequested(), is(true)); + task.closeDirty(); + assertThat(task.commitNeeded(), is(false)); + assertThat(task.commitRequested(), is(false)); + } + + @Test + public void closeShouldBeIdempotent() { + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()).anyTimes(); + EasyMock.expect(recordCollector.offsets()).andReturn(Collections.emptyMap()).anyTimes(); + EasyMock.replay(stateManager, recordCollector); + + task = createOptimizedStatefulTask(createConfig("100"), consumer); + + task.suspend(); + task.closeClean(); + + // close calls are idempotent since we are already in closed + task.closeClean(); + task.closeDirty(); + + EasyMock.reset(stateManager); + EasyMock.replay(stateManager); + } + + @Test + public void shouldUpdatePartitions() { + task = createStatelessTask(createConfig()); + final Set newPartitions = new HashSet<>(task.inputPartitions()); + newPartitions.add(new TopicPartition("newTopic", 0)); + + task.updateInputPartitions(newPartitions, mkMap( + mkEntry(source1.name(), asList(topic1, "newTopic")), + mkEntry(source2.name(), singletonList(topic2))) + ); + + assertThat(task.inputPartitions(), equalTo(newPartitions)); + } + + @Test + public void shouldThrowIfCleanClosingDirtyTask() { + task = createSingleSourceStateless(createConfig(AT_LEAST_ONCE, "0"), StreamsConfig.METRICS_LATEST); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + + task.addRecords(partition1, singletonList(getConsumerRecordWithOffsetAsTimestamp(partition1, 0))); + assertTrue(task.process(0L)); + assertTrue(task.commitNeeded()); + + assertThrows(TaskMigratedException.class, () -> task.closeClean()); + } + + @Test + public void shouldThrowIfRecyclingDirtyTask() { + task = createStatelessTask(createConfig()); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + + task.addRecords(partition1, singletonList(getConsumerRecordWithOffsetAsTimestamp(partition1, 0))); + task.addRecords(partition2, singletonList(getConsumerRecordWithOffsetAsTimestamp(partition2, 0))); + task.process(0L); + assertTrue(task.commitNeeded()); + + assertThrows(TaskMigratedException.class, () -> task.closeCleanAndRecycleState()); + } + + @Test + public void shouldOnlyRecycleSuspendedTasks() { + stateManager.recycle(); + recordCollector.closeClean(); + EasyMock.expect(stateManager.changelogOffsets()).andReturn(Collections.emptyMap()).anyTimes(); + EasyMock.replay(stateManager, recordCollector); + + task = createStatefulTask(createConfig("100"), true); + assertThrows(IllegalStateException.class, () -> task.closeCleanAndRecycleState()); // CREATED + + task.initializeIfNeeded(); + assertThrows(IllegalStateException.class, () -> task.closeCleanAndRecycleState()); // RESTORING + + task.completeRestoration(noOpResetter -> { }); + assertThrows(IllegalStateException.class, () -> task.closeCleanAndRecycleState()); // RUNNING + + task.suspend(); + task.closeCleanAndRecycleState(); // SUSPENDED + + EasyMock.verify(stateManager, recordCollector); + } + + @Test + public void shouldAlwaysSuspendCreatedTasks() { + EasyMock.replay(stateManager); + task = createStatefulTask(createConfig("100"), true); + assertThat(task.state(), equalTo(CREATED)); + task.suspend(); + assertThat(task.state(), equalTo(SUSPENDED)); + } + + @Test + public void shouldAlwaysSuspendRestoringTasks() { + EasyMock.expect(stateManager.changelogOffsets()).andReturn(Collections.emptyMap()).anyTimes(); + EasyMock.replay(stateManager); + task = createStatefulTask(createConfig("100"), true); + task.initializeIfNeeded(); + assertThat(task.state(), equalTo(RESTORING)); + task.suspend(); + assertThat(task.state(), equalTo(SUSPENDED)); + } + + @Test + public void shouldAlwaysSuspendRunningTasks() { + EasyMock.expect(stateManager.changelogOffsets()).andReturn(Collections.emptyMap()).anyTimes(); + EasyMock.replay(stateManager); + task = createFaultyStatefulTask(createConfig("100")); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + assertThat(task.state(), equalTo(RUNNING)); + assertThrows(RuntimeException.class, () -> task.suspend()); + assertThat(task.state(), equalTo(SUSPENDED)); + } + + @Test + public void shouldThrowTopologyExceptionIfTaskCreatedForUnknownTopic() { + final InternalProcessorContext context = new ProcessorContextImpl( + taskId, + createConfig("100"), + stateManager, + streamsMetrics, + null + ); + final StreamsMetricsImpl metrics = new StreamsMetricsImpl(this.metrics, "test", StreamsConfig.METRICS_LATEST, time); + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()); + EasyMock.replay(stateManager); + + // The processor topology is missing the topics + final ProcessorTopology topology = withSources(emptyList(), mkMap()); + + final TopologyException exception = assertThrows( + TopologyException.class, + () -> new StreamTask( + taskId, + partitions, + topology, + consumer, + createConfig("100"), + metrics, + stateDirectory, + cache, + time, + stateManager, + recordCollector, + context, + logContext) + ); + + assertThat(exception.getMessage(), equalTo("Invalid topology: " + + "Topic is unknown to the topology. This may happen if different KafkaStreams instances of the same " + + "application execute different Topologies. Note that Topologies are only identical if all operators " + + "are added in the same order.")); + } + + @Test + public void shouldInitTaskTimeoutAndEventuallyThrow() { + task = createStatelessTask(createConfig()); + + task.maybeInitTaskTimeoutOrThrow(0L, null); + task.maybeInitTaskTimeoutOrThrow(Duration.ofMinutes(5).toMillis(), null); + + final StreamsException thrown = assertThrows( + StreamsException.class, + () -> task.maybeInitTaskTimeoutOrThrow(Duration.ofMinutes(5).plus(Duration.ofMillis(1L)).toMillis(), null) + ); + + assertThat(thrown.getCause(), isA(TimeoutException.class)); + } + + @Test + public void shouldCLearTaskTimeout() { + task = createStatelessTask(createConfig()); + + task.maybeInitTaskTimeoutOrThrow(0L, null); + task.clearTaskTimeout(); + task.maybeInitTaskTimeoutOrThrow(Duration.ofMinutes(5).plus(Duration.ofMillis(1L)).toMillis(), null); + } + + private List getTaskMetrics() { + return metrics.metrics().keySet().stream().filter(m -> m.tags().containsKey("task-id")).collect(Collectors.toList()); + } + + private StreamTask createOptimizedStatefulTask(final StreamsConfig config, final Consumer consumer) { + final StateStore stateStore = new MockKeyValueStore(storeName, true); + + final ProcessorTopology topology = ProcessorTopologyFactories.with( + singletonList(source1), + mkMap(mkEntry(topic1, source1)), + singletonList(stateStore), + Collections.singletonMap(storeName, topic1)); + + final InternalProcessorContext context = new ProcessorContextImpl( + taskId, + config, + stateManager, + streamsMetrics, + null + ); + + return new StreamTask( + taskId, + mkSet(partition1), + topology, + consumer, + config, + streamsMetrics, + stateDirectory, + cache, + time, + stateManager, + recordCollector, + context, + logContext + ); + } + + private StreamTask createDisconnectedTask(final StreamsConfig config) { + final MockKeyValueStore stateStore = new MockKeyValueStore(storeName, false); + + final ProcessorTopology topology = ProcessorTopologyFactories.with( + asList(source1, source2), + mkMap(mkEntry(topic1, source1), mkEntry(topic2, source2)), + singletonList(stateStore), + emptyMap()); + + final MockConsumer consumer = new MockConsumer(OffsetResetStrategy.EARLIEST) { + @Override + public Map committed(final Set partitions) { + throw new TimeoutException("KABOOM!"); + } + }; + + final InternalProcessorContext context = new ProcessorContextImpl( + taskId, + config, + stateManager, + streamsMetrics, + null + ); + + return new StreamTask( + taskId, + partitions, + topology, + consumer, + config, + streamsMetrics, + stateDirectory, + cache, + time, + stateManager, + recordCollector, + context, + logContext + ); + } + + private StreamTask createFaultyStatefulTask(final StreamsConfig config) { + final ProcessorTopology topology = ProcessorTopologyFactories.with( + asList(source1, source3), + mkMap(mkEntry(topic1, source1), mkEntry(topic2, source3)), + singletonList(stateStore), + emptyMap() + ); + + final InternalProcessorContext context = new ProcessorContextImpl( + taskId, + config, + stateManager, + streamsMetrics, + null + ); + + return new StreamTask( + taskId, + partitions, + topology, + consumer, + config, + streamsMetrics, + stateDirectory, + cache, + time, + stateManager, + recordCollector, + context, + logContext + ); + } + + private StreamTask createStatefulTask(final StreamsConfig config, final boolean logged) { + return createStatefulTask(config, logged, stateManager); + } + + private StreamTask createStatefulTask(final StreamsConfig config, final boolean logged, final ProcessorStateManager stateManager) { + final MockKeyValueStore stateStore = new MockKeyValueStore(storeName, logged); + + final ProcessorTopology topology = ProcessorTopologyFactories.with( + asList(source1, source2), + mkMap(mkEntry(topic1, source1), mkEntry(topic2, source2)), + singletonList(stateStore), + logged ? Collections.singletonMap(storeName, storeName + "-changelog") : Collections.emptyMap()); + + final InternalProcessorContext context = new ProcessorContextImpl( + taskId, + config, + stateManager, + streamsMetrics, + null + ); + + return new StreamTask( + taskId, + partitions, + topology, + consumer, + config, + streamsMetrics, + stateDirectory, + cache, + time, + stateManager, + recordCollector, + context, + logContext + ); + } + + private StreamTask createSingleSourceStateless(final StreamsConfig config, + final String builtInMetricsVersion) { + final ProcessorTopology topology = withSources( + asList(source1, processorStreamTime, processorSystemTime), + mkMap(mkEntry(topic1, source1)) + ); + + source1.addChild(processorStreamTime); + source1.addChild(processorSystemTime); + + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()); + EasyMock.expect(stateManager.changelogOffsets()).andReturn(Collections.emptyMap()).anyTimes(); + EasyMock.expect(recordCollector.offsets()).andReturn(Collections.emptyMap()).anyTimes(); + EasyMock.replay(stateManager, recordCollector); + + final InternalProcessorContext context = new ProcessorContextImpl( + taskId, + config, + stateManager, + streamsMetrics, + null + ); + + return new StreamTask( + taskId, + mkSet(partition1), + topology, + consumer, + config, + new StreamsMetricsImpl(metrics, "test", builtInMetricsVersion, time), + stateDirectory, + cache, + time, + stateManager, + recordCollector, + context, + logContext + ); + } + + private StreamTask createStatelessTask(final StreamsConfig config) { + final ProcessorTopology topology = withSources( + asList(source1, source2, processorStreamTime, processorSystemTime), + mkMap(mkEntry(topic1, source1), mkEntry(topic2, source2)) + ); + + source1.addChild(processorStreamTime); + source2.addChild(processorStreamTime); + source1.addChild(processorSystemTime); + source2.addChild(processorSystemTime); + + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()); + EasyMock.expect(stateManager.changelogOffsets()).andReturn(Collections.emptyMap()).anyTimes(); + EasyMock.expect(recordCollector.offsets()).andReturn(Collections.emptyMap()).anyTimes(); + EasyMock.replay(stateManager, recordCollector); + + final InternalProcessorContext context = new ProcessorContextImpl( + taskId, + config, + stateManager, + streamsMetrics, + null + ); + + return new StreamTask( + taskId, + partitions, + topology, + consumer, + config, + new StreamsMetricsImpl(metrics, "test", StreamsConfig.METRICS_LATEST, time), + stateDirectory, + cache, + time, + stateManager, + recordCollector, + context, + logContext + ); + } + + private StreamTask createStatelessTaskWithForwardingTopology(final SourceNode sourceNode) { + final ProcessorTopology topology = withSources( + asList(sourceNode, processorStreamTime), + singletonMap(topic1, sourceNode) + ); + + sourceNode.addChild(processorStreamTime); + + EasyMock.expect(stateManager.changelogPartitions()).andReturn(Collections.emptySet()); + EasyMock.expect(recordCollector.offsets()).andReturn(Collections.emptyMap()).anyTimes(); + EasyMock.replay(stateManager, recordCollector); + + final StreamsConfig config = createConfig(); + + final InternalProcessorContext context = new ProcessorContextImpl( + taskId, + config, + stateManager, + streamsMetrics, + null + ); + + return new StreamTask( + taskId, + singleton(partition1), + topology, + consumer, + config, + new StreamsMetricsImpl(metrics, "test", StreamsConfig.METRICS_LATEST, time), + stateDirectory, + cache, + time, + stateManager, + recordCollector, + context, + logContext + ); + } + + private void createTimeoutTask(final String eosConfig) { + EasyMock.replay(stateManager); + + final ProcessorTopology topology = withSources( + singletonList(timeoutSource), + mkMap(mkEntry(topic1, timeoutSource)) + ); + + final StreamsConfig config = createConfig(eosConfig, "0"); + final InternalProcessorContext context = new ProcessorContextImpl( + taskId, + config, + stateManager, + streamsMetrics, + null + ); + + task = new StreamTask( + taskId, + mkSet(partition1), + topology, + consumer, + config, + streamsMetrics, + stateDirectory, + cache, + time, + stateManager, + recordCollector, + context, + logContext + ); + } + + private ConsumerRecord getConsumerRecordWithOffsetAsTimestamp(final TopicPartition topicPartition, + final long offset, + final int value) { + return new ConsumerRecord<>( + topicPartition.topic(), + topicPartition.partition(), + offset, + offset, // use the offset as the timestamp + TimestampType.CREATE_TIME, + 0, + 0, + recordKey, + intSerializer.serialize(null, value), + new RecordHeaders(), + Optional.empty() + ); + } + + private ConsumerRecord getConsumerRecordWithOffsetAsTimestamp(final TopicPartition topicPartition, + final long offset) { + return new ConsumerRecord<>( + topicPartition.topic(), + topicPartition.partition(), + offset, + offset, // use the offset as the timestamp + TimestampType.CREATE_TIME, + 0, + 0, + recordKey, + recordValue, + new RecordHeaders(), + Optional.empty() + ); + } + + private ConsumerRecord getConsumerRecordWithOffsetAsTimestamp(final Integer key, final long offset) { + return new ConsumerRecord<>( + topic1, + 1, + offset, + offset, // use the offset as the timestamp + TimestampType.CREATE_TIME, + 0, + 0, + new IntegerSerializer().serialize(topic1, key), + recordValue, + new RecordHeaders(), + Optional.empty() + ); + } + + private MetricName setupCloseTaskMetric() { + final MetricName metricName = new MetricName("name", "group", "description", Collections.emptyMap()); + final Sensor sensor = streamsMetrics.threadLevelSensor(threadId, "task-closed", Sensor.RecordingLevel.INFO); + sensor.add(metricName, new CumulativeSum()); + return metricName; + } + + private void verifyCloseTaskMetric(final double expected, final StreamsMetricsImpl streamsMetrics, final MetricName metricName) { + final KafkaMetric metric = (KafkaMetric) streamsMetrics.metrics().get(metricName); + final double totalCloses = metric.measurable().measure(metric.config(), System.currentTimeMillis()); + assertThat(totalCloses, equalTo(expected)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java new file mode 100644 index 0000000..c735978 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTest.java @@ -0,0 +1,2991 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.admin.MockAdminClient; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerGroupMetadata; +import org.apache.kafka.clients.consumer.ConsumerRebalanceListener; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.InvalidOffsetException; +import org.apache.kafka.clients.consumer.MockConsumer; +import org.apache.kafka.clients.consumer.OffsetResetStrategy; +import org.apache.kafka.clients.consumer.internals.MockRebalanceListener; +import org.apache.kafka.clients.producer.MockProducer; +import org.apache.kafka.clients.producer.Producer; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.ProducerFencedException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.metrics.JmxReporter; +import org.apache.kafka.common.metrics.KafkaMetric; +import org.apache.kafka.common.metrics.KafkaMetricsContext; +import org.apache.kafka.common.metrics.Measurable; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.MetricsContext; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.ThreadMetadata; +import org.apache.kafka.streams.errors.LogAndContinueExceptionHandler; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.errors.TaskCorruptedException; +import org.apache.kafka.streams.errors.TaskMigratedException; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.internals.ConsumedInternal; +import org.apache.kafka.streams.kstream.internals.InternalStreamsBuilder; +import org.apache.kafka.streams.kstream.internals.MaterializedInternal; +import org.apache.kafka.streams.processor.LogAndSkipOnInvalidTimestamp; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.internals.assignment.ReferenceContainer; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.internals.OffsetCheckpoint; +import org.apache.kafka.test.MockApiProcessor; +import org.apache.kafka.test.MockClientSupplier; +import org.apache.kafka.test.MockKeyValueStoreBuilder; +import org.apache.kafka.test.MockStateRestoreListener; +import org.apache.kafka.test.MockTimestampExtractor; +import org.apache.kafka.test.StreamsTestUtils; +import org.apache.kafka.test.TestUtils; +import org.easymock.EasyMock; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.slf4j.Logger; + +import java.io.File; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Properties; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.stream.Stream; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.emptySet; +import static java.util.Collections.singleton; +import static java.util.Collections.singletonMap; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkProperties; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.streams.processor.internals.ClientUtils.getSharedAdminClientId; +import static org.apache.kafka.streams.processor.internals.StateManagerUtil.CHECKPOINT_FILE_NAME; +import static org.easymock.EasyMock.anyObject; +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.expectLastCall; +import static org.easymock.EasyMock.mock; +import static org.easymock.EasyMock.niceMock; +import static org.easymock.EasyMock.verify; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.CoreMatchers.startsWith; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class StreamThreadTest { + + private final static String APPLICATION_ID = "stream-thread-test"; + private final static UUID PROCESS_ID = UUID.fromString("87bf53a8-54f2-485f-a4b6-acdbec0a8b3d"); + private final static String CLIENT_ID = APPLICATION_ID + "-" + PROCESS_ID; + + private final int threadIdx = 1; + private final Metrics metrics = new Metrics(); + private final MockTime mockTime = new MockTime(); + private final String stateDir = TestUtils.tempDirectory().getPath(); + private final MockClientSupplier clientSupplier = new MockClientSupplier(); + private final StreamsConfig config = new StreamsConfig(configProps(false)); + private final StreamsConfig eosEnabledConfig = new StreamsConfig(configProps(true)); + private final ConsumedInternal consumed = new ConsumedInternal<>(); + private final ChangelogReader changelogReader = new MockChangelogReader(); + private final StateDirectory stateDirectory = new StateDirectory(config, mockTime, true, false); + private final InternalTopologyBuilder internalTopologyBuilder = new InternalTopologyBuilder(); + private final InternalStreamsBuilder internalStreamsBuilder = new InternalStreamsBuilder(internalTopologyBuilder); + + private StreamsMetadataState streamsMetadataState; + private final static java.util.function.Consumer HANDLER = e -> { + if (e instanceof RuntimeException) { + throw (RuntimeException) e; + } else if (e instanceof Error) { + throw (Error) e; + } else { + throw new RuntimeException("Unexpected checked exception caught in the uncaught exception handler", e); + } + }; + + @Before + public void setUp() { + Thread.currentThread().setName(CLIENT_ID + "-StreamThread-" + threadIdx); + internalTopologyBuilder.setApplicationId(APPLICATION_ID); + streamsMetadataState = new StreamsMetadataState(new TopologyMetadata(internalTopologyBuilder, config), StreamsMetadataState.UNKNOWN_HOST); + } + + private final String topic1 = "topic1"; + private final String topic2 = "topic2"; + + private final TopicPartition t1p1 = new TopicPartition(topic1, 1); + private final TopicPartition t1p2 = new TopicPartition(topic1, 2); + private final TopicPartition t2p1 = new TopicPartition(topic2, 1); + + // task0 is unused + private final TaskId task1 = new TaskId(0, 1); + private final TaskId task2 = new TaskId(0, 2); + private final TaskId task3 = new TaskId(1, 1); + + private Properties configProps(final boolean enableEoS) { + return mkProperties(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, APPLICATION_ID), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:2171"), + mkEntry(StreamsConfig.BUFFERED_RECORDS_PER_PARTITION_CONFIG, "3"), + mkEntry(StreamsConfig.DEFAULT_TIMESTAMP_EXTRACTOR_CLASS_CONFIG, MockTimestampExtractor.class.getName()), + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getAbsolutePath()), + mkEntry(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, enableEoS ? StreamsConfig.EXACTLY_ONCE_V2 : StreamsConfig.AT_LEAST_ONCE), + mkEntry(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.ByteArraySerde.class.getName()), + mkEntry(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.ByteArraySerde.class.getName()) + )); + } + + private Cluster createCluster() { + final Node node = new Node(-1, "localhost", 8121); + return new Cluster( + "mockClusterId", + Collections.singletonList(node), + Collections.emptySet(), + Collections.emptySet(), + Collections.emptySet(), + node + ); + } + + private StreamThread createStreamThread(@SuppressWarnings("SameParameterValue") final String clientId, + final StreamsConfig config, + final boolean eosEnabled) { + if (eosEnabled) { + clientSupplier.setApplicationIdForProducer(APPLICATION_ID); + } + + clientSupplier.setCluster(createCluster()); + + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl( + metrics, + APPLICATION_ID, + config.getString(StreamsConfig.BUILT_IN_METRICS_VERSION_CONFIG), + mockTime + ); + + internalTopologyBuilder.buildTopology(); + + return StreamThread.create( + new TopologyMetadata(internalTopologyBuilder, config), + config, + clientSupplier, + clientSupplier.getAdmin(config.getAdminConfigs(clientId)), + PROCESS_ID, + clientId, + streamsMetrics, + mockTime, + streamsMetadataState, + 0, + stateDirectory, + new MockStateRestoreListener(), + threadIdx, + null, + HANDLER + ); + } + + private static class StateListenerStub implements StreamThread.StateListener { + int numChanges = 0; + ThreadStateTransitionValidator oldState = null; + ThreadStateTransitionValidator newState = null; + + @Override + public void onChange(final Thread thread, + final ThreadStateTransitionValidator newState, + final ThreadStateTransitionValidator oldState) { + ++numChanges; + if (this.newState != null) { + if (!this.newState.equals(oldState)) { + throw new RuntimeException("State mismatch " + oldState + " different from " + this.newState); + } + } + this.oldState = oldState; + this.newState = newState; + } + } + + @Test + public void shouldChangeStateInRebalanceListener() { + final StreamThread thread = createStreamThread(CLIENT_ID, config, false); + + final StateListenerStub stateListener = new StateListenerStub(); + thread.setStateListener(stateListener); + assertEquals(thread.state(), StreamThread.State.CREATED); + + final ConsumerRebalanceListener rebalanceListener = thread.rebalanceListener(); + + final List revokedPartitions; + final List assignedPartitions; + + // revoke nothing + thread.setState(StreamThread.State.STARTING); + revokedPartitions = Collections.emptyList(); + rebalanceListener.onPartitionsRevoked(revokedPartitions); + + assertEquals(thread.state(), StreamThread.State.PARTITIONS_REVOKED); + + // assign single partition + assignedPartitions = Collections.singletonList(t1p1); + + final MockConsumer mockConsumer = (MockConsumer) thread.mainConsumer(); + mockConsumer.assign(assignedPartitions); + mockConsumer.updateBeginningOffsets(Collections.singletonMap(t1p1, 0L)); + rebalanceListener.onPartitionsAssigned(assignedPartitions); + thread.runOnce(); + assertEquals(thread.state(), StreamThread.State.RUNNING); + Assert.assertEquals(4, stateListener.numChanges); + Assert.assertEquals(StreamThread.State.PARTITIONS_ASSIGNED, stateListener.oldState); + + thread.shutdown(); + assertSame(StreamThread.State.PENDING_SHUTDOWN, thread.state()); + } + + @Test + public void shouldChangeStateAtStartClose() throws Exception { + final StreamThread thread = createStreamThread(CLIENT_ID, config, false); + + final StateListenerStub stateListener = new StateListenerStub(); + thread.setStateListener(stateListener); + + thread.start(); + TestUtils.waitForCondition( + () -> thread.state() == StreamThread.State.STARTING, + 10 * 1000, + "Thread never started."); + + thread.shutdown(); + TestUtils.waitForCondition( + () -> thread.state() == StreamThread.State.DEAD, + 10 * 1000, + "Thread never shut down."); + + thread.shutdown(); + assertEquals(thread.state(), StreamThread.State.DEAD); + } + + @Test + public void shouldCreateMetricsAtStartup() { + final StreamThread thread = createStreamThread(CLIENT_ID, config, false); + final String defaultGroupName = "stream-thread-metrics"; + final Map defaultTags = Collections.singletonMap( + "thread-id", + thread.getName() + ); + final String descriptionIsNotVerified = ""; + + assertNotNull(metrics.metrics().get(metrics.metricName( + "commit-latency-avg", defaultGroupName, descriptionIsNotVerified, defaultTags))); + assertNotNull(metrics.metrics().get(metrics.metricName( + "commit-latency-max", defaultGroupName, descriptionIsNotVerified, defaultTags))); + assertNotNull(metrics.metrics().get(metrics.metricName( + "commit-rate", defaultGroupName, descriptionIsNotVerified, defaultTags))); + assertNotNull(metrics.metrics().get(metrics.metricName( + "commit-total", defaultGroupName, descriptionIsNotVerified, defaultTags))); + assertNotNull(metrics.metrics().get(metrics.metricName( + "commit-ratio", defaultGroupName, descriptionIsNotVerified, defaultTags))); + assertNotNull(metrics.metrics().get(metrics.metricName( + "poll-latency-avg", defaultGroupName, descriptionIsNotVerified, defaultTags))); + assertNotNull(metrics.metrics().get(metrics.metricName( + "poll-latency-max", defaultGroupName, descriptionIsNotVerified, defaultTags))); + assertNotNull(metrics.metrics().get(metrics.metricName( + "poll-rate", defaultGroupName, descriptionIsNotVerified, defaultTags))); + assertNotNull(metrics.metrics().get(metrics.metricName( + "poll-total", defaultGroupName, descriptionIsNotVerified, defaultTags))); + assertNotNull(metrics.metrics().get(metrics.metricName( + "poll-ratio", defaultGroupName, descriptionIsNotVerified, defaultTags))); + assertNotNull(metrics.metrics().get(metrics.metricName( + "poll-records-avg", defaultGroupName, descriptionIsNotVerified, defaultTags))); + assertNotNull(metrics.metrics().get(metrics.metricName( + "poll-records-max", defaultGroupName, descriptionIsNotVerified, defaultTags))); + assertNotNull(metrics.metrics().get(metrics.metricName( + "process-latency-avg", defaultGroupName, descriptionIsNotVerified, defaultTags))); + assertNotNull(metrics.metrics().get(metrics.metricName( + "process-latency-max", defaultGroupName, descriptionIsNotVerified, defaultTags))); + assertNotNull(metrics.metrics().get(metrics.metricName( + "process-rate", defaultGroupName, descriptionIsNotVerified, defaultTags))); + assertNotNull(metrics.metrics().get(metrics.metricName( + "process-total", defaultGroupName, descriptionIsNotVerified, defaultTags))); + assertNotNull(metrics.metrics().get(metrics.metricName( + "process-ratio", defaultGroupName, descriptionIsNotVerified, defaultTags))); + assertNotNull(metrics.metrics().get(metrics.metricName( + "process-records-avg", defaultGroupName, descriptionIsNotVerified, defaultTags))); + assertNotNull(metrics.metrics().get(metrics.metricName( + "process-records-max", defaultGroupName, descriptionIsNotVerified, defaultTags))); + assertNotNull(metrics.metrics().get(metrics.metricName( + "punctuate-latency-avg", defaultGroupName, descriptionIsNotVerified, defaultTags))); + assertNotNull(metrics.metrics().get(metrics.metricName( + "punctuate-latency-max", defaultGroupName, descriptionIsNotVerified, defaultTags))); + assertNotNull(metrics.metrics().get(metrics.metricName( + "punctuate-rate", defaultGroupName, descriptionIsNotVerified, defaultTags))); + assertNotNull(metrics.metrics().get(metrics.metricName( + "punctuate-total", defaultGroupName, descriptionIsNotVerified, defaultTags))); + assertNotNull(metrics.metrics().get(metrics.metricName( + "punctuate-ratio", defaultGroupName, descriptionIsNotVerified, defaultTags))); + assertNotNull(metrics.metrics().get(metrics.metricName( + "task-created-rate", defaultGroupName, descriptionIsNotVerified, defaultTags))); + assertNotNull(metrics.metrics().get(metrics.metricName( + "task-created-total", defaultGroupName, descriptionIsNotVerified, defaultTags))); + assertNotNull(metrics.metrics().get(metrics.metricName( + "task-closed-rate", defaultGroupName, descriptionIsNotVerified, defaultTags))); + assertNotNull(metrics.metrics().get(metrics.metricName( + "task-closed-total", defaultGroupName, descriptionIsNotVerified, defaultTags))); + + assertNull(metrics.metrics().get(metrics.metricName( + "skipped-records-rate", defaultGroupName, descriptionIsNotVerified, defaultTags))); + assertNull(metrics.metrics().get(metrics.metricName( + "skipped-records-total", defaultGroupName, descriptionIsNotVerified, defaultTags))); + + final String taskGroupName = "stream-task-metrics"; + final Map taskTags = + mkMap(mkEntry("task-id", "all"), mkEntry("thread-id", thread.getName())); + assertNull(metrics.metrics().get(metrics.metricName( + "commit-latency-avg", taskGroupName, descriptionIsNotVerified, taskTags))); + assertNull(metrics.metrics().get(metrics.metricName( + "commit-latency-max", taskGroupName, descriptionIsNotVerified, taskTags))); + assertNull(metrics.metrics().get(metrics.metricName( + "commit-rate", taskGroupName, descriptionIsNotVerified, taskTags))); + + final JmxReporter reporter = new JmxReporter(); + final MetricsContext metricsContext = new KafkaMetricsContext("kafka.streams"); + reporter.contextChange(metricsContext); + + metrics.addReporter(reporter); + assertEquals(CLIENT_ID + "-StreamThread-1", thread.getName()); + assertTrue(reporter.containsMbean(String.format("kafka.streams:type=%s,%s=%s", + defaultGroupName, + "thread-id", + thread.getName()) + )); + assertFalse(reporter.containsMbean(String.format( + "kafka.streams:type=stream-task-metrics,%s=%s,task-id=all", + "thread-id", + thread.getName()))); + } + + @Test + public void shouldNotCommitBeforeTheCommitInterval() { + final long commitInterval = 1000L; + final Properties props = configProps(false); + props.setProperty(StreamsConfig.STATE_DIR_CONFIG, stateDir); + props.setProperty(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, Long.toString(commitInterval)); + + final StreamsConfig config = new StreamsConfig(props); + final Consumer consumer = EasyMock.createNiceMock(Consumer.class); + final ConsumerGroupMetadata consumerGroupMetadata = mock(ConsumerGroupMetadata.class); + expect(consumer.groupMetadata()).andStubReturn(consumerGroupMetadata); + expect(consumerGroupMetadata.groupInstanceId()).andReturn(Optional.empty()); + final TaskManager taskManager = mockTaskManagerCommit(consumer, 1, 1); + EasyMock.replay(consumer, consumerGroupMetadata); + + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, CLIENT_ID, StreamsConfig.METRICS_LATEST, mockTime); + final StreamThread thread = new StreamThread( + mockTime, + config, + null, + consumer, + consumer, + null, + null, + taskManager, + streamsMetrics, + new TopologyMetadata(internalTopologyBuilder, config), + CLIENT_ID, + new LogContext(""), + new AtomicInteger(), + new AtomicLong(Long.MAX_VALUE), + null, + HANDLER, + null + ); + thread.setNow(mockTime.milliseconds()); + thread.maybeCommit(); + mockTime.sleep(commitInterval - 10L); + thread.setNow(mockTime.milliseconds()); + thread.maybeCommit(); + + verify(taskManager); + } + + @Test + public void shouldEnforceRebalanceAfterNextScheduledProbingRebalanceTime() throws InterruptedException { + final StreamsConfig config = new StreamsConfig(configProps(false)); + internalTopologyBuilder.buildTopology(); + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl( + metrics, + APPLICATION_ID, + config.getString(StreamsConfig.BUILT_IN_METRICS_VERSION_CONFIG), + mockTime + ); + + final Consumer mockConsumer = EasyMock.createNiceMock(Consumer.class); + expect(mockConsumer.poll(anyObject())).andStubReturn(ConsumerRecords.empty()); + final ConsumerGroupMetadata consumerGroupMetadata = mock(ConsumerGroupMetadata.class); + expect(mockConsumer.groupMetadata()).andStubReturn(consumerGroupMetadata); + expect(consumerGroupMetadata.groupInstanceId()).andReturn(Optional.empty()); + EasyMock.replay(consumerGroupMetadata); + final EasyMockConsumerClientSupplier mockClientSupplier = new EasyMockConsumerClientSupplier(mockConsumer); + + mockClientSupplier.setCluster(createCluster()); + EasyMock.replay(mockConsumer); + final StreamThread thread = StreamThread.create( + new TopologyMetadata(internalTopologyBuilder, config), + config, + mockClientSupplier, + mockClientSupplier.getAdmin(config.getAdminConfigs(CLIENT_ID)), + PROCESS_ID, + CLIENT_ID, + streamsMetrics, + mockTime, + streamsMetadataState, + 0, + stateDirectory, + new MockStateRestoreListener(), + threadIdx, + null, + null + ); + + mockConsumer.enforceRebalance(); + + mockClientSupplier.nextRebalanceMs().set(mockTime.milliseconds() - 1L); + + thread.start(); + TestUtils.waitForCondition( + () -> thread.state() == StreamThread.State.STARTING, + 10 * 1000, + "Thread never started."); + + TestUtils.retryOnExceptionWithTimeout( + () -> verify(mockConsumer) + ); + + thread.shutdown(); + TestUtils.waitForCondition( + () -> thread.state() == StreamThread.State.DEAD, + 10 * 1000, + "Thread never shut down."); + + } + + private static class EasyMockConsumerClientSupplier extends MockClientSupplier { + final Consumer mockConsumer; + final Map consumerConfigs = new HashMap<>(); + + EasyMockConsumerClientSupplier(final Consumer mockConsumer) { + this.mockConsumer = mockConsumer; + } + + @Override + public Consumer getConsumer(final Map config) { + consumerConfigs.putAll(config); + return mockConsumer; + } + + AtomicLong nextRebalanceMs() { + return ((ReferenceContainer) consumerConfigs.get( + StreamsConfig.InternalConfig.REFERENCE_CONTAINER_PARTITION_ASSIGNOR) + ).nextScheduledRebalanceMs; + } + } + + @Test + public void shouldRespectNumIterationsInMainLoop() { + final List> mockProcessors = new LinkedList<>(); + internalTopologyBuilder.addSource(null, "source1", null, null, null, topic1); + internalTopologyBuilder.addProcessor( + "processor1", + (ProcessorSupplier) () -> { + final MockApiProcessor processor = new MockApiProcessor<>(PunctuationType.WALL_CLOCK_TIME, 10L); + mockProcessors.add(processor); + return processor; + }, + "source1" + ); + internalTopologyBuilder.addProcessor( + "processor2", + (ProcessorSupplier) () -> new MockApiProcessor<>(PunctuationType.STREAM_TIME, 10L), + "source1" + ); + + final Properties properties = new Properties(); + properties.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L); + final StreamsConfig config = new StreamsConfig(StreamsTestUtils.getStreamsConfig(APPLICATION_ID, + "localhost:2171", + Serdes.ByteArraySerde.class.getName(), + Serdes.ByteArraySerde.class.getName(), + properties)); + final StreamThread thread = createStreamThread(CLIENT_ID, config, false); + + thread.setState(StreamThread.State.STARTING); + thread.setState(StreamThread.State.PARTITIONS_REVOKED); + + final TaskId task1 = new TaskId(0, t1p1.partition()); + final Set assignedPartitions = Collections.singleton(t1p1); + + thread.taskManager().handleAssignment(Collections.singletonMap(task1, assignedPartitions), emptyMap()); + + final MockConsumer mockConsumer = (MockConsumer) thread.mainConsumer(); + mockConsumer.assign(Collections.singleton(t1p1)); + mockConsumer.updateBeginningOffsets(Collections.singletonMap(t1p1, 0L)); + thread.rebalanceListener().onPartitionsAssigned(assignedPartitions); + thread.runOnce(); + + // processed one record, punctuated after the first record, and hence num.iterations is still 1 + long offset = -1; + addRecord(mockConsumer, ++offset, 0L); + thread.runOnce(); + + assertThat(thread.currentNumIterations(), equalTo(1)); + + // processed one more record without punctuation, and bump num.iterations to 2 + addRecord(mockConsumer, ++offset, 1L); + thread.runOnce(); + + assertThat(thread.currentNumIterations(), equalTo(2)); + + // processed zero records, early exit and iterations stays as 2 + thread.runOnce(); + assertThat(thread.currentNumIterations(), equalTo(2)); + + // system time based punctutation without processing any record, iteration stays as 2 + mockTime.sleep(11L); + + thread.runOnce(); + assertThat(thread.currentNumIterations(), equalTo(2)); + + // system time based punctutation after processing a record, half iteration to 1 + mockTime.sleep(11L); + addRecord(mockConsumer, ++offset, 5L); + + thread.runOnce(); + assertThat(thread.currentNumIterations(), equalTo(1)); + + // processed two records, bumping up iterations to 3 (1 + 2) + addRecord(mockConsumer, ++offset, 5L); + addRecord(mockConsumer, ++offset, 6L); + thread.runOnce(); + + assertThat(thread.currentNumIterations(), equalTo(3)); + + // stream time based punctutation halves to 1 + addRecord(mockConsumer, ++offset, 11L); + thread.runOnce(); + + assertThat(thread.currentNumIterations(), equalTo(1)); + + // processed three records, bumping up iterations to 3 (1 + 2) + addRecord(mockConsumer, ++offset, 12L); + addRecord(mockConsumer, ++offset, 13L); + addRecord(mockConsumer, ++offset, 14L); + thread.runOnce(); + + assertThat(thread.currentNumIterations(), equalTo(3)); + + mockProcessors.forEach(MockApiProcessor::requestCommit); + addRecord(mockConsumer, ++offset, 15L); + thread.runOnce(); + + // user requested commit should half iteration to 1 + assertThat(thread.currentNumIterations(), equalTo(1)); + + // processed three records, bumping up iterations to 3 (1 + 2) + addRecord(mockConsumer, ++offset, 15L); + addRecord(mockConsumer, ++offset, 16L); + addRecord(mockConsumer, ++offset, 17L); + thread.runOnce(); + + assertThat(thread.currentNumIterations(), equalTo(3)); + + // time based commit without processing, should keep the iteration as 3 + mockTime.sleep(90L); + thread.runOnce(); + + assertThat(thread.currentNumIterations(), equalTo(3)); + + // time based commit without processing, should half the iteration to 1 + mockTime.sleep(90L); + addRecord(mockConsumer, ++offset, 18L); + thread.runOnce(); + + assertThat(thread.currentNumIterations(), equalTo(1)); + } + + @Test + public void shouldNotCauseExceptionIfNothingCommitted() { + final long commitInterval = 1000L; + final Properties props = configProps(false); + props.setProperty(StreamsConfig.STATE_DIR_CONFIG, stateDir); + props.setProperty(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, Long.toString(commitInterval)); + + final StreamsConfig config = new StreamsConfig(props); + final Consumer consumer = EasyMock.createNiceMock(Consumer.class); + final ConsumerGroupMetadata consumerGroupMetadata = mock(ConsumerGroupMetadata.class); + expect(consumer.groupMetadata()).andStubReturn(consumerGroupMetadata); + expect(consumerGroupMetadata.groupInstanceId()).andReturn(Optional.empty()); + EasyMock.replay(consumer, consumerGroupMetadata); + final TaskManager taskManager = mockTaskManagerCommit(consumer, 1, 0); + + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, CLIENT_ID, StreamsConfig.METRICS_LATEST, mockTime); + final StreamThread thread = new StreamThread( + mockTime, + config, + null, + consumer, + consumer, + null, + null, + taskManager, + streamsMetrics, + new TopologyMetadata(internalTopologyBuilder, config), + CLIENT_ID, + new LogContext(""), + new AtomicInteger(), + new AtomicLong(Long.MAX_VALUE), + null, + HANDLER, + null + ); + thread.setNow(mockTime.milliseconds()); + thread.maybeCommit(); + mockTime.sleep(commitInterval - 10L); + thread.setNow(mockTime.milliseconds()); + thread.maybeCommit(); + + verify(taskManager); + } + + @Test + public void shouldCommitAfterCommitInterval() { + final long commitInterval = 100L; + final long commitLatency = 10L; + + final Properties props = configProps(false); + props.setProperty(StreamsConfig.STATE_DIR_CONFIG, stateDir); + props.setProperty(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, Long.toString(commitInterval)); + + final StreamsConfig config = new StreamsConfig(props); + final Consumer consumer = EasyMock.createNiceMock(Consumer.class); + final ConsumerGroupMetadata consumerGroupMetadata = mock(ConsumerGroupMetadata.class); + expect(consumer.groupMetadata()).andStubReturn(consumerGroupMetadata); + expect(consumerGroupMetadata.groupInstanceId()).andReturn(Optional.empty()); + EasyMock.replay(consumer, consumerGroupMetadata); + + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, CLIENT_ID, StreamsConfig.METRICS_LATEST, mockTime); + + final AtomicBoolean committed = new AtomicBoolean(false); + final TaskManager taskManager = new TaskManager( + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null + ) { + @Override + int commit(final Collection tasksToCommit) { + committed.set(true); + // we advance time to make sure the commit delay is considered when computing the next commit timestamp + mockTime.sleep(commitLatency); + return 1; + } + }; + + final StreamThread thread = new StreamThread( + mockTime, + config, + null, + consumer, + consumer, + changelogReader, + null, + taskManager, + streamsMetrics, + new TopologyMetadata(internalTopologyBuilder, config), + CLIENT_ID, + new LogContext(""), + new AtomicInteger(), + new AtomicLong(Long.MAX_VALUE), + null, + HANDLER, + null + ); + + thread.setNow(mockTime.milliseconds()); + thread.maybeCommit(); + assertTrue(committed.get()); + + mockTime.sleep(commitInterval); + + committed.set(false); + thread.setNow(mockTime.milliseconds()); + thread.maybeCommit(); + assertFalse(committed.get()); + + mockTime.sleep(1); + + committed.set(false); + thread.setNow(mockTime.milliseconds()); + thread.maybeCommit(); + assertTrue(committed.get()); + } + + @Test + public void shouldRecordCommitLatency() { + final Consumer consumer = EasyMock.createNiceMock(Consumer.class); + final ConsumerGroupMetadata consumerGroupMetadata = mock(ConsumerGroupMetadata.class); + expect(consumer.groupMetadata()).andStubReturn(consumerGroupMetadata); + expect(consumerGroupMetadata.groupInstanceId()).andReturn(Optional.empty()); + expect(consumer.poll(anyObject())).andStubReturn(new ConsumerRecords<>(Collections.emptyMap())); + final Task task = niceMock(Task.class); + expect(task.id()).andStubReturn(task1); + expect(task.inputPartitions()).andStubReturn(Collections.singleton(t1p1)); + expect(task.committedOffsets()).andStubReturn(Collections.emptyMap()); + expect(task.highWaterMark()).andStubReturn(Collections.emptyMap()); + final ActiveTaskCreator activeTaskCreator = mock(ActiveTaskCreator.class); + expect(activeTaskCreator.createTasks(anyObject(), anyObject())).andStubReturn(Collections.singleton(task)); + expect(activeTaskCreator.producerClientIds()).andStubReturn(Collections.singleton("producerClientId")); + expect(activeTaskCreator.uncreatedTasksForTopologies(anyObject())).andStubReturn(emptyMap()); + activeTaskCreator.removeRevokedUnknownTasks(singleton(task1)); + + final StandbyTaskCreator standbyTaskCreator = mock(StandbyTaskCreator.class); + expect(standbyTaskCreator.uncreatedTasksForTopologies(anyObject())).andStubReturn(emptyMap()); + standbyTaskCreator.removeRevokedUnknownTasks(emptySet()); + + EasyMock.replay(consumer, consumerGroupMetadata, task, activeTaskCreator, standbyTaskCreator); + + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, CLIENT_ID, StreamsConfig.METRICS_LATEST, mockTime); + + final TaskManager taskManager = new TaskManager( + null, + null, + null, + null, + null, + activeTaskCreator, + standbyTaskCreator, + new TopologyMetadata(internalTopologyBuilder, config), + null, + null, + null + ) { + @Override + int commit(final Collection tasksToCommit) { + mockTime.sleep(10L); + return 1; + } + }; + taskManager.setMainConsumer(consumer); + + final StreamThread thread = new StreamThread( + mockTime, + config, + null, + consumer, + consumer, + changelogReader, + null, + taskManager, + streamsMetrics, + new TopologyMetadata(internalTopologyBuilder, config), + CLIENT_ID, + new LogContext(""), + new AtomicInteger(), + new AtomicLong(Long.MAX_VALUE), + null, + HANDLER, + null + ); + thread.updateThreadMetadata("adminClientId"); + thread.setState(StreamThread.State.STARTING); + + final Map> activeTasks = new HashMap<>(); + activeTasks.put(task1, Collections.singleton(t1p1)); + thread.taskManager().handleAssignment(activeTasks, emptyMap()); + thread.rebalanceListener().onPartitionsAssigned(Collections.singleton(t1p1)); + + assertTrue( + Double.isNaN( + (Double) streamsMetrics.metrics().get(new MetricName( + "commit-latency-max", + "stream-thread-metrics", + "", + Collections.singletonMap("thread-id", CLIENT_ID)) + ).metricValue() + ) + ); + assertTrue( + Double.isNaN( + (Double) streamsMetrics.metrics().get(new MetricName( + "commit-latency-avg", + "stream-thread-metrics", + "", + Collections.singletonMap("thread-id", CLIENT_ID)) + ).metricValue() + ) + ); + + thread.runOnce(); + + assertThat( + streamsMetrics.metrics().get( + new MetricName( + "commit-latency-max", + "stream-thread-metrics", + "", + Collections.singletonMap("thread-id", CLIENT_ID) + ) + ).metricValue(), + equalTo(10.0) + ); + assertThat( + streamsMetrics.metrics().get( + new MetricName( + "commit-latency-avg", + "stream-thread-metrics", + "", + Collections.singletonMap("thread-id", CLIENT_ID) + ) + ).metricValue(), + equalTo(10.0) + ); + } + + @Test + public void shouldInjectSharedProducerForAllTasksUsingClientSupplierOnCreateIfEosDisabled() { + internalTopologyBuilder.addSource(null, "source1", null, null, null, topic1); + internalStreamsBuilder.buildAndOptimizeTopology(); + + final StreamThread thread = createStreamThread(CLIENT_ID, config, false); + + thread.setState(StreamThread.State.STARTING); + thread.rebalanceListener().onPartitionsRevoked(Collections.emptyList()); + + final Map> activeTasks = new HashMap<>(); + final List assignedPartitions = new ArrayList<>(); + + // assign single partition + assignedPartitions.add(t1p1); + assignedPartitions.add(t1p2); + activeTasks.put(task1, Collections.singleton(t1p1)); + activeTasks.put(task2, Collections.singleton(t1p2)); + + thread.taskManager().handleAssignment(activeTasks, emptyMap()); + + final MockConsumer mockConsumer = (MockConsumer) thread.mainConsumer(); + mockConsumer.assign(assignedPartitions); + final Map beginOffsets = new HashMap<>(); + beginOffsets.put(t1p1, 0L); + beginOffsets.put(t1p2, 0L); + mockConsumer.updateBeginningOffsets(beginOffsets); + thread.rebalanceListener().onPartitionsAssigned(new HashSet<>(assignedPartitions)); + + assertEquals(1, clientSupplier.producers.size()); + final Producer globalProducer = clientSupplier.producers.get(0); + for (final Task task : thread.activeTasks()) { + assertSame(globalProducer, ((RecordCollectorImpl) ((StreamTask) task).recordCollector()).producer()); + } + assertSame(clientSupplier.consumer, thread.mainConsumer()); + assertSame(clientSupplier.restoreConsumer, thread.restoreConsumer()); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldInjectProducerPerTaskUsingClientSupplierOnCreateIfEosAlphaEnabled() { + internalTopologyBuilder.addSource(null, "source1", null, null, null, topic1); + + final Properties props = configProps(true); + props.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE); + final StreamThread thread = createStreamThread(CLIENT_ID, new StreamsConfig(props), true); + + thread.setState(StreamThread.State.STARTING); + thread.rebalanceListener().onPartitionsRevoked(Collections.emptyList()); + + final Map> activeTasks = new HashMap<>(); + final List assignedPartitions = new ArrayList<>(); + + // assign single partition + assignedPartitions.add(t1p1); + assignedPartitions.add(t1p2); + activeTasks.put(task1, Collections.singleton(t1p1)); + activeTasks.put(task2, Collections.singleton(t1p2)); + + thread.taskManager().handleAssignment(activeTasks, emptyMap()); + + final MockConsumer mockConsumer = (MockConsumer) thread.mainConsumer(); + mockConsumer.assign(assignedPartitions); + final Map beginOffsets = new HashMap<>(); + beginOffsets.put(t1p1, 0L); + beginOffsets.put(t1p2, 0L); + mockConsumer.updateBeginningOffsets(beginOffsets); + thread.rebalanceListener().onPartitionsAssigned(new HashSet<>(assignedPartitions)); + + thread.runOnce(); + + assertEquals(thread.activeTasks().size(), clientSupplier.producers.size()); + assertSame(clientSupplier.consumer, thread.mainConsumer()); + assertSame(clientSupplier.restoreConsumer, thread.restoreConsumer()); + } + + @Test + public void shouldInjectProducerPerThreadUsingClientSupplierOnCreateIfEosV2Enabled() { + internalTopologyBuilder.addSource(null, "source1", null, null, null, topic1); + + final Properties props = configProps(true); + final StreamThread thread = createStreamThread(CLIENT_ID, new StreamsConfig(props), true); + + thread.setState(StreamThread.State.STARTING); + thread.rebalanceListener().onPartitionsRevoked(Collections.emptyList()); + + final Map> activeTasks = new HashMap<>(); + final List assignedPartitions = new ArrayList<>(); + + // assign single partition + assignedPartitions.add(t1p1); + assignedPartitions.add(t1p2); + activeTasks.put(task1, Collections.singleton(t1p1)); + activeTasks.put(task2, Collections.singleton(t1p2)); + + thread.taskManager().handleAssignment(activeTasks, emptyMap()); + + final MockConsumer mockConsumer = (MockConsumer) thread.mainConsumer(); + mockConsumer.assign(assignedPartitions); + final Map beginOffsets = new HashMap<>(); + beginOffsets.put(t1p1, 0L); + beginOffsets.put(t1p2, 0L); + mockConsumer.updateBeginningOffsets(beginOffsets); + thread.rebalanceListener().onPartitionsAssigned(new HashSet<>(assignedPartitions)); + + thread.runOnce(); + + assertThat(clientSupplier.producers.size(), is(1)); + assertSame(clientSupplier.consumer, thread.mainConsumer()); + assertSame(clientSupplier.restoreConsumer, thread.restoreConsumer()); + } + + @Test + public void shouldOnlyCompleteShutdownAfterRebalanceNotInProgress() throws InterruptedException { + internalTopologyBuilder.addSource(null, "source1", null, null, null, topic1); + + final StreamThread thread = createStreamThread(CLIENT_ID, new StreamsConfig(configProps(true)), true); + + thread.start(); + TestUtils.waitForCondition( + () -> thread.state() == StreamThread.State.STARTING, + 10 * 1000, + "Thread never started."); + + thread.rebalanceListener().onPartitionsRevoked(Collections.emptyList()); + thread.taskManager().handleRebalanceStart(Collections.singleton(topic1)); + + final Map> activeTasks = new HashMap<>(); + final List assignedPartitions = new ArrayList<>(); + + // assign single partition + assignedPartitions.add(t1p1); + assignedPartitions.add(t1p2); + activeTasks.put(task1, Collections.singleton(t1p1)); + activeTasks.put(task2, Collections.singleton(t1p2)); + + thread.taskManager().handleAssignment(activeTasks, emptyMap()); + + thread.shutdown(); + + // even if thread is no longer running, it should still be polling + // as long as the rebalance is still ongoing + assertFalse(thread.isRunning()); + + Thread.sleep(1000); + assertEquals(Utils.mkSet(task1, task2), thread.taskManager().activeTaskIds()); + assertEquals(StreamThread.State.PENDING_SHUTDOWN, thread.state()); + + thread.rebalanceListener().onPartitionsAssigned(assignedPartitions); + + TestUtils.waitForCondition( + () -> thread.state() == StreamThread.State.DEAD, + 10 * 1000, + "Thread never shut down."); + assertEquals(Collections.emptySet(), thread.taskManager().activeTaskIds()); + } + + @Test + public void shouldCloseAllTaskProducersOnCloseIfEosEnabled() throws InterruptedException { + internalTopologyBuilder.addSource(null, "source1", null, null, null, topic1); + + final StreamThread thread = createStreamThread(CLIENT_ID, new StreamsConfig(configProps(true)), true); + + thread.start(); + TestUtils.waitForCondition( + () -> thread.state() == StreamThread.State.STARTING, + 10 * 1000, + "Thread never started."); + + thread.rebalanceListener().onPartitionsRevoked(Collections.emptyList()); + + final Map> activeTasks = new HashMap<>(); + final List assignedPartitions = new ArrayList<>(); + + // assign single partition + assignedPartitions.add(t1p1); + assignedPartitions.add(t1p2); + activeTasks.put(task1, Collections.singleton(t1p1)); + activeTasks.put(task2, Collections.singleton(t1p2)); + + thread.taskManager().handleAssignment(activeTasks, emptyMap()); + thread.rebalanceListener().onPartitionsAssigned(assignedPartitions); + + thread.shutdown(); + TestUtils.waitForCondition( + () -> thread.state() == StreamThread.State.DEAD, + 10 * 1000, + "Thread never shut down."); + + for (final Task task : thread.activeTasks()) { + assertTrue(((MockProducer) ((RecordCollectorImpl) ((StreamTask) task).recordCollector()).producer()).closed()); + } + } + + @Test + public void shouldShutdownTaskManagerOnClose() { + final Consumer consumer = EasyMock.createNiceMock(Consumer.class); + final ConsumerGroupMetadata consumerGroupMetadata = mock(ConsumerGroupMetadata.class); + expect(consumer.groupMetadata()).andStubReturn(consumerGroupMetadata); + expect(consumerGroupMetadata.groupInstanceId()).andReturn(Optional.empty()); + EasyMock.replay(consumerGroupMetadata); + final TaskManager taskManager = EasyMock.createNiceMock(TaskManager.class); + expect(taskManager.producerClientIds()).andStubReturn(Collections.emptySet()); + taskManager.shutdown(true); + EasyMock.expectLastCall(); + EasyMock.replay(taskManager, consumer); + + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, CLIENT_ID, StreamsConfig.METRICS_LATEST, mockTime); + final TopologyMetadata topologyMetadata = new TopologyMetadata(internalTopologyBuilder, config); + topologyMetadata.buildAndRewriteTopology(); + final StreamThread thread = new StreamThread( + mockTime, + config, + null, + consumer, + consumer, + null, + null, + taskManager, + streamsMetrics, + topologyMetadata, + CLIENT_ID, + new LogContext(""), + new AtomicInteger(), + new AtomicLong(Long.MAX_VALUE), + null, + HANDLER, + null + ).updateThreadMetadata(getSharedAdminClientId(CLIENT_ID)); + thread.setStateListener( + (t, newState, oldState) -> { + if (oldState == StreamThread.State.CREATED && newState == StreamThread.State.STARTING) { + thread.shutdown(); + } + }); + thread.run(); + verify(taskManager); + } + + @Test + public void shouldNotReturnDataAfterTaskMigrated() { + final TaskManager taskManager = EasyMock.createNiceMock(TaskManager.class); + expect(taskManager.producerClientIds()).andStubReturn(Collections.emptySet()); + final InternalTopologyBuilder internalTopologyBuilder = EasyMock.createNiceMock(InternalTopologyBuilder.class); + + expect(internalTopologyBuilder.sourceTopicCollection()).andReturn(Collections.singletonList(topic1)).times(2); + + final MockConsumer consumer = new MockConsumer<>(OffsetResetStrategy.LATEST); + + consumer.subscribe(Collections.singletonList(topic1), new MockRebalanceListener()); + consumer.rebalance(Collections.singletonList(t1p1)); + consumer.updateEndOffsets(Collections.singletonMap(t1p1, 10L)); + consumer.seekToEnd(Collections.singletonList(t1p1)); + + final ChangelogReader changelogReader = new MockChangelogReader() { + @Override + public void restore(final Map tasks) { + consumer.addRecord(new ConsumerRecord<>(topic1, 1, 11, new byte[0], new byte[0])); + consumer.addRecord(new ConsumerRecord<>(topic1, 1, 12, new byte[1], new byte[0])); + + throw new TaskMigratedException( + "Changelog restore found task migrated", new RuntimeException("restore task migrated")); + } + }; + + taskManager.handleLostAll(); + + EasyMock.replay(taskManager, internalTopologyBuilder); + + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, CLIENT_ID, StreamsConfig.METRICS_LATEST, mockTime); + + final StreamThread thread = new StreamThread( + mockTime, + config, + null, + consumer, + consumer, + changelogReader, + null, + taskManager, + streamsMetrics, + new TopologyMetadata(internalTopologyBuilder, config), + CLIENT_ID, + new LogContext(""), + new AtomicInteger(), + new AtomicLong(Long.MAX_VALUE), + null, + HANDLER, + null + ).updateThreadMetadata(getSharedAdminClientId(CLIENT_ID)); + + final StreamsException thrown = assertThrows(StreamsException.class, thread::run); + + verify(taskManager); + + assertThat(thrown.getCause(), isA(IllegalStateException.class)); + // The Mock consumer shall throw as the assignment has been wiped out, but records are assigned. + assertEquals("No current assignment for partition topic1-1", thrown.getCause().getMessage()); + assertFalse(consumer.shouldRebalance()); + } + + @Test + public void shouldShutdownTaskManagerOnCloseWithoutStart() { + final Consumer consumer = EasyMock.createNiceMock(Consumer.class); + final ConsumerGroupMetadata consumerGroupMetadata = mock(ConsumerGroupMetadata.class); + expect(consumer.groupMetadata()).andStubReturn(consumerGroupMetadata); + expect(consumerGroupMetadata.groupInstanceId()).andReturn(Optional.empty()); + EasyMock.replay(consumerGroupMetadata); + final TaskManager taskManager = EasyMock.createNiceMock(TaskManager.class); + expect(taskManager.producerClientIds()).andStubReturn(Collections.emptySet()); + taskManager.shutdown(true); + EasyMock.expectLastCall(); + EasyMock.replay(taskManager, consumer); + + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, CLIENT_ID, StreamsConfig.METRICS_LATEST, mockTime); + final StreamThread thread = new StreamThread( + mockTime, + config, + null, + consumer, + consumer, + null, + null, + taskManager, + streamsMetrics, + new TopologyMetadata(internalTopologyBuilder, config), + CLIENT_ID, + new LogContext(""), + new AtomicInteger(), + new AtomicLong(Long.MAX_VALUE), + null, + HANDLER, + null + ).updateThreadMetadata(getSharedAdminClientId(CLIENT_ID)); + thread.shutdown(); + verify(taskManager); + } + + @Test + public void shouldOnlyShutdownOnce() { + final Consumer consumer = EasyMock.createNiceMock(Consumer.class); + final ConsumerGroupMetadata consumerGroupMetadata = mock(ConsumerGroupMetadata.class); + expect(consumer.groupMetadata()).andStubReturn(consumerGroupMetadata); + expect(consumerGroupMetadata.groupInstanceId()).andReturn(Optional.empty()); + EasyMock.replay(consumerGroupMetadata); + final TaskManager taskManager = EasyMock.createNiceMock(TaskManager.class); + expect(taskManager.producerClientIds()).andStubReturn(Collections.emptySet()); + taskManager.shutdown(true); + EasyMock.expectLastCall(); + EasyMock.replay(taskManager, consumer); + + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, CLIENT_ID, StreamsConfig.METRICS_LATEST, mockTime); + final StreamThread thread = new StreamThread( + mockTime, + config, + null, + consumer, + consumer, + null, + null, + taskManager, + streamsMetrics, + new TopologyMetadata(internalTopologyBuilder, config), + CLIENT_ID, + new LogContext(""), + new AtomicInteger(), + new AtomicLong(Long.MAX_VALUE), + null, + HANDLER, + null + ).updateThreadMetadata(getSharedAdminClientId(CLIENT_ID)); + thread.shutdown(); + // Execute the run method. Verification of the mock will check that shutdown was only done once + thread.run(); + verify(taskManager); + } + + @Test + public void shouldNotThrowWhenStandbyTasksAssignedAndNoStateStoresForTopology() { + internalTopologyBuilder.addSource(null, "name", null, null, null, "topic"); + internalTopologyBuilder.addSink("out", "output", null, null, null, "name"); + + final StreamThread thread = createStreamThread(CLIENT_ID, config, false); + + thread.setState(StreamThread.State.STARTING); + thread.rebalanceListener().onPartitionsRevoked(Collections.emptyList()); + + final Map> standbyTasks = new HashMap<>(); + + // assign single partition + standbyTasks.put(task1, Collections.singleton(t1p1)); + + thread.taskManager().handleAssignment(emptyMap(), standbyTasks); + + thread.rebalanceListener().onPartitionsAssigned(Collections.emptyList()); + } + + @Test + public void shouldNotCloseTaskAndRemoveFromTaskManagerIfProducerWasFencedWhileProcessing() throws Exception { + internalTopologyBuilder.addSource(null, "source", null, null, null, topic1); + internalTopologyBuilder.addSink("sink", "dummyTopic", null, null, null, "source"); + + final StreamThread thread = createStreamThread(CLIENT_ID, new StreamsConfig(configProps(true)), true); + + final MockConsumer consumer = clientSupplier.consumer; + + consumer.updatePartitions(topic1, Collections.singletonList(new PartitionInfo(topic1, 1, null, null, null))); + + thread.setState(StreamThread.State.STARTING); + thread.rebalanceListener().onPartitionsRevoked(Collections.emptySet()); + + final Map> activeTasks = new HashMap<>(); + final List assignedPartitions = new ArrayList<>(); + + // assign single partition + assignedPartitions.add(t1p1); + activeTasks.put(task1, Collections.singleton(t1p1)); + + thread.taskManager().handleAssignment(activeTasks, emptyMap()); + + final MockConsumer mockConsumer = (MockConsumer) thread.mainConsumer(); + mockConsumer.assign(assignedPartitions); + mockConsumer.updateBeginningOffsets(Collections.singletonMap(t1p1, 0L)); + thread.rebalanceListener().onPartitionsAssigned(assignedPartitions); + + thread.runOnce(); + assertThat(thread.activeTasks().size(), equalTo(1)); + final MockProducer producer = clientSupplier.producers.get(0); + + // change consumer subscription from "pattern" to "manual" to be able to call .addRecords() + consumer.updateBeginningOffsets(Collections.singletonMap(assignedPartitions.iterator().next(), 0L)); + consumer.unsubscribe(); + consumer.assign(new HashSet<>(assignedPartitions)); + + consumer.addRecord(new ConsumerRecord<>(topic1, 1, 0, new byte[0], new byte[0])); + mockTime.sleep(config.getLong(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG) + 1L); + thread.runOnce(); + assertThat(producer.history().size(), equalTo(1)); + + mockTime.sleep(config.getLong(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG) + 1L); + TestUtils.waitForCondition( + () -> producer.commitCount() == 1, + "StreamsThread did not commit transaction."); + + producer.fenceProducer(); + mockTime.sleep(config.getLong(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG) + 1L); + consumer.addRecord(new ConsumerRecord<>(topic1, 1, 1, new byte[0], new byte[0])); + try { + thread.runOnce(); + fail("Should have thrown TaskMigratedException"); + } catch (final KafkaException expected) { + assertTrue(expected instanceof TaskMigratedException); + assertTrue("StreamsThread removed the fenced zombie task already, should wait for rebalance to close all zombies together.", + thread.activeTasks().stream().anyMatch(task -> task.id().equals(task1))); + } + + assertThat(producer.commitCount(), equalTo(1L)); + } + + @Test + public void shouldNotCloseTaskAndRemoveFromTaskManagerIfProducerGotFencedInCommitTransactionWhenSuspendingTasks() { + final StreamThread thread = createStreamThread(CLIENT_ID, new StreamsConfig(configProps(true)), true); + + internalTopologyBuilder.addSource(null, "name", null, null, null, topic1); + internalTopologyBuilder.addSink("out", "output", null, null, null, "name"); + + thread.setState(StreamThread.State.STARTING); + thread.rebalanceListener().onPartitionsRevoked(Collections.emptySet()); + + final Map> activeTasks = new HashMap<>(); + final List assignedPartitions = new ArrayList<>(); + + // assign single partition + assignedPartitions.add(t1p1); + activeTasks.put(task1, Collections.singleton(t1p1)); + + thread.taskManager().handleAssignment(activeTasks, emptyMap()); + + final MockConsumer mockConsumer = (MockConsumer) thread.mainConsumer(); + mockConsumer.assign(assignedPartitions); + mockConsumer.updateBeginningOffsets(Collections.singletonMap(t1p1, 0L)); + thread.rebalanceListener().onPartitionsAssigned(assignedPartitions); + + thread.runOnce(); + + assertThat(thread.activeTasks().size(), equalTo(1)); + + // need to process a record to enable committing + addRecord(mockConsumer, 0L); + thread.runOnce(); + + clientSupplier.producers.get(0).commitTransactionException = new ProducerFencedException("Producer is fenced"); + assertThrows(TaskMigratedException.class, () -> thread.rebalanceListener().onPartitionsRevoked(assignedPartitions)); + assertFalse(clientSupplier.producers.get(0).transactionCommitted()); + assertFalse(clientSupplier.producers.get(0).closed()); + assertEquals(1, thread.activeTasks().size()); + } + + @Test + public void shouldReinitializeRevivedTasksInAnyState() { + final StreamThread thread = createStreamThread(CLIENT_ID, new StreamsConfig(configProps(false)), false); + + final String storeName = "store"; + final String storeChangelog = "stream-thread-test-store-changelog"; + final TopicPartition storeChangelogTopicPartition = new TopicPartition(storeChangelog, 1); + + internalTopologyBuilder.addSource(null, "name", null, null, null, topic1); + final AtomicBoolean shouldThrow = new AtomicBoolean(false); + final AtomicBoolean processed = new AtomicBoolean(false); + internalTopologyBuilder.addProcessor( + "proc", + () -> record -> { + if (shouldThrow.get()) { + throw new TaskCorruptedException(singleton(task1)); + } else { + processed.set(true); + } + }, + "name" + ); + internalTopologyBuilder.addStateStore( + Stores.keyValueStoreBuilder( + Stores.persistentKeyValueStore(storeName), + Serdes.String(), + Serdes.String() + ), + "proc" + ); + + thread.setState(StreamThread.State.STARTING); + thread.rebalanceListener().onPartitionsRevoked(Collections.emptySet()); + + final Map> activeTasks = new HashMap<>(); + final List assignedPartitions = new ArrayList<>(); + + // assign single partition + assignedPartitions.add(t1p1); + activeTasks.put(task1, Collections.singleton(t1p1)); + + thread.taskManager().handleAssignment(activeTasks, emptyMap()); + + final MockConsumer mockConsumer = (MockConsumer) thread.mainConsumer(); + mockConsumer.assign(assignedPartitions); + mockConsumer.updateBeginningOffsets(mkMap( + mkEntry(t1p1, 0L) + )); + + final MockConsumer restoreConsumer = (MockConsumer) thread.restoreConsumer(); + restoreConsumer.updateBeginningOffsets(mkMap( + mkEntry(storeChangelogTopicPartition, 0L) + )); + final MockAdminClient admin = (MockAdminClient) thread.adminClient(); + admin.updateEndOffsets(singletonMap(storeChangelogTopicPartition, 0L)); + + thread.rebalanceListener().onPartitionsAssigned(assignedPartitions); + + + // the first iteration completes the restoration + thread.runOnce(); + assertThat(thread.activeTasks().size(), equalTo(1)); + + // the second transits to running and unpause the input + thread.runOnce(); + + // the third actually polls, processes the record, and throws the corruption exception + addRecord(mockConsumer, 0L); + shouldThrow.set(true); + final TaskCorruptedException taskCorruptedException = assertThrows(TaskCorruptedException.class, thread::runOnce); + + // Now, we can handle the corruption + thread.taskManager().handleCorruption(taskCorruptedException.corruptedTasks()); + + // again, complete the restoration + thread.runOnce(); + // transit to running and unpause + thread.runOnce(); + // process the record + addRecord(mockConsumer, 0L); + shouldThrow.set(false); + assertThat(processed.get(), is(false)); + thread.runOnce(); + assertThat(processed.get(), is(true)); + thread.taskManager().shutdown(true); + } + + @Test + public void shouldNotCloseTaskAndRemoveFromTaskManagerIfProducerGotFencedInCommitTransactionWhenCommitting() { + // only have source but no sink so that we would not get fenced in producer.send + internalTopologyBuilder.addSource(null, "source", null, null, null, topic1); + + final StreamThread thread = createStreamThread(CLIENT_ID, new StreamsConfig(configProps(true)), true); + + final MockConsumer consumer = clientSupplier.consumer; + + consumer.updatePartitions(topic1, Collections.singletonList(new PartitionInfo(topic1, 1, null, null, null))); + + thread.setState(StreamThread.State.STARTING); + thread.rebalanceListener().onPartitionsRevoked(Collections.emptySet()); + + final Map> activeTasks = new HashMap<>(); + final List assignedPartitions = new ArrayList<>(); + + // assign single partition + assignedPartitions.add(t1p1); + activeTasks.put(task1, Collections.singleton(t1p1)); + + thread.taskManager().handleAssignment(activeTasks, emptyMap()); + + final MockConsumer mockConsumer = (MockConsumer) thread.mainConsumer(); + mockConsumer.assign(assignedPartitions); + mockConsumer.updateBeginningOffsets(Collections.singletonMap(t1p1, 0L)); + thread.rebalanceListener().onPartitionsAssigned(assignedPartitions); + + thread.runOnce(); + assertThat(thread.activeTasks().size(), equalTo(1)); + final MockProducer producer = clientSupplier.producers.get(0); + + producer.commitTransactionException = new ProducerFencedException("Producer is fenced"); + mockTime.sleep(config.getLong(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG) + 1L); + consumer.addRecord(new ConsumerRecord<>(topic1, 1, 1, new byte[0], new byte[0])); + try { + thread.runOnce(); + fail("Should have thrown TaskMigratedException"); + } catch (final KafkaException expected) { + assertTrue(expected instanceof TaskMigratedException); + assertTrue("StreamsThread removed the fenced zombie task already, should wait for rebalance to close all zombies together.", + thread.activeTasks().stream().anyMatch(task -> task.id().equals(task1))); + } + + assertThat(producer.commitCount(), equalTo(0L)); + + assertTrue(clientSupplier.producers.get(0).transactionInFlight()); + assertFalse(clientSupplier.producers.get(0).transactionCommitted()); + assertFalse(clientSupplier.producers.get(0).closed()); + assertEquals(1, thread.activeTasks().size()); + } + + @Test + public void shouldNotCloseTaskProducerWhenSuspending() { + final StreamThread thread = createStreamThread(CLIENT_ID, new StreamsConfig(configProps(true)), true); + + internalTopologyBuilder.addSource(null, "name", null, null, null, topic1); + internalTopologyBuilder.addSink("out", "output", null, null, null, "name"); + + thread.setState(StreamThread.State.STARTING); + thread.rebalanceListener().onPartitionsRevoked(Collections.emptySet()); + + final Map> activeTasks = new HashMap<>(); + final List assignedPartitions = new ArrayList<>(); + + // assign single partition + assignedPartitions.add(t1p1); + activeTasks.put(task1, Collections.singleton(t1p1)); + + thread.taskManager().handleAssignment(activeTasks, emptyMap()); + + final MockConsumer mockConsumer = (MockConsumer) thread.mainConsumer(); + mockConsumer.assign(assignedPartitions); + mockConsumer.updateBeginningOffsets(Collections.singletonMap(t1p1, 0L)); + thread.rebalanceListener().onPartitionsAssigned(assignedPartitions); + + thread.runOnce(); + + assertThat(thread.activeTasks().size(), equalTo(1)); + + // need to process a record to enable committing + addRecord(mockConsumer, 0L); + thread.runOnce(); + + thread.rebalanceListener().onPartitionsRevoked(assignedPartitions); + assertTrue(clientSupplier.producers.get(0).transactionCommitted()); + assertFalse(clientSupplier.producers.get(0).closed()); + assertEquals(1, thread.activeTasks().size()); + } + + @Test + public void shouldReturnActiveTaskMetadataWhileRunningState() { + internalTopologyBuilder.addSource(null, "source", null, null, null, topic1); + + clientSupplier.setCluster(createCluster()); + + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl( + metrics, + APPLICATION_ID, + config.getString(StreamsConfig.BUILT_IN_METRICS_VERSION_CONFIG), + mockTime + ); + + internalTopologyBuilder.buildTopology(); + + final StreamThread thread = StreamThread.create( + new TopologyMetadata(internalTopologyBuilder, config), + config, + clientSupplier, + clientSupplier.getAdmin(config.getAdminConfigs(CLIENT_ID)), + PROCESS_ID, + CLIENT_ID, + streamsMetrics, + mockTime, + streamsMetadataState, + 0, + stateDirectory, + new MockStateRestoreListener(), + threadIdx, + null, + HANDLER + ); + + thread.setState(StreamThread.State.STARTING); + thread.rebalanceListener().onPartitionsRevoked(Collections.emptySet()); + + final Map> activeTasks = new HashMap<>(); + final List assignedPartitions = new ArrayList<>(); + + // assign single partition + assignedPartitions.add(t1p1); + activeTasks.put(task1, Collections.singleton(t1p1)); + + thread.taskManager().handleAssignment(activeTasks, emptyMap()); + + final MockConsumer mockConsumer = (MockConsumer) thread.mainConsumer(); + mockConsumer.assign(assignedPartitions); + mockConsumer.updateBeginningOffsets(Collections.singletonMap(t1p1, 0L)); + thread.rebalanceListener().onPartitionsAssigned(assignedPartitions); + + thread.runOnce(); + + final ThreadMetadata metadata = thread.threadMetadata(); + assertEquals(StreamThread.State.RUNNING.name(), metadata.threadState()); + assertTrue(metadata.activeTasks().contains(new TaskMetadataImpl(task1, Utils.mkSet(t1p1), new HashMap<>(), new HashMap<>(), Optional.empty()))); + assertTrue(metadata.standbyTasks().isEmpty()); + + assertTrue("#threadState() was: " + metadata.threadState() + "; expected either RUNNING, STARTING, PARTITIONS_REVOKED, PARTITIONS_ASSIGNED, or CREATED", + Arrays.asList("RUNNING", "STARTING", "PARTITIONS_REVOKED", "PARTITIONS_ASSIGNED", "CREATED").contains(metadata.threadState())); + final String threadName = metadata.threadName(); + assertThat(threadName, startsWith(CLIENT_ID + "-StreamThread-" + threadIdx)); + assertEquals(threadName + "-consumer", metadata.consumerClientId()); + assertEquals(threadName + "-restore-consumer", metadata.restoreConsumerClientId()); + assertEquals(Collections.singleton(threadName + "-producer"), metadata.producerClientIds()); + assertEquals(CLIENT_ID + "-admin", metadata.adminClientId()); + } + + @Test + public void shouldReturnStandbyTaskMetadataWhileRunningState() { + internalStreamsBuilder.stream(Collections.singleton(topic1), consumed) + .groupByKey().count(Materialized.as("count-one")); + + internalStreamsBuilder.buildAndOptimizeTopology(); + final StreamThread thread = createStreamThread(CLIENT_ID, config, false); + final MockConsumer restoreConsumer = clientSupplier.restoreConsumer; + restoreConsumer.updatePartitions( + "stream-thread-test-count-one-changelog", + Collections.singletonList( + new PartitionInfo("stream-thread-test-count-one-changelog", + 0, + null, + new Node[0], + new Node[0]) + ) + ); + + final HashMap offsets = new HashMap<>(); + offsets.put(new TopicPartition("stream-thread-test-count-one-changelog", 1), 0L); + restoreConsumer.updateEndOffsets(offsets); + restoreConsumer.updateBeginningOffsets(offsets); + + thread.setState(StreamThread.State.STARTING); + thread.rebalanceListener().onPartitionsRevoked(Collections.emptySet()); + + final Map> standbyTasks = new HashMap<>(); + + // assign single partition + standbyTasks.put(task1, Collections.singleton(t1p1)); + + thread.taskManager().handleAssignment(emptyMap(), standbyTasks); + + thread.rebalanceListener().onPartitionsAssigned(Collections.emptyList()); + + thread.runOnce(); + + final ThreadMetadata threadMetadata = thread.threadMetadata(); + assertEquals(StreamThread.State.RUNNING.name(), threadMetadata.threadState()); + assertTrue(threadMetadata.standbyTasks().contains(new TaskMetadataImpl(task1, Utils.mkSet(t1p1), new HashMap<>(), new HashMap<>(), Optional.empty()))); + assertTrue(threadMetadata.activeTasks().isEmpty()); + + thread.taskManager().shutdown(true); + } + + @SuppressWarnings("unchecked") + @Test + public void shouldUpdateStandbyTask() throws Exception { + final String storeName1 = "count-one"; + final String storeName2 = "table-two"; + final String changelogName1 = APPLICATION_ID + "-" + storeName1 + "-changelog"; + final String changelogName2 = APPLICATION_ID + "-" + storeName2 + "-changelog"; + final TopicPartition partition1 = new TopicPartition(changelogName1, 1); + final TopicPartition partition2 = new TopicPartition(changelogName2, 1); + internalStreamsBuilder + .stream(Collections.singleton(topic1), consumed) + .groupByKey() + .count(Materialized.as(storeName1)); + final MaterializedInternal> materialized + = new MaterializedInternal<>(Materialized.as(storeName2), internalStreamsBuilder, ""); + internalStreamsBuilder.table(topic2, new ConsumedInternal<>(), materialized); + + internalStreamsBuilder.buildAndOptimizeTopology(); + final StreamThread thread = createStreamThread(CLIENT_ID, config, false); + final MockConsumer restoreConsumer = clientSupplier.restoreConsumer; + restoreConsumer.updatePartitions(changelogName1, + Collections.singletonList(new PartitionInfo(changelogName1, 1, null, new Node[0], new Node[0])) + ); + + restoreConsumer.updateEndOffsets(Collections.singletonMap(partition1, 10L)); + restoreConsumer.updateBeginningOffsets(Collections.singletonMap(partition1, 0L)); + restoreConsumer.updateEndOffsets(Collections.singletonMap(partition2, 10L)); + restoreConsumer.updateBeginningOffsets(Collections.singletonMap(partition2, 0L)); + final OffsetCheckpoint checkpoint + = new OffsetCheckpoint(new File(stateDirectory.getOrCreateDirectoryForTask(task3), CHECKPOINT_FILE_NAME)); + checkpoint.write(Collections.singletonMap(partition2, 5L)); + + thread.setState(StreamThread.State.STARTING); + thread.rebalanceListener().onPartitionsRevoked(Collections.emptySet()); + + final Map> standbyTasks = new HashMap<>(); + + // assign single partition + standbyTasks.put(task1, Collections.singleton(t1p1)); + standbyTasks.put(task3, Collections.singleton(t2p1)); + + thread.taskManager().handleAssignment(emptyMap(), standbyTasks); + thread.taskManager().tryToCompleteRestoration(mockTime.milliseconds(), null); + + thread.rebalanceListener().onPartitionsAssigned(Collections.emptyList()); + + thread.runOnce(); + + final StandbyTask standbyTask1 = standbyTask(thread.taskManager(), t1p1); + final StandbyTask standbyTask2 = standbyTask(thread.taskManager(), t2p1); + assertEquals(task1, standbyTask1.id()); + assertEquals(task3, standbyTask2.id()); + + final KeyValueStore store1 = (KeyValueStore) standbyTask1.getStore(storeName1); + final KeyValueStore store2 = (KeyValueStore) standbyTask2.getStore(storeName2); + assertEquals(0L, store1.approximateNumEntries()); + assertEquals(0L, store2.approximateNumEntries()); + + // let the store1 be restored from 0 to 10; store2 be restored from 5 (checkpointed) to 10 + for (long i = 0L; i < 10L; i++) { + restoreConsumer.addRecord(new ConsumerRecord<>( + changelogName1, + 1, + i, + ("K" + i).getBytes(), + ("V" + i).getBytes())); + restoreConsumer.addRecord(new ConsumerRecord<>( + changelogName2, + 1, + i, + ("K" + i).getBytes(), + ("V" + i).getBytes())); + } + + thread.runOnce(); + + assertEquals(10L, store1.approximateNumEntries()); + assertEquals(4L, store2.approximateNumEntries()); + + thread.taskManager().shutdown(true); + } + + @Test + public void shouldCreateStandbyTask() { + setupInternalTopologyWithoutState(); + internalTopologyBuilder.addStateStore(new MockKeyValueStoreBuilder("myStore", true), "processor1"); + + assertThat(createStandbyTask(), not(empty())); + } + + @Test + public void shouldNotCreateStandbyTaskWithoutStateStores() { + setupInternalTopologyWithoutState(); + + assertThat(createStandbyTask(), empty()); + } + + @Test + public void shouldNotCreateStandbyTaskIfStateStoresHaveLoggingDisabled() { + setupInternalTopologyWithoutState(); + final StoreBuilder> storeBuilder = + new MockKeyValueStoreBuilder("myStore", true); + storeBuilder.withLoggingDisabled(); + internalTopologyBuilder.addStateStore(storeBuilder, "processor1"); + + assertThat(createStandbyTask(), empty()); + } + + @Test + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + public void shouldPunctuateActiveTask() { + final List punctuatedStreamTime = new ArrayList<>(); + final List punctuatedWallClockTime = new ArrayList<>(); + final org.apache.kafka.streams.processor.ProcessorSupplier punctuateProcessor = + () -> new org.apache.kafka.streams.processor.AbstractProcessor() { + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + context.schedule(Duration.ofMillis(100L), PunctuationType.STREAM_TIME, punctuatedStreamTime::add); + context.schedule(Duration.ofMillis(100L), PunctuationType.WALL_CLOCK_TIME, punctuatedWallClockTime::add); + } + + @Override + public void process(final Object key, final Object value) {} + }; + + internalStreamsBuilder.stream(Collections.singleton(topic1), consumed).process(punctuateProcessor); + internalStreamsBuilder.buildAndOptimizeTopology(); + + final StreamThread thread = createStreamThread(CLIENT_ID, config, false); + + thread.setState(StreamThread.State.STARTING); + thread.rebalanceListener().onPartitionsRevoked(Collections.emptySet()); + final List assignedPartitions = new ArrayList<>(); + + final Map> activeTasks = new HashMap<>(); + + // assign single partition + assignedPartitions.add(t1p1); + activeTasks.put(task1, Collections.singleton(t1p1)); + + thread.taskManager().handleAssignment(activeTasks, emptyMap()); + + clientSupplier.consumer.assign(assignedPartitions); + clientSupplier.consumer.updateBeginningOffsets(Collections.singletonMap(t1p1, 0L)); + thread.rebalanceListener().onPartitionsAssigned(assignedPartitions); + + thread.runOnce(); + + assertEquals(0, punctuatedStreamTime.size()); + assertEquals(0, punctuatedWallClockTime.size()); + + mockTime.sleep(100L); + clientSupplier.consumer.addRecord(new ConsumerRecord<>( + topic1, + 1, + 100L, + 100L, + TimestampType.CREATE_TIME, + "K".getBytes().length, + "V".getBytes().length, + "K".getBytes(), + "V".getBytes(), + new RecordHeaders(), + Optional.empty())); + + thread.runOnce(); + + assertEquals(1, punctuatedStreamTime.size()); + assertEquals(1, punctuatedWallClockTime.size()); + + mockTime.sleep(100L); + + thread.runOnce(); + + // we should skip stream time punctuation, only trigger wall-clock time punctuation + assertEquals(1, punctuatedStreamTime.size()); + assertEquals(2, punctuatedWallClockTime.size()); + } + + @Test + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + public void shouldPunctuateWithTimestampPreservedInProcessorContext() { + final org.apache.kafka.streams.kstream.TransformerSupplier> punctuateProcessor = + () -> new org.apache.kafka.streams.kstream.Transformer>() { + @Override + public void init(final org.apache.kafka.streams.processor.ProcessorContext context) { + context.schedule(Duration.ofMillis(100L), PunctuationType.WALL_CLOCK_TIME, timestamp -> context.forward("key", "value")); + context.schedule(Duration.ofMillis(100L), PunctuationType.STREAM_TIME, timestamp -> context.forward("key", "value")); + } + + @Override + public KeyValue transform(final Object key, final Object value) { + return null; + } + + @Override + public void close() {} + }; + + final List peekedContextTime = new ArrayList<>(); + final org.apache.kafka.streams.processor.ProcessorSupplier peekProcessor = + () -> new org.apache.kafka.streams.processor.AbstractProcessor() { + @Override + public void process(final Object key, final Object value) { + peekedContextTime.add(context.timestamp()); + } + }; + + internalStreamsBuilder.stream(Collections.singleton(topic1), consumed) + .transform(punctuateProcessor) + .process(peekProcessor); + internalStreamsBuilder.buildAndOptimizeTopology(); + + final long currTime = mockTime.milliseconds(); + final StreamThread thread = createStreamThread(CLIENT_ID, config, false); + + thread.setState(StreamThread.State.STARTING); + thread.rebalanceListener().onPartitionsRevoked(Collections.emptySet()); + final List assignedPartitions = new ArrayList<>(); + + final Map> activeTasks = new HashMap<>(); + + // assign single partition + assignedPartitions.add(t1p1); + activeTasks.put(task1, Collections.singleton(t1p1)); + + thread.taskManager().handleAssignment(activeTasks, emptyMap()); + + clientSupplier.consumer.assign(assignedPartitions); + clientSupplier.consumer.updateBeginningOffsets(Collections.singletonMap(t1p1, 0L)); + thread.rebalanceListener().onPartitionsAssigned(assignedPartitions); + + thread.runOnce(); + assertEquals(0, peekedContextTime.size()); + + mockTime.sleep(100L); + thread.runOnce(); + + assertEquals(1, peekedContextTime.size()); + assertEquals(currTime + 100L, peekedContextTime.get(0).longValue()); + + clientSupplier.consumer.addRecord(new ConsumerRecord<>( + topic1, + 1, + 110L, + 110L, + TimestampType.CREATE_TIME, + "K".getBytes().length, + "V".getBytes().length, + "K".getBytes(), + "V".getBytes(), + new RecordHeaders(), + Optional.empty())); + + thread.runOnce(); + + assertEquals(2, peekedContextTime.size()); + assertEquals(110L, peekedContextTime.get(1).longValue()); + } + + @Test + public void shouldAlwaysUpdateTasksMetadataAfterChangingState() { + final StreamThread thread = createStreamThread(CLIENT_ID, config, false); + ThreadMetadata metadata = thread.threadMetadata(); + assertEquals(StreamThread.State.CREATED.name(), metadata.threadState()); + + thread.setState(StreamThread.State.STARTING); + thread.setState(StreamThread.State.PARTITIONS_REVOKED); + thread.setState(StreamThread.State.PARTITIONS_ASSIGNED); + thread.setState(StreamThread.State.RUNNING); + metadata = thread.threadMetadata(); + assertEquals(StreamThread.State.RUNNING.name(), metadata.threadState()); + } + + @Test + public void shouldRecoverFromInvalidOffsetExceptionOnRestoreAndFinishRestore() throws Exception { + internalStreamsBuilder.stream(Collections.singleton("topic"), consumed) + .groupByKey() + .count(Materialized.as("count")); + internalStreamsBuilder.buildAndOptimizeTopology(); + + final StreamThread thread = createStreamThread("clientId", config, false); + final MockConsumer mockConsumer = (MockConsumer) thread.mainConsumer(); + final MockConsumer mockRestoreConsumer = (MockConsumer) thread.restoreConsumer(); + final MockAdminClient mockAdminClient = (MockAdminClient) thread.adminClient(); + + final TopicPartition topicPartition = new TopicPartition("topic", 0); + final Set topicPartitionSet = Collections.singleton(topicPartition); + + final Map> activeTasks = new HashMap<>(); + final TaskId task0 = new TaskId(0, 0); + activeTasks.put(task0, topicPartitionSet); + + thread.taskManager().handleAssignment(activeTasks, emptyMap()); + + mockConsumer.updatePartitions( + "topic", + Collections.singletonList( + new PartitionInfo( + "topic", + 0, + null, + new Node[0], + new Node[0] + ) + ) + ); + mockConsumer.updateBeginningOffsets(Collections.singletonMap(topicPartition, 0L)); + + mockRestoreConsumer.updatePartitions( + "stream-thread-test-count-changelog", + Collections.singletonList( + new PartitionInfo( + "stream-thread-test-count-changelog", + 0, + null, + new Node[0], + new Node[0] + ) + ) + ); + + final TopicPartition changelogPartition = new TopicPartition("stream-thread-test-count-changelog", 0); + final Set changelogPartitionSet = Collections.singleton(changelogPartition); + mockRestoreConsumer.updateBeginningOffsets(Collections.singletonMap(changelogPartition, 0L)); + mockAdminClient.updateEndOffsets(Collections.singletonMap(changelogPartition, 2L)); + + mockConsumer.schedulePollTask(() -> { + thread.setState(StreamThread.State.PARTITIONS_REVOKED); + thread.rebalanceListener().onPartitionsAssigned(topicPartitionSet); + }); + + try { + thread.start(); + + TestUtils.waitForCondition( + () -> mockRestoreConsumer.assignment().size() == 1, + "Never get the assignment"); + + mockRestoreConsumer.addRecord(new ConsumerRecord<>( + "stream-thread-test-count-changelog", + 0, + 0L, + "K1".getBytes(), + "V1".getBytes())); + + TestUtils.waitForCondition( + () -> mockRestoreConsumer.position(changelogPartition) == 1L, + "Never restore first record"); + + mockRestoreConsumer.setPollException(new InvalidOffsetException("Try Again!") { + @Override + public Set partitions() { + return changelogPartitionSet; + } + }); + + // after handling the exception and reviving the task, the position + // should be reset to the beginning. + TestUtils.waitForCondition( + () -> mockRestoreConsumer.position(changelogPartition) == 0L, + "Never restore first record"); + + mockRestoreConsumer.addRecord(new ConsumerRecord<>( + "stream-thread-test-count-changelog", + 0, + 0L, + "K1".getBytes(), + "V1".getBytes())); + mockRestoreConsumer.addRecord(new ConsumerRecord<>( + "stream-thread-test-count-changelog", + 0, + 1L, + "K2".getBytes(), + "V2".getBytes())); + + TestUtils.waitForCondition( + () -> { + mockRestoreConsumer.assign(changelogPartitionSet); + return mockRestoreConsumer.position(changelogPartition) == 2L; + }, + "Never finished restore"); + } finally { + thread.shutdown(); + thread.join(10000); + } + } + + @Test + public void shouldLogAndRecordSkippedMetricForDeserializationException() { + internalTopologyBuilder.addSource(null, "source1", null, null, null, topic1); + + final Properties config = configProps(false); + config.setProperty( + StreamsConfig.DEFAULT_DESERIALIZATION_EXCEPTION_HANDLER_CLASS_CONFIG, + LogAndContinueExceptionHandler.class.getName() + ); + config.setProperty(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.Integer().getClass().getName()); + final StreamThread thread = createStreamThread(CLIENT_ID, new StreamsConfig(config), false); + + thread.setState(StreamThread.State.STARTING); + thread.setState(StreamThread.State.PARTITIONS_REVOKED); + + final TaskId task1 = new TaskId(0, t1p1.partition()); + final Set assignedPartitions = Collections.singleton(t1p1); + thread.taskManager().handleAssignment( + Collections.singletonMap(task1, assignedPartitions), + emptyMap()); + + final MockConsumer mockConsumer = (MockConsumer) thread.mainConsumer(); + mockConsumer.assign(Collections.singleton(t1p1)); + mockConsumer.updateBeginningOffsets(Collections.singletonMap(t1p1, 0L)); + thread.rebalanceListener().onPartitionsAssigned(assignedPartitions); + thread.runOnce(); + + long offset = -1; + mockConsumer.addRecord(new ConsumerRecord<>( + t1p1.topic(), + t1p1.partition(), + ++offset, + -1, + TimestampType.CREATE_TIME, + -1, + -1, + new byte[0], + "I am not an integer.".getBytes(), + new RecordHeaders(), + Optional.empty())); + mockConsumer.addRecord(new ConsumerRecord<>( + t1p1.topic(), + t1p1.partition(), + ++offset, + -1, + TimestampType.CREATE_TIME, + -1, + -1, + new byte[0], + "I am not an integer.".getBytes(), + new RecordHeaders(), + Optional.empty())); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(RecordDeserializer.class)) { + thread.runOnce(); + + final List strings = appender.getMessages(); + assertTrue(strings.contains("stream-thread [" + Thread.currentThread().getName() + "] task [0_1]" + + " Skipping record due to deserialization error. topic=[topic1] partition=[1] offset=[0]")); + assertTrue(strings.contains("stream-thread [" + Thread.currentThread().getName() + "] task [0_1]" + + " Skipping record due to deserialization error. topic=[topic1] partition=[1] offset=[1]")); + } + } + + @Test + public void shouldThrowTaskMigratedExceptionHandlingTaskLost() { + final Set assignedPartitions = Collections.singleton(t1p1); + + final TaskManager taskManager = EasyMock.createNiceMock(TaskManager.class); + expect(taskManager.producerClientIds()).andStubReturn(Collections.emptySet()); + final MockConsumer consumer = new MockConsumer<>(OffsetResetStrategy.LATEST); + consumer.assign(assignedPartitions); + consumer.updateBeginningOffsets(Collections.singletonMap(t1p1, 0L)); + consumer.updateEndOffsets(Collections.singletonMap(t1p1, 10L)); + + taskManager.handleLostAll(); + EasyMock.expectLastCall() + .andThrow(new TaskMigratedException("Task lost exception", new RuntimeException())); + + EasyMock.replay(taskManager); + + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, CLIENT_ID, StreamsConfig.METRICS_LATEST, mockTime); + final StreamThread thread = new StreamThread( + mockTime, + config, + null, + consumer, + consumer, + null, + null, + taskManager, + streamsMetrics, + new TopologyMetadata(internalTopologyBuilder, config), + CLIENT_ID, + new LogContext(""), + new AtomicInteger(), + new AtomicLong(Long.MAX_VALUE), + null, + HANDLER, + null + ).updateThreadMetadata(getSharedAdminClientId(CLIENT_ID)); + + consumer.schedulePollTask(() -> { + thread.setState(StreamThread.State.PARTITIONS_REVOKED); + thread.rebalanceListener().onPartitionsLost(assignedPartitions); + }); + + thread.setState(StreamThread.State.STARTING); + assertThrows(TaskMigratedException.class, thread::runOnce); + } + + @Test + public void shouldThrowTaskMigratedExceptionHandlingRevocation() { + final Set assignedPartitions = Collections.singleton(t1p1); + + final TaskManager taskManager = EasyMock.createNiceMock(TaskManager.class); + expect(taskManager.producerClientIds()).andStubReturn(Collections.emptySet()); + final MockConsumer consumer = new MockConsumer<>(OffsetResetStrategy.LATEST); + consumer.assign(assignedPartitions); + consumer.updateBeginningOffsets(Collections.singletonMap(t1p1, 0L)); + consumer.updateEndOffsets(Collections.singletonMap(t1p1, 10L)); + + taskManager.handleRevocation(assignedPartitions); + EasyMock.expectLastCall() + .andThrow(new TaskMigratedException("Revocation non fatal exception", new RuntimeException())); + + EasyMock.replay(taskManager); + + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, CLIENT_ID, StreamsConfig.METRICS_LATEST, mockTime); + final StreamThread thread = new StreamThread( + mockTime, + config, + null, + consumer, + consumer, + null, + null, + taskManager, + streamsMetrics, + new TopologyMetadata(internalTopologyBuilder, config), + CLIENT_ID, + new LogContext(""), + new AtomicInteger(), + new AtomicLong(Long.MAX_VALUE), + null, + HANDLER, + null + ).updateThreadMetadata(getSharedAdminClientId(CLIENT_ID)); + + consumer.schedulePollTask(() -> { + thread.setState(StreamThread.State.PARTITIONS_REVOKED); + thread.rebalanceListener().onPartitionsRevoked(assignedPartitions); + }); + + thread.setState(StreamThread.State.STARTING); + assertThrows(TaskMigratedException.class, thread::runOnce); + } + + @Test + @SuppressWarnings("unchecked") + public void shouldCatchHandleCorruptionOnTaskCorruptedExceptionPath() { + final TaskManager taskManager = EasyMock.createNiceMock(TaskManager.class); + expect(taskManager.producerClientIds()).andStubReturn(Collections.emptySet()); + final Consumer consumer = mock(Consumer.class); + final ConsumerGroupMetadata consumerGroupMetadata = mock(ConsumerGroupMetadata.class); + consumer.subscribe((Collection) anyObject(), anyObject()); + EasyMock.expectLastCall().anyTimes(); + consumer.unsubscribe(); + EasyMock.expectLastCall().anyTimes(); + expect(consumer.groupMetadata()).andStubReturn(consumerGroupMetadata); + expect(consumerGroupMetadata.groupInstanceId()).andReturn(Optional.empty()); + EasyMock.replay(consumerGroupMetadata); + final Task task1 = mock(Task.class); + final Task task2 = mock(Task.class); + final TaskId taskId1 = new TaskId(0, 0); + final TaskId taskId2 = new TaskId(0, 2); + + final Set corruptedTasks = singleton(taskId1); + + expect(task1.state()).andReturn(Task.State.RUNNING).anyTimes(); + expect(task1.id()).andReturn(taskId1).anyTimes(); + expect(task2.state()).andReturn(Task.State.RUNNING).anyTimes(); + expect(task2.id()).andReturn(taskId2).anyTimes(); + + expect(taskManager.handleCorruption(corruptedTasks)).andReturn(true); + + EasyMock.replay(task1, task2, taskManager, consumer); + + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, CLIENT_ID, StreamsConfig.METRICS_LATEST, mockTime); + final TopologyMetadata topologyMetadata = new TopologyMetadata(internalTopologyBuilder, config); + topologyMetadata.buildAndRewriteTopology(); + final StreamThread thread = new StreamThread( + mockTime, + config, + null, + consumer, + consumer, + null, + null, + taskManager, + streamsMetrics, + topologyMetadata, + CLIENT_ID, + new LogContext(""), + new AtomicInteger(), + new AtomicLong(Long.MAX_VALUE), + null, + HANDLER, + null + ) { + @Override + void runOnce() { + setState(State.PENDING_SHUTDOWN); + throw new TaskCorruptedException(corruptedTasks); + } + }.updateThreadMetadata(getSharedAdminClientId(CLIENT_ID)); + + thread.run(); + + verify(taskManager); + } + + @Test + @SuppressWarnings("unchecked") + public void shouldCatchTimeoutExceptionFromHandleCorruptionAndInvokeExceptionHandler() { + final TaskManager taskManager = EasyMock.createNiceMock(TaskManager.class); + expect(taskManager.producerClientIds()).andStubReturn(Collections.emptySet()); + final Consumer consumer = mock(Consumer.class); + final ConsumerGroupMetadata consumerGroupMetadata = mock(ConsumerGroupMetadata.class); + expect(consumer.groupMetadata()).andStubReturn(consumerGroupMetadata); + expect(consumerGroupMetadata.groupInstanceId()).andReturn(Optional.empty()); + consumer.subscribe((Collection) anyObject(), anyObject()); + EasyMock.expectLastCall().atLeastOnce(); + consumer.unsubscribe(); + EasyMock.expectLastCall().atLeastOnce(); + EasyMock.replay(consumerGroupMetadata); + final Task task1 = mock(Task.class); + final Task task2 = mock(Task.class); + final TaskId taskId1 = new TaskId(0, 0); + final TaskId taskId2 = new TaskId(0, 2); + + final Set corruptedTasks = singleton(taskId1); + + expect(task1.state()).andStubReturn(Task.State.RUNNING); + expect(task1.id()).andStubReturn(taskId1); + expect(task2.state()).andStubReturn(Task.State.RUNNING); + expect(task2.id()).andStubReturn(taskId2); + + taskManager.handleCorruption(corruptedTasks); + expectLastCall().andThrow(new TimeoutException()); + + EasyMock.replay(task1, task2, taskManager, consumer); + + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, CLIENT_ID, StreamsConfig.METRICS_LATEST, mockTime); + final TopologyMetadata topologyMetadata = new TopologyMetadata(internalTopologyBuilder, config); + topologyMetadata.buildAndRewriteTopology(); + final StreamThread thread = new StreamThread( + mockTime, + config, + null, + consumer, + consumer, + null, + null, + taskManager, + streamsMetrics, + topologyMetadata, + CLIENT_ID, + new LogContext(""), + new AtomicInteger(), + new AtomicLong(Long.MAX_VALUE), + null, + HANDLER, + null + ) { + @Override + void runOnce() { + setState(State.PENDING_SHUTDOWN); + throw new TaskCorruptedException(corruptedTasks); + } + }.updateThreadMetadata(getSharedAdminClientId(CLIENT_ID)); + + final AtomicBoolean exceptionHandlerInvoked = new AtomicBoolean(false); + + thread.setStreamsUncaughtExceptionHandler(e -> exceptionHandlerInvoked.set(true)); + thread.run(); + + verify(taskManager); + assertThat(exceptionHandlerInvoked.get(), is(true)); + } + + @Test + @SuppressWarnings("unchecked") + public void shouldCatchTaskMigratedExceptionOnOnTaskCorruptedExceptionPath() { + final TaskManager taskManager = EasyMock.createNiceMock(TaskManager.class); + expect(taskManager.producerClientIds()).andStubReturn(Collections.emptySet()); + final Consumer consumer = mock(Consumer.class); + final ConsumerGroupMetadata consumerGroupMetadata = mock(ConsumerGroupMetadata.class); + expect(consumer.groupMetadata()).andStubReturn(consumerGroupMetadata); + expect(consumerGroupMetadata.groupInstanceId()).andReturn(Optional.empty()); + consumer.subscribe((Collection) anyObject(), anyObject()); + EasyMock.expectLastCall().anyTimes(); + consumer.unsubscribe(); + EasyMock.expectLastCall().anyTimes(); + EasyMock.replay(consumerGroupMetadata); + final Task task1 = mock(Task.class); + final Task task2 = mock(Task.class); + final TaskId taskId1 = new TaskId(0, 0); + final TaskId taskId2 = new TaskId(0, 2); + + final Set corruptedTasks = singleton(taskId1); + + expect(task1.state()).andReturn(Task.State.RUNNING).anyTimes(); + expect(task1.id()).andReturn(taskId1).anyTimes(); + expect(task2.state()).andReturn(Task.State.RUNNING).anyTimes(); + expect(task2.id()).andReturn(taskId2).anyTimes(); + + taskManager.handleCorruption(corruptedTasks); + expectLastCall().andThrow(new TaskMigratedException("Task migrated", + new RuntimeException("non-corrupted task migrated"))); + + taskManager.handleLostAll(); + expectLastCall(); + + EasyMock.replay(task1, task2, taskManager, consumer); + + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, CLIENT_ID, StreamsConfig.METRICS_LATEST, mockTime); + final TopologyMetadata topologyMetadata = new TopologyMetadata(internalTopologyBuilder, config); + topologyMetadata.buildAndRewriteTopology(); + final StreamThread thread = new StreamThread( + mockTime, + config, + null, + consumer, + consumer, + null, + null, + taskManager, + streamsMetrics, + topologyMetadata, + CLIENT_ID, + new LogContext(""), + new AtomicInteger(), + new AtomicLong(Long.MAX_VALUE), + null, + HANDLER, + null + ) { + @Override + void runOnce() { + setState(State.PENDING_SHUTDOWN); + throw new TaskCorruptedException(corruptedTasks); + } + }.updateThreadMetadata(getSharedAdminClientId(CLIENT_ID)); + + thread.setState(StreamThread.State.STARTING); + thread.runLoop(); + + verify(taskManager); + } + + @Test + @SuppressWarnings("unchecked") + public void shouldEnforceRebalanceWhenTaskCorruptedExceptionIsThrownForAnActiveTask() { + final TaskManager taskManager = EasyMock.createNiceMock(TaskManager.class); + expect(taskManager.producerClientIds()).andStubReturn(Collections.emptySet()); + final Consumer consumer = mock(Consumer.class); + final ConsumerGroupMetadata consumerGroupMetadata = mock(ConsumerGroupMetadata.class); + expect(consumer.groupMetadata()).andStubReturn(consumerGroupMetadata); + expect(consumerGroupMetadata.groupInstanceId()).andReturn(Optional.empty()); + consumer.subscribe((Collection) anyObject(), anyObject()); + EasyMock.expectLastCall().anyTimes(); + consumer.unsubscribe(); + EasyMock.expectLastCall().anyTimes(); + EasyMock.replay(consumerGroupMetadata); + final Task task1 = mock(Task.class); + final Task task2 = mock(Task.class); + + final TaskId taskId1 = new TaskId(0, 0); + final TaskId taskId2 = new TaskId(0, 2); + + final Set corruptedTasks = singleton(taskId1); + + expect(task1.state()).andReturn(Task.State.RUNNING).anyTimes(); + expect(task1.id()).andReturn(taskId1).anyTimes(); + expect(task2.state()).andReturn(Task.State.CREATED).anyTimes(); + expect(task2.id()).andReturn(taskId2).anyTimes(); + expect(taskManager.handleCorruption(corruptedTasks)).andReturn(true); + + consumer.enforceRebalance(); + expectLastCall(); + + EasyMock.replay(task1, task2, taskManager, consumer); + + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, CLIENT_ID, StreamsConfig.METRICS_LATEST, mockTime); + final TopologyMetadata topologyMetadata = new TopologyMetadata(internalTopologyBuilder, config); + topologyMetadata.buildAndRewriteTopology(); + final StreamThread thread = new StreamThread( + mockTime, + eosEnabledConfig, + null, + consumer, + consumer, + null, + null, + taskManager, + streamsMetrics, + topologyMetadata, + CLIENT_ID, + new LogContext(""), + new AtomicInteger(), + new AtomicLong(Long.MAX_VALUE), + null, + HANDLER, + null + ) { + @Override + void runOnce() { + setState(State.PENDING_SHUTDOWN); + throw new TaskCorruptedException(corruptedTasks); + } + }.updateThreadMetadata(getSharedAdminClientId(CLIENT_ID)); + + thread.setState(StreamThread.State.STARTING); + thread.runLoop(); + + verify(taskManager); + verify(consumer); + } + + @Test + @SuppressWarnings("unchecked") + public void shouldNotEnforceRebalanceWhenTaskCorruptedExceptionIsThrownForAnInactiveTask() { + final TaskManager taskManager = EasyMock.createNiceMock(TaskManager.class); + expect(taskManager.producerClientIds()).andStubReturn(Collections.emptySet()); + final Consumer consumer = mock(Consumer.class); + final ConsumerGroupMetadata consumerGroupMetadata = mock(ConsumerGroupMetadata.class); + expect(consumer.groupMetadata()).andStubReturn(consumerGroupMetadata); + expect(consumerGroupMetadata.groupInstanceId()).andReturn(Optional.empty()); + consumer.subscribe((Collection) anyObject(), anyObject()); + EasyMock.expectLastCall().anyTimes(); + consumer.unsubscribe(); + EasyMock.expectLastCall().anyTimes(); + EasyMock.replay(consumerGroupMetadata); + final Task task1 = mock(Task.class); + final Task task2 = mock(Task.class); + + final TaskId taskId1 = new TaskId(0, 0); + final TaskId taskId2 = new TaskId(0, 2); + + final Set corruptedTasks = singleton(taskId1); + + expect(task1.state()).andReturn(Task.State.CLOSED).anyTimes(); + expect(task1.id()).andReturn(taskId1).anyTimes(); + expect(task2.state()).andReturn(Task.State.CLOSED).anyTimes(); + expect(task2.id()).andReturn(taskId2).anyTimes(); + expect(taskManager.handleCorruption(corruptedTasks)).andReturn(false); + + EasyMock.replay(task1, task2, taskManager, consumer); + + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, CLIENT_ID, StreamsConfig.METRICS_LATEST, mockTime); + final TopologyMetadata topologyMetadata = new TopologyMetadata(internalTopologyBuilder, config); + topologyMetadata.buildAndRewriteTopology(); + final StreamThread thread = new StreamThread( + mockTime, + eosEnabledConfig, + null, + consumer, + consumer, + null, + null, + taskManager, + streamsMetrics, + topologyMetadata, + CLIENT_ID, + new LogContext(""), + new AtomicInteger(), + new AtomicLong(Long.MAX_VALUE), + null, + HANDLER, + null + ) { + @Override + void runOnce() { + setState(State.PENDING_SHUTDOWN); + throw new TaskCorruptedException(corruptedTasks); + } + }.updateThreadMetadata(getSharedAdminClientId(CLIENT_ID)); + + thread.setState(StreamThread.State.STARTING); + thread.runLoop(); + + verify(taskManager); + verify(consumer); + } + + @Test + public void shouldNotCommitNonRunningNonRestoringTasks() { + final TaskManager taskManager = EasyMock.createNiceMock(TaskManager.class); + final Consumer consumer = mock(Consumer.class); + final ConsumerGroupMetadata consumerGroupMetadata = mock(ConsumerGroupMetadata.class); + expect(consumer.groupMetadata()).andStubReturn(consumerGroupMetadata); + expect(consumerGroupMetadata.groupInstanceId()).andReturn(Optional.empty()); + EasyMock.replay(consumer, consumerGroupMetadata); + final Task task1 = mock(Task.class); + final Task task2 = mock(Task.class); + final Task task3 = mock(Task.class); + + final TaskId taskId1 = new TaskId(0, 1); + final TaskId taskId2 = new TaskId(0, 2); + final TaskId taskId3 = new TaskId(0, 3); + + expect(task1.state()).andReturn(Task.State.RUNNING).anyTimes(); + expect(task1.id()).andReturn(taskId1).anyTimes(); + expect(task2.state()).andReturn(Task.State.RESTORING).anyTimes(); + expect(task2.id()).andReturn(taskId2).anyTimes(); + expect(task3.state()).andReturn(Task.State.CREATED).anyTimes(); + expect(task3.id()).andReturn(taskId3).anyTimes(); + + expect(taskManager.tasks()).andReturn(mkMap( + mkEntry(taskId1, task1), + mkEntry(taskId2, task2), + mkEntry(taskId3, task3) + )).anyTimes(); + + // expect not to try and commit task3, because it's not running. + expect(taskManager.commit(mkSet(task1, task2))).andReturn(2).times(1); + + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, CLIENT_ID, StreamsConfig.METRICS_LATEST, mockTime); + final StreamThread thread = new StreamThread( + mockTime, + config, + null, + consumer, + consumer, + null, + null, + taskManager, + streamsMetrics, + new TopologyMetadata(internalTopologyBuilder, config), + CLIENT_ID, + new LogContext(""), + new AtomicInteger(), + new AtomicLong(Long.MAX_VALUE), + null, + HANDLER, + null + ); + + EasyMock.replay(task1, task2, task3, taskManager); + + thread.setNow(mockTime.milliseconds()); + thread.maybeCommit(); + + verify(taskManager); + } + + @Test + public void shouldLogAndRecordSkippedRecordsForInvalidTimestamps() { + internalTopologyBuilder.addSource(null, "source1", null, null, null, topic1); + + final Properties config = configProps(false); + config.setProperty( + StreamsConfig.DEFAULT_TIMESTAMP_EXTRACTOR_CLASS_CONFIG, + LogAndSkipOnInvalidTimestamp.class.getName() + ); + final StreamThread thread = createStreamThread(CLIENT_ID, new StreamsConfig(config), false); + + thread.setState(StreamThread.State.STARTING); + thread.setState(StreamThread.State.PARTITIONS_REVOKED); + + final TaskId task1 = new TaskId(0, t1p1.partition()); + final Set assignedPartitions = Collections.singleton(t1p1); + thread.taskManager().handleAssignment( + Collections.singletonMap( + task1, + assignedPartitions), + emptyMap()); + + final MockConsumer mockConsumer = (MockConsumer) thread.mainConsumer(); + mockConsumer.assign(Collections.singleton(t1p1)); + mockConsumer.updateBeginningOffsets(Collections.singletonMap(t1p1, 0L)); + thread.rebalanceListener().onPartitionsAssigned(assignedPartitions); + thread.runOnce(); + + final MetricName skippedTotalMetric = metrics.metricName( + "skipped-records-total", + "stream-metrics", + Collections.singletonMap("client-id", thread.getName()) + ); + final MetricName skippedRateMetric = metrics.metricName( + "skipped-records-rate", + "stream-metrics", + Collections.singletonMap("client-id", thread.getName()) + ); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(RecordQueue.class)) { + long offset = -1; + addRecord(mockConsumer, ++offset); + addRecord(mockConsumer, ++offset); + thread.runOnce(); + + addRecord(mockConsumer, ++offset); + addRecord(mockConsumer, ++offset); + addRecord(mockConsumer, ++offset); + addRecord(mockConsumer, ++offset); + thread.runOnce(); + + addRecord(mockConsumer, ++offset, 1L); + addRecord(mockConsumer, ++offset, 1L); + thread.runOnce(); + + final List strings = appender.getMessages(); + + final String threadTaskPrefix = "stream-thread [" + Thread.currentThread().getName() + "] task [0_1] "; + assertTrue(strings.contains( + threadTaskPrefix + "Skipping record due to negative extracted timestamp. " + + "topic=[topic1] partition=[1] offset=[0] extractedTimestamp=[-1] " + + "extractor=[org.apache.kafka.streams.processor.LogAndSkipOnInvalidTimestamp]" + )); + assertTrue(strings.contains( + threadTaskPrefix + "Skipping record due to negative extracted timestamp. " + + "topic=[topic1] partition=[1] offset=[1] extractedTimestamp=[-1] " + + "extractor=[org.apache.kafka.streams.processor.LogAndSkipOnInvalidTimestamp]" + )); + assertTrue(strings.contains( + threadTaskPrefix + "Skipping record due to negative extracted timestamp. " + + "topic=[topic1] partition=[1] offset=[2] extractedTimestamp=[-1] " + + "extractor=[org.apache.kafka.streams.processor.LogAndSkipOnInvalidTimestamp]" + )); + assertTrue(strings.contains( + threadTaskPrefix + "Skipping record due to negative extracted timestamp. " + + "topic=[topic1] partition=[1] offset=[3] extractedTimestamp=[-1] " + + "extractor=[org.apache.kafka.streams.processor.LogAndSkipOnInvalidTimestamp]" + )); + assertTrue(strings.contains( + threadTaskPrefix + "Skipping record due to negative extracted timestamp. " + + "topic=[topic1] partition=[1] offset=[4] extractedTimestamp=[-1] " + + "extractor=[org.apache.kafka.streams.processor.LogAndSkipOnInvalidTimestamp]" + )); + assertTrue(strings.contains( + threadTaskPrefix + "Skipping record due to negative extracted timestamp. " + + "topic=[topic1] partition=[1] offset=[5] extractedTimestamp=[-1] " + + "extractor=[org.apache.kafka.streams.processor.LogAndSkipOnInvalidTimestamp]" + )); + } + } + + @Test + public void shouldTransmitTaskManagerMetrics() { + final Consumer consumer = EasyMock.createNiceMock(Consumer.class); + final ConsumerGroupMetadata consumerGroupMetadata = mock(ConsumerGroupMetadata.class); + expect(consumer.groupMetadata()).andStubReturn(consumerGroupMetadata); + expect(consumerGroupMetadata.groupInstanceId()).andReturn(Optional.empty()); + EasyMock.replay(consumer, consumerGroupMetadata); + final TaskManager taskManager = EasyMock.createNiceMock(TaskManager.class); + + final MetricName testMetricName = new MetricName("test_metric", "", "", new HashMap<>()); + final Metric testMetric = new KafkaMetric( + new Object(), + testMetricName, + (Measurable) (config, now) -> 0, + null, + new MockTime()); + final Map dummyProducerMetrics = singletonMap(testMetricName, testMetric); + + expect(taskManager.producerMetrics()).andReturn(dummyProducerMetrics); + EasyMock.replay(taskManager); + + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, CLIENT_ID, StreamsConfig.METRICS_LATEST, mockTime); + final StreamThread thread = new StreamThread( + mockTime, + new StreamsConfig(configProps(true)), + null, + consumer, + consumer, + null, + null, + taskManager, + streamsMetrics, + new TopologyMetadata(internalTopologyBuilder, config), + CLIENT_ID, + new LogContext(""), + new AtomicInteger(), + new AtomicLong(Long.MAX_VALUE), + null, + HANDLER, + null + ); + + assertThat(dummyProducerMetrics, is(thread.producerMetrics())); + } + + @Test + public void shouldConstructAdminMetrics() { + final Node broker1 = new Node(0, "dummyHost-1", 1234); + final Node broker2 = new Node(1, "dummyHost-2", 1234); + final List cluster = Arrays.asList(broker1, broker2); + + final MockAdminClient adminClient = new MockAdminClient.Builder(). + brokers(cluster).clusterId(null).build(); + + final Consumer consumer = EasyMock.createNiceMock(Consumer.class); + final ConsumerGroupMetadata consumerGroupMetadata = mock(ConsumerGroupMetadata.class); + expect(consumer.groupMetadata()).andStubReturn(consumerGroupMetadata); + expect(consumerGroupMetadata.groupInstanceId()).andReturn(Optional.empty()); + EasyMock.replay(consumer, consumerGroupMetadata); + final TaskManager taskManager = EasyMock.createNiceMock(TaskManager.class); + + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, CLIENT_ID, StreamsConfig.METRICS_LATEST, mockTime); + final StreamThread thread = new StreamThread( + mockTime, + config, + adminClient, + consumer, + consumer, + null, + null, + taskManager, + streamsMetrics, + new TopologyMetadata(internalTopologyBuilder, config), + CLIENT_ID, + new LogContext(""), + new AtomicInteger(), + new AtomicLong(Long.MAX_VALUE), + null, + HANDLER, + null + ); + final MetricName testMetricName = new MetricName("test_metric", "", "", new HashMap<>()); + final Metric testMetric = new KafkaMetric( + new Object(), + testMetricName, + (Measurable) (config, now) -> 0, + null, + new MockTime()); + + EasyMock.replay(taskManager); + + adminClient.setMockMetrics(testMetricName, testMetric); + final Map adminClientMetrics = thread.adminClientMetrics(); + assertEquals(testMetricName, adminClientMetrics.get(testMetricName).metricName()); + } + + @Test + public void shouldNotRecordFailedStreamThread() { + runAndVerifyFailedStreamThreadRecording(false); + } + + @Test + public void shouldRecordFailedStreamThread() { + runAndVerifyFailedStreamThreadRecording(true); + } + + public void runAndVerifyFailedStreamThreadRecording(final boolean shouldFail) { + final Consumer consumer = EasyMock.createNiceMock(Consumer.class); + final ConsumerGroupMetadata consumerGroupMetadata = mock(ConsumerGroupMetadata.class); + expect(consumer.groupMetadata()).andStubReturn(consumerGroupMetadata); + expect(consumerGroupMetadata.groupInstanceId()).andReturn(Optional.empty()); + EasyMock.replay(consumer, consumerGroupMetadata); + final TaskManager taskManager = EasyMock.createNiceMock(TaskManager.class); + expect(taskManager.producerClientIds()).andStubReturn(Collections.emptySet()); + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, CLIENT_ID, StreamsConfig.METRICS_LATEST, mockTime); + final TopologyMetadata topologyMetadata = new TopologyMetadata(internalTopologyBuilder, config); + topologyMetadata.buildAndRewriteTopology(); + final StreamThread thread = new StreamThread( + mockTime, + config, + null, + consumer, + consumer, + null, + null, + taskManager, + streamsMetrics, + topologyMetadata, + CLIENT_ID, + new LogContext(""), + new AtomicInteger(), + new AtomicLong(Long.MAX_VALUE), + null, + e -> { }, + null + ) { + @Override + void runOnce() { + setState(StreamThread.State.PENDING_SHUTDOWN); + if (shouldFail) { + throw new StreamsException(Thread.currentThread().getName()); + } + } + }; + EasyMock.replay(taskManager); + thread.updateThreadMetadata("metadata"); + + thread.run(); + + final Metric failedThreads = StreamsTestUtils.getMetricByName(metrics.metrics(), "failed-stream-threads", "stream-metrics"); + assertThat(failedThreads.metricValue(), is(shouldFail ? 1.0 : 0.0)); + } + + private TaskManager mockTaskManagerCommit(final Consumer consumer, + final int numberOfCommits, + final int commits) { + final TaskManager taskManager = EasyMock.createNiceMock(TaskManager.class); + final Task runningTask = mock(Task.class); + final TaskId taskId = new TaskId(0, 0); + + expect(runningTask.state()).andReturn(Task.State.RUNNING).anyTimes(); + expect(runningTask.id()).andReturn(taskId).anyTimes(); + expect(taskManager.tasks()) + .andReturn(Collections.singletonMap(taskId, runningTask)).times(numberOfCommits); + expect(taskManager.commit(Collections.singleton(runningTask))).andReturn(commits).times(numberOfCommits); + EasyMock.replay(taskManager, runningTask); + return taskManager; + } + + private void setupInternalTopologyWithoutState() { + internalTopologyBuilder.addSource(null, "source1", null, null, null, topic1); + internalTopologyBuilder.addProcessor( + "processor1", + (ProcessorSupplier) MockApiProcessor::new, + "source1" + ); + } + + // TODO: change return type to `StandbyTask` + private Collection createStandbyTask() { + final LogContext logContext = new LogContext("test"); + final Logger log = logContext.logger(StreamThreadTest.class); + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, CLIENT_ID, StreamsConfig.METRICS_LATEST, mockTime); + final StandbyTaskCreator standbyTaskCreator = new StandbyTaskCreator( + new TopologyMetadata(internalTopologyBuilder, config), + config, + streamsMetrics, + stateDirectory, + new MockChangelogReader(), + CLIENT_ID, + log); + return standbyTaskCreator.createTasks(singletonMap(new TaskId(1, 2), emptySet())); + } + + private void addRecord(final MockConsumer mockConsumer, + final long offset) { + addRecord(mockConsumer, offset, -1L); + } + + private void addRecord(final MockConsumer mockConsumer, + final long offset, + final long timestamp) { + mockConsumer.addRecord(new ConsumerRecord<>( + t1p1.topic(), + t1p1.partition(), + offset, + timestamp, + TimestampType.CREATE_TIME, + -1, + -1, + new byte[0], + new byte[0], + new RecordHeaders(), + Optional.empty())); + } + + StandbyTask standbyTask(final TaskManager taskManager, final TopicPartition partition) { + final Stream standbys = taskManager.tasks().values().stream().filter(t -> !t.isActive()); + for (final Task task : (Iterable) standbys::iterator) { + if (task.inputPartitions().contains(partition)) { + return (StandbyTask) task; + } + } + return null; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTotalBlockedTimeTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTotalBlockedTimeTest.java new file mode 100644 index 0000000..7151016 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamThreadTotalBlockedTimeTest.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.processor.internals; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.mockito.Mockito.when; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.function.Supplier; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +public class StreamThreadTotalBlockedTimeTest { + private static final int IO_TIME_TOTAL = 1; + private static final int IO_WAIT_TIME_TOTAL = 2; + private static final int COMMITTED_TIME_TOTAL = 3; + private static final int COMMIT_SYNC_TIME_TOTAL = 4; + private static final int RESTORE_IOTIME_TOTAL = 5; + private static final int RESTORE_IO_WAITTIME_TOTAL = 6; + private static final double PRODUCER_BLOCKED_TIME = 7.0; + + @Mock + Consumer consumer; + @Mock + Consumer restoreConsumer; + @Mock + Supplier producerBlocked; + + private StreamThreadTotalBlockedTime blockedTime; + + @Rule + public final MockitoRule mockitoRule = MockitoJUnit.rule(); + + @Before + public void setup() { + blockedTime = new StreamThreadTotalBlockedTime(consumer, restoreConsumer, producerBlocked); + when(consumer.metrics()).thenAnswer(a -> new MetricsBuilder() + .addMetric("io-time-ns-total", IO_TIME_TOTAL) + .addMetric("io-wait-time-ns-total", IO_WAIT_TIME_TOTAL) + .addMetric("committed-time-ns-total", COMMITTED_TIME_TOTAL) + .addMetric("commit-sync-time-ns-total", COMMIT_SYNC_TIME_TOTAL) + .build() + ); + when(restoreConsumer.metrics()).thenAnswer(a -> new MetricsBuilder() + .addMetric("io-time-ns-total", RESTORE_IOTIME_TOTAL) + .addMetric("io-wait-time-ns-total", RESTORE_IO_WAITTIME_TOTAL) + .build() + ); + when(producerBlocked.get()).thenReturn(PRODUCER_BLOCKED_TIME); + } + + @Test + public void shouldComputeTotalBlockedTime() { + assertThat( + blockedTime.compute(), + equalTo(IO_TIME_TOTAL + IO_WAIT_TIME_TOTAL + COMMITTED_TIME_TOTAL + + COMMIT_SYNC_TIME_TOTAL + RESTORE_IOTIME_TOTAL + RESTORE_IO_WAITTIME_TOTAL + + PRODUCER_BLOCKED_TIME) + ); + } + + private static class MetricsBuilder { + private final HashMap metrics = new HashMap<>(); + + private MetricsBuilder addMetric(final String name, final double value) { + final MetricName metricName = new MetricName(name, "", "", Collections.emptyMap()); + metrics.put( + metricName, + new Metric() { + @Override + public MetricName metricName() { + return metricName; + } + + @Override + public Object metricValue() { + return value; + } + } + ); + return this; + } + + public Map build() { + return Collections.unmodifiableMap(metrics); + } + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsAssignmentScaleTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsAssignmentScaleTest.java new file mode 100644 index 0000000..2782c1a --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsAssignmentScaleTest.java @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.admin.AdminClient; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Assignment; +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.GroupSubscription; +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Subscription; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.StreamsConfig.InternalConfig; +import org.apache.kafka.streams.processor.internals.assignment.AssignmentInfo; +import org.apache.kafka.streams.processor.internals.assignment.FallbackPriorTaskAssignor; +import org.apache.kafka.streams.processor.internals.assignment.HighAvailabilityTaskAssignor; +import org.apache.kafka.streams.processor.internals.assignment.ReferenceContainer; +import org.apache.kafka.streams.processor.internals.assignment.StickyTaskAssignor; +import org.apache.kafka.streams.processor.internals.assignment.TaskAssignor; +import org.apache.kafka.test.IntegrationTest; +import org.apache.kafka.test.MockApiProcessorSupplier; +import org.apache.kafka.test.MockClientSupplier; +import org.apache.kafka.test.MockInternalTopicManager; +import org.apache.kafka.test.MockKeyValueStoreBuilder; +import org.easymock.EasyMock; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; + +import static java.util.Collections.emptySet; +import static java.util.Collections.singletonList; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.EMPTY_TASKS; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.createMockAdminClientForAssignor; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.getInfo; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.uuidForInt; +import static org.easymock.EasyMock.expect; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +@Category({IntegrationTest.class}) +public class StreamsAssignmentScaleTest { + final static long MAX_ASSIGNMENT_DURATION = 60 * 1000L; //each individual assignment should complete within 20s + final static String APPLICATION_ID = "streams-assignment-scale-test"; + + private final Logger log = LoggerFactory.getLogger(StreamsAssignmentScaleTest.class); + + /************ HighAvailabilityTaskAssignor tests ************/ + + @Test(timeout = 120 * 1000) + public void testHighAvailabilityTaskAssignorLargePartitionCount() { + completeLargeAssignment(6_000, 2, 1, 1, HighAvailabilityTaskAssignor.class); + } + + @Test(timeout = 120 * 1000) + public void testHighAvailabilityTaskAssignorLargeNumConsumers() { + completeLargeAssignment(1_000, 1_000, 1, 1, HighAvailabilityTaskAssignor.class); + } + + @Test(timeout = 120 * 1000) + public void testHighAvailabilityTaskAssignorManyStandbys() { + completeLargeAssignment(1_000, 100, 1, 50, HighAvailabilityTaskAssignor.class); + } + + @Test(timeout = 120 * 1000) + public void testHighAvailabilityTaskAssignorManyThreadsPerClient() { + completeLargeAssignment(1_000, 10, 1000, 1, HighAvailabilityTaskAssignor.class); + } + + /************ StickyTaskAssignor tests ************/ + + @Test(timeout = 120 * 1000) + public void testStickyTaskAssignorLargePartitionCount() { + completeLargeAssignment(2_000, 2, 1, 1, StickyTaskAssignor.class); + } + + @Test(timeout = 120 * 1000) + public void testStickyTaskAssignorLargeNumConsumers() { + completeLargeAssignment(1_000, 1_000, 1, 1, StickyTaskAssignor.class); + } + + @Test(timeout = 120 * 1000) + public void testStickyTaskAssignorManyStandbys() { + completeLargeAssignment(1_000, 100, 1, 20, StickyTaskAssignor.class); + } + + @Test(timeout = 120 * 1000) + public void testStickyTaskAssignorManyThreadsPerClient() { + completeLargeAssignment(1_000, 10, 1000, 1, StickyTaskAssignor.class); + } + + /************ FallbackPriorTaskAssignor tests ************/ + + @Test(timeout = 120 * 1000) + public void testFallbackPriorTaskAssignorLargePartitionCount() { + completeLargeAssignment(2_000, 2, 1, 1, FallbackPriorTaskAssignor.class); + } + + @Test(timeout = 120 * 1000) + public void testFallbackPriorTaskAssignorLargeNumConsumers() { + completeLargeAssignment(1_000, 1_000, 1, 1, FallbackPriorTaskAssignor.class); + } + + @Test(timeout = 120 * 1000) + public void testFallbackPriorTaskAssignorManyStandbys() { + completeLargeAssignment(1_000, 100, 1, 20, FallbackPriorTaskAssignor.class); + } + + @Test(timeout = 120 * 1000) + public void testFallbackPriorTaskAssignorManyThreadsPerClient() { + completeLargeAssignment(1_000, 10, 1000, 1, FallbackPriorTaskAssignor.class); + } + + private void completeLargeAssignment(final int numPartitions, + final int numClients, + final int numThreadsPerClient, + final int numStandbys, + final Class taskAssignor) { + final List topic = singletonList("topic"); + + final Map changelogEndOffsets = new HashMap<>(); + for (int p = 0; p < numPartitions; ++p) { + changelogEndOffsets.put(new TopicPartition(APPLICATION_ID + "-store-changelog", p), 100_000L); + } + + final List partitionInfos = new ArrayList<>(); + for (int p = 0; p < numPartitions; ++p) { + partitionInfos.add(new PartitionInfo("topic", p, Node.noNode(), new Node[0], new Node[0])); + } + + final Cluster clusterMetadata = new Cluster( + "cluster", + Collections.singletonList(Node.noNode()), + partitionInfos, + emptySet(), + emptySet() + ); + final Map configMap = new HashMap<>(); + configMap.put(StreamsConfig.APPLICATION_ID_CONFIG, APPLICATION_ID); + configMap.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:8080"); + final InternalTopologyBuilder builder = new InternalTopologyBuilder(); + builder.addSource(null, "source", null, null, null, "topic"); + builder.addProcessor("processor", new MockApiProcessorSupplier<>(), "source"); + builder.addStateStore(new MockKeyValueStoreBuilder("store", false), "processor"); + final TopologyMetadata topologyMetadata = new TopologyMetadata(builder, new StreamsConfig(configMap)); + topologyMetadata.buildAndRewriteTopology(); + + final Consumer mainConsumer = EasyMock.createNiceMock(Consumer.class); + final TaskManager taskManager = EasyMock.createNiceMock(TaskManager.class); + expect(taskManager.topologyMetadata()).andStubReturn(topologyMetadata); + expect(mainConsumer.committed(new HashSet<>())).andStubReturn(Collections.emptyMap()); + final AdminClient adminClient = createMockAdminClientForAssignor(changelogEndOffsets); + + final ReferenceContainer referenceContainer = new ReferenceContainer(); + referenceContainer.mainConsumer = mainConsumer; + referenceContainer.adminClient = adminClient; + referenceContainer.taskManager = taskManager; + referenceContainer.streamsMetadataState = EasyMock.createNiceMock(StreamsMetadataState.class); + referenceContainer.time = new MockTime(); + configMap.put(InternalConfig.REFERENCE_CONTAINER_PARTITION_ASSIGNOR, referenceContainer); + configMap.put(InternalConfig.INTERNAL_TASK_ASSIGNOR_CLASS, taskAssignor.getName()); + configMap.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, numStandbys); + + final MockInternalTopicManager mockInternalTopicManager = new MockInternalTopicManager( + new MockTime(), + new StreamsConfig(configMap), + new MockClientSupplier().restoreConsumer, + false + ); + EasyMock.replay(taskManager, adminClient, mainConsumer); + + final StreamsPartitionAssignor partitionAssignor = new StreamsPartitionAssignor(); + partitionAssignor.configure(configMap); + partitionAssignor.setInternalTopicManager(mockInternalTopicManager); + + final Map subscriptions = new HashMap<>(); + for (int client = 0; client < numClients; ++client) { + for (int i = 0; i < numThreadsPerClient; ++i) { + subscriptions.put( + getConsumerName(i, client), + new Subscription(topic, getInfo(uuidForInt(client), EMPTY_TASKS, EMPTY_TASKS).encode()) + ); + } + } + + final long firstAssignmentStartMs = System.currentTimeMillis(); + final Map firstAssignments = partitionAssignor.assign(clusterMetadata, new GroupSubscription(subscriptions)).groupAssignment(); + final long firstAssignmentEndMs = System.currentTimeMillis(); + + final long firstAssignmentDuration = firstAssignmentEndMs - firstAssignmentStartMs; + if (firstAssignmentDuration > MAX_ASSIGNMENT_DURATION) { + throw new AssertionError("The first assignment took too long to complete at " + firstAssignmentDuration + "ms."); + } else { + log.info("First assignment took {}ms.", firstAssignmentDuration); + } + + // Use the assignment to generate the subscriptions' prev task data for the next rebalance + for (int client = 0; client < numClients; ++client) { + for (int i = 0; i < numThreadsPerClient; ++i) { + final String consumer = getConsumerName(i, client); + final Assignment assignment = firstAssignments.get(consumer); + final AssignmentInfo info = AssignmentInfo.decode(assignment.userData()); + + subscriptions.put( + consumer, + new Subscription( + topic, + getInfo(uuidForInt(client), new HashSet<>(info.activeTasks()), info.standbyTasks().keySet()).encode(), + assignment.partitions()) + ); + } + } + + final long secondAssignmentStartMs = System.currentTimeMillis(); + final Map secondAssignments = partitionAssignor.assign(clusterMetadata, new GroupSubscription(subscriptions)).groupAssignment(); + final long secondAssignmentEndMs = System.currentTimeMillis(); + final long secondAssignmentDuration = secondAssignmentEndMs - secondAssignmentStartMs; + if (secondAssignmentDuration > MAX_ASSIGNMENT_DURATION) { + throw new AssertionError("The second assignment took too long to complete at " + secondAssignmentDuration + "ms."); + } else { + log.info("Second assignment took {}ms.", secondAssignmentDuration); + } + + assertThat(secondAssignments.size(), is(numClients * numThreadsPerClient)); + } + + private String getConsumerName(final int consumerIndex, final int clientIndex) { + return "consumer-" + clientIndex + "-" + consumerIndex; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsMetadataStateTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsMetadataStateTest.java new file mode 100644 index 0000000..84578b2 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsMetadataStateTest.java @@ -0,0 +1,381 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.KeyQueryMetadata; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsMetadata; +import org.apache.kafka.streams.TopologyWrapper; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.processor.StreamPartitioner; +import org.apache.kafka.streams.processor.internals.testutil.DummyStreamsConfig; +import org.apache.kafka.streams.state.HostInfo; +import org.apache.kafka.streams.state.internals.StreamsMetadataImpl; +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertFalse; + +public class StreamsMetadataStateTest { + + private StreamsMetadataState metadataState; + private HostInfo hostOne; + private HostInfo hostTwo; + private HostInfo hostThree; + private TopicPartition topic1P0; + private TopicPartition topic2P0; + private TopicPartition topic3P0; + private Map> hostToActivePartitions; + private Map> hostToStandbyPartitions; + private StreamsBuilder builder; + private TopicPartition topic1P1; + private TopicPartition topic2P1; + private TopicPartition topic4P0; + private Cluster cluster; + private final String globalTable = "global-table"; + private StreamPartitioner partitioner; + private Set storeNames; + + @Before + public void before() { + builder = new StreamsBuilder(); + final KStream one = builder.stream("topic-one"); + one.groupByKey().count(Materialized.as("table-one")); + + final KStream two = builder.stream("topic-two"); + two.groupByKey().count(Materialized.as("table-two")); + + builder.stream("topic-three") + .groupByKey() + .count(Materialized.as("table-three")); + + one.merge(two).groupByKey().count(Materialized.as("merged-table")); + + builder.stream("topic-four").mapValues(value -> value); + + builder.globalTable("global-topic", + Consumed.with(null, null), + Materialized.as(globalTable)); + + TopologyWrapper.getInternalTopologyBuilder(builder.build()).setApplicationId("appId"); + + topic1P0 = new TopicPartition("topic-one", 0); + topic1P1 = new TopicPartition("topic-one", 1); + topic2P0 = new TopicPartition("topic-two", 0); + topic2P1 = new TopicPartition("topic-two", 1); + topic3P0 = new TopicPartition("topic-three", 0); + topic4P0 = new TopicPartition("topic-four", 0); + + hostOne = new HostInfo("host-one", 8080); + hostTwo = new HostInfo("host-two", 9090); + hostThree = new HostInfo("host-three", 7070); + hostToActivePartitions = new HashMap<>(); + hostToActivePartitions.put(hostOne, mkSet(topic1P0, topic2P1, topic4P0)); + hostToActivePartitions.put(hostTwo, mkSet(topic2P0, topic1P1)); + hostToActivePartitions.put(hostThree, Collections.singleton(topic3P0)); + hostToStandbyPartitions = new HashMap<>(); + hostToStandbyPartitions.put(hostThree, mkSet(topic1P0, topic2P1, topic4P0)); + hostToStandbyPartitions.put(hostOne, mkSet(topic2P0, topic1P1)); + hostToStandbyPartitions.put(hostTwo, Collections.singleton(topic3P0)); + + final List partitionInfos = Arrays.asList( + new PartitionInfo("topic-one", 0, null, null, null), + new PartitionInfo("topic-one", 1, null, null, null), + new PartitionInfo("topic-two", 0, null, null, null), + new PartitionInfo("topic-two", 1, null, null, null), + new PartitionInfo("topic-three", 0, null, null, null), + new PartitionInfo("topic-four", 0, null, null, null)); + + cluster = new Cluster(null, Collections.emptyList(), partitionInfos, Collections.emptySet(), Collections.emptySet()); + final TopologyMetadata topologyMetadata = new TopologyMetadata(TopologyWrapper.getInternalTopologyBuilder(builder.build()), new DummyStreamsConfig()); + topologyMetadata.buildAndRewriteTopology(); + metadataState = new StreamsMetadataState(topologyMetadata, hostOne); + metadataState.onChange(hostToActivePartitions, hostToStandbyPartitions, cluster); + partitioner = (topic, key, value, numPartitions) -> 1; + storeNames = mkSet("table-one", "table-two", "merged-table", globalTable); + } + + @Test + public void shouldNotThrowExceptionWhenOnChangeNotCalled() { + final Collection metadata = new StreamsMetadataState( + new TopologyMetadata(TopologyWrapper.getInternalTopologyBuilder(builder.build()), new DummyStreamsConfig()), + hostOne + ).getAllMetadataForStore("store"); + assertEquals(0, metadata.size()); + } + + @Test + public void shouldGetAllStreamInstances() { + final StreamsMetadata one = new StreamsMetadataImpl(hostOne, + mkSet(globalTable, "table-one", "table-two", "merged-table"), + mkSet(topic1P0, topic2P1, topic4P0), + mkSet("table-one", "table-two", "merged-table"), + mkSet(topic2P0, topic1P1)); + final StreamsMetadata two = new StreamsMetadataImpl(hostTwo, + mkSet(globalTable, "table-two", "table-one", "merged-table"), + mkSet(topic2P0, topic1P1), + mkSet("table-three"), + mkSet(topic3P0)); + final StreamsMetadata three = new StreamsMetadataImpl(hostThree, + mkSet(globalTable, "table-three"), + Collections.singleton(topic3P0), + mkSet("table-one", "table-two", "merged-table"), + mkSet(topic1P0, topic2P1, topic4P0)); + + final Collection actual = metadataState.getAllMetadata(); + assertEquals(3, actual.size()); + assertTrue("expected " + actual + " to contain " + one, actual.contains(one)); + assertTrue("expected " + actual + " to contain " + two, actual.contains(two)); + assertTrue("expected " + actual + " to contain " + three, actual.contains(three)); + } + + @Test + public void shouldGetAllStreamsInstancesWithNoStores() { + builder.stream("topic-five").filter((key, value) -> true).to("some-other-topic"); + + final TopicPartition tp5 = new TopicPartition("topic-five", 1); + final HostInfo hostFour = new HostInfo("host-four", 8080); + hostToActivePartitions.put(hostFour, mkSet(tp5)); + + metadataState.onChange(hostToActivePartitions, Collections.emptyMap(), + cluster.withPartitions(Collections.singletonMap(tp5, new PartitionInfo("topic-five", 1, null, null, null)))); + + final StreamsMetadata expected = new StreamsMetadataImpl(hostFour, Collections.singleton(globalTable), + Collections.singleton(tp5), Collections.emptySet(), Collections.emptySet()); + final Collection actual = metadataState.getAllMetadata(); + assertTrue("expected " + actual + " to contain " + expected, actual.contains(expected)); + } + + @Test + public void shouldGetInstancesForStoreName() { + final StreamsMetadata one = new StreamsMetadataImpl(hostOne, + mkSet(globalTable, "table-one", "table-two", "merged-table"), + mkSet(topic1P0, topic2P1, topic4P0), + mkSet("table-one", "table-two", "merged-table"), + mkSet(topic2P0, topic1P1)); + final StreamsMetadata two = new StreamsMetadataImpl(hostTwo, + mkSet(globalTable, "table-two", "table-one", "merged-table"), + mkSet(topic2P0, topic1P1), + mkSet("table-three"), + mkSet(topic3P0)); + final Collection actual = metadataState.getAllMetadataForStore("table-one"); + final Map actualAsMap = actual.stream() + .collect(Collectors.toMap(StreamsMetadata::hostInfo, Function.identity())); + assertEquals(3, actual.size()); + assertTrue("expected " + actual + " to contain " + one, actual.contains(one)); + assertTrue("expected " + actual + " to contain " + two, actual.contains(two)); + assertTrue("expected " + hostThree + " to contain as standby", + actualAsMap.get(hostThree).standbyStateStoreNames().contains("table-one")); + } + + @Test + public void shouldThrowIfStoreNameIsNullOnGetAllInstancesWithStore() { + assertThrows(NullPointerException.class, () -> metadataState.getAllMetadataForStore(null)); + } + + @Test + public void shouldReturnEmptyCollectionOnGetAllInstancesWithStoreWhenStoreDoesntExist() { + final Collection actual = metadataState.getAllMetadataForStore("not-a-store"); + assertTrue(actual.isEmpty()); + } + + @Test + public void shouldGetInstanceWithKey() { + final TopicPartition tp4 = new TopicPartition("topic-three", 1); + hostToActivePartitions.put(hostTwo, mkSet(topic2P0, tp4)); + + metadataState.onChange(hostToActivePartitions, hostToStandbyPartitions, + cluster.withPartitions(Collections.singletonMap(tp4, new PartitionInfo("topic-three", 1, null, null, null)))); + + final KeyQueryMetadata expected = new KeyQueryMetadata(hostThree, mkSet(hostTwo), 0); + final KeyQueryMetadata actual = metadataState.getKeyQueryMetadataForKey("table-three", + "the-key", + Serdes.String().serializer()); + assertEquals(expected, actual); + } + + @Test + public void shouldGetInstanceWithKeyAndCustomPartitioner() { + final TopicPartition tp4 = new TopicPartition("topic-three", 1); + hostToActivePartitions.put(hostTwo, mkSet(topic2P0, tp4)); + + metadataState.onChange(hostToActivePartitions, hostToStandbyPartitions, + cluster.withPartitions(Collections.singletonMap(tp4, new PartitionInfo("topic-three", 1, null, null, null)))); + + final KeyQueryMetadata expected = new KeyQueryMetadata(hostTwo, Collections.emptySet(), 1); + + final KeyQueryMetadata actual = metadataState.getKeyQueryMetadataForKey("table-three", + "the-key", + partitioner); + assertEquals(expected, actual); + assertEquals(1, actual.partition()); + } + + @Test + public void shouldReturnNotAvailableWhenClusterIsEmpty() { + metadataState.onChange(Collections.emptyMap(), Collections.emptyMap(), Cluster.empty()); + final KeyQueryMetadata result = metadataState.getKeyQueryMetadataForKey("table-one", "a", Serdes.String().serializer()); + assertEquals(KeyQueryMetadata.NOT_AVAILABLE, result); + } + + @Test + public void shouldGetInstanceWithKeyWithMergedStreams() { + final TopicPartition topic2P2 = new TopicPartition("topic-two", 2); + hostToActivePartitions.put(hostTwo, mkSet(topic2P0, topic1P1, topic2P2)); + hostToStandbyPartitions.put(hostOne, mkSet(topic2P0, topic1P1, topic2P2)); + metadataState.onChange(hostToActivePartitions, hostToStandbyPartitions, + cluster.withPartitions(Collections.singletonMap(topic2P2, new PartitionInfo("topic-two", 2, null, null, null)))); + + final KeyQueryMetadata expected = new KeyQueryMetadata(hostTwo, mkSet(hostOne), 2); + + final KeyQueryMetadata actual = metadataState.getKeyQueryMetadataForKey("merged-table", "the-key", + (topic, key, value, numPartitions) -> 2); + + assertEquals(expected, actual); + } + + @Test + public void shouldReturnNullOnGetWithKeyWhenStoreDoesntExist() { + final KeyQueryMetadata actual = metadataState.getKeyQueryMetadataForKey("not-a-store", + "key", + Serdes.String().serializer()); + assertNull(actual); + } + + @Test + public void shouldThrowWhenKeyIsNull() { + assertThrows(NullPointerException.class, () -> metadataState.getKeyQueryMetadataForKey("table-three", null, Serdes.String().serializer())); + } + + @Test + public void shouldThrowWhenSerializerIsNull() { + assertThrows(NullPointerException.class, () -> metadataState.getKeyQueryMetadataForKey("table-three", "key", (Serializer) null)); + } + + @Test + public void shouldThrowIfStoreNameIsNull() { + assertThrows(NullPointerException.class, () -> metadataState.getKeyQueryMetadataForKey(null, "key", Serdes.String().serializer())); + } + + @SuppressWarnings("unchecked") + @Test + public void shouldThrowIfStreamPartitionerIsNull() { + assertThrows(NullPointerException.class, () -> metadataState.getKeyQueryMetadataForKey(null, "key", (StreamPartitioner) null)); + } + + @Test + public void shouldHaveGlobalStoreInAllMetadata() { + final Collection metadata = metadataState.getAllMetadataForStore(globalTable); + assertEquals(3, metadata.size()); + for (final StreamsMetadata streamsMetadata : metadata) { + assertTrue(streamsMetadata.stateStoreNames().contains(globalTable)); + } + } + + @Test + public void shouldGetLocalMetadataWithRightActiveStandbyInfo() { + assertEquals(hostOne, metadataState.getLocalMetadata().hostInfo()); + assertEquals(hostToActivePartitions.get(hostOne), metadataState.getLocalMetadata().topicPartitions()); + assertEquals(hostToStandbyPartitions.get(hostOne), metadataState.getLocalMetadata().standbyTopicPartitions()); + assertEquals(storeNames, metadataState.getLocalMetadata().stateStoreNames()); + assertEquals(storeNames.stream().filter(s -> !s.equals(globalTable)).collect(Collectors.toSet()), + metadataState.getLocalMetadata().standbyStateStoreNames()); + } + + @Test + public void shouldGetQueryMetadataForGlobalStoreWithKey() { + final KeyQueryMetadata metadata = metadataState.getKeyQueryMetadataForKey(globalTable, "key", Serdes.String().serializer()); + assertEquals(hostOne, metadata.activeHost()); + assertTrue(metadata.standbyHosts().isEmpty()); + } + + @Test + public void shouldGetAnyHostForGlobalStoreByKeyIfMyHostUnknown() { + final StreamsMetadataState streamsMetadataState = new StreamsMetadataState( + new TopologyMetadata(TopologyWrapper.getInternalTopologyBuilder(builder.build()), new DummyStreamsConfig()), + StreamsMetadataState.UNKNOWN_HOST + ); + streamsMetadataState.onChange(hostToActivePartitions, hostToStandbyPartitions, cluster); + assertNotNull(streamsMetadataState.getKeyQueryMetadataForKey(globalTable, "key", Serdes.String().serializer())); + } + + @Test + public void shouldGetQueryMetadataForGlobalStoreWithKeyAndPartitioner() { + final KeyQueryMetadata metadata = metadataState.getKeyQueryMetadataForKey(globalTable, "key", partitioner); + assertEquals(hostOne, metadata.activeHost()); + assertTrue(metadata.standbyHosts().isEmpty()); + } + + @Test + public void shouldGetAnyHostForGlobalStoreByKeyAndPartitionerIfMyHostUnknown() { + final StreamsMetadataState streamsMetadataState = new StreamsMetadataState( + new TopologyMetadata(TopologyWrapper.getInternalTopologyBuilder(builder.build()), new DummyStreamsConfig()), + StreamsMetadataState.UNKNOWN_HOST + ); + streamsMetadataState.onChange(hostToActivePartitions, hostToStandbyPartitions, cluster); + assertNotNull(streamsMetadataState.getKeyQueryMetadataForKey(globalTable, "key", partitioner)); + } + + @Test + public void shouldReturnAllMetadataThatRemainsValidAfterChange() { + final Collection allMetadata = metadataState.getAllMetadata(); + final Collection copy = new ArrayList<>(allMetadata); + assertFalse("invalid test", allMetadata.isEmpty()); + metadataState.onChange(Collections.emptyMap(), Collections.emptyMap(), cluster); + assertEquals("encapsulation broken", allMetadata, copy); + } + + @Test + public void shouldNotReturnMutableReferenceToInternalAllMetadataCollection() { + final Collection allMetadata = metadataState.getAllMetadata(); + assertFalse("invalid test", allMetadata.isEmpty()); + + try { + // Either this should not affect internal state of 'metadataState' + allMetadata.clear(); + } catch (final UnsupportedOperationException e) { + // Or should fail. + } + + assertFalse("encapsulation broken", metadataState.getAllMetadata().isEmpty()); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java new file mode 100644 index 0000000..24bf5b8 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsPartitionAssignorTest.java @@ -0,0 +1,2197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import java.time.Duration; +import java.util.Properties; +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.AdminClient; +import org.apache.kafka.clients.admin.ListOffsetsResult; +import org.apache.kafka.clients.admin.ListOffsetsResult.ListOffsetsResultInfo; +import org.apache.kafka.clients.admin.OffsetSpec; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Assignment; +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.GroupSubscription; +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.RebalanceProtocol; +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor.Subscription; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.StreamsConfig.InternalConfig; +import org.apache.kafka.streams.TopologyWrapper; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.JoinWindows; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.streams.kstream.ValueJoiner; +import org.apache.kafka.streams.kstream.internals.ConsumedInternal; +import org.apache.kafka.streams.kstream.internals.InternalStreamsBuilder; +import org.apache.kafka.streams.kstream.internals.MaterializedInternal; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology; +import org.apache.kafka.streams.processor.internals.assignment.AssignmentInfo; +import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration; +import org.apache.kafka.streams.processor.internals.assignment.AssignorError; +import org.apache.kafka.streams.processor.internals.assignment.ClientState; +import org.apache.kafka.streams.processor.internals.assignment.FallbackPriorTaskAssignor; +import org.apache.kafka.streams.processor.internals.assignment.HighAvailabilityTaskAssignor; +import org.apache.kafka.streams.processor.internals.assignment.ReferenceContainer; +import org.apache.kafka.streams.processor.internals.assignment.StickyTaskAssignor; +import org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo; +import org.apache.kafka.streams.processor.internals.assignment.TaskAssignor; +import org.apache.kafka.streams.state.HostInfo; +import org.apache.kafka.test.MockApiProcessorSupplier; +import org.apache.kafka.test.MockClientSupplier; +import org.apache.kafka.test.MockInternalTopicManager; +import org.apache.kafka.test.MockKeyValueStoreBuilder; +import org.easymock.Capture; +import org.easymock.EasyMock; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.SortedSet; +import java.util.UUID; +import java.util.stream.Collectors; + +import static java.time.Duration.ofMillis; +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; +import static java.util.Collections.emptyMap; +import static java.util.Collections.emptySet; +import static java.util.Collections.singleton; +import static java.util.Collections.singletonList; +import static java.util.Collections.singletonMap; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.common.utils.Utils.mkSortedSet; +import static org.apache.kafka.streams.processor.internals.StreamsPartitionAssignor.assignTasksToThreads; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.EMPTY_CHANGELOG_END_OFFSETS; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.EMPTY_TASKS; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_2; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_3; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_2; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_3; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_2; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.createMockAdminClientForAssignor; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.getInfo; +import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.LATEST_SUPPORTED_VERSION; +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.mock; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.anEmptyMap; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +@RunWith(value = Parameterized.class) +@SuppressWarnings("deprecation") +public class StreamsPartitionAssignorTest { + private static final String CONSUMER_1 = "consumer1"; + private static final String CONSUMER_2 = "consumer2"; + private static final String CONSUMER_3 = "consumer3"; + private static final String CONSUMER_4 = "consumer4"; + + private final Set allTopics = mkSet("topic1", "topic2"); + + private final TopicPartition t1p0 = new TopicPartition("topic1", 0); + private final TopicPartition t1p1 = new TopicPartition("topic1", 1); + private final TopicPartition t1p2 = new TopicPartition("topic1", 2); + private final TopicPartition t1p3 = new TopicPartition("topic1", 3); + private final TopicPartition t2p0 = new TopicPartition("topic2", 0); + private final TopicPartition t2p1 = new TopicPartition("topic2", 1); + private final TopicPartition t2p2 = new TopicPartition("topic2", 2); + private final TopicPartition t2p3 = new TopicPartition("topic2", 3); + private final TopicPartition t3p0 = new TopicPartition("topic3", 0); + private final TopicPartition t3p1 = new TopicPartition("topic3", 1); + private final TopicPartition t3p2 = new TopicPartition("topic3", 2); + private final TopicPartition t3p3 = new TopicPartition("topic3", 3); + + private final List infos = asList( + new PartitionInfo("topic1", 0, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic1", 1, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic1", 2, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic2", 0, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic2", 1, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic2", 2, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic3", 0, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic3", 1, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic3", 2, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic3", 3, Node.noNode(), new Node[0], new Node[0]) + ); + + private final SubscriptionInfo defaultSubscriptionInfo = getInfo(UUID_1, EMPTY_TASKS, EMPTY_TASKS); + + private final Cluster metadata = new Cluster( + "cluster", + Collections.singletonList(Node.noNode()), + infos, + emptySet(), + emptySet() + ); + + private final StreamsPartitionAssignor partitionAssignor = new StreamsPartitionAssignor(); + private final MockClientSupplier mockClientSupplier = new MockClientSupplier(); + private static final String USER_END_POINT = "localhost:8080"; + private static final String OTHER_END_POINT = "other:9090"; + private static final String APPLICATION_ID = "stream-partition-assignor-test"; + + private TaskManager taskManager; + private Admin adminClient; + private InternalTopologyBuilder builder = new InternalTopologyBuilder(); + private TopologyMetadata topologyMetadata; + private StreamsMetadataState streamsMetadataState = EasyMock.createNiceMock(StreamsMetadataState.class); + private final Map subscriptions = new HashMap<>(); + private final Class taskAssignor; + + private final ReferenceContainer referenceContainer = new ReferenceContainer(); + private final MockTime time = new MockTime(); + private final byte uniqueField = 1; + + private Map configProps() { + final Map configurationMap = new HashMap<>(); + configurationMap.put(StreamsConfig.APPLICATION_ID_CONFIG, APPLICATION_ID); + configurationMap.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, USER_END_POINT); + referenceContainer.mainConsumer = mock(Consumer.class); + referenceContainer.adminClient = adminClient != null ? adminClient : mock(Admin.class); + referenceContainer.taskManager = taskManager; + referenceContainer.streamsMetadataState = streamsMetadataState; + referenceContainer.time = time; + configurationMap.put(InternalConfig.REFERENCE_CONTAINER_PARTITION_ASSIGNOR, referenceContainer); + configurationMap.put(InternalConfig.INTERNAL_TASK_ASSIGNOR_CLASS, taskAssignor.getName()); + return configurationMap; + } + + private MockInternalTopicManager configureDefault() { + createDefaultMockTaskManager(); + return configureDefaultPartitionAssignor(); + } + + // Make sure to complete setting up any mocks (such as TaskManager or AdminClient) before configuring the assignor + private MockInternalTopicManager configureDefaultPartitionAssignor() { + return configurePartitionAssignorWith(emptyMap()); + } + + // Make sure to complete setting up any mocks (such as TaskManager or AdminClient) before configuring the assignor + private MockInternalTopicManager configurePartitionAssignorWith(final Map props) { + final Map configMap = configProps(); + configMap.putAll(props); + + partitionAssignor.configure(configMap); + EasyMock.replay(taskManager, adminClient); + + topologyMetadata = new TopologyMetadata(builder, new StreamsConfig(configProps())); + return overwriteInternalTopicManagerWithMock(false); + } + + private void createDefaultMockTaskManager() { + createMockTaskManager(EMPTY_TASKS, EMPTY_TASKS); + } + + private void createMockTaskManager(final Set activeTasks, + final Set standbyTasks) { + taskManager = EasyMock.createNiceMock(TaskManager.class); + expect(taskManager.topologyMetadata()).andStubReturn(topologyMetadata); + expect(taskManager.getTaskOffsetSums()).andStubReturn(getTaskOffsetSums(activeTasks, standbyTasks)); + expect(taskManager.processId()).andStubReturn(UUID_1); + builder.setApplicationId(APPLICATION_ID); + topologyMetadata.buildAndRewriteTopology(); + } + + // If mockCreateInternalTopics is true the internal topic manager will report that it had to create all internal + // topics and we will skip the listOffsets request for these changelogs + private MockInternalTopicManager overwriteInternalTopicManagerWithMock(final boolean mockCreateInternalTopics) { + final MockInternalTopicManager mockInternalTopicManager = new MockInternalTopicManager( + time, + new StreamsConfig(configProps()), + mockClientSupplier.restoreConsumer, + mockCreateInternalTopics + ); + partitionAssignor.setInternalTopicManager(mockInternalTopicManager); + return mockInternalTopicManager; + } + + @Parameterized.Parameters(name = "task assignor = {0}") + public static Collection parameters() { + return asList( + new Object[]{HighAvailabilityTaskAssignor.class}, + new Object[]{StickyTaskAssignor.class}, + new Object[]{FallbackPriorTaskAssignor.class} + ); + } + + public StreamsPartitionAssignorTest(final Class taskAssignor) { + this.taskAssignor = taskAssignor; + adminClient = createMockAdminClientForAssignor(EMPTY_CHANGELOG_END_OFFSETS); + topologyMetadata = new TopologyMetadata(builder, new StreamsConfig(configProps())); + } + + @Test + public void shouldUseEagerRebalancingProtocol() { + createDefaultMockTaskManager(); + configurePartitionAssignorWith(Collections.singletonMap(StreamsConfig.UPGRADE_FROM_CONFIG, StreamsConfig.UPGRADE_FROM_23)); + + assertEquals(1, partitionAssignor.supportedProtocols().size()); + assertTrue(partitionAssignor.supportedProtocols().contains(RebalanceProtocol.EAGER)); + assertFalse(partitionAssignor.supportedProtocols().contains(RebalanceProtocol.COOPERATIVE)); + } + + @Test + public void shouldUseCooperativeRebalancingProtocol() { + configureDefault(); + + assertEquals(2, partitionAssignor.supportedProtocols().size()); + assertTrue(partitionAssignor.supportedProtocols().contains(RebalanceProtocol.COOPERATIVE)); + } + + @Test + public void shouldProduceStickyAndBalancedAssignmentWhenNothingChanges() { + final List allTasks = + asList(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_1_2, TASK_1_3); + + final Map> previousAssignment = mkMap( + mkEntry(CONSUMER_1, asList(TASK_0_0, TASK_1_1, TASK_1_3)), + mkEntry(CONSUMER_2, asList(TASK_0_3, TASK_1_0)), + mkEntry(CONSUMER_3, asList(TASK_0_1, TASK_0_2, TASK_1_2)) + ); + + final ClientState state = new ClientState(); + final SortedSet consumers = mkSortedSet(CONSUMER_1, CONSUMER_2, CONSUMER_3); + state.addPreviousTasksAndOffsetSums(CONSUMER_1, getTaskOffsetSums(asList(TASK_0_0, TASK_1_1, TASK_1_3), EMPTY_TASKS)); + state.addPreviousTasksAndOffsetSums(CONSUMER_2, getTaskOffsetSums(asList(TASK_0_3, TASK_1_0), EMPTY_TASKS)); + state.addPreviousTasksAndOffsetSums(CONSUMER_3, getTaskOffsetSums(asList(TASK_0_1, TASK_0_2, TASK_1_2), EMPTY_TASKS)); + state.initializePrevTasks(emptyMap()); + state.computeTaskLags(UUID_1, getTaskEndOffsetSums(allTasks)); + + assertEquivalentAssignment( + previousAssignment, + assignTasksToThreads( + allTasks, + emptySet(), + consumers, + state + ) + ); + } + + @Test + public void shouldProduceStickyAndBalancedAssignmentWhenNewTasksAreAdded() { + final List allTasks = + new ArrayList<>(asList(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_1_2, TASK_1_3)); + + final Map> previousAssignment = mkMap( + mkEntry(CONSUMER_1, new ArrayList<>(asList(TASK_0_0, TASK_1_1, TASK_1_3))), + mkEntry(CONSUMER_2, new ArrayList<>(asList(TASK_0_3, TASK_1_0))), + mkEntry(CONSUMER_3, new ArrayList<>(asList(TASK_0_1, TASK_0_2, TASK_1_2))) + ); + + final ClientState state = new ClientState(); + final SortedSet consumers = mkSortedSet(CONSUMER_1, CONSUMER_2, CONSUMER_3); + state.addPreviousTasksAndOffsetSums(CONSUMER_1, getTaskOffsetSums(asList(TASK_0_0, TASK_1_1, TASK_1_3), EMPTY_TASKS)); + state.addPreviousTasksAndOffsetSums(CONSUMER_2, getTaskOffsetSums(asList(TASK_0_3, TASK_1_0), EMPTY_TASKS)); + state.addPreviousTasksAndOffsetSums(CONSUMER_3, getTaskOffsetSums(asList(TASK_0_1, TASK_0_2, TASK_1_2), EMPTY_TASKS)); + state.initializePrevTasks(emptyMap()); + state.computeTaskLags(UUID_1, getTaskEndOffsetSums(allTasks)); + + // We should be able to add a new task without sacrificing stickiness + final TaskId newTask = TASK_2_0; + allTasks.add(newTask); + state.assignActiveTasks(allTasks); + + final Map> newAssignment = + assignTasksToThreads( + allTasks, + emptySet(), + consumers, + state + ); + + previousAssignment.get(CONSUMER_2).add(newTask); + assertEquivalentAssignment(previousAssignment, newAssignment); + } + + @Test + public void shouldProduceMaximallyStickyAssignmentWhenMemberLeaves() { + final List allTasks = + asList(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_1_2, TASK_1_3); + + final Map> previousAssignment = mkMap( + mkEntry(CONSUMER_1, asList(TASK_0_0, TASK_1_1, TASK_1_3)), + mkEntry(CONSUMER_2, asList(TASK_0_3, TASK_1_0)), + mkEntry(CONSUMER_3, asList(TASK_0_1, TASK_0_2, TASK_1_2)) + ); + + final ClientState state = new ClientState(); + final SortedSet consumers = mkSortedSet(CONSUMER_1, CONSUMER_2, CONSUMER_3); + state.addPreviousTasksAndOffsetSums(CONSUMER_1, getTaskOffsetSums(asList(TASK_0_0, TASK_1_1, TASK_1_3), EMPTY_TASKS)); + state.addPreviousTasksAndOffsetSums(CONSUMER_2, getTaskOffsetSums(asList(TASK_0_3, TASK_1_0), EMPTY_TASKS)); + state.addPreviousTasksAndOffsetSums(CONSUMER_3, getTaskOffsetSums(asList(TASK_0_1, TASK_0_2, TASK_1_2), EMPTY_TASKS)); + state.initializePrevTasks(emptyMap()); + state.computeTaskLags(UUID_1, getTaskEndOffsetSums(allTasks)); + + // Consumer 3 leaves the group + consumers.remove(CONSUMER_3); + + final Map> assignment = assignTasksToThreads( + allTasks, + emptySet(), + consumers, + state + ); + + // Each member should have all of its previous tasks reassigned plus some of consumer 3's tasks + // We should give one of its tasks to consumer 1, and two of its tasks to consumer 2 + assertTrue(assignment.get(CONSUMER_1).containsAll(previousAssignment.get(CONSUMER_1))); + assertTrue(assignment.get(CONSUMER_2).containsAll(previousAssignment.get(CONSUMER_2))); + + assertThat(assignment.get(CONSUMER_1).size(), equalTo(4)); + assertThat(assignment.get(CONSUMER_2).size(), equalTo(4)); + } + + @Test + public void shouldProduceStickyEnoughAssignmentWhenNewMemberJoins() { + final List allTasks = + asList(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_1_2, TASK_1_3); + + final Map> previousAssignment = mkMap( + mkEntry(CONSUMER_1, asList(TASK_0_0, TASK_1_1, TASK_1_3)), + mkEntry(CONSUMER_2, asList(TASK_0_3, TASK_1_0)), + mkEntry(CONSUMER_3, asList(TASK_0_1, TASK_0_2, TASK_1_2)) + ); + + final ClientState state = new ClientState(); + final SortedSet consumers = mkSortedSet(CONSUMER_1, CONSUMER_2, CONSUMER_3); + state.addPreviousTasksAndOffsetSums(CONSUMER_1, getTaskOffsetSums(asList(TASK_0_0, TASK_1_1, TASK_1_3), EMPTY_TASKS)); + state.addPreviousTasksAndOffsetSums(CONSUMER_2, getTaskOffsetSums(asList(TASK_0_3, TASK_1_0), EMPTY_TASKS)); + state.addPreviousTasksAndOffsetSums(CONSUMER_3, getTaskOffsetSums(asList(TASK_0_1, TASK_0_2, TASK_1_2), EMPTY_TASKS)); + + // Consumer 4 joins the group + consumers.add(CONSUMER_4); + state.addPreviousTasksAndOffsetSums(CONSUMER_4, getTaskOffsetSums(EMPTY_TASKS, EMPTY_TASKS)); + + state.initializePrevTasks(emptyMap()); + state.computeTaskLags(UUID_1, getTaskEndOffsetSums(allTasks)); + + final Map> assignment = assignTasksToThreads( + allTasks, + emptySet(), + consumers, + state + ); + + // we should move one task each from consumer 1 and consumer 3 to the new member, and none from consumer 2 + assertTrue(previousAssignment.get(CONSUMER_1).containsAll(assignment.get(CONSUMER_1))); + assertTrue(previousAssignment.get(CONSUMER_3).containsAll(assignment.get(CONSUMER_3))); + + assertTrue(assignment.get(CONSUMER_2).containsAll(previousAssignment.get(CONSUMER_2))); + + + assertThat(assignment.get(CONSUMER_1).size(), equalTo(2)); + assertThat(assignment.get(CONSUMER_2).size(), equalTo(2)); + assertThat(assignment.get(CONSUMER_3).size(), equalTo(2)); + assertThat(assignment.get(CONSUMER_4).size(), equalTo(2)); + } + + @Test + public void shouldInterleaveTasksByGroupIdDuringNewAssignment() { + final List allTasks = + asList(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_1_2, TASK_2_0, TASK_2_1); + + final Map> assignment = mkMap( + mkEntry(CONSUMER_1, new ArrayList<>(asList(TASK_0_0, TASK_0_3, TASK_1_2))), + mkEntry(CONSUMER_2, new ArrayList<>(asList(TASK_0_1, TASK_1_0, TASK_2_0))), + mkEntry(CONSUMER_3, new ArrayList<>(asList(TASK_0_2, TASK_1_1, TASK_2_1))) + ); + + final ClientState state = new ClientState(); + final SortedSet consumers = mkSortedSet(CONSUMER_1, CONSUMER_2, CONSUMER_3); + state.addPreviousTasksAndOffsetSums(CONSUMER_1, emptyMap()); + state.addPreviousTasksAndOffsetSums(CONSUMER_2, emptyMap()); + state.addPreviousTasksAndOffsetSums(CONSUMER_3, emptyMap()); + + Collections.shuffle(allTasks); + + final Map> interleavedTaskIds = + assignTasksToThreads( + allTasks, + emptySet(), + consumers, + state + ); + + assertThat(interleavedTaskIds, equalTo(assignment)); + } + + @Test + public void testEagerSubscription() { + builder.addSource(null, "source1", null, null, null, "topic1"); + builder.addSource(null, "source2", null, null, null, "topic2"); + builder.addProcessor("processor", new MockApiProcessorSupplier<>(), "source1", "source2"); + + final Set prevTasks = mkSet( + new TaskId(0, 1), new TaskId(1, 1), new TaskId(2, 1) + ); + final Set standbyTasks = mkSet( + new TaskId(0, 2), new TaskId(1, 2), new TaskId(2, 2) + ); + + createMockTaskManager(prevTasks, standbyTasks); + configurePartitionAssignorWith(Collections.singletonMap(StreamsConfig.UPGRADE_FROM_CONFIG, StreamsConfig.UPGRADE_FROM_23)); + assertThat(partitionAssignor.rebalanceProtocol(), equalTo(RebalanceProtocol.EAGER)); + + final Set topics = mkSet("topic1", "topic2"); + final Subscription subscription = new Subscription(new ArrayList<>(topics), partitionAssignor.subscriptionUserData(topics)); + + Collections.sort(subscription.topics()); + assertEquals(asList("topic1", "topic2"), subscription.topics()); + + final SubscriptionInfo info = getInfo(UUID_1, prevTasks, standbyTasks, uniqueField); + assertEquals(info, SubscriptionInfo.decode(subscription.userData())); + } + + @Test + public void testCooperativeSubscription() { + builder.addSource(null, "source1", null, null, null, "topic1"); + builder.addSource(null, "source2", null, null, null, "topic2"); + builder.addProcessor("processor", new MockApiProcessorSupplier<>(), "source1", "source2"); + + final Set prevTasks = mkSet( + new TaskId(0, 1), new TaskId(1, 1), new TaskId(2, 1)); + final Set standbyTasks = mkSet( + new TaskId(0, 1), new TaskId(1, 1), new TaskId(2, 1), + new TaskId(0, 2), new TaskId(1, 2), new TaskId(2, 2)); + + createMockTaskManager(prevTasks, standbyTasks); + configureDefaultPartitionAssignor(); + + final Set topics = mkSet("topic1", "topic2"); + final Subscription subscription = new Subscription( + new ArrayList<>(topics), partitionAssignor.subscriptionUserData(topics)); + + Collections.sort(subscription.topics()); + assertEquals(asList("topic1", "topic2"), subscription.topics()); + + final SubscriptionInfo info = getInfo(UUID_1, prevTasks, standbyTasks, uniqueField); + assertEquals(info, SubscriptionInfo.decode(subscription.userData())); + } + + @Test + public void testAssignBasic() { + builder.addSource(null, "source1", null, null, null, "topic1"); + builder.addSource(null, "source2", null, null, null, "topic2"); + builder.addProcessor("processor", new MockApiProcessorSupplier<>(), "source1", "source2"); + builder.addStateStore(new MockKeyValueStoreBuilder("store", false), "processor"); + final List topics = asList("topic1", "topic2"); + final Set allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2); + + final Set prevTasks10 = mkSet(TASK_0_0); + final Set prevTasks11 = mkSet(TASK_0_1); + final Set prevTasks20 = mkSet(TASK_0_2); + final Set standbyTasks10 = EMPTY_TASKS; + final Set standbyTasks11 = mkSet(TASK_0_2); + final Set standbyTasks20 = mkSet(TASK_0_0); + + createMockTaskManager(prevTasks10, standbyTasks10); + adminClient = createMockAdminClientForAssignor(getTopicPartitionOffsetsMap( + singletonList(APPLICATION_ID + "-store-changelog"), + singletonList(3)) + ); + configureDefaultPartitionAssignor(); + + subscriptions.put("consumer10", + new Subscription( + topics, + getInfo(UUID_1, prevTasks10, standbyTasks10).encode() + )); + subscriptions.put("consumer11", + new Subscription( + topics, + getInfo(UUID_1, prevTasks11, standbyTasks11).encode() + )); + subscriptions.put("consumer20", + new Subscription( + topics, + getInfo(UUID_2, prevTasks20, standbyTasks20).encode() + )); + + final Map assignments = partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)).groupAssignment(); + + + // check the assignment + assertEquals(mkSet(mkSet(t1p0, t2p0), mkSet(t1p1, t2p1)), + mkSet(new HashSet<>(assignments.get("consumer10").partitions()), + new HashSet<>(assignments.get("consumer11").partitions()))); + assertEquals(mkSet(t1p2, t2p2), new HashSet<>(assignments.get("consumer20").partitions())); + + // check assignment info + + // the first consumer + final AssignmentInfo info10 = checkAssignment(allTopics, assignments.get("consumer10")); + final Set allActiveTasks = new HashSet<>(info10.activeTasks()); + + // the second consumer + final AssignmentInfo info11 = checkAssignment(allTopics, assignments.get("consumer11")); + allActiveTasks.addAll(info11.activeTasks()); + + assertEquals(mkSet(TASK_0_0, TASK_0_1), allActiveTasks); + + // the third consumer + final AssignmentInfo info20 = checkAssignment(allTopics, assignments.get("consumer20")); + allActiveTasks.addAll(info20.activeTasks()); + + assertEquals(3, allActiveTasks.size()); + assertEquals(allTasks, new HashSet<>(allActiveTasks)); + + assertEquals(3, allActiveTasks.size()); + assertEquals(allTasks, allActiveTasks); + } + + @Test + public void shouldAssignEvenlyAcrossConsumersOneClientMultipleThreads() { + builder.addSource(null, "source1", null, null, null, "topic1"); + builder.addSource(null, "source2", null, null, null, "topic2"); + builder.addProcessor("processor", new MockApiProcessorSupplier<>(), "source1"); + builder.addProcessor("processorII", new MockApiProcessorSupplier<>(), "source2"); + + final List localInfos = asList( + new PartitionInfo("topic1", 0, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic1", 1, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic1", 2, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic1", 3, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic2", 0, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic2", 1, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic2", 2, Node.noNode(), new Node[0], new Node[0]), + new PartitionInfo("topic2", 3, Node.noNode(), new Node[0], new Node[0]) + ); + + final Cluster localMetadata = new Cluster( + "cluster", + Collections.singletonList(Node.noNode()), + localInfos, + emptySet(), + emptySet()); + + final List topics = asList("topic1", "topic2"); + + configureDefault(); + + subscriptions.put("consumer10", + new Subscription( + topics, + defaultSubscriptionInfo.encode() + )); + subscriptions.put("consumer11", + new Subscription( + topics, + defaultSubscriptionInfo.encode() + )); + + final Map assignments = partitionAssignor.assign(localMetadata, new GroupSubscription(subscriptions)).groupAssignment(); + + // check assigned partitions + assertEquals(mkSet(mkSet(t2p2, t1p0, t1p2, t2p0), mkSet(t1p1, t2p1, t1p3, t2p3)), + mkSet(new HashSet<>(assignments.get("consumer10").partitions()), new HashSet<>(assignments.get("consumer11").partitions()))); + + // the first consumer + final AssignmentInfo info10 = AssignmentInfo.decode(assignments.get("consumer10").userData()); + + final List expectedInfo10TaskIds = asList(TASK_0_0, TASK_0_2, TASK_1_0, TASK_1_2); + assertEquals(expectedInfo10TaskIds, info10.activeTasks()); + + // the second consumer + final AssignmentInfo info11 = AssignmentInfo.decode(assignments.get("consumer11").userData()); + final List expectedInfo11TaskIds = asList(TASK_0_1, TASK_0_3, TASK_1_1, TASK_1_3); + + assertEquals(expectedInfo11TaskIds, info11.activeTasks()); + } + + @Test + public void testAssignEmptyMetadata() { + builder.addSource(null, "source1", null, null, null, "topic1"); + builder.addSource(null, "source2", null, null, null, "topic2"); + builder.addProcessor("processor", new MockApiProcessorSupplier<>(), "source1", "source2"); + final List topics = asList("topic1", "topic2"); + final Set allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2); + + final Set prevTasks10 = mkSet(TASK_0_0); + final Set standbyTasks10 = mkSet(TASK_0_1); + final Cluster emptyMetadata = new Cluster("cluster", Collections.singletonList(Node.noNode()), + emptySet(), + emptySet(), + emptySet()); + + createMockTaskManager(prevTasks10, standbyTasks10); + configureDefaultPartitionAssignor(); + + subscriptions.put("consumer10", + new Subscription( + topics, + getInfo(UUID_1, prevTasks10, standbyTasks10).encode() + )); + + // initially metadata is empty + Map assignments = + partitionAssignor.assign(emptyMetadata, new GroupSubscription(subscriptions)).groupAssignment(); + + // check assigned partitions + assertEquals(emptySet(), + new HashSet<>(assignments.get("consumer10").partitions())); + + // check assignment info + AssignmentInfo info10 = checkAssignment(emptySet(), assignments.get("consumer10")); + final Set allActiveTasks = new HashSet<>(info10.activeTasks()); + + assertEquals(0, allActiveTasks.size()); + + // then metadata gets populated + assignments = partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)).groupAssignment(); + // check assigned partitions + assertEquals(mkSet(mkSet(t1p0, t2p0, t1p0, t2p0, t1p1, t2p1, t1p2, t2p2)), + mkSet(new HashSet<>(assignments.get("consumer10").partitions()))); + + // the first consumer + info10 = checkAssignment(allTopics, assignments.get("consumer10")); + allActiveTasks.addAll(info10.activeTasks()); + + assertEquals(3, allActiveTasks.size()); + assertEquals(allTasks, new HashSet<>(allActiveTasks)); + + assertEquals(3, allActiveTasks.size()); + assertEquals(allTasks, allActiveTasks); + } + + @Test + public void testAssignWithNewTasks() { + builder.addSource(null, "source1", null, null, null, "topic1"); + builder.addSource(null, "source2", null, null, null, "topic2"); + builder.addSource(null, "source3", null, null, null, "topic3"); + builder.addProcessor("processor", new MockApiProcessorSupplier<>(), "source1", "source2", "source3"); + final List topics = asList("topic1", "topic2", "topic3"); + final Set allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3); + + // assuming that previous tasks do not have topic3 + final Set prevTasks10 = mkSet(TASK_0_0); + final Set prevTasks11 = mkSet(TASK_0_1); + final Set prevTasks20 = mkSet(TASK_0_2); + + createMockTaskManager(prevTasks10, EMPTY_TASKS); + configureDefaultPartitionAssignor(); + + subscriptions.put("consumer10", + new Subscription( + topics, + getInfo(UUID_1, prevTasks10, EMPTY_TASKS).encode())); + subscriptions.put("consumer11", + new Subscription( + topics, + getInfo(UUID_1, prevTasks11, EMPTY_TASKS).encode())); + subscriptions.put("consumer20", + new Subscription( + topics, + getInfo(UUID_2, prevTasks20, EMPTY_TASKS).encode())); + + final Map assignments = partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)).groupAssignment(); + + // check assigned partitions: since there is no previous task for topic 3 it will be assigned randomly so we cannot check exact match + // also note that previously assigned partitions / tasks may not stay on the previous host since we may assign the new task first and + // then later ones will be re-assigned to other hosts due to load balancing + AssignmentInfo info = AssignmentInfo.decode(assignments.get("consumer10").userData()); + final Set allActiveTasks = new HashSet<>(info.activeTasks()); + final Set allPartitions = new HashSet<>(assignments.get("consumer10").partitions()); + + info = AssignmentInfo.decode(assignments.get("consumer11").userData()); + allActiveTasks.addAll(info.activeTasks()); + allPartitions.addAll(assignments.get("consumer11").partitions()); + + info = AssignmentInfo.decode(assignments.get("consumer20").userData()); + allActiveTasks.addAll(info.activeTasks()); + allPartitions.addAll(assignments.get("consumer20").partitions()); + + assertEquals(allTasks, allActiveTasks); + assertEquals(mkSet(t1p0, t1p1, t1p2, t2p0, t2p1, t2p2, t3p0, t3p1, t3p2, t3p3), allPartitions); + } + + @Test + public void testAssignWithStates() { + builder.addSource(null, "source1", null, null, null, "topic1"); + builder.addSource(null, "source2", null, null, null, "topic2"); + + builder.addProcessor("processor-1", new MockApiProcessorSupplier<>(), "source1"); + builder.addStateStore(new MockKeyValueStoreBuilder("store1", false), "processor-1"); + + builder.addProcessor("processor-2", new MockApiProcessorSupplier<>(), "source2"); + builder.addStateStore(new MockKeyValueStoreBuilder("store2", false), "processor-2"); + builder.addStateStore(new MockKeyValueStoreBuilder("store3", false), "processor-2"); + + final List topics = asList("topic1", "topic2"); + + final List tasks = asList(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2); + + adminClient = createMockAdminClientForAssignor(getTopicPartitionOffsetsMap( + asList(APPLICATION_ID + "-store1-changelog", + APPLICATION_ID + "-store2-changelog", + APPLICATION_ID + "-store3-changelog"), + asList(3, 3, 3)) + ); + configureDefault(); + + subscriptions.put("consumer10", + new Subscription(topics, defaultSubscriptionInfo.encode())); + subscriptions.put("consumer11", + new Subscription(topics, defaultSubscriptionInfo.encode())); + subscriptions.put("consumer20", + new Subscription(topics, getInfo(UUID_2, EMPTY_TASKS, EMPTY_TASKS).encode())); + + final Map assignments = partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)).groupAssignment(); + + // check assigned partition size: since there is no previous task and there are two sub-topologies the assignment is random so we cannot check exact match + assertEquals(2, assignments.get("consumer10").partitions().size()); + assertEquals(2, assignments.get("consumer11").partitions().size()); + assertEquals(2, assignments.get("consumer20").partitions().size()); + + final AssignmentInfo info10 = AssignmentInfo.decode(assignments.get("consumer10").userData()); + final AssignmentInfo info11 = AssignmentInfo.decode(assignments.get("consumer11").userData()); + final AssignmentInfo info20 = AssignmentInfo.decode(assignments.get("consumer20").userData()); + + assertEquals(2, info10.activeTasks().size()); + assertEquals(2, info11.activeTasks().size()); + assertEquals(2, info20.activeTasks().size()); + + final Set allTasks = new HashSet<>(); + allTasks.addAll(info10.activeTasks()); + allTasks.addAll(info11.activeTasks()); + allTasks.addAll(info20.activeTasks()); + assertEquals(new HashSet<>(tasks), allTasks); + + // check tasks for state topics + final Map topicGroups = builder.topicGroups(); + + assertEquals(mkSet(TASK_0_0, TASK_0_1, TASK_0_2), tasksForState("store1", tasks, topicGroups)); + assertEquals(mkSet(TASK_1_0, TASK_1_1, TASK_1_2), tasksForState("store2", tasks, topicGroups)); + assertEquals(mkSet(TASK_1_0, TASK_1_1, TASK_1_2), tasksForState("store3", tasks, topicGroups)); + } + + private static Set tasksForState(final String storeName, + final List tasks, + final Map topicGroups) { + final String changelogTopic = ProcessorStateManager.storeChangelogTopic(APPLICATION_ID, storeName, null); + + final Set ids = new HashSet<>(); + for (final Map.Entry entry : topicGroups.entrySet()) { + final Set stateChangelogTopics = entry.getValue().stateChangelogTopics.keySet(); + + if (stateChangelogTopics.contains(changelogTopic)) { + for (final TaskId id : tasks) { + if (id.subtopology() == entry.getKey().nodeGroupId) { + ids.add(id); + } + } + } + } + return ids; + } + + @Test + public void testAssignWithStandbyReplicasAndStatelessTasks() { + builder.addSource(null, "source1", null, null, null, "topic1", "topic2"); + builder.addProcessor("processor", new MockApiProcessorSupplier<>(), "source1"); + + final List topics = asList("topic1", "topic2"); + + createMockTaskManager(mkSet(TASK_0_0), emptySet()); + configurePartitionAssignorWith(Collections.singletonMap(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1)); + + subscriptions.put("consumer10", + new Subscription( + topics, + getInfo(UUID_1, mkSet(TASK_0_0), emptySet()).encode())); + subscriptions.put("consumer20", + new Subscription( + topics, + getInfo(UUID_2, mkSet(TASK_0_2), emptySet()).encode())); + + final Map assignments = + partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)).groupAssignment(); + + final AssignmentInfo info10 = checkAssignment(allTopics, assignments.get("consumer10")); + assertTrue(info10.standbyTasks().isEmpty()); + + final AssignmentInfo info20 = checkAssignment(allTopics, assignments.get("consumer20")); + assertTrue(info20.standbyTasks().isEmpty()); + } + + @Test + public void testAssignWithStandbyReplicasAndLoggingDisabled() { + builder.addSource(null, "source1", null, null, null, "topic1", "topic2"); + builder.addProcessor("processor", new MockApiProcessorSupplier<>(), "source1"); + builder.addStateStore(new MockKeyValueStoreBuilder("store1", false).withLoggingDisabled(), "processor"); + + final List topics = asList("topic1", "topic2"); + + createMockTaskManager(mkSet(TASK_0_0), emptySet()); + configurePartitionAssignorWith(Collections.singletonMap(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1)); + + subscriptions.put("consumer10", + new Subscription( + topics, + getInfo(UUID_1, mkSet(TASK_0_0), emptySet()).encode())); + subscriptions.put("consumer20", + new Subscription( + topics, + getInfo(UUID_2, mkSet(TASK_0_2), emptySet()).encode())); + + final Map assignments = + partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)).groupAssignment(); + + final AssignmentInfo info10 = checkAssignment(allTopics, assignments.get("consumer10")); + assertTrue(info10.standbyTasks().isEmpty()); + + final AssignmentInfo info20 = checkAssignment(allTopics, assignments.get("consumer20")); + assertTrue(info20.standbyTasks().isEmpty()); + } + + @Test + public void testAssignWithStandbyReplicas() { + builder.addSource(null, "source1", null, null, null, "topic1"); + builder.addSource(null, "source2", null, null, null, "topic2"); + builder.addProcessor("processor", new MockApiProcessorSupplier<>(), "source1", "source2"); + builder.addStateStore(new MockKeyValueStoreBuilder("store1", false), "processor"); + + final List topics = asList("topic1", "topic2"); + final Set allTopicPartitions = topics.stream() + .map(topic -> asList(new TopicPartition(topic, 0), new TopicPartition(topic, 1), new TopicPartition(topic, 2))) + .flatMap(Collection::stream) + .collect(Collectors.toSet()); + + final Set allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2); + + final Set prevTasks00 = mkSet(TASK_0_0); + final Set prevTasks01 = mkSet(TASK_0_1); + final Set prevTasks02 = mkSet(TASK_0_2); + final Set standbyTasks00 = mkSet(TASK_0_0); + final Set standbyTasks01 = mkSet(TASK_0_1); + final Set standbyTasks02 = mkSet(TASK_0_2); + + createMockTaskManager(prevTasks00, standbyTasks01); + adminClient = createMockAdminClientForAssignor(getTopicPartitionOffsetsMap( + singletonList(APPLICATION_ID + "-store1-changelog"), + singletonList(3)) + ); + configurePartitionAssignorWith(Collections.singletonMap(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1)); + + subscriptions.put("consumer10", + new Subscription( + topics, + getInfo(UUID_1, prevTasks00, EMPTY_TASKS, USER_END_POINT).encode())); + subscriptions.put("consumer11", + new Subscription( + topics, + getInfo(UUID_1, prevTasks01, standbyTasks02, USER_END_POINT).encode())); + subscriptions.put("consumer20", + new Subscription( + topics, + getInfo(UUID_2, prevTasks02, standbyTasks00, OTHER_END_POINT).encode())); + + final Map assignments = + partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)).groupAssignment(); + + // the first consumer + final AssignmentInfo info10 = checkAssignment(allTopics, assignments.get("consumer10")); + final Set allActiveTasks = new HashSet<>(info10.activeTasks()); + final Set allStandbyTasks = new HashSet<>(info10.standbyTasks().keySet()); + + // the second consumer + final AssignmentInfo info11 = checkAssignment(allTopics, assignments.get("consumer11")); + allActiveTasks.addAll(info11.activeTasks()); + allStandbyTasks.addAll(info11.standbyTasks().keySet()); + + assertNotEquals("same processId has same set of standby tasks", info11.standbyTasks().keySet(), info10.standbyTasks().keySet()); + + // check active tasks assigned to the first client + assertEquals(mkSet(TASK_0_0, TASK_0_1), new HashSet<>(allActiveTasks)); + assertEquals(mkSet(TASK_0_2), new HashSet<>(allStandbyTasks)); + + // the third consumer + final AssignmentInfo info20 = checkAssignment(allTopics, assignments.get("consumer20")); + allActiveTasks.addAll(info20.activeTasks()); + allStandbyTasks.addAll(info20.standbyTasks().keySet()); + + // all task ids are in the active tasks and also in the standby tasks + assertEquals(3, allActiveTasks.size()); + assertEquals(allTasks, allActiveTasks); + + assertEquals(3, allStandbyTasks.size()); + assertEquals(allTasks, allStandbyTasks); + + // Check host partition assignments + final Map> partitionsByHost = info10.partitionsByHost(); + assertEquals(2, partitionsByHost.size()); + assertEquals(allTopicPartitions, partitionsByHost.values().stream() + .flatMap(Collection::stream).collect(Collectors.toSet())); + + final Map> standbyPartitionsByHost = info10.standbyPartitionByHost(); + assertEquals(2, standbyPartitionsByHost.size()); + assertEquals(allTopicPartitions, standbyPartitionsByHost.values().stream() + .flatMap(Collection::stream).collect(Collectors.toSet())); + + for (final HostInfo hostInfo : partitionsByHost.keySet()) { + assertTrue(Collections.disjoint(partitionsByHost.get(hostInfo), standbyPartitionsByHost.get(hostInfo))); + } + + // All consumers got the same host info + assertEquals(partitionsByHost, info11.partitionsByHost()); + assertEquals(partitionsByHost, info20.partitionsByHost()); + assertEquals(standbyPartitionsByHost, info11.standbyPartitionByHost()); + assertEquals(standbyPartitionsByHost, info20.standbyPartitionByHost()); + } + + @Test + public void testOnAssignment() { + taskManager = EasyMock.createStrictMock(TaskManager.class); + + final Map> hostState = Collections.singletonMap( + new HostInfo("localhost", 9090), + mkSet(t3p0, t3p3)); + + final Map> activeTasks = new HashMap<>(); + activeTasks.put(TASK_0_0, mkSet(t3p0)); + activeTasks.put(TASK_0_3, mkSet(t3p3)); + final Map> standbyTasks = new HashMap<>(); + standbyTasks.put(TASK_0_1, mkSet(t3p1)); + standbyTasks.put(TASK_0_2, mkSet(t3p2)); + + taskManager.handleAssignment(activeTasks, standbyTasks); + EasyMock.expectLastCall(); + streamsMetadataState = EasyMock.createStrictMock(StreamsMetadataState.class); + final Capture capturedCluster = EasyMock.newCapture(); + streamsMetadataState.onChange(EasyMock.eq(hostState), EasyMock.anyObject(), EasyMock.capture(capturedCluster)); + EasyMock.expectLastCall(); + EasyMock.replay(streamsMetadataState); + + configureDefaultPartitionAssignor(); + + final List activeTaskList = asList(TASK_0_0, TASK_0_3); + final AssignmentInfo info = new AssignmentInfo(LATEST_SUPPORTED_VERSION, activeTaskList, standbyTasks, hostState, emptyMap(), 0); + final Assignment assignment = new Assignment(asList(t3p0, t3p3), info.encode()); + + partitionAssignor.onAssignment(assignment, null); + + EasyMock.verify(streamsMetadataState); + EasyMock.verify(taskManager); + + assertEquals(singleton(t3p0.topic()), capturedCluster.getValue().topics()); + assertEquals(2, capturedCluster.getValue().partitionsForTopic(t3p0.topic()).size()); + } + + @Test + public void testAssignWithInternalTopics() { + builder.addInternalTopic("topicX", InternalTopicProperties.empty()); + builder.addSource(null, "source1", null, null, null, "topic1"); + builder.addProcessor("processor1", new MockApiProcessorSupplier<>(), "source1"); + builder.addSink("sink1", "topicX", null, null, null, "processor1"); + builder.addSource(null, "source2", null, null, null, "topicX"); + builder.addProcessor("processor2", new MockApiProcessorSupplier<>(), "source2"); + final List topics = asList("topic1", APPLICATION_ID + "-topicX"); + final Set allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2); + + final MockInternalTopicManager internalTopicManager = configureDefault(); + + subscriptions.put("consumer10", + new Subscription( + topics, + defaultSubscriptionInfo.encode()) + ); + partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)); + + // check prepared internal topics + assertEquals(1, internalTopicManager.readyTopics.size()); + assertEquals(allTasks.size(), (long) internalTopicManager.readyTopics.get(APPLICATION_ID + "-topicX")); + } + + @Test + public void testAssignWithInternalTopicThatsSourceIsAnotherInternalTopic() { + builder.addInternalTopic("topicX", InternalTopicProperties.empty()); + builder.addSource(null, "source1", null, null, null, "topic1"); + builder.addProcessor("processor1", new MockApiProcessorSupplier<>(), "source1"); + builder.addSink("sink1", "topicX", null, null, null, "processor1"); + builder.addSource(null, "source2", null, null, null, "topicX"); + builder.addInternalTopic("topicZ", InternalTopicProperties.empty()); + builder.addProcessor("processor2", new MockApiProcessorSupplier<>(), "source2"); + builder.addSink("sink2", "topicZ", null, null, null, "processor2"); + builder.addSource(null, "source3", null, null, null, "topicZ"); + final List topics = asList("topic1", APPLICATION_ID + "-topicX", APPLICATION_ID + "-topicZ"); + final Set allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2); + + final MockInternalTopicManager internalTopicManager = configureDefault(); + + subscriptions.put("consumer10", + new Subscription( + topics, + defaultSubscriptionInfo.encode()) + ); + partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)); + + // check prepared internal topics + assertEquals(2, internalTopicManager.readyTopics.size()); + assertEquals(allTasks.size(), (long) internalTopicManager.readyTopics.get(APPLICATION_ID + "-topicZ")); + } + + @Test + public void shouldGenerateTasksForAllCreatedPartitions() { + final StreamsBuilder streamsBuilder = new StreamsBuilder(); + + // KStream with 3 partitions + final KStream stream1 = streamsBuilder + .stream("topic1") + // force creation of internal repartition topic + .map((KeyValueMapper>) KeyValue::new); + + // KTable with 4 partitions + final KTable table1 = streamsBuilder + .table("topic3") + // force creation of internal repartition topic + .groupBy(KeyValue::new) + .count(); + + // joining the stream and the table + // this triggers the enforceCopartitioning() routine in the StreamsPartitionAssignor, + // forcing the stream.map to get repartitioned to a topic with four partitions. + stream1.join( + table1, + (ValueJoiner) (value1, value2) -> null); + + final String client = "client1"; + builder = TopologyWrapper.getInternalTopologyBuilder(streamsBuilder.build()); + topologyMetadata = new TopologyMetadata(builder, new StreamsConfig(configProps())); + + adminClient = createMockAdminClientForAssignor(getTopicPartitionOffsetsMap( + asList(APPLICATION_ID + "-topic3-STATE-STORE-0000000002-changelog", + APPLICATION_ID + "-KTABLE-AGGREGATE-STATE-STORE-0000000006-changelog"), + asList(4, 4)) + ); + + final MockInternalTopicManager mockInternalTopicManager = configureDefault(); + + subscriptions.put(client, + new Subscription( + asList("topic1", "topic3"), + defaultSubscriptionInfo.encode()) + ); + final Map assignment = + partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)).groupAssignment(); + + final Map expectedCreatedInternalTopics = new HashMap<>(); + expectedCreatedInternalTopics.put(APPLICATION_ID + "-KTABLE-AGGREGATE-STATE-STORE-0000000006-repartition", 4); + expectedCreatedInternalTopics.put(APPLICATION_ID + "-KTABLE-AGGREGATE-STATE-STORE-0000000006-changelog", 4); + expectedCreatedInternalTopics.put(APPLICATION_ID + "-topic3-STATE-STORE-0000000002-changelog", 4); + expectedCreatedInternalTopics.put(APPLICATION_ID + "-KSTREAM-MAP-0000000001-repartition", 4); + + // check if all internal topics were created as expected + assertThat(mockInternalTopicManager.readyTopics, equalTo(expectedCreatedInternalTopics)); + + final List expectedAssignment = asList( + new TopicPartition("topic1", 0), + new TopicPartition("topic1", 1), + new TopicPartition("topic1", 2), + new TopicPartition("topic3", 0), + new TopicPartition("topic3", 1), + new TopicPartition("topic3", 2), + new TopicPartition("topic3", 3), + new TopicPartition(APPLICATION_ID + "-KTABLE-AGGREGATE-STATE-STORE-0000000006-repartition", 0), + new TopicPartition(APPLICATION_ID + "-KTABLE-AGGREGATE-STATE-STORE-0000000006-repartition", 1), + new TopicPartition(APPLICATION_ID + "-KTABLE-AGGREGATE-STATE-STORE-0000000006-repartition", 2), + new TopicPartition(APPLICATION_ID + "-KTABLE-AGGREGATE-STATE-STORE-0000000006-repartition", 3), + new TopicPartition(APPLICATION_ID + "-KSTREAM-MAP-0000000001-repartition", 0), + new TopicPartition(APPLICATION_ID + "-KSTREAM-MAP-0000000001-repartition", 1), + new TopicPartition(APPLICATION_ID + "-KSTREAM-MAP-0000000001-repartition", 2), + new TopicPartition(APPLICATION_ID + "-KSTREAM-MAP-0000000001-repartition", 3) + ); + + // check if we created a task for all expected topicPartitions. + assertThat(new HashSet<>(assignment.get(client).partitions()), equalTo(new HashSet<>(expectedAssignment))); + } + + @Test + public void shouldThrowTimeoutExceptionWhenCreatingRepartitionTopicsTimesOut() { + final StreamsBuilder streamsBuilder = new StreamsBuilder(); + streamsBuilder.stream("topic1").repartition(); + + final String client = "client1"; + builder = TopologyWrapper.getInternalTopologyBuilder(streamsBuilder.build()); + + createDefaultMockTaskManager(); + EasyMock.replay(taskManager); + partitionAssignor.configure(configProps()); + final MockInternalTopicManager mockInternalTopicManager = new MockInternalTopicManager( + time, + new StreamsConfig(configProps()), + mockClientSupplier.restoreConsumer, + false + ) { + @Override + public Set makeReady(final Map topics) { + throw new TimeoutException("KABOOM!"); + } + }; + partitionAssignor.setInternalTopicManager(mockInternalTopicManager); + + subscriptions.put(client, + new Subscription( + singletonList("topic1"), + defaultSubscriptionInfo.encode() + ) + ); + assertThrows(TimeoutException.class, () -> partitionAssignor.assign(metadata, new GroupSubscription(subscriptions))); + } + + @Test + public void shouldThrowTimeoutExceptionWhenCreatingChangelogTopicsTimesOut() { + final StreamsConfig config = new StreamsConfig(configProps()); + final StreamsBuilder streamsBuilder = new StreamsBuilder(); + streamsBuilder.table("topic1", Materialized.as("store")); + + final String client = "client1"; + builder = TopologyWrapper.getInternalTopologyBuilder(streamsBuilder.build()); + topologyMetadata = new TopologyMetadata(builder, config); + + createDefaultMockTaskManager(); + EasyMock.replay(taskManager); + partitionAssignor.configure(configProps()); + final MockInternalTopicManager mockInternalTopicManager = new MockInternalTopicManager( + time, + config, + mockClientSupplier.restoreConsumer, + false + ) { + @Override + public Set makeReady(final Map topics) { + if (topics.isEmpty()) { + return emptySet(); + } + throw new TimeoutException("KABOOM!"); + } + }; + partitionAssignor.setInternalTopicManager(mockInternalTopicManager); + + subscriptions.put(client, + new Subscription( + singletonList("topic1"), + defaultSubscriptionInfo.encode() + ) + ); + + assertThrows(TimeoutException.class, () -> partitionAssignor.assign(metadata, new GroupSubscription(subscriptions))); + } + + @Test + public void shouldAddUserDefinedEndPointToSubscription() { + builder.addSource(null, "source", null, null, null, "input"); + builder.addProcessor("processor", new MockApiProcessorSupplier<>(), "source"); + builder.addSink("sink", "output", null, null, null, "processor"); + + createDefaultMockTaskManager(); + configurePartitionAssignorWith(Collections.singletonMap(StreamsConfig.APPLICATION_SERVER_CONFIG, USER_END_POINT)); + + final Set topics = mkSet("input"); + final ByteBuffer userData = partitionAssignor.subscriptionUserData(topics); + final Subscription subscription = + new Subscription(new ArrayList<>(topics), userData); + final SubscriptionInfo subscriptionInfo = SubscriptionInfo.decode(subscription.userData()); + assertEquals("localhost:8080", subscriptionInfo.userEndPoint()); + } + + @Test + public void shouldMapUserEndPointToTopicPartitions() { + builder.addSource(null, "source", null, null, null, "topic1"); + builder.addProcessor("processor", new MockApiProcessorSupplier<>(), "source"); + builder.addSink("sink", "output", null, null, null, "processor"); + + final List topics = Collections.singletonList("topic1"); + + createDefaultMockTaskManager(); + configurePartitionAssignorWith(Collections.singletonMap(StreamsConfig.APPLICATION_SERVER_CONFIG, USER_END_POINT)); + + subscriptions.put("consumer1", + new Subscription( + topics, + getInfo(UUID_1, EMPTY_TASKS, EMPTY_TASKS, USER_END_POINT).encode()) + ); + final Map assignments = partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)).groupAssignment(); + final Assignment consumerAssignment = assignments.get("consumer1"); + final AssignmentInfo assignmentInfo = AssignmentInfo.decode(consumerAssignment.userData()); + final Set topicPartitions = assignmentInfo.partitionsByHost().get(new HostInfo("localhost", 8080)); + assertEquals( + mkSet( + new TopicPartition("topic1", 0), + new TopicPartition("topic1", 1), + new TopicPartition("topic1", 2)), + topicPartitions); + } + + @Test + public void shouldThrowExceptionIfApplicationServerConfigIsNotHostPortPair() { + createDefaultMockTaskManager(); + try { + configurePartitionAssignorWith(Collections.singletonMap(StreamsConfig.APPLICATION_SERVER_CONFIG, "localhost")); + fail("expected to an exception due to invalid config"); + } catch (final ConfigException e) { + // pass + } + } + + @Test + public void shouldThrowExceptionIfApplicationServerConfigPortIsNotAnInteger() { + createDefaultMockTaskManager(); + assertThrows(ConfigException.class, () -> configurePartitionAssignorWith(Collections.singletonMap(StreamsConfig.APPLICATION_SERVER_CONFIG, "localhost:j87yhk"))); + } + + @Test + public void shouldNotLoopInfinitelyOnMissingMetadataAndShouldNotCreateRelatedTasks() { + final StreamsBuilder streamsBuilder = new StreamsBuilder(); + + final KStream stream1 = streamsBuilder + + // Task 1 (should get created): + .stream("topic1") + // force repartitioning for aggregation + .selectKey((key, value) -> null) + .groupByKey() + + // Task 2 (should get created): + // create repartitioning and changelog topic as task 1 exists + .count(Materialized.as("count")) + + // force repartitioning for join, but second join input topic unknown + // -> internal repartitioning topic should not get created + .toStream() + .map((KeyValueMapper>) (key, value) -> null); + + streamsBuilder + // Task 3 (should not get created because input topic unknown) + .stream("unknownTopic") + + // force repartitioning for join, but input topic unknown + // -> thus should not create internal repartitioning topic + .selectKey((key, value) -> null) + + // Task 4 (should not get created because input topics unknown) + // should not create any of both input repartition topics or any of both changelog topics + .join( + stream1, + (ValueJoiner) (value1, value2) -> null, + JoinWindows.of(ofMillis(0)) + ); + + final String client = "client1"; + + builder = TopologyWrapper.getInternalTopologyBuilder(streamsBuilder.build()); + + final MockInternalTopicManager mockInternalTopicManager = configureDefault(); + + subscriptions.put(client, + new Subscription( + Collections.singletonList("unknownTopic"), + defaultSubscriptionInfo.encode()) + ); + final Map assignment = partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)).groupAssignment(); + + assertThat(mockInternalTopicManager.readyTopics.isEmpty(), equalTo(true)); + + assertThat(assignment.get(client).partitions().isEmpty(), equalTo(true)); + } + + @Test + public void shouldUpdateClusterMetadataAndHostInfoOnAssignment() { + final Map> initialHostState = mkMap( + mkEntry(new HostInfo("localhost", 9090), mkSet(t1p0, t1p1)), + mkEntry(new HostInfo("otherhost", 9090), mkSet(t2p0, t2p1)) + ); + + final Map> newHostState = mkMap( + mkEntry(new HostInfo("localhost", 9090), mkSet(t1p0, t1p1)), + mkEntry(new HostInfo("newotherhost", 9090), mkSet(t2p0, t2p1)) + ); + + streamsMetadataState = EasyMock.createStrictMock(StreamsMetadataState.class); + + streamsMetadataState.onChange(EasyMock.eq(initialHostState), EasyMock.anyObject(), EasyMock.anyObject()); + streamsMetadataState.onChange(EasyMock.eq(newHostState), EasyMock.anyObject(), EasyMock.anyObject()); + EasyMock.replay(streamsMetadataState); + + createDefaultMockTaskManager(); + configureDefaultPartitionAssignor(); + + partitionAssignor.onAssignment(createAssignment(initialHostState), null); + partitionAssignor.onAssignment(createAssignment(newHostState), null); + + EasyMock.verify(taskManager, streamsMetadataState); + } + + @Test + public void shouldTriggerImmediateRebalanceOnHostInfoChange() { + final Map> oldHostState = mkMap( + mkEntry(new HostInfo("localhost", 9090), mkSet(t1p0, t1p1)), + mkEntry(new HostInfo("otherhost", 9090), mkSet(t2p0, t2p1)) + ); + + final Map> newHostState = mkMap( + mkEntry(new HostInfo("newhost", 9090), mkSet(t1p0, t1p1)), + mkEntry(new HostInfo("otherhost", 9090), mkSet(t2p0, t2p1)) + ); + + createDefaultMockTaskManager(); + configurePartitionAssignorWith(Collections.singletonMap(StreamsConfig.APPLICATION_SERVER_CONFIG, "newhost:9090")); + + partitionAssignor.onAssignment(createAssignment(oldHostState), null); + + assertThat(referenceContainer.nextScheduledRebalanceMs.get(), is(0L)); + + partitionAssignor.onAssignment(createAssignment(newHostState), null); + + assertThat(referenceContainer.nextScheduledRebalanceMs.get(), is(Long.MAX_VALUE)); + } + + @Test + public void shouldTriggerImmediateRebalanceOnTasksRevoked() { + builder.addSource(null, "source1", null, null, null, "topic1"); + + final Set allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2); + final List allPartitions = asList(t1p0, t1p1, t1p2); + + subscriptions.put(CONSUMER_1, + new Subscription( + Collections.singletonList("topic1"), + getInfo(UUID_1, allTasks, EMPTY_TASKS).encode(), + allPartitions) + ); + subscriptions.put(CONSUMER_2, + new Subscription( + Collections.singletonList("topic1"), + getInfo(UUID_1, EMPTY_TASKS, allTasks).encode(), + emptyList()) + ); + + createMockTaskManager(allTasks, allTasks); + configurePartitionAssignorWith(singletonMap(StreamsConfig.ACCEPTABLE_RECOVERY_LAG_CONFIG, 0L)); + + final Map assignment = partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)).groupAssignment(); + + // Verify at least one partition was revoked + assertThat(assignment.get(CONSUMER_1).partitions(), not(allPartitions)); + assertThat(assignment.get(CONSUMER_2).partitions(), equalTo(emptyList())); + + // Verify that stateless revoked tasks would not be assigned as standbys + assertThat(AssignmentInfo.decode(assignment.get(CONSUMER_2).userData()).activeTasks(), equalTo(emptyList())); + assertThat(AssignmentInfo.decode(assignment.get(CONSUMER_2).userData()).standbyTasks(), equalTo(emptyMap())); + + partitionAssignor.onAssignment(assignment.get(CONSUMER_2), null); + + assertThat(referenceContainer.nextScheduledRebalanceMs.get(), is(0L)); + } + + @Test + public void shouldNotAddStandbyTaskPartitionsToPartitionsForHost() { + final Map props = configProps(); + props.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1); + props.put(StreamsConfig.APPLICATION_SERVER_CONFIG, USER_END_POINT); + + final StreamsBuilder streamsBuilder = new StreamsBuilder(); + streamsBuilder.stream("topic1").groupByKey().count(); + builder = TopologyWrapper.getInternalTopologyBuilder(streamsBuilder.build()); + topologyMetadata = new TopologyMetadata(builder, new StreamsConfig(props)); + + createDefaultMockTaskManager(); + adminClient = createMockAdminClientForAssignor(getTopicPartitionOffsetsMap( + singletonList(APPLICATION_ID + "-KSTREAM-AGGREGATE-STATE-STORE-0000000001-changelog"), + singletonList(3)) + ); + + configurePartitionAssignorWith(props); + + subscriptions.put("consumer1", + new Subscription( + Collections.singletonList("topic1"), + getInfo(UUID_1, EMPTY_TASKS, EMPTY_TASKS, USER_END_POINT).encode()) + ); + subscriptions.put("consumer2", + new Subscription( + Collections.singletonList("topic1"), + getInfo(UUID_2, EMPTY_TASKS, EMPTY_TASKS, OTHER_END_POINT).encode()) + ); + final Set allPartitions = mkSet(t1p0, t1p1, t1p2); + final Map assign = partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)).groupAssignment(); + final Assignment consumer1Assignment = assign.get("consumer1"); + final AssignmentInfo assignmentInfo = AssignmentInfo.decode(consumer1Assignment.userData()); + + final Set consumer1ActivePartitions = assignmentInfo.partitionsByHost().get(new HostInfo("localhost", 8080)); + final Set consumer2ActivePartitions = assignmentInfo.partitionsByHost().get(new HostInfo("other", 9090)); + final Set consumer1StandbyPartitions = assignmentInfo.standbyPartitionByHost().get(new HostInfo("localhost", 8080)); + final Set consumer2StandbyPartitions = assignmentInfo.standbyPartitionByHost().get(new HostInfo("other", 9090)); + final HashSet allAssignedPartitions = new HashSet<>(consumer1ActivePartitions); + allAssignedPartitions.addAll(consumer2ActivePartitions); + assertThat(consumer1ActivePartitions, not(allPartitions)); + assertThat(consumer2ActivePartitions, not(allPartitions)); + assertThat(consumer1ActivePartitions, equalTo(consumer2StandbyPartitions)); + assertThat(consumer2ActivePartitions, equalTo(consumer1StandbyPartitions)); + assertThat(allAssignedPartitions, equalTo(allPartitions)); + } + + @Test + public void shouldThrowKafkaExceptionIfReferenceContainerNotConfigured() { + final Map config = configProps(); + config.remove(InternalConfig.REFERENCE_CONTAINER_PARTITION_ASSIGNOR); + + final KafkaException expected = assertThrows( + KafkaException.class, + () -> partitionAssignor.configure(config) + ); + assertThat(expected.getMessage(), equalTo("ReferenceContainer is not specified")); + } + + @Test + public void shouldThrowKafkaExceptionIfReferenceContainerConfigIsNotTaskManagerInstance() { + final Map config = configProps(); + config.put(InternalConfig.REFERENCE_CONTAINER_PARTITION_ASSIGNOR, "i am not a reference container"); + + final KafkaException expected = assertThrows( + KafkaException.class, + () -> partitionAssignor.configure(config) + ); + assertThat( + expected.getMessage(), + equalTo("java.lang.String is not an instance of org.apache.kafka.streams.processor.internals.assignment.ReferenceContainer") + ); + } + + @Test + public void shouldReturnLowestAssignmentVersionForDifferentSubscriptionVersionsV1V2() { + shouldReturnLowestAssignmentVersionForDifferentSubscriptionVersions(1, 2); + } + + @Test + public void shouldReturnLowestAssignmentVersionForDifferentSubscriptionVersionsV1V3() { + shouldReturnLowestAssignmentVersionForDifferentSubscriptionVersions(1, 3); + } + + @Test + public void shouldReturnLowestAssignmentVersionForDifferentSubscriptionVersionsV2V3() { + shouldReturnLowestAssignmentVersionForDifferentSubscriptionVersions(2, 3); + } + + private void shouldReturnLowestAssignmentVersionForDifferentSubscriptionVersions(final int smallestVersion, + final int otherVersion) { + subscriptions.put("consumer1", + new Subscription( + Collections.singletonList("topic1"), + getInfoForOlderVersion(smallestVersion, UUID_1, EMPTY_TASKS, EMPTY_TASKS).encode()) + ); + subscriptions.put("consumer2", + new Subscription( + Collections.singletonList("topic1"), + getInfoForOlderVersion(otherVersion, UUID_2, EMPTY_TASKS, EMPTY_TASKS).encode() + ) + ); + + configureDefault(); + + final Map assignment = partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)).groupAssignment(); + + assertThat(assignment.size(), equalTo(2)); + assertThat(AssignmentInfo.decode(assignment.get("consumer1").userData()).version(), equalTo(smallestVersion)); + assertThat(AssignmentInfo.decode(assignment.get("consumer2").userData()).version(), equalTo(smallestVersion)); + } + + @Test + public void shouldDownGradeSubscriptionToVersion1() { + createDefaultMockTaskManager(); + configurePartitionAssignorWith(Collections.singletonMap(StreamsConfig.UPGRADE_FROM_CONFIG, StreamsConfig.UPGRADE_FROM_0100)); + + final Set topics = mkSet("topic1"); + final Subscription subscription = new Subscription(new ArrayList<>(topics), partitionAssignor.subscriptionUserData(topics)); + + assertThat(SubscriptionInfo.decode(subscription.userData()).version(), equalTo(1)); + } + + @Test + public void shouldDownGradeSubscriptionToVersion2For0101() { + shouldDownGradeSubscriptionToVersion2(StreamsConfig.UPGRADE_FROM_0101); + } + + @Test + public void shouldDownGradeSubscriptionToVersion2For0102() { + shouldDownGradeSubscriptionToVersion2(StreamsConfig.UPGRADE_FROM_0102); + } + + @Test + public void shouldDownGradeSubscriptionToVersion2For0110() { + shouldDownGradeSubscriptionToVersion2(StreamsConfig.UPGRADE_FROM_0110); + } + + @Test + public void shouldDownGradeSubscriptionToVersion2For10() { + shouldDownGradeSubscriptionToVersion2(StreamsConfig.UPGRADE_FROM_10); + } + + @Test + public void shouldDownGradeSubscriptionToVersion2For11() { + shouldDownGradeSubscriptionToVersion2(StreamsConfig.UPGRADE_FROM_11); + } + + private void shouldDownGradeSubscriptionToVersion2(final Object upgradeFromValue) { + createDefaultMockTaskManager(); + configurePartitionAssignorWith(Collections.singletonMap(StreamsConfig.UPGRADE_FROM_CONFIG, upgradeFromValue)); + + final Set topics = mkSet("topic1"); + final Subscription subscription = new Subscription(new ArrayList<>(topics), partitionAssignor.subscriptionUserData(topics)); + + assertThat(SubscriptionInfo.decode(subscription.userData()).version(), equalTo(2)); + } + + @Test + public void shouldReturnInterleavedAssignmentWithUnrevokedPartitionsRemovedWhenNewConsumerJoins() { + builder.addSource(null, "source1", null, null, null, "topic1"); + + final Set allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2); + + subscriptions.put(CONSUMER_1, + new Subscription( + Collections.singletonList("topic1"), + getInfo(UUID_1, allTasks, EMPTY_TASKS).encode(), + asList(t1p0, t1p1, t1p2)) + ); + subscriptions.put(CONSUMER_2, + new Subscription( + Collections.singletonList("topic1"), + getInfo(UUID_2, EMPTY_TASKS, EMPTY_TASKS).encode(), + emptyList()) + ); + + createMockTaskManager(allTasks, allTasks); + configureDefaultPartitionAssignor(); + + final Map assignment = partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)).groupAssignment(); + + assertThat(assignment.size(), equalTo(2)); + + // The new consumer's assignment should be empty until c1 has the chance to revoke its partitions/tasks + assertThat(assignment.get(CONSUMER_2).partitions(), equalTo(emptyList())); + + final AssignmentInfo actualAssignment = AssignmentInfo.decode(assignment.get(CONSUMER_2).userData()); + assertThat(actualAssignment.version(), is(LATEST_SUPPORTED_VERSION)); + assertThat(actualAssignment.activeTasks(), empty()); + // Note we're not asserting anything about standbys. If the assignor gave an active task to CONSUMER_2, it would + // be converted to a standby, but we don't know whether the assignor will do that. + assertThat(actualAssignment.partitionsByHost(), anEmptyMap()); + assertThat(actualAssignment.standbyPartitionByHost(), anEmptyMap()); + assertThat(actualAssignment.errCode(), is(0)); + } + + @Test + public void shouldReturnInterleavedAssignmentForOnlyFutureInstancesDuringVersionProbing() { + builder.addSource(null, "source1", null, null, null, "topic1"); + + final Set allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2); + + subscriptions.put(CONSUMER_1, + new Subscription( + Collections.singletonList("topic1"), + encodeFutureSubscription(), + emptyList()) + ); + subscriptions.put(CONSUMER_2, + new Subscription( + Collections.singletonList("topic1"), + encodeFutureSubscription(), + emptyList()) + ); + + createMockTaskManager(allTasks, allTasks); + configurePartitionAssignorWith(Collections.singletonMap(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1)); + + final Map assignment = + partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)).groupAssignment(); + + assertThat(assignment.size(), equalTo(2)); + + assertThat(assignment.get(CONSUMER_1).partitions(), equalTo(asList(t1p0, t1p2))); + assertThat( + AssignmentInfo.decode(assignment.get(CONSUMER_1).userData()), + equalTo(new AssignmentInfo(LATEST_SUPPORTED_VERSION, asList(TASK_0_0, TASK_0_2), emptyMap(), emptyMap(), emptyMap(), 0))); + + + assertThat(assignment.get(CONSUMER_2).partitions(), equalTo(Collections.singletonList(t1p1))); + assertThat( + AssignmentInfo.decode(assignment.get(CONSUMER_2).userData()), + equalTo(new AssignmentInfo(LATEST_SUPPORTED_VERSION, Collections.singletonList(TASK_0_1), emptyMap(), emptyMap(), emptyMap(), 0))); + } + + @Test + public void shouldEncodeAssignmentErrorIfV1SubscriptionAndFutureSubscriptionIsMixed() { + shouldEncodeAssignmentErrorIfPreVersionProbingSubscriptionAndFutureSubscriptionIsMixed(1); + } + + @Test + public void shouldEncodeAssignmentErrorIfV2SubscriptionAndFutureSubscriptionIsMixed() { + shouldEncodeAssignmentErrorIfPreVersionProbingSubscriptionAndFutureSubscriptionIsMixed(2); + } + + @Test + public void shouldNotFailOnBranchedMultiLevelRepartitionConnectedTopology() { + // Test out a topology with 3 level of sub-topology as: + // 0 + // / \ + // 1 3 + // \ / + // 2 + // where each pair of the sub topology is connected by repartition topic. + // The purpose of this test is to verify the robustness of the stream partition assignor algorithm, + // especially whether it could build the repartition topic counts (step zero) with a complex topology. + // The traversal path 0 -> 1 -> 2 -> 3 hits the case where sub-topology 2 will be initialized while its + // parent 3 hasn't been initialized yet. + builder.addSource(null, "KSTREAM-SOURCE-0000000000", null, null, null, "input-stream"); + builder.addProcessor("KSTREAM-FLATMAPVALUES-0000000001", new MockApiProcessorSupplier<>(), "KSTREAM-SOURCE-0000000000"); + builder.addProcessor("KSTREAM-BRANCH-0000000002", new MockApiProcessorSupplier<>(), "KSTREAM-FLATMAPVALUES-0000000001"); + builder.addProcessor("KSTREAM-BRANCHCHILD-0000000003", new MockApiProcessorSupplier<>(), "KSTREAM-BRANCH-0000000002"); + builder.addProcessor("KSTREAM-BRANCHCHILD-0000000004", new MockApiProcessorSupplier<>(), "KSTREAM-BRANCH-0000000002"); + builder.addProcessor("KSTREAM-MAP-0000000005", new MockApiProcessorSupplier<>(), "KSTREAM-BRANCHCHILD-0000000003"); + + builder.addInternalTopic("odd_store-repartition", InternalTopicProperties.empty()); + builder.addProcessor("odd_store-repartition-filter", new MockApiProcessorSupplier<>(), "KSTREAM-MAP-0000000005"); + builder.addSink("odd_store-repartition-sink", "odd_store-repartition", null, null, null, "odd_store-repartition-filter"); + builder.addSource(null, "odd_store-repartition-source", null, null, null, "odd_store-repartition"); + builder.addProcessor("KSTREAM-REDUCE-0000000006", new MockApiProcessorSupplier<>(), "odd_store-repartition-source"); + builder.addProcessor("KTABLE-TOSTREAM-0000000010", new MockApiProcessorSupplier<>(), "KSTREAM-REDUCE-0000000006"); + builder.addProcessor("KSTREAM-PEEK-0000000011", new MockApiProcessorSupplier<>(), "KTABLE-TOSTREAM-0000000010"); + builder.addProcessor("KSTREAM-MAP-0000000012", new MockApiProcessorSupplier<>(), "KSTREAM-PEEK-0000000011"); + + builder.addInternalTopic("odd_store_2-repartition", InternalTopicProperties.empty()); + builder.addProcessor("odd_store_2-repartition-filter", new MockApiProcessorSupplier<>(), "KSTREAM-MAP-0000000012"); + builder.addSink("odd_store_2-repartition-sink", "odd_store_2-repartition", null, null, null, "odd_store_2-repartition-filter"); + builder.addSource(null, "odd_store_2-repartition-source", null, null, null, "odd_store_2-repartition"); + builder.addProcessor("KSTREAM-REDUCE-0000000013", new MockApiProcessorSupplier<>(), "odd_store_2-repartition-source"); + builder.addProcessor("KSTREAM-MAP-0000000017", new MockApiProcessorSupplier<>(), "KSTREAM-BRANCHCHILD-0000000004"); + + builder.addInternalTopic("even_store-repartition", InternalTopicProperties.empty()); + builder.addProcessor("even_store-repartition-filter", new MockApiProcessorSupplier<>(), "KSTREAM-MAP-0000000017"); + builder.addSink("even_store-repartition-sink", "even_store-repartition", null, null, null, "even_store-repartition-filter"); + builder.addSource(null, "even_store-repartition-source", null, null, null, "even_store-repartition"); + builder.addProcessor("KSTREAM-REDUCE-0000000018", new MockApiProcessorSupplier<>(), "even_store-repartition-source"); + builder.addProcessor("KTABLE-TOSTREAM-0000000022", new MockApiProcessorSupplier<>(), "KSTREAM-REDUCE-0000000018"); + builder.addProcessor("KSTREAM-PEEK-0000000023", new MockApiProcessorSupplier<>(), "KTABLE-TOSTREAM-0000000022"); + builder.addProcessor("KSTREAM-MAP-0000000024", new MockApiProcessorSupplier<>(), "KSTREAM-PEEK-0000000023"); + + builder.addInternalTopic("even_store_2-repartition", InternalTopicProperties.empty()); + builder.addProcessor("even_store_2-repartition-filter", new MockApiProcessorSupplier<>(), "KSTREAM-MAP-0000000024"); + builder.addSink("even_store_2-repartition-sink", "even_store_2-repartition", null, null, null, "even_store_2-repartition-filter"); + builder.addSource(null, "even_store_2-repartition-source", null, null, null, "even_store_2-repartition"); + builder.addProcessor("KSTREAM-REDUCE-0000000025", new MockApiProcessorSupplier<>(), "even_store_2-repartition-source"); + builder.addProcessor("KTABLE-JOINTHIS-0000000030", new MockApiProcessorSupplier<>(), "KSTREAM-REDUCE-0000000013"); + builder.addProcessor("KTABLE-JOINOTHER-0000000031", new MockApiProcessorSupplier<>(), "KSTREAM-REDUCE-0000000025"); + builder.addProcessor("KTABLE-MERGE-0000000029", new MockApiProcessorSupplier<>(), "KTABLE-JOINTHIS-0000000030", "KTABLE-JOINOTHER-0000000031"); + builder.addProcessor("KTABLE-TOSTREAM-0000000032", new MockApiProcessorSupplier<>(), "KTABLE-MERGE-0000000029"); + + final List topics = asList("input-stream", "test-even_store-repartition", "test-even_store_2-repartition", "test-odd_store-repartition", "test-odd_store_2-repartition"); + + configureDefault(); + + subscriptions.put("consumer10", + new Subscription( + topics, + defaultSubscriptionInfo.encode()) + ); + + final Cluster metadata = new Cluster( + "cluster", + Collections.singletonList(Node.noNode()), + Collections.singletonList(new PartitionInfo("input-stream", 0, Node.noNode(), new Node[0], new Node[0])), + emptySet(), + emptySet()); + + // This shall fail if we have bugs in the repartition topic creation due to the inconsistent order of sub-topologies. + partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)); + } + + @Test + public void shouldGetAssignmentConfigs() { + createDefaultMockTaskManager(); + + final Map props = configProps(); + props.put(StreamsConfig.ACCEPTABLE_RECOVERY_LAG_CONFIG, 11); + props.put(StreamsConfig.MAX_WARMUP_REPLICAS_CONFIG, 33); + props.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 44); + props.put(StreamsConfig.PROBING_REBALANCE_INTERVAL_MS_CONFIG, 55 * 60 * 1000L); + + partitionAssignor.configure(props); + + assertThat(partitionAssignor.acceptableRecoveryLag(), equalTo(11L)); + assertThat(partitionAssignor.maxWarmupReplicas(), equalTo(33)); + assertThat(partitionAssignor.numStandbyReplicas(), equalTo(44)); + assertThat(partitionAssignor.probingRebalanceIntervalMs(), equalTo(55 * 60 * 1000L)); + } + + @Test + public void shouldGetTime() { + time.setCurrentTimeMs(Long.MAX_VALUE); + + createDefaultMockTaskManager(); + final Map props = configProps(); + final AssignorConfiguration assignorConfiguration = new AssignorConfiguration(props); + + assertThat(assignorConfiguration.referenceContainer().time.milliseconds(), equalTo(Long.MAX_VALUE)); + } + + @Test + public void shouldThrowIllegalStateExceptionIfAnyPartitionsMissingFromChangelogEndOffsets() { + final int changelogNumPartitions = 3; + builder.addSource(null, "source1", null, null, null, "topic1"); + builder.addProcessor("processor1", new MockApiProcessorSupplier<>(), "source1"); + builder.addStateStore(new MockKeyValueStoreBuilder("store1", false), "processor1"); + + adminClient = createMockAdminClientForAssignor(getTopicPartitionOffsetsMap( + singletonList(APPLICATION_ID + "-store1-changelog"), + singletonList(changelogNumPartitions - 1)) + ); + + configureDefault(); + + subscriptions.put("consumer10", + new Subscription( + singletonList("topic1"), + defaultSubscriptionInfo.encode() + )); + assertThrows(IllegalStateException.class, () -> partitionAssignor.assign(metadata, new GroupSubscription(subscriptions))); + } + + @Test + public void shouldThrowIllegalStateExceptionIfAnyTopicsMissingFromChangelogEndOffsets() { + builder.addSource(null, "source1", null, null, null, "topic1"); + builder.addProcessor("processor1", new MockApiProcessorSupplier<>(), "source1"); + builder.addStateStore(new MockKeyValueStoreBuilder("store1", false), "processor1"); + builder.addStateStore(new MockKeyValueStoreBuilder("store2", false), "processor1"); + + adminClient = createMockAdminClientForAssignor(getTopicPartitionOffsetsMap( + singletonList(APPLICATION_ID + "-store1-changelog"), + singletonList(3)) + ); + + configureDefault(); + + subscriptions.put("consumer10", + new Subscription( + singletonList("topic1"), + defaultSubscriptionInfo.encode() + )); + assertThrows(IllegalStateException.class, () -> partitionAssignor.assign(metadata, new GroupSubscription(subscriptions))); + } + + @Test + public void shouldSkipListOffsetsRequestForNewlyCreatedChangelogTopics() { + adminClient = EasyMock.createMock(AdminClient.class); + final ListOffsetsResult result = EasyMock.createNiceMock(ListOffsetsResult.class); + final KafkaFutureImpl> allFuture = new KafkaFutureImpl<>(); + allFuture.complete(emptyMap()); + + expect(adminClient.listOffsets(emptyMap())).andStubReturn(result); + expect(result.all()).andReturn(allFuture); + + builder.addSource(null, "source1", null, null, null, "topic1"); + builder.addProcessor("processor1", new MockApiProcessorSupplier<>(), "source1"); + builder.addStateStore(new MockKeyValueStoreBuilder("store1", false), "processor1"); + + subscriptions.put("consumer10", + new Subscription( + singletonList("topic1"), + defaultSubscriptionInfo.encode() + )); + + EasyMock.replay(result); + configureDefault(); + overwriteInternalTopicManagerWithMock(true); + + partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)); + + EasyMock.verify(adminClient); + } + + @Test + public void shouldRequestEndOffsetsForPreexistingChangelogs() { + final Set changelogs = mkSet( + new TopicPartition(APPLICATION_ID + "-store-changelog", 0), + new TopicPartition(APPLICATION_ID + "-store-changelog", 1), + new TopicPartition(APPLICATION_ID + "-store-changelog", 2) + ); + adminClient = EasyMock.createMock(AdminClient.class); + final ListOffsetsResult result = EasyMock.createNiceMock(ListOffsetsResult.class); + final KafkaFutureImpl> allFuture = new KafkaFutureImpl<>(); + allFuture.complete(changelogs.stream().collect(Collectors.toMap( + tp -> tp, + tp -> { + final ListOffsetsResultInfo info = EasyMock.createNiceMock(ListOffsetsResultInfo.class); + expect(info.offset()).andStubReturn(Long.MAX_VALUE); + EasyMock.replay(info); + return info; + })) + ); + final Capture> capturedChangelogs = EasyMock.newCapture(); + + expect(adminClient.listOffsets(EasyMock.capture(capturedChangelogs))).andReturn(result).once(); + expect(result.all()).andReturn(allFuture); + + builder.addSource(null, "source1", null, null, null, "topic1"); + builder.addProcessor("processor1", new MockApiProcessorSupplier<>(), "source1"); + builder.addStateStore(new MockKeyValueStoreBuilder("store", false), "processor1"); + + subscriptions.put("consumer10", + new Subscription( + singletonList("topic1"), + defaultSubscriptionInfo.encode() + )); + + EasyMock.replay(result); + configureDefault(); + overwriteInternalTopicManagerWithMock(false); + + partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)); + + EasyMock.verify(adminClient); + assertThat( + capturedChangelogs.getValue().keySet(), + equalTo(changelogs) + ); + } + + @Test + public void shouldRequestCommittedOffsetsForPreexistingSourceChangelogs() { + final Set changelogs = mkSet( + new TopicPartition("topic1", 0), + new TopicPartition("topic1", 1), + new TopicPartition("topic1", 2) + ); + + final StreamsBuilder streamsBuilder = new StreamsBuilder(); + streamsBuilder.table("topic1", Materialized.as("store")); + + final Properties props = new Properties(); + props.putAll(configProps()); + props.put(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, StreamsConfig.OPTIMIZE); + builder = TopologyWrapper.getInternalTopologyBuilder(streamsBuilder.build(props)); + topologyMetadata = new TopologyMetadata(builder, new StreamsConfig(props)); + + subscriptions.put("consumer10", + new Subscription( + singletonList("topic1"), + defaultSubscriptionInfo.encode() + )); + + createDefaultMockTaskManager(); + configurePartitionAssignorWith(singletonMap(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, StreamsConfig.OPTIMIZE)); + overwriteInternalTopicManagerWithMock(false); + + final Consumer consumerClient = referenceContainer.mainConsumer; + EasyMock.expect(consumerClient.committed(EasyMock.eq(changelogs))) + .andReturn(changelogs.stream().collect(Collectors.toMap(tp -> tp, tp -> new OffsetAndMetadata(Long.MAX_VALUE)))).once(); + + EasyMock.replay(consumerClient); + partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)); + + EasyMock.verify(consumerClient); + } + + @Test + public void shouldEncodeMissingSourceTopicError() { + final Cluster emptyClusterMetadata = new Cluster( + "cluster", + Collections.singletonList(Node.noNode()), + emptyList(), + emptySet(), + emptySet() + ); + + builder.addSource(null, "source1", null, null, null, "topic1"); + configureDefault(); + + subscriptions.put("consumer", + new Subscription( + singletonList("topic"), + defaultSubscriptionInfo.encode() + )); + final Map assignments = partitionAssignor.assign(emptyClusterMetadata, new GroupSubscription(subscriptions)).groupAssignment(); + assertThat(AssignmentInfo.decode(assignments.get("consumer").userData()).errCode(), + equalTo(AssignorError.INCOMPLETE_SOURCE_TOPIC_METADATA.code())); + } + + @Test + public void testUniqueField() { + createDefaultMockTaskManager(); + configureDefaultPartitionAssignor(); + final Set topics = mkSet("input"); + + assertEquals(0, partitionAssignor.uniqueField()); + partitionAssignor.subscriptionUserData(topics); + assertEquals(1, partitionAssignor.uniqueField()); + partitionAssignor.subscriptionUserData(topics); + assertEquals(2, partitionAssignor.uniqueField()); + } + + @Test + public void testUniqueFieldOverflow() { + createDefaultMockTaskManager(); + configureDefaultPartitionAssignor(); + final Set topics = mkSet("input"); + + for (int i = 0; i < 127; i++) { + partitionAssignor.subscriptionUserData(topics); + } + assertEquals(127, partitionAssignor.uniqueField()); + partitionAssignor.subscriptionUserData(topics); + assertEquals(-128, partitionAssignor.uniqueField()); + } + + @Test + public void shouldThrowTaskAssignmentExceptionWhenUnableToResolvePartitionCount() { + builder = new CorruptedInternalTopologyBuilder(); + topologyMetadata = new TopologyMetadata(builder, new StreamsConfig(configProps())); + + final InternalStreamsBuilder streamsBuilder = new InternalStreamsBuilder(builder); + + final KStream inputTopic = streamsBuilder.stream(singleton("topic1"), new ConsumedInternal<>()); + final KTable inputTable = streamsBuilder.table("topic2", new ConsumedInternal<>(), new MaterializedInternal<>(Materialized.as("store"))); + inputTopic + .groupBy( + (k, v) -> k, + Grouped.with("GroupName", Serdes.String(), Serdes.String()) + ) + .windowedBy(TimeWindows.of(Duration.ofMinutes(10))) + .aggregate( + () -> "", + (k, v, a) -> a + k) + .leftJoin( + inputTable, + v -> v, + (x, y) -> x + y + ); + streamsBuilder.buildAndOptimizeTopology(); + + configureDefault(); + + subscriptions.put("consumer", + new Subscription( + singletonList("topic"), + defaultSubscriptionInfo.encode() + )); + final Map assignments = partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)).groupAssignment(); + assertThat(AssignmentInfo.decode(assignments.get("consumer").userData()).errCode(), + equalTo(AssignorError.ASSIGNMENT_ERROR.code())); + } + + private static class CorruptedInternalTopologyBuilder extends InternalTopologyBuilder { + private Map corruptedTopicGroups; + + @Override + public synchronized Map topicGroups() { + if (corruptedTopicGroups == null) { + corruptedTopicGroups = new HashMap<>(); + for (final Map.Entry topicGroupEntry : super.topicGroups().entrySet()) { + final TopicsInfo originalInfo = topicGroupEntry.getValue(); + corruptedTopicGroups.put( + topicGroupEntry.getKey(), + new TopicsInfo( + emptySet(), + originalInfo.sourceTopics, + originalInfo.repartitionSourceTopics, + originalInfo.stateChangelogTopics + )); + } + } + + return corruptedTopicGroups; + } + } + + private static ByteBuffer encodeFutureSubscription() { + final ByteBuffer buf = ByteBuffer.allocate(4 /* used version */ + 4 /* supported version */); + buf.putInt(LATEST_SUPPORTED_VERSION + 1); + buf.putInt(LATEST_SUPPORTED_VERSION + 1); + return buf; + } + + private void shouldEncodeAssignmentErrorIfPreVersionProbingSubscriptionAndFutureSubscriptionIsMixed(final int oldVersion) { + subscriptions.put("consumer1", + new Subscription( + Collections.singletonList("topic1"), + getInfoForOlderVersion(oldVersion, UUID_1, EMPTY_TASKS, EMPTY_TASKS).encode()) + ); + subscriptions.put("future-consumer", + new Subscription( + Collections.singletonList("topic1"), + encodeFutureSubscription()) + ); + configureDefault(); + + final Map assignment = partitionAssignor.assign(metadata, new GroupSubscription(subscriptions)).groupAssignment(); + + assertThat(AssignmentInfo.decode(assignment.get("consumer1").userData()).errCode(), equalTo(AssignorError.ASSIGNMENT_ERROR.code())); + assertThat(AssignmentInfo.decode(assignment.get("future-consumer").userData()).errCode(), equalTo(AssignorError.ASSIGNMENT_ERROR.code())); + } + + private static Assignment createAssignment(final Map> firstHostState) { + final AssignmentInfo info = new AssignmentInfo(LATEST_SUPPORTED_VERSION, emptyList(), emptyMap(), firstHostState, emptyMap(), 0); + return new Assignment(emptyList(), info.encode()); + } + + private static AssignmentInfo checkAssignment(final Set expectedTopics, + final Assignment assignment) { + + // This assumed 1) DefaultPartitionGrouper is used, and 2) there is an only one topic group. + + final AssignmentInfo info = AssignmentInfo.decode(assignment.userData()); + + // check if the number of assigned partitions == the size of active task id list + assertEquals(assignment.partitions().size(), info.activeTasks().size()); + + // check if active tasks are consistent + final List activeTasks = new ArrayList<>(); + final Set activeTopics = new HashSet<>(); + for (final TopicPartition partition : assignment.partitions()) { + // since default grouper, taskid.partition == partition.partition() + activeTasks.add(new TaskId(0, partition.partition())); + activeTopics.add(partition.topic()); + } + assertEquals(activeTasks, info.activeTasks()); + + // check if active partitions cover all topics + assertEquals(expectedTopics, activeTopics); + + // check if standby tasks are consistent + final Set standbyTopics = new HashSet<>(); + for (final Map.Entry> entry : info.standbyTasks().entrySet()) { + final TaskId id = entry.getKey(); + final Set partitions = entry.getValue(); + for (final TopicPartition partition : partitions) { + // since default grouper, taskid.partition == partition.partition() + assertEquals(id.partition(), partition.partition()); + + standbyTopics.add(partition.topic()); + } + } + + if (!info.standbyTasks().isEmpty()) { + // check if standby partitions cover all topics + assertEquals(expectedTopics, standbyTopics); + } + + return info; + } + + private static void assertEquivalentAssignment(final Map> thisAssignment, + final Map> otherAssignment) { + assertEquals(thisAssignment.size(), otherAssignment.size()); + for (final Map.Entry> entry : thisAssignment.entrySet()) { + final String consumer = entry.getKey(); + assertTrue(otherAssignment.containsKey(consumer)); + + final List thisTaskList = entry.getValue(); + Collections.sort(thisTaskList); + final List otherTaskList = otherAssignment.get(consumer); + Collections.sort(otherTaskList); + + assertThat(thisTaskList, equalTo(otherTaskList)); + } + } + + /** + * Helper for building the input to createMockAdminClient in cases where we don't care about the actual offsets + * @param changelogTopics The names of all changelog topics in the topology + * @param topicsNumPartitions The number of partitions for the corresponding changelog topic, such that the number + * of partitions of the ith topic in changelogTopics is given by the ith element of topicsNumPartitions + */ + private static Map getTopicPartitionOffsetsMap(final List changelogTopics, + final List topicsNumPartitions) { + if (changelogTopics.size() != topicsNumPartitions.size()) { + throw new IllegalStateException("Passed in " + changelogTopics.size() + " changelog topic names, but " + + topicsNumPartitions.size() + " different numPartitions for the topics"); + } + final Map changelogEndOffsets = new HashMap<>(); + for (int i = 0; i < changelogTopics.size(); ++i) { + final String topic = changelogTopics.get(i); + final int numPartitions = topicsNumPartitions.get(i); + for (int partition = 0; partition < numPartitions; ++partition) { + changelogEndOffsets.put(new TopicPartition(topic, partition), Long.MAX_VALUE); + } + } + return changelogEndOffsets; + } + + private static SubscriptionInfo getInfoForOlderVersion(final int version, + final UUID processId, + final Set prevTasks, + final Set standbyTasks) { + return new SubscriptionInfo( + version, LATEST_SUPPORTED_VERSION, processId, null, getTaskOffsetSums(prevTasks, standbyTasks), (byte) 0, 0); + } + + // Stub offset sums for when we only care about the prev/standby task sets, not the actual offsets + private static Map getTaskOffsetSums(final Collection activeTasks, final Collection standbyTasks) { + final Map taskOffsetSums = activeTasks.stream().collect(Collectors.toMap(t -> t, t -> Task.LATEST_OFFSET)); + taskOffsetSums.putAll(standbyTasks.stream().collect(Collectors.toMap(t -> t, t -> 0L))); + return taskOffsetSums; + } + + // Stub end offsets sums for situations where we don't really care about computing exact lags + private static Map getTaskEndOffsetSums(final Collection allStatefulTasks) { + return allStatefulTasks.stream().collect(Collectors.toMap(t -> t, t -> Long.MAX_VALUE)); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsProducerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsProducerTest.java new file mode 100644 index 0000000..420d94a --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsProducerTest.java @@ -0,0 +1,1255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.consumer.ConsumerGroupMetadata; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.producer.MockProducer; +import org.apache.kafka.clients.producer.Producer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.InvalidProducerEpochException; +import org.apache.kafka.common.errors.ProducerFencedException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.UnknownProducerIdException; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.KafkaClientSupplier; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.errors.TaskMigratedException; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.test.MockClientSupplier; +import org.junit.Before; +import org.junit.Test; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.expectLastCall; +import static org.easymock.EasyMock.mock; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.reset; +import static org.easymock.EasyMock.verify; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class StreamsProducerTest { + private static final double BUFFER_POOL_WAIT_TIME = 1; + private static final double FLUSH_TME = 2; + private static final double TXN_INIT_TIME = 3; + private static final double TXN_BEGIN_TIME = 4; + private static final double TXN_SEND_OFFSETS_TIME = 5; + private static final double TXN_COMMIT_TIME = 6; + private static final double TXN_ABORT_TIME = 7; + + private final LogContext logContext = new LogContext("test "); + private final String topic = "topic"; + private final Cluster cluster = new Cluster( + "cluster", + Collections.singletonList(Node.noNode()), + Collections.singletonList(new PartitionInfo(topic, 0, Node.noNode(), new Node[0], new Node[0])), + Collections.emptySet(), + Collections.emptySet() + ); + + private final StreamsConfig nonEosConfig = new StreamsConfig(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234")) + ); + + @SuppressWarnings("deprecation") + private final StreamsConfig eosAlphaConfig = new StreamsConfig(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"), + mkEntry(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE)) + ); + + private final StreamsConfig eosBetaConfig = new StreamsConfig(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"), + mkEntry(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE_V2)) + ); + + private final Time mockTime = mock(Time.class); + + final Producer mockedProducer = mock(Producer.class); + final KafkaClientSupplier clientSupplier = new MockClientSupplier() { + @Override + public Producer getProducer(final Map config) { + return mockedProducer; + } + }; + final StreamsProducer streamsProducerWithMock = new StreamsProducer( + nonEosConfig, + "threadId", + clientSupplier, + null, + null, + logContext, + mockTime + ); + final StreamsProducer eosAlphaStreamsProducerWithMock = new StreamsProducer( + eosAlphaConfig, + "threadId", + clientSupplier, + new TaskId(0, 0), + null, + logContext, + mockTime + ); + + private final MockClientSupplier mockClientSupplier = new MockClientSupplier(); + private StreamsProducer nonEosStreamsProducer; + private MockProducer nonEosMockProducer; + + private final MockClientSupplier eosAlphaMockClientSupplier = new MockClientSupplier(); + private StreamsProducer eosAlphaStreamsProducer; + private MockProducer eosAlphaMockProducer; + + private final MockClientSupplier eosBetaMockClientSupplier = new MockClientSupplier(); + private StreamsProducer eosBetaStreamsProducer; + private MockProducer eosBetaMockProducer; + + private final ProducerRecord record = + new ProducerRecord<>(topic, 0, 0L, new byte[0], new byte[0], new RecordHeaders()); + + private final Map offsetsAndMetadata = mkMap( + mkEntry(new TopicPartition(topic, 0), new OffsetAndMetadata(0L, null)) + ); + + @Before + public void before() { + mockClientSupplier.setCluster(cluster); + nonEosStreamsProducer = + new StreamsProducer( + nonEosConfig, + "threadId-StreamThread-0", + mockClientSupplier, + null, + null, + logContext, + mockTime + ); + nonEosMockProducer = mockClientSupplier.producers.get(0); + + eosAlphaMockClientSupplier.setCluster(cluster); + eosAlphaMockClientSupplier.setApplicationIdForProducer("appId"); + eosAlphaStreamsProducer = + new StreamsProducer( + eosAlphaConfig, + "threadId-StreamThread-0", + eosAlphaMockClientSupplier, + new TaskId(0, 0), + null, + logContext, + mockTime + ); + eosAlphaStreamsProducer.initTransaction(); + eosAlphaMockProducer = eosAlphaMockClientSupplier.producers.get(0); + + eosBetaMockClientSupplier.setCluster(cluster); + eosBetaMockClientSupplier.setApplicationIdForProducer("appId"); + eosBetaStreamsProducer = + new StreamsProducer( + eosBetaConfig, + "threadId-StreamThread-0", + eosBetaMockClientSupplier, + null, + UUID.randomUUID(), + logContext, + mockTime + ); + eosBetaStreamsProducer.initTransaction(); + eosBetaMockProducer = eosBetaMockClientSupplier.producers.get(0); + expect(mockTime.nanoseconds()).andAnswer(Time.SYSTEM::nanoseconds).anyTimes(); + replay(mockTime); + } + + + + // common tests (non-EOS and EOS-alpha/beta) + + // functional tests + + @Test + public void shouldCreateProducer() { + assertThat(mockClientSupplier.producers.size(), is(1)); + assertThat(eosAlphaMockClientSupplier.producers.size(), is(1)); + } + + @Test + public void shouldForwardCallToPartitionsFor() { + final List expectedPartitionInfo = Collections.emptyList(); + expect(mockedProducer.partitionsFor("topic")).andReturn(expectedPartitionInfo); + replay(mockedProducer); + + final List partitionInfo = streamsProducerWithMock.partitionsFor(topic); + + assertThat(partitionInfo, sameInstance(expectedPartitionInfo)); + verify(mockedProducer); + } + + @Test + public void shouldForwardCallToFlush() { + mockedProducer.flush(); + expectLastCall(); + replay(mockedProducer); + + streamsProducerWithMock.flush(); + + verify(mockedProducer); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + @Test + public void shouldForwardCallToMetrics() { + final Map metrics = new HashMap<>(); + expect(mockedProducer.metrics()).andReturn(metrics); + replay(mockedProducer); + + assertSame(metrics, streamsProducerWithMock.metrics()); + + verify(mockedProducer); + } + + @Test + public void shouldForwardCallToClose() { + mockedProducer.close(); + expectLastCall(); + replay(mockedProducer); + + streamsProducerWithMock.close(); + + verify(mockedProducer); + } + + // error handling tests + + @Test + public void shouldFailIfStreamsConfigIsNull() { + final NullPointerException thrown = assertThrows( + NullPointerException.class, + () -> new StreamsProducer( + null, + "threadId", + mockClientSupplier, + new TaskId(0, 0), + UUID.randomUUID(), + logContext, + mockTime) + ); + + assertThat(thrown.getMessage(), is("config cannot be null")); + } + + @Test + public void shouldFailIfThreadIdIsNull() { + final NullPointerException thrown = assertThrows( + NullPointerException.class, + () -> new StreamsProducer( + nonEosConfig, + null, + mockClientSupplier, + new TaskId(0, 0), + UUID.randomUUID(), + logContext, + mockTime) + ); + + assertThat(thrown.getMessage(), is("threadId cannot be null")); + } + + @Test + public void shouldFailIfClientSupplierIsNull() { + final NullPointerException thrown = assertThrows( + NullPointerException.class, + () -> new StreamsProducer( + nonEosConfig, + "threadId", + null, + new TaskId(0, 0), + UUID.randomUUID(), + logContext, + mockTime) + ); + + assertThat(thrown.getMessage(), is("clientSupplier cannot be null")); + } + + @Test + public void shouldFailIfLogContextIsNull() { + final NullPointerException thrown = assertThrows( + NullPointerException.class, + () -> new StreamsProducer( + nonEosConfig, + "threadId", + mockClientSupplier, + new TaskId(0, 0), + UUID.randomUUID(), + null, + mockTime) + ); + + assertThat(thrown.getMessage(), is("logContext cannot be null")); + } + + @Test + public void shouldFailOnResetProducerForAtLeastOnce() { + final IllegalStateException thrown = assertThrows( + IllegalStateException.class, + () -> nonEosStreamsProducer.resetProducer() + ); + + assertThat(thrown.getMessage(), is("Expected eos-v2 to be enabled, but the processing mode was AT_LEAST_ONCE")); + } + + @Test + public void shouldFailOnResetProducerForExactlyOnceAlpha() { + final IllegalStateException thrown = assertThrows( + IllegalStateException.class, + () -> eosAlphaStreamsProducer.resetProducer() + ); + + assertThat(thrown.getMessage(), is("Expected eos-v2 to be enabled, but the processing mode was EXACTLY_ONCE_ALPHA")); + } + + + // non-EOS tests + + // functional tests + + @Test + public void shouldNotSetTransactionIdIfEosDisabled() { + final StreamsConfig mockConfig = mock(StreamsConfig.class); + expect(mockConfig.getProducerConfigs("threadId-producer")).andReturn(mock(Map.class)); + expect(mockConfig.getString(StreamsConfig.PROCESSING_GUARANTEE_CONFIG)).andReturn(StreamsConfig.AT_LEAST_ONCE).anyTimes(); + replay(mockConfig); + + new StreamsProducer( + mockConfig, + "threadId", + mockClientSupplier, + null, + null, + logContext, + mockTime + ); + } + + @Test + public void shouldNotHaveEosEnabledIfEosDisabled() { + assertThat(nonEosStreamsProducer.eosEnabled(), is(false)); + } + + @Test + public void shouldNotInitTxIfEosDisable() { + assertThat(nonEosMockProducer.transactionInitialized(), is(false)); + } + + @Test + public void shouldNotBeginTxOnSendIfEosDisable() { + nonEosStreamsProducer.send(record, null); + assertThat(nonEosMockProducer.transactionInFlight(), is(false)); + } + + @Test + public void shouldForwardRecordOnSend() { + nonEosStreamsProducer.send(record, null); + assertThat(nonEosMockProducer.history().size(), is(1)); + assertThat(nonEosMockProducer.history().get(0), is(record)); + } + + // error handling tests + + @Test + public void shouldFailOnInitTxIfEosDisabled() { + final IllegalStateException thrown = assertThrows( + IllegalStateException.class, + nonEosStreamsProducer::initTransaction + ); + + assertThat(thrown.getMessage(), is("Exactly-once is not enabled [test]")); + } + + @Test + public void shouldThrowStreamsExceptionOnSendError() { + nonEosMockProducer.sendException = new KafkaException("KABOOM!"); + + final StreamsException thrown = assertThrows( + StreamsException.class, + () -> nonEosStreamsProducer.send(record, null) + ); + + assertThat(thrown.getCause(), is(nonEosMockProducer.sendException)); + assertThat(thrown.getMessage(), is("Error encountered trying to send record to topic topic [test]")); + } + + @Test + public void shouldFailOnSendFatal() { + nonEosMockProducer.sendException = new RuntimeException("KABOOM!"); + + final RuntimeException thrown = assertThrows( + RuntimeException.class, + () -> nonEosStreamsProducer.send(record, null) + ); + + assertThat(thrown.getMessage(), is("KABOOM!")); + } + + @Test + public void shouldFailOnCommitIfEosDisabled() { + final IllegalStateException thrown = assertThrows( + IllegalStateException.class, + () -> nonEosStreamsProducer.commitTransaction(null, new ConsumerGroupMetadata("appId")) + ); + + assertThat(thrown.getMessage(), is("Exactly-once is not enabled [test]")); + } + + @Test + public void shouldFailOnAbortIfEosDisabled() { + final IllegalStateException thrown = assertThrows( + IllegalStateException.class, + nonEosStreamsProducer::abortTransaction + ); + + assertThat(thrown.getMessage(), is("Exactly-once is not enabled [test]")); + } + + + // EOS tests (alpha and beta) + + // functional tests + + @Test + public void shouldEnableEosIfEosAlphaEnabled() { + assertThat(eosAlphaStreamsProducer.eosEnabled(), is(true)); + } + + @Test + public void shouldEnableEosIfEosBetaEnabled() { + assertThat(eosBetaStreamsProducer.eosEnabled(), is(true)); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldSetTransactionIdUsingTaskIdIfEosAlphaEnabled() { + final Map mockMap = mock(Map.class); + expect(mockMap.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "appId-0_0")).andReturn(null); + expect(mockMap.get(ProducerConfig.TRANSACTIONAL_ID_CONFIG)).andReturn("appId-0_0"); + + final StreamsConfig mockConfig = mock(StreamsConfig.class); + expect(mockConfig.getProducerConfigs("threadId-0_0-producer")).andReturn(mockMap); + expect(mockConfig.getString(StreamsConfig.APPLICATION_ID_CONFIG)).andReturn("appId"); + expect(mockConfig.getString(StreamsConfig.PROCESSING_GUARANTEE_CONFIG)).andReturn(StreamsConfig.EXACTLY_ONCE); + + replay(mockMap, mockConfig); + + new StreamsProducer( + mockConfig, + "threadId", + eosAlphaMockClientSupplier, + new TaskId(0, 0), + null, + logContext, + mockTime + ); + + verify(mockMap); + } + + @Test + public void shouldSetTransactionIdUsingProcessIdIfEosV2Enabled() { + final UUID processId = UUID.randomUUID(); + + final Map mockMap = mock(Map.class); + expect(mockMap.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "appId-" + processId + "-0")).andReturn(null); + expect(mockMap.get(ProducerConfig.TRANSACTIONAL_ID_CONFIG)).andReturn("appId-" + processId); + + final StreamsConfig mockConfig = mock(StreamsConfig.class); + expect(mockConfig.getProducerConfigs("threadId-StreamThread-0-producer")).andReturn(mockMap); + expect(mockConfig.getString(StreamsConfig.APPLICATION_ID_CONFIG)).andReturn("appId"); + expect(mockConfig.getString(StreamsConfig.PROCESSING_GUARANTEE_CONFIG)).andReturn(StreamsConfig.EXACTLY_ONCE_V2).anyTimes(); + + replay(mockMap, mockConfig); + + new StreamsProducer( + mockConfig, + "threadId-StreamThread-0", + eosAlphaMockClientSupplier, + null, + processId, + logContext, + mockTime + ); + + verify(mockMap); + } + + @Test + public void shouldNotHaveEosEnabledIfEosAlphaEnable() { + assertThat(eosAlphaStreamsProducer.eosEnabled(), is(true)); + } + + @Test + public void shouldHaveEosEnabledIfEosBetaEnabled() { + assertThat(eosBetaStreamsProducer.eosEnabled(), is(true)); + } + + @Test + public void shouldInitTxOnEos() { + assertThat(eosAlphaMockProducer.transactionInitialized(), is(true)); + } + + @Test + public void shouldBeginTxOnEosSend() { + eosAlphaStreamsProducer.send(record, null); + assertThat(eosAlphaMockProducer.transactionInFlight(), is(true)); + } + + @Test + public void shouldContinueTxnSecondEosSend() { + eosAlphaStreamsProducer.send(record, null); + eosAlphaStreamsProducer.send(record, null); + assertThat(eosAlphaMockProducer.transactionInFlight(), is(true)); + assertThat(eosAlphaMockProducer.uncommittedRecords().size(), is(2)); + } + + @Test + public void shouldForwardRecordButNotCommitOnEosSend() { + eosAlphaStreamsProducer.send(record, null); + assertThat(eosAlphaMockProducer.transactionInFlight(), is(true)); + assertThat(eosAlphaMockProducer.history().isEmpty(), is(true)); + assertThat(eosAlphaMockProducer.uncommittedRecords().size(), is(1)); + assertThat(eosAlphaMockProducer.uncommittedRecords().get(0), is(record)); + } + + @Test + public void shouldBeginTxOnEosCommit() { + mockedProducer.initTransactions(); + mockedProducer.beginTransaction(); + mockedProducer.sendOffsetsToTransaction(offsetsAndMetadata, new ConsumerGroupMetadata("appId")); + mockedProducer.commitTransaction(); + expectLastCall(); + replay(mockedProducer); + + eosAlphaStreamsProducerWithMock.initTransaction(); + + eosAlphaStreamsProducerWithMock.commitTransaction(offsetsAndMetadata, new ConsumerGroupMetadata("appId")); + + verify(mockedProducer); + } + + @Test + public void shouldSendOffsetToTxOnEosCommit() { + eosAlphaStreamsProducer.commitTransaction(offsetsAndMetadata, new ConsumerGroupMetadata("appId")); + assertThat(eosAlphaMockProducer.sentOffsets(), is(true)); + } + + @Test + public void shouldCommitTxOnEosCommit() { + eosAlphaStreamsProducer.send(record, null); + assertThat(eosAlphaMockProducer.transactionInFlight(), is(true)); + + eosAlphaStreamsProducer.commitTransaction(offsetsAndMetadata, new ConsumerGroupMetadata("appId")); + + assertThat(eosAlphaMockProducer.transactionInFlight(), is(false)); + assertThat(eosAlphaMockProducer.uncommittedRecords().isEmpty(), is(true)); + assertThat(eosAlphaMockProducer.uncommittedOffsets().isEmpty(), is(true)); + assertThat(eosAlphaMockProducer.history().size(), is(1)); + assertThat(eosAlphaMockProducer.history().get(0), is(record)); + assertThat(eosAlphaMockProducer.consumerGroupOffsetsHistory().size(), is(1)); + assertThat(eosAlphaMockProducer.consumerGroupOffsetsHistory().get(0).get("appId"), is(offsetsAndMetadata)); + } + + @Test + public void shouldCommitTxWithApplicationIdOnEosAlphaCommit() { + mockedProducer.initTransactions(); + expectLastCall(); + mockedProducer.beginTransaction(); + expectLastCall(); + expect(mockedProducer.send(record, null)).andReturn(null); + mockedProducer.sendOffsetsToTransaction(null, new ConsumerGroupMetadata("appId")); + expectLastCall(); + mockedProducer.commitTransaction(); + expectLastCall(); + replay(mockedProducer); + + eosAlphaStreamsProducerWithMock.initTransaction(); + // call `send()` to start a transaction + eosAlphaStreamsProducerWithMock.send(record, null); + + eosAlphaStreamsProducerWithMock.commitTransaction(null, new ConsumerGroupMetadata("appId")); + + verify(mockedProducer); + } + + @Test + public void shouldCommitTxWithConsumerGroupMetadataOnEosBetaCommit() { + mockedProducer.initTransactions(); + expectLastCall(); + mockedProducer.beginTransaction(); + expectLastCall(); + expect(mockedProducer.send(record, null)).andReturn(null); + mockedProducer.sendOffsetsToTransaction(null, new ConsumerGroupMetadata("appId")); + expectLastCall(); + mockedProducer.commitTransaction(); + expectLastCall(); + replay(mockedProducer); + + final StreamsProducer streamsProducer = new StreamsProducer( + eosBetaConfig, + "threadId-StreamThread-0", + clientSupplier, + null, + UUID.randomUUID(), + logContext, + mockTime + ); + streamsProducer.initTransaction(); + // call `send()` to start a transaction + streamsProducer.send(record, null); + + streamsProducer.commitTransaction(null, new ConsumerGroupMetadata("appId")); + + verify(mockedProducer); + } + + @Test + public void shouldAbortTxOnEosAbort() { + // call `send()` to start a transaction + eosAlphaStreamsProducer.send(record, null); + assertThat(eosAlphaMockProducer.transactionInFlight(), is(true)); + assertThat(eosAlphaMockProducer.uncommittedRecords().size(), is(1)); + assertThat(eosAlphaMockProducer.uncommittedRecords().get(0), is(record)); + + eosAlphaStreamsProducer.abortTransaction(); + + assertThat(eosAlphaMockProducer.transactionInFlight(), is(false)); + assertThat(eosAlphaMockProducer.uncommittedRecords().isEmpty(), is(true)); + assertThat(eosAlphaMockProducer.uncommittedOffsets().isEmpty(), is(true)); + assertThat(eosAlphaMockProducer.history().isEmpty(), is(true)); + assertThat(eosAlphaMockProducer.consumerGroupOffsetsHistory().isEmpty(), is(true)); + } + + @Test + public void shouldSkipAbortTxOnEosAbortIfNotTxInFlight() { + mockedProducer.initTransactions(); + expectLastCall(); + replay(mockedProducer); + + eosAlphaStreamsProducerWithMock.initTransaction(); + + eosAlphaStreamsProducerWithMock.abortTransaction(); + + verify(mockedProducer); + } + + // error handling tests + + @Test + public void shouldFailIfTaskIdIsNullForEosAlpha() { + final NullPointerException thrown = assertThrows( + NullPointerException.class, + () -> new StreamsProducer( + eosAlphaConfig, + "threadId", + mockClientSupplier, + null, + UUID.randomUUID(), + logContext, + mockTime) + ); + + assertThat(thrown.getMessage(), is("taskId cannot be null for exactly-once alpha")); + } + + @Test + public void shouldFailIfProcessIdNullForEosBeta() { + final NullPointerException thrown = assertThrows( + NullPointerException.class, + () -> new StreamsProducer( + eosBetaConfig, + "threadId", + mockClientSupplier, + new TaskId(0, 0), + null, + logContext, + mockTime) + ); + + assertThat(thrown.getMessage(), is("processId cannot be null for exactly-once v2")); + } + + @Test + public void shouldThrowTimeoutExceptionOnEosInitTxTimeout() { + // use `nonEosMockProducer` instead of `eosMockProducer` to avoid double Tx-Init + nonEosMockProducer.initTransactionException = new TimeoutException("KABOOM!"); + final KafkaClientSupplier clientSupplier = new MockClientSupplier() { + @Override + public Producer getProducer(final Map config) { + return nonEosMockProducer; + } + }; + + final StreamsProducer streamsProducer = new StreamsProducer( + eosAlphaConfig, + "threadId", + clientSupplier, + new TaskId(0, 0), + null, + logContext, + mockTime + ); + + final TimeoutException thrown = assertThrows( + TimeoutException.class, + streamsProducer::initTransaction + ); + + assertThat(thrown.getMessage(), is("KABOOM!")); + } + + @Test + public void shouldFailOnMaybeBeginTransactionIfTransactionsNotInitializedForExactlyOnceAlpha() { + final StreamsProducer streamsProducer = + new StreamsProducer( + eosAlphaConfig, + "threadId", + eosAlphaMockClientSupplier, + new TaskId(0, 0), + null, + logContext, + mockTime + ); + + final IllegalStateException thrown = assertThrows( + IllegalStateException.class, + () -> streamsProducer.send(record, null) + ); + + assertThat(thrown.getMessage(), is("MockProducer hasn't been initialized for transactions.")); + } + + @Test + public void shouldFailOnMaybeBeginTransactionIfTransactionsNotInitializedForExactlyOnceBeta() { + final StreamsProducer streamsProducer = + new StreamsProducer( + eosBetaConfig, + "threadId-StreamThread-0", + eosBetaMockClientSupplier, + null, + UUID.randomUUID(), + logContext, + mockTime + ); + + final IllegalStateException thrown = assertThrows( + IllegalStateException.class, + () -> streamsProducer.send(record, null) + ); + + assertThat(thrown.getMessage(), is("MockProducer hasn't been initialized for transactions.")); + } + + @Test + public void shouldThrowStreamsExceptionOnEosInitError() { + // use `nonEosMockProducer` instead of `eosMockProducer` to avoid double Tx-Init + nonEosMockProducer.initTransactionException = new KafkaException("KABOOM!"); + final KafkaClientSupplier clientSupplier = new MockClientSupplier() { + @Override + public Producer getProducer(final Map config) { + return nonEosMockProducer; + } + }; + + final StreamsProducer streamsProducer = new StreamsProducer( + eosAlphaConfig, + "threadId", + clientSupplier, + new TaskId(0, 0), + null, + logContext, + mockTime + ); + + final StreamsException thrown = assertThrows( + StreamsException.class, + streamsProducer::initTransaction + ); + + assertThat(thrown.getCause(), is(nonEosMockProducer.initTransactionException)); + assertThat(thrown.getMessage(), is("Error encountered trying to initialize transactions [test]")); + } + + @Test + public void shouldFailOnEosInitFatal() { + // use `nonEosMockProducer` instead of `eosMockProducer` to avoid double Tx-Init + nonEosMockProducer.initTransactionException = new RuntimeException("KABOOM!"); + final KafkaClientSupplier clientSupplier = new MockClientSupplier() { + @Override + public Producer getProducer(final Map config) { + return nonEosMockProducer; + } + }; + + final StreamsProducer streamsProducer = new StreamsProducer( + eosAlphaConfig, + "threadId", + clientSupplier, + new TaskId(0, 0), + null, + logContext, + mockTime + ); + + final RuntimeException thrown = assertThrows( + RuntimeException.class, + streamsProducer::initTransaction + ); + + assertThat(thrown.getMessage(), is("KABOOM!")); + } + + @Test + public void shouldThrowTaskMigrateExceptionOnEosBeginTxnFenced() { + eosAlphaMockProducer.fenceProducer(); + + final TaskMigratedException thrown = assertThrows( + TaskMigratedException.class, + () -> eosAlphaStreamsProducer.send(null, null) + ); + + assertThat( + thrown.getMessage(), + is("Producer got fenced trying to begin a new transaction [test];" + + " it means all tasks belonging to this thread should be migrated.") + ); + } + + @Test + public void shouldThrowTaskMigrateExceptionOnEosBeginTxnError() { + eosAlphaMockProducer.beginTransactionException = new KafkaException("KABOOM!"); + + // calling `send()` implicitly starts a new transaction + final StreamsException thrown = assertThrows( + StreamsException.class, + () -> eosAlphaStreamsProducer.send(null, null)); + + assertThat(thrown.getCause(), is(eosAlphaMockProducer.beginTransactionException)); + assertThat( + thrown.getMessage(), + is("Error encountered trying to begin a new transaction [test]") + ); + } + + @Test + public void shouldFailOnEosBeginTxnFatal() { + eosAlphaMockProducer.beginTransactionException = new RuntimeException("KABOOM!"); + + // calling `send()` implicitly starts a new transaction + final RuntimeException thrown = assertThrows( + RuntimeException.class, + () -> eosAlphaStreamsProducer.send(null, null)); + + assertThat(thrown.getMessage(), is("KABOOM!")); + } + + @Test + public void shouldThrowTaskMigratedExceptionOnEosSendProducerFenced() { + testThrowTaskMigratedExceptionOnEosSend(new ProducerFencedException("KABOOM!")); + } + + @Test + public void shouldThrowTaskMigratedExceptionOnEosSendInvalidEpoch() { + testThrowTaskMigratedExceptionOnEosSend(new InvalidProducerEpochException("KABOOM!")); + } + + private void testThrowTaskMigratedExceptionOnEosSend(final RuntimeException exception) { + // we need to mimic that `send()` always wraps error in a KafkaException + // cannot use `eosMockProducer.fenceProducer()` because this would already trigger in `beginTransaction()` + eosAlphaMockProducer.sendException = new KafkaException(exception); + + final TaskMigratedException thrown = assertThrows( + TaskMigratedException.class, + () -> eosAlphaStreamsProducer.send(record, null) + ); + + assertThat(thrown.getCause(), is(exception)); + assertThat( + thrown.getMessage(), + is("Producer got fenced trying to send a record [test];" + + " it means all tasks belonging to this thread should be migrated.") + ); + } + + @Test + public void shouldThrowTaskMigratedExceptionOnEosSendUnknownPid() { + final UnknownProducerIdException exception = new UnknownProducerIdException("KABOOM!"); + // we need to mimic that `send()` always wraps error in a KafkaException + eosAlphaMockProducer.sendException = new KafkaException(exception); + + final TaskMigratedException thrown = assertThrows( + TaskMigratedException.class, + () -> eosAlphaStreamsProducer.send(record, null) + ); + + assertThat(thrown.getCause(), is(exception)); + assertThat( + thrown.getMessage(), + is("Producer got fenced trying to send a record [test];" + + " it means all tasks belonging to this thread should be migrated.") + ); + } + + @Test + public void shouldThrowTaskMigrateExceptionOnEosSendOffsetProducerFenced() { + // cannot use `eosMockProducer.fenceProducer()` because this would already trigger in `beginTransaction()` + testThrowTaskMigrateExceptionOnEosSendOffset(new ProducerFencedException("KABOOM!")); + } + + @Test + public void shouldThrowTaskMigrateExceptionOnEosSendOffsetInvalidEpoch() { + // cannot use `eosMockProducer.fenceProducer()` because this would already trigger in `beginTransaction()` + testThrowTaskMigrateExceptionOnEosSendOffset(new InvalidProducerEpochException("KABOOM!")); + } + + private void testThrowTaskMigrateExceptionOnEosSendOffset(final RuntimeException exception) { + // cannot use `eosMockProducer.fenceProducer()` because this would already trigger in `beginTransaction()` + eosAlphaMockProducer.sendOffsetsToTransactionException = exception; + + final TaskMigratedException thrown = assertThrows( + TaskMigratedException.class, + // we pass in `null` to verify that `sendOffsetsToTransaction()` fails instead of `commitTransaction()` + // `sendOffsetsToTransaction()` would throw an NPE on `null` offsets + () -> eosAlphaStreamsProducer.commitTransaction(null, new ConsumerGroupMetadata("appId")) + ); + + assertThat(thrown.getCause(), is(eosAlphaMockProducer.sendOffsetsToTransactionException)); + assertThat( + thrown.getMessage(), + is("Producer got fenced trying to commit a transaction [test];" + + " it means all tasks belonging to this thread should be migrated.") + ); + } + + @Test + public void shouldThrowStreamsExceptionOnEosSendOffsetError() { + eosAlphaMockProducer.sendOffsetsToTransactionException = new KafkaException("KABOOM!"); + + final StreamsException thrown = assertThrows( + StreamsException.class, + // we pass in `null` to verify that `sendOffsetsToTransaction()` fails instead of `commitTransaction()` + // `sendOffsetsToTransaction()` would throw an NPE on `null` offsets + () -> eosAlphaStreamsProducer.commitTransaction(null, new ConsumerGroupMetadata("appId")) + ); + + assertThat(thrown.getCause(), is(eosAlphaMockProducer.sendOffsetsToTransactionException)); + assertThat( + thrown.getMessage(), + is("Error encountered trying to commit a transaction [test]") + ); + } + + @Test + public void shouldFailOnEosSendOffsetFatal() { + eosAlphaMockProducer.sendOffsetsToTransactionException = new RuntimeException("KABOOM!"); + + final RuntimeException thrown = assertThrows( + RuntimeException.class, + // we pass in `null` to verify that `sendOffsetsToTransaction()` fails instead of `commitTransaction()` + // `sendOffsetsToTransaction()` would throw an NPE on `null` offsets + () -> eosAlphaStreamsProducer.commitTransaction(null, new ConsumerGroupMetadata("appId")) + ); + + assertThat(thrown.getMessage(), is("KABOOM!")); + } + + @Test + public void shouldThrowTaskMigratedExceptionOnEosCommitWithProducerFenced() { + testThrowTaskMigratedExceptionOnEos(new ProducerFencedException("KABOOM!")); + } + + @Test + public void shouldThrowTaskMigratedExceptionOnEosCommitWithInvalidEpoch() { + testThrowTaskMigratedExceptionOnEos(new InvalidProducerEpochException("KABOOM!")); + } + + private void testThrowTaskMigratedExceptionOnEos(final RuntimeException exception) { + // cannot use `eosMockProducer.fenceProducer()` because this would already trigger in `beginTransaction()` + eosAlphaMockProducer.commitTransactionException = exception; + + final TaskMigratedException thrown = assertThrows( + TaskMigratedException.class, + () -> eosAlphaStreamsProducer.commitTransaction(offsetsAndMetadata, new ConsumerGroupMetadata("appId")) + ); + + assertThat(eosAlphaMockProducer.sentOffsets(), is(true)); + assertThat(thrown.getCause(), is(eosAlphaMockProducer.commitTransactionException)); + assertThat( + thrown.getMessage(), + is("Producer got fenced trying to commit a transaction [test];" + + " it means all tasks belonging to this thread should be migrated.") + ); + } + + @Test + public void shouldThrowStreamsExceptionOnEosCommitTxError() { + eosAlphaMockProducer.commitTransactionException = new KafkaException("KABOOM!"); + + final StreamsException thrown = assertThrows( + StreamsException.class, + () -> eosAlphaStreamsProducer.commitTransaction(offsetsAndMetadata, new ConsumerGroupMetadata("appId")) + ); + + assertThat(eosAlphaMockProducer.sentOffsets(), is(true)); + assertThat(thrown.getCause(), is(eosAlphaMockProducer.commitTransactionException)); + assertThat( + thrown.getMessage(), + is("Error encountered trying to commit a transaction [test]") + ); + } + + @Test + public void shouldFailOnEosCommitTxFatal() { + eosAlphaMockProducer.commitTransactionException = new RuntimeException("KABOOM!"); + + final RuntimeException thrown = assertThrows( + RuntimeException.class, + () -> eosAlphaStreamsProducer.commitTransaction(offsetsAndMetadata, new ConsumerGroupMetadata("appId")) + ); + + assertThat(eosAlphaMockProducer.sentOffsets(), is(true)); + assertThat(thrown.getMessage(), is("KABOOM!")); + } + + @Test + public void shouldSwallowExceptionOnEosAbortTxProducerFenced() { + testSwallowExceptionOnEosAbortTx(new ProducerFencedException("KABOOM!")); + } + + @Test + public void shouldSwallowExceptionOnEosAbortTxInvalidEpoch() { + testSwallowExceptionOnEosAbortTx(new InvalidProducerEpochException("KABOOM!")); + } + + private void testSwallowExceptionOnEosAbortTx(final RuntimeException exception) { + mockedProducer.initTransactions(); + mockedProducer.beginTransaction(); + expect(mockedProducer.send(record, null)).andReturn(null); + mockedProducer.abortTransaction(); + expectLastCall().andThrow(exception); + replay(mockedProducer); + + eosAlphaStreamsProducerWithMock.initTransaction(); + // call `send()` to start a transaction + eosAlphaStreamsProducerWithMock.send(record, null); + + eosAlphaStreamsProducerWithMock.abortTransaction(); + + verify(mockedProducer); + } + + @Test + public void shouldThrowStreamsExceptionOnEosAbortTxError() { + eosAlphaMockProducer.abortTransactionException = new KafkaException("KABOOM!"); + // call `send()` to start a transaction + eosAlphaStreamsProducer.send(record, null); + + final StreamsException thrown = assertThrows(StreamsException.class, eosAlphaStreamsProducer::abortTransaction); + + assertThat(thrown.getCause(), is(eosAlphaMockProducer.abortTransactionException)); + assertThat( + thrown.getMessage(), + is("Error encounter trying to abort a transaction [test]") + ); + } + + @Test + public void shouldFailOnEosAbortTxFatal() { + eosAlphaMockProducer.abortTransactionException = new RuntimeException("KABOOM!"); + // call `send()` to start a transaction + eosAlphaStreamsProducer.send(record, null); + + final RuntimeException thrown = assertThrows(RuntimeException.class, eosAlphaStreamsProducer::abortTransaction); + + assertThat(thrown.getMessage(), is("KABOOM!")); + } + + + // EOS beta test + + // functional tests + + @Test + public void shouldCloseExistingProducerOnResetProducer() { + eosBetaStreamsProducer.resetProducer(); + + assertTrue(eosBetaMockProducer.closed()); + } + + @Test + public void shouldSetNewProducerOnResetProducer() { + eosBetaStreamsProducer.resetProducer(); + + assertThat(eosBetaMockClientSupplier.producers.size(), is(2)); + assertThat(eosBetaStreamsProducer.kafkaProducer(), is(eosBetaMockClientSupplier.producers.get(1))); + } + + @Test + public void shouldResetTransactionInitializedOnResetProducer() { + final StreamsProducer streamsProducer = new StreamsProducer( + eosBetaConfig, + "threadId-StreamThread-0", + clientSupplier, + null, + UUID.randomUUID(), + logContext, + mockTime + ); + streamsProducer.initTransaction(); + + reset(mockedProducer); + mockedProducer.close(); + mockedProducer.initTransactions(); + expectLastCall(); + expect(mockedProducer.metrics()).andReturn(Collections.emptyMap()).anyTimes(); + replay(mockedProducer); + + streamsProducer.resetProducer(); + streamsProducer.initTransaction(); + + verify(mockedProducer); + } + + @Test + public void shouldComputeTotalBlockedTime() { + setProducerMetrics( + nonEosMockProducer, + BUFFER_POOL_WAIT_TIME, + FLUSH_TME, + TXN_INIT_TIME, + TXN_BEGIN_TIME, + TXN_SEND_OFFSETS_TIME, + TXN_COMMIT_TIME, + TXN_ABORT_TIME + ); + + final double expectedTotalBlocked = BUFFER_POOL_WAIT_TIME + FLUSH_TME + TXN_INIT_TIME + + TXN_BEGIN_TIME + TXN_SEND_OFFSETS_TIME + TXN_COMMIT_TIME + TXN_ABORT_TIME; + assertThat(nonEosStreamsProducer.totalBlockedTime(), closeTo(expectedTotalBlocked, 0.01)); + } + + @Test + public void shouldComputeTotalBlockedTimeAfterReset() { + setProducerMetrics( + eosBetaMockProducer, + BUFFER_POOL_WAIT_TIME, + FLUSH_TME, + TXN_INIT_TIME, + TXN_BEGIN_TIME, + TXN_SEND_OFFSETS_TIME, + TXN_COMMIT_TIME, + TXN_ABORT_TIME + ); + final double expectedTotalBlocked = BUFFER_POOL_WAIT_TIME + FLUSH_TME + TXN_INIT_TIME + + TXN_BEGIN_TIME + TXN_SEND_OFFSETS_TIME + TXN_COMMIT_TIME + TXN_ABORT_TIME; + assertThat(eosBetaStreamsProducer.totalBlockedTime(), equalTo(expectedTotalBlocked)); + reset(mockTime); + final long closeStart = 1L; + final long clodeDelay = 1L; + expect(mockTime.nanoseconds()).andReturn(closeStart).andReturn(closeStart + clodeDelay); + replay(mockTime); + eosBetaStreamsProducer.resetProducer(); + setProducerMetrics( + eosBetaMockClientSupplier.producers.get(1), + BUFFER_POOL_WAIT_TIME, + FLUSH_TME, + TXN_INIT_TIME, + TXN_BEGIN_TIME, + TXN_SEND_OFFSETS_TIME, + TXN_COMMIT_TIME, + TXN_ABORT_TIME + ); + + assertThat( + eosBetaStreamsProducer.totalBlockedTime(), + closeTo(2 * expectedTotalBlocked + clodeDelay, 0.01) + ); + } + + private MetricName metricName(final String name) { + return new MetricName(name, "", "", Collections.emptyMap()); + } + + private void addMetric( + final MockProducer producer, + final String name, + final double value) { + final MetricName metricName = metricName(name); + producer.setMockMetrics(metricName, new Metric() { + @Override + public MetricName metricName() { + return metricName; + } + + @Override + public Object metricValue() { + return value; + } + }); + } + + private void setProducerMetrics( + final MockProducer producer, + final double bufferPoolWaitTime, + final double flushTime, + final double txnInitTime, + final double txnBeginTime, + final double txnSendOffsetsTime, + final double txnCommitTime, + final double txnAbortTime) { + addMetric(producer, "bufferpool-wait-time-ns-total", bufferPoolWaitTime); + addMetric(producer, "flush-time-ns-total", flushTime); + addMetric(producer, "txn-init-time-ns-total", txnInitTime); + addMetric(producer, "txn-begin-time-ns-total", txnBeginTime); + addMetric(producer, "txn-send-offsets-time-ns-total", txnSendOffsetsTime); + addMetric(producer, "txn-commit-time-ns-total", txnCommitTime); + addMetric(producer, "txn-abort-time-ns-total", txnAbortTime); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsRebalanceListenerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsRebalanceListenerTest.java new file mode 100644 index 0000000..31177dd --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/StreamsRebalanceListenerTest.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.errors.MissingSourceTopicException; +import org.apache.kafka.streams.errors.TaskAssignmentException; +import org.apache.kafka.streams.processor.internals.StreamThread.State; +import org.apache.kafka.streams.processor.internals.assignment.AssignorError; +import org.easymock.EasyMock; +import org.junit.Before; +import org.junit.Test; +import org.slf4j.LoggerFactory; + +import java.util.Collection; +import java.util.Collections; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.expectLastCall; +import static org.easymock.EasyMock.mock; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.verify; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertThrows; + +public class StreamsRebalanceListenerTest { + + private final TaskManager taskManager = mock(TaskManager.class); + private final StreamThread streamThread = mock(StreamThread.class); + private final AtomicInteger assignmentErrorCode = new AtomicInteger(); + private final MockTime time = new MockTime(); + private final StreamsRebalanceListener streamsRebalanceListener = new StreamsRebalanceListener( + time, + taskManager, + streamThread, + LoggerFactory.getLogger(StreamsRebalanceListenerTest.class), + assignmentErrorCode + ); + + @Before + public void before() { + expect(streamThread.state()).andStubReturn(null); + expect(taskManager.activeTaskIds()).andStubReturn(null); + expect(taskManager.standbyTaskIds()).andStubReturn(null); + } + + @Test + public void shouldThrowMissingSourceTopicException() { + taskManager.handleRebalanceComplete(); + expectLastCall(); + replay(taskManager, streamThread); + assignmentErrorCode.set(AssignorError.INCOMPLETE_SOURCE_TOPIC_METADATA.code()); + + final MissingSourceTopicException exception = assertThrows( + MissingSourceTopicException.class, + () -> streamsRebalanceListener.onPartitionsAssigned(Collections.emptyList()) + ); + assertThat(exception.getMessage(), is("One or more source topics were missing during rebalance")); + verify(taskManager, streamThread); + } + + @Test + public void shouldSwallowVersionProbingError() { + expect(streamThread.setState(State.PARTITIONS_ASSIGNED)).andStubReturn(State.PARTITIONS_REVOKED); + streamThread.setPartitionAssignedTime(time.milliseconds()); + taskManager.handleRebalanceComplete(); + replay(taskManager, streamThread); + assignmentErrorCode.set(AssignorError.VERSION_PROBING.code()); + streamsRebalanceListener.onPartitionsAssigned(Collections.emptyList()); + verify(taskManager, streamThread); + } + + @Test + public void shouldSendShutdown() { + streamThread.shutdownToError(); + EasyMock.expectLastCall(); + taskManager.handleRebalanceComplete(); + EasyMock.expectLastCall(); + replay(taskManager, streamThread); + assignmentErrorCode.set(AssignorError.SHUTDOWN_REQUESTED.code()); + streamsRebalanceListener.onPartitionsAssigned(Collections.emptyList()); + verify(taskManager, streamThread); + } + + @Test + public void shouldThrowTaskAssignmentException() { + taskManager.handleRebalanceComplete(); + expectLastCall(); + replay(taskManager, streamThread); + assignmentErrorCode.set(AssignorError.ASSIGNMENT_ERROR.code()); + + final TaskAssignmentException exception = assertThrows( + TaskAssignmentException.class, + () -> streamsRebalanceListener.onPartitionsAssigned(Collections.emptyList()) + ); + assertThat(exception.getMessage(), is("Hit an unexpected exception during task assignment phase of rebalance")); + verify(taskManager, streamThread); + } + + @Test + public void shouldThrowTaskAssignmentExceptionOnUnrecognizedErrorCode() { + replay(taskManager, streamThread); + assignmentErrorCode.set(Integer.MAX_VALUE); + + final TaskAssignmentException exception = assertThrows( + TaskAssignmentException.class, + () -> streamsRebalanceListener.onPartitionsAssigned(Collections.emptyList()) + ); + assertThat(exception.getMessage(), is("Hit an unrecognized exception during rebalance")); + verify(taskManager, streamThread); + } + + @Test + public void shouldHandleAssignedPartitions() { + taskManager.handleRebalanceComplete(); + expect(streamThread.setState(State.PARTITIONS_ASSIGNED)).andReturn(State.RUNNING); + streamThread.setPartitionAssignedTime(time.milliseconds()); + + replay(taskManager, streamThread); + assignmentErrorCode.set(AssignorError.NONE.code()); + + streamsRebalanceListener.onPartitionsAssigned(Collections.emptyList()); + + verify(taskManager, streamThread); + } + + @Test + public void shouldHandleRevokedPartitions() { + final Collection partitions = Collections.singletonList(new TopicPartition("topic", 0)); + expect(streamThread.setState(State.PARTITIONS_REVOKED)).andReturn(State.RUNNING); + taskManager.handleRevocation(partitions); + replay(streamThread, taskManager); + + streamsRebalanceListener.onPartitionsRevoked(partitions); + + verify(taskManager, streamThread); + } + + @Test + public void shouldNotHandleRevokedPartitionsIfStateCannotTransitToPartitionRevoked() { + expect(streamThread.setState(State.PARTITIONS_REVOKED)).andReturn(null); + replay(streamThread, taskManager); + + streamsRebalanceListener.onPartitionsRevoked(Collections.singletonList(new TopicPartition("topic", 0))); + + verify(taskManager, streamThread); + } + + @Test + public void shouldNotHandleEmptySetOfRevokedPartitions() { + expect(streamThread.setState(State.PARTITIONS_REVOKED)).andReturn(State.RUNNING); + replay(streamThread, taskManager); + + streamsRebalanceListener.onPartitionsRevoked(Collections.emptyList()); + + verify(taskManager, streamThread); + } + + @Test + public void shouldHandleLostPartitions() { + taskManager.handleLostAll(); + replay(streamThread, taskManager); + + streamsRebalanceListener.onPartitionsLost(Collections.singletonList(new TopicPartition("topic", 0))); + + verify(taskManager, streamThread); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java new file mode 100644 index 0000000..1067b66 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java @@ -0,0 +1,3519 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.DeleteRecordsResult; +import org.apache.kafka.clients.admin.DeletedRecords; +import org.apache.kafka.clients.admin.RecordsToDelete; +import org.apache.kafka.clients.consumer.CommitFailedException; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerGroupMetadata; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.metrics.KafkaMetric; +import org.apache.kafka.common.metrics.Measurable; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.LockException; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.errors.TaskCorruptedException; +import org.apache.kafka.streams.errors.TaskMigratedException; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.StateDirectory.TaskDirectory; +import org.apache.kafka.streams.processor.internals.StreamThread.ProcessingMode; +import org.apache.kafka.streams.processor.internals.Task.State; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.processor.internals.testutil.DummyStreamsConfig; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.apache.kafka.streams.state.internals.OffsetCheckpoint; + +import java.util.ArrayList; +import org.easymock.EasyMock; +import org.easymock.EasyMockRunner; +import org.easymock.Mock; +import org.easymock.MockType; +import org.hamcrest.Matchers; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; + +import java.io.File; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Deque; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; + +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; +import static java.util.Collections.emptyMap; +import static java.util.Collections.emptySet; +import static java.util.Collections.singleton; +import static java.util.Collections.singletonList; +import static java.util.Collections.singletonMap; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.common.utils.Utils.union; +import static org.easymock.EasyMock.anyObject; +import static org.easymock.EasyMock.anyString; +import static org.easymock.EasyMock.eq; +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.expectLastCall; +import static org.easymock.EasyMock.mock; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.reset; +import static org.easymock.EasyMock.resetToStrict; +import static org.easymock.EasyMock.verify; +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.core.IsEqual.equalTo; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +@RunWith(EasyMockRunner.class) +public class TaskManagerTest { + + private final String topic1 = "topic1"; + private final String topic2 = "topic2"; + + private final TaskId taskId00 = new TaskId(0, 0); + private final TopicPartition t1p0 = new TopicPartition(topic1, 0); + private final TopicPartition t1p0changelog = new TopicPartition("changelog", 0); + private final Set taskId00Partitions = mkSet(t1p0); + private final Set taskId00ChangelogPartitions = mkSet(t1p0changelog); + private final Map> taskId00Assignment = singletonMap(taskId00, taskId00Partitions); + + private final TaskId taskId01 = new TaskId(0, 1); + private final TopicPartition t1p1 = new TopicPartition(topic1, 1); + private final TopicPartition t1p1changelog = new TopicPartition("changelog", 1); + private final Set taskId01Partitions = mkSet(t1p1); + private final Set taskId01ChangelogPartitions = mkSet(t1p1changelog); + private final Map> taskId01Assignment = singletonMap(taskId01, taskId01Partitions); + + private final TaskId taskId02 = new TaskId(0, 2); + private final TopicPartition t1p2 = new TopicPartition(topic1, 2); + private final Set taskId02Partitions = mkSet(t1p2); + + private final TaskId taskId03 = new TaskId(0, 3); + private final TopicPartition t1p3 = new TopicPartition(topic1, 3); + private final Set taskId03Partitions = mkSet(t1p3); + + private final TaskId taskId04 = new TaskId(0, 4); + private final TopicPartition t1p4 = new TopicPartition(topic1, 4); + private final Set taskId04Partitions = mkSet(t1p4); + + private final TaskId taskId05 = new TaskId(0, 5); + private final TopicPartition t1p5 = new TopicPartition(topic1, 5); + private final Set taskId05Partitions = mkSet(t1p5); + + private final TaskId taskId10 = new TaskId(1, 0); + private final TopicPartition t2p0 = new TopicPartition(topic2, 0); + private final Set taskId10Partitions = mkSet(t2p0); + + @Mock(type = MockType.STRICT) + private InternalTopologyBuilder topologyBuilder; + @Mock(type = MockType.DEFAULT) + private StateDirectory stateDirectory; + @Mock(type = MockType.NICE) + private ChangelogReader changeLogReader; + @Mock(type = MockType.STRICT) + private Consumer consumer; + @Mock(type = MockType.STRICT) + private ActiveTaskCreator activeTaskCreator; + @Mock(type = MockType.NICE) + private StandbyTaskCreator standbyTaskCreator; + @Mock(type = MockType.NICE) + private Admin adminClient; + + private TaskManager taskManager; + private final Time time = new MockTime(); + + @Rule + public final TemporaryFolder testFolder = new TemporaryFolder(); + + @Before + public void setUp() { + setUpTaskManager(StreamThread.ProcessingMode.AT_LEAST_ONCE); + } + + private void setUpTaskManager(final StreamThread.ProcessingMode processingMode) { + taskManager = new TaskManager( + time, + changeLogReader, + UUID.randomUUID(), + "taskManagerTest", + new StreamsMetricsImpl(new Metrics(), "clientId", StreamsConfig.METRICS_LATEST, time), + activeTaskCreator, + standbyTaskCreator, + new TopologyMetadata(topologyBuilder, new DummyStreamsConfig()), + adminClient, + stateDirectory, + processingMode + ); + taskManager.setMainConsumer(consumer); + reset(topologyBuilder); + expect(topologyBuilder.hasNamedTopology()).andStubReturn(false); + activeTaskCreator.removeRevokedUnknownTasks(anyObject()); + expectLastCall().asStub(); + standbyTaskCreator.removeRevokedUnknownTasks(anyObject()); + expectLastCall().asStub(); + } + + @Test + public void shouldIdempotentlyUpdateSubscriptionFromActiveAssignment() { + final TopicPartition newTopicPartition = new TopicPartition("topic2", 1); + final Map> assignment = mkMap(mkEntry(taskId01, mkSet(t1p1, newTopicPartition))); + + expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andStubReturn(emptyList()); + + topologyBuilder.addSubscribedTopicsFromAssignment(eq(asList(t1p1, newTopicPartition)), anyString()); + expectLastCall(); + + replay(activeTaskCreator, topologyBuilder); + + taskManager.handleAssignment(assignment, emptyMap()); + + verify(activeTaskCreator, topologyBuilder); + } + + @Test + public void shouldNotLockAnythingIfStateDirIsEmpty() { + expect(stateDirectory.listNonEmptyTaskDirectories()).andReturn(new ArrayList<>()).once(); + + replay(stateDirectory); + taskManager.handleRebalanceStart(singleton("topic")); + + verify(stateDirectory); + assertTrue(taskManager.lockedTaskDirectories().isEmpty()); + } + + @Test + public void shouldTryToLockValidTaskDirsAtRebalanceStart() throws Exception { + expectLockObtainedFor(taskId01); + expectLockFailedFor(taskId10); + + makeTaskFolders( + taskId01.toString(), + taskId10.toString(), + "dummy" + ); + replay(stateDirectory); + taskManager.handleRebalanceStart(singleton("topic")); + + verify(stateDirectory); + assertThat(taskManager.lockedTaskDirectories(), is(singleton(taskId01))); + } + + @Test + public void shouldReleaseLockForUnassignedTasksAfterRebalance() throws Exception { + expectLockObtainedFor(taskId00, taskId01, taskId02); + expectUnlockFor(taskId02); + + makeTaskFolders( + taskId00.toString(), // active task + taskId01.toString(), // standby task + taskId02.toString() // unassigned but able to lock + ); + replay(stateDirectory); + taskManager.handleRebalanceStart(singleton("topic")); + + assertThat(taskManager.lockedTaskDirectories(), is(mkSet(taskId00, taskId01, taskId02))); + + handleAssignment(taskId00Assignment, taskId01Assignment, emptyMap()); + reset(consumer); + expectConsumerAssignmentPaused(consumer); + replay(consumer); + + taskManager.handleRebalanceComplete(); + assertThat(taskManager.lockedTaskDirectories(), is(mkSet(taskId00, taskId01))); + verify(stateDirectory); + } + + @Test + public void shouldReportLatestOffsetAsOffsetSumForRunningTask() throws Exception { + final Map changelogOffsets = mkMap( + mkEntry(new TopicPartition("changelog", 0), Task.LATEST_OFFSET), + mkEntry(new TopicPartition("changelog", 1), Task.LATEST_OFFSET) + ); + final Map expectedOffsetSums = mkMap(mkEntry(taskId00, Task.LATEST_OFFSET)); + + computeOffsetSumAndVerify(changelogOffsets, expectedOffsetSums); + } + + @Test + public void shouldComputeOffsetSumForNonRunningActiveTask() throws Exception { + final Map changelogOffsets = mkMap( + mkEntry(new TopicPartition("changelog", 0), 5L), + mkEntry(new TopicPartition("changelog", 1), 10L) + ); + final Map expectedOffsetSums = mkMap(mkEntry(taskId00, 15L)); + + computeOffsetSumAndVerify(changelogOffsets, expectedOffsetSums); + } + + @Test + public void shouldSkipUnknownOffsetsWhenComputingOffsetSum() throws Exception { + final Map changelogOffsets = mkMap( + mkEntry(new TopicPartition("changelog", 0), OffsetCheckpoint.OFFSET_UNKNOWN), + mkEntry(new TopicPartition("changelog", 1), 10L) + ); + final Map expectedOffsetSums = mkMap(mkEntry(taskId00, 10L)); + + computeOffsetSumAndVerify(changelogOffsets, expectedOffsetSums); + } + + private void computeOffsetSumAndVerify(final Map changelogOffsets, + final Map expectedOffsetSums) throws Exception { + expectLockObtainedFor(taskId00); + makeTaskFolders(taskId00.toString()); + replay(stateDirectory); + + taskManager.handleRebalanceStart(singleton("topic")); + final StateMachineTask restoringTask = handleAssignment( + emptyMap(), + emptyMap(), + taskId00Assignment + ).get(taskId00); + restoringTask.setChangelogOffsets(changelogOffsets); + + assertThat(taskManager.getTaskOffsetSums(), is(expectedOffsetSums)); + } + + @Test + public void shouldComputeOffsetSumForStandbyTask() throws Exception { + final Map changelogOffsets = mkMap( + mkEntry(new TopicPartition("changelog", 0), 5L), + mkEntry(new TopicPartition("changelog", 1), 10L) + ); + final Map expectedOffsetSums = mkMap(mkEntry(taskId00, 15L)); + + expectLockObtainedFor(taskId00); + makeTaskFolders(taskId00.toString()); + replay(stateDirectory); + + taskManager.handleRebalanceStart(singleton("topic")); + final StateMachineTask restoringTask = handleAssignment( + emptyMap(), + taskId00Assignment, + emptyMap() + ).get(taskId00); + restoringTask.setChangelogOffsets(changelogOffsets); + + assertThat(taskManager.getTaskOffsetSums(), is(expectedOffsetSums)); + } + + @Test + public void shouldComputeOffsetSumForUnassignedTaskWeCanLock() throws Exception { + final Map changelogOffsets = mkMap( + mkEntry(new TopicPartition("changelog", 0), 5L), + mkEntry(new TopicPartition("changelog", 1), 10L) + ); + final Map expectedOffsetSums = mkMap(mkEntry(taskId00, 15L)); + + expectLockObtainedFor(taskId00); + makeTaskFolders(taskId00.toString()); + writeCheckpointFile(taskId00, changelogOffsets); + + replay(stateDirectory); + taskManager.handleRebalanceStart(singleton("topic")); + + assertThat(taskManager.getTaskOffsetSums(), is(expectedOffsetSums)); + } + + @Test + public void shouldComputeOffsetSumFromCheckpointFileForUninitializedTask() throws Exception { + final Map changelogOffsets = mkMap( + mkEntry(new TopicPartition("changelog", 0), 5L), + mkEntry(new TopicPartition("changelog", 1), 10L) + ); + final Map expectedOffsetSums = mkMap(mkEntry(taskId00, 15L)); + + expectLockObtainedFor(taskId00); + makeTaskFolders(taskId00.toString()); + writeCheckpointFile(taskId00, changelogOffsets); + replay(stateDirectory); + + taskManager.handleRebalanceStart(singleton("topic")); + final StateMachineTask uninitializedTask = new StateMachineTask(taskId00, taskId00Partitions, true); + expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andStubReturn(singleton(uninitializedTask)); + replay(activeTaskCreator); + taskManager.handleAssignment(taskId00Assignment, emptyMap()); + + assertThat(uninitializedTask.state(), is(State.CREATED)); + + assertThat(taskManager.getTaskOffsetSums(), is(expectedOffsetSums)); + } + + @Test + public void shouldComputeOffsetSumFromCheckpointFileForClosedTask() throws Exception { + final Map changelogOffsets = mkMap( + mkEntry(new TopicPartition("changelog", 0), 5L), + mkEntry(new TopicPartition("changelog", 1), 10L) + ); + final Map expectedOffsetSums = mkMap(mkEntry(taskId00, 15L)); + + expectLockObtainedFor(taskId00); + makeTaskFolders(taskId00.toString()); + writeCheckpointFile(taskId00, changelogOffsets); + replay(stateDirectory); + + final StateMachineTask closedTask = new StateMachineTask(taskId00, taskId00Partitions, true); + + taskManager.handleRebalanceStart(singleton("topic")); + expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andStubReturn(singleton(closedTask)); + replay(activeTaskCreator); + taskManager.handleAssignment(taskId00Assignment, emptyMap()); + + closedTask.suspend(); + closedTask.closeClean(); + assertThat(closedTask.state(), is(State.CLOSED)); + + assertThat(taskManager.getTaskOffsetSums(), is(expectedOffsetSums)); + } + + @Test + public void shouldNotReportOffsetSumsForTaskWeCantLock() throws Exception { + expectLockFailedFor(taskId00); + makeTaskFolders(taskId00.toString()); + replay(stateDirectory); + taskManager.handleRebalanceStart(singleton("topic")); + assertTrue(taskManager.lockedTaskDirectories().isEmpty()); + + assertTrue(taskManager.getTaskOffsetSums().isEmpty()); + } + + @Test + public void shouldNotReportOffsetSumsAndReleaseLockForUnassignedTaskWithoutCheckpoint() throws Exception { + expectLockObtainedFor(taskId00); + makeTaskFolders(taskId00.toString()); + expect(stateDirectory.checkpointFileFor(taskId00)).andReturn(getCheckpointFile(taskId00)); + replay(stateDirectory); + taskManager.handleRebalanceStart(singleton("topic")); + + assertTrue(taskManager.getTaskOffsetSums().isEmpty()); + verify(stateDirectory); + } + + @Test + public void shouldPinOffsetSumToLongMaxValueInCaseOfOverflow() throws Exception { + final long largeOffset = Long.MAX_VALUE / 2; + final Map changelogOffsets = mkMap( + mkEntry(new TopicPartition("changelog", 1), largeOffset), + mkEntry(new TopicPartition("changelog", 2), largeOffset), + mkEntry(new TopicPartition("changelog", 3), largeOffset) + ); + final Map expectedOffsetSums = mkMap(mkEntry(taskId00, Long.MAX_VALUE)); + + expectLockObtainedFor(taskId00); + makeTaskFolders(taskId00.toString()); + writeCheckpointFile(taskId00, changelogOffsets); + replay(stateDirectory); + taskManager.handleRebalanceStart(singleton("topic")); + + assertThat(taskManager.getTaskOffsetSums(), is(expectedOffsetSums)); + } + + @Test + public void shouldCloseActiveUnassignedSuspendedTasksWhenClosingRevokedTasks() { + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true); + final Map offsets = singletonMap(t1p0, new OffsetAndMetadata(0L, null)); + task00.setCommittableOffsetsAndMetadata(offsets); + + // first `handleAssignment` + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andStubReturn(singletonList(task00)); + expect(activeTaskCreator.createTasks(anyObject(), eq(emptyMap()))).andStubReturn(emptyList()); + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId00); + expectLastCall(); + expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(emptyList()); + topologyBuilder.addSubscribedTopicsFromAssignment(anyObject(), anyString()); + expectLastCall().anyTimes(); + + // `handleRevocation` + consumer.commitSync(offsets); + expectLastCall(); + + // second `handleAssignment` + consumer.commitSync(offsets); + expectLastCall(); + + replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer, changeLogReader); + + taskManager.handleAssignment(taskId00Assignment, emptyMap()); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + assertThat(task00.state(), is(Task.State.RUNNING)); + + taskManager.handleRevocation(taskId00Partitions); + assertThat(task00.state(), is(Task.State.SUSPENDED)); + + taskManager.handleAssignment(emptyMap(), emptyMap()); + assertThat(task00.state(), is(Task.State.CLOSED)); + assertThat(taskManager.activeTaskMap(), Matchers.anEmptyMap()); + assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap()); + } + + @Test + public void shouldCloseDirtyActiveUnassignedTasksWhenErrorCleanClosingTask() { + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true) { + @Override + public void closeClean() { + throw new RuntimeException("KABOOM!"); + } + }; + + // first `handleAssignment` + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andStubReturn(singletonList(task00)); + expect(activeTaskCreator.createTasks(anyObject(), eq(emptyMap()))).andStubReturn(emptyList()); + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId00); + expectLastCall(); + expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(emptyList()); + topologyBuilder.addSubscribedTopicsFromAssignment(anyObject(), anyString()); + expectLastCall().anyTimes(); + + replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer, changeLogReader); + + taskManager.handleAssignment(taskId00Assignment, emptyMap()); + taskManager.handleRevocation(taskId00Partitions); + + final RuntimeException thrown = assertThrows( + RuntimeException.class, + () -> taskManager.handleAssignment(emptyMap(), emptyMap()) + ); + + assertThat(task00.state(), is(Task.State.CLOSED)); + assertThat( + thrown.getMessage(), + is("Unexpected failure to close 1 task(s) [[0_0]]. First unexpected exception (for task 0_0) follows.") + ); + assertThat(thrown.getCause().getMessage(), is("KABOOM!")); + } + + @Test + public void shouldCloseActiveTasksWhenHandlingLostTasks() throws Exception { + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true); + final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, false); + + // `handleAssignment` + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andStubReturn(singletonList(task00)); + expect(standbyTaskCreator.createTasks(eq(taskId01Assignment))).andStubReturn(singletonList(task01)); + topologyBuilder.addSubscribedTopicsFromAssignment(anyObject(), anyString()); + expectLastCall().anyTimes(); + + makeTaskFolders(taskId00.toString(), taskId01.toString()); + expectLockObtainedFor(taskId00, taskId01); + + // The second attempt will return empty tasks. + makeTaskFolders(); + expectLockObtainedFor(); + replay(stateDirectory); + + taskManager.handleRebalanceStart(emptySet()); + assertThat(taskManager.lockedTaskDirectories(), Matchers.is(mkSet(taskId00, taskId01))); + + // `handleLostAll` + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId00); + expectLastCall(); + + replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer, changeLogReader); + + taskManager.handleAssignment(taskId00Assignment, taskId01Assignment); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + assertThat(task00.state(), is(Task.State.RUNNING)); + assertThat(task01.state(), is(Task.State.RUNNING)); + + taskManager.handleLostAll(); + assertThat(task00.commitPrepared, is(true)); + assertThat(task00.state(), is(Task.State.CLOSED)); + assertThat(task01.state(), is(Task.State.RUNNING)); + assertThat(taskManager.activeTaskMap(), Matchers.anEmptyMap()); + assertThat(taskManager.standbyTaskMap(), is(singletonMap(taskId01, task01))); + + // The locked task map will not be cleared. + assertThat(taskManager.lockedTaskDirectories(), is(mkSet(taskId00, taskId01))); + + taskManager.handleRebalanceStart(emptySet()); + + assertThat(taskManager.lockedTaskDirectories(), is(emptySet())); + } + + @Test + public void shouldReInitializeThreadProducerOnHandleLostAllIfEosV2Enabled() { + activeTaskCreator.reInitializeThreadProducer(); + expectLastCall(); + + setUpTaskManager(ProcessingMode.EXACTLY_ONCE_V2); + + replay(activeTaskCreator); + + taskManager.handleLostAll(); + + verify(activeTaskCreator); + } + + @Test + public void shouldThrowWhenHandlingClosingTasksOnProducerCloseError() { + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true); + final Map offsets = singletonMap(t1p0, new OffsetAndMetadata(0L, null)); + task00.setCommittableOffsetsAndMetadata(offsets); + + // `handleAssignment` + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andStubReturn(singletonList(task00)); + expect(standbyTaskCreator.createTasks(anyObject())).andStubReturn(emptyList()); + topologyBuilder.addSubscribedTopicsFromAssignment(anyObject(), anyString()); + expectLastCall().anyTimes(); + + // `handleAssignment` + consumer.commitSync(offsets); + expectLastCall(); + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId00); + expectLastCall().andThrow(new RuntimeException("KABOOM!")); + + replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer, changeLogReader); + + taskManager.handleAssignment(taskId00Assignment, emptyMap()); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + assertThat(task00.state(), is(Task.State.RUNNING)); + + taskManager.handleRevocation(taskId00Partitions); + + final RuntimeException thrown = assertThrows( + RuntimeException.class, + () -> taskManager.handleAssignment(emptyMap(), emptyMap()) + ); + + assertThat( + thrown.getMessage(), + is("Unexpected failure to close 1 task(s) [[0_0]]. First unexpected exception (for task 0_0) follows.") + ); + assertThat(thrown.getCause(), instanceOf(RuntimeException.class)); + assertThat(thrown.getCause().getMessage(), is("KABOOM!")); + } + + @Test + public void shouldReviveCorruptTasks() { + final ProcessorStateManager stateManager = EasyMock.createStrictMock(ProcessorStateManager.class); + stateManager.markChangelogAsCorrupted(taskId00Partitions); + EasyMock.expectLastCall().once(); + replay(stateManager); + + final AtomicBoolean enforcedCheckpoint = new AtomicBoolean(false); + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true, stateManager) { + @Override + public void postCommit(final boolean enforceCheckpoint) { + if (enforceCheckpoint) { + enforcedCheckpoint.set(true); + } + super.postCommit(enforceCheckpoint); + } + }; + + // `handleAssignment` + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andStubReturn(singletonList(task00)); + topologyBuilder.addSubscribedTopicsFromAssignment(anyObject(), anyString()); + expectLastCall().anyTimes(); + expect(consumer.assignment()).andReturn(taskId00Partitions); + replay(activeTaskCreator, topologyBuilder, consumer, changeLogReader); + + taskManager.handleAssignment(taskId00Assignment, emptyMap()); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), tp -> assertThat(tp, is(empty()))), is(true)); + assertThat(task00.state(), is(Task.State.RUNNING)); + + task00.setChangelogOffsets(singletonMap(t1p0, 0L)); + taskManager.handleCorruption(singleton(taskId00)); + + assertThat(task00.commitPrepared, is(true)); + assertThat(task00.state(), is(Task.State.CREATED)); + assertThat(task00.partitionsForOffsetReset, equalTo(taskId00Partitions)); + assertThat(enforcedCheckpoint.get(), is(true)); + assertThat(taskManager.activeTaskMap(), is(singletonMap(taskId00, task00))); + assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap()); + + verify(stateManager); + verify(consumer); + } + + @Test + public void shouldReviveCorruptTasksEvenIfTheyCannotCloseClean() { + final ProcessorStateManager stateManager = EasyMock.createStrictMock(ProcessorStateManager.class); + stateManager.markChangelogAsCorrupted(taskId00Partitions); + replay(stateManager); + + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true, stateManager) { + @Override + public void suspend() { + super.suspend(); + throw new RuntimeException("oops"); + } + }; + + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andStubReturn(singletonList(task00)); + topologyBuilder.addSubscribedTopicsFromAssignment(anyObject(), anyString()); + expectLastCall().anyTimes(); + expect(consumer.assignment()).andReturn(taskId00Partitions); + replay(activeTaskCreator, topologyBuilder, consumer, changeLogReader); + + taskManager.handleAssignment(taskId00Assignment, emptyMap()); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), tp -> assertThat(tp, is(empty()))), is(true)); + assertThat(task00.state(), is(Task.State.RUNNING)); + + task00.setChangelogOffsets(singletonMap(t1p0, 0L)); + taskManager.handleCorruption(singleton(taskId00)); + assertThat(task00.commitPrepared, is(true)); + assertThat(task00.state(), is(Task.State.CREATED)); + assertThat(task00.partitionsForOffsetReset, equalTo(taskId00Partitions)); + assertThat(taskManager.activeTaskMap(), is(singletonMap(taskId00, task00))); + assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap()); + + verify(stateManager); + verify(consumer); + } + + @Test + public void shouldCommitNonCorruptedTasksOnTaskCorruptedException() { + final ProcessorStateManager stateManager = EasyMock.createStrictMock(ProcessorStateManager.class); + stateManager.markChangelogAsCorrupted(taskId00Partitions); + replay(stateManager); + + final StateMachineTask corruptedTask = new StateMachineTask(taskId00, taskId00Partitions, true, stateManager); + final StateMachineTask nonCorruptedTask = new StateMachineTask(taskId01, taskId01Partitions, true, stateManager); + + final Map> assignment = new HashMap<>(taskId00Assignment); + assignment.putAll(taskId01Assignment); + + // `handleAssignment` + expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))) + .andStubReturn(asList(corruptedTask, nonCorruptedTask)); + topologyBuilder.addSubscribedTopicsFromAssignment(anyObject(), anyString()); + expectLastCall().anyTimes(); + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(consumer.assignment()).andReturn(taskId00Partitions); + // check that we should not commit empty map either + consumer.commitSync(eq(emptyMap())); + expectLastCall().andStubThrow(new AssertionError("should not invoke commitSync when offset map is empty")); + replay(activeTaskCreator, topologyBuilder, consumer, changeLogReader); + + taskManager.handleAssignment(assignment, emptyMap()); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), tp -> assertThat(tp, is(empty()))), is(true)); + + assertThat(nonCorruptedTask.state(), is(Task.State.RUNNING)); + nonCorruptedTask.setCommitNeeded(); + + corruptedTask.setChangelogOffsets(singletonMap(t1p0, 0L)); + taskManager.handleCorruption(singleton(taskId00)); + + assertTrue(nonCorruptedTask.commitPrepared); + assertThat(nonCorruptedTask.partitionsForOffsetReset, equalTo(Collections.emptySet())); + assertThat(corruptedTask.partitionsForOffsetReset, equalTo(taskId00Partitions)); + + verify(consumer); + } + + @Test + public void shouldNotCommitNonRunningNonCorruptedTasks() { + final ProcessorStateManager stateManager = EasyMock.createStrictMock(ProcessorStateManager.class); + stateManager.markChangelogAsCorrupted(taskId00Partitions); + replay(stateManager); + + final StateMachineTask corruptedTask = new StateMachineTask(taskId00, taskId00Partitions, true, stateManager); + final StateMachineTask nonRunningNonCorruptedTask = new StateMachineTask(taskId01, taskId01Partitions, true, stateManager); + + nonRunningNonCorruptedTask.setCommitNeeded(); + + final Map> assignment = new HashMap<>(taskId00Assignment); + assignment.putAll(taskId01Assignment); + + // `handleAssignment` + expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))) + .andStubReturn(asList(corruptedTask, nonRunningNonCorruptedTask)); + topologyBuilder.addSubscribedTopicsFromAssignment(anyObject(), anyString()); + expectLastCall().anyTimes(); + expect(consumer.assignment()).andReturn(taskId00Partitions); + replay(activeTaskCreator, topologyBuilder, consumer, changeLogReader); + + taskManager.handleAssignment(assignment, emptyMap()); + + corruptedTask.setChangelogOffsets(singletonMap(t1p0, 0L)); + taskManager.handleCorruption(singleton(taskId00)); + + assertThat(nonRunningNonCorruptedTask.state(), is(Task.State.CREATED)); + assertThat(nonRunningNonCorruptedTask.partitionsForOffsetReset, equalTo(Collections.emptySet())); + assertThat(corruptedTask.partitionsForOffsetReset, equalTo(taskId00Partitions)); + + verify(activeTaskCreator); + assertFalse(nonRunningNonCorruptedTask.commitPrepared); + verify(consumer); + } + + @Test + public void shouldCleanAndReviveCorruptedStandbyTasksBeforeCommittingNonCorruptedTasks() { + final ProcessorStateManager stateManager = EasyMock.createStrictMock(ProcessorStateManager.class); + stateManager.markChangelogAsCorrupted(taskId00Partitions); + replay(stateManager); + + final StateMachineTask corruptedStandby = new StateMachineTask(taskId00, taskId00Partitions, false, stateManager); + final StateMachineTask runningNonCorruptedActive = new StateMachineTask(taskId01, taskId01Partitions, true, stateManager) { + @Override + public Map prepareCommit() { + throw new TaskMigratedException("You dropped out of the group!", new RuntimeException()); + } + }; + + // handleAssignment + expect(standbyTaskCreator.createTasks(eq(taskId00Assignment))).andStubReturn(singleton(corruptedStandby)); + expect(activeTaskCreator.createTasks(anyObject(), eq(taskId01Assignment))).andStubReturn(singleton(runningNonCorruptedActive)); + topologyBuilder.addSubscribedTopicsFromAssignment(anyObject(), anyString()); + expectLastCall().anyTimes(); + + expectRestoreToBeCompleted(consumer, changeLogReader); + + replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer, changeLogReader); + + taskManager.handleAssignment(taskId01Assignment, taskId00Assignment); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + + // make sure this will be committed and throw + assertThat(runningNonCorruptedActive.state(), is(Task.State.RUNNING)); + assertThat(corruptedStandby.state(), is(Task.State.RUNNING)); + + runningNonCorruptedActive.setCommitNeeded(); + + corruptedStandby.setChangelogOffsets(singletonMap(t1p0, 0L)); + assertThrows(TaskMigratedException.class, () -> taskManager.handleCorruption(singleton(taskId00))); + + + assertThat(corruptedStandby.commitPrepared, is(true)); + assertThat(corruptedStandby.state(), is(Task.State.CREATED)); + verify(consumer); + } + + @Test + public void shouldNotAttemptToCommitInHandleCorruptedDuringARebalance() { + final ProcessorStateManager stateManager = EasyMock.createNiceMock(ProcessorStateManager.class); + expect(stateDirectory.listNonEmptyTaskDirectories()).andStubReturn(new ArrayList<>()); + + final StateMachineTask corruptedActive = new StateMachineTask(taskId00, taskId00Partitions, true, stateManager); + + // make sure this will attempt to be committed and throw + final StateMachineTask uncorruptedActive = new StateMachineTask(taskId01, taskId01Partitions, true, stateManager); + final Map offsets = singletonMap(t1p1, new OffsetAndMetadata(0L, null)); + uncorruptedActive.setCommitNeeded(); + + // handleAssignment + final Map> assignment = new HashMap<>(); + assignment.putAll(taskId00Assignment); + assignment.putAll(taskId01Assignment); + expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andStubReturn(asList(corruptedActive, uncorruptedActive)); + topologyBuilder.addSubscribedTopicsFromAssignment(anyObject(), anyString()); + expectLastCall().anyTimes(); + topologyBuilder.addSubscribedTopicsFromMetadata(eq(singleton(topic1)), anyObject()); + expectLastCall().anyTimes(); + + expectRestoreToBeCompleted(consumer, changeLogReader); + + expect(consumer.assignment()).andStubReturn(union(HashSet::new, taskId00Partitions, taskId01Partitions)); + + replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer, changeLogReader, stateDirectory, stateManager); + + uncorruptedActive.setCommittableOffsetsAndMetadata(offsets); + + taskManager.handleAssignment(assignment, emptyMap()); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + + assertThat(uncorruptedActive.state(), is(Task.State.RUNNING)); + + assertThat(uncorruptedActive.commitPrepared, is(false)); + assertThat(uncorruptedActive.commitNeeded, is(true)); + assertThat(uncorruptedActive.commitCompleted, is(false)); + + taskManager.handleRebalanceStart(singleton(topic1)); + assertThat(taskManager.isRebalanceInProgress(), is(true)); + taskManager.handleCorruption(singleton(taskId00)); + + assertThat(uncorruptedActive.commitPrepared, is(false)); + assertThat(uncorruptedActive.commitNeeded, is(true)); + assertThat(uncorruptedActive.commitCompleted, is(false)); + + assertThat(uncorruptedActive.state(), is(State.RUNNING)); + verify(consumer); + } + + @Test + public void shouldCloseAndReviveUncorruptedTasksWhenTimeoutExceptionThrownFromCommitWithALOS() { + final ProcessorStateManager stateManager = EasyMock.createStrictMock(ProcessorStateManager.class); + stateManager.markChangelogAsCorrupted(taskId00Partitions); + replay(stateManager); + + final StateMachineTask corruptedActive = new StateMachineTask(taskId00, taskId00Partitions, true, stateManager); + final StateMachineTask uncorruptedActive = new StateMachineTask(taskId01, taskId01Partitions, true, stateManager) { + @Override + public void markChangelogAsCorrupted(final Collection partitions) { + fail("Should not try to mark changelogs as corrupted for uncorrupted task"); + } + }; + final Map offsets = singletonMap(t1p1, new OffsetAndMetadata(0L, null)); + uncorruptedActive.setCommittableOffsetsAndMetadata(offsets); + + // handleAssignment + final Map> assignment = new HashMap<>(); + assignment.putAll(taskId00Assignment); + assignment.putAll(taskId01Assignment); + expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andStubReturn(asList(corruptedActive, uncorruptedActive)); + topologyBuilder.addSubscribedTopicsFromAssignment(anyObject(), anyString()); + expectLastCall().anyTimes(); + + expectRestoreToBeCompleted(consumer, changeLogReader); + + consumer.commitSync(offsets); + expectLastCall().andThrow(new TimeoutException()); + + expect(consumer.assignment()).andStubReturn(union(HashSet::new, taskId00Partitions, taskId01Partitions)); + + replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer, changeLogReader); + + taskManager.handleAssignment(assignment, emptyMap()); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + + assertThat(uncorruptedActive.state(), is(Task.State.RUNNING)); + assertThat(corruptedActive.state(), is(Task.State.RUNNING)); + + // make sure this will be committed and throw + uncorruptedActive.setCommitNeeded(); + corruptedActive.setChangelogOffsets(singletonMap(t1p0, 0L)); + + assertThat(uncorruptedActive.commitPrepared, is(false)); + assertThat(uncorruptedActive.commitNeeded, is(true)); + assertThat(uncorruptedActive.commitCompleted, is(false)); + assertThat(corruptedActive.commitPrepared, is(false)); + assertThat(corruptedActive.commitNeeded, is(false)); + assertThat(corruptedActive.commitCompleted, is(false)); + + taskManager.handleCorruption(singleton(taskId00)); + + assertThat(uncorruptedActive.commitPrepared, is(true)); + assertThat(uncorruptedActive.commitNeeded, is(false)); + assertThat(uncorruptedActive.commitCompleted, is(false)); //if not corrupted, we should close dirty without committing + assertThat(corruptedActive.commitPrepared, is(true)); + assertThat(corruptedActive.commitNeeded, is(false)); + assertThat(corruptedActive.commitCompleted, is(true)); //if corrupted, should enforce checkpoint with corrupted tasks removed + + assertThat(corruptedActive.state(), is(Task.State.CREATED)); + assertThat(uncorruptedActive.state(), is(Task.State.CREATED)); + verify(consumer); + } + + @Test + public void shouldCloseAndReviveUncorruptedTasksWhenTimeoutExceptionThrownFromCommitDuringHandleCorruptedWithEOS() { + setUpTaskManager(ProcessingMode.EXACTLY_ONCE_V2); + final StreamsProducer producer = mock(StreamsProducer.class); + expect(activeTaskCreator.threadProducer()).andStubReturn(producer); + final ProcessorStateManager stateManager = EasyMock.createMock(ProcessorStateManager.class); + + final AtomicBoolean corruptedTaskChangelogMarkedAsCorrupted = new AtomicBoolean(false); + final StateMachineTask corruptedActiveTask = new StateMachineTask(taskId00, taskId00Partitions, true, stateManager) { + @Override + public void markChangelogAsCorrupted(final Collection partitions) { + super.markChangelogAsCorrupted(partitions); + corruptedTaskChangelogMarkedAsCorrupted.set(true); + } + }; + stateManager.markChangelogAsCorrupted(taskId00ChangelogPartitions); + + final AtomicBoolean uncorruptedTaskChangelogMarkedAsCorrupted = new AtomicBoolean(false); + final StateMachineTask uncorruptedActiveTask = new StateMachineTask(taskId01, taskId01Partitions, true, stateManager) { + @Override + public void markChangelogAsCorrupted(final Collection partitions) { + super.markChangelogAsCorrupted(partitions); + uncorruptedTaskChangelogMarkedAsCorrupted.set(true); + } + }; + final Map offsets = singletonMap(t1p1, new OffsetAndMetadata(0L, null)); + uncorruptedActiveTask.setCommittableOffsetsAndMetadata(offsets); + stateManager.markChangelogAsCorrupted(taskId01ChangelogPartitions); + + // handleAssignment + final Map> assignment = new HashMap<>(); + assignment.putAll(taskId00Assignment); + assignment.putAll(taskId01Assignment); + expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andStubReturn(asList(corruptedActiveTask, uncorruptedActiveTask)); + topologyBuilder.addSubscribedTopicsFromAssignment(anyObject(), anyString()); + expectLastCall().anyTimes(); + + expectRestoreToBeCompleted(consumer, changeLogReader); + + final ConsumerGroupMetadata groupMetadata = new ConsumerGroupMetadata("appId"); + expect(consumer.groupMetadata()).andReturn(groupMetadata); + producer.commitTransaction(offsets, groupMetadata); + expectLastCall().andThrow(new TimeoutException()); + + expect(consumer.assignment()).andStubReturn(union(HashSet::new, taskId00Partitions, taskId01Partitions)); + + replay(activeTaskCreator, standbyTaskCreator, topologyBuilder, consumer, changeLogReader, stateManager, producer); + + taskManager.handleAssignment(assignment, emptyMap()); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + + assertThat(uncorruptedActiveTask.state(), is(Task.State.RUNNING)); + assertThat(corruptedActiveTask.state(), is(Task.State.RUNNING)); + + // make sure this will be committed and throw + uncorruptedActiveTask.setCommitNeeded(); + + final Map corruptedActiveTaskChangelogOffsets = singletonMap(t1p0changelog, 0L); + corruptedActiveTask.setChangelogOffsets(corruptedActiveTaskChangelogOffsets); + final Map uncorruptedActiveTaskChangelogOffsets = singletonMap(t1p1changelog, 0L); + uncorruptedActiveTask.setChangelogOffsets(uncorruptedActiveTaskChangelogOffsets); + + assertThat(uncorruptedActiveTask.commitPrepared, is(false)); + assertThat(uncorruptedActiveTask.commitNeeded, is(true)); + assertThat(uncorruptedActiveTask.commitCompleted, is(false)); + assertThat(corruptedActiveTask.commitPrepared, is(false)); + assertThat(corruptedActiveTask.commitNeeded, is(false)); + assertThat(corruptedActiveTask.commitCompleted, is(false)); + + taskManager.handleCorruption(singleton(taskId00)); + + assertThat(uncorruptedActiveTask.commitPrepared, is(true)); + assertThat(uncorruptedActiveTask.commitNeeded, is(false)); + assertThat(uncorruptedActiveTask.commitCompleted, is(true)); //if corrupted due to timeout on commit, should enforce checkpoint with corrupted tasks removed + assertThat(corruptedActiveTask.commitPrepared, is(true)); + assertThat(corruptedActiveTask.commitNeeded, is(false)); + assertThat(corruptedActiveTask.commitCompleted, is(true)); //if corrupted, should enforce checkpoint with corrupted tasks removed + + assertThat(corruptedActiveTask.state(), is(Task.State.CREATED)); + assertThat(uncorruptedActiveTask.state(), is(Task.State.CREATED)); + assertThat(corruptedTaskChangelogMarkedAsCorrupted.get(), is(true)); + assertThat(uncorruptedTaskChangelogMarkedAsCorrupted.get(), is(true)); + verify(consumer); + } + + @Test + public void shouldCloseAndReviveUncorruptedTasksWhenTimeoutExceptionThrownFromCommitDuringRevocationWithALOS() { + final StateMachineTask revokedActiveTask = new StateMachineTask(taskId00, taskId00Partitions, true); + final Map offsets00 = singletonMap(t1p0, new OffsetAndMetadata(0L, null)); + revokedActiveTask.setCommittableOffsetsAndMetadata(offsets00); + revokedActiveTask.setCommitNeeded(); + + final StateMachineTask unrevokedActiveTaskWithCommitNeeded = new StateMachineTask(taskId01, taskId01Partitions, true) { + @Override + public void markChangelogAsCorrupted(final Collection partitions) { + fail("Should not try to mark changelogs as corrupted for uncorrupted task"); + } + }; + final Map offsets01 = singletonMap(t1p1, new OffsetAndMetadata(1L, null)); + unrevokedActiveTaskWithCommitNeeded.setCommittableOffsetsAndMetadata(offsets01); + unrevokedActiveTaskWithCommitNeeded.setCommitNeeded(); + + final StateMachineTask unrevokedActiveTaskWithoutCommitNeeded = new StateMachineTask(taskId02, taskId02Partitions, true); + + final Map expectedCommittedOffsets = new HashMap<>(); + expectedCommittedOffsets.putAll(offsets00); + expectedCommittedOffsets.putAll(offsets01); + + final Map> assignmentActive = mkMap( + mkEntry(taskId00, taskId00Partitions), + mkEntry(taskId01, taskId01Partitions), + mkEntry(taskId02, taskId02Partitions) + ); + + expectRestoreToBeCompleted(consumer, changeLogReader); + + expect(activeTaskCreator.createTasks(anyObject(), eq(assignmentActive))).andReturn(asList(revokedActiveTask, unrevokedActiveTaskWithCommitNeeded, unrevokedActiveTaskWithoutCommitNeeded)); + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId00); + expectLastCall(); + consumer.commitSync(expectedCommittedOffsets); + expectLastCall().andThrow(new TimeoutException()); + expect(consumer.assignment()).andStubReturn(union(HashSet::new, taskId00Partitions, taskId01Partitions, taskId02Partitions)); + + replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader); + + taskManager.handleAssignment(assignmentActive, emptyMap()); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + assertThat(revokedActiveTask.state(), is(Task.State.RUNNING)); + assertThat(unrevokedActiveTaskWithCommitNeeded.state(), is(State.RUNNING)); + assertThat(unrevokedActiveTaskWithoutCommitNeeded.state(), is(Task.State.RUNNING)); + + taskManager.handleRevocation(taskId00Partitions); + + assertThat(revokedActiveTask.state(), is(State.SUSPENDED)); + assertThat(unrevokedActiveTaskWithCommitNeeded.state(), is(State.CREATED)); + assertThat(unrevokedActiveTaskWithoutCommitNeeded.state(), is(State.RUNNING)); + } + + @Test + public void shouldCloseAndReviveUncorruptedTasksWhenTimeoutExceptionThrownFromCommitDuringRevocationWithEOS() { + setUpTaskManager(ProcessingMode.EXACTLY_ONCE_V2); + final StreamsProducer producer = mock(StreamsProducer.class); + expect(activeTaskCreator.threadProducer()).andStubReturn(producer); + final ProcessorStateManager stateManager = EasyMock.createMock(ProcessorStateManager.class); + + final StateMachineTask revokedActiveTask = new StateMachineTask(taskId00, taskId00Partitions, true, stateManager); + final Map revokedActiveTaskOffsets = singletonMap(t1p0, new OffsetAndMetadata(0L, null)); + revokedActiveTask.setCommittableOffsetsAndMetadata(revokedActiveTaskOffsets); + revokedActiveTask.setCommitNeeded(); + + final AtomicBoolean unrevokedTaskChangelogMarkedAsCorrupted = new AtomicBoolean(false); + final StateMachineTask unrevokedActiveTask = new StateMachineTask(taskId01, taskId01Partitions, true, stateManager) { + @Override + public void markChangelogAsCorrupted(final Collection partitions) { + super.markChangelogAsCorrupted(partitions); + unrevokedTaskChangelogMarkedAsCorrupted.set(true); + } + }; + final Map unrevokedTaskOffsets = singletonMap(t1p1, new OffsetAndMetadata(1L, null)); + unrevokedActiveTask.setCommittableOffsetsAndMetadata(unrevokedTaskOffsets); + unrevokedActiveTask.setCommitNeeded(); + + final StateMachineTask unrevokedActiveTaskWithoutCommitNeeded = new StateMachineTask(taskId02, taskId02Partitions, true, stateManager); + + final Map expectedCommittedOffsets = new HashMap<>(); + expectedCommittedOffsets.putAll(revokedActiveTaskOffsets); + expectedCommittedOffsets.putAll(unrevokedTaskOffsets); + + stateManager.markChangelogAsCorrupted(taskId00ChangelogPartitions); + stateManager.markChangelogAsCorrupted(taskId01ChangelogPartitions); + + final Map> assignmentActive = mkMap( + mkEntry(taskId00, taskId00Partitions), + mkEntry(taskId01, taskId01Partitions), + mkEntry(taskId02, taskId02Partitions) + ); + + expectRestoreToBeCompleted(consumer, changeLogReader); + + expect(activeTaskCreator.createTasks(anyObject(), eq(assignmentActive))).andReturn(asList(revokedActiveTask, unrevokedActiveTask, unrevokedActiveTaskWithoutCommitNeeded)); + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId00); + expectLastCall(); + + final ConsumerGroupMetadata groupMetadata = new ConsumerGroupMetadata("appId"); + expect(consumer.groupMetadata()).andReturn(groupMetadata); + producer.commitTransaction(expectedCommittedOffsets, groupMetadata); + expectLastCall().andThrow(new TimeoutException()); + + expect(consumer.assignment()).andStubReturn(union(HashSet::new, taskId00Partitions, taskId01Partitions, taskId02Partitions)); + + replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader, producer, stateManager); + + taskManager.handleAssignment(assignmentActive, emptyMap()); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + assertThat(revokedActiveTask.state(), is(Task.State.RUNNING)); + assertThat(unrevokedActiveTask.state(), is(Task.State.RUNNING)); + assertThat(unrevokedActiveTaskWithoutCommitNeeded.state(), is(State.RUNNING)); + + final Map revokedActiveTaskChangelogOffsets = singletonMap(t1p0changelog, 0L); + revokedActiveTask.setChangelogOffsets(revokedActiveTaskChangelogOffsets); + final Map unrevokedActiveTaskChangelogOffsets = singletonMap(t1p1changelog, 0L); + unrevokedActiveTask.setChangelogOffsets(unrevokedActiveTaskChangelogOffsets); + + taskManager.handleRevocation(taskId00Partitions); + + assertThat(unrevokedTaskChangelogMarkedAsCorrupted.get(), is(true)); + assertThat(revokedActiveTask.state(), is(State.SUSPENDED)); + assertThat(unrevokedActiveTask.state(), is(State.CREATED)); + assertThat(unrevokedActiveTaskWithoutCommitNeeded.state(), is(State.RUNNING)); + } + + @Test + public void shouldCloseStandbyUnassignedTasksWhenCreatingNewTasks() { + final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, false); + + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(standbyTaskCreator.createTasks(eq(taskId00Assignment))).andStubReturn(singletonList(task00)); + consumer.commitSync(Collections.emptyMap()); + expectLastCall(); + replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader); + + taskManager.handleAssignment(emptyMap(), taskId00Assignment); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + assertThat(task00.state(), is(Task.State.RUNNING)); + + taskManager.handleAssignment(emptyMap(), emptyMap()); + assertThat(task00.state(), is(Task.State.CLOSED)); + assertThat(taskManager.activeTaskMap(), Matchers.anEmptyMap()); + assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap()); + } + + @Test + public void shouldAddNonResumedSuspendedTasks() { + final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true); + final Task task01 = new StateMachineTask(taskId01, taskId01Partitions, false); + + expectRestoreToBeCompleted(consumer, changeLogReader); + // expect these calls twice (because we're going to tryToCompleteRestoration twice) + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(task00)); + expect(standbyTaskCreator.createTasks(eq(taskId01Assignment))).andReturn(singletonList(task01)); + replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader); + + taskManager.handleAssignment(taskId00Assignment, taskId01Assignment); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + assertThat(task00.state(), is(Task.State.RUNNING)); + assertThat(task01.state(), is(Task.State.RUNNING)); + + taskManager.handleAssignment(taskId00Assignment, taskId01Assignment); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + assertThat(task00.state(), is(Task.State.RUNNING)); + assertThat(task01.state(), is(Task.State.RUNNING)); + + verify(activeTaskCreator); + } + + @Test + public void shouldUpdateInputPartitionsAfterRebalance() { + final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true); + + expectRestoreToBeCompleted(consumer, changeLogReader); + // expect these calls twice (because we're going to tryToCompleteRestoration twice) + expectRestoreToBeCompleted(consumer, changeLogReader, false); + expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(task00)); + replay(activeTaskCreator, consumer, changeLogReader); + + + taskManager.handleAssignment(taskId00Assignment, emptyMap()); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + assertThat(task00.state(), is(Task.State.RUNNING)); + + final Set newPartitionsSet = mkSet(t1p1); + final Map> taskIdSetMap = singletonMap(taskId00, newPartitionsSet); + taskManager.handleAssignment(taskIdSetMap, emptyMap()); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + assertThat(task00.state(), is(Task.State.RUNNING)); + assertEquals(newPartitionsSet, task00.inputPartitions()); + verify(activeTaskCreator, consumer, changeLogReader); + } + + @Test + public void shouldAddNewActiveTasks() { + final Map> assignment = taskId00Assignment; + final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true); + + expect(changeLogReader.completedChangelogs()).andReturn(emptySet()); + expect(consumer.assignment()).andReturn(emptySet()); + consumer.resume(eq(emptySet())); + expectLastCall(); + changeLogReader.enforceRestoreActive(); + expectLastCall(); + expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andStubReturn(singletonList(task00)); + expect(standbyTaskCreator.createTasks(eq(emptyMap()))).andStubReturn(emptyList()); + replay(consumer, activeTaskCreator, standbyTaskCreator, changeLogReader); + + taskManager.handleAssignment(assignment, emptyMap()); + + assertThat(task00.state(), is(Task.State.CREATED)); + + taskManager.tryToCompleteRestoration(time.milliseconds(), noOpResetter -> { }); + + assertThat(task00.state(), is(Task.State.RUNNING)); + assertThat(taskManager.activeTaskMap(), Matchers.equalTo(singletonMap(taskId00, task00))); + assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap()); + verify(activeTaskCreator); + } + + @Test + public void shouldNotCompleteRestorationIfTasksCannotInitialize() { + final Map> assignment = mkMap( + mkEntry(taskId00, taskId00Partitions), + mkEntry(taskId01, taskId01Partitions) + ); + final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true) { + @Override + public void initializeIfNeeded() { + throw new LockException("can't lock"); + } + }; + final Task task01 = new StateMachineTask(taskId01, taskId01Partitions, true) { + @Override + public void initializeIfNeeded() { + throw new TimeoutException("timed out"); + } + }; + + consumer.commitSync(Collections.emptyMap()); + expectLastCall(); + expect(changeLogReader.completedChangelogs()).andReturn(emptySet()); + expect(consumer.assignment()).andReturn(emptySet()); + consumer.resume(eq(emptySet())); + expectLastCall(); + changeLogReader.enforceRestoreActive(); + expectLastCall(); + expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andStubReturn(asList(task00, task01)); + expect(standbyTaskCreator.createTasks(eq(emptyMap()))).andStubReturn(emptyList()); + replay(consumer, activeTaskCreator, standbyTaskCreator, changeLogReader); + + taskManager.handleAssignment(assignment, emptyMap()); + + assertThat(task00.state(), is(Task.State.CREATED)); + assertThat(task01.state(), is(Task.State.CREATED)); + + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(false)); + + assertThat(task00.state(), is(Task.State.CREATED)); + assertThat(task01.state(), is(Task.State.CREATED)); + assertThat( + taskManager.activeTaskMap(), + Matchers.equalTo(mkMap(mkEntry(taskId00, task00), mkEntry(taskId01, task01))) + ); + assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap()); + verify(activeTaskCreator); + } + + @Test + public void shouldNotCompleteRestorationIfTaskCannotCompleteRestoration() { + final Map> assignment = mkMap( + mkEntry(taskId00, taskId00Partitions) + ); + final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true) { + @Override + public void completeRestoration(final java.util.function.Consumer> offsetResetter) { + throw new TimeoutException("timeout!"); + } + }; + + consumer.commitSync(Collections.emptyMap()); + expectLastCall(); + expect(changeLogReader.completedChangelogs()).andReturn(emptySet()); + expect(consumer.assignment()).andReturn(emptySet()); + consumer.resume(eq(emptySet())); + expectLastCall(); + changeLogReader.enforceRestoreActive(); + expectLastCall(); + expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andStubReturn(singletonList(task00)); + expect(standbyTaskCreator.createTasks(eq(emptyMap()))).andStubReturn(emptyList()); + replay(consumer, activeTaskCreator, standbyTaskCreator, changeLogReader); + + taskManager.handleAssignment(assignment, emptyMap()); + + assertThat(task00.state(), is(Task.State.CREATED)); + + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(false)); + + assertThat(task00.state(), is(Task.State.RESTORING)); + assertThat( + taskManager.activeTaskMap(), + Matchers.equalTo(mkMap(mkEntry(taskId00, task00))) + ); + assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap()); + verify(activeTaskCreator); + } + + @Test + public void shouldSuspendActiveTasksDuringRevocation() { + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true); + final Map offsets = singletonMap(t1p0, new OffsetAndMetadata(0L, null)); + task00.setCommittableOffsetsAndMetadata(offsets); + + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(task00)); + consumer.commitSync(offsets); + expectLastCall(); + + replay(activeTaskCreator, consumer, changeLogReader); + + taskManager.handleAssignment(taskId00Assignment, emptyMap()); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + assertThat(task00.state(), is(Task.State.RUNNING)); + + taskManager.handleRevocation(taskId00Partitions); + assertThat(task00.state(), is(Task.State.SUSPENDED)); + } + + @Test + public void shouldCommitAllActiveTasksThatNeedCommittingOnHandleRevocationWithEosV2() { + final StreamsProducer producer = mock(StreamsProducer.class); + setUpTaskManager(ProcessingMode.EXACTLY_ONCE_V2); + + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true); + final Map offsets00 = singletonMap(t1p0, new OffsetAndMetadata(0L, null)); + task00.setCommittableOffsetsAndMetadata(offsets00); + task00.setCommitNeeded(); + + final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true); + final Map offsets01 = singletonMap(t1p1, new OffsetAndMetadata(1L, null)); + task01.setCommittableOffsetsAndMetadata(offsets01); + task01.setCommitNeeded(); + + final StateMachineTask task02 = new StateMachineTask(taskId02, taskId02Partitions, true); + final Map offsets02 = singletonMap(t1p2, new OffsetAndMetadata(2L, null)); + task02.setCommittableOffsetsAndMetadata(offsets02); + + final StateMachineTask task10 = new StateMachineTask(taskId10, taskId10Partitions, false); + + final Map expectedCommittedOffsets = new HashMap<>(); + expectedCommittedOffsets.putAll(offsets00); + expectedCommittedOffsets.putAll(offsets01); + + final Map> assignmentActive = mkMap( + mkEntry(taskId00, taskId00Partitions), + mkEntry(taskId01, taskId01Partitions), + mkEntry(taskId02, taskId02Partitions) + ); + + final Map> assignmentStandby = mkMap( + mkEntry(taskId10, taskId10Partitions) + ); + expectRestoreToBeCompleted(consumer, changeLogReader); + + expect(activeTaskCreator.createTasks(anyObject(), eq(assignmentActive))) + .andReturn(asList(task00, task01, task02)); + + expect(activeTaskCreator.threadProducer()).andReturn(producer); + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId00); + expect(standbyTaskCreator.createTasks(eq(assignmentStandby))) + .andReturn(singletonList(task10)); + + final ConsumerGroupMetadata groupMetadata = new ConsumerGroupMetadata("appId"); + expect(consumer.groupMetadata()).andReturn(groupMetadata); + producer.commitTransaction(expectedCommittedOffsets, groupMetadata); + expectLastCall(); + + task00.committedOffsets(); + EasyMock.expectLastCall(); + task01.committedOffsets(); + EasyMock.expectLastCall(); + task02.committedOffsets(); + EasyMock.expectLastCall(); + task10.committedOffsets(); + EasyMock.expectLastCall(); + + replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader); + + taskManager.handleAssignment(assignmentActive, assignmentStandby); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + assertThat(task00.state(), is(Task.State.RUNNING)); + assertThat(task01.state(), is(Task.State.RUNNING)); + assertThat(task02.state(), is(Task.State.RUNNING)); + assertThat(task10.state(), is(Task.State.RUNNING)); + + taskManager.handleRevocation(taskId00Partitions); + + assertThat(task00.commitNeeded, is(false)); + assertThat(task01.commitNeeded, is(false)); + assertThat(task02.commitPrepared, is(false)); + assertThat(task10.commitPrepared, is(false)); + } + + @Test + public void shouldCommitAllNeededTasksOnHandleRevocation() { + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true); + final Map offsets00 = singletonMap(t1p0, new OffsetAndMetadata(0L, null)); + task00.setCommittableOffsetsAndMetadata(offsets00); + task00.setCommitNeeded(); + + final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true); + final Map offsets01 = singletonMap(t1p1, new OffsetAndMetadata(1L, null)); + task01.setCommittableOffsetsAndMetadata(offsets01); + task01.setCommitNeeded(); + + final StateMachineTask task02 = new StateMachineTask(taskId02, taskId02Partitions, true); + final Map offsets02 = singletonMap(t1p2, new OffsetAndMetadata(2L, null)); + task02.setCommittableOffsetsAndMetadata(offsets02); + + final StateMachineTask task10 = new StateMachineTask(taskId10, taskId10Partitions, false); + + final Map expectedCommittedOffsets = new HashMap<>(); + expectedCommittedOffsets.putAll(offsets00); + expectedCommittedOffsets.putAll(offsets01); + + final Map> assignmentActive = mkMap( + mkEntry(taskId00, taskId00Partitions), + mkEntry(taskId01, taskId01Partitions), + mkEntry(taskId02, taskId02Partitions) + ); + + final Map> assignmentStandby = mkMap( + mkEntry(taskId10, taskId10Partitions) + ); + expectRestoreToBeCompleted(consumer, changeLogReader); + + expect(activeTaskCreator.createTasks(anyObject(), eq(assignmentActive))) + .andReturn(asList(task00, task01, task02)); + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId00); + expectLastCall(); + expect(standbyTaskCreator.createTasks(eq(assignmentStandby))) + .andReturn(singletonList(task10)); + consumer.commitSync(expectedCommittedOffsets); + expectLastCall(); + + replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader); + + taskManager.handleAssignment(assignmentActive, assignmentStandby); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + assertThat(task00.state(), is(Task.State.RUNNING)); + assertThat(task01.state(), is(Task.State.RUNNING)); + assertThat(task02.state(), is(Task.State.RUNNING)); + assertThat(task10.state(), is(Task.State.RUNNING)); + + taskManager.handleRevocation(taskId00Partitions); + + assertThat(task00.commitNeeded, is(false)); + assertThat(task00.commitPrepared, is(true)); + assertThat(task00.commitNeeded, is(false)); + assertThat(task01.commitPrepared, is(true)); + assertThat(task02.commitPrepared, is(false)); + assertThat(task10.commitPrepared, is(false)); + } + + @Test + public void shouldNotCommitOnHandleAssignmentIfNoTaskClosed() { + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true); + final Map offsets00 = singletonMap(t1p0, new OffsetAndMetadata(0L, null)); + task00.setCommittableOffsetsAndMetadata(offsets00); + task00.setCommitNeeded(); + + final StateMachineTask task10 = new StateMachineTask(taskId10, taskId10Partitions, false); + + final Map> assignmentActive = singletonMap(taskId00, taskId00Partitions); + final Map> assignmentStandby = singletonMap(taskId10, taskId10Partitions); + + expectRestoreToBeCompleted(consumer, changeLogReader); + + expect(activeTaskCreator.createTasks(anyObject(), eq(assignmentActive))).andReturn(singleton(task00)); + expect(standbyTaskCreator.createTasks(eq(assignmentStandby))).andReturn(singletonList(task10)); + + replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader); + + taskManager.handleAssignment(assignmentActive, assignmentStandby); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + assertThat(task00.state(), is(Task.State.RUNNING)); + assertThat(task10.state(), is(Task.State.RUNNING)); + + taskManager.handleAssignment(assignmentActive, assignmentStandby); + + assertThat(task00.commitNeeded, is(true)); + assertThat(task10.commitPrepared, is(false)); + } + + @Test + public void shouldNotCommitOnHandleAssignmentIfOnlyStandbyTaskClosed() { + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true); + final Map offsets00 = singletonMap(t1p0, new OffsetAndMetadata(0L, null)); + task00.setCommittableOffsetsAndMetadata(offsets00); + task00.setCommitNeeded(); + + final StateMachineTask task10 = new StateMachineTask(taskId10, taskId10Partitions, false); + + final Map> assignmentActive = singletonMap(taskId00, taskId00Partitions); + final Map> assignmentStandby = singletonMap(taskId10, taskId10Partitions); + + expectRestoreToBeCompleted(consumer, changeLogReader); + + expect(activeTaskCreator.createTasks(anyObject(), eq(assignmentActive))).andReturn(singleton(task00)); + expect(standbyTaskCreator.createTasks(eq(assignmentStandby))).andReturn(singletonList(task10)); + + replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader); + + taskManager.handleAssignment(assignmentActive, assignmentStandby); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + assertThat(task00.state(), is(Task.State.RUNNING)); + assertThat(task10.state(), is(Task.State.RUNNING)); + + taskManager.handleAssignment(assignmentActive, Collections.emptyMap()); + + assertThat(task00.commitNeeded, is(true)); + } + + @Test + public void shouldNotCommitCreatedTasksOnRevocationOrClosure() { + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true); + + expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(task00)); + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId00)); + replay(activeTaskCreator, consumer, changeLogReader); + + taskManager.handleAssignment(taskId00Assignment, emptyMap()); + assertThat(task00.state(), is(Task.State.CREATED)); + + taskManager.handleRevocation(taskId00Partitions); + assertThat(task00.state(), is(Task.State.SUSPENDED)); + + taskManager.handleAssignment(emptyMap(), emptyMap()); + assertThat(task00.state(), is(Task.State.CLOSED)); + } + + @Test + public void shouldPassUpIfExceptionDuringSuspend() { + final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true) { + @Override + public void suspend() { + super.suspend(); + throw new RuntimeException("KABOOM!"); + } + }; + + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(task00)); + + replay(activeTaskCreator, consumer, changeLogReader); + taskManager.handleAssignment(taskId00Assignment, emptyMap()); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + assertThat(task00.state(), is(Task.State.RUNNING)); + + assertThrows(RuntimeException.class, () -> taskManager.handleRevocation(taskId00Partitions)); + assertThat(task00.state(), is(Task.State.SUSPENDED)); + + verify(consumer); + } + + @Test + public void shouldCloseActiveTasksAndPropagateExceptionsOnCleanShutdown() { + final TopicPartition changelog = new TopicPartition("changelog", 0); + final Map> assignment = mkMap( + mkEntry(taskId00, taskId00Partitions), + mkEntry(taskId01, taskId01Partitions), + mkEntry(taskId02, taskId02Partitions), + mkEntry(taskId03, taskId03Partitions) + ); + final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true) { + @Override + public Collection changelogPartitions() { + return singletonList(changelog); + } + }; + final AtomicBoolean closedDirtyTask01 = new AtomicBoolean(false); + final AtomicBoolean closedDirtyTask02 = new AtomicBoolean(false); + final AtomicBoolean closedDirtyTask03 = new AtomicBoolean(false); + final Task task01 = new StateMachineTask(taskId01, taskId01Partitions, true) { + @Override + public void suspend() { + super.suspend(); + throw new TaskMigratedException("migrated", new RuntimeException("cause")); + } + + @Override + public void closeDirty() { + super.closeDirty(); + closedDirtyTask01.set(true); + } + }; + final Task task02 = new StateMachineTask(taskId02, taskId02Partitions, true) { + @Override + public void suspend() { + super.suspend(); + throw new RuntimeException("oops"); + } + + @Override + public void closeDirty() { + super.closeDirty(); + closedDirtyTask02.set(true); + } + }; + final Task task03 = new StateMachineTask(taskId03, taskId03Partitions, true) { + @Override + public void suspend() { + super.suspend(); + throw new RuntimeException("oops"); + } + + @Override + public void closeDirty() { + super.closeDirty(); + closedDirtyTask03.set(true); + } + }; + + resetToStrict(changeLogReader); + expect(changeLogReader.completedChangelogs()).andReturn(emptySet()); + expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))) + .andStubReturn(asList(task00, task01, task02, task03)); + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId00)); + expectLastCall(); + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId01)); + expectLastCall(); + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId02)); + expectLastCall(); + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId03)); + expectLastCall(); + activeTaskCreator.closeThreadProducerIfNeeded(); + expectLastCall(); + expect(standbyTaskCreator.createTasks(eq(emptyMap()))).andStubReturn(emptyList()); + replay(activeTaskCreator, standbyTaskCreator, changeLogReader); + + taskManager.handleAssignment(assignment, emptyMap()); + + assertThat(task00.state(), is(Task.State.CREATED)); + assertThat(task01.state(), is(Task.State.CREATED)); + assertThat(task02.state(), is(Task.State.CREATED)); + assertThat(task03.state(), is(Task.State.CREATED)); + + taskManager.tryToCompleteRestoration(time.milliseconds(), null); + + assertThat(task00.state(), is(Task.State.RESTORING)); + assertThat(task01.state(), is(Task.State.RUNNING)); + assertThat(task02.state(), is(Task.State.RUNNING)); + assertThat(task03.state(), is(Task.State.RUNNING)); + assertThat( + taskManager.activeTaskMap(), + Matchers.equalTo( + mkMap( + mkEntry(taskId00, task00), + mkEntry(taskId01, task01), + mkEntry(taskId02, task02), + mkEntry(taskId03, task03) + ) + ) + ); + assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap()); + + final RuntimeException exception = assertThrows( + RuntimeException.class, + () -> taskManager.shutdown(true) + ); + assertThat(exception.getMessage(), equalTo("Unexpected exception while closing task")); + assertThat(exception.getCause().getMessage(), is("migrated; it means all tasks belonging to this thread should be migrated.")); + assertThat(exception.getCause().getCause().getMessage(), is("cause")); + + assertThat(closedDirtyTask01.get(), is(true)); + assertThat(closedDirtyTask02.get(), is(true)); + assertThat(closedDirtyTask03.get(), is(true)); + assertThat(task00.state(), is(Task.State.CLOSED)); + assertThat(task01.state(), is(Task.State.CLOSED)); + assertThat(task02.state(), is(Task.State.CLOSED)); + assertThat(task03.state(), is(Task.State.CLOSED)); + assertThat(taskManager.activeTaskMap(), Matchers.anEmptyMap()); + assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap()); + // the active task creator should also get closed (so that it closes the thread producer if applicable) + verify(activeTaskCreator, changeLogReader); + } + + @Test + public void shouldCloseActiveTasksAndPropagateTaskProducerExceptionsOnCleanShutdown() { + final TopicPartition changelog = new TopicPartition("changelog", 0); + final Map> assignment = mkMap( + mkEntry(taskId00, taskId00Partitions) + ); + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true) { + @Override + public Collection changelogPartitions() { + return singletonList(changelog); + } + }; + final Map offsets = singletonMap(t1p0, new OffsetAndMetadata(0L, null)); + task00.setCommittableOffsetsAndMetadata(offsets); + + resetToStrict(changeLogReader); + expect(changeLogReader.completedChangelogs()).andReturn(emptySet()); + expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andStubReturn(singletonList(task00)); + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId00)); + expectLastCall().andThrow(new RuntimeException("whatever")); + activeTaskCreator.closeThreadProducerIfNeeded(); + expectLastCall(); + expect(standbyTaskCreator.createTasks(eq(emptyMap()))).andStubReturn(emptyList()); + replay(activeTaskCreator, standbyTaskCreator, changeLogReader); + + taskManager.handleAssignment(assignment, emptyMap()); + + assertThat(task00.state(), is(Task.State.CREATED)); + + taskManager.tryToCompleteRestoration(time.milliseconds(), null); + + assertThat(task00.state(), is(Task.State.RESTORING)); + assertThat( + taskManager.activeTaskMap(), + Matchers.equalTo( + mkMap( + mkEntry(taskId00, task00) + ) + ) + ); + assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap()); + + final RuntimeException exception = assertThrows(RuntimeException.class, () -> taskManager.shutdown(true)); + + assertThat(task00.state(), is(Task.State.CLOSED)); + assertThat(exception.getMessage(), is("Unexpected exception while closing task")); + assertThat(exception.getCause().getMessage(), is("whatever")); + assertThat(taskManager.activeTaskMap(), Matchers.anEmptyMap()); + assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap()); + // the active task creator should also get closed (so that it closes the thread producer if applicable) + verify(activeTaskCreator, changeLogReader); + } + + @Test + public void shouldCloseActiveTasksAndPropagateThreadProducerExceptionsOnCleanShutdown() { + final TopicPartition changelog = new TopicPartition("changelog", 0); + final Map> assignment = mkMap( + mkEntry(taskId00, taskId00Partitions) + ); + final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true) { + @Override + public Collection changelogPartitions() { + return singletonList(changelog); + } + }; + + resetToStrict(changeLogReader); + expect(changeLogReader.completedChangelogs()).andReturn(emptySet()); + expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andStubReturn(singletonList(task00)); + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId00)); + expectLastCall(); + activeTaskCreator.closeThreadProducerIfNeeded(); + expectLastCall().andThrow(new RuntimeException("whatever")); + expect(standbyTaskCreator.createTasks(eq(emptyMap()))).andStubReturn(emptyList()); + replay(activeTaskCreator, standbyTaskCreator, changeLogReader); + + taskManager.handleAssignment(assignment, emptyMap()); + + assertThat(task00.state(), is(Task.State.CREATED)); + + taskManager.tryToCompleteRestoration(time.milliseconds(), null); + + assertThat(task00.state(), is(Task.State.RESTORING)); + assertThat( + taskManager.activeTaskMap(), + Matchers.equalTo( + mkMap( + mkEntry(taskId00, task00) + ) + ) + ); + assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap()); + + final RuntimeException exception = assertThrows(RuntimeException.class, () -> taskManager.shutdown(true)); + + assertThat(task00.state(), is(Task.State.CLOSED)); + assertThat(exception.getMessage(), is("Unexpected exception while closing task")); + assertThat(exception.getCause().getMessage(), is("whatever")); + assertThat(taskManager.activeTaskMap(), Matchers.anEmptyMap()); + assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap()); + // the active task creator should also get closed (so that it closes the thread producer if applicable) + verify(activeTaskCreator, changeLogReader); + } + + @Test + public void shouldOnlyCommitRevokedStandbyTaskAndPropagatePrepareCommitException() { + setUpTaskManager(StreamThread.ProcessingMode.EXACTLY_ONCE_ALPHA); + + final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, false); + + final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, false) { + @Override + public Map prepareCommit() { + throw new RuntimeException("task 0_1 prepare commit boom!"); + } + }; + task01.setCommitNeeded(); + + taskManager.addTask(task00); + taskManager.addTask(task01); + + final RuntimeException thrown = assertThrows(RuntimeException.class, + () -> taskManager.handleAssignment( + Collections.emptyMap(), + singletonMap(taskId00, taskId00Partitions) + )); + assertThat(thrown.getCause().getMessage(), is("task 0_1 prepare commit boom!")); + + assertThat(task00.state(), is(Task.State.CREATED)); + assertThat(task01.state(), is(Task.State.CLOSED)); + + // All the tasks involving in the commit should already be removed. + assertThat(taskManager.tasks(), is(Collections.singletonMap(taskId00, task00))); + } + + @Test + public void shouldSuspendAllRevokedActiveTasksAndPropagateSuspendException() { + final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true); + + final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true) { + @Override + public void suspend() { + super.suspend(); + throw new RuntimeException("task 0_1 suspend boom!"); + } + }; + + final StateMachineTask task02 = new StateMachineTask(taskId02, taskId02Partitions, true); + + taskManager.addTask(task00); + taskManager.addTask(task01); + taskManager.addTask(task02); + + replay(activeTaskCreator); + + final RuntimeException thrown = assertThrows(RuntimeException.class, + () -> taskManager.handleRevocation(union(HashSet::new, taskId01Partitions, taskId02Partitions))); + assertThat(thrown.getCause().getMessage(), is("task 0_1 suspend boom!")); + + assertThat(task00.state(), is(Task.State.CREATED)); + assertThat(task01.state(), is(Task.State.SUSPENDED)); + assertThat(task02.state(), is(Task.State.SUSPENDED)); + + verify(activeTaskCreator); + } + + @Test + public void shouldCloseActiveTasksAndIgnoreExceptionsOnUncleanShutdown() { + final TopicPartition changelog = new TopicPartition("changelog", 0); + final Map> assignment = mkMap( + mkEntry(taskId00, taskId00Partitions), + mkEntry(taskId01, taskId01Partitions), + mkEntry(taskId02, taskId02Partitions) + ); + final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, true) { + @Override + public Collection changelogPartitions() { + return singletonList(changelog); + } + }; + final Task task01 = new StateMachineTask(taskId01, taskId01Partitions, true) { + @Override + public void suspend() { + super.suspend(); + throw new TaskMigratedException("migrated", new RuntimeException("cause")); + } + }; + final Task task02 = new StateMachineTask(taskId02, taskId02Partitions, true) { + @Override + public void suspend() { + super.suspend(); + throw new RuntimeException("oops"); + } + }; + + resetToStrict(changeLogReader); + expect(changeLogReader.completedChangelogs()).andReturn(emptySet()); + expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))).andStubReturn(asList(task00, task01, task02)); + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId00)); + expectLastCall().andThrow(new RuntimeException("whatever 0")); + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId01)); + expectLastCall().andThrow(new RuntimeException("whatever 1")); + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(eq(taskId02)); + expectLastCall().andThrow(new RuntimeException("whatever 2")); + activeTaskCreator.closeThreadProducerIfNeeded(); + expectLastCall().andThrow(new RuntimeException("whatever all")); + expect(standbyTaskCreator.createTasks(eq(emptyMap()))).andStubReturn(emptyList()); + replay(activeTaskCreator, standbyTaskCreator, changeLogReader); + + taskManager.handleAssignment(assignment, emptyMap()); + + assertThat(task00.state(), is(Task.State.CREATED)); + assertThat(task01.state(), is(Task.State.CREATED)); + assertThat(task02.state(), is(Task.State.CREATED)); + + taskManager.tryToCompleteRestoration(time.milliseconds(), null); + + assertThat(task00.state(), is(Task.State.RESTORING)); + assertThat(task01.state(), is(Task.State.RUNNING)); + assertThat(task02.state(), is(Task.State.RUNNING)); + assertThat( + taskManager.activeTaskMap(), + Matchers.equalTo( + mkMap( + mkEntry(taskId00, task00), + mkEntry(taskId01, task01), + mkEntry(taskId02, task02) + ) + ) + ); + assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap()); + + taskManager.shutdown(false); + + assertThat(task00.state(), is(Task.State.CLOSED)); + assertThat(task01.state(), is(Task.State.CLOSED)); + assertThat(task02.state(), is(Task.State.CLOSED)); + assertThat(taskManager.activeTaskMap(), Matchers.anEmptyMap()); + assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap()); + // the active task creator should also get closed (so that it closes the thread producer if applicable) + verify(activeTaskCreator, changeLogReader); + } + + @Test + public void shouldCloseStandbyTasksOnShutdown() { + final Map> assignment = singletonMap(taskId00, taskId00Partitions); + final Task task00 = new StateMachineTask(taskId00, taskId00Partitions, false); + + // `handleAssignment` + expect(standbyTaskCreator.createTasks(eq(assignment))).andStubReturn(singletonList(task00)); + + // `tryToCompleteRestoration` + expect(changeLogReader.completedChangelogs()).andReturn(emptySet()); + expect(consumer.assignment()).andReturn(emptySet()); + consumer.resume(eq(emptySet())); + expectLastCall(); + + // `shutdown` + consumer.commitSync(Collections.emptyMap()); + expectLastCall(); + activeTaskCreator.closeThreadProducerIfNeeded(); + expectLastCall(); + + replay(consumer, activeTaskCreator, standbyTaskCreator, changeLogReader); + + taskManager.handleAssignment(emptyMap(), assignment); + assertThat(task00.state(), is(Task.State.CREATED)); + + taskManager.tryToCompleteRestoration(time.milliseconds(), null); + assertThat(task00.state(), is(Task.State.RUNNING)); + assertThat(taskManager.activeTaskMap(), Matchers.anEmptyMap()); + assertThat(taskManager.standbyTaskMap(), Matchers.equalTo(singletonMap(taskId00, task00))); + + taskManager.shutdown(true); + assertThat(task00.state(), is(Task.State.CLOSED)); + assertThat(taskManager.activeTaskMap(), Matchers.anEmptyMap()); + assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap()); + // the active task creator should also get closed (so that it closes the thread producer if applicable) + verify(activeTaskCreator); + } + + @Test + public void shouldInitializeNewActiveTasks() { + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true); + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))) + .andStubReturn(singletonList(task00)); + replay(activeTaskCreator, consumer, changeLogReader); + + taskManager.handleAssignment(taskId00Assignment, emptyMap()); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + + assertThat(task00.state(), is(Task.State.RUNNING)); + assertThat(taskManager.activeTaskMap(), Matchers.equalTo(singletonMap(taskId00, task00))); + assertThat(taskManager.standbyTaskMap(), Matchers.anEmptyMap()); + // verifies that we actually resume the assignment at the end of restoration. + verify(consumer); + } + + @Test + public void shouldInitializeNewStandbyTasks() { + final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, false); + + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(standbyTaskCreator.createTasks(eq(taskId01Assignment))) + .andStubReturn(singletonList(task01)); + + replay(standbyTaskCreator, consumer, changeLogReader); + + taskManager.handleAssignment(emptyMap(), taskId01Assignment); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + + assertThat(task01.state(), is(Task.State.RUNNING)); + assertThat(taskManager.activeTaskMap(), Matchers.anEmptyMap()); + assertThat(taskManager.standbyTaskMap(), Matchers.equalTo(singletonMap(taskId01, task01))); + } + + @Test + public void shouldHandleRebalanceEvents() { + final Set assignment = singleton(new TopicPartition("assignment", 0)); + expect(consumer.assignment()).andReturn(assignment); + consumer.pause(assignment); + expectLastCall(); + expect(stateDirectory.listNonEmptyTaskDirectories()).andReturn(new ArrayList<>()); + replay(consumer, stateDirectory); + assertThat(taskManager.isRebalanceInProgress(), is(false)); + taskManager.handleRebalanceStart(emptySet()); + assertThat(taskManager.isRebalanceInProgress(), is(true)); + taskManager.handleRebalanceComplete(); + assertThat(taskManager.isRebalanceInProgress(), is(false)); + } + + @Test + public void shouldCommitActiveAndStandbyTasks() { + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true); + final Map offsets = singletonMap(t1p0, new OffsetAndMetadata(0L, null)); + task00.setCommittableOffsetsAndMetadata(offsets); + final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, false); + + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))) + .andStubReturn(singletonList(task00)); + expect(standbyTaskCreator.createTasks(eq(taskId01Assignment))) + .andStubReturn(singletonList(task01)); + consumer.commitSync(offsets); + expectLastCall(); + + replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader); + + taskManager.handleAssignment(taskId00Assignment, taskId01Assignment); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + + assertThat(task00.state(), is(Task.State.RUNNING)); + assertThat(task01.state(), is(Task.State.RUNNING)); + + task00.setCommitNeeded(); + task01.setCommitNeeded(); + + assertThat(taskManager.commitAll(), equalTo(2)); + assertThat(task00.commitNeeded, is(false)); + assertThat(task01.commitNeeded, is(false)); + } + + @Test + public void shouldCommitProvidedTasksIfNeeded() { + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true); + final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true); + final StateMachineTask task02 = new StateMachineTask(taskId02, taskId02Partitions, true); + final StateMachineTask task03 = new StateMachineTask(taskId03, taskId03Partitions, false); + final StateMachineTask task04 = new StateMachineTask(taskId04, taskId04Partitions, false); + final StateMachineTask task05 = new StateMachineTask(taskId05, taskId05Partitions, false); + + final Map> assignmentActive = mkMap( + mkEntry(taskId00, taskId00Partitions), + mkEntry(taskId01, taskId01Partitions), + mkEntry(taskId02, taskId02Partitions) + ); + final Map> assignmentStandby = mkMap( + mkEntry(taskId03, taskId03Partitions), + mkEntry(taskId04, taskId04Partitions), + mkEntry(taskId05, taskId05Partitions) + ); + + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(activeTaskCreator.createTasks(anyObject(), eq(assignmentActive))) + .andStubReturn(Arrays.asList(task00, task01, task02)); + expect(standbyTaskCreator.createTasks(eq(assignmentStandby))) + .andStubReturn(Arrays.asList(task03, task04, task05)); + + consumer.commitSync(eq(emptyMap())); + + replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader); + + taskManager.handleAssignment(assignmentActive, assignmentStandby); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + + assertThat(task00.state(), is(Task.State.RUNNING)); + assertThat(task01.state(), is(Task.State.RUNNING)); + + task00.setCommitNeeded(); + task01.setCommitNeeded(); + task03.setCommitNeeded(); + task04.setCommitNeeded(); + + assertThat(taskManager.commit(mkSet(task00, task02, task03, task05)), equalTo(2)); + assertThat(task00.commitNeeded, is(false)); + assertThat(task01.commitNeeded, is(true)); + assertThat(task02.commitNeeded, is(false)); + assertThat(task03.commitNeeded, is(false)); + assertThat(task04.commitNeeded, is(true)); + assertThat(task05.commitNeeded, is(false)); + } + + @Test + public void shouldNotCommitOffsetsIfOnlyStandbyTasksAssigned() { + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, false); + + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(standbyTaskCreator.createTasks(eq(taskId00Assignment))) + .andStubReturn(singletonList(task00)); + expectLastCall(); + + replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader); + + taskManager.handleAssignment(Collections.emptyMap(), taskId00Assignment); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + + assertThat(task00.state(), is(Task.State.RUNNING)); + + task00.setCommitNeeded(); + + assertThat(taskManager.commitAll(), equalTo(1)); + assertThat(task00.commitNeeded, is(false)); + } + + @Test + public void shouldNotCommitActiveAndStandbyTasksWhileRebalanceInProgress() throws Exception { + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true); + final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, false); + + makeTaskFolders(taskId00.toString(), task01.toString()); + expectLockObtainedFor(taskId00, taskId01); + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))) + .andStubReturn(singletonList(task00)); + expect(standbyTaskCreator.createTasks(eq(taskId01Assignment))) + .andStubReturn(singletonList(task01)); + + replay(activeTaskCreator, standbyTaskCreator, stateDirectory, consumer, changeLogReader); + + taskManager.handleAssignment(taskId00Assignment, taskId01Assignment); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + + assertThat(task00.state(), is(Task.State.RUNNING)); + assertThat(task01.state(), is(Task.State.RUNNING)); + + task00.setCommitNeeded(); + task01.setCommitNeeded(); + + taskManager.handleRebalanceStart(emptySet()); + + assertThat( + taskManager.commitAll(), + equalTo(-1) // sentinel indicating that nothing was done because a rebalance is in progress + ); + + assertThat( + taskManager.maybeCommitActiveTasksPerUserRequested(), + equalTo(-1) // sentinel indicating that nothing was done because a rebalance is in progress + ); + } + + @Test + public void shouldCommitViaConsumerIfEosDisabled() { + final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true); + final Map offsets = singletonMap(t1p1, new OffsetAndMetadata(0L, null)); + task01.setCommittableOffsetsAndMetadata(offsets); + task01.setCommitNeeded(); + taskManager.addTask(task01); + + consumer.commitSync(offsets); + expectLastCall(); + replay(consumer); + + taskManager.commitAll(); + + verify(consumer); + } + + @Test + public void shouldCommitViaProducerIfEosAlphaEnabled() { + final StreamsProducer producer = mock(StreamsProducer.class); + expect(activeTaskCreator.streamsProducerForTask(anyObject(TaskId.class))) + .andReturn(producer) + .andReturn(producer); + + final Map offsetsT01 = singletonMap(t1p1, new OffsetAndMetadata(0L, null)); + final Map offsetsT02 = singletonMap(t1p2, new OffsetAndMetadata(1L, null)); + + producer.commitTransaction(offsetsT01, new ConsumerGroupMetadata("appId")); + expectLastCall(); + producer.commitTransaction(offsetsT02, new ConsumerGroupMetadata("appId")); + expectLastCall(); + + shouldCommitViaProducerIfEosEnabled(StreamThread.ProcessingMode.EXACTLY_ONCE_ALPHA, producer, offsetsT01, offsetsT02); + } + + @Test + public void shouldCommitViaProducerIfEosV2Enabled() { + final StreamsProducer producer = mock(StreamsProducer.class); + expect(activeTaskCreator.threadProducer()).andReturn(producer); + + final Map offsetsT01 = singletonMap(t1p1, new OffsetAndMetadata(0L, null)); + final Map offsetsT02 = singletonMap(t1p2, new OffsetAndMetadata(1L, null)); + final Map allOffsets = new HashMap<>(); + allOffsets.putAll(offsetsT01); + allOffsets.putAll(offsetsT02); + + producer.commitTransaction(allOffsets, new ConsumerGroupMetadata("appId")); + expectLastCall(); + + shouldCommitViaProducerIfEosEnabled(StreamThread.ProcessingMode.EXACTLY_ONCE_V2, producer, offsetsT01, offsetsT02); + } + + private void shouldCommitViaProducerIfEosEnabled(final StreamThread.ProcessingMode processingMode, + final StreamsProducer producer, + final Map offsetsT01, + final Map offsetsT02) { + setUpTaskManager(processingMode); + + final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true); + task01.setCommittableOffsetsAndMetadata(offsetsT01); + task01.setCommitNeeded(); + taskManager.addTask(task01); + final StateMachineTask task02 = new StateMachineTask(taskId02, taskId02Partitions, true); + task02.setCommittableOffsetsAndMetadata(offsetsT02); + task02.setCommitNeeded(); + taskManager.addTask(task02); + + reset(consumer); + expect(consumer.groupMetadata()).andStubReturn(new ConsumerGroupMetadata("appId")); + replay(activeTaskCreator, consumer, producer); + + taskManager.commitAll(); + + verify(producer, consumer); + } + + @Test + public void shouldPropagateExceptionFromActiveCommit() { + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true) { + @Override + public Map prepareCommit() { + throw new RuntimeException("opsh."); + } + }; + + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))) + .andStubReturn(singletonList(task00)); + + replay(activeTaskCreator, consumer, changeLogReader); + + taskManager.handleAssignment(taskId00Assignment, emptyMap()); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + + assertThat(task00.state(), is(Task.State.RUNNING)); + + task00.setCommitNeeded(); + + final RuntimeException thrown = + assertThrows(RuntimeException.class, () -> taskManager.commitAll()); + assertThat(thrown.getMessage(), equalTo("opsh.")); + } + + @Test + public void shouldPropagateExceptionFromStandbyCommit() { + final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, false) { + @Override + public Map prepareCommit() { + throw new RuntimeException("opsh."); + } + }; + + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(standbyTaskCreator.createTasks(eq(taskId01Assignment))) + .andStubReturn(singletonList(task01)); + + replay(standbyTaskCreator, consumer, changeLogReader); + + taskManager.handleAssignment(emptyMap(), taskId01Assignment); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + + assertThat(task01.state(), is(Task.State.RUNNING)); + + task01.setCommitNeeded(); + + final RuntimeException thrown = + assertThrows(RuntimeException.class, () -> taskManager.commitAll()); + assertThat(thrown.getMessage(), equalTo("opsh.")); + } + + @Test + public void shouldSendPurgeData() { + resetToStrict(adminClient); + expect(adminClient.deleteRecords(singletonMap(t1p1, RecordsToDelete.beforeOffset(5L)))) + .andReturn(new DeleteRecordsResult(singletonMap(t1p1, completedFuture()))); + expect(adminClient.deleteRecords(singletonMap(t1p1, RecordsToDelete.beforeOffset(17L)))) + .andReturn(new DeleteRecordsResult(singletonMap(t1p1, completedFuture()))); + replay(adminClient); + + final Map purgableOffsets = new HashMap<>(); + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true) { + @Override + public Map purgeableOffsets() { + return purgableOffsets; + } + }; + + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))) + .andStubReturn(singletonList(task00)); + + replay(activeTaskCreator, consumer, changeLogReader); + + taskManager.handleAssignment(taskId00Assignment, emptyMap()); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + + assertThat(task00.state(), is(Task.State.RUNNING)); + + purgableOffsets.put(t1p1, 5L); + taskManager.maybePurgeCommittedRecords(); + + purgableOffsets.put(t1p1, 17L); + taskManager.maybePurgeCommittedRecords(); + + verify(adminClient); + } + + @Test + public void shouldNotSendPurgeDataIfPreviousNotDone() { + resetToStrict(adminClient); + final KafkaFutureImpl futureDeletedRecords = new KafkaFutureImpl<>(); + expect(adminClient.deleteRecords(singletonMap(t1p1, RecordsToDelete.beforeOffset(5L)))) + .andReturn(new DeleteRecordsResult(singletonMap(t1p1, futureDeletedRecords))); + replay(adminClient); + + final Map purgableOffsets = new HashMap<>(); + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true) { + @Override + public Map purgeableOffsets() { + return purgableOffsets; + } + }; + + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))) + .andStubReturn(singletonList(task00)); + + replay(activeTaskCreator, consumer, changeLogReader); + + taskManager.handleAssignment(taskId00Assignment, emptyMap()); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + + assertThat(task00.state(), is(Task.State.RUNNING)); + + purgableOffsets.put(t1p1, 5L); + taskManager.maybePurgeCommittedRecords(); + + // this call should be a no-op. + // this is verified, as there is no expectation on adminClient for this second call, + // so it would fail verification if we invoke the admin client again. + purgableOffsets.put(t1p1, 17L); + taskManager.maybePurgeCommittedRecords(); + + verify(adminClient); + } + + @Test + public void shouldIgnorePurgeDataErrors() { + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true); + + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))) + .andStubReturn(singletonList(task00)); + + final KafkaFutureImpl futureDeletedRecords = new KafkaFutureImpl<>(); + final DeleteRecordsResult deleteRecordsResult = new DeleteRecordsResult(singletonMap(t1p1, futureDeletedRecords)); + futureDeletedRecords.completeExceptionally(new Exception("KABOOM!")); + expect(adminClient.deleteRecords(anyObject())).andReturn(deleteRecordsResult).times(2); + + replay(activeTaskCreator, adminClient, consumer, changeLogReader); + + taskManager.handleAssignment(taskId00Assignment, emptyMap()); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + + assertThat(task00.state(), is(Task.State.RUNNING)); + + task00.setPurgeableOffsets(singletonMap(t1p1, 5L)); + + taskManager.maybePurgeCommittedRecords(); + taskManager.maybePurgeCommittedRecords(); + + verify(adminClient); + } + + @Test + public void shouldMaybeCommitAllActiveTasksThatNeedCommit() { + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true); + final Map offsets0 = singletonMap(t1p0, new OffsetAndMetadata(0L, null)); + task00.setCommittableOffsetsAndMetadata(offsets0); + final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true); + final Map offsets1 = singletonMap(t1p1, new OffsetAndMetadata(1L, null)); + task01.setCommittableOffsetsAndMetadata(offsets1); + final StateMachineTask task02 = new StateMachineTask(taskId02, taskId02Partitions, true); + final Map offsets2 = singletonMap(t1p2, new OffsetAndMetadata(2L, null)); + task02.setCommittableOffsetsAndMetadata(offsets2); + final StateMachineTask task03 = new StateMachineTask(taskId03, taskId03Partitions, true); + final StateMachineTask task04 = new StateMachineTask(taskId10, taskId10Partitions, false); + + final Map expectedCommittedOffsets = new HashMap<>(); + expectedCommittedOffsets.putAll(offsets0); + expectedCommittedOffsets.putAll(offsets1); + + final Map> assignmentActive = mkMap( + mkEntry(taskId00, taskId00Partitions), + mkEntry(taskId01, taskId01Partitions), + mkEntry(taskId02, taskId02Partitions), + mkEntry(taskId03, taskId03Partitions) + ); + + final Map> assignmentStandby = mkMap( + mkEntry(taskId10, taskId10Partitions) + ); + + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(activeTaskCreator.createTasks(anyObject(), eq(assignmentActive))) + .andStubReturn(asList(task00, task01, task02, task03)); + expect(standbyTaskCreator.createTasks(eq(assignmentStandby))) + .andStubReturn(singletonList(task04)); + consumer.commitSync(expectedCommittedOffsets); + expectLastCall(); + + replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader); + + taskManager.handleAssignment(assignmentActive, assignmentStandby); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + + assertThat(task00.state(), is(Task.State.RUNNING)); + assertThat(task01.state(), is(Task.State.RUNNING)); + assertThat(task02.state(), is(Task.State.RUNNING)); + assertThat(task03.state(), is(Task.State.RUNNING)); + assertThat(task04.state(), is(Task.State.RUNNING)); + + task00.setCommitNeeded(); + task00.setCommitRequested(); + + task01.setCommitNeeded(); + + task02.setCommitRequested(); + + task03.setCommitNeeded(); + task03.setCommitRequested(); + + task04.setCommitNeeded(); + task04.setCommitRequested(); + + assertThat(taskManager.maybeCommitActiveTasksPerUserRequested(), equalTo(3)); + } + + @Test + public void shouldProcessActiveTasks() { + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true); + final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true); + + final Map> assignment = new HashMap<>(); + assignment.put(taskId00, taskId00Partitions); + assignment.put(taskId01, taskId01Partitions); + + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))) + .andStubReturn(Arrays.asList(task00, task01)); + + replay(activeTaskCreator, consumer, changeLogReader); + + taskManager.handleAssignment(assignment, emptyMap()); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + + assertThat(task00.state(), is(Task.State.RUNNING)); + assertThat(task01.state(), is(Task.State.RUNNING)); + + task00.addRecords( + t1p0, + Arrays.asList( + getConsumerRecord(t1p0, 0L), + getConsumerRecord(t1p0, 1L), + getConsumerRecord(t1p0, 2L), + getConsumerRecord(t1p0, 3L), + getConsumerRecord(t1p0, 4L), + getConsumerRecord(t1p0, 5L) + ) + ); + task01.addRecords( + t1p1, + Arrays.asList( + getConsumerRecord(t1p1, 0L), + getConsumerRecord(t1p1, 1L), + getConsumerRecord(t1p1, 2L), + getConsumerRecord(t1p1, 3L), + getConsumerRecord(t1p1, 4L) + ) + ); + + // check that we should be processing at most max num records + assertThat(taskManager.process(3, time), is(6)); + + // check that if there's no records proccssible, we would stop early + assertThat(taskManager.process(3, time), is(5)); + assertThat(taskManager.process(3, time), is(0)); + } + + @Test + public void shouldNotFailOnTimeoutException() { + final AtomicReference timeoutException = new AtomicReference<>(); + timeoutException.set(new TimeoutException("Skip me!")); + + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true); + task00.transitionTo(State.RESTORING); + task00.transitionTo(State.RUNNING); + final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true) { + @Override + public boolean process(final long wallClockTime) { + final TimeoutException exception = timeoutException.get(); + if (exception != null) { + throw exception; + } + return true; + } + }; + task01.transitionTo(State.RESTORING); + task01.transitionTo(State.RUNNING); + final StateMachineTask task02 = new StateMachineTask(taskId02, taskId02Partitions, true); + task02.transitionTo(State.RESTORING); + task02.transitionTo(State.RUNNING); + + taskManager.addTask(task00); + taskManager.addTask(task01); + taskManager.addTask(task02); + + task00.addRecords( + t1p0, + Arrays.asList( + getConsumerRecord(t1p0, 0L), + getConsumerRecord(t1p0, 1L) + ) + ); + task01.addRecords( + t1p1, + Arrays.asList( + getConsumerRecord(t1p1, 0L), + getConsumerRecord(t1p1, 1L) + ) + ); + task02.addRecords( + t1p2, + Arrays.asList( + getConsumerRecord(t1p2, 0L), + getConsumerRecord(t1p2, 1L) + ) + ); + + // should only process 2 records, because task01 throws TimeoutException + assertThat(taskManager.process(1, time), is(2)); + assertThat(task01.timeout, equalTo(time.milliseconds())); + + // retry without error + timeoutException.set(null); + assertThat(taskManager.process(1, time), is(3)); + assertThat(task01.timeout, equalTo(null)); + + // there should still be one record for task01 to be processed + assertThat(taskManager.process(1, time), is(1)); + } + + @Test + public void shouldPropagateTaskMigratedExceptionsInProcessActiveTasks() { + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true) { + @Override + public boolean process(final long wallClockTime) { + throw new TaskMigratedException("migrated", new RuntimeException("cause")); + } + }; + + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))) + .andStubReturn(singletonList(task00)); + + replay(activeTaskCreator, consumer, changeLogReader); + + taskManager.handleAssignment(taskId00Assignment, emptyMap()); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + + assertThat(task00.state(), is(Task.State.RUNNING)); + + final TopicPartition partition = taskId00Partitions.iterator().next(); + task00.addRecords(partition, singletonList(getConsumerRecord(partition, 0L))); + + assertThrows(TaskMigratedException.class, () -> taskManager.process(1, time)); + } + + @Test + public void shouldWrapRuntimeExceptionsInProcessActiveTasksAndSetTaskId() { + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true) { + @Override + public boolean process(final long wallClockTime) { + throw new RuntimeException("oops"); + } + }; + + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))) + .andStubReturn(singletonList(task00)); + + replay(activeTaskCreator, consumer, changeLogReader); + + taskManager.handleAssignment(taskId00Assignment, emptyMap()); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + + assertThat(task00.state(), is(Task.State.RUNNING)); + + final TopicPartition partition = taskId00Partitions.iterator().next(); + task00.addRecords(partition, singletonList(getConsumerRecord(partition, 0L))); + + final StreamsException exception = assertThrows(StreamsException.class, () -> taskManager.process(1, time)); + assertThat(exception.taskId().isPresent(), is(true)); + assertThat(exception.taskId().get(), is(taskId00)); + assertThat(exception.getCause().getMessage(), is("oops")); + } + + @Test + public void shouldPropagateTaskMigratedExceptionsInPunctuateActiveTasks() { + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true) { + @Override + public boolean maybePunctuateStreamTime() { + throw new TaskMigratedException("migrated", new RuntimeException("cause")); + } + }; + + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))) + .andStubReturn(singletonList(task00)); + + replay(activeTaskCreator, consumer, changeLogReader); + + taskManager.handleAssignment(taskId00Assignment, emptyMap()); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + + assertThat(task00.state(), is(Task.State.RUNNING)); + + assertThrows(TaskMigratedException.class, () -> taskManager.punctuate()); + } + + @Test + public void shouldPropagateKafkaExceptionsInPunctuateActiveTasks() { + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true) { + @Override + public boolean maybePunctuateStreamTime() { + throw new KafkaException("oops"); + } + }; + + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))) + .andStubReturn(singletonList(task00)); + + replay(activeTaskCreator, consumer, changeLogReader); + + taskManager.handleAssignment(taskId00Assignment, emptyMap()); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + + assertThat(task00.state(), is(Task.State.RUNNING)); + + assertThrows(KafkaException.class, () -> taskManager.punctuate()); + } + + @Test + public void shouldPunctuateActiveTasks() { + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true) { + @Override + public boolean maybePunctuateStreamTime() { + return true; + } + + @Override + public boolean maybePunctuateSystemTime() { + return true; + } + }; + + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))) + .andStubReturn(singletonList(task00)); + + replay(activeTaskCreator, consumer, changeLogReader); + + taskManager.handleAssignment(taskId00Assignment, emptyMap()); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + + assertThat(task00.state(), is(Task.State.RUNNING)); + + // one for stream and one for system time + assertThat(taskManager.punctuate(), equalTo(2)); + } + + @Test + public void shouldReturnFalseWhenThereAreStillNonRunningTasks() { + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true) { + @Override + public Collection changelogPartitions() { + return singletonList(new TopicPartition("fake", 0)); + } + }; + + expect(changeLogReader.completedChangelogs()).andReturn(emptySet()); + expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))) + .andStubReturn(singletonList(task00)); + + replay(activeTaskCreator, changeLogReader, consumer); + + taskManager.handleAssignment(taskId00Assignment, emptyMap()); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(false)); + assertThat(task00.state(), is(Task.State.RESTORING)); + // this could be a bit mysterious; we're verifying _no_ interactions on the consumer, + // since the taskManager should _not_ resume the assignment while we're still in RESTORING + verify(consumer); + } + + @Test + public void shouldHaveRemainingPartitionsUncleared() { + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true); + final Map offsets = singletonMap(t1p0, new OffsetAndMetadata(0L, null)); + task00.setCommittableOffsetsAndMetadata(offsets); + + expectRestoreToBeCompleted(consumer, changeLogReader); + expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))).andReturn(singletonList(task00)); + consumer.commitSync(offsets); + expectLastCall(); + + replay(activeTaskCreator, consumer, changeLogReader); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(TaskManager.class)) { + LogCaptureAppender.setClassLoggerToDebug(TaskManager.class); + taskManager.handleAssignment(taskId00Assignment, emptyMap()); + assertThat(taskManager.tryToCompleteRestoration(time.milliseconds(), null), is(true)); + assertThat(task00.state(), is(Task.State.RUNNING)); + + taskManager.handleRevocation(mkSet(t1p0, new TopicPartition("unknown", 0))); + assertThat(task00.state(), is(Task.State.SUSPENDED)); + + final List messages = appender.getMessages(); + assertThat( + messages, + hasItem("taskManagerTestThe following revoked partitions [unknown-0] are missing " + + "from the current task partitions. It could potentially be due to race " + + "condition of consumer detecting the heartbeat failure, or the " + + "tasks have been cleaned up by the handleAssignment callback.") + ); + } + } + + @Test + public void shouldThrowTaskMigratedWhenAllTaskCloseExceptionsAreTaskMigrated() { + final StateMachineTask migratedTask01 = new StateMachineTask(taskId01, taskId01Partitions, false) { + @Override + public void suspend() { + super.suspend(); + throw new TaskMigratedException("t1 close exception", new RuntimeException()); + } + }; + + final StateMachineTask migratedTask02 = new StateMachineTask(taskId02, taskId02Partitions, false) { + @Override + public void suspend() { + super.suspend(); + throw new TaskMigratedException("t2 close exception", new RuntimeException()); + } + }; + taskManager.addTask(migratedTask01); + taskManager.addTask(migratedTask02); + + final TaskMigratedException thrown = assertThrows( + TaskMigratedException.class, + () -> taskManager.handleAssignment(emptyMap(), emptyMap()) + ); + // The task map orders tasks based on topic group id and partition, so here + // t1 should always be the first. + assertThat( + thrown.getMessage(), + equalTo("t1 close exception; it means all tasks belonging to this thread should be migrated.") + ); + } + + @Test + public void shouldThrowRuntimeExceptionWhenEncounteredUnknownExceptionDuringTaskClose() { + final StateMachineTask migratedTask01 = new StateMachineTask(taskId01, taskId01Partitions, false) { + @Override + public void suspend() { + super.suspend(); + throw new TaskMigratedException("t1 close exception", new RuntimeException()); + } + }; + + final StateMachineTask migratedTask02 = new StateMachineTask(taskId02, taskId02Partitions, false) { + @Override + public void suspend() { + super.suspend(); + throw new IllegalStateException("t2 illegal state exception", new RuntimeException()); + } + }; + taskManager.addTask(migratedTask01); + taskManager.addTask(migratedTask02); + + final RuntimeException thrown = assertThrows( + RuntimeException.class, + () -> taskManager.handleAssignment(emptyMap(), emptyMap()) + ); + // Fatal exception thrown first. + assertThat(thrown.getMessage(), equalTo("Unexpected failure to close 2 task(s) [[0_1, 0_2]]. " + + "First unexpected exception (for task 0_2) follows.")); + + assertThat(thrown.getCause().getMessage(), equalTo("t2 illegal state exception")); + } + + @Test + public void shouldThrowSameKafkaExceptionWhenEncounteredDuringTaskClose() { + final StateMachineTask migratedTask01 = new StateMachineTask(taskId01, taskId01Partitions, false) { + @Override + public void suspend() { + super.suspend(); + throw new TaskMigratedException("t1 close exception", new RuntimeException()); + } + }; + + final StateMachineTask migratedTask02 = new StateMachineTask(taskId02, taskId02Partitions, false) { + @Override + public void suspend() { + super.suspend(); + throw new KafkaException("Kaboom for t2!", new RuntimeException()); + } + }; + taskManager.addTask(migratedTask01); + taskManager.addTask(migratedTask02); + + final StreamsException thrown = assertThrows( + StreamsException.class, + () -> taskManager.handleAssignment(emptyMap(), emptyMap()) + ); + + assertThat(thrown.taskId().isPresent(), is(true)); + assertThat(thrown.taskId().get(), is(taskId02)); + + // Expecting the original Kafka exception wrapped in the StreamsException. + assertThat(thrown.getCause().getMessage(), equalTo("Kaboom for t2!")); + } + + @Test + public void shouldTransmitProducerMetrics() { + final MetricName testMetricName = new MetricName("test_metric", "", "", new HashMap<>()); + final Metric testMetric = new KafkaMetric( + new Object(), + testMetricName, + (Measurable) (config, now) -> 0, + null, + new MockTime()); + final Map dummyProducerMetrics = singletonMap(testMetricName, testMetric); + + expect(activeTaskCreator.producerMetrics()).andReturn(dummyProducerMetrics); + replay(activeTaskCreator); + + assertThat(taskManager.producerMetrics(), is(dummyProducerMetrics)); + } + + private Map handleAssignment(final Map> runningActiveAssignment, + final Map> standbyAssignment, + final Map> restoringActiveAssignment) { + final Set runningTasks = runningActiveAssignment.entrySet().stream() + .map(t -> new StateMachineTask(t.getKey(), t.getValue(), true)) + .collect(Collectors.toSet()); + final Set standbyTasks = standbyAssignment.entrySet().stream() + .map(t -> new StateMachineTask(t.getKey(), t.getValue(), false)) + .collect(Collectors.toSet()); + final Set restoringTasks = restoringActiveAssignment.entrySet().stream() + .map(t -> new StateMachineTask(t.getKey(), t.getValue(), true)) + .collect(Collectors.toSet()); + // give the restoring tasks some uncompleted changelog partitions so they'll stay in restoring + restoringTasks.forEach(t -> ((StateMachineTask) t).setChangelogOffsets(singletonMap(new TopicPartition("changelog", 0), 0L))); + + // Initially assign only the active tasks we want to complete restoration + final Map> allActiveTasksAssignment = new HashMap<>(runningActiveAssignment); + allActiveTasksAssignment.putAll(restoringActiveAssignment); + final Set allActiveTasks = new HashSet<>(runningTasks); + allActiveTasks.addAll(restoringTasks); + + expect(standbyTaskCreator.createTasks(eq(standbyAssignment))).andStubReturn(standbyTasks); + expect(activeTaskCreator.createTasks(anyObject(), eq(allActiveTasksAssignment))).andStubReturn(allActiveTasks); + + expectRestoreToBeCompleted(consumer, changeLogReader); + replay(activeTaskCreator, standbyTaskCreator, consumer, changeLogReader); + + taskManager.handleAssignment(allActiveTasksAssignment, standbyAssignment); + taskManager.tryToCompleteRestoration(time.milliseconds(), null); + + final Map allTasks = new HashMap<>(); + + // Just make sure all tasks ended up in the expected state + for (final Task task : runningTasks) { + assertThat(task.state(), is(Task.State.RUNNING)); + allTasks.put(task.id(), (StateMachineTask) task); + } + for (final Task task : restoringTasks) { + assertThat(task.state(), is(Task.State.RESTORING)); + allTasks.put(task.id(), (StateMachineTask) task); + } + for (final Task task : standbyTasks) { + assertThat(task.state(), is(Task.State.RUNNING)); + allTasks.put(task.id(), (StateMachineTask) task); + } + return allTasks; + } + + private void expectLockObtainedFor(final TaskId... tasks) throws Exception { + for (final TaskId task : tasks) { + expect(stateDirectory.lock(task)).andReturn(true).once(); + } + } + + private void expectLockFailedFor(final TaskId... tasks) throws Exception { + for (final TaskId task : tasks) { + expect(stateDirectory.lock(task)).andReturn(false).once(); + } + } + + private void expectUnlockFor(final TaskId... tasks) throws Exception { + for (final TaskId task : tasks) { + stateDirectory.unlock(task); + expectLastCall(); + } + } + + private static void expectConsumerAssignmentPaused(final Consumer consumer) { + final Set assignment = singleton(new TopicPartition("assignment", 0)); + expect(consumer.assignment()).andReturn(assignment); + consumer.pause(assignment); + } + + @Test + public void shouldThrowTaskMigratedExceptionOnCommitFailed() { + final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true); + final Map offsets = singletonMap(t1p0, new OffsetAndMetadata(0L, null)); + task01.setCommittableOffsetsAndMetadata(offsets); + task01.setCommitNeeded(); + taskManager.addTask(task01); + + consumer.commitSync(offsets); + expectLastCall().andThrow(new CommitFailedException()); + replay(consumer); + + final TaskMigratedException thrown = assertThrows( + TaskMigratedException.class, + () -> taskManager.commitAll() + ); + + assertThat(thrown.getCause(), instanceOf(CommitFailedException.class)); + assertThat( + thrown.getMessage(), + equalTo("Consumer committing offsets failed, indicating the corresponding thread is no longer part of the group;" + + " it means all tasks belonging to this thread should be migrated.") + ); + assertThat(task01.state(), is(Task.State.CREATED)); + } + + @SuppressWarnings("unchecked") + @Test + public void shouldNotFailForTimeoutExceptionOnConsumerCommit() { + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true); + final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true); + + task00.setCommittableOffsetsAndMetadata(taskId00Partitions.stream().collect(Collectors.toMap(p -> p, p -> new OffsetAndMetadata(0)))); + task01.setCommittableOffsetsAndMetadata(taskId00Partitions.stream().collect(Collectors.toMap(p -> p, p -> new OffsetAndMetadata(0)))); + + consumer.commitSync(anyObject(Map.class)); + expectLastCall().andThrow(new TimeoutException("KABOOM!")); + consumer.commitSync(anyObject(Map.class)); + expectLastCall(); + replay(consumer); + + task00.setCommitNeeded(); + + assertThat(taskManager.commit(mkSet(task00, task01)), equalTo(0)); + assertThat(task00.timeout, equalTo(time.milliseconds())); + assertNull(task01.timeout); + + assertThat(taskManager.commit(mkSet(task00, task01)), equalTo(1)); + assertNull(task00.timeout); + assertNull(task01.timeout); + } + + @Test + public void shouldNotFailForTimeoutExceptionOnCommitWithEosAlpha() { + setUpTaskManager(ProcessingMode.EXACTLY_ONCE_ALPHA); + + final StreamsProducer producer = mock(StreamsProducer.class); + expect(activeTaskCreator.streamsProducerForTask(anyObject(TaskId.class))) + .andReturn(producer) + .andReturn(producer) + .andReturn(producer); + + final Map offsetsT00 = singletonMap(t1p0, new OffsetAndMetadata(0L, null)); + final Map offsetsT01 = singletonMap(t1p1, new OffsetAndMetadata(1L, null)); + + producer.commitTransaction(offsetsT00, null); + expectLastCall().andThrow(new TimeoutException("KABOOM!")); + producer.commitTransaction(offsetsT00, null); + expectLastCall(); + + producer.commitTransaction(offsetsT01, null); + expectLastCall(); + producer.commitTransaction(offsetsT01, null); + expectLastCall(); + + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true); + task00.setCommittableOffsetsAndMetadata(offsetsT00); + final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true); + task01.setCommittableOffsetsAndMetadata(offsetsT01); + final StateMachineTask task02 = new StateMachineTask(taskId02, taskId02Partitions, true); + + expect(consumer.groupMetadata()).andStubReturn(null); + replay(producer, activeTaskCreator, consumer); + + task00.setCommitNeeded(); + task01.setCommitNeeded(); + + final TaskCorruptedException exception = assertThrows( + TaskCorruptedException.class, + () -> taskManager.commit(mkSet(task00, task01, task02)) + ); + assertThat( + exception.corruptedTasks(), + equalTo(Collections.singleton(taskId00)) + ); + } + + @Test + public void shouldThrowTaskCorruptedExceptionForTimeoutExceptionOnCommitWithEosV2() { + setUpTaskManager(ProcessingMode.EXACTLY_ONCE_V2); + + final StreamsProducer producer = mock(StreamsProducer.class); + expect(activeTaskCreator.threadProducer()) + .andReturn(producer) + .andReturn(producer); + + final Map offsetsT00 = singletonMap(t1p0, new OffsetAndMetadata(0L, null)); + final Map offsetsT01 = singletonMap(t1p1, new OffsetAndMetadata(1L, null)); + final Map allOffsets = new HashMap<>(offsetsT00); + allOffsets.putAll(offsetsT01); + + producer.commitTransaction(allOffsets, null); + expectLastCall().andThrow(new TimeoutException("KABOOM!")); + producer.commitTransaction(allOffsets, null); + expectLastCall(); + + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true); + task00.setCommittableOffsetsAndMetadata(offsetsT00); + final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true); + task01.setCommittableOffsetsAndMetadata(offsetsT01); + final StateMachineTask task02 = new StateMachineTask(taskId02, taskId02Partitions, true); + + expect(consumer.groupMetadata()).andStubReturn(null); + replay(producer, activeTaskCreator, consumer); + + task00.setCommitNeeded(); + task01.setCommitNeeded(); + + final TaskCorruptedException exception = assertThrows( + TaskCorruptedException.class, + () -> taskManager.commit(mkSet(task00, task01, task02)) + ); + assertThat( + exception.corruptedTasks(), + equalTo(mkSet(taskId00, taskId01)) + ); + } + + @Test + public void shouldStreamsExceptionOnCommitError() { + final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true); + final Map offsets = singletonMap(t1p0, new OffsetAndMetadata(0L, null)); + task01.setCommittableOffsetsAndMetadata(offsets); + task01.setCommitNeeded(); + taskManager.addTask(task01); + + consumer.commitSync(offsets); + expectLastCall().andThrow(new KafkaException()); + replay(consumer); + + final StreamsException thrown = assertThrows( + StreamsException.class, + () -> taskManager.commitAll() + ); + + assertThat(thrown.getCause(), instanceOf(KafkaException.class)); + assertThat(thrown.getMessage(), equalTo("Error encountered committing offsets via consumer")); + assertThat(task01.state(), is(Task.State.CREATED)); + } + + @Test + public void shouldFailOnCommitFatal() { + final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true); + final Map offsets = singletonMap(t1p0, new OffsetAndMetadata(0L, null)); + task01.setCommittableOffsetsAndMetadata(offsets); + task01.setCommitNeeded(); + taskManager.addTask(task01); + + consumer.commitSync(offsets); + expectLastCall().andThrow(new RuntimeException("KABOOM")); + replay(consumer); + + final RuntimeException thrown = assertThrows( + RuntimeException.class, + () -> taskManager.commitAll() + ); + + assertThat(thrown.getMessage(), equalTo("KABOOM")); + assertThat(task01.state(), is(Task.State.CREATED)); + } + + @Test + public void shouldSuspendAllTasksButSkipCommitIfSuspendingFailsDuringRevocation() { + final StateMachineTask task00 = new StateMachineTask(taskId00, taskId00Partitions, true) { + @Override + public void suspend() { + super.suspend(); + throw new RuntimeException("KABOOM!"); + } + }; + final StateMachineTask task01 = new StateMachineTask(taskId01, taskId01Partitions, true); + + final Map> assignment = new HashMap<>(taskId00Assignment); + assignment.putAll(taskId01Assignment); + expect(activeTaskCreator.createTasks(anyObject(), eq(assignment))) + .andReturn(asList(task00, task01)); + replay(activeTaskCreator, consumer); + + taskManager.handleAssignment(assignment, Collections.emptyMap()); + + final RuntimeException thrown = assertThrows( + RuntimeException.class, + () -> taskManager.handleRevocation(asList(t1p0, t1p1))); + + assertThat(thrown.getCause().getMessage(), is("KABOOM!")); + assertThat(task00.state(), is(Task.State.SUSPENDED)); + assertThat(task01.state(), is(Task.State.SUSPENDED)); + } + + @Test + public void shouldConvertActiveTaskToStandbyTask() { + final StreamTask activeTask = mock(StreamTask.class); + expect(activeTask.id()).andStubReturn(taskId00); + expect(activeTask.inputPartitions()).andStubReturn(taskId00Partitions); + expect(activeTask.isActive()).andStubReturn(true); + expect(activeTask.prepareCommit()).andStubReturn(Collections.emptyMap()); + + final StandbyTask standbyTask = mock(StandbyTask.class); + expect(standbyTask.id()).andStubReturn(taskId00); + + expect(activeTaskCreator.createTasks(anyObject(), eq(taskId00Assignment))) + .andReturn(singletonList(activeTask)); + activeTaskCreator.closeAndRemoveTaskProducerIfNeeded(taskId00); + expectLastCall().anyTimes(); + + expect(standbyTaskCreator.createStandbyTaskFromActive(anyObject(), eq(taskId00Partitions))) + .andReturn(standbyTask); + + replay(activeTask, standbyTask, activeTaskCreator, standbyTaskCreator, consumer); + + taskManager.handleAssignment(taskId00Assignment, Collections.emptyMap()); + taskManager.handleAssignment(Collections.emptyMap(), taskId00Assignment); + + verify(activeTaskCreator, standbyTaskCreator); + } + + @Test + public void shouldConvertStandbyTaskToActiveTask() { + final StandbyTask standbyTask = mock(StandbyTask.class); + expect(standbyTask.id()).andStubReturn(taskId00); + expect(standbyTask.isActive()).andStubReturn(false); + expect(standbyTask.prepareCommit()).andStubReturn(Collections.emptyMap()); + standbyTask.suspend(); + expectLastCall().anyTimes(); + standbyTask.postCommit(true); + expectLastCall().anyTimes(); + + final StreamTask activeTask = mock(StreamTask.class); + expect(activeTask.id()).andStubReturn(taskId00); + expect(activeTask.inputPartitions()).andStubReturn(taskId00Partitions); + + expect(standbyTaskCreator.createTasks(eq(taskId00Assignment))) + .andReturn(singletonList(standbyTask)); + + expect(activeTaskCreator.createActiveTaskFromStandby(anyObject(), eq(taskId00Partitions), anyObject())) + .andReturn(activeTask); + + replay(standbyTask, activeTask, standbyTaskCreator, activeTaskCreator, consumer); + + taskManager.handleAssignment(Collections.emptyMap(), taskId00Assignment); + taskManager.handleAssignment(taskId00Assignment, Collections.emptyMap()); + + verify(standbyTaskCreator, activeTaskCreator); + } + + private static void expectRestoreToBeCompleted(final Consumer consumer, + final ChangelogReader changeLogReader) { + expectRestoreToBeCompleted(consumer, changeLogReader, true); + } + + private static void expectRestoreToBeCompleted(final Consumer consumer, + final ChangelogReader changeLogReader, + final boolean changeLogUpdateRequired) { + final Set assignment = singleton(new TopicPartition("assignment", 0)); + expect(consumer.assignment()).andReturn(assignment); + consumer.resume(assignment); + expectLastCall(); + expect(changeLogReader.completedChangelogs()).andReturn(emptySet()).times(changeLogUpdateRequired ? 1 : 0, 1); + } + + private static KafkaFutureImpl completedFuture() { + final KafkaFutureImpl futureDeletedRecords = new KafkaFutureImpl<>(); + futureDeletedRecords.complete(null); + return futureDeletedRecords; + } + + private void makeTaskFolders(final String... names) throws Exception { + final ArrayList taskFolders = new ArrayList<>(names.length); + for (int i = 0; i < names.length; ++i) { + taskFolders.add(new TaskDirectory(testFolder.newFolder(names[i]), null)); + } + expect(stateDirectory.listNonEmptyTaskDirectories()).andReturn(taskFolders).once(); + } + + private void writeCheckpointFile(final TaskId task, final Map offsets) throws Exception { + final File checkpointFile = getCheckpointFile(task); + assertThat(checkpointFile.createNewFile(), is(true)); + new OffsetCheckpoint(checkpointFile).write(offsets); + expect(stateDirectory.checkpointFileFor(task)).andReturn(checkpointFile); + } + + private File getCheckpointFile(final TaskId task) { + return new File(new File(testFolder.getRoot(), task.toString()), StateManagerUtil.CHECKPOINT_FILE_NAME); + } + + private static ConsumerRecord getConsumerRecord(final TopicPartition topicPartition, final long offset) { + return new ConsumerRecord<>(topicPartition.topic(), topicPartition.partition(), offset, null, null); + } + + private static class StateMachineTask extends AbstractTask implements Task { + private final boolean active; + + // TODO: KAFKA-12569 clean up usage of these flags and use the new commitCompleted flag where appropriate + private boolean commitNeeded = false; + private boolean commitRequested = false; + private boolean commitPrepared = false; + private boolean commitCompleted = false; + private Map committableOffsets = Collections.emptyMap(); + private Map purgeableOffsets; + private Map changelogOffsets = Collections.emptyMap(); + private Set partitionsForOffsetReset = Collections.emptySet(); + private Long timeout = null; + + private final Map>> queue = new HashMap<>(); + + StateMachineTask(final TaskId id, + final Set partitions, + final boolean active) { + this(id, partitions, active, null); + } + + StateMachineTask(final TaskId id, + final Set partitions, + final boolean active, + final ProcessorStateManager processorStateManager) { + super(id, null, null, processorStateManager, partitions, 0L, "test-task", StateMachineTask.class); + this.active = active; + } + + @Override + public void initializeIfNeeded() { + if (state() == State.CREATED) { + transitionTo(State.RESTORING); + if (!active) { + transitionTo(State.RUNNING); + } + } + } + + @Override + public void addPartitionsForOffsetReset(final Set partitionsForOffsetReset) { + this.partitionsForOffsetReset = partitionsForOffsetReset; + } + + @Override + public void completeRestoration(final java.util.function.Consumer> offsetResetter) { + if (state() == State.RUNNING) { + return; + } + transitionTo(State.RUNNING); + } + + public void setCommitNeeded() { + commitNeeded = true; + } + + @Override + public boolean commitNeeded() { + return commitNeeded; + } + + public void setCommitRequested() { + commitRequested = true; + } + + @Override + public boolean commitRequested() { + return commitRequested; + } + + @Override + public Map prepareCommit() { + commitPrepared = true; + + if (commitNeeded) { + return committableOffsets; + } else { + return Collections.emptyMap(); + } + } + + @Override + public void postCommit(final boolean enforceCheckpoint) { + commitNeeded = false; + commitCompleted = true; + } + + @Override + public void suspend() { + if (state() == State.CLOSED) { + throw new IllegalStateException("Illegal state " + state() + " while suspending active task " + id); + } else if (state() == State.SUSPENDED) { + // do nothing + } else { + transitionTo(State.SUSPENDED); + } + } + + @Override + public void resume() { + if (state() == State.SUSPENDED) { + transitionTo(State.RUNNING); + } + } + + @Override + public void revive() { + //TODO: KAFKA-12569 move clearing of commit-required statuses to closeDirty/Clean/AndRecycle methods + commitNeeded = false; + commitRequested = false; + super.revive(); + } + + @Override + public void maybeInitTaskTimeoutOrThrow(final long currentWallClockMs, + final Exception cause) { + timeout = currentWallClockMs; + } + + @Override + public void clearTaskTimeout() { + timeout = null; + } + + @Override + public void closeClean() { + transitionTo(State.CLOSED); + } + + @Override + public void closeDirty() { + transitionTo(State.CLOSED); + } + + @Override + public void closeCleanAndRecycleState() { + transitionTo(State.CLOSED); + } + + @Override + public void updateInputPartitions(final Set topicPartitions, final Map> allTopologyNodesToSourceTopics) { + inputPartitions = topicPartitions; + } + + void setCommittableOffsetsAndMetadata(final Map committableOffsets) { + if (!active) { + throw new IllegalStateException("Cannot set CommittableOffsetsAndMetadate for StandbyTasks"); + } + this.committableOffsets = committableOffsets; + } + + @Override + public StateStore getStore(final String name) { + return null; + } + + @Override + public Collection changelogPartitions() { + return changelogOffsets.keySet(); + } + + public boolean isActive() { + return active; + } + + void setPurgeableOffsets(final Map purgeableOffsets) { + this.purgeableOffsets = purgeableOffsets; + } + + @Override + public Map purgeableOffsets() { + return purgeableOffsets; + } + + void setChangelogOffsets(final Map changelogOffsets) { + this.changelogOffsets = changelogOffsets; + } + + @Override + public Map changelogOffsets() { + return changelogOffsets; + } + + @Override + public Map committedOffsets() { + return Collections.emptyMap(); + } + + @Override + public Map highWaterMark() { + return Collections.emptyMap(); + } + + @Override + public Optional timeCurrentIdlingStarted() { + return Optional.empty(); + } + + @Override + public void addRecords(final TopicPartition partition, final Iterable> records) { + if (isActive()) { + final Deque> partitionQueue = + queue.computeIfAbsent(partition, k -> new LinkedList<>()); + + for (final ConsumerRecord record : records) { + partitionQueue.add(record); + } + } else { + throw new IllegalStateException("Can't add records to an inactive task."); + } + } + + @Override + public boolean process(final long wallClockTime) { + if (isActive() && state() == State.RUNNING) { + for (final LinkedList> records : queue.values()) { + final ConsumerRecord record = records.poll(); + if (record != null) { + return true; + } + } + return false; + } else { + throw new IllegalStateException("Can't process an inactive or non-running task."); + } + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskMetadataImplTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskMetadataImplTest.java new file mode 100644 index 0000000..dfe5daf --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskMetadataImplTest.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.streams.TaskMetadata; +import org.apache.kafka.streams.processor.TaskId; +import org.junit.Before; +import org.junit.Test; + +import java.util.Collection; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; + +public class TaskMetadataImplTest { + + public static final TaskId TASK_ID = new TaskId(1, 2); + public static final TopicPartition TP_0 = new TopicPartition("t", 0); + public static final TopicPartition TP_1 = new TopicPartition("t", 1); + public static final Set TOPIC_PARTITIONS = mkSet(TP_0, TP_1); + public static final Map COMMITTED_OFFSETS = mkMap(mkEntry(TP_1, 1L), mkEntry(TP_1, 2L)); + public static final Map END_OFFSETS = mkMap(mkEntry(TP_1, 1L), mkEntry(TP_1, 3L)); + public static final Optional TIME_CURRENT_IDLING_STARTED = Optional.of(3L); + + private TaskMetadata taskMetadata; + + @Before + public void setUp() { + taskMetadata = new TaskMetadataImpl( + TASK_ID, + TOPIC_PARTITIONS, + COMMITTED_OFFSETS, + END_OFFSETS, + TIME_CURRENT_IDLING_STARTED); + } + + @Test + public void shouldNotAllowModificationOfInternalStateViaGetters() { + assertThat(isUnmodifiable(taskMetadata.topicPartitions()), is(true)); + assertThat(isUnmodifiable(taskMetadata.committedOffsets()), is(true)); + assertThat(isUnmodifiable(taskMetadata.endOffsets()), is(true)); + } + + @Test + public void shouldBeEqualsIfSameObject() { + final TaskMetadataImpl same = new TaskMetadataImpl( + TASK_ID, + TOPIC_PARTITIONS, + COMMITTED_OFFSETS, + END_OFFSETS, + TIME_CURRENT_IDLING_STARTED); + assertThat(taskMetadata, equalTo(same)); + assertThat(taskMetadata.hashCode(), equalTo(same.hashCode())); + } + + @Test + public void shouldBeEqualsIfOnlyDifferInCommittedOffsets() { + final TaskMetadataImpl stillSameDifferCommittedOffsets = new TaskMetadataImpl( + TASK_ID, + TOPIC_PARTITIONS, + mkMap(mkEntry(TP_1, 1000000L), mkEntry(TP_1, 2L)), + END_OFFSETS, + TIME_CURRENT_IDLING_STARTED); + assertThat(taskMetadata, equalTo(stillSameDifferCommittedOffsets)); + assertThat(taskMetadata.hashCode(), equalTo(stillSameDifferCommittedOffsets.hashCode())); + } + + @Test + public void shouldBeEqualsIfOnlyDifferInEndOffsets() { + final TaskMetadataImpl stillSameDifferEndOffsets = new TaskMetadataImpl( + TASK_ID, + TOPIC_PARTITIONS, + COMMITTED_OFFSETS, + mkMap(mkEntry(TP_1, 1000000L), mkEntry(TP_1, 2L)), + TIME_CURRENT_IDLING_STARTED); + assertThat(taskMetadata, equalTo(stillSameDifferEndOffsets)); + assertThat(taskMetadata.hashCode(), equalTo(stillSameDifferEndOffsets.hashCode())); + } + + @Test + public void shouldBeEqualsIfOnlyDifferInIdlingTime() { + final TaskMetadataImpl stillSameDifferIdlingTime = new TaskMetadataImpl( + TASK_ID, + TOPIC_PARTITIONS, + COMMITTED_OFFSETS, + END_OFFSETS, + Optional.empty()); + assertThat(taskMetadata, equalTo(stillSameDifferIdlingTime)); + assertThat(taskMetadata.hashCode(), equalTo(stillSameDifferIdlingTime.hashCode())); + } + + @Test + public void shouldNotBeEqualsIfDifferInTaskID() { + final TaskMetadataImpl differTaskId = new TaskMetadataImpl( + new TaskId(1, 10000), + TOPIC_PARTITIONS, + COMMITTED_OFFSETS, + END_OFFSETS, + TIME_CURRENT_IDLING_STARTED); + assertThat(taskMetadata, not(equalTo(differTaskId))); + assertThat(taskMetadata.hashCode(), not(equalTo(differTaskId.hashCode()))); + } + + @Test + public void shouldNotBeEqualsIfDifferInTopicPartitions() { + final TaskMetadataImpl differTopicPartitions = new TaskMetadataImpl( + TASK_ID, + mkSet(TP_0), + COMMITTED_OFFSETS, + END_OFFSETS, + TIME_CURRENT_IDLING_STARTED); + assertThat(taskMetadata, not(equalTo(differTopicPartitions))); + assertThat(taskMetadata.hashCode(), not(equalTo(differTopicPartitions.hashCode()))); + } + + private static boolean isUnmodifiable(final Collection collection) { + try { + collection.clear(); + return false; + } catch (final UnsupportedOperationException e) { + return true; + } + } + + private static boolean isUnmodifiable(final Map collection) { + try { + collection.clear(); + return false; + } catch (final UnsupportedOperationException e) { + return true; + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskSuite.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskSuite.java new file mode 100644 index 0000000..25c4d71 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskSuite.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.streams.integration.StandbyTaskCreationIntegrationTest; +import org.apache.kafka.streams.processor.internals.metrics.TaskMetricsTest; +import org.apache.kafka.streams.processor.internals.assignment.StickyTaskAssignorTest; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; + +/** + * This suite runs all the tests related to task management. It's intended to simplify feature testing from IDEs. + * + * If desired, it can also be added to a Gradle build task, although this isn't strictly necessary, since all + * these tests are already included in the `:streams:test` task. + */ +@RunWith(Suite.class) +@Suite.SuiteClasses({ + StreamTaskTest.class, + StandbyTaskTest.class, + GlobalStateTaskTest.class, + TaskManagerTest.class, + TaskMetricsTest.class, + StickyTaskAssignorTest.class, + StreamsPartitionAssignorTest.class, + StandbyTaskCreationIntegrationTest.class, + }) +public class TaskSuite { +} + + diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/ThreadMetadataImplTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ThreadMetadataImplTest.java new file mode 100644 index 0000000..b87f662 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/ThreadMetadataImplTest.java @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.streams.TaskMetadata; +import org.apache.kafka.streams.ThreadMetadata; +import org.apache.kafka.streams.processor.TaskId; +import org.junit.Before; +import org.junit.Test; + +import java.util.Collection; +import java.util.Optional; +import java.util.Set; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; + +public class ThreadMetadataImplTest { + + public static final String THREAD_NAME = "thread name"; + public static final String THREAD_STATE = "thread state"; + public static final String MAIN_CONSUMER_CLIENT_ID = "main Consumer ClientID"; + public static final String RESTORE_CONSUMER_CLIENT_ID = "restore Consumer ClientID"; + public static final String CLIENT_ID_1 = "client Id 1"; + public static final String CLIENT_ID_2 = "client Id 2"; + public static final Set PRODUCER_CLIENT_IDS = mkSet(CLIENT_ID_1, CLIENT_ID_2); + public static final TaskId TASK_ID_0 = new TaskId(1, 2); + public static final TaskId TASK_ID_1 = new TaskId(1, 1); + public static final TopicPartition TP_0_0 = new TopicPartition("t", 0); + public static final TopicPartition TP_1_0 = new TopicPartition("t", 1); + public static final TopicPartition TP_0_1 = new TopicPartition("t", 2); + public static final TopicPartition TP_1_1 = new TopicPartition("t", 3); + public static final TaskMetadata TM_0 = new TaskMetadataImpl( + TASK_ID_0, + mkSet(TP_0_0, TP_1_0), + mkMap(mkEntry(TP_0_0, 1L), mkEntry(TP_1_0, 2L)), + mkMap(mkEntry(TP_0_0, 1L), mkEntry(TP_1_0, 2L)), + Optional.of(3L)); + public static final TaskMetadata TM_1 = new TaskMetadataImpl( + TASK_ID_1, + mkSet(TP_0_1, TP_1_1), + mkMap(mkEntry(TP_0_1, 1L), mkEntry(TP_1_1, 2L)), + mkMap(mkEntry(TP_0_1, 1L), mkEntry(TP_1_1, 2L)), + Optional.of(3L)); + public static final Set STANDBY_TASKS = mkSet(TM_0, TM_1); + public static final Set ACTIVE_TASKS = mkSet(TM_1); + public static final String ADMIN_CLIENT_ID = "admin ClientID"; + + private ThreadMetadata threadMetadata; + + @Before + public void setUp() { + threadMetadata = new ThreadMetadataImpl( + THREAD_NAME, + THREAD_STATE, + MAIN_CONSUMER_CLIENT_ID, + RESTORE_CONSUMER_CLIENT_ID, + PRODUCER_CLIENT_IDS, + ADMIN_CLIENT_ID, + ACTIVE_TASKS, + STANDBY_TASKS + ); + } + + @Test + public void shouldNotAllowModificationOfInternalStateViaGetters() { + assertThat(isUnmodifiable(threadMetadata.producerClientIds()), is(true)); + assertThat(isUnmodifiable(threadMetadata.activeTasks()), is(true)); + assertThat(isUnmodifiable(threadMetadata.standbyTasks()), is(true)); + } + + @Test + public void shouldBeEqualIfSameObject() { + final ThreadMetadata same = new ThreadMetadataImpl( + THREAD_NAME, + THREAD_STATE, + MAIN_CONSUMER_CLIENT_ID, + RESTORE_CONSUMER_CLIENT_ID, + PRODUCER_CLIENT_IDS, + ADMIN_CLIENT_ID, + ACTIVE_TASKS, + STANDBY_TASKS + ); + assertThat(threadMetadata, equalTo(same)); + assertThat(threadMetadata.hashCode(), equalTo(same.hashCode())); + } + + @Test + public void shouldNotBeEqualIfDifferInThreadName() { + final ThreadMetadata differThreadName = new ThreadMetadataImpl( + "different", + THREAD_STATE, + MAIN_CONSUMER_CLIENT_ID, + RESTORE_CONSUMER_CLIENT_ID, + PRODUCER_CLIENT_IDS, + ADMIN_CLIENT_ID, + ACTIVE_TASKS, + STANDBY_TASKS + ); + assertThat(threadMetadata, not(equalTo(differThreadName))); + assertThat(threadMetadata.hashCode(), not(equalTo(differThreadName.hashCode()))); + } + + @Test + public void shouldNotBeEqualIfDifferInThreadState() { + final ThreadMetadata differThreadState = new ThreadMetadataImpl( + THREAD_NAME, + "different", + MAIN_CONSUMER_CLIENT_ID, + RESTORE_CONSUMER_CLIENT_ID, + PRODUCER_CLIENT_IDS, + ADMIN_CLIENT_ID, + ACTIVE_TASKS, + STANDBY_TASKS + ); + assertThat(threadMetadata, not(equalTo(differThreadState))); + assertThat(threadMetadata.hashCode(), not(equalTo(differThreadState.hashCode()))); + } + + @Test + public void shouldNotBeEqualIfDifferInClientId() { + final ThreadMetadata differMainConsumerClientId = new ThreadMetadataImpl( + THREAD_NAME, + THREAD_STATE, + "different", + RESTORE_CONSUMER_CLIENT_ID, + PRODUCER_CLIENT_IDS, + ADMIN_CLIENT_ID, + ACTIVE_TASKS, + STANDBY_TASKS + ); + assertThat(threadMetadata, not(equalTo(differMainConsumerClientId))); + assertThat(threadMetadata.hashCode(), not(equalTo(differMainConsumerClientId.hashCode()))); + } + + @Test + public void shouldNotBeEqualIfDifferInConsumerClientId() { + final ThreadMetadata differRestoreConsumerClientId = new ThreadMetadataImpl( + THREAD_NAME, + THREAD_STATE, + MAIN_CONSUMER_CLIENT_ID, + "different", + PRODUCER_CLIENT_IDS, + ADMIN_CLIENT_ID, + ACTIVE_TASKS, + STANDBY_TASKS + ); + assertThat(threadMetadata, not(equalTo(differRestoreConsumerClientId))); + assertThat(threadMetadata.hashCode(), not(equalTo(differRestoreConsumerClientId.hashCode()))); + } + + @Test + public void shouldNotBeEqualIfDifferInProducerClientIds() { + final ThreadMetadata differProducerClientIds = new ThreadMetadataImpl( + THREAD_NAME, + THREAD_STATE, + MAIN_CONSUMER_CLIENT_ID, + RESTORE_CONSUMER_CLIENT_ID, + mkSet(CLIENT_ID_1), + ADMIN_CLIENT_ID, + ACTIVE_TASKS, + STANDBY_TASKS + ); + assertThat(threadMetadata, not(equalTo(differProducerClientIds))); + assertThat(threadMetadata.hashCode(), not(equalTo(differProducerClientIds.hashCode()))); + } + + @Test + public void shouldNotBeEqualIfDifferInAdminClientId() { + final ThreadMetadata differAdminClientId = new ThreadMetadataImpl( + THREAD_NAME, + THREAD_STATE, + MAIN_CONSUMER_CLIENT_ID, + RESTORE_CONSUMER_CLIENT_ID, + PRODUCER_CLIENT_IDS, + "different", + ACTIVE_TASKS, + STANDBY_TASKS + ); + assertThat(threadMetadata, not(equalTo(differAdminClientId))); + assertThat(threadMetadata.hashCode(), not(equalTo(differAdminClientId.hashCode()))); + } + + @Test + public void shouldNotBeEqualIfDifferInActiveTasks() { + final ThreadMetadata differActiveTasks = new ThreadMetadataImpl( + THREAD_NAME, + THREAD_STATE, + MAIN_CONSUMER_CLIENT_ID, + RESTORE_CONSUMER_CLIENT_ID, + PRODUCER_CLIENT_IDS, + ADMIN_CLIENT_ID, + mkSet(TM_0), + STANDBY_TASKS + ); + assertThat(threadMetadata, not(equalTo(differActiveTasks))); + assertThat(threadMetadata.hashCode(), not(equalTo(differActiveTasks.hashCode()))); + } + + @Test + public void shouldNotBeEqualIfDifferInStandByTasks() { + final ThreadMetadata differStandByTasks = new ThreadMetadataImpl( + THREAD_NAME, + THREAD_STATE, + MAIN_CONSUMER_CLIENT_ID, + RESTORE_CONSUMER_CLIENT_ID, + PRODUCER_CLIENT_IDS, + ADMIN_CLIENT_ID, + ACTIVE_TASKS, + mkSet(TM_0) + ); + assertThat(threadMetadata, not(equalTo(differStandByTasks))); + assertThat(threadMetadata.hashCode(), not(equalTo(differStandByTasks.hashCode()))); + } + + private static boolean isUnmodifiable(final Collection collection) { + try { + collection.clear(); + return false; + } catch (final UnsupportedOperationException e) { + return true; + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TimestampedKeyValueStoreMaterializerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TimestampedKeyValueStoreMaterializerTest.java new file mode 100644 index 0000000..77f4ab8 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TimestampedKeyValueStoreMaterializerTest.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.processor.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.internals.InternalNameProvider; +import org.apache.kafka.streams.kstream.internals.MaterializedInternal; +import org.apache.kafka.streams.kstream.internals.TimestampedKeyValueStoreMaterializer; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.state.KeyValueBytesStoreSupplier; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.internals.CachingKeyValueStore; +import org.apache.kafka.streams.state.internals.ChangeLoggingKeyValueBytesStore; +import org.apache.kafka.streams.state.internals.ChangeLoggingTimestampedKeyValueBytesStore; +import org.apache.kafka.streams.state.internals.InMemoryKeyValueStore; +import org.apache.kafka.streams.state.internals.MeteredTimestampedKeyValueStore; +import org.apache.kafka.streams.state.internals.WrappedStateStore; +import org.easymock.EasyMock; +import org.easymock.EasyMockRunner; +import org.easymock.Mock; +import org.easymock.MockType; +import org.hamcrest.CoreMatchers; +import org.junit.Test; +import org.junit.runner.RunWith; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.IsInstanceOf.instanceOf; +import static org.hamcrest.core.IsNot.not; + +@RunWith(EasyMockRunner.class) +public class TimestampedKeyValueStoreMaterializerTest { + + private final String storePrefix = "prefix"; + @Mock(type = MockType.NICE) + private InternalNameProvider nameProvider; + + @Test + public void shouldCreateBuilderThatBuildsMeteredStoreWithCachingAndLoggingEnabled() { + final MaterializedInternal> materialized = + new MaterializedInternal<>(Materialized.as("store"), nameProvider, storePrefix); + + final TimestampedKeyValueStoreMaterializer materializer = new TimestampedKeyValueStoreMaterializer<>(materialized); + final StoreBuilder> builder = materializer.materialize(); + final TimestampedKeyValueStore store = builder.build(); + final WrappedStateStore caching = (WrappedStateStore) ((WrappedStateStore) store).wrapped(); + final StateStore logging = caching.wrapped(); + assertThat(store, instanceOf(MeteredTimestampedKeyValueStore.class)); + assertThat(caching, instanceOf(CachingKeyValueStore.class)); + assertThat(logging, instanceOf(ChangeLoggingTimestampedKeyValueBytesStore.class)); + } + + @Test + public void shouldCreateBuilderThatBuildsStoreWithCachingDisabled() { + final MaterializedInternal> materialized = new MaterializedInternal<>( + Materialized.>as("store").withCachingDisabled(), nameProvider, storePrefix + ); + final TimestampedKeyValueStoreMaterializer materializer = new TimestampedKeyValueStoreMaterializer<>(materialized); + final StoreBuilder> builder = materializer.materialize(); + final TimestampedKeyValueStore store = builder.build(); + final WrappedStateStore logging = (WrappedStateStore) ((WrappedStateStore) store).wrapped(); + assertThat(logging, instanceOf(ChangeLoggingKeyValueBytesStore.class)); + } + + @Test + public void shouldCreateBuilderThatBuildsStoreWithLoggingDisabled() { + final MaterializedInternal> materialized = new MaterializedInternal<>( + Materialized.>as("store").withLoggingDisabled(), nameProvider, storePrefix + ); + final TimestampedKeyValueStoreMaterializer materializer = new TimestampedKeyValueStoreMaterializer<>(materialized); + final StoreBuilder> builder = materializer.materialize(); + final TimestampedKeyValueStore store = builder.build(); + final WrappedStateStore caching = (WrappedStateStore) ((WrappedStateStore) store).wrapped(); + assertThat(caching, instanceOf(CachingKeyValueStore.class)); + assertThat(caching.wrapped(), not(instanceOf(ChangeLoggingKeyValueBytesStore.class))); + } + + @Test + public void shouldCreateBuilderThatBuildsStoreWithCachingAndLoggingDisabled() { + final MaterializedInternal> materialized = new MaterializedInternal<>( + Materialized.>as("store").withCachingDisabled().withLoggingDisabled(), nameProvider, storePrefix + ); + final TimestampedKeyValueStoreMaterializer materializer = new TimestampedKeyValueStoreMaterializer<>(materialized); + final StoreBuilder> builder = materializer.materialize(); + final TimestampedKeyValueStore store = builder.build(); + final StateStore wrapped = ((WrappedStateStore) store).wrapped(); + assertThat(wrapped, not(instanceOf(CachingKeyValueStore.class))); + assertThat(wrapped, not(instanceOf(ChangeLoggingKeyValueBytesStore.class))); + } + + @Test + public void shouldCreateKeyValueStoreWithTheProvidedInnerStore() { + final KeyValueBytesStoreSupplier supplier = EasyMock.createNiceMock(KeyValueBytesStoreSupplier.class); + final InMemoryKeyValueStore store = new InMemoryKeyValueStore("name"); + EasyMock.expect(supplier.name()).andReturn("name").anyTimes(); + EasyMock.expect(supplier.get()).andReturn(store); + EasyMock.expect(supplier.metricsScope()).andReturn("metricScope"); + EasyMock.replay(supplier); + + final MaterializedInternal> materialized = + new MaterializedInternal<>(Materialized.as(supplier), nameProvider, storePrefix); + final TimestampedKeyValueStoreMaterializer materializer = new TimestampedKeyValueStoreMaterializer<>(materialized); + final StoreBuilder> builder = materializer.materialize(); + final TimestampedKeyValueStore built = builder.build(); + + assertThat(store.name(), CoreMatchers.equalTo(built.name())); + } + +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentInfoTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentInfoTest.java new file mode 100644 index 0000000..a4c9534 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentInfoTest.java @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.streams.errors.TaskAssignmentException; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.state.HostInfo; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.NAMED_TASK_T0_0_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.NAMED_TASK_T0_0_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.NAMED_TASK_T0_1_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.NAMED_TASK_T1_0_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.NAMED_TASK_T1_0_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.NAMED_TASK_T2_0_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.NAMED_TASK_T2_2_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_1; +import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.LATEST_SUPPORTED_VERSION; +import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.MIN_NAMED_TOPOLOGY_VERSION; +import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.UNKNOWN; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +public class AssignmentInfoTest { + private final List activeTasks = Arrays.asList( + TASK_0_0, + TASK_0_1, + TASK_1_0, + TASK_1_0 + ); + + private final Map> standbyTasks = mkMap( + mkEntry(TASK_1_0, mkSet(new TopicPartition("t1", 0), new TopicPartition("t2", 0))), + mkEntry(TASK_1_1, mkSet(new TopicPartition("t1", 1), new TopicPartition("t2", 1))) + ); + + private static final List NAMED_ACTIVE_TASKS = Arrays.asList( + NAMED_TASK_T0_0_1, + NAMED_TASK_T0_1_0, + NAMED_TASK_T0_1_0, + NAMED_TASK_T1_0_1, + NAMED_TASK_T1_0_1, + NAMED_TASK_T2_0_0, + NAMED_TASK_T2_2_0 + ); + + private static final Map> NAMED_STANDBY_TASKS = mkMap( + mkEntry(NAMED_TASK_T0_0_0, mkSet(new TopicPartition("t0-1", 0), new TopicPartition("t0-2", 0))), + mkEntry(NAMED_TASK_T0_0_1, mkSet(new TopicPartition("t0-1", 1), new TopicPartition("t0-2", 1))), + mkEntry(NAMED_TASK_T1_0_0, mkSet(new TopicPartition("t1-1", 0), new TopicPartition("t1-2", 0))) + ); + + private final Map> activeAssignment = mkMap( + mkEntry(new HostInfo("localhost", 8088), + mkSet(new TopicPartition("t0", 0), + new TopicPartition("t1", 0), + new TopicPartition("t2", 0))), + mkEntry(new HostInfo("localhost", 8089), + mkSet(new TopicPartition("t0", 1), + new TopicPartition("t1", 1), + new TopicPartition("t2", 1))) + ); + + private final Map> standbyAssignment = mkMap( + mkEntry(new HostInfo("localhost", 8088), + mkSet(new TopicPartition("t1", 0), + new TopicPartition("t2", 0))), + mkEntry(new HostInfo("localhost", 8089), + mkSet(new TopicPartition("t1", 1), + new TopicPartition("t2", 1))) + ); + + @Test + public void shouldUseLatestSupportedVersionByDefault() { + final AssignmentInfo info = new AssignmentInfo(LATEST_SUPPORTED_VERSION, activeTasks, standbyTasks, activeAssignment, standbyAssignment, 0); + assertEquals(LATEST_SUPPORTED_VERSION, info.version()); + } + + @Test + public void shouldThrowForUnknownVersion1() { + assertThrows(IllegalArgumentException.class, () -> new AssignmentInfo(0, activeTasks, standbyTasks, + activeAssignment, Collections.emptyMap(), 0)); + } + + @Test + public void shouldThrowForUnknownVersion2() { + assertThrows(IllegalArgumentException.class, () -> new AssignmentInfo(LATEST_SUPPORTED_VERSION + 1, + activeTasks, standbyTasks, activeAssignment, Collections.emptyMap(), 0)); + } + + @Test + public void shouldEncodeAndDecodeVersion1() { + final AssignmentInfo info = new AssignmentInfo(1, activeTasks, standbyTasks, activeAssignment, standbyAssignment, 0); + final AssignmentInfo expectedInfo = new AssignmentInfo(1, UNKNOWN, activeTasks, standbyTasks, Collections.emptyMap(), Collections.emptyMap(), 0); + assertEquals(expectedInfo, AssignmentInfo.decode(info.encode())); + } + + @Test + public void shouldEncodeAndDecodeVersion2() { + final AssignmentInfo info = new AssignmentInfo(2, activeTasks, standbyTasks, activeAssignment, standbyAssignment, 0); + final AssignmentInfo expectedInfo = new AssignmentInfo(2, UNKNOWN, activeTasks, standbyTasks, activeAssignment, Collections.emptyMap(), 0); + assertEquals(expectedInfo, AssignmentInfo.decode(info.encode())); + } + + @Test + public void shouldEncodeAndDecodeVersion3() { + final AssignmentInfo info = new AssignmentInfo(3, activeTasks, standbyTasks, activeAssignment, standbyAssignment, 0); + final AssignmentInfo expectedInfo = new AssignmentInfo(3, LATEST_SUPPORTED_VERSION, activeTasks, standbyTasks, + activeAssignment, Collections.emptyMap(), 0); + assertEquals(expectedInfo, AssignmentInfo.decode(info.encode())); + } + + @Test + public void shouldEncodeAndDecodeVersion4() { + final AssignmentInfo info = new AssignmentInfo(4, activeTasks, standbyTasks, activeAssignment, standbyAssignment, 2); + final AssignmentInfo expectedInfo = new AssignmentInfo(4, LATEST_SUPPORTED_VERSION, activeTasks, standbyTasks, + activeAssignment, Collections.emptyMap(), 2); + assertEquals(expectedInfo, AssignmentInfo.decode(info.encode())); + } + + @Test + public void shouldEncodeAndDecodeVersion5() { + final AssignmentInfo info = new AssignmentInfo(5, activeTasks, standbyTasks, activeAssignment, standbyAssignment, 2); + final AssignmentInfo expectedInfo = new AssignmentInfo(5, LATEST_SUPPORTED_VERSION, activeTasks, standbyTasks, + activeAssignment, Collections.emptyMap(), 2); + assertEquals(expectedInfo, AssignmentInfo.decode(info.encode())); + } + + @Test + public void shouldEncodeAndDecodeVersion6() { + final AssignmentInfo info = new AssignmentInfo(6, activeTasks, standbyTasks, activeAssignment, standbyAssignment, 2); + final AssignmentInfo expectedInfo = new AssignmentInfo(6, LATEST_SUPPORTED_VERSION, activeTasks, standbyTasks, + activeAssignment, standbyAssignment, 2); + assertEquals(expectedInfo, AssignmentInfo.decode(info.encode())); + } + + @Test + public void shouldEncodeAndDecodeVersion7() { + final AssignmentInfo info = + new AssignmentInfo(7, activeTasks, standbyTasks, activeAssignment, standbyAssignment, 2); + final AssignmentInfo expectedInfo = + new AssignmentInfo(7, LATEST_SUPPORTED_VERSION, activeTasks, standbyTasks, activeAssignment, standbyAssignment, 2); + assertEquals(expectedInfo, AssignmentInfo.decode(info.encode())); + } + + @Test + public void shouldEncodeAndDecodeVersion8() { + final AssignmentInfo info = + new AssignmentInfo(8, activeTasks, standbyTasks, activeAssignment, standbyAssignment, 2); + final AssignmentInfo expectedInfo = + new AssignmentInfo(8, LATEST_SUPPORTED_VERSION, activeTasks, standbyTasks, activeAssignment, standbyAssignment, 2); + assertEquals(expectedInfo, AssignmentInfo.decode(info.encode())); + } + + @Test + public void shouldEncodeAndDecodeVersion9() { + final AssignmentInfo info = + new AssignmentInfo(9, activeTasks, standbyTasks, activeAssignment, standbyAssignment, 2); + final AssignmentInfo expectedInfo = + new AssignmentInfo(9, LATEST_SUPPORTED_VERSION, activeTasks, standbyTasks, activeAssignment, standbyAssignment, 2); + assertEquals(expectedInfo, AssignmentInfo.decode(info.encode())); + } + + @Test + public void shouldEncodeAndDecodeVersion10() { + final AssignmentInfo info = + new AssignmentInfo(10, activeTasks, standbyTasks, activeAssignment, standbyAssignment, 2); + final AssignmentInfo expectedInfo = + new AssignmentInfo(10, LATEST_SUPPORTED_VERSION, activeTasks, standbyTasks, activeAssignment, standbyAssignment, 2); + assertEquals(expectedInfo, AssignmentInfo.decode(info.encode())); + } + + @Test + public void shouldEncodeAndDecodeVersion10WithNamedTopologies() { + final AssignmentInfo info = + new AssignmentInfo(10, LATEST_SUPPORTED_VERSION, NAMED_ACTIVE_TASKS, NAMED_STANDBY_TASKS, activeAssignment, standbyAssignment, 2); + assertEquals(info, AssignmentInfo.decode(info.encode())); + } + + @Test + public void shouldNotEncodeAndDecodeNamedTopologiesWithOlderVersion() { + final AssignmentInfo info = + new AssignmentInfo(MIN_NAMED_TOPOLOGY_VERSION - 1, LATEST_SUPPORTED_VERSION, NAMED_ACTIVE_TASKS, NAMED_STANDBY_TASKS, activeAssignment, standbyAssignment, 2); + assertThrows(TaskAssignmentException.class, () -> AssignmentInfo.decode(info.encode())); + } + + @Test + public void shouldEncodeAndDecodeSmallerCommonlySupportedVersion() { + final int usedVersion = 5; + final int commonlySupportedVersion = 5; + final AssignmentInfo info = new AssignmentInfo(usedVersion, commonlySupportedVersion, activeTasks, standbyTasks, + activeAssignment, standbyAssignment, 2); + final AssignmentInfo expectedInfo = new AssignmentInfo(usedVersion, commonlySupportedVersion, activeTasks, standbyTasks, + activeAssignment, Collections.emptyMap(), 2); + assertEquals(expectedInfo, AssignmentInfo.decode(info.encode())); + } + + @Test + public void nextRebalanceTimeShouldBeMaxValueByDefault() { + final AssignmentInfo info = new AssignmentInfo(7, activeTasks, standbyTasks, activeAssignment, standbyAssignment, 0); + assertEquals(info.nextRebalanceMs(), Long.MAX_VALUE); + } + + @Test + public void shouldDecodeDefaultNextRebalanceTime() { + final AssignmentInfo info = new AssignmentInfo(7, activeTasks, standbyTasks, activeAssignment, standbyAssignment, 0); + assertEquals(info.nextRebalanceMs(), Long.MAX_VALUE); + } + + @Test + public void shouldEncodeAndDecodeNextRebalanceTime() { + final AssignmentInfo info = new AssignmentInfo(7, activeTasks, standbyTasks, activeAssignment, standbyAssignment, 0); + info.setNextRebalanceTime(1000L); + assertEquals(1000L, AssignmentInfo.decode(info.encode()).nextRebalanceMs()); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentTestUtils.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentTestUtils.java new file mode 100644 index 0000000..38669da --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignmentTestUtils.java @@ -0,0 +1,482 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import java.util.Collection; +import java.util.Map.Entry; +import org.apache.kafka.clients.admin.AdminClient; +import org.apache.kafka.clients.admin.ListOffsetsResult; +import org.apache.kafka.clients.admin.ListOffsetsResult.ListOffsetsResultInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.Task; +import org.apache.kafka.streams.processor.internals.TopologyMetadata.Subtopology; + +import org.easymock.EasyMock; +import org.hamcrest.BaseMatcher; +import org.hamcrest.Description; +import org.hamcrest.Matcher; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.TreeMap; +import java.util.TreeSet; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.emptySet; +import static org.apache.kafka.common.utils.Utils.entriesToMap; +import static org.apache.kafka.common.utils.Utils.intersection; +import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.LATEST_SUPPORTED_VERSION; +import static org.easymock.EasyMock.anyObject; +import static org.easymock.EasyMock.expect; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.fail; + +public final class AssignmentTestUtils { + + public static final UUID UUID_1 = uuidForInt(1); + public static final UUID UUID_2 = uuidForInt(2); + public static final UUID UUID_3 = uuidForInt(3); + public static final UUID UUID_4 = uuidForInt(4); + public static final UUID UUID_5 = uuidForInt(5); + public static final UUID UUID_6 = uuidForInt(6); + + public static final TopicPartition TP_0_0 = new TopicPartition("topic0", 0); + public static final TopicPartition TP_0_1 = new TopicPartition("topic0", 1); + public static final TopicPartition TP_0_2 = new TopicPartition("topic0", 2); + public static final TopicPartition TP_1_0 = new TopicPartition("topic1", 0); + public static final TopicPartition TP_1_1 = new TopicPartition("topic1", 1); + public static final TopicPartition TP_1_2 = new TopicPartition("topic1", 2); + + public static final TaskId TASK_0_0 = new TaskId(0, 0); + public static final TaskId TASK_0_1 = new TaskId(0, 1); + public static final TaskId TASK_0_2 = new TaskId(0, 2); + public static final TaskId TASK_0_3 = new TaskId(0, 3); + public static final TaskId TASK_0_4 = new TaskId(0, 4); + public static final TaskId TASK_0_5 = new TaskId(0, 5); + public static final TaskId TASK_0_6 = new TaskId(0, 6); + public static final TaskId TASK_1_0 = new TaskId(1, 0); + public static final TaskId TASK_1_1 = new TaskId(1, 1); + public static final TaskId TASK_1_2 = new TaskId(1, 2); + public static final TaskId TASK_1_3 = new TaskId(1, 3); + public static final TaskId TASK_2_0 = new TaskId(2, 0); + public static final TaskId TASK_2_1 = new TaskId(2, 1); + public static final TaskId TASK_2_2 = new TaskId(2, 2); + public static final TaskId TASK_2_3 = new TaskId(2, 3); + + public static final TaskId NAMED_TASK_T0_0_0 = new TaskId(0, 0, "topology0"); + public static final TaskId NAMED_TASK_T0_0_1 = new TaskId(0, 1, "topology0"); + public static final TaskId NAMED_TASK_T0_1_0 = new TaskId(1, 0, "topology0"); + public static final TaskId NAMED_TASK_T0_1_1 = new TaskId(1, 1, "topology0"); + public static final TaskId NAMED_TASK_T1_0_0 = new TaskId(0, 0, "topology1"); + public static final TaskId NAMED_TASK_T1_0_1 = new TaskId(0, 1, "topology1"); + public static final TaskId NAMED_TASK_T2_0_0 = new TaskId(0, 0, "topology2"); + public static final TaskId NAMED_TASK_T2_2_0 = new TaskId(2, 0, "topology2"); + + public static final Subtopology SUBTOPOLOGY_0 = new Subtopology(0, null); + public static final Subtopology SUBTOPOLOGY_1 = new Subtopology(1, null); + public static final Subtopology SUBTOPOLOGY_2 = new Subtopology(2, null); + + public static final Set EMPTY_TASKS = emptySet(); + public static final Map EMPTY_CHANGELOG_END_OFFSETS = new HashMap<>(); + + private AssignmentTestUtils() {} + + static Map getClientStatesMap(final ClientState... states) { + final Map clientStates = new HashMap<>(); + int nthState = 1; + for (final ClientState state : states) { + clientStates.put(uuidForInt(nthState), state); + ++nthState; + } + return clientStates; + } + + // If you don't care about setting the end offsets for each specific topic partition, the helper method + // getTopicPartitionOffsetMap is useful for building this input map for all partitions + public static AdminClient createMockAdminClientForAssignor(final Map changelogEndOffsets) { + final AdminClient adminClient = EasyMock.createMock(AdminClient.class); + + final ListOffsetsResult result = EasyMock.createNiceMock(ListOffsetsResult.class); + final KafkaFutureImpl> allFuture = new KafkaFutureImpl<>(); + allFuture.complete(changelogEndOffsets.entrySet().stream().collect(Collectors.toMap( + Entry::getKey, + t -> { + final ListOffsetsResultInfo info = EasyMock.createNiceMock(ListOffsetsResultInfo.class); + expect(info.offset()).andStubReturn(t.getValue()); + EasyMock.replay(info); + return info; + })) + ); + + expect(adminClient.listOffsets(anyObject())).andStubReturn(result); + expect(result.all()).andStubReturn(allFuture); + + EasyMock.replay(result); + return adminClient; + } + + public static SubscriptionInfo getInfo(final UUID processId, + final Set prevTasks, + final Set standbyTasks) { + return new SubscriptionInfo( + LATEST_SUPPORTED_VERSION, LATEST_SUPPORTED_VERSION, processId, null, getTaskOffsetSums(prevTasks, standbyTasks), (byte) 0, 0); + } + + public static SubscriptionInfo getInfo(final UUID processId, + final Set prevTasks, + final Set standbyTasks, + final String userEndPoint) { + return new SubscriptionInfo( + LATEST_SUPPORTED_VERSION, LATEST_SUPPORTED_VERSION, processId, userEndPoint, getTaskOffsetSums(prevTasks, standbyTasks), (byte) 0, 0); + } + + public static SubscriptionInfo getInfo(final UUID processId, + final Set prevTasks, + final Set standbyTasks, + final byte uniqueField) { + return new SubscriptionInfo( + LATEST_SUPPORTED_VERSION, LATEST_SUPPORTED_VERSION, processId, null, getTaskOffsetSums(prevTasks, standbyTasks), uniqueField, 0); + } + + // Stub offset sums for when we only care about the prev/standby task sets, not the actual offsets + private static Map getTaskOffsetSums(final Collection activeTasks, final Collection standbyTasks) { + final Map taskOffsetSums = activeTasks.stream().collect(Collectors.toMap(t -> t, t -> Task.LATEST_OFFSET)); + taskOffsetSums.putAll(standbyTasks.stream().collect(Collectors.toMap(t -> t, t -> 0L))); + return taskOffsetSums; + } + + /** + * Builds a UUID by repeating the given number n. For valid n, it is guaranteed that the returned UUIDs satisfy + * the same relation relative to others as their parameter n does: iff n < m, then uuidForInt(n) < uuidForInt(m) + */ + public static UUID uuidForInt(final int n) { + return new UUID(0, n); + } + + static void assertValidAssignment(final int numStandbyReplicas, + final Set statefulTasks, + final Set statelessTasks, + final Map assignedStates, + final StringBuilder failureContext) { + assertValidAssignment( + numStandbyReplicas, + 0, + statefulTasks, + statelessTasks, + assignedStates, + failureContext + ); + } + + static void assertValidAssignment(final int numStandbyReplicas, + final int maxWarmupReplicas, + final Set statefulTasks, + final Set statelessTasks, + final Map assignedStates, + final StringBuilder failureContext) { + final Map> assignments = new TreeMap<>(); + for (final TaskId taskId : statefulTasks) { + assignments.put(taskId, new TreeSet<>()); + } + for (final TaskId taskId : statelessTasks) { + assignments.put(taskId, new TreeSet<>()); + } + for (final Map.Entry entry : assignedStates.entrySet()) { + validateAndAddActiveAssignments(statefulTasks, statelessTasks, failureContext, assignments, entry); + validateAndAddStandbyAssignments(statefulTasks, statelessTasks, failureContext, assignments, entry); + } + + final AtomicInteger remainingWarmups = new AtomicInteger(maxWarmupReplicas); + + final TreeMap> misassigned = + assignments + .entrySet() + .stream() + .filter(entry -> { + final int expectedActives = 1; + final boolean isStateless = statelessTasks.contains(entry.getKey()); + final int expectedStandbys = isStateless ? 0 : numStandbyReplicas; + // We'll never assign even the expected number of standbys if they don't actually fit in the cluster + final int expectedAssignments = Math.min( + assignedStates.size(), + expectedActives + expectedStandbys + ); + final int actualAssignments = entry.getValue().size(); + if (actualAssignments == expectedAssignments) { + return false; // not misassigned + } else { + if (actualAssignments == expectedAssignments + 1 && remainingWarmups.get() > 0) { + remainingWarmups.getAndDecrement(); + return false; // it's a warmup, so it's fine + } else { + return true; // misassigned + } + } + }) + .collect(entriesToMap(TreeMap::new)); + + if (!misassigned.isEmpty()) { + assertThat( + new StringBuilder().append("Found some over- or under-assigned tasks in the final assignment with ") + .append(numStandbyReplicas) + .append(" and max warmups ") + .append(maxWarmupReplicas) + .append(" standby replicas, stateful tasks:") + .append(statefulTasks) + .append(", and stateless tasks:") + .append(statelessTasks) + .append(failureContext) + .toString(), + misassigned, + is(emptyMap())); + } + } + + private static void validateAndAddStandbyAssignments(final Set statefulTasks, + final Set statelessTasks, + final StringBuilder failureContext, + final Map> assignments, + final Map.Entry entry) { + for (final TaskId standbyTask : entry.getValue().standbyTasks()) { + if (statelessTasks.contains(standbyTask)) { + throw new AssertionError( + new StringBuilder().append("Found a standby task for stateless task ") + .append(standbyTask) + .append(" on client ") + .append(entry) + .append(" stateless tasks:") + .append(statelessTasks) + .append(failureContext) + .toString() + ); + } else if (assignments.containsKey(standbyTask)) { + assignments.get(standbyTask).add(entry.getKey()); + } else { + throw new AssertionError( + new StringBuilder().append("Found an extra standby task ") + .append(standbyTask) + .append(" on client ") + .append(entry) + .append(" but expected stateful tasks:") + .append(statefulTasks) + .append(failureContext) + .toString() + ); + } + } + } + + private static void validateAndAddActiveAssignments(final Set statefulTasks, + final Set statelessTasks, + final StringBuilder failureContext, + final Map> assignments, + final Map.Entry entry) { + for (final TaskId activeTask : entry.getValue().activeTasks()) { + if (assignments.containsKey(activeTask)) { + assignments.get(activeTask).add(entry.getKey()); + } else { + throw new AssertionError( + new StringBuilder().append("Found an extra active task ") + .append(activeTask) + .append(" on client ") + .append(entry) + .append(" but expected stateful tasks:") + .append(statefulTasks) + .append(" and stateless tasks:") + .append(statelessTasks) + .append(failureContext) + .toString() + ); + } + } + } + + static void assertBalancedStatefulAssignment(final Set allStatefulTasks, + final Map clientStates, + final StringBuilder failureContext) { + double maxStateful = Double.MIN_VALUE; + double minStateful = Double.MAX_VALUE; + for (final ClientState clientState : clientStates.values()) { + final Set statefulTasks = + intersection(HashSet::new, clientState.assignedTasks(), allStatefulTasks); + final double statefulTaskLoad = 1.0 * statefulTasks.size() / clientState.capacity(); + maxStateful = Math.max(maxStateful, statefulTaskLoad); + minStateful = Math.min(minStateful, statefulTaskLoad); + } + final double statefulDiff = maxStateful - minStateful; + + if (statefulDiff > 1.0) { + final StringBuilder builder = new StringBuilder() + .append("detected a stateful assignment balance factor violation: ") + .append(statefulDiff) + .append(">") + .append(1.0) + .append(" in: "); + appendClientStates(builder, clientStates); + fail(builder.append(failureContext).toString()); + } + } + + static void assertBalancedActiveAssignment(final Map clientStates, + final StringBuilder failureContext) { + double maxActive = Double.MIN_VALUE; + double minActive = Double.MAX_VALUE; + for (final ClientState clientState : clientStates.values()) { + final double activeTaskLoad = clientState.activeTaskLoad(); + maxActive = Math.max(maxActive, activeTaskLoad); + minActive = Math.min(minActive, activeTaskLoad); + } + final double activeDiff = maxActive - minActive; + if (activeDiff > 1.0) { + final StringBuilder builder = new StringBuilder() + .append("detected an active assignment balance factor violation: ") + .append(activeDiff) + .append(">") + .append(1.0) + .append(" in: "); + appendClientStates(builder, clientStates); + fail(builder.append(failureContext).toString()); + } + } + + static void assertBalancedTasks(final Map clientStates) { + final TaskSkewReport taskSkewReport = analyzeTaskAssignmentBalance(clientStates); + if (taskSkewReport.totalSkewedTasks() > 0) { + fail("Expected a balanced task assignment, but was: " + taskSkewReport); + } + } + + static TaskSkewReport analyzeTaskAssignmentBalance(final Map clientStates) { + final Function> initialClientCounts = + i -> clientStates.keySet().stream().collect(Collectors.toMap(c -> c, c -> new AtomicInteger(0))); + + final Map> subtopologyToClientsWithPartition = new TreeMap<>(); + for (final Map.Entry entry : clientStates.entrySet()) { + final UUID client = entry.getKey(); + final ClientState clientState = entry.getValue(); + for (final TaskId task : clientState.activeTasks()) { + final int subtopology = task.subtopology(); + subtopologyToClientsWithPartition + .computeIfAbsent(subtopology, initialClientCounts) + .get(client) + .incrementAndGet(); + } + } + + int maxTaskSkew = 0; + final Set skewedSubtopologies = new TreeSet<>(); + + for (final Map.Entry> entry : subtopologyToClientsWithPartition.entrySet()) { + final Map clientsWithPartition = entry.getValue(); + int max = Integer.MIN_VALUE; + int min = Integer.MAX_VALUE; + for (final AtomicInteger count : clientsWithPartition.values()) { + max = Math.max(max, count.get()); + min = Math.min(min, count.get()); + } + final int taskSkew = max - min; + maxTaskSkew = Math.max(maxTaskSkew, taskSkew); + if (taskSkew > 1) { + skewedSubtopologies.add(entry.getKey()); + } + } + + return new TaskSkewReport(maxTaskSkew, skewedSubtopologies, subtopologyToClientsWithPartition); + } + + static Matcher hasAssignedTasks(final int taskCount) { + return hasProperty("assignedTasks", ClientState::assignedTaskCount, taskCount); + } + + static Matcher hasActiveTasks(final int taskCount) { + return hasProperty("activeTasks", ClientState::activeTaskCount, taskCount); + } + + static Matcher hasStandbyTasks(final int taskCount) { + return hasProperty("standbyTasks", ClientState::standbyTaskCount, taskCount); + } + + static Matcher hasProperty(final String propertyName, + final Function propertyExtractor, + final V propertyValue) { + return new BaseMatcher() { + @Override + public void describeTo(final Description description) { + description.appendText(propertyName).appendText(":").appendValue(propertyValue); + } + + @Override + public boolean matches(final Object actual) { + if (actual instanceof ClientState) { + return Objects.equals(propertyExtractor.apply((ClientState) actual), propertyValue); + } else { + return false; + } + } + }; + } + + static void appendClientStates(final StringBuilder stringBuilder, + final Map clientStates) { + stringBuilder.append('{').append('\n'); + for (final Map.Entry entry : clientStates.entrySet()) { + stringBuilder.append(" ").append(entry.getKey()).append(": ").append(entry.getValue()).append('\n'); + } + stringBuilder.append('}').append('\n'); + } + + static final class TaskSkewReport { + private final int maxTaskSkew; + private final Set skewedSubtopologies; + private final Map> subtopologyToClientsWithPartition; + + private TaskSkewReport(final int maxTaskSkew, + final Set skewedSubtopologies, + final Map> subtopologyToClientsWithPartition) { + this.maxTaskSkew = maxTaskSkew; + this.skewedSubtopologies = skewedSubtopologies; + this.subtopologyToClientsWithPartition = subtopologyToClientsWithPartition; + } + + int totalSkewedTasks() { + return skewedSubtopologies.size(); + } + + Set skewedSubtopologies() { + return skewedSubtopologies; + } + + @Override + public String toString() { + return "TaskSkewReport{" + + "maxTaskSkew=" + maxTaskSkew + + ", skewedSubtopologies=" + skewedSubtopologies + + ", subtopologyToClientsWithPartition=" + subtopologyToClientsWithPartition + + '}'; + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfigurationTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfigurationTest.java new file mode 100644 index 0000000..868c303 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/AssignorConfigurationTest.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import org.apache.kafka.common.config.ConfigException; +import org.junit.Test; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertThrows; + +public class AssignorConfigurationTest { + + @Test + public void configsShouldRejectZeroWarmups() { + final ConfigException exception = assertThrows( + ConfigException.class, + () -> new AssignorConfiguration.AssignmentConfigs(1L, 0, 1, 1L) + ); + + assertThat(exception.getMessage(), containsString("Invalid value 0 for configuration max.warmup.replicas")); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java new file mode 100644 index 0000000..e8acdc3 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ClientStateTest.java @@ -0,0 +1,512 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.Task; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static java.util.Arrays.asList; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.common.utils.Utils.mkSortedSet; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.NAMED_TASK_T0_0_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.NAMED_TASK_T1_0_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_2; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_3; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TP_0_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TP_0_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TP_0_2; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TP_1_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TP_1_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TP_1_2; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.hasActiveTasks; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.hasStandbyTasks; +import static org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo.UNKNOWN_OFFSET_SUM; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class ClientStateTest { + private final ClientState client = new ClientState(1); + private final ClientState zeroCapacityClient = new ClientState(0); + + @Test + public void previousStateConstructorShouldCreateAValidObject() { + final ClientState clientState = new ClientState( + mkSet(TASK_0_0, TASK_0_1), + mkSet(TASK_0_2, TASK_0_3), + mkMap(mkEntry(TASK_0_0, 5L), mkEntry(TASK_0_2, -1L)), + 4 + ); + + // all the "next assignment" fields should be empty + assertThat(clientState.activeTaskCount(), is(0)); + assertThat(clientState.activeTaskLoad(), is(0.0)); + assertThat(clientState.activeTasks(), is(empty())); + assertThat(clientState.standbyTaskCount(), is(0)); + assertThat(clientState.standbyTasks(), is(empty())); + assertThat(clientState.assignedTaskCount(), is(0)); + assertThat(clientState.assignedTasks(), is(empty())); + + // and the "previous assignment" fields should match the constructor args + assertThat(clientState.prevActiveTasks(), is(mkSet(TASK_0_0, TASK_0_1))); + assertThat(clientState.prevStandbyTasks(), is(mkSet(TASK_0_2, TASK_0_3))); + assertThat(clientState.previousAssignedTasks(), is(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3))); + assertThat(clientState.capacity(), is(4)); + assertThat(clientState.lagFor(TASK_0_0), is(5L)); + assertThat(clientState.lagFor(TASK_0_2), is(-1L)); + } + + @Test + public void shouldHaveNotReachedCapacityWhenAssignedTasksLessThanCapacity() { + assertFalse(client.reachedCapacity()); + } + + @Test + public void shouldHaveReachedCapacityWhenAssignedTasksGreaterThanOrEqualToCapacity() { + client.assignActive(TASK_0_1); + assertTrue(client.reachedCapacity()); + } + + @Test + public void shouldRefuseDoubleActiveTask() { + final ClientState clientState = new ClientState(1); + clientState.assignActive(TASK_0_0); + assertThrows(IllegalArgumentException.class, () -> clientState.assignActive(TASK_0_0)); + } + + @Test + public void shouldRefuseActiveAndStandbyTask() { + final ClientState clientState = new ClientState(1); + clientState.assignActive(TASK_0_0); + assertThrows(IllegalArgumentException.class, () -> clientState.assignStandby(TASK_0_0)); + } + + @Test + public void shouldRefuseDoubleStandbyTask() { + final ClientState clientState = new ClientState(1); + clientState.assignStandby(TASK_0_0); + assertThrows(IllegalArgumentException.class, () -> clientState.assignStandby(TASK_0_0)); + } + + @Test + public void shouldRefuseStandbyAndActiveTask() { + final ClientState clientState = new ClientState(1); + clientState.assignStandby(TASK_0_0); + assertThrows(IllegalArgumentException.class, () -> clientState.assignActive(TASK_0_0)); + } + + @Test + public void shouldRefuseToUnassignNotAssignedActiveTask() { + final ClientState clientState = new ClientState(1); + assertThrows(IllegalArgumentException.class, () -> clientState.unassignActive(TASK_0_0)); + } + + @Test + public void shouldRefuseToUnassignNotAssignedStandbyTask() { + final ClientState clientState = new ClientState(1); + assertThrows(IllegalArgumentException.class, () -> clientState.unassignStandby(TASK_0_0)); + } + + @Test + public void shouldRefuseToUnassignActiveTaskAsStandby() { + final ClientState clientState = new ClientState(1); + clientState.assignActive(TASK_0_0); + assertThrows(IllegalArgumentException.class, () -> clientState.unassignStandby(TASK_0_0)); + } + + @Test + public void shouldRefuseToUnassignStandbyTaskAsActive() { + final ClientState clientState = new ClientState(1); + clientState.assignStandby(TASK_0_0); + assertThrows(IllegalArgumentException.class, () -> clientState.unassignActive(TASK_0_0)); + } + + @Test + public void shouldUnassignActiveTask() { + final ClientState clientState = new ClientState(1); + clientState.assignActive(TASK_0_0); + assertThat(clientState, hasActiveTasks(1)); + clientState.unassignActive(TASK_0_0); + assertThat(clientState, hasActiveTasks(0)); + } + + @Test + public void shouldUnassignStandbyTask() { + final ClientState clientState = new ClientState(1); + clientState.assignStandby(TASK_0_0); + assertThat(clientState, hasStandbyTasks(1)); + clientState.unassignStandby(TASK_0_0); + assertThat(clientState, hasStandbyTasks(0)); + } + + @Test + public void shouldNotModifyActiveView() { + final ClientState clientState = new ClientState(1); + final Set taskIds = clientState.activeTasks(); + assertThrows(UnsupportedOperationException.class, () -> taskIds.add(TASK_0_0)); + assertThat(clientState, hasActiveTasks(0)); + } + + @Test + public void shouldNotModifyStandbyView() { + final ClientState clientState = new ClientState(1); + final Set taskIds = clientState.standbyTasks(); + assertThrows(UnsupportedOperationException.class, () -> taskIds.add(TASK_0_0)); + assertThat(clientState, hasStandbyTasks(0)); + } + + @Test + public void shouldNotModifyAssignedView() { + final ClientState clientState = new ClientState(1); + final Set taskIds = clientState.assignedTasks(); + assertThrows(UnsupportedOperationException.class, () -> taskIds.add(TASK_0_0)); + assertThat(clientState, hasActiveTasks(0)); + assertThat(clientState, hasStandbyTasks(0)); + } + + @Test + public void shouldAddActiveTasksToBothAssignedAndActive() { + client.assignActive(TASK_0_1); + assertThat(client.activeTasks(), equalTo(Collections.singleton(TASK_0_1))); + assertThat(client.assignedTasks(), equalTo(Collections.singleton(TASK_0_1))); + assertThat(client.assignedTaskCount(), equalTo(1)); + assertThat(client.standbyTasks().size(), equalTo(0)); + } + + @Test + public void shouldAddStandbyTasksToBothStandbyAndAssigned() { + client.assignStandby(TASK_0_1); + assertThat(client.assignedTasks(), equalTo(Collections.singleton(TASK_0_1))); + assertThat(client.standbyTasks(), equalTo(Collections.singleton(TASK_0_1))); + assertThat(client.assignedTaskCount(), equalTo(1)); + assertThat(client.activeTasks().size(), equalTo(0)); + } + + @Test + public void shouldAddPreviousActiveTasksToPreviousAssignedAndPreviousActive() { + client.addPreviousActiveTasks(Utils.mkSet(TASK_0_1, TASK_0_2)); + assertThat(client.prevActiveTasks(), equalTo(Utils.mkSet(TASK_0_1, TASK_0_2))); + assertThat(client.previousAssignedTasks(), equalTo(Utils.mkSet(TASK_0_1, TASK_0_2))); + } + + @Test + public void shouldAddPreviousStandbyTasksToPreviousAssignedAndPreviousStandby() { + client.addPreviousStandbyTasks(Utils.mkSet(TASK_0_1, TASK_0_2)); + assertThat(client.prevActiveTasks().size(), equalTo(0)); + assertThat(client.previousAssignedTasks(), equalTo(Utils.mkSet(TASK_0_1, TASK_0_2))); + } + + @Test + public void shouldHaveAssignedTaskIfActiveTaskAssigned() { + client.assignActive(TASK_0_1); + assertTrue(client.hasAssignedTask(TASK_0_1)); + } + + @Test + public void shouldHaveAssignedTaskIfStandbyTaskAssigned() { + client.assignStandby(TASK_0_1); + assertTrue(client.hasAssignedTask(TASK_0_1)); + } + + @Test + public void shouldNotHaveAssignedTaskIfTaskNotAssigned() { + client.assignActive(TASK_0_1); + assertFalse(client.hasAssignedTask(TASK_0_2)); + } + + @Test + public void shouldHaveMoreAvailableCapacityWhenCapacityTheSameButFewerAssignedTasks() { + final ClientState otherClient = new ClientState(1); + client.assignActive(TASK_0_1); + assertTrue(otherClient.hasMoreAvailableCapacityThan(client)); + assertFalse(client.hasMoreAvailableCapacityThan(otherClient)); + } + + @Test + public void shouldHaveMoreAvailableCapacityWhenCapacityHigherAndSameAssignedTaskCount() { + final ClientState otherClient = new ClientState(2); + assertTrue(otherClient.hasMoreAvailableCapacityThan(client)); + assertFalse(client.hasMoreAvailableCapacityThan(otherClient)); + } + + @Test + public void shouldUseMultiplesOfCapacityToDetermineClientWithMoreAvailableCapacity() { + final ClientState otherClient = new ClientState(2); + + for (int i = 0; i < 7; i++) { + otherClient.assignActive(new TaskId(0, i)); + } + + for (int i = 7; i < 11; i++) { + client.assignActive(new TaskId(0, i)); + } + + assertTrue(otherClient.hasMoreAvailableCapacityThan(client)); + } + + @Test + public void shouldHaveMoreAvailableCapacityWhenCapacityIsTheSameButAssignedTasksIsLess() { + final ClientState client = new ClientState(3); + final ClientState otherClient = new ClientState(3); + for (int i = 0; i < 4; i++) { + client.assignActive(new TaskId(0, i)); + otherClient.assignActive(new TaskId(0, i)); + } + otherClient.assignActive(new TaskId(0, 5)); + assertTrue(client.hasMoreAvailableCapacityThan(otherClient)); + } + + @Test + public void shouldThrowIllegalStateExceptionIfCapacityOfThisClientStateIsZero() { + assertThrows(IllegalStateException.class, () -> zeroCapacityClient.hasMoreAvailableCapacityThan(client)); + } + + @Test + public void shouldThrowIllegalStateExceptionIfCapacityOfOtherClientStateIsZero() { + assertThrows(IllegalStateException.class, () -> client.hasMoreAvailableCapacityThan(zeroCapacityClient)); + } + + @Test + public void shouldHaveUnfulfilledQuotaWhenActiveTaskSizeLessThanCapacityTimesTasksPerThread() { + client.assignActive(new TaskId(0, 1)); + assertTrue(client.hasUnfulfilledQuota(2)); + } + + @Test + public void shouldNotHaveUnfulfilledQuotaWhenActiveTaskSizeGreaterEqualThanCapacityTimesTasksPerThread() { + client.assignActive(new TaskId(0, 1)); + assertFalse(client.hasUnfulfilledQuota(1)); + } + + @Test + public void shouldAddTasksWithLatestOffsetToPrevActiveTasks() { + final Map taskOffsetSums = Collections.singletonMap(TASK_0_1, Task.LATEST_OFFSET); + client.addPreviousTasksAndOffsetSums("c1", taskOffsetSums); + client.initializePrevTasks(Collections.emptyMap()); + assertThat(client.prevActiveTasks(), equalTo(Collections.singleton(TASK_0_1))); + assertThat(client.previousAssignedTasks(), equalTo(Collections.singleton(TASK_0_1))); + assertTrue(client.prevStandbyTasks().isEmpty()); + } + + @Test + public void shouldReturnPreviousStatefulTasksForConsumer() { + client.addPreviousTasksAndOffsetSums("c1", mkMap( + mkEntry(TASK_0_0, 100L), + mkEntry(TASK_0_1, Task.LATEST_OFFSET) + )); + client.addPreviousTasksAndOffsetSums("c2", Collections.singletonMap(TASK_0_2, 0L)); + client.addPreviousTasksAndOffsetSums("c3", Collections.emptyMap()); + + client.initializePrevTasks(Collections.emptyMap()); + + assertThat(client.prevOwnedStatefulTasksByConsumer("c1"), equalTo(mkSet(TASK_0_0, TASK_0_1))); + assertThat(client.prevOwnedStatefulTasksByConsumer("c2"), equalTo(mkSet(TASK_0_2))); + assertTrue(client.prevOwnedStatefulTasksByConsumer("c3").isEmpty()); + } + + @Test + public void shouldReturnPreviousActiveStandbyTasksForConsumer() { + client.addOwnedPartitions(mkSet(TP_0_1, TP_1_1), "c1"); + client.addOwnedPartitions(mkSet(TP_0_2, TP_1_2), "c2"); + client.initializePrevTasks(mkMap( + mkEntry(TP_0_0, TASK_0_0), + mkEntry(TP_0_1, TASK_0_1), + mkEntry(TP_0_2, TASK_0_2), + mkEntry(TP_1_0, TASK_0_0), + mkEntry(TP_1_1, TASK_0_1), + mkEntry(TP_1_2, TASK_0_2)) + ); + + client.addPreviousTasksAndOffsetSums("c1", mkMap( + mkEntry(TASK_0_1, Task.LATEST_OFFSET), + mkEntry(TASK_0_0, 10L))); + client.addPreviousTasksAndOffsetSums("c2", Collections.singletonMap(TASK_0_2, 0L)); + + assertThat(client.prevOwnedStatefulTasksByConsumer("c1"), equalTo(mkSet(TASK_0_1, TASK_0_0))); + assertThat(client.prevOwnedStatefulTasksByConsumer("c2"), equalTo(mkSet(TASK_0_2))); + assertThat(client.prevOwnedActiveTasksByConsumer(), equalTo( + mkMap( + mkEntry("c1", Collections.singleton(TASK_0_1)), + mkEntry("c2", Collections.singleton(TASK_0_2)) + )) + ); + assertThat(client.prevOwnedStandbyByConsumer(), equalTo( + mkMap( + mkEntry("c1", Collections.singleton(TASK_0_0)), + mkEntry("c2", Collections.emptySet()) + )) + ); + } + + @Test + public void shouldReturnAssignedTasksForConsumer() { + final List allTasks = new ArrayList<>(asList(TASK_0_0, TASK_0_1, TASK_0_2)); + client.assignActiveTasks(allTasks); + + client.assignActiveToConsumer(TASK_0_0, "c1"); + // calling it multiple tasks should be idempotent + client.assignActiveToConsumer(TASK_0_0, "c1"); + client.assignActiveToConsumer(TASK_0_1, "c1"); + client.assignActiveToConsumer(TASK_0_2, "c2"); + + client.assignStandbyToConsumer(TASK_0_2, "c1"); + client.assignStandbyToConsumer(TASK_0_0, "c2"); + // calling it multiple tasks should be idempotent + client.assignStandbyToConsumer(TASK_0_0, "c2"); + + client.revokeActiveFromConsumer(TASK_0_1, "c1"); + // calling it multiple tasks should be idempotent + client.revokeActiveFromConsumer(TASK_0_1, "c1"); + + assertThat(client.assignedActiveTasksByConsumer(), equalTo(mkMap( + mkEntry("c1", mkSet(TASK_0_0, TASK_0_1)), + mkEntry("c2", mkSet(TASK_0_2)) + ))); + assertThat(client.assignedStandbyTasksByConsumer(), equalTo(mkMap( + mkEntry("c1", mkSet(TASK_0_2)), + mkEntry("c2", mkSet(TASK_0_0)) + ))); + assertThat(client.revokingActiveTasksByConsumer(), equalTo(Collections.singletonMap("c1", mkSet(TASK_0_1)))); + } + + @Test + public void shouldAddTasksInOffsetSumsMapToPrevStandbyTasks() { + final Map taskOffsetSums = mkMap( + mkEntry(TASK_0_1, 0L), + mkEntry(TASK_0_2, 100L) + ); + client.addPreviousTasksAndOffsetSums("c1", taskOffsetSums); + client.initializePrevTasks(Collections.emptyMap()); + assertThat(client.prevStandbyTasks(), equalTo(mkSet(TASK_0_1, TASK_0_2))); + assertThat(client.previousAssignedTasks(), equalTo(mkSet(TASK_0_1, TASK_0_2))); + assertTrue(client.prevActiveTasks().isEmpty()); + } + + @Test + public void shouldComputeTaskLags() { + final Map taskOffsetSums = mkMap( + mkEntry(TASK_0_1, 0L), + mkEntry(TASK_0_2, 100L) + ); + final Map allTaskEndOffsetSums = mkMap( + mkEntry(TASK_0_1, 500L), + mkEntry(TASK_0_2, 100L) + ); + client.addPreviousTasksAndOffsetSums("c1", taskOffsetSums); + client.computeTaskLags(null, allTaskEndOffsetSums); + + assertThat(client.lagFor(TASK_0_1), equalTo(500L)); + assertThat(client.lagFor(TASK_0_2), equalTo(0L)); + } + + @Test + public void shouldNotTryToLookupTasksThatWerePreviouslyAssignedButNoLongerExist() { + final Map clientReportedTaskEndOffsetSums = mkMap( + mkEntry(NAMED_TASK_T0_0_0, 500L), + mkEntry(NAMED_TASK_T1_0_0, 500L) + ); + final Map allTaskEndOffsetSumsComputedByAssignor = Collections.singletonMap(NAMED_TASK_T0_0_0, 500L); + client.addPreviousTasksAndOffsetSums("c1", clientReportedTaskEndOffsetSums); + client.computeTaskLags(null, allTaskEndOffsetSumsComputedByAssignor); + + assertThrows(IllegalStateException.class, () -> client.lagFor(NAMED_TASK_T1_0_0)); + + client.assignActive(NAMED_TASK_T0_0_0); + assertThat(client.prevTasksByLag("c1"), equalTo(mkSortedSet(NAMED_TASK_T0_0_0))); + } + + @Test + public void shouldReturnEndOffsetSumForLagOfTaskWeDidNotPreviouslyOwn() { + final Map taskOffsetSums = Collections.emptyMap(); + final Map allTaskEndOffsetSums = Collections.singletonMap(TASK_0_1, 500L); + client.addPreviousTasksAndOffsetSums("c1", taskOffsetSums); + client.computeTaskLags(null, allTaskEndOffsetSums); + assertThat(client.lagFor(TASK_0_1), equalTo(500L)); + } + + @Test + public void shouldReturnLatestOffsetForLagOfPreviousActiveRunningTask() { + final Map taskOffsetSums = Collections.singletonMap(TASK_0_1, Task.LATEST_OFFSET); + final Map allTaskEndOffsetSums = Collections.singletonMap(TASK_0_1, 500L); + client.addPreviousTasksAndOffsetSums("c1", taskOffsetSums); + client.computeTaskLags(null, allTaskEndOffsetSums); + assertThat(client.lagFor(TASK_0_1), equalTo(Task.LATEST_OFFSET)); + } + + @Test + public void shouldReturnUnknownOffsetSumForLagOfTaskWithUnknownOffset() { + final Map taskOffsetSums = Collections.singletonMap(TASK_0_1, UNKNOWN_OFFSET_SUM); + final Map allTaskEndOffsetSums = Collections.singletonMap(TASK_0_1, 500L); + client.addPreviousTasksAndOffsetSums("c1", taskOffsetSums); + client.computeTaskLags(null, allTaskEndOffsetSums); + assertThat(client.lagFor(TASK_0_1), equalTo(UNKNOWN_OFFSET_SUM)); + } + + @Test + public void shouldReturnEndOffsetSumIfOffsetSumIsGreaterThanEndOffsetSum() { + final Map taskOffsetSums = Collections.singletonMap(TASK_0_1, 5L); + final Map allTaskEndOffsetSums = Collections.singletonMap(TASK_0_1, 1L); + client.addPreviousTasksAndOffsetSums("c1", taskOffsetSums); + client.computeTaskLags(null, allTaskEndOffsetSums); + assertThat(client.lagFor(TASK_0_1), equalTo(1L)); + } + + @Test + public void shouldThrowIllegalStateExceptionIfTaskLagsMapIsNotEmpty() { + final Map taskOffsetSums = Collections.singletonMap(TASK_0_1, 5L); + final Map allTaskEndOffsetSums = Collections.singletonMap(TASK_0_1, 1L); + client.computeTaskLags(null, taskOffsetSums); + assertThrows(IllegalStateException.class, () -> client.computeTaskLags(null, allTaskEndOffsetSums)); + } + + @Test + public void shouldThrowIllegalStateExceptionOnLagForUnknownTask() { + final Map taskOffsetSums = Collections.singletonMap(TASK_0_1, 0L); + final Map allTaskEndOffsetSums = Collections.singletonMap(TASK_0_1, 500L); + client.addPreviousTasksAndOffsetSums("c1", taskOffsetSums); + client.computeTaskLags(null, allTaskEndOffsetSums); + assertThrows(IllegalStateException.class, () -> client.lagFor(TASK_0_2)); + } + + @Test + public void shouldThrowIllegalStateExceptionIfAttemptingToInitializeNonEmptyPrevTaskSets() { + client.addPreviousActiveTasks(Collections.singleton(TASK_0_1)); + assertThrows(IllegalStateException.class, () -> client.initializePrevTasks(Collections.emptyMap())); + } + + @Test + public void shouldThrowIllegalStateExceptionIfAssignedTasksForConsumerToNonClientAssignActive() { + assertThrows(IllegalStateException.class, () -> client.assignActiveToConsumer(TASK_0_0, "c1")); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ConstrainedPrioritySetTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ConstrainedPrioritySetTest.java new file mode 100644 index 0000000..d14e279 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/ConstrainedPrioritySetTest.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + + +import org.apache.kafka.streams.processor.TaskId; +import org.junit.Test; + +import java.util.UUID; +import java.util.function.BiFunction; + +import static java.util.Arrays.asList; +import static java.util.Collections.singleton; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_2; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_3; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; + +public class ConstrainedPrioritySetTest { + private static final TaskId DUMMY_TASK = new TaskId(0, 0); + + private final BiFunction alwaysTrue = (client, task) -> true; + private final BiFunction alwaysFalse = (client, task) -> false; + + @Test + public void shouldReturnOnlyClient() { + final ConstrainedPrioritySet queue = new ConstrainedPrioritySet(alwaysTrue, client -> 1.0); + queue.offerAll(singleton(UUID_1)); + + assertThat(queue.poll(DUMMY_TASK), equalTo(UUID_1)); + assertThat(queue.poll(DUMMY_TASK), nullValue()); + } + + @Test + public void shouldReturnNull() { + final ConstrainedPrioritySet queue = new ConstrainedPrioritySet(alwaysFalse, client -> 1.0); + queue.offerAll(singleton(UUID_1)); + + assertThat(queue.poll(DUMMY_TASK), nullValue()); + } + + @Test + public void shouldReturnLeastLoadedClient() { + final ConstrainedPrioritySet queue = new ConstrainedPrioritySet( + alwaysTrue, + client -> (client == UUID_1) ? 3.0 : (client == UUID_2) ? 2.0 : 1.0 + ); + + queue.offerAll(asList(UUID_1, UUID_2, UUID_3)); + + assertThat(queue.poll(DUMMY_TASK), equalTo(UUID_3)); + assertThat(queue.poll(DUMMY_TASK), equalTo(UUID_2)); + assertThat(queue.poll(DUMMY_TASK), equalTo(UUID_1)); + assertThat(queue.poll(DUMMY_TASK), nullValue()); + } + + @Test + public void shouldNotRetainDuplicates() { + final ConstrainedPrioritySet queue = new ConstrainedPrioritySet(alwaysTrue, client -> 1.0); + + queue.offerAll(singleton(UUID_1)); + queue.offer(UUID_1); + + assertThat(queue.poll(DUMMY_TASK), equalTo(UUID_1)); + assertThat(queue.poll(DUMMY_TASK), nullValue()); + } + + @Test + public void shouldOnlyReturnValidClients() { + final ConstrainedPrioritySet queue = new ConstrainedPrioritySet( + (client, task) -> client.equals(UUID_1), + client -> 1.0 + ); + + queue.offerAll(asList(UUID_1, UUID_2)); + + assertThat(queue.poll(DUMMY_TASK), equalTo(UUID_1)); + assertThat(queue.poll(DUMMY_TASK), nullValue()); + } + + @Test + public void shouldApplyPollFilter() { + final ConstrainedPrioritySet queue = new ConstrainedPrioritySet( + alwaysTrue, + client -> 1.0 + ); + + queue.offerAll(asList(UUID_1, UUID_2)); + + assertThat(queue.poll(DUMMY_TASK, client -> client.equals(UUID_1)), equalTo(UUID_1)); + assertThat(queue.poll(DUMMY_TASK, client -> client.equals(UUID_1)), nullValue()); + assertThat(queue.poll(DUMMY_TASK), equalTo(UUID_2)); + assertThat(queue.poll(DUMMY_TASK), nullValue()); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/FallbackPriorTaskAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/FallbackPriorTaskAssignorTest.java new file mode 100644 index 0000000..491fff1 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/FallbackPriorTaskAssignorTest.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import org.apache.kafka.streams.processor.TaskId; +import org.junit.Test; + +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; +import java.util.UUID; + +import static java.util.Arrays.asList; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_2; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_2; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; + +public class FallbackPriorTaskAssignorTest { + + private final Map clients = new TreeMap<>(); + + @Test + public void shouldViolateBalanceToPreserveActiveTaskStickiness() { + final ClientState c1 = createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_0, TASK_0_1, TASK_0_2); + final ClientState c2 = createClient(UUID_2, 1); + + final List taskIds = asList(TASK_0_0, TASK_0_1, TASK_0_2); + Collections.shuffle(taskIds); + final boolean probingRebalanceNeeded = new FallbackPriorTaskAssignor().assign( + clients, + new HashSet<>(taskIds), + new HashSet<>(taskIds), + new AssignorConfiguration.AssignmentConfigs(0L, 1, 0, 60_000L) + ); + assertThat(probingRebalanceNeeded, is(true)); + + assertThat(c1.activeTasks(), equalTo(mkSet(TASK_0_0, TASK_0_1, TASK_0_2))); + assertThat(c2.activeTasks(), empty()); + } + + private ClientState createClient(final UUID processId, final int capacity) { + return createClientWithPreviousActiveTasks(processId, capacity); + } + + private ClientState createClientWithPreviousActiveTasks(final UUID processId, final int capacity, final TaskId... taskIds) { + final ClientState clientState = new ClientState(capacity); + clientState.addPreviousActiveTasks(mkSet(taskIds)); + clients.put(processId, clientState); + return clientState; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignorTest.java new file mode 100644 index 0000000..a2d4716 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/HighAvailabilityTaskAssignorTest.java @@ -0,0 +1,834 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs; +import org.junit.Test; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; + +import static java.util.Collections.emptySet; +import static java.util.Collections.singleton; +import static java.util.Collections.singletonMap; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.EMPTY_TASKS; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_2; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_3; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_2; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_3; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_2; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_2; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_3; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.analyzeTaskAssignmentBalance; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.assertBalancedActiveAssignment; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.assertBalancedStatefulAssignment; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.assertBalancedTasks; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.assertValidAssignment; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.getClientStatesMap; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.hasActiveTasks; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.hasAssignedTasks; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.hasStandbyTasks; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.fail; + +public class HighAvailabilityTaskAssignorTest { + private final AssignmentConfigs configWithoutStandbys = new AssignmentConfigs( + /*acceptableRecoveryLag*/ 100L, + /*maxWarmupReplicas*/ 2, + /*numStandbyReplicas*/ 0, + /*probingRebalanceIntervalMs*/ 60 * 1000L + ); + + private final AssignmentConfigs configWithStandbys = new AssignmentConfigs( + /*acceptableRecoveryLag*/ 100L, + /*maxWarmupReplicas*/ 2, + /*numStandbyReplicas*/ 1, + /*probingRebalanceIntervalMs*/ 60 * 1000L + ); + + @Test + public void shouldBeStickyForActiveAndStandbyTasksWhileWarmingUp() { + final Set allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2, TASK_2_0, TASK_2_1, TASK_2_2); + final ClientState clientState1 = new ClientState(allTaskIds, emptySet(), allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 0L)), 1); + final ClientState clientState2 = new ClientState(emptySet(), allTaskIds, allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L)), 1); + final ClientState clientState3 = new ClientState(emptySet(), emptySet(), allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> Long.MAX_VALUE)), 1); + + final Map clientStates = mkMap( + mkEntry(UUID_1, clientState1), + mkEntry(UUID_2, clientState2), + mkEntry(UUID_3, clientState3) + ); + + final boolean unstable = new HighAvailabilityTaskAssignor().assign( + clientStates, + allTaskIds, + allTaskIds, + new AssignmentConfigs(11L, 2, 1, 60_000L) + ); + + assertThat(clientState1, hasAssignedTasks(allTaskIds.size())); + + assertThat(clientState2, hasAssignedTasks(allTaskIds.size())); + + assertThat(clientState3, hasAssignedTasks(2)); + + assertThat(unstable, is(true)); + } + + @Test + public void shouldSkipWarmupsWhenAcceptableLagIsMax() { + final Set allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2, TASK_2_0, TASK_2_1, TASK_2_2); + final ClientState clientState1 = new ClientState(allTaskIds, emptySet(), allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 0L)), 1); + final ClientState clientState2 = new ClientState(emptySet(), emptySet(), allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> Long.MAX_VALUE)), 1); + final ClientState clientState3 = new ClientState(emptySet(), emptySet(), allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> Long.MAX_VALUE)), 1); + + final Map clientStates = mkMap( + mkEntry(UUID_1, clientState1), + mkEntry(UUID_2, clientState2), + mkEntry(UUID_3, clientState3) + ); + + final boolean unstable = new HighAvailabilityTaskAssignor().assign( + clientStates, + allTaskIds, + allTaskIds, + new AssignmentConfigs(Long.MAX_VALUE, 1, 1, 60_000L) + ); + + assertThat(clientState1, hasAssignedTasks(6)); + assertThat(clientState2, hasAssignedTasks(6)); + assertThat(clientState3, hasAssignedTasks(6)); + assertThat(unstable, is(false)); + } + + @Test + public void shouldAssignActiveStatefulTasksEvenlyOverClientsWhereNumberOfClientsIntegralDivisorOfNumberOfTasks() { + final Set allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2, TASK_2_0, TASK_2_1, TASK_2_2); + final Map lags = allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L)); + final ClientState clientState1 = new ClientState(emptySet(), emptySet(), lags, 1); + final ClientState clientState2 = new ClientState(emptySet(), emptySet(), lags, 1); + final ClientState clientState3 = new ClientState(emptySet(), emptySet(), lags, 1); + final Map clientStates = getClientStatesMap(clientState1, clientState2, clientState3); + final boolean unstable = new HighAvailabilityTaskAssignor().assign( + clientStates, + allTaskIds, + allTaskIds, + new AssignmentConfigs(0L, 1, 0, 60_000L) + ); + assertThat(unstable, is(false)); + assertValidAssignment(0, allTaskIds, emptySet(), clientStates, new StringBuilder()); + assertBalancedActiveAssignment(clientStates, new StringBuilder()); + assertBalancedStatefulAssignment(allTaskIds, clientStates, new StringBuilder()); + assertBalancedTasks(clientStates); + } + + @Test + public void shouldAssignActiveStatefulTasksEvenlyOverClientsWhereNumberOfThreadsIntegralDivisorOfNumberOfTasks() { + final Set allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2, TASK_2_0, TASK_2_1, TASK_2_2); + final Map lags = allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L)); + final ClientState clientState1 = new ClientState(emptySet(), emptySet(), lags, 3); + final ClientState clientState2 = new ClientState(emptySet(), emptySet(), lags, 3); + final ClientState clientState3 = new ClientState(emptySet(), emptySet(), lags, 3); + final Map clientStates = getClientStatesMap(clientState1, clientState2, clientState3); + final boolean unstable = new HighAvailabilityTaskAssignor().assign( + clientStates, + allTaskIds, + allTaskIds, + new AssignmentConfigs(0L, 1, 0, 60_000L) + ); + assertThat(unstable, is(false)); + assertValidAssignment(0, allTaskIds, emptySet(), clientStates, new StringBuilder()); + assertBalancedActiveAssignment(clientStates, new StringBuilder()); + assertBalancedStatefulAssignment(allTaskIds, clientStates, new StringBuilder()); + assertBalancedTasks(clientStates); + } + + @Test + public void shouldAssignActiveStatefulTasksEvenlyOverClientsWhereNumberOfClientsNotIntegralDivisorOfNumberOfTasks() { + final Set allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2, TASK_2_0, TASK_2_1, TASK_2_2); + final Map lags = allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L)); + final ClientState clientState1 = new ClientState(emptySet(), emptySet(), lags, 1); + final ClientState clientState2 = new ClientState(emptySet(), emptySet(), lags, 1); + final Map clientStates = getClientStatesMap(clientState1, clientState2); + final boolean unstable = new HighAvailabilityTaskAssignor().assign( + clientStates, + allTaskIds, + allTaskIds, + new AssignmentConfigs(0L, 1, 0, 60_000L) + ); + + assertThat(unstable, is(false)); + assertValidAssignment(0, allTaskIds, emptySet(), clientStates, new StringBuilder()); + assertBalancedActiveAssignment(clientStates, new StringBuilder()); + assertBalancedStatefulAssignment(allTaskIds, clientStates, new StringBuilder()); + assertBalancedTasks(clientStates); + } + + @Test + public void shouldAssignActiveStatefulTasksEvenlyOverUnevenlyDistributedStreamThreads() { + final Set allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2); + final Map lags = allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L)); + final ClientState clientState1 = new ClientState(emptySet(), emptySet(), lags, 1); + final ClientState clientState2 = new ClientState(emptySet(), emptySet(), lags, 2); + final ClientState clientState3 = new ClientState(emptySet(), emptySet(), lags, 3); + final Map clientStates = getClientStatesMap(clientState1, clientState2, clientState3); + final boolean unstable = new HighAvailabilityTaskAssignor().assign( + clientStates, + allTaskIds, + allTaskIds, + new AssignmentConfigs(0L, 1, 0, 60_000L) + ); + + assertThat(unstable, is(false)); + assertValidAssignment(0, allTaskIds, emptySet(), clientStates, new StringBuilder()); + assertBalancedActiveAssignment(clientStates, new StringBuilder()); + assertBalancedStatefulAssignment(allTaskIds, clientStates, new StringBuilder()); + + assertThat(clientState1, hasActiveTasks(1)); + assertThat(clientState2, hasActiveTasks(2)); + assertThat(clientState3, hasActiveTasks(3)); + final AssignmentTestUtils.TaskSkewReport taskSkewReport = analyzeTaskAssignmentBalance(clientStates); + if (taskSkewReport.totalSkewedTasks() == 0) { + fail("Expected a skewed task assignment, but was: " + taskSkewReport); + } + } + + @Test + public void shouldAssignActiveStatefulTasksEvenlyOverClientsWithMoreClientsThanTasks() { + final Set allTaskIds = mkSet(TASK_0_0, TASK_0_1); + final Map lags = allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L)); + final ClientState clientState1 = new ClientState(emptySet(), emptySet(), lags, 1); + final ClientState clientState2 = new ClientState(emptySet(), emptySet(), lags, 1); + final ClientState clientState3 = new ClientState(emptySet(), emptySet(), lags, 1); + final Map clientStates = getClientStatesMap(clientState1, clientState2, clientState3); + final boolean unstable = new HighAvailabilityTaskAssignor().assign( + clientStates, + allTaskIds, + allTaskIds, + new AssignmentConfigs(0L, 1, 0, 60_000L) + ); + + assertThat(unstable, is(false)); + assertValidAssignment(0, allTaskIds, emptySet(), clientStates, new StringBuilder()); + assertBalancedActiveAssignment(clientStates, new StringBuilder()); + assertBalancedStatefulAssignment(allTaskIds, clientStates, new StringBuilder()); + assertBalancedTasks(clientStates); + } + + @Test + public void shouldAssignActiveStatefulTasksEvenlyOverClientsAndStreamThreadsWithEqualStreamThreadsPerClientAsTasks() { + final Set allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2, TASK_2_0, TASK_2_1, TASK_2_2); + final Map lags = allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L)); + final ClientState clientState1 = new ClientState(emptySet(), emptySet(), lags, 9); + final ClientState clientState2 = new ClientState(emptySet(), emptySet(), lags, 9); + final ClientState clientState3 = new ClientState(emptySet(), emptySet(), lags, 9); + final Map clientStates = getClientStatesMap(clientState1, clientState2, clientState3); + final boolean unstable = new HighAvailabilityTaskAssignor().assign( + clientStates, + allTaskIds, + allTaskIds, + new AssignmentConfigs(0L, 1, 0, 60_000L) + ); + + assertThat(unstable, is(false)); + assertValidAssignment(0, allTaskIds, emptySet(), clientStates, new StringBuilder()); + assertBalancedActiveAssignment(clientStates, new StringBuilder()); + assertBalancedStatefulAssignment(allTaskIds, clientStates, new StringBuilder()); + assertBalancedTasks(clientStates); + } + + @Test + public void shouldAssignWarmUpTasksIfStatefulActiveTasksBalancedOverStreamThreadsButNotOverClients() { + final Set allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_1_0, TASK_1_1); + final Map lagsForCaughtUpClient = allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 0L)); + final Map lagsForNotCaughtUpClient = + allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> Long.MAX_VALUE)); + final ClientState caughtUpClientState = new ClientState(allTaskIds, emptySet(), lagsForCaughtUpClient, 5); + final ClientState notCaughtUpClientState1 = new ClientState(emptySet(), emptySet(), lagsForNotCaughtUpClient, 5); + final ClientState notCaughtUpClientState2 = new ClientState(emptySet(), emptySet(), lagsForNotCaughtUpClient, 5); + final Map clientStates = + getClientStatesMap(caughtUpClientState, notCaughtUpClientState1, notCaughtUpClientState2); + final boolean unstable = new HighAvailabilityTaskAssignor().assign( + clientStates, + allTaskIds, + allTaskIds, + new AssignmentConfigs(0L, allTaskIds.size() / 3 + 1, 0, 60_000L) + ); + + assertThat(unstable, is(true)); + assertThat(notCaughtUpClientState1.standbyTaskCount(), greaterThanOrEqualTo(allTaskIds.size() / 3)); + assertThat(notCaughtUpClientState2.standbyTaskCount(), greaterThanOrEqualTo(allTaskIds.size() / 3)); + assertValidAssignment(0, allTaskIds.size() / 3 + 1, allTaskIds, emptySet(), clientStates, new StringBuilder()); + } + + @Test + public void shouldEvenlyAssignActiveStatefulTasksIfClientsAreWarmedUpToBalanceTaskOverClients() { + final Set allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_1_0, TASK_1_1); + final Set warmedUpTaskIds1 = mkSet(TASK_0_1); + final Set warmedUpTaskIds2 = mkSet(TASK_1_0); + final Map lagsForCaughtUpClient = allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 0L)); + final Map lagsForWarmedUpClient1 = + allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> Long.MAX_VALUE)); + lagsForWarmedUpClient1.put(TASK_0_1, 0L); + final Map lagsForWarmedUpClient2 = + allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> Long.MAX_VALUE)); + lagsForWarmedUpClient2.put(TASK_1_0, 0L); + final ClientState caughtUpClientState = new ClientState(allTaskIds, emptySet(), lagsForCaughtUpClient, 5); + final ClientState warmedUpClientState1 = new ClientState(emptySet(), warmedUpTaskIds1, lagsForWarmedUpClient1, 5); + final ClientState warmedUpClientState2 = new ClientState(emptySet(), warmedUpTaskIds2, lagsForWarmedUpClient2, 5); + final Map clientStates = + getClientStatesMap(caughtUpClientState, warmedUpClientState1, warmedUpClientState2); + final boolean unstable = new HighAvailabilityTaskAssignor().assign( + clientStates, + allTaskIds, + allTaskIds, + new AssignmentConfigs(0L, allTaskIds.size() / 3 + 1, 0, 60_000L) + ); + + assertThat(unstable, is(false)); + assertValidAssignment(0, allTaskIds, emptySet(), clientStates, new StringBuilder()); + assertBalancedActiveAssignment(clientStates, new StringBuilder()); + assertBalancedStatefulAssignment(allTaskIds, clientStates, new StringBuilder()); + assertBalancedTasks(clientStates); + } + + @Test + public void shouldAssignActiveStatefulTasksEvenlyOverStreamThreadsButBestEffortOverClients() { + final Set allTaskIds = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2, TASK_2_0, TASK_2_1, TASK_2_2); + final Map lags = allTaskIds.stream().collect(Collectors.toMap(k -> k, k -> 10L)); + final ClientState clientState1 = new ClientState(emptySet(), emptySet(), lags, 6); + final ClientState clientState2 = new ClientState(emptySet(), emptySet(), lags, 3); + final Map clientStates = getClientStatesMap(clientState1, clientState2); + final boolean unstable = new HighAvailabilityTaskAssignor().assign( + clientStates, + allTaskIds, + allTaskIds, + new AssignmentConfigs(0L, 1, 0, 60_000L) + ); + + assertThat(unstable, is(false)); + assertValidAssignment(0, allTaskIds, emptySet(), clientStates, new StringBuilder()); + assertBalancedActiveAssignment(clientStates, new StringBuilder()); + assertBalancedStatefulAssignment(allTaskIds, clientStates, new StringBuilder()); + assertThat(clientState1, hasActiveTasks(6)); + assertThat(clientState2, hasActiveTasks(3)); + } + + @Test + public void shouldComputeNewAssignmentIfThereAreUnassignedActiveTasks() { + final Set allTasks = mkSet(TASK_0_0, TASK_0_1); + final ClientState client1 = new ClientState(singleton(TASK_0_0), emptySet(), singletonMap(TASK_0_0, 0L), 1); + final Map clientStates = singletonMap(UUID_1, client1); + + final boolean probingRebalanceNeeded = new HighAvailabilityTaskAssignor().assign(clientStates, + allTasks, + singleton(TASK_0_0), + configWithoutStandbys); + + assertThat(probingRebalanceNeeded, is(false)); + assertThat(client1, hasActiveTasks(2)); + assertThat(client1, hasStandbyTasks(0)); + + assertValidAssignment(0, allTasks, emptySet(), clientStates, new StringBuilder()); + assertBalancedActiveAssignment(clientStates, new StringBuilder()); + assertBalancedStatefulAssignment(allTasks, clientStates, new StringBuilder()); + assertBalancedTasks(clientStates); + } + + @Test + public void shouldComputeNewAssignmentIfThereAreUnassignedStandbyTasks() { + final Set allTasks = mkSet(TASK_0_0); + final Set statefulTasks = mkSet(TASK_0_0); + final ClientState client1 = new ClientState(singleton(TASK_0_0), emptySet(), singletonMap(TASK_0_0, 0L), 1); + final ClientState client2 = new ClientState(emptySet(), emptySet(), singletonMap(TASK_0_0, 0L), 1); + final Map clientStates = mkMap(mkEntry(UUID_1, client1), mkEntry(UUID_2, client2)); + + final boolean probingRebalanceNeeded = new HighAvailabilityTaskAssignor().assign(clientStates, + allTasks, + statefulTasks, + configWithStandbys); + + assertThat(clientStates.get(UUID_2).standbyTasks(), not(empty())); + assertThat(probingRebalanceNeeded, is(false)); + assertValidAssignment(1, allTasks, emptySet(), clientStates, new StringBuilder()); + assertBalancedActiveAssignment(clientStates, new StringBuilder()); + assertBalancedStatefulAssignment(allTasks, clientStates, new StringBuilder()); + assertBalancedTasks(clientStates); + } + + @Test + public void shouldComputeNewAssignmentIfActiveTasksWasNotOnCaughtUpClient() { + final Set allTasks = mkSet(TASK_0_0, TASK_0_1); + final Set statefulTasks = mkSet(TASK_0_0); + final ClientState client1 = new ClientState(singleton(TASK_0_0), emptySet(), singletonMap(TASK_0_0, 500L), 1); + final ClientState client2 = new ClientState(singleton(TASK_0_1), emptySet(), singletonMap(TASK_0_0, 0L), 1); + final Map clientStates = mkMap( + mkEntry(UUID_1, client1), + mkEntry(UUID_2, client2) + ); + + final boolean probingRebalanceNeeded = + new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, statefulTasks, configWithoutStandbys); + + assertThat(clientStates.get(UUID_1).activeTasks(), is(singleton(TASK_0_1))); + assertThat(clientStates.get(UUID_2).activeTasks(), is(singleton(TASK_0_0))); + // we'll warm up task 0_0 on client1 because it's first in sorted order, + // although this isn't an optimal convergence + assertThat(probingRebalanceNeeded, is(true)); + assertValidAssignment(0, 1, allTasks, emptySet(), clientStates, new StringBuilder()); + assertBalancedActiveAssignment(clientStates, new StringBuilder()); + assertBalancedStatefulAssignment(allTasks, clientStates, new StringBuilder()); + assertBalancedTasks(clientStates); + } + + @Test + public void shouldAssignStandbysForStatefulTasks() { + final Set allTasks = mkSet(TASK_0_0, TASK_0_1); + final Set statefulTasks = mkSet(TASK_0_0, TASK_0_1); + + final ClientState client1 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0), statefulTasks); + final ClientState client2 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_1), statefulTasks); + + final Map clientStates = getClientStatesMap(client1, client2); + final boolean probingRebalanceNeeded = + new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, statefulTasks, configWithStandbys); + + + assertThat(client1.activeTasks(), equalTo(mkSet(TASK_0_0))); + assertThat(client2.activeTasks(), equalTo(mkSet(TASK_0_1))); + assertThat(client1.standbyTasks(), equalTo(mkSet(TASK_0_1))); + assertThat(client2.standbyTasks(), equalTo(mkSet(TASK_0_0))); + assertThat(probingRebalanceNeeded, is(false)); + } + + @Test + public void shouldNotAssignStandbysForStatelessTasks() { + final Set allTasks = mkSet(TASK_0_0, TASK_0_1); + final Set statefulTasks = EMPTY_TASKS; + + final ClientState client1 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks); + final ClientState client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks); + + final Map clientStates = getClientStatesMap(client1, client2); + final boolean probingRebalanceNeeded = + new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, statefulTasks, configWithStandbys); + + + assertThat(client1.activeTaskCount(), equalTo(1)); + assertThat(client2.activeTaskCount(), equalTo(1)); + assertHasNoStandbyTasks(client1, client2); + assertThat(probingRebalanceNeeded, is(false)); + } + + @Test + public void shouldAssignWarmupReplicasEvenIfNoStandbyReplicasConfigured() { + final Set allTasks = mkSet(TASK_0_0, TASK_0_1); + final Set statefulTasks = mkSet(TASK_0_0, TASK_0_1); + final ClientState client1 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_1), statefulTasks); + final ClientState client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks); + + final Map clientStates = getClientStatesMap(client1, client2); + final boolean probingRebalanceNeeded = + new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, statefulTasks, configWithoutStandbys); + + + assertThat(client1.activeTasks(), equalTo(mkSet(TASK_0_0, TASK_0_1))); + assertThat(client2.standbyTaskCount(), equalTo(1)); + assertHasNoStandbyTasks(client1); + assertHasNoActiveTasks(client2); + assertThat(probingRebalanceNeeded, is(true)); + } + + + @Test + public void shouldNotAssignMoreThanMaxWarmupReplicas() { + final Set allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3); + final Set statefulTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3); + final ClientState client1 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3), statefulTasks); + final ClientState client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks); + + final Map clientStates = getClientStatesMap(client1, client2); + final boolean probingRebalanceNeeded = new HighAvailabilityTaskAssignor().assign( + clientStates, + allTasks, + statefulTasks, + new AssignmentConfigs( + /*acceptableRecoveryLag*/ 100L, + /*maxWarmupReplicas*/ 1, + /*numStandbyReplicas*/ 0, + /*probingRebalanceIntervalMs*/ 60 * 1000L + ) + ); + + + assertThat(client1.activeTasks(), equalTo(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3))); + assertThat(client2.standbyTaskCount(), equalTo(1)); + assertHasNoStandbyTasks(client1); + assertHasNoActiveTasks(client2); + assertThat(probingRebalanceNeeded, is(true)); + } + + @Test + public void shouldNotAssignWarmupAndStandbyToTheSameClient() { + final Set allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3); + final Set statefulTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3); + final ClientState client1 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3), statefulTasks); + final ClientState client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks); + + final Map clientStates = getClientStatesMap(client1, client2); + final boolean probingRebalanceNeeded = new HighAvailabilityTaskAssignor().assign( + clientStates, + allTasks, + statefulTasks, + new AssignmentConfigs( + /*acceptableRecoveryLag*/ 100L, + /*maxWarmupReplicas*/ 1, + /*numStandbyReplicas*/ 1, + /*probingRebalanceIntervalMs*/ 60 * 1000L + ) + ); + + assertThat(client1.activeTasks(), equalTo(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3))); + assertThat(client2.standbyTasks(), equalTo(mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3))); + assertHasNoStandbyTasks(client1); + assertHasNoActiveTasks(client2); + assertThat(probingRebalanceNeeded, is(true)); + } + + @Test + public void shouldNotAssignAnyStandbysWithInsufficientCapacity() { + final Set allTasks = mkSet(TASK_0_0, TASK_0_1); + final Set statefulTasks = mkSet(TASK_0_0, TASK_0_1); + final ClientState client1 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_1), statefulTasks); + + final Map clientStates = getClientStatesMap(client1); + final boolean probingRebalanceNeeded = + new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, statefulTasks, configWithStandbys); + + assertThat(client1.activeTasks(), equalTo(mkSet(TASK_0_0, TASK_0_1))); + assertHasNoStandbyTasks(client1); + assertThat(probingRebalanceNeeded, is(false)); + } + + @Test + public void shouldAssignActiveTasksToNotCaughtUpClientIfNoneExist() { + final Set allTasks = mkSet(TASK_0_0, TASK_0_1); + final Set statefulTasks = mkSet(TASK_0_0, TASK_0_1); + final ClientState client1 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks); + + final Map clientStates = getClientStatesMap(client1); + + final boolean probingRebalanceNeeded = + new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, statefulTasks, configWithStandbys); + assertThat(client1.activeTasks(), equalTo(mkSet(TASK_0_0, TASK_0_1))); + assertHasNoStandbyTasks(client1); + assertThat(probingRebalanceNeeded, is(false)); + } + + @Test + public void shouldNotAssignMoreThanMaxWarmupReplicasWithStandbys() { + final Set allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3); + final Set statefulTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3); + final ClientState client1 = getMockClientWithPreviousCaughtUpTasks(statefulTasks, statefulTasks); + final ClientState client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks); + final ClientState client3 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks); + + final Map clientStates = getClientStatesMap(client1, client2, client3); + + final boolean probingRebalanceNeeded = + new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, statefulTasks, configWithStandbys); + + assertValidAssignment( + 1, + 2, + statefulTasks, + emptySet(), + clientStates, + new StringBuilder() + ); + assertThat(probingRebalanceNeeded, is(true)); + } + + @Test + public void shouldDistributeStatelessTasksToBalanceTotalTaskLoad() { + final Set allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_1_2); + final Set statefulTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3); + final Set statelessTasks = mkSet(TASK_1_0, TASK_1_1, TASK_1_2); + + final ClientState client1 = getMockClientWithPreviousCaughtUpTasks(statefulTasks, statefulTasks); + final ClientState client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks); + + final Map clientStates = getClientStatesMap(client1, client2); + + final boolean probingRebalanceNeeded = + new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, statefulTasks, configWithStandbys); + assertValidAssignment( + 1, + 2, + statefulTasks, + statelessTasks, + clientStates, + new StringBuilder() + ); + assertBalancedActiveAssignment(clientStates, new StringBuilder()); + assertBalancedStatefulAssignment(statefulTasks, clientStates, new StringBuilder()); + + // since only client1 is caught up on the stateful tasks, we expect it to get _all_ the active tasks, + // which means that client2 should have gotten all of the stateless tasks, so the tasks should be skewed + final AssignmentTestUtils.TaskSkewReport taskSkewReport = analyzeTaskAssignmentBalance(clientStates); + assertThat(taskSkewReport.toString(), taskSkewReport.skewedSubtopologies(), not(empty())); + + assertThat(probingRebalanceNeeded, is(true)); + } + + @Test + public void shouldDistributeStatefulActiveTasksToAllClients() { + final Set allTasks = + mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_1_2, TASK_1_3, TASK_2_0); // 9 total + final Map allTaskLags = allTasks.stream().collect(Collectors.toMap(t -> t, t -> 0L)); + final Set statefulTasks = new HashSet<>(allTasks); + final ClientState client1 = new ClientState(emptySet(), emptySet(), allTaskLags, 100); + final ClientState client2 = new ClientState(emptySet(), emptySet(), allTaskLags, 50); + final ClientState client3 = new ClientState(emptySet(), emptySet(), allTaskLags, 1); + + final Map clientStates = getClientStatesMap(client1, client2, client3); + + final boolean probingRebalanceNeeded = + new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, statefulTasks, configWithoutStandbys); + + assertThat(client1.activeTasks(), not(empty())); + assertThat(client2.activeTasks(), not(empty())); + assertThat(client3.activeTasks(), not(empty())); + assertThat(probingRebalanceNeeded, is(false)); + } + + @Test + public void shouldReturnFalseIfPreviousAssignmentIsReused() { + final Set allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3); + final Set statefulTasks = new HashSet<>(allTasks); + final ClientState client1 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_0, TASK_0_2), statefulTasks); + final ClientState client2 = getMockClientWithPreviousCaughtUpTasks(mkSet(TASK_0_1, TASK_0_3), statefulTasks); + + final Map clientStates = getClientStatesMap(client1, client2); + final boolean probingRebalanceNeeded = + new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, statefulTasks, configWithoutStandbys); + + assertThat(probingRebalanceNeeded, is(false)); + assertThat(client1.activeTasks(), equalTo(client1.prevActiveTasks())); + assertThat(client2.activeTasks(), equalTo(client2.prevActiveTasks())); + } + + @Test + public void shouldReturnFalseIfNoWarmupTasksAreAssigned() { + final Set allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3); + final Set statefulTasks = EMPTY_TASKS; + final ClientState client1 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks); + final ClientState client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks); + + final Map clientStates = getClientStatesMap(client1, client2); + final boolean probingRebalanceNeeded = + new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, statefulTasks, configWithoutStandbys); + assertThat(probingRebalanceNeeded, is(false)); + assertHasNoStandbyTasks(client1, client2); + } + + @Test + public void shouldReturnTrueIfWarmupTasksAreAssigned() { + final Set allTasks = mkSet(TASK_0_0, TASK_0_1); + final Set statefulTasks = mkSet(TASK_0_0, TASK_0_1); + final ClientState client1 = getMockClientWithPreviousCaughtUpTasks(allTasks, statefulTasks); + final ClientState client2 = getMockClientWithPreviousCaughtUpTasks(EMPTY_TASKS, statefulTasks); + + final Map clientStates = getClientStatesMap(client1, client2); + final boolean probingRebalanceNeeded = + new HighAvailabilityTaskAssignor().assign(clientStates, allTasks, statefulTasks, configWithoutStandbys); + assertThat(probingRebalanceNeeded, is(true)); + assertThat(client2.standbyTaskCount(), equalTo(1)); + } + + @Test + public void shouldDistributeStatelessTasksEvenlyOverClientsWithEqualStreamThreadsPerClientAsTasksAndNoStatefulTasks() { + final Set allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_1_2); + final Set statefulTasks = EMPTY_TASKS; + final Set statelessTasks = new HashSet<>(allTasks); + + final Map taskLags = new HashMap<>(); + final ClientState client1 = new ClientState(emptySet(), emptySet(), taskLags, 7); + final ClientState client2 = new ClientState(emptySet(), emptySet(), taskLags, 7); + final ClientState client3 = new ClientState(emptySet(), emptySet(), taskLags, 7); + + final Map clientStates = getClientStatesMap(client1, client2, client3); + + final boolean probingRebalanceNeeded = new HighAvailabilityTaskAssignor().assign( + clientStates, + allTasks, + statefulTasks, + new AssignmentConfigs(0L, 1, 0, 60_000L) + ); + + assertValidAssignment( + 0, + EMPTY_TASKS, + statelessTasks, + clientStates, + new StringBuilder() + ); + assertBalancedActiveAssignment(clientStates, new StringBuilder()); + assertThat(probingRebalanceNeeded, is(false)); + } + + @Test + public void shouldDistributeStatelessTasksEvenlyOverClientsWithLessStreamThreadsPerClientAsTasksAndNoStatefulTasks() { + final Set allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_1_2); + final Set statefulTasks = EMPTY_TASKS; + final Set statelessTasks = new HashSet<>(allTasks); + + final Map taskLags = new HashMap<>(); + final ClientState client1 = new ClientState(emptySet(), emptySet(), taskLags, 2); + final ClientState client2 = new ClientState(emptySet(), emptySet(), taskLags, 2); + final ClientState client3 = new ClientState(emptySet(), emptySet(), taskLags, 2); + + final Map clientStates = getClientStatesMap(client1, client2, client3); + + final boolean probingRebalanceNeeded = new HighAvailabilityTaskAssignor().assign( + clientStates, + allTasks, + statefulTasks, + new AssignmentConfigs(0L, 1, 0, 60_000L) + ); + + assertValidAssignment( + 0, + EMPTY_TASKS, + statelessTasks, + clientStates, + new StringBuilder() + ); + assertBalancedActiveAssignment(clientStates, new StringBuilder()); + assertThat(probingRebalanceNeeded, is(false)); + } + + @Test + public void shouldDistributeStatelessTasksEvenlyOverClientsWithUnevenlyDistributedStreamThreadsAndNoStatefulTasks() { + final Set allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_1_2); + final Set statefulTasks = EMPTY_TASKS; + final Set statelessTasks = new HashSet<>(allTasks); + + final Map taskLags = new HashMap<>(); + final ClientState client1 = new ClientState(emptySet(), emptySet(), taskLags, 1); + final ClientState client2 = new ClientState(emptySet(), emptySet(), taskLags, 2); + final ClientState client3 = new ClientState(emptySet(), emptySet(), taskLags, 3); + + final Map clientStates = getClientStatesMap(client1, client2, client3); + + final boolean probingRebalanceNeeded = new HighAvailabilityTaskAssignor().assign( + clientStates, + allTasks, + statefulTasks, + new AssignmentConfigs(0L, 1, 0, 60_000L) + ); + + assertValidAssignment( + 0, + EMPTY_TASKS, + statelessTasks, + clientStates, + new StringBuilder() + ); + assertBalancedActiveAssignment(clientStates, new StringBuilder()); + assertThat(probingRebalanceNeeded, is(false)); + } + + @Test + public void shouldDistributeStatelessTasksEvenlyWithPreviousAssignmentAndNoStatefulTasks() { + final Set allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_1_0, TASK_1_1, TASK_1_2); + final Set statefulTasks = EMPTY_TASKS; + final Set statelessTasks = new HashSet<>(allTasks); + + final Map taskLags = new HashMap<>(); + final ClientState client1 = new ClientState(statelessTasks, emptySet(), taskLags, 3); + final ClientState client2 = new ClientState(emptySet(), emptySet(), taskLags, 3); + final ClientState client3 = new ClientState(emptySet(), emptySet(), taskLags, 3); + + final Map clientStates = getClientStatesMap(client1, client2, client3); + + final boolean probingRebalanceNeeded = new HighAvailabilityTaskAssignor().assign( + clientStates, + allTasks, + statefulTasks, + new AssignmentConfigs(0L, 1, 0, 60_000L) + ); + + assertValidAssignment( + 0, + EMPTY_TASKS, + statelessTasks, + clientStates, + new StringBuilder() + ); + assertBalancedActiveAssignment(clientStates, new StringBuilder()); + assertThat(probingRebalanceNeeded, is(false)); + } + + private static void assertHasNoActiveTasks(final ClientState... clients) { + for (final ClientState client : clients) { + assertThat(client.activeTasks(), is(empty())); + } + } + + private static void assertHasNoStandbyTasks(final ClientState... clients) { + for (final ClientState client : clients) { + assertThat(client, hasStandbyTasks(0)); + } + } + + private static ClientState getMockClientWithPreviousCaughtUpTasks(final Set statefulActiveTasks, + final Set statefulTasks) { + if (!statefulTasks.containsAll(statefulActiveTasks)) { + throw new IllegalArgumentException("Need to initialize stateful tasks set before creating mock clients"); + } + final Map taskLags = new HashMap<>(); + for (final TaskId task : statefulTasks) { + if (statefulActiveTasks.contains(task)) { + taskLags.put(task, 0L); + } else { + taskLags.put(task, Long.MAX_VALUE); + } + } + return new ClientState(statefulActiveTasks, emptySet(), taskLags, 1); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/LegacySubscriptionInfoSerde.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/LegacySubscriptionInfoSerde.java new file mode 100644 index 0000000..5de0c22 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/LegacySubscriptionInfoSerde.java @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import static org.apache.kafka.streams.processor.internals.assignment.ConsumerProtocolUtils.readTaskIdFrom; +import static org.apache.kafka.streams.processor.internals.assignment.ConsumerProtocolUtils.writeTaskIdTo; +import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.LATEST_SUPPORTED_VERSION; + +import org.apache.kafka.streams.errors.TaskAssignmentException; +import org.apache.kafka.streams.processor.TaskId; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Collection; +import java.util.HashSet; +import java.util.Set; +import java.util.UUID; + +public class LegacySubscriptionInfoSerde { + + private static final Logger log = LoggerFactory.getLogger(LegacySubscriptionInfoSerde.class); + + static final int UNKNOWN = -1; + + private final int usedVersion; + private final int latestSupportedVersion; + private final UUID processId; + private final Set prevTasks; + private final Set standbyTasks; + private final String userEndPoint; + + public LegacySubscriptionInfoSerde(final int version, + final int latestSupportedVersion, + final UUID processId, + final Set prevTasks, + final Set standbyTasks, + final String userEndPoint) { + if (latestSupportedVersion == UNKNOWN && (version < 1 || version > 2)) { + throw new IllegalArgumentException( + "Only versions 1 and 2 are expected to use an UNKNOWN (-1) latest supported version. " + + "Got " + version + "." + ); + } else if (latestSupportedVersion != UNKNOWN && (version < 1 || version > latestSupportedVersion)) { + throw new IllegalArgumentException( + "version must be between 1 and " + latestSupportedVersion + "; was: " + version + ); + } + usedVersion = version; + this.latestSupportedVersion = latestSupportedVersion; + this.processId = processId; + this.prevTasks = prevTasks; + this.standbyTasks = standbyTasks; + // Coerce empty string to null. This was the effect of the serialization logic, anyway. + this.userEndPoint = userEndPoint == null || userEndPoint.isEmpty() ? null : userEndPoint; + } + + public int version() { + return usedVersion; + } + + public int latestSupportedVersion() { + return latestSupportedVersion; + } + + public UUID processId() { + return processId; + } + + public Set prevTasks() { + return prevTasks; + } + + public Set standbyTasks() { + return standbyTasks; + } + + public String userEndPoint() { + return userEndPoint; + } + + /** + * @throws TaskAssignmentException if method fails to encode the data + */ + public ByteBuffer encode() { + if (usedVersion == 3 || usedVersion == 4 || usedVersion == 5 || usedVersion == 6) { + final byte[] endPointBytes = prepareUserEndPoint(this.userEndPoint); + + final ByteBuffer buf = ByteBuffer.allocate( + 4 + // used version + 4 + // latest supported version version + 16 + // client ID + 4 + prevTasks.size() * 8 + // length + prev tasks + 4 + standbyTasks.size() * 8 + // length + standby tasks + 4 + endPointBytes.length + ); + + buf.putInt(usedVersion); // used version + buf.putInt(LATEST_SUPPORTED_VERSION); // supported version + encodeClientUUID(buf, processId()); + encodeTasks(buf, prevTasks, usedVersion); + encodeTasks(buf, standbyTasks, usedVersion); + encodeUserEndPoint(buf, endPointBytes); + + buf.rewind(); + + return buf; + } else if (usedVersion == 2) { + final byte[] endPointBytes = prepareUserEndPoint(this.userEndPoint); + + final ByteBuffer buf = ByteBuffer.allocate( + 4 + // version + 16 + // client ID + 4 + prevTasks.size() * 8 + // length + prev tasks + 4 + standbyTasks.size() * 8 + // length + standby tasks + 4 + endPointBytes.length + ); + + buf.putInt(2); // version + encodeClientUUID(buf, processId()); + encodeTasks(buf, prevTasks, usedVersion); + encodeTasks(buf, standbyTasks, usedVersion); + encodeUserEndPoint(buf, endPointBytes); + + buf.rewind(); + + return buf; + } else if (usedVersion == 1) { + final ByteBuffer buf1 = ByteBuffer.allocate( + 4 + // version + 16 + // client ID + 4 + prevTasks.size() * 8 + // length + prev tasks + 4 + standbyTasks.size() * 8 + ); + + buf1.putInt(1); // version + encodeClientUUID(buf1, processId()); + encodeTasks(buf1, prevTasks, usedVersion); + encodeTasks(buf1, standbyTasks, usedVersion); + buf1.rewind(); + return buf1; + } else { + throw new IllegalStateException("Unknown metadata version: " + usedVersion + + "; latest supported version: " + LATEST_SUPPORTED_VERSION); + } + } + + public static void encodeClientUUID(final ByteBuffer buf, final UUID processId) { + buf.putLong(processId.getMostSignificantBits()); + buf.putLong(processId.getLeastSignificantBits()); + } + + public static void encodeTasks(final ByteBuffer buf, + final Collection taskIds, + final int version) { + buf.putInt(taskIds.size()); + for (final TaskId id : taskIds) { + writeTaskIdTo(id, buf, version); + } + } + + public static void encodeUserEndPoint(final ByteBuffer buf, + final byte[] endPointBytes) { + if (endPointBytes != null) { + buf.putInt(endPointBytes.length); + buf.put(endPointBytes); + } + } + + public static byte[] prepareUserEndPoint(final String userEndPoint) { + if (userEndPoint == null) { + return new byte[0]; + } else { + return userEndPoint.getBytes(StandardCharsets.UTF_8); + } + } + + /** + * @throws TaskAssignmentException if method fails to decode the data + */ + public static LegacySubscriptionInfoSerde decode(final ByteBuffer data) { + + // ensure we are at the beginning of the ByteBuffer + data.rewind(); + + final int usedVersion = data.getInt(); + if (usedVersion > 2 && usedVersion < 7) { + final int latestSupportedVersion = data.getInt(); + final UUID processId = decodeProcessId(data); + final Set prevTasks = decodeTasks(data, usedVersion); + final Set standbyTasks = decodeTasks(data, usedVersion); + final String userEndPoint = decodeUserEndpoint(data); + return new LegacySubscriptionInfoSerde(usedVersion, latestSupportedVersion, processId, prevTasks, standbyTasks, userEndPoint); + } else if (usedVersion == 2) { + final UUID processId = decodeProcessId(data); + final Set prevTasks = decodeTasks(data, usedVersion); + final Set standbyTasks = decodeTasks(data, usedVersion); + final String userEndPoint = decodeUserEndpoint(data); + return new LegacySubscriptionInfoSerde(2, UNKNOWN, processId, prevTasks, standbyTasks, userEndPoint); + } else if (usedVersion == 1) { + final UUID processId = decodeProcessId(data); + final Set prevTasks = decodeTasks(data, usedVersion); + final Set standbyTasks = decodeTasks(data, usedVersion); + return new LegacySubscriptionInfoSerde(1, UNKNOWN, processId, prevTasks, standbyTasks, null); + } else { + final int latestSupportedVersion = data.getInt(); + log.info("Unable to decode subscription data: used version: {}; latest supported version: {}", usedVersion, LATEST_SUPPORTED_VERSION); + return new LegacySubscriptionInfoSerde(usedVersion, latestSupportedVersion, null, null, null, null); + } + } + + private static String decodeUserEndpoint(final ByteBuffer data) { + final int userEndpointBytesLength = data.getInt(); + final byte[] userEndpointBytes = new byte[userEndpointBytesLength]; + data.get(userEndpointBytes); + return new String(userEndpointBytes, StandardCharsets.UTF_8); + } + + private static Set decodeTasks(final ByteBuffer data, final int version) { + final Set prevTasks = new HashSet<>(); + final int numPrevTasks = data.getInt(); + for (int i = 0; i < numPrevTasks; i++) { + prevTasks.add(readTaskIdFrom(data, version)); + } + return prevTasks; + } + + private static UUID decodeProcessId(final ByteBuffer data) { + return new UUID(data.getLong(), data.getLong()); + } + + @Override + public int hashCode() { + final int hashCode = usedVersion ^ latestSupportedVersion ^ processId.hashCode() ^ prevTasks.hashCode() ^ standbyTasks.hashCode(); + if (userEndPoint == null) { + return hashCode; + } + return hashCode ^ userEndPoint.hashCode(); + } + + @Override + public boolean equals(final Object o) { + if (o instanceof LegacySubscriptionInfoSerde) { + final LegacySubscriptionInfoSerde other = (LegacySubscriptionInfoSerde) o; + return usedVersion == other.usedVersion && + latestSupportedVersion == other.latestSupportedVersion && + processId.equals(other.processId) && + prevTasks.equals(other.prevTasks) && + standbyTasks.equals(other.standbyTasks) && + userEndPoint != null ? userEndPoint.equals(other.userEndPoint) : other.userEndPoint == null; + } else { + return false; + } + } + + @Override + public String toString() { + return "[version=" + usedVersion + + ", supported version=" + latestSupportedVersion + + ", process ID=" + processId + + ", prev tasks=" + prevTasks + + ", standby tasks=" + standbyTasks + + ", user endpoint=" + userEndPoint + "]"; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignorTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignorTest.java new file mode 100644 index 0000000..7536ad2 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/StickyTaskAssignorTest.java @@ -0,0 +1,759 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import org.apache.kafka.streams.processor.TaskId; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; +import java.util.TreeSet; +import java.util.UUID; + +import static java.util.Arrays.asList; +import static java.util.Collections.singleton; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_2; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_3; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_4; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_5; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_6; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_2; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_3; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_2; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_3; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_2; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_3; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_4; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_5; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_6; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.hasItems; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.lessThanOrEqualTo; +import static org.hamcrest.Matchers.not; + +public class StickyTaskAssignorTest { + + private final List expectedTopicGroupIds = asList(1, 2); + + private final Map clients = new TreeMap<>(); + + @Test + public void shouldAssignOneActiveTaskToEachProcessWhenTaskCountSameAsProcessCount() { + createClient(UUID_1, 1); + createClient(UUID_2, 1); + createClient(UUID_3, 1); + + final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_0_1, TASK_0_2); + assertThat(probingRebalanceNeeded, is(false)); + + for (final ClientState clientState : clients.values()) { + assertThat(clientState.activeTaskCount(), equalTo(1)); + } + } + + @Test + public void shouldAssignTopicGroupIdEvenlyAcrossClientsWithNoStandByTasks() { + createClient(UUID_1, 2); + createClient(UUID_2, 2); + createClient(UUID_3, 2); + + final boolean probingRebalanceNeeded = assign(TASK_1_0, TASK_1_1, TASK_2_2, TASK_2_0, TASK_2_1, TASK_1_2); + assertThat(probingRebalanceNeeded, is(false)); + + assertActiveTaskTopicGroupIdsEvenlyDistributed(); + } + + @Test + public void shouldAssignTopicGroupIdEvenlyAcrossClientsWithStandByTasks() { + createClient(UUID_1, 2); + createClient(UUID_2, 2); + createClient(UUID_3, 2); + + final boolean probingRebalanceNeeded = assign(1, TASK_2_0, TASK_1_1, TASK_1_2, TASK_1_0, TASK_2_1, TASK_2_2); + assertThat(probingRebalanceNeeded, is(false)); + + assertActiveTaskTopicGroupIdsEvenlyDistributed(); + } + + @Test + public void shouldNotMigrateActiveTaskToOtherProcess() { + createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_0); + createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_1); + + assertThat(assign(TASK_0_0, TASK_0_1, TASK_0_2), is(false)); + + assertThat(clients.get(UUID_1).activeTasks(), hasItems(TASK_0_0)); + assertThat(clients.get(UUID_2).activeTasks(), hasItems(TASK_0_1)); + assertThat(allActiveTasks(), equalTo(asList(TASK_0_0, TASK_0_1, TASK_0_2))); + + clients.clear(); + + // flip the previous active tasks assignment around. + createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_1); + createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_2); + + assertThat(assign(TASK_0_0, TASK_0_1, TASK_0_2), is(false)); + + assertThat(clients.get(UUID_1).activeTasks(), hasItems(TASK_0_1)); + assertThat(clients.get(UUID_2).activeTasks(), hasItems(TASK_0_2)); + assertThat(allActiveTasks(), equalTo(asList(TASK_0_0, TASK_0_1, TASK_0_2))); + } + + @Test + public void shouldMigrateActiveTasksToNewProcessWithoutChangingAllAssignments() { + createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_0, TASK_0_2); + createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_1); + createClient(UUID_3, 1); + + final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_0_1, TASK_0_2); + + assertThat(probingRebalanceNeeded, is(false)); + assertThat(clients.get(UUID_2).activeTasks(), equalTo(singleton(TASK_0_1))); + assertThat(clients.get(UUID_1).activeTasks().size(), equalTo(1)); + assertThat(clients.get(UUID_3).activeTasks().size(), equalTo(1)); + assertThat(allActiveTasks(), equalTo(asList(TASK_0_0, TASK_0_1, TASK_0_2))); + } + + @Test + public void shouldAssignBasedOnCapacity() { + createClient(UUID_1, 1); + createClient(UUID_2, 2); + final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_0_1, TASK_0_2); + + assertThat(probingRebalanceNeeded, is(false)); + assertThat(clients.get(UUID_1).activeTasks().size(), equalTo(1)); + assertThat(clients.get(UUID_2).activeTasks().size(), equalTo(2)); + } + + @Test + public void shouldAssignTasksEvenlyWithUnequalTopicGroupSizes() { + createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_0_4, TASK_0_5, TASK_1_0); + + createClient(UUID_2, 1); + + assertThat(assign(TASK_1_0, TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_0_4, TASK_0_5), is(false)); + + final Set allTasks = new HashSet<>(asList(TASK_0_0, TASK_0_1, TASK_1_0, TASK_0_5, TASK_0_2, TASK_0_3, TASK_0_4)); + final Set client1Tasks = clients.get(UUID_1).activeTasks(); + final Set client2Tasks = clients.get(UUID_2).activeTasks(); + + // one client should get 3 tasks and the other should have 4 + assertThat( + (client1Tasks.size() == 3 && client2Tasks.size() == 4) || + (client1Tasks.size() == 4 && client2Tasks.size() == 3), + is(true)); + allTasks.removeAll(client1Tasks); + // client2 should have all the remaining tasks not assigned to client 1 + assertThat(client2Tasks, equalTo(allTasks)); + } + + @Test + public void shouldKeepActiveTaskStickinessWhenMoreClientThanActiveTasks() { + createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_0); + createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_2); + createClientWithPreviousActiveTasks(UUID_3, 1, TASK_0_1); + createClient(UUID_4, 1); + createClient(UUID_5, 1); + + assertThat(assign(TASK_0_0, TASK_0_1, TASK_0_2), is(false)); + + assertThat(clients.get(UUID_1).activeTasks(), equalTo(singleton(TASK_0_0))); + assertThat(clients.get(UUID_2).activeTasks(), equalTo(singleton(TASK_0_2))); + assertThat(clients.get(UUID_3).activeTasks(), equalTo(singleton(TASK_0_1))); + + // change up the assignment and make sure it is still sticky + clients.clear(); + createClient(UUID_1, 1); + createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_0); + createClient(UUID_3, 1); + createClientWithPreviousActiveTasks(UUID_4, 1, TASK_0_2); + createClientWithPreviousActiveTasks(UUID_5, 1, TASK_0_1); + + assertThat(assign(TASK_0_0, TASK_0_1, TASK_0_2), is(false)); + + assertThat(clients.get(UUID_2).activeTasks(), equalTo(singleton(TASK_0_0))); + assertThat(clients.get(UUID_4).activeTasks(), equalTo(singleton(TASK_0_2))); + assertThat(clients.get(UUID_5).activeTasks(), equalTo(singleton(TASK_0_1))); + } + + @Test + public void shouldAssignTasksToClientWithPreviousStandbyTasks() { + final ClientState client1 = createClient(UUID_1, 1); + client1.addPreviousStandbyTasks(mkSet(TASK_0_2)); + final ClientState client2 = createClient(UUID_2, 1); + client2.addPreviousStandbyTasks(mkSet(TASK_0_1)); + final ClientState client3 = createClient(UUID_3, 1); + client3.addPreviousStandbyTasks(mkSet(TASK_0_0)); + + final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_0_1, TASK_0_2); + + assertThat(probingRebalanceNeeded, is(false)); + + assertThat(clients.get(UUID_1).activeTasks(), equalTo(singleton(TASK_0_2))); + assertThat(clients.get(UUID_2).activeTasks(), equalTo(singleton(TASK_0_1))); + assertThat(clients.get(UUID_3).activeTasks(), equalTo(singleton(TASK_0_0))); + } + + @Test + public void shouldAssignBasedOnCapacityWhenMultipleClientHaveStandbyTasks() { + final ClientState c1 = createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_0); + c1.addPreviousStandbyTasks(mkSet(TASK_0_1)); + final ClientState c2 = createClientWithPreviousActiveTasks(UUID_2, 2, TASK_0_2); + c2.addPreviousStandbyTasks(mkSet(TASK_0_1)); + + final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_0_1, TASK_0_2); + + assertThat(probingRebalanceNeeded, is(false)); + + assertThat(clients.get(UUID_1).activeTasks(), equalTo(singleton(TASK_0_0))); + assertThat(clients.get(UUID_2).activeTasks(), equalTo(mkSet(TASK_0_2, TASK_0_1))); + } + + @Test + public void shouldAssignStandbyTasksToDifferentClientThanCorrespondingActiveTaskIsAssignedTo() { + createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_0); + createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_1); + createClientWithPreviousActiveTasks(UUID_3, 1, TASK_0_2); + createClientWithPreviousActiveTasks(UUID_4, 1, TASK_0_3); + + final boolean probingRebalanceNeeded = assign(1, TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3); + assertThat(probingRebalanceNeeded, is(false)); + + + assertThat(clients.get(UUID_1).standbyTasks(), not(hasItems(TASK_0_0))); + assertThat(clients.get(UUID_1).standbyTasks().size(), lessThanOrEqualTo(2)); + assertThat(clients.get(UUID_2).standbyTasks(), not(hasItems(TASK_0_1))); + assertThat(clients.get(UUID_2).standbyTasks().size(), lessThanOrEqualTo(2)); + assertThat(clients.get(UUID_3).standbyTasks(), not(hasItems(TASK_0_2))); + assertThat(clients.get(UUID_3).standbyTasks().size(), lessThanOrEqualTo(2)); + assertThat(clients.get(UUID_4).standbyTasks(), not(hasItems(TASK_0_3))); + assertThat(clients.get(UUID_4).standbyTasks().size(), lessThanOrEqualTo(2)); + + int nonEmptyStandbyTaskCount = 0; + for (final ClientState clientState : clients.values()) { + nonEmptyStandbyTaskCount += clientState.standbyTasks().isEmpty() ? 0 : 1; + } + + assertThat(nonEmptyStandbyTaskCount, greaterThanOrEqualTo(3)); + assertThat(allStandbyTasks(), equalTo(asList(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3))); + } + + @Test + public void shouldAssignMultipleReplicasOfStandbyTask() { + createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_0); + createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_1); + createClientWithPreviousActiveTasks(UUID_3, 1, TASK_0_2); + + final boolean probingRebalanceNeeded = assign(2, TASK_0_0, TASK_0_1, TASK_0_2); + assertThat(probingRebalanceNeeded, is(false)); + + assertThat(clients.get(UUID_1).standbyTasks(), equalTo(mkSet(TASK_0_1, TASK_0_2))); + assertThat(clients.get(UUID_2).standbyTasks(), equalTo(mkSet(TASK_0_2, TASK_0_0))); + assertThat(clients.get(UUID_3).standbyTasks(), equalTo(mkSet(TASK_0_0, TASK_0_1))); + } + + @Test + public void shouldNotAssignStandbyTaskReplicasWhenNoClientAvailableWithoutHavingTheTaskAssigned() { + createClient(UUID_1, 1); + final boolean probingRebalanceNeeded = assign(1, TASK_0_0); + assertThat(probingRebalanceNeeded, is(false)); + assertThat(clients.get(UUID_1).standbyTasks().size(), equalTo(0)); + } + + @Test + public void shouldAssignActiveAndStandbyTasks() { + createClient(UUID_1, 1); + createClient(UUID_2, 1); + createClient(UUID_3, 1); + + final boolean probingRebalanceNeeded = assign(1, TASK_0_0, TASK_0_1, TASK_0_2); + assertThat(probingRebalanceNeeded, is(false)); + + assertThat(allActiveTasks(), equalTo(asList(TASK_0_0, TASK_0_1, TASK_0_2))); + assertThat(allStandbyTasks(), equalTo(asList(TASK_0_0, TASK_0_1, TASK_0_2))); + } + + @Test + public void shouldAssignAtLeastOneTaskToEachClientIfPossible() { + createClient(UUID_1, 3); + createClient(UUID_2, 1); + createClient(UUID_3, 1); + + final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_0_1, TASK_0_2); + assertThat(probingRebalanceNeeded, is(false)); + assertThat(clients.get(UUID_1).assignedTaskCount(), equalTo(1)); + assertThat(clients.get(UUID_2).assignedTaskCount(), equalTo(1)); + assertThat(clients.get(UUID_3).assignedTaskCount(), equalTo(1)); + } + + @Test + public void shouldAssignEachActiveTaskToOneClientWhenMoreClientsThanTasks() { + createClient(UUID_1, 1); + createClient(UUID_2, 1); + createClient(UUID_3, 1); + createClient(UUID_4, 1); + createClient(UUID_5, 1); + createClient(UUID_6, 1); + + final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_0_1, TASK_0_2); + assertThat(probingRebalanceNeeded, is(false)); + + assertThat(allActiveTasks(), equalTo(asList(TASK_0_0, TASK_0_1, TASK_0_2))); + } + + @Test + public void shouldBalanceActiveAndStandbyTasksAcrossAvailableClients() { + createClient(UUID_1, 1); + createClient(UUID_2, 1); + createClient(UUID_3, 1); + createClient(UUID_4, 1); + createClient(UUID_5, 1); + createClient(UUID_6, 1); + + final boolean probingRebalanceNeeded = assign(1, TASK_0_0, TASK_0_1, TASK_0_2); + assertThat(probingRebalanceNeeded, is(false)); + + for (final ClientState clientState : clients.values()) { + assertThat(clientState.assignedTaskCount(), equalTo(1)); + } + } + + @Test + public void shouldAssignMoreTasksToClientWithMoreCapacity() { + createClient(UUID_2, 2); + createClient(UUID_1, 1); + + final boolean probingRebalanceNeeded = assign( + TASK_0_0, + TASK_0_1, + TASK_0_2, + new TaskId(1, 0), + new TaskId(1, 1), + new TaskId(1, 2), + new TaskId(2, 0), + new TaskId(2, 1), + new TaskId(2, 2), + new TaskId(3, 0), + new TaskId(3, 1), + new TaskId(3, 2) + ); + + assertThat(probingRebalanceNeeded, is(false)); + assertThat(clients.get(UUID_2).assignedTaskCount(), equalTo(8)); + assertThat(clients.get(UUID_1).assignedTaskCount(), equalTo(4)); + } + + @Test + public void shouldEvenlyDistributeByTaskIdAndPartition() { + createClient(UUID_1, 4); + createClient(UUID_2, 4); + createClient(UUID_3, 4); + createClient(UUID_4, 4); + + final List taskIds = new ArrayList<>(); + final TaskId[] taskIdArray = new TaskId[16]; + + for (int i = 1; i <= 2; i++) { + for (int j = 0; j < 8; j++) { + taskIds.add(new TaskId(i, j)); + } + } + + Collections.shuffle(taskIds); + taskIds.toArray(taskIdArray); + + final boolean probingRebalanceNeeded = assign(taskIdArray); + assertThat(probingRebalanceNeeded, is(false)); + + Collections.sort(taskIds); + final Set expectedClientOneAssignment = getExpectedTaskIdAssignment(taskIds, 0, 4, 8, 12); + final Set expectedClientTwoAssignment = getExpectedTaskIdAssignment(taskIds, 1, 5, 9, 13); + final Set expectedClientThreeAssignment = getExpectedTaskIdAssignment(taskIds, 2, 6, 10, 14); + final Set expectedClientFourAssignment = getExpectedTaskIdAssignment(taskIds, 3, 7, 11, 15); + + final Map> sortedAssignments = sortClientAssignments(clients); + + assertThat(sortedAssignments.get(UUID_1), equalTo(expectedClientOneAssignment)); + assertThat(sortedAssignments.get(UUID_2), equalTo(expectedClientTwoAssignment)); + assertThat(sortedAssignments.get(UUID_3), equalTo(expectedClientThreeAssignment)); + assertThat(sortedAssignments.get(UUID_4), equalTo(expectedClientFourAssignment)); + } + + @Test + public void shouldNotHaveSameAssignmentOnAnyTwoHosts() { + final List allUUIDs = asList(UUID_1, UUID_2, UUID_3, UUID_4); + createClient(UUID_1, 1); + createClient(UUID_2, 1); + createClient(UUID_3, 1); + createClient(UUID_4, 1); + + final boolean probingRebalanceNeeded = assign(1, TASK_0_0, TASK_0_2, TASK_0_1, TASK_0_3); + assertThat(probingRebalanceNeeded, is(false)); + + for (final UUID uuid : allUUIDs) { + final Set taskIds = clients.get(uuid).assignedTasks(); + for (final UUID otherUUID : allUUIDs) { + if (!uuid.equals(otherUUID)) { + assertThat("clients shouldn't have same task assignment", clients.get(otherUUID).assignedTasks(), + not(equalTo(taskIds))); + } + } + + } + } + + @Test + public void shouldNotHaveSameAssignmentOnAnyTwoHostsWhenThereArePreviousActiveTasks() { + final List allUUIDs = asList(UUID_1, UUID_2, UUID_3); + createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_1, TASK_0_2); + createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_3); + createClientWithPreviousActiveTasks(UUID_3, 1, TASK_0_0); + createClient(UUID_4, 1); + + final boolean probingRebalanceNeeded = assign(1, TASK_0_0, TASK_0_2, TASK_0_1, TASK_0_3); + assertThat(probingRebalanceNeeded, is(false)); + + for (final UUID uuid : allUUIDs) { + final Set taskIds = clients.get(uuid).assignedTasks(); + for (final UUID otherUUID : allUUIDs) { + if (!uuid.equals(otherUUID)) { + assertThat("clients shouldn't have same task assignment", clients.get(otherUUID).assignedTasks(), + not(equalTo(taskIds))); + } + } + + } + } + + @Test + public void shouldNotHaveSameAssignmentOnAnyTwoHostsWhenThereArePreviousStandbyTasks() { + final List allUUIDs = asList(UUID_1, UUID_2, UUID_3, UUID_4); + + final ClientState c1 = createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_1, TASK_0_2); + c1.addPreviousStandbyTasks(mkSet(TASK_0_3, TASK_0_0)); + final ClientState c2 = createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_3, TASK_0_0); + c2.addPreviousStandbyTasks(mkSet(TASK_0_1, TASK_0_2)); + + createClient(UUID_3, 1); + createClient(UUID_4, 1); + + final boolean probingRebalanceNeeded = assign(1, TASK_0_0, TASK_0_2, TASK_0_1, TASK_0_3); + assertThat(probingRebalanceNeeded, is(false)); + + for (final UUID uuid : allUUIDs) { + final Set taskIds = clients.get(uuid).assignedTasks(); + for (final UUID otherUUID : allUUIDs) { + if (!uuid.equals(otherUUID)) { + assertThat("clients shouldn't have same task assignment", clients.get(otherUUID).assignedTasks(), + not(equalTo(taskIds))); + } + } + + } + } + + @Test + public void shouldReBalanceTasksAcrossAllClientsWhenCapacityAndTaskCountTheSame() { + createClientWithPreviousActiveTasks(UUID_3, 1, TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3); + createClient(UUID_1, 1); + createClient(UUID_2, 1); + createClient(UUID_4, 1); + + final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_0_2, TASK_0_1, TASK_0_3); + assertThat(probingRebalanceNeeded, is(false)); + + assertThat(clients.get(UUID_1).assignedTaskCount(), equalTo(1)); + assertThat(clients.get(UUID_2).assignedTaskCount(), equalTo(1)); + assertThat(clients.get(UUID_3).assignedTaskCount(), equalTo(1)); + assertThat(clients.get(UUID_4).assignedTaskCount(), equalTo(1)); + } + + @Test + public void shouldReBalanceTasksAcrossClientsWhenCapacityLessThanTaskCount() { + createClientWithPreviousActiveTasks(UUID_3, 1, TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3); + createClient(UUID_1, 1); + createClient(UUID_2, 1); + + final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_0_2, TASK_0_1, TASK_0_3); + assertThat(probingRebalanceNeeded, is(false)); + + assertThat(clients.get(UUID_3).assignedTaskCount(), equalTo(2)); + assertThat(clients.get(UUID_1).assignedTaskCount(), equalTo(1)); + assertThat(clients.get(UUID_2).assignedTaskCount(), equalTo(1)); + } + + @Test + public void shouldRebalanceTasksToClientsBasedOnCapacity() { + createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_0, TASK_0_3, TASK_0_2); + createClient(UUID_3, 2); + final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_0_2, TASK_0_3); + assertThat(probingRebalanceNeeded, is(false)); + assertThat(clients.get(UUID_2).assignedTaskCount(), equalTo(1)); + assertThat(clients.get(UUID_3).assignedTaskCount(), equalTo(2)); + } + + @Test + public void shouldMoveMinimalNumberOfTasksWhenPreviouslyAboveCapacityAndNewClientAdded() { + final Set p1PrevTasks = mkSet(TASK_0_0, TASK_0_2); + final Set p2PrevTasks = mkSet(TASK_0_1, TASK_0_3); + + createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_0, TASK_0_2); + createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_1, TASK_0_3); + createClientWithPreviousActiveTasks(UUID_3, 1); + + final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_0_2, TASK_0_1, TASK_0_3); + assertThat(probingRebalanceNeeded, is(false)); + + final Set p3ActiveTasks = clients.get(UUID_3).activeTasks(); + assertThat(p3ActiveTasks.size(), equalTo(1)); + if (p1PrevTasks.removeAll(p3ActiveTasks)) { + assertThat(clients.get(UUID_2).activeTasks(), equalTo(p2PrevTasks)); + } else { + assertThat(clients.get(UUID_1).activeTasks(), equalTo(p1PrevTasks)); + } + } + + @Test + public void shouldNotMoveAnyTasksWhenNewTasksAdded() { + createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_0, TASK_0_1); + createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_2, TASK_0_3); + + final boolean probingRebalanceNeeded = assign(TASK_0_3, TASK_0_1, TASK_0_4, TASK_0_2, TASK_0_0, TASK_0_5); + assertThat(probingRebalanceNeeded, is(false)); + + assertThat(clients.get(UUID_1).activeTasks(), hasItems(TASK_0_0, TASK_0_1)); + assertThat(clients.get(UUID_2).activeTasks(), hasItems(TASK_0_2, TASK_0_3)); + } + + @Test + public void shouldAssignNewTasksToNewClientWhenPreviousTasksAssignedToOldClients() { + + createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_2, TASK_0_1); + createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_0, TASK_0_3); + createClient(UUID_3, 1); + + final boolean probingRebalanceNeeded = assign(TASK_0_3, TASK_0_1, TASK_0_4, TASK_0_2, TASK_0_0, TASK_0_5); + assertThat(probingRebalanceNeeded, is(false)); + + assertThat(clients.get(UUID_1).activeTasks(), hasItems(TASK_0_2, TASK_0_1)); + assertThat(clients.get(UUID_2).activeTasks(), hasItems(TASK_0_0, TASK_0_3)); + assertThat(clients.get(UUID_3).activeTasks(), hasItems(TASK_0_4, TASK_0_5)); + } + + @Test + public void shouldAssignTasksNotPreviouslyActiveToNewClient() { + final ClientState c1 = createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_1, TASK_1_2, TASK_1_3); + c1.addPreviousStandbyTasks(mkSet(TASK_0_0, TASK_1_1, TASK_2_0, TASK_2_1, TASK_2_3)); + final ClientState c2 = createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_0, TASK_1_1, TASK_2_2); + c2.addPreviousStandbyTasks(mkSet(TASK_0_1, TASK_1_0, TASK_0_2, TASK_2_0, TASK_0_3, TASK_1_2, TASK_2_1, TASK_1_3, TASK_2_3)); + final ClientState c3 = createClientWithPreviousActiveTasks(UUID_3, 1, TASK_2_0, TASK_2_1, TASK_2_3); + c3.addPreviousStandbyTasks(mkSet(TASK_0_2, TASK_1_2)); + + final ClientState newClient = createClient(UUID_4, 1); + newClient.addPreviousStandbyTasks(mkSet(TASK_0_0, TASK_1_0, TASK_0_1, TASK_0_2, TASK_1_1, TASK_2_0, TASK_0_3, TASK_1_2, TASK_2_1, TASK_1_3, TASK_2_2, TASK_2_3)); + + final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_1_0, TASK_0_1, TASK_0_2, TASK_1_1, TASK_2_0, TASK_0_3, TASK_1_2, TASK_2_1, TASK_1_3, TASK_2_2, TASK_2_3); + assertThat(probingRebalanceNeeded, is(false)); + + assertThat(c1.activeTasks(), equalTo(mkSet(TASK_0_1, TASK_1_2, TASK_1_3))); + assertThat(c2.activeTasks(), equalTo(mkSet(TASK_0_0, TASK_1_1, TASK_2_2))); + assertThat(c3.activeTasks(), equalTo(mkSet(TASK_2_0, TASK_2_1, TASK_2_3))); + assertThat(newClient.activeTasks(), equalTo(mkSet(TASK_0_2, TASK_0_3, TASK_1_0))); + } + + @Test + public void shouldAssignTasksNotPreviouslyActiveToMultipleNewClients() { + final ClientState c1 = createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_1, TASK_1_2, TASK_1_3); + c1.addPreviousStandbyTasks(mkSet(TASK_0_0, TASK_1_1, TASK_2_0, TASK_2_1, TASK_2_3)); + final ClientState c2 = createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_0, TASK_1_1, TASK_2_2); + c2.addPreviousStandbyTasks(mkSet(TASK_0_1, TASK_1_0, TASK_0_2, TASK_2_0, TASK_0_3, TASK_1_2, TASK_2_1, TASK_1_3, TASK_2_3)); + + final ClientState bounce1 = createClient(UUID_3, 1); + bounce1.addPreviousStandbyTasks(mkSet(TASK_2_0, TASK_2_1, TASK_2_3)); + + final ClientState bounce2 = createClient(UUID_4, 1); + bounce2.addPreviousStandbyTasks(mkSet(TASK_0_2, TASK_0_3, TASK_1_0)); + + final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_1_0, TASK_0_1, TASK_0_2, TASK_1_1, TASK_2_0, TASK_0_3, TASK_1_2, TASK_2_1, TASK_1_3, TASK_2_2, TASK_2_3); + assertThat(probingRebalanceNeeded, is(false)); + + assertThat(c1.activeTasks(), equalTo(mkSet(TASK_0_1, TASK_1_2, TASK_1_3))); + assertThat(c2.activeTasks(), equalTo(mkSet(TASK_0_0, TASK_1_1, TASK_2_2))); + assertThat(bounce1.activeTasks(), equalTo(mkSet(TASK_2_0, TASK_2_1, TASK_2_3))); + assertThat(bounce2.activeTasks(), equalTo(mkSet(TASK_0_2, TASK_0_3, TASK_1_0))); + } + + @Test + public void shouldAssignTasksToNewClient() { + createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_1, TASK_0_2); + createClient(UUID_2, 1); + assertThat(assign(TASK_0_1, TASK_0_2), is(false)); + assertThat(clients.get(UUID_1).activeTaskCount(), equalTo(1)); + } + + @Test + public void shouldAssignTasksToNewClientWithoutFlippingAssignmentBetweenExistingClients() { + final ClientState c1 = createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_0, TASK_0_1, TASK_0_2); + final ClientState c2 = createClientWithPreviousActiveTasks(UUID_2, 1, TASK_0_3, TASK_0_4, TASK_0_5); + final ClientState newClient = createClient(UUID_3, 1); + + final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_0_4, TASK_0_5); + assertThat(probingRebalanceNeeded, is(false)); + assertThat(c1.activeTasks(), not(hasItem(TASK_0_3))); + assertThat(c1.activeTasks(), not(hasItem(TASK_0_4))); + assertThat(c1.activeTasks(), not(hasItem(TASK_0_5))); + assertThat(c1.activeTaskCount(), equalTo(2)); + assertThat(c2.activeTasks(), not(hasItems(TASK_0_0))); + assertThat(c2.activeTasks(), not(hasItems(TASK_0_1))); + assertThat(c2.activeTasks(), not(hasItems(TASK_0_2))); + assertThat(c2.activeTaskCount(), equalTo(2)); + assertThat(newClient.activeTaskCount(), equalTo(2)); + } + + @Test + public void shouldAssignTasksToNewClientWithoutFlippingAssignmentBetweenExistingAndBouncedClients() { + final ClientState c1 = createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_6); + final ClientState c2 = createClient(UUID_2, 1); + c2.addPreviousStandbyTasks(mkSet(TASK_0_3, TASK_0_4, TASK_0_5)); + final ClientState newClient = createClient(UUID_3, 1); + + final boolean probingRebalanceNeeded = assign(TASK_0_0, TASK_0_1, TASK_0_2, TASK_0_3, TASK_0_4, TASK_0_5, TASK_0_6); + assertThat(probingRebalanceNeeded, is(false)); + + // it's possible for either client 1 or 2 to get three tasks since they both had three previously assigned + assertThat(c1.activeTasks(), not(hasItem(TASK_0_3))); + assertThat(c1.activeTasks(), not(hasItem(TASK_0_4))); + assertThat(c1.activeTasks(), not(hasItem(TASK_0_5))); + assertThat(c1.activeTaskCount(), greaterThanOrEqualTo(2)); + assertThat(c2.activeTasks(), not(hasItems(TASK_0_0))); + assertThat(c2.activeTasks(), not(hasItems(TASK_0_1))); + assertThat(c2.activeTasks(), not(hasItems(TASK_0_2))); + assertThat(c2.activeTaskCount(), greaterThanOrEqualTo(2)); + assertThat(newClient.activeTaskCount(), equalTo(2)); + } + + @Test + public void shouldViolateBalanceToPreserveActiveTaskStickiness() { + final ClientState c1 = createClientWithPreviousActiveTasks(UUID_1, 1, TASK_0_0, TASK_0_1, TASK_0_2); + final ClientState c2 = createClient(UUID_2, 1); + + final List taskIds = asList(TASK_0_0, TASK_0_1, TASK_0_2); + Collections.shuffle(taskIds); + final boolean probingRebalanceNeeded = new StickyTaskAssignor(true).assign( + clients, + new HashSet<>(taskIds), + new HashSet<>(taskIds), + new AssignorConfiguration.AssignmentConfigs(0L, 1, 0, 60_000L) + ); + assertThat(probingRebalanceNeeded, is(false)); + + assertThat(c1.activeTasks(), equalTo(mkSet(TASK_0_0, TASK_0_1, TASK_0_2))); + assertThat(c2.activeTasks(), empty()); + } + + private boolean assign(final TaskId... tasks) { + return assign(0, tasks); + } + + private boolean assign(final int numStandbys, final TaskId... tasks) { + final List taskIds = asList(tasks); + Collections.shuffle(taskIds); + return new StickyTaskAssignor().assign( + clients, + new HashSet<>(taskIds), + new HashSet<>(taskIds), + new AssignorConfiguration.AssignmentConfigs(0L, 1, numStandbys, 60_000L) + ); + } + + private List allActiveTasks() { + final List allActive = new ArrayList<>(); + for (final ClientState client : clients.values()) { + allActive.addAll(client.activeTasks()); + } + Collections.sort(allActive); + return allActive; + } + + private List allStandbyTasks() { + final List tasks = new ArrayList<>(); + for (final ClientState client : clients.values()) { + tasks.addAll(client.standbyTasks()); + } + Collections.sort(tasks); + return tasks; + } + + private ClientState createClient(final UUID processId, final int capacity) { + return createClientWithPreviousActiveTasks(processId, capacity); + } + + private ClientState createClientWithPreviousActiveTasks(final UUID processId, final int capacity, final TaskId... taskIds) { + final ClientState clientState = new ClientState(capacity); + clientState.addPreviousActiveTasks(mkSet(taskIds)); + clients.put(processId, clientState); + return clientState; + } + + private void assertActiveTaskTopicGroupIdsEvenlyDistributed() { + for (final Map.Entry clientStateEntry : clients.entrySet()) { + final List topicGroupIds = new ArrayList<>(); + final Set activeTasks = clientStateEntry.getValue().activeTasks(); + for (final TaskId activeTask : activeTasks) { + topicGroupIds.add(activeTask.subtopology()); + } + Collections.sort(topicGroupIds); + assertThat(topicGroupIds, equalTo(expectedTopicGroupIds)); + } + } + + private static Map> sortClientAssignments(final Map clients) { + final Map> sortedAssignments = new HashMap<>(); + for (final Map.Entry entry : clients.entrySet()) { + final Set sorted = new TreeSet<>(entry.getValue().activeTasks()); + sortedAssignments.put(entry.getKey(), sorted); + } + return sortedAssignments; + } + + private static Set getExpectedTaskIdAssignment(final List tasks, final int... indices) { + final Set sortedAssignment = new TreeSet<>(); + for (final int index : indices) { + sortedAssignment.add(tasks.get(index)); + } + return sortedAssignment; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfoTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfoTest.java new file mode 100644 index 0000000..dd65196 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/SubscriptionInfoTest.java @@ -0,0 +1,445 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import java.util.Map; + +import org.apache.kafka.streams.errors.TaskAssignmentException; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.Task; +import org.junit.Test; + +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.NAMED_TASK_T0_0_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.NAMED_TASK_T0_0_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.NAMED_TASK_T0_1_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.NAMED_TASK_T0_1_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.NAMED_TASK_T1_0_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.NAMED_TASK_T1_0_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.NAMED_TASK_T2_0_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.NAMED_TASK_T2_2_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_2_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_1; +import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.LATEST_SUPPORTED_VERSION; +import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.MIN_NAMED_TOPOLOGY_VERSION; +import static org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo.MIN_VERSION_OFFSET_SUM_SUBSCRIPTION; +import static org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo.UNKNOWN_OFFSET_SUM; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.fail; + +public class SubscriptionInfoTest { + private static final Set ACTIVE_TASKS = new HashSet<>(Arrays.asList( + TASK_0_0, + TASK_0_1, + TASK_1_0)); + private static final Set STANDBY_TASKS = new HashSet<>(Arrays.asList( + TASK_1_1, + TASK_2_0)); + private static final Map TASK_OFFSET_SUMS = mkMap( + mkEntry(TASK_0_0, Task.LATEST_OFFSET), + mkEntry(TASK_0_1, Task.LATEST_OFFSET), + mkEntry(TASK_1_0, Task.LATEST_OFFSET), + mkEntry(TASK_1_1, 0L), + mkEntry(TASK_2_0, 10L) + ); + + private static final Map NAMED_TASK_OFFSET_SUMS = mkMap( + mkEntry(NAMED_TASK_T0_0_0, Task.LATEST_OFFSET), + mkEntry(NAMED_TASK_T0_0_1, Task.LATEST_OFFSET), + mkEntry(NAMED_TASK_T0_1_0, 5L), + mkEntry(NAMED_TASK_T0_1_1, 10_000L), + mkEntry(NAMED_TASK_T1_0_0, Task.LATEST_OFFSET), + mkEntry(NAMED_TASK_T1_0_1, 0L), + mkEntry(NAMED_TASK_T2_0_0, 10L), + mkEntry(NAMED_TASK_T2_2_0, 5L) + ); + + private final static String IGNORED_USER_ENDPOINT = "ignoredUserEndpoint:80"; + private static final byte IGNORED_UNIQUE_FIELD = (byte) 0; + private static final int IGNORED_ERROR_CODE = 0; + + + @Test + public void shouldThrowForUnknownVersion1() { + assertThrows(IllegalArgumentException.class, () -> new SubscriptionInfo( + 0, + LATEST_SUPPORTED_VERSION, + UUID_1, + "localhost:80", + TASK_OFFSET_SUMS, + IGNORED_UNIQUE_FIELD, + IGNORED_ERROR_CODE + )); + } + + @Test + public void shouldThrowForUnknownVersion2() { + assertThrows(IllegalArgumentException.class, () -> new SubscriptionInfo( + LATEST_SUPPORTED_VERSION + 1, + LATEST_SUPPORTED_VERSION, + UUID_1, + "localhost:80", + TASK_OFFSET_SUMS, + IGNORED_UNIQUE_FIELD, + IGNORED_ERROR_CODE + )); + } + + @Test + public void shouldEncodeAndDecodeVersion1() { + final SubscriptionInfo info = new SubscriptionInfo( + 1, + LATEST_SUPPORTED_VERSION, + UUID_1, + IGNORED_USER_ENDPOINT, + TASK_OFFSET_SUMS, + IGNORED_UNIQUE_FIELD, + IGNORED_ERROR_CODE + ); + final SubscriptionInfo decoded = SubscriptionInfo.decode(info.encode()); + assertEquals(1, decoded.version()); + assertEquals(SubscriptionInfo.UNKNOWN, decoded.latestSupportedVersion()); + assertEquals(UUID_1, decoded.processId()); + assertEquals(ACTIVE_TASKS, decoded.prevTasks()); + assertEquals(STANDBY_TASKS, decoded.standbyTasks()); + assertNull(decoded.userEndPoint()); + } + + @Test + public void generatedVersion1ShouldBeDecodableByLegacyLogic() { + final SubscriptionInfo info = new SubscriptionInfo( + 1, + 1234, + UUID_1, + "ignoreme", + TASK_OFFSET_SUMS, + IGNORED_UNIQUE_FIELD, + IGNORED_ERROR_CODE + ); + final ByteBuffer buffer = info.encode(); + + final LegacySubscriptionInfoSerde decoded = LegacySubscriptionInfoSerde.decode(buffer); + assertEquals(1, decoded.version()); + assertEquals(SubscriptionInfo.UNKNOWN, decoded.latestSupportedVersion()); + assertEquals(UUID_1, decoded.processId()); + assertEquals(ACTIVE_TASKS, decoded.prevTasks()); + assertEquals(STANDBY_TASKS, decoded.standbyTasks()); + assertNull(decoded.userEndPoint()); + } + + @Test + public void generatedVersion1ShouldDecodeLegacyFormat() { + final LegacySubscriptionInfoSerde info = new LegacySubscriptionInfoSerde( + 1, + LATEST_SUPPORTED_VERSION, + UUID_1, + ACTIVE_TASKS, + STANDBY_TASKS, + "localhost:80" + ); + final ByteBuffer buffer = info.encode(); + buffer.rewind(); + final SubscriptionInfo decoded = SubscriptionInfo.decode(buffer); + assertEquals(1, decoded.version()); + assertEquals(SubscriptionInfo.UNKNOWN, decoded.latestSupportedVersion()); + assertEquals(UUID_1, decoded.processId()); + assertEquals(ACTIVE_TASKS, decoded.prevTasks()); + assertEquals(STANDBY_TASKS, decoded.standbyTasks()); + assertNull(decoded.userEndPoint()); + } + + @Test + public void shouldEncodeAndDecodeVersion2() { + final SubscriptionInfo info = new SubscriptionInfo( + 2, + LATEST_SUPPORTED_VERSION, + UUID_1, + "localhost:80", + TASK_OFFSET_SUMS, + IGNORED_UNIQUE_FIELD, + IGNORED_ERROR_CODE + ); + final SubscriptionInfo decoded = SubscriptionInfo.decode(info.encode()); + assertEquals(2, decoded.version()); + assertEquals(SubscriptionInfo.UNKNOWN, decoded.latestSupportedVersion()); + assertEquals(UUID_1, decoded.processId()); + assertEquals(ACTIVE_TASKS, decoded.prevTasks()); + assertEquals(STANDBY_TASKS, decoded.standbyTasks()); + assertEquals("localhost:80", decoded.userEndPoint()); + } + + @Test + public void generatedVersion2ShouldBeDecodableByLegacyLogic() { + final SubscriptionInfo info = new SubscriptionInfo( + 2, + LATEST_SUPPORTED_VERSION, + UUID_1, + "localhost:80", + TASK_OFFSET_SUMS, + IGNORED_UNIQUE_FIELD, + IGNORED_ERROR_CODE + ); + final ByteBuffer buffer = info.encode(); + + final LegacySubscriptionInfoSerde decoded = LegacySubscriptionInfoSerde.decode(buffer); + assertEquals(2, decoded.version()); + assertEquals(SubscriptionInfo.UNKNOWN, decoded.latestSupportedVersion()); + assertEquals(UUID_1, decoded.processId()); + assertEquals(ACTIVE_TASKS, decoded.prevTasks()); + assertEquals(STANDBY_TASKS, decoded.standbyTasks()); + assertEquals("localhost:80", decoded.userEndPoint()); + } + + @Test + public void generatedVersion2ShouldDecodeLegacyFormat() { + final LegacySubscriptionInfoSerde info = new LegacySubscriptionInfoSerde( + 2, + LATEST_SUPPORTED_VERSION, + UUID_1, + ACTIVE_TASKS, + STANDBY_TASKS, + "localhost:80" + ); + final ByteBuffer buffer = info.encode(); + buffer.rewind(); + final SubscriptionInfo decoded = SubscriptionInfo.decode(buffer); + assertEquals(2, decoded.version()); + assertEquals(SubscriptionInfo.UNKNOWN, decoded.latestSupportedVersion()); + assertEquals(UUID_1, decoded.processId()); + assertEquals(ACTIVE_TASKS, decoded.prevTasks()); + assertEquals(STANDBY_TASKS, decoded.standbyTasks()); + assertEquals("localhost:80", decoded.userEndPoint()); + } + + @Test + public void shouldEncodeAndDecodeVersion3And4() { + for (int version = 3; version <= 4; version++) { + final SubscriptionInfo info = new SubscriptionInfo( + version, + LATEST_SUPPORTED_VERSION, + UUID_1, + "localhost:80", + TASK_OFFSET_SUMS, + IGNORED_UNIQUE_FIELD, + IGNORED_ERROR_CODE + ); + final SubscriptionInfo decoded = SubscriptionInfo.decode(info.encode()); + assertEquals(version, decoded.version()); + assertEquals(LATEST_SUPPORTED_VERSION, decoded.latestSupportedVersion()); + assertEquals(UUID_1, decoded.processId()); + assertEquals(ACTIVE_TASKS, decoded.prevTasks()); + assertEquals(STANDBY_TASKS, decoded.standbyTasks()); + assertEquals("localhost:80", decoded.userEndPoint()); + } + } + + @Test + public void generatedVersion3And4ShouldBeDecodableByLegacyLogic() { + for (int version = 3; version <= 4; version++) { + final SubscriptionInfo info = new SubscriptionInfo( + version, + LATEST_SUPPORTED_VERSION, + UUID_1, + "localhost:80", + TASK_OFFSET_SUMS, + IGNORED_UNIQUE_FIELD, + IGNORED_ERROR_CODE + ); + final ByteBuffer buffer = info.encode(); + + final LegacySubscriptionInfoSerde decoded = LegacySubscriptionInfoSerde.decode(buffer); + assertEquals(version, decoded.version()); + assertEquals(LATEST_SUPPORTED_VERSION, decoded.latestSupportedVersion()); + assertEquals(UUID_1, decoded.processId()); + assertEquals(ACTIVE_TASKS, decoded.prevTasks()); + assertEquals(STANDBY_TASKS, decoded.standbyTasks()); + assertEquals("localhost:80", decoded.userEndPoint()); + } + } + + @Test + public void generatedVersion3To6ShouldDecodeLegacyFormat() { + for (int version = 3; version <= 6; version++) { + final LegacySubscriptionInfoSerde info = new LegacySubscriptionInfoSerde( + version, + LATEST_SUPPORTED_VERSION, + UUID_1, + ACTIVE_TASKS, + STANDBY_TASKS, + "localhost:80" + ); + final ByteBuffer buffer = info.encode(); + buffer.rewind(); + final SubscriptionInfo decoded = SubscriptionInfo.decode(buffer); + final String message = "for version: " + version; + assertEquals(message, version, decoded.version()); + assertEquals(message, LATEST_SUPPORTED_VERSION, decoded.latestSupportedVersion()); + assertEquals(message, UUID_1, decoded.processId()); + assertEquals(message, ACTIVE_TASKS, decoded.prevTasks()); + assertEquals(message, STANDBY_TASKS, decoded.standbyTasks()); + assertEquals(message, "localhost:80", decoded.userEndPoint()); + } + } + + @Test + public void shouldEncodeAndDecodeVersion5() { + final SubscriptionInfo info = + new SubscriptionInfo(5, LATEST_SUPPORTED_VERSION, UUID_1, "localhost:80", TASK_OFFSET_SUMS, IGNORED_UNIQUE_FIELD, IGNORED_ERROR_CODE); + assertEquals(info, SubscriptionInfo.decode(info.encode())); + } + + @Test + public void shouldAllowToDecodeFutureSupportedVersion() { + final SubscriptionInfo info = SubscriptionInfo.decode(encodeFutureVersion()); + assertEquals(LATEST_SUPPORTED_VERSION + 1, info.version()); + assertEquals(LATEST_SUPPORTED_VERSION + 1, info.latestSupportedVersion()); + } + + @Test + public void shouldEncodeAndDecodeSmallerLatestSupportedVersion() { + final int usedVersion = LATEST_SUPPORTED_VERSION - 1; + final int latestSupportedVersion = LATEST_SUPPORTED_VERSION - 1; + + final SubscriptionInfo info = + new SubscriptionInfo(usedVersion, latestSupportedVersion, UUID_1, "localhost:80", TASK_OFFSET_SUMS, IGNORED_UNIQUE_FIELD, IGNORED_ERROR_CODE); + final SubscriptionInfo expectedInfo = + new SubscriptionInfo(usedVersion, latestSupportedVersion, UUID_1, "localhost:80", TASK_OFFSET_SUMS, IGNORED_UNIQUE_FIELD, IGNORED_ERROR_CODE); + assertEquals(expectedInfo, SubscriptionInfo.decode(info.encode())); + } + + @Test + public void shouldEncodeAndDecodeVersion7() { + final SubscriptionInfo info = + new SubscriptionInfo(7, LATEST_SUPPORTED_VERSION, UUID_1, "localhost:80", TASK_OFFSET_SUMS, IGNORED_UNIQUE_FIELD, IGNORED_ERROR_CODE); + assertThat(info, is(SubscriptionInfo.decode(info.encode()))); + } + + @Test + public void shouldConvertTaskOffsetSumMapToTaskSets() { + final SubscriptionInfo info = + new SubscriptionInfo(7, LATEST_SUPPORTED_VERSION, UUID_1, "localhost:80", TASK_OFFSET_SUMS, IGNORED_UNIQUE_FIELD, IGNORED_ERROR_CODE); + assertThat(info.prevTasks(), is(ACTIVE_TASKS)); + assertThat(info.standbyTasks(), is(STANDBY_TASKS)); + } + + @Test + public void shouldReturnTaskOffsetSumsMapForDecodedSubscription() { + final SubscriptionInfo info = SubscriptionInfo.decode( + new SubscriptionInfo(MIN_VERSION_OFFSET_SUM_SUBSCRIPTION, + LATEST_SUPPORTED_VERSION, UUID_1, + "localhost:80", + TASK_OFFSET_SUMS, + IGNORED_UNIQUE_FIELD, + IGNORED_ERROR_CODE + ).encode()); + assertThat(info.taskOffsetSums(), is(TASK_OFFSET_SUMS)); + } + + @Test + public void shouldConvertTaskSetsToTaskOffsetSumMapWithOlderSubscription() { + final Map expectedOffsetSumsMap = mkMap( + mkEntry(new TaskId(0, 0), Task.LATEST_OFFSET), + mkEntry(new TaskId(0, 1), Task.LATEST_OFFSET), + mkEntry(new TaskId(1, 0), Task.LATEST_OFFSET), + mkEntry(new TaskId(1, 1), UNKNOWN_OFFSET_SUM), + mkEntry(new TaskId(2, 0), UNKNOWN_OFFSET_SUM) + ); + + final SubscriptionInfo info = SubscriptionInfo.decode( + new LegacySubscriptionInfoSerde( + SubscriptionInfo.MIN_VERSION_OFFSET_SUM_SUBSCRIPTION - 1, + LATEST_SUPPORTED_VERSION, + UUID_1, + ACTIVE_TASKS, + STANDBY_TASKS, + "localhost:80") + .encode()); + + assertThat(info.taskOffsetSums(), is(expectedOffsetSumsMap)); + } + + @Test + public void shouldEncodeAndDecodeVersion8() { + final SubscriptionInfo info = + new SubscriptionInfo(8, LATEST_SUPPORTED_VERSION, UUID_1, "localhost:80", TASK_OFFSET_SUMS, IGNORED_UNIQUE_FIELD, IGNORED_ERROR_CODE); + assertThat(info, is(SubscriptionInfo.decode(info.encode()))); + } + + @Test + public void shouldNotErrorAccessingFutureVars() { + final SubscriptionInfo info = + new SubscriptionInfo(8, LATEST_SUPPORTED_VERSION, UUID_1, "localhost:80", TASK_OFFSET_SUMS, IGNORED_UNIQUE_FIELD, IGNORED_ERROR_CODE); + try { + info.errorCode(); + } catch (final Exception e) { + fail("should not error"); + } + } + + @Test + public void shouldEncodeAndDecodeVersion9() { + final SubscriptionInfo info = + new SubscriptionInfo(9, LATEST_SUPPORTED_VERSION, UUID_1, "localhost:80", TASK_OFFSET_SUMS, IGNORED_UNIQUE_FIELD, IGNORED_ERROR_CODE); + assertThat(info, is(SubscriptionInfo.decode(info.encode()))); + } + + @Test + public void shouldEncodeAndDecodeVersion10() { + final SubscriptionInfo info = + new SubscriptionInfo(10, LATEST_SUPPORTED_VERSION, UUID_1, "localhost:80", TASK_OFFSET_SUMS, IGNORED_UNIQUE_FIELD, IGNORED_ERROR_CODE); + assertThat(info, is(SubscriptionInfo.decode(info.encode()))); + } + + @Test + public void shouldEncodeAndDecodeVersion10WithNamedTopologies() { + final SubscriptionInfo info = + new SubscriptionInfo(10, LATEST_SUPPORTED_VERSION, UUID_1, "localhost:80", NAMED_TASK_OFFSET_SUMS, IGNORED_UNIQUE_FIELD, IGNORED_ERROR_CODE); + assertThat(info, is(SubscriptionInfo.decode(info.encode()))); + } + + @Test + public void shouldThrowIfAttemptingToUseNamedTopologiesWithOlderVersion() { + assertThrows( + TaskAssignmentException.class, + () -> new SubscriptionInfo(MIN_NAMED_TOPOLOGY_VERSION - 1, LATEST_SUPPORTED_VERSION, UUID_1, "localhost:80", NAMED_TASK_OFFSET_SUMS, IGNORED_UNIQUE_FIELD, IGNORED_ERROR_CODE) + ); + } + + private static ByteBuffer encodeFutureVersion() { + final ByteBuffer buf = ByteBuffer.allocate(4 /* used version */ + + 4 /* supported version */); + buf.putInt(LATEST_SUPPORTED_VERSION + 1); + buf.putInt(LATEST_SUPPORTED_VERSION + 1); + buf.rewind(); + return buf; + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorConvergenceTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorConvergenceTest.java new file mode 100644 index 0000000..68c9dfe --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskAssignorConvergenceTest.java @@ -0,0 +1,426 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration.AssignmentConfigs; +import org.junit.Test; + +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Random; +import java.util.Set; +import java.util.TreeMap; +import java.util.TreeSet; +import java.util.UUID; +import java.util.function.Supplier; + +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.appendClientStates; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.assertBalancedActiveAssignment; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.assertBalancedStatefulAssignment; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.assertValidAssignment; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.uuidForInt; +import static org.junit.Assert.fail; + +public class TaskAssignorConvergenceTest { + private static final class Harness { + private final Set statelessTasks; + private final Map statefulTaskEndOffsetSums; + private final Map clientStates; + private final Map droppedClientStates; + private final StringBuilder history = new StringBuilder(); + + private static Harness initializeCluster(final int numStatelessTasks, + final int numStatefulTasks, + final int numNodes, + final Supplier partitionCountSupplier) { + int subtopology = 0; + final Set statelessTasks = new TreeSet<>(); + int remainingStatelessTasks = numStatelessTasks; + while (remainingStatelessTasks > 0) { + final int partitions = Math.min(remainingStatelessTasks, partitionCountSupplier.get()); + for (int i = 0; i < partitions; i++) { + statelessTasks.add(new TaskId(subtopology, i)); + remainingStatelessTasks--; + } + subtopology++; + } + + final Map statefulTaskEndOffsetSums = new TreeMap<>(); + int remainingStatefulTasks = numStatefulTasks; + while (remainingStatefulTasks > 0) { + final int partitions = Math.min(remainingStatefulTasks, partitionCountSupplier.get()); + for (int i = 0; i < partitions; i++) { + statefulTaskEndOffsetSums.put(new TaskId(subtopology, i), 150000L); + remainingStatefulTasks--; + } + subtopology++; + } + + final Map clientStates = new TreeMap<>(); + for (int i = 0; i < numNodes; i++) { + final UUID uuid = uuidForInt(i); + clientStates.put(uuid, emptyInstance(uuid, statefulTaskEndOffsetSums)); + } + + return new Harness(statelessTasks, statefulTaskEndOffsetSums, clientStates); + } + + private Harness(final Set statelessTasks, + final Map statefulTaskEndOffsetSums, + final Map clientStates) { + this.statelessTasks = statelessTasks; + this.statefulTaskEndOffsetSums = statefulTaskEndOffsetSums; + this.clientStates = clientStates; + droppedClientStates = new TreeMap<>(); + history.append('\n'); + history.append("Cluster and application initial state: \n"); + history.append("Stateless tasks: ").append(statelessTasks).append('\n'); + history.append("Stateful tasks: ").append(statefulTaskEndOffsetSums.keySet()).append('\n'); + formatClientStates(true); + history.append("History of the cluster: \n"); + } + + private void addNode() { + final UUID uuid = uuidForInt(clientStates.size() + droppedClientStates.size()); + history.append("Adding new node ").append(uuid).append('\n'); + clientStates.put(uuid, emptyInstance(uuid, statefulTaskEndOffsetSums)); + } + + private static ClientState emptyInstance(final UUID uuid, final Map allTaskEndOffsetSums) { + final ClientState clientState = new ClientState(1); + clientState.computeTaskLags(uuid, allTaskEndOffsetSums); + return clientState; + } + + private void addOrResurrectNodesRandomly(final Random prng, final int limit) { + final int numberToAdd = prng.nextInt(limit); + for (int i = 0; i < numberToAdd; i++) { + final boolean addNew = prng.nextBoolean(); + if (addNew || droppedClientStates.isEmpty()) { + addNode(); + } else { + final UUID uuid = selectRandomElement(prng, droppedClientStates); + history.append("Resurrecting node ").append(uuid).append('\n'); + clientStates.put(uuid, droppedClientStates.get(uuid)); + droppedClientStates.remove(uuid); + } + } + } + + private void dropNode() { + if (clientStates.isEmpty()) { + throw new NoSuchElementException("There are no nodes to drop"); + } else { + final UUID toDrop = clientStates.keySet().iterator().next(); + dropNode(toDrop); + } + } + + private void dropRandomNodes(final int numNode, final Random prng) { + int dropped = 0; + while (!clientStates.isEmpty() && dropped < numNode) { + final UUID toDrop = selectRandomElement(prng, clientStates); + dropNode(toDrop); + dropped++; + } + history.append("Stateless tasks: ").append(statelessTasks).append('\n'); + history.append("Stateful tasks: ").append(statefulTaskEndOffsetSums.keySet()).append('\n'); + formatClientStates(true); + } + + private void dropNode(final UUID toDrop) { + final ClientState clientState = clientStates.remove(toDrop); + history.append("Dropping node ").append(toDrop).append(": ").append(clientState).append('\n'); + droppedClientStates.put(toDrop, clientState); + } + + private static UUID selectRandomElement(final Random prng, final Map clients) { + int dropIndex = prng.nextInt(clients.size()); + UUID toDrop = null; + for (final UUID uuid : clients.keySet()) { + if (dropIndex == 0) { + toDrop = uuid; + break; + } else { + dropIndex--; + } + } + return toDrop; + } + + /** + * Flip the cluster states from "assigned" to "subscribed" so they can be used for another round of assignments. + */ + private void prepareForNextRebalance() { + final Map newClientStates = new TreeMap<>(); + for (final Map.Entry entry : clientStates.entrySet()) { + final UUID uuid = entry.getKey(); + final ClientState newClientState = new ClientState(1); + final ClientState clientState = entry.getValue(); + final Map taskOffsetSums = new TreeMap<>(); + for (final TaskId taskId : clientState.activeTasks()) { + if (statefulTaskEndOffsetSums.containsKey(taskId)) { + taskOffsetSums.put(taskId, statefulTaskEndOffsetSums.get(taskId)); + } + } + for (final TaskId taskId : clientState.standbyTasks()) { + if (statefulTaskEndOffsetSums.containsKey(taskId)) { + taskOffsetSums.put(taskId, statefulTaskEndOffsetSums.get(taskId)); + } + } + newClientState.addPreviousActiveTasks(clientState.activeTasks()); + newClientState.addPreviousStandbyTasks(clientState.standbyTasks()); + newClientState.addPreviousTasksAndOffsetSums("consumer", taskOffsetSums); + newClientState.computeTaskLags(uuid, statefulTaskEndOffsetSums); + newClientStates.put(uuid, newClientState); + } + + clientStates.clear(); + clientStates.putAll(newClientStates); + } + + private void recordConfig(final AssignmentConfigs configuration) { + history.append("Creating assignor with configuration: ") + .append(configuration) + .append('\n'); + } + + private void recordBefore(final int iteration) { + history.append("Starting Iteration: ").append(iteration).append('\n'); + formatClientStates(false); + } + + private void recordAfter(final int iteration, final boolean rebalancePending) { + history.append("After assignment: ").append(iteration).append('\n'); + history.append("Rebalance pending: ").append(rebalancePending).append('\n'); + formatClientStates(true); + history.append('\n'); + } + + private void formatClientStates(final boolean printUnassigned) { + appendClientStates(history, clientStates); + if (printUnassigned) { + final Set unassignedTasks = new TreeSet<>(); + unassignedTasks.addAll(statefulTaskEndOffsetSums.keySet()); + unassignedTasks.addAll(statelessTasks); + for (final Map.Entry entry : clientStates.entrySet()) { + unassignedTasks.removeAll(entry.getValue().assignedTasks()); + } + history.append("Unassigned Tasks: ").append(unassignedTasks).append('\n'); + } + } + } + + @Test + public void staticAssignmentShouldConvergeWithTheFirstAssignment() { + final AssignmentConfigs configs = new AssignmentConfigs(100L, + 2, + 0, + 60_000L); + + final Harness harness = Harness.initializeCluster(1, 1, 1, () -> 1); + + testForConvergence(harness, configs, 1); + verifyValidAssignment(0, harness); + verifyBalancedAssignment(harness); + } + + @Test + public void assignmentShouldConvergeAfterAddingNode() { + final int numStatelessTasks = 7; + final int numStatefulTasks = 11; + final int maxWarmupReplicas = 2; + final int numStandbyReplicas = 0; + + final AssignmentConfigs configs = new AssignmentConfigs(100L, + maxWarmupReplicas, + numStandbyReplicas, + 60_000L); + + final Harness harness = Harness.initializeCluster(numStatelessTasks, numStatefulTasks, 1, () -> 5); + testForConvergence(harness, configs, 1); + harness.addNode(); + // we expect convergence to involve moving each task at most once, and we can move "maxWarmupReplicas" number + // of tasks at once, hence the iteration limit + testForConvergence(harness, configs, numStatefulTasks / maxWarmupReplicas + 1); + verifyValidAssignment(numStandbyReplicas, harness); + verifyBalancedAssignment(harness); + } + + @Test + public void droppingNodesShouldConverge() { + final int numStatelessTasks = 11; + final int numStatefulTasks = 13; + final int maxWarmupReplicas = 2; + final int numStandbyReplicas = 0; + + final AssignmentConfigs configs = new AssignmentConfigs(100L, + maxWarmupReplicas, + numStandbyReplicas, + 60_000L); + + final Harness harness = Harness.initializeCluster(numStatelessTasks, numStatefulTasks, 7, () -> 5); + testForConvergence(harness, configs, 1); + harness.dropNode(); + // This time, we allow one extra iteration because the + // first stateful task needs to get shuffled back to the first node + testForConvergence(harness, configs, numStatefulTasks / maxWarmupReplicas + 2); + + verifyValidAssignment(numStandbyReplicas, harness); + verifyBalancedAssignment(harness); + } + + @Test + public void randomClusterPerturbationsShouldConverge() { + // do as many tests as we can in 10 seconds + final long deadline = System.currentTimeMillis() + 10_000L; + do { + final long seed = new Random().nextLong(); + runRandomizedScenario(seed); + } while (System.currentTimeMillis() < deadline); + } + + private static void runRandomizedScenario(final long seed) { + Harness harness = null; + try { + final Random prng = new Random(seed); + + // These are all rand(limit)+1 because we need them to be at least 1 and the upper bound is exclusive + final int initialClusterSize = prng.nextInt(10) + 1; + final int numStatelessTasks = prng.nextInt(10) + 1; + final int numStatefulTasks = prng.nextInt(10) + 1; + final int maxWarmupReplicas = prng.nextInt(numStatefulTasks) + 1; + // This one is rand(limit+1) because we _want_ to test zero and the upper bound is exclusive + final int numStandbyReplicas = prng.nextInt(initialClusterSize + 1); + + final int numberOfEvents = prng.nextInt(10) + 1; + + final AssignmentConfigs configs = new AssignmentConfigs(100L, + maxWarmupReplicas, + numStandbyReplicas, + 60_000L); + + harness = Harness.initializeCluster( + numStatelessTasks, + numStatefulTasks, + initialClusterSize, + () -> prng.nextInt(10) + 1 + ); + testForConvergence(harness, configs, 1); + verifyValidAssignment(numStandbyReplicas, harness); + verifyBalancedAssignment(harness); + + for (int i = 0; i < numberOfEvents; i++) { + final int event = prng.nextInt(2); + switch (event) { + case 0: + harness.dropRandomNodes(prng.nextInt(initialClusterSize), prng); + break; + case 1: + harness.addOrResurrectNodesRandomly(prng, initialClusterSize); + break; + default: + throw new IllegalStateException("Unexpected event: " + event); + } + if (!harness.clientStates.isEmpty()) { + testForConvergence(harness, configs, 2 * (numStatefulTasks + numStatefulTasks * numStandbyReplicas)); + verifyValidAssignment(numStandbyReplicas, harness); + verifyBalancedAssignment(harness); + } + } + } catch (final AssertionError t) { + throw new AssertionError( + "Assertion failed in randomized test. Reproduce with: `runRandomizedScenario(" + seed + ")`.", + t + ); + } catch (final Throwable t) { + final StringBuilder builder = + new StringBuilder() + .append("Exception in randomized scenario. Reproduce with: `runRandomizedScenario(") + .append(seed) + .append(")`. "); + if (harness != null) { + builder.append(harness.history); + } + throw new AssertionError(builder.toString(), t); + } + } + + private static void verifyBalancedAssignment(final Harness harness) { + final Set allStatefulTasks = harness.statefulTaskEndOffsetSums.keySet(); + final Map clientStates = harness.clientStates; + final StringBuilder failureContext = harness.history; + + assertBalancedActiveAssignment(clientStates, failureContext); + assertBalancedStatefulAssignment(allStatefulTasks, clientStates, failureContext); + final AssignmentTestUtils.TaskSkewReport taskSkewReport = AssignmentTestUtils.analyzeTaskAssignmentBalance(harness.clientStates); + if (taskSkewReport.totalSkewedTasks() > 0) { + fail( + new StringBuilder().append("Expected a balanced task assignment, but was: ") + .append(taskSkewReport) + .append('\n') + .append(failureContext) + .toString() + ); + } + } + + private static void verifyValidAssignment(final int numStandbyReplicas, final Harness harness) { + final Set statefulTasks = harness.statefulTaskEndOffsetSums.keySet(); + final Set statelessTasks = harness.statelessTasks; + final Map assignedStates = harness.clientStates; + final StringBuilder failureContext = harness.history; + + assertValidAssignment(numStandbyReplicas, statefulTasks, statelessTasks, assignedStates, failureContext); + } + + private static void testForConvergence(final Harness harness, + final AssignmentConfigs configs, + final int iterationLimit) { + final Set allTasks = new TreeSet<>(); + allTasks.addAll(harness.statelessTasks); + allTasks.addAll(harness.statefulTaskEndOffsetSums.keySet()); + + harness.recordConfig(configs); + + boolean rebalancePending = true; + int iteration = 0; + while (rebalancePending && iteration < iterationLimit) { + iteration++; + harness.prepareForNextRebalance(); + harness.recordBefore(iteration); + rebalancePending = new HighAvailabilityTaskAssignor().assign( + harness.clientStates, + allTasks, + harness.statefulTaskEndOffsetSums.keySet(), + configs + ); + harness.recordAfter(iteration, rebalancePending); + } + + if (rebalancePending) { + final StringBuilder message = + new StringBuilder().append("Rebalances have not converged after iteration cutoff: ") + .append(iterationLimit) + .append(harness.history); + fail(message.toString()); + } + } + + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovementTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovementTest.java new file mode 100644 index 0000000..9b58d18 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/assignment/TaskMovementTest.java @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.assignment; + +import org.apache.kafka.streams.processor.TaskId; +import org.junit.Test; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import java.util.SortedSet; +import java.util.TreeMap; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicInteger; + +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; +import static java.util.Collections.emptySortedSet; +import static java.util.Collections.singletonList; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.apache.kafka.common.utils.Utils.mkSortedSet; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_0_2; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_0; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.TASK_1_2; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_1; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_2; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.UUID_3; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.getClientStatesMap; +import static org.apache.kafka.streams.processor.internals.assignment.AssignmentTestUtils.hasProperty; +import static org.apache.kafka.streams.processor.internals.assignment.TaskMovement.assignActiveTaskMovements; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; + +public class TaskMovementTest { + @Test + public void shouldAssignTasksToClientsAndReturnFalseWhenAllClientsCaughtUp() { + final int maxWarmupReplicas = Integer.MAX_VALUE; + final Set allTasks = mkSet(TASK_0_0, TASK_0_1, TASK_0_2, TASK_1_0, TASK_1_1, TASK_1_2); + + final Map> tasksToCaughtUpClients = new HashMap<>(); + for (final TaskId task : allTasks) { + tasksToCaughtUpClients.put(task, mkSortedSet(UUID_1, UUID_2, UUID_3)); + } + + final ClientState client1 = getClientStateWithActiveAssignment(asList(TASK_0_0, TASK_1_0)); + final ClientState client2 = getClientStateWithActiveAssignment(asList(TASK_0_1, TASK_1_1)); + final ClientState client3 = getClientStateWithActiveAssignment(asList(TASK_0_2, TASK_1_2)); + + assertThat( + assignActiveTaskMovements( + tasksToCaughtUpClients, + getClientStatesMap(client1, client2, client3), + new TreeMap<>(), + new AtomicInteger(maxWarmupReplicas) + ), + is(0) + ); + } + + @Test + public void shouldAssignAllTasksToClientsAndReturnFalseIfNoClientsAreCaughtUp() { + final int maxWarmupReplicas = Integer.MAX_VALUE; + + final ClientState client1 = getClientStateWithActiveAssignment(asList(TASK_0_0, TASK_1_0)); + final ClientState client2 = getClientStateWithActiveAssignment(asList(TASK_0_1, TASK_1_1)); + final ClientState client3 = getClientStateWithActiveAssignment(asList(TASK_0_2, TASK_1_2)); + + final Map> tasksToCaughtUpClients = mkMap( + mkEntry(TASK_0_0, emptySortedSet()), + mkEntry(TASK_0_1, emptySortedSet()), + mkEntry(TASK_0_2, emptySortedSet()), + mkEntry(TASK_1_0, emptySortedSet()), + mkEntry(TASK_1_1, emptySortedSet()), + mkEntry(TASK_1_2, emptySortedSet()) + ); + assertThat( + assignActiveTaskMovements( + tasksToCaughtUpClients, + getClientStatesMap(client1, client2, client3), + new TreeMap<>(), + new AtomicInteger(maxWarmupReplicas) + ), + is(0) + ); + } + + @Test + public void shouldMoveTasksToCaughtUpClientsAndAssignWarmupReplicasInTheirPlace() { + final int maxWarmupReplicas = Integer.MAX_VALUE; + final ClientState client1 = getClientStateWithActiveAssignment(singletonList(TASK_0_0)); + final ClientState client2 = getClientStateWithActiveAssignment(singletonList(TASK_0_1)); + final ClientState client3 = getClientStateWithActiveAssignment(singletonList(TASK_0_2)); + final Map clientStates = getClientStatesMap(client1, client2, client3); + + final Map> tasksToCaughtUpClients = mkMap( + mkEntry(TASK_0_0, mkSortedSet(UUID_1)), + mkEntry(TASK_0_1, mkSortedSet(UUID_3)), + mkEntry(TASK_0_2, mkSortedSet(UUID_2)) + ); + + assertThat( + "should have assigned movements", + assignActiveTaskMovements( + tasksToCaughtUpClients, + clientStates, + new TreeMap<>(), + new AtomicInteger(maxWarmupReplicas) + ), + is(2) + ); + // The active tasks have changed to the ones that each client is caught up on + assertThat(client1, hasProperty("activeTasks", ClientState::activeTasks, mkSet(TASK_0_0))); + assertThat(client2, hasProperty("activeTasks", ClientState::activeTasks, mkSet(TASK_0_2))); + assertThat(client3, hasProperty("activeTasks", ClientState::activeTasks, mkSet(TASK_0_1))); + + // we assigned warmups to migrate to the input active assignment + assertThat(client1, hasProperty("standbyTasks", ClientState::standbyTasks, mkSet())); + assertThat(client2, hasProperty("standbyTasks", ClientState::standbyTasks, mkSet(TASK_0_1))); + assertThat(client3, hasProperty("standbyTasks", ClientState::standbyTasks, mkSet(TASK_0_2))); + } + + @Test + public void shouldOnlyGetUpToMaxWarmupReplicasAndReturnTrue() { + final int maxWarmupReplicas = 1; + final ClientState client1 = getClientStateWithActiveAssignment(singletonList(TASK_0_0)); + final ClientState client2 = getClientStateWithActiveAssignment(singletonList(TASK_0_1)); + final ClientState client3 = getClientStateWithActiveAssignment(singletonList(TASK_0_2)); + final Map clientStates = getClientStatesMap(client1, client2, client3); + + final Map> tasksToCaughtUpClients = mkMap( + mkEntry(TASK_0_0, mkSortedSet(UUID_1)), + mkEntry(TASK_0_1, mkSortedSet(UUID_3)), + mkEntry(TASK_0_2, mkSortedSet(UUID_2)) + ); + + assertThat( + "should have assigned movements", + assignActiveTaskMovements( + tasksToCaughtUpClients, + clientStates, + new TreeMap<>(), + new AtomicInteger(maxWarmupReplicas) + ), + is(2) + ); + // The active tasks have changed to the ones that each client is caught up on + assertThat(client1, hasProperty("activeTasks", ClientState::activeTasks, mkSet(TASK_0_0))); + assertThat(client2, hasProperty("activeTasks", ClientState::activeTasks, mkSet(TASK_0_2))); + assertThat(client3, hasProperty("activeTasks", ClientState::activeTasks, mkSet(TASK_0_1))); + + // we should only assign one warmup, but it could be either one that needs to be migrated. + assertThat(client1, hasProperty("standbyTasks", ClientState::standbyTasks, mkSet())); + try { + assertThat(client2, hasProperty("standbyTasks", ClientState::standbyTasks, mkSet(TASK_0_1))); + assertThat(client3, hasProperty("standbyTasks", ClientState::standbyTasks, mkSet())); + } catch (final AssertionError ignored) { + assertThat(client2, hasProperty("standbyTasks", ClientState::standbyTasks, mkSet())); + assertThat(client3, hasProperty("standbyTasks", ClientState::standbyTasks, mkSet(TASK_0_2))); + } + } + + @Test + public void shouldNotCountPreviousStandbyTasksTowardsMaxWarmupReplicas() { + final int maxWarmupReplicas = 0; + final ClientState client1 = getClientStateWithActiveAssignment(emptyList()); + client1.assignStandby(TASK_0_0); + final ClientState client2 = getClientStateWithActiveAssignment(singletonList(TASK_0_0)); + final Map clientStates = getClientStatesMap(client1, client2); + + final Map> tasksToCaughtUpClients = mkMap( + mkEntry(TASK_0_0, mkSortedSet(UUID_1)) + ); + + assertThat( + "should have assigned movements", + assignActiveTaskMovements( + tasksToCaughtUpClients, + clientStates, + new TreeMap<>(), + new AtomicInteger(maxWarmupReplicas) + ), + is(1) + ); + // Even though we have no warmups allowed, we still let client1 take over active processing while + // client2 "warms up" because client1 was a caught-up standby, so it can "trade" standby status with + // the not-caught-up active client2. + + // I.e., when you have a caught-up standby and a not-caught-up active, you can just swap their roles + // and not call it a "warmup". + assertThat(client1, hasProperty("activeTasks", ClientState::activeTasks, mkSet(TASK_0_0))); + assertThat(client2, hasProperty("activeTasks", ClientState::activeTasks, mkSet())); + + assertThat(client1, hasProperty("standbyTasks", ClientState::standbyTasks, mkSet())); + assertThat(client2, hasProperty("standbyTasks", ClientState::standbyTasks, mkSet(TASK_0_0))); + + } + + private static ClientState getClientStateWithActiveAssignment(final Collection activeTasks) { + final ClientState client1 = new ClientState(1); + client1.assignActiveTasks(activeTasks); + return client1; + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/metrics/ProcessorNodeMetricsTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/metrics/ProcessorNodeMetricsTest.java new file mode 100644 index 0000000..0ae1a99 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/metrics/ProcessorNodeMetricsTest.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.metrics; + +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.Sensor.RecordingLevel; +import org.junit.Test; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + + +import java.util.Collections; +import java.util.Map; +import java.util.function.Supplier; + +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.PROCESSOR_NODE_LEVEL_GROUP; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +public class ProcessorNodeMetricsTest { + + private static final String THREAD_ID = "test-thread"; + private static final String TASK_ID = "test-task"; + private static final String PROCESSOR_NODE_ID = "test-processor"; + + private final Map tagMap = Collections.singletonMap("hello", "world"); + private final Map parentTagMap = Collections.singletonMap("hi", "universe"); + + private final Sensor expectedSensor = mock(Sensor.class); + private final StreamsMetricsImpl streamsMetrics = mock(StreamsMetricsImpl.class); + private final Sensor expectedParentSensor = mock(Sensor.class); + + @Test + public void shouldGetSuppressionEmitSensor() { + final String metricNamePrefix = "suppression-emit"; + final String descriptionOfCount = "The total number of emitted records from the suppression buffer"; + final String descriptionOfRate = "The average number of emitted records from the suppression buffer per second"; + when(streamsMetrics.nodeLevelSensor(THREAD_ID, TASK_ID, PROCESSOR_NODE_ID, metricNamePrefix, RecordingLevel.DEBUG)) + .thenReturn(expectedSensor); + when(streamsMetrics.nodeLevelTagMap(THREAD_ID, TASK_ID, PROCESSOR_NODE_ID)).thenReturn(tagMap); + StreamsMetricsImpl.addInvocationRateAndCountToSensor( + expectedSensor, + PROCESSOR_NODE_LEVEL_GROUP, + tagMap, + metricNamePrefix, + descriptionOfRate, + descriptionOfCount + ); + + verifySensor( + () -> ProcessorNodeMetrics.suppressionEmitSensor(THREAD_ID, TASK_ID, PROCESSOR_NODE_ID, streamsMetrics)); + } + + @Test + public void shouldGetIdempotentUpdateSkipSensor() { + final String metricNamePrefix = "idempotent-update-skip"; + final String descriptionOfCount = "The total number of skipped idempotent updates"; + final String descriptionOfRate = "The average number of skipped idempotent updates per second"; + when(streamsMetrics.nodeLevelSensor(THREAD_ID, TASK_ID, PROCESSOR_NODE_ID, metricNamePrefix, RecordingLevel.DEBUG)) + .thenReturn(expectedSensor); + when(streamsMetrics.nodeLevelTagMap(THREAD_ID, TASK_ID, PROCESSOR_NODE_ID)).thenReturn(tagMap); + StreamsMetricsImpl.addInvocationRateAndCountToSensor( + expectedSensor, + PROCESSOR_NODE_LEVEL_GROUP, + tagMap, + metricNamePrefix, + descriptionOfRate, + descriptionOfCount + ); + verifySensor( + () -> ProcessorNodeMetrics.skippedIdempotentUpdatesSensor(THREAD_ID, TASK_ID, PROCESSOR_NODE_ID, streamsMetrics) + ); + } + + @Test + public void shouldGetProcessAtSourceSensor() { + final String metricNamePrefix = "process"; + final String descriptionOfCount = "The total number of calls to process"; + final String descriptionOfRate = "The average number of calls to process per second"; + when(streamsMetrics.taskLevelSensor(THREAD_ID, TASK_ID, metricNamePrefix, RecordingLevel.DEBUG)) + .thenReturn(expectedParentSensor); + when(streamsMetrics.taskLevelTagMap(THREAD_ID, TASK_ID)) + .thenReturn(parentTagMap); + StreamsMetricsImpl.addInvocationRateAndCountToSensor( + expectedParentSensor, + StreamsMetricsImpl.TASK_LEVEL_GROUP, + parentTagMap, + metricNamePrefix, + descriptionOfRate, + descriptionOfCount + ); + setUpThroughputSensor( + metricNamePrefix, + descriptionOfRate, + descriptionOfCount, + RecordingLevel.DEBUG, + expectedParentSensor + ); + + verifySensor(() -> ProcessorNodeMetrics.processAtSourceSensor(THREAD_ID, TASK_ID, PROCESSOR_NODE_ID, streamsMetrics)); + } + + @Test + public void shouldGetForwardSensor() { + final String metricNamePrefix = "forward"; + final String descriptionOfCount = "The total number of calls to forward"; + final String descriptionOfRate = "The average number of calls to forward per second"; + setUpThroughputParentSensor( + metricNamePrefix, + descriptionOfRate, + descriptionOfCount + ); + setUpThroughputSensor( + metricNamePrefix, + descriptionOfRate, + descriptionOfCount, + RecordingLevel.DEBUG, + expectedParentSensor + ); + + verifySensor(() -> ProcessorNodeMetrics.forwardSensor(THREAD_ID, TASK_ID, PROCESSOR_NODE_ID, streamsMetrics)); + } + + private void setUpThroughputParentSensor(final String metricNamePrefix, + final String descriptionOfRate, + final String descriptionOfCount) { + when(streamsMetrics.taskLevelSensor(THREAD_ID, TASK_ID, metricNamePrefix, RecordingLevel.DEBUG)) + .thenReturn(expectedParentSensor); + when(streamsMetrics.nodeLevelTagMap(THREAD_ID, TASK_ID, StreamsMetricsImpl.ROLLUP_VALUE)) + .thenReturn(parentTagMap); + StreamsMetricsImpl.addInvocationRateAndCountToSensor( + expectedParentSensor, + PROCESSOR_NODE_LEVEL_GROUP, + parentTagMap, + metricNamePrefix, + descriptionOfRate, + descriptionOfCount + ); + } + + private void setUpThroughputSensor(final String metricNamePrefix, + final String descriptionOfRate, + final String descriptionOfCount, + final RecordingLevel recordingLevel, + final Sensor... parentSensors) { + when(streamsMetrics.nodeLevelSensor( + THREAD_ID, + TASK_ID, + PROCESSOR_NODE_ID, + metricNamePrefix, + recordingLevel, + parentSensors + )).thenReturn(expectedSensor); + when(streamsMetrics.nodeLevelTagMap(THREAD_ID, TASK_ID, PROCESSOR_NODE_ID)).thenReturn(tagMap); + StreamsMetricsImpl.addInvocationRateAndCountToSensor( + expectedSensor, + PROCESSOR_NODE_LEVEL_GROUP, + tagMap, + metricNamePrefix, + descriptionOfRate, + descriptionOfCount + ); + } + + private void verifySensor(final Supplier sensorSupplier) { + final Sensor sensor = sensorSupplier.get(); + assertThat(sensor, is(expectedSensor)); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/metrics/StreamsMetricsImplTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/metrics/StreamsMetricsImplTest.java new file mode 100644 index 0000000..24cf8c7 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/metrics/StreamsMetricsImplTest.java @@ -0,0 +1,1302 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.metrics; + +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.Gauge; +import org.apache.kafka.common.metrics.KafkaMetric; +import org.apache.kafka.common.metrics.MetricConfig; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.Sensor.RecordingLevel; +import org.apache.kafka.common.metrics.stats.Rate; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.ImmutableMetricValue; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.Version; +import org.apache.kafka.test.StreamsTestUtils; +import org.easymock.Capture; +import org.easymock.CaptureType; +import org.easymock.EasyMock; +import org.easymock.IArgumentMatcher; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; + +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.AVG_SUFFIX; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.CLIENT_ID_TAG; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.CLIENT_LEVEL_GROUP; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.LATENCY_SUFFIX; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.MAX_SUFFIX; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.PROCESSOR_NODE_LEVEL_GROUP; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.RATE_SUFFIX; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.ROLLUP_VALUE; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.STATE_STORE_LEVEL_GROUP; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.THREAD_LEVEL_GROUP; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.TOTAL_SUFFIX; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addAvgAndMaxLatencyToSensor; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.addInvocationRateAndCountToSensor; +import static org.easymock.EasyMock.anyObject; +import static org.easymock.EasyMock.anyString; +import static org.easymock.EasyMock.capture; +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.mock; +import static org.easymock.EasyMock.newCapture; +import static org.easymock.EasyMock.niceMock; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.resetToDefault; +import static org.easymock.EasyMock.verify; +import static org.easymock.EasyMock.eq; +import static org.hamcrest.CoreMatchers.equalToObject; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.powermock.api.easymock.PowerMock.createMock; + +@RunWith(PowerMockRunner.class) +@PrepareForTest({Sensor.class, KafkaMetric.class}) +public class StreamsMetricsImplTest { + + private final static String SENSOR_PREFIX_DELIMITER = "."; + private final static String SENSOR_NAME_DELIMITER = ".s."; + private final static String SENSOR_NAME_1 = "sensor1"; + private final static String SENSOR_NAME_2 = "sensor2"; + private final static String INTERNAL_PREFIX = "internal"; + private final static String VERSION = StreamsConfig.METRICS_LATEST; + private final static String CLIENT_ID = "test-client"; + private final static String THREAD_ID1 = "test-thread-1"; + private final static String TASK_ID1 = "test-task-1"; + private final static String TASK_ID2 = "test-task-2"; + private final static String METRIC_NAME1 = "test-metric1"; + private final static String METRIC_NAME2 = "test-metric2"; + private final static String THREAD_ID_TAG = "thread-id"; + private final static String TASK_ID_TAG = "task-id"; + private final static String SCOPE_NAME = "test-scope"; + private final static String STORE_ID_TAG = "-state-id"; + private final static String STORE_NAME1 = "store1"; + private final static String STORE_NAME2 = "store2"; + private final static Map STORE_LEVEL_TAG_MAP = mkMap( + mkEntry(THREAD_ID_TAG, Thread.currentThread().getName()), + mkEntry(TASK_ID_TAG, TASK_ID1), + mkEntry(SCOPE_NAME + STORE_ID_TAG, STORE_NAME1) + ); + private final static String RECORD_CACHE_ID_TAG = "record-cache-id"; + private final static String ENTITY_NAME = "test-entity"; + private final static String OPERATION_NAME = "test-operation"; + private final static String CUSTOM_TAG_KEY1 = "test-key1"; + private final static String CUSTOM_TAG_VALUE1 = "test-value1"; + private final static String CUSTOM_TAG_KEY2 = "test-key2"; + private final static String CUSTOM_TAG_VALUE2 = "test-value2"; + private final static RecordingLevel INFO_RECORDING_LEVEL = RecordingLevel.INFO; + private final static String DESCRIPTION1 = "description number one"; + private final static String DESCRIPTION2 = "description number two"; + private final static String DESCRIPTION3 = "description number three"; + private final static Gauge VALUE_PROVIDER = (config, now) -> "mutable-value"; + + private final Metrics metrics = new Metrics(); + private final Sensor sensor = metrics.sensor("dummy"); + private final String metricNamePrefix = "metric"; + private final String group = "group"; + private final Map tags = mkMap(mkEntry("tag", "value")); + private final Map clientLevelTags = mkMap(mkEntry(CLIENT_ID_TAG, CLIENT_ID)); + private final MetricName metricName1 = + new MetricName(METRIC_NAME1, CLIENT_LEVEL_GROUP, DESCRIPTION1, clientLevelTags); + private final MetricName metricName2 = + new MetricName(METRIC_NAME1, CLIENT_LEVEL_GROUP, DESCRIPTION2, clientLevelTags); + private final MockTime time = new MockTime(0); + private final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time); + + private static MetricConfig eqMetricConfig(final MetricConfig metricConfig) { + EasyMock.reportMatcher(new IArgumentMatcher() { + private final StringBuffer message = new StringBuffer(); + + @Override + public boolean matches(final Object argument) { + if (argument instanceof MetricConfig) { + final MetricConfig otherMetricConfig = (MetricConfig) argument; + final boolean equalsComparisons = + (otherMetricConfig.quota() == metricConfig.quota() || + otherMetricConfig.quota().equals(metricConfig.quota())) && + otherMetricConfig.tags().equals(metricConfig.tags()); + if (otherMetricConfig.eventWindow() == metricConfig.eventWindow() && + otherMetricConfig.recordLevel() == metricConfig.recordLevel() && + equalsComparisons && + otherMetricConfig.samples() == metricConfig.samples() && + otherMetricConfig.timeWindowMs() == metricConfig.timeWindowMs()) { + + return true; + } else { + message.append("{ "); + message.append("eventWindow="); + message.append(otherMetricConfig.eventWindow()); + message.append(", "); + message.append("recordLevel="); + message.append(otherMetricConfig.recordLevel()); + message.append(", "); + message.append("quota="); + message.append(otherMetricConfig.quota().toString()); + message.append(", "); + message.append("samples="); + message.append(otherMetricConfig.samples()); + message.append(", "); + message.append("tags="); + message.append(otherMetricConfig.tags().toString()); + message.append(", "); + message.append("timeWindowMs="); + message.append(otherMetricConfig.timeWindowMs()); + message.append(" }"); + } + } + message.append("not a MetricConfig object"); + return false; + } + + @Override + public void appendTo(final StringBuffer buffer) { + buffer.append(message); + } + }); + return null; + } + + private Capture addSensorsOnAllLevels(final Metrics metrics, final StreamsMetricsImpl streamsMetrics) { + final Capture sensorKeys = newCapture(CaptureType.ALL); + final Sensor[] parents = {}; + expect(metrics.sensor(capture(sensorKeys), eq(INFO_RECORDING_LEVEL), parents)) + .andStubReturn(sensor); + expect(metrics.metricName(METRIC_NAME1, CLIENT_LEVEL_GROUP, DESCRIPTION1, clientLevelTags)) + .andReturn(metricName1); + expect(metrics.metricName(METRIC_NAME2, CLIENT_LEVEL_GROUP, DESCRIPTION2, clientLevelTags)) + .andReturn(metricName2); + replay(metrics); + streamsMetrics.addClientLevelImmutableMetric(METRIC_NAME1, DESCRIPTION1, INFO_RECORDING_LEVEL, "value"); + streamsMetrics.addClientLevelImmutableMetric(METRIC_NAME2, DESCRIPTION2, INFO_RECORDING_LEVEL, "value"); + streamsMetrics.clientLevelSensor(SENSOR_NAME_1, INFO_RECORDING_LEVEL); + streamsMetrics.clientLevelSensor(SENSOR_NAME_2, INFO_RECORDING_LEVEL); + streamsMetrics.threadLevelSensor(THREAD_ID1, SENSOR_NAME_1, INFO_RECORDING_LEVEL); + streamsMetrics.threadLevelSensor(THREAD_ID1, SENSOR_NAME_2, INFO_RECORDING_LEVEL); + streamsMetrics.taskLevelSensor(THREAD_ID1, TASK_ID1, SENSOR_NAME_1, INFO_RECORDING_LEVEL); + streamsMetrics.taskLevelSensor(THREAD_ID1, TASK_ID1, SENSOR_NAME_2, INFO_RECORDING_LEVEL); + streamsMetrics.storeLevelSensor(TASK_ID1, STORE_NAME1, SENSOR_NAME_1, INFO_RECORDING_LEVEL); + streamsMetrics.storeLevelSensor(TASK_ID1, STORE_NAME1, SENSOR_NAME_2, INFO_RECORDING_LEVEL); + streamsMetrics.storeLevelSensor(TASK_ID1, STORE_NAME2, SENSOR_NAME_1, INFO_RECORDING_LEVEL); + streamsMetrics.addStoreLevelMutableMetric( + TASK_ID1, + SCOPE_NAME, + STORE_NAME1, + METRIC_NAME1, + DESCRIPTION1, + INFO_RECORDING_LEVEL, + VALUE_PROVIDER + ); + streamsMetrics.addStoreLevelMutableMetric( + TASK_ID1, + SCOPE_NAME, + STORE_NAME1, + METRIC_NAME2, + DESCRIPTION2, + INFO_RECORDING_LEVEL, + VALUE_PROVIDER + ); + streamsMetrics.addStoreLevelMutableMetric( + TASK_ID1, + SCOPE_NAME, + STORE_NAME2, + METRIC_NAME1, + DESCRIPTION1, + INFO_RECORDING_LEVEL, + VALUE_PROVIDER + ); + return sensorKeys; + } + + private Capture setupGetNewSensorTest(final Metrics metrics, + final RecordingLevel recordingLevel) { + final Capture sensorKey = newCapture(CaptureType.ALL); + expect(metrics.getSensor(capture(sensorKey))).andStubReturn(null); + final Sensor[] parents = {}; + expect(metrics.sensor(capture(sensorKey), eq(recordingLevel), parents)).andReturn(sensor); + replay(metrics); + return sensorKey; + } + + private void setupGetExistingSensorTest(final Metrics metrics) { + expect(metrics.getSensor(anyString())).andStubReturn(sensor); + replay(metrics); + } + + @Test + public void shouldGetNewThreadLevelSensor() { + final Metrics metrics = mock(Metrics.class); + final RecordingLevel recordingLevel = RecordingLevel.INFO; + setupGetNewSensorTest(metrics, recordingLevel); + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time); + + final Sensor actualSensor = streamsMetrics.threadLevelSensor(THREAD_ID1, SENSOR_NAME_1, recordingLevel); + + verify(metrics); + assertThat(actualSensor, is(equalToObject(sensor))); + } + + @Test + public void shouldGetExistingThreadLevelSensor() { + final Metrics metrics = mock(Metrics.class); + final RecordingLevel recordingLevel = RecordingLevel.INFO; + setupGetExistingSensorTest(metrics); + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time); + + final Sensor actualSensor = streamsMetrics.threadLevelSensor(THREAD_ID1, SENSOR_NAME_1, recordingLevel); + + verify(metrics); + assertThat(actualSensor, is(equalToObject(sensor))); + } + + @Test + public void shouldGetNewTaskLevelSensor() { + final Metrics metrics = mock(Metrics.class); + final RecordingLevel recordingLevel = RecordingLevel.INFO; + setupGetNewSensorTest(metrics, recordingLevel); + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time); + + final Sensor actualSensor = streamsMetrics.taskLevelSensor( + THREAD_ID1, + TASK_ID1, + SENSOR_NAME_1, + recordingLevel + ); + + verify(metrics); + assertThat(actualSensor, is(equalToObject(sensor))); + } + + @Test + public void shouldGetExistingTaskLevelSensor() { + final Metrics metrics = mock(Metrics.class); + final RecordingLevel recordingLevel = RecordingLevel.INFO; + setupGetExistingSensorTest(metrics); + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time); + + final Sensor actualSensor = streamsMetrics.taskLevelSensor( + THREAD_ID1, + TASK_ID1, + SENSOR_NAME_1, + recordingLevel + ); + + verify(metrics); + assertThat(actualSensor, is(equalToObject(sensor))); + } + + @Test + public void shouldGetNewStoreLevelSensorIfNoneExists() { + final Metrics metrics = mock(Metrics.class); + final RecordingLevel recordingLevel = RecordingLevel.INFO; + final Capture sensorKeys = setupGetNewSensorTest(metrics, recordingLevel); + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time); + + final Sensor actualSensor = streamsMetrics.storeLevelSensor( + TASK_ID1, + STORE_NAME1, + SENSOR_NAME_1, + recordingLevel + ); + + verify(metrics); + assertThat(actualSensor, is(equalToObject(sensor))); + assertThat(sensorKeys.getValues().get(0), is(sensorKeys.getValues().get(1))); + } + + @Test + public void shouldGetExistingStoreLevelSensor() { + final Metrics metrics = mock(Metrics.class); + final RecordingLevel recordingLevel = RecordingLevel.INFO; + setupGetExistingSensorTest(metrics); + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time); + + final Sensor actualSensor = streamsMetrics.storeLevelSensor( + TASK_ID1, + STORE_NAME1, + SENSOR_NAME_1, + recordingLevel + ); + + verify(metrics); + assertThat(actualSensor, is(equalToObject(sensor))); + } + + @Test + public void shouldUseSameStoreLevelSensorKeyWithTwoDifferentSensorNames() { + final Metrics metrics = niceMock(Metrics.class); + final Capture sensorKeys = setUpSensorKeyTests(metrics); + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time); + + streamsMetrics.storeLevelSensor(TASK_ID1, STORE_NAME1, SENSOR_NAME_1, INFO_RECORDING_LEVEL); + streamsMetrics.storeLevelSensor(TASK_ID1, STORE_NAME1, SENSOR_NAME_2, INFO_RECORDING_LEVEL); + + assertThat(sensorKeys.getValues().get(0), not(sensorKeys.getValues().get(1))); + } + + @Test + public void shouldNotUseSameStoreLevelSensorKeyWithDifferentTaskIds() { + final Metrics metrics = niceMock(Metrics.class); + final Capture sensorKeys = setUpSensorKeyTests(metrics); + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time); + + streamsMetrics.storeLevelSensor(TASK_ID1, STORE_NAME1, SENSOR_NAME_1, INFO_RECORDING_LEVEL); + streamsMetrics.storeLevelSensor(TASK_ID2, STORE_NAME1, SENSOR_NAME_1, INFO_RECORDING_LEVEL); + + assertThat(sensorKeys.getValues().get(0), not(sensorKeys.getValues().get(1))); + } + + @Test + public void shouldNotUseSameStoreLevelSensorKeyWithDifferentStoreNames() { + final Metrics metrics = niceMock(Metrics.class); + final Capture sensorKeys = setUpSensorKeyTests(metrics); + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time); + + streamsMetrics.storeLevelSensor(TASK_ID1, STORE_NAME1, SENSOR_NAME_1, INFO_RECORDING_LEVEL); + streamsMetrics.storeLevelSensor(TASK_ID1, STORE_NAME2, SENSOR_NAME_1, INFO_RECORDING_LEVEL); + + assertThat(sensorKeys.getValues().get(0), not(sensorKeys.getValues().get(1))); + } + + @Test + public void shouldNotUseSameStoreLevelSensorKeyWithDifferentThreadIds() throws InterruptedException { + final Metrics metrics = niceMock(Metrics.class); + final Capture sensorKeys = setUpSensorKeyTests(metrics); + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time); + + streamsMetrics.storeLevelSensor(TASK_ID1, STORE_NAME1, SENSOR_NAME_1, INFO_RECORDING_LEVEL); + final Thread otherThread = + new Thread(() -> streamsMetrics.storeLevelSensor(TASK_ID1, STORE_NAME1, SENSOR_NAME_1, INFO_RECORDING_LEVEL)); + otherThread.start(); + otherThread.join(); + + assertThat(sensorKeys.getValues().get(0), not(sensorKeys.getValues().get(1))); + } + + @Test + public void shouldUseSameStoreLevelSensorKeyWithSameSensorNames() { + final Metrics metrics = niceMock(Metrics.class); + final Capture sensorKeys = setUpSensorKeyTests(metrics); + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time); + + streamsMetrics.storeLevelSensor(TASK_ID1, STORE_NAME1, SENSOR_NAME_1, INFO_RECORDING_LEVEL); + streamsMetrics.storeLevelSensor(TASK_ID1, STORE_NAME1, SENSOR_NAME_1, INFO_RECORDING_LEVEL); + + assertThat(sensorKeys.getValues().get(0), is(sensorKeys.getValues().get(1))); + } + + private Capture setUpSensorKeyTests(final Metrics metrics) { + final Capture sensorKeys = newCapture(CaptureType.ALL); + expect(metrics.getSensor(capture(sensorKeys))).andStubReturn(sensor); + replay(metrics); + return sensorKeys; + } + + @Test + public void shouldAddNewStoreLevelMutableMetric() { + final Metrics metrics = mock(Metrics.class); + final MetricName metricName = + new MetricName(METRIC_NAME1, STATE_STORE_LEVEL_GROUP, DESCRIPTION1, STORE_LEVEL_TAG_MAP); + final MetricConfig metricConfig = new MetricConfig().recordLevel(INFO_RECORDING_LEVEL); + expect(metrics.metricName(METRIC_NAME1, STATE_STORE_LEVEL_GROUP, DESCRIPTION1, STORE_LEVEL_TAG_MAP)) + .andReturn(metricName); + expect(metrics.metric(metricName)).andReturn(null); + metrics.addMetric(eq(metricName), eqMetricConfig(metricConfig), eq(VALUE_PROVIDER)); + replay(metrics); + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time); + + streamsMetrics.addStoreLevelMutableMetric( + TASK_ID1, + SCOPE_NAME, + STORE_NAME1, + METRIC_NAME1, + DESCRIPTION1, + INFO_RECORDING_LEVEL, + VALUE_PROVIDER + ); + + verify(metrics); + } + + @Test + public void shouldNotAddStoreLevelMutableMetricIfAlreadyExists() { + final Metrics metrics = mock(Metrics.class); + final MetricName metricName = + new MetricName(METRIC_NAME1, STATE_STORE_LEVEL_GROUP, DESCRIPTION1, STORE_LEVEL_TAG_MAP); + expect(metrics.metricName(METRIC_NAME1, STATE_STORE_LEVEL_GROUP, DESCRIPTION1, STORE_LEVEL_TAG_MAP)) + .andReturn(metricName); + expect(metrics.metric(metricName)).andReturn(mock(KafkaMetric.class)); + replay(metrics); + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time); + + streamsMetrics.addStoreLevelMutableMetric( + TASK_ID1, + SCOPE_NAME, + STORE_NAME1, + METRIC_NAME1, + DESCRIPTION1, + INFO_RECORDING_LEVEL, + VALUE_PROVIDER + ); + + verify(metrics); + } + + @Test + public void shouldRemoveStateStoreLevelSensors() { + final Metrics metrics = niceMock(Metrics.class); + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time); + final MetricName metricName1 = + new MetricName(METRIC_NAME1, STATE_STORE_LEVEL_GROUP, DESCRIPTION1, STORE_LEVEL_TAG_MAP); + final MetricName metricName2 = + new MetricName(METRIC_NAME2, STATE_STORE_LEVEL_GROUP, DESCRIPTION2, STORE_LEVEL_TAG_MAP); + expect(metrics.metricName(METRIC_NAME1, STATE_STORE_LEVEL_GROUP, DESCRIPTION1, STORE_LEVEL_TAG_MAP)) + .andReturn(metricName1); + expect(metrics.metricName(METRIC_NAME2, STATE_STORE_LEVEL_GROUP, DESCRIPTION2, STORE_LEVEL_TAG_MAP)) + .andReturn(metricName2); + final Capture sensorKeys = addSensorsOnAllLevels(metrics, streamsMetrics); + resetToDefault(metrics); + metrics.removeSensor(sensorKeys.getValues().get(6)); + metrics.removeSensor(sensorKeys.getValues().get(7)); + expect(metrics.removeMetric(metricName1)).andReturn(mock(KafkaMetric.class)); + expect(metrics.removeMetric(metricName2)).andReturn(mock(KafkaMetric.class)); + replay(metrics); + + streamsMetrics.removeAllStoreLevelSensorsAndMetrics(TASK_ID1, STORE_NAME1); + + verify(metrics); + } + + @Test + public void shouldGetNewNodeLevelSensor() { + final Metrics metrics = mock(Metrics.class); + final RecordingLevel recordingLevel = RecordingLevel.INFO; + final String processorNodeName = "processorNodeName"; + setupGetNewSensorTest(metrics, recordingLevel); + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time); + + final Sensor actualSensor = streamsMetrics.nodeLevelSensor( + THREAD_ID1, + TASK_ID1, + processorNodeName, + SENSOR_NAME_1, + recordingLevel + ); + + verify(metrics); + assertThat(actualSensor, is(equalToObject(sensor))); + } + + @Test + public void shouldGetExistingNodeLevelSensor() { + final Metrics metrics = mock(Metrics.class); + final RecordingLevel recordingLevel = RecordingLevel.INFO; + final String processorNodeName = "processorNodeName"; + setupGetExistingSensorTest(metrics); + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time); + + final Sensor actualSensor = streamsMetrics.nodeLevelSensor( + THREAD_ID1, + TASK_ID1, + processorNodeName, + SENSOR_NAME_1, + recordingLevel + ); + + verify(metrics); + assertThat(actualSensor, is(equalToObject(sensor))); + } + + @Test + public void shouldGetNewCacheLevelSensor() { + final Metrics metrics = mock(Metrics.class); + final RecordingLevel recordingLevel = RecordingLevel.INFO; + final String processorCacheName = "processorNodeName"; + setupGetNewSensorTest(metrics, recordingLevel); + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time); + + final Sensor actualSensor = streamsMetrics.cacheLevelSensor( + THREAD_ID1, + TASK_ID1, + processorCacheName, + SENSOR_NAME_1, + recordingLevel + ); + + verify(metrics); + assertThat(actualSensor, is(equalToObject(sensor))); + } + + @Test + public void shouldGetExistingCacheLevelSensor() { + final Metrics metrics = mock(Metrics.class); + final RecordingLevel recordingLevel = RecordingLevel.INFO; + final String processorCacheName = "processorNodeName"; + setupGetExistingSensorTest(metrics); + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time); + + final Sensor actualSensor = streamsMetrics.cacheLevelSensor( + THREAD_ID1, TASK_ID1, + processorCacheName, + SENSOR_NAME_1, + recordingLevel + ); + + verify(metrics); + assertThat(actualSensor, is(equalToObject(sensor))); + } + + @Test + public void shouldGetNewClientLevelSensor() { + final Metrics metrics = mock(Metrics.class); + final RecordingLevel recordingLevel = RecordingLevel.INFO; + setupGetNewSensorTest(metrics, recordingLevel); + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time); + + final Sensor actualSensor = streamsMetrics.clientLevelSensor(SENSOR_NAME_1, recordingLevel); + + verify(metrics); + assertThat(actualSensor, is(equalToObject(sensor))); + } + + @Test + public void shouldGetExistingClientLevelSensor() { + final Metrics metrics = mock(Metrics.class); + final RecordingLevel recordingLevel = RecordingLevel.INFO; + setupGetExistingSensorTest(metrics); + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time); + + final Sensor actualSensor = streamsMetrics.clientLevelSensor(SENSOR_NAME_1, recordingLevel); + + verify(metrics); + assertThat(actualSensor, is(equalToObject(sensor))); + } + + @Test + public void shouldAddClientLevelImmutableMetric() { + final Metrics metrics = mock(Metrics.class); + final RecordingLevel recordingLevel = RecordingLevel.INFO; + final MetricConfig metricConfig = new MetricConfig().recordLevel(recordingLevel); + final String value = "immutable-value"; + final ImmutableMetricValue immutableValue = new ImmutableMetricValue<>(value); + expect(metrics.metricName(METRIC_NAME1, CLIENT_LEVEL_GROUP, DESCRIPTION1, clientLevelTags)) + .andReturn(metricName1); + metrics.addMetric(eq(metricName1), eqMetricConfig(metricConfig), eq(immutableValue)); + replay(metrics); + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time); + + streamsMetrics.addClientLevelImmutableMetric(METRIC_NAME1, DESCRIPTION1, recordingLevel, value); + + verify(metrics); + } + + @Test + public void shouldAddClientLevelMutableMetric() { + final Metrics metrics = mock(Metrics.class); + final RecordingLevel recordingLevel = RecordingLevel.INFO; + final MetricConfig metricConfig = new MetricConfig().recordLevel(recordingLevel); + final Gauge valueProvider = (config, now) -> "mutable-value"; + expect(metrics.metricName(METRIC_NAME1, CLIENT_LEVEL_GROUP, DESCRIPTION1, clientLevelTags)) + .andReturn(metricName1); + metrics.addMetric(EasyMock.eq(metricName1), eqMetricConfig(metricConfig), eq(valueProvider)); + replay(metrics); + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time); + + streamsMetrics.addClientLevelMutableMetric(METRIC_NAME1, DESCRIPTION1, recordingLevel, valueProvider); + + verify(metrics); + } + + @Test + public void shouldProvideCorrectStrings() { + assertThat(LATENCY_SUFFIX, is("-latency")); + assertThat(ROLLUP_VALUE, is("all")); + } + + private void setupRemoveSensorsTest(final Metrics metrics, + final String level) { + final String fullSensorNamePrefix = INTERNAL_PREFIX + SENSOR_PREFIX_DELIMITER + level + SENSOR_NAME_DELIMITER; + resetToDefault(metrics); + metrics.removeSensor(fullSensorNamePrefix + SENSOR_NAME_1); + metrics.removeSensor(fullSensorNamePrefix + SENSOR_NAME_2); + replay(metrics); + } + + @Test + public void shouldRemoveClientLevelMetricsAndSensors() { + final Metrics metrics = niceMock(Metrics.class); + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time); + final Capture sensorKeys = addSensorsOnAllLevels(metrics, streamsMetrics); + resetToDefault(metrics); + + metrics.removeSensor(sensorKeys.getValues().get(0)); + metrics.removeSensor(sensorKeys.getValues().get(1)); + expect(metrics.removeMetric(metricName1)).andStubReturn(null); + expect(metrics.removeMetric(metricName2)).andStubReturn(null); + replay(metrics); + streamsMetrics.removeAllClientLevelSensorsAndMetrics(); + + verify(metrics); + } + + @Test + public void shouldRemoveThreadLevelSensors() { + final Metrics metrics = niceMock(Metrics.class); + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time); + addSensorsOnAllLevels(metrics, streamsMetrics); + setupRemoveSensorsTest(metrics, THREAD_ID1); + + streamsMetrics.removeAllThreadLevelSensors(THREAD_ID1); + + verify(metrics); + } + + @Test + public void testNullMetrics() { + assertThrows(NullPointerException.class, () -> new StreamsMetricsImpl(null, "", VERSION, time)); + } + + @Test + public void testRemoveNullSensor() { + assertThrows(NullPointerException.class, () -> streamsMetrics.removeSensor(null)); + } + + @Test + public void testRemoveSensor() { + final String sensorName = "sensor1"; + final String scope = "scope"; + final String entity = "entity"; + final String operation = "put"; + + final Sensor sensor1 = streamsMetrics.addSensor(sensorName, RecordingLevel.DEBUG); + streamsMetrics.removeSensor(sensor1); + + final Sensor sensor1a = streamsMetrics.addSensor(sensorName, RecordingLevel.DEBUG, sensor1); + streamsMetrics.removeSensor(sensor1a); + + final Sensor sensor2 = streamsMetrics.addLatencyRateTotalSensor(scope, entity, operation, RecordingLevel.DEBUG); + streamsMetrics.removeSensor(sensor2); + + final Sensor sensor3 = streamsMetrics.addRateTotalSensor(scope, entity, operation, RecordingLevel.DEBUG); + streamsMetrics.removeSensor(sensor3); + + assertEquals(Collections.emptyMap(), streamsMetrics.parentSensors()); + } + + @Test + public void testMultiLevelSensorRemoval() { + final Metrics registry = new Metrics(); + final StreamsMetricsImpl metrics = new StreamsMetricsImpl(registry, THREAD_ID1, VERSION, time); + for (final MetricName defaultMetric : registry.metrics().keySet()) { + registry.removeMetric(defaultMetric); + } + + final String taskName = "taskName"; + final String operation = "operation"; + final Map taskTags = mkMap(mkEntry("tkey", "value")); + + final String processorNodeName = "processorNodeName"; + final Map nodeTags = mkMap(mkEntry("nkey", "value")); + + final Sensor parent1 = metrics.taskLevelSensor(THREAD_ID1, taskName, operation, RecordingLevel.DEBUG); + addAvgAndMaxLatencyToSensor(parent1, PROCESSOR_NODE_LEVEL_GROUP, taskTags, operation); + addInvocationRateAndCountToSensor(parent1, PROCESSOR_NODE_LEVEL_GROUP, taskTags, operation, "", ""); + + final int numberOfTaskMetrics = registry.metrics().size(); + + final Sensor sensor1 = metrics.nodeLevelSensor(THREAD_ID1, taskName, processorNodeName, operation, RecordingLevel.DEBUG, parent1); + addAvgAndMaxLatencyToSensor(sensor1, PROCESSOR_NODE_LEVEL_GROUP, nodeTags, operation); + addInvocationRateAndCountToSensor(sensor1, PROCESSOR_NODE_LEVEL_GROUP, nodeTags, operation, "", ""); + + assertThat(registry.metrics().size(), greaterThan(numberOfTaskMetrics)); + + metrics.removeAllNodeLevelSensors(THREAD_ID1, taskName, processorNodeName); + + assertThat(registry.metrics().size(), equalTo(numberOfTaskMetrics)); + + final Sensor parent2 = metrics.taskLevelSensor(THREAD_ID1, taskName, operation, RecordingLevel.DEBUG); + addAvgAndMaxLatencyToSensor(parent2, PROCESSOR_NODE_LEVEL_GROUP, taskTags, operation); + addInvocationRateAndCountToSensor(parent2, PROCESSOR_NODE_LEVEL_GROUP, taskTags, operation, "", ""); + + assertThat(registry.metrics().size(), equalTo(numberOfTaskMetrics)); + + final Sensor sensor2 = metrics.nodeLevelSensor(THREAD_ID1, taskName, processorNodeName, operation, RecordingLevel.DEBUG, parent2); + addAvgAndMaxLatencyToSensor(sensor2, PROCESSOR_NODE_LEVEL_GROUP, nodeTags, operation); + addInvocationRateAndCountToSensor(sensor2, PROCESSOR_NODE_LEVEL_GROUP, nodeTags, operation, "", ""); + + assertThat(registry.metrics().size(), greaterThan(numberOfTaskMetrics)); + + metrics.removeAllNodeLevelSensors(THREAD_ID1, taskName, processorNodeName); + + assertThat(registry.metrics().size(), equalTo(numberOfTaskMetrics)); + + metrics.removeAllTaskLevelSensors(THREAD_ID1, taskName); + + assertThat(registry.metrics().size(), equalTo(0)); + } + + @Test + public void testLatencyMetrics() { + final int defaultMetrics = streamsMetrics.metrics().size(); + + final String scope = "scope"; + final String entity = "entity"; + final String operation = "put"; + + final Sensor sensor1 = streamsMetrics.addLatencyRateTotalSensor(scope, entity, operation, RecordingLevel.DEBUG); + + final int meterMetricsCount = 2; // Each Meter is a combination of a Rate and a Total + final int otherMetricsCount = 2; // Latency-max and Latency-avg + // 2 meters and 2 non-meter metrics plus a common metric that keeps track of total registered metrics in Metrics() constructor + assertEquals(defaultMetrics + meterMetricsCount + otherMetricsCount, streamsMetrics.metrics().size()); + + streamsMetrics.removeSensor(sensor1); + assertEquals(defaultMetrics, streamsMetrics.metrics().size()); + } + + @Test + public void testThroughputMetrics() { + final int defaultMetrics = streamsMetrics.metrics().size(); + + final String scope = "scope"; + final String entity = "entity"; + final String operation = "put"; + + final Sensor sensor1 = streamsMetrics.addRateTotalSensor(scope, entity, operation, RecordingLevel.DEBUG); + + final int meterMetricsCount = 2; // Each Meter is a combination of a Rate and a Total + // 2 meter metrics plus a common metric that keeps track of total registered metrics in Metrics() constructor + assertEquals(defaultMetrics + meterMetricsCount, streamsMetrics.metrics().size()); + + streamsMetrics.removeSensor(sensor1); + assertEquals(defaultMetrics, streamsMetrics.metrics().size()); + } + + @Test + public void testTotalMetricDoesntDecrease() { + final MockTime time = new MockTime(1); + final MetricConfig config = new MetricConfig().timeWindow(1, TimeUnit.MILLISECONDS); + final Metrics metrics = new Metrics(config, time); + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, "", VERSION, time); + + final String scope = "scope"; + final String entity = "entity"; + final String operation = "op"; + + final Sensor sensor = streamsMetrics.addLatencyRateTotalSensor( + scope, + entity, + operation, + RecordingLevel.INFO + ); + + final double latency = 100.0; + final MetricName totalMetricName = metrics.metricName( + "op-total", + "stream-scope-metrics", + "", + "thread-id", + Thread.currentThread().getName(), + "scope-id", + "entity" + ); + + final KafkaMetric totalMetric = metrics.metric(totalMetricName); + + for (int i = 0; i < 10; i++) { + assertEquals(i, Math.round(totalMetric.measurable().measure(config, time.milliseconds()))); + sensor.record(latency, time.milliseconds()); + } + } + + @Test + public void shouldAddLatencyRateTotalSensor() { + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time); + shouldAddCustomSensor( + streamsMetrics.addLatencyRateTotalSensor(SCOPE_NAME, ENTITY_NAME, OPERATION_NAME, RecordingLevel.DEBUG), + streamsMetrics, + Arrays.asList( + OPERATION_NAME + LATENCY_SUFFIX + AVG_SUFFIX, + OPERATION_NAME + LATENCY_SUFFIX + MAX_SUFFIX, + OPERATION_NAME + TOTAL_SUFFIX, + OPERATION_NAME + RATE_SUFFIX + ) + ); + } + + @Test + public void shouldAddRateTotalSensor() { + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, CLIENT_ID, VERSION, time); + shouldAddCustomSensor( + streamsMetrics.addRateTotalSensor(SCOPE_NAME, ENTITY_NAME, OPERATION_NAME, RecordingLevel.DEBUG), + streamsMetrics, + Arrays.asList(OPERATION_NAME + TOTAL_SUFFIX, OPERATION_NAME + RATE_SUFFIX) + ); + } + + @Test + public void shouldAddLatencyRateTotalSensorWithCustomTags() { + final Sensor sensor = streamsMetrics.addLatencyRateTotalSensor( + SCOPE_NAME, + ENTITY_NAME, + OPERATION_NAME, + RecordingLevel.DEBUG, + CUSTOM_TAG_KEY1, + CUSTOM_TAG_VALUE1, + CUSTOM_TAG_KEY2, + CUSTOM_TAG_VALUE2 + ); + final Map tags = customTags(streamsMetrics); + shouldAddCustomSensorWithTags( + sensor, + Arrays.asList( + OPERATION_NAME + LATENCY_SUFFIX + AVG_SUFFIX, + OPERATION_NAME + LATENCY_SUFFIX + MAX_SUFFIX, + OPERATION_NAME + TOTAL_SUFFIX, + OPERATION_NAME + RATE_SUFFIX + ), + tags + ); + } + + @Test + public void shouldAddRateTotalSensorWithCustomTags() { + final Sensor sensor = streamsMetrics.addRateTotalSensor( + SCOPE_NAME, + ENTITY_NAME, + OPERATION_NAME, + RecordingLevel.DEBUG, + CUSTOM_TAG_KEY1, + CUSTOM_TAG_VALUE1, + CUSTOM_TAG_KEY2, + CUSTOM_TAG_VALUE2 + ); + final Map tags = customTags(streamsMetrics); + shouldAddCustomSensorWithTags( + sensor, + Arrays.asList( + OPERATION_NAME + TOTAL_SUFFIX, + OPERATION_NAME + RATE_SUFFIX + ), + tags + ); + } + + private void shouldAddCustomSensor(final Sensor sensor, + final StreamsMetricsImpl streamsMetrics, + final List metricsNames) { + final Map tags = tags(streamsMetrics); + shouldAddCustomSensorWithTags(sensor, metricsNames, tags); + } + + private void shouldAddCustomSensorWithTags(final Sensor sensor, + final List metricsNames, + final Map tags) { + final String group = "stream-" + SCOPE_NAME + "-metrics"; + assertTrue(sensor.hasMetrics()); + assertThat( + sensor.name(), + is("external." + Thread.currentThread().getName() + ".entity." + ENTITY_NAME + ".s." + OPERATION_NAME) + ); + for (final String name : metricsNames) { + assertTrue(StreamsTestUtils.containsMetric(metrics, name, group, tags)); + } + } + + private Map tags(final StreamsMetricsImpl streamsMetrics) { + return mkMap( + mkEntry( + streamsMetrics.version() == Version.LATEST ? THREAD_ID_TAG : CLIENT_ID_TAG, + Thread.currentThread().getName() + ), + mkEntry(SCOPE_NAME + "-id", ENTITY_NAME) + ); + } + + private Map customTags(final StreamsMetricsImpl streamsMetrics) { + final Map tags = tags(streamsMetrics); + tags.put(CUSTOM_TAG_KEY1, CUSTOM_TAG_VALUE1); + tags.put(CUSTOM_TAG_KEY2, CUSTOM_TAG_VALUE2); + return tags; + } + + @Test + public void shouldThrowIfLatencyRateTotalSensorIsAddedWithOddTags() { + final IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> streamsMetrics.addLatencyRateTotalSensor( + SCOPE_NAME, + ENTITY_NAME, + OPERATION_NAME, + RecordingLevel.DEBUG, + "bad-tag") + ); + assertThat(exception.getMessage(), is("Tags needs to be specified in key-value pairs")); + } + + @Test + public void shouldThrowIfRateTotalSensorIsAddedWithOddTags() { + final IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> streamsMetrics.addRateTotalSensor( + SCOPE_NAME, + ENTITY_NAME, + OPERATION_NAME, + RecordingLevel.DEBUG, + "bad-tag") + ); + assertThat(exception.getMessage(), is("Tags needs to be specified in key-value pairs")); + } + + @Test + public void shouldGetClientLevelTagMap() { + final Map tagMap = streamsMetrics.clientLevelTagMap(); + + assertThat(tagMap.size(), equalTo(1)); + assertThat(tagMap.get(StreamsMetricsImpl.CLIENT_ID_TAG), equalTo(CLIENT_ID)); + } + + @Test + public void shouldGetStoreLevelTagMap() { + final String taskName = "test-task"; + final String storeType = "remote-window"; + final String storeName = "window-keeper"; + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, THREAD_ID1, VERSION, time); + + final Map tagMap = streamsMetrics.storeLevelTagMap(taskName, storeType, storeName); + + assertThat(tagMap.size(), equalTo(3)); + assertThat( + tagMap.get(StreamsMetricsImpl.THREAD_ID_TAG), + equalTo(Thread.currentThread().getName())); + assertThat(tagMap.get(StreamsMetricsImpl.TASK_ID_TAG), equalTo(taskName)); + assertThat(tagMap.get(storeType + "-" + StreamsMetricsImpl.STORE_ID_TAG), equalTo(storeName)); + } + + @Test + public void shouldGetCacheLevelTagMap() { + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, THREAD_ID1, VERSION, time); + final String taskName = "taskName"; + final String storeName = "storeName"; + + final Map tagMap = streamsMetrics.cacheLevelTagMap(THREAD_ID1, taskName, storeName); + + assertThat(tagMap.size(), equalTo(3)); + assertThat( + tagMap.get(StreamsMetricsImpl.THREAD_ID_TAG), + equalTo(THREAD_ID1) + ); + assertThat(tagMap.get(TASK_ID_TAG), equalTo(taskName)); + assertThat(tagMap.get(RECORD_CACHE_ID_TAG), equalTo(storeName)); + } + + @Test + public void shouldGetThreadLevelTagMap() { + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl(metrics, THREAD_ID1, VERSION, time); + + final Map tagMap = streamsMetrics.threadLevelTagMap(THREAD_ID1); + + assertThat(tagMap.size(), equalTo(1)); + assertThat( + tagMap.get(THREAD_ID_TAG), + equalTo(THREAD_ID1) + ); + } + + @Test + public void shouldAddInvocationRateToSensor() { + final Sensor sensor = createMock(Sensor.class); + final MetricName expectedMetricName = new MetricName(METRIC_NAME1 + "-rate", group, DESCRIPTION1, tags); + expect(sensor.add(eq(expectedMetricName), anyObject(Rate.class))).andReturn(true); + replay(sensor); + + StreamsMetricsImpl.addInvocationRateToSensor(sensor, group, tags, METRIC_NAME1, DESCRIPTION1); + + verify(sensor); + } + + @Test + public void shouldAddAmountRateAndSum() { + StreamsMetricsImpl + .addRateOfSumAndSumMetricsToSensor(sensor, group, tags, metricNamePrefix, DESCRIPTION1, DESCRIPTION2); + + final double valueToRecord1 = 18.0; + final double valueToRecord2 = 72.0; + final long defaultWindowSizeInSeconds = Duration.ofMillis(new MetricConfig().timeWindowMs()).getSeconds(); + final double expectedRateMetricValue = (valueToRecord1 + valueToRecord2) / defaultWindowSizeInSeconds; + verifyMetric(metricNamePrefix + "-rate", DESCRIPTION1, valueToRecord1, valueToRecord2, expectedRateMetricValue); + final double expectedSumMetricValue = 2 * valueToRecord1 + 2 * valueToRecord2; // values are recorded once for each metric verification + verifyMetric(metricNamePrefix + "-total", DESCRIPTION2, valueToRecord1, valueToRecord2, expectedSumMetricValue); + assertThat(metrics.metrics().size(), equalTo(2 + 1)); // one metric is added automatically in the constructor of Metrics + } + + @Test + public void shouldAddSum() { + StreamsMetricsImpl.addSumMetricToSensor(sensor, group, tags, metricNamePrefix, DESCRIPTION1); + + final double valueToRecord1 = 18.0; + final double valueToRecord2 = 42.0; + final double expectedSumMetricValue = valueToRecord1 + valueToRecord2; + verifyMetric(metricNamePrefix + "-total", DESCRIPTION1, valueToRecord1, valueToRecord2, expectedSumMetricValue); + assertThat(metrics.metrics().size(), equalTo(1 + 1)); // one metric is added automatically in the constructor of Metrics + } + + @Test + public void shouldAddAmountRate() { + StreamsMetricsImpl.addRateOfSumMetricToSensor(sensor, group, tags, metricNamePrefix, DESCRIPTION1); + + final double valueToRecord1 = 18.0; + final double valueToRecord2 = 72.0; + final long defaultWindowSizeInSeconds = Duration.ofMillis(new MetricConfig().timeWindowMs()).getSeconds(); + final double expectedRateMetricValue = (valueToRecord1 + valueToRecord2) / defaultWindowSizeInSeconds; + verifyMetric(metricNamePrefix + "-rate", DESCRIPTION1, valueToRecord1, valueToRecord2, expectedRateMetricValue); + assertThat(metrics.metrics().size(), equalTo(1 + 1)); // one metric is added automatically in the constructor of Metrics + } + + @Test + public void shouldAddValue() { + StreamsMetricsImpl.addValueMetricToSensor(sensor, group, tags, metricNamePrefix, DESCRIPTION1); + + final KafkaMetric ratioMetric = metrics.metric(new MetricName(metricNamePrefix, group, DESCRIPTION1, tags)); + assertThat(ratioMetric, is(notNullValue())); + final MetricConfig metricConfig = new MetricConfig(); + final double value1 = 42.0; + sensor.record(value1); + assertThat(ratioMetric.measurable().measure(metricConfig, time.milliseconds()), equalTo(42.0)); + final double value2 = 18.0; + sensor.record(value2); + assertThat(ratioMetric.measurable().measure(metricConfig, time.milliseconds()), equalTo(18.0)); + assertThat(metrics.metrics().size(), equalTo(1 + 1)); // one metric is added automatically in the constructor of Metrics + } + + @Test + public void shouldAddAvgAndTotalMetricsToSensor() { + StreamsMetricsImpl + .addAvgAndSumMetricsToSensor(sensor, group, tags, metricNamePrefix, DESCRIPTION1, DESCRIPTION2); + + final double valueToRecord1 = 18.0; + final double valueToRecord2 = 42.0; + final double expectedAvgMetricValue = (valueToRecord1 + valueToRecord2) / 2; + verifyMetric(metricNamePrefix + "-avg", DESCRIPTION1, valueToRecord1, valueToRecord2, expectedAvgMetricValue); + final double expectedSumMetricValue = 2 * valueToRecord1 + 2 * valueToRecord2; // values are recorded once for each metric verification + verifyMetric(metricNamePrefix + "-total", DESCRIPTION2, valueToRecord1, valueToRecord2, expectedSumMetricValue); + assertThat(metrics.metrics().size(), equalTo(2 + 1)); // one metric is added automatically in the constructor of Metrics + } + + @Test + public void shouldAddAvgAndMinAndMaxMetricsToSensor() { + StreamsMetricsImpl + .addAvgAndMinAndMaxToSensor(sensor, group, tags, metricNamePrefix, DESCRIPTION1, DESCRIPTION2, DESCRIPTION3); + + final double valueToRecord1 = 18.0; + final double valueToRecord2 = 42.0; + final double expectedAvgMetricValue = (valueToRecord1 + valueToRecord2) / 2; + verifyMetric(metricNamePrefix + "-avg", DESCRIPTION1, valueToRecord1, valueToRecord2, expectedAvgMetricValue); + verifyMetric(metricNamePrefix + "-min", DESCRIPTION2, valueToRecord1, valueToRecord2, valueToRecord1); + verifyMetric(metricNamePrefix + "-max", DESCRIPTION3, valueToRecord1, valueToRecord2, valueToRecord2); + assertThat(metrics.metrics().size(), equalTo(3 + 1)); // one metric is added automatically in the constructor of Metrics + } + + @Test + public void shouldAddMinAndMaxMetricsToSensor() { + StreamsMetricsImpl + .addMinAndMaxToSensor(sensor, group, tags, metricNamePrefix, DESCRIPTION1, DESCRIPTION2); + + final double valueToRecord1 = 18.0; + final double valueToRecord2 = 42.0; + verifyMetric(metricNamePrefix + "-min", DESCRIPTION1, valueToRecord1, valueToRecord2, valueToRecord1); + verifyMetric(metricNamePrefix + "-max", DESCRIPTION2, valueToRecord1, valueToRecord2, valueToRecord2); + assertThat(metrics.metrics().size(), equalTo(2 + 1)); // one metric is added automatically in the constructor of Metrics + } + + @Test + public void shouldReturnMetricsVersionCurrent() { + assertThat( + new StreamsMetricsImpl(metrics, THREAD_ID1, StreamsConfig.METRICS_LATEST, time).version(), + equalTo(Version.LATEST) + ); + } + + private void verifyMetric(final String name, + final String description, + final double valueToRecord1, + final double valueToRecord2, + final double expectedMetricValue) { + final KafkaMetric metric = metrics + .metric(new MetricName(name, group, description, tags)); + assertThat(metric, is(notNullValue())); + assertThat(metric.metricName().description(), equalTo(description)); + sensor.record(valueToRecord1, time.milliseconds()); + sensor.record(valueToRecord2, time.milliseconds()); + assertThat( + metric.measurable().measure(new MetricConfig(), time.milliseconds()), + equalTo(expectedMetricValue) + ); + } + + @Test + public void shouldMeasureLatency() { + final long startTime = 6; + final long endTime = 10; + final Sensor sensor = createMock(Sensor.class); + expect(sensor.shouldRecord()).andReturn(true); + expect(sensor.hasMetrics()).andReturn(true); + sensor.record(endTime - startTime); + final Time time = mock(Time.class); + expect(time.nanoseconds()).andReturn(startTime); + expect(time.nanoseconds()).andReturn(endTime); + replay(sensor, time); + + StreamsMetricsImpl.maybeMeasureLatency(() -> { }, time, sensor); + + verify(sensor, time); + } + + @Test + public void shouldNotMeasureLatencyDueToRecordingLevel() { + final Sensor sensor = createMock(Sensor.class); + expect(sensor.shouldRecord()).andReturn(false); + final Time time = mock(Time.class); + replay(sensor); + + StreamsMetricsImpl.maybeMeasureLatency(() -> { }, time, sensor); + + verify(sensor); + } + + @Test + public void shouldNotMeasureLatencyBecauseSensorHasNoMetrics() { + final Sensor sensor = createMock(Sensor.class); + expect(sensor.shouldRecord()).andReturn(true); + expect(sensor.hasMetrics()).andReturn(false); + final Time time = mock(Time.class); + replay(sensor); + + StreamsMetricsImpl.maybeMeasureLatency(() -> { }, time, sensor); + + verify(sensor); + } + + @Test + public void shouldAddThreadLevelMutableMetric() { + final int measuredValue = 123; + final StreamsMetricsImpl streamsMetrics + = new StreamsMetricsImpl(metrics, THREAD_ID1, VERSION, time); + + streamsMetrics.addThreadLevelMutableMetric( + "foobar", + "test metric", + "t1", + (c, t) -> measuredValue + ); + + final MetricName name = metrics.metricName( + "foobar", + THREAD_LEVEL_GROUP, + Collections.singletonMap("thread-id", "t1") + ); + assertThat(metrics.metric(name), notNullValue()); + assertThat(metrics.metric(name).metricValue(), equalTo(measuredValue)); + } + + @Test + public void shouldCleanupThreadLevelMutableMetric() { + final int measuredValue = 123; + final StreamsMetricsImpl streamsMetrics + = new StreamsMetricsImpl(metrics, THREAD_ID1, VERSION, time); + streamsMetrics.addThreadLevelMutableMetric( + "foobar", + "test metric", + "t1", + (c, t) -> measuredValue + ); + + streamsMetrics.removeAllThreadLevelMetrics("t1"); + + final MetricName name = metrics.metricName( + "foobar", + THREAD_LEVEL_GROUP, + Collections.singletonMap("thread-id", "t1") + ); + assertThat(metrics.metric(name), nullValue()); + } + + @Test + public void shouldAddThreadLevelImmutableMetric() { + final int measuredValue = 123; + final StreamsMetricsImpl streamsMetrics + = new StreamsMetricsImpl(metrics, THREAD_ID1, VERSION, time); + + streamsMetrics.addThreadLevelImmutableMetric( + "foobar", + "test metric", + "t1", + measuredValue + ); + + final MetricName name = metrics.metricName( + "foobar", + THREAD_LEVEL_GROUP, + Collections.singletonMap("thread-id", "t1") + ); + assertThat(metrics.metric(name), notNullValue()); + assertThat(metrics.metric(name).metricValue(), equalTo(measuredValue)); + } + + @Test + public void shouldCleanupThreadLevelImmutableMetric() { + final int measuredValue = 123; + final StreamsMetricsImpl streamsMetrics + = new StreamsMetricsImpl(metrics, THREAD_ID1, VERSION, time); + streamsMetrics.addThreadLevelImmutableMetric( + "foobar", + "test metric", + "t1", + measuredValue + ); + + streamsMetrics.removeAllThreadLevelMetrics("t1"); + + final MetricName name = metrics.metricName( + "foobar", + THREAD_LEVEL_GROUP, + Collections.singletonMap("thread-id", "t1") + ); + assertThat(metrics.metric(name), nullValue()); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/metrics/TaskMetricsTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/metrics/TaskMetricsTest.java new file mode 100644 index 0000000..1d33fea --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/metrics/TaskMetricsTest.java @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.metrics; + +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.Sensor.RecordingLevel; +import org.junit.Test; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.Collections; +import java.util.Map; + +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.TASK_LEVEL_GROUP; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +public class TaskMetricsTest { + + private final static String THREAD_ID = "test-thread"; + private final static String TASK_ID = "test-task"; + + private final StreamsMetricsImpl streamsMetrics = mock(StreamsMetricsImpl.class); + private final Sensor expectedSensor = mock(Sensor.class); + private final Map tagMap = Collections.singletonMap("hello", "world"); + + + @Test + public void shouldGetActiveProcessRatioSensor() { + final String operation = "active-process-ratio"; + when(streamsMetrics.taskLevelSensor(THREAD_ID, TASK_ID, operation, RecordingLevel.INFO)) + .thenReturn(expectedSensor); + + final String ratioDescription = "The fraction of time the thread spent " + + "on processing this task among all assigned active tasks"; + when(streamsMetrics.taskLevelTagMap(THREAD_ID, TASK_ID)).thenReturn(tagMap); + StreamsMetricsImpl.addValueMetricToSensor( + expectedSensor, + TASK_LEVEL_GROUP, + tagMap, + operation, + ratioDescription + ); + + + final Sensor sensor = TaskMetrics.activeProcessRatioSensor(THREAD_ID, TASK_ID, streamsMetrics); + + assertThat(sensor, is(expectedSensor)); + } + + @Test + public void shouldGetActiveBufferCountSensor() { + final String operation = "active-buffer-count"; + when(streamsMetrics.taskLevelSensor(THREAD_ID, TASK_ID, operation, RecordingLevel.DEBUG)) + .thenReturn(expectedSensor); + final String countDescription = "The count of buffered records that are polled " + + "from consumer and not yet processed for this active task"; + when(streamsMetrics.taskLevelTagMap(THREAD_ID, TASK_ID)).thenReturn(tagMap); + StreamsMetricsImpl.addValueMetricToSensor( + expectedSensor, + TASK_LEVEL_GROUP, + tagMap, + operation, + countDescription + ); + + + final Sensor sensor = TaskMetrics.activeBufferedRecordsSensor(THREAD_ID, TASK_ID, streamsMetrics); + + assertThat(sensor, is(expectedSensor)); + } + + @Test + public void shouldGetProcessLatencySensor() { + final String operation = "process-latency"; + when(streamsMetrics.taskLevelSensor(THREAD_ID, TASK_ID, operation, RecordingLevel.DEBUG)) + .thenReturn(expectedSensor); + final String avgLatencyDescription = "The average latency of calls to process"; + final String maxLatencyDescription = "The maximum latency of calls to process"; + when(streamsMetrics.taskLevelTagMap(THREAD_ID, TASK_ID)).thenReturn(tagMap); + StreamsMetricsImpl.addAvgAndMaxToSensor( + expectedSensor, + TASK_LEVEL_GROUP, + tagMap, + operation, + avgLatencyDescription, + maxLatencyDescription + ); + + final Sensor sensor = TaskMetrics.processLatencySensor(THREAD_ID, TASK_ID, streamsMetrics); + + assertThat(sensor, is(expectedSensor)); + } + + @Test + public void shouldGetPunctuateSensor() { + final String operation = "punctuate"; + when(streamsMetrics.taskLevelSensor(THREAD_ID, TASK_ID, operation, RecordingLevel.DEBUG)) + .thenReturn(expectedSensor); + final String operationLatency = operation + StreamsMetricsImpl.LATENCY_SUFFIX; + final String totalDescription = "The total number of calls to punctuate"; + final String rateDescription = "The average number of calls to punctuate per second"; + final String avgLatencyDescription = "The average latency of calls to punctuate"; + final String maxLatencyDescription = "The maximum latency of calls to punctuate"; + when(streamsMetrics.taskLevelTagMap(THREAD_ID, TASK_ID)).thenReturn(tagMap); + StreamsMetricsImpl.addInvocationRateAndCountToSensor( + expectedSensor, + TASK_LEVEL_GROUP, + tagMap, + operation, + rateDescription, + totalDescription + ); + StreamsMetricsImpl.addAvgAndMaxToSensor( + expectedSensor, + TASK_LEVEL_GROUP, + tagMap, + operationLatency, + avgLatencyDescription, + maxLatencyDescription + ); + + final Sensor sensor = TaskMetrics.punctuateSensor(THREAD_ID, TASK_ID, streamsMetrics); + + assertThat(sensor, is(expectedSensor)); + } + + @Test + public void shouldGetCommitSensor() { + final String operation = "commit"; + final String totalDescription = "The total number of calls to commit"; + final String rateDescription = "The average number of calls to commit per second"; + when(streamsMetrics.taskLevelSensor(THREAD_ID, TASK_ID, operation, RecordingLevel.DEBUG)).thenReturn(expectedSensor); + when(streamsMetrics.taskLevelTagMap(THREAD_ID, TASK_ID)).thenReturn(tagMap); + StreamsMetricsImpl.addInvocationRateAndCountToSensor( + expectedSensor, + TASK_LEVEL_GROUP, + tagMap, + operation, + rateDescription, + totalDescription + ); + + final Sensor sensor = TaskMetrics.commitSensor(THREAD_ID, TASK_ID, streamsMetrics); + + assertThat(sensor, is(expectedSensor)); + } + + @Test + public void shouldGetEnforcedProcessingSensor() { + final String operation = "enforced-processing"; + final String totalDescription = "The total number of occurrences of enforced-processing operations"; + final String rateDescription = "The average number of occurrences of enforced-processing operations per second"; + when(streamsMetrics.taskLevelSensor(THREAD_ID, TASK_ID, operation, RecordingLevel.DEBUG)).thenReturn(expectedSensor); + when(streamsMetrics.taskLevelTagMap(THREAD_ID, TASK_ID)).thenReturn(tagMap); + StreamsMetricsImpl.addInvocationRateAndCountToSensor( + expectedSensor, + TASK_LEVEL_GROUP, + tagMap, + operation, + rateDescription, + totalDescription + ); + + final Sensor sensor = TaskMetrics.enforcedProcessingSensor(THREAD_ID, TASK_ID, streamsMetrics); + + assertThat(sensor, is(expectedSensor)); + } + + @Test + public void shouldGetRecordLatenessSensor() { + final String operation = "record-lateness"; + final String avgDescription = + "The observed average lateness of records in milliseconds, measured by comparing the record timestamp with " + + "the current stream time"; + final String maxDescription = + "The observed maximum lateness of records in milliseconds, measured by comparing the record timestamp with " + + "the current stream time"; + when(streamsMetrics.taskLevelSensor(THREAD_ID, TASK_ID, operation, RecordingLevel.DEBUG)).thenReturn(expectedSensor); + when(streamsMetrics.taskLevelTagMap(THREAD_ID, TASK_ID)).thenReturn(tagMap); + StreamsMetricsImpl.addAvgAndMaxToSensor( + expectedSensor, + TASK_LEVEL_GROUP, + tagMap, + operation, + avgDescription, + maxDescription + ); + + final Sensor sensor = TaskMetrics.recordLatenessSensor(THREAD_ID, TASK_ID, streamsMetrics); + + assertThat(sensor, is(expectedSensor)); + } + + @Test + public void shouldGetDroppedRecordsSensor() { + final String operation = "dropped-records"; + final String totalDescription = "The total number of dropped records"; + final String rateDescription = "The average number of dropped records per second"; + when(streamsMetrics.taskLevelSensor(THREAD_ID, TASK_ID, operation, RecordingLevel.INFO)).thenReturn(expectedSensor); + when(streamsMetrics.taskLevelTagMap(THREAD_ID, TASK_ID)).thenReturn(tagMap); + StreamsMetricsImpl.addInvocationRateAndCountToSensor( + expectedSensor, + TASK_LEVEL_GROUP, + tagMap, + operation, + rateDescription, + totalDescription + ); + + final Sensor sensor = TaskMetrics.droppedRecordsSensor(THREAD_ID, TASK_ID, streamsMetrics); + + assertThat(sensor, is(expectedSensor)); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/metrics/ThreadMetricsTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/metrics/ThreadMetricsTest.java new file mode 100644 index 0000000..6ed97eb --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/metrics/ThreadMetricsTest.java @@ -0,0 +1,447 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.metrics; + +import org.apache.kafka.common.metrics.Gauge; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.Sensor.RecordingLevel; +import org.apache.kafka.streams.processor.internals.StreamThreadTotalBlockedTime; +import org.junit.Test; + +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.Collections; +import java.util.Map; +import org.mockito.ArgumentCaptor; + +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.LATENCY_SUFFIX; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.RATE_SUFFIX; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.ROLLUP_VALUE; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +public class ThreadMetricsTest { + + private static final String THREAD_ID = "thread-id"; + private static final String THREAD_LEVEL_GROUP = "stream-thread-metrics"; + private static final String TASK_LEVEL_GROUP = "stream-task-metrics"; + + private final Sensor expectedSensor = mock(Sensor.class); + private final StreamsMetricsImpl streamsMetrics = mock(StreamsMetricsImpl.class); + private final Map tagMap = Collections.singletonMap("hello", "world"); + + + @Test + public void shouldGetProcessRatioSensor() { + final String operation = "process-ratio"; + final String ratioDescription = "The fraction of time the thread spent on processing active tasks"; + when(streamsMetrics.threadLevelSensor(THREAD_ID, operation, RecordingLevel.INFO)).thenReturn(expectedSensor); + when(streamsMetrics.threadLevelTagMap(THREAD_ID)).thenReturn(tagMap); + StreamsMetricsImpl.addValueMetricToSensor( + expectedSensor, + THREAD_LEVEL_GROUP, + tagMap, + operation, + ratioDescription + ); + + final Sensor sensor = ThreadMetrics.processRatioSensor(THREAD_ID, streamsMetrics); + + assertThat(sensor, is(expectedSensor)); + } + + @Test + public void shouldGetProcessRecordsSensor() { + final String operation = "process-records"; + final String avgDescription = "The average number of records processed within an iteration"; + final String maxDescription = "The maximum number of records processed within an iteration"; + when(streamsMetrics.threadLevelSensor(THREAD_ID, operation, RecordingLevel.INFO)).thenReturn(expectedSensor); + when(streamsMetrics.threadLevelTagMap(THREAD_ID)).thenReturn(tagMap); + StreamsMetricsImpl.addAvgAndMaxToSensor( + expectedSensor, + THREAD_LEVEL_GROUP, + tagMap, + operation, + avgDescription, + maxDescription + ); + + final Sensor sensor = ThreadMetrics.processRecordsSensor(THREAD_ID, streamsMetrics); + + assertThat(sensor, is(expectedSensor)); + } + + @Test + public void shouldGetProcessLatencySensor() { + final String operationLatency = "process" + LATENCY_SUFFIX; + final String avgLatencyDescription = "The average process latency"; + final String maxLatencyDescription = "The maximum process latency"; + when(streamsMetrics.threadLevelSensor(THREAD_ID, operationLatency, RecordingLevel.INFO)).thenReturn(expectedSensor); + when(streamsMetrics.threadLevelTagMap(THREAD_ID)).thenReturn(tagMap); + StreamsMetricsImpl.addAvgAndMaxToSensor( + expectedSensor, + THREAD_LEVEL_GROUP, + tagMap, + operationLatency, + avgLatencyDescription, + maxLatencyDescription + ); + + final Sensor sensor = ThreadMetrics.processLatencySensor(THREAD_ID, streamsMetrics); + + assertThat(sensor, is(expectedSensor)); + } + + @Test + public void shouldGetProcessRateSensor() { + final String operation = "process"; + final String operationRate = "process" + RATE_SUFFIX; + final String totalDescription = "The total number of calls to process"; + final String rateDescription = "The average per-second number of calls to process"; + when(streamsMetrics.threadLevelSensor(THREAD_ID, operationRate, RecordingLevel.INFO)).thenReturn(expectedSensor); + when(streamsMetrics.threadLevelTagMap(THREAD_ID)).thenReturn(tagMap); + StreamsMetricsImpl.addRateOfSumAndSumMetricsToSensor( + expectedSensor, + THREAD_LEVEL_GROUP, + tagMap, + operation, + rateDescription, + totalDescription + ); + + final Sensor sensor = ThreadMetrics.processRateSensor(THREAD_ID, streamsMetrics); + + assertThat(sensor, is(expectedSensor)); + } + + @Test + public void shouldGetPollRatioSensor() { + final String operation = "poll-ratio"; + final String ratioDescription = "The fraction of time the thread spent on polling records from consumer"; + when(streamsMetrics.threadLevelSensor(THREAD_ID, operation, RecordingLevel.INFO)).thenReturn(expectedSensor); + when(streamsMetrics.threadLevelTagMap(THREAD_ID)).thenReturn(tagMap); + StreamsMetricsImpl.addValueMetricToSensor( + expectedSensor, + THREAD_LEVEL_GROUP, + tagMap, + operation, + ratioDescription + ); + + final Sensor sensor = ThreadMetrics.pollRatioSensor(THREAD_ID, streamsMetrics); + + assertThat(sensor, is(expectedSensor)); + } + + @Test + public void shouldGetPollRecordsSensor() { + final String operation = "poll-records"; + final String avgDescription = "The average number of records polled from consumer within an iteration"; + final String maxDescription = "The maximum number of records polled from consumer within an iteration"; + when(streamsMetrics.threadLevelSensor(THREAD_ID, operation, RecordingLevel.INFO)).thenReturn(expectedSensor); + when(streamsMetrics.threadLevelTagMap(THREAD_ID)).thenReturn(tagMap); + StreamsMetricsImpl.addAvgAndMaxToSensor( + expectedSensor, + THREAD_LEVEL_GROUP, + tagMap, + operation, + avgDescription, + maxDescription + ); + + final Sensor sensor = ThreadMetrics.pollRecordsSensor(THREAD_ID, streamsMetrics); + + assertThat(sensor, is(expectedSensor)); + } + + @Test + public void shouldGetPollSensor() { + final String operation = "poll"; + final String operationLatency = operation + StreamsMetricsImpl.LATENCY_SUFFIX; + final String totalDescription = "The total number of calls to poll"; + final String rateDescription = "The average per-second number of calls to poll"; + final String avgLatencyDescription = "The average poll latency"; + final String maxLatencyDescription = "The maximum poll latency"; + when(streamsMetrics.threadLevelSensor(THREAD_ID, operation, RecordingLevel.INFO)).thenReturn(expectedSensor); + when(streamsMetrics.threadLevelTagMap(THREAD_ID)).thenReturn(tagMap); + StreamsMetricsImpl.addInvocationRateAndCountToSensor( + expectedSensor, + THREAD_LEVEL_GROUP, + tagMap, + operation, + rateDescription, + totalDescription + ); + StreamsMetricsImpl.addAvgAndMaxToSensor( + expectedSensor, + THREAD_LEVEL_GROUP, + tagMap, + operationLatency, + avgLatencyDescription, + maxLatencyDescription + ); + + final Sensor sensor = ThreadMetrics.pollSensor(THREAD_ID, streamsMetrics); + + assertThat(sensor, is(expectedSensor)); + } + + @Test + public void shouldGetCommitSensor() { + final String operation = "commit"; + final String operationLatency = operation + StreamsMetricsImpl.LATENCY_SUFFIX; + final String totalDescription = "The total number of calls to commit"; + final String rateDescription = "The average per-second number of calls to commit"; + final String avgLatencyDescription = "The average commit latency"; + final String maxLatencyDescription = "The maximum commit latency"; + when(streamsMetrics.threadLevelSensor(THREAD_ID, operation, RecordingLevel.INFO)).thenReturn(expectedSensor); + when(streamsMetrics.threadLevelTagMap(THREAD_ID)).thenReturn(tagMap); + StreamsMetricsImpl.addInvocationRateAndCountToSensor( + expectedSensor, + THREAD_LEVEL_GROUP, + tagMap, + operation, + rateDescription, + totalDescription + ); + StreamsMetricsImpl.addAvgAndMaxToSensor( + expectedSensor, + THREAD_LEVEL_GROUP, + tagMap, + operationLatency, + avgLatencyDescription, + maxLatencyDescription); + + final Sensor sensor = ThreadMetrics.commitSensor(THREAD_ID, streamsMetrics); + + assertThat(sensor, is(expectedSensor)); + } + + @Test + public void shouldGetCommitRatioSensor() { + final String operation = "commit-ratio"; + final String ratioDescription = "The fraction of time the thread spent on committing all tasks"; + when(streamsMetrics.threadLevelSensor(THREAD_ID, operation, RecordingLevel.INFO)).thenReturn(expectedSensor); + when(streamsMetrics.threadLevelTagMap(THREAD_ID)).thenReturn(tagMap); + StreamsMetricsImpl.addValueMetricToSensor( + expectedSensor, + THREAD_LEVEL_GROUP, + tagMap, + operation, + ratioDescription + ); + + final Sensor sensor = ThreadMetrics.commitRatioSensor(THREAD_ID, streamsMetrics); + + assertThat(sensor, is(expectedSensor)); + } + + @Test + public void shouldGetCommitOverTasksSensor() { + final String operation = "commit"; + final String totalDescription = + "The total number of calls to commit over all tasks assigned to one stream thread"; + final String rateDescription = + "The average per-second number of calls to commit over all tasks assigned to one stream thread"; + when(streamsMetrics.threadLevelSensor(THREAD_ID, operation, RecordingLevel.DEBUG)).thenReturn(expectedSensor); + when(streamsMetrics.taskLevelTagMap(THREAD_ID, ROLLUP_VALUE)).thenReturn(tagMap); + StreamsMetricsImpl.addInvocationRateAndCountToSensor( + expectedSensor, + TASK_LEVEL_GROUP, + tagMap, + operation, + rateDescription, + totalDescription + ); + + final Sensor sensor = ThreadMetrics.commitOverTasksSensor(THREAD_ID, streamsMetrics); + + assertThat(sensor, is(expectedSensor)); + } + + @Test + public void shouldGetPunctuateSensor() { + final String operation = "punctuate"; + final String operationLatency = operation + StreamsMetricsImpl.LATENCY_SUFFIX; + final String totalDescription = "The total number of calls to punctuate"; + final String rateDescription = "The average per-second number of calls to punctuate"; + final String avgLatencyDescription = "The average punctuate latency"; + final String maxLatencyDescription = "The maximum punctuate latency"; + when(streamsMetrics.threadLevelSensor(THREAD_ID, operation, RecordingLevel.INFO)).thenReturn(expectedSensor); + when(streamsMetrics.threadLevelTagMap(THREAD_ID)).thenReturn(tagMap); + StreamsMetricsImpl.addInvocationRateAndCountToSensor( + expectedSensor, + THREAD_LEVEL_GROUP, + tagMap, + operation, + rateDescription, + totalDescription + ); + StreamsMetricsImpl.addAvgAndMaxToSensor( + expectedSensor, + THREAD_LEVEL_GROUP, + tagMap, + operationLatency, + avgLatencyDescription, + maxLatencyDescription + ); + + final Sensor sensor = ThreadMetrics.punctuateSensor(THREAD_ID, streamsMetrics); + + assertThat(sensor, is(expectedSensor)); + } + + @Test + public void shouldGetPunctuateRatioSensor() { + final String operation = "punctuate-ratio"; + final String ratioDescription = "The fraction of time the thread spent on punctuating active tasks"; + when(streamsMetrics.threadLevelSensor(THREAD_ID, operation, RecordingLevel.INFO)).thenReturn(expectedSensor); + when(streamsMetrics.threadLevelTagMap(THREAD_ID)).thenReturn(tagMap); + StreamsMetricsImpl.addValueMetricToSensor( + expectedSensor, + THREAD_LEVEL_GROUP, + tagMap, + operation, + ratioDescription + ); + + final Sensor sensor = ThreadMetrics.punctuateRatioSensor(THREAD_ID, streamsMetrics); + + assertThat(sensor, is(expectedSensor)); + } + + @Test + public void shouldGetSkipRecordSensor() { + final String operation = "skipped-records"; + final String totalDescription = "The total number of skipped records"; + final String rateDescription = "The average per-second number of skipped records"; + when(streamsMetrics.threadLevelSensor(THREAD_ID, operation, RecordingLevel.INFO)) + .thenReturn(expectedSensor); + when(streamsMetrics.threadLevelTagMap(THREAD_ID)).thenReturn(tagMap); + StreamsMetricsImpl.addInvocationRateAndCountToSensor( + expectedSensor, + THREAD_LEVEL_GROUP, + tagMap, + operation, + rateDescription, + totalDescription + ); + + final Sensor sensor = ThreadMetrics.skipRecordSensor(THREAD_ID, streamsMetrics); + + assertThat(sensor, is(expectedSensor)); + } + + @Test + public void shouldGetCreateTaskSensor() { + final String operation = "task-created"; + final String totalDescription = "The total number of newly created tasks"; + final String rateDescription = "The average per-second number of newly created tasks"; + when(streamsMetrics.threadLevelSensor(THREAD_ID, operation, RecordingLevel.INFO)).thenReturn(expectedSensor); + when(streamsMetrics.threadLevelTagMap(THREAD_ID)).thenReturn(tagMap); + + StreamsMetricsImpl.addInvocationRateAndCountToSensor( + expectedSensor, + THREAD_LEVEL_GROUP, + tagMap, + operation, + rateDescription, + totalDescription + ); + + + final Sensor sensor = ThreadMetrics.createTaskSensor(THREAD_ID, streamsMetrics); + + assertThat(sensor, is(expectedSensor)); + } + + @Test + public void shouldGetCloseTaskSensor() { + final String operation = "task-closed"; + final String totalDescription = "The total number of closed tasks"; + final String rateDescription = "The average per-second number of closed tasks"; + when(streamsMetrics.threadLevelSensor(THREAD_ID, operation, RecordingLevel.INFO)).thenReturn(expectedSensor); + when(streamsMetrics.threadLevelTagMap(THREAD_ID)).thenReturn(tagMap); + StreamsMetricsImpl.addInvocationRateAndCountToSensor( + expectedSensor, + THREAD_LEVEL_GROUP, + tagMap, + operation, + rateDescription, + totalDescription + ); + + + final Sensor sensor = ThreadMetrics.closeTaskSensor(THREAD_ID, streamsMetrics); + + assertThat(sensor, is(expectedSensor)); + } + + @Test + public void shouldAddThreadStartTimeMetric() { + // Given: + final long startTime = 123L; + + // When: + ThreadMetrics.addThreadStartTimeMetric( + "bongo", + streamsMetrics, + startTime + ); + + // Then: + verify(streamsMetrics).addThreadLevelImmutableMetric( + "thread-start-time", + "The time that the thread was started", + "bongo", + startTime + ); + } + + @Test + public void shouldAddTotalBlockedTimeMetric() { + // Given: + final double startTime = 123.45; + final StreamThreadTotalBlockedTime blockedTime = mock(StreamThreadTotalBlockedTime.class); + when(blockedTime.compute()).thenReturn(startTime); + + // When: + ThreadMetrics.addThreadBlockedTimeMetric( + "burger", + blockedTime, + streamsMetrics + ); + + // Then: + final ArgumentCaptor> captor = gaugeCaptor(); + verify(streamsMetrics).addThreadLevelMutableMetric( + eq("blocked-time-ns-total"), + eq("The total time the thread spent blocked on kafka in nanoseconds"), + eq("burger"), + captor.capture() + ); + assertThat(captor.getValue().value(null, 678L), is(startTime)); + } + + @SuppressWarnings("unchecked") + private ArgumentCaptor> gaugeCaptor() { + return ArgumentCaptor.forClass(Gauge.class); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/testutil/ConsumerRecordUtil.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/testutil/ConsumerRecordUtil.java new file mode 100644 index 0000000..4a6fbbd --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/testutil/ConsumerRecordUtil.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.testutil; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.record.TimestampType; + +import java.util.Optional; + +public final class ConsumerRecordUtil { + private ConsumerRecordUtil() {} + + public static ConsumerRecord record(final String topic, + final int partition, + final long offset, + final K key, + final V value) { + // the no-time constructor in ConsumerRecord initializes the + // timestamp to -1, which is an invalid configuration. Here, + // we initialize it to 0. + return new ConsumerRecord<>( + topic, + partition, + offset, + 0L, + TimestampType.CREATE_TIME, + 0, + 0, + key, + value, + new RecordHeaders(), + Optional.empty() + ); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/testutil/DummyStreamsConfig.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/testutil/DummyStreamsConfig.java new file mode 100644 index 0000000..0174619 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/testutil/DummyStreamsConfig.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.testutil; + +import org.apache.kafka.streams.StreamsConfig; + +import java.util.Properties; + +public class DummyStreamsConfig extends StreamsConfig { + + private final static Properties PROPS = dummyProps(); + + public DummyStreamsConfig() { + super(PROPS); + } + + private static Properties dummyProps() { + final Properties props = new Properties(); + props.put(StreamsConfig.APPLICATION_ID_CONFIG, "dummy-application"); + props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:2171"); + return props; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/testutil/LogCaptureAppender.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/testutil/LogCaptureAppender.java new file mode 100644 index 0000000..41d15da --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/testutil/LogCaptureAppender.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.internals.testutil; + +import org.apache.log4j.AppenderSkeleton; +import org.apache.log4j.Level; +import org.apache.log4j.Logger; +import org.apache.log4j.spi.LoggingEvent; + +import java.util.LinkedList; +import java.util.List; +import java.util.Optional; + +public class LogCaptureAppender extends AppenderSkeleton implements AutoCloseable { + private final List events = new LinkedList<>(); + + @SuppressWarnings("OptionalUsedAsFieldOrParameterType") + public static class Event { + private final String level; + private final String message; + private final Optional throwableInfo; + + Event(final String level, final String message, final Optional throwableInfo) { + this.level = level; + this.message = message; + this.throwableInfo = throwableInfo; + } + + public String getLevel() { + return level; + } + + public String getMessage() { + return message; + } + + public Optional getThrowableInfo() { + return throwableInfo; + } + } + + public static LogCaptureAppender createAndRegister() { + final LogCaptureAppender logCaptureAppender = new LogCaptureAppender(); + Logger.getRootLogger().addAppender(logCaptureAppender); + return logCaptureAppender; + } + + public static LogCaptureAppender createAndRegister(final Class clazz) { + final LogCaptureAppender logCaptureAppender = new LogCaptureAppender(); + Logger.getLogger(clazz).addAppender(logCaptureAppender); + return logCaptureAppender; + } + + public static void setClassLoggerToDebug(final Class clazz) { + Logger.getLogger(clazz).setLevel(Level.DEBUG); + } + + public static void setClassLoggerToTrace(final Class clazz) { + Logger.getLogger(clazz).setLevel(Level.TRACE); + } + + public static void unregister(final LogCaptureAppender logCaptureAppender) { + Logger.getRootLogger().removeAppender(logCaptureAppender); + } + + @Override + protected void append(final LoggingEvent event) { + synchronized (events) { + events.add(event); + } + } + + public List getMessages() { + final LinkedList result = new LinkedList<>(); + synchronized (events) { + for (final LoggingEvent event : events) { + result.add(event.getRenderedMessage()); + } + } + return result; + } + + public List getEvents() { + final LinkedList result = new LinkedList<>(); + synchronized (events) { + for (final LoggingEvent event : events) { + final String[] throwableStrRep = event.getThrowableStrRep(); + final Optional throwableString; + if (throwableStrRep == null) { + throwableString = Optional.empty(); + } else { + final StringBuilder throwableStringBuilder = new StringBuilder(); + + for (final String s : throwableStrRep) { + throwableStringBuilder.append(s); + } + + throwableString = Optional.of(throwableStringBuilder.toString()); + } + + result.add(new Event(event.getLevel().toString(), event.getRenderedMessage(), throwableString)); + } + } + return result; + } + + @Override + public void close() { + unregister(this); + } + + @Override + public boolean requiresLayout() { + return false; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/HostInfoTest.java b/streams/src/test/java/org/apache/kafka/streams/state/HostInfoTest.java new file mode 100644 index 0000000..0bee8e6 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/HostInfoTest.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; + +import org.apache.kafka.common.config.ConfigException; +import org.junit.Test; + +public class HostInfoTest { + + @Test + public void shouldCreateHostInfo() { + final String endPoint = "host:9090"; + final HostInfo hostInfo = HostInfo.buildFromEndpoint(endPoint); + + assertThat(hostInfo.host(), is("host")); + assertThat(hostInfo.port(), is(9090)); + } + + @Test + public void shouldReturnNullHostInfoForNullEndPoint() { + assertNull(HostInfo.buildFromEndpoint(null)); + } + + @Test + public void shouldReturnNullHostInfoForEmptyEndPoint() { + assertNull(HostInfo.buildFromEndpoint(" ")); + } + + @Test + public void shouldThrowConfigExceptionForNonsenseEndPoint() { + assertThrows(ConfigException.class, () -> HostInfo.buildFromEndpoint("nonsense")); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java b/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java new file mode 100644 index 0000000..6a95ccb --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/KeyValueStoreTestDriver.java @@ -0,0 +1,424 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state; + +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.DefaultProductionExceptionHandler; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateRestoreCallback; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.StreamPartitioner; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.MockStreamsMetrics; +import org.apache.kafka.streams.processor.internals.RecordCollector; +import org.apache.kafka.streams.processor.internals.RecordCollectorImpl; +import org.apache.kafka.streams.processor.internals.StreamsProducer; +import org.apache.kafka.streams.state.internals.MeteredKeyValueStore; +import org.apache.kafka.streams.state.internals.ThreadCache; +import org.apache.kafka.test.InternalMockProcessorContext; +import org.apache.kafka.test.MockClientSupplier; +import org.apache.kafka.test.MockRocksDbConfigSetter; +import org.apache.kafka.test.MockTimestampExtractor; +import org.apache.kafka.test.TestUtils; + +import java.io.File; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Properties; +import java.util.Set; + +/** + * A component that provides a {@link #context() ProcessingContext} that can be supplied to a {@link KeyValueStore} so that + * all entries written to the Kafka topic by the store during {@link KeyValueStore#flush()} are captured for testing purposes. + * This class simplifies testing of various {@link KeyValueStore} instances, especially those that use + * {@link MeteredKeyValueStore} to monitor and write its entries to the Kafka topic. + * + *

                Basic usage

                + * This component can be used to help test a {@link KeyValueStore}'s ability to read and write entries. + * + *
                + * // Create the test driver ...
                + * KeyValueStoreTestDriver<Integer, String> driver = KeyValueStoreTestDriver.create();
                + * KeyValueStore<Integer, String> store = Stores.create("my-store", driver.context())
                + *                                              .withIntegerKeys().withStringKeys()
                + *                                              .inMemory().build();
                + *
                + * // Verify that the store reads and writes correctly ...
                + * store.put(0, "zero");
                + * store.put(1, "one");
                + * store.put(2, "two");
                + * store.put(4, "four");
                + * store.put(5, "five");
                + * assertEquals(5, driver.sizeOf(store));
                + * assertEquals("zero", store.get(0));
                + * assertEquals("one", store.get(1));
                + * assertEquals("two", store.get(2));
                + * assertEquals("four", store.get(4));
                + * assertEquals("five", store.get(5));
                + * assertNull(store.get(3));
                + * store.delete(5);
                + *
                + * // Flush the store and verify all current entries were properly flushed ...
                + * store.flush();
                + * assertEquals("zero", driver.flushedEntryStored(0));
                + * assertEquals("one", driver.flushedEntryStored(1));
                + * assertEquals("two", driver.flushedEntryStored(2));
                + * assertEquals("four", driver.flushedEntryStored(4));
                + * assertNull(driver.flushedEntryStored(5));
                + *
                + * assertEquals(false, driver.flushedEntryRemoved(0));
                + * assertEquals(false, driver.flushedEntryRemoved(1));
                + * assertEquals(false, driver.flushedEntryRemoved(2));
                + * assertEquals(false, driver.flushedEntryRemoved(4));
                + * assertEquals(true, driver.flushedEntryRemoved(5));
                + * 
                + * + * + *

                Restoring a store

                + * This component can be used to test whether a {@link KeyValueStore} implementation properly + * {@link ProcessorContext#register(StateStore, StateRestoreCallback) registers itself} with the {@link ProcessorContext}, so that + * the persisted contents of a store are properly restored from the flushed entries when the store instance is started. + *

                + * To do this, create an instance of this driver component, {@link #addEntryToRestoreLog(Object, Object) add entries} that will be + * passed to the store upon creation (simulating the entries that were previously flushed to the topic), and then create the store + * using this driver's {@link #context() ProcessorContext}: + * + *

                + * // Create the test driver ...
                + * KeyValueStoreTestDriver<Integer, String> driver = KeyValueStoreTestDriver.create(Integer.class, String.class);
                + *
                + * // Add any entries that will be restored to any store that uses the driver's context ...
                + * driver.addRestoreEntry(0, "zero");
                + * driver.addRestoreEntry(1, "one");
                + * driver.addRestoreEntry(2, "two");
                + * driver.addRestoreEntry(4, "four");
                + *
                + * // Create the store, which should register with the context and automatically
                + * // receive the restore entries ...
                + * KeyValueStore<Integer, String> store = Stores.create("my-store", driver.context())
                + *                                              .withIntegerKeys().withStringKeys()
                + *                                              .inMemory().build();
                + *
                + * // Verify that the store's contents were properly restored ...
                + * assertEquals(0, driver.checkForRestoredEntries(store));
                + *
                + * // and there are no other entries ...
                + * assertEquals(4, driver.sizeOf(store));
                + * 
                + * + * @param the type of keys placed in the store + * @param the type of values placed in the store + */ +public class KeyValueStoreTestDriver { + + private final Properties props; + + /** + * Create a driver object that will have a {@link #context()} that records messages + * {@link ProcessorContext#forward(Object, Object) forwarded} by the store and that provides default serializers and + * deserializers for the given built-in key and value types (e.g., {@code String.class}, {@code Integer.class}, + * {@code Long.class}, and {@code byte[].class}). This can be used when store is created to rely upon the + * ProcessorContext's default key and value serializers and deserializers. + * + * @param keyClass the class for the keys; must be one of {@code String.class}, {@code Integer.class}, + * {@code Long.class}, or {@code byte[].class} + * @param valueClass the class for the values; must be one of {@code String.class}, {@code Integer.class}, + * {@code Long.class}, or {@code byte[].class} + * @return the test driver; never null + */ + public static KeyValueStoreTestDriver create(final Class keyClass, final Class valueClass) { + final StateSerdes serdes = StateSerdes.withBuiltinTypes("unexpected", keyClass, valueClass); + return new KeyValueStoreTestDriver<>(serdes); + } + + /** + * Create a driver object that will have a {@link #context()} that records messages + * {@link ProcessorContext#forward(Object, Object) forwarded} by the store and that provides the specified serializers and + * deserializers. This can be used when store is created to rely upon the ProcessorContext's default key and value serializers + * and deserializers. + * + * @param keySerializer the key serializer for the {@link ProcessorContext}; may not be null + * @param keyDeserializer the key deserializer for the {@link ProcessorContext}; may not be null + * @param valueSerializer the value serializer for the {@link ProcessorContext}; may not be null + * @param valueDeserializer the value deserializer for the {@link ProcessorContext}; may not be null + * @return the test driver; never null + */ + public static KeyValueStoreTestDriver create(final Serializer keySerializer, + final Deserializer keyDeserializer, + final Serializer valueSerializer, + final Deserializer valueDeserializer) { + final StateSerdes serdes = new StateSerdes<>( + "unexpected", + Serdes.serdeFrom(keySerializer, keyDeserializer), + Serdes.serdeFrom(valueSerializer, valueDeserializer)); + return new KeyValueStoreTestDriver<>(serdes); + } + + private final Map flushedEntries = new HashMap<>(); + private final Set flushedRemovals = new HashSet<>(); + private final List> restorableEntries = new LinkedList<>(); + + private final InternalMockProcessorContext context; + private final StateSerdes stateSerdes; + + @SuppressWarnings("unchecked") + private KeyValueStoreTestDriver(final StateSerdes serdes) { + props = new Properties(); + props.put(StreamsConfig.APPLICATION_ID_CONFIG, "application-id"); + props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + props.put(StreamsConfig.DEFAULT_TIMESTAMP_EXTRACTOR_CLASS_CONFIG, MockTimestampExtractor.class); + props.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, serdes.keySerde().getClass()); + props.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, serdes.valueSerde().getClass()); + props.put(StreamsConfig.ROCKSDB_CONFIG_SETTER_CLASS_CONFIG, MockRocksDbConfigSetter.class); + props.put(StreamsConfig.METRICS_RECORDING_LEVEL_CONFIG, "DEBUG"); + + final LogContext logContext = new LogContext("KeyValueStoreTestDriver "); + final RecordCollector recordCollector = new RecordCollectorImpl( + logContext, + new TaskId(0, 0), + new StreamsProducer( + new StreamsConfig(props), + "threadId", + new MockClientSupplier(), + null, + null, + logContext, + Time.SYSTEM), + new DefaultProductionExceptionHandler(), + new MockStreamsMetrics(new Metrics()) + ) { + @Override + public void send(final String topic, + final K1 key, + final V1 value, + final Headers headers, + final Integer partition, + final Long timestamp, + final Serializer keySerializer, + final Serializer valueSerializer) { + // for byte arrays we need to wrap it for comparison + + final K keyTest = serdes.keyFrom(keySerializer.serialize(topic, headers, key)); + final V valueTest = serdes.valueFrom(valueSerializer.serialize(topic, headers, value)); + + recordFlushed(keyTest, valueTest); + } + + @Override + public void send(final String topic, + final K1 key, + final V1 value, + final Headers headers, + final Long timestamp, + final Serializer keySerializer, + final Serializer valueSerializer, + final StreamPartitioner partitioner) { + throw new UnsupportedOperationException(); + } + }; + + final File stateDir = TestUtils.tempDirectory(); + //noinspection ResultOfMethodCallIgnored + stateDir.mkdirs(); + stateSerdes = serdes; + + context = new InternalMockProcessorContext(stateDir, serdes.keySerde(), serdes.valueSerde(), recordCollector, null) { + final ThreadCache cache = new ThreadCache(new LogContext("testCache "), 1024 * 1024L, metrics()); + + @Override + public ThreadCache cache() { + return cache; + } + + @Override + public Map appConfigs() { + return new StreamsConfig(props).originals(); + } + + @Override + public Map appConfigsWithPrefix(final String prefix) { + return new StreamsConfig(props).originalsWithPrefix(prefix); + } + }; + } + + private void recordFlushed(final K key, final V value) { + if (value == null) { + // This is a removal ... + flushedRemovals.add(key); + flushedEntries.remove(key); + } else { + // This is a normal add + flushedEntries.put(key, value); + flushedRemovals.remove(key); + } + } + + /** + * Get the entries that are restored to a KeyValueStore when it is constructed with this driver's {@link #context() + * ProcessorContext}. + * + * @return the restore entries; never null but possibly a null iterator + */ + public Iterable> restoredEntries() { + return restorableEntries; + } + + /** + * This method adds an entry to the "restore log" for the {@link KeyValueStore}, and is used only when testing the + * restore functionality of a {@link KeyValueStore} implementation. + *

                + * To create such a test, create the test driver, call this method one or more times, and then create the + * {@link KeyValueStore}. Your tests can then check whether the store contains the entries from the log. + * + *

                +     * // Set up the driver and pre-populate the log ...
                +     * KeyValueStoreTestDriver<Integer, String> driver = KeyValueStoreTestDriver.create();
                +     * driver.addRestoreEntry(1,"value1");
                +     * driver.addRestoreEntry(2,"value2");
                +     * driver.addRestoreEntry(3,"value3");
                +     *
                +     * // Create the store using the driver's context ...
                +     * ProcessorContext context = driver.context();
                +     * KeyValueStore<Integer, String> store = ...
                +     *
                +     * // Verify that the store's contents were properly restored from the log ...
                +     * assertEquals(0, driver.checkForRestoredEntries(store));
                +     *
                +     * // and there are no other entries ...
                +     * assertEquals(3, driver.sizeOf(store));
                +     * 
                + * + * @param key the key for the entry + * @param value the value for the entry + * @see #checkForRestoredEntries(KeyValueStore) + */ + public void addEntryToRestoreLog(final K key, final V value) { + restorableEntries.add(new KeyValue<>(stateSerdes.rawKey(key), stateSerdes.rawValue(value))); + } + + /** + * Get the context that should be supplied to a {@link KeyValueStore}'s constructor. This context records any messages + * written by the store to the Kafka topic, making them available via the {@link #flushedEntryStored(Object)} and + * {@link #flushedEntryRemoved(Object)} methods. + *

                + * If the {@link KeyValueStore}'s are to be restored upon its startup, be sure to {@link #addEntryToRestoreLog(Object, Object) + * add the restore entries} before creating the store with the {@link ProcessorContext} returned by this method. + * + * @return the processing context; never null + * @see #addEntryToRestoreLog(Object, Object) + */ + public StateStoreContext context() { + return context; + } + + /** + * Utility method that will count the number of {@link #addEntryToRestoreLog(Object, Object) restore entries} missing from the + * supplied store. + * + * @param store the store that is to have all of the {@link #restoredEntries() restore entries} + * @return the number of restore entries missing from the store, or 0 if all restore entries were found + * @see #addEntryToRestoreLog(Object, Object) + */ + public int checkForRestoredEntries(final KeyValueStore store) { + int missing = 0; + for (final KeyValue kv : restorableEntries) { + if (kv != null) { + final V value = store.get(stateSerdes.keyFrom(kv.key)); + if (!Objects.equals(value, stateSerdes.valueFrom(kv.value))) { + ++missing; + } + } + } + return missing; + } + + /** + * Utility method to compute the number of entries within the store. + * + * @param store the key value store using this {@link #context()}. + * @return the number of entries + */ + public int sizeOf(final KeyValueStore store) { + int size = 0; + try (final KeyValueIterator iterator = store.all()) { + while (iterator.hasNext()) { + iterator.next(); + ++size; + } + } + return size; + } + + /** + * Retrieve the value that the store {@link KeyValueStore#flush() flushed} with the given key. + * + * @param key the key + * @return the value that was flushed with the key, or {@code null} if no such key was flushed or if the entry with this + * key was removed upon flush + */ + public V flushedEntryStored(final K key) { + return flushedEntries.get(key); + } + + /** + * Determine whether the store {@link KeyValueStore#flush() flushed} the removal of the given key. + * + * @param key the key + * @return {@code true} if the entry with the given key was removed when flushed, or {@code false} if the entry was not + * removed when last flushed + */ + public boolean flushedEntryRemoved(final K key) { + return flushedRemovals.contains(key); + } + + /** + * Return number of removed entry + */ + public int numFlushedEntryStored() { + return flushedEntries.size(); + } + + /** + * Return number of removed entry + */ + public int numFlushedEntryRemoved() { + return flushedRemovals.size(); + } + + /** + * Remove all {@link #flushedEntryStored(Object) flushed entries}, {@link #flushedEntryRemoved(Object) flushed removals}, + */ + public void clear() { + restorableEntries.clear(); + flushedEntries.clear(); + flushedRemovals.clear(); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/NoOpWindowStore.java b/streams/src/test/java/org/apache/kafka/streams/state/NoOpWindowStore.java new file mode 100644 index 0000000..d64fd09 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/NoOpWindowStore.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state; + +import java.time.Instant; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; + +import java.util.NoSuchElementException; + +public class NoOpWindowStore implements ReadOnlyWindowStore, StateStore { + + private static class EmptyWindowStoreIterator implements WindowStoreIterator { + + @Override + public void close() { + } + + @Override + public Long peekNextKey() { + throw new NoSuchElementException(); + } + + @Override + public boolean hasNext() { + return false; + } + + @Override + public KeyValue next() { + throw new NoSuchElementException(); + } + } + + private static final WindowStoreIterator EMPTY_WINDOW_STORE_ITERATOR = new EmptyWindowStoreIterator(); + + @Override + public String name() { + return ""; + } + + @Deprecated + @Override + public void init(final ProcessorContext context, final StateStore root) { + + } + + @Override + public void flush() { + + } + + @Override + public void close() { + + } + + @Override + public boolean persistent() { + return false; + } + + @Override + public boolean isOpen() { + return false; + } + + @Override + public Object fetch(final Object key, final long time) { + return null; + } + + @Override + public WindowStoreIterator fetch(final Object key, final Instant timeFrom, final Instant timeTo) throws IllegalArgumentException { + return EMPTY_WINDOW_STORE_ITERATOR; + } + + @Override + public WindowStoreIterator backwardFetch(final Object key, + final Instant timeFrom, + final Instant timeTo) throws IllegalArgumentException { + return EMPTY_WINDOW_STORE_ITERATOR; + } + + @Override + public KeyValueIterator fetch(final Object keyFrom, + final Object keyTo, + final Instant timeFrom, + final Instant timeTo) throws IllegalArgumentException { + return EMPTY_WINDOW_STORE_ITERATOR; + } + + @Override + public KeyValueIterator backwardFetch(final Object from, + final Object keyTo, + final Instant timeFrom, + final Instant timeTo) throws IllegalArgumentException { + return EMPTY_WINDOW_STORE_ITERATOR; + } + + @Override + public WindowStoreIterator all() { + return EMPTY_WINDOW_STORE_ITERATOR; + } + + @Override + public WindowStoreIterator backwardAll() { + return EMPTY_WINDOW_STORE_ITERATOR; + } + + @Override + public KeyValueIterator fetchAll(final Instant timeFrom, final Instant timeTo) throws IllegalArgumentException { + return EMPTY_WINDOW_STORE_ITERATOR; + } + + @Override + public KeyValueIterator backwardFetchAll(final Instant timeFrom, + final Instant timeTo) throws IllegalArgumentException { + return EMPTY_WINDOW_STORE_ITERATOR; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/StateSerdesTest.java b/streams/src/test/java/org/apache/kafka/streams/state/StateSerdesTest.java new file mode 100644 index 0000000..3f3c62d --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/StateSerdesTest.java @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.state.internals.ValueAndTimestampSerde; +import org.junit.Assert; +import org.junit.Test; + +import java.nio.ByteBuffer; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertThrows; + +@SuppressWarnings("unchecked") +public class StateSerdesTest { + + @Test + public void shouldThrowIfTopicNameIsNullForBuiltinTypes() { + assertThrows(NullPointerException.class, () -> StateSerdes.withBuiltinTypes(null, byte[].class, byte[].class)); + } + + @Test + public void shouldThrowIfKeyClassIsNullForBuiltinTypes() { + assertThrows(NullPointerException.class, () -> StateSerdes.withBuiltinTypes("anyName", null, byte[].class)); + } + + @Test + public void shouldThrowIfValueClassIsNullForBuiltinTypes() { + assertThrows(NullPointerException.class, () -> StateSerdes.withBuiltinTypes("anyName", byte[].class, null)); + } + + @Test + public void shouldReturnSerdesForBuiltInKeyAndValueTypesForBuiltinTypes() { + final Class[] supportedBuildInTypes = new Class[] { + String.class, + Short.class, + Integer.class, + Long.class, + Float.class, + Double.class, + byte[].class, + ByteBuffer.class, + Bytes.class + }; + + for (final Class keyClass : supportedBuildInTypes) { + for (final Class valueClass : supportedBuildInTypes) { + Assert.assertNotNull(StateSerdes.withBuiltinTypes("anyName", keyClass, valueClass)); + } + } + } + + @Test + public void shouldThrowForUnknownKeyTypeForBuiltinTypes() { + assertThrows(IllegalArgumentException.class, () -> StateSerdes.withBuiltinTypes("anyName", Class.class, byte[].class)); + } + + @Test + public void shouldThrowForUnknownValueTypeForBuiltinTypes() { + assertThrows(IllegalArgumentException.class, () -> StateSerdes.withBuiltinTypes("anyName", byte[].class, Class.class)); + } + + @Test + public void shouldThrowIfTopicNameIsNull() { + assertThrows(NullPointerException.class, () -> new StateSerdes<>(null, Serdes.ByteArray(), Serdes.ByteArray())); + } + + @Test + public void shouldThrowIfKeyClassIsNull() { + assertThrows(NullPointerException.class, () -> new StateSerdes<>("anyName", null, Serdes.ByteArray())); + } + + @Test + public void shouldThrowIfValueClassIsNull() { + assertThrows(NullPointerException.class, () -> new StateSerdes<>("anyName", Serdes.ByteArray(), null)); + } + + @Test + public void shouldThrowIfIncompatibleSerdeForValue() throws ClassNotFoundException { + final Class myClass = Class.forName("java.lang.String"); + final StateSerdes stateSerdes = new StateSerdes("anyName", Serdes.serdeFrom(myClass), Serdes.serdeFrom(myClass)); + final Integer myInt = 123; + final Exception e = assertThrows(StreamsException.class, () -> stateSerdes.rawValue(myInt)); + assertThat( + e.getMessage(), + equalTo( + "A serializer (org.apache.kafka.common.serialization.StringSerializer) " + + "is not compatible to the actual value type (value type: java.lang.Integer). " + + "Change the default Serdes in StreamConfig or provide correct Serdes via method parameters.")); + } + + @Test + public void shouldSkipValueAndTimestampeInformationForErrorOnTimestampAndValueSerialization() throws ClassNotFoundException { + final Class myClass = Class.forName("java.lang.String"); + final StateSerdes stateSerdes = + new StateSerdes("anyName", Serdes.serdeFrom(myClass), new ValueAndTimestampSerde(Serdes.serdeFrom(myClass))); + final Integer myInt = 123; + final Exception e = assertThrows(StreamsException.class, () -> stateSerdes.rawValue(ValueAndTimestamp.make(myInt, 0L))); + assertThat( + e.getMessage(), + equalTo( + "A serializer (org.apache.kafka.common.serialization.StringSerializer) " + + "is not compatible to the actual value type (value type: java.lang.Integer). " + + "Change the default Serdes in StreamConfig or provide correct Serdes via method parameters.")); + } + + @Test + public void shouldThrowIfIncompatibleSerdeForKey() throws ClassNotFoundException { + final Class myClass = Class.forName("java.lang.String"); + final StateSerdes stateSerdes = new StateSerdes("anyName", Serdes.serdeFrom(myClass), Serdes.serdeFrom(myClass)); + final Integer myInt = 123; + final Exception e = assertThrows(StreamsException.class, () -> stateSerdes.rawKey(myInt)); + assertThat( + e.getMessage(), + equalTo( + "A serializer (org.apache.kafka.common.serialization.StringSerializer) " + + "is not compatible to the actual key type (key type: java.lang.Integer). " + + "Change the default Serdes in StreamConfig or provide correct Serdes via method parameters.")); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/StoresTest.java b/streams/src/test/java/org/apache/kafka/streams/state/StoresTest.java new file mode 100644 index 0000000..90f019a --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/StoresTest.java @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.state.internals.InMemoryKeyValueStore; +import org.apache.kafka.streams.state.internals.MemoryNavigableLRUCache; +import org.apache.kafka.streams.state.internals.RocksDBSegmentedBytesStore; +import org.apache.kafka.streams.state.internals.RocksDBSessionStore; +import org.apache.kafka.streams.state.internals.RocksDBStore; +import org.apache.kafka.streams.state.internals.RocksDBTimestampedSegmentedBytesStore; +import org.apache.kafka.streams.state.internals.RocksDBTimestampedStore; +import org.apache.kafka.streams.state.internals.RocksDBWindowStore; +import org.apache.kafka.streams.state.internals.WrappedStateStore; +import org.junit.Test; + +import static java.time.Duration.ZERO; +import static java.time.Duration.ofMillis; +import static org.hamcrest.CoreMatchers.allOf; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.IsInstanceOf.instanceOf; +import static org.hamcrest.core.IsNot.not; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +public class StoresTest { + + @Test + public void shouldThrowIfPersistentKeyValueStoreStoreNameIsNull() { + final Exception e = assertThrows(NullPointerException.class, () -> Stores.persistentKeyValueStore(null)); + assertEquals("name cannot be null", e.getMessage()); + } + + @Test + public void shouldThrowIfPersistentTimestampedKeyValueStoreStoreNameIsNull() { + final Exception e = assertThrows(NullPointerException.class, () -> Stores.persistentTimestampedKeyValueStore(null)); + assertEquals("name cannot be null", e.getMessage()); + } + + @Test + public void shouldThrowIfIMemoryKeyValueStoreStoreNameIsNull() { + final Exception e = assertThrows(NullPointerException.class, () -> Stores.inMemoryKeyValueStore(null)); + assertEquals("name cannot be null", e.getMessage()); + } + + @Test + public void shouldThrowIfILruMapStoreNameIsNull() { + final Exception e = assertThrows(NullPointerException.class, () -> Stores.lruMap(null, 0)); + assertEquals("name cannot be null", e.getMessage()); + } + + @Test + public void shouldThrowIfILruMapStoreCapacityIsNegative() { + final Exception e = assertThrows(IllegalArgumentException.class, () -> Stores.lruMap("anyName", -1)); + assertEquals("maxCacheSize cannot be negative", e.getMessage()); + } + + @Test + public void shouldThrowIfIPersistentWindowStoreStoreNameIsNull() { + final Exception e = assertThrows(NullPointerException.class, () -> Stores.persistentWindowStore(null, ZERO, ZERO, false)); + assertEquals("name cannot be null", e.getMessage()); + } + + @Test + public void shouldThrowIfIPersistentTimestampedWindowStoreStoreNameIsNull() { + final Exception e = assertThrows(NullPointerException.class, () -> Stores.persistentTimestampedWindowStore(null, ZERO, ZERO, false)); + assertEquals("name cannot be null", e.getMessage()); + } + + @Test + public void shouldThrowIfIPersistentWindowStoreRetentionPeriodIsNegative() { + final Exception e = assertThrows(IllegalArgumentException.class, () -> Stores.persistentWindowStore("anyName", ofMillis(-1L), ZERO, false)); + assertEquals("retentionPeriod cannot be negative", e.getMessage()); + } + + @Test + public void shouldThrowIfIPersistentTimestampedWindowStoreRetentionPeriodIsNegative() { + final Exception e = assertThrows(IllegalArgumentException.class, () -> Stores.persistentTimestampedWindowStore("anyName", ofMillis(-1L), ZERO, false)); + assertEquals("retentionPeriod cannot be negative", e.getMessage()); + } + + @Test + public void shouldThrowIfIPersistentWindowStoreIfWindowSizeIsNegative() { + final Exception e = assertThrows(IllegalArgumentException.class, () -> Stores.persistentWindowStore("anyName", ofMillis(0L), ofMillis(-1L), false)); + assertEquals("windowSize cannot be negative", e.getMessage()); + } + + @Test + public void shouldThrowIfIPersistentTimestampedWindowStoreIfWindowSizeIsNegative() { + final Exception e = assertThrows(IllegalArgumentException.class, () -> Stores.persistentTimestampedWindowStore("anyName", ofMillis(0L), ofMillis(-1L), false)); + assertEquals("windowSize cannot be negative", e.getMessage()); + } + + @Test + public void shouldThrowIfIPersistentSessionStoreStoreNameIsNull() { + final Exception e = assertThrows(NullPointerException.class, () -> Stores.persistentSessionStore(null, ofMillis(0))); + assertEquals("name cannot be null", e.getMessage()); + } + + @Test + public void shouldThrowIfIPersistentSessionStoreRetentionPeriodIsNegative() { + final Exception e = assertThrows(IllegalArgumentException.class, () -> Stores.persistentSessionStore("anyName", ofMillis(-1))); + assertEquals("retentionPeriod cannot be negative", e.getMessage()); + } + + @Test + public void shouldThrowIfSupplierIsNullForWindowStoreBuilder() { + final Exception e = assertThrows(NullPointerException.class, () -> Stores.windowStoreBuilder(null, Serdes.ByteArray(), Serdes.ByteArray())); + assertEquals("supplier cannot be null", e.getMessage()); + } + + @Test + public void shouldThrowIfSupplierIsNullForKeyValueStoreBuilder() { + final Exception e = assertThrows(NullPointerException.class, () -> Stores.keyValueStoreBuilder(null, Serdes.ByteArray(), Serdes.ByteArray())); + assertEquals("supplier cannot be null", e.getMessage()); + } + + @Test + public void shouldThrowIfSupplierIsNullForSessionStoreBuilder() { + final Exception e = assertThrows(NullPointerException.class, () -> Stores.sessionStoreBuilder(null, Serdes.ByteArray(), Serdes.ByteArray())); + assertEquals("supplier cannot be null", e.getMessage()); + } + + @Test + public void shouldCreateInMemoryKeyValueStore() { + assertThat(Stores.inMemoryKeyValueStore("memory").get(), instanceOf(InMemoryKeyValueStore.class)); + } + + @Test + public void shouldCreateMemoryNavigableCache() { + assertThat(Stores.lruMap("map", 10).get(), instanceOf(MemoryNavigableLRUCache.class)); + } + + @Test + public void shouldCreateRocksDbStore() { + assertThat( + Stores.persistentKeyValueStore("store").get(), + allOf(not(instanceOf(RocksDBTimestampedStore.class)), instanceOf(RocksDBStore.class))); + } + + @Test + public void shouldCreateRocksDbTimestampedStore() { + assertThat(Stores.persistentTimestampedKeyValueStore("store").get(), instanceOf(RocksDBTimestampedStore.class)); + } + + @Test + public void shouldCreateRocksDbWindowStore() { + final WindowStore store = Stores.persistentWindowStore("store", ofMillis(1L), ofMillis(1L), false).get(); + final StateStore wrapped = ((WrappedStateStore) store).wrapped(); + assertThat(store, instanceOf(RocksDBWindowStore.class)); + assertThat(wrapped, allOf(not(instanceOf(RocksDBTimestampedSegmentedBytesStore.class)), instanceOf(RocksDBSegmentedBytesStore.class))); + } + + @Test + public void shouldCreateRocksDbTimestampedWindowStore() { + final WindowStore store = Stores.persistentTimestampedWindowStore("store", ofMillis(1L), ofMillis(1L), false).get(); + final StateStore wrapped = ((WrappedStateStore) store).wrapped(); + assertThat(store, instanceOf(RocksDBWindowStore.class)); + assertThat(wrapped, instanceOf(RocksDBTimestampedSegmentedBytesStore.class)); + } + + @Test + public void shouldCreateRocksDbSessionStore() { + assertThat(Stores.persistentSessionStore("store", ofMillis(1)).get(), instanceOf(RocksDBSessionStore.class)); + } + + @Test + public void shouldBuildKeyValueStore() { + final KeyValueStore store = Stores.keyValueStoreBuilder( + Stores.persistentKeyValueStore("name"), + Serdes.String(), + Serdes.String() + ).build(); + assertThat(store, not(nullValue())); + } + + @Test + public void shouldBuildTimestampedKeyValueStore() { + final TimestampedKeyValueStore store = Stores.timestampedKeyValueStoreBuilder( + Stores.persistentTimestampedKeyValueStore("name"), + Serdes.String(), + Serdes.String() + ).build(); + assertThat(store, not(nullValue())); + } + + @Test + public void shouldBuildTimestampedKeyValueStoreThatWrapsKeyValueStore() { + final TimestampedKeyValueStore store = Stores.timestampedKeyValueStoreBuilder( + Stores.persistentKeyValueStore("name"), + Serdes.String(), + Serdes.String() + ).build(); + assertThat(store, not(nullValue())); + } + + @Test + public void shouldBuildTimestampedKeyValueStoreThatWrapsInMemoryKeyValueStore() { + final TimestampedKeyValueStore store = Stores.timestampedKeyValueStoreBuilder( + Stores.inMemoryKeyValueStore("name"), + Serdes.String(), + Serdes.String() + ).withLoggingDisabled().withCachingDisabled().build(); + assertThat(store, not(nullValue())); + assertThat(((WrappedStateStore) store).wrapped(), instanceOf(TimestampedBytesStore.class)); + } + + @Test + public void shouldBuildWindowStore() { + final WindowStore store = Stores.windowStoreBuilder( + Stores.persistentWindowStore("store", ofMillis(3L), ofMillis(3L), true), + Serdes.String(), + Serdes.String() + ).build(); + assertThat(store, not(nullValue())); + } + + @Test + public void shouldBuildTimestampedWindowStore() { + final TimestampedWindowStore store = Stores.timestampedWindowStoreBuilder( + Stores.persistentTimestampedWindowStore("store", ofMillis(3L), ofMillis(3L), true), + Serdes.String(), + Serdes.String() + ).build(); + assertThat(store, not(nullValue())); + } + + @Test + public void shouldBuildTimestampedWindowStoreThatWrapsWindowStore() { + final TimestampedWindowStore store = Stores.timestampedWindowStoreBuilder( + Stores.persistentWindowStore("store", ofMillis(3L), ofMillis(3L), true), + Serdes.String(), + Serdes.String() + ).build(); + assertThat(store, not(nullValue())); + } + + @Test + public void shouldBuildTimestampedWindowStoreThatWrapsInMemoryWindowStore() { + final TimestampedWindowStore store = Stores.timestampedWindowStoreBuilder( + Stores.inMemoryWindowStore("store", ofMillis(3L), ofMillis(3L), true), + Serdes.String(), + Serdes.String() + ).withLoggingDisabled().withCachingDisabled().build(); + assertThat(store, not(nullValue())); + assertThat(((WrappedStateStore) store).wrapped(), instanceOf(TimestampedBytesStore.class)); + } + + @Test + public void shouldBuildSessionStore() { + final SessionStore store = Stores.sessionStoreBuilder( + Stores.persistentSessionStore("name", ofMillis(10)), + Serdes.String(), + Serdes.String() + ).build(); + assertThat(store, not(nullValue())); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/StreamsMetadataTest.java b/streams/src/test/java/org/apache/kafka/streams/state/StreamsMetadataTest.java new file mode 100644 index 0000000..d6862ce --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/StreamsMetadataTest.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.state; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.streams.StreamsMetadata; +import org.apache.kafka.streams.state.internals.StreamsMetadataImpl; +import org.junit.Before; +import org.junit.Test; + +import java.util.Collection; +import java.util.Set; + +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; + +public class StreamsMetadataTest { + + private static final HostInfo HOST_INFO = new HostInfo("local", 12); + public static final Set STATE_STORE_NAMES = mkSet("store1", "store2"); + private static final TopicPartition TP_0 = new TopicPartition("t", 0); + private static final TopicPartition TP_1 = new TopicPartition("t", 1); + public static final Set TOPIC_PARTITIONS = mkSet(TP_0, TP_1); + public static final Set STAND_BY_STORE_NAMES = mkSet("store2"); + public static final Set STANDBY_TOPIC_PARTITIONS = mkSet(TP_1); + + private StreamsMetadata streamsMetadata; + + @Before + public void setUp() { + streamsMetadata = new StreamsMetadataImpl( + HOST_INFO, + STATE_STORE_NAMES, + TOPIC_PARTITIONS, + STAND_BY_STORE_NAMES, + STANDBY_TOPIC_PARTITIONS + ); + } + + @Test + public void shouldNotAllowModificationOfInternalStateViaGetters() { + assertThat(isUnmodifiable(streamsMetadata.stateStoreNames()), is(true)); + assertThat(isUnmodifiable(streamsMetadata.topicPartitions()), is(true)); + assertThat(isUnmodifiable(streamsMetadata.standbyTopicPartitions()), is(true)); + assertThat(isUnmodifiable(streamsMetadata.standbyStateStoreNames()), is(true)); + } + + @Test + public void shouldBeEqualsIfSameObject() { + final StreamsMetadata same = new StreamsMetadataImpl( + HOST_INFO, + STATE_STORE_NAMES, + TOPIC_PARTITIONS, + STAND_BY_STORE_NAMES, + STANDBY_TOPIC_PARTITIONS); + assertThat(streamsMetadata, equalTo(same)); + assertThat(streamsMetadata.hashCode(), equalTo(same.hashCode())); + } + + @Test + public void shouldNotBeEqualIfDifferInHostInfo() { + final StreamsMetadata differHostInfo = new StreamsMetadataImpl( + new HostInfo("different", 122), + STATE_STORE_NAMES, + TOPIC_PARTITIONS, + STAND_BY_STORE_NAMES, + STANDBY_TOPIC_PARTITIONS); + assertThat(streamsMetadata, not(equalTo(differHostInfo))); + assertThat(streamsMetadata.hashCode(), not(equalTo(differHostInfo.hashCode()))); + } + + @Test + public void shouldNotBeEqualIfDifferStateStoreNames() { + final StreamsMetadata differStateStoreNames = new StreamsMetadataImpl( + HOST_INFO, + mkSet("store1"), + TOPIC_PARTITIONS, + STAND_BY_STORE_NAMES, + STANDBY_TOPIC_PARTITIONS); + assertThat(streamsMetadata, not(equalTo(differStateStoreNames))); + assertThat(streamsMetadata.hashCode(), not(equalTo(differStateStoreNames.hashCode()))); + } + + @Test + public void shouldNotBeEqualIfDifferInTopicPartitions() { + final StreamsMetadata differTopicPartitions = new StreamsMetadataImpl( + HOST_INFO, + STATE_STORE_NAMES, + mkSet(TP_0), + STAND_BY_STORE_NAMES, + STANDBY_TOPIC_PARTITIONS); + assertThat(streamsMetadata, not(equalTo(differTopicPartitions))); + assertThat(streamsMetadata.hashCode(), not(equalTo(differTopicPartitions.hashCode()))); + } + + @Test + public void shouldNotBeEqualIfDifferInStandByStores() { + final StreamsMetadata differStandByStores = new StreamsMetadataImpl( + HOST_INFO, + STATE_STORE_NAMES, + TOPIC_PARTITIONS, + mkSet("store1"), + STANDBY_TOPIC_PARTITIONS); + assertThat(streamsMetadata, not(equalTo(differStandByStores))); + assertThat(streamsMetadata.hashCode(), not(equalTo(differStandByStores.hashCode()))); + } + + @Test + public void shouldNotBeEqualIfDifferInStandByTopicPartitions() { + final StreamsMetadata differStandByTopicPartitions = new StreamsMetadataImpl( + HOST_INFO, + STATE_STORE_NAMES, + TOPIC_PARTITIONS, + STAND_BY_STORE_NAMES, + mkSet(TP_0)); + assertThat(streamsMetadata, not(equalTo(differStandByTopicPartitions))); + assertThat(streamsMetadata.hashCode(), not(equalTo(differStandByTopicPartitions.hashCode()))); + } + + private static boolean isUnmodifiable(final Collection collection) { + try { + collection.clear(); + return false; + } catch (final UnsupportedOperationException e) { + return true; + } + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractKeyValueStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractKeyValueStoreTest.java new file mode 100644 index 0000000..19b057a --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractKeyValueStoreTest.java @@ -0,0 +1,651 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.KeyValueStoreTestDriver; +import org.apache.kafka.test.InternalMockProcessorContext; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.IsEqual.equalTo; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +@SuppressWarnings("unchecked") +public abstract class AbstractKeyValueStoreTest { + + protected abstract KeyValueStore createKeyValueStore(final StateStoreContext context); + + protected InternalMockProcessorContext context; + protected KeyValueStore store; + protected KeyValueStoreTestDriver driver; + + @Before + public void before() { + driver = KeyValueStoreTestDriver.create(Integer.class, String.class); + context = (InternalMockProcessorContext) driver.context(); + context.setTime(10); + store = createKeyValueStore(context); + } + + @After + public void after() { + store.close(); + driver.clear(); + } + + private static Map getContents(final KeyValueIterator iter) { + final HashMap result = new HashMap<>(); + while (iter.hasNext()) { + final KeyValue entry = iter.next(); + result.put(entry.key, entry.value); + } + return result; + } + + @SuppressWarnings("unchecked") + @Test + public void shouldNotIncludeDeletedFromRangeResult() { + store.close(); + + final Serializer serializer = new StringSerializer() { + private int numCalls = 0; + + @Override + public byte[] serialize(final String topic, final String data) { + if (++numCalls > 3) { + fail("Value serializer is called; it should never happen"); + } + + return super.serialize(topic, data); + } + }; + + context.setValueSerde(Serdes.serdeFrom(serializer, new StringDeserializer())); + store = createKeyValueStore(driver.context()); + + store.put(0, "zero"); + store.put(1, "one"); + store.put(2, "two"); + store.delete(0); + store.delete(1); + + // should not include deleted records in iterator + final Map expectedContents = Collections.singletonMap(2, "two"); + assertEquals(expectedContents, getContents(store.all())); + } + + @Test + public void shouldDeleteIfSerializedValueIsNull() { + store.close(); + + final Serializer serializer = new StringSerializer() { + @Override + public byte[] serialize(final String topic, final String data) { + if (data.equals("null")) { + // will be serialized to null bytes, indicating deletes + return null; + } + return super.serialize(topic, data); + } + }; + + context.setValueSerde(Serdes.serdeFrom(serializer, new StringDeserializer())); + store = createKeyValueStore(driver.context()); + + store.put(0, "zero"); + store.put(1, "one"); + store.put(2, "two"); + store.put(0, "null"); + store.put(1, "null"); + + // should not include deleted records in iterator + final Map expectedContents = Collections.singletonMap(2, "two"); + assertEquals(expectedContents, getContents(store.all())); + } + + @Test + public void testPutGetRange() { + // Verify that the store reads and writes correctly ... + store.put(0, "zero"); + store.put(1, "one"); + store.put(2, "two"); + store.put(4, "four"); + store.put(5, "five"); + assertEquals(5, driver.sizeOf(store)); + assertEquals("zero", store.get(0)); + assertEquals("one", store.get(1)); + assertEquals("two", store.get(2)); + assertNull(store.get(3)); + assertEquals("four", store.get(4)); + assertEquals("five", store.get(5)); + // Flush now so that for caching store, we will not skip the deletion following an put + store.flush(); + store.delete(5); + assertEquals(4, driver.sizeOf(store)); + + // Flush the store and verify all current entries were properly flushed ... + store.flush(); + assertEquals("zero", driver.flushedEntryStored(0)); + assertEquals("one", driver.flushedEntryStored(1)); + assertEquals("two", driver.flushedEntryStored(2)); + assertEquals("four", driver.flushedEntryStored(4)); + assertNull(driver.flushedEntryStored(5)); + + assertFalse(driver.flushedEntryRemoved(0)); + assertFalse(driver.flushedEntryRemoved(1)); + assertFalse(driver.flushedEntryRemoved(2)); + assertFalse(driver.flushedEntryRemoved(4)); + assertTrue(driver.flushedEntryRemoved(5)); + + final HashMap expectedContents = new HashMap<>(); + expectedContents.put(2, "two"); + expectedContents.put(4, "four"); + + // Check range iteration ... + assertEquals(expectedContents, getContents(store.range(2, 4))); + assertEquals(expectedContents, getContents(store.range(2, 6))); + + // Check all iteration ... + expectedContents.put(0, "zero"); + expectedContents.put(1, "one"); + assertEquals(expectedContents, getContents(store.all())); + } + + @Test + public void testPutGetReverseRange() { + // Verify that the store reads and writes correctly ... + store.put(0, "zero"); + store.put(1, "one"); + store.put(2, "two"); + store.put(4, "four"); + store.put(5, "five"); + assertEquals(5, driver.sizeOf(store)); + assertEquals("zero", store.get(0)); + assertEquals("one", store.get(1)); + assertEquals("two", store.get(2)); + assertNull(store.get(3)); + assertEquals("four", store.get(4)); + assertEquals("five", store.get(5)); + // Flush now so that for caching store, we will not skip the deletion following an put + store.flush(); + store.delete(5); + assertEquals(4, driver.sizeOf(store)); + + // Flush the store and verify all current entries were properly flushed ... + store.flush(); + assertEquals("zero", driver.flushedEntryStored(0)); + assertEquals("one", driver.flushedEntryStored(1)); + assertEquals("two", driver.flushedEntryStored(2)); + assertEquals("four", driver.flushedEntryStored(4)); + assertNull(driver.flushedEntryStored(5)); + + assertFalse(driver.flushedEntryRemoved(0)); + assertFalse(driver.flushedEntryRemoved(1)); + assertFalse(driver.flushedEntryRemoved(2)); + assertFalse(driver.flushedEntryRemoved(4)); + assertTrue(driver.flushedEntryRemoved(5)); + + final HashMap expectedContents = new HashMap<>(); + expectedContents.put(2, "two"); + expectedContents.put(4, "four"); + + // Check range iteration ... + assertEquals(expectedContents, getContents(store.reverseRange(2, 4))); + assertEquals(expectedContents, getContents(store.reverseRange(2, 6))); + + // Check all iteration ... + expectedContents.put(0, "zero"); + expectedContents.put(1, "one"); + assertEquals(expectedContents, getContents(store.reverseAll())); + } + + @Test + public void testPutGetWithDefaultSerdes() { + // Verify that the store reads and writes correctly ... + store.put(0, "zero"); + store.put(1, "one"); + store.put(2, "two"); + store.put(4, "four"); + store.put(5, "five"); + assertEquals(5, driver.sizeOf(store)); + assertEquals("zero", store.get(0)); + assertEquals("one", store.get(1)); + assertEquals("two", store.get(2)); + assertNull(store.get(3)); + assertEquals("four", store.get(4)); + assertEquals("five", store.get(5)); + store.flush(); + store.delete(5); + + // Flush the store and verify all current entries were properly flushed ... + store.flush(); + assertEquals("zero", driver.flushedEntryStored(0)); + assertEquals("one", driver.flushedEntryStored(1)); + assertEquals("two", driver.flushedEntryStored(2)); + assertEquals("four", driver.flushedEntryStored(4)); + assertNull(driver.flushedEntryStored(5)); + + assertFalse(driver.flushedEntryRemoved(0)); + assertFalse(driver.flushedEntryRemoved(1)); + assertFalse(driver.flushedEntryRemoved(2)); + assertFalse(driver.flushedEntryRemoved(4)); + assertTrue(driver.flushedEntryRemoved(5)); + } + + @Test + public void testRestore() { + store.close(); + // Add any entries that will be restored to any store + // that uses the driver's context ... + driver.addEntryToRestoreLog(0, "zero"); + driver.addEntryToRestoreLog(1, "one"); + driver.addEntryToRestoreLog(2, "two"); + driver.addEntryToRestoreLog(3, "three"); + + // Create the store, which should register with the context and automatically + // receive the restore entries ... + store = createKeyValueStore(driver.context()); + context.restore(store.name(), driver.restoredEntries()); + + // Verify that the store's contents were properly restored ... + assertEquals(0, driver.checkForRestoredEntries(store)); + + // and there are no other entries ... + assertEquals(4, driver.sizeOf(store)); + } + + @Test + public void testRestoreWithDefaultSerdes() { + store.close(); + // Add any entries that will be restored to any store + // that uses the driver's context ... + driver.addEntryToRestoreLog(0, "zero"); + driver.addEntryToRestoreLog(1, "one"); + driver.addEntryToRestoreLog(2, "two"); + driver.addEntryToRestoreLog(3, "three"); + + // Create the store, which should register with the context and automatically + // receive the restore entries ... + store = createKeyValueStore(driver.context()); + context.restore(store.name(), driver.restoredEntries()); + // Verify that the store's contents were properly restored ... + assertEquals(0, driver.checkForRestoredEntries(store)); + + // and there are no other entries ... + assertEquals(4, driver.sizeOf(store)); + } + + @Test + public void testPutIfAbsent() { + // Verify that the store reads and writes correctly ... + assertNull(store.putIfAbsent(0, "zero")); + assertNull(store.putIfAbsent(1, "one")); + assertNull(store.putIfAbsent(2, "two")); + assertNull(store.putIfAbsent(4, "four")); + assertEquals("four", store.putIfAbsent(4, "unexpected value")); + assertEquals(4, driver.sizeOf(store)); + assertEquals("zero", store.get(0)); + assertEquals("one", store.get(1)); + assertEquals("two", store.get(2)); + assertNull(store.get(3)); + assertEquals("four", store.get(4)); + + // Flush the store and verify all current entries were properly flushed ... + store.flush(); + assertEquals("zero", driver.flushedEntryStored(0)); + assertEquals("one", driver.flushedEntryStored(1)); + assertEquals("two", driver.flushedEntryStored(2)); + assertEquals("four", driver.flushedEntryStored(4)); + + assertFalse(driver.flushedEntryRemoved(0)); + assertFalse(driver.flushedEntryRemoved(1)); + assertFalse(driver.flushedEntryRemoved(2)); + assertFalse(driver.flushedEntryRemoved(4)); + } + + @Test + public void shouldThrowNullPointerExceptionOnPutNullKey() { + assertThrows(NullPointerException.class, () -> store.put(null, "anyValue")); + } + + @Test + public void shouldNotThrowNullPointerExceptionOnPutNullValue() { + store.put(1, null); + } + + @Test + public void shouldThrowNullPointerExceptionOnPutIfAbsentNullKey() { + assertThrows(NullPointerException.class, () -> store.putIfAbsent(null, "anyValue")); + } + + @Test + public void shouldNotThrowNullPointerExceptionOnPutIfAbsentNullValue() { + store.putIfAbsent(1, null); + } + + @Test + public void shouldThrowNullPointerExceptionOnPutAllNullKey() { + assertThrows(NullPointerException.class, () -> store.putAll(Collections.singletonList(new KeyValue<>(null, "anyValue")))); + } + + @Test + public void shouldNotThrowNullPointerExceptionOnPutAllNullKey() { + store.putAll(Collections.singletonList(new KeyValue<>(1, null))); + } + + @Test + public void shouldThrowNullPointerExceptionOnDeleteNullKey() { + assertThrows(NullPointerException.class, () -> store.delete(null)); + } + + @Test + public void shouldThrowNullPointerExceptionOnGetNullKey() { + assertThrows(NullPointerException.class, () -> store.get(null)); + } + + @Test + public void shouldReturnValueOnRangeNullToKey() { + store.put(0, "zero"); + store.put(1, "one"); + store.put(2, "two"); + + final LinkedList> expectedContents = new LinkedList<>(); + expectedContents.add(new KeyValue<>(0, "zero")); + expectedContents.add(new KeyValue<>(1, "one")); + + try (final KeyValueIterator iterator = store.range(null, 1)) { + assertEquals(expectedContents, Utils.toList(iterator)); + } + } + + @Test + public void shouldReturnValueOnRangeKeyToNull() { + store.put(0, "zero"); + store.put(1, "one"); + store.put(2, "two"); + + final LinkedList> expectedContents = new LinkedList<>(); + expectedContents.add(new KeyValue<>(1, "one")); + expectedContents.add(new KeyValue<>(2, "two")); + + try (final KeyValueIterator iterator = store.range(1, null)) { + assertEquals(expectedContents, Utils.toList(iterator)); + } + } + + @Test + public void shouldReturnValueOnRangeNullToNull() { + store.put(0, "zero"); + store.put(1, "one"); + store.put(2, "two"); + + final LinkedList> expectedContents = new LinkedList<>(); + expectedContents.add(new KeyValue<>(0, "zero")); + expectedContents.add(new KeyValue<>(1, "one")); + expectedContents.add(new KeyValue<>(2, "two")); + + try (final KeyValueIterator iterator = store.range(null, null)) { + assertEquals(expectedContents, Utils.toList(iterator)); + } + } + + @Test + public void shouldReturnValueOnReverseRangeNullToKey() { + store.put(0, "zero"); + store.put(1, "one"); + store.put(2, "two"); + + final LinkedList> expectedContents = new LinkedList<>(); + expectedContents.add(new KeyValue<>(1, "one")); + expectedContents.add(new KeyValue<>(0, "zero")); + + try (final KeyValueIterator iterator = store.reverseRange(null, 1)) { + assertEquals(expectedContents, Utils.toList(iterator)); + } + } + + @Test + public void shouldReturnValueOnReverseRangeKeyToNull() { + store.put(0, "zero"); + store.put(1, "one"); + store.put(2, "two"); + + final LinkedList> expectedContents = new LinkedList<>(); + expectedContents.add(new KeyValue<>(2, "two")); + expectedContents.add(new KeyValue<>(1, "one")); + + try (final KeyValueIterator iterator = store.reverseRange(1, null)) { + assertEquals(expectedContents, Utils.toList(iterator)); + } + } + + @Test + public void shouldReturnValueOnReverseRangeNullToNull() { + store.put(0, "zero"); + store.put(1, "one"); + store.put(2, "two"); + + final LinkedList> expectedContents = new LinkedList<>(); + expectedContents.add(new KeyValue<>(2, "two")); + expectedContents.add(new KeyValue<>(1, "one")); + expectedContents.add(new KeyValue<>(0, "zero")); + + try (final KeyValueIterator iterator = store.reverseRange(null, null)) { + assertEquals(expectedContents, Utils.toList(iterator)); + } + } + + @Test + public void testSize() { + assertEquals("A newly created store should have no entries", 0, store.approximateNumEntries()); + + store.put(0, "zero"); + store.put(1, "one"); + store.put(2, "two"); + store.put(4, "four"); + store.put(5, "five"); + store.flush(); + assertEquals(5, store.approximateNumEntries()); + } + + @Test + public void shouldPutAll() { + final List> entries = new ArrayList<>(); + entries.add(new KeyValue<>(1, "one")); + entries.add(new KeyValue<>(2, "two")); + + store.putAll(entries); + + final List> allReturned = new ArrayList<>(); + final List> expectedReturned = + Arrays.asList(KeyValue.pair(1, "one"), KeyValue.pair(2, "two")); + final Iterator> iterator = store.all(); + + while (iterator.hasNext()) { + allReturned.add(iterator.next()); + } + assertThat(allReturned, equalTo(expectedReturned)); + } + + @Test + public void shouldPutReverseAll() { + final List> entries = new ArrayList<>(); + entries.add(new KeyValue<>(1, "one")); + entries.add(new KeyValue<>(2, "two")); + + store.putAll(entries); + + final List> allReturned = new ArrayList<>(); + final List> expectedReturned = + Arrays.asList(KeyValue.pair(2, "two"), KeyValue.pair(1, "one")); + final Iterator> iterator = store.reverseAll(); + + while (iterator.hasNext()) { + allReturned.add(iterator.next()); + } + assertThat(allReturned, equalTo(expectedReturned)); + } + + @Test + public void shouldDeleteFromStore() { + store.put(1, "one"); + store.put(2, "two"); + store.delete(2); + assertNull(store.get(2)); + } + + @Test + public void shouldReturnSameResultsForGetAndRangeWithEqualKeys() { + final List> entries = new ArrayList<>(); + entries.add(new KeyValue<>(1, "one")); + entries.add(new KeyValue<>(2, "two")); + entries.add(new KeyValue<>(3, "three")); + + store.putAll(entries); + + final Iterator> iterator = store.range(2, 2); + + assertEquals(iterator.next().value, store.get(2)); + assertFalse(iterator.hasNext()); + } + + @Test + public void shouldReturnSameResultsForGetAndReverseRangeWithEqualKeys() { + final List> entries = new ArrayList<>(); + entries.add(new KeyValue<>(1, "one")); + entries.add(new KeyValue<>(2, "two")); + entries.add(new KeyValue<>(3, "three")); + + store.putAll(entries); + + final Iterator> iterator = store.reverseRange(2, 2); + + assertEquals(iterator.next().value, store.get(2)); + assertFalse(iterator.hasNext()); + } + + @Test + public void shouldNotThrowConcurrentModificationException() { + store.put(0, "zero"); + + try (final KeyValueIterator results = store.range(0, 2)) { + + store.put(1, "one"); + + assertEquals(new KeyValue<>(0, "zero"), results.next()); + } + } + + @Test + public void shouldNotThrowInvalidRangeExceptionWithNegativeFromKey() { + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister()) { + try (final KeyValueIterator iterator = store.range(-1, 1)) { + assertFalse(iterator.hasNext()); + } + + final List messages = appender.getMessages(); + assertThat( + messages, + hasItem("Returning empty iterator for fetch with invalid key range: from > to." + + " This may be due to range arguments set in the wrong order, " + + "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes." + + " Note that the built-in numerical serdes do not follow this for negative numbers") + ); + } + } + + @Test + public void shouldNotThrowInvalidReverseRangeExceptionWithNegativeFromKey() { + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister()) { + try (final KeyValueIterator iterator = store.reverseRange(-1, 1)) { + assertFalse(iterator.hasNext()); + } + + final List messages = appender.getMessages(); + assertThat( + messages, + hasItem("Returning empty iterator for fetch with invalid key range: from > to." + + " This may be due to range arguments set in the wrong order, " + + "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes." + + " Note that the built-in numerical serdes do not follow this for negative numbers") + ); + } + } + + @Test + public void shouldNotThrowInvalidRangeExceptionWithFromLargerThanTo() { + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister()) { + try (final KeyValueIterator iterator = store.range(2, 1)) { + assertFalse(iterator.hasNext()); + } + + final List messages = appender.getMessages(); + assertThat( + messages, + hasItem("Returning empty iterator for fetch with invalid key range: from > to." + + " This may be due to range arguments set in the wrong order, " + + "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes." + + " Note that the built-in numerical serdes do not follow this for negative numbers") + ); + } + } + + @Test + public void shouldNotThrowInvalidReverseRangeExceptionWithFromLargerThanTo() { + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister()) { + try (final KeyValueIterator iterator = store.reverseRange(2, 1)) { + assertFalse(iterator.hasNext()); + } + + final List messages = appender.getMessages(); + assertThat( + messages, + hasItem("Returning empty iterator for fetch with invalid key range: from > to." + + " This may be due to range arguments set in the wrong order, " + + "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes." + + " Note that the built-in numerical serdes do not follow this for negative numbers") + ); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractRocksDBSegmentedBytesStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractRocksDBSegmentedBytesStoreTest.java new file mode 100644 index 0000000..0640a46 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractRocksDBSegmentedBytesStoreTest.java @@ -0,0 +1,627 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.SystemTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.Window; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.SessionWindow; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.internals.MockStreamsMetrics; +import org.apache.kafka.streams.processor.internals.Task.TaskType; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.StateSerdes; +import org.apache.kafka.test.InternalMockProcessorContext; +import org.apache.kafka.test.MockRecordCollector; +import org.apache.kafka.test.StreamsTestUtils; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; +import org.rocksdb.WriteBatch; + +import java.io.File; +import java.text.SimpleDateFormat; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Date; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Properties; +import java.util.Set; +import java.util.SimpleTimeZone; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.streams.state.internals.WindowKeySchema.timeWindowForSize; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertTrue; + +@RunWith(Parameterized.class) +public abstract class AbstractRocksDBSegmentedBytesStoreTest { + + private final long windowSizeForTimeWindow = 500; + private InternalMockProcessorContext context; + private AbstractRocksDBSegmentedBytesStore bytesStore; + private File stateDir; + private final Window[] windows = new Window[4]; + private Window nextSegmentWindow; + + final long retention = 1000; + final long segmentInterval = 60_000L; + final String storeName = "bytes-store"; + + @Parameter + public SegmentedBytesStore.KeySchema schema; + + @Parameters(name = "{0}") + public static Object[] getKeySchemas() { + return new Object[] {new SessionKeySchema(), new WindowKeySchema()}; + } + + @Before + public void before() { + if (schema instanceof SessionKeySchema) { + windows[0] = new SessionWindow(10L, 10L); + windows[1] = new SessionWindow(500L, 1000L); + windows[2] = new SessionWindow(1_000L, 1_500L); + windows[3] = new SessionWindow(30_000L, 60_000L); + // All four of the previous windows will go into segment 1. + // The nextSegmentWindow is computed be a high enough time that when it gets written + // to the segment store, it will advance stream time past the first segment's retention time and + // expire it. + nextSegmentWindow = new SessionWindow(segmentInterval + retention, segmentInterval + retention); + } + if (schema instanceof WindowKeySchema) { + windows[0] = timeWindowForSize(10L, windowSizeForTimeWindow); + windows[1] = timeWindowForSize(500L, windowSizeForTimeWindow); + windows[2] = timeWindowForSize(1_000L, windowSizeForTimeWindow); + windows[3] = timeWindowForSize(60_000L, windowSizeForTimeWindow); + // All four of the previous windows will go into segment 1. + // The nextSegmentWindow is computed be a high enough time that when it gets written + // to the segment store, it will advance stream time past the first segment's retention time and + // expire it. + nextSegmentWindow = timeWindowForSize(segmentInterval + retention, windowSizeForTimeWindow); + } + + bytesStore = getBytesStore(); + + stateDir = TestUtils.tempDirectory(); + context = new InternalMockProcessorContext<>( + stateDir, + Serdes.String(), + Serdes.Long(), + new MockRecordCollector(), + new ThreadCache(new LogContext("testCache "), 0, new MockStreamsMetrics(new Metrics())) + ); + bytesStore.init((StateStoreContext) context, bytesStore); + } + + @After + public void close() { + bytesStore.close(); + } + + abstract AbstractRocksDBSegmentedBytesStore getBytesStore(); + + abstract AbstractSegments newSegments(); + + @Test + public void shouldPutAndFetch() { + final String keyA = "a"; + final String keyB = "b"; + final String keyC = "c"; + bytesStore.put(serializeKey(new Windowed<>(keyA, windows[0])), serializeValue(10)); + bytesStore.put(serializeKey(new Windowed<>(keyA, windows[1])), serializeValue(50)); + bytesStore.put(serializeKey(new Windowed<>(keyB, windows[2])), serializeValue(100)); + bytesStore.put(serializeKey(new Windowed<>(keyC, windows[3])), serializeValue(200)); + + try (final KeyValueIterator values = bytesStore.fetch( + Bytes.wrap(keyA.getBytes()), 0, windows[2].start())) { + + final List, Long>> expected = Arrays.asList( + KeyValue.pair(new Windowed<>(keyA, windows[0]), 10L), + KeyValue.pair(new Windowed<>(keyA, windows[1]), 50L) + ); + + assertEquals(expected, toList(values)); + } + + try (final KeyValueIterator values = bytesStore.fetch( + Bytes.wrap(keyA.getBytes()), Bytes.wrap(keyB.getBytes()), 0, windows[2].start())) { + + final List, Long>> expected = Arrays.asList( + KeyValue.pair(new Windowed<>(keyA, windows[0]), 10L), + KeyValue.pair(new Windowed<>(keyA, windows[1]), 50L), + KeyValue.pair(new Windowed<>(keyB, windows[2]), 100L) + ); + + assertEquals(expected, toList(values)); + } + + try (final KeyValueIterator values = bytesStore.fetch( + null, Bytes.wrap(keyB.getBytes()), 0, windows[2].start())) { + + final List, Long>> expected = Arrays.asList( + KeyValue.pair(new Windowed<>(keyA, windows[0]), 10L), + KeyValue.pair(new Windowed<>(keyA, windows[1]), 50L), + KeyValue.pair(new Windowed<>(keyB, windows[2]), 100L) + ); + + assertEquals(expected, toList(values)); + } + + try (final KeyValueIterator values = bytesStore.fetch( + Bytes.wrap(keyB.getBytes()), null, 0, windows[3].start())) { + + final List, Long>> expected = Arrays.asList( + KeyValue.pair(new Windowed<>(keyB, windows[2]), 100L), + KeyValue.pair(new Windowed<>(keyC, windows[3]), 200L) + ); + + assertEquals(expected, toList(values)); + } + + try (final KeyValueIterator values = bytesStore.fetch( + null, null, 0, windows[3].start())) { + + final List, Long>> expected = Arrays.asList( + KeyValue.pair(new Windowed<>(keyA, windows[0]), 10L), + KeyValue.pair(new Windowed<>(keyA, windows[1]), 50L), + KeyValue.pair(new Windowed<>(keyB, windows[2]), 100L), + KeyValue.pair(new Windowed<>(keyC, windows[3]), 200L) + ); + + assertEquals(expected, toList(values)); + } + } + + @Test + public void shouldPutAndBackwardFetch() { + final String keyA = "a"; + final String keyB = "b"; + final String keyC = "c"; + bytesStore.put(serializeKey(new Windowed<>(keyA, windows[0])), serializeValue(10)); + bytesStore.put(serializeKey(new Windowed<>(keyA, windows[1])), serializeValue(50)); + bytesStore.put(serializeKey(new Windowed<>(keyB, windows[2])), serializeValue(100)); + bytesStore.put(serializeKey(new Windowed<>(keyC, windows[3])), serializeValue(200)); + + try (final KeyValueIterator values = bytesStore.backwardFetch( + Bytes.wrap(keyA.getBytes()), 0, windows[2].start())) { + + final List, Long>> expected = Arrays.asList( + KeyValue.pair(new Windowed<>(keyA, windows[1]), 50L), + KeyValue.pair(new Windowed<>(keyA, windows[0]), 10L) + ); + + assertEquals(expected, toList(values)); + } + + try (final KeyValueIterator values = bytesStore.backwardFetch( + Bytes.wrap(keyA.getBytes()), Bytes.wrap(keyB.getBytes()), 0, windows[2].start())) { + + final List, Long>> expected = Arrays.asList( + KeyValue.pair(new Windowed<>(keyB, windows[2]), 100L), + KeyValue.pair(new Windowed<>(keyA, windows[1]), 50L), + KeyValue.pair(new Windowed<>(keyA, windows[0]), 10L) + ); + + assertEquals(expected, toList(values)); + } + + try (final KeyValueIterator values = bytesStore.backwardFetch( + null, Bytes.wrap(keyB.getBytes()), 0, windows[2].start())) { + + final List, Long>> expected = Arrays.asList( + KeyValue.pair(new Windowed<>(keyB, windows[2]), 100L), + KeyValue.pair(new Windowed<>(keyA, windows[1]), 50L), + KeyValue.pair(new Windowed<>(keyA, windows[0]), 10L) + ); + + assertEquals(expected, toList(values)); + } + + try (final KeyValueIterator values = bytesStore.backwardFetch( + Bytes.wrap(keyB.getBytes()), null, 0, windows[3].start())) { + + final List, Long>> expected = Arrays.asList( + KeyValue.pair(new Windowed<>(keyC, windows[3]), 200L), + KeyValue.pair(new Windowed<>(keyB, windows[2]), 100L) + ); + + assertEquals(expected, toList(values)); + } + + try (final KeyValueIterator values = bytesStore.backwardFetch( + null, null, 0, windows[3].start())) { + + final List, Long>> expected = Arrays.asList( + KeyValue.pair(new Windowed<>(keyC, windows[3]), 200L), + KeyValue.pair(new Windowed<>(keyB, windows[2]), 100L), + KeyValue.pair(new Windowed<>(keyA, windows[1]), 50L), + KeyValue.pair(new Windowed<>(keyA, windows[0]), 10L) + ); + + assertEquals(expected, toList(values)); + } + } + + @Test + public void shouldFindValuesWithinRange() { + final String key = "a"; + bytesStore.put(serializeKey(new Windowed<>(key, windows[0])), serializeValue(10)); + bytesStore.put(serializeKey(new Windowed<>(key, windows[1])), serializeValue(50)); + bytesStore.put(serializeKey(new Windowed<>(key, windows[2])), serializeValue(100)); + try (final KeyValueIterator results = bytesStore.fetch(Bytes.wrap(key.getBytes()), 1, 999)) { + final List, Long>> expected = Arrays.asList( + KeyValue.pair(new Windowed<>(key, windows[0]), 10L), + KeyValue.pair(new Windowed<>(key, windows[1]), 50L) + ); + + assertEquals(expected, toList(results)); + } + } + + @Test + public void shouldRemove() { + bytesStore.put(serializeKey(new Windowed<>("a", windows[0])), serializeValue(30)); + bytesStore.put(serializeKey(new Windowed<>("a", windows[1])), serializeValue(50)); + + bytesStore.remove(serializeKey(new Windowed<>("a", windows[0]))); + try (final KeyValueIterator value = bytesStore.fetch(Bytes.wrap("a".getBytes()), 0, 100)) { + assertFalse(value.hasNext()); + } + } + + @Test + public void shouldRollSegments() { + // just to validate directories + final AbstractSegments segments = newSegments(); + final String key = "a"; + + bytesStore.put(serializeKey(new Windowed<>(key, windows[0])), serializeValue(50)); + bytesStore.put(serializeKey(new Windowed<>(key, windows[1])), serializeValue(100)); + bytesStore.put(serializeKey(new Windowed<>(key, windows[2])), serializeValue(500)); + assertEquals(Collections.singleton(segments.segmentName(0)), segmentDirs()); + + bytesStore.put(serializeKey(new Windowed<>(key, windows[3])), serializeValue(1000)); + assertEquals(Utils.mkSet(segments.segmentName(0), segments.segmentName(1)), segmentDirs()); + + final List, Long>> results = toList(bytesStore.fetch(Bytes.wrap(key.getBytes()), 0, 1500)); + + assertEquals( + Arrays.asList( + KeyValue.pair(new Windowed<>(key, windows[0]), 50L), + KeyValue.pair(new Windowed<>(key, windows[1]), 100L), + KeyValue.pair(new Windowed<>(key, windows[2]), 500L) + ), + results + ); + + segments.close(); + } + + @Test + public void shouldGetAllSegments() { + // just to validate directories + final AbstractSegments segments = newSegments(); + final String key = "a"; + + bytesStore.put(serializeKey(new Windowed<>(key, windows[0])), serializeValue(50L)); + assertEquals(Collections.singleton(segments.segmentName(0)), segmentDirs()); + + bytesStore.put(serializeKey(new Windowed<>(key, windows[3])), serializeValue(100L)); + assertEquals( + Utils.mkSet( + segments.segmentName(0), + segments.segmentName(1) + ), + segmentDirs() + ); + + final List, Long>> results = toList(bytesStore.all()); + assertEquals( + Arrays.asList( + KeyValue.pair(new Windowed<>(key, windows[0]), 50L), + KeyValue.pair(new Windowed<>(key, windows[3]), 100L) + ), + results + ); + + segments.close(); + } + + @Test + public void shouldFetchAllSegments() { + // just to validate directories + final AbstractSegments segments = newSegments(); + final String key = "a"; + + bytesStore.put(serializeKey(new Windowed<>(key, windows[0])), serializeValue(50L)); + assertEquals(Collections.singleton(segments.segmentName(0)), segmentDirs()); + + bytesStore.put(serializeKey(new Windowed<>(key, windows[3])), serializeValue(100L)); + assertEquals( + Utils.mkSet( + segments.segmentName(0), + segments.segmentName(1) + ), + segmentDirs() + ); + + final List, Long>> results = toList(bytesStore.fetchAll(0L, 60_000L)); + assertEquals( + Arrays.asList( + KeyValue.pair(new Windowed<>(key, windows[0]), 50L), + KeyValue.pair(new Windowed<>(key, windows[3]), 100L) + ), + results + ); + + segments.close(); + } + + @Test + public void shouldLoadSegmentsWithOldStyleDateFormattedName() { + final AbstractSegments segments = newSegments(); + final String key = "a"; + + bytesStore.put(serializeKey(new Windowed<>(key, windows[0])), serializeValue(50L)); + bytesStore.put(serializeKey(new Windowed<>(key, windows[3])), serializeValue(100L)); + bytesStore.close(); + + final String firstSegmentName = segments.segmentName(0); + final String[] nameParts = firstSegmentName.split("\\."); + final long segmentId = Long.parseLong(nameParts[1]); + final SimpleDateFormat formatter = new SimpleDateFormat("yyyyMMddHHmm"); + formatter.setTimeZone(new SimpleTimeZone(0, "UTC")); + final String formatted = formatter.format(new Date(segmentId * segmentInterval)); + final File parent = new File(stateDir, storeName); + final File oldStyleName = new File(parent, nameParts[0] + "-" + formatted); + assertTrue(new File(parent, firstSegmentName).renameTo(oldStyleName)); + + bytesStore = getBytesStore(); + + bytesStore.init((StateStoreContext) context, bytesStore); + final List, Long>> results = toList(bytesStore.fetch(Bytes.wrap(key.getBytes()), 0L, 60_000L)); + assertThat( + results, + equalTo( + Arrays.asList( + KeyValue.pair(new Windowed<>(key, windows[0]), 50L), + KeyValue.pair(new Windowed<>(key, windows[3]), 100L) + ) + ) + ); + + segments.close(); + } + + @Test + public void shouldLoadSegmentsWithOldStyleColonFormattedName() { + final AbstractSegments segments = newSegments(); + final String key = "a"; + + bytesStore.put(serializeKey(new Windowed<>(key, windows[0])), serializeValue(50L)); + bytesStore.put(serializeKey(new Windowed<>(key, windows[3])), serializeValue(100L)); + bytesStore.close(); + + final String firstSegmentName = segments.segmentName(0); + final String[] nameParts = firstSegmentName.split("\\."); + final File parent = new File(stateDir, storeName); + final File oldStyleName = new File(parent, nameParts[0] + ":" + Long.parseLong(nameParts[1])); + assertTrue(new File(parent, firstSegmentName).renameTo(oldStyleName)); + + bytesStore = getBytesStore(); + + bytesStore.init((StateStoreContext) context, bytesStore); + final List, Long>> results = toList(bytesStore.fetch(Bytes.wrap(key.getBytes()), 0L, 60_000L)); + assertThat( + results, + equalTo( + Arrays.asList( + KeyValue.pair(new Windowed<>(key, windows[0]), 50L), + KeyValue.pair(new Windowed<>(key, windows[3]), 100L) + ) + ) + ); + + segments.close(); + } + + @Test + public void shouldBeAbleToWriteToReInitializedStore() { + final String key = "a"; + // need to create a segment so we can attempt to write to it again. + bytesStore.put(serializeKey(new Windowed<>(key, windows[0])), serializeValue(50)); + bytesStore.close(); + bytesStore.init((StateStoreContext) context, bytesStore); + bytesStore.put(serializeKey(new Windowed<>(key, windows[1])), serializeValue(100)); + } + + @Test + public void shouldCreateWriteBatches() { + final String key = "a"; + final Collection> records = new ArrayList<>(); + records.add(new KeyValue<>(serializeKey(new Windowed<>(key, windows[0])).get(), serializeValue(50L))); + records.add(new KeyValue<>(serializeKey(new Windowed<>(key, windows[3])).get(), serializeValue(100L))); + final Map writeBatchMap = bytesStore.getWriteBatches(records); + assertEquals(2, writeBatchMap.size()); + for (final WriteBatch batch : writeBatchMap.values()) { + assertEquals(1, batch.count()); + } + } + + @Test + public void shouldRestoreToByteStoreForActiveTask() { + shouldRestoreToByteStore(TaskType.ACTIVE); + } + + @Test + public void shouldRestoreToByteStoreForStandbyTask() { + context.transitionToStandby(null); + shouldRestoreToByteStore(TaskType.STANDBY); + } + + private void shouldRestoreToByteStore(final TaskType taskType) { + bytesStore.init((StateStoreContext) context, bytesStore); + // 0 segments initially. + assertEquals(0, bytesStore.getSegments().size()); + final String key = "a"; + final Collection> records = new ArrayList<>(); + records.add(new KeyValue<>(serializeKey(new Windowed<>(key, windows[0])).get(), serializeValue(50L))); + records.add(new KeyValue<>(serializeKey(new Windowed<>(key, windows[3])).get(), serializeValue(100L))); + bytesStore.restoreAllInternal(records); + + // 2 segments are created during restoration. + assertEquals(2, bytesStore.getSegments().size()); + + final List, Long>> expected = new ArrayList<>(); + expected.add(new KeyValue<>(new Windowed<>(key, windows[0]), 50L)); + expected.add(new KeyValue<>(new Windowed<>(key, windows[3]), 100L)); + + final List, Long>> results = toList(bytesStore.all()); + assertEquals(expected, results); + } + + @Test + public void shouldLogAndMeasureExpiredRecords() { + final Properties streamsConfig = StreamsTestUtils.getStreamsConfig(); + final AbstractRocksDBSegmentedBytesStore bytesStore = getBytesStore(); + final InternalMockProcessorContext context = new InternalMockProcessorContext( + TestUtils.tempDirectory(), + new StreamsConfig(streamsConfig) + ); + final Time time = new SystemTime(); + context.setSystemTimeMs(time.milliseconds()); + bytesStore.init((StateStoreContext) context, bytesStore); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister()) { + // write a record to advance stream time, with a high enough timestamp + // that the subsequent record in windows[0] will already be expired. + bytesStore.put(serializeKey(new Windowed<>("dummy", nextSegmentWindow)), serializeValue(0)); + + final Bytes key = serializeKey(new Windowed<>("a", windows[0])); + final byte[] value = serializeValue(5); + bytesStore.put(key, value); + + final List messages = appender.getMessages(); + assertThat(messages, hasItem("Skipping record for expired segment.")); + } + + final Map metrics = context.metrics().metrics(); + final String threadId = Thread.currentThread().getName(); + final Metric dropTotal; + final Metric dropRate; + dropTotal = metrics.get(new MetricName( + "dropped-records-total", + "stream-task-metrics", + "", + mkMap( + mkEntry("thread-id", threadId), + mkEntry("task-id", "0_0") + ) + )); + + dropRate = metrics.get(new MetricName( + "dropped-records-rate", + "stream-task-metrics", + "", + mkMap( + mkEntry("thread-id", threadId), + mkEntry("task-id", "0_0") + ) + )); + assertEquals(1.0, dropTotal.metricValue()); + assertNotEquals(0.0, dropRate.metricValue()); + + bytesStore.close(); + } + + private Set segmentDirs() { + final File windowDir = new File(stateDir, storeName); + + return Utils.mkSet(Objects.requireNonNull(windowDir.list())); + } + + private Bytes serializeKey(final Windowed key) { + final StateSerdes stateSerdes = StateSerdes.withBuiltinTypes("dummy", String.class, Long.class); + if (schema instanceof SessionKeySchema) { + return Bytes.wrap(SessionKeySchema.toBinary(key, stateSerdes.keySerializer(), "dummy")); + } else if (schema instanceof WindowKeySchema) { + return WindowKeySchema.toStoreKeyBinary(key, 0, stateSerdes); + } else { + throw new IllegalStateException("Unrecognized serde schema"); + } + } + + private byte[] serializeValue(final long value) { + return Serdes.Long().serializer().serialize("", value); + } + + private List, Long>> toList(final KeyValueIterator iterator) { + final List, Long>> results = new ArrayList<>(); + final StateSerdes stateSerdes = StateSerdes.withBuiltinTypes("dummy", String.class, Long.class); + while (iterator.hasNext()) { + final KeyValue next = iterator.next(); + if (schema instanceof WindowKeySchema) { + final KeyValue, Long> deserialized = KeyValue.pair( + WindowKeySchema.fromStoreKey( + next.key.get(), + windowSizeForTimeWindow, + stateSerdes.keyDeserializer(), + stateSerdes.topic() + ), + stateSerdes.valueDeserializer().deserialize("dummy", next.value) + ); + results.add(deserialized); + } else if (schema instanceof SessionKeySchema) { + final KeyValue, Long> deserialized = KeyValue.pair( + SessionKeySchema.from(next.key.get(), stateSerdes.keyDeserializer(), "dummy"), + stateSerdes.valueDeserializer().deserialize("dummy", next.value) + ); + results.add(deserialized); + } else { + throw new IllegalStateException("Unrecognized serde schema"); + } + } + return results; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractSessionBytesStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractSessionBytesStoreTest.java new file mode 100644 index 0000000..78b9d63 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractSessionBytesStoreTest.java @@ -0,0 +1,813 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.SystemTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.SessionWindow; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.internals.MockStreamsMetrics; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.SessionStore; +import org.apache.kafka.test.InternalMockProcessorContext; +import org.apache.kafka.test.MockRecordCollector; +import org.apache.kafka.test.StreamsTestUtils; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Properties; + +import static java.util.Arrays.asList; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.toList; +import static org.apache.kafka.test.StreamsTestUtils.valuesToSet; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + + +public abstract class AbstractSessionBytesStoreTest { + + static final long SEGMENT_INTERVAL = 60_000L; + static final long RETENTION_PERIOD = 10_000L; + + SessionStore sessionStore; + + private MockRecordCollector recordCollector; + + private InternalMockProcessorContext context; + + abstract SessionStore buildSessionStore(final long retentionPeriod, + final Serde keySerde, + final Serde valueSerde); + + @Before + public void setUp() { + sessionStore = buildSessionStore(RETENTION_PERIOD, Serdes.String(), Serdes.Long()); + recordCollector = new MockRecordCollector(); + context = new InternalMockProcessorContext<>( + TestUtils.tempDirectory(), + Serdes.String(), + Serdes.Long(), + recordCollector, + new ThreadCache( + new LogContext("testCache"), + 0, + new MockStreamsMetrics(new Metrics()))); + context.setTime(1L); + + sessionStore.init((StateStoreContext) context, sessionStore); + } + + @After + public void after() { + sessionStore.close(); + } + + @Test + public void shouldPutAndFindSessionsInRange() { + final String key = "a"; + final Windowed a1 = new Windowed<>(key, new SessionWindow(10, 10L)); + final Windowed a2 = new Windowed<>(key, new SessionWindow(500L, 1000L)); + sessionStore.put(a1, 1L); + sessionStore.put(a2, 2L); + sessionStore.put(new Windowed<>(key, new SessionWindow(1500L, 2000L)), 1L); + sessionStore.put(new Windowed<>(key, new SessionWindow(2500L, 3000L)), 2L); + + final List, Long>> expected = + Arrays.asList(KeyValue.pair(a1, 1L), KeyValue.pair(a2, 2L)); + + try (final KeyValueIterator, Long> values = sessionStore.findSessions(key, 0, 1000L) + ) { + assertEquals(expected, toList(values)); + } + + final List, Long>> expected2 = + Collections.singletonList(KeyValue.pair(a2, 2L)); + + try (final KeyValueIterator, Long> values2 = sessionStore.findSessions(key, 400L, 600L) + ) { + assertEquals(expected2, toList(values2)); + } + } + + @Test + public void shouldPutAndBackwardFindSessionsInRange() { + final String key = "a"; + final Windowed a1 = new Windowed<>(key, new SessionWindow(10, 10L)); + final Windowed a2 = new Windowed<>(key, new SessionWindow(500L, 1000L)); + sessionStore.put(a1, 1L); + sessionStore.put(a2, 2L); + sessionStore.put(new Windowed<>(key, new SessionWindow(1500L, 2000L)), 1L); + sessionStore.put(new Windowed<>(key, new SessionWindow(2500L, 3000L)), 2L); + + final LinkedList, Long>> expected = new LinkedList<>(); + expected.add(KeyValue.pair(a1, 1L)); + expected.add(KeyValue.pair(a2, 2L)); + + try (final KeyValueIterator, Long> values = sessionStore.backwardFindSessions(key, 0, 1000L)) { + assertEquals(toList(expected.descendingIterator()), toList(values)); + } + + final List, Long>> expected2 = + Collections.singletonList(KeyValue.pair(a2, 2L)); + + try (final KeyValueIterator, Long> values2 = sessionStore.backwardFindSessions(key, 400L, 600L)) { + assertEquals(expected2, toList(values2)); + } + } + + @Test + public void shouldFetchAllSessionsWithSameRecordKey() { + final LinkedList, Long>> expected = new LinkedList<>(); + expected.add(KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 1L)); + expected.add(KeyValue.pair(new Windowed<>("a", new SessionWindow(10, 10)), 2L)); + expected.add(KeyValue.pair(new Windowed<>("a", new SessionWindow(100, 100)), 3L)); + expected.add(KeyValue.pair(new Windowed<>("a", new SessionWindow(1000, 1000)), 4L)); + + for (final KeyValue, Long> kv : expected) { + sessionStore.put(kv.key, kv.value); + } + + // add one that shouldn't appear in the results + sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 0)), 5L); + + try (final KeyValueIterator, Long> values = sessionStore.fetch("a")) { + assertEquals(expected, toList(values)); + } + } + + @Test + public void shouldBackwardFetchAllSessionsWithSameRecordKey() { + final LinkedList, Long>> expected = new LinkedList<>(); + expected.add(KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 1L)); + expected.add(KeyValue.pair(new Windowed<>("a", new SessionWindow(10, 10)), 2L)); + expected.add(KeyValue.pair(new Windowed<>("a", new SessionWindow(100, 100)), 3L)); + expected.add(KeyValue.pair(new Windowed<>("a", new SessionWindow(1000, 1000)), 4L)); + + for (final KeyValue, Long> kv : expected) { + sessionStore.put(kv.key, kv.value); + } + + // add one that shouldn't appear in the results + sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 0)), 5L); + + try (final KeyValueIterator, Long> values = sessionStore.backwardFetch("a")) { + assertEquals(toList(expected.descendingIterator()), toList(values)); + } + } + + @Test + public void shouldFetchAllSessionsWithinKeyRange() { + final List, Long>> expected = new LinkedList<>(); + expected.add(KeyValue.pair(new Windowed<>("aa", new SessionWindow(10, 10)), 2L)); + expected.add(KeyValue.pair(new Windowed<>("aaa", new SessionWindow(100, 100)), 3L)); + expected.add(KeyValue.pair(new Windowed<>("aaaa", new SessionWindow(100, 100)), 6L)); + expected.add(KeyValue.pair(new Windowed<>("b", new SessionWindow(1000, 1000)), 4L)); + expected.add(KeyValue.pair(new Windowed<>("bb", new SessionWindow(1500, 2000)), 5L)); + + for (final KeyValue, Long> kv : expected) { + sessionStore.put(kv.key, kv.value); + } + + // add some that should only be fetched in infinite fetch + sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L); + sessionStore.put(new Windowed<>("bbb", new SessionWindow(2500, 3000)), 6L); + + try (final KeyValueIterator, Long> values = sessionStore.fetch("aa", "bb")) { + assertEquals(expected, toList(values)); + } + + try (final KeyValueIterator, Long> values = sessionStore.findSessions("aa", "bb", 0L, Long.MAX_VALUE)) { + assertEquals(expected, toList(values)); + } + + // infinite keyFrom fetch case + expected.add(0, KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 1L)); + + try (final KeyValueIterator, Long> values = sessionStore.fetch(null, "bb")) { + assertEquals(expected, toList(values)); + } + + // remove the one added for unlimited start fetch case + expected.remove(0); + // infinite keyTo fetch case + expected.add(KeyValue.pair(new Windowed<>("bbb", new SessionWindow(2500, 3000)), 6L)); + + try (final KeyValueIterator, Long> values = sessionStore.fetch("aa", null)) { + assertEquals(expected, toList(values)); + } + + // fetch all case + expected.add(0, KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 1L)); + + try (final KeyValueIterator, Long> values = sessionStore.fetch(null, null)) { + assertEquals(expected, toList(values)); + } + } + + @Test + public void shouldBackwardFetchAllSessionsWithinKeyRange() { + final LinkedList, Long>> expected = new LinkedList<>(); + expected.add(KeyValue.pair(new Windowed<>("aa", new SessionWindow(10, 10)), 2L)); + expected.add(KeyValue.pair(new Windowed<>("aaa", new SessionWindow(100, 100)), 3L)); + expected.add(KeyValue.pair(new Windowed<>("aaaa", new SessionWindow(100, 100)), 6L)); + expected.add(KeyValue.pair(new Windowed<>("b", new SessionWindow(1000, 1000)), 4L)); + expected.add(KeyValue.pair(new Windowed<>("bb", new SessionWindow(1500, 2000)), 5L)); + + for (final KeyValue, Long> kv : expected) { + sessionStore.put(kv.key, kv.value); + } + + // add some that should only be fetched in infinite fetch + sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L); + sessionStore.put(new Windowed<>("bbb", new SessionWindow(2500, 3000)), 6L); + + try (final KeyValueIterator, Long> values = sessionStore.backwardFetch("aa", "bb")) { + assertEquals(toList(expected.descendingIterator()), toList(values)); + } + + try (final KeyValueIterator, Long> values = sessionStore.backwardFindSessions("aa", "bb", 0L, Long.MAX_VALUE)) { + assertEquals(toList(expected.descendingIterator()), toList(values)); + } + + // infinite keyFrom fetch case + expected.add(0, KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 1L)); + + try (final KeyValueIterator, Long> values = sessionStore.backwardFetch(null, "bb")) { + assertEquals(toList(expected.descendingIterator()), toList(values)); + } + + // remove the one added for unlimited start fetch case + expected.remove(0); + // infinite keyTo fetch case + expected.add(KeyValue.pair(new Windowed<>("bbb", new SessionWindow(2500, 3000)), 6L)); + + try (final KeyValueIterator, Long> values = sessionStore.backwardFetch("aa", null)) { + assertEquals(toList(expected.descendingIterator()), toList(values)); + } + + // fetch all case + expected.add(0, KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 1L)); + + try (final KeyValueIterator, Long> values = sessionStore.backwardFetch(null, null)) { + assertEquals(toList(expected.descendingIterator()), toList(values)); + } + } + + @Test + public void shouldFetchExactSession() { + sessionStore.put(new Windowed<>("a", new SessionWindow(0, 4)), 1L); + sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 3)), 2L); + sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 4)), 3L); + sessionStore.put(new Windowed<>("aa", new SessionWindow(1, 4)), 4L); + sessionStore.put(new Windowed<>("aaa", new SessionWindow(0, 4)), 5L); + + final long result = sessionStore.fetchSession("aa", 0, 4); + assertEquals(3L, result); + } + + @Test + public void shouldReturnNullOnSessionNotFound() { + assertNull(sessionStore.fetchSession("any key", 0L, 5L)); + } + + @Test + public void shouldFindValuesWithinMergingSessionWindowRange() { + final String key = "a"; + sessionStore.put(new Windowed<>(key, new SessionWindow(0L, 0L)), 1L); + sessionStore.put(new Windowed<>(key, new SessionWindow(1000L, 1000L)), 2L); + + final List, Long>> expected = Arrays.asList( + KeyValue.pair(new Windowed<>(key, new SessionWindow(0L, 0L)), 1L), + KeyValue.pair(new Windowed<>(key, new SessionWindow(1000L, 1000L)), 2L)); + + try (final KeyValueIterator, Long> results = sessionStore.findSessions(key, -1, 1000L)) { + assertEquals(expected, toList(results)); + } + } + + @Test + public void shouldBackwardFindValuesWithinMergingSessionWindowRange() { + final String key = "a"; + sessionStore.put(new Windowed<>(key, new SessionWindow(0L, 0L)), 1L); + sessionStore.put(new Windowed<>(key, new SessionWindow(1000L, 1000L)), 2L); + + final LinkedList, Long>> expected = new LinkedList<>(); + expected.add(KeyValue.pair(new Windowed<>(key, new SessionWindow(0L, 0L)), 1L)); + expected.add(KeyValue.pair(new Windowed<>(key, new SessionWindow(1000L, 1000L)), 2L)); + + try (final KeyValueIterator, Long> results = sessionStore.backwardFindSessions(key, -1, 1000L)) { + assertEquals(toList(expected.descendingIterator()), toList(results)); + } + } + + @Test + public void shouldRemove() { + sessionStore.put(new Windowed<>("a", new SessionWindow(0, 1000)), 1L); + sessionStore.put(new Windowed<>("a", new SessionWindow(1500, 2500)), 2L); + + sessionStore.remove(new Windowed<>("a", new SessionWindow(0, 1000))); + + try (final KeyValueIterator, Long> results = sessionStore.findSessions("a", 0L, 1000L)) { + assertFalse(results.hasNext()); + } + + try (final KeyValueIterator, Long> results = sessionStore.findSessions("a", 1500L, 2500L)) { + assertTrue(results.hasNext()); + } + } + + @Test + public void shouldRemoveOnNullAggValue() { + sessionStore.put(new Windowed<>("a", new SessionWindow(0, 1000)), 1L); + sessionStore.put(new Windowed<>("a", new SessionWindow(1500, 2500)), 2L); + + sessionStore.put(new Windowed<>("a", new SessionWindow(0, 1000)), null); + + try (final KeyValueIterator, Long> results = sessionStore.findSessions("a", 0L, 1000L)) { + assertFalse(results.hasNext()); + } + + try (final KeyValueIterator, Long> results = sessionStore.findSessions("a", 1500L, 2500L)) { + assertTrue(results.hasNext()); + } + } + + @Test + public void shouldFindSessionsToMerge() { + final Windowed session1 = new Windowed<>("a", new SessionWindow(0, 100)); + final Windowed session2 = new Windowed<>("a", new SessionWindow(101, 200)); + final Windowed session3 = new Windowed<>("a", new SessionWindow(201, 300)); + final Windowed session4 = new Windowed<>("a", new SessionWindow(301, 400)); + final Windowed session5 = new Windowed<>("a", new SessionWindow(401, 500)); + sessionStore.put(session1, 1L); + sessionStore.put(session2, 2L); + sessionStore.put(session3, 3L); + sessionStore.put(session4, 4L); + sessionStore.put(session5, 5L); + + final List, Long>> expected = + Arrays.asList(KeyValue.pair(session2, 2L), KeyValue.pair(session3, 3L)); + + try (final KeyValueIterator, Long> results = sessionStore.findSessions("a", 150, 300)) { + assertEquals(expected, toList(results)); + } + } + + @Test + public void shouldBackwardFindSessionsToMerge() { + final Windowed session1 = new Windowed<>("a", new SessionWindow(0, 100)); + final Windowed session2 = new Windowed<>("a", new SessionWindow(101, 200)); + final Windowed session3 = new Windowed<>("a", new SessionWindow(201, 300)); + final Windowed session4 = new Windowed<>("a", new SessionWindow(301, 400)); + final Windowed session5 = new Windowed<>("a", new SessionWindow(401, 500)); + sessionStore.put(session1, 1L); + sessionStore.put(session2, 2L); + sessionStore.put(session3, 3L); + sessionStore.put(session4, 4L); + sessionStore.put(session5, 5L); + + final List, Long>> expected = + asList(KeyValue.pair(session3, 3L), KeyValue.pair(session2, 2L)); + + try (final KeyValueIterator, Long> results = sessionStore.backwardFindSessions("a", 150, 300)) { + assertEquals(expected, toList(results)); + } + } + + @Test + public void shouldFetchExactKeys() { + sessionStore.close(); + sessionStore = buildSessionStore(0x7a00000000000000L, Serdes.String(), Serdes.Long()); + sessionStore.init((StateStoreContext) context, sessionStore); + + sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L); + sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 10)), 2L); + sessionStore.put(new Windowed<>("a", new SessionWindow(10, 20)), 3L); + sessionStore.put(new Windowed<>("aa", new SessionWindow(10, 20)), 4L); + sessionStore.put(new Windowed<>("a", + new SessionWindow(0x7a00000000000000L - 2, 0x7a00000000000000L - 1)), 5L); + + try (final KeyValueIterator, Long> iterator = + sessionStore.findSessions("a", 0, Long.MAX_VALUE) + ) { + assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(1L, 3L, 5L)))); + } + + try (final KeyValueIterator, Long> iterator = + sessionStore.findSessions("aa", 0, Long.MAX_VALUE) + ) { + assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(2L, 4L)))); + } + + try (final KeyValueIterator, Long> iterator = + sessionStore.findSessions("a", "aa", 0, Long.MAX_VALUE) + ) { + assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(1L, 2L, 3L, 4L, 5L)))); + } + + try (final KeyValueIterator, Long> iterator = + sessionStore.findSessions("a", "aa", 10, 0) + ) { + assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(2L)))); + } + + try (final KeyValueIterator, Long> iterator = + sessionStore.findSessions(null, "aa", 0, Long.MAX_VALUE) + ) { + assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(1L, 2L, 3L, 4L, 5L)))); + } + + try (final KeyValueIterator, Long> iterator = + sessionStore.findSessions("a", null, 0, Long.MAX_VALUE) + ) { + assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(1L, 2L, 3L, 4L, 5L)))); + } + + try (final KeyValueIterator, Long> iterator = + sessionStore.findSessions(null, null, 0, Long.MAX_VALUE) + ) { + assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(1L, 2L, 3L, 4L, 5L)))); + } + } + + @Test + public void shouldBackwardFetchExactKeys() { + sessionStore.close(); + sessionStore = buildSessionStore(0x7a00000000000000L, Serdes.String(), Serdes.Long()); + sessionStore.init((StateStoreContext) context, sessionStore); + + sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L); + sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 10)), 2L); + sessionStore.put(new Windowed<>("a", new SessionWindow(10, 20)), 3L); + sessionStore.put(new Windowed<>("aa", new SessionWindow(10, 20)), 4L); + sessionStore.put(new Windowed<>("a", + new SessionWindow(0x7a00000000000000L - 2, 0x7a00000000000000L - 1)), 5L); + + try (final KeyValueIterator, Long> iterator = + sessionStore.backwardFindSessions("a", 0, Long.MAX_VALUE) + ) { + assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(1L, 3L, 5L)))); + } + + try (final KeyValueIterator, Long> iterator = + sessionStore.backwardFindSessions("aa", 0, Long.MAX_VALUE) + ) { + assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(2L, 4L)))); + } + + try (final KeyValueIterator, Long> iterator = + sessionStore.backwardFindSessions("a", "aa", 0, Long.MAX_VALUE) + ) { + assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(1L, 2L, 3L, 4L, 5L)))); + } + + try (final KeyValueIterator, Long> iterator = + sessionStore.backwardFindSessions("a", "aa", 10, 0) + ) { + assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(2L)))); + } + + try (final KeyValueIterator, Long> iterator = + sessionStore.backwardFindSessions(null, "aa", 0, Long.MAX_VALUE) + ) { + assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(1L, 2L, 3L, 4L, 5L)))); + } + + try (final KeyValueIterator, Long> iterator = + sessionStore.backwardFindSessions("a", null, 0, Long.MAX_VALUE) + ) { + assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(1L, 2L, 3L, 4L, 5L)))); + } + + try (final KeyValueIterator, Long> iterator = + sessionStore.backwardFindSessions(null, null, 0, Long.MAX_VALUE) + ) { + assertThat(valuesToSet(iterator), equalTo(new HashSet<>(asList(1L, 2L, 3L, 4L, 5L)))); + } + } + + @Test + public void shouldFetchAndIterateOverExactBinaryKeys() { + final SessionStore sessionStore = + buildSessionStore(RETENTION_PERIOD, Serdes.Bytes(), Serdes.String()); + + sessionStore.init((StateStoreContext) context, sessionStore); + + final Bytes key1 = Bytes.wrap(new byte[] {0}); + final Bytes key2 = Bytes.wrap(new byte[] {0, 0}); + final Bytes key3 = Bytes.wrap(new byte[] {0, 0, 0}); + + sessionStore.put(new Windowed<>(key1, new SessionWindow(1, 100)), "1"); + sessionStore.put(new Windowed<>(key2, new SessionWindow(2, 100)), "2"); + sessionStore.put(new Windowed<>(key3, new SessionWindow(3, 100)), "3"); + sessionStore.put(new Windowed<>(key1, new SessionWindow(4, 100)), "4"); + sessionStore.put(new Windowed<>(key2, new SessionWindow(5, 100)), "5"); + sessionStore.put(new Windowed<>(key3, new SessionWindow(6, 100)), "6"); + sessionStore.put(new Windowed<>(key1, new SessionWindow(7, 100)), "7"); + sessionStore.put(new Windowed<>(key2, new SessionWindow(8, 100)), "8"); + sessionStore.put(new Windowed<>(key3, new SessionWindow(9, 100)), "9"); + + final List expectedKey1 = asList("1", "4", "7"); + try (KeyValueIterator, String> iterator = sessionStore.findSessions(key1, 0L, Long.MAX_VALUE)) { + assertThat(valuesToSet(iterator), equalTo(new HashSet<>(expectedKey1))); + } + + final List expectedKey2 = asList("2", "5", "8"); + try (KeyValueIterator, String> iterator = sessionStore.findSessions(key2, 0L, Long.MAX_VALUE)) { + assertThat(valuesToSet(iterator), equalTo(new HashSet<>(expectedKey2))); + } + + final List expectedKey3 = asList("3", "6", "9"); + try (KeyValueIterator, String> iterator = sessionStore.findSessions(key3, 0L, Long.MAX_VALUE)) { + assertThat(valuesToSet(iterator), equalTo(new HashSet<>(expectedKey3))); + } + + sessionStore.close(); + } + + @Test + public void shouldBackwardFetchAndIterateOverExactBinaryKeys() { + final SessionStore sessionStore = + buildSessionStore(RETENTION_PERIOD, Serdes.Bytes(), Serdes.String()); + + sessionStore.init((StateStoreContext) context, sessionStore); + + final Bytes key1 = Bytes.wrap(new byte[] {0}); + final Bytes key2 = Bytes.wrap(new byte[] {0, 0}); + final Bytes key3 = Bytes.wrap(new byte[] {0, 0, 0}); + + sessionStore.put(new Windowed<>(key1, new SessionWindow(1, 100)), "1"); + sessionStore.put(new Windowed<>(key2, new SessionWindow(2, 100)), "2"); + sessionStore.put(new Windowed<>(key3, new SessionWindow(3, 100)), "3"); + sessionStore.put(new Windowed<>(key1, new SessionWindow(4, 100)), "4"); + sessionStore.put(new Windowed<>(key2, new SessionWindow(5, 100)), "5"); + sessionStore.put(new Windowed<>(key3, new SessionWindow(6, 100)), "6"); + sessionStore.put(new Windowed<>(key1, new SessionWindow(7, 100)), "7"); + sessionStore.put(new Windowed<>(key2, new SessionWindow(8, 100)), "8"); + sessionStore.put(new Windowed<>(key3, new SessionWindow(9, 100)), "9"); + + + final List expectedKey1 = asList("7", "4", "1"); + try (KeyValueIterator, String> iterator = sessionStore.backwardFindSessions(key1, 0L, Long.MAX_VALUE)) { + assertThat(valuesToSet(iterator), equalTo(new HashSet<>(expectedKey1))); + } + + final List expectedKey2 = asList("8", "5", "2"); + try (KeyValueIterator, String> iterator = sessionStore.backwardFindSessions(key2, 0L, Long.MAX_VALUE)) { + assertThat(valuesToSet(iterator), equalTo(new HashSet<>(expectedKey2))); + } + + final List expectedKey3 = asList("9", "6", "3"); + try (KeyValueIterator, String> iterator = sessionStore.backwardFindSessions(key3, 0L, Long.MAX_VALUE)) { + assertThat(valuesToSet(iterator), equalTo(new HashSet<>(expectedKey3))); + } + + sessionStore.close(); + } + + @Test + public void testIteratorPeek() { + sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L); + sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 10)), 2L); + sessionStore.put(new Windowed<>("a", new SessionWindow(10, 20)), 3L); + sessionStore.put(new Windowed<>("aa", new SessionWindow(10, 20)), 4L); + + try (final KeyValueIterator, Long> iterator = sessionStore.findSessions("a", 0L, 20)) { + + assertEquals(iterator.peekNextKey(), new Windowed<>("a", new SessionWindow(0L, 0L))); + assertEquals(iterator.peekNextKey(), iterator.next().key); + assertEquals(iterator.peekNextKey(), iterator.next().key); + assertFalse(iterator.hasNext()); + } + } + + @Test + public void testIteratorPeekBackward() { + sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L); + sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 10)), 2L); + sessionStore.put(new Windowed<>("a", new SessionWindow(10, 20)), 3L); + sessionStore.put(new Windowed<>("aa", new SessionWindow(10, 20)), 4L); + + try (final KeyValueIterator, Long> iterator = sessionStore.backwardFindSessions("a", 0L, 20)) { + + assertEquals(iterator.peekNextKey(), new Windowed<>("a", new SessionWindow(10L, 20L))); + assertEquals(iterator.peekNextKey(), iterator.next().key); + assertEquals(iterator.peekNextKey(), iterator.next().key); + assertFalse(iterator.hasNext()); + } + } + + @SuppressWarnings("unchecked") + @Test + public void shouldRestore() { + final List, Long>> expected = Arrays.asList( + KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 1L), + KeyValue.pair(new Windowed<>("a", new SessionWindow(10, 10)), 2L), + KeyValue.pair(new Windowed<>("a", new SessionWindow(100, 100)), 3L), + KeyValue.pair(new Windowed<>("a", new SessionWindow(1000, 1000)), 4L)); + + for (final KeyValue, Long> kv : expected) { + sessionStore.put(kv.key, kv.value); + } + + try (final KeyValueIterator, Long> values = sessionStore.fetch("a")) { + assertEquals(expected, toList(values)); + } + + sessionStore.close(); + + try (final KeyValueIterator, Long> values = sessionStore.fetch("a")) { + assertEquals(Collections.emptyList(), toList(values)); + } + + + final List> changeLog = new ArrayList<>(); + for (final ProducerRecord record : recordCollector.collected()) { + changeLog.add(new KeyValue<>(((Bytes) record.key()).get(), (byte[]) record.value())); + } + + context.restore(sessionStore.name(), changeLog); + + try (final KeyValueIterator, Long> values = sessionStore.fetch("a")) { + assertEquals(expected, toList(values)); + } + } + + @Test + public void shouldCloseOpenIteratorsWhenStoreIsClosedAndNotThrowInvalidStateStoreExceptionOnHasNext() { + sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L); + sessionStore.put(new Windowed<>("b", new SessionWindow(10, 50)), 2L); + sessionStore.put(new Windowed<>("c", new SessionWindow(100, 500)), 3L); + + try (final KeyValueIterator, Long> iterator = sessionStore.fetch("a")) { + assertTrue(iterator.hasNext()); + sessionStore.close(); + + assertFalse(iterator.hasNext()); + } + } + + @Test + public void shouldReturnSameResultsForSingleKeyFindSessionsAndEqualKeyRangeFindSessions() { + sessionStore.put(new Windowed<>("a", new SessionWindow(0, 1)), 0L); + sessionStore.put(new Windowed<>("aa", new SessionWindow(2, 3)), 1L); + sessionStore.put(new Windowed<>("aa", new SessionWindow(4, 5)), 2L); + sessionStore.put(new Windowed<>("aaa", new SessionWindow(6, 7)), 3L); + + try (final KeyValueIterator, Long> singleKeyIterator = sessionStore.findSessions("aa", 0L, 10L); + final KeyValueIterator, Long> rangeIterator = sessionStore.findSessions("aa", "aa", 0L, 10L)) { + + assertEquals(singleKeyIterator.next(), rangeIterator.next()); + assertEquals(singleKeyIterator.next(), rangeIterator.next()); + assertFalse(singleKeyIterator.hasNext()); + assertFalse(rangeIterator.hasNext()); + } + } + + @Test + public void shouldLogAndMeasureExpiredRecords() { + final Properties streamsConfig = StreamsTestUtils.getStreamsConfig(); + final SessionStore sessionStore = buildSessionStore(RETENTION_PERIOD, Serdes.String(), Serdes.Long()); + final InternalMockProcessorContext context = new InternalMockProcessorContext( + TestUtils.tempDirectory(), + new StreamsConfig(streamsConfig), + recordCollector + ); + final Time time = new SystemTime(); + context.setTime(1L); + context.setSystemTimeMs(time.milliseconds()); + sessionStore.init((StateStoreContext) context, sessionStore); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister()) { + // Advance stream time by inserting record with large enough timestamp that records with timestamp 0 are expired + // Note that rocksdb will only expire segments at a time (where segment interval = 60,000 for this retention period) + sessionStore.put(new Windowed<>("initial record", new SessionWindow(0, 2 * SEGMENT_INTERVAL)), 0L); + + // Try inserting a record with timestamp 0 -- should be dropped + sessionStore.put(new Windowed<>("late record", new SessionWindow(0, 0)), 0L); + sessionStore.put(new Windowed<>("another on-time record", new SessionWindow(0, 2 * SEGMENT_INTERVAL)), 0L); + + final List messages = appender.getMessages(); + assertThat(messages, hasItem("Skipping record for expired segment.")); + } + + final Map metrics = context.metrics().metrics(); + final String threadId = Thread.currentThread().getName(); + final Metric dropTotal; + final Metric dropRate; + dropTotal = metrics.get(new MetricName( + "dropped-records-total", + "stream-task-metrics", + "", + mkMap( + mkEntry("thread-id", threadId), + mkEntry("task-id", "0_0") + ) + )); + + dropRate = metrics.get(new MetricName( + "dropped-records-rate", + "stream-task-metrics", + "", + mkMap( + mkEntry("thread-id", threadId), + mkEntry("task-id", "0_0") + ) + )); + assertEquals(1.0, dropTotal.metricValue()); + assertNotEquals(0.0, dropRate.metricValue()); + + sessionStore.close(); + } + + @Test + public void shouldNotThrowExceptionRemovingNonexistentKey() { + sessionStore.remove(new Windowed<>("a", new SessionWindow(0, 1))); + } + + @Test + public void shouldThrowNullPointerExceptionOnFindSessionsNullKey() { + assertThrows(NullPointerException.class, () -> sessionStore.findSessions(null, 1L, 2L)); + } + + @Test + public void shouldThrowNullPointerExceptionOnFetchNullKey() { + assertThrows(NullPointerException.class, () -> sessionStore.fetch(null)); + } + + @Test + public void shouldThrowNullPointerExceptionOnRemoveNullKey() { + assertThrows(NullPointerException.class, () -> sessionStore.remove(null)); + } + + @Test + public void shouldThrowNullPointerExceptionOnPutNullKey() { + assertThrows(NullPointerException.class, () -> sessionStore.put(null, 1L)); + } + + @Test + public void shouldNotThrowInvalidRangeExceptionWithNegativeFromKey() { + final String keyFrom = Serdes.String().deserializer() + .deserialize("", Serdes.Integer().serializer().serialize("", -1)); + final String keyTo = Serdes.String().deserializer() + .deserialize("", Serdes.Integer().serializer().serialize("", 1)); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(); + final KeyValueIterator, Long> iterator = sessionStore.findSessions(keyFrom, keyTo, 0L, 10L)) { + assertFalse(iterator.hasNext()); + + final List messages = appender.getMessages(); + assertThat( + messages, + hasItem("Returning empty iterator for fetch with invalid key range: from > to." + + " This may be due to range arguments set in the wrong order, " + + "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes." + + " Note that the built-in numerical serdes do not follow this for negative numbers") + ); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractWindowBytesStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractWindowBytesStoreTest.java new file mode 100644 index 0000000..e93f758 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/AbstractWindowBytesStoreTest.java @@ -0,0 +1,1199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.SystemTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.internals.MockStreamsMetrics; +import org.apache.kafka.streams.processor.internals.ProcessorRecordContext; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.StateSerdes; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.streams.state.WindowStoreIterator; +import org.apache.kafka.test.InternalMockProcessorContext; +import org.apache.kafka.test.MockRecordCollector; +import org.apache.kafka.test.StreamsTestUtils; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Set; + +import static java.time.Instant.ofEpochMilli; +import static java.util.Arrays.asList; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.test.StreamsTestUtils.toList; +import static org.apache.kafka.test.StreamsTestUtils.toSet; +import static org.apache.kafka.test.StreamsTestUtils.valuesToSet; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public abstract class AbstractWindowBytesStoreTest { + + static final long WINDOW_SIZE = 3L; + static final long SEGMENT_INTERVAL = 60_000L; + static final long RETENTION_PERIOD = 2 * SEGMENT_INTERVAL; + + final long defaultStartTime = SEGMENT_INTERVAL - 4L; + + final KeyValue, String> zero = windowedPair(0, "zero", defaultStartTime); + final KeyValue, String> one = windowedPair(1, "one", defaultStartTime + 1); + final KeyValue, String> two = windowedPair(2, "two", defaultStartTime + 2); + final KeyValue, String> three = windowedPair(3, "three", defaultStartTime + 2); + final KeyValue, String> four = windowedPair(4, "four", defaultStartTime + 4); + final KeyValue, String> five = windowedPair(5, "five", defaultStartTime + 5); + + WindowStore windowStore; + InternalMockProcessorContext context; + MockRecordCollector recordCollector; + + final File baseDir = TestUtils.tempDirectory("test"); + private final StateSerdes serdes = new StateSerdes<>("", Serdes.Integer(), Serdes.String()); + + abstract WindowStore buildWindowStore(final long retentionPeriod, + final long windowSize, + final boolean retainDuplicates, + final Serde keySerde, + final Serde valueSerde); + + @Before + public void setup() { + windowStore = buildWindowStore(RETENTION_PERIOD, WINDOW_SIZE, false, Serdes.Integer(), Serdes.String()); + + recordCollector = new MockRecordCollector(); + context = new InternalMockProcessorContext<>( + baseDir, + Serdes.String(), + Serdes.Integer(), + recordCollector, + new ThreadCache( + new LogContext("testCache"), + 0, + new MockStreamsMetrics(new Metrics()))); + context.setTime(1L); + + windowStore.init((StateStoreContext) context, windowStore); + } + + @After + public void after() { + windowStore.close(); + } + + @Test + public void testRangeAndSinglePointFetch() { + putFirstBatch(windowStore, defaultStartTime, context); + + assertEquals( + new HashSet<>(Collections.singletonList("zero")), + valuesToSet(windowStore.fetch( + 0, + ofEpochMilli(defaultStartTime + 0 - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + 0 + WINDOW_SIZE)))); + + putSecondBatch(windowStore, defaultStartTime, context); + + assertEquals("two+1", windowStore.fetch(2, defaultStartTime + 3L)); + assertEquals("two+2", windowStore.fetch(2, defaultStartTime + 4L)); + assertEquals("two+3", windowStore.fetch(2, defaultStartTime + 5L)); + assertEquals("two+4", windowStore.fetch(2, defaultStartTime + 6L)); + assertEquals("two+5", windowStore.fetch(2, defaultStartTime + 7L)); + assertEquals("two+6", windowStore.fetch(2, defaultStartTime + 8L)); + + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch( + 2, + ofEpochMilli(defaultStartTime - 2L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime - 2L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("two")), + valuesToSet(windowStore.fetch( + 2, + ofEpochMilli(defaultStartTime - 1L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime - 1L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(asList("two", "two+1")), + valuesToSet(windowStore.fetch( + 2, + ofEpochMilli(defaultStartTime - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(asList("two", "two+1", "two+2")), + valuesToSet(windowStore.fetch( + 2, + ofEpochMilli(defaultStartTime + 1L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + 1L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(asList("two", "two+1", "two+2", "two+3")), + valuesToSet(windowStore.fetch( + 2, + ofEpochMilli(defaultStartTime + 2L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + 2L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(asList("two", "two+1", "two+2", "two+3", "two+4")), + valuesToSet(windowStore.fetch( + 2, + ofEpochMilli(defaultStartTime + 3L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + 3L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(asList("two", "two+1", "two+2", "two+3", "two+4", "two+5")), + valuesToSet(windowStore.fetch( + 2, + ofEpochMilli(defaultStartTime + 4L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + 4L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(asList("two", "two+1", "two+2", "two+3", "two+4", "two+5", "two+6")), + valuesToSet(windowStore.fetch( + 2, + ofEpochMilli(defaultStartTime + 5L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + 5L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(asList("two+1", "two+2", "two+3", "two+4", "two+5", "two+6")), + valuesToSet(windowStore.fetch( + 2, + ofEpochMilli(defaultStartTime + 6L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + 6L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(asList("two+2", "two+3", "two+4", "two+5", "two+6")), + valuesToSet(windowStore.fetch( + 2, + ofEpochMilli(defaultStartTime + 7L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + 7L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(asList("two+3", "two+4", "two+5", "two+6")), + valuesToSet(windowStore.fetch( + 2, + ofEpochMilli(defaultStartTime + 8L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + 8L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(asList("two+4", "two+5", "two+6")), + valuesToSet(windowStore.fetch( + 2, + ofEpochMilli(defaultStartTime + 9L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + 9L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(asList("two+5", "two+6")), + valuesToSet(windowStore.fetch( + 2, + ofEpochMilli(defaultStartTime + 10L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + 10L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("two+6")), + valuesToSet(windowStore.fetch( + 2, + ofEpochMilli(defaultStartTime + 11L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + 11L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch( + 2, + ofEpochMilli(defaultStartTime + 12L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + 12L + WINDOW_SIZE)))); + + // Flush the store and verify all current entries were properly flushed ... + windowStore.flush(); + + final List> changeLog = new ArrayList<>(); + for (final ProducerRecord record : recordCollector.collected()) { + changeLog.add(new KeyValue<>(((Bytes) record.key()).get(), (byte[]) record.value())); + } + + final Map> entriesByKey = entriesByKey(changeLog, defaultStartTime); + + assertEquals(Utils.mkSet("zero@0"), entriesByKey.get(0)); + assertEquals(Utils.mkSet("one@1"), entriesByKey.get(1)); + assertEquals( + Utils.mkSet("two@2", "two+1@3", "two+2@4", "two+3@5", "two+4@6", "two+5@7", "two+6@8"), + entriesByKey.get(2)); + assertEquals(Utils.mkSet("three@2"), entriesByKey.get(3)); + assertEquals(Utils.mkSet("four@4"), entriesByKey.get(4)); + assertEquals(Utils.mkSet("five@5"), entriesByKey.get(5)); + assertNull(entriesByKey.get(6)); + } + + @Test + public void shouldGetAll() { + putFirstBatch(windowStore, defaultStartTime, context); + + assertEquals( + asList(zero, one, two, three, four, five), + toList(windowStore.all()) + ); + } + + @Test + public void shouldGetAllNonDeletedRecords() { + // Add some records + windowStore.put(0, "zero", defaultStartTime + 0); + windowStore.put(1, "one", defaultStartTime + 1); + windowStore.put(2, "two", defaultStartTime + 2); + windowStore.put(3, "three", defaultStartTime + 3); + windowStore.put(4, "four", defaultStartTime + 4); + + // Delete some records + windowStore.put(1, null, defaultStartTime + 1); + windowStore.put(3, null, defaultStartTime + 3); + + // Only non-deleted records should appear in the all() iterator + assertEquals( + asList(zero, two, four), + toList(windowStore.all()) + ); + } + + @Test + public void shouldGetAllReturnTimestampOrderedRecords() { + // Add some records in different order + windowStore.put(4, "four", defaultStartTime + 4); + windowStore.put(0, "zero", defaultStartTime + 0); + windowStore.put(2, "two", defaultStartTime + 2); + windowStore.put(3, "three", defaultStartTime + 3); + windowStore.put(1, "one", defaultStartTime + 1); + + // Only non-deleted records should appear in the all() iterator + final KeyValue, String> three = windowedPair(3, "three", defaultStartTime + 3); + + assertEquals( + asList(zero, one, two, three, four), + toList(windowStore.all()) + ); + } + + @Test + public void shouldEarlyClosedIteratorStillGetAllRecords() { + windowStore.put(0, "zero", defaultStartTime + 0); + windowStore.put(1, "one", defaultStartTime + 1); + + final KeyValueIterator, String> it = windowStore.all(); + assertEquals(zero, it.next()); + it.close(); + + // A new all() iterator after a previous all() iterator was closed should return all elements. + assertEquals( + asList(zero, one), + toList(windowStore.all()) + ); + } + + @Test + public void shouldGetBackwardAll() { + putFirstBatch(windowStore, defaultStartTime, context); + + assertEquals( + asList(five, four, three, two, one, zero), + toList(windowStore.backwardAll()) + ); + } + + @Test + public void shouldFetchAllInTimeRange() { + putFirstBatch(windowStore, defaultStartTime, context); + + assertEquals( + asList(one, two, three, four), + toList(windowStore.fetchAll(ofEpochMilli(defaultStartTime + 1), ofEpochMilli(defaultStartTime + 4))) + ); + assertEquals( + asList(zero, one, two, three), + toList(windowStore.fetchAll(ofEpochMilli(defaultStartTime + 0), ofEpochMilli(defaultStartTime + 3))) + ); + assertEquals( + asList(one, two, three, four, five), + toList(windowStore.fetchAll(ofEpochMilli(defaultStartTime + 1), ofEpochMilli(defaultStartTime + 5))) + ); + } + + @Test + public void shouldBackwardFetchAllInTimeRange() { + putFirstBatch(windowStore, defaultStartTime, context); + + assertEquals( + asList(four, three, two, one), + toList(windowStore.backwardFetchAll(ofEpochMilli(defaultStartTime + 1), ofEpochMilli(defaultStartTime + 4))) + ); + assertEquals( + asList(three, two, one, zero), + toList(windowStore.backwardFetchAll(ofEpochMilli(defaultStartTime + 0), ofEpochMilli(defaultStartTime + 3))) + ); + assertEquals( + asList(five, four, three, two, one), + toList(windowStore.backwardFetchAll(ofEpochMilli(defaultStartTime + 1), ofEpochMilli(defaultStartTime + 5))) + ); + } + + @Test + public void testFetchRange() { + putFirstBatch(windowStore, defaultStartTime, context); + + assertEquals( + asList(zero, one), + toList(windowStore.fetch( + 0, + 1, + ofEpochMilli(defaultStartTime + 0L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + 0L + WINDOW_SIZE))) + ); + assertEquals( + Collections.singletonList(one), + toList(windowStore.fetch( + 1, + 1, + ofEpochMilli(defaultStartTime + 0L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + 0L + WINDOW_SIZE))) + ); + assertEquals( + asList(one, two, three), + toList(windowStore.fetch( + 1, + 3, + ofEpochMilli(defaultStartTime + 0L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + 0L + WINDOW_SIZE))) + ); + assertEquals( + asList(zero, one, two, three), + toList(windowStore.fetch( + 0, + 5, + ofEpochMilli(defaultStartTime + 0L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + 0L + WINDOW_SIZE))) + ); + assertEquals( + asList(zero, one, two, three, four, five), + toList(windowStore.fetch( + 0, + 5, + ofEpochMilli(defaultStartTime + 0L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + 0L + WINDOW_SIZE + 5L))) + ); + assertEquals( + asList(two, three, four, five), + toList(windowStore.fetch( + 0, + 5, + ofEpochMilli(defaultStartTime + 2L), + ofEpochMilli(defaultStartTime + 0L + WINDOW_SIZE + 5L))) + ); + assertEquals( + Collections.emptyList(), + toList(windowStore.fetch( + 4, + 5, + ofEpochMilli(defaultStartTime + 2L), + ofEpochMilli(defaultStartTime + WINDOW_SIZE))) + ); + assertEquals( + Collections.emptyList(), + toList(windowStore.fetch( + 0, + 3, + ofEpochMilli(defaultStartTime + 3L), + ofEpochMilli(defaultStartTime + WINDOW_SIZE + 5))) + ); + assertEquals( + asList(zero, one, two), + toList(windowStore.fetch( + null, + 2, + ofEpochMilli(defaultStartTime + 0L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + WINDOW_SIZE + 2L))) + ); + assertEquals( + asList(two, three, four, five), + toList(windowStore.fetch( + 2, + null, + ofEpochMilli(defaultStartTime + 0L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + WINDOW_SIZE + 5L))) + ); + assertEquals( + asList(zero, one, two, three, four, five), + toList(windowStore.fetch( + null, + null, + ofEpochMilli(defaultStartTime + 0L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + WINDOW_SIZE + 5L))) + ); + } + + @Test + public void testBackwardFetchRange() { + putFirstBatch(windowStore, defaultStartTime, context); + + assertEquals( + asList(one, zero), + toList(windowStore.backwardFetch( + 0, + 1, + ofEpochMilli(defaultStartTime + 0L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + 0L + WINDOW_SIZE))) + ); + assertEquals( + Collections.singletonList(one), + toList(windowStore.backwardFetch( + 1, + 1, + ofEpochMilli(defaultStartTime + 0L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + 0L + WINDOW_SIZE))) + ); + assertEquals( + asList(three, two, one), + toList(windowStore.backwardFetch( + 1, + 3, + ofEpochMilli(defaultStartTime + 0L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + 0L + WINDOW_SIZE))) + ); + assertEquals( + asList(three, two, one, zero), + toList(windowStore.backwardFetch( + 0, + 5, + ofEpochMilli(defaultStartTime + 0L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + 0L + WINDOW_SIZE))) + ); + assertEquals( + asList(five, four, three, two, one, zero), + toList(windowStore.backwardFetch( + 0, + 5, + ofEpochMilli(defaultStartTime + 0L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + 0L + WINDOW_SIZE + 5L))) + ); + assertEquals( + asList(five, four, three, two), + toList(windowStore.backwardFetch( + 0, + 5, + ofEpochMilli(defaultStartTime + 2L), + ofEpochMilli(defaultStartTime + 0L + WINDOW_SIZE + 5L))) + ); + assertEquals( + Collections.emptyList(), + toList(windowStore.backwardFetch( + 4, + 5, + ofEpochMilli(defaultStartTime + 2L), + ofEpochMilli(defaultStartTime + WINDOW_SIZE))) + ); + assertEquals( + Collections.emptyList(), + toList(windowStore.backwardFetch( + 0, + 3, + ofEpochMilli(defaultStartTime + 3L), + ofEpochMilli(defaultStartTime + WINDOW_SIZE + 5))) + ); + assertEquals( + asList(two, one, zero), + toList(windowStore.backwardFetch( + null, + 2, + ofEpochMilli(defaultStartTime + 0L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + WINDOW_SIZE + 2L))) + ); + assertEquals( + asList(five, four, three, two), + toList(windowStore.backwardFetch( + 2, + null, + ofEpochMilli(defaultStartTime + 0L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + WINDOW_SIZE + 5L))) + ); + assertEquals( + asList(five, four, three, two, one, zero), + toList(windowStore.backwardFetch( + null, + null, + ofEpochMilli(defaultStartTime + 0L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + WINDOW_SIZE + 5L))) + ); + } + + @Test + public void testPutAndFetchBefore() { + putFirstBatch(windowStore, defaultStartTime, context); + + assertEquals( + new HashSet<>(Collections.singletonList("zero")), + valuesToSet(windowStore.fetch(0, ofEpochMilli(defaultStartTime + 0L - WINDOW_SIZE), ofEpochMilli(defaultStartTime + 0L)))); + assertEquals( + new HashSet<>(Collections.singletonList("one")), + valuesToSet(windowStore.fetch(1, ofEpochMilli(defaultStartTime + 1L - WINDOW_SIZE), ofEpochMilli(defaultStartTime + 1L)))); + assertEquals( + new HashSet<>(Collections.singletonList("two")), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime + 2L - WINDOW_SIZE), ofEpochMilli(defaultStartTime + 2L)))); + assertEquals( + new HashSet<>(Collections.singletonList("three")), + valuesToSet(windowStore.fetch(3, ofEpochMilli(defaultStartTime + 3L - WINDOW_SIZE), ofEpochMilli(defaultStartTime + 3L)))); + assertEquals( + new HashSet<>(Collections.singletonList("four")), + valuesToSet(windowStore.fetch(4, ofEpochMilli(defaultStartTime + 4L - WINDOW_SIZE), ofEpochMilli(defaultStartTime + 4L)))); + assertEquals( + new HashSet<>(Collections.singletonList("five")), + valuesToSet(windowStore.fetch(5, ofEpochMilli(defaultStartTime + 5L - WINDOW_SIZE), ofEpochMilli(defaultStartTime + 5L)))); + + putSecondBatch(windowStore, defaultStartTime, context); + + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime - 1L - WINDOW_SIZE), ofEpochMilli(defaultStartTime - 1L)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime + 0L - WINDOW_SIZE), ofEpochMilli(defaultStartTime + 0L)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime + 1L - WINDOW_SIZE), ofEpochMilli(defaultStartTime + 1L)))); + assertEquals( + new HashSet<>(Collections.singletonList("two")), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime + 2L - WINDOW_SIZE), ofEpochMilli(defaultStartTime + 2L)))); + assertEquals( + new HashSet<>(asList("two", "two+1")), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime + 3L - WINDOW_SIZE), ofEpochMilli(defaultStartTime + 3L)))); + assertEquals( + new HashSet<>(asList("two", "two+1", "two+2")), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime + 4L - WINDOW_SIZE), ofEpochMilli(defaultStartTime + 4L)))); + assertEquals( + new HashSet<>(asList("two", "two+1", "two+2", "two+3")), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime + 5L - WINDOW_SIZE), ofEpochMilli(defaultStartTime + 5L)))); + assertEquals( + new HashSet<>(asList("two+1", "two+2", "two+3", "two+4")), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime + 6L - WINDOW_SIZE), ofEpochMilli(defaultStartTime + 6L)))); + assertEquals( + new HashSet<>(asList("two+2", "two+3", "two+4", "two+5")), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime + 7L - WINDOW_SIZE), ofEpochMilli(defaultStartTime + 7L)))); + assertEquals( + new HashSet<>(asList("two+3", "two+4", "two+5", "two+6")), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime + 8L - WINDOW_SIZE), ofEpochMilli(defaultStartTime + 8L)))); + assertEquals( + new HashSet<>(asList("two+4", "two+5", "two+6")), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime + 9L - WINDOW_SIZE), ofEpochMilli(defaultStartTime + 9L)))); + assertEquals( + new HashSet<>(asList("two+5", "two+6")), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime + 10L - WINDOW_SIZE), ofEpochMilli(defaultStartTime + 10L)))); + assertEquals( + new HashSet<>(Collections.singletonList("two+6")), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime + 11L - WINDOW_SIZE), ofEpochMilli(defaultStartTime + 11L)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime + 12L - WINDOW_SIZE), ofEpochMilli(defaultStartTime + 12L)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime + 13L - WINDOW_SIZE), ofEpochMilli(defaultStartTime + 13L)))); + + // Flush the store and verify all current entries were properly flushed ... + windowStore.flush(); + + final List> changeLog = new ArrayList<>(); + for (final ProducerRecord record : recordCollector.collected()) { + changeLog.add(new KeyValue<>(((Bytes) record.key()).get(), (byte[]) record.value())); + } + + final Map> entriesByKey = entriesByKey(changeLog, defaultStartTime); + assertEquals(Utils.mkSet("zero@0"), entriesByKey.get(0)); + assertEquals(Utils.mkSet("one@1"), entriesByKey.get(1)); + assertEquals(Utils.mkSet("two@2", "two+1@3", "two+2@4", "two+3@5", "two+4@6", "two+5@7", "two+6@8"), entriesByKey.get(2)); + assertEquals(Utils.mkSet("three@2"), entriesByKey.get(3)); + assertEquals(Utils.mkSet("four@4"), entriesByKey.get(4)); + assertEquals(Utils.mkSet("five@5"), entriesByKey.get(5)); + assertNull(entriesByKey.get(6)); + } + + @Test + public void testPutAndFetchAfter() { + putFirstBatch(windowStore, defaultStartTime, context); + + assertEquals( + new HashSet<>(Collections.singletonList("zero")), + valuesToSet(windowStore.fetch(0, ofEpochMilli(defaultStartTime + 0L), + ofEpochMilli(defaultStartTime + 0L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("one")), + valuesToSet(windowStore.fetch(1, ofEpochMilli(defaultStartTime + 1L), + ofEpochMilli(defaultStartTime + 1L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("two")), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime + 2L), + ofEpochMilli(defaultStartTime + 2L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch(3, ofEpochMilli(defaultStartTime + 3L), + ofEpochMilli(defaultStartTime + 3L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("four")), + valuesToSet(windowStore.fetch(4, ofEpochMilli(defaultStartTime + 4L), + ofEpochMilli(defaultStartTime + 4L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("five")), + valuesToSet(windowStore.fetch(5, ofEpochMilli(defaultStartTime + 5L), + ofEpochMilli(defaultStartTime + 5L + WINDOW_SIZE)))); + + putSecondBatch(windowStore, defaultStartTime, context); + + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime - 2L), + ofEpochMilli(defaultStartTime - 2L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("two")), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime - 1L), + ofEpochMilli(defaultStartTime - 1L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(asList("two", "two+1")), + valuesToSet(windowStore + .fetch(2, ofEpochMilli(defaultStartTime), ofEpochMilli(defaultStartTime + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(asList("two", "two+1", "two+2")), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime + 1L), + ofEpochMilli(defaultStartTime + 1L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(asList("two", "two+1", "two+2", "two+3")), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime + 2L), + ofEpochMilli(defaultStartTime + 2L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(asList("two+1", "two+2", "two+3", "two+4")), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime + 3L), + ofEpochMilli(defaultStartTime + 3L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(asList("two+2", "two+3", "two+4", "two+5")), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime + 4L), + ofEpochMilli(defaultStartTime + 4L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(asList("two+3", "two+4", "two+5", "two+6")), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime + 5L), + ofEpochMilli(defaultStartTime + 5L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(asList("two+4", "two+5", "two+6")), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime + 6L), + ofEpochMilli(defaultStartTime + 6L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(asList("two+5", "two+6")), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime + 7L), + ofEpochMilli(defaultStartTime + 7L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("two+6")), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime + 8L), + ofEpochMilli(defaultStartTime + 8L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime + 9L), + ofEpochMilli(defaultStartTime + 9L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime + 10L), + ofEpochMilli(defaultStartTime + 10L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime + 11L), + ofEpochMilli(defaultStartTime + 11L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch(2, ofEpochMilli(defaultStartTime + 12L), + ofEpochMilli(defaultStartTime + 12L + WINDOW_SIZE)))); + + // Flush the store and verify all current entries were properly flushed ... + windowStore.flush(); + + final List> changeLog = new ArrayList<>(); + for (final ProducerRecord record : recordCollector.collected()) { + changeLog.add(new KeyValue<>(((Bytes) record.key()).get(), (byte[]) record.value())); + } + + final Map> entriesByKey = entriesByKey(changeLog, defaultStartTime); + + assertEquals(Utils.mkSet("zero@0"), entriesByKey.get(0)); + assertEquals(Utils.mkSet("one@1"), entriesByKey.get(1)); + assertEquals( + Utils.mkSet("two@2", "two+1@3", "two+2@4", "two+3@5", "two+4@6", "two+5@7", "two+6@8"), + entriesByKey.get(2)); + assertEquals(Utils.mkSet("three@2"), entriesByKey.get(3)); + assertEquals(Utils.mkSet("four@4"), entriesByKey.get(4)); + assertEquals(Utils.mkSet("five@5"), entriesByKey.get(5)); + assertNull(entriesByKey.get(6)); + } + + @Test + public void testPutSameKeyTimestamp() { + windowStore.close(); + windowStore = buildWindowStore(RETENTION_PERIOD, WINDOW_SIZE, true, Serdes.Integer(), Serdes.String()); + windowStore.init((StateStoreContext) context, windowStore); + + windowStore.put(0, "zero", defaultStartTime); + + assertEquals( + new HashSet<>(Collections.singletonList("zero")), + valuesToSet(windowStore.fetch(0, ofEpochMilli(defaultStartTime - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + WINDOW_SIZE)))); + + windowStore.put(0, "zero", defaultStartTime); + windowStore.put(0, "zero+", defaultStartTime); + windowStore.put(0, "zero++", defaultStartTime); + + assertEquals( + new HashSet<>(asList("zero", "zero", "zero+", "zero++")), + valuesToSet(windowStore.fetch( + 0, + ofEpochMilli(defaultStartTime - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(asList("zero", "zero", "zero+", "zero++")), + valuesToSet(windowStore.fetch( + 0, + ofEpochMilli(defaultStartTime + 1L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + 1L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(asList("zero", "zero", "zero+", "zero++")), + valuesToSet(windowStore.fetch( + 0, + ofEpochMilli(defaultStartTime + 2L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + 2L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(asList("zero", "zero", "zero+", "zero++")), + valuesToSet(windowStore.fetch( + 0, + ofEpochMilli(defaultStartTime + 3L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + 3L + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch( + 0, + ofEpochMilli(defaultStartTime + 4L - WINDOW_SIZE), + ofEpochMilli(defaultStartTime + 4L + WINDOW_SIZE)))); + + // Flush the store and verify all current entries were properly flushed ... + windowStore.flush(); + + final List> changeLog = new ArrayList<>(); + for (final ProducerRecord record : recordCollector.collected()) { + changeLog.add(new KeyValue<>(((Bytes) record.key()).get(), (byte[]) record.value())); + } + + final Map> entriesByKey = entriesByKey(changeLog, defaultStartTime); + + assertEquals(Utils.mkSet("zero@0", "zero@0", "zero+@0", "zero++@0"), entriesByKey.get(0)); + } + + @Test + public void shouldCloseOpenIteratorsWhenStoreIsClosedAndNotThrowInvalidStateStoreExceptionOnHasNext() { + windowStore.put(1, "one", 1L); + windowStore.put(1, "two", 2L); + windowStore.put(1, "three", 3L); + + try (final WindowStoreIterator iterator = windowStore.fetch(1, ofEpochMilli(1L), ofEpochMilli(3L))) { + assertTrue(iterator.hasNext()); + windowStore.close(); + + assertFalse(iterator.hasNext()); + } + } + + @Test + public void shouldFetchAndIterateOverExactKeys() { + final long windowSize = 0x7a00000000000000L; + final long retentionPeriod = 0x7a00000000000000L; + final WindowStore windowStore = buildWindowStore(retentionPeriod, + windowSize, + false, + Serdes.String(), + Serdes.String()); + + windowStore.init((StateStoreContext) context, windowStore); + + windowStore.put("a", "0001", 0); + windowStore.put("aa", "0002", 0); + windowStore.put("a", "0003", 1); + windowStore.put("aa", "0004", 1); + windowStore.put("a", "0005", 0x7a00000000000000L - 1); + + final Set expected = new HashSet<>(asList("0001", "0003", "0005")); + assertThat( + valuesToSet(windowStore.fetch("a", ofEpochMilli(0), ofEpochMilli(Long.MAX_VALUE))), + equalTo(expected) + ); + + Set, String>> set = + toSet(windowStore.fetch("a", "a", ofEpochMilli(0), ofEpochMilli(Long.MAX_VALUE))); + assertThat( + set, + equalTo(new HashSet<>(asList( + windowedPair("a", "0001", 0, windowSize), + windowedPair("a", "0003", 1, windowSize), + windowedPair("a", "0005", 0x7a00000000000000L - 1, windowSize) + ))) + ); + + set = toSet(windowStore.fetch("aa", "aa", ofEpochMilli(0), ofEpochMilli(Long.MAX_VALUE))); + assertThat( + set, + equalTo(new HashSet<>(asList( + windowedPair("aa", "0002", 0, windowSize), + windowedPair("aa", "0004", 1, windowSize) + ))) + ); + windowStore.close(); + } + + @Test + public void testDeleteAndUpdate() { + final long currentTime = 0; + windowStore.put(1, "one", currentTime); + windowStore.put(1, "one v2", currentTime); + + WindowStoreIterator iterator = windowStore.fetch(1, 0, currentTime); + assertEquals(new KeyValue<>(currentTime, "one v2"), iterator.next()); + + windowStore.put(1, null, currentTime); + iterator = windowStore.fetch(1, 0, currentTime); + assertFalse(iterator.hasNext()); + } + + @Test + public void shouldReturnNullOnWindowNotFound() { + assertNull(windowStore.fetch(1, 0L)); + } + + @Test + public void shouldThrowNullPointerExceptionOnPutNullKey() { + assertThrows(NullPointerException.class, () -> windowStore.put(null, "anyValue", 0L)); + } + + @Test + public void shouldThrowNullPointerExceptionOnGetNullKey() { + assertThrows(NullPointerException.class, () -> windowStore.fetch(null, ofEpochMilli(1L), ofEpochMilli(2L))); + } + + @Test + public void shouldFetchAndIterateOverExactBinaryKeys() { + final WindowStore windowStore = buildWindowStore(RETENTION_PERIOD, + WINDOW_SIZE, + true, + Serdes.Bytes(), + Serdes.String()); + windowStore.init((StateStoreContext) context, windowStore); + + final Bytes key1 = Bytes.wrap(new byte[] {0}); + final Bytes key2 = Bytes.wrap(new byte[] {0, 0}); + final Bytes key3 = Bytes.wrap(new byte[] {0, 0, 0}); + windowStore.put(key1, "1", 0); + windowStore.put(key2, "2", 0); + windowStore.put(key3, "3", 0); + windowStore.put(key1, "4", 1); + windowStore.put(key2, "5", 1); + windowStore.put(key3, "6", 59999); + windowStore.put(key1, "7", 59999); + windowStore.put(key2, "8", 59999); + windowStore.put(key3, "9", 59999); + + final Set expectedKey1 = new HashSet<>(asList("1", "4", "7")); + assertThat( + valuesToSet(windowStore.fetch(key1, ofEpochMilli(0), ofEpochMilli(Long.MAX_VALUE))), + equalTo(expectedKey1) + ); + final Set expectedKey2 = new HashSet<>(asList("2", "5", "8")); + assertThat( + valuesToSet(windowStore.fetch(key2, ofEpochMilli(0), ofEpochMilli(Long.MAX_VALUE))), + equalTo(expectedKey2) + ); + final Set expectedKey3 = new HashSet<>(asList("3", "6", "9")); + assertThat( + valuesToSet(windowStore.fetch(key3, ofEpochMilli(0), ofEpochMilli(Long.MAX_VALUE))), + equalTo(expectedKey3) + ); + + windowStore.close(); + } + + @Test + public void shouldReturnSameResultsForSingleKeyFetchAndEqualKeyRangeFetch() { + windowStore.put(1, "one", 0L); + windowStore.put(2, "two", 1L); + windowStore.put(2, "two", 2L); + windowStore.put(3, "three", 3L); + + try (final WindowStoreIterator singleKeyIterator = windowStore.fetch(2, 0L, 5L); + final KeyValueIterator, String> keyRangeIterator = windowStore.fetch(2, 2, 0L, 5L)) { + + assertEquals(singleKeyIterator.next().value, keyRangeIterator.next().value); + assertEquals(singleKeyIterator.next().value, keyRangeIterator.next().value); + assertFalse(singleKeyIterator.hasNext()); + assertFalse(keyRangeIterator.hasNext()); + } + } + + @Test + public void shouldNotThrowInvalidRangeExceptionWithNegativeFromKey() { + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(); + final KeyValueIterator, String> iterator = windowStore.fetch(-1, 1, 0L, 10L)) { + assertFalse(iterator.hasNext()); + + final List messages = appender.getMessages(); + assertThat( + messages, + hasItem("Returning empty iterator for fetch with invalid key range: from > to." + + " This may be due to range arguments set in the wrong order, " + + "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes." + + " Note that the built-in numerical serdes do not follow this for negative numbers") + ); + } + } + + @Test + public void shouldLogAndMeasureExpiredRecords() { + final Properties streamsConfig = StreamsTestUtils.getStreamsConfig(); + final WindowStore windowStore = + buildWindowStore(RETENTION_PERIOD, WINDOW_SIZE, false, Serdes.Integer(), Serdes.String()); + final InternalMockProcessorContext context = new InternalMockProcessorContext( + TestUtils.tempDirectory(), + new StreamsConfig(streamsConfig), + recordCollector + ); + final Time time = new SystemTime(); + context.setSystemTimeMs(time.milliseconds()); + context.setTime(1L); + windowStore.init((StateStoreContext) context, windowStore); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister()) { + // Advance stream time by inserting record with large enough timestamp that records with timestamp 0 are expired + windowStore.put(1, "initial record", 2 * RETENTION_PERIOD); + + // Try inserting a record with timestamp 0 -- should be dropped + windowStore.put(1, "late record", 0L); + windowStore.put(1, "another on-time record", RETENTION_PERIOD + 1); + + final List messages = appender.getMessages(); + assertThat(messages, hasItem("Skipping record for expired segment.")); + } + + final Map metrics = context.metrics().metrics(); + + final String threadId = Thread.currentThread().getName(); + final Metric dropTotal; + final Metric dropRate; + dropTotal = metrics.get(new MetricName( + "dropped-records-total", + "stream-task-metrics", + "", + mkMap( + mkEntry("thread-id", threadId), + mkEntry("task-id", "0_0") + ) + )); + + dropRate = metrics.get(new MetricName( + "dropped-records-rate", + "stream-task-metrics", + "", + mkMap( + mkEntry("thread-id", threadId), + mkEntry("task-id", "0_0") + ) + )); + assertEquals(1.0, dropTotal.metricValue()); + assertNotEquals(0.0, dropRate.metricValue()); + + windowStore.close(); + } + + @Test + public void shouldNotThrowExceptionWhenFetchRangeIsExpired() { + windowStore.put(1, "one", 0L); + windowStore.put(1, "two", 4 * RETENTION_PERIOD); + + try (final WindowStoreIterator iterator = windowStore.fetch(1, 0L, 10L)) { + + assertFalse(iterator.hasNext()); + } + } + + @Test + public void testWindowIteratorPeek() { + final long currentTime = 0; + windowStore.put(1, "one", currentTime); + + try (final KeyValueIterator, String> iterator = windowStore.fetchAll(0L, currentTime)) { + + assertTrue(iterator.hasNext()); + final Windowed nextKey = iterator.peekNextKey(); + + assertEquals(iterator.peekNextKey(), nextKey); + assertEquals(iterator.peekNextKey(), iterator.next().key); + assertFalse(iterator.hasNext()); + } + } + + @Test + public void testValueIteratorPeek() { + windowStore.put(1, "one", 0L); + + try (final WindowStoreIterator iterator = windowStore.fetch(1, 0L, 10L)) { + + assertTrue(iterator.hasNext()); + final Long nextKey = iterator.peekNextKey(); + + assertEquals(iterator.peekNextKey(), nextKey); + assertEquals(iterator.peekNextKey(), iterator.next().key); + assertFalse(iterator.hasNext()); + } + } + + @Test + public void shouldNotThrowConcurrentModificationException() { + long currentTime = 0; + windowStore.put(1, "one", currentTime); + + currentTime += WINDOW_SIZE * 10; + windowStore.put(1, "two", currentTime); + + try (final KeyValueIterator, String> iterator = windowStore.all()) { + + currentTime += WINDOW_SIZE * 10; + windowStore.put(1, "three", currentTime); + + currentTime += WINDOW_SIZE * 10; + windowStore.put(2, "four", currentTime); + + // Iterator should return all records in store and not throw exception b/c some were added after fetch + assertEquals(windowedPair(1, "one", 0), iterator.next()); + assertEquals(windowedPair(1, "two", WINDOW_SIZE * 10), iterator.next()); + assertEquals(windowedPair(1, "three", WINDOW_SIZE * 20), iterator.next()); + assertEquals(windowedPair(2, "four", WINDOW_SIZE * 30), iterator.next()); + assertFalse(iterator.hasNext()); + } + } + + @Test + public void testFetchDuplicates() { + windowStore.close(); + windowStore = buildWindowStore(RETENTION_PERIOD, WINDOW_SIZE, true, Serdes.Integer(), Serdes.String()); + windowStore.init((StateStoreContext) context, windowStore); + + long currentTime = 0; + windowStore.put(1, "one", currentTime); + windowStore.put(1, "one-2", currentTime); + + currentTime += WINDOW_SIZE * 10; + windowStore.put(1, "two", currentTime); + windowStore.put(1, "two-2", currentTime); + + currentTime += WINDOW_SIZE * 10; + windowStore.put(1, "three", currentTime); + windowStore.put(1, "three-2", currentTime); + + try (final WindowStoreIterator iterator = windowStore.fetch(1, 0, WINDOW_SIZE * 10)) { + + assertEquals(new KeyValue<>(0L, "one"), iterator.next()); + assertEquals(new KeyValue<>(0L, "one-2"), iterator.next()); + assertEquals(new KeyValue<>(WINDOW_SIZE * 10, "two"), iterator.next()); + assertEquals(new KeyValue<>(WINDOW_SIZE * 10, "two-2"), iterator.next()); + assertFalse(iterator.hasNext()); + } + } + + + private void putFirstBatch(final WindowStore store, + @SuppressWarnings("SameParameterValue") final long startTime, + final InternalMockProcessorContext context) { + context.setRecordContext(createRecordContext(startTime)); + store.put(0, "zero", startTime); + store.put(1, "one", startTime + 1L); + store.put(2, "two", startTime + 2L); + store.put(3, "three", startTime + 2L); + store.put(4, "four", startTime + 4L); + store.put(5, "five", startTime + 5L); + } + + private void putSecondBatch(final WindowStore store, + @SuppressWarnings("SameParameterValue") final long startTime, + final InternalMockProcessorContext context) { + store.put(2, "two+1", startTime + 3L); + store.put(2, "two+2", startTime + 4L); + store.put(2, "two+3", startTime + 5L); + store.put(2, "two+4", startTime + 6L); + store.put(2, "two+5", startTime + 7L); + store.put(2, "two+6", startTime + 8L); + } + + long extractStoreTimestamp(final byte[] binaryKey) { + return WindowKeySchema.extractStoreTimestamp(binaryKey); + } + + K extractStoreKey(final byte[] binaryKey, + final StateSerdes serdes) { + return WindowKeySchema.extractStoreKey(binaryKey, serdes); + } + + private Map> entriesByKey(final List> changeLog, + @SuppressWarnings("SameParameterValue") final long startTime) { + final HashMap> entriesByKey = new HashMap<>(); + + for (final KeyValue entry : changeLog) { + final long timestamp = extractStoreTimestamp(entry.key); + + final Integer key = extractStoreKey(entry.key, serdes); + final String value = entry.value == null ? null : serdes.valueFrom(entry.value); + + final Set entries = entriesByKey.computeIfAbsent(key, k -> new HashSet<>()); + entries.add(value + "@" + (timestamp - startTime)); + } + + return entriesByKey; + } + + protected static KeyValue, V> windowedPair(final K key, final V value, final long timestamp) { + return windowedPair(key, value, timestamp, WINDOW_SIZE); + } + + private static KeyValue, V> windowedPair(final K key, final V value, final long timestamp, final long windowSize) { + return KeyValue.pair(new Windowed<>(key, WindowKeySchema.timeWindowForSize(timestamp, windowSize)), value); + } + + private ProcessorRecordContext createRecordContext(final long time) { + return new ProcessorRecordContext(time, 0, 0, "topic", new RecordHeaders()); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/BlockBasedTableConfigWithAccessibleCacheTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/BlockBasedTableConfigWithAccessibleCacheTest.java new file mode 100644 index 0000000..1904789 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/BlockBasedTableConfigWithAccessibleCacheTest.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.junit.Test; +import org.rocksdb.BlockBasedTableConfig; +import org.rocksdb.Cache; +import org.rocksdb.LRUCache; +import org.rocksdb.RocksDB; + +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.CoreMatchers.sameInstance; +import static org.hamcrest.MatcherAssert.assertThat; + +public class BlockBasedTableConfigWithAccessibleCacheTest { + + static { + RocksDB.loadLibrary(); + } + + @Test + public void shouldReturnNoBlockCacheIfNoneIsSet() { + final BlockBasedTableConfigWithAccessibleCache configWithAccessibleCache = + new BlockBasedTableConfigWithAccessibleCache(); + + assertThat(configWithAccessibleCache.blockCache(), nullValue()); + } + + @Test + public void shouldSetBlockCacheAndMakeItAccessible() { + final BlockBasedTableConfigWithAccessibleCache configWithAccessibleCache = + new BlockBasedTableConfigWithAccessibleCache(); + final Cache blockCache = new LRUCache(1024); + + final BlockBasedTableConfig updatedConfig = configWithAccessibleCache.setBlockCache(blockCache); + + assertThat(updatedConfig, sameInstance(configWithAccessibleCache)); + assertThat(configWithAccessibleCache.blockCache(), sameInstance(blockCache)); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/BufferValueTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/BufferValueTest.java new file mode 100644 index 0000000..a8cc5ac --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/BufferValueTest.java @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.streams.processor.internals.ProcessorRecordContext; +import org.junit.Test; + +import java.nio.ByteBuffer; +import java.util.Arrays; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; + +public class BufferValueTest { + @Test + public void shouldDeduplicateNullValues() { + final BufferValue bufferValue = new BufferValue(null, null, null, null); + assertSame(bufferValue.priorValue(), bufferValue.oldValue()); + } + + @Test + public void shouldDeduplicateIndenticalValues() { + final byte[] bytes = {(byte) 0}; + final BufferValue bufferValue = new BufferValue(bytes, bytes, null, null); + assertSame(bufferValue.priorValue(), bufferValue.oldValue()); + } + + @Test + public void shouldDeduplicateEqualValues() { + final BufferValue bufferValue = new BufferValue(new byte[] {(byte) 0}, new byte[] {(byte) 0}, null, null); + assertSame(bufferValue.priorValue(), bufferValue.oldValue()); + } + + @Test + public void shouldStoreDifferentValues() { + final byte[] priorValue = {(byte) 0}; + final byte[] oldValue = {(byte) 1}; + final BufferValue bufferValue = new BufferValue(priorValue, oldValue, null, null); + assertSame(priorValue, bufferValue.priorValue()); + assertSame(oldValue, bufferValue.oldValue()); + assertNotEquals(bufferValue.priorValue(), bufferValue.oldValue()); + } + + @Test + public void shouldStoreDifferentValuesWithPriorNull() { + final byte[] priorValue = null; + final byte[] oldValue = {(byte) 1}; + final BufferValue bufferValue = new BufferValue(priorValue, oldValue, null, null); + assertNull(bufferValue.priorValue()); + assertSame(oldValue, bufferValue.oldValue()); + assertNotEquals(bufferValue.priorValue(), bufferValue.oldValue()); + } + + @Test + public void shouldStoreDifferentValuesWithOldNull() { + final byte[] priorValue = {(byte) 0}; + final byte[] oldValue = null; + final BufferValue bufferValue = new BufferValue(priorValue, oldValue, null, null); + assertSame(priorValue, bufferValue.priorValue()); + assertNull(bufferValue.oldValue()); + assertNotEquals(bufferValue.priorValue(), bufferValue.oldValue()); + } + + @Test + public void shouldAccountForDeduplicationInSizeEstimate() { + final ProcessorRecordContext context = new ProcessorRecordContext(0L, 0L, 0, "topic", new RecordHeaders()); + assertEquals(25L, new BufferValue(null, null, null, context).residentMemorySizeEstimate()); + assertEquals(26L, new BufferValue(new byte[] {(byte) 0}, null, null, context).residentMemorySizeEstimate()); + assertEquals(26L, new BufferValue(null, new byte[] {(byte) 0}, null, context).residentMemorySizeEstimate()); + assertEquals(26L, new BufferValue(new byte[] {(byte) 0}, new byte[] {(byte) 0}, null, context).residentMemorySizeEstimate()); + assertEquals(27L, new BufferValue(new byte[] {(byte) 0}, new byte[] {(byte) 1}, null, context).residentMemorySizeEstimate()); + + // new value should get counted, but doesn't get deduplicated + assertEquals(28L, new BufferValue(new byte[] {(byte) 0}, new byte[] {(byte) 1}, new byte[] {(byte) 0}, context).residentMemorySizeEstimate()); + } + + @Test + public void shouldSerializeNulls() { + final ProcessorRecordContext context = new ProcessorRecordContext(0L, 0L, 0, "topic", new RecordHeaders()); + final byte[] serializedContext = context.serialize(); + final byte[] bytes = new BufferValue(null, null, null, context).serialize(0).array(); + final byte[] withoutContext = Arrays.copyOfRange(bytes, serializedContext.length, bytes.length); + + assertThat(withoutContext, is(ByteBuffer.allocate(Integer.BYTES * 3).putInt(-1).putInt(-1).putInt(-1).array())); + } + + @Test + public void shouldSerializePrior() { + final ProcessorRecordContext context = new ProcessorRecordContext(0L, 0L, 0, "topic", new RecordHeaders()); + final byte[] serializedContext = context.serialize(); + final byte[] priorValue = {(byte) 5}; + final byte[] bytes = new BufferValue(priorValue, null, null, context).serialize(0).array(); + final byte[] withoutContext = Arrays.copyOfRange(bytes, serializedContext.length, bytes.length); + + assertThat(withoutContext, is(ByteBuffer.allocate(Integer.BYTES * 3 + 1).putInt(1).put(priorValue).putInt(-1).putInt(-1).array())); + } + + @Test + public void shouldSerializeOld() { + final ProcessorRecordContext context = new ProcessorRecordContext(0L, 0L, 0, "topic", new RecordHeaders()); + final byte[] serializedContext = context.serialize(); + final byte[] oldValue = {(byte) 5}; + final byte[] bytes = new BufferValue(null, oldValue, null, context).serialize(0).array(); + final byte[] withoutContext = Arrays.copyOfRange(bytes, serializedContext.length, bytes.length); + + assertThat(withoutContext, is(ByteBuffer.allocate(Integer.BYTES * 3 + 1).putInt(-1).putInt(1).put(oldValue).putInt(-1).array())); + } + + @Test + public void shouldSerializeNew() { + final ProcessorRecordContext context = new ProcessorRecordContext(0L, 0L, 0, "topic", new RecordHeaders()); + final byte[] serializedContext = context.serialize(); + final byte[] newValue = {(byte) 5}; + final byte[] bytes = new BufferValue(null, null, newValue, context).serialize(0).array(); + final byte[] withoutContext = Arrays.copyOfRange(bytes, serializedContext.length, bytes.length); + + assertThat(withoutContext, is(ByteBuffer.allocate(Integer.BYTES * 3 + 1).putInt(-1).putInt(-1).putInt(1).put(newValue).array())); + } + + @Test + public void shouldCompactDuplicates() { + final ProcessorRecordContext context = new ProcessorRecordContext(0L, 0L, 0, "topic", new RecordHeaders()); + final byte[] serializedContext = context.serialize(); + final byte[] duplicate = {(byte) 5}; + final byte[] bytes = new BufferValue(duplicate, duplicate, null, context).serialize(0).array(); + final byte[] withoutContext = Arrays.copyOfRange(bytes, serializedContext.length, bytes.length); + + assertThat(withoutContext, is(ByteBuffer.allocate(Integer.BYTES * 3 + 1).putInt(1).put(duplicate).putInt(-2).putInt(-1).array())); + } + + @Test + public void shouldDeserializePrior() { + final ProcessorRecordContext context = new ProcessorRecordContext(0L, 0L, 0, "topic", new RecordHeaders()); + final byte[] serializedContext = context.serialize(); + final byte[] priorValue = {(byte) 5}; + final ByteBuffer serialValue = + ByteBuffer + .allocate(serializedContext.length + Integer.BYTES * 3 + priorValue.length) + .put(serializedContext).putInt(1).put(priorValue).putInt(-1).putInt(-1); + serialValue.position(0); + + final BufferValue deserialize = BufferValue.deserialize(serialValue); + assertThat(deserialize, is(new BufferValue(priorValue, null, null, context))); + } + + @Test + public void shouldDeserializeOld() { + final ProcessorRecordContext context = new ProcessorRecordContext(0L, 0L, 0, "topic", new RecordHeaders()); + final byte[] serializedContext = context.serialize(); + final byte[] oldValue = {(byte) 5}; + final ByteBuffer serialValue = + ByteBuffer + .allocate(serializedContext.length + Integer.BYTES * 3 + oldValue.length) + .put(serializedContext).putInt(-1).putInt(1).put(oldValue).putInt(-1); + serialValue.position(0); + + assertThat(BufferValue.deserialize(serialValue), is(new BufferValue(null, oldValue, null, context))); + } + + @Test + public void shouldDeserializeNew() { + final ProcessorRecordContext context = new ProcessorRecordContext(0L, 0L, 0, "topic", new RecordHeaders()); + final byte[] serializedContext = context.serialize(); + final byte[] newValue = {(byte) 5}; + final ByteBuffer serialValue = + ByteBuffer + .allocate(serializedContext.length + Integer.BYTES * 3 + newValue.length) + .put(serializedContext).putInt(-1).putInt(-1).putInt(1).put(newValue); + serialValue.position(0); + + assertThat(BufferValue.deserialize(serialValue), is(new BufferValue(null, null, newValue, context))); + } + + @Test + public void shouldDeserializeCompactedDuplicates() { + final ProcessorRecordContext context = new ProcessorRecordContext(0L, 0L, 0, "topic", new RecordHeaders()); + final byte[] serializedContext = context.serialize(); + final byte[] duplicate = {(byte) 5}; + final ByteBuffer serialValue = + ByteBuffer + .allocate(serializedContext.length + Integer.BYTES * 3 + duplicate.length) + .put(serializedContext).putInt(1).put(duplicate).putInt(-2).putInt(-1); + serialValue.position(0); + + final BufferValue bufferValue = BufferValue.deserialize(serialValue); + assertThat(bufferValue, is(new BufferValue(duplicate, duplicate, null, context))); + assertSame(bufferValue.priorValue(), bufferValue.oldValue()); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/CacheFlushListenerStub.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/CacheFlushListenerStub.java new file mode 100644 index 0000000..b214739 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/CacheFlushListenerStub.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.streams.kstream.internals.Change; +import org.apache.kafka.streams.processor.api.Record; + +import java.util.HashMap; +import java.util.Map; + +public class CacheFlushListenerStub implements CacheFlushListener { + private final Deserializer keyDeserializer; + private final Deserializer valueDeserializer; + final Map> forwarded = new HashMap<>(); + + CacheFlushListenerStub(final Deserializer keyDeserializer, + final Deserializer valueDeserializer) { + this.keyDeserializer = keyDeserializer; + this.valueDeserializer = valueDeserializer; + } + + @Override + public void apply(final byte[] key, + final byte[] newValue, + final byte[] oldValue, + final long timestamp) { + forwarded.put( + keyDeserializer.deserialize(null, key), + new Change<>( + valueDeserializer.deserialize(null, newValue), + valueDeserializer.deserialize(null, oldValue) + ) + ); + } + + @Override + public void apply(final Record> record) { + forwarded.put( + keyDeserializer.deserialize(null, record.key()), + new Change<>( + valueDeserializer.deserialize(null, record.value().newValue), + valueDeserializer.deserialize(null, record.value().oldValue) + ) + ); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingInMemoryKeyValueStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingInMemoryKeyValueStoreTest.java new file mode 100644 index 0000000..dadd3de --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingInMemoryKeyValueStoreTest.java @@ -0,0 +1,602 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.internals.MockStreamsMetrics; +import org.apache.kafka.streams.processor.internals.ProcessorRecordContext; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.test.InternalMockProcessorContext; +import org.apache.kafka.test.TestUtils; +import org.easymock.EasyMock; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.apache.kafka.streams.state.internals.ThreadCacheTest.memoryCacheEntrySize; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class CachingInMemoryKeyValueStoreTest extends AbstractKeyValueStoreTest { + + private final static String TOPIC = "topic"; + private static final String CACHE_NAMESPACE = "0_0-store-name"; + private final int maxCacheSizeBytes = 150; + private InternalMockProcessorContext context; + private CachingKeyValueStore store; + private KeyValueStore underlyingStore; + private ThreadCache cache; + private CacheFlushListenerStub cacheFlushListener; + + @Before + public void setUp() { + final String storeName = "store"; + underlyingStore = new InMemoryKeyValueStore(storeName); + cacheFlushListener = new CacheFlushListenerStub<>(new StringDeserializer(), new StringDeserializer()); + store = new CachingKeyValueStore(underlyingStore); + store.setFlushListener(cacheFlushListener, false); + cache = new ThreadCache(new LogContext("testCache "), maxCacheSizeBytes, new MockStreamsMetrics(new Metrics())); + context = new InternalMockProcessorContext<>(null, null, null, null, cache); + context.setRecordContext(new ProcessorRecordContext(10, 0, 0, TOPIC, new RecordHeaders())); + store.init((StateStoreContext) context, null); + } + + @After + public void after() { + super.after(); + } + + @SuppressWarnings("unchecked") + @Override + protected KeyValueStore createKeyValueStore(final StateStoreContext context) { + final StoreBuilder> storeBuilder = Stores.keyValueStoreBuilder( + Stores.persistentKeyValueStore("cache-store"), + (Serde) context.keySerde(), + (Serde) context.valueSerde()) + .withCachingEnabled(); + + final KeyValueStore store = storeBuilder.build(); + store.init(context, store); + return store; + } + + @SuppressWarnings("deprecation") + @Test + public void shouldDelegateDeprecatedInit() { + final KeyValueStore inner = EasyMock.mock(InMemoryKeyValueStore.class); + final CachingKeyValueStore outer = new CachingKeyValueStore(inner); + EasyMock.expect(inner.name()).andStubReturn("store"); + inner.init((ProcessorContext) context, outer); + EasyMock.expectLastCall(); + EasyMock.replay(inner); + outer.init((ProcessorContext) context, outer); + EasyMock.verify(inner); + } + + @Test + public void shouldDelegateInit() { + final KeyValueStore inner = EasyMock.mock(InMemoryKeyValueStore.class); + final CachingKeyValueStore outer = new CachingKeyValueStore(inner); + EasyMock.expect(inner.name()).andStubReturn("store"); + inner.init((StateStoreContext) context, outer); + EasyMock.expectLastCall(); + EasyMock.replay(inner); + outer.init((StateStoreContext) context, outer); + EasyMock.verify(inner); + } + + @Test + public void shouldSetFlushListener() { + assertTrue(store.setFlushListener(null, true)); + assertTrue(store.setFlushListener(null, false)); + } + + @Test + public void shouldAvoidFlushingDeletionsWithoutDirtyKeys() { + final int added = addItemsToCache(); + // all dirty entries should have been flushed + assertEquals(added, underlyingStore.approximateNumEntries()); + assertEquals(added, cacheFlushListener.forwarded.size()); + + store.put(bytesKey("key"), bytesValue("value")); + assertEquals(added, underlyingStore.approximateNumEntries()); + assertEquals(added, cacheFlushListener.forwarded.size()); + + store.put(bytesKey("key"), null); + store.flush(); + assertEquals(added, underlyingStore.approximateNumEntries()); + assertEquals(added, cacheFlushListener.forwarded.size()); + } + + @Test + public void shouldCloseWrappedStoreAndCacheAfterErrorDuringCacheFlush() { + setUpCloseTests(); + EasyMock.reset(cache); + cache.flush(CACHE_NAMESPACE); + EasyMock.expectLastCall().andThrow(new RuntimeException("Simulating an error on flush")); + EasyMock.replay(cache); + EasyMock.reset(underlyingStore); + underlyingStore.close(); + EasyMock.replay(underlyingStore); + + assertThrows(RuntimeException.class, store::close); + EasyMock.verify(cache, underlyingStore); + } + + @Test + public void shouldCloseWrappedStoreAfterErrorDuringCacheClose() { + setUpCloseTests(); + EasyMock.reset(cache); + cache.flush(CACHE_NAMESPACE); + cache.close(CACHE_NAMESPACE); + EasyMock.expectLastCall().andThrow(new RuntimeException("Simulating an error on close")); + EasyMock.replay(cache); + EasyMock.reset(underlyingStore); + underlyingStore.close(); + EasyMock.replay(underlyingStore); + + assertThrows(RuntimeException.class, store::close); + EasyMock.verify(cache, underlyingStore); + } + + @Test + public void shouldCloseCacheAfterErrorDuringStateStoreClose() { + setUpCloseTests(); + EasyMock.reset(cache); + cache.flush(CACHE_NAMESPACE); + cache.close(CACHE_NAMESPACE); + EasyMock.replay(cache); + EasyMock.reset(underlyingStore); + underlyingStore.close(); + EasyMock.expectLastCall().andThrow(new RuntimeException("Simulating an error on close")); + EasyMock.replay(underlyingStore); + + assertThrows(RuntimeException.class, store::close); + EasyMock.verify(cache, underlyingStore); + } + + private void setUpCloseTests() { + underlyingStore = EasyMock.createNiceMock(KeyValueStore.class); + EasyMock.expect(underlyingStore.name()).andStubReturn("store-name"); + EasyMock.expect(underlyingStore.isOpen()).andStubReturn(true); + EasyMock.replay(underlyingStore); + store = new CachingKeyValueStore(underlyingStore); + cache = EasyMock.niceMock(ThreadCache.class); + context = new InternalMockProcessorContext<>(TestUtils.tempDirectory(), null, null, null, cache); + context.setRecordContext(new ProcessorRecordContext(10, 0, 0, TOPIC, new RecordHeaders())); + store.init((StateStoreContext) context, store); + } + + @Test + public void shouldPutGetToFromCache() { + store.put(bytesKey("key"), bytesValue("value")); + store.put(bytesKey("key2"), bytesValue("value2")); + assertThat(store.get(bytesKey("key")), equalTo(bytesValue("value"))); + assertThat(store.get(bytesKey("key2")), equalTo(bytesValue("value2"))); + // nothing evicted so underlying store should be empty + assertEquals(2, cache.size()); + assertEquals(0, underlyingStore.approximateNumEntries()); + } + + private byte[] bytesValue(final String value) { + return value.getBytes(); + } + + private Bytes bytesKey(final String key) { + return Bytes.wrap(key.getBytes()); + } + + @Test + public void shouldFlushEvictedItemsIntoUnderlyingStore() { + final int added = addItemsToCache(); + // all dirty entries should have been flushed + assertEquals(added, underlyingStore.approximateNumEntries()); + assertEquals(added, store.approximateNumEntries()); + assertNotNull(underlyingStore.get(Bytes.wrap("0".getBytes()))); + } + + @Test + public void shouldForwardDirtyItemToListenerWhenEvicted() { + final int numRecords = addItemsToCache(); + assertEquals(numRecords, cacheFlushListener.forwarded.size()); + } + + @Test + public void shouldForwardDirtyItemsWhenFlushCalled() { + store.put(bytesKey("1"), bytesValue("a")); + store.flush(); + assertEquals("a", cacheFlushListener.forwarded.get("1").newValue); + assertNull(cacheFlushListener.forwarded.get("1").oldValue); + } + + @Test + public void shouldForwardOldValuesWhenEnabled() { + store.setFlushListener(cacheFlushListener, true); + store.put(bytesKey("1"), bytesValue("a")); + store.flush(); + assertEquals("a", cacheFlushListener.forwarded.get("1").newValue); + assertNull(cacheFlushListener.forwarded.get("1").oldValue); + store.put(bytesKey("1"), bytesValue("b")); + store.put(bytesKey("1"), bytesValue("c")); + store.flush(); + assertEquals("c", cacheFlushListener.forwarded.get("1").newValue); + assertEquals("a", cacheFlushListener.forwarded.get("1").oldValue); + store.put(bytesKey("1"), null); + store.flush(); + assertNull(cacheFlushListener.forwarded.get("1").newValue); + assertEquals("c", cacheFlushListener.forwarded.get("1").oldValue); + cacheFlushListener.forwarded.clear(); + store.put(bytesKey("1"), bytesValue("a")); + store.put(bytesKey("1"), bytesValue("b")); + store.put(bytesKey("1"), null); + store.flush(); + assertNull(cacheFlushListener.forwarded.get("1")); + cacheFlushListener.forwarded.clear(); + } + + @Test + public void shouldNotForwardOldValuesWhenDisabled() { + store.put(bytesKey("1"), bytesValue("a")); + store.flush(); + assertEquals("a", cacheFlushListener.forwarded.get("1").newValue); + assertNull(cacheFlushListener.forwarded.get("1").oldValue); + store.put(bytesKey("1"), bytesValue("b")); + store.flush(); + assertEquals("b", cacheFlushListener.forwarded.get("1").newValue); + assertNull(cacheFlushListener.forwarded.get("1").oldValue); + store.put(bytesKey("1"), null); + store.flush(); + assertNull(cacheFlushListener.forwarded.get("1").newValue); + assertNull(cacheFlushListener.forwarded.get("1").oldValue); + cacheFlushListener.forwarded.clear(); + store.put(bytesKey("1"), bytesValue("a")); + store.put(bytesKey("1"), bytesValue("b")); + store.put(bytesKey("1"), null); + store.flush(); + assertNull(cacheFlushListener.forwarded.get("1")); + cacheFlushListener.forwarded.clear(); + } + + @Test + public void shouldIterateAllStoredItems() { + final int items = addItemsToCache(); + final List results = new ArrayList<>(); + + try (final KeyValueIterator all = store.all()) { + while (all.hasNext()) { + results.add(all.next().key); + } + } + + assertEquals(items, results.size()); + assertEquals(Arrays.asList( + Bytes.wrap("0".getBytes()), + Bytes.wrap("1".getBytes()), + Bytes.wrap("2".getBytes()) + ), results); + + } + + @Test + public void shouldReverseIterateAllStoredItems() { + final int items = addItemsToCache(); + final List results = new ArrayList<>(); + + try (final KeyValueIterator all = store.reverseAll()) { + while (all.hasNext()) { + results.add(all.next().key); + } + } + + assertEquals(items, results.size()); + assertEquals(Arrays.asList( + Bytes.wrap("2".getBytes()), + Bytes.wrap("1".getBytes()), + Bytes.wrap("0".getBytes()) + ), results); + + } + + @Test + public void shouldIterateOverRange() { + final int items = addItemsToCache(); + final List results = new ArrayList<>(); + + try (final KeyValueIterator range = + store.range(bytesKey(String.valueOf(0)), bytesKey(String.valueOf(items)))) { + while (range.hasNext()) { + results.add(range.next().key); + } + } + + assertEquals(items, results.size()); + assertEquals(Arrays.asList( + Bytes.wrap("0".getBytes()), + Bytes.wrap("1".getBytes()), + Bytes.wrap("2".getBytes()) + ), results); + } + + @Test + public void shouldReverseIterateOverRange() { + final int items = addItemsToCache(); + final List results = new ArrayList<>(); + + try (final KeyValueIterator range = + store.reverseRange(bytesKey(String.valueOf(0)), bytesKey(String.valueOf(items)))) { + while (range.hasNext()) { + results.add(range.next().key); + } + } + + assertEquals(items, results.size()); + assertEquals(Arrays.asList( + Bytes.wrap("2".getBytes()), + Bytes.wrap("1".getBytes()), + Bytes.wrap("0".getBytes()) + ), results); + } + + @Test + public void shouldGetRecordsWithPrefixKey() { + final List> entries = new ArrayList<>(); + entries.add(new KeyValue<>(bytesKey("p11"), bytesValue("2"))); + entries.add(new KeyValue<>(bytesKey("k1"), bytesValue("1"))); + entries.add(new KeyValue<>(bytesKey("k2"), bytesValue("2"))); + entries.add(new KeyValue<>(bytesKey("p2"), bytesValue("2"))); + entries.add(new KeyValue<>(bytesKey("p1"), bytesValue("2"))); + entries.add(new KeyValue<>(bytesKey("p0"), bytesValue("2"))); + + store.putAll(entries); + + final List keys = new ArrayList<>(); + final List values = new ArrayList<>(); + int numberOfKeysReturned = 0; + + try (final KeyValueIterator keysWithPrefix = store.prefixScan("p1", new StringSerializer())) { + while (keysWithPrefix.hasNext()) { + final KeyValue next = keysWithPrefix.next(); + keys.add(next.key.toString()); + values.add(new String(next.value)); + numberOfKeysReturned++; + } + } + + assertThat(numberOfKeysReturned, is(2)); + assertThat(keys, is(Arrays.asList("p1", "p11"))); + assertThat(values, is(Arrays.asList("2", "2"))); + + } + + @Test + public void shouldGetRecordsWithPrefixKeyExcludingNextLargestKey() { + final List> entries = new ArrayList<>(); + entries.add(new KeyValue<>(bytesKey("abcd"), bytesValue("2"))); + entries.add(new KeyValue<>(bytesKey("abcdd"), bytesValue("1"))); + entries.add(new KeyValue<>(bytesKey("abce"), bytesValue("2"))); + entries.add(new KeyValue<>(bytesKey("abc"), bytesValue("2"))); + + store.putAll(entries); + + final List keys = new ArrayList<>(); + final List values = new ArrayList<>(); + int numberOfKeysReturned = 0; + + try (final KeyValueIterator keysWithPrefix = store.prefixScan("abcd", new StringSerializer())) { + while (keysWithPrefix.hasNext()) { + final KeyValue next = keysWithPrefix.next(); + keys.add(next.key.toString()); + values.add(new String(next.value)); + numberOfKeysReturned++; + } + } + + assertThat(numberOfKeysReturned, is(2)); + assertThat(keys, is(Arrays.asList("abcd", "abcdd"))); + assertThat(values, is(Arrays.asList("2", "1"))); + } + + @Test + public void shouldDeleteItemsFromCache() { + store.put(bytesKey("a"), bytesValue("a")); + store.delete(bytesKey("a")); + assertNull(store.get(bytesKey("a"))); + assertFalse(store.range(bytesKey("a"), bytesKey("b")).hasNext()); + assertFalse(store.reverseRange(bytesKey("a"), bytesKey("b")).hasNext()); + assertFalse(store.all().hasNext()); + assertFalse(store.reverseAll().hasNext()); + } + + @Test + public void shouldNotShowItemsDeletedFromCacheButFlushedToStoreBeforeDelete() { + store.put(bytesKey("a"), bytesValue("a")); + store.flush(); + store.delete(bytesKey("a")); + assertNull(store.get(bytesKey("a"))); + assertFalse(store.range(bytesKey("a"), bytesKey("b")).hasNext()); + assertFalse(store.reverseRange(bytesKey("a"), bytesKey("b")).hasNext()); + assertFalse(store.all().hasNext()); + assertFalse(store.reverseAll().hasNext()); + } + + @Test + public void shouldClearNamespaceCacheOnClose() { + store.put(bytesKey("a"), bytesValue("a")); + assertEquals(1, cache.size()); + store.close(); + assertEquals(0, cache.size()); + } + + @Test + public void shouldThrowIfTryingToGetFromClosedCachingStore() { + assertThrows(InvalidStateStoreException.class, () -> { + store.close(); + store.get(bytesKey("a")); + }); + } + + @Test + public void shouldThrowIfTryingToWriteToClosedCachingStore() { + assertThrows(InvalidStateStoreException.class, () -> { + store.close(); + store.put(bytesKey("a"), bytesValue("a")); + }); + } + + @Test + public void shouldThrowIfTryingToDoRangeQueryOnClosedCachingStore() { + assertThrows(InvalidStateStoreException.class, () -> { + store.close(); + store.range(bytesKey("a"), bytesKey("b")); + }); + } + + @Test + public void shouldThrowIfTryingToDoReverseRangeQueryOnClosedCachingStore() { + assertThrows(InvalidStateStoreException.class, () -> { + store.close(); + store.reverseRange(bytesKey("a"), bytesKey("b")); + }); + } + + @Test + public void shouldThrowIfTryingToDoAllQueryOnClosedCachingStore() { + assertThrows(InvalidStateStoreException.class, () -> { + store.close(); + store.all(); + }); + } + + @Test + public void shouldThrowIfTryingToDoReverseAllQueryOnClosedCachingStore() { + assertThrows(InvalidStateStoreException.class, () -> { + store.close(); + store.reverseAll(); + }); + } + + @Test + public void shouldThrowIfTryingToDoGetApproxSizeOnClosedCachingStore() { + assertThrows(InvalidStateStoreException.class, () -> { + store.close(); + store.close(); + store.approximateNumEntries(); + }); + } + + @Test + public void shouldThrowIfTryingToDoPutAllClosedCachingStore() { + assertThrows(InvalidStateStoreException.class, () -> { + store.close(); + store.putAll(Collections.singletonList(KeyValue.pair(bytesKey("a"), bytesValue("a")))); + }); + } + + @Test + public void shouldThrowIfTryingToDoPutIfAbsentClosedCachingStore() { + assertThrows(InvalidStateStoreException.class, () -> { + store.close(); + store.putIfAbsent(bytesKey("b"), bytesValue("c")); + }); + } + + @Test + public void shouldThrowNullPointerExceptionOnPutWithNullKey() { + assertThrows(NullPointerException.class, () -> store.put(null, bytesValue("c"))); + } + + @Test + public void shouldThrowNullPointerExceptionOnPutIfAbsentWithNullKey() { + assertThrows(NullPointerException.class, () -> store.putIfAbsent(null, bytesValue("c"))); + } + + @Test + public void shouldThrowNullPointerExceptionOnPutAllWithNullKey() { + final List> entries = new ArrayList<>(); + entries.add(new KeyValue<>(null, bytesValue("a"))); + assertThrows(NullPointerException.class, () -> store.putAll(entries)); + } + + @Test + public void shouldPutIfAbsent() { + store.putIfAbsent(bytesKey("b"), bytesValue("2")); + assertThat(store.get(bytesKey("b")), equalTo(bytesValue("2"))); + + store.putIfAbsent(bytesKey("b"), bytesValue("3")); + assertThat(store.get(bytesKey("b")), equalTo(bytesValue("2"))); + } + + @Test + public void shouldPutAll() { + final List> entries = new ArrayList<>(); + entries.add(new KeyValue<>(bytesKey("a"), bytesValue("1"))); + entries.add(new KeyValue<>(bytesKey("b"), bytesValue("2"))); + store.putAll(entries); + assertThat(store.get(bytesKey("a")), equalTo(bytesValue("1"))); + assertThat(store.get(bytesKey("b")), equalTo(bytesValue("2"))); + } + + @Test + public void shouldReturnUnderlying() { + assertEquals(underlyingStore, store.wrapped()); + } + + @Test + public void shouldThrowIfTryingToDeleteFromClosedCachingStore() { + assertThrows(InvalidStateStoreException.class, () -> { + store.close(); + store.delete(bytesKey("key")); + }); + } + + private int addItemsToCache() { + int cachedSize = 0; + int i = 0; + while (cachedSize < maxCacheSizeBytes) { + final String kv = String.valueOf(i++); + store.put(bytesKey(kv), bytesValue(kv)); + cachedSize += memoryCacheEntrySize(kv.getBytes(), kv.getBytes(), TOPIC); + } + return i; + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingInMemorySessionStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingInMemorySessionStoreTest.java new file mode 100644 index 0000000..4116df5 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingInMemorySessionStoreTest.java @@ -0,0 +1,860 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.kstream.SessionWindowedDeserializer; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.Change; +import org.apache.kafka.streams.kstream.internals.SessionWindow; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.MockStreamsMetrics; +import org.apache.kafka.streams.processor.internals.ProcessorRecordContext; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.SessionStore; +import org.apache.kafka.test.InternalMockProcessorContext; +import org.apache.kafka.test.TestUtils; +import org.easymock.EasyMock; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.Random; + +import static java.util.Arrays.asList; +import static org.apache.kafka.test.StreamsTestUtils.toList; +import static org.apache.kafka.test.StreamsTestUtils.verifyKeyValueList; +import static org.apache.kafka.test.StreamsTestUtils.verifyWindowedKeyValue; +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +@SuppressWarnings("PointlessArithmeticExpression") +public class CachingInMemorySessionStoreTest { + + private static final int MAX_CACHE_SIZE_BYTES = 600; + private static final Long DEFAULT_TIMESTAMP = 10L; + private static final long SEGMENT_INTERVAL = 100L; + private static final String TOPIC = "topic"; + private static final String CACHE_NAMESPACE = "0_0-store-name"; + + private final Bytes keyA = Bytes.wrap("a".getBytes()); + private final Bytes keyAA = Bytes.wrap("aa".getBytes()); + private final Bytes keyB = Bytes.wrap("b".getBytes()); + + private SessionStore underlyingStore; + private InternalMockProcessorContext context; + private CachingSessionStore cachingStore; + private ThreadCache cache; + + @Before + public void before() { + underlyingStore = new InMemorySessionStore("store-name", Long.MAX_VALUE, "metric-scope"); + cachingStore = new CachingSessionStore(underlyingStore, SEGMENT_INTERVAL); + cache = new ThreadCache(new LogContext("testCache "), MAX_CACHE_SIZE_BYTES, new MockStreamsMetrics(new Metrics())); + context = new InternalMockProcessorContext<>(TestUtils.tempDirectory(), null, null, null, cache); + context.setRecordContext(new ProcessorRecordContext(DEFAULT_TIMESTAMP, 0, 0, TOPIC, new RecordHeaders())); + cachingStore.init((StateStoreContext) context, cachingStore); + } + + @After + public void after() { + cachingStore.close(); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldDelegateDeprecatedInit() { + final SessionStore inner = EasyMock.mock(InMemorySessionStore.class); + final CachingSessionStore outer = new CachingSessionStore(inner, SEGMENT_INTERVAL); + EasyMock.expect(inner.name()).andStubReturn("store"); + inner.init((ProcessorContext) context, outer); + EasyMock.expectLastCall(); + EasyMock.replay(inner); + outer.init((ProcessorContext) context, outer); + EasyMock.verify(inner); + } + + @Test + public void shouldDelegateInit() { + final SessionStore inner = EasyMock.mock(InMemorySessionStore.class); + final CachingSessionStore outer = new CachingSessionStore(inner, SEGMENT_INTERVAL); + EasyMock.expect(inner.name()).andStubReturn("store"); + inner.init((StateStoreContext) context, outer); + EasyMock.expectLastCall(); + EasyMock.replay(inner); + outer.init((StateStoreContext) context, outer); + EasyMock.verify(inner); + } + + @Test + public void shouldPutFetchFromCache() { + cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes()); + cachingStore.put(new Windowed<>(keyAA, new SessionWindow(0, 0)), "1".getBytes()); + cachingStore.put(new Windowed<>(keyB, new SessionWindow(0, 0)), "1".getBytes()); + + assertEquals(3, cache.size()); + + try (final KeyValueIterator, byte[]> a = cachingStore.findSessions(keyA, 0, 0); + final KeyValueIterator, byte[]> b = cachingStore.findSessions(keyB, 0, 0)) { + + verifyWindowedKeyValue(a.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(b.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + assertFalse(a.hasNext()); + assertFalse(b.hasNext()); + } + } + + @Test + public void shouldPutFetchAllKeysFromCache() { + cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes()); + cachingStore.put(new Windowed<>(keyAA, new SessionWindow(0, 0)), "1".getBytes()); + cachingStore.put(new Windowed<>(keyB, new SessionWindow(0, 0)), "1".getBytes()); + + assertEquals(3, cache.size()); + + try (final KeyValueIterator, byte[]> all = cachingStore.fetch(keyA, keyB)) { + verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + assertFalse(all.hasNext()); + } + + // infinite keyFrom fetch + try (final KeyValueIterator, byte[]> all = cachingStore.fetch(null, keyB)) { + verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + assertFalse(all.hasNext()); + } + + // infinite keyTo fetch + try (final KeyValueIterator, byte[]> all = cachingStore.fetch(null, keyB)) { + verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + assertFalse(all.hasNext()); + } + + // infinite keyFrom and keyTo fetch + try (final KeyValueIterator, byte[]> all = cachingStore.fetch(null, keyB)) { + verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + assertFalse(all.hasNext()); + } + } + + @Test + public void shouldPutBackwardFetchAllKeysFromCache() { + cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes()); + cachingStore.put(new Windowed<>(keyAA, new SessionWindow(0, 0)), "1".getBytes()); + cachingStore.put(new Windowed<>(keyB, new SessionWindow(0, 0)), "1".getBytes()); + + assertEquals(3, cache.size()); + + try (final KeyValueIterator, byte[]> all = cachingStore.backwardFetch(keyA, keyB)) { + verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1"); + assertFalse(all.hasNext()); + } + + // infinite keyFrom fetch + try (final KeyValueIterator, byte[]> all = cachingStore.backwardFetch(null, keyB)) { + verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1"); + assertFalse(all.hasNext()); + } + + // infinite keyTo fetch + try (final KeyValueIterator, byte[]> all = cachingStore.backwardFetch(null, keyB)) { + verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1"); + assertFalse(all.hasNext()); + } + + // infinite keyFrom and keyTo fetch + try (final KeyValueIterator, byte[]> all = cachingStore.backwardFetch(null, null)) { + verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1"); + assertFalse(all.hasNext()); + } + } + + @Test + public void shouldCloseWrappedStoreAndCacheAfterErrorDuringCacheFlush() { + setUpCloseTests(); + EasyMock.reset(cache); + cache.flush(CACHE_NAMESPACE); + EasyMock.expectLastCall().andThrow(new RuntimeException("Simulating an error on flush")); + EasyMock.replay(cache); + EasyMock.reset(underlyingStore); + underlyingStore.close(); + EasyMock.replay(underlyingStore); + + assertThrows(RuntimeException.class, cachingStore::close); + EasyMock.verify(cache, underlyingStore); + } + + @Test + public void shouldCloseWrappedStoreAfterErrorDuringCacheClose() { + setUpCloseTests(); + EasyMock.reset(cache); + cache.flush(CACHE_NAMESPACE); + cache.close(CACHE_NAMESPACE); + EasyMock.expectLastCall().andThrow(new RuntimeException("Simulating an error on close")); + EasyMock.replay(cache); + EasyMock.reset(underlyingStore); + underlyingStore.close(); + EasyMock.replay(underlyingStore); + + assertThrows(RuntimeException.class, cachingStore::close); + EasyMock.verify(cache, underlyingStore); + } + + @Test + public void shouldCloseCacheAfterErrorDuringWrappedStoreClose() { + setUpCloseTests(); + EasyMock.reset(cache); + cache.flush(CACHE_NAMESPACE); + cache.close(CACHE_NAMESPACE); + EasyMock.replay(cache); + EasyMock.reset(underlyingStore); + underlyingStore.close(); + EasyMock.expectLastCall().andThrow(new RuntimeException("Simulating an error on close")); + EasyMock.replay(underlyingStore); + + assertThrows(RuntimeException.class, cachingStore::close); + EasyMock.verify(cache, underlyingStore); + } + + private void setUpCloseTests() { + underlyingStore = EasyMock.createNiceMock(SessionStore.class); + EasyMock.expect(underlyingStore.name()).andStubReturn("store-name"); + EasyMock.expect(underlyingStore.isOpen()).andStubReturn(true); + EasyMock.replay(underlyingStore); + cachingStore = new CachingSessionStore(underlyingStore, SEGMENT_INTERVAL); + cache = EasyMock.niceMock(ThreadCache.class); + final InternalMockProcessorContext context = new InternalMockProcessorContext<>(TestUtils.tempDirectory(), null, null, null, cache); + context.setRecordContext(new ProcessorRecordContext(10, 0, 0, TOPIC, new RecordHeaders())); + cachingStore.init((StateStoreContext) context, cachingStore); + } + + @Test + public void shouldPutFetchRangeFromCache() { + cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes()); + cachingStore.put(new Windowed<>(keyAA, new SessionWindow(0, 0)), "1".getBytes()); + cachingStore.put(new Windowed<>(keyB, new SessionWindow(0, 0)), "1".getBytes()); + + assertEquals(3, cache.size()); + + try (final KeyValueIterator, byte[]> some = + cachingStore.findSessions(keyAA, keyB, 0, 0)) { + verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(some.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + assertFalse(some.hasNext()); + } + + // infinite keyFrom case + try (final KeyValueIterator, byte[]> some = + cachingStore.findSessions(null, keyAA, 0, 0)) { + verifyWindowedKeyValue(some.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + assertFalse(some.hasNext()); + } + + // infinite keyTo case + try (final KeyValueIterator, byte[]> some = + cachingStore.findSessions(keyAA, keyB, 0, 0)) { + verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(some.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + assertFalse(some.hasNext()); + } + + // infinite keyFrom and keyTo case + try (final KeyValueIterator, byte[]> some = + cachingStore.findSessions(null, null, 0, 0)) { + verifyWindowedKeyValue(some.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(some.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + assertFalse(some.hasNext()); + } + } + + @Test + public void shouldPutBackwardFetchRangeFromCache() { + cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes()); + cachingStore.put(new Windowed<>(keyAA, new SessionWindow(0, 0)), "1".getBytes()); + cachingStore.put(new Windowed<>(keyB, new SessionWindow(0, 0)), "1".getBytes()); + + assertEquals(3, cache.size()); + + try (final KeyValueIterator, byte[]> some = + cachingStore.backwardFindSessions(keyAA, keyB, 0, 0)) { + verifyWindowedKeyValue(some.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + assertFalse(some.hasNext()); + } + + // infinite keyFrom case + try (final KeyValueIterator, byte[]> some = + cachingStore.backwardFindSessions(null, keyAA, 0, 0)) { + verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(some.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1"); + assertFalse(some.hasNext()); + } + + // infinite keyTo case + try (final KeyValueIterator, byte[]> some = + cachingStore.backwardFindSessions(keyAA, keyB, 0, 0)) { + verifyWindowedKeyValue(some.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + assertFalse(some.hasNext()); + } + + // infinite keyFrom and keyTo case + try (final KeyValueIterator, byte[]> some = + cachingStore.backwardFindSessions(null, null, 0, 0)) { + verifyWindowedKeyValue(some.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(some.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1"); + assertFalse(some.hasNext()); + } + } + + @Test + public void shouldFetchAllSessionsWithSameRecordKey() { + final List, byte[]>> expected = asList( + KeyValue.pair(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes()), + KeyValue.pair(new Windowed<>(keyA, new SessionWindow(10, 10)), "2".getBytes()), + KeyValue.pair(new Windowed<>(keyA, new SessionWindow(100, 100)), "3".getBytes()), + KeyValue.pair(new Windowed<>(keyA, new SessionWindow(1000, 1000)), "4".getBytes()) + ); + for (final KeyValue, byte[]> kv : expected) { + cachingStore.put(kv.key, kv.value); + } + + // add one that shouldn't appear in the results + cachingStore.put(new Windowed<>(keyAA, new SessionWindow(0, 0)), "5".getBytes()); + + final List, byte[]>> results = toList(cachingStore.fetch(keyA)); + verifyKeyValueList(expected, results); + } + + @Test + public void shouldBackwardFetchAllSessionsWithSameRecordKey() { + final List, byte[]>> expected = asList( + KeyValue.pair(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes()), + KeyValue.pair(new Windowed<>(keyA, new SessionWindow(10, 10)), "2".getBytes()), + KeyValue.pair(new Windowed<>(keyA, new SessionWindow(100, 100)), "3".getBytes()), + KeyValue.pair(new Windowed<>(keyA, new SessionWindow(1000, 1000)), "4".getBytes()) + ); + for (final KeyValue, byte[]> kv : expected) { + cachingStore.put(kv.key, kv.value); + } + + // add one that shouldn't appear in the results + cachingStore.put(new Windowed<>(keyAA, new SessionWindow(0, 0)), "5".getBytes()); + + final List, byte[]>> results = toList(cachingStore.backwardFetch(keyA)); + Collections.reverse(results); + verifyKeyValueList(expected, results); + } + + @Test + public void shouldFlushItemsToStoreOnEviction() { + final List, byte[]>> added = addSessionsUntilOverflow("a", "b", "c", "d"); + assertEquals(added.size() - 1, cache.size()); + try (final KeyValueIterator, byte[]> iterator = cachingStore.findSessions(added.get(0).key.key(), 0, 0)) { + final KeyValue, byte[]> next = iterator.next(); + assertEquals(added.get(0).key, next.key); + assertArrayEquals(added.get(0).value, next.value); + } + } + + @Test + public void shouldQueryItemsInCacheAndStore() { + final List, byte[]>> added = addSessionsUntilOverflow("a"); + final List, byte[]>> actual = toList(cachingStore.findSessions( + Bytes.wrap("a".getBytes(StandardCharsets.UTF_8)), + 0, + added.size() * 10L)); + verifyKeyValueList(added, actual); + } + + @Test + public void shouldRemove() { + final Windowed a = new Windowed<>(keyA, new SessionWindow(0, 0)); + final Windowed b = new Windowed<>(keyB, new SessionWindow(0, 0)); + cachingStore.put(a, "2".getBytes()); + cachingStore.put(b, "2".getBytes()); + cachingStore.remove(a); + + try (final KeyValueIterator, byte[]> rangeIter = + cachingStore.findSessions(keyA, 0, 0)) { + assertFalse(rangeIter.hasNext()); + + assertNull(cachingStore.fetchSession(keyA, 0, 0)); + assertThat(cachingStore.fetchSession(keyB, 0, 0), equalTo("2".getBytes())); + } + } + + @Test + public void shouldFetchCorrectlyAcrossSegments() { + final Windowed a1 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 0, SEGMENT_INTERVAL * 0)); + final Windowed a2 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 1, SEGMENT_INTERVAL * 1)); + final Windowed a3 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 2, SEGMENT_INTERVAL * 2)); + final Windowed a4 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 3, SEGMENT_INTERVAL * 3)); + final Windowed a5 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 4, SEGMENT_INTERVAL * 4)); + final Windowed a6 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 5, SEGMENT_INTERVAL * 5)); + cachingStore.put(a1, "1".getBytes()); + cachingStore.put(a2, "2".getBytes()); + cachingStore.put(a3, "3".getBytes()); + cachingStore.flush(); + cachingStore.put(a4, "4".getBytes()); + cachingStore.put(a5, "5".getBytes()); + cachingStore.put(a6, "6".getBytes()); + try (final KeyValueIterator, byte[]> results = + cachingStore.findSessions(keyA, 0, SEGMENT_INTERVAL * 5)) { + assertEquals(a1, results.next().key); + assertEquals(a2, results.next().key); + assertEquals(a3, results.next().key); + assertEquals(a4, results.next().key); + assertEquals(a5, results.next().key); + assertEquals(a6, results.next().key); + assertFalse(results.hasNext()); + } + } + + @Test + public void shouldBackwardFetchCorrectlyAcrossSegments() { + final Windowed a1 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 0, SEGMENT_INTERVAL * 0)); + final Windowed a2 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 1, SEGMENT_INTERVAL * 1)); + final Windowed a3 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 2, SEGMENT_INTERVAL * 2)); + final Windowed a4 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 3, SEGMENT_INTERVAL * 3)); + final Windowed a5 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 4, SEGMENT_INTERVAL * 4)); + final Windowed a6 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 5, SEGMENT_INTERVAL * 5)); + cachingStore.put(a1, "1".getBytes()); + cachingStore.put(a2, "2".getBytes()); + cachingStore.put(a3, "3".getBytes()); + cachingStore.flush(); + cachingStore.put(a4, "4".getBytes()); + cachingStore.put(a5, "5".getBytes()); + cachingStore.put(a6, "6".getBytes()); + try (final KeyValueIterator, byte[]> results = + cachingStore.backwardFindSessions(keyA, 0, SEGMENT_INTERVAL * 5)) { + assertEquals(a6, results.next().key); + assertEquals(a5, results.next().key); + assertEquals(a4, results.next().key); + assertEquals(a3, results.next().key); + assertEquals(a2, results.next().key); + assertEquals(a1, results.next().key); + assertFalse(results.hasNext()); + } + } + + @Test + public void shouldFetchRangeCorrectlyAcrossSegments() { + final Windowed a1 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 0, SEGMENT_INTERVAL * 0)); + final Windowed aa1 = new Windowed<>(keyAA, new SessionWindow(SEGMENT_INTERVAL * 0, SEGMENT_INTERVAL * 0)); + final Windowed a2 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 1, SEGMENT_INTERVAL * 1)); + final Windowed a3 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 2, SEGMENT_INTERVAL * 2)); + final Windowed aa3 = new Windowed<>(keyAA, new SessionWindow(SEGMENT_INTERVAL * 2, SEGMENT_INTERVAL * 2)); + cachingStore.put(a1, "1".getBytes()); + cachingStore.put(aa1, "1".getBytes()); + cachingStore.put(a2, "2".getBytes()); + cachingStore.put(a3, "3".getBytes()); + cachingStore.put(aa3, "3".getBytes()); + + final KeyValueIterator, byte[]> rangeResults = + cachingStore.findSessions(keyA, keyAA, 0, SEGMENT_INTERVAL * 2); + final List> keys = new ArrayList<>(); + while (rangeResults.hasNext()) { + keys.add(rangeResults.next().key); + } + rangeResults.close(); + assertEquals(asList(a1, aa1, a2, a3, aa3), keys); + } + + @Test + public void shouldBackwardFetchRangeCorrectlyAcrossSegments() { + final Windowed a1 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 0, SEGMENT_INTERVAL * 0)); + final Windowed aa1 = new Windowed<>(keyAA, new SessionWindow(SEGMENT_INTERVAL * 0, SEGMENT_INTERVAL * 0)); + final Windowed a2 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 1, SEGMENT_INTERVAL * 1)); + final Windowed a3 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 2, SEGMENT_INTERVAL * 2)); + final Windowed aa3 = new Windowed<>(keyAA, new SessionWindow(SEGMENT_INTERVAL * 2, SEGMENT_INTERVAL * 2)); + cachingStore.put(a1, "1".getBytes()); + cachingStore.put(aa1, "1".getBytes()); + cachingStore.put(a2, "2".getBytes()); + cachingStore.put(a3, "3".getBytes()); + cachingStore.put(aa3, "3".getBytes()); + + final KeyValueIterator, byte[]> rangeResults = + cachingStore.backwardFindSessions(keyA, keyAA, 0, SEGMENT_INTERVAL * 2); + final List> keys = new ArrayList<>(); + while (rangeResults.hasNext()) { + keys.add(rangeResults.next().key); + } + rangeResults.close(); + assertEquals(asList(aa3, a3, a2, aa1, a1), keys); + } + + @Test + public void shouldSetFlushListener() { + assertTrue(cachingStore.setFlushListener(null, true)); + assertTrue(cachingStore.setFlushListener(null, false)); + } + + @Test + public void shouldForwardChangedValuesDuringFlush() { + final Windowed a = new Windowed<>(keyA, new SessionWindow(2, 4)); + final Windowed b = new Windowed<>(keyA, new SessionWindow(1, 2)); + final Windowed aDeserialized = new Windowed<>("a", new SessionWindow(2, 4)); + final Windowed bDeserialized = new Windowed<>("a", new SessionWindow(1, 2)); + final CacheFlushListenerStub, String> flushListener = + new CacheFlushListenerStub<>( + new SessionWindowedDeserializer<>(new StringDeserializer()), + new StringDeserializer()); + cachingStore.setFlushListener(flushListener, true); + + cachingStore.put(b, "1".getBytes()); + cachingStore.flush(); + + assertEquals( + Collections.singletonList( + new KeyValueTimestamp<>( + bDeserialized, + new Change<>("1", null), + DEFAULT_TIMESTAMP)), + flushListener.forwarded + ); + flushListener.forwarded.clear(); + + cachingStore.put(a, "1".getBytes()); + cachingStore.flush(); + + assertEquals( + Collections.singletonList( + new KeyValueTimestamp<>( + aDeserialized, + new Change<>("1", null), + DEFAULT_TIMESTAMP)), + flushListener.forwarded + ); + flushListener.forwarded.clear(); + + cachingStore.put(a, "2".getBytes()); + cachingStore.flush(); + + assertEquals( + Collections.singletonList( + new KeyValueTimestamp<>( + aDeserialized, + new Change<>("2", "1"), + DEFAULT_TIMESTAMP)), + flushListener.forwarded + ); + flushListener.forwarded.clear(); + + cachingStore.remove(a); + cachingStore.flush(); + + assertEquals( + Collections.singletonList( + new KeyValueTimestamp<>( + aDeserialized, + new Change<>(null, "2"), + DEFAULT_TIMESTAMP)), + flushListener.forwarded + ); + flushListener.forwarded.clear(); + + cachingStore.put(a, "1".getBytes()); + cachingStore.put(a, "2".getBytes()); + cachingStore.remove(a); + cachingStore.flush(); + + assertEquals( + Collections.emptyList(), + flushListener.forwarded + ); + flushListener.forwarded.clear(); + } + + @Test + public void shouldNotForwardChangedValuesDuringFlushWhenSendOldValuesDisabled() { + final Windowed a = new Windowed<>(keyA, new SessionWindow(0, 0)); + final Windowed aDeserialized = new Windowed<>("a", new SessionWindow(0, 0)); + final CacheFlushListenerStub, String> flushListener = + new CacheFlushListenerStub<>( + new SessionWindowedDeserializer<>(new StringDeserializer()), + new StringDeserializer()); + cachingStore.setFlushListener(flushListener, false); + + cachingStore.put(a, "1".getBytes()); + cachingStore.flush(); + + cachingStore.put(a, "2".getBytes()); + cachingStore.flush(); + + cachingStore.remove(a); + cachingStore.flush(); + + assertEquals( + asList(new KeyValueTimestamp<>( + aDeserialized, + new Change<>("1", null), + DEFAULT_TIMESTAMP), + new KeyValueTimestamp<>( + aDeserialized, + new Change<>("2", null), + DEFAULT_TIMESTAMP), + new KeyValueTimestamp<>( + aDeserialized, + new Change<>(null, null), + DEFAULT_TIMESTAMP)), + flushListener.forwarded + ); + flushListener.forwarded.clear(); + + cachingStore.put(a, "1".getBytes()); + cachingStore.put(a, "2".getBytes()); + cachingStore.remove(a); + cachingStore.flush(); + + assertEquals( + Collections.emptyList(), + flushListener.forwarded + ); + flushListener.forwarded.clear(); + } + + @Test + public void shouldReturnSameResultsForSingleKeyFindSessionsAndEqualKeyRangeFindSessions() { + cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 1)), "1".getBytes()); + cachingStore.put(new Windowed<>(keyAA, new SessionWindow(2, 3)), "2".getBytes()); + cachingStore.put(new Windowed<>(keyAA, new SessionWindow(4, 5)), "3".getBytes()); + cachingStore.put(new Windowed<>(keyB, new SessionWindow(6, 7)), "4".getBytes()); + + try (final KeyValueIterator, byte[]> singleKeyIterator = cachingStore.findSessions(keyAA, 0L, 10L); + final KeyValueIterator, byte[]> keyRangeIterator = cachingStore.findSessions(keyAA, keyAA, 0L, 10L)) { + + assertEquals(singleKeyIterator.next(), keyRangeIterator.next()); + assertEquals(singleKeyIterator.next(), keyRangeIterator.next()); + assertFalse(singleKeyIterator.hasNext()); + assertFalse(keyRangeIterator.hasNext()); + } + } + + @Test + public void shouldReturnSameResultsForSingleKeyFindSessionsBackwardsAndEqualKeyRangeFindSessions() { + cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 1)), "1".getBytes()); + cachingStore.put(new Windowed<>(keyAA, new SessionWindow(2, 3)), "2".getBytes()); + cachingStore.put(new Windowed<>(keyAA, new SessionWindow(4, 5)), "3".getBytes()); + cachingStore.put(new Windowed<>(keyB, new SessionWindow(6, 7)), "4".getBytes()); + + try (final KeyValueIterator, byte[]> singleKeyIterator = + cachingStore.backwardFindSessions(keyAA, 0L, 10L); + final KeyValueIterator, byte[]> keyRangeIterator = + cachingStore.backwardFindSessions(keyAA, keyAA, 0L, 10L)) { + + assertEquals(singleKeyIterator.next(), keyRangeIterator.next()); + assertEquals(singleKeyIterator.next(), keyRangeIterator.next()); + assertFalse(singleKeyIterator.hasNext()); + assertFalse(keyRangeIterator.hasNext()); + } + } + + @Test + public void shouldClearNamespaceCacheOnClose() { + final Windowed a1 = new Windowed<>(keyA, new SessionWindow(0, 0)); + cachingStore.put(a1, "1".getBytes()); + assertEquals(1, cache.size()); + cachingStore.close(); + assertEquals(0, cache.size()); + } + + @Test + public void shouldThrowIfTryingToFetchFromClosedCachingStore() { + cachingStore.close(); + assertThrows(InvalidStateStoreException.class, () -> cachingStore.fetch(keyA)); + } + + @Test + public void shouldThrowIfTryingToFindMergeSessionFromClosedCachingStore() { + cachingStore.close(); + assertThrows(InvalidStateStoreException.class, () -> cachingStore.findSessions(keyA, 0, Long.MAX_VALUE)); + } + + @Test + public void shouldThrowIfTryingToRemoveFromClosedCachingStore() { + cachingStore.close(); + assertThrows(InvalidStateStoreException.class, () -> cachingStore.remove(new Windowed<>(keyA, new SessionWindow(0, 0)))); + } + + @Test + public void shouldThrowIfTryingToPutIntoClosedCachingStore() { + cachingStore.close(); + assertThrows(InvalidStateStoreException.class, () -> cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes())); + } + + @Test + public void shouldThrowNullPointerExceptionOnFindSessionsNullKey() { + assertThrows(NullPointerException.class, () -> cachingStore.findSessions(null, 1L, 2L)); + } + + @Test + public void shouldThrowNullPointerExceptionOnFetchNullKey() { + assertThrows(NullPointerException.class, () -> cachingStore.fetch(null)); + } + + @Test + public void shouldThrowNullPointerExceptionOnRemoveNullKey() { + assertThrows(NullPointerException.class, () -> cachingStore.remove(null)); + } + + @Test + public void shouldThrowNullPointerExceptionOnPutNullKey() { + assertThrows(NullPointerException.class, () -> cachingStore.put(null, "1".getBytes())); + } + + @Test + public void shouldNotThrowInvalidRangeExceptionWhenBackwardWithNegativeFromKey() { + final Bytes keyFrom = Bytes.wrap(Serdes.Integer().serializer().serialize("", -1)); + final Bytes keyTo = Bytes.wrap(Serdes.Integer().serializer().serialize("", 1)); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(CachingSessionStore.class); + final KeyValueIterator, byte[]> iterator = cachingStore.backwardFindSessions(keyFrom, keyTo, 0L, 10L)) { + assertFalse(iterator.hasNext()); + + final List messages = appender.getMessages(); + assertThat( + messages, + hasItem( + "Returning empty iterator for fetch with invalid key range: from > to." + + " This may be due to range arguments set in the wrong order, " + + "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes." + + " Note that the built-in numerical serdes do not follow this for negative numbers" + ) + ); + } + } + + @Test + public void shouldNotThrowInvalidRangeExceptionWithNegativeFromKey() { + final Bytes keyFrom = Bytes.wrap(Serdes.Integer().serializer().serialize("", -1)); + final Bytes keyTo = Bytes.wrap(Serdes.Integer().serializer().serialize("", 1)); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(CachingSessionStore.class); + final KeyValueIterator, byte[]> iterator = cachingStore.findSessions(keyFrom, keyTo, 0L, 10L)) { + assertFalse(iterator.hasNext()); + + final List messages = appender.getMessages(); + assertThat( + messages, + hasItem("Returning empty iterator for fetch with invalid key range: from > to." + + " This may be due to range arguments set in the wrong order, " + + "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes." + + " Note that the built-in numerical serdes do not follow this for negative numbers") + ); + } + } + + private List, byte[]>> addSessionsUntilOverflow(final String... sessionIds) { + final Random random = new Random(); + final List, byte[]>> results = new ArrayList<>(); + while (cache.size() == results.size()) { + final String sessionId = sessionIds[random.nextInt(sessionIds.length)]; + addSingleSession(sessionId, results); + } + return results; + } + + private void addSingleSession(final String sessionId, final List, byte[]>> allSessions) { + final int timestamp = allSessions.size() * 10; + final Windowed key = new Windowed<>(Bytes.wrap(sessionId.getBytes()), new SessionWindow(timestamp, timestamp)); + final byte[] value = "1".getBytes(); + cachingStore.put(key, value); + allSessions.add(KeyValue.pair(key, value)); + } + + public static class CacheFlushListenerStub implements CacheFlushListener { + final Deserializer keyDeserializer; + final Deserializer valueDesializer; + final List>> forwarded = new LinkedList<>(); + + CacheFlushListenerStub(final Deserializer keyDeserializer, + final Deserializer valueDesializer) { + this.keyDeserializer = keyDeserializer; + this.valueDesializer = valueDesializer; + } + + @Override + public void apply(final byte[] key, + final byte[] newValue, + final byte[] oldValue, + final long timestamp) { + forwarded.add( + new KeyValueTimestamp<>( + keyDeserializer.deserialize(null, key), + new Change<>( + valueDesializer.deserialize(null, newValue), + valueDesializer.deserialize(null, oldValue)), + timestamp)); + } + + @Override + public void apply(final Record> record) { + forwarded.add( + new KeyValueTimestamp<>( + keyDeserializer.deserialize(null, record.key()), + new Change<>( + valueDesializer.deserialize(null, record.value().newValue), + valueDesializer.deserialize(null, record.value().oldValue)), + record.timestamp() + ) + ); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingPersistentSessionStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingPersistentSessionStoreTest.java new file mode 100644 index 0000000..83a514b --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingPersistentSessionStoreTest.java @@ -0,0 +1,876 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.kstream.SessionWindowedDeserializer; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.Change; +import org.apache.kafka.streams.kstream.internals.SessionWindow; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.MockStreamsMetrics; +import org.apache.kafka.streams.processor.internals.ProcessorRecordContext; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.SessionStore; +import org.apache.kafka.test.InternalMockProcessorContext; +import org.apache.kafka.test.TestUtils; +import org.easymock.EasyMock; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.Random; + +import static java.util.Arrays.asList; +import static org.apache.kafka.test.StreamsTestUtils.toList; +import static org.apache.kafka.test.StreamsTestUtils.verifyKeyValueList; +import static org.apache.kafka.test.StreamsTestUtils.verifyWindowedKeyValue; +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class CachingPersistentSessionStoreTest { + + private static final int MAX_CACHE_SIZE_BYTES = 600; + private static final Long DEFAULT_TIMESTAMP = 10L; + private static final long SEGMENT_INTERVAL = 100L; + private static final String TOPIC = "topic"; + private static final String CACHE_NAMESPACE = "0_0-store-name"; + + private final Bytes keyA = Bytes.wrap("a".getBytes()); + private final Bytes keyAA = Bytes.wrap("aa".getBytes()); + private final Bytes keyB = Bytes.wrap("b".getBytes()); + + private SessionStore underlyingStore; + private CachingSessionStore cachingStore; + private ThreadCache cache; + + @Before + public void before() { + final RocksDBSegmentedBytesStore segmented = new RocksDBSegmentedBytesStore( + "store-name", + "metric-scope", + Long.MAX_VALUE, + SEGMENT_INTERVAL, + new SessionKeySchema() + ); + underlyingStore = new RocksDBSessionStore(segmented); + cachingStore = new CachingSessionStore(underlyingStore, SEGMENT_INTERVAL); + cache = new ThreadCache(new LogContext("testCache "), MAX_CACHE_SIZE_BYTES, new MockStreamsMetrics(new Metrics())); + final InternalMockProcessorContext context = + new InternalMockProcessorContext<>(TestUtils.tempDirectory(), null, null, null, cache); + context.setRecordContext(new ProcessorRecordContext(DEFAULT_TIMESTAMP, 0, 0, TOPIC, new RecordHeaders())); + cachingStore.init((StateStoreContext) context, cachingStore); + } + + @After + public void after() { + cachingStore.close(); + } + + @Test + public void shouldPutFetchFromCache() { + cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes()); + cachingStore.put(new Windowed<>(keyAA, new SessionWindow(0, 0)), "1".getBytes()); + cachingStore.put(new Windowed<>(keyB, new SessionWindow(0, 0)), "1".getBytes()); + + assertEquals(3, cache.size()); + + try (final KeyValueIterator, byte[]> a = + cachingStore.findSessions(keyA, 0, 0); + final KeyValueIterator, byte[]> b = + cachingStore.findSessions(keyB, 0, 0)) { + + verifyWindowedKeyValue(a.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(b.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + assertFalse(a.hasNext()); + assertFalse(b.hasNext()); + } + } + + @Test + public void shouldPutFetchAllKeysFromCache() { + cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes()); + cachingStore.put(new Windowed<>(keyAA, new SessionWindow(0, 0)), "1".getBytes()); + cachingStore.put(new Windowed<>(keyB, new SessionWindow(0, 0)), "1".getBytes()); + + assertEquals(3, cache.size()); + + try (final KeyValueIterator, byte[]> all = + cachingStore.fetch(keyA, keyB)) { + verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + assertFalse(all.hasNext()); + } + + // infinite keyFrom fetch + try (final KeyValueIterator, byte[]> all = + cachingStore.fetch(null, keyB)) { + verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + assertFalse(all.hasNext()); + } + + // infinite keyTo fetch + try (final KeyValueIterator, byte[]> all = + cachingStore.fetch(keyA, null)) { + verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + assertFalse(all.hasNext()); + } + + // infinite keyFrom and keyTo fetch + try (final KeyValueIterator, byte[]> all = + cachingStore.fetch(null, null)) { + verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + assertFalse(all.hasNext()); + } + } + + @Test + public void shouldPutBackwardFetchAllKeysFromCache() { + cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes()); + cachingStore.put(new Windowed<>(keyAA, new SessionWindow(0, 0)), "1".getBytes()); + cachingStore.put(new Windowed<>(keyB, new SessionWindow(0, 0)), "1".getBytes()); + + assertEquals(3, cache.size()); + + try (final KeyValueIterator, byte[]> all = + cachingStore.backwardFetch(keyA, keyB)) { + verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1"); + assertFalse(all.hasNext()); + } + + // infinite keyFrom fetch + try (final KeyValueIterator, byte[]> all = + cachingStore.backwardFetch(null, keyB)) { + verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1"); + assertFalse(all.hasNext()); + } + + // infinite keyTo fetch + try (final KeyValueIterator, byte[]> all = + cachingStore.backwardFetch(keyA, null)) { + verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1"); + assertFalse(all.hasNext()); + } + + // infinite keyFrom and keyTo fetch + try (final KeyValueIterator, byte[]> all = + cachingStore.backwardFetch(null, null)) { + verifyWindowedKeyValue(all.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(all.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1"); + assertFalse(all.hasNext()); + } + } + + @Test + public void shouldCloseWrappedStoreAndCacheAfterErrorDuringCacheFlush() { + setUpCloseTests(); + EasyMock.reset(cache); + cache.flush(CACHE_NAMESPACE); + EasyMock.expectLastCall().andThrow(new RuntimeException("Simulating an error on flush")); + EasyMock.replay(cache); + EasyMock.reset(underlyingStore); + underlyingStore.close(); + EasyMock.replay(underlyingStore); + + assertThrows(RuntimeException.class, cachingStore::close); + EasyMock.verify(cache, underlyingStore); + } + + @Test + public void shouldCloseWrappedStoreAfterErrorDuringCacheClose() { + setUpCloseTests(); + EasyMock.reset(cache); + cache.flush(CACHE_NAMESPACE); + cache.close(CACHE_NAMESPACE); + EasyMock.expectLastCall().andThrow(new RuntimeException("Simulating an error on close")); + EasyMock.replay(cache); + EasyMock.reset(underlyingStore); + underlyingStore.close(); + EasyMock.replay(underlyingStore); + + assertThrows(RuntimeException.class, cachingStore::close); + EasyMock.verify(cache, underlyingStore); + } + + @Test + public void shouldCloseCacheAfterErrorDuringWrappedStoreClose() { + setUpCloseTests(); + EasyMock.reset(cache); + cache.flush(CACHE_NAMESPACE); + cache.close(CACHE_NAMESPACE); + EasyMock.replay(cache); + EasyMock.reset(underlyingStore); + underlyingStore.close(); + EasyMock.expectLastCall().andThrow(new RuntimeException("Simulating an error on close")); + EasyMock.replay(underlyingStore); + + assertThrows(RuntimeException.class, cachingStore::close); + EasyMock.verify(cache, underlyingStore); + } + + private void setUpCloseTests() { + underlyingStore.close(); + underlyingStore = EasyMock.createNiceMock(SessionStore.class); + EasyMock.expect(underlyingStore.name()).andStubReturn("store-name"); + EasyMock.expect(underlyingStore.isOpen()).andStubReturn(true); + EasyMock.replay(underlyingStore); + cachingStore = new CachingSessionStore(underlyingStore, SEGMENT_INTERVAL); + cache = EasyMock.niceMock(ThreadCache.class); + final InternalMockProcessorContext context = + new InternalMockProcessorContext<>(TestUtils.tempDirectory(), null, null, null, cache); + context.setRecordContext(new ProcessorRecordContext(10, 0, 0, TOPIC, new RecordHeaders())); + cachingStore.init((StateStoreContext) context, cachingStore); + } + + @Test + public void shouldPutFetchRangeFromCache() { + cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes()); + cachingStore.put(new Windowed<>(keyAA, new SessionWindow(0, 0)), "1".getBytes()); + cachingStore.put(new Windowed<>(keyB, new SessionWindow(0, 0)), "1".getBytes()); + + assertEquals(3, cache.size()); + + try (final KeyValueIterator, byte[]> some = + cachingStore.findSessions(keyAA, keyB, 0, 0)) { + verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(some.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + assertFalse(some.hasNext()); + } + + // infinite keyFrom case + try (final KeyValueIterator, byte[]> some = + cachingStore.findSessions(null, keyAA, 0, 0)) { + verifyWindowedKeyValue(some.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + assertFalse(some.hasNext()); + } + + // infinite keyTo case + try (final KeyValueIterator, byte[]> some = + cachingStore.findSessions(keyAA, keyB, 0, 0)) { + verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(some.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + assertFalse(some.hasNext()); + } + + // infinite keyFrom and keyTo case + try (final KeyValueIterator, byte[]> some = + cachingStore.findSessions(null, null, 0, 0)) { + verifyWindowedKeyValue(some.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(some.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + assertFalse(some.hasNext()); + } + } + + @Test + public void shouldPutBackwardFetchRangeFromCache() { + cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes()); + cachingStore.put(new Windowed<>(keyAA, new SessionWindow(0, 0)), "1".getBytes()); + cachingStore.put(new Windowed<>(keyB, new SessionWindow(0, 0)), "1".getBytes()); + + assertEquals(3, cache.size()); + + try (final KeyValueIterator, byte[]> some = + cachingStore.backwardFindSessions(keyAA, keyB, 0, 0)) { + verifyWindowedKeyValue(some.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + assertFalse(some.hasNext()); + } + + // infinite keyFrom case + try (final KeyValueIterator, byte[]> some = + cachingStore.backwardFindSessions(null, keyAA, 0, 0)) { + verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(some.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1"); + assertFalse(some.hasNext()); + } + + // infinite keyTo case + try (final KeyValueIterator, byte[]> some = + cachingStore.backwardFindSessions(keyAA, keyB, 0, 0)) { + verifyWindowedKeyValue(some.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + assertFalse(some.hasNext()); + } + + // infinite keyFrom and keyTo case + try (final KeyValueIterator, byte[]> some = + cachingStore.backwardFindSessions(null, null, 0, 0)) { + verifyWindowedKeyValue(some.next(), new Windowed<>(keyB, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(some.next(), new Windowed<>(keyAA, new SessionWindow(0, 0)), "1"); + verifyWindowedKeyValue(some.next(), new Windowed<>(keyA, new SessionWindow(0, 0)), "1"); + assertFalse(some.hasNext()); + } + } + + @Test + public void shouldFetchAllSessionsWithSameRecordKey() { + final List, byte[]>> expected = asList( + KeyValue.pair(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes()), + KeyValue.pair(new Windowed<>(keyA, new SessionWindow(10, 10)), "2".getBytes()), + KeyValue.pair(new Windowed<>(keyA, new SessionWindow(100, 100)), "3".getBytes()), + KeyValue.pair(new Windowed<>(keyA, new SessionWindow(1000, 1000)), "4".getBytes()) + ); + for (final KeyValue, byte[]> kv : expected) { + cachingStore.put(kv.key, kv.value); + } + + // add one that shouldn't appear in the results + cachingStore.put(new Windowed<>(keyAA, new SessionWindow(0, 0)), "5".getBytes()); + + final List, byte[]>> results = toList(cachingStore.fetch(keyA)); + verifyKeyValueList(expected, results); + } + + @Test + public void shouldBackwardFetchAllSessionsWithSameRecordKey() { + final List, byte[]>> expected = asList( + KeyValue.pair(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes()), + KeyValue.pair(new Windowed<>(keyA, new SessionWindow(10, 10)), "2".getBytes()), + KeyValue.pair(new Windowed<>(keyA, new SessionWindow(100, 100)), "3".getBytes()), + KeyValue.pair(new Windowed<>(keyA, new SessionWindow(1000, 1000)), "4".getBytes()) + ); + for (final KeyValue, byte[]> kv : expected) { + cachingStore.put(kv.key, kv.value); + } + + // add one that shouldn't appear in the results + cachingStore.put(new Windowed<>(keyAA, new SessionWindow(0, 0)), "5".getBytes()); + + final List, byte[]>> results = toList(cachingStore.backwardFetch(keyA)); + Collections.reverse(results); + verifyKeyValueList(expected, results); + } + + @Test + public void shouldFlushItemsToStoreOnEviction() { + final List, byte[]>> added = addSessionsUntilOverflow("a", "b", "c", "d"); + assertEquals(added.size() - 1, cache.size()); + try (final KeyValueIterator, byte[]> iterator = + cachingStore.findSessions(added.get(0).key.key(), 0, 0)) { + final KeyValue, byte[]> next = iterator.next(); + assertEquals(added.get(0).key, next.key); + assertArrayEquals(added.get(0).value, next.value); + } + } + + @Test + public void shouldQueryItemsInCacheAndStore() { + final List, byte[]>> added = addSessionsUntilOverflow("a"); + final List, byte[]>> actual = toList(cachingStore.findSessions( + Bytes.wrap("a".getBytes(StandardCharsets.UTF_8)), + 0, + added.size() * 10L + )); + verifyKeyValueList(added, actual); + } + + @Test + public void shouldRemove() { + final Windowed a = new Windowed<>(keyA, new SessionWindow(0, 0)); + final Windowed b = new Windowed<>(keyB, new SessionWindow(0, 0)); + cachingStore.put(a, "2".getBytes()); + cachingStore.put(b, "2".getBytes()); + cachingStore.remove(a); + + try (final KeyValueIterator, byte[]> rangeIter = + cachingStore.findSessions(keyA, 0, 0)) { + assertFalse(rangeIter.hasNext()); + } + + assertNull(cachingStore.fetchSession(keyA, 0, 0)); + assertThat(cachingStore.fetchSession(keyB, 0, 0), equalTo("2".getBytes())); + + } + + @Test + public void shouldFetchCorrectlyAcrossSegments() { + final Windowed a1 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 0, SEGMENT_INTERVAL * 0)); + final Windowed a2 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 1, SEGMENT_INTERVAL * 1)); + final Windowed a3 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 2, SEGMENT_INTERVAL * 2)); + final Windowed a4 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 3, SEGMENT_INTERVAL * 3)); + final Windowed a5 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 4, SEGMENT_INTERVAL * 4)); + final Windowed a6 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 5, SEGMENT_INTERVAL * 5)); + cachingStore.put(a1, "1".getBytes()); + cachingStore.put(a2, "2".getBytes()); + cachingStore.put(a3, "3".getBytes()); + cachingStore.flush(); + cachingStore.put(a4, "4".getBytes()); + cachingStore.put(a5, "5".getBytes()); + cachingStore.put(a6, "6".getBytes()); + try (final KeyValueIterator, byte[]> results = + cachingStore.findSessions(keyA, 0, SEGMENT_INTERVAL * 5)) { + assertEquals(a1, results.next().key); + assertEquals(a2, results.next().key); + assertEquals(a3, results.next().key); + assertEquals(a4, results.next().key); + assertEquals(a5, results.next().key); + assertEquals(a6, results.next().key); + assertFalse(results.hasNext()); + } + } + + @Test + public void shouldBackwardFetchCorrectlyAcrossSegments() { + final Windowed a1 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 0, SEGMENT_INTERVAL * 0)); + final Windowed a2 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 1, SEGMENT_INTERVAL * 1)); + final Windowed a3 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 2, SEGMENT_INTERVAL * 2)); + final Windowed a4 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 3, SEGMENT_INTERVAL * 3)); + final Windowed a5 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 4, SEGMENT_INTERVAL * 4)); + final Windowed a6 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 5, SEGMENT_INTERVAL * 5)); + cachingStore.put(a1, "1".getBytes()); + cachingStore.put(a2, "2".getBytes()); + cachingStore.put(a3, "3".getBytes()); + cachingStore.flush(); + cachingStore.put(a4, "4".getBytes()); + cachingStore.put(a5, "5".getBytes()); + cachingStore.put(a6, "6".getBytes()); + try (final KeyValueIterator, byte[]> results = + cachingStore.backwardFindSessions(keyA, 0, SEGMENT_INTERVAL * 5)) { + assertEquals(a6, results.next().key); + assertEquals(a5, results.next().key); + assertEquals(a4, results.next().key); + assertEquals(a3, results.next().key); + assertEquals(a2, results.next().key); + assertEquals(a1, results.next().key); + assertFalse(results.hasNext()); + } + } + + @Test + public void shouldFetchRangeCorrectlyAcrossSegments() { + final Windowed a1 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 0, SEGMENT_INTERVAL * 0)); + final Windowed aa1 = new Windowed<>(keyAA, new SessionWindow(SEGMENT_INTERVAL * 0, SEGMENT_INTERVAL * 0)); + final Windowed a2 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 1, SEGMENT_INTERVAL * 1)); + final Windowed a3 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 2, SEGMENT_INTERVAL * 2)); + final Windowed aa3 = new Windowed<>(keyAA, new SessionWindow(SEGMENT_INTERVAL * 2, SEGMENT_INTERVAL * 2)); + cachingStore.put(a1, "1".getBytes()); + cachingStore.put(aa1, "1".getBytes()); + cachingStore.put(a2, "2".getBytes()); + cachingStore.put(a3, "3".getBytes()); + cachingStore.put(aa3, "3".getBytes()); + + final KeyValueIterator, byte[]> rangeResults = + cachingStore.findSessions(keyA, keyAA, 0, SEGMENT_INTERVAL * 2); + final List> keys = new ArrayList<>(); + while (rangeResults.hasNext()) { + keys.add(rangeResults.next().key); + } + rangeResults.close(); + assertEquals(asList(a1, aa1, a2, a3, aa3), keys); + } + + @Test + public void shouldBackwardFetchRangeCorrectlyAcrossSegments() { + final Windowed a1 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 0, SEGMENT_INTERVAL * 0)); + final Windowed aa1 = new Windowed<>(keyAA, new SessionWindow(SEGMENT_INTERVAL * 0, SEGMENT_INTERVAL * 0)); + final Windowed a2 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 1, SEGMENT_INTERVAL * 1)); + final Windowed a3 = new Windowed<>(keyA, new SessionWindow(SEGMENT_INTERVAL * 2, SEGMENT_INTERVAL * 2)); + final Windowed aa3 = new Windowed<>(keyAA, new SessionWindow(SEGMENT_INTERVAL * 2, SEGMENT_INTERVAL * 2)); + cachingStore.put(a1, "1".getBytes()); + cachingStore.put(aa1, "1".getBytes()); + cachingStore.put(a2, "2".getBytes()); + cachingStore.put(a3, "3".getBytes()); + cachingStore.put(aa3, "3".getBytes()); + + final KeyValueIterator, byte[]> rangeResults = + cachingStore.backwardFindSessions(keyA, keyAA, 0, SEGMENT_INTERVAL * 2); + final List> keys = new ArrayList<>(); + while (rangeResults.hasNext()) { + keys.add(rangeResults.next().key); + } + rangeResults.close(); + assertEquals(asList(aa3, a3, a2, aa1, a1), keys); + } + + @Test + public void shouldSetFlushListener() { + assertTrue(cachingStore.setFlushListener(null, true)); + assertTrue(cachingStore.setFlushListener(null, false)); + } + + @Test + public void shouldForwardChangedValuesDuringFlush() { + final Windowed a = new Windowed<>(keyA, new SessionWindow(2, 4)); + final Windowed b = new Windowed<>(keyA, new SessionWindow(1, 2)); + final Windowed aDeserialized = new Windowed<>("a", new SessionWindow(2, 4)); + final Windowed bDeserialized = new Windowed<>("a", new SessionWindow(1, 2)); + final CacheFlushListenerStub, String> flushListener = + new CacheFlushListenerStub<>( + new SessionWindowedDeserializer<>(new StringDeserializer()), + new StringDeserializer() + ); + cachingStore.setFlushListener(flushListener, true); + + cachingStore.put(b, "1".getBytes()); + cachingStore.flush(); + + assertEquals( + Collections.singletonList( + new KeyValueTimestamp<>( + bDeserialized, + new Change<>("1", null), + DEFAULT_TIMESTAMP + ) + ), + flushListener.forwarded + ); + flushListener.forwarded.clear(); + + cachingStore.put(a, "1".getBytes()); + cachingStore.flush(); + + assertEquals( + Collections.singletonList( + new KeyValueTimestamp<>( + aDeserialized, + new Change<>("1", null), + DEFAULT_TIMESTAMP + ) + ), + flushListener.forwarded + ); + flushListener.forwarded.clear(); + + cachingStore.put(a, "2".getBytes()); + cachingStore.flush(); + + assertEquals( + Collections.singletonList( + new KeyValueTimestamp<>( + aDeserialized, + new Change<>("2", "1"), + DEFAULT_TIMESTAMP + ) + ), + flushListener.forwarded + ); + flushListener.forwarded.clear(); + + cachingStore.remove(a); + cachingStore.flush(); + + assertEquals( + Collections.singletonList( + new KeyValueTimestamp<>( + aDeserialized, + new Change<>(null, "2"), + DEFAULT_TIMESTAMP + ) + ), + flushListener.forwarded + ); + flushListener.forwarded.clear(); + + cachingStore.put(a, "1".getBytes()); + cachingStore.put(a, "2".getBytes()); + cachingStore.remove(a); + cachingStore.flush(); + + assertEquals( + Collections.emptyList(), + flushListener.forwarded + ); + flushListener.forwarded.clear(); + } + + @Test + public void shouldNotForwardChangedValuesDuringFlushWhenSendOldValuesDisabled() { + final Windowed a = new Windowed<>(keyA, new SessionWindow(0, 0)); + final Windowed aDeserialized = new Windowed<>("a", new SessionWindow(0, 0)); + final CacheFlushListenerStub, String> flushListener = + new CacheFlushListenerStub<>( + new SessionWindowedDeserializer<>(new StringDeserializer()), + new StringDeserializer()); + cachingStore.setFlushListener(flushListener, false); + + cachingStore.put(a, "1".getBytes()); + cachingStore.flush(); + + cachingStore.put(a, "2".getBytes()); + cachingStore.flush(); + + cachingStore.remove(a); + cachingStore.flush(); + + assertEquals( + asList( + new KeyValueTimestamp<>( + aDeserialized, + new Change<>("1", null), + DEFAULT_TIMESTAMP + ), + new KeyValueTimestamp<>( + aDeserialized, + new Change<>("2", null), + DEFAULT_TIMESTAMP + ), + new KeyValueTimestamp<>( + aDeserialized, + new Change<>(null, null), + DEFAULT_TIMESTAMP + ) + ), + flushListener.forwarded + ); + flushListener.forwarded.clear(); + + cachingStore.put(a, "1".getBytes()); + cachingStore.put(a, "2".getBytes()); + cachingStore.remove(a); + cachingStore.flush(); + + assertEquals( + Collections.emptyList(), + flushListener.forwarded + ); + flushListener.forwarded.clear(); + } + + @Test + public void shouldReturnSameResultsForSingleKeyFindSessionsAndEqualKeyRangeFindSessions() { + cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 1)), "1".getBytes()); + cachingStore.put(new Windowed<>(keyAA, new SessionWindow(2, 3)), "2".getBytes()); + cachingStore.put(new Windowed<>(keyAA, new SessionWindow(4, 5)), "3".getBytes()); + cachingStore.put(new Windowed<>(keyB, new SessionWindow(6, 7)), "4".getBytes()); + + try (final KeyValueIterator, byte[]> singleKeyIterator = + cachingStore.findSessions(keyAA, 0L, 10L); + final KeyValueIterator, byte[]> keyRangeIterator = + cachingStore.findSessions(keyAA, keyAA, 0L, 10L)) { + + assertEquals(singleKeyIterator.next(), keyRangeIterator.next()); + assertEquals(singleKeyIterator.next(), keyRangeIterator.next()); + assertFalse(singleKeyIterator.hasNext()); + assertFalse(keyRangeIterator.hasNext()); + } + } + + @Test + public void shouldReturnSameResultsForSingleKeyFindSessionsBackwardsAndEqualKeyRangeFindSessions() { + cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 1)), "1".getBytes()); + cachingStore.put(new Windowed<>(keyAA, new SessionWindow(2, 3)), "2".getBytes()); + cachingStore.put(new Windowed<>(keyAA, new SessionWindow(4, 5)), "3".getBytes()); + cachingStore.put(new Windowed<>(keyB, new SessionWindow(6, 7)), "4".getBytes()); + + try (final KeyValueIterator, byte[]> singleKeyIterator = + cachingStore.backwardFindSessions(keyAA, 0L, 10L); + final KeyValueIterator, byte[]> keyRangeIterator = + cachingStore.backwardFindSessions(keyAA, keyAA, 0L, 10L)) { + + assertEquals(singleKeyIterator.next(), keyRangeIterator.next()); + assertEquals(singleKeyIterator.next(), keyRangeIterator.next()); + assertFalse(singleKeyIterator.hasNext()); + assertFalse(keyRangeIterator.hasNext()); + } + } + + @Test + public void shouldClearNamespaceCacheOnClose() { + final Windowed a1 = new Windowed<>(keyA, new SessionWindow(0, 0)); + cachingStore.put(a1, "1".getBytes()); + assertEquals(1, cache.size()); + cachingStore.close(); + assertEquals(0, cache.size()); + } + + @Test + public void shouldThrowIfTryingToFetchFromClosedCachingStore() { + cachingStore.close(); + assertThrows(InvalidStateStoreException.class, () -> cachingStore.fetch(keyA)); + } + + @Test + public void shouldThrowIfTryingToFindMergeSessionFromClosedCachingStore() { + cachingStore.close(); + assertThrows(InvalidStateStoreException.class, () -> cachingStore.findSessions(keyA, 0, Long.MAX_VALUE)); + } + + @Test + public void shouldThrowIfTryingToRemoveFromClosedCachingStore() { + cachingStore.close(); + assertThrows(InvalidStateStoreException.class, () -> cachingStore.remove(new Windowed<>(keyA, new SessionWindow(0, 0)))); + } + + @Test + public void shouldThrowIfTryingToPutIntoClosedCachingStore() { + cachingStore.close(); + assertThrows(InvalidStateStoreException.class, () -> cachingStore.put(new Windowed<>(keyA, new SessionWindow(0, 0)), "1".getBytes())); + } + + @Test + public void shouldThrowNullPointerExceptionOnFindSessionsNullKey() { + assertThrows(NullPointerException.class, () -> cachingStore.findSessions(null, 1L, 2L)); + } + + @Test + public void shouldThrowNullPointerExceptionOnFetchNullKey() { + assertThrows(NullPointerException.class, () -> cachingStore.fetch(null)); + } + + @Test + public void shouldThrowNullPointerExceptionOnRemoveNullKey() { + assertThrows(NullPointerException.class, () -> cachingStore.remove(null)); + } + + @Test + public void shouldThrowNullPointerExceptionOnPutNullKey() { + assertThrows(NullPointerException.class, () -> cachingStore.put(null, "1".getBytes())); + } + + @Test + public void shouldNotThrowInvalidRangeExceptionWhenBackwardWithNegativeFromKey() { + final Bytes keyFrom = Bytes.wrap(Serdes.Integer().serializer().serialize("", -1)); + final Bytes keyTo = Bytes.wrap(Serdes.Integer().serializer().serialize("", 1)); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(CachingSessionStore.class); + final KeyValueIterator, byte[]> iterator = + cachingStore.backwardFindSessions(keyFrom, keyTo, 0L, 10L)) { + assertFalse(iterator.hasNext()); + + final List messages = appender.getMessages(); + assertThat( + messages, + hasItem( + "Returning empty iterator for fetch with invalid key range: from > to." + + " This may be due to range arguments set in the wrong order, " + + "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes." + + " Note that the built-in numerical serdes do not follow this for negative numbers" + ) + ); + } + } + + @Test + public void shouldNotThrowInvalidRangeExceptionWithNegativeFromKey() { + final Bytes keyFrom = Bytes.wrap(Serdes.Integer().serializer().serialize("", -1)); + final Bytes keyTo = Bytes.wrap(Serdes.Integer().serializer().serialize("", 1)); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(CachingSessionStore.class); + final KeyValueIterator, byte[]> iterator = cachingStore.findSessions(keyFrom, keyTo, 0L, 10L)) { + assertFalse(iterator.hasNext()); + + final List messages = appender.getMessages(); + assertThat( + messages, + hasItem( + "Returning empty iterator for fetch with invalid key range: from > to." + + " This may be due to range arguments set in the wrong order, " + + "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes." + + " Note that the built-in numerical serdes do not follow this for negative numbers" + ) + ); + } + } + + private List, byte[]>> addSessionsUntilOverflow(final String... sessionIds) { + final Random random = new Random(); + final List, byte[]>> results = new ArrayList<>(); + while (cache.size() == results.size()) { + final String sessionId = sessionIds[random.nextInt(sessionIds.length)]; + addSingleSession(sessionId, results); + } + return results; + } + + private void addSingleSession(final String sessionId, final List, byte[]>> allSessions) { + final int timestamp = allSessions.size() * 10; + final Windowed key = new Windowed<>(Bytes.wrap(sessionId.getBytes()), new SessionWindow(timestamp, timestamp)); + final byte[] value = "1".getBytes(); + cachingStore.put(key, value); + allSessions.add(KeyValue.pair(key, value)); + } + + public static class CacheFlushListenerStub implements CacheFlushListener { + private final Deserializer keyDeserializer; + private final Deserializer valueDesializer; + private final List>> forwarded = new LinkedList<>(); + + CacheFlushListenerStub(final Deserializer keyDeserializer, + final Deserializer valueDesializer) { + this.keyDeserializer = keyDeserializer; + this.valueDesializer = valueDesializer; + } + + @Override + public void apply(final byte[] key, + final byte[] newValue, + final byte[] oldValue, + final long timestamp) { + forwarded.add( + new KeyValueTimestamp<>( + keyDeserializer.deserialize(null, key), + new Change<>( + valueDesializer.deserialize(null, newValue), + valueDesializer.deserialize(null, oldValue)), + timestamp + ) + ); + } + + @Override + public void apply(final Record> record) { + forwarded.add( + new KeyValueTimestamp<>( + keyDeserializer.deserialize(null, record.key()), + new Change<>( + valueDesializer.deserialize(null, record.value().newValue), + valueDesializer.deserialize(null, record.value().oldValue)), + record.timestamp() + ) + ); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingPersistentWindowStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingPersistentWindowStoreTest.java new file mode 100644 index 0000000..2d64a44 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/CachingPersistentWindowStoreTest.java @@ -0,0 +1,1074 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.TimeWindowedDeserializer; +import org.apache.kafka.streams.kstream.Transformer; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.TimeWindow; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.internals.MockStreamsMetrics; +import org.apache.kafka.streams.processor.internals.ProcessorRecordContext; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.streams.state.WindowStoreIterator; +import org.apache.kafka.test.InternalMockProcessorContext; +import org.apache.kafka.test.TestUtils; +import org.easymock.EasyMock; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.List; +import java.util.Properties; +import java.util.UUID; + +import static java.time.Duration.ofHours; +import static java.time.Duration.ofMinutes; +import static java.time.Instant.ofEpochMilli; +import static java.util.Arrays.asList; +import static org.apache.kafka.streams.state.internals.ThreadCacheTest.memoryCacheEntrySize; +import static org.apache.kafka.test.StreamsTestUtils.toList; +import static org.apache.kafka.test.StreamsTestUtils.verifyAllWindowedKeyValues; +import static org.apache.kafka.test.StreamsTestUtils.verifyKeyValueList; +import static org.apache.kafka.test.StreamsTestUtils.verifyWindowedKeyValue; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class CachingPersistentWindowStoreTest { + + private static final int MAX_CACHE_SIZE_BYTES = 150; + private static final long DEFAULT_TIMESTAMP = 10L; + private static final Long WINDOW_SIZE = 10L; + private static final long SEGMENT_INTERVAL = 100L; + private final static String TOPIC = "topic"; + private static final String CACHE_NAMESPACE = "0_0-store-name"; + + private InternalMockProcessorContext context; + private RocksDBSegmentedBytesStore bytesStore; + private WindowStore underlyingStore; + private CachingWindowStore cachingStore; + private CacheFlushListenerStub, String> cacheListener; + private ThreadCache cache; + private WindowKeySchema keySchema; + + @Before + public void setUp() { + keySchema = new WindowKeySchema(); + bytesStore = new RocksDBSegmentedBytesStore("test", "metrics-scope", 0, SEGMENT_INTERVAL, keySchema); + underlyingStore = new RocksDBWindowStore(bytesStore, false, WINDOW_SIZE); + final TimeWindowedDeserializer keyDeserializer = new TimeWindowedDeserializer<>(new StringDeserializer(), WINDOW_SIZE); + keyDeserializer.setIsChangelogTopic(true); + cacheListener = new CacheFlushListenerStub<>(keyDeserializer, new StringDeserializer()); + cachingStore = new CachingWindowStore(underlyingStore, WINDOW_SIZE, SEGMENT_INTERVAL); + cachingStore.setFlushListener(cacheListener, false); + cache = new ThreadCache(new LogContext("testCache "), MAX_CACHE_SIZE_BYTES, new MockStreamsMetrics(new Metrics())); + context = new InternalMockProcessorContext<>(TestUtils.tempDirectory(), null, null, null, cache); + context.setRecordContext(new ProcessorRecordContext(DEFAULT_TIMESTAMP, 0, 0, TOPIC, new RecordHeaders())); + cachingStore.init((StateStoreContext) context, cachingStore); + } + + @After + public void closeStore() { + cachingStore.close(); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldDelegateDeprecatedInit() { + final WindowStore inner = EasyMock.mock(WindowStore.class); + final CachingWindowStore outer = new CachingWindowStore(inner, WINDOW_SIZE, SEGMENT_INTERVAL); + EasyMock.expect(inner.name()).andStubReturn("store"); + inner.init((ProcessorContext) context, outer); + EasyMock.expectLastCall(); + EasyMock.replay(inner); + outer.init((ProcessorContext) context, outer); + EasyMock.verify(inner); + } + + @Test + public void shouldDelegateInit() { + final WindowStore inner = EasyMock.mock(WindowStore.class); + final CachingWindowStore outer = new CachingWindowStore(inner, WINDOW_SIZE, SEGMENT_INTERVAL); + EasyMock.expect(inner.name()).andStubReturn("store"); + inner.init((StateStoreContext) context, outer); + EasyMock.expectLastCall(); + EasyMock.replay(inner); + outer.init((StateStoreContext) context, outer); + EasyMock.verify(inner); + } + + @Test + public void shouldNotReturnDuplicatesInRanges() { + final StreamsBuilder builder = new StreamsBuilder(); + + final StoreBuilder> storeBuilder = Stores.windowStoreBuilder( + Stores.persistentWindowStore("store-name", ofHours(1L), ofMinutes(1L), false), + Serdes.String(), + Serdes.String()) + .withCachingEnabled(); + + builder.addStateStore(storeBuilder); + + builder.stream(TOPIC, + Consumed.with(Serdes.String(), Serdes.String())) + .transform(() -> new Transformer>() { + private WindowStore store; + private int numRecordsProcessed; + private ProcessorContext context; + + @SuppressWarnings("unchecked") + @Override + public void init(final ProcessorContext processorContext) { + this.context = processorContext; + this.store = (WindowStore) processorContext.getStateStore("store-name"); + int count = 0; + + try (final KeyValueIterator, String> all = store.all()) { + while (all.hasNext()) { + count++; + all.next(); + } + } + + assertThat(count, equalTo(0)); + } + + @Override + public KeyValue transform(final String key, final String value) { + int count = 0; + + try (final KeyValueIterator, String> all = store.all()) { + while (all.hasNext()) { + count++; + all.next(); + } + } + + assertThat(count, equalTo(numRecordsProcessed)); + + store.put(value, value, context.timestamp()); + + numRecordsProcessed++; + + return new KeyValue<>(key, value); + } + + @Override + public void close() { + } + }, "store-name"); + + final Properties streamsConfiguration = new Properties(); + streamsConfiguration.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass().getName()); + streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass().getName()); + streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()); + streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 10 * 1000L); + + final Instant initialWallClockTime = Instant.ofEpochMilli(0L); + final TopologyTestDriver driver = new TopologyTestDriver(builder.build(), streamsConfiguration, initialWallClockTime); + + final TestInputTopic inputTopic = driver.createInputTopic(TOPIC, + Serdes.String().serializer(), + Serdes.String().serializer(), + initialWallClockTime, + Duration.ZERO); + + for (int i = 0; i < 5; i++) { + inputTopic.pipeInput(UUID.randomUUID().toString(), UUID.randomUUID().toString()); + } + driver.advanceWallClockTime(Duration.ofSeconds(10)); + inputTopic.advanceTime(Duration.ofSeconds(10)); + for (int i = 0; i < 5; i++) { + inputTopic.pipeInput(UUID.randomUUID().toString(), UUID.randomUUID().toString()); + } + driver.advanceWallClockTime(Duration.ofSeconds(10)); + inputTopic.advanceTime(Duration.ofSeconds(10)); + for (int i = 0; i < 5; i++) { + inputTopic.pipeInput(UUID.randomUUID().toString(), UUID.randomUUID().toString()); + } + driver.advanceWallClockTime(Duration.ofSeconds(10)); + inputTopic.advanceTime(Duration.ofSeconds(10)); + for (int i = 0; i < 5; i++) { + inputTopic.pipeInput(UUID.randomUUID().toString(), UUID.randomUUID().toString()); + } + + driver.close(); + } + + @Test + public void shouldPutFetchFromCache() { + cachingStore.put(bytesKey("a"), bytesValue("a"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("b"), bytesValue("b"), DEFAULT_TIMESTAMP); + + assertThat(cachingStore.fetch(bytesKey("a"), 10), equalTo(bytesValue("a"))); + assertThat(cachingStore.fetch(bytesKey("b"), 10), equalTo(bytesValue("b"))); + assertThat(cachingStore.fetch(bytesKey("c"), 10), equalTo(null)); + assertThat(cachingStore.fetch(bytesKey("a"), 0), equalTo(null)); + + try (final WindowStoreIterator a = cachingStore.fetch(bytesKey("a"), ofEpochMilli(10), ofEpochMilli(10)); + final WindowStoreIterator b = cachingStore.fetch(bytesKey("b"), ofEpochMilli(10), ofEpochMilli(10))) { + verifyKeyValue(a.next(), DEFAULT_TIMESTAMP, "a"); + verifyKeyValue(b.next(), DEFAULT_TIMESTAMP, "b"); + assertFalse(a.hasNext()); + assertFalse(b.hasNext()); + assertEquals(2, cache.size()); + } + } + + private void verifyKeyValue(final KeyValue next, + final long expectedKey, + final String expectedValue) { + assertThat(next.key, equalTo(expectedKey)); + assertThat(next.value, equalTo(bytesValue(expectedValue))); + } + + private static byte[] bytesValue(final String value) { + return value.getBytes(); + } + + private static Bytes bytesKey(final String key) { + return Bytes.wrap(key.getBytes()); + } + + private String stringFrom(final byte[] from) { + return Serdes.String().deserializer().deserialize("", from); + } + + @Test + public void shouldPutFetchRangeFromCache() { + cachingStore.put(bytesKey("a"), bytesValue("a"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("b"), bytesValue("b"), DEFAULT_TIMESTAMP); + + try (final KeyValueIterator, byte[]> iterator = + cachingStore.fetch(bytesKey("a"), bytesKey("b"), ofEpochMilli(DEFAULT_TIMESTAMP), ofEpochMilli(DEFAULT_TIMESTAMP))) { + final List> expectedKeys = Arrays.asList( + new Windowed<>(bytesKey("a"), new TimeWindow(DEFAULT_TIMESTAMP, DEFAULT_TIMESTAMP + WINDOW_SIZE)), + new Windowed<>(bytesKey("b"), new TimeWindow(DEFAULT_TIMESTAMP, DEFAULT_TIMESTAMP + WINDOW_SIZE)) + ); + + final List expectedValues = Arrays.asList("a", "b"); + + verifyAllWindowedKeyValues(iterator, expectedKeys, expectedValues); + assertEquals(2, cache.size()); + } + } + + @Test + public void shouldPutFetchRangeFromCacheForNullKeyFrom() { + cachingStore.put(bytesKey("a"), bytesValue("a"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("b"), bytesValue("b"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("c"), bytesValue("c"), DEFAULT_TIMESTAMP + 10L); + cachingStore.put(bytesKey("d"), bytesValue("d"), DEFAULT_TIMESTAMP + 20L); + cachingStore.put(bytesKey("e"), bytesValue("e"), DEFAULT_TIMESTAMP + 20L); + + try (final KeyValueIterator, byte[]> iterator = + cachingStore.fetch(null, bytesKey("d"), ofEpochMilli(DEFAULT_TIMESTAMP), ofEpochMilli(DEFAULT_TIMESTAMP + 20L))) { + final List> expectedKeys = Arrays.asList( + new Windowed<>(bytesKey("a"), new TimeWindow(DEFAULT_TIMESTAMP, DEFAULT_TIMESTAMP + WINDOW_SIZE)), + new Windowed<>(bytesKey("b"), new TimeWindow(DEFAULT_TIMESTAMP, DEFAULT_TIMESTAMP + WINDOW_SIZE)), + new Windowed<>(bytesKey("c"), new TimeWindow(DEFAULT_TIMESTAMP + 10L, DEFAULT_TIMESTAMP + 10L + WINDOW_SIZE)), + new Windowed<>(bytesKey("d"), new TimeWindow(DEFAULT_TIMESTAMP + 20L, DEFAULT_TIMESTAMP + 20L + WINDOW_SIZE)) + ); + + final List expectedValues = Arrays.asList("a", "b", "c", "d"); + + verifyAllWindowedKeyValues(iterator, expectedKeys, expectedValues); + } + } + + @Test + public void shouldPutFetchRangeFromCacheForNullKeyTo() { + cachingStore.put(bytesKey("a"), bytesValue("a"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("b"), bytesValue("b"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("c"), bytesValue("c"), DEFAULT_TIMESTAMP + 10L); + cachingStore.put(bytesKey("d"), bytesValue("d"), DEFAULT_TIMESTAMP + 20L); + cachingStore.put(bytesKey("e"), bytesValue("e"), DEFAULT_TIMESTAMP + 20L); + + try (final KeyValueIterator, byte[]> iterator = + cachingStore.fetch(bytesKey("b"), null, ofEpochMilli(DEFAULT_TIMESTAMP), ofEpochMilli(DEFAULT_TIMESTAMP + 20L))) { + final List> expectedKeys = Arrays.asList( + new Windowed<>(bytesKey("b"), new TimeWindow(DEFAULT_TIMESTAMP, DEFAULT_TIMESTAMP + WINDOW_SIZE)), + new Windowed<>(bytesKey("c"), new TimeWindow(DEFAULT_TIMESTAMP + 10L, DEFAULT_TIMESTAMP + 10L + WINDOW_SIZE)), + new Windowed<>(bytesKey("d"), new TimeWindow(DEFAULT_TIMESTAMP + 20L, DEFAULT_TIMESTAMP + 20L + WINDOW_SIZE)), + new Windowed<>(bytesKey("e"), new TimeWindow(DEFAULT_TIMESTAMP + 20L, DEFAULT_TIMESTAMP + 20L + WINDOW_SIZE)) + ); + + final List expectedValues = Arrays.asList("b", "c", "d", "e"); + + verifyAllWindowedKeyValues(iterator, expectedKeys, expectedValues); + } + } + + @Test + public void shouldPutFetchRangeFromCacheForNullKeyFromKeyTo() { + cachingStore.put(bytesKey("a"), bytesValue("a"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("b"), bytesValue("b"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("c"), bytesValue("c"), DEFAULT_TIMESTAMP + 10L); + cachingStore.put(bytesKey("d"), bytesValue("d"), DEFAULT_TIMESTAMP + 20L); + cachingStore.put(bytesKey("e"), bytesValue("e"), DEFAULT_TIMESTAMP + 20L); + + try (final KeyValueIterator, byte[]> iterator = + cachingStore.fetch(null, null, ofEpochMilli(DEFAULT_TIMESTAMP), ofEpochMilli(DEFAULT_TIMESTAMP + 20L))) { + final List> expectedKeys = Arrays.asList( + new Windowed<>(bytesKey("a"), new TimeWindow(DEFAULT_TIMESTAMP, DEFAULT_TIMESTAMP + WINDOW_SIZE)), + new Windowed<>(bytesKey("b"), new TimeWindow(DEFAULT_TIMESTAMP, DEFAULT_TIMESTAMP + WINDOW_SIZE)), + new Windowed<>(bytesKey("c"), new TimeWindow(DEFAULT_TIMESTAMP + 10L, DEFAULT_TIMESTAMP + 10L + WINDOW_SIZE)), + new Windowed<>(bytesKey("d"), new TimeWindow(DEFAULT_TIMESTAMP + 20L, DEFAULT_TIMESTAMP + 20L + WINDOW_SIZE)), + new Windowed<>(bytesKey("e"), new TimeWindow(DEFAULT_TIMESTAMP + 20L, DEFAULT_TIMESTAMP + 20L + WINDOW_SIZE)) + ); + + final List expectedValues = Arrays.asList("a", "b", "c", "d", "e"); + + verifyAllWindowedKeyValues(iterator, expectedKeys, expectedValues); + } + } + + @Test + public void shouldPutBackwardFetchRangeFromCacheForNullKeyFrom() { + cachingStore.put(bytesKey("a"), bytesValue("a"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("b"), bytesValue("b"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("c"), bytesValue("c"), DEFAULT_TIMESTAMP + 10L); + cachingStore.put(bytesKey("d"), bytesValue("d"), DEFAULT_TIMESTAMP + 20L); + cachingStore.put(bytesKey("e"), bytesValue("e"), DEFAULT_TIMESTAMP + 20L); + + try (final KeyValueIterator, byte[]> iterator = + cachingStore.backwardFetch(null, bytesKey("c"), ofEpochMilli(DEFAULT_TIMESTAMP), ofEpochMilli(DEFAULT_TIMESTAMP + 20L))) { + final List> expectedKeys = Arrays.asList( + new Windowed<>(bytesKey("c"), new TimeWindow(DEFAULT_TIMESTAMP + 10L, DEFAULT_TIMESTAMP + 10L + WINDOW_SIZE)), + new Windowed<>(bytesKey("b"), new TimeWindow(DEFAULT_TIMESTAMP, DEFAULT_TIMESTAMP + WINDOW_SIZE)), + new Windowed<>(bytesKey("a"), new TimeWindow(DEFAULT_TIMESTAMP, DEFAULT_TIMESTAMP + WINDOW_SIZE)) + ); + + final List expectedValues = Arrays.asList("c", "b", "a"); + + verifyAllWindowedKeyValues(iterator, expectedKeys, expectedValues); + } + } + + @Test + public void shouldPutBackwardFetchRangeFromCacheForNullKeyTo() { + cachingStore.put(bytesKey("a"), bytesValue("a"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("b"), bytesValue("b"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("c"), bytesValue("c"), DEFAULT_TIMESTAMP + 10L); + cachingStore.put(bytesKey("d"), bytesValue("d"), DEFAULT_TIMESTAMP + 20L); + cachingStore.put(bytesKey("e"), bytesValue("e"), DEFAULT_TIMESTAMP + 20L); + + try (final KeyValueIterator, byte[]> iterator = + cachingStore.backwardFetch(bytesKey("c"), null, ofEpochMilli(DEFAULT_TIMESTAMP), ofEpochMilli(DEFAULT_TIMESTAMP + 20L))) { + final List> expectedKeys = Arrays.asList( + new Windowed<>(bytesKey("e"), new TimeWindow(DEFAULT_TIMESTAMP + 20L, DEFAULT_TIMESTAMP + 20L + WINDOW_SIZE)), + new Windowed<>(bytesKey("d"), new TimeWindow(DEFAULT_TIMESTAMP + 20L, DEFAULT_TIMESTAMP + 20L + WINDOW_SIZE)), + new Windowed<>(bytesKey("c"), new TimeWindow(DEFAULT_TIMESTAMP + 10L, DEFAULT_TIMESTAMP + 10L + WINDOW_SIZE)) + ); + + final List expectedValues = Arrays.asList("e", "d", "c"); + + verifyAllWindowedKeyValues(iterator, expectedKeys, expectedValues); + } + } + + @Test + public void shouldPutBackwardFetchRangeFromCacheForNullKeyFromKeyTo() { + cachingStore.put(bytesKey("a"), bytesValue("a"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("b"), bytesValue("b"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("c"), bytesValue("c"), DEFAULT_TIMESTAMP + 10L); + cachingStore.put(bytesKey("d"), bytesValue("d"), DEFAULT_TIMESTAMP + 20L); + cachingStore.put(bytesKey("e"), bytesValue("e"), DEFAULT_TIMESTAMP + 20L); + + try (final KeyValueIterator, byte[]> iterator = + cachingStore.backwardFetch(null, null, ofEpochMilli(DEFAULT_TIMESTAMP), ofEpochMilli(DEFAULT_TIMESTAMP + 20L))) { + final List> expectedKeys = Arrays.asList( + new Windowed<>(bytesKey("e"), new TimeWindow(DEFAULT_TIMESTAMP + 20L, DEFAULT_TIMESTAMP + 20L + WINDOW_SIZE)), + new Windowed<>(bytesKey("d"), new TimeWindow(DEFAULT_TIMESTAMP + 20L, DEFAULT_TIMESTAMP + 20L + WINDOW_SIZE)), + new Windowed<>(bytesKey("c"), new TimeWindow(DEFAULT_TIMESTAMP + 10L, DEFAULT_TIMESTAMP + 10L + WINDOW_SIZE)), + new Windowed<>(bytesKey("b"), new TimeWindow(DEFAULT_TIMESTAMP, DEFAULT_TIMESTAMP + WINDOW_SIZE)), + new Windowed<>(bytesKey("a"), new TimeWindow(DEFAULT_TIMESTAMP, DEFAULT_TIMESTAMP + WINDOW_SIZE)) + ); + + final List expectedValues = Arrays.asList("e", "d", "c", "b", "a"); + + verifyAllWindowedKeyValues(iterator, expectedKeys, expectedValues); + } + } + + @Test + public void shouldGetAllFromCache() { + cachingStore.put(bytesKey("a"), bytesValue("a"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("b"), bytesValue("b"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("c"), bytesValue("c"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("d"), bytesValue("d"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("e"), bytesValue("e"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("f"), bytesValue("f"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("g"), bytesValue("g"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("h"), bytesValue("h"), DEFAULT_TIMESTAMP); + + try (final KeyValueIterator, byte[]> iterator = cachingStore.all()) { + final String[] array = {"a", "b", "c", "d", "e", "f", "g", "h"}; + for (final String s : array) { + verifyWindowedKeyValue( + iterator.next(), + new Windowed<>(bytesKey(s), new TimeWindow(DEFAULT_TIMESTAMP, DEFAULT_TIMESTAMP + WINDOW_SIZE)), + s); + } + assertFalse(iterator.hasNext()); + } + } + + @Test + public void shouldGetAllBackwardFromCache() { + cachingStore.put(bytesKey("a"), bytesValue("a"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("b"), bytesValue("b"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("c"), bytesValue("c"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("d"), bytesValue("d"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("e"), bytesValue("e"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("f"), bytesValue("f"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("g"), bytesValue("g"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("h"), bytesValue("h"), DEFAULT_TIMESTAMP); + + try (final KeyValueIterator, byte[]> iterator = cachingStore.backwardAll()) { + final String[] array = {"h", "g", "f", "e", "d", "c", "b", "a"}; + for (final String s : array) { + verifyWindowedKeyValue( + iterator.next(), + new Windowed<>(bytesKey(s), new TimeWindow(DEFAULT_TIMESTAMP, DEFAULT_TIMESTAMP + WINDOW_SIZE)), + s); + } + assertFalse(iterator.hasNext()); + } + } + + @Test + public void shouldFetchAllWithinTimestampRange() { + final String[] array = {"a", "b", "c", "d", "e", "f", "g", "h"}; + for (int i = 0; i < array.length; i++) { + cachingStore.put(bytesKey(array[i]), bytesValue(array[i]), i); + } + + try (final KeyValueIterator, byte[]> iterator = + cachingStore.fetchAll(ofEpochMilli(0), ofEpochMilli(7))) { + for (int i = 0; i < array.length; i++) { + final String str = array[i]; + verifyWindowedKeyValue( + iterator.next(), + new Windowed<>(bytesKey(str), new TimeWindow(i, i + WINDOW_SIZE)), + str); + } + assertFalse(iterator.hasNext()); + } + + try (final KeyValueIterator, byte[]> iterator1 = + cachingStore.fetchAll(ofEpochMilli(2), ofEpochMilli(4))) { + for (int i = 2; i <= 4; i++) { + final String str = array[i]; + verifyWindowedKeyValue( + iterator1.next(), + new Windowed<>(bytesKey(str), new TimeWindow(i, i + WINDOW_SIZE)), + str); + } + assertFalse(iterator1.hasNext()); + } + + try (final KeyValueIterator, byte[]> iterator2 = + cachingStore.fetchAll(ofEpochMilli(5), ofEpochMilli(7))) { + for (int i = 5; i <= 7; i++) { + final String str = array[i]; + verifyWindowedKeyValue( + iterator2.next(), + new Windowed<>(bytesKey(str), new TimeWindow(i, i + WINDOW_SIZE)), + str); + } + assertFalse(iterator2.hasNext()); + } + } + + @Test + public void shouldFetchAllBackwardWithinTimestampRange() { + final String[] array = {"a", "b", "c", "d", "e", "f", "g", "h"}; + for (int i = 0; i < array.length; i++) { + cachingStore.put(bytesKey(array[i]), bytesValue(array[i]), i); + } + + try (final KeyValueIterator, byte[]> iterator = + cachingStore.backwardFetchAll(ofEpochMilli(0), ofEpochMilli(7))) { + for (int i = array.length - 1; i >= 0; i--) { + final String str = array[i]; + verifyWindowedKeyValue( + iterator.next(), + new Windowed<>(bytesKey(str), new TimeWindow(i, i + WINDOW_SIZE)), + str); + } + assertFalse(iterator.hasNext()); + } + + try (final KeyValueIterator, byte[]> iterator1 = + cachingStore.backwardFetchAll(ofEpochMilli(2), ofEpochMilli(4))) { + for (int i = 4; i >= 2; i--) { + final String str = array[i]; + verifyWindowedKeyValue( + iterator1.next(), + new Windowed<>(bytesKey(str), new TimeWindow(i, i + WINDOW_SIZE)), + str); + } + assertFalse(iterator1.hasNext()); + } + + try (final KeyValueIterator, byte[]> iterator2 = + cachingStore.backwardFetchAll(ofEpochMilli(5), ofEpochMilli(7))) { + for (int i = 7; i >= 5; i--) { + final String str = array[i]; + verifyWindowedKeyValue( + iterator2.next(), + new Windowed<>(bytesKey(str), new TimeWindow(i, i + WINDOW_SIZE)), + str); + } + assertFalse(iterator2.hasNext()); + } + } + + @Test + public void shouldFlushEvictedItemsIntoUnderlyingStore() { + final int added = addItemsToCache(); + // all dirty entries should have been flushed + try (final KeyValueIterator iter = bytesStore.fetch( + Bytes.wrap("0".getBytes(StandardCharsets.UTF_8)), + DEFAULT_TIMESTAMP, + DEFAULT_TIMESTAMP)) { + final KeyValue next = iter.next(); + assertEquals(DEFAULT_TIMESTAMP, keySchema.segmentTimestamp(next.key)); + assertArrayEquals("0".getBytes(), next.value); + assertFalse(iter.hasNext()); + assertEquals(added - 1, cache.size()); + } + } + + @Test + public void shouldForwardDirtyItemsWhenFlushCalled() { + final Windowed windowedKey = + new Windowed<>("1", new TimeWindow(DEFAULT_TIMESTAMP, DEFAULT_TIMESTAMP + WINDOW_SIZE)); + cachingStore.put(bytesKey("1"), bytesValue("a"), DEFAULT_TIMESTAMP); + cachingStore.flush(); + assertEquals("a", cacheListener.forwarded.get(windowedKey).newValue); + assertNull(cacheListener.forwarded.get(windowedKey).oldValue); + } + + @Test + public void shouldSetFlushListener() { + assertTrue(cachingStore.setFlushListener(null, true)); + assertTrue(cachingStore.setFlushListener(null, false)); + } + + @Test + public void shouldForwardOldValuesWhenEnabled() { + cachingStore.setFlushListener(cacheListener, true); + final Windowed windowedKey = + new Windowed<>("1", new TimeWindow(DEFAULT_TIMESTAMP, DEFAULT_TIMESTAMP + WINDOW_SIZE)); + cachingStore.put(bytesKey("1"), bytesValue("a"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("1"), bytesValue("b"), DEFAULT_TIMESTAMP); + cachingStore.flush(); + assertEquals("b", cacheListener.forwarded.get(windowedKey).newValue); + assertNull(cacheListener.forwarded.get(windowedKey).oldValue); + cacheListener.forwarded.clear(); + cachingStore.put(bytesKey("1"), bytesValue("c"), DEFAULT_TIMESTAMP); + cachingStore.flush(); + assertEquals("c", cacheListener.forwarded.get(windowedKey).newValue); + assertEquals("b", cacheListener.forwarded.get(windowedKey).oldValue); + cachingStore.put(bytesKey("1"), null, DEFAULT_TIMESTAMP); + cachingStore.flush(); + assertNull(cacheListener.forwarded.get(windowedKey).newValue); + assertEquals("c", cacheListener.forwarded.get(windowedKey).oldValue); + cacheListener.forwarded.clear(); + cachingStore.put(bytesKey("1"), bytesValue("a"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("1"), bytesValue("b"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("1"), null, DEFAULT_TIMESTAMP); + cachingStore.flush(); + assertNull(cacheListener.forwarded.get(windowedKey)); + cacheListener.forwarded.clear(); + } + + @Test + public void shouldForwardOldValuesWhenDisabled() { + final Windowed windowedKey = + new Windowed<>("1", new TimeWindow(DEFAULT_TIMESTAMP, DEFAULT_TIMESTAMP + WINDOW_SIZE)); + cachingStore.put(bytesKey("1"), bytesValue("a"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("1"), bytesValue("b"), DEFAULT_TIMESTAMP); + cachingStore.flush(); + assertEquals("b", cacheListener.forwarded.get(windowedKey).newValue); + assertNull(cacheListener.forwarded.get(windowedKey).oldValue); + cachingStore.put(bytesKey("1"), bytesValue("c"), DEFAULT_TIMESTAMP); + cachingStore.flush(); + assertEquals("c", cacheListener.forwarded.get(windowedKey).newValue); + assertNull(cacheListener.forwarded.get(windowedKey).oldValue); + cachingStore.put(bytesKey("1"), null, DEFAULT_TIMESTAMP); + cachingStore.flush(); + assertNull(cacheListener.forwarded.get(windowedKey).newValue); + assertNull(cacheListener.forwarded.get(windowedKey).oldValue); + cacheListener.forwarded.clear(); + cachingStore.put(bytesKey("1"), bytesValue("a"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("1"), bytesValue("b"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("1"), null, DEFAULT_TIMESTAMP); + cachingStore.flush(); + assertNull(cacheListener.forwarded.get(windowedKey)); + cacheListener.forwarded.clear(); + } + + @Test + public void shouldForwardDirtyItemToListenerWhenEvicted() { + final int numRecords = addItemsToCache(); + assertEquals(numRecords, cacheListener.forwarded.size()); + } + + @Test + public void shouldTakeValueFromCacheIfSameTimestampFlushedToRocks() { + cachingStore.put(bytesKey("1"), bytesValue("a"), DEFAULT_TIMESTAMP); + cachingStore.flush(); + cachingStore.put(bytesKey("1"), bytesValue("b"), DEFAULT_TIMESTAMP); + + try (final WindowStoreIterator fetch = + cachingStore.fetch(bytesKey("1"), ofEpochMilli(DEFAULT_TIMESTAMP), ofEpochMilli(DEFAULT_TIMESTAMP))) { + verifyKeyValue(fetch.next(), DEFAULT_TIMESTAMP, "b"); + assertFalse(fetch.hasNext()); + } + } + + @Test + public void shouldIterateAcrossWindows() { + cachingStore.put(bytesKey("1"), bytesValue("a"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("1"), bytesValue("b"), DEFAULT_TIMESTAMP + WINDOW_SIZE); + + try (final WindowStoreIterator fetch = + cachingStore.fetch(bytesKey("1"), ofEpochMilli(DEFAULT_TIMESTAMP), ofEpochMilli(DEFAULT_TIMESTAMP + WINDOW_SIZE))) { + verifyKeyValue(fetch.next(), DEFAULT_TIMESTAMP, "a"); + verifyKeyValue(fetch.next(), DEFAULT_TIMESTAMP + WINDOW_SIZE, "b"); + assertFalse(fetch.hasNext()); + } + } + + @Test + public void shouldIterateBackwardAcrossWindows() { + cachingStore.put(bytesKey("1"), bytesValue("a"), DEFAULT_TIMESTAMP); + cachingStore.put(bytesKey("1"), bytesValue("b"), DEFAULT_TIMESTAMP + WINDOW_SIZE); + + try (final WindowStoreIterator fetch = + cachingStore.backwardFetch(bytesKey("1"), ofEpochMilli(DEFAULT_TIMESTAMP), ofEpochMilli(DEFAULT_TIMESTAMP + WINDOW_SIZE))) { + verifyKeyValue(fetch.next(), DEFAULT_TIMESTAMP + WINDOW_SIZE, "b"); + verifyKeyValue(fetch.next(), DEFAULT_TIMESTAMP, "a"); + assertFalse(fetch.hasNext()); + } + } + + @Test + public void shouldIterateCacheAndStore() { + final Bytes key = Bytes.wrap("1".getBytes()); + bytesStore.put(WindowKeySchema.toStoreKeyBinary(key, DEFAULT_TIMESTAMP, 0), "a".getBytes()); + cachingStore.put(key, bytesValue("b"), DEFAULT_TIMESTAMP + WINDOW_SIZE); + try (final WindowStoreIterator fetch = + cachingStore.fetch(bytesKey("1"), ofEpochMilli(DEFAULT_TIMESTAMP), ofEpochMilli(DEFAULT_TIMESTAMP + WINDOW_SIZE))) { + verifyKeyValue(fetch.next(), DEFAULT_TIMESTAMP, "a"); + verifyKeyValue(fetch.next(), DEFAULT_TIMESTAMP + WINDOW_SIZE, "b"); + assertFalse(fetch.hasNext()); + } + } + + @Test + public void shouldIterateBackwardCacheAndStore() { + final Bytes key = Bytes.wrap("1".getBytes()); + bytesStore.put(WindowKeySchema.toStoreKeyBinary(key, DEFAULT_TIMESTAMP, 0), "a".getBytes()); + cachingStore.put(key, bytesValue("b"), DEFAULT_TIMESTAMP + WINDOW_SIZE); + try (final WindowStoreIterator fetch = + cachingStore.backwardFetch(bytesKey("1"), ofEpochMilli(DEFAULT_TIMESTAMP), ofEpochMilli(DEFAULT_TIMESTAMP + WINDOW_SIZE))) { + verifyKeyValue(fetch.next(), DEFAULT_TIMESTAMP + WINDOW_SIZE, "b"); + verifyKeyValue(fetch.next(), DEFAULT_TIMESTAMP, "a"); + assertFalse(fetch.hasNext()); + } + } + + @Test + public void shouldIterateCacheAndStoreKeyRange() { + final Bytes key = Bytes.wrap("1".getBytes()); + bytesStore.put(WindowKeySchema.toStoreKeyBinary(key, DEFAULT_TIMESTAMP, 0), "a".getBytes()); + cachingStore.put(key, bytesValue("b"), DEFAULT_TIMESTAMP + WINDOW_SIZE); + + try (final KeyValueIterator, byte[]> fetchRange = + cachingStore.fetch(key, bytesKey("2"), ofEpochMilli(DEFAULT_TIMESTAMP), ofEpochMilli(DEFAULT_TIMESTAMP + WINDOW_SIZE))) { + verifyWindowedKeyValue( + fetchRange.next(), + new Windowed<>(key, new TimeWindow(DEFAULT_TIMESTAMP, DEFAULT_TIMESTAMP + WINDOW_SIZE)), + "a"); + verifyWindowedKeyValue( + fetchRange.next(), + new Windowed<>(key, new TimeWindow(DEFAULT_TIMESTAMP + WINDOW_SIZE, DEFAULT_TIMESTAMP + WINDOW_SIZE + WINDOW_SIZE)), + "b"); + assertFalse(fetchRange.hasNext()); + } + } + + @Test + public void shouldIterateBackwardCacheAndStoreKeyRange() { + final Bytes key = Bytes.wrap("1".getBytes()); + bytesStore.put(WindowKeySchema.toStoreKeyBinary(key, DEFAULT_TIMESTAMP, 0), "a".getBytes()); + cachingStore.put(key, bytesValue("b"), DEFAULT_TIMESTAMP + WINDOW_SIZE); + + try (final KeyValueIterator, byte[]> fetchRange = + cachingStore.backwardFetch(key, bytesKey("2"), ofEpochMilli(DEFAULT_TIMESTAMP), ofEpochMilli(DEFAULT_TIMESTAMP + WINDOW_SIZE))) { + verifyWindowedKeyValue( + fetchRange.next(), + new Windowed<>(key, new TimeWindow(DEFAULT_TIMESTAMP + WINDOW_SIZE, DEFAULT_TIMESTAMP + WINDOW_SIZE + WINDOW_SIZE)), + "b"); + verifyWindowedKeyValue( + fetchRange.next(), + new Windowed<>(key, new TimeWindow(DEFAULT_TIMESTAMP, DEFAULT_TIMESTAMP + WINDOW_SIZE)), + "a"); + assertFalse(fetchRange.hasNext()); + } + } + + @Test + public void shouldClearNamespaceCacheOnClose() { + cachingStore.put(bytesKey("a"), bytesValue("a"), 0L); + assertEquals(1, cache.size()); + cachingStore.close(); + assertEquals(0, cache.size()); + } + + @Test + public void shouldThrowIfTryingToFetchFromClosedCachingStore() { + cachingStore.close(); + assertThrows(InvalidStateStoreException.class, () -> cachingStore.fetch(bytesKey("a"), ofEpochMilli(0), ofEpochMilli(10))); + } + + @Test + public void shouldThrowIfTryingToFetchRangeFromClosedCachingStore() { + cachingStore.close(); + assertThrows(InvalidStateStoreException.class, () -> cachingStore.fetch(bytesKey("a"), bytesKey("b"), ofEpochMilli(0), ofEpochMilli(10))); + } + + @Test + public void shouldThrowIfTryingToWriteToClosedCachingStore() { + cachingStore.close(); + assertThrows(InvalidStateStoreException.class, () -> cachingStore.put(bytesKey("a"), bytesValue("a"), 0L)); + } + + @Test + public void shouldFetchAndIterateOverExactKeys() { + cachingStore.put(bytesKey("a"), bytesValue("0001"), 0); + cachingStore.put(bytesKey("aa"), bytesValue("0002"), 0); + cachingStore.put(bytesKey("a"), bytesValue("0003"), 1); + cachingStore.put(bytesKey("aa"), bytesValue("0004"), 1); + cachingStore.put(bytesKey("a"), bytesValue("0005"), SEGMENT_INTERVAL); + + final List> expected = asList( + KeyValue.pair(0L, bytesValue("0001")), + KeyValue.pair(1L, bytesValue("0003")), + KeyValue.pair(SEGMENT_INTERVAL, bytesValue("0005")) + ); + final List> actual = + toList(cachingStore.fetch(bytesKey("a"), ofEpochMilli(0), ofEpochMilli(Long.MAX_VALUE))); + verifyKeyValueList(expected, actual); + } + + @Test + public void shouldBackwardFetchAndIterateOverExactKeys() { + cachingStore.put(bytesKey("a"), bytesValue("0001"), 0); + cachingStore.put(bytesKey("aa"), bytesValue("0002"), 0); + cachingStore.put(bytesKey("a"), bytesValue("0003"), 1); + cachingStore.put(bytesKey("aa"), bytesValue("0004"), 1); + cachingStore.put(bytesKey("a"), bytesValue("0005"), SEGMENT_INTERVAL); + + final List> expected = asList( + KeyValue.pair(SEGMENT_INTERVAL, bytesValue("0005")), + KeyValue.pair(1L, bytesValue("0003")), + KeyValue.pair(0L, bytesValue("0001")) + ); + final List> actual = + toList(cachingStore.backwardFetch(bytesKey("a"), ofEpochMilli(0), ofEpochMilli(Long.MAX_VALUE))); + verifyKeyValueList(expected, actual); + } + + @Test + public void shouldFetchAndIterateOverKeyRange() { + cachingStore.put(bytesKey("a"), bytesValue("0001"), 0); + cachingStore.put(bytesKey("aa"), bytesValue("0002"), 0); + cachingStore.put(bytesKey("a"), bytesValue("0003"), 1); + cachingStore.put(bytesKey("aa"), bytesValue("0004"), 1); + cachingStore.put(bytesKey("a"), bytesValue("0005"), SEGMENT_INTERVAL); + + verifyKeyValueList( + asList( + windowedPair("a", "0001", 0), + windowedPair("a", "0003", 1), + windowedPair("a", "0005", SEGMENT_INTERVAL) + ), + toList(cachingStore.fetch(bytesKey("a"), bytesKey("a"), ofEpochMilli(0), ofEpochMilli(Long.MAX_VALUE))) + ); + + verifyKeyValueList( + asList( + windowedPair("aa", "0002", 0), + windowedPair("aa", "0004", 1)), + toList(cachingStore.fetch(bytesKey("aa"), bytesKey("aa"), ofEpochMilli(0), ofEpochMilli(Long.MAX_VALUE))) + ); + + verifyKeyValueList( + asList( + windowedPair("a", "0001", 0), + windowedPair("a", "0003", 1), + windowedPair("aa", "0002", 0), + windowedPair("aa", "0004", 1), + windowedPair("a", "0005", SEGMENT_INTERVAL) + ), + toList(cachingStore.fetch(bytesKey("a"), bytesKey("aa"), ofEpochMilli(0), ofEpochMilli(Long.MAX_VALUE))) + ); + } + + @Test + public void shouldFetchAndIterateOverKeyBackwardRange() { + cachingStore.put(bytesKey("a"), bytesValue("0001"), 0); + cachingStore.put(bytesKey("aa"), bytesValue("0002"), 0); + cachingStore.put(bytesKey("a"), bytesValue("0003"), 1); + cachingStore.put(bytesKey("aa"), bytesValue("0004"), 1); + cachingStore.put(bytesKey("a"), bytesValue("0005"), SEGMENT_INTERVAL); + + verifyKeyValueList( + asList( + windowedPair("a", "0005", SEGMENT_INTERVAL), + windowedPair("a", "0003", 1), + windowedPair("a", "0001", 0) + ), + toList(cachingStore.backwardFetch(bytesKey("a"), bytesKey("a"), ofEpochMilli(0), ofEpochMilli(Long.MAX_VALUE))) + ); + + verifyKeyValueList( + asList( + windowedPair("aa", "0004", 1), + windowedPair("aa", "0002", 0)), + toList(cachingStore.backwardFetch(bytesKey("aa"), bytesKey("aa"), ofEpochMilli(0), ofEpochMilli(Long.MAX_VALUE))) + ); + + verifyKeyValueList( + asList( + windowedPair("a", "0005", SEGMENT_INTERVAL), + windowedPair("aa", "0004", 1), + windowedPair("aa", "0002", 0), + windowedPair("a", "0003", 1), + windowedPair("a", "0001", 0) + ), + toList(cachingStore.backwardFetch(bytesKey("a"), bytesKey("aa"), ofEpochMilli(0), ofEpochMilli(Long.MAX_VALUE))) + ); + } + + @Test + public void shouldReturnSameResultsForSingleKeyFetchAndEqualKeyRangeFetch() { + cachingStore.put(bytesKey("a"), bytesValue("0001"), 0); + cachingStore.put(bytesKey("aa"), bytesValue("0002"), 1); + cachingStore.put(bytesKey("aa"), bytesValue("0003"), 2); + cachingStore.put(bytesKey("aaa"), bytesValue("0004"), 3); + + try (final WindowStoreIterator singleKeyIterator = cachingStore.fetch(bytesKey("aa"), 0L, 5L); + final KeyValueIterator, byte[]> keyRangeIterator = cachingStore.fetch(bytesKey("aa"), bytesKey("aa"), 0L, 5L)) { + + assertEquals(stringFrom(singleKeyIterator.next().value), stringFrom(keyRangeIterator.next().value)); + assertEquals(stringFrom(singleKeyIterator.next().value), stringFrom(keyRangeIterator.next().value)); + assertFalse(singleKeyIterator.hasNext()); + assertFalse(keyRangeIterator.hasNext()); + } + } + + @Test + public void shouldReturnSameResultsForSingleKeyFetchAndEqualKeyRangeBackwardFetch() { + cachingStore.put(bytesKey("a"), bytesValue("0001"), 0); + cachingStore.put(bytesKey("aa"), bytesValue("0002"), 1); + cachingStore.put(bytesKey("aa"), bytesValue("0003"), 2); + cachingStore.put(bytesKey("aaa"), bytesValue("0004"), 3); + + try (final WindowStoreIterator singleKeyIterator = + cachingStore.backwardFetch(bytesKey("aa"), Instant.ofEpochMilli(0L), Instant.ofEpochMilli(5L)); + final KeyValueIterator, byte[]> keyRangeIterator = + cachingStore.backwardFetch(bytesKey("aa"), bytesKey("aa"), Instant.ofEpochMilli(0L), Instant.ofEpochMilli(5L))) { + + assertEquals(stringFrom(singleKeyIterator.next().value), stringFrom(keyRangeIterator.next().value)); + assertEquals(stringFrom(singleKeyIterator.next().value), stringFrom(keyRangeIterator.next().value)); + assertFalse(singleKeyIterator.hasNext()); + assertFalse(keyRangeIterator.hasNext()); + } + } + + @Test + public void shouldThrowNullPointerExceptionOnPutNullKey() { + assertThrows(NullPointerException.class, () -> cachingStore.put(null, bytesValue("anyValue"), 0L)); + } + + @Test + public void shouldNotThrowNullPointerExceptionOnPutNullValue() { + cachingStore.put(bytesKey("a"), null, 0L); + } + + @Test + public void shouldThrowNullPointerExceptionOnFetchNullKey() { + assertThrows(NullPointerException.class, () -> cachingStore.fetch(null, ofEpochMilli(1L), ofEpochMilli(2L))); + } + + @Test + public void shouldNotThrowInvalidRangeExceptionWithNegativeFromKey() { + final Bytes keyFrom = Bytes.wrap(Serdes.Integer().serializer().serialize("", -1)); + final Bytes keyTo = Bytes.wrap(Serdes.Integer().serializer().serialize("", 1)); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(CachingWindowStore.class); + final KeyValueIterator, byte[]> iterator = cachingStore.fetch(keyFrom, keyTo, 0L, 10L)) { + assertFalse(iterator.hasNext()); + + final List messages = appender.getMessages(); + assertThat( + messages, + hasItem("Returning empty iterator for fetch with invalid key range: from > to." + + " This may be due to range arguments set in the wrong order, " + + "or serdes that don't preserve ordering when lexicographically comparing the serialized bytes." + + " Note that the built-in numerical serdes do not follow this for negative numbers") + ); + } + } + + @Test + public void shouldNotThrowInvalidBackwardRangeExceptionWithNegativeFromKey() { + final Bytes keyFrom = Bytes.wrap(Serdes.Integer().serializer().serialize("", -1)); + final Bytes keyTo = Bytes.wrap(Serdes.Integer().serializer().serialize("", 1)); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(CachingWindowStore.class); + final KeyValueIterator, byte[]> iterator = + cachingStore.backwardFetch(keyFrom, keyTo, Instant.ofEpochMilli(0L), Instant.ofEpochMilli(10L))) { + assertFalse(iterator.hasNext()); + + final List messages = appender.getMessages(); + assertThat( + messages, + hasItem("Returning empty iterator for fetch with invalid key range: from > to." + + " This may be due to serdes that don't preserve ordering when lexicographically comparing the serialized bytes." + + " Note that the built-in numerical serdes do not follow this for negative numbers") + ); + } + } + + @Test + public void shouldCloseCacheAndWrappedStoreAfterErrorDuringCacheFlush() { + setUpCloseTests(); + EasyMock.reset(cache); + cache.flush(CACHE_NAMESPACE); + EasyMock.expectLastCall().andThrow(new RuntimeException("Simulating an error on flush")); + cache.close(CACHE_NAMESPACE); + EasyMock.replay(cache); + EasyMock.reset(underlyingStore); + underlyingStore.close(); + EasyMock.replay(underlyingStore); + + assertThrows(RuntimeException.class, cachingStore::close); + EasyMock.verify(cache, underlyingStore); + } + + @Test + public void shouldCloseWrappedStoreAfterErrorDuringCacheClose() { + setUpCloseTests(); + EasyMock.reset(cache); + cache.flush(CACHE_NAMESPACE); + cache.close(CACHE_NAMESPACE); + EasyMock.expectLastCall().andThrow(new RuntimeException("Simulating an error on close")); + EasyMock.replay(cache); + EasyMock.reset(underlyingStore); + underlyingStore.close(); + EasyMock.replay(underlyingStore); + + assertThrows(RuntimeException.class, cachingStore::close); + EasyMock.verify(cache, underlyingStore); + } + + @Test + public void shouldCloseCacheAfterErrorDuringStateStoreClose() { + setUpCloseTests(); + EasyMock.reset(cache); + cache.flush(CACHE_NAMESPACE); + cache.close(CACHE_NAMESPACE); + EasyMock.replay(cache); + EasyMock.reset(underlyingStore); + underlyingStore.close(); + EasyMock.expectLastCall().andThrow(new RuntimeException("Simulating an error on close")); + EasyMock.replay(underlyingStore); + + assertThrows(RuntimeException.class, cachingStore::close); + EasyMock.verify(cache, underlyingStore); + } + + private void setUpCloseTests() { + underlyingStore = EasyMock.createNiceMock(WindowStore.class); + EasyMock.expect(underlyingStore.name()).andStubReturn("store-name"); + EasyMock.expect(underlyingStore.isOpen()).andStubReturn(true); + EasyMock.replay(underlyingStore); + cachingStore = new CachingWindowStore(underlyingStore, WINDOW_SIZE, SEGMENT_INTERVAL); + cache = EasyMock.createNiceMock(ThreadCache.class); + context = new InternalMockProcessorContext<>(TestUtils.tempDirectory(), null, null, null, cache); + context.setRecordContext(new ProcessorRecordContext(10, 0, 0, TOPIC, new RecordHeaders())); + cachingStore.init((StateStoreContext) context, cachingStore); + } + + private static KeyValue, byte[]> windowedPair(final String key, final String value, final long timestamp) { + return KeyValue.pair( + new Windowed<>(bytesKey(key), new TimeWindow(timestamp, timestamp + WINDOW_SIZE)), + bytesValue(value)); + } + + private int addItemsToCache() { + int cachedSize = 0; + int i = 0; + while (cachedSize < MAX_CACHE_SIZE_BYTES) { + final String kv = String.valueOf(i++); + cachingStore.put(bytesKey(kv), bytesValue(kv), DEFAULT_TIMESTAMP); + cachedSize += memoryCacheEntrySize(kv.getBytes(), kv.getBytes(), TOPIC) + + 8 + // timestamp + 4; // sequenceNumber + } + return i; + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/ChangeLoggingKeyValueBytesStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/ChangeLoggingKeyValueBytesStoreTest.java new file mode 100644 index 0000000..6e8e979 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/ChangeLoggingKeyValueBytesStoreTest.java @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.internals.MockStreamsMetrics; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.test.InternalMockProcessorContext; +import org.apache.kafka.test.MockRecordCollector; +import org.apache.kafka.test.TestUtils; +import org.easymock.EasyMock; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; + +@SuppressWarnings("rawtypes") +public class ChangeLoggingKeyValueBytesStoreTest { + + private final MockRecordCollector collector = new MockRecordCollector(); + private final InMemoryKeyValueStore inner = new InMemoryKeyValueStore("kv"); + private final ChangeLoggingKeyValueBytesStore store = new ChangeLoggingKeyValueBytesStore(inner); + private final Bytes hi = Bytes.wrap("hi".getBytes()); + private final Bytes hello = Bytes.wrap("hello".getBytes()); + private final byte[] there = "there".getBytes(); + private final byte[] world = "world".getBytes(); + + @Before + public void before() { + final InternalMockProcessorContext context = mockContext(); + context.setTime(0); + store.init((StateStoreContext) context, store); + } + + private InternalMockProcessorContext mockContext() { + return new InternalMockProcessorContext<>( + TestUtils.tempDirectory(), + Serdes.String(), + Serdes.Long(), + collector, + new ThreadCache(new LogContext("testCache "), 0, new MockStreamsMetrics(new Metrics())) + ); + } + + @After + public void after() { + store.close(); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldDelegateDeprecatedInit() { + final InternalMockProcessorContext context = mockContext(); + final KeyValueStore innerMock = EasyMock.mock(InMemoryKeyValueStore.class); + final StateStore outer = new ChangeLoggingKeyValueBytesStore(innerMock); + innerMock.init((ProcessorContext) context, outer); + EasyMock.expectLastCall(); + EasyMock.replay(innerMock); + outer.init((ProcessorContext) context, outer); + EasyMock.verify(innerMock); + } + + @Test + public void shouldDelegateInit() { + final InternalMockProcessorContext context = mockContext(); + final KeyValueStore innerMock = EasyMock.mock(InMemoryKeyValueStore.class); + final StateStore outer = new ChangeLoggingKeyValueBytesStore(innerMock); + innerMock.init((StateStoreContext) context, outer); + EasyMock.expectLastCall(); + EasyMock.replay(innerMock); + outer.init((StateStoreContext) context, outer); + EasyMock.verify(innerMock); + } + + @Test + public void shouldWriteKeyValueBytesToInnerStoreOnPut() { + store.put(hi, there); + assertThat(inner.get(hi), equalTo(there)); + assertThat(collector.collected().size(), equalTo(1)); + assertThat(collector.collected().get(0).key(), equalTo(hi)); + assertThat(collector.collected().get(0).value(), equalTo(there)); + } + + @Test + public void shouldWriteAllKeyValueToInnerStoreOnPutAll() { + store.putAll(Arrays.asList(KeyValue.pair(hi, there), + KeyValue.pair(hello, world))); + assertThat(inner.get(hi), equalTo(there)); + assertThat(inner.get(hello), equalTo(world)); + + assertThat(collector.collected().size(), equalTo(2)); + assertThat(collector.collected().get(0).key(), equalTo(hi)); + assertThat(collector.collected().get(0).value(), equalTo(there)); + assertThat(collector.collected().get(1).key(), equalTo(hello)); + assertThat(collector.collected().get(1).value(), equalTo(world)); + } + + @Test + public void shouldPropagateDelete() { + store.put(hi, there); + store.delete(hi); + assertThat(inner.approximateNumEntries(), equalTo(0L)); + assertThat(inner.get(hi), nullValue()); + } + + @Test + public void shouldReturnOldValueOnDelete() { + store.put(hi, there); + assertThat(store.delete(hi), equalTo(there)); + } + + @Test + public void shouldLogKeyNullOnDelete() { + store.put(hi, there); + assertThat(store.delete(hi), equalTo(there)); + + assertThat(collector.collected().size(), equalTo(2)); + assertThat(collector.collected().get(0).key(), equalTo(hi)); + assertThat(collector.collected().get(0).value(), equalTo(there)); + assertThat(collector.collected().get(1).key(), equalTo(hi)); + assertThat(collector.collected().get(1).value(), nullValue()); + } + + @Test + public void shouldWriteToInnerOnPutIfAbsentNoPreviousValue() { + store.putIfAbsent(hi, there); + assertThat(inner.get(hi), equalTo(there)); + } + + @Test + public void shouldNotWriteToInnerOnPutIfAbsentWhenValueForKeyExists() { + store.put(hi, there); + store.putIfAbsent(hi, world); + assertThat(inner.get(hi), equalTo(there)); + } + + @Test + public void shouldWriteToChangelogOnPutIfAbsentWhenNoPreviousValue() { + store.putIfAbsent(hi, there); + + assertThat(collector.collected().size(), equalTo(1)); + assertThat(collector.collected().get(0).key(), equalTo(hi)); + assertThat(collector.collected().get(0).value(), equalTo(there)); + } + + @Test + public void shouldNotWriteToChangeLogOnPutIfAbsentWhenValueForKeyExists() { + store.put(hi, there); + store.putIfAbsent(hi, world); + + assertThat(collector.collected().size(), equalTo(1)); + assertThat(collector.collected().get(0).key(), equalTo(hi)); + assertThat(collector.collected().get(0).value(), equalTo(there)); + } + + @Test + public void shouldReturnCurrentValueOnPutIfAbsent() { + store.put(hi, there); + assertThat(store.putIfAbsent(hi, world), equalTo(there)); + } + + @Test + public void shouldReturnNullOnPutIfAbsentWhenNoPreviousValue() { + assertThat(store.putIfAbsent(hi, there), is(nullValue())); + } + + @Test + public void shouldReturnValueOnGetWhenExists() { + store.put(hello, world); + assertThat(store.get(hello), equalTo(world)); + } + + @Test + public void shouldGetRecordsWithPrefixKey() { + store.put(hi, there); + store.put(Bytes.increment(hi), world); + + final List keys = new ArrayList<>(); + final List values = new ArrayList<>(); + int numberOfKeysReturned = 0; + + try (final KeyValueIterator keysWithPrefix = store.prefixScan(hi.toString(), new StringSerializer())) { + while (keysWithPrefix.hasNext()) { + final KeyValue next = keysWithPrefix.next(); + keys.add(next.key); + values.add(Bytes.wrap(next.value)); + numberOfKeysReturned++; + } + } + + assertThat(numberOfKeysReturned, is(1)); + assertThat(keys, is(Collections.singletonList(hi))); + assertThat(values, is(Collections.singletonList(Bytes.wrap(there)))); + } + + @Test + public void shouldReturnNullOnGetWhenDoesntExist() { + assertThat(store.get(hello), is(nullValue())); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/ChangeLoggingSessionBytesStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/ChangeLoggingSessionBytesStoreTest.java new file mode 100644 index 0000000..8fdbd33 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/ChangeLoggingSessionBytesStoreTest.java @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.SessionWindow; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.ProcessorContextImpl; +import org.apache.kafka.streams.state.SessionStore; +import org.apache.kafka.test.MockRecordCollector; +import org.easymock.EasyMock; +import org.easymock.EasyMockRunner; +import org.easymock.Mock; +import org.easymock.MockType; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +@RunWith(EasyMockRunner.class) +public class ChangeLoggingSessionBytesStoreTest { + + private final TaskId taskId = new TaskId(0, 0); + private final MockRecordCollector collector = new MockRecordCollector(); + + @Mock(type = MockType.NICE) + private SessionStore inner; + @Mock(type = MockType.NICE) + private ProcessorContextImpl context; + + private ChangeLoggingSessionBytesStore store; + private final byte[] value1 = {0}; + private final Bytes bytesKey = Bytes.wrap(value1); + private final Windowed key1 = new Windowed<>(bytesKey, new SessionWindow(0, 0)); + + @Before + public void setUp() { + store = new ChangeLoggingSessionBytesStore(inner); + } + + private void init() { + EasyMock.expect(context.taskId()).andReturn(taskId); + EasyMock.expect(context.recordCollector()).andReturn(collector); + inner.init((StateStoreContext) context, store); + EasyMock.expectLastCall(); + EasyMock.replay(inner, context); + + store.init((StateStoreContext) context, store); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldDelegateDeprecatedInit() { + inner.init((ProcessorContext) context, store); + EasyMock.expectLastCall(); + EasyMock.replay(inner); + store.init((ProcessorContext) context, store); + EasyMock.verify(inner); + } + + @Test + public void shouldDelegateInit() { + inner.init((StateStoreContext) context, store); + EasyMock.expectLastCall(); + EasyMock.replay(inner); + store.init((StateStoreContext) context, store); + EasyMock.verify(inner); + } + + @Test + public void shouldLogPuts() { + inner.put(key1, value1); + EasyMock.expectLastCall(); + + init(); + + final Bytes binaryKey = SessionKeySchema.toBinary(key1); + + EasyMock.reset(context); + context.logChange(store.name(), binaryKey, value1, 0L); + + EasyMock.replay(context); + store.put(key1, value1); + + EasyMock.verify(inner, context); + } + + @Test + public void shouldLogRemoves() { + inner.remove(key1); + EasyMock.expectLastCall(); + + init(); + store.remove(key1); + + final Bytes binaryKey = SessionKeySchema.toBinary(key1); + + EasyMock.reset(context); + context.logChange(store.name(), binaryKey, null, 0L); + + EasyMock.replay(context); + store.remove(key1); + + EasyMock.verify(inner, context); + } + + @Test + public void shouldDelegateToUnderlyingStoreWhenFetching() { + EasyMock.expect(inner.fetch(bytesKey)).andReturn(KeyValueIterators.emptyIterator()); + + init(); + + store.fetch(bytesKey); + EasyMock.verify(inner); + } + + @Test + public void shouldDelegateToUnderlyingStoreWhenBackwardFetching() { + EasyMock.expect(inner.backwardFetch(bytesKey)).andReturn(KeyValueIterators.emptyIterator()); + + init(); + + store.backwardFetch(bytesKey); + EasyMock.verify(inner); + } + + @Test + public void shouldDelegateToUnderlyingStoreWhenFetchingRange() { + EasyMock.expect(inner.fetch(bytesKey, bytesKey)).andReturn(KeyValueIterators.emptyIterator()); + + init(); + + store.fetch(bytesKey, bytesKey); + EasyMock.verify(inner); + } + + @Test + public void shouldDelegateToUnderlyingStoreWhenBackwardFetchingRange() { + EasyMock.expect(inner.backwardFetch(bytesKey, bytesKey)).andReturn(KeyValueIterators.emptyIterator()); + + init(); + + store.backwardFetch(bytesKey, bytesKey); + EasyMock.verify(inner); + } + + @Test + public void shouldDelegateToUnderlyingStoreWhenFindingSessions() { + EasyMock.expect(inner.findSessions(bytesKey, 0, 1)).andReturn(KeyValueIterators.emptyIterator()); + + init(); + + store.findSessions(bytesKey, 0, 1); + EasyMock.verify(inner); + } + + @Test + public void shouldDelegateToUnderlyingStoreWhenBackwardFindingSessions() { + EasyMock.expect(inner.backwardFindSessions(bytesKey, 0, 1)).andReturn(KeyValueIterators.emptyIterator()); + + init(); + + store.backwardFindSessions(bytesKey, 0, 1); + EasyMock.verify(inner); + } + + @Test + public void shouldDelegateToUnderlyingStoreWhenFindingSessionRange() { + EasyMock.expect(inner.findSessions(bytesKey, bytesKey, 0, 1)).andReturn(KeyValueIterators.emptyIterator()); + + init(); + + store.findSessions(bytesKey, bytesKey, 0, 1); + EasyMock.verify(inner); + } + + @Test + public void shouldDelegateToUnderlyingStoreWhenBackwardFindingSessionRange() { + EasyMock.expect(inner.backwardFindSessions(bytesKey, bytesKey, 0, 1)).andReturn(KeyValueIterators.emptyIterator()); + + init(); + + store.backwardFindSessions(bytesKey, bytesKey, 0, 1); + EasyMock.verify(inner); + } + + @Test + public void shouldFlushUnderlyingStore() { + inner.flush(); + EasyMock.expectLastCall(); + + init(); + + store.flush(); + EasyMock.verify(inner); + } + + @Test + public void shouldCloseUnderlyingStore() { + inner.close(); + EasyMock.expectLastCall(); + + init(); + + store.close(); + EasyMock.verify(inner); + } + + +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/ChangeLoggingTimestampedKeyValueBytesStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/ChangeLoggingTimestampedKeyValueBytesStoreTest.java new file mode 100644 index 0000000..d65d948 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/ChangeLoggingTimestampedKeyValueBytesStoreTest.java @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.internals.MockStreamsMetrics; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.test.InternalMockProcessorContext; +import org.apache.kafka.test.MockRecordCollector; +import org.apache.kafka.test.TestUtils; +import org.easymock.EasyMock; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; + +@SuppressWarnings("rawtypes") +public class ChangeLoggingTimestampedKeyValueBytesStoreTest { + + private final MockRecordCollector collector = new MockRecordCollector(); + private final InMemoryKeyValueStore root = new InMemoryKeyValueStore("kv"); + private final ChangeLoggingTimestampedKeyValueBytesStore store = new ChangeLoggingTimestampedKeyValueBytesStore(root); + private final Bytes hi = Bytes.wrap("hi".getBytes()); + private final Bytes hello = Bytes.wrap("hello".getBytes()); + private final ValueAndTimestamp there = ValueAndTimestamp.make("there".getBytes(), 97L); + // timestamp is 97 what is ASCII of 'a' + private final byte[] rawThere = "\0\0\0\0\0\0\0athere".getBytes(); + private final ValueAndTimestamp world = ValueAndTimestamp.make("world".getBytes(), 98L); + // timestamp is 98 what is ASCII of 'b' + private final byte[] rawWorld = "\0\0\0\0\0\0\0bworld".getBytes(); + + @Before + public void before() { + final InternalMockProcessorContext context = mockContext(); + context.setTime(0); + store.init((StateStoreContext) context, store); + } + + private InternalMockProcessorContext mockContext() { + return new InternalMockProcessorContext<>( + TestUtils.tempDirectory(), + Serdes.String(), + Serdes.Long(), + collector, + new ThreadCache(new LogContext("testCache "), 0, new MockStreamsMetrics(new Metrics())) + ); + } + + @After + public void after() { + store.close(); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldDelegateDeprecatedInit() { + final InternalMockProcessorContext context = mockContext(); + final KeyValueStore inner = EasyMock.mock(InMemoryKeyValueStore.class); + final StateStore outer = new ChangeLoggingTimestampedKeyValueBytesStore(inner); + inner.init((ProcessorContext) context, outer); + EasyMock.expectLastCall(); + EasyMock.replay(inner); + outer.init((ProcessorContext) context, outer); + EasyMock.verify(inner); + } + + @Test + public void shouldDelegateInit() { + final InternalMockProcessorContext context = mockContext(); + final KeyValueStore inner = EasyMock.mock(InMemoryKeyValueStore.class); + final StateStore outer = new ChangeLoggingTimestampedKeyValueBytesStore(inner); + inner.init((StateStoreContext) context, outer); + EasyMock.expectLastCall(); + EasyMock.replay(inner); + outer.init((StateStoreContext) context, outer); + EasyMock.verify(inner); + } + + @Test + public void shouldWriteKeyValueBytesToInnerStoreOnPut() { + store.put(hi, rawThere); + + assertThat(root.get(hi), equalTo(rawThere)); + assertThat(collector.collected().size(), equalTo(1)); + assertThat(collector.collected().get(0).key(), equalTo(hi)); + assertThat(collector.collected().get(0).value(), equalTo(there.value())); + assertThat(collector.collected().get(0).timestamp(), equalTo(there.timestamp())); + } + + @Test + public void shouldWriteAllKeyValueToInnerStoreOnPutAll() { + store.putAll(Arrays.asList(KeyValue.pair(hi, rawThere), + KeyValue.pair(hello, rawWorld))); + assertThat(root.get(hi), equalTo(rawThere)); + assertThat(root.get(hello), equalTo(rawWorld)); + } + + @Test + public void shouldLogChangesOnPutAll() { + store.putAll(Arrays.asList(KeyValue.pair(hi, rawThere), + KeyValue.pair(hello, rawWorld))); + + assertThat(collector.collected().size(), equalTo(2)); + assertThat(collector.collected().get(0).key(), equalTo(hi)); + assertThat(collector.collected().get(0).value(), equalTo(there.value())); + assertThat(collector.collected().get(0).timestamp(), equalTo(there.timestamp())); + assertThat(collector.collected().get(1).key(), equalTo(hello)); + assertThat(collector.collected().get(1).value(), equalTo(world.value())); + assertThat(collector.collected().get(1).timestamp(), equalTo(world.timestamp())); + } + + @Test + public void shouldPropagateDelete() { + store.put(hi, rawThere); + store.delete(hi); + assertThat(root.approximateNumEntries(), equalTo(0L)); + assertThat(root.get(hi), nullValue()); + } + + @Test + public void shouldReturnOldValueOnDelete() { + store.put(hi, rawThere); + assertThat(store.delete(hi), equalTo(rawThere)); + } + + @Test + public void shouldLogKeyNullOnDelete() { + store.put(hi, rawThere); + store.delete(hi); + + assertThat(collector.collected().size(), equalTo(2)); + assertThat(collector.collected().get(0).key(), equalTo(hi)); + assertThat(collector.collected().get(0).value(), equalTo(there.value())); + assertThat(collector.collected().get(0).timestamp(), equalTo(there.timestamp())); + assertThat(collector.collected().get(1).key(), equalTo(hi)); + assertThat(collector.collected().get(1).value(), nullValue()); + assertThat(collector.collected().get(1).timestamp(), equalTo(0L)); + + } + + @Test + public void shouldWriteToInnerOnPutIfAbsentNoPreviousValue() { + store.putIfAbsent(hi, rawThere); + assertThat(root.get(hi), equalTo(rawThere)); + } + + @Test + public void shouldNotWriteToInnerOnPutIfAbsentWhenValueForKeyExists() { + store.put(hi, rawThere); + store.putIfAbsent(hi, rawWorld); + assertThat(root.get(hi), equalTo(rawThere)); + } + + @Test + public void shouldWriteToChangelogOnPutIfAbsentWhenNoPreviousValue() { + store.putIfAbsent(hi, rawThere); + + assertThat(collector.collected().size(), equalTo(1)); + assertThat(collector.collected().get(0).key(), equalTo(hi)); + assertThat(collector.collected().get(0).value(), equalTo(there.value())); + assertThat(collector.collected().get(0).timestamp(), equalTo(there.timestamp())); + } + + @Test + public void shouldNotWriteToChangeLogOnPutIfAbsentWhenValueForKeyExists() { + store.put(hi, rawThere); + store.putIfAbsent(hi, rawWorld); + + assertThat(collector.collected().size(), equalTo(1)); + assertThat(collector.collected().get(0).key(), equalTo(hi)); + assertThat(collector.collected().get(0).value(), equalTo(there.value())); + assertThat(collector.collected().get(0).timestamp(), equalTo(there.timestamp())); + } + + @Test + public void shouldReturnCurrentValueOnPutIfAbsent() { + store.put(hi, rawThere); + assertThat(store.putIfAbsent(hi, rawWorld), equalTo(rawThere)); + } + + @Test + public void shouldReturnNullOnPutIfAbsentWhenNoPreviousValue() { + assertThat(store.putIfAbsent(hi, rawThere), is(nullValue())); + } + + @Test + public void shouldReturnValueOnGetWhenExists() { + store.put(hello, rawWorld); + assertThat(store.get(hello), equalTo(rawWorld)); + } + + @Test + public void shouldReturnNullOnGetWhenDoesntExist() { + assertThat(store.get(hello), is(nullValue())); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/ChangeLoggingTimestampedWindowBytesStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/ChangeLoggingTimestampedWindowBytesStoreTest.java new file mode 100644 index 0000000..50c18fe --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/ChangeLoggingTimestampedWindowBytesStoreTest.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.ProcessorContextImpl; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.test.MockRecordCollector; +import org.easymock.EasyMock; +import org.easymock.EasyMockRunner; +import org.easymock.Mock; +import org.easymock.MockType; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import static java.time.Instant.ofEpochMilli; + +@RunWith(EasyMockRunner.class) +public class ChangeLoggingTimestampedWindowBytesStoreTest { + + private final TaskId taskId = new TaskId(0, 0); + private final MockRecordCollector collector = new MockRecordCollector(); + + private final byte[] value = {0}; + private final byte[] valueAndTimestamp = {0, 0, 0, 0, 0, 0, 0, 42, 0}; + private final Bytes bytesKey = Bytes.wrap(value); + + @Mock(type = MockType.NICE) + private WindowStore inner; + @Mock(type = MockType.NICE) + private ProcessorContextImpl context; + private ChangeLoggingTimestampedWindowBytesStore store; + + + @Before + public void setUp() { + store = new ChangeLoggingTimestampedWindowBytesStore(inner, false); + } + + private void init() { + EasyMock.expect(context.taskId()).andReturn(taskId); + EasyMock.expect(context.recordCollector()).andReturn(collector); + inner.init((StateStoreContext) context, store); + EasyMock.expectLastCall(); + EasyMock.replay(inner, context); + + store.init((StateStoreContext) context, store); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldDelegateDeprecatedInit() { + inner.init((ProcessorContext) context, store); + EasyMock.expectLastCall(); + EasyMock.replay(inner); + store.init((ProcessorContext) context, store); + EasyMock.verify(inner); + } + + @Test + public void shouldDelegateInit() { + inner.init((StateStoreContext) context, store); + EasyMock.expectLastCall(); + EasyMock.replay(inner); + store.init((StateStoreContext) context, store); + EasyMock.verify(inner); + } + + @Test + @SuppressWarnings("deprecation") + public void shouldLogPuts() { + inner.put(bytesKey, valueAndTimestamp, 0); + EasyMock.expectLastCall(); + + init(); + + final Bytes key = WindowKeySchema.toStoreKeyBinary(bytesKey, 0, 0); + + EasyMock.reset(context); + context.logChange(store.name(), key, value, 42); + + EasyMock.replay(context); + store.put(bytesKey, valueAndTimestamp, context.timestamp()); + + EasyMock.verify(inner, context); + } + + @Test + public void shouldDelegateToUnderlyingStoreWhenFetching() { + EasyMock + .expect(inner.fetch(bytesKey, 0, 10)) + .andReturn(KeyValueIterators.emptyWindowStoreIterator()); + + init(); + + store.fetch(bytesKey, ofEpochMilli(0), ofEpochMilli(10)); + EasyMock.verify(inner); + } + + @Test + public void shouldDelegateToUnderlyingStoreWhenFetchingRange() { + EasyMock + .expect(inner.fetch(bytesKey, bytesKey, 0, 1)) + .andReturn(KeyValueIterators.emptyIterator()); + + init(); + + store.fetch(bytesKey, bytesKey, ofEpochMilli(0), ofEpochMilli(1)); + EasyMock.verify(inner); + } + + @Test + @SuppressWarnings("deprecation") + public void shouldRetainDuplicatesWhenSet() { + store = new ChangeLoggingTimestampedWindowBytesStore(inner, true); + inner.put(bytesKey, valueAndTimestamp, 0); + EasyMock.expectLastCall().times(2); + + init(); + + final Bytes key1 = WindowKeySchema.toStoreKeyBinary(bytesKey, 0, 1); + final Bytes key2 = WindowKeySchema.toStoreKeyBinary(bytesKey, 0, 2); + + EasyMock.reset(context); + context.logChange(store.name(), key1, value, 42L); + context.logChange(store.name(), key2, value, 42L); + + EasyMock.replay(context); + + store.put(bytesKey, valueAndTimestamp, context.timestamp()); + store.put(bytesKey, valueAndTimestamp, context.timestamp()); + + EasyMock.verify(inner, context); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/ChangeLoggingWindowBytesStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/ChangeLoggingWindowBytesStoreTest.java new file mode 100644 index 0000000..36e3297 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/ChangeLoggingWindowBytesStoreTest.java @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.ProcessorContextImpl; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.test.MockRecordCollector; +import org.easymock.EasyMock; +import org.easymock.EasyMockRunner; +import org.easymock.Mock; +import org.easymock.MockType; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import static java.time.Instant.ofEpochMilli; + +@RunWith(EasyMockRunner.class) +public class ChangeLoggingWindowBytesStoreTest { + + private final TaskId taskId = new TaskId(0, 0); + private final MockRecordCollector collector = new MockRecordCollector(); + + private final byte[] value = {0}; + private final Bytes bytesKey = Bytes.wrap(value); + + @Mock(type = MockType.NICE) + private WindowStore inner; + @Mock(type = MockType.NICE) + private ProcessorContextImpl context; + private ChangeLoggingWindowBytesStore store; + + @Before + public void setUp() { + store = new ChangeLoggingWindowBytesStore(inner, false, WindowKeySchema::toStoreKeyBinary); + } + + private void init() { + EasyMock.expect(context.taskId()).andReturn(taskId); + EasyMock.expect(context.recordCollector()).andReturn(collector); + inner.init((StateStoreContext) context, store); + EasyMock.expectLastCall(); + EasyMock.replay(inner, context); + + store.init((StateStoreContext) context, store); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldDelegateDeprecatedInit() { + inner.init((ProcessorContext) context, store); + EasyMock.expectLastCall(); + EasyMock.replay(inner); + store.init((ProcessorContext) context, store); + EasyMock.verify(inner); + } + + @Test + public void shouldDelegateInit() { + inner.init((StateStoreContext) context, store); + EasyMock.expectLastCall(); + EasyMock.replay(inner); + store.init((StateStoreContext) context, store); + EasyMock.verify(inner); + } + + @Test + public void shouldLogPuts() { + inner.put(bytesKey, value, 0); + EasyMock.expectLastCall(); + + init(); + + final Bytes key = WindowKeySchema.toStoreKeyBinary(bytesKey, 0, 0); + + EasyMock.reset(context); + EasyMock.expect(context.timestamp()).andStubReturn(0L); + context.logChange(store.name(), key, value, 0L); + + EasyMock.replay(context); + store.put(bytesKey, value, context.timestamp()); + + EasyMock.verify(inner, context); + } + + @Test + public void shouldDelegateToUnderlyingStoreWhenFetching() { + EasyMock + .expect(inner.fetch(bytesKey, 0, 10)) + .andReturn(KeyValueIterators.emptyWindowStoreIterator()); + + init(); + + store.fetch(bytesKey, ofEpochMilli(0), ofEpochMilli(10)); + EasyMock.verify(inner); + } + + @Test + public void shouldDelegateToUnderlyingStoreWhenBackwardFetching() { + EasyMock + .expect(inner.backwardFetch(bytesKey, 0, 10)) + .andReturn(KeyValueIterators.emptyWindowStoreIterator()); + + init(); + + store.backwardFetch(bytesKey, ofEpochMilli(0), ofEpochMilli(10)); + EasyMock.verify(inner); + } + + @Test + public void shouldDelegateToUnderlyingStoreWhenFetchingRange() { + EasyMock + .expect(inner.fetch(bytesKey, bytesKey, 0, 1)) + .andReturn(KeyValueIterators.emptyIterator()); + + init(); + + store.fetch(bytesKey, bytesKey, ofEpochMilli(0), ofEpochMilli(1)); + EasyMock.verify(inner); + } + + @Test + public void shouldDelegateToUnderlyingStoreWhenBackwardFetchingRange() { + EasyMock + .expect(inner.backwardFetch(bytesKey, bytesKey, 0, 1)) + .andReturn(KeyValueIterators.emptyIterator()); + + init(); + + store.backwardFetch(bytesKey, bytesKey, ofEpochMilli(0), ofEpochMilli(1)); + EasyMock.verify(inner); + } + + @Test + public void shouldRetainDuplicatesWhenSet() { + store = new ChangeLoggingWindowBytesStore(inner, true, WindowKeySchema::toStoreKeyBinary); + + inner.put(bytesKey, value, 0); + EasyMock.expectLastCall().times(2); + + init(); + + final Bytes key1 = WindowKeySchema.toStoreKeyBinary(bytesKey, 0, 1); + final Bytes key2 = WindowKeySchema.toStoreKeyBinary(bytesKey, 0, 2); + + EasyMock.reset(context); + EasyMock.expect(context.timestamp()).andStubReturn(0L); + context.logChange(store.name(), key1, value, 0L); + context.logChange(store.name(), key2, value, 0L); + + EasyMock.replay(context); + + store.put(bytesKey, value, context.timestamp()); + store.put(bytesKey, value, context.timestamp()); + + EasyMock.verify(inner, context); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/CompositeReadOnlyKeyValueStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/CompositeReadOnlyKeyValueStoreTest.java new file mode 100644 index 0000000..88785ed --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/CompositeReadOnlyKeyValueStoreTest.java @@ -0,0 +1,527 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StoreQueryParameters; +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.internals.ProcessorStateManager; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.QueryableStoreTypes; +import org.apache.kafka.streams.state.StateSerdes; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.test.InternalMockProcessorContext; +import org.apache.kafka.test.MockRecordCollector; +import org.apache.kafka.test.NoOpReadOnlyStore; +import org.apache.kafka.test.StateStoreProviderStub; +import org.junit.Before; +import org.junit.Test; + +import java.util.LinkedList; +import java.util.List; +import java.util.NoSuchElementException; + +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static org.apache.kafka.test.StreamsTestUtils.toList; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class CompositeReadOnlyKeyValueStoreTest { + + private final String storeName = "my-store"; + private StateStoreProviderStub stubProviderTwo; + private KeyValueStore stubOneUnderlying; + private KeyValueStore otherUnderlyingStore; + private CompositeReadOnlyKeyValueStore theStore; + + @Before + public void before() { + final StateStoreProviderStub stubProviderOne = new StateStoreProviderStub(false); + stubProviderTwo = new StateStoreProviderStub(false); + + stubOneUnderlying = newStoreInstance(); + stubProviderOne.addStore(storeName, stubOneUnderlying); + otherUnderlyingStore = newStoreInstance(); + stubProviderOne.addStore("other-store", otherUnderlyingStore); + theStore = new CompositeReadOnlyKeyValueStore<>( + new WrappingStoreProvider(asList(stubProviderOne, stubProviderTwo), StoreQueryParameters.fromNameAndType(storeName, QueryableStoreTypes.keyValueStore())), + QueryableStoreTypes.keyValueStore(), + storeName + ); + } + + private KeyValueStore newStoreInstance() { + final KeyValueStore store = Stores.keyValueStoreBuilder(Stores.inMemoryKeyValueStore(storeName), + Serdes.String(), + Serdes.String()) + .build(); + + @SuppressWarnings("rawtypes") final InternalMockProcessorContext context = + new InternalMockProcessorContext<>( + new StateSerdes<>( + ProcessorStateManager.storeChangelogTopic("appId", storeName, null), + Serdes.String(), + Serdes.String() + ), + new MockRecordCollector() + ); + context.setTime(1L); + + store.init((StateStoreContext) context, store); + + return store; + } + + @Test + public void shouldReturnNullIfKeyDoesNotExist() { + assertNull(theStore.get("whatever")); + } + + @Test + public void shouldThrowNullPointerExceptionOnGetNullKey() { + assertThrows(NullPointerException.class, () -> theStore.get(null)); + } + + @Test + public void shouldReturnValueOnRangeNullFromKey() { + stubOneUnderlying.put("0", "zero"); + stubOneUnderlying.put("1", "one"); + stubOneUnderlying.put("2", "two"); + + final LinkedList> expectedContents = new LinkedList<>(); + expectedContents.add(new KeyValue<>("0", "zero")); + expectedContents.add(new KeyValue<>("1", "one")); + + try (final KeyValueIterator iterator = theStore.range(null, "1")) { + assertEquals(expectedContents, Utils.toList(iterator)); + } + } + + @Test + public void shouldReturnValueOnRangeNullToKey() { + stubOneUnderlying.put("0", "zero"); + stubOneUnderlying.put("1", "one"); + stubOneUnderlying.put("2", "two"); + + final LinkedList> expectedContents = new LinkedList<>(); + expectedContents.add(new KeyValue<>("1", "one")); + expectedContents.add(new KeyValue<>("2", "two")); + + try (final KeyValueIterator iterator = theStore.range("1", null)) { + assertEquals(expectedContents, Utils.toList(iterator)); + } + } + + @Test + public void shouldThrowNullPointerExceptionOnPrefixScanNullPrefix() { + assertThrows(NullPointerException.class, () -> theStore.prefixScan(null, new StringSerializer())); + } + + @Test + public void shouldThrowNullPointerExceptionOnPrefixScanNullPrefixKeySerializer() { + assertThrows(NullPointerException.class, () -> theStore.prefixScan("aa", null)); + } + + @Test + public void shouldReturnValueOnReverseRangeNullFromKey() { + stubOneUnderlying.put("0", "zero"); + stubOneUnderlying.put("1", "one"); + stubOneUnderlying.put("2", "two"); + + final LinkedList> expectedContents = new LinkedList<>(); + expectedContents.add(new KeyValue<>("1", "one")); + expectedContents.add(new KeyValue<>("0", "zero")); + + try (final KeyValueIterator iterator = theStore.reverseRange(null, "1")) { + assertEquals(expectedContents, Utils.toList(iterator)); + } + } + + @Test + public void shouldReturnValueOnReverseRangeNullToKey() { + stubOneUnderlying.put("0", "zero"); + stubOneUnderlying.put("1", "one"); + stubOneUnderlying.put("2", "two"); + + final LinkedList> expectedContents = new LinkedList<>(); + expectedContents.add(new KeyValue<>("2", "two")); + expectedContents.add(new KeyValue<>("1", "one")); + + try (final KeyValueIterator iterator = theStore.reverseRange("1", null)) { + assertEquals(expectedContents, Utils.toList(iterator)); + } + } + + @Test + public void shouldReturnValueIfExists() { + stubOneUnderlying.put("key", "value"); + assertEquals("value", theStore.get("key")); + } + + @Test + public void shouldNotGetValuesFromOtherStores() { + otherUnderlyingStore.put("otherKey", "otherValue"); + assertNull(theStore.get("otherKey")); + } + + @Test + public void shouldThrowNoSuchElementExceptionWhileNext() { + stubOneUnderlying.put("a", "1"); + try (final KeyValueIterator keyValueIterator = theStore.range("a", "b")) { + keyValueIterator.next(); + assertThrows(NoSuchElementException.class, keyValueIterator::next); + } + } + + @Test + public void shouldThrowNoSuchElementExceptionWhilePeekNext() { + stubOneUnderlying.put("a", "1"); + try (final KeyValueIterator keyValueIterator = theStore.range("a", "b")) { + keyValueIterator.next(); + assertThrows(NoSuchElementException.class, keyValueIterator::peekNextKey); + } + } + + @Test + public void shouldThrowNoSuchElementExceptionWhileNextForPrefixScan() { + stubOneUnderlying.put("a", "1"); + try (final KeyValueIterator keyValueIterator = theStore.prefixScan("a", new StringSerializer())) { + keyValueIterator.next(); + assertThrows(NoSuchElementException.class, keyValueIterator::next); + } + } + + @Test + public void shouldThrowNoSuchElementExceptionWhilePeekNextForPrefixScan() { + stubOneUnderlying.put("a", "1"); + try (final KeyValueIterator keyValueIterator = theStore.prefixScan("a", new StringSerializer())) { + keyValueIterator.next(); + assertThrows(NoSuchElementException.class, keyValueIterator::peekNextKey); + } + } + + @Test + public void shouldThrowUnsupportedOperationExceptionWhileRemove() { + try (final KeyValueIterator keyValueIterator = theStore.all()) { + assertThrows(UnsupportedOperationException.class, keyValueIterator::remove); + } + } + + @Test + public void shouldThrowUnsupportedOperationExceptionWhileReverseRange() { + stubOneUnderlying.put("a", "1"); + stubOneUnderlying.put("b", "1"); + try (final KeyValueIterator keyValueIterator = theStore.reverseRange("a", "b")) { + assertThrows(UnsupportedOperationException.class, keyValueIterator::remove); + } + } + + @Test + public void shouldThrowUnsupportedOperationExceptionWhileRange() { + stubOneUnderlying.put("a", "1"); + stubOneUnderlying.put("b", "1"); + try (final KeyValueIterator keyValueIterator = theStore.range("a", "b")) { + assertThrows(UnsupportedOperationException.class, keyValueIterator::remove); + } + } + + @Test + public void shouldThrowUnsupportedOperationExceptionWhilePrefixScan() { + stubOneUnderlying.put("a", "1"); + stubOneUnderlying.put("b", "1"); + try (final KeyValueIterator keyValueIterator = theStore.prefixScan("a", new StringSerializer())) { + assertThrows(UnsupportedOperationException.class, keyValueIterator::remove); + } + } + + @Test + public void shouldFindValueForKeyWhenMultiStores() { + final KeyValueStore cache = newStoreInstance(); + stubProviderTwo.addStore(storeName, cache); + + cache.put("key-two", "key-two-value"); + stubOneUnderlying.put("key-one", "key-one-value"); + + assertEquals("key-two-value", theStore.get("key-two")); + assertEquals("key-one-value", theStore.get("key-one")); + } + + @Test + public void shouldSupportRange() { + stubOneUnderlying.put("a", "a"); + stubOneUnderlying.put("b", "b"); + stubOneUnderlying.put("c", "c"); + + final List> results = toList(theStore.range("a", "b")); + assertTrue(results.contains(new KeyValue<>("a", "a"))); + assertTrue(results.contains(new KeyValue<>("b", "b"))); + assertEquals(2, results.size()); + } + + @Test + public void shouldSupportReverseRange() { + stubOneUnderlying.put("a", "a"); + stubOneUnderlying.put("b", "b"); + stubOneUnderlying.put("c", "c"); + + final List> results = toList(theStore.reverseRange("a", "b")); + assertArrayEquals( + asList( + new KeyValue<>("b", "b"), + new KeyValue<>("a", "a") + ).toArray(), + results.toArray()); + } + + @Test + public void shouldReturnKeysWithGivenPrefixExcludingNextKeyLargestKey() { + stubOneUnderlying.put("abc", "a"); + stubOneUnderlying.put("abcd", "b"); + stubOneUnderlying.put("abce", "c"); + + final List> results = toList(theStore.prefixScan("abcd", new StringSerializer())); + assertTrue(results.contains(new KeyValue<>("abcd", "b"))); + assertEquals(1, results.size()); + } + + @Test + public void shouldSupportPrefixScan() { + stubOneUnderlying.put("a", "a"); + stubOneUnderlying.put("aa", "b"); + stubOneUnderlying.put("b", "c"); + + final List> results = toList(theStore.prefixScan("a", new StringSerializer())); + assertTrue(results.contains(new KeyValue<>("a", "a"))); + assertTrue(results.contains(new KeyValue<>("aa", "b"))); + assertEquals(2, results.size()); + } + + @Test + public void shouldSupportRangeAcrossMultipleKVStores() { + final KeyValueStore cache = newStoreInstance(); + stubProviderTwo.addStore(storeName, cache); + + stubOneUnderlying.put("a", "a"); + stubOneUnderlying.put("b", "b"); + stubOneUnderlying.put("z", "z"); + + cache.put("c", "c"); + cache.put("d", "d"); + cache.put("x", "x"); + + final List> results = toList(theStore.range("a", "e")); + assertArrayEquals( + asList( + new KeyValue<>("a", "a"), + new KeyValue<>("b", "b"), + new KeyValue<>("c", "c"), + new KeyValue<>("d", "d") + ).toArray(), + results.toArray()); + } + + @Test + public void shouldSupportPrefixScanAcrossMultipleKVStores() { + final KeyValueStore cache = newStoreInstance(); + stubProviderTwo.addStore(storeName, cache); + + stubOneUnderlying.put("a", "a"); + stubOneUnderlying.put("b", "b"); + stubOneUnderlying.put("z", "z"); + + cache.put("aa", "c"); + cache.put("ab", "d"); + cache.put("x", "x"); + + final List> results = toList(theStore.prefixScan("a", new StringSerializer())); + assertArrayEquals( + asList( + new KeyValue<>("a", "a"), + new KeyValue<>("aa", "c"), + new KeyValue<>("ab", "d") + ).toArray(), + results.toArray()); + } + + @Test + public void shouldSupportReverseRangeAcrossMultipleKVStores() { + final KeyValueStore cache = newStoreInstance(); + stubProviderTwo.addStore(storeName, cache); + + stubOneUnderlying.put("a", "a"); + stubOneUnderlying.put("b", "b"); + stubOneUnderlying.put("z", "z"); + + cache.put("c", "c"); + cache.put("d", "d"); + cache.put("x", "x"); + + final List> results = toList(theStore.reverseRange("a", "e")); + assertTrue(results.contains(new KeyValue<>("a", "a"))); + assertTrue(results.contains(new KeyValue<>("b", "b"))); + assertTrue(results.contains(new KeyValue<>("c", "c"))); + assertTrue(results.contains(new KeyValue<>("d", "d"))); + assertEquals(4, results.size()); + } + + @Test + public void shouldSupportAllAcrossMultipleStores() { + final KeyValueStore cache = newStoreInstance(); + stubProviderTwo.addStore(storeName, cache); + + stubOneUnderlying.put("a", "a"); + stubOneUnderlying.put("b", "b"); + stubOneUnderlying.put("z", "z"); + + cache.put("c", "c"); + cache.put("d", "d"); + cache.put("x", "x"); + + final List> results = toList(theStore.all()); + assertTrue(results.contains(new KeyValue<>("a", "a"))); + assertTrue(results.contains(new KeyValue<>("b", "b"))); + assertTrue(results.contains(new KeyValue<>("c", "c"))); + assertTrue(results.contains(new KeyValue<>("d", "d"))); + assertTrue(results.contains(new KeyValue<>("x", "x"))); + assertTrue(results.contains(new KeyValue<>("z", "z"))); + assertEquals(6, results.size()); + } + + @Test + public void shouldSupportReverseAllAcrossMultipleStores() { + final KeyValueStore cache = newStoreInstance(); + stubProviderTwo.addStore(storeName, cache); + + stubOneUnderlying.put("a", "a"); + stubOneUnderlying.put("b", "b"); + stubOneUnderlying.put("z", "z"); + + cache.put("c", "c"); + cache.put("d", "d"); + cache.put("x", "x"); + + final List> results = toList(theStore.reverseAll()); + assertTrue(results.contains(new KeyValue<>("a", "a"))); + assertTrue(results.contains(new KeyValue<>("b", "b"))); + assertTrue(results.contains(new KeyValue<>("c", "c"))); + assertTrue(results.contains(new KeyValue<>("d", "d"))); + assertTrue(results.contains(new KeyValue<>("x", "x"))); + assertTrue(results.contains(new KeyValue<>("z", "z"))); + assertEquals(6, results.size()); + } + + @Test + public void shouldThrowInvalidStoreExceptionDuringRebalance() { + assertThrows(InvalidStateStoreException.class, () -> rebalancing().get("anything")); + } + + @Test + public void shouldThrowInvalidStoreExceptionOnApproximateNumEntriesDuringRebalance() { + assertThrows(InvalidStateStoreException.class, () -> rebalancing().approximateNumEntries()); + } + + @Test + public void shouldThrowInvalidStoreExceptionOnRangeDuringRebalance() { + assertThrows(InvalidStateStoreException.class, () -> rebalancing().range("anything", "something")); + } + + @Test + public void shouldThrowInvalidStoreExceptionOnReverseRangeDuringRebalance() { + assertThrows(InvalidStateStoreException.class, () -> rebalancing().reverseRange("anything", "something")); + } + + @Test + public void shouldThrowInvalidStoreExceptionOnPrefixScanDuringRebalance() { + assertThrows(InvalidStateStoreException.class, () -> rebalancing().prefixScan("anything", new StringSerializer())); + } + + @Test + public void shouldThrowInvalidStoreExceptionOnAllDuringRebalance() { + assertThrows(InvalidStateStoreException.class, () -> rebalancing().all()); + } + + @Test + public void shouldThrowInvalidStoreExceptionOnReverseAllDuringRebalance() { + assertThrows(InvalidStateStoreException.class, () -> rebalancing().reverseAll()); + } + + @Test + public void shouldGetApproximateEntriesAcrossAllStores() { + final KeyValueStore cache = newStoreInstance(); + stubProviderTwo.addStore(storeName, cache); + + stubOneUnderlying.put("a", "a"); + stubOneUnderlying.put("b", "b"); + stubOneUnderlying.put("z", "z"); + + cache.put("c", "c"); + cache.put("d", "d"); + cache.put("x", "x"); + + assertEquals(6, theStore.approximateNumEntries()); + } + + @Test + public void shouldReturnLongMaxValueOnOverflow() { + stubProviderTwo.addStore(storeName, new NoOpReadOnlyStore() { + @Override + public long approximateNumEntries() { + return Long.MAX_VALUE; + } + }); + + stubOneUnderlying.put("overflow", "me"); + assertEquals(Long.MAX_VALUE, theStore.approximateNumEntries()); + } + + @Test + public void shouldReturnLongMaxValueOnUnderflow() { + stubProviderTwo.addStore(storeName, new NoOpReadOnlyStore() { + @Override + public long approximateNumEntries() { + return Long.MAX_VALUE; + } + }); + stubProviderTwo.addStore("my-storeA", new NoOpReadOnlyStore() { + @Override + public long approximateNumEntries() { + return Long.MAX_VALUE; + } + }); + + assertEquals(Long.MAX_VALUE, theStore.approximateNumEntries()); + } + + private CompositeReadOnlyKeyValueStore rebalancing() { + return new CompositeReadOnlyKeyValueStore<>( + new WrappingStoreProvider( + singletonList(new StateStoreProviderStub(true)), + StoreQueryParameters.fromNameAndType(storeName, QueryableStoreTypes.keyValueStore())), + QueryableStoreTypes.keyValueStore(), + storeName + ); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/CompositeReadOnlySessionStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/CompositeReadOnlySessionStoreTest.java new file mode 100644 index 0000000..c2d38de --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/CompositeReadOnlySessionStoreTest.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StoreQueryParameters; +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.SessionWindow; +import org.apache.kafka.streams.state.ReadOnlySessionStore; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.QueryableStoreType; +import org.apache.kafka.streams.state.QueryableStoreTypes; +import org.apache.kafka.test.ReadOnlySessionStoreStub; +import org.apache.kafka.test.StateStoreProviderStub; +import org.apache.kafka.test.StreamsTestUtils; +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.List; + +import static java.util.Collections.singletonList; +import static org.apache.kafka.test.StreamsTestUtils.toList; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.IsEqual.equalTo; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.fail; + +public class CompositeReadOnlySessionStoreTest { + + private final String storeName = "session-store"; + private final StateStoreProviderStub stubProviderOne = new StateStoreProviderStub(false); + private final StateStoreProviderStub stubProviderTwo = new StateStoreProviderStub(false); + private final ReadOnlySessionStoreStub underlyingSessionStore = new ReadOnlySessionStoreStub<>(); + private final ReadOnlySessionStoreStub otherUnderlyingStore = new ReadOnlySessionStoreStub<>(); + private CompositeReadOnlySessionStore sessionStore; + + @Before + public void before() { + stubProviderOne.addStore(storeName, underlyingSessionStore); + stubProviderOne.addStore("other-session-store", otherUnderlyingStore); + final QueryableStoreType> queryableStoreType = QueryableStoreTypes.sessionStore(); + + sessionStore = new CompositeReadOnlySessionStore<>( + new WrappingStoreProvider(Arrays.asList(stubProviderOne, stubProviderTwo), StoreQueryParameters.fromNameAndType(storeName, queryableStoreType)), + QueryableStoreTypes.sessionStore(), storeName); + } + + @Test + public void shouldFetchResulstFromUnderlyingSessionStore() { + underlyingSessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L); + underlyingSessionStore.put(new Windowed<>("a", new SessionWindow(10, 10)), 2L); + + final List, Long>> results = toList(sessionStore.fetch("a")); + assertEquals(Arrays.asList(KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 1L), + KeyValue.pair(new Windowed<>("a", new SessionWindow(10, 10)), 2L)), + results); + } + + @Test + public void shouldReturnEmptyIteratorIfNoData() { + try (final KeyValueIterator, Long> result = sessionStore.fetch("b")) { + assertFalse(result.hasNext()); + } + } + + @Test + public void shouldFindValueForKeyWhenMultiStores() { + final ReadOnlySessionStoreStub secondUnderlying = new + ReadOnlySessionStoreStub<>(); + stubProviderTwo.addStore(storeName, secondUnderlying); + + final Windowed keyOne = new Windowed<>("key-one", new SessionWindow(0, 0)); + final Windowed keyTwo = new Windowed<>("key-two", new SessionWindow(0, 0)); + underlyingSessionStore.put(keyOne, 0L); + secondUnderlying.put(keyTwo, 10L); + + final List, Long>> keyOneResults = toList(sessionStore.fetch("key-one")); + final List, Long>> keyTwoResults = toList(sessionStore.fetch("key-two")); + + assertEquals(singletonList(KeyValue.pair(keyOne, 0L)), keyOneResults); + assertEquals(singletonList(KeyValue.pair(keyTwo, 10L)), keyTwoResults); + } + + @Test + public void shouldNotGetValueFromOtherStores() { + final Windowed expectedKey = new Windowed<>("foo", new SessionWindow(0, 0)); + otherUnderlyingStore.put(new Windowed<>("foo", new SessionWindow(10, 10)), 10L); + underlyingSessionStore.put(expectedKey, 1L); + + try (final KeyValueIterator, Long> result = sessionStore.fetch("foo")) { + assertEquals(KeyValue.pair(expectedKey, 1L), result.next()); + assertFalse(result.hasNext()); + } + } + + @Test + public void shouldThrowInvalidStateStoreExceptionOnRebalance() { + final QueryableStoreType> queryableStoreType = QueryableStoreTypes.sessionStore(); + final CompositeReadOnlySessionStore store = + new CompositeReadOnlySessionStore<>( + new WrappingStoreProvider(singletonList(new StateStoreProviderStub(true)), StoreQueryParameters.fromNameAndType("whateva", queryableStoreType)), + QueryableStoreTypes.sessionStore(), + "whateva" + ); + + assertThrows(InvalidStateStoreException.class, () -> store.fetch("a")); + } + + @Test + public void shouldThrowInvalidStateStoreExceptionIfSessionFetchThrows() { + underlyingSessionStore.setOpen(false); + try { + sessionStore.fetch("key"); + fail("Should have thrown InvalidStateStoreException with session store"); + } catch (final InvalidStateStoreException e) { } + } + + @Test + public void shouldThrowNullPointerExceptionIfFetchingNullKey() { + assertThrows(NullPointerException.class, () -> sessionStore.fetch(null)); + } + + @Test + public void shouldFetchKeyRangeAcrossStores() { + final ReadOnlySessionStoreStub secondUnderlying = new + ReadOnlySessionStoreStub<>(); + stubProviderTwo.addStore(storeName, secondUnderlying); + underlyingSessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 0L); + secondUnderlying.put(new Windowed<>("b", new SessionWindow(0, 0)), 10L); + final List, Long>> results = StreamsTestUtils.toList(sessionStore.fetch("a", "b")); + assertThat(results, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 0L), + KeyValue.pair(new Windowed<>("b", new SessionWindow(0, 0)), 10L)))); + } + + @Test + public void shouldFetchKeyRangeAcrossStoresWithNullKeyFrom() { + final ReadOnlySessionStoreStub secondUnderlying = new + ReadOnlySessionStoreStub<>(); + stubProviderTwo.addStore(storeName, secondUnderlying); + underlyingSessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 0L); + secondUnderlying.put(new Windowed<>("b", new SessionWindow(0, 0)), 10L); + final List, Long>> results = StreamsTestUtils.toList(sessionStore.fetch(null, "b")); + assertThat(results, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 0L), + KeyValue.pair(new Windowed<>("b", new SessionWindow(0, 0)), 10L)))); + } + + @Test + public void shouldFetchKeyRangeAcrossStoresWithNullKeyTo() { + final ReadOnlySessionStoreStub secondUnderlying = new + ReadOnlySessionStoreStub<>(); + stubProviderTwo.addStore(storeName, secondUnderlying); + underlyingSessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 0L); + secondUnderlying.put(new Windowed<>("b", new SessionWindow(0, 0)), 10L); + final List, Long>> results = StreamsTestUtils.toList(sessionStore.fetch("a", null)); + assertThat(results, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 0L), + KeyValue.pair(new Windowed<>("b", new SessionWindow(0, 0)), 10L)))); + } + + @Test + public void shouldFetchKeyRangeAcrossStoresWithNullKeyFromKeyTo() { + final ReadOnlySessionStoreStub secondUnderlying = new + ReadOnlySessionStoreStub<>(); + stubProviderTwo.addStore(storeName, secondUnderlying); + underlyingSessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 0L); + secondUnderlying.put(new Windowed<>("b", new SessionWindow(0, 0)), 10L); + final List, Long>> results = StreamsTestUtils.toList(sessionStore.fetch(null, null)); + assertThat(results, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("a", new SessionWindow(0, 0)), 0L), + KeyValue.pair(new Windowed<>("b", new SessionWindow(0, 0)), 10L)))); + } + + @Test + public void shouldThrowNPEIfKeyIsNull() { + assertThrows(NullPointerException.class, () -> underlyingSessionStore.fetch(null)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/CompositeReadOnlyWindowStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/CompositeReadOnlyWindowStoreTest.java new file mode 100644 index 0000000..db15276 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/CompositeReadOnlyWindowStoreTest.java @@ -0,0 +1,531 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StoreQueryParameters; +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.TimeWindow; +import org.apache.kafka.streams.state.QueryableStoreTypes; +import org.apache.kafka.streams.state.WindowStoreIterator; +import org.apache.kafka.test.StateStoreProviderStub; +import org.apache.kafka.test.StreamsTestUtils; +import org.easymock.EasyMock; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.NoSuchElementException; + +import static java.time.Instant.ofEpochMilli; +import static java.util.Arrays.asList; +import static java.util.Collections.emptyList; +import static java.util.Collections.singletonList; +import static org.easymock.EasyMock.anyObject; +import static org.easymock.EasyMock.anyString; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.IsEqual.equalTo; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; + +public class CompositeReadOnlyWindowStoreTest { + + private static final long WINDOW_SIZE = 30_000; + + private final String storeName = "window-store"; + private StateStoreProviderStub stubProviderOne; + private StateStoreProviderStub stubProviderTwo; + private CompositeReadOnlyWindowStore windowStore; + private ReadOnlyWindowStoreStub underlyingWindowStore; + private ReadOnlyWindowStoreStub otherUnderlyingStore; + + @Before + public void before() { + stubProviderOne = new StateStoreProviderStub(false); + stubProviderTwo = new StateStoreProviderStub(false); + underlyingWindowStore = new ReadOnlyWindowStoreStub<>(WINDOW_SIZE); + stubProviderOne.addStore(storeName, underlyingWindowStore); + + otherUnderlyingStore = new ReadOnlyWindowStoreStub<>(WINDOW_SIZE); + stubProviderOne.addStore("other-window-store", otherUnderlyingStore); + + windowStore = new CompositeReadOnlyWindowStore<>( + new WrappingStoreProvider(asList(stubProviderOne, stubProviderTwo), StoreQueryParameters.fromNameAndType(storeName, QueryableStoreTypes.windowStore())), + QueryableStoreTypes.windowStore(), + storeName + ); + } + + @Test + public void shouldFetchValuesFromWindowStore() { + underlyingWindowStore.put("my-key", "my-value", 0L); + underlyingWindowStore.put("my-key", "my-later-value", 10L); + + assertEquals( + asList(new KeyValue<>(0L, "my-value"), new KeyValue<>(10L, "my-later-value")), + StreamsTestUtils.toList(windowStore.fetch("my-key", ofEpochMilli(0L), ofEpochMilli(25L))) + ); + } + + + @Test + public void shouldBackwardFetchValuesFromWindowStore() { + underlyingWindowStore.put("my-key", "my-value", 0L); + underlyingWindowStore.put("my-key", "my-later-value", 10L); + + assertEquals( + asList(new KeyValue<>(10L, "my-later-value"), new KeyValue<>(0L, "my-value")), + StreamsTestUtils.toList(windowStore.backwardFetch("my-key", ofEpochMilli(0L), ofEpochMilli(25L))) + ); + } + + @Test + public void shouldReturnEmptyIteratorIfNoData() { + try (final WindowStoreIterator iterator = + windowStore.fetch("my-key", ofEpochMilli(0L), ofEpochMilli(25L))) { + assertFalse(iterator.hasNext()); + } + } + + @Test + public void shouldReturnBackwardEmptyIteratorIfNoData() { + try (final WindowStoreIterator iterator = + windowStore.backwardFetch("my-key", ofEpochMilli(0L), ofEpochMilli(25L))) { + assertFalse(iterator.hasNext()); + } + } + + @Test + public void shouldFindValueForKeyWhenMultiStores() { + final ReadOnlyWindowStoreStub secondUnderlying = new + ReadOnlyWindowStoreStub<>(WINDOW_SIZE); + stubProviderTwo.addStore(storeName, secondUnderlying); + + underlyingWindowStore.put("key-one", "value-one", 0L); + secondUnderlying.put("key-two", "value-two", 10L); + + final List> keyOneResults = + StreamsTestUtils.toList(windowStore.fetch("key-one", ofEpochMilli(0L), ofEpochMilli(1L))); + final List> keyTwoResults = + StreamsTestUtils.toList(windowStore.fetch("key-two", ofEpochMilli(10L), ofEpochMilli(11L))); + + assertEquals(Collections.singletonList(KeyValue.pair(0L, "value-one")), keyOneResults); + assertEquals(Collections.singletonList(KeyValue.pair(10L, "value-two")), keyTwoResults); + } + + @Test + public void shouldFindValueForKeyWhenMultiStoresBackwards() { + final ReadOnlyWindowStoreStub secondUnderlying = new + ReadOnlyWindowStoreStub<>(WINDOW_SIZE); + stubProviderTwo.addStore(storeName, secondUnderlying); + + underlyingWindowStore.put("key-one", "value-one", 0L); + secondUnderlying.put("key-two", "value-two", 10L); + + final List> keyOneResults = + StreamsTestUtils.toList(windowStore.backwardFetch("key-one", ofEpochMilli(0L), ofEpochMilli(1L))); + final List> keyTwoResults = + StreamsTestUtils.toList(windowStore.backwardFetch("key-two", ofEpochMilli(10L), ofEpochMilli(11L))); + + assertEquals(Collections.singletonList(KeyValue.pair(0L, "value-one")), keyOneResults); + assertEquals(Collections.singletonList(KeyValue.pair(10L, "value-two")), keyTwoResults); + } + + @Test + public void shouldNotGetValuesFromOtherStores() { + otherUnderlyingStore.put("some-key", "some-value", 0L); + underlyingWindowStore.put("some-key", "my-value", 1L); + + final List> results = + StreamsTestUtils.toList(windowStore.fetch("some-key", ofEpochMilli(0L), ofEpochMilli(2L))); + assertEquals(Collections.singletonList(new KeyValue<>(1L, "my-value")), results); + } + + @Test + public void shouldNotGetValuesBackwardFromOtherStores() { + otherUnderlyingStore.put("some-key", "some-value", 0L); + underlyingWindowStore.put("some-key", "my-value", 1L); + + final List> results = + StreamsTestUtils.toList(windowStore.backwardFetch("some-key", ofEpochMilli(0L), ofEpochMilli(2L))); + assertEquals(Collections.singletonList(new KeyValue<>(1L, "my-value")), results); + } + + @Test + public void shouldThrowInvalidStateStoreExceptionOnRebalance() { + final StateStoreProvider storeProvider = EasyMock.createNiceMock(StateStoreProvider.class); + EasyMock.expect(storeProvider.stores(anyString(), anyObject())) + .andThrow(new InvalidStateStoreException("store is unavailable")); + EasyMock.replay(storeProvider); + + final CompositeReadOnlyWindowStore store = new CompositeReadOnlyWindowStore<>( + storeProvider, + QueryableStoreTypes.windowStore(), + "foo" + ); + + assertThrows(InvalidStateStoreException.class, () -> store.fetch("key", ofEpochMilli(1), ofEpochMilli(10))); + } + + @Test + public void shouldThrowInvalidStateStoreExceptionOnRebalanceWhenBackwards() { + final StateStoreProvider storeProvider = EasyMock.createNiceMock(StateStoreProvider.class); + EasyMock.expect(storeProvider.stores(anyString(), anyObject())) + .andThrow(new InvalidStateStoreException("store is unavailable")); + EasyMock.replay(storeProvider); + + final CompositeReadOnlyWindowStore store = new CompositeReadOnlyWindowStore<>( + storeProvider, + QueryableStoreTypes.windowStore(), + "foo" + ); + assertThrows(InvalidStateStoreException.class, () -> store.backwardFetch("key", ofEpochMilli(1), ofEpochMilli(10))); + } + + @Test + public void shouldThrowInvalidStateStoreExceptionIfFetchThrows() { + underlyingWindowStore.setOpen(false); + final CompositeReadOnlyWindowStore store = + new CompositeReadOnlyWindowStore<>( + new WrappingStoreProvider(singletonList(stubProviderOne), StoreQueryParameters.fromNameAndType("window-store", QueryableStoreTypes.windowStore())), + QueryableStoreTypes.windowStore(), + "window-store" + ); + try { + store.fetch("key", ofEpochMilli(1), ofEpochMilli(10)); + Assert.fail("InvalidStateStoreException was expected"); + } catch (final InvalidStateStoreException e) { + Assert.assertEquals("State store is not available anymore and may have been migrated to another instance; " + + "please re-discover its location from the state metadata.", e.getMessage()); + } + } + + @Test + public void shouldThrowInvalidStateStoreExceptionIfBackwardFetchThrows() { + underlyingWindowStore.setOpen(false); + final CompositeReadOnlyWindowStore store = + new CompositeReadOnlyWindowStore<>( + new WrappingStoreProvider(singletonList(stubProviderOne), StoreQueryParameters.fromNameAndType("window-store", QueryableStoreTypes.windowStore())), + QueryableStoreTypes.windowStore(), + "window-store" + ); + try { + store.backwardFetch("key", ofEpochMilli(1), ofEpochMilli(10)); + Assert.fail("InvalidStateStoreException was expected"); + } catch (final InvalidStateStoreException e) { + Assert.assertEquals("State store is not available anymore and may have been migrated to another instance; " + + "please re-discover its location from the state metadata.", e.getMessage()); + } + } + + @Test + public void emptyBackwardIteratorAlwaysReturnsFalse() { + final StateStoreProvider storeProvider = EasyMock.createNiceMock(StateStoreProvider.class); + EasyMock.expect(storeProvider.stores(anyString(), anyObject())).andReturn(emptyList()); + EasyMock.replay(storeProvider); + + final CompositeReadOnlyWindowStore store = new CompositeReadOnlyWindowStore<>( + storeProvider, + QueryableStoreTypes.windowStore(), + "foo" + ); + try (final WindowStoreIterator windowStoreIterator = + store.backwardFetch("key", ofEpochMilli(1), ofEpochMilli(10))) { + + Assert.assertFalse(windowStoreIterator.hasNext()); + } + } + + @Test + public void emptyIteratorAlwaysReturnsFalse() { + final StateStoreProvider storeProvider = EasyMock.createNiceMock(StateStoreProvider.class); + EasyMock.expect(storeProvider.stores(anyString(), anyObject())).andReturn(emptyList()); + EasyMock.replay(storeProvider); + + final CompositeReadOnlyWindowStore store = new CompositeReadOnlyWindowStore<>( + storeProvider, + QueryableStoreTypes.windowStore(), + "foo" + ); + try (final WindowStoreIterator windowStoreIterator = + store.fetch("key", ofEpochMilli(1), ofEpochMilli(10))) { + + Assert.assertFalse(windowStoreIterator.hasNext()); + } + } + + @Test + public void emptyBackwardIteratorPeekNextKeyShouldThrowNoSuchElementException() { + final StateStoreProvider storeProvider = EasyMock.createNiceMock(StateStoreProvider.class); + EasyMock.expect(storeProvider.stores(anyString(), anyObject())).andReturn(emptyList()); + EasyMock.replay(storeProvider); + + final CompositeReadOnlyWindowStore store = new CompositeReadOnlyWindowStore<>( + storeProvider, + QueryableStoreTypes.windowStore(), + "foo" + ); + try (final WindowStoreIterator windowStoreIterator = store.backwardFetch("key", ofEpochMilli(1), ofEpochMilli(10))) { + assertThrows(NoSuchElementException.class, windowStoreIterator::peekNextKey); + } + } + + + @Test + public void emptyIteratorPeekNextKeyShouldThrowNoSuchElementException() { + final StateStoreProvider storeProvider = EasyMock.createNiceMock(StateStoreProvider.class); + EasyMock.expect(storeProvider.stores(anyString(), anyObject())).andReturn(emptyList()); + EasyMock.replay(storeProvider); + + final CompositeReadOnlyWindowStore store = new CompositeReadOnlyWindowStore<>( + storeProvider, + QueryableStoreTypes.windowStore(), + "foo" + ); + try (final WindowStoreIterator windowStoreIterator = + store.fetch("key", ofEpochMilli(1), ofEpochMilli(10))) { + assertThrows(NoSuchElementException.class, windowStoreIterator::peekNextKey); + } + } + + @Test + public void emptyIteratorNextShouldThrowNoSuchElementException() { + final StateStoreProvider storeProvider = EasyMock.createNiceMock(StateStoreProvider.class); + EasyMock.expect(storeProvider.stores(anyString(), anyObject())).andReturn(emptyList()); + EasyMock.replay(storeProvider); + + final CompositeReadOnlyWindowStore store = new CompositeReadOnlyWindowStore<>( + storeProvider, + QueryableStoreTypes.windowStore(), + "foo" + ); + try (final WindowStoreIterator windowStoreIterator = + store.fetch("key", ofEpochMilli(1), ofEpochMilli(10))) { + assertThrows(NoSuchElementException.class, windowStoreIterator::next); + } + } + + @Test + public void emptyBackwardIteratorNextShouldThrowNoSuchElementException() { + final StateStoreProvider storeProvider = EasyMock.createNiceMock(StateStoreProvider.class); + EasyMock.expect(storeProvider.stores(anyString(), anyObject())).andReturn(emptyList()); + EasyMock.replay(storeProvider); + + final CompositeReadOnlyWindowStore store = new CompositeReadOnlyWindowStore<>( + storeProvider, + QueryableStoreTypes.windowStore(), + "foo" + ); + try (final WindowStoreIterator windowStoreIterator = + store.backwardFetch("key", ofEpochMilli(1), ofEpochMilli(10))) { + assertThrows(NoSuchElementException.class, windowStoreIterator::next); + } + } + + @Test + public void shouldFetchKeyRangeAcrossStores() { + final ReadOnlyWindowStoreStub secondUnderlying = new ReadOnlyWindowStoreStub<>(WINDOW_SIZE); + stubProviderTwo.addStore(storeName, secondUnderlying); + underlyingWindowStore.put("a", "a", 0L); + secondUnderlying.put("b", "b", 10L); + final List, String>> results = + StreamsTestUtils.toList(windowStore.fetch("a", "b", ofEpochMilli(0), ofEpochMilli(10))); + assertThat(results, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("a", new TimeWindow(0, WINDOW_SIZE)), "a"), + KeyValue.pair(new Windowed<>("b", new TimeWindow(10, 10 + WINDOW_SIZE)), "b")))); + } + + @Test + public void shouldFetchKeyRangeAcrossStoresWithNullKeyTo() { + final ReadOnlyWindowStoreStub secondUnderlying = new ReadOnlyWindowStoreStub<>(WINDOW_SIZE); + stubProviderTwo.addStore(storeName, secondUnderlying); + underlyingWindowStore.put("a", "a", 0L); + secondUnderlying.put("b", "b", 10L); + secondUnderlying.put("c", "c", 10L); + final List, String>> results = + StreamsTestUtils.toList(windowStore.fetch("a", null, ofEpochMilli(0), ofEpochMilli(10))); + assertThat(results, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("a", new TimeWindow(0, WINDOW_SIZE)), "a"), + KeyValue.pair(new Windowed<>("b", new TimeWindow(10, 10 + WINDOW_SIZE)), "b"), + KeyValue.pair(new Windowed<>("c", new TimeWindow(10, 10 + WINDOW_SIZE)), "c")))); + } + + @Test + public void shouldFetchKeyRangeAcrossStoresWithNullKeyFrom() { + final ReadOnlyWindowStoreStub secondUnderlying = new ReadOnlyWindowStoreStub<>(WINDOW_SIZE); + stubProviderTwo.addStore(storeName, secondUnderlying); + underlyingWindowStore.put("a", "a", 0L); + secondUnderlying.put("b", "b", 10L); + secondUnderlying.put("c", "c", 10L); + final List, String>> results = + StreamsTestUtils.toList(windowStore.fetch(null, "c", ofEpochMilli(0), ofEpochMilli(10))); + assertThat(results, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("a", new TimeWindow(0, WINDOW_SIZE)), "a"), + KeyValue.pair(new Windowed<>("b", new TimeWindow(10, 10 + WINDOW_SIZE)), "b"), + KeyValue.pair(new Windowed<>("c", new TimeWindow(10, 10 + WINDOW_SIZE)), "c")))); + } + + @Test + public void shouldFetchKeyRangeAcrossStoresWithNullKeyFromKeyTo() { + final ReadOnlyWindowStoreStub secondUnderlying = new ReadOnlyWindowStoreStub<>(WINDOW_SIZE); + stubProviderTwo.addStore(storeName, secondUnderlying); + underlyingWindowStore.put("a", "a", 0L); + secondUnderlying.put("b", "b", 10L); + secondUnderlying.put("c", "c", 10L); + final List, String>> results = + StreamsTestUtils.toList(windowStore.fetch(null, null, ofEpochMilli(0), ofEpochMilli(10))); + assertThat(results, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("a", new TimeWindow(0, WINDOW_SIZE)), "a"), + KeyValue.pair(new Windowed<>("b", new TimeWindow(10, 10 + WINDOW_SIZE)), "b"), + KeyValue.pair(new Windowed<>("c", new TimeWindow(10, 10 + WINDOW_SIZE)), "c")))); + } + + @Test + public void shouldBackwardFetchKeyRangeAcrossStoresWithNullKeyTo() { + final ReadOnlyWindowStoreStub secondUnderlying = new ReadOnlyWindowStoreStub<>(WINDOW_SIZE); + stubProviderTwo.addStore(storeName, secondUnderlying); + underlyingWindowStore.put("a", "a", 0L); + secondUnderlying.put("b", "b", 10L); + secondUnderlying.put("c", "c", 10L); + final List, String>> results = + StreamsTestUtils.toList(windowStore.backwardFetch("a", null, ofEpochMilli(0), ofEpochMilli(10))); + assertThat(results, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("a", new TimeWindow(0, WINDOW_SIZE)), "a"), + KeyValue.pair(new Windowed<>("c", new TimeWindow(10, 10 + WINDOW_SIZE)), "c"), + KeyValue.pair(new Windowed<>("b", new TimeWindow(10, 10 + WINDOW_SIZE)), "b")))); + } + + @Test + public void shouldBackwardFetchKeyRangeAcrossStoresWithNullKeyFrom() { + final ReadOnlyWindowStoreStub secondUnderlying = new ReadOnlyWindowStoreStub<>(WINDOW_SIZE); + stubProviderTwo.addStore(storeName, secondUnderlying); + underlyingWindowStore.put("a", "a", 0L); + secondUnderlying.put("b", "b", 10L); + secondUnderlying.put("c", "c", 10L); + final List, String>> results = + StreamsTestUtils.toList(windowStore.backwardFetch(null, "c", ofEpochMilli(0), ofEpochMilli(10))); + assertThat(results, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("a", new TimeWindow(0, WINDOW_SIZE)), "a"), + KeyValue.pair(new Windowed<>("c", new TimeWindow(10, 10 + WINDOW_SIZE)), "c"), + KeyValue.pair(new Windowed<>("b", new TimeWindow(10, 10 + WINDOW_SIZE)), "b") + ))); + } + + @Test + public void shouldBackwardFetchKeyRangeAcrossStoresWithNullKeyFromKeyTo() { + final ReadOnlyWindowStoreStub secondUnderlying = new ReadOnlyWindowStoreStub<>(WINDOW_SIZE); + stubProviderTwo.addStore(storeName, secondUnderlying); + underlyingWindowStore.put("a", "a", 0L); + secondUnderlying.put("b", "b", 10L); + secondUnderlying.put("c", "c", 10L); + final List, String>> results = + StreamsTestUtils.toList(windowStore.backwardFetch(null, null, ofEpochMilli(0), ofEpochMilli(10))); + assertThat(results, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("a", new TimeWindow(0, WINDOW_SIZE)), "a"), + KeyValue.pair(new Windowed<>("c", new TimeWindow(10, 10 + WINDOW_SIZE)), "c"), + KeyValue.pair(new Windowed<>("b", new TimeWindow(10, 10 + WINDOW_SIZE)), "b")))); + } + + @Test + public void shouldBackwardFetchKeyRangeAcrossStores() { + final ReadOnlyWindowStoreStub secondUnderlying = new ReadOnlyWindowStoreStub<>(WINDOW_SIZE); + stubProviderTwo.addStore(storeName, secondUnderlying); + underlyingWindowStore.put("a", "a", 0L); + secondUnderlying.put("b", "b", 10L); + final List, String>> results = + StreamsTestUtils.toList(windowStore.backwardFetch("a", "b", ofEpochMilli(0), ofEpochMilli(10))); + assertThat(results, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("a", new TimeWindow(0, WINDOW_SIZE)), "a"), + KeyValue.pair(new Windowed<>("b", new TimeWindow(10, 10 + WINDOW_SIZE)), "b")))); + } + + @Test + public void shouldFetchKeyValueAcrossStores() { + final ReadOnlyWindowStoreStub secondUnderlyingWindowStore = new ReadOnlyWindowStoreStub<>(WINDOW_SIZE); + stubProviderTwo.addStore(storeName, secondUnderlyingWindowStore); + underlyingWindowStore.put("a", "a", 0L); + secondUnderlyingWindowStore.put("b", "b", 10L); + assertThat(windowStore.fetch("a", 0L), equalTo("a")); + assertThat(windowStore.fetch("b", 10L), equalTo("b")); + assertThat(windowStore.fetch("c", 10L), equalTo(null)); + assertThat(windowStore.fetch("a", 10L), equalTo(null)); + } + + @Test + public void shouldGetAllAcrossStores() { + final ReadOnlyWindowStoreStub secondUnderlying = new + ReadOnlyWindowStoreStub<>(WINDOW_SIZE); + stubProviderTwo.addStore(storeName, secondUnderlying); + underlyingWindowStore.put("a", "a", 0L); + secondUnderlying.put("b", "b", 10L); + final List, String>> results = StreamsTestUtils.toList(windowStore.all()); + assertThat(results, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("a", new TimeWindow(0, WINDOW_SIZE)), "a"), + KeyValue.pair(new Windowed<>("b", new TimeWindow(10, 10 + WINDOW_SIZE)), "b")))); + } + + @Test + public void shouldGetBackwardAllAcrossStores() { + final ReadOnlyWindowStoreStub secondUnderlying = new + ReadOnlyWindowStoreStub<>(WINDOW_SIZE); + stubProviderTwo.addStore(storeName, secondUnderlying); + underlyingWindowStore.put("a", "a", 0L); + secondUnderlying.put("b", "b", 10L); + final List, String>> results = StreamsTestUtils.toList(windowStore.backwardAll()); + assertThat(results, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("a", new TimeWindow(0, WINDOW_SIZE)), "a"), + KeyValue.pair(new Windowed<>("b", new TimeWindow(10, 10 + WINDOW_SIZE)), "b")))); + } + + @Test + public void shouldFetchAllAcrossStores() { + final ReadOnlyWindowStoreStub secondUnderlying = new + ReadOnlyWindowStoreStub<>(WINDOW_SIZE); + stubProviderTwo.addStore(storeName, secondUnderlying); + underlyingWindowStore.put("a", "a", 0L); + secondUnderlying.put("b", "b", 10L); + final List, String>> results = + StreamsTestUtils.toList(windowStore.fetchAll(ofEpochMilli(0), ofEpochMilli(10))); + assertThat(results, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("a", new TimeWindow(0, WINDOW_SIZE)), "a"), + KeyValue.pair(new Windowed<>("b", new TimeWindow(10, 10 + WINDOW_SIZE)), "b")))); + } + + @Test + public void shouldBackwardFetchAllAcrossStores() { + final ReadOnlyWindowStoreStub secondUnderlying = new + ReadOnlyWindowStoreStub<>(WINDOW_SIZE); + stubProviderTwo.addStore(storeName, secondUnderlying); + underlyingWindowStore.put("a", "a", 0L); + secondUnderlying.put("b", "b", 10L); + final List, String>> results = + StreamsTestUtils.toList(windowStore.backwardFetchAll(ofEpochMilli(0), ofEpochMilli(10))); + assertThat(results, equalTo(Arrays.asList( + KeyValue.pair(new Windowed<>("a", new TimeWindow(0, WINDOW_SIZE)), "a"), + KeyValue.pair(new Windowed<>("b", new TimeWindow(10, 10 + WINDOW_SIZE)), "b")))); + } + + @Test + public void shouldThrowNPEIfKeyIsNull() { + assertThrows(NullPointerException.class, () -> windowStore.fetch(null, ofEpochMilli(0), ofEpochMilli(0))); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/DelegatingPeekingKeyValueIteratorTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/DelegatingPeekingKeyValueIteratorTest.java new file mode 100644 index 0000000..9e0ef8d --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/DelegatingPeekingKeyValueIteratorTest.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.test.GenericInMemoryKeyValueStore; +import org.junit.Before; +import org.junit.Test; + +import java.util.NoSuchElementException; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class DelegatingPeekingKeyValueIteratorTest { + + private final String name = "name"; + private KeyValueStore store; + + @Before + public void setUp() { + store = new GenericInMemoryKeyValueStore<>(name); + } + + @Test + public void shouldPeekNextKey() { + store.put("A", "A"); + final DelegatingPeekingKeyValueIterator peekingIterator = new DelegatingPeekingKeyValueIterator<>(name, store.all()); + assertEquals("A", peekingIterator.peekNextKey()); + assertEquals("A", peekingIterator.peekNextKey()); + assertTrue(peekingIterator.hasNext()); + peekingIterator.close(); + } + + @Test + public void shouldPeekNext() { + store.put("A", "A"); + final DelegatingPeekingKeyValueIterator peekingIterator = new DelegatingPeekingKeyValueIterator<>(name, store.all()); + assertEquals(KeyValue.pair("A", "A"), peekingIterator.peekNext()); + assertEquals(KeyValue.pair("A", "A"), peekingIterator.peekNext()); + assertTrue(peekingIterator.hasNext()); + peekingIterator.close(); + } + + @Test + public void shouldPeekAndIterate() { + final String[] kvs = {"a", "b", "c", "d", "e", "f"}; + for (final String kv : kvs) { + store.put(kv, kv); + } + + final DelegatingPeekingKeyValueIterator peekingIterator = new DelegatingPeekingKeyValueIterator<>(name, store.all()); + int index = 0; + while (peekingIterator.hasNext()) { + final String peekNext = peekingIterator.peekNextKey(); + final String key = peekingIterator.next().key; + assertEquals(kvs[index], peekNext); + assertEquals(kvs[index], key); + index++; + } + assertEquals(kvs.length, index); + peekingIterator.close(); + } + + @Test + public void shouldThrowNoSuchElementWhenNoMoreItemsLeftAndNextCalled() { + try (final DelegatingPeekingKeyValueIterator peekingIterator = + new DelegatingPeekingKeyValueIterator<>(name, store.all())) { + assertThrows(NoSuchElementException.class, peekingIterator::next); + } + } + + @Test + public void shouldThrowNoSuchElementWhenNoMoreItemsLeftAndPeekNextCalled() { + try (final DelegatingPeekingKeyValueIterator peekingIterator = + new DelegatingPeekingKeyValueIterator<>(name, store.all())) { + assertThrows(NoSuchElementException.class, peekingIterator::peekNextKey); + } + } + + +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/FilteredCacheIteratorTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/FilteredCacheIteratorTest.java new file mode 100644 index 0000000..bd79433 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/FilteredCacheIteratorTest.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.test.GenericInMemoryKeyValueStore; +import org.junit.Before; +import org.junit.Test; + +import java.util.Collections; +import java.util.List; + +import static java.util.Arrays.asList; +import static org.apache.kafka.test.StreamsTestUtils.toList; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class FilteredCacheIteratorTest { + + private static final CacheFunction IDENTITY_FUNCTION = new CacheFunction() { + @Override + public Bytes key(final Bytes cacheKey) { + return cacheKey; + } + + @Override + public Bytes cacheKey(final Bytes key) { + return key; + } + }; + + private final KeyValueStore store = new GenericInMemoryKeyValueStore<>("my-store"); + private final KeyValue firstEntry = KeyValue.pair(Bytes.wrap("a".getBytes()), + new LRUCacheEntry("1".getBytes())); + private final List> entries = asList( + firstEntry, + KeyValue.pair(Bytes.wrap("b".getBytes()), + new LRUCacheEntry("2".getBytes())), + KeyValue.pair(Bytes.wrap("c".getBytes()), + new LRUCacheEntry("3".getBytes()))); + + private FilteredCacheIterator allIterator; + private FilteredCacheIterator firstEntryIterator; + + @Before + public void before() { + store.putAll(entries); + final HasNextCondition allCondition = new HasNextCondition() { + @Override + public boolean hasNext(final KeyValueIterator iterator) { + return iterator.hasNext(); + } + }; + allIterator = new FilteredCacheIterator( + new DelegatingPeekingKeyValueIterator<>("", + store.all()), allCondition, IDENTITY_FUNCTION); + + final HasNextCondition firstEntryCondition = new HasNextCondition() { + @Override + public boolean hasNext(final KeyValueIterator iterator) { + return iterator.hasNext() && iterator.peekNextKey().equals(firstEntry.key); + } + }; + firstEntryIterator = new FilteredCacheIterator( + new DelegatingPeekingKeyValueIterator<>("", + store.all()), firstEntryCondition, IDENTITY_FUNCTION); + + } + + @Test + public void shouldAllowEntryMatchingHasNextCondition() { + final List> keyValues = toList(allIterator); + assertThat(keyValues, equalTo(entries)); + } + + @Test + public void shouldPeekNextKey() { + while (allIterator.hasNext()) { + final Bytes nextKey = allIterator.peekNextKey(); + final KeyValue next = allIterator.next(); + assertThat(next.key, equalTo(nextKey)); + } + } + + @Test + public void shouldPeekNext() { + while (allIterator.hasNext()) { + final KeyValue peeked = allIterator.peekNext(); + final KeyValue next = allIterator.next(); + assertThat(peeked, equalTo(next)); + } + } + + @Test + public void shouldNotHaveNextIfHasNextConditionNotMet() { + assertTrue(firstEntryIterator.hasNext()); + firstEntryIterator.next(); + assertFalse(firstEntryIterator.hasNext()); + } + + @Test + public void shouldFilterEntriesNotMatchingHasNextCondition() { + final List> keyValues = toList(firstEntryIterator); + assertThat(keyValues, equalTo(Collections.singletonList(firstEntry))); + } + + @Test + public void shouldThrowUnsupportedOperationExeceptionOnRemove() { + assertThrows(UnsupportedOperationException.class, () -> allIterator.remove()); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/GlobalStateStoreProviderTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/GlobalStateStoreProviderTest.java new file mode 100644 index 0000000..bd23173 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/GlobalStateStoreProviderTest.java @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.ProcessorContextImpl; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.QueryableStoreTypes; +import org.apache.kafka.streams.state.ReadOnlyKeyValueStore; +import org.apache.kafka.streams.state.ReadOnlySessionStore; +import org.apache.kafka.streams.state.ReadOnlyWindowStore; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.TimestampedWindowStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.test.NoOpReadOnlyStore; +import org.junit.Before; +import org.junit.Test; + +import java.time.Duration; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.niceMock; +import static org.easymock.EasyMock.replay; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + + +public class GlobalStateStoreProviderTest { + private final Map stores = new HashMap<>(); + + @Before + public void before() { + stores.put( + "kv-store", + Stores.keyValueStoreBuilder( + Stores.inMemoryKeyValueStore("kv-store"), + Serdes.String(), + Serdes.String()).build()); + stores.put( + "ts-kv-store", + Stores.timestampedKeyValueStoreBuilder( + Stores.inMemoryKeyValueStore("ts-kv-store"), + Serdes.String(), + Serdes.String()).build()); + stores.put( + "w-store", + Stores.windowStoreBuilder( + Stores.inMemoryWindowStore( + "w-store", + Duration.ofMillis(10L), + Duration.ofMillis(2L), + false), + Serdes.String(), + Serdes.String()).build()); + stores.put( + "ts-w-store", + Stores.timestampedWindowStoreBuilder( + Stores.inMemoryWindowStore( + "ts-w-store", + Duration.ofMillis(10L), + Duration.ofMillis(2L), + false), + Serdes.String(), + Serdes.String()).build()); + stores.put( + "s-store", + Stores.sessionStoreBuilder( + Stores.inMemorySessionStore( + "s-store", + Duration.ofMillis(10L)), + Serdes.String(), + Serdes.String()).build()); + + final ProcessorContextImpl mockContext = niceMock(ProcessorContextImpl.class); + expect(mockContext.applicationId()).andStubReturn("appId"); + expect(mockContext.metrics()) + .andStubReturn( + new StreamsMetricsImpl(new Metrics(), "threadName", StreamsConfig.METRICS_LATEST, new MockTime()) + ); + expect(mockContext.taskId()).andStubReturn(new TaskId(0, 0)); + expect(mockContext.recordCollector()).andStubReturn(null); + expectSerdes(mockContext); + replay(mockContext); + for (final StateStore store : stores.values()) { + store.init((StateStoreContext) mockContext, null); + } + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private static void expectSerdes(final ProcessorContextImpl context) { + expect(context.keySerde()).andStubReturn((Serde) Serdes.String()); + expect(context.valueSerde()).andStubReturn((Serde) Serdes.Long()); + } + + @Test + public void shouldReturnSingleItemListIfStoreExists() { + final GlobalStateStoreProvider provider = + new GlobalStateStoreProvider(Collections.singletonMap("global", new NoOpReadOnlyStore<>())); + final List> stores = + provider.stores("global", QueryableStoreTypes.keyValueStore()); + assertEquals(stores.size(), 1); + } + + @Test + public void shouldReturnEmptyItemListIfStoreDoesntExist() { + final GlobalStateStoreProvider provider = new GlobalStateStoreProvider(Collections.emptyMap()); + final List> stores = + provider.stores("global", QueryableStoreTypes.keyValueStore()); + assertTrue(stores.isEmpty()); + } + + @Test + public void shouldThrowExceptionIfStoreIsntOpen() { + final NoOpReadOnlyStore store = new NoOpReadOnlyStore<>(); + store.close(); + final GlobalStateStoreProvider provider = + new GlobalStateStoreProvider(Collections.singletonMap("global", store)); + assertThrows(InvalidStateStoreException.class, () -> provider.stores("global", + QueryableStoreTypes.keyValueStore())); + } + + @Test + public void shouldReturnKeyValueStore() { + final GlobalStateStoreProvider provider = new GlobalStateStoreProvider(stores); + final List> stores = + provider.stores("kv-store", QueryableStoreTypes.keyValueStore()); + assertEquals(1, stores.size()); + for (final ReadOnlyKeyValueStore store : stores) { + assertThat(store, instanceOf(ReadOnlyKeyValueStore.class)); + assertThat(store, not(instanceOf(TimestampedKeyValueStore.class))); + } + } + + @Test + public void shouldReturnTimestampedKeyValueStore() { + final GlobalStateStoreProvider provider = new GlobalStateStoreProvider(stores); + final List>> stores = + provider.stores("ts-kv-store", QueryableStoreTypes.timestampedKeyValueStore()); + assertEquals(1, stores.size()); + for (final ReadOnlyKeyValueStore> store : stores) { + assertThat(store, instanceOf(ReadOnlyKeyValueStore.class)); + assertThat(store, instanceOf(TimestampedKeyValueStore.class)); + } + } + + @Test + public void shouldNotReturnKeyValueStoreAsTimestampedStore() { + final GlobalStateStoreProvider provider = new GlobalStateStoreProvider(stores); + final List>> stores = + provider.stores("kv-store", QueryableStoreTypes.timestampedKeyValueStore()); + assertEquals(0, stores.size()); + } + + @Test + public void shouldReturnTimestampedKeyValueStoreAsKeyValueStore() { + final GlobalStateStoreProvider provider = new GlobalStateStoreProvider(stores); + final List>> stores = + provider.stores("ts-kv-store", QueryableStoreTypes.keyValueStore()); + assertEquals(1, stores.size()); + for (final ReadOnlyKeyValueStore> store : stores) { + assertThat(store, instanceOf(ReadOnlyKeyValueStore.class)); + assertThat(store, not(instanceOf(TimestampedKeyValueStore.class))); + } + } + + @Test + public void shouldReturnWindowStore() { + final GlobalStateStoreProvider provider = new GlobalStateStoreProvider(stores); + final List> stores = + provider.stores("w-store", QueryableStoreTypes.windowStore()); + assertEquals(1, stores.size()); + for (final ReadOnlyWindowStore store : stores) { + assertThat(store, instanceOf(ReadOnlyWindowStore.class)); + assertThat(store, not(instanceOf(TimestampedWindowStore.class))); + } + } + + @Test + public void shouldNotReturnWindowStoreAsTimestampedStore() { + final GlobalStateStoreProvider provider = new GlobalStateStoreProvider(stores); + final List>> stores = + provider.stores("w-store", QueryableStoreTypes.timestampedWindowStore()); + assertEquals(0, stores.size()); + } + + @Test + public void shouldReturnTimestampedWindowStoreAsWindowStore() { + final GlobalStateStoreProvider provider = new GlobalStateStoreProvider(stores); + final List>> stores = + provider.stores("ts-w-store", QueryableStoreTypes.windowStore()); + assertEquals(1, stores.size()); + for (final ReadOnlyWindowStore> store : stores) { + assertThat(store, instanceOf(ReadOnlyWindowStore.class)); + assertThat(store, not(instanceOf(TimestampedWindowStore.class))); + } + } + + @Test + public void shouldReturnSessionStore() { + final GlobalStateStoreProvider provider = new GlobalStateStoreProvider(stores); + final List> stores = + provider.stores("s-store", QueryableStoreTypes.sessionStore()); + assertEquals(1, stores.size()); + for (final ReadOnlySessionStore store : stores) { + assertThat(store, instanceOf(ReadOnlySessionStore.class)); + } + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemoryKeyValueLoggedStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemoryKeyValueLoggedStoreTest.java new file mode 100644 index 0000000..f54fca1 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemoryKeyValueLoggedStoreTest.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; + +import java.util.Collections; + +public class InMemoryKeyValueLoggedStoreTest extends AbstractKeyValueStoreTest { + + @SuppressWarnings("unchecked") + @Override + protected KeyValueStore createKeyValueStore(final StateStoreContext context) { + final StoreBuilder> storeBuilder = Stores.keyValueStoreBuilder( + Stores.inMemoryKeyValueStore("my-store"), + (Serde) context.keySerde(), + (Serde) context.valueSerde()) + .withLoggingEnabled(Collections.singletonMap("retention.ms", "1000")); + + final KeyValueStore store = storeBuilder.build(); + store.init(context, store); + + return store; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemoryKeyValueStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemoryKeyValueStoreTest.java new file mode 100644 index 0000000..d67d665 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemoryKeyValueStoreTest.java @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.KeyValueStoreTestDriver; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +public class InMemoryKeyValueStoreTest extends AbstractKeyValueStoreTest { + + private KeyValueStore byteStore; + private final Serializer stringSerializer = new StringSerializer(); + private final KeyValueStoreTestDriver byteStoreDriver = KeyValueStoreTestDriver.create(Bytes.class, byte[].class); + + @Before + public void createStringKeyValueStore() { + super.before(); + final StateStoreContext byteStoreContext = byteStoreDriver.context(); + final StoreBuilder> storeBuilder = Stores.keyValueStoreBuilder( + Stores.inMemoryKeyValueStore("in-memory-byte-store"), + new Serdes.BytesSerde(), + new Serdes.ByteArraySerde()); + byteStore = storeBuilder.build(); + byteStore.init(byteStoreContext, byteStore); + } + + @After + public void after() { + super.after(); + byteStore.close(); + byteStoreDriver.clear(); + } + + @SuppressWarnings("unchecked") + @Override + protected KeyValueStore createKeyValueStore(final StateStoreContext context) { + final StoreBuilder> storeBuilder = Stores.keyValueStoreBuilder( + Stores.inMemoryKeyValueStore("my-store"), + (Serde) context.keySerde(), + (Serde) context.valueSerde()); + + final KeyValueStore store = storeBuilder.build(); + store.init(context, store); + return store; + } + + @SuppressWarnings("unchecked") + @Test + public void shouldRemoveKeysWithNullValues() { + store.close(); + // Add any entries that will be restored to any store + // that uses the driver's context ... + driver.addEntryToRestoreLog(0, "zero"); + driver.addEntryToRestoreLog(1, "one"); + driver.addEntryToRestoreLog(2, "two"); + driver.addEntryToRestoreLog(3, "three"); + driver.addEntryToRestoreLog(0, null); + + store = createKeyValueStore(driver.context()); + context.restore(store.name(), driver.restoredEntries()); + + assertEquals(3, driver.sizeOf(store)); + + assertThat(store.get(0), nullValue()); + } + + + @Test + public void shouldReturnKeysWithGivenPrefix() { + + final List> entries = new ArrayList<>(); + entries.add(new KeyValue<>( + new Bytes(stringSerializer.serialize(null, "k1")), + stringSerializer.serialize(null, "a"))); + entries.add(new KeyValue<>( + new Bytes(stringSerializer.serialize(null, "prefix_3")), + stringSerializer.serialize(null, "b"))); + entries.add(new KeyValue<>( + new Bytes(stringSerializer.serialize(null, "k2")), + stringSerializer.serialize(null, "c"))); + entries.add(new KeyValue<>( + new Bytes(stringSerializer.serialize(null, "prefix_2")), + stringSerializer.serialize(null, "d"))); + entries.add(new KeyValue<>( + new Bytes(stringSerializer.serialize(null, "k3")), + stringSerializer.serialize(null, "e"))); + entries.add(new KeyValue<>( + new Bytes(stringSerializer.serialize(null, "prefix_1")), + stringSerializer.serialize(null, "f"))); + + byteStore.putAll(entries); + byteStore.flush(); + + final List valuesWithPrefix = new ArrayList<>(); + int numberOfKeysReturned = 0; + + try (final KeyValueIterator keysWithPrefix = byteStore.prefixScan("prefix", stringSerializer)) { + while (keysWithPrefix.hasNext()) { + final KeyValue next = keysWithPrefix.next(); + valuesWithPrefix.add(new String(next.value)); + numberOfKeysReturned++; + } + } + + assertThat(numberOfKeysReturned, is(3)); + assertThat(valuesWithPrefix.get(0), is("f")); + assertThat(valuesWithPrefix.get(1), is("d")); + assertThat(valuesWithPrefix.get(2), is("b")); + } + + @Test + public void shouldReturnKeysWithGivenPrefixExcludingNextKeyLargestKey() { + final List> entries = new ArrayList<>(); + entries.add(new KeyValue<>( + new Bytes(stringSerializer.serialize(null, "abc")), + stringSerializer.serialize(null, "f"))); + + entries.add(new KeyValue<>( + new Bytes(stringSerializer.serialize(null, "abcd")), + stringSerializer.serialize(null, "f"))); + + entries.add(new KeyValue<>( + new Bytes(stringSerializer.serialize(null, "abce")), + stringSerializer.serialize(null, "f"))); + + byteStore.putAll(entries); + byteStore.flush(); + + try (final KeyValueIterator keysWithPrefixAsabcd = byteStore.prefixScan("abcd", stringSerializer)) { + int numberOfKeysReturned = 0; + + while (keysWithPrefixAsabcd.hasNext()) { + keysWithPrefixAsabcd.next().key.get(); + numberOfKeysReturned++; + } + + assertThat(numberOfKeysReturned, is(1)); + } + } + + @Test + public void shouldReturnUUIDsWithStringPrefix() { + final List> entries = new ArrayList<>(); + final Serializer uuidSerializer = Serdes.UUID().serializer(); + final UUID uuid1 = UUID.randomUUID(); + final UUID uuid2 = UUID.randomUUID(); + final String prefix = uuid1.toString().substring(0, 4); + entries.add(new KeyValue<>( + new Bytes(uuidSerializer.serialize(null, uuid1)), + stringSerializer.serialize(null, "a"))); + entries.add(new KeyValue<>( + new Bytes(uuidSerializer.serialize(null, uuid2)), + stringSerializer.serialize(null, "b"))); + + byteStore.putAll(entries); + byteStore.flush(); + + final List valuesWithPrefix = new ArrayList<>(); + int numberOfKeysReturned = 0; + + try (final KeyValueIterator keysWithPrefix = byteStore.prefixScan(prefix, stringSerializer)) { + while (keysWithPrefix.hasNext()) { + final KeyValue next = keysWithPrefix.next(); + valuesWithPrefix.add(new String(next.value)); + numberOfKeysReturned++; + } + } + + assertThat(numberOfKeysReturned, is(1)); + assertThat(valuesWithPrefix.get(0), is("a")); + } + + @Test + public void shouldReturnNoKeys() { + final List> entries = new ArrayList<>(); + entries.add(new KeyValue<>( + new Bytes(stringSerializer.serialize(null, "a")), + stringSerializer.serialize(null, "a"))); + entries.add(new KeyValue<>( + new Bytes(stringSerializer.serialize(null, "b")), + stringSerializer.serialize(null, "c"))); + entries.add(new KeyValue<>( + new Bytes(stringSerializer.serialize(null, "c")), + stringSerializer.serialize(null, "e"))); + byteStore.putAll(entries); + byteStore.flush(); + + int numberOfKeysReturned = 0; + + try (final KeyValueIterator keysWithPrefix = byteStore.prefixScan("bb", stringSerializer)) { + while (keysWithPrefix.hasNext()) { + keysWithPrefix.next(); + numberOfKeysReturned++; + } + } + + assertThat(numberOfKeysReturned, is(0)); + } + + @Test + public void shouldThrowNullPointerIfPrefixKeySerializerIsNull() { + assertThrows(NullPointerException.class, () -> byteStore.prefixScan("bb", null)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemoryLRUCacheStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemoryLRUCacheStoreTest.java new file mode 100644 index 0000000..53057b9 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemoryLRUCacheStoreTest.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; +import org.junit.Test; + +import java.util.Arrays; +import java.util.List; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class InMemoryLRUCacheStoreTest extends AbstractKeyValueStoreTest { + + @SuppressWarnings("unchecked") + @Override + protected KeyValueStore createKeyValueStore(final StateStoreContext context) { + final StoreBuilder> storeBuilder = Stores.keyValueStoreBuilder( + Stores.lruMap("my-store", 10), + (Serde) context.keySerde(), + (Serde) context.valueSerde()); + + final KeyValueStore store = storeBuilder.build(); + store.init(context, store); + + return store; + } + + @Test + public void shouldPutAllKeyValuePairs() { + final List> kvPairs = Arrays.asList(KeyValue.pair(1, "1"), + KeyValue.pair(2, "2"), + KeyValue.pair(3, "3")); + + store.putAll(kvPairs); + + assertThat(store.approximateNumEntries(), equalTo(3L)); + + for (final KeyValue kvPair : kvPairs) { + assertThat(store.get(kvPair.key), equalTo(kvPair.value)); + } + } + + @Test + public void shouldUpdateValuesForExistingKeysOnPutAll() { + final List> kvPairs = Arrays.asList(KeyValue.pair(1, "1"), + KeyValue.pair(2, "2"), + KeyValue.pair(3, "3")); + + store.putAll(kvPairs); + + + final List> updatedKvPairs = Arrays.asList(KeyValue.pair(1, "ONE"), + KeyValue.pair(2, "TWO"), + KeyValue.pair(3, "THREE")); + + store.putAll(updatedKvPairs); + + assertThat(store.approximateNumEntries(), equalTo(3L)); + + for (final KeyValue kvPair : updatedKvPairs) { + assertThat(store.get(kvPair.key), equalTo(kvPair.value)); + } + } + + @Test + public void testEvict() { + // Create the test driver ... + store.put(0, "zero"); + store.put(1, "one"); + store.put(2, "two"); + store.put(3, "three"); + store.put(4, "four"); + store.put(5, "five"); + store.put(6, "six"); + store.put(7, "seven"); + store.put(8, "eight"); + store.put(9, "nine"); + assertEquals(10, driver.sizeOf(store)); + + store.put(10, "ten"); + store.flush(); + assertEquals(10, driver.sizeOf(store)); + assertTrue(driver.flushedEntryRemoved(0)); + assertEquals(1, driver.numFlushedEntryRemoved()); + + store.delete(1); + store.flush(); + assertEquals(9, driver.sizeOf(store)); + assertTrue(driver.flushedEntryRemoved(0)); + assertTrue(driver.flushedEntryRemoved(1)); + assertEquals(2, driver.numFlushedEntryRemoved()); + + store.put(11, "eleven"); + store.flush(); + assertEquals(10, driver.sizeOf(store)); + assertEquals(2, driver.numFlushedEntryRemoved()); + + store.put(2, "two-again"); + store.flush(); + assertEquals(10, driver.sizeOf(store)); + assertEquals(2, driver.numFlushedEntryRemoved()); + + store.put(12, "twelve"); + store.flush(); + assertEquals(10, driver.sizeOf(store)); + assertTrue(driver.flushedEntryRemoved(0)); + assertTrue(driver.flushedEntryRemoved(1)); + assertTrue(driver.flushedEntryRemoved(3)); + assertEquals(3, driver.numFlushedEntryRemoved()); + } + + @SuppressWarnings("unchecked") + @Test + public void testRestoreEvict() { + store.close(); + // Add any entries that will be restored to any store + // that uses the driver's context ... + driver.addEntryToRestoreLog(0, "zero"); + driver.addEntryToRestoreLog(1, "one"); + driver.addEntryToRestoreLog(2, "two"); + driver.addEntryToRestoreLog(3, "three"); + driver.addEntryToRestoreLog(4, "four"); + driver.addEntryToRestoreLog(5, "five"); + driver.addEntryToRestoreLog(6, "fix"); + driver.addEntryToRestoreLog(7, "seven"); + driver.addEntryToRestoreLog(8, "eight"); + driver.addEntryToRestoreLog(9, "nine"); + driver.addEntryToRestoreLog(10, "ten"); + + // Create the store, which should register with the context and automatically + // receive the restore entries ... + store = createKeyValueStore(driver.context()); + context.restore(store.name(), driver.restoredEntries()); + // Verify that the store's changelog does not get more appends ... + assertEquals(0, driver.numFlushedEntryStored()); + assertEquals(0, driver.numFlushedEntryRemoved()); + + // and there are no other entries ... + assertEquals(10, driver.sizeOf(store)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemorySessionStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemorySessionStoreTest.java new file mode 100644 index 0000000..a6ea780 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemorySessionStoreTest.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.SessionWindow; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.SessionStore; +import org.apache.kafka.streams.state.Stores; +import org.junit.Test; + +import java.util.Arrays; +import java.util.HashSet; + +import static java.time.Duration.ofMillis; +import static org.apache.kafka.test.StreamsTestUtils.valuesToSet; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; + +public class InMemorySessionStoreTest extends AbstractSessionBytesStoreTest { + + private static final String STORE_NAME = "in-memory session store"; + + @Override + SessionStore buildSessionStore(final long retentionPeriod, + final Serde keySerde, + final Serde valueSerde) { + return Stores.sessionStoreBuilder( + Stores.inMemorySessionStore( + STORE_NAME, + ofMillis(retentionPeriod)), + keySerde, + valueSerde).build(); + } + + @Test + public void shouldRemoveExpired() { + sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L); + sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 10)), 2L); + sessionStore.put(new Windowed<>("a", new SessionWindow(10, 20)), 3L); + + // Advance stream time to expire the first record + sessionStore.put(new Windowed<>("aa", new SessionWindow(10, RETENTION_PERIOD)), 4L); + + try (final KeyValueIterator, Long> iterator = + sessionStore.findSessions("a", "b", 0L, Long.MAX_VALUE) + ) { + assertEquals(valuesToSet(iterator), new HashSet<>(Arrays.asList(2L, 3L, 4L))); + } + } + + @Test + public void shouldNotExpireFromOpenIterator() { + + sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L); + sessionStore.put(new Windowed<>("aa", new SessionWindow(0, 10)), 2L); + sessionStore.put(new Windowed<>("a", new SessionWindow(10, 20)), 3L); + + final KeyValueIterator, Long> iterator = sessionStore.findSessions("a", "b", 0L, RETENTION_PERIOD); + + // Advance stream time to expire the first three record + sessionStore.put(new Windowed<>("aa", new SessionWindow(100, 2 * RETENTION_PERIOD)), 4L); + + assertEquals(valuesToSet(iterator), new HashSet<>(Arrays.asList(1L, 2L, 3L, 4L))); + assertFalse(iterator.hasNext()); + + iterator.close(); + assertFalse(sessionStore.findSessions("a", "b", 0L, 20L).hasNext()); + } + +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemoryTimeOrderedKeyValueBufferTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemoryTimeOrderedKeyValueBufferTest.java new file mode 100644 index 0000000..90f4850 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemoryTimeOrderedKeyValueBufferTest.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.state.StoreBuilder; +import org.junit.Test; + +import java.util.HashMap; +import java.util.Map; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.singletonMap; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; + +public class InMemoryTimeOrderedKeyValueBufferTest { + + @Test + public void bufferShouldAllowCacheEnablement() { + new InMemoryTimeOrderedKeyValueBuffer.Builder<>(null, null, null).withCachingEnabled(); + } + + @Test + public void bufferShouldAllowCacheDisablement() { + new InMemoryTimeOrderedKeyValueBuffer.Builder<>(null, null, null).withCachingDisabled(); + } + + @Test + public void bufferShouldAllowLoggingEnablement() { + final String expect = "3"; + final Map logConfig = new HashMap<>(); + logConfig.put("min.insync.replicas", expect); + final StoreBuilder> builder = + new InMemoryTimeOrderedKeyValueBuffer.Builder<>(null, null, null) + .withLoggingEnabled(logConfig); + + assertThat(builder.logConfig(), is(singletonMap("min.insync.replicas", expect))); + assertThat(builder.loggingEnabled(), is(true)); + } + + @Test + public void bufferShouldAllowLoggingDisablement() { + final StoreBuilder> builder + = new InMemoryTimeOrderedKeyValueBuffer.Builder<>(null, null, null) + .withLoggingDisabled(); + + assertThat(builder.logConfig(), is(emptyMap())); + assertThat(builder.loggingEnabled(), is(false)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemoryWindowStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemoryWindowStoreTest.java new file mode 100644 index 0000000..5150839 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/InMemoryWindowStoreTest.java @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.StateSerdes; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.streams.state.WindowStoreIterator; +import org.junit.Test; + +import java.util.LinkedList; +import java.util.List; + +import static java.time.Duration.ofMillis; +import static org.apache.kafka.streams.state.internals.WindowKeySchema.toStoreKeyBinary; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; + +public class InMemoryWindowStoreTest extends AbstractWindowBytesStoreTest { + + private final static String STORE_NAME = "InMemoryWindowStore"; + + @Override + WindowStore buildWindowStore(final long retentionPeriod, + final long windowSize, + final boolean retainDuplicates, + final Serde keySerde, + final Serde valueSerde) { + return Stores.windowStoreBuilder( + Stores.inMemoryWindowStore( + STORE_NAME, + ofMillis(retentionPeriod), + ofMillis(windowSize), + retainDuplicates), + keySerde, + valueSerde) + .build(); + } + + @SuppressWarnings("unchecked") + @Test + public void shouldRestore() { + // should be empty initially + assertFalse(windowStore.all().hasNext()); + + final StateSerdes serdes = new StateSerdes<>("", Serdes.Integer(), + Serdes.String()); + + final List> restorableEntries = new LinkedList<>(); + + restorableEntries + .add(new KeyValue<>(toStoreKeyBinary(1, 0L, 0, serdes).get(), serdes.rawValue("one"))); + restorableEntries.add(new KeyValue<>(toStoreKeyBinary(2, WINDOW_SIZE, 0, serdes).get(), + serdes.rawValue("two"))); + restorableEntries.add(new KeyValue<>(toStoreKeyBinary(3, 2 * WINDOW_SIZE, 0, serdes).get(), + serdes.rawValue("three"))); + + context.restore(STORE_NAME, restorableEntries); + try (final KeyValueIterator, String> iterator = windowStore + .fetchAll(0L, 2 * WINDOW_SIZE)) { + + assertEquals(windowedPair(1, "one", 0L), iterator.next()); + assertEquals(windowedPair(2, "two", WINDOW_SIZE), iterator.next()); + assertEquals(windowedPair(3, "three", 2 * WINDOW_SIZE), iterator.next()); + assertFalse(iterator.hasNext()); + } + } + + @Test + public void shouldNotExpireFromOpenIterator() { + + windowStore.put(1, "one", 0L); + windowStore.put(1, "two", 10L); + + windowStore.put(2, "one", 5L); + windowStore.put(2, "two", 15L); + + final WindowStoreIterator iterator1 = windowStore.fetch(1, 0L, 50L); + final WindowStoreIterator iterator2 = windowStore.fetch(2, 0L, 50L); + + // This put expires all four previous records, but they should still be returned from already open iterators + windowStore.put(1, "four", 2 * RETENTION_PERIOD); + + assertEquals(new KeyValue<>(0L, "one"), iterator1.next()); + assertEquals(new KeyValue<>(5L, "one"), iterator2.next()); + + assertEquals(new KeyValue<>(15L, "two"), iterator2.next()); + assertEquals(new KeyValue<>(10L, "two"), iterator1.next()); + + assertFalse(iterator1.hasNext()); + assertFalse(iterator2.hasNext()); + + iterator1.close(); + iterator2.close(); + + // Make sure expired records are removed now that open iterators are closed + assertFalse(windowStore.fetch(1, 0L, 50L).hasNext()); + } + + @Test + public void testExpiration() { + + long currentTime = 0; + windowStore.put(1, "one", currentTime); + + currentTime += RETENTION_PERIOD / 4; + windowStore.put(1, "two", currentTime); + + currentTime += RETENTION_PERIOD / 4; + windowStore.put(1, "three", currentTime); + + currentTime += RETENTION_PERIOD / 4; + windowStore.put(1, "four", currentTime); + + // increase current time to the full RETENTION_PERIOD to expire first record + currentTime = currentTime + RETENTION_PERIOD / 4; + windowStore.put(1, "five", currentTime); + + KeyValueIterator, String> iterator = windowStore + .fetchAll(0L, currentTime); + + // effect of this put (expires next oldest record, adds new one) should not be reflected in the already fetched results + currentTime = currentTime + RETENTION_PERIOD / 4; + windowStore.put(1, "six", currentTime); + + // should only have middle 4 values, as (only) the first record was expired at the time of the fetch + // and the last was inserted after the fetch + assertEquals(windowedPair(1, "two", RETENTION_PERIOD / 4), iterator.next()); + assertEquals(windowedPair(1, "three", RETENTION_PERIOD / 2), iterator.next()); + assertEquals(windowedPair(1, "four", 3 * (RETENTION_PERIOD / 4)), iterator.next()); + assertEquals(windowedPair(1, "five", RETENTION_PERIOD), iterator.next()); + assertFalse(iterator.hasNext()); + + iterator = windowStore.fetchAll(0L, currentTime); + + // If we fetch again after the last put, the second oldest record should have expired and newest should appear in results + assertEquals(windowedPair(1, "three", RETENTION_PERIOD / 2), iterator.next()); + assertEquals(windowedPair(1, "four", 3 * (RETENTION_PERIOD / 4)), iterator.next()); + assertEquals(windowedPair(1, "five", RETENTION_PERIOD), iterator.next()); + assertEquals(windowedPair(1, "six", 5 * (RETENTION_PERIOD / 4)), iterator.next()); + assertFalse(iterator.hasNext()); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/KeyValueIteratorFacadeTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/KeyValueIteratorFacadeTest.java new file mode 100644 index 0000000..19566e9 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/KeyValueIteratorFacadeTest.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.easymock.EasyMockRunner; +import org.easymock.Mock; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.expectLastCall; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.verify; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +@RunWith(EasyMockRunner.class) +public class KeyValueIteratorFacadeTest { + @Mock + private KeyValueIterator> mockedKeyValueIterator; + + private KeyValueIteratorFacade keyValueIteratorFacade; + + @Before + public void setup() { + keyValueIteratorFacade = new KeyValueIteratorFacade<>(mockedKeyValueIterator); + } + + @Test + public void shouldForwardHasNext() { + expect(mockedKeyValueIterator.hasNext()).andReturn(true).andReturn(false); + replay(mockedKeyValueIterator); + + assertTrue(keyValueIteratorFacade.hasNext()); + assertFalse(keyValueIteratorFacade.hasNext()); + verify(mockedKeyValueIterator); + } + + @Test + public void shouldForwardPeekNextKey() { + expect(mockedKeyValueIterator.peekNextKey()).andReturn("key"); + replay(mockedKeyValueIterator); + + assertThat(keyValueIteratorFacade.peekNextKey(), is("key")); + verify(mockedKeyValueIterator); + } + + @Test + public void shouldReturnPlainKeyValuePairOnGet() { + expect(mockedKeyValueIterator.next()).andReturn( + new KeyValue<>("key", ValueAndTimestamp.make("value", 42L))); + replay(mockedKeyValueIterator); + + assertThat(keyValueIteratorFacade.next(), is(KeyValue.pair("key", "value"))); + verify(mockedKeyValueIterator); + } + + @Test + public void shouldCloseInnerIterator() { + mockedKeyValueIterator.close(); + expectLastCall(); + replay(mockedKeyValueIterator); + + keyValueIteratorFacade.close(); + verify(mockedKeyValueIterator); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/KeyValueSegmentTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/KeyValueSegmentTest.java new file mode 100644 index 0000000..859aea1 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/KeyValueSegmentTest.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.internals.metrics.RocksDBMetricsRecorder; +import org.apache.kafka.test.TestUtils; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.util.HashSet; +import java.util.Set; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.streams.StreamsConfig.METRICS_RECORDING_LEVEL_CONFIG; +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.mock; +import static org.easymock.EasyMock.replay; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class KeyValueSegmentTest { + + private final RocksDBMetricsRecorder metricsRecorder = + new RocksDBMetricsRecorder("metrics-scope", "store-name"); + + @Before + public void setUp() { + metricsRecorder.init( + new StreamsMetricsImpl(new Metrics(), "test-client", StreamsConfig.METRICS_LATEST, new MockTime()), + new TaskId(0, 0) + ); + } + + @Test + public void shouldDeleteStateDirectoryOnDestroy() throws Exception { + final KeyValueSegment segment = new KeyValueSegment("segment", "window", 0L, metricsRecorder); + final String directoryPath = TestUtils.tempDirectory().getAbsolutePath(); + final File directory = new File(directoryPath); + + final ProcessorContext mockContext = mock(ProcessorContext.class); + expect(mockContext.appConfigs()).andReturn(mkMap(mkEntry(METRICS_RECORDING_LEVEL_CONFIG, "INFO"))); + expect(mockContext.stateDir()).andReturn(directory); + replay(mockContext); + + segment.openDB(mockContext.appConfigs(), mockContext.stateDir()); + + assertTrue(new File(directoryPath, "window").exists()); + assertTrue(new File(directoryPath + File.separator + "window", "segment").exists()); + assertTrue(new File(directoryPath + File.separator + "window", "segment").list().length > 0); + segment.destroy(); + assertFalse(new File(directoryPath + File.separator + "window", "segment").exists()); + assertTrue(new File(directoryPath, "window").exists()); + + segment.close(); + } + + @Test + public void shouldBeEqualIfIdIsEqual() { + final KeyValueSegment segment = new KeyValueSegment("anyName", "anyName", 0L, metricsRecorder); + final KeyValueSegment segmentSameId = + new KeyValueSegment("someOtherName", "someOtherName", 0L, metricsRecorder); + final KeyValueSegment segmentDifferentId = new KeyValueSegment("anyName", "anyName", 1L, metricsRecorder); + + assertThat(segment, equalTo(segment)); + assertThat(segment, equalTo(segmentSameId)); + assertThat(segment, not(equalTo(segmentDifferentId))); + assertThat(segment, not(equalTo(null))); + assertThat(segment, not(equalTo("anyName"))); + + segment.close(); + } + + @Test + public void shouldHashOnSegmentIdOnly() { + final KeyValueSegment segment = new KeyValueSegment("anyName", "anyName", 0L, metricsRecorder); + final KeyValueSegment segmentSameId = + new KeyValueSegment("someOtherName", "someOtherName", 0L, metricsRecorder); + final KeyValueSegment segmentDifferentId = new KeyValueSegment("anyName", "anyName", 1L, metricsRecorder); + + final Set set = new HashSet<>(); + assertTrue(set.add(segment)); + assertFalse(set.add(segmentSameId)); + assertTrue(set.add(segmentDifferentId)); + + segment.close(); + } + + @Test + public void shouldCompareSegmentIdOnly() { + final KeyValueSegment segment1 = new KeyValueSegment("a", "C", 50L, metricsRecorder); + final KeyValueSegment segment2 = new KeyValueSegment("b", "B", 100L, metricsRecorder); + final KeyValueSegment segment3 = new KeyValueSegment("c", "A", 0L, metricsRecorder); + + assertThat(segment1.compareTo(segment1), equalTo(0)); + assertThat(segment1.compareTo(segment2), equalTo(-1)); + assertThat(segment2.compareTo(segment1), equalTo(1)); + assertThat(segment1.compareTo(segment3), equalTo(1)); + assertThat(segment3.compareTo(segment1), equalTo(-1)); + assertThat(segment2.compareTo(segment3), equalTo(1)); + assertThat(segment3.compareTo(segment2), equalTo(-1)); + + segment1.close(); + segment2.close(); + segment3.close(); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/KeyValueSegmentsTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/KeyValueSegmentsTest.java new file mode 100644 index 0000000..c8f1a0e --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/KeyValueSegmentsTest.java @@ -0,0 +1,353 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.streams.processor.internals.MockStreamsMetrics; +import org.apache.kafka.test.InternalMockProcessorContext; +import org.apache.kafka.test.MockRecordCollector; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.text.SimpleDateFormat; +import java.util.Date; +import java.util.List; +import java.util.SimpleTimeZone; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +public class KeyValueSegmentsTest { + + private static final int NUM_SEGMENTS = 5; + private static final long SEGMENT_INTERVAL = 100L; + private static final long RETENTION_PERIOD = 4 * SEGMENT_INTERVAL; + private static final String METRICS_SCOPE = "test-state-id"; + private InternalMockProcessorContext context; + private KeyValueSegments segments; + private File stateDirectory; + private final String storeName = "test"; + + @Before + public void createContext() { + stateDirectory = TestUtils.tempDirectory(); + context = new InternalMockProcessorContext<>( + stateDirectory, + Serdes.String(), + Serdes.Long(), + new MockRecordCollector(), + new ThreadCache(new LogContext("testCache "), 0, new MockStreamsMetrics(new Metrics())) + ); + segments = new KeyValueSegments(storeName, METRICS_SCOPE, RETENTION_PERIOD, SEGMENT_INTERVAL); + segments.openExisting(context, -1L); + } + + @After + public void close() { + segments.close(); + } + + @Test + public void shouldGetSegmentIdsFromTimestamp() { + assertEquals(0, segments.segmentId(0)); + assertEquals(1, segments.segmentId(SEGMENT_INTERVAL)); + assertEquals(2, segments.segmentId(2 * SEGMENT_INTERVAL)); + assertEquals(3, segments.segmentId(3 * SEGMENT_INTERVAL)); + } + + @Test + public void shouldBaseSegmentIntervalOnRetentionAndNumSegments() { + final KeyValueSegments segments = new KeyValueSegments("test", METRICS_SCOPE, 8 * SEGMENT_INTERVAL, 2 * SEGMENT_INTERVAL); + assertEquals(0, segments.segmentId(0)); + assertEquals(0, segments.segmentId(SEGMENT_INTERVAL)); + assertEquals(1, segments.segmentId(2 * SEGMENT_INTERVAL)); + } + + @Test + public void shouldGetSegmentNameFromId() { + assertEquals("test.0", segments.segmentName(0)); + assertEquals("test." + SEGMENT_INTERVAL, segments.segmentName(1)); + assertEquals("test." + 2 * SEGMENT_INTERVAL, segments.segmentName(2)); + } + + @Test + public void shouldCreateSegments() { + final KeyValueSegment segment1 = segments.getOrCreateSegmentIfLive(0, context, -1L); + final KeyValueSegment segment2 = segments.getOrCreateSegmentIfLive(1, context, -1L); + final KeyValueSegment segment3 = segments.getOrCreateSegmentIfLive(2, context, -1L); + assertTrue(new File(context.stateDir(), "test/test.0").isDirectory()); + assertTrue(new File(context.stateDir(), "test/test." + SEGMENT_INTERVAL).isDirectory()); + assertTrue(new File(context.stateDir(), "test/test." + 2 * SEGMENT_INTERVAL).isDirectory()); + assertTrue(segment1.isOpen()); + assertTrue(segment2.isOpen()); + assertTrue(segment3.isOpen()); + } + + @Test + public void shouldNotCreateSegmentThatIsAlreadyExpired() { + final long streamTime = updateStreamTimeAndCreateSegment(7); + assertNull(segments.getOrCreateSegmentIfLive(0, context, streamTime)); + assertFalse(new File(context.stateDir(), "test/test.0").exists()); + } + + @Test + public void shouldCleanupSegmentsThatHaveExpired() { + final KeyValueSegment segment1 = segments.getOrCreateSegmentIfLive(0, context, -1L); + final KeyValueSegment segment2 = segments.getOrCreateSegmentIfLive(1, context, -1L); + final KeyValueSegment segment3 = segments.getOrCreateSegmentIfLive(7, context, SEGMENT_INTERVAL * 7L); + assertFalse(segment1.isOpen()); + assertFalse(segment2.isOpen()); + assertTrue(segment3.isOpen()); + assertFalse(new File(context.stateDir(), "test/test.0").exists()); + assertFalse(new File(context.stateDir(), "test/test." + SEGMENT_INTERVAL).exists()); + assertTrue(new File(context.stateDir(), "test/test." + 7 * SEGMENT_INTERVAL).exists()); + } + + @Test + public void shouldGetSegmentForTimestamp() { + final KeyValueSegment segment = segments.getOrCreateSegmentIfLive(0, context, -1L); + segments.getOrCreateSegmentIfLive(1, context, -1L); + assertEquals(segment, segments.getSegmentForTimestamp(0L)); + } + + @Test + public void shouldGetCorrectSegmentString() { + final KeyValueSegment segment = segments.getOrCreateSegmentIfLive(0, context, -1L); + assertEquals("KeyValueSegment(id=0, name=test.0)", segment.toString()); + } + + @Test + public void shouldCloseAllOpenSegments() { + final KeyValueSegment first = segments.getOrCreateSegmentIfLive(0, context, -1L); + final KeyValueSegment second = segments.getOrCreateSegmentIfLive(1, context, -1L); + final KeyValueSegment third = segments.getOrCreateSegmentIfLive(2, context, -1L); + segments.close(); + + assertFalse(first.isOpen()); + assertFalse(second.isOpen()); + assertFalse(third.isOpen()); + } + + @Test + public void shouldOpenExistingSegments() { + segments = new KeyValueSegments("test", METRICS_SCOPE, 4, 1); + segments.openExisting(context, -1L); + segments.getOrCreateSegmentIfLive(0, context, -1L); + segments.getOrCreateSegmentIfLive(1, context, -1L); + segments.getOrCreateSegmentIfLive(2, context, -1L); + segments.getOrCreateSegmentIfLive(3, context, -1L); + segments.getOrCreateSegmentIfLive(4, context, -1L); + // close existing. + segments.close(); + + segments = new KeyValueSegments("test", METRICS_SCOPE, 4, 1); + segments.openExisting(context, -1L); + + assertTrue(segments.getSegmentForTimestamp(0).isOpen()); + assertTrue(segments.getSegmentForTimestamp(1).isOpen()); + assertTrue(segments.getSegmentForTimestamp(2).isOpen()); + assertTrue(segments.getSegmentForTimestamp(3).isOpen()); + assertTrue(segments.getSegmentForTimestamp(4).isOpen()); + } + + @Test + public void shouldGetSegmentsWithinTimeRange() { + updateStreamTimeAndCreateSegment(0); + updateStreamTimeAndCreateSegment(1); + updateStreamTimeAndCreateSegment(2); + updateStreamTimeAndCreateSegment(3); + final long streamTime = updateStreamTimeAndCreateSegment(4); + segments.getOrCreateSegmentIfLive(0, context, streamTime); + segments.getOrCreateSegmentIfLive(1, context, streamTime); + segments.getOrCreateSegmentIfLive(2, context, streamTime); + segments.getOrCreateSegmentIfLive(3, context, streamTime); + segments.getOrCreateSegmentIfLive(4, context, streamTime); + + final List segments = this.segments.segments(0, 2 * SEGMENT_INTERVAL, true); + assertEquals(3, segments.size()); + assertEquals(0, segments.get(0).id); + assertEquals(1, segments.get(1).id); + assertEquals(2, segments.get(2).id); + } + + @Test + public void shouldGetSegmentsWithinBackwardTimeRange() { + updateStreamTimeAndCreateSegment(0); + updateStreamTimeAndCreateSegment(1); + updateStreamTimeAndCreateSegment(2); + updateStreamTimeAndCreateSegment(3); + final long streamTime = updateStreamTimeAndCreateSegment(4); + segments.getOrCreateSegmentIfLive(0, context, streamTime); + segments.getOrCreateSegmentIfLive(1, context, streamTime); + segments.getOrCreateSegmentIfLive(2, context, streamTime); + segments.getOrCreateSegmentIfLive(3, context, streamTime); + segments.getOrCreateSegmentIfLive(4, context, streamTime); + + final List segments = this.segments.segments(0, 2 * SEGMENT_INTERVAL, false); + assertEquals(3, segments.size()); + assertEquals(0, segments.get(2).id); + assertEquals(1, segments.get(1).id); + assertEquals(2, segments.get(0).id); + } + + @Test + public void shouldGetSegmentsWithinTimeRangeOutOfOrder() { + updateStreamTimeAndCreateSegment(4); + updateStreamTimeAndCreateSegment(2); + updateStreamTimeAndCreateSegment(0); + updateStreamTimeAndCreateSegment(1); + updateStreamTimeAndCreateSegment(3); + + final List segments = this.segments.segments(0, 2 * SEGMENT_INTERVAL, true); + assertEquals(3, segments.size()); + assertEquals(0, segments.get(0).id); + assertEquals(1, segments.get(1).id); + assertEquals(2, segments.get(2).id); + } + + @Test + public void shouldGetSegmentsWithinTimeBackwardRangeOutOfOrder() { + updateStreamTimeAndCreateSegment(4); + updateStreamTimeAndCreateSegment(2); + updateStreamTimeAndCreateSegment(0); + updateStreamTimeAndCreateSegment(1); + updateStreamTimeAndCreateSegment(3); + + final List segments = this.segments.segments(0, 2 * SEGMENT_INTERVAL, false); + assertEquals(3, segments.size()); + assertEquals(2, segments.get(0).id); + assertEquals(1, segments.get(1).id); + assertEquals(0, segments.get(2).id); + } + + @Test + public void shouldRollSegments() { + updateStreamTimeAndCreateSegment(0); + verifyCorrectSegments(0, 1); + updateStreamTimeAndCreateSegment(1); + verifyCorrectSegments(0, 2); + updateStreamTimeAndCreateSegment(2); + verifyCorrectSegments(0, 3); + updateStreamTimeAndCreateSegment(3); + verifyCorrectSegments(0, 4); + updateStreamTimeAndCreateSegment(4); + verifyCorrectSegments(0, 5); + updateStreamTimeAndCreateSegment(5); + verifyCorrectSegments(1, 5); + updateStreamTimeAndCreateSegment(6); + verifyCorrectSegments(2, 5); + } + + @Test + public void futureEventsShouldNotCauseSegmentRoll() { + updateStreamTimeAndCreateSegment(0); + verifyCorrectSegments(0, 1); + updateStreamTimeAndCreateSegment(1); + verifyCorrectSegments(0, 2); + updateStreamTimeAndCreateSegment(2); + verifyCorrectSegments(0, 3); + updateStreamTimeAndCreateSegment(3); + verifyCorrectSegments(0, 4); + final long streamTime = updateStreamTimeAndCreateSegment(4); + verifyCorrectSegments(0, 5); + segments.getOrCreateSegmentIfLive(5, context, streamTime); + verifyCorrectSegments(0, 6); + segments.getOrCreateSegmentIfLive(6, context, streamTime); + verifyCorrectSegments(0, 7); + } + + private long updateStreamTimeAndCreateSegment(final int segment) { + final long streamTime = SEGMENT_INTERVAL * segment; + segments.getOrCreateSegmentIfLive(segment, context, streamTime); + return streamTime; + } + + @Test + public void shouldUpdateSegmentFileNameFromOldDateFormatToNewFormat() throws Exception { + final long segmentInterval = 60_000L; // the old segment file's naming system maxes out at 1 minute granularity. + + segments = new KeyValueSegments(storeName, METRICS_SCOPE, NUM_SEGMENTS * segmentInterval, segmentInterval); + + final String storeDirectoryPath = stateDirectory.getAbsolutePath() + File.separator + storeName; + final File storeDirectory = new File(storeDirectoryPath); + //noinspection ResultOfMethodCallIgnored + storeDirectory.mkdirs(); + + final SimpleDateFormat formatter = new SimpleDateFormat("yyyyMMddHHmm"); + formatter.setTimeZone(new SimpleTimeZone(0, "UTC")); + + for (int segmentId = 0; segmentId < NUM_SEGMENTS; ++segmentId) { + final File oldSegment = new File(storeDirectoryPath + File.separator + storeName + "-" + formatter.format(new Date(segmentId * segmentInterval))); + //noinspection ResultOfMethodCallIgnored + oldSegment.createNewFile(); + } + + segments.openExisting(context, -1L); + + for (int segmentId = 0; segmentId < NUM_SEGMENTS; ++segmentId) { + final String segmentName = storeName + "." + (long) segmentId * segmentInterval; + final File newSegment = new File(storeDirectoryPath + File.separator + segmentName); + assertTrue(newSegment.exists()); + } + } + + @Test + public void shouldUpdateSegmentFileNameFromOldColonFormatToNewFormat() throws Exception { + final String storeDirectoryPath = stateDirectory.getAbsolutePath() + File.separator + storeName; + final File storeDirectory = new File(storeDirectoryPath); + //noinspection ResultOfMethodCallIgnored + storeDirectory.mkdirs(); + + for (int segmentId = 0; segmentId < NUM_SEGMENTS; ++segmentId) { + final File oldSegment = new File(storeDirectoryPath + File.separator + storeName + ":" + segmentId * (RETENTION_PERIOD / (NUM_SEGMENTS - 1))); + //noinspection ResultOfMethodCallIgnored + oldSegment.createNewFile(); + } + + segments.openExisting(context, -1L); + + for (int segmentId = 0; segmentId < NUM_SEGMENTS; ++segmentId) { + final File newSegment = new File(storeDirectoryPath + File.separator + storeName + "." + segmentId * (RETENTION_PERIOD / (NUM_SEGMENTS - 1))); + assertTrue(newSegment.exists()); + } + } + + @Test + public void shouldClearSegmentsOnClose() { + segments.getOrCreateSegmentIfLive(0, context, -1L); + segments.close(); + assertThat(segments.getSegmentForTimestamp(0), is(nullValue())); + } + + private void verifyCorrectSegments(final long first, final int numSegments) { + final List result = this.segments.segments(0, Long.MAX_VALUE, true); + assertEquals(numSegments, result.size()); + for (int i = 0; i < numSegments; i++) { + assertEquals(i + first, result.get(i).id); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/KeyValueStoreBuilderTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/KeyValueStoreBuilderTest.java new file mode 100644 index 0000000..465b734 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/KeyValueStoreBuilderTest.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.state.KeyValueBytesStoreSupplier; +import org.apache.kafka.streams.state.KeyValueStore; +import org.easymock.EasyMock; +import org.easymock.EasyMockRunner; +import org.easymock.Mock; +import org.easymock.MockType; +import org.hamcrest.CoreMatchers; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import java.util.Collections; + +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.reset; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.core.IsInstanceOf.instanceOf; +import static org.junit.Assert.assertThrows; + +@RunWith(EasyMockRunner.class) +public class KeyValueStoreBuilderTest { + + @Mock(type = MockType.NICE) + private KeyValueBytesStoreSupplier supplier; + @Mock(type = MockType.NICE) + private KeyValueStore inner; + private KeyValueStoreBuilder builder; + + @Before + public void setUp() { + EasyMock.expect(supplier.get()).andReturn(inner); + EasyMock.expect(supplier.name()).andReturn("name"); + expect(supplier.metricsScope()).andReturn("metricScope"); + EasyMock.replay(supplier); + builder = new KeyValueStoreBuilder<>( + supplier, + Serdes.String(), + Serdes.String(), + new MockTime() + ); + } + + @Test + public void shouldHaveMeteredStoreAsOuterStore() { + final KeyValueStore store = builder.build(); + assertThat(store, instanceOf(MeteredKeyValueStore.class)); + } + + @Test + public void shouldHaveChangeLoggingStoreByDefault() { + final KeyValueStore store = builder.build(); + assertThat(store, instanceOf(MeteredKeyValueStore.class)); + final StateStore next = ((WrappedStateStore) store).wrapped(); + assertThat(next, instanceOf(ChangeLoggingKeyValueBytesStore.class)); + } + + @Test + public void shouldNotHaveChangeLoggingStoreWhenDisabled() { + final KeyValueStore store = builder.withLoggingDisabled().build(); + final StateStore next = ((WrappedStateStore) store).wrapped(); + assertThat(next, CoreMatchers.equalTo(inner)); + } + + @Test + public void shouldHaveCachingStoreWhenEnabled() { + final KeyValueStore store = builder.withCachingEnabled().build(); + final StateStore wrapped = ((WrappedStateStore) store).wrapped(); + assertThat(store, instanceOf(MeteredKeyValueStore.class)); + assertThat(wrapped, instanceOf(CachingKeyValueStore.class)); + } + + @Test + public void shouldHaveChangeLoggingStoreWhenLoggingEnabled() { + final KeyValueStore store = builder + .withLoggingEnabled(Collections.emptyMap()) + .build(); + final StateStore wrapped = ((WrappedStateStore) store).wrapped(); + assertThat(store, instanceOf(MeteredKeyValueStore.class)); + assertThat(wrapped, instanceOf(ChangeLoggingKeyValueBytesStore.class)); + assertThat(((WrappedStateStore) wrapped).wrapped(), CoreMatchers.equalTo(inner)); + } + + @Test + public void shouldHaveCachingAndChangeLoggingWhenBothEnabled() { + final KeyValueStore store = builder + .withLoggingEnabled(Collections.emptyMap()) + .withCachingEnabled() + .build(); + final WrappedStateStore caching = (WrappedStateStore) ((WrappedStateStore) store).wrapped(); + final WrappedStateStore changeLogging = (WrappedStateStore) caching.wrapped(); + assertThat(store, instanceOf(MeteredKeyValueStore.class)); + assertThat(caching, instanceOf(CachingKeyValueStore.class)); + assertThat(changeLogging, instanceOf(ChangeLoggingKeyValueBytesStore.class)); + assertThat(changeLogging.wrapped(), CoreMatchers.equalTo(inner)); + } + + @SuppressWarnings("all") + @Test + public void shouldThrowNullPointerIfInnerIsNull() { + assertThrows(NullPointerException.class, () -> new KeyValueStoreBuilder<>(null, Serdes.String(), + Serdes.String(), new MockTime())); + } + + @Test + public void shouldThrowNullPointerIfKeySerdeIsNull() { + assertThrows(NullPointerException.class, () -> new KeyValueStoreBuilder<>(supplier, null, Serdes.String(), new MockTime())); + } + + @Test + public void shouldThrowNullPointerIfValueSerdeIsNull() { + assertThrows(NullPointerException.class, () -> new KeyValueStoreBuilder<>(supplier, Serdes.String(), null, new MockTime())); + } + + @Test + public void shouldThrowNullPointerIfTimeIsNull() { + assertThrows(NullPointerException.class, () -> new KeyValueStoreBuilder<>(supplier, Serdes.String(), Serdes.String(), null)); + } + + @Test + public void shouldThrowNullPointerIfMetricsScopeIsNull() { + reset(supplier); + expect(supplier.get()).andReturn(new RocksDBStore("name", null)); + expect(supplier.name()).andReturn("name"); + replay(supplier); + + final Exception e = assertThrows(NullPointerException.class, + () -> new KeyValueStoreBuilder<>(supplier, Serdes.String(), Serdes.String(), new MockTime())); + assertThat(e.getMessage(), equalTo("storeSupplier's metricsScope can't be null")); + } + +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/LeftOrRightValueSerializerTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/LeftOrRightValueSerializerTest.java new file mode 100644 index 0000000..f1ebda6 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/LeftOrRightValueSerializerTest.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serdes; +import org.junit.Test; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.junit.Assert.assertThrows; + +public class LeftOrRightValueSerializerTest { + private static final String TOPIC = "some-topic"; + + private static final LeftOrRightValueSerde STRING_OR_INTEGER_SERDE = + new LeftOrRightValueSerde<>(Serdes.String(), Serdes.Integer()); + + @Test + public void shouldSerializeStringValue() { + final String value = "some-string"; + + final LeftOrRightValue leftOrRightValue = LeftOrRightValue.makeLeftValue(value); + + final byte[] serialized = + STRING_OR_INTEGER_SERDE.serializer().serialize(TOPIC, leftOrRightValue); + + assertThat(serialized, is(notNullValue())); + + final LeftOrRightValue deserialized = + STRING_OR_INTEGER_SERDE.deserializer().deserialize(TOPIC, serialized); + + assertThat(deserialized, is(leftOrRightValue)); + } + + @Test + public void shouldSerializeIntegerValue() { + final int value = 5; + + final LeftOrRightValue leftOrRightValue = LeftOrRightValue.makeRightValue(value); + + final byte[] serialized = + STRING_OR_INTEGER_SERDE.serializer().serialize(TOPIC, leftOrRightValue); + + assertThat(serialized, is(notNullValue())); + + final LeftOrRightValue deserialized = + STRING_OR_INTEGER_SERDE.deserializer().deserialize(TOPIC, serialized); + + assertThat(deserialized, is(leftOrRightValue)); + } + + @Test + public void shouldThrowIfSerializeValueAsNull() { + assertThrows(NullPointerException.class, + () -> STRING_OR_INTEGER_SERDE.serializer().serialize(TOPIC, LeftOrRightValue.makeLeftValue(null))); + } + + @Test + public void shouldThrowIfSerializeOtherValueAsNull() { + assertThrows(NullPointerException.class, + () -> STRING_OR_INTEGER_SERDE.serializer().serialize(TOPIC, LeftOrRightValue.makeRightValue(null))); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/ListValueStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/ListValueStoreTest.java new file mode 100644 index 0000000..220eb9f --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/ListValueStoreTest.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.internals.MockStreamsMetrics; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.test.InternalMockProcessorContext; +import org.apache.kafka.test.MockRecordCollector; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.File; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +import static java.util.Arrays.asList; +import static org.apache.kafka.test.StreamsTestUtils.toList; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +@RunWith(Parameterized.class) +public class ListValueStoreTest { + private enum StoreType { InMemory, RocksDB } + + private final StoreType storeType; + private KeyValueStore listStore; + + final File baseDir = TestUtils.tempDirectory("test"); + + public ListValueStoreTest(final StoreType type) { + this.storeType = type; + } + + @Parameterized.Parameters(name = "store type = {0}") + public static Collection data() { + final List values = new ArrayList<>(); + for (final StoreType type : Arrays.asList(StoreType.InMemory, StoreType.RocksDB)) { + values.add(new Object[]{type}); + } + return values; + } + + @Before + public void setup() { + listStore = buildStore(Serdes.Integer(), Serdes.String()); + + final MockRecordCollector recordCollector = new MockRecordCollector(); + final InternalMockProcessorContext context = new InternalMockProcessorContext<>( + baseDir, + Serdes.String(), + Serdes.Integer(), + recordCollector, + new ThreadCache( + new LogContext("testCache"), + 0, + new MockStreamsMetrics(new Metrics()))); + context.setTime(1L); + + listStore.init((StateStoreContext) context, listStore); + } + + @After + public void after() { + listStore.close(); + } + + KeyValueStore buildStore(final Serde keySerde, + final Serde valueSerde) { + return new ListValueStoreBuilder<>( + storeType == StoreType.RocksDB ? Stores.persistentKeyValueStore("rocksDB list store") + : Stores.inMemoryKeyValueStore("in-memory list store"), + keySerde, + valueSerde, + Time.SYSTEM) + .build(); + } + + @Test + public void shouldGetAll() { + listStore.put(0, "zero"); + // should retain duplicates + listStore.put(0, "zero again"); + listStore.put(1, "one"); + listStore.put(2, "two"); + + final KeyValue zero = KeyValue.pair(0, "zero"); + final KeyValue zeroAgain = KeyValue.pair(0, "zero again"); + final KeyValue one = KeyValue.pair(1, "one"); + final KeyValue two = KeyValue.pair(2, "two"); + + assertEquals( + asList(zero, zeroAgain, one, two), + toList(listStore.all()) + ); + } + + @Test + public void shouldGetAllNonDeletedRecords() { + // Add some records + listStore.put(0, "zero"); + listStore.put(1, "one"); + listStore.put(1, "one again"); + listStore.put(2, "two"); + listStore.put(3, "three"); + listStore.put(4, "four"); + + // Delete some records + listStore.put(1, null); + listStore.put(3, null); + + // Only non-deleted records should appear in the all() iterator + final KeyValue zero = KeyValue.pair(0, "zero"); + final KeyValue two = KeyValue.pair(2, "two"); + final KeyValue four = KeyValue.pair(4, "four"); + + assertEquals( + asList(zero, two, four), + toList(listStore.all()) + ); + } + + @Test + public void shouldGetAllReturnTimestampOrderedRecords() { + // Add some records in different order + listStore.put(4, "four"); + listStore.put(0, "zero"); + listStore.put(2, "two1"); + listStore.put(3, "three"); + listStore.put(1, "one"); + + // Add duplicates + listStore.put(2, "two2"); + + // Only non-deleted records should appear in the all() iterator + final KeyValue zero = KeyValue.pair(0, "zero"); + final KeyValue one = KeyValue.pair(1, "one"); + final KeyValue two1 = KeyValue.pair(2, "two1"); + final KeyValue two2 = KeyValue.pair(2, "two2"); + final KeyValue three = KeyValue.pair(3, "three"); + final KeyValue four = KeyValue.pair(4, "four"); + + assertEquals( + asList(zero, one, two1, two2, three, four), + toList(listStore.all()) + ); + } + + @Test + public void shouldAllowDeleteWhileIterateRecords() { + listStore.put(0, "zero1"); + listStore.put(0, "zero2"); + listStore.put(1, "one"); + + final KeyValue zero1 = KeyValue.pair(0, "zero1"); + final KeyValue zero2 = KeyValue.pair(0, "zero2"); + final KeyValue one = KeyValue.pair(1, "one"); + + final KeyValueIterator it = listStore.all(); + assertEquals(zero1, it.next()); + + listStore.put(0, null); + + // zero2 should still be returned from the iterator after the delete call + assertEquals(zero2, it.next()); + + it.close(); + + // A new all() iterator after a previous all() iterator was closed should not return deleted records. + assertEquals(Collections.singletonList(one), toList(listStore.all())); + } + + @Test + public void shouldNotReturnMoreDataWhenIteratorClosed() { + listStore.put(0, "zero1"); + listStore.put(0, "zero2"); + listStore.put(1, "one"); + + final KeyValueIterator it = listStore.all(); + + it.close(); + + // A new all() iterator after a previous all() iterator was closed should not return deleted records. + assertThrows(InvalidStateStoreException.class, it::next); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/MaybeTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/MaybeTest.java new file mode 100644 index 0000000..7a29e13 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/MaybeTest.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.junit.Test; + +import java.util.NoSuchElementException; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.nullValue; +import static org.junit.Assert.fail; + +public class MaybeTest { + @Test + public void shouldReturnDefinedValue() { + assertThat(Maybe.defined(null).getNullableValue(), nullValue()); + assertThat(Maybe.defined("ASDF").getNullableValue(), is("ASDF")); + } + + @Test + public void shouldAnswerIsDefined() { + assertThat(Maybe.defined(null).isDefined(), is(true)); + assertThat(Maybe.defined("ASDF").isDefined(), is(true)); + assertThat(Maybe.undefined().isDefined(), is(false)); + } + + @Test + public void shouldThrowOnGetUndefinedValue() { + final Maybe undefined = Maybe.undefined(); + try { + undefined.getNullableValue(); + fail(); + } catch (final NoSuchElementException e) { + // no assertion necessary + } + } + + @Test + public void shouldUpholdEqualityCorrectness() { + assertThat(Maybe.undefined().equals(Maybe.undefined()), is(true)); + assertThat(Maybe.defined(null).equals(Maybe.defined(null)), is(true)); + assertThat(Maybe.defined("q").equals(Maybe.defined("q")), is(true)); + + assertThat(Maybe.undefined().equals(Maybe.defined(null)), is(false)); + assertThat(Maybe.undefined().equals(Maybe.defined("x")), is(false)); + + assertThat(Maybe.defined(null).equals(Maybe.undefined()), is(false)); + assertThat(Maybe.defined(null).equals(Maybe.defined("x")), is(false)); + + assertThat(Maybe.defined("a").equals(Maybe.undefined()), is(false)); + assertThat(Maybe.defined("a").equals(Maybe.defined(null)), is(false)); + assertThat(Maybe.defined("a").equals(Maybe.defined("b")), is(false)); + } + + @Test + public void shouldUpholdHashCodeCorrectness() { + // This specifies the current implementation, which is simpler to write than an exhaustive test. + // As long as this implementation doesn't change, then the equals/hashcode contract is upheld. + + assertThat(Maybe.undefined().hashCode(), is(-1)); + assertThat(Maybe.defined(null).hashCode(), is(0)); + assertThat(Maybe.defined("a").hashCode(), is("a".hashCode())); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/MergedSortedCacheKeyValueBytesStoreIteratorTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/MergedSortedCacheKeyValueBytesStoreIteratorTest.java new file mode 100644 index 0000000..1716ac1 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/MergedSortedCacheKeyValueBytesStoreIteratorTest.java @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.streams.processor.internals.MockStreamsMetrics; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertFalse; + +public class MergedSortedCacheKeyValueBytesStoreIteratorTest { + + private final String namespace = "0.0-one"; + private KeyValueStore store; + private ThreadCache cache; + + @Before + public void setUp() { + store = new InMemoryKeyValueStore(namespace); + cache = new ThreadCache(new LogContext("testCache "), 10000L, new MockStreamsMetrics(new Metrics())); + } + @Test + public void shouldIterateOverRange() { + final byte[][] bytes = {{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}, {10}, {11}}; + for (int i = 0; i < bytes.length; i += 2) { + store.put(Bytes.wrap(bytes[i]), bytes[i]); + cache.put(namespace, Bytes.wrap(bytes[i + 1]), new LRUCacheEntry(bytes[i + 1])); + } + + final Bytes from = Bytes.wrap(new byte[] {2}); + final Bytes to = Bytes.wrap(new byte[] {9}); + final KeyValueIterator storeIterator = + new DelegatingPeekingKeyValueIterator<>("store", store.range(from, to)); + final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = cache.range(namespace, from, to); + + final MergedSortedCacheKeyValueBytesStoreIterator iterator = + new MergedSortedCacheKeyValueBytesStoreIterator(cacheIterator, storeIterator, true); + final byte[][] values = new byte[8][]; + int index = 0; + int bytesIndex = 2; + while (iterator.hasNext()) { + final byte[] value = iterator.next().value; + values[index++] = value; + assertArrayEquals(bytes[bytesIndex++], value); + } + iterator.close(); + } + + + @Test + public void shouldReverseIterateOverRange() { + final byte[][] bytes = {{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}, {10}, {11}}; + for (int i = 0; i < bytes.length; i += 2) { + store.put(Bytes.wrap(bytes[i]), bytes[i]); + cache.put(namespace, Bytes.wrap(bytes[i + 1]), new LRUCacheEntry(bytes[i + 1])); + } + + final Bytes from = Bytes.wrap(new byte[] {2}); + final Bytes to = Bytes.wrap(new byte[] {9}); + final KeyValueIterator storeIterator = + new DelegatingPeekingKeyValueIterator<>("store", store.reverseRange(from, to)); + final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = cache.reverseRange(namespace, from, to); + + final MergedSortedCacheKeyValueBytesStoreIterator iterator = + new MergedSortedCacheKeyValueBytesStoreIterator(cacheIterator, storeIterator, false); + final byte[][] values = new byte[8][]; + int index = 0; + int bytesIndex = 9; + while (iterator.hasNext()) { + final byte[] value = iterator.next().value; + values[index++] = value; + assertArrayEquals(bytes[bytesIndex--], value); + } + iterator.close(); + } + + @Test + public void shouldSkipLargerDeletedCacheValue() { + final byte[][] bytes = {{0}, {1}}; + store.put(Bytes.wrap(bytes[0]), bytes[0]); + cache.put(namespace, Bytes.wrap(bytes[1]), new LRUCacheEntry(null)); + try (final MergedSortedCacheKeyValueBytesStoreIterator iterator = createIterator()) { + assertArrayEquals(bytes[0], iterator.next().key.get()); + assertFalse(iterator.hasNext()); + } + } + + @Test + public void shouldSkipSmallerDeletedCachedValue() { + final byte[][] bytes = {{0}, {1}}; + cache.put(namespace, Bytes.wrap(bytes[0]), new LRUCacheEntry(null)); + store.put(Bytes.wrap(bytes[1]), bytes[1]); + try (final MergedSortedCacheKeyValueBytesStoreIterator iterator = createIterator()) { + assertArrayEquals(bytes[1], iterator.next().key.get()); + assertFalse(iterator.hasNext()); + } + } + + @Test + public void shouldIgnoreIfDeletedInCacheButExistsInStore() { + final byte[][] bytes = {{0}}; + cache.put(namespace, Bytes.wrap(bytes[0]), new LRUCacheEntry(null)); + store.put(Bytes.wrap(bytes[0]), bytes[0]); + try (final MergedSortedCacheKeyValueBytesStoreIterator iterator = createIterator()) { + assertFalse(iterator.hasNext()); + } + } + + @Test + public void shouldNotHaveNextIfAllCachedItemsDeleted() { + final byte[][] bytes = {{0}, {1}, {2}}; + for (final byte[] aByte : bytes) { + final Bytes aBytes = Bytes.wrap(aByte); + store.put(aBytes, aByte); + cache.put(namespace, aBytes, new LRUCacheEntry(null)); + } + assertFalse(createIterator().hasNext()); + } + + @Test + public void shouldNotHaveNextIfOnlyCacheItemsAndAllDeleted() { + final byte[][] bytes = {{0}, {1}, {2}}; + for (final byte[] aByte : bytes) { + cache.put(namespace, Bytes.wrap(aByte), new LRUCacheEntry(null)); + } + assertFalse(createIterator().hasNext()); + } + + @Test + public void shouldSkipAllDeletedFromCache() { + final byte[][] bytes = {{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}, {10}, {11}}; + for (final byte[] aByte : bytes) { + final Bytes aBytes = Bytes.wrap(aByte); + store.put(aBytes, aByte); + cache.put(namespace, aBytes, new LRUCacheEntry(aByte)); + } + cache.put(namespace, Bytes.wrap(bytes[1]), new LRUCacheEntry(null)); + cache.put(namespace, Bytes.wrap(bytes[2]), new LRUCacheEntry(null)); + cache.put(namespace, Bytes.wrap(bytes[3]), new LRUCacheEntry(null)); + cache.put(namespace, Bytes.wrap(bytes[8]), new LRUCacheEntry(null)); + cache.put(namespace, Bytes.wrap(bytes[11]), new LRUCacheEntry(null)); + + try (final MergedSortedCacheKeyValueBytesStoreIterator iterator = createIterator()) { + assertArrayEquals(bytes[0], iterator.next().key.get()); + assertArrayEquals(bytes[4], iterator.next().key.get()); + assertArrayEquals(bytes[5], iterator.next().key.get()); + assertArrayEquals(bytes[6], iterator.next().key.get()); + assertArrayEquals(bytes[7], iterator.next().key.get()); + assertArrayEquals(bytes[9], iterator.next().key.get()); + assertArrayEquals(bytes[10], iterator.next().key.get()); + assertFalse(iterator.hasNext()); + } + } + + @Test + public void shouldPeekNextKey() { + final KeyValueStore kv = new InMemoryKeyValueStore("one"); + final ThreadCache cache = new ThreadCache(new LogContext("testCache "), 1000000L, new MockStreamsMetrics(new Metrics())); + final byte[][] bytes = {{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}, {10}}; + for (int i = 0; i < bytes.length - 1; i += 2) { + kv.put(Bytes.wrap(bytes[i]), bytes[i]); + cache.put(namespace, Bytes.wrap(bytes[i + 1]), new LRUCacheEntry(bytes[i + 1])); + } + + final Bytes from = Bytes.wrap(new byte[] {2}); + final Bytes to = Bytes.wrap(new byte[] {9}); + final KeyValueIterator storeIterator = kv.range(from, to); + final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = cache.range(namespace, from, to); + + final MergedSortedCacheKeyValueBytesStoreIterator iterator = + new MergedSortedCacheKeyValueBytesStoreIterator(cacheIterator, storeIterator, true); + final byte[][] values = new byte[8][]; + int index = 0; + int bytesIndex = 2; + while (iterator.hasNext()) { + final byte[] keys = iterator.peekNextKey().get(); + values[index++] = keys; + assertArrayEquals(bytes[bytesIndex++], keys); + iterator.next(); + } + iterator.close(); + } + + @Test + public void shouldPeekNextKeyReverse() { + final KeyValueStore kv = new InMemoryKeyValueStore("one"); + final ThreadCache cache = new ThreadCache(new LogContext("testCache "), 1000000L, new MockStreamsMetrics(new Metrics())); + final byte[][] bytes = {{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}, {10}}; + for (int i = 0; i < bytes.length - 1; i += 2) { + kv.put(Bytes.wrap(bytes[i]), bytes[i]); + cache.put(namespace, Bytes.wrap(bytes[i + 1]), new LRUCacheEntry(bytes[i + 1])); + } + + final Bytes from = Bytes.wrap(new byte[] {2}); + final Bytes to = Bytes.wrap(new byte[] {9}); + final KeyValueIterator storeIterator = kv.reverseRange(from, to); + final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = cache.reverseRange(namespace, from, to); + + final MergedSortedCacheKeyValueBytesStoreIterator iterator = + new MergedSortedCacheKeyValueBytesStoreIterator(cacheIterator, storeIterator, false); + final byte[][] values = new byte[8][]; + int index = 0; + int bytesIndex = 9; + while (iterator.hasNext()) { + final byte[] keys = iterator.peekNextKey().get(); + values[index++] = keys; + assertArrayEquals(bytes[bytesIndex--], keys); + iterator.next(); + } + iterator.close(); + } + + private MergedSortedCacheKeyValueBytesStoreIterator createIterator() { + final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = cache.all(namespace); + final KeyValueIterator storeIterator = new DelegatingPeekingKeyValueIterator<>("store", store.all()); + return new MergedSortedCacheKeyValueBytesStoreIterator(cacheIterator, storeIterator, true); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/MergedSortedCacheWrappedSessionStoreIteratorTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/MergedSortedCacheWrappedSessionStoreIteratorTest.java new file mode 100644 index 0000000..4bd125a --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/MergedSortedCacheWrappedSessionStoreIteratorTest.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.SessionWindow; +import org.apache.kafka.test.KeyValueIteratorStub; +import org.junit.Test; + +import java.util.Collections; +import java.util.Iterator; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class MergedSortedCacheWrappedSessionStoreIteratorTest { + + private static final SegmentedCacheFunction SINGLE_SEGMENT_CACHE_FUNCTION = new SegmentedCacheFunction(null, -1) { + @Override + public long segmentId(final Bytes key) { + return 0; + } + }; + + private final Bytes storeKey = Bytes.wrap("a".getBytes()); + private final Bytes cacheKey = Bytes.wrap("b".getBytes()); + + private final SessionWindow storeWindow = new SessionWindow(0, 1); + private final Iterator, byte[]>> storeKvs = Collections.singleton( + KeyValue.pair(new Windowed<>(storeKey, storeWindow), storeKey.get())).iterator(); + private final SessionWindow cacheWindow = new SessionWindow(10, 20); + private final Iterator> cacheKvs = Collections.singleton( + KeyValue.pair( + SINGLE_SEGMENT_CACHE_FUNCTION.cacheKey(SessionKeySchema.toBinary(new Windowed<>(cacheKey, cacheWindow))), + new LRUCacheEntry(cacheKey.get()) + )).iterator(); + + @Test + public void shouldHaveNextFromStore() { + final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(storeKvs, Collections.emptyIterator(), false); + assertTrue(mergeIterator.hasNext()); + } + + @Test + public void shouldHaveNextFromReverseStore() { + final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(storeKvs, Collections.emptyIterator(), true); + assertTrue(mergeIterator.hasNext()); + } + + @Test + public void shouldGetNextFromStore() { + final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(storeKvs, Collections.emptyIterator(), false); + assertThat(mergeIterator.next(), equalTo(KeyValue.pair(new Windowed<>(storeKey, storeWindow), storeKey.get()))); + } + + @Test + public void shouldGetNextFromReverseStore() { + final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(storeKvs, Collections.emptyIterator(), true); + assertThat(mergeIterator.next(), equalTo(KeyValue.pair(new Windowed<>(storeKey, storeWindow), storeKey.get()))); + } + + @Test + public void shouldPeekNextKeyFromStore() { + final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(storeKvs, Collections.emptyIterator(), false); + assertThat(mergeIterator.peekNextKey(), equalTo(new Windowed<>(storeKey, storeWindow))); + } + + @Test + public void shouldPeekNextKeyFromReverseStore() { + final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(storeKvs, Collections.emptyIterator(), true); + assertThat(mergeIterator.peekNextKey(), equalTo(new Windowed<>(storeKey, storeWindow))); + } + + @Test + public void shouldHaveNextFromCache() { + final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(Collections.emptyIterator(), cacheKvs, false); + assertTrue(mergeIterator.hasNext()); + } + + @Test + public void shouldHaveNextFromReverseCache() { + final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(Collections.emptyIterator(), cacheKvs, true); + assertTrue(mergeIterator.hasNext()); + } + + @Test + public void shouldGetNextFromCache() { + final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(Collections.emptyIterator(), cacheKvs, false); + assertThat(mergeIterator.next(), equalTo(KeyValue.pair(new Windowed<>(cacheKey, cacheWindow), cacheKey.get()))); + } + + @Test + public void shouldGetNextFromReverseCache() { + final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(Collections.emptyIterator(), cacheKvs, true); + assertThat(mergeIterator.next(), equalTo(KeyValue.pair(new Windowed<>(cacheKey, cacheWindow), cacheKey.get()))); + } + + @Test + public void shouldPeekNextKeyFromCache() { + final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(Collections.emptyIterator(), cacheKvs, false); + assertThat(mergeIterator.peekNextKey(), equalTo(new Windowed<>(cacheKey, cacheWindow))); + } + + @Test + public void shouldPeekNextKeyFromReverseCache() { + final MergedSortedCacheSessionStoreIterator mergeIterator = createIterator(Collections.emptyIterator(), cacheKvs, true); + assertThat(mergeIterator.peekNextKey(), equalTo(new Windowed<>(cacheKey, cacheWindow))); + } + + @Test + public void shouldIterateBothStoreAndCache() { + final MergedSortedCacheSessionStoreIterator iterator = createIterator(storeKvs, cacheKvs, true); + assertThat(iterator.next(), equalTo(KeyValue.pair(new Windowed<>(storeKey, storeWindow), storeKey.get()))); + assertThat(iterator.next(), equalTo(KeyValue.pair(new Windowed<>(cacheKey, cacheWindow), cacheKey.get()))); + assertFalse(iterator.hasNext()); + } + + @Test + public void shouldReverseIterateBothStoreAndCache() { + final MergedSortedCacheSessionStoreIterator iterator = createIterator(storeKvs, cacheKvs, false); + assertThat(iterator.next(), equalTo(KeyValue.pair(new Windowed<>(cacheKey, cacheWindow), cacheKey.get()))); + assertThat(iterator.next(), equalTo(KeyValue.pair(new Windowed<>(storeKey, storeWindow), storeKey.get()))); + assertFalse(iterator.hasNext()); + } + + private MergedSortedCacheSessionStoreIterator createIterator(final Iterator, byte[]>> storeKvs, + final Iterator> cacheKvs, + final boolean forward) { + final DelegatingPeekingKeyValueIterator, byte[]> storeIterator = + new DelegatingPeekingKeyValueIterator<>("store", new KeyValueIteratorStub<>(storeKvs)); + + final PeekingKeyValueIterator cacheIterator = + new DelegatingPeekingKeyValueIterator<>("cache", new KeyValueIteratorStub<>(cacheKvs)); + return new MergedSortedCacheSessionStoreIterator(cacheIterator, storeIterator, SINGLE_SEGMENT_CACHE_FUNCTION, forward); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/MergedSortedCacheWrappedWindowStoreIteratorTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/MergedSortedCacheWrappedWindowStoreIteratorTest.java new file mode 100644 index 0000000..0d69d93 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/MergedSortedCacheWrappedWindowStoreIteratorTest.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.internals.MockStreamsMetrics; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.StateSerdes; +import org.apache.kafka.test.KeyValueIteratorStub; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +public class MergedSortedCacheWrappedWindowStoreIteratorTest { + + private static final SegmentedCacheFunction SINGLE_SEGMENT_CACHE_FUNCTION = new SegmentedCacheFunction(null, -1) { + @Override + public long segmentId(final Bytes key) { + return 0; + } + }; + + private final List> windowStoreKvPairs = new ArrayList<>(); + private final ThreadCache cache = new ThreadCache(new LogContext("testCache "), 1000000L, new MockStreamsMetrics(new Metrics())); + private final String namespace = "0.0-one"; + private final StateSerdes stateSerdes = new StateSerdes<>("foo", Serdes.String(), Serdes.String()); + + @Test + public void shouldIterateOverValueFromBothIterators() { + final List> expectedKvPairs = new ArrayList<>(); + for (long t = 0; t < 100; t += 20) { + final byte[] v1Bytes = String.valueOf(t).getBytes(); + final KeyValue v1 = KeyValue.pair(t, v1Bytes); + windowStoreKvPairs.add(v1); + expectedKvPairs.add(KeyValue.pair(t, v1Bytes)); + final Bytes keyBytes = WindowKeySchema.toStoreKeyBinary("a", t + 10, 0, stateSerdes); + final byte[] valBytes = String.valueOf(t + 10).getBytes(); + expectedKvPairs.add(KeyValue.pair(t + 10, valBytes)); + cache.put(namespace, SINGLE_SEGMENT_CACHE_FUNCTION.cacheKey(keyBytes), new LRUCacheEntry(valBytes)); + } + + final Bytes fromBytes = WindowKeySchema.toStoreKeyBinary("a", 0, 0, stateSerdes); + final Bytes toBytes = WindowKeySchema.toStoreKeyBinary("a", 100, 0, stateSerdes); + final KeyValueIterator storeIterator = new DelegatingPeekingKeyValueIterator<>("store", new KeyValueIteratorStub<>(windowStoreKvPairs.iterator())); + + final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = cache.range( + namespace, SINGLE_SEGMENT_CACHE_FUNCTION.cacheKey(fromBytes), SINGLE_SEGMENT_CACHE_FUNCTION.cacheKey(toBytes) + ); + + final MergedSortedCacheWindowStoreIterator iterator = new MergedSortedCacheWindowStoreIterator( + cacheIterator, storeIterator, true + ); + int index = 0; + while (iterator.hasNext()) { + final KeyValue next = iterator.next(); + final KeyValue expected = expectedKvPairs.get(index++); + assertArrayEquals(expected.value, next.value); + assertEquals(expected.key, next.key); + } + iterator.close(); + } + + + @Test + public void shouldReverseIterateOverValueFromBothIterators() { + final List> expectedKvPairs = new ArrayList<>(); + for (long t = 0; t < 100; t += 20) { + final byte[] v1Bytes = String.valueOf(t).getBytes(); + final KeyValue v1 = KeyValue.pair(t, v1Bytes); + windowStoreKvPairs.add(v1); + expectedKvPairs.add(KeyValue.pair(t, v1Bytes)); + final Bytes keyBytes = WindowKeySchema.toStoreKeyBinary("a", t + 10, 0, stateSerdes); + final byte[] valBytes = String.valueOf(t + 10).getBytes(); + expectedKvPairs.add(KeyValue.pair(t + 10, valBytes)); + cache.put(namespace, SINGLE_SEGMENT_CACHE_FUNCTION.cacheKey(keyBytes), new LRUCacheEntry(valBytes)); + } + + final Bytes fromBytes = WindowKeySchema.toStoreKeyBinary("a", 0, 0, stateSerdes); + final Bytes toBytes = WindowKeySchema.toStoreKeyBinary("a", 100, 0, stateSerdes); + Collections.reverse(windowStoreKvPairs); + final KeyValueIterator storeIterator = + new DelegatingPeekingKeyValueIterator<>("store", new KeyValueIteratorStub<>(windowStoreKvPairs.iterator())); + + final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = cache.reverseRange( + namespace, SINGLE_SEGMENT_CACHE_FUNCTION.cacheKey(fromBytes), SINGLE_SEGMENT_CACHE_FUNCTION.cacheKey(toBytes) + ); + + final MergedSortedCacheWindowStoreIterator iterator = new MergedSortedCacheWindowStoreIterator( + cacheIterator, storeIterator, false + ); + int index = 0; + Collections.reverse(expectedKvPairs); + while (iterator.hasNext()) { + final KeyValue next = iterator.next(); + final KeyValue expected = expectedKvPairs.get(index++); + assertArrayEquals(expected.value, next.value); + assertEquals(expected.key, next.key); + } + iterator.close(); + } + + @Test + public void shouldPeekNextStoreKey() { + windowStoreKvPairs.add(KeyValue.pair(10L, "a".getBytes())); + cache.put(namespace, SINGLE_SEGMENT_CACHE_FUNCTION.cacheKey(WindowKeySchema.toStoreKeyBinary("a", 0, 0, stateSerdes)), new LRUCacheEntry("b".getBytes())); + final Bytes fromBytes = WindowKeySchema.toStoreKeyBinary("a", 0, 0, stateSerdes); + final Bytes toBytes = WindowKeySchema.toStoreKeyBinary("a", 100, 0, stateSerdes); + final KeyValueIterator storeIterator = new DelegatingPeekingKeyValueIterator<>("store", new KeyValueIteratorStub<>(windowStoreKvPairs.iterator())); + final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = cache.range( + namespace, SINGLE_SEGMENT_CACHE_FUNCTION.cacheKey(fromBytes), SINGLE_SEGMENT_CACHE_FUNCTION.cacheKey(toBytes) + ); + final MergedSortedCacheWindowStoreIterator iterator = new MergedSortedCacheWindowStoreIterator( + cacheIterator, storeIterator, true + ); + assertThat(iterator.peekNextKey(), equalTo(0L)); + iterator.next(); + assertThat(iterator.peekNextKey(), equalTo(10L)); + iterator.close(); + } + + @Test + public void shouldPeekNextStoreKeyReverse() { + windowStoreKvPairs.add(KeyValue.pair(10L, "a".getBytes())); + cache.put(namespace, SINGLE_SEGMENT_CACHE_FUNCTION.cacheKey(WindowKeySchema.toStoreKeyBinary("a", 0, 0, stateSerdes)), new LRUCacheEntry("b".getBytes())); + final Bytes fromBytes = WindowKeySchema.toStoreKeyBinary("a", 0, 0, stateSerdes); + final Bytes toBytes = WindowKeySchema.toStoreKeyBinary("a", 100, 0, stateSerdes); + final KeyValueIterator storeIterator = + new DelegatingPeekingKeyValueIterator<>("store", new KeyValueIteratorStub<>(windowStoreKvPairs.iterator())); + final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = cache.reverseRange( + namespace, SINGLE_SEGMENT_CACHE_FUNCTION.cacheKey(fromBytes), + SINGLE_SEGMENT_CACHE_FUNCTION.cacheKey(toBytes) + ); + final MergedSortedCacheWindowStoreIterator iterator = new MergedSortedCacheWindowStoreIterator( + cacheIterator, storeIterator, false + ); + assertThat(iterator.peekNextKey(), equalTo(10L)); + iterator.next(); + assertThat(iterator.peekNextKey(), equalTo(0L)); + iterator.close(); + } + + @Test + public void shouldPeekNextCacheKey() { + windowStoreKvPairs.add(KeyValue.pair(0L, "a".getBytes())); + cache.put(namespace, SINGLE_SEGMENT_CACHE_FUNCTION.cacheKey(WindowKeySchema.toStoreKeyBinary("a", 10L, 0, stateSerdes)), new LRUCacheEntry("b".getBytes())); + final Bytes fromBytes = WindowKeySchema.toStoreKeyBinary("a", 0, 0, stateSerdes); + final Bytes toBytes = WindowKeySchema.toStoreKeyBinary("a", 100, 0, stateSerdes); + final KeyValueIterator storeIterator = + new DelegatingPeekingKeyValueIterator<>("store", new KeyValueIteratorStub<>(windowStoreKvPairs.iterator())); + final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = cache.range( + namespace, + SINGLE_SEGMENT_CACHE_FUNCTION.cacheKey(fromBytes), + SINGLE_SEGMENT_CACHE_FUNCTION.cacheKey(toBytes) + ); + final MergedSortedCacheWindowStoreIterator iterator = new MergedSortedCacheWindowStoreIterator( + cacheIterator, + storeIterator, + true + ); + assertThat(iterator.peekNextKey(), equalTo(0L)); + iterator.next(); + assertThat(iterator.peekNextKey(), equalTo(10L)); + iterator.close(); + } + + @Test + public void shouldPeekNextCacheKeyReverse() { + windowStoreKvPairs.add(KeyValue.pair(0L, "a".getBytes())); + cache.put(namespace, SINGLE_SEGMENT_CACHE_FUNCTION.cacheKey(WindowKeySchema.toStoreKeyBinary("a", 10L, 0, stateSerdes)), new LRUCacheEntry("b".getBytes())); + final Bytes fromBytes = WindowKeySchema.toStoreKeyBinary("a", 0, 0, stateSerdes); + final Bytes toBytes = WindowKeySchema.toStoreKeyBinary("a", 100, 0, stateSerdes); + final KeyValueIterator storeIterator = + new DelegatingPeekingKeyValueIterator<>("store", new KeyValueIteratorStub<>(windowStoreKvPairs.iterator())); + final ThreadCache.MemoryLRUCacheBytesIterator cacheIterator = cache.reverseRange( + namespace, + SINGLE_SEGMENT_CACHE_FUNCTION.cacheKey(fromBytes), + SINGLE_SEGMENT_CACHE_FUNCTION.cacheKey(toBytes) + ); + final MergedSortedCacheWindowStoreIterator iterator = new MergedSortedCacheWindowStoreIterator( + cacheIterator, + storeIterator, + false + ); + assertThat(iterator.peekNextKey(), equalTo(10L)); + iterator.next(); + assertThat(iterator.peekNextKey(), equalTo(0L)); + iterator.close(); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/MergedSortedCacheWrappedWindowStoreKeyValueIteratorTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/MergedSortedCacheWrappedWindowStoreKeyValueIteratorTest.java new file mode 100644 index 0000000..8cfe9b8 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/MergedSortedCacheWrappedWindowStoreKeyValueIteratorTest.java @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.TimeWindow; +import org.apache.kafka.streams.state.StateSerdes; +import org.apache.kafka.test.KeyValueIteratorStub; +import org.junit.Test; + +import java.util.Collections; +import java.util.Iterator; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class MergedSortedCacheWrappedWindowStoreKeyValueIteratorTest { + private static final SegmentedCacheFunction SINGLE_SEGMENT_CACHE_FUNCTION = new SegmentedCacheFunction(null, -1) { + @Override + public long segmentId(final Bytes key) { + return 0; + } + }; + private static final int WINDOW_SIZE = 10; + + private final String storeKey = "a"; + private final String cacheKey = "b"; + + private final TimeWindow storeWindow = new TimeWindow(0, 1); + private final Iterator, byte[]>> storeKvs = Collections.singleton( + KeyValue.pair(new Windowed<>(Bytes.wrap(storeKey.getBytes()), storeWindow), storeKey.getBytes())).iterator(); + private final TimeWindow cacheWindow = new TimeWindow(10, 20); + private final Iterator> cacheKvs = Collections.singleton( + KeyValue.pair( + SINGLE_SEGMENT_CACHE_FUNCTION.cacheKey(WindowKeySchema.toStoreKeyBinary( + new Windowed<>(cacheKey, cacheWindow), 0, new StateSerdes<>("dummy", Serdes.String(), Serdes.ByteArray())) + ), + new LRUCacheEntry(cacheKey.getBytes()) + )).iterator(); + final private Deserializer deserializer = Serdes.String().deserializer(); + + @Test + public void shouldHaveNextFromStore() { + final MergedSortedCacheWindowStoreKeyValueIterator mergeIterator = + createIterator(storeKvs, Collections.emptyIterator(), false); + assertTrue(mergeIterator.hasNext()); + } + + @Test + public void shouldHaveNextFromReverseStore() { + final MergedSortedCacheWindowStoreKeyValueIterator mergeIterator = + createIterator(storeKvs, Collections.emptyIterator(), true); + assertTrue(mergeIterator.hasNext()); + } + + @Test + public void shouldGetNextFromStore() { + final MergedSortedCacheWindowStoreKeyValueIterator mergeIterator = + createIterator(storeKvs, Collections.emptyIterator(), false); + assertThat(convertKeyValuePair(mergeIterator.next()), equalTo(KeyValue.pair(new Windowed<>(storeKey, storeWindow), storeKey))); + } + + @Test + public void shouldGetNextFromReverseStore() { + final MergedSortedCacheWindowStoreKeyValueIterator mergeIterator = + createIterator(storeKvs, Collections.emptyIterator(), true); + assertThat(convertKeyValuePair(mergeIterator.next()), equalTo(KeyValue.pair(new Windowed<>(storeKey, storeWindow), storeKey))); + } + + @Test + public void shouldPeekNextKeyFromStore() { + final MergedSortedCacheWindowStoreKeyValueIterator mergeIterator = + createIterator(storeKvs, Collections.emptyIterator(), false); + assertThat(convertWindowedKey(mergeIterator.peekNextKey()), equalTo(new Windowed<>(storeKey, storeWindow))); + } + + @Test + public void shouldPeekNextKeyFromReverseStore() { + final MergedSortedCacheWindowStoreKeyValueIterator mergeIterator = + createIterator(storeKvs, Collections.emptyIterator(), true); + assertThat(convertWindowedKey(mergeIterator.peekNextKey()), equalTo(new Windowed<>(storeKey, storeWindow))); + } + + @Test + public void shouldHaveNextFromCache() { + final MergedSortedCacheWindowStoreKeyValueIterator mergeIterator = + createIterator(Collections.emptyIterator(), cacheKvs, false); + assertTrue(mergeIterator.hasNext()); + } + + @Test + public void shouldHaveNextFromReverseCache() { + final MergedSortedCacheWindowStoreKeyValueIterator mergeIterator = + createIterator(Collections.emptyIterator(), cacheKvs, true); + assertTrue(mergeIterator.hasNext()); + } + + @Test + public void shouldGetNextFromCache() { + final MergedSortedCacheWindowStoreKeyValueIterator mergeIterator = + createIterator(Collections.emptyIterator(), cacheKvs, false); + assertThat(convertKeyValuePair(mergeIterator.next()), equalTo(KeyValue.pair(new Windowed<>(cacheKey, cacheWindow), cacheKey))); + } + + @Test + public void shouldGetNextFromReverseCache() { + final MergedSortedCacheWindowStoreKeyValueIterator mergeIterator = + createIterator(Collections.emptyIterator(), cacheKvs, true); + assertThat(convertKeyValuePair(mergeIterator.next()), equalTo(KeyValue.pair(new Windowed<>(cacheKey, cacheWindow), cacheKey))); + } + + @Test + public void shouldPeekNextKeyFromCache() { + final MergedSortedCacheWindowStoreKeyValueIterator mergeIterator = + createIterator(Collections.emptyIterator(), cacheKvs, false); + assertThat(convertWindowedKey(mergeIterator.peekNextKey()), equalTo(new Windowed<>(cacheKey, cacheWindow))); + } + + @Test + public void shouldPeekNextKeyFromReverseCache() { + final MergedSortedCacheWindowStoreKeyValueIterator mergeIterator = + createIterator(Collections.emptyIterator(), cacheKvs, true); + assertThat(convertWindowedKey(mergeIterator.peekNextKey()), equalTo(new Windowed<>(cacheKey, cacheWindow))); + } + + @Test + public void shouldIterateBothStoreAndCache() { + final MergedSortedCacheWindowStoreKeyValueIterator iterator = createIterator(storeKvs, cacheKvs, true); + assertThat(convertKeyValuePair(iterator.next()), equalTo(KeyValue.pair(new Windowed<>(storeKey, storeWindow), storeKey))); + assertThat(convertKeyValuePair(iterator.next()), equalTo(KeyValue.pair(new Windowed<>(cacheKey, cacheWindow), cacheKey))); + assertFalse(iterator.hasNext()); + } + + @Test + public void shouldReverseIterateBothStoreAndCache() { + final MergedSortedCacheWindowStoreKeyValueIterator iterator = createIterator(storeKvs, cacheKvs, false); + assertThat(convertKeyValuePair(iterator.next()), equalTo(KeyValue.pair(new Windowed<>(cacheKey, cacheWindow), cacheKey))); + assertThat(convertKeyValuePair(iterator.next()), equalTo(KeyValue.pair(new Windowed<>(storeKey, storeWindow), storeKey))); + assertFalse(iterator.hasNext()); + } + + private KeyValue, String> convertKeyValuePair(final KeyValue, byte[]> next) { + final String value = deserializer.deserialize("", next.value); + return KeyValue.pair(convertWindowedKey(next.key), value); + } + + private Windowed convertWindowedKey(final Windowed bytesWindowed) { + final String key = deserializer.deserialize("", bytesWindowed.key().get()); + return new Windowed<>(key, bytesWindowed.window()); + } + + + private MergedSortedCacheWindowStoreKeyValueIterator createIterator(final Iterator, byte[]>> storeKvs, + final Iterator> cacheKvs, + final boolean forward) { + final DelegatingPeekingKeyValueIterator, byte[]> storeIterator = + new DelegatingPeekingKeyValueIterator<>("store", new KeyValueIteratorStub<>(storeKvs)); + + final PeekingKeyValueIterator cacheIterator = + new DelegatingPeekingKeyValueIterator<>("cache", new KeyValueIteratorStub<>(cacheKvs)); + return new MergedSortedCacheWindowStoreKeyValueIterator( + cacheIterator, + storeIterator, + new StateSerdes<>("name", Serdes.Bytes(), Serdes.ByteArray()), + WINDOW_SIZE, + SINGLE_SEGMENT_CACHE_FUNCTION, + forward + ); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredKeyValueStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredKeyValueStoreTest.java new file mode 100644 index 0000000..3b1a314 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredKeyValueStoreTest.java @@ -0,0 +1,496 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.JmxReporter; +import org.apache.kafka.common.metrics.KafkaMetric; +import org.apache.kafka.common.metrics.KafkaMetricsContext; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.MetricsContext; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.processor.internals.ProcessorStateManager; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.test.KeyValueIteratorStub; +import org.easymock.EasyMockRule; +import org.easymock.Mock; +import org.easymock.MockType; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.easymock.EasyMock.anyObject; +import static org.easymock.EasyMock.aryEq; +import static org.easymock.EasyMock.eq; +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.expectLastCall; +import static org.easymock.EasyMock.mock; +import static org.easymock.EasyMock.niceMock; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.verify; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class MeteredKeyValueStoreTest { + + @Rule + public EasyMockRule rule = new EasyMockRule(this); + + private static final String APPLICATION_ID = "test-app"; + private static final String STORE_NAME = "store-name"; + private static final String STORE_TYPE = "scope"; + private static final String STORE_LEVEL_GROUP = "stream-state-metrics"; + private static final String CHANGELOG_TOPIC = "changelog-topic"; + private static final String THREAD_ID_TAG_KEY = "thread-id"; + private static final String KEY = "key"; + private static final Bytes KEY_BYTES = Bytes.wrap(KEY.getBytes()); + private static final String VALUE = "value"; + private static final byte[] VALUE_BYTES = VALUE.getBytes(); + private static final KeyValue BYTE_KEY_VALUE_PAIR = KeyValue.pair(KEY_BYTES, VALUE_BYTES); + + private final String threadId = Thread.currentThread().getName(); + private final TaskId taskId = new TaskId(0, 0, "My-Topology"); + + @Mock(type = MockType.NICE) + private KeyValueStore inner; + @Mock(type = MockType.NICE) + private InternalProcessorContext context; + + private MeteredKeyValueStore metered; + private final Metrics metrics = new Metrics(); + private Map tags; + + @Before + public void before() { + final Time mockTime = new MockTime(); + metered = new MeteredKeyValueStore<>( + inner, + STORE_TYPE, + mockTime, + Serdes.String(), + Serdes.String() + ); + metrics.config().recordLevel(Sensor.RecordingLevel.DEBUG); + expect(context.applicationId()).andStubReturn(APPLICATION_ID); + expect(context.metrics()).andStubReturn( + new StreamsMetricsImpl(metrics, "test", StreamsConfig.METRICS_LATEST, mockTime) + ); + expect(context.taskId()).andStubReturn(taskId); + expect(context.changelogFor(STORE_NAME)).andStubReturn(CHANGELOG_TOPIC); + expect(inner.name()).andStubReturn(STORE_NAME); + tags = mkMap( + mkEntry(THREAD_ID_TAG_KEY, threadId), + mkEntry("task-id", taskId.toString()), + mkEntry(STORE_TYPE + "-state-id", STORE_NAME) + ); + } + + private void init() { + replay(inner, context); + metered.init((StateStoreContext) context, metered); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldDelegateDeprecatedInit() { + final KeyValueStore inner = mock(KeyValueStore.class); + final MeteredKeyValueStore outer = new MeteredKeyValueStore<>( + inner, + STORE_TYPE, + new MockTime(), + Serdes.String(), + Serdes.String() + ); + expect(inner.name()).andStubReturn("store"); + inner.init((ProcessorContext) context, outer); + expectLastCall(); + replay(inner, context); + outer.init((ProcessorContext) context, outer); + verify(inner); + } + + @Test + public void shouldDelegateInit() { + final KeyValueStore inner = mock(KeyValueStore.class); + final MeteredKeyValueStore outer = new MeteredKeyValueStore<>( + inner, + STORE_TYPE, + new MockTime(), + Serdes.String(), + Serdes.String() + ); + expect(inner.name()).andStubReturn("store"); + inner.init((StateStoreContext) context, outer); + expectLastCall(); + replay(inner, context); + outer.init((StateStoreContext) context, outer); + verify(inner); + } + + @Test + public void shouldPassChangelogTopicNameToStateStoreSerde() { + doShouldPassChangelogTopicNameToStateStoreSerde(CHANGELOG_TOPIC); + } + + @Test + public void shouldPassDefaultChangelogTopicNameToStateStoreSerdeIfLoggingDisabled() { + final String defaultChangelogTopicName = ProcessorStateManager.storeChangelogTopic(APPLICATION_ID, STORE_NAME, taskId.topologyName()); + expect(context.changelogFor(STORE_NAME)).andReturn(null); + doShouldPassChangelogTopicNameToStateStoreSerde(defaultChangelogTopicName); + } + + private void doShouldPassChangelogTopicNameToStateStoreSerde(final String topic) { + final Serde keySerde = niceMock(Serde.class); + final Serializer keySerializer = mock(Serializer.class); + final Serde valueSerde = niceMock(Serde.class); + final Deserializer valueDeserializer = mock(Deserializer.class); + final Serializer valueSerializer = mock(Serializer.class); + expect(keySerde.serializer()).andStubReturn(keySerializer); + expect(keySerializer.serialize(topic, KEY)).andStubReturn(KEY.getBytes()); + expect(valueSerde.deserializer()).andStubReturn(valueDeserializer); + expect(valueDeserializer.deserialize(topic, VALUE_BYTES)).andStubReturn(VALUE); + expect(valueSerde.serializer()).andStubReturn(valueSerializer); + expect(valueSerializer.serialize(topic, VALUE)).andStubReturn(VALUE_BYTES); + expect(inner.get(KEY_BYTES)).andStubReturn(VALUE_BYTES); + replay(inner, context, keySerializer, keySerde, valueDeserializer, valueSerializer, valueSerde); + metered = new MeteredKeyValueStore<>( + inner, + STORE_TYPE, + new MockTime(), + keySerde, + valueSerde + ); + metered.init((StateStoreContext) context, metered); + + metered.get(KEY); + metered.put(KEY, VALUE); + + verify(keySerializer, valueDeserializer, valueSerializer); + } + + @Test + public void testMetrics() { + init(); + final JmxReporter reporter = new JmxReporter(); + final MetricsContext metricsContext = new KafkaMetricsContext("kafka.streams"); + reporter.contextChange(metricsContext); + + metrics.addReporter(reporter); + assertTrue(reporter.containsMbean(String.format( + "kafka.streams:type=%s,%s=%s,task-id=%s,%s-state-id=%s", + STORE_LEVEL_GROUP, + THREAD_ID_TAG_KEY, + threadId, + taskId.toString(), + STORE_TYPE, + STORE_NAME + ))); + } + + @Test + public void shouldRecordRestoreLatencyOnInit() { + inner.init((StateStoreContext) context, metered); + + init(); + + // it suffices to verify one restore metric since all restore metrics are recorded by the same sensor + // and the sensor is tested elsewhere + final KafkaMetric metric = metric("restore-rate"); + assertThat((Double) metric.metricValue(), greaterThan(0.0)); + verify(inner); + } + + @Test + public void shouldWriteBytesToInnerStoreAndRecordPutMetric() { + inner.put(eq(KEY_BYTES), aryEq(VALUE_BYTES)); + expectLastCall(); + init(); + + metered.put(KEY, VALUE); + + final KafkaMetric metric = metric("put-rate"); + assertTrue((Double) metric.metricValue() > 0); + verify(inner); + } + + @Test + public void shouldGetBytesFromInnerStoreAndReturnGetMetric() { + expect(inner.get(KEY_BYTES)).andReturn(VALUE_BYTES); + init(); + + assertThat(metered.get(KEY), equalTo(VALUE)); + + final KafkaMetric metric = metric("get-rate"); + assertTrue((Double) metric.metricValue() > 0); + verify(inner); + } + + @Test + public void shouldPutIfAbsentAndRecordPutIfAbsentMetric() { + expect(inner.putIfAbsent(eq(KEY_BYTES), aryEq(VALUE_BYTES))).andReturn(null); + init(); + + metered.putIfAbsent(KEY, VALUE); + + final KafkaMetric metric = metric("put-if-absent-rate"); + assertTrue((Double) metric.metricValue() > 0); + verify(inner); + } + + @SuppressWarnings("unchecked") + @Test + public void shouldPutAllToInnerStoreAndRecordPutAllMetric() { + inner.putAll(anyObject(List.class)); + expectLastCall(); + init(); + + metered.putAll(Collections.singletonList(KeyValue.pair(KEY, VALUE))); + + final KafkaMetric metric = metric("put-all-rate"); + assertTrue((Double) metric.metricValue() > 0); + verify(inner); + } + + @Test + public void shouldDeleteFromInnerStoreAndRecordDeleteMetric() { + expect(inner.delete(KEY_BYTES)).andReturn(VALUE_BYTES); + init(); + + metered.delete(KEY); + + final KafkaMetric metric = metric("delete-rate"); + assertTrue((Double) metric.metricValue() > 0); + verify(inner); + } + + @Test + public void shouldGetRangeFromInnerStoreAndRecordRangeMetric() { + expect(inner.range(KEY_BYTES, KEY_BYTES)) + .andReturn(new KeyValueIteratorStub<>(Collections.singletonList(BYTE_KEY_VALUE_PAIR).iterator())); + init(); + + final KeyValueIterator iterator = metered.range(KEY, KEY); + assertThat(iterator.next().value, equalTo(VALUE)); + assertFalse(iterator.hasNext()); + iterator.close(); + + final KafkaMetric metric = metric("range-rate"); + assertTrue((Double) metric.metricValue() > 0); + verify(inner); + } + + @Test + public void shouldGetAllFromInnerStoreAndRecordAllMetric() { + expect(inner.all()).andReturn(new KeyValueIteratorStub<>(Collections.singletonList(BYTE_KEY_VALUE_PAIR).iterator())); + init(); + + final KeyValueIterator iterator = metered.all(); + assertThat(iterator.next().value, equalTo(VALUE)); + assertFalse(iterator.hasNext()); + iterator.close(); + + final KafkaMetric metric = metric(new MetricName("all-rate", STORE_LEVEL_GROUP, "", tags)); + assertTrue((Double) metric.metricValue() > 0); + verify(inner); + } + + @Test + public void shouldFlushInnerWhenFlushTimeRecords() { + inner.flush(); + expectLastCall().once(); + init(); + + metered.flush(); + + final KafkaMetric metric = metric("flush-rate"); + assertTrue((Double) metric.metricValue() > 0); + verify(inner); + } + + private interface CachedKeyValueStore extends KeyValueStore, CachedStateStore { } + + @SuppressWarnings("unchecked") + @Test + public void shouldSetFlushListenerOnWrappedCachingStore() { + final CachedKeyValueStore cachedKeyValueStore = mock(CachedKeyValueStore.class); + + expect(cachedKeyValueStore.setFlushListener(anyObject(CacheFlushListener.class), eq(false))).andReturn(true); + replay(cachedKeyValueStore); + + metered = new MeteredKeyValueStore<>( + cachedKeyValueStore, + STORE_TYPE, + new MockTime(), + Serdes.String(), + Serdes.String() + ); + assertTrue(metered.setFlushListener(null, false)); + + verify(cachedKeyValueStore); + } + + @Test + public void shouldNotThrowNullPointerExceptionIfGetReturnsNull() { + expect(inner.get(Bytes.wrap("a".getBytes()))).andReturn(null); + + init(); + assertNull(metered.get("a")); + } + + @Test + public void shouldNotSetFlushListenerOnWrappedNoneCachingStore() { + assertFalse(metered.setFlushListener(null, false)); + } + + @Test + public void shouldRemoveMetricsOnClose() { + inner.close(); + expectLastCall(); + init(); // replays "inner" + + // There's always a "count" metric registered + assertThat(storeMetrics(), not(empty())); + metered.close(); + assertThat(storeMetrics(), empty()); + verify(inner); + } + + @Test + public void shouldRemoveMetricsEvenIfWrappedStoreThrowsOnClose() { + inner.close(); + expectLastCall().andThrow(new RuntimeException("Oops!")); + init(); // replays "inner" + + assertThat(storeMetrics(), not(empty())); + assertThrows(RuntimeException.class, metered::close); + assertThat(storeMetrics(), empty()); + verify(inner); + } + + @Test + public void shouldThrowNullPointerOnGetIfKeyIsNull() { + assertThrows(NullPointerException.class, () -> metered.get(null)); + } + + @Test + public void shouldThrowNullPointerOnPutIfKeyIsNull() { + assertThrows(NullPointerException.class, () -> metered.put(null, VALUE)); + } + + @Test + public void shouldThrowNullPointerOnPutIfAbsentIfKeyIsNull() { + assertThrows(NullPointerException.class, () -> metered.putIfAbsent(null, VALUE)); + } + + @Test + public void shouldThrowNullPointerOnDeleteIfKeyIsNull() { + assertThrows(NullPointerException.class, () -> metered.delete(null)); + } + + @Test + public void shouldThrowNullPointerOnPutAllIfAnyKeyIsNull() { + assertThrows(NullPointerException.class, () -> metered.putAll(Collections.singletonList(KeyValue.pair(null, VALUE)))); + } + + @Test + public void shouldThrowNullPointerOnPrefixScanIfPrefixIsNull() { + final StringSerializer stringSerializer = new StringSerializer(); + assertThrows(NullPointerException.class, () -> metered.prefixScan(null, stringSerializer)); + } + + @Test + public void shouldThrowNullPointerOnRangeIfFromIsNull() { + assertThrows(NullPointerException.class, () -> metered.range(null, "to")); + } + + @Test + public void shouldThrowNullPointerOnRangeIfToIsNull() { + assertThrows(NullPointerException.class, () -> metered.range("from", null)); + } + + @Test + public void shouldThrowNullPointerOnReverseRangeIfFromIsNull() { + assertThrows(NullPointerException.class, () -> metered.reverseRange(null, "to")); + } + + @Test + public void shouldThrowNullPointerOnReverseRangeIfToIsNull() { + assertThrows(NullPointerException.class, () -> metered.reverseRange("from", null)); + } + + @Test + public void shouldGetRecordsWithPrefixKey() { + final StringSerializer stringSerializer = new StringSerializer(); + expect(inner.prefixScan(KEY, stringSerializer)) + .andReturn(new KeyValueIteratorStub<>(Collections.singletonList(BYTE_KEY_VALUE_PAIR).iterator())); + init(); + + final KeyValueIterator iterator = metered.prefixScan(KEY, stringSerializer); + assertThat(iterator.next().value, equalTo(VALUE)); + iterator.close(); + + final KafkaMetric metric = metrics.metric(new MetricName("prefix-scan-rate", STORE_LEVEL_GROUP, "", tags)); + assertTrue((Double) metric.metricValue() > 0); + verify(inner); + } + + private KafkaMetric metric(final MetricName metricName) { + return this.metrics.metric(metricName); + } + + private KafkaMetric metric(final String name) { + return metrics.metric(new MetricName(name, STORE_LEVEL_GROUP, "", tags)); + } + + private List storeMetrics() { + return metrics.metrics() + .keySet() + .stream() + .filter(name -> name.group().equals(STORE_LEVEL_GROUP) && name.tags().equals(tags)) + .collect(Collectors.toList()); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreTest.java new file mode 100644 index 0000000..922608d --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreTest.java @@ -0,0 +1,607 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.JmxReporter; +import org.apache.kafka.common.metrics.KafkaMetric; +import org.apache.kafka.common.metrics.KafkaMetricsContext; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.MetricsContext; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.SessionWindow; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.processor.internals.ProcessorStateManager; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.SessionStore; +import org.apache.kafka.test.KeyValueIteratorStub; +import org.easymock.EasyMockRule; +import org.easymock.Mock; +import org.easymock.MockType; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.easymock.EasyMock.anyObject; +import static org.easymock.EasyMock.aryEq; +import static org.easymock.EasyMock.eq; +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.expectLastCall; +import static org.easymock.EasyMock.mock; +import static org.easymock.EasyMock.niceMock; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.verify; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class MeteredSessionStoreTest { + + @Rule + public EasyMockRule rule = new EasyMockRule(this); + + private static final String APPLICATION_ID = "test-app"; + private static final String STORE_TYPE = "scope"; + private static final String STORE_NAME = "mocked-store"; + private static final String STORE_LEVEL_GROUP = "stream-state-metrics"; + private static final String THREAD_ID_TAG_KEY = "thread-id"; + private static final String CHANGELOG_TOPIC = "changelog-topic"; + private static final String KEY = "key"; + private static final Bytes KEY_BYTES = Bytes.wrap(KEY.getBytes()); + private static final Windowed WINDOWED_KEY = new Windowed<>(KEY, new SessionWindow(0, 0)); + private static final Windowed WINDOWED_KEY_BYTES = new Windowed<>(KEY_BYTES, new SessionWindow(0, 0)); + private static final String VALUE = "value"; + private static final byte[] VALUE_BYTES = VALUE.getBytes(); + private static final long START_TIMESTAMP = 24L; + private static final long END_TIMESTAMP = 42L; + + private final String threadId = Thread.currentThread().getName(); + private final TaskId taskId = new TaskId(0, 0, "My-Topology"); + private final Metrics metrics = new Metrics(); + private MeteredSessionStore store; + @Mock(type = MockType.NICE) + private SessionStore innerStore; + @Mock(type = MockType.NICE) + private InternalProcessorContext context; + + private Map tags; + + @Before + public void before() { + final Time mockTime = new MockTime(); + store = new MeteredSessionStore<>( + innerStore, + STORE_TYPE, + Serdes.String(), + Serdes.String(), + mockTime + ); + metrics.config().recordLevel(Sensor.RecordingLevel.DEBUG); + expect(context.applicationId()).andStubReturn(APPLICATION_ID); + expect(context.metrics()) + .andStubReturn(new StreamsMetricsImpl(metrics, "test", StreamsConfig.METRICS_LATEST, mockTime)); + expect(context.taskId()).andStubReturn(taskId); + expect(context.changelogFor(STORE_NAME)).andStubReturn(CHANGELOG_TOPIC); + expect(innerStore.name()).andStubReturn(STORE_NAME); + tags = mkMap( + mkEntry(THREAD_ID_TAG_KEY, threadId), + mkEntry("task-id", taskId.toString()), + mkEntry(STORE_TYPE + "-state-id", STORE_NAME) + ); + } + + private void init() { + replay(innerStore, context); + store.init((StateStoreContext) context, store); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldDelegateDeprecatedInit() { + final SessionStore inner = mock(SessionStore.class); + final MeteredSessionStore outer = new MeteredSessionStore<>( + inner, + STORE_TYPE, + Serdes.String(), + Serdes.String(), + new MockTime() + ); + expect(inner.name()).andStubReturn("store"); + inner.init((ProcessorContext) context, outer); + expectLastCall(); + replay(inner, context); + outer.init((ProcessorContext) context, outer); + verify(inner); + } + + @Test + public void shouldDelegateInit() { + final SessionStore inner = mock(SessionStore.class); + final MeteredSessionStore outer = new MeteredSessionStore<>( + inner, + STORE_TYPE, + Serdes.String(), + Serdes.String(), + new MockTime() + ); + expect(inner.name()).andStubReturn("store"); + inner.init((StateStoreContext) context, outer); + expectLastCall(); + replay(inner, context); + outer.init((StateStoreContext) context, outer); + verify(inner); + } + + @Test + public void shouldPassChangelogTopicNameToStateStoreSerde() { + doShouldPassChangelogTopicNameToStateStoreSerde(CHANGELOG_TOPIC); + } + + @Test + public void shouldPassDefaultChangelogTopicNameToStateStoreSerdeIfLoggingDisabled() { + final String defaultChangelogTopicName = + ProcessorStateManager.storeChangelogTopic(APPLICATION_ID, STORE_NAME, taskId.topologyName()); + expect(context.changelogFor(STORE_NAME)).andReturn(null); + doShouldPassChangelogTopicNameToStateStoreSerde(defaultChangelogTopicName); + } + + private void doShouldPassChangelogTopicNameToStateStoreSerde(final String topic) { + final Serde keySerde = niceMock(Serde.class); + final Serializer keySerializer = mock(Serializer.class); + final Serde valueSerde = niceMock(Serde.class); + final Deserializer valueDeserializer = mock(Deserializer.class); + final Serializer valueSerializer = mock(Serializer.class); + expect(keySerde.serializer()).andStubReturn(keySerializer); + expect(keySerializer.serialize(topic, KEY)).andStubReturn(KEY.getBytes()); + expect(valueSerde.deserializer()).andStubReturn(valueDeserializer); + expect(valueDeserializer.deserialize(topic, VALUE_BYTES)).andStubReturn(VALUE); + expect(valueSerde.serializer()).andStubReturn(valueSerializer); + expect(valueSerializer.serialize(topic, VALUE)).andStubReturn(VALUE_BYTES); + expect(innerStore.fetchSession(KEY_BYTES, START_TIMESTAMP, END_TIMESTAMP)).andStubReturn(VALUE_BYTES); + replay(innerStore, context, keySerializer, keySerde, valueDeserializer, valueSerializer, valueSerde); + store = new MeteredSessionStore<>( + innerStore, + STORE_TYPE, + keySerde, + valueSerde, + new MockTime() + ); + store.init((StateStoreContext) context, store); + + store.fetchSession(KEY, START_TIMESTAMP, END_TIMESTAMP); + store.put(WINDOWED_KEY, VALUE); + + verify(keySerializer, valueDeserializer, valueSerializer); + } + + @Test + public void testMetrics() { + init(); + final JmxReporter reporter = new JmxReporter(); + final MetricsContext metricsContext = new KafkaMetricsContext("kafka.streams"); + reporter.contextChange(metricsContext); + + metrics.addReporter(reporter); + assertTrue(reporter.containsMbean(String.format( + "kafka.streams:type=%s,%s=%s,task-id=%s,%s-state-id=%s", + STORE_LEVEL_GROUP, + THREAD_ID_TAG_KEY, + threadId, + taskId.toString(), + STORE_TYPE, + STORE_NAME + ))); + } + + @Test + public void shouldWriteBytesToInnerStoreAndRecordPutMetric() { + innerStore.put(eq(WINDOWED_KEY_BYTES), aryEq(VALUE_BYTES)); + expectLastCall(); + init(); + + store.put(WINDOWED_KEY, VALUE); + + // it suffices to verify one put metric since all put metrics are recorded by the same sensor + // and the sensor is tested elsewhere + final KafkaMetric metric = metric("put-rate"); + assertTrue(((Double) metric.metricValue()) > 0); + verify(innerStore); + } + + @Test + public void shouldFindSessionsFromStoreAndRecordFetchMetric() { + expect(innerStore.findSessions(KEY_BYTES, 0, 0)) + .andReturn(new KeyValueIteratorStub<>( + Collections.singleton(KeyValue.pair(WINDOWED_KEY_BYTES, VALUE_BYTES)).iterator())); + init(); + + final KeyValueIterator, String> iterator = store.findSessions(KEY, 0, 0); + assertThat(iterator.next().value, equalTo(VALUE)); + assertFalse(iterator.hasNext()); + iterator.close(); + + // it suffices to verify one fetch metric since all put metrics are recorded by the same sensor + // and the sensor is tested elsewhere + final KafkaMetric metric = metric("fetch-rate"); + assertTrue((Double) metric.metricValue() > 0); + verify(innerStore); + } + + @Test + public void shouldBackwardFindSessionsFromStoreAndRecordFetchMetric() { + expect(innerStore.backwardFindSessions(KEY_BYTES, 0, 0)) + .andReturn( + new KeyValueIteratorStub<>( + Collections.singleton(KeyValue.pair(WINDOWED_KEY_BYTES, VALUE_BYTES)).iterator() + ) + ); + init(); + + final KeyValueIterator, String> iterator = store.backwardFindSessions(KEY, 0, 0); + assertThat(iterator.next().value, equalTo(VALUE)); + assertFalse(iterator.hasNext()); + iterator.close(); + + // it suffices to verify one fetch metric since all put metrics are recorded by the same sensor + // and the sensor is tested elsewhere + final KafkaMetric metric = metric("fetch-rate"); + assertTrue((Double) metric.metricValue() > 0); + verify(innerStore); + } + + @Test + public void shouldFindSessionRangeFromStoreAndRecordFetchMetric() { + expect(innerStore.findSessions(KEY_BYTES, KEY_BYTES, 0, 0)) + .andReturn(new KeyValueIteratorStub<>( + Collections.singleton(KeyValue.pair(WINDOWED_KEY_BYTES, VALUE_BYTES)).iterator())); + init(); + + final KeyValueIterator, String> iterator = store.findSessions(KEY, KEY, 0, 0); + assertThat(iterator.next().value, equalTo(VALUE)); + assertFalse(iterator.hasNext()); + iterator.close(); + + // it suffices to verify one fetch metric since all put metrics are recorded by the same sensor + // and the sensor is tested elsewhere + final KafkaMetric metric = metric("fetch-rate"); + assertTrue((Double) metric.metricValue() > 0); + verify(innerStore); + } + + @Test + public void shouldBackwardFindSessionRangeFromStoreAndRecordFetchMetric() { + expect(innerStore.backwardFindSessions(KEY_BYTES, KEY_BYTES, 0, 0)) + .andReturn( + new KeyValueIteratorStub<>( + Collections.singleton(KeyValue.pair(WINDOWED_KEY_BYTES, VALUE_BYTES)).iterator() + ) + ); + init(); + + final KeyValueIterator, String> iterator = store.backwardFindSessions(KEY, KEY, 0, 0); + assertThat(iterator.next().value, equalTo(VALUE)); + assertFalse(iterator.hasNext()); + iterator.close(); + + // it suffices to verify one fetch metric since all put metrics are recorded by the same sensor + // and the sensor is tested elsewhere + final KafkaMetric metric = metric("fetch-rate"); + assertTrue((Double) metric.metricValue() > 0); + verify(innerStore); + } + + @Test + public void shouldRemoveFromStoreAndRecordRemoveMetric() { + innerStore.remove(WINDOWED_KEY_BYTES); + expectLastCall(); + + init(); + + store.remove(new Windowed<>(KEY, new SessionWindow(0, 0))); + + // it suffices to verify one remove metric since all remove metrics are recorded by the same sensor + // and the sensor is tested elsewhere + final KafkaMetric metric = metric("remove-rate"); + assertTrue((Double) metric.metricValue() > 0); + verify(innerStore); + } + + @Test + public void shouldFetchForKeyAndRecordFetchMetric() { + expect(innerStore.fetch(KEY_BYTES)) + .andReturn(new KeyValueIteratorStub<>( + Collections.singleton(KeyValue.pair(WINDOWED_KEY_BYTES, VALUE_BYTES)).iterator())); + init(); + + final KeyValueIterator, String> iterator = store.fetch(KEY); + assertThat(iterator.next().value, equalTo(VALUE)); + assertFalse(iterator.hasNext()); + iterator.close(); + + // it suffices to verify one fetch metric since all fetch metrics are recorded by the same sensor + // and the sensor is tested elsewhere + final KafkaMetric metric = metric("fetch-rate"); + assertTrue((Double) metric.metricValue() > 0); + verify(innerStore); + } + + @Test + public void shouldBackwardFetchForKeyAndRecordFetchMetric() { + expect(innerStore.backwardFetch(KEY_BYTES)) + .andReturn( + new KeyValueIteratorStub<>( + Collections.singleton(KeyValue.pair(WINDOWED_KEY_BYTES, VALUE_BYTES)).iterator() + ) + ); + init(); + + final KeyValueIterator, String> iterator = store.backwardFetch(KEY); + assertThat(iterator.next().value, equalTo(VALUE)); + assertFalse(iterator.hasNext()); + iterator.close(); + + // it suffices to verify one fetch metric since all fetch metrics are recorded by the same sensor + // and the sensor is tested elsewhere + final KafkaMetric metric = metric("fetch-rate"); + assertTrue((Double) metric.metricValue() > 0); + verify(innerStore); + } + + @Test + public void shouldFetchRangeFromStoreAndRecordFetchMetric() { + expect(innerStore.fetch(KEY_BYTES, KEY_BYTES)) + .andReturn(new KeyValueIteratorStub<>( + Collections.singleton(KeyValue.pair(WINDOWED_KEY_BYTES, VALUE_BYTES)).iterator())); + init(); + + final KeyValueIterator, String> iterator = store.fetch(KEY, KEY); + assertThat(iterator.next().value, equalTo(VALUE)); + assertFalse(iterator.hasNext()); + iterator.close(); + + // it suffices to verify one fetch metric since all fetch metrics are recorded by the same sensor + // and the sensor is tested elsewhere + final KafkaMetric metric = metric("fetch-rate"); + assertTrue((Double) metric.metricValue() > 0); + verify(innerStore); + } + + @Test + public void shouldBackwardFetchRangeFromStoreAndRecordFetchMetric() { + expect(innerStore.backwardFetch(KEY_BYTES, KEY_BYTES)) + .andReturn( + new KeyValueIteratorStub<>( + Collections.singleton(KeyValue.pair(WINDOWED_KEY_BYTES, VALUE_BYTES)).iterator() + ) + ); + init(); + + final KeyValueIterator, String> iterator = store.backwardFetch(KEY, KEY); + assertThat(iterator.next().value, equalTo(VALUE)); + assertFalse(iterator.hasNext()); + iterator.close(); + + // it suffices to verify one fetch metric since all fetch metrics are recorded by the same sensor + // and the sensor is tested elsewhere + final KafkaMetric metric = metric("fetch-rate"); + assertTrue((Double) metric.metricValue() > 0); + verify(innerStore); + } + + @Test + public void shouldRecordRestoreTimeOnInit() { + init(); + + // it suffices to verify one restore metric since all restore metrics are recorded by the same sensor + // and the sensor is tested elsewhere + final KafkaMetric metric = metric("restore-rate"); + assertTrue((Double) metric.metricValue() > 0); + } + + @Test + public void shouldNotThrowNullPointerExceptionIfFetchSessionReturnsNull() { + expect(innerStore.fetchSession(Bytes.wrap("a".getBytes()), 0, Long.MAX_VALUE)).andReturn(null); + + init(); + assertNull(store.fetchSession("a", 0, Long.MAX_VALUE)); + } + + @Test + public void shouldThrowNullPointerOnPutIfKeyIsNull() { + assertThrows(NullPointerException.class, () -> store.put(null, "a")); + } + + @Test + public void shouldThrowNullPointerOnRemoveIfKeyIsNull() { + assertThrows(NullPointerException.class, () -> store.remove(null)); + } + + @Test + public void shouldThrowNullPointerOnPutIfWrappedKeyIsNull() { + assertThrows(NullPointerException.class, () -> store.put(new Windowed<>(null, new SessionWindow(0, 0)), "a")); + } + + @Test + public void shouldThrowNullPointerOnRemoveIfWrappedKeyIsNull() { + assertThrows(NullPointerException.class, () -> store.remove(new Windowed<>(null, new SessionWindow(0, 0)))); + } + + @Test + public void shouldThrowNullPointerOnPutIfWindowIsNull() { + assertThrows(NullPointerException.class, () -> store.put(new Windowed<>(KEY, null), "a")); + } + + @Test + public void shouldThrowNullPointerOnRemoveIfWindowIsNull() { + assertThrows(NullPointerException.class, () -> store.remove(new Windowed<>(KEY, null))); + } + + @Test + public void shouldThrowNullPointerOnFetchIfKeyIsNull() { + assertThrows(NullPointerException.class, () -> store.fetch(null)); + } + + @Test + public void shouldThrowNullPointerOnFetchSessionIfKeyIsNull() { + assertThrows(NullPointerException.class, () -> store.fetchSession(null, 0, Long.MAX_VALUE)); + } + + @Test + public void shouldThrowNullPointerOnFetchRangeIfFromIsNull() { + assertThrows(NullPointerException.class, () -> store.fetch(null, "to")); + } + + @Test + public void shouldThrowNullPointerOnFetchRangeIfToIsNull() { + assertThrows(NullPointerException.class, () -> store.fetch("from", null)); + } + + @Test + public void shouldThrowNullPointerOnBackwardFetchIfKeyIsNull() { + assertThrows(NullPointerException.class, () -> store.backwardFetch(null)); + } + + @Test + public void shouldThrowNullPointerOnBackwardFetchIfFromIsNull() { + assertThrows(NullPointerException.class, () -> store.backwardFetch(null, "to")); + } + + @Test + public void shouldThrowNullPointerOnBackwardFetchIfToIsNull() { + assertThrows(NullPointerException.class, () -> store.backwardFetch("from", null)); + } + + @Test + public void shouldThrowNullPointerOnFindSessionsIfKeyIsNull() { + assertThrows(NullPointerException.class, () -> store.findSessions(null, 0, 0)); + } + + @Test + public void shouldThrowNullPointerOnFindSessionsRangeIfFromIsNull() { + assertThrows(NullPointerException.class, () -> store.findSessions(null, "a", 0, 0)); + } + + @Test + public void shouldThrowNullPointerOnFindSessionsRangeIfToIsNull() { + assertThrows(NullPointerException.class, () -> store.findSessions("a", null, 0, 0)); + } + + @Test + public void shouldThrowNullPointerOnBackwardFindSessionsIfKeyIsNull() { + assertThrows(NullPointerException.class, () -> store.backwardFindSessions(null, 0, 0)); + } + + @Test + public void shouldThrowNullPointerOnBackwardFindSessionsRangeIfFromIsNull() { + assertThrows(NullPointerException.class, () -> store.backwardFindSessions(null, "a", 0, 0)); + } + + @Test + public void shouldThrowNullPointerOnBackwardFindSessionsRangeIfToIsNull() { + assertThrows(NullPointerException.class, () -> store.backwardFindSessions("a", null, 0, 0)); + } + + private interface CachedSessionStore extends SessionStore, CachedStateStore { } + + @SuppressWarnings("unchecked") + @Test + public void shouldSetFlushListenerOnWrappedCachingStore() { + final CachedSessionStore cachedSessionStore = mock(CachedSessionStore.class); + + expect(cachedSessionStore.setFlushListener(anyObject(CacheFlushListener.class), eq(false))).andReturn(true); + replay(cachedSessionStore); + + store = new MeteredSessionStore<>( + cachedSessionStore, + STORE_TYPE, + Serdes.String(), + Serdes.String(), + new MockTime()); + assertTrue(store.setFlushListener(null, false)); + + verify(cachedSessionStore); + } + + @Test + public void shouldNotSetFlushListenerOnWrappedNoneCachingStore() { + assertFalse(store.setFlushListener(null, false)); + } + + @Test + public void shouldRemoveMetricsOnClose() { + innerStore.close(); + expectLastCall(); + init(); // replays "inner" + + // There's always a "count" metric registered + assertThat(storeMetrics(), not(empty())); + store.close(); + assertThat(storeMetrics(), empty()); + verify(innerStore); + } + + @Test + public void shouldRemoveMetricsEvenIfWrappedStoreThrowsOnClose() { + innerStore.close(); + expectLastCall().andThrow(new RuntimeException("Oops!")); + init(); // replays "inner" + + assertThat(storeMetrics(), not(empty())); + assertThrows(RuntimeException.class, store::close); + assertThat(storeMetrics(), empty()); + verify(innerStore); + } + + private KafkaMetric metric(final String name) { + return this.metrics.metric(new MetricName(name, STORE_LEVEL_GROUP, "", this.tags)); + } + + private List storeMetrics() { + return metrics.metrics() + .keySet() + .stream() + .filter(name -> name.group().equals(STORE_LEVEL_GROUP) && name.tags().equals(tags)) + .collect(Collectors.toList()); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredTimestampedKeyValueStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredTimestampedKeyValueStoreTest.java new file mode 100644 index 0000000..c24a52c --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredTimestampedKeyValueStoreTest.java @@ -0,0 +1,478 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.JmxReporter; +import org.apache.kafka.common.metrics.KafkaMetric; +import org.apache.kafka.common.metrics.KafkaMetricsContext; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.MetricsContext; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.processor.internals.ProcessorStateManager; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.streams.state.internals.MeteredTimestampedKeyValueStore.RawAndDeserializedValue; +import org.apache.kafka.test.KeyValueIteratorStub; +import org.easymock.EasyMockRule; +import org.easymock.Mock; +import org.easymock.MockType; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.easymock.EasyMock.anyObject; +import static org.easymock.EasyMock.aryEq; +import static org.easymock.EasyMock.eq; +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.expectLastCall; +import static org.easymock.EasyMock.mock; +import static org.easymock.EasyMock.niceMock; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.verify; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class MeteredTimestampedKeyValueStoreTest { + + @Rule + public EasyMockRule rule = new EasyMockRule(this); + + private static final String APPLICATION_ID = "test-app"; + private static final String STORE_NAME = "store-name"; + private static final String STORE_TYPE = "scope"; + private static final String STORE_LEVEL_GROUP = "stream-state-metrics"; + private static final String CHANGELOG_TOPIC = "changelog-topic-name"; + private static final String THREAD_ID_TAG_KEY = "thread-id"; + private static final String KEY = "key"; + private static final Bytes KEY_BYTES = Bytes.wrap(KEY.getBytes()); + private static final ValueAndTimestamp VALUE_AND_TIMESTAMP = + ValueAndTimestamp.make("value", 97L); + // timestamp is 97 what is ASCII of 'a' + private static final byte[] VALUE_AND_TIMESTAMP_BYTES = "\0\0\0\0\0\0\0avalue".getBytes(); + + + private final String threadId = Thread.currentThread().getName(); + private final TaskId taskId = new TaskId(0, 0, "My-Topology"); + @Mock(type = MockType.NICE) + private KeyValueStore inner; + @Mock(type = MockType.NICE) + private InternalProcessorContext context; + + private MeteredTimestampedKeyValueStore metered; + private final KeyValue byteKeyValueTimestampPair = KeyValue.pair(KEY_BYTES, + VALUE_AND_TIMESTAMP_BYTES + ); + private final Metrics metrics = new Metrics(); + private Map tags; + + @Before + public void before() { + final Time mockTime = new MockTime(); + metered = new MeteredTimestampedKeyValueStore<>( + inner, + "scope", + mockTime, + Serdes.String(), + new ValueAndTimestampSerde<>(Serdes.String()) + ); + metrics.config().recordLevel(Sensor.RecordingLevel.DEBUG); + expect(context.applicationId()).andStubReturn(APPLICATION_ID); + expect(context.metrics()) + .andStubReturn(new StreamsMetricsImpl(metrics, "test", StreamsConfig.METRICS_LATEST, mockTime)); + expect(context.taskId()).andStubReturn(taskId); + expect(context.changelogFor(STORE_NAME)).andStubReturn(CHANGELOG_TOPIC); + expectSerdes(); + expect(inner.name()).andStubReturn(STORE_NAME); + tags = mkMap( + mkEntry(THREAD_ID_TAG_KEY, threadId), + mkEntry("task-id", taskId.toString()), + mkEntry(STORE_TYPE + "-state-id", STORE_NAME) + ); + + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private void expectSerdes() { + expect(context.keySerde()).andStubReturn((Serde) Serdes.String()); + expect(context.valueSerde()).andStubReturn((Serde) Serdes.Long()); + } + + private void init() { + replay(inner, context); + metered.init((StateStoreContext) context, metered); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldDelegateDeprecatedInit() { + final KeyValueStore inner = mock(InMemoryKeyValueStore.class); + final MeteredTimestampedKeyValueStore outer = new MeteredTimestampedKeyValueStore<>( + inner, + STORE_TYPE, + new MockTime(), + Serdes.String(), + new ValueAndTimestampSerde<>(Serdes.String()) + ); + expect(inner.name()).andStubReturn("store"); + inner.init((ProcessorContext) context, outer); + expectLastCall(); + replay(inner, context); + outer.init((ProcessorContext) context, outer); + verify(inner); + } + + @Test + public void shouldDelegateInit() { + final KeyValueStore inner = mock(InMemoryKeyValueStore.class); + final MeteredTimestampedKeyValueStore outer = new MeteredTimestampedKeyValueStore<>( + inner, + STORE_TYPE, + new MockTime(), + Serdes.String(), + new ValueAndTimestampSerde<>(Serdes.String()) + ); + expect(inner.name()).andStubReturn("store"); + inner.init((StateStoreContext) context, outer); + expectLastCall(); + replay(inner, context); + outer.init((StateStoreContext) context, outer); + verify(inner); + } + + @Test + public void shouldPassChangelogTopicNameToStateStoreSerde() { + doShouldPassChangelogTopicNameToStateStoreSerde(CHANGELOG_TOPIC); + } + + @Test + public void shouldPassDefaultChangelogTopicNameToStateStoreSerdeIfLoggingDisabled() { + final String defaultChangelogTopicName = ProcessorStateManager.storeChangelogTopic(APPLICATION_ID, STORE_NAME, taskId.topologyName()); + expect(context.changelogFor(STORE_NAME)).andReturn(null); + doShouldPassChangelogTopicNameToStateStoreSerde(defaultChangelogTopicName); + } + + private void doShouldPassChangelogTopicNameToStateStoreSerde(final String topic) { + final Serde keySerde = niceMock(Serde.class); + final Serializer keySerializer = mock(Serializer.class); + final Serde> valueSerde = niceMock(Serde.class); + final Deserializer> valueDeserializer = mock(Deserializer.class); + final Serializer> valueSerializer = mock(Serializer.class); + expect(keySerde.serializer()).andStubReturn(keySerializer); + expect(keySerializer.serialize(topic, KEY)).andStubReturn(KEY.getBytes()); + expect(valueSerde.deserializer()).andStubReturn(valueDeserializer); + expect(valueDeserializer.deserialize(topic, VALUE_AND_TIMESTAMP_BYTES)).andStubReturn(VALUE_AND_TIMESTAMP); + expect(valueSerde.serializer()).andStubReturn(valueSerializer); + expect(valueSerializer.serialize(topic, VALUE_AND_TIMESTAMP)).andStubReturn(VALUE_AND_TIMESTAMP_BYTES); + expect(inner.get(KEY_BYTES)).andStubReturn(VALUE_AND_TIMESTAMP_BYTES); + replay(inner, context, keySerializer, keySerde, valueDeserializer, valueSerializer, valueSerde); + metered = new MeteredTimestampedKeyValueStore<>( + inner, + STORE_TYPE, + new MockTime(), + keySerde, + valueSerde + ); + metered.init((StateStoreContext) context, metered); + + metered.get(KEY); + metered.put(KEY, VALUE_AND_TIMESTAMP); + + verify(keySerializer, valueDeserializer, valueSerializer); + } + + @Test + public void testMetrics() { + init(); + final JmxReporter reporter = new JmxReporter(); + final MetricsContext metricsContext = new KafkaMetricsContext("kafka.streams"); + reporter.contextChange(metricsContext); + + metrics.addReporter(reporter); + assertTrue(reporter.containsMbean(String.format( + "kafka.streams:type=%s,%s=%s,task-id=%s,%s-state-id=%s", + STORE_LEVEL_GROUP, + THREAD_ID_TAG_KEY, + threadId, + taskId.toString(), + STORE_TYPE, + STORE_NAME + ))); + } + @Test + public void shouldWriteBytesToInnerStoreAndRecordPutMetric() { + inner.put(eq(KEY_BYTES), aryEq(VALUE_AND_TIMESTAMP_BYTES)); + expectLastCall(); + init(); + + metered.put(KEY, VALUE_AND_TIMESTAMP); + + final KafkaMetric metric = metric("put-rate"); + assertTrue((Double) metric.metricValue() > 0); + verify(inner); + } + + @Test + public void shouldGetWithBinary() { + expect(inner.get(KEY_BYTES)).andReturn(VALUE_AND_TIMESTAMP_BYTES); + + inner.put(eq(KEY_BYTES), aryEq(VALUE_AND_TIMESTAMP_BYTES)); + expectLastCall(); + init(); + + final RawAndDeserializedValue valueWithBinary = metered.getWithBinary(KEY); + assertEquals(valueWithBinary.value, VALUE_AND_TIMESTAMP); + assertEquals(valueWithBinary.serializedValue, VALUE_AND_TIMESTAMP_BYTES); + } + + @SuppressWarnings("resource") + @Test + public void shouldNotPutIfSameValuesAndGreaterTimestamp() { + init(); + + metered.put(KEY, VALUE_AND_TIMESTAMP); + final ValueAndTimestampSerde stringSerde = new ValueAndTimestampSerde<>(Serdes.String()); + final byte[] encodedOldValue = stringSerde.serializer().serialize("TOPIC", VALUE_AND_TIMESTAMP); + + final ValueAndTimestamp newValueAndTimestamp = ValueAndTimestamp.make("value", 98L); + assertFalse(metered.putIfDifferentValues(KEY, newValueAndTimestamp, encodedOldValue)); + verify(inner); + } + + @SuppressWarnings("resource") + @Test + public void shouldPutIfOutOfOrder() { + inner.put(eq(KEY_BYTES), aryEq(VALUE_AND_TIMESTAMP_BYTES)); + expectLastCall(); + init(); + + metered.put(KEY, VALUE_AND_TIMESTAMP); + + final ValueAndTimestampSerde stringSerde = new ValueAndTimestampSerde<>(Serdes.String()); + final byte[] encodedOldValue = stringSerde.serializer().serialize("TOPIC", VALUE_AND_TIMESTAMP); + + final ValueAndTimestamp outOfOrderValueAndTimestamp = ValueAndTimestamp.make("value", 95L); + assertTrue(metered.putIfDifferentValues(KEY, outOfOrderValueAndTimestamp, encodedOldValue)); + verify(inner); + } + + @Test + public void shouldGetBytesFromInnerStoreAndReturnGetMetric() { + expect(inner.get(KEY_BYTES)).andReturn(VALUE_AND_TIMESTAMP_BYTES); + init(); + + assertThat(metered.get(KEY), equalTo(VALUE_AND_TIMESTAMP)); + + final KafkaMetric metric = metric("get-rate"); + assertTrue((Double) metric.metricValue() > 0); + verify(inner); + } + + @Test + public void shouldPutIfAbsentAndRecordPutIfAbsentMetric() { + expect(inner.putIfAbsent(eq(KEY_BYTES), aryEq(VALUE_AND_TIMESTAMP_BYTES))).andReturn(null); + init(); + + metered.putIfAbsent(KEY, VALUE_AND_TIMESTAMP); + + final KafkaMetric metric = metric("put-if-absent-rate"); + assertTrue((Double) metric.metricValue() > 0); + verify(inner); + } + + private KafkaMetric metric(final String name) { + return this.metrics.metric(new MetricName(name, STORE_LEVEL_GROUP, "", tags)); + } + + @SuppressWarnings("unchecked") + @Test + public void shouldPutAllToInnerStoreAndRecordPutAllMetric() { + inner.putAll(anyObject(List.class)); + expectLastCall(); + init(); + + metered.putAll(Collections.singletonList(KeyValue.pair(KEY, VALUE_AND_TIMESTAMP))); + + final KafkaMetric metric = metric("put-all-rate"); + assertTrue((Double) metric.metricValue() > 0); + verify(inner); + } + + @Test + public void shouldDeleteFromInnerStoreAndRecordDeleteMetric() { + expect(inner.delete(KEY_BYTES)).andReturn(VALUE_AND_TIMESTAMP_BYTES); + init(); + + metered.delete(KEY); + + final KafkaMetric metric = metric("delete-rate"); + assertTrue((Double) metric.metricValue() > 0); + verify(inner); + } + + @Test + public void shouldGetRangeFromInnerStoreAndRecordRangeMetric() { + expect(inner.range(KEY_BYTES, KEY_BYTES)).andReturn( + new KeyValueIteratorStub<>(Collections.singletonList(byteKeyValueTimestampPair).iterator())); + init(); + + final KeyValueIterator> iterator = metered.range(KEY, KEY); + assertThat(iterator.next().value, equalTo(VALUE_AND_TIMESTAMP)); + assertFalse(iterator.hasNext()); + iterator.close(); + + final KafkaMetric metric = metric("range-rate"); + assertTrue((Double) metric.metricValue() > 0); + verify(inner); + } + + @Test + public void shouldGetAllFromInnerStoreAndRecordAllMetric() { + expect(inner.all()) + .andReturn(new KeyValueIteratorStub<>(Collections.singletonList(byteKeyValueTimestampPair).iterator())); + init(); + + final KeyValueIterator> iterator = metered.all(); + assertThat(iterator.next().value, equalTo(VALUE_AND_TIMESTAMP)); + assertFalse(iterator.hasNext()); + iterator.close(); + + final KafkaMetric metric = metric(new MetricName("all-rate", STORE_LEVEL_GROUP, "", tags)); + assertTrue((Double) metric.metricValue() > 0); + verify(inner); + } + + @Test + public void shouldFlushInnerWhenFlushTimeRecords() { + inner.flush(); + expectLastCall().once(); + init(); + + metered.flush(); + + final KafkaMetric metric = metric("flush-rate"); + assertTrue((Double) metric.metricValue() > 0); + verify(inner); + } + + private interface CachedKeyValueStore extends KeyValueStore, CachedStateStore { } + + @SuppressWarnings("unchecked") + @Test + public void shouldSetFlushListenerOnWrappedCachingStore() { + final CachedKeyValueStore cachedKeyValueStore = mock(CachedKeyValueStore.class); + + expect(cachedKeyValueStore.setFlushListener(anyObject(CacheFlushListener.class), eq(false))).andReturn(true); + replay(cachedKeyValueStore); + + metered = new MeteredTimestampedKeyValueStore<>( + cachedKeyValueStore, + STORE_TYPE, + new MockTime(), + Serdes.String(), + new ValueAndTimestampSerde<>(Serdes.String())); + assertTrue(metered.setFlushListener(null, false)); + + verify(cachedKeyValueStore); + } + + @Test + public void shouldNotSetFlushListenerOnWrappedNoneCachingStore() { + assertFalse(metered.setFlushListener(null, false)); + } + + private KafkaMetric metric(final MetricName metricName) { + return this.metrics.metric(metricName); + } + + @Test + public void shouldNotThrowExceptionIfSerdesCorrectlySetFromProcessorContext() { + final MeteredTimestampedKeyValueStore store = new MeteredTimestampedKeyValueStore<>( + inner, + STORE_TYPE, + new MockTime(), + null, + null + ); + replay(inner, context); + store.init((StateStoreContext) context, inner); + + try { + store.put("key", ValueAndTimestamp.make(42L, 60000)); + } catch (final StreamsException exception) { + if (exception.getCause() instanceof ClassCastException) { + throw new AssertionError( + "Serdes are not correctly set from processor context.", + exception + ); + } else { + throw exception; + } + } + } + + @Test + @SuppressWarnings("unchecked") + public void shouldNotThrowExceptionIfSerdesCorrectlySetFromConstructorParameters() { + expect(context.keySerde()).andStubReturn((Serde) Serdes.String()); + expect(context.valueSerde()).andStubReturn((Serde) Serdes.Long()); + final MeteredTimestampedKeyValueStore store = new MeteredTimestampedKeyValueStore<>( + inner, + STORE_TYPE, + new MockTime(), + Serdes.String(), + new ValueAndTimestampSerde<>(Serdes.Long()) + ); + replay(inner, context); + store.init((StateStoreContext) context, inner); + + try { + store.put("key", ValueAndTimestamp.make(42L, 60000)); + } catch (final StreamsException exception) { + if (exception.getCause() instanceof ClassCastException) { + fail("Serdes are not correctly set from constructor parameters."); + } + throw exception; + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredTimestampedWindowStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredTimestampedWindowStoreTest.java new file mode 100644 index 0000000..92ba38a --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredTimestampedWindowStoreTest.java @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.metrics.MetricConfig; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.ProcessorStateManager; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.test.InternalMockProcessorContext; +import org.apache.kafka.test.MockRecordCollector; +import org.apache.kafka.test.StreamsTestUtils; +import org.apache.kafka.test.TestUtils; +import org.easymock.EasyMock; +import org.junit.Before; +import org.junit.Test; + +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.expectLastCall; +import static org.easymock.EasyMock.mock; +import static org.easymock.EasyMock.niceMock; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.verify; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.fail; + +public class MeteredTimestampedWindowStoreTest { + + private static final String STORE_NAME = "mocked-store"; + private static final String STORE_TYPE = "scope"; + private static final String CHANGELOG_TOPIC = "changelog-topic"; + private static final String KEY = "key"; + private static final Bytes KEY_BYTES = Bytes.wrap(KEY.getBytes()); + // timestamp is 97 what is ASCII of 'a' + private static final long TIMESTAMP = 97L; + private static final ValueAndTimestamp VALUE_AND_TIMESTAMP = + ValueAndTimestamp.make("value", TIMESTAMP); + private static final byte[] VALUE_AND_TIMESTAMP_BYTES = "\0\0\0\0\0\0\0avalue".getBytes(); + private static final int WINDOW_SIZE_MS = 10; + + private InternalMockProcessorContext context; + private final TaskId taskId = new TaskId(0, 0, "My-Topology"); + private final WindowStore innerStoreMock = EasyMock.createNiceMock(WindowStore.class); + private final Metrics metrics = new Metrics(new MetricConfig().recordLevel(Sensor.RecordingLevel.DEBUG)); + private MeteredTimestampedWindowStore store = new MeteredTimestampedWindowStore<>( + innerStoreMock, + WINDOW_SIZE_MS, // any size + STORE_TYPE, + new MockTime(), + Serdes.String(), + new ValueAndTimestampSerde<>(new SerdeThatDoesntHandleNull()) + ); + + { + EasyMock.expect(innerStoreMock.name()).andStubReturn(STORE_NAME); + } + + @Before + public void setUp() { + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, "test", StreamsConfig.METRICS_LATEST, new MockTime()); + + context = new InternalMockProcessorContext<>( + TestUtils.tempDirectory(), + Serdes.String(), + Serdes.Long(), + streamsMetrics, + new StreamsConfig(StreamsTestUtils.getStreamsConfig()), + MockRecordCollector::new, + new ThreadCache(new LogContext("testCache "), 0, streamsMetrics), + Time.SYSTEM, + taskId + ); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldDelegateDeprecatedInit() { + final WindowStore inner = mock(WindowStore.class); + final MeteredTimestampedWindowStore outer = new MeteredTimestampedWindowStore<>( + inner, + WINDOW_SIZE_MS, // any size + STORE_TYPE, + new MockTime(), + Serdes.String(), + new ValueAndTimestampSerde<>(new SerdeThatDoesntHandleNull()) + ); + expect(inner.name()).andStubReturn("store"); + inner.init((ProcessorContext) context, outer); + expectLastCall(); + replay(inner); + outer.init((ProcessorContext) context, outer); + verify(inner); + } + + @Test + public void shouldDelegateInit() { + final WindowStore inner = mock(WindowStore.class); + final MeteredTimestampedWindowStore outer = new MeteredTimestampedWindowStore<>( + inner, + WINDOW_SIZE_MS, // any size + STORE_TYPE, + new MockTime(), + Serdes.String(), + new ValueAndTimestampSerde<>(new SerdeThatDoesntHandleNull()) + ); + expect(inner.name()).andStubReturn("store"); + inner.init((StateStoreContext) context, outer); + expectLastCall(); + replay(inner); + outer.init((StateStoreContext) context, outer); + verify(inner); + } + + @Test + public void shouldPassChangelogTopicNameToStateStoreSerde() { + context.addChangelogForStore(STORE_NAME, CHANGELOG_TOPIC); + doShouldPassChangelogTopicNameToStateStoreSerde(CHANGELOG_TOPIC); + } + + @Test + public void shouldPassDefaultChangelogTopicNameToStateStoreSerdeIfLoggingDisabled() { + final String defaultChangelogTopicName = + ProcessorStateManager.storeChangelogTopic(context.applicationId(), STORE_NAME, taskId.topologyName()); + doShouldPassChangelogTopicNameToStateStoreSerde(defaultChangelogTopicName); + } + + private void doShouldPassChangelogTopicNameToStateStoreSerde(final String topic) { + final Serde keySerde = niceMock(Serde.class); + final Serializer keySerializer = mock(Serializer.class); + final Serde> valueSerde = niceMock(Serde.class); + final Deserializer> valueDeserializer = mock(Deserializer.class); + final Serializer> valueSerializer = mock(Serializer.class); + expect(keySerde.serializer()).andStubReturn(keySerializer); + expect(keySerializer.serialize(topic, KEY)).andStubReturn(KEY.getBytes()); + expect(valueSerde.deserializer()).andStubReturn(valueDeserializer); + expect(valueDeserializer.deserialize(topic, VALUE_AND_TIMESTAMP_BYTES)).andStubReturn(VALUE_AND_TIMESTAMP); + expect(valueSerde.serializer()).andStubReturn(valueSerializer); + expect(valueSerializer.serialize(topic, VALUE_AND_TIMESTAMP)).andStubReturn(VALUE_AND_TIMESTAMP_BYTES); + expect(innerStoreMock.fetch(KEY_BYTES, TIMESTAMP)).andStubReturn(VALUE_AND_TIMESTAMP_BYTES); + replay(innerStoreMock, keySerializer, keySerde, valueDeserializer, valueSerializer, valueSerde); + store = new MeteredTimestampedWindowStore<>( + innerStoreMock, + WINDOW_SIZE_MS, + STORE_TYPE, + new MockTime(), + keySerde, + valueSerde + ); + store.init((StateStoreContext) context, store); + + store.fetch(KEY, TIMESTAMP); + store.put(KEY, VALUE_AND_TIMESTAMP, TIMESTAMP); + + verify(keySerializer, valueDeserializer, valueSerializer); + } + + @Test + public void shouldCloseUnderlyingStore() { + innerStoreMock.close(); + EasyMock.expectLastCall(); + EasyMock.replay(innerStoreMock); + + store.init((StateStoreContext) context, store); + store.close(); + EasyMock.verify(innerStoreMock); + } + + @Test + public void shouldNotExceptionIfFetchReturnsNull() { + EasyMock.expect(innerStoreMock.fetch(Bytes.wrap("a".getBytes()), 0)).andReturn(null); + EasyMock.replay(innerStoreMock); + + store.init((StateStoreContext) context, store); + assertNull(store.fetch("a", 0)); + } + + @Test + public void shouldNotThrowExceptionIfSerdesCorrectlySetFromProcessorContext() { + EasyMock.expect(innerStoreMock.name()).andStubReturn("mocked-store"); + EasyMock.replay(innerStoreMock); + final MeteredTimestampedWindowStore store = new MeteredTimestampedWindowStore<>( + innerStoreMock, + 10L, // any size + "scope", + new MockTime(), + null, + null + ); + store.init((StateStoreContext) context, innerStoreMock); + + try { + store.put("key", ValueAndTimestamp.make(42L, 60000), 60000L); + } catch (final StreamsException exception) { + if (exception.getCause() instanceof ClassCastException) { + fail("Serdes are not correctly set from processor context."); + } + throw exception; + } + } + + @Test + public void shouldNotThrowExceptionIfSerdesCorrectlySetFromConstructorParameters() { + EasyMock.expect(innerStoreMock.name()).andStubReturn("mocked-store"); + EasyMock.replay(innerStoreMock); + final MeteredTimestampedWindowStore store = new MeteredTimestampedWindowStore<>( + innerStoreMock, + 10L, // any size + "scope", + new MockTime(), + Serdes.String(), + new ValueAndTimestampSerde<>(Serdes.Long()) + ); + store.init((StateStoreContext) context, innerStoreMock); + + try { + store.put("key", ValueAndTimestamp.make(42L, 60000), 60000L); + } catch (final StreamsException exception) { + if (exception.getCause() instanceof ClassCastException) { + fail("Serdes are not correctly set from constructor parameters."); + } + throw exception; + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredWindowStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredWindowStoreTest.java new file mode 100644 index 0000000..ca6a518 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredWindowStoreTest.java @@ -0,0 +1,486 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.JmxReporter; +import org.apache.kafka.common.metrics.KafkaMetric; +import org.apache.kafka.common.metrics.KafkaMetricsContext; +import org.apache.kafka.common.metrics.MetricConfig; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.MetricsContext; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.internals.ProcessorStateManager; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.test.InternalMockProcessorContext; +import org.apache.kafka.test.MockRecordCollector; +import org.apache.kafka.test.StreamsTestUtils; +import org.apache.kafka.test.TestUtils; +import org.junit.Before; +import org.junit.Test; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static java.time.Instant.ofEpochMilli; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.easymock.EasyMock.anyObject; +import static org.easymock.EasyMock.createNiceMock; +import static org.easymock.EasyMock.eq; +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.expectLastCall; +import static org.easymock.EasyMock.mock; +import static org.easymock.EasyMock.niceMock; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.verify; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class MeteredWindowStoreTest { + + private static final String STORE_TYPE = "scope"; + private static final String STORE_LEVEL_GROUP = "stream-state-metrics"; + private static final String THREAD_ID_TAG_KEY = "thread-id"; + private static final String STORE_NAME = "mocked-store"; + private static final String CHANGELOG_TOPIC = "changelog-topic"; + private static final String KEY = "key"; + private static final Bytes KEY_BYTES = Bytes.wrap(KEY.getBytes()); + private static final String VALUE = "value"; + private static final byte[] VALUE_BYTES = VALUE.getBytes(); + private static final int WINDOW_SIZE_MS = 10; + private static final long TIMESTAMP = 42L; + + private final String threadId = Thread.currentThread().getName(); + private InternalMockProcessorContext context; + private final WindowStore innerStoreMock = createNiceMock(WindowStore.class); + private MeteredWindowStore store = new MeteredWindowStore<>( + innerStoreMock, + WINDOW_SIZE_MS, // any size + STORE_TYPE, + new MockTime(), + Serdes.String(), + new SerdeThatDoesntHandleNull() + ); + private final Metrics metrics = new Metrics(new MetricConfig().recordLevel(Sensor.RecordingLevel.DEBUG)); + private Map tags; + + { + expect(innerStoreMock.name()).andReturn(STORE_NAME).anyTimes(); + } + + @Before + public void setUp() { + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, "test", StreamsConfig.METRICS_LATEST, new MockTime()); + context = new InternalMockProcessorContext<>( + TestUtils.tempDirectory(), + Serdes.String(), + Serdes.Long(), + streamsMetrics, + new StreamsConfig(StreamsTestUtils.getStreamsConfig()), + MockRecordCollector::new, + new ThreadCache(new LogContext("testCache "), 0, streamsMetrics), + Time.SYSTEM + ); + tags = mkMap( + mkEntry(THREAD_ID_TAG_KEY, threadId), + mkEntry("task-id", context.taskId().toString()), + mkEntry(STORE_TYPE + "-state-id", STORE_NAME) + ); + } + + @SuppressWarnings("deprecation") + @Test + public void shouldDelegateDeprecatedInit() { + final WindowStore inner = mock(WindowStore.class); + final MeteredWindowStore outer = new MeteredWindowStore<>( + inner, + WINDOW_SIZE_MS, // any size + STORE_TYPE, + new MockTime(), + Serdes.String(), + new SerdeThatDoesntHandleNull() + ); + expect(inner.name()).andStubReturn("store"); + inner.init((ProcessorContext) context, outer); + expectLastCall(); + replay(inner); + outer.init((ProcessorContext) context, outer); + verify(inner); + } + + @Test + public void shouldDelegateInit() { + final WindowStore inner = mock(WindowStore.class); + final MeteredWindowStore outer = new MeteredWindowStore<>( + inner, + WINDOW_SIZE_MS, // any size + STORE_TYPE, + new MockTime(), + Serdes.String(), + new SerdeThatDoesntHandleNull() + ); + expect(inner.name()).andStubReturn("store"); + inner.init((StateStoreContext) context, outer); + expectLastCall(); + replay(inner); + outer.init((StateStoreContext) context, outer); + verify(inner); + } + + @Test + public void shouldPassChangelogTopicNameToStateStoreSerde() { + context.addChangelogForStore(STORE_NAME, CHANGELOG_TOPIC); + doShouldPassChangelogTopicNameToStateStoreSerde(CHANGELOG_TOPIC); + } + + @Test + public void shouldPassDefaultChangelogTopicNameToStateStoreSerdeIfLoggingDisabled() { + final String defaultChangelogTopicName = + ProcessorStateManager.storeChangelogTopic(context.applicationId(), STORE_NAME, context.taskId().topologyName()); + doShouldPassChangelogTopicNameToStateStoreSerde(defaultChangelogTopicName); + } + + private void doShouldPassChangelogTopicNameToStateStoreSerde(final String topic) { + final Serde keySerde = niceMock(Serde.class); + final Serializer keySerializer = mock(Serializer.class); + final Serde valueSerde = niceMock(Serde.class); + final Deserializer valueDeserializer = mock(Deserializer.class); + final Serializer valueSerializer = mock(Serializer.class); + expect(keySerde.serializer()).andStubReturn(keySerializer); + expect(keySerializer.serialize(topic, KEY)).andStubReturn(KEY.getBytes()); + expect(valueSerde.deserializer()).andStubReturn(valueDeserializer); + expect(valueDeserializer.deserialize(topic, VALUE_BYTES)).andStubReturn(VALUE); + expect(valueSerde.serializer()).andStubReturn(valueSerializer); + expect(valueSerializer.serialize(topic, VALUE)).andStubReturn(VALUE_BYTES); + expect(innerStoreMock.fetch(KEY_BYTES, TIMESTAMP)).andStubReturn(VALUE_BYTES); + replay(innerStoreMock, keySerializer, keySerde, valueDeserializer, valueSerializer, valueSerde); + store = new MeteredWindowStore<>( + innerStoreMock, + WINDOW_SIZE_MS, + STORE_TYPE, + new MockTime(), + keySerde, + valueSerde + ); + store.init((StateStoreContext) context, store); + + store.fetch(KEY, TIMESTAMP); + store.put(KEY, VALUE, TIMESTAMP); + + verify(keySerializer, valueDeserializer, valueSerializer); + } + + @Test + public void testMetrics() { + replay(innerStoreMock); + store.init((StateStoreContext) context, store); + final JmxReporter reporter = new JmxReporter(); + final MetricsContext metricsContext = new KafkaMetricsContext("kafka.streams"); + reporter.contextChange(metricsContext); + + metrics.addReporter(reporter); + assertTrue(reporter.containsMbean(String.format( + "kafka.streams:type=%s,%s=%s,task-id=%s,%s-state-id=%s", + STORE_LEVEL_GROUP, + THREAD_ID_TAG_KEY, + threadId, + context.taskId().toString(), + STORE_TYPE, + STORE_NAME + ))); + } + + @Test + public void shouldRecordRestoreLatencyOnInit() { + innerStoreMock.init((StateStoreContext) context, store); + replay(innerStoreMock); + store.init((StateStoreContext) context, store); + + // it suffices to verify one restore metric since all restore metrics are recorded by the same sensor + // and the sensor is tested elsewhere + final KafkaMetric metric = metric("restore-rate"); + assertThat((Double) metric.metricValue(), greaterThan(0.0)); + verify(innerStoreMock); + } + + @Test + public void shouldPutToInnerStoreAndRecordPutMetrics() { + final byte[] bytes = "a".getBytes(); + innerStoreMock.put(eq(Bytes.wrap(bytes)), anyObject(), eq(context.timestamp())); + replay(innerStoreMock); + + store.init((StateStoreContext) context, store); + store.put("a", "a", context.timestamp()); + + // it suffices to verify one put metric since all put metrics are recorded by the same sensor + // and the sensor is tested elsewhere + final KafkaMetric metric = metric("put-rate"); + assertThat((Double) metric.metricValue(), greaterThan(0.0)); + verify(innerStoreMock); + } + + @Test + public void shouldFetchFromInnerStoreAndRecordFetchMetrics() { + expect(innerStoreMock.fetch(Bytes.wrap("a".getBytes()), 1, 1)) + .andReturn(KeyValueIterators.emptyWindowStoreIterator()); + replay(innerStoreMock); + + store.init((StateStoreContext) context, store); + store.fetch("a", ofEpochMilli(1), ofEpochMilli(1)).close(); // recorded on close; + + // it suffices to verify one fetch metric since all fetch metrics are recorded by the same sensor + // and the sensor is tested elsewhere + final KafkaMetric metric = metric("fetch-rate"); + assertThat((Double) metric.metricValue(), greaterThan(0.0)); + verify(innerStoreMock); + } + + @Test + public void shouldFetchRangeFromInnerStoreAndRecordFetchMetrics() { + expect(innerStoreMock.fetch(Bytes.wrap("a".getBytes()), Bytes.wrap("b".getBytes()), 1, 1)) + .andReturn(KeyValueIterators.emptyIterator()); + expect(innerStoreMock.fetch(null, Bytes.wrap("b".getBytes()), 1, 1)) + .andReturn(KeyValueIterators.emptyIterator()); + expect(innerStoreMock.fetch(Bytes.wrap("a".getBytes()), null, 1, 1)) + .andReturn(KeyValueIterators.emptyIterator()); + expect(innerStoreMock.fetch(null, null, 1, 1)) + .andReturn(KeyValueIterators.emptyIterator()); + replay(innerStoreMock); + + store.init((StateStoreContext) context, store); + store.fetch("a", "b", ofEpochMilli(1), ofEpochMilli(1)).close(); // recorded on close; + store.fetch(null, "b", ofEpochMilli(1), ofEpochMilli(1)).close(); // recorded on close; + store.fetch("a", null, ofEpochMilli(1), ofEpochMilli(1)).close(); // recorded on close; + store.fetch(null, null, ofEpochMilli(1), ofEpochMilli(1)).close(); // recorded on close; + + // it suffices to verify one fetch metric since all fetch metrics are recorded by the same sensor + // and the sensor is tested elsewhere + final KafkaMetric metric = metric("fetch-rate"); + assertThat((Double) metric.metricValue(), greaterThan(0.0)); + verify(innerStoreMock); + } + + @Test + public void shouldBackwardFetchFromInnerStoreAndRecordFetchMetrics() { + expect(innerStoreMock.backwardFetch(Bytes.wrap("a".getBytes()), Bytes.wrap("b".getBytes()), 1, 1)) + .andReturn(KeyValueIterators.emptyIterator()); + replay(innerStoreMock); + + store.init((StateStoreContext) context, store); + store.backwardFetch("a", "b", ofEpochMilli(1), ofEpochMilli(1)).close(); // recorded on close; + + // it suffices to verify one fetch metric since all fetch metrics are recorded by the same sensor + // and the sensor is tested elsewhere + final KafkaMetric metric = metric("fetch-rate"); + assertThat((Double) metric.metricValue(), greaterThan(0.0)); + verify(innerStoreMock); + } + + @Test + public void shouldBackwardFetchRangeFromInnerStoreAndRecordFetchMetrics() { + expect(innerStoreMock.backwardFetch(Bytes.wrap("a".getBytes()), Bytes.wrap("b".getBytes()), 1, 1)) + .andReturn(KeyValueIterators.emptyIterator()); + expect(innerStoreMock.backwardFetch(null, Bytes.wrap("b".getBytes()), 1, 1)) + .andReturn(KeyValueIterators.emptyIterator()); + expect(innerStoreMock.backwardFetch(Bytes.wrap("a".getBytes()), null, 1, 1)) + .andReturn(KeyValueIterators.emptyIterator()); + expect(innerStoreMock.backwardFetch(null, null, 1, 1)) + .andReturn(KeyValueIterators.emptyIterator()); + replay(innerStoreMock); + + store.init((StateStoreContext) context, store); + store.backwardFetch("a", "b", ofEpochMilli(1), ofEpochMilli(1)).close(); // recorded on close; + store.backwardFetch(null, "b", ofEpochMilli(1), ofEpochMilli(1)).close(); // recorded on close; + store.backwardFetch("a", null, ofEpochMilli(1), ofEpochMilli(1)).close(); // recorded on close; + store.backwardFetch(null, null, ofEpochMilli(1), ofEpochMilli(1)).close(); // recorded on close; + + // it suffices to verify one fetch metric since all fetch metrics are recorded by the same sensor + // and the sensor is tested elsewhere + final KafkaMetric metric = metric("fetch-rate"); + assertThat((Double) metric.metricValue(), greaterThan(0.0)); + verify(innerStoreMock); + } + + @Test + public void shouldFetchAllFromInnerStoreAndRecordFetchMetrics() { + expect(innerStoreMock.fetchAll(1, 1)).andReturn(KeyValueIterators.emptyIterator()); + replay(innerStoreMock); + + store.init((StateStoreContext) context, store); + store.fetchAll(ofEpochMilli(1), ofEpochMilli(1)).close(); // recorded on close; + + // it suffices to verify one fetch metric since all fetch metrics are recorded by the same sensor + // and the sensor is tested elsewhere + final KafkaMetric metric = metric("fetch-rate"); + assertThat((Double) metric.metricValue(), greaterThan(0.0)); + verify(innerStoreMock); + } + + @Test + public void shouldBackwardFetchAllFromInnerStoreAndRecordFetchMetrics() { + expect(innerStoreMock.backwardFetchAll(1, 1)).andReturn(KeyValueIterators.emptyIterator()); + replay(innerStoreMock); + + store.init((StateStoreContext) context, store); + store.backwardFetchAll(ofEpochMilli(1), ofEpochMilli(1)).close(); // recorded on close; + + // it suffices to verify one fetch metric since all fetch metrics are recorded by the same sensor + // and the sensor is tested elsewhere + final KafkaMetric metric = metric("fetch-rate"); + assertThat((Double) metric.metricValue(), greaterThan(0.0)); + verify(innerStoreMock); + } + + @Test + public void shouldRecordFlushLatency() { + innerStoreMock.flush(); + replay(innerStoreMock); + + store.init((StateStoreContext) context, store); + store.flush(); + + // it suffices to verify one flush metric since all flush metrics are recorded by the same sensor + // and the sensor is tested elsewhere + final KafkaMetric metric = metric("flush-rate"); + assertTrue((Double) metric.metricValue() > 0); + verify(innerStoreMock); + } + + @Test + public void shouldNotThrowNullPointerExceptionIfFetchReturnsNull() { + expect(innerStoreMock.fetch(Bytes.wrap("a".getBytes()), 0)).andReturn(null); + replay(innerStoreMock); + + store.init((StateStoreContext) context, store); + assertNull(store.fetch("a", 0)); + } + + private interface CachedWindowStore extends WindowStore, CachedStateStore { + } + + @SuppressWarnings("unchecked") + @Test + public void shouldSetFlushListenerOnWrappedCachingStore() { + final CachedWindowStore cachedWindowStore = mock(CachedWindowStore.class); + + expect(cachedWindowStore.setFlushListener(anyObject(CacheFlushListener.class), eq(false))).andReturn(true); + replay(cachedWindowStore); + + final MeteredWindowStore metered = new MeteredWindowStore<>( + cachedWindowStore, + 10L, // any size + STORE_TYPE, + new MockTime(), + Serdes.String(), + new SerdeThatDoesntHandleNull() + ); + assertTrue(metered.setFlushListener(null, false)); + + verify(cachedWindowStore); + } + + @Test + public void shouldNotSetFlushListenerOnWrappedNoneCachingStore() { + assertFalse(store.setFlushListener(null, false)); + } + + @Test + public void shouldCloseUnderlyingStore() { + innerStoreMock.close(); + expectLastCall(); + replay(innerStoreMock); + store.init((StateStoreContext) context, store); + + store.close(); + verify(innerStoreMock); + } + + @Test + public void shouldRemoveMetricsOnClose() { + innerStoreMock.close(); + expectLastCall(); + replay(innerStoreMock); + store.init((StateStoreContext) context, store); + + assertThat(storeMetrics(), not(empty())); + store.close(); + assertThat(storeMetrics(), empty()); + verify(innerStoreMock); + } + + @Test + public void shouldRemoveMetricsEvenIfWrappedStoreThrowsOnClose() { + innerStoreMock.close(); + expectLastCall().andThrow(new RuntimeException("Oops!")); + replay(innerStoreMock); + store.init((StateStoreContext) context, store); + + // There's always a "count" metric registered + assertThat(storeMetrics(), not(empty())); + assertThrows(RuntimeException.class, store::close); + assertThat(storeMetrics(), empty()); + verify(innerStoreMock); + } + + @Test + public void shouldThrowNullPointerOnPutIfKeyIsNull() { + assertThrows(NullPointerException.class, () -> store.put(null, "a", 1L)); + } + + @Test + public void shouldThrowNullPointerOnFetchIfKeyIsNull() { + assertThrows(NullPointerException.class, () -> store.fetch(null, 0L, 1L)); + } + + @Test + public void shouldThrowNullPointerOnBackwardFetchIfKeyIsNull() { + assertThrows(NullPointerException.class, () -> store.backwardFetch(null, 0L, 1L)); + } + + private KafkaMetric metric(final String name) { + return metrics.metric(new MetricName(name, STORE_LEVEL_GROUP, "", tags)); + } + + private List storeMetrics() { + return metrics.metrics() + .keySet() + .stream() + .filter(name -> name.group().equals(STORE_LEVEL_GROUP) && name.tags().equals(tags)) + .collect(Collectors.toList()); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/Murmur3Test.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/Murmur3Test.java new file mode 100644 index 0000000..d0759bf --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/Murmur3Test.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import static org.junit.Assert.*; + +import org.junit.Test; + +import java.util.Map; + +/** + * This class was taken from Hive org.apache.hive.common.util; + * https://github.com/apache/hive/blob/master/storage-api/src/test/org/apache/hive/common/util/TestMurmur3.java + * Commit: dffa3a16588bc8e95b9d0ab5af295a74e06ef702 + * + * + * Tests for Murmur3 variants. + */ +public class Murmur3Test { + + @Test + public void testMurmur3_32() { + Map cases = new java.util.HashMap<>(); + cases.put("21".getBytes(), 896581614); + cases.put("foobar".getBytes(), -328928243); + cases.put("a-little-bit-long-string".getBytes(), -1479816207); + cases.put("a-little-bit-longer-string".getBytes(), -153232333); + cases.put("lkjh234lh9fiuh90y23oiuhsafujhadof229phr9h19h89h8".getBytes(), 13417721); + cases.put(new byte[]{'a', 'b', 'c'}, 461137560); + + int seed = 123; + for (Map.Entry c : cases.entrySet()) { + byte[] b = (byte[]) c.getKey(); + assertEquals(c.getValue(), Murmur3.hash32(b, b.length, seed)); + } + } + + @Test + public void testMurmur3_128() { + Map cases = new java.util.HashMap<>(); + cases.put("21".getBytes(), new long[]{5857341059704281894L, -5288187638297930763L}); + cases.put("foobar".getBytes(), new long[]{-351361463397418609L, 8959716011862540668L}); + cases.put("a-little-bit-long-string".getBytes(), new long[]{8836256500583638442L, -198172363548498523L}); + cases.put("a-little-bit-longer-string".getBytes(), new long[]{1838346159335108511L, 8794688210320490705L}); + cases.put("lkjh234lh9fiuh90y23oiuhsafujhadof229phr9h19h89h8".getBytes(), new long[]{-4024021876037397259L, -1482317706335141238L}); + cases.put(new byte[]{'a', 'b', 'c'}, new long[]{1489494923063836066L, -5440978547625122829L}); + + int seed = 123; + + for (Map.Entry c : cases.entrySet()) { + byte[] b = (byte[]) c.getKey(); + long[] result = Murmur3.hash128(b, 0, b.length, seed); + assertArrayEquals((long[]) c.getValue(), result); + } + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/NamedCacheTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/NamedCacheTest.java new file mode 100644 index 0000000..6d43b4c --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/NamedCacheTest.java @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.internals.MockStreamsMetrics; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; + +public class NamedCacheTest { + + private final Headers headers = new RecordHeaders(new Header[]{new RecordHeader("key", "value".getBytes())}); + private NamedCache cache; + + @Before + public void setUp() { + final Metrics innerMetrics = new Metrics(); + final StreamsMetricsImpl metrics = new MockStreamsMetrics(innerMetrics); + cache = new NamedCache("dummy-name", metrics); + } + + @Test + public void shouldKeepTrackOfMostRecentlyAndLeastRecentlyUsed() { + final List> toInsert = Arrays.asList( + new KeyValue<>("K1", "V1"), + new KeyValue<>("K2", "V2"), + new KeyValue<>("K3", "V3"), + new KeyValue<>("K4", "V4"), + new KeyValue<>("K5", "V5")); + for (final KeyValue stringStringKeyValue : toInsert) { + final byte[] key = stringStringKeyValue.key.getBytes(); + final byte[] value = stringStringKeyValue.value.getBytes(); + cache.put(Bytes.wrap(key), + new LRUCacheEntry(value, new RecordHeaders(), true, 1, 1, 1, "")); + final LRUCacheEntry head = cache.first(); + final LRUCacheEntry tail = cache.last(); + assertEquals(new String(head.value()), stringStringKeyValue.value); + assertEquals(new String(tail.value()), toInsert.get(0).value); + assertEquals(cache.flushes(), 0); + assertEquals(cache.hits(), 0); + assertEquals(cache.misses(), 0); + assertEquals(cache.overwrites(), 0); + } + } + + @Test + public void shouldKeepTrackOfSize() { + final LRUCacheEntry value = new LRUCacheEntry(new byte[]{0}); + cache.put(Bytes.wrap(new byte[]{0}), value); + cache.put(Bytes.wrap(new byte[]{1}), value); + cache.put(Bytes.wrap(new byte[]{2}), value); + final long size = cache.sizeInBytes(); + // 1 byte key + 24 bytes overhead + assertEquals((value.size() + 25) * 3, size); + } + + @Test + public void shouldPutGet() { + cache.put(Bytes.wrap(new byte[]{0}), new LRUCacheEntry(new byte[]{10})); + cache.put(Bytes.wrap(new byte[]{1}), new LRUCacheEntry(new byte[]{11})); + cache.put(Bytes.wrap(new byte[]{2}), new LRUCacheEntry(new byte[]{12})); + + assertArrayEquals(new byte[] {10}, cache.get(Bytes.wrap(new byte[] {0})).value()); + assertArrayEquals(new byte[] {11}, cache.get(Bytes.wrap(new byte[] {1})).value()); + assertArrayEquals(new byte[] {12}, cache.get(Bytes.wrap(new byte[] {2})).value()); + assertEquals(cache.hits(), 3); + } + + @Test + public void shouldPutIfAbsent() { + cache.put(Bytes.wrap(new byte[]{0}), new LRUCacheEntry(new byte[]{10})); + cache.putIfAbsent(Bytes.wrap(new byte[]{0}), new LRUCacheEntry(new byte[]{20})); + cache.putIfAbsent(Bytes.wrap(new byte[]{1}), new LRUCacheEntry(new byte[]{30})); + + assertArrayEquals(new byte[] {10}, cache.get(Bytes.wrap(new byte[] {0})).value()); + assertArrayEquals(new byte[] {30}, cache.get(Bytes.wrap(new byte[] {1})).value()); + } + + @Test + public void shouldDeleteAndUpdateSize() { + cache.put(Bytes.wrap(new byte[]{0}), new LRUCacheEntry(new byte[]{10})); + final LRUCacheEntry deleted = cache.delete(Bytes.wrap(new byte[]{0})); + assertArrayEquals(new byte[] {10}, deleted.value()); + assertEquals(0, cache.sizeInBytes()); + } + + @Test + public void shouldPutAll() { + cache.putAll(Arrays.asList(KeyValue.pair(new byte[] {0}, new LRUCacheEntry(new byte[]{0})), + KeyValue.pair(new byte[] {1}, new LRUCacheEntry(new byte[]{1})), + KeyValue.pair(new byte[] {2}, new LRUCacheEntry(new byte[]{2})))); + + assertArrayEquals(new byte[]{0}, cache.get(Bytes.wrap(new byte[]{0})).value()); + assertArrayEquals(new byte[]{1}, cache.get(Bytes.wrap(new byte[]{1})).value()); + assertArrayEquals(new byte[]{2}, cache.get(Bytes.wrap(new byte[]{2})).value()); + } + + @Test + public void shouldOverwriteAll() { + cache.putAll(Arrays.asList(KeyValue.pair(new byte[] {0}, new LRUCacheEntry(new byte[]{0})), + KeyValue.pair(new byte[] {0}, new LRUCacheEntry(new byte[]{1})), + KeyValue.pair(new byte[] {0}, new LRUCacheEntry(new byte[]{2})))); + + assertArrayEquals(new byte[]{2}, cache.get(Bytes.wrap(new byte[]{0})).value()); + assertEquals(cache.overwrites(), 2); + } + + @Test + public void shouldEvictEldestEntry() { + cache.put(Bytes.wrap(new byte[]{0}), new LRUCacheEntry(new byte[]{10})); + cache.put(Bytes.wrap(new byte[]{1}), new LRUCacheEntry(new byte[]{20})); + cache.put(Bytes.wrap(new byte[]{2}), new LRUCacheEntry(new byte[]{30})); + + cache.evict(); + assertNull(cache.get(Bytes.wrap(new byte[]{0}))); + assertEquals(2, cache.size()); + } + + @Test + public void shouldFlushDirtEntriesOnEviction() { + final List flushed = new ArrayList<>(); + cache.put(Bytes.wrap(new byte[]{0}), new LRUCacheEntry(new byte[]{10}, headers, true, 0, 0, 0, "")); + cache.put(Bytes.wrap(new byte[]{1}), new LRUCacheEntry(new byte[]{20})); + cache.put(Bytes.wrap(new byte[]{2}), new LRUCacheEntry(new byte[]{30}, headers, true, 0, 0, 0, "")); + + cache.setListener(flushed::addAll); + + cache.evict(); + + assertEquals(2, flushed.size()); + assertEquals(Bytes.wrap(new byte[] {0}), flushed.get(0).key()); + assertEquals(headers, flushed.get(0).entry().context().headers()); + assertArrayEquals(new byte[] {10}, flushed.get(0).newValue()); + assertEquals(Bytes.wrap(new byte[] {2}), flushed.get(1).key()); + assertArrayEquals(new byte[] {30}, flushed.get(1).newValue()); + assertEquals(cache.flushes(), 1); + } + + @Test + public void shouldNotThrowNullPointerWhenCacheIsEmptyAndEvictionCalled() { + cache.evict(); + } + + @Test + public void shouldThrowIllegalStateExceptionWhenTryingToOverwriteDirtyEntryWithCleanEntry() { + cache.put(Bytes.wrap(new byte[]{0}), new LRUCacheEntry(new byte[]{10}, headers, true, 0, 0, 0, "")); + assertThrows(IllegalStateException.class, () -> cache.put(Bytes.wrap(new byte[]{0}), + new LRUCacheEntry(new byte[]{10}, new RecordHeaders(), false, 0, 0, 0, ""))); + } + + @Test + public void shouldRemoveDeletedValuesOnFlush() { + cache.setListener(dirty -> { /* no-op */ }); + cache.put(Bytes.wrap(new byte[]{0}), new LRUCacheEntry(null, headers, true, 0, 0, 0, "")); + cache.put(Bytes.wrap(new byte[]{1}), new LRUCacheEntry(new byte[]{20}, new RecordHeaders(), true, 0, 0, 0, "")); + cache.flush(); + assertEquals(1, cache.size()); + assertNotNull(cache.get(Bytes.wrap(new byte[]{1}))); + } + + @Test + public void shouldBeReentrantAndNotBreakLRU() { + final LRUCacheEntry dirty = new LRUCacheEntry(new byte[]{3}, new RecordHeaders(), true, 0, 0, 0, ""); + final LRUCacheEntry clean = new LRUCacheEntry(new byte[]{3}); + cache.put(Bytes.wrap(new byte[]{0}), dirty); + cache.put(Bytes.wrap(new byte[]{1}), clean); + cache.put(Bytes.wrap(new byte[]{2}), clean); + assertEquals(3 * cache.head().size(), cache.sizeInBytes()); + cache.setListener(dirty1 -> { + cache.put(Bytes.wrap(new byte[]{3}), clean); + // evict key 1 + cache.evict(); + // evict key 2 + cache.evict(); + }); + + assertEquals(3 * cache.head().size(), cache.sizeInBytes()); + // Evict key 0 + cache.evict(); + final Bytes entryFour = Bytes.wrap(new byte[]{4}); + cache.put(entryFour, dirty); + + // check that the LRU is still correct + final NamedCache.LRUNode head = cache.head(); + final NamedCache.LRUNode tail = cache.tail(); + assertEquals(2, cache.size()); + assertEquals(2 * head.size(), cache.sizeInBytes()); + // dirty should be the newest + assertEquals(entryFour, head.key()); + assertEquals(Bytes.wrap(new byte[] {3}), tail.key()); + assertSame(tail, head.next()); + assertNull(head.previous()); + assertSame(head, tail.previous()); + assertNull(tail.next()); + + // evict key 3 + cache.evict(); + assertSame(cache.head(), cache.tail()); + assertEquals(entryFour, cache.head().key()); + assertNull(cache.head().next()); + assertNull(cache.head().previous()); + } + + @Test + public void shouldNotThrowIllegalArgumentAfterEvictingDirtyRecordAndThenPuttingNewRecordWithSameKey() { + final LRUCacheEntry dirty = new LRUCacheEntry(new byte[]{3}, new RecordHeaders(), true, 0, 0, 0, ""); + final LRUCacheEntry clean = new LRUCacheEntry(new byte[]{3}); + final Bytes key = Bytes.wrap(new byte[] {3}); + cache.setListener(dirty1 -> cache.put(key, clean)); + cache.put(key, dirty); + cache.evict(); + } + + @Test + public void shouldReturnNullIfKeyIsNull() { + assertNull(cache.get(null)); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/OffsetCheckpointTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/OffsetCheckpointTest.java new file mode 100644 index 0000000..9970a1b --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/OffsetCheckpointTest.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.test.TestUtils; +import org.junit.Test; + +import static org.apache.kafka.streams.state.internals.OffsetCheckpoint.writeEntry; +import static org.apache.kafka.streams.state.internals.OffsetCheckpoint.writeIntLine; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.CoreMatchers.containsString; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class OffsetCheckpointTest { + + private final String topic = "topic"; + + @Test + public void testReadWrite() throws IOException { + final File f = TestUtils.tempFile(); + final OffsetCheckpoint checkpoint = new OffsetCheckpoint(f); + + try { + final Map offsets = new HashMap<>(); + offsets.put(new TopicPartition(topic, 0), 0L); + offsets.put(new TopicPartition(topic, 1), 1L); + offsets.put(new TopicPartition(topic, 2), 2L); + + checkpoint.write(offsets); + assertEquals(offsets, checkpoint.read()); + + checkpoint.delete(); + assertFalse(f.exists()); + + offsets.put(new TopicPartition(topic, 3), 3L); + checkpoint.write(offsets); + assertEquals(offsets, checkpoint.read()); + } finally { + checkpoint.delete(); + } + } + + @Test + public void shouldNotWriteCheckpointWhenNoOffsets() throws IOException { + // we do not need to worry about file name uniqueness since this file should not be created + final File f = new File(TestUtils.tempDirectory().getAbsolutePath(), "kafka.tmp"); + final OffsetCheckpoint checkpoint = new OffsetCheckpoint(f); + + checkpoint.write(Collections.emptyMap()); + + assertFalse(f.exists()); + + assertEquals(Collections.emptyMap(), checkpoint.read()); + + // deleting a non-exist checkpoint file should be fine + checkpoint.delete(); + } + + @Test + public void shouldDeleteExistingCheckpointWhenNoOffsets() throws IOException { + final File file = TestUtils.tempFile(); + final OffsetCheckpoint checkpoint = new OffsetCheckpoint(file); + + final Map offsets = Collections.singletonMap(new TopicPartition(topic, 0), 1L); + + checkpoint.write(offsets); + + assertThat(file.exists(), is(true)); + assertThat(offsets, is(checkpoint.read())); + + checkpoint.write(Collections.emptyMap()); + + assertThat(file.exists(), is(false)); + assertThat(Collections.emptyMap(), is(checkpoint.read())); + } + + @Test + public void shouldSkipInvalidOffsetsDuringRead() throws IOException { + final File file = TestUtils.tempFile(); + final OffsetCheckpoint checkpoint = new OffsetCheckpoint(file); + + try { + final Map offsets = new HashMap<>(); + offsets.put(new TopicPartition(topic, 0), -1L); + + writeVersion0(offsets, file); + assertTrue(checkpoint.read().isEmpty()); + } finally { + checkpoint.delete(); + } + } + + @Test + public void shouldReadAndWriteSentinelOffset() throws IOException { + final File f = TestUtils.tempFile(); + final OffsetCheckpoint checkpoint = new OffsetCheckpoint(f); + final long sentinelOffset = -4L; + + try { + final Map offsetsToWrite = new HashMap<>(); + offsetsToWrite.put(new TopicPartition(topic, 1), sentinelOffset); + checkpoint.write(offsetsToWrite); + + final Map readOffsets = checkpoint.read(); + assertThat(readOffsets.get(new TopicPartition(topic, 1)), equalTo(sentinelOffset)); + } finally { + checkpoint.delete(); + } + } + + @Test + public void shouldThrowOnInvalidOffsetInWrite() throws IOException { + final File f = TestUtils.tempFile(); + final OffsetCheckpoint checkpoint = new OffsetCheckpoint(f); + + try { + final Map offsets = new HashMap<>(); + offsets.put(new TopicPartition(topic, 0), 0L); + offsets.put(new TopicPartition(topic, 1), -1L); // invalid + offsets.put(new TopicPartition(topic, 2), 2L); + + assertThrows(IllegalStateException.class, () -> checkpoint.write(offsets)); + } finally { + checkpoint.delete(); + } + } + + @Test + public void shouldThrowIOExceptionWhenWritingToNotExistedFile() { + final Map offsetsToWrite = Collections.singletonMap(new TopicPartition(topic, 0), 0L); + + final File notExistedFile = new File("/not_existed_dir/not_existed_file"); + final OffsetCheckpoint checkpoint = new OffsetCheckpoint(notExistedFile); + + final IOException e = assertThrows(IOException.class, () -> checkpoint.write(offsetsToWrite)); + assertThat(e.getMessage(), containsString("No such file or directory")); + } + + /** + * Write all the offsets following the version 0 format without any verification (eg enforcing offsets >= 0) + */ + static void writeVersion0(final Map offsets, final File file) throws IOException { + final FileOutputStream fileOutputStream = new FileOutputStream(file); + try (final BufferedWriter writer = new BufferedWriter( + new OutputStreamWriter(fileOutputStream, StandardCharsets.UTF_8))) { + writeIntLine(writer, 0); + writeIntLine(writer, offsets.size()); + + for (final Map.Entry entry : offsets.entrySet()) { + final TopicPartition tp = entry.getKey(); + final Long offset = entry.getValue(); + writeEntry(writer, tp, offset); + } + + writer.flush(); + fileOutputStream.getFD().sync(); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/QueryableStoreProviderTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/QueryableStoreProviderTest.java new file mode 100644 index 0000000..79cfb65 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/QueryableStoreProviderTest.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + + +import org.apache.kafka.streams.StoreQueryParameters; +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.state.NoOpWindowStore; +import org.apache.kafka.streams.state.QueryableStoreTypes; +import org.apache.kafka.test.NoOpReadOnlyStore; +import org.apache.kafka.test.StateStoreProviderStub; +import org.junit.Before; +import org.junit.Test; + +import java.util.HashMap; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.IsEqual.equalTo; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertNotNull; + +public class QueryableStoreProviderTest { + + private final String keyValueStore = "key-value"; + private final String windowStore = "window-store"; + private QueryableStoreProvider storeProvider; + private HashMap globalStateStores; + private final int numStateStorePartitions = 2; + + @Before + public void before() { + final StateStoreProviderStub theStoreProvider = new StateStoreProviderStub(false); + for (int partition = 0; partition < numStateStorePartitions; partition++) { + theStoreProvider.addStore(keyValueStore, partition, new NoOpReadOnlyStore<>()); + theStoreProvider.addStore(windowStore, partition, new NoOpWindowStore()); + } + globalStateStores = new HashMap<>(); + storeProvider = + new QueryableStoreProvider( + new GlobalStateStoreProvider(globalStateStores) + ); + storeProvider.addStoreProviderForThread("thread1", theStoreProvider); + } + + @Test + public void shouldThrowExceptionIfKVStoreDoesntExist() { + assertThrows(InvalidStateStoreException.class, () -> storeProvider.getStore( + StoreQueryParameters.fromNameAndType("not-a-store", QueryableStoreTypes.keyValueStore())).get("1")); + } + + @Test + public void shouldThrowExceptionIfWindowStoreDoesntExist() { + assertThrows(InvalidStateStoreException.class, () -> storeProvider.getStore( + StoreQueryParameters.fromNameAndType("not-a-store", QueryableStoreTypes.windowStore())).fetch("1", System.currentTimeMillis())); + } + + @Test + public void shouldReturnKVStoreWhenItExists() { + assertNotNull(storeProvider.getStore(StoreQueryParameters.fromNameAndType(keyValueStore, QueryableStoreTypes.keyValueStore()))); + } + + @Test + public void shouldReturnWindowStoreWhenItExists() { + assertNotNull(storeProvider.getStore(StoreQueryParameters.fromNameAndType(windowStore, QueryableStoreTypes.windowStore()))); + } + + @Test + public void shouldThrowExceptionWhenLookingForWindowStoreWithDifferentType() { + assertThrows(InvalidStateStoreException.class, () -> storeProvider.getStore(StoreQueryParameters.fromNameAndType(windowStore, + QueryableStoreTypes.keyValueStore())).get("1")); + } + + @Test + public void shouldThrowExceptionWhenLookingForKVStoreWithDifferentType() { + assertThrows(InvalidStateStoreException.class, () -> storeProvider.getStore(StoreQueryParameters.fromNameAndType(keyValueStore, + QueryableStoreTypes.windowStore())).fetch("1", System.currentTimeMillis())); + } + + @Test + public void shouldFindGlobalStores() { + globalStateStores.put("global", new NoOpReadOnlyStore<>()); + assertNotNull(storeProvider.getStore(StoreQueryParameters.fromNameAndType("global", QueryableStoreTypes.keyValueStore()))); + } + + @Test + public void shouldReturnKVStoreWithPartitionWhenItExists() { + assertNotNull(storeProvider.getStore(StoreQueryParameters.fromNameAndType(keyValueStore, QueryableStoreTypes.keyValueStore()).withPartition(numStateStorePartitions - 1))); + } + + @Test + public void shouldThrowExceptionWhenKVStoreWithPartitionDoesntExists() { + final int partition = numStateStorePartitions + 1; + final InvalidStateStoreException thrown = assertThrows(InvalidStateStoreException.class, () -> + storeProvider.getStore( + StoreQueryParameters + .fromNameAndType(keyValueStore, QueryableStoreTypes.keyValueStore()) + .withPartition(partition)).get("1") + ); + assertThat(thrown.getMessage(), equalTo(String.format("The specified partition %d for store %s does not exist.", partition, keyValueStore))); + } + + @Test + public void shouldReturnWindowStoreWithPartitionWhenItExists() { + assertNotNull(storeProvider.getStore(StoreQueryParameters.fromNameAndType(windowStore, QueryableStoreTypes.windowStore()).withPartition(numStateStorePartitions - 1))); + } + + @Test + public void shouldThrowExceptionWhenWindowStoreWithPartitionDoesntExists() { + final int partition = numStateStorePartitions + 1; + final InvalidStateStoreException thrown = assertThrows(InvalidStateStoreException.class, () -> + storeProvider.getStore( + StoreQueryParameters + .fromNameAndType(windowStore, QueryableStoreTypes.windowStore()) + .withPartition(partition)).fetch("1", System.currentTimeMillis()) + ); + assertThat(thrown.getMessage(), equalTo(String.format("The specified partition %d for store %s does not exist.", partition, windowStore))); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/ReadOnlyKeyValueStoreFacadeTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/ReadOnlyKeyValueStoreFacadeTest.java new file mode 100644 index 0000000..ffb5ab3 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/ReadOnlyKeyValueStoreFacadeTest.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.easymock.EasyMockRunner; +import org.easymock.Mock; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.verify; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertNull; + +@RunWith(EasyMockRunner.class) +public class ReadOnlyKeyValueStoreFacadeTest { + @Mock + private TimestampedKeyValueStore mockedKeyValueTimestampStore; + @Mock + private KeyValueIterator> mockedKeyValueTimestampIterator; + + private ReadOnlyKeyValueStoreFacade readOnlyKeyValueStoreFacade; + + @Before + public void setup() { + readOnlyKeyValueStoreFacade = new ReadOnlyKeyValueStoreFacade<>(mockedKeyValueTimestampStore); + } + + @Test + public void shouldReturnPlainValueOnGet() { + expect(mockedKeyValueTimestampStore.get("key")) + .andReturn(ValueAndTimestamp.make("value", 42L)); + expect(mockedKeyValueTimestampStore.get("unknownKey")) + .andReturn(null); + replay(mockedKeyValueTimestampStore); + + assertThat(readOnlyKeyValueStoreFacade.get("key"), is("value")); + assertNull(readOnlyKeyValueStoreFacade.get("unknownKey")); + verify(mockedKeyValueTimestampStore); + } + + @Test + public void shouldReturnPlainKeyValuePairsForRangeIterator() { + expect(mockedKeyValueTimestampIterator.next()) + .andReturn(KeyValue.pair("key1", ValueAndTimestamp.make("value1", 21L))) + .andReturn(KeyValue.pair("key2", ValueAndTimestamp.make("value2", 42L))); + expect(mockedKeyValueTimestampStore.range("key1", "key2")).andReturn(mockedKeyValueTimestampIterator); + replay(mockedKeyValueTimestampIterator, mockedKeyValueTimestampStore); + + final KeyValueIterator iterator = readOnlyKeyValueStoreFacade.range("key1", "key2"); + assertThat(iterator.next(), is(KeyValue.pair("key1", "value1"))); + assertThat(iterator.next(), is(KeyValue.pair("key2", "value2"))); + verify(mockedKeyValueTimestampIterator, mockedKeyValueTimestampStore); + } + + @Test + public void shouldReturnPlainKeyValuePairsForPrefixScan() { + final StringSerializer stringSerializer = new StringSerializer(); + expect(mockedKeyValueTimestampIterator.next()) + .andReturn(KeyValue.pair("key1", ValueAndTimestamp.make("value1", 21L))) + .andReturn(KeyValue.pair("key2", ValueAndTimestamp.make("value2", 42L))); + expect(mockedKeyValueTimestampStore.prefixScan("key", stringSerializer)).andReturn(mockedKeyValueTimestampIterator); + replay(mockedKeyValueTimestampIterator, mockedKeyValueTimestampStore); + + final KeyValueIterator iterator = readOnlyKeyValueStoreFacade.prefixScan("key", stringSerializer); + assertThat(iterator.next(), is(KeyValue.pair("key1", "value1"))); + assertThat(iterator.next(), is(KeyValue.pair("key2", "value2"))); + verify(mockedKeyValueTimestampIterator, mockedKeyValueTimestampStore); + } + + @Test + public void shouldReturnPlainKeyValuePairsForAllIterator() { + expect(mockedKeyValueTimestampIterator.next()) + .andReturn(KeyValue.pair("key1", ValueAndTimestamp.make("value1", 21L))) + .andReturn(KeyValue.pair("key2", ValueAndTimestamp.make("value2", 42L))); + expect(mockedKeyValueTimestampStore.all()).andReturn(mockedKeyValueTimestampIterator); + replay(mockedKeyValueTimestampIterator, mockedKeyValueTimestampStore); + + final KeyValueIterator iterator = readOnlyKeyValueStoreFacade.all(); + assertThat(iterator.next(), is(KeyValue.pair("key1", "value1"))); + assertThat(iterator.next(), is(KeyValue.pair("key2", "value2"))); + verify(mockedKeyValueTimestampIterator, mockedKeyValueTimestampStore); + } + + @Test + public void shouldForwardApproximateNumEntries() { + expect(mockedKeyValueTimestampStore.approximateNumEntries()).andReturn(42L); + replay(mockedKeyValueTimestampStore); + + assertThat(readOnlyKeyValueStoreFacade.approximateNumEntries(), is(42L)); + verify(mockedKeyValueTimestampStore); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/ReadOnlyWindowStoreFacadeTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/ReadOnlyWindowStoreFacadeTest.java new file mode 100644 index 0000000..fa8128c --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/ReadOnlyWindowStoreFacadeTest.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.TimeWindow; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.TimestampedWindowStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.streams.state.WindowStoreIterator; +import org.easymock.EasyMockRunner; +import org.easymock.Mock; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import java.time.Instant; + +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.verify; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertNull; + +@RunWith(EasyMockRunner.class) +public class ReadOnlyWindowStoreFacadeTest { + @Mock + private TimestampedWindowStore mockedWindowTimestampStore; + @Mock + private WindowStoreIterator> mockedWindowTimestampIterator; + @Mock + private KeyValueIterator, ValueAndTimestamp> mockedKeyValueWindowTimestampIterator; + + private ReadOnlyWindowStoreFacade readOnlyWindowStoreFacade; + + @Before + public void setup() { + readOnlyWindowStoreFacade = new ReadOnlyWindowStoreFacade<>(mockedWindowTimestampStore); + } + + @Test + public void shouldReturnPlainKeyValuePairsOnSingleKeyFetch() { + expect(mockedWindowTimestampStore.fetch("key1", 21L)) + .andReturn(ValueAndTimestamp.make("value1", 42L)); + expect(mockedWindowTimestampStore.fetch("unknownKey", 21L)) + .andReturn(null); + replay(mockedWindowTimestampStore); + + assertThat(readOnlyWindowStoreFacade.fetch("key1", 21L), is("value1")); + assertNull(readOnlyWindowStoreFacade.fetch("unknownKey", 21L)); + + verify(mockedWindowTimestampStore); + } + + @Test + public void shouldReturnPlainKeyValuePairsOnSingleKeyFetchLongParameters() { + expect(mockedWindowTimestampIterator.next()) + .andReturn(KeyValue.pair(21L, ValueAndTimestamp.make("value1", 22L))) + .andReturn(KeyValue.pair(42L, ValueAndTimestamp.make("value2", 23L))); + expect(mockedWindowTimestampStore.fetch("key1", Instant.ofEpochMilli(21L), Instant.ofEpochMilli(42L))) + .andReturn(mockedWindowTimestampIterator); + replay(mockedWindowTimestampIterator, mockedWindowTimestampStore); + + final WindowStoreIterator iterator = + readOnlyWindowStoreFacade.fetch("key1", Instant.ofEpochMilli(21L), Instant.ofEpochMilli(42L)); + + assertThat(iterator.next(), is(KeyValue.pair(21L, "value1"))); + assertThat(iterator.next(), is(KeyValue.pair(42L, "value2"))); + verify(mockedWindowTimestampIterator, mockedWindowTimestampStore); + } + + @Test + public void shouldReturnPlainKeyValuePairsOnSingleKeyFetchInstantParameters() { + expect(mockedWindowTimestampIterator.next()) + .andReturn(KeyValue.pair(21L, ValueAndTimestamp.make("value1", 22L))) + .andReturn(KeyValue.pair(42L, ValueAndTimestamp.make("value2", 23L))); + expect(mockedWindowTimestampStore.fetch("key1", Instant.ofEpochMilli(21L), Instant.ofEpochMilli(42L))) + .andReturn(mockedWindowTimestampIterator); + replay(mockedWindowTimestampIterator, mockedWindowTimestampStore); + + final WindowStoreIterator iterator = + readOnlyWindowStoreFacade.fetch("key1", Instant.ofEpochMilli(21L), Instant.ofEpochMilli(42L)); + + assertThat(iterator.next(), is(KeyValue.pair(21L, "value1"))); + assertThat(iterator.next(), is(KeyValue.pair(42L, "value2"))); + verify(mockedWindowTimestampIterator, mockedWindowTimestampStore); + } + + @Test + public void shouldReturnPlainKeyValuePairsOnRangeFetchLongParameters() { + expect(mockedKeyValueWindowTimestampIterator.next()) + .andReturn(KeyValue.pair( + new Windowed<>("key1", new TimeWindow(21L, 22L)), + ValueAndTimestamp.make("value1", 22L))) + .andReturn(KeyValue.pair( + new Windowed<>("key2", new TimeWindow(42L, 43L)), + ValueAndTimestamp.make("value2", 100L))); + expect(mockedWindowTimestampStore.fetch("key1", "key2", Instant.ofEpochMilli(21L), Instant.ofEpochMilli(42L))) + .andReturn(mockedKeyValueWindowTimestampIterator); + replay(mockedKeyValueWindowTimestampIterator, mockedWindowTimestampStore); + + final KeyValueIterator, String> iterator = + readOnlyWindowStoreFacade.fetch("key1", "key2", Instant.ofEpochMilli(21L), Instant.ofEpochMilli(42L)); + + assertThat(iterator.next(), is(KeyValue.pair(new Windowed<>("key1", new TimeWindow(21L, 22L)), "value1"))); + assertThat(iterator.next(), is(KeyValue.pair(new Windowed<>("key2", new TimeWindow(42L, 43L)), "value2"))); + verify(mockedKeyValueWindowTimestampIterator, mockedWindowTimestampStore); + } + + @Test + public void shouldReturnPlainKeyValuePairsOnRangeFetchInstantParameters() { + expect(mockedKeyValueWindowTimestampIterator.next()) + .andReturn(KeyValue.pair( + new Windowed<>("key1", new TimeWindow(21L, 22L)), + ValueAndTimestamp.make("value1", 22L))) + .andReturn(KeyValue.pair( + new Windowed<>("key2", new TimeWindow(42L, 43L)), + ValueAndTimestamp.make("value2", 100L))); + expect(mockedWindowTimestampStore.fetch("key1", "key2", Instant.ofEpochMilli(21L), Instant.ofEpochMilli(42L))) + .andReturn(mockedKeyValueWindowTimestampIterator); + replay(mockedKeyValueWindowTimestampIterator, mockedWindowTimestampStore); + + final KeyValueIterator, String> iterator = + readOnlyWindowStoreFacade.fetch("key1", "key2", Instant.ofEpochMilli(21L), Instant.ofEpochMilli(42L)); + + assertThat(iterator.next(), is(KeyValue.pair(new Windowed<>("key1", new TimeWindow(21L, 22L)), "value1"))); + assertThat(iterator.next(), is(KeyValue.pair(new Windowed<>("key2", new TimeWindow(42L, 43L)), "value2"))); + verify(mockedKeyValueWindowTimestampIterator, mockedWindowTimestampStore); + } + + @Test + public void shouldReturnPlainKeyValuePairsOnFetchAllLongParameters() { + expect(mockedKeyValueWindowTimestampIterator.next()) + .andReturn(KeyValue.pair( + new Windowed<>("key1", new TimeWindow(21L, 22L)), + ValueAndTimestamp.make("value1", 22L))) + .andReturn(KeyValue.pair( + new Windowed<>("key2", new TimeWindow(42L, 43L)), + ValueAndTimestamp.make("value2", 100L))); + expect(mockedWindowTimestampStore.fetchAll(Instant.ofEpochMilli(21L), Instant.ofEpochMilli(42L))) + .andReturn(mockedKeyValueWindowTimestampIterator); + replay(mockedKeyValueWindowTimestampIterator, mockedWindowTimestampStore); + + final KeyValueIterator, String> iterator = + readOnlyWindowStoreFacade.fetchAll(Instant.ofEpochMilli(21L), Instant.ofEpochMilli(42L)); + + assertThat(iterator.next(), is(KeyValue.pair(new Windowed<>("key1", new TimeWindow(21L, 22L)), "value1"))); + assertThat(iterator.next(), is(KeyValue.pair(new Windowed<>("key2", new TimeWindow(42L, 43L)), "value2"))); + verify(mockedKeyValueWindowTimestampIterator, mockedWindowTimestampStore); + } + + @Test + public void shouldReturnPlainKeyValuePairsOnFetchAllInstantParameters() { + expect(mockedKeyValueWindowTimestampIterator.next()) + .andReturn(KeyValue.pair( + new Windowed<>("key1", new TimeWindow(21L, 22L)), + ValueAndTimestamp.make("value1", 22L))) + .andReturn(KeyValue.pair( + new Windowed<>("key2", new TimeWindow(42L, 43L)), + ValueAndTimestamp.make("value2", 100L))); + expect(mockedWindowTimestampStore.fetchAll(Instant.ofEpochMilli(21L), Instant.ofEpochMilli(42L))) + .andReturn(mockedKeyValueWindowTimestampIterator); + replay(mockedKeyValueWindowTimestampIterator, mockedWindowTimestampStore); + + final KeyValueIterator, String> iterator = + readOnlyWindowStoreFacade.fetchAll(Instant.ofEpochMilli(21L), Instant.ofEpochMilli(42L)); + + assertThat(iterator.next(), is(KeyValue.pair(new Windowed<>("key1", new TimeWindow(21L, 22L)), "value1"))); + assertThat(iterator.next(), is(KeyValue.pair(new Windowed<>("key2", new TimeWindow(42L, 43L)), "value2"))); + verify(mockedKeyValueWindowTimestampIterator, mockedWindowTimestampStore); + } + + @Test + public void shouldReturnPlainKeyValuePairsOnAll() { + expect(mockedKeyValueWindowTimestampIterator.next()) + .andReturn(KeyValue.pair( + new Windowed<>("key1", new TimeWindow(21L, 22L)), + ValueAndTimestamp.make("value1", 22L))) + .andReturn(KeyValue.pair( + new Windowed<>("key2", new TimeWindow(42L, 43L)), + ValueAndTimestamp.make("value2", 100L))); + expect(mockedWindowTimestampStore.all()).andReturn(mockedKeyValueWindowTimestampIterator); + replay(mockedKeyValueWindowTimestampIterator, mockedWindowTimestampStore); + + final KeyValueIterator, String> iterator = readOnlyWindowStoreFacade.all(); + + assertThat(iterator.next(), is(KeyValue.pair(new Windowed<>("key1", new TimeWindow(21L, 22L)), "value1"))); + assertThat(iterator.next(), is(KeyValue.pair(new Windowed<>("key2", new TimeWindow(42L, 43L)), "value2"))); + verify(mockedKeyValueWindowTimestampIterator, mockedWindowTimestampStore); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/ReadOnlyWindowStoreStub.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/ReadOnlyWindowStoreStub.java new file mode 100644 index 0000000..752334d --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/ReadOnlyWindowStoreStub.java @@ -0,0 +1,433 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.internals.ApiUtils; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.TimeWindow; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.ReadOnlyWindowStore; +import org.apache.kafka.streams.state.WindowStoreIterator; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.NavigableMap; +import java.util.TreeMap; + +import static org.apache.kafka.streams.internals.ApiUtils.prepareMillisCheckFailMsgPrefix; + +/** + * A very simple window store stub for testing purposes. + */ +public class ReadOnlyWindowStoreStub implements ReadOnlyWindowStore, StateStore { + + private final long windowSize; + private final NavigableMap> data = new TreeMap<>(); + private boolean open = true; + + ReadOnlyWindowStoreStub(final long windowSize) { + this.windowSize = windowSize; + } + + @Override + public V fetch(final K key, final long time) { + final Map kvMap = data.get(time); + if (kvMap != null) { + return kvMap.get(key); + } else { + return null; + } + } + + @Override + public WindowStoreIterator fetch(final K key, final Instant timeFrom, final Instant timeTo) { + if (!open) { + throw new InvalidStateStoreException("Store is not open"); + } + final List> results = new ArrayList<>(); + for (long now = timeFrom.toEpochMilli(); now <= timeTo.toEpochMilli(); now++) { + final Map kvMap = data.get(now); + if (kvMap != null && kvMap.containsKey(key)) { + results.add(new KeyValue<>(now, kvMap.get(key))); + } + } + return new TheWindowStoreIterator<>(results.iterator()); + } + + @Override + public WindowStoreIterator backwardFetch(final K key, final Instant timeFrom, final Instant timeTo) throws IllegalArgumentException { + final long timeFromTs = ApiUtils.validateMillisecondInstant(timeFrom, prepareMillisCheckFailMsgPrefix(timeFrom, "timeFrom")); + final long timeToTs = ApiUtils.validateMillisecondInstant(timeTo, prepareMillisCheckFailMsgPrefix(timeTo, "timeTo")); + if (!open) { + throw new InvalidStateStoreException("Store is not open"); + } + final List> results = new ArrayList<>(); + for (long now = timeToTs; now >= timeFromTs; now--) { + final Map kvMap = data.get(now); + if (kvMap != null && kvMap.containsKey(key)) { + results.add(new KeyValue<>(now, kvMap.get(key))); + } + } + return new TheWindowStoreIterator<>(results.iterator()); + } + + @Override + public KeyValueIterator, V> all() { + if (!open) { + throw new InvalidStateStoreException("Store is not open"); + } + final List, V>> results = new ArrayList<>(); + for (final long now : data.keySet()) { + final NavigableMap kvMap = data.get(now); + if (kvMap != null) { + for (final Entry entry : kvMap.entrySet()) { + results.add(new KeyValue<>(new Windowed<>(entry.getKey(), new TimeWindow(now, now + windowSize)), entry.getValue())); + } + } + } + final Iterator, V>> iterator = results.iterator(); + + return new KeyValueIterator, V>() { + @Override + public void close() { + } + + @Override + public Windowed peekNextKey() { + throw new UnsupportedOperationException("peekNextKey() not supported in " + getClass().getName()); + } + + @Override + public boolean hasNext() { + return iterator.hasNext(); + } + + @Override + public KeyValue, V> next() { + return iterator.next(); + } + + }; + } + + @Override + public KeyValueIterator, V> backwardAll() { + if (!open) { + throw new InvalidStateStoreException("Store is not open"); + } + final List, V>> results = new ArrayList<>(); + for (final long now : data.descendingKeySet()) { + final NavigableMap kvMap = data.get(now); + if (kvMap != null) { + for (final Entry entry : kvMap.descendingMap().entrySet()) { + results.add(new KeyValue<>(new Windowed<>(entry.getKey(), new TimeWindow(now, now + windowSize)), entry.getValue())); + } + } + } + final Iterator, V>> iterator = results.iterator(); + + return new KeyValueIterator, V>() { + @Override + public void close() { + } + + @Override + public Windowed peekNextKey() { + throw new UnsupportedOperationException("peekNextKey() not supported in " + getClass().getName()); + } + + @Override + public boolean hasNext() { + return iterator.hasNext(); + } + + @Override + public KeyValue, V> next() { + return iterator.next(); + } + + }; + } + + @Override + public KeyValueIterator, V> fetchAll(final Instant timeFrom, final Instant timeTo) { + if (!open) { + throw new InvalidStateStoreException("Store is not open"); + } + final List, V>> results = new ArrayList<>(); + for (final long now : data.keySet()) { + if (!(now >= timeFrom.toEpochMilli() && now <= timeTo.toEpochMilli())) { + continue; + } + final NavigableMap kvMap = data.get(now); + if (kvMap != null) { + for (final Entry entry : kvMap.entrySet()) { + results.add(new KeyValue<>(new Windowed<>(entry.getKey(), new TimeWindow(now, now + windowSize)), entry.getValue())); + } + } + } + final Iterator, V>> iterator = results.iterator(); + + return new KeyValueIterator, V>() { + @Override + public void close() { + } + + @Override + public Windowed peekNextKey() { + throw new UnsupportedOperationException("peekNextKey() not supported in " + getClass().getName()); + } + + @Override + public boolean hasNext() { + return iterator.hasNext(); + } + + @Override + public KeyValue, V> next() { + return iterator.next(); + } + + }; + } + + @Override + public KeyValueIterator, V> backwardFetchAll(final Instant timeFrom, final Instant timeTo) throws IllegalArgumentException { + final long timeFromTs = ApiUtils.validateMillisecondInstant(timeFrom, prepareMillisCheckFailMsgPrefix(timeFrom, "timeFrom")); + final long timeToTs = ApiUtils.validateMillisecondInstant(timeTo, prepareMillisCheckFailMsgPrefix(timeTo, "timeTo")); + if (!open) { + throw new InvalidStateStoreException("Store is not open"); + } + final List, V>> results = new ArrayList<>(); + for (final long now : data.descendingKeySet()) { + if (!(now >= timeFromTs && now <= timeToTs)) { + continue; + } + final NavigableMap kvMap = data.get(now); + if (kvMap != null) { + for (final Entry entry : kvMap.descendingMap().entrySet()) { + results.add(new KeyValue<>(new Windowed<>(entry.getKey(), new TimeWindow(now, now + windowSize)), entry.getValue())); + } + } + } + final Iterator, V>> iterator = results.iterator(); + + return new KeyValueIterator, V>() { + @Override + public void close() { + } + + @Override + public Windowed peekNextKey() { + throw new UnsupportedOperationException("peekNextKey() not supported in " + getClass().getName()); + } + + @Override + public boolean hasNext() { + return iterator.hasNext(); + } + + @Override + public KeyValue, V> next() { + return iterator.next(); + } + + }; + } + + @Override + public KeyValueIterator, V> fetch(final K keyFrom, final K keyTo, final Instant timeFrom, final Instant timeTo) { + if (!open) { + throw new InvalidStateStoreException("Store is not open"); + } + final List, V>> results = new ArrayList<>(); + for (long now = timeFrom.toEpochMilli(); now <= timeTo.toEpochMilli(); now++) { + final NavigableMap kvMap = data.get(now); + if (kvMap != null) { + final NavigableMap kvSubMap; + if (keyFrom == null && keyFrom == null) { + kvSubMap = kvMap; + } else if (keyFrom == null) { + kvSubMap = kvMap.headMap(keyTo, true); + } else if (keyTo == null) { + kvSubMap = kvMap.tailMap(keyFrom, true); + } else { + // keyFrom != null and KeyTo != null + kvSubMap = kvMap.subMap(keyFrom, true, keyTo, true); + } + + for (final Entry entry : kvSubMap.entrySet()) { + results.add(new KeyValue<>(new Windowed<>(entry.getKey(), new TimeWindow(now, now + windowSize)), entry.getValue())); + } + } + } + final Iterator, V>> iterator = results.iterator(); + + return new KeyValueIterator, V>() { + @Override + public void close() { + } + + @Override + public Windowed peekNextKey() { + throw new UnsupportedOperationException("peekNextKey() not supported in " + getClass().getName()); + } + + @Override + public boolean hasNext() { + return iterator.hasNext(); + } + + @Override + public KeyValue, V> next() { + return iterator.next(); + } + + }; + } + + @Override + public KeyValueIterator, V> backwardFetch(final K keyFrom, + final K keyTo, + final Instant timeFrom, + final Instant timeTo) throws IllegalArgumentException { + final long timeFromTs = ApiUtils.validateMillisecondInstant(timeFrom, prepareMillisCheckFailMsgPrefix(timeFrom, "timeFrom")); + final long timeToTs = ApiUtils.validateMillisecondInstant(timeTo, prepareMillisCheckFailMsgPrefix(timeTo, "timeTo")); + if (!open) { + throw new InvalidStateStoreException("Store is not open"); + } + final List, V>> results = new ArrayList<>(); + for (long now = timeToTs; now >= timeFromTs; now--) { + final NavigableMap kvMap = data.get(now); + if (kvMap != null) { + final NavigableMap kvSubMap; + if (keyFrom == null && keyFrom == null) { + kvSubMap = kvMap; + } else if (keyFrom == null) { + kvSubMap = kvMap.headMap(keyTo, true); + } else if (keyTo == null) { + kvSubMap = kvMap.tailMap(keyFrom, true); + } else { + // keyFrom != null and KeyTo != null + kvSubMap = kvMap.subMap(keyFrom, true, keyTo, true); + } + + for (final Entry entry : kvSubMap.descendingMap().entrySet()) { + results.add(new KeyValue<>(new Windowed<>(entry.getKey(), new TimeWindow(now, now + windowSize)), entry.getValue())); + } + } + } + final Iterator, V>> iterator = results.iterator(); + + return new KeyValueIterator, V>() { + @Override + public void close() { + } + + @Override + public Windowed peekNextKey() { + throw new UnsupportedOperationException("peekNextKey() not supported in " + getClass().getName()); + } + + @Override + public boolean hasNext() { + return iterator.hasNext(); + } + + @Override + public KeyValue, V> next() { + return iterator.next(); + } + + }; + } + + public void put(final K key, final V value, final long timestamp) { + if (!data.containsKey(timestamp)) { + data.put(timestamp, new TreeMap<>()); + } + data.get(timestamp).put(key, value); + } + + @Override + public String name() { + return null; + } + + @Deprecated + @Override + public void init(final ProcessorContext context, final StateStore root) { + } + + @Override + public void flush() { + } + + @Override + public void close() { + } + + @Override + public boolean persistent() { + return false; + } + + @Override + public boolean isOpen() { + return open; + } + + void setOpen(final boolean open) { + this.open = open; + } + + private static class TheWindowStoreIterator implements WindowStoreIterator { + + private final Iterator> underlying; + + TheWindowStoreIterator(final Iterator> underlying) { + this.underlying = underlying; + } + + @Override + public void close() { + } + + @Override + public Long peekNextKey() { + throw new UnsupportedOperationException("peekNextKey() not supported in " + getClass().getName()); + } + + @Override + public boolean hasNext() { + return underlying.hasNext(); + } + + @Override + public KeyValue next() { + return underlying.next(); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/RecordConvertersTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/RecordConvertersTest.java new file mode 100644 index 0000000..e409ca1 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/RecordConvertersTest.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.record.TimestampType; +import org.junit.Test; + +import java.nio.ByteBuffer; +import java.util.Optional; + +import static org.apache.kafka.streams.state.internals.RecordConverters.rawValueToTimestampedValue; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertNull; + +public class RecordConvertersTest { + + private final RecordConverter timestampedValueConverter = rawValueToTimestampedValue(); + + @Test + public void shouldPreserveNullValueOnConversion() { + final ConsumerRecord nullValueRecord = new ConsumerRecord<>("", 0, 0L, new byte[0], null); + assertNull(timestampedValueConverter.convert(nullValueRecord).value()); + } + + @Test + public void shouldAddTimestampToValueOnConversionWhenValueIsNotNull() { + final long timestamp = 10L; + final byte[] value = new byte[1]; + final ConsumerRecord inputRecord = new ConsumerRecord<>( + "topic", 1, 0, timestamp, TimestampType.CREATE_TIME, 0, 0, new byte[0], value, + new RecordHeaders(), Optional.empty()); + final byte[] expectedValue = ByteBuffer.allocate(9).putLong(timestamp).put(value).array(); + final byte[] actualValue = timestampedValueConverter.convert(inputRecord).value(); + assertArrayEquals(expectedValue, actualValue); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBGenericOptionsToDbOptionsColumnFamilyOptionsAdapterTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBGenericOptionsToDbOptionsColumnFamilyOptionsAdapterTest.java new file mode 100644 index 0000000..4cafecf --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBGenericOptionsToDbOptionsColumnFamilyOptionsAdapterTest.java @@ -0,0 +1,356 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.easymock.EasyMockRunner; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.rocksdb.AbstractCompactionFilter; +import org.rocksdb.AbstractCompactionFilter.Context; +import org.rocksdb.AbstractCompactionFilterFactory; +import org.rocksdb.AbstractWalFilter; +import org.rocksdb.AccessHint; +import org.rocksdb.BuiltinComparator; +import org.rocksdb.ColumnFamilyOptions; +import org.rocksdb.CompactionPriority; +import org.rocksdb.CompactionStyle; +import org.rocksdb.ComparatorOptions; +import org.rocksdb.CompressionType; +import org.rocksdb.DBOptions; +import org.rocksdb.Env; +import org.rocksdb.InfoLogLevel; +import org.rocksdb.LRUCache; +import org.rocksdb.Logger; +import org.rocksdb.Options; +import org.rocksdb.PlainTableConfig; +import org.rocksdb.RateLimiter; +import org.rocksdb.RemoveEmptyValueCompactionFilter; +import org.rocksdb.RocksDB; +import org.rocksdb.SstFileManager; +import org.rocksdb.StringAppendOperator; +import org.rocksdb.VectorMemTableConfig; +import org.rocksdb.WALRecoveryMode; +import org.rocksdb.WalProcessingOption; +import org.rocksdb.WriteBatch; +import org.rocksdb.WriteBufferManager; +import org.rocksdb.util.BytewiseComparator; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; + +import java.util.Arrays; +import java.util.Set; +import java.util.ArrayList; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.easymock.EasyMock.mock; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.reset; +import static org.easymock.EasyMock.verify; +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.matchesPattern; +import static org.junit.Assert.fail; + +/** + * The purpose of this test is, to catch interface changes if we upgrade {@link RocksDB}. + * Using reflections, we make sure the {@link RocksDBGenericOptionsToDbOptionsColumnFamilyOptionsAdapter} maps all + * methods from {@link DBOptions} and {@link ColumnFamilyOptions} to/from {@link Options} correctly. + */ +@RunWith(EasyMockRunner.class) +public class RocksDBGenericOptionsToDbOptionsColumnFamilyOptionsAdapterTest { + + private final List walRelatedMethods = new LinkedList() { + { + add("setManualWalFlush"); + add("setMaxTotalWalSize"); + add("setWalBytesPerSync"); + add("setWalDir"); + add("setWalFilter"); + add("setWalRecoveryMode"); + add("setWalSizeLimitMB"); + add("setWalTtlSeconds"); + } + }; + + private final List ignoreMethods = new LinkedList() { + { + add("isOwningHandle"); + add("getNativeHandle"); + add("dispose"); + add("wait"); + add("equals"); + add("getClass"); + add("hashCode"); + add("notify"); + add("notifyAll"); + add("toString"); + add("getOptionStringFromProps"); + addAll(walRelatedMethods); + } + }; + + @Test + public void shouldOverwriteAllOptionsMethods() throws Exception { + for (final Method method : Options.class.getMethods()) { + if (!ignoreMethods.contains(method.getName())) { + RocksDBGenericOptionsToDbOptionsColumnFamilyOptionsAdapter.class + .getDeclaredMethod(method.getName(), method.getParameterTypes()); + } + } + } + + @Test + public void shouldForwardAllDbOptionsCalls() throws Exception { + for (final Method method : Options.class.getMethods()) { + if (!ignoreMethods.contains(method.getName())) { + try { + DBOptions.class.getMethod(method.getName(), method.getParameterTypes()); + verifyDBOptionsMethodCall(method); + } catch (final NoSuchMethodException expectedAndSwallow) { } + } + } + } + + private void verifyDBOptionsMethodCall(final Method method) throws Exception { + final DBOptions mockedDbOptions = mock(DBOptions.class); + final RocksDBGenericOptionsToDbOptionsColumnFamilyOptionsAdapter optionsFacadeDbOptions + = new RocksDBGenericOptionsToDbOptionsColumnFamilyOptionsAdapter(mockedDbOptions, new ColumnFamilyOptions()); + + final Object[] parameters = getDBOptionsParameters(method.getParameterTypes()); + + try { + reset(mockedDbOptions); + replay(mockedDbOptions); + method.invoke(optionsFacadeDbOptions, parameters); + verify(); + fail("Should have called DBOptions." + method.getName() + "()"); + } catch (final InvocationTargetException undeclaredMockMethodCall) { + assertThat(undeclaredMockMethodCall.getCause(), instanceOf(AssertionError.class)); + assertThat(undeclaredMockMethodCall.getCause().getMessage().trim(), + matchesPattern("Unexpected method call DBOptions\\." + method.getName() + "((.*\n*)*):")); + } + } + + private Object[] getDBOptionsParameters(final Class[] parameterTypes) throws Exception { + final Object[] parameters = new Object[parameterTypes.length]; + + for (int i = 0; i < parameterTypes.length; ++i) { + switch (parameterTypes[i].getName()) { + case "boolean": + parameters[i] = true; + break; + case "int": + parameters[i] = 0; + break; + case "long": + parameters[i] = 0L; + break; + case "java.util.Collection": + parameters[i] = new ArrayList<>(); + break; + case "org.rocksdb.AccessHint": + parameters[i] = AccessHint.NONE; + break; + case "org.rocksdb.Cache": + parameters[i] = new LRUCache(1L); + break; + case "org.rocksdb.Env": + parameters[i] = Env.getDefault(); + break; + case "org.rocksdb.InfoLogLevel": + parameters[i] = InfoLogLevel.FATAL_LEVEL; + break; + case "org.rocksdb.Logger": + parameters[i] = new Logger(new Options()) { + @Override + protected void log(final InfoLogLevel infoLogLevel, final String logMsg) {} + }; + break; + case "org.rocksdb.RateLimiter": + parameters[i] = new RateLimiter(1L); + break; + case "org.rocksdb.SstFileManager": + parameters[i] = new SstFileManager(Env.getDefault()); + break; + case "org.rocksdb.WALRecoveryMode": + parameters[i] = WALRecoveryMode.AbsoluteConsistency; + break; + case "org.rocksdb.WriteBufferManager": + parameters[i] = new WriteBufferManager(1L, new LRUCache(1L)); + break; + case "org.rocksdb.AbstractWalFilter": + class TestWalFilter extends AbstractWalFilter { + @Override + public void columnFamilyLogNumberMap(final Map cfLognumber, final Map cfNameId) { + } + + @Override + public LogRecordFoundResult logRecordFound(final long logNumber, final String logFileName, final WriteBatch batch, final WriteBatch newBatch) { + return new LogRecordFoundResult(WalProcessingOption.CONTINUE_PROCESSING, false); + } + + @Override + public String name() { + return "TestWalFilter"; + } + } + parameters[i] = new TestWalFilter(); + break; + default: + parameters[i] = parameterTypes[i].getConstructor().newInstance(); + } + } + + return parameters; + } + + @Test + public void shouldForwardAllColumnFamilyCalls() throws Exception { + for (final Method method : Options.class.getMethods()) { + if (!ignoreMethods.contains(method.getName())) { + try { + ColumnFamilyOptions.class.getMethod(method.getName(), method.getParameterTypes()); + verifyColumnFamilyOptionsMethodCall(method); + } catch (final NoSuchMethodException expectedAndSwallow) { } + } + } + } + + private void verifyColumnFamilyOptionsMethodCall(final Method method) throws Exception { + final ColumnFamilyOptions mockedColumnFamilyOptions = mock(ColumnFamilyOptions.class); + final RocksDBGenericOptionsToDbOptionsColumnFamilyOptionsAdapter optionsFacadeColumnFamilyOptions + = new RocksDBGenericOptionsToDbOptionsColumnFamilyOptionsAdapter(new DBOptions(), mockedColumnFamilyOptions); + + final Object[] parameters = getColumnFamilyOptionsParameters(method.getParameterTypes()); + + try { + reset(mockedColumnFamilyOptions); + replay(mockedColumnFamilyOptions); + method.invoke(optionsFacadeColumnFamilyOptions, parameters); + verify(); + fail("Should have called ColumnFamilyOptions." + method.getName() + "()"); + } catch (final InvocationTargetException undeclaredMockMethodCall) { + assertThat(undeclaredMockMethodCall.getCause(), instanceOf(AssertionError.class)); + assertThat(undeclaredMockMethodCall.getCause().getMessage().trim(), + matchesPattern("Unexpected method call ColumnFamilyOptions\\." + method.getName() + "(.*)")); + } + } + + private Object[] getColumnFamilyOptionsParameters(final Class[] parameterTypes) throws Exception { + final Object[] parameters = new Object[parameterTypes.length]; + + for (int i = 0; i < parameterTypes.length; ++i) { + switch (parameterTypes[i].getName()) { + case "boolean": + parameters[i] = true; + break; + case "double": + parameters[i] = 0.0d; + break; + case "int": + parameters[i] = 0; + break; + case "long": + parameters[i] = 0L; + break; + case "[I": + parameters[i] = new int[0]; + break; + case "java.util.List": + parameters[i] = new ArrayList<>(); + break; + case "org.rocksdb.AbstractCompactionFilter": + parameters[i] = new RemoveEmptyValueCompactionFilter(); + break; + case "org.rocksdb.AbstractCompactionFilterFactory": + parameters[i] = new AbstractCompactionFilterFactory>() { + + @Override + public AbstractCompactionFilter createCompactionFilter(final Context context) { + return null; + } + + @Override + public String name() { + return "AbstractCompactionFilterFactory"; + } + }; + break; + case "org.rocksdb.AbstractComparator": + parameters[i] = new BytewiseComparator(new ComparatorOptions()); + break; + case "org.rocksdb.BuiltinComparator": + parameters[i] = BuiltinComparator.BYTEWISE_COMPARATOR; + break; + case "org.rocksdb.CompactionPriority": + parameters[i] = CompactionPriority.ByCompensatedSize; + break; + case "org.rocksdb.CompactionStyle": + parameters[i] = CompactionStyle.UNIVERSAL; + break; + case "org.rocksdb.CompressionType": + parameters[i] = CompressionType.NO_COMPRESSION; + break; + case "org.rocksdb.MemTableConfig": + parameters[i] = new VectorMemTableConfig(); + break; + case "org.rocksdb.MergeOperator": + parameters[i] = new StringAppendOperator(); + break; + case "org.rocksdb.TableFormatConfig": + parameters[i] = new PlainTableConfig(); + break; + default: + parameters[i] = parameterTypes[i].getConstructor().newInstance(); + } + } + + return parameters; + } + + @Test + public void shouldLogWarningWhenSettingWalOptions() throws Exception { + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(RocksDBGenericOptionsToDbOptionsColumnFamilyOptionsAdapter.class)) { + + final RocksDBGenericOptionsToDbOptionsColumnFamilyOptionsAdapter adapter + = new RocksDBGenericOptionsToDbOptionsColumnFamilyOptionsAdapter(new DBOptions(), new ColumnFamilyOptions()); + + for (final Method method : RocksDBGenericOptionsToDbOptionsColumnFamilyOptionsAdapter.class.getDeclaredMethods()) { + if (walRelatedMethods.contains(method.getName())) { + method.invoke(adapter, getDBOptionsParameters(method.getParameterTypes())); + } + } + + final List walOptions = Arrays.asList("walDir", "walFilter", "walRecoveryMode", "walBytesPerSync", "walSizeLimitMB", "manualWalFlush", "maxTotalWalSize", "walTtlSeconds"); + + final Set logMessages = appender.getEvents().stream() + .filter(e -> e.getLevel().equals("WARN")) + .map(LogCaptureAppender.Event::getMessage) + .collect(Collectors.toSet()); + + walOptions.forEach(option -> assertThat(logMessages, hasItem(String.format("WAL is explicitly disabled by Streams in RocksDB. Setting option '%s' will be ignored", option)))); + + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBRangeIteratorTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBRangeIteratorTest.java new file mode 100644 index 0000000..b4c7d79 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBRangeIteratorTest.java @@ -0,0 +1,440 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; +import org.junit.Test; +import org.rocksdb.RocksIterator; + +import java.util.Collections; +import java.util.NoSuchElementException; + +import static org.easymock.EasyMock.mock; +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.expectLastCall; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.verify; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.Is.is; +import static org.junit.Assert.assertThrows; + +public class RocksDBRangeIteratorTest { + + private final String storeName = "store"; + private final String key1 = "a"; + private final String key2 = "b"; + private final String key3 = "c"; + private final String key4 = "d"; + + private final String value = "value"; + private final Bytes key1Bytes = Bytes.wrap(key1.getBytes()); + private final Bytes key2Bytes = Bytes.wrap(key2.getBytes()); + private final Bytes key3Bytes = Bytes.wrap(key3.getBytes()); + private final Bytes key4Bytes = Bytes.wrap(key4.getBytes()); + private final byte[] valueBytes = value.getBytes(); + + @Test + public void shouldReturnAllKeysInTheRangeInForwardDirection() { + final RocksIterator rocksIterator = mock(RocksIterator.class); + rocksIterator.seek(key1Bytes.get()); + expect(rocksIterator.isValid()) + .andReturn(true) + .andReturn(true) + .andReturn(true) + .andReturn(false); + expect(rocksIterator.key()) + .andReturn(key1Bytes.get()) + .andReturn(key2Bytes.get()) + .andReturn(key3Bytes.get()); + expect(rocksIterator.value()).andReturn(valueBytes).times(3); + rocksIterator.next(); + expectLastCall().times(3); + replay(rocksIterator); + final RocksDBRangeIterator rocksDBRangeIterator = new RocksDBRangeIterator( + storeName, + rocksIterator, + Collections.emptySet(), + key1Bytes, + key3Bytes, + true, + true + ); + assertThat(rocksDBRangeIterator.hasNext(), is(true)); + assertThat(rocksDBRangeIterator.next().key, is(key1Bytes)); + assertThat(rocksDBRangeIterator.hasNext(), is(true)); + assertThat(rocksDBRangeIterator.next().key, is(key2Bytes)); + assertThat(rocksDBRangeIterator.hasNext(), is(true)); + assertThat(rocksDBRangeIterator.next().key, is(key3Bytes)); + assertThat(rocksDBRangeIterator.hasNext(), is(false)); + verify(rocksIterator); + } + + @Test + public void shouldReturnAllKeysInTheRangeReverseDirection() { + final RocksIterator rocksIterator = mock(RocksIterator.class); + rocksIterator.seekForPrev(key3Bytes.get()); + expect(rocksIterator.isValid()) + .andReturn(true) + .andReturn(true) + .andReturn(true) + .andReturn(false); + expect(rocksIterator.key()) + .andReturn(key3Bytes.get()) + .andReturn(key2Bytes.get()) + .andReturn(key1Bytes.get()); + expect(rocksIterator.value()).andReturn(valueBytes).times(3); + rocksIterator.prev(); + expectLastCall().times(3); + replay(rocksIterator); + final RocksDBRangeIterator rocksDBRangeIterator = new RocksDBRangeIterator( + storeName, + rocksIterator, + Collections.emptySet(), + key1Bytes, + key3Bytes, + false, + true + ); + assertThat(rocksDBRangeIterator.hasNext(), is(true)); + assertThat(rocksDBRangeIterator.next().key, is(key3Bytes)); + assertThat(rocksDBRangeIterator.hasNext(), is(true)); + assertThat(rocksDBRangeIterator.next().key, is(key2Bytes)); + assertThat(rocksDBRangeIterator.hasNext(), is(true)); + assertThat(rocksDBRangeIterator.next().key, is(key1Bytes)); + assertThat(rocksDBRangeIterator.hasNext(), is(false)); + verify(rocksIterator); + } + + @Test + public void shouldReturnAllKeysWhenLastKeyIsGreaterThanLargestKeyInStateStoreInForwardDirection() { + final Bytes toBytes = Bytes.increment(key4Bytes); + final RocksIterator rocksIterator = mock(RocksIterator.class); + rocksIterator.seek(key1Bytes.get()); + expect(rocksIterator.isValid()) + .andReturn(true) + .andReturn(true) + .andReturn(true) + .andReturn(true) + .andReturn(false); + expect(rocksIterator.key()) + .andReturn(key1Bytes.get()) + .andReturn(key2Bytes.get()) + .andReturn(key3Bytes.get()) + .andReturn(key4Bytes.get()); + expect(rocksIterator.value()).andReturn(valueBytes).times(4); + rocksIterator.next(); + expectLastCall().times(4); + replay(rocksIterator); + final RocksDBRangeIterator rocksDBRangeIterator = new RocksDBRangeIterator( + storeName, + rocksIterator, + Collections.emptySet(), + key1Bytes, + toBytes, + true, + true + ); + assertThat(rocksDBRangeIterator.hasNext(), is(true)); + assertThat(rocksDBRangeIterator.next().key, is(key1Bytes)); + assertThat(rocksDBRangeIterator.hasNext(), is(true)); + assertThat(rocksDBRangeIterator.next().key, is(key2Bytes)); + assertThat(rocksDBRangeIterator.hasNext(), is(true)); + assertThat(rocksDBRangeIterator.next().key, is(key3Bytes)); + assertThat(rocksDBRangeIterator.hasNext(), is(true)); + assertThat(rocksDBRangeIterator.next().key, is(key4Bytes)); + assertThat(rocksDBRangeIterator.hasNext(), is(false)); + verify(rocksIterator); + } + + + @Test + public void shouldReturnAllKeysWhenLastKeyIsSmallerThanSmallestKeyInStateStoreInReverseDirection() { + final RocksIterator rocksIterator = mock(RocksIterator.class); + rocksIterator.seekForPrev(key4Bytes.get()); + expect(rocksIterator.isValid()) + .andReturn(true) + .andReturn(true) + .andReturn(true) + .andReturn(true) + .andReturn(false); + expect(rocksIterator.key()) + .andReturn(key4Bytes.get()) + .andReturn(key3Bytes.get()) + .andReturn(key2Bytes.get()) + .andReturn(key1Bytes.get()); + expect(rocksIterator.value()).andReturn(valueBytes).times(4); + rocksIterator.prev(); + expectLastCall().times(4); + replay(rocksIterator); + final RocksDBRangeIterator rocksDBRangeIterator = new RocksDBRangeIterator( + storeName, + rocksIterator, + Collections.emptySet(), + key1Bytes, + key4Bytes, + false, + true + ); + assertThat(rocksDBRangeIterator.hasNext(), is(true)); + assertThat(rocksDBRangeIterator.next().key, is(key4Bytes)); + assertThat(rocksDBRangeIterator.hasNext(), is(true)); + assertThat(rocksDBRangeIterator.next().key, is(key3Bytes)); + assertThat(rocksDBRangeIterator.hasNext(), is(true)); + assertThat(rocksDBRangeIterator.next().key, is(key2Bytes)); + assertThat(rocksDBRangeIterator.hasNext(), is(true)); + assertThat(rocksDBRangeIterator.next().key, is(key1Bytes)); + assertThat(rocksDBRangeIterator.hasNext(), is(false)); + verify(rocksIterator); + } + + + @Test + public void shouldReturnNoKeysWhenLastKeyIsSmallerThanSmallestKeyInStateStoreForwardDirection() { + // key range in state store: [c-f] + final RocksIterator rocksIterator = mock(RocksIterator.class); + rocksIterator.seek(key1Bytes.get()); + expect(rocksIterator.isValid()).andReturn(false); + replay(rocksIterator); + final RocksDBRangeIterator rocksDBRangeIterator = new RocksDBRangeIterator( + storeName, + rocksIterator, + Collections.emptySet(), + key1Bytes, + key2Bytes, + true, + true + ); + assertThat(rocksDBRangeIterator.hasNext(), is(false)); + verify(rocksIterator); + } + + @Test + public void shouldReturnNoKeysWhenLastKeyIsLargerThanLargestKeyInStateStoreReverseDirection() { + // key range in state store: [c-f] + final String from = "g"; + final String to = "h"; + final Bytes fromBytes = Bytes.wrap(from.getBytes()); + final Bytes toBytes = Bytes.wrap(to.getBytes()); + final RocksIterator rocksIterator = mock(RocksIterator.class); + rocksIterator.seekForPrev(toBytes.get()); + expect(rocksIterator.isValid()) + .andReturn(false); + replay(rocksIterator); + final RocksDBRangeIterator rocksDBRangeIterator = new RocksDBRangeIterator( + storeName, + rocksIterator, + Collections.emptySet(), + fromBytes, + toBytes, + false, + true + ); + assertThat(rocksDBRangeIterator.hasNext(), is(false)); + verify(rocksIterator); + } + + @Test + public void shouldReturnAllKeysInPartiallyOverlappingRangeInForwardDirection() { + final RocksIterator rocksIterator = mock(RocksIterator.class); + rocksIterator.seek(key1Bytes.get()); + expect(rocksIterator.isValid()) + .andReturn(true) + .andReturn(true) + .andReturn(false); + expect(rocksIterator.key()) + .andReturn(key2Bytes.get()) + .andReturn(key3Bytes.get()); + expect(rocksIterator.value()).andReturn(valueBytes).times(2); + rocksIterator.next(); + expectLastCall().times(2); + replay(rocksIterator); + final RocksDBRangeIterator rocksDBRangeIterator = new RocksDBRangeIterator( + storeName, + rocksIterator, + Collections.emptySet(), + key1Bytes, + key3Bytes, + true, + true + ); + assertThat(rocksDBRangeIterator.hasNext(), is(true)); + assertThat(rocksDBRangeIterator.next().key, is(key2Bytes)); + assertThat(rocksDBRangeIterator.hasNext(), is(true)); + assertThat(rocksDBRangeIterator.next().key, is(key3Bytes)); + assertThat(rocksDBRangeIterator.hasNext(), is(false)); + verify(rocksIterator); + } + + @Test + public void shouldReturnAllKeysInPartiallyOverlappingRangeInReverseDirection() { + final RocksIterator rocksIterator = mock(RocksIterator.class); + final String to = "e"; + final Bytes toBytes = Bytes.wrap(to.getBytes()); + rocksIterator.seekForPrev(toBytes.get()); + expect(rocksIterator.isValid()) + .andReturn(true) + .andReturn(true) + .andReturn(false); + expect(rocksIterator.key()) + .andReturn(key4Bytes.get()) + .andReturn(key3Bytes.get()); + expect(rocksIterator.value()).andReturn(valueBytes).times(2); + rocksIterator.prev(); + expectLastCall().times(2); + replay(rocksIterator); + final RocksDBRangeIterator rocksDBRangeIterator = new RocksDBRangeIterator( + storeName, + rocksIterator, + Collections.emptySet(), + key3Bytes, + toBytes, + false, + true + ); + assertThat(rocksDBRangeIterator.hasNext(), is(true)); + assertThat(rocksDBRangeIterator.next().key, is(key4Bytes)); + assertThat(rocksDBRangeIterator.hasNext(), is(true)); + assertThat(rocksDBRangeIterator.next().key, is(key3Bytes)); + assertThat(rocksDBRangeIterator.hasNext(), is(false)); + verify(rocksIterator); + } + + @Test + public void shouldReturnTheCurrentKeyOnInvokingPeekNextKeyInForwardDirection() { + final RocksIterator rocksIterator = mock(RocksIterator.class); + rocksIterator.seek(key1Bytes.get()); + expect(rocksIterator.isValid()) + .andReturn(true) + .andReturn(true) + .andReturn(false); + expect(rocksIterator.key()) + .andReturn(key2Bytes.get()) + .andReturn(key3Bytes.get()); + expect(rocksIterator.value()).andReturn(valueBytes).times(2); + rocksIterator.next(); + expectLastCall().times(2); + replay(rocksIterator); + final RocksDBRangeIterator rocksDBRangeIterator = new RocksDBRangeIterator( + storeName, + rocksIterator, + Collections.emptySet(), + key1Bytes, + key3Bytes, + true, + true + ); + assertThat(rocksDBRangeIterator.hasNext(), is(true)); + assertThat(rocksDBRangeIterator.peekNextKey(), is(key2Bytes)); + assertThat(rocksDBRangeIterator.peekNextKey(), is(key2Bytes)); + assertThat(rocksDBRangeIterator.next().key, is(key2Bytes)); + assertThat(rocksDBRangeIterator.hasNext(), is(true)); + assertThat(rocksDBRangeIterator.peekNextKey(), is(key3Bytes)); + assertThat(rocksDBRangeIterator.peekNextKey(), is(key3Bytes)); + assertThat(rocksDBRangeIterator.next().key, is(key3Bytes)); + assertThat(rocksDBRangeIterator.hasNext(), is(false)); + assertThrows(NoSuchElementException.class, rocksDBRangeIterator::peekNextKey); + verify(rocksIterator); + } + + @Test + public void shouldReturnTheCurrentKeyOnInvokingPeekNextKeyInReverseDirection() { + final RocksIterator rocksIterator = mock(RocksIterator.class); + final Bytes toBytes = Bytes.increment(key4Bytes); + rocksIterator.seekForPrev(toBytes.get()); + expect(rocksIterator.isValid()) + .andReturn(true) + .andReturn(true) + .andReturn(false); + expect(rocksIterator.key()) + .andReturn(key4Bytes.get()) + .andReturn(key3Bytes.get()); + expect(rocksIterator.value()).andReturn(valueBytes).times(2); + rocksIterator.prev(); + expectLastCall().times(2); + replay(rocksIterator); + final RocksDBRangeIterator rocksDBRangeIterator = new RocksDBRangeIterator( + storeName, + rocksIterator, + Collections.emptySet(), + key3Bytes, + toBytes, + false, + true + ); + assertThat(rocksDBRangeIterator.hasNext(), is(true)); + assertThat(rocksDBRangeIterator.peekNextKey(), is(key4Bytes)); + assertThat(rocksDBRangeIterator.peekNextKey(), is(key4Bytes)); + assertThat(rocksDBRangeIterator.next().key, is(key4Bytes)); + assertThat(rocksDBRangeIterator.hasNext(), is(true)); + assertThat(rocksDBRangeIterator.peekNextKey(), is(key3Bytes)); + assertThat(rocksDBRangeIterator.peekNextKey(), is(key3Bytes)); + assertThat(rocksDBRangeIterator.next().key, is(key3Bytes)); + assertThat(rocksDBRangeIterator.hasNext(), is(false)); + assertThrows(NoSuchElementException.class, rocksDBRangeIterator::peekNextKey); + verify(rocksIterator); + } + + @Test + public void shouldCloseIterator() { + final RocksIterator rocksIterator = mock(RocksIterator.class); + rocksIterator.seek(key1Bytes.get()); + rocksIterator.close(); + expectLastCall().times(1); + replay(rocksIterator); + final RocksDBRangeIterator rocksDBRangeIterator = new RocksDBRangeIterator( + storeName, + rocksIterator, + Collections.emptySet(), + key1Bytes, + key2Bytes, + true, + true + ); + rocksDBRangeIterator.close(); + verify(rocksIterator); + } + + @Test + public void shouldExcludeEndOfRange() { + final RocksIterator rocksIterator = mock(RocksIterator.class); + rocksIterator.seek(key1Bytes.get()); + expect(rocksIterator.isValid()) + .andReturn(true) + .andReturn(true); + expect(rocksIterator.key()) + .andReturn(key1Bytes.get()) + .andReturn(key2Bytes.get()); + expect(rocksIterator.value()).andReturn(valueBytes).times(2); + rocksIterator.next(); + expectLastCall().times(2); + replay(rocksIterator); + final RocksDBRangeIterator rocksDBRangeIterator = new RocksDBRangeIterator( + storeName, + rocksIterator, + Collections.emptySet(), + key1Bytes, + key2Bytes, + true, + false + ); + assertThat(rocksDBRangeIterator.hasNext(), is(true)); + assertThat(rocksDBRangeIterator.next().key, is(key1Bytes)); + assertThat(rocksDBRangeIterator.hasNext(), is(false)); + verify(rocksIterator); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBSegmentedBytesStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBSegmentedBytesStoreTest.java new file mode 100644 index 0000000..3b6904f --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBSegmentedBytesStoreTest.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +public class RocksDBSegmentedBytesStoreTest extends AbstractRocksDBSegmentedBytesStoreTest { + + private final static String METRICS_SCOPE = "metrics-scope"; + + @Override + RocksDBSegmentedBytesStore getBytesStore() { + return new RocksDBSegmentedBytesStore( + storeName, + METRICS_SCOPE, + retention, + segmentInterval, + schema + ); + } + + @Override + KeyValueSegments newSegments() { + return new KeyValueSegments(storeName, METRICS_SCOPE, retention, segmentInterval); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBSessionStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBSessionStoreTest.java new file mode 100644 index 0000000..e3b98b7 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBSessionStoreTest.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.SessionWindow; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.SessionStore; +import org.apache.kafka.streams.state.Stores; +import org.junit.Test; + +import java.util.Arrays; +import java.util.HashSet; + +import static java.time.Duration.ofMillis; +import static org.apache.kafka.test.StreamsTestUtils.valuesToSet; +import static org.junit.Assert.assertEquals; + + +public class RocksDBSessionStoreTest extends AbstractSessionBytesStoreTest { + + private static final String STORE_NAME = "rocksDB session store"; + + @Override + SessionStore buildSessionStore(final long retentionPeriod, + final Serde keySerde, + final Serde valueSerde) { + return Stores.sessionStoreBuilder( + Stores.persistentSessionStore( + STORE_NAME, + ofMillis(retentionPeriod)), + keySerde, + valueSerde).build(); + } + + @Test + public void shouldRemoveExpired() { + sessionStore.put(new Windowed<>("a", new SessionWindow(0, 0)), 1L); + sessionStore.put(new Windowed<>("aa", new SessionWindow(0, SEGMENT_INTERVAL)), 2L); + sessionStore.put(new Windowed<>("a", new SessionWindow(10, SEGMENT_INTERVAL)), 3L); + + // Advance stream time to expire the first record + sessionStore.put(new Windowed<>("aa", new SessionWindow(10, 2 * SEGMENT_INTERVAL)), 4L); + + try (final KeyValueIterator, Long> iterator = + sessionStore.findSessions("a", "b", 0L, Long.MAX_VALUE) + ) { + assertEquals(valuesToSet(iterator), new HashSet<>(Arrays.asList(2L, 3L, 4L))); + } + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBStoreTest.java new file mode 100644 index 0000000..066f080 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBStoreTest.java @@ -0,0 +1,984 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.MetricConfig; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.Sensor.RecordingLevel; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.errors.ProcessorStateException; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.RocksDBConfigSetter; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.internals.metrics.RocksDBMetricsRecorder; +import org.apache.kafka.test.InternalMockProcessorContext; +import org.apache.kafka.test.MockRocksDbConfigSetter; +import org.apache.kafka.test.StreamsTestUtils; +import org.apache.kafka.test.TestUtils; +import org.easymock.EasyMock; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.jupiter.api.Assertions; +import org.rocksdb.BlockBasedTableConfig; +import org.rocksdb.BloomFilter; +import org.rocksdb.Cache; +import org.rocksdb.Filter; +import org.rocksdb.LRUCache; +import org.rocksdb.Options; +import org.rocksdb.PlainTableConfig; +import org.rocksdb.Statistics; + +import java.io.File; +import java.io.IOException; +import java.math.BigInteger; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.UUID; +import java.util.stream.Collectors; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.easymock.EasyMock.eq; +import static org.easymock.EasyMock.isNull; +import static org.easymock.EasyMock.mock; +import static org.easymock.EasyMock.notNull; +import static org.easymock.EasyMock.reset; +import static org.hamcrest.CoreMatchers.either; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThan; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.powermock.api.easymock.PowerMock.replay; +import static org.powermock.api.easymock.PowerMock.verify; + +@SuppressWarnings("unchecked") +public class RocksDBStoreTest extends AbstractKeyValueStoreTest { + private static boolean enableBloomFilters = false; + final static String DB_NAME = "db-name"; + final static String METRICS_SCOPE = "metrics-scope"; + + private File dir; + private final Time time = new MockTime(); + private final Serializer stringSerializer = new StringSerializer(); + private final Deserializer stringDeserializer = new StringDeserializer(); + + private final RocksDBMetricsRecorder metricsRecorder = mock(RocksDBMetricsRecorder.class); + + InternalMockProcessorContext context; + RocksDBStore rocksDBStore; + + @Before + public void setUp() { + final Properties props = StreamsTestUtils.getStreamsConfig(); + props.put(StreamsConfig.ROCKSDB_CONFIG_SETTER_CLASS_CONFIG, MockRocksDbConfigSetter.class); + dir = TestUtils.tempDirectory(); + context = new InternalMockProcessorContext<>( + dir, + Serdes.String(), + Serdes.String(), + new StreamsConfig(props) + ); + rocksDBStore = getRocksDBStore(); + } + + @After + public void tearDown() { + rocksDBStore.close(); + } + + @Override + protected KeyValueStore createKeyValueStore(final StateStoreContext context) { + final StoreBuilder> storeBuilder = Stores.keyValueStoreBuilder( + Stores.persistentKeyValueStore("my-store"), + (Serde) context.keySerde(), + (Serde) context.valueSerde()); + + final KeyValueStore store = storeBuilder.build(); + store.init(context, store); + return store; + } + + RocksDBStore getRocksDBStore() { + return new RocksDBStore(DB_NAME, METRICS_SCOPE); + } + + private RocksDBStore getRocksDBStoreWithRocksDBMetricsRecorder() { + return new RocksDBStore(DB_NAME, METRICS_SCOPE, metricsRecorder); + } + + private InternalMockProcessorContext getProcessorContext(final Properties streamsProps) { + return new InternalMockProcessorContext( + TestUtils.tempDirectory(), + new StreamsConfig(streamsProps) + ); + } + + private InternalMockProcessorContext getProcessorContext( + final RecordingLevel recordingLevel, + final Class rocksDBConfigSetterClass) { + + final Properties streamsProps = StreamsTestUtils.getStreamsConfig(); + streamsProps.setProperty(StreamsConfig.METRICS_RECORDING_LEVEL_CONFIG, recordingLevel.name()); + streamsProps.put(StreamsConfig.ROCKSDB_CONFIG_SETTER_CLASS_CONFIG, rocksDBConfigSetterClass); + return getProcessorContext(streamsProps); + } + + private InternalMockProcessorContext getProcessorContext(final RecordingLevel recordingLevel) { + final Properties streamsProps = StreamsTestUtils.getStreamsConfig(); + streamsProps.setProperty(StreamsConfig.METRICS_RECORDING_LEVEL_CONFIG, recordingLevel.name()); + return getProcessorContext(streamsProps); + } + + @Test + public void shouldAddValueProvidersWithoutStatisticsToInjectedMetricsRecorderWhenRecordingLevelInfo() { + rocksDBStore = getRocksDBStoreWithRocksDBMetricsRecorder(); + context = getProcessorContext(RecordingLevel.INFO); + reset(metricsRecorder); + metricsRecorder.addValueProviders(eq(DB_NAME), notNull(), notNull(), isNull()); + replay(metricsRecorder); + + rocksDBStore.openDB(context.appConfigs(), context.stateDir()); + + verify(metricsRecorder); + reset(metricsRecorder); + } + + @Test + public void shouldAddValueProvidersWithStatisticsToInjectedMetricsRecorderWhenRecordingLevelDebug() { + rocksDBStore = getRocksDBStoreWithRocksDBMetricsRecorder(); + context = getProcessorContext(RecordingLevel.DEBUG); + reset(metricsRecorder); + metricsRecorder.addValueProviders(eq(DB_NAME), notNull(), notNull(), notNull()); + replay(metricsRecorder); + + rocksDBStore.openDB(context.appConfigs(), context.stateDir()); + + verify(metricsRecorder); + reset(metricsRecorder); + } + + @Test + public void shouldRemoveValueProvidersFromInjectedMetricsRecorderOnClose() { + rocksDBStore = getRocksDBStoreWithRocksDBMetricsRecorder(); + try { + context = getProcessorContext(RecordingLevel.DEBUG); + rocksDBStore.openDB(context.appConfigs(), context.stateDir()); + reset(metricsRecorder); + metricsRecorder.removeValueProviders(DB_NAME); + replay(metricsRecorder); + } finally { + rocksDBStore.close(); + } + + verify(metricsRecorder); + } + + public static class RocksDBConfigSetterWithUserProvidedStatistics implements RocksDBConfigSetter { + public RocksDBConfigSetterWithUserProvidedStatistics(){} + + public void setConfig(final String storeName, final Options options, final Map configs) { + options.setStatistics(new Statistics()); + } + + public void close(final String storeName, final Options options) { + options.statistics().close(); + } + } + + @Test + public void shouldNotSetStatisticsInValueProvidersWhenUserProvidesStatistics() { + rocksDBStore = getRocksDBStoreWithRocksDBMetricsRecorder(); + context = getProcessorContext(RecordingLevel.DEBUG, RocksDBConfigSetterWithUserProvidedStatistics.class); + metricsRecorder.addValueProviders(eq(DB_NAME), notNull(), notNull(), isNull()); + replay(metricsRecorder); + + rocksDBStore.openDB(context.appConfigs(), context.stateDir()); + verify(metricsRecorder); + reset(metricsRecorder); + } + + public static class RocksDBConfigSetterWithUserProvidedNewBlockBasedTableFormatConfig implements RocksDBConfigSetter { + public RocksDBConfigSetterWithUserProvidedNewBlockBasedTableFormatConfig(){} + + public void setConfig(final String storeName, final Options options, final Map configs) { + options.setTableFormatConfig(new BlockBasedTableConfig()); + } + + public void close(final String storeName, final Options options) { + options.statistics().close(); + } + } + + @Test + public void shouldThrowWhenUserProvidesNewBlockBasedTableFormatConfig() { + rocksDBStore = getRocksDBStoreWithRocksDBMetricsRecorder(); + context = getProcessorContext( + RecordingLevel.DEBUG, + RocksDBConfigSetterWithUserProvidedNewBlockBasedTableFormatConfig.class + ); + assertThrows( + "The used block-based table format configuration does not expose the " + + "block cache. Use the BlockBasedTableConfig instance provided by Options#tableFormatConfig() to configure " + + "the block-based table format of RocksDB. Do not provide a new instance of BlockBasedTableConfig to " + + "the RocksDB options.", + ProcessorStateException.class, + () -> rocksDBStore.openDB(context.appConfigs(), context.stateDir()) + ); + } + + public static class RocksDBConfigSetterWithUserProvidedNewPlainTableFormatConfig implements RocksDBConfigSetter { + public RocksDBConfigSetterWithUserProvidedNewPlainTableFormatConfig(){} + + public void setConfig(final String storeName, final Options options, final Map configs) { + options.setTableFormatConfig(new PlainTableConfig()); + } + + public void close(final String storeName, final Options options) { + options.statistics().close(); + } + } + + @Test + public void shouldNotSetCacheInValueProvidersWhenUserProvidesPlainTableFormatConfig() { + rocksDBStore = getRocksDBStoreWithRocksDBMetricsRecorder(); + context = getProcessorContext( + RecordingLevel.DEBUG, + RocksDBConfigSetterWithUserProvidedNewPlainTableFormatConfig.class + ); + metricsRecorder.addValueProviders(eq(DB_NAME), notNull(), isNull(), notNull()); + replay(metricsRecorder); + + rocksDBStore.openDB(context.appConfigs(), context.stateDir()); + verify(metricsRecorder); + reset(metricsRecorder); + } + + @Test + public void shouldNotThrowExceptionOnRestoreWhenThereIsPreExistingRocksDbFiles() { + rocksDBStore.init((StateStoreContext) context, rocksDBStore); + rocksDBStore.put(new Bytes("existingKey".getBytes(UTF_8)), "existingValue".getBytes(UTF_8)); + rocksDBStore.flush(); + + final List> restoreBytes = new ArrayList<>(); + + final byte[] restoredKey = "restoredKey".getBytes(UTF_8); + final byte[] restoredValue = "restoredValue".getBytes(UTF_8); + restoreBytes.add(KeyValue.pair(restoredKey, restoredValue)); + + context.restore(DB_NAME, restoreBytes); + + assertThat( + stringDeserializer.deserialize( + null, + rocksDBStore.get(new Bytes(stringSerializer.serialize(null, "restoredKey")))), + equalTo("restoredValue")); + } + + @Test + public void shouldCallRocksDbConfigSetter() { + MockRocksDbConfigSetter.called = false; + + final Properties props = StreamsTestUtils.getStreamsConfig(); + props.put(StreamsConfig.ROCKSDB_CONFIG_SETTER_CLASS_CONFIG, MockRocksDbConfigSetter.class); + final Object param = new Object(); + props.put("abc.def", param); + final InternalMockProcessorContext context = new InternalMockProcessorContext( + dir, + Serdes.String(), + Serdes.String(), + new StreamsConfig(props) + ); + + rocksDBStore.init((StateStoreContext) context, rocksDBStore); + + assertTrue(MockRocksDbConfigSetter.called); + assertThat(MockRocksDbConfigSetter.configMap.get("abc.def"), equalTo(param)); + } + + @Test + public void shouldThrowProcessorStateExceptionOnOpeningReadOnlyDir() { + final File tmpDir = TestUtils.tempDirectory(); + final InternalMockProcessorContext tmpContext = new InternalMockProcessorContext(tmpDir, new StreamsConfig(StreamsTestUtils.getStreamsConfig())); + + assertTrue(tmpDir.setReadOnly()); + + assertThrows(ProcessorStateException.class, () -> rocksDBStore.openDB(tmpContext.appConfigs(), tmpContext.stateDir())); + } + + @Test + public void shouldPutAll() { + final List> entries = new ArrayList<>(); + entries.add(new KeyValue<>( + new Bytes(stringSerializer.serialize(null, "1")), + stringSerializer.serialize(null, "a"))); + entries.add(new KeyValue<>( + new Bytes(stringSerializer.serialize(null, "2")), + stringSerializer.serialize(null, "b"))); + entries.add(new KeyValue<>( + new Bytes(stringSerializer.serialize(null, "3")), + stringSerializer.serialize(null, "c"))); + + rocksDBStore.init((StateStoreContext) context, rocksDBStore); + rocksDBStore.putAll(entries); + rocksDBStore.flush(); + + assertEquals( + "a", + stringDeserializer.deserialize( + null, + rocksDBStore.get(new Bytes(stringSerializer.serialize(null, "1"))))); + assertEquals( + "b", + stringDeserializer.deserialize( + null, + rocksDBStore.get(new Bytes(stringSerializer.serialize(null, "2"))))); + assertEquals( + "c", + stringDeserializer.deserialize( + null, + rocksDBStore.get(new Bytes(stringSerializer.serialize(null, "3"))))); + } + + @Test + public void shouldReturnKeysWithGivenPrefix() { + final List> entries = new ArrayList<>(); + entries.add(new KeyValue<>( + new Bytes(stringSerializer.serialize(null, "k1")), + stringSerializer.serialize(null, "a"))); + entries.add(new KeyValue<>( + new Bytes(stringSerializer.serialize(null, "prefix_3")), + stringSerializer.serialize(null, "b"))); + entries.add(new KeyValue<>( + new Bytes(stringSerializer.serialize(null, "k2")), + stringSerializer.serialize(null, "c"))); + entries.add(new KeyValue<>( + new Bytes(stringSerializer.serialize(null, "prefix_2")), + stringSerializer.serialize(null, "d"))); + entries.add(new KeyValue<>( + new Bytes(stringSerializer.serialize(null, "k3")), + stringSerializer.serialize(null, "e"))); + entries.add(new KeyValue<>( + new Bytes(stringSerializer.serialize(null, "prefix_1")), + stringSerializer.serialize(null, "f"))); + + rocksDBStore.init((StateStoreContext) context, rocksDBStore); + rocksDBStore.putAll(entries); + rocksDBStore.flush(); + + try (final KeyValueIterator keysWithPrefix = rocksDBStore.prefixScan("prefix", stringSerializer)) { + final List valuesWithPrefix = new ArrayList<>(); + int numberOfKeysReturned = 0; + + while (keysWithPrefix.hasNext()) { + final KeyValue next = keysWithPrefix.next(); + valuesWithPrefix.add(new String(next.value)); + numberOfKeysReturned++; + } + assertThat(numberOfKeysReturned, is(3)); + assertThat(valuesWithPrefix.get(0), is("f")); + assertThat(valuesWithPrefix.get(1), is("d")); + assertThat(valuesWithPrefix.get(2), is("b")); + } + } + + @Test + public void shouldReturnKeysWithGivenPrefixExcludingNextKeyLargestKey() { + final List> entries = new ArrayList<>(); + entries.add(new KeyValue<>( + new Bytes(stringSerializer.serialize(null, "abc")), + stringSerializer.serialize(null, "f"))); + + entries.add(new KeyValue<>( + new Bytes(stringSerializer.serialize(null, "abcd")), + stringSerializer.serialize(null, "f"))); + + entries.add(new KeyValue<>( + new Bytes(stringSerializer.serialize(null, "abce")), + stringSerializer.serialize(null, "f"))); + + rocksDBStore.init((StateStoreContext) context, rocksDBStore); + rocksDBStore.putAll(entries); + rocksDBStore.flush(); + + try (final KeyValueIterator keysWithPrefixAsabcd = rocksDBStore.prefixScan("abcd", stringSerializer)) { + int numberOfKeysReturned = 0; + + while (keysWithPrefixAsabcd.hasNext()) { + keysWithPrefixAsabcd.next().key.get(); + numberOfKeysReturned++; + } + + assertThat(numberOfKeysReturned, is(1)); + } + } + + @Test + public void shouldReturnUUIDsWithStringPrefix() { + final List> entries = new ArrayList<>(); + final Serializer uuidSerializer = Serdes.UUID().serializer(); + final UUID uuid1 = UUID.randomUUID(); + final UUID uuid2 = UUID.randomUUID(); + final String prefix = uuid1.toString().substring(0, 4); + final int numMatches = uuid2.toString().substring(0, 4).equals(prefix) ? 2 : 1; + + entries.add(new KeyValue<>( + new Bytes(uuidSerializer.serialize(null, uuid1)), + stringSerializer.serialize(null, "a"))); + entries.add(new KeyValue<>( + new Bytes(uuidSerializer.serialize(null, uuid2)), + stringSerializer.serialize(null, "b"))); + + rocksDBStore.init((StateStoreContext) context, rocksDBStore); + rocksDBStore.putAll(entries); + rocksDBStore.flush(); + + try (final KeyValueIterator keysWithPrefix = rocksDBStore.prefixScan(prefix, stringSerializer)) { + final List valuesWithPrefix = new ArrayList<>(); + int numberOfKeysReturned = 0; + + while (keysWithPrefix.hasNext()) { + final KeyValue next = keysWithPrefix.next(); + valuesWithPrefix.add(new String(next.value)); + numberOfKeysReturned++; + } + + assertThat(numberOfKeysReturned, is(numMatches)); + if (numMatches == 2) { + assertThat(valuesWithPrefix.get(0), either(is("a")).or(is("b"))); + } else { + assertThat(valuesWithPrefix.get(0), is("a")); + } + } + } + + @Test + public void shouldReturnNoKeys() { + final List> entries = new ArrayList<>(); + entries.add(new KeyValue<>( + new Bytes(stringSerializer.serialize(null, "a")), + stringSerializer.serialize(null, "a"))); + entries.add(new KeyValue<>( + new Bytes(stringSerializer.serialize(null, "b")), + stringSerializer.serialize(null, "c"))); + entries.add(new KeyValue<>( + new Bytes(stringSerializer.serialize(null, "c")), + stringSerializer.serialize(null, "e"))); + rocksDBStore.init((StateStoreContext) context, rocksDBStore); + rocksDBStore.putAll(entries); + rocksDBStore.flush(); + + try (final KeyValueIterator keysWithPrefix = rocksDBStore.prefixScan("d", stringSerializer)) { + int numberOfKeysReturned = 0; + + while (keysWithPrefix.hasNext()) { + keysWithPrefix.next(); + numberOfKeysReturned++; + } + assertThat(numberOfKeysReturned, is(0)); + } + } + + @Test + public void shouldRestoreAll() { + final List> entries = getKeyValueEntries(); + + rocksDBStore.init((StateStoreContext) context, rocksDBStore); + context.restore(rocksDBStore.name(), entries); + + assertEquals( + "a", + stringDeserializer.deserialize( + null, + rocksDBStore.get(new Bytes(stringSerializer.serialize(null, "1"))))); + assertEquals( + "b", + stringDeserializer.deserialize( + null, + rocksDBStore.get(new Bytes(stringSerializer.serialize(null, "2"))))); + assertEquals( + "c", + stringDeserializer.deserialize( + null, + rocksDBStore.get(new Bytes(stringSerializer.serialize(null, "3"))))); + } + + @Test + public void shouldPutOnlyIfAbsentValue() { + rocksDBStore.init((StateStoreContext) context, rocksDBStore); + final Bytes keyBytes = new Bytes(stringSerializer.serialize(null, "one")); + final byte[] valueBytes = stringSerializer.serialize(null, "A"); + final byte[] valueBytesUpdate = stringSerializer.serialize(null, "B"); + + rocksDBStore.putIfAbsent(keyBytes, valueBytes); + rocksDBStore.putIfAbsent(keyBytes, valueBytesUpdate); + + final String retrievedValue = stringDeserializer.deserialize(null, rocksDBStore.get(keyBytes)); + assertEquals("A", retrievedValue); + } + + @Test + public void shouldHandleDeletesOnRestoreAll() { + final List> entries = getKeyValueEntries(); + entries.add(new KeyValue<>("1".getBytes(UTF_8), null)); + + rocksDBStore.init((StateStoreContext) context, rocksDBStore); + context.restore(rocksDBStore.name(), entries); + + try (final KeyValueIterator iterator = rocksDBStore.all()) { + final Set keys = new HashSet<>(); + + while (iterator.hasNext()) { + keys.add(stringDeserializer.deserialize(null, iterator.next().key.get())); + } + + assertThat(keys, equalTo(Utils.mkSet("2", "3"))); + } + } + + @Test + public void shouldHandleDeletesAndPutBackOnRestoreAll() { + final List> entries = new ArrayList<>(); + entries.add(new KeyValue<>("1".getBytes(UTF_8), "a".getBytes(UTF_8))); + entries.add(new KeyValue<>("2".getBytes(UTF_8), "b".getBytes(UTF_8))); + // this will be deleted + entries.add(new KeyValue<>("1".getBytes(UTF_8), null)); + entries.add(new KeyValue<>("3".getBytes(UTF_8), "c".getBytes(UTF_8))); + // this will restore key "1" as WriteBatch applies updates in order + entries.add(new KeyValue<>("1".getBytes(UTF_8), "restored".getBytes(UTF_8))); + + rocksDBStore.init((StateStoreContext) context, rocksDBStore); + context.restore(rocksDBStore.name(), entries); + + try (final KeyValueIterator iterator = rocksDBStore.all()) { + final Set keys = new HashSet<>(); + + while (iterator.hasNext()) { + keys.add(stringDeserializer.deserialize(null, iterator.next().key.get())); + } + + assertThat(keys, equalTo(Utils.mkSet("1", "2", "3"))); + + assertEquals( + "restored", + stringDeserializer.deserialize( + null, + rocksDBStore.get(new Bytes(stringSerializer.serialize(null, "1"))))); + assertEquals( + "b", + stringDeserializer.deserialize( + null, + rocksDBStore.get(new Bytes(stringSerializer.serialize(null, "2"))))); + assertEquals( + "c", + stringDeserializer.deserialize( + null, + rocksDBStore.get(new Bytes(stringSerializer.serialize(null, "3"))))); + } + } + + @Test + public void shouldRestoreThenDeleteOnRestoreAll() { + final List> entries = getKeyValueEntries(); + + rocksDBStore.init((StateStoreContext) context, rocksDBStore); + + context.restore(rocksDBStore.name(), entries); + + assertEquals( + "a", + stringDeserializer.deserialize( + null, + rocksDBStore.get(new Bytes(stringSerializer.serialize(null, "1"))))); + assertEquals( + "b", + stringDeserializer.deserialize( + null, + rocksDBStore.get(new Bytes(stringSerializer.serialize(null, "2"))))); + assertEquals( + "c", + stringDeserializer.deserialize( + null, + rocksDBStore.get(new Bytes(stringSerializer.serialize(null, "3"))))); + + entries.clear(); + + entries.add(new KeyValue<>("2".getBytes(UTF_8), "b".getBytes(UTF_8))); + entries.add(new KeyValue<>("3".getBytes(UTF_8), "c".getBytes(UTF_8))); + entries.add(new KeyValue<>("1".getBytes(UTF_8), null)); + + context.restore(rocksDBStore.name(), entries); + + try (final KeyValueIterator iterator = rocksDBStore.all()) { + final Set keys = new HashSet<>(); + + while (iterator.hasNext()) { + keys.add(stringDeserializer.deserialize(null, iterator.next().key.get())); + } + + assertThat(keys, equalTo(Utils.mkSet("2", "3"))); + } + } + + @Test + public void shouldThrowNullPointerExceptionOnNullPut() { + rocksDBStore.init((StateStoreContext) context, rocksDBStore); + assertThrows( + NullPointerException.class, + () -> rocksDBStore.put(null, stringSerializer.serialize(null, "someVal"))); + } + + @Test + public void shouldThrowNullPointerExceptionOnNullPutAll() { + rocksDBStore.init((StateStoreContext) context, rocksDBStore); + assertThrows( + NullPointerException.class, + () -> rocksDBStore.put(null, stringSerializer.serialize(null, "someVal"))); + } + + @Test + public void shouldThrowNullPointerExceptionOnNullGet() { + rocksDBStore.init((StateStoreContext) context, rocksDBStore); + assertThrows( + NullPointerException.class, + () -> rocksDBStore.get(null)); + } + + @Test + public void shouldThrowNullPointerExceptionOnDelete() { + rocksDBStore.init((StateStoreContext) context, rocksDBStore); + assertThrows( + NullPointerException.class, + () -> rocksDBStore.delete(null)); + } + + @Test + public void shouldReturnValueOnRange() { + rocksDBStore.init((StateStoreContext) context, rocksDBStore); + + final KeyValue kv0 = new KeyValue<>("0", "zero"); + final KeyValue kv1 = new KeyValue<>("1", "one"); + final KeyValue kv2 = new KeyValue<>("2", "two"); + + rocksDBStore.put(new Bytes(kv0.key.getBytes(UTF_8)), kv0.value.getBytes(UTF_8)); + rocksDBStore.put(new Bytes(kv1.key.getBytes(UTF_8)), kv1.value.getBytes(UTF_8)); + rocksDBStore.put(new Bytes(kv2.key.getBytes(UTF_8)), kv2.value.getBytes(UTF_8)); + + final LinkedList> expectedContents = new LinkedList<>(); + expectedContents.add(kv0); + expectedContents.add(kv1); + + try (final KeyValueIterator iterator = rocksDBStore.range(null, new Bytes(stringSerializer.serialize(null, "1")))) { + assertEquals(expectedContents, getDeserializedList(iterator)); + } + } + + @Test + public void shouldThrowProcessorStateExceptionOnPutDeletedDir() throws IOException { + rocksDBStore.init((StateStoreContext) context, rocksDBStore); + Utils.delete(dir); + rocksDBStore.put( + new Bytes(stringSerializer.serialize(null, "anyKey")), + stringSerializer.serialize(null, "anyValue")); + assertThrows(ProcessorStateException.class, () -> rocksDBStore.flush()); + } + + @Test + public void shouldHandleToggleOfEnablingBloomFilters() { + final Properties props = StreamsTestUtils.getStreamsConfig(); + props.put(StreamsConfig.ROCKSDB_CONFIG_SETTER_CLASS_CONFIG, TestingBloomFilterRocksDBConfigSetter.class); + dir = TestUtils.tempDirectory(); + context = new InternalMockProcessorContext(dir, + Serdes.String(), + Serdes.String(), + new StreamsConfig(props)); + + enableBloomFilters = false; + rocksDBStore.init((StateStoreContext) context, rocksDBStore); + + final List expectedValues = new ArrayList<>(); + expectedValues.add("a"); + expectedValues.add("b"); + expectedValues.add("c"); + + final List> keyValues = getKeyValueEntries(); + for (final KeyValue keyValue : keyValues) { + rocksDBStore.put(new Bytes(keyValue.key), keyValue.value); + } + + int expectedIndex = 0; + for (final KeyValue keyValue : keyValues) { + final byte[] valBytes = rocksDBStore.get(new Bytes(keyValue.key)); + assertThat(new String(valBytes, UTF_8), is(expectedValues.get(expectedIndex++))); + } + assertFalse(TestingBloomFilterRocksDBConfigSetter.bloomFiltersSet); + + rocksDBStore.close(); + expectedIndex = 0; + + // reopen with Bloom Filters enabled + // should open fine without errors + enableBloomFilters = true; + rocksDBStore.init((StateStoreContext) context, rocksDBStore); + + for (final KeyValue keyValue : keyValues) { + final byte[] valBytes = rocksDBStore.get(new Bytes(keyValue.key)); + assertThat(new String(valBytes, UTF_8), is(expectedValues.get(expectedIndex++))); + } + + assertTrue(TestingBloomFilterRocksDBConfigSetter.bloomFiltersSet); + } + + @Test + public void shouldVerifyThatMetricsRecordedFromStatisticsGetMeasurementsFromRocksDB() { + final TaskId taskId = new TaskId(0, 0); + + final Metrics metrics = new Metrics(new MetricConfig().recordLevel(RecordingLevel.DEBUG)); + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, "test-application", StreamsConfig.METRICS_LATEST, time); + + context = EasyMock.niceMock(InternalMockProcessorContext.class); + EasyMock.expect(context.metrics()).andStubReturn(streamsMetrics); + EasyMock.expect(context.taskId()).andStubReturn(taskId); + EasyMock.expect(context.appConfigs()) + .andStubReturn(new StreamsConfig(StreamsTestUtils.getStreamsConfig()).originals()); + EasyMock.expect(context.stateDir()).andStubReturn(dir); + EasyMock.replay(context); + + rocksDBStore.init((StateStoreContext) context, rocksDBStore); + final byte[] key = "hello".getBytes(); + final byte[] value = "world".getBytes(); + rocksDBStore.put(Bytes.wrap(key), value); + + streamsMetrics.rocksDBMetricsRecordingTrigger().run(); + + final Metric bytesWrittenTotal = metrics.metric(new MetricName( + "bytes-written-total", + StreamsMetricsImpl.STATE_STORE_LEVEL_GROUP, + "description is not verified", + streamsMetrics.storeLevelTagMap(taskId.toString(), METRICS_SCOPE, DB_NAME) + )); + assertThat((double) bytesWrittenTotal.metricValue(), greaterThan(0d)); + } + + @Test + public void shouldVerifyThatMetricsRecordedFromPropertiesGetMeasurementsFromRocksDB() { + final TaskId taskId = new TaskId(0, 0); + + final Metrics metrics = new Metrics(new MetricConfig().recordLevel(RecordingLevel.INFO)); + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, "test-application", StreamsConfig.METRICS_LATEST, time); + + context = EasyMock.niceMock(InternalMockProcessorContext.class); + EasyMock.expect(context.metrics()).andStubReturn(streamsMetrics); + EasyMock.expect(context.taskId()).andStubReturn(taskId); + EasyMock.expect(context.appConfigs()) + .andStubReturn(new StreamsConfig(StreamsTestUtils.getStreamsConfig()).originals()); + EasyMock.expect(context.stateDir()).andStubReturn(dir); + EasyMock.replay(context); + + rocksDBStore.init((StateStoreContext) context, rocksDBStore); + final byte[] key = "hello".getBytes(); + final byte[] value = "world".getBytes(); + rocksDBStore.put(Bytes.wrap(key), value); + + final Metric numberOfEntriesActiveMemTable = metrics.metric(new MetricName( + "num-entries-active-mem-table", + StreamsMetricsImpl.STATE_STORE_LEVEL_GROUP, + "description is not verified", + streamsMetrics.storeLevelTagMap(taskId.toString(), METRICS_SCOPE, DB_NAME) + )); + assertThat(numberOfEntriesActiveMemTable, notNullValue()); + assertThat((BigInteger) numberOfEntriesActiveMemTable.metricValue(), greaterThan(BigInteger.valueOf(0))); + } + + @Test + public void shouldVerifyThatPropertyBasedMetricsUseValidPropertyName() { + final TaskId taskId = new TaskId(0, 0); + + final Metrics metrics = new Metrics(new MetricConfig().recordLevel(RecordingLevel.INFO)); + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(metrics, "test-application", StreamsConfig.METRICS_LATEST, time); + + final Properties props = StreamsTestUtils.getStreamsConfig(); + context = EasyMock.niceMock(InternalMockProcessorContext.class); + EasyMock.expect(context.metrics()).andStubReturn(streamsMetrics); + EasyMock.expect(context.taskId()).andStubReturn(taskId); + EasyMock.expect(context.appConfigs()).andStubReturn(new StreamsConfig(props).originals()); + EasyMock.expect(context.stateDir()).andStubReturn(dir); + EasyMock.replay(context); + + rocksDBStore.init((StateStoreContext) context, rocksDBStore); + + final List propertyNames = Arrays.asList( + "num-entries-active-mem-table", + "num-deletes-active-mem-table", + "num-entries-imm-mem-tables", + "num-deletes-imm-mem-tables", + "num-immutable-mem-table", + "cur-size-active-mem-table", + "cur-size-all-mem-tables", + "size-all-mem-tables", + "mem-table-flush-pending", + "num-running-flushes", + "compaction-pending", + "num-running-compactions", + "estimate-pending-compaction-bytes", + "total-sst-files-size", + "live-sst-files-size", + "num-live-versions", + "block-cache-capacity", + "block-cache-usage", + "block-cache-pinned-usage", + "estimate-num-keys", + "estimate-table-readers-mem", + "background-errors" + ); + for (final String propertyname : propertyNames) { + final Metric metric = metrics.metric(new MetricName( + propertyname, + StreamsMetricsImpl.STATE_STORE_LEVEL_GROUP, + "description is not verified", + streamsMetrics.storeLevelTagMap(taskId.toString(), METRICS_SCOPE, DB_NAME) + )); + assertThat("Metric " + propertyname + " not found!", metric, notNullValue()); + metric.metricValue(); + } + } + + @Test + public void shouldPerformRangeQueriesWithCachingDisabled() { + context.setTime(1L); + store.put(1, "hi"); + store.put(2, "goodbye"); + try (final KeyValueIterator range = store.range(1, 2)) { + assertEquals("hi", range.next().value); + assertEquals("goodbye", range.next().value); + assertFalse(range.hasNext()); + } + } + + @Test + public void shouldPerformAllQueriesWithCachingDisabled() { + context.setTime(1L); + store.put(1, "hi"); + store.put(2, "goodbye"); + try (final KeyValueIterator range = store.all()) { + assertEquals("hi", range.next().value); + assertEquals("goodbye", range.next().value); + assertFalse(range.hasNext()); + } + } + + @Test + public void shouldCloseOpenRangeIteratorsWhenStoreClosedAndThrowInvalidStateStoreOnHasNextAndNext() { + context.setTime(1L); + store.put(1, "hi"); + store.put(2, "goodbye"); + try (final KeyValueIterator iteratorOne = store.range(1, 5); + final KeyValueIterator iteratorTwo = store.range(1, 4)) { + + assertTrue(iteratorOne.hasNext()); + assertTrue(iteratorTwo.hasNext()); + + store.close(); + + Assertions.assertThrows(InvalidStateStoreException.class, () -> iteratorOne.hasNext()); + Assertions.assertThrows(InvalidStateStoreException.class, () -> iteratorOne.next()); + Assertions.assertThrows(InvalidStateStoreException.class, () -> iteratorTwo.hasNext()); + Assertions.assertThrows(InvalidStateStoreException.class, () -> iteratorTwo.next()); + } + } + + public static class TestingBloomFilterRocksDBConfigSetter implements RocksDBConfigSetter { + + static boolean bloomFiltersSet; + static Filter filter; + static Cache cache; + + @Override + public void setConfig(final String storeName, final Options options, final Map configs) { + final BlockBasedTableConfig tableConfig = (BlockBasedTableConfig) options.tableFormatConfig(); + cache = new LRUCache(50 * 1024 * 1024L); + tableConfig.setBlockCache(cache); + tableConfig.setBlockSize(4096L); + if (enableBloomFilters) { + filter = new BloomFilter(); + tableConfig.setFilterPolicy(filter); + options.optimizeFiltersForHits(); + bloomFiltersSet = true; + } else { + options.setOptimizeFiltersForHits(false); + bloomFiltersSet = false; + } + + options.setTableFormatConfig(tableConfig); + } + + @Override + public void close(final String storeName, final Options options) { + if (filter != null) { + filter.close(); + } + cache.close(); + } + } + + private List> getKeyValueEntries() { + final List> entries = new ArrayList<>(); + entries.add(new KeyValue<>("1".getBytes(UTF_8), "a".getBytes(UTF_8))); + entries.add(new KeyValue<>("2".getBytes(UTF_8), "b".getBytes(UTF_8))); + entries.add(new KeyValue<>("3".getBytes(UTF_8), "c".getBytes(UTF_8))); + return entries; + } + + private List> getDeserializedList(final KeyValueIterator iter) { + final List> bytes = Utils.toList(iter); + final List> result = bytes.stream().map(kv -> new KeyValue(kv.key.toString(), stringDeserializer.deserialize(null, kv.value))).collect(Collectors.toList()); + return result; + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBTimestampedSegmentedBytesStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBTimestampedSegmentedBytesStoreTest.java new file mode 100644 index 0000000..814a04c --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBTimestampedSegmentedBytesStoreTest.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +public class RocksDBTimestampedSegmentedBytesStoreTest + extends AbstractRocksDBSegmentedBytesStoreTest { + + private final static String METRICS_SCOPE = "metrics-scope"; + + RocksDBTimestampedSegmentedBytesStore getBytesStore() { + return new RocksDBTimestampedSegmentedBytesStore( + storeName, + METRICS_SCOPE, + retention, + segmentInterval, + schema + ); + } + + @Override + TimestampedSegments newSegments() { + return new TimestampedSegments(storeName, METRICS_SCOPE, retention, segmentInterval); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBTimestampedStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBTimestampedStoreTest.java new file mode 100644 index 0000000..a1d511a --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBTimestampedStoreTest.java @@ -0,0 +1,483 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.internals.testutil.LogCaptureAppender; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.hamcrest.core.IsNull; +import org.junit.Test; +import org.rocksdb.ColumnFamilyDescriptor; +import org.rocksdb.ColumnFamilyHandle; +import org.rocksdb.ColumnFamilyOptions; +import org.rocksdb.DBOptions; +import org.rocksdb.RocksDB; + +import java.io.File; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; + +import static java.util.Arrays.asList; +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertFalse; + +public class RocksDBTimestampedStoreTest extends RocksDBStoreTest { + + private final Serializer stringSerializer = new StringSerializer(); + + RocksDBStore getRocksDBStore() { + return new RocksDBTimestampedStore(DB_NAME, METRICS_SCOPE); + } + + @Test + public void shouldOpenNewStoreInRegularMode() { + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(RocksDBTimestampedStore.class)) { + rocksDBStore.init((StateStoreContext) context, rocksDBStore); + + assertThat(appender.getMessages(), hasItem("Opening store " + DB_NAME + " in regular mode")); + } + + try (final KeyValueIterator iterator = rocksDBStore.all()) { + assertThat(iterator.hasNext(), is(false)); + } + } + + @Test + public void shouldOpenExistingStoreInRegularMode() throws Exception { + // prepare store + rocksDBStore.init((StateStoreContext) context, rocksDBStore); + rocksDBStore.put(new Bytes("key".getBytes()), "timestamped".getBytes()); + rocksDBStore.close(); + + // re-open store + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(RocksDBTimestampedStore.class)) { + rocksDBStore.init((StateStoreContext) context, rocksDBStore); + + assertThat(appender.getMessages(), hasItem("Opening store " + DB_NAME + " in regular mode")); + } finally { + rocksDBStore.close(); + } + + // verify store + final DBOptions dbOptions = new DBOptions(); + final ColumnFamilyOptions columnFamilyOptions = new ColumnFamilyOptions(); + + final List columnFamilyDescriptors = asList( + new ColumnFamilyDescriptor(RocksDB.DEFAULT_COLUMN_FAMILY, columnFamilyOptions), + new ColumnFamilyDescriptor("keyValueWithTimestamp".getBytes(StandardCharsets.UTF_8), columnFamilyOptions)); + final List columnFamilies = new ArrayList<>(columnFamilyDescriptors.size()); + + RocksDB db = null; + ColumnFamilyHandle noTimestampColumnFamily = null, withTimestampColumnFamily = null; + try { + db = RocksDB.open( + dbOptions, + new File(new File(context.stateDir(), "rocksdb"), DB_NAME).getAbsolutePath(), + columnFamilyDescriptors, + columnFamilies); + + noTimestampColumnFamily = columnFamilies.get(0); + withTimestampColumnFamily = columnFamilies.get(1); + + assertThat(db.get(noTimestampColumnFamily, "key".getBytes()), new IsNull<>()); + assertThat(db.getLongProperty(noTimestampColumnFamily, "rocksdb.estimate-num-keys"), is(0L)); + assertThat(db.get(withTimestampColumnFamily, "key".getBytes()).length, is(11)); + assertThat(db.getLongProperty(withTimestampColumnFamily, "rocksdb.estimate-num-keys"), is(1L)); + } finally { + // Order of closing must follow: ColumnFamilyHandle > RocksDB > DBOptions > ColumnFamilyOptions + if (noTimestampColumnFamily != null) { + noTimestampColumnFamily.close(); + } + if (withTimestampColumnFamily != null) { + withTimestampColumnFamily.close(); + } + if (db != null) { + db.close(); + } + dbOptions.close(); + columnFamilyOptions.close(); + } + } + + @Test + public void shouldMigrateDataFromDefaultToTimestampColumnFamily() throws Exception { + prepareOldStore(); + + try (final LogCaptureAppender appender = LogCaptureAppender.createAndRegister(RocksDBTimestampedStore.class)) { + rocksDBStore.init((StateStoreContext) context, rocksDBStore); + + assertThat(appender.getMessages(), hasItem("Opening store " + DB_NAME + " in upgrade mode")); + } + + // approx: 7 entries on old CF, 0 in new CF + assertThat(rocksDBStore.approximateNumEntries(), is(7L)); + + // get() + + // should be no-op on both CF + assertThat(rocksDBStore.get(new Bytes("unknown".getBytes())), new IsNull<>()); + // approx: 7 entries on old CF, 0 in new CF + assertThat(rocksDBStore.approximateNumEntries(), is(7L)); + + // should migrate key1 from old to new CF + // must return timestamp plus value, ie, it's not 1 byte but 9 bytes + assertThat(rocksDBStore.get(new Bytes("key1".getBytes())).length, is(8 + 1)); + // one delete on old CF, one put on new CF + // approx: 6 entries on old CF, 1 in new CF + assertThat(rocksDBStore.approximateNumEntries(), is(7L)); + + // put() + + // should migrate key2 from old to new CF with new value + rocksDBStore.put(new Bytes("key2".getBytes()), "timestamp+22".getBytes()); + // one delete on old CF, one put on new CF + // approx: 5 entries on old CF, 2 in new CF + assertThat(rocksDBStore.approximateNumEntries(), is(7L)); + + // should delete key3 from old and new CF + rocksDBStore.put(new Bytes("key3".getBytes()), null); + // count is off by one, due to two delete operations (even if one does not delete anything) + // approx: 4 entries on old CF, 1 in new CF + assertThat(rocksDBStore.approximateNumEntries(), is(5L)); + + // should add new key8 to new CF + rocksDBStore.put(new Bytes("key8".getBytes()), "timestamp+88888888".getBytes()); + // one delete on old CF, one put on new CF + // approx: 3 entries on old CF, 2 in new CF + assertThat(rocksDBStore.approximateNumEntries(), is(5L)); + + // putIfAbsent() + + // should migrate key4 from old to new CF with old value + assertThat(rocksDBStore.putIfAbsent(new Bytes("key4".getBytes()), "timestamp+4444".getBytes()).length, is(8 + 4)); + // one delete on old CF, one put on new CF + // approx: 2 entries on old CF, 3 in new CF + assertThat(rocksDBStore.approximateNumEntries(), is(5L)); + + // should add new key11 to new CF + assertThat(rocksDBStore.putIfAbsent(new Bytes("key11".getBytes()), "timestamp+11111111111".getBytes()), new IsNull<>()); + // one delete on old CF, one put on new CF + // approx: 1 entries on old CF, 4 in new CF + assertThat(rocksDBStore.approximateNumEntries(), is(5L)); + + // should not delete key5 but migrate to new CF + assertThat(rocksDBStore.putIfAbsent(new Bytes("key5".getBytes()), null).length, is(8 + 5)); + // one delete on old CF, one put on new CF + // approx: 0 entries on old CF, 5 in new CF + assertThat(rocksDBStore.approximateNumEntries(), is(5L)); + + // should be no-op on both CF + assertThat(rocksDBStore.putIfAbsent(new Bytes("key12".getBytes()), null), new IsNull<>()); + // two delete operation, however, only one is counted because old CF count was zero before already + // approx: 0 entries on old CF, 4 in new CF + assertThat(rocksDBStore.approximateNumEntries(), is(4L)); + + // delete() + + // should delete key6 from old and new CF + assertThat(rocksDBStore.delete(new Bytes("key6".getBytes())).length, is(8 + 6)); + // two delete operation, however, only one is counted because old CF count was zero before already + // approx: 0 entries on old CF, 3 in new CF + assertThat(rocksDBStore.approximateNumEntries(), is(3L)); + + iteratorsShouldNotMigrateData(); + assertThat(rocksDBStore.approximateNumEntries(), is(3L)); + + rocksDBStore.close(); + + verifyOldAndNewColumnFamily(); + } + + private void iteratorsShouldNotMigrateData() { + // iterating should not migrate any data, but return all key over both CF (plus surrogate timestamps for old CF) + try (final KeyValueIterator itAll = rocksDBStore.all()) { + { + final KeyValue keyValue = itAll.next(); + assertArrayEquals("key1".getBytes(), keyValue.key.get()); + // unknown timestamp == -1 plus value == 1 + assertArrayEquals(new byte[]{-1, -1, -1, -1, -1, -1, -1, -1, '1'}, keyValue.value); + } + { + final KeyValue keyValue = itAll.next(); + assertArrayEquals("key11".getBytes(), keyValue.key.get()); + assertArrayEquals(new byte[]{'t', 'i', 'm', 'e', 's', 't', 'a', 'm', 'p', '+', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1'}, keyValue.value); + } + { + final KeyValue keyValue = itAll.next(); + assertArrayEquals("key2".getBytes(), keyValue.key.get()); + assertArrayEquals(new byte[]{'t', 'i', 'm', 'e', 's', 't', 'a', 'm', 'p', '+', '2', '2'}, keyValue.value); + } + { + final KeyValue keyValue = itAll.next(); + assertArrayEquals("key4".getBytes(), keyValue.key.get()); + // unknown timestamp == -1 plus value == 4444 + assertArrayEquals(new byte[]{-1, -1, -1, -1, -1, -1, -1, -1, '4', '4', '4', '4'}, keyValue.value); + } + { + final KeyValue keyValue = itAll.next(); + assertArrayEquals("key5".getBytes(), keyValue.key.get()); + // unknown timestamp == -1 plus value == 55555 + assertArrayEquals(new byte[]{-1, -1, -1, -1, -1, -1, -1, -1, '5', '5', '5', '5', '5'}, keyValue.value); + } + { + final KeyValue keyValue = itAll.next(); + assertArrayEquals("key7".getBytes(), keyValue.key.get()); + // unknown timestamp == -1 plus value == 7777777 + assertArrayEquals(new byte[]{-1, -1, -1, -1, -1, -1, -1, -1, '7', '7', '7', '7', '7', '7', '7'}, keyValue.value); + } + { + final KeyValue keyValue = itAll.next(); + assertArrayEquals("key8".getBytes(), keyValue.key.get()); + assertArrayEquals(new byte[]{'t', 'i', 'm', 'e', 's', 't', 'a', 'm', 'p', '+', '8', '8', '8', '8', '8', '8', '8', '8'}, keyValue.value); + } + assertFalse(itAll.hasNext()); + } + + try (final KeyValueIterator it = + rocksDBStore.range(new Bytes("key2".getBytes()), new Bytes("key5".getBytes()))) { + { + final KeyValue keyValue = it.next(); + assertArrayEquals("key2".getBytes(), keyValue.key.get()); + assertArrayEquals(new byte[]{'t', 'i', 'm', 'e', 's', 't', 'a', 'm', 'p', '+', '2', '2'}, keyValue.value); + } + { + final KeyValue keyValue = it.next(); + assertArrayEquals("key4".getBytes(), keyValue.key.get()); + // unknown timestamp == -1 plus value == 4444 + assertArrayEquals(new byte[]{-1, -1, -1, -1, -1, -1, -1, -1, '4', '4', '4', '4'}, keyValue.value); + } + { + final KeyValue keyValue = it.next(); + assertArrayEquals("key5".getBytes(), keyValue.key.get()); + // unknown timestamp == -1 plus value == 55555 + assertArrayEquals(new byte[]{-1, -1, -1, -1, -1, -1, -1, -1, '5', '5', '5', '5', '5'}, keyValue.value); + } + assertFalse(it.hasNext()); + } + + try (final KeyValueIterator itAll = rocksDBStore.reverseAll()) { + { + final KeyValue keyValue = itAll.next(); + assertArrayEquals("key8".getBytes(), keyValue.key.get()); + assertArrayEquals(new byte[]{'t', 'i', 'm', 'e', 's', 't', 'a', 'm', 'p', '+', '8', '8', '8', '8', '8', '8', '8', '8'}, keyValue.value); + } + { + final KeyValue keyValue = itAll.next(); + assertArrayEquals("key7".getBytes(), keyValue.key.get()); + // unknown timestamp == -1 plus value == 7777777 + assertArrayEquals(new byte[]{-1, -1, -1, -1, -1, -1, -1, -1, '7', '7', '7', '7', '7', '7', '7'}, keyValue.value); + } + { + final KeyValue keyValue = itAll.next(); + assertArrayEquals("key5".getBytes(), keyValue.key.get()); + // unknown timestamp == -1 plus value == 55555 + assertArrayEquals(new byte[]{-1, -1, -1, -1, -1, -1, -1, -1, '5', '5', '5', '5', '5'}, keyValue.value); + } + { + final KeyValue keyValue = itAll.next(); + assertArrayEquals("key4".getBytes(), keyValue.key.get()); + // unknown timestamp == -1 plus value == 4444 + assertArrayEquals(new byte[]{-1, -1, -1, -1, -1, -1, -1, -1, '4', '4', '4', '4'}, keyValue.value); + } + { + final KeyValue keyValue = itAll.next(); + assertArrayEquals("key2".getBytes(), keyValue.key.get()); + assertArrayEquals(new byte[]{'t', 'i', 'm', 'e', 's', 't', 'a', 'm', 'p', '+', '2', '2'}, keyValue.value); + } + { + final KeyValue keyValue = itAll.next(); + assertArrayEquals("key11".getBytes(), keyValue.key.get()); + assertArrayEquals(new byte[]{'t', 'i', 'm', 'e', 's', 't', 'a', 'm', 'p', '+', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1'}, keyValue.value); + } + { + final KeyValue keyValue = itAll.next(); + assertArrayEquals("key1".getBytes(), keyValue.key.get()); + // unknown timestamp == -1 plus value == 1 + assertArrayEquals(new byte[]{-1, -1, -1, -1, -1, -1, -1, -1, '1'}, keyValue.value); + } + assertFalse(itAll.hasNext()); + } + + try (final KeyValueIterator it = + rocksDBStore.reverseRange(new Bytes("key2".getBytes()), new Bytes("key5".getBytes()))) { + { + final KeyValue keyValue = it.next(); + assertArrayEquals("key5".getBytes(), keyValue.key.get()); + // unknown timestamp == -1 plus value == 55555 + assertArrayEquals(new byte[]{-1, -1, -1, -1, -1, -1, -1, -1, '5', '5', '5', '5', '5'}, keyValue.value); + } + { + final KeyValue keyValue = it.next(); + assertArrayEquals("key4".getBytes(), keyValue.key.get()); + // unknown timestamp == -1 plus value == 4444 + assertArrayEquals(new byte[]{-1, -1, -1, -1, -1, -1, -1, -1, '4', '4', '4', '4'}, keyValue.value); + } + { + final KeyValue keyValue = it.next(); + assertArrayEquals("key2".getBytes(), keyValue.key.get()); + assertArrayEquals(new byte[]{'t', 'i', 'm', 'e', 's', 't', 'a', 'm', 'p', '+', '2', '2'}, keyValue.value); + } + assertFalse(it.hasNext()); + } + + try (final KeyValueIterator it = rocksDBStore.prefixScan("key1", stringSerializer)) { + { + final KeyValue keyValue = it.next(); + assertArrayEquals("key1".getBytes(), keyValue.key.get()); + // unknown timestamp == -1 plus value == 1 + assertArrayEquals(new byte[]{-1, -1, -1, -1, -1, -1, -1, -1, '1'}, keyValue.value); + } + { + final KeyValue keyValue = it.next(); + assertArrayEquals("key11".getBytes(), keyValue.key.get()); + assertArrayEquals(new byte[]{'t', 'i', 'm', 'e', 's', 't', 'a', 'm', 'p', '+', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1', '1'}, keyValue.value); + } + assertFalse(it.hasNext()); + } + } + + private void verifyOldAndNewColumnFamily() throws Exception { + final DBOptions dbOptions = new DBOptions(); + final ColumnFamilyOptions columnFamilyOptions = new ColumnFamilyOptions(); + + final List columnFamilyDescriptors = asList( + new ColumnFamilyDescriptor(RocksDB.DEFAULT_COLUMN_FAMILY, columnFamilyOptions), + new ColumnFamilyDescriptor("keyValueWithTimestamp".getBytes(StandardCharsets.UTF_8), columnFamilyOptions)); + final List columnFamilies = new ArrayList<>(columnFamilyDescriptors.size()); + + RocksDB db = null; + ColumnFamilyHandle noTimestampColumnFamily = null, withTimestampColumnFamily = null; + boolean errorOccurred = false; + try { + db = RocksDB.open( + dbOptions, + new File(new File(context.stateDir(), "rocksdb"), DB_NAME).getAbsolutePath(), + columnFamilyDescriptors, + columnFamilies); + + noTimestampColumnFamily = columnFamilies.get(0); + withTimestampColumnFamily = columnFamilies.get(1); + + assertThat(db.get(noTimestampColumnFamily, "unknown".getBytes()), new IsNull<>()); + assertThat(db.get(noTimestampColumnFamily, "key1".getBytes()), new IsNull<>()); + assertThat(db.get(noTimestampColumnFamily, "key2".getBytes()), new IsNull<>()); + assertThat(db.get(noTimestampColumnFamily, "key3".getBytes()), new IsNull<>()); + assertThat(db.get(noTimestampColumnFamily, "key4".getBytes()), new IsNull<>()); + assertThat(db.get(noTimestampColumnFamily, "key5".getBytes()), new IsNull<>()); + assertThat(db.get(noTimestampColumnFamily, "key6".getBytes()), new IsNull<>()); + assertThat(db.get(noTimestampColumnFamily, "key7".getBytes()).length, is(7)); + assertThat(db.get(noTimestampColumnFamily, "key8".getBytes()), new IsNull<>()); + assertThat(db.get(noTimestampColumnFamily, "key11".getBytes()), new IsNull<>()); + assertThat(db.get(noTimestampColumnFamily, "key12".getBytes()), new IsNull<>()); + + assertThat(db.get(withTimestampColumnFamily, "unknown".getBytes()), new IsNull<>()); + assertThat(db.get(withTimestampColumnFamily, "key1".getBytes()).length, is(8 + 1)); + assertThat(db.get(withTimestampColumnFamily, "key2".getBytes()).length, is(12)); + assertThat(db.get(withTimestampColumnFamily, "key3".getBytes()), new IsNull<>()); + assertThat(db.get(withTimestampColumnFamily, "key4".getBytes()).length, is(8 + 4)); + assertThat(db.get(withTimestampColumnFamily, "key5".getBytes()).length, is(8 + 5)); + assertThat(db.get(withTimestampColumnFamily, "key6".getBytes()), new IsNull<>()); + assertThat(db.get(withTimestampColumnFamily, "key7".getBytes()), new IsNull<>()); + assertThat(db.get(withTimestampColumnFamily, "key8".getBytes()).length, is(18)); + assertThat(db.get(withTimestampColumnFamily, "key11".getBytes()).length, is(21)); + assertThat(db.get(withTimestampColumnFamily, "key12".getBytes()), new IsNull<>()); + } catch (final RuntimeException fatal) { + errorOccurred = true; + } finally { + // Order of closing must follow: ColumnFamilyHandle > RocksDB > DBOptions > ColumnFamilyOptions + if (noTimestampColumnFamily != null) { + noTimestampColumnFamily.close(); + } + if (withTimestampColumnFamily != null) { + withTimestampColumnFamily.close(); + } + if (db != null) { + db.close(); + } + if (errorOccurred) { + dbOptions.close(); + columnFamilyOptions.close(); + } + } + + // check that still in upgrade mode + try (LogCaptureAppender appender = LogCaptureAppender.createAndRegister(RocksDBTimestampedStore.class)) { + rocksDBStore.init((StateStoreContext) context, rocksDBStore); + + assertThat(appender.getMessages(), hasItem("Opening store " + DB_NAME + " in upgrade mode")); + } finally { + rocksDBStore.close(); + } + + // clear old CF + columnFamilies.clear(); + db = null; + noTimestampColumnFamily = null; + try { + db = RocksDB.open( + dbOptions, + new File(new File(context.stateDir(), "rocksdb"), DB_NAME).getAbsolutePath(), + columnFamilyDescriptors, + columnFamilies); + + noTimestampColumnFamily = columnFamilies.get(0); + db.delete(noTimestampColumnFamily, "key7".getBytes()); + } finally { + // Order of closing must follow: ColumnFamilyHandle > RocksDB > DBOptions > ColumnFamilyOptions + if (noTimestampColumnFamily != null) { + noTimestampColumnFamily.close(); + } + if (db != null) { + db.close(); + } + dbOptions.close(); + columnFamilyOptions.close(); + } + + // check that still in regular mode + try (LogCaptureAppender appender = LogCaptureAppender.createAndRegister(RocksDBTimestampedStore.class)) { + rocksDBStore.init((StateStoreContext) context, rocksDBStore); + + assertThat(appender.getMessages(), hasItem("Opening store " + DB_NAME + " in regular mode")); + } + } + + private void prepareOldStore() { + final RocksDBStore keyValueStore = new RocksDBStore(DB_NAME, METRICS_SCOPE); + try { + keyValueStore.init((StateStoreContext) context, keyValueStore); + + keyValueStore.put(new Bytes("key1".getBytes()), "1".getBytes()); + keyValueStore.put(new Bytes("key2".getBytes()), "22".getBytes()); + keyValueStore.put(new Bytes("key3".getBytes()), "333".getBytes()); + keyValueStore.put(new Bytes("key4".getBytes()), "4444".getBytes()); + keyValueStore.put(new Bytes("key5".getBytes()), "55555".getBytes()); + keyValueStore.put(new Bytes("key6".getBytes()), "666666".getBytes()); + keyValueStore.put(new Bytes("key7".getBytes()), "7777777".getBytes()); + } finally { + keyValueStore.close(); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBWindowStoreTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBWindowStoreTest.java new file mode 100644 index 0000000..7da5245 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/RocksDBWindowStoreTest.java @@ -0,0 +1,646 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import java.io.File; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.HashSet; +import java.util.Set; + +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.streams.state.WindowStoreIterator; +import org.junit.Test; + +import static java.time.Duration.ofMillis; +import static java.time.Instant.ofEpochMilli; +import static java.util.Arrays.asList; +import static java.util.Objects.requireNonNull; +import static org.apache.kafka.test.StreamsTestUtils.valuesToSet; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; + +public class RocksDBWindowStoreTest extends AbstractWindowBytesStoreTest { + + private static final String STORE_NAME = "rocksDB window store"; + private static final String METRICS_SCOPE = "test-state-id"; + + private final KeyValueSegments segments = + new KeyValueSegments(STORE_NAME, METRICS_SCOPE, RETENTION_PERIOD, SEGMENT_INTERVAL); + + @Override + WindowStore buildWindowStore(final long retentionPeriod, + final long windowSize, + final boolean retainDuplicates, + final Serde keySerde, + final Serde valueSerde) { + return Stores.windowStoreBuilder( + Stores.persistentWindowStore( + STORE_NAME, + ofMillis(retentionPeriod), + ofMillis(windowSize), + retainDuplicates), + keySerde, + valueSerde) + .build(); + } + + @Test + public void shouldOnlyIterateOpenSegments() { + long currentTime = 0; + windowStore.put(1, "one", currentTime); + + currentTime = currentTime + SEGMENT_INTERVAL; + windowStore.put(1, "two", currentTime); + currentTime = currentTime + SEGMENT_INTERVAL; + + windowStore.put(1, "three", currentTime); + + try (final WindowStoreIterator iterator = windowStore.fetch(1, 0L, currentTime)) { + + // roll to the next segment that will close the first + currentTime = currentTime + SEGMENT_INTERVAL; + windowStore.put(1, "four", currentTime); + + // should only have 2 values as the first segment is no longer open + assertEquals(new KeyValue<>(SEGMENT_INTERVAL, "two"), iterator.next()); + assertEquals(new KeyValue<>(2 * SEGMENT_INTERVAL, "three"), iterator.next()); + assertFalse(iterator.hasNext()); + } + } + + @Test + public void testRolling() { + + // to validate segments + final long startTime = SEGMENT_INTERVAL * 2; + final long increment = SEGMENT_INTERVAL / 2; + windowStore.put(0, "zero", startTime); + assertEquals(Utils.mkSet(segments.segmentName(2)), segmentDirs(baseDir)); + + windowStore.put(1, "one", startTime + increment); + assertEquals(Utils.mkSet(segments.segmentName(2)), segmentDirs(baseDir)); + + windowStore.put(2, "two", startTime + increment * 2); + assertEquals( + Utils.mkSet( + segments.segmentName(2), + segments.segmentName(3) + ), + segmentDirs(baseDir) + ); + + windowStore.put(4, "four", startTime + increment * 4); + assertEquals( + Utils.mkSet( + segments.segmentName(2), + segments.segmentName(3), + segments.segmentName(4) + ), + segmentDirs(baseDir) + ); + + windowStore.put(5, "five", startTime + increment * 5); + assertEquals( + Utils.mkSet( + segments.segmentName(2), + segments.segmentName(3), + segments.segmentName(4) + ), + segmentDirs(baseDir) + ); + + assertEquals( + new HashSet<>(Collections.singletonList("zero")), + valuesToSet(windowStore.fetch( + 0, + ofEpochMilli(startTime - WINDOW_SIZE), + ofEpochMilli(startTime + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("one")), + valuesToSet(windowStore.fetch( + 1, + ofEpochMilli(startTime + increment - WINDOW_SIZE), + ofEpochMilli(startTime + increment + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("two")), + valuesToSet(windowStore.fetch( + 2, + ofEpochMilli(startTime + increment * 2 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 2 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch( + 3, + ofEpochMilli(startTime + increment * 3 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 3 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("four")), + valuesToSet(windowStore.fetch( + 4, + ofEpochMilli(startTime + increment * 4 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 4 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("five")), + valuesToSet(windowStore.fetch( + 5, + ofEpochMilli(startTime + increment * 5 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 5 + WINDOW_SIZE)))); + + windowStore.put(6, "six", startTime + increment * 6); + assertEquals( + Utils.mkSet( + segments.segmentName(3), + segments.segmentName(4), + segments.segmentName(5) + ), + segmentDirs(baseDir) + ); + + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch( + 0, + ofEpochMilli(startTime - WINDOW_SIZE), + ofEpochMilli(startTime + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch( + 1, + ofEpochMilli(startTime + increment - WINDOW_SIZE), + ofEpochMilli(startTime + increment + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("two")), + valuesToSet(windowStore.fetch( + 2, + ofEpochMilli(startTime + increment * 2 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 2 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch( + 3, + ofEpochMilli(startTime + increment * 3 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 3 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("four")), + valuesToSet(windowStore.fetch( + 4, + ofEpochMilli(startTime + increment * 4 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 4 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("five")), + valuesToSet(windowStore.fetch( + 5, + ofEpochMilli(startTime + increment * 5 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 5 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("six")), + valuesToSet(windowStore.fetch( + 6, + ofEpochMilli(startTime + increment * 6 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 6 + WINDOW_SIZE)))); + + windowStore.put(7, "seven", startTime + increment * 7); + assertEquals( + Utils.mkSet( + segments.segmentName(3), + segments.segmentName(4), + segments.segmentName(5) + ), + segmentDirs(baseDir) + ); + + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch( + 0, + ofEpochMilli(startTime - WINDOW_SIZE), + ofEpochMilli(startTime + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch( + 1, + ofEpochMilli(startTime + increment - WINDOW_SIZE), + ofEpochMilli(startTime + increment + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("two")), + valuesToSet(windowStore.fetch( + 2, + ofEpochMilli(startTime + increment * 2 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 2 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch( + 3, + ofEpochMilli(startTime + increment * 3 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 3 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("four")), + valuesToSet(windowStore.fetch( + 4, + ofEpochMilli(startTime + increment * 4 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 4 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("five")), + valuesToSet(windowStore.fetch( + 5, + ofEpochMilli(startTime + increment * 5 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 5 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("six")), + valuesToSet(windowStore.fetch( + 6, + ofEpochMilli(startTime + increment * 6 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 6 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("seven")), + valuesToSet(windowStore.fetch( + 7, + ofEpochMilli(startTime + increment * 7 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 7 + WINDOW_SIZE)))); + + windowStore.put(8, "eight", startTime + increment * 8); + assertEquals( + Utils.mkSet( + segments.segmentName(4), + segments.segmentName(5), + segments.segmentName(6) + ), + segmentDirs(baseDir) + ); + + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch( + 0, + ofEpochMilli(startTime - WINDOW_SIZE), + ofEpochMilli(startTime + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch( + 1, + ofEpochMilli(startTime + increment - WINDOW_SIZE), + ofEpochMilli(startTime + increment + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch( + 2, + ofEpochMilli(startTime + increment * 2 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 2 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch( + 3, + ofEpochMilli(startTime + increment * 3 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 3 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("four")), + valuesToSet(windowStore.fetch( + 4, + ofEpochMilli(startTime + increment * 4 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 4 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("five")), + valuesToSet(windowStore.fetch( + 5, + ofEpochMilli(startTime + increment * 5 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 5 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("six")), + valuesToSet(windowStore.fetch( + 6, + ofEpochMilli(startTime + increment * 6 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 6 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("seven")), + valuesToSet(windowStore.fetch( + 7, + ofEpochMilli(startTime + increment * 7 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 7 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("eight")), + valuesToSet(windowStore.fetch( + 8, + ofEpochMilli(startTime + increment * 8 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 8 + WINDOW_SIZE)))); + + // check segment directories + windowStore.flush(); + assertEquals( + Utils.mkSet( + segments.segmentName(4), + segments.segmentName(5), + segments.segmentName(6) + ), + segmentDirs(baseDir) + ); + } + + @Test + public void testSegmentMaintenance() { + windowStore.close(); + windowStore = buildWindowStore(RETENTION_PERIOD, WINDOW_SIZE, true, Serdes.Integer(), + Serdes.String()); + windowStore.init((StateStoreContext) context, windowStore); + + context.setTime(0L); + windowStore.put(0, "v", 0); + assertEquals( + Utils.mkSet(segments.segmentName(0L)), + segmentDirs(baseDir) + ); + + windowStore.put(0, "v", SEGMENT_INTERVAL - 1); + windowStore.put(0, "v", SEGMENT_INTERVAL - 1); + assertEquals( + Utils.mkSet(segments.segmentName(0L)), + segmentDirs(baseDir) + ); + + windowStore.put(0, "v", SEGMENT_INTERVAL); + assertEquals( + Utils.mkSet(segments.segmentName(0L), segments.segmentName(1L)), + segmentDirs(baseDir) + ); + + WindowStoreIterator iter; + int fetchedCount; + + iter = windowStore.fetch(0, ofEpochMilli(0L), ofEpochMilli(SEGMENT_INTERVAL * 4)); + fetchedCount = 0; + while (iter.hasNext()) { + iter.next(); + fetchedCount++; + } + assertEquals(4, fetchedCount); + + assertEquals( + Utils.mkSet(segments.segmentName(0L), segments.segmentName(1L)), + segmentDirs(baseDir) + ); + + windowStore.put(0, "v", SEGMENT_INTERVAL * 3); + + iter = windowStore.fetch(0, ofEpochMilli(0L), ofEpochMilli(SEGMENT_INTERVAL * 4)); + fetchedCount = 0; + while (iter.hasNext()) { + iter.next(); + fetchedCount++; + } + assertEquals(2, fetchedCount); + + assertEquals( + Utils.mkSet(segments.segmentName(1L), segments.segmentName(3L)), + segmentDirs(baseDir) + ); + + windowStore.put(0, "v", SEGMENT_INTERVAL * 5); + + iter = windowStore.fetch(0, ofEpochMilli(SEGMENT_INTERVAL * 4), ofEpochMilli(SEGMENT_INTERVAL * 10)); + fetchedCount = 0; + while (iter.hasNext()) { + iter.next(); + fetchedCount++; + } + assertEquals(1, fetchedCount); + + assertEquals( + Utils.mkSet(segments.segmentName(3L), segments.segmentName(5L)), + segmentDirs(baseDir) + ); + + } + + @SuppressWarnings("ResultOfMethodCallIgnored") + @Test + public void testInitialLoading() { + final File storeDir = new File(baseDir, STORE_NAME); + + new File(storeDir, segments.segmentName(0L)).mkdir(); + new File(storeDir, segments.segmentName(1L)).mkdir(); + new File(storeDir, segments.segmentName(2L)).mkdir(); + new File(storeDir, segments.segmentName(3L)).mkdir(); + new File(storeDir, segments.segmentName(4L)).mkdir(); + new File(storeDir, segments.segmentName(5L)).mkdir(); + new File(storeDir, segments.segmentName(6L)).mkdir(); + windowStore.close(); + + windowStore = buildWindowStore(RETENTION_PERIOD, WINDOW_SIZE, false, Serdes.Integer(), Serdes.String()); + windowStore.init((StateStoreContext) context, windowStore); + + // put something in the store to advance its stream time and expire the old segments + windowStore.put(1, "v", 6L * SEGMENT_INTERVAL); + + final List expected = asList( + segments.segmentName(4L), + segments.segmentName(5L), + segments.segmentName(6L)); + expected.sort(String::compareTo); + + final List actual = Utils.toList(segmentDirs(baseDir).iterator()); + actual.sort(String::compareTo); + + assertEquals(expected, actual); + + try (final WindowStoreIterator iter = windowStore.fetch(0, ofEpochMilli(0L), ofEpochMilli(1000000L))) { + while (iter.hasNext()) { + iter.next(); + } + } + + assertEquals( + Utils.mkSet( + segments.segmentName(4L), + segments.segmentName(5L), + segments.segmentName(6L)), + segmentDirs(baseDir) + ); + } + + @SuppressWarnings("unchecked") + @Test + public void testRestore() throws Exception { + final long startTime = SEGMENT_INTERVAL * 2; + final long increment = SEGMENT_INTERVAL / 2; + + windowStore.put(0, "zero", startTime); + windowStore.put(1, "one", startTime + increment); + windowStore.put(2, "two", startTime + increment * 2); + windowStore.put(3, "three", startTime + increment * 3); + windowStore.put(4, "four", startTime + increment * 4); + windowStore.put(5, "five", startTime + increment * 5); + windowStore.put(6, "six", startTime + increment * 6); + windowStore.put(7, "seven", startTime + increment * 7); + windowStore.put(8, "eight", startTime + increment * 8); + windowStore.flush(); + + windowStore.close(); + + // remove local store image + Utils.delete(baseDir); + + windowStore = buildWindowStore(RETENTION_PERIOD, + WINDOW_SIZE, + false, + Serdes.Integer(), + Serdes.String()); + windowStore.init((StateStoreContext) context, windowStore); + + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch( + 0, + ofEpochMilli(startTime - WINDOW_SIZE), + ofEpochMilli(startTime + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch( + 1, + ofEpochMilli(startTime + increment - WINDOW_SIZE), + ofEpochMilli(startTime + increment + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch( + 2, + ofEpochMilli(startTime + increment * 2 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 2 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch( + 3, + ofEpochMilli(startTime + increment * 3 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 3 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch( + 4, + ofEpochMilli(startTime + increment * 4 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 4 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch( + 5, + ofEpochMilli(startTime + increment * 5 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 5 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch( + 6, + ofEpochMilli(startTime + increment * 6 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 6 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch( + 7, + ofEpochMilli(startTime + increment * 7 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 7 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch( + 8, + ofEpochMilli(startTime + increment * 8 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 8 + WINDOW_SIZE)))); + + final List> changeLog = new ArrayList<>(); + for (final ProducerRecord record : recordCollector.collected()) { + changeLog.add(new KeyValue<>(((Bytes) record.key()).get(), (byte[]) record.value())); + } + + context.restore(STORE_NAME, changeLog); + + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch( + 0, + ofEpochMilli(startTime - WINDOW_SIZE), + ofEpochMilli(startTime + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch( + 1, + ofEpochMilli(startTime + increment - WINDOW_SIZE), + ofEpochMilli(startTime + increment + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch( + 2, + ofEpochMilli(startTime + increment * 2 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 2 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.emptyList()), + valuesToSet(windowStore.fetch( + 3, + ofEpochMilli(startTime + increment * 3 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 3 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("four")), + valuesToSet(windowStore.fetch( + 4, + ofEpochMilli(startTime + increment * 4 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 4 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("five")), + valuesToSet(windowStore.fetch( + 5, + ofEpochMilli(startTime + increment * 5 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 5 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("six")), + valuesToSet(windowStore.fetch( + 6, + ofEpochMilli(startTime + increment * 6 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 6 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("seven")), + valuesToSet(windowStore.fetch( + 7, + ofEpochMilli(startTime + increment * 7 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 7 + WINDOW_SIZE)))); + assertEquals( + new HashSet<>(Collections.singletonList("eight")), + valuesToSet(windowStore.fetch( + 8, + ofEpochMilli(startTime + increment * 8 - WINDOW_SIZE), + ofEpochMilli(startTime + increment * 8 + WINDOW_SIZE)))); + + // check segment directories + windowStore.flush(); + assertEquals( + Utils.mkSet( + segments.segmentName(4L), + segments.segmentName(5L), + segments.segmentName(6L)), + segmentDirs(baseDir) + ); + } + + private Set segmentDirs(final File baseDir) { + final File windowDir = new File(baseDir, windowStore.name()); + + return new HashSet<>(asList(requireNonNull(windowDir.list()))); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/SegmentIteratorTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/SegmentIteratorTest.java new file mode 100644 index 0000000..31ce3c6 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/SegmentIteratorTest.java @@ -0,0 +1,368 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.internals.MockStreamsMetrics; +import org.apache.kafka.streams.state.internals.metrics.RocksDBMetricsRecorder; +import org.apache.kafka.test.InternalMockProcessorContext; +import org.apache.kafka.test.MockRecordCollector; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; +import java.util.NoSuchElementException; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class SegmentIteratorTest { + + private final RocksDBMetricsRecorder rocksDBMetricsRecorder = + new RocksDBMetricsRecorder("metrics-scope", "store-name"); + private final KeyValueSegment segmentOne = + new KeyValueSegment("one", "one", 0, rocksDBMetricsRecorder); + private final KeyValueSegment segmentTwo = + new KeyValueSegment("two", "window", 1, rocksDBMetricsRecorder); + private final HasNextCondition hasNextCondition = Iterator::hasNext; + + private SegmentIterator iterator = null; + + @SuppressWarnings("rawtypes") + @Before + public void before() { + final InternalMockProcessorContext context = new InternalMockProcessorContext<>( + TestUtils.tempDirectory(), + Serdes.String(), + Serdes.String(), + new MockRecordCollector(), + new ThreadCache( + new LogContext("testCache "), + 0, + new MockStreamsMetrics(new Metrics()))); + segmentOne.init((StateStoreContext) context, segmentOne); + segmentTwo.init((StateStoreContext) context, segmentTwo); + segmentOne.put(Bytes.wrap("a".getBytes()), "1".getBytes()); + segmentOne.put(Bytes.wrap("b".getBytes()), "2".getBytes()); + segmentTwo.put(Bytes.wrap("c".getBytes()), "3".getBytes()); + segmentTwo.put(Bytes.wrap("d".getBytes()), "4".getBytes()); + } + + @After + public void closeSegments() { + if (iterator != null) { + iterator.close(); + iterator = null; + } + segmentOne.close(); + segmentTwo.close(); + } + + @Test + public void shouldIterateOverAllSegments() { + iterator = new SegmentIterator<>( + Arrays.asList(segmentOne, segmentTwo).iterator(), + hasNextCondition, + Bytes.wrap("a".getBytes()), + Bytes.wrap("z".getBytes()), + true); + + assertTrue(iterator.hasNext()); + assertEquals("a", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("a", "1"), toStringKeyValue(iterator.next())); + + assertTrue(iterator.hasNext()); + assertEquals("b", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("b", "2"), toStringKeyValue(iterator.next())); + + assertTrue(iterator.hasNext()); + assertEquals("c", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("c", "3"), toStringKeyValue(iterator.next())); + + assertTrue(iterator.hasNext()); + assertEquals("d", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("d", "4"), toStringKeyValue(iterator.next())); + + assertFalse(iterator.hasNext()); + } + + @Test + public void shouldIterateOverAllSegmentsWhenNullKeyFromKeyTo() { + iterator = new SegmentIterator<>( + Arrays.asList(segmentOne, segmentTwo).iterator(), + hasNextCondition, + null, + null, + true); + + assertTrue(iterator.hasNext()); + assertEquals("a", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("a", "1"), toStringKeyValue(iterator.next())); + + assertTrue(iterator.hasNext()); + assertEquals("b", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("b", "2"), toStringKeyValue(iterator.next())); + + assertTrue(iterator.hasNext()); + assertEquals("c", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("c", "3"), toStringKeyValue(iterator.next())); + + assertTrue(iterator.hasNext()); + assertEquals("d", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("d", "4"), toStringKeyValue(iterator.next())); + + assertFalse(iterator.hasNext()); + } + + @Test + public void shouldIterateBackwardOverAllSegments() { + iterator = new SegmentIterator<>( + Arrays.asList(segmentTwo, segmentOne).iterator(), //store should pass the segments in the right order + hasNextCondition, + Bytes.wrap("a".getBytes()), + Bytes.wrap("z".getBytes()), + false); + + assertTrue(iterator.hasNext()); + assertEquals("d", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("d", "4"), toStringKeyValue(iterator.next())); + + assertTrue(iterator.hasNext()); + assertEquals("c", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("c", "3"), toStringKeyValue(iterator.next())); + + assertTrue(iterator.hasNext()); + assertEquals("b", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("b", "2"), toStringKeyValue(iterator.next())); + + assertTrue(iterator.hasNext()); + assertEquals("a", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("a", "1"), toStringKeyValue(iterator.next())); + + assertFalse(iterator.hasNext()); + } + + @Test + public void shouldIterateBackwardOverAllSegmentsWhenNullKeyFromKeyTo() { + iterator = new SegmentIterator<>( + Arrays.asList(segmentTwo, segmentOne).iterator(), //store should pass the segments in the right order + hasNextCondition, + null, + null, + false); + + assertTrue(iterator.hasNext()); + assertEquals("d", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("d", "4"), toStringKeyValue(iterator.next())); + + assertTrue(iterator.hasNext()); + assertEquals("c", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("c", "3"), toStringKeyValue(iterator.next())); + + assertTrue(iterator.hasNext()); + assertEquals("b", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("b", "2"), toStringKeyValue(iterator.next())); + + assertTrue(iterator.hasNext()); + assertEquals("a", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("a", "1"), toStringKeyValue(iterator.next())); + + assertFalse(iterator.hasNext()); + } + + @Test + public void shouldNotThrowExceptionOnHasNextWhenStoreClosed() { + iterator = new SegmentIterator<>( + Collections.singletonList(segmentOne).iterator(), + hasNextCondition, + Bytes.wrap("a".getBytes()), + Bytes.wrap("z".getBytes()), + true); + + iterator.currentIterator = segmentOne.all(); + segmentOne.close(); + assertFalse(iterator.hasNext()); + } + + @Test + public void shouldOnlyIterateOverSegmentsInBackwardRange() { + iterator = new SegmentIterator<>( + Arrays.asList(segmentOne, segmentTwo).iterator(), + hasNextCondition, + Bytes.wrap("a".getBytes()), + Bytes.wrap("b".getBytes()), + false); + + assertTrue(iterator.hasNext()); + assertEquals("b", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("b", "2"), toStringKeyValue(iterator.next())); + + assertTrue(iterator.hasNext()); + assertEquals("a", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("a", "1"), toStringKeyValue(iterator.next())); + + assertFalse(iterator.hasNext()); + } + + @Test + public void shouldOnlyIterateOverSegmentsInBackwardRangeWhenNullKeyFrom() { + iterator = new SegmentIterator<>( + Arrays.asList(segmentOne, segmentTwo).iterator(), + hasNextCondition, + null, + Bytes.wrap("b".getBytes()), + false); + + assertTrue(iterator.hasNext()); + assertEquals("b", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("b", "2"), toStringKeyValue(iterator.next())); + + + assertTrue(iterator.hasNext()); + assertEquals("a", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("a", "1"), toStringKeyValue(iterator.next())); + + assertFalse(iterator.hasNext()); + } + + @Test + public void shouldOnlyIterateOverSegmentsInBackwardRangeWhenNullKeyTo() { + iterator = new SegmentIterator<>( + Arrays.asList(segmentOne, segmentTwo).iterator(), + hasNextCondition, + Bytes.wrap("c".getBytes()), + null, + false); + + assertTrue(iterator.hasNext()); + assertEquals("d", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("d", "4"), toStringKeyValue(iterator.next())); + + assertTrue(iterator.hasNext()); + assertEquals("c", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("c", "3"), toStringKeyValue(iterator.next())); + + assertFalse(iterator.hasNext()); + } + + @Test + public void shouldOnlyIterateOverSegmentsInRange() { + iterator = new SegmentIterator<>( + Arrays.asList(segmentOne, segmentTwo).iterator(), + hasNextCondition, + Bytes.wrap("a".getBytes()), + Bytes.wrap("b".getBytes()), + true); + + assertTrue(iterator.hasNext()); + assertEquals("a", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("a", "1"), toStringKeyValue(iterator.next())); + + assertTrue(iterator.hasNext()); + assertEquals("b", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("b", "2"), toStringKeyValue(iterator.next())); + + assertFalse(iterator.hasNext()); + } + + @Test + public void shouldOnlyIterateOverSegmentsInRangeWhenNullKeyFrom() { + iterator = new SegmentIterator<>( + Arrays.asList(segmentOne, segmentTwo).iterator(), + hasNextCondition, + null, + Bytes.wrap("c".getBytes()), + true); + + assertTrue(iterator.hasNext()); + assertEquals("a", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("a", "1"), toStringKeyValue(iterator.next())); + + assertTrue(iterator.hasNext()); + assertEquals("b", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("b", "2"), toStringKeyValue(iterator.next())); + + assertTrue(iterator.hasNext()); + assertEquals("c", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("c", "3"), toStringKeyValue(iterator.next())); + + assertFalse(iterator.hasNext()); + } + + @Test + public void shouldOnlyIterateOverSegmentsInRangeWhenNullKeyTo() { + iterator = new SegmentIterator<>( + Arrays.asList(segmentOne, segmentTwo).iterator(), + hasNextCondition, + Bytes.wrap("b".getBytes()), + null, + true); + + assertTrue(iterator.hasNext()); + assertEquals("b", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("b", "2"), toStringKeyValue(iterator.next())); + + assertTrue(iterator.hasNext()); + assertEquals("c", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("c", "3"), toStringKeyValue(iterator.next())); + + assertTrue(iterator.hasNext()); + assertEquals("d", new String(iterator.peekNextKey().get())); + assertEquals(KeyValue.pair("d", "4"), toStringKeyValue(iterator.next())); + + assertFalse(iterator.hasNext()); + } + + @Test + public void shouldThrowNoSuchElementOnPeekNextKeyIfNoNext() { + iterator = new SegmentIterator<>( + Arrays.asList(segmentOne, segmentTwo).iterator(), + hasNextCondition, + Bytes.wrap("f".getBytes()), + Bytes.wrap("h".getBytes()), + true); + + assertThrows(NoSuchElementException.class, () -> iterator.peekNextKey()); + } + + @Test + public void shouldThrowNoSuchElementOnNextIfNoNext() { + iterator = new SegmentIterator<>( + Arrays.asList(segmentOne, segmentTwo).iterator(), + hasNextCondition, + Bytes.wrap("f".getBytes()), + Bytes.wrap("h".getBytes()), + true); + + assertThrows(NoSuchElementException.class, () -> iterator.next()); + } + + private KeyValue toStringKeyValue(final KeyValue binaryKv) { + return KeyValue.pair(new String(binaryKv.key.get()), new String(binaryKv.value)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/SegmentedCacheFunctionTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/SegmentedCacheFunctionTest.java new file mode 100644 index 0000000..1f6a747 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/SegmentedCacheFunctionTest.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.utils.Bytes; + +import org.junit.Test; + +import java.nio.ByteBuffer; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.IsEqual.equalTo; + +// TODO: this test coverage does not consider session serde yet +public class SegmentedCacheFunctionTest { + + private static final int SEGMENT_INTERVAL = 17; + private static final int TIMESTAMP = 736213517; + + private static final Bytes THE_KEY = WindowKeySchema.toStoreKeyBinary(new byte[]{0xA, 0xB, 0xC}, TIMESTAMP, 42); + private final static Bytes THE_CACHE_KEY = Bytes.wrap( + ByteBuffer.allocate(8 + THE_KEY.get().length) + .putLong(TIMESTAMP / SEGMENT_INTERVAL) + .put(THE_KEY.get()).array() + ); + + private final SegmentedCacheFunction cacheFunction = new SegmentedCacheFunction(new WindowKeySchema(), SEGMENT_INTERVAL); + + @Test + public void key() { + assertThat( + cacheFunction.key(THE_CACHE_KEY), + equalTo(THE_KEY) + ); + } + + @Test + public void cacheKey() { + final long segmentId = TIMESTAMP / SEGMENT_INTERVAL; + + final Bytes actualCacheKey = cacheFunction.cacheKey(THE_KEY); + final ByteBuffer buffer = ByteBuffer.wrap(actualCacheKey.get()); + + assertThat(buffer.getLong(), equalTo(segmentId)); + + final byte[] actualKey = new byte[buffer.remaining()]; + buffer.get(actualKey); + assertThat(Bytes.wrap(actualKey), equalTo(THE_KEY)); + } + + @Test + public void testRoundTripping() { + assertThat( + cacheFunction.key(cacheFunction.cacheKey(THE_KEY)), + equalTo(THE_KEY) + ); + + assertThat( + cacheFunction.cacheKey(cacheFunction.key(THE_CACHE_KEY)), + equalTo(THE_CACHE_KEY) + ); + } + + @Test + public void compareSegmentedKeys() { + assertThat( + "same key in same segment should be ranked the same", + cacheFunction.compareSegmentedKeys( + cacheFunction.cacheKey(THE_KEY), + THE_KEY + ) == 0 + ); + + final Bytes sameKeyInPriorSegment = WindowKeySchema.toStoreKeyBinary(new byte[]{0xA, 0xB, 0xC}, 1234, 42); + + assertThat( + "same keys in different segments should be ordered according to segment", + cacheFunction.compareSegmentedKeys( + cacheFunction.cacheKey(sameKeyInPriorSegment), + THE_KEY + ) < 0 + ); + + assertThat( + "same keys in different segments should be ordered according to segment", + cacheFunction.compareSegmentedKeys( + cacheFunction.cacheKey(THE_KEY), + sameKeyInPriorSegment + ) > 0 + ); + + final Bytes lowerKeyInSameSegment = WindowKeySchema.toStoreKeyBinary(new byte[]{0xA, 0xB, 0xB}, TIMESTAMP - 1, 0); + + assertThat( + "different keys in same segments should be ordered according to key", + cacheFunction.compareSegmentedKeys( + cacheFunction.cacheKey(THE_KEY), + lowerKeyInSameSegment + ) > 0 + ); + + assertThat( + "different keys in same segments should be ordered according to key", + cacheFunction.compareSegmentedKeys( + cacheFunction.cacheKey(lowerKeyInSameSegment), + THE_KEY + ) < 0 + ); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/SerdeThatDoesntHandleNull.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/SerdeThatDoesntHandleNull.java new file mode 100644 index 0000000..03e0c3a --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/SerdeThatDoesntHandleNull.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; + +class SerdeThatDoesntHandleNull implements Serde { + @Override + public Serializer serializer() { + return new StringSerializer(); + } + + @Override + public Deserializer deserializer() { + return new StringDeserializer() { + @Override + public String deserialize(final String topic, final byte[] data) { + if (data == null) { + throw new NullPointerException(); + } + return super.deserialize(topic, data); + } + }; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/SessionKeySchemaTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/SessionKeySchemaTest.java new file mode 100644 index 0000000..40b06c0 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/SessionKeySchemaTest.java @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Window; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.WindowedSerdes; +import org.apache.kafka.streams.kstream.internals.SessionWindow; +import org.apache.kafka.test.KeyValueIteratorStub; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.IsEqual.equalTo; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +public class SessionKeySchemaTest { + + private final String key = "key"; + private final String topic = "topic"; + private final long startTime = 50L; + private final long endTime = 100L; + private final Serde serde = Serdes.String(); + + private final Window window = new SessionWindow(startTime, endTime); + private final Windowed windowedKey = new Windowed<>(key, window); + private final Serde> keySerde = new WindowedSerdes.SessionWindowedSerde<>(serde); + + private final SessionKeySchema sessionKeySchema = new SessionKeySchema(); + private DelegatingPeekingKeyValueIterator iterator; + + @After + public void after() { + if (iterator != null) { + iterator.close(); + } + } + + @Before + public void before() { + final List> keys = Arrays.asList(KeyValue.pair(SessionKeySchema.toBinary(new Windowed<>(Bytes.wrap(new byte[]{0, 0}), new SessionWindow(0, 0))), 1), + KeyValue.pair(SessionKeySchema.toBinary(new Windowed<>(Bytes.wrap(new byte[]{0}), new SessionWindow(0, 0))), 2), + KeyValue.pair(SessionKeySchema.toBinary(new Windowed<>(Bytes.wrap(new byte[]{0, 0, 0}), new SessionWindow(0, 0))), 3), + KeyValue.pair(SessionKeySchema.toBinary(new Windowed<>(Bytes.wrap(new byte[]{0}), new SessionWindow(10, 20))), 4), + KeyValue.pair(SessionKeySchema.toBinary(new Windowed<>(Bytes.wrap(new byte[]{0, 0}), new SessionWindow(10, 20))), 5), + KeyValue.pair(SessionKeySchema.toBinary(new Windowed<>(Bytes.wrap(new byte[]{0, 0, 0}), new SessionWindow(10, 20))), 6)); + iterator = new DelegatingPeekingKeyValueIterator<>("foo", new KeyValueIteratorStub<>(keys.iterator())); + } + + @Test + public void shouldFetchExactKeysSkippingLongerKeys() { + final Bytes key = Bytes.wrap(new byte[]{0}); + final List result = getValues(sessionKeySchema.hasNextCondition(key, key, 0, Long.MAX_VALUE)); + assertThat(result, equalTo(Arrays.asList(2, 4))); + } + + @Test + public void shouldFetchExactKeySkippingShorterKeys() { + final Bytes key = Bytes.wrap(new byte[]{0, 0}); + final HasNextCondition hasNextCondition = sessionKeySchema.hasNextCondition(key, key, 0, Long.MAX_VALUE); + final List results = getValues(hasNextCondition); + assertThat(results, equalTo(Arrays.asList(1, 5))); + } + + @Test + public void shouldFetchAllKeysUsingNullKeys() { + final HasNextCondition hasNextCondition = sessionKeySchema.hasNextCondition(null, null, 0, Long.MAX_VALUE); + final List results = getValues(hasNextCondition); + assertThat(results, equalTo(Arrays.asList(1, 2, 3, 4, 5, 6))); + } + + @Test + public void testUpperBoundWithLargeTimestamps() { + final Bytes upper = sessionKeySchema.upperRange(Bytes.wrap(new byte[]{0xA, 0xB, 0xC}), Long.MAX_VALUE); + + assertThat( + "shorter key with max timestamp should be in range", + upper.compareTo(SessionKeySchema.toBinary( + new Windowed<>( + Bytes.wrap(new byte[]{0xA}), + new SessionWindow(Long.MAX_VALUE, Long.MAX_VALUE)) + )) >= 0 + ); + + assertThat( + "shorter key with max timestamp should be in range", + upper.compareTo(SessionKeySchema.toBinary( + new Windowed<>( + Bytes.wrap(new byte[]{0xA, 0xB}), + new SessionWindow(Long.MAX_VALUE, Long.MAX_VALUE)) + + )) >= 0 + ); + + assertThat(upper, equalTo(SessionKeySchema.toBinary( + new Windowed<>(Bytes.wrap(new byte[]{0xA}), new SessionWindow(Long.MAX_VALUE, Long.MAX_VALUE)))) + ); + } + + @Test + public void testUpperBoundWithKeyBytesLargerThanFirstTimestampByte() { + final Bytes upper = sessionKeySchema.upperRange(Bytes.wrap(new byte[]{0xA, (byte) 0x8F, (byte) 0x9F}), Long.MAX_VALUE); + + assertThat( + "shorter key with max timestamp should be in range", + upper.compareTo(SessionKeySchema.toBinary( + new Windowed<>( + Bytes.wrap(new byte[]{0xA, (byte) 0x8F}), + new SessionWindow(Long.MAX_VALUE, Long.MAX_VALUE)) + ) + ) >= 0 + ); + + assertThat(upper, equalTo(SessionKeySchema.toBinary( + new Windowed<>(Bytes.wrap(new byte[]{0xA, (byte) 0x8F, (byte) 0x9F}), new SessionWindow(Long.MAX_VALUE, Long.MAX_VALUE)))) + ); + } + + @Test + public void testUpperBoundWithZeroTimestamp() { + final Bytes upper = sessionKeySchema.upperRange(Bytes.wrap(new byte[]{0xA, 0xB, 0xC}), 0); + + assertThat(upper, equalTo(SessionKeySchema.toBinary( + new Windowed<>(Bytes.wrap(new byte[]{0xA}), new SessionWindow(0, Long.MAX_VALUE)))) + ); + } + + @Test + public void testLowerBoundWithZeroTimestamp() { + final Bytes lower = sessionKeySchema.lowerRange(Bytes.wrap(new byte[]{0xA, 0xB, 0xC}), 0); + assertThat(lower, equalTo(SessionKeySchema.toBinary(new Windowed<>(Bytes.wrap(new byte[]{0xA, 0xB, 0xC}), new SessionWindow(0, 0))))); + } + + @Test + public void testLowerBoundMatchesTrailingZeros() { + final Bytes lower = sessionKeySchema.lowerRange(Bytes.wrap(new byte[]{0xA, 0xB, 0xC}), Long.MAX_VALUE); + + assertThat( + "appending zeros to key should still be in range", + lower.compareTo(SessionKeySchema.toBinary( + new Windowed<>( + Bytes.wrap(new byte[]{0xA, 0xB, 0xC, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}), + new SessionWindow(Long.MAX_VALUE, Long.MAX_VALUE)) + )) < 0 + ); + + assertThat(lower, equalTo(SessionKeySchema.toBinary(new Windowed<>(Bytes.wrap(new byte[]{0xA, 0xB, 0xC}), new SessionWindow(0, 0))))); + } + + @Test + public void shouldSerializeDeserialize() { + final byte[] bytes = keySerde.serializer().serialize(topic, windowedKey); + final Windowed result = keySerde.deserializer().deserialize(topic, bytes); + assertEquals(windowedKey, result); + } + + @Test + public void shouldSerializeNullToNull() { + assertNull(keySerde.serializer().serialize(topic, null)); + } + + @Test + public void shouldDeSerializeEmtpyByteArrayToNull() { + assertNull(keySerde.deserializer().deserialize(topic, new byte[0])); + } + + @Test + public void shouldDeSerializeNullToNull() { + assertNull(keySerde.deserializer().deserialize(topic, null)); + } + + @Test + public void shouldConvertToBinaryAndBack() { + final byte[] serialized = SessionKeySchema.toBinary(windowedKey, serde.serializer(), "dummy"); + final Windowed result = SessionKeySchema.from(serialized, Serdes.String().deserializer(), "dummy"); + assertEquals(windowedKey, result); + } + + @Test + public void shouldExtractEndTimeFromBinary() { + final byte[] serialized = SessionKeySchema.toBinary(windowedKey, serde.serializer(), "dummy"); + assertEquals(endTime, SessionKeySchema.extractEndTimestamp(serialized)); + } + + @Test + public void shouldExtractStartTimeFromBinary() { + final byte[] serialized = SessionKeySchema.toBinary(windowedKey, serde.serializer(), "dummy"); + assertEquals(startTime, SessionKeySchema.extractStartTimestamp(serialized)); + } + + @Test + public void shouldExtractWindowFromBindary() { + final byte[] serialized = SessionKeySchema.toBinary(windowedKey, serde.serializer(), "dummy"); + assertEquals(window, SessionKeySchema.extractWindow(serialized)); + } + + @Test + public void shouldExtractKeyBytesFromBinary() { + final byte[] serialized = SessionKeySchema.toBinary(windowedKey, serde.serializer(), "dummy"); + assertArrayEquals(key.getBytes(), SessionKeySchema.extractKeyBytes(serialized)); + } + + @Test + public void shouldExtractKeyFromBinary() { + final byte[] serialized = SessionKeySchema.toBinary(windowedKey, serde.serializer(), "dummy"); + assertEquals(windowedKey, SessionKeySchema.from(serialized, serde.deserializer(), "dummy")); + } + + @Test + public void shouldExtractBytesKeyFromBinary() { + final Bytes bytesKey = Bytes.wrap(key.getBytes()); + final Windowed windowedBytesKey = new Windowed<>(bytesKey, window); + final Bytes serialized = SessionKeySchema.toBinary(windowedBytesKey); + assertEquals(windowedBytesKey, SessionKeySchema.from(serialized)); + } + + private List getValues(final HasNextCondition hasNextCondition) { + final List results = new ArrayList<>(); + while (hasNextCondition.hasNext(iterator)) { + results.add(iterator.next().value); + } + return results; + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/SessionStoreBuilderTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/SessionStoreBuilderTest.java new file mode 100644 index 0000000..352b0cb --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/SessionStoreBuilderTest.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.state.SessionBytesStoreSupplier; +import org.apache.kafka.streams.state.SessionStore; +import org.easymock.EasyMockRunner; +import org.easymock.Mock; +import org.easymock.MockType; +import org.hamcrest.CoreMatchers; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import java.util.Collections; + +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.reset; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.core.IsInstanceOf.instanceOf; +import static org.junit.Assert.assertThrows; + +@RunWith(EasyMockRunner.class) +public class SessionStoreBuilderTest { + + @Mock(type = MockType.NICE) + private SessionBytesStoreSupplier supplier; + @Mock(type = MockType.NICE) + private SessionStore inner; + private SessionStoreBuilder builder; + + @Before + public void setUp() { + expect(supplier.get()).andReturn(inner); + expect(supplier.name()).andReturn("name"); + expect(supplier.metricsScope()).andReturn("metricScope"); + replay(supplier); + + builder = new SessionStoreBuilder<>( + supplier, + Serdes.String(), + Serdes.String(), + new MockTime()); + } + + @Test + public void shouldHaveMeteredStoreAsOuterStore() { + final SessionStore store = builder.build(); + assertThat(store, instanceOf(MeteredSessionStore.class)); + } + + @Test + public void shouldHaveChangeLoggingStoreByDefault() { + final SessionStore store = builder.build(); + final StateStore next = ((WrappedStateStore) store).wrapped(); + assertThat(next, instanceOf(ChangeLoggingSessionBytesStore.class)); + } + + @Test + public void shouldNotHaveChangeLoggingStoreWhenDisabled() { + final SessionStore store = builder.withLoggingDisabled().build(); + final StateStore next = ((WrappedStateStore) store).wrapped(); + assertThat(next, CoreMatchers.equalTo(inner)); + } + + @Test + public void shouldHaveCachingStoreWhenEnabled() { + final SessionStore store = builder.withCachingEnabled().build(); + final StateStore wrapped = ((WrappedStateStore) store).wrapped(); + assertThat(store, instanceOf(MeteredSessionStore.class)); + assertThat(wrapped, instanceOf(CachingSessionStore.class)); + } + + @Test + public void shouldHaveChangeLoggingStoreWhenLoggingEnabled() { + final SessionStore store = builder + .withLoggingEnabled(Collections.emptyMap()) + .build(); + final StateStore wrapped = ((WrappedStateStore) store).wrapped(); + assertThat(store, instanceOf(MeteredSessionStore.class)); + assertThat(wrapped, instanceOf(ChangeLoggingSessionBytesStore.class)); + assertThat(((WrappedStateStore) wrapped).wrapped(), CoreMatchers.equalTo(inner)); + } + + @Test + public void shouldHaveCachingAndChangeLoggingWhenBothEnabled() { + final SessionStore store = builder + .withLoggingEnabled(Collections.emptyMap()) + .withCachingEnabled() + .build(); + final WrappedStateStore caching = (WrappedStateStore) ((WrappedStateStore) store).wrapped(); + final WrappedStateStore changeLogging = (WrappedStateStore) caching.wrapped(); + assertThat(store, instanceOf(MeteredSessionStore.class)); + assertThat(caching, instanceOf(CachingSessionStore.class)); + assertThat(changeLogging, instanceOf(ChangeLoggingSessionBytesStore.class)); + assertThat(changeLogging.wrapped(), CoreMatchers.equalTo(inner)); + } + + @Test + public void shouldThrowNullPointerIfInnerIsNull() { + final Exception e = assertThrows(NullPointerException.class, () -> new SessionStoreBuilder<>(null, Serdes.String(), Serdes.String(), new MockTime())); + assertThat(e.getMessage(), equalTo("storeSupplier cannot be null")); + } + + @Test + public void shouldThrowNullPointerIfKeySerdeIsNull() { + final Exception e = assertThrows(NullPointerException.class, () -> new SessionStoreBuilder<>(supplier, null, Serdes.String(), new MockTime())); + assertThat(e.getMessage(), equalTo("name cannot be null")); + } + + @Test + public void shouldThrowNullPointerIfValueSerdeIsNull() { + final Exception e = assertThrows(NullPointerException.class, () -> new SessionStoreBuilder<>(supplier, Serdes.String(), null, new MockTime())); + assertThat(e.getMessage(), equalTo("name cannot be null")); + } + + @Test + public void shouldThrowNullPointerIfTimeIsNull() { + reset(supplier); + expect(supplier.name()).andReturn("name"); + replay(supplier); + final Exception e = assertThrows(NullPointerException.class, () -> new SessionStoreBuilder<>(supplier, Serdes.String(), Serdes.String(), null)); + assertThat(e.getMessage(), equalTo("time cannot be null")); + } + + @Test + public void shouldThrowNullPointerIfMetricsScopeIsNull() { + reset(supplier); + expect(supplier.get()).andReturn(new RocksDBSessionStore( + new RocksDBSegmentedBytesStore( + "name", + null, + 10L, + 5L, + new SessionKeySchema()) + )); + expect(supplier.name()).andReturn("name"); + replay(supplier); + + final Exception e = assertThrows(NullPointerException.class, + () -> new SessionStoreBuilder<>(supplier, Serdes.String(), Serdes.String(), new MockTime())); + assertThat(e.getMessage(), equalTo("storeSupplier's metricsScope can't be null")); + } + +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/SessionStoreFetchTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/SessionStoreFetchTest.java new file mode 100644 index 0000000..fbb6e00 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/SessionStoreFetchTest.java @@ -0,0 +1,333 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.SessionWindows; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.SessionWindow; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.SessionBytesStoreSupplier; +import org.apache.kafka.streams.state.SessionStore; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.test.TestUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestName; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.time.Duration; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Properties; +import java.util.function.Predicate; +import java.util.function.Supplier; + +import static java.time.Duration.ofMillis; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkProperties; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +@RunWith(Parameterized.class) +public class SessionStoreFetchTest { + private enum StoreType { InMemory, RocksDB }; + private static final String STORE_NAME = "store"; + private static final int DATA_SIZE = 5; + private static final long WINDOW_SIZE = 500L; + private static final long RETENTION_MS = 10000L; + + private StoreType storeType; + private boolean enableLogging; + private boolean enableCaching; + private boolean forward; + + private LinkedList, Long>> expectedRecords; + private LinkedList> records; + private Properties streamsConfig; + private String low; + private String high; + private String middle; + private String innerLow; + private String innerHigh; + private String innerLowBetween; + private String innerHighBetween; + + public SessionStoreFetchTest(final StoreType storeType, final boolean enableLogging, final boolean enableCaching, final boolean forward) { + this.storeType = storeType; + this.enableLogging = enableLogging; + this.enableCaching = enableCaching; + this.forward = forward; + + this.records = new LinkedList<>(); + this.expectedRecords = new LinkedList<>(); + final int m = DATA_SIZE / 2; + for (int i = 0; i < DATA_SIZE; i++) { + final String keyStr = i < m ? "a" : "b"; + final String key = "key-" + keyStr; + final String key2 = "key-" + keyStr + keyStr; + final String value = "val-" + i; + final KeyValue r = new KeyValue<>(key, value); + final KeyValue r2 = new KeyValue<>(key2, value); + records.add(r); + records.add(r2); + high = key; + if (low == null) { + low = key; + } + if (i == m) { + middle = key; + } + if (i == 1) { + innerLow = key; + final int index = i * 2 - 1; + innerLowBetween = "key-" + index; + } + if (i == DATA_SIZE - 2) { + innerHigh = key; + final int index = i * 2 + 1; + innerHighBetween = "key-" + index; + } + } + Assert.assertNotNull(low); + Assert.assertNotNull(high); + Assert.assertNotNull(middle); + Assert.assertNotNull(innerLow); + Assert.assertNotNull(innerHigh); + Assert.assertNotNull(innerLowBetween); + Assert.assertNotNull(innerHighBetween); + + expectedRecords.add(new KeyValue<>(new Windowed<>("key-a", new SessionWindow(0, 500)), 4L)); + expectedRecords.add(new KeyValue<>(new Windowed<>("key-aa", new SessionWindow(0, 500)), 4L)); + expectedRecords.add(new KeyValue<>(new Windowed<>("key-b", new SessionWindow(1500, 2000)), 6L)); + expectedRecords.add(new KeyValue<>(new Windowed<>("key-bb", new SessionWindow(1500, 2000)), 6L)); + } + + @Rule + public TestName testName = new TestName(); + + @Parameterized.Parameters(name = "storeType={0}, enableLogging={1}, enableCaching={2}, forward={3}") + public static Collection data() { + final List types = Arrays.asList(StoreType.InMemory, StoreType.RocksDB); + final List logging = Arrays.asList(true, false); + final List caching = Arrays.asList(true, false); + final List forward = Arrays.asList(true, false); + return buildParameters(types, logging, caching, forward); + } + + @Before + public void setup() { + streamsConfig = mkProperties(mkMap( + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()) + )); + } + + private void verifyNormalQuery(final SessionStore stateStore) { + try (final KeyValueIterator, Long> scanIterator = forward ? + stateStore.fetch("key-a", "key-bb") : + stateStore.backwardFetch("key-a", "key-bb")) { + + final Iterator, Long>> dataIterator = forward ? + expectedRecords.iterator() : + expectedRecords.descendingIterator(); + + TestUtils.checkEquals(scanIterator, dataIterator); + } + + try (final KeyValueIterator, Long> scanIterator = forward ? + stateStore.findSessions("key-a", "key-bb", 0L, Long.MAX_VALUE) : + stateStore.backwardFindSessions("key-a", "key-bb", 0L, Long.MAX_VALUE)) { + + final Iterator, Long>> dataIterator = forward ? + expectedRecords.iterator() : + expectedRecords.descendingIterator(); + + TestUtils.checkEquals(scanIterator, dataIterator); + } + } + + private void verifyInfiniteQuery(final SessionStore stateStore) { + try (final KeyValueIterator, Long> scanIterator = forward ? + stateStore.fetch(null, null) : + stateStore.backwardFetch(null, null)) { + + final Iterator, Long>> dataIterator = forward ? + expectedRecords.iterator() : + expectedRecords.descendingIterator(); + + TestUtils.checkEquals(scanIterator, dataIterator); + } + + try (final KeyValueIterator, Long> scanIterator = forward ? + stateStore.findSessions(null, null, 0L, Long.MAX_VALUE) : + stateStore.backwardFindSessions(null, null, 0L, Long.MAX_VALUE)) { + + final Iterator, Long>> dataIterator = forward ? + expectedRecords.iterator() : + expectedRecords.descendingIterator(); + + TestUtils.checkEquals(scanIterator, dataIterator); + } + } + + private void verifyRangeQuery(final SessionStore stateStore) { + testRange("range", stateStore, innerLow, innerHigh, forward); + testRange("until", stateStore, null, middle, forward); + testRange("from", stateStore, middle, null, forward); + + testRange("untilBetween", stateStore, null, innerHighBetween, forward); + testRange("fromBetween", stateStore, innerLowBetween, null, forward); + } + + @Test + public void testStoreConfig() { + final Materialized> stateStoreConfig = getStoreConfig(storeType, STORE_NAME, enableLogging, enableCaching); + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream stream = builder.stream("input", Consumed.with(Serdes.String(), Serdes.String())); + stream. + groupByKey(Grouped.with(Serdes.String(), Serdes.String())) + .windowedBy(SessionWindows.ofInactivityGapWithNoGrace(ofMillis(WINDOW_SIZE))) + .count(stateStoreConfig) + .toStream() + .to("output"); + + final Topology topology = builder.build(); + + try (final TopologyTestDriver driver = new TopologyTestDriver(topology)) { + //get input topic and stateStore + final TestInputTopic input = driver + .createInputTopic("input", new StringSerializer(), new StringSerializer()); + final SessionStore stateStore = driver.getSessionStore(STORE_NAME); + + //write some data + final int medium = DATA_SIZE / 2 * 2; + for (int i = 0; i < records.size(); i++) { + final KeyValue kv = records.get(i); + final long windowStartTime = i < medium ? 0 : 1500; + input.pipeInput(kv.key, kv.value, windowStartTime); + input.pipeInput(kv.key, kv.value, windowStartTime + WINDOW_SIZE); + } + + verifyNormalQuery(stateStore); + verifyInfiniteQuery(stateStore); + verifyRangeQuery(stateStore); + } + } + + + private List, Long>> filterList(final KeyValueIterator, Long> iterator, final String from, final String to) { + final Predicate, Long>> pred = new Predicate, Long>>() { + @Override + public boolean test(final KeyValue, Long> elem) { + if (from != null && elem.key.key().compareTo(from) < 0) { + return false; + } + if (to != null && elem.key.key().compareTo(to) > 0) { + return false; + } + return elem != null; + } + }; + + return Utils.toList(iterator, pred); + } + + private void testRange(final String name, final SessionStore store, final String from, final String to, final boolean forward) { + try (final KeyValueIterator, Long> resultIterator = forward ? store.fetch(from, to) : store.backwardFetch(from, to); + final KeyValueIterator, Long> expectedIterator = forward ? store.fetch(null, null) : store.backwardFetch(null, null)) { + final List, Long>> result = Utils.toList(resultIterator); + final List, Long>> expected = filterList(expectedIterator, from, to); + assertThat(result, is(expected)); + } + } + + private static Collection buildParameters(final List... argOptions) { + List result = new LinkedList<>(); + result.add(new Object[0]); + + for (final List argOption : argOptions) { + result = times(result, argOption); + } + + return result; + } + + private static List times(final List left, final List right) { + final List result = new LinkedList<>(); + for (final Object[] args : left) { + for (final Object rightElem : right) { + final Object[] resArgs = new Object[args.length + 1]; + System.arraycopy(args, 0, resArgs, 0, args.length); + resArgs[args.length] = rightElem; + result.add(resArgs); + } + } + return result; + } + + private Materialized> getStoreConfig(final StoreType type, final String name, final boolean cachingEnabled, final boolean loggingEnabled) { + final Supplier createStore = () -> { + if (type == StoreType.InMemory) { + return Stores.inMemorySessionStore(STORE_NAME, Duration.ofMillis(RETENTION_MS)); + } else if (type == StoreType.RocksDB) { + return Stores.persistentSessionStore(STORE_NAME, Duration.ofMillis(RETENTION_MS)); + } else { + return Stores.inMemorySessionStore(STORE_NAME, Duration.ofMillis(RETENTION_MS)); + } + }; + + final SessionBytesStoreSupplier stateStoreSupplier = createStore.get(); + final Materialized> stateStoreConfig = Materialized + .as(stateStoreSupplier) + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.Long()); + if (cachingEnabled) { + stateStoreConfig.withCachingEnabled(); + } else { + stateStoreConfig.withCachingDisabled(); + } + if (loggingEnabled) { + stateStoreConfig.withLoggingEnabled(new HashMap()); + } else { + stateStoreConfig.withLoggingDisabled(); + } + return stateStoreConfig; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java new file mode 100644 index 0000000..c56f7bc --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/StreamThreadStateStoreProviderTest.java @@ -0,0 +1,487 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.StoreQueryParameters; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.TopologyWrapper; +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; +import org.apache.kafka.streams.processor.internals.MockStreamsMetrics; +import org.apache.kafka.streams.processor.internals.ProcessorContextImpl; +import org.apache.kafka.streams.processor.internals.ProcessorStateManager; +import org.apache.kafka.streams.processor.internals.ProcessorTopology; +import org.apache.kafka.streams.processor.internals.RecordCollector; +import org.apache.kafka.streams.processor.internals.RecordCollectorImpl; +import org.apache.kafka.streams.processor.internals.StateDirectory; +import org.apache.kafka.streams.processor.internals.StoreChangelogReader; +import org.apache.kafka.streams.processor.internals.StreamTask; +import org.apache.kafka.streams.processor.internals.StreamThread; +import org.apache.kafka.streams.processor.internals.StreamsProducer; +import org.apache.kafka.streams.processor.internals.Task; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.QueryableStoreTypes; +import org.apache.kafka.streams.state.ReadOnlyKeyValueStore; +import org.apache.kafka.streams.state.ReadOnlySessionStore; +import org.apache.kafka.streams.state.ReadOnlyWindowStore; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.TimestampedWindowStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.test.MockApiProcessorSupplier; +import org.apache.kafka.test.MockClientSupplier; +import org.apache.kafka.test.MockStateRestoreListener; +import org.apache.kafka.test.TestUtils; +import org.easymock.EasyMock; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.io.IOException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.UUID; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +public class StreamThreadStateStoreProviderTest { + + private StreamTask taskOne; + private StreamThreadStateStoreProvider provider; + private StateDirectory stateDirectory; + private File stateDir; + private final String topicName = "topic"; + private StreamThread threadMock; + private Map tasks; + + @Before + public void before() { + final TopologyWrapper topology = new TopologyWrapper(); + topology.addSource("the-source", topicName); + topology.addProcessor("the-processor", new MockApiProcessorSupplier<>(), "the-source"); + topology.addStateStore( + Stores.keyValueStoreBuilder( + Stores.inMemoryKeyValueStore("kv-store"), + Serdes.String(), + Serdes.String()), + "the-processor"); + topology.addStateStore( + Stores.timestampedKeyValueStoreBuilder( + Stores.inMemoryKeyValueStore("timestamped-kv-store"), + Serdes.String(), + Serdes.String()), + "the-processor"); + topology.addStateStore( + Stores.windowStoreBuilder( + Stores.inMemoryWindowStore( + "window-store", + Duration.ofMillis(10L), + Duration.ofMillis(2L), + false), + Serdes.String(), + Serdes.String()), + "the-processor"); + topology.addStateStore( + Stores.timestampedWindowStoreBuilder( + Stores.inMemoryWindowStore( + "timestamped-window-store", + Duration.ofMillis(10L), + Duration.ofMillis(2L), + false), + Serdes.String(), + Serdes.String()), + "the-processor"); + topology.addStateStore( + Stores.sessionStoreBuilder( + Stores.inMemorySessionStore( + "session-store", + Duration.ofMillis(10L)), + Serdes.String(), + Serdes.String()), + "the-processor"); + + final Properties properties = new Properties(); + final String applicationId = "applicationId"; + properties.put(StreamsConfig.APPLICATION_ID_CONFIG, applicationId); + properties.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + stateDir = TestUtils.tempDirectory(); + properties.put(StreamsConfig.STATE_DIR_CONFIG, stateDir.getPath()); + + final StreamsConfig streamsConfig = new StreamsConfig(properties); + final MockClientSupplier clientSupplier = new MockClientSupplier(); + configureClients(clientSupplier, "applicationId-kv-store-changelog"); + configureClients(clientSupplier, "applicationId-window-store-changelog"); + + final InternalTopologyBuilder internalTopologyBuilder = topology.getInternalBuilder(applicationId); + final ProcessorTopology processorTopology = internalTopologyBuilder.buildTopology(); + + tasks = new HashMap<>(); + stateDirectory = new StateDirectory(streamsConfig, new MockTime(), true, false); + + taskOne = createStreamsTask( + streamsConfig, + clientSupplier, + processorTopology, + new TaskId(0, 0)); + taskOne.initializeIfNeeded(); + tasks.put(new TaskId(0, 0), taskOne); + + final StreamTask taskTwo = createStreamsTask( + streamsConfig, + clientSupplier, + processorTopology, + new TaskId(0, 1)); + taskTwo.initializeIfNeeded(); + tasks.put(new TaskId(0, 1), taskTwo); + + threadMock = EasyMock.createNiceMock(StreamThread.class); + provider = new StreamThreadStateStoreProvider(threadMock); + + } + + @After + public void cleanUp() throws IOException { + Utils.delete(stateDir); + } + + @Test + public void shouldFindKeyValueStores() { + mockThread(true); + final List> kvStores = + provider.stores(StoreQueryParameters.fromNameAndType("kv-store", QueryableStoreTypes.keyValueStore())); + assertEquals(2, kvStores.size()); + for (final ReadOnlyKeyValueStore store: kvStores) { + assertThat(store, instanceOf(ReadOnlyKeyValueStore.class)); + assertThat(store, not(instanceOf(TimestampedKeyValueStore.class))); + } + } + + @Test + public void shouldFindTimestampedKeyValueStores() { + mockThread(true); + final List>> tkvStores = + provider.stores(StoreQueryParameters.fromNameAndType("timestamped-kv-store", QueryableStoreTypes.timestampedKeyValueStore())); + assertEquals(2, tkvStores.size()); + for (final ReadOnlyKeyValueStore> store: tkvStores) { + assertThat(store, instanceOf(ReadOnlyKeyValueStore.class)); + assertThat(store, instanceOf(TimestampedKeyValueStore.class)); + } + } + + @Test + public void shouldNotFindKeyValueStoresAsTimestampedStore() { + mockThread(true); + final InvalidStateStoreException exception = assertThrows( + InvalidStateStoreException.class, + () -> provider.stores(StoreQueryParameters.fromNameAndType("kv-store", QueryableStoreTypes.timestampedKeyValueStore())) + ); + assertThat( + exception.getMessage(), + is( + "Cannot get state store kv-store because the queryable store type " + + "[class org.apache.kafka.streams.state.QueryableStoreTypes$TimestampedKeyValueStoreType] " + + "does not accept the actual store type " + + "[class org.apache.kafka.streams.state.internals.MeteredKeyValueStore]." + ) + ); + } + + @Test + public void shouldFindTimestampedKeyValueStoresAsKeyValueStores() { + mockThread(true); + final List>> tkvStores = + provider.stores(StoreQueryParameters.fromNameAndType("timestamped-kv-store", QueryableStoreTypes.keyValueStore())); + assertEquals(2, tkvStores.size()); + for (final ReadOnlyKeyValueStore> store: tkvStores) { + assertThat(store, instanceOf(ReadOnlyKeyValueStore.class)); + assertThat(store, not(instanceOf(TimestampedKeyValueStore.class))); + } + } + + @Test + public void shouldFindWindowStores() { + mockThread(true); + final List> windowStores = + provider.stores(StoreQueryParameters.fromNameAndType("window-store", QueryableStoreTypes.windowStore())); + assertEquals(2, windowStores.size()); + for (final ReadOnlyWindowStore store: windowStores) { + assertThat(store, instanceOf(ReadOnlyWindowStore.class)); + assertThat(store, not(instanceOf(TimestampedWindowStore.class))); + } + } + + @Test + public void shouldFindTimestampedWindowStores() { + mockThread(true); + final List>> windowStores = + provider.stores(StoreQueryParameters.fromNameAndType("timestamped-window-store", QueryableStoreTypes.timestampedWindowStore())); + assertEquals(2, windowStores.size()); + for (final ReadOnlyWindowStore> store: windowStores) { + assertThat(store, instanceOf(ReadOnlyWindowStore.class)); + assertThat(store, instanceOf(TimestampedWindowStore.class)); + } + } + + @Test + public void shouldNotFindWindowStoresAsTimestampedStore() { + mockThread(true); + final InvalidStateStoreException exception = assertThrows( + InvalidStateStoreException.class, + () -> provider.stores(StoreQueryParameters.fromNameAndType("window-store", QueryableStoreTypes.timestampedWindowStore())) + ); + assertThat( + exception.getMessage(), + is( + "Cannot get state store window-store because the queryable store type " + + "[class org.apache.kafka.streams.state.QueryableStoreTypes$TimestampedWindowStoreType] " + + "does not accept the actual store type " + + "[class org.apache.kafka.streams.state.internals.MeteredWindowStore]." + ) + ); + } + + @Test + public void shouldFindTimestampedWindowStoresAsWindowStore() { + mockThread(true); + final List>> windowStores = + provider.stores(StoreQueryParameters.fromNameAndType("timestamped-window-store", QueryableStoreTypes.windowStore())); + assertEquals(2, windowStores.size()); + for (final ReadOnlyWindowStore> store: windowStores) { + assertThat(store, instanceOf(ReadOnlyWindowStore.class)); + assertThat(store, not(instanceOf(TimestampedWindowStore.class))); + } + } + + @Test + public void shouldFindSessionStores() { + mockThread(true); + final List> sessionStores = + provider.stores(StoreQueryParameters.fromNameAndType("session-store", QueryableStoreTypes.sessionStore())); + assertEquals(2, sessionStores.size()); + for (final ReadOnlySessionStore store: sessionStores) { + assertThat(store, instanceOf(ReadOnlySessionStore.class)); + } + } + + @Test + public void shouldThrowInvalidStoreExceptionIfKVStoreClosed() { + mockThread(true); + taskOne.getStore("kv-store").close(); + assertThrows(InvalidStateStoreException.class, () -> provider.stores(StoreQueryParameters.fromNameAndType("kv-store", + QueryableStoreTypes.keyValueStore()))); + } + + @Test + public void shouldThrowInvalidStoreExceptionIfTsKVStoreClosed() { + mockThread(true); + taskOne.getStore("timestamped-kv-store").close(); + assertThrows(InvalidStateStoreException.class, () -> provider.stores(StoreQueryParameters.fromNameAndType("timestamped-kv-store", + QueryableStoreTypes.timestampedKeyValueStore()))); + } + + @Test + public void shouldThrowInvalidStoreExceptionIfWindowStoreClosed() { + mockThread(true); + taskOne.getStore("window-store").close(); + assertThrows(InvalidStateStoreException.class, () -> provider.stores(StoreQueryParameters.fromNameAndType("window-store", + QueryableStoreTypes.windowStore()))); + } + + @Test + public void shouldThrowInvalidStoreExceptionIfTsWindowStoreClosed() { + mockThread(true); + taskOne.getStore("timestamped-window-store").close(); + assertThrows(InvalidStateStoreException.class, () -> provider.stores(StoreQueryParameters.fromNameAndType("timestamped-window-store", + QueryableStoreTypes.timestampedWindowStore()))); + } + + @Test + public void shouldThrowInvalidStoreExceptionIfSessionStoreClosed() { + mockThread(true); + taskOne.getStore("session-store").close(); + assertThrows(InvalidStateStoreException.class, () -> provider.stores(StoreQueryParameters.fromNameAndType("session-store", + QueryableStoreTypes.sessionStore()))); + } + + @Test + public void shouldReturnEmptyListIfNoStoresFoundWithName() { + mockThread(true); + assertEquals( + Collections.emptyList(), + provider.stores(StoreQueryParameters.fromNameAndType("not-a-store", QueryableStoreTypes.keyValueStore()))); + } + + @Test + public void shouldReturnSingleStoreForPartition() { + mockThread(true); + { + final List> kvStores = + provider.stores( + StoreQueryParameters + .fromNameAndType("kv-store", QueryableStoreTypes.keyValueStore()) + .withPartition(0)); + assertEquals(1, kvStores.size()); + for (final ReadOnlyKeyValueStore store : kvStores) { + assertThat(store, instanceOf(ReadOnlyKeyValueStore.class)); + assertThat(store, not(instanceOf(TimestampedKeyValueStore.class))); + } + } + { + final List> kvStores = + provider.stores( + StoreQueryParameters + .fromNameAndType("kv-store", QueryableStoreTypes.keyValueStore()) + .withPartition(1)); + assertEquals(1, kvStores.size()); + for (final ReadOnlyKeyValueStore store : kvStores) { + assertThat(store, instanceOf(ReadOnlyKeyValueStore.class)); + assertThat(store, not(instanceOf(TimestampedKeyValueStore.class))); + } + } + } + + @Test + public void shouldReturnEmptyListForInvalidPartitions() { + mockThread(true); + assertEquals( + Collections.emptyList(), + provider.stores(StoreQueryParameters.fromNameAndType("kv-store", QueryableStoreTypes.keyValueStore()).withPartition(2)) + ); + } + + @Test + public void shouldThrowInvalidStoreExceptionIfNotAllStoresAvailable() { + mockThread(false); + assertThrows(InvalidStateStoreException.class, () -> provider.stores(StoreQueryParameters.fromNameAndType("kv-store", + QueryableStoreTypes.keyValueStore()))); + } + + private StreamTask createStreamsTask(final StreamsConfig streamsConfig, + final MockClientSupplier clientSupplier, + final ProcessorTopology topology, + final TaskId taskId) { + final Metrics metrics = new Metrics(); + final LogContext logContext = new LogContext("test-stream-task "); + final Set partitions = Collections.singleton(new TopicPartition(topicName, taskId.partition())); + final ProcessorStateManager stateManager = new ProcessorStateManager( + taskId, + Task.TaskType.ACTIVE, + StreamThread.eosEnabled(streamsConfig), + logContext, + stateDirectory, + new StoreChangelogReader( + new MockTime(), + streamsConfig, + logContext, + clientSupplier.adminClient, + clientSupplier.restoreConsumer, + new MockStateRestoreListener()), + topology.storeToChangelogTopic(), partitions); + final RecordCollector recordCollector = new RecordCollectorImpl( + logContext, + taskId, + new StreamsProducer( + streamsConfig, + "threadId", + clientSupplier, + new TaskId(0, 0), + UUID.randomUUID(), + logContext, + Time.SYSTEM + ), + streamsConfig.defaultProductionExceptionHandler(), + new MockStreamsMetrics(metrics)); + final StreamsMetricsImpl streamsMetrics = new MockStreamsMetrics(metrics); + final InternalProcessorContext context = new ProcessorContextImpl( + taskId, + streamsConfig, + stateManager, + streamsMetrics, + null + ); + return new StreamTask( + taskId, + partitions, + topology, + clientSupplier.consumer, + streamsConfig, + streamsMetrics, + stateDirectory, + EasyMock.createNiceMock(ThreadCache.class), + new MockTime(), + stateManager, + recordCollector, + context, logContext); + } + + private void mockThread(final boolean initialized) { + EasyMock.expect(threadMock.isRunning()).andReturn(initialized); + EasyMock.expect(threadMock.allTasks()).andStubReturn(tasks); + EasyMock.expect(threadMock.activeTaskMap()).andStubReturn(tasks); + EasyMock.expect(threadMock.activeTasks()).andStubReturn(new ArrayList<>(tasks.values())); + EasyMock.expect(threadMock.state()).andReturn( + initialized ? StreamThread.State.RUNNING : StreamThread.State.PARTITIONS_ASSIGNED + ).anyTimes(); + EasyMock.replay(threadMock); + } + + private void configureClients(final MockClientSupplier clientSupplier, final String topic) { + final List partitions = Arrays.asList( + new PartitionInfo(topic, 0, null, null, null), + new PartitionInfo(topic, 1, null, null, null) + ); + clientSupplier.restoreConsumer.updatePartitions(topic, partitions); + final TopicPartition tp1 = new TopicPartition(topic, 0); + final TopicPartition tp2 = new TopicPartition(topic, 1); + + clientSupplier.restoreConsumer.assign(Arrays.asList(tp1, tp2)); + + final Map offsets = new HashMap<>(); + offsets.put(tp1, 0L); + offsets.put(tp2, 0L); + + clientSupplier.restoreConsumer.updateBeginningOffsets(offsets); + clientSupplier.restoreConsumer.updateEndOffsets(offsets); + + clientSupplier.adminClient.updateBeginningOffsets(offsets); + clientSupplier.adminClient.updateEndOffsets(offsets); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/ThreadCacheTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/ThreadCacheTest.java new file mode 100644 index 0000000..805d295 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/ThreadCacheTest.java @@ -0,0 +1,627 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + + +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.internals.MockStreamsMetrics; +import static org.hamcrest.MatcherAssert.assertThat; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.function.Supplier; + +import static org.hamcrest.core.Is.is; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +public class ThreadCacheTest { + final String namespace = "0.0-namespace"; + final String namespace1 = "0.1-namespace"; + final String namespace2 = "0.2-namespace"; + private final LogContext logContext = new LogContext("testCache "); + private final byte[][] bytes = new byte[][]{{0}, {1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}, {10}}; + + @Test + public void basicPutGet() { + final List> toInsert = Arrays.asList( + new KeyValue<>("K1", "V1"), + new KeyValue<>("K2", "V2"), + new KeyValue<>("K3", "V3"), + new KeyValue<>("K4", "V4"), + new KeyValue<>("K5", "V5")); + final KeyValue kv = toInsert.get(0); + final ThreadCache cache = new ThreadCache(logContext, + toInsert.size() * memoryCacheEntrySize(kv.key.getBytes(), kv.value.getBytes(), ""), + new MockStreamsMetrics(new Metrics())); + + for (final KeyValue kvToInsert : toInsert) { + final Bytes key = Bytes.wrap(kvToInsert.key.getBytes()); + final byte[] value = kvToInsert.value.getBytes(); + cache.put(namespace, key, new LRUCacheEntry(value, new RecordHeaders(), true, 1L, 1L, 1, "")); + } + + for (final KeyValue kvToInsert : toInsert) { + final Bytes key = Bytes.wrap(kvToInsert.key.getBytes()); + final LRUCacheEntry entry = cache.get(namespace, key); + assertTrue(entry.isDirty()); + assertEquals(new String(entry.value()), kvToInsert.value); + } + assertEquals(cache.gets(), 5); + assertEquals(cache.puts(), 5); + assertEquals(cache.evicts(), 0); + assertEquals(cache.flushes(), 0); + } + + private void checkOverheads(final double entryFactor, + final double systemFactor, + final long desiredCacheSize, + final int keySizeBytes, + final int valueSizeBytes) { + final Runtime runtime = Runtime.getRuntime(); + final long numElements = desiredCacheSize / memoryCacheEntrySize(new byte[keySizeBytes], new byte[valueSizeBytes], ""); + + System.gc(); + final long prevRuntimeMemory = runtime.totalMemory() - runtime.freeMemory(); + + final ThreadCache cache = new ThreadCache(logContext, desiredCacheSize, new MockStreamsMetrics(new Metrics())); + final long size = cache.sizeBytes(); + assertEquals(size, 0); + for (int i = 0; i < numElements; i++) { + final String keyStr = "K" + i; + final Bytes key = Bytes.wrap(keyStr.getBytes()); + final byte[] value = new byte[valueSizeBytes]; + cache.put(namespace, key, new LRUCacheEntry(value, new RecordHeaders(), true, 1L, 1L, 1, "")); + } + + + System.gc(); + final double ceiling = desiredCacheSize + desiredCacheSize * entryFactor; + final long usedRuntimeMemory = runtime.totalMemory() - runtime.freeMemory() - prevRuntimeMemory; + assertTrue((double) cache.sizeBytes() <= ceiling); + + assertTrue("Used memory size " + usedRuntimeMemory + " greater than expected " + cache.sizeBytes() * systemFactor, + cache.sizeBytes() * systemFactor >= usedRuntimeMemory); + } + + @Test + public void cacheOverheadsSmallValues() { + final Runtime runtime = Runtime.getRuntime(); + final double factor = 0.05; + final double systemFactor = 3; // if I ask for a cache size of 10 MB, accept an overhead of 3x, i.e., 30 MBs might be allocated + final long desiredCacheSize = Math.min(100 * 1024 * 1024L, runtime.maxMemory()); + final int keySizeBytes = 8; + final int valueSizeBytes = 100; + + checkOverheads(factor, systemFactor, desiredCacheSize, keySizeBytes, valueSizeBytes); + } + + @Test + public void cacheOverheadsLargeValues() { + final Runtime runtime = Runtime.getRuntime(); + final double factor = 0.05; + final double systemFactor = 2; // if I ask for a cache size of 10 MB, accept an overhead of 2x, i.e., 20 MBs might be allocated + final long desiredCacheSize = Math.min(100 * 1024 * 1024L, runtime.maxMemory()); + final int keySizeBytes = 8; + final int valueSizeBytes = 1000; + + checkOverheads(factor, systemFactor, desiredCacheSize, keySizeBytes, valueSizeBytes); + } + + + static long memoryCacheEntrySize(final byte[] key, final byte[] value, final String topic) { + return key.length + + value.length + + 1 + // isDirty + 8 + // timestamp + 8 + // offset + 4 + + topic.length() + + // LRU Node entries + key.length + + 8 + // entry + 8 + // previous + 8; // next + } + + @Test + public void evict() { + final List> received = new ArrayList<>(); + final List> expected = Collections.singletonList( + new KeyValue<>("K1", "V1")); + + final List> toInsert = Arrays.asList( + new KeyValue<>("K1", "V1"), + new KeyValue<>("K2", "V2"), + new KeyValue<>("K3", "V3"), + new KeyValue<>("K4", "V4"), + new KeyValue<>("K5", "V5")); + final KeyValue kv = toInsert.get(0); + final ThreadCache cache = new ThreadCache(logContext, + memoryCacheEntrySize(kv.key.getBytes(), kv.value.getBytes(), ""), + new MockStreamsMetrics(new Metrics())); + cache.addDirtyEntryFlushListener(namespace, dirty -> { + for (final ThreadCache.DirtyEntry dirtyEntry : dirty) { + received.add(new KeyValue<>(dirtyEntry.key().toString(), new String(dirtyEntry.newValue()))); + } + }); + + for (final KeyValue kvToInsert : toInsert) { + final Bytes key = Bytes.wrap(kvToInsert.key.getBytes()); + final byte[] value = kvToInsert.value.getBytes(); + cache.put(namespace, key, new LRUCacheEntry(value, new RecordHeaders(), true, 1, 1, 1, "")); + } + + for (int i = 0; i < expected.size(); i++) { + final KeyValue expectedRecord = expected.get(i); + final KeyValue actualRecord = received.get(i); + assertEquals(expectedRecord, actualRecord); + } + assertEquals(cache.evicts(), 4); + } + + @Test + public void shouldDelete() { + final ThreadCache cache = new ThreadCache(logContext, 10000L, new MockStreamsMetrics(new Metrics())); + final Bytes key = Bytes.wrap(new byte[]{0}); + + cache.put(namespace, key, dirtyEntry(key.get())); + assertEquals(key.get(), cache.delete(namespace, key).value()); + assertNull(cache.get(namespace, key)); + } + + @Test + public void shouldNotFlushAfterDelete() { + final Bytes key = Bytes.wrap(new byte[]{0}); + final ThreadCache cache = new ThreadCache(logContext, 10000L, new MockStreamsMetrics(new Metrics())); + final List received = new ArrayList<>(); + cache.addDirtyEntryFlushListener(namespace, received::addAll); + cache.put(namespace, key, dirtyEntry(key.get())); + assertEquals(key.get(), cache.delete(namespace, key).value()); + + // flushing should have no further effect + cache.flush(namespace); + assertEquals(0, received.size()); + assertEquals(cache.flushes(), 1); + } + + @Test + public void shouldNotBlowUpOnNonExistentKeyWhenDeleting() { + final Bytes key = Bytes.wrap(new byte[]{0}); + final ThreadCache cache = new ThreadCache(logContext, 10000L, new MockStreamsMetrics(new Metrics())); + + cache.put(namespace, key, dirtyEntry(key.get())); + assertNull(cache.delete(namespace, Bytes.wrap(new byte[]{1}))); + } + + @Test + public void shouldNotBlowUpOnNonExistentNamespaceWhenDeleting() { + final ThreadCache cache = new ThreadCache(logContext, 10000L, new MockStreamsMetrics(new Metrics())); + assertNull(cache.delete(namespace, Bytes.wrap(new byte[]{1}))); + } + + @Test + public void shouldNotClashWithOverlappingNames() { + final ThreadCache cache = new ThreadCache(logContext, 10000L, new MockStreamsMetrics(new Metrics())); + final Bytes nameByte = Bytes.wrap(new byte[]{0}); + final Bytes name1Byte = Bytes.wrap(new byte[]{1}); + cache.put(namespace1, nameByte, dirtyEntry(nameByte.get())); + cache.put(namespace2, nameByte, dirtyEntry(name1Byte.get())); + + assertArrayEquals(nameByte.get(), cache.get(namespace1, nameByte).value()); + assertArrayEquals(name1Byte.get(), cache.get(namespace2, nameByte).value()); + } + + private ThreadCache setupThreadCache(final int first, final int last, final long entrySize, final boolean reverse) { + final ThreadCache cache = new ThreadCache(logContext, entrySize, new MockStreamsMetrics(new Metrics())); + cache.addDirtyEntryFlushListener(namespace, dirty -> { }); + int index = first; + while ((!reverse && index < last) || (reverse && index >= last)) { + cache.put(namespace, Bytes.wrap(bytes[index]), dirtyEntry(bytes[index])); + if (!reverse) + index++; + else + index--; + } + return cache; + } + + @Test + public void shouldPeekNextKey() { + final ThreadCache cache = setupThreadCache(0, 1, 10000L, false); + final Bytes theByte = Bytes.wrap(new byte[]{0}); + final ThreadCache.MemoryLRUCacheBytesIterator iterator = cache.range(namespace, theByte, Bytes.wrap(new byte[]{1})); + assertEquals(theByte, iterator.peekNextKey()); + assertEquals(theByte, iterator.peekNextKey()); + } + + @Test + public void shouldPeekNextKeyReverseRange() { + final ThreadCache cache = setupThreadCache(1, 1, 10000L, true); + final Bytes theByte = Bytes.wrap(new byte[]{1}); + final ThreadCache.MemoryLRUCacheBytesIterator iterator = cache.reverseRange(namespace, Bytes.wrap(new byte[]{0}), theByte); + assertThat(iterator.peekNextKey(), is(theByte)); + assertThat(iterator.peekNextKey(), is(theByte)); + } + + @Test + public void shouldGetSameKeyAsPeekNext() { + final ThreadCache cache = setupThreadCache(0, 1, 10000L, false); + final Bytes theByte = Bytes.wrap(new byte[]{0}); + final ThreadCache.MemoryLRUCacheBytesIterator iterator = cache.range(namespace, theByte, Bytes.wrap(new byte[]{1})); + assertThat(iterator.peekNextKey(), is(iterator.next().key)); + } + + @Test + public void shouldGetSameKeyAsPeekNextReverseRange() { + final ThreadCache cache = setupThreadCache(1, 1, 10000L, true); + final Bytes theByte = Bytes.wrap(new byte[]{1}); + final ThreadCache.MemoryLRUCacheBytesIterator iterator = cache.reverseRange(namespace, Bytes.wrap(new byte[]{0}), theByte); + assertThat(iterator.peekNextKey(), is(iterator.next().key)); + } + + private void shouldThrowIfNoPeekNextKey(final Supplier methodUnderTest) { + final ThreadCache.MemoryLRUCacheBytesIterator iterator = methodUnderTest.get(); + assertThrows(NoSuchElementException.class, iterator::peekNextKey); + } + + @Test + public void shouldThrowIfNoPeekNextKeyRange() { + final ThreadCache cache = setupThreadCache(0, 0, 10000L, false); + shouldThrowIfNoPeekNextKey(() -> cache.range(namespace, Bytes.wrap(new byte[]{0}), Bytes.wrap(new byte[]{1}))); + } + + @Test + public void shouldThrowIfNoPeekNextKeyReverseRange() { + final ThreadCache cache = setupThreadCache(-1, 0, 10000L, true); + shouldThrowIfNoPeekNextKey(() -> cache.reverseRange(namespace, Bytes.wrap(new byte[]{0}), Bytes.wrap(new byte[]{1}))); + } + + @Test + public void shouldReturnFalseIfNoNextKey() { + final ThreadCache cache = setupThreadCache(0, 0, 10000L, false); + final ThreadCache.MemoryLRUCacheBytesIterator iterator = cache.range(namespace, Bytes.wrap(new byte[]{0}), Bytes.wrap(new byte[]{1})); + assertFalse(iterator.hasNext()); + } + + @Test + public void shouldReturnFalseIfNoNextKeyReverseRange() { + final ThreadCache cache = setupThreadCache(-1, 0, 10000L, true); + final ThreadCache.MemoryLRUCacheBytesIterator iterator = cache.reverseRange(namespace, Bytes.wrap(new byte[]{0}), Bytes.wrap(new byte[]{1})); + assertFalse(iterator.hasNext()); + } + + @Test + public void shouldPeekAndIterateOverRange() { + final ThreadCache cache = setupThreadCache(0, 10, 10000L, false); + final ThreadCache.MemoryLRUCacheBytesIterator iterator = cache.range(namespace, Bytes.wrap(new byte[]{1}), Bytes.wrap(new byte[]{4})); + int bytesIndex = 1; + while (iterator.hasNext()) { + final Bytes peekedKey = iterator.peekNextKey(); + final KeyValue next = iterator.next(); + assertArrayEquals(bytes[bytesIndex], peekedKey.get()); + assertArrayEquals(bytes[bytesIndex], next.key.get()); + bytesIndex++; + } + assertEquals(5, bytesIndex); + } + + @Test + public void shouldSkipToEntryWhenToInclusiveIsFalseInRange() { + final ThreadCache cache = setupThreadCache(0, 10, 10000L, false); + final ThreadCache.MemoryLRUCacheBytesIterator iterator = cache.range(namespace, Bytes.wrap(new byte[]{1}), Bytes.wrap(new byte[]{4}), false); + int bytesIndex = 1; + while (iterator.hasNext()) { + final Bytes peekedKey = iterator.peekNextKey(); + final KeyValue next = iterator.next(); + assertArrayEquals(bytes[bytesIndex], peekedKey.get()); + assertArrayEquals(bytes[bytesIndex], next.key.get()); + bytesIndex++; + } + assertEquals(4, bytesIndex); + } + + @Test + public void shouldPeekAndIterateOverReverseRange() { + final ThreadCache cache = setupThreadCache(10, 0, 10000L, true); + final ThreadCache.MemoryLRUCacheBytesIterator iterator = cache.reverseRange(namespace, Bytes.wrap(new byte[]{1}), Bytes.wrap(new byte[]{4})); + int bytesIndex = 4; + while (iterator.hasNext()) { + final Bytes peekedKey = iterator.peekNextKey(); + final KeyValue next = iterator.next(); + assertArrayEquals(bytes[bytesIndex], peekedKey.get()); + assertArrayEquals(bytes[bytesIndex], next.key.get()); + bytesIndex--; + } + assertEquals(0, bytesIndex); + } + + @Test + public void shouldSkipEntriesWhereValueHasBeenEvictedFromCache() { + final long entrySize = memoryCacheEntrySize(new byte[1], new byte[1], ""); + final ThreadCache cache = setupThreadCache(0, 5, entrySize * 5L, false); + assertEquals(5, cache.size()); + // should evict byte[] {0} + cache.put(namespace, Bytes.wrap(new byte[]{6}), dirtyEntry(new byte[]{6})); + final ThreadCache.MemoryLRUCacheBytesIterator range = cache.range(namespace, Bytes.wrap(new byte[]{0}), Bytes.wrap(new byte[]{5})); + assertEquals(Bytes.wrap(new byte[]{1}), range.peekNextKey()); + } + + @Test + public void shouldSkipEntriesWhereValueHasBeenEvictedFromCacheReverseRange() { + final long entrySize = memoryCacheEntrySize(new byte[1], new byte[1], ""); + final ThreadCache cache = setupThreadCache(4, 0, entrySize * 5L, true); + assertEquals(5, cache.size()); + // should evict byte[] {4} + cache.put(namespace, Bytes.wrap(new byte[]{6}), dirtyEntry(new byte[]{6})); + final ThreadCache.MemoryLRUCacheBytesIterator range = cache.reverseRange(namespace, Bytes.wrap(new byte[]{0}), Bytes.wrap(new byte[]{5})); + assertEquals(Bytes.wrap(new byte[]{3}), range.peekNextKey()); + } + + @Test + public void shouldFetchAllEntriesInCache() { + final ThreadCache cache = setupThreadCache(0, 11, 10000L, false); + final ThreadCache.MemoryLRUCacheBytesIterator iterator = cache.all(namespace); + int bytesIndex = 0; + while (iterator.hasNext()) { + final Bytes peekedKey = iterator.peekNextKey(); + final KeyValue next = iterator.next(); + assertArrayEquals(bytes[bytesIndex], peekedKey.get()); + assertArrayEquals(bytes[bytesIndex], next.key.get()); + bytesIndex++; + } + assertEquals(11, bytesIndex); + } + + @Test + public void shouldFetchAllEntriesInCacheInReverseOrder() { + final ThreadCache cache = setupThreadCache(10, 0, 10000L, true); + final ThreadCache.MemoryLRUCacheBytesIterator iterator = cache.reverseAll(namespace); + int bytesIndex = 10; + while (iterator.hasNext()) { + final Bytes peekedKey = iterator.peekNextKey(); + final KeyValue next = iterator.next(); + assertArrayEquals(bytes[bytesIndex], peekedKey.get()); + assertArrayEquals(bytes[bytesIndex], next.key.get()); + bytesIndex--; + } + assertEquals(-1, bytesIndex); + } + + @Test + public void shouldReturnAllUnevictedValuesFromCache() { + final long entrySize = memoryCacheEntrySize(new byte[1], new byte[1], ""); + final ThreadCache cache = setupThreadCache(0, 5, entrySize * 5L, false); + assertEquals(5, cache.size()); + // should evict byte[] {0} + cache.put(namespace, Bytes.wrap(new byte[]{6}), dirtyEntry(new byte[]{6})); + final ThreadCache.MemoryLRUCacheBytesIterator range = cache.all(namespace); + assertEquals(Bytes.wrap(new byte[]{1}), range.peekNextKey()); + } + + @Test + public void shouldReturnAllUnevictedValuesFromCacheInReverseOrder() { + final long entrySize = memoryCacheEntrySize(new byte[1], new byte[1], ""); + final ThreadCache cache = setupThreadCache(4, 0, entrySize * 5L, true); + assertEquals(5, cache.size()); + // should evict byte[] {4} + cache.put(namespace, Bytes.wrap(new byte[]{6}), dirtyEntry(new byte[]{6})); + final ThreadCache.MemoryLRUCacheBytesIterator range = cache.reverseAll(namespace); + assertEquals(Bytes.wrap(new byte[]{6}), range.peekNextKey()); + } + + @Test + public void shouldFlushDirtyEntriesForNamespace() { + final ThreadCache cache = new ThreadCache(logContext, 100000, new MockStreamsMetrics(new Metrics())); + final List received = new ArrayList<>(); + cache.addDirtyEntryFlushListener(namespace1, dirty -> { + for (final ThreadCache.DirtyEntry dirtyEntry : dirty) { + received.add(dirtyEntry.key().get()); + } + }); + final List expected = Arrays.asList(new byte[]{0}, new byte[]{1}, new byte[]{2}); + for (final byte[] bytes : expected) { + cache.put(namespace1, Bytes.wrap(bytes), dirtyEntry(bytes)); + } + cache.put(namespace2, Bytes.wrap(new byte[]{4}), dirtyEntry(new byte[]{4})); + + cache.flush(namespace1); + assertEquals(expected, received); + } + + @Test + public void shouldNotFlushCleanEntriesForNamespace() { + final ThreadCache cache = new ThreadCache(logContext, 100000, new MockStreamsMetrics(new Metrics())); + final List received = new ArrayList<>(); + cache.addDirtyEntryFlushListener(namespace1, dirty -> { + for (final ThreadCache.DirtyEntry dirtyEntry : dirty) { + received.add(dirtyEntry.key().get()); + } + }); + final List toInsert = Arrays.asList(new byte[]{0}, new byte[]{1}, new byte[]{2}); + for (final byte[] bytes : toInsert) { + cache.put(namespace1, Bytes.wrap(bytes), cleanEntry(bytes)); + } + cache.put(namespace2, Bytes.wrap(new byte[]{4}), cleanEntry(new byte[]{4})); + + cache.flush(namespace1); + assertEquals(Collections.emptyList(), received); + } + + + private void shouldEvictImmediatelyIfCacheSizeIsZeroOrVerySmall(final ThreadCache cache) { + final List received = new ArrayList<>(); + + cache.addDirtyEntryFlushListener(namespace, received::addAll); + cache.put(namespace, Bytes.wrap(new byte[]{0}), dirtyEntry(new byte[]{0})); + assertEquals(1, received.size()); + + // flushing should have no further effect + cache.flush(namespace); + assertEquals(1, received.size()); + } + + @Test + public void shouldEvictImmediatelyIfCacheSizeIsVerySmall() { + final ThreadCache cache = new ThreadCache(logContext, 1, new MockStreamsMetrics(new Metrics())); + shouldEvictImmediatelyIfCacheSizeIsZeroOrVerySmall(cache); + } + + @Test + public void shouldEvictImmediatelyIfCacheSizeIsZero() { + final ThreadCache cache = new ThreadCache(logContext, 0, new MockStreamsMetrics(new Metrics())); + shouldEvictImmediatelyIfCacheSizeIsZeroOrVerySmall(cache); + } + + @Test + public void shouldEvictAfterPutAll() { + final List received = new ArrayList<>(); + final ThreadCache cache = new ThreadCache(logContext, 1, new MockStreamsMetrics(new Metrics())); + cache.addDirtyEntryFlushListener(namespace, received::addAll); + + cache.putAll(namespace, Arrays.asList(KeyValue.pair(Bytes.wrap(new byte[]{0}), dirtyEntry(new byte[]{5})), + KeyValue.pair(Bytes.wrap(new byte[]{1}), dirtyEntry(new byte[]{6})))); + + assertEquals(cache.evicts(), 2); + assertEquals(received.size(), 2); + } + + @Test + public void shouldPutAll() { + final ThreadCache cache = new ThreadCache(logContext, 100000, new MockStreamsMetrics(new Metrics())); + + cache.putAll(namespace, Arrays.asList(KeyValue.pair(Bytes.wrap(new byte[]{0}), dirtyEntry(new byte[]{5})), + KeyValue.pair(Bytes.wrap(new byte[]{1}), dirtyEntry(new byte[]{6})))); + + assertArrayEquals(new byte[]{5}, cache.get(namespace, Bytes.wrap(new byte[]{0})).value()); + assertArrayEquals(new byte[]{6}, cache.get(namespace, Bytes.wrap(new byte[]{1})).value()); + } + + @Test + public void shouldNotForwardCleanEntryOnEviction() { + final ThreadCache cache = new ThreadCache(logContext, 0, new MockStreamsMetrics(new Metrics())); + final List received = new ArrayList<>(); + cache.addDirtyEntryFlushListener(namespace, received::addAll); + cache.put(namespace, Bytes.wrap(new byte[]{1}), cleanEntry(new byte[]{0})); + assertEquals(0, received.size()); + } + @Test + public void shouldPutIfAbsent() { + final ThreadCache cache = new ThreadCache(logContext, 100000, new MockStreamsMetrics(new Metrics())); + final Bytes key = Bytes.wrap(new byte[]{10}); + final byte[] value = {30}; + assertNull(cache.putIfAbsent(namespace, key, dirtyEntry(value))); + assertArrayEquals(value, cache.putIfAbsent(namespace, key, dirtyEntry(new byte[]{8})).value()); + assertArrayEquals(value, cache.get(namespace, key).value()); + } + + @Test + public void shouldEvictAfterPutIfAbsent() { + final List received = new ArrayList<>(); + final ThreadCache cache = new ThreadCache(logContext, 1, new MockStreamsMetrics(new Metrics())); + cache.addDirtyEntryFlushListener(namespace, received::addAll); + + cache.putIfAbsent(namespace, Bytes.wrap(new byte[]{0}), dirtyEntry(new byte[]{5})); + cache.putIfAbsent(namespace, Bytes.wrap(new byte[]{1}), dirtyEntry(new byte[]{6})); + cache.putIfAbsent(namespace, Bytes.wrap(new byte[]{1}), dirtyEntry(new byte[]{6})); + + assertEquals(cache.evicts(), 3); + assertEquals(received.size(), 3); + } + + @Test + public void shouldNotLoopForEverWhenEvictingAndCurrentCacheIsEmpty() { + final int maxCacheSizeInBytes = 100; + final ThreadCache threadCache = new ThreadCache(logContext, maxCacheSizeInBytes, new MockStreamsMetrics(new Metrics())); + // trigger a put into another cache on eviction from "name" + threadCache.addDirtyEntryFlushListener(namespace, dirty -> { + // put an item into an empty cache when the total cache size + // is already > than maxCacheSizeBytes + threadCache.put(namespace1, Bytes.wrap(new byte[]{0}), dirtyEntry(new byte[2])); + }); + threadCache.addDirtyEntryFlushListener(namespace1, dirty -> { }); + threadCache.addDirtyEntryFlushListener(namespace2, dirty -> { }); + + threadCache.put(namespace2, Bytes.wrap(new byte[]{1}), dirtyEntry(new byte[1])); + threadCache.put(namespace, Bytes.wrap(new byte[]{1}), dirtyEntry(new byte[1])); + // Put a large item such that when the eldest item is removed + // cache sizeInBytes() > maxCacheSizeBytes + final int remaining = (int) (maxCacheSizeInBytes - threadCache.sizeBytes()); + threadCache.put(namespace, Bytes.wrap(new byte[]{2}), dirtyEntry(new byte[remaining + 100])); + } + + @Test + public void shouldCleanupNamedCacheOnClose() { + final ThreadCache cache = new ThreadCache(logContext, 100000, new MockStreamsMetrics(new Metrics())); + cache.put(namespace1, Bytes.wrap(new byte[]{1}), cleanEntry(new byte[] {1})); + cache.put(namespace2, Bytes.wrap(new byte[]{1}), cleanEntry(new byte[] {1})); + assertEquals(cache.size(), 2); + cache.close(namespace2); + assertEquals(cache.size(), 1); + assertNull(cache.get(namespace2, Bytes.wrap(new byte[]{1}))); + } + + @Test + public void shouldReturnNullIfKeyIsNull() { + final ThreadCache threadCache = new ThreadCache(logContext, 10, new MockStreamsMetrics(new Metrics())); + threadCache.put(namespace, Bytes.wrap(new byte[]{1}), cleanEntry(new byte[] {1})); + assertNull(threadCache.get(namespace, null)); + } + + @Test + public void shouldCalculateSizeInBytes() { + final ThreadCache cache = new ThreadCache(logContext, 100000, new MockStreamsMetrics(new Metrics())); + final NamedCache.LRUNode node = new NamedCache.LRUNode(Bytes.wrap(new byte[]{1}), dirtyEntry(new byte[]{0})); + cache.put(namespace1, Bytes.wrap(new byte[]{1}), cleanEntry(new byte[]{0})); + assertEquals(cache.sizeBytes(), node.size()); + } + + @Test + public void shouldResizeAndShrink() { + final ThreadCache cache = new ThreadCache(logContext, 10000, new MockStreamsMetrics(new Metrics())); + cache.put(namespace, Bytes.wrap(new byte[]{1}), cleanEntry(new byte[]{0})); + cache.put(namespace, Bytes.wrap(new byte[]{2}), cleanEntry(new byte[]{0})); + cache.put(namespace, Bytes.wrap(new byte[]{3}), cleanEntry(new byte[]{0})); + assertEquals(141, cache.sizeBytes()); + cache.resize(100); + assertEquals(94, cache.sizeBytes()); + cache.put(namespace1, Bytes.wrap(new byte[]{4}), cleanEntry(new byte[]{0})); + assertEquals(94, cache.sizeBytes()); + } + + private LRUCacheEntry dirtyEntry(final byte[] key) { + return new LRUCacheEntry(key, new RecordHeaders(), true, -1, -1, -1, ""); + } + + private LRUCacheEntry cleanEntry(final byte[] key) { + return new LRUCacheEntry(key); + } + + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/TimeOrderedKeyValueBufferTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/TimeOrderedKeyValueBufferTest.java new file mode 100644 index 0000000..4c74df2 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/TimeOrderedKeyValueBufferTest.java @@ -0,0 +1,1054 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.internals.Change; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.ProcessorRecordContext; +import org.apache.kafka.streams.processor.internals.RecordBatchingStateRestoreCallback; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.streams.state.internals.TimeOrderedKeyValueBuffer.Eviction; +import org.apache.kafka.test.MockInternalProcessorContext; +import org.apache.kafka.test.MockRecordCollector; +import org.apache.kafka.test.TestUtils; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Collection; +import java.util.LinkedList; +import java.util.List; +import java.util.Optional; +import java.util.Properties; +import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static org.apache.kafka.streams.state.internals.InMemoryTimeOrderedKeyValueBuffer.CHANGELOG_HEADERS; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.fail; + +@RunWith(Parameterized.class) +public class TimeOrderedKeyValueBufferTest> { + + private static final String APP_ID = "test-app"; + private final Function bufferSupplier; + private final String testName; + + public static final class NullRejectingStringSerializer extends StringSerializer { + @Override + public byte[] serialize(final String topic, final String data) { + if (data == null) { + throw new IllegalArgumentException("null data not allowed"); + } + return super.serialize(topic, data); + } + } + + // As we add more buffer implementations/configurations, we can add them here + @Parameterized.Parameters(name = "{index}: test={0}") + public static Collection parameters() { + return singletonList( + new Object[] { + "in-memory buffer", + (Function>) name -> + new InMemoryTimeOrderedKeyValueBuffer + .Builder<>(name, Serdes.String(), Serdes.serdeFrom(new NullRejectingStringSerializer(), new StringDeserializer())) + .build() + } + ); + } + + public TimeOrderedKeyValueBufferTest(final String testName, final Function bufferSupplier) { + this.testName = testName + "_" + new Random().nextInt(Integer.MAX_VALUE); + this.bufferSupplier = bufferSupplier; + } + + private static MockInternalProcessorContext makeContext() { + final Properties properties = new Properties(); + properties.setProperty(StreamsConfig.APPLICATION_ID_CONFIG, APP_ID); + properties.setProperty(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, ""); + + final TaskId taskId = new TaskId(0, 0); + + final MockInternalProcessorContext context = new MockInternalProcessorContext(properties, taskId, TestUtils.tempDirectory()); + context.setRecordCollector(new MockRecordCollector()); + + return context; + } + + + private static void cleanup(final MockInternalProcessorContext context, final TimeOrderedKeyValueBuffer buffer) { + try { + buffer.close(); + Utils.delete(context.stateDir()); + } catch (final IOException e) { + throw new RuntimeException(e); + } + } + + @Test + public void shouldInit() { + final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName); + final MockInternalProcessorContext context = makeContext(); + buffer.init((StateStoreContext) context, buffer); + cleanup(context, buffer); + } + + @Test + public void shouldAcceptData() { + final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName); + final MockInternalProcessorContext context = makeContext(); + buffer.init((StateStoreContext) context, buffer); + putRecord(buffer, context, 0L, 0L, "asdf", "2p93nf"); + cleanup(context, buffer); + } + + @Test + public void shouldRejectNullValues() { + final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName); + final MockInternalProcessorContext context = makeContext(); + buffer.init((StateStoreContext) context, buffer); + try { + buffer.put(0, new Record<>("asdf", null, 0L), getContext(0)); + fail("expected an exception"); + } catch (final NullPointerException expected) { + // expected + } + cleanup(context, buffer); + } + + @Test + public void shouldRemoveData() { + final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName); + final MockInternalProcessorContext context = makeContext(); + buffer.init((StateStoreContext) context, buffer); + putRecord(buffer, context, 0L, 0L, "asdf", "qwer"); + assertThat(buffer.numRecords(), is(1)); + buffer.evictWhile(() -> true, kv -> { }); + assertThat(buffer.numRecords(), is(0)); + cleanup(context, buffer); + } + + @Test + public void shouldRespectEvictionPredicate() { + final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName); + final MockInternalProcessorContext context = makeContext(); + buffer.init((StateStoreContext) context, buffer); + putRecord(buffer, context, 0L, 0L, "asdf", "eyt"); + putRecord(buffer, context, 1L, 0L, "zxcv", "rtg"); + assertThat(buffer.numRecords(), is(2)); + final List> evicted = new LinkedList<>(); + buffer.evictWhile(() -> buffer.numRecords() > 1, evicted::add); + assertThat(buffer.numRecords(), is(1)); + assertThat(evicted, is(singletonList( + new Eviction<>("asdf", new Change<>("eyt", null), getContext(0L)) + ))); + cleanup(context, buffer); + } + + @Test + public void shouldTrackCount() { + final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName); + final MockInternalProcessorContext context = makeContext(); + buffer.init((StateStoreContext) context, buffer); + putRecord(buffer, context, 0L, 0L, "asdf", "oin"); + assertThat(buffer.numRecords(), is(1)); + putRecord(buffer, context, 1L, 0L, "asdf", "wekjn"); + assertThat(buffer.numRecords(), is(1)); + putRecord(buffer, context, 0L, 0L, "zxcv", "24inf"); + assertThat(buffer.numRecords(), is(2)); + cleanup(context, buffer); + } + + @Test + public void shouldTrackSize() { + final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName); + final MockInternalProcessorContext context = makeContext(); + buffer.init((StateStoreContext) context, buffer); + putRecord(buffer, context, 0L, 0L, "asdf", "23roni"); + assertThat(buffer.bufferSize(), is(43L)); + putRecord(buffer, context, 1L, 0L, "asdf", "3l"); + assertThat(buffer.bufferSize(), is(39L)); + putRecord(buffer, context, 0L, 0L, "zxcv", "qfowin"); + assertThat(buffer.bufferSize(), is(82L)); + cleanup(context, buffer); + } + + @Test + public void shouldTrackMinTimestamp() { + final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName); + final MockInternalProcessorContext context = makeContext(); + buffer.init((StateStoreContext) context, buffer); + putRecord(buffer, context, 1L, 0L, "asdf", "2093j"); + assertThat(buffer.minTimestamp(), is(1L)); + putRecord(buffer, context, 0L, 0L, "zxcv", "3gon4i"); + assertThat(buffer.minTimestamp(), is(0L)); + cleanup(context, buffer); + } + + @Test + public void shouldEvictOldestAndUpdateSizeAndCountAndMinTimestamp() { + final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName); + final MockInternalProcessorContext context = makeContext(); + buffer.init((StateStoreContext) context, buffer); + + putRecord(buffer, context, 1L, 0L, "zxcv", "o23i4"); + assertThat(buffer.numRecords(), is(1)); + assertThat(buffer.bufferSize(), is(42L)); + assertThat(buffer.minTimestamp(), is(1L)); + + putRecord(buffer, context, 0L, 0L, "asdf", "3ng"); + assertThat(buffer.numRecords(), is(2)); + assertThat(buffer.bufferSize(), is(82L)); + assertThat(buffer.minTimestamp(), is(0L)); + + final AtomicInteger callbackCount = new AtomicInteger(0); + buffer.evictWhile(() -> true, kv -> { + switch (callbackCount.incrementAndGet()) { + case 1: { + assertThat(kv.key(), is("asdf")); + assertThat(buffer.numRecords(), is(2)); + assertThat(buffer.bufferSize(), is(82L)); + assertThat(buffer.minTimestamp(), is(0L)); + break; + } + case 2: { + assertThat(kv.key(), is("zxcv")); + assertThat(buffer.numRecords(), is(1)); + assertThat(buffer.bufferSize(), is(42L)); + assertThat(buffer.minTimestamp(), is(1L)); + break; + } + default: { + fail("too many invocations"); + break; + } + } + }); + assertThat(callbackCount.get(), is(2)); + assertThat(buffer.numRecords(), is(0)); + assertThat(buffer.bufferSize(), is(0L)); + assertThat(buffer.minTimestamp(), is(Long.MAX_VALUE)); + cleanup(context, buffer); + } + + @Test + public void shouldReturnUndefinedOnPriorValueForNotBufferedKey() { + final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName); + final MockInternalProcessorContext context = makeContext(); + buffer.init((StateStoreContext) context, buffer); + + assertThat(buffer.priorValueForBuffered("ASDF"), is(Maybe.undefined())); + } + + @Test + public void shouldReturnPriorValueForBufferedKey() { + final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName); + final MockInternalProcessorContext context = makeContext(); + buffer.init((StateStoreContext) context, buffer); + + final ProcessorRecordContext recordContext = getContext(0L); + context.setRecordContext(recordContext); + buffer.put(1L, new Record<>("A", new Change<>("new-value", "old-value"), 0L), recordContext); + buffer.put(1L, new Record<>("B", new Change<>("new-value", null), 0L), recordContext); + assertThat(buffer.priorValueForBuffered("A"), is(Maybe.defined(ValueAndTimestamp.make("old-value", -1)))); + assertThat(buffer.priorValueForBuffered("B"), is(Maybe.defined(null))); + } + + @Test + public void shouldFlush() { + final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName); + final MockInternalProcessorContext context = makeContext(); + buffer.init((StateStoreContext) context, buffer); + putRecord(buffer, context, 2L, 0L, "asdf", "2093j"); + putRecord(buffer, context, 1L, 1L, "zxcv", "3gon4i"); + putRecord(buffer, context, 0L, 2L, "deleteme", "deadbeef"); + + // replace "deleteme" with a tombstone + buffer.evictWhile(() -> buffer.minTimestamp() < 1, kv -> { }); + + // flush everything to the changelog + buffer.flush(); + + // the buffer should serialize the buffer time and the value as byte[], + // which we can't compare for equality using ProducerRecord. + // As a workaround, I'm deserializing them and shoving them in a KeyValue, just for ease of testing. + + final List>> collected = + ((MockRecordCollector) context.recordCollector()) + .collected() + .stream() + .map(pr -> { + final KeyValue niceValue; + if (pr.value() == null) { + niceValue = null; + } else { + final byte[] serializedValue = (byte[]) pr.value(); + final ByteBuffer valueBuffer = ByteBuffer.wrap(serializedValue); + final BufferValue contextualRecord = BufferValue.deserialize(valueBuffer); + final long timestamp = valueBuffer.getLong(); + niceValue = new KeyValue<>(timestamp, contextualRecord); + } + + return new ProducerRecord<>(pr.topic(), + pr.partition(), + pr.timestamp(), + pr.key().toString(), + niceValue, + pr.headers()); + }) + .collect(Collectors.toList()); + + assertThat(collected, is(asList( + new ProducerRecord<>(APP_ID + "-" + testName + "-changelog", + 0, // Producer will assign + null, + "deleteme", + null, + new RecordHeaders() + ), + new ProducerRecord<>(APP_ID + "-" + testName + "-changelog", + 0, + null, + "zxcv", + new KeyValue<>(1L, getBufferValue("3gon4i", 1)), + CHANGELOG_HEADERS + ), + new ProducerRecord<>(APP_ID + "-" + testName + "-changelog", + 0, + null, + "asdf", + new KeyValue<>(2L, getBufferValue("2093j", 0)), + CHANGELOG_HEADERS + ) + ))); + + cleanup(context, buffer); + } + + @Test + public void shouldRestoreOldUnversionedFormat() { + final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName); + final MockInternalProcessorContext context = makeContext(); + buffer.init((StateStoreContext) context, buffer); + + final RecordBatchingStateRestoreCallback stateRestoreCallback = + (RecordBatchingStateRestoreCallback) context.stateRestoreCallback(testName); + + context.setRecordContext(new ProcessorRecordContext(0, 0, 0, "", new RecordHeaders())); + + // These serialized formats were captured by running version 2.1 code. + // They verify that an upgrade from 2.1 will work. + // Do not change them. + final String toDeleteBinaryValue = "0000000000000000FFFFFFFF00000006646F6F6D6564"; + final String asdfBinaryValue = "0000000000000002FFFFFFFF0000000471776572"; + final String zxcvBinaryValue1 = "00000000000000010000000870726576696F757300000005656F34696D"; + final String zxcvBinaryValue2 = "000000000000000100000005656F34696D000000046E657874"; + + stateRestoreCallback.restoreBatch(asList( + new ConsumerRecord<>("changelog-topic", + 0, + 0, + 0, + TimestampType.CREATE_TIME, + -1, + -1, + "todelete".getBytes(UTF_8), + hexStringToByteArray(toDeleteBinaryValue), + new RecordHeaders(), + Optional.empty()), + new ConsumerRecord<>("changelog-topic", + 0, + 1, + 1, + TimestampType.CREATE_TIME, + -1, + -1, + "asdf".getBytes(UTF_8), + hexStringToByteArray(asdfBinaryValue), + new RecordHeaders(), + Optional.empty()), + new ConsumerRecord<>("changelog-topic", + 0, + 2, + 2, + TimestampType.CREATE_TIME, + -1, + -1, + "zxcv".getBytes(UTF_8), + hexStringToByteArray(zxcvBinaryValue1), + new RecordHeaders(), + Optional.empty()), + new ConsumerRecord<>("changelog-topic", + 0, + 3, + 3, + TimestampType.CREATE_TIME, + -1, + -1, + "zxcv".getBytes(UTF_8), + hexStringToByteArray(zxcvBinaryValue2), + new RecordHeaders(), + Optional.empty()) + )); + + assertThat(buffer.numRecords(), is(3)); + assertThat(buffer.minTimestamp(), is(0L)); + assertThat(buffer.bufferSize(), is(172L)); + + stateRestoreCallback.restoreBatch(singletonList( + new ConsumerRecord<>("changelog-topic", + 0, + 3, + 3, + TimestampType.CREATE_TIME, + -1, + -1, + "todelete".getBytes(UTF_8), + null, + new RecordHeaders(), + Optional.empty()) + )); + + assertThat(buffer.numRecords(), is(2)); + assertThat(buffer.minTimestamp(), is(1L)); + assertThat(buffer.bufferSize(), is(115L)); + + assertThat(buffer.priorValueForBuffered("todelete"), is(Maybe.undefined())); + assertThat(buffer.priorValueForBuffered("asdf"), is(Maybe.defined(null))); + assertThat(buffer.priorValueForBuffered("zxcv"), is(Maybe.defined(ValueAndTimestamp.make("previous", -1)))); + + // flush the buffer into a list in buffer order so we can make assertions about the contents. + + final List> evicted = new LinkedList<>(); + buffer.evictWhile(() -> true, evicted::add); + + // Several things to note: + // * The buffered records are ordered according to their buffer time (serialized in the value of the changelog) + // * The record timestamps are properly restored, and not conflated with the record's buffer time. + // * The keys and values are properly restored + // * The record topic is set to the changelog topic. This was an oversight in the original implementation, + // which is fixed in changelog format v1. But upgraded applications still need to be able to handle the + // original format. + + assertThat(evicted, is(asList( + new Eviction<>( + "zxcv", + new Change<>("next", "eo4im"), + new ProcessorRecordContext(3L, 3, 0, "changelog-topic", new RecordHeaders())), + new Eviction<>( + "asdf", + new Change<>("qwer", null), + new ProcessorRecordContext(1L, 1, 0, "changelog-topic", new RecordHeaders())) + ))); + + cleanup(context, buffer); + } + + @Test + public void shouldRestoreV1Format() { + final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName); + final MockInternalProcessorContext context = makeContext(); + buffer.init((StateStoreContext) context, buffer); + + final RecordBatchingStateRestoreCallback stateRestoreCallback = + (RecordBatchingStateRestoreCallback) context.stateRestoreCallback(testName); + + context.setRecordContext(new ProcessorRecordContext(0, 0, 0, "", new RecordHeaders())); + + final RecordHeaders v1FlagHeaders = new RecordHeaders(new Header[] {new RecordHeader("v", new byte[] {(byte) 1})}); + + // These serialized formats were captured by running version 2.2 code. + // They verify that an upgrade from 2.2 will work. + // Do not change them. + final String toDeleteBinary = "00000000000000000000000000000000000000000000000000000005746F70696300000000FFFFFFFF0000000EFFFFFFFF00000006646F6F6D6564"; + final String asdfBinary = "00000000000000020000000000000001000000000000000000000005746F70696300000000FFFFFFFF0000000CFFFFFFFF0000000471776572"; + final String zxcvBinary1 = "00000000000000010000000000000002000000000000000000000005746F70696300000000FFFFFFFF000000150000000870726576696F757300000005336F34696D"; + final String zxcvBinary2 = "00000000000000010000000000000003000000000000000000000005746F70696300000000FFFFFFFF0000001100000005336F34696D000000046E657874"; + + stateRestoreCallback.restoreBatch(asList( + new ConsumerRecord<>("changelog-topic", + 0, + 0, + 999, + TimestampType.CREATE_TIME, + -1, + -1, + "todelete".getBytes(UTF_8), + hexStringToByteArray(toDeleteBinary), + v1FlagHeaders, + Optional.empty()), + new ConsumerRecord<>("changelog-topic", + 0, + 1, + 9999, + TimestampType.CREATE_TIME, + -1, + -1, + "asdf".getBytes(UTF_8), + hexStringToByteArray(asdfBinary), + v1FlagHeaders, + Optional.empty()), + new ConsumerRecord<>("changelog-topic", + 0, + 2, + 99, + TimestampType.CREATE_TIME, + -1, + -1, + "zxcv".getBytes(UTF_8), + hexStringToByteArray(zxcvBinary1), + v1FlagHeaders, + Optional.empty()), + new ConsumerRecord<>("changelog-topic", + 0, + 3, + 100, + TimestampType.CREATE_TIME, + -1, + -1, + "zxcv".getBytes(UTF_8), + hexStringToByteArray(zxcvBinary2), + v1FlagHeaders, + Optional.empty()) + )); + + assertThat(buffer.numRecords(), is(3)); + assertThat(buffer.minTimestamp(), is(0L)); + assertThat(buffer.bufferSize(), is(142L)); + + stateRestoreCallback.restoreBatch(singletonList( + new ConsumerRecord<>("changelog-topic", + 0, + 3, + 3, + TimestampType.CREATE_TIME, + -1, + -1, + "todelete".getBytes(UTF_8), + null, + new RecordHeaders(), + Optional.empty()) + )); + + assertThat(buffer.numRecords(), is(2)); + assertThat(buffer.minTimestamp(), is(1L)); + assertThat(buffer.bufferSize(), is(95L)); + + assertThat(buffer.priorValueForBuffered("todelete"), is(Maybe.undefined())); + assertThat(buffer.priorValueForBuffered("asdf"), is(Maybe.defined(null))); + assertThat(buffer.priorValueForBuffered("zxcv"), is(Maybe.defined(ValueAndTimestamp.make("previous", -1)))); + + // flush the buffer into a list in buffer order so we can make assertions about the contents. + + final List> evicted = new LinkedList<>(); + buffer.evictWhile(() -> true, evicted::add); + + // Several things to note: + // * The buffered records are ordered according to their buffer time (serialized in the value of the changelog) + // * The record timestamps are properly restored, and not conflated with the record's buffer time. + // * The keys and values are properly restored + // * The record topic is set to the original input topic, *not* the changelog topic + // * The record offset preserves the original input record's offset, *not* the offset of the changelog record + + + assertThat(evicted, is(asList( + new Eviction<>( + "zxcv", + new Change<>("next", "3o4im"), + getContext(3L)), + new Eviction<>( + "asdf", + new Change<>("qwer", null), + getContext(1L) + )))); + + cleanup(context, buffer); + } + + + @Test + public void shouldRestoreV2Format() { + final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName); + final MockInternalProcessorContext context = makeContext(); + buffer.init((StateStoreContext) context, buffer); + + final RecordBatchingStateRestoreCallback stateRestoreCallback = + (RecordBatchingStateRestoreCallback) context.stateRestoreCallback(testName); + + context.setRecordContext(new ProcessorRecordContext(0, 0, 0, "", new RecordHeaders())); + + final RecordHeaders v2FlagHeaders = new RecordHeaders(new Header[] {new RecordHeader("v", new byte[] {(byte) 2})}); + + // These serialized formats were captured by running version 2.3 code. + // They verify that an upgrade from 2.3 will work. + // Do not change them. + final String toDeleteBinary = "0000000000000000000000000000000000000005746F70696300000000FFFFFFFF0000000EFFFFFFFF00000006646F6F6D6564FFFFFFFF0000000000000000"; + final String asdfBinary = "0000000000000001000000000000000000000005746F70696300000000FFFFFFFF0000000CFFFFFFFF0000000471776572FFFFFFFF0000000000000002"; + final String zxcvBinary1 = "0000000000000002000000000000000000000005746F70696300000000FFFFFFFF000000140000000749474E4F52454400000005336F34696D0000000870726576696F75730000000000000001"; + final String zxcvBinary2 = "0000000000000003000000000000000000000005746F70696300000000FFFFFFFF0000001100000005336F34696D000000046E6578740000000870726576696F75730000000000000001"; + + stateRestoreCallback.restoreBatch(asList( + new ConsumerRecord<>("changelog-topic", + 0, + 0, + 999, + TimestampType.CREATE_TIME, + -1, + -1, + "todelete".getBytes(UTF_8), + hexStringToByteArray(toDeleteBinary), + v2FlagHeaders, + Optional.empty()), + new ConsumerRecord<>("changelog-topic", + 0, + 1, + 9999, + TimestampType.CREATE_TIME, + -1, + -1, + "asdf".getBytes(UTF_8), + hexStringToByteArray(asdfBinary), + v2FlagHeaders, + Optional.empty()), + new ConsumerRecord<>("changelog-topic", + 0, + 2, + 99, + TimestampType.CREATE_TIME, + -1, + -1, + "zxcv".getBytes(UTF_8), + hexStringToByteArray(zxcvBinary1), + v2FlagHeaders, + Optional.empty()), + new ConsumerRecord<>("changelog-topic", + 0, + 2, + 100, + TimestampType.CREATE_TIME, + -1, + -1, + "zxcv".getBytes(UTF_8), + hexStringToByteArray(zxcvBinary2), + v2FlagHeaders, + Optional.empty()) + )); + + assertThat(buffer.numRecords(), is(3)); + assertThat(buffer.minTimestamp(), is(0L)); + assertThat(buffer.bufferSize(), is(142L)); + + stateRestoreCallback.restoreBatch(singletonList( + new ConsumerRecord<>("changelog-topic", + 0, + 3, + 3, + TimestampType.CREATE_TIME, + -1, + -1, + "todelete".getBytes(UTF_8), + null, + new RecordHeaders(), + Optional.empty()) + )); + + assertThat(buffer.numRecords(), is(2)); + assertThat(buffer.minTimestamp(), is(1L)); + assertThat(buffer.bufferSize(), is(95L)); + + assertThat(buffer.priorValueForBuffered("todelete"), is(Maybe.undefined())); + assertThat(buffer.priorValueForBuffered("asdf"), is(Maybe.defined(null))); + assertThat(buffer.priorValueForBuffered("zxcv"), is(Maybe.defined(ValueAndTimestamp.make("previous", -1)))); + + // flush the buffer into a list in buffer order so we can make assertions about the contents. + + final List> evicted = new LinkedList<>(); + buffer.evictWhile(() -> true, evicted::add); + + // Several things to note: + // * The buffered records are ordered according to their buffer time (serialized in the value of the changelog) + // * The record timestamps are properly restored, and not conflated with the record's buffer time. + // * The keys and values are properly restored + // * The record topic is set to the original input topic, *not* the changelog topic + // * The record offset preserves the original input record's offset, *not* the offset of the changelog record + + + assertThat(evicted, is(asList( + new Eviction<>( + "zxcv", + new Change<>("next", "3o4im"), + getContext(3L)), + new Eviction<>( + "asdf", + new Change<>("qwer", null), + getContext(1L) + )))); + + cleanup(context, buffer); + } + + @Test + public void shouldRestoreV3FormatWithV2Header() { + // versions 2.4.0, 2.4.1, and 2.5.0 would have erroneously encoded a V3 record with the + // V2 header, so we need to be sure to handle this case as well. + // Note the data is the same as the V3 test. + final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName); + final MockInternalProcessorContext context = makeContext(); + buffer.init((StateStoreContext) context, buffer); + + final RecordBatchingStateRestoreCallback stateRestoreCallback = + (RecordBatchingStateRestoreCallback) context.stateRestoreCallback(testName); + + context.setRecordContext(new ProcessorRecordContext(0, 0, 0, "", new RecordHeaders())); + + final RecordHeaders headers = new RecordHeaders(new Header[] {new RecordHeader("v", new byte[] {(byte) 2})}); + + // These serialized formats were captured by running version 2.4 code. + // They verify that an upgrade from 2.4 will work. + // Do not change them. + final String toDeleteBinary = "0000000000000000000000000000000000000005746F70696300000000FFFFFFFFFFFFFFFFFFFFFFFF00000006646F6F6D65640000000000000000"; + final String asdfBinary = "0000000000000001000000000000000000000005746F70696300000000FFFFFFFFFFFFFFFFFFFFFFFF00000004717765720000000000000002"; + final String zxcvBinary1 = "0000000000000002000000000000000000000005746F70696300000000FFFFFFFF0000000870726576696F75730000000749474E4F52454400000005336F34696D0000000000000001"; + final String zxcvBinary2 = "0000000000000003000000000000000000000005746F70696300000000FFFFFFFF0000000870726576696F757300000005336F34696D000000046E6578740000000000000001"; + + stateRestoreCallback.restoreBatch(asList( + new ConsumerRecord<>("changelog-topic", + 0, + 0, + 999, + TimestampType.CREATE_TIME, + -1, + -1, + "todelete".getBytes(UTF_8), + hexStringToByteArray(toDeleteBinary), + headers, + Optional.empty()), + new ConsumerRecord<>("changelog-topic", + 0, + 1, + 9999, + TimestampType.CREATE_TIME, + -1, + -1, + "asdf".getBytes(UTF_8), + hexStringToByteArray(asdfBinary), + headers, + Optional.empty()), + new ConsumerRecord<>("changelog-topic", + 0, + 2, + 99, + TimestampType.CREATE_TIME, + -1, + -1, + "zxcv".getBytes(UTF_8), + hexStringToByteArray(zxcvBinary1), + headers, + Optional.empty()), + new ConsumerRecord<>("changelog-topic", + 0, + 2, + 100, + TimestampType.CREATE_TIME, + -1, + -1, + "zxcv".getBytes(UTF_8), + hexStringToByteArray(zxcvBinary2), + headers, + Optional.empty()) + )); + + assertThat(buffer.numRecords(), is(3)); + assertThat(buffer.minTimestamp(), is(0L)); + assertThat(buffer.bufferSize(), is(142L)); + + stateRestoreCallback.restoreBatch(singletonList( + new ConsumerRecord<>("changelog-topic", + 0, + 3, + 3, + TimestampType.CREATE_TIME, + -1, + -1, + "todelete".getBytes(UTF_8), + null, + new RecordHeaders(), + Optional.empty()) + )); + + assertThat(buffer.numRecords(), is(2)); + assertThat(buffer.minTimestamp(), is(1L)); + assertThat(buffer.bufferSize(), is(95L)); + + assertThat(buffer.priorValueForBuffered("todelete"), is(Maybe.undefined())); + assertThat(buffer.priorValueForBuffered("asdf"), is(Maybe.defined(null))); + assertThat(buffer.priorValueForBuffered("zxcv"), is(Maybe.defined(ValueAndTimestamp.make("previous", -1)))); + + // flush the buffer into a list in buffer order so we can make assertions about the contents. + + final List> evicted = new LinkedList<>(); + buffer.evictWhile(() -> true, evicted::add); + + // Several things to note: + // * The buffered records are ordered according to their buffer time (serialized in the value of the changelog) + // * The record timestamps are properly restored, and not conflated with the record's buffer time. + // * The keys and values are properly restored + // * The record topic is set to the original input topic, *not* the changelog topic + // * The record offset preserves the original input record's offset, *not* the offset of the changelog record + + + assertThat(evicted, is(asList( + new Eviction<>( + "zxcv", + new Change<>("next", "3o4im"), + getContext(3L)), + new Eviction<>( + "asdf", + new Change<>("qwer", null), + getContext(1L) + )))); + + cleanup(context, buffer); + } + + @Test + public void shouldRestoreV3Format() { + final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName); + final MockInternalProcessorContext context = makeContext(); + buffer.init((StateStoreContext) context, buffer); + + final RecordBatchingStateRestoreCallback stateRestoreCallback = + (RecordBatchingStateRestoreCallback) context.stateRestoreCallback(testName); + + context.setRecordContext(new ProcessorRecordContext(0, 0, 0, "", new RecordHeaders())); + + final RecordHeaders headers = new RecordHeaders(new Header[] {new RecordHeader("v", new byte[] {(byte) 3})}); + + // These serialized formats were captured by running version 2.4 code. + // They verify that an upgrade from 2.4 will work. + // Do not change them. + final String toDeleteBinary = "0000000000000000000000000000000000000005746F70696300000000FFFFFFFFFFFFFFFFFFFFFFFF00000006646F6F6D65640000000000000000"; + final String asdfBinary = "0000000000000001000000000000000000000005746F70696300000000FFFFFFFFFFFFFFFFFFFFFFFF00000004717765720000000000000002"; + final String zxcvBinary1 = "0000000000000002000000000000000000000005746F70696300000000FFFFFFFF0000000870726576696F75730000000749474E4F52454400000005336F34696D0000000000000001"; + final String zxcvBinary2 = "0000000000000003000000000000000000000005746F70696300000000FFFFFFFF0000000870726576696F757300000005336F34696D000000046E6578740000000000000001"; + + stateRestoreCallback.restoreBatch(asList( + new ConsumerRecord<>("changelog-topic", + 0, + 0, + 999, + TimestampType.CREATE_TIME, + -1, + -1, + "todelete".getBytes(UTF_8), + hexStringToByteArray(toDeleteBinary), + headers, + Optional.empty()), + new ConsumerRecord<>("changelog-topic", + 0, + 1, + 9999, + TimestampType.CREATE_TIME, + -1, + -1, + "asdf".getBytes(UTF_8), + hexStringToByteArray(asdfBinary), + headers, + Optional.empty()), + new ConsumerRecord<>("changelog-topic", + 0, + 2, + 99, + TimestampType.CREATE_TIME, + -1, + -1, + "zxcv".getBytes(UTF_8), + hexStringToByteArray(zxcvBinary1), + headers, + Optional.empty()), + new ConsumerRecord<>("changelog-topic", + 0, + 2, + 100, + TimestampType.CREATE_TIME, + -1, + -1, + "zxcv".getBytes(UTF_8), + hexStringToByteArray(zxcvBinary2), + headers, + Optional.empty()) + )); + + assertThat(buffer.numRecords(), is(3)); + assertThat(buffer.minTimestamp(), is(0L)); + assertThat(buffer.bufferSize(), is(142L)); + + stateRestoreCallback.restoreBatch(singletonList( + new ConsumerRecord<>("changelog-topic", + 0, + 3, + 3, + TimestampType.CREATE_TIME, + -1, + -1, + "todelete".getBytes(UTF_8), + null, + new RecordHeaders(), + Optional.empty()) + )); + + assertThat(buffer.numRecords(), is(2)); + assertThat(buffer.minTimestamp(), is(1L)); + assertThat(buffer.bufferSize(), is(95L)); + + assertThat(buffer.priorValueForBuffered("todelete"), is(Maybe.undefined())); + assertThat(buffer.priorValueForBuffered("asdf"), is(Maybe.defined(null))); + assertThat(buffer.priorValueForBuffered("zxcv"), is(Maybe.defined(ValueAndTimestamp.make("previous", -1)))); + + // flush the buffer into a list in buffer order so we can make assertions about the contents. + + final List> evicted = new LinkedList<>(); + buffer.evictWhile(() -> true, evicted::add); + + // Several things to note: + // * The buffered records are ordered according to their buffer time (serialized in the value of the changelog) + // * The record timestamps are properly restored, and not conflated with the record's buffer time. + // * The keys and values are properly restored + // * The record topic is set to the original input topic, *not* the changelog topic + // * The record offset preserves the original input record's offset, *not* the offset of the changelog record + + + assertThat(evicted, is(asList( + new Eviction<>( + "zxcv", + new Change<>("next", "3o4im"), + getContext(3L)), + new Eviction<>( + "asdf", + new Change<>("qwer", null), + getContext(1L) + )))); + + cleanup(context, buffer); + } + + @Test + public void shouldNotRestoreUnrecognizedVersionRecord() { + final TimeOrderedKeyValueBuffer buffer = bufferSupplier.apply(testName); + final MockInternalProcessorContext context = makeContext(); + buffer.init((StateStoreContext) context, buffer); + + final RecordBatchingStateRestoreCallback stateRestoreCallback = + (RecordBatchingStateRestoreCallback) context.stateRestoreCallback(testName); + + context.setRecordContext(new ProcessorRecordContext(0, 0, 0, "", new RecordHeaders())); + + final RecordHeaders unknownFlagHeaders = new RecordHeaders(new Header[] {new RecordHeader("v", new byte[] {(byte) -1})}); + + final byte[] todeleteValue = getBufferValue("doomed", 0).serialize(0).array(); + try { + stateRestoreCallback.restoreBatch(singletonList( + new ConsumerRecord<>("changelog-topic", + 0, + 0, + 999, + TimestampType.CREATE_TIME, + -1, + -1, + "todelete".getBytes(UTF_8), + ByteBuffer.allocate(Long.BYTES + todeleteValue.length).putLong(0L).put(todeleteValue).array(), + unknownFlagHeaders, + Optional.empty()) + )); + fail("expected an exception"); + } catch (final IllegalArgumentException expected) { + // nothing to do. + } finally { + cleanup(context, buffer); + } + } + + private static void putRecord(final TimeOrderedKeyValueBuffer buffer, + final MockInternalProcessorContext context, + final long streamTime, + final long recordTimestamp, + final String key, + final String value) { + final ProcessorRecordContext recordContext = getContext(recordTimestamp); + context.setRecordContext(recordContext); + buffer.put(streamTime, new Record<>(key, new Change<>(value, null), 0L), recordContext); + } + + private static BufferValue getBufferValue(final String value, final long timestamp) { + return new BufferValue( + null, + null, + Serdes.String().serializer().serialize(null, value), + getContext(timestamp) + ); + } + + private static ProcessorRecordContext getContext(final long recordTimestamp) { + return new ProcessorRecordContext(recordTimestamp, 0, 0, "topic", new RecordHeaders()); + } + + + // to be used to generate future hex-encoded values +// private static final char[] HEX_ARRAY = "0123456789ABCDEF".toCharArray(); +// private static String bytesToHex(final byte[] bytes) { +// final char[] hexChars = new char[bytes.length * 2]; +// for (int j = 0; j < bytes.length; j++) { +// final int v = bytes[j] & 0xFF; +// hexChars[j * 2] = HEX_ARRAY[v >>> 4]; +// hexChars[j * 2 + 1] = HEX_ARRAY[v & 0x0F]; +// } +// return new String(hexChars); +// } + + private static byte[] hexStringToByteArray(final String hexString) { + final int len = hexString.length(); + final byte[] data = new byte[len / 2]; + for (int i = 0; i < len; i += 2) { + data[i / 2] = (byte) ((Character.digit(hexString.charAt(i), 16) << 4) + + Character.digit(hexString.charAt(i + 1), 16)); + } + return data; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSideSerializerTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSideSerializerTest.java new file mode 100644 index 0000000..5cca8f6 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/TimestampedKeyAndJoinSideSerializerTest.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serdes; +import org.junit.Test; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.junit.Assert.assertThrows; + +public class TimestampedKeyAndJoinSideSerializerTest { + private static final String TOPIC = "some-topic"; + + private static final TimestampedKeyAndJoinSideSerde STRING_SERDE = + new TimestampedKeyAndJoinSideSerde<>(Serdes.String()); + + @Test + public void shouldSerializeKeyWithJoinSideAsTrue() { + final String value = "some-string"; + + final TimestampedKeyAndJoinSide timestampedKeyAndJoinSide = TimestampedKeyAndJoinSide.make(true, value, 10); + + final byte[] serialized = + STRING_SERDE.serializer().serialize(TOPIC, timestampedKeyAndJoinSide); + + assertThat(serialized, is(notNullValue())); + + final TimestampedKeyAndJoinSide deserialized = + STRING_SERDE.deserializer().deserialize(TOPIC, serialized); + + assertThat(deserialized, is(timestampedKeyAndJoinSide)); + } + + @Test + public void shouldSerializeKeyWithJoinSideAsFalse() { + final String value = "some-string"; + + final TimestampedKeyAndJoinSide timestampedKeyAndJoinSide = TimestampedKeyAndJoinSide.make(false, value, 20); + + final byte[] serialized = + STRING_SERDE.serializer().serialize(TOPIC, timestampedKeyAndJoinSide); + + assertThat(serialized, is(notNullValue())); + + final TimestampedKeyAndJoinSide deserialized = + STRING_SERDE.deserializer().deserialize(TOPIC, serialized); + + assertThat(deserialized, is(timestampedKeyAndJoinSide)); + } + + @Test + public void shouldThrowIfSerializeNullData() { + assertThrows(NullPointerException.class, + () -> STRING_SERDE.serializer().serialize(TOPIC, TimestampedKeyAndJoinSide.make(true, null, 0))); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/TimestampedKeyValueStoreBuilderTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/TimestampedKeyValueStoreBuilderTest.java new file mode 100644 index 0000000..b79d67e --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/TimestampedKeyValueStoreBuilderTest.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.state.KeyValueBytesStoreSupplier; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.easymock.EasyMockRunner; +import org.easymock.Mock; +import org.easymock.MockType; +import org.hamcrest.CoreMatchers; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import java.util.Collections; + +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.reset; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.core.IsInstanceOf.instanceOf; +import static org.junit.Assert.assertThrows; + +@RunWith(EasyMockRunner.class) +public class TimestampedKeyValueStoreBuilderTest { + + @Mock(type = MockType.NICE) + private KeyValueBytesStoreSupplier supplier; + @Mock(type = MockType.NICE) + private RocksDBTimestampedStore inner; + private TimestampedKeyValueStoreBuilder builder; + + @Before + public void setUp() { + expect(supplier.get()).andReturn(inner); + expect(supplier.name()).andReturn("name"); + expect(supplier.metricsScope()).andReturn("metricScope"); + expect(inner.persistent()).andReturn(true).anyTimes(); + replay(supplier, inner); + + builder = new TimestampedKeyValueStoreBuilder<>( + supplier, + Serdes.String(), + Serdes.String(), + new MockTime() + ); + } + + @Test + public void shouldHaveMeteredStoreAsOuterStore() { + final TimestampedKeyValueStore store = builder.build(); + assertThat(store, instanceOf(MeteredTimestampedKeyValueStore.class)); + } + + @Test + public void shouldHaveChangeLoggingStoreByDefault() { + final TimestampedKeyValueStore store = builder.build(); + assertThat(store, instanceOf(MeteredTimestampedKeyValueStore.class)); + final StateStore next = ((WrappedStateStore) store).wrapped(); + assertThat(next, instanceOf(ChangeLoggingTimestampedKeyValueBytesStore.class)); + } + + @Test + public void shouldNotHaveChangeLoggingStoreWhenDisabled() { + final TimestampedKeyValueStore store = builder.withLoggingDisabled().build(); + final StateStore next = ((WrappedStateStore) store).wrapped(); + assertThat(next, CoreMatchers.equalTo(inner)); + } + + @Test + public void shouldHaveCachingStoreWhenEnabled() { + final TimestampedKeyValueStore store = builder.withCachingEnabled().build(); + final StateStore wrapped = ((WrappedStateStore) store).wrapped(); + assertThat(store, instanceOf(MeteredTimestampedKeyValueStore.class)); + assertThat(wrapped, instanceOf(CachingKeyValueStore.class)); + } + + @Test + public void shouldHaveChangeLoggingStoreWhenLoggingEnabled() { + final TimestampedKeyValueStore store = builder + .withLoggingEnabled(Collections.emptyMap()) + .build(); + final StateStore wrapped = ((WrappedStateStore) store).wrapped(); + assertThat(store, instanceOf(MeteredTimestampedKeyValueStore.class)); + assertThat(wrapped, instanceOf(ChangeLoggingTimestampedKeyValueBytesStore.class)); + assertThat(((WrappedStateStore) wrapped).wrapped(), CoreMatchers.equalTo(inner)); + } + + @Test + public void shouldHaveCachingAndChangeLoggingWhenBothEnabled() { + final TimestampedKeyValueStore store = builder + .withLoggingEnabled(Collections.emptyMap()) + .withCachingEnabled() + .build(); + final WrappedStateStore caching = (WrappedStateStore) ((WrappedStateStore) store).wrapped(); + final WrappedStateStore changeLogging = (WrappedStateStore) caching.wrapped(); + assertThat(store, instanceOf(MeteredTimestampedKeyValueStore.class)); + assertThat(caching, instanceOf(CachingKeyValueStore.class)); + assertThat(changeLogging, instanceOf(ChangeLoggingTimestampedKeyValueBytesStore.class)); + assertThat(changeLogging.wrapped(), CoreMatchers.equalTo(inner)); + } + + @Test + public void shouldNotWrapTimestampedByteStore() { + reset(supplier); + expect(supplier.get()).andReturn(new RocksDBTimestampedStore("name", "metrics-scope")); + expect(supplier.name()).andReturn("name"); + replay(supplier); + + final TimestampedKeyValueStore store = builder + .withLoggingDisabled() + .withCachingDisabled() + .build(); + assertThat(((WrappedStateStore) store).wrapped(), instanceOf(RocksDBTimestampedStore.class)); + } + + @Test + public void shouldWrapPlainKeyValueStoreAsTimestampStore() { + reset(supplier); + expect(supplier.get()).andReturn(new RocksDBStore("name", "metrics-scope")); + expect(supplier.name()).andReturn("name"); + replay(supplier); + + final TimestampedKeyValueStore store = builder + .withLoggingDisabled() + .withCachingDisabled() + .build(); + assertThat(((WrappedStateStore) store).wrapped(), instanceOf(KeyValueToTimestampedKeyValueByteStoreAdapter.class)); + } + + @SuppressWarnings("all") + @Test + public void shouldThrowNullPointerIfInnerIsNull() { + assertThrows(NullPointerException.class, () -> new TimestampedKeyValueStoreBuilder<>(null, Serdes.String(), Serdes.String(), new MockTime())); + } + + @Test + public void shouldThrowNullPointerIfKeySerdeIsNull() { + assertThrows(NullPointerException.class, () -> new TimestampedKeyValueStoreBuilder<>(supplier, null, Serdes.String(), new MockTime())); + } + + @Test + public void shouldThrowNullPointerIfValueSerdeIsNull() { + assertThrows(NullPointerException.class, () -> new TimestampedKeyValueStoreBuilder<>(supplier, Serdes.String(), null, new MockTime())); + } + + @Test + public void shouldThrowNullPointerIfTimeIsNull() { + assertThrows(NullPointerException.class, () -> new TimestampedKeyValueStoreBuilder<>(supplier, Serdes.String(), Serdes.String(), null)); + } + + @Test + public void shouldThrowNullPointerIfMetricsScopeIsNull() { + reset(supplier); + expect(supplier.get()).andReturn(new RocksDBTimestampedStore("name", null)); + expect(supplier.name()).andReturn("name"); + replay(supplier); + + final Exception e = assertThrows(NullPointerException.class, + () -> new TimestampedKeyValueStoreBuilder<>(supplier, Serdes.String(), Serdes.String(), new MockTime())); + assertThat(e.getMessage(), equalTo("storeSupplier's metricsScope can't be null")); + } + +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/TimestampedSegmentTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/TimestampedSegmentTest.java new file mode 100644 index 0000000..9d339b7 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/TimestampedSegmentTest.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.internals.metrics.RocksDBMetricsRecorder; +import org.apache.kafka.test.TestUtils; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.util.HashSet; +import java.util.Set; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.streams.StreamsConfig.METRICS_RECORDING_LEVEL_CONFIG; +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.mock; +import static org.easymock.EasyMock.replay; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class TimestampedSegmentTest { + + private final RocksDBMetricsRecorder metricsRecorder = + new RocksDBMetricsRecorder("metrics-scope", "store-name"); + + @Before + public void setUp() { + metricsRecorder.init( + new StreamsMetricsImpl(new Metrics(), "test-client", StreamsConfig.METRICS_LATEST, new MockTime()), + new TaskId(0, 0) + ); + } + + @Test + public void shouldDeleteStateDirectoryOnDestroy() throws Exception { + final TimestampedSegment segment = new TimestampedSegment("segment", "window", 0L, metricsRecorder); + final String directoryPath = TestUtils.tempDirectory().getAbsolutePath(); + final File directory = new File(directoryPath); + + final ProcessorContext mockContext = mock(ProcessorContext.class); + expect(mockContext.appConfigs()).andReturn(mkMap(mkEntry(METRICS_RECORDING_LEVEL_CONFIG, "INFO"))); + expect(mockContext.stateDir()).andReturn(directory); + replay(mockContext); + + segment.openDB(mockContext.appConfigs(), mockContext.stateDir()); + + assertTrue(new File(directoryPath, "window").exists()); + assertTrue(new File(directoryPath + File.separator + "window", "segment").exists()); + assertTrue(new File(directoryPath + File.separator + "window", "segment").list().length > 0); + segment.destroy(); + assertFalse(new File(directoryPath + File.separator + "window", "segment").exists()); + assertTrue(new File(directoryPath, "window").exists()); + + segment.close(); + } + + @Test + public void shouldBeEqualIfIdIsEqual() { + final TimestampedSegment segment = new TimestampedSegment("anyName", "anyName", 0L, metricsRecorder); + final TimestampedSegment segmentSameId = + new TimestampedSegment("someOtherName", "someOtherName", 0L, metricsRecorder); + final TimestampedSegment segmentDifferentId = + new TimestampedSegment("anyName", "anyName", 1L, metricsRecorder); + + assertThat(segment, equalTo(segment)); + assertThat(segment, equalTo(segmentSameId)); + assertThat(segment, not(equalTo(segmentDifferentId))); + assertThat(segment, not(equalTo(null))); + assertThat(segment, not(equalTo("anyName"))); + + segment.close(); + segmentSameId.close(); + segmentDifferentId.close(); + } + + @Test + public void shouldHashOnSegmentIdOnly() { + final TimestampedSegment segment = new TimestampedSegment("anyName", "anyName", 0L, metricsRecorder); + final TimestampedSegment segmentSameId = + new TimestampedSegment("someOtherName", "someOtherName", 0L, metricsRecorder); + final TimestampedSegment segmentDifferentId = + new TimestampedSegment("anyName", "anyName", 1L, metricsRecorder); + + final Set set = new HashSet<>(); + assertTrue(set.add(segment)); + assertFalse(set.add(segmentSameId)); + assertTrue(set.add(segmentDifferentId)); + + segment.close(); + segmentSameId.close(); + segmentDifferentId.close(); + } + + @Test + public void shouldCompareSegmentIdOnly() { + final TimestampedSegment segment1 = new TimestampedSegment("a", "C", 50L, metricsRecorder); + final TimestampedSegment segment2 = new TimestampedSegment("b", "B", 100L, metricsRecorder); + final TimestampedSegment segment3 = new TimestampedSegment("c", "A", 0L, metricsRecorder); + + assertThat(segment1.compareTo(segment1), equalTo(0)); + assertThat(segment1.compareTo(segment2), equalTo(-1)); + assertThat(segment2.compareTo(segment1), equalTo(1)); + assertThat(segment1.compareTo(segment3), equalTo(1)); + assertThat(segment3.compareTo(segment1), equalTo(-1)); + assertThat(segment2.compareTo(segment3), equalTo(1)); + assertThat(segment3.compareTo(segment2), equalTo(-1)); + + segment1.close(); + segment2.close(); + segment3.close(); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/TimestampedSegmentsTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/TimestampedSegmentsTest.java new file mode 100644 index 0000000..722cb69 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/TimestampedSegmentsTest.java @@ -0,0 +1,354 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.streams.processor.internals.MockStreamsMetrics; +import org.apache.kafka.test.InternalMockProcessorContext; +import org.apache.kafka.test.MockRecordCollector; +import org.apache.kafka.test.TestUtils; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.File; +import java.text.SimpleDateFormat; +import java.util.Date; +import java.util.List; +import java.util.SimpleTimeZone; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +public class TimestampedSegmentsTest { + + private static final int NUM_SEGMENTS = 5; + private static final long SEGMENT_INTERVAL = 100L; + private static final long RETENTION_PERIOD = 4 * SEGMENT_INTERVAL; + private static final String METRICS_SCOPE = "test-state-id"; + private InternalMockProcessorContext context; + private TimestampedSegments segments; + private File stateDirectory; + private final String storeName = "test"; + + @Before + public void createContext() { + stateDirectory = TestUtils.tempDirectory(); + context = new InternalMockProcessorContext<>( + stateDirectory, + Serdes.String(), + Serdes.Long(), + new MockRecordCollector(), + new ThreadCache(new LogContext("testCache "), 0, new MockStreamsMetrics(new Metrics())) + ); + segments = new TimestampedSegments(storeName, METRICS_SCOPE, RETENTION_PERIOD, SEGMENT_INTERVAL); + segments.openExisting(context, -1L); + } + + @After + public void close() { + segments.close(); + } + + @Test + public void shouldGetSegmentIdsFromTimestamp() { + assertEquals(0, segments.segmentId(0)); + assertEquals(1, segments.segmentId(SEGMENT_INTERVAL)); + assertEquals(2, segments.segmentId(2 * SEGMENT_INTERVAL)); + assertEquals(3, segments.segmentId(3 * SEGMENT_INTERVAL)); + } + + @Test + public void shouldBaseSegmentIntervalOnRetentionAndNumSegments() { + final TimestampedSegments segments = + new TimestampedSegments("test", METRICS_SCOPE, 8 * SEGMENT_INTERVAL, 2 * SEGMENT_INTERVAL); + assertEquals(0, segments.segmentId(0)); + assertEquals(0, segments.segmentId(SEGMENT_INTERVAL)); + assertEquals(1, segments.segmentId(2 * SEGMENT_INTERVAL)); + } + + @Test + public void shouldGetSegmentNameFromId() { + assertEquals("test.0", segments.segmentName(0)); + assertEquals("test." + SEGMENT_INTERVAL, segments.segmentName(1)); + assertEquals("test." + 2 * SEGMENT_INTERVAL, segments.segmentName(2)); + } + + @Test + public void shouldCreateSegments() { + final TimestampedSegment segment1 = segments.getOrCreateSegmentIfLive(0, context, -1L); + final TimestampedSegment segment2 = segments.getOrCreateSegmentIfLive(1, context, -1L); + final TimestampedSegment segment3 = segments.getOrCreateSegmentIfLive(2, context, -1L); + assertTrue(new File(context.stateDir(), "test/test.0").isDirectory()); + assertTrue(new File(context.stateDir(), "test/test." + SEGMENT_INTERVAL).isDirectory()); + assertTrue(new File(context.stateDir(), "test/test." + 2 * SEGMENT_INTERVAL).isDirectory()); + assertTrue(segment1.isOpen()); + assertTrue(segment2.isOpen()); + assertTrue(segment3.isOpen()); + } + + @Test + public void shouldNotCreateSegmentThatIsAlreadyExpired() { + final long streamTime = updateStreamTimeAndCreateSegment(7); + assertNull(segments.getOrCreateSegmentIfLive(0, context, streamTime)); + assertFalse(new File(context.stateDir(), "test/test.0").exists()); + } + + @Test + public void shouldCleanupSegmentsThatHaveExpired() { + final TimestampedSegment segment1 = segments.getOrCreateSegmentIfLive(0, context, -1L); + final TimestampedSegment segment2 = segments.getOrCreateSegmentIfLive(1, context, -1L); + final TimestampedSegment segment3 = segments.getOrCreateSegmentIfLive(7, context, SEGMENT_INTERVAL * 7L); + assertFalse(segment1.isOpen()); + assertFalse(segment2.isOpen()); + assertTrue(segment3.isOpen()); + assertFalse(new File(context.stateDir(), "test/test.0").exists()); + assertFalse(new File(context.stateDir(), "test/test." + SEGMENT_INTERVAL).exists()); + assertTrue(new File(context.stateDir(), "test/test." + 7 * SEGMENT_INTERVAL).exists()); + } + + @Test + public void shouldGetSegmentForTimestamp() { + final TimestampedSegment segment = segments.getOrCreateSegmentIfLive(0, context, -1L); + segments.getOrCreateSegmentIfLive(1, context, -1L); + assertEquals(segment, segments.getSegmentForTimestamp(0L)); + } + + @Test + public void shouldGetCorrectSegmentString() { + final TimestampedSegment segment = segments.getOrCreateSegmentIfLive(0, context, -1L); + assertEquals("TimestampedSegment(id=0, name=test.0)", segment.toString()); + } + + @Test + public void shouldCloseAllOpenSegments() { + final TimestampedSegment first = segments.getOrCreateSegmentIfLive(0, context, -1L); + final TimestampedSegment second = segments.getOrCreateSegmentIfLive(1, context, -1L); + final TimestampedSegment third = segments.getOrCreateSegmentIfLive(2, context, -1L); + segments.close(); + + assertFalse(first.isOpen()); + assertFalse(second.isOpen()); + assertFalse(third.isOpen()); + } + + @Test + public void shouldOpenExistingSegments() { + segments = new TimestampedSegments("test", METRICS_SCOPE, 4, 1); + segments.openExisting(context, -1L); + segments.getOrCreateSegmentIfLive(0, context, -1L); + segments.getOrCreateSegmentIfLive(1, context, -1L); + segments.getOrCreateSegmentIfLive(2, context, -1L); + segments.getOrCreateSegmentIfLive(3, context, -1L); + segments.getOrCreateSegmentIfLive(4, context, -1L); + // close existing. + segments.close(); + + segments = new TimestampedSegments("test", METRICS_SCOPE, 4, 1); + segments.openExisting(context, -1L); + + assertTrue(segments.getSegmentForTimestamp(0).isOpen()); + assertTrue(segments.getSegmentForTimestamp(1).isOpen()); + assertTrue(segments.getSegmentForTimestamp(2).isOpen()); + assertTrue(segments.getSegmentForTimestamp(3).isOpen()); + assertTrue(segments.getSegmentForTimestamp(4).isOpen()); + } + + @Test + public void shouldGetSegmentsWithinTimeRange() { + updateStreamTimeAndCreateSegment(0); + updateStreamTimeAndCreateSegment(1); + updateStreamTimeAndCreateSegment(2); + updateStreamTimeAndCreateSegment(3); + final long streamTime = updateStreamTimeAndCreateSegment(4); + segments.getOrCreateSegmentIfLive(0, context, streamTime); + segments.getOrCreateSegmentIfLive(1, context, streamTime); + segments.getOrCreateSegmentIfLive(2, context, streamTime); + segments.getOrCreateSegmentIfLive(3, context, streamTime); + segments.getOrCreateSegmentIfLive(4, context, streamTime); + + final List segments = this.segments.segments(0, 2 * SEGMENT_INTERVAL, true); + assertEquals(3, segments.size()); + assertEquals(0, segments.get(0).id); + assertEquals(1, segments.get(1).id); + assertEquals(2, segments.get(2).id); + } + + @Test + public void shouldGetSegmentsWithinBackwardTimeRange() { + updateStreamTimeAndCreateSegment(0); + updateStreamTimeAndCreateSegment(1); + updateStreamTimeAndCreateSegment(2); + updateStreamTimeAndCreateSegment(3); + final long streamTime = updateStreamTimeAndCreateSegment(4); + segments.getOrCreateSegmentIfLive(0, context, streamTime); + segments.getOrCreateSegmentIfLive(1, context, streamTime); + segments.getOrCreateSegmentIfLive(2, context, streamTime); + segments.getOrCreateSegmentIfLive(3, context, streamTime); + segments.getOrCreateSegmentIfLive(4, context, streamTime); + + final List segments = this.segments.segments(0, 2 * SEGMENT_INTERVAL, false); + assertEquals(3, segments.size()); + assertEquals(0, segments.get(2).id); + assertEquals(1, segments.get(1).id); + assertEquals(2, segments.get(0).id); + } + + @Test + public void shouldGetSegmentsWithinTimeRangeOutOfOrder() { + updateStreamTimeAndCreateSegment(4); + updateStreamTimeAndCreateSegment(2); + updateStreamTimeAndCreateSegment(0); + updateStreamTimeAndCreateSegment(1); + updateStreamTimeAndCreateSegment(3); + + final List segments = this.segments.segments(0, 2 * SEGMENT_INTERVAL, true); + assertEquals(3, segments.size()); + assertEquals(0, segments.get(0).id); + assertEquals(1, segments.get(1).id); + assertEquals(2, segments.get(2).id); + } + + @Test + public void shouldGetSegmentsWithinBackwardTimeRangeOutOfOrder() { + updateStreamTimeAndCreateSegment(4); + updateStreamTimeAndCreateSegment(2); + updateStreamTimeAndCreateSegment(0); + updateStreamTimeAndCreateSegment(1); + updateStreamTimeAndCreateSegment(3); + + final List segments = this.segments.segments(0, 2 * SEGMENT_INTERVAL, false); + assertEquals(3, segments.size()); + assertEquals(0, segments.get(2).id); + assertEquals(1, segments.get(1).id); + assertEquals(2, segments.get(0).id); + } + + @Test + public void shouldRollSegments() { + updateStreamTimeAndCreateSegment(0); + verifyCorrectSegments(0, 1); + updateStreamTimeAndCreateSegment(1); + verifyCorrectSegments(0, 2); + updateStreamTimeAndCreateSegment(2); + verifyCorrectSegments(0, 3); + updateStreamTimeAndCreateSegment(3); + verifyCorrectSegments(0, 4); + updateStreamTimeAndCreateSegment(4); + verifyCorrectSegments(0, 5); + updateStreamTimeAndCreateSegment(5); + verifyCorrectSegments(1, 5); + updateStreamTimeAndCreateSegment(6); + verifyCorrectSegments(2, 5); + } + + @Test + public void futureEventsShouldNotCauseSegmentRoll() { + updateStreamTimeAndCreateSegment(0); + verifyCorrectSegments(0, 1); + updateStreamTimeAndCreateSegment(1); + verifyCorrectSegments(0, 2); + updateStreamTimeAndCreateSegment(2); + verifyCorrectSegments(0, 3); + updateStreamTimeAndCreateSegment(3); + verifyCorrectSegments(0, 4); + final long streamTime = updateStreamTimeAndCreateSegment(4); + verifyCorrectSegments(0, 5); + segments.getOrCreateSegmentIfLive(5, context, streamTime); + verifyCorrectSegments(0, 6); + segments.getOrCreateSegmentIfLive(6, context, streamTime); + verifyCorrectSegments(0, 7); + } + + private long updateStreamTimeAndCreateSegment(final int segment) { + final long streamTime = SEGMENT_INTERVAL * segment; + segments.getOrCreateSegmentIfLive(segment, context, streamTime); + return streamTime; + } + + @Test + public void shouldUpdateSegmentFileNameFromOldDateFormatToNewFormat() throws Exception { + final long segmentInterval = 60_000L; // the old segment file's naming system maxes out at 1 minute granularity. + + segments = new TimestampedSegments(storeName, METRICS_SCOPE, NUM_SEGMENTS * segmentInterval, segmentInterval); + + final String storeDirectoryPath = stateDirectory.getAbsolutePath() + File.separator + storeName; + final File storeDirectory = new File(storeDirectoryPath); + //noinspection ResultOfMethodCallIgnored + storeDirectory.mkdirs(); + + final SimpleDateFormat formatter = new SimpleDateFormat("yyyyMMddHHmm"); + formatter.setTimeZone(new SimpleTimeZone(0, "UTC")); + + for (int segmentId = 0; segmentId < NUM_SEGMENTS; ++segmentId) { + final File oldSegment = new File(storeDirectoryPath + File.separator + storeName + "-" + formatter.format(new Date(segmentId * segmentInterval))); + //noinspection ResultOfMethodCallIgnored + oldSegment.createNewFile(); + } + + segments.openExisting(context, -1L); + + for (int segmentId = 0; segmentId < NUM_SEGMENTS; ++segmentId) { + final String segmentName = storeName + "." + (long) segmentId * segmentInterval; + final File newSegment = new File(storeDirectoryPath + File.separator + segmentName); + assertTrue(newSegment.exists()); + } + } + + @Test + public void shouldUpdateSegmentFileNameFromOldColonFormatToNewFormat() throws Exception { + final String storeDirectoryPath = stateDirectory.getAbsolutePath() + File.separator + storeName; + final File storeDirectory = new File(storeDirectoryPath); + //noinspection ResultOfMethodCallIgnored + storeDirectory.mkdirs(); + + for (int segmentId = 0; segmentId < NUM_SEGMENTS; ++segmentId) { + final File oldSegment = new File(storeDirectoryPath + File.separator + storeName + ":" + segmentId * (RETENTION_PERIOD / (NUM_SEGMENTS - 1))); + //noinspection ResultOfMethodCallIgnored + oldSegment.createNewFile(); + } + + segments.openExisting(context, -1L); + + for (int segmentId = 0; segmentId < NUM_SEGMENTS; ++segmentId) { + final File newSegment = new File(storeDirectoryPath + File.separator + storeName + "." + segmentId * (RETENTION_PERIOD / (NUM_SEGMENTS - 1))); + assertTrue(newSegment.exists()); + } + } + + @Test + public void shouldClearSegmentsOnClose() { + segments.getOrCreateSegmentIfLive(0, context, -1L); + segments.close(); + assertThat(segments.getSegmentForTimestamp(0), is(nullValue())); + } + + private void verifyCorrectSegments(final long first, final int numSegments) { + final List result = this.segments.segments(0, Long.MAX_VALUE, true); + assertEquals(numSegments, result.size()); + for (int i = 0; i < numSegments; i++) { + assertEquals(i + first, result.get(i).id); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/TimestampedWindowStoreBuilderTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/TimestampedWindowStoreBuilderTest.java new file mode 100644 index 0000000..586ec73 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/TimestampedWindowStoreBuilderTest.java @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.state.internals; + +import java.time.Duration; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.TimestampedWindowStore; +import org.apache.kafka.streams.state.WindowBytesStoreSupplier; +import org.easymock.EasyMockRunner; +import org.easymock.Mock; +import org.easymock.MockType; +import org.hamcrest.CoreMatchers; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import java.util.Collections; + +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.reset; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.core.IsInstanceOf.instanceOf; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; + +@RunWith(EasyMockRunner.class) +public class TimestampedWindowStoreBuilderTest { + + @Mock(type = MockType.NICE) + private WindowBytesStoreSupplier supplier; + @Mock(type = MockType.NICE) + private RocksDBTimestampedWindowStore inner; + private TimestampedWindowStoreBuilder builder; + + @Before + public void setUp() { + expect(supplier.get()).andReturn(inner); + expect(supplier.name()).andReturn("name"); + expect(supplier.metricsScope()).andReturn("metricScope"); + expect(inner.persistent()).andReturn(true).anyTimes(); + replay(supplier, inner); + + builder = new TimestampedWindowStoreBuilder<>( + supplier, + Serdes.String(), + Serdes.String(), + new MockTime()); + } + + @Test + public void shouldHaveMeteredStoreAsOuterStore() { + final TimestampedWindowStore store = builder.build(); + assertThat(store, instanceOf(MeteredTimestampedWindowStore.class)); + } + + @Test + public void shouldHaveChangeLoggingStoreByDefault() { + final TimestampedWindowStore store = builder.build(); + final StateStore next = ((WrappedStateStore) store).wrapped(); + assertThat(next, instanceOf(ChangeLoggingTimestampedWindowBytesStore.class)); + } + + @Test + public void shouldNotHaveChangeLoggingStoreWhenDisabled() { + final TimestampedWindowStore store = builder.withLoggingDisabled().build(); + final StateStore next = ((WrappedStateStore) store).wrapped(); + assertThat(next, CoreMatchers.equalTo(inner)); + } + + @Test + public void shouldHaveCachingStoreWhenEnabled() { + final TimestampedWindowStore store = builder.withCachingEnabled().build(); + final StateStore wrapped = ((WrappedStateStore) store).wrapped(); + assertThat(store, instanceOf(MeteredTimestampedWindowStore.class)); + assertThat(wrapped, instanceOf(CachingWindowStore.class)); + } + + @Test + public void shouldHaveChangeLoggingStoreWhenLoggingEnabled() { + final TimestampedWindowStore store = builder + .withLoggingEnabled(Collections.emptyMap()) + .build(); + final StateStore wrapped = ((WrappedStateStore) store).wrapped(); + assertThat(store, instanceOf(MeteredTimestampedWindowStore.class)); + assertThat(wrapped, instanceOf(ChangeLoggingTimestampedWindowBytesStore.class)); + assertThat(((WrappedStateStore) wrapped).wrapped(), CoreMatchers.equalTo(inner)); + } + + @Test + public void shouldHaveCachingAndChangeLoggingWhenBothEnabled() { + final TimestampedWindowStore store = builder + .withLoggingEnabled(Collections.emptyMap()) + .withCachingEnabled() + .build(); + final WrappedStateStore caching = (WrappedStateStore) ((WrappedStateStore) store).wrapped(); + final WrappedStateStore changeLogging = (WrappedStateStore) caching.wrapped(); + assertThat(store, instanceOf(MeteredTimestampedWindowStore.class)); + assertThat(caching, instanceOf(CachingWindowStore.class)); + assertThat(changeLogging, instanceOf(ChangeLoggingTimestampedWindowBytesStore.class)); + assertThat(changeLogging.wrapped(), CoreMatchers.equalTo(inner)); + } + + @Test + public void shouldNotWrapTimestampedByteStore() { + reset(supplier); + expect(supplier.get()).andReturn(new RocksDBTimestampedWindowStore( + new RocksDBTimestampedSegmentedBytesStore( + "name", + "metric-scope", + 10L, + 5L, + new WindowKeySchema()), + false, + 1L)); + expect(supplier.name()).andReturn("name"); + replay(supplier); + + final TimestampedWindowStore store = builder + .withLoggingDisabled() + .withCachingDisabled() + .build(); + assertThat(((WrappedStateStore) store).wrapped(), instanceOf(RocksDBTimestampedWindowStore.class)); + } + + @Test + public void shouldWrapPlainKeyValueStoreAsTimestampStore() { + reset(supplier); + expect(supplier.get()).andReturn(new RocksDBWindowStore( + new RocksDBSegmentedBytesStore( + "name", + "metric-scope", + 10L, + 5L, + new WindowKeySchema()), + false, + 1L)); + expect(supplier.name()).andReturn("name"); + replay(supplier); + + final TimestampedWindowStore store = builder + .withLoggingDisabled() + .withCachingDisabled() + .build(); + assertThat(((WrappedStateStore) store).wrapped(), instanceOf(WindowToTimestampedWindowByteStoreAdapter.class)); + } + + @SuppressWarnings("unchecked") + @Test + public void shouldDisableCachingWithRetainDuplicates() { + supplier = Stores.persistentTimestampedWindowStore("name", Duration.ofMillis(10L), Duration.ofMillis(10L), true); + final StoreBuilder> builder = new TimestampedWindowStoreBuilder<>( + supplier, + Serdes.String(), + Serdes.String(), + new MockTime() + ).withCachingEnabled(); + + builder.build(); + + assertFalse(((AbstractStoreBuilder>) builder).enableCaching); + } + + @SuppressWarnings("all") + @Test + public void shouldThrowNullPointerIfInnerIsNull() { + assertThrows(NullPointerException.class, () -> new TimestampedWindowStoreBuilder<>(null, Serdes.String(), Serdes.String(), new MockTime())); + } + + @Test + public void shouldThrowNullPointerIfKeySerdeIsNull() { + assertThrows(NullPointerException.class, () -> new TimestampedWindowStoreBuilder<>(supplier, null, Serdes.String(), new MockTime())); + } + + @Test + public void shouldThrowNullPointerIfValueSerdeIsNull() { + assertThrows(NullPointerException.class, () -> new TimestampedWindowStoreBuilder<>(supplier, Serdes.String(), null, new MockTime())); + } + + @Test + public void shouldThrowNullPointerIfTimeIsNull() { + assertThrows(NullPointerException.class, () -> new TimestampedWindowStoreBuilder<>(supplier, Serdes.String(), Serdes.String(), null)); + } + + @Test + public void shouldThrowNullPointerIfMetricsScopeIsNull() { + reset(supplier); + expect(supplier.get()).andReturn(new RocksDBTimestampedWindowStore( + new RocksDBTimestampedSegmentedBytesStore( + "name", + null, + 10L, + 5L, + new WindowKeySchema()), + false, + 1L)); + expect(supplier.name()).andReturn("name"); + replay(supplier); + final Exception e = assertThrows(NullPointerException.class, + () -> new TimestampedWindowStoreBuilder<>(supplier, Serdes.String(), Serdes.String(), new MockTime())); + assertThat(e.getMessage(), equalTo("storeSupplier's metricsScope can't be null")); + } + +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/ValueAndTimestampSerializerTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/ValueAndTimestampSerializerTest.java new file mode 100644 index 0000000..599ce3d --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/ValueAndTimestampSerializerTest.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.junit.Test; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class ValueAndTimestampSerializerTest { + private static final String TOPIC = "some-topic"; + private static final long TIMESTAMP = 23; + + private static final ValueAndTimestampSerde STRING_SERDE = + new ValueAndTimestampSerde<>(Serdes.String()); + + @Test + public void shouldSerializeNonNullDataUsingTheInternalSerializer() { + final String value = "some-string"; + + final ValueAndTimestamp valueAndTimestamp = ValueAndTimestamp.make(value, TIMESTAMP); + + final byte[] serialized = + STRING_SERDE.serializer().serialize(TOPIC, valueAndTimestamp); + + assertThat(serialized, is(notNullValue())); + + final ValueAndTimestamp deserialized = + STRING_SERDE.deserializer().deserialize(TOPIC, serialized); + + assertThat(deserialized, is(valueAndTimestamp)); + } + + @Test + public void shouldDropSerializedValueIfEqualWithGreaterTimestamp() { + final String value = "food"; + + final ValueAndTimestamp oldValueAndTimestamp = ValueAndTimestamp.make(value, TIMESTAMP); + final byte[] oldSerializedValue = STRING_SERDE.serializer().serialize(TOPIC, oldValueAndTimestamp); + final ValueAndTimestamp newValueAndTimestamp = ValueAndTimestamp.make(value, TIMESTAMP + 1); + final byte[] newSerializedValue = STRING_SERDE.serializer().serialize(TOPIC, newValueAndTimestamp); + assertTrue(ValueAndTimestampSerializer.valuesAreSameAndTimeIsIncreasing(oldSerializedValue, newSerializedValue)); + } + + @Test + public void shouldKeepSerializedValueIfOutOfOrder() { + final String value = "balls"; + + final ValueAndTimestamp oldValueAndTimestamp = ValueAndTimestamp.make(value, TIMESTAMP); + final byte[] oldSerializedValue = STRING_SERDE.serializer().serialize(TOPIC, oldValueAndTimestamp); + final ValueAndTimestamp outOfOrderValueAndTimestamp = ValueAndTimestamp.make(value, TIMESTAMP - 1); + final byte[] outOfOrderSerializedValue = STRING_SERDE.serializer().serialize(TOPIC, outOfOrderValueAndTimestamp); + assertFalse(ValueAndTimestampSerializer.valuesAreSameAndTimeIsIncreasing(oldSerializedValue, outOfOrderSerializedValue)); + } + + @Test + public void shouldSerializeNullDataAsNull() { + final byte[] serialized = + STRING_SERDE.serializer().serialize(TOPIC, ValueAndTimestamp.make(null, TIMESTAMP)); + + assertThat(serialized, is(nullValue())); + } + + @Test + public void shouldReturnNullWhenTheInternalSerializerReturnsNull() { + // Testing against regressions with respect to https://github.com/apache/kafka/pull/7679 + + final Serializer alwaysNullSerializer = (topic, data) -> null; + + final ValueAndTimestampSerializer serializer = + new ValueAndTimestampSerializer<>(alwaysNullSerializer); + + final byte[] serialized = serializer.serialize(TOPIC, "non-null-data", TIMESTAMP); + + assertThat(serialized, is(nullValue())); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/WindowKeySchemaTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/WindowKeySchemaTest.java new file mode 100644 index 0000000..dc88410 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/WindowKeySchemaTest.java @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.TimeWindowedDeserializer; +import org.apache.kafka.streams.kstream.Window; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.WindowedSerdes; +import org.apache.kafka.streams.kstream.internals.TimeWindow; +import org.apache.kafka.streams.state.StateSerdes; +import org.apache.kafka.test.KeyValueIteratorStub; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.IsEqual.equalTo; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +public class WindowKeySchemaTest { + + final private String key = "key"; + final private String topic = "topic"; + final private long startTime = 50L; + final private long endTime = 100L; + final private Serde serde = Serdes.String(); + + final private Window window = new TimeWindow(startTime, endTime); + final private Windowed windowedKey = new Windowed<>(key, window); + final private WindowKeySchema windowKeySchema = new WindowKeySchema(); + final private Serde> keySerde = new WindowedSerdes.TimeWindowedSerde<>(serde, Long.MAX_VALUE); + final private StateSerdes stateSerdes = new StateSerdes<>("dummy", serde, Serdes.ByteArray()); + + @Test + public void testHasNextConditionUsingNullKeys() { + final List> keys = Arrays.asList( + KeyValue.pair(WindowKeySchema.toStoreKeyBinary(new Windowed<>(Bytes.wrap(new byte[] {0, 0}), new TimeWindow(0, 1)), 0), 1), + KeyValue.pair(WindowKeySchema.toStoreKeyBinary(new Windowed<>(Bytes.wrap(new byte[] {0}), new TimeWindow(0, 1)), 0), 2), + KeyValue.pair(WindowKeySchema.toStoreKeyBinary(new Windowed<>(Bytes.wrap(new byte[] {0, 0, 0}), new TimeWindow(0, 1)), 0), 3), + KeyValue.pair(WindowKeySchema.toStoreKeyBinary(new Windowed<>(Bytes.wrap(new byte[] {0}), new TimeWindow(10, 20)), 4), 4), + KeyValue.pair(WindowKeySchema.toStoreKeyBinary(new Windowed<>(Bytes.wrap(new byte[] {0, 0}), new TimeWindow(10, 20)), 5), 5), + KeyValue.pair(WindowKeySchema.toStoreKeyBinary(new Windowed<>(Bytes.wrap(new byte[] {0, 0, 0}), new TimeWindow(10, 20)), 6), 6)); + try (final DelegatingPeekingKeyValueIterator iterator = new DelegatingPeekingKeyValueIterator<>("foo", new KeyValueIteratorStub<>(keys.iterator()))) { + + final HasNextCondition hasNextCondition = windowKeySchema.hasNextCondition(null, null, 0, Long.MAX_VALUE); + final List results = new ArrayList<>(); + while (hasNextCondition.hasNext(iterator)) { + results.add(iterator.next().value); + } + + assertThat(results, equalTo(Arrays.asList(1, 2, 3, 4, 5, 6))); + } + } + + @Test + public void testUpperBoundWithLargeTimestamps() { + final Bytes upper = windowKeySchema.upperRange(Bytes.wrap(new byte[] {0xA, 0xB, 0xC}), Long.MAX_VALUE); + + assertThat( + "shorter key with max timestamp should be in range", + upper.compareTo( + WindowKeySchema.toStoreKeyBinary( + new byte[] {0xA}, + Long.MAX_VALUE, + Integer.MAX_VALUE + ) + ) >= 0 + ); + + assertThat( + "shorter key with max timestamp should be in range", + upper.compareTo( + WindowKeySchema.toStoreKeyBinary( + new byte[] {0xA, 0xB}, + Long.MAX_VALUE, + Integer.MAX_VALUE + ) + ) >= 0 + ); + + assertThat(upper, equalTo(WindowKeySchema.toStoreKeyBinary(new byte[] {0xA}, Long.MAX_VALUE, Integer.MAX_VALUE))); + } + + @Test + public void testUpperBoundWithKeyBytesLargerThanFirstTimestampByte() { + final Bytes upper = windowKeySchema.upperRange(Bytes.wrap(new byte[] {0xA, (byte) 0x8F, (byte) 0x9F}), Long.MAX_VALUE); + + assertThat( + "shorter key with max timestamp should be in range", + upper.compareTo( + WindowKeySchema.toStoreKeyBinary( + new byte[] {0xA, (byte) 0x8F}, + Long.MAX_VALUE, + Integer.MAX_VALUE + ) + ) >= 0 + ); + + assertThat(upper, equalTo(WindowKeySchema.toStoreKeyBinary(new byte[] {0xA, (byte) 0x8F, (byte) 0x9F}, Long.MAX_VALUE, Integer.MAX_VALUE))); + } + + + @Test + public void testUpperBoundWithKeyBytesLargerAndSmallerThanFirstTimestampByte() { + final Bytes upper = windowKeySchema.upperRange(Bytes.wrap(new byte[] {0xC, 0xC, 0x9}), 0x0AffffffffffffffL); + + assertThat( + "shorter key with customized timestamp should be in range", + upper.compareTo( + WindowKeySchema.toStoreKeyBinary( + new byte[] {0xC, 0xC}, + 0x0AffffffffffffffL, + Integer.MAX_VALUE + ) + ) >= 0 + ); + + assertThat(upper, equalTo(WindowKeySchema.toStoreKeyBinary(new byte[] {0xC, 0xC}, 0x0AffffffffffffffL, Integer.MAX_VALUE))); + } + + @Test + public void testUpperBoundWithZeroTimestamp() { + final Bytes upper = windowKeySchema.upperRange(Bytes.wrap(new byte[] {0xA, 0xB, 0xC}), 0); + assertThat(upper, equalTo(WindowKeySchema.toStoreKeyBinary(new byte[] {0xA, 0xB, 0xC}, 0, Integer.MAX_VALUE))); + } + + @Test + public void testLowerBoundWithZeroTimestamp() { + final Bytes lower = windowKeySchema.lowerRange(Bytes.wrap(new byte[] {0xA, 0xB, 0xC}), 0); + assertThat(lower, equalTo(WindowKeySchema.toStoreKeyBinary(new byte[] {0xA, 0xB, 0xC}, 0, 0))); + } + + @Test + public void testLowerBoundWithMonZeroTimestamp() { + final Bytes lower = windowKeySchema.lowerRange(Bytes.wrap(new byte[] {0xA, 0xB, 0xC}), 42); + assertThat(lower, equalTo(WindowKeySchema.toStoreKeyBinary(new byte[] {0xA, 0xB, 0xC}, 0, 0))); + } + + @Test + public void testLowerBoundMatchesTrailingZeros() { + final Bytes lower = windowKeySchema.lowerRange(Bytes.wrap(new byte[] {0xA, 0xB, 0xC}), Long.MAX_VALUE - 1); + + assertThat( + "appending zeros to key should still be in range", + lower.compareTo( + WindowKeySchema.toStoreKeyBinary( + new byte[] {0xA, 0xB, 0xC, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + Long.MAX_VALUE - 1, + 0 + ) + ) < 0 + ); + + assertThat(lower, equalTo(WindowKeySchema.toStoreKeyBinary(new byte[] {0xA, 0xB, 0xC}, 0, 0))); + } + + @Test + public void shouldSerializeDeserialize() { + final byte[] bytes = keySerde.serializer().serialize(topic, windowedKey); + final Windowed result = keySerde.deserializer().deserialize(topic, bytes); + // TODO: fix this part as last bits of KAFKA-4468 + assertEquals(new Windowed<>(key, new TimeWindow(startTime, Long.MAX_VALUE)), result); + } + + @Test + public void testSerializeDeserializeOverflowWindowSize() { + final byte[] bytes = keySerde.serializer().serialize(topic, windowedKey); + final Windowed result = new TimeWindowedDeserializer<>(serde.deserializer(), Long.MAX_VALUE - 1) + .deserialize(topic, bytes); + assertEquals(new Windowed<>(key, new TimeWindow(startTime, Long.MAX_VALUE)), result); + } + + @Test + public void shouldSerializeDeserializeExpectedWindowSize() { + final byte[] bytes = keySerde.serializer().serialize(topic, windowedKey); + final Windowed result = new TimeWindowedDeserializer<>(serde.deserializer(), endTime - startTime) + .deserialize(topic, bytes); + assertEquals(windowedKey, result); + } + + @Test + public void shouldSerializeDeserializeExpectedChangelogWindowSize() { + // Key-value containing serialized store key binary and the key's window size + final List> keys = Arrays.asList( + KeyValue.pair(WindowKeySchema.toStoreKeyBinary(new Windowed<>(Bytes.wrap(new byte[] {0}), new TimeWindow(0, 1)), 0), 1), + KeyValue.pair(WindowKeySchema.toStoreKeyBinary(new Windowed<>(Bytes.wrap(new byte[] {0, 0}), new TimeWindow(0, 10)), 0), 10), + KeyValue.pair(WindowKeySchema.toStoreKeyBinary(new Windowed<>(Bytes.wrap(new byte[] {0, 0, 0}), new TimeWindow(10, 30)), 6), 20)); + + final List results = new ArrayList<>(); + for (final KeyValue keyValue : keys) { + // Let the deserializer know that it's deserializing a changelog windowed key + final Serde> keySerde = new WindowedSerdes.TimeWindowedSerde<>(serde, keyValue.value).forChangelog(true); + final Windowed result = keySerde.deserializer().deserialize(topic, keyValue.key.get()); + final Window resultWindow = result.window(); + results.add(resultWindow.end() - resultWindow.start()); + } + + assertThat(results, equalTo(Arrays.asList(1L, 10L, 20L))); + } + + @Test + public void shouldSerializeNullToNull() { + assertNull(keySerde.serializer().serialize(topic, null)); + } + + @Test + public void shouldDeserializeEmptyByteArrayToNull() { + assertNull(keySerde.deserializer().deserialize(topic, new byte[0])); + } + + @Test + public void shouldDeserializeNullToNull() { + assertNull(keySerde.deserializer().deserialize(topic, null)); + } + + @Test + public void shouldConvertToBinaryAndBack() { + final Bytes serialized = WindowKeySchema.toStoreKeyBinary(windowedKey, 0, stateSerdes); + final Windowed result = WindowKeySchema.fromStoreKey(serialized.get(), endTime - startTime, stateSerdes.keyDeserializer(), stateSerdes.topic()); + assertEquals(windowedKey, result); + } + + @Test + public void shouldExtractSequenceFromBinary() { + final Bytes serialized = WindowKeySchema.toStoreKeyBinary(windowedKey, 0, stateSerdes); + assertEquals(0, WindowKeySchema.extractStoreSequence(serialized.get())); + } + + @Test + public void shouldExtractStartTimeFromBinary() { + final Bytes serialized = WindowKeySchema.toStoreKeyBinary(windowedKey, 0, stateSerdes); + assertEquals(startTime, WindowKeySchema.extractStoreTimestamp(serialized.get())); + } + + @Test + public void shouldExtractWindowFromBinary() { + final Bytes serialized = WindowKeySchema.toStoreKeyBinary(windowedKey, 0, stateSerdes); + assertEquals(window, WindowKeySchema.extractStoreWindow(serialized.get(), endTime - startTime)); + } + + @Test + public void shouldExtractKeyBytesFromBinary() { + final Bytes serialized = WindowKeySchema.toStoreKeyBinary(windowedKey, 0, stateSerdes); + assertArrayEquals(key.getBytes(), WindowKeySchema.extractStoreKeyBytes(serialized.get())); + } + + @Test + public void shouldExtractKeyFromBinary() { + final Bytes serialized = WindowKeySchema.toStoreKeyBinary(windowedKey, 0, stateSerdes); + assertEquals(windowedKey, WindowKeySchema.fromStoreKey(serialized.get(), endTime - startTime, stateSerdes.keyDeserializer(), stateSerdes.topic())); + } + + @Test + public void shouldExtractBytesKeyFromBinary() { + final Windowed windowedBytesKey = new Windowed<>(Bytes.wrap(key.getBytes()), window); + final Bytes serialized = WindowKeySchema.toStoreKeyBinary(windowedBytesKey, 0); + assertEquals(windowedBytesKey, WindowKeySchema.fromStoreBytesKey(serialized.get(), endTime - startTime)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/WindowStoreBuilderTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/WindowStoreBuilderTest.java new file mode 100644 index 0000000..6442da0 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/WindowStoreBuilderTest.java @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.state.internals; + +import java.time.Duration; +import org.apache.kafka.common.config.TopicConfig; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.WindowBytesStoreSupplier; +import org.apache.kafka.streams.state.WindowStore; +import org.easymock.EasyMockRunner; +import org.easymock.Mock; +import org.easymock.MockType; +import org.hamcrest.CoreMatchers; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +import java.util.Collections; + +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.reset; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.core.IsInstanceOf.instanceOf; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; + +@RunWith(EasyMockRunner.class) +public class WindowStoreBuilderTest { + + @Mock(type = MockType.NICE) + private WindowBytesStoreSupplier supplier; + @Mock(type = MockType.NICE) + private WindowStore inner; + private WindowStoreBuilder builder; + + @Before + public void setUp() { + expect(supplier.get()).andReturn(inner); + expect(supplier.name()).andReturn("name"); + expect(supplier.metricsScope()).andReturn("metricScope"); + replay(supplier); + + builder = new WindowStoreBuilder<>( + supplier, + Serdes.String(), + Serdes.String(), + new MockTime()); + } + + @Test + public void shouldHaveMeteredStoreAsOuterStore() { + final WindowStore store = builder.build(); + assertThat(store, instanceOf(MeteredWindowStore.class)); + } + + @Test + public void shouldHaveChangeLoggingStoreByDefault() { + final WindowStore store = builder.build(); + final StateStore next = ((WrappedStateStore) store).wrapped(); + assertThat(next, instanceOf(ChangeLoggingWindowBytesStore.class)); + } + + @Test + public void shouldNotHaveChangeLoggingStoreWhenDisabled() { + final WindowStore store = builder.withLoggingDisabled().build(); + final StateStore next = ((WrappedStateStore) store).wrapped(); + assertThat(next, CoreMatchers.equalTo(inner)); + } + + @Test + public void shouldHaveCachingStoreWhenEnabled() { + final WindowStore store = builder.withCachingEnabled().build(); + final StateStore wrapped = ((WrappedStateStore) store).wrapped(); + assertThat(store, instanceOf(MeteredWindowStore.class)); + assertThat(wrapped, instanceOf(CachingWindowStore.class)); + } + + @Test + public void shouldHaveChangeLoggingStoreWhenLoggingEnabled() { + final WindowStore store = builder + .withLoggingEnabled(Collections.emptyMap()) + .build(); + final StateStore wrapped = ((WrappedStateStore) store).wrapped(); + assertThat(store, instanceOf(MeteredWindowStore.class)); + assertThat(wrapped, instanceOf(ChangeLoggingWindowBytesStore.class)); + assertThat(((WrappedStateStore) wrapped).wrapped(), CoreMatchers.equalTo(inner)); + } + + @Test + public void shouldHaveCachingAndChangeLoggingWhenBothEnabled() { + final WindowStore store = builder + .withLoggingEnabled(Collections.emptyMap()) + .withCachingEnabled() + .build(); + final WrappedStateStore caching = (WrappedStateStore) ((WrappedStateStore) store).wrapped(); + final WrappedStateStore changeLogging = (WrappedStateStore) caching.wrapped(); + assertThat(store, instanceOf(MeteredWindowStore.class)); + assertThat(caching, instanceOf(CachingWindowStore.class)); + assertThat(changeLogging, instanceOf(ChangeLoggingWindowBytesStore.class)); + assertThat(changeLogging.wrapped(), CoreMatchers.equalTo(inner)); + } + + @SuppressWarnings("unchecked") + @Test + public void shouldDisableCachingWithRetainDuplicates() { + supplier = Stores.persistentWindowStore("name", Duration.ofMillis(10L), Duration.ofMillis(10L), true); + final StoreBuilder> builder = new WindowStoreBuilder<>( + supplier, + Serdes.String(), + Serdes.String(), + new MockTime() + ).withCachingEnabled(); + + builder.build(); + + assertFalse(((AbstractStoreBuilder>) builder).enableCaching); + } + + @Test + public void shouldDisableLogCompactionWithRetainDuplicates() { + supplier = Stores.persistentWindowStore( + "name", + Duration.ofMillis(10L), + Duration.ofMillis(10L), + true); + final StoreBuilder> builder = new WindowStoreBuilder<>( + supplier, + Serdes.String(), + Serdes.String(), + new MockTime() + ).withCachingEnabled(); + + assertThat( + builder.logConfig().get(TopicConfig.CLEANUP_POLICY_CONFIG), + equalTo(TopicConfig.CLEANUP_POLICY_DELETE) + ); + } + + @SuppressWarnings("null") + @Test + public void shouldThrowNullPointerIfInnerIsNull() { + assertThrows(NullPointerException.class, () -> new WindowStoreBuilder<>(null, Serdes.String(), Serdes.String(), new MockTime())); + } + + @Test + public void shouldThrowNullPointerIfKeySerdeIsNull() { + assertThrows(NullPointerException.class, () -> new WindowStoreBuilder<>(supplier, null, Serdes.String(), new MockTime())); + } + + @Test + public void shouldThrowNullPointerIfValueSerdeIsNull() { + assertThrows(NullPointerException.class, () -> new WindowStoreBuilder<>(supplier, Serdes.String(), + null, new MockTime())); + } + + @Test + public void shouldThrowNullPointerIfTimeIsNull() { + assertThrows(NullPointerException.class, () -> new WindowStoreBuilder<>(supplier, Serdes.String(), + Serdes.String(), null)); + } + + @Test + public void shouldThrowNullPointerIfMetricsScopeIsNull() { + reset(supplier); + expect(supplier.get()).andReturn(new RocksDBWindowStore( + new RocksDBSegmentedBytesStore( + "name", + null, + 10L, + 5L, + new WindowKeySchema()), + false, + 1L)); + expect(supplier.name()).andReturn("name"); + replay(supplier); + + final Exception e = assertThrows(NullPointerException.class, + () -> new WindowStoreBuilder<>(supplier, Serdes.String(), Serdes.String(), new MockTime())); + assertThat(e.getMessage(), equalTo("storeSupplier's metricsScope can't be null")); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/WindowStoreFetchTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/WindowStoreFetchTest.java new file mode 100644 index 0000000..a429256 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/WindowStoreFetchTest.java @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.state.internals; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.TimeWindowedKStream; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.kstream.internals.TimeWindow; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.WindowBytesStoreSupplier; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.test.TestUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestName; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.time.Duration; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Properties; +import java.util.function.Predicate; +import java.util.function.Supplier; + +import static java.time.Duration.ofMillis; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkProperties; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +@RunWith(Parameterized.class) +public class WindowStoreFetchTest { + private enum StoreType { InMemory, RocksDB, Timed }; + private static final String STORE_NAME = "store"; + private static final int DATA_SIZE = 5; + private static final long WINDOW_SIZE = 500L; + private static final long RETENTION_MS = 10000L; + + private StoreType storeType; + private boolean enableLogging; + private boolean enableCaching; + private boolean forward; + + private LinkedList, Long>> expectedRecords; + private LinkedList> records; + private Properties streamsConfig; + private String low; + private String high; + private String middle; + private String innerLow; + private String innerHigh; + private String innerLowBetween; + private String innerHighBetween; + private String storeName; + + private TimeWindowedKStream windowedStream; + + public WindowStoreFetchTest(final StoreType storeType, final boolean enableLogging, final boolean enableCaching, final boolean forward) { + this.storeType = storeType; + this.enableLogging = enableLogging; + this.enableCaching = enableCaching; + this.forward = forward; + + this.records = new LinkedList<>(); + this.expectedRecords = new LinkedList<>(); + final int m = DATA_SIZE / 2; + for (int i = 0; i < DATA_SIZE; i++) { + final String key = "key-" + i * 2; + final String value = "val-" + i * 2; + final KeyValue r = new KeyValue<>(key, value); + records.add(r); + records.add(r); + // expected the count of each key is 2 + final long windowStartTime = i < m ? 0 : WINDOW_SIZE; + expectedRecords.add(new KeyValue<>(new Windowed<>(key, new TimeWindow(windowStartTime, windowStartTime + WINDOW_SIZE)), 2L)); + high = key; + if (low == null) { + low = key; + } + if (i == m) { + middle = key; + } + if (i == 1) { + innerLow = key; + final int index = i * 2 - 1; + innerLowBetween = "key-" + index; + } + if (i == DATA_SIZE - 2) { + innerHigh = key; + final int index = i * 2 + 1; + innerHighBetween = "key-" + index; + } + } + Assert.assertNotNull(low); + Assert.assertNotNull(high); + Assert.assertNotNull(middle); + Assert.assertNotNull(innerLow); + Assert.assertNotNull(innerHigh); + Assert.assertNotNull(innerLowBetween); + Assert.assertNotNull(innerHighBetween); + } + + @Rule + public TestName testName = new TestName(); + + @Parameterized.Parameters(name = "storeType={0}, enableLogging={1}, enableCaching={2}, forward={3}") + public static Collection data() { + final List types = Arrays.asList(StoreType.InMemory, StoreType.RocksDB, StoreType.Timed); + final List logging = Arrays.asList(true, false); + final List caching = Arrays.asList(true, false); + final List forward = Arrays.asList(true, false); + return buildParameters(types, logging, caching, forward); + } + + @Before + public void setup() { + streamsConfig = mkProperties(mkMap( + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()) + )); + } + + @Test + public void testStoreConfig() { + final Materialized> stateStoreConfig = getStoreConfig(storeType, STORE_NAME, enableLogging, enableCaching); + //Create topology: table from input topic + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream stream = builder.stream("input", Consumed.with(Serdes.String(), Serdes.String())); + stream. + groupByKey(Grouped.with(Serdes.String(), Serdes.String())) + .windowedBy(TimeWindows.ofSizeWithNoGrace(ofMillis(WINDOW_SIZE))) + .count(stateStoreConfig) + .toStream() + .to("output"); + + final Topology topology = builder.build(); + + try (final TopologyTestDriver driver = new TopologyTestDriver(topology)) { + //get input topic and stateStore + final TestInputTopic input = driver + .createInputTopic("input", new StringSerializer(), new StringSerializer()); + final WindowStore stateStore = driver.getWindowStore(STORE_NAME); + + //write some data + final int medium = DATA_SIZE / 2 * 2; + for (int i = 0; i < records.size(); i++) { + final KeyValue kv = records.get(i); + final long windowStartTime = i < medium ? 0 : WINDOW_SIZE; + input.pipeInput(kv.key, kv.value, windowStartTime + i); + } + + // query the state store + try (final KeyValueIterator, Long> scanIterator = forward ? + stateStore.fetchAll(0, Long.MAX_VALUE) : + stateStore.backwardFetchAll(0, Long.MAX_VALUE)) { + + final Iterator, Long>> dataIterator = forward ? + expectedRecords.iterator() : + expectedRecords.descendingIterator(); + + TestUtils.checkEquals(scanIterator, dataIterator); + } + + try (final KeyValueIterator, Long> scanIterator = forward ? + stateStore.fetch(null, null, 0, Long.MAX_VALUE) : + stateStore.backwardFetch(null, null, 0, Long.MAX_VALUE)) { + + final Iterator, Long>> dataIterator = forward ? + expectedRecords.iterator() : + expectedRecords.descendingIterator(); + + TestUtils.checkEquals(scanIterator, dataIterator); + } + + testRange("range", stateStore, innerLow, innerHigh, forward); + testRange("until", stateStore, null, middle, forward); + testRange("from", stateStore, middle, null, forward); + + testRange("untilBetween", stateStore, null, innerHighBetween, forward); + testRange("fromBetween", stateStore, innerLowBetween, null, forward); + } + } + + private List, Long>> filterList(final KeyValueIterator, Long> iterator, final String from, final String to) { + final Predicate, Long>> pred = new Predicate, Long>>() { + @Override + public boolean test(final KeyValue, Long> elem) { + if (from != null && elem.key.key().compareTo(from) < 0) { + return false; + } + if (to != null && elem.key.key().compareTo(to) > 0) { + return false; + } + return elem != null; + } + }; + + return Utils.toList(iterator, pred); + } + + private void testRange(final String name, final WindowStore store, final String from, final String to, final boolean forward) { + try (final KeyValueIterator, Long> resultIterator = forward ? store.fetch(from, to, 0, Long.MAX_VALUE) : store.backwardFetch(from, to, 0, Long.MAX_VALUE); + final KeyValueIterator, Long> expectedIterator = forward ? store.fetchAll(0, Long.MAX_VALUE) : store.backwardFetchAll(0, Long.MAX_VALUE)) { + final List, Long>> result = Utils.toList(resultIterator); + final List, Long>> expected = filterList(expectedIterator, from, to); + assertThat(result, is(expected)); + } + } + + private static Collection buildParameters(final List... argOptions) { + List result = new LinkedList<>(); + result.add(new Object[0]); + + for (final List argOption : argOptions) { + result = times(result, argOption); + } + + return result; + } + + private static List times(final List left, final List right) { + final List result = new LinkedList<>(); + for (final Object[] args : left) { + for (final Object rightElem : right) { + final Object[] resArgs = new Object[args.length + 1]; + System.arraycopy(args, 0, resArgs, 0, args.length); + resArgs[args.length] = rightElem; + result.add(resArgs); + } + } + return result; + } + + private Materialized> getStoreConfig(final StoreType type, final String name, final boolean cachingEnabled, final boolean loggingEnabled) { + final Supplier createStore = () -> { + if (type == StoreType.InMemory) { + return Stores.inMemoryWindowStore(STORE_NAME, Duration.ofMillis(RETENTION_MS), + Duration.ofMillis(WINDOW_SIZE), + false); + } else if (type == StoreType.RocksDB) { + return Stores.persistentWindowStore(STORE_NAME, Duration.ofMillis(RETENTION_MS), + Duration.ofMillis(WINDOW_SIZE), + false); + } else if (type == StoreType.Timed) { + return Stores.persistentTimestampedWindowStore(STORE_NAME, Duration.ofMillis(RETENTION_MS), + Duration.ofMillis(WINDOW_SIZE), + false); + } else { + return Stores.inMemoryWindowStore(STORE_NAME, Duration.ofMillis(RETENTION_MS), + Duration.ofMillis(WINDOW_SIZE), + false); + } + }; + + final WindowBytesStoreSupplier stateStoreSupplier = createStore.get(); + final Materialized> stateStoreConfig = Materialized + .as(stateStoreSupplier) + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.Long()); + if (cachingEnabled) { + stateStoreConfig.withCachingEnabled(); + } else { + stateStoreConfig.withCachingDisabled(); + } + if (loggingEnabled) { + stateStoreConfig.withLoggingEnabled(new HashMap()); + } else { + stateStoreConfig.withLoggingDisabled(); + } + return stateStoreConfig; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/WrappingStoreProviderTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/WrappingStoreProviderTest.java new file mode 100644 index 0000000..2a5551a --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/WrappingStoreProviderTest.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals; + + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.StoreQueryParameters; +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.errors.InvalidStateStorePartitionException; +import org.apache.kafka.streams.state.NoOpWindowStore; +import org.apache.kafka.streams.state.QueryableStoreTypes; +import org.apache.kafka.streams.state.ReadOnlyKeyValueStore; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.ReadOnlyWindowStore; +import org.apache.kafka.test.StateStoreProviderStub; +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.List; + +import static org.apache.kafka.streams.state.QueryableStoreTypes.windowStore; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +public class WrappingStoreProviderTest { + + private WrappingStoreProvider wrappingStoreProvider; + + private final int numStateStorePartitions = 2; + + @Before + public void before() { + final StateStoreProviderStub stubProviderOne = new StateStoreProviderStub(false); + final StateStoreProviderStub stubProviderTwo = new StateStoreProviderStub(false); + + for (int partition = 0; partition < numStateStorePartitions; partition++) { + stubProviderOne.addStore("kv", partition, Stores.keyValueStoreBuilder(Stores.inMemoryKeyValueStore("kv"), + Serdes.serdeFrom(String.class), + Serdes.serdeFrom(String.class)) + .build()); + stubProviderOne.addStore("window", partition, new NoOpWindowStore()); + wrappingStoreProvider = new WrappingStoreProvider( + Arrays.asList(stubProviderOne, stubProviderTwo), + StoreQueryParameters.fromNameAndType("kv", QueryableStoreTypes.keyValueStore()) + ); + } + } + + @Test + public void shouldFindKeyValueStores() { + final List> results = + wrappingStoreProvider.stores("kv", QueryableStoreTypes.keyValueStore()); + assertEquals(2, results.size()); + } + + @Test + public void shouldFindWindowStores() { + wrappingStoreProvider.setStoreQueryParameters(StoreQueryParameters.fromNameAndType("window", windowStore())); + final List> + windowStores = + wrappingStoreProvider.stores("window", windowStore()); + assertEquals(2, windowStores.size()); + } + + @Test + public void shouldThrowInvalidStoreExceptionIfNoStoreOfTypeFound() { + wrappingStoreProvider.setStoreQueryParameters(StoreQueryParameters.fromNameAndType("doesn't exist", QueryableStoreTypes.keyValueStore())); + assertThrows(InvalidStateStoreException.class, () -> wrappingStoreProvider.stores("doesn't exist", QueryableStoreTypes.keyValueStore())); + } + + @Test + public void shouldThrowInvalidStoreExceptionIfNoPartitionFound() { + final int invalidPartition = numStateStorePartitions + 1; + wrappingStoreProvider.setStoreQueryParameters(StoreQueryParameters.fromNameAndType("kv", QueryableStoreTypes.keyValueStore()).withPartition(invalidPartition)); + assertThrows(InvalidStateStorePartitionException.class, () -> wrappingStoreProvider.stores("kv", QueryableStoreTypes.keyValueStore())); + } + + @Test + public void shouldReturnAllStoreWhenQueryWithoutPartition() { + wrappingStoreProvider.setStoreQueryParameters(StoreQueryParameters.fromNameAndType("kv", QueryableStoreTypes.keyValueStore())); + final List> results = + wrappingStoreProvider.stores("kv", QueryableStoreTypes.keyValueStore()); + assertEquals(numStateStorePartitions, results.size()); + } + + @Test + public void shouldReturnSingleStoreWhenQueryWithPartition() { + wrappingStoreProvider.setStoreQueryParameters(StoreQueryParameters.fromNameAndType("kv", QueryableStoreTypes.keyValueStore()).withPartition(numStateStorePartitions - 1)); + final List> results = + wrappingStoreProvider.stores("kv", QueryableStoreTypes.keyValueStore()); + assertEquals(1, results.size()); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/metrics/NamedCacheMetricsTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/metrics/NamedCacheMetricsTest.java new file mode 100644 index 0000000..0b525db --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/metrics/NamedCacheMetricsTest.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals.metrics; + +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.Sensor.RecordingLevel; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.junit.Test; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.Map; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +public class NamedCacheMetricsTest { + + private static final String THREAD_ID = "test-thread"; + private static final String TASK_ID = "test-task"; + private static final String STORE_NAME = "storeName"; + private static final String HIT_RATIO_AVG_DESCRIPTION = "The average cache hit ratio"; + private static final String HIT_RATIO_MIN_DESCRIPTION = "The minimum cache hit ratio"; + private static final String HIT_RATIO_MAX_DESCRIPTION = "The maximum cache hit ratio"; + + private final StreamsMetricsImpl streamsMetrics = mock(StreamsMetricsImpl.class); + private final Sensor expectedSensor = mock(Sensor.class); + private final Map tagMap = mkMap(mkEntry("key", "value")); + + @Test + public void shouldGetHitRatioSensorWithBuiltInMetricsVersionCurrent() { + final String hitRatio = "hit-ratio"; + when(streamsMetrics.cacheLevelSensor(THREAD_ID, TASK_ID, STORE_NAME, hitRatio, RecordingLevel.DEBUG)).thenReturn(expectedSensor); + when(streamsMetrics.cacheLevelTagMap(THREAD_ID, TASK_ID, STORE_NAME)).thenReturn(tagMap); + StreamsMetricsImpl.addAvgAndMinAndMaxToSensor( + expectedSensor, + StreamsMetricsImpl.CACHE_LEVEL_GROUP, + tagMap, + hitRatio, + HIT_RATIO_AVG_DESCRIPTION, + HIT_RATIO_MIN_DESCRIPTION, + HIT_RATIO_MAX_DESCRIPTION); + + final Sensor sensor = NamedCacheMetrics.hitRatioSensor(streamsMetrics, THREAD_ID, TASK_ID, STORE_NAME); + + assertThat(sensor, is(expectedSensor)); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/metrics/RocksDBMetricsRecorderGaugesTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/metrics/RocksDBMetricsRecorderGaugesTest.java new file mode 100644 index 0000000..2695b86 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/metrics/RocksDBMetricsRecorderGaugesTest.java @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals.metrics; + +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.KafkaMetric; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.junit.Test; +import org.rocksdb.Cache; +import org.rocksdb.RocksDB; +import org.rocksdb.Statistics; + +import java.math.BigInteger; +import java.util.Map; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.STATE_STORE_LEVEL_GROUP; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.STORE_ID_TAG; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.TASK_ID_TAG; +import static org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.THREAD_ID_TAG; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.CAPACITY_OF_BLOCK_CACHE; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.COMPACTION_PENDING; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.CURRENT_SIZE_OF_ACTIVE_MEMTABLE; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.CURRENT_SIZE_OF_ALL_MEMTABLES; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.ESTIMATED_BYTES_OF_PENDING_COMPACTION; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.ESTIMATED_MEMORY_OF_TABLE_READERS; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.ESTIMATED_NUMBER_OF_KEYS; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.LIVE_SST_FILES_SIZE; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.MEMTABLE_FLUSH_PENDING; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.NUMBER_OF_DELETES_ACTIVE_MEMTABLE; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.NUMBER_OF_DELETES_IMMUTABLE_MEMTABLES; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.NUMBER_OF_ENTRIES_ACTIVE_MEMTABLE; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.NUMBER_OF_ENTRIES_IMMUTABLE_MEMTABLES; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.NUMBER_OF_IMMUTABLE_MEMTABLES; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.NUMBER_OF_RUNNING_COMPACTIONS; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.NUMBER_OF_RUNNING_FLUSHES; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.PINNED_USAGE_OF_BLOCK_CACHE; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.SIZE_OF_ALL_MEMTABLES; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.NUMBER_OF_BACKGROUND_ERRORS; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.TOTAL_SST_FILES_SIZE; +import static org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.USAGE_OF_BLOCK_CACHE; +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.mock; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.powermock.api.easymock.PowerMock.replay; + +public class RocksDBMetricsRecorderGaugesTest { + private static final String METRICS_SCOPE = "metrics-scope"; + private static final TaskId TASK_ID = new TaskId(0, 0); + private static final String STORE_NAME = "store-name"; + private static final String SEGMENT_STORE_NAME_1 = "segment-store-name-1"; + private static final String SEGMENT_STORE_NAME_2 = "segment-store-name-2"; + private static final String ROCKSDB_PROPERTIES_PREFIX = "rocksdb."; + + private final RocksDB dbToAdd1 = mock(RocksDB.class); + private final RocksDB dbToAdd2 = mock(RocksDB.class); + private final Cache cacheToAdd1 = mock(Cache.class); + private final Cache cacheToAdd2 = mock(Cache.class); + private final Statistics statisticsToAdd1 = mock(Statistics.class); + private final Statistics statisticsToAdd2 = mock(Statistics.class); + + @Test + public void shouldGetNumberOfImmutableMemTables() throws Exception { + runAndVerifySumOfProperties(NUMBER_OF_IMMUTABLE_MEMTABLES); + } + + @Test + public void shouldGetCurrentSizeofActiveMemTable() throws Exception { + runAndVerifySumOfProperties(CURRENT_SIZE_OF_ACTIVE_MEMTABLE); + } + + @Test + public void shouldGetCurrentSizeofAllMemTables() throws Exception { + runAndVerifySumOfProperties(CURRENT_SIZE_OF_ALL_MEMTABLES); + } + + @Test + public void shouldGetSizeofAllMemTables() throws Exception { + runAndVerifySumOfProperties(SIZE_OF_ALL_MEMTABLES); + } + + @Test + public void shouldGetNumberOfEntriesActiveMemTable() throws Exception { + runAndVerifySumOfProperties(NUMBER_OF_ENTRIES_ACTIVE_MEMTABLE); + } + + @Test + public void shouldGetNumberOfDeletesActiveMemTable() throws Exception { + runAndVerifySumOfProperties(NUMBER_OF_DELETES_ACTIVE_MEMTABLE); + } + + @Test + public void shouldGetNumberOfEntriesImmutableMemTables() throws Exception { + runAndVerifySumOfProperties(NUMBER_OF_ENTRIES_IMMUTABLE_MEMTABLES); + } + + @Test + public void shouldGetNumberOfDeletesImmutableMemTables() throws Exception { + runAndVerifySumOfProperties(NUMBER_OF_DELETES_IMMUTABLE_MEMTABLES); + } + + @Test + public void shouldGetMemTableFlushPending() throws Exception { + runAndVerifySumOfProperties(MEMTABLE_FLUSH_PENDING); + } + + @Test + public void shouldGetNumberOfRunningFlushes() throws Exception { + runAndVerifySumOfProperties(NUMBER_OF_RUNNING_FLUSHES); + } + + @Test + public void shouldGetCompactionPending() throws Exception { + runAndVerifySumOfProperties(COMPACTION_PENDING); + } + + @Test + public void shouldGetNumberOfRunningCompactions() throws Exception { + runAndVerifySumOfProperties(NUMBER_OF_RUNNING_COMPACTIONS); + } + + @Test + public void shouldGetEstimatedBytesOfPendingCompactions() throws Exception { + runAndVerifySumOfProperties(ESTIMATED_BYTES_OF_PENDING_COMPACTION); + } + + @Test + public void shouldGetTotalSstFilesSize() throws Exception { + runAndVerifySumOfProperties(TOTAL_SST_FILES_SIZE); + } + + @Test + public void shouldGetLiveSstFilesSize() throws Exception { + runAndVerifySumOfProperties(LIVE_SST_FILES_SIZE); + } + + @Test + public void shouldGetEstimatedNumberOfKeys() throws Exception { + runAndVerifySumOfProperties(ESTIMATED_NUMBER_OF_KEYS); + } + + @Test + public void shouldGetEstimatedMemoryOfTableReaders() throws Exception { + runAndVerifySumOfProperties(ESTIMATED_MEMORY_OF_TABLE_READERS); + } + + @Test + public void shouldGetNumberOfBackgroundErrors() throws Exception { + runAndVerifySumOfProperties(NUMBER_OF_BACKGROUND_ERRORS); + } + + @Test + public void shouldGetCapacityOfBlockCacheWithMultipleCaches() throws Exception { + runAndVerifyBlockCacheMetricsWithMultipleCaches(CAPACITY_OF_BLOCK_CACHE); + } + + @Test + public void shouldGetCapacityOfBlockCacheWithSingleCache() throws Exception { + runAndVerifyBlockCacheMetricsWithSingleCache(CAPACITY_OF_BLOCK_CACHE); + } + + @Test + public void shouldGetUsageOfBlockCacheWithMultipleCaches() throws Exception { + runAndVerifyBlockCacheMetricsWithMultipleCaches(USAGE_OF_BLOCK_CACHE); + } + + @Test + public void shouldGetUsageOfBlockCacheWithSingleCache() throws Exception { + runAndVerifyBlockCacheMetricsWithSingleCache(USAGE_OF_BLOCK_CACHE); + } + + @Test + public void shouldGetPinnedUsageOfBlockCacheWithMultipleCaches() throws Exception { + runAndVerifyBlockCacheMetricsWithMultipleCaches(PINNED_USAGE_OF_BLOCK_CACHE); + } + + @Test + public void shouldGetPinnedUsageOfBlockCacheWithSingleCache() throws Exception { + runAndVerifyBlockCacheMetricsWithSingleCache(PINNED_USAGE_OF_BLOCK_CACHE); + } + + private void runAndVerifySumOfProperties(final String propertyName) throws Exception { + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(new Metrics(), "test-client", StreamsConfig.METRICS_LATEST, new MockTime()); + final RocksDBMetricsRecorder recorder = new RocksDBMetricsRecorder(METRICS_SCOPE, STORE_NAME); + + recorder.init(streamsMetrics, TASK_ID); + recorder.addValueProviders(SEGMENT_STORE_NAME_1, dbToAdd1, cacheToAdd1, statisticsToAdd1); + recorder.addValueProviders(SEGMENT_STORE_NAME_2, dbToAdd2, cacheToAdd2, statisticsToAdd2); + + final long recordedValue1 = 5L; + final long recordedValue2 = 3L; + expect(dbToAdd1.getAggregatedLongProperty(ROCKSDB_PROPERTIES_PREFIX + propertyName)) + .andStubReturn(recordedValue1); + expect(dbToAdd2.getAggregatedLongProperty(ROCKSDB_PROPERTIES_PREFIX + propertyName)) + .andStubReturn(recordedValue2); + replay(dbToAdd1, dbToAdd2); + + verifyMetrics(streamsMetrics, propertyName, recordedValue1 + recordedValue2); + } + + private void runAndVerifyBlockCacheMetricsWithMultipleCaches(final String propertyName) throws Exception { + runAndVerifySumOfProperties(propertyName); + } + + private void runAndVerifyBlockCacheMetricsWithSingleCache(final String propertyName) throws Exception { + final StreamsMetricsImpl streamsMetrics = + new StreamsMetricsImpl(new Metrics(), "test-client", StreamsConfig.METRICS_LATEST, new MockTime()); + final RocksDBMetricsRecorder recorder = new RocksDBMetricsRecorder(METRICS_SCOPE, STORE_NAME); + + recorder.init(streamsMetrics, TASK_ID); + recorder.addValueProviders(SEGMENT_STORE_NAME_1, dbToAdd1, cacheToAdd1, statisticsToAdd1); + recorder.addValueProviders(SEGMENT_STORE_NAME_2, dbToAdd2, cacheToAdd1, statisticsToAdd2); + + final long recordedValue = 5L; + expect(dbToAdd1.getAggregatedLongProperty(ROCKSDB_PROPERTIES_PREFIX + propertyName)) + .andStubReturn(recordedValue); + expect(dbToAdd2.getAggregatedLongProperty(ROCKSDB_PROPERTIES_PREFIX + propertyName)) + .andStubReturn(recordedValue); + replay(dbToAdd1, dbToAdd2); + + verifyMetrics(streamsMetrics, propertyName, recordedValue); + } + + private void verifyMetrics(final StreamsMetricsImpl streamsMetrics, + final String propertyName, + final long expectedValue) { + + final Map metrics = streamsMetrics.metrics(); + final Map tagMap = mkMap( + mkEntry(THREAD_ID_TAG, Thread.currentThread().getName()), + mkEntry(TASK_ID_TAG, TASK_ID.toString()), + mkEntry(METRICS_SCOPE + "-" + STORE_ID_TAG, STORE_NAME) + ); + final KafkaMetric metric = (KafkaMetric) metrics.get(new MetricName( + propertyName, + STATE_STORE_LEVEL_GROUP, + "description is ignored", + tagMap + )); + + assertThat(metric, notNullValue()); + assertThat(metric.metricValue(), is(BigInteger.valueOf(expectedValue))); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/metrics/RocksDBMetricsRecorderTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/metrics/RocksDBMetricsRecorderTest.java new file mode 100644 index 0000000..dc08f84 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/metrics/RocksDBMetricsRecorderTest.java @@ -0,0 +1,638 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals.metrics; + +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.RocksDBMetricContext; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; +import org.rocksdb.Cache; +import org.rocksdb.RocksDB; +import org.rocksdb.Statistics; +import org.rocksdb.StatsLevel; +import org.rocksdb.TickerType; + +import static org.easymock.EasyMock.anyObject; +import static org.easymock.EasyMock.eq; +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.mock; +import static org.easymock.EasyMock.niceMock; +import static org.easymock.EasyMock.resetToNice; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertThrows; +import static org.powermock.api.easymock.PowerMock.reset; +import static org.powermock.api.easymock.PowerMock.createMock; +import static org.powermock.api.easymock.PowerMock.mockStatic; +import static org.powermock.api.easymock.PowerMock.replay; +import static org.powermock.api.easymock.PowerMock.verify; + +@RunWith(PowerMockRunner.class) +@PrepareForTest({RocksDBMetrics.class, Sensor.class}) +public class RocksDBMetricsRecorderTest { + private final static String METRICS_SCOPE = "metrics-scope"; + private final static String THREAD_ID = "thread-id"; + private final static TaskId TASK_ID1 = new TaskId(0, 0); + private final static TaskId TASK_ID2 = new TaskId(0, 1); + private final static String STORE_NAME = "store-name"; + private final static String SEGMENT_STORE_NAME_1 = "segment-store-name-1"; + private final static String SEGMENT_STORE_NAME_2 = "segment-store-name-2"; + private final static String SEGMENT_STORE_NAME_3 = "segment-store-name-3"; + + private final RocksDB dbToAdd1 = mock(RocksDB.class); + private final RocksDB dbToAdd2 = mock(RocksDB.class); + private final RocksDB dbToAdd3 = mock(RocksDB.class); + private final Cache cacheToAdd1 = mock(Cache.class); + private final Cache cacheToAdd2 = mock(Cache.class); + private final Statistics statisticsToAdd1 = mock(Statistics.class); + private final Statistics statisticsToAdd2 = mock(Statistics.class); + private final Statistics statisticsToAdd3 = mock(Statistics.class); + + private final Sensor bytesWrittenToDatabaseSensor = createMock(Sensor.class); + private final Sensor bytesReadFromDatabaseSensor = createMock(Sensor.class); + private final Sensor memtableBytesFlushedSensor = createMock(Sensor.class); + private final Sensor memtableHitRatioSensor = createMock(Sensor.class); + private final Sensor writeStallDurationSensor = createMock(Sensor.class); + private final Sensor blockCacheDataHitRatioSensor = createMock(Sensor.class); + private final Sensor blockCacheIndexHitRatioSensor = createMock(Sensor.class); + private final Sensor blockCacheFilterHitRatioSensor = createMock(Sensor.class); + private final Sensor bytesReadDuringCompactionSensor = createMock(Sensor.class); + private final Sensor bytesWrittenDuringCompactionSensor = createMock(Sensor.class); + private final Sensor numberOfOpenFilesSensor = createMock(Sensor.class); + private final Sensor numberOfFileErrorsSensor = createMock(Sensor.class); + + private final StreamsMetricsImpl streamsMetrics = niceMock(StreamsMetricsImpl.class); + private final RocksDBMetricsRecordingTrigger recordingTrigger = mock(RocksDBMetricsRecordingTrigger.class); + + private final RocksDBMetricsRecorder recorder = new RocksDBMetricsRecorder(METRICS_SCOPE, STORE_NAME); + + @Before + public void setUp() { + setUpMetricsStubMock(); + expect(streamsMetrics.rocksDBMetricsRecordingTrigger()).andStubReturn(recordingTrigger); + replay(streamsMetrics); + recorder.init(streamsMetrics, TASK_ID1); + } + + @Test + public void shouldInitMetricsRecorder() { + setUpMetricsMock(); + + recorder.init(streamsMetrics, TASK_ID1); + + verify(RocksDBMetrics.class); + assertThat(recorder.taskId(), is(TASK_ID1)); + } + + @Test + public void shouldThrowIfMetricRecorderIsReInitialisedWithDifferentTask() { + setUpMetricsStubMock(); + recorder.init(streamsMetrics, TASK_ID1); + + assertThrows( + IllegalStateException.class, + () -> recorder.init(streamsMetrics, TASK_ID2) + ); + } + + @Test + public void shouldThrowIfMetricRecorderIsReInitialisedWithDifferentStreamsMetrics() { + setUpMetricsStubMock(); + recorder.init(streamsMetrics, TASK_ID1); + + assertThrows( + IllegalStateException.class, + () -> recorder.init( + new StreamsMetricsImpl(new Metrics(), "test-client", StreamsConfig.METRICS_LATEST, new MockTime()), + TASK_ID1 + ) + ); + } + + @Test + public void shouldSetStatsLevelToExceptDetailedTimersWhenValueProvidersWithStatisticsAreAdded() { + statisticsToAdd1.setStatsLevel(StatsLevel.EXCEPT_DETAILED_TIMERS); + replay(statisticsToAdd1); + + recorder.addValueProviders(SEGMENT_STORE_NAME_1, dbToAdd1, cacheToAdd1, statisticsToAdd1); + + verify(statisticsToAdd1); + } + + @Test + public void shouldNotSetStatsLevelToExceptDetailedTimersWhenValueProvidersWithoutStatisticsAreAdded() { + replay(statisticsToAdd1); + + recorder.addValueProviders(SEGMENT_STORE_NAME_1, dbToAdd1, cacheToAdd1, null); + + verify(statisticsToAdd1); + } + + @Test + public void shouldThrowIfValueProvidersForASegmentHasBeenAlreadyAdded() { + recorder.addValueProviders(SEGMENT_STORE_NAME_1, dbToAdd1, cacheToAdd1, statisticsToAdd1); + + final Throwable exception = assertThrows( + IllegalStateException.class, + () -> recorder.addValueProviders(SEGMENT_STORE_NAME_1, dbToAdd1, cacheToAdd1, statisticsToAdd2) + ); + assertThat( + exception.getMessage(), + is("Value providers for store " + SEGMENT_STORE_NAME_1 + " of task " + TASK_ID1 + + " has been already added. This is a bug in Kafka Streams. " + + "Please open a bug report under https://issues.apache.org/jira/projects/KAFKA/issues") + ); + } + + @Test + public void shouldThrowIfStatisticsToAddIsNotNullButExsitingStatisticsAreNull() { + recorder.addValueProviders(SEGMENT_STORE_NAME_1, dbToAdd1, cacheToAdd1, null); + + final Throwable exception = assertThrows( + IllegalStateException.class, + () -> recorder.addValueProviders(SEGMENT_STORE_NAME_2, dbToAdd2, cacheToAdd2, statisticsToAdd2) + ); + assertThat( + exception.getMessage(), + is("Statistics for segment " + SEGMENT_STORE_NAME_2 + " of task " + TASK_ID1 + + " is not null although the statistics of another segment in this metrics recorder is null. " + + "This is a bug in Kafka Streams. " + + "Please open a bug report under https://issues.apache.org/jira/projects/KAFKA/issues") + ); + } + + @Test + public void shouldThrowIfStatisticsToAddIsNullButExsitingStatisticsAreNotNull() { + recorder.addValueProviders(SEGMENT_STORE_NAME_1, dbToAdd1, cacheToAdd1, statisticsToAdd1); + + final Throwable exception = assertThrows( + IllegalStateException.class, + () -> recorder.addValueProviders(SEGMENT_STORE_NAME_2, dbToAdd2, cacheToAdd2, null) + ); + assertThat( + exception.getMessage(), + is("Statistics for segment " + SEGMENT_STORE_NAME_2 + " of task " + TASK_ID1 + + " is null although the statistics of another segment in this metrics recorder is not null. " + + "This is a bug in Kafka Streams. " + + "Please open a bug report under https://issues.apache.org/jira/projects/KAFKA/issues") + ); + } + + @Test + public void shouldThrowIfCacheToAddIsNullButExsitingCacheIsNotNull() { + recorder.addValueProviders(SEGMENT_STORE_NAME_1, dbToAdd1, null, statisticsToAdd1); + + final Throwable exception = assertThrows( + IllegalStateException.class, + () -> recorder.addValueProviders(SEGMENT_STORE_NAME_2, dbToAdd2, cacheToAdd1, statisticsToAdd1) + ); + assertThat( + exception.getMessage(), + is("Cache for segment " + SEGMENT_STORE_NAME_2 + " of task " + TASK_ID1 + + " is not null although the cache of another segment in this metrics recorder is null. " + + "This is a bug in Kafka Streams. " + + "Please open a bug report under https://issues.apache.org/jira/projects/KAFKA/issues") + ); + } + + @Test + public void shouldThrowIfCacheToAddIsNotNullButExistingCacheIsNull() { + recorder.addValueProviders(SEGMENT_STORE_NAME_1, dbToAdd1, cacheToAdd1, statisticsToAdd1); + + final Throwable exception = assertThrows( + IllegalStateException.class, + () -> recorder.addValueProviders(SEGMENT_STORE_NAME_2, dbToAdd2, null, statisticsToAdd2) + ); + assertThat( + exception.getMessage(), + is("Cache for segment " + SEGMENT_STORE_NAME_2 + " of task " + TASK_ID1 + + " is null although the cache of another segment in this metrics recorder is not null. " + + "This is a bug in Kafka Streams. " + + "Please open a bug report under https://issues.apache.org/jira/projects/KAFKA/issues") + ); + } + + @Test + public void shouldThrowIfCacheToAddIsNotSameAsAllExistingCaches() { + recorder.addValueProviders(SEGMENT_STORE_NAME_1, dbToAdd1, cacheToAdd1, statisticsToAdd1); + recorder.addValueProviders(SEGMENT_STORE_NAME_2, dbToAdd2, cacheToAdd1, statisticsToAdd2); + + final Throwable exception = assertThrows( + IllegalStateException.class, + () -> recorder.addValueProviders(SEGMENT_STORE_NAME_3, dbToAdd3, cacheToAdd2, statisticsToAdd3) + ); + assertThat( + exception.getMessage(), + is("Caches for store " + STORE_NAME + " of task " + TASK_ID1 + + " are either not all distinct or do not all refer to the same cache. This is a bug in Kafka Streams. " + + "Please open a bug report under https://issues.apache.org/jira/projects/KAFKA/issues") + ); + } + + @Test + public void shouldThrowIfCacheToAddIsSameAsOnlyOneOfMultipleCaches() { + recorder.addValueProviders(SEGMENT_STORE_NAME_1, dbToAdd1, cacheToAdd1, statisticsToAdd1); + recorder.addValueProviders(SEGMENT_STORE_NAME_2, dbToAdd2, cacheToAdd2, statisticsToAdd2); + + final Throwable exception = assertThrows( + IllegalStateException.class, + () -> recorder.addValueProviders(SEGMENT_STORE_NAME_3, dbToAdd3, cacheToAdd1, statisticsToAdd3) + ); + assertThat( + exception.getMessage(), + is("Caches for store " + STORE_NAME + " of task " + TASK_ID1 + + " are either not all distinct or do not all refer to the same cache. This is a bug in Kafka Streams. " + + "Please open a bug report under https://issues.apache.org/jira/projects/KAFKA/issues") + ); + } + + @Test + public void shouldThrowIfDbToAddWasAlreadyAddedForOtherSegment() { + recorder.addValueProviders(SEGMENT_STORE_NAME_1, dbToAdd1, cacheToAdd1, statisticsToAdd1); + + final Throwable exception = assertThrows( + IllegalStateException.class, + () -> recorder.addValueProviders(SEGMENT_STORE_NAME_2, dbToAdd1, cacheToAdd2, statisticsToAdd2) + ); + assertThat( + exception.getMessage(), + is("DB instance for store " + SEGMENT_STORE_NAME_2 + " of task " + TASK_ID1 + + " was already added for another segment as a value provider. This is a bug in Kafka Streams. " + + "Please open a bug report under https://issues.apache.org/jira/projects/KAFKA/issues") + ); + } + + @Test + public void shouldAddItselfToRecordingTriggerWhenFirstValueProvidersAreAddedToNewlyCreatedRecorder() { + recordingTrigger.addMetricsRecorder(recorder); + replay(recordingTrigger); + + recorder.addValueProviders(SEGMENT_STORE_NAME_1, dbToAdd1, cacheToAdd1, statisticsToAdd1); + + verify(recordingTrigger); + } + + @Test + public void shouldAddItselfToRecordingTriggerWhenFirstValueProvidersAreAddedAfterLastValueProvidersWereRemoved() { + recorder.addValueProviders(SEGMENT_STORE_NAME_1, dbToAdd1, cacheToAdd1, statisticsToAdd1); + recorder.removeValueProviders(SEGMENT_STORE_NAME_1); + reset(recordingTrigger); + recordingTrigger.addMetricsRecorder(recorder); + replay(recordingTrigger); + + recorder.addValueProviders(SEGMENT_STORE_NAME_2, dbToAdd2, cacheToAdd2, statisticsToAdd2); + + verify(recordingTrigger); + } + + @Test + public void shouldNotAddItselfToRecordingTriggerWhenNotEmpty2() { + recorder.addValueProviders(SEGMENT_STORE_NAME_1, dbToAdd1, cacheToAdd1, statisticsToAdd1); + reset(recordingTrigger); + replay(recordingTrigger); + + recorder.addValueProviders(SEGMENT_STORE_NAME_2, dbToAdd2, cacheToAdd2, statisticsToAdd2); + + verify(recordingTrigger); + } + + @Test + public void shouldCloseStatisticsWhenValueProvidersAreRemoved() { + recorder.addValueProviders(SEGMENT_STORE_NAME_1, dbToAdd1, cacheToAdd1, statisticsToAdd1); + reset(statisticsToAdd1); + statisticsToAdd1.close(); + replay(statisticsToAdd1); + + recorder.removeValueProviders(SEGMENT_STORE_NAME_1); + + verify(statisticsToAdd1); + } + + @Test + public void shouldNotCloseStatisticsWhenValueProvidersWithoutStatisticsAreRemoved() { + recorder.addValueProviders(SEGMENT_STORE_NAME_1, dbToAdd1, cacheToAdd1, null); + reset(statisticsToAdd1); + replay(statisticsToAdd1); + + recorder.removeValueProviders(SEGMENT_STORE_NAME_1); + + verify(statisticsToAdd1); + } + + @Test + public void shouldRemoveItselfFromRecordingTriggerWhenLastValueProvidersAreRemoved() { + recorder.addValueProviders(SEGMENT_STORE_NAME_1, dbToAdd1, cacheToAdd1, statisticsToAdd1); + recorder.addValueProviders(SEGMENT_STORE_NAME_2, dbToAdd2, cacheToAdd2, statisticsToAdd2); + reset(recordingTrigger); + replay(recordingTrigger); + + recorder.removeValueProviders(SEGMENT_STORE_NAME_1); + + verify(recordingTrigger); + + reset(recordingTrigger); + recordingTrigger.removeMetricsRecorder(recorder); + replay(recordingTrigger); + + recorder.removeValueProviders(SEGMENT_STORE_NAME_2); + + verify(recordingTrigger); + } + + @Test + public void shouldThrowIfValueProvidersToRemoveNotFound() { + recorder.addValueProviders(SEGMENT_STORE_NAME_1, dbToAdd1, cacheToAdd1, statisticsToAdd1); + + assertThrows( + IllegalStateException.class, + () -> recorder.removeValueProviders(SEGMENT_STORE_NAME_2) + ); + } + + @Test + public void shouldRecordStatisticsBasedMetrics() { + recorder.addValueProviders(SEGMENT_STORE_NAME_1, dbToAdd1, cacheToAdd1, statisticsToAdd1); + recorder.addValueProviders(SEGMENT_STORE_NAME_2, dbToAdd2, cacheToAdd2, statisticsToAdd2); + reset(statisticsToAdd1); + reset(statisticsToAdd2); + + expect(statisticsToAdd1.getAndResetTickerCount(TickerType.BYTES_WRITTEN)).andReturn(1L); + expect(statisticsToAdd2.getAndResetTickerCount(TickerType.BYTES_WRITTEN)).andReturn(2L); + bytesWrittenToDatabaseSensor.record(1 + 2, 0L); + replay(bytesWrittenToDatabaseSensor); + + expect(statisticsToAdd1.getAndResetTickerCount(TickerType.BYTES_READ)).andReturn(2L); + expect(statisticsToAdd2.getAndResetTickerCount(TickerType.BYTES_READ)).andReturn(3L); + bytesReadFromDatabaseSensor.record(2 + 3, 0L); + replay(bytesReadFromDatabaseSensor); + + expect(statisticsToAdd1.getAndResetTickerCount(TickerType.FLUSH_WRITE_BYTES)).andReturn(3L); + expect(statisticsToAdd2.getAndResetTickerCount(TickerType.FLUSH_WRITE_BYTES)).andReturn(4L); + memtableBytesFlushedSensor.record(3 + 4, 0L); + replay(memtableBytesFlushedSensor); + + expect(statisticsToAdd1.getAndResetTickerCount(TickerType.MEMTABLE_HIT)).andReturn(1L); + expect(statisticsToAdd1.getAndResetTickerCount(TickerType.MEMTABLE_MISS)).andReturn(2L); + expect(statisticsToAdd2.getAndResetTickerCount(TickerType.MEMTABLE_HIT)).andReturn(3L); + expect(statisticsToAdd2.getAndResetTickerCount(TickerType.MEMTABLE_MISS)).andReturn(4L); + memtableHitRatioSensor.record((double) 4 / (4 + 6), 0L); + replay(memtableHitRatioSensor); + + expect(statisticsToAdd1.getAndResetTickerCount(TickerType.STALL_MICROS)).andReturn(4L); + expect(statisticsToAdd2.getAndResetTickerCount(TickerType.STALL_MICROS)).andReturn(5L); + writeStallDurationSensor.record(4 + 5, 0L); + replay(writeStallDurationSensor); + + expect(statisticsToAdd1.getAndResetTickerCount(TickerType.BLOCK_CACHE_DATA_HIT)).andReturn(5L); + expect(statisticsToAdd1.getAndResetTickerCount(TickerType.BLOCK_CACHE_DATA_MISS)).andReturn(4L); + expect(statisticsToAdd2.getAndResetTickerCount(TickerType.BLOCK_CACHE_DATA_HIT)).andReturn(3L); + expect(statisticsToAdd2.getAndResetTickerCount(TickerType.BLOCK_CACHE_DATA_MISS)).andReturn(2L); + blockCacheDataHitRatioSensor.record((double) 8 / (8 + 6), 0L); + replay(blockCacheDataHitRatioSensor); + + expect(statisticsToAdd1.getAndResetTickerCount(TickerType.BLOCK_CACHE_INDEX_HIT)).andReturn(4L); + expect(statisticsToAdd1.getAndResetTickerCount(TickerType.BLOCK_CACHE_INDEX_MISS)).andReturn(2L); + expect(statisticsToAdd2.getAndResetTickerCount(TickerType.BLOCK_CACHE_INDEX_HIT)).andReturn(2L); + expect(statisticsToAdd2.getAndResetTickerCount(TickerType.BLOCK_CACHE_INDEX_MISS)).andReturn(4L); + blockCacheIndexHitRatioSensor.record((double) 6 / (6 + 6), 0L); + replay(blockCacheIndexHitRatioSensor); + + expect(statisticsToAdd1.getAndResetTickerCount(TickerType.BLOCK_CACHE_FILTER_HIT)).andReturn(2L); + expect(statisticsToAdd1.getAndResetTickerCount(TickerType.BLOCK_CACHE_FILTER_MISS)).andReturn(4L); + expect(statisticsToAdd2.getAndResetTickerCount(TickerType.BLOCK_CACHE_FILTER_HIT)).andReturn(3L); + expect(statisticsToAdd2.getAndResetTickerCount(TickerType.BLOCK_CACHE_FILTER_MISS)).andReturn(5L); + blockCacheFilterHitRatioSensor.record((double) 5 / (5 + 9), 0L); + replay(blockCacheFilterHitRatioSensor); + + expect(statisticsToAdd1.getAndResetTickerCount(TickerType.COMPACT_WRITE_BYTES)).andReturn(2L); + expect(statisticsToAdd2.getAndResetTickerCount(TickerType.COMPACT_WRITE_BYTES)).andReturn(4L); + bytesWrittenDuringCompactionSensor.record(2 + 4, 0L); + replay(bytesWrittenDuringCompactionSensor); + + expect(statisticsToAdd1.getAndResetTickerCount(TickerType.COMPACT_READ_BYTES)).andReturn(5L); + expect(statisticsToAdd2.getAndResetTickerCount(TickerType.COMPACT_READ_BYTES)).andReturn(6L); + bytesReadDuringCompactionSensor.record(5 + 6, 0L); + replay(bytesReadDuringCompactionSensor); + + expect(statisticsToAdd1.getAndResetTickerCount(TickerType.NO_FILE_OPENS)).andReturn(5L); + expect(statisticsToAdd1.getAndResetTickerCount(TickerType.NO_FILE_CLOSES)).andReturn(3L); + expect(statisticsToAdd2.getAndResetTickerCount(TickerType.NO_FILE_OPENS)).andReturn(7L); + expect(statisticsToAdd2.getAndResetTickerCount(TickerType.NO_FILE_CLOSES)).andReturn(4L); + numberOfOpenFilesSensor.record((5 + 7) - (3 + 4), 0L); + replay(numberOfOpenFilesSensor); + + expect(statisticsToAdd1.getAndResetTickerCount(TickerType.NO_FILE_ERRORS)).andReturn(34L); + expect(statisticsToAdd2.getAndResetTickerCount(TickerType.NO_FILE_ERRORS)).andReturn(11L); + numberOfFileErrorsSensor.record(11 + 34, 0L); + replay(numberOfFileErrorsSensor); + + replay(statisticsToAdd1); + replay(statisticsToAdd2); + + recorder.record(0L); + + verify(statisticsToAdd1); + verify(statisticsToAdd2); + verify( + bytesWrittenToDatabaseSensor, + bytesReadFromDatabaseSensor, + memtableBytesFlushedSensor, + memtableHitRatioSensor, + writeStallDurationSensor, + blockCacheDataHitRatioSensor, + blockCacheIndexHitRatioSensor, + blockCacheFilterHitRatioSensor, + bytesWrittenDuringCompactionSensor, + bytesReadDuringCompactionSensor, + numberOfOpenFilesSensor, + numberOfFileErrorsSensor + ); + } + + @Test + public void shouldNotRecordStatisticsBasedMetricsIfStatisticsIsNull() { + recorder.addValueProviders(SEGMENT_STORE_NAME_1, dbToAdd1, cacheToAdd1, null); + replay( + bytesWrittenToDatabaseSensor, + bytesReadFromDatabaseSensor, + memtableBytesFlushedSensor, + memtableHitRatioSensor, + writeStallDurationSensor, + blockCacheDataHitRatioSensor, + blockCacheIndexHitRatioSensor, + blockCacheFilterHitRatioSensor, + bytesWrittenDuringCompactionSensor, + bytesReadDuringCompactionSensor, + numberOfOpenFilesSensor, + numberOfFileErrorsSensor + ); + + recorder.record(0L); + + verify( + bytesWrittenToDatabaseSensor, + bytesReadFromDatabaseSensor, + memtableBytesFlushedSensor, + memtableHitRatioSensor, + writeStallDurationSensor, + blockCacheDataHitRatioSensor, + blockCacheIndexHitRatioSensor, + blockCacheFilterHitRatioSensor, + bytesWrittenDuringCompactionSensor, + bytesReadDuringCompactionSensor, + numberOfOpenFilesSensor, + numberOfFileErrorsSensor + ); + } + + @Test + public void shouldCorrectlyHandleHitRatioRecordingsWithZeroHitsAndMisses() { + resetToNice(statisticsToAdd1); + recorder.addValueProviders(SEGMENT_STORE_NAME_1, dbToAdd1, cacheToAdd1, statisticsToAdd1); + expect(statisticsToAdd1.getTickerCount(anyObject())).andStubReturn(0L); + replay(statisticsToAdd1); + memtableHitRatioSensor.record(0, 0L); + blockCacheDataHitRatioSensor.record(0, 0L); + blockCacheIndexHitRatioSensor.record(0, 0L); + blockCacheFilterHitRatioSensor.record(0, 0L); + replay(memtableHitRatioSensor); + replay(blockCacheDataHitRatioSensor); + replay(blockCacheIndexHitRatioSensor); + replay(blockCacheFilterHitRatioSensor); + + recorder.record(0L); + + verify(memtableHitRatioSensor); + verify(blockCacheDataHitRatioSensor); + verify(blockCacheIndexHitRatioSensor); + verify(blockCacheFilterHitRatioSensor); + } + + private void setUpMetricsMock() { + mockStatic(RocksDBMetrics.class); + final RocksDBMetricContext metricsContext = + new RocksDBMetricContext(TASK_ID1.toString(), METRICS_SCOPE, STORE_NAME); + expect(RocksDBMetrics.bytesWrittenToDatabaseSensor(eq(streamsMetrics), eq(metricsContext))) + .andReturn(bytesWrittenToDatabaseSensor); + expect(RocksDBMetrics.bytesReadFromDatabaseSensor(eq(streamsMetrics), eq(metricsContext))) + .andReturn(bytesReadFromDatabaseSensor); + expect(RocksDBMetrics.memtableBytesFlushedSensor(eq(streamsMetrics), eq(metricsContext))) + .andReturn(memtableBytesFlushedSensor); + expect(RocksDBMetrics.memtableHitRatioSensor(eq(streamsMetrics), eq(metricsContext))) + .andReturn(memtableHitRatioSensor); + expect(RocksDBMetrics.writeStallDurationSensor(eq(streamsMetrics), eq(metricsContext))) + .andReturn(writeStallDurationSensor); + expect(RocksDBMetrics.blockCacheDataHitRatioSensor(eq(streamsMetrics), eq(metricsContext))) + .andReturn(blockCacheDataHitRatioSensor); + expect(RocksDBMetrics.blockCacheIndexHitRatioSensor(eq(streamsMetrics), eq(metricsContext))) + .andReturn(blockCacheIndexHitRatioSensor); + expect(RocksDBMetrics.blockCacheFilterHitRatioSensor(eq(streamsMetrics), eq(metricsContext))) + .andReturn(blockCacheFilterHitRatioSensor); + expect(RocksDBMetrics.bytesWrittenDuringCompactionSensor(eq(streamsMetrics), eq(metricsContext))) + .andReturn(bytesWrittenDuringCompactionSensor); + expect(RocksDBMetrics.bytesReadDuringCompactionSensor(eq(streamsMetrics), eq(metricsContext))) + .andReturn(bytesReadDuringCompactionSensor); + expect(RocksDBMetrics.numberOfOpenFilesSensor(eq(streamsMetrics), eq(metricsContext))) + .andReturn(numberOfOpenFilesSensor); + expect(RocksDBMetrics.numberOfFileErrorsSensor(eq(streamsMetrics), eq(metricsContext))) + .andReturn(numberOfFileErrorsSensor); + RocksDBMetrics.addNumImmutableMemTableMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addCurSizeActiveMemTable(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addCurSizeAllMemTables(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addSizeAllMemTables(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addNumEntriesActiveMemTableMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addNumEntriesImmMemTablesMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addNumDeletesActiveMemTableMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addNumDeletesImmMemTablesMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addMemTableFlushPending(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addNumRunningFlushesMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addCompactionPendingMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addNumRunningCompactionsMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addEstimatePendingCompactionBytesMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addTotalSstFilesSizeMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addLiveSstFilesSizeMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addNumLiveVersionMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addBlockCacheCapacityMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addBlockCacheUsageMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addBlockCachePinnedUsageMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addEstimateNumKeysMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addEstimateTableReadersMemMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addBackgroundErrorsMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + replay(RocksDBMetrics.class); + } + + private void setUpMetricsStubMock() { + mockStatic(RocksDBMetrics.class); + final RocksDBMetricContext metricsContext = + new RocksDBMetricContext(TASK_ID1.toString(), METRICS_SCOPE, STORE_NAME); + expect(RocksDBMetrics.bytesWrittenToDatabaseSensor(streamsMetrics, metricsContext)) + .andStubReturn(bytesWrittenToDatabaseSensor); + expect(RocksDBMetrics.bytesReadFromDatabaseSensor(streamsMetrics, metricsContext)) + .andStubReturn(bytesReadFromDatabaseSensor); + expect(RocksDBMetrics.memtableBytesFlushedSensor(streamsMetrics, metricsContext)) + .andStubReturn(memtableBytesFlushedSensor); + expect(RocksDBMetrics.memtableHitRatioSensor(streamsMetrics, metricsContext)) + .andStubReturn(memtableHitRatioSensor); + expect(RocksDBMetrics.writeStallDurationSensor(streamsMetrics, metricsContext)) + .andStubReturn(writeStallDurationSensor); + expect(RocksDBMetrics.blockCacheDataHitRatioSensor(streamsMetrics, metricsContext)) + .andStubReturn(blockCacheDataHitRatioSensor); + expect(RocksDBMetrics.blockCacheIndexHitRatioSensor(streamsMetrics, metricsContext)) + .andStubReturn(blockCacheIndexHitRatioSensor); + expect(RocksDBMetrics.blockCacheFilterHitRatioSensor(streamsMetrics, metricsContext)) + .andStubReturn(blockCacheFilterHitRatioSensor); + expect(RocksDBMetrics.bytesWrittenDuringCompactionSensor(streamsMetrics, metricsContext)) + .andStubReturn(bytesWrittenDuringCompactionSensor); + expect(RocksDBMetrics.bytesReadDuringCompactionSensor(streamsMetrics, metricsContext)) + .andStubReturn(bytesReadDuringCompactionSensor); + expect(RocksDBMetrics.numberOfOpenFilesSensor(streamsMetrics, metricsContext)) + .andStubReturn(numberOfOpenFilesSensor); + expect(RocksDBMetrics.numberOfFileErrorsSensor(streamsMetrics, metricsContext)) + .andStubReturn(numberOfFileErrorsSensor); + RocksDBMetrics.addNumImmutableMemTableMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addCurSizeActiveMemTable(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addCurSizeAllMemTables(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addSizeAllMemTables(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addNumEntriesActiveMemTableMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addNumEntriesImmMemTablesMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addNumDeletesActiveMemTableMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addNumDeletesImmMemTablesMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addMemTableFlushPending(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addNumRunningFlushesMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addCompactionPendingMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addNumRunningCompactionsMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addEstimatePendingCompactionBytesMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addTotalSstFilesSizeMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addLiveSstFilesSizeMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addNumLiveVersionMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addBlockCacheCapacityMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addBlockCacheUsageMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addBlockCachePinnedUsageMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addEstimateNumKeysMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addEstimateTableReadersMemMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + RocksDBMetrics.addBackgroundErrorsMetric(eq(streamsMetrics), eq(metricsContext), anyObject()); + replay(RocksDBMetrics.class); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/metrics/RocksDBMetricsRecordingTriggerTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/metrics/RocksDBMetricsRecordingTriggerTest.java new file mode 100644 index 0000000..d7c5c97 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/metrics/RocksDBMetricsRecordingTriggerTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals.metrics; + +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.processor.TaskId; +import org.junit.Before; +import org.junit.Test; + +import static org.easymock.EasyMock.expect; +import static org.easymock.EasyMock.niceMock; +import static org.easymock.EasyMock.replay; +import static org.easymock.EasyMock.resetToDefault; +import static org.easymock.EasyMock.verify; +import static org.junit.Assert.assertThrows; + +public class RocksDBMetricsRecordingTriggerTest { + + private final static String STORE_NAME1 = "store-name1"; + private final static String STORE_NAME2 = "store-name2"; + private final static TaskId TASK_ID1 = new TaskId(1, 2); + private final static TaskId TASK_ID2 = new TaskId(2, 4); + private final RocksDBMetricsRecorder recorder1 = niceMock(RocksDBMetricsRecorder.class); + private final RocksDBMetricsRecorder recorder2 = niceMock(RocksDBMetricsRecorder.class); + + private final Time time = new MockTime(); + private final RocksDBMetricsRecordingTrigger recordingTrigger = new RocksDBMetricsRecordingTrigger(time); + + @Before + public void setUp() { + expect(recorder1.storeName()).andStubReturn(STORE_NAME1); + expect(recorder1.taskId()).andStubReturn(TASK_ID1); + replay(recorder1); + expect(recorder2.storeName()).andStubReturn(STORE_NAME2); + expect(recorder2.taskId()).andStubReturn(TASK_ID2); + replay(recorder2); + } + + @Test + public void shouldTriggerAddedMetricsRecorders() { + recordingTrigger.addMetricsRecorder(recorder1); + recordingTrigger.addMetricsRecorder(recorder2); + + resetToDefault(recorder1); + recorder1.record(time.milliseconds()); + replay(recorder1); + resetToDefault(recorder2); + recorder2.record(time.milliseconds()); + replay(recorder2); + + recordingTrigger.run(); + + verify(recorder1); + verify(recorder2); + } + + @Test + public void shouldThrowIfRecorderToAddHasBeenAlreadyAdded() { + recordingTrigger.addMetricsRecorder(recorder1); + + assertThrows( + IllegalStateException.class, + () -> recordingTrigger.addMetricsRecorder(recorder1) + ); + } + + @Test + public void shouldThrowIfRecorderToRemoveCouldNotBeFound() { + recordingTrigger.addMetricsRecorder(recorder1); + assertThrows( + IllegalStateException.class, + () -> recordingTrigger.removeMetricsRecorder(recorder2) + ); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/metrics/RocksDBMetricsTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/metrics/RocksDBMetricsTest.java new file mode 100644 index 0000000..c8d6388 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/metrics/RocksDBMetricsTest.java @@ -0,0 +1,548 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals.metrics; + +import org.apache.kafka.common.metrics.Gauge; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.Sensor.RecordingLevel; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.internals.metrics.RocksDBMetrics.RocksDBMetricContext; +import org.junit.Test; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.eq; + +import java.math.BigInteger; +import java.util.Collections; +import java.util.Map; + + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +public class RocksDBMetricsTest { + + private static final String STATE_LEVEL_GROUP = "stream-state-metrics"; + private static final String TASK_ID = "test-task"; + private static final String STORE_TYPE = "test-store-type"; + private static final String STORE_NAME = "store"; + private static final RocksDBMetricContext ROCKSDB_METRIC_CONTEXT = + new RocksDBMetricContext(TASK_ID, STORE_TYPE, STORE_NAME); + private static final Gauge VALUE_PROVIDER = (config, now) -> BigInteger.valueOf(10); + + private final Metrics metrics = new Metrics(); + private final Sensor sensor = metrics.sensor("dummy"); + private final StreamsMetricsImpl streamsMetrics = mock(StreamsMetricsImpl.class); + private final Map tags = Collections.singletonMap("hello", "world"); + + private interface SensorCreator { + Sensor sensor(final StreamsMetricsImpl streamsMetrics, final RocksDBMetricContext metricContext); + } + + @Test + public void shouldGetBytesWrittenSensor() { + final String metricNamePrefix = "bytes-written"; + final String descriptionOfTotal = "Total number of bytes written to the RocksDB state store"; + final String descriptionOfRate = "Average number of bytes written per second to the RocksDB state store"; + verifyRateAndTotalSensor( + metricNamePrefix, + descriptionOfTotal, + descriptionOfRate, + RocksDBMetrics::bytesWrittenToDatabaseSensor + ); + } + + @Test + public void shouldGetBytesReadSensor() { + final String metricNamePrefix = "bytes-read"; + final String descriptionOfTotal = "Total number of bytes read from the RocksDB state store"; + final String descriptionOfRate = "Average number of bytes read per second from the RocksDB state store"; + verifyRateAndTotalSensor( + metricNamePrefix, + descriptionOfTotal, + descriptionOfRate, + RocksDBMetrics::bytesReadFromDatabaseSensor + ); + } + + @Test + public void shouldGetMemtableHitRatioSensor() { + final String metricNamePrefix = "memtable-hit-ratio"; + final String description = "Ratio of memtable hits relative to all lookups to the memtable"; + verifyValueSensor(metricNamePrefix, description, RocksDBMetrics::memtableHitRatioSensor); + } + + @Test + public void shouldGetMemtableBytesFlushedSensor() { + final String metricNamePrefix = "memtable-bytes-flushed"; + final String descriptionOfTotal = "Total number of bytes flushed from the memtable to disk"; + final String descriptionOfRate = "Average number of bytes flushed per second from the memtable to disk"; + verifyRateAndTotalSensor( + metricNamePrefix, + descriptionOfTotal, + descriptionOfRate, + RocksDBMetrics::memtableBytesFlushedSensor + ); + } + + @Test + public void shouldGetMemtableAvgFlushTimeSensor() { + final String metricNamePrefix = "memtable-flush-time-avg"; + final String description = "Average time spent on flushing the memtable to disk in ms"; + verifyValueSensor(metricNamePrefix, description, RocksDBMetrics::memtableAvgFlushTimeSensor); + } + + @Test + public void shouldGetMemtableMinFlushTimeSensor() { + final String metricNamePrefix = "memtable-flush-time-min"; + final String description = "Minimum time spent on flushing the memtable to disk in ms"; + verifyValueSensor(metricNamePrefix, description, RocksDBMetrics::memtableMinFlushTimeSensor); + } + + @Test + public void shouldGetMemtableMaxFlushTimeSensor() { + final String metricNamePrefix = "memtable-flush-time-max"; + final String description = "Maximum time spent on flushing the memtable to disk in ms"; + verifyValueSensor(metricNamePrefix, description, RocksDBMetrics::memtableMaxFlushTimeSensor); + } + + @Test + public void shouldGetWriteStallDurationSensor() { + final String metricNamePrefix = "write-stall-duration"; + final String descriptionOfAvg = "Average duration of write stalls in ms"; + final String descriptionOfTotal = "Total duration of write stalls in ms"; + setupStreamsMetricsMock(metricNamePrefix); + StreamsMetricsImpl.addAvgAndSumMetricsToSensor( + sensor, + STATE_LEVEL_GROUP, + tags, + metricNamePrefix, + descriptionOfAvg, + descriptionOfTotal + ); + + replayCallAndVerify(RocksDBMetrics::writeStallDurationSensor); + } + + @Test + public void shouldGetBlockCacheDataHitRatioSensor() { + final String metricNamePrefix = "block-cache-data-hit-ratio"; + final String description = + "Ratio of block cache hits for data relative to all lookups for data to the block cache"; + verifyValueSensor(metricNamePrefix, description, RocksDBMetrics::blockCacheDataHitRatioSensor); + } + + @Test + public void shouldGetBlockCacheIndexHitRatioSensor() { + final String metricNamePrefix = "block-cache-index-hit-ratio"; + final String description = + "Ratio of block cache hits for indexes relative to all lookups for indexes to the block cache"; + verifyValueSensor(metricNamePrefix, description, RocksDBMetrics::blockCacheIndexHitRatioSensor); + } + + @Test + public void shouldGetBlockCacheFilterHitRatioSensor() { + final String metricNamePrefix = "block-cache-filter-hit-ratio"; + final String description = + "Ratio of block cache hits for filters relative to all lookups for filters to the block cache"; + verifyValueSensor(metricNamePrefix, description, RocksDBMetrics::blockCacheFilterHitRatioSensor); + } + + @Test + public void shouldGetBytesReadDuringCompactionSensor() { + final String metricNamePrefix = "bytes-read-compaction"; + final String description = "Average number of bytes read per second during compaction"; + verifyRateSensor(metricNamePrefix, description, RocksDBMetrics::bytesReadDuringCompactionSensor); + } + + @Test + public void shouldGetBytesWrittenDuringCompactionSensor() { + final String metricNamePrefix = "bytes-written-compaction"; + final String description = "Average number of bytes written per second during compaction"; + verifyRateSensor(metricNamePrefix, description, RocksDBMetrics::bytesWrittenDuringCompactionSensor); + } + + @Test + public void shouldGetCompactionTimeAvgSensor() { + final String metricNamePrefix = "compaction-time-avg"; + final String description = "Average time spent on compaction in ms"; + verifyValueSensor(metricNamePrefix, description, RocksDBMetrics::compactionTimeAvgSensor); + } + + @Test + public void shouldGetCompactionTimeMinSensor() { + final String metricNamePrefix = "compaction-time-min"; + final String description = "Minimum time spent on compaction in ms"; + verifyValueSensor(metricNamePrefix, description, RocksDBMetrics::compactionTimeMinSensor); + } + + @Test + public void shouldGetCompactionTimeMaxSensor() { + final String metricNamePrefix = "compaction-time-max"; + final String description = "Maximum time spent on compaction in ms"; + verifyValueSensor(metricNamePrefix, description, RocksDBMetrics::compactionTimeMaxSensor); + } + + @Test + public void shouldGetNumberOfOpenFilesSensor() { + final String metricNamePrefix = "number-open-files"; + final String description = "Number of currently open files"; + verifySumSensor(metricNamePrefix, false, description, RocksDBMetrics::numberOfOpenFilesSensor); + } + + @Test + public void shouldGetNumberOfFilesErrors() { + final String metricNamePrefix = "number-file-errors"; + final String description = "Total number of file errors occurred"; + verifySumSensor(metricNamePrefix, true, description, RocksDBMetrics::numberOfFileErrorsSensor); + } + + @Test + public void shouldAddNumEntriesActiveMemTableMetric() { + final String name = "num-entries-active-mem-table"; + final String description = "Total number of entries in the active memtable"; + runAndVerifyMutableMetric( + name, + description, + () -> RocksDBMetrics.addNumEntriesActiveMemTableMetric(streamsMetrics, ROCKSDB_METRIC_CONTEXT, VALUE_PROVIDER) + ); + } + + @Test + public void shouldAddNumberDeletesActiveTableMetric() { + final String name = "num-deletes-active-mem-table"; + final String description = "Total number of delete entries in the active memtable"; + runAndVerifyMutableMetric( + name, + description, + () -> RocksDBMetrics.addNumDeletesActiveMemTableMetric(streamsMetrics, ROCKSDB_METRIC_CONTEXT, VALUE_PROVIDER) + ); + } + + @Test + public void shouldAddNumEntriesImmutableMemTablesMetric() { + final String name = "num-entries-imm-mem-tables"; + final String description = "Total number of entries in the unflushed immutable memtables"; + runAndVerifyMutableMetric( + name, + description, + () -> RocksDBMetrics.addNumEntriesImmMemTablesMetric(streamsMetrics, ROCKSDB_METRIC_CONTEXT, VALUE_PROVIDER) + ); + } + + @Test + public void shouldAddNumDeletesImmutableMemTablesMetric() { + final String name = "num-deletes-imm-mem-tables"; + final String description = "Total number of delete entries in the unflushed immutable memtables"; + runAndVerifyMutableMetric( + name, + description, + () -> RocksDBMetrics.addNumDeletesImmMemTablesMetric(streamsMetrics, ROCKSDB_METRIC_CONTEXT, VALUE_PROVIDER) + ); + } + + @Test + public void shouldAddNumImmutableMemTablesMetric() { + final String name = "num-immutable-mem-table"; + final String description = "Number of immutable memtables that have not yet been flushed"; + runAndVerifyMutableMetric( + name, + description, + () -> RocksDBMetrics.addNumImmutableMemTableMetric(streamsMetrics, ROCKSDB_METRIC_CONTEXT, VALUE_PROVIDER) + ); + } + + @Test + public void shouldAddCurSizeActiveMemTableMetric() { + final String name = "cur-size-active-mem-table"; + final String description = "Approximate size of active memtable in bytes"; + runAndVerifyMutableMetric( + name, + description, + () -> RocksDBMetrics.addCurSizeActiveMemTable(streamsMetrics, ROCKSDB_METRIC_CONTEXT, VALUE_PROVIDER) + ); + } + + @Test + public void shouldAddCurSizeAllMemTablesMetric() { + final String name = "cur-size-all-mem-tables"; + final String description = "Approximate size of active and unflushed immutable memtables in bytes"; + runAndVerifyMutableMetric( + name, + description, + () -> RocksDBMetrics.addCurSizeAllMemTables(streamsMetrics, ROCKSDB_METRIC_CONTEXT, VALUE_PROVIDER) + ); + } + + @Test + public void shouldAddSizeAllMemTablesMetric() { + final String name = "size-all-mem-tables"; + final String description = "Approximate size of active, unflushed immutable, and pinned immutable memtables in bytes"; + runAndVerifyMutableMetric( + name, + description, + () -> RocksDBMetrics.addSizeAllMemTables(streamsMetrics, ROCKSDB_METRIC_CONTEXT, VALUE_PROVIDER) + ); + } + + @Test + public void shouldAddMemTableFlushPendingMetric() { + final String name = "mem-table-flush-pending"; + final String description = "Reports 1 if a memtable flush is pending, otherwise it reports 0"; + runAndVerifyMutableMetric( + name, + description, + () -> RocksDBMetrics.addMemTableFlushPending(streamsMetrics, ROCKSDB_METRIC_CONTEXT, VALUE_PROVIDER) + ); + } + + @Test + public void shouldAddNumRunningFlushesMetric() { + final String name = "num-running-flushes"; + final String description = "Number of currently running flushes"; + runAndVerifyMutableMetric( + name, + description, + () -> RocksDBMetrics.addNumRunningFlushesMetric(streamsMetrics, ROCKSDB_METRIC_CONTEXT, VALUE_PROVIDER) + ); + } + + @Test + public void shouldAddCompactionPendingMetric() { + final String name = "compaction-pending"; + final String description = "Reports 1 if at least one compaction is pending, otherwise it reports 0"; + runAndVerifyMutableMetric( + name, + description, + () -> RocksDBMetrics.addCompactionPendingMetric(streamsMetrics, ROCKSDB_METRIC_CONTEXT, VALUE_PROVIDER) + ); + } + + @Test + public void shouldAddNumRunningCompactionsMetric() { + final String name = "num-running-compactions"; + final String description = "Number of currently running compactions"; + runAndVerifyMutableMetric( + name, + description, + () -> RocksDBMetrics.addNumRunningCompactionsMetric(streamsMetrics, ROCKSDB_METRIC_CONTEXT, VALUE_PROVIDER) + ); + } + + @Test + public void shouldAddEstimatePendingCompactionBytesMetric() { + final String name = "estimate-pending-compaction-bytes"; + final String description = + "Estimated total number of bytes a compaction needs to rewrite on disk to get all levels down to under target size"; + runAndVerifyMutableMetric( + name, + description, + () -> RocksDBMetrics.addEstimatePendingCompactionBytesMetric(streamsMetrics, ROCKSDB_METRIC_CONTEXT, VALUE_PROVIDER) + ); + } + + @Test + public void shouldAddTotalSstFilesSizeMetric() { + final String name = "total-sst-files-size"; + final String description = "Total size in bytes of all SST files"; + runAndVerifyMutableMetric( + name, + description, + () -> RocksDBMetrics.addTotalSstFilesSizeMetric(streamsMetrics, ROCKSDB_METRIC_CONTEXT, VALUE_PROVIDER) + ); + } + + @Test + public void shouldAddLiveSstFilesSizeMetric() { + final String name = "live-sst-files-size"; + final String description = "Total size in bytes of all SST files that belong to the latest LSM tree"; + runAndVerifyMutableMetric( + name, + description, + () -> RocksDBMetrics.addLiveSstFilesSizeMetric(streamsMetrics, ROCKSDB_METRIC_CONTEXT, VALUE_PROVIDER) + ); + } + + @Test + public void shouldAddNumLiveVersionMetric() { + final String name = "num-live-versions"; + final String description = "Number of live versions of the LSM tree"; + runAndVerifyMutableMetric( + name, + description, + () -> RocksDBMetrics.addNumLiveVersionMetric(streamsMetrics, ROCKSDB_METRIC_CONTEXT, VALUE_PROVIDER) + ); + } + + @Test + public void shouldAddBlockCacheCapacityMetric() { + final String name = "block-cache-capacity"; + final String description = "Capacity of the block cache in bytes"; + runAndVerifyMutableMetric( + name, + description, + () -> RocksDBMetrics.addBlockCacheCapacityMetric(streamsMetrics, ROCKSDB_METRIC_CONTEXT, VALUE_PROVIDER) + ); + } + + @Test + public void shouldAddBlockCacheUsageMetric() { + final String name = "block-cache-usage"; + final String description = "Memory size of the entries residing in block cache in bytes"; + runAndVerifyMutableMetric( + name, + description, + () -> RocksDBMetrics.addBlockCacheUsageMetric(streamsMetrics, ROCKSDB_METRIC_CONTEXT, VALUE_PROVIDER) + ); + } + + @Test + public void shouldAddBlockCachePinnedUsageMetric() { + final String name = "block-cache-pinned-usage"; + final String description = "Memory size for the entries being pinned in the block cache in bytes"; + runAndVerifyMutableMetric( + name, + description, + () -> RocksDBMetrics.addBlockCachePinnedUsageMetric(streamsMetrics, ROCKSDB_METRIC_CONTEXT, VALUE_PROVIDER) + ); + } + + @Test + public void shouldAddEstimateNumKeysMetric() { + final String name = "estimate-num-keys"; + final String description = + "Estimated number of keys in the active and unflushed immutable memtables and storage"; + runAndVerifyMutableMetric( + name, + description, + () -> RocksDBMetrics.addEstimateNumKeysMetric(streamsMetrics, ROCKSDB_METRIC_CONTEXT, VALUE_PROVIDER) + ); + } + + @Test + public void shouldAddEstimateTableReadersMemMetric() { + final String name = "estimate-table-readers-mem"; + final String description = + "Estimated memory in bytes used for reading SST tables, excluding memory used in block cache"; + runAndVerifyMutableMetric( + name, + description, + () -> RocksDBMetrics.addEstimateTableReadersMemMetric(streamsMetrics, ROCKSDB_METRIC_CONTEXT, VALUE_PROVIDER) + ); + } + + @Test + public void shouldAddBackgroundErrorsMetric() { + final String name = "background-errors"; + final String description = "Total number of background errors"; + runAndVerifyMutableMetric( + name, + description, + () -> RocksDBMetrics.addBackgroundErrorsMetric(streamsMetrics, ROCKSDB_METRIC_CONTEXT, VALUE_PROVIDER) + ); + } + + private void runAndVerifyMutableMetric(final String name, final String description, final Runnable metricAdder) { + + metricAdder.run(); + + verify(streamsMetrics).addStoreLevelMutableMetric( + eq(TASK_ID), + eq(STORE_TYPE), + eq(STORE_NAME), + eq(name), + eq(description), + eq(RecordingLevel.INFO), + eq(VALUE_PROVIDER) + ); + } + + private void verifyRateAndTotalSensor(final String metricNamePrefix, + final String descriptionOfTotal, + final String descriptionOfRate, + final SensorCreator sensorCreator) { + setupStreamsMetricsMock(metricNamePrefix); + StreamsMetricsImpl.addRateOfSumAndSumMetricsToSensor( + sensor, + STATE_LEVEL_GROUP, + tags, + metricNamePrefix, + descriptionOfRate, + descriptionOfTotal + ); + + replayCallAndVerify(sensorCreator); + } + + private void verifyRateSensor(final String metricNamePrefix, + final String description, + final SensorCreator sensorCreator) { + setupStreamsMetricsMock(metricNamePrefix); + StreamsMetricsImpl.addRateOfSumMetricToSensor(sensor, STATE_LEVEL_GROUP, tags, metricNamePrefix, description); + + replayCallAndVerify(sensorCreator); + } + + private void verifyValueSensor(final String metricNamePrefix, + final String description, + final SensorCreator sensorCreator) { + setupStreamsMetricsMock(metricNamePrefix); + StreamsMetricsImpl.addValueMetricToSensor(sensor, STATE_LEVEL_GROUP, tags, metricNamePrefix, description); + + replayCallAndVerify(sensorCreator); + } + + private void verifySumSensor(final String metricNamePrefix, + final boolean withSuffix, + final String description, + final SensorCreator sensorCreator) { + setupStreamsMetricsMock(metricNamePrefix); + if (withSuffix) { + StreamsMetricsImpl.addSumMetricToSensor(sensor, STATE_LEVEL_GROUP, tags, metricNamePrefix, description); + } else { + StreamsMetricsImpl + .addSumMetricToSensor(sensor, STATE_LEVEL_GROUP, tags, metricNamePrefix, withSuffix, description); + } + + replayCallAndVerify(sensorCreator); + } + + private void setupStreamsMetricsMock(final String metricNamePrefix) { + + when(streamsMetrics.storeLevelSensor( + TASK_ID, + STORE_NAME, + metricNamePrefix, + RecordingLevel.DEBUG + )).thenReturn(sensor); + when(streamsMetrics.storeLevelTagMap( + TASK_ID, + STORE_TYPE, + STORE_NAME + )).thenReturn(tags); + } + + private void replayCallAndVerify(final SensorCreator sensorCreator) { + + final Sensor sensor = sensorCreator.sensor(streamsMetrics, ROCKSDB_METRIC_CONTEXT); + + + assertThat(sensor, is(this.sensor)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/state/internals/metrics/StateStoreMetricsTest.java b/streams/src/test/java/org/apache/kafka/streams/state/internals/metrics/StateStoreMetricsTest.java new file mode 100644 index 0000000..f981316 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/state/internals/metrics/StateStoreMetricsTest.java @@ -0,0 +1,382 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.state.internals.metrics; + +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.metrics.Sensor.RecordingLevel; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.junit.Test; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.Collections; +import java.util.Map; +import java.util.function.Supplier; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +public class StateStoreMetricsTest { + + private static final String TASK_ID = "test-task"; + private static final String STORE_NAME = "test-store"; + private static final String STORE_TYPE = "test-type"; + private static final String STORE_LEVEL_GROUP = "stream-state-metrics"; + private static final String BUFFER_NAME = "test-buffer"; + + private final Sensor expectedSensor = mock(Sensor.class); + private final StreamsMetricsImpl streamsMetrics = mock(StreamsMetricsImpl.class); + private final Map storeTagMap = Collections.singletonMap("hello", "world"); + + @Test + public void shouldGetPutSensor() { + final String metricName = "put"; + final String descriptionOfRate = "The average number of calls to put per second"; + final String descriptionOfAvg = "The average latency of calls to put"; + final String descriptionOfMax = "The maximum latency of calls to put"; + shouldGetSensor( + metricName, + descriptionOfRate, + descriptionOfAvg, + descriptionOfMax, + () -> StateStoreMetrics.putSensor(TASK_ID, STORE_TYPE, STORE_NAME, streamsMetrics) + ); + } + + @Test + public void shouldGetPutIfAbsentSensor() { + final String metricName = "put-if-absent"; + final String descriptionOfRate = "The average number of calls to put-if-absent per second"; + final String descriptionOfAvg = "The average latency of calls to put-if-absent"; + final String descriptionOfMax = "The maximum latency of calls to put-if-absent"; + shouldGetSensor( + metricName, + descriptionOfRate, + descriptionOfAvg, + descriptionOfMax, + () -> StateStoreMetrics.putIfAbsentSensor(TASK_ID, STORE_TYPE, STORE_NAME, streamsMetrics) + ); + } + + @Test + public void shouldGetPutAllSensor() { + final String metricName = "put-all"; + final String descriptionOfRate = "The average number of calls to put-all per second"; + final String descriptionOfAvg = "The average latency of calls to put-all"; + final String descriptionOfMax = "The maximum latency of calls to put-all"; + shouldGetSensor( + metricName, + descriptionOfRate, + descriptionOfAvg, + descriptionOfMax, + () -> StateStoreMetrics.putAllSensor(TASK_ID, STORE_TYPE, STORE_NAME, streamsMetrics) + ); + } + + @Test + public void shouldGetFetchSensor() { + final String metricName = "fetch"; + final String descriptionOfRate = "The average number of calls to fetch per second"; + final String descriptionOfAvg = "The average latency of calls to fetch"; + final String descriptionOfMax = "The maximum latency of calls to fetch"; + shouldGetSensor( + metricName, + descriptionOfRate, + descriptionOfAvg, + descriptionOfMax, + () -> StateStoreMetrics.fetchSensor(TASK_ID, STORE_TYPE, STORE_NAME, streamsMetrics) + ); + } + + @Test + public void shouldGetGetSensor() { + final String metricName = "get"; + final String descriptionOfRate = "The average number of calls to get per second"; + final String descriptionOfAvg = "The average latency of calls to get"; + final String descriptionOfMax = "The maximum latency of calls to get"; + shouldGetSensor( + metricName, + descriptionOfRate, + descriptionOfAvg, + descriptionOfMax, + () -> StateStoreMetrics.getSensor(TASK_ID, STORE_TYPE, STORE_NAME, streamsMetrics) + ); + } + + @Test + public void shouldGetAllSensor() { + final String metricName = "all"; + final String descriptionOfRate = "The average number of calls to all per second"; + final String descriptionOfAvg = "The average latency of calls to all"; + final String descriptionOfMax = "The maximum latency of calls to all"; + shouldGetSensor( + metricName, + descriptionOfRate, + descriptionOfAvg, + descriptionOfMax, + () -> StateStoreMetrics.allSensor(TASK_ID, STORE_TYPE, STORE_NAME, streamsMetrics) + ); + } + + @Test + public void shouldGetRangeSensor() { + final String metricName = "range"; + final String descriptionOfRate = "The average number of calls to range per second"; + final String descriptionOfAvg = "The average latency of calls to range"; + final String descriptionOfMax = "The maximum latency of calls to range"; + shouldGetSensor( + metricName, + descriptionOfRate, + descriptionOfAvg, + descriptionOfMax, + () -> StateStoreMetrics.rangeSensor(TASK_ID, STORE_TYPE, STORE_NAME, streamsMetrics) + ); + } + + @Test + public void shouldGetPrefixScanSensor() { + final String metricName = "prefix-scan"; + final String descriptionOfRate = "The average number of calls to prefix-scan per second"; + final String descriptionOfAvg = "The average latency of calls to prefix-scan"; + final String descriptionOfMax = "The maximum latency of calls to prefix-scan"; + when(streamsMetrics.storeLevelSensor(TASK_ID, STORE_NAME, metricName, RecordingLevel.DEBUG)) + .thenReturn(expectedSensor); + when(streamsMetrics.storeLevelTagMap(TASK_ID, STORE_TYPE, STORE_NAME)).thenReturn(storeTagMap); + StreamsMetricsImpl.addInvocationRateToSensor( + expectedSensor, + STORE_LEVEL_GROUP, + storeTagMap, + metricName, + descriptionOfRate + ); + StreamsMetricsImpl.addAvgAndMaxToSensor( + expectedSensor, + STORE_LEVEL_GROUP, + storeTagMap, + latencyMetricName(metricName), + descriptionOfAvg, + descriptionOfMax + ); + + final Sensor sensor = StateStoreMetrics.prefixScanSensor(TASK_ID, STORE_TYPE, STORE_NAME, streamsMetrics); + + assertThat(sensor, is(expectedSensor)); + } + + @Test + public void shouldGetFlushSensor() { + final String metricName = "flush"; + final String descriptionOfRate = "The average number of calls to flush per second"; + final String descriptionOfAvg = "The average latency of calls to flush"; + final String descriptionOfMax = "The maximum latency of calls to flush"; + shouldGetSensor( + metricName, + descriptionOfRate, + descriptionOfAvg, + descriptionOfMax, + () -> StateStoreMetrics.flushSensor(TASK_ID, STORE_TYPE, STORE_NAME, streamsMetrics) + ); + } + + @Test + public void shouldGetRemoveSensor() { + final String metricName = "remove"; + final String descriptionOfRate = "The average number of calls to remove per second"; + final String descriptionOfAvg = "The average latency of calls to remove"; + final String descriptionOfMax = "The maximum latency of calls to remove"; + shouldGetSensor( + metricName, + descriptionOfRate, + descriptionOfAvg, + descriptionOfMax, + () -> StateStoreMetrics.removeSensor(TASK_ID, STORE_TYPE, STORE_NAME, streamsMetrics) + ); + } + + @Test + public void shouldGetDeleteSensor() { + final String metricName = "delete"; + final String descriptionOfRate = "The average number of calls to delete per second"; + final String descriptionOfAvg = "The average latency of calls to delete"; + final String descriptionOfMax = "The maximum latency of calls to delete"; + shouldGetSensor( + metricName, + descriptionOfRate, + descriptionOfAvg, + descriptionOfMax, + () -> StateStoreMetrics.deleteSensor(TASK_ID, STORE_TYPE, STORE_NAME, streamsMetrics) + ); + } + + @Test + public void shouldGetRestoreSensor() { + final String metricName = "restore"; + final String descriptionOfRate = "The average number of restorations per second"; + final String descriptionOfAvg = "The average latency of restorations"; + final String descriptionOfMax = "The maximum latency of restorations"; + shouldGetSensor( + metricName, + descriptionOfRate, + descriptionOfAvg, + descriptionOfMax, + () -> StateStoreMetrics.restoreSensor(TASK_ID, STORE_TYPE, STORE_NAME, streamsMetrics) + ); + } + + @Test + public void shouldGetSuppressionBufferCountSensor() { + final String metricName = "suppression-buffer-count"; + final String descriptionOfAvg = "The average count of buffered records"; + final String descriptionOfMax = "The maximum count of buffered records"; + shouldGetSuppressionBufferSensor( + metricName, + descriptionOfAvg, + descriptionOfMax, + () -> StateStoreMetrics.suppressionBufferCountSensor(TASK_ID, STORE_TYPE, BUFFER_NAME, streamsMetrics) + ); + } + + @Test + public void shouldGetSuppressionBufferSizeSensor() { + final String metricName = "suppression-buffer-size"; + final String descriptionOfAvg = "The average size of buffered records"; + final String descriptionOfMax = "The maximum size of buffered records"; + shouldGetSuppressionBufferSensor( + metricName, + descriptionOfAvg, + descriptionOfMax, + () -> StateStoreMetrics.suppressionBufferSizeSensor(TASK_ID, STORE_TYPE, BUFFER_NAME, streamsMetrics) + ); + } + + @Test + public void shouldGetExpiredWindowRecordDropSensor() { + final String metricName = "expired-window-record-drop"; + final String descriptionOfRate = "The average number of dropped records due to an expired window per second"; + final String descriptionOfCount = "The total number of dropped records due to an expired window"; + when(streamsMetrics.storeLevelSensor(TASK_ID, STORE_NAME, metricName, RecordingLevel.INFO)) + .thenReturn(expectedSensor); + + when(streamsMetrics.storeLevelTagMap(TASK_ID, STORE_TYPE, STORE_NAME)).thenReturn(storeTagMap); + StreamsMetricsImpl.addInvocationRateAndCountToSensor( + expectedSensor, + "stream-" + STORE_TYPE + "-metrics", + storeTagMap, + metricName, + descriptionOfRate, + descriptionOfCount + ); + + final Sensor sensor = + StateStoreMetrics.expiredWindowRecordDropSensor(TASK_ID, STORE_TYPE, STORE_NAME, streamsMetrics); + + assertThat(sensor, is(expectedSensor)); + } + + @Test + public void shouldGetRecordE2ELatencySensor() { + final String metricName = "record-e2e-latency"; + + final String e2eLatencyDescription = + "end-to-end latency of a record, measuring by comparing the record timestamp with the " + + "system time when it has been fully processed by the node"; + final String descriptionOfAvg = "The average " + e2eLatencyDescription; + final String descriptionOfMin = "The minimum " + e2eLatencyDescription; + final String descriptionOfMax = "The maximum " + e2eLatencyDescription; + + when(streamsMetrics.storeLevelSensor(TASK_ID, STORE_NAME, metricName, RecordingLevel.TRACE)) + .thenReturn(expectedSensor); + when(streamsMetrics.storeLevelTagMap(TASK_ID, STORE_TYPE, STORE_NAME)).thenReturn(storeTagMap); + + StreamsMetricsImpl.addAvgAndMinAndMaxToSensor( + expectedSensor, + STORE_LEVEL_GROUP, + storeTagMap, + metricName, + descriptionOfAvg, + descriptionOfMin, + descriptionOfMax + ); + + final Sensor sensor = + StateStoreMetrics.e2ELatencySensor(TASK_ID, STORE_TYPE, STORE_NAME, streamsMetrics); + + assertThat(sensor, is(expectedSensor)); + } + + private void shouldGetSensor(final String metricName, + final String descriptionOfRate, + final String descriptionOfAvg, + final String descriptionOfMax, + final Supplier sensorSupplier) { + when(streamsMetrics.storeLevelSensor( + TASK_ID, + STORE_NAME, + metricName, + RecordingLevel.DEBUG + )).thenReturn(expectedSensor); + + StreamsMetricsImpl.addInvocationRateToSensor( + expectedSensor, + STORE_LEVEL_GROUP, + storeTagMap, + metricName, + descriptionOfRate + ); + when(streamsMetrics.storeLevelTagMap(TASK_ID, STORE_TYPE, STORE_NAME)).thenReturn(storeTagMap); + + StreamsMetricsImpl.addAvgAndMaxToSensor( + expectedSensor, + STORE_LEVEL_GROUP, + storeTagMap, + latencyMetricName(metricName), + descriptionOfAvg, + descriptionOfMax + ); + + final Sensor sensor = sensorSupplier.get(); + + assertThat(sensor, is(expectedSensor)); + } + + private String latencyMetricName(final String metricName) { + return metricName + StreamsMetricsImpl.LATENCY_SUFFIX; + } + + private void shouldGetSuppressionBufferSensor(final String metricName, + final String descriptionOfAvg, + final String descriptionOfMax, + final Supplier sensorSupplier) { + final Map tagMap; + when(streamsMetrics.storeLevelSensor(TASK_ID, BUFFER_NAME, metricName, RecordingLevel.DEBUG)).thenReturn(expectedSensor); + tagMap = storeTagMap; + when(streamsMetrics.storeLevelTagMap(TASK_ID, STORE_TYPE, BUFFER_NAME)).thenReturn(tagMap); + + StreamsMetricsImpl.addAvgAndMaxToSensor( + expectedSensor, + STORE_LEVEL_GROUP, + tagMap, + metricName, + descriptionOfAvg, + descriptionOfMax + ); + + final Sensor sensor = sensorSupplier.get(); + + assertThat(sensor, is(expectedSensor)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/tests/BrokerCompatibilityTest.java b/streams/src/test/java/org/apache/kafka/streams/tests/BrokerCompatibilityTest.java new file mode 100644 index 0000000..8a402ab --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/tests/BrokerCompatibilityTest.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.IsolationLevel; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.errors.StreamsUncaughtExceptionHandler; +import org.apache.kafka.streams.kstream.Grouped; + +import java.io.IOException; +import java.time.Duration; +import java.util.Collections; +import java.util.Locale; +import java.util.Properties; + +public class BrokerCompatibilityTest { + + private static final String SOURCE_TOPIC = "brokerCompatibilitySourceTopic"; + private static final String SINK_TOPIC = "brokerCompatibilitySinkTopic"; + + public static void main(final String[] args) throws IOException { + if (args.length < 2) { + System.err.println("BrokerCompatibilityTest are expecting two parameters: propFile, processingMode; but only see " + args.length + " parameter"); + Exit.exit(1); + } + + System.out.println("StreamsTest instance started"); + + final String propFileName = args[0]; + final String processingMode = args[1]; + + final Properties streamsProperties = Utils.loadProps(propFileName); + final String kafka = streamsProperties.getProperty(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG); + + if (kafka == null) { + System.err.println("No bootstrap kafka servers specified in " + StreamsConfig.BOOTSTRAP_SERVERS_CONFIG); + Exit.exit(1); + } + + streamsProperties.put(StreamsConfig.APPLICATION_ID_CONFIG, "kafka-streams-system-test-broker-compatibility"); + streamsProperties.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + streamsProperties.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + streamsProperties.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + streamsProperties.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L); + streamsProperties.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + streamsProperties.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, processingMode); + final int timeout = 6000; + streamsProperties.put(StreamsConfig.consumerPrefix(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG), timeout); + streamsProperties.put(StreamsConfig.consumerPrefix(ConsumerConfig.FETCH_MAX_WAIT_MS_CONFIG), timeout); + streamsProperties.put(StreamsConfig.REQUEST_TIMEOUT_MS_CONFIG, timeout + 1); + final Serde stringSerde = Serdes.String(); + + + final StreamsBuilder builder = new StreamsBuilder(); + builder.stream(SOURCE_TOPIC).groupByKey(Grouped.with(stringSerde, stringSerde)) + .count() + .toStream() + .mapValues(Object::toString) + .to(SINK_TOPIC); + + final KafkaStreams streams = new KafkaStreams(builder.build(), streamsProperties); + streams.setUncaughtExceptionHandler(e -> { + Throwable cause = e; + if (cause instanceof StreamsException) { + while (cause.getCause() != null) { + cause = cause.getCause(); + } + } + System.err.println("FATAL: An unexpected exception " + cause); + e.printStackTrace(System.err); + System.err.flush(); + return StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse.SHUTDOWN_CLIENT; + }); + System.out.println("start Kafka Streams"); + streams.start(); + + final boolean eosEnabled = processingMode.startsWith("exactly_once"); + + System.out.println("send data"); + final Properties producerProperties = new Properties(); + producerProperties.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, kafka); + producerProperties.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class); + producerProperties.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, StringSerializer.class); + if (eosEnabled) { + producerProperties.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "broker-compatibility-producer-tx"); + } + + try { + try (final KafkaProducer producer = new KafkaProducer<>(producerProperties)) { + if (eosEnabled) { + producer.initTransactions(); + producer.beginTransaction(); + } + producer.send(new ProducerRecord<>(SOURCE_TOPIC, "key", "value")); + if (eosEnabled) { + producer.commitTransaction(); + } + + System.out.println("wait for result"); + loopUntilRecordReceived(kafka, eosEnabled); + System.out.println("close Kafka Streams"); + streams.close(); + } + } catch (final RuntimeException e) { + System.err.println("Non-Streams exception occurred: "); + e.printStackTrace(System.err); + System.err.flush(); + } + } + + private static void loopUntilRecordReceived(final String kafka, final boolean eosEnabled) { + final Properties consumerProperties = new Properties(); + consumerProperties.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, kafka); + consumerProperties.put(ConsumerConfig.GROUP_ID_CONFIG, "broker-compatibility-consumer"); + consumerProperties.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + consumerProperties.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); + consumerProperties.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); + if (eosEnabled) { + consumerProperties.put(ConsumerConfig.ISOLATION_LEVEL_CONFIG, IsolationLevel.READ_COMMITTED.name().toLowerCase(Locale.ROOT)); + } + + try (final KafkaConsumer consumer = new KafkaConsumer<>(consumerProperties)) { + consumer.subscribe(Collections.singletonList(SINK_TOPIC)); + + while (true) { + final ConsumerRecords records = consumer.poll(Duration.ofMillis(100)); + for (final ConsumerRecord record : records) { + if (record.key().equals("key") && record.value().equals("1")) { + return; + } + } + } + } + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/tests/EosTestClient.java b/streams/src/test/java/org/apache/kafka/streams/tests/EosTestClient.java new file mode 100644 index 0000000..b9b24da --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/tests/EosTestClient.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.StreamsUncaughtExceptionHandler; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Produced; + +import java.time.Duration; +import java.util.Properties; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +public class EosTestClient extends SmokeTestUtil { + + static final String APP_ID = "EosTest"; + private final Properties properties; + private final boolean withRepartitioning; + private final AtomicBoolean notRunningCallbackReceived = new AtomicBoolean(false); + + private KafkaStreams streams; + private boolean uncaughtException; + + EosTestClient(final Properties properties, final boolean withRepartitioning) { + super(); + this.properties = properties; + this.withRepartitioning = withRepartitioning; + } + + private volatile boolean isRunning = true; + + public void start() { + Exit.addShutdownHook("streams-shutdown-hook", () -> { + isRunning = false; + streams.close(Duration.ofSeconds(300)); + + // need to wait for callback to avoid race condition + // -> make sure the callback printout to stdout is there as it is expected test output + waitForStateTransitionCallback(); + + // do not remove these printouts since they are needed for health scripts + if (!uncaughtException) { + System.out.println(System.currentTimeMillis()); + System.out.println("EOS-TEST-CLIENT-CLOSED"); + System.out.flush(); + } + }); + + while (isRunning) { + if (streams == null) { + uncaughtException = false; + + streams = createKafkaStreams(properties); + streams.setUncaughtExceptionHandler(e -> { + System.out.println(System.currentTimeMillis()); + System.out.println("EOS-TEST-CLIENT-EXCEPTION"); + e.printStackTrace(); + System.out.flush(); + uncaughtException = true; + return StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse.SHUTDOWN_CLIENT; + }); + streams.setStateListener((newState, oldState) -> { + // don't remove this -- it's required test output + System.out.println(System.currentTimeMillis()); + System.out.println("StateChange: " + oldState + " -> " + newState); + System.out.flush(); + if (newState == KafkaStreams.State.NOT_RUNNING) { + notRunningCallbackReceived.set(true); + } + }); + streams.start(); + } + if (uncaughtException) { + streams.close(Duration.ofSeconds(60_000L)); + streams = null; + } + sleep(1000); + } + } + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + private KafkaStreams createKafkaStreams(final Properties props) { + props.put(StreamsConfig.APPLICATION_ID_CONFIG, APP_ID); + props.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 1); + props.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 2); + props.put(StreamsConfig.PROBING_REBALANCE_INTERVAL_MS_CONFIG, Duration.ofMinutes(1).toMillis()); + props.put(StreamsConfig.MAX_WARMUP_REPLICAS_CONFIG, Integer.MAX_VALUE); + props.put(StreamsConfig.REPLICATION_FACTOR_CONFIG, 3); + props.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + props.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 5000L); // increase commit interval to make sure a client is killed having an open transaction + props.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + props.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.Integer().getClass()); + + final StreamsBuilder builder = new StreamsBuilder(); + final KStream data = builder.stream("data"); + + data.to("echo"); + data.process(SmokeTestUtil.printProcessorSupplier("data")); + + final KGroupedStream groupedData = data.groupByKey(); + // min + groupedData + .aggregate( + () -> Integer.MAX_VALUE, + (aggKey, value, aggregate) -> (value < aggregate) ? value : aggregate, + Materialized.with(null, intSerde)) + .toStream() + .to("min", Produced.with(stringSerde, intSerde)); + + // sum + groupedData.aggregate( + () -> 0L, + (aggKey, value, aggregate) -> (long) value + aggregate, + Materialized.with(null, longSerde)) + .toStream() + .to("sum", Produced.with(stringSerde, longSerde)); + + if (withRepartitioning) { + data.to("repartition"); + final KStream repartitionedData = builder.stream("repartition"); + + repartitionedData.process(SmokeTestUtil.printProcessorSupplier("repartition")); + + final KGroupedStream groupedDataAfterRepartitioning = repartitionedData.groupByKey(); + // max + groupedDataAfterRepartitioning + .aggregate( + () -> Integer.MIN_VALUE, + (aggKey, value, aggregate) -> (value > aggregate) ? value : aggregate, + Materialized.with(null, intSerde)) + .toStream() + .to("max", Produced.with(stringSerde, intSerde)); + + // count + groupedDataAfterRepartitioning.count() + .toStream() + .to("cnt", Produced.with(stringSerde, longSerde)); + } + + return new KafkaStreams(builder.build(), props); + } + + private void waitForStateTransitionCallback() { + final long maxWaitTime = System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(300); + while (!notRunningCallbackReceived.get() && System.currentTimeMillis() < maxWaitTime) { + try { + Thread.sleep(500); + } catch (final InterruptedException ignoreAndSwallow) { /* just keep waiting */ } + } + if (!notRunningCallbackReceived.get()) { + System.err.println("State transition callback to NOT_RUNNING never received. Timed out after 5 minutes."); + System.err.flush(); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/tests/EosTestDriver.java b/streams/src/test/java/org/apache/kafka/streams/tests/EosTestDriver.java new file mode 100644 index 0000000..18822d3 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/tests/EosTestDriver.java @@ -0,0 +1,654 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.ConsumerGroupDescription; +import org.apache.kafka.clients.admin.ListConsumerGroupOffsetsResult; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.IsolationLevel; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.apache.kafka.common.serialization.IntegerDeserializer; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.LongDeserializer; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Properties; +import java.util.Random; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; + +public class EosTestDriver extends SmokeTestUtil { + + private static final int MAX_NUMBER_OF_KEYS = 20000; + private static final long MAX_IDLE_TIME_MS = 600000L; + + private volatile static boolean isRunning = true; + private static CountDownLatch terminated = new CountDownLatch(1); + + private static int numRecordsProduced = 0; + + private static synchronized void updateNumRecordsProduces(final int delta) { + numRecordsProduced += delta; + } + + static void generate(final String kafka) { + Exit.addShutdownHook("streams-eos-test-driver-shutdown-hook", () -> { + System.out.println("Terminating"); + isRunning = false; + + try { + if (terminated.await(5L, TimeUnit.MINUTES)) { + System.out.println("Terminated"); + } else { + System.out.println("Terminated with timeout"); + } + } catch (final InterruptedException swallow) { + swallow.printStackTrace(System.err); + System.out.println("Terminated with error"); + } + System.err.flush(); + System.out.flush(); + }); + + final Properties producerProps = new Properties(); + producerProps.put(ProducerConfig.CLIENT_ID_CONFIG, "EosTest"); + producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, kafka); + producerProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class); + producerProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, IntegerSerializer.class); + producerProps.put(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG, true); + + final Map> offsets = new HashMap<>(); + + try { + try (final KafkaProducer producer = new KafkaProducer<>(producerProps)) { + final Random rand = new Random(System.currentTimeMillis()); + + while (isRunning) { + final String key = "" + rand.nextInt(MAX_NUMBER_OF_KEYS); + final int value = rand.nextInt(10000); + + final ProducerRecord record = new ProducerRecord<>("data", key, value); + + producer.send(record, (metadata, exception) -> { + if (exception != null) { + exception.printStackTrace(System.err); + System.err.flush(); + if (exception instanceof TimeoutException) { + try { + // message == org.apache.kafka.common.errors.TimeoutException: Expiring 4 record(s) for data-0: 30004 ms has passed since last attempt plus backoff time + final int expired = Integer.parseInt(exception.getMessage().split(" ")[2]); + updateNumRecordsProduces(-expired); + } catch (final Exception ignore) { + } + } + } else { + offsets.getOrDefault(metadata.partition(), new LinkedList<>()).add(metadata.offset()); + } + }); + + updateNumRecordsProduces(1); + if (numRecordsProduced % 1000 == 0) { + System.out.println(numRecordsProduced + " records produced"); + System.out.flush(); + } + Utils.sleep(rand.nextInt(10)); + } + } + System.out.println("Producer closed: " + numRecordsProduced + " records produced"); + System.out.flush(); + + // verify offsets + for (final Map.Entry> offsetsOfPartition : offsets.entrySet()) { + offsetsOfPartition.getValue().sort(Long::compareTo); + for (int i = 0; i < offsetsOfPartition.getValue().size() - 1; ++i) { + if (offsetsOfPartition.getValue().get(i) != i) { + System.err.println("Offset for partition " + offsetsOfPartition.getKey() + " is not " + i + " as expected but " + offsetsOfPartition.getValue().get(i)); + System.err.flush(); + } + } + System.out.println("Max offset of partition " + offsetsOfPartition.getKey() + " is " + offsetsOfPartition.getValue().get(offsetsOfPartition.getValue().size() - 1)); + } + + final Properties props = new Properties(); + props.put(ConsumerConfig.CLIENT_ID_CONFIG, "verifier"); + props.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, kafka); + props.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class); + props.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class); + props.put(ConsumerConfig.ISOLATION_LEVEL_CONFIG, IsolationLevel.READ_COMMITTED.toString().toLowerCase(Locale.ROOT)); + + try (final KafkaConsumer consumer = new KafkaConsumer<>(props)) { + final List partitions = getAllPartitions(consumer, "data"); + System.out.println("Partitions: " + partitions); + System.out.flush(); + consumer.assign(partitions); + consumer.seekToEnd(partitions); + + for (final TopicPartition tp : partitions) { + System.out.println("End-offset for " + tp + " is " + consumer.position(tp)); + System.out.flush(); + } + } + System.out.flush(); + } finally { + terminated.countDown(); + } + } + + public static void verify(final String kafka, final boolean withRepartitioning) { + final Properties props = new Properties(); + props.put(ConsumerConfig.CLIENT_ID_CONFIG, "verifier"); + props.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, kafka); + props.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class); + props.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class); + props.put(ConsumerConfig.ISOLATION_LEVEL_CONFIG, IsolationLevel.READ_COMMITTED.toString().toLowerCase(Locale.ROOT)); + + try (final KafkaConsumer consumer = new KafkaConsumer<>(props)) { + verifyAllTransactionFinished(consumer, kafka, withRepartitioning); + } catch (final Exception e) { + e.printStackTrace(System.err); + System.out.println("FAILED"); + return; + } + + final Map committedOffsets; + try (final Admin adminClient = Admin.create(props)) { + ensureStreamsApplicationDown(adminClient); + + committedOffsets = getCommittedOffsets(adminClient, withRepartitioning); + } + + final String[] allInputTopics; + final String[] allOutputTopics; + if (withRepartitioning) { + allInputTopics = new String[] {"data", "repartition"}; + allOutputTopics = new String[] {"echo", "min", "sum", "repartition", "max", "cnt"}; + } else { + allInputTopics = new String[] {"data"}; + allOutputTopics = new String[] {"echo", "min", "sum"}; + } + + final Map>>> inputRecordsPerTopicPerPartition; + try (final KafkaConsumer consumer = new KafkaConsumer<>(props)) { + final List partitions = getAllPartitions(consumer, allInputTopics); + consumer.assign(partitions); + consumer.seekToBeginning(partitions); + + inputRecordsPerTopicPerPartition = getRecords(consumer, committedOffsets, withRepartitioning, true); + } catch (final Exception e) { + e.printStackTrace(System.err); + System.out.println("FAILED"); + return; + } + + final Map>>> outputRecordsPerTopicPerPartition; + try (final KafkaConsumer consumer = new KafkaConsumer<>(props)) { + final List partitions = getAllPartitions(consumer, allOutputTopics); + consumer.assign(partitions); + consumer.seekToBeginning(partitions); + + outputRecordsPerTopicPerPartition = getRecords(consumer, consumer.endOffsets(partitions), withRepartitioning, false); + } catch (final Exception e) { + e.printStackTrace(System.err); + System.out.println("FAILED"); + return; + } + + verifyReceivedAllRecords(inputRecordsPerTopicPerPartition.get("data"), outputRecordsPerTopicPerPartition.get("echo")); + if (withRepartitioning) { + verifyReceivedAllRecords(inputRecordsPerTopicPerPartition.get("data"), outputRecordsPerTopicPerPartition.get("repartition")); + } + + verifyMin(inputRecordsPerTopicPerPartition.get("data"), outputRecordsPerTopicPerPartition.get("min")); + verifySum(inputRecordsPerTopicPerPartition.get("data"), outputRecordsPerTopicPerPartition.get("sum")); + + if (withRepartitioning) { + verifyMax(inputRecordsPerTopicPerPartition.get("repartition"), outputRecordsPerTopicPerPartition.get("max")); + verifyCnt(inputRecordsPerTopicPerPartition.get("repartition"), outputRecordsPerTopicPerPartition.get("cnt")); + } + + // do not modify: required test output + System.out.println("ALL-RECORDS-DELIVERED"); + System.out.flush(); + } + + private static void ensureStreamsApplicationDown(final Admin adminClient) { + + final long maxWaitTime = System.currentTimeMillis() + MAX_IDLE_TIME_MS; + ConsumerGroupDescription description; + do { + description = getConsumerGroupDescription(adminClient); + + if (System.currentTimeMillis() > maxWaitTime && !description.members().isEmpty()) { + throw new RuntimeException( + "Streams application not down after " + (MAX_IDLE_TIME_MS / 1000L) + " seconds. " + + "Group: " + description + ); + } + sleep(1000L); + } while (!description.members().isEmpty()); + } + + + private static Map getCommittedOffsets(final Admin adminClient, + final boolean withRepartitioning) { + final Map topicPartitionOffsetAndMetadataMap; + + try { + final ListConsumerGroupOffsetsResult listConsumerGroupOffsetsResult = adminClient.listConsumerGroupOffsets(EosTestClient.APP_ID); + topicPartitionOffsetAndMetadataMap = listConsumerGroupOffsetsResult.partitionsToOffsetAndMetadata().get(10, TimeUnit.SECONDS); + } catch (final Exception e) { + e.printStackTrace(); + throw new RuntimeException(e); + } + + final Map committedOffsets = new HashMap<>(); + + for (final Map.Entry entry : topicPartitionOffsetAndMetadataMap.entrySet()) { + final String topic = entry.getKey().topic(); + if (topic.equals("data") || withRepartitioning && topic.equals("repartition")) { + committedOffsets.put(entry.getKey(), entry.getValue().offset()); + } + } + + return committedOffsets; + } + + private static Map>>> getRecords(final KafkaConsumer consumer, + final Map readEndOffsets, + final boolean withRepartitioning, + final boolean isInputTopic) { + System.out.println("read end offset: " + readEndOffsets); + final Map>>> recordPerTopicPerPartition = new HashMap<>(); + final Map maxReceivedOffsetPerPartition = new HashMap<>(); + final Map maxConsumerPositionPerPartition = new HashMap<>(); + + long maxWaitTime = System.currentTimeMillis() + MAX_IDLE_TIME_MS; + boolean allRecordsReceived = false; + while (!allRecordsReceived && System.currentTimeMillis() < maxWaitTime) { + final ConsumerRecords receivedRecords = consumer.poll(Duration.ofSeconds(1L)); + + for (final ConsumerRecord record : receivedRecords) { + maxWaitTime = System.currentTimeMillis() + MAX_IDLE_TIME_MS; + final TopicPartition tp = new TopicPartition(record.topic(), record.partition()); + maxReceivedOffsetPerPartition.put(tp, record.offset()); + final long readEndOffset = readEndOffsets.get(tp); + if (record.offset() < readEndOffset) { + addRecord(record, recordPerTopicPerPartition, withRepartitioning); + } else if (!isInputTopic) { + throw new RuntimeException("FAIL: did receive more records than expected for " + tp + + " (expected EOL offset: " + readEndOffset + "; current offset: " + record.offset()); + } + } + + for (final TopicPartition tp : readEndOffsets.keySet()) { + maxConsumerPositionPerPartition.put(tp, consumer.position(tp)); + if (consumer.position(tp) >= readEndOffsets.get(tp)) { + consumer.pause(Collections.singletonList(tp)); + } + } + + allRecordsReceived = consumer.paused().size() == readEndOffsets.keySet().size(); + } + + if (!allRecordsReceived) { + System.err.println("Pause partitions (ie, received all data): " + consumer.paused()); + System.err.println("Max received offset per partition: " + maxReceivedOffsetPerPartition); + System.err.println("Max consumer position per partition: " + maxConsumerPositionPerPartition); + throw new RuntimeException("FAIL: did not receive all records after " + (MAX_IDLE_TIME_MS / 1000L) + " sec idle time."); + } + + return recordPerTopicPerPartition; + } + + private static void addRecord(final ConsumerRecord record, + final Map>>> recordPerTopicPerPartition, + final boolean withRepartitioning) { + + final String topic = record.topic(); + final TopicPartition partition = new TopicPartition(topic, record.partition()); + + if (verifyTopic(topic, withRepartitioning)) { + final Map>> topicRecordsPerPartition = + recordPerTopicPerPartition.computeIfAbsent(topic, k -> new HashMap<>()); + + final List> records = + topicRecordsPerPartition.computeIfAbsent(partition, k -> new ArrayList<>()); + + records.add(record); + } else { + throw new RuntimeException("FAIL: received data from unexpected topic: " + record); + } + } + + private static boolean verifyTopic(final String topic, + final boolean withRepartitioning) { + final boolean validTopic = "data".equals(topic) || "echo".equals(topic) || "min".equals(topic) || "sum".equals(topic); + + if (withRepartitioning) { + return validTopic || "repartition".equals(topic) || "max".equals(topic) || "cnt".equals(topic); + } + + return validTopic; + } + + private static void verifyReceivedAllRecords(final Map>> expectedRecords, + final Map>> receivedRecords) { + if (expectedRecords.size() != receivedRecords.size()) { + throw new RuntimeException("Result verification failed. Received " + receivedRecords.size() + " records but expected " + expectedRecords.size()); + } + + final StringDeserializer stringDeserializer = new StringDeserializer(); + final IntegerDeserializer integerDeserializer = new IntegerDeserializer(); + for (final Map.Entry>> partitionRecords : receivedRecords.entrySet()) { + final TopicPartition inputTopicPartition = new TopicPartition("data", partitionRecords.getKey().partition()); + final List> receivedRecordsForPartition = partitionRecords.getValue(); + final List> expectedRecordsForPartition = expectedRecords.get(inputTopicPartition); + + System.out.println(partitionRecords.getKey() + " with " + receivedRecordsForPartition.size() + ", " + + inputTopicPartition + " with " + expectedRecordsForPartition.size()); + + final Iterator> expectedRecord = expectedRecordsForPartition.iterator(); + RuntimeException exception = null; + for (final ConsumerRecord receivedRecord : receivedRecordsForPartition) { + if (!expectedRecord.hasNext()) { + exception = new RuntimeException("Result verification failed for " + receivedRecord + " since there's no more expected record"); + } + + final ConsumerRecord expected = expectedRecord.next(); + + final String receivedKey = stringDeserializer.deserialize(receivedRecord.topic(), receivedRecord.key()); + final int receivedValue = integerDeserializer.deserialize(receivedRecord.topic(), receivedRecord.value()); + final String expectedKey = stringDeserializer.deserialize(expected.topic(), expected.key()); + final int expectedValue = integerDeserializer.deserialize(expected.topic(), expected.value()); + + if (!receivedKey.equals(expectedKey) || receivedValue != expectedValue) { + exception = new RuntimeException("Result verification failed for " + receivedRecord + " expected <" + expectedKey + "," + expectedValue + "> but was <" + receivedKey + "," + receivedValue + ">"); + } + } + + if (exception != null) { + throw exception; + } + } + } + + private static void verifyMin(final Map>> inputPerTopicPerPartition, + final Map>> minPerTopicPerPartition) { + final StringDeserializer stringDeserializer = new StringDeserializer(); + final IntegerDeserializer integerDeserializer = new IntegerDeserializer(); + + final HashMap currentMinPerKey = new HashMap<>(); + for (final Map.Entry>> partitionRecords : minPerTopicPerPartition.entrySet()) { + final TopicPartition inputTopicPartition = new TopicPartition("data", partitionRecords.getKey().partition()); + final List> partitionInput = inputPerTopicPerPartition.get(inputTopicPartition); + final List> partitionMin = partitionRecords.getValue(); + + if (partitionInput.size() != partitionMin.size()) { + throw new RuntimeException("Result verification failed: expected " + partitionInput.size() + " records for " + + partitionRecords.getKey() + " but received " + partitionMin.size()); + } + + final Iterator> inputRecords = partitionInput.iterator(); + + for (final ConsumerRecord receivedRecord : partitionMin) { + final ConsumerRecord input = inputRecords.next(); + + final String receivedKey = stringDeserializer.deserialize(receivedRecord.topic(), receivedRecord.key()); + final int receivedValue = integerDeserializer.deserialize(receivedRecord.topic(), receivedRecord.value()); + final String key = stringDeserializer.deserialize(input.topic(), input.key()); + final int value = integerDeserializer.deserialize(input.topic(), input.value()); + + Integer min = currentMinPerKey.get(key); + if (min == null) { + min = value; + } else { + min = Math.min(min, value); + } + currentMinPerKey.put(key, min); + + if (!receivedKey.equals(key) || receivedValue != min) { + throw new RuntimeException("Result verification failed for " + receivedRecord + " expected <" + key + "," + min + "> but was <" + receivedKey + "," + receivedValue + ">"); + } + } + } + } + + private static void verifySum(final Map>> inputPerTopicPerPartition, + final Map>> minPerTopicPerPartition) { + final StringDeserializer stringDeserializer = new StringDeserializer(); + final IntegerDeserializer integerDeserializer = new IntegerDeserializer(); + final LongDeserializer longDeserializer = new LongDeserializer(); + + final HashMap currentSumPerKey = new HashMap<>(); + for (final Map.Entry>> partitionRecords : minPerTopicPerPartition.entrySet()) { + final TopicPartition inputTopicPartition = new TopicPartition("data", partitionRecords.getKey().partition()); + final List> partitionInput = inputPerTopicPerPartition.get(inputTopicPartition); + final List> partitionSum = partitionRecords.getValue(); + + if (partitionInput.size() != partitionSum.size()) { + throw new RuntimeException("Result verification failed: expected " + partitionInput.size() + " records for " + + partitionRecords.getKey() + " but received " + partitionSum.size()); + } + + final Iterator> inputRecords = partitionInput.iterator(); + + for (final ConsumerRecord receivedRecord : partitionSum) { + final ConsumerRecord input = inputRecords.next(); + + final String receivedKey = stringDeserializer.deserialize(receivedRecord.topic(), receivedRecord.key()); + final long receivedValue = longDeserializer.deserialize(receivedRecord.topic(), receivedRecord.value()); + final String key = stringDeserializer.deserialize(input.topic(), input.key()); + final int value = integerDeserializer.deserialize(input.topic(), input.value()); + + Long sum = currentSumPerKey.get(key); + if (sum == null) { + sum = (long) value; + } else { + sum += value; + } + currentSumPerKey.put(key, sum); + + if (!receivedKey.equals(key) || receivedValue != sum) { + throw new RuntimeException("Result verification failed for " + receivedRecord + " expected <" + key + "," + sum + "> but was <" + receivedKey + "," + receivedValue + ">"); + } + } + } + } + + private static void verifyMax(final Map>> inputPerTopicPerPartition, + final Map>> maxPerTopicPerPartition) { + final StringDeserializer stringDeserializer = new StringDeserializer(); + final IntegerDeserializer integerDeserializer = new IntegerDeserializer(); + + final HashMap currentMinPerKey = new HashMap<>(); + for (final Map.Entry>> partitionRecords : maxPerTopicPerPartition.entrySet()) { + final TopicPartition inputTopicPartition = new TopicPartition("repartition", partitionRecords.getKey().partition()); + final List> partitionInput = inputPerTopicPerPartition.get(inputTopicPartition); + final List> partitionMax = partitionRecords.getValue(); + + if (partitionInput.size() != partitionMax.size()) { + throw new RuntimeException("Result verification failed: expected " + partitionInput.size() + " records for " + + partitionRecords.getKey() + " but received " + partitionMax.size()); + } + + final Iterator> inputRecords = partitionInput.iterator(); + + for (final ConsumerRecord receivedRecord : partitionMax) { + final ConsumerRecord input = inputRecords.next(); + + final String receivedKey = stringDeserializer.deserialize(receivedRecord.topic(), receivedRecord.key()); + final int receivedValue = integerDeserializer.deserialize(receivedRecord.topic(), receivedRecord.value()); + final String key = stringDeserializer.deserialize(input.topic(), input.key()); + final int value = integerDeserializer.deserialize(input.topic(), input.value()); + + + Integer max = currentMinPerKey.get(key); + if (max == null) { + max = Integer.MIN_VALUE; + } + max = Math.max(max, value); + currentMinPerKey.put(key, max); + + if (!receivedKey.equals(key) || receivedValue != max) { + throw new RuntimeException("Result verification failed for " + receivedRecord + " expected <" + key + "," + max + "> but was <" + receivedKey + "," + receivedValue + ">"); + } + } + } + } + + private static void verifyCnt(final Map>> inputPerTopicPerPartition, + final Map>> cntPerTopicPerPartition) { + final StringDeserializer stringDeserializer = new StringDeserializer(); + final LongDeserializer longDeserializer = new LongDeserializer(); + + final HashMap currentSumPerKey = new HashMap<>(); + for (final Map.Entry>> partitionRecords : cntPerTopicPerPartition.entrySet()) { + final TopicPartition inputTopicPartition = new TopicPartition("repartition", partitionRecords.getKey().partition()); + final List> partitionInput = inputPerTopicPerPartition.get(inputTopicPartition); + final List> partitionCnt = partitionRecords.getValue(); + + if (partitionInput.size() != partitionCnt.size()) { + throw new RuntimeException("Result verification failed: expected " + partitionInput.size() + " records for " + + partitionRecords.getKey() + " but received " + partitionCnt.size()); + } + + final Iterator> inputRecords = partitionInput.iterator(); + + for (final ConsumerRecord receivedRecord : partitionCnt) { + final ConsumerRecord input = inputRecords.next(); + + final String receivedKey = stringDeserializer.deserialize(receivedRecord.topic(), receivedRecord.key()); + final long receivedValue = longDeserializer.deserialize(receivedRecord.topic(), receivedRecord.value()); + final String key = stringDeserializer.deserialize(input.topic(), input.key()); + + Long cnt = currentSumPerKey.get(key); + if (cnt == null) { + cnt = 0L; + } + currentSumPerKey.put(key, ++cnt); + + if (!receivedKey.equals(key) || receivedValue != cnt) { + throw new RuntimeException("Result verification failed for " + receivedRecord + " expected <" + key + "," + cnt + "> but was <" + receivedKey + "," + receivedValue + ">"); + } + } + } + } + + private static void verifyAllTransactionFinished(final KafkaConsumer consumer, + final String kafka, + final boolean withRepartitioning) { + final String[] topics; + if (withRepartitioning) { + topics = new String[] {"echo", "min", "sum", "repartition", "max", "cnt"}; + } else { + topics = new String[] {"echo", "min", "sum"}; + } + + final List partitions = getAllPartitions(consumer, topics); + consumer.assign(partitions); + consumer.seekToEnd(partitions); + for (final TopicPartition tp : partitions) { + System.out.println(tp + " at position " + consumer.position(tp)); + } + + final Properties consumerProps = new Properties(); + consumerProps.put(ConsumerConfig.CLIENT_ID_CONFIG, "consumer-uncommitted"); + consumerProps.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, kafka); + consumerProps.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class); + consumerProps.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class); + + + final long maxWaitTime = System.currentTimeMillis() + MAX_IDLE_TIME_MS; + try (final KafkaConsumer consumerUncommitted = new KafkaConsumer<>(consumerProps)) { + while (!partitions.isEmpty() && System.currentTimeMillis() < maxWaitTime) { + consumer.seekToEnd(partitions); + final Map topicEndOffsets = consumerUncommitted.endOffsets(partitions); + + final Iterator iterator = partitions.iterator(); + while (iterator.hasNext()) { + final TopicPartition topicPartition = iterator.next(); + final long position = consumer.position(topicPartition); + + if (position == topicEndOffsets.get(topicPartition)) { + iterator.remove(); + System.out.println("Removing " + topicPartition + " at position " + position); + } else if (consumer.position(topicPartition) > topicEndOffsets.get(topicPartition)) { + throw new IllegalStateException("Offset for partition " + topicPartition + " is larger than topic endOffset: " + position + " > " + topicEndOffsets.get(topicPartition)); + } else { + System.out.println("Retry " + topicPartition + " at position " + position); + } + } + sleep(1000L); + } + } + + if (!partitions.isEmpty()) { + throw new RuntimeException("Could not read all verification records. Did not receive any new record within the last " + (MAX_IDLE_TIME_MS / 1000L) + " sec."); + } + } + + private static List getAllPartitions(final KafkaConsumer consumer, + final String... topics) { + final ArrayList partitions = new ArrayList<>(); + + for (final String topic : topics) { + for (final PartitionInfo info : consumer.partitionsFor(topic)) { + partitions.add(new TopicPartition(info.topic(), info.partition())); + } + } + return partitions; + } + + + private static ConsumerGroupDescription getConsumerGroupDescription(final Admin adminClient) { + final ConsumerGroupDescription description; + try { + description = adminClient.describeConsumerGroups(Collections.singleton(EosTestClient.APP_ID)) + .describedGroups() + .get(EosTestClient.APP_ID) + .get(10, TimeUnit.SECONDS); + } catch (final InterruptedException | ExecutionException | java.util.concurrent.TimeoutException e) { + e.printStackTrace(); + throw new RuntimeException("Unexpected Exception getting group description", e); + } + return description; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/tests/RelationalSmokeTest.java b/streams/src/test/java/org/apache/kafka/streams/tests/RelationalSmokeTest.java new file mode 100644 index 0000000..8802109 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/tests/RelationalSmokeTest.java @@ -0,0 +1,1004 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.IntegerDeserializer; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.ValueJoiner; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Random; +import java.util.TreeMap; +import java.util.UUID; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Stream; + +import static java.util.stream.Collectors.toList; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkProperties; + +/** + * This test builds on a basic relational data caricature: + * a set of articles, each with a set of comments on them. + *

                + * The particular relation in this data is that each Comment + * has an articleId foreign-key reference to an Article. + *

                + * We traverse this relation in both directions to verify + * correct operation of Kafka Streams: + * - Embed on each Comment a prefix of the text in the Article it references + * aka SELECT * FROM Comment JOIN Article ON Comment.ArticleID = Article.ID + * - Embed on each Article a count of all the Comments it has + */ +public class RelationalSmokeTest extends SmokeTestUtil { + private static final Logger LOG = LoggerFactory.getLogger(RelationalSmokeTest.class); + + static final String ARTICLE_SOURCE = "in-article"; + static final String COMMENT_SOURCE = "in-comment"; + static final String ARTICLE_RESULT_SINK = "out-augmented-article"; + static final String COMMENT_RESULT_SINK = "out-augmented-comment"; + private static final String[] TOPICS = { + ARTICLE_SOURCE, + COMMENT_SOURCE, + ARTICLE_RESULT_SINK, + COMMENT_RESULT_SINK + }; + + public static String[] topics() { + return Arrays.copyOf(TOPICS, TOPICS.length); + } + + public static class Article { + private final int key; + private final long timestamp; + private final String text; + + private Article(final int key, final long timestamp, final String text) { + this.key = key; + this.timestamp = timestamp; + this.text = text; + } + + public int getKey() { + return key; + } + + public long getTimestamp() { + return timestamp; + } + + public String getText() { + return text; + } + + @Override + public String toString() { + return "Article{" + + "key=" + key + + ", timestamp=" + Instant.ofEpochMilli(timestamp) + + ", text='" + text + '\'' + + '}'; + } + + public static class ArticleSerializer implements Serializer

                { + @Override + public byte[] serialize(final String topic, final Article data) { + final byte[] serialText = stringSerde.serializer().serialize(topic, data.getText()); + + final int length = Integer.BYTES + Long.BYTES + Integer.BYTES + serialText.length; + + final ByteBuffer buffer = + ByteBuffer.allocate(length) + .putInt(data.getKey()) + .putLong(data.getTimestamp()) + .putInt(serialText.length) + .put(serialText); + + return Serdes.ByteBuffer().serializer().serialize(topic, buffer); + } + } + + public static class ArticleDeserializer implements Deserializer
                { + + public static Article deserialize(final String topic, final ByteBuffer buffer) { + final int key = buffer.getInt(); + final long timestamp = buffer.getLong(); + final int textLength = buffer.getInt(); + final byte[] serialText = new byte[textLength]; + buffer.get(serialText); + final String text = stringSerde.deserializer().deserialize(topic, serialText); + return new Article(key, timestamp, text); + } + + @Override + public Article deserialize(final String topic, final byte[] data) { + final ByteBuffer buffer = Serdes.ByteBuffer().deserializer().deserialize(topic, data); + return deserialize(topic, buffer); + } + } + + public static class ArticleSerde implements Serde
                { + @Override + public Serializer
                serializer() { + return new ArticleSerializer(); + } + + @Override + public Deserializer
                deserializer() { + return new ArticleDeserializer(); + } + } + } + + public static class Comment { + private final int key; + private final long timestamp; + private final String text; + private final int articleId; + + private Comment(final int key, final long timestamp, final String text, final int articleId) { + this.key = key; + this.timestamp = timestamp; + this.text = text; + this.articleId = articleId; + } + + public int getKey() { + return key; + } + + public long getTimestamp() { + return timestamp; + } + + public String getText() { + return text; + } + + public int getArticleId() { + return articleId; + } + + @Override + public String toString() { + return "Comment{" + + "key=" + key + + ", timestamp=" + Instant.ofEpochMilli(timestamp) + + ", text='" + text + '\'' + + ", articleId=" + articleId + + '}'; + } + + public static class CommentSerializer implements Serializer { + + @Override + public byte[] serialize(final String topic, final Comment data) { + final byte[] serialText = stringSerde.serializer().serialize(topic, data.text); + + final int length = Integer.BYTES + Long.BYTES + (Integer.BYTES + serialText.length) + Integer.BYTES; + + final ByteBuffer buffer = + ByteBuffer.allocate(length) + .putInt(data.key) + .putLong(data.timestamp) + .putInt(serialText.length) + .put(serialText) + .putInt(data.articleId); + + return Serdes.ByteBuffer().serializer().serialize(topic, buffer); + } + } + + public static class CommentDeserializer implements Deserializer { + + public static Comment deserialize(final String topic, final ByteBuffer buffer) { + final int key = buffer.getInt(); + final long timestamp = buffer.getLong(); + final int textLength = buffer.getInt(); + final byte[] textBytes = new byte[textLength]; + buffer.get(textBytes); + final String text = stringSerde.deserializer().deserialize(topic, textBytes); + final int articleId = buffer.getInt(); + + return new Comment(key, timestamp, text, articleId); + } + + @Override + public Comment deserialize(final String topic, final byte[] data) { + final ByteBuffer buffer = Serdes.ByteBuffer().deserializer().deserialize(topic, data); + return deserialize(topic, buffer); + } + } + + public static class CommentSerde implements Serde { + + @Override + public Serializer serializer() { + return new CommentSerializer(); + } + + @Override + public Deserializer deserializer() { + return new CommentDeserializer(); + } + } + } + + public static final class DataSet { + private final Article[] articles; + private final Comment[] comments; + + private DataSet(final Article[] articles, final Comment[] comments) { + this.articles = articles; + this.comments = comments; + } + + public Article[] getArticles() { + return articles; + } + + public Comment[] getComments() { + return comments; + } + + @Override + public String toString() { + final StringBuilder stringBuilder = new StringBuilder(); + stringBuilder.append(articles.length).append(" Articles").append('\n'); + for (final Article article : articles) { + stringBuilder.append(" ").append(article).append('\n'); + } + stringBuilder.append(comments.length).append(" Comments").append("\n"); + for (final Comment comment : comments) { + stringBuilder.append(" ").append(comment).append('\n'); + } + return stringBuilder.toString(); + } + + public static DataSet generate(final int numArticles, final int numComments) { + // generate four days' worth of data, starting right now (to avoid broker retention/compaction) + final int timeSpan = 1000 * 60 * 60 * 24 * 4; + final long dataStartTime = System.currentTimeMillis(); + final long dataEndTime = dataStartTime + timeSpan; + + // Explicitly create a seed so we can we can log. + // If we are debugging a failed run, we can deterministically produce the same dataset + // by plugging in the seed from that run. + final long seed = new Random().nextLong(); + final Random random = new Random(seed); + LOG.info("Dataset PRNG seed: {}", seed); + final Iterator articlesToCommentOnSequence = zipfNormal(random, numArticles); + + final Article[] articles = new Article[numArticles]; + final Comment[] comments = new Comment[numComments]; + + // first generate the articles (note: out of order) + for (int i = 0; i < numArticles; i++) { + final long timestamp = random.nextInt(timeSpan) + dataStartTime; + final String text = randomText(random, 2000); + articles[i] = new Article(i, timestamp, text); + } + + // then spend the rest of the time generating the comments + for (int i = 0; i < numComments; i++) { + final int articleId = articlesToCommentOnSequence.next(); + final long articleTimestamp = articles[articleId].getTimestamp(); + // comments get written after articles + final long timestamp = random.nextInt((int) (dataEndTime - articleTimestamp)) + articleTimestamp; + final String text = randomText(random, 200); + final Comment comment = new Comment(i, timestamp, text, articleId); + comments[i] = comment; + } + return new DataSet(articles, comments); + } + + /** + * Rough-and-ready random text generator. Creates a text with a length normally + * distributed about {@code avgLength} with a standard deviation of 1/3rd {@code avgLength}. + * Each letter is drawn uniformly from a-z. + */ + private static String randomText(final Random random, final int avgLength) { + final int lowChar = 97; // letter 'a' + final int highChar = 122; // letter 'z' + + final int length = Math.max(0, (int) (random.nextGaussian() * avgLength / 3.0) + avgLength); + final char[] chars = new char[length]; + for (int i = 0; i < chars.length; i++) { + chars[i] = (char) (random.nextInt(highChar - lowChar) + lowChar); + } + return new String(chars); + } + + /** + * Generates a keySpace number of unique keys normally distributed, + * with the mean at 0 and stdDev 1/3 of the way through the keySpace + * any sample more than 3 standard deviations from the mean are assigned to the 0 key. + *

                + * This is designed to roughly balance the key properties of two dominant real-world + * data distribution: Zipfian and Normal, while also being efficient to generate. + */ + private static Iterator zipfNormal(final Random random, final int keySpace) { + return new Iterator() { + @Override + public boolean hasNext() { + return true; + } + + @Override + public Integer next() { + final double gaussian = Math.abs(random.nextGaussian()); + final double scaled = gaussian / 3.0; // + final double sample = scaled > 1.0 ? 0.0 : scaled; + final double keyDouble = sample * keySpace; + return (int) keyDouble; + } + }; + } + + public void produce(final String kafka, final Duration timeToSpend) throws InterruptedException { + final Properties producerProps = new Properties(); + final String id = "RelationalSmokeTestProducer" + UUID.randomUUID(); + producerProps.put(ProducerConfig.CLIENT_ID_CONFIG, id); + producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, kafka); + producerProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, IntegerSerializer.class); + producerProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class); + producerProps.put(ProducerConfig.ACKS_CONFIG, "-1"); + + final Article.ArticleSerializer articleSerializer = new Article.ArticleSerializer(); + final Comment.CommentSerializer commentSerializer = new Comment.CommentSerializer(); + + final long pauseTime = timeToSpend.toMillis() / (articles.length + comments.length); + + try (final KafkaProducer producer = new KafkaProducer<>(producerProps)) { + int a = 0; + int c = 0; + while (a < articles.length || c < comments.length) { + final ProducerRecord producerRecord; + if (a < articles.length && c >= comments.length || + a < articles.length && articles[a].getTimestamp() <= comments[c].timestamp) { + producerRecord = new ProducerRecord<>( + ARTICLE_SOURCE, + null, + articles[a].getTimestamp(), + articles[a].getKey(), + articleSerializer.serialize("", articles[a]) + ); + a++; + } else { + producerRecord = new ProducerRecord<>( + COMMENT_SOURCE, + null, + comments[c].getTimestamp(), + comments[c].getKey(), + commentSerializer.serialize("", comments[c]) + ); + c++; + } + producer.send(producerRecord); + producer.flush(); + LOG.info("sent {} {}", producerRecord.topic(), producerRecord.key()); + Thread.sleep(pauseTime); + } + } + } + } + + public static final class AugmentedArticle extends Article { + private final long commentCount; + + private AugmentedArticle(final int key, final long timestamp, final String text, final long commentCount) { + super(key, timestamp, text); + this.commentCount = commentCount; + } + + public long getCommentCount() { + return commentCount; + } + + @Override + public String toString() { + return "AugmentedArticle{" + + "key=" + super.key + + ", timestamp=" + getTimestamp() + + ", text='" + getText() + '\'' + + ", commentCount=" + commentCount + + '}'; + } + + public static class AugmentedArticleSerializer implements Serializer { + private final ArticleSerializer articleSerializer = new ArticleSerializer(); + + @Override + public byte[] serialize(final String topic, final AugmentedArticle data) { + final byte[] serializedArticle = articleSerializer.serialize(topic, data); + final int length = serializedArticle.length + Long.BYTES; + final ByteBuffer buffer = + ByteBuffer.allocate(length) + .put(serializedArticle) + .putLong(data.getCommentCount()); + return Serdes.ByteBuffer().serializer().serialize(topic, buffer); + } + } + + public static class AugmentedArticleDeserializer implements Deserializer { + @Override + public AugmentedArticle deserialize(final String topic, final byte[] data) { + final ByteBuffer wrap = ByteBuffer.wrap(data); + + final Article article = ArticleDeserializer.deserialize(topic, wrap); + final long commentCount = wrap.getLong(); + return new AugmentedArticle(article.key, article.getTimestamp(), article.getText(), commentCount); + } + } + + public static class AugmentedArticleSerde implements Serde { + + @Override + public Serializer serializer() { + return new AugmentedArticleSerializer(); + } + + @Override + public Deserializer deserializer() { + return new AugmentedArticleDeserializer(); + } + } + + public static ValueJoiner joiner() { + return (article, commentCount) -> new AugmentedArticle( + article.getKey(), + article.getTimestamp(), + article.getText(), commentCount == null ? 0 : commentCount + ); + } + } + + public static final class AugmentedComment extends Comment { + private final String articlePrefix; + + private AugmentedComment(final int key, + final long timestamp, + final String text, + final int articleId, + final String articlePrefix) { + super(key, timestamp, text, articleId); + this.articlePrefix = articlePrefix; + } + + public String getArticlePrefix() { + return articlePrefix; + } + + @Override + public String toString() { + return "AugmentedComment{" + + "key=" + super.key + + ", timestamp=" + getTimestamp() + + ", text='" + getText() + '\'' + + ", articleId=" + getArticleId() + + ", articlePrefix='" + articlePrefix + '\'' + + '}'; + } + + public static class AugmentedCommentSerializer implements Serializer { + private final CommentSerializer commentSerializer = new CommentSerializer(); + + @Override + public byte[] serialize(final String topic, final AugmentedComment data) { + final byte[] serializedComment = commentSerializer.serialize(topic, data); + final byte[] serializedPrefix = stringSerde.serializer().serialize(topic, data.getArticlePrefix()); + final int length = serializedComment.length + Integer.BYTES + serializedPrefix.length; + final ByteBuffer buffer = + ByteBuffer.allocate(length) + .put(serializedComment) + .putInt(serializedPrefix.length) + .put(serializedPrefix); + return Serdes.ByteBuffer().serializer().serialize(topic, buffer); + } + } + + public static class AugmentedCommentDeserializer implements Deserializer { + @Override + public AugmentedComment deserialize(final String topic, final byte[] data) { + final ByteBuffer wrap = ByteBuffer.wrap(data); + + final Comment comment = CommentDeserializer.deserialize(topic, wrap); + final int prefixLength = wrap.getInt(); + final byte[] serializedPrefix = new byte[prefixLength]; + wrap.get(serializedPrefix); + final String articlePrefix = stringSerde.deserializer().deserialize(topic, serializedPrefix); + return new AugmentedComment( + comment.key, + comment.getTimestamp(), + comment.getText(), + comment.getArticleId(), + articlePrefix + ); + } + } + + public static class AugmentedCommentSerde implements Serde { + + @Override + public Serializer serializer() { + return new AugmentedCommentSerializer(); + } + + @Override + public Deserializer deserializer() { + return new AugmentedCommentDeserializer(); + } + } + + private static String prefix(final String text, final int length) { + return text.length() < length ? text : text.substring(0, length); + } + + public static ValueJoiner joiner() { + return (comment, article) -> new AugmentedComment( + comment.key, + comment.getTimestamp(), + comment.getText(), + comment.getArticleId(), + prefix(article.getText(), 10) + ); + } + } + + public static final class App { + public static Topology getTopology() { + final StreamsBuilder streamsBuilder = new StreamsBuilder(); + final KTable articles = + streamsBuilder.table(ARTICLE_SOURCE, Consumed.with(intSerde, new Article.ArticleSerde())); + + final KTable comments = + streamsBuilder.table(COMMENT_SOURCE, Consumed.with(intSerde, new Comment.CommentSerde())); + + + final KTable commentCounts = + comments.groupBy( + (key, value) -> new KeyValue<>(value.getArticleId(), (short) 1), + Grouped.with(Serdes.Integer(), Serdes.Short()) + ) + .count(); + + articles + .leftJoin( + commentCounts, + AugmentedArticle.joiner(), + Materialized.with(null, new AugmentedArticle.AugmentedArticleSerde()) + ) + .toStream() + .to(ARTICLE_RESULT_SINK); + + comments.join(articles, + Comment::getArticleId, + AugmentedComment.joiner(), + Materialized.with(null, new AugmentedComment.AugmentedCommentSerde())) + .toStream() + .to(COMMENT_RESULT_SINK); + + return streamsBuilder.build(); + } + + public static Properties getConfig(final String broker, + final String application, + final String id, + final String processingGuarantee, + final String stateDir) { + final Properties properties = + mkProperties( + mkMap( + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, broker), + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, application), + mkEntry(StreamsConfig.CLIENT_ID_CONFIG, id), + mkEntry(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, processingGuarantee), + mkEntry(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"), + mkEntry(StreamsConfig.STATE_DIR_CONFIG, stateDir) + ) + ); + properties.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + return properties; + } + + public static KafkaStreams startSync(final String broker, + final String application, + final String id, + final String processingGuarantee, + final String stateDir) throws InterruptedException { + final KafkaStreams kafkaStreams = + new KafkaStreams(getTopology(), getConfig(broker, application, id, processingGuarantee, stateDir)); + final CountDownLatch startUpLatch = new CountDownLatch(1); + kafkaStreams.setStateListener((newState, oldState) -> { + if (oldState == KafkaStreams.State.REBALANCING && newState == KafkaStreams.State.RUNNING) { + startUpLatch.countDown(); + } + }); + kafkaStreams.start(); + startUpLatch.await(); + LOG.info("Streams has started."); + return kafkaStreams; + } + + public static boolean verifySync(final String broker, final Instant deadline) throws InterruptedException { + final Deserializer keyDeserializer = intSerde.deserializer(); + + final Deserializer

                articleDeserializer = new Article.ArticleDeserializer(); + + final Deserializer augmentedArticleDeserializer = + new AugmentedArticle.AugmentedArticleDeserializer(); + + final Deserializer commentDeserializer = new Comment.CommentDeserializer(); + + final Deserializer augmentedCommentDeserializer = + new AugmentedComment.AugmentedCommentDeserializer(); + + + final Properties consumerProperties = new Properties(); + final String id = "RelationalSmokeTestConsumer" + UUID.randomUUID(); + consumerProperties.put(ConsumerConfig.CLIENT_ID_CONFIG, id); + consumerProperties.put(ConsumerConfig.GROUP_ID_CONFIG, id); + consumerProperties.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, broker); + consumerProperties.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, IntegerDeserializer.class); + consumerProperties.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, ByteArrayDeserializer.class); + consumerProperties.put(ConsumerConfig.ISOLATION_LEVEL_CONFIG, "read_committed"); + consumerProperties.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, false); + + try (final KafkaConsumer consumer = new KafkaConsumer<>(consumerProperties)) { + final List articlePartitions = consumer.partitionsFor(ARTICLE_SOURCE); + final List augmentedArticlePartitions = consumer.partitionsFor(ARTICLE_RESULT_SINK); + final List commentPartitions = consumer.partitionsFor(COMMENT_SOURCE); + final List augmentedCommentPartitions = consumer.partitionsFor(COMMENT_RESULT_SINK); + final List assignment = + Stream.concat( + Stream.concat( + articlePartitions.stream().map(p -> new TopicPartition(p.topic(), p.partition())), + augmentedArticlePartitions.stream().map(p -> new TopicPartition(p.topic(), p.partition())) + ), + Stream.concat( + commentPartitions.stream().map(p -> new TopicPartition(p.topic(), p.partition())), + augmentedCommentPartitions.stream().map(p -> new TopicPartition(p.topic(), p.partition())) + ) + ).collect(toList()); + consumer.assign(assignment); + consumer.seekToBeginning(assignment); + + final Map consumedArticles = new TreeMap<>(); + final Map consumedAugmentedArticles = new TreeMap<>(); + final Map consumedComments = new TreeMap<>(); + final Map consumedAugmentedComments = new TreeMap<>(); + + boolean printedConsumedArticle = false; + boolean printedConsumedAugmentedArticle = false; + boolean printedConsumedComment = false; + boolean printedConsumedAugmentedComment = false; + boolean passed = false; + + while (!passed && Instant.now().isBefore(deadline)) { + boolean lastPollWasEmpty = false; + while (!lastPollWasEmpty) { + final ConsumerRecords poll = consumer.poll(Duration.ofSeconds(1)); + lastPollWasEmpty = poll.isEmpty(); + for (final ConsumerRecord record : poll) { + final Integer key = record.key(); + switch (record.topic()) { + case ARTICLE_SOURCE: { + final Article article = articleDeserializer.deserialize("", record.value()); + if (consumedArticles.containsKey(key)) { + LOG.warn("Duplicate article: {} and {}", consumedArticles.get(key), article); + } + consumedArticles.put(key, article); + break; + } + case COMMENT_SOURCE: { + final Comment comment = commentDeserializer.deserialize("", record.value()); + if (consumedComments.containsKey(key)) { + LOG.warn("Duplicate comment: {} and {}", consumedComments.get(key), comment); + } + consumedComments.put(key, comment); + break; + } + case ARTICLE_RESULT_SINK: { + final AugmentedArticle article = + augmentedArticleDeserializer.deserialize("", record.value()); + consumedAugmentedArticles.put(key, article); + break; + } + case COMMENT_RESULT_SINK: { + final AugmentedComment comment = + augmentedCommentDeserializer.deserialize("", record.value()); + consumedAugmentedComments.put(key, comment); + break; + } + default: + throw new IllegalArgumentException(record.toString()); + } + } + consumer.commitSync(); + } + + if (!printedConsumedArticle && !consumedArticles.isEmpty()) { + LOG.info("Consumed first Article"); + printedConsumedArticle = true; + } + if (!printedConsumedComment && !consumedComments.isEmpty()) { + LOG.info("Consumed first Comment"); + printedConsumedComment = true; + } + if (!printedConsumedAugmentedArticle && !consumedAugmentedArticles.isEmpty()) { + LOG.info("Consumed first AugmentedArticle"); + printedConsumedAugmentedArticle = true; + } + if (!printedConsumedAugmentedComment && !consumedAugmentedComments.isEmpty()) { + LOG.info("Consumed first AugmentedComment"); + printedConsumedAugmentedComment = true; + } + + passed = verifySync( + false, + consumedArticles, + consumedComments, + consumedAugmentedArticles, + consumedAugmentedComments + ); + if (!passed) { + LOG.info("Verification has not passed yet. "); + Thread.sleep(500); + } + } + return verifySync( + true, + consumedArticles, + consumedComments, + consumedAugmentedArticles, + consumedAugmentedComments + ); + } + } + + public static void assertThat(final AtomicBoolean pass, + final StringBuilder failures, + final String message, + final boolean passed) { + if (!passed) { + if (failures != null) { + failures.append("\n").append(message); + } + pass.set(false); + } + } + + static boolean verifySync(final boolean logResults, + final Map consumedArticles, + final Map consumedComments, + final Map consumedAugmentedArticles, + final Map consumedAugmentedComments) { + final AtomicBoolean pass = new AtomicBoolean(true); + final StringBuilder report = logResults ? new StringBuilder() : null; + + assertThat( + pass, + report, + "Expected 1 article, got " + consumedArticles.size(), + consumedArticles.size() > 0 + ); + assertThat( + pass, + report, + "Expected 1 comment, got " + consumedComments.size(), + consumedComments.size() > 0 + ); + + assertThat( + pass, + report, + "Mismatched article size between augmented articles (size " + + consumedAugmentedArticles.size() + + ") and consumed articles (size " + + consumedArticles.size() + ")", + consumedAugmentedArticles.size() == consumedArticles.size() + ); + assertThat( + pass, + report, + "Mismatched comments size between augmented comments (size " + + consumedAugmentedComments.size() + + ") and consumed comments (size " + + consumedComments.size() + ")", + consumedAugmentedComments.size() == consumedComments.size() + ); + + final Map commentCounts = new TreeMap<>(); + + for (final RelationalSmokeTest.AugmentedComment augmentedComment : consumedAugmentedComments.values()) { + final int key = augmentedComment.getKey(); + assertThat( + pass, + report, + "comment missing, but found in augmentedComment: " + key, + consumedComments.containsKey(key) + ); + + final Comment comment = consumedComments.get(key); + if (comment != null) { + assertThat( + pass, + report, + "comment missing, but found in augmentedComment: " + key, + consumedComments.containsKey(key) + ); + } + commentCounts.put( + augmentedComment.getArticleId(), + commentCounts.getOrDefault(augmentedComment.getArticleId(), 0L) + 1 + ); + + assertThat( + pass, + report, + "augmentedArticle [" + augmentedComment.getArticleId() + "] " + + "missing for augmentedComment [" + augmentedComment.getKey() + "]", + consumedAugmentedArticles.containsKey(augmentedComment.getArticleId()) + ); + final AugmentedArticle augmentedArticle = + consumedAugmentedArticles.get(augmentedComment.getArticleId()); + if (augmentedArticle != null) { + assertThat( + pass, + report, + "articlePrefix didn't match augmentedArticle: " + augmentedArticle.getText(), + augmentedArticle.getText().startsWith(augmentedComment.getArticlePrefix()) + ); + } + + assertThat( + pass, + report, + "article " + augmentedComment.getArticleId() + " missing from consumedArticles", + consumedArticles.containsKey(augmentedComment.getArticleId()) + ); + final Article article = consumedArticles.get(augmentedComment.getArticleId()); + if (article != null) { + assertThat( + pass, + report, + "articlePrefix didn't match article: " + article.getText(), + article.getText().startsWith(augmentedComment.getArticlePrefix()) + ); + } + } + + + for (final RelationalSmokeTest.AugmentedArticle augmentedArticle : consumedAugmentedArticles.values()) { + assertThat( + pass, + report, + "article " + augmentedArticle.getKey() + " comment count mismatch", + augmentedArticle.getCommentCount() == commentCounts.getOrDefault(augmentedArticle.getKey(), 0L) + ); + } + + if (logResults) { + if (pass.get()) { + LOG.info( + "Evaluation passed ({}/{}) articles and ({}/{}) comments", + consumedAugmentedArticles.size(), + consumedArticles.size(), + consumedAugmentedComments.size(), + consumedComments.size() + ); + } else { + LOG.error( + "Evaluation failed\nReport: {}\n" + + "Consumed Input Articles: {}\n" + + "Consumed Input Comments: {}\n" + + "Consumed Augmented Articles: {}\n" + + "Consumed Augmented Comments: {}", + report, + consumedArticles, + consumedComments, + consumedAugmentedArticles, + consumedAugmentedComments + ); + } + } + + return pass.get(); + } + } + + /* + * Used by the smoke tests. + */ + public static void main(final String[] args) { + System.out.println(Arrays.toString(args)); + final String mode = args[0]; + final String kafka = args[1]; + + try { + switch (mode) { + case "driver": { + // this starts the driver (data generation and result verification) + final int numArticles = 1_000; + final int numComments = 10_000; + final DataSet dataSet = DataSet.generate(numArticles, numComments); + // publish the data for at least one minute + dataSet.produce(kafka, Duration.ofMinutes(1)); + LOG.info("Smoke test finished producing"); + // let it soak in + Thread.sleep(1000); + LOG.info("Smoke test starting verification"); + // wait for at most 10 minutes to get a passing result + final boolean pass = App.verifySync(kafka, Instant.now().plus(Duration.ofMinutes(10))); + if (pass) { + LOG.info("Smoke test complete: passed"); + } else { + LOG.error("Smoke test complete: failed"); + } + break; + } + case "application": { + final String nodeId = args[2]; + final String processingGuarantee = args[3]; + final String stateDir = args[4]; + App.startSync(kafka, UUID.randomUUID().toString(), nodeId, processingGuarantee, stateDir); + break; + } + default: + LOG.error("Unknown command: {}", mode); + throw new RuntimeException("Unknown command: " + mode); + } + } catch (final InterruptedException e) { + LOG.error("Interrupted", e); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/tests/RelationalSmokeTestTest.java b/streams/src/test/java/org/apache/kafka/streams/tests/RelationalSmokeTestTest.java new file mode 100644 index 0000000..a8f1186 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/tests/RelationalSmokeTestTest.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.serialization.IntegerDeserializer; +import org.apache.kafka.common.serialization.IntegerSerializer; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TestOutputTopic; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.test.TestUtils; +import org.junit.Test; + +import java.util.Map; +import java.util.TreeMap; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +public class RelationalSmokeTestTest extends SmokeTestUtil { + + @Test + public void verifySmokeTestLogic() { + try (final TopologyTestDriver driver = + new TopologyTestDriver(RelationalSmokeTest.App.getTopology(), + RelationalSmokeTest.App.getConfig( + "nothing:0", + "test", + "test", + StreamsConfig.AT_LEAST_ONCE, + TestUtils.tempDirectory().getAbsolutePath() + ))) { + + final TestInputTopic articles = + driver.createInputTopic(RelationalSmokeTest.ARTICLE_SOURCE, + new IntegerSerializer(), + new RelationalSmokeTest.Article.ArticleSerializer()); + + final TestInputTopic comments = + driver.createInputTopic(RelationalSmokeTest.COMMENT_SOURCE, + new IntegerSerializer(), + new RelationalSmokeTest.Comment.CommentSerializer()); + + final TestOutputTopic augmentedArticles = + driver.createOutputTopic(RelationalSmokeTest.ARTICLE_RESULT_SINK, + new IntegerDeserializer(), + new RelationalSmokeTest.AugmentedArticle.AugmentedArticleDeserializer()); + + final TestOutputTopic augmentedComments = + driver.createOutputTopic(RelationalSmokeTest.COMMENT_RESULT_SINK, + new IntegerDeserializer(), + new RelationalSmokeTest.AugmentedComment.AugmentedCommentDeserializer()); + + final RelationalSmokeTest.DataSet dataSet = + RelationalSmokeTest.DataSet.generate(10, 30); + + final Map articleMap = new TreeMap<>(); + for (final RelationalSmokeTest.Article article : dataSet.getArticles()) { + articles.pipeInput(article.getKey(), article, article.getTimestamp()); + articleMap.put(article.getKey(), article); + } + + final Map commentCounts = new TreeMap<>(); + + final Map commentMap = new TreeMap<>(); + for (final RelationalSmokeTest.Comment comment : dataSet.getComments()) { + comments.pipeInput(comment.getKey(), comment, comment.getTimestamp()); + commentMap.put(comment.getKey(), comment); + commentCounts.put(comment.getArticleId(), + commentCounts.getOrDefault(comment.getArticleId(), 0L) + 1); + } + + final Map augmentedArticleResults = + augmentedArticles.readKeyValuesToMap(); + + final Map augmentedCommentResults = + augmentedComments.readKeyValuesToMap(); + + assertThat(augmentedArticleResults.size(), is(dataSet.getArticles().length)); + assertThat(augmentedCommentResults.size(), is(dataSet.getComments().length)); + + assertThat( + RelationalSmokeTest.App.verifySync(true, + articleMap, + commentMap, + augmentedArticleResults, + augmentedCommentResults), + is(true)); + } + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/tests/ShutdownDeadlockTest.java b/streams/src/test/java/org/apache/kafka/streams/tests/ShutdownDeadlockTest.java new file mode 100644 index 0000000..9afd80a --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/tests/ShutdownDeadlockTest.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import java.time.Duration; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.streams.errors.StreamsUncaughtExceptionHandler; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.ForeachAction; +import org.apache.kafka.streams.kstream.KStream; + +import java.util.Properties; + +public class ShutdownDeadlockTest { + + private final String kafka; + + public ShutdownDeadlockTest(final String kafka) { + this.kafka = kafka; + } + + public void start() { + final String topic = "source"; + final Properties props = new Properties(); + props.setProperty(StreamsConfig.APPLICATION_ID_CONFIG, "shouldNotDeadlock"); + props.setProperty(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, kafka); + final StreamsBuilder builder = new StreamsBuilder(); + final KStream source = builder.stream(topic, Consumed.with(Serdes.String(), Serdes.String())); + + source.foreach(new ForeachAction() { + @Override + public void apply(final String key, final String value) { + throw new RuntimeException("KABOOM!"); + } + }); + final KafkaStreams streams = new KafkaStreams(builder.build(), props); + streams.setUncaughtExceptionHandler(e -> { + Exit.exit(1); + return StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse.SHUTDOWN_CLIENT; + }); + + Exit.addShutdownHook("streams-shutdown-hook", () -> streams.close(Duration.ofSeconds(5))); + + final Properties producerProps = new Properties(); + producerProps.put(ProducerConfig.CLIENT_ID_CONFIG, "SmokeTest"); + producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, kafka); + producerProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, StringSerializer.class); + producerProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, StringSerializer.class); + + final KafkaProducer producer = new KafkaProducer<>(producerProps); + producer.send(new ProducerRecord<>(topic, "a", "a")); + producer.flush(); + + streams.start(); + + synchronized (this) { + try { + wait(); + } catch (final InterruptedException e) { + // ignored + } + } + + + } + + + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/tests/SmokeTestClient.java b/streams/src/test/java/org/apache/kafka/streams/tests/SmokeTestClient.java new file mode 100644 index 0000000..86f7583 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/tests/SmokeTestClient.java @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.KafkaThread; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.errors.StreamsUncaughtExceptionHandler; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.Suppressed.BufferConfig; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.WindowStore; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.time.Duration; +import java.time.Instant; +import java.util.Properties; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.apache.kafka.streams.kstream.Suppressed.untilWindowCloses; + +@SuppressWarnings("deprecation") +public class SmokeTestClient extends SmokeTestUtil { + + private final String name; + + private KafkaStreams streams; + private boolean uncaughtException = false; + private volatile boolean closed; + + private static void addShutdownHook(final String name, final Runnable runnable) { + if (name != null) { + Runtime.getRuntime().addShutdownHook(KafkaThread.nonDaemon(name, runnable)); + } else { + Runtime.getRuntime().addShutdownHook(new Thread(runnable)); + } + } + + private static File tempDirectory() { + final String prefix = "kafka-"; + final File file; + try { + file = Files.createTempDirectory(prefix).toFile(); + } catch (final IOException ex) { + throw new RuntimeException("Failed to create a temp dir", ex); + } + file.deleteOnExit(); + + addShutdownHook("delete-temp-file-shutdown-hook", () -> { + try { + Utils.delete(file); + } catch (final IOException e) { + System.out.println("Error deleting " + file.getAbsolutePath()); + e.printStackTrace(System.out); + } + }); + + return file; + } + + public SmokeTestClient(final String name) { + this.name = name; + } + + public boolean closed() { + return closed; + } + + public void start(final Properties streamsProperties) { + final Topology build = getTopology(); + streams = new KafkaStreams(build, getStreamsConfig(streamsProperties)); + + final CountDownLatch countDownLatch = new CountDownLatch(1); + streams.setStateListener((newState, oldState) -> { + System.out.printf("%s %s: %s -> %s%n", name, Instant.now(), oldState, newState); + if (oldState == KafkaStreams.State.REBALANCING && newState == KafkaStreams.State.RUNNING) { + countDownLatch.countDown(); + } + + if (newState == KafkaStreams.State.NOT_RUNNING) { + closed = true; + } + }); + + streams.setUncaughtExceptionHandler(e -> { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION"); + System.out.println(name + ": FATAL: An unexpected exception is encountered on thread " + Thread.currentThread() + ": " + e); + e.printStackTrace(System.out); + uncaughtException = true; + return StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse.SHUTDOWN_CLIENT; + }); + + addShutdownHook("streams-shutdown-hook", this::close); + + streams.start(); + try { + if (!countDownLatch.await(1, TimeUnit.MINUTES)) { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION: Didn't start in one minute"); + } else { + System.out.println(name + ": SMOKE-TEST-CLIENT-STARTED"); + System.out.println(name + " started at " + Instant.now()); + } + } catch (final InterruptedException e) { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION: " + e); + e.printStackTrace(System.out); + } + } + + public void closeAsync() { + streams.close(Duration.ZERO); + } + + public void close() { + final boolean wasClosed = streams.close(Duration.ofMinutes(1)); + + if (wasClosed && !uncaughtException) { + System.out.println(name + ": SMOKE-TEST-CLIENT-CLOSED"); + } else if (wasClosed) { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION: Got an uncaught exception"); + } else { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION: Didn't close in time."); + } + } + + private Properties getStreamsConfig(final Properties props) { + final Properties fullProps = new Properties(props); + fullProps.put(StreamsConfig.APPLICATION_ID_CONFIG, "SmokeTest"); + fullProps.put(StreamsConfig.CLIENT_ID_CONFIG, "SmokeTest-" + name); + fullProps.put(StreamsConfig.STATE_DIR_CONFIG, tempDirectory().getAbsolutePath()); + fullProps.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE_V2); + fullProps.putAll(props); + return fullProps; + } + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + public Topology getTopology() { + final StreamsBuilder builder = new StreamsBuilder(); + final Consumed stringIntConsumed = Consumed.with(stringSerde, intSerde); + final KStream source = builder.stream("data", stringIntConsumed); + source.filterNot((k, v) -> k.equals("flush")) + .to("echo", Produced.with(stringSerde, intSerde)); + final KStream data = source.filter((key, value) -> value == null || value != END); + data.process(SmokeTestUtil.printProcessorSupplier("data", name)); + + // min + final KGroupedStream groupedData = data.groupByKey(Grouped.with(stringSerde, intSerde)); + + final KTable, Integer> minAggregation = groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(1)).grace(Duration.ofMinutes(1))) + .aggregate( + () -> Integer.MAX_VALUE, + (aggKey, value, aggregate) -> (value < aggregate) ? value : aggregate, + Materialized + .>as("uwin-min") + .withValueSerde(intSerde) + .withRetention(Duration.ofHours(25)) + ); + + streamify(minAggregation, "min-raw"); + + streamify(minAggregation.suppress(untilWindowCloses(BufferConfig.unbounded())), "min-suppressed"); + + minAggregation + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("min", Produced.with(stringSerde, intSerde)); + + final KTable, Integer> smallWindowSum = groupedData + .windowedBy(TimeWindows.of(Duration.ofSeconds(2)).advanceBy(Duration.ofSeconds(1)).grace(Duration.ofSeconds(30))) + .reduce((l, r) -> l + r); + + streamify(smallWindowSum, "sws-raw"); + streamify(smallWindowSum.suppress(untilWindowCloses(BufferConfig.unbounded())), "sws-suppressed"); + + final KTable minTable = builder.table( + "min", + Consumed.with(stringSerde, intSerde), + Materialized.as("minStoreName")); + + minTable.toStream().process(SmokeTestUtil.printProcessorSupplier("min", name)); + + // max + groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(2))) + .aggregate( + () -> Integer.MIN_VALUE, + (aggKey, value, aggregate) -> (value > aggregate) ? value : aggregate, + Materialized.>as("uwin-max").withValueSerde(intSerde)) + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("max", Produced.with(stringSerde, intSerde)); + + final KTable maxTable = builder.table( + "max", + Consumed.with(stringSerde, intSerde), + Materialized.as("maxStoreName")); + maxTable.toStream().process(SmokeTestUtil.printProcessorSupplier("max", name)); + + // sum + groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(2))) + .aggregate( + () -> 0L, + (aggKey, value, aggregate) -> (long) value + aggregate, + Materialized.>as("win-sum").withValueSerde(longSerde)) + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("sum", Produced.with(stringSerde, longSerde)); + + final Consumed stringLongConsumed = Consumed.with(stringSerde, longSerde); + final KTable sumTable = builder.table("sum", stringLongConsumed); + sumTable.toStream().process(SmokeTestUtil.printProcessorSupplier("sum", name)); + + // cnt + groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(2))) + .count(Materialized.as("uwin-cnt")) + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("cnt", Produced.with(stringSerde, longSerde)); + + final KTable cntTable = builder.table( + "cnt", + Consumed.with(stringSerde, longSerde), + Materialized.as("cntStoreName")); + cntTable.toStream().process(SmokeTestUtil.printProcessorSupplier("cnt", name)); + + // dif + maxTable + .join( + minTable, + (value1, value2) -> value1 - value2) + .toStream() + .filterNot((k, v) -> k.equals("flush")) + .to("dif", Produced.with(stringSerde, intSerde)); + + // avg + sumTable + .join( + cntTable, + (value1, value2) -> (double) value1 / (double) value2) + .toStream() + .filterNot((k, v) -> k.equals("flush")) + .to("avg", Produced.with(stringSerde, doubleSerde)); + + // test repartition + final Agg agg = new Agg(); + cntTable.groupBy(agg.selector(), Grouped.with(stringSerde, longSerde)) + .aggregate(agg.init(), agg.adder(), agg.remover(), + Materialized.as(Stores.inMemoryKeyValueStore("cntByCnt")) + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.Long())) + .toStream() + .to("tagg", Produced.with(stringSerde, longSerde)); + + return builder.build(); + } + + private static void streamify(final KTable, Integer> windowedTable, final String topic) { + windowedTable + .toStream() + .filterNot((k, v) -> k.key().equals("flush")) + .map((key, value) -> new KeyValue<>(key.toString(), value)) + .to(topic, Produced.with(stringSerde, intSerde)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/tests/SmokeTestDriver.java b/streams/src/test/java/org/apache/kafka/streams/tests/SmokeTestDriver.java new file mode 100644 index 0000000..ac83cd9 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/tests/SmokeTestDriver.java @@ -0,0 +1,622 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.Callback; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; + +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static java.util.Collections.emptyMap; +import static org.apache.kafka.common.utils.Utils.mkEntry; + +public class SmokeTestDriver extends SmokeTestUtil { + private static final String[] TOPICS = { + "data", + "echo", + "max", + "min", "min-suppressed", "min-raw", + "dif", + "sum", + "sws-raw", "sws-suppressed", + "cnt", + "avg", + "tagg" + }; + + private static final int MAX_RECORD_EMPTY_RETRIES = 30; + + private static class ValueList { + public final String key; + private final int[] values; + private int index; + + ValueList(final int min, final int max) { + key = min + "-" + max; + + values = new int[max - min + 1]; + for (int i = 0; i < values.length; i++) { + values[i] = min + i; + } + // We want to randomize the order of data to test not completely predictable processing order + // However, values are also use as a timestamp of the record. (TODO: separate data and timestamp) + // We keep some correlation of time and order. Thus, the shuffling is done with a sliding window + shuffle(values, 10); + + index = 0; + } + + int next() { + return (index < values.length) ? values[index++] : -1; + } + } + + public static String[] topics() { + return Arrays.copyOf(TOPICS, TOPICS.length); + } + + static void generatePerpetually(final String kafka, + final int numKeys, + final int maxRecordsPerKey) { + final Properties producerProps = generatorProperties(kafka); + + int numRecordsProduced = 0; + + final ValueList[] data = new ValueList[numKeys]; + for (int i = 0; i < numKeys; i++) { + data[i] = new ValueList(i, i + maxRecordsPerKey - 1); + } + + final Random rand = new Random(); + + try (final KafkaProducer producer = new KafkaProducer<>(producerProps)) { + while (true) { + final int index = rand.nextInt(numKeys); + final String key = data[index].key; + final int value = data[index].next(); + + final ProducerRecord record = + new ProducerRecord<>( + "data", + stringSerde.serializer().serialize("", key), + intSerde.serializer().serialize("", value) + ); + + producer.send(record); + + numRecordsProduced++; + if (numRecordsProduced % 100 == 0) { + System.out.println(Instant.now() + " " + numRecordsProduced + " records produced"); + } + Utils.sleep(2); + } + } + } + + public static Map> generate(final String kafka, + final int numKeys, + final int maxRecordsPerKey, + final Duration timeToSpend) { + final Properties producerProps = generatorProperties(kafka); + + + int numRecordsProduced = 0; + + final Map> allData = new HashMap<>(); + final ValueList[] data = new ValueList[numKeys]; + for (int i = 0; i < numKeys; i++) { + data[i] = new ValueList(i, i + maxRecordsPerKey - 1); + allData.put(data[i].key, new HashSet<>()); + } + final Random rand = new Random(); + + int remaining = data.length; + + final long recordPauseTime = timeToSpend.toMillis() / numKeys / maxRecordsPerKey; + + List> needRetry = new ArrayList<>(); + + try (final KafkaProducer producer = new KafkaProducer<>(producerProps)) { + while (remaining > 0) { + final int index = rand.nextInt(remaining); + final String key = data[index].key; + final int value = data[index].next(); + + if (value < 0) { + remaining--; + data[index] = data[remaining]; + } else { + + final ProducerRecord record = + new ProducerRecord<>( + "data", + stringSerde.serializer().serialize("", key), + intSerde.serializer().serialize("", value) + ); + + producer.send(record, new TestCallback(record, needRetry)); + + numRecordsProduced++; + allData.get(key).add(value); + if (numRecordsProduced % 100 == 0) { + System.out.println(Instant.now() + " " + numRecordsProduced + " records produced"); + } + Utils.sleep(Math.max(recordPauseTime, 2)); + } + } + producer.flush(); + + int remainingRetries = 5; + while (!needRetry.isEmpty()) { + final List> needRetry2 = new ArrayList<>(); + for (final ProducerRecord record : needRetry) { + System.out.println("retry producing " + stringSerde.deserializer().deserialize("", record.key())); + producer.send(record, new TestCallback(record, needRetry2)); + } + producer.flush(); + needRetry = needRetry2; + + if (--remainingRetries == 0 && !needRetry.isEmpty()) { + System.err.println("Failed to produce all records after multiple retries"); + Exit.exit(1); + } + } + + // now that we've sent everything, we'll send some final records with a timestamp high enough to flush out + // all suppressed records. + final List partitions = producer.partitionsFor("data"); + for (final PartitionInfo partition : partitions) { + producer.send(new ProducerRecord<>( + partition.topic(), + partition.partition(), + System.currentTimeMillis() + Duration.ofDays(2).toMillis(), + stringSerde.serializer().serialize("", "flush"), + intSerde.serializer().serialize("", 0) + )); + } + } + return Collections.unmodifiableMap(allData); + } + + private static Properties generatorProperties(final String kafka) { + final Properties producerProps = new Properties(); + producerProps.put(ProducerConfig.CLIENT_ID_CONFIG, "SmokeTest"); + producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, kafka); + producerProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class); + producerProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class); + producerProps.put(ProducerConfig.ACKS_CONFIG, "all"); + return producerProps; + } + + private static class TestCallback implements Callback { + private final ProducerRecord originalRecord; + private final List> needRetry; + + TestCallback(final ProducerRecord originalRecord, + final List> needRetry) { + this.originalRecord = originalRecord; + this.needRetry = needRetry; + } + + @Override + public void onCompletion(final RecordMetadata metadata, final Exception exception) { + if (exception != null) { + if (exception instanceof TimeoutException) { + needRetry.add(originalRecord); + } else { + exception.printStackTrace(); + Exit.exit(1); + } + } + } + } + + private static void shuffle(final int[] data, @SuppressWarnings("SameParameterValue") final int windowSize) { + final Random rand = new Random(); + for (int i = 0; i < data.length; i++) { + // we shuffle data within windowSize + final int j = rand.nextInt(Math.min(data.length - i, windowSize)) + i; + + // swap + final int tmp = data[i]; + data[i] = data[j]; + data[j] = tmp; + } + } + + public static class NumberDeserializer implements Deserializer { + @Override + public Number deserialize(final String topic, final byte[] data) { + final Number value; + switch (topic) { + case "data": + case "echo": + case "min": + case "min-raw": + case "min-suppressed": + case "sws-raw": + case "sws-suppressed": + case "max": + case "dif": + value = intSerde.deserializer().deserialize(topic, data); + break; + case "sum": + case "cnt": + case "tagg": + value = longSerde.deserializer().deserialize(topic, data); + break; + case "avg": + value = doubleSerde.deserializer().deserialize(topic, data); + break; + default: + throw new RuntimeException("unknown topic: " + topic); + } + return value; + } + } + + public static VerificationResult verify(final String kafka, + final Map> inputs, + final int maxRecordsPerKey) { + final Properties props = new Properties(); + props.put(ConsumerConfig.CLIENT_ID_CONFIG, "verifier"); + props.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, kafka); + props.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); + props.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, NumberDeserializer.class); + props.put(ConsumerConfig.ISOLATION_LEVEL_CONFIG, "read_committed"); + + final KafkaConsumer consumer = new KafkaConsumer<>(props); + final List partitions = getAllPartitions(consumer, TOPICS); + consumer.assign(partitions); + consumer.seekToBeginning(partitions); + + final int recordsGenerated = inputs.size() * maxRecordsPerKey; + int recordsProcessed = 0; + final Map processed = + Stream.of(TOPICS) + .collect(Collectors.toMap(t -> t, t -> new AtomicInteger(0))); + + final Map>>> events = new HashMap<>(); + + VerificationResult verificationResult = new VerificationResult(false, "no results yet"); + int retry = 0; + final long start = System.currentTimeMillis(); + while (System.currentTimeMillis() - start < TimeUnit.MINUTES.toMillis(6)) { + final ConsumerRecords records = consumer.poll(Duration.ofSeconds(5)); + if (records.isEmpty() && recordsProcessed >= recordsGenerated) { + verificationResult = verifyAll(inputs, events, false); + if (verificationResult.passed()) { + break; + } else if (retry++ > MAX_RECORD_EMPTY_RETRIES) { + System.out.println(Instant.now() + " Didn't get any more results, verification hasn't passed, and out of retries."); + break; + } else { + System.out.println(Instant.now() + " Didn't get any more results, but verification hasn't passed (yet). Retrying..." + retry); + } + } else { + System.out.println(Instant.now() + " Get some more results from " + records.partitions() + ", resetting retry."); + + retry = 0; + for (final ConsumerRecord record : records) { + final String key = record.key(); + + final String topic = record.topic(); + processed.get(topic).incrementAndGet(); + + if (topic.equals("echo")) { + recordsProcessed++; + if (recordsProcessed % 100 == 0) { + System.out.println("Echo records processed = " + recordsProcessed); + } + } + + events.computeIfAbsent(topic, t -> new HashMap<>()) + .computeIfAbsent(key, k -> new LinkedList<>()) + .add(record); + } + + System.out.println(processed); + } + } + consumer.close(); + final long finished = System.currentTimeMillis() - start; + System.out.println("Verification time=" + finished); + System.out.println("-------------------"); + System.out.println("Result Verification"); + System.out.println("-------------------"); + System.out.println("recordGenerated=" + recordsGenerated); + System.out.println("recordProcessed=" + recordsProcessed); + + if (recordsProcessed > recordsGenerated) { + System.out.println("PROCESSED-MORE-THAN-GENERATED"); + } else if (recordsProcessed < recordsGenerated) { + System.out.println("PROCESSED-LESS-THAN-GENERATED"); + } + + boolean success; + + final Map> received = + events.get("echo") + .entrySet() + .stream() + .map(entry -> mkEntry( + entry.getKey(), + entry.getValue().stream().map(ConsumerRecord::value).collect(Collectors.toSet())) + ) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + success = inputs.equals(received); + + if (success) { + System.out.println("ALL-RECORDS-DELIVERED"); + } else { + int missedCount = 0; + for (final Map.Entry> entry : inputs.entrySet()) { + missedCount += received.get(entry.getKey()).size(); + } + System.out.println("missedRecords=" + missedCount); + } + + // give it one more try if it's not already passing. + if (!verificationResult.passed()) { + verificationResult = verifyAll(inputs, events, true); + } + success &= verificationResult.passed(); + + System.out.println(verificationResult.result()); + + System.out.println(success ? "SUCCESS" : "FAILURE"); + return verificationResult; + } + + public static class VerificationResult { + private final boolean passed; + private final String result; + + VerificationResult(final boolean passed, final String result) { + this.passed = passed; + this.result = result; + } + + public boolean passed() { + return passed; + } + + public String result() { + return result; + } + } + + private static VerificationResult verifyAll(final Map> inputs, + final Map>>> events, + final boolean printResults) { + final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + boolean pass; + try (final PrintStream resultStream = new PrintStream(byteArrayOutputStream)) { + pass = verifyTAgg(resultStream, inputs, events.get("tagg"), printResults); + pass &= verifySuppressed(resultStream, "min-suppressed", events, printResults); + pass &= verify(resultStream, "min-suppressed", inputs, events, windowedKey -> { + final String unwindowedKey = windowedKey.substring(1, windowedKey.length() - 1).replaceAll("@.*", ""); + return getMin(unwindowedKey); + }, printResults); + pass &= verifySuppressed(resultStream, "sws-suppressed", events, printResults); + pass &= verify(resultStream, "min", inputs, events, SmokeTestDriver::getMin, printResults); + pass &= verify(resultStream, "max", inputs, events, SmokeTestDriver::getMax, printResults); + pass &= verify(resultStream, "dif", inputs, events, key -> getMax(key).intValue() - getMin(key).intValue(), printResults); + pass &= verify(resultStream, "sum", inputs, events, SmokeTestDriver::getSum, printResults); + pass &= verify(resultStream, "cnt", inputs, events, key1 -> getMax(key1).intValue() - getMin(key1).intValue() + 1L, printResults); + pass &= verify(resultStream, "avg", inputs, events, SmokeTestDriver::getAvg, printResults); + } + return new VerificationResult(pass, new String(byteArrayOutputStream.toByteArray(), StandardCharsets.UTF_8)); + } + + private static boolean verify(final PrintStream resultStream, + final String topic, + final Map> inputData, + final Map>>> events, + final Function keyToExpectation, + final boolean printResults) { + final Map>> observedInputEvents = events.get("data"); + final Map>> outputEvents = events.getOrDefault(topic, emptyMap()); + if (outputEvents.isEmpty()) { + resultStream.println(topic + " is empty"); + return false; + } else { + resultStream.printf("verifying %s with %d keys%n", topic, outputEvents.size()); + + if (outputEvents.size() != inputData.size()) { + resultStream.printf("fail: resultCount=%d expectedCount=%s%n\tresult=%s%n\texpected=%s%n", + outputEvents.size(), inputData.size(), outputEvents.keySet(), inputData.keySet()); + return false; + } + for (final Map.Entry>> entry : outputEvents.entrySet()) { + final String key = entry.getKey(); + final Number expected = keyToExpectation.apply(key); + final Number actual = entry.getValue().getLast().value(); + if (!expected.equals(actual)) { + resultStream.printf("%s fail: key=%s actual=%s expected=%s%n", topic, key, actual, expected); + + if (printResults) { + resultStream.printf("\t inputEvents=%n%s%n\t" + + "echoEvents=%n%s%n\tmaxEvents=%n%s%n\tminEvents=%n%s%n\tdifEvents=%n%s%n\tcntEvents=%n%s%n\ttaggEvents=%n%s%n", + indent("\t\t", observedInputEvents.get(key)), + indent("\t\t", events.getOrDefault("echo", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("max", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("min", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("dif", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("cnt", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("tagg", emptyMap()).getOrDefault(key, new LinkedList<>()))); + + if (!Utils.mkSet("echo", "max", "min", "dif", "cnt", "tagg").contains(topic)) + resultStream.printf("%sEvents=%n%s%n", topic, indent("\t\t", entry.getValue())); + } + + return false; + } + } + return true; + } + } + + + private static boolean verifySuppressed(final PrintStream resultStream, + @SuppressWarnings("SameParameterValue") final String topic, + final Map>>> events, + final boolean printResults) { + resultStream.println("verifying suppressed " + topic); + final Map>> topicEvents = events.getOrDefault(topic, emptyMap()); + for (final Map.Entry>> entry : topicEvents.entrySet()) { + if (entry.getValue().size() != 1) { + final String unsuppressedTopic = topic.replace("-suppressed", "-raw"); + final String key = entry.getKey(); + final String unwindowedKey = key.substring(1, key.length() - 1).replaceAll("@.*", ""); + resultStream.printf("fail: key=%s%n\tnon-unique result:%n%s%n", + key, + indent("\t\t", entry.getValue())); + + if (printResults) + resultStream.printf("\tresultEvents:%n%s%n\tinputEvents:%n%s%n", + indent("\t\t", events.get(unsuppressedTopic).get(key)), + indent("\t\t", events.get("data").get(unwindowedKey))); + + return false; + } + } + return true; + } + + private static String indent(@SuppressWarnings("SameParameterValue") final String prefix, + final Iterable> list) { + final StringBuilder stringBuilder = new StringBuilder(); + for (final ConsumerRecord record : list) { + stringBuilder.append(prefix).append(record).append('\n'); + } + return stringBuilder.toString(); + } + + private static Long getSum(final String key) { + final int min = getMin(key).intValue(); + final int max = getMax(key).intValue(); + return ((long) min + max) * (max - min + 1L) / 2L; + } + + private static Double getAvg(final String key) { + final int min = getMin(key).intValue(); + final int max = getMax(key).intValue(); + return ((long) min + max) / 2.0; + } + + + private static boolean verifyTAgg(final PrintStream resultStream, + final Map> allData, + final Map>> taggEvents, + final boolean printResults) { + if (taggEvents == null) { + resultStream.println("tagg is missing"); + return false; + } else if (taggEvents.isEmpty()) { + resultStream.println("tagg is empty"); + return false; + } else { + resultStream.println("verifying tagg"); + + // generate expected answer + final Map expected = new HashMap<>(); + for (final String key : allData.keySet()) { + final int min = getMin(key).intValue(); + final int max = getMax(key).intValue(); + final String cnt = Long.toString(max - min + 1L); + + expected.put(cnt, expected.getOrDefault(cnt, 0L) + 1); + } + + // check the result + for (final Map.Entry>> entry : taggEvents.entrySet()) { + final String key = entry.getKey(); + Long expectedCount = expected.remove(key); + if (expectedCount == null) { + expectedCount = 0L; + } + + if (entry.getValue().getLast().value().longValue() != expectedCount) { + resultStream.println("fail: key=" + key + " tagg=" + entry.getValue() + " expected=" + expectedCount); + + if (printResults) + resultStream.println("\t taggEvents: " + entry.getValue()); + return false; + } + } + + } + return true; + } + + private static Number getMin(final String key) { + return Integer.parseInt(key.split("-")[0]); + } + + private static Number getMax(final String key) { + return Integer.parseInt(key.split("-")[1]); + } + + private static List getAllPartitions(final KafkaConsumer consumer, final String... topics) { + final List partitions = new ArrayList<>(); + + for (final String topic : topics) { + for (final PartitionInfo info : consumer.partitionsFor(topic)) { + partitions.add(new TopicPartition(info.topic(), info.partition())); + } + } + return partitions; + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/tests/SmokeTestUtil.java b/streams/src/test/java/org/apache/kafka/streams/tests/SmokeTestUtil.java new file mode 100644 index 0000000..1222a81 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/tests/SmokeTestUtil.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.ProcessorContext; + +import java.time.Instant; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class SmokeTestUtil { + + final static int END = Integer.MAX_VALUE; + + static org.apache.kafka.streams.processor.ProcessorSupplier printProcessorSupplier(final String topic) { + return printProcessorSupplier(topic, ""); + } + + static org.apache.kafka.streams.processor.ProcessorSupplier printProcessorSupplier(final String topic, final String name) { + return new org.apache.kafka.streams.processor.ProcessorSupplier() { + @Override + public org.apache.kafka.streams.processor.Processor get() { + return new org.apache.kafka.streams.processor.AbstractProcessor() { + private int numRecordsProcessed = 0; + private long smallestOffset = Long.MAX_VALUE; + private long largestOffset = Long.MIN_VALUE; + + @Override + public void init(final ProcessorContext context) { + super.init(context); + System.out.println("[DEV] initializing processor: topic=" + topic + " taskId=" + context.taskId()); + System.out.flush(); + numRecordsProcessed = 0; + smallestOffset = Long.MAX_VALUE; + largestOffset = Long.MIN_VALUE; + } + + @Override + public void process(final Object key, final Object value) { + numRecordsProcessed++; + if (numRecordsProcessed % 100 == 0) { + System.out.printf("%s: %s%n", name, Instant.now()); + System.out.println("processed " + numRecordsProcessed + " records from topic=" + topic); + } + + if (smallestOffset > context().offset()) { + smallestOffset = context().offset(); + } + if (largestOffset < context().offset()) { + largestOffset = context().offset(); + } + } + + @Override + public void close() { + System.out.printf("Close processor for task %s%n", context().taskId()); + System.out.println("processed " + numRecordsProcessed + " records"); + final long processed; + if (largestOffset >= smallestOffset) { + processed = 1L + largestOffset - smallestOffset; + } else { + processed = 0L; + } + System.out.println("offset " + smallestOffset + " to " + largestOffset + " -> processed " + processed); + System.out.flush(); + } + }; + } + }; + } + + public static final class Unwindow implements KeyValueMapper, V, K> { + @Override + public K apply(final Windowed winKey, final V value) { + return winKey.key(); + } + } + + public static class Agg { + + KeyValueMapper> selector() { + return (key, value) -> new KeyValue<>(value == null ? null : Long.toString(value), 1L); + } + + public Initializer init() { + return () -> 0L; + } + + Aggregator adder() { + return (aggKey, value, aggregate) -> aggregate + value; + } + + Aggregator remover() { + return (aggKey, value, aggregate) -> aggregate - value; + } + } + + public static Serde stringSerde = Serdes.String(); + + public static Serde intSerde = Serdes.Integer(); + + static Serde longSerde = Serdes.Long(); + + static Serde doubleSerde = Serdes.Double(); + + public static void sleep(final long duration) { + try { + Thread.sleep(duration); + } catch (final Exception ignore) { } + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/tests/StaticMemberTestClient.java b/streams/src/test/java/org/apache/kafka/streams/tests/StaticMemberTestClient.java new file mode 100644 index 0000000..e4b96fe --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/tests/StaticMemberTestClient.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.KStream; + +import java.util.Objects; +import java.util.Properties; + +public class StaticMemberTestClient { + + private static String testName = "StaticMemberTestClient"; + + @SuppressWarnings("unchecked") + public static void main(final String[] args) throws Exception { + if (args.length < 1) { + System.err.println(testName + " requires one argument (properties-file) but none provided: "); + } + + System.out.println("StreamsTest instance started"); + + final String propFileName = args[0]; + + final Properties streamsProperties = Utils.loadProps(propFileName); + + final String groupInstanceId = Objects.requireNonNull(streamsProperties.getProperty(ConsumerConfig.GROUP_INSTANCE_ID_CONFIG)); + + System.out.println(testName + " instance started with group.instance.id " + groupInstanceId); + System.out.println("props=" + streamsProperties); + System.out.flush(); + + final StreamsBuilder builder = new StreamsBuilder(); + final String inputTopic = (String) (Objects.requireNonNull(streamsProperties.remove("input.topic"))); + + final KStream dataStream = builder.stream(inputTopic); + dataStream.peek((k, v) -> System.out.println(String.format("PROCESSED key=%s value=%s", k, v))); + + final Properties config = new Properties(); + config.setProperty(StreamsConfig.APPLICATION_ID_CONFIG, testName); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + config.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.StringSerde.class); + config.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.StringSerde.class); + + config.putAll(streamsProperties); + + final KafkaStreams streams = new KafkaStreams(builder.build(), config); + streams.setStateListener((newState, oldState) -> { + if (oldState == KafkaStreams.State.REBALANCING && newState == KafkaStreams.State.RUNNING) { + System.out.println("REBALANCING -> RUNNING"); + System.out.flush(); + } + }); + + streams.start(); + + Exit.addShutdownHook("streams-shutdown-hook", () -> { + System.out.println("closing Kafka Streams instance"); + System.out.flush(); + streams.close(); + System.out.println("Static membership test closed"); + System.out.flush(); + }); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/tests/StreamsBrokerDownResilienceTest.java b/streams/src/test/java/org/apache/kafka/streams/tests/StreamsBrokerDownResilienceTest.java new file mode 100644 index 0000000..90c2bb9 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/tests/StreamsBrokerDownResilienceTest.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.tests; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.StreamsUncaughtExceptionHandler; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.ForeachAction; + +import java.io.IOException; +import java.time.Duration; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; + +public class StreamsBrokerDownResilienceTest { + + private static final int KEY = 0; + private static final int VALUE = 1; + + private static final String SOURCE_TOPIC_1 = "streamsResilienceSource"; + + private static final String SINK_TOPIC = "streamsResilienceSink"; + + public static void main(final String[] args) throws IOException { + if (args.length < 2) { + System.err.println("StreamsBrokerDownResilienceTest are expecting two parameters: propFile, additionalConfigs; but only see " + args.length + " parameter"); + Exit.exit(1); + } + + System.out.println("StreamsTest instance started"); + + final String propFileName = args[0]; + final String additionalConfigs = args[1]; + + final Properties streamsProperties = Utils.loadProps(propFileName); + final String kafka = streamsProperties.getProperty(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG); + + if (kafka == null) { + System.err.println("No bootstrap kafka servers specified in " + StreamsConfig.BOOTSTRAP_SERVERS_CONFIG); + Exit.exit(1); + } + + streamsProperties.put(StreamsConfig.APPLICATION_ID_CONFIG, "kafka-streams-resilience"); + streamsProperties.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + streamsProperties.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + streamsProperties.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L); + + + // it is expected that max.poll.interval, retries, request.timeout and max.block.ms set + // streams_broker_down_resilience_test and passed as args + if (additionalConfigs != null && !additionalConfigs.equalsIgnoreCase("none")) { + final Map updated = updatedConfigs(additionalConfigs); + System.out.println("Updating configs with " + updated); + streamsProperties.putAll(updated); + } + + if (!confirmCorrectConfigs(streamsProperties)) { + System.err.println(String.format("ERROR: Did not have all required configs expected to contain %s %s %s %s", + StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG), + StreamsConfig.producerPrefix(ProducerConfig.RETRIES_CONFIG), + StreamsConfig.producerPrefix(ProducerConfig.REQUEST_TIMEOUT_MS_CONFIG), + StreamsConfig.producerPrefix(ProducerConfig.MAX_BLOCK_MS_CONFIG))); + + Exit.exit(1); + } + + final StreamsBuilder builder = new StreamsBuilder(); + final Serde stringSerde = Serdes.String(); + + builder.stream(Collections.singletonList(SOURCE_TOPIC_1), Consumed.with(stringSerde, stringSerde)) + .peek(new ForeachAction() { + int messagesProcessed = 0; + @Override + public void apply(final String key, final String value) { + System.out.println("received key " + key + " and value " + value); + messagesProcessed++; + System.out.println("processed " + messagesProcessed + " messages"); + System.out.flush(); + } + }).to(SINK_TOPIC); + + final KafkaStreams streams = new KafkaStreams(builder.build(), streamsProperties); + + streams.setUncaughtExceptionHandler(e -> { + System.err.println("FATAL: An unexpected exception " + e); + System.err.flush(); + return StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse.SHUTDOWN_CLIENT; + } + ); + + System.out.println("Start Kafka Streams"); + streams.start(); + + Exit.addShutdownHook("streams-shutdown-hook", () -> { + streams.close(Duration.ofSeconds(30)); + System.out.println("Complete shutdown of streams resilience test app now"); + System.out.flush(); + }); + } + + private static boolean confirmCorrectConfigs(final Properties properties) { + return properties.containsKey(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG)) && + properties.containsKey(StreamsConfig.producerPrefix(ProducerConfig.RETRIES_CONFIG)) && + properties.containsKey(StreamsConfig.producerPrefix(ProducerConfig.REQUEST_TIMEOUT_MS_CONFIG)) && + properties.containsKey(StreamsConfig.producerPrefix(ProducerConfig.MAX_BLOCK_MS_CONFIG)); + } + + /** + * Takes a string with keys and values separated by '=' and each key value pair + * separated by ',' for example max.block.ms=5000,retries=6,request.timeout.ms=6000 + * + * @param formattedConfigs the formatted config string + * @return HashMap with keys and values inserted + */ + private static Map updatedConfigs(final String formattedConfigs) { + final String[] parts = formattedConfigs.split(","); + final Map updatedConfigs = new HashMap<>(); + for (final String part : parts) { + final String[] keyValue = part.split("="); + updatedConfigs.put(keyValue[KEY], keyValue[VALUE]); + } + return updatedConfigs; + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/tests/StreamsEosTest.java b/streams/src/test/java/org/apache/kafka/streams/tests/StreamsEosTest.java new file mode 100644 index 0000000..5ad0641 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/tests/StreamsEosTest.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.StreamsConfig; + +import java.io.IOException; +import java.util.Properties; + +public class StreamsEosTest { + + /** + * args ::= kafka propFileName command + * command := "run" | "process" | "verify" + */ + @SuppressWarnings("deprecation") + public static void main(final String[] args) throws IOException { + if (args.length < 2) { + System.err.println("StreamsEosTest are expecting two parameters: propFile, command; but only see " + args.length + " parameter"); + Exit.exit(1); + } + + final String propFileName = args[0]; + final String command = args[1]; + + final Properties streamsProperties = Utils.loadProps(propFileName); + final String kafka = streamsProperties.getProperty(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG); + final String processingGuarantee = streamsProperties.getProperty(StreamsConfig.PROCESSING_GUARANTEE_CONFIG); + + if (kafka == null) { + System.err.println("No bootstrap kafka servers specified in " + StreamsConfig.BOOTSTRAP_SERVERS_CONFIG); + Exit.exit(1); + } + + if ("process".equals(command) || "process-complex".equals(command)) { + if (!StreamsConfig.EXACTLY_ONCE.equals(processingGuarantee) && + !StreamsConfig.EXACTLY_ONCE_BETA.equals(processingGuarantee) && + !StreamsConfig.EXACTLY_ONCE_V2.equals(processingGuarantee)) { + + System.err.println("processingGuarantee must be either " + StreamsConfig.EXACTLY_ONCE + " or " + + StreamsConfig.EXACTLY_ONCE_BETA + " or " + StreamsConfig.EXACTLY_ONCE_V2); + Exit.exit(1); + } + } + + System.out.println("StreamsTest instance started"); + System.out.println("kafka=" + kafka); + System.out.println("props=" + streamsProperties); + System.out.println("command=" + command); + System.out.flush(); + + if (command == null || propFileName == null) { + Exit.exit(-1); + } + + switch (command) { + case "run": + EosTestDriver.generate(kafka); + break; + case "process": + new EosTestClient(streamsProperties, false).start(); + break; + case "process-complex": + new EosTestClient(streamsProperties, true).start(); + break; + case "verify": + EosTestDriver.verify(kafka, false); + break; + case "verify-complex": + EosTestDriver.verify(kafka, true); + break; + default: + System.out.println("unknown command: " + command); + System.out.flush(); + Exit.exit(-1); + } + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/tests/StreamsNamedRepartitionTest.java b/streams/src/test/java/org/apache/kafka/streams/tests/StreamsNamedRepartitionTest.java new file mode 100644 index 0000000..b98f861 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/tests/StreamsNamedRepartitionTest.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KafkaStreams.State; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.state.KeyValueStore; + +import java.time.Duration; +import java.util.Objects; +import java.util.Properties; +import java.util.function.Function; + +public class StreamsNamedRepartitionTest { + + public static void main(final String[] args) throws Exception { + if (args.length < 1) { + System.err.println("StreamsNamedRepartitionTest requires one argument (properties-file) but none provided: "); + } + final String propFileName = args[0]; + + final Properties streamsProperties = Utils.loadProps(propFileName); + + System.out.println("StreamsTest instance started NAMED_REPARTITION_TEST"); + System.out.println("props=" + streamsProperties); + + final String inputTopic = (String) (Objects.requireNonNull(streamsProperties.remove("input.topic"))); + final String aggregationTopic = (String) (Objects.requireNonNull(streamsProperties.remove("aggregation.topic"))); + final boolean addOperators = Boolean.valueOf(Objects.requireNonNull((String) streamsProperties.remove("add.operations"))); + + + final Initializer initializer = () -> 0; + final Aggregator aggregator = (k, v, agg) -> agg + Integer.parseInt(v); + + final Function keyFunction = s -> Integer.toString(Integer.parseInt(s) % 9); + + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream sourceStream = builder.stream(inputTopic, Consumed.with(Serdes.String(), Serdes.String())); + sourceStream.peek((k, v) -> System.out.println(String.format("input data key=%s, value=%s", k, v))); + + final KStream mappedStream = sourceStream.selectKey((k, v) -> keyFunction.apply(v)); + + final KStream maybeUpdatedStream; + + if (addOperators) { + maybeUpdatedStream = mappedStream.filter((k, v) -> true).mapValues(v -> Integer.toString(Integer.parseInt(v) + 1)); + } else { + maybeUpdatedStream = mappedStream; + } + + maybeUpdatedStream.groupByKey(Grouped.with("grouped-stream", Serdes.String(), Serdes.String())) + .aggregate(initializer, aggregator, Materialized.>as("count-store").withKeySerde(Serdes.String()).withValueSerde(Serdes.Integer())) + .toStream() + .peek((k, v) -> System.out.println(String.format("AGGREGATED key=%s value=%s", k, v))) + .to(aggregationTopic, Produced.with(Serdes.String(), Serdes.Integer())); + + final Properties config = new Properties(); + + config.setProperty(StreamsConfig.APPLICATION_ID_CONFIG, "StreamsNamedRepartitionTest"); + config.setProperty(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, "0"); + config.setProperty(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass().getName()); + config.setProperty(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass().getName()); + + + config.putAll(streamsProperties); + + final Topology topology = builder.build(config); + final KafkaStreams streams = new KafkaStreams(topology, config); + + + streams.setStateListener((newState, oldState) -> { + if (oldState == State.REBALANCING && newState == State.RUNNING) { + if (addOperators) { + System.out.println("UPDATED Topology"); + } else { + System.out.println("REBALANCING -> RUNNING"); + } + System.out.flush(); + } + }); + + streams.start(); + + Exit.addShutdownHook("streams-shutdown-hook", () -> { + System.out.println("closing Kafka Streams instance"); + System.out.flush(); + streams.close(Duration.ofMillis(5000)); + System.out.println("NAMED_REPARTITION_TEST Streams Stopped"); + System.out.flush(); + }); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/tests/StreamsOptimizedTest.java b/streams/src/test/java/org/apache/kafka/streams/tests/StreamsOptimizedTest.java new file mode 100644 index 0000000..714aa11 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/tests/StreamsOptimizedTest.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.tests; + +import org.apache.kafka.clients.admin.AdminClientConfig; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KafkaStreams.State; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.kstream.JoinWindows; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.Reducer; +import org.apache.kafka.streams.kstream.StreamJoined; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.Properties; +import java.util.function.Function; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static java.time.Duration.ofMillis; + +@SuppressWarnings("deprecation") +public class StreamsOptimizedTest { + + public static void main(final String[] args) throws Exception { + if (args.length < 1) { + System.err.println("StreamsOptimizedTest requires one argument (properties-file) but no provided: "); + } + final String propFileName = args[0]; + + final Properties streamsProperties = Utils.loadProps(propFileName); + + System.out.println("StreamsTest instance started StreamsOptimizedTest"); + System.out.println("props=" + streamsProperties); + + final String inputTopic = (String) Objects.requireNonNull(streamsProperties.remove("input.topic")); + final String aggregationTopic = (String) Objects.requireNonNull(streamsProperties.remove("aggregation.topic")); + final String reduceTopic = (String) Objects.requireNonNull(streamsProperties.remove("reduce.topic")); + final String joinTopic = (String) Objects.requireNonNull(streamsProperties.remove("join.topic")); + + + final Pattern repartitionTopicPattern = Pattern.compile("Sink: .*-repartition"); + final Initializer initializer = () -> 0; + final Aggregator aggregator = (k, v, agg) -> agg + v.length(); + + final Reducer reducer = (v1, v2) -> Integer.toString(Integer.parseInt(v1) + Integer.parseInt(v2)); + + final Function keyFunction = s -> Integer.toString(Integer.parseInt(s) % 9); + + final StreamsBuilder builder = new StreamsBuilder(); + + final KStream sourceStream = builder.stream(inputTopic, Consumed.with(Serdes.String(), Serdes.String())); + + final KStream mappedStream = sourceStream.selectKey((k, v) -> keyFunction.apply(v)); + + final KStream countStream = mappedStream.groupByKey() + .count(Materialized.with(Serdes.String(), + Serdes.Long())).toStream(); + + mappedStream.groupByKey().aggregate( + initializer, + aggregator, + Materialized.with(Serdes.String(), Serdes.Integer())) + .toStream() + .peek((k, v) -> System.out.println(String.format("AGGREGATED key=%s value=%s", k, v))) + .to(aggregationTopic, Produced.with(Serdes.String(), Serdes.Integer())); + + + mappedStream.groupByKey() + .reduce(reducer, Materialized.with(Serdes.String(), Serdes.String())) + .toStream() + .peek((k, v) -> System.out.println(String.format("REDUCED key=%s value=%s", k, v))) + .to(reduceTopic, Produced.with(Serdes.String(), Serdes.String())); + + mappedStream.join(countStream, (v1, v2) -> v1 + ":" + v2.toString(), + JoinWindows.of(ofMillis(500)), + StreamJoined.with(Serdes.String(), Serdes.String(), Serdes.Long())) + .peek((k, v) -> System.out.println(String.format("JOINED key=%s value=%s", k, v))) + .to(joinTopic, Produced.with(Serdes.String(), Serdes.String())); + + final Properties config = new Properties(); + + + config.setProperty(StreamsConfig.APPLICATION_ID_CONFIG, "StreamsOptimizedTest"); + config.setProperty(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, "0"); + config.setProperty(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass().getName()); + config.setProperty(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass().getName()); + config.setProperty(StreamsConfig.adminClientPrefix(AdminClientConfig.RETRIES_CONFIG), "100"); + + + config.putAll(streamsProperties); + + final Topology topology = builder.build(config); + final KafkaStreams streams = new KafkaStreams(topology, config); + + + streams.setStateListener((newState, oldState) -> { + if (oldState == State.REBALANCING && newState == State.RUNNING) { + final int repartitionTopicCount = getCountOfRepartitionTopicsFound(topology.describe().toString(), repartitionTopicPattern); + System.out.println(String.format("REBALANCING -> RUNNING with REPARTITION TOPIC COUNT=%d", repartitionTopicCount)); + System.out.flush(); + } + }); + + streams.cleanUp(); + streams.start(); + + Exit.addShutdownHook("streams-shutdown-hook", () -> { + System.out.println("closing Kafka Streams instance"); + System.out.flush(); + streams.close(Duration.ofMillis(5000)); + System.out.println("OPTIMIZE_TEST Streams Stopped"); + System.out.flush(); + }); + + } + + private static int getCountOfRepartitionTopicsFound(final String topologyString, + final Pattern repartitionTopicPattern) { + final Matcher matcher = repartitionTopicPattern.matcher(topologyString); + final List repartitionTopicsFound = new ArrayList<>(); + while (matcher.find()) { + final String repartitionTopic = matcher.group(); + System.out.println(String.format("REPARTITION TOPIC found -> %s", repartitionTopic)); + repartitionTopicsFound.add(repartitionTopic); + } + return repartitionTopicsFound.size(); + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/tests/StreamsSmokeTest.java b/streams/src/test/java/org/apache/kafka/streams/tests/StreamsSmokeTest.java new file mode 100644 index 0000000..d87da74 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/tests/StreamsSmokeTest.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.StreamsConfig; + +import java.io.IOException; +import java.time.Duration; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.UUID; + +import static org.apache.kafka.streams.tests.SmokeTestDriver.generate; +import static org.apache.kafka.streams.tests.SmokeTestDriver.generatePerpetually; + +public class StreamsSmokeTest { + + /** + * args ::= kafka propFileName command disableAutoTerminate + * command := "run" | "process" + * + * @param args + */ + @SuppressWarnings("deprecation") + public static void main(final String[] args) throws IOException { + if (args.length < 2) { + System.err.println("StreamsSmokeTest are expecting two parameters: propFile, command; but only see " + args.length + " parameter"); + Exit.exit(1); + } + + final String propFileName = args[0]; + final String command = args[1]; + final boolean disableAutoTerminate = args.length > 2; + + final Properties streamsProperties = Utils.loadProps(propFileName); + final String kafka = streamsProperties.getProperty(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG); + final String processingGuarantee = streamsProperties.getProperty(StreamsConfig.PROCESSING_GUARANTEE_CONFIG); + + if (kafka == null) { + System.err.println("No bootstrap kafka servers specified in " + StreamsConfig.BOOTSTRAP_SERVERS_CONFIG); + Exit.exit(1); + } + + if ("process".equals(command)) { + if (!StreamsConfig.AT_LEAST_ONCE.equals(processingGuarantee) && + !StreamsConfig.EXACTLY_ONCE.equals(processingGuarantee) && + !StreamsConfig.EXACTLY_ONCE_BETA.equals(processingGuarantee) && + !StreamsConfig.EXACTLY_ONCE_V2.equals(processingGuarantee)) { + + System.err.println("processingGuarantee must be either " + + StreamsConfig.AT_LEAST_ONCE + ", " + + StreamsConfig.EXACTLY_ONCE + ", or " + + StreamsConfig.EXACTLY_ONCE_BETA + ", or " + + StreamsConfig.EXACTLY_ONCE_V2); + + Exit.exit(1); + } + } + + System.out.println("StreamsTest instance started (StreamsSmokeTest)"); + System.out.println("command=" + command); + System.out.println("props=" + streamsProperties); + System.out.println("disableAutoTerminate=" + disableAutoTerminate); + + switch (command) { + case "run": + // this starts the driver (data generation and result verification) + final int numKeys = 20; + final int maxRecordsPerKey = 1000; + if (disableAutoTerminate) { + generatePerpetually(kafka, numKeys, maxRecordsPerKey); + } else { + // slow down data production so that system tests have time to + // do their bounces, etc. + final Map> allData = + generate(kafka, numKeys, maxRecordsPerKey, Duration.ofSeconds(90)); + SmokeTestDriver.verify(kafka, allData, maxRecordsPerKey); + } + break; + case "process": + // this starts the stream processing app + new SmokeTestClient(UUID.randomUUID().toString()).start(streamsProperties); + break; + case "close-deadlock-test": + final ShutdownDeadlockTest test = new ShutdownDeadlockTest(kafka); + test.start(); + break; + default: + System.out.println("unknown command: " + command); + } + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/tests/StreamsStandByReplicaTest.java b/streams/src/test/java/org/apache/kafka/streams/tests/StreamsStandByReplicaTest.java new file mode 100644 index 0000000..3c693cc --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/tests/StreamsStandByReplicaTest.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.tests; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.ThreadMetadata; +import org.apache.kafka.streams.errors.StreamsUncaughtExceptionHandler; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.ValueMapper; +import org.apache.kafka.streams.state.KeyValueBytesStoreSupplier; +import org.apache.kafka.streams.state.Stores; + +import java.io.IOException; +import java.time.Duration; +import java.util.Map; +import java.util.Properties; +import java.util.Set; + +public class StreamsStandByReplicaTest { + + public static void main(final String[] args) throws IOException { + if (args.length < 2) { + System.err.println("StreamsStandByReplicaTest are expecting two parameters: " + + "propFile, additionalConfigs; but only see " + args.length + " parameter"); + Exit.exit(1); + } + + System.out.println("StreamsTest instance started"); + + final String propFileName = args[0]; + final String additionalConfigs = args[1]; + + final Properties streamsProperties = Utils.loadProps(propFileName); + final String kafka = streamsProperties.getProperty(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG); + + if (kafka == null) { + System.err.println("No bootstrap kafka servers specified in " + StreamsConfig.BOOTSTRAP_SERVERS_CONFIG); + Exit.exit(1); + } + + streamsProperties.put(StreamsConfig.APPLICATION_ID_CONFIG, "kafka-streams-standby-tasks"); + streamsProperties.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L); + streamsProperties.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1); + streamsProperties.put(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG, 0); + streamsProperties.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + streamsProperties.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + streamsProperties.put(StreamsConfig.producerPrefix(ProducerConfig.ENABLE_IDEMPOTENCE_CONFIG), true); + + if (additionalConfigs == null) { + System.err.println("additional configs are not provided"); + System.err.flush(); + Exit.exit(1); + } + + final Map updated = SystemTestUtil.parseConfigs(additionalConfigs); + System.out.println("Updating configs with " + updated); + + final String sourceTopic = updated.remove("sourceTopic"); + final String sinkTopic1 = updated.remove("sinkTopic1"); + final String sinkTopic2 = updated.remove("sinkTopic2"); + + if (sourceTopic == null || sinkTopic1 == null || sinkTopic2 == null) { + System.err.println(String.format( + "one or more required topics null sourceTopic[%s], sinkTopic1[%s], sinkTopic2[%s]", + sourceTopic, + sinkTopic1, + sinkTopic2)); + System.err.flush(); + Exit.exit(1); + } + + streamsProperties.putAll(updated); + + if (!confirmCorrectConfigs(streamsProperties)) { + System.err.println(String.format("ERROR: Did not have all required configs expected to contain %s, %s, %s, %s", + StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG), + StreamsConfig.producerPrefix(ProducerConfig.RETRIES_CONFIG), + StreamsConfig.producerPrefix(ProducerConfig.REQUEST_TIMEOUT_MS_CONFIG), + StreamsConfig.producerPrefix(ProducerConfig.MAX_BLOCK_MS_CONFIG))); + + Exit.exit(1); + } + + final StreamsBuilder builder = new StreamsBuilder(); + + final String inMemoryStoreName = "in-memory-store"; + final String persistentMemoryStoreName = "persistent-memory-store"; + + final KeyValueBytesStoreSupplier inMemoryStoreSupplier = Stores.inMemoryKeyValueStore(inMemoryStoreName); + final KeyValueBytesStoreSupplier persistentStoreSupplier = Stores.persistentKeyValueStore(persistentMemoryStoreName); + + final Serde stringSerde = Serdes.String(); + final ValueMapper countMapper = Object::toString; + + final KStream inputStream = builder.stream(sourceTopic, Consumed.with(stringSerde, stringSerde)); + + inputStream.groupByKey().count(Materialized.as(inMemoryStoreSupplier)).toStream().mapValues(countMapper) + .to(sinkTopic1, Produced.with(stringSerde, stringSerde)); + + inputStream.groupByKey().count(Materialized.as(persistentStoreSupplier)).toStream().mapValues(countMapper) + .to(sinkTopic2, Produced.with(stringSerde, stringSerde)); + + final KafkaStreams streams = new KafkaStreams(builder.build(), streamsProperties); + + streams.setUncaughtExceptionHandler(e -> { + System.err.println("FATAL: An unexpected exception " + e); + e.printStackTrace(System.err); + System.err.flush(); + return StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse.SHUTDOWN_CLIENT; + }); + + streams.setStateListener((newState, oldState) -> { + if (newState == KafkaStreams.State.RUNNING && oldState == KafkaStreams.State.REBALANCING) { + final Set threadMetadata = streams.metadataForLocalThreads(); + for (final ThreadMetadata threadMetadatum : threadMetadata) { + System.out.println( + "ACTIVE_TASKS:" + threadMetadatum.activeTasks().size() + + " STANDBY_TASKS:" + threadMetadatum.standbyTasks().size()); + } + } + }); + + System.out.println("Start Kafka Streams"); + streams.start(); + + Exit.addShutdownHook("streams-shutdown-hook", () -> { + shutdown(streams); + System.out.println("Shut down streams now"); + }); + } + + private static void shutdown(final KafkaStreams streams) { + streams.close(Duration.ofSeconds(10)); + } + + private static boolean confirmCorrectConfigs(final Properties properties) { + return properties.containsKey(StreamsConfig.consumerPrefix(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG)) && + properties.containsKey(StreamsConfig.producerPrefix(ProducerConfig.RETRIES_CONFIG)) && + properties.containsKey(StreamsConfig.producerPrefix(ProducerConfig.REQUEST_TIMEOUT_MS_CONFIG)) && + properties.containsKey(StreamsConfig.producerPrefix(ProducerConfig.MAX_BLOCK_MS_CONFIG)); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java b/streams/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java new file mode 100644 index 0000000..f20a532 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java @@ -0,0 +1,393 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerGroupMetadata; +import org.apache.kafka.clients.consumer.ConsumerPartitionAssignor; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.ByteBufferInputStream; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaClientSupplier; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.TaskAssignmentException; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.DefaultKafkaClientSupplier; +import org.apache.kafka.streams.processor.internals.StreamsPartitionAssignor; +import org.apache.kafka.streams.processor.internals.TaskManager; +import org.apache.kafka.streams.processor.internals.assignment.AssignmentInfo; +import org.apache.kafka.streams.processor.internals.assignment.AssignorConfiguration; +import org.apache.kafka.streams.processor.internals.assignment.LegacySubscriptionInfoSerde; +import org.apache.kafka.streams.processor.internals.assignment.SubscriptionInfo; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.BufferUnderflowException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; + +import static org.apache.kafka.streams.processor.internals.assignment.StreamsAssignmentProtocolVersions.LATEST_SUPPORTED_VERSION; + +public class StreamsUpgradeTest { + + public static void main(final String[] args) throws Exception { + if (args.length < 1) { + System.err.println("StreamsUpgradeTest requires one argument (properties-file) but no provided: "); + } + final String propFileName = args.length > 0 ? args[0] : null; + + final Properties streamsProperties = Utils.loadProps(propFileName); + + System.out.println("StreamsTest instance started (StreamsUpgradeTest trunk)"); + System.out.println("props=" + streamsProperties); + + final KafkaStreams streams = buildStreams(streamsProperties); + streams.start(); + + Exit.addShutdownHook("streams-shutdown-hook", () -> { + System.out.println("closing Kafka Streams instance"); + System.out.flush(); + streams.close(); + System.out.println("UPGRADE-TEST-CLIENT-CLOSED"); + System.out.flush(); + }); + } + + @SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. + public static KafkaStreams buildStreams(final Properties streamsProperties) { + final StreamsBuilder builder = new StreamsBuilder(); + final KStream dataStream = builder.stream("data"); + dataStream.process(SmokeTestUtil.printProcessorSupplier("data")); + dataStream.to("echo"); + + final Properties config = new Properties(); + config.setProperty(StreamsConfig.APPLICATION_ID_CONFIG, "StreamsUpgradeTest"); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + config.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.ByteArraySerde.class.getName()); + config.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.ByteArraySerde.class.getName()); + + final KafkaClientSupplier kafkaClientSupplier; + if (streamsProperties.containsKey("test.future.metadata")) { + kafkaClientSupplier = new FutureKafkaClientSupplier(); + } else { + kafkaClientSupplier = new DefaultKafkaClientSupplier(); + } + config.putAll(streamsProperties); + + return new KafkaStreams(builder.build(), config, kafkaClientSupplier); + } + + private static class FutureKafkaClientSupplier extends DefaultKafkaClientSupplier { + @Override + public Consumer getConsumer(final Map config) { + config.put(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG, FutureStreamsPartitionAssignor.class.getName()); + return new KafkaConsumer<>(config, new ByteArrayDeserializer(), new ByteArrayDeserializer()); + } + } + + public static class FutureStreamsPartitionAssignor extends StreamsPartitionAssignor { + private final Logger log = LoggerFactory.getLogger(FutureStreamsPartitionAssignor.class); + + private AtomicInteger usedSubscriptionMetadataVersionPeek; + private AtomicLong nextScheduledRebalanceMs; + + public FutureStreamsPartitionAssignor() { + usedSubscriptionMetadataVersion = LATEST_SUPPORTED_VERSION + 1; + } + + @Override + public void configure(final Map configs) { + final Object o = configs.get("test.future.metadata"); + if (o instanceof AtomicInteger) { + usedSubscriptionMetadataVersionPeek = (AtomicInteger) o; + } else { + // will not be used, just adding a dummy container for simpler code paths + usedSubscriptionMetadataVersionPeek = new AtomicInteger(); + } + configs.remove("test.future.metadata"); + nextScheduledRebalanceMs = new AssignorConfiguration(configs).referenceContainer().nextScheduledRebalanceMs; + + super.configure(configs); + } + + @Override + public ByteBuffer subscriptionUserData(final Set topics) { + // Adds the following information to subscription + // 1. Client UUID (a unique id assigned to an instance of KafkaStreams) + // 2. Task ids of previously running tasks + // 3. Task ids of valid local states on the client's state directory. + final TaskManager taskManager = taskManager(); + handleRebalanceStart(topics); + byte uniqueField = 0; + if (usedSubscriptionMetadataVersion <= LATEST_SUPPORTED_VERSION) { + uniqueField++; + return new SubscriptionInfo( + usedSubscriptionMetadataVersion, + LATEST_SUPPORTED_VERSION + 1, + taskManager.processId(), + userEndPoint(), + taskManager.getTaskOffsetSums(), + uniqueField, + 0 + ).encode(); + } else { + return new FutureSubscriptionInfo( + usedSubscriptionMetadataVersion, + taskManager.processId(), + SubscriptionInfo.getActiveTasksFromTaskOffsetSumMap(taskManager.getTaskOffsetSums()), + SubscriptionInfo.getStandbyTasksFromTaskOffsetSumMap(taskManager.getTaskOffsetSums()), + userEndPoint()) + .encode(); + } + } + + @Override + public void onAssignment(final ConsumerPartitionAssignor.Assignment assignment, + final ConsumerGroupMetadata metadata) { + try { + super.onAssignment(assignment, metadata); + usedSubscriptionMetadataVersionPeek.set(usedSubscriptionMetadataVersion); + return; + } catch (final TaskAssignmentException cannotProcessFutureVersion) { + // continue + } + + final ByteBuffer data = assignment.userData(); + data.rewind(); + + final int usedVersion; + try (final DataInputStream in = new DataInputStream(new ByteBufferInputStream(data))) { + usedVersion = in.readInt(); + } catch (final IOException ex) { + throw new TaskAssignmentException("Failed to decode AssignmentInfo", ex); + } + + if (usedVersion > LATEST_SUPPORTED_VERSION + 1) { + throw new IllegalStateException("Unknown metadata version: " + usedVersion + + "; latest supported version: " + LATEST_SUPPORTED_VERSION + 1); + } + + final AssignmentInfo info = AssignmentInfo.decode( + assignment.userData().putInt(0, LATEST_SUPPORTED_VERSION)); + + if (maybeUpdateSubscriptionVersion(usedVersion, info.commonlySupportedVersion())) { + log.info("Requested to schedule immediate rebalance due to version probing."); + nextScheduledRebalanceMs.set(0L); + usedSubscriptionMetadataVersionPeek.set(usedSubscriptionMetadataVersion); + } + + final List partitions = new ArrayList<>(assignment.partitions()); + partitions.sort(PARTITION_COMPARATOR); + + final Map> activeTasks = getActiveTasks(partitions, info); + + final TaskManager taskManager = taskManager(); + taskManager.handleAssignment(activeTasks, info.standbyTasks()); + usedSubscriptionMetadataVersionPeek.set(usedSubscriptionMetadataVersion); + } + + @Override + public GroupAssignment assign(final Cluster metadata, final GroupSubscription groupSubscription) { + final Map subscriptions = groupSubscription.groupSubscription(); + final Set supportedVersions = new HashSet<>(); + for (final Map.Entry entry : subscriptions.entrySet()) { + final Subscription subscription = entry.getValue(); + final SubscriptionInfo info = SubscriptionInfo.decode(subscription.userData()); + supportedVersions.add(info.latestSupportedVersion()); + } + Map assignment = null; + + final Map downgradedSubscriptions = new HashMap<>(); + for (final Subscription subscription : subscriptions.values()) { + final SubscriptionInfo info = SubscriptionInfo.decode(subscription.userData()); + if (info.version() < LATEST_SUPPORTED_VERSION + 1) { + assignment = super.assign(metadata, new GroupSubscription(subscriptions)).groupAssignment(); + break; + } + } + + boolean bumpUsedVersion = false; + final boolean bumpSupportedVersion; + if (assignment != null) { + bumpSupportedVersion = supportedVersions.size() == 1 && supportedVersions.iterator().next() == LATEST_SUPPORTED_VERSION + 1; + } else { + for (final Map.Entry entry : subscriptions.entrySet()) { + final Subscription subscription = entry.getValue(); + + final SubscriptionInfo info = SubscriptionInfo.decode(subscription.userData() + .putInt(0, LATEST_SUPPORTED_VERSION) + .putInt(4, LATEST_SUPPORTED_VERSION)); + + downgradedSubscriptions.put( + entry.getKey(), + new Subscription( + subscription.topics(), + new SubscriptionInfo( + LATEST_SUPPORTED_VERSION, + LATEST_SUPPORTED_VERSION, + info.processId(), + info.userEndPoint(), + taskManager().getTaskOffsetSums(), + (byte) 0, + 0 + ).encode(), + subscription.ownedPartitions() + )); + } + assignment = super.assign(metadata, new GroupSubscription(downgradedSubscriptions)).groupAssignment(); + bumpUsedVersion = true; + bumpSupportedVersion = true; + } + + final Map newAssignment = new HashMap<>(); + for (final Map.Entry entry : assignment.entrySet()) { + final Assignment singleAssignment = entry.getValue(); + newAssignment.put( + entry.getKey(), + new Assignment( + singleAssignment.partitions(), + new FutureAssignmentInfo( + bumpUsedVersion, + bumpSupportedVersion, + singleAssignment.userData()) + .encode())); + } + + return new GroupAssignment(newAssignment); + } + } + + private static class FutureSubscriptionInfo { + private final int version; + private final UUID processId; + private final Set activeTasks; + private final Set standbyTasks; + private final String userEndPoint; + + // for testing only; don't apply version checks + FutureSubscriptionInfo(final int version, + final UUID processId, + final Set activeTasks, + final Set standbyTasks, + final String userEndPoint) { + this.version = version; + this.processId = processId; + this.activeTasks = activeTasks; + this.standbyTasks = standbyTasks; + this.userEndPoint = userEndPoint; + if (version <= LATEST_SUPPORTED_VERSION) { + throw new IllegalArgumentException("this class can't be used with version " + version); + } + } + + private ByteBuffer encode() { + final byte[] endPointBytes = LegacySubscriptionInfoSerde.prepareUserEndPoint(userEndPoint); + + final ByteBuffer buf = ByteBuffer.allocate( + 4 + // used version + 4 + // latest supported version version + 16 + // client ID + 4 + activeTasks.size() * 8 + // length + active tasks + 4 + standbyTasks.size() * 8 + // length + standby tasks + 4 + endPointBytes.length + // length + endpoint + 4 + //uniqueField + 4 //assignment error code + ); + + buf.putInt(version); // used version + buf.putInt(version); // supported version + LegacySubscriptionInfoSerde.encodeClientUUID(buf, processId); + LegacySubscriptionInfoSerde.encodeTasks(buf, activeTasks, version); + LegacySubscriptionInfoSerde.encodeTasks(buf, standbyTasks, version); + LegacySubscriptionInfoSerde.encodeUserEndPoint(buf, endPointBytes); + + buf.rewind(); + + return buf; + } + } + + private static class FutureAssignmentInfo extends AssignmentInfo { + private final boolean bumpUsedVersion; + private final boolean bumpSupportedVersion; + final ByteBuffer originalUserMetadata; + + private FutureAssignmentInfo(final boolean bumpUsedVersion, + final boolean bumpSupportedVersion, + final ByteBuffer bytes) { + super(LATEST_SUPPORTED_VERSION, LATEST_SUPPORTED_VERSION); + this.bumpUsedVersion = bumpUsedVersion; + this.bumpSupportedVersion = bumpSupportedVersion; + originalUserMetadata = bytes; + } + + @Override + public ByteBuffer encode() { + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + + originalUserMetadata.rewind(); + + try (final DataOutputStream out = new DataOutputStream(baos)) { + if (bumpUsedVersion) { + originalUserMetadata.getInt(); // discard original used version + out.writeInt(LATEST_SUPPORTED_VERSION + 1); + } else { + out.writeInt(originalUserMetadata.getInt()); + } + if (bumpSupportedVersion) { + originalUserMetadata.getInt(); // discard original supported version + out.writeInt(LATEST_SUPPORTED_VERSION + 1); + } + + try { + while (true) { + out.write(originalUserMetadata.get()); + } + } catch (final BufferUnderflowException expectedWhenAllDataCopied) { } + + out.flush(); + out.close(); + + return ByteBuffer.wrap(baos.toByteArray()); + } catch (final IOException ex) { + throw new TaskAssignmentException("Failed to encode AssignmentInfo", ex); + } + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java b/streams/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java new file mode 100644 index 0000000..6d7da29 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KafkaStreams.State; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.TaskMetadata; +import org.apache.kafka.streams.ThreadMetadata; +import org.apache.kafka.streams.kstream.ForeachAction; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Properties; +import java.util.Set; + +public class StreamsUpgradeToCooperativeRebalanceTest { + + + public static void main(final String[] args) throws Exception { + if (args.length < 1) { + System.err.println("StreamsUpgradeToCooperativeRebalanceTest requires one argument (properties-file) but no args provided"); + } + System.out.println("Args are " + Arrays.toString(args)); + final String propFileName = args[0]; + final Properties streamsProperties = Utils.loadProps(propFileName); + + final Properties config = new Properties(); + System.out.println("StreamsTest instance started (StreamsUpgradeToCooperativeRebalanceTest)"); + System.out.println("props=" + streamsProperties); + + config.put(StreamsConfig.APPLICATION_ID_CONFIG, "cooperative-rebalance-upgrade"); + config.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + config.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + config.putAll(streamsProperties); + + final String sourceTopic = streamsProperties.getProperty("source.topic", "source"); + final String sinkTopic = streamsProperties.getProperty("sink.topic", "sink"); + final String taskDelimiter = "#"; + final int reportInterval = Integer.parseInt(streamsProperties.getProperty("report.interval", "100")); + final String upgradePhase = streamsProperties.getProperty("upgrade.phase", ""); + + final StreamsBuilder builder = new StreamsBuilder(); + + builder.stream(sourceTopic) + .peek(new ForeachAction() { + int recordCounter = 0; + + @Override + public void apply(final String key, final String value) { + if (recordCounter++ % reportInterval == 0) { + System.out.println(String.format("%sProcessed %d records so far", upgradePhase, recordCounter)); + System.out.flush(); + } + } + } + ).to(sinkTopic); + + final KafkaStreams streams = new KafkaStreams(builder.build(), config); + + streams.setStateListener((newState, oldState) -> { + if (newState == State.RUNNING && oldState == State.REBALANCING) { + System.out.println(String.format("%sSTREAMS in a RUNNING State", upgradePhase)); + final Set allThreadMetadata = streams.metadataForLocalThreads(); + final StringBuilder taskReportBuilder = new StringBuilder(); + final List activeTasks = new ArrayList<>(); + final List standbyTasks = new ArrayList<>(); + for (final ThreadMetadata threadMetadata : allThreadMetadata) { + getTasks(threadMetadata.activeTasks(), activeTasks); + if (!threadMetadata.standbyTasks().isEmpty()) { + getTasks(threadMetadata.standbyTasks(), standbyTasks); + } + } + addTasksToBuilder(activeTasks, taskReportBuilder); + taskReportBuilder.append(taskDelimiter); + if (!standbyTasks.isEmpty()) { + addTasksToBuilder(standbyTasks, taskReportBuilder); + } + System.out.println("TASK-ASSIGNMENTS:" + taskReportBuilder); + } + + if (newState == State.REBALANCING) { + System.out.println(String.format("%sStarting a REBALANCE", upgradePhase)); + } + }); + + + streams.start(); + + Exit.addShutdownHook("streams-shutdown-hook", () -> { + streams.close(); + System.out.printf("%sCOOPERATIVE-REBALANCE-TEST-CLIENT-CLOSED%n", upgradePhase); + System.out.flush(); + }); + } + + private static void addTasksToBuilder(final List tasks, final StringBuilder builder) { + if (!tasks.isEmpty()) { + for (final String task : tasks) { + builder.append(task).append(","); + } + builder.setLength(builder.length() - 1); + } + } + + private static void getTasks(final Set taskMetadata, + final List taskList) { + for (final TaskMetadata task : taskMetadata) { + final Set topicPartitions = task.topicPartitions(); + for (final TopicPartition topicPartition : topicPartitions) { + taskList.add(topicPartition.toString()); + } + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/tests/SystemTestUtil.java b/streams/src/test/java/org/apache/kafka/streams/tests/SystemTestUtil.java new file mode 100644 index 0000000..4ddbf69 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/tests/SystemTestUtil.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.tests; + +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +/** + * Class for common convenience methods for working on + * System tests + */ + +public class SystemTestUtil { + + private static final int KEY = 0; + private static final int VALUE = 1; + + /** + * Takes a string with keys and values separated by '=' and each key value pair + * separated by ',' for example max.block.ms=5000,retries=6,request.timeout.ms=6000 + * + * This class makes it easier to pass configs from the system test in python to the Java test. + * + * @param formattedConfigs the formatted config string + * @return HashMap with keys and values inserted + */ + public static Map parseConfigs(final String formattedConfigs) { + Objects.requireNonNull(formattedConfigs, "Formatted config String can't be null"); + + if (formattedConfigs.indexOf('=') == -1) { + throw new IllegalStateException(String.format("Provided string [ %s ] does not have expected key-value separator of '='", formattedConfigs)); + } + + final String[] parts = formattedConfigs.split(","); + final Map configs = new HashMap<>(); + for (final String part : parts) { + final String[] keyValue = part.split("="); + if (keyValue.length > 2) { + throw new IllegalStateException( + String.format("Provided string [ %s ] does not have expected key-value pair separator of ','", formattedConfigs)); + } + configs.put(keyValue[KEY], keyValue[VALUE]); + } + return configs; + } +} diff --git a/streams/src/test/java/org/apache/kafka/streams/tests/SystemTestUtilTest.java b/streams/src/test/java/org/apache/kafka/streams/tests/SystemTestUtilTest.java new file mode 100644 index 0000000..a2a26a3 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/tests/SystemTestUtilTest.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams.tests; + +import org.junit.Before; +import org.junit.Test; + +import java.util.HashMap; +import java.util.Map; +import java.util.TreeMap; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +public class SystemTestUtilTest { + + private final Map expectedParsedMap = new TreeMap<>(); + + @Before + public void setUp() { + expectedParsedMap.put("foo", "foo1"); + expectedParsedMap.put("bar", "bar1"); + expectedParsedMap.put("baz", "baz1"); + } + + @Test + public void shouldParseCorrectMap() { + final String formattedConfigs = "foo=foo1,bar=bar1,baz=baz1"; + final Map parsedMap = SystemTestUtil.parseConfigs(formattedConfigs); + final TreeMap sortedParsedMap = new TreeMap<>(parsedMap); + assertEquals(sortedParsedMap, expectedParsedMap); + } + + @Test + public void shouldThrowExceptionOnNull() { + assertThrows(NullPointerException.class, () -> SystemTestUtil.parseConfigs(null)); + } + + @Test + public void shouldThrowExceptionIfNotCorrectKeyValueSeparator() { + final String badString = "foo:bar,baz:boo"; + assertThrows(IllegalStateException.class, () -> SystemTestUtil.parseConfigs(badString)); + } + + @Test + public void shouldThrowExceptionIfNotCorrectKeyValuePairSeparator() { + final String badString = "foo=bar;baz=boo"; + assertThrows(IllegalStateException.class, () -> SystemTestUtil.parseConfigs(badString)); + } + + @Test + public void shouldParseSingleKeyValuePairString() { + final Map expectedSinglePairMap = new HashMap<>(); + expectedSinglePairMap.put("foo", "bar"); + final String singleValueString = "foo=bar"; + final Map parsedMap = SystemTestUtil.parseConfigs(singleValueString); + assertEquals(expectedSinglePairMap, parsedMap); + } + + +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/tools/StreamsResetterTest.java b/streams/src/test/java/org/apache/kafka/streams/tools/StreamsResetterTest.java new file mode 100644 index 0000000..d4f7841 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/tools/StreamsResetterTest.java @@ -0,0 +1,321 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tools; + +import kafka.tools.StreamsResetter; +import org.apache.kafka.clients.admin.MockAdminClient; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.MockConsumer; +import org.apache.kafka.clients.consumer.OffsetAndTimestamp; +import org.apache.kafka.clients.consumer.OffsetResetStrategy; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.TopicPartitionInfo; +import org.junit.Before; +import org.junit.Test; + +import java.time.Duration; +import java.time.Instant; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ExecutionException; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class StreamsResetterTest { + + private static final String TOPIC = "topic1"; + private final StreamsResetter streamsResetter = new StreamsResetter(); + private final MockConsumer consumer = new MockConsumer<>(OffsetResetStrategy.EARLIEST); + private final TopicPartition topicPartition = new TopicPartition(TOPIC, 0); + private final Set inputTopicPartitions = new HashSet<>(Collections.singletonList(topicPartition)); + + @Before + public void setUp() { + consumer.assign(Collections.singletonList(topicPartition)); + + consumer.addRecord(new ConsumerRecord<>(TOPIC, 0, 0L, new byte[] {}, new byte[] {})); + consumer.addRecord(new ConsumerRecord<>(TOPIC, 0, 1L, new byte[] {}, new byte[] {})); + consumer.addRecord(new ConsumerRecord<>(TOPIC, 0, 2L, new byte[] {}, new byte[] {})); + consumer.addRecord(new ConsumerRecord<>(TOPIC, 0, 3L, new byte[] {}, new byte[] {})); + consumer.addRecord(new ConsumerRecord<>(TOPIC, 0, 4L, new byte[] {}, new byte[] {})); + } + + @Test + public void testResetToSpecificOffsetWhenBetweenBeginningAndEndOffset() { + final Map endOffsets = new HashMap<>(); + endOffsets.put(topicPartition, 4L); + consumer.updateEndOffsets(endOffsets); + + final Map beginningOffsets = new HashMap<>(); + beginningOffsets.put(topicPartition, 0L); + consumer.updateBeginningOffsets(beginningOffsets); + + streamsResetter.resetOffsetsTo(consumer, inputTopicPartitions, 2L); + + final ConsumerRecords records = consumer.poll(Duration.ofMillis(500)); + assertEquals(3, records.count()); + } + + @Test + public void testResetOffsetToSpecificOffsetWhenAfterEndOffset() { + final long beginningOffset = 5L; + final long endOffset = 10L; + final MockConsumer emptyConsumer = new MockConsumer<>(OffsetResetStrategy.EARLIEST); + emptyConsumer.assign(Collections.singletonList(topicPartition)); + + final Map beginningOffsetsMap = new HashMap<>(); + beginningOffsetsMap.put(topicPartition, beginningOffset); + emptyConsumer.updateBeginningOffsets(beginningOffsetsMap); + + final Map endOffsetsMap = new HashMap<>(); + endOffsetsMap.put(topicPartition, endOffset); + emptyConsumer.updateEndOffsets(endOffsetsMap); + // resetOffsetsTo only seeks the offset, but does not commit. + streamsResetter.resetOffsetsTo(emptyConsumer, inputTopicPartitions, endOffset + 2L); + + final long position = emptyConsumer.position(topicPartition); + + assertEquals(endOffset, position); + } + + @Test + public void testResetToSpecificOffsetWhenBeforeBeginningOffset() { + final Map endOffsets = new HashMap<>(); + endOffsets.put(topicPartition, 4L); + consumer.updateEndOffsets(endOffsets); + + final Map beginningOffsets = new HashMap<>(); + beginningOffsets.put(topicPartition, 3L); + consumer.updateBeginningOffsets(beginningOffsets); + + streamsResetter.resetOffsetsTo(consumer, inputTopicPartitions, 2L); + + final ConsumerRecords records = consumer.poll(Duration.ofMillis(500)); + assertEquals(2, records.count()); + } + + @Test + public void testResetToSpecificOffsetWhenAfterEndOffset() { + final Map endOffsets = new HashMap<>(); + endOffsets.put(topicPartition, 3L); + consumer.updateEndOffsets(endOffsets); + + final Map beginningOffsets = new HashMap<>(); + beginningOffsets.put(topicPartition, 0L); + consumer.updateBeginningOffsets(beginningOffsets); + + streamsResetter.resetOffsetsTo(consumer, inputTopicPartitions, 4L); + + final ConsumerRecords records = consumer.poll(Duration.ofMillis(500)); + assertEquals(2, records.count()); + } + + @Test + public void testShiftOffsetByWhenBetweenBeginningAndEndOffset() { + final Map endOffsets = new HashMap<>(); + endOffsets.put(topicPartition, 4L); + consumer.updateEndOffsets(endOffsets); + + final Map beginningOffsets = new HashMap<>(); + beginningOffsets.put(topicPartition, 0L); + consumer.updateBeginningOffsets(beginningOffsets); + + streamsResetter.shiftOffsetsBy(consumer, inputTopicPartitions, 3L); + + final ConsumerRecords records = consumer.poll(Duration.ofMillis(500)); + assertEquals(2, records.count()); + } + + @Test + public void testShiftOffsetByWhenBeforeBeginningOffset() { + final Map endOffsets = new HashMap<>(); + endOffsets.put(topicPartition, 4L); + consumer.updateEndOffsets(endOffsets); + + final Map beginningOffsets = new HashMap<>(); + beginningOffsets.put(topicPartition, 0L); + consumer.updateBeginningOffsets(beginningOffsets); + + streamsResetter.shiftOffsetsBy(consumer, inputTopicPartitions, -3L); + + final ConsumerRecords records = consumer.poll(Duration.ofMillis(500)); + assertEquals(5, records.count()); + } + + @Test + public void testShiftOffsetByWhenAfterEndOffset() { + final Map endOffsets = new HashMap<>(); + endOffsets.put(topicPartition, 3L); + consumer.updateEndOffsets(endOffsets); + + final Map beginningOffsets = new HashMap<>(); + beginningOffsets.put(topicPartition, 0L); + consumer.updateBeginningOffsets(beginningOffsets); + + streamsResetter.shiftOffsetsBy(consumer, inputTopicPartitions, 5L); + + final ConsumerRecords records = consumer.poll(Duration.ofMillis(500)); + assertEquals(2, records.count()); + } + + @Test + public void testResetUsingPlanWhenBetweenBeginningAndEndOffset() { + final Map endOffsets = new HashMap<>(); + endOffsets.put(topicPartition, 4L); + consumer.updateEndOffsets(endOffsets); + + final Map beginningOffsets = new HashMap<>(); + beginningOffsets.put(topicPartition, 0L); + consumer.updateBeginningOffsets(beginningOffsets); + + final Map topicPartitionsAndOffset = new HashMap<>(); + topicPartitionsAndOffset.put(topicPartition, 3L); + streamsResetter.resetOffsetsFromResetPlan(consumer, inputTopicPartitions, topicPartitionsAndOffset); + + final ConsumerRecords records = consumer.poll(Duration.ofMillis(500)); + assertEquals(2, records.count()); + } + + @Test + public void testResetUsingPlanWhenBeforeBeginningOffset() { + final Map endOffsets = new HashMap<>(); + endOffsets.put(topicPartition, 4L); + consumer.updateEndOffsets(endOffsets); + + final Map beginningOffsets = new HashMap<>(); + beginningOffsets.put(topicPartition, 3L); + consumer.updateBeginningOffsets(beginningOffsets); + + final Map topicPartitionsAndOffset = new HashMap<>(); + topicPartitionsAndOffset.put(topicPartition, 1L); + streamsResetter.resetOffsetsFromResetPlan(consumer, inputTopicPartitions, topicPartitionsAndOffset); + + final ConsumerRecords records = consumer.poll(Duration.ofMillis(500)); + assertEquals(2, records.count()); + } + + @Test + public void testResetUsingPlanWhenAfterEndOffset() { + final Map endOffsets = new HashMap<>(); + endOffsets.put(topicPartition, 3L); + consumer.updateEndOffsets(endOffsets); + + final Map beginningOffsets = new HashMap<>(); + beginningOffsets.put(topicPartition, 0L); + consumer.updateBeginningOffsets(beginningOffsets); + + final Map topicPartitionsAndOffset = new HashMap<>(); + topicPartitionsAndOffset.put(topicPartition, 5L); + streamsResetter.resetOffsetsFromResetPlan(consumer, inputTopicPartitions, topicPartitionsAndOffset); + + final ConsumerRecords records = consumer.poll(Duration.ofMillis(500)); + assertEquals(2, records.count()); + } + + @Test + public void shouldSeekToEndOffset() { + final Map endOffsets = new HashMap<>(); + endOffsets.put(topicPartition, 3L); + consumer.updateEndOffsets(endOffsets); + + final Map beginningOffsets = new HashMap<>(); + beginningOffsets.put(topicPartition, 0L); + consumer.updateBeginningOffsets(beginningOffsets); + + final Set intermediateTopicPartitions = new HashSet<>(); + intermediateTopicPartitions.add(topicPartition); + streamsResetter.maybeSeekToEnd("g1", consumer, intermediateTopicPartitions); + + final ConsumerRecords records = consumer.poll(Duration.ofMillis(500)); + assertEquals(2, records.count()); + } + + @Test + public void shouldDeleteTopic() throws InterruptedException, ExecutionException { + final Cluster cluster = createCluster(1); + try (final MockAdminClient adminClient = new MockAdminClient(cluster.nodes(), cluster.nodeById(0))) { + final TopicPartitionInfo topicPartitionInfo = new TopicPartitionInfo(0, cluster.nodeById(0), cluster.nodes(), Collections.emptyList()); + adminClient.addTopic(false, TOPIC, Collections.singletonList(topicPartitionInfo), null); + streamsResetter.doDelete(Collections.singletonList(TOPIC), adminClient); + assertEquals(Collections.emptySet(), adminClient.listTopics().names().get()); + } + } + + @Test + public void shouldDetermineInternalTopicBasedOnTopicName1() { + assertTrue(StreamsResetter.matchesInternalTopicFormat("appId-named-subscription-response-topic")); + assertTrue(StreamsResetter.matchesInternalTopicFormat("appId-named-subscription-registration-topic")); + assertTrue(StreamsResetter.matchesInternalTopicFormat("appId-KTABLE-FK-JOIN-SUBSCRIPTION-RESPONSE-12323232-topic")); + assertTrue(StreamsResetter.matchesInternalTopicFormat("appId-KTABLE-FK-JOIN-SUBSCRIPTION-REGISTRATION-12323232-topic")); + } + + @Test + public void testResetToDatetimeWhenPartitionIsEmptyResetsToLatestOffset() { + final long beginningAndEndOffset = 5L; // Empty partition implies beginning offset == end offset + final MockConsumer emptyConsumer = new EmptyPartitionConsumer<>(OffsetResetStrategy.EARLIEST); + emptyConsumer.assign(Collections.singletonList(topicPartition)); + + final Map beginningOffsetsMap = new HashMap<>(); + beginningOffsetsMap.put(topicPartition, beginningAndEndOffset); + emptyConsumer.updateBeginningOffsets(beginningOffsetsMap); + + final Map endOffsetsMap = new HashMap<>(); + endOffsetsMap.put(topicPartition, beginningAndEndOffset); + emptyConsumer.updateEndOffsets(endOffsetsMap); + + final long yesterdayTimestamp = Instant.now().minus(Duration.ofDays(1)).toEpochMilli(); + // resetToDatetime only seeks the offset, but does not commit. + streamsResetter.resetToDatetime(emptyConsumer, inputTopicPartitions, yesterdayTimestamp); + + final long position = emptyConsumer.position(topicPartition); + + assertEquals(beginningAndEndOffset, position); + } + + private Cluster createCluster(final int numNodes) { + final HashMap nodes = new HashMap<>(); + for (int i = 0; i < numNodes; ++i) { + nodes.put(i, new Node(i, "localhost", 8121 + i)); + } + return new Cluster("mockClusterId", nodes.values(), + Collections.emptySet(), Collections.emptySet(), + Collections.emptySet(), nodes.get(0)); + } + + private static class EmptyPartitionConsumer extends MockConsumer { + + public EmptyPartitionConsumer(final OffsetResetStrategy offsetResetStrategy) { + super(offsetResetStrategy); + } + + @Override + public synchronized Map offsetsForTimes(final Map timestampsToSearch) { + final Map topicPartitionToOffsetAndTimestamp = new HashMap<>(); + timestampsToSearch.keySet().forEach(k -> topicPartitionToOffsetAndTimestamp.put(k, null)); + return topicPartitionToOffsetAndTimestamp; + } + } + +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/streams/utils/UniqueTopicSerdeScope.java b/streams/src/test/java/org/apache/kafka/streams/utils/UniqueTopicSerdeScope.java new file mode 100644 index 0000000..c385187 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/streams/utils/UniqueTopicSerdeScope.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.utils; + +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serializer; + +import java.util.Collections; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.TreeMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; + +public class UniqueTopicSerdeScope { + private final Map> topicTypeRegistry = new TreeMap<>(); + + public UniqueTopicSerdeDecorator decorateSerde(final Serde delegate, + final Properties config, + final boolean isKey) { + final UniqueTopicSerdeDecorator decorator = new UniqueTopicSerdeDecorator<>(delegate); + decorator.configure(config.entrySet().stream().collect(Collectors.toMap(e -> e.getKey().toString(), Map.Entry::getValue)), isKey); + return decorator; + } + + public Set registeredTopics() { + return Collections.unmodifiableSet(topicTypeRegistry.keySet()); + } + + public class UniqueTopicSerdeDecorator implements Serde { + private final AtomicBoolean isKey = new AtomicBoolean(false); + private final Serde delegate; + + public UniqueTopicSerdeDecorator(final Serde delegate) { + this.delegate = delegate; + } + + @Override + public void configure(final Map configs, final boolean isKey) { + delegate.configure(configs, isKey); + this.isKey.set(isKey); + } + + @Override + public void close() { + delegate.close(); + } + + @Override + public Serializer serializer() { + return new UniqueTopicSerializerDecorator<>(isKey, delegate.serializer()); + } + + @Override + public Deserializer deserializer() { + return new UniqueTopicDeserializerDecorator<>(isKey, delegate.deserializer()); + } + } + + public class UniqueTopicSerializerDecorator implements Serializer { + private final AtomicBoolean isKey; + private final Serializer delegate; + + public UniqueTopicSerializerDecorator(final AtomicBoolean isKey, final Serializer delegate) { + this.isKey = isKey; + this.delegate = delegate; + } + + @Override + public void configure(final Map configs, final boolean isKey) { + delegate.configure(configs, isKey); + this.isKey.set(isKey); + } + + @Override + public byte[] serialize(final String topic, final T data) { + verifyTopic(topic, data); + return delegate.serialize(topic, data); + } + + @Override + public byte[] serialize(final String topic, final Headers headers, final T data) { + verifyTopic(topic, data); + return delegate.serialize(topic, headers, data); + } + + private void verifyTopic(final String topic, final T data) { + if (data != null) { + final String key = topic + (isKey.get() ? "--key" : "--value"); + if (topicTypeRegistry.containsKey(key)) { + assertThat(String.format("key[%s] data[%s][%s]", key, data, data.getClass()), topicTypeRegistry.get(key), equalTo(data.getClass())); + } else { + topicTypeRegistry.put(key, data.getClass()); + } + } + } + + @Override + public void close() { + delegate.close(); + } + } + + public class UniqueTopicDeserializerDecorator implements Deserializer { + private final AtomicBoolean isKey; + private final Deserializer delegate; + + public UniqueTopicDeserializerDecorator(final AtomicBoolean isKey, final Deserializer delegate) { + this.isKey = isKey; + this.delegate = delegate; + } + + @Override + public void configure(final Map configs, final boolean isKey) { + delegate.configure(configs, isKey); + this.isKey.set(isKey); + } + + @Override + public T deserialize(final String topic, final byte[] data) { + return delegate.deserialize(topic, data); + } + + @Override + public T deserialize(final String topic, final Headers headers, final byte[] data) { + return delegate.deserialize(topic, headers, data); + } + + @Override + public void close() { + delegate.close(); + } + } +} diff --git a/streams/src/test/java/org/apache/kafka/test/GenericInMemoryKeyValueStore.java b/streams/src/test/java/org/apache/kafka/test/GenericInMemoryKeyValueStore.java new file mode 100644 index 0000000..7c3af25 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/GenericInMemoryKeyValueStore.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.internals.CacheFlushListener; +import org.apache.kafka.streams.state.internals.DelegatingPeekingKeyValueIterator; +import org.apache.kafka.streams.state.internals.WrappedStateStore; + +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.NavigableMap; +import java.util.TreeMap; + +/** + * This class is a generic version of the in-memory key-value store that is useful for testing when you + * need a basic KeyValueStore for arbitrary types and don't have/want to write a serde + */ +public class GenericInMemoryKeyValueStore + extends WrappedStateStore + implements KeyValueStore { + + private final String name; + private final NavigableMap map; + private volatile boolean open = false; + + public GenericInMemoryKeyValueStore(final String name) { + // it's not really a `WrappedStateStore` so we pass `null` + // however, we need to implement `WrappedStateStore` to make the store usable + super(null); + this.name = name; + + this.map = new TreeMap<>(); + } + + @Override + public String name() { + return this.name; + } + + @SuppressWarnings("deprecation") + @Deprecated + @Override + /* This is a "dummy" store used for testing; + it does not support restoring from changelog since we allow it to be serde-ignorant */ + public void init(final ProcessorContext context, final StateStore root) { + if (root != null) { + context.register(root, null); + } + + this.open = true; + } + + @Override + public boolean setFlushListener(final CacheFlushListener listener, + final boolean sendOldValues) { + return false; + } + + @Override + public boolean persistent() { + return false; + } + + @Override + public boolean isOpen() { + return this.open; + } + + @Override + public synchronized V get(final K key) { + return this.map.get(key); + } + + @Override + public synchronized void put(final K key, + final V value) { + if (value == null) { + this.map.remove(key); + } else { + this.map.put(key, value); + } + } + + @Override + public synchronized V putIfAbsent(final K key, + final V value) { + final V originalValue = get(key); + if (originalValue == null) { + put(key, value); + } + return originalValue; + } + + @Override + public synchronized void putAll(final List> entries) { + for (final KeyValue entry : entries) { + put(entry.key, entry.value); + } + } + + @Override + public synchronized V delete(final K key) { + return this.map.remove(key); + } + + @Override + public synchronized KeyValueIterator range(final K from, + final K to) { + return new DelegatingPeekingKeyValueIterator<>( + name, + new GenericInMemoryKeyValueIterator<>(this.map.subMap(from, true, to, true).entrySet().iterator())); + } + + @Override + public synchronized KeyValueIterator all() { + final TreeMap copy = new TreeMap<>(this.map); + return new DelegatingPeekingKeyValueIterator<>(name, new GenericInMemoryKeyValueIterator<>(copy.entrySet().iterator())); + } + + @Override + public long approximateNumEntries() { + return this.map.size(); + } + + @Override + public void flush() { + // do-nothing since it is in-memory + } + + @Override + public void close() { + this.map.clear(); + this.open = false; + } + + private static class GenericInMemoryKeyValueIterator implements KeyValueIterator { + private final Iterator> iter; + + private GenericInMemoryKeyValueIterator(final Iterator> iter) { + this.iter = iter; + } + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public KeyValue next() { + final Map.Entry entry = iter.next(); + return new KeyValue<>(entry.getKey(), entry.getValue()); + } + + @Override + public void close() { + // do nothing + } + + @Override + public K peekNextKey() { + throw new UnsupportedOperationException("peekNextKey() not supported in " + getClass().getName()); + } + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/test/GenericInMemoryTimestampedKeyValueStore.java b/streams/src/test/java/org/apache/kafka/test/GenericInMemoryTimestampedKeyValueStore.java new file mode 100644 index 0000000..114ea06 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/GenericInMemoryTimestampedKeyValueStore.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.streams.state.internals.CacheFlushListener; +import org.apache.kafka.streams.state.internals.DelegatingPeekingKeyValueIterator; +import org.apache.kafka.streams.state.internals.WrappedStateStore; + +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.NavigableMap; +import java.util.TreeMap; + +/** + * This class is a generic version of the in-memory key-value store that is useful for testing when you + * need a basic KeyValueStore for arbitrary types and don't have/want to write a serde + */ +@SuppressWarnings("deprecation") +public class GenericInMemoryTimestampedKeyValueStore + extends WrappedStateStore> + implements TimestampedKeyValueStore { + + private final String name; + private final NavigableMap> map; + private volatile boolean open = false; + + public GenericInMemoryTimestampedKeyValueStore(final String name) { + // it's not really a `WrappedStateStore` so we pass `null` + // however, we need to implement `WrappedStateStore` to make the store usable + super(null); + this.name = name; + + this.map = new TreeMap<>(); + } + + @Override + public String name() { + return this.name; + } + + @Deprecated + @Override + /* This is a "dummy" store used for testing; + it does not support restoring from changelog since we allow it to be serde-ignorant */ + public void init(final ProcessorContext context, final StateStore root) { + if (root != null) { + context.register(root, null); + } + + this.open = true; + } + + @Override + public boolean setFlushListener(final CacheFlushListener> listener, + final boolean sendOldValues) { + return false; + } + + @Override + public boolean persistent() { + return false; + } + + @Override + public boolean isOpen() { + return this.open; + } + + @Override + public synchronized ValueAndTimestamp get(final K key) { + return this.map.get(key); + } + + @Override + public synchronized void put(final K key, + final ValueAndTimestamp value) { + if (value == null) { + this.map.remove(key); + } else { + this.map.put(key, value); + } + } + + @Override + public synchronized ValueAndTimestamp putIfAbsent(final K key, + final ValueAndTimestamp value) { + final ValueAndTimestamp originalValue = get(key); + if (originalValue == null) { + put(key, value); + } + return originalValue; + } + + @Override + public synchronized void putAll(final List>> entries) { + for (final KeyValue> entry : entries) { + put(entry.key, entry.value); + } + } + + @Override + public synchronized ValueAndTimestamp delete(final K key) { + return this.map.remove(key); + } + + @Override + public synchronized KeyValueIterator> range(final K from, + final K to) { + return new DelegatingPeekingKeyValueIterator<>( + name, + new GenericInMemoryKeyValueIterator<>(this.map.subMap(from, true, to, true).entrySet().iterator())); + } + + @Override + public synchronized KeyValueIterator> all() { + final TreeMap> copy = new TreeMap<>(this.map); + return new DelegatingPeekingKeyValueIterator<>(name, new GenericInMemoryKeyValueIterator<>(copy.entrySet().iterator())); + } + + @Override + public long approximateNumEntries() { + return this.map.size(); + } + + @Override + public void flush() { + // do-nothing since it is in-memory + } + + @Override + public void close() { + this.map.clear(); + this.open = false; + } + + private static class GenericInMemoryKeyValueIterator implements KeyValueIterator> { + private final Iterator>> iter; + + private GenericInMemoryKeyValueIterator(final Iterator>> iter) { + this.iter = iter; + } + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public KeyValue> next() { + final Map.Entry> entry = iter.next(); + return new KeyValue<>(entry.getKey(), entry.getValue()); + } + + @Override + public void close() { + // do nothing + } + + @Override + public K peekNextKey() { + throw new UnsupportedOperationException("peekNextKey() not supported in " + getClass().getName()); + } + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/test/GlobalStateManagerStub.java b/streams/src/test/java/org/apache/kafka/test/GlobalStateManagerStub.java new file mode 100644 index 0000000..9ea836b --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/GlobalStateManagerStub.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.streams.processor.StateRestoreCallback; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.internals.GlobalStateManager; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.processor.internals.Task.TaskType; + +import java.io.File; +import java.util.Map; +import java.util.Set; + +public class GlobalStateManagerStub implements GlobalStateManager { + + private final Set storeNames; + private final Map offsets; + private final File baseDirectory; + public boolean initialized; + public boolean closed; + + public GlobalStateManagerStub(final Set storeNames, + final Map offsets, + final File baseDirectory) { + this.storeNames = storeNames; + this.offsets = offsets; + this.baseDirectory = baseDirectory; + } + + @Override + public void setGlobalProcessorContext(final InternalProcessorContext processorContext) {} + + @Override + public Set initialize() { + initialized = true; + return storeNames; + } + + @Override + public File baseDir() { + return baseDirectory; + } + + @Override + public void registerStore(final StateStore store, final StateRestoreCallback stateRestoreCallback) {} + + @Override + public void flush() {} + + @Override + public void close() { + closed = true; + } + + @Override + public void updateChangelogOffsets(final Map writtenOffsets) { + this.offsets.putAll(writtenOffsets); + } + + @Override + public void checkpoint() {} + + @Override + public StateStore getStore(final String name) { + return null; + } + + @Override + public StateStore getGlobalStore(final String name) { + return null; + } + + @Override + public Map changelogOffsets() { + return offsets; + } + + @Override + public TaskType taskType() { + return TaskType.GLOBAL; + } + + @Override + public String changelogFor(final String storeName) { + return null; + } +} diff --git a/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java b/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java new file mode 100644 index 0000000..c3cf64a --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/InternalMockProcessorContext.java @@ -0,0 +1,473 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.processor.Cancellable; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.Punctuator; +import org.apache.kafka.streams.processor.StateRestoreCallback; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.To; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.AbstractProcessorContext; +import org.apache.kafka.streams.processor.internals.ProcessorNode; +import org.apache.kafka.streams.processor.internals.ProcessorRecordContext; +import org.apache.kafka.streams.processor.internals.RecordBatchingStateRestoreCallback; +import org.apache.kafka.streams.processor.internals.RecordCollector; +import org.apache.kafka.streams.processor.internals.StateManager; +import org.apache.kafka.streams.processor.internals.StateManagerStub; +import org.apache.kafka.streams.processor.internals.StreamTask; +import org.apache.kafka.streams.processor.internals.Task.TaskType; +import org.apache.kafka.streams.processor.internals.ToInternal; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.StateSerdes; +import org.apache.kafka.streams.state.internals.ThreadCache; +import org.apache.kafka.streams.state.internals.ThreadCache.DirtyEntryFlushListener; + +import java.io.File; +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.kafka.streams.processor.internals.StateRestoreCallbackAdapter.adapt; + +public class InternalMockProcessorContext + extends AbstractProcessorContext + implements RecordCollector.Supplier { + + private StateManager stateManager = new StateManagerStub(); + private final File stateDir; + private final RecordCollector.Supplier recordCollectorSupplier; + private final Map storeMap = new LinkedHashMap<>(); + private final Map restoreFuncs = new HashMap<>(); + private final ToInternal toInternal = new ToInternal(); + + private TaskType taskType = TaskType.ACTIVE; + private Serde keySerde; + private Serde valueSerde; + private long timestamp = -1L; + private final Time time; + private final Map storeToChangelogTopic = new HashMap<>(); + + public InternalMockProcessorContext() { + this(null, + null, + null, + new StreamsMetricsImpl(new Metrics(), "mock", StreamsConfig.METRICS_LATEST, new MockTime()), + new StreamsConfig(StreamsTestUtils.getStreamsConfig()), + null, + null, + Time.SYSTEM + ); + } + + public InternalMockProcessorContext(final File stateDir, + final StreamsConfig config) { + this( + stateDir, + null, + null, + new StreamsMetricsImpl( + new Metrics(), + "mock", + config.getString(StreamsConfig.BUILT_IN_METRICS_VERSION_CONFIG), + new MockTime() + ), + config, + null, + null, + Time.SYSTEM + ); + } + + public InternalMockProcessorContext(final StreamsMetricsImpl streamsMetrics) { + this( + null, + null, + null, + streamsMetrics, + new StreamsConfig(StreamsTestUtils.getStreamsConfig()), + null, + null, + Time.SYSTEM + ); + } + + public InternalMockProcessorContext(final File stateDir, + final StreamsConfig config, + final RecordCollector collector) { + this( + stateDir, + null, + null, + new StreamsMetricsImpl( + new Metrics(), + "mock", + config.getString(StreamsConfig.BUILT_IN_METRICS_VERSION_CONFIG), + new MockTime() + ), + config, + () -> collector, + null, + Time.SYSTEM + ); + } + + public InternalMockProcessorContext(final File stateDir, + final Serde keySerde, + final Serde valueSerde, + final StreamsConfig config) { + this( + stateDir, + keySerde, + valueSerde, + new StreamsMetricsImpl(new Metrics(), "mock", StreamsConfig.METRICS_LATEST, new MockTime()), + config, + null, + null, + Time.SYSTEM + ); + } + + public InternalMockProcessorContext(final StateSerdes serdes, + final RecordCollector collector) { + this(null, serdes.keySerde(), serdes.valueSerde(), collector, null); + } + + public InternalMockProcessorContext(final StateSerdes serdes, + final RecordCollector collector, + final Metrics metrics) { + this( + null, + serdes.keySerde(), + serdes.valueSerde(), + new StreamsMetricsImpl(metrics, "mock", StreamsConfig.METRICS_LATEST, new MockTime()), + new StreamsConfig(StreamsTestUtils.getStreamsConfig()), + () -> collector, + null, + Time.SYSTEM + ); + } + + public InternalMockProcessorContext(final File stateDir, + final Serde keySerde, + final Serde valueSerde, + final RecordCollector collector, + final ThreadCache cache) { + this( + stateDir, + keySerde, + valueSerde, + new StreamsMetricsImpl(new Metrics(), "mock", StreamsConfig.METRICS_LATEST, new MockTime()), + new StreamsConfig(StreamsTestUtils.getStreamsConfig()), + () -> collector, + cache, + Time.SYSTEM + ); + } + + public InternalMockProcessorContext(final File stateDir, + final Serde keySerde, + final Serde valueSerde, + final StreamsMetricsImpl metrics, + final StreamsConfig config, + final RecordCollector.Supplier collectorSupplier, + final ThreadCache cache, + final Time time) { + this(stateDir, keySerde, valueSerde, metrics, config, collectorSupplier, cache, time, new TaskId(0, 0)); + } + + public InternalMockProcessorContext(final File stateDir, + final Serde keySerde, + final Serde valueSerde, + final StreamsMetricsImpl metrics, + final StreamsConfig config, + final RecordCollector.Supplier collectorSupplier, + final ThreadCache cache, + final Time time, + final TaskId taskId) { + super( + taskId, + config, + metrics, + cache + ); + super.setCurrentNode(new ProcessorNode<>("TESTING_NODE")); + this.stateDir = stateDir; + this.keySerde = keySerde; + this.valueSerde = valueSerde; + this.recordCollectorSupplier = collectorSupplier; + this.time = time; + } + + @Override + protected StateManager stateManager() { + return stateManager; + } + + public void setStateManger(final StateManager stateManger) { + this.stateManager = stateManger; + } + + @Override + public RecordCollector recordCollector() { + final RecordCollector recordCollector = recordCollectorSupplier.recordCollector(); + + if (recordCollector == null) { + throw new UnsupportedOperationException("No RecordCollector specified"); + } + return recordCollector; + } + + public void setKeySerde(final Serde keySerde) { + this.keySerde = keySerde; + } + + public void setValueSerde(final Serde valueSerde) { + this.valueSerde = valueSerde; + } + + @Override + public Serde keySerde() { + return keySerde; + } + + @Override + public Serde valueSerde() { + return valueSerde; + } + + // state mgr will be overridden by the state dir and store maps + @Override + public void initialize() {} + + @Override + public File stateDir() { + if (stateDir == null) { + throw new UnsupportedOperationException("State directory not specified"); + } + return stateDir; + } + + @Override + public void register(final StateStore store, + final StateRestoreCallback func) { + storeMap.put(store.name(), store); + restoreFuncs.put(store.name(), func); + stateManager().registerStore(store, func); + } + + @SuppressWarnings("unchecked") + @Override + public S getStateStore(final String name) { + return (S) storeMap.get(name); + } + + @Override + public Cancellable schedule(final Duration interval, + final PunctuationType type, + final Punctuator callback) throws IllegalArgumentException { + throw new UnsupportedOperationException("schedule() not supported."); + } + + @Override + public void commit() {} + + @Override + public void forward(final Record record) { + forward(record, null); + } + + @SuppressWarnings("unchecked") + @Override + public void forward(final Record record, final String childName) { + if (recordContext != null && record.timestamp() != recordContext.timestamp()) { + setTime(record.timestamp()); + } + final ProcessorNode thisNode = currentNode; + try { + for (final ProcessorNode childNode : thisNode.children()) { + currentNode = childNode; + ((ProcessorNode) childNode).process(record); + } + } finally { + currentNode = thisNode; + } + } + + @Override + public void forward(final Object key, final Object value) { + forward(key, value, To.all()); + } + + @SuppressWarnings("unchecked") + @Override + public void forward(final Object key, final Object value, final To to) { + toInternal.update(to); + if (toInternal.hasTimestamp()) { + setTime(toInternal.timestamp()); + } + final ProcessorNode thisNode = currentNode; + try { + for (final ProcessorNode childNode : thisNode.children()) { + if (toInternal.child() == null || toInternal.child().equals(childNode.name())) { + currentNode = childNode; + final Record record = new Record<>(key, value, toInternal.timestamp(), headers()); + ((ProcessorNode) childNode).process(record); + toInternal.update(to); // need to reset because MockProcessorContext is shared over multiple + // Processors and toInternal might have been modified + } + } + } finally { + currentNode = thisNode; + } + } + + // allow only setting time but not other fields in for record context, + // and also not throwing exceptions if record context is not available. + public void setTime(final long timestamp) { + if (recordContext != null) { + recordContext = new ProcessorRecordContext( + timestamp, + recordContext.offset(), + recordContext.partition(), + recordContext.topic(), + recordContext.headers() + ); + } + this.timestamp = timestamp; + } + + @Override + public long timestamp() { + if (recordContext == null) { + return timestamp; + } + return recordContext.timestamp(); + } + + @Override + public long currentSystemTimeMs() { + return time.milliseconds(); + } + + @Override + public long currentStreamTimeMs() { + throw new UnsupportedOperationException("this method is not supported in InternalMockProcessorContext"); + } + + @Override + public String topic() { + if (recordContext == null) { + return null; + } + return recordContext.topic(); + } + + @Override + public int partition() { + if (recordContext == null) { + return -1; + } + return recordContext.partition(); + } + + @Override + public long offset() { + if (recordContext == null) { + return -1L; + } + return recordContext.offset(); + } + + @Override + public Headers headers() { + if (recordContext == null) { + return new RecordHeaders(); + } + return recordContext.headers(); + } + + @Override + public TaskType taskType() { + return taskType; + } + + @Override + public void logChange(final String storeName, + final Bytes key, + final byte[] value, + final long timestamp) { + recordCollector().send( + storeName + "-changelog", + key, + value, + null, + taskId().partition(), + timestamp, + BYTES_KEY_SERIALIZER, + BYTEARRAY_VALUE_SERIALIZER); + } + + @Override + public void transitionToActive(final StreamTask streamTask, final RecordCollector recordCollector, final ThreadCache newCache) { + taskType = TaskType.ACTIVE; + } + + @Override + public void transitionToStandby(final ThreadCache newCache) { + taskType = TaskType.STANDBY; + } + + @Override + public void registerCacheFlushListener(final String namespace, final DirtyEntryFlushListener listener) { + cache().addDirtyEntryFlushListener(namespace, listener); + } + + public void restore(final String storeName, final Iterable> changeLog) { + final RecordBatchingStateRestoreCallback restoreCallback = adapt(restoreFuncs.get(storeName)); + + final List> records = new ArrayList<>(); + for (final KeyValue keyValue : changeLog) { + records.add(new ConsumerRecord<>("", 0, 0L, keyValue.key, keyValue.value)); + } + restoreCallback.restoreBatch(records); + } + + public void addChangelogForStore(final String storeName, final String changelogTopic) { + storeToChangelogTopic.put(storeName, changelogTopic); + } + + @Override + public String changelogFor(final String storeName) { + return storeToChangelogTopic.get(storeName); + } +} diff --git a/streams/src/test/java/org/apache/kafka/test/KeyValueIteratorStub.java b/streams/src/test/java/org/apache/kafka/test/KeyValueIteratorStub.java new file mode 100644 index 0000000..aa3a4e9 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/KeyValueIteratorStub.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.state.KeyValueIterator; + +import java.util.Iterator; + +public class KeyValueIteratorStub implements KeyValueIterator { + + private final Iterator> iterator; + + public KeyValueIteratorStub(final Iterator> iterator) { + this.iterator = iterator; + } + + @Override + public void close() { + //no-op + } + + @Override + public K peekNextKey() { + return null; + } + + @Override + public boolean hasNext() { + return iterator.hasNext(); + } + + @Override + public KeyValue next() { + return iterator.next(); + } + +} diff --git a/streams/src/test/java/org/apache/kafka/test/MockAggregator.java b/streams/src/test/java/org/apache/kafka/test/MockAggregator.java new file mode 100644 index 0000000..42416ec --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/MockAggregator.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.streams.kstream.Aggregator; + +public class MockAggregator { + + public final static Aggregator TOSTRING_ADDER = toStringInstance("+"); + public final static Aggregator TOSTRING_REMOVER = toStringInstance("-"); + + public static Aggregator toStringInstance(final String sep) { + return (aggKey, value, aggregate) -> aggregate + sep + value; + } +} diff --git a/streams/src/test/java/org/apache/kafka/test/MockApiProcessor.java b/streams/src/test/java/org/apache/kafka/test/MockApiProcessor.java new file mode 100644 index 0000000..dd56bad --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/MockApiProcessor.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.processor.Cancellable; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.state.ValueAndTimestamp; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; + +public class MockApiProcessor implements Processor { + + private final ArrayList> processed = new ArrayList<>(); + private final Map> lastValueAndTimestampPerKey = new HashMap<>(); + + private final ArrayList punctuatedStreamTime = new ArrayList<>(); + private final ArrayList punctuatedSystemTime = new ArrayList<>(); + + private Cancellable scheduleCancellable; + + private final PunctuationType punctuationType; + private final long scheduleInterval; + + private boolean commitRequested = false; + private ProcessorContext context; + + public MockApiProcessor(final PunctuationType punctuationType, + final long scheduleInterval) { + this.punctuationType = punctuationType; + this.scheduleInterval = scheduleInterval; + } + + public MockApiProcessor() { + this(PunctuationType.STREAM_TIME, -1); + } + + @Override + public void init(final ProcessorContext context) { + this.context = context; + if (scheduleInterval > 0L) { + scheduleCancellable = context.schedule( + Duration.ofMillis(scheduleInterval), + punctuationType, + (punctuationType == PunctuationType.STREAM_TIME ? punctuatedStreamTime : punctuatedSystemTime)::add + ); + } + } + + @Override + public void process(final Record record) { + final KIn key = record.key(); + final VIn value = record.value(); + final KeyValueTimestamp keyValueTimestamp = new KeyValueTimestamp<>(key, value, record.timestamp()); + + if (value != null) { + lastValueAndTimestampPerKey.put(key, ValueAndTimestamp.make(value, record.timestamp())); + } else { + lastValueAndTimestampPerKey.remove(key); + } + + processed.add(keyValueTimestamp); + + if (commitRequested) { + context.commit(); + commitRequested = false; + } + } + + public void checkAndClearProcessResult(final KeyValueTimestamp... expected) { + assertThat("the number of outputs:" + processed, processed.size(), is(expected.length)); + for (int i = 0; i < expected.length; i++) { + assertThat("output[" + i + "]:", processed.get(i), is(expected[i])); + } + + processed.clear(); + } + + public void requestCommit() { + commitRequested = true; + } + + public void checkEmptyAndClearProcessResult() { + assertThat("the number of outputs:", processed.size(), is(0)); + processed.clear(); + } + + public void checkAndClearPunctuateResult(final PunctuationType type, final long... expected) { + final ArrayList punctuated = type == PunctuationType.STREAM_TIME ? punctuatedStreamTime : punctuatedSystemTime; + assertThat("the number of outputs:", punctuated.size(), is(expected.length)); + + for (int i = 0; i < expected.length; i++) { + assertThat("output[" + i + "]:", punctuated.get(i), is(expected[i])); + } + + processed.clear(); + } + + public ArrayList> processed() { + return processed; + } + + public Map> lastValueAndTimestampPerKey() { + return lastValueAndTimestampPerKey; + } + + public List punctuatedStreamTime() { + return punctuatedStreamTime; + } + + public Cancellable scheduleCancellable() { + return scheduleCancellable; + } + + public ProcessorContext context() { + return context; + } + + public void context(final ProcessorContext context) { + this.context = context; + } +} diff --git a/streams/src/test/java/org/apache/kafka/test/MockApiProcessorSupplier.java b/streams/src/test/java/org/apache/kafka/test/MockApiProcessorSupplier.java new file mode 100644 index 0000000..af90ddc --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/MockApiProcessorSupplier.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.PunctuationType; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +public class MockApiProcessorSupplier implements ProcessorSupplier { + + private final long scheduleInterval; + private final PunctuationType punctuationType; + private final List> processors = new ArrayList<>(); + + public MockApiProcessorSupplier() { + this(-1L); + } + + public MockApiProcessorSupplier(final long scheduleInterval) { + this(scheduleInterval, PunctuationType.STREAM_TIME); + } + + public MockApiProcessorSupplier(final long scheduleInterval, final PunctuationType punctuationType) { + this.scheduleInterval = scheduleInterval; + this.punctuationType = punctuationType; + } + + @Override + public Processor get() { + final MockApiProcessor processor = new MockApiProcessor<>(punctuationType, scheduleInterval); + + // to keep tests simple, ignore calls from ApiUtils.checkSupplier + if (!StreamsTestUtils.isCheckSupplierCall()) { + processors.add(processor); + } + + return processor; + } + + // get the captured processor assuming that only one processor gets returned from this supplier + public MockApiProcessor theCapturedProcessor() { + return capturedProcessors(1).get(0); + } + + public int capturedProcessorsCount() { + return processors.size(); + } + + // get the captured processors with the expected number + public List> capturedProcessors(final int expectedNumberOfProcessors) { + assertEquals(expectedNumberOfProcessors, processors.size()); + + return processors; + } +} diff --git a/streams/src/test/java/org/apache/kafka/test/MockClientSupplier.java b/streams/src/test/java/org/apache/kafka/test/MockClientSupplier.java new file mode 100644 index 0000000..880f2cb --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/MockClientSupplier.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.MockAdminClient; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.MockConsumer; +import org.apache.kafka.clients.consumer.OffsetResetStrategy; +import org.apache.kafka.clients.producer.MockProducer; +import org.apache.kafka.clients.producer.Producer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.internals.DefaultPartitioner; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.streams.KafkaClientSupplier; + +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.CoreMatchers.startsWith; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertFalse; + +public class MockClientSupplier implements KafkaClientSupplier { + private static final ByteArraySerializer BYTE_ARRAY_SERIALIZER = new ByteArraySerializer(); + + private Cluster cluster; + private String applicationId; + + public MockAdminClient adminClient = new MockAdminClient(); + public final List> producers = new LinkedList<>(); + public final MockConsumer consumer = new MockConsumer<>(OffsetResetStrategy.EARLIEST); + public final MockConsumer restoreConsumer = new MockConsumer<>(OffsetResetStrategy.LATEST); + + public void setApplicationIdForProducer(final String applicationId) { + this.applicationId = applicationId; + } + + public void setCluster(final Cluster cluster) { + this.cluster = cluster; + this.adminClient = new MockAdminClient(cluster.nodes(), cluster.nodeById(-1)); + } + + @Override + public Admin getAdmin(final Map config) { + return adminClient; + } + + @Override + public Producer getProducer(final Map config) { + if (applicationId != null) { + assertThat((String) config.get(ProducerConfig.TRANSACTIONAL_ID_CONFIG), startsWith(applicationId + "-")); + } else { + assertFalse(config.containsKey(ProducerConfig.TRANSACTIONAL_ID_CONFIG)); + } + final MockProducer producer = new MockProducer<>(cluster, true, new DefaultPartitioner(), BYTE_ARRAY_SERIALIZER, BYTE_ARRAY_SERIALIZER); + producers.add(producer); + return producer; + } + + @Override + public Consumer getConsumer(final Map config) { + return consumer; + } + + @Override + public Consumer getRestoreConsumer(final Map config) { + return restoreConsumer; + } + + @Override + public Consumer getGlobalConsumer(final Map config) { + return restoreConsumer; + } +} diff --git a/streams/src/test/java/org/apache/kafka/test/MockInitializer.java b/streams/src/test/java/org/apache/kafka/test/MockInitializer.java new file mode 100644 index 0000000..d5d69c8 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/MockInitializer.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.streams.kstream.Initializer; + +public class MockInitializer { + + private static class StringInit implements Initializer { + @Override + public String apply() { + return "0"; + } + } + + public final static Initializer STRING_INIT = new StringInit(); +} diff --git a/streams/src/test/java/org/apache/kafka/test/MockInternalNewProcessorContext.java b/streams/src/test/java/org/apache/kafka/test/MockInternalNewProcessorContext.java new file mode 100644 index 0000000..ffb503e --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/MockInternalNewProcessorContext.java @@ -0,0 +1,198 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.processor.StateRestoreCallback; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.To; +import org.apache.kafka.streams.processor.api.MockProcessorContext; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.processor.internals.ProcessorNode; +import org.apache.kafka.streams.processor.internals.ProcessorRecordContext; +import org.apache.kafka.streams.processor.internals.RecordCollector; +import org.apache.kafka.streams.processor.internals.StreamTask; +import org.apache.kafka.streams.processor.internals.Task.TaskType; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.internals.ThreadCache; +import org.apache.kafka.streams.state.internals.ThreadCache.DirtyEntryFlushListener; + +import java.io.File; +import java.util.Properties; + +public class MockInternalNewProcessorContext extends MockProcessorContext implements InternalProcessorContext { + + private ProcessorNode currentNode; + private long currentSystemTimeMs; + private TaskType taskType = TaskType.ACTIVE; + + private long timestamp = 0; + private Headers headers = new RecordHeaders(); + + public MockInternalNewProcessorContext() { + } + + public MockInternalNewProcessorContext(final Properties config, final TaskId taskId, final File stateDir) { + super(config, taskId, stateDir); + } + + @Override + public void setSystemTimeMs(long timeMs) { + currentSystemTimeMs = timeMs; + } + + @Override + public long currentSystemTimeMs() { + return currentSystemTimeMs; + } + + @Override + public long currentStreamTimeMs() { + return 0; + } + + @Override + public StreamsMetricsImpl metrics() { + return (StreamsMetricsImpl) super.metrics(); + } + + @Override + public ProcessorRecordContext recordContext() { + return new ProcessorRecordContext(timestamp(), offset(), partition(), topic(), headers()); + } + + @Override + public void setRecordContext(final ProcessorRecordContext recordContext) { + setRecordMetadata( + recordContext.topic(), + recordContext.partition(), + recordContext.offset() + ); + this.headers = recordContext.headers(); + this.timestamp = recordContext.timestamp(); + } + + public void setTimestamp(final long timestamp) { + this.timestamp = timestamp; + } + + public void setHeaders(final Headers headers) { + this.headers = headers; + } + + @Override + public void setCurrentNode(final ProcessorNode currentNode) { + this.currentNode = currentNode; + } + + @Override + public ProcessorNode currentNode() { + return currentNode; + } + + @Override + public ThreadCache cache() { + return null; + } + + @Override + public void initialize() {} + + @Override + public void uninitialize() {} + + @Override + public void register(final StateStore store, final StateRestoreCallback stateRestoreCallback) { + addStateStore(store); + } + + @Override + public void forward(K key, V value) { + throw new UnsupportedOperationException("Migrate to new implementation"); + } + + @Override + public void forward(K key, V value, To to) { + throw new UnsupportedOperationException("Migrate to new implementation"); + } + + @Override + public String topic() { + if (recordMetadata().isPresent()) return recordMetadata().get().topic(); + else return null; + } + + @Override + public int partition() { + if (recordMetadata().isPresent()) return recordMetadata().get().partition(); + else return 0; + } + + @Override + public long offset() { + if (recordMetadata().isPresent()) return recordMetadata().get().offset(); + else return 0; + } + + @Override + public Headers headers() { + return headers; + } + + @Override + public long timestamp() { + return timestamp; + } + + @Override + public TaskType taskType() { + return taskType; + } + + @Override + public void logChange(final String storeName, + final Bytes key, + final byte[] value, + final long timestamp) { + } + + @Override + public void transitionToActive(final StreamTask streamTask, final RecordCollector recordCollector, final ThreadCache newCache) { + } + + @Override + public void transitionToStandby(final ThreadCache newCache) { + } + + @Override + public void registerCacheFlushListener(final String namespace, final DirtyEntryFlushListener listener) { + } + + @Override + public T getStateStore(StoreBuilder builder) { + return getStateStore(builder.name()); + } + + @Override + public String changelogFor(final String storeName) { + return "mock-changelog"; + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/test/MockInternalProcessorContext.java b/streams/src/test/java/org/apache/kafka/test/MockInternalProcessorContext.java new file mode 100644 index 0000000..c32c136 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/MockInternalProcessorContext.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.processor.MockProcessorContext; +import org.apache.kafka.streams.processor.StateRestoreCallback; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.To; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.api.RecordMetadata; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.processor.internals.ProcessorNode; +import org.apache.kafka.streams.processor.internals.ProcessorRecordContext; +import org.apache.kafka.streams.processor.internals.RecordCollector; +import org.apache.kafka.streams.processor.internals.StreamTask; +import org.apache.kafka.streams.processor.internals.Task.TaskType; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.state.internals.ThreadCache; + +import java.io.File; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Optional; +import java.util.Properties; +import org.apache.kafka.streams.state.internals.ThreadCache.DirtyEntryFlushListener; + +public class MockInternalProcessorContext extends MockProcessorContext implements InternalProcessorContext { + + private final Map restoreCallbacks = new LinkedHashMap<>(); + private ProcessorNode currentNode; + private RecordCollector recordCollector; + private long currentSystemTimeMs; + private TaskType taskType = TaskType.ACTIVE; + + public MockInternalProcessorContext() { + } + + public MockInternalProcessorContext(final Properties config, final TaskId taskId, final File stateDir) { + super(config, taskId, stateDir); + } + + @Override + public void setSystemTimeMs(long timeMs) { + currentSystemTimeMs = timeMs; + } + + @Override + public long currentSystemTimeMs() { + return currentSystemTimeMs; + } + + @Override + public StreamsMetricsImpl metrics() { + return (StreamsMetricsImpl) super.metrics(); + } + + @Override + public void forward(final Record record) { + forward(record.key(), record.value(), To.all().withTimestamp(record.timestamp())); + } + + @Override + public void forward(final Record record, final String childName) { + forward(record.key(), record.value(), To.child(childName).withTimestamp(record.timestamp())); + } + + @Override + public ProcessorRecordContext recordContext() { + return new ProcessorRecordContext(timestamp(), offset(), partition(), topic(), headers()); + } + + @Override + public Optional recordMetadata() { + return Optional.of(recordContext()); + } + + @Override + public void setRecordContext(final ProcessorRecordContext recordContext) { + setRecordMetadata( + recordContext.topic(), + recordContext.partition(), + recordContext.offset(), + recordContext.headers(), + recordContext.timestamp() + ); + } + + @Override + public void setCurrentNode(final ProcessorNode currentNode) { + this.currentNode = currentNode; + } + + @Override + public ProcessorNode currentNode() { + return currentNode; + } + + @Override + public ThreadCache cache() { + return null; + } + + @Override + public void initialize() {} + + @Override + public void uninitialize() {} + + @Override + public RecordCollector recordCollector() { + return recordCollector; + } + + public void setRecordCollector(final RecordCollector recordCollector) { + this.recordCollector = recordCollector; + } + + @Override + public void register(final StateStore store, final StateRestoreCallback stateRestoreCallback) { + restoreCallbacks.put(store.name(), stateRestoreCallback); + super.register(store, stateRestoreCallback); + } + + public StateRestoreCallback stateRestoreCallback(final String storeName) { + return restoreCallbacks.get(storeName); + } + + @Override + public TaskType taskType() { + return taskType; + } + + @Override + public void logChange(final String storeName, + final Bytes key, + final byte[] value, + final long timestamp) { + } + + @Override + public void transitionToActive(final StreamTask streamTask, final RecordCollector recordCollector, final ThreadCache newCache) { + } + + @Override + public void transitionToStandby(final ThreadCache newCache) { + } + + @Override + public void registerCacheFlushListener(final String namespace, final DirtyEntryFlushListener listener) { + } + + @Override + public String changelogFor(final String storeName) { + return "mock-changelog"; + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/test/MockInternalTopicManager.java b/streams/src/test/java/org/apache/kafka/test/MockInternalTopicManager.java new file mode 100644 index 0000000..8b049a0 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/MockInternalTopicManager.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.clients.consumer.MockConsumer; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.processor.internals.InternalTopicConfig; +import org.apache.kafka.streams.processor.internals.InternalTopicManager; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +public class MockInternalTopicManager extends InternalTopicManager { + + public final Map readyTopics = new HashMap<>(); + private final MockConsumer restoreConsumer; + private final boolean mockCreateInternalTopics; + + public MockInternalTopicManager(final Time time, + final StreamsConfig streamsConfig, + final MockConsumer restoreConsumer, + final boolean mockCreateInternalTopics) { + super(time, new MockClientSupplier().getAdmin(streamsConfig.originals()), streamsConfig); + + this.restoreConsumer = restoreConsumer; + this.mockCreateInternalTopics = mockCreateInternalTopics; + } + + @Override + public Set makeReady(final Map topics) { + for (final InternalTopicConfig topic : topics.values()) { + final String topicName = topic.name(); + final int numberOfPartitions = topic.numberOfPartitions().get(); + readyTopics.put(topicName, numberOfPartitions); + + final List partitions = new ArrayList<>(); + for (int i = 0; i < numberOfPartitions; i++) { + partitions.add(new PartitionInfo(topicName, i, null, null, null)); + } + + restoreConsumer.updatePartitions(topicName, partitions); + } + return mockCreateInternalTopics ? topics.keySet() : Collections.emptySet(); + } + + @Override + protected Map getNumPartitions(final Set topics, + final Set tempUnknownTopics) { + final Map partitions = new HashMap<>(); + for (final String topic : topics) { + partitions.put(topic, restoreConsumer.partitionsFor(topic) == null ? null : restoreConsumer.partitionsFor(topic).size()); + } + + return partitions; + } +} diff --git a/streams/src/test/java/org/apache/kafka/test/MockKeyValueStore.java b/streams/src/test/java/org/apache/kafka/test/MockKeyValueStore.java new file mode 100644 index 0000000..91f8a6f --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/MockKeyValueStore.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.IntegerDeserializer; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateRestoreCallback; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +public class MockKeyValueStore implements KeyValueStore { + // keep a global counter of flushes and a local reference to which store had which + // flush, so we can reason about the order in which stores get flushed. + private static final AtomicInteger GLOBAL_FLUSH_COUNTER = new AtomicInteger(0); + private final AtomicInteger instanceLastFlushCount = new AtomicInteger(-1); + private final String name; + private final boolean persistent; + + public boolean initialized = false; + public boolean flushed = false; + public boolean closed = true; + public final ArrayList keys = new ArrayList<>(); + public final ArrayList values = new ArrayList<>(); + + public MockKeyValueStore(final String name, + final boolean persistent) { + this.name = name; + this.persistent = persistent; + } + + @Override + public String name() { + return name; + } + + @Deprecated + @Override + public void init(final ProcessorContext context, + final StateStore root) { + context.register(root, stateRestoreCallback); + initialized = true; + closed = false; + } + + @Override + public void flush() { + instanceLastFlushCount.set(GLOBAL_FLUSH_COUNTER.getAndIncrement()); + flushed = true; + } + + public int getLastFlushCount() { + return instanceLastFlushCount.get(); + } + + @Override + public void close() { + closed = true; + } + + @Override + public boolean persistent() { + return persistent; + } + + @Override + public boolean isOpen() { + return !closed; + } + + public final StateRestoreCallback stateRestoreCallback = new StateRestoreCallback() { + private final Deserializer deserializer = new IntegerDeserializer(); + + @Override + public void restore(final byte[] key, + final byte[] value) { + keys.add(deserializer.deserialize("", key)); + values.add(value); + } + }; + + @Override + public void put(final Object key, final Object value) {} + + @Override + public Object putIfAbsent(final Object key, final Object value) { + return null; + } + + @Override + public Object delete(final Object key) { + return null; + } + + @Override + public void putAll(final List> entries) {} + + @Override + public Object get(final Object key) { + return null; + } + + @Override + public KeyValueIterator range(final Object from, final Object to) { + return null; + } + + @Override + public , P> KeyValueIterator prefixScan(P prefix, PS prefixKeySerializer) { + return null; + } + + @Override + public KeyValueIterator all() { + return null; + } + + @Override + public long approximateNumEntries() { + return 0; + } +} diff --git a/streams/src/test/java/org/apache/kafka/test/MockKeyValueStoreBuilder.java b/streams/src/test/java/org/apache/kafka/test/MockKeyValueStoreBuilder.java new file mode 100644 index 0000000..11c2f3d --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/MockKeyValueStoreBuilder.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.internals.AbstractStoreBuilder; + +public class MockKeyValueStoreBuilder extends AbstractStoreBuilder> { + + private final boolean persistent; + + public MockKeyValueStoreBuilder(final String storeName, final boolean persistent) { + super(storeName, Serdes.Integer(), Serdes.ByteArray(), new MockTime()); + + this.persistent = persistent; + } + + @Override + public KeyValueStore build() { + return new MockKeyValueStore(name, persistent); + } +} + diff --git a/streams/src/test/java/org/apache/kafka/test/MockMapper.java b/streams/src/test/java/org/apache/kafka/test/MockMapper.java new file mode 100644 index 0000000..c3d9e93 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/MockMapper.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.ValueMapper; + +import java.util.Collections; + +public class MockMapper { + + private static class NoOpKeyValueMapper implements KeyValueMapper> { + @Override + public KeyValue apply(final K key, final V value) { + return KeyValue.pair(key, value); + } + } + + private static class NoOpFlatKeyValueMapper implements KeyValueMapper>> { + @Override + public Iterable> apply(final K key, final V value) { + return Collections.singletonList(KeyValue.pair(key, value)); + } + } + + private static class SelectValueKeyValueMapper implements KeyValueMapper> { + @Override + public KeyValue apply(final K key, final V value) { + return KeyValue.pair(value, value); + } + } + + private static class SelectValueMapper implements KeyValueMapper { + @Override + public V apply(final K key, final V value) { + return value; + } + } + + private static class SelectKeyMapper implements KeyValueMapper { + @Override + public K apply(final K key, final V value) { + return key; + } + } + + private static class NoOpValueMapper implements ValueMapper { + @Override + public V apply(final V value) { + return value; + } + } + + public static KeyValueMapper selectKeyKeyValueMapper() { + return new SelectKeyMapper<>(); + } + + public static KeyValueMapper>> noOpFlatKeyValueMapper() { + return new NoOpFlatKeyValueMapper<>(); + } + + public static KeyValueMapper> noOpKeyValueMapper() { + return new NoOpKeyValueMapper<>(); + } + + public static KeyValueMapper> selectValueKeyValueMapper() { + return new SelectValueKeyValueMapper<>(); + } + + public static KeyValueMapper selectValueMapper() { + return new SelectValueMapper<>(); + } + + public static ValueMapper noOpValueMapper() { + return new NoOpValueMapper<>(); + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/test/MockPredicate.java b/streams/src/test/java/org/apache/kafka/test/MockPredicate.java new file mode 100644 index 0000000..9d59bab --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/MockPredicate.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.streams.kstream.Predicate; + +public class MockPredicate { + + private static class AllGoodPredicate implements Predicate { + @Override + public boolean test(final K key, final V value) { + return true; + } + } + + public static Predicate allGoodPredicate() { + return new AllGoodPredicate<>(); + } +} diff --git a/streams/src/test/java/org/apache/kafka/test/MockProcessor.java b/streams/src/test/java/org/apache/kafka/test/MockProcessor.java new file mode 100644 index 0000000..a3bb87d --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/MockProcessor.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.streams.KeyValueTimestamp; +import org.apache.kafka.streams.processor.Cancellable; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.state.ValueAndTimestamp; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class MockProcessor extends org.apache.kafka.streams.processor.AbstractProcessor { + private final MockApiProcessor delegate; + + public MockProcessor(final PunctuationType punctuationType, + final long scheduleInterval) { + delegate = new MockApiProcessor<>(punctuationType, scheduleInterval); + } + + public MockProcessor() { + delegate = new MockApiProcessor<>(); + } + + @SuppressWarnings("unchecked") + @Override + public void init(final ProcessorContext context) { + super.init(context); + delegate.init((org.apache.kafka.streams.processor.api.ProcessorContext) context); + } + + @Override + public void process(final K key, final V value) { + delegate.process(new Record<>(key, value, context.timestamp(), context.headers())); + } + + public void checkAndClearProcessResult(final KeyValueTimestamp... expected) { + delegate.checkAndClearProcessResult(expected); + } + + public void requestCommit() { + delegate.requestCommit(); + } + + public void checkEmptyAndClearProcessResult() { + delegate.checkEmptyAndClearProcessResult(); + } + + public void checkAndClearPunctuateResult(final PunctuationType type, final long... expected) { + delegate.checkAndClearPunctuateResult(type, expected); + } + + public Map> lastValueAndTimestampPerKey() { + return delegate.lastValueAndTimestampPerKey(); + } + + public List punctuatedStreamTime() { + return delegate.punctuatedStreamTime(); + } + + public Cancellable scheduleCancellable() { + return delegate.scheduleCancellable(); + } + + public ArrayList> processed() { + return delegate.processed(); + } +} diff --git a/streams/src/test/java/org/apache/kafka/test/MockProcessorNode.java b/streams/src/test/java/org/apache/kafka/test/MockProcessorNode.java new file mode 100644 index 0000000..4ab4cb8 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/MockProcessorNode.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.processor.internals.ProcessorNode; + +import java.util.Collections; +import java.util.concurrent.atomic.AtomicInteger; + +public class MockProcessorNode extends ProcessorNode { + + private static final String NAME = "MOCK-PROCESS-"; + private static final AtomicInteger INDEX = new AtomicInteger(1); + + public final MockProcessor mockProcessor; + + public boolean closed; + public boolean initialized; + + public MockProcessorNode(final long scheduleInterval) { + this(scheduleInterval, PunctuationType.STREAM_TIME); + } + + public MockProcessorNode(final long scheduleInterval, final PunctuationType punctuationType) { + this(new MockProcessor<>(punctuationType, scheduleInterval)); + } + + public MockProcessorNode() { + this(new MockProcessor<>()); + } + + private MockProcessorNode(final MockProcessor mockProcessor) { + super(NAME + INDEX.getAndIncrement(), mockProcessor, Collections.emptySet()); + + this.mockProcessor = mockProcessor; + } + + @Override + public void init(final InternalProcessorContext context) { + super.init(context); + initialized = true; + } + + @Override + public void process(final Record record) { + processor().process(record); + } + + @Override + public void close() { + super.close(); + this.closed = true; + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/test/MockProcessorSupplier.java b/streams/src/test/java/org/apache/kafka/test/MockProcessorSupplier.java new file mode 100644 index 0000000..c6b70f2 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/MockProcessorSupplier.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.streams.processor.PunctuationType; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +@SuppressWarnings("deprecation") // Old PAPI. Needs to be migrated. +public class MockProcessorSupplier implements org.apache.kafka.streams.processor.ProcessorSupplier { + + private final long scheduleInterval; + private final PunctuationType punctuationType; + private final List> processors = new ArrayList<>(); + + public MockProcessorSupplier() { + this(-1L); + } + + public MockProcessorSupplier(final long scheduleInterval) { + this(scheduleInterval, PunctuationType.STREAM_TIME); + } + + public MockProcessorSupplier(final long scheduleInterval, final PunctuationType punctuationType) { + this.scheduleInterval = scheduleInterval; + this.punctuationType = punctuationType; + } + + @Override + public org.apache.kafka.streams.processor.Processor get() { + final MockProcessor processor = new MockProcessor<>(punctuationType, scheduleInterval); + + // to keep tests simple, ignore calls from ApiUtils.checkSupplier + if (!StreamsTestUtils.isCheckSupplierCall()) { + processors.add(processor); + } + + return processor; + } + + // get the captured processor assuming that only one processor gets returned from this supplier + public MockProcessor theCapturedProcessor() { + return capturedProcessors(1).get(0); + } + + public int capturedProcessorsCount() { + return processors.size(); + } + + // get the captured processors with the expected number + public List> capturedProcessors(final int expectedNumberOfProcessors) { + assertEquals(expectedNumberOfProcessors, processors.size()); + + return processors; + } +} diff --git a/streams/src/test/java/org/apache/kafka/test/MockRecordCollector.java b/streams/src/test/java/org/apache/kafka/test/MockRecordCollector.java new file mode 100644 index 0000000..505ee68 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/MockRecordCollector.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.processor.StreamPartitioner; +import org.apache.kafka.streams.processor.internals.RecordCollector; + +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + +import static java.util.Collections.unmodifiableList; + +public class MockRecordCollector implements RecordCollector { + + // remember all records that are collected so far + private final List> collected = new LinkedList<>(); + + // remember if flushed is called + private boolean flushed = false; + + @Override + public void send(final String topic, + final K key, + final V value, + final Headers headers, + final Integer partition, + final Long timestamp, + final Serializer keySerializer, + final Serializer valueSerializer) { + collected.add(new ProducerRecord<>(topic, + partition, + timestamp, + key, + value, + headers)); + } + + @Override + public void send(final String topic, + final K key, + final V value, + final Headers headers, + final Long timestamp, + final Serializer keySerializer, + final Serializer valueSerializer, + final StreamPartitioner partitioner) { + collected.add(new ProducerRecord<>(topic, + 0, // partition id + timestamp, + key, + value, + headers)); + } + + @Override + public void initialize() {} + + @Override + public void flush() { + flushed = true; + } + + @Override + public void closeClean() {} + + @Override + public void closeDirty() {} + + @Override + public Map offsets() { + return Collections.emptyMap(); + } + + public List> collected() { + return unmodifiableList(collected); + } + + public boolean flushed() { + return flushed; + } + + public void clear() { + this.flushed = false; + this.collected.clear(); + } +} diff --git a/streams/src/test/java/org/apache/kafka/test/MockReducer.java b/streams/src/test/java/org/apache/kafka/test/MockReducer.java new file mode 100644 index 0000000..0ecb4a1 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/MockReducer.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.streams.kstream.Reducer; + +public class MockReducer { + + private static class StringAdd implements Reducer { + + @Override + public String apply(final String value1, final String value2) { + return value1 + "+" + value2; + } + } + + private static class StringRemove implements Reducer { + + @Override + public String apply(final String value1, final String value2) { + return value1 + "-" + value2; + } + } + + + private static class IntegerAdd implements Reducer { + + @Override + public Integer apply(final Integer value1, final Integer value2) { + return value1 + value2; + } + } + + private static class IntegerSubtract implements Reducer { + + @Override + public Integer apply(final Integer value1, final Integer value2) { + return value1 - value2; + } + } + + public final static Reducer STRING_ADDER = new StringAdd(); + + public final static Reducer STRING_REMOVER = new StringRemove(); + + public final static Reducer INTEGER_ADDER = new IntegerAdd(); + + public final static Reducer INTEGER_SUBTRACTOR = new IntegerSubtract(); +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/test/MockRestoreCallback.java b/streams/src/test/java/org/apache/kafka/test/MockRestoreCallback.java new file mode 100644 index 0000000..fa5b465 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/MockRestoreCallback.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.processor.StateRestoreCallback; + +import java.util.ArrayList; +import java.util.List; + +public class MockRestoreCallback implements StateRestoreCallback { + public List> restored = new ArrayList<>(); + + @Override + public void restore(final byte[] key, final byte[] value) { + restored.add(KeyValue.pair(key, value)); + } +} diff --git a/streams/src/test/java/org/apache/kafka/test/MockRestoreConsumer.java b/streams/src/test/java/org/apache/kafka/test/MockRestoreConsumer.java new file mode 100644 index 0000000..2398872 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/MockRestoreConsumer.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.MockConsumer; +import org.apache.kafka.clients.consumer.OffsetResetStrategy; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.Serializer; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Map; +import java.util.Optional; + +public class MockRestoreConsumer extends MockConsumer { + private final Serializer keySerializer; + private final Serializer valueSerializer; + + private TopicPartition assignedPartition = null; + private long seekOffset = -1L; + private long endOffset = 0L; + private long currentOffset = 0L; + + private ArrayList> recordBuffer = new ArrayList<>(); + + public MockRestoreConsumer(final Serializer keySerializer, final Serializer valueSerializer) { + super(OffsetResetStrategy.EARLIEST); + + reset(); + this.keySerializer = keySerializer; + this.valueSerializer = valueSerializer; + } + + // reset this mock restore consumer for a state store registration + public void reset() { + assignedPartition = null; + seekOffset = -1L; + endOffset = 0L; + recordBuffer.clear(); + } + + // buffer a record (we cannot use addRecord because we need to add records before assigning a partition) + public void bufferRecord(final ConsumerRecord record) { + recordBuffer.add( + new ConsumerRecord<>(record.topic(), record.partition(), record.offset(), record.timestamp(), + record.timestampType(), 0, 0, + keySerializer.serialize(record.topic(), record.headers(), record.key()), + valueSerializer.serialize(record.topic(), record.headers(), record.value()), + record.headers(), Optional.empty())); + endOffset = record.offset(); + + super.updateEndOffsets(Collections.singletonMap(assignedPartition, endOffset)); + } + + @Override + public synchronized void assign(final Collection partitions) { + final int numPartitions = partitions.size(); + if (numPartitions > 1) + throw new IllegalArgumentException("RestoreConsumer: more than one partition specified"); + + if (numPartitions == 1) { + if (assignedPartition != null) + throw new IllegalStateException("RestoreConsumer: partition already assigned"); + assignedPartition = partitions.iterator().next(); + + // set the beginning offset to 0 + // NOTE: this is users responsible to set the initial lEO. + super.updateBeginningOffsets(Collections.singletonMap(assignedPartition, 0L)); + } + + super.assign(partitions); + } + + @Override + public ConsumerRecords poll(final Duration timeout) { + // add buffered records to MockConsumer + for (final ConsumerRecord record : recordBuffer) { + super.addRecord(record); + } + recordBuffer.clear(); + + final ConsumerRecords records = super.poll(timeout); + + // set the current offset + final Iterable> partitionRecords = records.records(assignedPartition); + for (final ConsumerRecord record : partitionRecords) { + currentOffset = record.offset(); + } + + return records; + } + + @Override + public synchronized long position(final TopicPartition partition) { + if (!partition.equals(assignedPartition)) + throw new IllegalStateException("RestoreConsumer: unassigned partition"); + + return currentOffset; + } + + @Override + public synchronized void seek(final TopicPartition partition, final long offset) { + if (offset < 0) + throw new IllegalArgumentException("RestoreConsumer: offset should not be negative"); + + if (seekOffset >= 0) + throw new IllegalStateException("RestoreConsumer: offset already seeked"); + + seekOffset = offset; + currentOffset = offset; + super.seek(partition, offset); + } + + @Override + public synchronized void seekToBeginning(final Collection partitions) { + if (partitions.size() != 1) + throw new IllegalStateException("RestoreConsumer: other than one partition specified"); + + for (final TopicPartition partition : partitions) { + if (!partition.equals(assignedPartition)) + throw new IllegalStateException("RestoreConsumer: seek-to-end not on the assigned partition"); + } + + currentOffset = 0L; + } + + + @Override + public Map endOffsets(final Collection partitions) { + if (partitions.size() != 1) + throw new IllegalStateException("RestoreConsumer: other than one partition specified"); + + for (final TopicPartition partition : partitions) { + if (!partition.equals(assignedPartition)) + throw new IllegalStateException("RestoreConsumer: seek-to-end not on the assigned partition"); + } + + currentOffset = endOffset; + return super.endOffsets(partitions); + } +} diff --git a/streams/src/test/java/org/apache/kafka/test/MockRocksDbConfigSetter.java b/streams/src/test/java/org/apache/kafka/test/MockRocksDbConfigSetter.java new file mode 100644 index 0000000..49c782b --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/MockRocksDbConfigSetter.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.streams.state.RocksDBConfigSetter; +import org.rocksdb.Options; + +import java.util.HashMap; +import java.util.Map; + +public class MockRocksDbConfigSetter implements RocksDBConfigSetter { + public static boolean called = false; + public static Map configMap = new HashMap<>(); + + @Override + public void setConfig(final String storeName, final Options options, final Map configs) { + called = true; + + configMap.putAll(configs); + } + + @Override + public void close(String storeName, Options options) { + } +} diff --git a/streams/src/test/java/org/apache/kafka/test/MockSourceNode.java b/streams/src/test/java/org/apache/kafka/test/MockSourceNode.java new file mode 100644 index 0000000..f52134e --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/MockSourceNode.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.processor.internals.SourceNode; + +import java.util.ArrayList; +import java.util.concurrent.atomic.AtomicInteger; + +public class MockSourceNode extends SourceNode { + + private static final String NAME = "MOCK-SOURCE-"; + private static final AtomicInteger INDEX = new AtomicInteger(1); + + public int numReceived = 0; + public final ArrayList keys = new ArrayList<>(); + public final ArrayList values = new ArrayList<>(); + public boolean initialized; + public boolean closed; + + public MockSourceNode(final Deserializer keyDeserializer, final Deserializer valDeserializer) { + super(NAME + INDEX.getAndIncrement(), keyDeserializer, valDeserializer); + } + + @Override + public void process(final Record record) { + numReceived++; + keys.add(record.key()); + values.add(record.value()); + } + + @Override + public void init(final InternalProcessorContext context) { + super.init(context); + initialized = true; + } + + @Override + public void close() { + super.close(); + closed = true; + } +} diff --git a/streams/src/test/java/org/apache/kafka/test/MockStateRestoreListener.java b/streams/src/test/java/org/apache/kafka/test/MockStateRestoreListener.java new file mode 100644 index 0000000..1026969 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/MockStateRestoreListener.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.test; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.streams.processor.StateRestoreListener; + +import java.util.HashMap; +import java.util.Map; + +public class MockStateRestoreListener implements StateRestoreListener { + + // verifies store name called for each state + public final Map storeNameCalledStates = new HashMap<>(); + public long restoreStartOffset; + public long restoreEndOffset; + public long restoredBatchOffset; + public long numBatchRestored; + public long totalNumRestored; + public TopicPartition restoreTopicPartition; + + public static final String RESTORE_START = "restore_start"; + public static final String RESTORE_BATCH = "restore_batch"; + public static final String RESTORE_END = "restore_end"; + + @Override + public void onRestoreStart(final TopicPartition topicPartition, + final String storeName, + final long startingOffset, + final long endingOffset) { + restoreTopicPartition = topicPartition; + storeNameCalledStates.put(RESTORE_START, storeName); + restoreStartOffset = startingOffset; + restoreEndOffset = endingOffset; + } + + @Override + public void onBatchRestored(final TopicPartition topicPartition, + final String storeName, + final long batchEndOffset, + final long numRestored) { + restoreTopicPartition = topicPartition; + storeNameCalledStates.put(RESTORE_BATCH, storeName); + restoredBatchOffset = batchEndOffset; + numBatchRestored = numRestored; + } + + @Override + public void onRestoreEnd(final TopicPartition topicPartition, + final String storeName, + final long totalRestored) { + restoreTopicPartition = topicPartition; + storeNameCalledStates.put(RESTORE_END, storeName); + totalNumRestored = totalRestored; + } + + @Override + public String toString() { + return "MockStateRestoreListener{" + + "storeNameCalledStates=" + storeNameCalledStates + + ", restoreStartOffset=" + restoreStartOffset + + ", restoreEndOffset=" + restoreEndOffset + + ", restoredBatchOffset=" + restoredBatchOffset + + ", numBatchRestored=" + numBatchRestored + + ", totalNumRestored=" + totalNumRestored + + ", restoreTopicPartition=" + restoreTopicPartition + + '}'; + } +} diff --git a/streams/src/test/java/org/apache/kafka/test/MockTimestampExtractor.java b/streams/src/test/java/org/apache/kafka/test/MockTimestampExtractor.java new file mode 100644 index 0000000..f437772 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/MockTimestampExtractor.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.streams.processor.TimestampExtractor; + +/* Extract the timestamp as the offset of the record */ +public class MockTimestampExtractor implements TimestampExtractor { + + @Override + public long extract(final ConsumerRecord record, final long partitionTime) { + return record.offset(); + } +} diff --git a/streams/src/test/java/org/apache/kafka/test/MockValueJoiner.java b/streams/src/test/java/org/apache/kafka/test/MockValueJoiner.java new file mode 100644 index 0000000..b1842ff --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/MockValueJoiner.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.streams.kstream.ValueJoiner; + +public class MockValueJoiner { + + public final static ValueJoiner TOSTRING_JOINER = instance("+"); + + public static ValueJoiner instance(final String separator) { + return new ValueJoiner() { + @Override + public String apply(final V1 value1, final V2 value2) { + return value1 + separator + value2; + } + }; + } +} diff --git a/streams/src/test/java/org/apache/kafka/test/NoOpProcessorContext.java b/streams/src/test/java/org/apache/kafka/test/NoOpProcessorContext.java new file mode 100644 index 0000000..53d1040 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/NoOpProcessorContext.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.processor.Cancellable; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.Punctuator; +import org.apache.kafka.streams.processor.StateRestoreCallback; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.To; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.internals.AbstractProcessorContext; +import org.apache.kafka.streams.processor.internals.MockStreamsMetrics; + +import java.time.Duration; +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; + +import org.apache.kafka.streams.processor.internals.ProcessorStateManager; +import org.apache.kafka.streams.processor.internals.RecordCollector; +import org.apache.kafka.streams.processor.internals.StateManager; +import org.apache.kafka.streams.processor.internals.StateManagerStub; +import org.apache.kafka.streams.processor.internals.StreamTask; +import org.apache.kafka.streams.processor.internals.Task.TaskType; +import org.apache.kafka.streams.state.internals.ThreadCache; +import org.apache.kafka.streams.state.internals.ThreadCache.DirtyEntryFlushListener; + +public class NoOpProcessorContext extends AbstractProcessorContext { + public boolean initialized; + @SuppressWarnings("WeakerAccess") + public Map forwardedValues = new HashMap<>(); + + public NoOpProcessorContext() { + super(new TaskId(1, 1), streamsConfig(), new MockStreamsMetrics(new Metrics()), null); + } + + private static StreamsConfig streamsConfig() { + final Properties props = new Properties(); + props.put(StreamsConfig.APPLICATION_ID_CONFIG, "appId"); + props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "boot"); + return new StreamsConfig(props); + } + + @Override + protected StateManager stateManager() { + return new StateManagerStub(); + } + + @Override + public S getStateStore(final String name) { + return null; + } + + @Override + public Cancellable schedule(final Duration interval, + final PunctuationType type, + final Punctuator callback) throws IllegalArgumentException { + return null; + } + + @Override + public void forward(final Record record) { + forward(record.key(), record.value()); + } + + @Override + public void forward(final Record record, final String childName) { + forward(record.key(), record.value()); + } + + @Override + public void forward(final K key, final V value) { + forwardedValues.put(key, value); + } + + @Override + public void forward(final K key, final V value, final To to) { + forward(key, value); + } + + @Override + public void commit() {} + + @Override + public long currentSystemTimeMs() { + throw new UnsupportedOperationException("Not implemented yet."); + } + + @Override + public long currentStreamTimeMs() { + throw new UnsupportedOperationException("Not implemented yet."); + } + + @Override + public void initialize() { + initialized = true; + } + + @Override + public void register(final StateStore store, + final StateRestoreCallback stateRestoreCallback) { + } + + @Override + public TaskType taskType() { + return TaskType.ACTIVE; + } + + @Override + public void logChange(final String storeName, + final Bytes key, + final byte[] value, + final long timestamp) { + } + + @Override + public void transitionToActive(final StreamTask streamTask, final RecordCollector recordCollector, final ThreadCache newCache) { + } + + @Override + public void transitionToStandby(final ThreadCache newCache) { + } + + @Override + public void registerCacheFlushListener(final String namespace, final DirtyEntryFlushListener listener) { + cache.addDirtyEntryFlushListener(namespace, listener); + } + + @Override + public String changelogFor(final String storeName) { + return ProcessorStateManager.storeChangelogTopic(applicationId(), storeName, taskId().topologyName()); + } +} diff --git a/streams/src/test/java/org/apache/kafka/test/NoOpReadOnlyStore.java b/streams/src/test/java/org/apache/kafka/test/NoOpReadOnlyStore.java new file mode 100644 index 0000000..7234231 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/NoOpReadOnlyStore.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.ReadOnlyKeyValueStore; + +import java.io.File; + +public class NoOpReadOnlyStore implements ReadOnlyKeyValueStore, StateStore { + private final String name; + private final boolean rocksdbStore; + private boolean open = true; + public boolean initialized; + public boolean flushed; + + public NoOpReadOnlyStore() { + this("", false); + } + + public NoOpReadOnlyStore(final String name) { + this(name, false); + } + + public NoOpReadOnlyStore(final String name, + final boolean rocksdbStore) { + this.name = name; + this.rocksdbStore = rocksdbStore; + } + + @Override + public V get(final K key) { + return null; + } + + @Override + public KeyValueIterator range(final K from, final K to) { + return null; + } + + @Override + public , P> KeyValueIterator prefixScan(P prefix, PS prefixKeySerializer) { + return null; + } + + @Override + public KeyValueIterator all() { + return null; + } + + @Override + public long approximateNumEntries() { + return 0L; + } + + @Override + public String name() { + return name; + } + + @Deprecated + @Override + public void init(final ProcessorContext context, final StateStore root) { + if (rocksdbStore) { + // cf. RocksDBStore + new File(context.stateDir() + File.separator + "rocksdb" + File.separator + name).mkdirs(); + } else { + new File(context.stateDir() + File.separator + name).mkdir(); + } + this.initialized = true; + context.register(root, (k, v) -> { }); + } + + @Override + public void flush() { + flushed = true; + } + + @Override + public void close() { + open = false; + } + + @Override + public boolean persistent() { + return rocksdbStore; + } + + @Override + public boolean isOpen() { + return open; + } + +} diff --git a/streams/src/test/java/org/apache/kafka/test/NoOpValueTransformerWithKeySupplier.java b/streams/src/test/java/org/apache/kafka/test/NoOpValueTransformerWithKeySupplier.java new file mode 100644 index 0000000..b948abf --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/NoOpValueTransformerWithKeySupplier.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.streams.kstream.ValueTransformerWithKey; +import org.apache.kafka.streams.kstream.ValueTransformerWithKeySupplier; +import org.apache.kafka.streams.processor.ProcessorContext; + +public class NoOpValueTransformerWithKeySupplier implements ValueTransformerWithKeySupplier { + public ProcessorContext context; + + @Override + public ValueTransformerWithKey get() { + return new ValueTransformerWithKey() { + + @Override + public void init(final ProcessorContext context1) { + NoOpValueTransformerWithKeySupplier.this.context = context1; + } + + @Override + public V transform(final K readOnlyKey, final V value) { + return value; + } + + @Override + public void close() { + } + }; + } +} \ No newline at end of file diff --git a/streams/src/test/java/org/apache/kafka/test/NoopValueTransformer.java b/streams/src/test/java/org/apache/kafka/test/NoopValueTransformer.java new file mode 100644 index 0000000..f45c825 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/NoopValueTransformer.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.streams.kstream.ValueTransformer; +import org.apache.kafka.streams.processor.ProcessorContext; + +public class NoopValueTransformer implements ValueTransformer { + @Override + public void init(final ProcessorContext context) { + } + + @Override + public VR transform(final V value) { + return null; + } + + @Override + public void close() { + } +} diff --git a/streams/src/test/java/org/apache/kafka/test/NoopValueTransformerWithKey.java b/streams/src/test/java/org/apache/kafka/test/NoopValueTransformerWithKey.java new file mode 100644 index 0000000..9677a8f --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/NoopValueTransformerWithKey.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.streams.kstream.ValueTransformerWithKey; +import org.apache.kafka.streams.processor.ProcessorContext; + +public class NoopValueTransformerWithKey implements ValueTransformerWithKey { + @Override + public void init(final ProcessorContext context) { + } + + @Override + public VR transform(final K readOnlyKey, final V value) { + return null; + } + + @Override + public void close() { + } +} diff --git a/streams/src/test/java/org/apache/kafka/test/ReadOnlySessionStoreStub.java b/streams/src/test/java/org/apache/kafka/test/ReadOnlySessionStoreStub.java new file mode 100644 index 0000000..61ea1ff --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/ReadOnlySessionStoreStub.java @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.ReadOnlySessionStore; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.NavigableMap; +import java.util.TreeMap; + +public class ReadOnlySessionStoreStub implements ReadOnlySessionStore, StateStore { + private final NavigableMap, V>>> sessions = new TreeMap<>(); + private boolean open = true; + + public void put(final Windowed sessionKey, final V value) { + if (!sessions.containsKey(sessionKey.key())) { + sessions.put(sessionKey.key(), new ArrayList<>()); + } + sessions.get(sessionKey.key()).add(KeyValue.pair(sessionKey, value)); + } + + @Override + public KeyValueIterator, V> findSessions(K key, long earliestSessionEndTime, long latestSessionStartTime) { + throw new UnsupportedOperationException("Moved from Session Store. Implement if needed"); + } + + @Override + public KeyValueIterator, V> backwardFindSessions(K key, long earliestSessionEndTime, long latestSessionStartTime) { + throw new UnsupportedOperationException("Moved from Session Store. Implement if needed"); + } + + @Override + public KeyValueIterator, V> findSessions(K keyFrom, K keyTo, long earliestSessionEndTime, long latestSessionStartTime) { + throw new UnsupportedOperationException("Moved from Session Store. Implement if needed"); + } + + @Override + public KeyValueIterator, V> backwardFindSessions(K keyFrom, K keyTo, long earliestSessionEndTime, long latestSessionStartTime) { + throw new UnsupportedOperationException("Moved from Session Store. Implement if needed"); + } + + @Override + public V fetchSession(K key, long earliestSessionEndTime, long latestSessionStartTime) { + throw new UnsupportedOperationException("Moved from Session Store. Implement if needed"); + } + + @Override + public KeyValueIterator, V> fetch(final K key) { + if (!open) { + throw new InvalidStateStoreException("not open"); + } + if (!sessions.containsKey(key)) { + return new KeyValueIteratorStub<>(Collections., V>>emptyIterator()); + } + return new KeyValueIteratorStub<>(sessions.get(key).iterator()); + } + + @Override + public KeyValueIterator, V> backwardFetch(K key) { + if (!open) { + throw new InvalidStateStoreException("not open"); + } + if (!sessions.containsKey(key)) { + return new KeyValueIteratorStub<>(Collections.emptyIterator()); + } + return new KeyValueIteratorStub<>(sessions.descendingMap().get(key).iterator()); + } + + @Override + public KeyValueIterator, V> fetch(final K keyFrom, final K keyTo) { + if (!open) { + throw new InvalidStateStoreException("not open"); + } + + NavigableMap, V>>> subSessionsMap = getSubSessionsMap(keyFrom, keyTo); + + if (subSessionsMap.isEmpty()) { + return new KeyValueIteratorStub<>(Collections., V>>emptyIterator()); + } + final Iterator, V>>> keysIterator = subSessionsMap.values().iterator(); + return new KeyValueIteratorStub<>( + new Iterator, V>>() { + + Iterator, V>> it; + + @Override + public boolean hasNext() { + while (it == null || !it.hasNext()) { + if (!keysIterator.hasNext()) { + return false; + } + it = keysIterator.next().iterator(); + } + return true; + } + + @Override + public KeyValue, V> next() { + return it.next(); + } + + } + ); + } + + private NavigableMap, V>>> getSubSessionsMap(final K keyFrom, final K keyTo) { + final NavigableMap, V>>> subSessionsMap; + if (keyFrom == null && keyTo == null) { // fetch all + subSessionsMap = sessions; + } else if (keyFrom == null) { + subSessionsMap = sessions.headMap(keyTo, true); + } else if (keyTo == null) { + subSessionsMap = sessions.tailMap(keyFrom, true); + } else { + subSessionsMap = sessions.subMap(keyFrom, true, keyTo, true); + } + return subSessionsMap; + } + + @Override + public KeyValueIterator, V> backwardFetch(K keyFrom, K keyTo) { + if (!open) { + throw new InvalidStateStoreException("not open"); + } + + NavigableMap, V>>> subSessionsMap = getSubSessionsMap(keyFrom, keyTo); + + if (subSessionsMap.isEmpty()) { + return new KeyValueIteratorStub<>(Collections.emptyIterator()); + } + + final Iterator, V>>> keysIterator = subSessionsMap.descendingMap().values().iterator(); + return new KeyValueIteratorStub<>( + new Iterator, V>>() { + + Iterator, V>> it; + + @Override + public boolean hasNext() { + while (it == null || !it.hasNext()) { + if (!keysIterator.hasNext()) { + return false; + } + it = keysIterator.next().iterator(); + } + return true; + } + + @Override + public KeyValue, V> next() { + return it.next(); + } + } + ); + } + + @Override + public String name() { + return ""; + } + + @Deprecated + @Override + public void init(final ProcessorContext context, final StateStore root) { + + } + + @Override + public void flush() { + + } + + @Override + public void close() { + + } + + @Override + public boolean persistent() { + return false; + } + + @Override + public boolean isOpen() { + return open; + } + + + public void setOpen(final boolean open) { + this.open = open; + } +} diff --git a/streams/src/test/java/org/apache/kafka/test/StateStoreProviderStub.java b/streams/src/test/java/org/apache/kafka/test/StateStoreProviderStub.java new file mode 100644 index 0000000..9d89ae2 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/StateStoreProviderStub.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.streams.StoreQueryParameters; +import org.apache.kafka.streams.errors.InvalidStateStoreException; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.state.QueryableStoreType; +import org.apache.kafka.streams.state.internals.StreamThreadStateStoreProvider; + +import java.util.AbstractMap.SimpleEntry; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.stream.Collectors; + +public class StateStoreProviderStub extends StreamThreadStateStoreProvider { + + // -> state store + private final Map, StateStore> stores = new HashMap<>(); + private final boolean throwException; + + private final int defaultStorePartition = 0; + + public StateStoreProviderStub(final boolean throwException) { + super(null); + this.throwException = throwException; + } + + @SuppressWarnings("unchecked") + @Override + public List stores(final StoreQueryParameters storeQueryParameters) { + final String storeName = storeQueryParameters.storeName(); + final QueryableStoreType queryableStoreType = storeQueryParameters.queryableStoreType(); + if (throwException) { + throw new InvalidStateStoreException("store is unavailable"); + } + if (storeQueryParameters.partition() != null) { + final Entry stateStoreKey = new SimpleEntry<>(storeName, storeQueryParameters.partition()); + if (stores.containsKey(stateStoreKey) && queryableStoreType.accepts(stores.get(stateStoreKey))) { + return (List) Collections.singletonList(stores.get(stateStoreKey)); + } + return Collections.emptyList(); + } + return (List) Collections.unmodifiableList( + stores.entrySet().stream(). + filter(entry -> entry.getKey().getKey().equals(storeName) && queryableStoreType.accepts(entry.getValue())). + map(Entry::getValue). + collect(Collectors.toList())); + } + + public void addStore(final String storeName, + final StateStore store) { + addStore(storeName, defaultStorePartition, store); + } + + public void addStore(final String storeName, + final int partition, + final StateStore store) { + stores.put(new SimpleEntry<>(storeName, partition), store); + } +} diff --git a/streams/src/test/java/org/apache/kafka/test/StreamsTestUtils.java b/streams/src/test/java/org/apache/kafka/test/StreamsTestUtils.java new file mode 100644 index 0000000..9008982 --- /dev/null +++ b/streams/src/test/java/org/apache/kafka/test/StreamsTestUtils.java @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.test; + +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.KeyValueIterator; + +import java.io.Closeable; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +import static org.apache.kafka.common.metrics.Sensor.RecordingLevel.DEBUG; +import static org.apache.kafka.test.TestUtils.DEFAULT_MAX_WAIT_MS; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertFalse; + +public final class StreamsTestUtils { + private StreamsTestUtils() {} + + public static Properties getStreamsConfig(final String applicationId, + final String bootstrapServers, + final String keySerdeClassName, + final String valueSerdeClassName, + final Properties additional) { + + final Properties props = new Properties(); + props.put(StreamsConfig.APPLICATION_ID_CONFIG, applicationId); + props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers); + props.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, keySerdeClassName); + props.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, valueSerdeClassName); + props.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()); + props.put(StreamsConfig.METRICS_RECORDING_LEVEL_CONFIG, DEBUG.name); + props.putAll(additional); + return props; + } + + public static Properties getStreamsConfig(final String applicationId, + final String bootstrapServers, + final Properties additional) { + + final Properties props = new Properties(); + props.put(StreamsConfig.APPLICATION_ID_CONFIG, applicationId); + props.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers); + props.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()); + props.put(StreamsConfig.METRICS_RECORDING_LEVEL_CONFIG, DEBUG.name); + props.putAll(additional); + return props; + } + + public static Properties getStreamsConfig(final Serde keyDeserializer, + final Serde valueDeserializer) { + return getStreamsConfig( + UUID.randomUUID().toString(), + "localhost:9091", + keyDeserializer.getClass().getName(), + valueDeserializer.getClass().getName(), + new Properties()); + } + + public static Properties getStreamsConfig(final String applicationId) { + return getStreamsConfig(applicationId, new Properties()); + } + + public static Properties getStreamsConfig(final String applicationId, final Properties additional) { + return getStreamsConfig( + applicationId, + "localhost:9091", + additional); + } + + public static Properties getStreamsConfig() { + return getStreamsConfig(UUID.randomUUID().toString()); + } + + public static void startKafkaStreamsAndWaitForRunningState(final KafkaStreams kafkaStreams) throws InterruptedException { + startKafkaStreamsAndWaitForRunningState(kafkaStreams, DEFAULT_MAX_WAIT_MS); + } + + public static void startKafkaStreamsAndWaitForRunningState(final KafkaStreams kafkaStreams, + final long timeoutMs) throws InterruptedException { + final CountDownLatch countDownLatch = new CountDownLatch(1); + kafkaStreams.setStateListener((newState, oldState) -> { + if (newState == KafkaStreams.State.RUNNING) { + countDownLatch.countDown(); + } + }); + + kafkaStreams.start(); + assertThat( + "KafkaStreams did not transit to RUNNING state within " + timeoutMs + " milli seconds.", + countDownLatch.await(timeoutMs, TimeUnit.MILLISECONDS), + equalTo(true) + ); + } + + public static List> toList(final Iterator> iterator) { + final List> results = new ArrayList<>(); + + while (iterator.hasNext()) { + results.add(iterator.next()); + } + + if (iterator instanceof Closeable) { + try { + ((Closeable) iterator).close(); + } catch (IOException e) { /* do nothing */ } + } + + return results; + } + + public static Set> toSet(final Iterator> iterator) { + final Set> results = new LinkedHashSet<>(); + + while (iterator.hasNext()) { + results.add(iterator.next()); + } + return results; + } + + public static Set valuesToSet(final Iterator> iterator) { + final Set results = new HashSet<>(); + + while (iterator.hasNext()) { + results.add(iterator.next().value); + } + return results; + } + + public static void verifyKeyValueList(final List> expected, final List> actual) { + assertThat(actual.size(), equalTo(expected.size())); + for (int i = 0; i < actual.size(); i++) { + final KeyValue expectedKv = expected.get(i); + final KeyValue actualKv = actual.get(i); + assertThat(actualKv.key, equalTo(expectedKv.key)); + assertThat(actualKv.value, equalTo(expectedKv.value)); + } + } + + public static void verifyAllWindowedKeyValues(final KeyValueIterator, byte[]> iterator, + final List> expectedKeys, + final List expectedValues) { + if (expectedKeys.size() != expectedValues.size()) { + throw new IllegalArgumentException("expectedKeys and expectedValues should have the same size. " + + "expectedKeys size: " + expectedKeys.size() + ", expectedValues size: " + expectedValues.size()); + } + + for (int i = 0; i < expectedKeys.size(); i++) { + verifyWindowedKeyValue( + iterator.next(), + expectedKeys.get(i), + expectedValues.get(i) + ); + } + assertFalse(iterator.hasNext()); + } + + public static void verifyWindowedKeyValue(final KeyValue, byte[]> actual, + final Windowed expectedKey, + final String expectedValue) { + assertThat(actual.key.window(), equalTo(expectedKey.window())); + assertThat(actual.key.key(), equalTo(expectedKey.key())); + assertThat(actual.value, equalTo(expectedValue.getBytes())); + } + + public static Metric getMetricByName(final Map metrics, + final String name, + final String group) { + Metric metric = null; + for (final Map.Entry entry : metrics.entrySet()) { + if (entry.getKey().name().equals(name) && entry.getKey().group().equals(group)) { + if (metric == null) { + metric = entry.getValue(); + } else { + throw new IllegalStateException( + "Found two metrics with name=[" + name + "]: \n" + + metric.metricName().toString() + + " AND \n" + + entry.getKey().toString() + ); + } + } + } + if (metric == null) { + throw new IllegalStateException("Didn't find metric with name=[" + name + "]"); + } else { + return metric; + } + } + + public static Metric getMetricByNameFilterByTags(final Map metrics, + final String name, + final String group, + final Map filterTags) { + Metric metric = null; + for (final Map.Entry entry : metrics.entrySet()) { + if (entry.getKey().name().equals(name) && entry.getKey().group().equals(group)) { + boolean filtersMatch = true; + for (final Map.Entry filter : filterTags.entrySet()) { + if (!filter.getValue().equals(entry.getKey().tags().get(filter.getKey()))) { + filtersMatch = false; + } + } + if (filtersMatch) { + if (metric == null) { + metric = entry.getValue(); + } else { + throw new IllegalStateException( + "Found two metrics with name=[" + name + "] and tags=[" + filterTags + "]: \n" + + metric.metricName().toString() + + " AND \n" + + entry.getKey().toString() + ); + } + } + } + } + if (metric == null) { + throw new IllegalStateException("Didn't find metric with name=[" + name + "] and tags=[" + filterTags + "]"); + } else { + return metric; + } + } + + public static boolean containsMetric(final Metrics metrics, + final String name, + final String group, + final Map tags) { + final MetricName metricName = metrics.metricName(name, group, tags); + return metrics.metric(metricName) != null; + } + + /** + * Used to keep tests simple, and ignore calls from {@link org.apache.kafka.streams.internals.ApiUtils#checkSupplier(Supplier)} )}. + * @return true if the stack context is within a {@link org.apache.kafka.streams.internals.ApiUtils#checkSupplier(Supplier)} )} call + */ + public static boolean isCheckSupplierCall() { + return Arrays.stream(Thread.currentThread().getStackTrace()) + .anyMatch(caller -> "org.apache.kafka.streams.internals.ApiUtils".equals(caller.getClassName()) && "checkSupplier".equals(caller.getMethodName())); + } +} diff --git a/streams/src/test/resources/kafka/kafka-streams-version.properties b/streams/src/test/resources/kafka/kafka-streams-version.properties new file mode 100644 index 0000000..333ba71 --- /dev/null +++ b/streams/src/test/resources/kafka/kafka-streams-version.properties @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +commitId=test-commit-ID +version=test-version \ No newline at end of file diff --git a/streams/src/test/resources/log4j.properties b/streams/src/test/resources/log4j.properties new file mode 100644 index 0000000..eabbc54 --- /dev/null +++ b/streams/src/test/resources/log4j.properties @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +log4j.rootLogger=INFO, stdout + +log4j.appender.stdout=org.apache.log4j.ConsoleAppender +log4j.appender.stdout.layout=org.apache.log4j.PatternLayout +log4j.appender.stdout.layout.ConversionPattern=[%d] %p %m (%c:%L)%n + +log4j.logger.kafka=INFO +log4j.logger.org.apache.kafka=ERROR +log4j.logger.org.apache.zookeeper=ERROR + +log4j.logger.org.apache.kafka.clients=WARN +log4j.logger.org.apache.kafka.streams=INFO diff --git a/streams/streams-scala/.gitignore b/streams/streams-scala/.gitignore new file mode 100644 index 0000000..bf11921 --- /dev/null +++ b/streams/streams-scala/.gitignore @@ -0,0 +1 @@ +/logs/ \ No newline at end of file diff --git a/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/FunctionsCompatConversions.scala b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/FunctionsCompatConversions.scala new file mode 100644 index 0000000..c3c6403 --- /dev/null +++ b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/FunctionsCompatConversions.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala + +import org.apache.kafka.streams.KeyValue +import org.apache.kafka.streams.kstream._ +import scala.jdk.CollectionConverters._ +import java.lang.{Iterable => JIterable} + +import org.apache.kafka.streams.processor.ProcessorContext + +/** + * Implicit classes that offer conversions of Scala function literals to + * SAM (Single Abstract Method) objects in Java. These make the Scala APIs much + * more expressive, with less boilerplate and more succinct. + */ +private[scala] object FunctionsCompatConversions { + + implicit class ForeachActionFromFunction[K, V](val p: (K, V) => Unit) extends AnyVal { + def asForeachAction: ForeachAction[K, V] = (key: K, value: V) => p(key, value) + } + + implicit class PredicateFromFunction[K, V](val p: (K, V) => Boolean) extends AnyVal { + def asPredicate: Predicate[K, V] = (key: K, value: V) => p(key, value) + } + + implicit class MapperFromFunction[T, U, VR](val f: (T, U) => VR) extends AnyVal { + def asKeyValueMapper: KeyValueMapper[T, U, VR] = (key: T, value: U) => f(key, value) + def asValueJoiner: ValueJoiner[T, U, VR] = (value1: T, value2: U) => f(value1, value2) + } + + implicit class KeyValueMapperFromFunction[K, V, KR, VR](val f: (K, V) => (KR, VR)) extends AnyVal { + def asKeyValueMapper: KeyValueMapper[K, V, KeyValue[KR, VR]] = (key: K, value: V) => { + val (kr, vr) = f(key, value) + KeyValue.pair(kr, vr) + } + } + + implicit class FunctionFromFunction[V, VR](val f: V => VR) extends AnyVal { + def asJavaFunction: java.util.function.Function[V, VR] = (value: V) => f(value) + } + + implicit class ValueMapperFromFunction[V, VR](val f: V => VR) extends AnyVal { + def asValueMapper: ValueMapper[V, VR] = (value: V) => f(value) + } + + implicit class FlatValueMapperFromFunction[V, VR](val f: V => Iterable[VR]) extends AnyVal { + def asValueMapper: ValueMapper[V, JIterable[VR]] = (value: V) => f(value).asJava + } + + implicit class ValueMapperWithKeyFromFunction[K, V, VR](val f: (K, V) => VR) extends AnyVal { + def asValueMapperWithKey: ValueMapperWithKey[K, V, VR] = (readOnlyKey: K, value: V) => f(readOnlyKey, value) + } + + implicit class FlatValueMapperWithKeyFromFunction[K, V, VR](val f: (K, V) => Iterable[VR]) extends AnyVal { + def asValueMapperWithKey: ValueMapperWithKey[K, V, JIterable[VR]] = + (readOnlyKey: K, value: V) => f(readOnlyKey, value).asJava + } + + implicit class AggregatorFromFunction[K, V, VA](val f: (K, V, VA) => VA) extends AnyVal { + def asAggregator: Aggregator[K, V, VA] = (key: K, value: V, aggregate: VA) => f(key, value, aggregate) + } + + implicit class MergerFromFunction[K, VR](val f: (K, VR, VR) => VR) extends AnyVal { + def asMerger: Merger[K, VR] = (aggKey: K, aggOne: VR, aggTwo: VR) => f(aggKey, aggOne, aggTwo) + } + + implicit class ReducerFromFunction[V](val f: (V, V) => V) extends AnyVal { + def asReducer: Reducer[V] = (value1: V, value2: V) => f(value1, value2) + } + + implicit class InitializerFromFunction[VA](val f: () => VA) extends AnyVal { + def asInitializer: Initializer[VA] = () => f() + } + + implicit class TransformerSupplierFromFunction[K, V, VO](val f: () => Transformer[K, V, VO]) extends AnyVal { + def asTransformerSupplier: TransformerSupplier[K, V, VO] = () => f() + } + + implicit class TransformerSupplierAsJava[K, V, VO](val supplier: TransformerSupplier[K, V, Iterable[VO]]) + extends AnyVal { + def asJava: TransformerSupplier[K, V, JIterable[VO]] = () => { + val innerTransformer = supplier.get() + new Transformer[K, V, JIterable[VO]] { + override def transform(key: K, value: V): JIterable[VO] = innerTransformer.transform(key, value).asJava + override def init(context: ProcessorContext): Unit = innerTransformer.init(context) + override def close(): Unit = innerTransformer.close() + } + } + } + implicit class ValueTransformerSupplierAsJava[V, VO](val supplier: ValueTransformerSupplier[V, Iterable[VO]]) + extends AnyVal { + def asJava: ValueTransformerSupplier[V, JIterable[VO]] = () => { + val innerTransformer = supplier.get() + new ValueTransformer[V, JIterable[VO]] { + override def transform(value: V): JIterable[VO] = innerTransformer.transform(value).asJava + override def init(context: ProcessorContext): Unit = innerTransformer.init(context) + override def close(): Unit = innerTransformer.close() + } + } + } + implicit class ValueTransformerSupplierWithKeyAsJava[K, V, VO]( + val supplier: ValueTransformerWithKeySupplier[K, V, Iterable[VO]] + ) extends AnyVal { + def asJava: ValueTransformerWithKeySupplier[K, V, JIterable[VO]] = () => { + val innerTransformer = supplier.get() + new ValueTransformerWithKey[K, V, JIterable[VO]] { + override def transform(key: K, value: V): JIterable[VO] = innerTransformer.transform(key, value).asJava + override def init(context: ProcessorContext): Unit = innerTransformer.init(context) + override def close(): Unit = innerTransformer.close() + } + } + } +} diff --git a/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/ImplicitConversions.scala b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/ImplicitConversions.scala new file mode 100644 index 0000000..5f7064b --- /dev/null +++ b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/ImplicitConversions.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala + +import org.apache.kafka.common.serialization.Serde +import org.apache.kafka.streams.KeyValue +import org.apache.kafka.streams.kstream.{ + KStream => KStreamJ, + KGroupedStream => KGroupedStreamJ, + TimeWindowedKStream => TimeWindowedKStreamJ, + SessionWindowedKStream => SessionWindowedKStreamJ, + CogroupedKStream => CogroupedKStreamJ, + TimeWindowedCogroupedKStream => TimeWindowedCogroupedKStreamJ, + SessionWindowedCogroupedKStream => SessionWindowedCogroupedKStreamJ, + KTable => KTableJ, + KGroupedTable => KGroupedTableJ +} +import org.apache.kafka.streams.processor.StateStore +import org.apache.kafka.streams.scala.kstream._ + +/** + * Implicit conversions between the Scala wrapper objects and the underlying Java + * objects. + */ +object ImplicitConversions { + + implicit def wrapKStream[K, V](inner: KStreamJ[K, V]): KStream[K, V] = + new KStream[K, V](inner) + + implicit def wrapKGroupedStream[K, V](inner: KGroupedStreamJ[K, V]): KGroupedStream[K, V] = + new KGroupedStream[K, V](inner) + + implicit def wrapTimeWindowedKStream[K, V](inner: TimeWindowedKStreamJ[K, V]): TimeWindowedKStream[K, V] = + new TimeWindowedKStream[K, V](inner) + + implicit def wrapSessionWindowedKStream[K, V](inner: SessionWindowedKStreamJ[K, V]): SessionWindowedKStream[K, V] = + new SessionWindowedKStream[K, V](inner) + + implicit def wrapCogroupedKStream[K, V](inner: CogroupedKStreamJ[K, V]): CogroupedKStream[K, V] = + new CogroupedKStream[K, V](inner) + + implicit def wrapTimeWindowedCogroupedKStream[K, V]( + inner: TimeWindowedCogroupedKStreamJ[K, V] + ): TimeWindowedCogroupedKStream[K, V] = + new TimeWindowedCogroupedKStream[K, V](inner) + + implicit def wrapSessionWindowedCogroupedKStream[K, V]( + inner: SessionWindowedCogroupedKStreamJ[K, V] + ): SessionWindowedCogroupedKStream[K, V] = + new SessionWindowedCogroupedKStream[K, V](inner) + + implicit def wrapKTable[K, V](inner: KTableJ[K, V]): KTable[K, V] = + new KTable[K, V](inner) + + implicit def wrapKGroupedTable[K, V](inner: KGroupedTableJ[K, V]): KGroupedTable[K, V] = + new KGroupedTable[K, V](inner) + + implicit def tuple2ToKeyValue[K, V](tuple: (K, V)): KeyValue[K, V] = new KeyValue(tuple._1, tuple._2) + + // we would also like to allow users implicit serdes + // and these implicits will convert them to `Grouped`, `Produced` or `Consumed` + + implicit def consumedFromSerde[K, V](implicit keySerde: Serde[K], valueSerde: Serde[V]): Consumed[K, V] = + Consumed.`with`[K, V] + + implicit def groupedFromSerde[K, V](implicit keySerde: Serde[K], valueSerde: Serde[V]): Grouped[K, V] = + Grouped.`with`[K, V] + + implicit def joinedFromKeyValueOtherSerde[K, V, VO](implicit + keySerde: Serde[K], + valueSerde: Serde[V], + otherValueSerde: Serde[VO] + ): Joined[K, V, VO] = + Joined.`with`[K, V, VO] + + implicit def materializedFromSerde[K, V, S <: StateStore](implicit + keySerde: Serde[K], + valueSerde: Serde[V] + ): Materialized[K, V, S] = + Materialized.`with`[K, V, S] + + implicit def producedFromSerde[K, V](implicit keySerde: Serde[K], valueSerde: Serde[V]): Produced[K, V] = + Produced.`with`[K, V] + + implicit def repartitionedFromSerde[K, V](implicit keySerde: Serde[K], valueSerde: Serde[V]): Repartitioned[K, V] = + Repartitioned.`with`[K, V] + + implicit def streamJoinFromKeyValueOtherSerde[K, V, VO](implicit + keySerde: Serde[K], + valueSerde: Serde[V], + otherValueSerde: Serde[VO] + ): StreamJoined[K, V, VO] = + StreamJoined.`with`[K, V, VO] +} diff --git a/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/Serdes.scala b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/Serdes.scala new file mode 100644 index 0000000..2e42090 --- /dev/null +++ b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/Serdes.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala + +import java.util + +import org.apache.kafka.common.serialization.{Deserializer, Serde, Serdes => JSerdes, Serializer} +import org.apache.kafka.streams.kstream.WindowedSerdes + +@deprecated( + "Use org.apache.kafka.streams.scala.serialization.Serdes. For WindowedSerdes.TimeWindowedSerde, use explicit constructors.", + "2.7.0" +) +object Serdes { + implicit def String: Serde[String] = JSerdes.String() + implicit def Long: Serde[Long] = JSerdes.Long().asInstanceOf[Serde[Long]] + implicit def JavaLong: Serde[java.lang.Long] = JSerdes.Long() + implicit def ByteArray: Serde[Array[Byte]] = JSerdes.ByteArray() + implicit def Bytes: Serde[org.apache.kafka.common.utils.Bytes] = JSerdes.Bytes() + implicit def Float: Serde[Float] = JSerdes.Float().asInstanceOf[Serde[Float]] + implicit def JavaFloat: Serde[java.lang.Float] = JSerdes.Float() + implicit def Double: Serde[Double] = JSerdes.Double().asInstanceOf[Serde[Double]] + implicit def JavaDouble: Serde[java.lang.Double] = JSerdes.Double() + implicit def Integer: Serde[Int] = JSerdes.Integer().asInstanceOf[Serde[Int]] + implicit def JavaInteger: Serde[java.lang.Integer] = JSerdes.Integer() + + implicit def timeWindowedSerde[T](implicit tSerde: Serde[T]): WindowedSerdes.TimeWindowedSerde[T] = + new WindowedSerdes.TimeWindowedSerde[T](tSerde) + + implicit def sessionWindowedSerde[T](implicit tSerde: Serde[T]): WindowedSerdes.SessionWindowedSerde[T] = + new WindowedSerdes.SessionWindowedSerde[T](tSerde) + + def fromFn[T >: Null](serializer: T => Array[Byte], deserializer: Array[Byte] => Option[T]): Serde[T] = + JSerdes.serdeFrom( + new Serializer[T] { + override def serialize(topic: String, data: T): Array[Byte] = serializer(data) + override def configure(configs: util.Map[String, _], isKey: Boolean): Unit = () + override def close(): Unit = () + }, + new Deserializer[T] { + override def deserialize(topic: String, data: Array[Byte]): T = deserializer(data).orNull + override def configure(configs: util.Map[String, _], isKey: Boolean): Unit = () + override def close(): Unit = () + } + ) + + def fromFn[T >: Null]( + serializer: (String, T) => Array[Byte], + deserializer: (String, Array[Byte]) => Option[T] + ): Serde[T] = + JSerdes.serdeFrom( + new Serializer[T] { + override def serialize(topic: String, data: T): Array[Byte] = serializer(topic, data) + override def configure(configs: util.Map[String, _], isKey: Boolean): Unit = () + override def close(): Unit = () + }, + new Deserializer[T] { + override def deserialize(topic: String, data: Array[Byte]): T = deserializer(topic, data).orNull + override def configure(configs: util.Map[String, _], isKey: Boolean): Unit = () + override def close(): Unit = () + } + ) +} diff --git a/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/StreamsBuilder.scala b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/StreamsBuilder.scala new file mode 100644 index 0000000..9430a51 --- /dev/null +++ b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/StreamsBuilder.scala @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala + +import java.util.Properties +import java.util.regex.Pattern + +import org.apache.kafka.streams.kstream.GlobalKTable +import org.apache.kafka.streams.processor.{ProcessorSupplier, StateStore} +import org.apache.kafka.streams.state.StoreBuilder +import org.apache.kafka.streams.{StreamsBuilder => StreamsBuilderJ, Topology} +import org.apache.kafka.streams.scala.kstream.{Consumed, KStream, KTable, Materialized} + +import scala.jdk.CollectionConverters._ + +/** + * Wraps the Java class StreamsBuilder and delegates method calls to the underlying Java object. + */ +class StreamsBuilder(inner: StreamsBuilderJ = new StreamsBuilderJ) { + + /** + * Create a [[kstream.KStream]] from the specified topic. + *

                + * The `implicit Consumed` instance provides the values of `auto.offset.reset` strategy, `TimestampExtractor`, + * key and value deserializers etc. If the implicit is not found in scope, compiler error will result. + *

                + * A convenient alternative is to have the necessary implicit serdes in scope, which will be implicitly + * converted to generate an instance of `Consumed`. @see [[ImplicitConversions]]. + * {{{ + * // Brings all implicit conversions in scope + * import ImplicitConversions._ + * + * // Bring implicit default serdes in scope + * import Serdes._ + * + * val builder = new StreamsBuilder() + * + * // stream function gets the implicit Consumed which is constructed automatically + * // from the serdes through the implicits in ImplicitConversions#consumedFromSerde + * val userClicksStream: KStream[String, Long] = builder.stream(userClicksTopic) + * }}} + * + * @param topic the topic name + * @return a [[kstream.KStream]] for the specified topic + */ + def stream[K, V](topic: String)(implicit consumed: Consumed[K, V]): KStream[K, V] = + new KStream(inner.stream[K, V](topic, consumed)) + + /** + * Create a [[kstream.KStream]] from the specified topics. + * + * @param topics the topic names + * @return a [[kstream.KStream]] for the specified topics + * @see #stream(String) + * @see `org.apache.kafka.streams.StreamsBuilder#stream` + */ + def stream[K, V](topics: Set[String])(implicit consumed: Consumed[K, V]): KStream[K, V] = + new KStream(inner.stream[K, V](topics.asJava, consumed)) + + /** + * Create a [[kstream.KStream]] from the specified topic pattern. + * + * @param topicPattern the topic name pattern + * @return a [[kstream.KStream]] for the specified topics + * @see #stream(String) + * @see `org.apache.kafka.streams.StreamsBuilder#stream` + */ + def stream[K, V](topicPattern: Pattern)(implicit consumed: Consumed[K, V]): KStream[K, V] = + new KStream(inner.stream[K, V](topicPattern, consumed)) + + /** + * Create a [[kstream.KTable]] from the specified topic. + *

                + * The `implicit Consumed` instance provides the values of `auto.offset.reset` strategy, `TimestampExtractor`, + * key and value deserializers etc. If the implicit is not found in scope, compiler error will result. + *

                + * A convenient alternative is to have the necessary implicit serdes in scope, which will be implicitly + * converted to generate an instance of `Consumed`. @see [[ImplicitConversions]]. + * {{{ + * // Brings all implicit conversions in scope + * import ImplicitConversions._ + * + * // Bring implicit default serdes in scope + * import Serdes._ + * + * val builder = new StreamsBuilder() + * + * // stream function gets the implicit Consumed which is constructed automatically + * // from the serdes through the implicits in ImplicitConversions#consumedFromSerde + * val userClicksStream: KTable[String, Long] = builder.table(userClicksTopic) + * }}} + * + * @param topic the topic name + * @return a [[kstream.KTable]] for the specified topic + * @see `org.apache.kafka.streams.StreamsBuilder#table` + */ + def table[K, V](topic: String)(implicit consumed: Consumed[K, V]): KTable[K, V] = + new KTable(inner.table[K, V](topic, consumed)) + + /** + * Create a [[kstream.KTable]] from the specified topic. + * + * @param topic the topic name + * @param materialized the instance of `Materialized` used to materialize a state store + * @return a [[kstream.KTable]] for the specified topic + * @see #table(String) + * @see `org.apache.kafka.streams.StreamsBuilder#table` + */ + def table[K, V](topic: String, materialized: Materialized[K, V, ByteArrayKeyValueStore])(implicit + consumed: Consumed[K, V] + ): KTable[K, V] = + new KTable(inner.table[K, V](topic, consumed, materialized)) + + /** + * Create a `GlobalKTable` from the specified topic. The serializers from the implicit `Consumed` + * instance will be used. Input records with `null` key will be dropped. + * + * @param topic the topic name + * @return a `GlobalKTable` for the specified topic + * @see `org.apache.kafka.streams.StreamsBuilder#globalTable` + */ + def globalTable[K, V](topic: String)(implicit consumed: Consumed[K, V]): GlobalKTable[K, V] = + inner.globalTable(topic, consumed) + + /** + * Create a `GlobalKTable` from the specified topic. The resulting `GlobalKTable` will be materialized + * in a local `KeyValueStore` configured with the provided instance of `Materialized`. The serializers + * from the implicit `Consumed` instance will be used. + * + * @param topic the topic name + * @param materialized the instance of `Materialized` used to materialize a state store + * @return a `GlobalKTable` for the specified topic + * @see `org.apache.kafka.streams.StreamsBuilder#globalTable` + */ + def globalTable[K, V](topic: String, materialized: Materialized[K, V, ByteArrayKeyValueStore])(implicit + consumed: Consumed[K, V] + ): GlobalKTable[K, V] = + inner.globalTable(topic, consumed, materialized) + + /** + * Adds a state store to the underlying `Topology`. The store must still be "connected" to a `Processor`, + * `Transformer`, or `ValueTransformer` before it can be used. + *

                + * It is required to connect state stores to `Processor`, `Transformer`, or `ValueTransformer` before they can be used. + * + * @param builder the builder used to obtain this state store `StateStore` instance + * @return the underlying Java abstraction `StreamsBuilder` after adding the `StateStore` + * @throws org.apache.kafka.streams.errors.TopologyException if state store supplier is already added + * @see `org.apache.kafka.streams.StreamsBuilder#addStateStore` + */ + def addStateStore(builder: StoreBuilder[_ <: StateStore]): StreamsBuilderJ = inner.addStateStore(builder) + + /** + * Adds a global `StateStore` to the topology. Global stores should not be added to `Processor`, `Transformer`, + * or `ValueTransformer` (in contrast to regular stores). + *

                + * It is not required to connect a global store to `Processor`, `Transformer`, or `ValueTransformer`; + * those have read-only access to all global stores by default. + * + * @see `org.apache.kafka.streams.StreamsBuilder#addGlobalStore` + */ + @deprecated( + "Use #addGlobalStore(StoreBuilder, String, Consumed, org.apache.kafka.streams.processor.api.ProcessorSupplier) instead.", + "2.7.0" + ) + def addGlobalStore[K, V]( + storeBuilder: StoreBuilder[_ <: StateStore], + topic: String, + consumed: Consumed[K, V], + stateUpdateSupplier: ProcessorSupplier[K, V] + ): StreamsBuilderJ = + inner.addGlobalStore(storeBuilder, topic, consumed, stateUpdateSupplier) + + /** + * Adds a global `StateStore` to the topology. Global stores should not be added to `Processor`, `Transformer`, + * or `ValueTransformer` (in contrast to regular stores). + *

                + * It is not required to connect a global store to `Processor`, `Transformer`, or `ValueTransformer`; + * those have read-only access to all global stores by default. + * + * @see `org.apache.kafka.streams.StreamsBuilder#addGlobalStore` + */ + def addGlobalStore[K, V]( + storeBuilder: StoreBuilder[_ <: StateStore], + topic: String, + consumed: Consumed[K, V], + stateUpdateSupplier: org.apache.kafka.streams.processor.api.ProcessorSupplier[K, V, Void, Void] + ): StreamsBuilderJ = + inner.addGlobalStore(storeBuilder, topic, consumed, stateUpdateSupplier) + + def build(): Topology = inner.build() + + /** + * Returns the `Topology` that represents the specified processing logic and accepts + * a `Properties` instance used to indicate whether to optimize topology or not. + * + * @param props the `Properties` used for building possibly optimized topology + * @return the `Topology` that represents the specified processing logic + * @see `org.apache.kafka.streams.StreamsBuilder#build` + */ + def build(props: Properties): Topology = inner.build(props) +} diff --git a/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/Branched.scala b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/Branched.scala new file mode 100644 index 0000000..6ac1371 --- /dev/null +++ b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/Branched.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala.kstream + +import org.apache.kafka.streams.kstream.{Branched => BranchedJ, KStream => KStreamJ} + +object Branched { + + /** + * Create an instance of `Branched` with provided branch name suffix. + * + * @param name the branch name suffix to be used (see [[BranchedKStream]] description for details) + * @tparam K key type + * @tparam V value type + * @return a new instance of `Branched` + */ + def as[K, V](name: String): BranchedJ[K, V] = + BranchedJ.as[K, V](name) + + /** + * Create an instance of `Branched` with provided chain function and branch name suffix. + * + * @param chain A function that will be applied to the branch. If the provided function returns + * `null`, its result is ignored, otherwise it is added to the Map returned + * by [[BranchedKStream.defaultBranch]] or [[BranchedKStream.noDefaultBranch]] (see + * [[BranchedKStream]] description for details). + * @param name the branch name suffix to be used. If `null`, a default branch name suffix will be generated + * (see [[BranchedKStream]] description for details) + * @tparam K key type + * @tparam V value type + * @return a new instance of `Branched` + * @see `org.apache.kafka.streams.kstream.Branched#withFunction(java.util.function.Function, java.lang.String)` + */ + def withFunction[K, V](chain: KStream[K, V] => KStream[K, V], name: String = null): BranchedJ[K, V] = + BranchedJ.withFunction((f: KStreamJ[K, V]) => chain.apply(new KStream[K, V](f)).inner, name) + + /** + * Create an instance of `Branched` with provided chain consumer and branch name suffix. + * + * @param chain A consumer to which the branch will be sent. If a non-null consumer is provided here, + * the respective branch will not be added to the resulting Map returned + * by [[BranchedKStream.defaultBranch]] or [[BranchedKStream.noDefaultBranch]] (see + * [[BranchedKStream]] description for details). + * @param name the branch name suffix to be used. If `null`, a default branch name suffix will be generated + * (see [[BranchedKStream]] description for details) + * @tparam K key type + * @tparam V value type + * @return a new instance of `Branched` + * @see `org.apache.kafka.streams.kstream.Branched#withConsumer(java.util.function.Consumer, java.lang.String)` + */ + def withConsumer[K, V](chain: KStream[K, V] => Unit, name: String = null): BranchedJ[K, V] = + BranchedJ.withConsumer((c: KStreamJ[K, V]) => chain.apply(new KStream[K, V](c)), name) +} diff --git a/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/BranchedKStream.scala b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/BranchedKStream.scala new file mode 100644 index 0000000..c606c00 --- /dev/null +++ b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/BranchedKStream.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala.kstream + +import java.util + +import org.apache.kafka.streams.kstream +import org.apache.kafka.streams.kstream.{BranchedKStream => BranchedKStreamJ} +import org.apache.kafka.streams.scala.FunctionsCompatConversions.PredicateFromFunction + +import scala.jdk.CollectionConverters._ + +/** + * Branches the records in the original stream based on the predicates supplied for the branch definitions. + *

                + * Branches are defined with [[branch]] or [[defaultBranch]] methods. Each record is evaluated against the predicates + * supplied via [[Branched]] parameters, and is routed to the first branch for which its respective predicate + * evaluates to `true`. If a record does not match any predicates, it will be routed to the default branch, + * or dropped if no default branch is created. + *

                + * + * Each branch (which is a [[KStream]] instance) then can be processed either by + * a function or a consumer provided via a [[Branched]] + * parameter. If certain conditions are met, it also can be accessed from the `Map` returned by + * an optional [[defaultBranch]] or [[noDefaultBranch]] method call. + *

                + * The branching happens on a first match basis: A record in the original stream is assigned to the corresponding result + * stream for the first predicate that evaluates to true, and is assigned to this stream only. If you need + * to route a record to multiple streams, you can apply multiple [[KStream.filter]] operators to the same [[KStream]] + * instance, one for each predicate, instead of branching. + *

                + * The process of routing the records to different branches is a stateless record-by-record operation. + * + * @tparam K Type of keys + * @tparam V Type of values + */ +class BranchedKStream[K, V](val inner: BranchedKStreamJ[K, V]) { + + /** + * Define a branch for records that match the predicate. + * + * @param predicate A predicate against which each record will be evaluated. + * If this predicate returns `true` for a given record, the record will be + * routed to the current branch and will not be evaluated against the predicates + * for the remaining branches. + * @return `this` to facilitate method chaining + */ + def branch(predicate: (K, V) => Boolean): BranchedKStream[K, V] = { + inner.branch(predicate.asPredicate) + this + } + + /** + * Define a branch for records that match the predicate. + * + * @param predicate A predicate against which each record will be evaluated. + * If this predicate returns `true` for a given record, the record will be + * routed to the current branch and will not be evaluated against the predicates + * for the remaining branches. + * @param branched A [[Branched]] parameter, that allows to define a branch name, an in-place + * branch consumer or branch mapper (see code examples + * for [[BranchedKStream]]) + * @return `this` to facilitate method chaining + */ + def branch(predicate: (K, V) => Boolean, branched: Branched[K, V]): BranchedKStream[K, V] = { + inner.branch(predicate.asPredicate, branched) + this + } + + /** + * Finalize the construction of branches and defines the default branch for the messages not intercepted + * by other branches. Calling [[defaultBranch]] or [[noDefaultBranch]] is optional. + * + * @return Map of named branches. For rules of forming the resulting map, see [[BranchedKStream]] + * description. + */ + def defaultBranch(): Map[String, KStream[K, V]] = toScalaMap(inner.defaultBranch()) + + /** + * Finalize the construction of branches and defines the default branch for the messages not intercepted + * by other branches. Calling [[defaultBranch]] or [[noDefaultBranch]] is optional. + * + * @param branched A [[Branched]] parameter, that allows to define a branch name, an in-place + * branch consumer or branch mapper for [[BranchedKStream]]. + * @return Map of named branches. For rules of forming the resulting map, see [[BranchedKStream]] + * description. + */ + def defaultBranch(branched: Branched[K, V]): Map[String, KStream[K, V]] = toScalaMap(inner.defaultBranch(branched)) + + /** + * Finalizes the construction of branches without forming a default branch. + * + * @return Map of named branches. For rules of forming the resulting map, see [[BranchedKStream]] + * description. + */ + def noDefaultBranch(): Map[String, KStream[K, V]] = toScalaMap(inner.noDefaultBranch()) + + private def toScalaMap(m: util.Map[String, kstream.KStream[K, V]]): collection.immutable.Map[String, KStream[K, V]] = + m.asScala.map { case (name, kStreamJ) => + (name, new KStream(kStreamJ)) + }.toMap +} diff --git a/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/CogroupedKStream.scala b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/CogroupedKStream.scala new file mode 100644 index 0000000..2bf58ca --- /dev/null +++ b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/CogroupedKStream.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala +package kstream + +import org.apache.kafka.streams.kstream.{ + SessionWindows, + SlidingWindows, + Window, + Windows, + CogroupedKStream => CogroupedKStreamJ +} +import org.apache.kafka.streams.scala.FunctionsCompatConversions.{AggregatorFromFunction, InitializerFromFunction} + +/** + * Wraps the Java class CogroupedKStream and delegates method calls to the underlying Java object. + * + * @tparam KIn Type of keys + * @tparam VOut Type of values + * @param inner The underlying Java abstraction for CogroupedKStream + * @see `org.apache.kafka.streams.kstream.CogroupedKStream` + */ +class CogroupedKStream[KIn, VOut](val inner: CogroupedKStreamJ[KIn, VOut]) { + + /** + * Add an already [[KGroupedStream]] to this [[CogroupedKStream]]. + * + * @param groupedStream a group stream + * @param aggregator a function that computes a new aggregate result + * @return a [[CogroupedKStream]] + */ + def cogroup[VIn]( + groupedStream: KGroupedStream[KIn, VIn], + aggregator: (KIn, VIn, VOut) => VOut + ): CogroupedKStream[KIn, VOut] = + new CogroupedKStream(inner.cogroup(groupedStream.inner, aggregator.asAggregator)) + + /** + * Aggregate the values of records in these streams by the grouped key and defined window. + * + * @param initializer an `Initializer` that computes an initial intermediate aggregation result. + * Cannot be { @code null}. + * @param materialized an instance of `Materialized` used to materialize a state store. + * Cannot be { @code null}. + * @return a [[KTable]] that contains "update" records with unmodified keys, and values that represent the latest + * (rolling) aggregate for each key + * @see `org.apache.kafka.streams.kstream.CogroupedKStream#aggregate` + */ + def aggregate(initializer: => VOut)(implicit + materialized: Materialized[KIn, VOut, ByteArrayKeyValueStore] + ): KTable[KIn, VOut] = new KTable(inner.aggregate((() => initializer).asInitializer, materialized)) + + /** + * Aggregate the values of records in these streams by the grouped key and defined window. + * + * @param initializer an `Initializer` that computes an initial intermediate aggregation result. + * Cannot be { @code null}. + * @param named a [[Named]] config used to name the processor in the topology + * @param materialized an instance of `Materialized` used to materialize a state store. + * Cannot be { @code null}. + * @return a [[KTable]] that contains "update" records with unmodified keys, and values that represent the latest + * (rolling) aggregate for each key + * @see `org.apache.kafka.streams.kstream.CogroupedKStream#aggregate` + */ + def aggregate(initializer: => VOut, named: Named)(implicit + materialized: Materialized[KIn, VOut, ByteArrayKeyValueStore] + ): KTable[KIn, VOut] = new KTable(inner.aggregate((() => initializer).asInitializer, named, materialized)) + + /** + * Create a new [[TimeWindowedCogroupedKStream]] instance that can be used to perform windowed aggregations. + * + * @param windows the specification of the aggregation `Windows` + * @return an instance of [[TimeWindowedCogroupedKStream]] + * @see `org.apache.kafka.streams.kstream.CogroupedKStream#windowedBy` + */ + def windowedBy[W <: Window](windows: Windows[W]): TimeWindowedCogroupedKStream[KIn, VOut] = + new TimeWindowedCogroupedKStream(inner.windowedBy(windows)) + + /** + * Create a new [[TimeWindowedCogroupedKStream]] instance that can be used to perform sliding windowed aggregations. + * + * @param windows the specification of the aggregation `SlidingWindows` + * @return an instance of [[TimeWindowedCogroupedKStream]] + * @see `org.apache.kafka.streams.kstream.CogroupedKStream#windowedBy` + */ + def windowedBy(windows: SlidingWindows): TimeWindowedCogroupedKStream[KIn, VOut] = + new TimeWindowedCogroupedKStream(inner.windowedBy(windows)) + + /** + * Create a new [[SessionWindowedKStream]] instance that can be used to perform session windowed aggregations. + * + * @param windows the specification of the aggregation `SessionWindows` + * @return an instance of [[SessionWindowedKStream]] + * @see `org.apache.kafka.streams.kstream.KGroupedStream#windowedBy` + */ + def windowedBy(windows: SessionWindows): SessionWindowedCogroupedKStream[KIn, VOut] = + new SessionWindowedCogroupedKStream(inner.windowedBy(windows)) + +} diff --git a/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/Consumed.scala b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/Consumed.scala new file mode 100644 index 0000000..714df97 --- /dev/null +++ b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/Consumed.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala.kstream + +import org.apache.kafka.common.serialization.Serde +import org.apache.kafka.streams.kstream.{Consumed => ConsumedJ} +import org.apache.kafka.streams.Topology +import org.apache.kafka.streams.processor.TimestampExtractor + +object Consumed { + + /** + * Create an instance of [[Consumed]] with the supplied arguments. `null` values are acceptable. + * + * @tparam K key type + * @tparam V value type + * @param timestampExtractor the timestamp extractor to used. If `null` the default timestamp extractor from + * config will be used + * @param resetPolicy the offset reset policy to be used. If `null` the default reset policy from config + * will be used + * @param keySerde the key serde to use. + * @param valueSerde the value serde to use. + * @return a new instance of [[Consumed]] + */ + def `with`[K, V]( + timestampExtractor: TimestampExtractor, + resetPolicy: Topology.AutoOffsetReset + )(implicit keySerde: Serde[K], valueSerde: Serde[V]): ConsumedJ[K, V] = + ConsumedJ.`with`(keySerde, valueSerde, timestampExtractor, resetPolicy) + + /** + * Create an instance of [[Consumed]] with key and value [[Serde]]s. + * + * @tparam K key type + * @tparam V value type + * @return a new instance of [[Consumed]] + */ + def `with`[K, V](implicit keySerde: Serde[K], valueSerde: Serde[V]): ConsumedJ[K, V] = + ConsumedJ.`with`(keySerde, valueSerde) + + /** + * Create an instance of [[Consumed]] with a [[TimestampExtractor]]. + * + * @param timestampExtractor the timestamp extractor to used. If `null` the default timestamp extractor from + * config will be used + * @tparam K key type + * @tparam V value type + * @return a new instance of [[Consumed]] + */ + def `with`[K, V]( + timestampExtractor: TimestampExtractor + )(implicit keySerde: Serde[K], valueSerde: Serde[V]): ConsumedJ[K, V] = + ConsumedJ.`with`(timestampExtractor).withKeySerde(keySerde).withValueSerde(valueSerde) + + /** + * Create an instance of [[Consumed]] with a [[Topology.AutoOffsetReset]]. + * + * @tparam K key type + * @tparam V value type + * @param resetPolicy the offset reset policy to be used. If `null` the default reset policy from config will be used + * @return a new instance of [[Consumed]] + */ + def `with`[K, V]( + resetPolicy: Topology.AutoOffsetReset + )(implicit keySerde: Serde[K], valueSerde: Serde[V]): ConsumedJ[K, V] = + ConsumedJ.`with`(resetPolicy).withKeySerde(keySerde).withValueSerde(valueSerde) +} diff --git a/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/Grouped.scala b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/Grouped.scala new file mode 100644 index 0000000..03dde16 --- /dev/null +++ b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/Grouped.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala.kstream + +import org.apache.kafka.common.serialization.Serde +import org.apache.kafka.streams.kstream.{Grouped => GroupedJ} + +object Grouped { + + /** + * Construct a `Grouped` instance with the provided key and value [[Serde]]s. + * If the [[Serde]] params are `null` the default serdes defined in the configs will be used. + * + * @tparam K the key type + * @tparam V the value type + * @param keySerde keySerde that will be used to materialize a stream + * @param valueSerde valueSerde that will be used to materialize a stream + * @return a new instance of [[Grouped]] configured with the provided serdes + */ + def `with`[K, V](implicit keySerde: Serde[K], valueSerde: Serde[V]): GroupedJ[K, V] = + GroupedJ.`with`(keySerde, valueSerde) + + /** + * Construct a `Grouped` instance with the provided key and value [[Serde]]s. + * If the [[Serde]] params are `null` the default serdes defined in the configs will be used. + * + * @tparam K the key type + * @tparam V the value type + * @param name the name used as part of a potential repartition topic + * @param keySerde keySerde that will be used to materialize a stream + * @param valueSerde valueSerde that will be used to materialize a stream + * @return a new instance of [[Grouped]] configured with the provided serdes + */ + def `with`[K, V](name: String)(implicit keySerde: Serde[K], valueSerde: Serde[V]): GroupedJ[K, V] = + GroupedJ.`with`(name, keySerde, valueSerde) + +} diff --git a/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/Joined.scala b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/Joined.scala new file mode 100644 index 0000000..c614e14 --- /dev/null +++ b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/Joined.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala.kstream + +import org.apache.kafka.common.serialization.Serde +import org.apache.kafka.streams.kstream.{Joined => JoinedJ} + +object Joined { + + /** + * Create an instance of [[org.apache.kafka.streams.kstream.Joined]] with key, value, and otherValue [[Serde]] + * instances. + * `null` values are accepted and will be replaced by the default serdes as defined in config. + * + * @tparam K key type + * @tparam V value type + * @tparam VO other value type + * @param keySerde the key serde to use. + * @param valueSerde the value serde to use. + * @param otherValueSerde the otherValue serde to use. If `null` the default value serde from config will be used + * @return new [[org.apache.kafka.streams.kstream.Joined]] instance with the provided serdes + */ + def `with`[K, V, VO](implicit + keySerde: Serde[K], + valueSerde: Serde[V], + otherValueSerde: Serde[VO] + ): JoinedJ[K, V, VO] = + JoinedJ.`with`(keySerde, valueSerde, otherValueSerde) + + /** + * Create an instance of [[org.apache.kafka.streams.kstream.Joined]] with key, value, and otherValue [[Serde]] + * instances. + * `null` values are accepted and will be replaced by the default serdes as defined in config. + * + * @tparam K key type + * @tparam V value type + * @tparam VO other value type + * @param name name of possible repartition topic + * @param keySerde the key serde to use. + * @param valueSerde the value serde to use. + * @param otherValueSerde the otherValue serde to use. If `null` the default value serde from config will be used + * @return new [[org.apache.kafka.streams.kstream.Joined]] instance with the provided serdes + */ + // disable spotless scala, which wants to make a mess of the argument lists + // format: off + def `with`[K, V, VO](name: String) + (implicit keySerde: Serde[K], + valueSerde: Serde[V], + otherValueSerde: Serde[VO]): JoinedJ[K, V, VO] = + JoinedJ.`with`(keySerde, valueSerde, otherValueSerde, name) + // format:on +} diff --git a/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/KGroupedStream.scala b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/KGroupedStream.scala new file mode 100644 index 0000000..60a9c57 --- /dev/null +++ b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/KGroupedStream.scala @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala +package kstream + +import org.apache.kafka.streams.kstream.internals.KTableImpl +import org.apache.kafka.streams.scala.serialization.Serdes +import org.apache.kafka.streams.kstream.{ + SessionWindows, + SlidingWindows, + Window, + Windows, + KGroupedStream => KGroupedStreamJ, + KTable => KTableJ +} +import org.apache.kafka.streams.scala.FunctionsCompatConversions.{ + AggregatorFromFunction, + InitializerFromFunction, + ReducerFromFunction, + ValueMapperFromFunction +} + +/** + * Wraps the Java class KGroupedStream and delegates method calls to the underlying Java object. + * + * @tparam K Type of keys + * @tparam V Type of values + * @param inner The underlying Java abstraction for KGroupedStream + * @see `org.apache.kafka.streams.kstream.KGroupedStream` + */ +class KGroupedStream[K, V](val inner: KGroupedStreamJ[K, V]) { + + /** + * Count the number of records in this stream by the grouped key. + * The result is written into a local `KeyValueStore` (which is basically an ever-updating materialized view) + * provided by the given `materialized`. + * + * @param materialized an instance of `Materialized` used to materialize a state store. + * @return a [[KTable]] that contains "update" records with unmodified keys and `Long` values that + * represent the latest (rolling) count (i.e., number of records) for each key + * @see `org.apache.kafka.streams.kstream.KGroupedStream#count` + */ + def count()(implicit materialized: Materialized[K, Long, ByteArrayKeyValueStore]): KTable[K, Long] = { + val javaCountTable: KTableJ[K, java.lang.Long] = + inner.count(materialized.asInstanceOf[Materialized[K, java.lang.Long, ByteArrayKeyValueStore]]) + val tableImpl = javaCountTable.asInstanceOf[KTableImpl[K, ByteArrayKeyValueStore, java.lang.Long]] + new KTable( + javaCountTable.mapValues[Long]( + ((l: java.lang.Long) => Long2long(l)).asValueMapper, + Materialized.`with`[K, Long, ByteArrayKeyValueStore](tableImpl.keySerde(), Serdes.longSerde) + ) + ) + } + + /** + * Count the number of records in this stream by the grouped key. + * The result is written into a local `KeyValueStore` (which is basically an ever-updating materialized view) + * provided by the given `materialized`. + * + * @param named a [[Named]] config used to name the processor in the topology + * @param materialized an instance of `Materialized` used to materialize a state store. + * @return a [[KTable]] that contains "update" records with unmodified keys and `Long` values that + * represent the latest (rolling) count (i.e., number of records) for each key + * @see `org.apache.kafka.streams.kstream.KGroupedStream#count` + */ + def count(named: Named)(implicit materialized: Materialized[K, Long, ByteArrayKeyValueStore]): KTable[K, Long] = { + val javaCountTable: KTableJ[K, java.lang.Long] = + inner.count(named, materialized.asInstanceOf[Materialized[K, java.lang.Long, ByteArrayKeyValueStore]]) + val tableImpl = javaCountTable.asInstanceOf[KTableImpl[K, ByteArrayKeyValueStore, java.lang.Long]] + new KTable( + javaCountTable.mapValues[Long]( + ((l: java.lang.Long) => Long2long(l)).asValueMapper, + Materialized.`with`[K, Long, ByteArrayKeyValueStore](tableImpl.keySerde(), Serdes.longSerde) + ) + ) + } + + /** + * Combine the values of records in this stream by the grouped key. + * + * @param reducer a function `(V, V) => V` that computes a new aggregate result. + * @param materialized an instance of `Materialized` used to materialize a state store. + * @return a [[KTable]] that contains "update" records with unmodified keys, and values that represent the + * latest (rolling) aggregate for each key + * @see `org.apache.kafka.streams.kstream.KGroupedStream#reduce` + */ + def reduce(reducer: (V, V) => V)(implicit materialized: Materialized[K, V, ByteArrayKeyValueStore]): KTable[K, V] = + new KTable(inner.reduce(reducer.asReducer, materialized)) + + /** + * Combine the values of records in this stream by the grouped key. + * + * @param reducer a function `(V, V) => V` that computes a new aggregate result. + * @param named a [[Named]] config used to name the processor in the topology + * @param materialized an instance of `Materialized` used to materialize a state store. + * @return a [[KTable]] that contains "update" records with unmodified keys, and values that represent the + * latest (rolling) aggregate for each key + * @see `org.apache.kafka.streams.kstream.KGroupedStream#reduce` + */ + def reduce(reducer: (V, V) => V, named: Named)(implicit + materialized: Materialized[K, V, ByteArrayKeyValueStore] + ): KTable[K, V] = + new KTable(inner.reduce(reducer.asReducer, materialized)) + + /** + * Aggregate the values of records in this stream by the grouped key. + * + * @param initializer an `Initializer` that computes an initial intermediate aggregation result + * @param aggregator an `Aggregator` that computes a new aggregate result + * @param materialized an instance of `Materialized` used to materialize a state store. + * @return a [[KTable]] that contains "update" records with unmodified keys, and values that represent the + * latest (rolling) aggregate for each key + * @see `org.apache.kafka.streams.kstream.KGroupedStream#aggregate` + */ + def aggregate[VR](initializer: => VR)(aggregator: (K, V, VR) => VR)(implicit + materialized: Materialized[K, VR, ByteArrayKeyValueStore] + ): KTable[K, VR] = + new KTable(inner.aggregate((() => initializer).asInitializer, aggregator.asAggregator, materialized)) + + /** + * Aggregate the values of records in this stream by the grouped key. + * + * @param initializer an `Initializer` that computes an initial intermediate aggregation result + * @param aggregator an `Aggregator` that computes a new aggregate result + * @param named a [[Named]] config used to name the processor in the topology + * @param materialized an instance of `Materialized` used to materialize a state store. + * @return a [[KTable]] that contains "update" records with unmodified keys, and values that represent the + * latest (rolling) aggregate for each key + * @see `org.apache.kafka.streams.kstream.KGroupedStream#aggregate` + */ + def aggregate[VR](initializer: => VR, named: Named)(aggregator: (K, V, VR) => VR)(implicit + materialized: Materialized[K, VR, ByteArrayKeyValueStore] + ): KTable[K, VR] = + new KTable(inner.aggregate((() => initializer).asInitializer, aggregator.asAggregator, named, materialized)) + + /** + * Create a new [[TimeWindowedKStream]] instance that can be used to perform windowed aggregations. + * + * @param windows the specification of the aggregation `Windows` + * @return an instance of [[TimeWindowedKStream]] + * @see `org.apache.kafka.streams.kstream.KGroupedStream#windowedBy` + */ + def windowedBy[W <: Window](windows: Windows[W]): TimeWindowedKStream[K, V] = + new TimeWindowedKStream(inner.windowedBy(windows)) + + /** + * Create a new [[TimeWindowedKStream]] instance that can be used to perform sliding windowed aggregations. + * + * @param windows the specification of the aggregation `SlidingWindows` + * @return an instance of [[TimeWindowedKStream]] + * @see `org.apache.kafka.streams.kstream.KGroupedStream#windowedBy` + */ + def windowedBy(windows: SlidingWindows): TimeWindowedKStream[K, V] = + new TimeWindowedKStream(inner.windowedBy(windows)) + + /** + * Create a new [[SessionWindowedKStream]] instance that can be used to perform session windowed aggregations. + * + * @param windows the specification of the aggregation `SessionWindows` + * @return an instance of [[SessionWindowedKStream]] + * @see `org.apache.kafka.streams.kstream.KGroupedStream#windowedBy` + */ + def windowedBy(windows: SessionWindows): SessionWindowedKStream[K, V] = + new SessionWindowedKStream(inner.windowedBy(windows)) + + /** + * Create a new [[CogroupedKStream]] from this grouped KStream to allow cogrouping other [[KGroupedStream]] to it. + * + * @param aggregator an `Aggregator` that computes a new aggregate result + * @return an instance of [[CogroupedKStream]] + * @see `org.apache.kafka.streams.kstream.KGroupedStream#cogroup` + */ + def cogroup[VR](aggregator: (K, V, VR) => VR): CogroupedKStream[K, VR] = + new CogroupedKStream(inner.cogroup(aggregator.asAggregator)) + +} diff --git a/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/KGroupedTable.scala b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/KGroupedTable.scala new file mode 100644 index 0000000..3d9e052 --- /dev/null +++ b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/KGroupedTable.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala +package kstream + +import org.apache.kafka.streams.kstream.{KGroupedTable => KGroupedTableJ} +import org.apache.kafka.streams.scala.FunctionsCompatConversions.{ + AggregatorFromFunction, + InitializerFromFunction, + ReducerFromFunction +} + +/** + * Wraps the Java class KGroupedTable and delegates method calls to the underlying Java object. + * + * @tparam K Type of keys + * @tparam V Type of values + * @param inner The underlying Java abstraction for KGroupedTable + * @see `org.apache.kafka.streams.kstream.KGroupedTable` + */ +class KGroupedTable[K, V](inner: KGroupedTableJ[K, V]) { + + /** + * Count number of records of the original [[KTable]] that got [[KTable#groupBy]] to + * the same key into a new instance of [[KTable]]. + * + * @param materialized an instance of `Materialized` used to materialize a state store. + * @return a [[KTable]] that contains "update" records with unmodified keys and `Long` values that + * represent the latest (rolling) count (i.e., number of records) for each key + * @see `org.apache.kafka.streams.kstream.KGroupedTable#count` + */ + def count()(implicit materialized: Materialized[K, Long, ByteArrayKeyValueStore]): KTable[K, Long] = { + val c: KTable[K, java.lang.Long] = + new KTable(inner.count(materialized.asInstanceOf[Materialized[K, java.lang.Long, ByteArrayKeyValueStore]])) + c.mapValues[Long](Long2long _) + } + + /** + * Count number of records of the original [[KTable]] that got [[KTable#groupBy]] to + * the same key into a new instance of [[KTable]]. + * + * @param named a [[Named]] config used to name the processor in the topology + * @param materialized an instance of `Materialized` used to materialize a state store. + * @return a [[KTable]] that contains "update" records with unmodified keys and `Long` values that + * represent the latest (rolling) count (i.e., number of records) for each key + * @see `org.apache.kafka.streams.kstream.KGroupedTable#count` + */ + def count(named: Named)(implicit materialized: Materialized[K, Long, ByteArrayKeyValueStore]): KTable[K, Long] = { + val c: KTable[K, java.lang.Long] = + new KTable(inner.count(named, materialized.asInstanceOf[Materialized[K, java.lang.Long, ByteArrayKeyValueStore]])) + c.mapValues[Long](Long2long _) + } + + /** + * Combine the value of records of the original [[KTable]] that got [[KTable#groupBy]] + * to the same key into a new instance of [[KTable]]. + * + * @param adder a function that adds a new value to the aggregate result + * @param subtractor a function that removed an old value from the aggregate result + * @param materialized an instance of `Materialized` used to materialize a state store. + * @return a [[KTable]] that contains "update" records with unmodified keys, and values that represent the + * latest (rolling) aggregate for each key + * @see `org.apache.kafka.streams.kstream.KGroupedTable#reduce` + */ + def reduce(adder: (V, V) => V, subtractor: (V, V) => V)(implicit + materialized: Materialized[K, V, ByteArrayKeyValueStore] + ): KTable[K, V] = + new KTable(inner.reduce(adder.asReducer, subtractor.asReducer, materialized)) + + /** + * Combine the value of records of the original [[KTable]] that got [[KTable#groupBy]] + * to the same key into a new instance of [[KTable]]. + * + * @param adder a function that adds a new value to the aggregate result + * @param subtractor a function that removed an old value from the aggregate result + * @param named a [[Named]] config used to name the processor in the topology + * @param materialized an instance of `Materialized` used to materialize a state store. + * @return a [[KTable]] that contains "update" records with unmodified keys, and values that represent the + * latest (rolling) aggregate for each key + * @see `org.apache.kafka.streams.kstream.KGroupedTable#reduce` + */ + def reduce(adder: (V, V) => V, subtractor: (V, V) => V, named: Named)(implicit + materialized: Materialized[K, V, ByteArrayKeyValueStore] + ): KTable[K, V] = + new KTable(inner.reduce(adder.asReducer, subtractor.asReducer, named, materialized)) + + /** + * Aggregate the value of records of the original [[KTable]] that got [[KTable#groupBy]] + * to the same key into a new instance of [[KTable]] using default serializers and deserializers. + * + * @param initializer a function that provides an initial aggregate result value + * @param adder a function that adds a new record to the aggregate result + * @param subtractor an aggregator function that removed an old record from the aggregate result + * @param materialized an instance of `Materialized` used to materialize a state store. + * @return a [[KTable]] that contains "update" records with unmodified keys, and values that represent the + * latest (rolling) aggregate for each key + * @see `org.apache.kafka.streams.kstream.KGroupedTable#aggregate` + */ + def aggregate[VR](initializer: => VR)(adder: (K, V, VR) => VR, subtractor: (K, V, VR) => VR)(implicit + materialized: Materialized[K, VR, ByteArrayKeyValueStore] + ): KTable[K, VR] = + new KTable( + inner.aggregate((() => initializer).asInitializer, adder.asAggregator, subtractor.asAggregator, materialized) + ) + + /** + * Aggregate the value of records of the original [[KTable]] that got [[KTable#groupBy]] + * to the same key into a new instance of [[KTable]] using default serializers and deserializers. + * + * @param initializer a function that provides an initial aggregate result value + * @param named a [[Named]] config used to name the processor in the topology + * @param adder a function that adds a new record to the aggregate result + * @param subtractor an aggregator function that removed an old record from the aggregate result + * @param materialized an instance of `Materialized` used to materialize a state store. + * @return a [[KTable]] that contains "update" records with unmodified keys, and values that represent the + * latest (rolling) aggregate for each key + * @see `org.apache.kafka.streams.kstream.KGroupedTable#aggregate` + */ + def aggregate[VR](initializer: => VR, named: Named)(adder: (K, V, VR) => VR, subtractor: (K, V, VR) => VR)(implicit + materialized: Materialized[K, VR, ByteArrayKeyValueStore] + ): KTable[K, VR] = + new KTable( + inner.aggregate( + (() => initializer).asInitializer, + adder.asAggregator, + subtractor.asAggregator, + named, + materialized + ) + ) +} diff --git a/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/KStream.scala b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/KStream.scala new file mode 100644 index 0000000..dedb424 --- /dev/null +++ b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/KStream.scala @@ -0,0 +1,1193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala +package kstream + +import org.apache.kafka.streams.KeyValue +import org.apache.kafka.streams.kstream.{ + GlobalKTable, + JoinWindows, + Printed, + TransformerSupplier, + ValueTransformerSupplier, + ValueTransformerWithKeySupplier, + KStream => KStreamJ +} +import org.apache.kafka.streams.processor.TopicNameExtractor +import org.apache.kafka.streams.processor.api.ProcessorSupplier +import org.apache.kafka.streams.scala.FunctionsCompatConversions.{ + FlatValueMapperFromFunction, + FlatValueMapperWithKeyFromFunction, + ForeachActionFromFunction, + KeyValueMapperFromFunction, + MapperFromFunction, + PredicateFromFunction, + TransformerSupplierAsJava, + ValueMapperFromFunction, + ValueMapperWithKeyFromFunction, + ValueTransformerSupplierAsJava, + ValueTransformerSupplierWithKeyAsJava +} + +import scala.jdk.CollectionConverters._ + +/** + * Wraps the Java class [[org.apache.kafka.streams.kstream.KStream KStream]] and delegates method calls to the + * underlying Java object. + * + * @tparam K Type of keys + * @tparam V Type of values + * @param inner The underlying Java abstraction for KStream + * @see `org.apache.kafka.streams.kstream.KStream` + */ +//noinspection ScalaDeprecation +class KStream[K, V](val inner: KStreamJ[K, V]) { + + /** + * Create a new [[KStream]] that consists all records of this stream which satisfies the given predicate. + * + * @param predicate a filter that is applied to each record + * @return a [[KStream]] that contains only those records that satisfy the given predicate + * @see `org.apache.kafka.streams.kstream.KStream#filter` + */ + def filter(predicate: (K, V) => Boolean): KStream[K, V] = + new KStream(inner.filter(predicate.asPredicate)) + + /** + * Create a new [[KStream]] that consists all records of this stream which satisfies the given predicate. + * + * @param predicate a filter that is applied to each record + * @param named a [[Named]] config used to name the processor in the topology + * @return a [[KStream]] that contains only those records that satisfy the given predicate + * @see `org.apache.kafka.streams.kstream.KStream#filter` + */ + def filter(predicate: (K, V) => Boolean, named: Named): KStream[K, V] = + new KStream(inner.filter(predicate.asPredicate, named)) + + /** + * Create a new [[KStream]] that consists all records of this stream which do not satisfy the given + * predicate. + * + * @param predicate a filter that is applied to each record + * @return a [[KStream]] that contains only those records that do not satisfy the given predicate + * @see `org.apache.kafka.streams.kstream.KStream#filterNot` + */ + def filterNot(predicate: (K, V) => Boolean): KStream[K, V] = + new KStream(inner.filterNot(predicate.asPredicate)) + + /** + * Create a new [[KStream]] that consists all records of this stream which do not satisfy the given + * predicate. + * + * @param predicate a filter that is applied to each record + * @param named a [[Named]] config used to name the processor in the topology + * @return a [[KStream]] that contains only those records that do not satisfy the given predicate + * @see `org.apache.kafka.streams.kstream.KStream#filterNot` + */ + def filterNot(predicate: (K, V) => Boolean, named: Named): KStream[K, V] = + new KStream(inner.filterNot(predicate.asPredicate, named)) + + /** + * Set a new key (with possibly new type) for each input record. + *

                + * The function `mapper` passed is applied to every record and results in the generation of a new + * key `KR`. The function outputs a new [[KStream]] where each record has this new key. + * + * @param mapper a function `(K, V) => KR` that computes a new key for each record + * @return a [[KStream]] that contains records with new key (possibly of different type) and unmodified value + * @see `org.apache.kafka.streams.kstream.KStream#selectKey` + */ + def selectKey[KR](mapper: (K, V) => KR): KStream[KR, V] = + new KStream(inner.selectKey[KR](mapper.asKeyValueMapper)) + + /** + * Set a new key (with possibly new type) for each input record. + *

                + * The function `mapper` passed is applied to every record and results in the generation of a new + * key `KR`. The function outputs a new [[KStream]] where each record has this new key. + * + * @param mapper a function `(K, V) => KR` that computes a new key for each record + * @param named a [[Named]] config used to name the processor in the topology + * @return a [[KStream]] that contains records with new key (possibly of different type) and unmodified value + * @see `org.apache.kafka.streams.kstream.KStream#selectKey` + */ + def selectKey[KR](mapper: (K, V) => KR, named: Named): KStream[KR, V] = + new KStream(inner.selectKey[KR](mapper.asKeyValueMapper, named)) + + /** + * Transform each record of the input stream into a new record in the output stream (both key and value type can be + * altered arbitrarily). + *

                + * The provided `mapper`, a function `(K, V) => (KR, VR)` is applied to each input record and computes a new output record. + * + * @param mapper a function `(K, V) => (KR, VR)` that computes a new output record + * @return a [[KStream]] that contains records with new key and value (possibly both of different type) + * @see `org.apache.kafka.streams.kstream.KStream#map` + */ + def map[KR, VR](mapper: (K, V) => (KR, VR)): KStream[KR, VR] = + new KStream(inner.map[KR, VR](mapper.asKeyValueMapper)) + + /** + * Transform each record of the input stream into a new record in the output stream (both key and value type can be + * altered arbitrarily). + *

                + * The provided `mapper`, a function `(K, V) => (KR, VR)` is applied to each input record and computes a new output record. + * + * @param mapper a function `(K, V) => (KR, VR)` that computes a new output record + * @param named a [[Named]] config used to name the processor in the topology + * @return a [[KStream]] that contains records with new key and value (possibly both of different type) + * @see `org.apache.kafka.streams.kstream.KStream#map` + */ + def map[KR, VR](mapper: (K, V) => (KR, VR), named: Named): KStream[KR, VR] = + new KStream(inner.map[KR, VR](mapper.asKeyValueMapper, named)) + + /** + * Transform the value of each input record into a new value (with possible new type) of the output record. + *

                + * The provided `mapper`, a function `V => VR` is applied to each input record value and computes a new value for it + * + * @param mapper , a function `V => VR` that computes a new output value + * @return a [[KStream]] that contains records with unmodified key and new values (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KStream#mapValues` + */ + def mapValues[VR](mapper: V => VR): KStream[K, VR] = + new KStream(inner.mapValues[VR](mapper.asValueMapper)) + + /** + * Transform the value of each input record into a new value (with possible new type) of the output record. + *

                + * The provided `mapper`, a function `V => VR` is applied to each input record value and computes a new value for it + * + * @param mapper , a function `V => VR` that computes a new output value + * @param named a [[Named]] config used to name the processor in the topology + * @return a [[KStream]] that contains records with unmodified key and new values (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KStream#mapValues` + */ + def mapValues[VR](mapper: V => VR, named: Named): KStream[K, VR] = + new KStream(inner.mapValues[VR](mapper.asValueMapper, named)) + + /** + * Transform the value of each input record into a new value (with possible new type) of the output record. + *

                + * The provided `mapper`, a function `(K, V) => VR` is applied to each input record value and computes a new value for it + * + * @param mapper , a function `(K, V) => VR` that computes a new output value + * @return a [[KStream]] that contains records with unmodified key and new values (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KStream#mapValues` + */ + def mapValues[VR](mapper: (K, V) => VR): KStream[K, VR] = + new KStream(inner.mapValues[VR](mapper.asValueMapperWithKey)) + + /** + * Transform the value of each input record into a new value (with possible new type) of the output record. + *

                + * The provided `mapper`, a function `(K, V) => VR` is applied to each input record value and computes a new value for it + * + * @param mapper , a function `(K, V) => VR` that computes a new output value + * @param named a [[Named]] config used to name the processor in the topology + * @return a [[KStream]] that contains records with unmodified key and new values (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KStream#mapValues` + */ + def mapValues[VR](mapper: (K, V) => VR, named: Named): KStream[K, VR] = + new KStream(inner.mapValues[VR](mapper.asValueMapperWithKey, named)) + + /** + * Transform each record of the input stream into zero or more records in the output stream (both key and value type + * can be altered arbitrarily). + *

                + * The provided `mapper`, function `(K, V) => Iterable[(KR, VR)]` is applied to each input record and computes zero or more output records. + * + * @param mapper function `(K, V) => Iterable[(KR, VR)]` that computes the new output records + * @return a [[KStream]] that contains more or less records with new key and value (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KStream#flatMap` + */ + def flatMap[KR, VR](mapper: (K, V) => Iterable[(KR, VR)]): KStream[KR, VR] = { + val kvMapper = mapper.tupled.andThen(_.map(ImplicitConversions.tuple2ToKeyValue).asJava) + new KStream(inner.flatMap[KR, VR](((k: K, v: V) => kvMapper(k, v)).asKeyValueMapper)) + } + + /** + * Transform each record of the input stream into zero or more records in the output stream (both key and value type + * can be altered arbitrarily). + *

                + * The provided `mapper`, function `(K, V) => Iterable[(KR, VR)]` is applied to each input record and computes zero or more output records. + * + * @param mapper function `(K, V) => Iterable[(KR, VR)]` that computes the new output records + * @param named a [[Named]] config used to name the processor in the topology + * @return a [[KStream]] that contains more or less records with new key and value (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KStream#flatMap` + */ + def flatMap[KR, VR](mapper: (K, V) => Iterable[(KR, VR)], named: Named): KStream[KR, VR] = { + val kvMapper = mapper.tupled.andThen(_.map(ImplicitConversions.tuple2ToKeyValue).asJava) + new KStream(inner.flatMap[KR, VR](((k: K, v: V) => kvMapper(k, v)).asKeyValueMapper, named)) + } + + /** + * Create a new [[KStream]] by transforming the value of each record in this stream into zero or more values + * with the same key in the new stream. + *

                + * Transform the value of each input record into zero or more records with the same (unmodified) key in the output + * stream (value type can be altered arbitrarily). + * The provided `mapper`, a function `V => Iterable[VR]` is applied to each input record and computes zero or more output values. + * + * @param mapper a function `V => Iterable[VR]` that computes the new output values + * @return a [[KStream]] that contains more or less records with unmodified keys and new values of different type + * @see `org.apache.kafka.streams.kstream.KStream#flatMapValues` + */ + def flatMapValues[VR](mapper: V => Iterable[VR]): KStream[K, VR] = + new KStream(inner.flatMapValues[VR](mapper.asValueMapper)) + + /** + * Create a new [[KStream]] by transforming the value of each record in this stream into zero or more values + * with the same key in the new stream. + *

                + * Transform the value of each input record into zero or more records with the same (unmodified) key in the output + * stream (value type can be altered arbitrarily). + * The provided `mapper`, a function `V => Iterable[VR]` is applied to each input record and computes zero or more output values. + * + * @param mapper a function `V => Iterable[VR]` that computes the new output values + * @param named a [[Named]] config used to name the processor in the topology + * @return a [[KStream]] that contains more or less records with unmodified keys and new values of different type + * @see `org.apache.kafka.streams.kstream.KStream#flatMapValues` + */ + def flatMapValues[VR](mapper: V => Iterable[VR], named: Named): KStream[K, VR] = + new KStream(inner.flatMapValues[VR](mapper.asValueMapper, named)) + + /** + * Create a new [[KStream]] by transforming the value of each record in this stream into zero or more values + * with the same key in the new stream. + *

                + * Transform the value of each input record into zero or more records with the same (unmodified) key in the output + * stream (value type can be altered arbitrarily). + * The provided `mapper`, a function `(K, V) => Iterable[VR]` is applied to each input record and computes zero or more output values. + * + * @param mapper a function `(K, V) => Iterable[VR]` that computes the new output values + * @return a [[KStream]] that contains more or less records with unmodified keys and new values of different type + * @see `org.apache.kafka.streams.kstream.KStream#flatMapValues` + */ + def flatMapValues[VR](mapper: (K, V) => Iterable[VR]): KStream[K, VR] = + new KStream(inner.flatMapValues[VR](mapper.asValueMapperWithKey)) + + /** + * Create a new [[KStream]] by transforming the value of each record in this stream into zero or more values + * with the same key in the new stream. + *

                + * Transform the value of each input record into zero or more records with the same (unmodified) key in the output + * stream (value type can be altered arbitrarily). + * The provided `mapper`, a function `(K, V) => Iterable[VR]` is applied to each input record and computes zero or more output values. + * + * @param mapper a function `(K, V) => Iterable[VR]` that computes the new output values + * @param named a [[Named]] config used to name the processor in the topology + * @return a [[KStream]] that contains more or less records with unmodified keys and new values of different type + * @see `org.apache.kafka.streams.kstream.KStream#flatMapValues` + */ + def flatMapValues[VR](mapper: (K, V) => Iterable[VR], named: Named): KStream[K, VR] = + new KStream(inner.flatMapValues[VR](mapper.asValueMapperWithKey, named)) + + /** + * Print the records of this KStream using the options provided by `Printed` + * + * @param printed options for printing + * @see `org.apache.kafka.streams.kstream.KStream#print` + */ + def print(printed: Printed[K, V]): Unit = inner.print(printed) + + /** + * Perform an action on each record of `KStream` + * + * @param action an action to perform on each record + * @see `org.apache.kafka.streams.kstream.KStream#foreach` + */ + def foreach(action: (K, V) => Unit): Unit = + inner.foreach(action.asForeachAction) + + /** + * Perform an action on each record of `KStream` + * + * @param action an action to perform on each record + * @param named a [[Named]] config used to name the processor in the topology + * @see `org.apache.kafka.streams.kstream.KStream#foreach` + */ + def foreach(action: (K, V) => Unit, named: Named): Unit = + inner.foreach(action.asForeachAction, named) + + /** + * Creates an array of `KStream` from this stream by branching the records in the original stream based on + * the supplied predicates. + * + * @param predicates the ordered list of functions that return a Boolean + * @return multiple distinct substreams of this [[KStream]] + * @see `org.apache.kafka.streams.kstream.KStream#branch` + * @deprecated since 2.8. Use `split` instead. + */ + //noinspection ScalaUnnecessaryParentheses + @deprecated("use `split()` instead", "2.8") + def branch(predicates: ((K, V) => Boolean)*): Array[KStream[K, V]] = + inner.branch(predicates.map(_.asPredicate): _*).map(kstream => new KStream(kstream)) + + /** + * Split this stream. [[BranchedKStream]] can be used for routing the records to different branches depending + * on evaluation against the supplied predicates. + * Stream branching is a stateless record-by-record operation. + * + * @return [[BranchedKStream]] that provides methods for routing the records to different branches. + * @see `org.apache.kafka.streams.kstream.KStream#split` + */ + def split(): BranchedKStream[K, V] = + new BranchedKStream(inner.split()) + + /** + * Split this stream. [[BranchedKStream]] can be used for routing the records to different branches depending + * on evaluation against the supplied predicates. + * Stream branching is a stateless record-by-record operation. + * + * @param named a [[Named]] config used to name the processor in the topology and also to set the name prefix + * for the resulting branches (see [[BranchedKStream]]) + * @return [[BranchedKStream]] that provides methods for routing the records to different branches. + * @see `org.apache.kafka.streams.kstream.KStream#split` + */ + def split(named: Named): BranchedKStream[K, V] = + new BranchedKStream(inner.split(named)) + + /** + * Materialize this stream to a topic and creates a new [[KStream]] from the topic using the `Produced` instance for + * configuration of the `Serde key serde`, `Serde value serde`, and `StreamPartitioner` + *

                + * The user can either supply the `Produced` instance as an implicit in scope or they can also provide implicit + * key and value serdes that will be converted to a `Produced` instance implicitly. + *

                + * {{{ + * Example: + * + * // brings implicit serdes in scope + * import Serdes._ + * + * //.. + * val clicksPerRegion: KStream[String, Long] = //.. + * + * // Implicit serdes in scope will generate an implicit Produced instance, which + * // will be passed automatically to the call of through below + * clicksPerRegion.through(topic) + * + * // Similarly you can create an implicit Produced and it will be passed implicitly + * // to the through call + * }}} + * + * @param topic the topic name + * @param produced the instance of Produced that gives the serdes and `StreamPartitioner` + * @return a [[KStream]] that contains the exact same (and potentially repartitioned) records as this [[KStream]] + * @see `org.apache.kafka.streams.kstream.KStream#through` + * @deprecated use `repartition()` instead + */ + @deprecated("use `repartition()` instead", "2.6.0") + def through(topic: String)(implicit produced: Produced[K, V]): KStream[K, V] = + new KStream(inner.through(topic, produced)) + + /** + * Materialize this stream to a topic and creates a new [[KStream]] from the topic using the `Repartitioned` instance + * for configuration of the `Serde key serde`, `Serde value serde`, `StreamPartitioner`, number of partitions, and + * topic name part. + *

                + * The created topic is considered as an internal topic and is meant to be used only by the current Kafka Streams instance. + * Similar to auto-repartitioning, the topic will be created with infinite retention time and data will be automatically purged by Kafka Streams. + * The topic will be named as "${applicationId}-<name>-repartition", where "applicationId" is user-specified in + * `StreamsConfig` via parameter `APPLICATION_ID_CONFIG APPLICATION_ID_CONFIG`, + * "<name>" is either provided via `Repartitioned#as(String)` or an internally + * generated name, and "-repartition" is a fixed suffix. + *

                + * The user can either supply the `Repartitioned` instance as an implicit in scope or they can also provide implicit + * key and value serdes that will be converted to a `Repartitioned` instance implicitly. + *

                + * {{{ + * Example: + * + * // brings implicit serdes in scope + * import Serdes._ + * + * //.. + * val clicksPerRegion: KStream[String, Long] = //.. + * + * // Implicit serdes in scope will generate an implicit Produced instance, which + * // will be passed automatically to the call of through below + * clicksPerRegion.repartition + * + * // Similarly you can create an implicit Repartitioned and it will be passed implicitly + * // to the repartition call + * }}} + * + * @param repartitioned the `Repartitioned` instance used to specify `Serdes`, `StreamPartitioner` which determines + * how records are distributed among partitions of the topic, + * part of the topic name, and number of partitions for a repartition topic. + * @return a [[KStream]] that contains the exact same repartitioned records as this [[KStream]] + * @see `org.apache.kafka.streams.kstream.KStream#repartition` + */ + def repartition(implicit repartitioned: Repartitioned[K, V]): KStream[K, V] = + new KStream(inner.repartition(repartitioned)) + + /** + * Materialize this stream to a topic using the `Produced` instance for + * configuration of the `Serde key serde`, `Serde value serde`, and `StreamPartitioner` + *

                + * The user can either supply the `Produced` instance as an implicit in scope or they can also provide implicit + * key and value serdes that will be converted to a `Produced` instance implicitly. + *

                + * {{{ + * Example: + * + * // brings implicit serdes in scope + * import Serdes._ + * + * //.. + * val clicksPerRegion: KTable[String, Long] = //.. + * + * // Implicit serdes in scope will generate an implicit Produced instance, which + * // will be passed automatically to the call of through below + * clicksPerRegion.to(topic) + * + * // Similarly you can create an implicit Produced and it will be passed implicitly + * // to the through call + * }}} + * + * @param topic the topic name + * @param produced the instance of Produced that gives the serdes and `StreamPartitioner` + * @see `org.apache.kafka.streams.kstream.KStream#to` + */ + def to(topic: String)(implicit produced: Produced[K, V]): Unit = + inner.to(topic, produced) + + /** + * Dynamically materialize this stream to topics using the `Produced` instance for + * configuration of the `Serde key serde`, `Serde value serde`, and `StreamPartitioner`. + * The topic names for each record to send to is dynamically determined based on the given mapper. + *

                + * The user can either supply the `Produced` instance as an implicit in scope or they can also provide implicit + * key and value serdes that will be converted to a `Produced` instance implicitly. + *

                + * {{{ + * Example: + * + * // brings implicit serdes in scope + * import Serdes._ + * + * //.. + * val clicksPerRegion: KTable[String, Long] = //.. + * + * // Implicit serdes in scope will generate an implicit Produced instance, which + * // will be passed automatically to the call of through below + * clicksPerRegion.to(topicChooser) + * + * // Similarly you can create an implicit Produced and it will be passed implicitly + * // to the through call + * }}} + * + * @param extractor the extractor to determine the name of the Kafka topic to write to for reach record + * @param produced the instance of Produced that gives the serdes and `StreamPartitioner` + * @see `org.apache.kafka.streams.kstream.KStream#to` + */ + def to(extractor: TopicNameExtractor[K, V])(implicit produced: Produced[K, V]): Unit = + inner.to(extractor, produced) + + /** + * Convert this stream to a [[KTable]]. + * + * @return a [[KTable]] that contains the same records as this [[KStream]] + * @see `org.apache.kafka.streams.kstream.KStream#toTable` + */ + def toTable: KTable[K, V] = + new KTable(inner.toTable) + + /** + * Convert this stream to a [[KTable]]. + * + * @param named a [[Named]] config used to name the processor in the topology + * @return a [[KTable]] that contains the same records as this [[KStream]] + * @see `org.apache.kafka.streams.kstream.KStream#toTable` + */ + def toTable(named: Named): KTable[K, V] = + new KTable(inner.toTable(named)) + + /** + * Convert this stream to a [[KTable]]. + * + * @param materialized a `Materialized` that describes how the `StateStore` for the resulting [[KTable]] + * should be materialized. + * @return a [[KTable]] that contains the same records as this [[KStream]] + * @see `org.apache.kafka.streams.kstream.KStream#toTable` + */ + def toTable(materialized: Materialized[K, V, ByteArrayKeyValueStore]): KTable[K, V] = + new KTable(inner.toTable(materialized)) + + /** + * Convert this stream to a [[KTable]]. + * + * @param named a [[Named]] config used to name the processor in the topology + * @param materialized a `Materialized` that describes how the `StateStore` for the resulting [[KTable]] + * should be materialized. + * @return a [[KTable]] that contains the same records as this [[KStream]] + * @see `org.apache.kafka.streams.kstream.KStream#toTable` + */ + def toTable(named: Named, materialized: Materialized[K, V, ByteArrayKeyValueStore]): KTable[K, V] = + new KTable(inner.toTable(named, materialized)) + + /** + * Transform each record of the input stream into zero or more records in the output stream (both key and value type + * can be altered arbitrarily). + * A `Transformer` (provided by the given `TransformerSupplier`) is applied to each input record + * and computes zero or more output records. + * In order to assign a state, the state must be created and added via `addStateStore` before they can be connected + * to the `Transformer`. + * It's not required to connect global state stores that are added via `addGlobalStore`; + * read-only access to global state stores is available by default. + * + * @param transformerSupplier the `TransformerSuplier` that generates `Transformer` + * @param stateStoreNames the names of the state stores used by the processor + * @return a [[KStream]] that contains more or less records with new key and value (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KStream#transform` + */ + def transform[K1, V1]( + transformerSupplier: TransformerSupplier[K, V, KeyValue[K1, V1]], + stateStoreNames: String* + ): KStream[K1, V1] = + new KStream(inner.transform(transformerSupplier, stateStoreNames: _*)) + + /** + * Transform each record of the input stream into zero or more records in the output stream (both key and value type + * can be altered arbitrarily). + * A `Transformer` (provided by the given `TransformerSupplier`) is applied to each input record + * and computes zero or more output records. + * In order to assign a state, the state must be created and added via `addStateStore` before they can be connected + * to the `Transformer`. + * It's not required to connect global state stores that are added via `addGlobalStore`; + * read-only access to global state stores is available by default. + * + * @param transformerSupplier the `TransformerSuplier` that generates `Transformer` + * @param named a [[Named]] config used to name the processor in the topology + * @param stateStoreNames the names of the state stores used by the processor + * @return a [[KStream]] that contains more or less records with new key and value (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KStream#transform` + */ + def transform[K1, V1]( + transformerSupplier: TransformerSupplier[K, V, KeyValue[K1, V1]], + named: Named, + stateStoreNames: String* + ): KStream[K1, V1] = + new KStream(inner.transform(transformerSupplier, named, stateStoreNames: _*)) + + /** + * Transform each record of the input stream into zero or more records in the output stream (both key and value type + * can be altered arbitrarily). + * A `Transformer` (provided by the given `TransformerSupplier`) is applied to each input record + * and computes zero or more output records. + * In order to assign a state, the state must be created and added via `addStateStore` before they can be connected + * to the `Transformer`. + * It's not required to connect global state stores that are added via `addGlobalStore`; + * read-only access to global state stores is available by default. + * + * @param transformerSupplier the `TransformerSuplier` that generates `Transformer` + * @param stateStoreNames the names of the state stores used by the processor + * @return a [[KStream]] that contains more or less records with new key and value (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KStream#transform` + */ + def flatTransform[K1, V1]( + transformerSupplier: TransformerSupplier[K, V, Iterable[KeyValue[K1, V1]]], + stateStoreNames: String* + ): KStream[K1, V1] = + new KStream(inner.flatTransform(transformerSupplier.asJava, stateStoreNames: _*)) + + /** + * Transform each record of the input stream into zero or more records in the output stream (both key and value type + * can be altered arbitrarily). + * A `Transformer` (provided by the given `TransformerSupplier`) is applied to each input record + * and computes zero or more output records. + * In order to assign a state, the state must be created and added via `addStateStore` before they can be connected + * to the `Transformer`. + * It's not required to connect global state stores that are added via `addGlobalStore`; + * read-only access to global state stores is available by default. + * + * @param transformerSupplier the `TransformerSuplier` that generates `Transformer` + * @param named a [[Named]] config used to name the processor in the topology + * @param stateStoreNames the names of the state stores used by the processor + * @return a [[KStream]] that contains more or less records with new key and value (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KStream#transform` + */ + def flatTransform[K1, V1]( + transformerSupplier: TransformerSupplier[K, V, Iterable[KeyValue[K1, V1]]], + named: Named, + stateStoreNames: String* + ): KStream[K1, V1] = + new KStream(inner.flatTransform(transformerSupplier.asJava, named, stateStoreNames: _*)) + + /** + * Transform the value of each input record into zero or more records (with possible new type) in the + * output stream. + * A `ValueTransformer` (provided by the given `ValueTransformerSupplier`) is applied to each input + * record value and computes a new value for it. + * In order to assign a state, the state must be created and added via `addStateStore` before they can be connected + * to the `ValueTransformer`. + * It's not required to connect global state stores that are added via `addGlobalStore`; + * read-only access to global state stores is available by default. + * + * @param valueTransformerSupplier a instance of `ValueTransformerSupplier` that generates a `ValueTransformer` + * @param stateStoreNames the names of the state stores used by the processor + * @return a [[KStream]] that contains records with unmodified key and new values (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KStream#transformValues` + */ + def flatTransformValues[VR]( + valueTransformerSupplier: ValueTransformerSupplier[V, Iterable[VR]], + stateStoreNames: String* + ): KStream[K, VR] = + new KStream(inner.flatTransformValues[VR](valueTransformerSupplier.asJava, stateStoreNames: _*)) + + /** + * Transform the value of each input record into zero or more records (with possible new type) in the + * output stream. + * A `ValueTransformer` (provided by the given `ValueTransformerSupplier`) is applied to each input + * record value and computes a new value for it. + * In order to assign a state, the state must be created and added via `addStateStore` before they can be connected + * to the `ValueTransformer`. + * It's not required to connect global state stores that are added via `addGlobalStore`; + * read-only access to global state stores is available by default. + * + * @param valueTransformerSupplier a instance of `ValueTransformerSupplier` that generates a `ValueTransformer` + * @param named a [[Named]] config used to name the processor in the topology + * @param stateStoreNames the names of the state stores used by the processor + * @return a [[KStream]] that contains records with unmodified key and new values (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KStream#transformValues` + */ + def flatTransformValues[VR]( + valueTransformerSupplier: ValueTransformerSupplier[V, Iterable[VR]], + named: Named, + stateStoreNames: String* + ): KStream[K, VR] = + new KStream(inner.flatTransformValues[VR](valueTransformerSupplier.asJava, named, stateStoreNames: _*)) + + /** + * Transform the value of each input record into zero or more records (with possible new type) in the + * output stream. + * A `ValueTransformer` (provided by the given `ValueTransformerSupplier`) is applied to each input + * record value and computes a new value for it. + * In order to assign a state, the state must be created and added via `addStateStore` before they can be connected + * to the `ValueTransformer`. + * It's not required to connect global state stores that are added via `addGlobalStore`; + * read-only access to global state stores is available by default. + * + * @param valueTransformerSupplier a instance of `ValueTransformerWithKeySupplier` that generates a `ValueTransformerWithKey` + * @param stateStoreNames the names of the state stores used by the processor + * @return a [[KStream]] that contains records with unmodified key and new values (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KStream#transformValues` + */ + def flatTransformValues[VR]( + valueTransformerSupplier: ValueTransformerWithKeySupplier[K, V, Iterable[VR]], + stateStoreNames: String* + ): KStream[K, VR] = + new KStream(inner.flatTransformValues[VR](valueTransformerSupplier.asJava, stateStoreNames: _*)) + + /** + * Transform the value of each input record into zero or more records (with possible new type) in the + * output stream. + * A `ValueTransformer` (provided by the given `ValueTransformerSupplier`) is applied to each input + * record value and computes a new value for it. + * In order to assign a state, the state must be created and added via `addStateStore` before they can be connected + * to the `ValueTransformer`. + * It's not required to connect global state stores that are added via `addGlobalStore`; + * read-only access to global state stores is available by default. + * + * @param valueTransformerSupplier a instance of `ValueTransformerWithKeySupplier` that generates a `ValueTransformerWithKey` + * @param named a [[Named]] config used to name the processor in the topology + * @param stateStoreNames the names of the state stores used by the processor + * @return a [[KStream]] that contains records with unmodified key and new values (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KStream#transformValues` + */ + def flatTransformValues[VR]( + valueTransformerSupplier: ValueTransformerWithKeySupplier[K, V, Iterable[VR]], + named: Named, + stateStoreNames: String* + ): KStream[K, VR] = + new KStream(inner.flatTransformValues[VR](valueTransformerSupplier.asJava, named, stateStoreNames: _*)) + + /** + * Transform the value of each input record into a new value (with possible new type) of the output record. + * A `ValueTransformer` (provided by the given `ValueTransformerSupplier`) is applied to each input + * record value and computes a new value for it. + * In order to assign a state, the state must be created and added via `addStateStore` before they can be connected + * to the `ValueTransformer`. + * It's not required to connect global state stores that are added via `addGlobalStore`; + * read-only access to global state stores is available by default. + * + * @param valueTransformerSupplier a instance of `ValueTransformerSupplier` that generates a `ValueTransformer` + * @param stateStoreNames the names of the state stores used by the processor + * @return a [[KStream]] that contains records with unmodified key and new values (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KStream#transformValues` + */ + def transformValues[VR]( + valueTransformerSupplier: ValueTransformerSupplier[V, VR], + stateStoreNames: String* + ): KStream[K, VR] = + new KStream(inner.transformValues[VR](valueTransformerSupplier, stateStoreNames: _*)) + + /** + * Transform the value of each input record into a new value (with possible new type) of the output record. + * A `ValueTransformer` (provided by the given `ValueTransformerSupplier`) is applied to each input + * record value and computes a new value for it. + * In order to assign a state, the state must be created and added via `addStateStore` before they can be connected + * to the `ValueTransformer`. + * It's not required to connect global state stores that are added via `addGlobalStore`; + * read-only access to global state stores is available by default. + * + * @param valueTransformerSupplier a instance of `ValueTransformerSupplier` that generates a `ValueTransformer` + * @param named a [[Named]] config used to name the processor in the topology + * @param stateStoreNames the names of the state stores used by the processor + * @return a [[KStream]] that contains records with unmodified key and new values (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KStream#transformValues` + */ + def transformValues[VR]( + valueTransformerSupplier: ValueTransformerSupplier[V, VR], + named: Named, + stateStoreNames: String* + ): KStream[K, VR] = + new KStream(inner.transformValues[VR](valueTransformerSupplier, named, stateStoreNames: _*)) + + /** + * Transform the value of each input record into a new value (with possible new type) of the output record. + * A `ValueTransformer` (provided by the given `ValueTransformerSupplier`) is applied to each input + * record value and computes a new value for it. + * In order to assign a state, the state must be created and added via `addStateStore` before they can be connected + * to the `ValueTransformer`. + * It's not required to connect global state stores that are added via `addGlobalStore`; + * read-only access to global state stores is available by default. + * + * @param valueTransformerSupplier a instance of `ValueTransformerWithKeySupplier` that generates a `ValueTransformerWithKey` + * @param stateStoreNames the names of the state stores used by the processor + * @return a [[KStream]] that contains records with unmodified key and new values (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KStream#transformValues` + */ + def transformValues[VR]( + valueTransformerSupplier: ValueTransformerWithKeySupplier[K, V, VR], + stateStoreNames: String* + ): KStream[K, VR] = + new KStream(inner.transformValues[VR](valueTransformerSupplier, stateStoreNames: _*)) + + /** + * Transform the value of each input record into a new value (with possible new type) of the output record. + * A `ValueTransformer` (provided by the given `ValueTransformerSupplier`) is applied to each input + * record value and computes a new value for it. + * In order to assign a state, the state must be created and added via `addStateStore` before they can be connected + * to the `ValueTransformer`. + * It's not required to connect global state stores that are added via `addGlobalStore`; + * read-only access to global state stores is available by default. + * + * @param valueTransformerSupplier a instance of `ValueTransformerWithKeySupplier` that generates a `ValueTransformerWithKey` + * @param named a [[Named]] config used to name the processor in the topology + * @param stateStoreNames the names of the state stores used by the processor + * @return a [[KStream]] that contains records with unmodified key and new values (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KStream#transformValues` + */ + def transformValues[VR]( + valueTransformerSupplier: ValueTransformerWithKeySupplier[K, V, VR], + named: Named, + stateStoreNames: String* + ): KStream[K, VR] = + new KStream(inner.transformValues[VR](valueTransformerSupplier, named, stateStoreNames: _*)) + + /** + * Process all records in this stream, one record at a time, by applying a `Processor` (provided by the given + * `processorSupplier`). + * In order to assign a state, the state must be created and added via `addStateStore` before they can be connected + * to the `Processor`. + * It's not required to connect global state stores that are added via `addGlobalStore`; + * read-only access to global state stores is available by default. + * + * @param processorSupplier a function that generates a [[org.apache.kafka.streams.processor.Processor]] + * @param stateStoreNames the names of the state store used by the processor + * @see `org.apache.kafka.streams.kstream.KStream#process` + */ + @deprecated(since = "3.0", message = "Use process(ProcessorSupplier, String*) instead.") + def process( + processorSupplier: () => org.apache.kafka.streams.processor.Processor[K, V], + stateStoreNames: String* + ): Unit = { + val processorSupplierJ: org.apache.kafka.streams.processor.ProcessorSupplier[K, V] = () => processorSupplier() + inner.process(processorSupplierJ, stateStoreNames: _*) + } + + /** + * Process all records in this stream, one record at a time, by applying a `Processor` (provided by the given + * `processorSupplier`). + * In order to assign a state, the state must be created and added via `addStateStore` before they can be connected + * to the `Processor`. + * It's not required to connect global state stores that are added via `addGlobalStore`; + * read-only access to global state stores is available by default. + * + * Note that this overload takes a ProcessorSupplier instead of a Function to avoid post-erasure ambiguity with + * the older (deprecated) overload. + * + * @param processorSupplier a supplier for [[org.apache.kafka.streams.processor.api.Processor]] + * @param stateStoreNames the names of the state store used by the processor + * @see `org.apache.kafka.streams.kstream.KStream#process` + */ + def process(processorSupplier: ProcessorSupplier[K, V, Void, Void], stateStoreNames: String*): Unit = + inner.process(processorSupplier, stateStoreNames: _*) + + /** + * Process all records in this stream, one record at a time, by applying a `Processor` (provided by the given + * `processorSupplier`). + * In order to assign a state, the state must be created and added via `addStateStore` before they can be connected + * to the `Processor`. + * It's not required to connect global state stores that are added via `addGlobalStore`; + * read-only access to global state stores is available by default. + * + * @param processorSupplier a function that generates a [[org.apache.kafka.streams.processor.Processor]] + * @param named a [[Named]] config used to name the processor in the topology + * @param stateStoreNames the names of the state store used by the processor + * @see `org.apache.kafka.streams.kstream.KStream#process` + */ + @deprecated(since = "3.0", message = "Use process(ProcessorSupplier, String*) instead.") + def process( + processorSupplier: () => org.apache.kafka.streams.processor.Processor[K, V], + named: Named, + stateStoreNames: String* + ): Unit = { + val processorSupplierJ: org.apache.kafka.streams.processor.ProcessorSupplier[K, V] = () => processorSupplier() + inner.process(processorSupplierJ, named, stateStoreNames: _*) + } + + /** + * Process all records in this stream, one record at a time, by applying a `Processor` (provided by the given + * `processorSupplier`). + * In order to assign a state, the state must be created and added via `addStateStore` before they can be connected + * to the `Processor`. + * It's not required to connect global state stores that are added via `addGlobalStore`; + * read-only access to global state stores is available by default. + * + * Note that this overload takes a ProcessorSupplier instead of a Function to avoid post-erasure ambiguity with + * the older (deprecated) overload. + * + * @param processorSupplier a supplier for [[org.apache.kafka.streams.processor.api.Processor]] + * @param named a [[Named]] config used to name the processor in the topology + * @param stateStoreNames the names of the state store used by the processor + * @see `org.apache.kafka.streams.kstream.KStream#process` + */ + def process(processorSupplier: ProcessorSupplier[K, V, Void, Void], named: Named, stateStoreNames: String*): Unit = + inner.process(processorSupplier, named, stateStoreNames: _*) + + /** + * Group the records by their current key into a [[KGroupedStream]] + *

                + * The user can either supply the `Grouped` instance as an implicit in scope or they can also provide an implicit + * serdes that will be converted to a `Grouped` instance implicitly. + *

                + * {{{ + * Example: + * + * // brings implicit serdes in scope + * import Serdes._ + * + * val clicksPerRegion: KTable[String, Long] = + * userClicksStream + * .leftJoin(userRegionsTable, (clicks: Long, region: String) => (if (region == null) "UNKNOWN" else region, clicks)) + * .map((_, regionWithClicks) => regionWithClicks) + * + * // the groupByKey gets the Grouped instance through an implicit conversion of the + * // serdes brought into scope through the import Serdes._ above + * .groupByKey + * .reduce(_ + _) + * + * // Similarly you can create an implicit Grouped and it will be passed implicitly + * // to the groupByKey call + * }}} + * + * @param grouped the instance of Grouped that gives the serdes + * @return a [[KGroupedStream]] that contains the grouped records of the original [[KStream]] + * @see `org.apache.kafka.streams.kstream.KStream#groupByKey` + */ + def groupByKey(implicit grouped: Grouped[K, V]): KGroupedStream[K, V] = + new KGroupedStream(inner.groupByKey(grouped)) + + /** + * Group the records of this [[KStream]] on a new key that is selected using the provided key transformation function + * and the `Grouped` instance. + *

                + * The user can either supply the `Grouped` instance as an implicit in scope or they can also provide an implicit + * serdes that will be converted to a `Grouped` instance implicitly. + *

                + * {{{ + * Example: + * + * // brings implicit serdes in scope + * import Serdes._ + * + * val textLines = streamBuilder.stream[String, String](inputTopic) + * + * val pattern = Pattern.compile("\\W+", Pattern.UNICODE_CHARACTER_CLASS) + * + * val wordCounts: KTable[String, Long] = + * textLines.flatMapValues(v => pattern.split(v.toLowerCase)) + * + * // the groupBy gets the Grouped instance through an implicit conversion of the + * // serdes brought into scope through the import Serdes._ above + * .groupBy((k, v) => v) + * + * .count() + * }}} + * + * @param selector a function that computes a new key for grouping + * @return a [[KGroupedStream]] that contains the grouped records of the original [[KStream]] + * @see `org.apache.kafka.streams.kstream.KStream#groupBy` + */ + def groupBy[KR](selector: (K, V) => KR)(implicit grouped: Grouped[KR, V]): KGroupedStream[KR, V] = + new KGroupedStream(inner.groupBy(selector.asKeyValueMapper, grouped)) + + /** + * Join records of this stream with another [[KStream]]'s records using windowed inner equi join with + * serializers and deserializers supplied by the implicit `StreamJoined` instance. + * + * @param otherStream the [[KStream]] to be joined with this stream + * @param joiner a function that computes the join result for a pair of matching records + * @param windows the specification of the `JoinWindows` + * @param streamJoin an implicit `StreamJoin` instance that defines the serdes to be used to serialize/deserialize + * inputs and outputs of the joined streams. Instead of `StreamJoin`, the user can also supply + * key serde, value serde and other value serde in implicit scope and they will be + * converted to the instance of `Stream` through implicit conversion. The `StreamJoin` instance can + * also name the repartition topic (if required), the state stores for the join, and the join + * processor node. + * @return a [[KStream]] that contains join-records for each key and values computed by the given `joiner`, + * one for each matched record-pair with the same key and within the joining window intervals + * @see `org.apache.kafka.streams.kstream.KStream#join` + */ + def join[VO, VR](otherStream: KStream[K, VO])( + joiner: (V, VO) => VR, + windows: JoinWindows + )(implicit streamJoin: StreamJoined[K, V, VO]): KStream[K, VR] = + new KStream(inner.join[VO, VR](otherStream.inner, joiner.asValueJoiner, windows, streamJoin)) + + /** + * Join records of this stream with another [[KStream]]'s records using windowed left equi join with + * serializers and deserializers supplied by the implicit `StreamJoined` instance. + * + * @param otherStream the [[KStream]] to be joined with this stream + * @param joiner a function that computes the join result for a pair of matching records + * @param windows the specification of the `JoinWindows` + * @param streamJoin an implicit `StreamJoin` instance that defines the serdes to be used to serialize/deserialize + * inputs and outputs of the joined streams. Instead of `StreamJoin`, the user can also supply + * key serde, value serde and other value serde in implicit scope and they will be + * converted to the instance of `Stream` through implicit conversion. The `StreamJoin` instance can + * also name the repartition topic (if required), the state stores for the join, and the join + * processor node. + * @return a [[KStream]] that contains join-records for each key and values computed by the given `joiner`, + * one for each matched record-pair with the same key and within the joining window intervals + * @see `org.apache.kafka.streams.kstream.KStream#leftJoin` + */ + def leftJoin[VO, VR](otherStream: KStream[K, VO])( + joiner: (V, VO) => VR, + windows: JoinWindows + )(implicit streamJoin: StreamJoined[K, V, VO]): KStream[K, VR] = + new KStream(inner.leftJoin[VO, VR](otherStream.inner, joiner.asValueJoiner, windows, streamJoin)) + + /** + * Join records of this stream with another [[KStream]]'s records using windowed outer equi join with + * serializers and deserializers supplied by the implicit `Joined` instance. + * + * @param otherStream the [[KStream]] to be joined with this stream + * @param joiner a function that computes the join result for a pair of matching records + * @param windows the specification of the `JoinWindows` + * @param streamJoin an implicit `StreamJoin` instance that defines the serdes to be used to serialize/deserialize + * inputs and outputs of the joined streams. Instead of `StreamJoin`, the user can also supply + * key serde, value serde and other value serde in implicit scope and they will be + * converted to the instance of `Stream` through implicit conversion. The `StreamJoin` instance can + * also name the repartition topic (if required), the state stores for the join, and the join + * processor node. + * @return a [[KStream]] that contains join-records for each key and values computed by the given `joiner`, + * one for each matched record-pair with the same key and within the joining window intervals + * @see `org.apache.kafka.streams.kstream.KStream#outerJoin` + */ + def outerJoin[VO, VR](otherStream: KStream[K, VO])( + joiner: (V, VO) => VR, + windows: JoinWindows + )(implicit streamJoin: StreamJoined[K, V, VO]): KStream[K, VR] = + new KStream(inner.outerJoin[VO, VR](otherStream.inner, joiner.asValueJoiner, windows, streamJoin)) + + /** + * Join records of this stream with another [[KTable]]'s records using inner equi join with + * serializers and deserializers supplied by the implicit `Joined` instance. + * + * @param table the [[KTable]] to be joined with this stream + * @param joiner a function that computes the join result for a pair of matching records + * @param joined an implicit `Joined` instance that defines the serdes to be used to serialize/deserialize + * inputs and outputs of the joined streams. Instead of `Joined`, the user can also supply + * key serde, value serde and other value serde in implicit scope and they will be + * converted to the instance of `Joined` through implicit conversion + * @return a [[KStream]] that contains join-records for each key and values computed by the given `joiner`, + * one for each matched record-pair with the same key + * @see `org.apache.kafka.streams.kstream.KStream#join` + */ + def join[VT, VR](table: KTable[K, VT])(joiner: (V, VT) => VR)(implicit joined: Joined[K, V, VT]): KStream[K, VR] = + new KStream(inner.join[VT, VR](table.inner, joiner.asValueJoiner, joined)) + + /** + * Join records of this stream with another [[KTable]]'s records using left equi join with + * serializers and deserializers supplied by the implicit `Joined` instance. + * + * @param table the [[KTable]] to be joined with this stream + * @param joiner a function that computes the join result for a pair of matching records + * @param joined an implicit `Joined` instance that defines the serdes to be used to serialize/deserialize + * inputs and outputs of the joined streams. Instead of `Joined`, the user can also supply + * key serde, value serde and other value serde in implicit scope and they will be + * converted to the instance of `Joined` through implicit conversion + * @return a [[KStream]] that contains join-records for each key and values computed by the given `joiner`, + * one for each matched record-pair with the same key + * @see `org.apache.kafka.streams.kstream.KStream#leftJoin` + */ + def leftJoin[VT, VR](table: KTable[K, VT])(joiner: (V, VT) => VR)(implicit joined: Joined[K, V, VT]): KStream[K, VR] = + new KStream(inner.leftJoin[VT, VR](table.inner, joiner.asValueJoiner, joined)) + + /** + * Join records of this stream with `GlobalKTable`'s records using non-windowed inner equi join. + * + * @param globalKTable the `GlobalKTable` to be joined with this stream + * @param keyValueMapper a function used to map from the (key, value) of this stream + * to the key of the `GlobalKTable` + * @param joiner a function that computes the join result for a pair of matching records + * @return a [[KStream]] that contains join-records for each key and values computed by the given `joiner`, + * one output for each input [[KStream]] record + * @see `org.apache.kafka.streams.kstream.KStream#join` + */ + def join[GK, GV, RV](globalKTable: GlobalKTable[GK, GV])( + keyValueMapper: (K, V) => GK, + joiner: (V, GV) => RV + ): KStream[K, RV] = + new KStream( + inner.join[GK, GV, RV]( + globalKTable, + ((k: K, v: V) => keyValueMapper(k, v)).asKeyValueMapper, + ((v: V, gv: GV) => joiner(v, gv)).asValueJoiner + ) + ) + + /** + * Join records of this stream with `GlobalKTable`'s records using non-windowed inner equi join. + * + * @param globalKTable the `GlobalKTable` to be joined with this stream + * @param named a [[Named]] config used to name the processor in the topology + * @param keyValueMapper a function used to map from the (key, value) of this stream + * to the key of the `GlobalKTable` + * @param joiner a function that computes the join result for a pair of matching records + * @return a [[KStream]] that contains join-records for each key and values computed by the given `joiner`, + * one output for each input [[KStream]] record + * @see `org.apache.kafka.streams.kstream.KStream#join` + */ + def join[GK, GV, RV](globalKTable: GlobalKTable[GK, GV], named: Named)( + keyValueMapper: (K, V) => GK, + joiner: (V, GV) => RV + ): KStream[K, RV] = + new KStream( + inner.join[GK, GV, RV]( + globalKTable, + ((k: K, v: V) => keyValueMapper(k, v)).asKeyValueMapper, + ((v: V, gv: GV) => joiner(v, gv)).asValueJoiner, + named + ) + ) + + /** + * Join records of this stream with `GlobalKTable`'s records using non-windowed left equi join. + * + * @param globalKTable the `GlobalKTable` to be joined with this stream + * @param keyValueMapper a function used to map from the (key, value) of this stream + * to the key of the `GlobalKTable` + * @param joiner a function that computes the join result for a pair of matching records + * @return a [[KStream]] that contains join-records for each key and values computed by the given `joiner`, + * one output for each input [[KStream]] record + * @see `org.apache.kafka.streams.kstream.KStream#leftJoin` + */ + def leftJoin[GK, GV, RV](globalKTable: GlobalKTable[GK, GV])( + keyValueMapper: (K, V) => GK, + joiner: (V, GV) => RV + ): KStream[K, RV] = + new KStream(inner.leftJoin[GK, GV, RV](globalKTable, keyValueMapper.asKeyValueMapper, joiner.asValueJoiner)) + + /** + * Join records of this stream with `GlobalKTable`'s records using non-windowed left equi join. + * + * @param globalKTable the `GlobalKTable` to be joined with this stream + * @param named a [[Named]] config used to name the processor in the topology + * @param keyValueMapper a function used to map from the (key, value) of this stream + * to the key of the `GlobalKTable` + * @param joiner a function that computes the join result for a pair of matching records + * @return a [[KStream]] that contains join-records for each key and values computed by the given `joiner`, + * one output for each input [[KStream]] record + * @see `org.apache.kafka.streams.kstream.KStream#leftJoin` + */ + def leftJoin[GK, GV, RV](globalKTable: GlobalKTable[GK, GV], named: Named)( + keyValueMapper: (K, V) => GK, + joiner: (V, GV) => RV + ): KStream[K, RV] = + new KStream(inner.leftJoin[GK, GV, RV](globalKTable, keyValueMapper.asKeyValueMapper, joiner.asValueJoiner, named)) + + /** + * Merge this stream and the given stream into one larger stream. + *

                + * There is no ordering guarantee between records from this `KStream` and records from the provided `KStream` + * in the merged stream. Relative order is preserved within each input stream though (ie, records within + * one input stream are processed in order). + * + * @param stream a stream which is to be merged into this stream + * @return a merged stream containing all records from this and the provided [[KStream]] + * @see `org.apache.kafka.streams.kstream.KStream#merge` + */ + def merge(stream: KStream[K, V]): KStream[K, V] = + new KStream(inner.merge(stream.inner)) + + /** + * Merge this stream and the given stream into one larger stream. + *

                + * There is no ordering guarantee between records from this `KStream` and records from the provided `KStream` + * in the merged stream. Relative order is preserved within each input stream though (ie, records within + * one input stream are processed in order). + * + * @param named a [[Named]] config used to name the processor in the topology + * @param stream a stream which is to be merged into this stream + * @return a merged stream containing all records from this and the provided [[KStream]] + * @see `org.apache.kafka.streams.kstream.KStream#merge` + */ + def merge(stream: KStream[K, V], named: Named): KStream[K, V] = + new KStream(inner.merge(stream.inner, named)) + + /** + * Perform an action on each record of `KStream`. + *

                + * Peek is a non-terminal operation that triggers a side effect (such as logging or statistics collection) + * and returns an unchanged stream. + * + * @param action an action to perform on each record + * @see `org.apache.kafka.streams.kstream.KStream#peek` + */ + def peek(action: (K, V) => Unit): KStream[K, V] = + new KStream(inner.peek(action.asForeachAction)) + + /** + * Perform an action on each record of `KStream`. + *

                + * Peek is a non-terminal operation that triggers a side effect (such as logging or statistics collection) + * and returns an unchanged stream. + * + * @param action an action to perform on each record + * @param named a [[Named]] config used to name the processor in the topology + * @see `org.apache.kafka.streams.kstream.KStream#peek` + */ + def peek(action: (K, V) => Unit, named: Named): KStream[K, V] = + new KStream(inner.peek(action.asForeachAction, named)) +} diff --git a/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/KTable.scala b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/KTable.scala new file mode 100644 index 0000000..3a405b6 --- /dev/null +++ b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/KTable.scala @@ -0,0 +1,763 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala +package kstream + +import org.apache.kafka.common.utils.Bytes +import org.apache.kafka.streams.kstream.{TableJoined, ValueJoiner, ValueTransformerWithKeySupplier, KTable => KTableJ} +import org.apache.kafka.streams.scala.FunctionsCompatConversions.{ + FunctionFromFunction, + KeyValueMapperFromFunction, + MapperFromFunction, + PredicateFromFunction, + ValueMapperFromFunction, + ValueMapperWithKeyFromFunction +} +import org.apache.kafka.streams.state.KeyValueStore + +/** + * Wraps the Java class [[org.apache.kafka.streams.kstream.KTable]] and delegates method calls to the underlying Java object. + * + * @tparam K Type of keys + * @tparam V Type of values + * @param inner The underlying Java abstraction for KTable + * @see `org.apache.kafka.streams.kstream.KTable` + */ +class KTable[K, V](val inner: KTableJ[K, V]) { + + /** + * Create a new [[KTable]] that consists all records of this [[KTable]] which satisfies the given + * predicate + * + * @param predicate a filter that is applied to each record + * @return a [[KTable]] that contains only those records that satisfy the given predicate + * @see `org.apache.kafka.streams.kstream.KTable#filter` + */ + def filter(predicate: (K, V) => Boolean): KTable[K, V] = + new KTable(inner.filter(predicate.asPredicate)) + + /** + * Create a new [[KTable]] that consists all records of this [[KTable]] which satisfies the given + * predicate + * + * @param predicate a filter that is applied to each record + * @param named a [[Named]] config used to name the processor in the topology + * @return a [[KTable]] that contains only those records that satisfy the given predicate + * @see `org.apache.kafka.streams.kstream.KTable#filter` + */ + def filter(predicate: (K, V) => Boolean, named: Named): KTable[K, V] = + new KTable(inner.filter(predicate.asPredicate, named)) + + /** + * Create a new [[KTable]] that consists all records of this [[KTable]] which satisfies the given + * predicate + * + * @param predicate a filter that is applied to each record + * @param materialized a `Materialized` that describes how the `StateStore` for the resulting [[KTable]] + * should be materialized. + * @return a [[KTable]] that contains only those records that satisfy the given predicate + * @see `org.apache.kafka.streams.kstream.KTable#filter` + */ + def filter(predicate: (K, V) => Boolean, materialized: Materialized[K, V, ByteArrayKeyValueStore]): KTable[K, V] = + new KTable(inner.filter(predicate.asPredicate, materialized)) + + /** + * Create a new [[KTable]] that consists all records of this [[KTable]] which satisfies the given + * predicate + * + * @param predicate a filter that is applied to each record + * @param named a [[Named]] config used to name the processor in the topology + * @param materialized a `Materialized` that describes how the `StateStore` for the resulting [[KTable]] + * should be materialized. + * @return a [[KTable]] that contains only those records that satisfy the given predicate + * @see `org.apache.kafka.streams.kstream.KTable#filter` + */ + def filter( + predicate: (K, V) => Boolean, + named: Named, + materialized: Materialized[K, V, ByteArrayKeyValueStore] + ): KTable[K, V] = + new KTable(inner.filter(predicate.asPredicate, named, materialized)) + + /** + * Create a new [[KTable]] that consists all records of this [[KTable]] which do not satisfy the given + * predicate + * + * @param predicate a filter that is applied to each record + * @return a [[KTable]] that contains only those records that do not satisfy the given predicate + * @see `org.apache.kafka.streams.kstream.KTable#filterNot` + */ + def filterNot(predicate: (K, V) => Boolean): KTable[K, V] = + new KTable(inner.filterNot(predicate.asPredicate)) + + /** + * Create a new [[KTable]] that consists all records of this [[KTable]] which do not satisfy the given + * predicate + * + * @param predicate a filter that is applied to each record + * @param named a [[Named]] config used to name the processor in the topology + * @return a [[KTable]] that contains only those records that do not satisfy the given predicate + * @see `org.apache.kafka.streams.kstream.KTable#filterNot` + */ + def filterNot(predicate: (K, V) => Boolean, named: Named): KTable[K, V] = + new KTable(inner.filterNot(predicate.asPredicate, named)) + + /** + * Create a new [[KTable]] that consists all records of this [[KTable]] which do not satisfy the given + * predicate + * + * @param predicate a filter that is applied to each record + * @param materialized a `Materialized` that describes how the `StateStore` for the resulting [[KTable]] + * should be materialized. + * @return a [[KTable]] that contains only those records that do not satisfy the given predicate + * @see `org.apache.kafka.streams.kstream.KTable#filterNot` + */ + def filterNot(predicate: (K, V) => Boolean, materialized: Materialized[K, V, ByteArrayKeyValueStore]): KTable[K, V] = + new KTable(inner.filterNot(predicate.asPredicate, materialized)) + + /** + * Create a new [[KTable]] that consists all records of this [[KTable]] which do not satisfy the given + * predicate + * + * @param predicate a filter that is applied to each record + * @param named a [[Named]] config used to name the processor in the topology + * @param materialized a `Materialized` that describes how the `StateStore` for the resulting [[KTable]] + * should be materialized. + * @return a [[KTable]] that contains only those records that do not satisfy the given predicate + * @see `org.apache.kafka.streams.kstream.KTable#filterNot` + */ + def filterNot( + predicate: (K, V) => Boolean, + named: Named, + materialized: Materialized[K, V, ByteArrayKeyValueStore] + ): KTable[K, V] = + new KTable(inner.filterNot(predicate.asPredicate, named, materialized)) + + /** + * Create a new [[KTable]] by transforming the value of each record in this [[KTable]] into a new value + * (with possible new type) in the new [[KTable]]. + *

                + * The provided `mapper`, a function `V => VR` is applied to each input record value and computes a new value for it + * + * @param mapper , a function `V => VR` that computes a new output value + * @return a [[KTable]] that contains records with unmodified key and new values (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KTable#mapValues` + */ + def mapValues[VR](mapper: V => VR): KTable[K, VR] = + new KTable(inner.mapValues[VR](mapper.asValueMapper)) + + /** + * Create a new [[KTable]] by transforming the value of each record in this [[KTable]] into a new value + * (with possible new type) in the new [[KTable]]. + *

                + * The provided `mapper`, a function `V => VR` is applied to each input record value and computes a new value for it + * + * @param mapper , a function `V => VR` that computes a new output value + * @param named a [[Named]] config used to name the processor in the topology + * @return a [[KTable]] that contains records with unmodified key and new values (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KTable#mapValues` + */ + def mapValues[VR](mapper: V => VR, named: Named): KTable[K, VR] = + new KTable(inner.mapValues[VR](mapper.asValueMapper, named)) + + /** + * Create a new [[KTable]] by transforming the value of each record in this [[KTable]] into a new value + * (with possible new type) in the new [[KTable]]. + *

                + * The provided `mapper`, a function `V => VR` is applied to each input record value and computes a new value for it + * + * @param mapper , a function `V => VR` that computes a new output value + * @param materialized a `Materialized` that describes how the `StateStore` for the resulting [[KTable]] + * should be materialized. + * @return a [[KTable]] that contains records with unmodified key and new values (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KTable#mapValues` + */ + def mapValues[VR](mapper: V => VR, materialized: Materialized[K, VR, ByteArrayKeyValueStore]): KTable[K, VR] = + new KTable(inner.mapValues[VR](mapper.asValueMapper, materialized)) + + /** + * Create a new [[KTable]] by transforming the value of each record in this [[KTable]] into a new value + * (with possible new type) in the new [[KTable]]. + *

                + * The provided `mapper`, a function `V => VR` is applied to each input record value and computes a new value for it + * + * @param mapper , a function `V => VR` that computes a new output value + * @param named a [[Named]] config used to name the processor in the topology + * @param materialized a `Materialized` that describes how the `StateStore` for the resulting [[KTable]] + * should be materialized. + * @return a [[KTable]] that contains records with unmodified key and new values (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KTable#mapValues` + */ + def mapValues[VR]( + mapper: V => VR, + named: Named, + materialized: Materialized[K, VR, ByteArrayKeyValueStore] + ): KTable[K, VR] = + new KTable(inner.mapValues[VR](mapper.asValueMapper, named, materialized)) + + /** + * Create a new [[KTable]] by transforming the value of each record in this [[KTable]] into a new value + * (with possible new type) in the new [[KTable]]. + *

                + * The provided `mapper`, a function `(K, V) => VR` is applied to each input record value and computes a new value for it + * + * @param mapper , a function `(K, V) => VR` that computes a new output value + * @return a [[KTable]] that contains records with unmodified key and new values (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KTable#mapValues` + */ + def mapValues[VR](mapper: (K, V) => VR): KTable[K, VR] = + new KTable(inner.mapValues[VR](mapper.asValueMapperWithKey)) + + /** + * Create a new [[KTable]] by transforming the value of each record in this [[KTable]] into a new value + * (with possible new type) in the new [[KTable]]. + *

                + * The provided `mapper`, a function `(K, V) => VR` is applied to each input record value and computes a new value for it + * + * @param mapper , a function `(K, V) => VR` that computes a new output value + * @param named a [[Named]] config used to name the processor in the topology + * @return a [[KTable]] that contains records with unmodified key and new values (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KTable#mapValues` + */ + def mapValues[VR](mapper: (K, V) => VR, named: Named): KTable[K, VR] = + new KTable(inner.mapValues[VR](mapper.asValueMapperWithKey, named)) + + /** + * Create a new [[KTable]] by transforming the value of each record in this [[KTable]] into a new value + * (with possible new type) in the new [[KTable]]. + *

                + * The provided `mapper`, a function `(K, V) => VR` is applied to each input record value and computes a new value for it + * + * @param mapper , a function `(K, V) => VR` that computes a new output value + * @param materialized a `Materialized` that describes how the `StateStore` for the resulting [[KTable]] + * should be materialized. + * @return a [[KTable]] that contains records with unmodified key and new values (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KTable#mapValues` + */ + def mapValues[VR](mapper: (K, V) => VR, materialized: Materialized[K, VR, ByteArrayKeyValueStore]): KTable[K, VR] = + new KTable(inner.mapValues[VR](mapper.asValueMapperWithKey)) + + /** + * Create a new [[KTable]] by transforming the value of each record in this [[KTable]] into a new value + * (with possible new type) in the new [[KTable]]. + *

                + * The provided `mapper`, a function `(K, V) => VR` is applied to each input record value and computes a new value for it + * + * @param mapper , a function `(K, V) => VR` that computes a new output value + * @param named a [[Named]] config used to name the processor in the topology + * @param materialized a `Materialized` that describes how the `StateStore` for the resulting [[KTable]] + * should be materialized. + * @return a [[KTable]] that contains records with unmodified key and new values (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KTable#mapValues` + */ + def mapValues[VR]( + mapper: (K, V) => VR, + named: Named, + materialized: Materialized[K, VR, ByteArrayKeyValueStore] + ): KTable[K, VR] = + new KTable(inner.mapValues[VR](mapper.asValueMapperWithKey, named, materialized)) + + /** + * Convert this changelog stream to a [[KStream]]. + * + * @return a [[KStream]] that contains the same records as this [[KTable]] + * @see `org.apache.kafka.streams.kstream.KTable#toStream` + */ + def toStream: KStream[K, V] = + new KStream(inner.toStream) + + /** + * Convert this changelog stream to a [[KStream]]. + * + * @param named a [[Named]] config used to name the processor in the topology + * @return a [[KStream]] that contains the same records as this [[KTable]] + * @see `org.apache.kafka.streams.kstream.KTable#toStream` + */ + def toStream(named: Named): KStream[K, V] = + new KStream(inner.toStream(named)) + + /** + * Convert this changelog stream to a [[KStream]] using the given key/value mapper to select the new key + * + * @param mapper a function that computes a new key for each record + * @return a [[KStream]] that contains the same records as this [[KTable]] + * @see `org.apache.kafka.streams.kstream.KTable#toStream` + */ + def toStream[KR](mapper: (K, V) => KR): KStream[KR, V] = + new KStream(inner.toStream[KR](mapper.asKeyValueMapper)) + + /** + * Convert this changelog stream to a [[KStream]] using the given key/value mapper to select the new key + * + * @param mapper a function that computes a new key for each record + * @param named a [[Named]] config used to name the processor in the topology + * @return a [[KStream]] that contains the same records as this [[KTable]] + * @see `org.apache.kafka.streams.kstream.KTable#toStream` + */ + def toStream[KR](mapper: (K, V) => KR, named: Named): KStream[KR, V] = + new KStream(inner.toStream[KR](mapper.asKeyValueMapper, named)) + + /** + * Suppress some updates from this changelog stream, determined by the supplied [[org.apache.kafka.streams.kstream.Suppressed]] configuration. + * + * This controls what updates downstream table and stream operations will receive. + * + * @param suppressed Configuration object determining what, if any, updates to suppress. + * @return A new KTable with the desired suppression characteristics. + * @see `org.apache.kafka.streams.kstream.KTable#suppress` + */ + def suppress(suppressed: org.apache.kafka.streams.kstream.Suppressed[_ >: K]): KTable[K, V] = + new KTable(inner.suppress(suppressed)) + + /** + * Create a new `KTable` by transforming the value of each record in this `KTable` into a new value, (with possibly new type). + * Transform the value of each input record into a new value (with possible new type) of the output record. + * A `ValueTransformerWithKey` (provided by the given `ValueTransformerWithKeySupplier`) is applied to each input + * record value and computes a new value for it. + * This is similar to `#mapValues(ValueMapperWithKey)`, but more flexible, allowing access to additional state-stores, + * and to the `ProcessorContext`. + * If the downstream topology uses aggregation functions, (e.g. `KGroupedTable#reduce`, `KGroupedTable#aggregate`, etc), + * care must be taken when dealing with state, (either held in state-stores or transformer instances), to ensure correct + * aggregate results. + * In contrast, if the resulting KTable is materialized, (cf. `#transformValues(ValueTransformerWithKeySupplier, Materialized, String...)`), + * such concerns are handled for you. + * In order to assign a state, the state must be created and registered + * beforehand via stores added via `addStateStore` or `addGlobalStore` before they can be connected to the `Transformer` + * + * @param valueTransformerWithKeySupplier a instance of `ValueTransformerWithKeySupplier` that generates a `ValueTransformerWithKey`. + * At least one transformer instance will be created per streaming task. + * Transformer implementations doe not need to be thread-safe. + * @param stateStoreNames the names of the state stores used by the processor + * @return a [[KStream]] that contains records with unmodified key and new values (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KStream#transformValues` + */ + def transformValues[VR]( + valueTransformerWithKeySupplier: ValueTransformerWithKeySupplier[K, V, VR], + stateStoreNames: String* + ): KTable[K, VR] = + new KTable(inner.transformValues[VR](valueTransformerWithKeySupplier, stateStoreNames: _*)) + + /** + * Create a new `KTable` by transforming the value of each record in this `KTable` into a new value, (with possibly new type). + * Transform the value of each input record into a new value (with possible new type) of the output record. + * A `ValueTransformerWithKey` (provided by the given `ValueTransformerWithKeySupplier`) is applied to each input + * record value and computes a new value for it. + * This is similar to `#mapValues(ValueMapperWithKey)`, but more flexible, allowing access to additional state-stores, + * and to the `ProcessorContext`. + * If the downstream topology uses aggregation functions, (e.g. `KGroupedTable#reduce`, `KGroupedTable#aggregate`, etc), + * care must be taken when dealing with state, (either held in state-stores or transformer instances), to ensure correct + * aggregate results. + * In contrast, if the resulting KTable is materialized, (cf. `#transformValues(ValueTransformerWithKeySupplier, Materialized, String...)`), + * such concerns are handled for you. + * In order to assign a state, the state must be created and registered + * beforehand via stores added via `addStateStore` or `addGlobalStore` before they can be connected to the `Transformer` + * + * @param valueTransformerWithKeySupplier a instance of `ValueTransformerWithKeySupplier` that generates a `ValueTransformerWithKey`. + * At least one transformer instance will be created per streaming task. + * Transformer implementations doe not need to be thread-safe. + * @param named a [[Named]] config used to name the processor in the topology + * @param stateStoreNames the names of the state stores used by the processor + * @return a [[KStream]] that contains records with unmodified key and new values (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KStream#transformValues` + */ + def transformValues[VR]( + valueTransformerWithKeySupplier: ValueTransformerWithKeySupplier[K, V, VR], + named: Named, + stateStoreNames: String* + ): KTable[K, VR] = + new KTable(inner.transformValues[VR](valueTransformerWithKeySupplier, named, stateStoreNames: _*)) + + /** + * Create a new `KTable` by transforming the value of each record in this `KTable` into a new value, (with possibly new type). + * A `ValueTransformer` (provided by the given `ValueTransformerSupplier`) is applied to each input + * record value and computes a new value for it. + * This is similar to `#mapValues(ValueMapperWithKey)`, but more flexible, allowing stateful, rather than stateless, + * record-by-record operation, access to additional state-stores, and access to the `ProcessorContext`. + * In order to assign a state, the state must be created and registered + * beforehand via stores added via `addStateStore` or `addGlobalStore` before they can be connected to the `Transformer` + * The resulting `KTable` is materialized into another state store (additional to the provided state store names) + * as specified by the user via `Materialized` parameter, and is queryable through its given name. + * + * @param valueTransformerWithKeySupplier a instance of `ValueTransformerWithKeySupplier` that generates a `ValueTransformerWithKey` + * At least one transformer instance will be created per streaming task. + * Transformer implementations doe not need to be thread-safe. + * @param materialized an instance of `Materialized` used to describe how the state store of the + * resulting table should be materialized. + * @param stateStoreNames the names of the state stores used by the processor + * @return a [[KStream]] that contains records with unmodified key and new values (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KStream#transformValues` + */ + def transformValues[VR]( + valueTransformerWithKeySupplier: ValueTransformerWithKeySupplier[K, V, VR], + materialized: Materialized[K, VR, KeyValueStore[Bytes, Array[Byte]]], + stateStoreNames: String* + ): KTable[K, VR] = + new KTable(inner.transformValues[VR](valueTransformerWithKeySupplier, materialized, stateStoreNames: _*)) + + /** + * Create a new `KTable` by transforming the value of each record in this `KTable` into a new value, (with possibly new type). + * A `ValueTransformer` (provided by the given `ValueTransformerSupplier`) is applied to each input + * record value and computes a new value for it. + * This is similar to `#mapValues(ValueMapperWithKey)`, but more flexible, allowing stateful, rather than stateless, + * record-by-record operation, access to additional state-stores, and access to the `ProcessorContext`. + * In order to assign a state, the state must be created and registered + * beforehand via stores added via `addStateStore` or `addGlobalStore` before they can be connected to the `Transformer` + * The resulting `KTable` is materialized into another state store (additional to the provided state store names) + * as specified by the user via `Materialized` parameter, and is queryable through its given name. + * + * @param valueTransformerWithKeySupplier a instance of `ValueTransformerWithKeySupplier` that generates a `ValueTransformerWithKey` + * At least one transformer instance will be created per streaming task. + * Transformer implementations doe not need to be thread-safe. + * @param materialized an instance of `Materialized` used to describe how the state store of the + * resulting table should be materialized. + * @param named a [[Named]] config used to name the processor in the topology + * @param stateStoreNames the names of the state stores used by the processor + * @return a [[KStream]] that contains records with unmodified key and new values (possibly of different type) + * @see `org.apache.kafka.streams.kstream.KStream#transformValues` + */ + def transformValues[VR]( + valueTransformerWithKeySupplier: ValueTransformerWithKeySupplier[K, V, VR], + materialized: Materialized[K, VR, KeyValueStore[Bytes, Array[Byte]]], + named: Named, + stateStoreNames: String* + ): KTable[K, VR] = + new KTable(inner.transformValues[VR](valueTransformerWithKeySupplier, materialized, named, stateStoreNames: _*)) + + /** + * Re-groups the records of this [[KTable]] using the provided key/value mapper + * and `Serde`s as specified by `Grouped`. + * + * @param selector a function that computes a new grouping key and value to be aggregated + * @param grouped the `Grouped` instance used to specify `Serdes` + * @return a [[KGroupedTable]] that contains the re-grouped records of the original [[KTable]] + * @see `org.apache.kafka.streams.kstream.KTable#groupBy` + */ + def groupBy[KR, VR](selector: (K, V) => (KR, VR))(implicit grouped: Grouped[KR, VR]): KGroupedTable[KR, VR] = + new KGroupedTable(inner.groupBy(selector.asKeyValueMapper, grouped)) + + /** + * Join records of this [[KTable]] with another [[KTable]]'s records using non-windowed inner equi join. + * + * @param other the other [[KTable]] to be joined with this [[KTable]] + * @param joiner a function that computes the join result for a pair of matching records + * @return a [[KTable]] that contains join-records for each key and values computed by the given joiner, + * one for each matched record-pair with the same key + * @see `org.apache.kafka.streams.kstream.KTable#join` + */ + def join[VO, VR](other: KTable[K, VO])(joiner: (V, VO) => VR): KTable[K, VR] = + new KTable(inner.join[VO, VR](other.inner, joiner.asValueJoiner)) + + /** + * Join records of this [[KTable]] with another [[KTable]]'s records using non-windowed inner equi join. + * + * @param other the other [[KTable]] to be joined with this [[KTable]] + * @param named a [[Named]] config used to name the processor in the topology + * @param joiner a function that computes the join result for a pair of matching records + * @return a [[KTable]] that contains join-records for each key and values computed by the given joiner, + * one for each matched record-pair with the same key + * @see `org.apache.kafka.streams.kstream.KTable#join` + */ + def join[VO, VR](other: KTable[K, VO], named: Named)(joiner: (V, VO) => VR): KTable[K, VR] = + new KTable(inner.join[VO, VR](other.inner, joiner.asValueJoiner, named)) + + /** + * Join records of this [[KTable]] with another [[KTable]]'s records using non-windowed inner equi join. + * + * @param other the other [[KTable]] to be joined with this [[KTable]] + * @param joiner a function that computes the join result for a pair of matching records + * @param materialized a `Materialized` that describes how the `StateStore` for the resulting [[KTable]] + * should be materialized. + * @return a [[KTable]] that contains join-records for each key and values computed by the given joiner, + * one for each matched record-pair with the same key + * @see `org.apache.kafka.streams.kstream.KTable#join` + */ + def join[VO, VR](other: KTable[K, VO], materialized: Materialized[K, VR, ByteArrayKeyValueStore])( + joiner: (V, VO) => VR + ): KTable[K, VR] = + new KTable(inner.join[VO, VR](other.inner, joiner.asValueJoiner, materialized)) + + /** + * Join records of this [[KTable]] with another [[KTable]]'s records using non-windowed inner equi join. + * + * @param other the other [[KTable]] to be joined with this [[KTable]] + * @param joiner a function that computes the join result for a pair of matching records + * @param named a [[Named]] config used to name the processor in the topology + * @param materialized a `Materialized` that describes how the `StateStore` for the resulting [[KTable]] + * should be materialized. + * @return a [[KTable]] that contains join-records for each key and values computed by the given joiner, + * one for each matched record-pair with the same key + * @see `org.apache.kafka.streams.kstream.KTable#join` + */ + def join[VO, VR](other: KTable[K, VO], named: Named, materialized: Materialized[K, VR, ByteArrayKeyValueStore])( + joiner: (V, VO) => VR + ): KTable[K, VR] = + new KTable(inner.join[VO, VR](other.inner, joiner.asValueJoiner, named, materialized)) + + /** + * Join records of this [[KTable]] with another [[KTable]]'s records using non-windowed left equi join. + * + * @param other the other [[KTable]] to be joined with this [[KTable]] + * @param joiner a function that computes the join result for a pair of matching records + * @return a [[KTable]] that contains join-records for each key and values computed by the given joiner, + * one for each matched record-pair with the same key + * @see `org.apache.kafka.streams.kstream.KTable#leftJoin` + */ + def leftJoin[VO, VR](other: KTable[K, VO])(joiner: (V, VO) => VR): KTable[K, VR] = + new KTable(inner.leftJoin[VO, VR](other.inner, joiner.asValueJoiner)) + + /** + * Join records of this [[KTable]] with another [[KTable]]'s records using non-windowed left equi join. + * + * @param other the other [[KTable]] to be joined with this [[KTable]] + * @param named a [[Named]] config used to name the processor in the topology + * @param joiner a function that computes the join result for a pair of matching records + * @return a [[KTable]] that contains join-records for each key and values computed by the given joiner, + * one for each matched record-pair with the same key + * @see `org.apache.kafka.streams.kstream.KTable#leftJoin` + */ + def leftJoin[VO, VR](other: KTable[K, VO], named: Named)(joiner: (V, VO) => VR): KTable[K, VR] = + new KTable(inner.leftJoin[VO, VR](other.inner, joiner.asValueJoiner, named)) + + /** + * Join records of this [[KTable]] with another [[KTable]]'s records using non-windowed left equi join. + * + * @param other the other [[KTable]] to be joined with this [[KTable]] + * @param joiner a function that computes the join result for a pair of matching records + * @param materialized a `Materialized` that describes how the `StateStore` for the resulting [[KTable]] + * should be materialized. + * @return a [[KTable]] that contains join-records for each key and values computed by the given joiner, + * one for each matched record-pair with the same key + * @see `org.apache.kafka.streams.kstream.KTable#leftJoin` + */ + def leftJoin[VO, VR](other: KTable[K, VO], materialized: Materialized[K, VR, ByteArrayKeyValueStore])( + joiner: (V, VO) => VR + ): KTable[K, VR] = + new KTable(inner.leftJoin[VO, VR](other.inner, joiner.asValueJoiner, materialized)) + + /** + * Join records of this [[KTable]] with another [[KTable]]'s records using non-windowed left equi join. + * + * @param other the other [[KTable]] to be joined with this [[KTable]] + * @param named a [[Named]] config used to name the processor in the topology + * @param joiner a function that computes the join result for a pair of matching records + * @param materialized a `Materialized` that describes how the `StateStore` for the resulting [[KTable]] + * should be materialized. + * @return a [[KTable]] that contains join-records for each key and values computed by the given joiner, + * one for each matched record-pair with the same key + * @see `org.apache.kafka.streams.kstream.KTable#leftJoin` + */ + def leftJoin[VO, VR](other: KTable[K, VO], named: Named, materialized: Materialized[K, VR, ByteArrayKeyValueStore])( + joiner: (V, VO) => VR + ): KTable[K, VR] = + new KTable(inner.leftJoin[VO, VR](other.inner, joiner.asValueJoiner, named, materialized)) + + /** + * Join records of this [[KTable]] with another [[KTable]]'s records using non-windowed outer equi join. + * + * @param other the other [[KTable]] to be joined with this [[KTable]] + * @param joiner a function that computes the join result for a pair of matching records + * @return a [[KTable]] that contains join-records for each key and values computed by the given joiner, + * one for each matched record-pair with the same key + * @see `org.apache.kafka.streams.kstream.KTable#leftJoin` + */ + def outerJoin[VO, VR](other: KTable[K, VO])(joiner: (V, VO) => VR): KTable[K, VR] = + new KTable(inner.outerJoin[VO, VR](other.inner, joiner.asValueJoiner)) + + /** + * Join records of this [[KTable]] with another [[KTable]]'s records using non-windowed outer equi join. + * + * @param other the other [[KTable]] to be joined with this [[KTable]] + * @param named a [[Named]] config used to name the processor in the topology + * @param joiner a function that computes the join result for a pair of matching records + * @return a [[KTable]] that contains join-records for each key and values computed by the given joiner, + * one for each matched record-pair with the same key + * @see `org.apache.kafka.streams.kstream.KTable#leftJoin` + */ + def outerJoin[VO, VR](other: KTable[K, VO], named: Named)(joiner: (V, VO) => VR): KTable[K, VR] = + new KTable(inner.outerJoin[VO, VR](other.inner, joiner.asValueJoiner, named)) + + /** + * Join records of this [[KTable]] with another [[KTable]]'s records using non-windowed outer equi join. + * + * @param other the other [[KTable]] to be joined with this [[KTable]] + * @param joiner a function that computes the join result for a pair of matching records + * @param materialized a `Materialized` that describes how the `StateStore` for the resulting [[KTable]] + * should be materialized. + * @return a [[KTable]] that contains join-records for each key and values computed by the given joiner, + * one for each matched record-pair with the same key + * @see `org.apache.kafka.streams.kstream.KTable#leftJoin` + */ + def outerJoin[VO, VR](other: KTable[K, VO], materialized: Materialized[K, VR, ByteArrayKeyValueStore])( + joiner: (V, VO) => VR + ): KTable[K, VR] = + new KTable(inner.outerJoin[VO, VR](other.inner, joiner.asValueJoiner, materialized)) + + /** + * Join records of this [[KTable]] with another [[KTable]]'s records using non-windowed outer equi join. + * + * @param other the other [[KTable]] to be joined with this [[KTable]] + * @param named a [[Named]] config used to name the processor in the topology + * @param joiner a function that computes the join result for a pair of matching records + * @param materialized a `Materialized` that describes how the `StateStore` for the resulting [[KTable]] + * should be materialized. + * @return a [[KTable]] that contains join-records for each key and values computed by the given joiner, + * one for each matched record-pair with the same key + * @see `org.apache.kafka.streams.kstream.KTable#leftJoin` + */ + def outerJoin[VO, VR](other: KTable[K, VO], named: Named, materialized: Materialized[K, VR, ByteArrayKeyValueStore])( + joiner: (V, VO) => VR + ): KTable[K, VR] = + new KTable(inner.outerJoin[VO, VR](other.inner, joiner.asValueJoiner, named, materialized)) + + /** + * Join records of this [[KTable]] with another [[KTable]]'s records using non-windowed inner join. Records from this + * table are joined according to the result of keyExtractor on the other KTable. + * + * @param other the other [[KTable]] to be joined with this [[KTable]], keyed on the value obtained from keyExtractor + * @param keyExtractor a function that extracts the foreign key from this table's value + * @param joiner a function that computes the join result for a pair of matching records + * @param materialized a `Materialized` that describes how the `StateStore` for the resulting [[KTable]] + * should be materialized. + * @return a [[KTable]] that contains join-records for each key and values computed by the given joiner, + * one for each matched record-pair with the same key + */ + def join[VR, KO, VO]( + other: KTable[KO, VO], + keyExtractor: Function[V, KO], + joiner: ValueJoiner[V, VO, VR], + materialized: Materialized[K, VR, KeyValueStore[Bytes, Array[Byte]]] + ): KTable[K, VR] = + new KTable(inner.join(other.inner, keyExtractor.asJavaFunction, joiner, materialized)) + + /** + * Join records of this [[KTable]] with another [[KTable]]'s records using non-windowed inner join. Records from this + * table are joined according to the result of keyExtractor on the other KTable. + * + * @param other the other [[KTable]] to be joined with this [[KTable]], keyed on the value obtained from keyExtractor + * @param keyExtractor a function that extracts the foreign key from this table's value + * @param joiner a function that computes the join result for a pair of matching records + * @param named a [[Named]] config used to name the processor in the topology + * @param materialized a `Materialized` that describes how the `StateStore` for the resulting [[KTable]] + * should be materialized. + * @return a [[KTable]] that contains join-records for each key and values computed by the given joiner, + * one for each matched record-pair with the same key + */ + @deprecated("Use join(KTable, Function, ValueJoiner, TableJoined, Materialized) instead", since = "3.1") + def join[VR, KO, VO]( + other: KTable[KO, VO], + keyExtractor: Function[V, KO], + joiner: ValueJoiner[V, VO, VR], + named: Named, + materialized: Materialized[K, VR, KeyValueStore[Bytes, Array[Byte]]] + ): KTable[K, VR] = + new KTable(inner.join(other.inner, keyExtractor.asJavaFunction, joiner, named, materialized)) + + /** + * Join records of this [[KTable]] with another [[KTable]]'s records using non-windowed inner join. Records from this + * table are joined according to the result of keyExtractor on the other KTable. + * + * @param other the other [[KTable]] to be joined with this [[KTable]], keyed on the value obtained from keyExtractor + * @param keyExtractor a function that extracts the foreign key from this table's value + * @param joiner a function that computes the join result for a pair of matching records + * @param tableJoined a [[TableJoined]] used to configure partitioners and names of internal topics and stores + * @param materialized a `Materialized` that describes how the `StateStore` for the resulting [[KTable]] + * should be materialized. + * @return a [[KTable]] that contains join-records for each key and values computed by the given joiner, + * one for each matched record-pair with the same key + */ + def join[VR, KO, VO]( + other: KTable[KO, VO], + keyExtractor: Function[V, KO], + joiner: ValueJoiner[V, VO, VR], + tableJoined: TableJoined[K, KO], + materialized: Materialized[K, VR, KeyValueStore[Bytes, Array[Byte]]] + ): KTable[K, VR] = + new KTable(inner.join(other.inner, keyExtractor.asJavaFunction, joiner, tableJoined, materialized)) + + /** + * Join records of this [[KTable]] with another [[KTable]]'s records using non-windowed left join. Records from this + * table are joined according to the result of keyExtractor on the other KTable. + * + * @param other the other [[KTable]] to be joined with this [[KTable]], keyed on the value obtained from keyExtractor + * @param keyExtractor a function that extracts the foreign key from this table's value + * @param joiner a function that computes the join result for a pair of matching records + * @param materialized a `Materialized` that describes how the `StateStore` for the resulting [[KTable]] + * should be materialized. + * @return a [[KTable]] that contains join-records for each key and values computed by the given joiner, + * one for each matched record-pair with the same key + */ + def leftJoin[VR, KO, VO]( + other: KTable[KO, VO], + keyExtractor: Function[V, KO], + joiner: ValueJoiner[V, VO, VR], + materialized: Materialized[K, VR, KeyValueStore[Bytes, Array[Byte]]] + ): KTable[K, VR] = + new KTable(inner.leftJoin(other.inner, keyExtractor.asJavaFunction, joiner, materialized)) + + /** + * Join records of this [[KTable]] with another [[KTable]]'s records using non-windowed left join. Records from this + * table are joined according to the result of keyExtractor on the other KTable. + * + * @param other the other [[KTable]] to be joined with this [[KTable]], keyed on the value obtained from keyExtractor + * @param keyExtractor a function that extracts the foreign key from this table's value + * @param joiner a function that computes the join result for a pair of matching records + * @param named a [[Named]] config used to name the processor in the topology + * @param materialized a `Materialized` that describes how the `StateStore` for the resulting [[KTable]] + * should be materialized. + * @return a [[KTable]] that contains join-records for each key and values computed by the given joiner, + * one for each matched record-pair with the same key + */ + @deprecated("Use leftJoin(KTable, Function, ValueJoiner, TableJoined, Materialized) instead", since = "3.1") + def leftJoin[VR, KO, VO]( + other: KTable[KO, VO], + keyExtractor: Function[V, KO], + joiner: ValueJoiner[V, VO, VR], + named: Named, + materialized: Materialized[K, VR, KeyValueStore[Bytes, Array[Byte]]] + ): KTable[K, VR] = + new KTable(inner.leftJoin(other.inner, keyExtractor.asJavaFunction, joiner, named, materialized)) + + /** + * Join records of this [[KTable]] with another [[KTable]]'s records using non-windowed left join. Records from this + * table are joined according to the result of keyExtractor on the other KTable. + * + * @param other the other [[KTable]] to be joined with this [[KTable]], keyed on the value obtained from keyExtractor + * @param keyExtractor a function that extracts the foreign key from this table's value + * @param joiner a function that computes the join result for a pair of matching records + * @param tableJoined a [[TableJoined]] used to configure partitioners and names of internal topics and stores + * @param materialized a `Materialized` that describes how the `StateStore` for the resulting [[KTable]] + * should be materialized. + * @return a [[KTable]] that contains join-records for each key and values computed by the given joiner, + * one for each matched record-pair with the same key + */ + def leftJoin[VR, KO, VO]( + other: KTable[KO, VO], + keyExtractor: Function[V, KO], + joiner: ValueJoiner[V, VO, VR], + tableJoined: TableJoined[K, KO], + materialized: Materialized[K, VR, KeyValueStore[Bytes, Array[Byte]]] + ): KTable[K, VR] = + new KTable(inner.leftJoin(other.inner, keyExtractor.asJavaFunction, joiner, tableJoined, materialized)) + + /** + * Get the name of the local state store used that can be used to query this [[KTable]]. + * + * @return the underlying state store name, or `null` if this [[KTable]] cannot be queried. + */ + def queryableStoreName: String = + inner.queryableStoreName +} diff --git a/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/Materialized.scala b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/Materialized.scala new file mode 100644 index 0000000..421ac5a --- /dev/null +++ b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/Materialized.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala.kstream + +import org.apache.kafka.common.serialization.Serde +import org.apache.kafka.streams.kstream.{Materialized => MaterializedJ} +import org.apache.kafka.streams.processor.StateStore +import org.apache.kafka.streams.scala.{ByteArrayKeyValueStore, ByteArraySessionStore, ByteArrayWindowStore} +import org.apache.kafka.streams.state.{KeyValueBytesStoreSupplier, SessionBytesStoreSupplier, WindowBytesStoreSupplier} + +object Materialized { + + /** + * Materialize a [[StateStore]] with the provided key and value [[Serde]]s. + * An internal name will be used for the store. + * + * @tparam K key type + * @tparam V value type + * @tparam S store type + * @param keySerde the key [[Serde]] to use. + * @param valueSerde the value [[Serde]] to use. + * @return a new [[Materialized]] instance with the given key and value serdes + */ + def `with`[K, V, S <: StateStore](implicit keySerde: Serde[K], valueSerde: Serde[V]): MaterializedJ[K, V, S] = + MaterializedJ.`with`(keySerde, valueSerde) + + /** + * Materialize a [[StateStore]] with the given name. + * + * @tparam K key type of the store + * @tparam V value type of the store + * @tparam S type of the [[StateStore]] + * @param storeName the name of the underlying [[org.apache.kafka.streams.scala.kstream.KTable]] state store; + * valid characters are ASCII alphanumerics, '.', '_' and '-'. + * @param keySerde the key serde to use. + * @param valueSerde the value serde to use. + * @return a new [[Materialized]] instance with the given storeName + */ + def as[K, V, S <: StateStore]( + storeName: String + )(implicit keySerde: Serde[K], valueSerde: Serde[V]): MaterializedJ[K, V, S] = + MaterializedJ.as(storeName).withKeySerde(keySerde).withValueSerde(valueSerde) + + /** + * Materialize a [[org.apache.kafka.streams.state.WindowStore]] using the provided [[WindowBytesStoreSupplier]]. + * + * Important: Custom subclasses are allowed here, but they should respect the retention contract: + * Window stores are required to retain windows at least as long as (window size + window grace period). + * Stores constructed via [[org.apache.kafka.streams.state.Stores]] already satisfy this contract. + * + * @tparam K key type of the store + * @tparam V value type of the store + * @param supplier the [[WindowBytesStoreSupplier]] used to materialize the store + * @param keySerde the key serde to use. + * @param valueSerde the value serde to use. + * @return a new [[Materialized]] instance with the given supplier + */ + def as[K, V]( + supplier: WindowBytesStoreSupplier + )(implicit keySerde: Serde[K], valueSerde: Serde[V]): MaterializedJ[K, V, ByteArrayWindowStore] = + MaterializedJ.as(supplier).withKeySerde(keySerde).withValueSerde(valueSerde) + + /** + * Materialize a [[org.apache.kafka.streams.state.SessionStore]] using the provided [[SessionBytesStoreSupplier]]. + * + * Important: Custom subclasses are allowed here, but they should respect the retention contract: + * Session stores are required to retain windows at least as long as (session inactivity gap + session grace period). + * Stores constructed via [[org.apache.kafka.streams.state.Stores]] already satisfy this contract. + * + * @tparam K key type of the store + * @tparam V value type of the store + * @param supplier the [[SessionBytesStoreSupplier]] used to materialize the store + * @param keySerde the key serde to use. + * @param valueSerde the value serde to use. + * @return a new [[Materialized]] instance with the given supplier + */ + def as[K, V]( + supplier: SessionBytesStoreSupplier + )(implicit keySerde: Serde[K], valueSerde: Serde[V]): MaterializedJ[K, V, ByteArraySessionStore] = + MaterializedJ.as(supplier).withKeySerde(keySerde).withValueSerde(valueSerde) + + /** + * Materialize a [[org.apache.kafka.streams.state.KeyValueStore]] using the provided [[KeyValueBytesStoreSupplier]]. + * + * @tparam K key type of the store + * @tparam V value type of the store + * @param supplier the [[KeyValueBytesStoreSupplier]] used to materialize the store + * @param keySerde the key serde to use. + * @param valueSerde the value serde to use. + * @return a new [[Materialized]] instance with the given supplier + */ + def as[K, V]( + supplier: KeyValueBytesStoreSupplier + )(implicit keySerde: Serde[K], valueSerde: Serde[V]): MaterializedJ[K, V, ByteArrayKeyValueStore] = + MaterializedJ.as(supplier).withKeySerde(keySerde).withValueSerde(valueSerde) +} diff --git a/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/Produced.scala b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/Produced.scala new file mode 100644 index 0000000..48f9178 --- /dev/null +++ b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/Produced.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala.kstream + +import org.apache.kafka.common.serialization.Serde +import org.apache.kafka.streams.kstream.{Produced => ProducedJ} +import org.apache.kafka.streams.processor.StreamPartitioner + +object Produced { + + /** + * Create a Produced instance with provided keySerde and valueSerde. + * + * @tparam K key type + * @tparam V value type + * @param keySerde Serde to use for serializing the key + * @param valueSerde Serde to use for serializing the value + * @return A new [[Produced]] instance configured with keySerde and valueSerde + * @see KStream#through(String, Produced) + * @see KStream#to(String, Produced) + */ + def `with`[K, V](implicit keySerde: Serde[K], valueSerde: Serde[V]): ProducedJ[K, V] = + ProducedJ.`with`(keySerde, valueSerde) + + /** + * Create a Produced instance with provided keySerde, valueSerde, and partitioner. + * + * @tparam K key type + * @tparam V value type + * @param partitioner the function used to determine how records are distributed among partitions of the topic, + * if not specified and `keySerde` provides a + * [[org.apache.kafka.streams.kstream.internals.WindowedSerializer]] for the key + * [[org.apache.kafka.streams.kstream.internals.WindowedStreamPartitioner]] will be + * used—otherwise [[org.apache.kafka.clients.producer.internals.DefaultPartitioner]] + * will be used + * @param keySerde Serde to use for serializing the key + * @param valueSerde Serde to use for serializing the value + * @return A new [[Produced]] instance configured with keySerde, valueSerde, and partitioner + * @see KStream#through(String, Produced) + * @see KStream#to(String, Produced) + */ + def `with`[K, V]( + partitioner: StreamPartitioner[K, V] + )(implicit keySerde: Serde[K], valueSerde: Serde[V]): ProducedJ[K, V] = + ProducedJ.`with`(keySerde, valueSerde, partitioner) +} diff --git a/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/Repartitioned.scala b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/Repartitioned.scala new file mode 100644 index 0000000..5f33efa --- /dev/null +++ b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/Repartitioned.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala.kstream + +import org.apache.kafka.common.serialization.Serde +import org.apache.kafka.streams.kstream.{Repartitioned => RepartitionedJ} +import org.apache.kafka.streams.processor.StreamPartitioner + +object Repartitioned { + + /** + * Create a Repartitioned instance with provided keySerde and valueSerde. + * + * @tparam K key type + * @tparam V value type + * @param keySerde Serde to use for serializing the key + * @param valueSerde Serde to use for serializing the value + * @return A new [[Repartitioned]] instance configured with keySerde and valueSerde + * @see KStream#repartition(Repartitioned) + */ + def `with`[K, V](implicit keySerde: Serde[K], valueSerde: Serde[V]): RepartitionedJ[K, V] = + RepartitionedJ.`with`(keySerde, valueSerde) + + /** + * Create a Repartitioned instance with provided keySerde, valueSerde, and name used as part of the repartition topic. + * + * @tparam K key type + * @tparam V value type + * @param name the name used as a processor named and part of the repartition topic name. + * @param keySerde Serde to use for serializing the key + * @param valueSerde Serde to use for serializing the value + * @return A new [[Repartitioned]] instance configured with keySerde, valueSerde, and processor and repartition topic name + * @see KStream#repartition(Repartitioned) + */ + def `with`[K, V](name: String)(implicit keySerde: Serde[K], valueSerde: Serde[V]): RepartitionedJ[K, V] = + RepartitionedJ.`as`(name).withKeySerde(keySerde).withValueSerde(valueSerde) + + /** + * Create a Repartitioned instance with provided keySerde, valueSerde, and partitioner. + * + * @tparam K key type + * @tparam V value type + * @param partitioner the function used to determine how records are distributed among partitions of the topic, + * if not specified and `keySerde` provides a + * [[org.apache.kafka.streams.kstream.internals.WindowedSerializer]] for the key + * [[org.apache.kafka.streams.kstream.internals.WindowedStreamPartitioner]] will be + * used—otherwise [[org.apache.kafka.clients.producer.internals.DefaultPartitioner]] + * will be used + * @param keySerde Serde to use for serializing the key + * @param valueSerde Serde to use for serializing the value + * @return A new [[Repartitioned]] instance configured with keySerde, valueSerde, and partitioner + * @see KStream#repartition(Repartitioned) + */ + def `with`[K, V]( + partitioner: StreamPartitioner[K, V] + )(implicit keySerde: Serde[K], valueSerde: Serde[V]): RepartitionedJ[K, V] = + RepartitionedJ.`streamPartitioner`(partitioner).withKeySerde(keySerde).withValueSerde(valueSerde) + + /** + * Create a Repartitioned instance with provided keySerde, valueSerde, and number of partitions for repartition topic. + * + * @tparam K key type + * @tparam V value type + * @param numberOfPartitions number of partitions used when creating repartition topic + * @param keySerde Serde to use for serializing the key + * @param valueSerde Serde to use for serializing the value + * @return A new [[Repartitioned]] instance configured with keySerde, valueSerde, and number of partitions + * @see KStream#repartition(Repartitioned) + */ + def `with`[K, V](numberOfPartitions: Int)(implicit keySerde: Serde[K], valueSerde: Serde[V]): RepartitionedJ[K, V] = + RepartitionedJ.`numberOfPartitions`(numberOfPartitions).withKeySerde(keySerde).withValueSerde(valueSerde) + +} diff --git a/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/SessionWindowedCogroupedKStream.scala b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/SessionWindowedCogroupedKStream.scala new file mode 100644 index 0000000..1b20179 --- /dev/null +++ b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/SessionWindowedCogroupedKStream.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala +package kstream + +import org.apache.kafka.streams.kstream.{SessionWindowedCogroupedKStream => SessionWindowedCogroupedKStreamJ, Windowed} +import org.apache.kafka.streams.scala.FunctionsCompatConversions.{InitializerFromFunction, MergerFromFunction} + +/** + * Wraps the Java class SessionWindowedCogroupedKStream and delegates method calls to the underlying Java object. + * + * @tparam K Type of keys + * @tparam V Type of values + * @param inner The underlying Java abstraction for SessionWindowedCogroupedKStream + * @see `org.apache.kafka.streams.kstream.SessionWindowedCogroupedKStream` + */ +class SessionWindowedCogroupedKStream[K, V](val inner: SessionWindowedCogroupedKStreamJ[K, V]) { + + /** + * Aggregate the values of records in this stream by the grouped key and defined `SessionWindows`. + * + * @param initializer the initializer function + * @param merger a function that combines two aggregation results. + * @param materialized an instance of `Materialized` used to materialize a state store. + * @return a windowed [[KTable]] that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key within a window + * @see `org.apache.kafka.streams.kstream.SessionWindowedCogroupedKStream#aggregate` + */ + def aggregate(initializer: => V, merger: (K, V, V) => V)(implicit + materialized: Materialized[K, V, ByteArraySessionStore] + ): KTable[Windowed[K], V] = + new KTable(inner.aggregate((() => initializer).asInitializer, merger.asMerger, materialized)) + + /** + * Aggregate the values of records in this stream by the grouped key and defined `SessionWindows`. + * + * @param initializer the initializer function + * @param merger a function that combines two aggregation results. + * @param named a [[Named]] config used to name the processor in the topology + * @param materialized an instance of `Materialized` used to materialize a state store. + * @return a windowed [[KTable]] that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key within a window + * @see `org.apache.kafka.streams.kstream.SessionWindowedCogroupedKStream#aggregate` + */ + def aggregate(initializer: => V, merger: (K, V, V) => V, named: Named)(implicit + materialized: Materialized[K, V, ByteArraySessionStore] + ): KTable[Windowed[K], V] = + new KTable(inner.aggregate((() => initializer).asInitializer, merger.asMerger, named, materialized)) + +} diff --git a/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/SessionWindowedKStream.scala b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/SessionWindowedKStream.scala new file mode 100644 index 0000000..3d6e157 --- /dev/null +++ b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/SessionWindowedKStream.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala +package kstream + +import org.apache.kafka.streams.kstream.internals.KTableImpl +import org.apache.kafka.streams.scala.serialization.Serdes +import org.apache.kafka.streams.kstream.{KTable => KTableJ, SessionWindowedKStream => SessionWindowedKStreamJ, Windowed} +import org.apache.kafka.streams.scala.FunctionsCompatConversions.{ + AggregatorFromFunction, + InitializerFromFunction, + MergerFromFunction, + ReducerFromFunction, + ValueMapperFromFunction +} + +/** + * Wraps the Java class SessionWindowedKStream and delegates method calls to the underlying Java object. + * + * @tparam K Type of keys + * @tparam V Type of values + * @param inner The underlying Java abstraction for SessionWindowedKStream + * @see `org.apache.kafka.streams.kstream.SessionWindowedKStream` + */ +class SessionWindowedKStream[K, V](val inner: SessionWindowedKStreamJ[K, V]) { + + /** + * Aggregate the values of records in this stream by the grouped key and defined `SessionWindows`. + * + * @param initializer the initializer function + * @param aggregator the aggregator function + * @param merger the merger function + * @param materialized an instance of `Materialized` used to materialize a state store. + * @return a windowed [[KTable]] that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key within a window + * @see `org.apache.kafka.streams.kstream.SessionWindowedKStream#aggregate` + */ + def aggregate[VR](initializer: => VR)(aggregator: (K, V, VR) => VR, merger: (K, VR, VR) => VR)(implicit + materialized: Materialized[K, VR, ByteArraySessionStore] + ): KTable[Windowed[K], VR] = + new KTable( + inner.aggregate((() => initializer).asInitializer, aggregator.asAggregator, merger.asMerger, materialized) + ) + + /** + * Aggregate the values of records in this stream by the grouped key and defined `SessionWindows`. + * + * @param initializer the initializer function + * @param aggregator the aggregator function + * @param merger the merger function + * @param named a [[Named]] config used to name the processor in the topology + * @param materialized an instance of `Materialized` used to materialize a state store. + * @return a windowed [[KTable]] that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key within a window + * @see `org.apache.kafka.streams.kstream.SessionWindowedKStream#aggregate` + */ + def aggregate[VR](initializer: => VR, named: Named)(aggregator: (K, V, VR) => VR, merger: (K, VR, VR) => VR)(implicit + materialized: Materialized[K, VR, ByteArraySessionStore] + ): KTable[Windowed[K], VR] = + new KTable( + inner.aggregate((() => initializer).asInitializer, aggregator.asAggregator, merger.asMerger, named, materialized) + ) + + /** + * Count the number of records in this stream by the grouped key into `SessionWindows`. + * + * @param materialized an instance of `Materialized` used to materialize a state store. + * @return a windowed [[KTable]] that contains "update" records with unmodified keys and `Long` values + * that represent the latest (rolling) count (i.e., number of records) for each key within a window + * @see `org.apache.kafka.streams.kstream.SessionWindowedKStream#count` + */ + def count()(implicit materialized: Materialized[K, Long, ByteArraySessionStore]): KTable[Windowed[K], Long] = { + val javaCountTable: KTableJ[Windowed[K], java.lang.Long] = + inner.count(materialized.asInstanceOf[Materialized[K, java.lang.Long, ByteArraySessionStore]]) + val tableImpl = javaCountTable.asInstanceOf[KTableImpl[Windowed[K], ByteArraySessionStore, java.lang.Long]] + new KTable( + javaCountTable.mapValues[Long]( + ((l: java.lang.Long) => Long2long(l)).asValueMapper, + Materialized.`with`[Windowed[K], Long, ByteArrayKeyValueStore](tableImpl.keySerde(), Serdes.longSerde) + ) + ) + } + + /** + * Count the number of records in this stream by the grouped key into `SessionWindows`. + * + * @param named a [[Named]] config used to name the processor in the topology + * @param materialized an instance of `Materialized` used to materialize a state store. + * @return a windowed [[KTable]] that contains "update" records with unmodified keys and `Long` values + * that represent the latest (rolling) count (i.e., number of records) for each key within a window + * @see `org.apache.kafka.streams.kstream.SessionWindowedKStream#count` + */ + def count( + named: Named + )(implicit materialized: Materialized[K, Long, ByteArraySessionStore]): KTable[Windowed[K], Long] = { + val javaCountTable: KTableJ[Windowed[K], java.lang.Long] = + inner.count(named, materialized.asInstanceOf[Materialized[K, java.lang.Long, ByteArraySessionStore]]) + val tableImpl = javaCountTable.asInstanceOf[KTableImpl[Windowed[K], ByteArraySessionStore, java.lang.Long]] + new KTable( + javaCountTable.mapValues[Long]( + ((l: java.lang.Long) => Long2long(l)).asValueMapper, + Materialized.`with`[Windowed[K], Long, ByteArrayKeyValueStore](tableImpl.keySerde(), Serdes.longSerde) + ) + ) + } + + /** + * Combine values of this stream by the grouped key into `SessionWindows`. + * + * @param reducer a reducer function that computes a new aggregate result. + * @param materialized an instance of `Materialized` used to materialize a state store. + * @return a windowed [[KTable]] that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key within a window + * @see `org.apache.kafka.streams.kstream.SessionWindowedKStream#reduce` + */ + def reduce(reducer: (V, V) => V)(implicit + materialized: Materialized[K, V, ByteArraySessionStore] + ): KTable[Windowed[K], V] = + new KTable(inner.reduce(reducer.asReducer, materialized)) + + /** + * Combine values of this stream by the grouped key into `SessionWindows`. + * + * @param reducer a reducer function that computes a new aggregate result. + * @param materialized an instance of `Materialized` used to materialize a state store. + * @return a windowed [[KTable]] that contains "update" records with unmodified keys, and values that represent + * the latest (rolling) aggregate for each key within a window + * @see `org.apache.kafka.streams.kstream.SessionWindowedKStream#reduce` + */ + def reduce(reducer: (V, V) => V, named: Named)(implicit + materialized: Materialized[K, V, ByteArraySessionStore] + ): KTable[Windowed[K], V] = + new KTable(inner.reduce(reducer.asReducer, named, materialized)) +} diff --git a/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/StreamJoined.scala b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/StreamJoined.scala new file mode 100644 index 0000000..9caad63 --- /dev/null +++ b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/StreamJoined.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala.kstream + +import org.apache.kafka.common.serialization.Serde +import org.apache.kafka.streams.kstream.{StreamJoined => StreamJoinedJ} +import org.apache.kafka.streams.state.WindowBytesStoreSupplier + +object StreamJoined { + + /** + * Create an instance of [[StreamJoined]] with key, value, and otherValue [[Serde]] + * instances. + * `null` values are accepted and will be replaced by the default serdes as defined in config. + * + * @tparam K key type + * @tparam V value type + * @tparam VO other value type + * @param keySerde the key serde to use. + * @param valueSerde the value serde to use. + * @param otherValueSerde the otherValue serde to use. If `null` the default value serde from config will be used + * @return new [[StreamJoined]] instance with the provided serdes + */ + def `with`[K, V, VO](implicit + keySerde: Serde[K], + valueSerde: Serde[V], + otherValueSerde: Serde[VO] + ): StreamJoinedJ[K, V, VO] = + StreamJoinedJ.`with`(keySerde, valueSerde, otherValueSerde) + + /** + * Create an instance of [[StreamJoined]] with store suppliers for the calling stream + * and the other stream. Also adds the key, value, and otherValue [[Serde]] + * instances. + * `null` values are accepted and will be replaced by the default serdes as defined in config. + * + * @tparam K key type + * @tparam V value type + * @tparam VO other value type + * @param supplier store supplier to use + * @param otherSupplier other store supplier to use + * @param keySerde the key serde to use. + * @param valueSerde the value serde to use. + * @param otherValueSerde the otherValue serde to use. If `null` the default value serde from config will be used + * @return new [[StreamJoined]] instance with the provided store suppliers and serdes + */ + def `with`[K, V, VO]( + supplier: WindowBytesStoreSupplier, + otherSupplier: WindowBytesStoreSupplier + )(implicit keySerde: Serde[K], valueSerde: Serde[V], otherValueSerde: Serde[VO]): StreamJoinedJ[K, V, VO] = + StreamJoinedJ + .`with`(supplier, otherSupplier) + .withKeySerde(keySerde) + .withValueSerde(valueSerde) + .withOtherValueSerde(otherValueSerde) + + /** + * Create an instance of [[StreamJoined]] with the name used for naming + * the state stores involved in the join. Also adds the key, value, and otherValue [[Serde]] + * instances. + * `null` values are accepted and will be replaced by the default serdes as defined in config. + * + * @tparam K key type + * @tparam V value type + * @tparam VO other value type + * @param storeName the name to use as a base name for the state stores of the join + * @param keySerde the key serde to use. + * @param valueSerde the value serde to use. + * @param otherValueSerde the otherValue serde to use. If `null` the default value serde from config will be used + * @return new [[StreamJoined]] instance with the provided store suppliers and serdes + */ + def as[K, V, VO]( + storeName: String + )(implicit keySerde: Serde[K], valueSerde: Serde[V], otherValueSerde: Serde[VO]): StreamJoinedJ[K, V, VO] = + StreamJoinedJ.as(storeName).withKeySerde(keySerde).withValueSerde(valueSerde).withOtherValueSerde(otherValueSerde) + +} diff --git a/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/TimeWindowedCogroupedKStream.scala b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/TimeWindowedCogroupedKStream.scala new file mode 100644 index 0000000..ad24228 --- /dev/null +++ b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/TimeWindowedCogroupedKStream.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala +package kstream + +import org.apache.kafka.streams.kstream.{TimeWindowedCogroupedKStream => TimeWindowedCogroupedKStreamJ, Windowed} +import org.apache.kafka.streams.scala.FunctionsCompatConversions.InitializerFromFunction + +/** + * Wraps the Java class TimeWindowedCogroupedKStream and delegates method calls to the underlying Java object. + * + * @tparam K Type of keys + * @tparam V Type of values + * @param inner The underlying Java abstraction for TimeWindowedCogroupedKStream + * @see `org.apache.kafka.streams.kstream.TimeWindowedCogroupedKStream` + */ +class TimeWindowedCogroupedKStream[K, V](val inner: TimeWindowedCogroupedKStreamJ[K, V]) { + + /** + * Aggregate the values of records in these streams by the grouped key and defined window. + * + * @param initializer an initializer function that computes an initial intermediate aggregation result + * @param materialized an instance of `Materialized` used to materialize a state store. + * @return a [[KTable]] that contains "update" records with unmodified keys, and values that represent the latest + * (rolling) aggregate for each key + * @see `org.apache.kafka.streams.kstream.TimeWindowedCogroupedKStream#aggregate` + */ + def aggregate(initializer: => V)(implicit + materialized: Materialized[K, V, ByteArrayWindowStore] + ): KTable[Windowed[K], V] = + new KTable(inner.aggregate((() => initializer).asInitializer, materialized)) + + /** + * Aggregate the values of records in these streams by the grouped key and defined window. + * + * @param initializer an initializer function that computes an initial intermediate aggregation result + * @param named a [[Named]] config used to name the processor in the topology + * @param materialized an instance of `Materialized` used to materialize a state store. + * @return a [[KTable]] that contains "update" records with unmodified keys, and values that represent the latest + * (rolling) aggregate for each key + * @see `org.apache.kafka.streams.kstream.TimeWindowedCogroupedKStream#aggregate` + */ + def aggregate(initializer: => V, named: Named)(implicit + materialized: Materialized[K, V, ByteArrayWindowStore] + ): KTable[Windowed[K], V] = + new KTable(inner.aggregate((() => initializer).asInitializer, named, materialized)) + +} diff --git a/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/TimeWindowedKStream.scala b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/TimeWindowedKStream.scala new file mode 100644 index 0000000..4fcf227 --- /dev/null +++ b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/TimeWindowedKStream.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala +package kstream + +import org.apache.kafka.streams.kstream.internals.KTableImpl +import org.apache.kafka.streams.scala.serialization.Serdes +import org.apache.kafka.streams.kstream.{KTable => KTableJ, TimeWindowedKStream => TimeWindowedKStreamJ, Windowed} +import org.apache.kafka.streams.scala.FunctionsCompatConversions.{ + AggregatorFromFunction, + InitializerFromFunction, + ReducerFromFunction, + ValueMapperFromFunction +} + +/** + * Wraps the Java class TimeWindowedKStream and delegates method calls to the underlying Java object. + * + * @tparam K Type of keys + * @tparam V Type of values + * @param inner The underlying Java abstraction for TimeWindowedKStream + * @see `org.apache.kafka.streams.kstream.TimeWindowedKStream` + */ +class TimeWindowedKStream[K, V](val inner: TimeWindowedKStreamJ[K, V]) { + + /** + * Aggregate the values of records in this stream by the grouped key. + * + * @param initializer an initializer function that computes an initial intermediate aggregation result + * @param aggregator an aggregator function that computes a new aggregate result + * @param materialized an instance of `Materialized` used to materialize a state store. + * @return a [[KTable]] that contains "update" records with unmodified keys, and values that represent the + * latest (rolling) aggregate for each key + * @see `org.apache.kafka.streams.kstream.TimeWindowedKStream#aggregate` + */ + def aggregate[VR](initializer: => VR)(aggregator: (K, V, VR) => VR)(implicit + materialized: Materialized[K, VR, ByteArrayWindowStore] + ): KTable[Windowed[K], VR] = + new KTable(inner.aggregate((() => initializer).asInitializer, aggregator.asAggregator, materialized)) + + /** + * Aggregate the values of records in this stream by the grouped key. + * + * @param initializer an initializer function that computes an initial intermediate aggregation result + * @param named a [[Named]] config used to name the processor in the topology + * @param aggregator an aggregator function that computes a new aggregate result + * @param materialized an instance of `Materialized` used to materialize a state store. + * @return a [[KTable]] that contains "update" records with unmodified keys, and values that represent the + * latest (rolling) aggregate for each key + * @see `org.apache.kafka.streams.kstream.TimeWindowedKStream#aggregate` + */ + def aggregate[VR](initializer: => VR, named: Named)(aggregator: (K, V, VR) => VR)(implicit + materialized: Materialized[K, VR, ByteArrayWindowStore] + ): KTable[Windowed[K], VR] = + new KTable(inner.aggregate((() => initializer).asInitializer, aggregator.asAggregator, named, materialized)) + + /** + * Count the number of records in this stream by the grouped key and the defined windows. + * + * @param materialized an instance of `Materialized` used to materialize a state store. + * @return a [[KTable]] that contains "update" records with unmodified keys and `Long` values that + * represent the latest (rolling) count (i.e., number of records) for each key + * @see `org.apache.kafka.streams.kstream.TimeWindowedKStream#count` + */ + def count()(implicit materialized: Materialized[K, Long, ByteArrayWindowStore]): KTable[Windowed[K], Long] = { + val javaCountTable: KTableJ[Windowed[K], java.lang.Long] = + inner.count(materialized.asInstanceOf[Materialized[K, java.lang.Long, ByteArrayWindowStore]]) + val tableImpl = javaCountTable.asInstanceOf[KTableImpl[Windowed[K], ByteArrayWindowStore, java.lang.Long]] + new KTable( + javaCountTable.mapValues[Long]( + ((l: java.lang.Long) => Long2long(l)).asValueMapper, + Materialized.`with`[Windowed[K], Long, ByteArrayKeyValueStore](tableImpl.keySerde(), Serdes.longSerde) + ) + ) + } + + /** + * Count the number of records in this stream by the grouped key and the defined windows. + * + * @param named a [[Named]] config used to name the processor in the topology + * @param materialized an instance of `Materialized` used to materialize a state store. + * @return a [[KTable]] that contains "update" records with unmodified keys and `Long` values that + * represent the latest (rolling) count (i.e., number of records) for each key + * @see `org.apache.kafka.streams.kstream.TimeWindowedKStream#count` + */ + def count( + named: Named + )(implicit materialized: Materialized[K, Long, ByteArrayWindowStore]): KTable[Windowed[K], Long] = { + val javaCountTable: KTableJ[Windowed[K], java.lang.Long] = + inner.count(named, materialized.asInstanceOf[Materialized[K, java.lang.Long, ByteArrayWindowStore]]) + val tableImpl = javaCountTable.asInstanceOf[KTableImpl[Windowed[K], ByteArrayWindowStore, java.lang.Long]] + new KTable( + javaCountTable.mapValues[Long]( + ((l: java.lang.Long) => Long2long(l)).asValueMapper, + Materialized.`with`[Windowed[K], Long, ByteArrayKeyValueStore](tableImpl.keySerde(), Serdes.longSerde) + ) + ) + } + + /** + * Combine the values of records in this stream by the grouped key. + * + * @param reducer a function that computes a new aggregate result + * @param materialized an instance of `Materialized` used to materialize a state store. + * @return a [[KTable]] that contains "update" records with unmodified keys, and values that represent the + * latest (rolling) aggregate for each key + * @see `org.apache.kafka.streams.kstream.TimeWindowedKStream#reduce` + */ + def reduce(reducer: (V, V) => V)(implicit + materialized: Materialized[K, V, ByteArrayWindowStore] + ): KTable[Windowed[K], V] = + new KTable(inner.reduce(reducer.asReducer, materialized)) + + /** + * Combine the values of records in this stream by the grouped key. + * + * @param reducer a function that computes a new aggregate result + * @param named a [[Named]] config used to name the processor in the topology + * @param materialized an instance of `Materialized` used to materialize a state store. + * @return a [[KTable]] that contains "update" records with unmodified keys, and values that represent the + * latest (rolling) aggregate for each key + * @see `org.apache.kafka.streams.kstream.TimeWindowedKStream#reduce` + */ + def reduce(reducer: (V, V) => V, named: Named)(implicit + materialized: Materialized[K, V, ByteArrayWindowStore] + ): KTable[Windowed[K], V] = + new KTable(inner.reduce(reducer.asReducer, materialized)) +} diff --git a/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/package.scala b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/package.scala new file mode 100644 index 0000000..7365c68 --- /dev/null +++ b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/kstream/package.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala + +import org.apache.kafka.streams.processor.StateStore + +package object kstream { + type Materialized[K, V, S <: StateStore] = org.apache.kafka.streams.kstream.Materialized[K, V, S] + type Grouped[K, V] = org.apache.kafka.streams.kstream.Grouped[K, V] + type Consumed[K, V] = org.apache.kafka.streams.kstream.Consumed[K, V] + type Produced[K, V] = org.apache.kafka.streams.kstream.Produced[K, V] + type Repartitioned[K, V] = org.apache.kafka.streams.kstream.Repartitioned[K, V] + type Joined[K, V, VO] = org.apache.kafka.streams.kstream.Joined[K, V, VO] + type StreamJoined[K, V, VO] = org.apache.kafka.streams.kstream.StreamJoined[K, V, VO] + type Named = org.apache.kafka.streams.kstream.Named + type Branched[K, V] = org.apache.kafka.streams.kstream.Branched[K, V] +} diff --git a/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/package.scala b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/package.scala new file mode 100644 index 0000000..6a3906d --- /dev/null +++ b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/package.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams + +import org.apache.kafka.streams.state.{KeyValueStore, SessionStore, WindowStore} +import org.apache.kafka.common.utils.Bytes + +package object scala { + type ByteArrayKeyValueStore = KeyValueStore[Bytes, Array[Byte]] + type ByteArraySessionStore = SessionStore[Bytes, Array[Byte]] + type ByteArrayWindowStore = WindowStore[Bytes, Array[Byte]] +} diff --git a/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/serialization/Serdes.scala b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/serialization/Serdes.scala new file mode 100644 index 0000000..0c72358 --- /dev/null +++ b/streams/streams-scala/src/main/scala/org/apache/kafka/streams/scala/serialization/Serdes.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala.serialization + +import java.nio.ByteBuffer +import java.util +import java.util.UUID + +import org.apache.kafka.common.serialization.{Deserializer, Serde, Serializer, Serdes => JSerdes} +import org.apache.kafka.streams.kstream.WindowedSerdes + +object Serdes extends LowPrioritySerdes { + implicit def stringSerde: Serde[String] = JSerdes.String() + implicit def longSerde: Serde[Long] = JSerdes.Long().asInstanceOf[Serde[Long]] + implicit def javaLongSerde: Serde[java.lang.Long] = JSerdes.Long() + implicit def byteArraySerde: Serde[Array[Byte]] = JSerdes.ByteArray() + implicit def bytesSerde: Serde[org.apache.kafka.common.utils.Bytes] = JSerdes.Bytes() + implicit def byteBufferSerde: Serde[ByteBuffer] = JSerdes.ByteBuffer() + implicit def shortSerde: Serde[Short] = JSerdes.Short().asInstanceOf[Serde[Short]] + implicit def javaShortSerde: Serde[java.lang.Short] = JSerdes.Short() + implicit def floatSerde: Serde[Float] = JSerdes.Float().asInstanceOf[Serde[Float]] + implicit def javaFloatSerde: Serde[java.lang.Float] = JSerdes.Float() + implicit def doubleSerde: Serde[Double] = JSerdes.Double().asInstanceOf[Serde[Double]] + implicit def javaDoubleSerde: Serde[java.lang.Double] = JSerdes.Double() + implicit def intSerde: Serde[Int] = JSerdes.Integer().asInstanceOf[Serde[Int]] + implicit def javaIntegerSerde: Serde[java.lang.Integer] = JSerdes.Integer() + implicit def uuidSerde: Serde[UUID] = JSerdes.UUID() + + implicit def sessionWindowedSerde[T](implicit tSerde: Serde[T]): WindowedSerdes.SessionWindowedSerde[T] = + new WindowedSerdes.SessionWindowedSerde[T](tSerde) + + def fromFn[T >: Null](serializer: T => Array[Byte], deserializer: Array[Byte] => Option[T]): Serde[T] = + JSerdes.serdeFrom( + new Serializer[T] { + override def serialize(topic: String, data: T): Array[Byte] = serializer(data) + override def configure(configs: util.Map[String, _], isKey: Boolean): Unit = () + override def close(): Unit = () + }, + new Deserializer[T] { + override def deserialize(topic: String, data: Array[Byte]): T = deserializer(data).orNull + override def configure(configs: util.Map[String, _], isKey: Boolean): Unit = () + override def close(): Unit = () + } + ) + + def fromFn[T >: Null]( + serializer: (String, T) => Array[Byte], + deserializer: (String, Array[Byte]) => Option[T] + ): Serde[T] = + JSerdes.serdeFrom( + new Serializer[T] { + override def serialize(topic: String, data: T): Array[Byte] = serializer(topic, data) + override def configure(configs: util.Map[String, _], isKey: Boolean): Unit = () + override def close(): Unit = () + }, + new Deserializer[T] { + override def deserialize(topic: String, data: Array[Byte]): T = deserializer(topic, data).orNull + override def configure(configs: util.Map[String, _], isKey: Boolean): Unit = () + override def close(): Unit = () + } + ) +} + +trait LowPrioritySerdes { + + implicit val nullSerde: Serde[Null] = + Serdes.fromFn[Null]( + { _: Null => + null + }, + { _: Array[Byte] => + None + } + ) +} diff --git a/streams/streams-scala/src/test/resources/log4j.properties b/streams/streams-scala/src/test/resources/log4j.properties new file mode 100644 index 0000000..93ffc16 --- /dev/null +++ b/streams/streams-scala/src/test/resources/log4j.properties @@ -0,0 +1,34 @@ +# Copyright (C) 2018 Lightbend Inc. +# Copyright (C) 2017-2018 Alexis Seigneurin. +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Set root logger level to DEBUG and its only appender to A1. +log4j.rootLogger=INFO, R + +# A1 is set to be a ConsoleAppender. +log4j.appender.A1=org.apache.log4j.ConsoleAppender + +log4j.appender.R=org.apache.log4j.RollingFileAppender +log4j.appender.R.File=logs/kafka-streams-scala.log + +log4j.appender.R.MaxFileSize=100KB +# Keep one backup file +log4j.appender.R.MaxBackupIndex=1 + +# A1 uses PatternLayout. +log4j.appender.R.layout=org.apache.log4j.PatternLayout +log4j.appender.R.layout.ConversionPattern=%-4r [%t] %-5p %c %x - %m%n diff --git a/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/StreamToTableJoinScalaIntegrationTestImplicitSerdes.scala b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/StreamToTableJoinScalaIntegrationTestImplicitSerdes.scala new file mode 100644 index 0000000..e9577bc --- /dev/null +++ b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/StreamToTableJoinScalaIntegrationTestImplicitSerdes.scala @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala + +import java.util.Properties + +import org.apache.kafka.streams.{KafkaStreams, KeyValue, StreamsConfig} +import org.apache.kafka.streams.scala.serialization.{Serdes => NewSerdes} +import org.apache.kafka.streams.scala.ImplicitConversions._ +import org.apache.kafka.streams.scala.kstream._ +import org.apache.kafka.streams.scala.utils.StreamToTableJoinScalaIntegrationTestBase +import org.junit.jupiter.api._ +import org.junit.jupiter.api.Assertions._ + +/** + * Test suite that does an example to demonstrate stream-table joins in Kafka Streams + *

                + * The suite contains the test case using Scala APIs `testShouldCountClicksPerRegion` and the same test case using the + * Java APIs `testShouldCountClicksPerRegionJava`. The idea is to demonstrate that both generate the same result. + */ +@Tag("integration") +class StreamToTableJoinScalaIntegrationTestImplicitSerdes extends StreamToTableJoinScalaIntegrationTestBase { + + @Test def testShouldCountClicksPerRegion(): Unit = { + + // DefaultSerdes brings into scope implicit serdes (mostly for primitives) that will set up all Grouped, Produced, + // Consumed and Joined instances. So all APIs below that accept Grouped, Produced, Consumed or Joined will + // get these instances automatically + import org.apache.kafka.streams.scala.serialization.Serdes._ + + val streamsConfiguration: Properties = getStreamsConfiguration() + + val builder = new StreamsBuilder() + + val userClicksStream: KStream[String, Long] = builder.stream(userClicksTopic) + + val userRegionsTable: KTable[String, String] = builder.table(userRegionsTopic) + + // Compute the total per region by summing the individual click counts per region. + val clicksPerRegion: KTable[String, Long] = + userClicksStream + + // Join the stream against the table. + .leftJoin(userRegionsTable)((clicks, region) => (if (region == null) "UNKNOWN" else region, clicks)) + + // Change the stream from -> to -> + .map((_, regionWithClicks) => regionWithClicks) + + // Compute the total per region by summing the individual click counts per region. + .groupByKey + .reduce(_ + _) + + // Write the (continuously updating) results to the output topic. + clicksPerRegion.toStream.to(outputTopic) + + val streams: KafkaStreams = new KafkaStreams(builder.build(), streamsConfiguration) + streams.start() + + val actualClicksPerRegion: java.util.List[KeyValue[String, Long]] = + produceNConsume(userClicksTopic, userRegionsTopic, outputTopic) + + assertTrue(!actualClicksPerRegion.isEmpty, "Expected to process some data") + + streams.close() + } + + @Test + def testShouldCountClicksPerRegionWithNamedRepartitionTopic(): Unit = { + + // DefaultSerdes brings into scope implicit serdes (mostly for primitives) that will set up all Grouped, Produced, + // Consumed and Joined instances. So all APIs below that accept Grouped, Produced, Consumed or Joined will + // get these instances automatically + import org.apache.kafka.streams.scala.serialization.Serdes._ + + val streamsConfiguration: Properties = getStreamsConfiguration() + + val builder = new StreamsBuilder() + + val userClicksStream: KStream[String, Long] = builder.stream(userClicksTopic) + + val userRegionsTable: KTable[String, String] = builder.table(userRegionsTopic) + + // Compute the total per region by summing the individual click counts per region. + val clicksPerRegion: KTable[String, Long] = + userClicksStream + + // Join the stream against the table. + .leftJoin(userRegionsTable)((clicks, region) => (if (region == null) "UNKNOWN" else region, clicks)) + + // Change the stream from -> to -> + .map((_, regionWithClicks) => regionWithClicks) + + // Compute the total per region by summing the individual click counts per region. + .groupByKey + .reduce(_ + _) + + // Write the (continuously updating) results to the output topic. + clicksPerRegion.toStream.to(outputTopic) + + val streams: KafkaStreams = new KafkaStreams(builder.build(), streamsConfiguration) + streams.start() + + val actualClicksPerRegion: java.util.List[KeyValue[String, Long]] = + produceNConsume(userClicksTopic, userRegionsTopic, outputTopic) + + assertTrue(!actualClicksPerRegion.isEmpty, "Expected to process some data") + + streams.close() + } + + @Test + def testShouldCountClicksPerRegionJava(): Unit = { + + import java.lang.{Long => JLong} + + import org.apache.kafka.streams.kstream.{KStream => KStreamJ, KTable => KTableJ, _} + import org.apache.kafka.streams.{KafkaStreams => KafkaStreamsJ, StreamsBuilder => StreamsBuilderJ} + + val streamsConfiguration: Properties = getStreamsConfiguration() + + streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, NewSerdes.stringSerde.getClass.getName) + streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, NewSerdes.stringSerde.getClass.getName) + + val builder: StreamsBuilderJ = new StreamsBuilderJ() + + val userClicksStream: KStreamJ[String, JLong] = + builder.stream[String, JLong](userClicksTopicJ, Consumed.`with`(NewSerdes.stringSerde, NewSerdes.javaLongSerde)) + + val userRegionsTable: KTableJ[String, String] = + builder.table[String, String](userRegionsTopicJ, Consumed.`with`(NewSerdes.stringSerde, NewSerdes.stringSerde)) + + // Join the stream against the table. + val valueJoinerJ: ValueJoiner[JLong, String, (String, JLong)] = + (clicks: JLong, region: String) => (if (region == null) "UNKNOWN" else region, clicks) + val userClicksJoinRegion: KStreamJ[String, (String, JLong)] = userClicksStream.leftJoin( + userRegionsTable, + valueJoinerJ, + Joined.`with`[String, JLong, String](NewSerdes.stringSerde, NewSerdes.javaLongSerde, NewSerdes.stringSerde) + ) + + // Change the stream from -> to -> + val clicksByRegion: KStreamJ[String, JLong] = userClicksJoinRegion.map { (_, regionWithClicks) => + new KeyValue(regionWithClicks._1, regionWithClicks._2) + } + + // Compute the total per region by summing the individual click counts per region. + val clicksPerRegion: KTableJ[String, JLong] = clicksByRegion + .groupByKey(Grouped.`with`(NewSerdes.stringSerde, NewSerdes.javaLongSerde)) + .reduce((v1, v2) => v1 + v2) + + // Write the (continuously updating) results to the output topic. + clicksPerRegion.toStream.to(outputTopicJ, Produced.`with`(NewSerdes.stringSerde, NewSerdes.javaLongSerde)) + + val streams = new KafkaStreamsJ(builder.build(), streamsConfiguration) + + streams.start() + produceNConsume(userClicksTopicJ, userRegionsTopicJ, outputTopicJ) + streams.close() + } +} diff --git a/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/TopologyTest.scala b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/TopologyTest.scala new file mode 100644 index 0000000..926ba43 --- /dev/null +++ b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/TopologyTest.scala @@ -0,0 +1,479 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala + +import java.time.Duration +import java.util +import java.util.{Locale, Properties} +import java.util.regex.Pattern +import org.apache.kafka.common.serialization.{Serdes => SerdesJ} +import org.apache.kafka.streams.kstream.{ + Aggregator, + Initializer, + JoinWindows, + KeyValueMapper, + Reducer, + Transformer, + ValueJoiner, + ValueMapper, + KGroupedStream => KGroupedStreamJ, + KStream => KStreamJ, + KTable => KTableJ, + Materialized => MaterializedJ, + StreamJoined => StreamJoinedJ +} +import org.apache.kafka.streams.processor.{api, ProcessorContext} +import org.apache.kafka.streams.processor.api.{Processor, ProcessorSupplier} +import org.apache.kafka.streams.scala.ImplicitConversions._ +import org.apache.kafka.streams.scala.serialization.{Serdes => NewSerdes} +import org.apache.kafka.streams.scala.serialization.Serdes._ +import org.apache.kafka.streams.scala.kstream._ +import org.apache.kafka.streams.{KeyValue, StreamsConfig, TopologyDescription, StreamsBuilder => StreamsBuilderJ} +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api._ + +import scala.jdk.CollectionConverters._ + +/** + * Test suite that verifies that the topology built by the Java and Scala APIs match. + */ +//noinspection ScalaDeprecation +class TopologyTest { + + private val inputTopic = "input-topic" + private val userClicksTopic = "user-clicks-topic" + private val userRegionsTopic = "user-regions-topic" + + private val pattern = Pattern.compile("\\W+", Pattern.UNICODE_CHARACTER_CLASS) + + @Test + def shouldBuildIdenticalTopologyInJavaNScalaSimple(): Unit = { + + // build the Scala topology + def getTopologyScala: TopologyDescription = { + + import org.apache.kafka.streams.scala.serialization.Serdes._ + + val streamBuilder = new StreamsBuilder + val textLines = streamBuilder.stream[String, String](inputTopic) + + val _: KStream[String, String] = textLines.flatMapValues(v => pattern.split(v.toLowerCase)) + + streamBuilder.build().describe() + } + + // build the Java topology + def getTopologyJava: TopologyDescription = { + val streamBuilder = new StreamsBuilderJ + val textLines = streamBuilder.stream[String, String](inputTopic) + val _: KStreamJ[String, String] = textLines.flatMapValues(s => pattern.split(s.toLowerCase).toIterable.asJava) + streamBuilder.build().describe() + } + + // should match + assertEquals(getTopologyScala, getTopologyJava) + } + + @Test + def shouldBuildIdenticalTopologyInJavaNScalaAggregate(): Unit = { + + // build the Scala topology + def getTopologyScala: TopologyDescription = { + + import org.apache.kafka.streams.scala.serialization.Serdes._ + + val streamBuilder = new StreamsBuilder + val textLines = streamBuilder.stream[String, String](inputTopic) + + textLines + .flatMapValues(v => pattern.split(v.toLowerCase)) + .groupBy((_, v) => v) + .count() + + streamBuilder.build().describe() + } + + // build the Java topology + def getTopologyJava: TopologyDescription = { + + val streamBuilder = new StreamsBuilderJ + val textLines: KStreamJ[String, String] = streamBuilder.stream[String, String](inputTopic) + + val splits: KStreamJ[String, String] = + textLines.flatMapValues(s => pattern.split(s.toLowerCase).toIterable.asJava) + + val grouped: KGroupedStreamJ[String, String] = splits.groupBy((_, v) => v) + + grouped.count() + + streamBuilder.build().describe() + } + + // should match + assertEquals(getTopologyScala, getTopologyJava) + } + + @Test def shouldBuildIdenticalTopologyInJavaNScalaCogroupSimple(): Unit = { + + // build the Scala topology + def getTopologyScala: TopologyDescription = { + + import org.apache.kafka.streams.scala.serialization.Serdes._ + + val streamBuilder = new StreamsBuilder + val textLines = streamBuilder.stream[String, String](inputTopic) + textLines + .mapValues(v => v.length) + .groupByKey + .cogroup((_, v1, v2: Long) => v1 + v2) + .aggregate(0L) + + streamBuilder.build().describe() + } + + // build the Java topology + def getTopologyJava: TopologyDescription = { + + val streamBuilder = new StreamsBuilderJ + val textLines: KStreamJ[String, String] = streamBuilder.stream[String, String](inputTopic) + + val splits: KStreamJ[String, Int] = textLines.mapValues( + new ValueMapper[String, Int] { + def apply(s: String): Int = s.length + } + ) + + splits.groupByKey + .cogroup((k: String, v: Int, a: Long) => a + v) + .aggregate(() => 0L) + + streamBuilder.build().describe() + } + + // should match + assertEquals(getTopologyScala, getTopologyJava) + } + + @Test def shouldBuildIdenticalTopologyInJavaNScalaCogroup(): Unit = { + + // build the Scala topology + def getTopologyScala: TopologyDescription = { + + import org.apache.kafka.streams.scala.serialization.Serdes._ + + val streamBuilder = new StreamsBuilder + val textLines1 = streamBuilder.stream[String, String](inputTopic) + val textLines2 = streamBuilder.stream[String, String]("inputTopic2") + + textLines1 + .mapValues(v => v.length) + .groupByKey + .cogroup((_, v1, v2: Long) => v1 + v2) + .cogroup(textLines2.groupByKey, (_, v: String, a) => v.length + a) + .aggregate(0L) + + streamBuilder.build().describe() + } + + // build the Java topology + def getTopologyJava: TopologyDescription = { + + val streamBuilder = new StreamsBuilderJ + val textLines1: KStreamJ[String, String] = streamBuilder.stream[String, String](inputTopic) + val textLines2: KStreamJ[String, String] = streamBuilder.stream[String, String]("inputTopic2") + + val splits: KStreamJ[String, Int] = textLines1.mapValues( + new ValueMapper[String, Int] { + def apply(s: String): Int = s.length + } + ) + + splits.groupByKey + .cogroup((k: String, v: Int, a: Long) => a + v) + .cogroup(textLines2.groupByKey(), (k: String, v: String, a: Long) => v.length + a) + .aggregate(() => 0L) + + streamBuilder.build().describe() + } + + // should match + assertEquals(getTopologyScala, getTopologyJava) + } + + @Test def shouldBuildIdenticalTopologyInJavaNScalaJoin(): Unit = { + + // build the Scala topology + def getTopologyScala: TopologyDescription = { + import org.apache.kafka.streams.scala.serialization.Serdes._ + + val builder = new StreamsBuilder() + + val userClicksStream: KStream[String, Long] = builder.stream(userClicksTopic) + + val userRegionsTable: KTable[String, String] = builder.table(userRegionsTopic) + + // clicks per region + userClicksStream + .leftJoin(userRegionsTable)((clicks, region) => (if (region == null) "UNKNOWN" else region, clicks)) + .map((_, regionWithClicks) => regionWithClicks) + .groupByKey + .reduce(_ + _) + + builder.build().describe() + } + + // build the Java topology + def getTopologyJava: TopologyDescription = { + + import java.lang.{Long => JLong} + + val builder: StreamsBuilderJ = new StreamsBuilderJ() + + val userClicksStream: KStreamJ[String, JLong] = + builder.stream[String, JLong](userClicksTopic, Consumed.`with`[String, JLong]) + + val userRegionsTable: KTableJ[String, String] = + builder.table[String, String](userRegionsTopic, Consumed.`with`[String, String]) + + // Join the stream against the table. + val valueJoinerJ: ValueJoiner[JLong, String, (String, JLong)] = + (clicks: JLong, region: String) => (if (region == null) "UNKNOWN" else region, clicks) + val userClicksJoinRegion: KStreamJ[String, (String, JLong)] = userClicksStream.leftJoin( + userRegionsTable, + valueJoinerJ, + Joined.`with`[String, JLong, String] + ) + + // Change the stream from -> to -> + val clicksByRegion: KStreamJ[String, JLong] = userClicksJoinRegion.map { (_, regionWithClicks) => + new KeyValue(regionWithClicks._1, regionWithClicks._2) + } + + // Compute the total per region by summing the individual click counts per region. + clicksByRegion + .groupByKey(Grouped.`with`[String, JLong]) + .reduce((v1, v2) => v1 + v2) + + builder.build().describe() + } + + // should match + assertEquals(getTopologyScala, getTopologyJava) + } + + @Test + def shouldBuildIdenticalTopologyInJavaNScalaTransform(): Unit = { + + // build the Scala topology + def getTopologyScala: TopologyDescription = { + + import org.apache.kafka.streams.scala.serialization.Serdes._ + + val streamBuilder = new StreamsBuilder + val textLines = streamBuilder.stream[String, String](inputTopic) + + val _: KTable[String, Long] = textLines + .transform(() => + new Transformer[String, String, KeyValue[String, String]] { + override def init(context: ProcessorContext): Unit = () + override def transform(key: String, value: String): KeyValue[String, String] = + new KeyValue(key, value.toLowerCase) + override def close(): Unit = () + } + ) + .groupBy((_, v) => v) + .count() + + streamBuilder.build().describe() + } + + // build the Java topology + def getTopologyJava: TopologyDescription = { + + val streamBuilder = new StreamsBuilderJ + val textLines: KStreamJ[String, String] = streamBuilder.stream[String, String](inputTopic) + + val lowered: KStreamJ[String, String] = textLines.transform(() => + new Transformer[String, String, KeyValue[String, String]] { + override def init(context: ProcessorContext): Unit = () + override def transform(key: String, value: String): KeyValue[String, String] = + new KeyValue(key, value.toLowerCase) + override def close(): Unit = () + } + ) + + val grouped: KGroupedStreamJ[String, String] = lowered.groupBy((_, v) => v) + + // word counts + grouped.count() + + streamBuilder.build().describe() + } + + // should match + assertEquals(getTopologyScala, getTopologyJava) + } + + @Test + def shouldBuildIdenticalTopologyInJavaNScalaProperties(): Unit = { + + val props = new Properties() + props.put(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, StreamsConfig.OPTIMIZE) + + val propsNoOptimization = new Properties() + propsNoOptimization.put(StreamsConfig.TOPOLOGY_OPTIMIZATION_CONFIG, StreamsConfig.NO_OPTIMIZATION) + + val AGGREGATION_TOPIC = "aggregationTopic" + val REDUCE_TOPIC = "reduceTopic" + val JOINED_TOPIC = "joinedTopic" + + // build the Scala topology + def getTopologyScala: StreamsBuilder = { + + val aggregator = (_: String, v: String, agg: Int) => agg + v.length + val reducer = (v1: String, v2: String) => v1 + ":" + v2 + val processorValueCollector: util.List[String] = new util.ArrayList[String] + + val builder: StreamsBuilder = new StreamsBuilder + + val sourceStream: KStream[String, String] = + builder.stream(inputTopic)(Consumed.`with`(NewSerdes.stringSerde, NewSerdes.stringSerde)) + + val mappedStream: KStream[String, String] = + sourceStream.map((k: String, v: String) => (k.toUpperCase(Locale.getDefault), v)) + mappedStream + .filter((k: String, _: String) => k == "B") + .mapValues((v: String) => v.toUpperCase(Locale.getDefault)) + .process(new SimpleProcessorSupplier(processorValueCollector)) + + val stream2 = mappedStream.groupByKey + .aggregate(0)(aggregator)(Materialized.`with`(NewSerdes.stringSerde, NewSerdes.intSerde)) + .toStream + stream2.to(AGGREGATION_TOPIC)(Produced.`with`(NewSerdes.stringSerde, NewSerdes.intSerde)) + + // adding operators for case where the repartition node is further downstream + val stream3 = mappedStream + .filter((_: String, _: String) => true) + .peek((k: String, v: String) => System.out.println(k + ":" + v)) + .groupByKey + .reduce(reducer)(Materialized.`with`(NewSerdes.stringSerde, NewSerdes.stringSerde)) + .toStream + stream3.to(REDUCE_TOPIC)(Produced.`with`(NewSerdes.stringSerde, NewSerdes.stringSerde)) + + mappedStream + .filter((k: String, _: String) => k == "A") + .join(stream2)( + (v1: String, v2: Int) => v1 + ":" + v2.toString, + JoinWindows.ofTimeDifferenceAndGrace(Duration.ofMillis(5000), Duration.ofHours(24)) + )( + StreamJoined.`with`(NewSerdes.stringSerde, NewSerdes.stringSerde, NewSerdes.intSerde) + ) + .to(JOINED_TOPIC) + + mappedStream + .filter((k: String, _: String) => k == "A") + .join(stream3)( + (v1: String, v2: String) => v1 + ":" + v2.toString, + JoinWindows.ofTimeDifferenceAndGrace(Duration.ofMillis(5000), Duration.ofHours(24)) + )( + StreamJoined.`with`(NewSerdes.stringSerde, NewSerdes.stringSerde, NewSerdes.stringSerde) + ) + .to(JOINED_TOPIC) + + builder + } + + // build the Java topology + def getTopologyJava: StreamsBuilderJ = { + + val keyValueMapper: KeyValueMapper[String, String, KeyValue[String, String]] = + (key, value) => KeyValue.pair(key.toUpperCase(Locale.getDefault), value) + val initializer: Initializer[Integer] = () => 0 + val aggregator: Aggregator[String, String, Integer] = (_, value, aggregate) => aggregate + value.length + val reducer: Reducer[String] = (v1, v2) => v1 + ":" + v2 + val valueMapper: ValueMapper[String, String] = v => v.toUpperCase(Locale.getDefault) + val processorValueCollector = new util.ArrayList[String] + val processorSupplier = new SimpleProcessorSupplier(processorValueCollector) + val valueJoiner2: ValueJoiner[String, Integer, String] = (value1, value2) => value1 + ":" + value2.toString + val valueJoiner3: ValueJoiner[String, String, String] = (value1, value2) => value1 + ":" + value2 + + val builder = new StreamsBuilderJ + + val sourceStream = builder.stream(inputTopic, Consumed.`with`(NewSerdes.stringSerde, NewSerdes.stringSerde)) + + val mappedStream: KStreamJ[String, String] = + sourceStream.map(keyValueMapper) + mappedStream + .filter((key, _) => key == "B") + .mapValues[String](valueMapper) + .process(processorSupplier) + + val stream2: KStreamJ[String, Integer] = mappedStream.groupByKey + .aggregate(initializer, aggregator, MaterializedJ.`with`(NewSerdes.stringSerde, SerdesJ.Integer)) + .toStream + stream2.to(AGGREGATION_TOPIC, Produced.`with`(NewSerdes.stringSerde, SerdesJ.Integer)) + + // adding operators for case where the repartition node is further downstream + val stream3 = mappedStream + .filter((_, _) => true) + .peek((k, v) => System.out.println(k + ":" + v)) + .groupByKey + .reduce(reducer, MaterializedJ.`with`(NewSerdes.stringSerde, NewSerdes.stringSerde)) + .toStream + stream3.to(REDUCE_TOPIC, Produced.`with`(NewSerdes.stringSerde, NewSerdes.stringSerde)) + + mappedStream + .filter((key, _) => key == "A") + .join[Integer, String]( + stream2, + valueJoiner2, + JoinWindows.ofTimeDifferenceAndGrace(Duration.ofMillis(5000), Duration.ofHours(24)), + StreamJoinedJ.`with`(NewSerdes.stringSerde, NewSerdes.stringSerde, SerdesJ.Integer) + ) + .to(JOINED_TOPIC) + + mappedStream + .filter((key, _) => key == "A") + .join( + stream3, + valueJoiner3, + JoinWindows.ofTimeDifferenceAndGrace(Duration.ofMillis(5000), Duration.ofHours(24)), + StreamJoinedJ.`with`(NewSerdes.stringSerde, NewSerdes.stringSerde, SerdesJ.String) + ) + .to(JOINED_TOPIC) + + builder + } + + assertNotEquals( + getTopologyScala.build(props).describe.toString, + getTopologyScala.build(propsNoOptimization).describe.toString + ) + assertEquals( + getTopologyScala.build(propsNoOptimization).describe.toString, + getTopologyJava.build(propsNoOptimization).describe.toString + ) + assertEquals(getTopologyScala.build(props).describe.toString, getTopologyJava.build(props).describe.toString) + } + + private class SimpleProcessorSupplier private[TopologyTest] (val valueList: util.List[String]) + extends ProcessorSupplier[String, String, Void, Void] { + + override def get(): Processor[String, String, Void, Void] = + (record: api.Record[String, String]) => valueList.add(record.value()) + } +} diff --git a/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/WordCountTest.scala b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/WordCountTest.scala new file mode 100644 index 0000000..1fa364a --- /dev/null +++ b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/WordCountTest.scala @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala + +import java.util.Properties +import java.util.regex.Pattern +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api._ +import org.apache.kafka.streams.scala.serialization.{Serdes => NewSerdes} +import org.apache.kafka.streams.{KafkaStreams, KeyValue, StreamsConfig} +import org.apache.kafka.streams.scala.kstream._ +import org.apache.kafka.streams.integration.utils.{EmbeddedKafkaCluster, IntegrationTestUtils} +import org.apache.kafka.clients.consumer.ConsumerConfig +import org.apache.kafka.clients.producer.ProducerConfig +import org.apache.kafka.common.utils.{MockTime, Utils} +import ImplicitConversions._ +import org.apache.kafka.common.serialization.{LongDeserializer, StringDeserializer, StringSerializer} +import org.apache.kafka.test.TestUtils +import org.junit.jupiter.api.Tag + +import java.io.File + +/** + * Test suite that does a classic word count example. + *

                + * The suite contains the test case using Scala APIs `testShouldCountWords` and the same test case using the + * Java APIs `testShouldCountWordsJava`. The idea is to demonstrate that both generate the same result. + */ +@Tag("integration") +class WordCountTest extends WordCountTestData { + + private val cluster: EmbeddedKafkaCluster = new EmbeddedKafkaCluster(1) + + final private val alignedTime = (System.currentTimeMillis() / 1000 + 1) * 1000 + private val mockTime: MockTime = cluster.time + mockTime.setCurrentTimeMs(alignedTime) + + private val testFolder: File = TestUtils.tempDirectory() + + @BeforeEach + def startKafkaCluster(): Unit = { + cluster.start() + cluster.createTopic(inputTopic) + cluster.createTopic(outputTopic) + cluster.createTopic(inputTopicJ) + cluster.createTopic(outputTopicJ) + } + + @AfterEach + def stopKafkaCluster(): Unit = { + cluster.stop() + Utils.delete(testFolder) + } + + @Test + def testShouldCountWords(): Unit = { + import org.apache.kafka.streams.scala.serialization.Serdes._ + + val streamsConfiguration = getStreamsConfiguration() + + val streamBuilder = new StreamsBuilder + val textLines = streamBuilder.stream[String, String](inputTopic) + + val pattern = Pattern.compile("\\W+", Pattern.UNICODE_CHARACTER_CLASS) + + // generate word counts + val wordCounts: KTable[String, Long] = + textLines + .flatMapValues(v => pattern.split(v.toLowerCase)) + .groupBy((_, v) => v) + .count() + + // write to output topic + wordCounts.toStream.to(outputTopic) + + val streams = new KafkaStreams(streamBuilder.build(), streamsConfiguration) + streams.start() + + // produce and consume synchronously + val actualWordCounts: java.util.List[KeyValue[String, Long]] = produceNConsume(inputTopic, outputTopic) + + streams.close() + + import scala.jdk.CollectionConverters._ + assertEquals(actualWordCounts.asScala.take(expectedWordCounts.size).sortBy(_.key), expectedWordCounts.sortBy(_.key)) + } + + @Test + def testShouldCountWordsMaterialized(): Unit = { + import org.apache.kafka.streams.scala.serialization.Serdes._ + + val streamsConfiguration = getStreamsConfiguration() + + val streamBuilder = new StreamsBuilder + val textLines = streamBuilder.stream[String, String](inputTopic) + + val pattern = Pattern.compile("\\W+", Pattern.UNICODE_CHARACTER_CLASS) + + // generate word counts + val wordCounts: KTable[String, Long] = + textLines + .flatMapValues(v => pattern.split(v.toLowerCase)) + .groupBy((k, v) => v) + .count()(Materialized.as("word-count")) + + // write to output topic + wordCounts.toStream.to(outputTopic) + + val streams = new KafkaStreams(streamBuilder.build(), streamsConfiguration) + streams.start() + + // produce and consume synchronously + val actualWordCounts: java.util.List[KeyValue[String, Long]] = produceNConsume(inputTopic, outputTopic) + + streams.close() + + import scala.jdk.CollectionConverters._ + assertEquals(actualWordCounts.asScala.take(expectedWordCounts.size).sortBy(_.key), expectedWordCounts.sortBy(_.key)) + } + + @Test + def testShouldCountWordsJava(): Unit = { + + import org.apache.kafka.streams.{KafkaStreams => KafkaStreamsJ, StreamsBuilder => StreamsBuilderJ} + import org.apache.kafka.streams.kstream.{ + KTable => KTableJ, + KStream => KStreamJ, + KGroupedStream => KGroupedStreamJ, + _ + } + import scala.jdk.CollectionConverters._ + + val streamsConfiguration = getStreamsConfiguration() + streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, NewSerdes.stringSerde.getClass.getName) + streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, NewSerdes.stringSerde.getClass.getName) + + val streamBuilder = new StreamsBuilderJ + val textLines: KStreamJ[String, String] = streamBuilder.stream[String, String](inputTopicJ) + + val pattern = Pattern.compile("\\W+", Pattern.UNICODE_CHARACTER_CLASS) + + val splits: KStreamJ[String, String] = textLines.flatMapValues { line => + pattern.split(line.toLowerCase).toIterable.asJava + } + + val grouped: KGroupedStreamJ[String, String] = splits.groupBy { (_, v) => + v + } + + val wordCounts: KTableJ[String, java.lang.Long] = grouped.count() + + wordCounts.toStream.to(outputTopicJ, Produced.`with`(NewSerdes.stringSerde, NewSerdes.javaLongSerde)) + + val streams: KafkaStreamsJ = new KafkaStreamsJ(streamBuilder.build(), streamsConfiguration) + streams.start() + + val actualWordCounts: java.util.List[KeyValue[String, Long]] = produceNConsume(inputTopicJ, outputTopicJ) + + streams.close() + + assertEquals(actualWordCounts.asScala.take(expectedWordCounts.size).sortBy(_.key), expectedWordCounts.sortBy(_.key)) + } + + private def getStreamsConfiguration(): Properties = { + val streamsConfiguration: Properties = new Properties() + + streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, "wordcount-test") + streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, cluster.bootstrapServers()) + streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, "10000") + streamsConfiguration.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest") + streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, testFolder.getPath) + streamsConfiguration + } + + private def getProducerConfig(): Properties = { + val p = new Properties() + p.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, cluster.bootstrapServers()) + p.put(ProducerConfig.ACKS_CONFIG, "all") + p.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[StringSerializer]) + p.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[StringSerializer]) + p + } + + private def getConsumerConfig(): Properties = { + val p = new Properties() + p.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, cluster.bootstrapServers()) + p.put(ConsumerConfig.GROUP_ID_CONFIG, "wordcount-scala-integration-test-standard-consumer") + p.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest") + p.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, classOf[StringDeserializer]) + p.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, classOf[LongDeserializer]) + p + } + + private def produceNConsume(inputTopic: String, outputTopic: String): java.util.List[KeyValue[String, Long]] = { + + val linesProducerConfig: Properties = getProducerConfig() + + import scala.jdk.CollectionConverters._ + IntegrationTestUtils.produceValuesSynchronously(inputTopic, inputValues.asJava, linesProducerConfig, mockTime) + + val consumerConfig = getConsumerConfig() + + IntegrationTestUtils.waitUntilMinKeyValueRecordsReceived(consumerConfig, outputTopic, expectedWordCounts.size) + } +} + +trait WordCountTestData { + val inputTopic = s"inputTopic" + val outputTopic = s"outputTopic" + val inputTopicJ = s"inputTopicJ" + val outputTopicJ = s"outputTopicJ" + + val inputValues = List( + "Hello Kafka Streams", + "All streams lead to Kafka", + "Join Kafka Summit", + "И теперь пошли русские слова" + ) + + val expectedWordCounts: List[KeyValue[String, Long]] = List( + new KeyValue("hello", 1L), + new KeyValue("all", 1L), + new KeyValue("streams", 2L), + new KeyValue("lead", 1L), + new KeyValue("to", 1L), + new KeyValue("join", 1L), + new KeyValue("kafka", 3L), + new KeyValue("summit", 1L), + new KeyValue("и", 1L), + new KeyValue("теперь", 1L), + new KeyValue("пошли", 1L), + new KeyValue("русские", 1L), + new KeyValue("слова", 1L) + ) +} diff --git a/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/ConsumedTest.scala b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/ConsumedTest.scala new file mode 100644 index 0000000..0b44165 --- /dev/null +++ b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/ConsumedTest.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala.kstream + +import org.apache.kafka.streams.Topology +import org.apache.kafka.streams.kstream.internals.ConsumedInternal +import org.apache.kafka.streams.processor.FailOnInvalidTimestamp +import org.apache.kafka.streams.scala.serialization.Serdes +import org.apache.kafka.streams.scala.serialization.Serdes._ +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test + +class ConsumedTest { + + @Test + def testCreateConsumed(): Unit = { + val consumed: Consumed[String, Long] = Consumed.`with`[String, Long] + + val internalConsumed = new ConsumedInternal(consumed) + assertEquals(Serdes.stringSerde.getClass, internalConsumed.keySerde.getClass) + assertEquals(Serdes.longSerde.getClass, internalConsumed.valueSerde.getClass) + } + + @Test + def testCreateConsumedWithTimestampExtractorAndResetPolicy(): Unit = { + val timestampExtractor = new FailOnInvalidTimestamp() + val resetPolicy = Topology.AutoOffsetReset.LATEST + val consumed: Consumed[String, Long] = + Consumed.`with`[String, Long](timestampExtractor, resetPolicy) + + val internalConsumed = new ConsumedInternal(consumed) + assertEquals(Serdes.stringSerde.getClass, internalConsumed.keySerde.getClass) + assertEquals(Serdes.longSerde.getClass, internalConsumed.valueSerde.getClass) + assertEquals(timestampExtractor, internalConsumed.timestampExtractor) + assertEquals(resetPolicy, internalConsumed.offsetResetPolicy) + } + + @Test + def testCreateConsumedWithTimestampExtractor(): Unit = { + val timestampExtractor = new FailOnInvalidTimestamp() + val consumed: Consumed[String, Long] = Consumed.`with`[String, Long](timestampExtractor) + + val internalConsumed = new ConsumedInternal(consumed) + assertEquals(Serdes.stringSerde.getClass, internalConsumed.keySerde.getClass) + assertEquals(Serdes.longSerde.getClass, internalConsumed.valueSerde.getClass) + assertEquals(timestampExtractor, internalConsumed.timestampExtractor) + } + @Test + def testCreateConsumedWithResetPolicy(): Unit = { + val resetPolicy = Topology.AutoOffsetReset.LATEST + val consumed: Consumed[String, Long] = Consumed.`with`[String, Long](resetPolicy) + + val internalConsumed = new ConsumedInternal(consumed) + assertEquals(Serdes.stringSerde.getClass, internalConsumed.keySerde.getClass) + assertEquals(Serdes.longSerde.getClass, internalConsumed.valueSerde.getClass) + assertEquals(resetPolicy, internalConsumed.offsetResetPolicy) + } +} diff --git a/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/GroupedTest.scala b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/GroupedTest.scala new file mode 100644 index 0000000..02f333e --- /dev/null +++ b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/GroupedTest.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala.kstream + +import org.apache.kafka.streams.kstream.internals.GroupedInternal +import org.apache.kafka.streams.scala.serialization.Serdes +import org.apache.kafka.streams.scala.serialization.Serdes._ +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test + +class GroupedTest { + + @Test + def testCreateGrouped(): Unit = { + val grouped: Grouped[String, Long] = Grouped.`with`[String, Long] + + val internalGrouped = new GroupedInternal[String, Long](grouped) + assertEquals(Serdes.stringSerde.getClass, internalGrouped.keySerde.getClass) + assertEquals(Serdes.longSerde.getClass, internalGrouped.valueSerde.getClass) + } + + @Test + def testCreateGroupedWithRepartitionTopicName(): Unit = { + val repartitionTopicName = "repartition-topic" + val grouped: Grouped[String, Long] = Grouped.`with`(repartitionTopicName) + + val internalGrouped = new GroupedInternal[String, Long](grouped) + assertEquals(Serdes.stringSerde.getClass, internalGrouped.keySerde.getClass) + assertEquals(Serdes.longSerde.getClass, internalGrouped.valueSerde.getClass) + assertEquals(repartitionTopicName, internalGrouped.name()) + } +} diff --git a/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/JoinedTest.scala b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/JoinedTest.scala new file mode 100644 index 0000000..4e6fa56 --- /dev/null +++ b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/JoinedTest.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala.kstream + +import org.apache.kafka.streams.scala.serialization.Serdes +import org.apache.kafka.streams.scala.serialization.Serdes._ +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test + +class JoinedTest { + + @Test + def testCreateJoined(): Unit = { + val joined: Joined[String, Long, Int] = Joined.`with`[String, Long, Int] + + assertEquals(joined.keySerde.getClass, Serdes.stringSerde.getClass) + assertEquals(joined.valueSerde.getClass, Serdes.longSerde.getClass) + assertEquals(joined.otherValueSerde.getClass, Serdes.intSerde.getClass) + } + + @Test + def testCreateJoinedWithSerdesAndRepartitionTopicName(): Unit = { + val repartitionTopicName = "repartition-topic" + val joined: Joined[String, Long, Int] = Joined.`with`(repartitionTopicName) + + assertEquals(joined.keySerde.getClass, Serdes.stringSerde.getClass) + assertEquals(joined.valueSerde.getClass, Serdes.longSerde.getClass) + assertEquals(joined.otherValueSerde.getClass, Serdes.intSerde.getClass) + } +} diff --git a/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/KStreamSplitTest.scala b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/KStreamSplitTest.scala new file mode 100644 index 0000000..bbcc1b5 --- /dev/null +++ b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/KStreamSplitTest.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala.kstream + +import org.apache.kafka.streams.kstream.Named +import org.apache.kafka.streams.scala.ImplicitConversions._ +import org.apache.kafka.streams.scala.StreamsBuilder +import org.apache.kafka.streams.scala.serialization.Serdes._ +import org.apache.kafka.streams.scala.utils.TestDriver +import org.junit.jupiter.api.Assertions._ +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ + +class KStreamSplitTest extends TestDriver { + + @Test + def testRouteMessagesAccordingToPredicates(): Unit = { + val builder = new StreamsBuilder() + val sourceTopic = "source" + val sinkTopic = Array("default", "even", "three"); + + val m = builder + .stream[Integer, Integer](sourceTopic) + .split(Named.as("_")) + .branch((_, v) => v % 2 == 0) + .branch((_, v) => v % 3 == 0) + .defaultBranch() + + m("_0").to(sinkTopic(0)) + m("_1").to(sinkTopic(1)) + m("_2").to(sinkTopic(2)) + + val testDriver = createTestDriver(builder) + val testInput = testDriver.createInput[Integer, Integer](sourceTopic) + val testOutput = sinkTopic.map(name => testDriver.createOutput[Integer, Integer](name)) + + testInput.pipeValueList( + List(1, 2, 3, 4, 5) + .map(Integer.valueOf) + .asJava + ) + assertEquals(List(1, 5), testOutput(0).readValuesToList().asScala) + assertEquals(List(2, 4), testOutput(1).readValuesToList().asScala) + assertEquals(List(3), testOutput(2).readValuesToList().asScala) + + testDriver.close() + } + + @Test + def testRouteMessagesToConsumers(): Unit = { + val builder = new StreamsBuilder() + val sourceTopic = "source" + + val m = builder + .stream[Integer, Integer](sourceTopic) + .split(Named.as("_")) + .branch((_, v) => v % 2 == 0, Branched.withConsumer(ks => ks.to("even"), "consumedEvens")) + .branch((_, v) => v % 3 == 0, Branched.withFunction(ks => ks.mapValues(x => x * x), "mapped")) + .noDefaultBranch() + + m("_mapped").to("mapped") + + val testDriver = createTestDriver(builder) + val testInput = testDriver.createInput[Integer, Integer](sourceTopic) + testInput.pipeValueList( + List(1, 2, 3, 4, 5, 9) + .map(Integer.valueOf) + .asJava + ) + + val even = testDriver.createOutput[Integer, Integer]("even") + val mapped = testDriver.createOutput[Integer, Integer]("mapped") + + assertEquals(List(2, 4), even.readValuesToList().asScala) + assertEquals(List(9, 81), mapped.readValuesToList().asScala) + + testDriver.close() + } + + @Test + def testRouteMessagesToAnonymousConsumers(): Unit = { + val builder = new StreamsBuilder() + val sourceTopic = "source" + + val m = builder + .stream[Integer, Integer](sourceTopic) + .split(Named.as("_")) + .branch((_, v) => v % 2 == 0, Branched.withConsumer(ks => ks.to("even"))) + .branch((_, v) => v % 3 == 0, Branched.withFunction(ks => ks.mapValues(x => x * x))) + .noDefaultBranch() + + m("_2").to("mapped") + + val testDriver = createTestDriver(builder) + val testInput = testDriver.createInput[Integer, Integer](sourceTopic) + testInput.pipeValueList( + List(1, 2, 3, 4, 5, 9) + .map(Integer.valueOf) + .asJava + ) + + val even = testDriver.createOutput[Integer, Integer]("even") + val mapped = testDriver.createOutput[Integer, Integer]("mapped") + + assertEquals(List(2, 4), even.readValuesToList().asScala) + assertEquals(List(9, 81), mapped.readValuesToList().asScala) + + testDriver.close() + } +} diff --git a/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/KStreamTest.scala b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/KStreamTest.scala new file mode 100644 index 0000000..0ec7b0e --- /dev/null +++ b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/KStreamTest.scala @@ -0,0 +1,468 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala.kstream + +import java.time.Duration.ofSeconds +import java.time.{Duration, Instant} +import org.apache.kafka.streams.KeyValue +import org.apache.kafka.streams.kstream.{ + JoinWindows, + Named, + Transformer, + ValueTransformer, + ValueTransformerSupplier, + ValueTransformerWithKey, + ValueTransformerWithKeySupplier +} +import org.apache.kafka.streams.processor.ProcessorContext +import org.apache.kafka.streams.scala.ImplicitConversions._ +import org.apache.kafka.streams.scala.serialization.Serdes._ +import org.apache.kafka.streams.scala.StreamsBuilder +import org.apache.kafka.streams.scala.utils.TestDriver +import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue} +import org.junit.jupiter.api.Test + +import scala.jdk.CollectionConverters._ + +class KStreamTest extends TestDriver { + + @Test + def testFilterRecordsSatisfyingPredicate(): Unit = { + val builder = new StreamsBuilder() + val sourceTopic = "source" + val sinkTopic = "sink" + + builder.stream[String, String](sourceTopic).filter((_, value) => value != "value2").to(sinkTopic) + + val testDriver = createTestDriver(builder) + val testInput = testDriver.createInput[String, String](sourceTopic) + val testOutput = testDriver.createOutput[String, String](sinkTopic) + + testInput.pipeInput("1", "value1") + assertEquals("value1", testOutput.readValue) + + testInput.pipeInput("2", "value2") + assertTrue(testOutput.isEmpty) + + testInput.pipeInput("3", "value3") + assertEquals("value3", testOutput.readValue) + + assertTrue(testOutput.isEmpty) + + testDriver.close() + } + + @Test + def testFilterRecordsNotSatisfyingPredicate(): Unit = { + val builder = new StreamsBuilder() + val sourceTopic = "source" + val sinkTopic = "sink" + + builder.stream[String, String](sourceTopic).filterNot((_, value) => value == "value2").to(sinkTopic) + + val testDriver = createTestDriver(builder) + val testInput = testDriver.createInput[String, String](sourceTopic) + val testOutput = testDriver.createOutput[String, String](sinkTopic) + + testInput.pipeInput("1", "value1") + assertEquals("value1", testOutput.readValue) + + testInput.pipeInput("2", "value2") + assertTrue(testOutput.isEmpty) + + testInput.pipeInput("3", "value3") + assertEquals("value3", testOutput.readValue) + + assertTrue(testOutput.isEmpty) + + testDriver.close() + } + + @Test + def testForeachActionsOnRecords(): Unit = { + val builder = new StreamsBuilder() + val sourceTopic = "source" + + var acc = "" + builder.stream[String, String](sourceTopic).foreach((_, value) => acc += value) + + val testDriver = createTestDriver(builder) + val testInput = testDriver.createInput[String, String](sourceTopic) + + testInput.pipeInput("1", "value1") + assertEquals("value1", acc) + + testInput.pipeInput("2", "value2") + assertEquals("value1value2", acc) + + testDriver.close() + } + + @Test + def testPeekActionsOnRecords(): Unit = { + val builder = new StreamsBuilder() + val sourceTopic = "source" + val sinkTopic = "sink" + + var acc = "" + builder.stream[String, String](sourceTopic).peek((_, v) => acc += v).to(sinkTopic) + + val testDriver = createTestDriver(builder) + val testInput = testDriver.createInput[String, String](sourceTopic) + val testOutput = testDriver.createOutput[String, String](sinkTopic) + + testInput.pipeInput("1", "value1") + assertEquals("value1", acc) + assertEquals("value1", testOutput.readValue) + + testInput.pipeInput("2", "value2") + assertEquals("value1value2", acc) + assertEquals("value2", testOutput.readValue) + + testDriver.close() + } + + @Test + def testSelectNewKey(): Unit = { + val builder = new StreamsBuilder() + val sourceTopic = "source" + val sinkTopic = "sink" + + builder.stream[String, String](sourceTopic).selectKey((_, value) => value).to(sinkTopic) + + val testDriver = createTestDriver(builder) + val testInput = testDriver.createInput[String, String](sourceTopic) + val testOutput = testDriver.createOutput[String, String](sinkTopic) + + testInput.pipeInput("1", "value1") + assertEquals("value1", testOutput.readKeyValue.key) + + testInput.pipeInput("1", "value2") + assertEquals("value2", testOutput.readKeyValue.key) + + assertTrue(testOutput.isEmpty) + + testDriver.close() + } + + @Test + def testRepartitionKStream(): Unit = { + val builder = new StreamsBuilder() + val sourceTopic = "source" + val repartitionName = "repartition" + val sinkTopic = "sink" + + builder.stream[String, String](sourceTopic).repartition(Repartitioned.`with`(repartitionName)).to(sinkTopic) + + val testDriver = createTestDriver(builder) + val testInput = testDriver.createInput[String, String](sourceTopic) + val testOutput = testDriver.createOutput[String, String](sinkTopic) + + testInput.pipeInput("1", "value1") + val kv1 = testOutput.readKeyValue + assertEquals("1", kv1.key) + assertEquals("value1", kv1.value) + + testInput.pipeInput("2", "value2") + val kv2 = testOutput.readKeyValue + assertEquals("2", kv2.key) + assertEquals("value2", kv2.value) + + assertTrue(testOutput.isEmpty) + + // appId == "test" + testDriver.producedTopicNames() contains "test-" + repartitionName + "-repartition" + + testDriver.close() + } + + //noinspection ScalaDeprecation + @Test + def testJoinCorrectlyRecords(): Unit = { + val builder = new StreamsBuilder() + val sourceTopic1 = "source1" + val sourceTopic2 = "source2" + val sinkTopic = "sink" + + val stream1 = builder.stream[String, String](sourceTopic1) + val stream2 = builder.stream[String, String](sourceTopic2) + stream1 + .join(stream2)((a, b) => s"$a-$b", JoinWindows.ofTimeDifferenceAndGrace(ofSeconds(1), Duration.ofHours(24))) + .to(sinkTopic) + + val now = Instant.now() + + val testDriver = createTestDriver(builder, now) + val testInput1 = testDriver.createInput[String, String](sourceTopic1) + val testInput2 = testDriver.createInput[String, String](sourceTopic2) + val testOutput = testDriver.createOutput[String, String](sinkTopic) + + testInput1.pipeInput("1", "topic1value1", now) + testInput2.pipeInput("1", "topic2value1", now) + + assertEquals("topic1value1-topic2value1", testOutput.readValue) + + assertTrue(testOutput.isEmpty) + + testDriver.close() + } + + @Test + def testTransformCorrectlyRecords(): Unit = { + class TestTransformer extends Transformer[String, String, KeyValue[String, String]] { + override def init(context: ProcessorContext): Unit = {} + + override def transform(key: String, value: String): KeyValue[String, String] = + new KeyValue(s"$key-transformed", s"$value-transformed") + + override def close(): Unit = {} + } + val builder = new StreamsBuilder() + val sourceTopic = "source" + val sinkTopic = "sink" + + val stream = builder.stream[String, String](sourceTopic) + stream + .transform(() => new TestTransformer) + .to(sinkTopic) + + val now = Instant.now() + val testDriver = createTestDriver(builder, now) + val testInput = testDriver.createInput[String, String](sourceTopic) + val testOutput = testDriver.createOutput[String, String](sinkTopic) + + testInput.pipeInput("1", "value", now) + + val result = testOutput.readKeyValue() + assertEquals("value-transformed", result.value) + assertEquals("1-transformed", result.key) + + assertTrue(testOutput.isEmpty) + + testDriver.close() + } + + @Test + def testFlatTransformCorrectlyRecords(): Unit = { + class TestTransformer extends Transformer[String, String, Iterable[KeyValue[String, String]]] { + override def init(context: ProcessorContext): Unit = {} + + override def transform(key: String, value: String): Iterable[KeyValue[String, String]] = + Array(new KeyValue(s"$key-transformed", s"$value-transformed")) + + override def close(): Unit = {} + } + val builder = new StreamsBuilder() + val sourceTopic = "source" + val sinkTopic = "sink" + + val stream = builder.stream[String, String](sourceTopic) + stream + .flatTransform(() => new TestTransformer) + .to(sinkTopic) + + val now = Instant.now() + val testDriver = createTestDriver(builder, now) + val testInput = testDriver.createInput[String, String](sourceTopic) + val testOutput = testDriver.createOutput[String, String](sinkTopic) + + testInput.pipeInput("1", "value", now) + + val result = testOutput.readKeyValue() + assertEquals("value-transformed", result.value) + assertEquals("1-transformed", result.key) + + assertTrue(testOutput.isEmpty) + + testDriver.close() + } + + @Test + def testCorrectlyFlatTransformValuesInRecords(): Unit = { + class TestTransformer extends ValueTransformer[String, Iterable[String]] { + override def init(context: ProcessorContext): Unit = {} + + override def transform(value: String): Iterable[String] = + Array(s"$value-transformed") + + override def close(): Unit = {} + } + val builder = new StreamsBuilder() + val sourceTopic = "source" + val sinkTopic = "sink" + + val stream = builder.stream[String, String](sourceTopic) + stream + .flatTransformValues(new ValueTransformerSupplier[String, Iterable[String]] { + def get(): ValueTransformer[String, Iterable[String]] = + new TestTransformer + }) + .to(sinkTopic) + + val now = Instant.now() + val testDriver = createTestDriver(builder, now) + val testInput = testDriver.createInput[String, String](sourceTopic) + val testOutput = testDriver.createOutput[String, String](sinkTopic) + + testInput.pipeInput("1", "value", now) + + assertEquals("value-transformed", testOutput.readValue) + + assertTrue(testOutput.isEmpty) + + testDriver.close() + } + + @Test + def testCorrectlyFlatTransformValuesInRecordsWithKey(): Unit = { + class TestTransformer extends ValueTransformerWithKey[String, String, Iterable[String]] { + override def init(context: ProcessorContext): Unit = {} + + override def transform(key: String, value: String): Iterable[String] = + Array(s"$value-transformed-$key") + + override def close(): Unit = {} + } + val builder = new StreamsBuilder() + val sourceTopic = "source" + val sinkTopic = "sink" + + val stream = builder.stream[String, String](sourceTopic) + stream + .flatTransformValues(new ValueTransformerWithKeySupplier[String, String, Iterable[String]] { + def get(): ValueTransformerWithKey[String, String, Iterable[String]] = + new TestTransformer + }) + .to(sinkTopic) + + val now = Instant.now() + val testDriver = createTestDriver(builder, now) + val testInput = testDriver.createInput[String, String](sourceTopic) + val testOutput = testDriver.createOutput[String, String](sinkTopic) + + testInput.pipeInput("1", "value", now) + + assertEquals("value-transformed-1", testOutput.readValue) + + assertTrue(testOutput.isEmpty) + + testDriver.close() + } + + @Test + def testJoinTwoKStreamToTables(): Unit = { + val builder = new StreamsBuilder() + val sourceTopic1 = "source1" + val sourceTopic2 = "source2" + val sinkTopic = "sink" + + val table1 = builder.stream[String, String](sourceTopic1).toTable + val table2 = builder.stream[String, String](sourceTopic2).toTable + table1.join(table2)((a, b) => a + b).toStream.to(sinkTopic) + + val testDriver = createTestDriver(builder) + val testInput1 = testDriver.createInput[String, String](sourceTopic1) + val testInput2 = testDriver.createInput[String, String](sourceTopic2) + val testOutput = testDriver.createOutput[String, String](sinkTopic) + + testInput1.pipeInput("1", "topic1value1") + testInput2.pipeInput("1", "topic2value1") + + assertEquals("topic1value1topic2value1", testOutput.readValue) + + assertTrue(testOutput.isEmpty) + + testDriver.close() + } + + @Test + def testSettingNameOnFilter(): Unit = { + val builder = new StreamsBuilder() + val sourceTopic = "source" + val sinkTopic = "sink" + + builder + .stream[String, String](sourceTopic) + .filter((_, value) => value != "value2", Named.as("my-name")) + .to(sinkTopic) + + import scala.jdk.CollectionConverters._ + + val filterNode = builder.build().describe().subtopologies().asScala.head.nodes().asScala.toList(1) + assertEquals("my-name", filterNode.name()) + } + + @Test + def testSettingNameOnOutputTable(): Unit = { + val builder = new StreamsBuilder() + val sourceTopic1 = "source1" + val sinkTopic = "sink" + + builder + .stream[String, String](sourceTopic1) + .toTable(Named.as("my-name")) + .toStream + .to(sinkTopic) + + import scala.jdk.CollectionConverters._ + + val tableNode = builder.build().describe().subtopologies().asScala.head.nodes().asScala.toList(1) + assertEquals("my-name", tableNode.name()) + } + + @Test + def testSettingNameOnJoin(): Unit = { + val builder = new StreamsBuilder() + val sourceTopic1 = "source" + val sourceGTable = "table" + val sinkTopic = "sink" + + val stream = builder.stream[String, String](sourceTopic1) + val table = builder.globalTable[String, String](sourceGTable) + stream + .join(table, Named.as("my-name"))((a, b) => s"$a-$b", (a, b) => a + b) + .to(sinkTopic) + + import scala.jdk.CollectionConverters._ + + val joinNode = builder.build().describe().subtopologies().asScala.head.nodes().asScala.toList(1) + assertEquals("my-name", joinNode.name()) + } + + @Test + def testSettingNameOnTransform(): Unit = { + class TestTransformer extends Transformer[String, String, KeyValue[String, String]] { + override def init(context: ProcessorContext): Unit = {} + + override def transform(key: String, value: String): KeyValue[String, String] = + new KeyValue(s"$key-transformed", s"$value-transformed") + + override def close(): Unit = {} + } + val builder = new StreamsBuilder() + val sourceTopic = "source" + val sinkTopic = "sink" + + val stream = builder.stream[String, String](sourceTopic) + stream + .transform(() => new TestTransformer, Named.as("my-name")) + .to(sinkTopic) + + val transformNode = builder.build().describe().subtopologies().asScala.head.nodes().asScala.toList(1) + assertEquals("my-name", transformNode.name()) + } +} diff --git a/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/KTableTest.scala b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/KTableTest.scala new file mode 100644 index 0000000..09a3a7d --- /dev/null +++ b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/KTableTest.scala @@ -0,0 +1,499 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala.kstream + +import org.apache.kafka.streams.kstream.Suppressed.BufferConfig +import org.apache.kafka.streams.kstream.{ + Named, + SlidingWindows, + SessionWindows, + TimeWindows, + Windowed, + Suppressed => JSuppressed +} +import org.apache.kafka.streams.scala.ImplicitConversions._ +import org.apache.kafka.streams.scala.serialization.Serdes._ +import org.apache.kafka.streams.scala.utils.TestDriver +import org.apache.kafka.streams.scala.{ByteArrayKeyValueStore, StreamsBuilder} +import org.junit.jupiter.api.Assertions.{assertEquals, assertNull, assertTrue} +import org.junit.jupiter.api.Test +import java.time.Duration +import java.time.Duration.ofMillis + +import scala.jdk.CollectionConverters._ + +//noinspection ScalaDeprecation +class KTableTest extends TestDriver { + + @Test + def testFilterRecordsSatisfyingPredicate(): Unit = { + val builder = new StreamsBuilder() + val sourceTopic = "source" + val sinkTopic = "sink" + + val table = builder.stream[String, String](sourceTopic).groupBy((key, _) => key).count() + table.filter((key, value) => key.equals("a") && value == 1).toStream.to(sinkTopic) + + val testDriver = createTestDriver(builder) + val testInput = testDriver.createInput[String, String](sourceTopic) + val testOutput = testDriver.createOutput[String, Long](sinkTopic) + + { + testInput.pipeInput("a", "passes filter : add new row to table") + val record = testOutput.readKeyValue + assertEquals("a", record.key) + assertEquals(1, record.value) + } + { + testInput.pipeInput("a", "fails filter : remove existing row from table") + val record = testOutput.readKeyValue + assertEquals("a", record.key) + assertNull(record.value) + } + { + testInput.pipeInput("b", "fails filter : no output") + assertTrue(testOutput.isEmpty) + } + assertTrue(testOutput.isEmpty) + + testDriver.close() + } + + @Test + def testFilterRecordsNotSatisfyingPredicate(): Unit = { + val builder = new StreamsBuilder() + val sourceTopic = "source" + val sinkTopic = "sink" + + val table = builder.stream[String, String](sourceTopic).groupBy((key, _) => key).count() + table.filterNot((_, value) => value > 1).toStream.to(sinkTopic) + + val testDriver = createTestDriver(builder) + val testInput = testDriver.createInput[String, String](sourceTopic) + val testOutput = testDriver.createOutput[String, Long](sinkTopic) + + { + testInput.pipeInput("1", "value1") + val record = testOutput.readKeyValue + assertEquals("1", record.key) + assertEquals(1, record.value) + } + { + testInput.pipeInput("1", "value2") + val record = testOutput.readKeyValue + assertEquals("1", record.key) + assertNull(record.value) + } + { + testInput.pipeInput("2", "value1") + val record = testOutput.readKeyValue + assertEquals("2", record.key) + assertEquals(1, record.value) + } + assertTrue(testOutput.isEmpty) + + testDriver.close() + } + + @Test + def testJoinCorrectlyRecords(): Unit = { + val builder = new StreamsBuilder() + val sourceTopic1 = "source1" + val sourceTopic2 = "source2" + val sinkTopic = "sink" + + val table1 = builder.stream[String, String](sourceTopic1).groupBy((key, _) => key).count() + val table2 = builder.stream[String, String](sourceTopic2).groupBy((key, _) => key).count() + table1.join(table2)((a, b) => a + b).toStream.to(sinkTopic) + + val testDriver = createTestDriver(builder) + val testInput1 = testDriver.createInput[String, String](sourceTopic1) + val testInput2 = testDriver.createInput[String, String](sourceTopic2) + val testOutput = testDriver.createOutput[String, Long](sinkTopic) + + testInput1.pipeInput("1", "topic1value1") + testInput2.pipeInput("1", "topic2value1") + assertEquals(2, testOutput.readValue) + + assertTrue(testOutput.isEmpty) + + testDriver.close() + } + + @Test + def testJoinCorrectlyRecordsAndStateStore(): Unit = { + val builder = new StreamsBuilder() + val sourceTopic1 = "source1" + val sourceTopic2 = "source2" + val sinkTopic = "sink" + val stateStore = "store" + val materialized = Materialized.as[String, Long, ByteArrayKeyValueStore](stateStore) + + val table1 = builder.stream[String, String](sourceTopic1).groupBy((key, _) => key).count() + val table2 = builder.stream[String, String](sourceTopic2).groupBy((key, _) => key).count() + table1.join(table2, materialized)((a, b) => a + b).toStream.to(sinkTopic) + + val testDriver = createTestDriver(builder) + val testInput1 = testDriver.createInput[String, String](sourceTopic1) + val testInput2 = testDriver.createInput[String, String](sourceTopic2) + val testOutput = testDriver.createOutput[String, Long](sinkTopic) + + testInput1.pipeInput("1", "topic1value1") + testInput2.pipeInput("1", "topic2value1") + assertEquals(2, testOutput.readValue) + assertEquals(2, testDriver.getKeyValueStore[String, Long](stateStore).get("1")) + + assertTrue(testOutput.isEmpty) + + testDriver.close() + } + + @Test + def testCorrectlySuppressResultsUsingSuppressedUntilTimeLimit(): Unit = { + val builder = new StreamsBuilder() + val sourceTopic = "source" + val sinkTopic = "sink" + val window = TimeWindows.ofSizeAndGrace(Duration.ofSeconds(1L), Duration.ofHours(24)) + val suppression = JSuppressed.untilTimeLimit[Windowed[String]](Duration.ofSeconds(2L), BufferConfig.unbounded()) + + val table: KTable[Windowed[String], Long] = builder + .stream[String, String](sourceTopic) + .groupByKey + .windowedBy(window) + .count() + .suppress(suppression) + + table.toStream((k, _) => s"${k.window().start()}:${k.window().end()}:${k.key()}").to(sinkTopic) + + val testDriver = createTestDriver(builder) + val testInput = testDriver.createInput[String, String](sourceTopic) + val testOutput = testDriver.createOutput[String, Long](sinkTopic) + + { + // publish key=1 @ time 0 => count==1 + testInput.pipeInput("1", "value1", 0L) + assertTrue(testOutput.isEmpty) + } + { + // publish key=1 @ time 1 => count==2 + testInput.pipeInput("1", "value2", 1L) + assertTrue(testOutput.isEmpty) + } + { + // move event time past the first window, but before the suppression window + testInput.pipeInput("2", "value1", 1001L) + assertTrue(testOutput.isEmpty) + } + { + // move event time riiiight before suppression window ends + testInput.pipeInput("2", "value2", 1999L) + assertTrue(testOutput.isEmpty) + } + { + // publish a late event before suppression window terminates => count==3 + testInput.pipeInput("1", "value3", 999L) + assertTrue(testOutput.isEmpty) + } + { + // move event time right past the suppression window of the first window. + testInput.pipeInput("2", "value3", 2001L) + val record = testOutput.readKeyValue + assertEquals("0:1000:1", record.key) + assertEquals(3L, record.value) + } + assertTrue(testOutput.isEmpty) + + testDriver.close() + } + + @Test + def testCorrectlyGroupByKeyWindowedBySlidingWindow(): Unit = { + val builder = new StreamsBuilder() + val sourceTopic = "source" + val sinkTopic = "sink" + val window = SlidingWindows.ofTimeDifferenceAndGrace(ofMillis(1000L), ofMillis(1000L)) + val suppression = JSuppressed.untilWindowCloses(BufferConfig.unbounded()) + + val table: KTable[Windowed[String], Long] = builder + .stream[String, String](sourceTopic) + .groupByKey + .windowedBy(window) + .count() + .suppress(suppression) + + table.toStream((k, _) => s"${k.window().start()}:${k.window().end()}:${k.key()}").to(sinkTopic) + + val testDriver = createTestDriver(builder) + val testInput = testDriver.createInput[String, String](sourceTopic) + val testOutput = testDriver.createOutput[String, Long](sinkTopic) + + { + // publish key=1 @ time 0 => count==1 + testInput.pipeInput("1", "value1", 0L) + assertTrue(testOutput.isEmpty) + } + { + // move event time right past the grace period of the first window. + testInput.pipeInput("2", "value3", 5001L) + val record = testOutput.readKeyValue + assertEquals("0:1000:1", record.key) + assertEquals(1L, record.value) + } + assertTrue(testOutput.isEmpty) + + testDriver.close() + } + + @Test + def testCorrectlySuppressResultsUsingSuppressedUntilWindowClosesByWindowed(): Unit = { + val builder = new StreamsBuilder() + val sourceTopic = "source" + val sinkTopic = "sink" + val window = TimeWindows.ofSizeAndGrace(Duration.ofSeconds(1L), Duration.ofSeconds(1L)) + val suppression = JSuppressed.untilWindowCloses(BufferConfig.unbounded()) + + val table: KTable[Windowed[String], Long] = builder + .stream[String, String](sourceTopic) + .groupByKey + .windowedBy(window) + .count() + .suppress(suppression) + + table.toStream((k, _) => s"${k.window().start()}:${k.window().end()}:${k.key()}").to(sinkTopic) + + val testDriver = createTestDriver(builder) + val testInput = testDriver.createInput[String, String](sourceTopic) + val testOutput = testDriver.createOutput[String, Long](sinkTopic) + + { + // publish key=1 @ time 0 => count==1 + testInput.pipeInput("1", "value1", 0L) + assertTrue(testOutput.isEmpty) + } + { + // publish key=1 @ time 1 => count==2 + testInput.pipeInput("1", "value2", 1L) + assertTrue(testOutput.isEmpty) + } + { + // move event time past the window, but before the grace period + testInput.pipeInput("2", "value1", 1001L) + assertTrue(testOutput.isEmpty) + } + { + // move event time riiiight before grace period ends + testInput.pipeInput("2", "value2", 1999L) + assertTrue(testOutput.isEmpty) + } + { + // publish a late event before grace period terminates => count==3 + testInput.pipeInput("1", "value3", 999L) + assertTrue(testOutput.isEmpty) + } + { + // move event time right past the grace period of the first window. + testInput.pipeInput("2", "value3", 2001L) + val record = testOutput.readKeyValue + assertEquals("0:1000:1", record.key) + assertEquals(3L, record.value) + } + assertTrue(testOutput.isEmpty) + + testDriver.close() + } + + @Test + def testCorrectlySuppressResultsUsingSuppressedUntilWindowClosesBySession(): Unit = { + val builder = new StreamsBuilder() + val sourceTopic = "source" + val sinkTopic = "sink" + // Very similar to SuppressScenarioTest.shouldSupportFinalResultsForSessionWindows + val window = SessionWindows.ofInactivityGapAndGrace(Duration.ofMillis(5L), Duration.ofMillis(10L)) + val suppression = JSuppressed.untilWindowCloses(BufferConfig.unbounded()) + + val table: KTable[Windowed[String], Long] = builder + .stream[String, String](sourceTopic) + .groupByKey + .windowedBy(window) + .count() + .suppress(suppression) + + table.toStream((k, _) => s"${k.window().start()}:${k.window().end()}:${k.key()}").to(sinkTopic) + + val testDriver = createTestDriver(builder) + val testInput = testDriver.createInput[String, String](sourceTopic) + val testOutput = testDriver.createOutput[String, Long](sinkTopic) + + { + // first window + testInput.pipeInput("k1", "v1", 0L) + assertTrue(testOutput.isEmpty) + } + { + // first window + testInput.pipeInput("k1", "v1", 1L) + assertTrue(testOutput.isEmpty) + } + { + // new window, but grace period hasn't ended for first window + testInput.pipeInput("k1", "v1", 8L) + assertTrue(testOutput.isEmpty) + } + { + // out-of-order event for first window, included since grade period hasn't passed + testInput.pipeInput("k1", "v1", 2L) + assertTrue(testOutput.isEmpty) + } + { + // add to second window + testInput.pipeInput("k1", "v1", 13L) + assertTrue(testOutput.isEmpty) + } + { + // add out-of-order to second window + testInput.pipeInput("k1", "v1", 10L) + assertTrue(testOutput.isEmpty) + } + { + // push stream time forward to flush other events through + testInput.pipeInput("k1", "v1", 30L) + // late event should get dropped from the stream + testInput.pipeInput("k1", "v1", 3L) + // should now have to results + val r1 = testOutput.readRecord + assertEquals("0:2:k1", r1.key) + assertEquals(3L, r1.value) + assertEquals(2L, r1.timestamp) + val r2 = testOutput.readRecord + assertEquals("8:13:k1", r2.key) + assertEquals(3L, r2.value) + assertEquals(13L, r2.timestamp) + } + assertTrue(testOutput.isEmpty) + + testDriver.close() + } + + @Test + def testCorrectlySuppressResultsUsingSuppressedUntilTimeLimtByNonWindowed(): Unit = { + val builder = new StreamsBuilder() + val sourceTopic = "source" + val sinkTopic = "sink" + val suppression = JSuppressed.untilTimeLimit[String](Duration.ofSeconds(2L), BufferConfig.unbounded()) + + val table: KTable[String, Long] = builder + .stream[String, String](sourceTopic) + .groupByKey + .count() + .suppress(suppression) + + table.toStream.to(sinkTopic) + + val testDriver = createTestDriver(builder) + val testInput = testDriver.createInput[String, String](sourceTopic) + val testOutput = testDriver.createOutput[String, Long](sinkTopic) + + { + // publish key=1 @ time 0 => count==1 + testInput.pipeInput("1", "value1", 0L) + assertTrue(testOutput.isEmpty) + } + { + // publish key=1 @ time 1 => count==2 + testInput.pipeInput("1", "value2", 1L) + assertTrue(testOutput.isEmpty) + } + { + // move event time past the window, but before the grace period + testInput.pipeInput("2", "value1", 1001L) + assertTrue(testOutput.isEmpty) + } + { + // move event time right before grace period ends + testInput.pipeInput("2", "value2", 1999L) + assertTrue(testOutput.isEmpty) + } + { + // publish a late event before grace period terminates => count==3 + testInput.pipeInput("1", "value3", 999L) + assertTrue(testOutput.isEmpty) + } + { + // move event time right past the grace period of the first window. + testInput.pipeInput("2", "value3", 2001L) + val record = testOutput.readKeyValue + assertEquals("1", record.key) + assertEquals(3L, record.value) + } + assertTrue(testOutput.isEmpty) + + testDriver.close() + } + + @Test + def testSettingNameOnFilterProcessor(): Unit = { + val builder = new StreamsBuilder() + val sourceTopic = "source" + val sinkTopic = "sink" + + val table = builder.stream[String, String](sourceTopic).groupBy((key, _) => key).count() + table + .filter((key, value) => key.equals("a") && value == 1, Named.as("my-name")) + .toStream + .to(sinkTopic) + + import scala.jdk.CollectionConverters._ + + val filterNode = builder.build().describe().subtopologies().asScala.toList(1).nodes().asScala.toList(3) + assertEquals("my-name", filterNode.name()) + } + + @Test + def testSettingNameOnCountProcessor(): Unit = { + val builder = new StreamsBuilder() + val sourceTopic = "source" + val sinkTopic = "sink" + + val table = builder.stream[String, String](sourceTopic).groupBy((key, _) => key).count(Named.as("my-name")) + table.toStream.to(sinkTopic) + + import scala.jdk.CollectionConverters._ + + val countNode = builder.build().describe().subtopologies().asScala.toList(1).nodes().asScala.toList(1) + assertEquals("my-name", countNode.name()) + } + + @Test + def testSettingNameOnJoinProcessor(): Unit = { + val builder = new StreamsBuilder() + val sourceTopic1 = "source1" + val sourceTopic2 = "source2" + val sinkTopic = "sink" + + val table1 = builder.stream[String, String](sourceTopic1).groupBy((key, _) => key).count() + val table2 = builder.stream[String, String](sourceTopic2).groupBy((key, _) => key).count() + table1 + .join(table2, Named.as("my-name"))((a, b) => a + b) + .toStream + .to(sinkTopic) + + val joinNodeLeft = builder.build().describe().subtopologies().asScala.toList(1).nodes().asScala.toList(6) + val joinNodeRight = builder.build().describe().subtopologies().asScala.toList(1).nodes().asScala.toList(7) + assertTrue(joinNodeLeft.name().contains("my-name")) + assertTrue(joinNodeRight.name().contains("my-name")) + } +} diff --git a/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/MaterializedTest.scala b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/MaterializedTest.scala new file mode 100644 index 0000000..9e0c466 --- /dev/null +++ b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/MaterializedTest.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala.kstream + +import org.apache.kafka.streams.kstream.internals.MaterializedInternal +import org.apache.kafka.streams.scala._ +import org.apache.kafka.streams.scala.serialization.Serdes +import org.apache.kafka.streams.scala.serialization.Serdes._ +import org.apache.kafka.streams.state.Stores +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test + +import java.time.Duration + +class MaterializedTest { + + @Test + def testCreateMaterializedWithSerdes(): Unit = { + val materialized: Materialized[String, Long, ByteArrayKeyValueStore] = + Materialized.`with`[String, Long, ByteArrayKeyValueStore] + + val internalMaterialized = new MaterializedInternal(materialized) + assertEquals(Serdes.stringSerde.getClass, internalMaterialized.keySerde.getClass) + assertEquals(Serdes.longSerde.getClass, internalMaterialized.valueSerde.getClass) + } + + @Test + def testCreateMaterializedWithSerdesAndStoreName(): Unit = { + val storeName = "store" + val materialized: Materialized[String, Long, ByteArrayKeyValueStore] = + Materialized.as[String, Long, ByteArrayKeyValueStore](storeName) + + val internalMaterialized = new MaterializedInternal(materialized) + assertEquals(Serdes.stringSerde.getClass, internalMaterialized.keySerde.getClass) + assertEquals(Serdes.longSerde.getClass, internalMaterialized.valueSerde.getClass) + assertEquals(storeName, internalMaterialized.storeName) + } + + @Test + def testCreateMaterializedWithSerdesAndWindowStoreSupplier(): Unit = { + val storeSupplier = Stores.persistentWindowStore("store", Duration.ofMillis(1), Duration.ofMillis(1), true) + val materialized: Materialized[String, Long, ByteArrayWindowStore] = + Materialized.as[String, Long](storeSupplier) + + val internalMaterialized = new MaterializedInternal(materialized) + assertEquals(Serdes.stringSerde.getClass, internalMaterialized.keySerde.getClass) + assertEquals(Serdes.longSerde.getClass, internalMaterialized.valueSerde.getClass) + assertEquals(storeSupplier, internalMaterialized.storeSupplier) + } + + @Test + def testCreateMaterializedWithSerdesAndKeyValueStoreSupplier(): Unit = { + val storeSupplier = Stores.persistentKeyValueStore("store") + val materialized: Materialized[String, Long, ByteArrayKeyValueStore] = + Materialized.as[String, Long](storeSupplier) + + val internalMaterialized = new MaterializedInternal(materialized) + assertEquals(Serdes.stringSerde.getClass, internalMaterialized.keySerde.getClass) + assertEquals(Serdes.longSerde.getClass, internalMaterialized.valueSerde.getClass) + assertEquals(storeSupplier, internalMaterialized.storeSupplier) + } + + @Test + def testCreateMaterializedWithSerdesAndSessionStoreSupplier(): Unit = { + val storeSupplier = Stores.persistentSessionStore("store", Duration.ofMillis(1)) + val materialized: Materialized[String, Long, ByteArraySessionStore] = + Materialized.as[String, Long](storeSupplier) + + val internalMaterialized = new MaterializedInternal(materialized) + assertEquals(Serdes.stringSerde.getClass, internalMaterialized.keySerde.getClass) + assertEquals(Serdes.longSerde.getClass, internalMaterialized.valueSerde.getClass) + assertEquals(storeSupplier, internalMaterialized.storeSupplier) + } +} diff --git a/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/ProducedTest.scala b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/ProducedTest.scala new file mode 100644 index 0000000..69c4b17 --- /dev/null +++ b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/ProducedTest.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala.kstream + +import org.apache.kafka.streams.kstream.internals.ProducedInternal +import org.apache.kafka.streams.processor.StreamPartitioner +import org.apache.kafka.streams.scala.serialization.Serdes +import org.apache.kafka.streams.scala.serialization.Serdes._ +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test + +class ProducedTest { + + @Test + def testCreateProducedWithSerdes(): Unit = { + val produced: Produced[String, Long] = Produced.`with`[String, Long] + + val internalProduced = new ProducedInternal(produced) + assertEquals(Serdes.stringSerde.getClass, internalProduced.keySerde.getClass) + assertEquals(Serdes.longSerde.getClass, internalProduced.valueSerde.getClass) + } + + @Test + def testCreateProducedWithSerdesAndStreamPartitioner(): Unit = { + val partitioner = new StreamPartitioner[String, Long] { + override def partition(topic: String, key: String, value: Long, numPartitions: Int): Integer = 0 + } + val produced: Produced[String, Long] = Produced.`with`(partitioner) + + val internalProduced = new ProducedInternal(produced) + assertEquals(Serdes.stringSerde.getClass, internalProduced.keySerde.getClass) + assertEquals(Serdes.longSerde.getClass, internalProduced.valueSerde.getClass) + assertEquals(partitioner, internalProduced.streamPartitioner) + } +} diff --git a/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/RepartitionedTest.scala b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/RepartitionedTest.scala new file mode 100644 index 0000000..4c8d895 --- /dev/null +++ b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/RepartitionedTest.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala.kstream + +import org.apache.kafka.streams.kstream.internals.RepartitionedInternal +import org.apache.kafka.streams.processor.StreamPartitioner +import org.apache.kafka.streams.scala.serialization.Serdes +import org.apache.kafka.streams.scala.serialization.Serdes._ +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test + +class RepartitionedTest { + + @Test + def testCreateRepartitionedWithSerdes(): Unit = { + val repartitioned: Repartitioned[String, Long] = Repartitioned.`with`[String, Long] + + val internalRepartitioned = new RepartitionedInternal(repartitioned) + assertEquals(Serdes.stringSerde.getClass, internalRepartitioned.keySerde.getClass) + assertEquals(Serdes.longSerde.getClass, internalRepartitioned.valueSerde.getClass) + } + + @Test + def testCreateRepartitionedWithSerdesAndNumPartitions(): Unit = { + val repartitioned: Repartitioned[String, Long] = Repartitioned.`with`[String, Long](5) + + val internalRepartitioned = new RepartitionedInternal(repartitioned) + assertEquals(Serdes.stringSerde.getClass, internalRepartitioned.keySerde.getClass) + assertEquals(Serdes.longSerde.getClass, internalRepartitioned.valueSerde.getClass) + assertEquals(5, internalRepartitioned.numberOfPartitions) + + } + + @Test + def testCreateRepartitionedWithSerdesAndTopicName(): Unit = { + val repartitioned: Repartitioned[String, Long] = Repartitioned.`with`[String, Long]("repartitionTopic") + + val internalRepartitioned = new RepartitionedInternal(repartitioned) + assertEquals(Serdes.stringSerde.getClass, internalRepartitioned.keySerde.getClass) + assertEquals(Serdes.longSerde.getClass, internalRepartitioned.valueSerde.getClass) + assertEquals("repartitionTopic", internalRepartitioned.name) + } + + @Test + def testCreateRepartitionedWithSerdesAndTopicNameAndNumPartitionsAndStreamPartitioner(): Unit = { + val partitioner = new StreamPartitioner[String, Long] { + override def partition(topic: String, key: String, value: Long, numPartitions: Int): Integer = 0 + } + val repartitioned: Repartitioned[String, Long] = Repartitioned.`with`[String, Long](partitioner) + + val internalRepartitioned = new RepartitionedInternal(repartitioned) + assertEquals(Serdes.stringSerde.getClass, internalRepartitioned.keySerde.getClass) + assertEquals(Serdes.longSerde.getClass, internalRepartitioned.valueSerde.getClass) + assertEquals(partitioner, internalRepartitioned.streamPartitioner) + } + + @Test + def testCreateRepartitionedWithTopicNameAndNumPartitionsAndStreamPartitioner(): Unit = { + val partitioner = new StreamPartitioner[String, Long] { + override def partition(topic: String, key: String, value: Long, numPartitions: Int): Integer = 0 + } + val repartitioned: Repartitioned[String, Long] = + Repartitioned + .`with`[String, Long](5) + .withName("repartitionTopic") + .withStreamPartitioner(partitioner) + + val internalRepartitioned = new RepartitionedInternal(repartitioned) + assertEquals(Serdes.stringSerde.getClass, internalRepartitioned.keySerde.getClass) + assertEquals(Serdes.longSerde.getClass, internalRepartitioned.valueSerde.getClass) + assertEquals(5, internalRepartitioned.numberOfPartitions) + assertEquals("repartitionTopic", internalRepartitioned.name) + assertEquals(partitioner, internalRepartitioned.streamPartitioner) + } + +} diff --git a/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/StreamJoinedTest.scala b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/StreamJoinedTest.scala new file mode 100644 index 0000000..0717d05 --- /dev/null +++ b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/kstream/StreamJoinedTest.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala.kstream + +import org.apache.kafka.streams.kstream.internals.StreamJoinedInternal +import org.apache.kafka.streams.scala.serialization.Serdes +import org.apache.kafka.streams.scala.serialization.Serdes._ +import org.apache.kafka.streams.state.Stores +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Test + +import java.time.Duration + +class StreamJoinedTest { + + @Test + def testCreateStreamJoinedWithSerdes(): Unit = { + val streamJoined: StreamJoined[String, String, Long] = StreamJoined.`with`[String, String, Long] + + val streamJoinedInternal = new StreamJoinedInternal[String, String, Long](streamJoined) + assertEquals(Serdes.stringSerde.getClass, streamJoinedInternal.keySerde().getClass) + assertEquals(Serdes.stringSerde.getClass, streamJoinedInternal.valueSerde().getClass) + assertEquals(Serdes.longSerde.getClass, streamJoinedInternal.otherValueSerde().getClass) + } + + @Test + def testCreateStreamJoinedWithSerdesAndStoreSuppliers(): Unit = { + val storeSupplier = Stores.inMemoryWindowStore("myStore", Duration.ofMillis(500), Duration.ofMillis(250), false) + + val otherStoreSupplier = + Stores.inMemoryWindowStore("otherStore", Duration.ofMillis(500), Duration.ofMillis(250), false) + + val streamJoined: StreamJoined[String, String, Long] = + StreamJoined.`with`[String, String, Long](storeSupplier, otherStoreSupplier) + + val streamJoinedInternal = new StreamJoinedInternal[String, String, Long](streamJoined) + assertEquals(Serdes.stringSerde.getClass, streamJoinedInternal.keySerde().getClass) + assertEquals(Serdes.stringSerde.getClass, streamJoinedInternal.valueSerde().getClass) + assertEquals(Serdes.longSerde.getClass, streamJoinedInternal.otherValueSerde().getClass) + assertEquals(otherStoreSupplier, streamJoinedInternal.otherStoreSupplier()) + assertEquals(storeSupplier, streamJoinedInternal.thisStoreSupplier()) + } + + @Test + def testCreateStreamJoinedWithSerdesAndStateStoreName(): Unit = { + val streamJoined: StreamJoined[String, String, Long] = StreamJoined.as[String, String, Long]("myStoreName") + + val streamJoinedInternal = new StreamJoinedInternal[String, String, Long](streamJoined) + assertEquals(Serdes.stringSerde.getClass, streamJoinedInternal.keySerde().getClass) + assertEquals(Serdes.stringSerde.getClass, streamJoinedInternal.valueSerde().getClass) + assertEquals(Serdes.longSerde.getClass, streamJoinedInternal.otherValueSerde().getClass) + assertEquals("myStoreName", streamJoinedInternal.storeName()) + } + +} diff --git a/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/utils/StreamToTableJoinScalaIntegrationTestBase.scala b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/utils/StreamToTableJoinScalaIntegrationTestBase.scala new file mode 100644 index 0000000..984cb74 --- /dev/null +++ b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/utils/StreamToTableJoinScalaIntegrationTestBase.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala.utils + +import java.util.Properties +import org.apache.kafka.clients.consumer.ConsumerConfig +import org.apache.kafka.clients.producer.ProducerConfig +import org.apache.kafka.common.serialization._ +import org.apache.kafka.common.utils.{MockTime, Utils} +import org.apache.kafka.streams._ +import org.apache.kafka.streams.integration.utils.{EmbeddedKafkaCluster, IntegrationTestUtils} +import org.apache.kafka.test.TestUtils +import org.junit.jupiter.api._ + +import java.io.File + +/** + * Test suite base that prepares Kafka cluster for stream-table joins in Kafka Streams + *

                + */ +@Tag("integration") +class StreamToTableJoinScalaIntegrationTestBase extends StreamToTableJoinTestData { + + private val cluster: EmbeddedKafkaCluster = new EmbeddedKafkaCluster(1) + + final private val alignedTime = (System.currentTimeMillis() / 1000 + 1) * 1000 + private val mockTime: MockTime = cluster.time + mockTime.setCurrentTimeMs(alignedTime) + + private val testFolder: File = TestUtils.tempDirectory() + + @BeforeEach + def startKafkaCluster(): Unit = { + cluster.start() + cluster.createTopic(userClicksTopic) + cluster.createTopic(userRegionsTopic) + cluster.createTopic(outputTopic) + cluster.createTopic(userClicksTopicJ) + cluster.createTopic(userRegionsTopicJ) + cluster.createTopic(outputTopicJ) + } + + @AfterEach + def stopKafkaCluster(): Unit = { + cluster.stop() + Utils.delete(testFolder) + } + + def getStreamsConfiguration(): Properties = { + val streamsConfiguration: Properties = new Properties() + + streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, "stream-table-join-scala-integration-test") + streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, cluster.bootstrapServers()) + streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, "1000") + streamsConfiguration.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest") + streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, testFolder.getPath) + + streamsConfiguration + } + + private def getUserRegionsProducerConfig(): Properties = { + val p = new Properties() + p.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, cluster.bootstrapServers()) + p.put(ProducerConfig.ACKS_CONFIG, "all") + p.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[StringSerializer]) + p.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[StringSerializer]) + p + } + + private def getUserClicksProducerConfig(): Properties = { + val p = new Properties() + p.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, cluster.bootstrapServers()) + p.put(ProducerConfig.ACKS_CONFIG, "all") + p.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[StringSerializer]) + p.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[LongSerializer]) + p + } + + private def getConsumerConfig(): Properties = { + val p = new Properties() + p.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, cluster.bootstrapServers()) + p.put(ConsumerConfig.GROUP_ID_CONFIG, "join-scala-integration-test-standard-consumer") + p.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest") + p.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, classOf[StringDeserializer]) + p.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, classOf[LongDeserializer]) + p + } + + def produceNConsume( + userClicksTopic: String, + userRegionsTopic: String, + outputTopic: String, + waitTillRecordsReceived: Boolean = true + ): java.util.List[KeyValue[String, Long]] = { + + import _root_.scala.jdk.CollectionConverters._ + + // Publish user-region information. + val userRegionsProducerConfig: Properties = getUserRegionsProducerConfig() + IntegrationTestUtils.produceKeyValuesSynchronously( + userRegionsTopic, + userRegions.asJava, + userRegionsProducerConfig, + mockTime, + false + ) + + // Publish user-click information. + val userClicksProducerConfig: Properties = getUserClicksProducerConfig() + IntegrationTestUtils.produceKeyValuesSynchronously( + userClicksTopic, + userClicks.asJava, + userClicksProducerConfig, + mockTime, + false + ) + + if (waitTillRecordsReceived) { + // consume and verify result + val consumerConfig = getConsumerConfig() + + IntegrationTestUtils.waitUntilFinalKeyValueRecordsReceived( + consumerConfig, + outputTopic, + expectedClicksPerRegion.asJava + ) + } else { + java.util.Collections.emptyList() + } + } +} diff --git a/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/utils/StreamToTableJoinTestData.scala b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/utils/StreamToTableJoinTestData.scala new file mode 100644 index 0000000..29d0695 --- /dev/null +++ b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/utils/StreamToTableJoinTestData.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala.utils + +import org.apache.kafka.streams.KeyValue + +trait StreamToTableJoinTestData { + val brokers = "localhost:9092" + + val userClicksTopic = s"user-clicks" + val userRegionsTopic = s"user-regions" + val outputTopic = s"output-topic" + + val userClicksTopicJ = s"user-clicks-j" + val userRegionsTopicJ = s"user-regions-j" + val outputTopicJ = s"output-topic-j" + + // Input 1: Clicks per user (multiple records allowed per user). + val userClicks: Seq[KeyValue[String, Long]] = Seq( + new KeyValue("alice", 13L), + new KeyValue("bob", 4L), + new KeyValue("chao", 25L), + new KeyValue("bob", 19L), + new KeyValue("dave", 56L), + new KeyValue("eve", 78L), + new KeyValue("alice", 40L), + new KeyValue("fang", 99L) + ) + + // Input 2: Region per user (multiple records allowed per user). + val userRegions: Seq[KeyValue[String, String]] = Seq( + new KeyValue("alice", "asia"), /* Alice lived in Asia originally... */ + new KeyValue("bob", "americas"), + new KeyValue("chao", "asia"), + new KeyValue("dave", "europe"), + new KeyValue("alice", "europe"), /* ...but moved to Europe some time later. */ + new KeyValue("eve", "americas"), + new KeyValue("fang", "asia") + ) + + val expectedClicksPerRegion: Seq[KeyValue[String, Long]] = Seq( + new KeyValue("americas", 101L), + new KeyValue("europe", 109L), + new KeyValue("asia", 124L) + ) +} diff --git a/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/utils/TestDriver.scala b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/utils/TestDriver.scala new file mode 100644 index 0000000..23a2417 --- /dev/null +++ b/streams/streams-scala/src/test/scala/org/apache/kafka/streams/scala/utils/TestDriver.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.scala.utils + +import java.time.Instant +import java.util.Properties + +import org.apache.kafka.common.serialization.Serde +import org.apache.kafka.streams.scala.StreamsBuilder +import org.apache.kafka.streams.{StreamsConfig, TestInputTopic, TestOutputTopic, TopologyTestDriver} +import org.apache.kafka.test.TestUtils + +trait TestDriver { + def createTestDriver(builder: StreamsBuilder, initialWallClockTime: Instant = Instant.now()): TopologyTestDriver = { + val config = new Properties() + config.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath) + new TopologyTestDriver(builder.build(), config, initialWallClockTime) + } + + implicit class TopologyTestDriverOps(inner: TopologyTestDriver) { + def createInput[K, V](topic: String)(implicit serdeKey: Serde[K], serdeValue: Serde[V]): TestInputTopic[K, V] = + inner.createInputTopic(topic, serdeKey.serializer, serdeValue.serializer) + + def createOutput[K, V](topic: String)(implicit serdeKey: Serde[K], serdeValue: Serde[V]): TestOutputTopic[K, V] = + inner.createOutputTopic(topic, serdeKey.deserializer, serdeValue.deserializer) + } +} diff --git a/streams/test-utils/src/main/java/org/apache/kafka/streams/TestInputTopic.java b/streams/test-utils/src/main/java/org/apache/kafka/streams/TestInputTopic.java new file mode 100644 index 0000000..c5966a7 --- /dev/null +++ b/streams/test-utils/src/main/java/org/apache/kafka/streams/TestInputTopic.java @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.streams.test.TestRecord; + +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.Objects; +import java.util.StringJoiner; + +/** + * {@code TestInputTopic} is used to pipe records to topic in {@link TopologyTestDriver}. + * To use {@code TestInputTopic} create a new instance via + * {@link TopologyTestDriver#createInputTopic(String, Serializer, Serializer)}. + * In actual test code, you can pipe new record values, keys and values or list of {@link KeyValue} pairs. + * If you have multiple source topics, you need to create a {@code TestInputTopic} for each. + * + *

                Processing messages

                + *
                {@code
                + *     private TestInputTopic inputTopic;
                + *     ...
                + *     inputTopic = testDriver.createInputTopic(INPUT_TOPIC, longSerializer, stringSerializer);
                + *     ...
                + *     inputTopic.pipeInput("Hello");
                + * }
                + * + * @param the type of the record key + * @param the type of the record value + * @see TopologyTestDriver + */ + +public class TestInputTopic { + private final TopologyTestDriver driver; + private final String topic; + private final Serializer keySerializer; + private final Serializer valueSerializer; + + //Timing + private Instant currentTime; + private final Duration advanceDuration; + + TestInputTopic(final TopologyTestDriver driver, + final String topicName, + final Serializer keySerializer, + final Serializer valueSerializer, + final Instant startTimestamp, + final Duration autoAdvance) { + Objects.requireNonNull(driver, "TopologyTestDriver cannot be null"); + Objects.requireNonNull(topicName, "topicName cannot be null"); + Objects.requireNonNull(keySerializer, "keySerializer cannot be null"); + Objects.requireNonNull(valueSerializer, "valueSerializer cannot be null"); + Objects.requireNonNull(startTimestamp, "startTimestamp cannot be null"); + Objects.requireNonNull(autoAdvance, "autoAdvance cannot be null"); + this.driver = driver; + this.topic = topicName; + this.keySerializer = keySerializer; + this.valueSerializer = valueSerializer; + this.currentTime = startTimestamp; + if (autoAdvance.isNegative()) { + throw new IllegalArgumentException("autoAdvance must be positive"); + } + this.advanceDuration = autoAdvance; + } + + /** + * Advances the internally tracked event time of this input topic. + * Each time a record without explicitly defined timestamp is piped, + * the current topic event time is used as record timestamp. + *

                + * Note: advancing the event time on the input topic, does not advance the tracked stream time in + * {@link TopologyTestDriver} as long as no new input records are piped. + * Furthermore, it does not advance the wall-clock time of {@link TopologyTestDriver}. + * + * @param advance the duration of time to advance + */ + public void advanceTime(final Duration advance) { + if (advance.isNegative()) { + throw new IllegalArgumentException("advance must be positive"); + } + currentTime = currentTime.plus(advance); + } + + private Instant getTimestampAndAdvance() { + final Instant timestamp = currentTime; + currentTime = currentTime.plus(advanceDuration); + return timestamp; + } + + /** + * Send an input record with the given record on the topic and then commit the records. + * May auto advance topic time. + * + * @param record the record to sent + */ + public void pipeInput(final TestRecord record) { + //if record timestamp not set get timestamp and advance + final Instant timestamp = (record.getRecordTime() == null) ? getTimestampAndAdvance() : record.getRecordTime(); + driver.pipeRecord(topic, record, keySerializer, valueSerializer, timestamp); + } + + /** + * Send an input record with the given value on the topic and then commit the records. + * May auto advance topic time. + * + * @param value the record value + */ + public void pipeInput(final V value) { + pipeInput(new TestRecord<>(value)); + } + + /** + * Send an input record with the given key and value on the topic and then commit the records. + * May auto advance topic time + * + * @param key the record key + * @param value the record value + */ + public void pipeInput(final K key, + final V value) { + pipeInput(new TestRecord<>(key, value)); + } + + /** + * Send an input record with the given value and timestamp on the topic and then commit the records. + * Does not auto advance internally tracked time. + * + * @param value the record value + * @param timestamp the record timestamp + */ + public void pipeInput(final V value, + final Instant timestamp) { + pipeInput(new TestRecord<>(null, value, timestamp)); + } + + /** + * Send an input record with the given key, value and timestamp on the topic and then commit the records. + * Does not auto advance internally tracked time. + * + * @param key the record key + * @param value the record value + * @param timestampMs the record timestamp + */ + public void pipeInput(final K key, + final V value, + final long timestampMs) { + pipeInput(new TestRecord<>(key, value, null, timestampMs)); + } + + /** + * Send an input record with the given key, value and timestamp on the topic and then commit the records. + * Does not auto advance internally tracked time. + * + * @param key the record key + * @param value the record value + * @param timestamp the record timestamp + */ + public void pipeInput(final K key, + final V value, + final Instant timestamp) { + pipeInput(new TestRecord<>(key, value, timestamp)); + } + + /** + * Send input records with the given KeyValue list on the topic then commit each record individually. + * The timestamp will be generated based on the constructor provided start time and time will auto advance. + * + * @param records the list of TestRecord records + */ + public void pipeRecordList(final List> records) { + for (final TestRecord record : records) { + pipeInput(record); + } + } + + /** + * Send input records with the given KeyValue list on the topic then commit each record individually. + * The timestamp will be generated based on the constructor provided start time and time will auto advance based on + * {@link #TestInputTopic(TopologyTestDriver, String, Serializer, Serializer, Instant, Duration) autoAdvance} setting. + * + * @param keyValues the {@link List} of {@link KeyValue} records + */ + public void pipeKeyValueList(final List> keyValues) { + for (final KeyValue keyValue : keyValues) { + pipeInput(keyValue.key, keyValue.value); + } + } + + /** + * Send input records with the given value list on the topic then commit each record individually. + * The timestamp will be generated based on the constructor provided start time and time will auto advance based on + * {@link #TestInputTopic(TopologyTestDriver, String, Serializer, Serializer, Instant, Duration) autoAdvance} setting. + * + * @param values the {@link List} of {@link KeyValue} records + */ + public void pipeValueList(final List values) { + for (final V value : values) { + pipeInput(value); + } + } + + /** + * Send input records with the given {@link KeyValue} list on the topic then commit each record individually. + * Does not auto advance internally tracked time. + * + * @param keyValues the {@link List} of {@link KeyValue} records + * @param startTimestamp the timestamp for the first generated record + * @param advance the time difference between two consecutive generated records + */ + public void pipeKeyValueList(final List> keyValues, + final Instant startTimestamp, + final Duration advance) { + Instant recordTime = startTimestamp; + for (final KeyValue keyValue : keyValues) { + pipeInput(keyValue.key, keyValue.value, recordTime); + recordTime = recordTime.plus(advance); + } + } + + /** + * Send input records with the given value list on the topic then commit each record individually. + * The timestamp will be generated based on the constructor provided start time and time will auto advance based on + * {@link #TestInputTopic(TopologyTestDriver, String, Serializer, Serializer, Instant, Duration) autoAdvance} setting. + * + * @param values the {@link List} of values + * @param startTimestamp the timestamp for the first generated record + * @param advance the time difference between two consecutive generated records + */ + public void pipeValueList(final List values, + final Instant startTimestamp, + final Duration advance) { + Instant recordTime = startTimestamp; + for (final V value : values) { + pipeInput(value, recordTime); + recordTime = recordTime.plus(advance); + } + } + + @Override + public String toString() { + return new StringJoiner(", ", TestInputTopic.class.getSimpleName() + "[", "]") + .add("topic='" + topic + "'") + .add("keySerializer=" + keySerializer.getClass().getSimpleName()) + .add("valueSerializer=" + valueSerializer.getClass().getSimpleName()) + .toString(); + } +} diff --git a/streams/test-utils/src/main/java/org/apache/kafka/streams/TestOutputTopic.java b/streams/test-utils/src/main/java/org/apache/kafka/streams/TestOutputTopic.java new file mode 100644 index 0000000..4296ce8 --- /dev/null +++ b/streams/test-utils/src/main/java/org/apache/kafka/streams/TestOutputTopic.java @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.streams.test.TestRecord; + +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.StringJoiner; + +/** + * {@code TestOutputTopic} is used to read records from a topic in {@link TopologyTestDriver}. + * To use {@code TestOutputTopic} create a new instance via + * {@link TopologyTestDriver#createOutputTopic(String, Deserializer, Deserializer)}. + * In actual test code, you can read record values, keys, {@link KeyValue} or {@link TestRecord} + * If you have multiple source topics, you need to create a {@code TestOutputTopic} for each. + *

                + * If you need to test key, value and headers, use {@link #readRecord()} methods. + * Using {@link #readKeyValue()} you get a {@link KeyValue} pair, and thus, don't get access to the record's + * timestamp or headers. + * Similarly using {@link #readValue()} you only get the value of a record. + * + *

                Processing records

                + *
                {@code
                + *     private TestOutputTopic outputTopic;
                + *      ...
                + *     outputTopic = testDriver.createOutputTopic(OUTPUT_TOPIC, stringDeserializer, longDeserializer);
                + *     ...
                + *     assertThat(outputTopic.readValue()).isEqual(1);
                + * }
                + * + * @param the type of the record key + * @param the type of the record value + * @see TopologyTestDriver + */ +public class TestOutputTopic { + private final TopologyTestDriver driver; + private final String topic; + private final Deserializer keyDeserializer; + private final Deserializer valueDeserializer; + + TestOutputTopic(final TopologyTestDriver driver, + final String topicName, + final Deserializer keyDeserializer, + final Deserializer valueDeserializer) { + Objects.requireNonNull(driver, "TopologyTestDriver cannot be null"); + Objects.requireNonNull(topicName, "topicName cannot be null"); + Objects.requireNonNull(keyDeserializer, "keyDeserializer cannot be null"); + Objects.requireNonNull(valueDeserializer, "valueDeserializer cannot be null"); + this.driver = driver; + this.topic = topicName; + this.keyDeserializer = keyDeserializer; + this.valueDeserializer = valueDeserializer; + } + + /** + * Read one record from the output topic and return record's value. + * + * @return Next value for output topic. + */ + public V readValue() { + final TestRecord record = readRecord(); + return record.value(); + } + + /** + * Read one record from the output topic and return its key and value as pair. + * + * @return Next output as {@link KeyValue}. + */ + public KeyValue readKeyValue() { + final TestRecord record = readRecord(); + return new KeyValue<>(record.key(), record.value()); + } + + /** + * Read one Record from output topic. + * + * @return Next output as {@link TestRecord}. + */ + public TestRecord readRecord() { + return driver.readRecord(topic, keyDeserializer, valueDeserializer); + } + + /** + * Read output to List. + * This method can be used if the result is considered a stream. + * If the result is considered a table, the list will contain all updated, ie, a key might be contained multiple times. + * If you are only interested in the last table update (ie, the final table state), + * you can use {@link #readKeyValuesToMap()} instead. + * + * @return List of output. + */ + public List> readRecordsToList() { + final List> output = new LinkedList<>(); + while (!isEmpty()) { + output.add(readRecord()); + } + return output; + } + + + /** + * Read output to map. + * This method can be used if the result is considered a table, + * when you are only interested in the last table update (ie, the final table state). + * If the result is considered a stream, you can use {@link #readRecordsToList()} instead. + * The list will contain all updated, ie, a key might be contained multiple times. + * If the last update to a key is a delete/tombstone, the key will still be in the map (with null-value). + * + * @return Map of output by key. + */ + public Map readKeyValuesToMap() { + final Map output = new HashMap<>(); + TestRecord outputRow; + while (!isEmpty()) { + outputRow = readRecord(); + if (outputRow.key() == null) { + throw new IllegalStateException("Null keys not allowed with readKeyValuesToMap method"); + } + output.put(outputRow.key(), outputRow.value()); + } + return output; + } + + /** + * Read all KeyValues from topic to List. + * + * @return List of output KeyValues. + */ + public List> readKeyValuesToList() { + final List> output = new LinkedList<>(); + KeyValue outputRow; + while (!isEmpty()) { + outputRow = readKeyValue(); + output.add(outputRow); + } + return output; + } + + /** + * Read all values from topic to List. + * + * @return List of output values. + */ + public List readValuesToList() { + final List output = new LinkedList<>(); + V outputValue; + while (!isEmpty()) { + outputValue = readValue(); + output.add(outputValue); + } + return output; + } + + /** + * Get size of unread record in the topic queue. + * + * @return size of topic queue. + */ + public final long getQueueSize() { + return driver.getQueueSize(topic); + } + + /** + * Verify if the topic queue is empty. + * + * @return {@code true} if no more record in the topic queue. + */ + public final boolean isEmpty() { + return driver.isEmpty(topic); + } + + @Override + public String toString() { + return new StringJoiner(", ", TestOutputTopic.class.getSimpleName() + "[", "]") + .add("topic='" + topic + "'") + .add("keyDeserializer=" + keyDeserializer.getClass().getSimpleName()) + .add("valueDeserializer=" + valueDeserializer.getClass().getSimpleName()) + .add("size=" + getQueueSize()) + .toString(); + } +} diff --git a/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java new file mode 100644 index 0000000..05f10e9 --- /dev/null +++ b/streams/test-utils/src/main/java/org/apache/kafka/streams/TopologyTestDriver.java @@ -0,0 +1,1349 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerGroupMetadata; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.MockConsumer; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.consumer.OffsetResetStrategy; +import org.apache.kafka.clients.producer.MockProducer; +import org.apache.kafka.clients.producer.Producer; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.ProducerFencedException; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.metrics.MetricConfig; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.record.TimestampType; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.errors.LogAndContinueExceptionHandler; +import org.apache.kafka.streams.errors.TopologyException; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.Punctuator; +import org.apache.kafka.streams.processor.StateRestoreListener; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.ChangelogRegister; +import org.apache.kafka.streams.processor.internals.ClientUtils; +import org.apache.kafka.streams.processor.internals.GlobalProcessorContextImpl; +import org.apache.kafka.streams.processor.internals.GlobalStateManager; +import org.apache.kafka.streams.processor.internals.GlobalStateManagerImpl; +import org.apache.kafka.streams.processor.internals.GlobalStateUpdateTask; +import org.apache.kafka.streams.processor.internals.InternalProcessorContext; +import org.apache.kafka.streams.processor.internals.InternalTopologyBuilder; +import org.apache.kafka.streams.processor.internals.ProcessorContextImpl; +import org.apache.kafka.streams.processor.internals.ProcessorStateManager; +import org.apache.kafka.streams.processor.internals.ProcessorTopology; +import org.apache.kafka.streams.processor.internals.RecordCollector; +import org.apache.kafka.streams.processor.internals.RecordCollectorImpl; +import org.apache.kafka.streams.processor.internals.StateDirectory; +import org.apache.kafka.streams.processor.internals.StreamTask; +import org.apache.kafka.streams.processor.internals.StreamThread; +import org.apache.kafka.streams.processor.internals.StreamsProducer; +import org.apache.kafka.streams.processor.internals.Task; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.processor.internals.metrics.TaskMetrics; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.ReadOnlyKeyValueStore; +import org.apache.kafka.streams.state.ReadOnlySessionStore; +import org.apache.kafka.streams.state.ReadOnlyWindowStore; +import org.apache.kafka.streams.state.SessionStore; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.TimestampedWindowStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.streams.state.WindowStoreIterator; +import org.apache.kafka.streams.state.internals.ReadOnlyKeyValueStoreFacade; +import org.apache.kafka.streams.state.internals.ReadOnlyWindowStoreFacade; +import org.apache.kafka.streams.state.internals.ThreadCache; +import org.apache.kafka.streams.test.TestRecord; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Closeable; +import java.io.IOException; +import java.time.Duration; +import java.time.Instant; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Objects; +import java.util.Optional; +import java.util.Properties; +import java.util.Queue; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; +import java.util.regex.Pattern; + +import static org.apache.kafka.streams.processor.internals.StreamThread.ProcessingMode.AT_LEAST_ONCE; +import static org.apache.kafka.streams.processor.internals.StreamThread.ProcessingMode.EXACTLY_ONCE_ALPHA; +import static org.apache.kafka.streams.processor.internals.StreamThread.ProcessingMode.EXACTLY_ONCE_V2; +import static org.apache.kafka.streams.state.ValueAndTimestamp.getValueOrNull; + +/** + * This class makes it easier to write tests to verify the behavior of topologies created with {@link Topology} or + * {@link StreamsBuilder}. + * You can test simple topologies that have a single processor, or very complex topologies that have multiple sources, + * processors, sinks, or sub-topologies. + * Best of all, the class works without a real Kafka broker, so the tests execute very quickly with very little overhead. + *

                + * Using the {@code TopologyTestDriver} in tests is easy: simply instantiate the driver and provide a {@link Topology} + * (cf. {@link StreamsBuilder#build()}) and {@link Properties config}, {@link #createInputTopic(String, Serializer, Serializer) create} + * and use a {@link TestInputTopic} to supply an input records to the topology, + * and then {@link #createOutputTopic(String, Deserializer, Deserializer) create} and use a {@link TestOutputTopic} to read and + * verify any output records by the topology. + *

                + * Although the driver doesn't use a real Kafka broker, it does simulate Kafka {@link Consumer consumers} and + * {@link Producer producers} that read and write raw {@code byte[]} messages. + * You can let {@link TestInputTopic} and {@link TestOutputTopic} to handle conversion + * form regular Java objects to raw bytes. + * + *

                Driver setup

                + * In order to create a {@code TopologyTestDriver} instance, you need a {@link Topology} and a {@link Properties config}. + * The configuration needs to be representative of what you'd supply to the real topology, so that means including + * several key properties (cf. {@link StreamsConfig}). + * For example, the following code fragment creates a configuration that specifies a timestamp extractor, + * default serializers and deserializers for string keys and values: + * + *
                {@code
                + * Properties props = new Properties();
                + * props.setProperty(StreamsConfig.DEFAULT_TIMESTAMP_EXTRACTOR_CLASS_CONFIG, CustomTimestampExtractor.class.getName());
                + * props.setProperty(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass().getName());
                + * props.setProperty(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass().getName());
                + * Topology topology = ...
                + * TopologyTestDriver driver = new TopologyTestDriver(topology, props);
                + * }
                + * + *

                Note that the {@code TopologyTestDriver} processes input records synchronously. + * This implies that {@link StreamsConfig#COMMIT_INTERVAL_MS_CONFIG commit.interval.ms} and + * {@link StreamsConfig#CACHE_MAX_BYTES_BUFFERING_CONFIG cache.max.bytes.buffering} configuration have no effect. + * The driver behaves as if both configs would be set to zero, i.e., as if a "commit" (and thus "flush") would happen + * after each input record. + * + *

                Processing messages

                + *

                + * Your test can supply new input records on any of the topics that the topology's sources consume. + * This test driver simulates single-partitioned input topics. + * Here's an example of an input message on the topic named {@code input-topic}: + * + *

                {@code
                + * TestInputTopic inputTopic = driver.createInputTopic("input-topic", stringSerdeSerializer, stringSerializer);
                + * inputTopic.pipeInput("key1", "value1");
                + * }
                + * + * When {@link TestInputTopic#pipeInput(Object, Object)} is called, the driver passes the input message through to the appropriate source that + * consumes the named topic, and will invoke the processor(s) downstream of the source. + * If your topology's processors forward messages to sinks, your test can then consume these output messages to verify + * they match the expected outcome. + * For example, if our topology should have generated 2 messages on {@code output-topic-1} and 1 message on + * {@code output-topic-2}, then our test can obtain these messages using the + * {@link TestOutputTopic#readKeyValue()} method: + * + *
                {@code
                + * TestOutputTopic outputTopic1 = driver.createOutputTopic("output-topic-1", stringDeserializer, stringDeserializer);
                + * TestOutputTopic outputTopic2 = driver.createOutputTopic("output-topic-2", stringDeserializer, stringDeserializer);
                + *
                + * KeyValue record1 = outputTopic1.readKeyValue();
                + * KeyValue record2 = outputTopic2.readKeyValue();
                + * KeyValue record3 = outputTopic1.readKeyValue();
                + * }
                + * + * Again, our example topology generates messages with string keys and values, so we supply our string deserializer + * instance for use on both the keys and values. Your test logic can then verify whether these output records are + * correct. + *

                + * Note, that calling {@code pipeInput()} will also trigger {@link PunctuationType#STREAM_TIME event-time} base + * {@link ProcessorContext#schedule(Duration, PunctuationType, Punctuator) punctuation} callbacks. + * However, you won't trigger {@link PunctuationType#WALL_CLOCK_TIME wall-clock} type punctuations that you must + * trigger manually via {@link #advanceWallClockTime(Duration)}. + *

                + * Finally, when completed, make sure your tests {@link #close()} the driver to release all resources and + * {@link org.apache.kafka.streams.processor.api.Processor processors}. + * + *

                Processor state

                + *

                + * Some processors use Kafka {@link StateStore state storage}, so this driver class provides the generic + * {@link #getStateStore(String)} as well as store-type specific methods so that your tests can check the underlying + * state store(s) used by your topology's processors. + * In our previous example, after we supplied a single input message and checked the three output messages, our test + * could also check the key value store to verify the processor correctly added, removed, or updated internal state. + * Or, our test might have pre-populated some state before submitting the input message, and verified afterward + * that the processor(s) correctly updated the state. + * + * @see TestInputTopic + * @see TestOutputTopic + */ +public class TopologyTestDriver implements Closeable { + + private static final Logger log = LoggerFactory.getLogger(TopologyTestDriver.class); + + private final LogContext logContext; + private final Time mockWallClockTime; + private InternalTopologyBuilder internalTopologyBuilder; + + private final static int PARTITION_ID = 0; + private final static TaskId TASK_ID = new TaskId(0, PARTITION_ID); + StreamTask task; + private GlobalStateUpdateTask globalStateTask; + private GlobalStateManager globalStateManager; + + private StateDirectory stateDirectory; + private Metrics metrics; + ProcessorTopology processorTopology; + ProcessorTopology globalTopology; + + private final MockConsumer consumer; + private final MockProducer producer; + private final TestDriverProducer testDriverProducer; + + private final Map partitionsByInputTopic = new HashMap<>(); + private final Map globalPartitionsByInputTopic = new HashMap<>(); + private final Map offsetsByTopicOrPatternPartition = new HashMap<>(); + + private final Map>> outputRecordsByTopic = new HashMap<>(); + private final StreamThread.ProcessingMode processingMode; + + private final StateRestoreListener stateRestoreListener = new StateRestoreListener() { + @Override + public void onRestoreStart(final TopicPartition topicPartition, final String storeName, final long startingOffset, final long endingOffset) {} + + @Override + public void onBatchRestored(final TopicPartition topicPartition, final String storeName, final long batchEndOffset, final long numRestored) {} + + @Override + public void onRestoreEnd(final TopicPartition topicPartition, final String storeName, final long totalRestored) {} + }; + + /** + * Create a new test diver instance. + * Default test properties are used to initialize the driver instance + * + * @param topology the topology to be tested + */ + public TopologyTestDriver(final Topology topology) { + this(topology, new Properties()); + } + + /** + * Create a new test diver instance. + * Initialized the internally mocked wall-clock time with {@link System#currentTimeMillis() current system time}. + * + * @param topology the topology to be tested + * @param config the configuration for the topology + */ + public TopologyTestDriver(final Topology topology, + final Properties config) { + this(topology, config, null); + } + + /** + * Create a new test diver instance. + * + * @param topology the topology to be tested + * @param initialWallClockTimeMs the initial value of internally mocked wall-clock time + */ + public TopologyTestDriver(final Topology topology, + final Instant initialWallClockTimeMs) { + this(topology, new Properties(), initialWallClockTimeMs); + } + + /** + * Create a new test diver instance. + * + * @param topology the topology to be tested + * @param config the configuration for the topology + * @param initialWallClockTime the initial value of internally mocked wall-clock time + */ + public TopologyTestDriver(final Topology topology, + final Properties config, + final Instant initialWallClockTime) { + this( + topology.internalTopologyBuilder, + config, + initialWallClockTime == null ? System.currentTimeMillis() : initialWallClockTime.toEpochMilli()); + } + + /** + * Create a new test diver instance. + * + * @param builder builder for the topology to be tested + * @param config the configuration for the topology + * @param initialWallClockTimeMs the initial value of internally mocked wall-clock time + */ + private TopologyTestDriver(final InternalTopologyBuilder builder, + final Properties config, + final long initialWallClockTimeMs) { + final Properties configCopy = new Properties(); + configCopy.putAll(config); + configCopy.putIfAbsent(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy-bootstrap-host:0"); + // provide randomized dummy app-id if it's not specified + configCopy.putIfAbsent(StreamsConfig.APPLICATION_ID_CONFIG, "dummy-topology-test-driver-app-id-" + ThreadLocalRandom.current().nextInt()); + final StreamsConfig streamsConfig = new ClientUtils.QuietStreamsConfig(configCopy); + logIfTaskIdleEnabled(streamsConfig); + + logContext = new LogContext("topology-test-driver "); + mockWallClockTime = new MockTime(initialWallClockTimeMs); + processingMode = StreamThread.processingMode(streamsConfig); + + final StreamsMetricsImpl streamsMetrics = setupMetrics(streamsConfig); + setupTopology(builder, streamsConfig); + + final ThreadCache cache = new ThreadCache( + logContext, + Math.max(0, streamsConfig.getLong(StreamsConfig.CACHE_MAX_BYTES_BUFFERING_CONFIG)), + streamsMetrics + ); + + consumer = new MockConsumer<>(OffsetResetStrategy.EARLIEST); + final Serializer bytesSerializer = new ByteArraySerializer(); + producer = new MockProducer(true, bytesSerializer, bytesSerializer) { + @Override + public List partitionsFor(final String topic) { + return Collections.singletonList(new PartitionInfo(topic, PARTITION_ID, null, null, null)); + } + }; + testDriverProducer = new TestDriverProducer( + streamsConfig, + new KafkaClientSupplier() { + @Override + public Producer getProducer(final Map config) { + return producer; + } + + @Override + public Consumer getConsumer(final Map config) { + throw new IllegalStateException(); + } + + @Override + public Consumer getRestoreConsumer(final Map config) { + throw new IllegalStateException(); + } + + @Override + public Consumer getGlobalConsumer(final Map config) { + throw new IllegalStateException(); + } + }, + logContext, + mockWallClockTime + ); + + setupGlobalTask(mockWallClockTime, streamsConfig, streamsMetrics, cache); + setupTask(streamsConfig, streamsMetrics, cache); + } + + private static void logIfTaskIdleEnabled(final StreamsConfig streamsConfig) { + final Long taskIdleTime = streamsConfig.getLong(StreamsConfig.MAX_TASK_IDLE_MS_CONFIG); + if (taskIdleTime > 0) { + log.info("Detected {} config in use with TopologyTestDriver (set to {}ms)." + + " This means you might need to use TopologyTestDriver#advanceWallClockTime()" + + " or enqueue records on all partitions to allow Steams to make progress." + + " TopologyTestDriver will log a message each time it cannot process enqueued" + + " records due to {}.", + StreamsConfig.MAX_TASK_IDLE_MS_CONFIG, + taskIdleTime, + StreamsConfig.MAX_TASK_IDLE_MS_CONFIG); + } + } + + private StreamsMetricsImpl setupMetrics(final StreamsConfig streamsConfig) { + final String threadId = Thread.currentThread().getName(); + + final MetricConfig metricConfig = new MetricConfig() + .samples(streamsConfig.getInt(StreamsConfig.METRICS_NUM_SAMPLES_CONFIG)) + .recordLevel(Sensor.RecordingLevel.forName(streamsConfig.getString(StreamsConfig.METRICS_RECORDING_LEVEL_CONFIG))) + .timeWindow(streamsConfig.getLong(StreamsConfig.METRICS_SAMPLE_WINDOW_MS_CONFIG), TimeUnit.MILLISECONDS); + metrics = new Metrics(metricConfig, mockWallClockTime); + + final StreamsMetricsImpl streamsMetrics = new StreamsMetricsImpl( + metrics, + "test-client", + streamsConfig.getString(StreamsConfig.BUILT_IN_METRICS_VERSION_CONFIG), + mockWallClockTime + ); + TaskMetrics.droppedRecordsSensor(threadId, TASK_ID.toString(), streamsMetrics); + + return streamsMetrics; + } + + private void setupTopology(final InternalTopologyBuilder builder, + final StreamsConfig streamsConfig) { + internalTopologyBuilder = builder; + internalTopologyBuilder.rewriteTopology(streamsConfig); + + processorTopology = internalTopologyBuilder.buildTopology(); + globalTopology = internalTopologyBuilder.buildGlobalStateTopology(); + + for (final String topic : processorTopology.sourceTopics()) { + final TopicPartition tp = new TopicPartition(topic, PARTITION_ID); + partitionsByInputTopic.put(topic, tp); + offsetsByTopicOrPatternPartition.put(tp, new AtomicLong()); + } + + stateDirectory = new StateDirectory(streamsConfig, mockWallClockTime, internalTopologyBuilder.hasPersistentStores(), false); + } + + private void setupGlobalTask(final Time mockWallClockTime, + final StreamsConfig streamsConfig, + final StreamsMetricsImpl streamsMetrics, + final ThreadCache cache) { + if (globalTopology != null) { + final MockConsumer globalConsumer = new MockConsumer<>(OffsetResetStrategy.NONE); + for (final String topicName : globalTopology.sourceTopics()) { + final TopicPartition partition = new TopicPartition(topicName, 0); + globalPartitionsByInputTopic.put(topicName, partition); + offsetsByTopicOrPatternPartition.put(partition, new AtomicLong()); + globalConsumer.updatePartitions(topicName, Collections.singletonList( + new PartitionInfo(topicName, 0, null, null, null))); + globalConsumer.updateBeginningOffsets(Collections.singletonMap(partition, 0L)); + globalConsumer.updateEndOffsets(Collections.singletonMap(partition, 0L)); + } + + globalStateManager = new GlobalStateManagerImpl( + logContext, + mockWallClockTime, + globalTopology, + globalConsumer, + stateDirectory, + stateRestoreListener, + streamsConfig + ); + + final GlobalProcessorContextImpl globalProcessorContext = + new GlobalProcessorContextImpl(streamsConfig, globalStateManager, streamsMetrics, cache, mockWallClockTime); + globalStateManager.setGlobalProcessorContext(globalProcessorContext); + + globalStateTask = new GlobalStateUpdateTask( + logContext, + globalTopology, + globalProcessorContext, + globalStateManager, + new LogAndContinueExceptionHandler() + ); + globalStateTask.initialize(); + globalProcessorContext.setRecordContext(null); + } else { + globalStateManager = null; + globalStateTask = null; + } + } + + @SuppressWarnings("deprecation") + private void setupTask(final StreamsConfig streamsConfig, + final StreamsMetricsImpl streamsMetrics, + final ThreadCache cache) { + if (!partitionsByInputTopic.isEmpty()) { + consumer.assign(partitionsByInputTopic.values()); + final Map startOffsets = new HashMap<>(); + for (final TopicPartition topicPartition : partitionsByInputTopic.values()) { + startOffsets.put(topicPartition, 0L); + } + consumer.updateBeginningOffsets(startOffsets); + + final ProcessorStateManager stateManager = new ProcessorStateManager( + TASK_ID, + Task.TaskType.ACTIVE, + StreamsConfig.EXACTLY_ONCE.equals(streamsConfig.getString(StreamsConfig.PROCESSING_GUARANTEE_CONFIG)), + logContext, + stateDirectory, + new MockChangelogRegister(), + processorTopology.storeToChangelogTopic(), + new HashSet<>(partitionsByInputTopic.values()) + ); + final RecordCollector recordCollector = new RecordCollectorImpl( + logContext, + TASK_ID, + testDriverProducer, + streamsConfig.defaultProductionExceptionHandler(), + streamsMetrics + ); + + final InternalProcessorContext context = new ProcessorContextImpl( + TASK_ID, + streamsConfig, + stateManager, + streamsMetrics, + cache + ); + + task = new StreamTask( + TASK_ID, + new HashSet<>(partitionsByInputTopic.values()), + processorTopology, + consumer, + streamsConfig, + streamsMetrics, + stateDirectory, + cache, + mockWallClockTime, + stateManager, + recordCollector, + context, + logContext); + task.initializeIfNeeded(); + task.completeRestoration(noOpResetter -> { }); + task.processorContext().setRecordContext(null); + + } else { + task = null; + } + } + + /** + * Get read-only handle on global metrics registry. + * + * @return Map of all metrics. + */ + public Map metrics() { + return Collections.unmodifiableMap(metrics.metrics()); + } + + private void pipeRecord(final String topicName, + final long timestamp, + final byte[] key, + final byte[] value, + final Headers headers) { + final TopicPartition inputTopicOrPatternPartition = getInputTopicOrPatternPartition(topicName); + final TopicPartition globalInputTopicPartition = globalPartitionsByInputTopic.get(topicName); + + if (inputTopicOrPatternPartition == null && globalInputTopicPartition == null) { + throw new IllegalArgumentException("Unknown topic: " + topicName); + } + + if (inputTopicOrPatternPartition != null) { + enqueueTaskRecord(topicName, inputTopicOrPatternPartition, timestamp, key, value, headers); + completeAllProcessableWork(); + } + + if (globalInputTopicPartition != null) { + processGlobalRecord(globalInputTopicPartition, timestamp, key, value, headers); + } + } + + private void enqueueTaskRecord(final String inputTopic, + final TopicPartition topicOrPatternPartition, + final long timestamp, + final byte[] key, + final byte[] value, + final Headers headers) { + final long offset = offsetsByTopicOrPatternPartition.get(topicOrPatternPartition).incrementAndGet() - 1; + task.addRecords(topicOrPatternPartition, Collections.singleton(new ConsumerRecord<>( + inputTopic, + topicOrPatternPartition.partition(), + offset, + timestamp, + TimestampType.CREATE_TIME, + key == null ? ConsumerRecord.NULL_SIZE : key.length, + value == null ? ConsumerRecord.NULL_SIZE : value.length, + key, + value, + headers, + Optional.empty())) + ); + } + + private void completeAllProcessableWork() { + // for internally triggered processing (like wall-clock punctuations), + // we might have buffered some records to internal topics that need to + // be piped back in to kick-start the processing loop. This is idempotent + // and therefore harmless in the case where all we've done is enqueued an + // input record from the user. + captureOutputsAndReEnqueueInternalResults(); + + // If the topology only has global tasks, then `task` would be null. + // For this method, it just means there's nothing to do. + if (task != null) { + while (task.hasRecordsQueued() && task.isProcessable(mockWallClockTime.milliseconds())) { + // Process the record ... + task.process(mockWallClockTime.milliseconds()); + task.maybePunctuateStreamTime(); + commit(task.prepareCommit()); + task.postCommit(true); + captureOutputsAndReEnqueueInternalResults(); + } + if (task.hasRecordsQueued()) { + log.info("Due to the {} configuration, there are currently some records" + + " that cannot be processed. Advancing wall-clock time or" + + " enqueuing records on the empty topics will allow" + + " Streams to process more.", + StreamsConfig.MAX_TASK_IDLE_MS_CONFIG); + } + } + } + + private void commit(final Map offsets) { + if (processingMode == EXACTLY_ONCE_ALPHA || processingMode == EXACTLY_ONCE_V2) { + testDriverProducer.commitTransaction(offsets, new ConsumerGroupMetadata("dummy-app-id")); + } else { + consumer.commitSync(offsets); + } + } + + private void processGlobalRecord(final TopicPartition globalInputTopicPartition, + final long timestamp, + final byte[] key, + final byte[] value, + final Headers headers) { + globalStateTask.update(new ConsumerRecord<>( + globalInputTopicPartition.topic(), + globalInputTopicPartition.partition(), + offsetsByTopicOrPatternPartition.get(globalInputTopicPartition).incrementAndGet() - 1, + timestamp, + TimestampType.CREATE_TIME, + key == null ? ConsumerRecord.NULL_SIZE : key.length, + value == null ? ConsumerRecord.NULL_SIZE : value.length, + key, + value, + headers, + Optional.empty()) + ); + globalStateTask.flushState(); + } + + private void validateSourceTopicNameRegexPattern(final String inputRecordTopic) { + for (final String sourceTopicName : internalTopologyBuilder.sourceTopicNames()) { + if (!sourceTopicName.equals(inputRecordTopic) && Pattern.compile(sourceTopicName).matcher(inputRecordTopic).matches()) { + throw new TopologyException("Topology add source of type String for topic: " + sourceTopicName + + " cannot contain regex pattern for input record topic: " + inputRecordTopic + + " and hence cannot process the message."); + } + } + } + + private TopicPartition getInputTopicOrPatternPartition(final String topicName) { + if (!internalTopologyBuilder.sourceTopicNames().isEmpty()) { + validateSourceTopicNameRegexPattern(topicName); + } + + final TopicPartition topicPartition = partitionsByInputTopic.get(topicName); + if (topicPartition == null) { + for (final Map.Entry entry : partitionsByInputTopic.entrySet()) { + if (Pattern.compile(entry.getKey()).matcher(topicName).matches()) { + return entry.getValue(); + } + } + } + return topicPartition; + } + + private void captureOutputsAndReEnqueueInternalResults() { + // Capture all the records sent to the producer ... + final List> output = producer.history(); + producer.clear(); + + for (final ProducerRecord record : output) { + outputRecordsByTopic.computeIfAbsent(record.topic(), k -> new LinkedList<>()).add(record); + + // Forward back into the topology if the produced record is to an internal or a source topic ... + final String outputTopicName = record.topic(); + + final TopicPartition inputTopicOrPatternPartition = getInputTopicOrPatternPartition(outputTopicName); + final TopicPartition globalInputTopicPartition = globalPartitionsByInputTopic.get(outputTopicName); + + if (inputTopicOrPatternPartition != null) { + enqueueTaskRecord( + outputTopicName, + inputTopicOrPatternPartition, + record.timestamp(), + record.key(), + record.value(), + record.headers() + ); + } + + if (globalInputTopicPartition != null) { + processGlobalRecord( + globalInputTopicPartition, + record.timestamp(), + record.key(), + record.value(), + record.headers() + ); + } + } + } + + /** + * Advances the internally mocked wall-clock time. + * This might trigger a {@link PunctuationType#WALL_CLOCK_TIME wall-clock} type + * {@link ProcessorContext#schedule(Duration, PunctuationType, Punctuator) punctuations}. + * + * @param advance the amount of time to advance wall-clock time + */ + public void advanceWallClockTime(final Duration advance) { + Objects.requireNonNull(advance, "advance cannot be null"); + mockWallClockTime.sleep(advance.toMillis()); + if (task != null) { + task.maybePunctuateSystemTime(); + commit(task.prepareCommit()); + task.postCommit(true); + } + completeAllProcessableWork(); + } + + private Queue> getRecordsQueue(final String topicName) { + final Queue> outputRecords = outputRecordsByTopic.get(topicName); + if (outputRecords == null && !processorTopology.sinkTopics().contains(topicName)) { + log.warn("Unrecognized topic: {}, this can occur if dynamic routing is used and no output has been " + + "sent to this topic yet. If not using a TopicNameExtractor, check that the output topic " + + "is correct.", topicName); + } + return outputRecords; + } + + /** + * Create {@link TestInputTopic} to be used for piping records to topic + * Uses current system time as start timestamp for records. + * Auto-advance is disabled. + * + * @param topicName the name of the topic + * @param keySerializer the Serializer for the key type + * @param valueSerializer the Serializer for the value type + * @param the key type + * @param the value type + * @return {@link TestInputTopic} object + */ + public final TestInputTopic createInputTopic(final String topicName, + final Serializer keySerializer, + final Serializer valueSerializer) { + return new TestInputTopic<>(this, topicName, keySerializer, valueSerializer, Instant.now(), Duration.ZERO); + } + + /** + * Create {@link TestInputTopic} to be used for piping records to topic + * Uses provided start timestamp and autoAdvance parameter for records + * + * @param topicName the name of the topic + * @param keySerializer the Serializer for the key type + * @param valueSerializer the Serializer for the value type + * @param startTimestamp Start timestamp for auto-generated record time + * @param autoAdvance autoAdvance duration for auto-generated record time + * @param the key type + * @param the value type + * @return {@link TestInputTopic} object + */ + public final TestInputTopic createInputTopic(final String topicName, + final Serializer keySerializer, + final Serializer valueSerializer, + final Instant startTimestamp, + final Duration autoAdvance) { + return new TestInputTopic<>(this, topicName, keySerializer, valueSerializer, startTimestamp, autoAdvance); + } + + /** + * Create {@link TestOutputTopic} to be used for reading records from topic + * + * @param topicName the name of the topic + * @param keyDeserializer the Deserializer for the key type + * @param valueDeserializer the Deserializer for the value type + * @param the key type + * @param the value type + * @return {@link TestOutputTopic} object + */ + public final TestOutputTopic createOutputTopic(final String topicName, + final Deserializer keyDeserializer, + final Deserializer valueDeserializer) { + return new TestOutputTopic<>(this, topicName, keyDeserializer, valueDeserializer); + } + + /** + * Get all the names of all the topics to which records have been produced during the test run. + *

                + * Call this method after piping the input into the test driver to retrieve the full set of topic names the topology + * produced records to. + *

                + * The returned set of topic names may include user (e.g., output) and internal (e.g., changelog, repartition) topic + * names. + * + * @return the set of topic names the topology has produced to + */ + public final Set producedTopicNames() { + return Collections.unmodifiableSet(outputRecordsByTopic.keySet()); + } + + ProducerRecord readRecord(final String topic) { + final Queue> outputRecords = getRecordsQueue(topic); + if (outputRecords == null) { + return null; + } + return outputRecords.poll(); + } + + TestRecord readRecord(final String topic, + final Deserializer keyDeserializer, + final Deserializer valueDeserializer) { + final Queue> outputRecords = getRecordsQueue(topic); + if (outputRecords == null) { + throw new NoSuchElementException("Uninitialized topic: " + topic); + } + final ProducerRecord record = outputRecords.poll(); + if (record == null) { + throw new NoSuchElementException("Empty topic: " + topic); + } + final K key = keyDeserializer.deserialize(record.topic(), record.headers(), record.key()); + final V value = valueDeserializer.deserialize(record.topic(), record.headers(), record.value()); + return new TestRecord<>(key, value, record.headers(), record.timestamp()); + } + + void pipeRecord(final String topic, + final TestRecord record, + final Serializer keySerializer, + final Serializer valueSerializer, + final Instant time) { + final byte[] serializedKey = keySerializer.serialize(topic, record.headers(), record.key()); + final byte[] serializedValue = valueSerializer.serialize(topic, record.headers(), record.value()); + final long timestamp; + if (time != null) { + timestamp = time.toEpochMilli(); + } else if (record.timestamp() != null) { + timestamp = record.timestamp(); + } else { + throw new IllegalStateException("Provided `TestRecord` does not have a timestamp and no timestamp overwrite was provided via `time` parameter."); + } + + pipeRecord(topic, timestamp, serializedKey, serializedValue, record.headers()); + } + + final long getQueueSize(final String topic) { + final Queue> queue = getRecordsQueue(topic); + if (queue == null) { + //Return 0 if not initialized, getRecordsQueue throw exception if non existing topic + return 0; + } + return queue.size(); + } + + final boolean isEmpty(final String topic) { + return getQueueSize(topic) == 0; + } + + /** + * Get all {@link StateStore StateStores} from the topology. + * The stores can be a "regular" or global stores. + *

                + * This is often useful in test cases to pre-populate the store before the test case instructs the topology to + * {@link TestInputTopic#pipeInput(TestRecord)} process an input message}, and/or to check the store afterward. + *

                + * Note, that {@code StateStore} might be {@code null} if a store is added but not connected to any processor. + *

                + * Caution: Using this method to access stores that are added by the DSL is unsafe as the store + * types may change. Stores added by the DSL should only be accessed via the corresponding typed methods + * like {@link #getKeyValueStore(String)} etc. + * + * @return all stores my name + * @see #getStateStore(String) + * @see #getKeyValueStore(String) + * @see #getTimestampedKeyValueStore(String) + * @see #getWindowStore(String) + * @see #getTimestampedWindowStore(String) + * @see #getSessionStore(String) + */ + public Map getAllStateStores() { + final Map allStores = new HashMap<>(); + for (final String storeName : internalTopologyBuilder.allStateStoreNames()) { + allStores.put(storeName, getStateStore(storeName, false)); + } + return allStores; + } + + /** + * Get the {@link StateStore} with the given name. + * The store can be a "regular" or global store. + *

                + * Should be used for custom stores only. + * For built-in stores, the corresponding typed methods like {@link #getKeyValueStore(String)} should be used. + *

                + * This is often useful in test cases to pre-populate the store before the test case instructs the topology to + * {@link TestInputTopic#pipeInput(TestRecord) process an input message}, and/or to check the store afterward. + * + * @param name the name of the store + * @return the state store, or {@code null} if no store has been registered with the given name + * @throws IllegalArgumentException if the store is a built-in store like {@link KeyValueStore}, + * {@link WindowStore}, or {@link SessionStore} + * + * @see #getAllStateStores() + * @see #getKeyValueStore(String) + * @see #getTimestampedKeyValueStore(String) + * @see #getWindowStore(String) + * @see #getTimestampedWindowStore(String) + * @see #getSessionStore(String) + */ + public StateStore getStateStore(final String name) throws IllegalArgumentException { + return getStateStore(name, true); + } + + private StateStore getStateStore(final String name, + final boolean throwForBuiltInStores) { + if (task != null) { + final StateStore stateStore = ((ProcessorContextImpl) task.processorContext()).stateManager().getStore(name); + if (stateStore != null) { + if (throwForBuiltInStores) { + throwIfBuiltInStore(stateStore); + } + return stateStore; + } + } + + if (globalStateManager != null) { + final StateStore stateStore = globalStateManager.getStore(name); + if (stateStore != null) { + if (throwForBuiltInStores) { + throwIfBuiltInStore(stateStore); + } + return stateStore; + } + + } + + return null; + } + + private void throwIfBuiltInStore(final StateStore stateStore) { + if (stateStore instanceof TimestampedKeyValueStore) { + throw new IllegalArgumentException("Store " + stateStore.name() + + " is a timestamped key-value store and should be accessed via `getTimestampedKeyValueStore()`"); + } + if (stateStore instanceof ReadOnlyKeyValueStore) { + throw new IllegalArgumentException("Store " + stateStore.name() + + " is a key-value store and should be accessed via `getKeyValueStore()`"); + } + if (stateStore instanceof TimestampedWindowStore) { + throw new IllegalArgumentException("Store " + stateStore.name() + + " is a timestamped window store and should be accessed via `getTimestampedWindowStore()`"); + } + if (stateStore instanceof ReadOnlyWindowStore) { + throw new IllegalArgumentException("Store " + stateStore.name() + + " is a window store and should be accessed via `getWindowStore()`"); + } + if (stateStore instanceof ReadOnlySessionStore) { + throw new IllegalArgumentException("Store " + stateStore.name() + + " is a session store and should be accessed via `getSessionStore()`"); + } + } + + /** + * Get the {@link KeyValueStore} or {@link TimestampedKeyValueStore} with the given name. + * The store can be a "regular" or global store. + *

                + * If the registered store is a {@link TimestampedKeyValueStore} this method will return a value-only query + * interface. It is highly recommended to update the code for this case to avoid bugs and to use + * {@link #getTimestampedKeyValueStore(String)} for full store access instead. + *

                + * This is often useful in test cases to pre-populate the store before the test case instructs the topology to + * {@link TestInputTopic#pipeInput(TestRecord) process an input message}, and/or to check the store afterward. + * + * @param name the name of the store + * @return the key value store, or {@code null} if no {@link KeyValueStore} or {@link TimestampedKeyValueStore} + * has been registered with the given name + * @see #getAllStateStores() + * @see #getStateStore(String) + * @see #getTimestampedKeyValueStore(String) + * @see #getWindowStore(String) + * @see #getTimestampedWindowStore(String) + * @see #getSessionStore(String) + */ + @SuppressWarnings("unchecked") + public KeyValueStore getKeyValueStore(final String name) { + final StateStore store = getStateStore(name, false); + if (store instanceof TimestampedKeyValueStore) { + log.info("Method #getTimestampedKeyValueStore() should be used to access a TimestampedKeyValueStore."); + return new KeyValueStoreFacade<>((TimestampedKeyValueStore) store); + } + return store instanceof KeyValueStore ? (KeyValueStore) store : null; + } + + /** + * Get the {@link TimestampedKeyValueStore} with the given name. + * The store can be a "regular" or global store. + *

                + * This is often useful in test cases to pre-populate the store before the test case instructs the topology to + * {@link TestInputTopic#pipeInput(TestRecord) process an input message}, and/or to check the store afterward. + * + * @param name the name of the store + * @return the key value store, or {@code null} if no {@link TimestampedKeyValueStore} has been registered with the given name + * @see #getAllStateStores() + * @see #getStateStore(String) + * @see #getKeyValueStore(String) + * @see #getWindowStore(String) + * @see #getTimestampedWindowStore(String) + * @see #getSessionStore(String) + */ + @SuppressWarnings("unchecked") + public KeyValueStore> getTimestampedKeyValueStore(final String name) { + final StateStore store = getStateStore(name, false); + return store instanceof TimestampedKeyValueStore ? (TimestampedKeyValueStore) store : null; + } + + /** + * Get the {@link WindowStore} or {@link TimestampedWindowStore} with the given name. + * The store can be a "regular" or global store. + *

                + * If the registered store is a {@link TimestampedWindowStore} this method will return a value-only query + * interface. It is highly recommended to update the code for this case to avoid bugs and to use + * {@link #getTimestampedWindowStore(String)} for full store access instead. + *

                + * This is often useful in test cases to pre-populate the store before the test case instructs the topology to + * {@link TestInputTopic#pipeInput(TestRecord) process an input message}, and/or to check the store afterward. + * + * @param name the name of the store + * @return the key value store, or {@code null} if no {@link WindowStore} or {@link TimestampedWindowStore} + * has been registered with the given name + * @see #getAllStateStores() + * @see #getStateStore(String) + * @see #getKeyValueStore(String) + * @see #getTimestampedKeyValueStore(String) + * @see #getTimestampedWindowStore(String) + * @see #getSessionStore(String) + */ + @SuppressWarnings("unchecked") + public WindowStore getWindowStore(final String name) { + final StateStore store = getStateStore(name, false); + if (store instanceof TimestampedWindowStore) { + log.info("Method #getTimestampedWindowStore() should be used to access a TimestampedWindowStore."); + return new WindowStoreFacade<>((TimestampedWindowStore) store); + } + return store instanceof WindowStore ? (WindowStore) store : null; + } + + /** + * Get the {@link TimestampedWindowStore} with the given name. + * The store can be a "regular" or global store. + *

                + * This is often useful in test cases to pre-populate the store before the test case instructs the topology to + * {@link TestInputTopic#pipeInput(TestRecord) process an input message}, and/or to check the store afterward. + * + * @param name the name of the store + * @return the key value store, or {@code null} if no {@link TimestampedWindowStore} has been registered with the given name + * @see #getAllStateStores() + * @see #getStateStore(String) + * @see #getKeyValueStore(String) + * @see #getTimestampedKeyValueStore(String) + * @see #getWindowStore(String) + * @see #getSessionStore(String) + */ + @SuppressWarnings("unchecked") + public WindowStore> getTimestampedWindowStore(final String name) { + final StateStore store = getStateStore(name, false); + return store instanceof TimestampedWindowStore ? (TimestampedWindowStore) store : null; + } + + /** + * Get the {@link SessionStore} with the given name. + * The store can be a "regular" or global store. + *

                + * This is often useful in test cases to pre-populate the store before the test case instructs the topology to + * {@link TestInputTopic#pipeInput(TestRecord) process an input message}, and/or to check the store afterward. + * + * @param name the name of the store + * @return the key value store, or {@code null} if no {@link SessionStore} has been registered with the given name + * @see #getAllStateStores() + * @see #getStateStore(String) + * @see #getKeyValueStore(String) + * @see #getTimestampedKeyValueStore(String) + * @see #getWindowStore(String) + * @see #getTimestampedWindowStore(String) + */ + @SuppressWarnings("unchecked") + public SessionStore getSessionStore(final String name) { + final StateStore store = getStateStore(name, false); + return store instanceof SessionStore ? (SessionStore) store : null; + } + + /** + * Close the driver, its topology, and all processors. + */ + public void close() { + if (task != null) { + task.suspend(); + task.prepareCommit(); + task.postCommit(true); + task.closeClean(); + } + if (globalStateTask != null) { + try { + globalStateTask.close(false); + } catch (final IOException e) { + // ignore + } + } + completeAllProcessableWork(); + if (task != null && task.hasRecordsQueued()) { + log.warn("Found some records that cannot be processed due to the" + + " {} configuration during TopologyTestDriver#close().", + StreamsConfig.MAX_TASK_IDLE_MS_CONFIG); + } + if (processingMode == AT_LEAST_ONCE) { + producer.close(); + } + stateDirectory.clean(); + } + + static class MockChangelogRegister implements ChangelogRegister { + private final Set restoringPartitions = new HashSet<>(); + + @Override + public void register(final TopicPartition partition, final ProcessorStateManager stateManager) { + restoringPartitions.add(partition); + } + + @Override + public void unregister(final Collection partitions) { + restoringPartitions.removeAll(partitions); + } + } + + static class MockTime implements Time { + private final AtomicLong timeMs; + private final AtomicLong highResTimeNs; + + MockTime(final long startTimestampMs) { + this.timeMs = new AtomicLong(startTimestampMs); + this.highResTimeNs = new AtomicLong(startTimestampMs * 1000L * 1000L); + } + + @Override + public long milliseconds() { + return timeMs.get(); + } + + @Override + public long nanoseconds() { + return highResTimeNs.get(); + } + + @Override + public long hiResClockMs() { + return TimeUnit.NANOSECONDS.toMillis(nanoseconds()); + } + + @Override + public void sleep(final long ms) { + if (ms < 0) { + throw new IllegalArgumentException("Sleep ms cannot be negative."); + } + timeMs.addAndGet(ms); + highResTimeNs.addAndGet(TimeUnit.MILLISECONDS.toNanos(ms)); + } + + @Override + public void waitObject(final Object obj, final Supplier condition, final long timeoutMs) { + throw new UnsupportedOperationException(); + } + } + + static class KeyValueStoreFacade extends ReadOnlyKeyValueStoreFacade implements KeyValueStore { + + public KeyValueStoreFacade(final TimestampedKeyValueStore inner) { + super(inner); + } + + @Deprecated + @Override + public void init(final ProcessorContext context, + final StateStore root) { + inner.init(context, root); + } + + @Override + public void init(final StateStoreContext context, final StateStore root) { + inner.init(context, root); + } + + @Override + public void put(final K key, + final V value) { + inner.put(key, ValueAndTimestamp.make(value, ConsumerRecord.NO_TIMESTAMP)); + } + + @Override + public V putIfAbsent(final K key, + final V value) { + return getValueOrNull(inner.putIfAbsent(key, ValueAndTimestamp.make(value, ConsumerRecord.NO_TIMESTAMP))); + } + + @Override + public void putAll(final List> entries) { + for (final KeyValue entry : entries) { + inner.put(entry.key, ValueAndTimestamp.make(entry.value, ConsumerRecord.NO_TIMESTAMP)); + } + } + + @Override + public V delete(final K key) { + return getValueOrNull(inner.delete(key)); + } + + @Override + public void flush() { + inner.flush(); + } + + @Override + public void close() { + inner.close(); + } + + @Override + public String name() { + return inner.name(); + } + + @Override + public boolean persistent() { + return inner.persistent(); + } + + @Override + public boolean isOpen() { + return inner.isOpen(); + } + } + + static class WindowStoreFacade extends ReadOnlyWindowStoreFacade implements WindowStore { + + public WindowStoreFacade(final TimestampedWindowStore store) { + super(store); + } + + @Deprecated + @Override + public void init(final ProcessorContext context, + final StateStore root) { + inner.init(context, root); + } + + @Override + public void init(final StateStoreContext context, final StateStore root) { + inner.init(context, root); + } + + @Override + public void put(final K key, + final V value, + final long windowStartTimestamp) { + inner.put(key, ValueAndTimestamp.make(value, ConsumerRecord.NO_TIMESTAMP), windowStartTimestamp); + } + + @Override + public WindowStoreIterator fetch(final K key, + final long timeFrom, + final long timeTo) { + return fetch(key, Instant.ofEpochMilli(timeFrom), Instant.ofEpochMilli(timeTo)); + } + + @Override + public WindowStoreIterator backwardFetch(final K key, + final long timeFrom, + final long timeTo) { + return backwardFetch(key, Instant.ofEpochMilli(timeFrom), Instant.ofEpochMilli(timeTo)); + } + + @Override + public KeyValueIterator, V> fetch(final K keyFrom, + final K keyTo, + final long timeFrom, + final long timeTo) { + return fetch(keyFrom, keyTo, Instant.ofEpochMilli(timeFrom), + Instant.ofEpochMilli(timeTo)); + } + + @Override + public KeyValueIterator, V> backwardFetch(final K keyFrom, + final K keyTo, + final long timeFrom, + final long timeTo) { + return backwardFetch(keyFrom, keyTo, Instant.ofEpochMilli(timeFrom), Instant.ofEpochMilli(timeTo)); + } + + @Override + public KeyValueIterator, V> fetchAll(final long timeFrom, + final long timeTo) { + return fetchAll(Instant.ofEpochMilli(timeFrom), Instant.ofEpochMilli(timeTo)); + } + + @Override + public KeyValueIterator, V> backwardFetchAll(final long timeFrom, + final long timeTo) { + return backwardFetchAll(Instant.ofEpochMilli(timeFrom), Instant.ofEpochMilli(timeTo)); + } + + @Override + public void flush() { + inner.flush(); + } + + @Override + public void close() { + inner.close(); + } + + @Override + public String name() { + return inner.name(); + } + + @Override + public boolean persistent() { + return inner.persistent(); + } + + @Override + public boolean isOpen() { + return inner.isOpen(); + } + } + + private static class TestDriverProducer extends StreamsProducer { + + public TestDriverProducer(final StreamsConfig config, + final KafkaClientSupplier clientSupplier, + final LogContext logContext, + final Time time) { + super(config, "TopologyTestDriver-StreamThread-1", clientSupplier, new TaskId(0, 0), UUID.randomUUID(), logContext, time); + } + + @Override + public void commitTransaction(final Map offsets, + final ConsumerGroupMetadata consumerGroupMetadata) throws ProducerFencedException { + super.commitTransaction(offsets, consumerGroupMetadata); + } + } +} diff --git a/streams/test-utils/src/main/java/org/apache/kafka/streams/processor/MockProcessorContext.java b/streams/test-utils/src/main/java/org/apache/kafka/streams/processor/MockProcessorContext.java new file mode 100644 index 0000000..061c1fb --- /dev/null +++ b/streams/test-utils/src/main/java/org/apache/kafka/streams/processor/MockProcessorContext.java @@ -0,0 +1,586 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor; + +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.metrics.MetricConfig; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.StreamsMetrics; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.internals.ApiUtils; +import org.apache.kafka.streams.kstream.Transformer; +import org.apache.kafka.streams.kstream.ValueTransformer; +import org.apache.kafka.streams.processor.internals.ClientUtils; +import org.apache.kafka.streams.processor.internals.RecordCollector; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.processor.internals.metrics.TaskMetrics; +import org.apache.kafka.streams.state.internals.InMemoryKeyValueStore; + +import java.io.File; +import java.time.Duration; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Properties; + +/** + * {@link MockProcessorContext} is a mock of {@link ProcessorContext} for users to test their {@link Processor}, + * {@link Transformer}, and {@link ValueTransformer} implementations. + *

                + * The tests for this class (org.apache.kafka.streams.MockProcessorContextTest) include several behavioral + * tests that serve as example usage. + *

                + * Note that this class does not take any automated actions (such as firing scheduled punctuators). + * It simply captures any data it witnesses. + * If you require more automated tests, we recommend wrapping your {@link Processor} in a minimal source-processor-sink + * {@link Topology} and using the {@link TopologyTestDriver}. + */ +@SuppressWarnings("deprecation") // not deprecating old PAPI Context, since it is still in use by Transformers. +public class MockProcessorContext implements ProcessorContext, RecordCollector.Supplier { + // Immutable fields ================================================ + private final StreamsMetricsImpl metrics; + private final TaskId taskId; + private final StreamsConfig config; + private final File stateDir; + + // settable record metadata ================================================ + private String topic; + private Integer partition; + private Long offset; + private Headers headers; + private Long recordTimestamp; + private Long currentSystemTimeMs; + private Long currentStreamTimeMs; + + // mocks ================================================ + private final Map stateStores = new HashMap<>(); + private final List punctuators = new LinkedList<>(); + private final List capturedForwards = new LinkedList<>(); + private boolean committed = false; + + + /** + * {@link CapturedPunctuator} holds captured punctuators, along with their scheduling information. + */ + public static class CapturedPunctuator { + private final long intervalMs; + private final PunctuationType type; + private final Punctuator punctuator; + private boolean cancelled = false; + + private CapturedPunctuator(final long intervalMs, final PunctuationType type, final Punctuator punctuator) { + this.intervalMs = intervalMs; + this.type = type; + this.punctuator = punctuator; + } + + @SuppressWarnings({"WeakerAccess", "unused"}) + public long getIntervalMs() { + return intervalMs; + } + + @SuppressWarnings({"WeakerAccess", "unused"}) + public PunctuationType getType() { + return type; + } + + @SuppressWarnings({"WeakerAccess", "unused"}) + public Punctuator getPunctuator() { + return punctuator; + } + + @SuppressWarnings({"WeakerAccess", "unused"}) + public void cancel() { + cancelled = true; + } + + @SuppressWarnings({"WeakerAccess", "unused"}) + public boolean cancelled() { + return cancelled; + } + } + + + public static class CapturedForward { + private final String childName; + private final long timestamp; + private final Headers headers; + private final KeyValue keyValue; + + private CapturedForward(final KeyValue keyValue, final To to, final Headers headers) { + if (keyValue == null) { + throw new IllegalArgumentException(); + } + + this.childName = to.childName; + this.timestamp = to.timestamp; + this.keyValue = keyValue; + this.headers = headers; + } + + /** + * The child this data was forwarded to. + * + * @return The child name, or {@code null} if it was broadcast. + */ + @SuppressWarnings({"WeakerAccess", "unused"}) + public String childName() { + return childName; + } + + /** + * The timestamp attached to the forwarded record. + * + * @return A timestamp, or {@code -1} if none was forwarded. + */ + @SuppressWarnings({"WeakerAccess", "unused"}) + public long timestamp() { + return timestamp; + } + + /** + * The data forwarded. + * + * @return A key/value pair. Not null. + */ + @SuppressWarnings({"WeakerAccess", "unused"}) + public KeyValue keyValue() { + return keyValue; + } + + @Override + public String toString() { + return "CapturedForward{" + + "childName='" + childName + '\'' + + ", timestamp=" + timestamp + + ", keyValue=" + keyValue + + '}'; + } + + public Headers headers() { + return this.headers; + } + } + + // constructors ================================================ + + /** + * Create a {@link MockProcessorContext} with dummy {@code config} and {@code taskId} and {@code null} {@code stateDir}. + * Most unit tests using this mock won't need to know the taskId, + * and most unit tests should be able to get by with the + * {@link InMemoryKeyValueStore}, so the stateDir won't matter. + */ + @SuppressWarnings({"WeakerAccess", "unused"}) + public MockProcessorContext() { + //noinspection DoubleBraceInitialization + this( + new Properties() { + { + put(StreamsConfig.APPLICATION_ID_CONFIG, ""); + put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, ""); + } + }, + new TaskId(0, 0), + null); + } + + /** + * Create a {@link MockProcessorContext} with dummy {@code taskId} and {@code null} {@code stateDir}. + * Most unit tests using this mock won't need to know the taskId, + * and most unit tests should be able to get by with the + * {@link InMemoryKeyValueStore}, so the stateDir won't matter. + * + * @param config a Properties object, used to configure the context and the processor. + */ + @SuppressWarnings({"WeakerAccess", "unused"}) + public MockProcessorContext(final Properties config) { + this(config, new TaskId(0, 0), null); + } + + /** + * Create a {@link MockProcessorContext} with a specified taskId and null stateDir. + * + * @param config a {@link Properties} object, used to configure the context and the processor. + * @param taskId a {@link TaskId}, which the context makes available via {@link MockProcessorContext#taskId()}. + * @param stateDir a {@link File}, which the context makes available viw {@link MockProcessorContext#stateDir()}. + */ + @SuppressWarnings({"WeakerAccess", "unused"}) + public MockProcessorContext(final Properties config, final TaskId taskId, final File stateDir) { + final Properties configCopy = new Properties(); + configCopy.putAll(config); + configCopy.putIfAbsent(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy-bootstrap-host:0"); + configCopy.putIfAbsent(StreamsConfig.APPLICATION_ID_CONFIG, "dummy-mock-app-id"); + final StreamsConfig streamsConfig = new ClientUtils.QuietStreamsConfig(configCopy); + this.taskId = taskId; + this.config = streamsConfig; + this.stateDir = stateDir; + final MetricConfig metricConfig = new MetricConfig(); + metricConfig.recordLevel(Sensor.RecordingLevel.DEBUG); + final String threadId = Thread.currentThread().getName(); + this.metrics = new StreamsMetricsImpl( + new Metrics(metricConfig), + threadId, + streamsConfig.getString(StreamsConfig.BUILT_IN_METRICS_VERSION_CONFIG), + Time.SYSTEM + ); + TaskMetrics.droppedRecordsSensor(threadId, taskId.toString(), metrics); + } + + @Override + public String applicationId() { + return config.getString(StreamsConfig.APPLICATION_ID_CONFIG); + } + + @Override + public TaskId taskId() { + return taskId; + } + + @Override + public Map appConfigs() { + final Map combined = new HashMap<>(); + combined.putAll(config.originals()); + combined.putAll(config.values()); + return combined; + } + + @Override + public Map appConfigsWithPrefix(final String prefix) { + return config.originalsWithPrefix(prefix); + } + + @Override + public long currentSystemTimeMs() { + if (currentSystemTimeMs == null) { + throw new IllegalStateException("System time must be set before use via setCurrentSystemTimeMs()."); + } + return currentSystemTimeMs; + } + + @Override + public long currentStreamTimeMs() { + if (currentStreamTimeMs == null) { + throw new IllegalStateException("Stream time must be set before use via setCurrentStreamTimeMs()."); + } + return currentStreamTimeMs; + } + + @Override + public Serde keySerde() { + return config.defaultKeySerde(); + } + + @Override + public Serde valueSerde() { + return config.defaultValueSerde(); + } + + @Override + public File stateDir() { + return stateDir; + } + + @Override + public StreamsMetrics metrics() { + return metrics; + } + + // settable record metadata ================================================ + + /** + * The context exposes these metadata for use in the processor. Normally, they are set by the Kafka Streams framework, + * but for the purpose of driving unit tests, you can set them directly. + * + * @param topic A topic name + * @param partition A partition number + * @param offset A record offset + * @param timestamp A record timestamp + */ + @SuppressWarnings({"WeakerAccess", "unused"}) + public void setRecordMetadata(final String topic, + final int partition, + final long offset, + final Headers headers, + final long timestamp) { + this.topic = topic; + this.partition = partition; + this.offset = offset; + this.headers = headers; + this.recordTimestamp = timestamp; + } + + /** + * The context exposes this metadata for use in the processor. Normally, they are set by the Kafka Streams framework, + * but for the purpose of driving unit tests, you can set it directly. Setting this attribute doesn't affect the others. + * + * @param topic A topic name + */ + @SuppressWarnings({"WeakerAccess", "unused"}) + public void setTopic(final String topic) { + this.topic = topic; + } + + /** + * The context exposes this metadata for use in the processor. Normally, they are set by the Kafka Streams framework, + * but for the purpose of driving unit tests, you can set it directly. Setting this attribute doesn't affect the others. + * + * @param partition A partition number + */ + @SuppressWarnings({"WeakerAccess", "unused"}) + public void setPartition(final int partition) { + this.partition = partition; + } + + /** + * The context exposes this metadata for use in the processor. Normally, they are set by the Kafka Streams framework, + * but for the purpose of driving unit tests, you can set it directly. Setting this attribute doesn't affect the others. + * + * @param offset A record offset + */ + @SuppressWarnings({"WeakerAccess", "unused"}) + public void setOffset(final long offset) { + this.offset = offset; + } + + /** + * The context exposes this metadata for use in the processor. Normally, they are set by the Kafka Streams framework, + * but for the purpose of driving unit tests, you can set it directly. Setting this attribute doesn't affect the others. + * + * @param headers Record headers + */ + @SuppressWarnings({"WeakerAccess", "unused"}) + public void setHeaders(final Headers headers) { + this.headers = headers; + } + + /** + * The context exposes this metadata for use in the processor. Normally, they are set by the Kafka Streams framework, + * but for the purpose of driving unit tests, you can set it directly. Setting this attribute doesn't affect the others. + * + * @param timestamp A record timestamp + * @deprecated Since 3.0.0; use {@link MockProcessorContext#setRecordTimestamp(long)} instead. + */ + @Deprecated + @SuppressWarnings({"WeakerAccess", "unused"}) + public void setTimestamp(final long timestamp) { + this.recordTimestamp = timestamp; + } + + /** + * The context exposes this metadata for use in the processor. Normally, they are set by the Kafka Streams framework, + * but for the purpose of driving unit tests, you can set it directly. Setting this attribute doesn't affect the others. + * + * @param recordTimestamp A record timestamp + */ + @SuppressWarnings({"WeakerAccess"}) + public void setRecordTimestamp(final long recordTimestamp) { + this.recordTimestamp = recordTimestamp; + } + + public void setCurrentSystemTimeMs(final long currentSystemTimeMs) { + this.currentSystemTimeMs = currentSystemTimeMs; + } + + public void setCurrentStreamTimeMs(final long currentStreamTimeMs) { + this.currentStreamTimeMs = currentStreamTimeMs; + } + + @Override + public String topic() { + if (topic == null) { + throw new IllegalStateException("Topic must be set before use via setRecordMetadata() or setTopic()."); + } + return topic; + } + + @Override + public int partition() { + if (partition == null) { + throw new IllegalStateException("Partition must be set before use via setRecordMetadata() or setPartition()."); + } + return partition; + } + + @Override + public long offset() { + if (offset == null) { + throw new IllegalStateException("Offset must be set before use via setRecordMetadata() or setOffset()."); + } + return offset; + } + + /** + * Returns the headers of the current input record; could be {@code null} if it is not + * available. + * + *

                Note, that headers should never be {@code null} in the actual Kafka Streams runtime, + * even if they could be empty. However, this mock does not guarantee non-{@code null} headers. + * Thus, you either need to add a {@code null} check to your production code to use this mock + * for testing or you always need to set headers manually via {@link #setHeaders(Headers)} to + * avoid a {@link NullPointerException} from your {@link Processor} implementation. + * + * @return the headers + */ + @Override + public Headers headers() { + return headers; + } + + @Override + public long timestamp() { + if (recordTimestamp == null) { + throw new IllegalStateException("Timestamp must be set before use via setRecordMetadata() or setTimestamp()."); + } + return recordTimestamp; + } + + // mocks ================================================ + + @Override + public void register(final StateStore store, + final StateRestoreCallback stateRestoreCallbackIsIgnoredInMock) { + stateStores.put(store.name(), store); + } + + @SuppressWarnings("unchecked") + @Override + public S getStateStore(final String name) { + return (S) stateStores.get(name); + } + + @SuppressWarnings("deprecation") // removing #schedule(final long intervalMs,...) will fix this + @Override + public Cancellable schedule(final Duration interval, + final PunctuationType type, + final Punctuator callback) throws IllegalArgumentException { + final long intervalMs = ApiUtils.validateMillisecondDuration(interval, "interval"); + if (intervalMs < 1) { + throw new IllegalArgumentException("The minimum supported scheduling interval is 1 millisecond."); + } + final CapturedPunctuator capturedPunctuator = new CapturedPunctuator(intervalMs, type, callback); + + punctuators.add(capturedPunctuator); + + return capturedPunctuator::cancel; + } + + /** + * Get the punctuators scheduled so far. The returned list is not affected by subsequent calls to {@code schedule(...)}. + * + * @return A list of captured punctuators. + */ + @SuppressWarnings({"WeakerAccess", "unused"}) + public List scheduledPunctuators() { + return new LinkedList<>(punctuators); + } + + @Override + public void forward(final K key, final V value) { + forward(key, value, To.all()); + } + + @Override + public void forward(final K key, final V value, final To to) { + capturedForwards.add( + new CapturedForward( + new KeyValue<>(key, value), + to.timestamp == -1 ? to.withTimestamp(recordTimestamp == null ? -1 : recordTimestamp) : to, + headers + ) + ); + } + + /** + * Get all the forwarded data this context has observed. The returned list will not be + * affected by subsequent interactions with the context. The data in the list is in the same order as the calls to + * {@code forward(...)}. + * + * @return A list of key/value pairs that were previously passed to the context. + */ + public List forwarded() { + return new LinkedList<>(capturedForwards); + } + + /** + * Get all the forwarded data this context has observed for a specific child by name. + * The returned list will not be affected by subsequent interactions with the context. + * The data in the list is in the same order as the calls to {@code forward(...)}. + * + * @param childName The child name to retrieve forwards for + * @return A list of key/value pairs that were previously passed to the context. + */ + @SuppressWarnings({"WeakerAccess", "unused"}) + public List forwarded(final String childName) { + final LinkedList result = new LinkedList<>(); + for (final CapturedForward capture : capturedForwards) { + if (capture.childName() == null || capture.childName().equals(childName)) { + result.add(capture); + } + } + return result; + } + + /** + * Clear the captured forwarded data. + */ + @SuppressWarnings({"WeakerAccess", "unused"}) + public void resetForwards() { + capturedForwards.clear(); + } + + @Override + public void commit() { + committed = true; + } + + /** + * Whether {@link ProcessorContext#commit()} has been called in this context. + * + * @return {@code true} iff {@link ProcessorContext#commit()} has been called in this context since construction or reset. + */ + @SuppressWarnings("WeakerAccess") + public boolean committed() { + return committed; + } + + /** + * Reset the commit capture to {@code false} (whether or not it was previously {@code true}). + */ + @SuppressWarnings({"WeakerAccess", "unused"}) + public void resetCommit() { + committed = false; + } + + @Override + public RecordCollector recordCollector() { + // This interface is assumed by state stores that add change-logging. + // Rather than risk a mysterious ClassCastException during unit tests, throw an explanatory exception. + + throw new UnsupportedOperationException( + "MockProcessorContext does not provide record collection. " + + "For processor unit tests, use an in-memory state store with change-logging disabled. " + + "Alternatively, use the TopologyTestDriver for testing processor/store/topology integration." + ); + } +} diff --git a/streams/test-utils/src/main/java/org/apache/kafka/streams/processor/api/MockProcessorContext.java b/streams/test-utils/src/main/java/org/apache/kafka/streams/processor/api/MockProcessorContext.java new file mode 100644 index 0000000..ffc2940 --- /dev/null +++ b/streams/test-utils/src/main/java/org/apache/kafka/streams/processor/api/MockProcessorContext.java @@ -0,0 +1,498 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.processor.api; + +import org.apache.kafka.common.metrics.MetricConfig; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.metrics.Sensor; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.StreamsMetrics; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.TopologyTestDriver; +import org.apache.kafka.streams.kstream.Transformer; +import org.apache.kafka.streams.kstream.ValueTransformer; +import org.apache.kafka.streams.processor.Cancellable; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.Punctuator; +import org.apache.kafka.streams.processor.StateRestoreCallback; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.internals.ClientUtils; +import org.apache.kafka.streams.processor.internals.RecordCollector; +import org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl; +import org.apache.kafka.streams.processor.internals.metrics.TaskMetrics; +import org.apache.kafka.streams.state.internals.InMemoryKeyValueStore; + +import java.io.File; +import java.time.Duration; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Properties; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkProperties; + +/** + * {@link MockProcessorContext} is a mock of {@link ProcessorContext} for users to test their {@link Processor}, + * {@link Transformer}, and {@link ValueTransformer} implementations. + *

                + * The tests for this class (org.apache.kafka.streams.MockProcessorContextTest) include several behavioral + * tests that serve as example usage. + *

                + * Note that this class does not take any automated actions (such as firing scheduled punctuators). + * It simply captures any data it witnesses. + * If you require more automated tests, we recommend wrapping your {@link Processor} in a minimal source-processor-sink + * {@link Topology} and using the {@link TopologyTestDriver}. + */ +public class MockProcessorContext implements ProcessorContext, RecordCollector.Supplier { + // Immutable fields ================================================ + private final StreamsMetricsImpl metrics; + private final TaskId taskId; + private final StreamsConfig config; + private final File stateDir; + + // settable record metadata ================================================ + private MockRecordMetadata recordMetadata; + + // mocks ================================================ + private final Map stateStores = new HashMap<>(); + private final List punctuators = new LinkedList<>(); + private final List> capturedForwards = new LinkedList<>(); + private boolean committed = false; + + private static final class MockRecordMetadata implements RecordMetadata { + private final String topic; + private final int partition; + private final long offset; + + private MockRecordMetadata(final String topic, final int partition, final long offset) { + this.topic = topic; + this.partition = partition; + this.offset = offset; + } + + @Override + public String topic() { + return topic; + } + + @Override + public int partition() { + return partition; + } + + @Override + public long offset() { + return offset; + } + } + + /** + * {@link CapturedPunctuator} holds captured punctuators, along with their scheduling information. + */ + public static final class CapturedPunctuator { + private final Duration interval; + private final PunctuationType type; + private final Punctuator punctuator; + private boolean cancelled = false; + + private CapturedPunctuator(final Duration interval, final PunctuationType type, final Punctuator punctuator) { + this.interval = interval; + this.type = type; + this.punctuator = punctuator; + } + + public Duration getInterval() { + return interval; + } + + public PunctuationType getType() { + return type; + } + + public Punctuator getPunctuator() { + return punctuator; + } + + public void cancel() { + cancelled = true; + } + + public boolean cancelled() { + return cancelled; + } + } + + public static final class CapturedForward { + + private final Record record; + private final Optional childName; + + public CapturedForward(final Record record) { + this(record, Optional.empty()); + } + + public CapturedForward(final Record record, final Optional childName) { + this.record = Objects.requireNonNull(record); + this.childName = Objects.requireNonNull(childName); + } + + /** + * The child this data was forwarded to. + * + * @return If present, the child name the record was forwarded to. + * If empty, the forward was a broadcast. + */ + public Optional childName() { + return childName; + } + + /** + * The record that was forwarded. + * + * @return The forwarded record. Not null. + */ + public Record record() { + return record; + } + + @Override + public String toString() { + return "CapturedForward{" + + "record=" + record + + ", childName=" + childName + + '}'; + } + + @Override + public boolean equals(final Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + final CapturedForward that = (CapturedForward) o; + return Objects.equals(record, that.record) && + Objects.equals(childName, that.childName); + } + + @Override + public int hashCode() { + return Objects.hash(record, childName); + } + } + + // constructors ================================================ + + /** + * Create a {@link MockProcessorContext} with dummy {@code config} and {@code taskId} and {@code null} {@code stateDir}. + * Most unit tests using this mock won't need to know the taskId, + * and most unit tests should be able to get by with the + * {@link InMemoryKeyValueStore}, so the stateDir won't matter. + */ + public MockProcessorContext() { + this( + mkProperties(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, ""), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "") + )), + new TaskId(0, 0), + null + ); + } + + /** + * Create a {@link MockProcessorContext} with dummy {@code taskId} and {@code null} {@code stateDir}. + * Most unit tests using this mock won't need to know the taskId, + * and most unit tests should be able to get by with the + * {@link InMemoryKeyValueStore}, so the stateDir won't matter. + * + * @param config a Properties object, used to configure the context and the processor. + */ + public MockProcessorContext(final Properties config) { + this(config, new TaskId(0, 0), null); + } + + /** + * Create a {@link MockProcessorContext} with a specified taskId and null stateDir. + * + * @param config a {@link Properties} object, used to configure the context and the processor. + * @param taskId a {@link TaskId}, which the context makes available via {@link MockProcessorContext#taskId()}. + * @param stateDir a {@link File}, which the context makes available viw {@link MockProcessorContext#stateDir()}. + */ + public MockProcessorContext(final Properties config, final TaskId taskId, final File stateDir) { + final Properties configCopy = new Properties(); + configCopy.putAll(config); + configCopy.putIfAbsent(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy-bootstrap-host:0"); + configCopy.putIfAbsent(StreamsConfig.APPLICATION_ID_CONFIG, "dummy-mock-app-id"); + final StreamsConfig streamsConfig = new ClientUtils.QuietStreamsConfig(configCopy); + this.taskId = taskId; + this.config = streamsConfig; + this.stateDir = stateDir; + final MetricConfig metricConfig = new MetricConfig(); + metricConfig.recordLevel(Sensor.RecordingLevel.DEBUG); + final String threadId = Thread.currentThread().getName(); + metrics = new StreamsMetricsImpl( + new Metrics(metricConfig), + threadId, + streamsConfig.getString(StreamsConfig.BUILT_IN_METRICS_VERSION_CONFIG), + Time.SYSTEM + ); + TaskMetrics.droppedRecordsSensor(threadId, taskId.toString(), metrics); + } + + @Override + public String applicationId() { + return config.getString(StreamsConfig.APPLICATION_ID_CONFIG); + } + + @Override + public TaskId taskId() { + return taskId; + } + + @Override + public Map appConfigs() { + final Map combined = new HashMap<>(); + combined.putAll(config.originals()); + combined.putAll(config.values()); + return combined; + } + + @Override + public Map appConfigsWithPrefix(final String prefix) { + return config.originalsWithPrefix(prefix); + } + + @Override + public Serde keySerde() { + return config.defaultKeySerde(); + } + + @Override + public Serde valueSerde() { + return config.defaultValueSerde(); + } + + @Override + public File stateDir() { + return Objects.requireNonNull( + stateDir, + "The stateDir constructor argument was needed (probably for a state store) but not supplied. " + + "You can either reconfigure your test so that it doesn't need access to the disk " + + "(such as using an in-memory store), or use the full MockProcessorContext constructor to supply " + + "a non-null stateDir argument." + ); + } + + @Override + public StreamsMetrics metrics() { + return metrics; + } + + // settable record metadata ================================================ + + /** + * The context exposes these metadata for use in the processor. Normally, they are set by the Kafka Streams framework, + * but for the purpose of driving unit tests, you can set them directly. + * + * @param topic A topic name + * @param partition A partition number + * @param offset A record offset + */ + public void setRecordMetadata(final String topic, + final int partition, + final long offset) { + recordMetadata = new MockRecordMetadata(topic, partition, offset); + } + + @Override + public Optional recordMetadata() { + return Optional.ofNullable(recordMetadata); + } + + // mocks ================================================ + + @SuppressWarnings("unchecked") + @Override + public S getStateStore(final String name) { + return (S) stateStores.get(name); + } + + public void addStateStore(final S stateStore) { + stateStores.put(stateStore.name(), stateStore); + } + + @Override + public Cancellable schedule(final Duration interval, + final PunctuationType type, + final Punctuator callback) { + final CapturedPunctuator capturedPunctuator = new CapturedPunctuator(interval, type, callback); + + punctuators.add(capturedPunctuator); + + return capturedPunctuator::cancel; + } + + /** + * Get the punctuators scheduled so far. The returned list is not affected by subsequent calls to {@code schedule(...)}. + * + * @return A list of captured punctuators. + */ + public List scheduledPunctuators() { + return new LinkedList<>(punctuators); + } + + @Override + public void forward(final Record record) { + forward(record, null); + } + + @Override + public void forward(final Record record, final String childName) { + capturedForwards.add(new CapturedForward<>(record, Optional.ofNullable(childName))); + } + + /** + * Get all the forwarded data this context has observed. The returned list will not be + * affected by subsequent interactions with the context. The data in the list is in the same order as the calls to + * {@code forward(...)}. + * + * @return A list of records that were previously passed to the context. + */ + public List> forwarded() { + return new LinkedList<>(capturedForwards); + } + + /** + * Get all the forwarded data this context has observed for a specific child by name. + * The returned list will not be affected by subsequent interactions with the context. + * The data in the list is in the same order as the calls to {@code forward(...)}. + * + * @param childName The child name to retrieve forwards for + * @return A list of records that were previously passed to the context. + */ + public List> forwarded(final String childName) { + final LinkedList> result = new LinkedList<>(); + for (final CapturedForward capture : capturedForwards) { + if (!capture.childName().isPresent() || capture.childName().equals(Optional.of(childName))) { + result.add(capture); + } + } + return result; + } + + /** + * Clear the captured forwarded data. + */ + public void resetForwards() { + capturedForwards.clear(); + } + + @Override + public void commit() { + committed = true; + } + + /** + * Whether {@link ProcessorContext#commit()} has been called in this context. + * + * @return {@code true} iff {@link ProcessorContext#commit()} has been called in this context since construction or reset. + */ + public boolean committed() { + return committed; + } + + /** + * Reset the commit capture to {@code false} (whether or not it was previously {@code true}). + */ + public void resetCommit() { + committed = false; + } + + @Override + public RecordCollector recordCollector() { + // This interface is assumed by state stores that add change-logging. + // Rather than risk a mysterious ClassCastException during unit tests, throw an explanatory exception. + + throw new UnsupportedOperationException( + "MockProcessorContext does not provide record collection. " + + "For processor unit tests, use an in-memory state store with change-logging disabled. " + + "Alternatively, use the TopologyTestDriver for testing processor/store/topology integration." + ); + } + + /** + * Used to get a {@link StateStoreContext} for use with + * {@link StateStore#init(StateStoreContext, StateStore)} + * if you need to initialize a store for your tests. + * @return a {@link StateStoreContext} that delegates to this ProcessorContext. + */ + public StateStoreContext getStateStoreContext() { + return new StateStoreContext() { + @Override + public String applicationId() { + return MockProcessorContext.this.applicationId(); + } + + @Override + public TaskId taskId() { + return MockProcessorContext.this.taskId(); + } + + @Override + public Serde keySerde() { + return MockProcessorContext.this.keySerde(); + } + + @Override + public Serde valueSerde() { + return MockProcessorContext.this.valueSerde(); + } + + @Override + public File stateDir() { + return MockProcessorContext.this.stateDir(); + } + + @Override + public StreamsMetrics metrics() { + return MockProcessorContext.this.metrics(); + } + + @Override + public void register(final StateStore store, final StateRestoreCallback stateRestoreCallback) { + stateStores.put(store.name(), store); + } + + @Override + public Map appConfigs() { + return MockProcessorContext.this.appConfigs(); + } + + @Override + public Map appConfigsWithPrefix(final String prefix) { + return MockProcessorContext.this.appConfigsWithPrefix(prefix); + } + }; + } +} diff --git a/streams/test-utils/src/main/java/org/apache/kafka/streams/test/TestRecord.java b/streams/test-utils/src/main/java/org/apache/kafka/streams/test/TestRecord.java new file mode 100644 index 0000000..63f6921 --- /dev/null +++ b/streams/test-utils/src/main/java/org/apache/kafka/streams/test/TestRecord.java @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.test; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.streams.TestInputTopic; +import org.apache.kafka.streams.TopologyTestDriver; + +import java.time.Instant; +import java.util.Objects; +import java.util.StringJoiner; + +/** + * A key/value pair, including timestamp and record headers, to be sent to or received from {@link TopologyTestDriver}. + * If [a] record does not contain a timestamp, + * {@link TestInputTopic} will auto advance it's time when the record is piped. + */ +public class TestRecord { + private final Headers headers; + private final K key; + private final V value; + private final Instant recordTime; + + /** + * Creates a record. + * + * @param key The key that will be included in the record + * @param value The value of the record + * @param headers the record headers that will be included in the record + * @param recordTime The timestamp of the record. + */ + public TestRecord(final K key, final V value, final Headers headers, final Instant recordTime) { + this.key = key; + this.value = value; + this.recordTime = recordTime; + this.headers = new RecordHeaders(headers); + } + + /** + * Creates a record. + * + * @param key The key that will be included in the record + * @param value The value of the record + * @param headers the record headers that will be included in the record + * @param timestampMs The timestamp of the record, in milliseconds since the beginning of the epoch. + */ + public TestRecord(final K key, final V value, final Headers headers, final Long timestampMs) { + if (timestampMs != null) { + if (timestampMs < 0) { + throw new IllegalArgumentException( + String.format("Invalid timestamp: %d. Timestamp should always be non-negative or null.", timestampMs)); + } + this.recordTime = Instant.ofEpochMilli(timestampMs); + } else { + this.recordTime = null; + } + this.key = key; + this.value = value; + this.headers = new RecordHeaders(headers); + } + + /** + * Creates a record. + * + * @param key The key of the record + * @param value The value of the record + * @param recordTime The timestamp of the record as Instant. + */ + public TestRecord(final K key, final V value, final Instant recordTime) { + this(key, value, null, recordTime); + } + + /** + * Creates a record. + * + * @param key The key of the record + * @param value The value of the record + * @param headers The record headers that will be included in the record + */ + public TestRecord(final K key, final V value, final Headers headers) { + this.key = key; + this.value = value; + this.headers = new RecordHeaders(headers); + this.recordTime = null; + } + + /** + * Creates a record. + * + * @param key The key of the record + * @param value The value of the record + */ + public TestRecord(final K key, final V value) { + this.key = key; + this.value = value; + this.headers = new RecordHeaders(); + this.recordTime = null; + } + + /** + * Create a record with {@code null} key. + * + * @param value The value of the record + */ + public TestRecord(final V value) { + this(null, value); + } + + /** + * Create a {@code TestRecord} from a {@link ConsumerRecord}. + * + * @param record The v + */ + public TestRecord(final ConsumerRecord record) { + Objects.requireNonNull(record); + this.key = record.key(); + this.value = record.value(); + this.headers = record.headers(); + this.recordTime = Instant.ofEpochMilli(record.timestamp()); + } + + /** + * Create a {@code TestRecord} from a {@link ProducerRecord}. + * + * @param record The record contents + */ + public TestRecord(final ProducerRecord record) { + Objects.requireNonNull(record); + this.key = record.key(); + this.value = record.value(); + this.headers = record.headers(); + this.recordTime = Instant.ofEpochMilli(record.timestamp()); + } + + /** + * @return The headers. + */ + public Headers headers() { + return headers; + } + + /** + * @return The key (or {@code null} if no key is specified). + */ + public K key() { + return key; + } + + /** + * @return The value. + */ + public V value() { + return value; + } + + /** + * @return The timestamp, which is in milliseconds since epoch. + */ + public Long timestamp() { + return this.recordTime == null ? null : this.recordTime.toEpochMilli(); + } + + /** + * @return The headers. + */ + public Headers getHeaders() { + return headers; + } + + /** + * @return The key (or null if no key is specified) + */ + public K getKey() { + return key; + } + + /** + * @return The value. + */ + public V getValue() { + return value; + } + + /** + * @return The timestamp. + */ + public Instant getRecordTime() { + return recordTime; + } + + @Override + public String toString() { + return new StringJoiner(", ", TestRecord.class.getSimpleName() + "[", "]") + .add("key=" + key) + .add("value=" + value) + .add("headers=" + headers) + .add("recordTime=" + recordTime) + .toString(); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final TestRecord that = (TestRecord) o; + return Objects.equals(headers, that.headers) && + Objects.equals(key, that.key) && + Objects.equals(value, that.value) && + Objects.equals(recordTime, that.recordTime); + } + + @Override + public int hashCode() { + return Objects.hash(headers, key, value, recordTime); + } +} diff --git a/streams/test-utils/src/test/java/org/apache/kafka/streams/KeyValueStoreFacadeTest.java b/streams/test-utils/src/test/java/org/apache/kafka/streams/KeyValueStoreFacadeTest.java new file mode 100644 index 0000000..b79a2de --- /dev/null +++ b/streams/test-utils/src/test/java/org/apache/kafka/streams/KeyValueStoreFacadeTest.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.streams.TopologyTestDriver.KeyValueStoreFacade; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.state.TimestampedKeyValueStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static java.util.Arrays.asList; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class KeyValueStoreFacadeTest { + @SuppressWarnings("unchecked") + private final TimestampedKeyValueStore mockedKeyValueTimestampStore = mock(TimestampedKeyValueStore.class); + + private KeyValueStoreFacade keyValueStoreFacade; + + @BeforeEach + public void setup() { + keyValueStoreFacade = new KeyValueStoreFacade<>(mockedKeyValueTimestampStore); + } + + @SuppressWarnings("deprecation") // test of deprecated method + @Test + public void shouldForwardDeprecatedInit() { + final ProcessorContext context = mock(ProcessorContext.class); + final StateStore store = mock(StateStore.class); + + keyValueStoreFacade.init(context, store); + verify(mockedKeyValueTimestampStore).init(context, store); + } + + @Test + public void shouldForwardInit() { + final StateStoreContext context = mock(StateStoreContext.class); + final StateStore store = mock(StateStore.class); + + keyValueStoreFacade.init(context, store); + verify(mockedKeyValueTimestampStore).init(context, store); + } + + @Test + public void shouldPutWithUnknownTimestamp() { + keyValueStoreFacade.put("key", "value"); + verify(mockedKeyValueTimestampStore) + .put("key", ValueAndTimestamp.make("value", ConsumerRecord.NO_TIMESTAMP)); + } + + @Test + public void shouldPutIfAbsentWithUnknownTimestamp() { + doReturn(null, ValueAndTimestamp.make("oldValue", 42L)) + .when(mockedKeyValueTimestampStore) + .putIfAbsent("key", ValueAndTimestamp.make("value", ConsumerRecord.NO_TIMESTAMP)); + + assertNull(keyValueStoreFacade.putIfAbsent("key", "value")); + assertThat(keyValueStoreFacade.putIfAbsent("key", "value"), is("oldValue")); + verify(mockedKeyValueTimestampStore, times(2)) + .putIfAbsent("key", ValueAndTimestamp.make("value", ConsumerRecord.NO_TIMESTAMP)); + } + + @Test + public void shouldPutAllWithUnknownTimestamp() { + keyValueStoreFacade.putAll(asList( + KeyValue.pair("key1", "value1"), + KeyValue.pair("key2", "value2") + )); + verify(mockedKeyValueTimestampStore) + .put("key1", ValueAndTimestamp.make("value1", ConsumerRecord.NO_TIMESTAMP)); + verify(mockedKeyValueTimestampStore) + .put("key2", ValueAndTimestamp.make("value2", ConsumerRecord.NO_TIMESTAMP)); + } + + @Test + public void shouldDeleteAndReturnPlainValue() { + doReturn(null, ValueAndTimestamp.make("oldValue", 42L)) + .when(mockedKeyValueTimestampStore).delete("key"); + + assertNull(keyValueStoreFacade.delete("key")); + assertThat(keyValueStoreFacade.delete("key"), is("oldValue")); + verify(mockedKeyValueTimestampStore, times(2)).delete("key"); + } + + @Test + public void shouldForwardFlush() { + keyValueStoreFacade.flush(); + verify(mockedKeyValueTimestampStore).flush(); + } + + @Test + public void shouldForwardClose() { + keyValueStoreFacade.close(); + verify(mockedKeyValueTimestampStore).close(); + } + + @Test + public void shouldReturnName() { + when(mockedKeyValueTimestampStore.name()).thenReturn("name"); + + assertThat(keyValueStoreFacade.name(), is("name")); + verify(mockedKeyValueTimestampStore).name(); + } + + @Test + public void shouldReturnIsPersistent() { + when(mockedKeyValueTimestampStore.persistent()) + .thenReturn(true, false); + + assertThat(keyValueStoreFacade.persistent(), is(true)); + assertThat(keyValueStoreFacade.persistent(), is(false)); + verify(mockedKeyValueTimestampStore, times(2)).persistent(); + } + + @Test + public void shouldReturnIsOpen() { + when(mockedKeyValueTimestampStore.isOpen()) + .thenReturn(true, false); + + assertThat(keyValueStoreFacade.isOpen(), is(true)); + assertThat(keyValueStoreFacade.isOpen(), is(false)); + verify(mockedKeyValueTimestampStore, times(2)).isOpen(); + } +} diff --git a/streams/test-utils/src/test/java/org/apache/kafka/streams/MockProcessorContextTest.java b/streams/test-utils/src/test/java/org/apache/kafka/streams/MockProcessorContextTest.java new file mode 100644 index 0000000..e76cb4f --- /dev/null +++ b/streams/test-utils/src/test/java/org/apache/kafka/streams/MockProcessorContextTest.java @@ -0,0 +1,366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.processor.MockProcessorContext; +import org.apache.kafka.streams.processor.MockProcessorContext.CapturedForward; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.Punctuator; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.To; +import org.apache.kafka.streams.state.KeyValueStore; + +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; +import org.junit.jupiter.api.Test; + +import java.io.File; +import java.time.Duration; +import java.util.Iterator; +import java.util.Properties; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +@SuppressWarnings("deprecation") // this is a test of a deprecated API +public class MockProcessorContextTest { + @Test + public void shouldCaptureOutputRecords() { + final org.apache.kafka.streams.processor.AbstractProcessor processor = new org.apache.kafka.streams.processor.AbstractProcessor() { + @Override + public void process(final String key, final Long value) { + context().forward(key + value, key.length() + value); + } + }; + + final MockProcessorContext context = new MockProcessorContext(); + processor.init(context); + + processor.process("foo", 5L); + processor.process("barbaz", 50L); + + final Iterator forwarded = context.forwarded().iterator(); + assertEquals(new KeyValue<>("foo5", 8L), forwarded.next().keyValue()); + assertEquals(new KeyValue<>("barbaz50", 56L), forwarded.next().keyValue()); + assertFalse(forwarded.hasNext()); + + context.resetForwards(); + + assertEquals(0, context.forwarded().size()); + } + + @Test + public void shouldCaptureOutputRecordsUsingTo() { + final org.apache.kafka.streams.processor.AbstractProcessor processor = new org.apache.kafka.streams.processor.AbstractProcessor() { + @Override + public void process(final String key, final Long value) { + context().forward(key + value, key.length() + value, To.all()); + } + }; + + final MockProcessorContext context = new MockProcessorContext(); + + processor.init(context); + + processor.process("foo", 5L); + processor.process("barbaz", 50L); + + final Iterator forwarded = context.forwarded().iterator(); + assertEquals(new KeyValue<>("foo5", 8L), forwarded.next().keyValue()); + assertEquals(new KeyValue<>("barbaz50", 56L), forwarded.next().keyValue()); + assertFalse(forwarded.hasNext()); + + context.resetForwards(); + + assertEquals(0, context.forwarded().size()); + } + + @Test + public void shouldCaptureRecordsOutputToChildByName() { + final org.apache.kafka.streams.processor.AbstractProcessor processor = new org.apache.kafka.streams.processor.AbstractProcessor() { + private int count = 0; + + @Override + public void process(final String key, final Long value) { + if (count == 0) { + context().forward("start", -1L, To.all()); // broadcast + } + final To toChild = count % 2 == 0 ? To.child("george") : To.child("pete"); + context().forward(key + value, key.length() + value, toChild); + count++; + } + }; + + final MockProcessorContext context = new MockProcessorContext(); + + processor.init(context); + + processor.process("foo", 5L); + processor.process("barbaz", 50L); + + { + final Iterator forwarded = context.forwarded().iterator(); + + final CapturedForward forward1 = forwarded.next(); + assertEquals(new KeyValue<>("start", -1L), forward1.keyValue()); + assertNull(forward1.childName()); + + final CapturedForward forward2 = forwarded.next(); + assertEquals(new KeyValue<>("foo5", 8L), forward2.keyValue()); + assertEquals("george", forward2.childName()); + + final CapturedForward forward3 = forwarded.next(); + assertEquals(new KeyValue<>("barbaz50", 56L), forward3.keyValue()); + assertEquals("pete", forward3.childName()); + + assertFalse(forwarded.hasNext()); + } + + { + final Iterator forwarded = context.forwarded("george").iterator(); + assertEquals(new KeyValue<>("start", -1L), forwarded.next().keyValue()); + assertEquals(new KeyValue<>("foo5", 8L), forwarded.next().keyValue()); + assertFalse(forwarded.hasNext()); + } + + { + final Iterator forwarded = context.forwarded("pete").iterator(); + assertEquals(new KeyValue<>("start", -1L), forwarded.next().keyValue()); + assertEquals(new KeyValue<>("barbaz50", 56L), forwarded.next().keyValue()); + assertFalse(forwarded.hasNext()); + } + + { + final Iterator forwarded = context.forwarded("steve").iterator(); + assertEquals(new KeyValue<>("start", -1L), forwarded.next().keyValue()); + assertFalse(forwarded.hasNext()); + } + } + + @Test + public void shouldCaptureCommitsAndAllowReset() { + final org.apache.kafka.streams.processor.AbstractProcessor processor = new org.apache.kafka.streams.processor.AbstractProcessor() { + private int count = 0; + + @Override + public void process(final String key, final Long value) { + if (++count > 2) { + context().commit(); + } + } + }; + + final MockProcessorContext context = new MockProcessorContext(); + + processor.init(context); + + processor.process("foo", 5L); + processor.process("barbaz", 50L); + + assertFalse(context.committed()); + + processor.process("foobar", 500L); + + assertTrue(context.committed()); + + context.resetCommit(); + + assertFalse(context.committed()); + } + + @Test + public void shouldStoreAndReturnStateStores() { + final org.apache.kafka.streams.processor.AbstractProcessor processor = new org.apache.kafka.streams.processor.AbstractProcessor() { + @Override + public void process(final String key, final Long value) { + final KeyValueStore stateStore = context().getStateStore("my-state"); + stateStore.put(key, (stateStore.get(key) == null ? 0 : stateStore.get(key)) + value); + stateStore.put("all", (stateStore.get("all") == null ? 0 : stateStore.get("all")) + value); + } + }; + + final MockProcessorContext context = new MockProcessorContext(); + + final StoreBuilder> storeBuilder = Stores.keyValueStoreBuilder( + Stores.inMemoryKeyValueStore("my-state"), + Serdes.String(), + Serdes.Long()).withLoggingDisabled(); + + final KeyValueStore store = storeBuilder.build(); + + store.init(context, store); + + processor.init(context); + + processor.process("foo", 5L); + processor.process("bar", 50L); + + assertEquals(5L, (long) store.get("foo")); + assertEquals(50L, (long) store.get("bar")); + assertEquals(55L, (long) store.get("all")); + } + + @Test + public void shouldCaptureApplicationAndRecordMetadata() { + final Properties config = new Properties(); + config.put(StreamsConfig.APPLICATION_ID_CONFIG, "testMetadata"); + + final org.apache.kafka.streams.processor.AbstractProcessor processor = new org.apache.kafka.streams.processor.AbstractProcessor() { + @Override + public void process(final String key, final Object value) { + context().forward("appId", context().applicationId()); + context().forward("taskId", context().taskId()); + + context().forward("topic", context().topic()); + context().forward("partition", context().partition()); + context().forward("offset", context().offset()); + context().forward("timestamp", context().timestamp()); + + context().forward("key", key); + context().forward("value", value); + } + }; + + final MockProcessorContext context = new MockProcessorContext(config); + processor.init(context); + + try { + processor.process("foo", 5L); + fail("Should have thrown an exception."); + } catch (final IllegalStateException expected) { + // expected, since the record metadata isn't initialized + } + + context.resetForwards(); + context.setRecordMetadata("t1", 0, 0L, new RecordHeaders(), 0L); + + { + processor.process("foo", 5L); + final Iterator forwarded = context.forwarded().iterator(); + assertEquals(new KeyValue<>("appId", "testMetadata"), forwarded.next().keyValue()); + assertEquals(new KeyValue<>("taskId", new TaskId(0, 0)), forwarded.next().keyValue()); + assertEquals(new KeyValue<>("topic", "t1"), forwarded.next().keyValue()); + assertEquals(new KeyValue<>("partition", 0), forwarded.next().keyValue()); + assertEquals(new KeyValue<>("offset", 0L), forwarded.next().keyValue()); + assertEquals(new KeyValue<>("timestamp", 0L), forwarded.next().keyValue()); + assertEquals(new KeyValue<>("key", "foo"), forwarded.next().keyValue()); + assertEquals(new KeyValue<>("value", 5L), forwarded.next().keyValue()); + } + + context.resetForwards(); + + // record metadata should be "sticky" + context.setOffset(1L); + context.setRecordTimestamp(10L); + context.setCurrentSystemTimeMs(20L); + context.setCurrentStreamTimeMs(30L); + + { + processor.process("bar", 50L); + final Iterator forwarded = context.forwarded().iterator(); + assertEquals(new KeyValue<>("appId", "testMetadata"), forwarded.next().keyValue()); + assertEquals(new KeyValue<>("taskId", new TaskId(0, 0)), forwarded.next().keyValue()); + assertEquals(new KeyValue<>("topic", "t1"), forwarded.next().keyValue()); + assertEquals(new KeyValue<>("partition", 0), forwarded.next().keyValue()); + assertEquals(new KeyValue<>("offset", 1L), forwarded.next().keyValue()); + assertEquals(new KeyValue<>("timestamp", 10L), forwarded.next().keyValue()); + assertEquals(new KeyValue<>("key", "bar"), forwarded.next().keyValue()); + assertEquals(new KeyValue<>("value", 50L), forwarded.next().keyValue()); + assertEquals(20L, context.currentSystemTimeMs()); + assertEquals(30L, context.currentStreamTimeMs()); + } + + context.resetForwards(); + // record metadata should be "sticky" + context.setTopic("t2"); + context.setPartition(30); + + { + processor.process("baz", 500L); + final Iterator forwarded = context.forwarded().iterator(); + assertEquals(new KeyValue<>("appId", "testMetadata"), forwarded.next().keyValue()); + assertEquals(new KeyValue<>("taskId", new TaskId(0, 0)), forwarded.next().keyValue()); + assertEquals(new KeyValue<>("topic", "t2"), forwarded.next().keyValue()); + assertEquals(new KeyValue<>("partition", 30), forwarded.next().keyValue()); + assertEquals(new KeyValue<>("offset", 1L), forwarded.next().keyValue()); + assertEquals(new KeyValue<>("timestamp", 10L), forwarded.next().keyValue()); + assertEquals(new KeyValue<>("key", "baz"), forwarded.next().keyValue()); + assertEquals(new KeyValue<>("value", 500L), forwarded.next().keyValue()); + } + } + + @Test + public void shouldCapturePunctuator() { + final org.apache.kafka.streams.processor.Processor processor = new org.apache.kafka.streams.processor.Processor() { + @Override + public void init(final ProcessorContext context) { + context.schedule( + Duration.ofSeconds(1L), + PunctuationType.WALL_CLOCK_TIME, + timestamp -> context.commit() + ); + } + + @Override + public void process(final String key, final Long value) { + } + + @Override + public void close() { + } + }; + + final MockProcessorContext context = new MockProcessorContext(); + + processor.init(context); + + final MockProcessorContext.CapturedPunctuator capturedPunctuator = context.scheduledPunctuators().get(0); + assertEquals(1000L, capturedPunctuator.getIntervalMs()); + assertEquals(PunctuationType.WALL_CLOCK_TIME, capturedPunctuator.getType()); + assertFalse(capturedPunctuator.cancelled()); + + final Punctuator punctuator = capturedPunctuator.getPunctuator(); + assertFalse(context.committed()); + punctuator.punctuate(1234L); + assertTrue(context.committed()); + } + + @Test + public void fullConstructorShouldSetAllExpectedAttributes() { + final Properties config = new Properties(); + config.put(StreamsConfig.APPLICATION_ID_CONFIG, "testFullConstructor"); + config.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + config.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.Long().getClass()); + + final File dummyFile = new File(""); + final MockProcessorContext context = new MockProcessorContext(config, new TaskId(1, 1), dummyFile); + + assertEquals("testFullConstructor", context.applicationId()); + assertEquals(new TaskId(1, 1), context.taskId()); + assertEquals("testFullConstructor", context.appConfigs().get(StreamsConfig.APPLICATION_ID_CONFIG)); + assertEquals("testFullConstructor", context.appConfigsWithPrefix("application.").get("id")); + assertEquals(Serdes.String().getClass(), context.keySerde().getClass()); + assertEquals(Serdes.Long().getClass(), context.valueSerde().getClass()); + assertEquals(dummyFile, context.stateDir()); + } +} diff --git a/streams/test-utils/src/test/java/org/apache/kafka/streams/MockTimeTest.java b/streams/test-utils/src/test/java/org/apache/kafka/streams/MockTimeTest.java new file mode 100644 index 0000000..47b625c --- /dev/null +++ b/streams/test-utils/src/test/java/org/apache/kafka/streams/MockTimeTest.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class MockTimeTest { + + @Test + public void shouldSetStartTime() { + final TopologyTestDriver.MockTime time = new TopologyTestDriver.MockTime(42L); + assertEquals(42L, time.milliseconds()); + assertEquals(42L * 1000L * 1000L, time.nanoseconds()); + } + + @Test + public void shouldGetNanosAsMillis() { + final TopologyTestDriver.MockTime time = new TopologyTestDriver.MockTime(42L); + assertEquals(42L, time.hiResClockMs()); + } + + @Test + public void shouldNotAllowNegativeSleep() { + assertThrows(IllegalArgumentException.class, + () -> new TopologyTestDriver.MockTime(42).sleep(-1L)); + } + + @Test + public void shouldAdvanceTimeOnSleep() { + final TopologyTestDriver.MockTime time = new TopologyTestDriver.MockTime(42L); + + assertEquals(42L, time.milliseconds()); + time.sleep(1L); + assertEquals(43L, time.milliseconds()); + time.sleep(0L); + assertEquals(43L, time.milliseconds()); + time.sleep(3L); + assertEquals(46L, time.milliseconds()); + } + +} diff --git a/streams/test-utils/src/test/java/org/apache/kafka/streams/TestTopicsTest.java b/streams/test-utils/src/test/java/org/apache/kafka/streams/TestTopicsTest.java new file mode 100644 index 0000000..30e16d7 --- /dev/null +++ b/streams/test-utils/src/test/java/org/apache/kafka/streams/TestTopicsTest.java @@ -0,0 +1,459 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import org.apache.kafka.common.errors.SerializationException; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.errors.StreamsException; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.test.TestRecord; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Properties; + +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.hasItems; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.hasProperty; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class TestTopicsTest { + private static final Logger log = LoggerFactory.getLogger(TestTopicsTest.class); + + private final static String INPUT_TOPIC = "input"; + private final static String OUTPUT_TOPIC = "output1"; + private final static String INPUT_TOPIC_MAP = OUTPUT_TOPIC; + private final static String OUTPUT_TOPIC_MAP = "output2"; + + private TopologyTestDriver testDriver; + private final Serde stringSerde = new Serdes.StringSerde(); + private final Serde longSerde = new Serdes.LongSerde(); + + private final Instant testBaseTime = Instant.parse("2019-06-01T10:00:00Z"); + + @BeforeEach + public void setup() { + final StreamsBuilder builder = new StreamsBuilder(); + //Create Actual Stream Processing pipeline + builder.stream(INPUT_TOPIC).to(OUTPUT_TOPIC); + final KStream source = builder.stream(INPUT_TOPIC_MAP, Consumed.with(longSerde, stringSerde)); + final KStream mapped = source.map((key, value) -> new KeyValue<>(value, key)); + mapped.to(OUTPUT_TOPIC_MAP, Produced.with(stringSerde, longSerde)); + + final Properties properties = new Properties(); + properties.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.ByteArraySerde.class); + properties.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.ByteArraySerde.class); + testDriver = new TopologyTestDriver(builder.build(), properties); + } + + @AfterEach + public void tearDown() { + try { + testDriver.close(); + } catch (final RuntimeException e) { + // https://issues.apache.org/jira/browse/KAFKA-6647 causes exception when executed in Windows, ignoring it + // Logged stacktrace cannot be avoided + log.warn("Ignoring exception, test failing in Windows due this exception: {}", e.getLocalizedMessage()); + } + } + + @Test + public void testValue() { + final TestInputTopic inputTopic = + testDriver.createInputTopic(INPUT_TOPIC, stringSerde.serializer(), stringSerde.serializer()); + final TestOutputTopic outputTopic = + testDriver.createOutputTopic(OUTPUT_TOPIC, stringSerde.deserializer(), stringSerde.deserializer()); + //Feed word "Hello" to inputTopic and no kafka key, timestamp is irrelevant in this case + inputTopic.pipeInput("Hello"); + assertThat(outputTopic.readValue(), equalTo("Hello")); + //No more output in topic + assertThat(outputTopic.isEmpty(), is(true)); + } + + @Test + public void testValueList() { + final TestInputTopic inputTopic = + testDriver.createInputTopic(INPUT_TOPIC, stringSerde.serializer(), stringSerde.serializer()); + final TestOutputTopic outputTopic = + testDriver.createOutputTopic(OUTPUT_TOPIC, stringSerde.deserializer(), stringSerde.deserializer()); + final List inputList = Arrays.asList("This", "is", "an", "example"); + //Feed list of words to inputTopic and no kafka key, timestamp is irrelevant in this case + inputTopic.pipeValueList(inputList); + final List output = outputTopic.readValuesToList(); + assertThat(output, hasItems("This", "is", "an", "example")); + assertThat(output, is(equalTo(inputList))); + } + + @Test + public void testKeyValue() { + final TestInputTopic inputTopic = + testDriver.createInputTopic(INPUT_TOPIC, longSerde.serializer(), stringSerde.serializer()); + final TestOutputTopic outputTopic = + testDriver.createOutputTopic(OUTPUT_TOPIC, longSerde.deserializer(), stringSerde.deserializer()); + inputTopic.pipeInput(1L, "Hello"); + assertThat(outputTopic.readKeyValue(), equalTo(new KeyValue<>(1L, "Hello"))); + assertThat(outputTopic.isEmpty(), is(true)); + } + + @Test + public void testKeyValueList() { + final TestInputTopic inputTopic = + testDriver.createInputTopic(INPUT_TOPIC_MAP, longSerde.serializer(), stringSerde.serializer()); + final TestOutputTopic outputTopic = + testDriver.createOutputTopic(OUTPUT_TOPIC_MAP, stringSerde.deserializer(), longSerde.deserializer()); + final List inputList = Arrays.asList("This", "is", "an", "example"); + final List> input = new LinkedList<>(); + final List> expected = new LinkedList<>(); + long i = 0; + for (final String s : inputList) { + input.add(new KeyValue<>(i, s)); + expected.add(new KeyValue<>(s, i)); + i++; + } + inputTopic.pipeKeyValueList(input); + final List> output = outputTopic.readKeyValuesToList(); + assertThat(output, is(equalTo(expected))); + } + + @Test + public void testKeyValuesToMap() { + final TestInputTopic inputTopic = + testDriver.createInputTopic(INPUT_TOPIC_MAP, longSerde.serializer(), stringSerde.serializer()); + final TestOutputTopic outputTopic = + testDriver.createOutputTopic(OUTPUT_TOPIC_MAP, stringSerde.deserializer(), longSerde.deserializer()); + final List inputList = Arrays.asList("This", "is", "an", "example"); + final List> input = new LinkedList<>(); + final Map expected = new HashMap<>(); + long i = 0; + for (final String s : inputList) { + input.add(new KeyValue<>(i, s)); + expected.put(s, i); + i++; + } + inputTopic.pipeKeyValueList(input); + final Map output = outputTopic.readKeyValuesToMap(); + assertThat(output, is(equalTo(expected))); + } + + @Test + public void testKeyValuesToMapWithNull() { + final TestInputTopic inputTopic = + testDriver.createInputTopic(INPUT_TOPIC, longSerde.serializer(), stringSerde.serializer()); + final TestOutputTopic outputTopic = + testDriver.createOutputTopic(OUTPUT_TOPIC, longSerde.deserializer(), stringSerde.deserializer()); + inputTopic.pipeInput("value"); + assertThrows(IllegalStateException.class, outputTopic::readKeyValuesToMap); + } + + + @Test + public void testKeyValueListDuration() { + final TestInputTopic inputTopic = + testDriver.createInputTopic(INPUT_TOPIC_MAP, longSerde.serializer(), stringSerde.serializer()); + final TestOutputTopic outputTopic = + testDriver.createOutputTopic(OUTPUT_TOPIC_MAP, stringSerde.deserializer(), longSerde.deserializer()); + final List inputList = Arrays.asList("This", "is", "an", "example"); + final List> input = new LinkedList<>(); + final List> expected = new LinkedList<>(); + long i = 0; + final Duration advance = Duration.ofSeconds(15); + Instant recordInstant = testBaseTime; + for (final String s : inputList) { + input.add(new KeyValue<>(i, s)); + expected.add(new TestRecord<>(s, i, recordInstant)); + i++; + recordInstant = recordInstant.plus(advance); + } + inputTopic.pipeKeyValueList(input, testBaseTime, advance); + final List> output = outputTopic.readRecordsToList(); + assertThat(output, is(equalTo(expected))); + } + + @Test + public void testRecordList() { + final TestInputTopic inputTopic = + testDriver.createInputTopic(INPUT_TOPIC_MAP, longSerde.serializer(), stringSerde.serializer()); + final TestOutputTopic outputTopic = + testDriver.createOutputTopic(OUTPUT_TOPIC_MAP, stringSerde.deserializer(), longSerde.deserializer()); + final List inputList = Arrays.asList("This", "is", "an", "example"); + final List> input = new LinkedList<>(); + final List> expected = new LinkedList<>(); + final Duration advance = Duration.ofSeconds(15); + Instant recordInstant = testBaseTime; + Long i = 0L; + for (final String s : inputList) { + input.add(new TestRecord<>(i, s, recordInstant)); + expected.add(new TestRecord<>(s, i, recordInstant)); + i++; + recordInstant = recordInstant.plus(advance); + } + inputTopic.pipeRecordList(input); + final List> output = outputTopic.readRecordsToList(); + assertThat(output, is(equalTo(expected))); + } + + @Test + public void testTimestamp() { + long baseTime = 3; + final TestInputTopic inputTopic = + testDriver.createInputTopic(INPUT_TOPIC, longSerde.serializer(), stringSerde.serializer()); + final TestOutputTopic outputTopic = + testDriver.createOutputTopic(OUTPUT_TOPIC, longSerde.deserializer(), stringSerde.deserializer()); + inputTopic.pipeInput(null, "Hello", baseTime); + assertThat(outputTopic.readRecord(), is(equalTo(new TestRecord<>(null, "Hello", null, baseTime)))); + + inputTopic.pipeInput(2L, "Kafka", ++baseTime); + assertThat(outputTopic.readRecord(), is(equalTo(new TestRecord<>(2L, "Kafka", null, baseTime)))); + + inputTopic.pipeInput(2L, "Kafka", testBaseTime); + assertThat(outputTopic.readRecord(), is(equalTo(new TestRecord<>(2L, "Kafka", testBaseTime)))); + + final List inputList = Arrays.asList("Advancing", "time"); + //Feed list of words to inputTopic and no kafka key, timestamp advancing from testInstant + final Duration advance = Duration.ofSeconds(15); + final Instant recordInstant = testBaseTime.plus(Duration.ofDays(1)); + inputTopic.pipeValueList(inputList, recordInstant, advance); + assertThat(outputTopic.readRecord(), is(equalTo(new TestRecord<>(null, "Advancing", recordInstant)))); + assertThat(outputTopic.readRecord(), is(equalTo(new TestRecord<>(null, "time", null, recordInstant.plus(advance))))); + } + + @Test + public void testWithHeaders() { + long baseTime = 3; + final Headers headers = new RecordHeaders( + new Header[]{ + new RecordHeader("foo", "value".getBytes()), + new RecordHeader("bar", null), + new RecordHeader("\"A\\u00ea\\u00f1\\u00fcC\"", "value".getBytes()) + }); + final TestInputTopic inputTopic = + testDriver.createInputTopic(INPUT_TOPIC, longSerde.serializer(), stringSerde.serializer()); + final TestOutputTopic outputTopic = + testDriver.createOutputTopic(OUTPUT_TOPIC, longSerde.deserializer(), stringSerde.deserializer()); + inputTopic.pipeInput(new TestRecord<>(1L, "Hello", headers)); + assertThat(outputTopic.readRecord(), allOf( + hasProperty("key", equalTo(1L)), + hasProperty("value", equalTo("Hello")), + hasProperty("headers", equalTo(headers)))); + inputTopic.pipeInput(new TestRecord<>(2L, "Kafka", headers, ++baseTime)); + assertThat(outputTopic.readRecord(), is(equalTo(new TestRecord<>(2L, "Kafka", headers, baseTime)))); + } + + @Test + public void testStartTimestamp() { + final Duration advance = Duration.ofSeconds(2); + final TestInputTopic inputTopic = + testDriver.createInputTopic(INPUT_TOPIC, longSerde.serializer(), stringSerde.serializer(), testBaseTime, Duration.ZERO); + final TestOutputTopic outputTopic = + testDriver.createOutputTopic(OUTPUT_TOPIC, longSerde.deserializer(), stringSerde.deserializer()); + inputTopic.pipeInput(1L, "Hello"); + assertThat(outputTopic.readRecord(), is(equalTo(new TestRecord<>(1L, "Hello", testBaseTime)))); + inputTopic.pipeInput(2L, "World"); + assertThat(outputTopic.readRecord(), is(equalTo(new TestRecord<>(2L, "World", null, testBaseTime.toEpochMilli())))); + inputTopic.advanceTime(advance); + inputTopic.pipeInput(3L, "Kafka"); + assertThat(outputTopic.readRecord(), is(equalTo(new TestRecord<>(3L, "Kafka", testBaseTime.plus(advance))))); + } + + @Test + public void testTimestampAutoAdvance() { + final Duration advance = Duration.ofSeconds(2); + final TestInputTopic inputTopic = + testDriver.createInputTopic(INPUT_TOPIC, longSerde.serializer(), stringSerde.serializer(), testBaseTime, advance); + final TestOutputTopic outputTopic = + testDriver.createOutputTopic(OUTPUT_TOPIC, longSerde.deserializer(), stringSerde.deserializer()); + inputTopic.pipeInput("Hello"); + assertThat(outputTopic.readRecord(), is(equalTo(new TestRecord<>(null, "Hello", testBaseTime)))); + inputTopic.pipeInput(2L, "Kafka"); + assertThat(outputTopic.readRecord(), is(equalTo(new TestRecord<>(2L, "Kafka", testBaseTime.plus(advance))))); + } + + + @Test + public void testMultipleTopics() { + final TestInputTopic inputTopic1 = + testDriver.createInputTopic(INPUT_TOPIC, longSerde.serializer(), stringSerde.serializer()); + final TestInputTopic inputTopic2 = + testDriver.createInputTopic(INPUT_TOPIC_MAP, longSerde.serializer(), stringSerde.serializer()); + final TestOutputTopic outputTopic1 = + testDriver.createOutputTopic(OUTPUT_TOPIC, longSerde.deserializer(), stringSerde.deserializer()); + final TestOutputTopic outputTopic2 = + testDriver.createOutputTopic(OUTPUT_TOPIC_MAP, stringSerde.deserializer(), longSerde.deserializer()); + inputTopic1.pipeInput(1L, "Hello"); + assertThat(outputTopic1.readKeyValue(), equalTo(new KeyValue<>(1L, "Hello"))); + assertThat(outputTopic2.readKeyValue(), equalTo(new KeyValue<>("Hello", 1L))); + assertThat(outputTopic1.isEmpty(), is(true)); + assertThat(outputTopic2.isEmpty(), is(true)); + inputTopic2.pipeInput(1L, "Hello"); + //This is not visible in outputTopic1 even it is the same topic + assertThat(outputTopic2.readKeyValue(), equalTo(new KeyValue<>("Hello", 1L))); + assertThat(outputTopic1.isEmpty(), is(true)); + assertThat(outputTopic2.isEmpty(), is(true)); + } + + @Test + public void testNonExistingOutputTopic() { + final TestOutputTopic outputTopic = + testDriver.createOutputTopic("no-exist", longSerde.deserializer(), stringSerde.deserializer()); + assertThrows(NoSuchElementException.class, outputTopic::readRecord, "Uninitialized topic"); + } + + @Test + public void testNonUsedOutputTopic() { + final TestOutputTopic outputTopic = + testDriver.createOutputTopic(OUTPUT_TOPIC, longSerde.deserializer(), stringSerde.deserializer()); + assertThrows(NoSuchElementException.class, outputTopic::readRecord, "Uninitialized topic"); + } + + @Test + public void testEmptyTopic() { + final TestInputTopic inputTopic = + testDriver.createInputTopic(INPUT_TOPIC, stringSerde.serializer(), stringSerde.serializer()); + final TestOutputTopic outputTopic = + testDriver.createOutputTopic(OUTPUT_TOPIC, stringSerde.deserializer(), stringSerde.deserializer()); + //Feed word "Hello" to inputTopic and no kafka key, timestamp is irrelevant in this case + inputTopic.pipeInput("Hello"); + assertThat(outputTopic.readValue(), equalTo("Hello")); + //No more output in topic + assertThrows(NoSuchElementException.class, outputTopic::readRecord, "Empty topic"); + } + + @Test + public void testNonExistingInputTopic() { + final TestInputTopic inputTopic = + testDriver.createInputTopic("no-exist", longSerde.serializer(), stringSerde.serializer()); + assertThrows(IllegalArgumentException.class, () -> inputTopic.pipeInput(1L, "Hello"), "Unknown topic"); + } + + @Test + public void shouldNotAllowToCreateTopicWithNullTopicName() { + assertThrows(NullPointerException.class, () -> testDriver.createInputTopic(null, stringSerde.serializer(), stringSerde.serializer())); + } + + @Test + public void shouldNotAllowToCreateWithNullDriver() { + assertThrows(NullPointerException.class, + () -> new TestInputTopic<>(null, INPUT_TOPIC, stringSerde.serializer(), stringSerde.serializer(), Instant.now(), Duration.ZERO)); + } + + + @Test + public void testWrongSerde() { + final TestInputTopic inputTopic = + testDriver.createInputTopic(INPUT_TOPIC_MAP, stringSerde.serializer(), stringSerde.serializer()); + assertThrows(StreamsException.class, () -> inputTopic.pipeInput("1L", "Hello")); + } + + @Test + public void testDuration() { + assertThrows(IllegalArgumentException.class, + () -> testDriver.createInputTopic(INPUT_TOPIC_MAP, stringSerde.serializer(), stringSerde.serializer(), testBaseTime, Duration.ofDays(-1))); + } + + @Test + public void testNegativeAdvance() { + final TestInputTopic inputTopic = testDriver.createInputTopic(INPUT_TOPIC_MAP, stringSerde.serializer(), stringSerde.serializer()); + assertThrows(IllegalArgumentException.class, () -> inputTopic.advanceTime(Duration.ofDays(-1))); + } + + @Test + public void testInputToString() { + final TestInputTopic inputTopic = + testDriver.createInputTopic("topicName", stringSerde.serializer(), stringSerde.serializer()); + assertThat(inputTopic.toString(), allOf( + containsString("TestInputTopic"), + containsString("topic='topicName'"), + containsString("StringSerializer"))); + } + + @Test + public void shouldNotAllowToCreateOutputTopicWithNullTopicName() { + assertThrows(NullPointerException.class, () -> testDriver.createOutputTopic(null, stringSerde.deserializer(), stringSerde.deserializer())); + } + + @Test + public void shouldNotAllowToCreateOutputWithNullDriver() { + assertThrows(NullPointerException.class, () -> new TestOutputTopic<>(null, OUTPUT_TOPIC, stringSerde.deserializer(), stringSerde.deserializer())); + } + + @Test + public void testOutputWrongSerde() { + final TestInputTopic inputTopic = + testDriver.createInputTopic(INPUT_TOPIC_MAP, longSerde.serializer(), stringSerde.serializer()); + final TestOutputTopic outputTopic = + testDriver.createOutputTopic(OUTPUT_TOPIC_MAP, longSerde.deserializer(), stringSerde.deserializer()); + inputTopic.pipeInput(1L, "Hello"); + assertThrows(SerializationException.class, outputTopic::readKeyValue); + } + + @Test + public void testOutputToString() { + final TestOutputTopic outputTopic = + testDriver.createOutputTopic(OUTPUT_TOPIC, stringSerde.deserializer(), stringSerde.deserializer()); + assertThat(outputTopic.toString(), allOf( + containsString("TestOutputTopic"), + containsString("topic='output1'"), + containsString("size=0"), + containsString("StringDeserializer"))); + } + + @Test + public void testRecordsToList() { + final TestInputTopic inputTopic = + testDriver.createInputTopic(INPUT_TOPIC_MAP, longSerde.serializer(), stringSerde.serializer()); + final TestOutputTopic outputTopic = + testDriver.createOutputTopic(OUTPUT_TOPIC_MAP, stringSerde.deserializer(), longSerde.deserializer()); + final List inputList = Arrays.asList("This", "is", "an", "example"); + final List> input = new LinkedList<>(); + final List> expected = new LinkedList<>(); + long i = 0; + final Duration advance = Duration.ofSeconds(15); + Instant recordInstant = Instant.parse("2019-06-01T10:00:00Z"); + for (final String s : inputList) { + input.add(new KeyValue<>(i, s)); + expected.add(new TestRecord<>(s, i, recordInstant)); + i++; + recordInstant = recordInstant.plus(advance); + } + inputTopic.pipeKeyValueList(input, Instant.parse("2019-06-01T10:00:00Z"), advance); + final List> output = outputTopic.readRecordsToList(); + assertThat(output, is(equalTo(expected))); + } + +} diff --git a/streams/test-utils/src/test/java/org/apache/kafka/streams/TopologyTestDriverAtLeastOnceTest.java b/streams/test-utils/src/test/java/org/apache/kafka/streams/TopologyTestDriverAtLeastOnceTest.java new file mode 100644 index 0000000..39a70da --- /dev/null +++ b/streams/test-utils/src/test/java/org/apache/kafka/streams/TopologyTestDriverAtLeastOnceTest.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams; + +import java.util.Collections; + +public class TopologyTestDriverAtLeastOnceTest extends TopologyTestDriverTest { + TopologyTestDriverAtLeastOnceTest() { + super(Collections.singletonMap(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.AT_LEAST_ONCE)); + } +} diff --git a/streams/test-utils/src/test/java/org/apache/kafka/streams/TopologyTestDriverEosTest.java b/streams/test-utils/src/test/java/org/apache/kafka/streams/TopologyTestDriverEosTest.java new file mode 100644 index 0000000..a2c2139 --- /dev/null +++ b/streams/test-utils/src/test/java/org/apache/kafka/streams/TopologyTestDriverEosTest.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.streams; + +import java.util.Collections; + +public class TopologyTestDriverEosTest extends TopologyTestDriverTest { + TopologyTestDriverEosTest() { + super(Collections.singletonMap(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE_V2)); + } +} diff --git a/streams/test-utils/src/test/java/org/apache/kafka/streams/TopologyTestDriverTest.java b/streams/test-utils/src/test/java/org/apache/kafka/streams/TopologyTestDriverTest.java new file mode 100644 index 0000000..c541650 --- /dev/null +++ b/streams/test-utils/src/test/java/org/apache/kafka/streams/TopologyTestDriverTest.java @@ -0,0 +1,1709 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.LongDeserializer; +import org.apache.kafka.common.serialization.LongSerializer; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.serialization.Serializer; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.serialization.StringSerializer; +import org.apache.kafka.common.utils.SystemTime; +import org.apache.kafka.streams.errors.TopologyException; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.TableJoined; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.Punctuator; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.api.RecordMetadata; +import org.apache.kafka.streams.state.KeyValueBytesStoreSupplier; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.internals.KeyValueStoreBuilder; +import org.apache.kafka.streams.test.TestRecord; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import java.io.File; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Objects; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; +import java.util.regex.Pattern; + +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkProperties; +import static org.apache.kafka.common.utils.Utils.mkSet; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.hasItem; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public abstract class TopologyTestDriverTest { + + TopologyTestDriverTest(final Map overrides) { + config = mkProperties(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "test-TopologyTestDriver"), + mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getAbsolutePath()) + )); + config.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.ByteArraySerde.class); + config.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.ByteArraySerde.class); + config.putAll(overrides); + } + + private final static String SOURCE_TOPIC_1 = "source-topic-1"; + private final static String SOURCE_TOPIC_2 = "source-topic-2"; + private final static String SINK_TOPIC_1 = "sink-topic-1"; + private final static String SINK_TOPIC_2 = "sink-topic-2"; + + private final Headers headers = new RecordHeaders(new Header[]{new RecordHeader("key", "value".getBytes())}); + + private final byte[] key1 = new byte[0]; + private final byte[] value1 = new byte[0]; + private final long timestamp1 = 42L; + private final TestRecord testRecord1 = new TestRecord<>(key1, value1, headers, timestamp1); + + private final byte[] key2 = new byte[0]; + private final byte[] value2 = new byte[0]; + private final long timestamp2 = 43L; + + private TopologyTestDriver testDriver; + private final Properties config; + private KeyValueStore store; + + private final StringDeserializer stringDeserializer = new StringDeserializer(); + private final LongDeserializer longDeserializer = new LongDeserializer(); + + private final static class TTDTestRecord { + private final Object key; + private final Object value; + private final long timestamp; + private final long offset; + private final String topic; + private final Headers headers; + + TTDTestRecord(final String newTopic, + final TestRecord consumerRecord, + final long newOffset) { + key = consumerRecord.key(); + value = consumerRecord.value(); + timestamp = consumerRecord.timestamp(); + offset = newOffset; + topic = newTopic; + headers = consumerRecord.headers(); + } + + TTDTestRecord(final Object key, + final Object value, + final Headers headers, + final long timestamp, + final long offset, + final String topic) { + this.key = key; + this.value = value; + this.headers = headers; + this.timestamp = timestamp; + this.offset = offset; + this.topic = topic; + } + + @Override + public String toString() { + return "key: " + key + + ", value: " + value + + ", timestamp: " + timestamp + + ", offset: " + offset + + ", topic: " + topic + + ", num.headers: " + (headers == null ? "null" : headers.toArray().length); + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final TTDTestRecord record = (TTDTestRecord) o; + return timestamp == record.timestamp && + offset == record.offset && + Objects.equals(key, record.key) && + Objects.equals(value, record.value) && + Objects.equals(topic, record.topic) && + Objects.equals(headers, record.headers); + } + + @Override + public int hashCode() { + return Objects.hash(key, value, headers, timestamp, offset, topic); + } + } + + private final static class Punctuation { + private final long intervalMs; + private final PunctuationType punctuationType; + private final Punctuator callback; + + Punctuation(final long intervalMs, + final PunctuationType punctuationType, + final Punctuator callback) { + this.intervalMs = intervalMs; + this.punctuationType = punctuationType; + this.callback = callback; + } + } + + private final static class MockPunctuator implements Punctuator { + private final List punctuatedAt = new LinkedList<>(); + + @Override + public void punctuate(final long timestamp) { + punctuatedAt.add(timestamp); + } + } + + private final static class MockProcessor implements Processor { + private final Collection punctuations; + private ProcessorContext context; + + private boolean initialized = false; + private boolean closed = false; + private final List processedRecords = new ArrayList<>(); + + MockProcessor(final Collection punctuations) { + this.punctuations = punctuations; + } + + @Override + public void init(final ProcessorContext context) { + initialized = true; + this.context = context; + for (final Punctuation punctuation : punctuations) { + this.context.schedule(Duration.ofMillis(punctuation.intervalMs), punctuation.punctuationType, punctuation.callback); + } + } + + @Override + public void process(final Record record) { + processedRecords.add(new TTDTestRecord( + record.key(), + record.value(), + record.headers(), + record.timestamp(), + context.recordMetadata().map(RecordMetadata::offset).orElse(-1L), + context.recordMetadata().map(RecordMetadata::topic).orElse(null) + )); + context.forward(record); + } + + @Override + public void close() { + closed = true; + } + } + + private final List mockProcessors = new ArrayList<>(); + + private final class MockProcessorSupplier implements ProcessorSupplier { + private final Collection punctuations; + + private MockProcessorSupplier() { + this(Collections.emptySet()); + } + + private MockProcessorSupplier(final Collection punctuations) { + this.punctuations = punctuations; + } + + @Override + public Processor get() { + final MockProcessor mockProcessor = new MockProcessor(punctuations); + + // to keep tests simple, ignore calls from ApiUtils.checkSupplier + if (!isCheckSupplierCall()) { + mockProcessors.add(mockProcessor); + } + + return mockProcessor; + } + + /** + * Used to keep tests simple, and ignore calls from {@link org.apache.kafka.streams.internals.ApiUtils#checkSupplier(Supplier)} )}. + * @return true if the stack context is within a {@link org.apache.kafka.streams.internals.ApiUtils#checkSupplier(Supplier)} )} call + */ + public boolean isCheckSupplierCall() { + return Arrays.stream(Thread.currentThread().getStackTrace()) + .anyMatch(caller -> "org.apache.kafka.streams.internals.ApiUtils".equals(caller.getClassName()) && "checkSupplier".equals(caller.getMethodName())); + } + } + + @AfterEach + public void tearDown() { + if (testDriver != null) { + testDriver.close(); + } + } + + private Topology setupSourceSinkTopology() { + final Topology topology = new Topology(); + + final String sourceName = "source"; + + topology.addSource(sourceName, SOURCE_TOPIC_1); + topology.addSink("sink", SINK_TOPIC_1, sourceName); + + return topology; + } + + private Topology setupTopologyWithTwoSubtopologies() { + final Topology topology = new Topology(); + + final String sourceName1 = "source-1"; + final String sourceName2 = "source-2"; + + topology.addSource(sourceName1, SOURCE_TOPIC_1); + topology.addSink("sink-1", SINK_TOPIC_1, sourceName1); + topology.addSource(sourceName2, SINK_TOPIC_1); + topology.addSink("sink-2", SINK_TOPIC_2, sourceName2); + + return topology; + } + + + private Topology setupSingleProcessorTopology() { + return setupSingleProcessorTopology(-1, null, null); + } + + private Topology setupSingleProcessorTopology(final long punctuationIntervalMs, + final PunctuationType punctuationType, + final Punctuator callback) { + final Collection punctuations; + if (punctuationIntervalMs > 0 && punctuationType != null && callback != null) { + punctuations = Collections.singleton(new Punctuation(punctuationIntervalMs, punctuationType, callback)); + } else { + punctuations = Collections.emptySet(); + } + + final Topology topology = new Topology(); + + final String sourceName = "source"; + + topology.addSource(sourceName, SOURCE_TOPIC_1); + topology.addProcessor("processor", new MockProcessorSupplier(punctuations), sourceName); + + return topology; + } + + private Topology setupMultipleSourceTopology(final String... sourceTopicNames) { + final Topology topology = new Topology(); + + final String[] processorNames = new String[sourceTopicNames.length]; + int i = 0; + for (final String sourceTopicName : sourceTopicNames) { + final String sourceName = sourceTopicName + "-source"; + final String processorName = sourceTopicName + "-processor"; + topology.addSource(sourceName, sourceTopicName); + processorNames[i++] = processorName; + topology.addProcessor(processorName, new MockProcessorSupplier(), sourceName); + } + topology.addSink("sink-topic", SINK_TOPIC_1, processorNames); + + return topology; + } + + private Topology setupGlobalStoreTopology(final String... sourceTopicNames) { + if (sourceTopicNames.length == 0) { + throw new IllegalArgumentException("sourceTopicNames cannot be empty"); + } + final Topology topology = new Topology(); + + for (final String sourceTopicName : sourceTopicNames) { + topology.addGlobalStore( + Stores.keyValueStoreBuilder( + Stores.inMemoryKeyValueStore( + sourceTopicName + "-globalStore"), + null, + null) + .withLoggingDisabled(), + sourceTopicName, + null, + null, + sourceTopicName, + sourceTopicName + "-processor", + () -> new Processor() { + KeyValueStore store; + + @SuppressWarnings("unchecked") + @Override + public void init(final ProcessorContext context) { + store = context.getStateStore(sourceTopicName + "-globalStore"); + } + + @Override + public void process(final Record record) { + store.put(record.key(), record.value()); + } + } + ); + } + + return topology; + } + + private Topology setupTopologyWithInternalTopic(final String firstTableName, + final String secondTableName, + final String joinName) { + final StreamsBuilder builder = new StreamsBuilder(); + + final KTable t1 = builder.stream(SOURCE_TOPIC_1) + .selectKey((k, v) -> v) + .groupByKey() + .count(Materialized.as(firstTableName)); + + builder.table(SOURCE_TOPIC_2, Materialized.as(secondTableName)) + .join(t1, v -> v, (v1, v2) -> v2, TableJoined.as(joinName)); + + return builder.build(config); + } + + @Test + public void shouldNotRequireParameters() { + new TopologyTestDriver(setupSingleProcessorTopology(), config); + } + + @Test + public void shouldInitProcessor() { + testDriver = new TopologyTestDriver(setupSingleProcessorTopology(), config); + assertTrue(mockProcessors.get(0).initialized); + } + + @Test + public void shouldCloseProcessor() { + testDriver = new TopologyTestDriver(setupSingleProcessorTopology(), config); + testDriver.close(); + assertTrue(mockProcessors.get(0).closed); + // As testDriver is already closed, bypassing @AfterEach tearDown testDriver.close(). + testDriver = null; + } + + @Test + public void shouldThrowForUnknownTopic() { + testDriver = new TopologyTestDriver(new Topology()); + assertThrows( + IllegalArgumentException.class, + () -> testDriver.pipeRecord( + "unknownTopic", + new TestRecord<>((byte[]) null), + new ByteArraySerializer(), + new ByteArraySerializer(), + Instant.now()) + ); + } + + @Test + public void shouldThrowForMissingTime() { + testDriver = new TopologyTestDriver(new Topology()); + assertThrows( + IllegalStateException.class, + () -> testDriver.pipeRecord( + SINK_TOPIC_1, + new TestRecord<>("value"), + new StringSerializer(), + new StringSerializer(), + null)); + } + + @Test + public void shouldThrowNoSuchElementExceptionForUnusedOutputTopicWithDynamicRouting() { + testDriver = new TopologyTestDriver(setupSourceSinkTopology(), config); + final TestOutputTopic outputTopic = new TestOutputTopic<>( + testDriver, + "unused-topic", + new StringDeserializer(), + new StringDeserializer() + ); + + assertTrue(outputTopic.isEmpty()); + assertThrows(NoSuchElementException.class, outputTopic::readRecord); + } + + @Test + public void shouldCaptureSinkTopicNamesIfWrittenInto() { + testDriver = new TopologyTestDriver(setupSourceSinkTopology(), config); + + assertThat(testDriver.producedTopicNames(), is(Collections.emptySet())); + + pipeRecord(SOURCE_TOPIC_1, testRecord1); + assertThat(testDriver.producedTopicNames(), hasItem(SINK_TOPIC_1)); + } + + @Test + public void shouldCaptureInternalTopicNamesIfWrittenInto() { + testDriver = new TopologyTestDriver( + setupTopologyWithInternalTopic("table1", "table2", "join"), + config + ); + + assertThat(testDriver.producedTopicNames(), is(Collections.emptySet())); + + pipeRecord(SOURCE_TOPIC_1, testRecord1); + assertThat( + testDriver.producedTopicNames(), + equalTo(mkSet( + config.getProperty(StreamsConfig.APPLICATION_ID_CONFIG) + "-table1-repartition", + config.getProperty(StreamsConfig.APPLICATION_ID_CONFIG) + "-table1-changelog" + )) + ); + + pipeRecord(SOURCE_TOPIC_2, testRecord1); + assertThat( + testDriver.producedTopicNames(), + equalTo(mkSet( + config.getProperty(StreamsConfig.APPLICATION_ID_CONFIG) + "-table1-repartition", + config.getProperty(StreamsConfig.APPLICATION_ID_CONFIG) + "-table1-changelog", + config.getProperty(StreamsConfig.APPLICATION_ID_CONFIG) + "-table2-changelog", + config.getProperty(StreamsConfig.APPLICATION_ID_CONFIG) + "-join-subscription-registration-topic", + config.getProperty(StreamsConfig.APPLICATION_ID_CONFIG) + "-join-subscription-store-changelog", + config.getProperty(StreamsConfig.APPLICATION_ID_CONFIG) + "-join-subscription-response-topic" + )) + ); + } + + @Test + public void shouldCaptureGlobalTopicNameIfWrittenInto() { + final StreamsBuilder builder = new StreamsBuilder(); + builder.globalTable(SOURCE_TOPIC_1, Materialized.as("globalTable")); + builder.stream(SOURCE_TOPIC_2).to(SOURCE_TOPIC_1); + + testDriver = new TopologyTestDriver(builder.build(), config); + + assertThat(testDriver.producedTopicNames(), is(Collections.emptySet())); + + pipeRecord(SOURCE_TOPIC_2, testRecord1); + assertThat( + testDriver.producedTopicNames(), + equalTo(Collections.singleton(SOURCE_TOPIC_1)) + ); + } + + @Test + public void shouldProcessRecordForTopic() { + testDriver = new TopologyTestDriver(setupSourceSinkTopology(), config); + + pipeRecord(SOURCE_TOPIC_1, testRecord1); + final ProducerRecord outputRecord = testDriver.readRecord(SINK_TOPIC_1); + + assertEquals(key1, outputRecord.key()); + assertEquals(value1, outputRecord.value()); + assertEquals(SINK_TOPIC_1, outputRecord.topic()); + } + + @Test + public void shouldSetRecordMetadata() { + testDriver = new TopologyTestDriver(setupSingleProcessorTopology(), config); + + pipeRecord(SOURCE_TOPIC_1, testRecord1); + + final List processedRecords = mockProcessors.get(0).processedRecords; + assertEquals(1, processedRecords.size()); + + final TTDTestRecord record = processedRecords.get(0); + final TTDTestRecord expectedResult = new TTDTestRecord(SOURCE_TOPIC_1, testRecord1, 0L); + + assertThat(record, equalTo(expectedResult)); + } + + private void pipeRecord(final String topic, final TestRecord record) { + testDriver.pipeRecord(topic, record, new ByteArraySerializer(), new ByteArraySerializer(), null); + } + + + @Test + public void shouldSendRecordViaCorrectSourceTopic() { + testDriver = new TopologyTestDriver(setupMultipleSourceTopology(SOURCE_TOPIC_1, SOURCE_TOPIC_2), config); + + final List processedRecords1 = mockProcessors.get(0).processedRecords; + final List processedRecords2 = mockProcessors.get(1).processedRecords; + + final TestInputTopic inputTopic1 = testDriver.createInputTopic(SOURCE_TOPIC_1, + new ByteArraySerializer(), new ByteArraySerializer()); + final TestInputTopic inputTopic2 = testDriver.createInputTopic(SOURCE_TOPIC_2, + new ByteArraySerializer(), new ByteArraySerializer()); + + inputTopic1.pipeInput(new TestRecord<>(key1, value1, headers, timestamp1)); + + assertEquals(1, processedRecords1.size()); + assertEquals(0, processedRecords2.size()); + + TTDTestRecord record = processedRecords1.get(0); + TTDTestRecord expectedResult = new TTDTestRecord(key1, value1, headers, timestamp1, 0L, SOURCE_TOPIC_1); + assertThat(record, equalTo(expectedResult)); + + inputTopic2.pipeInput(new TestRecord<>(key2, value2, Instant.ofEpochMilli(timestamp2))); + + assertEquals(1, processedRecords1.size()); + assertEquals(1, processedRecords2.size()); + + record = processedRecords2.get(0); + expectedResult = new TTDTestRecord(key2, value2, new RecordHeaders((Iterable

                ) null), timestamp2, 0L, SOURCE_TOPIC_2); + assertThat(record, equalTo(expectedResult)); + } + + @Test + public void shouldUseSourceSpecificDeserializers() { + final Topology topology = new Topology(); + + final String sourceName1 = "source-1"; + final String sourceName2 = "source-2"; + final String processor = "processor"; + + topology.addSource(sourceName1, Serdes.Long().deserializer(), Serdes.String().deserializer(), SOURCE_TOPIC_1); + topology.addSource(sourceName2, Serdes.Integer().deserializer(), Serdes.Double().deserializer(), SOURCE_TOPIC_2); + topology.addProcessor(processor, new MockProcessorSupplier(), sourceName1, sourceName2); + topology.addSink( + "sink", + SINK_TOPIC_1, + (topic, data) -> { + if (data instanceof Long) { + return Serdes.Long().serializer().serialize(topic, (Long) data); + } + return Serdes.Integer().serializer().serialize(topic, (Integer) data); + }, + (topic, data) -> { + if (data instanceof String) { + return Serdes.String().serializer().serialize(topic, (String) data); + } + return Serdes.Double().serializer().serialize(topic, (Double) data); + }, + processor); + + testDriver = new TopologyTestDriver(topology); + + final Long source1Key = 42L; + final String source1Value = "anyString"; + final Integer source2Key = 73; + final Double source2Value = 3.14; + + final TestRecord consumerRecord1 = new TestRecord<>(source1Key, source1Value); + final TestRecord consumerRecord2 = new TestRecord<>(source2Key, source2Value); + + testDriver.pipeRecord(SOURCE_TOPIC_1, + consumerRecord1, + Serdes.Long().serializer(), + Serdes.String().serializer(), + Instant.now()); + final TestRecord result1 = + testDriver.readRecord(SINK_TOPIC_1, Serdes.Long().deserializer(), Serdes.String().deserializer()); + assertThat(result1.getKey(), equalTo(source1Key)); + assertThat(result1.getValue(), equalTo(source1Value)); + + testDriver.pipeRecord(SOURCE_TOPIC_2, + consumerRecord2, + Serdes.Integer().serializer(), + Serdes.Double().serializer(), + Instant.now()); + final TestRecord result2 = + testDriver.readRecord(SINK_TOPIC_1, Serdes.Integer().deserializer(), Serdes.Double().deserializer()); + assertThat(result2.getKey(), equalTo(source2Key)); + assertThat(result2.getValue(), equalTo(source2Value)); + } + + @Test + public void shouldPassRecordHeadersIntoSerializersAndDeserializers() { + testDriver = new TopologyTestDriver(setupSourceSinkTopology(), config); + + final AtomicBoolean passedHeadersToKeySerializer = new AtomicBoolean(false); + final AtomicBoolean passedHeadersToValueSerializer = new AtomicBoolean(false); + final AtomicBoolean passedHeadersToKeyDeserializer = new AtomicBoolean(false); + final AtomicBoolean passedHeadersToValueDeserializer = new AtomicBoolean(false); + + final Serializer keySerializer = new ByteArraySerializer() { + @Override + public byte[] serialize(final String topic, final Headers headers, final byte[] data) { + passedHeadersToKeySerializer.set(true); + return serialize(topic, data); + } + }; + final Serializer valueSerializer = new ByteArraySerializer() { + @Override + public byte[] serialize(final String topic, final Headers headers, final byte[] data) { + passedHeadersToValueSerializer.set(true); + return serialize(topic, data); + } + }; + + final Deserializer keyDeserializer = new ByteArrayDeserializer() { + @Override + public byte[] deserialize(final String topic, final Headers headers, final byte[] data) { + passedHeadersToKeyDeserializer.set(true); + return deserialize(topic, data); + } + }; + final Deserializer valueDeserializer = new ByteArrayDeserializer() { + @Override + public byte[] deserialize(final String topic, final Headers headers, final byte[] data) { + passedHeadersToValueDeserializer.set(true); + return deserialize(topic, data); + } + }; + + final TestInputTopic inputTopic = testDriver.createInputTopic(SOURCE_TOPIC_1, keySerializer, valueSerializer); + final TestOutputTopic outputTopic = testDriver.createOutputTopic(SINK_TOPIC_1, keyDeserializer, valueDeserializer); + inputTopic.pipeInput(testRecord1); + outputTopic.readRecord(); + + assertThat(passedHeadersToKeySerializer.get(), equalTo(true)); + assertThat(passedHeadersToValueSerializer.get(), equalTo(true)); + assertThat(passedHeadersToKeyDeserializer.get(), equalTo(true)); + assertThat(passedHeadersToValueDeserializer.get(), equalTo(true)); + } + + @Test + public void shouldUseSinkSpecificSerializers() { + final Topology topology = new Topology(); + + final String sourceName1 = "source-1"; + final String sourceName2 = "source-2"; + + topology.addSource(sourceName1, Serdes.Long().deserializer(), Serdes.String().deserializer(), SOURCE_TOPIC_1); + topology.addSource(sourceName2, Serdes.Integer().deserializer(), Serdes.Double().deserializer(), SOURCE_TOPIC_2); + topology.addSink("sink-1", SINK_TOPIC_1, Serdes.Long().serializer(), Serdes.String().serializer(), sourceName1); + topology.addSink("sink-2", SINK_TOPIC_2, Serdes.Integer().serializer(), Serdes.Double().serializer(), sourceName2); + + testDriver = new TopologyTestDriver(topology); + + final Long source1Key = 42L; + final String source1Value = "anyString"; + final Integer source2Key = 73; + final Double source2Value = 3.14; + + final TestRecord consumerRecord1 = new TestRecord<>(source1Key, source1Value); + final TestRecord consumerRecord2 = new TestRecord<>(source2Key, source2Value); + + testDriver.pipeRecord(SOURCE_TOPIC_1, + consumerRecord1, + Serdes.Long().serializer(), + Serdes.String().serializer(), + Instant.now()); + final TestRecord result1 = + testDriver.readRecord(SINK_TOPIC_1, Serdes.Long().deserializer(), Serdes.String().deserializer()); + assertThat(result1.getKey(), equalTo(source1Key)); + assertThat(result1.getValue(), equalTo(source1Value)); + + testDriver.pipeRecord(SOURCE_TOPIC_2, + consumerRecord2, + Serdes.Integer().serializer(), + Serdes.Double().serializer(), + Instant.now()); + final TestRecord result2 = + testDriver.readRecord(SINK_TOPIC_2, Serdes.Integer().deserializer(), Serdes.Double().deserializer()); + assertThat(result2.getKey(), equalTo(source2Key)); + assertThat(result2.getValue(), equalTo(source2Value)); + } + + @Test + public void shouldForwardRecordsFromSubtopologyToSubtopology() { + testDriver = new TopologyTestDriver(setupTopologyWithTwoSubtopologies(), config); + + pipeRecord(SOURCE_TOPIC_1, testRecord1); + + ProducerRecord outputRecord = testDriver.readRecord(SINK_TOPIC_1); + assertEquals(key1, outputRecord.key()); + assertEquals(value1, outputRecord.value()); + assertEquals(SINK_TOPIC_1, outputRecord.topic()); + + outputRecord = testDriver.readRecord(SINK_TOPIC_2); + assertEquals(key1, outputRecord.key()); + assertEquals(value1, outputRecord.value()); + assertEquals(SINK_TOPIC_2, outputRecord.topic()); + } + + @Test + public void shouldPopulateGlobalStore() { + testDriver = new TopologyTestDriver(setupGlobalStoreTopology(SOURCE_TOPIC_1), config); + + final KeyValueStore globalStore = testDriver.getKeyValueStore(SOURCE_TOPIC_1 + "-globalStore"); + assertNotNull(globalStore); + assertNotNull(testDriver.getAllStateStores().get(SOURCE_TOPIC_1 + "-globalStore")); + + pipeRecord(SOURCE_TOPIC_1, testRecord1); + + assertThat(globalStore.get(testRecord1.key()), is(testRecord1.value())); + } + + @Test + public void shouldPunctuateOnStreamsTime() { + final MockPunctuator mockPunctuator = new MockPunctuator(); + testDriver = new TopologyTestDriver( + setupSingleProcessorTopology(10L, PunctuationType.STREAM_TIME, mockPunctuator), + config + ); + + final List expectedPunctuations = new LinkedList<>(); + + expectedPunctuations.add(42L); + pipeRecord(SOURCE_TOPIC_1, new TestRecord<>(key1, value1, null, 42L)); + assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); + + pipeRecord(SOURCE_TOPIC_1, new TestRecord<>(key1, value1, null, 42L)); + assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); + + expectedPunctuations.add(51L); + pipeRecord(SOURCE_TOPIC_1, new TestRecord<>(key1, value1, null, 51L)); + assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); + + pipeRecord(SOURCE_TOPIC_1, new TestRecord<>(key1, value1, null, 52L)); + assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); + + expectedPunctuations.add(61L); + pipeRecord(SOURCE_TOPIC_1, new TestRecord<>(key1, value1, null, 61L)); + assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); + + pipeRecord(SOURCE_TOPIC_1, new TestRecord<>(key1, value1, null, 65L)); + assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); + + expectedPunctuations.add(71L); + pipeRecord(SOURCE_TOPIC_1, new TestRecord<>(key1, value1, null, 71L)); + assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); + + pipeRecord(SOURCE_TOPIC_1, new TestRecord<>(key1, value1, null, 72L)); + assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); + + expectedPunctuations.add(95L); + pipeRecord(SOURCE_TOPIC_1, new TestRecord<>(key1, value1, null, 95L)); + assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); + + expectedPunctuations.add(101L); + pipeRecord(SOURCE_TOPIC_1, new TestRecord<>(key1, value1, null, 101L)); + assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); + + pipeRecord(SOURCE_TOPIC_1, new TestRecord<>(key1, value1, null, 102L)); + assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); + } + + @Test + public void shouldPunctuateOnWallClockTime() { + final MockPunctuator mockPunctuator = new MockPunctuator(); + testDriver = new TopologyTestDriver( + setupSingleProcessorTopology(10L, PunctuationType.WALL_CLOCK_TIME, mockPunctuator), + config, Instant.ofEpochMilli(0L)); + + final List expectedPunctuations = new LinkedList<>(); + + testDriver.advanceWallClockTime(Duration.ofMillis(5L)); + assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); + + expectedPunctuations.add(14L); + testDriver.advanceWallClockTime(Duration.ofMillis(9L)); + assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); + + testDriver.advanceWallClockTime(Duration.ofMillis(1L)); + assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); + + expectedPunctuations.add(35L); + testDriver.advanceWallClockTime(Duration.ofMillis(20L)); + assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); + + expectedPunctuations.add(40L); + testDriver.advanceWallClockTime(Duration.ofMillis(5L)); + assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); + } + + @Test + public void shouldReturnAllStores() { + final Topology topology = setupSourceSinkTopology(); + topology.addProcessor("processor", new MockProcessorSupplier(), "source"); + topology.addStateStore( + new KeyValueStoreBuilder<>( + Stores.inMemoryKeyValueStore("store"), + Serdes.ByteArray(), + Serdes.ByteArray(), + new SystemTime()), + "processor"); + topology.addGlobalStore( + new KeyValueStoreBuilder<>( + Stores.inMemoryKeyValueStore("globalStore"), + Serdes.ByteArray(), + Serdes.ByteArray(), + new SystemTime()).withLoggingDisabled(), + "sourceProcessorName", + Serdes.ByteArray().deserializer(), + Serdes.ByteArray().deserializer(), + "globalTopicName", + "globalProcessorName", + voidProcessorSupplier); + + testDriver = new TopologyTestDriver(topology, config); + + final Set expectedStoreNames = new HashSet<>(); + expectedStoreNames.add("store"); + expectedStoreNames.add("globalStore"); + final Map allStores = testDriver.getAllStateStores(); + assertThat(allStores.keySet(), equalTo(expectedStoreNames)); + for (final StateStore store : allStores.values()) { + assertNotNull(store); + } + } + + @Test + public void shouldReturnCorrectPersistentStoreTypeOnly() { + shouldReturnCorrectStoreTypeOnly(true); + } + + @Test + public void shouldReturnCorrectInMemoryStoreTypeOnly() { + shouldReturnCorrectStoreTypeOnly(false); + } + + private void shouldReturnCorrectStoreTypeOnly(final boolean persistent) { + final String keyValueStoreName = "keyValueStore"; + final String timestampedKeyValueStoreName = "keyValueTimestampStore"; + final String windowStoreName = "windowStore"; + final String timestampedWindowStoreName = "windowTimestampStore"; + final String sessionStoreName = "sessionStore"; + final String globalKeyValueStoreName = "globalKeyValueStore"; + final String globalTimestampedKeyValueStoreName = "globalKeyValueTimestampStore"; + + final Topology topology = setupSingleProcessorTopology(); + addStoresToTopology( + topology, + persistent, + keyValueStoreName, + timestampedKeyValueStoreName, + windowStoreName, + timestampedWindowStoreName, + sessionStoreName, + globalKeyValueStoreName, + globalTimestampedKeyValueStoreName); + + + testDriver = new TopologyTestDriver(topology, config); + + // verify state stores + assertNotNull(testDriver.getKeyValueStore(keyValueStoreName)); + assertNull(testDriver.getTimestampedKeyValueStore(keyValueStoreName)); + assertNull(testDriver.getWindowStore(keyValueStoreName)); + assertNull(testDriver.getTimestampedWindowStore(keyValueStoreName)); + assertNull(testDriver.getSessionStore(keyValueStoreName)); + + assertNotNull(testDriver.getKeyValueStore(timestampedKeyValueStoreName)); + assertNotNull(testDriver.getTimestampedKeyValueStore(timestampedKeyValueStoreName)); + assertNull(testDriver.getWindowStore(timestampedKeyValueStoreName)); + assertNull(testDriver.getTimestampedWindowStore(timestampedKeyValueStoreName)); + assertNull(testDriver.getSessionStore(timestampedKeyValueStoreName)); + + assertNull(testDriver.getKeyValueStore(windowStoreName)); + assertNull(testDriver.getTimestampedKeyValueStore(windowStoreName)); + assertNotNull(testDriver.getWindowStore(windowStoreName)); + assertNull(testDriver.getTimestampedWindowStore(windowStoreName)); + assertNull(testDriver.getSessionStore(windowStoreName)); + + assertNull(testDriver.getKeyValueStore(timestampedWindowStoreName)); + assertNull(testDriver.getTimestampedKeyValueStore(timestampedWindowStoreName)); + assertNotNull(testDriver.getWindowStore(timestampedWindowStoreName)); + assertNotNull(testDriver.getTimestampedWindowStore(timestampedWindowStoreName)); + assertNull(testDriver.getSessionStore(timestampedWindowStoreName)); + + assertNull(testDriver.getKeyValueStore(sessionStoreName)); + assertNull(testDriver.getTimestampedKeyValueStore(sessionStoreName)); + assertNull(testDriver.getWindowStore(sessionStoreName)); + assertNull(testDriver.getTimestampedWindowStore(sessionStoreName)); + assertNotNull(testDriver.getSessionStore(sessionStoreName)); + + // verify global stores + assertNotNull(testDriver.getKeyValueStore(globalKeyValueStoreName)); + assertNull(testDriver.getTimestampedKeyValueStore(globalKeyValueStoreName)); + assertNull(testDriver.getWindowStore(globalKeyValueStoreName)); + assertNull(testDriver.getTimestampedWindowStore(globalKeyValueStoreName)); + assertNull(testDriver.getSessionStore(globalKeyValueStoreName)); + + assertNotNull(testDriver.getKeyValueStore(globalTimestampedKeyValueStoreName)); + assertNotNull(testDriver.getTimestampedKeyValueStore(globalTimestampedKeyValueStoreName)); + assertNull(testDriver.getWindowStore(globalTimestampedKeyValueStoreName)); + assertNull(testDriver.getTimestampedWindowStore(globalTimestampedKeyValueStoreName)); + assertNull(testDriver.getSessionStore(globalTimestampedKeyValueStoreName)); + } + + @Test + public void shouldThrowIfInMemoryBuiltInStoreIsAccessedWithUntypedMethod() { + shouldThrowIfBuiltInStoreIsAccessedWithUntypedMethod(false); + } + + @Test + public void shouldThrowIfPersistentBuiltInStoreIsAccessedWithUntypedMethod() { + shouldThrowIfBuiltInStoreIsAccessedWithUntypedMethod(true); + } + + private void shouldThrowIfBuiltInStoreIsAccessedWithUntypedMethod(final boolean persistent) { + final String keyValueStoreName = "keyValueStore"; + final String timestampedKeyValueStoreName = "keyValueTimestampStore"; + final String windowStoreName = "windowStore"; + final String timestampedWindowStoreName = "windowTimestampStore"; + final String sessionStoreName = "sessionStore"; + final String globalKeyValueStoreName = "globalKeyValueStore"; + final String globalTimestampedKeyValueStoreName = "globalKeyValueTimestampStore"; + + final Topology topology = setupSingleProcessorTopology(); + addStoresToTopology( + topology, + persistent, + keyValueStoreName, + timestampedKeyValueStoreName, + windowStoreName, + timestampedWindowStoreName, + sessionStoreName, + globalKeyValueStoreName, + globalTimestampedKeyValueStoreName); + + + testDriver = new TopologyTestDriver(topology, config); + + { + final IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> testDriver.getStateStore(keyValueStoreName)); + assertThat( + e.getMessage(), + equalTo("Store " + keyValueStoreName + + " is a key-value store and should be accessed via `getKeyValueStore()`")); + } + { + final IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> testDriver.getStateStore(timestampedKeyValueStoreName)); + assertThat( + e.getMessage(), + equalTo("Store " + timestampedKeyValueStoreName + + " is a timestamped key-value store and should be accessed via `getTimestampedKeyValueStore()`")); + } + { + final IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> testDriver.getStateStore(windowStoreName)); + assertThat( + e.getMessage(), + equalTo("Store " + windowStoreName + + " is a window store and should be accessed via `getWindowStore()`")); + } + { + final IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> testDriver.getStateStore(timestampedWindowStoreName)); + assertThat( + e.getMessage(), + equalTo("Store " + timestampedWindowStoreName + + " is a timestamped window store and should be accessed via `getTimestampedWindowStore()`")); + } + { + final IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> testDriver.getStateStore(sessionStoreName)); + assertThat( + e.getMessage(), + equalTo("Store " + sessionStoreName + + " is a session store and should be accessed via `getSessionStore()`")); + } + { + final IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> testDriver.getStateStore(globalKeyValueStoreName)); + assertThat( + e.getMessage(), + equalTo("Store " + globalKeyValueStoreName + + " is a key-value store and should be accessed via `getKeyValueStore()`")); + } + { + final IllegalArgumentException e = assertThrows( + IllegalArgumentException.class, + () -> testDriver.getStateStore(globalTimestampedKeyValueStoreName)); + assertThat( + e.getMessage(), + equalTo("Store " + globalTimestampedKeyValueStoreName + + " is a timestamped key-value store and should be accessed via `getTimestampedKeyValueStore()`")); + } + } + + final ProcessorSupplier voidProcessorSupplier = () -> new Processor() { + @Override + public void process(final Record record) { + } + }; + + private void addStoresToTopology(final Topology topology, + final boolean persistent, + final String keyValueStoreName, + final String timestampedKeyValueStoreName, + final String windowStoreName, + final String timestampedWindowStoreName, + final String sessionStoreName, + final String globalKeyValueStoreName, + final String globalTimestampedKeyValueStoreName) { + + // add state stores + topology.addStateStore( + Stores.keyValueStoreBuilder( + persistent ? + Stores.persistentKeyValueStore(keyValueStoreName) : + Stores.inMemoryKeyValueStore(keyValueStoreName), + Serdes.ByteArray(), + Serdes.ByteArray() + ), + "processor"); + topology.addStateStore( + Stores.timestampedKeyValueStoreBuilder( + persistent ? + Stores.persistentTimestampedKeyValueStore(timestampedKeyValueStoreName) : + Stores.inMemoryKeyValueStore(timestampedKeyValueStoreName), + Serdes.ByteArray(), + Serdes.ByteArray() + ), + "processor"); + topology.addStateStore( + Stores.windowStoreBuilder( + persistent ? + Stores.persistentWindowStore(windowStoreName, Duration.ofMillis(1000L), Duration.ofMillis(100L), false) : + Stores.inMemoryWindowStore(windowStoreName, Duration.ofMillis(1000L), Duration.ofMillis(100L), false), + Serdes.ByteArray(), + Serdes.ByteArray() + ), + "processor"); + topology.addStateStore( + Stores.timestampedWindowStoreBuilder( + persistent ? + Stores.persistentTimestampedWindowStore(timestampedWindowStoreName, Duration.ofMillis(1000L), Duration.ofMillis(100L), false) : + Stores.inMemoryWindowStore(timestampedWindowStoreName, Duration.ofMillis(1000L), Duration.ofMillis(100L), false), + Serdes.ByteArray(), + Serdes.ByteArray() + ), + "processor"); + topology.addStateStore( + persistent ? + Stores.sessionStoreBuilder( + Stores.persistentSessionStore(sessionStoreName, Duration.ofMillis(1000L)), + Serdes.ByteArray(), + Serdes.ByteArray()) : + Stores.sessionStoreBuilder( + Stores.inMemorySessionStore(sessionStoreName, Duration.ofMillis(1000L)), + Serdes.ByteArray(), + Serdes.ByteArray()), + "processor"); + // add global stores + topology.addGlobalStore( + persistent ? + Stores.keyValueStoreBuilder( + Stores.persistentKeyValueStore(globalKeyValueStoreName), + Serdes.ByteArray(), + Serdes.ByteArray() + ).withLoggingDisabled() : + Stores.keyValueStoreBuilder( + Stores.inMemoryKeyValueStore(globalKeyValueStoreName), + Serdes.ByteArray(), + Serdes.ByteArray() + ).withLoggingDisabled(), + "sourceDummy1", + Serdes.ByteArray().deserializer(), + Serdes.ByteArray().deserializer(), + "topicDummy1", + "processorDummy1", + voidProcessorSupplier); + topology.addGlobalStore( + persistent ? + Stores.timestampedKeyValueStoreBuilder( + Stores.persistentTimestampedKeyValueStore(globalTimestampedKeyValueStoreName), + Serdes.ByteArray(), + Serdes.ByteArray() + ).withLoggingDisabled() : + Stores.timestampedKeyValueStoreBuilder( + Stores.inMemoryKeyValueStore(globalTimestampedKeyValueStoreName), + Serdes.ByteArray(), + Serdes.ByteArray() + ).withLoggingDisabled(), + "sourceDummy2", + Serdes.ByteArray().deserializer(), + Serdes.ByteArray().deserializer(), + "topicDummy2", + "processorDummy2", + voidProcessorSupplier); + } + + @Test + public void shouldReturnAllStoresNames() { + final Topology topology = setupSourceSinkTopology(); + topology.addStateStore( + new KeyValueStoreBuilder<>( + Stores.inMemoryKeyValueStore("store"), + Serdes.ByteArray(), + Serdes.ByteArray(), + new SystemTime())); + topology.addGlobalStore( + new KeyValueStoreBuilder<>( + Stores.inMemoryKeyValueStore("globalStore"), + Serdes.ByteArray(), + Serdes.ByteArray(), + new SystemTime()).withLoggingDisabled(), + "sourceProcessorName", + Serdes.ByteArray().deserializer(), + Serdes.ByteArray().deserializer(), + "globalTopicName", + "globalProcessorName", + voidProcessorSupplier); + + testDriver = new TopologyTestDriver(topology, config); + + final Set expectedStoreNames = new HashSet<>(); + expectedStoreNames.add("store"); + expectedStoreNames.add("globalStore"); + assertThat(testDriver.getAllStateStores().keySet(), equalTo(expectedStoreNames)); + } + + private void setup() { + setup(Stores.inMemoryKeyValueStore("aggStore")); + } + + private void setup(final KeyValueBytesStoreSupplier storeSupplier) { + final Topology topology = new Topology(); + topology.addSource("sourceProcessor", "input-topic"); + topology.addProcessor("aggregator", new CustomMaxAggregatorSupplier(), "sourceProcessor"); + topology.addStateStore(Stores.keyValueStoreBuilder( + storeSupplier, + Serdes.String(), + Serdes.Long()), + "aggregator"); + topology.addSink("sinkProcessor", "result-topic", "aggregator"); + + config.setProperty(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass().getName()); + config.setProperty(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.Long().getClass().getName()); + testDriver = new TopologyTestDriver(topology, config); + + store = testDriver.getKeyValueStore("aggStore"); + store.put("a", 21L); + } + + private void pipeInput(final String topic, final String key, final Long value, final Long time) { + testDriver.pipeRecord(topic, new TestRecord<>(key, value, null, time), + new StringSerializer(), new LongSerializer(), null); + } + + private void compareKeyValue(final TestRecord record, final String key, final Long value) { + assertThat(record.getKey(), equalTo(key)); + assertThat(record.getValue(), equalTo(value)); + } + + @Test + public void shouldFlushStoreForFirstInput() { + setup(); + pipeInput("input-topic", "a", 1L, 9999L); + compareKeyValue(testDriver.readRecord("result-topic", stringDeserializer, longDeserializer), "a", 21L); + assertTrue(testDriver.isEmpty("result-topic")); + } + + @Test + public void shouldNotUpdateStoreForSmallerValue() { + setup(); + pipeInput("input-topic", "a", 1L, 9999L); + assertThat(store.get("a"), equalTo(21L)); + compareKeyValue(testDriver.readRecord("result-topic", stringDeserializer, longDeserializer), "a", 21L); + assertTrue(testDriver.isEmpty("result-topic")); + } + + @Test + public void shouldNotUpdateStoreForLargerValue() { + setup(); + pipeInput("input-topic", "a", 42L, 9999L); + assertThat(store.get("a"), equalTo(42L)); + compareKeyValue(testDriver.readRecord("result-topic", stringDeserializer, longDeserializer), "a", 42L); + assertTrue(testDriver.isEmpty("result-topic")); + } + + @Test + public void shouldUpdateStoreForNewKey() { + setup(); + pipeInput("input-topic", "b", 21L, 9999L); + assertThat(store.get("b"), equalTo(21L)); + compareKeyValue(testDriver.readRecord("result-topic", stringDeserializer, longDeserializer), "a", 21L); + compareKeyValue(testDriver.readRecord("result-topic", stringDeserializer, longDeserializer), "b", 21L); + assertTrue(testDriver.isEmpty("result-topic")); + } + + @Test + public void shouldPunctuateIfEvenTimeAdvances() { + setup(); + pipeInput("input-topic", "a", 1L, 9999L); + compareKeyValue(testDriver.readRecord("result-topic", stringDeserializer, longDeserializer), "a", 21L); + + pipeInput("input-topic", "a", 1L, 9999L); + assertTrue(testDriver.isEmpty("result-topic")); + + pipeInput("input-topic", "a", 1L, 10000L); + compareKeyValue(testDriver.readRecord("result-topic", stringDeserializer, longDeserializer), "a", 21L); + assertTrue(testDriver.isEmpty("result-topic")); + } + + @Test + public void shouldPunctuateIfWallClockTimeAdvances() { + setup(); + testDriver.advanceWallClockTime(Duration.ofMillis(60000)); + compareKeyValue(testDriver.readRecord("result-topic", stringDeserializer, longDeserializer), "a", 21L); + assertTrue(testDriver.isEmpty("result-topic")); + } + + private static class CustomMaxAggregatorSupplier implements ProcessorSupplier { + @Override + public Processor get() { + return new CustomMaxAggregator(); + } + } + + private static class CustomMaxAggregator implements Processor { + ProcessorContext context; + private KeyValueStore store; + + @Override + public void init(final ProcessorContext context) { + this.context = context; + context.schedule(Duration.ofMinutes(1), PunctuationType.WALL_CLOCK_TIME, this::flushStore); + context.schedule(Duration.ofSeconds(10), PunctuationType.STREAM_TIME, this::flushStore); + store = context.getStateStore("aggStore"); + } + + @Override + public void process(final Record record) { + final Long oldValue = store.get(record.key()); + if (oldValue == null || record.value() > oldValue) { + store.put(record.key(), record.value()); + } + } + + private void flushStore(final long timestamp) { + try (final KeyValueIterator it = store.all()) { + while (it.hasNext()) { + final KeyValue next = it.next(); + context.forward(new Record<>(next.key, next.value, timestamp)); + } + } + } + } + + @Test + public void shouldAllowPrePopulatingStatesStoresWithCachingEnabled() { + final Topology topology = new Topology(); + topology.addSource("sourceProcessor", "input-topic"); + topology.addProcessor("aggregator", new CustomMaxAggregatorSupplier(), "sourceProcessor"); + topology.addStateStore(Stores.keyValueStoreBuilder( + Stores.inMemoryKeyValueStore("aggStore"), + Serdes.String(), + Serdes.Long()).withCachingEnabled(), // intentionally turn on caching to achieve better test coverage + "aggregator"); + + testDriver = new TopologyTestDriver(topology, config); + + store = testDriver.getKeyValueStore("aggStore"); + store.put("a", 21L); + } + + @Test + public void shouldCleanUpPersistentStateStoresOnClose() { + final Topology topology = new Topology(); + topology.addSource("sourceProcessor", "input-topic"); + topology.addProcessor( + "storeProcessor", + new ProcessorSupplier() { + @Override + public Processor get() { + return new Processor() { + private KeyValueStore store; + + @Override + public void init(final ProcessorContext context) { + this.store = context.getStateStore("storeProcessorStore"); + } + + @Override + public void process(final Record record) { + store.put(record.key(), record.value()); + } + }; + } + }, + "sourceProcessor" + ); + topology.addStateStore(Stores.keyValueStoreBuilder( + Stores.persistentKeyValueStore("storeProcessorStore"), Serdes.String(), Serdes.Long()), + "storeProcessor"); + + final Properties config = new Properties(); + config.put(StreamsConfig.APPLICATION_ID_CONFIG, "test-TopologyTestDriver-cleanup"); + config.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getAbsolutePath()); + config.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass().getName()); + config.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.Long().getClass().getName()); + + try (final TopologyTestDriver testDriver = new TopologyTestDriver(topology, config)) { + assertNull(testDriver.getKeyValueStore("storeProcessorStore").get("a")); + testDriver.pipeRecord("input-topic", new TestRecord<>("a", 1L), + new StringSerializer(), new LongSerializer(), Instant.now()); + assertEquals(1L, testDriver.getKeyValueStore("storeProcessorStore").get("a")); + } + + + try (final TopologyTestDriver testDriver = new TopologyTestDriver(topology, config)) { + assertNull(testDriver.getKeyValueStore("storeProcessorStore").get("a"), + "Closing the prior test driver should have cleaned up this store and value."); + } + + } + + @Test + public void shouldFeedStoreFromGlobalKTable() { + final StreamsBuilder builder = new StreamsBuilder(); + builder.globalTable("topic", + Consumed.with(Serdes.String(), Serdes.String()), + Materialized.as("globalStore")); + try (final TopologyTestDriver testDriver = new TopologyTestDriver(builder.build(), config)) { + final KeyValueStore globalStore = testDriver.getKeyValueStore("globalStore"); + assertNotNull(globalStore); + assertNotNull(testDriver.getAllStateStores().get("globalStore")); + testDriver.pipeRecord( + "topic", + new TestRecord<>("k1", "value1"), + new StringSerializer(), + new StringSerializer(), + Instant.now()); + // we expect to have both in the global store, the one from pipeInput and the one from the producer + assertEquals("value1", globalStore.get("k1")); + } + } + + private Topology setupMultipleSourcesPatternTopology(final Pattern... sourceTopicPatternNames) { + final Topology topology = new Topology(); + + final String[] processorNames = new String[sourceTopicPatternNames.length]; + int i = 0; + for (final Pattern sourceTopicPatternName : sourceTopicPatternNames) { + final String sourceName = sourceTopicPatternName + "-source"; + final String processorName = sourceTopicPatternName + "-processor"; + topology.addSource(sourceName, sourceTopicPatternName); + processorNames[i++] = processorName; + topology.addProcessor(processorName, new MockProcessorSupplier(), sourceName); + } + topology.addSink("sink-topic", SINK_TOPIC_1, processorNames); + return topology; + } + + @Test + public void shouldProcessFromSourcesThatMatchMultiplePattern() { + + final Pattern pattern2Source1 = Pattern.compile("source-topic-\\d"); + final Pattern pattern2Source2 = Pattern.compile("source-topic-[A-Z]"); + final String consumerTopic2 = "source-topic-Z"; + + final TestRecord consumerRecord2 = new TestRecord<>(key2, value2, null, timestamp2); + + testDriver = new TopologyTestDriver(setupMultipleSourcesPatternTopology(pattern2Source1, pattern2Source2), config); + + final List processedRecords1 = mockProcessors.get(0).processedRecords; + final List processedRecords2 = mockProcessors.get(1).processedRecords; + + pipeRecord(SOURCE_TOPIC_1, testRecord1); + + assertEquals(1, processedRecords1.size()); + assertEquals(0, processedRecords2.size()); + + final TTDTestRecord record1 = processedRecords1.get(0); + final TTDTestRecord expectedResult1 = new TTDTestRecord(SOURCE_TOPIC_1, testRecord1, 0L); + assertThat(record1, equalTo(expectedResult1)); + + pipeRecord(consumerTopic2, consumerRecord2); + + assertEquals(1, processedRecords1.size()); + assertEquals(1, processedRecords2.size()); + + final TTDTestRecord record2 = processedRecords2.get(0); + final TTDTestRecord expectedResult2 = new TTDTestRecord(consumerTopic2, consumerRecord2, 0L); + assertThat(record2, equalTo(expectedResult2)); + } + + @Test + public void shouldProcessFromSourceThatMatchPattern() { + final String sourceName = "source"; + final Pattern pattern2Source1 = Pattern.compile("source-topic-\\d"); + + final Topology topology = new Topology(); + + topology.addSource(sourceName, pattern2Source1); + topology.addSink("sink", SINK_TOPIC_1, sourceName); + + testDriver = new TopologyTestDriver(topology, config); + pipeRecord(SOURCE_TOPIC_1, testRecord1); + + final ProducerRecord outputRecord = testDriver.readRecord(SINK_TOPIC_1); + assertEquals(key1, outputRecord.key()); + assertEquals(value1, outputRecord.value()); + assertEquals(SINK_TOPIC_1, outputRecord.topic()); + } + + @Test + public void shouldThrowPatternNotValidForTopicNameException() { + final String sourceName = "source"; + final String pattern2Source1 = "source-topic-\\d"; + + final Topology topology = new Topology(); + + topology.addSource(sourceName, pattern2Source1); + topology.addSink("sink", SINK_TOPIC_1, sourceName); + + testDriver = new TopologyTestDriver(topology, config); + try { + pipeRecord(SOURCE_TOPIC_1, testRecord1); + } catch (final TopologyException exception) { + final String str = + String.format( + "Invalid topology: Topology add source of type String for topic: %s cannot contain regex pattern for " + + "input record topic: %s and hence cannot process the message.", + pattern2Source1, + SOURCE_TOPIC_1); + assertEquals(str, exception.getMessage()); + } + } + + @Test + public void shouldNotCreateStateDirectoryForStatelessTopology() { + setup(); + final String stateDir = config.getProperty(StreamsConfig.STATE_DIR_CONFIG); + final File appDir = new File(stateDir, config.getProperty(StreamsConfig.APPLICATION_ID_CONFIG)); + assertFalse(appDir.exists()); + } + + @Test + public void shouldCreateStateDirectoryForStatefulTopology() { + setup(Stores.persistentKeyValueStore("aggStore")); + final String stateDir = config.getProperty(StreamsConfig.STATE_DIR_CONFIG); + final File appDir = new File(stateDir, config.getProperty(StreamsConfig.APPLICATION_ID_CONFIG)); + + assertTrue(appDir.exists()); + assertTrue(appDir.isDirectory()); + + final TaskId taskId = new TaskId(0, 0); + assertTrue(new File(appDir, taskId.toString()).exists()); + } + + @Test + public void shouldEnqueueLaterOutputsAfterEarlierOnes() { + final Topology topology = new Topology(); + topology.addSource("source", new StringDeserializer(), new StringDeserializer(), "input"); + topology.addProcessor( + "recursiveProcessor", + () -> new Processor() { + private ProcessorContext context; + + @Override + public void init(final ProcessorContext context) { + this.context = context; + } + + @Override + public void process(final Record record) { + final String value = record.value(); + if (!value.startsWith("recurse-")) { + context.forward(record.withValue("recurse-" + value), "recursiveSink"); + } + context.forward(record, "sink"); + } + }, + "source" + ); + topology.addSink("recursiveSink", "input", new StringSerializer(), new StringSerializer(), "recursiveProcessor"); + topology.addSink("sink", "output", new StringSerializer(), new StringSerializer(), "recursiveProcessor"); + + try (final TopologyTestDriver topologyTestDriver = new TopologyTestDriver(topology)) { + final TestInputTopic in = topologyTestDriver.createInputTopic("input", new StringSerializer(), new StringSerializer()); + final TestOutputTopic out = topologyTestDriver.createOutputTopic("output", new StringDeserializer(), new StringDeserializer()); + + // given the topology above, we expect to see the output _first_ echo the input + // and _then_ print it with "recurse-" prepended. + + in.pipeInput("B", "beta"); + final List> events = out.readKeyValuesToList(); + assertThat( + events, + is(Arrays.asList( + new KeyValue<>("B", "beta"), + new KeyValue<>("B", "recurse-beta") + )) + ); + + } + } + + @Test + public void shouldApplyGlobalUpdatesCorrectlyInRecursiveTopologies() { + final Topology topology = new Topology(); + topology.addSource("source", new StringDeserializer(), new StringDeserializer(), "input"); + topology.addGlobalStore( + Stores.keyValueStoreBuilder(Stores.inMemoryKeyValueStore("global-store"), Serdes.String(), Serdes.String()).withLoggingDisabled(), + "globalSource", + new StringDeserializer(), + new StringDeserializer(), + "global-topic", + "globalProcessor", + () -> new Processor() { + private KeyValueStore stateStore; + + @Override + public void init(final ProcessorContext context) { + stateStore = context.getStateStore("global-store"); + } + + @Override + public void process(final Record record) { + stateStore.put(record.key(), record.value()); + } + } + ); + topology.addProcessor( + "recursiveProcessor", + () -> new Processor() { + private ProcessorContext context; + + @Override + public void init(final ProcessorContext context) { + this.context = context; + } + + @Override + public void process(final Record record) { + final String value = record.value(); + if (!value.startsWith("recurse-")) { + context.forward(record.withValue("recurse-" + value), "recursiveSink"); + } + context.forward(record, "sink"); + context.forward(record, "globalSink"); + } + }, + "source" + ); + topology.addSink("recursiveSink", "input", new StringSerializer(), new StringSerializer(), "recursiveProcessor"); + topology.addSink("sink", "output", new StringSerializer(), new StringSerializer(), "recursiveProcessor"); + topology.addSink("globalSink", "global-topic", new StringSerializer(), new StringSerializer(), "recursiveProcessor"); + + try (final TopologyTestDriver topologyTestDriver = new TopologyTestDriver(topology)) { + final TestInputTopic in = topologyTestDriver.createInputTopic("input", new StringSerializer(), new StringSerializer()); + final TestOutputTopic globalTopic = topologyTestDriver.createOutputTopic("global-topic", new StringDeserializer(), new StringDeserializer()); + + in.pipeInput("A", "alpha"); + + // expect the global store to correctly reflect the last update + final KeyValueStore keyValueStore = topologyTestDriver.getKeyValueStore("global-store"); + assertThat(keyValueStore, notNullValue()); + assertThat(keyValueStore.get("A"), is("recurse-alpha")); + + // and also just make sure the test really sent both events to the topic. + final List> events = globalTopic.readKeyValuesToList(); + assertThat( + events, + is(Arrays.asList( + new KeyValue<>("A", "alpha"), + new KeyValue<>("A", "recurse-alpha") + )) + ); + } + } + + @Test + public void shouldRespectTaskIdling() { + final Properties properties = new Properties(); + // This is the key to this test. Wall-clock time doesn't advance automatically in TopologyTestDriver, + // so with an idle time specified, TTD can't just expect all enqueued records to be processable. + properties.setProperty(StreamsConfig.MAX_TASK_IDLE_MS_CONFIG, "1000"); + + final Topology topology = new Topology(); + topology.addSource("source1", new StringDeserializer(), new StringDeserializer(), "input1"); + topology.addSource("source2", new StringDeserializer(), new StringDeserializer(), "input2"); + topology.addSink("sink", "output", new StringSerializer(), new StringSerializer(), "source1", "source2"); + + try (final TopologyTestDriver topologyTestDriver = new TopologyTestDriver(topology, properties)) { + final TestInputTopic in1 = topologyTestDriver.createInputTopic("input1", new StringSerializer(), new StringSerializer()); + final TestInputTopic in2 = topologyTestDriver.createInputTopic("input2", new StringSerializer(), new StringSerializer()); + final TestOutputTopic out = topologyTestDriver.createOutputTopic("output", new StringDeserializer(), new StringDeserializer()); + + in1.pipeInput("A", "alpha"); + topologyTestDriver.advanceWallClockTime(Duration.ofMillis(1)); + + // only one input has records, and it's only been one ms + assertThat(out.readKeyValuesToList(), is(Collections.emptyList())); + + in2.pipeInput("B", "beta"); + + // because both topics have records, we can process (even though it's only been one ms) + // but after processing A (the earlier record), we now only have one input queued, so + // task idling takes effect again + assertThat( + out.readKeyValuesToList(), + is(Collections.singletonList( + new KeyValue<>("A", "alpha") + )) + ); + + topologyTestDriver.advanceWallClockTime(Duration.ofSeconds(1)); + + // now that one second has elapsed, the idle time has expired, and we can process B + assertThat( + out.readKeyValuesToList(), + is(Collections.singletonList( + new KeyValue<>("B", "beta") + )) + ); + } + } +} diff --git a/streams/test-utils/src/test/java/org/apache/kafka/streams/WindowStoreFacadeTest.java b/streams/test-utils/src/test/java/org/apache/kafka/streams/WindowStoreFacadeTest.java new file mode 100644 index 0000000..a0d575e --- /dev/null +++ b/streams/test-utils/src/test/java/org/apache/kafka/streams/WindowStoreFacadeTest.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.streams.TopologyTestDriver.WindowStoreFacade; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.state.TimestampedWindowStore; +import org.apache.kafka.streams.state.ValueAndTimestamp; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class WindowStoreFacadeTest { + @SuppressWarnings("unchecked") + private final TimestampedWindowStore mockedWindowTimestampStore = mock(TimestampedWindowStore.class); + + private WindowStoreFacade windowStoreFacade; + + @BeforeEach + public void setup() { + windowStoreFacade = new WindowStoreFacade<>(mockedWindowTimestampStore); + } + + @SuppressWarnings("deprecation") // test of deprecated method + @Test + public void shouldForwardDeprecatedInit() { + final ProcessorContext context = mock(ProcessorContext.class); + final StateStore store = mock(StateStore.class); + + windowStoreFacade.init(context, store); + verify(mockedWindowTimestampStore) + .init(context, store); + } + + @Test + public void shouldForwardInit() { + final StateStoreContext context = mock(StateStoreContext.class); + final StateStore store = mock(StateStore.class); + + windowStoreFacade.init(context, store); + verify(mockedWindowTimestampStore) + .init(context, store); + } + + @Test + public void shouldPutWindowStartTimestampWithUnknownTimestamp() { + windowStoreFacade.put("key", "value", 21L); + verify(mockedWindowTimestampStore) + .put("key", ValueAndTimestamp.make("value", ConsumerRecord.NO_TIMESTAMP), 21L); + } + + @Test + public void shouldForwardFlush() { + windowStoreFacade.flush(); + verify(mockedWindowTimestampStore).flush(); + } + + @Test + public void shouldForwardClose() { + windowStoreFacade.close(); + verify(mockedWindowTimestampStore).close(); + } + + @Test + public void shouldReturnName() { + when(mockedWindowTimestampStore.name()).thenReturn("name"); + + assertThat(windowStoreFacade.name(), is("name")); + verify(mockedWindowTimestampStore).name(); + } + + @Test + public void shouldReturnIsPersistent() { + when(mockedWindowTimestampStore.persistent()) + .thenReturn(true, false); + + assertThat(windowStoreFacade.persistent(), is(true)); + assertThat(windowStoreFacade.persistent(), is(false)); + verify(mockedWindowTimestampStore, times(2)).persistent(); + } + + @Test + public void shouldReturnIsOpen() { + when(mockedWindowTimestampStore.isOpen()) + .thenReturn(true, false); + + assertThat(windowStoreFacade.isOpen(), is(true)); + assertThat(windowStoreFacade.isOpen(), is(false)); + verify(mockedWindowTimestampStore, times(2)).isOpen(); + } + +} diff --git a/streams/test-utils/src/test/java/org/apache/kafka/streams/test/MockProcessorContextAPITest.java b/streams/test-utils/src/test/java/org/apache/kafka/streams/test/MockProcessorContextAPITest.java new file mode 100644 index 0000000..88c67d5 --- /dev/null +++ b/streams/test-utils/src/test/java/org/apache/kafka/streams/test/MockProcessorContextAPITest.java @@ -0,0 +1,353 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.test; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.Punctuator; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.api.MockProcessorContext; +import org.apache.kafka.streams.processor.api.MockProcessorContext.CapturedForward; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.processor.api.RecordMetadata; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; +import org.junit.jupiter.api.Test; + +import java.io.File; +import java.time.Duration; +import java.util.List; +import java.util.Optional; +import java.util.Properties; + +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkProperties; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.is; + +public class MockProcessorContextAPITest { + @Test + public void shouldCaptureOutputRecords() { + final Processor processor = new Processor() { + private ProcessorContext context; + + @Override + public void init(final ProcessorContext context) { + this.context = context; + } + + @Override + public void process(final Record record) { + final String key = record.key(); + final Long value = record.value(); + context.forward(record.withKey(key + value).withValue(key.length() + value)); + } + }; + + final MockProcessorContext context = new MockProcessorContext<>(); + processor.init(context); + + processor.process(new Record<>("foo", 5L, 0L)); + processor.process(new Record<>("barbaz", 50L, 0L)); + + final List> actual = context.forwarded(); + final List> expected = asList( + new CapturedForward<>(new Record<>("foo5", 8L, 0L)), + new CapturedForward<>(new Record<>("barbaz50", 56L, 0L)) + ); + assertThat(actual, is(expected)); + + context.resetForwards(); + + assertThat(context.forwarded(), empty()); + } + + @Test + public void shouldCaptureRecordsOutputToChildByName() { + final Processor processor = new Processor() { + private ProcessorContext context; + + @Override + public void process(final Record record) { + final String key = record.key(); + final Long value = record.value(); + if (count == 0) { + context.forward(new Record<>("start", -1L, 0L)); // broadcast + } + final String toChild = count % 2 == 0 ? "george" : "pete"; + context.forward(new Record<>(key + value, key.length() + value, 0L), toChild); + count++; + } + + @Override + public void init(final ProcessorContext context) { + this.context = context; + } + + private int count = 0; + + }; + + final MockProcessorContext context = new MockProcessorContext<>(); + + processor.init(context); + + processor.process(new Record<>("foo", 5L, 0L)); + processor.process(new Record<>("barbaz", 50L, 0L)); + + { + final List> forwarded = context.forwarded(); + final List> expected = asList( + new CapturedForward<>(new Record<>("start", -1L, 0L), Optional.empty()), + new CapturedForward<>(new Record<>("foo5", 8L, 0L), Optional.of("george")), + new CapturedForward<>(new Record<>("barbaz50", 56L, 0L), Optional.of("pete")) + ); + + assertThat(forwarded, is(expected)); + } + { + final List> forwarded = context.forwarded("george"); + final List> expected = asList( + new CapturedForward<>(new Record<>("start", -1L, 0L), Optional.empty()), + new CapturedForward<>(new Record<>("foo5", 8L, 0L), Optional.of("george")) + ); + + assertThat(forwarded, is(expected)); + } + { + final List> forwarded = context.forwarded("pete"); + final List> expected = asList( + new CapturedForward<>(new Record<>("start", -1L, 0L), Optional.empty()), + new CapturedForward<>(new Record<>("barbaz50", 56L, 0L), Optional.of("pete")) + ); + + assertThat(forwarded, is(expected)); + } + { + final List> forwarded = context.forwarded("steve"); + final List> expected = singletonList( + new CapturedForward<>(new Record<>("start", -1L, 0L)) + ); + + assertThat(forwarded, is(expected)); + } + } + + @Test + public void shouldCaptureCommitsAndAllowReset() { + final Processor processor = new Processor() { + private ProcessorContext context; + private int count = 0; + + @Override + public void init(final ProcessorContext context) { + this.context = context; + } + + @Override + public void process(final Record record) { + if (++count > 2) { + context.commit(); + } + } + }; + + final MockProcessorContext context = new MockProcessorContext<>(); + + processor.init(context); + + processor.process(new Record<>("foo", 5L, 0L)); + processor.process(new Record<>("barbaz", 50L, 0L)); + + assertThat(context.committed(), is(false)); + + processor.process(new Record<>("foobar", 500L, 0L)); + + assertThat(context.committed(), is(true)); + + context.resetCommit(); + + assertThat(context.committed(), is(false)); + } + + @Test + public void shouldStoreAndReturnStateStores() { + final Processor processor = new Processor() { + private ProcessorContext context; + + @Override + public void init(final ProcessorContext context) { + this.context = context; + } + + @Override + public void process(final Record record) { + final String key = record.key(); + final Long value = record.value(); + final KeyValueStore stateStore = context.getStateStore("my-state"); + + stateStore.put(key, (stateStore.get(key) == null ? 0 : stateStore.get(key)) + value); + stateStore.put("all", (stateStore.get("all") == null ? 0 : stateStore.get("all")) + value); + } + + }; + + final MockProcessorContext context = new MockProcessorContext<>(); + + final StoreBuilder> storeBuilder = Stores.keyValueStoreBuilder( + Stores.inMemoryKeyValueStore("my-state"), + Serdes.String(), + Serdes.Long()).withLoggingDisabled(); + + final KeyValueStore store = storeBuilder.build(); + + store.init(context.getStateStoreContext(), store); + + processor.init(context); + + processor.process(new Record<>("foo", 5L, 0L)); + processor.process(new Record<>("bar", 50L, 0L)); + + assertThat(store.get("foo"), is(5L)); + assertThat(store.get("bar"), is(50L)); + assertThat(store.get("all"), is(55L)); + } + + + @Test + public void shouldCaptureApplicationAndRecordMetadata() { + final Properties config = mkProperties( + mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "testMetadata"), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "") + ) + ); + + final Processor processor = new Processor() { + private ProcessorContext context; + + @Override + public void init(final ProcessorContext context) { + this.context = context; + } + + @Override + public void process(final Record record) { + context.forward(new Record("appId", context.applicationId(), 0L)); + context.forward(new Record("taskId", context.taskId(), 0L)); + + if (context.recordMetadata().isPresent()) { + final RecordMetadata recordMetadata = context.recordMetadata().get(); + context.forward(new Record("topic", recordMetadata.topic(), 0L)); + context.forward(new Record("partition", recordMetadata.partition(), 0L)); + context.forward(new Record("offset", recordMetadata.offset(), 0L)); + } + + context.forward(new Record("record", record, 0L)); + } + }; + + final MockProcessorContext context = new MockProcessorContext<>(config); + processor.init(context); + + processor.process(new Record<>("foo", 5L, 0L)); + { + final List> forwarded = context.forwarded(); + final List> expected = asList( + new CapturedForward<>(new Record<>("appId", "testMetadata", 0L)), + new CapturedForward<>(new Record<>("taskId", new TaskId(0, 0), 0L)), + new CapturedForward<>(new Record<>("record", new Record<>("foo", 5L, 0L), 0L)) + ); + assertThat(forwarded, is(expected)); + } + context.resetForwards(); + context.setRecordMetadata("t1", 0, 0L); + processor.process(new Record<>("foo", 5L, 0L)); + { + final List> forwarded = context.forwarded(); + final List> expected = asList( + new CapturedForward<>(new Record<>("appId", "testMetadata", 0L)), + new CapturedForward<>(new Record<>("taskId", new TaskId(0, 0), 0L)), + new CapturedForward<>(new Record<>("topic", "t1", 0L)), + new CapturedForward<>(new Record<>("partition", 0, 0L)), + new CapturedForward<>(new Record<>("offset", 0L, 0L)), + new CapturedForward<>(new Record<>("record", new Record<>("foo", 5L, 0L), 0L)) + ); + assertThat(forwarded, is(expected)); + } + } + + @Test + public void shouldCapturePunctuator() { + final Processor processor = new Processor() { + @Override + public void init(final ProcessorContext context) { + context.schedule( + Duration.ofSeconds(1L), + PunctuationType.WALL_CLOCK_TIME, + timestamp -> context.commit() + ); + } + + @Override + public void process(final Record record) {} + }; + + final MockProcessorContext context = new MockProcessorContext<>(); + + processor.init(context); + + final MockProcessorContext.CapturedPunctuator capturedPunctuator = context.scheduledPunctuators().get(0); + assertThat(capturedPunctuator.getInterval(), is(Duration.ofMillis(1000L))); + assertThat(capturedPunctuator.getType(), is(PunctuationType.WALL_CLOCK_TIME)); + assertThat(capturedPunctuator.cancelled(), is(false)); + + final Punctuator punctuator = capturedPunctuator.getPunctuator(); + assertThat(context.committed(), is(false)); + punctuator.punctuate(1234L); + assertThat(context.committed(), is(true)); + } + + @Test + public void fullConstructorShouldSetAllExpectedAttributes() { + final Properties config = new Properties(); + config.put(StreamsConfig.APPLICATION_ID_CONFIG, "testFullConstructor"); + config.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, ""); + config.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + config.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.Long().getClass()); + + final File dummyFile = new File(""); + final MockProcessorContext context = + new MockProcessorContext<>(config, new TaskId(1, 1), dummyFile); + + assertThat(context.applicationId(), is("testFullConstructor")); + assertThat(context.taskId(), is(new TaskId(1, 1))); + assertThat(context.appConfigs().get(StreamsConfig.APPLICATION_ID_CONFIG), is("testFullConstructor")); + assertThat(context.appConfigsWithPrefix("application.").get("id"), is("testFullConstructor")); + assertThat(context.keySerde().getClass(), is(Serdes.String().getClass())); + assertThat(context.valueSerde().getClass(), is(Serdes.Long().getClass())); + assertThat(context.stateDir(), is(dummyFile)); + } +} diff --git a/streams/test-utils/src/test/java/org/apache/kafka/streams/test/MockProcessorContextStateStoreTest.java b/streams/test-utils/src/test/java/org/apache/kafka/streams/test/MockProcessorContextStateStoreTest.java new file mode 100644 index 0000000..ca64266 --- /dev/null +++ b/streams/test-utils/src/test/java/org/apache/kafka/streams/test/MockProcessorContextStateStoreTest.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.test; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.api.MockProcessorContext; +import org.apache.kafka.streams.state.KeyValueBytesStoreSupplier; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.SessionBytesStoreSupplier; +import org.apache.kafka.streams.state.SessionStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.WindowBytesStoreSupplier; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.io.File; +import java.io.IOException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.stream.Stream; + +import static java.util.Arrays.asList; +import static org.apache.kafka.common.utils.Utils.mkEntry; +import static org.apache.kafka.common.utils.Utils.mkMap; +import static org.apache.kafka.common.utils.Utils.mkProperties; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class MockProcessorContextStateStoreTest { + + public static Stream parameters() { + final List booleans = asList(true, false); + + final List values = new ArrayList<>(); + + for (final Boolean timestamped : booleans) { + for (final Boolean caching : booleans) { + for (final Boolean logging : booleans) { + final List keyValueBytesStoreSuppliers = asList( + Stores.inMemoryKeyValueStore("kv" + timestamped + caching + logging), + Stores.persistentKeyValueStore("kv" + timestamped + caching + logging), + Stores.persistentTimestampedKeyValueStore("kv" + timestamped + caching + logging) + ); + for (final KeyValueBytesStoreSupplier supplier : keyValueBytesStoreSuppliers) { + final StoreBuilder> builder; + if (timestamped) { + builder = Stores.timestampedKeyValueStoreBuilder(supplier, Serdes.String(), Serdes.Long()); + } else { + builder = Stores.keyValueStoreBuilder(supplier, Serdes.String(), Serdes.Long()); + } + if (caching) { + builder.withCachingEnabled(); + } else { + builder.withCachingDisabled(); + } + if (logging) { + builder.withLoggingEnabled(Collections.emptyMap()); + } else { + builder.withLoggingDisabled(); + } + + values.add(Arguments.of(builder, timestamped, caching, logging)); + } + } + } + } + + for (final Boolean timestamped : booleans) { + for (final Boolean caching : booleans) { + for (final Boolean logging : booleans) { + final List windowBytesStoreSuppliers = asList( + Stores.inMemoryWindowStore("w" + timestamped + caching + logging, Duration.ofSeconds(1), Duration.ofSeconds(1), false), + Stores.persistentWindowStore("w" + timestamped + caching + logging, Duration.ofSeconds(1), Duration.ofSeconds(1), false), + Stores.persistentTimestampedWindowStore("w" + timestamped + caching + logging, Duration.ofSeconds(1), Duration.ofSeconds(1), false) + ); + + for (final WindowBytesStoreSupplier supplier : windowBytesStoreSuppliers) { + final StoreBuilder> builder; + if (timestamped) { + builder = Stores.timestampedWindowStoreBuilder(supplier, Serdes.String(), Serdes.Long()); + } else { + builder = Stores.windowStoreBuilder(supplier, Serdes.String(), Serdes.Long()); + } + if (caching) { + builder.withCachingEnabled(); + } else { + builder.withCachingDisabled(); + } + if (logging) { + builder.withLoggingEnabled(Collections.emptyMap()); + } else { + builder.withLoggingDisabled(); + } + + values.add(Arguments.of(builder, timestamped, caching, logging)); + } + } + } + } + + for (final Boolean caching : booleans) { + for (final Boolean logging : booleans) { + final List sessionBytesStoreSuppliers = asList( + Stores.inMemorySessionStore("s" + caching + logging, Duration.ofSeconds(1)), + Stores.persistentSessionStore("s" + caching + logging, Duration.ofSeconds(1)) + ); + + for (final SessionBytesStoreSupplier supplier : sessionBytesStoreSuppliers) { + final StoreBuilder> builder = + Stores.sessionStoreBuilder(supplier, Serdes.String(), Serdes.Long()); + if (caching) { + builder.withCachingEnabled(); + } else { + builder.withCachingDisabled(); + } + if (logging) { + builder.withLoggingEnabled(Collections.emptyMap()); + } else { + builder.withLoggingDisabled(); + } + + values.add(Arguments.of(builder, false, caching, logging)); + } + } + } + + return values.stream(); + } + + @ParameterizedTest(name = "builder = {0}, timestamped = {1}, caching = {2}, logging = {3}") + @MethodSource(value = "parameters") + public void shouldEitherInitOrThrow(final StoreBuilder builder, + final boolean timestamped, + final boolean caching, + final boolean logging) { + final File stateDir = TestUtils.tempDirectory(); + try { + final MockProcessorContext context = new MockProcessorContext<>( + mkProperties(mkMap( + mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, ""), + mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "") + )), + new TaskId(0, 0), + stateDir + ); + final StateStore store = builder.build(); + if (caching || logging) { + assertThrows( + IllegalArgumentException.class, + () -> store.init(context.getStateStoreContext(), store) + ); + } else { + store.init(context.getStateStoreContext(), store); + store.close(); + } + } finally { + try { + Utils.delete(stateDir); + } catch (final IOException e) { + // Failed to clean up the state dir. The JVM hooks will try again later. + } + } + } +} diff --git a/streams/test-utils/src/test/java/org/apache/kafka/streams/test/TestRecordTest.java b/streams/test-utils/src/test/java/org/apache/kafka/streams/test/TestRecordTest.java new file mode 100644 index 0000000..ad3b1a2 --- /dev/null +++ b/streams/test-utils/src/test/java/org/apache/kafka/streams/test/TestRecordTest.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.test; + +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.header.Header; +import org.apache.kafka.common.header.Headers; +import org.apache.kafka.common.header.internals.RecordHeader; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.record.TimestampType; + +import org.junit.jupiter.api.Test; + +import java.time.Instant; +import java.util.Optional; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.hasProperty; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class TestRecordTest { + private final String key = "testKey"; + private final int value = 1; + private final Headers headers = new RecordHeaders( + new Header[]{ + new RecordHeader("foo", "value".getBytes()), + new RecordHeader("bar", null), + new RecordHeader("\"A\\u00ea\\u00f1\\u00fcC\"", "value".getBytes()) + }); + private final Instant recordTime = Instant.parse("2019-06-01T10:00:00Z"); + private final long recordMs = recordTime.toEpochMilli(); + + @Test + public void testFields() { + final TestRecord testRecord = new TestRecord<>(key, value, headers, recordTime); + assertThat(testRecord.key(), equalTo(key)); + assertThat(testRecord.value(), equalTo(value)); + assertThat(testRecord.headers(), equalTo(headers)); + assertThat(testRecord.timestamp(), equalTo(recordMs)); + + assertThat(testRecord.getKey(), equalTo(key)); + assertThat(testRecord.getValue(), equalTo(value)); + assertThat(testRecord.getHeaders(), equalTo(headers)); + assertThat(testRecord.getRecordTime(), equalTo(recordTime)); + } + + @Test + public void testMultiFieldMatcher() { + final TestRecord testRecord = new TestRecord<>(key, value, headers, recordTime); + + assertThat(testRecord, allOf( + hasProperty("key", equalTo(key)), + hasProperty("value", equalTo(value)), + hasProperty("headers", equalTo(headers)))); + + assertThat(testRecord, allOf( + hasProperty("key", equalTo(key)), + hasProperty("value", equalTo(value)), + hasProperty("headers", equalTo(headers)), + hasProperty("recordTime", equalTo(recordTime)))); + + assertThat(testRecord, allOf( + hasProperty("key", equalTo(key)), + hasProperty("value", equalTo(value)))); + } + + + @Test + public void testEqualsAndHashCode() { + final TestRecord testRecord = new TestRecord<>(key, value, headers, recordTime); + assertEquals(testRecord, testRecord); + assertEquals(testRecord.hashCode(), testRecord.hashCode()); + + final TestRecord equalRecord = new TestRecord<>(key, value, headers, recordTime); + assertEquals(testRecord, equalRecord); + assertEquals(testRecord.hashCode(), equalRecord.hashCode()); + + final TestRecord equalRecordMs = new TestRecord<>(key, value, headers, recordMs); + assertEquals(testRecord, equalRecordMs); + assertEquals(testRecord.hashCode(), equalRecordMs.hashCode()); + + final Headers headers2 = new RecordHeaders( + new Header[]{ + new RecordHeader("foo", "value".getBytes()), + new RecordHeader("bar", null), + }); + final TestRecord headerMismatch = new TestRecord<>(key, value, headers2, recordTime); + assertNotEquals(testRecord, headerMismatch); + + final TestRecord keyMisMatch = new TestRecord<>("test-mismatch", value, headers, recordTime); + assertNotEquals(testRecord, keyMisMatch); + + final TestRecord valueMisMatch = new TestRecord<>(key, 2, headers, recordTime); + assertNotEquals(testRecord, valueMisMatch); + + final TestRecord timeMisMatch = new TestRecord<>(key, value, headers, recordTime.plusMillis(1)); + assertNotEquals(testRecord, timeMisMatch); + + final TestRecord nullFieldsRecord = new TestRecord<>(null, null, null, (Instant) null); + assertEquals(nullFieldsRecord, nullFieldsRecord); + assertEquals(nullFieldsRecord.hashCode(), nullFieldsRecord.hashCode()); + } + + @Test + public void testPartialConstructorEquals() { + final TestRecord record1 = new TestRecord<>(value); + assertThat(record1, equalTo(new TestRecord<>(null, value, null, (Instant) null))); + + final TestRecord record2 = new TestRecord<>(key, value); + assertThat(record2, equalTo(new TestRecord<>(key, value, null, (Instant) null))); + + final TestRecord record3 = new TestRecord<>(key, value, headers); + assertThat(record3, equalTo(new TestRecord<>(key, value, headers, (Long) null))); + + final TestRecord record4 = new TestRecord<>(key, value, recordTime); + assertThat(record4, equalTo(new TestRecord<>(key, value, null, recordMs))); + } + + @Test + public void testInvalidRecords() { + assertThrows(IllegalArgumentException.class, + () -> new TestRecord<>(key, value, headers, -1L)); + } + + @Test + public void testToString() { + final TestRecord testRecord = new TestRecord<>(key, value, headers, recordTime); + assertThat(testRecord.toString(), equalTo("TestRecord[key=testKey, value=1, " + + "headers=RecordHeaders(headers = [RecordHeader(key = foo, value = [118, 97, 108, 117, 101]), " + + "RecordHeader(key = bar, value = null), RecordHeader(key = \"A\\u00ea\\u00f1\\u00fcC\", value = [118, 97, 108, 117, 101])], isReadOnly = false), " + + "recordTime=2019-06-01T10:00:00Z]")); + } + + @Test + public void testConsumerRecord() { + final String topicName = "topic"; + final ConsumerRecord consumerRecord = new ConsumerRecord<>(topicName, 1, 0, recordMs, + TimestampType.CREATE_TIME, 0, 0, key, value, headers, Optional.empty()); + final TestRecord testRecord = new TestRecord<>(consumerRecord); + final TestRecord expectedRecord = new TestRecord<>(key, value, headers, recordTime); + assertEquals(expectedRecord, testRecord); + } + + @Test + public void testProducerRecord() { + final String topicName = "topic"; + final ProducerRecord producerRecord = + new ProducerRecord<>(topicName, 1, recordMs, key, value, headers); + final TestRecord testRecord = new TestRecord<>(producerRecord); + final TestRecord expectedRecord = new TestRecord<>(key, value, headers, recordTime); + assertEquals(expectedRecord, testRecord); + } +} diff --git a/streams/test-utils/src/test/java/org/apache/kafka/streams/test/wordcount/WindowedWordCountProcessorSupplier.java b/streams/test-utils/src/test/java/org/apache/kafka/streams/test/wordcount/WindowedWordCountProcessorSupplier.java new file mode 100644 index 0000000..403e453 --- /dev/null +++ b/streams/test-utils/src/test/java/org/apache/kafka/streams/test/wordcount/WindowedWordCountProcessorSupplier.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.test.wordcount; + +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.PunctuationType; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.ProcessorContext; +import org.apache.kafka.streams.processor.api.ProcessorSupplier; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.state.KeyValueIterator; +import org.apache.kafka.streams.state.WindowStore; + +import java.time.Duration; +import java.util.Locale; + +public final class WindowedWordCountProcessorSupplier implements ProcessorSupplier { + + @Override + public Processor get() { + return new Processor() { + private WindowStore windowStore; + + @Override + public void init(final ProcessorContext context) { + context.schedule(Duration.ofSeconds(1), PunctuationType.STREAM_TIME, timestamp -> { + try (final KeyValueIterator, Integer> iter = windowStore.all()) { + while (iter.hasNext()) { + final KeyValue, Integer> entry = iter.next(); + context.forward(new Record<>(entry.key.toString(), entry.value.toString(), timestamp)); + } + } + }); + windowStore = context.getStateStore("WindowedCounts"); + } + + @Override + public void process(final Record record) { + final String[] words = record.value().toLowerCase(Locale.getDefault()).split(" "); + final long timestamp = record.timestamp(); + + // calculate the window as every 100 ms + // Note this has to be aligned with the configuration for the window store you register separately + final long windowStart = timestamp / 100 * 100; + + for (final String word : words) { + final Integer oldValue = windowStore.fetch(word, windowStart); + + if (oldValue == null) { + windowStore.put(word, 1, windowStart); + } else { + windowStore.put(word, oldValue + 1, windowStart); + } + } + } + }; + } +} diff --git a/streams/test-utils/src/test/java/org/apache/kafka/streams/test/wordcount/WindowedWordCountProcessorTest.java b/streams/test-utils/src/test/java/org/apache/kafka/streams/test/wordcount/WindowedWordCountProcessorTest.java new file mode 100644 index 0000000..203c51d --- /dev/null +++ b/streams/test-utils/src/test/java/org/apache/kafka/streams/test/wordcount/WindowedWordCountProcessorTest.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.test.wordcount; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.processor.TaskId; +import org.apache.kafka.streams.processor.api.MockProcessorContext; +import org.apache.kafka.streams.processor.api.MockProcessorContext.CapturedForward; +import org.apache.kafka.streams.processor.api.Processor; +import org.apache.kafka.streams.processor.api.Record; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.WindowStore; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; + +import java.io.File; +import java.io.IOException; +import java.time.Duration; +import java.util.List; +import java.util.Properties; + +import static java.util.Arrays.asList; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; + +public class WindowedWordCountProcessorTest { + @Test + public void shouldWorkWithInMemoryStore() { + final MockProcessorContext context = new MockProcessorContext<>(); + + // Create, initialize, and register the state store. + final WindowStore store = + Stores.windowStoreBuilder(Stores.inMemoryWindowStore("WindowedCounts", + Duration.ofDays(24), + Duration.ofMillis(100), + false), + Serdes.String(), + Serdes.Integer()) + .withLoggingDisabled() // Changelog is not supported by MockProcessorContext. + .withCachingDisabled() // Caching is not supported by MockProcessorContext. + .build(); + store.init(context.getStateStoreContext(), store); + context.getStateStoreContext().register(store, null); + + // Create and initialize the processor under test + final Processor processor = new WindowedWordCountProcessorSupplier().get(); + processor.init(context); + + // send a record to the processor + processor.process(new Record<>("key", "alpha beta gamma alpha", 101L)); + + // send a record to the processor in a new window + processor.process(new Record<>("key", "gamma delta", 221L)); + + // note that the processor does not forward during process() + assertThat(context.forwarded().isEmpty(), is(true)); + + // now, we trigger the punctuator, which iterates over the state store and forwards the contents. + context.scheduledPunctuators().get(0).getPunctuator().punctuate(1_000L); + + // finally, we can verify the output. + final List> capturedForwards = context.forwarded(); + final List> expected = asList( + new CapturedForward<>(new Record<>("[alpha@100/200]", "2", 1_000L)), + new CapturedForward<>(new Record<>("[beta@100/200]", "1", 1_000L)), + new CapturedForward<>(new Record<>("[gamma@100/200]", "1", 1_000L)), + new CapturedForward<>(new Record<>("[delta@200/300]", "1", 1_000L)), + new CapturedForward<>(new Record<>("[gamma@200/300]", "1", 1_000L)) + ); + + assertThat(capturedForwards, is(expected)); + + store.close(); + } + + @Test + public void shouldWorkWithPersistentStore() throws IOException { + final File stateDir = TestUtils.tempDirectory(); + + try { + final MockProcessorContext context = new MockProcessorContext<>( + new Properties(), + new TaskId(0, 0), + stateDir + ); + + // Create, initialize, and register the state store. + final WindowStore store = + Stores.windowStoreBuilder(Stores.persistentWindowStore("WindowedCounts", + Duration.ofDays(24), + Duration.ofMillis(100), + false), + Serdes.String(), + Serdes.Integer()) + .withLoggingDisabled() // Changelog is not supported by MockProcessorContext. + .withCachingDisabled() // Caching is not supported by MockProcessorContext. + .build(); + store.init(context.getStateStoreContext(), store); + context.getStateStoreContext().register(store, null); + + // Create and initialize the processor under test + final Processor processor = new WindowedWordCountProcessorSupplier().get(); + processor.init(context); + + // send a record to the processor + processor.process(new Record<>("key", "alpha beta gamma alpha", 101L)); + + // send a record to the processor in a new window + processor.process(new Record<>("key", "gamma delta", 221L)); + + // note that the processor does not forward during process() + assertThat(context.forwarded().isEmpty(), is(true)); + + // now, we trigger the punctuator, which iterates over the state store and forwards the contents. + context.scheduledPunctuators().get(0).getPunctuator().punctuate(1_000L); + + // finally, we can verify the output. + final List> capturedForwards = context.forwarded(); + final List> expected = asList( + new CapturedForward<>(new Record<>("[alpha@100/200]", "2", 1_000L)), + new CapturedForward<>(new Record<>("[beta@100/200]", "1", 1_000L)), + new CapturedForward<>(new Record<>("[delta@200/300]", "1", 1_000L)), + new CapturedForward<>(new Record<>("[gamma@100/200]", "1", 1_000L)), + new CapturedForward<>(new Record<>("[gamma@200/300]", "1", 1_000L)) + ); + + assertThat(capturedForwards, is(expected)); + + store.close(); + } finally { + Utils.delete(stateDir); + } + } +} diff --git a/streams/test-utils/src/test/resources/log4j.properties b/streams/test-utils/src/test/resources/log4j.properties new file mode 100644 index 0000000..be36f90 --- /dev/null +++ b/streams/test-utils/src/test/resources/log4j.properties @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +log4j.rootLogger=INFO, stdout + +log4j.appender.stdout=org.apache.log4j.ConsoleAppender +log4j.appender.stdout.layout=org.apache.log4j.PatternLayout +log4j.appender.stdout.layout.ConversionPattern=[%d] %p %m (%c:%L)%n + +log4j.logger.org.apache.kafka=INFO diff --git a/streams/upgrade-system-tests-0100/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java b/streams/upgrade-system-tests-0100/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java new file mode 100644 index 0000000..27712cc --- /dev/null +++ b/streams/upgrade-system-tests-0100/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KStreamBuilder; +import org.apache.kafka.streams.processor.AbstractProcessor; +import org.apache.kafka.streams.processor.Processor; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.ProcessorSupplier; + +import java.util.Properties; + +public class StreamsUpgradeTest { + + @SuppressWarnings("unchecked") + public static void main(final String[] args) throws Exception { + if (args.length < 2) { + System.err.println("StreamsUpgradeTest requires two arguments (zookeeper-url, properties-file) but only " + args.length + " provided: " + + (args.length > 0 ? args[0] + " " : "")); + } + final String zookeeper = args[0]; + final String propFileName = args[1]; + + final Properties streamsProperties = Utils.loadProps(propFileName); + + System.out.println("StreamsTest instance started (StreamsUpgradeTest v0.10.0)"); + System.out.println("zookeeper=" + zookeeper); + System.out.println("props=" + streamsProperties); + + final KStreamBuilder builder = new KStreamBuilder(); + final KStream dataStream = builder.stream("data"); + dataStream.process(printProcessorSupplier()); + dataStream.to("echo"); + + final Properties config = new Properties(); + config.setProperty(StreamsConfig.APPLICATION_ID_CONFIG, "StreamsUpgradeTest"); + config.setProperty(StreamsConfig.ZOOKEEPER_CONNECT_CONFIG, zookeeper); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + config.putAll(streamsProperties); + + final KafkaStreams streams = new KafkaStreams(builder, config); + streams.start(); + + Runtime.getRuntime().addShutdownHook(new Thread() { + @Override + public void run() { + System.out.println("closing Kafka Streams instance"); + System.out.flush(); + streams.close(); + System.out.println("UPGRADE-TEST-CLIENT-CLOSED"); + System.out.flush(); + } + }); + } + + private static ProcessorSupplier printProcessorSupplier() { + return new ProcessorSupplier() { + public Processor get() { + return new AbstractProcessor() { + private int numRecordsProcessed = 0; + + @Override + public void init(final ProcessorContext context) { + System.out.println("[0.10.0] initializing processor: topic=data taskId=" + context.taskId()); + numRecordsProcessed = 0; + } + + @Override + public void process(final K key, final V value) { + numRecordsProcessed++; + if (numRecordsProcessed % 100 == 0) { + System.out.println("processed " + numRecordsProcessed + " records from topic=data"); + } + } + + @Override + public void punctuate(final long timestamp) {} + + @Override + public void close() {} + }; + } + }; + } +} diff --git a/streams/upgrade-system-tests-0100/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java b/streams/upgrade-system-tests-0100/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java new file mode 100644 index 0000000..ee15b1d --- /dev/null +++ b/streams/upgrade-system-tests-0100/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.ForeachAction; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KStreamBuilder; + +import java.util.Properties; + +public class StreamsUpgradeToCooperativeRebalanceTest { + + + @SuppressWarnings("unchecked") + public static void main(final String[] args) throws Exception { + if (args.length < 2) { + System.err.println("StreamsUpgradeToCooperativeRebalanceTest requires two arguments (zookeeper-url, properties-file) but only " + args.length + " provided: " + + (args.length > 0 ? args[0] : "")); + } + + final String zookeeper = args[0]; + final String propFileName = args[1]; + + final Properties streamsProperties = Utils.loadProps(propFileName); + final Properties config = new Properties(); + + System.out.println("StreamsTest instance started (StreamsUpgradeToCooperativeRebalanceTest v0.10.0)"); + System.out.println("zookeeper=" + zookeeper); + System.out.println("props=" + config); + + config.put(StreamsConfig.APPLICATION_ID_CONFIG, "cooperative-rebalance-upgrade"); + config.put(StreamsConfig.KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + config.put(StreamsConfig.VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + config.setProperty(StreamsConfig.ZOOKEEPER_CONNECT_CONFIG, zookeeper); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + config.putAll(streamsProperties); + + final String sourceTopic = config.getProperty("source.topic", "source"); + final String sinkTopic = config.getProperty("sink.topic", "sink"); + final int reportInterval = Integer.parseInt(config.getProperty("report.interval", "100")); + final String upgradePhase = config.getProperty("upgrade.phase", ""); + + final KStreamBuilder builder = new KStreamBuilder(); + + final KStream upgradeStream = builder.stream(sourceTopic); + upgradeStream.foreach(new ForeachAction() { + int recordCounter = 0; + + @Override + public void apply(final String key, final String value) { + if (recordCounter++ % reportInterval == 0) { + System.out.println(String.format("%sProcessed %d records so far", upgradePhase, recordCounter)); + System.out.flush(); + } + } + } + ); + upgradeStream.to(sinkTopic); + + final KafkaStreams streams = new KafkaStreams(builder, config); + + + streams.start(); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + streams.close(); + System.out.println(String.format("%sCOOPERATIVE-REBALANCE-TEST-CLIENT-CLOSED", upgradePhase)); + System.out.flush(); + })); + } +} diff --git a/streams/upgrade-system-tests-0101/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java b/streams/upgrade-system-tests-0101/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java new file mode 100644 index 0000000..379720b --- /dev/null +++ b/streams/upgrade-system-tests-0101/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KStreamBuilder; +import org.apache.kafka.streams.processor.AbstractProcessor; +import org.apache.kafka.streams.processor.Processor; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.ProcessorSupplier; + +import java.util.Properties; + +public class StreamsUpgradeTest { + + /** + * This test cannot be executed, as long as Kafka 0.10.1.2 is not released + */ + @SuppressWarnings("unchecked") + public static void main(final String[] args) throws Exception { + if (args.length < 2) { + System.err.println("StreamsUpgradeTest requires two arguments (zookeeper-url, properties-file) but only " + args.length + " provided: " + + (args.length > 0 ? args[0] + " " : "")); + } + final String zookeeper = args[0]; + final String propFileName = args[1]; + + final Properties streamsProperties = Utils.loadProps(propFileName); + + System.out.println("StreamsTest instance started (StreamsUpgradeTest v0.10.1)"); + System.out.println("zookeeper=" + zookeeper); + System.out.println("props=" + streamsProperties); + + final KStreamBuilder builder = new KStreamBuilder(); + final KStream dataStream = builder.stream("data"); + dataStream.process(printProcessorSupplier()); + dataStream.to("echo"); + + final Properties config = new Properties(); + config.setProperty(StreamsConfig.APPLICATION_ID_CONFIG, "StreamsUpgradeTest"); + config.setProperty(StreamsConfig.ZOOKEEPER_CONNECT_CONFIG, zookeeper); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + config.putAll(streamsProperties); + + final KafkaStreams streams = new KafkaStreams(builder, config); + streams.start(); + + Runtime.getRuntime().addShutdownHook(new Thread() { + @Override + public void run() { + System.out.println("closing Kafka Streams instance"); + System.out.flush(); + streams.close(); + System.out.println("UPGRADE-TEST-CLIENT-CLOSED"); + System.out.flush(); + } + }); + } + + private static ProcessorSupplier printProcessorSupplier() { + return new ProcessorSupplier() { + public Processor get() { + return new AbstractProcessor() { + private int numRecordsProcessed = 0; + + @Override + public void init(final ProcessorContext context) { + System.out.println("[0.10.1] initializing processor: topic=data taskId=" + context.taskId()); + numRecordsProcessed = 0; + } + + @Override + public void process(final K key, final V value) { + numRecordsProcessed++; + if (numRecordsProcessed % 100 == 0) { + System.out.println("processed " + numRecordsProcessed + " records from topic=data"); + } + } + + @Override + public void punctuate(final long timestamp) {} + + @Override + public void close() {} + }; + } + }; + } +} diff --git a/streams/upgrade-system-tests-0101/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java b/streams/upgrade-system-tests-0101/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java new file mode 100644 index 0000000..6b339b6 --- /dev/null +++ b/streams/upgrade-system-tests-0101/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.ForeachAction; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KStreamBuilder; + +import java.util.Properties; + +public class StreamsUpgradeToCooperativeRebalanceTest { + + + @SuppressWarnings("unchecked") + public static void main(final String[] args) throws Exception { + if (args.length < 2) { + System.err.println("StreamsUpgradeToCooperativeRebalanceTest requires two arguments (zookeeper-url, properties-file) but only " + args.length + " provided: " + + (args.length > 0 ? args[0] : "")); + } + final String zookeeper = args[0]; + final String propFileName = args[1]; + + final Properties streamsProperties = Utils.loadProps(propFileName); + final Properties config = new Properties(); + + System.out.println("StreamsTest instance started (StreamsUpgradeToCooperativeRebalanceTest v0.10.1)"); + System.out.println("zookeeper=" + zookeeper); + System.out.println("props=" + config); + + config.put(StreamsConfig.APPLICATION_ID_CONFIG, "cooperative-rebalance-upgrade"); + config.put(StreamsConfig.KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + config.put(StreamsConfig.VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + config.setProperty(StreamsConfig.ZOOKEEPER_CONNECT_CONFIG, zookeeper); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + config.putAll(streamsProperties); + + final String sourceTopic = config.getProperty("source.topic", "source"); + final String sinkTopic = config.getProperty("sink.topic", "sink"); + final int reportInterval = Integer.parseInt(config.getProperty("report.interval", "100")); + final String upgradePhase = config.getProperty("upgrade.phase", ""); + + final KStreamBuilder builder = new KStreamBuilder(); + + final KStream upgradeStream = builder.stream(sourceTopic); + upgradeStream.foreach(new ForeachAction() { + int recordCounter = 0; + + @Override + public void apply(final String key, final String value) { + if (recordCounter++ % reportInterval == 0) { + System.out.println(String.format("%sProcessed %d records so far", upgradePhase, recordCounter)); + System.out.flush(); + } + } + } + ); + upgradeStream.to(sinkTopic); + + final KafkaStreams streams = new KafkaStreams(builder, config); + + + streams.start(); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + streams.close(); + System.out.println(String.format("%sCOOPERATIVE-REBALANCE-TEST-CLIENT-CLOSED", upgradePhase)); + System.out.flush(); + })); + } +} diff --git a/streams/upgrade-system-tests-0102/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java b/streams/upgrade-system-tests-0102/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java new file mode 100644 index 0000000..75e5484 --- /dev/null +++ b/streams/upgrade-system-tests-0102/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KStreamBuilder; +import org.apache.kafka.streams.processor.AbstractProcessor; +import org.apache.kafka.streams.processor.Processor; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.ProcessorSupplier; + +import java.util.Properties; + +public class StreamsUpgradeTest { + + @SuppressWarnings("unchecked") + public static void main(final String[] args) throws Exception { + if (args.length < 1) { + System.err.println("StreamsUpgradeTest requires one argument (properties-file) but provided none"); + } + final String propFileName = args[0]; + + final Properties streamsProperties = Utils.loadProps(propFileName); + + System.out.println("StreamsTest instance started (StreamsUpgradeTest v0.10.2)"); + System.out.println("props=" + streamsProperties); + + final KStreamBuilder builder = new KStreamBuilder(); + final KStream dataStream = builder.stream("data"); + dataStream.process(printProcessorSupplier()); + dataStream.to("echo"); + + final Properties config = new Properties(); + config.setProperty(StreamsConfig.APPLICATION_ID_CONFIG, "StreamsUpgradeTest"); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + config.putAll(streamsProperties); + + final KafkaStreams streams = new KafkaStreams(builder, config); + streams.start(); + + Runtime.getRuntime().addShutdownHook(new Thread() { + @Override + public void run() { + streams.close(); + System.out.println("UPGRADE-TEST-CLIENT-CLOSED"); + System.out.flush(); + } + }); + } + + private static ProcessorSupplier printProcessorSupplier() { + return new ProcessorSupplier() { + public Processor get() { + return new AbstractProcessor() { + private int numRecordsProcessed = 0; + + @Override + public void init(final ProcessorContext context) { + System.out.println("[0.10.2] initializing processor: topic=data taskId=" + context.taskId()); + numRecordsProcessed = 0; + } + + @Override + public void process(final K key, final V value) { + numRecordsProcessed++; + if (numRecordsProcessed % 100 == 0) { + System.out.println("processed " + numRecordsProcessed + " records from topic=data"); + } + } + + @Override + public void punctuate(final long timestamp) {} + + @Override + public void close() {} + }; + } + }; + } +} diff --git a/streams/upgrade-system-tests-0102/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java b/streams/upgrade-system-tests-0102/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java new file mode 100644 index 0000000..32ef2eb --- /dev/null +++ b/streams/upgrade-system-tests-0102/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.ForeachAction; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KStreamBuilder; + +import java.util.Properties; + +public class StreamsUpgradeToCooperativeRebalanceTest { + + + @SuppressWarnings("unchecked") + public static void main(final String[] args) throws Exception { + if (args.length < 1) { + System.err.println("StreamsUpgradeToCooperativeRebalanceTest requires one argument (properties-file) but none provided"); + } + final String propFileName = args[0]; + + final Properties streamsProperties = Utils.loadProps(propFileName); + final Properties config = new Properties(); + + System.out.println("StreamsTest instance started (StreamsUpgradeToCooperativeRebalanceTest v0.10.2)"); + System.out.println("props=" + config); + + config.put(StreamsConfig.APPLICATION_ID_CONFIG, "cooperative-rebalance-upgrade"); + config.put(StreamsConfig.KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + config.put(StreamsConfig.VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + config.putAll(streamsProperties); + + final String sourceTopic = config.getProperty("source.topic", "source"); + final String sinkTopic = config.getProperty("sink.topic", "sink"); + final int reportInterval = Integer.parseInt(config.getProperty("report.interval", "100")); + final String upgradePhase = config.getProperty("upgrade.phase", ""); + + final KStreamBuilder builder = new KStreamBuilder(); + + final KStream upgradeStream = builder.stream(sourceTopic); + upgradeStream.foreach(new ForeachAction() { + int recordCounter = 0; + + @Override + public void apply(final String key, final String value) { + if (recordCounter++ % reportInterval == 0) { + System.out.println(String.format("%sProcessed %d records so far", upgradePhase, recordCounter)); + System.out.flush(); + } + } + } + ); + upgradeStream.to(sinkTopic); + + final KafkaStreams streams = new KafkaStreams(builder, config); + + + streams.start(); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + streams.close(); + System.out.println(String.format("%sCOOPERATIVE-REBALANCE-TEST-CLIENT-CLOSED", upgradePhase)); + System.out.flush(); + })); + } +} diff --git a/streams/upgrade-system-tests-0110/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java b/streams/upgrade-system-tests-0110/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java new file mode 100644 index 0000000..e029161 --- /dev/null +++ b/streams/upgrade-system-tests-0110/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KStreamBuilder; +import org.apache.kafka.streams.processor.AbstractProcessor; +import org.apache.kafka.streams.processor.Processor; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.ProcessorSupplier; + +import java.util.Properties; + +public class StreamsUpgradeTest { + + @SuppressWarnings("unchecked") + public static void main(final String[] args) throws Exception { + if (args.length < 1) { + System.err.println("StreamsUpgradeTest requires one argument (properties-file) but provided none"); + } + final String propFileName = args[0]; + + final Properties streamsProperties = Utils.loadProps(propFileName); + + System.out.println("StreamsTest instance started (StreamsUpgradeTest v0.11.0)"); + System.out.println("props=" + streamsProperties); + + final KStreamBuilder builder = new KStreamBuilder(); + final KStream dataStream = builder.stream("data"); + dataStream.process(printProcessorSupplier()); + dataStream.to("echo"); + + final Properties config = new Properties(); + config.setProperty(StreamsConfig.APPLICATION_ID_CONFIG, "StreamsUpgradeTest"); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + config.putAll(streamsProperties); + + final KafkaStreams streams = new KafkaStreams(builder, config); + streams.start(); + + Runtime.getRuntime().addShutdownHook(new Thread() { + @Override + public void run() { + streams.close(); + System.out.println("UPGRADE-TEST-CLIENT-CLOSED"); + System.out.flush(); + } + }); + } + + private static ProcessorSupplier printProcessorSupplier() { + return new ProcessorSupplier() { + public Processor get() { + return new AbstractProcessor() { + private int numRecordsProcessed = 0; + + @Override + public void init(final ProcessorContext context) { + System.out.println("[0.11.0] initializing processor: topic=data taskId=" + context.taskId()); + numRecordsProcessed = 0; + } + + @Override + public void process(final K key, final V value) { + numRecordsProcessed++; + if (numRecordsProcessed % 100 == 0) { + System.out.println("processed " + numRecordsProcessed + " records from topic=data"); + } + } + + @Override + public void punctuate(final long timestamp) {} + + @Override + public void close() {} + }; + } + }; + } +} diff --git a/streams/upgrade-system-tests-0110/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java b/streams/upgrade-system-tests-0110/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java new file mode 100644 index 0000000..a2ffa9d --- /dev/null +++ b/streams/upgrade-system-tests-0110/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.ForeachAction; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KStreamBuilder; + +import java.util.Properties; + +public class StreamsUpgradeToCooperativeRebalanceTest { + + + @SuppressWarnings("unchecked") + public static void main(final String[] args) throws Exception { + if (args.length < 1) { + System.err.println("StreamsUpgradeToCooperativeRebalanceTest requires one argument (properties-file) but none provided"); + } + final String propFileName = args[0]; + final Properties streamsProperties = Utils.loadProps(propFileName); + final Properties config = new Properties(); + + System.out.println("StreamsTest instance started (StreamsUpgradeToCooperativeRebalanceTest v0.11.0)"); + System.out.println("props=" + config); + + config.put(StreamsConfig.APPLICATION_ID_CONFIG, "cooperative-rebalance-upgrade"); + config.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + config.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + config.putAll(streamsProperties); + + final String sourceTopic = config.getProperty("source.topic", "source"); + final String sinkTopic = config.getProperty("sink.topic", "sink"); + final int reportInterval = Integer.parseInt(config.getProperty("report.interval", "100")); + final String upgradePhase = config.getProperty("upgrade.phase", ""); + + final KStreamBuilder builder = new KStreamBuilder(); + + final KStream upgradeStream = builder.stream(sourceTopic); + upgradeStream.foreach(new ForeachAction() { + int recordCounter = 0; + + @Override + public void apply(final String key, final String value) { + if (recordCounter++ % reportInterval == 0) { + System.out.println(String.format("%sProcessed %d records so far", upgradePhase, recordCounter)); + System.out.flush(); + } + } + } + ); + upgradeStream.to(sinkTopic); + + final KafkaStreams streams = new KafkaStreams(builder, config); + + + streams.start(); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + streams.close(); + System.out.println(String.format("%sCOOPERATIVE-REBALANCE-TEST-CLIENT-CLOSED", upgradePhase)); + System.out.flush(); + })); + } +} diff --git a/streams/upgrade-system-tests-10/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java b/streams/upgrade-system-tests-10/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java new file mode 100644 index 0000000..fc7cdd0 --- /dev/null +++ b/streams/upgrade-system-tests-10/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.processor.AbstractProcessor; +import org.apache.kafka.streams.processor.Processor; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.ProcessorSupplier; + +import java.util.Properties; + +public class StreamsUpgradeTest { + + @SuppressWarnings("unchecked") + public static void main(final String[] args) throws Exception { + if (args.length < 1) { + System.err.println("StreamsUpgradeTest requires one argument (properties-file) but provided none"); + } + final String propFileName = args[0]; + + final Properties streamsProperties = Utils.loadProps(propFileName); + + System.out.println("StreamsTest instance started (StreamsUpgradeTest v1.0)"); + System.out.println("props=" + streamsProperties); + + final StreamsBuilder builder = new StreamsBuilder(); + final KStream dataStream = builder.stream("data"); + dataStream.process(printProcessorSupplier()); + dataStream.to("echo"); + + final Properties config = new Properties(); + config.setProperty(StreamsConfig.APPLICATION_ID_CONFIG, "StreamsUpgradeTest"); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + config.putAll(streamsProperties); + + final KafkaStreams streams = new KafkaStreams(builder.build(), config); + streams.start(); + + Runtime.getRuntime().addShutdownHook(new Thread() { + @Override + public void run() { + streams.close(); + System.out.println("UPGRADE-TEST-CLIENT-CLOSED"); + System.out.flush(); + } + }); + } + + private static ProcessorSupplier printProcessorSupplier() { + return new ProcessorSupplier() { + public Processor get() { + return new AbstractProcessor() { + private int numRecordsProcessed = 0; + + @Override + public void init(final ProcessorContext context) { + System.out.println("[1.0] initializing processor: topic=data taskId=" + context.taskId()); + numRecordsProcessed = 0; + } + + @Override + public void process(final K key, final V value) { + numRecordsProcessed++; + if (numRecordsProcessed % 100 == 0) { + System.out.println("processed " + numRecordsProcessed + " records from topic=data"); + } + } + + @Override + public void punctuate(final long timestamp) {} + + @Override + public void close() {} + }; + } + }; + } +} diff --git a/streams/upgrade-system-tests-10/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java b/streams/upgrade-system-tests-10/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java new file mode 100644 index 0000000..bda6ac4 --- /dev/null +++ b/streams/upgrade-system-tests-10/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KafkaStreams.State; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.ForeachAction; +import org.apache.kafka.streams.processor.TaskMetadata; +import org.apache.kafka.streams.processor.ThreadMetadata; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Properties; +import java.util.Set; + +public class StreamsUpgradeToCooperativeRebalanceTest { + + + @SuppressWarnings("unchecked") + public static void main(final String[] args) throws Exception { + if (args.length < 1) { + System.err.println("StreamsUpgradeToCooperativeRebalanceTest requires one argument (properties-file) but none provided"); + } + System.out.println("Args are " + Arrays.toString(args)); + final String propFileName = args[0]; + final Properties streamsProperties = Utils.loadProps(propFileName); + + final Properties config = new Properties(); + System.out.println("StreamsTest instance started (StreamsUpgradeToCooperativeRebalanceTest v1.0)"); + System.out.println("props=" + streamsProperties); + + config.put(StreamsConfig.APPLICATION_ID_CONFIG, "cooperative-rebalance-upgrade"); + config.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + config.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + config.putAll(streamsProperties); + + final String sourceTopic = streamsProperties.getProperty("source.topic", "source"); + final String sinkTopic = streamsProperties.getProperty("sink.topic", "sink"); + final String taskDelimiter = streamsProperties.getProperty("task.delimiter", "#"); + final int reportInterval = Integer.parseInt(streamsProperties.getProperty("report.interval", "100")); + final String upgradePhase = streamsProperties.getProperty("upgrade.phase", ""); + + final StreamsBuilder builder = new StreamsBuilder(); + + builder.stream(sourceTopic) + .peek(new ForeachAction() { + int recordCounter = 0; + + @Override + public void apply(final String key, final String value) { + if (recordCounter++ % reportInterval == 0) { + System.out.println(String.format("%sProcessed %d records so far", upgradePhase, recordCounter)); + System.out.flush(); + } + } + } + ).to(sinkTopic); + + final KafkaStreams streams = new KafkaStreams(builder.build(), config); + + streams.setStateListener((newState, oldState) -> { + if (newState == State.RUNNING && oldState == State.REBALANCING) { + System.out.println(String.format("%sSTREAMS in a RUNNING State", upgradePhase)); + final Set allThreadMetadata = streams.localThreadsMetadata(); + final StringBuilder taskReportBuilder = new StringBuilder(); + final List activeTasks = new ArrayList<>(); + final List standbyTasks = new ArrayList<>(); + for (final ThreadMetadata threadMetadata : allThreadMetadata) { + getTasks(threadMetadata.activeTasks(), activeTasks); + if (!threadMetadata.standbyTasks().isEmpty()) { + getTasks(threadMetadata.standbyTasks(), standbyTasks); + } + } + addTasksToBuilder(activeTasks, taskReportBuilder); + taskReportBuilder.append(taskDelimiter); + if (!standbyTasks.isEmpty()) { + addTasksToBuilder(standbyTasks, taskReportBuilder); + } + System.out.println("TASK-ASSIGNMENTS:" + taskReportBuilder); + } + + if (newState == State.REBALANCING) { + System.out.println(String.format("%sStarting a REBALANCE", upgradePhase)); + } + }); + + + streams.start(); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + streams.close(); + System.out.println(String.format("%sCOOPERATIVE-REBALANCE-TEST-CLIENT-CLOSED", upgradePhase)); + System.out.flush(); + })); + } + + private static void addTasksToBuilder(final List tasks, final StringBuilder builder) { + if (!tasks.isEmpty()) { + for (final String task : tasks) { + builder.append(task).append(","); + } + builder.setLength(builder.length() - 1); + } + } + private static void getTasks(final Set taskMetadata, + final List taskList) { + for (final TaskMetadata task : taskMetadata) { + final Set topicPartitions = task.topicPartitions(); + for (final TopicPartition topicPartition : topicPartitions) { + taskList.add(topicPartition.toString()); + } + } + } +} diff --git a/streams/upgrade-system-tests-11/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java b/streams/upgrade-system-tests-11/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java new file mode 100644 index 0000000..ebe59ab --- /dev/null +++ b/streams/upgrade-system-tests-11/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.processor.AbstractProcessor; +import org.apache.kafka.streams.processor.Processor; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.ProcessorSupplier; + +import java.util.Properties; + +public class StreamsUpgradeTest { + + @SuppressWarnings("unchecked") + public static void main(final String[] args) throws Exception { + if (args.length < 1) { + System.err.println("StreamsUpgradeTest requires one argument (properties-file) but provided none"); + } + final String propFileName = args[0]; + + final Properties streamsProperties = Utils.loadProps(propFileName); + + System.out.println("StreamsTest instance started (StreamsUpgradeTest v1.1)"); + System.out.println("props=" + streamsProperties); + + final StreamsBuilder builder = new StreamsBuilder(); + final KStream dataStream = builder.stream("data"); + dataStream.process(printProcessorSupplier()); + dataStream.to("echo"); + + final Properties config = new Properties(); + config.setProperty(StreamsConfig.APPLICATION_ID_CONFIG, "StreamsUpgradeTest"); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + config.putAll(streamsProperties); + + final KafkaStreams streams = new KafkaStreams(builder.build(), config); + streams.start(); + + Runtime.getRuntime().addShutdownHook(new Thread() { + @Override + public void run() { + streams.close(); + System.out.println("UPGRADE-TEST-CLIENT-CLOSED"); + System.out.flush(); + } + }); + } + + private static ProcessorSupplier printProcessorSupplier() { + return new ProcessorSupplier() { + public Processor get() { + return new AbstractProcessor() { + private int numRecordsProcessed = 0; + + @Override + public void init(final ProcessorContext context) { + System.out.println("[1.1] initializing processor: topic=data taskId=" + context.taskId()); + numRecordsProcessed = 0; + } + + @Override + public void process(final K key, final V value) { + numRecordsProcessed++; + if (numRecordsProcessed % 100 == 0) { + System.out.println("processed " + numRecordsProcessed + " records from topic=data"); + } + } + + @Override + public void punctuate(final long timestamp) {} + + @Override + public void close() {} + }; + } + }; + } +} diff --git a/streams/upgrade-system-tests-11/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java b/streams/upgrade-system-tests-11/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java new file mode 100644 index 0000000..6643d29 --- /dev/null +++ b/streams/upgrade-system-tests-11/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KafkaStreams.State; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.ForeachAction; +import org.apache.kafka.streams.processor.TaskMetadata; +import org.apache.kafka.streams.processor.ThreadMetadata; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Properties; +import java.util.Set; + +public class StreamsUpgradeToCooperativeRebalanceTest { + + + @SuppressWarnings("unchecked") + public static void main(final String[] args) throws Exception { + if (args.length < 1) { + System.err.println("StreamsUpgradeToCooperativeRebalanceTest requires one argument (properties-file) but none provided"); + } + System.out.println("Args are " + Arrays.toString(args)); + final String propFileName = args[0]; + final Properties streamsProperties = Utils.loadProps(propFileName); + + final Properties config = new Properties(); + System.out.println("StreamsTest instance started (StreamsUpgradeToCooperativeRebalanceTest v1.1)"); + System.out.println("props=" + streamsProperties); + + config.put(StreamsConfig.APPLICATION_ID_CONFIG, "cooperative-rebalance-upgrade"); + config.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + config.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + config.putAll(streamsProperties); + + final String sourceTopic = streamsProperties.getProperty("source.topic", "source"); + final String sinkTopic = streamsProperties.getProperty("sink.topic", "sink"); + final String taskDelimiter = streamsProperties.getProperty("task.delimiter", "#"); + final int reportInterval = Integer.parseInt(streamsProperties.getProperty("report.interval", "100")); + final String upgradePhase = streamsProperties.getProperty("upgrade.phase", ""); + + final StreamsBuilder builder = new StreamsBuilder(); + + builder.stream(sourceTopic) + .peek(new ForeachAction() { + int recordCounter = 0; + + @Override + public void apply(final String key, final String value) { + if (recordCounter++ % reportInterval == 0) { + System.out.println(String.format("%sProcessed %d records so far", upgradePhase, recordCounter)); + System.out.flush(); + } + } + } + ).to(sinkTopic); + + final KafkaStreams streams = new KafkaStreams(builder.build(), config); + + streams.setStateListener((newState, oldState) -> { + if (newState == State.RUNNING && oldState == State.REBALANCING) { + System.out.println(String.format("%sSTREAMS in a RUNNING State", upgradePhase)); + final Set allThreadMetadata = streams.localThreadsMetadata(); + final StringBuilder taskReportBuilder = new StringBuilder(); + final List activeTasks = new ArrayList<>(); + final List standbyTasks = new ArrayList<>(); + for (final ThreadMetadata threadMetadata : allThreadMetadata) { + getTasks(threadMetadata.activeTasks(), activeTasks); + if (!threadMetadata.standbyTasks().isEmpty()) { + getTasks(threadMetadata.standbyTasks(), standbyTasks); + } + } + addTasksToBuilder(activeTasks, taskReportBuilder); + taskReportBuilder.append(taskDelimiter); + if (!standbyTasks.isEmpty()) { + addTasksToBuilder(standbyTasks, taskReportBuilder); + } + System.out.println("TASK-ASSIGNMENTS:" + taskReportBuilder); + } + + if (newState == State.REBALANCING) { + System.out.println(String.format("%sStarting a REBALANCE", upgradePhase)); + } + }); + + + streams.start(); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + streams.close(); + System.out.println(String.format("%sCOOPERATIVE-REBALANCE-TEST-CLIENT-CLOSED", upgradePhase)); + System.out.flush(); + })); + } + + private static void addTasksToBuilder(final List tasks, final StringBuilder builder) { + if (!tasks.isEmpty()) { + for (final String task : tasks) { + builder.append(task).append(","); + } + builder.setLength(builder.length() - 1); + } + } + private static void getTasks(final Set taskMetadata, + final List taskList) { + for (final TaskMetadata task : taskMetadata) { + final Set topicPartitions = task.topicPartitions(); + for (final TopicPartition topicPartition : topicPartitions) { + taskList.add(topicPartition.toString()); + } + } + } +} diff --git a/streams/upgrade-system-tests-20/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java b/streams/upgrade-system-tests-20/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java new file mode 100644 index 0000000..5becd29 --- /dev/null +++ b/streams/upgrade-system-tests-20/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.processor.AbstractProcessor; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.ProcessorSupplier; + +import java.util.Properties; + +public class StreamsUpgradeTest { + + @SuppressWarnings("unchecked") + public static void main(final String[] args) throws Exception { + if (args.length < 1) { + System.err.println("StreamsUpgradeTest requires one argument (properties-file) but provided none"); + } + final String propFileName = args[0]; + + final Properties streamsProperties = Utils.loadProps(propFileName); + + System.out.println("StreamsTest instance started (StreamsUpgradeTest v2.0)"); + System.out.println("props=" + streamsProperties); + + final StreamsBuilder builder = new StreamsBuilder(); + final KStream dataStream = builder.stream("data"); + dataStream.process(printProcessorSupplier()); + dataStream.to("echo"); + + final Properties config = new Properties(); + config.setProperty(StreamsConfig.APPLICATION_ID_CONFIG, "StreamsUpgradeTest"); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + config.putAll(streamsProperties); + + final KafkaStreams streams = new KafkaStreams(builder.build(), config); + streams.start(); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + streams.close(); + System.out.println("UPGRADE-TEST-CLIENT-CLOSED"); + System.out.flush(); + })); + } + + private static ProcessorSupplier printProcessorSupplier() { + return () -> new AbstractProcessor() { + private int numRecordsProcessed = 0; + + @Override + public void init(final ProcessorContext context) { + System.out.println("[2.0] initializing processor: topic=data taskId=" + context.taskId()); + numRecordsProcessed = 0; + } + + @Override + public void process(final K key, final V value) { + numRecordsProcessed++; + if (numRecordsProcessed % 100 == 0) { + System.out.println("processed " + numRecordsProcessed + " records from topic=data"); + } + } + + @Override + public void close() {} + }; + } +} diff --git a/streams/upgrade-system-tests-20/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java b/streams/upgrade-system-tests-20/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java new file mode 100644 index 0000000..0c697f6 --- /dev/null +++ b/streams/upgrade-system-tests-20/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KafkaStreams.State; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.ForeachAction; +import org.apache.kafka.streams.processor.TaskMetadata; +import org.apache.kafka.streams.processor.ThreadMetadata; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Properties; +import java.util.Set; + +public class StreamsUpgradeToCooperativeRebalanceTest { + + + @SuppressWarnings("unchecked") + public static void main(final String[] args) throws Exception { + if (args.length < 1) { + System.err.println("StreamsUpgradeToCooperativeRebalanceTest requires one argument (properties-file) but none provided"); + } + System.out.println("Args are " + Arrays.toString(args)); + final String propFileName = args[0]; + final Properties streamsProperties = Utils.loadProps(propFileName); + + final Properties config = new Properties(); + System.out.println("StreamsTest instance started (StreamsUpgradeToCooperativeRebalanceTest v2.0)"); + System.out.println("props=" + streamsProperties); + + config.put(StreamsConfig.APPLICATION_ID_CONFIG, "cooperative-rebalance-upgrade"); + config.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + config.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + config.putAll(streamsProperties); + + final String sourceTopic = streamsProperties.getProperty("source.topic", "source"); + final String sinkTopic = streamsProperties.getProperty("sink.topic", "sink"); + final String taskDelimiter = streamsProperties.getProperty("task.delimiter", "#"); + final int reportInterval = Integer.parseInt(streamsProperties.getProperty("report.interval", "100")); + final String upgradePhase = streamsProperties.getProperty("upgrade.phase", ""); + + final StreamsBuilder builder = new StreamsBuilder(); + + builder.stream(sourceTopic) + .peek(new ForeachAction() { + int recordCounter = 0; + + @Override + public void apply(final String key, final String value) { + if (recordCounter++ % reportInterval == 0) { + System.out.println(String.format("%sProcessed %d records so far", upgradePhase, recordCounter)); + System.out.flush(); + } + } + } + ).to(sinkTopic); + + final KafkaStreams streams = new KafkaStreams(builder.build(), config); + + streams.setStateListener((newState, oldState) -> { + if (newState == State.RUNNING && oldState == State.REBALANCING) { + System.out.println(String.format("%sSTREAMS in a RUNNING State", upgradePhase)); + final Set allThreadMetadata = streams.localThreadsMetadata(); + final StringBuilder taskReportBuilder = new StringBuilder(); + final List activeTasks = new ArrayList<>(); + final List standbyTasks = new ArrayList<>(); + for (final ThreadMetadata threadMetadata : allThreadMetadata) { + getTasks(threadMetadata.activeTasks(), activeTasks); + if (!threadMetadata.standbyTasks().isEmpty()) { + getTasks(threadMetadata.standbyTasks(), standbyTasks); + } + } + addTasksToBuilder(activeTasks, taskReportBuilder); + taskReportBuilder.append(taskDelimiter); + if (!standbyTasks.isEmpty()) { + addTasksToBuilder(standbyTasks, taskReportBuilder); + } + System.out.println("TASK-ASSIGNMENTS:" + taskReportBuilder); + } + + if (newState == State.REBALANCING) { + System.out.println(String.format("%sStarting a REBALANCE", upgradePhase)); + } + }); + + + streams.start(); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + streams.close(); + System.out.println(String.format("%sCOOPERATIVE-REBALANCE-TEST-CLIENT-CLOSED", upgradePhase)); + System.out.flush(); + })); + } + + private static void addTasksToBuilder(final List tasks, final StringBuilder builder) { + if (!tasks.isEmpty()) { + for (final String task : tasks) { + builder.append(task).append(","); + } + builder.setLength(builder.length() - 1); + } + } + private static void getTasks(final Set taskMetadata, + final List taskList) { + for (final TaskMetadata task : taskMetadata) { + final Set topicPartitions = task.topicPartitions(); + for (final TopicPartition topicPartition : topicPartitions) { + taskList.add(topicPartition.toString()); + } + } + } +} diff --git a/streams/upgrade-system-tests-21/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java b/streams/upgrade-system-tests-21/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java new file mode 100644 index 0000000..1f0a17d --- /dev/null +++ b/streams/upgrade-system-tests-21/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.processor.AbstractProcessor; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.ProcessorSupplier; + +import java.util.Properties; + +public class StreamsUpgradeTest { + + @SuppressWarnings("unchecked") + public static void main(final String[] args) throws Exception { + if (args.length < 1) { + System.err.println("StreamsUpgradeTest requires one argument (properties-file) but provided none"); + } + final String propFileName = args[0]; + + final Properties streamsProperties = Utils.loadProps(propFileName); + + System.out.println("StreamsTest instance started (StreamsUpgradeTest v2.1)"); + System.out.println("props=" + streamsProperties); + + final StreamsBuilder builder = new StreamsBuilder(); + final KStream dataStream = builder.stream("data"); + dataStream.process(printProcessorSupplier()); + dataStream.to("echo"); + + final Properties config = new Properties(); + config.setProperty(StreamsConfig.APPLICATION_ID_CONFIG, "StreamsUpgradeTest"); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + config.putAll(streamsProperties); + + final KafkaStreams streams = new KafkaStreams(builder.build(), config); + streams.start(); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + streams.close(); + System.out.println("UPGRADE-TEST-CLIENT-CLOSED"); + System.out.flush(); + })); + } + + private static ProcessorSupplier printProcessorSupplier() { + return () -> new AbstractProcessor() { + private int numRecordsProcessed = 0; + + @Override + public void init(final ProcessorContext context) { + System.out.println("[2.1] initializing processor: topic=data taskId=" + context.taskId()); + numRecordsProcessed = 0; + } + + @Override + public void process(final K key, final V value) { + numRecordsProcessed++; + if (numRecordsProcessed % 100 == 0) { + System.out.println("processed " + numRecordsProcessed + " records from topic=data"); + } + } + + @Override + public void close() {} + }; + } +} diff --git a/streams/upgrade-system-tests-21/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java b/streams/upgrade-system-tests-21/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java new file mode 100644 index 0000000..299fffa --- /dev/null +++ b/streams/upgrade-system-tests-21/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KafkaStreams.State; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.ForeachAction; +import org.apache.kafka.streams.processor.TaskMetadata; +import org.apache.kafka.streams.processor.ThreadMetadata; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Properties; +import java.util.Set; + +public class StreamsUpgradeToCooperativeRebalanceTest { + + + @SuppressWarnings("unchecked") + public static void main(final String[] args) throws Exception { + if (args.length < 1) { + System.err.println("StreamsUpgradeToCooperativeRebalanceTest requires one argument (properties-file) but none provided"); + } + System.out.println("Args are " + Arrays.toString(args)); + final String propFileName = args[0]; + final Properties streamsProperties = Utils.loadProps(propFileName); + + final Properties config = new Properties(); + System.out.println("StreamsTest instance started (StreamsUpgradeToCooperativeRebalanceTest v2.2)"); + System.out.println("props=" + streamsProperties); + + config.put(StreamsConfig.APPLICATION_ID_CONFIG, "cooperative-rebalance-upgrade"); + config.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + config.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + config.putAll(streamsProperties); + + final String sourceTopic = streamsProperties.getProperty("source.topic", "source"); + final String sinkTopic = streamsProperties.getProperty("sink.topic", "sink"); + final String taskDelimiter = streamsProperties.getProperty("task.delimiter", "#"); + final int reportInterval = Integer.parseInt(streamsProperties.getProperty("report.interval", "100")); + final String upgradePhase = streamsProperties.getProperty("upgrade.phase", ""); + + final StreamsBuilder builder = new StreamsBuilder(); + + builder.stream(sourceTopic) + .peek(new ForeachAction() { + int recordCounter = 0; + + @Override + public void apply(final String key, final String value) { + if (recordCounter++ % reportInterval == 0) { + System.out.println(String.format("%sProcessed %d records so far", upgradePhase, recordCounter)); + System.out.flush(); + } + } + } + ).to(sinkTopic); + + final KafkaStreams streams = new KafkaStreams(builder.build(), config); + + streams.setStateListener((newState, oldState) -> { + if (newState == State.RUNNING && oldState == State.REBALANCING) { + System.out.println(String.format("%sSTREAMS in a RUNNING State", upgradePhase)); + final Set allThreadMetadata = streams.localThreadsMetadata(); + final StringBuilder taskReportBuilder = new StringBuilder(); + final List activeTasks = new ArrayList<>(); + final List standbyTasks = new ArrayList<>(); + for (final ThreadMetadata threadMetadata : allThreadMetadata) { + getTasks(threadMetadata.activeTasks(), activeTasks); + if (!threadMetadata.standbyTasks().isEmpty()) { + getTasks(threadMetadata.standbyTasks(), standbyTasks); + } + } + addTasksToBuilder(activeTasks, taskReportBuilder); + taskReportBuilder.append(taskDelimiter); + if (!standbyTasks.isEmpty()) { + addTasksToBuilder(standbyTasks, taskReportBuilder); + } + System.out.println("TASK-ASSIGNMENTS:" + taskReportBuilder); + } + + if (newState == State.REBALANCING) { + System.out.println(String.format("%sStarting a REBALANCE", upgradePhase)); + } + }); + + + streams.start(); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + streams.close(); + System.out.println(String.format("%sCOOPERATIVE-REBALANCE-TEST-CLIENT-CLOSED", upgradePhase)); + System.out.flush(); + })); + } + + private static void addTasksToBuilder(final List tasks, final StringBuilder builder) { + if (!tasks.isEmpty()) { + for (final String task : tasks) { + builder.append(task).append(","); + } + builder.setLength(builder.length() - 1); + } + } + private static void getTasks(final Set taskMetadata, + final List taskList) { + for (final TaskMetadata task : taskMetadata) { + final Set topicPartitions = task.topicPartitions(); + for (final TopicPartition topicPartition : topicPartitions) { + taskList.add(topicPartition.toString()); + } + } + } +} diff --git a/streams/upgrade-system-tests-22/src/test/java/org/apache/kafka/streams/tests/SmokeTestClient.java b/streams/upgrade-system-tests-22/src/test/java/org/apache/kafka/streams/tests/SmokeTestClient.java new file mode 100644 index 0000000..ced1369 --- /dev/null +++ b/streams/upgrade-system-tests-22/src/test/java/org/apache/kafka/streams/tests/SmokeTestClient.java @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.KafkaThread; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.Suppressed.BufferConfig; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.WindowStore; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.time.Duration; +import java.time.Instant; +import java.util.Properties; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.apache.kafka.streams.kstream.Suppressed.untilWindowCloses; + +public class SmokeTestClient extends SmokeTestUtil { + + private final String name; + + private KafkaStreams streams; + private boolean uncaughtException = false; + private boolean started; + private volatile boolean closed; + + private static void addShutdownHook(final String name, final Runnable runnable) { + if (name != null) { + Runtime.getRuntime().addShutdownHook(KafkaThread.nonDaemon(name, runnable)); + } else { + Runtime.getRuntime().addShutdownHook(new Thread(runnable)); + } + } + + private static File tempDirectory() { + final String prefix = "kafka-"; + final File file; + try { + file = Files.createTempDirectory(prefix).toFile(); + } catch (final IOException ex) { + throw new RuntimeException("Failed to create a temp dir", ex); + } + file.deleteOnExit(); + + addShutdownHook("delete-temp-file-shutdown-hook", () -> { + try { + Utils.delete(file); + } catch (final IOException e) { + System.out.println("Error deleting " + file.getAbsolutePath()); + e.printStackTrace(System.out); + } + }); + + return file; + } + + public SmokeTestClient(final String name) { + this.name = name; + } + + public boolean started() { + return started; + } + + public boolean closed() { + return closed; + } + + public void start(final Properties streamsProperties) { + final Topology build = getTopology(); + streams = new KafkaStreams(build, getStreamsConfig(streamsProperties)); + + final CountDownLatch countDownLatch = new CountDownLatch(1); + streams.setStateListener((newState, oldState) -> { + System.out.printf("%s %s: %s -> %s%n", name, Instant.now(), oldState, newState); + if (oldState == KafkaStreams.State.REBALANCING && newState == KafkaStreams.State.RUNNING) { + started = true; + countDownLatch.countDown(); + } + + if (newState == KafkaStreams.State.NOT_RUNNING) { + closed = true; + } + }); + + streams.setUncaughtExceptionHandler((t, e) -> { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION"); + System.out.println(name + ": FATAL: An unexpected exception is encountered on thread " + t + ": " + e); + e.printStackTrace(System.out); + uncaughtException = true; + streams.close(Duration.ofSeconds(30)); + }); + + addShutdownHook("streams-shutdown-hook", this::close); + + streams.start(); + try { + if (!countDownLatch.await(1, TimeUnit.MINUTES)) { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION: Didn't start in one minute"); + } + } catch (final InterruptedException e) { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION: " + e); + e.printStackTrace(System.out); + } + System.out.println(name + ": SMOKE-TEST-CLIENT-STARTED"); + System.out.println(name + " started at " + Instant.now()); + } + + public void closeAsync() { + streams.close(Duration.ZERO); + } + + public void close() { + final boolean closed = streams.close(Duration.ofMinutes(1)); + + if (closed && !uncaughtException) { + System.out.println(name + ": SMOKE-TEST-CLIENT-CLOSED"); + } else if (closed) { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION"); + } else { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION: Didn't close"); + } + } + + private Properties getStreamsConfig(final Properties props) { + final Properties fullProps = new Properties(props); + fullProps.put(StreamsConfig.APPLICATION_ID_CONFIG, "SmokeTest"); + fullProps.put(StreamsConfig.CLIENT_ID_CONFIG, "SmokeTest-" + name); + fullProps.put(StreamsConfig.STATE_DIR_CONFIG, tempDirectory().getAbsolutePath()); + fullProps.putAll(props); + return fullProps; + } + + public Topology getTopology() { + final StreamsBuilder builder = new StreamsBuilder(); + final Consumed stringIntConsumed = Consumed.with(stringSerde, intSerde); + final KStream source = builder.stream("data", stringIntConsumed); + source.filterNot((k, v) -> k.equals("flush")) + .to("echo", Produced.with(stringSerde, intSerde)); + final KStream data = source.filter((key, value) -> value == null || value != END); + data.process(SmokeTestUtil.printProcessorSupplier("data", name)); + + // min + final KGroupedStream groupedData = data.groupByKey(Grouped.with(stringSerde, intSerde)); + + final KTable, Integer> minAggregation = groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(1)).grace(Duration.ofMinutes(1))) + .aggregate( + () -> Integer.MAX_VALUE, + (aggKey, value, aggregate) -> (value < aggregate) ? value : aggregate, + Materialized + .>as("uwin-min") + .withValueSerde(intSerde) + .withRetention(Duration.ofHours(25)) + ); + + streamify(minAggregation, "min-raw"); + + streamify(minAggregation.suppress(untilWindowCloses(BufferConfig.unbounded())), "min-suppressed"); + + minAggregation + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("min", Produced.with(stringSerde, intSerde)); + + final KTable, Integer> smallWindowSum = groupedData + .windowedBy(TimeWindows.of(Duration.ofSeconds(2)).advanceBy(Duration.ofSeconds(1)).grace(Duration.ofSeconds(30))) + .reduce((l, r) -> l + r); + + streamify(smallWindowSum, "sws-raw"); + streamify(smallWindowSum.suppress(untilWindowCloses(BufferConfig.unbounded())), "sws-suppressed"); + + final KTable minTable = builder.table( + "min", + Consumed.with(stringSerde, intSerde), + Materialized.as("minStoreName")); + + minTable.toStream().process(SmokeTestUtil.printProcessorSupplier("min", name)); + + // max + groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(2))) + .aggregate( + () -> Integer.MIN_VALUE, + (aggKey, value, aggregate) -> (value > aggregate) ? value : aggregate, + Materialized.>as("uwin-max").withValueSerde(intSerde)) + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("max", Produced.with(stringSerde, intSerde)); + + final KTable maxTable = builder.table( + "max", + Consumed.with(stringSerde, intSerde), + Materialized.as("maxStoreName")); + maxTable.toStream().process(SmokeTestUtil.printProcessorSupplier("max", name)); + + // sum + groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(2))) + .aggregate( + () -> 0L, + (aggKey, value, aggregate) -> (long) value + aggregate, + Materialized.>as("win-sum").withValueSerde(longSerde)) + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("sum", Produced.with(stringSerde, longSerde)); + + final Consumed stringLongConsumed = Consumed.with(stringSerde, longSerde); + final KTable sumTable = builder.table("sum", stringLongConsumed); + sumTable.toStream().process(SmokeTestUtil.printProcessorSupplier("sum", name)); + + // cnt + groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(2))) + .count(Materialized.as("uwin-cnt")) + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("cnt", Produced.with(stringSerde, longSerde)); + + final KTable cntTable = builder.table( + "cnt", + Consumed.with(stringSerde, longSerde), + Materialized.as("cntStoreName")); + cntTable.toStream().process(SmokeTestUtil.printProcessorSupplier("cnt", name)); + + // dif + maxTable + .join( + minTable, + (value1, value2) -> value1 - value2) + .toStream() + .filterNot((k, v) -> k.equals("flush")) + .to("dif", Produced.with(stringSerde, intSerde)); + + // avg + sumTable + .join( + cntTable, + (value1, value2) -> (double) value1 / (double) value2) + .toStream() + .filterNot((k, v) -> k.equals("flush")) + .to("avg", Produced.with(stringSerde, doubleSerde)); + + // test repartition + final Agg agg = new Agg(); + cntTable.groupBy(agg.selector(), Grouped.with(stringSerde, longSerde)) + .aggregate(agg.init(), agg.adder(), agg.remover(), + Materialized.as(Stores.inMemoryKeyValueStore("cntByCnt")) + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.Long())) + .toStream() + .to("tagg", Produced.with(stringSerde, longSerde)); + + return builder.build(); + } + + private static void streamify(final KTable, Integer> windowedTable, final String topic) { + windowedTable + .toStream() + .filterNot((k, v) -> k.key().equals("flush")) + .map((key, value) -> new KeyValue<>(key.toString(), value)) + .to(topic, Produced.with(stringSerde, intSerde)); + } +} diff --git a/streams/upgrade-system-tests-22/src/test/java/org/apache/kafka/streams/tests/SmokeTestDriver.java b/streams/upgrade-system-tests-22/src/test/java/org/apache/kafka/streams/tests/SmokeTestDriver.java new file mode 100644 index 0000000..d0a7d22 --- /dev/null +++ b/streams/upgrade-system-tests-22/src/test/java/org/apache/kafka/streams/tests/SmokeTestDriver.java @@ -0,0 +1,632 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.Callback; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; + +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static java.util.Collections.emptyMap; +import static org.apache.kafka.common.utils.Utils.mkEntry; + +public class SmokeTestDriver extends SmokeTestUtil { + private static final String[] TOPICS = { + "data", + "echo", + "max", + "min", "min-suppressed", "min-raw", + "dif", + "sum", + "sws-raw", "sws-suppressed", + "cnt", + "avg", + "tagg" + }; + + private static final int MAX_RECORD_EMPTY_RETRIES = 30; + + private static class ValueList { + public final String key; + private final int[] values; + private int index; + + ValueList(final int min, final int max) { + key = min + "-" + max; + + values = new int[max - min + 1]; + for (int i = 0; i < values.length; i++) { + values[i] = min + i; + } + // We want to randomize the order of data to test not completely predictable processing order + // However, values are also use as a timestamp of the record. (TODO: separate data and timestamp) + // We keep some correlation of time and order. Thus, the shuffling is done with a sliding window + shuffle(values, 10); + + index = 0; + } + + int next() { + return (index < values.length) ? values[index++] : -1; + } + } + + public static String[] topics() { + return Arrays.copyOf(TOPICS, TOPICS.length); + } + + static void generatePerpetually(final String kafka, + final int numKeys, + final int maxRecordsPerKey) { + final Properties producerProps = generatorProperties(kafka); + + int numRecordsProduced = 0; + + final ValueList[] data = new ValueList[numKeys]; + for (int i = 0; i < numKeys; i++) { + data[i] = new ValueList(i, i + maxRecordsPerKey - 1); + } + + final Random rand = new Random(); + + try (final KafkaProducer producer = new KafkaProducer<>(producerProps)) { + while (true) { + final int index = rand.nextInt(numKeys); + final String key = data[index].key; + final int value = data[index].next(); + + final ProducerRecord record = + new ProducerRecord<>( + "data", + stringSerde.serializer().serialize("", key), + intSerde.serializer().serialize("", value) + ); + + producer.send(record); + + numRecordsProduced++; + if (numRecordsProduced % 100 == 0) { + System.out.println(Instant.now() + " " + numRecordsProduced + " records produced"); + } + Utils.sleep(2); + } + } + } + + public static Map> generate(final String kafka, + final int numKeys, + final int maxRecordsPerKey, + final Duration timeToSpend) { + final Properties producerProps = generatorProperties(kafka); + + + int numRecordsProduced = 0; + + final Map> allData = new HashMap<>(); + final ValueList[] data = new ValueList[numKeys]; + for (int i = 0; i < numKeys; i++) { + data[i] = new ValueList(i, i + maxRecordsPerKey - 1); + allData.put(data[i].key, new HashSet<>()); + } + final Random rand = new Random(); + + int remaining = data.length; + + final long recordPauseTime = timeToSpend.toMillis() / numKeys / maxRecordsPerKey; + + List> needRetry = new ArrayList<>(); + + try (final KafkaProducer producer = new KafkaProducer<>(producerProps)) { + while (remaining > 0) { + final int index = rand.nextInt(remaining); + final String key = data[index].key; + final int value = data[index].next(); + + if (value < 0) { + remaining--; + data[index] = data[remaining]; + } else { + + final ProducerRecord record = + new ProducerRecord<>( + "data", + stringSerde.serializer().serialize("", key), + intSerde.serializer().serialize("", value) + ); + + producer.send(record, new TestCallback(record, needRetry)); + + numRecordsProduced++; + allData.get(key).add(value); + if (numRecordsProduced % 100 == 0) { + System.out.println(Instant.now() + " " + numRecordsProduced + " records produced"); + } + Utils.sleep(Math.max(recordPauseTime, 2)); + } + } + producer.flush(); + + int remainingRetries = 5; + while (!needRetry.isEmpty()) { + final List> needRetry2 = new ArrayList<>(); + for (final ProducerRecord record : needRetry) { + System.out.println("retry producing " + stringSerde.deserializer().deserialize("", record.key())); + producer.send(record, new TestCallback(record, needRetry2)); + } + producer.flush(); + needRetry = needRetry2; + + if (--remainingRetries == 0 && !needRetry.isEmpty()) { + System.err.println("Failed to produce all records after multiple retries"); + Exit.exit(1); + } + } + + // now that we've sent everything, we'll send some final records with a timestamp high enough to flush out + // all suppressed records. + final List partitions = producer.partitionsFor("data"); + for (final PartitionInfo partition : partitions) { + producer.send(new ProducerRecord<>( + partition.topic(), + partition.partition(), + System.currentTimeMillis() + Duration.ofDays(2).toMillis(), + stringSerde.serializer().serialize("", "flush"), + intSerde.serializer().serialize("", 0) + )); + } + } + return Collections.unmodifiableMap(allData); + } + + private static Properties generatorProperties(final String kafka) { + final Properties producerProps = new Properties(); + producerProps.put(ProducerConfig.CLIENT_ID_CONFIG, "SmokeTest"); + producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, kafka); + producerProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class); + producerProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class); + producerProps.put(ProducerConfig.ACKS_CONFIG, "all"); + return producerProps; + } + + private static class TestCallback implements Callback { + private final ProducerRecord originalRecord; + private final List> needRetry; + + TestCallback(final ProducerRecord originalRecord, + final List> needRetry) { + this.originalRecord = originalRecord; + this.needRetry = needRetry; + } + + @Override + public void onCompletion(final RecordMetadata metadata, final Exception exception) { + if (exception != null) { + if (exception instanceof TimeoutException) { + needRetry.add(originalRecord); + } else { + exception.printStackTrace(); + Exit.exit(1); + } + } + } + } + + private static void shuffle(final int[] data, @SuppressWarnings("SameParameterValue") final int windowSize) { + final Random rand = new Random(); + for (int i = 0; i < data.length; i++) { + // we shuffle data within windowSize + final int j = rand.nextInt(Math.min(data.length - i, windowSize)) + i; + + // swap + final int tmp = data[i]; + data[i] = data[j]; + data[j] = tmp; + } + } + + public static class NumberDeserializer implements Deserializer { + @Override + public void configure(final Map configs, final boolean isKey) { + + } + + @Override + public Number deserialize(final String topic, final byte[] data) { + final Number value; + switch (topic) { + case "data": + case "echo": + case "min": + case "min-raw": + case "min-suppressed": + case "sws-raw": + case "sws-suppressed": + case "max": + case "dif": + value = intSerde.deserializer().deserialize(topic, data); + break; + case "sum": + case "cnt": + case "tagg": + value = longSerde.deserializer().deserialize(topic, data); + break; + case "avg": + value = doubleSerde.deserializer().deserialize(topic, data); + break; + default: + throw new RuntimeException("unknown topic: " + topic); + } + return value; + } + + @Override + public void close() { + + } + } + + public static VerificationResult verify(final String kafka, + final Map> inputs, + final int maxRecordsPerKey) { + final Properties props = new Properties(); + props.put(ConsumerConfig.CLIENT_ID_CONFIG, "verifier"); + props.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, kafka); + props.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); + props.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, NumberDeserializer.class); + props.put(ConsumerConfig.ISOLATION_LEVEL_CONFIG, "read_committed"); + + final KafkaConsumer consumer = new KafkaConsumer<>(props); + final List partitions = getAllPartitions(consumer, TOPICS); + consumer.assign(partitions); + consumer.seekToBeginning(partitions); + + final int recordsGenerated = inputs.size() * maxRecordsPerKey; + int recordsProcessed = 0; + final Map processed = + Stream.of(TOPICS) + .collect(Collectors.toMap(t -> t, t -> new AtomicInteger(0))); + + final Map>>> events = new HashMap<>(); + + VerificationResult verificationResult = new VerificationResult(false, "no results yet"); + int retry = 0; + final long start = System.currentTimeMillis(); + while (System.currentTimeMillis() - start < TimeUnit.MINUTES.toMillis(6)) { + final ConsumerRecords records = consumer.poll(Duration.ofSeconds(5)); + if (records.isEmpty() && recordsProcessed >= recordsGenerated) { + verificationResult = verifyAll(inputs, events, false); + if (verificationResult.passed()) { + break; + } else if (retry++ > MAX_RECORD_EMPTY_RETRIES) { + System.out.println(Instant.now() + " Didn't get any more results, verification hasn't passed, and out of retries."); + break; + } else { + System.out.println(Instant.now() + " Didn't get any more results, but verification hasn't passed (yet). Retrying..." + retry); + } + } else { + System.out.println(Instant.now() + " Get some more results from " + records.partitions() + ", resetting retry."); + + retry = 0; + for (final ConsumerRecord record : records) { + final String key = record.key(); + + final String topic = record.topic(); + processed.get(topic).incrementAndGet(); + + if (topic.equals("echo")) { + recordsProcessed++; + if (recordsProcessed % 100 == 0) { + System.out.println("Echo records processed = " + recordsProcessed); + } + } + + events.computeIfAbsent(topic, t -> new HashMap<>()) + .computeIfAbsent(key, k -> new LinkedList<>()) + .add(record); + } + + System.out.println(processed); + } + } + consumer.close(); + final long finished = System.currentTimeMillis() - start; + System.out.println("Verification time=" + finished); + System.out.println("-------------------"); + System.out.println("Result Verification"); + System.out.println("-------------------"); + System.out.println("recordGenerated=" + recordsGenerated); + System.out.println("recordProcessed=" + recordsProcessed); + + if (recordsProcessed > recordsGenerated) { + System.out.println("PROCESSED-MORE-THAN-GENERATED"); + } else if (recordsProcessed < recordsGenerated) { + System.out.println("PROCESSED-LESS-THAN-GENERATED"); + } + + boolean success; + + final Map> received = + events.get("echo") + .entrySet() + .stream() + .map(entry -> mkEntry( + entry.getKey(), + entry.getValue().stream().map(ConsumerRecord::value).collect(Collectors.toSet())) + ) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + success = inputs.equals(received); + + if (success) { + System.out.println("ALL-RECORDS-DELIVERED"); + } else { + int missedCount = 0; + for (final Map.Entry> entry : inputs.entrySet()) { + missedCount += received.get(entry.getKey()).size(); + } + System.out.println("missedRecords=" + missedCount); + } + + // give it one more try if it's not already passing. + if (!verificationResult.passed()) { + verificationResult = verifyAll(inputs, events, true); + } + success &= verificationResult.passed(); + + System.out.println(verificationResult.result()); + + System.out.println(success ? "SUCCESS" : "FAILURE"); + return verificationResult; + } + + public static class VerificationResult { + private final boolean passed; + private final String result; + + VerificationResult(final boolean passed, final String result) { + this.passed = passed; + this.result = result; + } + + public boolean passed() { + return passed; + } + + public String result() { + return result; + } + } + + private static VerificationResult verifyAll(final Map> inputs, + final Map>>> events, + final boolean printResults) { + final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + boolean pass; + try (final PrintStream resultStream = new PrintStream(byteArrayOutputStream)) { + pass = verifyTAgg(resultStream, inputs, events.get("tagg"), printResults); + pass &= verifySuppressed(resultStream, "min-suppressed", events, printResults); + pass &= verify(resultStream, "min-suppressed", inputs, events, windowedKey -> { + final String unwindowedKey = windowedKey.substring(1, windowedKey.length() - 1).replaceAll("@.*", ""); + return getMin(unwindowedKey); + }, printResults); + pass &= verifySuppressed(resultStream, "sws-suppressed", events, printResults); + pass &= verify(resultStream, "min", inputs, events, SmokeTestDriver::getMin, printResults); + pass &= verify(resultStream, "max", inputs, events, SmokeTestDriver::getMax, printResults); + pass &= verify(resultStream, "dif", inputs, events, key -> getMax(key).intValue() - getMin(key).intValue(), printResults); + pass &= verify(resultStream, "sum", inputs, events, SmokeTestDriver::getSum, printResults); + pass &= verify(resultStream, "cnt", inputs, events, key1 -> getMax(key1).intValue() - getMin(key1).intValue() + 1L, printResults); + pass &= verify(resultStream, "avg", inputs, events, SmokeTestDriver::getAvg, printResults); + } + return new VerificationResult(pass, new String(byteArrayOutputStream.toByteArray(), StandardCharsets.UTF_8)); + } + + private static boolean verify(final PrintStream resultStream, + final String topic, + final Map> inputData, + final Map>>> events, + final Function keyToExpectation, + final boolean printResults) { + final Map>> observedInputEvents = events.get("data"); + final Map>> outputEvents = events.getOrDefault(topic, emptyMap()); + if (outputEvents.isEmpty()) { + resultStream.println(topic + " is empty"); + return false; + } else { + resultStream.printf("verifying %s with %d keys%n", topic, outputEvents.size()); + + if (outputEvents.size() != inputData.size()) { + resultStream.printf("fail: resultCount=%d expectedCount=%s%n\tresult=%s%n\texpected=%s%n", + outputEvents.size(), inputData.size(), outputEvents.keySet(), inputData.keySet()); + return false; + } + for (final Map.Entry>> entry : outputEvents.entrySet()) { + final String key = entry.getKey(); + final Number expected = keyToExpectation.apply(key); + final Number actual = entry.getValue().getLast().value(); + if (!expected.equals(actual)) { + resultStream.printf("%s fail: key=%s actual=%s expected=%s%n", topic, key, actual, expected); + + if (printResults) { + resultStream.printf("\t inputEvents=%n%s%n\t" + + "echoEvents=%n%s%n\tmaxEvents=%n%s%n\tminEvents=%n%s%n\tdifEvents=%n%s%n\tcntEvents=%n%s%n\ttaggEvents=%n%s%n", + indent("\t\t", observedInputEvents.get(key)), + indent("\t\t", events.getOrDefault("echo", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("max", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("min", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("dif", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("cnt", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("tagg", emptyMap()).getOrDefault(key, new LinkedList<>()))); + + if (!Utils.mkSet("echo", "max", "min", "dif", "cnt", "tagg").contains(topic)) + resultStream.printf("%sEvents=%n%s%n", topic, indent("\t\t", entry.getValue())); + } + + return false; + } + } + return true; + } + } + + + private static boolean verifySuppressed(final PrintStream resultStream, + @SuppressWarnings("SameParameterValue") final String topic, + final Map>>> events, + final boolean printResults) { + resultStream.println("verifying suppressed " + topic); + final Map>> topicEvents = events.getOrDefault(topic, emptyMap()); + for (final Map.Entry>> entry : topicEvents.entrySet()) { + if (entry.getValue().size() != 1) { + final String unsuppressedTopic = topic.replace("-suppressed", "-raw"); + final String key = entry.getKey(); + final String unwindowedKey = key.substring(1, key.length() - 1).replaceAll("@.*", ""); + resultStream.printf("fail: key=%s%n\tnon-unique result:%n%s%n", + key, + indent("\t\t", entry.getValue())); + + if (printResults) + resultStream.printf("\tresultEvents:%n%s%n\tinputEvents:%n%s%n", + indent("\t\t", events.get(unsuppressedTopic).get(key)), + indent("\t\t", events.get("data").get(unwindowedKey))); + + return false; + } + } + return true; + } + + private static String indent(@SuppressWarnings("SameParameterValue") final String prefix, + final Iterable> list) { + final StringBuilder stringBuilder = new StringBuilder(); + for (final ConsumerRecord record : list) { + stringBuilder.append(prefix).append(record).append('\n'); + } + return stringBuilder.toString(); + } + + private static Long getSum(final String key) { + final int min = getMin(key).intValue(); + final int max = getMax(key).intValue(); + return ((long) min + max) * (max - min + 1L) / 2L; + } + + private static Double getAvg(final String key) { + final int min = getMin(key).intValue(); + final int max = getMax(key).intValue(); + return ((long) min + max) / 2.0; + } + + + private static boolean verifyTAgg(final PrintStream resultStream, + final Map> allData, + final Map>> taggEvents, + final boolean printResults) { + if (taggEvents == null) { + resultStream.println("tagg is missing"); + return false; + } else if (taggEvents.isEmpty()) { + resultStream.println("tagg is empty"); + return false; + } else { + resultStream.println("verifying tagg"); + + // generate expected answer + final Map expected = new HashMap<>(); + for (final String key : allData.keySet()) { + final int min = getMin(key).intValue(); + final int max = getMax(key).intValue(); + final String cnt = Long.toString(max - min + 1L); + + expected.put(cnt, expected.getOrDefault(cnt, 0L) + 1); + } + + // check the result + for (final Map.Entry>> entry : taggEvents.entrySet()) { + final String key = entry.getKey(); + Long expectedCount = expected.remove(key); + if (expectedCount == null) { + expectedCount = 0L; + } + + if (entry.getValue().getLast().value().longValue() != expectedCount) { + resultStream.println("fail: key=" + key + " tagg=" + entry.getValue() + " expected=" + expectedCount); + + if (printResults) + resultStream.println("\t taggEvents: " + entry.getValue()); + return false; + } + } + + } + return true; + } + + private static Number getMin(final String key) { + return Integer.parseInt(key.split("-")[0]); + } + + private static Number getMax(final String key) { + return Integer.parseInt(key.split("-")[1]); + } + + private static List getAllPartitions(final KafkaConsumer consumer, final String... topics) { + final List partitions = new ArrayList<>(); + + for (final String topic : topics) { + for (final PartitionInfo info : consumer.partitionsFor(topic)) { + partitions.add(new TopicPartition(info.topic(), info.partition())); + } + } + return partitions; + } + +} diff --git a/streams/upgrade-system-tests-22/src/test/java/org/apache/kafka/streams/tests/SmokeTestUtil.java b/streams/upgrade-system-tests-22/src/test/java/org/apache/kafka/streams/tests/SmokeTestUtil.java new file mode 100644 index 0000000..e8ec04c --- /dev/null +++ b/streams/upgrade-system-tests-22/src/test/java/org/apache/kafka/streams/tests/SmokeTestUtil.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.AbstractProcessor; +import org.apache.kafka.streams.processor.Processor; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.ProcessorSupplier; + +import java.time.Instant; + +public class SmokeTestUtil { + + final static int END = Integer.MAX_VALUE; + + static ProcessorSupplier printProcessorSupplier(final String topic) { + return printProcessorSupplier(topic, ""); + } + + static ProcessorSupplier printProcessorSupplier(final String topic, final String name) { + return new ProcessorSupplier() { + @Override + public Processor get() { + return new AbstractProcessor() { + private int numRecordsProcessed = 0; + private long smallestOffset = Long.MAX_VALUE; + private long largestOffset = Long.MIN_VALUE; + + @Override + public void init(final ProcessorContext context) { + super.init(context); + System.out.println("[DEV] initializing processor: topic=" + topic + " taskId=" + context.taskId()); + System.out.flush(); + numRecordsProcessed = 0; + smallestOffset = Long.MAX_VALUE; + largestOffset = Long.MIN_VALUE; + } + + @Override + public void process(final Object key, final Object value) { + numRecordsProcessed++; + if (numRecordsProcessed % 100 == 0) { + System.out.printf("%s: %s%n", name, Instant.now()); + System.out.println("processed " + numRecordsProcessed + " records from topic=" + topic); + } + + if (smallestOffset > context().offset()) { + smallestOffset = context().offset(); + } + if (largestOffset < context().offset()) { + largestOffset = context().offset(); + } + } + + @Override + public void close() { + System.out.printf("Close processor for task %s%n", context().taskId()); + System.out.println("processed " + numRecordsProcessed + " records"); + final long processed; + if (largestOffset >= smallestOffset) { + processed = 1L + largestOffset - smallestOffset; + } else { + processed = 0L; + } + System.out.println("offset " + smallestOffset + " to " + largestOffset + " -> processed " + processed); + System.out.flush(); + } + }; + } + }; + } + + public static final class Unwindow implements KeyValueMapper, V, K> { + @Override + public K apply(final Windowed winKey, final V value) { + return winKey.key(); + } + } + + public static class Agg { + + KeyValueMapper> selector() { + return (key, value) -> new KeyValue<>(value == null ? null : Long.toString(value), 1L); + } + + public Initializer init() { + return () -> 0L; + } + + Aggregator adder() { + return (aggKey, value, aggregate) -> aggregate + value; + } + + Aggregator remover() { + return (aggKey, value, aggregate) -> aggregate - value; + } + } + + public static Serde stringSerde = Serdes.String(); + + public static Serde intSerde = Serdes.Integer(); + + static Serde longSerde = Serdes.Long(); + + static Serde doubleSerde = Serdes.Double(); + + public static void sleep(final long duration) { + try { + Thread.sleep(duration); + } catch (final Exception ignore) { } + } + +} diff --git a/streams/upgrade-system-tests-22/src/test/java/org/apache/kafka/streams/tests/StreamsSmokeTest.java b/streams/upgrade-system-tests-22/src/test/java/org/apache/kafka/streams/tests/StreamsSmokeTest.java new file mode 100644 index 0000000..f280eb0 --- /dev/null +++ b/streams/upgrade-system-tests-22/src/test/java/org/apache/kafka/streams/tests/StreamsSmokeTest.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.StreamsConfig; + +import java.io.IOException; +import java.time.Duration; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.UUID; + +import static org.apache.kafka.streams.tests.SmokeTestDriver.generate; +import static org.apache.kafka.streams.tests.SmokeTestDriver.generatePerpetually; + +public class StreamsSmokeTest { + + /** + * args ::= kafka propFileName command disableAutoTerminate + * command := "run" | "process" + * + * @param args + */ + public static void main(final String[] args) throws IOException { + if (args.length < 2) { + System.err.println("StreamsSmokeTest are expecting two parameters: propFile, command; but only see " + args.length + " parameter"); + Exit.exit(1); + } + + final String propFileName = args[0]; + final String command = args[1]; + final boolean disableAutoTerminate = args.length > 2; + + final Properties streamsProperties = Utils.loadProps(propFileName); + final String kafka = streamsProperties.getProperty(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG); + final String processingGuarantee = streamsProperties.getProperty(StreamsConfig.PROCESSING_GUARANTEE_CONFIG); + + if (kafka == null) { + System.err.println("No bootstrap kafka servers specified in " + StreamsConfig.BOOTSTRAP_SERVERS_CONFIG); + Exit.exit(1); + } + + if ("process".equals(command)) { + if (!StreamsConfig.AT_LEAST_ONCE.equals(processingGuarantee) && + !StreamsConfig.EXACTLY_ONCE.equals(processingGuarantee)) { + + System.err.println("processingGuarantee must be either " + StreamsConfig.AT_LEAST_ONCE + " or " + + StreamsConfig.EXACTLY_ONCE); + + Exit.exit(1); + } + } + + System.out.println("StreamsTest instance started (StreamsSmokeTest)"); + System.out.println("command=" + command); + System.out.println("props=" + streamsProperties); + System.out.println("disableAutoTerminate=" + disableAutoTerminate); + + switch (command) { + case "run": + // this starts the driver (data generation and result verification) + final int numKeys = 10; + final int maxRecordsPerKey = 500; + if (disableAutoTerminate) { + generatePerpetually(kafka, numKeys, maxRecordsPerKey); + } else { + // slow down data production to span 30 seconds so that system tests have time to + // do their bounces, etc. + final Map> allData = + generate(kafka, numKeys, maxRecordsPerKey, Duration.ofSeconds(30)); + SmokeTestDriver.verify(kafka, allData, maxRecordsPerKey); + } + break; + case "process": + // this starts the stream processing app + new SmokeTestClient(UUID.randomUUID().toString()).start(streamsProperties); + break; + default: + System.out.println("unknown command: " + command); + } + } + +} diff --git a/streams/upgrade-system-tests-22/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java b/streams/upgrade-system-tests-22/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java new file mode 100644 index 0000000..82d102d --- /dev/null +++ b/streams/upgrade-system-tests-22/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.processor.AbstractProcessor; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.ProcessorSupplier; + +import java.util.Properties; + +public class StreamsUpgradeTest { + + @SuppressWarnings("unchecked") + public static void main(final String[] args) throws Exception { + if (args.length < 1) { + System.err.println("StreamsUpgradeTest requires one argument (properties-file) but provided none"); + } + final String propFileName = args[0]; + + final Properties streamsProperties = Utils.loadProps(propFileName); + + System.out.println("StreamsTest instance started (StreamsUpgradeTest v2.2)"); + System.out.println("props=" + streamsProperties); + + final StreamsBuilder builder = new StreamsBuilder(); + final KStream dataStream = builder.stream("data"); + dataStream.process(printProcessorSupplier()); + dataStream.to("echo"); + + final Properties config = new Properties(); + config.setProperty(StreamsConfig.APPLICATION_ID_CONFIG, "StreamsUpgradeTest"); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + config.putAll(streamsProperties); + + final KafkaStreams streams = new KafkaStreams(builder.build(), config); + streams.start(); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + streams.close(); + System.out.println("UPGRADE-TEST-CLIENT-CLOSED"); + System.out.flush(); + })); + } + + private static ProcessorSupplier printProcessorSupplier() { + return () -> new AbstractProcessor() { + private int numRecordsProcessed = 0; + + @Override + public void init(final ProcessorContext context) { + System.out.println("[2.2] initializing processor: topic=data taskId=" + context.taskId()); + numRecordsProcessed = 0; + } + + @Override + public void process(final K key, final V value) { + numRecordsProcessed++; + if (numRecordsProcessed % 100 == 0) { + System.out.println("processed " + numRecordsProcessed + " records from topic=data"); + } + } + + @Override + public void close() {} + }; + } +} diff --git a/streams/upgrade-system-tests-22/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java b/streams/upgrade-system-tests-22/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java new file mode 100644 index 0000000..299fffa --- /dev/null +++ b/streams/upgrade-system-tests-22/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KafkaStreams.State; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.ForeachAction; +import org.apache.kafka.streams.processor.TaskMetadata; +import org.apache.kafka.streams.processor.ThreadMetadata; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Properties; +import java.util.Set; + +public class StreamsUpgradeToCooperativeRebalanceTest { + + + @SuppressWarnings("unchecked") + public static void main(final String[] args) throws Exception { + if (args.length < 1) { + System.err.println("StreamsUpgradeToCooperativeRebalanceTest requires one argument (properties-file) but none provided"); + } + System.out.println("Args are " + Arrays.toString(args)); + final String propFileName = args[0]; + final Properties streamsProperties = Utils.loadProps(propFileName); + + final Properties config = new Properties(); + System.out.println("StreamsTest instance started (StreamsUpgradeToCooperativeRebalanceTest v2.2)"); + System.out.println("props=" + streamsProperties); + + config.put(StreamsConfig.APPLICATION_ID_CONFIG, "cooperative-rebalance-upgrade"); + config.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + config.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + config.putAll(streamsProperties); + + final String sourceTopic = streamsProperties.getProperty("source.topic", "source"); + final String sinkTopic = streamsProperties.getProperty("sink.topic", "sink"); + final String taskDelimiter = streamsProperties.getProperty("task.delimiter", "#"); + final int reportInterval = Integer.parseInt(streamsProperties.getProperty("report.interval", "100")); + final String upgradePhase = streamsProperties.getProperty("upgrade.phase", ""); + + final StreamsBuilder builder = new StreamsBuilder(); + + builder.stream(sourceTopic) + .peek(new ForeachAction() { + int recordCounter = 0; + + @Override + public void apply(final String key, final String value) { + if (recordCounter++ % reportInterval == 0) { + System.out.println(String.format("%sProcessed %d records so far", upgradePhase, recordCounter)); + System.out.flush(); + } + } + } + ).to(sinkTopic); + + final KafkaStreams streams = new KafkaStreams(builder.build(), config); + + streams.setStateListener((newState, oldState) -> { + if (newState == State.RUNNING && oldState == State.REBALANCING) { + System.out.println(String.format("%sSTREAMS in a RUNNING State", upgradePhase)); + final Set allThreadMetadata = streams.localThreadsMetadata(); + final StringBuilder taskReportBuilder = new StringBuilder(); + final List activeTasks = new ArrayList<>(); + final List standbyTasks = new ArrayList<>(); + for (final ThreadMetadata threadMetadata : allThreadMetadata) { + getTasks(threadMetadata.activeTasks(), activeTasks); + if (!threadMetadata.standbyTasks().isEmpty()) { + getTasks(threadMetadata.standbyTasks(), standbyTasks); + } + } + addTasksToBuilder(activeTasks, taskReportBuilder); + taskReportBuilder.append(taskDelimiter); + if (!standbyTasks.isEmpty()) { + addTasksToBuilder(standbyTasks, taskReportBuilder); + } + System.out.println("TASK-ASSIGNMENTS:" + taskReportBuilder); + } + + if (newState == State.REBALANCING) { + System.out.println(String.format("%sStarting a REBALANCE", upgradePhase)); + } + }); + + + streams.start(); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + streams.close(); + System.out.println(String.format("%sCOOPERATIVE-REBALANCE-TEST-CLIENT-CLOSED", upgradePhase)); + System.out.flush(); + })); + } + + private static void addTasksToBuilder(final List tasks, final StringBuilder builder) { + if (!tasks.isEmpty()) { + for (final String task : tasks) { + builder.append(task).append(","); + } + builder.setLength(builder.length() - 1); + } + } + private static void getTasks(final Set taskMetadata, + final List taskList) { + for (final TaskMetadata task : taskMetadata) { + final Set topicPartitions = task.topicPartitions(); + for (final TopicPartition topicPartition : topicPartitions) { + taskList.add(topicPartition.toString()); + } + } + } +} diff --git a/streams/upgrade-system-tests-23/src/test/java/org/apache/kafka/streams/tests/SmokeTestClient.java b/streams/upgrade-system-tests-23/src/test/java/org/apache/kafka/streams/tests/SmokeTestClient.java new file mode 100644 index 0000000..ced1369 --- /dev/null +++ b/streams/upgrade-system-tests-23/src/test/java/org/apache/kafka/streams/tests/SmokeTestClient.java @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.KafkaThread; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.Suppressed.BufferConfig; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.WindowStore; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.time.Duration; +import java.time.Instant; +import java.util.Properties; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.apache.kafka.streams.kstream.Suppressed.untilWindowCloses; + +public class SmokeTestClient extends SmokeTestUtil { + + private final String name; + + private KafkaStreams streams; + private boolean uncaughtException = false; + private boolean started; + private volatile boolean closed; + + private static void addShutdownHook(final String name, final Runnable runnable) { + if (name != null) { + Runtime.getRuntime().addShutdownHook(KafkaThread.nonDaemon(name, runnable)); + } else { + Runtime.getRuntime().addShutdownHook(new Thread(runnable)); + } + } + + private static File tempDirectory() { + final String prefix = "kafka-"; + final File file; + try { + file = Files.createTempDirectory(prefix).toFile(); + } catch (final IOException ex) { + throw new RuntimeException("Failed to create a temp dir", ex); + } + file.deleteOnExit(); + + addShutdownHook("delete-temp-file-shutdown-hook", () -> { + try { + Utils.delete(file); + } catch (final IOException e) { + System.out.println("Error deleting " + file.getAbsolutePath()); + e.printStackTrace(System.out); + } + }); + + return file; + } + + public SmokeTestClient(final String name) { + this.name = name; + } + + public boolean started() { + return started; + } + + public boolean closed() { + return closed; + } + + public void start(final Properties streamsProperties) { + final Topology build = getTopology(); + streams = new KafkaStreams(build, getStreamsConfig(streamsProperties)); + + final CountDownLatch countDownLatch = new CountDownLatch(1); + streams.setStateListener((newState, oldState) -> { + System.out.printf("%s %s: %s -> %s%n", name, Instant.now(), oldState, newState); + if (oldState == KafkaStreams.State.REBALANCING && newState == KafkaStreams.State.RUNNING) { + started = true; + countDownLatch.countDown(); + } + + if (newState == KafkaStreams.State.NOT_RUNNING) { + closed = true; + } + }); + + streams.setUncaughtExceptionHandler((t, e) -> { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION"); + System.out.println(name + ": FATAL: An unexpected exception is encountered on thread " + t + ": " + e); + e.printStackTrace(System.out); + uncaughtException = true; + streams.close(Duration.ofSeconds(30)); + }); + + addShutdownHook("streams-shutdown-hook", this::close); + + streams.start(); + try { + if (!countDownLatch.await(1, TimeUnit.MINUTES)) { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION: Didn't start in one minute"); + } + } catch (final InterruptedException e) { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION: " + e); + e.printStackTrace(System.out); + } + System.out.println(name + ": SMOKE-TEST-CLIENT-STARTED"); + System.out.println(name + " started at " + Instant.now()); + } + + public void closeAsync() { + streams.close(Duration.ZERO); + } + + public void close() { + final boolean closed = streams.close(Duration.ofMinutes(1)); + + if (closed && !uncaughtException) { + System.out.println(name + ": SMOKE-TEST-CLIENT-CLOSED"); + } else if (closed) { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION"); + } else { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION: Didn't close"); + } + } + + private Properties getStreamsConfig(final Properties props) { + final Properties fullProps = new Properties(props); + fullProps.put(StreamsConfig.APPLICATION_ID_CONFIG, "SmokeTest"); + fullProps.put(StreamsConfig.CLIENT_ID_CONFIG, "SmokeTest-" + name); + fullProps.put(StreamsConfig.STATE_DIR_CONFIG, tempDirectory().getAbsolutePath()); + fullProps.putAll(props); + return fullProps; + } + + public Topology getTopology() { + final StreamsBuilder builder = new StreamsBuilder(); + final Consumed stringIntConsumed = Consumed.with(stringSerde, intSerde); + final KStream source = builder.stream("data", stringIntConsumed); + source.filterNot((k, v) -> k.equals("flush")) + .to("echo", Produced.with(stringSerde, intSerde)); + final KStream data = source.filter((key, value) -> value == null || value != END); + data.process(SmokeTestUtil.printProcessorSupplier("data", name)); + + // min + final KGroupedStream groupedData = data.groupByKey(Grouped.with(stringSerde, intSerde)); + + final KTable, Integer> minAggregation = groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(1)).grace(Duration.ofMinutes(1))) + .aggregate( + () -> Integer.MAX_VALUE, + (aggKey, value, aggregate) -> (value < aggregate) ? value : aggregate, + Materialized + .>as("uwin-min") + .withValueSerde(intSerde) + .withRetention(Duration.ofHours(25)) + ); + + streamify(minAggregation, "min-raw"); + + streamify(minAggregation.suppress(untilWindowCloses(BufferConfig.unbounded())), "min-suppressed"); + + minAggregation + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("min", Produced.with(stringSerde, intSerde)); + + final KTable, Integer> smallWindowSum = groupedData + .windowedBy(TimeWindows.of(Duration.ofSeconds(2)).advanceBy(Duration.ofSeconds(1)).grace(Duration.ofSeconds(30))) + .reduce((l, r) -> l + r); + + streamify(smallWindowSum, "sws-raw"); + streamify(smallWindowSum.suppress(untilWindowCloses(BufferConfig.unbounded())), "sws-suppressed"); + + final KTable minTable = builder.table( + "min", + Consumed.with(stringSerde, intSerde), + Materialized.as("minStoreName")); + + minTable.toStream().process(SmokeTestUtil.printProcessorSupplier("min", name)); + + // max + groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(2))) + .aggregate( + () -> Integer.MIN_VALUE, + (aggKey, value, aggregate) -> (value > aggregate) ? value : aggregate, + Materialized.>as("uwin-max").withValueSerde(intSerde)) + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("max", Produced.with(stringSerde, intSerde)); + + final KTable maxTable = builder.table( + "max", + Consumed.with(stringSerde, intSerde), + Materialized.as("maxStoreName")); + maxTable.toStream().process(SmokeTestUtil.printProcessorSupplier("max", name)); + + // sum + groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(2))) + .aggregate( + () -> 0L, + (aggKey, value, aggregate) -> (long) value + aggregate, + Materialized.>as("win-sum").withValueSerde(longSerde)) + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("sum", Produced.with(stringSerde, longSerde)); + + final Consumed stringLongConsumed = Consumed.with(stringSerde, longSerde); + final KTable sumTable = builder.table("sum", stringLongConsumed); + sumTable.toStream().process(SmokeTestUtil.printProcessorSupplier("sum", name)); + + // cnt + groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(2))) + .count(Materialized.as("uwin-cnt")) + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("cnt", Produced.with(stringSerde, longSerde)); + + final KTable cntTable = builder.table( + "cnt", + Consumed.with(stringSerde, longSerde), + Materialized.as("cntStoreName")); + cntTable.toStream().process(SmokeTestUtil.printProcessorSupplier("cnt", name)); + + // dif + maxTable + .join( + minTable, + (value1, value2) -> value1 - value2) + .toStream() + .filterNot((k, v) -> k.equals("flush")) + .to("dif", Produced.with(stringSerde, intSerde)); + + // avg + sumTable + .join( + cntTable, + (value1, value2) -> (double) value1 / (double) value2) + .toStream() + .filterNot((k, v) -> k.equals("flush")) + .to("avg", Produced.with(stringSerde, doubleSerde)); + + // test repartition + final Agg agg = new Agg(); + cntTable.groupBy(agg.selector(), Grouped.with(stringSerde, longSerde)) + .aggregate(agg.init(), agg.adder(), agg.remover(), + Materialized.as(Stores.inMemoryKeyValueStore("cntByCnt")) + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.Long())) + .toStream() + .to("tagg", Produced.with(stringSerde, longSerde)); + + return builder.build(); + } + + private static void streamify(final KTable, Integer> windowedTable, final String topic) { + windowedTable + .toStream() + .filterNot((k, v) -> k.key().equals("flush")) + .map((key, value) -> new KeyValue<>(key.toString(), value)) + .to(topic, Produced.with(stringSerde, intSerde)); + } +} diff --git a/streams/upgrade-system-tests-23/src/test/java/org/apache/kafka/streams/tests/SmokeTestDriver.java b/streams/upgrade-system-tests-23/src/test/java/org/apache/kafka/streams/tests/SmokeTestDriver.java new file mode 100644 index 0000000..ac83cd9 --- /dev/null +++ b/streams/upgrade-system-tests-23/src/test/java/org/apache/kafka/streams/tests/SmokeTestDriver.java @@ -0,0 +1,622 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.Callback; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; + +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static java.util.Collections.emptyMap; +import static org.apache.kafka.common.utils.Utils.mkEntry; + +public class SmokeTestDriver extends SmokeTestUtil { + private static final String[] TOPICS = { + "data", + "echo", + "max", + "min", "min-suppressed", "min-raw", + "dif", + "sum", + "sws-raw", "sws-suppressed", + "cnt", + "avg", + "tagg" + }; + + private static final int MAX_RECORD_EMPTY_RETRIES = 30; + + private static class ValueList { + public final String key; + private final int[] values; + private int index; + + ValueList(final int min, final int max) { + key = min + "-" + max; + + values = new int[max - min + 1]; + for (int i = 0; i < values.length; i++) { + values[i] = min + i; + } + // We want to randomize the order of data to test not completely predictable processing order + // However, values are also use as a timestamp of the record. (TODO: separate data and timestamp) + // We keep some correlation of time and order. Thus, the shuffling is done with a sliding window + shuffle(values, 10); + + index = 0; + } + + int next() { + return (index < values.length) ? values[index++] : -1; + } + } + + public static String[] topics() { + return Arrays.copyOf(TOPICS, TOPICS.length); + } + + static void generatePerpetually(final String kafka, + final int numKeys, + final int maxRecordsPerKey) { + final Properties producerProps = generatorProperties(kafka); + + int numRecordsProduced = 0; + + final ValueList[] data = new ValueList[numKeys]; + for (int i = 0; i < numKeys; i++) { + data[i] = new ValueList(i, i + maxRecordsPerKey - 1); + } + + final Random rand = new Random(); + + try (final KafkaProducer producer = new KafkaProducer<>(producerProps)) { + while (true) { + final int index = rand.nextInt(numKeys); + final String key = data[index].key; + final int value = data[index].next(); + + final ProducerRecord record = + new ProducerRecord<>( + "data", + stringSerde.serializer().serialize("", key), + intSerde.serializer().serialize("", value) + ); + + producer.send(record); + + numRecordsProduced++; + if (numRecordsProduced % 100 == 0) { + System.out.println(Instant.now() + " " + numRecordsProduced + " records produced"); + } + Utils.sleep(2); + } + } + } + + public static Map> generate(final String kafka, + final int numKeys, + final int maxRecordsPerKey, + final Duration timeToSpend) { + final Properties producerProps = generatorProperties(kafka); + + + int numRecordsProduced = 0; + + final Map> allData = new HashMap<>(); + final ValueList[] data = new ValueList[numKeys]; + for (int i = 0; i < numKeys; i++) { + data[i] = new ValueList(i, i + maxRecordsPerKey - 1); + allData.put(data[i].key, new HashSet<>()); + } + final Random rand = new Random(); + + int remaining = data.length; + + final long recordPauseTime = timeToSpend.toMillis() / numKeys / maxRecordsPerKey; + + List> needRetry = new ArrayList<>(); + + try (final KafkaProducer producer = new KafkaProducer<>(producerProps)) { + while (remaining > 0) { + final int index = rand.nextInt(remaining); + final String key = data[index].key; + final int value = data[index].next(); + + if (value < 0) { + remaining--; + data[index] = data[remaining]; + } else { + + final ProducerRecord record = + new ProducerRecord<>( + "data", + stringSerde.serializer().serialize("", key), + intSerde.serializer().serialize("", value) + ); + + producer.send(record, new TestCallback(record, needRetry)); + + numRecordsProduced++; + allData.get(key).add(value); + if (numRecordsProduced % 100 == 0) { + System.out.println(Instant.now() + " " + numRecordsProduced + " records produced"); + } + Utils.sleep(Math.max(recordPauseTime, 2)); + } + } + producer.flush(); + + int remainingRetries = 5; + while (!needRetry.isEmpty()) { + final List> needRetry2 = new ArrayList<>(); + for (final ProducerRecord record : needRetry) { + System.out.println("retry producing " + stringSerde.deserializer().deserialize("", record.key())); + producer.send(record, new TestCallback(record, needRetry2)); + } + producer.flush(); + needRetry = needRetry2; + + if (--remainingRetries == 0 && !needRetry.isEmpty()) { + System.err.println("Failed to produce all records after multiple retries"); + Exit.exit(1); + } + } + + // now that we've sent everything, we'll send some final records with a timestamp high enough to flush out + // all suppressed records. + final List partitions = producer.partitionsFor("data"); + for (final PartitionInfo partition : partitions) { + producer.send(new ProducerRecord<>( + partition.topic(), + partition.partition(), + System.currentTimeMillis() + Duration.ofDays(2).toMillis(), + stringSerde.serializer().serialize("", "flush"), + intSerde.serializer().serialize("", 0) + )); + } + } + return Collections.unmodifiableMap(allData); + } + + private static Properties generatorProperties(final String kafka) { + final Properties producerProps = new Properties(); + producerProps.put(ProducerConfig.CLIENT_ID_CONFIG, "SmokeTest"); + producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, kafka); + producerProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class); + producerProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class); + producerProps.put(ProducerConfig.ACKS_CONFIG, "all"); + return producerProps; + } + + private static class TestCallback implements Callback { + private final ProducerRecord originalRecord; + private final List> needRetry; + + TestCallback(final ProducerRecord originalRecord, + final List> needRetry) { + this.originalRecord = originalRecord; + this.needRetry = needRetry; + } + + @Override + public void onCompletion(final RecordMetadata metadata, final Exception exception) { + if (exception != null) { + if (exception instanceof TimeoutException) { + needRetry.add(originalRecord); + } else { + exception.printStackTrace(); + Exit.exit(1); + } + } + } + } + + private static void shuffle(final int[] data, @SuppressWarnings("SameParameterValue") final int windowSize) { + final Random rand = new Random(); + for (int i = 0; i < data.length; i++) { + // we shuffle data within windowSize + final int j = rand.nextInt(Math.min(data.length - i, windowSize)) + i; + + // swap + final int tmp = data[i]; + data[i] = data[j]; + data[j] = tmp; + } + } + + public static class NumberDeserializer implements Deserializer { + @Override + public Number deserialize(final String topic, final byte[] data) { + final Number value; + switch (topic) { + case "data": + case "echo": + case "min": + case "min-raw": + case "min-suppressed": + case "sws-raw": + case "sws-suppressed": + case "max": + case "dif": + value = intSerde.deserializer().deserialize(topic, data); + break; + case "sum": + case "cnt": + case "tagg": + value = longSerde.deserializer().deserialize(topic, data); + break; + case "avg": + value = doubleSerde.deserializer().deserialize(topic, data); + break; + default: + throw new RuntimeException("unknown topic: " + topic); + } + return value; + } + } + + public static VerificationResult verify(final String kafka, + final Map> inputs, + final int maxRecordsPerKey) { + final Properties props = new Properties(); + props.put(ConsumerConfig.CLIENT_ID_CONFIG, "verifier"); + props.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, kafka); + props.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); + props.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, NumberDeserializer.class); + props.put(ConsumerConfig.ISOLATION_LEVEL_CONFIG, "read_committed"); + + final KafkaConsumer consumer = new KafkaConsumer<>(props); + final List partitions = getAllPartitions(consumer, TOPICS); + consumer.assign(partitions); + consumer.seekToBeginning(partitions); + + final int recordsGenerated = inputs.size() * maxRecordsPerKey; + int recordsProcessed = 0; + final Map processed = + Stream.of(TOPICS) + .collect(Collectors.toMap(t -> t, t -> new AtomicInteger(0))); + + final Map>>> events = new HashMap<>(); + + VerificationResult verificationResult = new VerificationResult(false, "no results yet"); + int retry = 0; + final long start = System.currentTimeMillis(); + while (System.currentTimeMillis() - start < TimeUnit.MINUTES.toMillis(6)) { + final ConsumerRecords records = consumer.poll(Duration.ofSeconds(5)); + if (records.isEmpty() && recordsProcessed >= recordsGenerated) { + verificationResult = verifyAll(inputs, events, false); + if (verificationResult.passed()) { + break; + } else if (retry++ > MAX_RECORD_EMPTY_RETRIES) { + System.out.println(Instant.now() + " Didn't get any more results, verification hasn't passed, and out of retries."); + break; + } else { + System.out.println(Instant.now() + " Didn't get any more results, but verification hasn't passed (yet). Retrying..." + retry); + } + } else { + System.out.println(Instant.now() + " Get some more results from " + records.partitions() + ", resetting retry."); + + retry = 0; + for (final ConsumerRecord record : records) { + final String key = record.key(); + + final String topic = record.topic(); + processed.get(topic).incrementAndGet(); + + if (topic.equals("echo")) { + recordsProcessed++; + if (recordsProcessed % 100 == 0) { + System.out.println("Echo records processed = " + recordsProcessed); + } + } + + events.computeIfAbsent(topic, t -> new HashMap<>()) + .computeIfAbsent(key, k -> new LinkedList<>()) + .add(record); + } + + System.out.println(processed); + } + } + consumer.close(); + final long finished = System.currentTimeMillis() - start; + System.out.println("Verification time=" + finished); + System.out.println("-------------------"); + System.out.println("Result Verification"); + System.out.println("-------------------"); + System.out.println("recordGenerated=" + recordsGenerated); + System.out.println("recordProcessed=" + recordsProcessed); + + if (recordsProcessed > recordsGenerated) { + System.out.println("PROCESSED-MORE-THAN-GENERATED"); + } else if (recordsProcessed < recordsGenerated) { + System.out.println("PROCESSED-LESS-THAN-GENERATED"); + } + + boolean success; + + final Map> received = + events.get("echo") + .entrySet() + .stream() + .map(entry -> mkEntry( + entry.getKey(), + entry.getValue().stream().map(ConsumerRecord::value).collect(Collectors.toSet())) + ) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + success = inputs.equals(received); + + if (success) { + System.out.println("ALL-RECORDS-DELIVERED"); + } else { + int missedCount = 0; + for (final Map.Entry> entry : inputs.entrySet()) { + missedCount += received.get(entry.getKey()).size(); + } + System.out.println("missedRecords=" + missedCount); + } + + // give it one more try if it's not already passing. + if (!verificationResult.passed()) { + verificationResult = verifyAll(inputs, events, true); + } + success &= verificationResult.passed(); + + System.out.println(verificationResult.result()); + + System.out.println(success ? "SUCCESS" : "FAILURE"); + return verificationResult; + } + + public static class VerificationResult { + private final boolean passed; + private final String result; + + VerificationResult(final boolean passed, final String result) { + this.passed = passed; + this.result = result; + } + + public boolean passed() { + return passed; + } + + public String result() { + return result; + } + } + + private static VerificationResult verifyAll(final Map> inputs, + final Map>>> events, + final boolean printResults) { + final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + boolean pass; + try (final PrintStream resultStream = new PrintStream(byteArrayOutputStream)) { + pass = verifyTAgg(resultStream, inputs, events.get("tagg"), printResults); + pass &= verifySuppressed(resultStream, "min-suppressed", events, printResults); + pass &= verify(resultStream, "min-suppressed", inputs, events, windowedKey -> { + final String unwindowedKey = windowedKey.substring(1, windowedKey.length() - 1).replaceAll("@.*", ""); + return getMin(unwindowedKey); + }, printResults); + pass &= verifySuppressed(resultStream, "sws-suppressed", events, printResults); + pass &= verify(resultStream, "min", inputs, events, SmokeTestDriver::getMin, printResults); + pass &= verify(resultStream, "max", inputs, events, SmokeTestDriver::getMax, printResults); + pass &= verify(resultStream, "dif", inputs, events, key -> getMax(key).intValue() - getMin(key).intValue(), printResults); + pass &= verify(resultStream, "sum", inputs, events, SmokeTestDriver::getSum, printResults); + pass &= verify(resultStream, "cnt", inputs, events, key1 -> getMax(key1).intValue() - getMin(key1).intValue() + 1L, printResults); + pass &= verify(resultStream, "avg", inputs, events, SmokeTestDriver::getAvg, printResults); + } + return new VerificationResult(pass, new String(byteArrayOutputStream.toByteArray(), StandardCharsets.UTF_8)); + } + + private static boolean verify(final PrintStream resultStream, + final String topic, + final Map> inputData, + final Map>>> events, + final Function keyToExpectation, + final boolean printResults) { + final Map>> observedInputEvents = events.get("data"); + final Map>> outputEvents = events.getOrDefault(topic, emptyMap()); + if (outputEvents.isEmpty()) { + resultStream.println(topic + " is empty"); + return false; + } else { + resultStream.printf("verifying %s with %d keys%n", topic, outputEvents.size()); + + if (outputEvents.size() != inputData.size()) { + resultStream.printf("fail: resultCount=%d expectedCount=%s%n\tresult=%s%n\texpected=%s%n", + outputEvents.size(), inputData.size(), outputEvents.keySet(), inputData.keySet()); + return false; + } + for (final Map.Entry>> entry : outputEvents.entrySet()) { + final String key = entry.getKey(); + final Number expected = keyToExpectation.apply(key); + final Number actual = entry.getValue().getLast().value(); + if (!expected.equals(actual)) { + resultStream.printf("%s fail: key=%s actual=%s expected=%s%n", topic, key, actual, expected); + + if (printResults) { + resultStream.printf("\t inputEvents=%n%s%n\t" + + "echoEvents=%n%s%n\tmaxEvents=%n%s%n\tminEvents=%n%s%n\tdifEvents=%n%s%n\tcntEvents=%n%s%n\ttaggEvents=%n%s%n", + indent("\t\t", observedInputEvents.get(key)), + indent("\t\t", events.getOrDefault("echo", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("max", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("min", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("dif", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("cnt", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("tagg", emptyMap()).getOrDefault(key, new LinkedList<>()))); + + if (!Utils.mkSet("echo", "max", "min", "dif", "cnt", "tagg").contains(topic)) + resultStream.printf("%sEvents=%n%s%n", topic, indent("\t\t", entry.getValue())); + } + + return false; + } + } + return true; + } + } + + + private static boolean verifySuppressed(final PrintStream resultStream, + @SuppressWarnings("SameParameterValue") final String topic, + final Map>>> events, + final boolean printResults) { + resultStream.println("verifying suppressed " + topic); + final Map>> topicEvents = events.getOrDefault(topic, emptyMap()); + for (final Map.Entry>> entry : topicEvents.entrySet()) { + if (entry.getValue().size() != 1) { + final String unsuppressedTopic = topic.replace("-suppressed", "-raw"); + final String key = entry.getKey(); + final String unwindowedKey = key.substring(1, key.length() - 1).replaceAll("@.*", ""); + resultStream.printf("fail: key=%s%n\tnon-unique result:%n%s%n", + key, + indent("\t\t", entry.getValue())); + + if (printResults) + resultStream.printf("\tresultEvents:%n%s%n\tinputEvents:%n%s%n", + indent("\t\t", events.get(unsuppressedTopic).get(key)), + indent("\t\t", events.get("data").get(unwindowedKey))); + + return false; + } + } + return true; + } + + private static String indent(@SuppressWarnings("SameParameterValue") final String prefix, + final Iterable> list) { + final StringBuilder stringBuilder = new StringBuilder(); + for (final ConsumerRecord record : list) { + stringBuilder.append(prefix).append(record).append('\n'); + } + return stringBuilder.toString(); + } + + private static Long getSum(final String key) { + final int min = getMin(key).intValue(); + final int max = getMax(key).intValue(); + return ((long) min + max) * (max - min + 1L) / 2L; + } + + private static Double getAvg(final String key) { + final int min = getMin(key).intValue(); + final int max = getMax(key).intValue(); + return ((long) min + max) / 2.0; + } + + + private static boolean verifyTAgg(final PrintStream resultStream, + final Map> allData, + final Map>> taggEvents, + final boolean printResults) { + if (taggEvents == null) { + resultStream.println("tagg is missing"); + return false; + } else if (taggEvents.isEmpty()) { + resultStream.println("tagg is empty"); + return false; + } else { + resultStream.println("verifying tagg"); + + // generate expected answer + final Map expected = new HashMap<>(); + for (final String key : allData.keySet()) { + final int min = getMin(key).intValue(); + final int max = getMax(key).intValue(); + final String cnt = Long.toString(max - min + 1L); + + expected.put(cnt, expected.getOrDefault(cnt, 0L) + 1); + } + + // check the result + for (final Map.Entry>> entry : taggEvents.entrySet()) { + final String key = entry.getKey(); + Long expectedCount = expected.remove(key); + if (expectedCount == null) { + expectedCount = 0L; + } + + if (entry.getValue().getLast().value().longValue() != expectedCount) { + resultStream.println("fail: key=" + key + " tagg=" + entry.getValue() + " expected=" + expectedCount); + + if (printResults) + resultStream.println("\t taggEvents: " + entry.getValue()); + return false; + } + } + + } + return true; + } + + private static Number getMin(final String key) { + return Integer.parseInt(key.split("-")[0]); + } + + private static Number getMax(final String key) { + return Integer.parseInt(key.split("-")[1]); + } + + private static List getAllPartitions(final KafkaConsumer consumer, final String... topics) { + final List partitions = new ArrayList<>(); + + for (final String topic : topics) { + for (final PartitionInfo info : consumer.partitionsFor(topic)) { + partitions.add(new TopicPartition(info.topic(), info.partition())); + } + } + return partitions; + } + +} diff --git a/streams/upgrade-system-tests-23/src/test/java/org/apache/kafka/streams/tests/SmokeTestUtil.java b/streams/upgrade-system-tests-23/src/test/java/org/apache/kafka/streams/tests/SmokeTestUtil.java new file mode 100644 index 0000000..e8ec04c --- /dev/null +++ b/streams/upgrade-system-tests-23/src/test/java/org/apache/kafka/streams/tests/SmokeTestUtil.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.AbstractProcessor; +import org.apache.kafka.streams.processor.Processor; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.ProcessorSupplier; + +import java.time.Instant; + +public class SmokeTestUtil { + + final static int END = Integer.MAX_VALUE; + + static ProcessorSupplier printProcessorSupplier(final String topic) { + return printProcessorSupplier(topic, ""); + } + + static ProcessorSupplier printProcessorSupplier(final String topic, final String name) { + return new ProcessorSupplier() { + @Override + public Processor get() { + return new AbstractProcessor() { + private int numRecordsProcessed = 0; + private long smallestOffset = Long.MAX_VALUE; + private long largestOffset = Long.MIN_VALUE; + + @Override + public void init(final ProcessorContext context) { + super.init(context); + System.out.println("[DEV] initializing processor: topic=" + topic + " taskId=" + context.taskId()); + System.out.flush(); + numRecordsProcessed = 0; + smallestOffset = Long.MAX_VALUE; + largestOffset = Long.MIN_VALUE; + } + + @Override + public void process(final Object key, final Object value) { + numRecordsProcessed++; + if (numRecordsProcessed % 100 == 0) { + System.out.printf("%s: %s%n", name, Instant.now()); + System.out.println("processed " + numRecordsProcessed + " records from topic=" + topic); + } + + if (smallestOffset > context().offset()) { + smallestOffset = context().offset(); + } + if (largestOffset < context().offset()) { + largestOffset = context().offset(); + } + } + + @Override + public void close() { + System.out.printf("Close processor for task %s%n", context().taskId()); + System.out.println("processed " + numRecordsProcessed + " records"); + final long processed; + if (largestOffset >= smallestOffset) { + processed = 1L + largestOffset - smallestOffset; + } else { + processed = 0L; + } + System.out.println("offset " + smallestOffset + " to " + largestOffset + " -> processed " + processed); + System.out.flush(); + } + }; + } + }; + } + + public static final class Unwindow implements KeyValueMapper, V, K> { + @Override + public K apply(final Windowed winKey, final V value) { + return winKey.key(); + } + } + + public static class Agg { + + KeyValueMapper> selector() { + return (key, value) -> new KeyValue<>(value == null ? null : Long.toString(value), 1L); + } + + public Initializer init() { + return () -> 0L; + } + + Aggregator adder() { + return (aggKey, value, aggregate) -> aggregate + value; + } + + Aggregator remover() { + return (aggKey, value, aggregate) -> aggregate - value; + } + } + + public static Serde stringSerde = Serdes.String(); + + public static Serde intSerde = Serdes.Integer(); + + static Serde longSerde = Serdes.Long(); + + static Serde doubleSerde = Serdes.Double(); + + public static void sleep(final long duration) { + try { + Thread.sleep(duration); + } catch (final Exception ignore) { } + } + +} diff --git a/streams/upgrade-system-tests-23/src/test/java/org/apache/kafka/streams/tests/StreamsSmokeTest.java b/streams/upgrade-system-tests-23/src/test/java/org/apache/kafka/streams/tests/StreamsSmokeTest.java new file mode 100644 index 0000000..f280eb0 --- /dev/null +++ b/streams/upgrade-system-tests-23/src/test/java/org/apache/kafka/streams/tests/StreamsSmokeTest.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.StreamsConfig; + +import java.io.IOException; +import java.time.Duration; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.UUID; + +import static org.apache.kafka.streams.tests.SmokeTestDriver.generate; +import static org.apache.kafka.streams.tests.SmokeTestDriver.generatePerpetually; + +public class StreamsSmokeTest { + + /** + * args ::= kafka propFileName command disableAutoTerminate + * command := "run" | "process" + * + * @param args + */ + public static void main(final String[] args) throws IOException { + if (args.length < 2) { + System.err.println("StreamsSmokeTest are expecting two parameters: propFile, command; but only see " + args.length + " parameter"); + Exit.exit(1); + } + + final String propFileName = args[0]; + final String command = args[1]; + final boolean disableAutoTerminate = args.length > 2; + + final Properties streamsProperties = Utils.loadProps(propFileName); + final String kafka = streamsProperties.getProperty(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG); + final String processingGuarantee = streamsProperties.getProperty(StreamsConfig.PROCESSING_GUARANTEE_CONFIG); + + if (kafka == null) { + System.err.println("No bootstrap kafka servers specified in " + StreamsConfig.BOOTSTRAP_SERVERS_CONFIG); + Exit.exit(1); + } + + if ("process".equals(command)) { + if (!StreamsConfig.AT_LEAST_ONCE.equals(processingGuarantee) && + !StreamsConfig.EXACTLY_ONCE.equals(processingGuarantee)) { + + System.err.println("processingGuarantee must be either " + StreamsConfig.AT_LEAST_ONCE + " or " + + StreamsConfig.EXACTLY_ONCE); + + Exit.exit(1); + } + } + + System.out.println("StreamsTest instance started (StreamsSmokeTest)"); + System.out.println("command=" + command); + System.out.println("props=" + streamsProperties); + System.out.println("disableAutoTerminate=" + disableAutoTerminate); + + switch (command) { + case "run": + // this starts the driver (data generation and result verification) + final int numKeys = 10; + final int maxRecordsPerKey = 500; + if (disableAutoTerminate) { + generatePerpetually(kafka, numKeys, maxRecordsPerKey); + } else { + // slow down data production to span 30 seconds so that system tests have time to + // do their bounces, etc. + final Map> allData = + generate(kafka, numKeys, maxRecordsPerKey, Duration.ofSeconds(30)); + SmokeTestDriver.verify(kafka, allData, maxRecordsPerKey); + } + break; + case "process": + // this starts the stream processing app + new SmokeTestClient(UUID.randomUUID().toString()).start(streamsProperties); + break; + default: + System.out.println("unknown command: " + command); + } + } + +} diff --git a/streams/upgrade-system-tests-23/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java b/streams/upgrade-system-tests-23/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java new file mode 100644 index 0000000..45e0628 --- /dev/null +++ b/streams/upgrade-system-tests-23/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.processor.AbstractProcessor; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.ProcessorSupplier; + +import java.util.Properties; + +public class StreamsUpgradeTest { + + @SuppressWarnings("unchecked") + public static void main(final String[] args) throws Exception { + if (args.length < 1) { + System.err.println("StreamsUpgradeTest requires one argument (properties-file) but provided none"); + } + final String propFileName = args[0]; + + final Properties streamsProperties = Utils.loadProps(propFileName); + + System.out.println("StreamsTest instance started (StreamsUpgradeTest v2.3)"); + System.out.println("props=" + streamsProperties); + + final StreamsBuilder builder = new StreamsBuilder(); + final KStream dataStream = builder.stream("data"); + dataStream.process(printProcessorSupplier()); + dataStream.to("echo"); + + final Properties config = new Properties(); + config.setProperty(StreamsConfig.APPLICATION_ID_CONFIG, "StreamsUpgradeTest"); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + config.putAll(streamsProperties); + + final KafkaStreams streams = new KafkaStreams(builder.build(), config); + streams.start(); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + streams.close(); + System.out.println("UPGRADE-TEST-CLIENT-CLOSED"); + System.out.flush(); + })); + } + + private static ProcessorSupplier printProcessorSupplier() { + return () -> new AbstractProcessor() { + private int numRecordsProcessed = 0; + + @Override + public void init(final ProcessorContext context) { + System.out.println("[2.3] initializing processor: topic=data taskId=" + context.taskId()); + numRecordsProcessed = 0; + } + + @Override + public void process(final K key, final V value) { + numRecordsProcessed++; + if (numRecordsProcessed % 100 == 0) { + System.out.println("processed " + numRecordsProcessed + " records from topic=data"); + } + } + + @Override + public void close() {} + }; + } +} diff --git a/streams/upgrade-system-tests-23/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java b/streams/upgrade-system-tests-23/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java new file mode 100644 index 0000000..0a7a48f --- /dev/null +++ b/streams/upgrade-system-tests-23/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeToCooperativeRebalanceTest.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KafkaStreams.State; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.ForeachAction; +import org.apache.kafka.streams.processor.TaskMetadata; +import org.apache.kafka.streams.processor.ThreadMetadata; + +import java.util.Arrays; +import java.util.Properties; +import java.util.Set; + +public class StreamsUpgradeToCooperativeRebalanceTest { + + + @SuppressWarnings("unchecked") + public static void main(final String[] args) throws Exception { + if (args.length < 1) { + System.err.println("StreamsUpgradeToCooperativeRebalanceTest requires one argument (kafka-url, properties-file) but none provided"); + } + System.out.println("Args are " + Arrays.toString(args)); + final String propFileName = args[0]; + final Properties streamsProperties = Utils.loadProps(propFileName); + + final Properties config = new Properties(); + System.out.println("StreamsTest instance started (StreamsUpgradeToCooperativeRebalanceTest v2.3)"); + System.out.println("props=" + streamsProperties); + + config.put(StreamsConfig.APPLICATION_ID_CONFIG, "cooperative-rebalance-upgrade"); + config.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + config.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.String().getClass()); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + config.putAll(streamsProperties); + + final String sourceTopic = streamsProperties.getProperty("source.topic", "source"); + final String sinkTopic = streamsProperties.getProperty("sink.topic", "sink"); + final String threadDelimiter = streamsProperties.getProperty("thread.delimiter", "&"); + final String taskDelimiter = streamsProperties.getProperty("task.delimiter", "#"); + final int reportInterval = Integer.parseInt(streamsProperties.getProperty("report.interval", "100")); + + final StreamsBuilder builder = new StreamsBuilder(); + + builder.stream(sourceTopic) + .peek(new ForeachAction() { + int recordCounter = 0; + + @Override + public void apply(final String key, final String value) { + if (recordCounter++ % reportInterval == 0) { + System.out.println(String.format("Processed %d records so far", recordCounter)); + System.out.flush(); + } + } + } + ).to(sinkTopic); + + final KafkaStreams streams = new KafkaStreams(builder.build(), config); + + streams.setStateListener((newState, oldState) -> { + if (newState == State.RUNNING && oldState == State.REBALANCING) { + System.out.println("STREAMS in a RUNNING State"); + final Set allThreadMetadata = streams.localThreadsMetadata(); + final StringBuilder taskReportBuilder = new StringBuilder(); + for (final ThreadMetadata threadMetadata : allThreadMetadata) { + buildTaskAssignmentReport(taskReportBuilder, threadMetadata.activeTasks(), "ACTIVE-TASKS:"); + if (!threadMetadata.standbyTasks().isEmpty()) { + taskReportBuilder.append(taskDelimiter); + buildTaskAssignmentReport(taskReportBuilder, threadMetadata.standbyTasks(), "STANDBY-TASKS:"); + } + taskReportBuilder.append(threadDelimiter); + } + taskReportBuilder.setLength(taskReportBuilder.length() - 1); + System.out.println("TASK-ASSIGNMENTS:" + taskReportBuilder); + } + + if (newState == State.REBALANCING) { + System.out.println("Starting a REBALANCE"); + } + }); + + + streams.start(); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + streams.close(); + System.out.println("COOPERATIVE-REBALANCE-TEST-CLIENT-CLOSED"); + System.out.flush(); + })); + } + + private static void buildTaskAssignmentReport(final StringBuilder taskReportBuilder, + final Set taskMetadata, + final String taskType) { + taskReportBuilder.append(taskType); + for (final TaskMetadata task : taskMetadata) { + final Set topicPartitions = task.topicPartitions(); + for (final TopicPartition topicPartition : topicPartitions) { + taskReportBuilder.append(topicPartition.toString()).append(","); + } + } + taskReportBuilder.setLength(taskReportBuilder.length() - 1); + } +} diff --git a/streams/upgrade-system-tests-24/src/test/java/org/apache/kafka/streams/tests/SmokeTestClient.java b/streams/upgrade-system-tests-24/src/test/java/org/apache/kafka/streams/tests/SmokeTestClient.java new file mode 100644 index 0000000..ced1369 --- /dev/null +++ b/streams/upgrade-system-tests-24/src/test/java/org/apache/kafka/streams/tests/SmokeTestClient.java @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.KafkaThread; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.Suppressed.BufferConfig; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.WindowStore; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.time.Duration; +import java.time.Instant; +import java.util.Properties; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.apache.kafka.streams.kstream.Suppressed.untilWindowCloses; + +public class SmokeTestClient extends SmokeTestUtil { + + private final String name; + + private KafkaStreams streams; + private boolean uncaughtException = false; + private boolean started; + private volatile boolean closed; + + private static void addShutdownHook(final String name, final Runnable runnable) { + if (name != null) { + Runtime.getRuntime().addShutdownHook(KafkaThread.nonDaemon(name, runnable)); + } else { + Runtime.getRuntime().addShutdownHook(new Thread(runnable)); + } + } + + private static File tempDirectory() { + final String prefix = "kafka-"; + final File file; + try { + file = Files.createTempDirectory(prefix).toFile(); + } catch (final IOException ex) { + throw new RuntimeException("Failed to create a temp dir", ex); + } + file.deleteOnExit(); + + addShutdownHook("delete-temp-file-shutdown-hook", () -> { + try { + Utils.delete(file); + } catch (final IOException e) { + System.out.println("Error deleting " + file.getAbsolutePath()); + e.printStackTrace(System.out); + } + }); + + return file; + } + + public SmokeTestClient(final String name) { + this.name = name; + } + + public boolean started() { + return started; + } + + public boolean closed() { + return closed; + } + + public void start(final Properties streamsProperties) { + final Topology build = getTopology(); + streams = new KafkaStreams(build, getStreamsConfig(streamsProperties)); + + final CountDownLatch countDownLatch = new CountDownLatch(1); + streams.setStateListener((newState, oldState) -> { + System.out.printf("%s %s: %s -> %s%n", name, Instant.now(), oldState, newState); + if (oldState == KafkaStreams.State.REBALANCING && newState == KafkaStreams.State.RUNNING) { + started = true; + countDownLatch.countDown(); + } + + if (newState == KafkaStreams.State.NOT_RUNNING) { + closed = true; + } + }); + + streams.setUncaughtExceptionHandler((t, e) -> { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION"); + System.out.println(name + ": FATAL: An unexpected exception is encountered on thread " + t + ": " + e); + e.printStackTrace(System.out); + uncaughtException = true; + streams.close(Duration.ofSeconds(30)); + }); + + addShutdownHook("streams-shutdown-hook", this::close); + + streams.start(); + try { + if (!countDownLatch.await(1, TimeUnit.MINUTES)) { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION: Didn't start in one minute"); + } + } catch (final InterruptedException e) { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION: " + e); + e.printStackTrace(System.out); + } + System.out.println(name + ": SMOKE-TEST-CLIENT-STARTED"); + System.out.println(name + " started at " + Instant.now()); + } + + public void closeAsync() { + streams.close(Duration.ZERO); + } + + public void close() { + final boolean closed = streams.close(Duration.ofMinutes(1)); + + if (closed && !uncaughtException) { + System.out.println(name + ": SMOKE-TEST-CLIENT-CLOSED"); + } else if (closed) { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION"); + } else { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION: Didn't close"); + } + } + + private Properties getStreamsConfig(final Properties props) { + final Properties fullProps = new Properties(props); + fullProps.put(StreamsConfig.APPLICATION_ID_CONFIG, "SmokeTest"); + fullProps.put(StreamsConfig.CLIENT_ID_CONFIG, "SmokeTest-" + name); + fullProps.put(StreamsConfig.STATE_DIR_CONFIG, tempDirectory().getAbsolutePath()); + fullProps.putAll(props); + return fullProps; + } + + public Topology getTopology() { + final StreamsBuilder builder = new StreamsBuilder(); + final Consumed stringIntConsumed = Consumed.with(stringSerde, intSerde); + final KStream source = builder.stream("data", stringIntConsumed); + source.filterNot((k, v) -> k.equals("flush")) + .to("echo", Produced.with(stringSerde, intSerde)); + final KStream data = source.filter((key, value) -> value == null || value != END); + data.process(SmokeTestUtil.printProcessorSupplier("data", name)); + + // min + final KGroupedStream groupedData = data.groupByKey(Grouped.with(stringSerde, intSerde)); + + final KTable, Integer> minAggregation = groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(1)).grace(Duration.ofMinutes(1))) + .aggregate( + () -> Integer.MAX_VALUE, + (aggKey, value, aggregate) -> (value < aggregate) ? value : aggregate, + Materialized + .>as("uwin-min") + .withValueSerde(intSerde) + .withRetention(Duration.ofHours(25)) + ); + + streamify(minAggregation, "min-raw"); + + streamify(minAggregation.suppress(untilWindowCloses(BufferConfig.unbounded())), "min-suppressed"); + + minAggregation + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("min", Produced.with(stringSerde, intSerde)); + + final KTable, Integer> smallWindowSum = groupedData + .windowedBy(TimeWindows.of(Duration.ofSeconds(2)).advanceBy(Duration.ofSeconds(1)).grace(Duration.ofSeconds(30))) + .reduce((l, r) -> l + r); + + streamify(smallWindowSum, "sws-raw"); + streamify(smallWindowSum.suppress(untilWindowCloses(BufferConfig.unbounded())), "sws-suppressed"); + + final KTable minTable = builder.table( + "min", + Consumed.with(stringSerde, intSerde), + Materialized.as("minStoreName")); + + minTable.toStream().process(SmokeTestUtil.printProcessorSupplier("min", name)); + + // max + groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(2))) + .aggregate( + () -> Integer.MIN_VALUE, + (aggKey, value, aggregate) -> (value > aggregate) ? value : aggregate, + Materialized.>as("uwin-max").withValueSerde(intSerde)) + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("max", Produced.with(stringSerde, intSerde)); + + final KTable maxTable = builder.table( + "max", + Consumed.with(stringSerde, intSerde), + Materialized.as("maxStoreName")); + maxTable.toStream().process(SmokeTestUtil.printProcessorSupplier("max", name)); + + // sum + groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(2))) + .aggregate( + () -> 0L, + (aggKey, value, aggregate) -> (long) value + aggregate, + Materialized.>as("win-sum").withValueSerde(longSerde)) + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("sum", Produced.with(stringSerde, longSerde)); + + final Consumed stringLongConsumed = Consumed.with(stringSerde, longSerde); + final KTable sumTable = builder.table("sum", stringLongConsumed); + sumTable.toStream().process(SmokeTestUtil.printProcessorSupplier("sum", name)); + + // cnt + groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(2))) + .count(Materialized.as("uwin-cnt")) + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("cnt", Produced.with(stringSerde, longSerde)); + + final KTable cntTable = builder.table( + "cnt", + Consumed.with(stringSerde, longSerde), + Materialized.as("cntStoreName")); + cntTable.toStream().process(SmokeTestUtil.printProcessorSupplier("cnt", name)); + + // dif + maxTable + .join( + minTable, + (value1, value2) -> value1 - value2) + .toStream() + .filterNot((k, v) -> k.equals("flush")) + .to("dif", Produced.with(stringSerde, intSerde)); + + // avg + sumTable + .join( + cntTable, + (value1, value2) -> (double) value1 / (double) value2) + .toStream() + .filterNot((k, v) -> k.equals("flush")) + .to("avg", Produced.with(stringSerde, doubleSerde)); + + // test repartition + final Agg agg = new Agg(); + cntTable.groupBy(agg.selector(), Grouped.with(stringSerde, longSerde)) + .aggregate(agg.init(), agg.adder(), agg.remover(), + Materialized.as(Stores.inMemoryKeyValueStore("cntByCnt")) + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.Long())) + .toStream() + .to("tagg", Produced.with(stringSerde, longSerde)); + + return builder.build(); + } + + private static void streamify(final KTable, Integer> windowedTable, final String topic) { + windowedTable + .toStream() + .filterNot((k, v) -> k.key().equals("flush")) + .map((key, value) -> new KeyValue<>(key.toString(), value)) + .to(topic, Produced.with(stringSerde, intSerde)); + } +} diff --git a/streams/upgrade-system-tests-24/src/test/java/org/apache/kafka/streams/tests/SmokeTestDriver.java b/streams/upgrade-system-tests-24/src/test/java/org/apache/kafka/streams/tests/SmokeTestDriver.java new file mode 100644 index 0000000..ac83cd9 --- /dev/null +++ b/streams/upgrade-system-tests-24/src/test/java/org/apache/kafka/streams/tests/SmokeTestDriver.java @@ -0,0 +1,622 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.Callback; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; + +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static java.util.Collections.emptyMap; +import static org.apache.kafka.common.utils.Utils.mkEntry; + +public class SmokeTestDriver extends SmokeTestUtil { + private static final String[] TOPICS = { + "data", + "echo", + "max", + "min", "min-suppressed", "min-raw", + "dif", + "sum", + "sws-raw", "sws-suppressed", + "cnt", + "avg", + "tagg" + }; + + private static final int MAX_RECORD_EMPTY_RETRIES = 30; + + private static class ValueList { + public final String key; + private final int[] values; + private int index; + + ValueList(final int min, final int max) { + key = min + "-" + max; + + values = new int[max - min + 1]; + for (int i = 0; i < values.length; i++) { + values[i] = min + i; + } + // We want to randomize the order of data to test not completely predictable processing order + // However, values are also use as a timestamp of the record. (TODO: separate data and timestamp) + // We keep some correlation of time and order. Thus, the shuffling is done with a sliding window + shuffle(values, 10); + + index = 0; + } + + int next() { + return (index < values.length) ? values[index++] : -1; + } + } + + public static String[] topics() { + return Arrays.copyOf(TOPICS, TOPICS.length); + } + + static void generatePerpetually(final String kafka, + final int numKeys, + final int maxRecordsPerKey) { + final Properties producerProps = generatorProperties(kafka); + + int numRecordsProduced = 0; + + final ValueList[] data = new ValueList[numKeys]; + for (int i = 0; i < numKeys; i++) { + data[i] = new ValueList(i, i + maxRecordsPerKey - 1); + } + + final Random rand = new Random(); + + try (final KafkaProducer producer = new KafkaProducer<>(producerProps)) { + while (true) { + final int index = rand.nextInt(numKeys); + final String key = data[index].key; + final int value = data[index].next(); + + final ProducerRecord record = + new ProducerRecord<>( + "data", + stringSerde.serializer().serialize("", key), + intSerde.serializer().serialize("", value) + ); + + producer.send(record); + + numRecordsProduced++; + if (numRecordsProduced % 100 == 0) { + System.out.println(Instant.now() + " " + numRecordsProduced + " records produced"); + } + Utils.sleep(2); + } + } + } + + public static Map> generate(final String kafka, + final int numKeys, + final int maxRecordsPerKey, + final Duration timeToSpend) { + final Properties producerProps = generatorProperties(kafka); + + + int numRecordsProduced = 0; + + final Map> allData = new HashMap<>(); + final ValueList[] data = new ValueList[numKeys]; + for (int i = 0; i < numKeys; i++) { + data[i] = new ValueList(i, i + maxRecordsPerKey - 1); + allData.put(data[i].key, new HashSet<>()); + } + final Random rand = new Random(); + + int remaining = data.length; + + final long recordPauseTime = timeToSpend.toMillis() / numKeys / maxRecordsPerKey; + + List> needRetry = new ArrayList<>(); + + try (final KafkaProducer producer = new KafkaProducer<>(producerProps)) { + while (remaining > 0) { + final int index = rand.nextInt(remaining); + final String key = data[index].key; + final int value = data[index].next(); + + if (value < 0) { + remaining--; + data[index] = data[remaining]; + } else { + + final ProducerRecord record = + new ProducerRecord<>( + "data", + stringSerde.serializer().serialize("", key), + intSerde.serializer().serialize("", value) + ); + + producer.send(record, new TestCallback(record, needRetry)); + + numRecordsProduced++; + allData.get(key).add(value); + if (numRecordsProduced % 100 == 0) { + System.out.println(Instant.now() + " " + numRecordsProduced + " records produced"); + } + Utils.sleep(Math.max(recordPauseTime, 2)); + } + } + producer.flush(); + + int remainingRetries = 5; + while (!needRetry.isEmpty()) { + final List> needRetry2 = new ArrayList<>(); + for (final ProducerRecord record : needRetry) { + System.out.println("retry producing " + stringSerde.deserializer().deserialize("", record.key())); + producer.send(record, new TestCallback(record, needRetry2)); + } + producer.flush(); + needRetry = needRetry2; + + if (--remainingRetries == 0 && !needRetry.isEmpty()) { + System.err.println("Failed to produce all records after multiple retries"); + Exit.exit(1); + } + } + + // now that we've sent everything, we'll send some final records with a timestamp high enough to flush out + // all suppressed records. + final List partitions = producer.partitionsFor("data"); + for (final PartitionInfo partition : partitions) { + producer.send(new ProducerRecord<>( + partition.topic(), + partition.partition(), + System.currentTimeMillis() + Duration.ofDays(2).toMillis(), + stringSerde.serializer().serialize("", "flush"), + intSerde.serializer().serialize("", 0) + )); + } + } + return Collections.unmodifiableMap(allData); + } + + private static Properties generatorProperties(final String kafka) { + final Properties producerProps = new Properties(); + producerProps.put(ProducerConfig.CLIENT_ID_CONFIG, "SmokeTest"); + producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, kafka); + producerProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class); + producerProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class); + producerProps.put(ProducerConfig.ACKS_CONFIG, "all"); + return producerProps; + } + + private static class TestCallback implements Callback { + private final ProducerRecord originalRecord; + private final List> needRetry; + + TestCallback(final ProducerRecord originalRecord, + final List> needRetry) { + this.originalRecord = originalRecord; + this.needRetry = needRetry; + } + + @Override + public void onCompletion(final RecordMetadata metadata, final Exception exception) { + if (exception != null) { + if (exception instanceof TimeoutException) { + needRetry.add(originalRecord); + } else { + exception.printStackTrace(); + Exit.exit(1); + } + } + } + } + + private static void shuffle(final int[] data, @SuppressWarnings("SameParameterValue") final int windowSize) { + final Random rand = new Random(); + for (int i = 0; i < data.length; i++) { + // we shuffle data within windowSize + final int j = rand.nextInt(Math.min(data.length - i, windowSize)) + i; + + // swap + final int tmp = data[i]; + data[i] = data[j]; + data[j] = tmp; + } + } + + public static class NumberDeserializer implements Deserializer { + @Override + public Number deserialize(final String topic, final byte[] data) { + final Number value; + switch (topic) { + case "data": + case "echo": + case "min": + case "min-raw": + case "min-suppressed": + case "sws-raw": + case "sws-suppressed": + case "max": + case "dif": + value = intSerde.deserializer().deserialize(topic, data); + break; + case "sum": + case "cnt": + case "tagg": + value = longSerde.deserializer().deserialize(topic, data); + break; + case "avg": + value = doubleSerde.deserializer().deserialize(topic, data); + break; + default: + throw new RuntimeException("unknown topic: " + topic); + } + return value; + } + } + + public static VerificationResult verify(final String kafka, + final Map> inputs, + final int maxRecordsPerKey) { + final Properties props = new Properties(); + props.put(ConsumerConfig.CLIENT_ID_CONFIG, "verifier"); + props.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, kafka); + props.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); + props.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, NumberDeserializer.class); + props.put(ConsumerConfig.ISOLATION_LEVEL_CONFIG, "read_committed"); + + final KafkaConsumer consumer = new KafkaConsumer<>(props); + final List partitions = getAllPartitions(consumer, TOPICS); + consumer.assign(partitions); + consumer.seekToBeginning(partitions); + + final int recordsGenerated = inputs.size() * maxRecordsPerKey; + int recordsProcessed = 0; + final Map processed = + Stream.of(TOPICS) + .collect(Collectors.toMap(t -> t, t -> new AtomicInteger(0))); + + final Map>>> events = new HashMap<>(); + + VerificationResult verificationResult = new VerificationResult(false, "no results yet"); + int retry = 0; + final long start = System.currentTimeMillis(); + while (System.currentTimeMillis() - start < TimeUnit.MINUTES.toMillis(6)) { + final ConsumerRecords records = consumer.poll(Duration.ofSeconds(5)); + if (records.isEmpty() && recordsProcessed >= recordsGenerated) { + verificationResult = verifyAll(inputs, events, false); + if (verificationResult.passed()) { + break; + } else if (retry++ > MAX_RECORD_EMPTY_RETRIES) { + System.out.println(Instant.now() + " Didn't get any more results, verification hasn't passed, and out of retries."); + break; + } else { + System.out.println(Instant.now() + " Didn't get any more results, but verification hasn't passed (yet). Retrying..." + retry); + } + } else { + System.out.println(Instant.now() + " Get some more results from " + records.partitions() + ", resetting retry."); + + retry = 0; + for (final ConsumerRecord record : records) { + final String key = record.key(); + + final String topic = record.topic(); + processed.get(topic).incrementAndGet(); + + if (topic.equals("echo")) { + recordsProcessed++; + if (recordsProcessed % 100 == 0) { + System.out.println("Echo records processed = " + recordsProcessed); + } + } + + events.computeIfAbsent(topic, t -> new HashMap<>()) + .computeIfAbsent(key, k -> new LinkedList<>()) + .add(record); + } + + System.out.println(processed); + } + } + consumer.close(); + final long finished = System.currentTimeMillis() - start; + System.out.println("Verification time=" + finished); + System.out.println("-------------------"); + System.out.println("Result Verification"); + System.out.println("-------------------"); + System.out.println("recordGenerated=" + recordsGenerated); + System.out.println("recordProcessed=" + recordsProcessed); + + if (recordsProcessed > recordsGenerated) { + System.out.println("PROCESSED-MORE-THAN-GENERATED"); + } else if (recordsProcessed < recordsGenerated) { + System.out.println("PROCESSED-LESS-THAN-GENERATED"); + } + + boolean success; + + final Map> received = + events.get("echo") + .entrySet() + .stream() + .map(entry -> mkEntry( + entry.getKey(), + entry.getValue().stream().map(ConsumerRecord::value).collect(Collectors.toSet())) + ) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + success = inputs.equals(received); + + if (success) { + System.out.println("ALL-RECORDS-DELIVERED"); + } else { + int missedCount = 0; + for (final Map.Entry> entry : inputs.entrySet()) { + missedCount += received.get(entry.getKey()).size(); + } + System.out.println("missedRecords=" + missedCount); + } + + // give it one more try if it's not already passing. + if (!verificationResult.passed()) { + verificationResult = verifyAll(inputs, events, true); + } + success &= verificationResult.passed(); + + System.out.println(verificationResult.result()); + + System.out.println(success ? "SUCCESS" : "FAILURE"); + return verificationResult; + } + + public static class VerificationResult { + private final boolean passed; + private final String result; + + VerificationResult(final boolean passed, final String result) { + this.passed = passed; + this.result = result; + } + + public boolean passed() { + return passed; + } + + public String result() { + return result; + } + } + + private static VerificationResult verifyAll(final Map> inputs, + final Map>>> events, + final boolean printResults) { + final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + boolean pass; + try (final PrintStream resultStream = new PrintStream(byteArrayOutputStream)) { + pass = verifyTAgg(resultStream, inputs, events.get("tagg"), printResults); + pass &= verifySuppressed(resultStream, "min-suppressed", events, printResults); + pass &= verify(resultStream, "min-suppressed", inputs, events, windowedKey -> { + final String unwindowedKey = windowedKey.substring(1, windowedKey.length() - 1).replaceAll("@.*", ""); + return getMin(unwindowedKey); + }, printResults); + pass &= verifySuppressed(resultStream, "sws-suppressed", events, printResults); + pass &= verify(resultStream, "min", inputs, events, SmokeTestDriver::getMin, printResults); + pass &= verify(resultStream, "max", inputs, events, SmokeTestDriver::getMax, printResults); + pass &= verify(resultStream, "dif", inputs, events, key -> getMax(key).intValue() - getMin(key).intValue(), printResults); + pass &= verify(resultStream, "sum", inputs, events, SmokeTestDriver::getSum, printResults); + pass &= verify(resultStream, "cnt", inputs, events, key1 -> getMax(key1).intValue() - getMin(key1).intValue() + 1L, printResults); + pass &= verify(resultStream, "avg", inputs, events, SmokeTestDriver::getAvg, printResults); + } + return new VerificationResult(pass, new String(byteArrayOutputStream.toByteArray(), StandardCharsets.UTF_8)); + } + + private static boolean verify(final PrintStream resultStream, + final String topic, + final Map> inputData, + final Map>>> events, + final Function keyToExpectation, + final boolean printResults) { + final Map>> observedInputEvents = events.get("data"); + final Map>> outputEvents = events.getOrDefault(topic, emptyMap()); + if (outputEvents.isEmpty()) { + resultStream.println(topic + " is empty"); + return false; + } else { + resultStream.printf("verifying %s with %d keys%n", topic, outputEvents.size()); + + if (outputEvents.size() != inputData.size()) { + resultStream.printf("fail: resultCount=%d expectedCount=%s%n\tresult=%s%n\texpected=%s%n", + outputEvents.size(), inputData.size(), outputEvents.keySet(), inputData.keySet()); + return false; + } + for (final Map.Entry>> entry : outputEvents.entrySet()) { + final String key = entry.getKey(); + final Number expected = keyToExpectation.apply(key); + final Number actual = entry.getValue().getLast().value(); + if (!expected.equals(actual)) { + resultStream.printf("%s fail: key=%s actual=%s expected=%s%n", topic, key, actual, expected); + + if (printResults) { + resultStream.printf("\t inputEvents=%n%s%n\t" + + "echoEvents=%n%s%n\tmaxEvents=%n%s%n\tminEvents=%n%s%n\tdifEvents=%n%s%n\tcntEvents=%n%s%n\ttaggEvents=%n%s%n", + indent("\t\t", observedInputEvents.get(key)), + indent("\t\t", events.getOrDefault("echo", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("max", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("min", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("dif", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("cnt", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("tagg", emptyMap()).getOrDefault(key, new LinkedList<>()))); + + if (!Utils.mkSet("echo", "max", "min", "dif", "cnt", "tagg").contains(topic)) + resultStream.printf("%sEvents=%n%s%n", topic, indent("\t\t", entry.getValue())); + } + + return false; + } + } + return true; + } + } + + + private static boolean verifySuppressed(final PrintStream resultStream, + @SuppressWarnings("SameParameterValue") final String topic, + final Map>>> events, + final boolean printResults) { + resultStream.println("verifying suppressed " + topic); + final Map>> topicEvents = events.getOrDefault(topic, emptyMap()); + for (final Map.Entry>> entry : topicEvents.entrySet()) { + if (entry.getValue().size() != 1) { + final String unsuppressedTopic = topic.replace("-suppressed", "-raw"); + final String key = entry.getKey(); + final String unwindowedKey = key.substring(1, key.length() - 1).replaceAll("@.*", ""); + resultStream.printf("fail: key=%s%n\tnon-unique result:%n%s%n", + key, + indent("\t\t", entry.getValue())); + + if (printResults) + resultStream.printf("\tresultEvents:%n%s%n\tinputEvents:%n%s%n", + indent("\t\t", events.get(unsuppressedTopic).get(key)), + indent("\t\t", events.get("data").get(unwindowedKey))); + + return false; + } + } + return true; + } + + private static String indent(@SuppressWarnings("SameParameterValue") final String prefix, + final Iterable> list) { + final StringBuilder stringBuilder = new StringBuilder(); + for (final ConsumerRecord record : list) { + stringBuilder.append(prefix).append(record).append('\n'); + } + return stringBuilder.toString(); + } + + private static Long getSum(final String key) { + final int min = getMin(key).intValue(); + final int max = getMax(key).intValue(); + return ((long) min + max) * (max - min + 1L) / 2L; + } + + private static Double getAvg(final String key) { + final int min = getMin(key).intValue(); + final int max = getMax(key).intValue(); + return ((long) min + max) / 2.0; + } + + + private static boolean verifyTAgg(final PrintStream resultStream, + final Map> allData, + final Map>> taggEvents, + final boolean printResults) { + if (taggEvents == null) { + resultStream.println("tagg is missing"); + return false; + } else if (taggEvents.isEmpty()) { + resultStream.println("tagg is empty"); + return false; + } else { + resultStream.println("verifying tagg"); + + // generate expected answer + final Map expected = new HashMap<>(); + for (final String key : allData.keySet()) { + final int min = getMin(key).intValue(); + final int max = getMax(key).intValue(); + final String cnt = Long.toString(max - min + 1L); + + expected.put(cnt, expected.getOrDefault(cnt, 0L) + 1); + } + + // check the result + for (final Map.Entry>> entry : taggEvents.entrySet()) { + final String key = entry.getKey(); + Long expectedCount = expected.remove(key); + if (expectedCount == null) { + expectedCount = 0L; + } + + if (entry.getValue().getLast().value().longValue() != expectedCount) { + resultStream.println("fail: key=" + key + " tagg=" + entry.getValue() + " expected=" + expectedCount); + + if (printResults) + resultStream.println("\t taggEvents: " + entry.getValue()); + return false; + } + } + + } + return true; + } + + private static Number getMin(final String key) { + return Integer.parseInt(key.split("-")[0]); + } + + private static Number getMax(final String key) { + return Integer.parseInt(key.split("-")[1]); + } + + private static List getAllPartitions(final KafkaConsumer consumer, final String... topics) { + final List partitions = new ArrayList<>(); + + for (final String topic : topics) { + for (final PartitionInfo info : consumer.partitionsFor(topic)) { + partitions.add(new TopicPartition(info.topic(), info.partition())); + } + } + return partitions; + } + +} diff --git a/streams/upgrade-system-tests-24/src/test/java/org/apache/kafka/streams/tests/SmokeTestUtil.java b/streams/upgrade-system-tests-24/src/test/java/org/apache/kafka/streams/tests/SmokeTestUtil.java new file mode 100644 index 0000000..e8ec04c --- /dev/null +++ b/streams/upgrade-system-tests-24/src/test/java/org/apache/kafka/streams/tests/SmokeTestUtil.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.AbstractProcessor; +import org.apache.kafka.streams.processor.Processor; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.ProcessorSupplier; + +import java.time.Instant; + +public class SmokeTestUtil { + + final static int END = Integer.MAX_VALUE; + + static ProcessorSupplier printProcessorSupplier(final String topic) { + return printProcessorSupplier(topic, ""); + } + + static ProcessorSupplier printProcessorSupplier(final String topic, final String name) { + return new ProcessorSupplier() { + @Override + public Processor get() { + return new AbstractProcessor() { + private int numRecordsProcessed = 0; + private long smallestOffset = Long.MAX_VALUE; + private long largestOffset = Long.MIN_VALUE; + + @Override + public void init(final ProcessorContext context) { + super.init(context); + System.out.println("[DEV] initializing processor: topic=" + topic + " taskId=" + context.taskId()); + System.out.flush(); + numRecordsProcessed = 0; + smallestOffset = Long.MAX_VALUE; + largestOffset = Long.MIN_VALUE; + } + + @Override + public void process(final Object key, final Object value) { + numRecordsProcessed++; + if (numRecordsProcessed % 100 == 0) { + System.out.printf("%s: %s%n", name, Instant.now()); + System.out.println("processed " + numRecordsProcessed + " records from topic=" + topic); + } + + if (smallestOffset > context().offset()) { + smallestOffset = context().offset(); + } + if (largestOffset < context().offset()) { + largestOffset = context().offset(); + } + } + + @Override + public void close() { + System.out.printf("Close processor for task %s%n", context().taskId()); + System.out.println("processed " + numRecordsProcessed + " records"); + final long processed; + if (largestOffset >= smallestOffset) { + processed = 1L + largestOffset - smallestOffset; + } else { + processed = 0L; + } + System.out.println("offset " + smallestOffset + " to " + largestOffset + " -> processed " + processed); + System.out.flush(); + } + }; + } + }; + } + + public static final class Unwindow implements KeyValueMapper, V, K> { + @Override + public K apply(final Windowed winKey, final V value) { + return winKey.key(); + } + } + + public static class Agg { + + KeyValueMapper> selector() { + return (key, value) -> new KeyValue<>(value == null ? null : Long.toString(value), 1L); + } + + public Initializer init() { + return () -> 0L; + } + + Aggregator adder() { + return (aggKey, value, aggregate) -> aggregate + value; + } + + Aggregator remover() { + return (aggKey, value, aggregate) -> aggregate - value; + } + } + + public static Serde stringSerde = Serdes.String(); + + public static Serde intSerde = Serdes.Integer(); + + static Serde longSerde = Serdes.Long(); + + static Serde doubleSerde = Serdes.Double(); + + public static void sleep(final long duration) { + try { + Thread.sleep(duration); + } catch (final Exception ignore) { } + } + +} diff --git a/streams/upgrade-system-tests-24/src/test/java/org/apache/kafka/streams/tests/StreamsSmokeTest.java b/streams/upgrade-system-tests-24/src/test/java/org/apache/kafka/streams/tests/StreamsSmokeTest.java new file mode 100644 index 0000000..f280eb0 --- /dev/null +++ b/streams/upgrade-system-tests-24/src/test/java/org/apache/kafka/streams/tests/StreamsSmokeTest.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.StreamsConfig; + +import java.io.IOException; +import java.time.Duration; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.UUID; + +import static org.apache.kafka.streams.tests.SmokeTestDriver.generate; +import static org.apache.kafka.streams.tests.SmokeTestDriver.generatePerpetually; + +public class StreamsSmokeTest { + + /** + * args ::= kafka propFileName command disableAutoTerminate + * command := "run" | "process" + * + * @param args + */ + public static void main(final String[] args) throws IOException { + if (args.length < 2) { + System.err.println("StreamsSmokeTest are expecting two parameters: propFile, command; but only see " + args.length + " parameter"); + Exit.exit(1); + } + + final String propFileName = args[0]; + final String command = args[1]; + final boolean disableAutoTerminate = args.length > 2; + + final Properties streamsProperties = Utils.loadProps(propFileName); + final String kafka = streamsProperties.getProperty(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG); + final String processingGuarantee = streamsProperties.getProperty(StreamsConfig.PROCESSING_GUARANTEE_CONFIG); + + if (kafka == null) { + System.err.println("No bootstrap kafka servers specified in " + StreamsConfig.BOOTSTRAP_SERVERS_CONFIG); + Exit.exit(1); + } + + if ("process".equals(command)) { + if (!StreamsConfig.AT_LEAST_ONCE.equals(processingGuarantee) && + !StreamsConfig.EXACTLY_ONCE.equals(processingGuarantee)) { + + System.err.println("processingGuarantee must be either " + StreamsConfig.AT_LEAST_ONCE + " or " + + StreamsConfig.EXACTLY_ONCE); + + Exit.exit(1); + } + } + + System.out.println("StreamsTest instance started (StreamsSmokeTest)"); + System.out.println("command=" + command); + System.out.println("props=" + streamsProperties); + System.out.println("disableAutoTerminate=" + disableAutoTerminate); + + switch (command) { + case "run": + // this starts the driver (data generation and result verification) + final int numKeys = 10; + final int maxRecordsPerKey = 500; + if (disableAutoTerminate) { + generatePerpetually(kafka, numKeys, maxRecordsPerKey); + } else { + // slow down data production to span 30 seconds so that system tests have time to + // do their bounces, etc. + final Map> allData = + generate(kafka, numKeys, maxRecordsPerKey, Duration.ofSeconds(30)); + SmokeTestDriver.verify(kafka, allData, maxRecordsPerKey); + } + break; + case "process": + // this starts the stream processing app + new SmokeTestClient(UUID.randomUUID().toString()).start(streamsProperties); + break; + default: + System.out.println("unknown command: " + command); + } + } + +} diff --git a/streams/upgrade-system-tests-24/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java b/streams/upgrade-system-tests-24/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java new file mode 100644 index 0000000..c0c8c72 --- /dev/null +++ b/streams/upgrade-system-tests-24/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.processor.AbstractProcessor; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.ProcessorSupplier; + +import java.util.Properties; + +public class StreamsUpgradeTest { + + @SuppressWarnings("unchecked") + public static void main(final String[] args) throws Exception { + if (args.length < 1) { + System.err.println("StreamsUpgradeTest requires one argument (properties-file) but provided none"); + } + final String propFileName = args[0]; + + final Properties streamsProperties = Utils.loadProps(propFileName); + + System.out.println("StreamsTest instance started (StreamsUpgradeTest v2.4)"); + System.out.println("props=" + streamsProperties); + + final StreamsBuilder builder = new StreamsBuilder(); + final KStream dataStream = builder.stream("data"); + dataStream.process(printProcessorSupplier()); + dataStream.to("echo"); + + final Properties config = new Properties(); + config.setProperty(StreamsConfig.APPLICATION_ID_CONFIG, "StreamsUpgradeTest"); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + config.putAll(streamsProperties); + + final KafkaStreams streams = new KafkaStreams(builder.build(), config); + streams.start(); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + streams.close(); + System.out.println("UPGRADE-TEST-CLIENT-CLOSED"); + System.out.flush(); + })); + } + + private static ProcessorSupplier printProcessorSupplier() { + return () -> new AbstractProcessor() { + private int numRecordsProcessed = 0; + + @Override + public void init(final ProcessorContext context) { + System.out.println("[2.4] initializing processor: topic=data taskId=" + context.taskId()); + numRecordsProcessed = 0; + } + + @Override + public void process(final K key, final V value) { + numRecordsProcessed++; + if (numRecordsProcessed % 100 == 0) { + System.out.println("processed " + numRecordsProcessed + " records from topic=data"); + } + } + + @Override + public void close() {} + }; + } +} diff --git a/streams/upgrade-system-tests-25/src/test/java/org/apache/kafka/streams/tests/SmokeTestClient.java b/streams/upgrade-system-tests-25/src/test/java/org/apache/kafka/streams/tests/SmokeTestClient.java new file mode 100644 index 0000000..ced1369 --- /dev/null +++ b/streams/upgrade-system-tests-25/src/test/java/org/apache/kafka/streams/tests/SmokeTestClient.java @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.KafkaThread; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.Suppressed.BufferConfig; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.WindowStore; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.time.Duration; +import java.time.Instant; +import java.util.Properties; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.apache.kafka.streams.kstream.Suppressed.untilWindowCloses; + +public class SmokeTestClient extends SmokeTestUtil { + + private final String name; + + private KafkaStreams streams; + private boolean uncaughtException = false; + private boolean started; + private volatile boolean closed; + + private static void addShutdownHook(final String name, final Runnable runnable) { + if (name != null) { + Runtime.getRuntime().addShutdownHook(KafkaThread.nonDaemon(name, runnable)); + } else { + Runtime.getRuntime().addShutdownHook(new Thread(runnable)); + } + } + + private static File tempDirectory() { + final String prefix = "kafka-"; + final File file; + try { + file = Files.createTempDirectory(prefix).toFile(); + } catch (final IOException ex) { + throw new RuntimeException("Failed to create a temp dir", ex); + } + file.deleteOnExit(); + + addShutdownHook("delete-temp-file-shutdown-hook", () -> { + try { + Utils.delete(file); + } catch (final IOException e) { + System.out.println("Error deleting " + file.getAbsolutePath()); + e.printStackTrace(System.out); + } + }); + + return file; + } + + public SmokeTestClient(final String name) { + this.name = name; + } + + public boolean started() { + return started; + } + + public boolean closed() { + return closed; + } + + public void start(final Properties streamsProperties) { + final Topology build = getTopology(); + streams = new KafkaStreams(build, getStreamsConfig(streamsProperties)); + + final CountDownLatch countDownLatch = new CountDownLatch(1); + streams.setStateListener((newState, oldState) -> { + System.out.printf("%s %s: %s -> %s%n", name, Instant.now(), oldState, newState); + if (oldState == KafkaStreams.State.REBALANCING && newState == KafkaStreams.State.RUNNING) { + started = true; + countDownLatch.countDown(); + } + + if (newState == KafkaStreams.State.NOT_RUNNING) { + closed = true; + } + }); + + streams.setUncaughtExceptionHandler((t, e) -> { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION"); + System.out.println(name + ": FATAL: An unexpected exception is encountered on thread " + t + ": " + e); + e.printStackTrace(System.out); + uncaughtException = true; + streams.close(Duration.ofSeconds(30)); + }); + + addShutdownHook("streams-shutdown-hook", this::close); + + streams.start(); + try { + if (!countDownLatch.await(1, TimeUnit.MINUTES)) { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION: Didn't start in one minute"); + } + } catch (final InterruptedException e) { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION: " + e); + e.printStackTrace(System.out); + } + System.out.println(name + ": SMOKE-TEST-CLIENT-STARTED"); + System.out.println(name + " started at " + Instant.now()); + } + + public void closeAsync() { + streams.close(Duration.ZERO); + } + + public void close() { + final boolean closed = streams.close(Duration.ofMinutes(1)); + + if (closed && !uncaughtException) { + System.out.println(name + ": SMOKE-TEST-CLIENT-CLOSED"); + } else if (closed) { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION"); + } else { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION: Didn't close"); + } + } + + private Properties getStreamsConfig(final Properties props) { + final Properties fullProps = new Properties(props); + fullProps.put(StreamsConfig.APPLICATION_ID_CONFIG, "SmokeTest"); + fullProps.put(StreamsConfig.CLIENT_ID_CONFIG, "SmokeTest-" + name); + fullProps.put(StreamsConfig.STATE_DIR_CONFIG, tempDirectory().getAbsolutePath()); + fullProps.putAll(props); + return fullProps; + } + + public Topology getTopology() { + final StreamsBuilder builder = new StreamsBuilder(); + final Consumed stringIntConsumed = Consumed.with(stringSerde, intSerde); + final KStream source = builder.stream("data", stringIntConsumed); + source.filterNot((k, v) -> k.equals("flush")) + .to("echo", Produced.with(stringSerde, intSerde)); + final KStream data = source.filter((key, value) -> value == null || value != END); + data.process(SmokeTestUtil.printProcessorSupplier("data", name)); + + // min + final KGroupedStream groupedData = data.groupByKey(Grouped.with(stringSerde, intSerde)); + + final KTable, Integer> minAggregation = groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(1)).grace(Duration.ofMinutes(1))) + .aggregate( + () -> Integer.MAX_VALUE, + (aggKey, value, aggregate) -> (value < aggregate) ? value : aggregate, + Materialized + .>as("uwin-min") + .withValueSerde(intSerde) + .withRetention(Duration.ofHours(25)) + ); + + streamify(minAggregation, "min-raw"); + + streamify(minAggregation.suppress(untilWindowCloses(BufferConfig.unbounded())), "min-suppressed"); + + minAggregation + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("min", Produced.with(stringSerde, intSerde)); + + final KTable, Integer> smallWindowSum = groupedData + .windowedBy(TimeWindows.of(Duration.ofSeconds(2)).advanceBy(Duration.ofSeconds(1)).grace(Duration.ofSeconds(30))) + .reduce((l, r) -> l + r); + + streamify(smallWindowSum, "sws-raw"); + streamify(smallWindowSum.suppress(untilWindowCloses(BufferConfig.unbounded())), "sws-suppressed"); + + final KTable minTable = builder.table( + "min", + Consumed.with(stringSerde, intSerde), + Materialized.as("minStoreName")); + + minTable.toStream().process(SmokeTestUtil.printProcessorSupplier("min", name)); + + // max + groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(2))) + .aggregate( + () -> Integer.MIN_VALUE, + (aggKey, value, aggregate) -> (value > aggregate) ? value : aggregate, + Materialized.>as("uwin-max").withValueSerde(intSerde)) + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("max", Produced.with(stringSerde, intSerde)); + + final KTable maxTable = builder.table( + "max", + Consumed.with(stringSerde, intSerde), + Materialized.as("maxStoreName")); + maxTable.toStream().process(SmokeTestUtil.printProcessorSupplier("max", name)); + + // sum + groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(2))) + .aggregate( + () -> 0L, + (aggKey, value, aggregate) -> (long) value + aggregate, + Materialized.>as("win-sum").withValueSerde(longSerde)) + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("sum", Produced.with(stringSerde, longSerde)); + + final Consumed stringLongConsumed = Consumed.with(stringSerde, longSerde); + final KTable sumTable = builder.table("sum", stringLongConsumed); + sumTable.toStream().process(SmokeTestUtil.printProcessorSupplier("sum", name)); + + // cnt + groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(2))) + .count(Materialized.as("uwin-cnt")) + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("cnt", Produced.with(stringSerde, longSerde)); + + final KTable cntTable = builder.table( + "cnt", + Consumed.with(stringSerde, longSerde), + Materialized.as("cntStoreName")); + cntTable.toStream().process(SmokeTestUtil.printProcessorSupplier("cnt", name)); + + // dif + maxTable + .join( + minTable, + (value1, value2) -> value1 - value2) + .toStream() + .filterNot((k, v) -> k.equals("flush")) + .to("dif", Produced.with(stringSerde, intSerde)); + + // avg + sumTable + .join( + cntTable, + (value1, value2) -> (double) value1 / (double) value2) + .toStream() + .filterNot((k, v) -> k.equals("flush")) + .to("avg", Produced.with(stringSerde, doubleSerde)); + + // test repartition + final Agg agg = new Agg(); + cntTable.groupBy(agg.selector(), Grouped.with(stringSerde, longSerde)) + .aggregate(agg.init(), agg.adder(), agg.remover(), + Materialized.as(Stores.inMemoryKeyValueStore("cntByCnt")) + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.Long())) + .toStream() + .to("tagg", Produced.with(stringSerde, longSerde)); + + return builder.build(); + } + + private static void streamify(final KTable, Integer> windowedTable, final String topic) { + windowedTable + .toStream() + .filterNot((k, v) -> k.key().equals("flush")) + .map((key, value) -> new KeyValue<>(key.toString(), value)) + .to(topic, Produced.with(stringSerde, intSerde)); + } +} diff --git a/streams/upgrade-system-tests-25/src/test/java/org/apache/kafka/streams/tests/SmokeTestDriver.java b/streams/upgrade-system-tests-25/src/test/java/org/apache/kafka/streams/tests/SmokeTestDriver.java new file mode 100644 index 0000000..ac83cd9 --- /dev/null +++ b/streams/upgrade-system-tests-25/src/test/java/org/apache/kafka/streams/tests/SmokeTestDriver.java @@ -0,0 +1,622 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.Callback; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; + +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static java.util.Collections.emptyMap; +import static org.apache.kafka.common.utils.Utils.mkEntry; + +public class SmokeTestDriver extends SmokeTestUtil { + private static final String[] TOPICS = { + "data", + "echo", + "max", + "min", "min-suppressed", "min-raw", + "dif", + "sum", + "sws-raw", "sws-suppressed", + "cnt", + "avg", + "tagg" + }; + + private static final int MAX_RECORD_EMPTY_RETRIES = 30; + + private static class ValueList { + public final String key; + private final int[] values; + private int index; + + ValueList(final int min, final int max) { + key = min + "-" + max; + + values = new int[max - min + 1]; + for (int i = 0; i < values.length; i++) { + values[i] = min + i; + } + // We want to randomize the order of data to test not completely predictable processing order + // However, values are also use as a timestamp of the record. (TODO: separate data and timestamp) + // We keep some correlation of time and order. Thus, the shuffling is done with a sliding window + shuffle(values, 10); + + index = 0; + } + + int next() { + return (index < values.length) ? values[index++] : -1; + } + } + + public static String[] topics() { + return Arrays.copyOf(TOPICS, TOPICS.length); + } + + static void generatePerpetually(final String kafka, + final int numKeys, + final int maxRecordsPerKey) { + final Properties producerProps = generatorProperties(kafka); + + int numRecordsProduced = 0; + + final ValueList[] data = new ValueList[numKeys]; + for (int i = 0; i < numKeys; i++) { + data[i] = new ValueList(i, i + maxRecordsPerKey - 1); + } + + final Random rand = new Random(); + + try (final KafkaProducer producer = new KafkaProducer<>(producerProps)) { + while (true) { + final int index = rand.nextInt(numKeys); + final String key = data[index].key; + final int value = data[index].next(); + + final ProducerRecord record = + new ProducerRecord<>( + "data", + stringSerde.serializer().serialize("", key), + intSerde.serializer().serialize("", value) + ); + + producer.send(record); + + numRecordsProduced++; + if (numRecordsProduced % 100 == 0) { + System.out.println(Instant.now() + " " + numRecordsProduced + " records produced"); + } + Utils.sleep(2); + } + } + } + + public static Map> generate(final String kafka, + final int numKeys, + final int maxRecordsPerKey, + final Duration timeToSpend) { + final Properties producerProps = generatorProperties(kafka); + + + int numRecordsProduced = 0; + + final Map> allData = new HashMap<>(); + final ValueList[] data = new ValueList[numKeys]; + for (int i = 0; i < numKeys; i++) { + data[i] = new ValueList(i, i + maxRecordsPerKey - 1); + allData.put(data[i].key, new HashSet<>()); + } + final Random rand = new Random(); + + int remaining = data.length; + + final long recordPauseTime = timeToSpend.toMillis() / numKeys / maxRecordsPerKey; + + List> needRetry = new ArrayList<>(); + + try (final KafkaProducer producer = new KafkaProducer<>(producerProps)) { + while (remaining > 0) { + final int index = rand.nextInt(remaining); + final String key = data[index].key; + final int value = data[index].next(); + + if (value < 0) { + remaining--; + data[index] = data[remaining]; + } else { + + final ProducerRecord record = + new ProducerRecord<>( + "data", + stringSerde.serializer().serialize("", key), + intSerde.serializer().serialize("", value) + ); + + producer.send(record, new TestCallback(record, needRetry)); + + numRecordsProduced++; + allData.get(key).add(value); + if (numRecordsProduced % 100 == 0) { + System.out.println(Instant.now() + " " + numRecordsProduced + " records produced"); + } + Utils.sleep(Math.max(recordPauseTime, 2)); + } + } + producer.flush(); + + int remainingRetries = 5; + while (!needRetry.isEmpty()) { + final List> needRetry2 = new ArrayList<>(); + for (final ProducerRecord record : needRetry) { + System.out.println("retry producing " + stringSerde.deserializer().deserialize("", record.key())); + producer.send(record, new TestCallback(record, needRetry2)); + } + producer.flush(); + needRetry = needRetry2; + + if (--remainingRetries == 0 && !needRetry.isEmpty()) { + System.err.println("Failed to produce all records after multiple retries"); + Exit.exit(1); + } + } + + // now that we've sent everything, we'll send some final records with a timestamp high enough to flush out + // all suppressed records. + final List partitions = producer.partitionsFor("data"); + for (final PartitionInfo partition : partitions) { + producer.send(new ProducerRecord<>( + partition.topic(), + partition.partition(), + System.currentTimeMillis() + Duration.ofDays(2).toMillis(), + stringSerde.serializer().serialize("", "flush"), + intSerde.serializer().serialize("", 0) + )); + } + } + return Collections.unmodifiableMap(allData); + } + + private static Properties generatorProperties(final String kafka) { + final Properties producerProps = new Properties(); + producerProps.put(ProducerConfig.CLIENT_ID_CONFIG, "SmokeTest"); + producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, kafka); + producerProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class); + producerProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class); + producerProps.put(ProducerConfig.ACKS_CONFIG, "all"); + return producerProps; + } + + private static class TestCallback implements Callback { + private final ProducerRecord originalRecord; + private final List> needRetry; + + TestCallback(final ProducerRecord originalRecord, + final List> needRetry) { + this.originalRecord = originalRecord; + this.needRetry = needRetry; + } + + @Override + public void onCompletion(final RecordMetadata metadata, final Exception exception) { + if (exception != null) { + if (exception instanceof TimeoutException) { + needRetry.add(originalRecord); + } else { + exception.printStackTrace(); + Exit.exit(1); + } + } + } + } + + private static void shuffle(final int[] data, @SuppressWarnings("SameParameterValue") final int windowSize) { + final Random rand = new Random(); + for (int i = 0; i < data.length; i++) { + // we shuffle data within windowSize + final int j = rand.nextInt(Math.min(data.length - i, windowSize)) + i; + + // swap + final int tmp = data[i]; + data[i] = data[j]; + data[j] = tmp; + } + } + + public static class NumberDeserializer implements Deserializer { + @Override + public Number deserialize(final String topic, final byte[] data) { + final Number value; + switch (topic) { + case "data": + case "echo": + case "min": + case "min-raw": + case "min-suppressed": + case "sws-raw": + case "sws-suppressed": + case "max": + case "dif": + value = intSerde.deserializer().deserialize(topic, data); + break; + case "sum": + case "cnt": + case "tagg": + value = longSerde.deserializer().deserialize(topic, data); + break; + case "avg": + value = doubleSerde.deserializer().deserialize(topic, data); + break; + default: + throw new RuntimeException("unknown topic: " + topic); + } + return value; + } + } + + public static VerificationResult verify(final String kafka, + final Map> inputs, + final int maxRecordsPerKey) { + final Properties props = new Properties(); + props.put(ConsumerConfig.CLIENT_ID_CONFIG, "verifier"); + props.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, kafka); + props.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); + props.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, NumberDeserializer.class); + props.put(ConsumerConfig.ISOLATION_LEVEL_CONFIG, "read_committed"); + + final KafkaConsumer consumer = new KafkaConsumer<>(props); + final List partitions = getAllPartitions(consumer, TOPICS); + consumer.assign(partitions); + consumer.seekToBeginning(partitions); + + final int recordsGenerated = inputs.size() * maxRecordsPerKey; + int recordsProcessed = 0; + final Map processed = + Stream.of(TOPICS) + .collect(Collectors.toMap(t -> t, t -> new AtomicInteger(0))); + + final Map>>> events = new HashMap<>(); + + VerificationResult verificationResult = new VerificationResult(false, "no results yet"); + int retry = 0; + final long start = System.currentTimeMillis(); + while (System.currentTimeMillis() - start < TimeUnit.MINUTES.toMillis(6)) { + final ConsumerRecords records = consumer.poll(Duration.ofSeconds(5)); + if (records.isEmpty() && recordsProcessed >= recordsGenerated) { + verificationResult = verifyAll(inputs, events, false); + if (verificationResult.passed()) { + break; + } else if (retry++ > MAX_RECORD_EMPTY_RETRIES) { + System.out.println(Instant.now() + " Didn't get any more results, verification hasn't passed, and out of retries."); + break; + } else { + System.out.println(Instant.now() + " Didn't get any more results, but verification hasn't passed (yet). Retrying..." + retry); + } + } else { + System.out.println(Instant.now() + " Get some more results from " + records.partitions() + ", resetting retry."); + + retry = 0; + for (final ConsumerRecord record : records) { + final String key = record.key(); + + final String topic = record.topic(); + processed.get(topic).incrementAndGet(); + + if (topic.equals("echo")) { + recordsProcessed++; + if (recordsProcessed % 100 == 0) { + System.out.println("Echo records processed = " + recordsProcessed); + } + } + + events.computeIfAbsent(topic, t -> new HashMap<>()) + .computeIfAbsent(key, k -> new LinkedList<>()) + .add(record); + } + + System.out.println(processed); + } + } + consumer.close(); + final long finished = System.currentTimeMillis() - start; + System.out.println("Verification time=" + finished); + System.out.println("-------------------"); + System.out.println("Result Verification"); + System.out.println("-------------------"); + System.out.println("recordGenerated=" + recordsGenerated); + System.out.println("recordProcessed=" + recordsProcessed); + + if (recordsProcessed > recordsGenerated) { + System.out.println("PROCESSED-MORE-THAN-GENERATED"); + } else if (recordsProcessed < recordsGenerated) { + System.out.println("PROCESSED-LESS-THAN-GENERATED"); + } + + boolean success; + + final Map> received = + events.get("echo") + .entrySet() + .stream() + .map(entry -> mkEntry( + entry.getKey(), + entry.getValue().stream().map(ConsumerRecord::value).collect(Collectors.toSet())) + ) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + success = inputs.equals(received); + + if (success) { + System.out.println("ALL-RECORDS-DELIVERED"); + } else { + int missedCount = 0; + for (final Map.Entry> entry : inputs.entrySet()) { + missedCount += received.get(entry.getKey()).size(); + } + System.out.println("missedRecords=" + missedCount); + } + + // give it one more try if it's not already passing. + if (!verificationResult.passed()) { + verificationResult = verifyAll(inputs, events, true); + } + success &= verificationResult.passed(); + + System.out.println(verificationResult.result()); + + System.out.println(success ? "SUCCESS" : "FAILURE"); + return verificationResult; + } + + public static class VerificationResult { + private final boolean passed; + private final String result; + + VerificationResult(final boolean passed, final String result) { + this.passed = passed; + this.result = result; + } + + public boolean passed() { + return passed; + } + + public String result() { + return result; + } + } + + private static VerificationResult verifyAll(final Map> inputs, + final Map>>> events, + final boolean printResults) { + final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + boolean pass; + try (final PrintStream resultStream = new PrintStream(byteArrayOutputStream)) { + pass = verifyTAgg(resultStream, inputs, events.get("tagg"), printResults); + pass &= verifySuppressed(resultStream, "min-suppressed", events, printResults); + pass &= verify(resultStream, "min-suppressed", inputs, events, windowedKey -> { + final String unwindowedKey = windowedKey.substring(1, windowedKey.length() - 1).replaceAll("@.*", ""); + return getMin(unwindowedKey); + }, printResults); + pass &= verifySuppressed(resultStream, "sws-suppressed", events, printResults); + pass &= verify(resultStream, "min", inputs, events, SmokeTestDriver::getMin, printResults); + pass &= verify(resultStream, "max", inputs, events, SmokeTestDriver::getMax, printResults); + pass &= verify(resultStream, "dif", inputs, events, key -> getMax(key).intValue() - getMin(key).intValue(), printResults); + pass &= verify(resultStream, "sum", inputs, events, SmokeTestDriver::getSum, printResults); + pass &= verify(resultStream, "cnt", inputs, events, key1 -> getMax(key1).intValue() - getMin(key1).intValue() + 1L, printResults); + pass &= verify(resultStream, "avg", inputs, events, SmokeTestDriver::getAvg, printResults); + } + return new VerificationResult(pass, new String(byteArrayOutputStream.toByteArray(), StandardCharsets.UTF_8)); + } + + private static boolean verify(final PrintStream resultStream, + final String topic, + final Map> inputData, + final Map>>> events, + final Function keyToExpectation, + final boolean printResults) { + final Map>> observedInputEvents = events.get("data"); + final Map>> outputEvents = events.getOrDefault(topic, emptyMap()); + if (outputEvents.isEmpty()) { + resultStream.println(topic + " is empty"); + return false; + } else { + resultStream.printf("verifying %s with %d keys%n", topic, outputEvents.size()); + + if (outputEvents.size() != inputData.size()) { + resultStream.printf("fail: resultCount=%d expectedCount=%s%n\tresult=%s%n\texpected=%s%n", + outputEvents.size(), inputData.size(), outputEvents.keySet(), inputData.keySet()); + return false; + } + for (final Map.Entry>> entry : outputEvents.entrySet()) { + final String key = entry.getKey(); + final Number expected = keyToExpectation.apply(key); + final Number actual = entry.getValue().getLast().value(); + if (!expected.equals(actual)) { + resultStream.printf("%s fail: key=%s actual=%s expected=%s%n", topic, key, actual, expected); + + if (printResults) { + resultStream.printf("\t inputEvents=%n%s%n\t" + + "echoEvents=%n%s%n\tmaxEvents=%n%s%n\tminEvents=%n%s%n\tdifEvents=%n%s%n\tcntEvents=%n%s%n\ttaggEvents=%n%s%n", + indent("\t\t", observedInputEvents.get(key)), + indent("\t\t", events.getOrDefault("echo", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("max", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("min", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("dif", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("cnt", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("tagg", emptyMap()).getOrDefault(key, new LinkedList<>()))); + + if (!Utils.mkSet("echo", "max", "min", "dif", "cnt", "tagg").contains(topic)) + resultStream.printf("%sEvents=%n%s%n", topic, indent("\t\t", entry.getValue())); + } + + return false; + } + } + return true; + } + } + + + private static boolean verifySuppressed(final PrintStream resultStream, + @SuppressWarnings("SameParameterValue") final String topic, + final Map>>> events, + final boolean printResults) { + resultStream.println("verifying suppressed " + topic); + final Map>> topicEvents = events.getOrDefault(topic, emptyMap()); + for (final Map.Entry>> entry : topicEvents.entrySet()) { + if (entry.getValue().size() != 1) { + final String unsuppressedTopic = topic.replace("-suppressed", "-raw"); + final String key = entry.getKey(); + final String unwindowedKey = key.substring(1, key.length() - 1).replaceAll("@.*", ""); + resultStream.printf("fail: key=%s%n\tnon-unique result:%n%s%n", + key, + indent("\t\t", entry.getValue())); + + if (printResults) + resultStream.printf("\tresultEvents:%n%s%n\tinputEvents:%n%s%n", + indent("\t\t", events.get(unsuppressedTopic).get(key)), + indent("\t\t", events.get("data").get(unwindowedKey))); + + return false; + } + } + return true; + } + + private static String indent(@SuppressWarnings("SameParameterValue") final String prefix, + final Iterable> list) { + final StringBuilder stringBuilder = new StringBuilder(); + for (final ConsumerRecord record : list) { + stringBuilder.append(prefix).append(record).append('\n'); + } + return stringBuilder.toString(); + } + + private static Long getSum(final String key) { + final int min = getMin(key).intValue(); + final int max = getMax(key).intValue(); + return ((long) min + max) * (max - min + 1L) / 2L; + } + + private static Double getAvg(final String key) { + final int min = getMin(key).intValue(); + final int max = getMax(key).intValue(); + return ((long) min + max) / 2.0; + } + + + private static boolean verifyTAgg(final PrintStream resultStream, + final Map> allData, + final Map>> taggEvents, + final boolean printResults) { + if (taggEvents == null) { + resultStream.println("tagg is missing"); + return false; + } else if (taggEvents.isEmpty()) { + resultStream.println("tagg is empty"); + return false; + } else { + resultStream.println("verifying tagg"); + + // generate expected answer + final Map expected = new HashMap<>(); + for (final String key : allData.keySet()) { + final int min = getMin(key).intValue(); + final int max = getMax(key).intValue(); + final String cnt = Long.toString(max - min + 1L); + + expected.put(cnt, expected.getOrDefault(cnt, 0L) + 1); + } + + // check the result + for (final Map.Entry>> entry : taggEvents.entrySet()) { + final String key = entry.getKey(); + Long expectedCount = expected.remove(key); + if (expectedCount == null) { + expectedCount = 0L; + } + + if (entry.getValue().getLast().value().longValue() != expectedCount) { + resultStream.println("fail: key=" + key + " tagg=" + entry.getValue() + " expected=" + expectedCount); + + if (printResults) + resultStream.println("\t taggEvents: " + entry.getValue()); + return false; + } + } + + } + return true; + } + + private static Number getMin(final String key) { + return Integer.parseInt(key.split("-")[0]); + } + + private static Number getMax(final String key) { + return Integer.parseInt(key.split("-")[1]); + } + + private static List getAllPartitions(final KafkaConsumer consumer, final String... topics) { + final List partitions = new ArrayList<>(); + + for (final String topic : topics) { + for (final PartitionInfo info : consumer.partitionsFor(topic)) { + partitions.add(new TopicPartition(info.topic(), info.partition())); + } + } + return partitions; + } + +} diff --git a/streams/upgrade-system-tests-25/src/test/java/org/apache/kafka/streams/tests/SmokeTestUtil.java b/streams/upgrade-system-tests-25/src/test/java/org/apache/kafka/streams/tests/SmokeTestUtil.java new file mode 100644 index 0000000..e8ec04c --- /dev/null +++ b/streams/upgrade-system-tests-25/src/test/java/org/apache/kafka/streams/tests/SmokeTestUtil.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.AbstractProcessor; +import org.apache.kafka.streams.processor.Processor; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.ProcessorSupplier; + +import java.time.Instant; + +public class SmokeTestUtil { + + final static int END = Integer.MAX_VALUE; + + static ProcessorSupplier printProcessorSupplier(final String topic) { + return printProcessorSupplier(topic, ""); + } + + static ProcessorSupplier printProcessorSupplier(final String topic, final String name) { + return new ProcessorSupplier() { + @Override + public Processor get() { + return new AbstractProcessor() { + private int numRecordsProcessed = 0; + private long smallestOffset = Long.MAX_VALUE; + private long largestOffset = Long.MIN_VALUE; + + @Override + public void init(final ProcessorContext context) { + super.init(context); + System.out.println("[DEV] initializing processor: topic=" + topic + " taskId=" + context.taskId()); + System.out.flush(); + numRecordsProcessed = 0; + smallestOffset = Long.MAX_VALUE; + largestOffset = Long.MIN_VALUE; + } + + @Override + public void process(final Object key, final Object value) { + numRecordsProcessed++; + if (numRecordsProcessed % 100 == 0) { + System.out.printf("%s: %s%n", name, Instant.now()); + System.out.println("processed " + numRecordsProcessed + " records from topic=" + topic); + } + + if (smallestOffset > context().offset()) { + smallestOffset = context().offset(); + } + if (largestOffset < context().offset()) { + largestOffset = context().offset(); + } + } + + @Override + public void close() { + System.out.printf("Close processor for task %s%n", context().taskId()); + System.out.println("processed " + numRecordsProcessed + " records"); + final long processed; + if (largestOffset >= smallestOffset) { + processed = 1L + largestOffset - smallestOffset; + } else { + processed = 0L; + } + System.out.println("offset " + smallestOffset + " to " + largestOffset + " -> processed " + processed); + System.out.flush(); + } + }; + } + }; + } + + public static final class Unwindow implements KeyValueMapper, V, K> { + @Override + public K apply(final Windowed winKey, final V value) { + return winKey.key(); + } + } + + public static class Agg { + + KeyValueMapper> selector() { + return (key, value) -> new KeyValue<>(value == null ? null : Long.toString(value), 1L); + } + + public Initializer init() { + return () -> 0L; + } + + Aggregator adder() { + return (aggKey, value, aggregate) -> aggregate + value; + } + + Aggregator remover() { + return (aggKey, value, aggregate) -> aggregate - value; + } + } + + public static Serde stringSerde = Serdes.String(); + + public static Serde intSerde = Serdes.Integer(); + + static Serde longSerde = Serdes.Long(); + + static Serde doubleSerde = Serdes.Double(); + + public static void sleep(final long duration) { + try { + Thread.sleep(duration); + } catch (final Exception ignore) { } + } + +} diff --git a/streams/upgrade-system-tests-25/src/test/java/org/apache/kafka/streams/tests/StreamsSmokeTest.java b/streams/upgrade-system-tests-25/src/test/java/org/apache/kafka/streams/tests/StreamsSmokeTest.java new file mode 100644 index 0000000..f280eb0 --- /dev/null +++ b/streams/upgrade-system-tests-25/src/test/java/org/apache/kafka/streams/tests/StreamsSmokeTest.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.StreamsConfig; + +import java.io.IOException; +import java.time.Duration; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.UUID; + +import static org.apache.kafka.streams.tests.SmokeTestDriver.generate; +import static org.apache.kafka.streams.tests.SmokeTestDriver.generatePerpetually; + +public class StreamsSmokeTest { + + /** + * args ::= kafka propFileName command disableAutoTerminate + * command := "run" | "process" + * + * @param args + */ + public static void main(final String[] args) throws IOException { + if (args.length < 2) { + System.err.println("StreamsSmokeTest are expecting two parameters: propFile, command; but only see " + args.length + " parameter"); + Exit.exit(1); + } + + final String propFileName = args[0]; + final String command = args[1]; + final boolean disableAutoTerminate = args.length > 2; + + final Properties streamsProperties = Utils.loadProps(propFileName); + final String kafka = streamsProperties.getProperty(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG); + final String processingGuarantee = streamsProperties.getProperty(StreamsConfig.PROCESSING_GUARANTEE_CONFIG); + + if (kafka == null) { + System.err.println("No bootstrap kafka servers specified in " + StreamsConfig.BOOTSTRAP_SERVERS_CONFIG); + Exit.exit(1); + } + + if ("process".equals(command)) { + if (!StreamsConfig.AT_LEAST_ONCE.equals(processingGuarantee) && + !StreamsConfig.EXACTLY_ONCE.equals(processingGuarantee)) { + + System.err.println("processingGuarantee must be either " + StreamsConfig.AT_LEAST_ONCE + " or " + + StreamsConfig.EXACTLY_ONCE); + + Exit.exit(1); + } + } + + System.out.println("StreamsTest instance started (StreamsSmokeTest)"); + System.out.println("command=" + command); + System.out.println("props=" + streamsProperties); + System.out.println("disableAutoTerminate=" + disableAutoTerminate); + + switch (command) { + case "run": + // this starts the driver (data generation and result verification) + final int numKeys = 10; + final int maxRecordsPerKey = 500; + if (disableAutoTerminate) { + generatePerpetually(kafka, numKeys, maxRecordsPerKey); + } else { + // slow down data production to span 30 seconds so that system tests have time to + // do their bounces, etc. + final Map> allData = + generate(kafka, numKeys, maxRecordsPerKey, Duration.ofSeconds(30)); + SmokeTestDriver.verify(kafka, allData, maxRecordsPerKey); + } + break; + case "process": + // this starts the stream processing app + new SmokeTestClient(UUID.randomUUID().toString()).start(streamsProperties); + break; + default: + System.out.println("unknown command: " + command); + } + } + +} diff --git a/streams/upgrade-system-tests-25/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java b/streams/upgrade-system-tests-25/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java new file mode 100644 index 0000000..0fea040 --- /dev/null +++ b/streams/upgrade-system-tests-25/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.processor.AbstractProcessor; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.ProcessorSupplier; + +import java.util.Properties; + +public class StreamsUpgradeTest { + + @SuppressWarnings("unchecked") + public static void main(final String[] args) throws Exception { + if (args.length < 1) { + System.err.println("StreamsUpgradeTest requires one argument (properties-file) but provided none"); + } + final String propFileName = args[0]; + + final Properties streamsProperties = Utils.loadProps(propFileName); + + System.out.println("StreamsTest instance started (StreamsUpgradeTest v2.5)"); + System.out.println("props=" + streamsProperties); + + final StreamsBuilder builder = new StreamsBuilder(); + final KStream dataStream = builder.stream("data"); + dataStream.process(printProcessorSupplier()); + dataStream.to("echo"); + + final Properties config = new Properties(); + config.setProperty(StreamsConfig.APPLICATION_ID_CONFIG, "StreamsUpgradeTest"); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + config.putAll(streamsProperties); + + final KafkaStreams streams = new KafkaStreams(builder.build(), config); + streams.start(); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + streams.close(); + System.out.println("UPGRADE-TEST-CLIENT-CLOSED"); + System.out.flush(); + })); + } + + private static ProcessorSupplier printProcessorSupplier() { + return () -> new AbstractProcessor() { + private int numRecordsProcessed = 0; + + @Override + public void init(final ProcessorContext context) { + System.out.println("[2.5] initializing processor: topic=data taskId=" + context.taskId()); + numRecordsProcessed = 0; + } + + @Override + public void process(final K key, final V value) { + numRecordsProcessed++; + if (numRecordsProcessed % 100 == 0) { + System.out.println("processed " + numRecordsProcessed + " records from topic=data"); + } + } + + @Override + public void close() {} + }; + } +} diff --git a/streams/upgrade-system-tests-26/src/test/java/org/apache/kafka/streams/tests/SmokeTestClient.java b/streams/upgrade-system-tests-26/src/test/java/org/apache/kafka/streams/tests/SmokeTestClient.java new file mode 100644 index 0000000..ced1369 --- /dev/null +++ b/streams/upgrade-system-tests-26/src/test/java/org/apache/kafka/streams/tests/SmokeTestClient.java @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.KafkaThread; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.Suppressed.BufferConfig; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.WindowStore; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.time.Duration; +import java.time.Instant; +import java.util.Properties; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.apache.kafka.streams.kstream.Suppressed.untilWindowCloses; + +public class SmokeTestClient extends SmokeTestUtil { + + private final String name; + + private KafkaStreams streams; + private boolean uncaughtException = false; + private boolean started; + private volatile boolean closed; + + private static void addShutdownHook(final String name, final Runnable runnable) { + if (name != null) { + Runtime.getRuntime().addShutdownHook(KafkaThread.nonDaemon(name, runnable)); + } else { + Runtime.getRuntime().addShutdownHook(new Thread(runnable)); + } + } + + private static File tempDirectory() { + final String prefix = "kafka-"; + final File file; + try { + file = Files.createTempDirectory(prefix).toFile(); + } catch (final IOException ex) { + throw new RuntimeException("Failed to create a temp dir", ex); + } + file.deleteOnExit(); + + addShutdownHook("delete-temp-file-shutdown-hook", () -> { + try { + Utils.delete(file); + } catch (final IOException e) { + System.out.println("Error deleting " + file.getAbsolutePath()); + e.printStackTrace(System.out); + } + }); + + return file; + } + + public SmokeTestClient(final String name) { + this.name = name; + } + + public boolean started() { + return started; + } + + public boolean closed() { + return closed; + } + + public void start(final Properties streamsProperties) { + final Topology build = getTopology(); + streams = new KafkaStreams(build, getStreamsConfig(streamsProperties)); + + final CountDownLatch countDownLatch = new CountDownLatch(1); + streams.setStateListener((newState, oldState) -> { + System.out.printf("%s %s: %s -> %s%n", name, Instant.now(), oldState, newState); + if (oldState == KafkaStreams.State.REBALANCING && newState == KafkaStreams.State.RUNNING) { + started = true; + countDownLatch.countDown(); + } + + if (newState == KafkaStreams.State.NOT_RUNNING) { + closed = true; + } + }); + + streams.setUncaughtExceptionHandler((t, e) -> { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION"); + System.out.println(name + ": FATAL: An unexpected exception is encountered on thread " + t + ": " + e); + e.printStackTrace(System.out); + uncaughtException = true; + streams.close(Duration.ofSeconds(30)); + }); + + addShutdownHook("streams-shutdown-hook", this::close); + + streams.start(); + try { + if (!countDownLatch.await(1, TimeUnit.MINUTES)) { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION: Didn't start in one minute"); + } + } catch (final InterruptedException e) { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION: " + e); + e.printStackTrace(System.out); + } + System.out.println(name + ": SMOKE-TEST-CLIENT-STARTED"); + System.out.println(name + " started at " + Instant.now()); + } + + public void closeAsync() { + streams.close(Duration.ZERO); + } + + public void close() { + final boolean closed = streams.close(Duration.ofMinutes(1)); + + if (closed && !uncaughtException) { + System.out.println(name + ": SMOKE-TEST-CLIENT-CLOSED"); + } else if (closed) { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION"); + } else { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION: Didn't close"); + } + } + + private Properties getStreamsConfig(final Properties props) { + final Properties fullProps = new Properties(props); + fullProps.put(StreamsConfig.APPLICATION_ID_CONFIG, "SmokeTest"); + fullProps.put(StreamsConfig.CLIENT_ID_CONFIG, "SmokeTest-" + name); + fullProps.put(StreamsConfig.STATE_DIR_CONFIG, tempDirectory().getAbsolutePath()); + fullProps.putAll(props); + return fullProps; + } + + public Topology getTopology() { + final StreamsBuilder builder = new StreamsBuilder(); + final Consumed stringIntConsumed = Consumed.with(stringSerde, intSerde); + final KStream source = builder.stream("data", stringIntConsumed); + source.filterNot((k, v) -> k.equals("flush")) + .to("echo", Produced.with(stringSerde, intSerde)); + final KStream data = source.filter((key, value) -> value == null || value != END); + data.process(SmokeTestUtil.printProcessorSupplier("data", name)); + + // min + final KGroupedStream groupedData = data.groupByKey(Grouped.with(stringSerde, intSerde)); + + final KTable, Integer> minAggregation = groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(1)).grace(Duration.ofMinutes(1))) + .aggregate( + () -> Integer.MAX_VALUE, + (aggKey, value, aggregate) -> (value < aggregate) ? value : aggregate, + Materialized + .>as("uwin-min") + .withValueSerde(intSerde) + .withRetention(Duration.ofHours(25)) + ); + + streamify(minAggregation, "min-raw"); + + streamify(minAggregation.suppress(untilWindowCloses(BufferConfig.unbounded())), "min-suppressed"); + + minAggregation + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("min", Produced.with(stringSerde, intSerde)); + + final KTable, Integer> smallWindowSum = groupedData + .windowedBy(TimeWindows.of(Duration.ofSeconds(2)).advanceBy(Duration.ofSeconds(1)).grace(Duration.ofSeconds(30))) + .reduce((l, r) -> l + r); + + streamify(smallWindowSum, "sws-raw"); + streamify(smallWindowSum.suppress(untilWindowCloses(BufferConfig.unbounded())), "sws-suppressed"); + + final KTable minTable = builder.table( + "min", + Consumed.with(stringSerde, intSerde), + Materialized.as("minStoreName")); + + minTable.toStream().process(SmokeTestUtil.printProcessorSupplier("min", name)); + + // max + groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(2))) + .aggregate( + () -> Integer.MIN_VALUE, + (aggKey, value, aggregate) -> (value > aggregate) ? value : aggregate, + Materialized.>as("uwin-max").withValueSerde(intSerde)) + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("max", Produced.with(stringSerde, intSerde)); + + final KTable maxTable = builder.table( + "max", + Consumed.with(stringSerde, intSerde), + Materialized.as("maxStoreName")); + maxTable.toStream().process(SmokeTestUtil.printProcessorSupplier("max", name)); + + // sum + groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(2))) + .aggregate( + () -> 0L, + (aggKey, value, aggregate) -> (long) value + aggregate, + Materialized.>as("win-sum").withValueSerde(longSerde)) + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("sum", Produced.with(stringSerde, longSerde)); + + final Consumed stringLongConsumed = Consumed.with(stringSerde, longSerde); + final KTable sumTable = builder.table("sum", stringLongConsumed); + sumTable.toStream().process(SmokeTestUtil.printProcessorSupplier("sum", name)); + + // cnt + groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(2))) + .count(Materialized.as("uwin-cnt")) + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("cnt", Produced.with(stringSerde, longSerde)); + + final KTable cntTable = builder.table( + "cnt", + Consumed.with(stringSerde, longSerde), + Materialized.as("cntStoreName")); + cntTable.toStream().process(SmokeTestUtil.printProcessorSupplier("cnt", name)); + + // dif + maxTable + .join( + minTable, + (value1, value2) -> value1 - value2) + .toStream() + .filterNot((k, v) -> k.equals("flush")) + .to("dif", Produced.with(stringSerde, intSerde)); + + // avg + sumTable + .join( + cntTable, + (value1, value2) -> (double) value1 / (double) value2) + .toStream() + .filterNot((k, v) -> k.equals("flush")) + .to("avg", Produced.with(stringSerde, doubleSerde)); + + // test repartition + final Agg agg = new Agg(); + cntTable.groupBy(agg.selector(), Grouped.with(stringSerde, longSerde)) + .aggregate(agg.init(), agg.adder(), agg.remover(), + Materialized.as(Stores.inMemoryKeyValueStore("cntByCnt")) + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.Long())) + .toStream() + .to("tagg", Produced.with(stringSerde, longSerde)); + + return builder.build(); + } + + private static void streamify(final KTable, Integer> windowedTable, final String topic) { + windowedTable + .toStream() + .filterNot((k, v) -> k.key().equals("flush")) + .map((key, value) -> new KeyValue<>(key.toString(), value)) + .to(topic, Produced.with(stringSerde, intSerde)); + } +} diff --git a/streams/upgrade-system-tests-26/src/test/java/org/apache/kafka/streams/tests/SmokeTestDriver.java b/streams/upgrade-system-tests-26/src/test/java/org/apache/kafka/streams/tests/SmokeTestDriver.java new file mode 100644 index 0000000..ac83cd9 --- /dev/null +++ b/streams/upgrade-system-tests-26/src/test/java/org/apache/kafka/streams/tests/SmokeTestDriver.java @@ -0,0 +1,622 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.Callback; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; + +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static java.util.Collections.emptyMap; +import static org.apache.kafka.common.utils.Utils.mkEntry; + +public class SmokeTestDriver extends SmokeTestUtil { + private static final String[] TOPICS = { + "data", + "echo", + "max", + "min", "min-suppressed", "min-raw", + "dif", + "sum", + "sws-raw", "sws-suppressed", + "cnt", + "avg", + "tagg" + }; + + private static final int MAX_RECORD_EMPTY_RETRIES = 30; + + private static class ValueList { + public final String key; + private final int[] values; + private int index; + + ValueList(final int min, final int max) { + key = min + "-" + max; + + values = new int[max - min + 1]; + for (int i = 0; i < values.length; i++) { + values[i] = min + i; + } + // We want to randomize the order of data to test not completely predictable processing order + // However, values are also use as a timestamp of the record. (TODO: separate data and timestamp) + // We keep some correlation of time and order. Thus, the shuffling is done with a sliding window + shuffle(values, 10); + + index = 0; + } + + int next() { + return (index < values.length) ? values[index++] : -1; + } + } + + public static String[] topics() { + return Arrays.copyOf(TOPICS, TOPICS.length); + } + + static void generatePerpetually(final String kafka, + final int numKeys, + final int maxRecordsPerKey) { + final Properties producerProps = generatorProperties(kafka); + + int numRecordsProduced = 0; + + final ValueList[] data = new ValueList[numKeys]; + for (int i = 0; i < numKeys; i++) { + data[i] = new ValueList(i, i + maxRecordsPerKey - 1); + } + + final Random rand = new Random(); + + try (final KafkaProducer producer = new KafkaProducer<>(producerProps)) { + while (true) { + final int index = rand.nextInt(numKeys); + final String key = data[index].key; + final int value = data[index].next(); + + final ProducerRecord record = + new ProducerRecord<>( + "data", + stringSerde.serializer().serialize("", key), + intSerde.serializer().serialize("", value) + ); + + producer.send(record); + + numRecordsProduced++; + if (numRecordsProduced % 100 == 0) { + System.out.println(Instant.now() + " " + numRecordsProduced + " records produced"); + } + Utils.sleep(2); + } + } + } + + public static Map> generate(final String kafka, + final int numKeys, + final int maxRecordsPerKey, + final Duration timeToSpend) { + final Properties producerProps = generatorProperties(kafka); + + + int numRecordsProduced = 0; + + final Map> allData = new HashMap<>(); + final ValueList[] data = new ValueList[numKeys]; + for (int i = 0; i < numKeys; i++) { + data[i] = new ValueList(i, i + maxRecordsPerKey - 1); + allData.put(data[i].key, new HashSet<>()); + } + final Random rand = new Random(); + + int remaining = data.length; + + final long recordPauseTime = timeToSpend.toMillis() / numKeys / maxRecordsPerKey; + + List> needRetry = new ArrayList<>(); + + try (final KafkaProducer producer = new KafkaProducer<>(producerProps)) { + while (remaining > 0) { + final int index = rand.nextInt(remaining); + final String key = data[index].key; + final int value = data[index].next(); + + if (value < 0) { + remaining--; + data[index] = data[remaining]; + } else { + + final ProducerRecord record = + new ProducerRecord<>( + "data", + stringSerde.serializer().serialize("", key), + intSerde.serializer().serialize("", value) + ); + + producer.send(record, new TestCallback(record, needRetry)); + + numRecordsProduced++; + allData.get(key).add(value); + if (numRecordsProduced % 100 == 0) { + System.out.println(Instant.now() + " " + numRecordsProduced + " records produced"); + } + Utils.sleep(Math.max(recordPauseTime, 2)); + } + } + producer.flush(); + + int remainingRetries = 5; + while (!needRetry.isEmpty()) { + final List> needRetry2 = new ArrayList<>(); + for (final ProducerRecord record : needRetry) { + System.out.println("retry producing " + stringSerde.deserializer().deserialize("", record.key())); + producer.send(record, new TestCallback(record, needRetry2)); + } + producer.flush(); + needRetry = needRetry2; + + if (--remainingRetries == 0 && !needRetry.isEmpty()) { + System.err.println("Failed to produce all records after multiple retries"); + Exit.exit(1); + } + } + + // now that we've sent everything, we'll send some final records with a timestamp high enough to flush out + // all suppressed records. + final List partitions = producer.partitionsFor("data"); + for (final PartitionInfo partition : partitions) { + producer.send(new ProducerRecord<>( + partition.topic(), + partition.partition(), + System.currentTimeMillis() + Duration.ofDays(2).toMillis(), + stringSerde.serializer().serialize("", "flush"), + intSerde.serializer().serialize("", 0) + )); + } + } + return Collections.unmodifiableMap(allData); + } + + private static Properties generatorProperties(final String kafka) { + final Properties producerProps = new Properties(); + producerProps.put(ProducerConfig.CLIENT_ID_CONFIG, "SmokeTest"); + producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, kafka); + producerProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class); + producerProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class); + producerProps.put(ProducerConfig.ACKS_CONFIG, "all"); + return producerProps; + } + + private static class TestCallback implements Callback { + private final ProducerRecord originalRecord; + private final List> needRetry; + + TestCallback(final ProducerRecord originalRecord, + final List> needRetry) { + this.originalRecord = originalRecord; + this.needRetry = needRetry; + } + + @Override + public void onCompletion(final RecordMetadata metadata, final Exception exception) { + if (exception != null) { + if (exception instanceof TimeoutException) { + needRetry.add(originalRecord); + } else { + exception.printStackTrace(); + Exit.exit(1); + } + } + } + } + + private static void shuffle(final int[] data, @SuppressWarnings("SameParameterValue") final int windowSize) { + final Random rand = new Random(); + for (int i = 0; i < data.length; i++) { + // we shuffle data within windowSize + final int j = rand.nextInt(Math.min(data.length - i, windowSize)) + i; + + // swap + final int tmp = data[i]; + data[i] = data[j]; + data[j] = tmp; + } + } + + public static class NumberDeserializer implements Deserializer { + @Override + public Number deserialize(final String topic, final byte[] data) { + final Number value; + switch (topic) { + case "data": + case "echo": + case "min": + case "min-raw": + case "min-suppressed": + case "sws-raw": + case "sws-suppressed": + case "max": + case "dif": + value = intSerde.deserializer().deserialize(topic, data); + break; + case "sum": + case "cnt": + case "tagg": + value = longSerde.deserializer().deserialize(topic, data); + break; + case "avg": + value = doubleSerde.deserializer().deserialize(topic, data); + break; + default: + throw new RuntimeException("unknown topic: " + topic); + } + return value; + } + } + + public static VerificationResult verify(final String kafka, + final Map> inputs, + final int maxRecordsPerKey) { + final Properties props = new Properties(); + props.put(ConsumerConfig.CLIENT_ID_CONFIG, "verifier"); + props.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, kafka); + props.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); + props.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, NumberDeserializer.class); + props.put(ConsumerConfig.ISOLATION_LEVEL_CONFIG, "read_committed"); + + final KafkaConsumer consumer = new KafkaConsumer<>(props); + final List partitions = getAllPartitions(consumer, TOPICS); + consumer.assign(partitions); + consumer.seekToBeginning(partitions); + + final int recordsGenerated = inputs.size() * maxRecordsPerKey; + int recordsProcessed = 0; + final Map processed = + Stream.of(TOPICS) + .collect(Collectors.toMap(t -> t, t -> new AtomicInteger(0))); + + final Map>>> events = new HashMap<>(); + + VerificationResult verificationResult = new VerificationResult(false, "no results yet"); + int retry = 0; + final long start = System.currentTimeMillis(); + while (System.currentTimeMillis() - start < TimeUnit.MINUTES.toMillis(6)) { + final ConsumerRecords records = consumer.poll(Duration.ofSeconds(5)); + if (records.isEmpty() && recordsProcessed >= recordsGenerated) { + verificationResult = verifyAll(inputs, events, false); + if (verificationResult.passed()) { + break; + } else if (retry++ > MAX_RECORD_EMPTY_RETRIES) { + System.out.println(Instant.now() + " Didn't get any more results, verification hasn't passed, and out of retries."); + break; + } else { + System.out.println(Instant.now() + " Didn't get any more results, but verification hasn't passed (yet). Retrying..." + retry); + } + } else { + System.out.println(Instant.now() + " Get some more results from " + records.partitions() + ", resetting retry."); + + retry = 0; + for (final ConsumerRecord record : records) { + final String key = record.key(); + + final String topic = record.topic(); + processed.get(topic).incrementAndGet(); + + if (topic.equals("echo")) { + recordsProcessed++; + if (recordsProcessed % 100 == 0) { + System.out.println("Echo records processed = " + recordsProcessed); + } + } + + events.computeIfAbsent(topic, t -> new HashMap<>()) + .computeIfAbsent(key, k -> new LinkedList<>()) + .add(record); + } + + System.out.println(processed); + } + } + consumer.close(); + final long finished = System.currentTimeMillis() - start; + System.out.println("Verification time=" + finished); + System.out.println("-------------------"); + System.out.println("Result Verification"); + System.out.println("-------------------"); + System.out.println("recordGenerated=" + recordsGenerated); + System.out.println("recordProcessed=" + recordsProcessed); + + if (recordsProcessed > recordsGenerated) { + System.out.println("PROCESSED-MORE-THAN-GENERATED"); + } else if (recordsProcessed < recordsGenerated) { + System.out.println("PROCESSED-LESS-THAN-GENERATED"); + } + + boolean success; + + final Map> received = + events.get("echo") + .entrySet() + .stream() + .map(entry -> mkEntry( + entry.getKey(), + entry.getValue().stream().map(ConsumerRecord::value).collect(Collectors.toSet())) + ) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + success = inputs.equals(received); + + if (success) { + System.out.println("ALL-RECORDS-DELIVERED"); + } else { + int missedCount = 0; + for (final Map.Entry> entry : inputs.entrySet()) { + missedCount += received.get(entry.getKey()).size(); + } + System.out.println("missedRecords=" + missedCount); + } + + // give it one more try if it's not already passing. + if (!verificationResult.passed()) { + verificationResult = verifyAll(inputs, events, true); + } + success &= verificationResult.passed(); + + System.out.println(verificationResult.result()); + + System.out.println(success ? "SUCCESS" : "FAILURE"); + return verificationResult; + } + + public static class VerificationResult { + private final boolean passed; + private final String result; + + VerificationResult(final boolean passed, final String result) { + this.passed = passed; + this.result = result; + } + + public boolean passed() { + return passed; + } + + public String result() { + return result; + } + } + + private static VerificationResult verifyAll(final Map> inputs, + final Map>>> events, + final boolean printResults) { + final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + boolean pass; + try (final PrintStream resultStream = new PrintStream(byteArrayOutputStream)) { + pass = verifyTAgg(resultStream, inputs, events.get("tagg"), printResults); + pass &= verifySuppressed(resultStream, "min-suppressed", events, printResults); + pass &= verify(resultStream, "min-suppressed", inputs, events, windowedKey -> { + final String unwindowedKey = windowedKey.substring(1, windowedKey.length() - 1).replaceAll("@.*", ""); + return getMin(unwindowedKey); + }, printResults); + pass &= verifySuppressed(resultStream, "sws-suppressed", events, printResults); + pass &= verify(resultStream, "min", inputs, events, SmokeTestDriver::getMin, printResults); + pass &= verify(resultStream, "max", inputs, events, SmokeTestDriver::getMax, printResults); + pass &= verify(resultStream, "dif", inputs, events, key -> getMax(key).intValue() - getMin(key).intValue(), printResults); + pass &= verify(resultStream, "sum", inputs, events, SmokeTestDriver::getSum, printResults); + pass &= verify(resultStream, "cnt", inputs, events, key1 -> getMax(key1).intValue() - getMin(key1).intValue() + 1L, printResults); + pass &= verify(resultStream, "avg", inputs, events, SmokeTestDriver::getAvg, printResults); + } + return new VerificationResult(pass, new String(byteArrayOutputStream.toByteArray(), StandardCharsets.UTF_8)); + } + + private static boolean verify(final PrintStream resultStream, + final String topic, + final Map> inputData, + final Map>>> events, + final Function keyToExpectation, + final boolean printResults) { + final Map>> observedInputEvents = events.get("data"); + final Map>> outputEvents = events.getOrDefault(topic, emptyMap()); + if (outputEvents.isEmpty()) { + resultStream.println(topic + " is empty"); + return false; + } else { + resultStream.printf("verifying %s with %d keys%n", topic, outputEvents.size()); + + if (outputEvents.size() != inputData.size()) { + resultStream.printf("fail: resultCount=%d expectedCount=%s%n\tresult=%s%n\texpected=%s%n", + outputEvents.size(), inputData.size(), outputEvents.keySet(), inputData.keySet()); + return false; + } + for (final Map.Entry>> entry : outputEvents.entrySet()) { + final String key = entry.getKey(); + final Number expected = keyToExpectation.apply(key); + final Number actual = entry.getValue().getLast().value(); + if (!expected.equals(actual)) { + resultStream.printf("%s fail: key=%s actual=%s expected=%s%n", topic, key, actual, expected); + + if (printResults) { + resultStream.printf("\t inputEvents=%n%s%n\t" + + "echoEvents=%n%s%n\tmaxEvents=%n%s%n\tminEvents=%n%s%n\tdifEvents=%n%s%n\tcntEvents=%n%s%n\ttaggEvents=%n%s%n", + indent("\t\t", observedInputEvents.get(key)), + indent("\t\t", events.getOrDefault("echo", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("max", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("min", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("dif", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("cnt", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("tagg", emptyMap()).getOrDefault(key, new LinkedList<>()))); + + if (!Utils.mkSet("echo", "max", "min", "dif", "cnt", "tagg").contains(topic)) + resultStream.printf("%sEvents=%n%s%n", topic, indent("\t\t", entry.getValue())); + } + + return false; + } + } + return true; + } + } + + + private static boolean verifySuppressed(final PrintStream resultStream, + @SuppressWarnings("SameParameterValue") final String topic, + final Map>>> events, + final boolean printResults) { + resultStream.println("verifying suppressed " + topic); + final Map>> topicEvents = events.getOrDefault(topic, emptyMap()); + for (final Map.Entry>> entry : topicEvents.entrySet()) { + if (entry.getValue().size() != 1) { + final String unsuppressedTopic = topic.replace("-suppressed", "-raw"); + final String key = entry.getKey(); + final String unwindowedKey = key.substring(1, key.length() - 1).replaceAll("@.*", ""); + resultStream.printf("fail: key=%s%n\tnon-unique result:%n%s%n", + key, + indent("\t\t", entry.getValue())); + + if (printResults) + resultStream.printf("\tresultEvents:%n%s%n\tinputEvents:%n%s%n", + indent("\t\t", events.get(unsuppressedTopic).get(key)), + indent("\t\t", events.get("data").get(unwindowedKey))); + + return false; + } + } + return true; + } + + private static String indent(@SuppressWarnings("SameParameterValue") final String prefix, + final Iterable> list) { + final StringBuilder stringBuilder = new StringBuilder(); + for (final ConsumerRecord record : list) { + stringBuilder.append(prefix).append(record).append('\n'); + } + return stringBuilder.toString(); + } + + private static Long getSum(final String key) { + final int min = getMin(key).intValue(); + final int max = getMax(key).intValue(); + return ((long) min + max) * (max - min + 1L) / 2L; + } + + private static Double getAvg(final String key) { + final int min = getMin(key).intValue(); + final int max = getMax(key).intValue(); + return ((long) min + max) / 2.0; + } + + + private static boolean verifyTAgg(final PrintStream resultStream, + final Map> allData, + final Map>> taggEvents, + final boolean printResults) { + if (taggEvents == null) { + resultStream.println("tagg is missing"); + return false; + } else if (taggEvents.isEmpty()) { + resultStream.println("tagg is empty"); + return false; + } else { + resultStream.println("verifying tagg"); + + // generate expected answer + final Map expected = new HashMap<>(); + for (final String key : allData.keySet()) { + final int min = getMin(key).intValue(); + final int max = getMax(key).intValue(); + final String cnt = Long.toString(max - min + 1L); + + expected.put(cnt, expected.getOrDefault(cnt, 0L) + 1); + } + + // check the result + for (final Map.Entry>> entry : taggEvents.entrySet()) { + final String key = entry.getKey(); + Long expectedCount = expected.remove(key); + if (expectedCount == null) { + expectedCount = 0L; + } + + if (entry.getValue().getLast().value().longValue() != expectedCount) { + resultStream.println("fail: key=" + key + " tagg=" + entry.getValue() + " expected=" + expectedCount); + + if (printResults) + resultStream.println("\t taggEvents: " + entry.getValue()); + return false; + } + } + + } + return true; + } + + private static Number getMin(final String key) { + return Integer.parseInt(key.split("-")[0]); + } + + private static Number getMax(final String key) { + return Integer.parseInt(key.split("-")[1]); + } + + private static List getAllPartitions(final KafkaConsumer consumer, final String... topics) { + final List partitions = new ArrayList<>(); + + for (final String topic : topics) { + for (final PartitionInfo info : consumer.partitionsFor(topic)) { + partitions.add(new TopicPartition(info.topic(), info.partition())); + } + } + return partitions; + } + +} diff --git a/streams/upgrade-system-tests-26/src/test/java/org/apache/kafka/streams/tests/SmokeTestUtil.java b/streams/upgrade-system-tests-26/src/test/java/org/apache/kafka/streams/tests/SmokeTestUtil.java new file mode 100644 index 0000000..e8ec04c --- /dev/null +++ b/streams/upgrade-system-tests-26/src/test/java/org/apache/kafka/streams/tests/SmokeTestUtil.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.AbstractProcessor; +import org.apache.kafka.streams.processor.Processor; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.ProcessorSupplier; + +import java.time.Instant; + +public class SmokeTestUtil { + + final static int END = Integer.MAX_VALUE; + + static ProcessorSupplier printProcessorSupplier(final String topic) { + return printProcessorSupplier(topic, ""); + } + + static ProcessorSupplier printProcessorSupplier(final String topic, final String name) { + return new ProcessorSupplier() { + @Override + public Processor get() { + return new AbstractProcessor() { + private int numRecordsProcessed = 0; + private long smallestOffset = Long.MAX_VALUE; + private long largestOffset = Long.MIN_VALUE; + + @Override + public void init(final ProcessorContext context) { + super.init(context); + System.out.println("[DEV] initializing processor: topic=" + topic + " taskId=" + context.taskId()); + System.out.flush(); + numRecordsProcessed = 0; + smallestOffset = Long.MAX_VALUE; + largestOffset = Long.MIN_VALUE; + } + + @Override + public void process(final Object key, final Object value) { + numRecordsProcessed++; + if (numRecordsProcessed % 100 == 0) { + System.out.printf("%s: %s%n", name, Instant.now()); + System.out.println("processed " + numRecordsProcessed + " records from topic=" + topic); + } + + if (smallestOffset > context().offset()) { + smallestOffset = context().offset(); + } + if (largestOffset < context().offset()) { + largestOffset = context().offset(); + } + } + + @Override + public void close() { + System.out.printf("Close processor for task %s%n", context().taskId()); + System.out.println("processed " + numRecordsProcessed + " records"); + final long processed; + if (largestOffset >= smallestOffset) { + processed = 1L + largestOffset - smallestOffset; + } else { + processed = 0L; + } + System.out.println("offset " + smallestOffset + " to " + largestOffset + " -> processed " + processed); + System.out.flush(); + } + }; + } + }; + } + + public static final class Unwindow implements KeyValueMapper, V, K> { + @Override + public K apply(final Windowed winKey, final V value) { + return winKey.key(); + } + } + + public static class Agg { + + KeyValueMapper> selector() { + return (key, value) -> new KeyValue<>(value == null ? null : Long.toString(value), 1L); + } + + public Initializer init() { + return () -> 0L; + } + + Aggregator adder() { + return (aggKey, value, aggregate) -> aggregate + value; + } + + Aggregator remover() { + return (aggKey, value, aggregate) -> aggregate - value; + } + } + + public static Serde stringSerde = Serdes.String(); + + public static Serde intSerde = Serdes.Integer(); + + static Serde longSerde = Serdes.Long(); + + static Serde doubleSerde = Serdes.Double(); + + public static void sleep(final long duration) { + try { + Thread.sleep(duration); + } catch (final Exception ignore) { } + } + +} diff --git a/streams/upgrade-system-tests-26/src/test/java/org/apache/kafka/streams/tests/StreamsSmokeTest.java b/streams/upgrade-system-tests-26/src/test/java/org/apache/kafka/streams/tests/StreamsSmokeTest.java new file mode 100644 index 0000000..f280eb0 --- /dev/null +++ b/streams/upgrade-system-tests-26/src/test/java/org/apache/kafka/streams/tests/StreamsSmokeTest.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.StreamsConfig; + +import java.io.IOException; +import java.time.Duration; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.UUID; + +import static org.apache.kafka.streams.tests.SmokeTestDriver.generate; +import static org.apache.kafka.streams.tests.SmokeTestDriver.generatePerpetually; + +public class StreamsSmokeTest { + + /** + * args ::= kafka propFileName command disableAutoTerminate + * command := "run" | "process" + * + * @param args + */ + public static void main(final String[] args) throws IOException { + if (args.length < 2) { + System.err.println("StreamsSmokeTest are expecting two parameters: propFile, command; but only see " + args.length + " parameter"); + Exit.exit(1); + } + + final String propFileName = args[0]; + final String command = args[1]; + final boolean disableAutoTerminate = args.length > 2; + + final Properties streamsProperties = Utils.loadProps(propFileName); + final String kafka = streamsProperties.getProperty(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG); + final String processingGuarantee = streamsProperties.getProperty(StreamsConfig.PROCESSING_GUARANTEE_CONFIG); + + if (kafka == null) { + System.err.println("No bootstrap kafka servers specified in " + StreamsConfig.BOOTSTRAP_SERVERS_CONFIG); + Exit.exit(1); + } + + if ("process".equals(command)) { + if (!StreamsConfig.AT_LEAST_ONCE.equals(processingGuarantee) && + !StreamsConfig.EXACTLY_ONCE.equals(processingGuarantee)) { + + System.err.println("processingGuarantee must be either " + StreamsConfig.AT_LEAST_ONCE + " or " + + StreamsConfig.EXACTLY_ONCE); + + Exit.exit(1); + } + } + + System.out.println("StreamsTest instance started (StreamsSmokeTest)"); + System.out.println("command=" + command); + System.out.println("props=" + streamsProperties); + System.out.println("disableAutoTerminate=" + disableAutoTerminate); + + switch (command) { + case "run": + // this starts the driver (data generation and result verification) + final int numKeys = 10; + final int maxRecordsPerKey = 500; + if (disableAutoTerminate) { + generatePerpetually(kafka, numKeys, maxRecordsPerKey); + } else { + // slow down data production to span 30 seconds so that system tests have time to + // do their bounces, etc. + final Map> allData = + generate(kafka, numKeys, maxRecordsPerKey, Duration.ofSeconds(30)); + SmokeTestDriver.verify(kafka, allData, maxRecordsPerKey); + } + break; + case "process": + // this starts the stream processing app + new SmokeTestClient(UUID.randomUUID().toString()).start(streamsProperties); + break; + default: + System.out.println("unknown command: " + command); + } + } + +} diff --git a/streams/upgrade-system-tests-26/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java b/streams/upgrade-system-tests-26/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java new file mode 100644 index 0000000..e1b294f --- /dev/null +++ b/streams/upgrade-system-tests-26/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.processor.AbstractProcessor; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.ProcessorSupplier; + +import java.util.Properties; + +public class StreamsUpgradeTest { + + @SuppressWarnings("unchecked") + public static void main(final String[] args) throws Exception { + if (args.length < 1) { + System.err.println("StreamsUpgradeTest requires one argument (properties-file) but provided none"); + } + final String propFileName = args[0]; + + final Properties streamsProperties = Utils.loadProps(propFileName); + + System.out.println("StreamsTest instance started (StreamsUpgradeTest v2.6)"); + System.out.println("props=" + streamsProperties); + + final StreamsBuilder builder = new StreamsBuilder(); + final KStream dataStream = builder.stream("data"); + dataStream.process(printProcessorSupplier()); + dataStream.to("echo"); + + final Properties config = new Properties(); + config.setProperty(StreamsConfig.APPLICATION_ID_CONFIG, "StreamsUpgradeTest"); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + config.putAll(streamsProperties); + + final KafkaStreams streams = new KafkaStreams(builder.build(), config); + streams.start(); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + streams.close(); + System.out.println("UPGRADE-TEST-CLIENT-CLOSED"); + System.out.flush(); + })); + } + + private static ProcessorSupplier printProcessorSupplier() { + return () -> new AbstractProcessor() { + private int numRecordsProcessed = 0; + + @Override + public void init(final ProcessorContext context) { + System.out.println("[2.6] initializing processor: topic=data taskId=" + context.taskId()); + numRecordsProcessed = 0; + } + + @Override + public void process(final K key, final V value) { + numRecordsProcessed++; + if (numRecordsProcessed % 100 == 0) { + System.out.println("processed " + numRecordsProcessed + " records from topic=data"); + } + } + + @Override + public void close() {} + }; + } +} diff --git a/streams/upgrade-system-tests-27/src/test/java/org/apache/kafka/streams/tests/SmokeTestClient.java b/streams/upgrade-system-tests-27/src/test/java/org/apache/kafka/streams/tests/SmokeTestClient.java new file mode 100644 index 0000000..ced1369 --- /dev/null +++ b/streams/upgrade-system-tests-27/src/test/java/org/apache/kafka/streams/tests/SmokeTestClient.java @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.KafkaThread; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.Suppressed.BufferConfig; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.WindowStore; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.time.Duration; +import java.time.Instant; +import java.util.Properties; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.apache.kafka.streams.kstream.Suppressed.untilWindowCloses; + +public class SmokeTestClient extends SmokeTestUtil { + + private final String name; + + private KafkaStreams streams; + private boolean uncaughtException = false; + private boolean started; + private volatile boolean closed; + + private static void addShutdownHook(final String name, final Runnable runnable) { + if (name != null) { + Runtime.getRuntime().addShutdownHook(KafkaThread.nonDaemon(name, runnable)); + } else { + Runtime.getRuntime().addShutdownHook(new Thread(runnable)); + } + } + + private static File tempDirectory() { + final String prefix = "kafka-"; + final File file; + try { + file = Files.createTempDirectory(prefix).toFile(); + } catch (final IOException ex) { + throw new RuntimeException("Failed to create a temp dir", ex); + } + file.deleteOnExit(); + + addShutdownHook("delete-temp-file-shutdown-hook", () -> { + try { + Utils.delete(file); + } catch (final IOException e) { + System.out.println("Error deleting " + file.getAbsolutePath()); + e.printStackTrace(System.out); + } + }); + + return file; + } + + public SmokeTestClient(final String name) { + this.name = name; + } + + public boolean started() { + return started; + } + + public boolean closed() { + return closed; + } + + public void start(final Properties streamsProperties) { + final Topology build = getTopology(); + streams = new KafkaStreams(build, getStreamsConfig(streamsProperties)); + + final CountDownLatch countDownLatch = new CountDownLatch(1); + streams.setStateListener((newState, oldState) -> { + System.out.printf("%s %s: %s -> %s%n", name, Instant.now(), oldState, newState); + if (oldState == KafkaStreams.State.REBALANCING && newState == KafkaStreams.State.RUNNING) { + started = true; + countDownLatch.countDown(); + } + + if (newState == KafkaStreams.State.NOT_RUNNING) { + closed = true; + } + }); + + streams.setUncaughtExceptionHandler((t, e) -> { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION"); + System.out.println(name + ": FATAL: An unexpected exception is encountered on thread " + t + ": " + e); + e.printStackTrace(System.out); + uncaughtException = true; + streams.close(Duration.ofSeconds(30)); + }); + + addShutdownHook("streams-shutdown-hook", this::close); + + streams.start(); + try { + if (!countDownLatch.await(1, TimeUnit.MINUTES)) { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION: Didn't start in one minute"); + } + } catch (final InterruptedException e) { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION: " + e); + e.printStackTrace(System.out); + } + System.out.println(name + ": SMOKE-TEST-CLIENT-STARTED"); + System.out.println(name + " started at " + Instant.now()); + } + + public void closeAsync() { + streams.close(Duration.ZERO); + } + + public void close() { + final boolean closed = streams.close(Duration.ofMinutes(1)); + + if (closed && !uncaughtException) { + System.out.println(name + ": SMOKE-TEST-CLIENT-CLOSED"); + } else if (closed) { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION"); + } else { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION: Didn't close"); + } + } + + private Properties getStreamsConfig(final Properties props) { + final Properties fullProps = new Properties(props); + fullProps.put(StreamsConfig.APPLICATION_ID_CONFIG, "SmokeTest"); + fullProps.put(StreamsConfig.CLIENT_ID_CONFIG, "SmokeTest-" + name); + fullProps.put(StreamsConfig.STATE_DIR_CONFIG, tempDirectory().getAbsolutePath()); + fullProps.putAll(props); + return fullProps; + } + + public Topology getTopology() { + final StreamsBuilder builder = new StreamsBuilder(); + final Consumed stringIntConsumed = Consumed.with(stringSerde, intSerde); + final KStream source = builder.stream("data", stringIntConsumed); + source.filterNot((k, v) -> k.equals("flush")) + .to("echo", Produced.with(stringSerde, intSerde)); + final KStream data = source.filter((key, value) -> value == null || value != END); + data.process(SmokeTestUtil.printProcessorSupplier("data", name)); + + // min + final KGroupedStream groupedData = data.groupByKey(Grouped.with(stringSerde, intSerde)); + + final KTable, Integer> minAggregation = groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(1)).grace(Duration.ofMinutes(1))) + .aggregate( + () -> Integer.MAX_VALUE, + (aggKey, value, aggregate) -> (value < aggregate) ? value : aggregate, + Materialized + .>as("uwin-min") + .withValueSerde(intSerde) + .withRetention(Duration.ofHours(25)) + ); + + streamify(minAggregation, "min-raw"); + + streamify(minAggregation.suppress(untilWindowCloses(BufferConfig.unbounded())), "min-suppressed"); + + minAggregation + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("min", Produced.with(stringSerde, intSerde)); + + final KTable, Integer> smallWindowSum = groupedData + .windowedBy(TimeWindows.of(Duration.ofSeconds(2)).advanceBy(Duration.ofSeconds(1)).grace(Duration.ofSeconds(30))) + .reduce((l, r) -> l + r); + + streamify(smallWindowSum, "sws-raw"); + streamify(smallWindowSum.suppress(untilWindowCloses(BufferConfig.unbounded())), "sws-suppressed"); + + final KTable minTable = builder.table( + "min", + Consumed.with(stringSerde, intSerde), + Materialized.as("minStoreName")); + + minTable.toStream().process(SmokeTestUtil.printProcessorSupplier("min", name)); + + // max + groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(2))) + .aggregate( + () -> Integer.MIN_VALUE, + (aggKey, value, aggregate) -> (value > aggregate) ? value : aggregate, + Materialized.>as("uwin-max").withValueSerde(intSerde)) + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("max", Produced.with(stringSerde, intSerde)); + + final KTable maxTable = builder.table( + "max", + Consumed.with(stringSerde, intSerde), + Materialized.as("maxStoreName")); + maxTable.toStream().process(SmokeTestUtil.printProcessorSupplier("max", name)); + + // sum + groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(2))) + .aggregate( + () -> 0L, + (aggKey, value, aggregate) -> (long) value + aggregate, + Materialized.>as("win-sum").withValueSerde(longSerde)) + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("sum", Produced.with(stringSerde, longSerde)); + + final Consumed stringLongConsumed = Consumed.with(stringSerde, longSerde); + final KTable sumTable = builder.table("sum", stringLongConsumed); + sumTable.toStream().process(SmokeTestUtil.printProcessorSupplier("sum", name)); + + // cnt + groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(2))) + .count(Materialized.as("uwin-cnt")) + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("cnt", Produced.with(stringSerde, longSerde)); + + final KTable cntTable = builder.table( + "cnt", + Consumed.with(stringSerde, longSerde), + Materialized.as("cntStoreName")); + cntTable.toStream().process(SmokeTestUtil.printProcessorSupplier("cnt", name)); + + // dif + maxTable + .join( + minTable, + (value1, value2) -> value1 - value2) + .toStream() + .filterNot((k, v) -> k.equals("flush")) + .to("dif", Produced.with(stringSerde, intSerde)); + + // avg + sumTable + .join( + cntTable, + (value1, value2) -> (double) value1 / (double) value2) + .toStream() + .filterNot((k, v) -> k.equals("flush")) + .to("avg", Produced.with(stringSerde, doubleSerde)); + + // test repartition + final Agg agg = new Agg(); + cntTable.groupBy(agg.selector(), Grouped.with(stringSerde, longSerde)) + .aggregate(agg.init(), agg.adder(), agg.remover(), + Materialized.as(Stores.inMemoryKeyValueStore("cntByCnt")) + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.Long())) + .toStream() + .to("tagg", Produced.with(stringSerde, longSerde)); + + return builder.build(); + } + + private static void streamify(final KTable, Integer> windowedTable, final String topic) { + windowedTable + .toStream() + .filterNot((k, v) -> k.key().equals("flush")) + .map((key, value) -> new KeyValue<>(key.toString(), value)) + .to(topic, Produced.with(stringSerde, intSerde)); + } +} diff --git a/streams/upgrade-system-tests-27/src/test/java/org/apache/kafka/streams/tests/SmokeTestDriver.java b/streams/upgrade-system-tests-27/src/test/java/org/apache/kafka/streams/tests/SmokeTestDriver.java new file mode 100644 index 0000000..ac83cd9 --- /dev/null +++ b/streams/upgrade-system-tests-27/src/test/java/org/apache/kafka/streams/tests/SmokeTestDriver.java @@ -0,0 +1,622 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.Callback; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; + +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static java.util.Collections.emptyMap; +import static org.apache.kafka.common.utils.Utils.mkEntry; + +public class SmokeTestDriver extends SmokeTestUtil { + private static final String[] TOPICS = { + "data", + "echo", + "max", + "min", "min-suppressed", "min-raw", + "dif", + "sum", + "sws-raw", "sws-suppressed", + "cnt", + "avg", + "tagg" + }; + + private static final int MAX_RECORD_EMPTY_RETRIES = 30; + + private static class ValueList { + public final String key; + private final int[] values; + private int index; + + ValueList(final int min, final int max) { + key = min + "-" + max; + + values = new int[max - min + 1]; + for (int i = 0; i < values.length; i++) { + values[i] = min + i; + } + // We want to randomize the order of data to test not completely predictable processing order + // However, values are also use as a timestamp of the record. (TODO: separate data and timestamp) + // We keep some correlation of time and order. Thus, the shuffling is done with a sliding window + shuffle(values, 10); + + index = 0; + } + + int next() { + return (index < values.length) ? values[index++] : -1; + } + } + + public static String[] topics() { + return Arrays.copyOf(TOPICS, TOPICS.length); + } + + static void generatePerpetually(final String kafka, + final int numKeys, + final int maxRecordsPerKey) { + final Properties producerProps = generatorProperties(kafka); + + int numRecordsProduced = 0; + + final ValueList[] data = new ValueList[numKeys]; + for (int i = 0; i < numKeys; i++) { + data[i] = new ValueList(i, i + maxRecordsPerKey - 1); + } + + final Random rand = new Random(); + + try (final KafkaProducer producer = new KafkaProducer<>(producerProps)) { + while (true) { + final int index = rand.nextInt(numKeys); + final String key = data[index].key; + final int value = data[index].next(); + + final ProducerRecord record = + new ProducerRecord<>( + "data", + stringSerde.serializer().serialize("", key), + intSerde.serializer().serialize("", value) + ); + + producer.send(record); + + numRecordsProduced++; + if (numRecordsProduced % 100 == 0) { + System.out.println(Instant.now() + " " + numRecordsProduced + " records produced"); + } + Utils.sleep(2); + } + } + } + + public static Map> generate(final String kafka, + final int numKeys, + final int maxRecordsPerKey, + final Duration timeToSpend) { + final Properties producerProps = generatorProperties(kafka); + + + int numRecordsProduced = 0; + + final Map> allData = new HashMap<>(); + final ValueList[] data = new ValueList[numKeys]; + for (int i = 0; i < numKeys; i++) { + data[i] = new ValueList(i, i + maxRecordsPerKey - 1); + allData.put(data[i].key, new HashSet<>()); + } + final Random rand = new Random(); + + int remaining = data.length; + + final long recordPauseTime = timeToSpend.toMillis() / numKeys / maxRecordsPerKey; + + List> needRetry = new ArrayList<>(); + + try (final KafkaProducer producer = new KafkaProducer<>(producerProps)) { + while (remaining > 0) { + final int index = rand.nextInt(remaining); + final String key = data[index].key; + final int value = data[index].next(); + + if (value < 0) { + remaining--; + data[index] = data[remaining]; + } else { + + final ProducerRecord record = + new ProducerRecord<>( + "data", + stringSerde.serializer().serialize("", key), + intSerde.serializer().serialize("", value) + ); + + producer.send(record, new TestCallback(record, needRetry)); + + numRecordsProduced++; + allData.get(key).add(value); + if (numRecordsProduced % 100 == 0) { + System.out.println(Instant.now() + " " + numRecordsProduced + " records produced"); + } + Utils.sleep(Math.max(recordPauseTime, 2)); + } + } + producer.flush(); + + int remainingRetries = 5; + while (!needRetry.isEmpty()) { + final List> needRetry2 = new ArrayList<>(); + for (final ProducerRecord record : needRetry) { + System.out.println("retry producing " + stringSerde.deserializer().deserialize("", record.key())); + producer.send(record, new TestCallback(record, needRetry2)); + } + producer.flush(); + needRetry = needRetry2; + + if (--remainingRetries == 0 && !needRetry.isEmpty()) { + System.err.println("Failed to produce all records after multiple retries"); + Exit.exit(1); + } + } + + // now that we've sent everything, we'll send some final records with a timestamp high enough to flush out + // all suppressed records. + final List partitions = producer.partitionsFor("data"); + for (final PartitionInfo partition : partitions) { + producer.send(new ProducerRecord<>( + partition.topic(), + partition.partition(), + System.currentTimeMillis() + Duration.ofDays(2).toMillis(), + stringSerde.serializer().serialize("", "flush"), + intSerde.serializer().serialize("", 0) + )); + } + } + return Collections.unmodifiableMap(allData); + } + + private static Properties generatorProperties(final String kafka) { + final Properties producerProps = new Properties(); + producerProps.put(ProducerConfig.CLIENT_ID_CONFIG, "SmokeTest"); + producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, kafka); + producerProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class); + producerProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class); + producerProps.put(ProducerConfig.ACKS_CONFIG, "all"); + return producerProps; + } + + private static class TestCallback implements Callback { + private final ProducerRecord originalRecord; + private final List> needRetry; + + TestCallback(final ProducerRecord originalRecord, + final List> needRetry) { + this.originalRecord = originalRecord; + this.needRetry = needRetry; + } + + @Override + public void onCompletion(final RecordMetadata metadata, final Exception exception) { + if (exception != null) { + if (exception instanceof TimeoutException) { + needRetry.add(originalRecord); + } else { + exception.printStackTrace(); + Exit.exit(1); + } + } + } + } + + private static void shuffle(final int[] data, @SuppressWarnings("SameParameterValue") final int windowSize) { + final Random rand = new Random(); + for (int i = 0; i < data.length; i++) { + // we shuffle data within windowSize + final int j = rand.nextInt(Math.min(data.length - i, windowSize)) + i; + + // swap + final int tmp = data[i]; + data[i] = data[j]; + data[j] = tmp; + } + } + + public static class NumberDeserializer implements Deserializer { + @Override + public Number deserialize(final String topic, final byte[] data) { + final Number value; + switch (topic) { + case "data": + case "echo": + case "min": + case "min-raw": + case "min-suppressed": + case "sws-raw": + case "sws-suppressed": + case "max": + case "dif": + value = intSerde.deserializer().deserialize(topic, data); + break; + case "sum": + case "cnt": + case "tagg": + value = longSerde.deserializer().deserialize(topic, data); + break; + case "avg": + value = doubleSerde.deserializer().deserialize(topic, data); + break; + default: + throw new RuntimeException("unknown topic: " + topic); + } + return value; + } + } + + public static VerificationResult verify(final String kafka, + final Map> inputs, + final int maxRecordsPerKey) { + final Properties props = new Properties(); + props.put(ConsumerConfig.CLIENT_ID_CONFIG, "verifier"); + props.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, kafka); + props.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); + props.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, NumberDeserializer.class); + props.put(ConsumerConfig.ISOLATION_LEVEL_CONFIG, "read_committed"); + + final KafkaConsumer consumer = new KafkaConsumer<>(props); + final List partitions = getAllPartitions(consumer, TOPICS); + consumer.assign(partitions); + consumer.seekToBeginning(partitions); + + final int recordsGenerated = inputs.size() * maxRecordsPerKey; + int recordsProcessed = 0; + final Map processed = + Stream.of(TOPICS) + .collect(Collectors.toMap(t -> t, t -> new AtomicInteger(0))); + + final Map>>> events = new HashMap<>(); + + VerificationResult verificationResult = new VerificationResult(false, "no results yet"); + int retry = 0; + final long start = System.currentTimeMillis(); + while (System.currentTimeMillis() - start < TimeUnit.MINUTES.toMillis(6)) { + final ConsumerRecords records = consumer.poll(Duration.ofSeconds(5)); + if (records.isEmpty() && recordsProcessed >= recordsGenerated) { + verificationResult = verifyAll(inputs, events, false); + if (verificationResult.passed()) { + break; + } else if (retry++ > MAX_RECORD_EMPTY_RETRIES) { + System.out.println(Instant.now() + " Didn't get any more results, verification hasn't passed, and out of retries."); + break; + } else { + System.out.println(Instant.now() + " Didn't get any more results, but verification hasn't passed (yet). Retrying..." + retry); + } + } else { + System.out.println(Instant.now() + " Get some more results from " + records.partitions() + ", resetting retry."); + + retry = 0; + for (final ConsumerRecord record : records) { + final String key = record.key(); + + final String topic = record.topic(); + processed.get(topic).incrementAndGet(); + + if (topic.equals("echo")) { + recordsProcessed++; + if (recordsProcessed % 100 == 0) { + System.out.println("Echo records processed = " + recordsProcessed); + } + } + + events.computeIfAbsent(topic, t -> new HashMap<>()) + .computeIfAbsent(key, k -> new LinkedList<>()) + .add(record); + } + + System.out.println(processed); + } + } + consumer.close(); + final long finished = System.currentTimeMillis() - start; + System.out.println("Verification time=" + finished); + System.out.println("-------------------"); + System.out.println("Result Verification"); + System.out.println("-------------------"); + System.out.println("recordGenerated=" + recordsGenerated); + System.out.println("recordProcessed=" + recordsProcessed); + + if (recordsProcessed > recordsGenerated) { + System.out.println("PROCESSED-MORE-THAN-GENERATED"); + } else if (recordsProcessed < recordsGenerated) { + System.out.println("PROCESSED-LESS-THAN-GENERATED"); + } + + boolean success; + + final Map> received = + events.get("echo") + .entrySet() + .stream() + .map(entry -> mkEntry( + entry.getKey(), + entry.getValue().stream().map(ConsumerRecord::value).collect(Collectors.toSet())) + ) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + success = inputs.equals(received); + + if (success) { + System.out.println("ALL-RECORDS-DELIVERED"); + } else { + int missedCount = 0; + for (final Map.Entry> entry : inputs.entrySet()) { + missedCount += received.get(entry.getKey()).size(); + } + System.out.println("missedRecords=" + missedCount); + } + + // give it one more try if it's not already passing. + if (!verificationResult.passed()) { + verificationResult = verifyAll(inputs, events, true); + } + success &= verificationResult.passed(); + + System.out.println(verificationResult.result()); + + System.out.println(success ? "SUCCESS" : "FAILURE"); + return verificationResult; + } + + public static class VerificationResult { + private final boolean passed; + private final String result; + + VerificationResult(final boolean passed, final String result) { + this.passed = passed; + this.result = result; + } + + public boolean passed() { + return passed; + } + + public String result() { + return result; + } + } + + private static VerificationResult verifyAll(final Map> inputs, + final Map>>> events, + final boolean printResults) { + final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + boolean pass; + try (final PrintStream resultStream = new PrintStream(byteArrayOutputStream)) { + pass = verifyTAgg(resultStream, inputs, events.get("tagg"), printResults); + pass &= verifySuppressed(resultStream, "min-suppressed", events, printResults); + pass &= verify(resultStream, "min-suppressed", inputs, events, windowedKey -> { + final String unwindowedKey = windowedKey.substring(1, windowedKey.length() - 1).replaceAll("@.*", ""); + return getMin(unwindowedKey); + }, printResults); + pass &= verifySuppressed(resultStream, "sws-suppressed", events, printResults); + pass &= verify(resultStream, "min", inputs, events, SmokeTestDriver::getMin, printResults); + pass &= verify(resultStream, "max", inputs, events, SmokeTestDriver::getMax, printResults); + pass &= verify(resultStream, "dif", inputs, events, key -> getMax(key).intValue() - getMin(key).intValue(), printResults); + pass &= verify(resultStream, "sum", inputs, events, SmokeTestDriver::getSum, printResults); + pass &= verify(resultStream, "cnt", inputs, events, key1 -> getMax(key1).intValue() - getMin(key1).intValue() + 1L, printResults); + pass &= verify(resultStream, "avg", inputs, events, SmokeTestDriver::getAvg, printResults); + } + return new VerificationResult(pass, new String(byteArrayOutputStream.toByteArray(), StandardCharsets.UTF_8)); + } + + private static boolean verify(final PrintStream resultStream, + final String topic, + final Map> inputData, + final Map>>> events, + final Function keyToExpectation, + final boolean printResults) { + final Map>> observedInputEvents = events.get("data"); + final Map>> outputEvents = events.getOrDefault(topic, emptyMap()); + if (outputEvents.isEmpty()) { + resultStream.println(topic + " is empty"); + return false; + } else { + resultStream.printf("verifying %s with %d keys%n", topic, outputEvents.size()); + + if (outputEvents.size() != inputData.size()) { + resultStream.printf("fail: resultCount=%d expectedCount=%s%n\tresult=%s%n\texpected=%s%n", + outputEvents.size(), inputData.size(), outputEvents.keySet(), inputData.keySet()); + return false; + } + for (final Map.Entry>> entry : outputEvents.entrySet()) { + final String key = entry.getKey(); + final Number expected = keyToExpectation.apply(key); + final Number actual = entry.getValue().getLast().value(); + if (!expected.equals(actual)) { + resultStream.printf("%s fail: key=%s actual=%s expected=%s%n", topic, key, actual, expected); + + if (printResults) { + resultStream.printf("\t inputEvents=%n%s%n\t" + + "echoEvents=%n%s%n\tmaxEvents=%n%s%n\tminEvents=%n%s%n\tdifEvents=%n%s%n\tcntEvents=%n%s%n\ttaggEvents=%n%s%n", + indent("\t\t", observedInputEvents.get(key)), + indent("\t\t", events.getOrDefault("echo", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("max", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("min", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("dif", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("cnt", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("tagg", emptyMap()).getOrDefault(key, new LinkedList<>()))); + + if (!Utils.mkSet("echo", "max", "min", "dif", "cnt", "tagg").contains(topic)) + resultStream.printf("%sEvents=%n%s%n", topic, indent("\t\t", entry.getValue())); + } + + return false; + } + } + return true; + } + } + + + private static boolean verifySuppressed(final PrintStream resultStream, + @SuppressWarnings("SameParameterValue") final String topic, + final Map>>> events, + final boolean printResults) { + resultStream.println("verifying suppressed " + topic); + final Map>> topicEvents = events.getOrDefault(topic, emptyMap()); + for (final Map.Entry>> entry : topicEvents.entrySet()) { + if (entry.getValue().size() != 1) { + final String unsuppressedTopic = topic.replace("-suppressed", "-raw"); + final String key = entry.getKey(); + final String unwindowedKey = key.substring(1, key.length() - 1).replaceAll("@.*", ""); + resultStream.printf("fail: key=%s%n\tnon-unique result:%n%s%n", + key, + indent("\t\t", entry.getValue())); + + if (printResults) + resultStream.printf("\tresultEvents:%n%s%n\tinputEvents:%n%s%n", + indent("\t\t", events.get(unsuppressedTopic).get(key)), + indent("\t\t", events.get("data").get(unwindowedKey))); + + return false; + } + } + return true; + } + + private static String indent(@SuppressWarnings("SameParameterValue") final String prefix, + final Iterable> list) { + final StringBuilder stringBuilder = new StringBuilder(); + for (final ConsumerRecord record : list) { + stringBuilder.append(prefix).append(record).append('\n'); + } + return stringBuilder.toString(); + } + + private static Long getSum(final String key) { + final int min = getMin(key).intValue(); + final int max = getMax(key).intValue(); + return ((long) min + max) * (max - min + 1L) / 2L; + } + + private static Double getAvg(final String key) { + final int min = getMin(key).intValue(); + final int max = getMax(key).intValue(); + return ((long) min + max) / 2.0; + } + + + private static boolean verifyTAgg(final PrintStream resultStream, + final Map> allData, + final Map>> taggEvents, + final boolean printResults) { + if (taggEvents == null) { + resultStream.println("tagg is missing"); + return false; + } else if (taggEvents.isEmpty()) { + resultStream.println("tagg is empty"); + return false; + } else { + resultStream.println("verifying tagg"); + + // generate expected answer + final Map expected = new HashMap<>(); + for (final String key : allData.keySet()) { + final int min = getMin(key).intValue(); + final int max = getMax(key).intValue(); + final String cnt = Long.toString(max - min + 1L); + + expected.put(cnt, expected.getOrDefault(cnt, 0L) + 1); + } + + // check the result + for (final Map.Entry>> entry : taggEvents.entrySet()) { + final String key = entry.getKey(); + Long expectedCount = expected.remove(key); + if (expectedCount == null) { + expectedCount = 0L; + } + + if (entry.getValue().getLast().value().longValue() != expectedCount) { + resultStream.println("fail: key=" + key + " tagg=" + entry.getValue() + " expected=" + expectedCount); + + if (printResults) + resultStream.println("\t taggEvents: " + entry.getValue()); + return false; + } + } + + } + return true; + } + + private static Number getMin(final String key) { + return Integer.parseInt(key.split("-")[0]); + } + + private static Number getMax(final String key) { + return Integer.parseInt(key.split("-")[1]); + } + + private static List getAllPartitions(final KafkaConsumer consumer, final String... topics) { + final List partitions = new ArrayList<>(); + + for (final String topic : topics) { + for (final PartitionInfo info : consumer.partitionsFor(topic)) { + partitions.add(new TopicPartition(info.topic(), info.partition())); + } + } + return partitions; + } + +} diff --git a/streams/upgrade-system-tests-27/src/test/java/org/apache/kafka/streams/tests/SmokeTestUtil.java b/streams/upgrade-system-tests-27/src/test/java/org/apache/kafka/streams/tests/SmokeTestUtil.java new file mode 100644 index 0000000..e8ec04c --- /dev/null +++ b/streams/upgrade-system-tests-27/src/test/java/org/apache/kafka/streams/tests/SmokeTestUtil.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.AbstractProcessor; +import org.apache.kafka.streams.processor.Processor; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.ProcessorSupplier; + +import java.time.Instant; + +public class SmokeTestUtil { + + final static int END = Integer.MAX_VALUE; + + static ProcessorSupplier printProcessorSupplier(final String topic) { + return printProcessorSupplier(topic, ""); + } + + static ProcessorSupplier printProcessorSupplier(final String topic, final String name) { + return new ProcessorSupplier() { + @Override + public Processor get() { + return new AbstractProcessor() { + private int numRecordsProcessed = 0; + private long smallestOffset = Long.MAX_VALUE; + private long largestOffset = Long.MIN_VALUE; + + @Override + public void init(final ProcessorContext context) { + super.init(context); + System.out.println("[DEV] initializing processor: topic=" + topic + " taskId=" + context.taskId()); + System.out.flush(); + numRecordsProcessed = 0; + smallestOffset = Long.MAX_VALUE; + largestOffset = Long.MIN_VALUE; + } + + @Override + public void process(final Object key, final Object value) { + numRecordsProcessed++; + if (numRecordsProcessed % 100 == 0) { + System.out.printf("%s: %s%n", name, Instant.now()); + System.out.println("processed " + numRecordsProcessed + " records from topic=" + topic); + } + + if (smallestOffset > context().offset()) { + smallestOffset = context().offset(); + } + if (largestOffset < context().offset()) { + largestOffset = context().offset(); + } + } + + @Override + public void close() { + System.out.printf("Close processor for task %s%n", context().taskId()); + System.out.println("processed " + numRecordsProcessed + " records"); + final long processed; + if (largestOffset >= smallestOffset) { + processed = 1L + largestOffset - smallestOffset; + } else { + processed = 0L; + } + System.out.println("offset " + smallestOffset + " to " + largestOffset + " -> processed " + processed); + System.out.flush(); + } + }; + } + }; + } + + public static final class Unwindow implements KeyValueMapper, V, K> { + @Override + public K apply(final Windowed winKey, final V value) { + return winKey.key(); + } + } + + public static class Agg { + + KeyValueMapper> selector() { + return (key, value) -> new KeyValue<>(value == null ? null : Long.toString(value), 1L); + } + + public Initializer init() { + return () -> 0L; + } + + Aggregator adder() { + return (aggKey, value, aggregate) -> aggregate + value; + } + + Aggregator remover() { + return (aggKey, value, aggregate) -> aggregate - value; + } + } + + public static Serde stringSerde = Serdes.String(); + + public static Serde intSerde = Serdes.Integer(); + + static Serde longSerde = Serdes.Long(); + + static Serde doubleSerde = Serdes.Double(); + + public static void sleep(final long duration) { + try { + Thread.sleep(duration); + } catch (final Exception ignore) { } + } + +} diff --git a/streams/upgrade-system-tests-27/src/test/java/org/apache/kafka/streams/tests/StreamsSmokeTest.java b/streams/upgrade-system-tests-27/src/test/java/org/apache/kafka/streams/tests/StreamsSmokeTest.java new file mode 100644 index 0000000..f280eb0 --- /dev/null +++ b/streams/upgrade-system-tests-27/src/test/java/org/apache/kafka/streams/tests/StreamsSmokeTest.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.StreamsConfig; + +import java.io.IOException; +import java.time.Duration; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.UUID; + +import static org.apache.kafka.streams.tests.SmokeTestDriver.generate; +import static org.apache.kafka.streams.tests.SmokeTestDriver.generatePerpetually; + +public class StreamsSmokeTest { + + /** + * args ::= kafka propFileName command disableAutoTerminate + * command := "run" | "process" + * + * @param args + */ + public static void main(final String[] args) throws IOException { + if (args.length < 2) { + System.err.println("StreamsSmokeTest are expecting two parameters: propFile, command; but only see " + args.length + " parameter"); + Exit.exit(1); + } + + final String propFileName = args[0]; + final String command = args[1]; + final boolean disableAutoTerminate = args.length > 2; + + final Properties streamsProperties = Utils.loadProps(propFileName); + final String kafka = streamsProperties.getProperty(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG); + final String processingGuarantee = streamsProperties.getProperty(StreamsConfig.PROCESSING_GUARANTEE_CONFIG); + + if (kafka == null) { + System.err.println("No bootstrap kafka servers specified in " + StreamsConfig.BOOTSTRAP_SERVERS_CONFIG); + Exit.exit(1); + } + + if ("process".equals(command)) { + if (!StreamsConfig.AT_LEAST_ONCE.equals(processingGuarantee) && + !StreamsConfig.EXACTLY_ONCE.equals(processingGuarantee)) { + + System.err.println("processingGuarantee must be either " + StreamsConfig.AT_LEAST_ONCE + " or " + + StreamsConfig.EXACTLY_ONCE); + + Exit.exit(1); + } + } + + System.out.println("StreamsTest instance started (StreamsSmokeTest)"); + System.out.println("command=" + command); + System.out.println("props=" + streamsProperties); + System.out.println("disableAutoTerminate=" + disableAutoTerminate); + + switch (command) { + case "run": + // this starts the driver (data generation and result verification) + final int numKeys = 10; + final int maxRecordsPerKey = 500; + if (disableAutoTerminate) { + generatePerpetually(kafka, numKeys, maxRecordsPerKey); + } else { + // slow down data production to span 30 seconds so that system tests have time to + // do their bounces, etc. + final Map> allData = + generate(kafka, numKeys, maxRecordsPerKey, Duration.ofSeconds(30)); + SmokeTestDriver.verify(kafka, allData, maxRecordsPerKey); + } + break; + case "process": + // this starts the stream processing app + new SmokeTestClient(UUID.randomUUID().toString()).start(streamsProperties); + break; + default: + System.out.println("unknown command: " + command); + } + } + +} diff --git a/streams/upgrade-system-tests-27/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java b/streams/upgrade-system-tests-27/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java new file mode 100644 index 0000000..6f485e6 --- /dev/null +++ b/streams/upgrade-system-tests-27/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.processor.AbstractProcessor; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.ProcessorSupplier; + +import java.util.Properties; + +public class StreamsUpgradeTest { + + @SuppressWarnings("unchecked") + public static void main(final String[] args) throws Exception { + if (args.length < 1) { + System.err.println("StreamsUpgradeTest requires one argument (properties-file) but provided none"); + } + final String propFileName = args[0]; + + final Properties streamsProperties = Utils.loadProps(propFileName); + + System.out.println("StreamsTest instance started (StreamsUpgradeTest v2.7)"); + System.out.println("props=" + streamsProperties); + + final StreamsBuilder builder = new StreamsBuilder(); + final KStream dataStream = builder.stream("data"); + dataStream.process(printProcessorSupplier()); + dataStream.to("echo"); + + final Properties config = new Properties(); + config.setProperty(StreamsConfig.APPLICATION_ID_CONFIG, "StreamsUpgradeTest"); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000L); + config.putAll(streamsProperties); + + final KafkaStreams streams = new KafkaStreams(builder.build(), config); + streams.start(); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + streams.close(); + System.out.println("UPGRADE-TEST-CLIENT-CLOSED"); + System.out.flush(); + })); + } + + private static ProcessorSupplier printProcessorSupplier() { + return () -> new AbstractProcessor() { + private int numRecordsProcessed = 0; + + @Override + public void init(final ProcessorContext context) { + System.out.println("[2.7] initializing processor: topic=data taskId=" + context.taskId()); + numRecordsProcessed = 0; + } + + @Override + public void process(final K key, final V value) { + numRecordsProcessed++; + if (numRecordsProcessed % 100 == 0) { + System.out.println("processed " + numRecordsProcessed + " records from topic=data"); + } + } + + @Override + public void close() {} + }; + } +} diff --git a/streams/upgrade-system-tests-28/src/test/java/org/apache/kafka/streams/tests/SmokeTestClient.java b/streams/upgrade-system-tests-28/src/test/java/org/apache/kafka/streams/tests/SmokeTestClient.java new file mode 100644 index 0000000..55aebf4 --- /dev/null +++ b/streams/upgrade-system-tests-28/src/test/java/org/apache/kafka/streams/tests/SmokeTestClient.java @@ -0,0 +1,299 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.Bytes; +import org.apache.kafka.common.utils.KafkaThread; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.Topology; +import org.apache.kafka.streams.errors.StreamsUncaughtExceptionHandler; +import org.apache.kafka.streams.kstream.Consumed; +import org.apache.kafka.streams.kstream.Grouped; +import org.apache.kafka.streams.kstream.KGroupedStream; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.kstream.KTable; +import org.apache.kafka.streams.kstream.Materialized; +import org.apache.kafka.streams.kstream.Produced; +import org.apache.kafka.streams.kstream.Suppressed.BufferConfig; +import org.apache.kafka.streams.kstream.TimeWindows; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.state.Stores; +import org.apache.kafka.streams.state.WindowStore; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.time.Duration; +import java.time.Instant; +import java.util.Properties; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.apache.kafka.streams.kstream.Suppressed.untilWindowCloses; + +public class SmokeTestClient extends SmokeTestUtil { + + private final String name; + + private KafkaStreams streams; + private boolean uncaughtException = false; + private boolean started; + private volatile boolean closed; + + private static void addShutdownHook(final String name, final Runnable runnable) { + if (name != null) { + Runtime.getRuntime().addShutdownHook(KafkaThread.nonDaemon(name, runnable)); + } else { + Runtime.getRuntime().addShutdownHook(new Thread(runnable)); + } + } + + private static File tempDirectory() { + final String prefix = "kafka-"; + final File file; + try { + file = Files.createTempDirectory(prefix).toFile(); + } catch (final IOException ex) { + throw new RuntimeException("Failed to create a temp dir", ex); + } + file.deleteOnExit(); + + addShutdownHook("delete-temp-file-shutdown-hook", () -> { + try { + Utils.delete(file); + } catch (final IOException e) { + System.out.println("Error deleting " + file.getAbsolutePath()); + e.printStackTrace(System.out); + } + }); + + return file; + } + + public SmokeTestClient(final String name) { + this.name = name; + } + + public boolean started() { + return started; + } + + public boolean closed() { + return closed; + } + + public void start(final Properties streamsProperties) { + final Topology build = getTopology(); + streams = new KafkaStreams(build, getStreamsConfig(streamsProperties)); + + final CountDownLatch countDownLatch = new CountDownLatch(1); + streams.setStateListener((newState, oldState) -> { + System.out.printf("%s %s: %s -> %s%n", name, Instant.now(), oldState, newState); + if (oldState == KafkaStreams.State.REBALANCING && newState == KafkaStreams.State.RUNNING) { + started = true; + countDownLatch.countDown(); + } + + if (newState == KafkaStreams.State.NOT_RUNNING) { + closed = true; + } + }); + + streams.setUncaughtExceptionHandler(e -> { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION"); + System.out.println(name + ": FATAL: An unexpected exception is encountered: " + e); + e.printStackTrace(System.out); + uncaughtException = true; + return StreamsUncaughtExceptionHandler.StreamThreadExceptionResponse.SHUTDOWN_CLIENT; + }); + + addShutdownHook("streams-shutdown-hook", this::close); + + streams.start(); + try { + if (!countDownLatch.await(1, TimeUnit.MINUTES)) { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION: Didn't start in one minute"); + } + } catch (final InterruptedException e) { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION: " + e); + e.printStackTrace(System.out); + } + System.out.println(name + ": SMOKE-TEST-CLIENT-STARTED"); + System.out.println(name + " started at " + Instant.now()); + } + + public void closeAsync() { + streams.close(Duration.ZERO); + } + + public void close() { + final boolean closed = streams.close(Duration.ofMinutes(1)); + + if (closed && !uncaughtException) { + System.out.println(name + ": SMOKE-TEST-CLIENT-CLOSED"); + } else if (closed) { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION"); + } else { + System.out.println(name + ": SMOKE-TEST-CLIENT-EXCEPTION: Didn't close"); + } + } + + private Properties getStreamsConfig(final Properties props) { + final Properties fullProps = new Properties(props); + fullProps.put(StreamsConfig.APPLICATION_ID_CONFIG, "SmokeTest"); + fullProps.put(StreamsConfig.CLIENT_ID_CONFIG, "SmokeTest-" + name); + fullProps.put(StreamsConfig.STATE_DIR_CONFIG, tempDirectory().getAbsolutePath()); + fullProps.putAll(props); + return fullProps; + } + + public Topology getTopology() { + final StreamsBuilder builder = new StreamsBuilder(); + final Consumed stringIntConsumed = Consumed.with(stringSerde, intSerde); + final KStream source = builder.stream("data", stringIntConsumed); + source.filterNot((k, v) -> k.equals("flush")) + .to("echo", Produced.with(stringSerde, intSerde)); + final KStream data = source.filter((key, value) -> value == null || value != END); + data.process(SmokeTestUtil.printProcessorSupplier("data", name)); + + // min + final KGroupedStream groupedData = data.groupByKey(Grouped.with(stringSerde, intSerde)); + + final KTable, Integer> minAggregation = groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(1)).grace(Duration.ofMinutes(1))) + .aggregate( + () -> Integer.MAX_VALUE, + (aggKey, value, aggregate) -> (value < aggregate) ? value : aggregate, + Materialized + .>as("uwin-min") + .withValueSerde(intSerde) + .withRetention(Duration.ofHours(25)) + ); + + streamify(minAggregation, "min-raw"); + + streamify(minAggregation.suppress(untilWindowCloses(BufferConfig.unbounded())), "min-suppressed"); + + minAggregation + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("min", Produced.with(stringSerde, intSerde)); + + final KTable, Integer> smallWindowSum = groupedData + .windowedBy(TimeWindows.of(Duration.ofSeconds(2)).advanceBy(Duration.ofSeconds(1)).grace(Duration.ofSeconds(30))) + .reduce((l, r) -> l + r); + + streamify(smallWindowSum, "sws-raw"); + streamify(smallWindowSum.suppress(untilWindowCloses(BufferConfig.unbounded())), "sws-suppressed"); + + final KTable minTable = builder.table( + "min", + Consumed.with(stringSerde, intSerde), + Materialized.as("minStoreName")); + + minTable.toStream().process(SmokeTestUtil.printProcessorSupplier("min", name)); + + // max + groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(2))) + .aggregate( + () -> Integer.MIN_VALUE, + (aggKey, value, aggregate) -> (value > aggregate) ? value : aggregate, + Materialized.>as("uwin-max").withValueSerde(intSerde)) + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("max", Produced.with(stringSerde, intSerde)); + + final KTable maxTable = builder.table( + "max", + Consumed.with(stringSerde, intSerde), + Materialized.as("maxStoreName")); + maxTable.toStream().process(SmokeTestUtil.printProcessorSupplier("max", name)); + + // sum + groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(2))) + .aggregate( + () -> 0L, + (aggKey, value, aggregate) -> (long) value + aggregate, + Materialized.>as("win-sum").withValueSerde(longSerde)) + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("sum", Produced.with(stringSerde, longSerde)); + + final Consumed stringLongConsumed = Consumed.with(stringSerde, longSerde); + final KTable sumTable = builder.table("sum", stringLongConsumed); + sumTable.toStream().process(SmokeTestUtil.printProcessorSupplier("sum", name)); + + // cnt + groupedData + .windowedBy(TimeWindows.of(Duration.ofDays(2))) + .count(Materialized.as("uwin-cnt")) + .toStream(new Unwindow<>()) + .filterNot((k, v) -> k.equals("flush")) + .to("cnt", Produced.with(stringSerde, longSerde)); + + final KTable cntTable = builder.table( + "cnt", + Consumed.with(stringSerde, longSerde), + Materialized.as("cntStoreName")); + cntTable.toStream().process(SmokeTestUtil.printProcessorSupplier("cnt", name)); + + // dif + maxTable + .join( + minTable, + (value1, value2) -> value1 - value2) + .toStream() + .filterNot((k, v) -> k.equals("flush")) + .to("dif", Produced.with(stringSerde, intSerde)); + + // avg + sumTable + .join( + cntTable, + (value1, value2) -> (double) value1 / (double) value2) + .toStream() + .filterNot((k, v) -> k.equals("flush")) + .to("avg", Produced.with(stringSerde, doubleSerde)); + + // test repartition + final Agg agg = new Agg(); + cntTable.groupBy(agg.selector(), Grouped.with(stringSerde, longSerde)) + .aggregate(agg.init(), agg.adder(), agg.remover(), + Materialized.as(Stores.inMemoryKeyValueStore("cntByCnt")) + .withKeySerde(Serdes.String()) + .withValueSerde(Serdes.Long())) + .toStream() + .to("tagg", Produced.with(stringSerde, longSerde)); + + return builder.build(); + } + + private static void streamify(final KTable, Integer> windowedTable, final String topic) { + windowedTable + .toStream() + .filterNot((k, v) -> k.key().equals("flush")) + .map((key, value) -> new KeyValue<>(key.toString(), value)) + .to(topic, Produced.with(stringSerde, intSerde)); + } +} diff --git a/streams/upgrade-system-tests-28/src/test/java/org/apache/kafka/streams/tests/SmokeTestDriver.java b/streams/upgrade-system-tests-28/src/test/java/org/apache/kafka/streams/tests/SmokeTestDriver.java new file mode 100644 index 0000000..ac83cd9 --- /dev/null +++ b/streams/upgrade-system-tests-28/src/test/java/org/apache/kafka/streams/tests/SmokeTestDriver.java @@ -0,0 +1,622 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.Callback; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.serialization.Deserializer; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; + +import java.io.ByteArrayOutputStream; +import java.io.PrintStream; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static java.util.Collections.emptyMap; +import static org.apache.kafka.common.utils.Utils.mkEntry; + +public class SmokeTestDriver extends SmokeTestUtil { + private static final String[] TOPICS = { + "data", + "echo", + "max", + "min", "min-suppressed", "min-raw", + "dif", + "sum", + "sws-raw", "sws-suppressed", + "cnt", + "avg", + "tagg" + }; + + private static final int MAX_RECORD_EMPTY_RETRIES = 30; + + private static class ValueList { + public final String key; + private final int[] values; + private int index; + + ValueList(final int min, final int max) { + key = min + "-" + max; + + values = new int[max - min + 1]; + for (int i = 0; i < values.length; i++) { + values[i] = min + i; + } + // We want to randomize the order of data to test not completely predictable processing order + // However, values are also use as a timestamp of the record. (TODO: separate data and timestamp) + // We keep some correlation of time and order. Thus, the shuffling is done with a sliding window + shuffle(values, 10); + + index = 0; + } + + int next() { + return (index < values.length) ? values[index++] : -1; + } + } + + public static String[] topics() { + return Arrays.copyOf(TOPICS, TOPICS.length); + } + + static void generatePerpetually(final String kafka, + final int numKeys, + final int maxRecordsPerKey) { + final Properties producerProps = generatorProperties(kafka); + + int numRecordsProduced = 0; + + final ValueList[] data = new ValueList[numKeys]; + for (int i = 0; i < numKeys; i++) { + data[i] = new ValueList(i, i + maxRecordsPerKey - 1); + } + + final Random rand = new Random(); + + try (final KafkaProducer producer = new KafkaProducer<>(producerProps)) { + while (true) { + final int index = rand.nextInt(numKeys); + final String key = data[index].key; + final int value = data[index].next(); + + final ProducerRecord record = + new ProducerRecord<>( + "data", + stringSerde.serializer().serialize("", key), + intSerde.serializer().serialize("", value) + ); + + producer.send(record); + + numRecordsProduced++; + if (numRecordsProduced % 100 == 0) { + System.out.println(Instant.now() + " " + numRecordsProduced + " records produced"); + } + Utils.sleep(2); + } + } + } + + public static Map> generate(final String kafka, + final int numKeys, + final int maxRecordsPerKey, + final Duration timeToSpend) { + final Properties producerProps = generatorProperties(kafka); + + + int numRecordsProduced = 0; + + final Map> allData = new HashMap<>(); + final ValueList[] data = new ValueList[numKeys]; + for (int i = 0; i < numKeys; i++) { + data[i] = new ValueList(i, i + maxRecordsPerKey - 1); + allData.put(data[i].key, new HashSet<>()); + } + final Random rand = new Random(); + + int remaining = data.length; + + final long recordPauseTime = timeToSpend.toMillis() / numKeys / maxRecordsPerKey; + + List> needRetry = new ArrayList<>(); + + try (final KafkaProducer producer = new KafkaProducer<>(producerProps)) { + while (remaining > 0) { + final int index = rand.nextInt(remaining); + final String key = data[index].key; + final int value = data[index].next(); + + if (value < 0) { + remaining--; + data[index] = data[remaining]; + } else { + + final ProducerRecord record = + new ProducerRecord<>( + "data", + stringSerde.serializer().serialize("", key), + intSerde.serializer().serialize("", value) + ); + + producer.send(record, new TestCallback(record, needRetry)); + + numRecordsProduced++; + allData.get(key).add(value); + if (numRecordsProduced % 100 == 0) { + System.out.println(Instant.now() + " " + numRecordsProduced + " records produced"); + } + Utils.sleep(Math.max(recordPauseTime, 2)); + } + } + producer.flush(); + + int remainingRetries = 5; + while (!needRetry.isEmpty()) { + final List> needRetry2 = new ArrayList<>(); + for (final ProducerRecord record : needRetry) { + System.out.println("retry producing " + stringSerde.deserializer().deserialize("", record.key())); + producer.send(record, new TestCallback(record, needRetry2)); + } + producer.flush(); + needRetry = needRetry2; + + if (--remainingRetries == 0 && !needRetry.isEmpty()) { + System.err.println("Failed to produce all records after multiple retries"); + Exit.exit(1); + } + } + + // now that we've sent everything, we'll send some final records with a timestamp high enough to flush out + // all suppressed records. + final List partitions = producer.partitionsFor("data"); + for (final PartitionInfo partition : partitions) { + producer.send(new ProducerRecord<>( + partition.topic(), + partition.partition(), + System.currentTimeMillis() + Duration.ofDays(2).toMillis(), + stringSerde.serializer().serialize("", "flush"), + intSerde.serializer().serialize("", 0) + )); + } + } + return Collections.unmodifiableMap(allData); + } + + private static Properties generatorProperties(final String kafka) { + final Properties producerProps = new Properties(); + producerProps.put(ProducerConfig.CLIENT_ID_CONFIG, "SmokeTest"); + producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, kafka); + producerProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class); + producerProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, ByteArraySerializer.class); + producerProps.put(ProducerConfig.ACKS_CONFIG, "all"); + return producerProps; + } + + private static class TestCallback implements Callback { + private final ProducerRecord originalRecord; + private final List> needRetry; + + TestCallback(final ProducerRecord originalRecord, + final List> needRetry) { + this.originalRecord = originalRecord; + this.needRetry = needRetry; + } + + @Override + public void onCompletion(final RecordMetadata metadata, final Exception exception) { + if (exception != null) { + if (exception instanceof TimeoutException) { + needRetry.add(originalRecord); + } else { + exception.printStackTrace(); + Exit.exit(1); + } + } + } + } + + private static void shuffle(final int[] data, @SuppressWarnings("SameParameterValue") final int windowSize) { + final Random rand = new Random(); + for (int i = 0; i < data.length; i++) { + // we shuffle data within windowSize + final int j = rand.nextInt(Math.min(data.length - i, windowSize)) + i; + + // swap + final int tmp = data[i]; + data[i] = data[j]; + data[j] = tmp; + } + } + + public static class NumberDeserializer implements Deserializer { + @Override + public Number deserialize(final String topic, final byte[] data) { + final Number value; + switch (topic) { + case "data": + case "echo": + case "min": + case "min-raw": + case "min-suppressed": + case "sws-raw": + case "sws-suppressed": + case "max": + case "dif": + value = intSerde.deserializer().deserialize(topic, data); + break; + case "sum": + case "cnt": + case "tagg": + value = longSerde.deserializer().deserialize(topic, data); + break; + case "avg": + value = doubleSerde.deserializer().deserialize(topic, data); + break; + default: + throw new RuntimeException("unknown topic: " + topic); + } + return value; + } + } + + public static VerificationResult verify(final String kafka, + final Map> inputs, + final int maxRecordsPerKey) { + final Properties props = new Properties(); + props.put(ConsumerConfig.CLIENT_ID_CONFIG, "verifier"); + props.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, kafka); + props.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class); + props.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, NumberDeserializer.class); + props.put(ConsumerConfig.ISOLATION_LEVEL_CONFIG, "read_committed"); + + final KafkaConsumer consumer = new KafkaConsumer<>(props); + final List partitions = getAllPartitions(consumer, TOPICS); + consumer.assign(partitions); + consumer.seekToBeginning(partitions); + + final int recordsGenerated = inputs.size() * maxRecordsPerKey; + int recordsProcessed = 0; + final Map processed = + Stream.of(TOPICS) + .collect(Collectors.toMap(t -> t, t -> new AtomicInteger(0))); + + final Map>>> events = new HashMap<>(); + + VerificationResult verificationResult = new VerificationResult(false, "no results yet"); + int retry = 0; + final long start = System.currentTimeMillis(); + while (System.currentTimeMillis() - start < TimeUnit.MINUTES.toMillis(6)) { + final ConsumerRecords records = consumer.poll(Duration.ofSeconds(5)); + if (records.isEmpty() && recordsProcessed >= recordsGenerated) { + verificationResult = verifyAll(inputs, events, false); + if (verificationResult.passed()) { + break; + } else if (retry++ > MAX_RECORD_EMPTY_RETRIES) { + System.out.println(Instant.now() + " Didn't get any more results, verification hasn't passed, and out of retries."); + break; + } else { + System.out.println(Instant.now() + " Didn't get any more results, but verification hasn't passed (yet). Retrying..." + retry); + } + } else { + System.out.println(Instant.now() + " Get some more results from " + records.partitions() + ", resetting retry."); + + retry = 0; + for (final ConsumerRecord record : records) { + final String key = record.key(); + + final String topic = record.topic(); + processed.get(topic).incrementAndGet(); + + if (topic.equals("echo")) { + recordsProcessed++; + if (recordsProcessed % 100 == 0) { + System.out.println("Echo records processed = " + recordsProcessed); + } + } + + events.computeIfAbsent(topic, t -> new HashMap<>()) + .computeIfAbsent(key, k -> new LinkedList<>()) + .add(record); + } + + System.out.println(processed); + } + } + consumer.close(); + final long finished = System.currentTimeMillis() - start; + System.out.println("Verification time=" + finished); + System.out.println("-------------------"); + System.out.println("Result Verification"); + System.out.println("-------------------"); + System.out.println("recordGenerated=" + recordsGenerated); + System.out.println("recordProcessed=" + recordsProcessed); + + if (recordsProcessed > recordsGenerated) { + System.out.println("PROCESSED-MORE-THAN-GENERATED"); + } else if (recordsProcessed < recordsGenerated) { + System.out.println("PROCESSED-LESS-THAN-GENERATED"); + } + + boolean success; + + final Map> received = + events.get("echo") + .entrySet() + .stream() + .map(entry -> mkEntry( + entry.getKey(), + entry.getValue().stream().map(ConsumerRecord::value).collect(Collectors.toSet())) + ) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + success = inputs.equals(received); + + if (success) { + System.out.println("ALL-RECORDS-DELIVERED"); + } else { + int missedCount = 0; + for (final Map.Entry> entry : inputs.entrySet()) { + missedCount += received.get(entry.getKey()).size(); + } + System.out.println("missedRecords=" + missedCount); + } + + // give it one more try if it's not already passing. + if (!verificationResult.passed()) { + verificationResult = verifyAll(inputs, events, true); + } + success &= verificationResult.passed(); + + System.out.println(verificationResult.result()); + + System.out.println(success ? "SUCCESS" : "FAILURE"); + return verificationResult; + } + + public static class VerificationResult { + private final boolean passed; + private final String result; + + VerificationResult(final boolean passed, final String result) { + this.passed = passed; + this.result = result; + } + + public boolean passed() { + return passed; + } + + public String result() { + return result; + } + } + + private static VerificationResult verifyAll(final Map> inputs, + final Map>>> events, + final boolean printResults) { + final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + boolean pass; + try (final PrintStream resultStream = new PrintStream(byteArrayOutputStream)) { + pass = verifyTAgg(resultStream, inputs, events.get("tagg"), printResults); + pass &= verifySuppressed(resultStream, "min-suppressed", events, printResults); + pass &= verify(resultStream, "min-suppressed", inputs, events, windowedKey -> { + final String unwindowedKey = windowedKey.substring(1, windowedKey.length() - 1).replaceAll("@.*", ""); + return getMin(unwindowedKey); + }, printResults); + pass &= verifySuppressed(resultStream, "sws-suppressed", events, printResults); + pass &= verify(resultStream, "min", inputs, events, SmokeTestDriver::getMin, printResults); + pass &= verify(resultStream, "max", inputs, events, SmokeTestDriver::getMax, printResults); + pass &= verify(resultStream, "dif", inputs, events, key -> getMax(key).intValue() - getMin(key).intValue(), printResults); + pass &= verify(resultStream, "sum", inputs, events, SmokeTestDriver::getSum, printResults); + pass &= verify(resultStream, "cnt", inputs, events, key1 -> getMax(key1).intValue() - getMin(key1).intValue() + 1L, printResults); + pass &= verify(resultStream, "avg", inputs, events, SmokeTestDriver::getAvg, printResults); + } + return new VerificationResult(pass, new String(byteArrayOutputStream.toByteArray(), StandardCharsets.UTF_8)); + } + + private static boolean verify(final PrintStream resultStream, + final String topic, + final Map> inputData, + final Map>>> events, + final Function keyToExpectation, + final boolean printResults) { + final Map>> observedInputEvents = events.get("data"); + final Map>> outputEvents = events.getOrDefault(topic, emptyMap()); + if (outputEvents.isEmpty()) { + resultStream.println(topic + " is empty"); + return false; + } else { + resultStream.printf("verifying %s with %d keys%n", topic, outputEvents.size()); + + if (outputEvents.size() != inputData.size()) { + resultStream.printf("fail: resultCount=%d expectedCount=%s%n\tresult=%s%n\texpected=%s%n", + outputEvents.size(), inputData.size(), outputEvents.keySet(), inputData.keySet()); + return false; + } + for (final Map.Entry>> entry : outputEvents.entrySet()) { + final String key = entry.getKey(); + final Number expected = keyToExpectation.apply(key); + final Number actual = entry.getValue().getLast().value(); + if (!expected.equals(actual)) { + resultStream.printf("%s fail: key=%s actual=%s expected=%s%n", topic, key, actual, expected); + + if (printResults) { + resultStream.printf("\t inputEvents=%n%s%n\t" + + "echoEvents=%n%s%n\tmaxEvents=%n%s%n\tminEvents=%n%s%n\tdifEvents=%n%s%n\tcntEvents=%n%s%n\ttaggEvents=%n%s%n", + indent("\t\t", observedInputEvents.get(key)), + indent("\t\t", events.getOrDefault("echo", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("max", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("min", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("dif", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("cnt", emptyMap()).getOrDefault(key, new LinkedList<>())), + indent("\t\t", events.getOrDefault("tagg", emptyMap()).getOrDefault(key, new LinkedList<>()))); + + if (!Utils.mkSet("echo", "max", "min", "dif", "cnt", "tagg").contains(topic)) + resultStream.printf("%sEvents=%n%s%n", topic, indent("\t\t", entry.getValue())); + } + + return false; + } + } + return true; + } + } + + + private static boolean verifySuppressed(final PrintStream resultStream, + @SuppressWarnings("SameParameterValue") final String topic, + final Map>>> events, + final boolean printResults) { + resultStream.println("verifying suppressed " + topic); + final Map>> topicEvents = events.getOrDefault(topic, emptyMap()); + for (final Map.Entry>> entry : topicEvents.entrySet()) { + if (entry.getValue().size() != 1) { + final String unsuppressedTopic = topic.replace("-suppressed", "-raw"); + final String key = entry.getKey(); + final String unwindowedKey = key.substring(1, key.length() - 1).replaceAll("@.*", ""); + resultStream.printf("fail: key=%s%n\tnon-unique result:%n%s%n", + key, + indent("\t\t", entry.getValue())); + + if (printResults) + resultStream.printf("\tresultEvents:%n%s%n\tinputEvents:%n%s%n", + indent("\t\t", events.get(unsuppressedTopic).get(key)), + indent("\t\t", events.get("data").get(unwindowedKey))); + + return false; + } + } + return true; + } + + private static String indent(@SuppressWarnings("SameParameterValue") final String prefix, + final Iterable> list) { + final StringBuilder stringBuilder = new StringBuilder(); + for (final ConsumerRecord record : list) { + stringBuilder.append(prefix).append(record).append('\n'); + } + return stringBuilder.toString(); + } + + private static Long getSum(final String key) { + final int min = getMin(key).intValue(); + final int max = getMax(key).intValue(); + return ((long) min + max) * (max - min + 1L) / 2L; + } + + private static Double getAvg(final String key) { + final int min = getMin(key).intValue(); + final int max = getMax(key).intValue(); + return ((long) min + max) / 2.0; + } + + + private static boolean verifyTAgg(final PrintStream resultStream, + final Map> allData, + final Map>> taggEvents, + final boolean printResults) { + if (taggEvents == null) { + resultStream.println("tagg is missing"); + return false; + } else if (taggEvents.isEmpty()) { + resultStream.println("tagg is empty"); + return false; + } else { + resultStream.println("verifying tagg"); + + // generate expected answer + final Map expected = new HashMap<>(); + for (final String key : allData.keySet()) { + final int min = getMin(key).intValue(); + final int max = getMax(key).intValue(); + final String cnt = Long.toString(max - min + 1L); + + expected.put(cnt, expected.getOrDefault(cnt, 0L) + 1); + } + + // check the result + for (final Map.Entry>> entry : taggEvents.entrySet()) { + final String key = entry.getKey(); + Long expectedCount = expected.remove(key); + if (expectedCount == null) { + expectedCount = 0L; + } + + if (entry.getValue().getLast().value().longValue() != expectedCount) { + resultStream.println("fail: key=" + key + " tagg=" + entry.getValue() + " expected=" + expectedCount); + + if (printResults) + resultStream.println("\t taggEvents: " + entry.getValue()); + return false; + } + } + + } + return true; + } + + private static Number getMin(final String key) { + return Integer.parseInt(key.split("-")[0]); + } + + private static Number getMax(final String key) { + return Integer.parseInt(key.split("-")[1]); + } + + private static List getAllPartitions(final KafkaConsumer consumer, final String... topics) { + final List partitions = new ArrayList<>(); + + for (final String topic : topics) { + for (final PartitionInfo info : consumer.partitionsFor(topic)) { + partitions.add(new TopicPartition(info.topic(), info.partition())); + } + } + return partitions; + } + +} diff --git a/streams/upgrade-system-tests-28/src/test/java/org/apache/kafka/streams/tests/SmokeTestUtil.java b/streams/upgrade-system-tests-28/src/test/java/org/apache/kafka/streams/tests/SmokeTestUtil.java new file mode 100644 index 0000000..519b5d5 --- /dev/null +++ b/streams/upgrade-system-tests-28/src/test/java/org/apache/kafka/streams/tests/SmokeTestUtil.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.serialization.Serde; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.streams.KeyValue; +import org.apache.kafka.streams.kstream.Aggregator; +import org.apache.kafka.streams.kstream.Initializer; +import org.apache.kafka.streams.kstream.KeyValueMapper; +import org.apache.kafka.streams.kstream.Windowed; +import org.apache.kafka.streams.processor.AbstractProcessor; +import org.apache.kafka.streams.processor.Processor; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.ProcessorSupplier; + +import java.time.Instant; + +public class SmokeTestUtil { + + final static int END = Integer.MAX_VALUE; + + static ProcessorSupplier printProcessorSupplier(final String topic) { + return printProcessorSupplier(topic, ""); + } + + static ProcessorSupplier printProcessorSupplier(final String topic, final String name) { + return new ProcessorSupplier() { + @Override + public Processor get() { + return new AbstractProcessor() { + private int numRecordsProcessed = 0; + private long smallestOffset = Long.MAX_VALUE; + private long largestOffset = Long.MIN_VALUE; + + @Override + public void init(final ProcessorContext context) { + super.init(context); + System.out.println("[2.8] initializing processor: topic=" + topic + " taskId=" + context.taskId()); + System.out.flush(); + numRecordsProcessed = 0; + smallestOffset = Long.MAX_VALUE; + largestOffset = Long.MIN_VALUE; + } + + @Override + public void process(final Object key, final Object value) { + numRecordsProcessed++; + if (numRecordsProcessed % 100 == 0) { + System.out.printf("%s: %s%n", name, Instant.now()); + System.out.println("processed " + numRecordsProcessed + " records from topic=" + topic); + } + + if (smallestOffset > context().offset()) { + smallestOffset = context().offset(); + } + if (largestOffset < context().offset()) { + largestOffset = context().offset(); + } + } + + @Override + public void close() { + System.out.printf("Close processor for task %s%n", context().taskId()); + System.out.println("processed " + numRecordsProcessed + " records"); + final long processed; + if (largestOffset >= smallestOffset) { + processed = 1L + largestOffset - smallestOffset; + } else { + processed = 0L; + } + System.out.println("offset " + smallestOffset + " to " + largestOffset + " -> processed " + processed); + System.out.flush(); + } + }; + } + }; + } + + public static final class Unwindow implements KeyValueMapper, V, K> { + @Override + public K apply(final Windowed winKey, final V value) { + return winKey.key(); + } + } + + public static class Agg { + + KeyValueMapper> selector() { + return (key, value) -> new KeyValue<>(value == null ? null : Long.toString(value), 1L); + } + + public Initializer init() { + return () -> 0L; + } + + Aggregator adder() { + return (aggKey, value, aggregate) -> aggregate + value; + } + + Aggregator remover() { + return (aggKey, value, aggregate) -> aggregate - value; + } + } + + public static Serde stringSerde = Serdes.String(); + + public static Serde intSerde = Serdes.Integer(); + + static Serde longSerde = Serdes.Long(); + + static Serde doubleSerde = Serdes.Double(); + + public static void sleep(final long duration) { + try { + Thread.sleep(duration); + } catch (final Exception ignore) { } + } + +} diff --git a/streams/upgrade-system-tests-28/src/test/java/org/apache/kafka/streams/tests/StreamsSmokeTest.java b/streams/upgrade-system-tests-28/src/test/java/org/apache/kafka/streams/tests/StreamsSmokeTest.java new file mode 100644 index 0000000..f280eb0 --- /dev/null +++ b/streams/upgrade-system-tests-28/src/test/java/org/apache/kafka/streams/tests/StreamsSmokeTest.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.StreamsConfig; + +import java.io.IOException; +import java.time.Duration; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.UUID; + +import static org.apache.kafka.streams.tests.SmokeTestDriver.generate; +import static org.apache.kafka.streams.tests.SmokeTestDriver.generatePerpetually; + +public class StreamsSmokeTest { + + /** + * args ::= kafka propFileName command disableAutoTerminate + * command := "run" | "process" + * + * @param args + */ + public static void main(final String[] args) throws IOException { + if (args.length < 2) { + System.err.println("StreamsSmokeTest are expecting two parameters: propFile, command; but only see " + args.length + " parameter"); + Exit.exit(1); + } + + final String propFileName = args[0]; + final String command = args[1]; + final boolean disableAutoTerminate = args.length > 2; + + final Properties streamsProperties = Utils.loadProps(propFileName); + final String kafka = streamsProperties.getProperty(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG); + final String processingGuarantee = streamsProperties.getProperty(StreamsConfig.PROCESSING_GUARANTEE_CONFIG); + + if (kafka == null) { + System.err.println("No bootstrap kafka servers specified in " + StreamsConfig.BOOTSTRAP_SERVERS_CONFIG); + Exit.exit(1); + } + + if ("process".equals(command)) { + if (!StreamsConfig.AT_LEAST_ONCE.equals(processingGuarantee) && + !StreamsConfig.EXACTLY_ONCE.equals(processingGuarantee)) { + + System.err.println("processingGuarantee must be either " + StreamsConfig.AT_LEAST_ONCE + " or " + + StreamsConfig.EXACTLY_ONCE); + + Exit.exit(1); + } + } + + System.out.println("StreamsTest instance started (StreamsSmokeTest)"); + System.out.println("command=" + command); + System.out.println("props=" + streamsProperties); + System.out.println("disableAutoTerminate=" + disableAutoTerminate); + + switch (command) { + case "run": + // this starts the driver (data generation and result verification) + final int numKeys = 10; + final int maxRecordsPerKey = 500; + if (disableAutoTerminate) { + generatePerpetually(kafka, numKeys, maxRecordsPerKey); + } else { + // slow down data production to span 30 seconds so that system tests have time to + // do their bounces, etc. + final Map> allData = + generate(kafka, numKeys, maxRecordsPerKey, Duration.ofSeconds(30)); + SmokeTestDriver.verify(kafka, allData, maxRecordsPerKey); + } + break; + case "process": + // this starts the stream processing app + new SmokeTestClient(UUID.randomUUID().toString()).start(streamsProperties); + break; + default: + System.out.println("unknown command: " + command); + } + } + +} diff --git a/streams/upgrade-system-tests-28/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java b/streams/upgrade-system-tests-28/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java new file mode 100644 index 0000000..4f2825d --- /dev/null +++ b/streams/upgrade-system-tests-28/src/test/java/org/apache/kafka/streams/tests/StreamsUpgradeTest.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.tests; + +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsBuilder; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.kstream.KStream; +import org.apache.kafka.streams.processor.AbstractProcessor; +import org.apache.kafka.streams.processor.ProcessorContext; +import org.apache.kafka.streams.processor.ProcessorSupplier; + +import java.util.Properties; + +public class StreamsUpgradeTest { + + @SuppressWarnings("unchecked") + public static void main(final String[] args) throws Exception { + if (args.length < 1) { + System.err.println("StreamsUpgradeTest requires one argument (properties-file) but provided none"); + } + final String propFileName = args[0]; + + final Properties streamsProperties = Utils.loadProps(propFileName); + + System.out.println("StreamsTest instance started (StreamsUpgradeTest v2.8)"); + System.out.println("props=" + streamsProperties); + + final StreamsBuilder builder = new StreamsBuilder(); + final KStream dataStream = builder.stream("data"); + dataStream.process(printProcessorSupplier()); + dataStream.to("echo"); + + final Properties config = new Properties(); + config.setProperty(StreamsConfig.APPLICATION_ID_CONFIG, "StreamsUpgradeTest"); + config.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 1000); + config.putAll(streamsProperties); + + final KafkaStreams streams = new KafkaStreams(builder.build(), config); + streams.start(); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + streams.close(); + System.out.println("UPGRADE-TEST-CLIENT-CLOSED"); + System.out.flush(); + })); + } + + private static ProcessorSupplier printProcessorSupplier() { + return () -> new AbstractProcessor() { + private int numRecordsProcessed = 0; + + @Override + public void init(final ProcessorContext context) { + System.out.println("[2.8] initializing processor: topic=data taskId=" + context.taskId()); + numRecordsProcessed = 0; + } + + @Override + public void process(final K key, final V value) { + numRecordsProcessed++; + if (numRecordsProcessed % 100 == 0) { + System.out.println("processed " + numRecordsProcessed + " records from topic=data"); + } + } + + @Override + public void close() {} + }; + } +} diff --git a/tests/.gitignore b/tests/.gitignore new file mode 100644 index 0000000..5eb6164 --- /dev/null +++ b/tests/.gitignore @@ -0,0 +1,7 @@ +Vagrantfile.local +.idea/ +*.pyc +*.ipynb +.DS_Store +.ducktape +results/ diff --git a/tests/MANIFEST.in b/tests/MANIFEST.in new file mode 100644 index 0000000..3164ec6 --- /dev/null +++ b/tests/MANIFEST.in @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +recursive-include kafkatest */templates/* diff --git a/tests/README.md b/tests/README.md new file mode 100644 index 0000000..e54c60e --- /dev/null +++ b/tests/README.md @@ -0,0 +1,595 @@ +System Integration & Performance Testing +======================================== + +This directory contains Kafka system integration and performance tests. +[ducktape](https://github.com/confluentinc/ducktape) is used to run the tests. +(ducktape is a distributed testing framework which provides test runner, +result reporter and utilities to pull up and tear down services.) + +Running tests using docker +-------------------------- +Docker containers can be used for running kafka system tests locally. +* Requirements + - Docker 1.12.3 (or higher) is installed and running on the machine. + - Test requires that Kafka, including system test libs, is built. This can be done by running +``` +./gradlew clean systemTestLibs +``` +* Run all tests +``` +bash tests/docker/run_tests.sh +``` +* Run all tests with debug on (warning will produce log of logs) +``` +_DUCKTAPE_OPTIONS="--debug" bash tests/docker/run_tests.sh | tee debug_logs.txt +``` +* Run a subset of tests +``` +TC_PATHS="tests/kafkatest/tests/streams tests/kafkatest/tests/tools" bash tests/docker/run_tests.sh +``` +* Run a specific tests file +``` +TC_PATHS="tests/kafkatest/tests/client/pluggable_test.py" bash tests/docker/run_tests.sh +``` +* Run a specific test class +``` +TC_PATHS="tests/kafkatest/tests/client/pluggable_test.py::PluggableConsumerTest" bash tests/docker/run_tests.sh +``` +* Run a specific test method +``` +TC_PATHS="tests/kafkatest/tests/client/pluggable_test.py::PluggableConsumerTest.test_start_stop" bash tests/docker/run_tests.sh +``` +* Run a specific test method with specific parameters +``` +TC_PATHS="tests/kafkatest/tests/streams/streams_upgrade_test.py::StreamsUpgradeTest.test_metadata_upgrade" _DUCKTAPE_OPTIONS='--parameters '\''{"from_version":"0.10.1.1","to_version":"2.6.0-SNAPSHOT"}'\' bash tests/docker/run_tests.sh +``` +* Run tests with a different JVM +``` +bash tests/docker/ducker-ak up -j 'openjdk:11'; tests/docker/run_tests.sh +``` +* Rebuild first and then run tests +``` +REBUILD="t" bash tests/docker/run_tests.sh +``` +* Debug tests in VS Code: + - Run test with `--debug` flag (can be before or after file name): + ``` + tests/docker/ducker-ak up; tests/docker/ducker-ak test tests/kafkatest/tests/core/security_test.py --debug + ``` + - Test will run in debug mode and wait for a debugger to attach. + - Launch VS Code debugger with `"attach"` request - here's an example: + ```json + { + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Attach to Ducker", + "type": "python", + "request": "attach", + "connect": { + "host": "localhost", + "port": 5678 + }, + "justMyCode": false, + "pathMappings": [ + { + "localRoot": "${workspaceFolder}", + "remoteRoot": "." + } + ] + } + ] + } + ``` + - To pass `--debug` flag to ducktape itself, use `--`: + ``` + tests/docker/ducker-ak test tests/kafkatest/tests/core/security_test.py --debug -- --debug + ``` + +* Notes + - The scripts to run tests creates and destroys docker network named *knw*. + This network can't be used for any other purpose. + - The docker containers are named knode01, knode02 etc. + These nodes can't be used for any other purpose. + +* Exposing ports using --expose-ports option of `ducker-ak up` command + + If `--expose-ports` is specified then we will expose those ports to random ephemeral ports + on the host. The argument can be a single port (like 5005), a port range like (5005-5009) + or a combination of port/port-range separated by comma (like 2181,9092 or 2181,5005-5008). + By default no port is exposed. + + The exposed port mapping can be seen by executing `docker ps` command. The PORT column + of the output shows the mapping like this (maps port 33891 on host to port 2182 in container): + + 0.0.0.0:33891->2182/tcp + + Behind the scene Docker is setting up a DNAT rule for the mapping and it is visible in + the DOCKER section of iptables command (`sudo iptables -t nat -L -n`), something like: + +
                DNAT       tcp  --  0.0.0.0/0      0.0.0.0/0      tcp       dpt:33882       to:172.22.0.2:9092
                + + The exposed port(s) are useful to attach a remote debugger to the process running + in the docker image. For example if port 5005 was exposed and is mapped to an ephemeral + port (say 33891), then a debugger attaching to port 33891 on host will be connecting to + a debug session started at port 5005 in the docker image. As an example, for above port + numbers, run following commands in the docker image (say by ssh using `./docker/ducker-ak ssh ducker02`): + + > $ export KAFKA_OPTS="-agentlib:jdwp=transport=dt_socket,server=y,suspend=y,address=5005" + + > $ /opt/kafka-dev/bin/kafka-topics.sh --bootstrap-server ducker03:9095 --topic __consumer_offsets --describe + + This will run the TopicCommand to describe the __consumer-offset topic. The java process + will stop and wait for debugger to attach as `suspend=y` option was specified. Now starting + a debugger on host with host `localhost` and following parameter as JVM setting: + + `-agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=33891` + + will attach it to the TopicCommand process running in the docker image. + +Examining CI run +---------------- +* Set BUILD_ID is travis ci's build id. E.g. build id is 169519874 for the following build +```bash +https://travis-ci.org/apache/kafka/builds/169519874 +``` + +* Getting number of tests that were actually run +```bash +for id in $(curl -sSL https://api.travis-ci.org/builds/$BUILD_ID | jq '.matrix|map(.id)|.[]'); do curl -sSL "https://api.travis-ci.org/jobs/$id/log.txt?deansi=true" ; done | grep -cE 'RunnerClient: Loading test' +``` + +* Getting number of tests that passed +```bash +for id in $(curl -sSL https://api.travis-ci.org/builds/$BUILD_ID | jq '.matrix|map(.id)|.[]'); do curl -sSL "https://api.travis-ci.org/jobs/$id/log.txt?deansi=true" ; done | grep -cE 'RunnerClient.*PASS' +``` +* Getting all the logs produced from a run +```bash +for id in $(curl -sSL https://api.travis-ci.org/builds/$BUILD_ID | jq '.matrix|map(.id)|.[]'); do curl -sSL "https://api.travis-ci.org/jobs/$id/log.txt?deansi=true" ; done +``` +* Explanation of curl calls to travis-ci & jq commands + - We get json information of the build using the following command +```bash +curl -sSL https://api.travis-ci.org/apache/kafka/builds/169519874 +``` +This produces a json about the build which looks like: +```json +{ + "id": 169519874, + "repository_id": 6097916, + "number": "19", + "config": { + "sudo": "required", + "dist": "trusty", + "language": "java", + "env": [ + "TC_PATHS=\"tests/kafkatest/tests/client\"", + "TC_PATHS=\"tests/kafkatest/tests/connect tests/kafkatest/tests/streams tests/kafkatest/tests/tools\"", + "TC_PATHS=\"tests/kafkatest/tests/mirror_maker\"", + "TC_PATHS=\"tests/kafkatest/tests/replication\"", + "TC_PATHS=\"tests/kafkatest/tests/upgrade\"", + "TC_PATHS=\"tests/kafkatest/tests/security\"", + "TC_PATHS=\"tests/kafkatest/tests/core\"" + ], + "jdk": [ + "oraclejdk8" + ], + "before_install": null, + "script": [ + "./gradlew systemTestLibs && /bin/bash ./tests/travis/run_tests.sh" + ], + "services": [ + "docker" + ], + "before_cache": [ + "rm -f $HOME/.gradle/caches/modules-2/modules-2.lock", + "rm -fr $HOME/.gradle/caches/*/plugin-resolution/" + ], + "cache": { + "directories": [ + "$HOME/.m2/repository", + "$HOME/.gradle/caches/", + "$HOME/.gradle/wrapper/" + ] + }, + ".result": "configured", + "group": "stable" + }, + "state": "finished", + "result": null, + "status": null, + "started_at": "2016-10-21T13:35:43Z", + "finished_at": "2016-10-21T14:46:03Z", + "duration": 16514, + "commit": "7e583d9ea08c70dbbe35a3adde72ed203a797f64", + "branch": "trunk", + "message": "respect _DUCK_OPTIONS", + "committed_at": "2016-10-21T00:12:36Z", + "author_name": "Raghav Kumar Gautam", + "author_email": "raghav@apache.org", + "committer_name": "Raghav Kumar Gautam", + "committer_email": "raghav@apache.org", + "compare_url": "https://github.com/raghavgautam/kafka/compare/cc788ac99ca7...7e583d9ea08c", + "event_type": "push", + "matrix": [ + { + "id": 169519875, + "repository_id": 6097916, + "number": "19.1", + "config": { + "sudo": "required", + "dist": "trusty", + "language": "java", + "env": "TC_PATHS=\"tests/kafkatest/tests/client\"", + "jdk": "oraclejdk8", + "before_install": null, + "script": [ + "./gradlew systemTestLibs && /bin/bash ./tests/travis/run_tests.sh" + ], + "services": [ + "docker" + ], + "before_cache": [ + "rm -f $HOME/.gradle/caches/modules-2/modules-2.lock", + "rm -fr $HOME/.gradle/caches/*/plugin-resolution/" + ], + "cache": { + "directories": [ + "$HOME/.m2/repository", + "$HOME/.gradle/caches/", + "$HOME/.gradle/wrapper/" + ] + }, + ".result": "configured", + "group": "stable", + "os": "linux" + }, + "result": null, + "started_at": "2016-10-21T13:35:43Z", + "finished_at": "2016-10-21T14:24:50Z", + "allow_failure": false + }, + { + "id": 169519876, + "repository_id": 6097916, + "number": "19.2", + "config": { + "sudo": "required", + "dist": "trusty", + "language": "java", + "env": "TC_PATHS=\"tests/kafkatest/tests/connect tests/kafkatest/tests/streams tests/kafkatest/tests/tools\"", + "jdk": "oraclejdk8", + "before_install": null, + "script": [ + "./gradlew systemTestLibs && /bin/bash ./tests/travis/run_tests.sh" + ], + "services": [ + "docker" + ], + "before_cache": [ + "rm -f $HOME/.gradle/caches/modules-2/modules-2.lock", + "rm -fr $HOME/.gradle/caches/*/plugin-resolution/" + ], + "cache": { + "directories": [ + "$HOME/.m2/repository", + "$HOME/.gradle/caches/", + "$HOME/.gradle/wrapper/" + ] + }, + ".result": "configured", + "group": "stable", + "os": "linux" + }, + "result": 1, + "started_at": "2016-10-21T13:35:46Z", + "finished_at": "2016-10-21T14:22:05Z", + "allow_failure": false + }, + + ... + ] +} + +``` + - By passing this through jq filter `.matrix` we extract the matrix part of the json +```bash +curl -sSL https://api.travis-ci.org/apache/kafka/builds/169519874 | jq '.matrix' +``` +The resulting json looks like: +```json +[ + { + "id": 169519875, + "repository_id": 6097916, + "number": "19.1", + "config": { + "sudo": "required", + "dist": "trusty", + "language": "java", + "env": "TC_PATHS=\"tests/kafkatest/tests/client\"", + "jdk": "oraclejdk8", + "before_install": null, + "script": [ + "./gradlew systemTestLibs && /bin/bash ./tests/travis/run_tests.sh" + ], + "services": [ + "docker" + ], + "before_cache": [ + "rm -f $HOME/.gradle/caches/modules-2/modules-2.lock", + "rm -fr $HOME/.gradle/caches/*/plugin-resolution/" + ], + "cache": { + "directories": [ + "$HOME/.m2/repository", + "$HOME/.gradle/caches/", + "$HOME/.gradle/wrapper/" + ] + }, + ".result": "configured", + "group": "stable", + "os": "linux" + }, + "result": null, + "started_at": "2016-10-21T13:35:43Z", + "finished_at": "2016-10-21T14:24:50Z", + "allow_failure": false + }, + { + "id": 169519876, + "repository_id": 6097916, + "number": "19.2", + "config": { + "sudo": "required", + "dist": "trusty", + "language": "java", + "env": "TC_PATHS=\"tests/kafkatest/tests/connect tests/kafkatest/tests/streams tests/kafkatest/tests/tools\"", + "jdk": "oraclejdk8", + "before_install": null, + "script": [ + "./gradlew systemTestLibs && /bin/bash ./tests/travis/run_tests.sh" + ], + "services": [ + "docker" + ], + "before_cache": [ + "rm -f $HOME/.gradle/caches/modules-2/modules-2.lock", + "rm -fr $HOME/.gradle/caches/*/plugin-resolution/" + ], + "cache": { + "directories": [ + "$HOME/.m2/repository", + "$HOME/.gradle/caches/", + "$HOME/.gradle/wrapper/" + ] + }, + ".result": "configured", + "group": "stable", + "os": "linux" + }, + "result": 1, + "started_at": "2016-10-21T13:35:46Z", + "finished_at": "2016-10-21T14:22:05Z", + "allow_failure": false + }, + + ... +] + +``` + - By further passing this through jq filter `map(.id)` we extract the id of + the builds for each of the splits +```bash +curl -sSL https://api.travis-ci.org/apache/kafka/builds/169519874 | jq '.matrix|map(.id)' +``` +The resulting json looks like: +```json +[ + 169519875, + 169519876, + 169519877, + 169519878, + 169519879, + 169519880, + 169519881 +] +``` + - To use these ids in for loop we want to get rid of `[]` which is done by + passing it through `.[]` filter +```bash +curl -sSL https://api.travis-ci.org/apache/kafka/builds/169519874 | jq '.matrix|map(.id)|.[]' +``` +And we get +```text +169519875 +169519876 +169519877 +169519878 +169519879 +169519880 +169519881 +``` + - In the for loop we have made calls to fetch logs +```bash +curl -sSL "https://api.travis-ci.org/jobs/169519875/log.txt?deansi=true" | tail +``` +which gives us +```text +[INFO:2016-10-21 14:21:12,538]: SerialTestRunner: kafkatest.tests.client.consumer_test.OffsetValidationTest.test_consumer_bounce.clean_shutdown=False.bounce_mode=rolling: test 16 of 28 +[INFO:2016-10-21 14:21:12,538]: SerialTestRunner: kafkatest.tests.client.consumer_test.OffsetValidationTest.test_consumer_bounce.clean_shutdown=False.bounce_mode=rolling: setting up +[INFO:2016-10-21 14:21:30,810]: SerialTestRunner: kafkatest.tests.client.consumer_test.OffsetValidationTest.test_consumer_bounce.clean_shutdown=False.bounce_mode=rolling: running +[INFO:2016-10-21 14:24:35,519]: SerialTestRunner: kafkatest.tests.client.consumer_test.OffsetValidationTest.test_consumer_bounce.clean_shutdown=False.bounce_mode=rolling: PASS +[INFO:2016-10-21 14:24:35,519]: SerialTestRunner: kafkatest.tests.client.consumer_test.OffsetValidationTest.test_consumer_bounce.clean_shutdown=False.bounce_mode=rolling: tearing down + + +The job exceeded the maximum time limit for jobs, and has been terminated. + +``` +* Links + - [Travis-CI REST api documentation](https://docs.travis-ci.com/api) + - [jq Manual](https://stedolan.github.io/jq/manual/) + +Local Quickstart +---------------- +This quickstart will help you run the Kafka system tests on your local machine. Note this requires bringing up a cluster of virtual machines on your local computer, which is memory intensive; it currently requires around 10G RAM. +For a tutorial on how to setup and run the Kafka system tests, see +https://cwiki.apache.org/confluence/display/KAFKA/tutorial+-+set+up+and+run+Kafka+system+tests+with+ducktape + +* Install Virtual Box from [https://www.virtualbox.org/](https://www.virtualbox.org/) (run `$ vboxmanage --version` to check if it's installed). +* Install Vagrant >= 1.6.4 from [https://www.vagrantup.com/](https://www.vagrantup.com/) (run `vagrant --version` to check if it's installed). +* Install system test dependencies, including ducktape, a command-line tool and library for testing distributed systems. We recommend to use virtual env for system test development + + $ cd kafka/tests + $ virtualenv -p python3 venv + $ . ./venv/bin/activate + $ python3 setup.py develop + $ cd .. # back to base kafka directory + +* Run the bootstrap script to set up Vagrant for testing + + $ tests/bootstrap-test-env.sh + +* Bring up the test cluster + + $ vagrant/vagrant-up.sh + $ # When using Virtualbox, it also works to run: vagrant up + +* Build the desired branch of Kafka + + $ git checkout $BRANCH + $ gradle # (only if necessary) + $ ./gradlew systemTestLibs + +* Run the system tests using ducktape: + + $ ducktape tests/kafkatest/tests + +EC2 Quickstart +-------------- +This quickstart will help you run the Kafka system tests on EC2. In this setup, all logic is run +on EC2 and none on your local machine. + +There are a lot of steps here, but the basic goals are to create one distinguished EC2 instance that +will be our "test driver", and to set up the security groups and iam role so that the test driver +can create, destroy, and run ssh commands on any number of "workers". + +As a convention, we'll use "kafkatest" in most names, but you can use whatever name you want. + +Preparation +----------- +In these steps, we will create an IAM role which has permission to create and destroy EC2 instances, +set up a keypair used for ssh access to the test driver and worker machines, and create a security group to allow the test driver and workers to all communicate via TCP. + +* [Create an IAM role](https://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles_create_for-user.html). We'll give this role the ability to launch or kill additional EC2 machines. + - Create role "kafkatest-master" + - Role type: Amazon EC2 + - Attach policy: AmazonEC2FullAccess (this will allow our test-driver to create and destroy EC2 instances) + +* If you haven't already, [set up a keypair to use for SSH access](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-key-pairs.html). For the purpose +of this quickstart, let's say the keypair name is kafkatest, and you've saved the private key in kafktest.pem + +* Next, create a EC2 security group called "kafkatest". + - After creating the group, inbound rules: allow SSH on port 22 from anywhere; also, allow access on all ports (0-65535) from other machines in the kafkatest group. + +Create the Test Driver +---------------------- +* Launch a new test driver machine + - OS: Ubuntu server is recommended + - Instance type: t2.medium is easily enough since this machine is just a driver + - Instance details: Most defaults are fine. + - IAM role -> kafkatest-master + - Tagging the instance with a useful name is recommended. + - Security group -> 'kafkatest' + + +* Once the machine is started, upload the SSH key to your test driver: + + $ scp -i /path/to/kafkatest.pem \ + /path/to/kafkatest.pem ubuntu@public.hostname.amazonaws.com:kafkatest.pem + +* Grab the public hostname/IP (available for example by navigating to your EC2 dashboard and viewing running instances) of your test driver and SSH into it: + + $ ssh -i /path/to/kafkatest.pem ubuntu@public.hostname.amazonaws.com + +Set Up the Test Driver +---------------------- +The following steps assume you have ssh'd into +the test driver machine. + +* Start by making sure you're up to date, and install git and ducktape: + + $ sudo apt-get update && sudo apt-get -y upgrade && sudo apt-get install -y python3-pip git + $ pip install ducktape + +* Get Kafka: + + $ git clone https://git-wip-us.apache.org/repos/asf/kafka.git kafka + +* Update your AWS credentials: + + export AWS_IAM_ROLE=$(curl -s http://169.254.169.254/latest/meta-data/iam/info | grep InstanceProfileArn | cut -d '"' -f 4 | cut -d '/' -f 2) + export AWS_ACCESS_KEY=$(curl -s http://169.254.169.254/latest/meta-data/iam/security-credentials/$AWS_IAM_ROLE | grep AccessKeyId | awk -F\" '{ print $4 }') + export AWS_SECRET_KEY=$(curl -s http://169.254.169.254/latest/meta-data/iam/security-credentials/$AWS_IAM_ROLE | grep SecretAccessKey | awk -F\" '{ print $4 }') + export AWS_SESSION_TOKEN=$(curl -s http://169.254.169.254/latest/meta-data/iam/security-credentials/$AWS_IAM_ROLE | grep Token | awk -F\" '{ print $4 }') + +* Install some dependencies: + + $ cd kafka + $ ./vagrant/aws/aws-init.sh + $ . ~/.bashrc + +* An example Vagrantfile.local has been created by aws-init.sh which looks something like: + + # Vagrantfile.local + ec2_instance_type = "..." # Pick something appropriate for your + # test. Note that the default m3.medium has + # a small disk. + ec2_spot_max_price = "0.123" # On-demand price for instance type + enable_hostmanager = false + num_zookeepers = 0 + num_kafka = 0 + num_workers = 9 + ec2_keypair_name = 'kafkatest' + ec2_keypair_file = '/home/ubuntu/kafkatest.pem' + ec2_security_groups = ['kafkatest'] + ec2_region = 'us-west-2' + ec2_ami = "ami-29ebb519" + +* Start up the instances: + + # This will brink up worker machines in small parallel batches + $ vagrant/vagrant-up.sh --aws + +* Now you should be able to run tests: + + $ cd kafka/tests + $ ducktape kafkatest/tests + +* Update Worker VM + +If you change code in a branch on your driver VM, you need to update your worker VM to pick up this change: + + $ ./gradlew systemTestLibs + $ vagrant rsync + +* To halt your workers without destroying persistent state, run `vagrant halt`. Run `vagrant destroy -f` to destroy all traces of your workers. + +Unit Tests +---------- +The system tests have unit tests! The various services in the python `kafkatest` module are reasonably complex, and intended to be reusable. Hence we have unit tests +for the system service classes. + +Where are the unit tests? +* The kafkatest unit tests are located under kafka/tests/unit + +How do I run the unit tests? +```bash +$ cd kafka/tests # The base system test directory +$ python3 setup.py test +``` + +How can I add a unit test? +* Follow the naming conventions - module name starts with "check", class name begins with "Check", test method name begins with "check" +* These naming conventions are defined in "setup.cfg". We use "check" to distinguish unit tests from system tests, which use "test" in the various names. + diff --git a/tests/bin/external_trogdor_command_example.py b/tests/bin/external_trogdor_command_example.py new file mode 100755 index 0000000..1254b82 --- /dev/null +++ b/tests/bin/external_trogdor_command_example.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import sys +import time + +# +# This is an example of an external script which can be run through Trogdor's +# ExternalCommandWorker. It sleeps for the given amount of time expressed by the delayMs field in the ExternalCommandSpec +# + +if __name__ == '__main__': + # Read the ExternalCommandWorker start message. + line = sys.stdin.readline() + start_message = json.loads(line) + workload = start_message["workload"] + print("Starting external_trogdor_command_example with task id %s, workload %s" + % (start_message["id"], workload)) + sys.stdout.flush() + + # pretend to start some workload + print(json.dumps({"status": "running"})) + sys.stdout.flush() + time.sleep(0.001 * workload["delayMs"]) + + print(json.dumps({"status": "exiting after %s delayMs" % workload["delayMs"]})) + sys.stdout.flush() diff --git a/tests/bin/flatten_html.sh b/tests/bin/flatten_html.sh new file mode 100755 index 0000000..fbcad50 --- /dev/null +++ b/tests/bin/flatten_html.sh @@ -0,0 +1,78 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +usage() { + cat < /tmp/my-protocol.html + firefox /tmp/my-protocol.html & + +usage: +$0 [flags] + +flags: +-f [filename] The HTML file to process. +-h Print this help message. +EOF +} + +die() { + echo $@ + exit 1 +} + +realpath() { + [[ $1 = /* ]] && echo "$1" || echo "$PWD/${1#./}" +} + +process_file() { + local CUR_FILE="${1}" + [[ -f "${CUR_FILE}" ]] || die "Unable to open input file ${CUR_FILE}" + while IFS= read -r LINE; do + if [[ $LINE =~ \#include\ virtual=\"(.*)\" ]]; then + local INCLUDED_FILE="${BASH_REMATCH[1]}" + if [[ $INCLUDED_FILE =~ ../includes/ ]]; then + : # ignore ../includes + else + pushd "$(dirname "${CUR_FILE}")" &> /dev/null \ + || die "failed to change directory to directory of ${CUR_FILE}" + process_file "${INCLUDED_FILE}" + popd &> /dev/null + fi + else + echo "${LINE}" + fi + done < "${CUR_FILE}" +} + +FILE="" +while getopts "f:h" arg; do + case $arg in + f) FILE=$OPTARG;; + h) usage; exit 0;; + *) echo "Error parsing command-line arguments." + usage + exit 1;; + esac +done + +[[ -z "${FILE}" ]] && die "You must specify which file to process. -h for help." +process_file "${FILE}" diff --git a/tests/bootstrap-test-env.sh b/tests/bootstrap-test-env.sh new file mode 100755 index 0000000..8e70bbf --- /dev/null +++ b/tests/bootstrap-test-env.sh @@ -0,0 +1,87 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script automates the process of setting up a local machine for running Kafka system tests +export GREP_OPTIONS='--color=never' + +# Helper function which prints version numbers so they can be compared lexically or numerically +function version { echo "$@" | awk -F. '{ printf("%03d%03d%03d%03d\n", $1,$2,$3,$4); }'; } + +base_dir=`dirname $0`/.. +cd $base_dir + +echo "Checking Virtual Box installation..." +bad_vb=false +if [ -z `vboxmanage --version` ]; then + echo "It appears that Virtual Box is not installed. Please install and try again (see https://www.virtualbox.org/ for details)" + bad_vb=true +else + echo "Virtual Box looks good." +fi + +echo "Checking Vagrant installation..." +vagrant_version=`vagrant --version | egrep -o "[0-9]+\.[0-9]+\.[0-9]+"` +bad_vagrant=false +if [ "$(version $vagrant_version)" -lt "$(version 1.6.4)" ]; then + echo "Found Vagrant version $vagrant_version. Please upgrade to 1.6.4 or higher (see https://www.vagrantup.com for details)" + bad_vagrant=true +else + echo "Vagrant installation looks good." +fi + +if [ "x$bad_vagrant" == "xtrue" -o "x$bad_vb" == "xtrue" ]; then + exit 1 +fi + +echo "Checking for necessary Vagrant plugins..." +hostmanager_version=`vagrant plugin list | grep vagrant-hostmanager | egrep -o "[0-9]+\.[0-9]+\.[0-9]+"` +if [ -z "$hostmanager_version" ]; then + vagrant plugin install vagrant-hostmanager +fi + +echo "Creating and packaging a reusable base box for Vagrant..." +vagrant/package-base-box.sh + +# Set up Vagrantfile.local if necessary +if [ ! -e Vagrantfile.local ]; then + echo "Creating Vagrantfile.local..." + cp vagrant/system-test-Vagrantfile.local Vagrantfile.local +else + echo "Found an existing Vagrantfile.local. Keeping without overwriting..." +fi + +# Sanity check contents of Vagrantfile.local +echo "Checking Vagrantfile.local..." +vagrantfile_ok=true +num_brokers=`egrep -o "num_brokers\s*=\s*[0-9]+" Vagrantfile.local | cut -d '=' -f 2 | xargs` +num_zookeepers=`egrep -o "num_zookeepers\s*=\s*[0-9]+" Vagrantfile.local | cut -d '=' -f 2 | xargs` +num_workers=`egrep -o "num_workers\s*=\s*[0-9]+" Vagrantfile.local | cut -d '=' -f 2 | xargs` +if [ "x$num_brokers" == "x" -o "$num_brokers" != 0 ]; then + echo "Vagrantfile.local: bad num_brokers. Update to: num_brokers = 0" + vagrantfile_ok=false +fi +if [ "x$num_zookeepers" == "x" -o "$num_zookeepers" != 0 ]; then + echo "Vagrantfile.local: bad num_zookeepers. Update to: num_zookeepers = 0" + vagrantfile_ok=false +fi +if [ "x$num_workers" == "x" -o "$num_workers" == 0 ]; then + echo "Vagrantfile.local: bad num_workers (size of test cluster). Set num_workers high enough to run your tests." + vagrantfile_ok=false +fi + +if [ "$vagrantfile_ok" == "true" ]; then + echo "Vagrantfile.local looks good." +fi diff --git a/tests/docker/Dockerfile b/tests/docker/Dockerfile new file mode 100644 index 0000000..51e8afd --- /dev/null +++ b/tests/docker/Dockerfile @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +ARG jdk_version=openjdk:8 +FROM $jdk_version + +MAINTAINER Apache Kafka dev@kafka.apache.org +VOLUME ["/opt/kafka-dev"] + +# Set the timezone. +ENV TZ="/usr/share/zoneinfo/America/Los_Angeles" + +# Do not ask for confirmations when running apt-get, etc. +ENV DEBIAN_FRONTEND noninteractive + +# Set the ducker.creator label so that we know that this is a ducker image. This will make it +# visible to 'ducker purge'. The ducker.creator label also lets us know what UNIX user built this +# image. +ARG ducker_creator=default +LABEL ducker.creator=$ducker_creator + +# Update Linux and install necessary utilities. +# we have to install git since it is included in openjdk:8 but not openjdk:11 +RUN apt update && apt install -y sudo git netcat iptables rsync unzip wget curl jq coreutils openssh-server net-tools vim python3-pip python3-dev libffi-dev libssl-dev cmake pkg-config libfuse-dev iperf traceroute && apt-get -y clean +RUN python3 -m pip install -U pip==21.1.1; +RUN pip3 install --upgrade cffi virtualenv pyasn1 boto3 pycrypto pywinrm ipaddress enum34 debugpy && pip3 install --upgrade ducktape==0.8.1 + +# Set up ssh +COPY ./ssh-config /root/.ssh/config +# NOTE: The paramiko library supports the PEM-format private key, but does not support the RFC4716 format. +RUN ssh-keygen -m PEM -q -t rsa -N '' -f /root/.ssh/id_rsa && cp -f /root/.ssh/id_rsa.pub /root/.ssh/authorized_keys +RUN echo 'PermitUserEnvironment yes' >> /etc/ssh/sshd_config + +# Install binary test dependencies. +# we use the same versions as in vagrant/base.sh +ARG KAFKA_MIRROR="https://s3-us-west-2.amazonaws.com/kafka-packages" +RUN mkdir -p "/opt/kafka-0.8.2.2" && chmod a+rw /opt/kafka-0.8.2.2 && curl -s "$KAFKA_MIRROR/kafka_2.11-0.8.2.2.tgz" | tar xz --strip-components=1 -C "/opt/kafka-0.8.2.2" +RUN mkdir -p "/opt/kafka-0.9.0.1" && chmod a+rw /opt/kafka-0.9.0.1 && curl -s "$KAFKA_MIRROR/kafka_2.11-0.9.0.1.tgz" | tar xz --strip-components=1 -C "/opt/kafka-0.9.0.1" +RUN mkdir -p "/opt/kafka-0.10.0.1" && chmod a+rw /opt/kafka-0.10.0.1 && curl -s "$KAFKA_MIRROR/kafka_2.11-0.10.0.1.tgz" | tar xz --strip-components=1 -C "/opt/kafka-0.10.0.1" +RUN mkdir -p "/opt/kafka-0.10.1.1" && chmod a+rw /opt/kafka-0.10.1.1 && curl -s "$KAFKA_MIRROR/kafka_2.11-0.10.1.1.tgz" | tar xz --strip-components=1 -C "/opt/kafka-0.10.1.1" +RUN mkdir -p "/opt/kafka-0.10.2.2" && chmod a+rw /opt/kafka-0.10.2.2 && curl -s "$KAFKA_MIRROR/kafka_2.11-0.10.2.2.tgz" | tar xz --strip-components=1 -C "/opt/kafka-0.10.2.2" +RUN mkdir -p "/opt/kafka-0.11.0.3" && chmod a+rw /opt/kafka-0.11.0.3 && curl -s "$KAFKA_MIRROR/kafka_2.11-0.11.0.3.tgz" | tar xz --strip-components=1 -C "/opt/kafka-0.11.0.3" +RUN mkdir -p "/opt/kafka-1.0.2" && chmod a+rw /opt/kafka-1.0.2 && curl -s "$KAFKA_MIRROR/kafka_2.11-1.0.2.tgz" | tar xz --strip-components=1 -C "/opt/kafka-1.0.2" +RUN mkdir -p "/opt/kafka-1.1.1" && chmod a+rw /opt/kafka-1.1.1 && curl -s "$KAFKA_MIRROR/kafka_2.11-1.1.1.tgz" | tar xz --strip-components=1 -C "/opt/kafka-1.1.1" +RUN mkdir -p "/opt/kafka-2.0.1" && chmod a+rw /opt/kafka-2.0.1 && curl -s "$KAFKA_MIRROR/kafka_2.12-2.0.1.tgz" | tar xz --strip-components=1 -C "/opt/kafka-2.0.1" +RUN mkdir -p "/opt/kafka-2.1.1" && chmod a+rw /opt/kafka-2.1.1 && curl -s "$KAFKA_MIRROR/kafka_2.12-2.1.1.tgz" | tar xz --strip-components=1 -C "/opt/kafka-2.1.1" +RUN mkdir -p "/opt/kafka-2.2.2" && chmod a+rw /opt/kafka-2.2.2 && curl -s "$KAFKA_MIRROR/kafka_2.12-2.2.2.tgz" | tar xz --strip-components=1 -C "/opt/kafka-2.2.2" +RUN mkdir -p "/opt/kafka-2.3.1" && chmod a+rw /opt/kafka-2.3.1 && curl -s "$KAFKA_MIRROR/kafka_2.12-2.3.1.tgz" | tar xz --strip-components=1 -C "/opt/kafka-2.3.1" +RUN mkdir -p "/opt/kafka-2.4.1" && chmod a+rw /opt/kafka-2.4.1 && curl -s "$KAFKA_MIRROR/kafka_2.12-2.4.1.tgz" | tar xz --strip-components=1 -C "/opt/kafka-2.4.1" +RUN mkdir -p "/opt/kafka-2.5.1" && chmod a+rw /opt/kafka-2.5.1 && curl -s "$KAFKA_MIRROR/kafka_2.12-2.5.1.tgz" | tar xz --strip-components=1 -C "/opt/kafka-2.5.1" +RUN mkdir -p "/opt/kafka-2.6.2" && chmod a+rw /opt/kafka-2.6.2 && curl -s "$KAFKA_MIRROR/kafka_2.12-2.6.2.tgz" | tar xz --strip-components=1 -C "/opt/kafka-2.6.2" +RUN mkdir -p "/opt/kafka-2.7.1" && chmod a+rw /opt/kafka-2.7.1 && curl -s "$KAFKA_MIRROR/kafka_2.12-2.7.1.tgz" | tar xz --strip-components=1 -C "/opt/kafka-2.7.1" +RUN mkdir -p "/opt/kafka-2.8.1" && chmod a+rw /opt/kafka-2.8.1 && curl -s "$KAFKA_MIRROR/kafka_2.12-2.8.1.tgz" | tar xz --strip-components=1 -C "/opt/kafka-2.8.1" + +# Streams test dependencies +RUN curl -s "$KAFKA_MIRROR/kafka-streams-0.10.0.1-test.jar" -o /opt/kafka-0.10.0.1/libs/kafka-streams-0.10.0.1-test.jar +RUN curl -s "$KAFKA_MIRROR/kafka-streams-0.10.1.1-test.jar" -o /opt/kafka-0.10.1.1/libs/kafka-streams-0.10.1.1-test.jar +RUN curl -s "$KAFKA_MIRROR/kafka-streams-0.10.2.2-test.jar" -o /opt/kafka-0.10.2.2/libs/kafka-streams-0.10.2.2-test.jar +RUN curl -s "$KAFKA_MIRROR/kafka-streams-0.11.0.3-test.jar" -o /opt/kafka-0.11.0.3/libs/kafka-streams-0.11.0.3-test.jar +RUN curl -s "$KAFKA_MIRROR/kafka-streams-1.0.2-test.jar" -o /opt/kafka-1.0.2/libs/kafka-streams-1.0.2-test.jar +RUN curl -s "$KAFKA_MIRROR/kafka-streams-1.1.1-test.jar" -o /opt/kafka-1.1.1/libs/kafka-streams-1.1.1-test.jar +RUN curl -s "$KAFKA_MIRROR/kafka-streams-2.0.1-test.jar" -o /opt/kafka-2.0.1/libs/kafka-streams-2.0.1-test.jar +RUN curl -s "$KAFKA_MIRROR/kafka-streams-2.1.1-test.jar" -o /opt/kafka-2.1.1/libs/kafka-streams-2.1.1-test.jar +RUN curl -s "$KAFKA_MIRROR/kafka-streams-2.2.2-test.jar" -o /opt/kafka-2.2.2/libs/kafka-streams-2.2.2-test.jar +RUN curl -s "$KAFKA_MIRROR/kafka-streams-2.3.1-test.jar" -o /opt/kafka-2.3.1/libs/kafka-streams-2.3.1-test.jar +RUN curl -s "$KAFKA_MIRROR/kafka-streams-2.4.1-test.jar" -o /opt/kafka-2.4.1/libs/kafka-streams-2.4.1-test.jar +RUN curl -s "$KAFKA_MIRROR/kafka-streams-2.5.1-test.jar" -o /opt/kafka-2.5.1/libs/kafka-streams-2.5.1-test.jar +RUN curl -s "$KAFKA_MIRROR/kafka-streams-2.6.2-test.jar" -o /opt/kafka-2.6.2/libs/kafka-streams-2.6.2-test.jar +RUN curl -s "$KAFKA_MIRROR/kafka-streams-2.7.1-test.jar" -o /opt/kafka-2.7.1/libs/kafka-streams-2.7.1-test.jar +RUN curl -s "$KAFKA_MIRROR/kafka-streams-2.8.1-test.jar" -o /opt/kafka-2.8.1/libs/kafka-streams-2.8.1-test.jar + +# The version of Kibosh to use for testing. +# If you update this, also update vagrant/base.sh +ARG KIBOSH_VERSION="8841dd392e6fbf02986e2fb1f1ebf04df344b65a" + +# Aligning uid inside/outside docker enables containers to modify files of kafka source (mounted in /opt/kafka-dev") +# By default, the outside user id is 1000 (UID_MIN). The known exception in QA is travis which gives non-1000 id. +ARG UID="1000" + +# Install Kibosh +RUN apt-get install fuse +RUN cd /opt && git clone -q https://github.com/confluentinc/kibosh.git && cd "/opt/kibosh" && git reset --hard $KIBOSH_VERSION && mkdir "/opt/kibosh/build" && cd "/opt/kibosh/build" && ../configure && make -j 2 + +# Set up the ducker user. +RUN useradd -u $UID -ms /bin/bash ducker \ + && mkdir -p /home/ducker/ \ + && rsync -aiq /root/.ssh/ /home/ducker/.ssh \ + && chown -R ducker /home/ducker/ /mnt/ /var/log/ \ + && echo "PATH=$(runuser -l ducker -c 'echo $PATH'):$JAVA_HOME/bin" >> /home/ducker/.ssh/environment \ + && echo 'PATH=$PATH:'"$JAVA_HOME/bin" >> /home/ducker/.profile \ + && echo 'ducker ALL=(ALL) NOPASSWD: ALL' >> /etc/sudoers + +USER ducker + +CMD sudo service ssh start && tail -f /dev/null diff --git a/tests/docker/ducker-ak b/tests/docker/ducker-ak new file mode 100755 index 0000000..8047b8f --- /dev/null +++ b/tests/docker/ducker-ak @@ -0,0 +1,627 @@ +#!/usr/bin/env bash + +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# +# Ducker-AK: a tool for running Apache Kafka system tests inside Docker images. +# +# Note: this should be compatible with the version of bash that ships on most +# Macs, bash 3.2.57. +# + +script_path="${0}" + +# The absolute path to the directory which this script is in. This will also be the directory +# which we run docker build from. +ducker_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" + +# The absolute path to the root Kafka directory +kafka_dir="$( cd "${ducker_dir}/../.." && pwd )" + +# The memory consumption to allow during the docker build. +# This does not include swap. +docker_build_memory_limit="3200m" + +# The maximum mmemory consumption to allow in containers. +docker_run_memory_limit="2000m" + +# The default number of cluster nodes to bring up if a number is not specified. +default_num_nodes=14 + +# The default OpenJDK base image. +default_jdk="openjdk:8" + +# The default ducker-ak image name. +default_image_name="ducker-ak" + +# Port to listen on when debugging +debugpy_port=5678 + +# Display a usage message on the terminal and exit. +# +# $1: The exit status to use +usage() { + local exit_status="${1}" + cat < /dev/null || die "You must install ${cmd} to run this script." + done +} + +# Set a global variable to a value. +# +# $1: The variable name to set. This function will die if the variable already has a value. The +# variable will be made readonly to prevent any future modifications. +# $2: The value to set the variable to. This function will die if the value is empty or starts +# with a dash. +# $3: A human-readable description of the variable. +set_once() { + local key="${1}" + local value="${2}" + local what="${3}" + [[ -n "${!key}" ]] && die "Error: more than one value specified for ${what}." + verify_command_line_argument "${value}" "${what}" + # It would be better to use declare -g, but older bash versions don't support it. + export ${key}="${value}" +} + +# Verify that a command-line argument is present and does not start with a slash. +# +# $1: The command-line argument to verify. +# $2: A human-readable description of the variable. +verify_command_line_argument() { + local value="${1}" + local what="${2}" + [[ -n "${value}" ]] || die "Error: no value specified for ${what}" + [[ ${value} == -* ]] && die "Error: invalid value ${value} specified for ${what}" +} + +# Echo a message if a flag is set. +# +# $1: If this is 1, the message will be echoed. +# $@: The message +maybe_echo() { + local verbose="${1}" + shift + [[ "${verbose}" -eq 1 ]] && echo "${@}" +} + +# Counts the number of elements passed to this subroutine. +count() { + echo $# +} + +# Push a new directory on to the bash directory stack, or exit with a failure message. +# +# $1: The directory push on to the directory stack. +must_pushd() { + local target_dir="${1}" + pushd -- "${target_dir}" &> /dev/null || die "failed to change directory to ${target_dir}" +} + +# Pop a directory from the bash directory stack, or exit with a failure message. +must_popd() { + popd &> /dev/null || die "failed to popd" +} + +echo_and_do() { + local cmd="${@}" + echo "${cmd}" + ${cmd} +} + +# Run a command and die if it fails. +# +# Optional flags: +# -v: print the command before running it. +# -o: display the command output. +# $@: The command to run. +must_do() { + local verbose=0 + local output="/dev/null" + while true; do + case ${1} in + -v) verbose=1; shift;; + -o) output="/dev/stdout"; shift;; + *) break;; + esac + done + local cmd="${@}" + [[ "${verbose}" -eq 1 ]] && echo "${cmd}" + ${cmd} >${output} || die "${1} failed" +} + +# Ask the user a yes/no question. +# +# $1: The prompt to use +# $_return: 0 if the user answered no; 1 if the user answered yes. +ask_yes_no() { + local prompt="${1}" + while true; do + read -r -p "${prompt} " response + case "${response}" in + [yY]|[yY][eE][sS]) _return=1; return;; + [nN]|[nN][oO]) _return=0; return;; + *);; + esac + echo "Please respond 'yes' or 'no'." + echo + done +} + +# Build a docker image. +# +# $1: The name of the image to build. +ducker_build() { + local image_name="${1}" + + # Use SECONDS, a builtin bash variable that gets incremented each second, to measure the docker + # build duration. + SECONDS=0 + + must_pushd "${ducker_dir}" + # Tip: if you are scratching your head for some dependency problems that are referring to an old code version + # (for example java.lang.NoClassDefFoundError), add --no-cache flag to the build shall give you a clean start. + echo_and_do docker build --memory="${docker_build_memory_limit}" \ + --build-arg "ducker_creator=${user_name}" \ + --build-arg "jdk_version=${jdk_version}" \ + --build-arg "UID=${UID}" \ + -t "${image_name}" \ + -f "${ducker_dir}/Dockerfile" ${docker_args} -- . + docker_status=$? + must_popd + duration="${SECONDS}" + if [[ ${docker_status} -ne 0 ]]; then + echo "ERROR: Failed to build ${what} image after $((${duration} / 60))m \ +$((${duration} % 60))s." + echo "If this error is unexpected, consider running 'docker system prune -a' \ +to clear old images from your local cache." + exit 1 + fi + echo "Successfully built ${what} image in $((${duration} / 60))m \ +$((${duration} % 60))s. See ${build_log} for details." +} + +docker_run() { + local node=${1} + local image_name=${2} + local ports_option=${3} + local port_mapping=${4} + + local expose_ports="" + if [[ -n ${ports_option} ]]; then + expose_ports="-P" + for expose_port in ${ports_option//,/ }; do + expose_ports="${expose_ports} --expose ${expose_port}" + done + fi + if [[ -n ${port_mapping} ]]; then + expose_ports="${expose_ports} -p ${port_mapping}:${port_mapping}" + fi + + # Invoke docker-run. We need privileged mode to be able to run iptables + # and mount FUSE filesystems inside the container. We also need it to + # run iptables inside the container. + must_do -v docker run --privileged \ + -d -t -h "${node}" --network ducknet "${expose_ports}" \ + --memory=${docker_run_memory_limit} --memory-swappiness=1 \ + -v "${kafka_dir}:/opt/kafka-dev" --name "${node}" -- "${image_name}" +} + +setup_custom_ducktape() { + local custom_ducktape="${1}" + local image_name="${2}" + + [[ -f "${custom_ducktape}/ducktape/__init__.py" ]] || \ + die "You must supply a valid ducktape directory to --custom-ducktape" + docker_run ducker01 "${image_name}" + local running_container="$(docker ps -f=network=ducknet -q)" + must_do -v -o docker cp "${custom_ducktape}" "${running_container}:/opt/ducktape" + docker exec --user=root ducker01 bash -c 'set -x && cd /opt/kafka-dev/tests && sudo python3 ./setup.py develop install && cd /opt/ducktape && sudo python3 ./setup.py develop install' + [[ $? -ne 0 ]] && die "failed to install the new ducktape." + must_do -v -o docker commit ducker01 "${image_name}" + must_do -v docker kill "${running_container}" + must_do -v docker rm ducker01 +} + +ducker_up() { + require_commands docker + while [[ $# -ge 1 ]]; do + case "${1}" in + -C|--custom-ducktape) set_once custom_ducktape "${2}" "the custom ducktape directory"; shift 2;; + -f|--force) force=1; shift;; + -n|--num-nodes) set_once num_nodes "${2}" "number of nodes"; shift 2;; + -j|--jdk) set_once jdk_version "${2}" "the OpenJDK base image"; shift 2;; + -e|--expose-ports) set_once expose_ports "${2}" "the ports to expose"; shift 2;; + *) set_once image_name "${1}" "docker image name"; shift;; + esac + done + [[ -n "${num_nodes}" ]] || num_nodes="${default_num_nodes}" + [[ -n "${jdk_version}" ]] || jdk_version="${default_jdk}" + [[ -n "${image_name}" ]] || image_name="${default_image_name}-${jdk_version/:/-}" + [[ "${num_nodes}" =~ ^-?[0-9]+$ ]] || \ + die "ducker_up: the number of nodes must be an integer." + [[ "${num_nodes}" -gt 0 ]] || die "ducker_up: the number of nodes must be greater than 0." + if [[ "${num_nodes}" -lt 2 ]]; then + if [[ "${force}" -ne 1 ]]; then + echo "ducker_up: It is recommended to run at least 2 nodes, since ducker01 is only \ +used to run ducktape itself. If you want to do it anyway, you can use --force to attempt to \ +use only ${num_nodes}." + exit 1 + fi + fi + + docker ps >/dev/null || die "ducker_up: failed to run docker. Please check that the daemon is started." + + ducker_build "${image_name}" + + docker inspect --format='{{.Config.Labels}}' --type=image "${image_name}" | grep -q 'ducker.type' + local docker_status=${PIPESTATUS[0]} + local grep_status=${PIPESTATUS[1]} + [[ "${docker_status}" -eq 0 ]] || die "ducker_up: failed to inspect image ${image_name}. \ +Please check that it exists." + if [[ "${grep_status}" -ne 0 ]]; then + if [[ "${force}" -ne 1 ]]; then + echo "ducker_up: ${image_name} does not appear to be a ducker image. It lacks the \ +ducker.type label. If you think this is a mistake, you can use --force to attempt to bring \ +it up anyway." + exit 1 + fi + fi + local running_containers="$(docker ps -f=network=ducknet -q)" + local num_running_containers=$(count ${running_containers}) + if [[ ${num_running_containers} -gt 0 ]]; then + die "ducker_up: there are ${num_running_containers} ducker containers \ +running already. Use ducker down to bring down these containers before \ +attempting to start new ones." + fi + + echo "ducker_up: Bringing up ${image_name} with ${num_nodes} nodes..." + if docker network inspect ducknet &>/dev/null; then + must_do -v docker network rm ducknet + fi + must_do -v docker network create ducknet + if [[ -n "${custom_ducktape}" ]]; then + setup_custom_ducktape "${custom_ducktape}" "${image_name}" + fi + docker_run ducker01 "${image_name}" "${expose_ports}" "${debugpy_port}" + for n in $(seq -f %02g 2 ${num_nodes}); do + local node="ducker${n}" + docker_run "${node}" "${image_name}" "${expose_ports}" + done + mkdir -p "${ducker_dir}/build" + exec 3<> "${ducker_dir}/build/node_hosts" + for n in $(seq -f %02g 1 ${num_nodes}); do + local node="ducker${n}" + docker exec --user=root "${node}" grep "${node}" /etc/hosts >&3 + [[ $? -ne 0 ]] && die "failed to find the /etc/hosts entry for ${node}" + done + exec 3>&- + for n in $(seq -f %02g 1 ${num_nodes}); do + local node="ducker${n}" + docker exec --user=root "${node}" \ + bash -c "grep -v ${node} /opt/kafka-dev/tests/docker/build/node_hosts >> /etc/hosts" + [[ $? -ne 0 ]] && die "failed to append to the /etc/hosts file on ${node}" + done + + echo "ducker_up: added the latest entries to /etc/hosts on each node." + generate_cluster_json_file "${num_nodes}" "${ducker_dir}/build/cluster.json" + echo "ducker_up: successfully wrote ${ducker_dir}/build/cluster.json" + echo "** ducker_up: successfully brought up ${num_nodes} nodes." +} + +# Generate the cluster.json file used by ducktape to identify cluster nodes. +# +# $1: The number of cluster nodes. +# $2: The path to write the cluster.json file to. +generate_cluster_json_file() { + local num_nodes="${1}" + local path="${2}" + exec 3<> "${path}" +cat<&3 +{ + "_comment": [ + "Licensed to the Apache Software Foundation (ASF) under one or more", + "contributor license agreements. See the NOTICE file distributed with", + "this work for additional information regarding copyright ownership.", + "The ASF licenses this file to You under the Apache License, Version 2.0", + "(the \"License\"); you may not use this file except in compliance with", + "the License. You may obtain a copy of the License at", + "", + "http://www.apache.org/licenses/LICENSE-2.0", + "", + "Unless required by applicable law or agreed to in writing, software", + "distributed under the License is distributed on an \"AS IS\" BASIS,", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.", + "See the License for the specific language governing permissions and", + "limitations under the License." + ], + "nodes": [ +EOF + for n in $(seq 2 ${num_nodes}); do + if [[ ${n} -eq ${num_nodes} ]]; then + suffix="" + else + suffix="," + fi + local node=$(printf ducker%02d ${n}) +cat<&3 + { + "externally_routable_ip": "${node}", + "ssh_config": { + "host": "${node}", + "hostname": "${node}", + "identityfile": "/home/ducker/.ssh/id_rsa", + "password": "", + "port": 22, + "user": "ducker" + } + }${suffix} +EOF + done +cat<&3 + ] +} +EOF + exec 3>&- +} + +ducker_test() { + require_commands docker + docker inspect ducker01 &>/dev/null || \ + die "ducker_test: the ducker01 instance appears to be down. Did you run 'ducker up'?" + declare -a test_name_args=() + local debug=0 + while [[ $# -ge 1 ]]; do + case "${1}" in + -d|--debug) debug=1; shift;; + --) shift; break;; + *) test_name_args+=("${1}"); shift;; + esac + done + local ducktape_args=${*} + + [[ ${#test_name_args} -lt 1 ]] && \ + die "ducker_test: you must supply at least one system test to run. Type --help for help." + local test_names="" + + for test_name in ${test_name_args[*]}; do + local regex=".*\/kafkatest\/(.*)" + if [[ $test_name =~ $regex ]]; then + local kpath=${BASH_REMATCH[1]} + test_names="${test_names} ./tests/kafkatest/${kpath}" + else + test_names="${test_names} ${test_name}" + fi + done + + must_pushd "${kafka_dir}" + (test -f ./gradlew || gradle) && ./gradlew systemTestLibs + must_popd + if [[ "${debug}" -eq 1 ]]; then + local ducktape_cmd="python3 -m debugpy --listen 0.0.0.0:${debugpy_port} --wait-for-client /usr/local/bin/ducktape" + else + local ducktape_cmd="ducktape" + fi + + cmd="cd /opt/kafka-dev && ${ducktape_cmd} --cluster-file /opt/kafka-dev/tests/docker/build/cluster.json $test_names $ducktape_args" + echo "docker exec ducker01 bash -c \"${cmd}\"" + exec docker exec --user=ducker ducker01 bash -c "${cmd}" +} + +ducker_ssh() { + require_commands docker + [[ $# -eq 0 ]] && die "ducker_ssh: Please specify a container name to log into. \ +Currently active containers: $(echo_running_container_names)" + local node_info="${1}" + shift + local guest_command="$*" + local user_name="ducker" + if [[ "${node_info}" =~ @ ]]; then + user_name="${node_info%%@*}" + local node_name="${node_info##*@}" + else + local node_name="${node_info}" + fi + local docker_flags="" + if [[ -z "${guest_command}" ]]; then + local docker_flags="${docker_flags} -t" + local guest_command_prefix="" + guest_command=bash + else + local guest_command_prefix="bash -c" + fi + if [[ "${node_name}" == "all" ]]; then + local nodes=$(echo_running_container_names) + [[ "${nodes}" == "(none)" ]] && die "ducker_ssh: can't locate any running ducker nodes." + for node in ${nodes}; do + docker exec --user=${user_name} -i ${docker_flags} "${node}" \ + ${guest_command_prefix} "${guest_command}" || die "docker exec ${node} failed" + done + else + docker inspect --type=container -- "${node_name}" &>/dev/null || \ + die "ducker_ssh: can't locate node ${node_name}. Currently running nodes: \ +$(echo_running_container_names)" + exec docker exec --user=${user_name} -i ${docker_flags} "${node_name}" \ + ${guest_command_prefix} "${guest_command}" + fi +} + +# Echo all the running Ducker container names, or (none) if there are no running Ducker containers. +echo_running_container_names() { + node_names="$(docker ps -f=network=ducknet -q --format '{{.Names}}' | sort)" + if [[ -z "${node_names}" ]]; then + echo "(none)" + else + echo ${node_names//$'\n'/ } + fi +} + +ducker_down() { + require_commands docker + local verbose=1 + local force_str="" + while [[ $# -ge 1 ]]; do + case "${1}" in + -q|--quiet) verbose=0; shift;; + -f|--force) force_str="-f"; shift;; + *) die "ducker_down: unexpected command-line argument ${1}";; + esac + done + local running_containers + running_containers="$(docker ps -f=network=ducknet -q)" + [[ $? -eq 0 ]] || die "ducker_down: docker command failed. Is the docker daemon running?" + running_containers=${running_containers//$'\n'/ } + local all_containers="$(docker ps -a -f=network=ducknet -q)" + all_containers=${all_containers//$'\n'/ } + if [[ -z "${all_containers}" ]]; then + maybe_echo "${verbose}" "No ducker containers found." + return + fi + verbose_flag="" + if [[ ${verbose} == 1 ]]; then + verbose_flag="-v" + fi + if [[ -n "${running_containers}" ]]; then + must_do ${verbose_flag} docker kill "${running_containers}" + fi + must_do ${verbose_flag} docker rm ${force_str} "${all_containers}" + must_do ${verbose_flag} -o rm -f -- "${ducker_dir}/build/node_hosts" "${ducker_dir}/build/cluster.json" + if docker network inspect ducknet &>/dev/null; then + must_do -v docker network rm ducknet + fi + maybe_echo "${verbose}" "ducker_down: removed $(count ${all_containers}) containers." +} + +ducker_purge() { + require_commands docker + local force_str="" + while [[ $# -ge 1 ]]; do + case "${1}" in + -f|--force) force_str="-f"; shift;; + *) die "ducker_purge: unknown argument ${1}";; + esac + done + echo "** ducker_purge: attempting to locate ducker images to purge" + local images + images=$(docker images -q -a -f label=ducker.creator) + [[ $? -ne 0 ]] && die "docker images command failed" + images=${images//$'\n'/ } + declare -a purge_images=() + if [[ -z "${images}" ]]; then + echo "** ducker_purge: no images found to purge." + exit 0 + fi + echo "** ducker_purge: images to delete:" + for image in ${images}; do + echo -n "${image} " + docker inspect --format='{{.Config.Labels}} {{.Created}}' --type=image "${image}" + [[ $? -ne 0 ]] && die "docker inspect ${image} failed" + done + ask_yes_no "Delete these docker images? [y/n]" + [[ "${_return}" -eq 0 ]] && exit 0 + must_do -v -o docker rmi ${force_str} ${images} +} + +# Parse command-line arguments +[[ $# -lt 1 ]] && usage 0 +# Display the help text if -h or --help appears in the command line +for arg in ${@}; do + case "${arg}" in + -h|--help) usage 0;; + --) break;; + *);; + esac +done +action="${1}" +shift +case "${action}" in + help) usage 0;; + + up|test|ssh|down|purge) + ducker_${action} "${@}"; exit 0;; + + *) echo "Unknown command '${action}'. Type '${script_path} --help' for usage information." + exit 1;; +esac diff --git a/tests/docker/run_tests.sh b/tests/docker/run_tests.sh new file mode 100755 index 0000000..0128fd6 --- /dev/null +++ b/tests/docker/run_tests.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash + +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +KAFKA_NUM_CONTAINERS=${KAFKA_NUM_CONTAINERS:-14} +TC_PATHS=${TC_PATHS:-./kafkatest/} +REBUILD=${REBUILD:f} + +die() { + echo $@ + exit 1 +} + +if [ "$REBUILD" == "t" ]; then + ./gradlew clean systemTestLibs +fi + +if ${SCRIPT_DIR}/ducker-ak ssh | grep -q '(none)'; then + ${SCRIPT_DIR}/ducker-ak up -n "${KAFKA_NUM_CONTAINERS}" || die "ducker-ak up failed" +fi + +[[ -n ${_DUCKTAPE_OPTIONS} ]] && _DUCKTAPE_OPTIONS="-- ${_DUCKTAPE_OPTIONS}" + +${SCRIPT_DIR}/ducker-ak test ${TC_PATHS} ${_DUCKTAPE_OPTIONS} || die "ducker-ak test failed" diff --git a/tests/docker/ssh-config b/tests/docker/ssh-config new file mode 100644 index 0000000..1f87417 --- /dev/null +++ b/tests/docker/ssh-config @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +Host * + ControlMaster auto + ControlPath ~/.ssh/master-%r@%h:%p + StrictHostKeyChecking no + ConnectTimeout=10 + IdentityFile ~/.ssh/id_rsa diff --git a/tests/docker/ssh/authorized_keys b/tests/docker/ssh/authorized_keys new file mode 100644 index 0000000..9f9da1f --- /dev/null +++ b/tests/docker/ssh/authorized_keys @@ -0,0 +1,15 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQC0qDT9kEPWc8JQ53b4KnT/ZJOLwb+3c//jpLW/2ofjDyIsPW4FohLpicfouch/zsRpN4G38lua+2BsGls9sMIZc6PXY2L+NIGCkqEMdCoU1Ym8SMtyJklfzp3m/0PeK9s2dLlR3PFRYvyFA4btQK5hkbYDNZPzf4airvzdRzLkrFf81+RemaMI2EtONwJRcbLViPaTXVKJdbFwJTJ1u7yu9wDYWHKBMA92mHTQeP6bhVYCqxJn3to/RfZYd+sHw6mfxVg5OrAlUOYpSV4pDNCAsIHdtZ56V8NQlJL6NJ2vzzSSYUwLMqe88fhrC8yYHoxC07QPy1EdkSTHdohAicyT root@knode01.knw diff --git a/tests/docker/ssh/config b/tests/docker/ssh/config new file mode 100644 index 0000000..1f87417 --- /dev/null +++ b/tests/docker/ssh/config @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +Host * + ControlMaster auto + ControlPath ~/.ssh/master-%r@%h:%p + StrictHostKeyChecking no + ConnectTimeout=10 + IdentityFile ~/.ssh/id_rsa diff --git a/tests/docker/ssh/id_rsa b/tests/docker/ssh/id_rsa new file mode 100644 index 0000000..276e07b --- /dev/null +++ b/tests/docker/ssh/id_rsa @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEpQIBAAKCAQEAtKg0/ZBD1nPCUOd2+Cp0/2STi8G/t3P/46S1v9qH4w8iLD1u +BaIS6YnH6LnIf87EaTeBt/JbmvtgbBpbPbDCGXOj12Ni/jSBgpKhDHQqFNWJvEjL +ciZJX86d5v9D3ivbNnS5UdzxUWL8hQOG7UCuYZG2AzWT83+Goq783Ucy5KxX/Nfk +XpmjCNhLTjcCUXGy1Yj2k11SiXWxcCUydbu8rvcA2FhygTAPdph00Hj+m4VWAqsS +Z97aP0X2WHfrB8Opn8VYOTqwJVDmKUleKQzQgLCB3bWeelfDUJSS+jSdr880kmFM +CzKnvPH4awvMmB6MQtO0D8tRHZEkx3aIQInMkwIDAQABAoIBAQCz6EMFNNLp0NP1 +X9yRXS6wW4e4CRWUazesiw3YZpcmnp6IchCMGZA99FEZyVILPW1J3tYWyotBdw7Z ++RFeCRXy5L+IMtiVkNJcpwss7M4ve0w0LkY0gj5V49xJ+3Gp4gDnZSxcguvrAem5 +yP5obR572fDpl0SknB4HCr6U2l+rauzrLyevy5eeDT/vmXbuM1cdHpNIXmmElz4L +t31n+exQRn6tP1h516iXbcYbopxDgdv2qKGAqzWKE6TyWpzF5x7kjOEYt0bZ5QO3 +Lwh7AAqE/3mwxlYwng1L4WAT7RtcP19W+9JDIc7ENInMGxq6q46p1S3IPZsf1cj/ +aAJ9q3LBAoGBAOVJr0+WkR786n3BuswpGQWBgVxfai4y9Lf90vuGKawdQUzXv0/c +EB/CFqP/dIsquukA8PfzjNMyTNmEHXi4Sf16H8Rg4EGhIYMEqIQojx1t/yLLm0aU +YPEvW/02Umtlg3pJw9fQAAzFVqCasw2E2lUdAUkydGRwDUJZmv2/b3NzAoGBAMm0 +Jo7Et7ochH8Vku6uA+hG+RdwlKFm5JA7/Ci3DOdQ1zmJNrvBBFQLo7AjA4iSCoBd +s9+y0nrSPcF4pM3l6ghLheaqbnIi2HqIMH9mjDbrOZiWvbnjvjpOketgNX8vV3Ye +GUkSjoNcmvRmdsICmUjeML8bGOmq4zF9W/GIfTphAoGBAKGRo8R8f/SLGh3VtvCI +gUY89NAHuEWnyIQii1qMNq8+yjYAzaHTm1UVqmiT6SbrzFvGOwcuCu0Dw91+2Fmp +2xGPzfTOoxf8GCY/0ROXlQmS6jc1rEw24Hzz92ldrwRYuyYf9q4Ltw1IvXtcp5F+ +LW/OiYpv0E66Gs3HYI0wKbP7AoGBAJMZWeFW37LQJ2TTJAQDToAwemq4xPxsoJX7 +2SsMTFHKKBwi0JLe8jwk/OxwrJwF/bieHZcvv8ao2zbkuDQcz6/a/D074C5G8V9z +QQM4k1td8vQwQw91Yv782/gvgvRNX1iaHNCowtxURgGlVEirQoTc3eoRZfrLkMM/ +7DTa2JEhAoGACEu3zHJ1sgyeOEgLArUJXlQM30A/ulMrnCd4MEyIE+ReyWAUevUQ +0lYdVNva0/W4C5e2lUOJL41jjIPLqI7tcFR2PZE6n0xTTkxNH5W2u1WpFeKjx+O3 +czv7Bt6wYyLHIMy1JEqAQ7pw1mtJ5s76UDvXUhciF+DU2pWYc6APKR0= +-----END RSA PRIVATE KEY----- diff --git a/tests/docker/ssh/id_rsa.pub b/tests/docker/ssh/id_rsa.pub new file mode 100644 index 0000000..76e8f5f --- /dev/null +++ b/tests/docker/ssh/id_rsa.pub @@ -0,0 +1 @@ +ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQC0qDT9kEPWc8JQ53b4KnT/ZJOLwb+3c//jpLW/2ofjDyIsPW4FohLpicfouch/zsRpN4G38lua+2BsGls9sMIZc6PXY2L+NIGCkqEMdCoU1Ym8SMtyJklfzp3m/0PeK9s2dLlR3PFRYvyFA4btQK5hkbYDNZPzf4airvzdRzLkrFf81+RemaMI2EtONwJRcbLViPaTXVKJdbFwJTJ1u7yu9wDYWHKBMA92mHTQeP6bhVYCqxJn3to/RfZYd+sHw6mfxVg5OrAlUOYpSV4pDNCAsIHdtZ56V8NQlJL6NJ2vzzSSYUwLMqe88fhrC8yYHoxC07QPy1EdkSTHdohAicyT root@knode01.knw diff --git a/tests/kafkatest/__init__.py b/tests/kafkatest/__init__.py new file mode 100644 index 0000000..7f331f9 --- /dev/null +++ b/tests/kafkatest/__init__.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This determines the version of kafkatest that can be published to PyPi and installed with pip +# +# Note that in development, this version name can't follow Kafka's convention of having a trailing "-SNAPSHOT" +# due to python version naming restrictions, which are enforced by python packaging tools +# (see https://www.python.org/dev/peps/pep-0440/) +# +# Instead, in development branches, the version should have a suffix of the form ".devN" +# +# For example, when Kafka is at version 1.0.0-SNAPSHOT, this should be something like "1.0.0.dev0" +__version__ = '3.1.0' diff --git a/tests/kafkatest/benchmarks/__init__.py b/tests/kafkatest/benchmarks/__init__.py new file mode 100644 index 0000000..ec20143 --- /dev/null +++ b/tests/kafkatest/benchmarks/__init__.py @@ -0,0 +1,14 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/kafkatest/benchmarks/core/__init__.py b/tests/kafkatest/benchmarks/core/__init__.py new file mode 100644 index 0000000..ec20143 --- /dev/null +++ b/tests/kafkatest/benchmarks/core/__init__.py @@ -0,0 +1,14 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/kafkatest/benchmarks/core/benchmark_test.py b/tests/kafkatest/benchmarks/core/benchmark_test.py new file mode 100644 index 0000000..5e7bdde --- /dev/null +++ b/tests/kafkatest/benchmarks/core/benchmark_test.py @@ -0,0 +1,281 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ducktape.mark import matrix +from ducktape.mark import parametrize +from ducktape.mark.resource import cluster +from ducktape.services.service import Service +from ducktape.tests.test import Test + +from kafkatest.services.kafka import KafkaService +from kafkatest.services.performance import ProducerPerformanceService, EndToEndLatencyService, ConsumerPerformanceService, throughput, latency, compute_aggregate_throughput +from kafkatest.services.zookeeper import ZookeeperService +from kafkatest.version import DEV_BRANCH, KafkaVersion + +TOPIC_REP_ONE = "topic-replication-factor-one" +TOPIC_REP_THREE = "topic-replication-factor-three" +DEFAULT_RECORD_SIZE = 100 # bytes + + +class Benchmark(Test): + """A benchmark of Kafka producer/consumer performance. This replicates the test + run here: + https://engineering.linkedin.com/kafka/benchmarking-apache-kafka-2-million-writes-second-three-cheap-machines + """ + def __init__(self, test_context): + super(Benchmark, self).__init__(test_context) + self.num_zk = 1 + self.num_brokers = 3 + self.topics = { + TOPIC_REP_ONE: {'partitions': 6, 'replication-factor': 1}, + TOPIC_REP_THREE: {'partitions': 6, 'replication-factor': 3} + } + + self.zk = ZookeeperService(test_context, self.num_zk) + + self.msgs_large = 10000000 + self.batch_size = 8*1024 + self.buffer_memory = 64*1024*1024 + self.msg_sizes = [10, 100, 1000, 10000, 100000] + self.target_data_size = 128*1024*1024 + self.target_data_size_gb = self.target_data_size/float(1024*1024*1024) + + def setUp(self): + self.zk.start() + + def start_kafka(self, security_protocol, interbroker_security_protocol, version, tls_version=None): + self.kafka = KafkaService( + self.test_context, self.num_brokers, + self.zk, security_protocol=security_protocol, + interbroker_security_protocol=interbroker_security_protocol, topics=self.topics, + version=version, tls_version=tls_version) + self.kafka.log_level = "INFO" # We don't DEBUG logging here + self.kafka.start() + + @cluster(num_nodes=5) + @parametrize(acks=1, topic=TOPIC_REP_ONE) + @parametrize(acks=1, topic=TOPIC_REP_THREE) + @parametrize(acks=-1, topic=TOPIC_REP_THREE) + @matrix(acks=[1], topic=[TOPIC_REP_THREE], message_size=[10, 100, 1000, 10000, 100000], compression_type=["none", "snappy"], security_protocol=['SSL'], tls_version=['TLSv1.2', 'TLSv1.3']) + @matrix(acks=[1], topic=[TOPIC_REP_THREE], message_size=[10, 100, 1000, 10000, 100000], compression_type=["none", "snappy"], security_protocol=['PLAINTEXT']) + @cluster(num_nodes=7) + @parametrize(acks=1, topic=TOPIC_REP_THREE, num_producers=3) + def test_producer_throughput(self, acks, topic, num_producers=1, message_size=DEFAULT_RECORD_SIZE, + compression_type="none", security_protocol='PLAINTEXT', tls_version=None, client_version=str(DEV_BRANCH), + broker_version=str(DEV_BRANCH)): + """ + Setup: 1 node zk + 3 node kafka cluster + Produce ~128MB worth of messages to a topic with 6 partitions. Required acks, topic replication factor, + security protocol and message size are varied depending on arguments injected into this test. + + Collect and return aggregate throughput statistics after all messages have been acknowledged. + (This runs ProducerPerformance.java under the hood) + """ + client_version = KafkaVersion(client_version) + broker_version = KafkaVersion(broker_version) + self.validate_versions(client_version, broker_version) + self.start_kafka(security_protocol, security_protocol, broker_version, tls_version) + # Always generate the same total amount of data + nrecords = int(self.target_data_size / message_size) + + self.producer = ProducerPerformanceService( + self.test_context, num_producers, self.kafka, topic=topic, + num_records=nrecords, record_size=message_size, throughput=-1, version=client_version, + settings={ + 'acks': acks, + 'compression.type': compression_type, + 'batch.size': self.batch_size, + 'buffer.memory': self.buffer_memory}) + self.producer.run() + return compute_aggregate_throughput(self.producer) + + @cluster(num_nodes=5) + @matrix(security_protocol=['SSL'], interbroker_security_protocol=['PLAINTEXT'], tls_version=['TLSv1.2', 'TLSv1.3'], compression_type=["none", "snappy"]) + @matrix(security_protocol=['PLAINTEXT'], compression_type=["none", "snappy"]) + def test_long_term_producer_throughput(self, compression_type="none", + security_protocol='PLAINTEXT', tls_version=None, + interbroker_security_protocol=None, client_version=str(DEV_BRANCH), + broker_version=str(DEV_BRANCH)): + """ + Setup: 1 node zk + 3 node kafka cluster + Produce 10e6 100 byte messages to a topic with 6 partitions, replication-factor 3, and acks=1. + + Collect and return aggregate throughput statistics after all messages have been acknowledged. + + (This runs ProducerPerformance.java under the hood) + """ + client_version = KafkaVersion(client_version) + broker_version = KafkaVersion(broker_version) + self.validate_versions(client_version, broker_version) + if interbroker_security_protocol is None: + interbroker_security_protocol = security_protocol + self.start_kafka(security_protocol, interbroker_security_protocol, broker_version, tls_version) + self.producer = ProducerPerformanceService( + self.test_context, 1, self.kafka, + topic=TOPIC_REP_THREE, num_records=self.msgs_large, record_size=DEFAULT_RECORD_SIZE, + throughput=-1, version=client_version, settings={ + 'acks': 1, + 'compression.type': compression_type, + 'batch.size': self.batch_size, + 'buffer.memory': self.buffer_memory + }, + intermediate_stats=True + ) + self.producer.run() + + summary = ["Throughput over long run, data > memory:"] + data = {} + # FIXME we should be generating a graph too + # Try to break it into 5 blocks, but fall back to a smaller number if + # there aren't even 5 elements + block_size = max(len(self.producer.stats[0]) // 5, 1) + nblocks = len(self.producer.stats[0]) // block_size + + for i in range(nblocks): + subset = self.producer.stats[0][i*block_size:min((i+1)*block_size, len(self.producer.stats[0]))] + if len(subset) == 0: + summary.append(" Time block %d: (empty)" % i) + data[i] = None + else: + records_per_sec = sum([stat['records_per_sec'] for stat in subset])/float(len(subset)) + mb_per_sec = sum([stat['mbps'] for stat in subset])/float(len(subset)) + + summary.append(" Time block %d: %f rec/sec (%f MB/s)" % (i, records_per_sec, mb_per_sec)) + data[i] = throughput(records_per_sec, mb_per_sec) + + self.logger.info("\n".join(summary)) + return data + + @cluster(num_nodes=5) + @matrix(security_protocol=['SSL'], interbroker_security_protocol=['PLAINTEXT'], tls_version=['TLSv1.2', 'TLSv1.3'], compression_type=["none", "snappy"]) + @matrix(security_protocol=['PLAINTEXT'], compression_type=["none", "snappy"]) + @cluster(num_nodes=6) + @matrix(security_protocol=['SASL_PLAINTEXT', 'SASL_SSL'], compression_type=["none", "snappy"]) + def test_end_to_end_latency(self, compression_type="none", security_protocol="PLAINTEXT", tls_version=None, + interbroker_security_protocol=None, client_version=str(DEV_BRANCH), + broker_version=str(DEV_BRANCH)): + """ + Setup: 1 node zk + 3 node kafka cluster + Produce (acks = 1) and consume 10e3 messages to a topic with 6 partitions and replication-factor 3, + measuring the latency between production and consumption of each message. + + Return aggregate latency statistics. + + (Under the hood, this simply runs EndToEndLatency.scala) + """ + client_version = KafkaVersion(client_version) + broker_version = KafkaVersion(broker_version) + self.validate_versions(client_version, broker_version) + if interbroker_security_protocol is None: + interbroker_security_protocol = security_protocol + self.start_kafka(security_protocol, interbroker_security_protocol, broker_version, tls_version) + self.logger.info("BENCHMARK: End to end latency") + self.perf = EndToEndLatencyService( + self.test_context, 1, self.kafka, + topic=TOPIC_REP_THREE, num_records=10000, + compression_type=compression_type, version=client_version + ) + self.perf.run() + return latency(self.perf.results[0]['latency_50th_ms'], self.perf.results[0]['latency_99th_ms'], self.perf.results[0]['latency_999th_ms']) + + @cluster(num_nodes=6) + @matrix(security_protocol=['SSL'], interbroker_security_protocol=['PLAINTEXT'], tls_version=['TLSv1.2', 'TLSv1.3'], compression_type=["none", "snappy"]) + @matrix(security_protocol=['PLAINTEXT'], compression_type=["none", "snappy"]) + def test_producer_and_consumer(self, compression_type="none", security_protocol="PLAINTEXT", tls_version=None, + interbroker_security_protocol=None, + client_version=str(DEV_BRANCH), broker_version=str(DEV_BRANCH)): + """ + Setup: 1 node zk + 3 node kafka cluster + Concurrently produce and consume 10e6 messages with a single producer and a single consumer, + + Return aggregate throughput statistics for both producer and consumer. + + (Under the hood, this runs ProducerPerformance.java, and ConsumerPerformance.scala) + """ + client_version = KafkaVersion(client_version) + broker_version = KafkaVersion(broker_version) + self.validate_versions(client_version, broker_version) + if interbroker_security_protocol is None: + interbroker_security_protocol = security_protocol + self.start_kafka(security_protocol, interbroker_security_protocol, broker_version, tls_version) + num_records = 10 * 1000 * 1000 # 10e6 + + self.producer = ProducerPerformanceService( + self.test_context, 1, self.kafka, + topic=TOPIC_REP_THREE, + num_records=num_records, record_size=DEFAULT_RECORD_SIZE, throughput=-1, version=client_version, + settings={ + 'acks': 1, + 'compression.type': compression_type, + 'batch.size': self.batch_size, + 'buffer.memory': self.buffer_memory + } + ) + self.consumer = ConsumerPerformanceService( + self.test_context, 1, self.kafka, topic=TOPIC_REP_THREE, messages=num_records) + Service.run_parallel(self.producer, self.consumer) + + data = { + "producer": compute_aggregate_throughput(self.producer), + "consumer": compute_aggregate_throughput(self.consumer) + } + summary = [ + "Producer + consumer:", + str(data)] + self.logger.info("\n".join(summary)) + return data + + @cluster(num_nodes=6) + @matrix(security_protocol=['SSL'], interbroker_security_protocol=['PLAINTEXT'], tls_version=['TLSv1.2', 'TLSv1.3'], compression_type=["none", "snappy"]) + @matrix(security_protocol=['PLAINTEXT'], compression_type=["none", "snappy"]) + def test_consumer_throughput(self, compression_type="none", security_protocol="PLAINTEXT", tls_version=None, + interbroker_security_protocol=None, num_consumers=1, + client_version=str(DEV_BRANCH), broker_version=str(DEV_BRANCH)): + """ + Consume 10e6 100-byte messages with 1 or more consumers from a topic with 6 partitions + and report throughput. + """ + client_version = KafkaVersion(client_version) + broker_version = KafkaVersion(broker_version) + self.validate_versions(client_version, broker_version) + if interbroker_security_protocol is None: + interbroker_security_protocol = security_protocol + self.start_kafka(security_protocol, interbroker_security_protocol, broker_version, tls_version) + num_records = 10 * 1000 * 1000 # 10e6 + + # seed kafka w/messages + self.producer = ProducerPerformanceService( + self.test_context, 1, self.kafka, + topic=TOPIC_REP_THREE, + num_records=num_records, record_size=DEFAULT_RECORD_SIZE, throughput=-1, version=client_version, + settings={ + 'acks': 1, + 'compression.type': compression_type, + 'batch.size': self.batch_size, + 'buffer.memory': self.buffer_memory + } + ) + self.producer.run() + + # consume + self.consumer = ConsumerPerformanceService( + self.test_context, num_consumers, self.kafka, + topic=TOPIC_REP_THREE, messages=num_records) + self.consumer.group = "test-consumer-group" + self.consumer.run() + return compute_aggregate_throughput(self.consumer) + + def validate_versions(self, client_version, broker_version): + assert client_version <= broker_version, "Client version %s should be <= than broker version %s" (client_version, broker_version) diff --git a/tests/kafkatest/directory_layout/__init__.py b/tests/kafkatest/directory_layout/__init__.py new file mode 100644 index 0000000..ec20143 --- /dev/null +++ b/tests/kafkatest/directory_layout/__init__.py @@ -0,0 +1,14 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/kafkatest/directory_layout/kafka_path.py b/tests/kafkatest/directory_layout/kafka_path.py new file mode 100644 index 0000000..40dda22 --- /dev/null +++ b/tests/kafkatest/directory_layout/kafka_path.py @@ -0,0 +1,137 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import os + +from kafkatest.version import get_version, KafkaVersion, DEV_BRANCH + + +"""This module serves a few purposes: + +First, it gathers information about path layout in a single place, and second, it +makes the layout of the Kafka installation pluggable, so that users are not forced +to use the layout assumed in the KafkaPathResolver class. + +To run system tests using your own path resolver, use for example: + +ducktape --globals '{"kafka-path-resolver": "my.path.resolver.CustomResolverClass"}' +""" + +SCRATCH_ROOT = "/mnt" +KAFKA_INSTALL_ROOT = "/opt" +KAFKA_PATH_RESOLVER_KEY = "kafka-path-resolver" +KAFKA_PATH_RESOLVER = "kafkatest.directory_layout.kafka_path.KafkaSystemTestPathResolver" + +# Variables for jar path resolution +CORE_JAR_NAME = "core" +CORE_LIBS_JAR_NAME = "core-libs" +CORE_DEPENDANT_TEST_LIBS_JAR_NAME = "core-dependant-testlibs" +TOOLS_JAR_NAME = "tools" +TOOLS_DEPENDANT_TEST_LIBS_JAR_NAME = "tools-dependant-libs" + +JARS = { + "dev": { + CORE_JAR_NAME: "core/build/*/*.jar", + CORE_LIBS_JAR_NAME: "core/build/libs/*.jar", + CORE_DEPENDANT_TEST_LIBS_JAR_NAME: "core/build/dependant-testlibs/*.jar", + TOOLS_JAR_NAME: "tools/build/libs/kafka-tools*.jar", + TOOLS_DEPENDANT_TEST_LIBS_JAR_NAME: "tools/build/dependant-libs*/*.jar" + } +} + + +def create_path_resolver(context, project="kafka"): + """Factory for generating a path resolver class + + This will first check for a fully qualified path resolver classname in context.globals. + + If present, construct a new instance, else default to KafkaSystemTestPathResolver + """ + assert project is not None + + if KAFKA_PATH_RESOLVER_KEY in context.globals: + resolver_fully_qualified_classname = context.globals[KAFKA_PATH_RESOLVER_KEY] + else: + resolver_fully_qualified_classname = KAFKA_PATH_RESOLVER + + # Using the fully qualified classname, import the resolver class + (module_name, resolver_class_name) = resolver_fully_qualified_classname.rsplit('.', 1) + cluster_mod = importlib.import_module(module_name) + path_resolver_class = getattr(cluster_mod, resolver_class_name) + path_resolver = path_resolver_class(context, project) + + return path_resolver + + +class KafkaPathResolverMixin(object): + """Mixin to automatically provide pluggable path resolution functionality to any class using it. + + Keep life simple, and don't add a constructor to this class: + Since use of a mixin entails multiple inheritence, it is *much* simpler to reason about the interaction of this + class with subclasses if we don't have to worry about method resolution order, constructor signatures etc. + """ + + @property + def path(self): + if not hasattr(self, "_path"): + setattr(self, "_path", create_path_resolver(self.context, "kafka")) + if hasattr(self.context, "logger") and self.context.logger is not None: + self.context.logger.debug("Using path resolver %s" % self._path.__class__.__name__) + + return self._path + + +class KafkaSystemTestPathResolver(object): + """Path resolver for Kafka system tests which assumes the following layout: + + /opt/kafka-dev # Current version of kafka under test + /opt/kafka-0.9.0.1 # Example of an older version of kafka installed from tarball + /opt/kafka- # Other previous versions of kafka + ... + """ + def __init__(self, context, project="kafka"): + self.context = context + self.project = project + + def home(self, node_or_version=DEV_BRANCH, project=None): + version = self._version(node_or_version) + home_dir = project or self.project + if version is not None: + home_dir += "-%s" % str(version) + + return os.path.join(KAFKA_INSTALL_ROOT, home_dir) + + def bin(self, node_or_version=DEV_BRANCH, project=None): + version = self._version(node_or_version) + return os.path.join(self.home(version, project=project), "bin") + + def script(self, script_name, node_or_version=DEV_BRANCH, project=None): + version = self._version(node_or_version) + return os.path.join(self.bin(version, project=project), script_name) + + def jar(self, jar_name, node_or_version=DEV_BRANCH, project=None): + version = self._version(node_or_version) + return os.path.join(self.home(version, project=project), JARS[str(version)][jar_name]) + + def scratch_space(self, service_instance): + return os.path.join(SCRATCH_ROOT, service_instance.service_id) + + def _version(self, node_or_version): + if isinstance(node_or_version, KafkaVersion): + return node_or_version + else: + return get_version(node_or_version) + diff --git a/tests/kafkatest/sanity_checks/__init__.py b/tests/kafkatest/sanity_checks/__init__.py new file mode 100644 index 0000000..91eacc9 --- /dev/null +++ b/tests/kafkatest/sanity_checks/__init__.py @@ -0,0 +1,14 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/tests/kafkatest/sanity_checks/test_bounce.py b/tests/kafkatest/sanity_checks/test_bounce.py new file mode 100644 index 0000000..5c9cd7f --- /dev/null +++ b/tests/kafkatest/sanity_checks/test_bounce.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ducktape.mark import matrix +from ducktape.mark.resource import cluster +from ducktape.tests.test import Test +from ducktape.utils.util import wait_until + +from kafkatest.services.kafka import KafkaService, quorum +from kafkatest.services.verifiable_producer import VerifiableProducer +from kafkatest.services.zookeeper import ZookeeperService + + +class TestBounce(Test): + """Sanity checks on verifiable producer service class with cluster roll.""" + def __init__(self, test_context): + super(TestBounce, self).__init__(test_context) + + quorum_size_arg_name = 'quorum_size' + default_quorum_size = 1 + quorum_size = default_quorum_size if not test_context.injected_args else test_context.injected_args.get(quorum_size_arg_name, default_quorum_size) + if quorum_size < 1: + raise Exception("Illegal %s value provided for the test: %s" % (quorum_size_arg_name, quorum_size)) + self.topic = "topic" + self.zk = ZookeeperService(test_context, num_nodes=quorum_size) if quorum.for_test(test_context) == quorum.zk else None + num_kafka_nodes = quorum_size if quorum.for_test(test_context) == quorum.colocated_kraft else 1 + self.kafka = KafkaService(test_context, num_nodes=num_kafka_nodes, zk=self.zk, + topics={self.topic: {"partitions": 1, "replication-factor": 1}}, + controller_num_nodes_override=quorum_size) + self.num_messages = 1000 + + def create_producer(self): + # This will produce to source kafka cluster + self.producer = VerifiableProducer(self.test_context, num_nodes=1, kafka=self.kafka, topic=self.topic, + max_messages=self.num_messages, throughput=self.num_messages // 10) + def setUp(self): + if self.zk: + self.zk.start() + + # ZooKeeper and KRaft, quorum size = 1 + @cluster(num_nodes=4) + @matrix(metadata_quorum=quorum.all, quorum_size=[1]) + # Remote and Co-located KRaft, quorum size = 3 + @cluster(num_nodes=6) + @matrix(metadata_quorum=quorum.all_kraft, quorum_size=[3]) + def test_simple_run(self, metadata_quorum, quorum_size): + """ + Test that we can start VerifiableProducer on the current branch snapshot version, and + verify that we can produce a small number of messages both before and after a subsequent roll. + """ + self.kafka.start() + for first_time in [True, False]: + self.create_producer() + self.producer.start() + wait_until(lambda: self.producer.num_acked > 5, timeout_sec=15, + err_msg="Producer failed to start in a reasonable amount of time.") + + self.producer.wait() + num_produced = self.producer.num_acked + assert num_produced == self.num_messages, "num_produced: %d, num_messages: %d" % (num_produced, self.num_messages) + if first_time: + self.producer.stop() + if self.kafka.quorum_info.using_kraft and self.kafka.remote_controller_quorum: + self.kafka.remote_controller_quorum.restart_cluster() + self.kafka.restart_cluster() diff --git a/tests/kafkatest/sanity_checks/test_console_consumer.py b/tests/kafkatest/sanity_checks/test_console_consumer.py new file mode 100644 index 0000000..675920a --- /dev/null +++ b/tests/kafkatest/sanity_checks/test_console_consumer.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +from ducktape.mark import matrix +from ducktape.mark.resource import cluster +from ducktape.tests.test import Test +from ducktape.utils.util import wait_until + +from kafkatest.services.console_consumer import ConsoleConsumer +from kafkatest.services.kafka import KafkaService, quorum +from kafkatest.services.verifiable_producer import VerifiableProducer +from kafkatest.services.zookeeper import ZookeeperService +from kafkatest.utils.remote_account import line_count, file_exists +from kafkatest.version import LATEST_0_8_2 + + +class ConsoleConsumerTest(Test): + """Sanity checks on console consumer service class.""" + def __init__(self, test_context): + super(ConsoleConsumerTest, self).__init__(test_context) + + self.topic = "topic" + self.zk = ZookeeperService(test_context, num_nodes=1) if quorum.for_test(test_context) == quorum.zk else None + self.kafka = KafkaService(self.test_context, num_nodes=1, zk=self.zk, zk_chroot="/kafka", + topics={self.topic: {"partitions": 1, "replication-factor": 1}}) + self.consumer = ConsoleConsumer(self.test_context, num_nodes=1, kafka=self.kafka, topic=self.topic) + + def setUp(self): + if self.zk: + self.zk.start() + + @cluster(num_nodes=3) + @matrix(security_protocol=['PLAINTEXT', 'SSL'], metadata_quorum=quorum.all_kraft) + @cluster(num_nodes=4) + @matrix(security_protocol=['SASL_SSL'], sasl_mechanism=['PLAIN'], metadata_quorum=quorum.all_kraft) + @matrix(security_protocol=['SASL_SSL'], sasl_mechanism=['SCRAM-SHA-256', 'SCRAM-SHA-512']) # SCRAM not yet supported with KRaft + @matrix(security_protocol=['SASL_PLAINTEXT', 'SASL_SSL'], metadata_quorum=quorum.all_kraft) + def test_lifecycle(self, security_protocol, sasl_mechanism='GSSAPI', metadata_quorum=quorum.zk): + """Check that console consumer starts/stops properly, and that we are capturing log output.""" + + self.kafka.security_protocol = security_protocol + self.kafka.client_sasl_mechanism = sasl_mechanism + self.kafka.interbroker_sasl_mechanism = sasl_mechanism + self.kafka.start() + + self.consumer.security_protocol = security_protocol + + t0 = time.time() + self.consumer.start() + node = self.consumer.nodes[0] + + wait_until(lambda: self.consumer.alive(node), + timeout_sec=20, backoff_sec=.2, err_msg="Consumer was too slow to start") + self.logger.info("consumer started in %s seconds " % str(time.time() - t0)) + + # Verify that log output is happening + wait_until(lambda: file_exists(node, ConsoleConsumer.LOG_FILE), timeout_sec=10, + err_msg="Timed out waiting for consumer log file to exist.") + wait_until(lambda: line_count(node, ConsoleConsumer.LOG_FILE) > 0, timeout_sec=1, + backoff_sec=.25, err_msg="Timed out waiting for log entries to start.") + + # Verify no consumed messages + assert line_count(node, ConsoleConsumer.STDOUT_CAPTURE) == 0 + + self.consumer.stop_node(node) + + @cluster(num_nodes=4) + def test_version(self): + """Check that console consumer v0.8.2.X successfully starts and consumes messages.""" + self.kafka.start() + + num_messages = 1000 + self.producer = VerifiableProducer(self.test_context, num_nodes=1, kafka=self.kafka, topic=self.topic, + max_messages=num_messages, throughput=1000) + self.producer.start() + self.producer.wait() + + self.consumer.nodes[0].version = LATEST_0_8_2 + self.consumer.new_consumer = False + self.consumer.consumer_timeout_ms = 1000 + self.consumer.start() + self.consumer.wait() + + num_consumed = len(self.consumer.messages_consumed[1]) + num_produced = self.producer.num_acked + assert num_produced == num_consumed, "num_produced: %d, num_consumed: %d" % (num_produced, num_consumed) diff --git a/tests/kafkatest/sanity_checks/test_kafka_version.py b/tests/kafkatest/sanity_checks/test_kafka_version.py new file mode 100644 index 0000000..0265ecd --- /dev/null +++ b/tests/kafkatest/sanity_checks/test_kafka_version.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ducktape.tests.test import Test +from ducktape.mark.resource import cluster + +from kafkatest.services.kafka import KafkaService, config_property +from kafkatest.services.zookeeper import ZookeeperService +from kafkatest.utils import is_version +from kafkatest.version import LATEST_0_8_2, DEV_BRANCH + + +class KafkaVersionTest(Test): + """Sanity checks on kafka versioning.""" + def __init__(self, test_context): + super(KafkaVersionTest, self).__init__(test_context) + + self.topic = "topic" + self.zk = ZookeeperService(test_context, num_nodes=1) + + def setUp(self): + self.zk.start() + + @cluster(num_nodes=2) + def test_0_8_2(self): + """Test kafka service node-versioning api - verify that we can bring up a single-node 0.8.2.X cluster.""" + self.kafka = KafkaService(self.test_context, num_nodes=1, zk=self.zk, + topics={self.topic: {"partitions": 1, "replication-factor": 1}}) + node = self.kafka.nodes[0] + node.version = LATEST_0_8_2 + self.kafka.start() + + assert is_version(node, [LATEST_0_8_2], logger=self.logger) + + @cluster(num_nodes=3) + def test_multi_version(self): + """Test kafka service node-versioning api - ensure we can bring up a 2-node cluster, one on version 0.8.2.X, + the other on the current development branch.""" + self.kafka = KafkaService(self.test_context, num_nodes=2, zk=self.zk, + topics={self.topic: {"partitions": 1, "replication-factor": 2}}) + # Be sure to make node[0] the one with v0.8.2 because the topic will be created using the --zookeeper option + # since not all nodes support the --bootstrap-server option; the --zookeeper option is removed as of v3.0, + # and the topic will be created against the broker on node[0], so that node has to be the one running the older + # version (otherwise the kafka-topics --zookeeper command will fail). + self.kafka.nodes[0].version = LATEST_0_8_2 + self.kafka.nodes[0].config[config_property.INTER_BROKER_PROTOCOL_VERSION] = "0.8.2.X" + self.kafka.start() + + assert is_version(self.kafka.nodes[0], [LATEST_0_8_2], logger=self.logger) + assert is_version(self.kafka.nodes[1], [DEV_BRANCH.vstring], logger=self.logger) diff --git a/tests/kafkatest/sanity_checks/test_performance_services.py b/tests/kafkatest/sanity_checks/test_performance_services.py new file mode 100644 index 0000000..f0d1a48 --- /dev/null +++ b/tests/kafkatest/sanity_checks/test_performance_services.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ducktape.mark import matrix, parametrize +from ducktape.mark.resource import cluster +from ducktape.tests.test import Test + +from kafkatest.services.kafka import KafkaService, quorum +from kafkatest.services.performance import ProducerPerformanceService, ConsumerPerformanceService, EndToEndLatencyService +from kafkatest.services.performance import latency, compute_aggregate_throughput +from kafkatest.services.zookeeper import ZookeeperService +from kafkatest.version import DEV_BRANCH, LATEST_0_8_2, LATEST_0_9, LATEST_1_1, KafkaVersion + + +class PerformanceServiceTest(Test): + def __init__(self, test_context): + super(PerformanceServiceTest, self).__init__(test_context) + self.record_size = 100 + self.num_records = 10000 + self.topic = "topic" + + self.zk = ZookeeperService(test_context, 1) if quorum.for_test(test_context) == quorum.zk else None + + def setUp(self): + if self.zk: + self.zk.start() + + @cluster(num_nodes=5) + # We are keeping 0.8.2 here so that we don't inadvertently break support for it. Since this is just a sanity check, + # the overhead should be manageable. + @parametrize(version=str(LATEST_0_8_2), new_consumer=False) + @parametrize(version=str(LATEST_0_9), new_consumer=False) + @parametrize(version=str(LATEST_0_9)) + @parametrize(version=str(LATEST_1_1), new_consumer=False) + @cluster(num_nodes=5) + @matrix(version=[str(DEV_BRANCH)], metadata_quorum=quorum.all) + def test_version(self, version=str(LATEST_0_9), new_consumer=True, metadata_quorum=quorum.zk): + """ + Sanity check out producer performance service - verify that we can run the service with a small + number of messages. The actual stats here are pretty meaningless since the number of messages is quite small. + """ + version = KafkaVersion(version) + self.kafka = KafkaService( + self.test_context, 1, + self.zk, topics={self.topic: {'partitions': 1, 'replication-factor': 1}}, version=version) + self.kafka.start() + + # check basic run of producer performance + self.producer_perf = ProducerPerformanceService( + self.test_context, 1, self.kafka, topic=self.topic, + num_records=self.num_records, record_size=self.record_size, + throughput=1000000000, # Set impossibly for no throttling for equivalent behavior between 0.8.X and 0.9.X + version=version, + settings={ + 'acks': 1, + 'batch.size': 8*1024, + 'buffer.memory': 64*1024*1024}) + self.producer_perf.run() + producer_perf_data = compute_aggregate_throughput(self.producer_perf) + assert producer_perf_data['records_per_sec'] > 0 + + # check basic run of end to end latency + self.end_to_end = EndToEndLatencyService( + self.test_context, 1, self.kafka, + topic=self.topic, num_records=self.num_records, version=version) + self.end_to_end.run() + end_to_end_data = latency(self.end_to_end.results[0]['latency_50th_ms'], self.end_to_end.results[0]['latency_99th_ms'], self.end_to_end.results[0]['latency_999th_ms']) + + # check basic run of consumer performance service + self.consumer_perf = ConsumerPerformanceService( + self.test_context, 1, self.kafka, new_consumer=new_consumer, + topic=self.topic, version=version, messages=self.num_records) + self.consumer_perf.group = "test-consumer-group" + self.consumer_perf.run() + consumer_perf_data = compute_aggregate_throughput(self.consumer_perf) + assert consumer_perf_data['records_per_sec'] > 0 + + return { + "producer_performance": producer_perf_data, + "end_to_end_latency": end_to_end_data, + "consumer_performance": consumer_perf_data + } diff --git a/tests/kafkatest/sanity_checks/test_verifiable_producer.py b/tests/kafkatest/sanity_checks/test_verifiable_producer.py new file mode 100644 index 0000000..1aa2109 --- /dev/null +++ b/tests/kafkatest/sanity_checks/test_verifiable_producer.py @@ -0,0 +1,176 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ducktape.mark import matrix, parametrize +from ducktape.mark.resource import cluster +from ducktape.tests.test import Test +from ducktape.utils.util import wait_until + +from kafkatest.services.kafka import KafkaService, quorum +from kafkatest.services.verifiable_producer import VerifiableProducer +from kafkatest.services.zookeeper import ZookeeperService +from kafkatest.utils import is_version +from kafkatest.version import LATEST_0_8_2, LATEST_0_9, LATEST_0_10_0, LATEST_0_10_1, DEV_BRANCH, KafkaVersion + + +class TestVerifiableProducer(Test): + """Sanity checks on verifiable producer service class.""" + def __init__(self, test_context): + super(TestVerifiableProducer, self).__init__(test_context) + + self.topic = "topic" + self.zk = ZookeeperService(test_context, num_nodes=1) if quorum.for_test(test_context) == quorum.zk else None + self.kafka = KafkaService(test_context, num_nodes=1, zk=self.zk, + topics={self.topic: {"partitions": 1, "replication-factor": 1}}) + + self.num_messages = 1000 + # This will produce to source kafka cluster + self.producer = VerifiableProducer(test_context, num_nodes=1, kafka=self.kafka, topic=self.topic, + max_messages=self.num_messages, throughput=self.num_messages // 10) + def setUp(self): + if self.zk: + self.zk.start() + + @cluster(num_nodes=3) + @parametrize(producer_version=str(LATEST_0_8_2)) + @parametrize(producer_version=str(LATEST_0_9)) + @parametrize(producer_version=str(LATEST_0_10_0)) + @parametrize(producer_version=str(LATEST_0_10_1)) + @matrix(producer_version=[str(DEV_BRANCH)], acks=["0", "1", "-1"], enable_idempotence=[False]) + @matrix(producer_version=[str(DEV_BRANCH)], acks=["-1"], enable_idempotence=[True]) + @matrix(producer_version=[str(DEV_BRANCH)], security_protocol=['PLAINTEXT', 'SSL'], metadata_quorum=quorum.all) + @cluster(num_nodes=4) + @matrix(producer_version=[str(DEV_BRANCH)], security_protocol=['SASL_SSL'], sasl_mechanism=['PLAIN', 'GSSAPI'], + metadata_quorum=quorum.all) + def test_simple_run(self, producer_version, acks=None, enable_idempotence=False, security_protocol = 'PLAINTEXT', + sasl_mechanism='PLAIN', metadata_quorum=quorum.zk): + """ + Test that we can start VerifiableProducer on the current branch snapshot version or against the 0.8.2 jar, and + verify that we can produce a small number of messages. + """ + self.kafka.security_protocol = security_protocol + self.kafka.client_sasl_mechanism = sasl_mechanism + self.kafka.interbroker_security_protocol = security_protocol + self.kafka.interbroker_sasl_mechanism = sasl_mechanism + if self.kafka.quorum_info.using_kraft: + controller_quorum = self.kafka.controller_quorum + controller_quorum.controller_security_protocol = security_protocol + controller_quorum.controller_sasl_mechanism = sasl_mechanism + controller_quorum.intercontroller_security_protocol = security_protocol + controller_quorum.intercontroller_sasl_mechanism = sasl_mechanism + self.kafka.start() + + node = self.producer.nodes[0] + self.producer.enable_idempotence = enable_idempotence + self.producer.acks = acks + node.version = KafkaVersion(producer_version) + self.producer.start() + wait_until(lambda: self.producer.num_acked > 5, timeout_sec=15, + err_msg="Producer failed to start in a reasonable amount of time.") + + # using version.vstring (distutils.version.LooseVersion) is a tricky way of ensuring + # that this check works with DEV_BRANCH + # When running VerifiableProducer 0.8.X, both the current branch version and 0.8.X should show up because of the + # way verifiable producer pulls in some development directories into its classpath + # + # If the test fails here because 'ps .. | grep' couldn't find the process it means + # the login and grep that is_version() performs is slower than + # the time it takes the producer to produce its messages. + # Easy fix is to decrease throughput= above, the good fix is to make the producer + # not terminate until explicitly killed in this case. + if node.version <= LATEST_0_8_2: + assert is_version(node, [node.version.vstring, DEV_BRANCH.vstring], logger=self.logger) + else: + assert is_version(node, [node.version.vstring], logger=self.logger) + + self.producer.wait() + num_produced = self.producer.num_acked + assert num_produced == self.num_messages, "num_produced: %d, num_messages: %d" % (num_produced, self.num_messages) + + @cluster(num_nodes=4) + @matrix(inter_broker_security_protocol=['PLAINTEXT', 'SSL'], metadata_quorum=[quorum.remote_kraft]) + @matrix(inter_broker_security_protocol=['SASL_SSL'], inter_broker_sasl_mechanism=['PLAIN', 'GSSAPI'], + metadata_quorum=[quorum.remote_kraft]) + def test_multiple_kraft_security_protocols( + self, inter_broker_security_protocol, inter_broker_sasl_mechanism='GSSAPI', metadata_quorum=quorum.remote_kraft): + """ + Test for remote KRaft cases that we can start VerifiableProducer on the current branch snapshot version, and + verify that we can produce a small number of messages. The inter-controller and broker-to-controller + security protocols are defined to be different (which differs from the above test, where they were the same). + """ + self.kafka.security_protocol = self.kafka.interbroker_security_protocol = inter_broker_security_protocol + self.kafka.client_sasl_mechanism = self.kafka.interbroker_sasl_mechanism = inter_broker_sasl_mechanism + controller_quorum = self.kafka.controller_quorum + sasl_mechanism = 'PLAIN' if inter_broker_sasl_mechanism == 'GSSAPI' else 'GSSAPI' + if inter_broker_security_protocol == 'PLAINTEXT': + controller_security_protocol = 'SSL' + intercontroller_security_protocol = 'SASL_SSL' + elif inter_broker_security_protocol == 'SSL': + controller_security_protocol = 'SASL_SSL' + intercontroller_security_protocol = 'PLAINTEXT' + else: # inter_broker_security_protocol == 'SASL_SSL' + controller_security_protocol = 'PLAINTEXT' + intercontroller_security_protocol = 'SSL' + controller_quorum.controller_security_protocol = controller_security_protocol + controller_quorum.controller_sasl_mechanism = sasl_mechanism + controller_quorum.intercontroller_security_protocol = intercontroller_security_protocol + controller_quorum.intercontroller_sasl_mechanism = sasl_mechanism + self.kafka.start() + + node = self.producer.nodes[0] + node.version = KafkaVersion(str(DEV_BRANCH)) + self.producer.start() + wait_until(lambda: self.producer.num_acked > 5, timeout_sec=15, + err_msg="Producer failed to start in a reasonable amount of time.") + + # See above comment above regarding use of version.vstring (distutils.version.LooseVersion) + assert is_version(node, [node.version.vstring], logger=self.logger) + + self.producer.wait() + num_produced = self.producer.num_acked + assert num_produced == self.num_messages, "num_produced: %d, num_messages: %d" % (num_produced, self.num_messages) + + @cluster(num_nodes=4) + @parametrize(metadata_quorum=quorum.remote_kraft) + def test_multiple_kraft_sasl_mechanisms(self, metadata_quorum): + """ + Test for remote KRaft cases that we can start VerifiableProducer on the current branch snapshot version, and + verify that we can produce a small number of messages. The inter-controller and broker-to-controller + security protocols are both SASL_PLAINTEXT but the SASL mechanisms are different (we set + GSSAPI for the inter-controller mechanism and PLAIN for the broker-to-controller mechanism). + This test differs from the above tests -- he ones above used the same SASL mechanism for both paths. + """ + self.kafka.security_protocol = self.kafka.interbroker_security_protocol = 'PLAINTEXT' + controller_quorum = self.kafka.controller_quorum + controller_quorum.controller_security_protocol = 'SASL_PLAINTEXT' + controller_quorum.controller_sasl_mechanism = 'PLAIN' + controller_quorum.intercontroller_security_protocol = 'SASL_PLAINTEXT' + controller_quorum.intercontroller_sasl_mechanism = 'GSSAPI' + self.kafka.start() + + node = self.producer.nodes[0] + node.version = KafkaVersion(str(DEV_BRANCH)) + self.producer.start() + wait_until(lambda: self.producer.num_acked > 5, timeout_sec=15, + err_msg="Producer failed to start in a reasonable amount of time.") + + # See above comment above regarding use of version.vstring (distutils.version.LooseVersion) + assert is_version(node, [node.version.vstring], logger=self.logger) + + self.producer.wait() + num_produced = self.producer.num_acked + assert num_produced == self.num_messages, "num_produced: %d, num_messages: %d" % (num_produced, self.num_messages) + diff --git a/tests/kafkatest/services/__init__.py b/tests/kafkatest/services/__init__.py new file mode 100644 index 0000000..ec20143 --- /dev/null +++ b/tests/kafkatest/services/__init__.py @@ -0,0 +1,14 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/kafkatest/services/connect.py b/tests/kafkatest/services/connect.py new file mode 100644 index 0000000..26c0d92 --- /dev/null +++ b/tests/kafkatest/services/connect.py @@ -0,0 +1,525 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os.path +import random +import signal +import time + +import requests +from ducktape.errors import DucktapeError +from ducktape.services.service import Service +from ducktape.utils.util import wait_until + +from kafkatest.directory_layout.kafka_path import KafkaPathResolverMixin +from kafkatest.services.kafka.util import fix_opts_for_new_jvm + + +class ConnectServiceBase(KafkaPathResolverMixin, Service): + """Base class for Kafka Connect services providing some common settings and functionality""" + + PERSISTENT_ROOT = "/mnt/connect" + CONFIG_FILE = os.path.join(PERSISTENT_ROOT, "connect.properties") + # The log file contains normal log4j logs written using a file appender. stdout and stderr are handled separately + # so they can be used for other output, e.g. verifiable source & sink. + LOG_FILE = os.path.join(PERSISTENT_ROOT, "connect.log") + STDOUT_FILE = os.path.join(PERSISTENT_ROOT, "connect.stdout") + STDERR_FILE = os.path.join(PERSISTENT_ROOT, "connect.stderr") + LOG4J_CONFIG_FILE = os.path.join(PERSISTENT_ROOT, "connect-log4j.properties") + PID_FILE = os.path.join(PERSISTENT_ROOT, "connect.pid") + EXTERNAL_CONFIGS_FILE = os.path.join(PERSISTENT_ROOT, "connect-external-configs.properties") + CONNECT_REST_PORT = 8083 + HEAP_DUMP_FILE = os.path.join(PERSISTENT_ROOT, "connect_heap_dump.bin") + + # Currently the Connect worker supports waiting on four modes: + STARTUP_MODE_INSTANT = 'INSTANT' + """STARTUP_MODE_INSTANT: Start Connect worker and return immediately""" + STARTUP_MODE_LOAD = 'LOAD' + """STARTUP_MODE_LOAD: Start Connect worker and return after discovering and loading plugins""" + STARTUP_MODE_LISTEN = 'LISTEN' + """STARTUP_MODE_LISTEN: Start Connect worker and return after opening the REST port.""" + STARTUP_MODE_JOIN = 'JOIN' + """STARTUP_MODE_JOIN: Start Connect worker and return after joining the group.""" + + logs = { + "connect_log": { + "path": LOG_FILE, + "collect_default": True}, + "connect_stdout": { + "path": STDOUT_FILE, + "collect_default": False}, + "connect_stderr": { + "path": STDERR_FILE, + "collect_default": True}, + "connect_heap_dump_file": { + "path": HEAP_DUMP_FILE, + "collect_default": True} + } + + def __init__(self, context, num_nodes, kafka, files, startup_timeout_sec = 60): + super(ConnectServiceBase, self).__init__(context, num_nodes) + self.kafka = kafka + self.security_config = kafka.security_config.client_config() + self.files = files + self.startup_mode = self.STARTUP_MODE_LISTEN + self.startup_timeout_sec = startup_timeout_sec + self.environment = {} + self.external_config_template_func = None + + def pids(self, node): + """Return process ids for Kafka Connect processes.""" + try: + return [pid for pid in node.account.ssh_capture("cat " + self.PID_FILE, callback=int)] + except: + return [] + + def set_configs(self, config_template_func, connector_config_templates=None): + """ + Set configurations for the worker and the connector to run on + it. These are not provided in the constructor because the worker + config generally needs access to ZK/Kafka services to + create the configuration. + """ + self.config_template_func = config_template_func + self.connector_config_templates = connector_config_templates + + def set_external_configs(self, external_config_template_func): + """ + Set the properties that will be written in the external file properties + as used by the org.apache.kafka.common.config.provider.FileConfigProvider. + When this is used, the worker configuration must also enable the FileConfigProvider. + This is not provided in the constructor in case the worker + config generally needs access to ZK/Kafka services to + create the configuration. + """ + self.external_config_template_func = external_config_template_func + + def listening(self, node): + try: + self.list_connectors(node) + self.logger.debug("Connect worker started serving REST at: '%s:%s')", node.account.hostname, + self.CONNECT_REST_PORT) + return True + except requests.exceptions.ConnectionError: + self.logger.debug("REST resources are not loaded yet") + return False + + def start(self, mode=None): + if mode: + self.startup_mode = mode + super(ConnectServiceBase, self).start() + + def start_and_return_immediately(self, node, worker_type, remote_connector_configs): + cmd = self.start_cmd(node, remote_connector_configs) + self.logger.debug("Connect %s command: %s", worker_type, cmd) + node.account.ssh(cmd) + + def start_and_wait_to_load_plugins(self, node, worker_type, remote_connector_configs): + with node.account.monitor_log(self.LOG_FILE) as monitor: + self.start_and_return_immediately(node, worker_type, remote_connector_configs) + monitor.wait_until('Kafka version', timeout_sec=self.startup_timeout_sec, + err_msg="Never saw message indicating Kafka Connect finished startup on node: " + + "%s in condition mode: %s" % (str(node.account), self.startup_mode)) + + def start_and_wait_to_start_listening(self, node, worker_type, remote_connector_configs): + self.start_and_return_immediately(node, worker_type, remote_connector_configs) + wait_until(lambda: self.listening(node), timeout_sec=self.startup_timeout_sec, + err_msg="Kafka Connect failed to start on node: %s in condition mode: %s" % + (str(node.account), self.startup_mode)) + + def start_and_wait_to_join_group(self, node, worker_type, remote_connector_configs): + if worker_type != 'distributed': + raise RuntimeError("Cannot wait for joined group message for %s" % worker_type) + with node.account.monitor_log(self.LOG_FILE) as monitor: + self.start_and_return_immediately(node, worker_type, remote_connector_configs) + monitor.wait_until('Joined group', timeout_sec=self.startup_timeout_sec, + err_msg="Never saw message indicating Kafka Connect joined group on node: " + + "%s in condition mode: %s" % (str(node.account), self.startup_mode)) + + def stop_node(self, node, clean_shutdown=True): + self.logger.info((clean_shutdown and "Cleanly" or "Forcibly") + " stopping Kafka Connect on " + str(node.account)) + pids = self.pids(node) + sig = signal.SIGTERM if clean_shutdown else signal.SIGKILL + + for pid in pids: + node.account.signal(pid, sig, allow_fail=True) + if clean_shutdown: + for pid in pids: + wait_until(lambda: not node.account.alive(pid), timeout_sec=self.startup_timeout_sec, err_msg="Kafka Connect process on " + str( + node.account) + " took too long to exit") + + node.account.ssh("rm -f " + self.PID_FILE, allow_fail=False) + + def restart(self, clean_shutdown=True): + # We don't want to do any clean up here, just restart the process. + for node in self.nodes: + self.logger.info("Restarting Kafka Connect on " + str(node.account)) + self.restart_node(node, clean_shutdown) + + def restart_node(self, node, clean_shutdown=True): + self.stop_node(node, clean_shutdown) + self.start_node(node) + + def clean_node(self, node): + node.account.kill_process("connect", clean_shutdown=False, allow_fail=True) + self.security_config.clean_node(node) + other_files = " ".join(self.config_filenames() + self.files) + node.account.ssh("rm -rf -- %s %s" % (ConnectServiceBase.PERSISTENT_ROOT, other_files), allow_fail=False) + + def config_filenames(self): + return [os.path.join(self.PERSISTENT_ROOT, "connect-connector-" + str(idx) + ".properties") for idx, template in enumerate(self.connector_config_templates or [])] + + def list_connectors(self, node=None, **kwargs): + return self._rest_with_retry('/connectors', node=node, **kwargs) + + def create_connector(self, config, node=None, **kwargs): + create_request = { + 'name': config['name'], + 'config': config + } + return self._rest_with_retry('/connectors', create_request, node=node, method="POST", **kwargs) + + def get_connector(self, name, node=None, **kwargs): + return self._rest_with_retry('/connectors/' + name, node=node, **kwargs) + + def get_connector_config(self, name, node=None, **kwargs): + return self._rest_with_retry('/connectors/' + name + '/config', node=node, **kwargs) + + def set_connector_config(self, name, config, node=None, **kwargs): + # Unlike many other calls, a 409 when setting a connector config is expected if the connector already exists. + # However, we also might see 409s for other reasons (e.g. rebalancing). So we still perform retries at the cost + # of tests possibly taking longer to ultimately fail. Tests that care about this can explicitly override the + # number of retries. + return self._rest_with_retry('/connectors/' + name + '/config', config, node=node, method="PUT", **kwargs) + + def get_connector_tasks(self, name, node=None, **kwargs): + return self._rest_with_retry('/connectors/' + name + '/tasks', node=node, **kwargs) + + def delete_connector(self, name, node=None, **kwargs): + return self._rest_with_retry('/connectors/' + name, node=node, method="DELETE", **kwargs) + + def get_connector_status(self, name, node=None): + return self._rest('/connectors/' + name + '/status', node=node) + + def restart_connector(self, name, node=None, **kwargs): + return self._rest_with_retry('/connectors/' + name + '/restart', node=node, method="POST", **kwargs) + + def restart_connector_and_tasks(self, name, only_failed, include_tasks, node=None, **kwargs): + return self._rest_with_retry('/connectors/' + name + '/restart?onlyFailed=' + only_failed + '&includeTasks=' + include_tasks, node=node, method="POST", **kwargs) + + def restart_task(self, connector_name, task_id, node=None): + return self._rest('/connectors/' + connector_name + '/tasks/' + str(task_id) + '/restart', node=node, method="POST") + + def pause_connector(self, name, node=None): + return self._rest('/connectors/' + name + '/pause', node=node, method="PUT") + + def resume_connector(self, name, node=None): + return self._rest('/connectors/' + name + '/resume', node=node, method="PUT") + + def list_connector_plugins(self, node=None): + return self._rest('/connector-plugins/', node=node) + + def validate_config(self, connector_type, validate_request, node=None): + return self._rest('/connector-plugins/' + connector_type + '/config/validate', validate_request, node=node, method="PUT") + + def _rest(self, path, body=None, node=None, method="GET"): + if node is None: + node = random.choice(self.nodes) + + meth = getattr(requests, method.lower()) + url = self._base_url(node) + path + self.logger.debug("Kafka Connect REST request: %s %s %s %s", node.account.hostname, url, method, body) + resp = meth(url, json=body) + self.logger.debug("%s %s response: %d", url, method, resp.status_code) + if resp.status_code > 400: + self.logger.debug("Connect REST API error for %s: %d %s", resp.url, resp.status_code, resp.text) + raise ConnectRestError(resp.status_code, resp.text, resp.url) + if resp.status_code == 204 or resp.status_code == 202: + return None + else: + return resp.json() + + def _rest_with_retry(self, path, body=None, node=None, method="GET", retries=40, retry_backoff=.25): + """ + Invokes a REST API with retries for errors that may occur during normal operation (notably 409 CONFLICT + responses that can occur due to rebalancing or 404 when the connect resources are not initialized yet). + """ + exception_to_throw = None + for i in range(0, retries + 1): + try: + return self._rest(path, body, node, method) + except ConnectRestError as e: + exception_to_throw = e + if e.status != 409 and e.status != 404: + break + time.sleep(retry_backoff) + raise exception_to_throw + + def _base_url(self, node): + return 'http://' + node.account.externally_routable_ip + ':' + str(self.CONNECT_REST_PORT) + + def append_to_environment_variable(self, envvar, value): + env_opts = self.environment[envvar] + if env_opts is None: + env_opts = "\"%s\"" % value + else: + env_opts = "\"%s %s\"" % (env_opts.strip('\"'), value) + self.environment[envvar] = env_opts + + +class ConnectStandaloneService(ConnectServiceBase): + """Runs Kafka Connect in standalone mode.""" + + def __init__(self, context, kafka, files, startup_timeout_sec = 60): + super(ConnectStandaloneService, self).__init__(context, 1, kafka, files, startup_timeout_sec) + + # For convenience since this service only makes sense with a single node + @property + def node(self): + return self.nodes[0] + + def start_cmd(self, node, connector_configs): + cmd = "( export KAFKA_LOG4J_OPTS=\"-Dlog4j.configuration=file:%s\"; " % self.LOG4J_CONFIG_FILE + heap_kafka_opts = "-XX:+HeapDumpOnOutOfMemoryError -XX:HeapDumpPath=%s" % \ + self.logs["connect_heap_dump_file"]["path"] + other_kafka_opts = self.security_config.kafka_opts.strip('\"') + + cmd += fix_opts_for_new_jvm(node) + cmd += "export KAFKA_OPTS=\"%s %s\"; " % (heap_kafka_opts, other_kafka_opts) + for envvar in self.environment: + cmd += "export %s=%s; " % (envvar, str(self.environment[envvar])) + cmd += "%s %s " % (self.path.script("connect-standalone.sh", node), self.CONFIG_FILE) + cmd += " ".join(connector_configs) + cmd += " & echo $! >&3 ) 1>> %s 2>> %s 3> %s" % (self.STDOUT_FILE, self.STDERR_FILE, self.PID_FILE) + return cmd + + def start_node(self, node): + node.account.ssh("mkdir -p %s" % self.PERSISTENT_ROOT, allow_fail=False) + + self.security_config.setup_node(node) + if self.external_config_template_func: + node.account.create_file(self.EXTERNAL_CONFIGS_FILE, self.external_config_template_func(node)) + node.account.create_file(self.CONFIG_FILE, self.config_template_func(node)) + node.account.create_file(self.LOG4J_CONFIG_FILE, self.render('connect_log4j.properties', log_file=self.LOG_FILE)) + remote_connector_configs = [] + for idx, template in enumerate(self.connector_config_templates): + target_file = os.path.join(self.PERSISTENT_ROOT, "connect-connector-" + str(idx) + ".properties") + node.account.create_file(target_file, template) + remote_connector_configs.append(target_file) + + self.logger.info("Starting Kafka Connect standalone process on " + str(node.account)) + if self.startup_mode == self.STARTUP_MODE_LOAD: + self.start_and_wait_to_load_plugins(node, 'standalone', remote_connector_configs) + elif self.startup_mode == self.STARTUP_MODE_INSTANT: + self.start_and_return_immediately(node, 'standalone', remote_connector_configs) + elif self.startup_mode == self.STARTUP_MODE_JOIN: + self.start_and_wait_to_join_group(node, 'standalone', remote_connector_configs) + else: + # The default mode is to wait until the complete startup of the worker + self.start_and_wait_to_start_listening(node, 'standalone', remote_connector_configs) + + if len(self.pids(node)) == 0: + raise RuntimeError("No process ids recorded") + + +class ConnectDistributedService(ConnectServiceBase): + """Runs Kafka Connect in distributed mode.""" + + def __init__(self, context, num_nodes, kafka, files, offsets_topic="connect-offsets", + configs_topic="connect-configs", status_topic="connect-status", startup_timeout_sec = 60): + super(ConnectDistributedService, self).__init__(context, num_nodes, kafka, files, startup_timeout_sec) + self.startup_mode = self.STARTUP_MODE_JOIN + self.offsets_topic = offsets_topic + self.configs_topic = configs_topic + self.status_topic = status_topic + + # connector_configs argument is intentionally ignored in distributed service. + def start_cmd(self, node, connector_configs): + cmd = "( export KAFKA_LOG4J_OPTS=\"-Dlog4j.configuration=file:%s\"; " % self.LOG4J_CONFIG_FILE + heap_kafka_opts = "-XX:+HeapDumpOnOutOfMemoryError -XX:HeapDumpPath=%s" % \ + self.logs["connect_heap_dump_file"]["path"] + other_kafka_opts = self.security_config.kafka_opts.strip('\"') + cmd += "export KAFKA_OPTS=\"%s %s\"; " % (heap_kafka_opts, other_kafka_opts) + for envvar in self.environment: + cmd += "export %s=%s; " % (envvar, str(self.environment[envvar])) + cmd += "%s %s " % (self.path.script("connect-distributed.sh", node), self.CONFIG_FILE) + cmd += " & echo $! >&3 ) 1>> %s 2>> %s 3> %s" % (self.STDOUT_FILE, self.STDERR_FILE, self.PID_FILE) + return cmd + + def start_node(self, node): + node.account.ssh("mkdir -p %s" % self.PERSISTENT_ROOT, allow_fail=False) + + self.security_config.setup_node(node) + if self.external_config_template_func: + node.account.create_file(self.EXTERNAL_CONFIGS_FILE, self.external_config_template_func(node)) + node.account.create_file(self.CONFIG_FILE, self.config_template_func(node)) + node.account.create_file(self.LOG4J_CONFIG_FILE, self.render('connect_log4j.properties', log_file=self.LOG_FILE)) + if self.connector_config_templates: + raise DucktapeError("Config files are not valid in distributed mode, submit connectors via the REST API") + + self.logger.info("Starting Kafka Connect distributed process on " + str(node.account)) + if self.startup_mode == self.STARTUP_MODE_LOAD: + self.start_and_wait_to_load_plugins(node, 'distributed', '') + elif self.startup_mode == self.STARTUP_MODE_INSTANT: + self.start_and_return_immediately(node, 'distributed', '') + elif self.startup_mode == self.STARTUP_MODE_LISTEN: + self.start_and_wait_to_start_listening(node, 'distributed', '') + else: + # The default mode is to wait until the complete startup of the worker + self.start_and_wait_to_join_group(node, 'distributed', '') + + if len(self.pids(node)) == 0: + raise RuntimeError("No process ids recorded") + + +class ErrorTolerance(object): + ALL = "all" + NONE = "none" + + +class ConnectRestError(RuntimeError): + def __init__(self, status, msg, url): + self.status = status + self.message = msg + self.url = url + + def __unicode__(self): + return "Kafka Connect REST call failed: returned " + self.status + " for " + self.url + ". Response: " + self.message + + +class VerifiableConnector(object): + def messages(self): + """ + Collect and parse the logs from Kafka Connect nodes. Return a list containing all parsed JSON messages generated by + this source. + """ + self.logger.info("Collecting messages from log of %s %s", type(self).__name__, self.name) + records = [] + for node in self.cc.nodes: + for line in node.account.ssh_capture('cat ' + self.cc.STDOUT_FILE): + try: + data = json.loads(line) + except ValueError: + self.logger.debug("Ignoring unparseable line: %s", line) + continue + # Filter to only ones matching our name to support multiple verifiable producers + if data['name'] != self.name: + continue + data['node'] = node + records.append(data) + return records + + def stop(self): + self.logger.info("Destroying connector %s %s", type(self).__name__, self.name) + self.cc.delete_connector(self.name) + + +class VerifiableSource(VerifiableConnector): + """ + Helper class for running a verifiable source connector on a Kafka Connect cluster and analyzing the output. + """ + + def __init__(self, cc, name="verifiable-source", tasks=1, topic="verifiable", throughput=1000): + self.cc = cc + self.logger = self.cc.logger + self.name = name + self.tasks = tasks + self.topic = topic + self.throughput = throughput + + def committed_messages(self): + return list(filter(lambda m: 'committed' in m and m['committed'], self.messages())) + + def sent_messages(self): + return list(filter(lambda m: 'committed' not in m or not m['committed'], self.messages())) + + def start(self): + self.logger.info("Creating connector VerifiableSourceConnector %s", self.name) + self.cc.create_connector({ + 'name': self.name, + 'connector.class': 'org.apache.kafka.connect.tools.VerifiableSourceConnector', + 'tasks.max': self.tasks, + 'topic': self.topic, + 'throughput': self.throughput + }) + + +class VerifiableSink(VerifiableConnector): + """ + Helper class for running a verifiable sink connector on a Kafka Connect cluster and analyzing the output. + """ + + def __init__(self, cc, name="verifiable-sink", tasks=1, topics=["verifiable"]): + self.cc = cc + self.logger = self.cc.logger + self.name = name + self.tasks = tasks + self.topics = topics + + def flushed_messages(self): + return list(filter(lambda m: 'flushed' in m and m['flushed'], self.messages())) + + def received_messages(self): + return list(filter(lambda m: 'flushed' not in m or not m['flushed'], self.messages())) + + def start(self): + self.logger.info("Creating connector VerifiableSinkConnector %s", self.name) + self.cc.create_connector({ + 'name': self.name, + 'connector.class': 'org.apache.kafka.connect.tools.VerifiableSinkConnector', + 'tasks.max': self.tasks, + 'topics': ",".join(self.topics) + }) + +class MockSink(object): + + def __init__(self, cc, topics, mode=None, delay_sec=10, name="mock-sink"): + self.cc = cc + self.logger = self.cc.logger + self.name = name + self.mode = mode + self.delay_sec = delay_sec + self.topics = topics + + def start(self): + self.logger.info("Creating connector MockSinkConnector %s", self.name) + self.cc.create_connector({ + 'name': self.name, + 'connector.class': 'org.apache.kafka.connect.tools.MockSinkConnector', + 'tasks.max': 1, + 'topics': ",".join(self.topics), + 'mock_mode': self.mode, + 'delay_ms': self.delay_sec * 1000 + }) + +class MockSource(object): + + def __init__(self, cc, mode=None, delay_sec=10, name="mock-source"): + self.cc = cc + self.logger = self.cc.logger + self.name = name + self.mode = mode + self.delay_sec = delay_sec + + def start(self): + self.logger.info("Creating connector MockSourceConnector %s", self.name) + self.cc.create_connector({ + 'name': self.name, + 'connector.class': 'org.apache.kafka.connect.tools.MockSourceConnector', + 'tasks.max': 1, + 'mock_mode': self.mode, + 'delay_ms': self.delay_sec * 1000 + }) diff --git a/tests/kafkatest/services/console_consumer.py b/tests/kafkatest/services/console_consumer.py new file mode 100644 index 0000000..32e7145 --- /dev/null +++ b/tests/kafkatest/services/console_consumer.py @@ -0,0 +1,340 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from ducktape.cluster.remoteaccount import RemoteCommandError +from ducktape.services.background_thread import BackgroundThreadService +from ducktape.utils.util import wait_until + +from kafkatest.directory_layout.kafka_path import KafkaPathResolverMixin +from kafkatest.services.monitor.jmx import JmxMixin, JmxTool +from kafkatest.version import DEV_BRANCH, LATEST_0_8_2, LATEST_0_9, LATEST_0_10_0, V_0_10_0_0, V_0_11_0_0, V_2_0_0 +from kafkatest.services.kafka.util import fix_opts_for_new_jvm + +""" +The console consumer is a tool that reads data from Kafka and outputs it to standard output. +""" + + +class ConsoleConsumer(KafkaPathResolverMixin, JmxMixin, BackgroundThreadService): + # Root directory for persistent output + PERSISTENT_ROOT = "/mnt/console_consumer" + STDOUT_CAPTURE = os.path.join(PERSISTENT_ROOT, "console_consumer.stdout") + STDERR_CAPTURE = os.path.join(PERSISTENT_ROOT, "console_consumer.stderr") + LOG_DIR = os.path.join(PERSISTENT_ROOT, "logs") + LOG_FILE = os.path.join(LOG_DIR, "console_consumer.log") + LOG4J_CONFIG = os.path.join(PERSISTENT_ROOT, "tools-log4j.properties") + CONFIG_FILE = os.path.join(PERSISTENT_ROOT, "console_consumer.properties") + JMX_TOOL_LOG = os.path.join(PERSISTENT_ROOT, "jmx_tool.log") + JMX_TOOL_ERROR_LOG = os.path.join(PERSISTENT_ROOT, "jmx_tool.err.log") + + logs = { + "consumer_stdout": { + "path": STDOUT_CAPTURE, + "collect_default": False}, + "consumer_stderr": { + "path": STDERR_CAPTURE, + "collect_default": False}, + "consumer_log": { + "path": LOG_FILE, + "collect_default": True}, + "jmx_log": { + "path" : JMX_TOOL_LOG, + "collect_default": False}, + "jmx_err_log": { + "path": JMX_TOOL_ERROR_LOG, + "collect_default": False} + } + + def __init__(self, context, num_nodes, kafka, topic, group_id="test-consumer-group", new_consumer=True, + message_validator=None, from_beginning=True, consumer_timeout_ms=None, version=DEV_BRANCH, + client_id="console-consumer", print_key=False, jmx_object_names=None, jmx_attributes=None, + enable_systest_events=False, stop_timeout_sec=35, print_timestamp=False, print_partition=False, + isolation_level="read_uncommitted", jaas_override_variables=None, + kafka_opts_override="", client_prop_file_override="", consumer_properties={}, + wait_until_partitions_assigned=False): + """ + Args: + context: standard context + num_nodes: number of nodes to use (this should be 1) + kafka: kafka service + topic: consume from this topic + new_consumer: use new Kafka consumer if True + message_validator: function which returns message or None + from_beginning: consume from beginning if True, else from the end + consumer_timeout_ms: corresponds to consumer.timeout.ms. consumer process ends if time between + successively consumed messages exceeds this timeout. Setting this and + waiting for the consumer to stop is a pretty good way to consume all messages + in a topic. + print_timestamp if True, print each message's timestamp as well + print_key if True, print each message's key as well + print_partition if True, print each message's partition as well + enable_systest_events if True, console consumer will print additional lifecycle-related information + only available in 0.10.0 and later. + stop_timeout_sec After stopping a node, wait up to stop_timeout_sec for the node to stop, + and the corresponding background thread to finish successfully. + isolation_level How to handle transactional messages. + jaas_override_variables A dict of variables to be used in the jaas.conf template file + kafka_opts_override Override parameters of the KAFKA_OPTS environment variable + client_prop_file_override Override client.properties file used by the consumer + consumer_properties A dict of values to pass in as --consumer-property key=value + """ + JmxMixin.__init__(self, num_nodes=num_nodes, jmx_object_names=jmx_object_names, jmx_attributes=(jmx_attributes or []), + root=ConsoleConsumer.PERSISTENT_ROOT) + BackgroundThreadService.__init__(self, context, num_nodes) + self.kafka = kafka + self.new_consumer = new_consumer + self.group_id = group_id + self.args = { + 'topic': topic, + } + + self.consumer_timeout_ms = consumer_timeout_ms + for node in self.nodes: + node.version = version + + self.from_beginning = from_beginning + self.message_validator = message_validator + self.messages_consumed = {idx: [] for idx in range(1, num_nodes + 1)} + self.clean_shutdown_nodes = set() + self.client_id = client_id + self.print_key = print_key + self.print_partition = print_partition + self.log_level = "TRACE" + self.stop_timeout_sec = stop_timeout_sec + + self.isolation_level = isolation_level + self.enable_systest_events = enable_systest_events + if self.enable_systest_events: + # Only available in 0.10.0 and up + assert version >= V_0_10_0_0 + + self.print_timestamp = print_timestamp + self.jaas_override_variables = jaas_override_variables or {} + self.kafka_opts_override = kafka_opts_override + self.client_prop_file_override = client_prop_file_override + self.consumer_properties = consumer_properties + self.wait_until_partitions_assigned = wait_until_partitions_assigned + + + def prop_file(self, node): + """Return a string which can be used to create a configuration file appropriate for the given node.""" + # Process client configuration + prop_file = self.render('console_consumer.properties') + if hasattr(node, "version") and node.version <= LATEST_0_8_2: + # in 0.8.2.X and earlier, console consumer does not have --timeout-ms option + # instead, we have to pass it through the config file + prop_file += "\nconsumer.timeout.ms=%s\n" % str(self.consumer_timeout_ms) + + # Add security properties to the config. If security protocol is not specified, + # use the default in the template properties. + self.security_config = self.kafka.security_config.client_config(prop_file, node, self.jaas_override_variables) + self.security_config.setup_node(node) + + prop_file += str(self.security_config) + return prop_file + + + def start_cmd(self, node): + """Return the start command appropriate for the given node.""" + args = self.args.copy() + args['broker_list'] = self.kafka.bootstrap_servers(self.security_config.security_protocol) + if not self.new_consumer: + args['zk_connect'] = self.kafka.zk_connect_setting() + args['stdout'] = ConsoleConsumer.STDOUT_CAPTURE + args['stderr'] = ConsoleConsumer.STDERR_CAPTURE + args['log_dir'] = ConsoleConsumer.LOG_DIR + args['log4j_config'] = ConsoleConsumer.LOG4J_CONFIG + args['config_file'] = ConsoleConsumer.CONFIG_FILE + args['stdout'] = ConsoleConsumer.STDOUT_CAPTURE + args['jmx_port'] = self.jmx_port + args['console_consumer'] = self.path.script("kafka-console-consumer.sh", node) + + if self.kafka_opts_override: + args['kafka_opts'] = "\"%s\"" % self.kafka_opts_override + else: + args['kafka_opts'] = self.security_config.kafka_opts + + cmd = fix_opts_for_new_jvm(node) + cmd += "export JMX_PORT=%(jmx_port)s; " \ + "export LOG_DIR=%(log_dir)s; " \ + "export KAFKA_LOG4J_OPTS=\"-Dlog4j.configuration=file:%(log4j_config)s\"; " \ + "export KAFKA_OPTS=%(kafka_opts)s; " \ + "%(console_consumer)s " \ + "--topic %(topic)s " \ + "--consumer.config %(config_file)s " % args + + if self.new_consumer: + assert node.version.consumer_supports_bootstrap_server(), \ + "new_consumer is only supported if version >= 0.9.0.0, version %s" % str(node.version) + if node.version <= LATEST_0_10_0: + cmd += " --new-consumer" + cmd += " --bootstrap-server %(broker_list)s" % args + if node.version >= V_0_11_0_0: + cmd += " --isolation-level %s" % self.isolation_level + else: + assert node.version < V_2_0_0, \ + "new_consumer==false is only supported if version < 2.0.0, version %s" % str(node.version) + cmd += " --zookeeper %(zk_connect)s" % args + + if self.from_beginning: + cmd += " --from-beginning" + + if self.consumer_timeout_ms is not None: + # version 0.8.X and below do not support --timeout-ms option + # This will be added in the properties file instead + if node.version > LATEST_0_8_2: + cmd += " --timeout-ms %s" % self.consumer_timeout_ms + + if self.print_timestamp: + cmd += " --property print.timestamp=true" + + if self.print_key: + cmd += " --property print.key=true" + + if self.print_partition: + cmd += " --property print.partition=true" + + # LoggingMessageFormatter was introduced after 0.9 + if node.version > LATEST_0_9: + cmd += " --formatter kafka.tools.LoggingMessageFormatter" + + if self.enable_systest_events: + # enable systest events is only available in 0.10.0 and later + # check the assertion here as well, in case node.version has been modified + assert node.version >= V_0_10_0_0 + cmd += " --enable-systest-events" + + if self.consumer_properties is not None: + for k, v in self.consumer_properties.items(): + cmd += " --consumer-property %s=%s" % (k, v) + + cmd += " 2>> %(stderr)s | tee -a %(stdout)s &" % args + return cmd + + def pids(self, node): + return node.account.java_pids(self.java_class_name()) + + def alive(self, node): + return len(self.pids(node)) > 0 + + def _worker(self, idx, node): + node.account.ssh("mkdir -p %s" % ConsoleConsumer.PERSISTENT_ROOT, allow_fail=False) + + # Create and upload config file + self.logger.info("console_consumer.properties:") + + self.security_config = self.kafka.security_config.client_config(node=node, + jaas_override_variables=self.jaas_override_variables) + self.security_config.setup_node(node) + + if self.client_prop_file_override: + prop_file = self.client_prop_file_override + else: + prop_file = self.prop_file(node) + + self.logger.info(prop_file) + node.account.create_file(ConsoleConsumer.CONFIG_FILE, prop_file) + + # Create and upload log properties + log_config = self.render('tools_log4j.properties', log_file=ConsoleConsumer.LOG_FILE) + node.account.create_file(ConsoleConsumer.LOG4J_CONFIG, log_config) + + # Run and capture output + cmd = self.start_cmd(node) + self.logger.debug("Console consumer %d command: %s", idx, cmd) + + consumer_output = node.account.ssh_capture(cmd, allow_fail=False) + + with self.lock: + self.logger.debug("collecting following jmx objects: %s", self.jmx_object_names) + self.start_jmx_tool(idx, node) + + for line in consumer_output: + msg = line.strip() + if msg == "shutdown_complete": + # Note that we can only rely on shutdown_complete message if running 0.10.0 or greater + if node in self.clean_shutdown_nodes: + raise Exception("Unexpected shutdown event from consumer, already shutdown. Consumer index: %d" % idx) + self.clean_shutdown_nodes.add(node) + else: + if self.message_validator is not None: + msg = self.message_validator(msg) + if msg is not None: + self.messages_consumed[idx].append(msg) + + with self.lock: + self.read_jmx_output(idx, node) + + def _wait_until_partitions_assigned(self, node, timeout_sec=60): + if self.jmx_object_names is not None: + raise Exception("'wait_until_partitions_assigned' is not supported while using 'jmx_object_names'/'jmx_attributes'") + jmx_tool = JmxTool(self.context, jmx_poll_ms=100) + jmx_tool.jmx_object_names = ["kafka.consumer:type=consumer-coordinator-metrics,client-id=%s" % self.client_id] + jmx_tool.jmx_attributes = ["assigned-partitions"] + jmx_tool.assigned_partitions_jmx_attr = "kafka.consumer:type=consumer-coordinator-metrics,client-id=%s:assigned-partitions" % self.client_id + jmx_tool.start_jmx_tool(self.idx(node), node) + assigned_partitions_jmx_attr = "kafka.consumer:type=consumer-coordinator-metrics,client-id=%s:assigned-partitions" % self.client_id + + def read_and_check(): + jmx_tool.read_jmx_output(self.idx(node), node) + return assigned_partitions_jmx_attr in jmx_tool.maximum_jmx_value + + wait_until(lambda: read_and_check(), + timeout_sec=timeout_sec, + backoff_sec=.5, + err_msg="consumer was not assigned partitions within %d seconds" % timeout_sec) + + def start_node(self, node): + BackgroundThreadService.start_node(self, node) + if self.wait_until_partitions_assigned: + self._wait_until_partitions_assigned(node) + + def stop_node(self, node): + self.logger.info("%s Stopping node %s" % (self.__class__.__name__, str(node.account))) + node.account.kill_java_processes(self.java_class_name(), + clean_shutdown=True, allow_fail=True) + + stopped = self.wait_node(node, timeout_sec=self.stop_timeout_sec) + assert stopped, "Node %s: did not stop within the specified timeout of %s seconds" % \ + (str(node.account), str(self.stop_timeout_sec)) + + def clean_node(self, node): + if self.alive(node): + self.logger.warn("%s %s was still alive at cleanup time. Killing forcefully..." % + (self.__class__.__name__, node.account)) + JmxMixin.clean_node(self, node) + node.account.kill_java_processes(self.java_class_name(), clean_shutdown=False, allow_fail=True) + node.account.ssh("rm -rf %s" % ConsoleConsumer.PERSISTENT_ROOT, allow_fail=False) + self.security_config.clean_node(node) + + def java_class_name(self): + return "ConsoleConsumer" + + def has_log_message(self, node, message): + try: + node.account.ssh("grep '%s' %s" % (message, ConsoleConsumer.LOG_FILE)) + except RemoteCommandError: + return False + return True + + def wait_for_offset_reset(self, node, topic, num_partitions): + for partition in range(num_partitions): + message = "Resetting offset for partition %s-%d" % (topic, partition) + wait_until(lambda: self.has_log_message(node, message), + timeout_sec=60, + err_msg="Offset not reset for partition %s-%d" % (topic, partition)) + diff --git a/tests/kafkatest/services/consumer_property.py b/tests/kafkatest/services/consumer_property.py new file mode 100644 index 0000000..0a9756a --- /dev/null +++ b/tests/kafkatest/services/consumer_property.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Define Consumer configuration property names here. +""" + +GROUP_INSTANCE_ID = "group.instance.id" +SESSION_TIMEOUT_MS = "session.timeout.ms" diff --git a/tests/kafkatest/services/delegation_tokens.py b/tests/kafkatest/services/delegation_tokens.py new file mode 100644 index 0000000..34da16b --- /dev/null +++ b/tests/kafkatest/services/delegation_tokens.py @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os.path +from kafkatest.directory_layout.kafka_path import KafkaPathResolverMixin + +""" +Delegation tokens is a tool to manage the lifecycle of delegation tokens. +All commands are executed on a secured Kafka node reusing its generated jaas.conf and krb5.conf. +""" + +class DelegationTokens(KafkaPathResolverMixin): + def __init__(self, kafka, context): + self.client_properties_content = """ +security.protocol=SASL_PLAINTEXT +sasl.kerberos.service.name=kafka +""" + self.context = context + self.command_path = self.path.script("kafka-delegation-tokens.sh") + self.kafka_opts = "KAFKA_OPTS=\"-Djava.security.auth.login.config=/mnt/security/jaas.conf " \ + "-Djava.security.krb5.conf=/mnt/security/krb5.conf\" " + self.kafka = kafka + self.bootstrap_server = " --bootstrap-server " + self.kafka.bootstrap_servers('SASL_PLAINTEXT') + self.base_cmd = self.kafka_opts + self.command_path + self.bootstrap_server + self.client_prop_path = os.path.join(self.kafka.PERSISTENT_ROOT, "client.properties") + self.jaas_deleg_conf_path = os.path.join(self.kafka.PERSISTENT_ROOT, "jaas_deleg.conf") + self.token_hmac_path = os.path.join(self.kafka.PERSISTENT_ROOT, "deleg_token_hmac.out") + self.delegation_token_out = os.path.join(self.kafka.PERSISTENT_ROOT, "delegation_token.out") + self.expire_delegation_token_out = os.path.join(self.kafka.PERSISTENT_ROOT, "expire_delegation_token.out") + self.renew_delegation_token_out = os.path.join(self.kafka.PERSISTENT_ROOT, "renew_delegation_token.out") + + self.node = self.kafka.nodes[0] + + def generate_delegation_token(self, maxlifetimeperiod=-1): + self.node.account.create_file(self.client_prop_path, self.client_properties_content) + + cmd = self.base_cmd + " --create" \ + " --max-life-time-period %s" \ + " --command-config %s > %s" % (maxlifetimeperiod, self.client_prop_path, self.delegation_token_out) + self.node.account.ssh(cmd, allow_fail=False) + + def expire_delegation_token(self, hmac): + cmd = self.base_cmd + " --expire" \ + " --expiry-time-period -1" \ + " --hmac %s" \ + " --command-config %s > %s" % (hmac, self.client_prop_path, self.expire_delegation_token_out) + self.node.account.ssh(cmd, allow_fail=False) + + def renew_delegation_token(self, hmac, renew_time_period=-1): + cmd = self.base_cmd + " --renew" \ + " --renew-time-period %s" \ + " --hmac %s" \ + " --command-config %s > %s" \ + % (renew_time_period, hmac, self.client_prop_path, self.renew_delegation_token_out) + return self.node.account.ssh_capture(cmd, allow_fail=False) + + def create_jaas_conf_with_delegation_token(self): + dt = self.parse_delegation_token_out() + jaas_deleg_content = """ +KafkaClient { + org.apache.kafka.common.security.scram.ScramLoginModule required + username="%s" + password="%s" + tokenauth=true; +}; +""" % (dt["tokenid"], dt["hmac"]) + self.node.account.create_file(self.jaas_deleg_conf_path, jaas_deleg_content) + + return jaas_deleg_content + + def token_hmac(self): + dt = self.parse_delegation_token_out() + return dt["hmac"] + + def parse_delegation_token_out(self): + cmd = "tail -1 %s" % self.delegation_token_out + + output_iter = self.node.account.ssh_capture(cmd, allow_fail=False) + output = "" + for line in output_iter: + output += line + + tokenid, hmac, owner, renewers, issuedate, expirydate, maxdate = output.split() + return {"tokenid" : tokenid, + "hmac" : hmac, + "owner" : owner, + "renewers" : renewers, + "issuedate" : issuedate, + "expirydate" :expirydate, + "maxdate" : maxdate} \ No newline at end of file diff --git a/tests/kafkatest/services/kafka/__init__.py b/tests/kafkatest/services/kafka/__init__.py new file mode 100644 index 0000000..5ae879d --- /dev/null +++ b/tests/kafkatest/services/kafka/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .kafka import KafkaService +from .util import TopicPartition +from .config import KafkaConfig diff --git a/tests/kafkatest/services/kafka/config.py b/tests/kafkatest/services/kafka/config.py new file mode 100644 index 0000000..da5b4a2 --- /dev/null +++ b/tests/kafkatest/services/kafka/config.py @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import config_property + + +class KafkaConfig(dict): + """A dictionary-like container class which allows for definition of overridable default values, + which is also capable of "rendering" itself as a useable server.properties file. + """ + + DEFAULTS = { + config_property.SOCKET_RECEIVE_BUFFER_BYTES: 65536, + config_property.LOG_DIRS: "/mnt/kafka/kafka-data-logs-1,/mnt/kafka/kafka-data-logs-2", + config_property.METADATA_LOG_DIR: "/mnt/kafka/kafka-metadata-logs", + config_property.METADATA_LOG_SEGMENT_BYTES: str(9*1024*1024), # 9 MB + config_property.METADATA_LOG_BYTES_BETWEEN_SNAPSHOTS: str(10*1024*1024), # 10 MB + config_property.METADATA_LOG_RETENTION_BYTES: str(10*1024*1024), # 10 MB + config_property.METADATA_LOG_SEGMENT_MS: str(1*60*1000) # one minute + } + + def __init__(self, **kwargs): + super(KafkaConfig, self).__init__(**kwargs) + + # Set defaults + for key, val in self.DEFAULTS.items(): + if key not in self: + self[key] = val + + def render(self): + """Render self as a series of lines key=val\n, and do so in a consistent order. """ + keys = [k for k in self.keys()] + keys.sort() + + s = "" + for k in keys: + s += "%s=%s\n" % (k, str(self[k])) + return s + diff --git a/tests/kafkatest/services/kafka/config_property.py b/tests/kafkatest/services/kafka/config_property.py new file mode 100644 index 0000000..de11422 --- /dev/null +++ b/tests/kafkatest/services/kafka/config_property.py @@ -0,0 +1,204 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Define Kafka configuration property names here. +""" + +BROKER_ID = "broker.id" +NODE_ID = "node.id" +FIRST_BROKER_PORT = 9092 +FIRST_CONTROLLER_PORT = FIRST_BROKER_PORT + 500 +FIRST_CONTROLLER_ID = 3001 +CLUSTER_ID = "I2eXt9rvSnyhct8BYmW6-w" +PORT = "port" +ADVERTISED_HOSTNAME = "advertised.host.name" + +NUM_NETWORK_THREADS = "num.network.threads" +NUM_IO_THREADS = "num.io.threads" +SOCKET_SEND_BUFFER_BYTES = "socket.send.buffer.bytes" +SOCKET_RECEIVE_BUFFER_BYTES = "socket.receive.buffer.bytes" +SOCKET_REQUEST_MAX_BYTES = "socket.request.max.bytes" +LOG_DIRS = "log.dirs" +NUM_PARTITIONS = "num.partitions" +NUM_RECOVERY_THREADS_PER_DATA_DIR = "num.recovery.threads.per.data.dir" + +LOG_RETENTION_HOURS = "log.retention.hours" +LOG_SEGMENT_BYTES = "log.segment.bytes" +LOG_RETENTION_CHECK_INTERVAL_MS = "log.retention.check.interval.ms" +LOG_RETENTION_MS = "log.retention.ms" +LOG_CLEANER_ENABLE = "log.cleaner.enable" + +METADATA_LOG_DIR = "metadata.log.dir" +METADATA_LOG_SEGMENT_BYTES = "metadata.log.segment.bytes" +METADATA_LOG_RETENTION_BYTES = "metadata.max.retention.bytes" +METADATA_LOG_SEGMENT_MS = "metadata.log.segment.ms" +METADATA_LOG_BYTES_BETWEEN_SNAPSHOTS = "metadata.log.max.record.bytes.between.snapshots" + +AUTO_CREATE_TOPICS_ENABLE = "auto.create.topics.enable" + +ZOOKEEPER_CONNECT = "zookeeper.connect" +ZOOKEEPER_SSL_CLIENT_ENABLE = "zookeeper.ssl.client.enable" +ZOOKEEPER_CLIENT_CNXN_SOCKET = "zookeeper.clientCnxnSocket" +ZOOKEEPER_CONNECTION_TIMEOUT_MS = "zookeeper.connection.timeout.ms" +ZOOKEEPER_SESSION_TIMEOUT_MS = "zookeeper.session.timeout.ms" +INTER_BROKER_PROTOCOL_VERSION = "inter.broker.protocol.version" +MESSAGE_FORMAT_VERSION = "log.message.format.version" +MESSAGE_TIMESTAMP_TYPE = "message.timestamp.type" +THROTTLING_REPLICATION_RATE_LIMIT = "replication.quota.throttled.rate" + +LOG_FLUSH_INTERVAL_MESSAGE = "log.flush.interval.messages" +REPLICA_HIGHWATERMARK_CHECKPOINT_INTERVAL_MS = "replica.high.watermark.checkpoint.interval.ms" +LOG_ROLL_TIME_MS = "log.roll.ms" +OFFSETS_TOPIC_NUM_PARTITIONS = "offsets.topic.num.partitions" + +DELEGATION_TOKEN_MAX_LIFETIME_MS="delegation.token.max.lifetime.ms" +DELEGATION_TOKEN_EXPIRY_TIME_MS="delegation.token.expiry.time.ms" +DELEGATION_TOKEN_SECRET_KEY="delegation.token.secret.key" +SASL_ENABLED_MECHANISMS="sasl.enabled.mechanisms" + + +""" +From KafkaConfig.scala + + /** ********* General Configuration ***********/ + val MaxReservedBrokerIdProp = "reserved.broker.max.id" + val MessageMaxBytesProp = "message.max.bytes" + val NumIoThreadsProp = "num.io.threads" + val BackgroundThreadsProp = "background.threads" + val QueuedMaxRequestsProp = "queued.max.requests" + /** ********* Socket Server Configuration ***********/ + val PortProp = "port" + val HostNameProp = "host.name" + val ListenersProp = "listeners" + val AdvertisedPortProp = "advertised.port" + val AdvertisedListenersProp = "advertised.listeners" + val SocketSendBufferBytesProp = "socket.send.buffer.bytes" + val SocketReceiveBufferBytesProp = "socket.receive.buffer.bytes" + val SocketRequestMaxBytesProp = "socket.request.max.bytes" + val MaxConnectionsPerIpProp = "max.connections.per.ip" + val MaxConnectionsPerIpOverridesProp = "max.connections.per.ip.overrides" + val ConnectionsMaxIdleMsProp = "connections.max.idle.ms" + /** ********* Log Configuration ***********/ + val NumPartitionsProp = "num.partitions" + val LogDirsProp = "log.dirs" + val LogDirProp = "log.dir" + val LogSegmentBytesProp = "log.segment.bytes" + + val LogRollTimeMillisProp = "log.roll.ms" + val LogRollTimeHoursProp = "log.roll.hours" + + val LogRollTimeJitterMillisProp = "log.roll.jitter.ms" + val LogRollTimeJitterHoursProp = "log.roll.jitter.hours" + + val LogRetentionTimeMillisProp = "log.retention.ms" + val LogRetentionTimeMinutesProp = "log.retention.minutes" + val LogRetentionTimeHoursProp = "log.retention.hours" + + val LogRetentionBytesProp = "log.retention.bytes" + val LogCleanupIntervalMsProp = "log.retention.check.interval.ms" + val LogCleanupPolicyProp = "log.cleanup.policy" + val LogCleanerThreadsProp = "log.cleaner.threads" + val LogCleanerIoMaxBytesPerSecondProp = "log.cleaner.io.max.bytes.per.second" + val LogCleanerDedupeBufferSizeProp = "log.cleaner.dedupe.buffer.size" + val LogCleanerIoBufferSizeProp = "log.cleaner.io.buffer.size" + val LogCleanerDedupeBufferLoadFactorProp = "log.cleaner.io.buffer.load.factor" + val LogCleanerBackoffMsProp = "log.cleaner.backoff.ms" + val LogCleanerMinCleanRatioProp = "log.cleaner.min.cleanable.ratio" + val LogCleanerEnableProp = "log.cleaner.enable" + val LogCleanerDeleteRetentionMsProp = "log.cleaner.delete.retention.ms" + val LogIndexSizeMaxBytesProp = "log.index.size.max.bytes" + val LogIndexIntervalBytesProp = "log.index.interval.bytes" + val LogFlushIntervalMessagesProp = "log.flush.interval.messages" + val LogDeleteDelayMsProp = "log.segment.delete.delay.ms" + val LogFlushSchedulerIntervalMsProp = "log.flush.scheduler.interval.ms" + val LogFlushIntervalMsProp = "log.flush.interval.ms" + val LogFlushOffsetCheckpointIntervalMsProp = "log.flush.offset.checkpoint.interval.ms" + val LogPreAllocateProp = "log.preallocate" + val NumRecoveryThreadsPerDataDirProp = "num.recovery.threads.per.data.dir" + val MinInSyncReplicasProp = "min.insync.replicas" + /** ********* Replication configuration ***********/ + val ControllerSocketTimeoutMsProp = "controller.socket.timeout.ms" + val DefaultReplicationFactorProp = "default.replication.factor" + val ReplicaLagTimeMaxMsProp = "replica.lag.time.max.ms" + val ReplicaSocketTimeoutMsProp = "replica.socket.timeout.ms" + val ReplicaSocketReceiveBufferBytesProp = "replica.socket.receive.buffer.bytes" + val ReplicaFetchMaxBytesProp = "replica.fetch.max.bytes" + val ReplicaFetchWaitMaxMsProp = "replica.fetch.wait.max.ms" + val ReplicaFetchMinBytesProp = "replica.fetch.min.bytes" + val ReplicaFetchBackoffMsProp = "replica.fetch.backoff.ms" + val NumReplicaFetchersProp = "num.replica.fetchers" + val ReplicaHighWatermarkCheckpointIntervalMsProp = "replica.high.watermark.checkpoint.interval.ms" + val FetchPurgatoryPurgeIntervalRequestsProp = "fetch.purgatory.purge.interval.requests" + val ProducerPurgatoryPurgeIntervalRequestsProp = "producer.purgatory.purge.interval.requests" + val AutoLeaderRebalanceEnableProp = "auto.leader.rebalance.enable" + val LeaderImbalancePerBrokerPercentageProp = "leader.imbalance.per.broker.percentage" + val LeaderImbalanceCheckIntervalSecondsProp = "leader.imbalance.check.interval.seconds" + val UncleanLeaderElectionEnableProp = "unclean.leader.election.enable" + val InterBrokerSecurityProtocolProp = "security.inter.broker.protocol" + val InterBrokerProtocolVersionProp = "inter.broker.protocol.version" + /** ********* Controlled shutdown configuration ***********/ + val ControlledShutdownMaxRetriesProp = "controlled.shutdown.max.retries" + val ControlledShutdownRetryBackoffMsProp = "controlled.shutdown.retry.backoff.ms" + val ControlledShutdownEnableProp = "controlled.shutdown.enable" + /** ********* Consumer coordinator configuration ***********/ + val ConsumerMinSessionTimeoutMsProp = "consumer.min.session.timeout.ms" + val ConsumerMaxSessionTimeoutMsProp = "consumer.max.session.timeout.ms" + /** ********* Offset management configuration ***********/ + val OffsetMetadataMaxSizeProp = "offset.metadata.max.bytes" + val OffsetsLoadBufferSizeProp = "offsets.load.buffer.size" + val OffsetsTopicReplicationFactorProp = "offsets.topic.replication.factor" + val OffsetsTopicPartitionsProp = "offsets.topic.num.partitions" + val OffsetsTopicSegmentBytesProp = "offsets.topic.segment.bytes" + val OffsetsTopicCompressionCodecProp = "offsets.topic.compression.codec" + val OffsetsRetentionMinutesProp = "offsets.retention.minutes" + val OffsetsRetentionCheckIntervalMsProp = "offsets.retention.check.interval.ms" + val OffsetCommitTimeoutMsProp = "offsets.commit.timeout.ms" + val OffsetCommitRequiredAcksProp = "offsets.commit.required.acks" + /** ********* Quota Configuration ***********/ + val ProducerQuotaBytesPerSecondDefaultProp = "quota.producer.default" + val ConsumerQuotaBytesPerSecondDefaultProp = "quota.consumer.default" + val NumQuotaSamplesProp = "quota.window.num" + val QuotaWindowSizeSecondsProp = "quota.window.size.seconds" + + val DeleteTopicEnableProp = "delete.topic.enable" + val CompressionTypeProp = "compression.type" + + /** ********* Kafka Metrics Configuration ***********/ + val MetricSampleWindowMsProp = CommonClientConfigs.METRICS_SAMPLE_WINDOW_MS_CONFIG + val MetricNumSamplesProp: String = CommonClientConfigs.METRICS_NUM_SAMPLES_CONFIG + val MetricReporterClassesProp: String = CommonClientConfigs.METRIC_REPORTER_CLASSES_CONFIG + + /** ********* SSL Configuration ****************/ + val PrincipalBuilderClassProp = SSLConfigs.PRINCIPAL_BUILDER_CLASS_CONFIG + val SSLProtocolProp = SSLConfigs.SSL_PROTOCOL_CONFIG + val SSLProviderProp = SSLConfigs.SSL_PROVIDER_CONFIG + val SSLCipherSuitesProp = SSLConfigs.SSL_CIPHER_SUITES_CONFIG + val SSLEnabledProtocolsProp = SSLConfigs.SSL_ENABLED_PROTOCOLS_CONFIG + val SSLKeystoreTypeProp = SSLConfigs.SSL_KEYSTORE_TYPE_CONFIG + val SSLKeystoreLocationProp = SSLConfigs.SSL_KEYSTORE_LOCATION_CONFIG + val SSLKeystorePasswordProp = SSLConfigs.SSL_KEYSTORE_PASSWORD_CONFIG + val SSLKeyPasswordProp = SSLConfigs.SSL_KEY_PASSWORD_CONFIG + val SSLTruststoreTypeProp = SSLConfigs.SSL_TRUSTSTORE_TYPE_CONFIG + val SSLTruststoreLocationProp = SSLConfigs.SSL_TRUSTSTORE_LOCATION_CONFIG + val SSLTruststorePasswordProp = SSLConfigs.SSL_TRUSTSTORE_PASSWORD_CONFIG + val SSLKeyManagerAlgorithmProp = SSLConfigs.SSL_KEYMANAGER_ALGORITHM_CONFIG + val SSLTrustManagerAlgorithmProp = SSLConfigs.SSL_TRUSTMANAGER_ALGORITHM_CONFIG + val SSLEndpointIdentificationAlgorithmProp = SSLConfigs.SSL_ENDPOINT_IDENTIFICATION_ALGORITHM_CONFIG + val SSLSecureRandomImplementationProp = SSLConfigs.SSL_SECURE_RANDOM_IMPLEMENTATION_CONFIG + val SSLClientAuthProp = SSLConfigs.SSL_CLIENT_AUTH_CONFIG +""" + + diff --git a/tests/kafkatest/services/kafka/kafka.py b/tests/kafkatest/services/kafka/kafka.py new file mode 100644 index 0000000..55b5b7b --- /dev/null +++ b/tests/kafkatest/services/kafka/kafka.py @@ -0,0 +1,1733 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os.path +import re +import signal +import time + +from ducktape.services.service import Service +from ducktape.utils.util import wait_until +from ducktape.cluster.remoteaccount import RemoteCommandError + +from .config import KafkaConfig +from kafkatest.directory_layout.kafka_path import KafkaPathResolverMixin +from kafkatest.services.kafka import config_property, quorum +from kafkatest.services.monitor.jmx import JmxMixin +from kafkatest.services.security.minikdc import MiniKdc +from kafkatest.services.security.listener_security_config import ListenerSecurityConfig +from kafkatest.services.security.security_config import SecurityConfig +from kafkatest.version import DEV_BRANCH +from kafkatest.version import KafkaVersion +from kafkatest.services.kafka.util import fix_opts_for_new_jvm + + +class KafkaListener: + + def __init__(self, name, port_number, security_protocol, open=False, sasl_mechanism = None): + self.name = name + self.port_number = port_number + self.security_protocol = security_protocol + self.open = open + self.sasl_mechanism = sasl_mechanism + + def listener(self): + return "%s://:%s" % (self.name, str(self.port_number)) + + def advertised_listener(self, node): + return "%s://%s:%s" % (self.name, node.account.hostname, str(self.port_number)) + + def listener_security_protocol(self): + return "%s:%s" % (self.name, self.security_protocol) + +class KafkaService(KafkaPathResolverMixin, JmxMixin, Service): + """ + Ducktape system test service for Brokers and KRaft Controllers + + Metadata Quorums + ---------------- + Kafka can use either ZooKeeper or a KRaft Controller quorum for its + metadata. See the kafkatest.services.kafka.quorum.ServiceQuorumInfo + class for details. + + Attributes + ---------- + + quorum_info : kafkatest.services.kafka.quorum.ServiceQuorumInfo + Information about the service and it's metadata quorum + num_nodes_broker_role : int + The number of nodes in the service that include 'broker' + in process.roles (0 when using Zookeeper) + num_nodes_controller_role : int + The number of nodes in the service that include 'controller' + in process.roles (0 when using Zookeeper) + controller_quorum : KafkaService + None when using ZooKeeper, otherwise the Kafka service for the + co-located case or the remote controller quorum service + instance for the remote case + remote_controller_quorum : KafkaService + None for the co-located case or when using ZooKeeper, otherwise + the remote controller quorum service instance + + Kafka Security Protocols + ------------------------ + The security protocol advertised to clients and the inter-broker + security protocol can be set in the constructor and can be changed + afterwards as well. Set these attributes to make changes; they + take effect when starting each node: + + security_protocol : str + default PLAINTEXT + client_sasl_mechanism : str + default GSSAPI, ignored unless using SASL_PLAINTEXT or SASL_SSL + interbroker_security_protocol : str + default PLAINTEXT + interbroker_sasl_mechanism : str + default GSSAPI, ignored unless using SASL_PLAINTEXT or SASL_SSL + + ZooKeeper + --------- + Create an instance of ZookeeperService when metadata_quorum is ZK + (ZK is the default if metadata_quorum is not a test parameter). + + KRaft Quorums + ------------ + Set metadata_quorum accordingly (to COLOCATED_KRAFT or REMOTE_KRAFT). + Do not instantiate a ZookeeperService instance. + + Starting Kafka will cause any remote controller quorum to + automatically start first. Explicitly stopping Kafka does not stop + any remote controller quorum, but Ducktape will stop both when + tearing down the test (it will stop Kafka first). + + KRaft Security Protocols + -------------------------------- + The broker-to-controller and inter-controller security protocols + will both initially be set to the inter-broker security protocol. + The broker-to-controller and inter-controller security protocols + must be identical for the co-located case (an exception will be + thrown when trying to start the service if they are not identical). + The broker-to-controller and inter-controller security protocols + can differ in the remote case. + + Set these attributes for the co-located case. Changes take effect + when starting each node: + + controller_security_protocol : str + default PLAINTEXT + controller_sasl_mechanism : str + default GSSAPI, ignored unless using SASL_PLAINTEXT or SASL_SSL + intercontroller_security_protocol : str + default PLAINTEXT + intercontroller_sasl_mechanism : str + default GSSAPI, ignored unless using SASL_PLAINTEXT or SASL_SSL + + Set the same attributes for the remote case (changes take effect + when starting each quorum node), but you must first obtain the + service instance for the remote quorum via one of the + 'controller_quorum' or 'remote_controller_quorum' attributes as + defined above. + + """ + PERSISTENT_ROOT = "/mnt/kafka" + STDOUT_STDERR_CAPTURE = os.path.join(PERSISTENT_ROOT, "server-start-stdout-stderr.log") + LOG4J_CONFIG = os.path.join(PERSISTENT_ROOT, "kafka-log4j.properties") + # Logs such as controller.log, server.log, etc all go here + OPERATIONAL_LOG_DIR = os.path.join(PERSISTENT_ROOT, "kafka-operational-logs") + OPERATIONAL_LOG_INFO_DIR = os.path.join(OPERATIONAL_LOG_DIR, "info") + OPERATIONAL_LOG_DEBUG_DIR = os.path.join(OPERATIONAL_LOG_DIR, "debug") + # Kafka log segments etc go here + DATA_LOG_DIR_PREFIX = os.path.join(PERSISTENT_ROOT, "kafka-data-logs") + DATA_LOG_DIR_1 = "%s-1" % (DATA_LOG_DIR_PREFIX) + DATA_LOG_DIR_2 = "%s-2" % (DATA_LOG_DIR_PREFIX) + CONFIG_FILE = os.path.join(PERSISTENT_ROOT, "kafka.properties") + METADATA_LOG_DIR = os.path.join (PERSISTENT_ROOT, "kafka-metadata-logs") + METADATA_SNAPSHOT_SEARCH_STR = "%s/__cluster_metadata-0/*.checkpoint" % METADATA_LOG_DIR + METADATA_FIRST_LOG = "%s/__cluster_metadata-0/00000000000000000000.log" % METADATA_LOG_DIR + # Kafka Authorizer + ACL_AUTHORIZER = "kafka.security.authorizer.AclAuthorizer" + HEAP_DUMP_FILE = os.path.join(PERSISTENT_ROOT, "kafka_heap_dump.bin") + INTERBROKER_LISTENER_NAME = 'INTERNAL' + JAAS_CONF_PROPERTY = "java.security.auth.login.config=/mnt/security/jaas.conf" + ADMIN_CLIENT_AS_BROKER_JAAS_CONF_PROPERTY = "java.security.auth.login.config=/mnt/security/admin_client_as_broker_jaas.conf" + KRB5_CONF = "java.security.krb5.conf=/mnt/security/krb5.conf" + SECURITY_PROTOCOLS = [SecurityConfig.PLAINTEXT, SecurityConfig.SSL, SecurityConfig.SASL_PLAINTEXT, SecurityConfig.SASL_SSL] + + logs = { + "kafka_server_start_stdout_stderr": { + "path": STDOUT_STDERR_CAPTURE, + "collect_default": True}, + "kafka_operational_logs_info": { + "path": OPERATIONAL_LOG_INFO_DIR, + "collect_default": True}, + "kafka_operational_logs_debug": { + "path": OPERATIONAL_LOG_DEBUG_DIR, + "collect_default": False}, + "kafka_data_1": { + "path": DATA_LOG_DIR_1, + "collect_default": False}, + "kafka_data_2": { + "path": DATA_LOG_DIR_2, + "collect_default": False}, + "kafka_heap_dump_file": { + "path": HEAP_DUMP_FILE, + "collect_default": True} + } + + def __init__(self, context, num_nodes, zk, security_protocol=SecurityConfig.PLAINTEXT, + interbroker_security_protocol=SecurityConfig.PLAINTEXT, + client_sasl_mechanism=SecurityConfig.SASL_MECHANISM_GSSAPI, interbroker_sasl_mechanism=SecurityConfig.SASL_MECHANISM_GSSAPI, + authorizer_class_name=None, topics=None, version=DEV_BRANCH, jmx_object_names=None, + jmx_attributes=None, zk_connect_timeout=18000, zk_session_timeout=18000, server_prop_overrides=None, zk_chroot=None, + zk_client_secure=False, + listener_security_config=ListenerSecurityConfig(), per_node_server_prop_overrides=None, + extra_kafka_opts="", tls_version=None, + remote_kafka=None, + controller_num_nodes_override=0, + allow_zk_with_kraft=False, + ): + """ + :param context: test context + :param int num_nodes: the number of nodes in the service. There are 4 possibilities: + 1) Zookeeper quorum: + The number of brokers is defined by this parameter. + The broker.id values will be 1..num_nodes. + 2) Co-located KRaft quorum: + The number of nodes having a broker role is defined by this parameter. + The node.id values will be 1..num_nodes + The number of nodes having a controller role will by default be 1, 3, or 5 depending on num_nodes + (1 if num_nodes < 3, otherwise 3 if num_nodes < 5, otherwise 5). This calculation + can be overridden via controller_num_nodes_override, which must be between 1 and num_nodes, + inclusive, when non-zero. Here are some possibilities: + num_nodes = 1: + broker having node.id=1: broker.roles=broker+controller + num_nodes = 2: + broker having node.id=1: broker.roles=broker+controller + broker having node.id=2: broker.roles=broker + num_nodes = 3: + broker having node.id=1: broker.roles=broker+controller + broker having node.id=2: broker.roles=broker+controller + broker having node.id=3: broker.roles=broker+controller + num_nodes = 3, controller_num_nodes_override = 1 + broker having node.id=1: broker.roles=broker+controller + broker having node.id=2: broker.roles=broker + broker having node.id=3: broker.roles=broker + 3) Remote KRaft quorum when instantiating the broker service: + The number of nodes, all of which will have broker.roles=broker, is defined by this parameter. + The node.id values will be 1..num_nodes + 4) Remote KRaft quorum when instantiating the controller service: + The number of nodes, all of which will have broker.roles=controller, is defined by this parameter. + The node.id values will be 3001..(3000 + num_nodes) + The value passed in is determined by the broker service when that is instantiated, and it uses the + same algorithm as described above: 1, 3, or 5 unless controller_num_nodes_override is provided. + :param ZookeeperService zk: + :param dict topics: which topics to create automatically + :param str security_protocol: security protocol for clients to use + :param str tls_version: version of the TLS protocol. + :param str interbroker_security_protocol: security protocol to use for broker-to-broker (and KRaft controller-to-controller) communication + :param str client_sasl_mechanism: sasl mechanism for clients to use + :param str interbroker_sasl_mechanism: sasl mechanism to use for broker-to-broker (and to-controller) communication + :param str authorizer_class_name: which authorizer class to use + :param str version: which kafka version to use. Defaults to "dev" branch + :param jmx_object_names: + :param jmx_attributes: + :param int zk_connect_timeout: + :param int zk_session_timeout: + :param list[list] server_prop_overrides: overrides for kafka.properties file + e.g: [["config1", "true"], ["config2", "1000"]] + :param str zk_chroot: + :param bool zk_client_secure: connect to Zookeeper over secure client port (TLS) when True + :param ListenerSecurityConfig listener_security_config: listener config to use + :param dict per_node_server_prop_overrides: overrides for kafka.properties file keyed by 1-based node number + e.g: {1: [["config1", "true"], ["config2", "1000"]], 2: [["config1", "false"], ["config2", "0"]]} + :param str extra_kafka_opts: jvm args to add to KAFKA_OPTS variable + :param KafkaService remote_kafka: process.roles=controller for this cluster when not None; ignored when using ZooKeeper + :param int controller_num_nodes_override: the number of nodes to use in the cluster, instead of 5, 3, or 1 based on num_nodes, if positive, not using ZooKeeper, and remote_kafka is not None; ignored otherwise + :param bool allow_zk_with_kraft: if True, then allow a KRaft broker or controller to also use ZooKeeper + + """ + + self.zk = zk + self.remote_kafka = remote_kafka + self.allow_zk_with_kraft = allow_zk_with_kraft + self.quorum_info = quorum.ServiceQuorumInfo(self, context) + self.controller_quorum = None # will define below if necessary + self.remote_controller_quorum = None # will define below if necessary + + if num_nodes < 1: + raise Exception("Must set a positive number of nodes: %i" % num_nodes) + self.num_nodes_broker_role = 0 + self.num_nodes_controller_role = 0 + + if self.quorum_info.using_kraft: + if self.quorum_info.has_brokers: + num_nodes_broker_role = num_nodes + if self.quorum_info.has_controllers: + self.num_nodes_controller_role = self.num_kraft_controllers(num_nodes_broker_role, controller_num_nodes_override) + if self.remote_kafka: + raise Exception("Must not specify remote Kafka service with co-located Controller quorum") + else: + self.num_nodes_controller_role = num_nodes + if not self.remote_kafka: + raise Exception("Must specify remote Kafka service when instantiating remote Controller service (should not happen)") + + # Initially use the inter-broker security protocol for both + # broker-to-controller and inter-controller communication. Both can be explicitly changed later if desired. + # Note, however, that the two must the same if the controller quorum is co-located with the + # brokers. Different security protocols for the two are only supported with a remote controller quorum. + self.controller_security_protocol = interbroker_security_protocol + self.controller_sasl_mechanism = interbroker_sasl_mechanism + self.intercontroller_security_protocol = interbroker_security_protocol + self.intercontroller_sasl_mechanism = interbroker_sasl_mechanism + + # Ducktape tears down services in the reverse order in which they are created, + # so create a service for the remote controller quorum (if we need one) first, before + # invoking Service.__init__(), so that Ducktape will tear down the quorum last; otherwise + # Ducktape will tear down the controller quorum first, which could lead to problems in + # Kafka and delays in tearing it down (and who knows what else -- it's simply better + # to correctly tear down Kafka first, before tearing down the remote controller). + if self.quorum_info.has_controllers: + self.controller_quorum = self + else: + num_remote_controller_nodes = self.num_kraft_controllers(num_nodes, controller_num_nodes_override) + self.remote_controller_quorum = KafkaService( + context, num_remote_controller_nodes, self.zk, security_protocol=self.controller_security_protocol, + interbroker_security_protocol=self.intercontroller_security_protocol, + client_sasl_mechanism=self.controller_sasl_mechanism, interbroker_sasl_mechanism=self.intercontroller_sasl_mechanism, + authorizer_class_name=authorizer_class_name, version=version, jmx_object_names=jmx_object_names, + jmx_attributes=jmx_attributes, + listener_security_config=listener_security_config, + extra_kafka_opts=extra_kafka_opts, tls_version=tls_version, + remote_kafka=self, allow_zk_with_kraft=self.allow_zk_with_kraft, + server_prop_overrides=server_prop_overrides + ) + self.controller_quorum = self.remote_controller_quorum + + Service.__init__(self, context, num_nodes) + JmxMixin.__init__(self, num_nodes=num_nodes, jmx_object_names=jmx_object_names, jmx_attributes=(jmx_attributes or []), + root=KafkaService.PERSISTENT_ROOT) + + self.security_protocol = security_protocol + self.tls_version = tls_version + self.client_sasl_mechanism = client_sasl_mechanism + self.topics = topics + self.minikdc = None + self.concurrent_start = True # start concurrently by default + self.authorizer_class_name = authorizer_class_name + self.zk_set_acl = False + if server_prop_overrides is None: + self.server_prop_overrides = [] + else: + self.server_prop_overrides = server_prop_overrides + if per_node_server_prop_overrides is None: + self.per_node_server_prop_overrides = {} + else: + self.per_node_server_prop_overrides = per_node_server_prop_overrides + self.log_level = "DEBUG" + self.zk_chroot = zk_chroot + self.zk_client_secure = zk_client_secure + self.listener_security_config = listener_security_config + self.extra_kafka_opts = extra_kafka_opts + + # + # In a heavily loaded and not very fast machine, it is + # sometimes necessary to give more time for the zk client + # to have its session established, especially if the client + # is authenticating and waiting for the SaslAuthenticated + # in addition to the SyncConnected event. + # + # The default value for zookeeper.connect.timeout.ms is + # 2 seconds and here we increase it to 5 seconds, but + # it can be overridden by setting the corresponding parameter + # for this constructor. + self.zk_connect_timeout = zk_connect_timeout + + # Also allow the session timeout to be provided explicitly, + # primarily so that test cases can depend on it when waiting + # e.g. brokers to deregister after a hard kill. + self.zk_session_timeout = zk_session_timeout + + broker_only_port_mappings = { + KafkaService.INTERBROKER_LISTENER_NAME: + KafkaListener(KafkaService.INTERBROKER_LISTENER_NAME, config_property.FIRST_BROKER_PORT + 7, None, False) + } + controller_only_port_mappings = {} + for idx, sec_protocol in enumerate(KafkaService.SECURITY_PROTOCOLS): + name_for_controller = self.controller_listener_name(sec_protocol) + broker_only_port_mappings[sec_protocol] = KafkaListener(sec_protocol, config_property.FIRST_BROKER_PORT + idx, sec_protocol, False) + controller_only_port_mappings[name_for_controller] = KafkaListener(name_for_controller, config_property.FIRST_CONTROLLER_PORT + idx, sec_protocol, False) + + if self.quorum_info.using_zk or self.quorum_info.has_brokers and not self.quorum_info.has_controllers: # ZK or KRaft broker-only + self.port_mappings = broker_only_port_mappings + elif self.quorum_info.has_brokers_and_controllers: # KRaft broker+controller + self.port_mappings = broker_only_port_mappings.copy() + self.port_mappings.update(controller_only_port_mappings) + else: # KRaft controller-only + self.port_mappings = controller_only_port_mappings + + self.interbroker_listener = None + if self.quorum_info.using_zk or self.quorum_info.has_brokers: + self.setup_interbroker_listener(interbroker_security_protocol, self.listener_security_config.use_separate_interbroker_listener) + self.interbroker_sasl_mechanism = interbroker_sasl_mechanism + self._security_config = None + + for node in self.nodes: + node_quorum_info = quorum.NodeQuorumInfo(self.quorum_info, node) + + node.version = version + zk_broker_configs = { + config_property.PORT: config_property.FIRST_BROKER_PORT, + config_property.BROKER_ID: self.idx(node), + config_property.ZOOKEEPER_CONNECTION_TIMEOUT_MS: zk_connect_timeout, + config_property.ZOOKEEPER_SESSION_TIMEOUT_MS: zk_session_timeout + } + kraft_broker_configs = { + config_property.PORT: config_property.FIRST_BROKER_PORT, + config_property.NODE_ID: self.idx(node), + } + kraft_broker_plus_zk_configs = kraft_broker_configs.copy() + kraft_broker_plus_zk_configs.update(zk_broker_configs) + kraft_broker_plus_zk_configs.pop(config_property.BROKER_ID) + controller_only_configs = { + config_property.NODE_ID: self.idx(node) + config_property.FIRST_CONTROLLER_ID - 1, + } + kraft_controller_plus_zk_configs = controller_only_configs.copy() + kraft_controller_plus_zk_configs.update(zk_broker_configs) + kraft_controller_plus_zk_configs.pop(config_property.BROKER_ID) + if node_quorum_info.service_quorum_info.using_zk: + node.config = KafkaConfig(**zk_broker_configs) + elif not node_quorum_info.has_broker_role: # KRaft controller-only role + if self.zk: + node.config = KafkaConfig(**kraft_controller_plus_zk_configs) + else: + node.config = KafkaConfig(**controller_only_configs) + else: # KRaft broker-only role or combined broker+controller roles + if self.zk: + node.config = KafkaConfig(**kraft_broker_plus_zk_configs) + else: + node.config = KafkaConfig(**kraft_broker_configs) + self.colocated_nodes_started = 0 + self.nodes_to_start = self.nodes + + def num_kraft_controllers(self, num_nodes_broker_role, controller_num_nodes_override): + if controller_num_nodes_override < 0: + raise Exception("controller_num_nodes_override must not be negative: %i" % controller_num_nodes_override) + if controller_num_nodes_override > num_nodes_broker_role and self.quorum_info.quorum_type == quorum.colocated_kraft: + raise Exception("controller_num_nodes_override must not exceed the service's node count in the co-located case: %i > %i" % + (controller_num_nodes_override, num_nodes_broker_role)) + if controller_num_nodes_override: + return controller_num_nodes_override + if num_nodes_broker_role < 3: + return 1 + if num_nodes_broker_role < 5: + return 3 + return 5 + + def set_version(self, version): + for node in self.nodes: + node.version = version + + def controller_listener_name(self, security_protocol_name): + return "CONTROLLER_%s" % security_protocol_name + + @property + def interbroker_security_protocol(self): + # TODO: disentangle interbroker and intercontroller protocol information + return self.interbroker_listener.security_protocol if self.quorum_info.using_zk or self.quorum_info.has_brokers else self.intercontroller_security_protocol + + # this is required for backwards compatibility - there are a lot of tests that set this property explicitly + # meaning 'use one of the existing listeners that match given security protocol, do not use custom listener' + @interbroker_security_protocol.setter + def interbroker_security_protocol(self, security_protocol): + self.setup_interbroker_listener(security_protocol, use_separate_listener=False) + + def setup_interbroker_listener(self, security_protocol, use_separate_listener=False): + self.listener_security_config.use_separate_interbroker_listener = use_separate_listener + + if self.listener_security_config.use_separate_interbroker_listener: + # do not close existing port here since it is not used exclusively for interbroker communication + self.interbroker_listener = self.port_mappings[KafkaService.INTERBROKER_LISTENER_NAME] + self.interbroker_listener.security_protocol = security_protocol + else: + # close dedicated interbroker port, so it's not dangling in 'listeners' and 'advertised.listeners' + self.close_port(KafkaService.INTERBROKER_LISTENER_NAME) + self.interbroker_listener = self.port_mappings[security_protocol] + + @property + def security_config(self): + if not self._security_config: + # we will later change the security protocols to PLAINTEXT if this is a remote KRaft controller case since + # those security protocols are irrelevant there and we don't want to falsely indicate the use of SASL or TLS + security_protocol_to_use=self.security_protocol + interbroker_security_protocol_to_use=self.interbroker_security_protocol + # determine uses/serves controller sasl mechanisms + serves_controller_sasl_mechanism=None + serves_intercontroller_sasl_mechanism=None + uses_controller_sasl_mechanism=None + if self.quorum_info.has_brokers: + if self.controller_quorum.controller_security_protocol in SecurityConfig.SASL_SECURITY_PROTOCOLS: + uses_controller_sasl_mechanism = self.controller_quorum.controller_sasl_mechanism + if self.quorum_info.has_controllers: + if self.intercontroller_security_protocol in SecurityConfig.SASL_SECURITY_PROTOCOLS: + serves_intercontroller_sasl_mechanism = self.intercontroller_sasl_mechanism + uses_controller_sasl_mechanism = self.intercontroller_sasl_mechanism # won't change from above in co-located case + if self.controller_security_protocol in SecurityConfig.SASL_SECURITY_PROTOCOLS: + serves_controller_sasl_mechanism = self.controller_sasl_mechanism + # determine if KRaft uses TLS + kraft_tls = False + if self.quorum_info.has_brokers and not self.quorum_info.has_controllers: + # KRaft broker only + kraft_tls = self.controller_quorum.controller_security_protocol in SecurityConfig.SSL_SECURITY_PROTOCOLS + if self.quorum_info.has_controllers: + # remote or co-located KRaft controller + kraft_tls = self.controller_security_protocol in SecurityConfig.SSL_SECURITY_PROTOCOLS \ + or self.intercontroller_security_protocol in SecurityConfig.SSL_SECURITY_PROTOCOLS + # clear irrelevant security protocols of SASL/TLS implications for remote controller quorum case + if self.quorum_info.has_controllers and not self.quorum_info.has_brokers: + security_protocol_to_use=SecurityConfig.PLAINTEXT + interbroker_security_protocol_to_use=SecurityConfig.PLAINTEXT + + self._security_config = SecurityConfig(self.context, security_protocol_to_use, interbroker_security_protocol_to_use, + zk_sasl=self.zk.zk_sasl if self.quorum_info.using_zk else False, zk_tls=self.zk_client_secure, + client_sasl_mechanism=self.client_sasl_mechanism, + interbroker_sasl_mechanism=self.interbroker_sasl_mechanism, + listener_security_config=self.listener_security_config, + tls_version=self.tls_version, + serves_controller_sasl_mechanism=serves_controller_sasl_mechanism, + serves_intercontroller_sasl_mechanism=serves_intercontroller_sasl_mechanism, + uses_controller_sasl_mechanism=uses_controller_sasl_mechanism, + kraft_tls=kraft_tls) + # Ensure we have the correct client security protocol and SASL mechanism because they may have been mutated + self._security_config.properties['security.protocol'] = self.security_protocol + self._security_config.properties['sasl.mechanism'] = self.client_sasl_mechanism + # Ensure we have the right inter-broker security protocol because it may have been mutated + # since we cached our security config (ignore if this is a remote KRaft controller quorum case; the + # inter-broker security protocol is not used there). + if (self.quorum_info.using_zk or self.quorum_info.has_brokers): + # in case inter-broker SASL mechanism has changed without changing the inter-broker security protocol + self._security_config.properties['sasl.mechanism.inter.broker.protocol'] = self.interbroker_sasl_mechanism + if self._security_config.interbroker_security_protocol != self.interbroker_security_protocol: + self._security_config.interbroker_security_protocol = self.interbroker_security_protocol + self._security_config.calc_has_sasl() + self._security_config.calc_has_ssl() + for port in self.port_mappings.values(): + if port.open: + self._security_config.enable_security_protocol(port.security_protocol, port.sasl_mechanism) + if self.quorum_info.using_zk: + if self.zk.zk_sasl: + self._security_config.enable_sasl() + self._security_config.zk_sasl = self.zk.zk_sasl + if self.zk_client_secure: + self._security_config.enable_ssl() + self._security_config.zk_tls = self.zk_client_secure + return self._security_config + + def open_port(self, listener_name): + self.port_mappings[listener_name].open = True + + def close_port(self, listener_name): + self.port_mappings[listener_name].open = False + + def start_minikdc_if_necessary(self, add_principals=""): + has_sasl = self.security_config.has_sasl + if has_sasl: + if self.minikdc is None: + other_service = self.remote_kafka if self.remote_kafka else self.controller_quorum if self.quorum_info.using_kraft else None + if not other_service or not other_service.minikdc: + nodes_for_kdc = self.nodes.copy() + if other_service and other_service != self: + nodes_for_kdc += other_service.nodes + self.minikdc = MiniKdc(self.context, nodes_for_kdc, extra_principals = add_principals) + self.minikdc.start() + else: + self.minikdc = None + if self.quorum_info.using_kraft: + self.controller_quorum.minikdc = None + if self.remote_kafka: + self.remote_kafka.minikdc = None + + def alive(self, node): + return len(self.pids(node)) > 0 + + def start(self, add_principals="", nodes_to_skip=[], timeout_sec=60): + """ + Start the Kafka broker and wait until it registers its ID in ZooKeeper + Startup will be skipped for any nodes in nodes_to_skip. These nodes can be started later via add_broker + """ + if self.quorum_info.using_zk and self.zk_client_secure and not self.zk.zk_client_secure_port: + raise Exception("Unable to start Kafka: TLS to Zookeeper requested but Zookeeper secure port not enabled") + + if not all([node in self.nodes for node in nodes_to_skip]): + raise Exception("nodes_to_skip should be a subset of this service's nodes") + + if self.quorum_info.has_brokers_and_controllers and ( + self.controller_security_protocol != self.intercontroller_security_protocol or + self.controller_security_protocol in SecurityConfig.SASL_SECURITY_PROTOCOLS and self.controller_sasl_mechanism != self.intercontroller_sasl_mechanism): + # This is not supported because both the broker and the controller take the first entry from + # controller.listener.names and the value from sasl.mechanism.controller.protocol; + # they share a single config, so they must both see/use identical values. + raise Exception("Co-located KRaft Brokers (%s/%s) and Controllers (%s/%s) cannot talk to Controllers via different security protocols" % + (self.controller_security_protocol, self.controller_sasl_mechanism, + self.intercontroller_security_protocol, self.intercontroller_sasl_mechanism)) + if self.quorum_info.using_zk or self.quorum_info.has_brokers: + self.open_port(self.security_protocol) + self.interbroker_listener.open = True + # we have to wait to decide whether to open the controller port(s) + # because it could be dependent on the particular node in the + # co-located case where the number of controllers could be less + # than the number of nodes in the service + + self.start_minikdc_if_necessary(add_principals) + + # save the nodes we want to start in a member variable so we know which nodes to start and which to skip + # in start_node + self.nodes_to_start = [node for node in self.nodes if node not in nodes_to_skip] + + if self.quorum_info.using_zk: + self._ensure_zk_chroot() + + if self.remote_controller_quorum: + self.remote_controller_quorum.start() + Service.start(self) + if self.concurrent_start: + # We didn't wait while starting each individual node, so wait for them all now + for node in self.nodes_to_start: + with node.account.monitor_log(KafkaService.STDOUT_STDERR_CAPTURE) as monitor: + monitor.offset = 0 + self.wait_for_start(node, monitor, timeout_sec) + + if self.quorum_info.using_zk: + self.logger.info("Waiting for brokers to register at ZK") + + expected_broker_ids = set(self.nodes_to_start) + wait_until(lambda: {node for node in self.nodes_to_start if self.is_registered(node)} == expected_broker_ids, + timeout_sec=30, backoff_sec=1, err_msg="Kafka servers didn't register at ZK within 30 seconds") + + # Create topics if necessary + if self.topics is not None: + for topic, topic_cfg in self.topics.items(): + if topic_cfg is None: + topic_cfg = {} + + topic_cfg["topic"] = topic + self.create_topic(topic_cfg) + self.concurrent_start = False # in case it was True and this method was invoked directly instead of via start_concurrently() + + def start_concurrently(self, add_principals="", timeout_sec=60): + self.concurrent_start = True # ensure it is True in case it has been explicitly disabled elsewhere + self.start(add_principals = add_principals, timeout_sec=timeout_sec) + self.concurrent_start = False + + def add_broker(self, node): + """ + Starts an individual node. add_broker should only be used for nodes skipped during initial kafka service startup + """ + if node in self.nodes_to_start: + raise Exception("Add broker should only be used for nodes that haven't already been started") + + self.logger.debug(self.who_am_i() + ": killing processes and attempting to clean up before starting") + # Added precaution - kill running processes, clean persistent files + # try/except for each step, since each of these steps may fail if there are no processes + # to kill or no files to remove + try: + self.stop_node(node) + except Exception: + pass + + try: + self.clean_node(node) + except Exception: + pass + + if node not in self.nodes_to_start: + self.nodes_to_start += [node] + self.logger.debug("%s: starting node" % self.who_am_i(node)) + # ensure we wait for the broker to start by setting concurrent start to False for the invocation of start_node() + orig_concurrent_start = self.concurrent_start + self.concurrent_start = False + self.start_node(node) + self.concurrent_start = orig_concurrent_start + wait_until(lambda: self.is_registered(node), 30, 1) + + def _ensure_zk_chroot(self): + self.logger.info("Ensuring zk_chroot %s exists", self.zk_chroot) + if self.zk_chroot: + if not self.zk_chroot.startswith('/'): + raise Exception("Zookeeper chroot must start with '/' but found " + self.zk_chroot) + + parts = self.zk_chroot.split('/')[1:] + for i in range(len(parts)): + self.zk.create('/' + '/'.join(parts[:i+1])) + + def set_protocol_and_port(self, node): + listeners = [] + advertised_listeners = [] + protocol_map = [] + + controller_listener_names = self.controller_listener_name_list(node) + + for port in self.port_mappings.values(): + if port.open: + listeners.append(port.listener()) + if not port.name in controller_listener_names: + advertised_listeners.append(port.advertised_listener(node)) + protocol_map.append(port.listener_security_protocol()) + controller_sec_protocol = self.remote_controller_quorum.controller_security_protocol if self.remote_controller_quorum \ + else self.controller_security_protocol if self.quorum_info.has_brokers_and_controllers and not quorum.NodeQuorumInfo(self.quorum_info, node).has_controller_role \ + else None + if controller_sec_protocol: + protocol_map.append("%s:%s" % (self.controller_listener_name(controller_sec_protocol), controller_sec_protocol)) + + self.listeners = ','.join(listeners) + self.advertised_listeners = ','.join(advertised_listeners) + self.listener_security_protocol_map = ','.join(protocol_map) + if self.quorum_info.using_zk or self.quorum_info.has_brokers: + self.interbroker_bootstrap_servers = self.__bootstrap_servers(self.interbroker_listener, True) + + def prop_file(self, node): + self.set_protocol_and_port(node) + + #load template configs as dictionary + config_template = self.render('kafka.properties', node=node, broker_id=self.idx(node), + security_config=self.security_config, num_nodes=self.num_nodes, + listener_security_config=self.listener_security_config) + + configs = dict( l.rstrip().split('=', 1) for l in config_template.split('\n') + if not l.startswith("#") and "=" in l ) + + #load specific test override configs + override_configs = KafkaConfig(**node.config) + if self.quorum_info.using_zk or self.quorum_info.has_brokers: + override_configs[config_property.ADVERTISED_HOSTNAME] = node.account.hostname + if self.quorum_info.using_zk or self.zk: + override_configs[config_property.ZOOKEEPER_CONNECT] = self.zk_connect_setting() + if self.zk_client_secure: + override_configs[config_property.ZOOKEEPER_SSL_CLIENT_ENABLE] = 'true' + override_configs[config_property.ZOOKEEPER_CLIENT_CNXN_SOCKET] = 'org.apache.zookeeper.ClientCnxnSocketNetty' + else: + override_configs[config_property.ZOOKEEPER_SSL_CLIENT_ENABLE] = 'false' + + for prop in self.server_prop_overrides: + override_configs[prop[0]] = prop[1] + + for prop in self.per_node_server_prop_overrides.get(self.idx(node), []): + override_configs[prop[0]] = prop[1] + + #update template configs with test override configs + configs.update(override_configs) + + prop_file = self.render_configs(configs) + return prop_file + + def render_configs(self, configs): + """Render self as a series of lines key=val\n, and do so in a consistent order. """ + keys = [k for k in configs.keys()] + keys.sort() + + s = "" + for k in keys: + s += "%s=%s\n" % (k, str(configs[k])) + return s + + def start_cmd(self, node): + cmd = "export JMX_PORT=%d; " % self.jmx_port + cmd += "export KAFKA_LOG4J_OPTS=\"-Dlog4j.configuration=file:%s\"; " % self.LOG4J_CONFIG + heap_kafka_opts = "-XX:+HeapDumpOnOutOfMemoryError -XX:HeapDumpPath=%s" % \ + self.logs["kafka_heap_dump_file"]["path"] + security_kafka_opts = self.security_config.kafka_opts.strip('\"') + + cmd += fix_opts_for_new_jvm(node) + cmd += "export KAFKA_OPTS=\"%s %s %s\"; " % (heap_kafka_opts, security_kafka_opts, self.extra_kafka_opts) + cmd += "%s %s 1>> %s 2>> %s &" % \ + (self.path.script("kafka-server-start.sh", node), + KafkaService.CONFIG_FILE, + KafkaService.STDOUT_STDERR_CAPTURE, + KafkaService.STDOUT_STDERR_CAPTURE) + return cmd + + def controller_listener_name_list(self, node): + if self.quorum_info.using_zk: + return [] + broker_to_controller_listener_name = self.controller_listener_name(self.controller_quorum.controller_security_protocol) + # Brokers always use the first controller listener, so include a second, inter-controller listener if and only if: + # 1) the node is a controller node + # 2) the inter-controller listener name differs from the broker-to-controller listener name + return [broker_to_controller_listener_name, self.controller_listener_name(self.controller_quorum.intercontroller_security_protocol)] \ + if (quorum.NodeQuorumInfo(self.quorum_info, node).has_controller_role and + self.controller_quorum.intercontroller_security_protocol != self.controller_quorum.controller_security_protocol) \ + else [broker_to_controller_listener_name] + + def start_node(self, node, timeout_sec=60): + if node not in self.nodes_to_start: + return + node.account.mkdirs(KafkaService.PERSISTENT_ROOT) + + self.node_quorum_info = quorum.NodeQuorumInfo(self.quorum_info, node) + if self.quorum_info.has_controllers: + for controller_listener in self.controller_listener_name_list(node): + if self.node_quorum_info.has_controller_role: + self.open_port(controller_listener) + else: # co-located case where node doesn't have a controller + self.close_port(controller_listener) + + self.security_config.setup_node(node) + if self.quorum_info.using_zk or self.quorum_info.has_brokers: # TODO: SCRAM currently unsupported for controller quorum + self.maybe_setup_broker_scram_credentials(node) + + if self.quorum_info.using_kraft: + # define controller.quorum.voters text + security_protocol_to_use = self.controller_quorum.controller_security_protocol + first_node_id = 1 if self.quorum_info.has_brokers_and_controllers else config_property.FIRST_CONTROLLER_ID + self.controller_quorum_voters = ','.join(["%s@%s:%s" % + (self.controller_quorum.idx(node) + first_node_id - 1, + node.account.hostname, + config_property.FIRST_CONTROLLER_PORT + + KafkaService.SECURITY_PROTOCOLS.index(security_protocol_to_use)) + for node in self.controller_quorum.nodes[:self.controller_quorum.num_nodes_controller_role]]) + # define controller.listener.names + self.controller_listener_names = ','.join(self.controller_listener_name_list(node)) + # define sasl.mechanism.controller.protocol to match remote quorum if one exists + if self.remote_controller_quorum: + self.controller_sasl_mechanism = self.remote_controller_quorum.controller_sasl_mechanism + + prop_file = self.prop_file(node) + self.logger.info("kafka.properties:") + self.logger.info(prop_file) + node.account.create_file(KafkaService.CONFIG_FILE, prop_file) + node.account.create_file(self.LOG4J_CONFIG, self.render('log4j.properties', log_dir=KafkaService.OPERATIONAL_LOG_DIR)) + + if self.quorum_info.using_kraft: + # format log directories if necessary + kafka_storage_script = self.path.script("kafka-storage.sh", node) + cmd = "%s format --ignore-formatted --config %s --cluster-id %s" % (kafka_storage_script, KafkaService.CONFIG_FILE, config_property.CLUSTER_ID) + self.logger.info("Running log directory format command...\n%s" % cmd) + node.account.ssh(cmd) + + cmd = self.start_cmd(node) + self.logger.debug("Attempting to start KafkaService %s on %s with command: %s" %\ + ("concurrently" if self.concurrent_start else "serially", str(node.account), cmd)) + if self.node_quorum_info.has_controller_role and self.node_quorum_info.has_broker_role: + self.colocated_nodes_started += 1 + if self.concurrent_start: + node.account.ssh(cmd) # and then don't wait for the startup message + else: + with node.account.monitor_log(KafkaService.STDOUT_STDERR_CAPTURE) as monitor: + node.account.ssh(cmd) + self.wait_for_start(node, monitor, timeout_sec) + + def wait_for_start(self, node, monitor, timeout_sec=60): + # Kafka 1.0.0 and higher don't have a space between "Kafka" and "Server" + monitor.wait_until("Kafka\s*Server.*started", timeout_sec=timeout_sec, backoff_sec=.25, + err_msg="Kafka server didn't finish startup in %d seconds" % timeout_sec) + + if self.quorum_info.using_zk or self.quorum_info.has_brokers: # TODO: SCRAM currently unsupported for controller quorum + # Credentials for inter-broker communication are created before starting Kafka. + # Client credentials are created after starting Kafka so that both loading of + # existing credentials from ZK and dynamic update of credentials in Kafka are tested. + # We use the admin client and connect as the broker user when creating the client (non-broker) credentials + # if Kafka supports KIP-554, otherwise we use ZooKeeper. + self.maybe_setup_client_scram_credentials(node) + + self.start_jmx_tool(self.idx(node), node) + if len(self.pids(node)) == 0: + raise Exception("No process ids recorded on node %s" % node.account.hostname) + + def pids(self, node): + """Return process ids associated with running processes on the given node.""" + try: + cmd = "jcmd | grep -e %s | awk '{print $1}'" % self.java_class_name() + pid_arr = [pid for pid in node.account.ssh_capture(cmd, allow_fail=True, callback=int)] + return pid_arr + except (RemoteCommandError, ValueError) as e: + return [] + + def signal_node(self, node, sig=signal.SIGTERM): + pids = self.pids(node) + for pid in pids: + node.account.signal(pid, sig) + + def signal_leader(self, topic, partition=0, sig=signal.SIGTERM): + leader = self.leader(topic, partition) + self.signal_node(leader, sig) + + def stop_node(self, node, clean_shutdown=True, timeout_sec=60): + pids = self.pids(node) + cluster_has_colocated_controllers = self.quorum_info.has_brokers and self.quorum_info.has_controllers + force_sigkill_due_to_too_few_colocated_controllers =\ + clean_shutdown and cluster_has_colocated_controllers\ + and self.colocated_nodes_started < round(self.num_nodes_controller_role / 2) + if force_sigkill_due_to_too_few_colocated_controllers: + self.logger.info("Forcing node to stop via SIGKILL due to too few co-located KRaft controllers: %i/%i" %\ + (self.colocated_nodes_started, self.num_nodes_controller_role)) + + sig = signal.SIGTERM if clean_shutdown and not force_sigkill_due_to_too_few_colocated_controllers else signal.SIGKILL + + for pid in pids: + node.account.signal(pid, sig, allow_fail=False) + + node_quorum_info = quorum.NodeQuorumInfo(self.quorum_info, node) + node_has_colocated_controllers = node_quorum_info.has_controller_role and node_quorum_info.has_broker_role + if pids and node_has_colocated_controllers: + self.colocated_nodes_started -= 1 + + try: + wait_until(lambda: len(self.pids(node)) == 0, timeout_sec=timeout_sec, + err_msg="Kafka node failed to stop in %d seconds" % timeout_sec) + except Exception: + if node_has_colocated_controllers: + # it didn't stop + self.colocated_nodes_started += 1 + self.thread_dump(node) + raise + + def thread_dump(self, node): + for pid in self.pids(node): + try: + node.account.signal(pid, signal.SIGQUIT, allow_fail=True) + except: + self.logger.warn("Could not dump threads on node") + + def clean_node(self, node): + JmxMixin.clean_node(self, node) + self.security_config.clean_node(node) + node.account.kill_java_processes(self.java_class_name(), + clean_shutdown=False, allow_fail=True) + node.account.ssh("sudo rm -rf -- %s" % KafkaService.PERSISTENT_ROOT, allow_fail=False) + + def kafka_topics_cmd_with_optional_security_settings(self, node, force_use_zk_connection, kafka_security_protocol=None, offline_nodes=[]): + if self.quorum_info.using_kraft and not self.quorum_info.has_brokers: + raise Exception("Must invoke kafka-topics against a broker, not a KRaft controller") + if force_use_zk_connection: + bootstrap_server_or_zookeeper = "--zookeeper %s" % (self.zk_connect_setting()) + skip_optional_security_settings = True + else: + if kafka_security_protocol is None: + # it wasn't specified, so use the inter-broker security protocol if it is PLAINTEXT, + # otherwise use the client security protocol + if self.interbroker_security_protocol == SecurityConfig.PLAINTEXT: + security_protocol_to_use = SecurityConfig.PLAINTEXT + else: + security_protocol_to_use = self.security_protocol + else: + security_protocol_to_use = kafka_security_protocol + bootstrap_server_or_zookeeper = "--bootstrap-server %s" % (self.bootstrap_servers(security_protocol_to_use, offline_nodes=offline_nodes)) + skip_optional_security_settings = security_protocol_to_use == SecurityConfig.PLAINTEXT + if skip_optional_security_settings: + optional_jass_krb_system_props_prefix = "" + optional_command_config_suffix = "" + else: + # we need security configs because aren't going to ZooKeeper and we aren't using PLAINTEXT + if (security_protocol_to_use == self.interbroker_security_protocol): + # configure JAAS to provide the broker's credentials + # since this is an authenticating cluster and we are going to use the inter-broker security protocol + jaas_conf_prop = KafkaService.ADMIN_CLIENT_AS_BROKER_JAAS_CONF_PROPERTY + use_inter_broker_mechanism_for_client = True + else: + # configure JAAS to provide the typical client credentials + jaas_conf_prop = KafkaService.JAAS_CONF_PROPERTY + use_inter_broker_mechanism_for_client = False + # We are either using SASL (SASL_SSL or SASL_PLAINTEXT) or we are using SSL + using_sasl = security_protocol_to_use != "SSL" + optional_jass_krb_system_props_prefix = "KAFKA_OPTS='-D%s -D%s' " % (jaas_conf_prop, KafkaService.KRB5_CONF) if using_sasl else "" + optional_command_config_suffix = " --command-config <(echo '%s')" % (self.security_config.client_config(use_inter_broker_mechanism_for_client = use_inter_broker_mechanism_for_client)) + kafka_topic_script = self.path.script("kafka-topics.sh", node) + return "%s%s %s%s" % \ + (optional_jass_krb_system_props_prefix, kafka_topic_script, + bootstrap_server_or_zookeeper, optional_command_config_suffix) + + def kafka_configs_cmd_with_optional_security_settings(self, node, force_use_zk_connection, kafka_security_protocol = None): + if self.quorum_info.using_kraft and not self.quorum_info.has_brokers: + raise Exception("Must invoke kafka-configs against a broker, not a KRaft controller") + if force_use_zk_connection: + # kafka-configs supports a TLS config file, so include it if there is one + bootstrap_server_or_zookeeper = "--zookeeper %s %s" % (self.zk_connect_setting(), self.zk.zkTlsConfigFileOption()) + skip_optional_security_settings = True + else: + if kafka_security_protocol is None: + # it wasn't specified, so use the inter-broker security protocol if it is PLAINTEXT, + # otherwise use the client security protocol + if self.interbroker_security_protocol == SecurityConfig.PLAINTEXT: + security_protocol_to_use = SecurityConfig.PLAINTEXT + else: + security_protocol_to_use = self.security_protocol + else: + security_protocol_to_use = kafka_security_protocol + bootstrap_server_or_zookeeper = "--bootstrap-server %s" % (self.bootstrap_servers(security_protocol_to_use)) + skip_optional_security_settings = security_protocol_to_use == SecurityConfig.PLAINTEXT + if skip_optional_security_settings: + optional_jass_krb_system_props_prefix = "" + optional_command_config_suffix = "" + else: + # we need security configs because aren't going to ZooKeeper and we aren't using PLAINTEXT + if (security_protocol_to_use == self.interbroker_security_protocol): + # configure JAAS to provide the broker's credentials + # since this is an authenticating cluster and we are going to use the inter-broker security protocol + jaas_conf_prop = KafkaService.ADMIN_CLIENT_AS_BROKER_JAAS_CONF_PROPERTY + use_inter_broker_mechanism_for_client = True + else: + # configure JAAS to provide the typical client credentials + jaas_conf_prop = KafkaService.JAAS_CONF_PROPERTY + use_inter_broker_mechanism_for_client = False + # We are either using SASL (SASL_SSL or SASL_PLAINTEXT) or we are using SSL + using_sasl = security_protocol_to_use != "SSL" + optional_jass_krb_system_props_prefix = "KAFKA_OPTS='-D%s -D%s' " % (jaas_conf_prop, KafkaService.KRB5_CONF) if using_sasl else "" + optional_command_config_suffix = " --command-config <(echo '%s')" % (self.security_config.client_config(use_inter_broker_mechanism_for_client = use_inter_broker_mechanism_for_client)) + kafka_config_script = self.path.script("kafka-configs.sh", node) + return "%s%s %s%s" % \ + (optional_jass_krb_system_props_prefix, kafka_config_script, + bootstrap_server_or_zookeeper, optional_command_config_suffix) + + def maybe_setup_broker_scram_credentials(self, node): + security_config = self.security_config + # we only need to create broker credentials when the broker mechanism is SASL/SCRAM + if security_config.is_sasl(self.interbroker_security_protocol) and security_config.is_sasl_scram(self.interbroker_sasl_mechanism): + force_use_zk_connection = True # we are bootstrapping these credentials before Kafka is started + cmd = fix_opts_for_new_jvm(node) + cmd += "%(kafka_configs_cmd)s --entity-name %(user)s --entity-type users --alter --add-config %(mechanism)s=[password=%(password)s]" % { + 'kafka_configs_cmd': self.kafka_configs_cmd_with_optional_security_settings(node, force_use_zk_connection), + 'user': SecurityConfig.SCRAM_BROKER_USER, + 'mechanism': self.interbroker_sasl_mechanism, + 'password': SecurityConfig.SCRAM_BROKER_PASSWORD + } + node.account.ssh(cmd) + + def maybe_setup_client_scram_credentials(self, node): + security_config = self.security_config + # we only need to create client credentials when the client mechanism is SASL/SCRAM + if security_config.is_sasl(self.security_protocol) and security_config.is_sasl_scram(self.client_sasl_mechanism): + force_use_zk_connection = not self.all_nodes_configs_command_uses_bootstrap_server_scram() + # ignored if forcing the use of Zookeeper, but we need a value to send, so calculate it anyway + if self.interbroker_security_protocol == SecurityConfig.PLAINTEXT: + kafka_security_protocol = self.interbroker_security_protocol + else: + kafka_security_protocol = self.security_protocol + cmd = fix_opts_for_new_jvm(node) + cmd += "%(kafka_configs_cmd)s --entity-name %(user)s --entity-type users --alter --add-config %(mechanism)s=[password=%(password)s]" % { + 'kafka_configs_cmd': self.kafka_configs_cmd_with_optional_security_settings(node, force_use_zk_connection, kafka_security_protocol), + 'user': SecurityConfig.SCRAM_CLIENT_USER, + 'mechanism': self.client_sasl_mechanism, + 'password': SecurityConfig.SCRAM_CLIENT_PASSWORD + } + node.account.ssh(cmd) + + def node_inter_broker_protocol_version(self, node): + if config_property.INTER_BROKER_PROTOCOL_VERSION in node.config: + return KafkaVersion(node.config[config_property.INTER_BROKER_PROTOCOL_VERSION]) + return node.version + + def all_nodes_topic_command_supports_bootstrap_server(self): + for node in self.nodes: + if not node.version.topic_command_supports_bootstrap_server(): + return False + return True + + def all_nodes_topic_command_supports_if_not_exists_with_bootstrap_server(self): + for node in self.nodes: + if not node.version.topic_command_supports_if_not_exists_with_bootstrap_server(): + return False + return True + + def all_nodes_configs_command_uses_bootstrap_server(self): + for node in self.nodes: + if not node.version.kafka_configs_command_uses_bootstrap_server(): + return False + return True + + def all_nodes_configs_command_uses_bootstrap_server_scram(self): + for node in self.nodes: + if not node.version.kafka_configs_command_uses_bootstrap_server_scram(): + return False + return True + + def all_nodes_acl_command_supports_bootstrap_server(self): + for node in self.nodes: + if not node.version.acl_command_supports_bootstrap_server(): + return False + return True + + def all_nodes_reassign_partitions_command_supports_bootstrap_server(self): + for node in self.nodes: + if not node.version.reassign_partitions_command_supports_bootstrap_server(): + return False + return True + + def all_nodes_support_topic_ids(self): + if self.quorum_info.using_kraft: return True + for node in self.nodes: + if not self.node_inter_broker_protocol_version(node).supports_topic_ids_when_using_zk(): + return False + return True + + def create_topic(self, topic_cfg, node=None): + """Run the admin tool create topic command. + Specifying node is optional, and may be done if for different kafka nodes have different versions, + and we care where command gets run. + + If the node is not specified, run the command from self.nodes[0] + """ + if node is None: + node = self.nodes[0] + self.logger.info("Creating topic %s with settings %s", + topic_cfg["topic"], topic_cfg) + + force_use_zk_connection = not self.all_nodes_topic_command_supports_bootstrap_server() or\ + (topic_cfg.get('if-not-exists', False) and not self.all_nodes_topic_command_supports_if_not_exists_with_bootstrap_server()) + + cmd = fix_opts_for_new_jvm(node) + cmd += "%(kafka_topics_cmd)s --create --topic %(topic)s " % { + 'kafka_topics_cmd': self.kafka_topics_cmd_with_optional_security_settings(node, force_use_zk_connection), + 'topic': topic_cfg.get("topic"), + } + if 'replica-assignment' in topic_cfg: + cmd += " --replica-assignment %(replica-assignment)s" % { + 'replica-assignment': topic_cfg.get('replica-assignment') + } + else: + cmd += " --partitions %(partitions)d --replication-factor %(replication-factor)d" % { + 'partitions': topic_cfg.get('partitions', 1), + 'replication-factor': topic_cfg.get('replication-factor', 1) + } + + if topic_cfg.get('if-not-exists', False): + cmd += ' --if-not-exists' + + if "configs" in topic_cfg.keys() and topic_cfg["configs"] is not None: + for config_name, config_value in topic_cfg["configs"].items(): + cmd += " --config %s=%s" % (config_name, str(config_value)) + + self.logger.info("Running topic creation command...\n%s" % cmd) + node.account.ssh(cmd) + + def delete_topic(self, topic, node=None): + """ + Delete a topic with the topics command + :param topic: + :param node: + :return: + """ + if node is None: + node = self.nodes[0] + self.logger.info("Deleting topic %s" % topic) + + force_use_zk_connection = not self.all_nodes_topic_command_supports_bootstrap_server() + + cmd = fix_opts_for_new_jvm(node) + cmd += "%s --topic %s --delete" % \ + (self.kafka_topics_cmd_with_optional_security_settings(node, force_use_zk_connection), topic) + self.logger.info("Running topic delete command...\n%s" % cmd) + node.account.ssh(cmd) + + def has_under_replicated_partitions(self): + """ + Check whether the cluster has under-replicated partitions. + + :return True if there are under-replicated partitions, False otherwise. + """ + return len(self.describe_under_replicated_partitions()) > 0 + + def await_no_under_replicated_partitions(self, timeout_sec=30): + """ + Wait for all under-replicated partitions to clear. + + :param timeout_sec: the maximum time in seconds to wait + """ + wait_until(lambda: not self.has_under_replicated_partitions(), + timeout_sec = timeout_sec, + err_msg="Timed out waiting for under-replicated-partitions to clear") + + def describe_under_replicated_partitions(self): + """ + Use the topic tool to find the under-replicated partitions in the cluster. + + :return the under-replicated partitions as a list of dictionaries + (e.g. [{"topic": "foo", "partition": 1}, {"topic": "bar", "partition": 0}, ... ]) + """ + + node = self.nodes[0] + force_use_zk_connection = not node.version.topic_command_supports_bootstrap_server() + + cmd = fix_opts_for_new_jvm(node) + cmd += "%s --describe --under-replicated-partitions" % \ + self.kafka_topics_cmd_with_optional_security_settings(node, force_use_zk_connection) + + self.logger.debug("Running topic command to describe under-replicated partitions\n%s" % cmd) + output = "" + for line in node.account.ssh_capture(cmd): + output += line + + under_replicated_partitions = self.parse_describe_topic(output)["partitions"] + self.logger.debug("Found %d under-replicated-partitions" % len(under_replicated_partitions)) + + return under_replicated_partitions + + def describe_topic(self, topic, node=None, offline_nodes=[]): + if node is None: + node = self.nodes[0] + + force_use_zk_connection = not self.all_nodes_topic_command_supports_bootstrap_server() + + cmd = fix_opts_for_new_jvm(node) + cmd += "%s --topic %s --describe" % \ + (self.kafka_topics_cmd_with_optional_security_settings(node, force_use_zk_connection, offline_nodes=offline_nodes), topic) + + self.logger.info("Running topic describe command...\n%s" % cmd) + output = "" + for line in node.account.ssh_capture(cmd): + output += line + return output + + def list_topics(self, node=None): + if node is None: + node = self.nodes[0] + + force_use_zk_connection = not self.all_nodes_topic_command_supports_bootstrap_server() + + cmd = fix_opts_for_new_jvm(node) + cmd += "%s --list" % (self.kafka_topics_cmd_with_optional_security_settings(node, force_use_zk_connection)) + for line in node.account.ssh_capture(cmd): + if not line.startswith("SLF4J"): + yield line.rstrip() + + def alter_message_format(self, topic, msg_format_version, node=None): + if node is None: + node = self.nodes[0] + self.logger.info("Altering message format version for topic %s with format %s", topic, msg_format_version) + + force_use_zk_connection = not self.all_nodes_configs_command_uses_bootstrap_server() + + cmd = fix_opts_for_new_jvm(node) + cmd += "%s --entity-name %s --entity-type topics --alter --add-config message.format.version=%s" % \ + (self.kafka_configs_cmd_with_optional_security_settings(node, force_use_zk_connection), topic, msg_format_version) + self.logger.info("Running alter message format command...\n%s" % cmd) + node.account.ssh(cmd) + + def set_unclean_leader_election(self, topic, value=True, node=None): + if node is None: + node = self.nodes[0] + if value is True: + self.logger.info("Enabling unclean leader election for topic %s", topic) + else: + self.logger.info("Disabling unclean leader election for topic %s", topic) + + force_use_zk_connection = not self.all_nodes_configs_command_uses_bootstrap_server() + + cmd = fix_opts_for_new_jvm(node) + cmd += "%s --entity-name %s --entity-type topics --alter --add-config unclean.leader.election.enable=%s" % \ + (self.kafka_configs_cmd_with_optional_security_settings(node, force_use_zk_connection), topic, str(value).lower()) + self.logger.info("Running alter unclean leader command...\n%s" % cmd) + node.account.ssh(cmd) + + def kafka_acls_cmd_with_optional_security_settings(self, node, force_use_zk_connection, kafka_security_protocol = None, override_command_config = None): + if self.quorum_info.using_kraft and not self.quorum_info.has_brokers: + raise Exception("Must invoke kafka-acls against a broker, not a KRaft controller") + force_use_zk_connection = force_use_zk_connection or not self.all_nodes_acl_command_supports_bootstrap_server + if force_use_zk_connection: + bootstrap_server_or_authorizer_zk_props = "--authorizer-properties zookeeper.connect=%s" % (self.zk_connect_setting()) + skip_optional_security_settings = True + else: + if kafka_security_protocol is None: + # it wasn't specified, so use the inter-broker security protocol if it is PLAINTEXT, + # otherwise use the client security protocol + if self.interbroker_security_protocol == SecurityConfig.PLAINTEXT: + security_protocol_to_use = SecurityConfig.PLAINTEXT + else: + security_protocol_to_use = self.security_protocol + else: + security_protocol_to_use = kafka_security_protocol + bootstrap_server_or_authorizer_zk_props = "--bootstrap-server %s" % (self.bootstrap_servers(security_protocol_to_use)) + skip_optional_security_settings = security_protocol_to_use == SecurityConfig.PLAINTEXT + if skip_optional_security_settings: + optional_jass_krb_system_props_prefix = "" + optional_command_config_suffix = "" + else: + # we need security configs because aren't going to ZooKeeper and we aren't using PLAINTEXT + if (security_protocol_to_use == self.interbroker_security_protocol): + # configure JAAS to provide the broker's credentials + # since this is an authenticating cluster and we are going to use the inter-broker security protocol + jaas_conf_prop = KafkaService.ADMIN_CLIENT_AS_BROKER_JAAS_CONF_PROPERTY + use_inter_broker_mechanism_for_client = True + else: + # configure JAAS to provide the typical client credentials + jaas_conf_prop = KafkaService.JAAS_CONF_PROPERTY + use_inter_broker_mechanism_for_client = False + # We are either using SASL (SASL_SSL or SASL_PLAINTEXT) or we are using SSL + using_sasl = security_protocol_to_use != "SSL" + optional_jass_krb_system_props_prefix = "KAFKA_OPTS='-D%s -D%s' " % (jaas_conf_prop, KafkaService.KRB5_CONF) if using_sasl else "" + if override_command_config is None: + optional_command_config_suffix = " --command-config <(echo '%s')" % (self.security_config.client_config(use_inter_broker_mechanism_for_client = use_inter_broker_mechanism_for_client)) + else: + optional_command_config_suffix = " --command-config %s" % (override_command_config) + kafka_acls_script = self.path.script("kafka-acls.sh", node) + return "%s%s %s%s" % \ + (optional_jass_krb_system_props_prefix, kafka_acls_script, + bootstrap_server_or_authorizer_zk_props, optional_command_config_suffix) + + def run_cli_tool(self, node, cmd): + output = "" + self.logger.debug(cmd) + for line in node.account.ssh_capture(cmd): + if not line.startswith("SLF4J"): + output += line + self.logger.debug(output) + return output + + def parse_describe_topic(self, topic_description): + """Parse output of kafka-topics.sh --describe (or describe_topic() method above), which is a string of form + Topic: test_topic\tTopicId: \tPartitionCount: 2\tReplicationFactor: 2\tConfigs: + Topic: test_topic\tPartition: 0\tLeader: 3\tReplicas: 3,1\tIsr: 3,1 + Topic: test_topic\tPartition: 1\tLeader: 1\tReplicas: 1,2\tIsr: 1,2 + into a dictionary structure appropriate for use with reassign-partitions tool: + { + "partitions": [ + {"topic": "test_topic", "partition": 0, "replicas": [3, 1]}, + {"topic": "test_topic", "partition": 1, "replicas": [1, 2]} + ] + } + """ + lines = map(lambda x: x.strip(), topic_description.split("\n")) + partitions = [] + for line in lines: + m = re.match(".*Leader:.*", line) + if m is None: + continue + + fields = line.split("\t") + # ["Partition: 4", "Leader: 0"] -> ["4", "0"] + fields = list(map(lambda x: x.split(" ")[1], fields)) + partitions.append( + {"topic": fields[0], + "partition": int(fields[1]), + "replicas": list(map(int, fields[3].split(',')))}) + return {"partitions": partitions} + + + def _connect_setting_reassign_partitions(self, node): + if self.all_nodes_reassign_partitions_command_supports_bootstrap_server(): + return "--bootstrap-server %s " % self.bootstrap_servers(self.security_protocol) + else: + return "--zookeeper %s " % self.zk_connect_setting() + + def verify_reassign_partitions(self, reassignment, node=None): + """Run the reassign partitions admin tool in "verify" mode + """ + if node is None: + node = self.nodes[0] + + json_file = "/tmp/%s_reassign.json" % str(time.time()) + + # reassignment to json + json_str = json.dumps(reassignment) + json_str = json.dumps(json_str) + + # create command + cmd = fix_opts_for_new_jvm(node) + cmd += "echo %s > %s && " % (json_str, json_file) + cmd += "%s " % self.path.script("kafka-reassign-partitions.sh", node) + cmd += self._connect_setting_reassign_partitions(node) + cmd += "--reassignment-json-file %s " % json_file + cmd += "--verify " + cmd += "&& sleep 1 && rm -f %s" % json_file + + # send command + self.logger.info("Verifying partition reassignment...") + self.logger.debug(cmd) + output = "" + for line in node.account.ssh_capture(cmd): + output += line + + self.logger.debug(output) + + if re.match(".*Reassignment of partition.*failed.*", + output.replace('\n', '')) is not None: + return False + + if re.match(".*is still in progress.*", + output.replace('\n', '')) is not None: + return False + + return True + + def execute_reassign_partitions(self, reassignment, node=None, + throttle=None): + """Run the reassign partitions admin tool in "verify" mode + """ + if node is None: + node = self.nodes[0] + json_file = "/tmp/%s_reassign.json" % str(time.time()) + + # reassignment to json + json_str = json.dumps(reassignment) + json_str = json.dumps(json_str) + + # create command + cmd = fix_opts_for_new_jvm(node) + cmd += "echo %s > %s && " % (json_str, json_file) + cmd += "%s " % self.path.script( "kafka-reassign-partitions.sh", node) + cmd += self._connect_setting_reassign_partitions(node) + cmd += "--reassignment-json-file %s " % json_file + cmd += "--execute" + if throttle is not None: + cmd += " --throttle %d" % throttle + cmd += " && sleep 1 && rm -f %s" % json_file + + # send command + self.logger.info("Executing parition reassignment...") + self.logger.debug(cmd) + output = "" + for line in node.account.ssh_capture(cmd): + output += line + + self.logger.debug("Verify partition reassignment:") + self.logger.debug(output) + + def search_data_files(self, topic, messages): + """Check if a set of messages made it into the Kakfa data files. Note that + this method takes no account of replication. It simply looks for the + payload in all the partition files of the specified topic. 'messages' should be + an array of numbers. The list of missing messages is returned. + """ + payload_match = "payload: " + "$|payload: ".join(str(x) for x in messages) + "$" + found = set([]) + self.logger.debug("number of unique missing messages we will search for: %d", + len(messages)) + for node in self.nodes: + # Grab all .log files in directories prefixed with this topic + files = node.account.ssh_capture("find %s* -regex '.*/%s-.*/[^/]*.log'" % (KafkaService.DATA_LOG_DIR_PREFIX, topic)) + + # Check each data file to see if it contains the messages we want + for log in files: + cmd = fix_opts_for_new_jvm(node) + cmd += "%s kafka.tools.DumpLogSegments --print-data-log --files %s | grep -E \"%s\"" % \ + (self.path.script("kafka-run-class.sh", node), log.strip(), payload_match) + + for line in node.account.ssh_capture(cmd, allow_fail=True): + for val in messages: + if line.strip().endswith("payload: "+str(val)): + self.logger.debug("Found %s in data-file [%s] in line: [%s]" % (val, log.strip(), line.strip())) + found.add(val) + + self.logger.debug("Number of unique messages found in the log: %d", + len(found)) + missing = list(set(messages) - found) + + if len(missing) > 0: + self.logger.warn("The following values were not found in the data files: " + str(missing)) + + return missing + + def restart_cluster(self, clean_shutdown=True, timeout_sec=60, after_each_broker_restart=None, *args): + # We do not restart the remote controller quorum if it exists. + # This is not widely used -- it typically appears in rolling upgrade tests -- + # so we will let tests explicitly decide if/when to restart any remote controller quorum. + for node in self.nodes: + self.restart_node(node, clean_shutdown=clean_shutdown, timeout_sec=timeout_sec) + if after_each_broker_restart is not None: + after_each_broker_restart(*args) + + def restart_node(self, node, clean_shutdown=True, timeout_sec=60): + """Restart the given node.""" + # ensure we wait for the broker to start by setting concurrent start to False for the invocation of start_node() + orig_concurrent_start = self.concurrent_start + self.concurrent_start = False + self.stop_node(node, clean_shutdown, timeout_sec) + self.start_node(node, timeout_sec) + self.concurrent_start = orig_concurrent_start + + def _describe_topic_line_for_partition(self, partition, describe_topic_output): + # Lines look like this: Topic: test_topic Partition: 0 Leader: 3 Replicas: 3,2 Isr: 3,2 + grep_for = "Partition: %i\t" % (partition) # be sure to include trailing tab, otherwise 1 might match 10 (for example) + found_lines = [line for line in describe_topic_output.splitlines() if grep_for in line] + return None if not found_lines else found_lines[0] + + def isr_idx_list(self, topic, partition=0, node=None, offline_nodes=[]): + """ Get in-sync replica list the given topic and partition. + """ + if node is None: + node = self.nodes[0] + if not self.all_nodes_topic_command_supports_bootstrap_server(): + self.logger.debug("Querying zookeeper to find in-sync replicas for topic %s and partition %d" % (topic, partition)) + zk_path = "/brokers/topics/%s/partitions/%d/state" % (topic, partition) + partition_state = self.zk.query(zk_path, chroot=self.zk_chroot) + + if partition_state is None: + raise Exception("Error finding partition state for topic %s and partition %d." % (topic, partition)) + + partition_state = json.loads(partition_state) + self.logger.info(partition_state) + + isr_idx_list = partition_state["isr"] + else: + self.logger.debug("Querying Kafka Admin API to find in-sync replicas for topic %s and partition %d" % (topic, partition)) + describe_output = self.describe_topic(topic, node, offline_nodes=offline_nodes) + self.logger.debug(describe_output) + requested_partition_line = self._describe_topic_line_for_partition(partition, describe_output) + # e.g. Topic: test_topic Partition: 0 Leader: 3 Replicas: 3,2 Isr: 3,2 + if not requested_partition_line: + raise Exception("Error finding partition state for topic %s and partition %d." % (topic, partition)) + isr_csv = requested_partition_line.split()[9] # 10th column from above + isr_idx_list = [int(i) for i in isr_csv.split(",")] + + self.logger.info("Isr for topic %s and partition %d is now: %s" % (topic, partition, isr_idx_list)) + return isr_idx_list + + def replicas(self, topic, partition=0): + """ Get the assigned replicas for the given topic and partition. + """ + node = self.nodes[0] + if not self.all_nodes_topic_command_supports_bootstrap_server(): + self.logger.debug("Querying zookeeper to find assigned replicas for topic %s and partition %d" % (topic, partition)) + zk_path = "/brokers/topics/%s" % (topic) + assignment = self.zk.query(zk_path, chroot=self.zk_chroot) + + if assignment is None: + raise Exception("Error finding partition state for topic %s and partition %d." % (topic, partition)) + + assignment = json.loads(assignment) + self.logger.info(assignment) + + replicas = assignment["partitions"][str(partition)] + else: + self.logger.debug("Querying Kafka Admin API to find replicas for topic %s and partition %d" % (topic, partition)) + describe_output = self.describe_topic(topic, node) + self.logger.debug(describe_output) + requested_partition_line = self._describe_topic_line_for_partition(partition, describe_output) + # e.g. Topic: test_topic Partition: 0 Leader: 3 Replicas: 3,2 Isr: 3,2 + if not requested_partition_line: + raise Exception("Error finding partition state for topic %s and partition %d." % (topic, partition)) + isr_csv = requested_partition_line.split()[7] # 8th column from above + replicas = [int(i) for i in isr_csv.split(",")] + + self.logger.info("Assigned replicas for topic %s and partition %d is now: %s" % (topic, partition, replicas)) + return [self.get_node(replica) for replica in replicas] + + def leader(self, topic, partition=0): + """ Get the leader replica for the given topic and partition. + """ + node = self.nodes[0] + if not self.all_nodes_topic_command_supports_bootstrap_server(): + self.logger.debug("Querying zookeeper to find leader replica for topic %s and partition %d" % (topic, partition)) + zk_path = "/brokers/topics/%s/partitions/%d/state" % (topic, partition) + partition_state = self.zk.query(zk_path, chroot=self.zk_chroot) + + if partition_state is None: + raise Exception("Error finding partition state for topic %s and partition %d." % (topic, partition)) + + partition_state = json.loads(partition_state) + self.logger.info(partition_state) + + leader_idx = int(partition_state["leader"]) + else: + self.logger.debug("Querying Kafka Admin API to find leader for topic %s and partition %d" % (topic, partition)) + describe_output = self.describe_topic(topic, node) + self.logger.debug(describe_output) + requested_partition_line = self._describe_topic_line_for_partition(partition, describe_output) + # e.g. Topic: test_topic Partition: 0 Leader: 3 Replicas: 3,2 Isr: 3,2 + if not requested_partition_line: + raise Exception("Error finding partition state for topic %s and partition %d." % (topic, partition)) + leader_idx = int(requested_partition_line.split()[5]) # 6th column from above + + self.logger.info("Leader for topic %s and partition %d is now: %d" % (topic, partition, leader_idx)) + return self.get_node(leader_idx) + + def cluster_id(self): + """ Get the current cluster id + """ + if self.quorum_info.using_kraft: + return config_property.CLUSTER_ID + + self.logger.debug("Querying ZooKeeper to retrieve cluster id") + cluster = self.zk.query("/cluster/id", chroot=self.zk_chroot) + + try: + return json.loads(cluster)['id'] if cluster else None + except: + self.logger.debug("Data in /cluster/id znode could not be parsed. Data = %s" % cluster) + raise + + def topic_id(self, topic): + if self.all_nodes_support_topic_ids(): + node = self.nodes[0] + + force_use_zk_connection = not self.all_nodes_topic_command_supports_bootstrap_server() + + cmd = fix_opts_for_new_jvm(node) + cmd += "%s --topic %s --describe" % \ + (self.kafka_topics_cmd_with_optional_security_settings(node, force_use_zk_connection), topic) + + self.logger.debug( + "Querying topic ID by using describe topic command ...\n%s" % cmd + ) + output = "" + for line in node.account.ssh_capture(cmd): + output += line + + lines = map(lambda x: x.strip(), output.split("\n")) + for line in lines: + m = re.match(".*TopicId:.*", line) + if m is None: + continue + + fields = line.split("\t") + # [Topic: test_topic, TopicId: , PartitionCount: 2, ReplicationFactor: 2, ...] + # -> [test_topic, , 2, 2, ...] + # -> + topic_id = list(map(lambda x: x.split(" ")[1], fields))[1] + self.logger.info("Topic ID assigned for topic %s is %s" % (topic, topic_id)) + + return topic_id + raise Exception("Error finding topic ID for topic %s." % topic) + else: + self.logger.info("No topic ID assigned for topic %s" % topic) + return None + + def check_protocol_errors(self, node): + """ Checks for common protocol exceptions due to invalid inter broker protocol handling. + While such errors can and should be checked in other ways, checking the logs is a worthwhile failsafe. + """ + for node in self.nodes: + exit_code = node.account.ssh("grep -e 'java.lang.IllegalArgumentException: Invalid version' -e SchemaException %s/*" + % KafkaService.OPERATIONAL_LOG_DEBUG_DIR, allow_fail=True) + if exit_code != 1: + return False + return True + + def list_consumer_groups(self, node=None, command_config=None): + """ Get list of consumer groups. + """ + if node is None: + node = self.nodes[0] + consumer_group_script = self.path.script("kafka-consumer-groups.sh", node) + + if command_config is None: + command_config = "" + else: + command_config = "--command-config " + command_config + + cmd = fix_opts_for_new_jvm(node) + cmd += "%s --bootstrap-server %s %s --list" % \ + (consumer_group_script, + self.bootstrap_servers(self.security_protocol), + command_config) + return self.run_cli_tool(node, cmd) + + def describe_consumer_group(self, group, node=None, command_config=None): + """ Describe a consumer group. + """ + if node is None: + node = self.nodes[0] + consumer_group_script = self.path.script("kafka-consumer-groups.sh", node) + + if command_config is None: + command_config = "" + else: + command_config = "--command-config " + command_config + + cmd = fix_opts_for_new_jvm(node) + cmd += "%s --bootstrap-server %s %s --group %s --describe" % \ + (consumer_group_script, + self.bootstrap_servers(self.security_protocol), + command_config, group) + + output = "" + self.logger.debug(cmd) + for line in node.account.ssh_capture(cmd): + if not (line.startswith("SLF4J") or line.startswith("TOPIC") or line.startswith("Could not fetch offset")): + output += line + self.logger.debug(output) + return output + + def zk_connect_setting(self): + if self.quorum_info.using_kraft and not self.zk: + raise Exception("No zookeeper connect string available with KRaft unless ZooKeeper is explicitly enabled") + return self.zk.connect_setting(self.zk_chroot, self.zk_client_secure) + + def __bootstrap_servers(self, port, validate=True, offline_nodes=[]): + if validate and not port.open: + raise ValueError("We are retrieving bootstrap servers for the port: %s which is not currently open. - " % + str(port.port_number)) + + return ','.join([node.account.hostname + ":" + str(port.port_number) + for node in self.nodes + if node not in offline_nodes]) + + def bootstrap_servers(self, protocol='PLAINTEXT', validate=True, offline_nodes=[]): + """Return comma-delimited list of brokers in this cluster formatted as HOSTNAME1:PORT1,HOSTNAME:PORT2,... + + This is the format expected by many config files. + """ + port_mapping = self.port_mappings[protocol] + self.logger.info("Bootstrap client port is: " + str(port_mapping.port_number)) + return self.__bootstrap_servers(port_mapping, validate, offline_nodes) + + def controller(self): + """ Get the controller node + """ + if self.quorum_info.using_kraft: + raise Exception("Cannot obtain Controller node when using KRaft instead of ZooKeeper") + self.logger.debug("Querying zookeeper to find controller broker") + controller_info = self.zk.query("/controller", chroot=self.zk_chroot) + + if controller_info is None: + raise Exception("Error finding controller info") + + controller_info = json.loads(controller_info) + self.logger.debug(controller_info) + + controller_idx = int(controller_info["brokerid"]) + self.logger.info("Controller's ID: %d" % (controller_idx)) + return self.get_node(controller_idx) + + def is_registered(self, node): + """ + Check whether a broker is registered in Zookeeper + """ + if self.quorum_info.using_kraft: + raise Exception("Cannot obtain broker registration information when using KRaft instead of ZooKeeper") + self.logger.debug("Querying zookeeper to see if broker %s is registered", str(node)) + broker_info = self.zk.query("/brokers/ids/%s" % self.idx(node), chroot=self.zk_chroot) + self.logger.debug("Broker info: %s", broker_info) + return broker_info is not None + + def get_offset_shell(self, time=None, topic=None, partitions=None, topic_partitions=None, exclude_internal_topics=False): + node = self.nodes[0] + + cmd = fix_opts_for_new_jvm(node) + cmd += self.path.script("kafka-run-class.sh", node) + cmd += " kafka.tools.GetOffsetShell" + cmd += " --bootstrap-server %s" % self.bootstrap_servers(self.security_protocol) + + if time: + cmd += ' --time %s' % time + if topic_partitions: + cmd += ' --topic-partitions %s' % topic_partitions + if topic: + cmd += ' --topic %s' % topic + if partitions: + cmd += ' --partitions %s' % partitions + if exclude_internal_topics: + cmd += ' --exclude-internal-topics' + + cmd += " 2>> %s/get_offset_shell.log" % KafkaService.PERSISTENT_ROOT + cmd += " | tee -a %s/get_offset_shell.log &" % KafkaService.PERSISTENT_ROOT + output = "" + self.logger.debug(cmd) + for line in node.account.ssh_capture(cmd): + output += line + self.logger.debug(output) + return output + + def java_class_name(self): + return "kafka.Kafka" diff --git a/tests/kafkatest/services/kafka/quorum.py b/tests/kafkatest/services/kafka/quorum.py new file mode 100644 index 0000000..d188eec --- /dev/null +++ b/tests/kafkatest/services/kafka/quorum.py @@ -0,0 +1,144 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# the types of metadata quorums we support +zk = 'ZK' # ZooKeeper, used before/during the KIP-500 bridge release(s) +colocated_kraft = 'COLOCATED_KRAFT' # co-located Controllers in KRaft mode, used during/after the KIP-500 bridge release(s) +remote_kraft = 'REMOTE_KRAFT' # separate Controllers in KRaft mode, used during/after the KIP-500 bridge release(s) + +# How we will parameterize tests that exercise all quorum styles +# [“ZK”, “REMOTE_KRAFT”, "COLOCATED_KRAFT"] during the KIP-500 bridge release(s) +# [“REMOTE_KRAFT”, "COLOCATED_KRAFT”] after the KIP-500 bridge release(s) +all = [zk, remote_kraft, colocated_kraft] +# How we will parameterize tests that exercise all KRaft quorum styles +all_kraft = [remote_kraft, colocated_kraft] +# How we will parameterize tests that are unrelated to upgrades: +# [“ZK”] before the KIP-500 bridge release(s) +# [“ZK”, “REMOTE_KRAFT”] during the KIP-500 bridge release(s) and in preview releases +# [“REMOTE_KRAFT”] after the KIP-500 bridge release(s) +all_non_upgrade = [zk, remote_kraft] + +def for_test(test_context): + # A test uses ZooKeeper if it doesn't specify a metadata quorum or if it explicitly specifies ZooKeeper + default_quorum_type = zk + arg_name = 'metadata_quorum' + retval = default_quorum_type if not test_context.injected_args else test_context.injected_args.get(arg_name, default_quorum_type) + if retval not in all: + raise Exception("Unknown %s value provided for the test: %s" % (arg_name, retval)) + return retval + +class ServiceQuorumInfo: + """ + Exposes quorum-related information for a KafkaService + + Kafka can use either ZooKeeper or a KRaft (Kafka Raft) Controller quorum for + its metadata. KRaft Controllers can either be co-located with Kafka in + the same JVM or remote in separate JVMs. The choice is made via + the 'metadata_quorum' parameter defined for the system test: if it + is not explicitly defined, or if it is set to 'ZK', then ZooKeeper + is used. If it is explicitly set to 'COLOCATED_KRAFT' then KRaft + controllers will be co-located with the brokers; the value + `REMOTE_KRAFT` indicates remote controllers. + + Attributes + ---------- + + kafka : KafkaService + The service for which this instance exposes quorum-related + information + quorum_type : str + COLOCATED_KRAFT, REMOTE_KRAFT, or ZK + using_zk : bool + True iff quorum_type==ZK + using_kraft : bool + False iff quorum_type==ZK + has_brokers : bool + Whether there is at least one node with process.roles + containing 'broker'. True iff using_kraft and the Kafka + service doesn't itself have a remote Kafka service (meaning + it is not a remote controller quorum). + has_controllers : bool + Whether there is at least one node with process.roles + containing 'controller'. True iff quorum_type == + COLOCATED_KRAFT or the Kafka service itself has a remote Kafka + service (meaning it is a remote controller quorum). + has_brokers_and_controllers : + True iff quorum_type==COLOCATED_KRAFT + """ + + def __init__(self, kafka, context): + """ + + :param kafka : KafkaService + The service for which this instance exposes quorum-related + information + :param context : TestContext + The test context within which the this instance and the + given Kafka service is being instantiated + """ + + quorum_type = for_test(context) + if quorum_type != zk and kafka.zk and not kafka.allow_zk_with_kraft: + raise Exception("Cannot use ZooKeeper while specifying a KRaft metadata quorum unless explicitly allowing it") + if kafka.remote_kafka and quorum_type != remote_kraft: + raise Exception("Cannot specify a remote Kafka service unless using a remote KRaft metadata quorum (should not happen)") + self.kafka = kafka + self.quorum_type = quorum_type + self.using_zk = quorum_type == zk + self.using_kraft = not self.using_zk + self.has_brokers = self.using_kraft and not kafka.remote_kafka + self.has_controllers = quorum_type == colocated_kraft or kafka.remote_kafka + self.has_brokers_and_controllers = quorum_type == colocated_kraft + +class NodeQuorumInfo: + """ + Exposes quorum-related information for a node in a KafkaService + + Attributes + ---------- + service_quorum_info : ServiceQuorumInfo + The quorum information about the service to which the node + belongs + has_broker_role : bool + True iff using_kraft and the Kafka service doesn't itself have + a remote Kafka service (meaning it is not a remote controller) + has_controller_role : bool + True iff quorum_type==COLOCATED_KRAFT and the node is one of + the first N in the cluster where N is the number of nodes + that have a controller role; or the Kafka service itself has a + remote Kafka service (meaning it is a remote controller + quorum). + has_combined_broker_and_controller_roles : + True iff has_broker_role==True and has_controller_role==true + """ + + def __init__(self, service_quorum_info, node): + """ + :param service_quorum_info : ServiceQuorumInfo + The quorum information about the service to which the node + belongs + :param node : Node + The particular node for which this information applies. + In the co-located case, whether or not a node's broker's + process.roles contains 'controller' may vary based on the + particular node if the number of controller nodes is less + than the number of nodes in the service. + """ + + self.service_quorum_info = service_quorum_info + self.has_broker_role = self.service_quorum_info.has_brokers + idx = self.service_quorum_info.kafka.nodes.index(node) + self.has_controller_role = self.service_quorum_info.kafka.num_nodes_controller_role > idx + self.has_combined_broker_and_controller_roles = self.has_broker_role and self.has_controller_role diff --git a/tests/kafkatest/services/kafka/templates/kafka.properties b/tests/kafkatest/services/kafka/templates/kafka.properties new file mode 100644 index 0000000..65bf389 --- /dev/null +++ b/tests/kafkatest/services/kafka/templates/kafka.properties @@ -0,0 +1,129 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# see kafka.server.KafkaConfig for additional details and defaults +{% if quorum_info.using_kraft %} +# The role(s) of this server. Setting this puts us in KRaft metadata quorum mode +{% if node_quorum_info.has_combined_broker_and_controller_roles %} +process.roles=broker,controller +{% elif node_quorum_info.has_controller_role %} +process.roles=controller +{% else %} +process.roles=broker +{% endif %} +# The connect string for the controller quorum +controller.quorum.voters={{ controller_quorum_voters }} + +controller.listener.names={{ controller_listener_names }} + +{% endif %} + +listeners={{ listeners }} + +listener.security.protocol.map={{ listener_security_protocol_map }} + +{% if quorum_info.using_zk or quorum_info.has_brokers %} +advertised.host.name={{ node.account.hostname }} +advertised.listeners={{ advertised_listeners }} + +{% if node.version.supports_named_listeners() %} +inter.broker.listener.name={{ interbroker_listener.name }} +{% else %} +security.inter.broker.protocol={{ interbroker_listener.security_protocol }} +{% endif %} +{% endif %} + +{% for k, v in listener_security_config.client_listener_overrides.items() %} +{% if listener_security_config.requires_sasl_mechanism_prefix(k) %} +listener.name.{{ security_protocol.lower() }}.{{ security_config.client_sasl_mechanism.lower() }}.{{ k }}={{ v }} +{% else %} +listener.name.{{ security_protocol.lower() }}.{{ k }}={{ v }} +{% endif %} +{% endfor %} + +{% if quorum_info.using_zk or quorum_info.has_brokers %} +{% if interbroker_listener.name != security_protocol %} +{% for k, v in listener_security_config.interbroker_listener_overrides.items() %} +{% if listener_security_config.requires_sasl_mechanism_prefix(k) %} +listener.name.{{ interbroker_listener.name.lower() }}.{{ security_config.interbroker_sasl_mechanism.lower() }}.{{ k }}={{ v }} +{% else %} +listener.name.{{ interbroker_listener.name.lower() }}.{{ k }}={{ v }} +{% endif %} +{% endfor %} +{% endif %} +{% endif %} + +{% if security_config.tls_version is not none %} +ssl.enabled.protocols={{ security_config.tls_version }} +ssl.protocol={{ security_config.tls_version }} +{% endif %} +ssl.keystore.location=/mnt/security/test.keystore.jks +ssl.keystore.password=test-ks-passwd +ssl.key.password=test-ks-passwd +ssl.keystore.type=JKS +ssl.truststore.location=/mnt/security/test.truststore.jks +ssl.truststore.password=test-ts-passwd +ssl.truststore.type=JKS +ssl.endpoint.identification.algorithm=HTTPS + +{% if quorum_info.using_zk %} +# Zookeeper TLS settings +# +# Note that zookeeper.ssl.client.enable will be set to true or false elsewhere, as appropriate. +# If it is false then these ZK keystore/truststore settings will have no effect. If it is true then +# zookeeper.clientCnxnSocket will also be set elsewhere (to org.apache.zookeeper.ClientCnxnSocketNetty) +{% if not zk.zk_tls_encrypt_only %} +zookeeper.ssl.keystore.location=/mnt/security/test.keystore.jks +zookeeper.ssl.keystore.password=test-ks-passwd +{% endif %} +zookeeper.ssl.truststore.location=/mnt/security/test.truststore.jks +zookeeper.ssl.truststore.password=test-ts-passwd +{% endif %} +# +{% if quorum_info.using_zk or quorum_info.has_brokers %} +sasl.mechanism.inter.broker.protocol={{ security_config.interbroker_sasl_mechanism }} +{% endif %} +{% if quorum_info.using_kraft %} +{% if not quorum_info.has_brokers %} +sasl.mechanism.controller.protocol={{ intercontroller_sasl_mechanism }} +{% else %} +sasl.mechanism.controller.protocol={{ controller_quorum.controller_sasl_mechanism }} +{% endif %} +{% endif %} +sasl.enabled.mechanisms={{ ",".join(security_config.enabled_sasl_mechanisms) }} +sasl.kerberos.service.name=kafka +{% if authorizer_class_name is not none %} +ssl.client.auth=required +authorizer.class.name={{ authorizer_class_name }} +{% endif %} + +{% if quorum_info.using_zk %} +zookeeper.set.acl={{"true" if zk_set_acl else "false"}} + +zookeeper.connection.timeout.ms={{ zk_connect_timeout }} +zookeeper.session.timeout.ms={{ zk_session_timeout }} +{% endif %} + +{% if replica_lag is defined %} +replica.lag.time.max.ms={{replica_lag}} +{% endif %} + +{% if auto_create_topics_enable is defined and auto_create_topics_enable is not none %} +auto.create.topics.enable={{ auto_create_topics_enable }} +{% endif %} +offsets.topic.num.partitions={{ num_nodes }} +offsets.topic.replication.factor={{ 3 if num_nodes > 3 else num_nodes }} +# Set to a low, but non-zero value to exercise this path without making tests much slower +group.initial.rebalance.delay.ms=100 diff --git a/tests/kafkatest/services/kafka/templates/log4j.properties b/tests/kafkatest/services/kafka/templates/log4j.properties new file mode 100644 index 0000000..5963c39 --- /dev/null +++ b/tests/kafkatest/services/kafka/templates/log4j.properties @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +log4j.rootLogger={{ log_level|default("DEBUG") }}, stdout + +log4j.appender.stdout=org.apache.log4j.ConsoleAppender +log4j.appender.stdout.layout=org.apache.log4j.PatternLayout +log4j.appender.stdout.layout.ConversionPattern=[%d] %p %m (%c)%n + +# INFO level appenders +log4j.appender.kafkaInfoAppender=org.apache.log4j.DailyRollingFileAppender +log4j.appender.kafkaInfoAppender.DatePattern='.'yyyy-MM-dd-HH +log4j.appender.kafkaInfoAppender.File={{ log_dir }}/info/server.log +log4j.appender.kafkaInfoAppender.layout=org.apache.log4j.PatternLayout +log4j.appender.kafkaInfoAppender.layout.ConversionPattern=[%d] %p %m (%c)%n +log4j.appender.kafkaInfoAppender.Threshold=INFO + +log4j.appender.stateChangeInfoAppender=org.apache.log4j.DailyRollingFileAppender +log4j.appender.stateChangeInfoAppender.DatePattern='.'yyyy-MM-dd-HH +log4j.appender.stateChangeInfoAppender.File={{ log_dir }}/info/state-change.log +log4j.appender.stateChangeInfoAppender.layout=org.apache.log4j.PatternLayout +log4j.appender.stateChangeInfoAppender.layout.ConversionPattern=[%d] %p %m (%c)%n +log4j.appender.stateChangeInfoAppender.Threshold=INFO + +log4j.appender.requestInfoAppender=org.apache.log4j.DailyRollingFileAppender +log4j.appender.requestInfoAppender.DatePattern='.'yyyy-MM-dd-HH +log4j.appender.requestInfoAppender.File={{ log_dir }}/info/kafka-request.log +log4j.appender.requestInfoAppender.layout=org.apache.log4j.PatternLayout +log4j.appender.requestInfoAppender.layout.ConversionPattern=[%d] %p %m (%c)%n +log4j.appender.requestInfoAppender.Threshold=INFO + +log4j.appender.cleanerInfoAppender=org.apache.log4j.DailyRollingFileAppender +log4j.appender.cleanerInfoAppender.DatePattern='.'yyyy-MM-dd-HH +log4j.appender.cleanerInfoAppender.File={{ log_dir }}/info/log-cleaner.log +log4j.appender.cleanerInfoAppender.layout=org.apache.log4j.PatternLayout +log4j.appender.cleanerInfoAppender.layout.ConversionPattern=[%d] %p %m (%c)%n +log4j.appender.cleanerInfoAppender.Threshold=INFO + +log4j.appender.controllerInfoAppender=org.apache.log4j.DailyRollingFileAppender +log4j.appender.controllerInfoAppender.DatePattern='.'yyyy-MM-dd-HH +log4j.appender.controllerInfoAppender.File={{ log_dir }}/info/controller.log +log4j.appender.controllerInfoAppender.layout=org.apache.log4j.PatternLayout +log4j.appender.controllerInfoAppender.layout.ConversionPattern=[%d] %p %m (%c)%n +log4j.appender.controllerInfoAppender.Threshold=INFO + +log4j.appender.authorizerInfoAppender=org.apache.log4j.DailyRollingFileAppender +log4j.appender.authorizerInfoAppender.DatePattern='.'yyyy-MM-dd-HH +log4j.appender.authorizerInfoAppender.File={{ log_dir }}/info/kafka-authorizer.log +log4j.appender.authorizerInfoAppender.layout=org.apache.log4j.PatternLayout +log4j.appender.authorizerInfoAppender.layout.ConversionPattern=[%d] %p %m (%c)%n +log4j.appender.authorizerInfoAppender.Threshold=INFO + +# DEBUG level appenders +log4j.appender.kafkaDebugAppender=org.apache.log4j.DailyRollingFileAppender +log4j.appender.kafkaDebugAppender.DatePattern='.'yyyy-MM-dd-HH +log4j.appender.kafkaDebugAppender.File={{ log_dir }}/debug/server.log +log4j.appender.kafkaDebugAppender.layout=org.apache.log4j.PatternLayout +log4j.appender.kafkaDebugAppender.layout.ConversionPattern=[%d] %p %m (%c)%n +log4j.appender.kafkaDebugAppender.Threshold=DEBUG + +log4j.appender.stateChangeDebugAppender=org.apache.log4j.DailyRollingFileAppender +log4j.appender.stateChangeDebugAppender.DatePattern='.'yyyy-MM-dd-HH +log4j.appender.stateChangeDebugAppender.File={{ log_dir }}/debug/state-change.log +log4j.appender.stateChangeDebugAppender.layout=org.apache.log4j.PatternLayout +log4j.appender.stateChangeDebugAppender.layout.ConversionPattern=[%d] %p %m (%c)%n +log4j.appender.stateChangeDebugAppender.Threshold=DEBUG + +log4j.appender.requestDebugAppender=org.apache.log4j.DailyRollingFileAppender +log4j.appender.requestDebugAppender.DatePattern='.'yyyy-MM-dd-HH +log4j.appender.requestDebugAppender.File={{ log_dir }}/debug/kafka-request.log +log4j.appender.requestDebugAppender.layout=org.apache.log4j.PatternLayout +log4j.appender.requestDebugAppender.layout.ConversionPattern=[%d] %p %m (%c)%n +log4j.appender.requestDebugAppender.Threshold=DEBUG + +log4j.appender.cleanerDebugAppender=org.apache.log4j.DailyRollingFileAppender +log4j.appender.cleanerDebugAppender.DatePattern='.'yyyy-MM-dd-HH +log4j.appender.cleanerDebugAppender.File={{ log_dir }}/debug/log-cleaner.log +log4j.appender.cleanerDebugAppender.layout=org.apache.log4j.PatternLayout +log4j.appender.cleanerDebugAppender.layout.ConversionPattern=[%d] %p %m (%c)%n +log4j.appender.cleanerDebugAppender.Threshold=DEBUG + +log4j.appender.controllerDebugAppender=org.apache.log4j.DailyRollingFileAppender +log4j.appender.controllerDebugAppender.DatePattern='.'yyyy-MM-dd-HH +log4j.appender.controllerDebugAppender.File={{ log_dir }}/debug/controller.log +log4j.appender.controllerDebugAppender.layout=org.apache.log4j.PatternLayout +log4j.appender.controllerDebugAppender.layout.ConversionPattern=[%d] %p %m (%c)%n +log4j.appender.controllerDebugAppender.Threshold=DEBUG + +log4j.appender.authorizerDebugAppender=org.apache.log4j.DailyRollingFileAppender +log4j.appender.authorizerDebugAppender.DatePattern='.'yyyy-MM-dd-HH +log4j.appender.authorizerDebugAppender.File={{ log_dir }}/debug/kafka-authorizer.log +log4j.appender.authorizerDebugAppender.layout=org.apache.log4j.PatternLayout +log4j.appender.authorizerDebugAppender.layout.ConversionPattern=[%d] %p %m (%c)%n +log4j.appender.authorizerDebugAppender.Threshold=DEBUG + +# Turn on all our debugging info +log4j.logger.kafka.producer.async.DefaultEventHandler={{ log_level|default("DEBUG") }}, kafkaInfoAppender, kafkaDebugAppender +log4j.logger.kafka.client.ClientUtils={{ log_level|default("DEBUG") }}, kafkaInfoAppender, kafkaDebugAppender +log4j.logger.kafka.perf={{ log_level|default("DEBUG") }}, kafkaInfoAppender, kafkaDebugAppender +log4j.logger.kafka.perf.ProducerPerformance$ProducerThread={{ log_level|default("DEBUG") }}, kafkaInfoAppender, kafkaDebugAppender +log4j.logger.kafka={{ log_level|default("DEBUG") }}, kafkaInfoAppender, kafkaDebugAppender + +log4j.logger.kafka.network.RequestChannel$={{ log_level|default("DEBUG") }}, requestInfoAppender, requestDebugAppender +log4j.additivity.kafka.network.RequestChannel$=false + +log4j.logger.kafka.network.Processor={{ log_level|default("DEBUG") }}, requestInfoAppender, requestDebugAppender +log4j.logger.kafka.server.KafkaApis={{ log_level|default("DEBUG") }}, requestInfoAppender, requestDebugAppender +log4j.additivity.kafka.server.KafkaApis=false +log4j.logger.kafka.request.logger={{ log_level|default("DEBUG") }}, requestInfoAppender, requestDebugAppender +log4j.additivity.kafka.request.logger=false + +log4j.logger.kafka.controller={{ log_level|default("DEBUG") }}, controllerInfoAppender, controllerDebugAppender +log4j.additivity.kafka.controller=false + +log4j.logger.kafka.log.LogCleaner={{ log_level|default("DEBUG") }}, cleanerInfoAppender, cleanerDebugAppender +log4j.additivity.kafka.log.LogCleaner=false + +log4j.logger.state.change.logger={{ log_level|default("DEBUG") }}, stateChangeInfoAppender, stateChangeDebugAppender +log4j.additivity.state.change.logger=false + +#Change this to debug to get the actual audit log for authorizer. +log4j.logger.kafka.authorizer.logger={{ log_level|default("DEBUG") }}, authorizerInfoAppender, authorizerDebugAppender +log4j.additivity.kafka.authorizer.logger=false + diff --git a/tests/kafkatest/services/kafka/util.py b/tests/kafkatest/services/kafka/util.py new file mode 100644 index 0000000..8782ebe --- /dev/null +++ b/tests/kafkatest/services/kafka/util.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple + +from kafkatest.utils.remote_account import java_version +from kafkatest.version import LATEST_0_8_2, LATEST_0_9, LATEST_0_10_0, LATEST_0_10_1, LATEST_0_10_2, LATEST_0_11_0, LATEST_1_0 + +TopicPartition = namedtuple('TopicPartition', ['topic', 'partition']) + +new_jdk_not_supported = frozenset([str(LATEST_0_8_2), str(LATEST_0_9), str(LATEST_0_10_0), str(LATEST_0_10_1), str(LATEST_0_10_2), str(LATEST_0_11_0), str(LATEST_1_0)]) + + +def fix_opts_for_new_jvm(node): + # Startup scripts for early versions of Kafka contains options + # that not supported on latest versions of JVM like -XX:+PrintGCDateStamps or -XX:UseParNewGC. + # When system test run on JVM that doesn't support these options + # we should setup environment variables with correct options. + java_ver = java_version(node) + if java_ver <= 9: + return "" + + cmd = "" + if node.version == LATEST_0_8_2 or node.version == LATEST_0_9 or node.version == LATEST_0_10_0 or node.version == LATEST_0_10_1 or node.version == LATEST_0_10_2 or node.version == LATEST_0_11_0 or node.version == LATEST_1_0: + cmd += "export KAFKA_GC_LOG_OPTS=\"-Xlog:gc*:file=kafka-gc.log:time,tags:filecount=10,filesize=102400\"; " + cmd += "export KAFKA_JVM_PERFORMANCE_OPTS=\"-server -XX:+UseG1GC -XX:MaxGCPauseMillis=20 -XX:InitiatingHeapOccupancyPercent=35 -XX:+ExplicitGCInvokesConcurrent -XX:MaxInlineLevel=15 -Djava.awt.headless=true\"; " + return cmd + + diff --git a/tests/kafkatest/services/kafka_log4j_appender.py b/tests/kafkatest/services/kafka_log4j_appender.py new file mode 100644 index 0000000..1212a7d --- /dev/null +++ b/tests/kafkatest/services/kafka_log4j_appender.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ducktape.services.background_thread import BackgroundThreadService + +from kafkatest.directory_layout.kafka_path import KafkaPathResolverMixin +from kafkatest.services.security.security_config import SecurityConfig +from kafkatest.services.kafka.util import fix_opts_for_new_jvm + + +class KafkaLog4jAppender(KafkaPathResolverMixin, BackgroundThreadService): + + logs = { + "producer_log": { + "path": "/mnt/kafka_log4j_appender.log", + "collect_default": False} + } + + def __init__(self, context, num_nodes, kafka, topic, max_messages=-1, security_protocol="PLAINTEXT", tls_version=None): + super(KafkaLog4jAppender, self).__init__(context, num_nodes) + + self.kafka = kafka + self.topic = topic + self.max_messages = max_messages + self.security_protocol = security_protocol + self.security_config = SecurityConfig(self.context, security_protocol, tls_version=tls_version) + self.stop_timeout_sec = 30 + + for node in self.nodes: + node.version = kafka.nodes[0].version + + def _worker(self, idx, node): + cmd = self.start_cmd(node) + self.logger.debug("VerifiableLog4jAppender %d command: %s" % (idx, cmd)) + self.security_config.setup_node(node) + node.account.ssh(cmd) + + def start_cmd(self, node): + cmd = fix_opts_for_new_jvm(node) + cmd += self.path.script("kafka-run-class.sh", node) + cmd += " " + cmd += self.java_class_name() + cmd += " --topic %s --broker-list %s" % (self.topic, self.kafka.bootstrap_servers(self.security_protocol)) + + if self.max_messages > 0: + cmd += " --max-messages %s" % str(self.max_messages) + if self.security_protocol != SecurityConfig.PLAINTEXT: + cmd += " --security-protocol %s" % str(self.security_protocol) + if self.security_protocol == SecurityConfig.SSL or self.security_protocol == SecurityConfig.SASL_SSL: + cmd += " --ssl-truststore-location %s" % str(SecurityConfig.TRUSTSTORE_PATH) + cmd += " --ssl-truststore-password %s" % str(SecurityConfig.ssl_stores.truststore_passwd) + if self.security_protocol == SecurityConfig.SASL_PLAINTEXT or \ + self.security_protocol == SecurityConfig.SASL_SSL or \ + self.security_protocol == SecurityConfig.SASL_MECHANISM_GSSAPI or \ + self.security_protocol == SecurityConfig.SASL_MECHANISM_PLAIN: + cmd += " --sasl-kerberos-service-name %s" % str('kafka') + cmd += " --client-jaas-conf-path %s" % str(SecurityConfig.JAAS_CONF_PATH) + cmd += " --kerb5-conf-path %s" % str(SecurityConfig.KRB5CONF_PATH) + + cmd += " 2>> /mnt/kafka_log4j_appender.log | tee -a /mnt/kafka_log4j_appender.log &" + return cmd + + def stop_node(self, node): + node.account.kill_java_processes(self.java_class_name(), allow_fail=False) + + stopped = self.wait_node(node, timeout_sec=self.stop_timeout_sec) + assert stopped, "Node %s: did not stop within the specified timeout of %s seconds" % \ + (str(node.account), str(self.stop_timeout_sec)) + + def clean_node(self, node): + node.account.kill_java_processes(self.java_class_name(), clean_shutdown=False, + allow_fail=False) + node.account.ssh("rm -rf /mnt/kafka_log4j_appender.log", allow_fail=False) + + def java_class_name(self): + return "org.apache.kafka.tools.VerifiableLog4jAppender" diff --git a/tests/kafkatest/services/log_compaction_tester.py b/tests/kafkatest/services/log_compaction_tester.py new file mode 100644 index 0000000..cc6bf4f --- /dev/null +++ b/tests/kafkatest/services/log_compaction_tester.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from ducktape.services.background_thread import BackgroundThreadService + +from kafkatest.directory_layout.kafka_path import KafkaPathResolverMixin, CORE_LIBS_JAR_NAME, CORE_DEPENDANT_TEST_LIBS_JAR_NAME +from kafkatest.services.security.security_config import SecurityConfig +from kafkatest.version import DEV_BRANCH + +class LogCompactionTester(KafkaPathResolverMixin, BackgroundThreadService): + + OUTPUT_DIR = "/mnt/logcompaction_tester" + LOG_PATH = os.path.join(OUTPUT_DIR, "logcompaction_tester_stdout.log") + VERIFICATION_STRING = "Data verification is completed" + + logs = { + "tool_logs": { + "path": LOG_PATH, + "collect_default": True} + } + + def __init__(self, context, kafka, security_protocol="PLAINTEXT", stop_timeout_sec=30, tls_version=None): + super(LogCompactionTester, self).__init__(context, 1) + + self.kafka = kafka + self.security_protocol = security_protocol + self.tls_version = tls_version + self.security_config = SecurityConfig(self.context, security_protocol, tls_version=tls_version) + self.stop_timeout_sec = stop_timeout_sec + self.log_compaction_completed = False + + def _worker(self, idx, node): + node.account.ssh("mkdir -p %s" % LogCompactionTester.OUTPUT_DIR) + cmd = self.start_cmd(node) + self.logger.info("LogCompactionTester %d command: %s" % (idx, cmd)) + self.security_config.setup_node(node) + for line in node.account.ssh_capture(cmd): + self.logger.debug("Checking line:{}".format(line)) + + if line.startswith(LogCompactionTester.VERIFICATION_STRING): + self.log_compaction_completed = True + + def start_cmd(self, node): + core_libs_jar = self.path.jar(CORE_LIBS_JAR_NAME, DEV_BRANCH) + core_dependant_test_libs_jar = self.path.jar(CORE_DEPENDANT_TEST_LIBS_JAR_NAME, DEV_BRANCH) + + cmd = "for file in %s; do CLASSPATH=$CLASSPATH:$file; done;" % core_libs_jar + cmd += " for file in %s; do CLASSPATH=$CLASSPATH:$file; done;" % core_dependant_test_libs_jar + cmd += " export CLASSPATH;" + cmd += self.path.script("kafka-run-class.sh", node) + cmd += " %s" % self.java_class_name() + cmd += " --bootstrap-server %s --messages 1000000 --sleep 20 --duplicates 10 --percent-deletes 10" % (self.kafka.bootstrap_servers(self.security_protocol)) + + cmd += " 2>> %s | tee -a %s &" % (self.logs["tool_logs"]["path"], self.logs["tool_logs"]["path"]) + return cmd + + def stop_node(self, node): + node.account.kill_java_processes(self.java_class_name(), clean_shutdown=True, + allow_fail=True) + + stopped = self.wait_node(node, timeout_sec=self.stop_timeout_sec) + assert stopped, "Node %s: did not stop within the specified timeout of %s seconds" % \ + (str(node.account), str(self.stop_timeout_sec)) + + def clean_node(self, node): + node.account.kill_java_processes(self.java_class_name(), clean_shutdown=False, + allow_fail=True) + node.account.ssh("rm -rf %s" % LogCompactionTester.OUTPUT_DIR, allow_fail=False) + + def java_class_name(self): + return "kafka.tools.LogCompactionTester" + + @property + def is_done(self): + return self.log_compaction_completed diff --git a/tests/kafkatest/services/mirror_maker.py b/tests/kafkatest/services/mirror_maker.py new file mode 100644 index 0000000..340aa16 --- /dev/null +++ b/tests/kafkatest/services/mirror_maker.py @@ -0,0 +1,164 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from ducktape.services.service import Service +from ducktape.utils.util import wait_until + +from kafkatest.directory_layout.kafka_path import KafkaPathResolverMixin + +""" +MirrorMaker is a tool for mirroring data between two Kafka clusters. +""" + +class MirrorMaker(KafkaPathResolverMixin, Service): + + # Root directory for persistent output + PERSISTENT_ROOT = "/mnt/mirror_maker" + LOG_DIR = os.path.join(PERSISTENT_ROOT, "logs") + LOG_FILE = os.path.join(LOG_DIR, "mirror_maker.log") + LOG4J_CONFIG = os.path.join(PERSISTENT_ROOT, "tools-log4j.properties") + PRODUCER_CONFIG = os.path.join(PERSISTENT_ROOT, "producer.properties") + CONSUMER_CONFIG = os.path.join(PERSISTENT_ROOT, "consumer.properties") + + logs = { + "mirror_maker_log": { + "path": LOG_FILE, + "collect_default": True} + } + + def __init__(self, context, num_nodes, source, target, whitelist=None, num_streams=1, + consumer_timeout_ms=None, offsets_storage="kafka", + offset_commit_interval_ms=60000, log_level="DEBUG", producer_interceptor_classes=None): + """ + MirrorMaker mirrors messages from one or more source clusters to a single destination cluster. + + Args: + context: standard context + source: source Kafka cluster + target: target Kafka cluster to which data will be mirrored + whitelist: whitelist regex for topics to mirror + blacklist: blacklist regex for topics not to mirror + num_streams: number of consumer threads to create; can be a single int, or a list with + one value per node, allowing num_streams to be the same for each node, + or configured independently per-node + consumer_timeout_ms: consumer stops if t > consumer_timeout_ms elapses between consecutive messages + offsets_storage: used for consumer offsets.storage property + offset_commit_interval_ms: how frequently the mirror maker consumer commits offsets + """ + super(MirrorMaker, self).__init__(context, num_nodes=num_nodes) + self.log_level = log_level + self.consumer_timeout_ms = consumer_timeout_ms + self.num_streams = num_streams + if not isinstance(num_streams, int): + # if not an integer, num_streams should be configured per-node + assert len(num_streams) == num_nodes + self.whitelist = whitelist + self.source = source + self.target = target + + self.offsets_storage = offsets_storage.lower() + if not (self.offsets_storage in ["kafka", "zookeeper"]): + raise Exception("offsets_storage should be 'kafka' or 'zookeeper'. Instead found %s" % self.offsets_storage) + + self.offset_commit_interval_ms = offset_commit_interval_ms + self.producer_interceptor_classes = producer_interceptor_classes + self.external_jars = None + + # These properties are potentially used by third-party tests. + self.source_auto_offset_reset = None + self.partition_assignment_strategy = None + + def start_cmd(self, node): + cmd = "export LOG_DIR=%s;" % MirrorMaker.LOG_DIR + cmd += " export KAFKA_LOG4J_OPTS=\"-Dlog4j.configuration=file:%s\";" % MirrorMaker.LOG4J_CONFIG + cmd += " export KAFKA_OPTS=%s;" % self.security_config.kafka_opts + # add external dependencies, for instance for interceptors + if self.external_jars is not None: + cmd += "for file in %s; do CLASSPATH=$CLASSPATH:$file; done; " % self.external_jars + cmd += "export CLASSPATH; " + cmd += " %s %s" % (self.path.script("kafka-run-class.sh", node), + self.java_class_name()) + cmd += " --consumer.config %s" % MirrorMaker.CONSUMER_CONFIG + cmd += " --producer.config %s" % MirrorMaker.PRODUCER_CONFIG + cmd += " --offset.commit.interval.ms %s" % str(self.offset_commit_interval_ms) + if isinstance(self.num_streams, int): + cmd += " --num.streams %d" % self.num_streams + else: + # config num_streams separately on each node + cmd += " --num.streams %d" % self.num_streams[self.idx(node) - 1] + if self.whitelist is not None: + cmd += " --whitelist=\"%s\"" % self.whitelist + + cmd += " 1>> %s 2>> %s &" % (MirrorMaker.LOG_FILE, MirrorMaker.LOG_FILE) + return cmd + + def pids(self, node): + return node.account.java_pids(self.java_class_name()) + + def alive(self, node): + return len(self.pids(node)) > 0 + + def start_node(self, node): + node.account.ssh("mkdir -p %s" % MirrorMaker.PERSISTENT_ROOT, allow_fail=False) + node.account.ssh("mkdir -p %s" % MirrorMaker.LOG_DIR, allow_fail=False) + + self.security_config = self.source.security_config.client_config() + self.security_config.setup_node(node) + + # Create, upload one consumer config file for source cluster + consumer_props = self.render("mirror_maker_consumer.properties") + consumer_props += str(self.security_config) + + node.account.create_file(MirrorMaker.CONSUMER_CONFIG, consumer_props) + self.logger.info("Mirrormaker consumer props:\n" + consumer_props) + + # Create, upload producer properties file for target cluster + producer_props = self.render('mirror_maker_producer.properties') + producer_props += str(self.security_config) + self.logger.info("Mirrormaker producer props:\n" + producer_props) + node.account.create_file(MirrorMaker.PRODUCER_CONFIG, producer_props) + + + # Create and upload log properties + log_config = self.render('tools_log4j.properties', log_file=MirrorMaker.LOG_FILE) + node.account.create_file(MirrorMaker.LOG4J_CONFIG, log_config) + + # Run mirror maker + cmd = self.start_cmd(node) + self.logger.debug("Mirror maker command: %s", cmd) + node.account.ssh(cmd, allow_fail=False) + wait_until(lambda: self.alive(node), timeout_sec=30, backoff_sec=.5, + err_msg="Mirror maker took to long to start.") + self.logger.debug("Mirror maker is alive") + + def stop_node(self, node, clean_shutdown=True): + node.account.kill_java_processes(self.java_class_name(), allow_fail=True, + clean_shutdown=clean_shutdown) + wait_until(lambda: not self.alive(node), timeout_sec=30, backoff_sec=.5, + err_msg="Mirror maker took to long to stop.") + + def clean_node(self, node): + if self.alive(node): + self.logger.warn("%s %s was still alive at cleanup time. Killing forcefully..." % + (self.__class__.__name__, node.account)) + node.account.kill_java_processes(self.java_class_name(), clean_shutdown=False, + allow_fail=True) + node.account.ssh("rm -rf %s" % MirrorMaker.PERSISTENT_ROOT, allow_fail=False) + self.security_config.clean_node(node) + + def java_class_name(self): + return "kafka.tools.MirrorMaker" diff --git a/tests/kafkatest/services/monitor/__init__.py b/tests/kafkatest/services/monitor/__init__.py new file mode 100644 index 0000000..ec20143 --- /dev/null +++ b/tests/kafkatest/services/monitor/__init__.py @@ -0,0 +1,14 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/kafkatest/services/monitor/http.py b/tests/kafkatest/services/monitor/http.py new file mode 100644 index 0000000..0293fbd --- /dev/null +++ b/tests/kafkatest/services/monitor/http.py @@ -0,0 +1,228 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from http.server import BaseHTTPRequestHandler, HTTPServer +from collections import defaultdict, namedtuple +import json +from threading import Thread +from select import select +import socket + +MetricKey = namedtuple('MetricKey', ['host', 'client_id', 'name', 'group', 'tags']) +MetricValue = namedtuple('MetricValue', ['time', 'value']) + +# Python's logging library doesn't define anything more detailed than DEBUG, but we'd like a finer-grained setting for +# for highly detailed messages, e.g. logging every single incoming request. +TRACE = 5 + + +class HttpMetricsCollector(object): + """ + HttpMetricsCollector enables collection of metrics from various Kafka clients instrumented with the + PushHttpMetricsReporter. It starts a web server locally and provides the necessary configuration for clients + to automatically report metrics data to this server. It also provides basic functionality for querying the + recorded metrics. This class can be used either as a mixin or standalone object. + """ + + # The port to listen on on the worker node, which will be forwarded to the port listening on this driver node + REMOTE_PORT = 6789 + + def __init__(self, **kwargs): + """ + Create a new HttpMetricsCollector + :param period the period, in seconds, between updates that the metrics reporter configuration should define. + defaults to reporting once per second + :param args: + :param kwargs: + """ + self._http_metrics_period = kwargs.pop('period', 1) + + super(HttpMetricsCollector, self).__init__(**kwargs) + + # TODO: currently we maintain just a simple map from all key info -> value. However, some key fields are far + # more common to filter on, so we'd want to index by them, e.g. host, client.id, metric name. + self._http_metrics = defaultdict(list) + + self._httpd = HTTPServer(('', 0), _MetricsReceiver) + self._httpd.parent = self + self._httpd.metrics = self._http_metrics + + self._http_metrics_thread = Thread(target=self._run_http_metrics_httpd, + name='http-metrics-thread[%s]' % str(self)) + self._http_metrics_thread.start() + + self._forwarders = {} + + @property + def http_metrics_url(self): + """ + :return: the URL to use when reporting metrics + """ + return "http://%s:%d" % ("localhost", self.REMOTE_PORT) + + @property + def http_metrics_client_configs(self): + """ + Get client configurations that can be used to report data to this collector. Put these in a properties file for + clients (e.g. console producer or consumer) to have them push metrics to this driver. Note that in some cases + (e.g. streams, connect) these settings may need to be prefixed. + :return: a dictionary of client configurations that will direct a client to report metrics to this collector + """ + return { + "metric.reporters": "org.apache.kafka.tools.PushHttpMetricsReporter", + "metrics.url": self.http_metrics_url, + "metrics.period": self._http_metrics_period, + } + + def start_node(self, node): + local_port = self._httpd.socket.getsockname()[1] + self.logger.debug('HttpMetricsCollector listening on %s', local_port) + self._forwarders[self.idx(node)] = _ReverseForwarder(self.logger, node, self.REMOTE_PORT, local_port) + + super(HttpMetricsCollector, self).start_node(node) + + def stop(self): + super(HttpMetricsCollector, self).stop() + + if self._http_metrics_thread: + self.logger.debug("Shutting down metrics httpd") + self._httpd.shutdown() + self._http_metrics_thread.join() + self.logger.debug("Finished shutting down metrics httpd") + + def stop_node(self, node): + super(HttpMetricsCollector, self).stop_node(node) + + idx = self.idx(node) + self._forwarders[idx].stop() + del self._forwarders[idx] + + def metrics(self, host=None, client_id=None, name=None, group=None, tags=None): + """ + Get any collected metrics that match the specified parameters, yielding each as a tuple of + (key, [, ...]) values. + """ + for k, values in self._http_metrics.items(): + if ((host is None or host == k.host) and + (client_id is None or client_id == k.client_id) and + (name is None or name == k.name) and + (group is None or group == k.group) and + (tags is None or tags == k.tags)): + yield (k, values) + + def _run_http_metrics_httpd(self): + self._httpd.serve_forever() + + +class _MetricsReceiver(BaseHTTPRequestHandler): + """ + HTTP request handler that accepts requests from the PushHttpMetricsReporter and stores them back into the parent + HttpMetricsCollector + """ + + def log_message(self, format, *args, **kwargs): + # Don't do any logging here so we get rid of the mostly useless per-request Apache log-style info that spams + # the debug log + pass + + def do_POST(self): + data = self.rfile.read(int(self.headers['Content-Length'])) + data = json.loads(data) + self.server.parent.logger.log(TRACE, "POST %s\n\n%s\n%s", self.path, self.headers, + json.dumps(data, indent=4, separators=(',', ': '))) + self.send_response(204) + self.end_headers() + + client = data['client'] + host = client['host'] + client_id = client['client_id'] + ts = client['time'] + metrics = data['metrics'] + for raw_metric in metrics: + name = raw_metric['name'] + group = raw_metric['group'] + # Convert to tuple of pairs because dicts & lists are unhashable + tags = tuple((k, v) for k, v in raw_metric['tags'].items()), + value = raw_metric['value'] + + key = MetricKey(host=host, client_id=client_id, name=name, group=group, tags=tags) + metric_value = MetricValue(time=ts, value=value) + + self.server.metrics[key].append(metric_value) + + +class _ReverseForwarder(object): + """ + Runs reverse forwarding of a port on a node to a local port. This allows you to setup a server on the test driver + that only assumes we have basic SSH access that ducktape guarantees is available for worker nodes. + """ + + def __init__(self, logger, node, remote_port, local_port): + self.logger = logger + self._node = node + self._local_port = local_port + self._remote_port = remote_port + + self.logger.debug('Forwarding %s port %d to driver port %d', node, remote_port, local_port) + + self._stopping = False + + self._transport = node.account.ssh_client.get_transport() + self._transport.request_port_forward('', remote_port) + + self._accept_thread = Thread(target=self._accept) + self._accept_thread.start() + + def stop(self): + self._stopping = True + self._accept_thread.join(30) + if self._accept_thread.isAlive(): + raise RuntimeError("Failed to stop reverse forwarder on %s", self._node) + self._transport.cancel_port_forward('', self._remote_port) + + def _accept(self): + while not self._stopping: + chan = self._transport.accept(1) + if chan is None: + continue + thr = Thread(target=self._handler, args=(chan,)) + thr.setDaemon(True) + thr.start() + + def _handler(self, chan): + sock = socket.socket() + try: + sock.connect(("localhost", self._local_port)) + except Exception as e: + self.logger.error('Forwarding request to port %d failed: %r', self._local_port, e) + return + + self.logger.log(TRACE, 'Connected! Tunnel open %r -> %r -> %d', chan.origin_addr, chan.getpeername(), + self._local_port) + while True: + r, w, x = select([sock, chan], [], []) + if sock in r: + data = sock.recv(1024) + if len(data) == 0: + break + chan.send(data) + if chan in r: + data = chan.recv(1024) + if len(data) == 0: + break + sock.send(data) + chan.close() + sock.close() + self.logger.log(TRACE, 'Tunnel closed from %r', chan.origin_addr) diff --git a/tests/kafkatest/services/monitor/jmx.py b/tests/kafkatest/services/monitor/jmx.py new file mode 100644 index 0000000..bff1878 --- /dev/null +++ b/tests/kafkatest/services/monitor/jmx.py @@ -0,0 +1,156 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from ducktape.cluster.remoteaccount import RemoteCommandError +from ducktape.utils.util import wait_until + +from kafkatest.directory_layout.kafka_path import KafkaPathResolverMixin +from kafkatest.version import get_version, V_0_11_0_0, DEV_BRANCH + +class JmxMixin(object): + """This mixin helps existing service subclasses start JmxTool on their worker nodes and collect jmx stats. + + A couple things worth noting: + - this is not a service in its own right. + - we assume the service using JmxMixin also uses KafkaPathResolverMixin + - this uses the --wait option for JmxTool, so the list of object names must be explicit; no patterns are permitted + """ + def __init__(self, num_nodes, jmx_object_names=None, jmx_attributes=None, jmx_poll_ms=1000, root="/mnt"): + self.jmx_object_names = jmx_object_names + self.jmx_attributes = jmx_attributes or [] + self.jmx_poll_ms = jmx_poll_ms + self.jmx_port = 9192 + + self.started = [False] * num_nodes + self.jmx_stats = [{} for x in range(num_nodes)] + self.maximum_jmx_value = {} # map from object_attribute_name to maximum value observed over time + self.average_jmx_value = {} # map from object_attribute_name to average value observed over time + + self.jmx_tool_log = os.path.join(root, "jmx_tool.log") + self.jmx_tool_err_log = os.path.join(root, "jmx_tool.err.log") + + def clean_node(self, node, idx=None): + node.account.kill_java_processes(self.jmx_class_name(), clean_shutdown=False, + allow_fail=True) + if idx is None: + idx = self.idx(node) + self.started[idx-1] = False + node.account.ssh("rm -f -- %s %s" % (self.jmx_tool_log, self.jmx_tool_err_log), allow_fail=False) + + def start_jmx_tool(self, idx, node): + if self.jmx_object_names is None: + self.logger.debug("%s: Not starting jmx tool because no jmx objects are defined" % node.account) + return + + if self.started[idx-1]: + self.logger.debug("%s: jmx tool has been started already on this node" % node.account) + return + + # JmxTool is not particularly robust to slow-starting processes. In order to ensure JmxTool doesn't fail if the + # process we're trying to monitor takes awhile before listening on the JMX port, wait until we can see that port + # listening before even launching JmxTool + def check_jmx_port_listening(): + return 0 == node.account.ssh("nc -z 127.0.0.1 %d" % self.jmx_port, allow_fail=True) + + wait_until(check_jmx_port_listening, timeout_sec=30, backoff_sec=.1, + err_msg="%s: Never saw JMX port for %s start listening" % (node.account, self)) + + # To correctly wait for requested JMX metrics to be added we need the --wait option for JmxTool. This option was + # not added until 0.11.0.1, so any earlier versions need to use JmxTool from a newer version. + use_jmxtool_version = get_version(node) + if use_jmxtool_version <= V_0_11_0_0: + use_jmxtool_version = DEV_BRANCH + cmd = "%s %s " % (self.path.script("kafka-run-class.sh", use_jmxtool_version), self.jmx_class_name()) + cmd += "--reporting-interval %d --jmx-url service:jmx:rmi:///jndi/rmi://127.0.0.1:%d/jmxrmi" % (self.jmx_poll_ms, self.jmx_port) + cmd += " --wait" + for jmx_object_name in self.jmx_object_names: + cmd += " --object-name %s" % jmx_object_name + cmd += " --attributes " + for jmx_attribute in self.jmx_attributes: + cmd += "%s," % jmx_attribute + cmd += " 1>> %s" % self.jmx_tool_log + cmd += " 2>> %s &" % self.jmx_tool_err_log + + self.logger.debug("%s: Start JmxTool %d command: %s" % (node.account, idx, cmd)) + node.account.ssh(cmd, allow_fail=False) + wait_until(lambda: self._jmx_has_output(node), timeout_sec=30, backoff_sec=.5, err_msg="%s: Jmx tool took too long to start" % node.account) + self.started[idx-1] = True + + def _jmx_has_output(self, node): + """Helper used as a proxy to determine whether jmx is running by that jmx_tool_log contains output.""" + try: + node.account.ssh("test -s %s" % self.jmx_tool_log, allow_fail=False) + return True + except RemoteCommandError: + return False + + def read_jmx_output(self, idx, node): + if not self.started[idx-1]: + return + + object_attribute_names = [] + + cmd = "cat %s" % self.jmx_tool_log + self.logger.debug("Read jmx output %d command: %s", idx, cmd) + lines = [line for line in node.account.ssh_capture(cmd, allow_fail=False)] + assert len(lines) > 1, "There don't appear to be any samples in the jmx tool log: %s" % lines + + for line in lines: + if "time" in line: + object_attribute_names = line.strip()[1:-1].split("\",\"")[1:] + continue + stats = [float(field) for field in line.split(',')] + time_sec = int(stats[0]/1000) + self.jmx_stats[idx-1][time_sec] = {name: stats[i+1] for i, name in enumerate(object_attribute_names)} + + # do not calculate average and maximum of jmx stats until we have read output from all nodes + # If the service is multithreaded, this means that the results will be aggregated only when the last + # service finishes + if any(len(time_to_stats) == 0 for time_to_stats in self.jmx_stats): + return + + start_time_sec = min([min(time_to_stats.keys()) for time_to_stats in self.jmx_stats]) + end_time_sec = max([max(time_to_stats.keys()) for time_to_stats in self.jmx_stats]) + + for name in object_attribute_names: + aggregates_per_time = [] + for time_sec in range(start_time_sec, end_time_sec + 1): + # assume that value is 0 if it is not read by jmx tool at the given time. This is appropriate for metrics such as bandwidth + values_per_node = [time_to_stats.get(time_sec, {}).get(name, 0) for time_to_stats in self.jmx_stats] + # assume that value is aggregated across nodes by sum. This is appropriate for metrics such as bandwidth + aggregates_per_time.append(sum(values_per_node)) + self.average_jmx_value[name] = sum(aggregates_per_time) / len(aggregates_per_time) + self.maximum_jmx_value[name] = max(aggregates_per_time) + + def read_jmx_output_all_nodes(self): + for node in self.nodes: + self.read_jmx_output(self.idx(node), node) + + def jmx_class_name(self): + return "kafka.tools.JmxTool" + +class JmxTool(JmxMixin, KafkaPathResolverMixin): + """ + Simple helper class for using the JmxTool directly instead of as a mix-in + """ + def __init__(self, text_context, *args, **kwargs): + JmxMixin.__init__(self, num_nodes=1, *args, **kwargs) + self.context = text_context + + @property + def logger(self): + return self.context.logger diff --git a/tests/kafkatest/services/performance/__init__.py b/tests/kafkatest/services/performance/__init__.py new file mode 100644 index 0000000..69686f7 --- /dev/null +++ b/tests/kafkatest/services/performance/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .performance import PerformanceService, throughput, latency, compute_aggregate_throughput +from .end_to_end_latency import EndToEndLatencyService +from .producer_performance import ProducerPerformanceService +from .consumer_performance import ConsumerPerformanceService diff --git a/tests/kafkatest/services/performance/consumer_performance.py b/tests/kafkatest/services/performance/consumer_performance.py new file mode 100644 index 0000000..6df8dfb --- /dev/null +++ b/tests/kafkatest/services/performance/consumer_performance.py @@ -0,0 +1,187 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os + +from kafkatest.services.performance import PerformanceService +from kafkatest.services.security.security_config import SecurityConfig +from kafkatest.version import DEV_BRANCH, V_2_0_0, LATEST_0_10_0 + + +class ConsumerPerformanceService(PerformanceService): + """ + See ConsumerPerformance.scala as the source of truth on these settings, but for reference: + + "zookeeper" "The connection string for the zookeeper connection in the form host:port. Multiple URLS can + be given to allow fail-over. This option is only used with the old consumer." + + "broker-list", "A broker list to use for connecting if using the new consumer." + + "topic", "REQUIRED: The topic to consume from." + + "group", "The group id to consume on." + + "fetch-size", "The amount of data to fetch in a single request." + + "from-latest", "If the consumer does not already have an establishedoffset to consume from, + start with the latest message present in the log rather than the earliest message." + + "socket-buffer-size", "The size of the tcp RECV size." + + "threads", "Number of processing threads." + + "num-fetch-threads", "Number of fetcher threads. Defaults to 1" + + "new-consumer", "Use the new consumer implementation." + "consumer.config", "Consumer config properties file." + """ + + # Root directory for persistent output + PERSISTENT_ROOT = "/mnt/consumer_performance" + LOG_DIR = os.path.join(PERSISTENT_ROOT, "logs") + STDOUT_CAPTURE = os.path.join(PERSISTENT_ROOT, "consumer_performance.stdout") + STDERR_CAPTURE = os.path.join(PERSISTENT_ROOT, "consumer_performance.stderr") + LOG_FILE = os.path.join(LOG_DIR, "consumer_performance.log") + LOG4J_CONFIG = os.path.join(PERSISTENT_ROOT, "tools-log4j.properties") + CONFIG_FILE = os.path.join(PERSISTENT_ROOT, "consumer.properties") + + logs = { + "consumer_performance_output": { + "path": STDOUT_CAPTURE, + "collect_default": True}, + "consumer_performance_stderr": { + "path": STDERR_CAPTURE, + "collect_default": True}, + "consumer_performance_log": { + "path": LOG_FILE, + "collect_default": True} + } + + def __init__(self, context, num_nodes, kafka, topic, messages, version=DEV_BRANCH, new_consumer=True, settings={}): + super(ConsumerPerformanceService, self).__init__(context, num_nodes) + self.kafka = kafka + self.security_config = kafka.security_config.client_config() + self.topic = topic + self.messages = messages + self.new_consumer = new_consumer + self.settings = settings + + assert version.consumer_supports_bootstrap_server() or (not new_consumer), \ + "new_consumer is only supported if version >= 0.9.0.0, version %s" % str(version) + + assert version < V_2_0_0 or new_consumer, \ + "new_consumer==false is only supported if version < 2.0.0, version %s" % str(version) + + security_protocol = self.security_config.security_protocol + assert version.consumer_supports_bootstrap_server() or security_protocol == SecurityConfig.PLAINTEXT, \ + "Security protocol %s is only supported if version >= 0.9.0.0, version %s" % (self.security_config, str(version)) + + # These less-frequently used settings can be updated manually after instantiation + self.fetch_size = None + self.socket_buffer_size = None + self.threads = None + self.num_fetch_threads = None + self.group = None + self.from_latest = None + + for node in self.nodes: + node.version = version + + def args(self, version): + """Dictionary of arguments used to start the Consumer Performance script.""" + args = { + 'topic': self.topic, + 'messages': self.messages, + } + + if self.new_consumer: + if version <= LATEST_0_10_0: + args['new-consumer'] = "" + args['broker-list'] = self.kafka.bootstrap_servers(self.security_config.security_protocol) + else: + args['zookeeper'] = self.kafka.zk_connect_setting() + + if self.fetch_size is not None: + args['fetch-size'] = self.fetch_size + + if self.socket_buffer_size is not None: + args['socket-buffer-size'] = self.socket_buffer_size + + if self.threads is not None: + args['threads'] = self.threads + + if self.num_fetch_threads is not None: + args['num-fetch-threads'] = self.num_fetch_threads + + if self.group is not None: + args['group'] = self.group + + if self.from_latest: + args['from-latest'] = "" + + return args + + def start_cmd(self, node): + cmd = "export LOG_DIR=%s;" % ConsumerPerformanceService.LOG_DIR + cmd += " export KAFKA_OPTS=%s;" % self.security_config.kafka_opts + cmd += " export KAFKA_LOG4J_OPTS=\"-Dlog4j.configuration=file:%s\";" % ConsumerPerformanceService.LOG4J_CONFIG + cmd += " %s" % self.path.script("kafka-consumer-perf-test.sh", node) + for key, value in self.args(node.version).items(): + cmd += " --%s %s" % (key, value) + + if node.version.consumer_supports_bootstrap_server(): + # This is only used for security settings + cmd += " --consumer.config %s" % ConsumerPerformanceService.CONFIG_FILE + + for key, value in self.settings.items(): + cmd += " %s=%s" % (str(key), str(value)) + + cmd += " 2>> %(stderr)s | tee -a %(stdout)s" % {'stdout': ConsumerPerformanceService.STDOUT_CAPTURE, + 'stderr': ConsumerPerformanceService.STDERR_CAPTURE} + return cmd + + def parse_results(self, line, version): + parts = line.split(',') + if version.consumer_supports_bootstrap_server(): + result = { + 'total_mb': float(parts[2]), + 'mbps': float(parts[3]), + 'records_per_sec': float(parts[5]), + } + else: + result = { + 'total_mb': float(parts[3]), + 'mbps': float(parts[4]), + 'records_per_sec': float(parts[6]), + } + return result + + def _worker(self, idx, node): + node.account.ssh("mkdir -p %s" % ConsumerPerformanceService.PERSISTENT_ROOT, allow_fail=False) + + log_config = self.render('tools_log4j.properties', log_file=ConsumerPerformanceService.LOG_FILE) + node.account.create_file(ConsumerPerformanceService.LOG4J_CONFIG, log_config) + node.account.create_file(ConsumerPerformanceService.CONFIG_FILE, str(self.security_config)) + self.security_config.setup_node(node) + + cmd = self.start_cmd(node) + self.logger.debug("Consumer performance %d command: %s", idx, cmd) + last = None + for line in node.account.ssh_capture(cmd): + last = line + + # Parse and save the last line's information + self.results[idx-1] = self.parse_results(last, node.version) diff --git a/tests/kafkatest/services/performance/end_to_end_latency.py b/tests/kafkatest/services/performance/end_to_end_latency.py new file mode 100644 index 0000000..3cde3ef --- /dev/null +++ b/tests/kafkatest/services/performance/end_to_end_latency.py @@ -0,0 +1,127 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from kafkatest.services.performance import PerformanceService +from kafkatest.services.security.security_config import SecurityConfig +from kafkatest.version import DEV_BRANCH + + + +class EndToEndLatencyService(PerformanceService): + MESSAGE_BYTES = 21 # 0.8.X messages are fixed at 21 bytes, so we'll match that for other versions + + # Root directory for persistent output + PERSISTENT_ROOT = "/mnt/end_to_end_latency" + LOG_DIR = os.path.join(PERSISTENT_ROOT, "logs") + STDOUT_CAPTURE = os.path.join(PERSISTENT_ROOT, "end_to_end_latency.stdout") + STDERR_CAPTURE = os.path.join(PERSISTENT_ROOT, "end_to_end_latency.stderr") + LOG_FILE = os.path.join(LOG_DIR, "end_to_end_latency.log") + LOG4J_CONFIG = os.path.join(PERSISTENT_ROOT, "tools-log4j.properties") + CONFIG_FILE = os.path.join(PERSISTENT_ROOT, "client.properties") + + logs = { + "end_to_end_latency_output": { + "path": STDOUT_CAPTURE, + "collect_default": True}, + "end_to_end_latency_stderr": { + "path": STDERR_CAPTURE, + "collect_default": True}, + "end_to_end_latency_log": { + "path": LOG_FILE, + "collect_default": True} + } + + def __init__(self, context, num_nodes, kafka, topic, num_records, compression_type="none", version=DEV_BRANCH, acks=1): + super(EndToEndLatencyService, self).__init__(context, num_nodes, + root=EndToEndLatencyService.PERSISTENT_ROOT) + self.kafka = kafka + self.security_config = kafka.security_config.client_config() + + security_protocol = self.security_config.security_protocol + + if not version.consumer_supports_bootstrap_server(): + assert security_protocol == SecurityConfig.PLAINTEXT, \ + "Security protocol %s is only supported if version >= 0.9.0.0, version %s" % (self.security_config, str(version)) + assert compression_type == "none", \ + "Compression type %s is only supported if version >= 0.9.0.0, version %s" % (compression_type, str(version)) + + self.args = { + 'topic': topic, + 'num_records': num_records, + 'acks': acks, + 'compression_type': compression_type, + 'kafka_opts': self.security_config.kafka_opts, + 'message_bytes': EndToEndLatencyService.MESSAGE_BYTES + } + + for node in self.nodes: + node.version = version + + def start_cmd(self, node): + args = self.args.copy() + args.update({ + 'bootstrap_servers': self.kafka.bootstrap_servers(self.security_config.security_protocol), + 'config_file': EndToEndLatencyService.CONFIG_FILE, + 'kafka_run_class': self.path.script("kafka-run-class.sh", node), + 'java_class_name': self.java_class_name() + }) + if not node.version.consumer_supports_bootstrap_server(): + args.update({ + 'zk_connect': self.kafka.zk_connect_setting(), + }) + + cmd = "export KAFKA_LOG4J_OPTS=\"-Dlog4j.configuration=file:%s\"; " % EndToEndLatencyService.LOG4J_CONFIG + if node.version.consumer_supports_bootstrap_server(): + cmd += "KAFKA_OPTS=%(kafka_opts)s %(kafka_run_class)s %(java_class_name)s " % args + cmd += "%(bootstrap_servers)s %(topic)s %(num_records)d %(acks)d %(message_bytes)d %(config_file)s" % args + else: + # Set fetch max wait to 0 to match behavior in later versions + cmd += "KAFKA_OPTS=%(kafka_opts)s %(kafka_run_class)s kafka.tools.TestEndToEndLatency " % args + cmd += "%(bootstrap_servers)s %(zk_connect)s %(topic)s %(num_records)d 0 %(acks)d" % args + + cmd += " 2>> %(stderr)s | tee -a %(stdout)s" % {'stdout': EndToEndLatencyService.STDOUT_CAPTURE, + 'stderr': EndToEndLatencyService.STDERR_CAPTURE} + + return cmd + + def _worker(self, idx, node): + node.account.ssh("mkdir -p %s" % EndToEndLatencyService.PERSISTENT_ROOT, allow_fail=False) + + log_config = self.render('tools_log4j.properties', log_file=EndToEndLatencyService.LOG_FILE) + + node.account.create_file(EndToEndLatencyService.LOG4J_CONFIG, log_config) + client_config = str(self.security_config) + if node.version.consumer_supports_bootstrap_server(): + client_config += "compression_type=%(compression_type)s" % self.args + node.account.create_file(EndToEndLatencyService.CONFIG_FILE, client_config) + + self.security_config.setup_node(node) + + cmd = self.start_cmd(node) + self.logger.debug("End-to-end latency %d command: %s", idx, cmd) + results = {} + for line in node.account.ssh_capture(cmd): + if line.startswith("Avg latency:"): + results['latency_avg_ms'] = float(line.split()[2]) + if line.startswith("Percentiles"): + results['latency_50th_ms'] = float(line.split()[3][:-1]) + results['latency_99th_ms'] = float(line.split()[6][:-1]) + results['latency_999th_ms'] = float(line.split()[9]) + self.results[idx-1] = results + + def java_class_name(self): + return "kafka.tools.EndToEndLatency" diff --git a/tests/kafkatest/services/performance/performance.py b/tests/kafkatest/services/performance/performance.py new file mode 100644 index 0000000..0d1f5b0 --- /dev/null +++ b/tests/kafkatest/services/performance/performance.py @@ -0,0 +1,72 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ducktape.services.background_thread import BackgroundThreadService +from kafkatest.directory_layout.kafka_path import KafkaPathResolverMixin + + +class PerformanceService(KafkaPathResolverMixin, BackgroundThreadService): + + def __init__(self, context=None, num_nodes=0, root="/mnt/*", stop_timeout_sec=30): + super(PerformanceService, self).__init__(context, num_nodes) + self.results = [None] * self.num_nodes + self.stats = [[] for x in range(self.num_nodes)] + self.stop_timeout_sec = stop_timeout_sec + self.root = root + + def java_class_name(self): + """ + Returns the name of the Java class which this service creates. Subclasses should override + this method, so that we know the name of the java process to stop. If it is not + overridden, we will kill all java processes in PerformanceService#stop_node (for backwards + compatibility.) + """ + return "" + + def stop_node(self, node): + node.account.kill_java_processes(self.java_class_name(), clean_shutdown=True, allow_fail=True) + + stopped = self.wait_node(node, timeout_sec=self.stop_timeout_sec) + assert stopped, "Node %s: did not stop within the specified timeout of %s seconds" % \ + (str(node.account), str(self.stop_timeout_sec)) + + def clean_node(self, node): + node.account.kill_java_processes(self.java_class_name(), clean_shutdown=False, allow_fail=True) + node.account.ssh("rm -rf -- %s" % self.root, allow_fail=False) + + +def throughput(records_per_sec, mb_per_sec): + """Helper method to ensure uniform representation of throughput data""" + return { + "records_per_sec": records_per_sec, + "mb_per_sec": mb_per_sec + } + + +def latency(latency_50th_ms, latency_99th_ms, latency_999th_ms): + """Helper method to ensure uniform representation of latency data""" + return { + "latency_50th_ms": latency_50th_ms, + "latency_99th_ms": latency_99th_ms, + "latency_999th_ms": latency_999th_ms + } + + +def compute_aggregate_throughput(perf): + """Helper method for computing throughput after running a performance service.""" + aggregate_rate = sum([r['records_per_sec'] for r in perf.results]) + aggregate_mbps = sum([r['mbps'] for r in perf.results]) + + return throughput(aggregate_rate, aggregate_mbps) diff --git a/tests/kafkatest/services/performance/producer_performance.py b/tests/kafkatest/services/performance/producer_performance.py new file mode 100644 index 0000000..a990d4f --- /dev/null +++ b/tests/kafkatest/services/performance/producer_performance.py @@ -0,0 +1,174 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +from ducktape.utils.util import wait_until +from ducktape.cluster.remoteaccount import RemoteCommandError + +from kafkatest.directory_layout.kafka_path import TOOLS_JAR_NAME, TOOLS_DEPENDANT_TEST_LIBS_JAR_NAME +from kafkatest.services.monitor.http import HttpMetricsCollector +from kafkatest.services.performance import PerformanceService +from kafkatest.services.security.security_config import SecurityConfig +from kafkatest.version import DEV_BRANCH + + +class ProducerPerformanceService(HttpMetricsCollector, PerformanceService): + + PERSISTENT_ROOT = "/mnt/producer_performance" + STDOUT_CAPTURE = os.path.join(PERSISTENT_ROOT, "producer_performance.stdout") + STDERR_CAPTURE = os.path.join(PERSISTENT_ROOT, "producer_performance.stderr") + LOG_DIR = os.path.join(PERSISTENT_ROOT, "logs") + LOG_FILE = os.path.join(LOG_DIR, "producer_performance.log") + LOG4J_CONFIG = os.path.join(PERSISTENT_ROOT, "tools-log4j.properties") + + def __init__(self, context, num_nodes, kafka, topic, num_records, record_size, throughput, version=DEV_BRANCH, settings=None, + intermediate_stats=False, client_id="producer-performance"): + + super(ProducerPerformanceService, self).__init__(context=context, num_nodes=num_nodes) + + self.logs = { + "producer_performance_stdout": { + "path": ProducerPerformanceService.STDOUT_CAPTURE, + "collect_default": True}, + "producer_performance_stderr": { + "path": ProducerPerformanceService.STDERR_CAPTURE, + "collect_default": True}, + "producer_performance_log": { + "path": ProducerPerformanceService.LOG_FILE, + "collect_default": True} + } + + self.kafka = kafka + self.security_config = kafka.security_config.client_config() + + security_protocol = self.security_config.security_protocol + assert version.consumer_supports_bootstrap_server() or security_protocol == SecurityConfig.PLAINTEXT, \ + "Security protocol %s is only supported if version >= 0.9.0.0, version %s" % (self.security_config, str(version)) + + self.args = { + 'topic': topic, + 'kafka_opts': self.security_config.kafka_opts, + 'num_records': num_records, + 'record_size': record_size, + 'throughput': throughput + } + self.settings = settings or {} + self.intermediate_stats = intermediate_stats + self.client_id = client_id + + for node in self.nodes: + node.version = version + + def start_cmd(self, node): + args = self.args.copy() + args.update({ + 'bootstrap_servers': self.kafka.bootstrap_servers(self.security_config.security_protocol), + 'client_id': self.client_id, + 'kafka_run_class': self.path.script("kafka-run-class.sh", node), + 'metrics_props': ' '.join("%s=%s" % (k, v) for k, v in self.http_metrics_client_configs.items()) + }) + + cmd = "" + + if node.version < DEV_BRANCH: + # In order to ensure more consistent configuration between versions, always use the ProducerPerformance + # tool from the development branch + tools_jar = self.path.jar(TOOLS_JAR_NAME, DEV_BRANCH) + tools_dependant_libs_jar = self.path.jar(TOOLS_DEPENDANT_TEST_LIBS_JAR_NAME, DEV_BRANCH) + + for jar in (tools_jar, tools_dependant_libs_jar): + cmd += "for file in %s; do CLASSPATH=$CLASSPATH:$file; done; " % jar + cmd += "export CLASSPATH; " + + cmd += " export KAFKA_LOG4J_OPTS=\"-Dlog4j.configuration=file:%s\"; " % ProducerPerformanceService.LOG4J_CONFIG + cmd += "KAFKA_OPTS=%(kafka_opts)s KAFKA_HEAP_OPTS=\"-XX:+HeapDumpOnOutOfMemoryError\" %(kafka_run_class)s org.apache.kafka.tools.ProducerPerformance " \ + "--topic %(topic)s --num-records %(num_records)d --record-size %(record_size)d --throughput %(throughput)d --producer-props bootstrap.servers=%(bootstrap_servers)s client.id=%(client_id)s %(metrics_props)s" % args + + self.security_config.setup_node(node) + if self.security_config.security_protocol != SecurityConfig.PLAINTEXT: + self.settings.update(self.security_config.properties) + + for key, value in self.settings.items(): + cmd += " %s=%s" % (str(key), str(value)) + + cmd += " 2>>%s | tee %s" % (ProducerPerformanceService.STDERR_CAPTURE, ProducerPerformanceService.STDOUT_CAPTURE) + return cmd + + def pids(self, node): + try: + cmd = "jps | grep -i ProducerPerformance | awk '{print $1}'" + pid_arr = [pid for pid in node.account.ssh_capture(cmd, allow_fail=True, callback=int)] + return pid_arr + except (RemoteCommandError, ValueError) as e: + return [] + + def alive(self, node): + return len(self.pids(node)) > 0 + + def _worker(self, idx, node): + node.account.ssh("mkdir -p %s" % ProducerPerformanceService.PERSISTENT_ROOT, allow_fail=False) + + # Create and upload log properties + log_config = self.render('tools_log4j.properties', log_file=ProducerPerformanceService.LOG_FILE) + node.account.create_file(ProducerPerformanceService.LOG4J_CONFIG, log_config) + + cmd = self.start_cmd(node) + self.logger.debug("Producer performance %d command: %s", idx, cmd) + + # start ProducerPerformance process + start = time.time() + producer_output = node.account.ssh_capture(cmd) + wait_until(lambda: self.alive(node), timeout_sec=20, err_msg="ProducerPerformance failed to start") + # block until there is at least one line of output + first_line = next(producer_output, None) + if first_line is None: + raise Exception("No output from ProducerPerformance") + + wait_until(lambda: not self.alive(node), timeout_sec=1200, backoff_sec=2, err_msg="ProducerPerformance failed to finish") + elapsed = time.time() - start + self.logger.debug("ProducerPerformance process ran for %s seconds" % elapsed) + + # parse producer output from file + last = None + producer_output = node.account.ssh_capture("cat %s" % ProducerPerformanceService.STDOUT_CAPTURE) + for line in producer_output: + if self.intermediate_stats: + try: + self.stats[idx-1].append(self.parse_stats(line)) + except: + # Sometimes there are extraneous log messages + pass + + last = line + try: + self.results[idx-1] = self.parse_stats(last) + except: + raise Exception("Unable to parse aggregate performance statistics on node %d: %s" % (idx, last)) + + def parse_stats(self, line): + + parts = line.split(',') + return { + 'records': int(parts[0].split()[0]), + 'records_per_sec': float(parts[1].split()[0]), + 'mbps': float(parts[1].split('(')[1].split()[0]), + 'latency_avg_ms': float(parts[2].split()[0]), + 'latency_max_ms': float(parts[3].split()[0]), + 'latency_50th_ms': float(parts[4].split()[0]), + 'latency_95th_ms': float(parts[5].split()[0]), + 'latency_99th_ms': float(parts[6].split()[0]), + 'latency_999th_ms': float(parts[7].split()[0]), + } diff --git a/tests/kafkatest/services/performance/templates/tools_log4j.properties b/tests/kafkatest/services/performance/templates/tools_log4j.properties new file mode 100644 index 0000000..df10d88 --- /dev/null +++ b/tests/kafkatest/services/performance/templates/tools_log4j.properties @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Define the root logger with appender file +log4j.rootLogger = {{ log_level|default("INFO") }}, FILE + +log4j.appender.FILE=org.apache.log4j.FileAppender +log4j.appender.FILE.File={{ log_file }} +log4j.appender.FILE.ImmediateFlush=true +# Set the append to false, overwrite +log4j.appender.FILE.Append=false +log4j.appender.FILE.layout=org.apache.log4j.PatternLayout +log4j.appender.FILE.layout.conversionPattern=[%d] %p %m (%c)%n diff --git a/tests/kafkatest/services/replica_verification_tool.py b/tests/kafkatest/services/replica_verification_tool.py new file mode 100644 index 0000000..13a1288 --- /dev/null +++ b/tests/kafkatest/services/replica_verification_tool.py @@ -0,0 +1,95 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ducktape.services.background_thread import BackgroundThreadService + +from kafkatest.directory_layout.kafka_path import KafkaPathResolverMixin +from kafkatest.services.security.security_config import SecurityConfig + +import re + + +class ReplicaVerificationTool(KafkaPathResolverMixin, BackgroundThreadService): + + logs = { + "producer_log": { + "path": "/mnt/replica_verification_tool.log", + "collect_default": False} + } + + def __init__(self, context, num_nodes, kafka, topic, report_interval_ms, security_protocol="PLAINTEXT", + stop_timeout_sec=30, tls_version=None): + super(ReplicaVerificationTool, self).__init__(context, num_nodes) + + self.kafka = kafka + self.topic = topic + self.report_interval_ms = report_interval_ms + self.security_protocol = security_protocol + self.tls_version = tls_version + self.security_config = SecurityConfig(self.context, security_protocol, tls_version=tls_version) + self.partition_lag = {} + self.stop_timeout_sec = stop_timeout_sec + + def _worker(self, idx, node): + cmd = self.start_cmd(node) + self.logger.debug("ReplicaVerificationTool %d command: %s" % (idx, cmd)) + self.security_config.setup_node(node) + for line in node.account.ssh_capture(cmd): + self.logger.debug("Parsing line:{}".format(line)) + + parsed = re.search('.*max lag is (.+?) for partition ([a-zA-Z0-9._-]+-[0-9]+) at', line) + if parsed: + lag = int(parsed.group(1)) + topic_partition = parsed.group(2) + self.logger.debug("Setting max lag for {} as {}".format(topic_partition, lag)) + self.partition_lag[topic_partition] = lag + + def get_lag_for_partition(self, topic, partition): + """ + Get latest lag for given topic-partition + + Args: + topic: a topic + partition: a partition of the topic + """ + topic_partition = topic + '-' + str(partition) + lag = self.partition_lag.get(topic_partition, -1) + self.logger.debug("Returning lag for {} as {}".format(topic_partition, lag)) + + return lag + + def start_cmd(self, node): + cmd = self.path.script("kafka-run-class.sh", node) + cmd += " %s" % self.java_class_name() + cmd += " --broker-list %s --topic-white-list %s --time -2 --report-interval-ms %s" % (self.kafka.bootstrap_servers(self.security_protocol), self.topic, self.report_interval_ms) + + cmd += " 2>> /mnt/replica_verification_tool.log | tee -a /mnt/replica_verification_tool.log &" + return cmd + + def stop_node(self, node): + node.account.kill_java_processes(self.java_class_name(), clean_shutdown=True, + allow_fail=True) + + stopped = self.wait_node(node, timeout_sec=self.stop_timeout_sec) + assert stopped, "Node %s: did not stop within the specified timeout of %s seconds" % \ + (str(node.account), str(self.stop_timeout_sec)) + + def clean_node(self, node): + node.account.kill_java_processes(self.java_class_name(), clean_shutdown=False, + allow_fail=True) + node.account.ssh("rm -rf /mnt/replica_verification_tool.log", allow_fail=False) + + def java_class_name(self): + return "kafka.tools.ReplicaVerificationTool" diff --git a/tests/kafkatest/services/security/__init__.py b/tests/kafkatest/services/security/__init__.py new file mode 100644 index 0000000..e556dc9 --- /dev/null +++ b/tests/kafkatest/services/security/__init__.py @@ -0,0 +1,15 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/kafkatest/services/security/kafka_acls.py b/tests/kafkatest/services/security/kafka_acls.py new file mode 100644 index 0000000..30fc343 --- /dev/null +++ b/tests/kafkatest/services/security/kafka_acls.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +class ACLs: + def __init__(self, context): + self.context = context + + def set_acls(self, protocol, kafka, topic, group, force_use_zk_connection=False, additional_cluster_operations_to_grant = []): + """ + Creates ACls for the Kafka Broker principal that brokers use in tests + + :param protocol: the security protocol to use (e.g. PLAINTEXT, SASL_PLAINTEXT, etc.) + :param kafka: Kafka cluster upon which ClusterAction ACL is created + :param topic: topic for which produce and consume ACLs are created + :param group: consumer group for which consume ACL is created + :param force_use_zk_connection: forces the use of ZooKeeper when true, otherwise AdminClient is used when available. + This is necessary for the case where we are bootstrapping ACLs before Kafka is started or before authorizer is enabled + :param additional_cluster_operations_to_grant may be set to ['Alter', 'Create'] if the cluster is secured since these are required + to create SCRAM credentials and topics, respectively + """ + # Set server ACLs + kafka_principal = "User:CN=systemtest" if protocol == "SSL" else "User:kafka" + self.add_cluster_acl(kafka, kafka_principal, force_use_zk_connection=force_use_zk_connection, additional_cluster_operations_to_grant = additional_cluster_operations_to_grant) + self.add_read_acl(kafka, kafka_principal, "*", force_use_zk_connection=force_use_zk_connection) + + # Set client ACLs + client_principal = "User:CN=systemtest" if protocol == "SSL" else "User:client" + self.add_produce_acl(kafka, client_principal, topic, force_use_zk_connection=force_use_zk_connection) + self.add_consume_acl(kafka, client_principal, topic, group, force_use_zk_connection=force_use_zk_connection) + + def _add_acl_on_topic(self, kafka, principal, topic, operation_flag, node, force_use_zk_connection): + """ + :param principal: principal for which ACL is created + :param topic: topic for which ACL is created + :param operation_flag: type of ACL created (e.g. --producer, --consumer, --operation=Read) + :param node: Node to use when determining connection settings + :param force_use_zk_connection: forces the use of ZooKeeper when true, otherwise AdminClient is used when available + """ + cmd = "%(cmd_prefix)s --add --topic=%(topic)s %(operation_flag)s --allow-principal=%(principal)s" % { + 'cmd_prefix': kafka.kafka_acls_cmd_with_optional_security_settings(node, force_use_zk_connection), + 'topic': topic, + 'operation_flag': operation_flag, + 'principal': principal + } + kafka.run_cli_tool(node, cmd) + + def add_cluster_acl(self, kafka, principal, force_use_zk_connection=False, additional_cluster_operations_to_grant = [], security_protocol=None): + """ + :param kafka: Kafka cluster upon which ClusterAction ACL is created + :param principal: principal for which ClusterAction ACL is created + :param node: Node to use when determining connection settings + :param force_use_zk_connection: forces the use of ZooKeeper when true, otherwise AdminClient is used when available. + This is necessary for the case where we are bootstrapping ACLs before Kafka is started or before authorizer is enabled + :param additional_cluster_operations_to_grant may be set to ['Alter', 'Create'] if the cluster is secured since these are required + to create SCRAM credentials and topics, respectively + :param security_protocol set it to explicitly determine whether we use client or broker credentials, otherwise + we use the the client security protocol unless inter-broker security protocol is PLAINTEXT, in which case we use PLAINTEXT. + Then we use the broker's credentials if the selected security protocol matches the inter-broker security protocol, + otherwise we use the client's credentials. + """ + node = kafka.nodes[0] + + for operation in ['ClusterAction'] + additional_cluster_operations_to_grant: + cmd = "%(cmd_prefix)s --add --cluster --operation=%(operation)s --allow-principal=%(principal)s" % { + 'cmd_prefix': kafka.kafka_acls_cmd_with_optional_security_settings(node, force_use_zk_connection, security_protocol), + 'operation': operation, + 'principal': principal + } + kafka.run_cli_tool(node, cmd) + + def remove_cluster_acl(self, kafka, principal, additional_cluster_operations_to_remove = [], security_protocol=None): + """ + :param kafka: Kafka cluster upon which ClusterAction ACL is deleted + :param principal: principal for which ClusterAction ACL is deleted + :param node: Node to use when determining connection settings + :param additional_cluster_operations_to_remove may be set to ['Alter', 'Create'] if the cluster is secured since these are required + to create SCRAM credentials and topics, respectively + :param security_protocol set it to explicitly determine whether we use client or broker credentials, otherwise + we use the the client security protocol unless inter-broker security protocol is PLAINTEXT, in which case we use PLAINTEXT. + Then we use the broker's credentials if the selected security protocol matches the inter-broker security protocol, + otherwise we use the client's credentials. + """ + node = kafka.nodes[0] + + for operation in ['ClusterAction'] + additional_cluster_operations_to_remove: + cmd = "%(cmd_prefix)s --remove --force --cluster --operation=%(operation)s --allow-principal=%(principal)s" % { + 'cmd_prefix': kafka.kafka_acls_cmd_with_optional_security_settings(node, False, security_protocol), + 'operation': operation, + 'principal': principal + } + kafka.logger.info(cmd) + kafka.run_cli_tool(node, cmd) + + def add_read_acl(self, kafka, principal, topic, force_use_zk_connection=False): + """ + :param kafka: Kafka cluster upon which Read ACL is created + :param principal: principal for which Read ACL is created + :param topic: topic for which Read ACL is created + :param node: Node to use when determining connection settings + :param force_use_zk_connection: forces the use of ZooKeeper when true, otherwise AdminClient is used when available. + This is necessary for the case where we are bootstrapping ACLs before Kafka is started or before authorizer is enabled + """ + node = kafka.nodes[0] + + self._add_acl_on_topic(kafka, principal, topic, "--operation=Read", node, force_use_zk_connection) + + def add_produce_acl(self, kafka, principal, topic, force_use_zk_connection=False): + """ + :param kafka: Kafka cluster upon which Producer ACL is created + :param principal: principal for which Producer ACL is created + :param topic: topic for which Producer ACL is created + :param node: Node to use when determining connection settings + :param force_use_zk_connection: forces the use of ZooKeeper when true, otherwise AdminClient is used when available. + This is necessary for the case where we are bootstrapping ACLs before Kafka is started or before authorizer is enabled + """ + node = kafka.nodes[0] + + self._add_acl_on_topic(kafka, principal, topic, "--producer", node, force_use_zk_connection) + + def add_consume_acl(self, kafka, principal, topic, group, force_use_zk_connection=False): + """ + :param kafka: Kafka cluster upon which Consumer ACL is created + :param principal: principal for which Consumer ACL is created + :param topic: topic for which Consumer ACL is created + :param group: consumewr group for which Consumer ACL is created + :param node: Node to use when determining connection settings + :param force_use_zk_connection: forces the use of ZooKeeper when true, otherwise AdminClient is used when available. + This is necessary for the case where we are bootstrapping ACLs before Kafka is started or before authorizer is enabled + """ + node = kafka.nodes[0] + + cmd = "%(cmd_prefix)s --add --topic=%(topic)s --group=%(group)s --consumer --allow-principal=%(principal)s" % { + 'cmd_prefix': kafka.kafka_acls_cmd_with_optional_security_settings(node, force_use_zk_connection), + 'topic': topic, + 'group': group, + 'principal': principal + } + kafka.run_cli_tool(node, cmd) + diff --git a/tests/kafkatest/services/security/listener_security_config.py b/tests/kafkatest/services/security/listener_security_config.py new file mode 100644 index 0000000..119e9f3 --- /dev/null +++ b/tests/kafkatest/services/security/listener_security_config.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +class ListenerSecurityConfig: + + SASL_MECHANISM_PREFIXED_CONFIGS = ["connections.max.reauth.ms", "sasl.jaas.config", + "sasl.login.callback.handler.class", "sasl.login.class", + "sasl.server.callback.handler.class"] + + def __init__(self, use_separate_interbroker_listener=False, + client_listener_overrides={}, interbroker_listener_overrides={}): + """ + :param bool use_separate_interbroker_listener - if set, will use a separate interbroker listener, + with security protocol set to interbroker_security_protocol value. If set, requires + interbroker_security_protocol to be provided. + Normally port name is the same as its security protocol, so setting security_protocol and + interbroker_security_protocol to the same value will lead to a single port being open and both client + and broker-to-broker communication will go over that port. This parameter allows + you to add an interbroker listener with the same security protocol as a client listener, but running on a + separate port. + :param dict client_listener_overrides - non-prefixed listener config overrides for named client listener + (for example 'sasl.jaas.config', 'ssl.keystore.location', 'sasl.login.callback.handler.class', etc). + :param dict interbroker_listener_overrides - non-prefixed listener config overrides for named interbroker + listener (for example 'sasl.jaas.config', 'ssl.keystore.location', 'sasl.login.callback.handler.class', etc). + """ + self.use_separate_interbroker_listener = use_separate_interbroker_listener + self.client_listener_overrides = client_listener_overrides + self.interbroker_listener_overrides = interbroker_listener_overrides + + def requires_sasl_mechanism_prefix(self, config): + return config in ListenerSecurityConfig.SASL_MECHANISM_PREFIXED_CONFIGS diff --git a/tests/kafkatest/services/security/minikdc.py b/tests/kafkatest/services/security/minikdc.py new file mode 100644 index 0000000..23327db --- /dev/null +++ b/tests/kafkatest/services/security/minikdc.py @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from io import open +from os import remove, close +from shutil import move +from tempfile import mkstemp + +from ducktape.services.service import Service + +from kafkatest.directory_layout.kafka_path import KafkaPathResolverMixin, CORE_LIBS_JAR_NAME, CORE_DEPENDANT_TEST_LIBS_JAR_NAME +from kafkatest.version import DEV_BRANCH + + +class MiniKdc(KafkaPathResolverMixin, Service): + + logs = { + "minikdc_log": { + "path": "/mnt/minikdc/minikdc.log", + "collect_default": True} + } + + WORK_DIR = "/mnt/minikdc" + PROPS_FILE = "/mnt/minikdc/minikdc.properties" + KEYTAB_FILE = "/mnt/minikdc/keytab" + KRB5CONF_FILE = "/mnt/minikdc/krb5.conf" + LOG_FILE = "/mnt/minikdc/minikdc.log" + + LOCAL_KEYTAB_FILE = None + LOCAL_KRB5CONF_FILE = None + + @staticmethod + def _set_local_keytab_file(local_scratch_dir): + """Set MiniKdc.LOCAL_KEYTAB_FILE exactly once per test. + + LOCAL_KEYTAB_FILE is currently used like a global variable to provide a mechanism to share the + location of the local keytab file among all services which might need it. + + Since individual ducktape tests are each run in a subprocess forked from the ducktape main process, + class variables set at class load time are duplicated between test processes. This leads to collisions + if test subprocesses are run in parallel, so we defer setting these class variables until after the test itself + begins to run. + """ + if MiniKdc.LOCAL_KEYTAB_FILE is None: + MiniKdc.LOCAL_KEYTAB_FILE = os.path.join(local_scratch_dir, "keytab") + return MiniKdc.LOCAL_KEYTAB_FILE + + @staticmethod + def _set_local_krb5conf_file(local_scratch_dir): + """Set MiniKdc.LOCAL_KRB5CONF_FILE exactly once per test. + + See _set_local_keytab_file for details why we do this. + """ + + if MiniKdc.LOCAL_KRB5CONF_FILE is None: + MiniKdc.LOCAL_KRB5CONF_FILE = os.path.join(local_scratch_dir, "krb5conf") + return MiniKdc.LOCAL_KRB5CONF_FILE + + def __init__(self, context, kafka_nodes, extra_principals=""): + super(MiniKdc, self).__init__(context, 1) + self.kafka_nodes = kafka_nodes + self.extra_principals = extra_principals + + # context.local_scratch_dir uses a ducktape feature: + # each test_context object has a unique local scratch directory which is available for the duration of the test + # which is automatically garbage collected after the test finishes + MiniKdc._set_local_keytab_file(context.local_scratch_dir) + MiniKdc._set_local_krb5conf_file(context.local_scratch_dir) + + def replace_in_file(self, file_path, pattern, subst): + fh, abs_path = mkstemp() + with open(abs_path, 'w') as new_file: + with open(file_path) as old_file: + for line in old_file: + new_file.write(line.replace(pattern, subst)) + close(fh) + remove(file_path) + move(abs_path, file_path) + + def start_node(self, node): + node.account.ssh("mkdir -p %s" % MiniKdc.WORK_DIR, allow_fail=False) + props_file = self.render('minikdc.properties', node=node) + node.account.create_file(MiniKdc.PROPS_FILE, props_file) + self.logger.info("minikdc.properties") + self.logger.info(props_file) + + kafka_principals = ' '.join(['kafka/' + kafka_node.account.hostname for kafka_node in self.kafka_nodes]) + principals = 'client ' + kafka_principals + ' ' + self.extra_principals + self.logger.info("Starting MiniKdc with principals " + principals) + + core_libs_jar = self.path.jar(CORE_LIBS_JAR_NAME, DEV_BRANCH) + core_dependant_test_libs_jar = self.path.jar(CORE_DEPENDANT_TEST_LIBS_JAR_NAME, DEV_BRANCH) + + cmd = "for file in %s; do CLASSPATH=$CLASSPATH:$file; done;" % core_libs_jar + cmd += " for file in %s; do CLASSPATH=$CLASSPATH:$file; done;" % core_dependant_test_libs_jar + cmd += " export CLASSPATH;" + cmd += " %s kafka.security.minikdc.MiniKdc %s %s %s %s 1>> %s 2>> %s &" % (self.path.script("kafka-run-class.sh", node), MiniKdc.WORK_DIR, MiniKdc.PROPS_FILE, MiniKdc.KEYTAB_FILE, principals, MiniKdc.LOG_FILE, MiniKdc.LOG_FILE) + self.logger.debug("Attempting to start MiniKdc on %s with command: %s" % (str(node.account), cmd)) + with node.account.monitor_log(MiniKdc.LOG_FILE) as monitor: + node.account.ssh(cmd) + monitor.wait_until("MiniKdc Running", timeout_sec=60, backoff_sec=1, err_msg="MiniKdc didn't finish startup") + + node.account.copy_from(MiniKdc.KEYTAB_FILE, MiniKdc.LOCAL_KEYTAB_FILE) + node.account.copy_from(MiniKdc.KRB5CONF_FILE, MiniKdc.LOCAL_KRB5CONF_FILE) + + # KDC is set to bind openly (via 0.0.0.0). Change krb5.conf to hold the specific KDC address + self.replace_in_file(MiniKdc.LOCAL_KRB5CONF_FILE, '0.0.0.0', node.account.hostname) + + def stop_node(self, node): + self.logger.info("Stopping %s on %s" % (type(self).__name__, node.account.hostname)) + node.account.kill_java_processes("MiniKdc", clean_shutdown=True, allow_fail=False) + + def clean_node(self, node): + node.account.kill_java_processes("MiniKdc", clean_shutdown=False, allow_fail=True) + node.account.ssh("rm -rf " + MiniKdc.WORK_DIR, allow_fail=False) + if os.path.exists(MiniKdc.LOCAL_KEYTAB_FILE): + os.remove(MiniKdc.LOCAL_KEYTAB_FILE) + if os.path.exists(MiniKdc.LOCAL_KRB5CONF_FILE): + os.remove(MiniKdc.LOCAL_KRB5CONF_FILE) + + diff --git a/tests/kafkatest/services/security/security_config.py b/tests/kafkatest/services/security/security_config.py new file mode 100644 index 0000000..ccd3dd4 --- /dev/null +++ b/tests/kafkatest/services/security/security_config.py @@ -0,0 +1,432 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import subprocess +from tempfile import mkdtemp +from shutil import rmtree +from ducktape.template import TemplateRenderer + +from kafkatest.services.security.minikdc import MiniKdc +from kafkatest.services.security.listener_security_config import ListenerSecurityConfig +import itertools + +from kafkatest.utils.remote_account import java_version + + +class SslStores(object): + def __init__(self, local_scratch_dir, logger=None): + self.logger = logger + self.ca_crt_path = os.path.join(local_scratch_dir, "test.ca.crt") + self.ca_jks_path = os.path.join(local_scratch_dir, "test.ca.jks") + self.ca_passwd = "test-ca-passwd" + + self.truststore_path = os.path.join(local_scratch_dir, "test.truststore.jks") + self.truststore_passwd = "test-ts-passwd" + self.keystore_passwd = "test-ks-passwd" + # Zookeeper TLS (as of v3.5.6) does not support a key password different than the keystore password + self.key_passwd = self.keystore_passwd + # Allow upto one hour of clock skew between host and VMs + self.startdate = "-1H" + + for file in [self.ca_crt_path, self.ca_jks_path, self.truststore_path]: + if os.path.exists(file): + os.remove(file) + + def generate_ca(self): + """ + Generate CA private key and certificate. + """ + + self.runcmd("keytool -genkeypair -alias ca -keyalg RSA -keysize 2048 -keystore %s -storetype JKS -storepass %s -keypass %s -dname CN=SystemTestCA -startdate %s --ext bc=ca:true" % (self.ca_jks_path, self.ca_passwd, self.ca_passwd, self.startdate)) + self.runcmd("keytool -export -alias ca -keystore %s -storepass %s -storetype JKS -rfc -file %s" % (self.ca_jks_path, self.ca_passwd, self.ca_crt_path)) + + def generate_truststore(self): + """ + Generate JKS truststore containing CA certificate. + """ + + self.runcmd("keytool -importcert -alias ca -file %s -keystore %s -storepass %s -storetype JKS -noprompt" % (self.ca_crt_path, self.truststore_path, self.truststore_passwd)) + + def generate_and_copy_keystore(self, node): + """ + Generate JKS keystore with certificate signed by the test CA. + The generated certificate has the node's hostname as a DNS SubjectAlternativeName. + """ + + ks_dir = mkdtemp(dir="/tmp") + ks_path = os.path.join(ks_dir, "test.keystore.jks") + csr_path = os.path.join(ks_dir, "test.kafka.csr") + crt_path = os.path.join(ks_dir, "test.kafka.crt") + + self.runcmd("keytool -genkeypair -alias kafka -keyalg RSA -keysize 2048 -keystore %s -storepass %s -storetype JKS -keypass %s -dname CN=systemtest -ext SAN=DNS:%s -startdate %s" % (ks_path, self.keystore_passwd, self.key_passwd, self.hostname(node), self.startdate)) + self.runcmd("keytool -certreq -keystore %s -storepass %s -storetype JKS -keypass %s -alias kafka -file %s" % (ks_path, self.keystore_passwd, self.key_passwd, csr_path)) + self.runcmd("keytool -gencert -keystore %s -storepass %s -storetype JKS -alias ca -infile %s -outfile %s -dname CN=systemtest -ext SAN=DNS:%s -startdate %s" % (self.ca_jks_path, self.ca_passwd, csr_path, crt_path, self.hostname(node), self.startdate)) + self.runcmd("keytool -importcert -keystore %s -storepass %s -storetype JKS -alias ca -file %s -noprompt" % (ks_path, self.keystore_passwd, self.ca_crt_path)) + self.runcmd("keytool -importcert -keystore %s -storepass %s -storetype JKS -keypass %s -alias kafka -file %s -noprompt" % (ks_path, self.keystore_passwd, self.key_passwd, crt_path)) + node.account.copy_to(ks_path, SecurityConfig.KEYSTORE_PATH) + + # generate ZooKeeper client TLS config file for encryption-only (no client cert) use case + str = """zookeeper.clientCnxnSocket=org.apache.zookeeper.ClientCnxnSocketNetty +zookeeper.ssl.client.enable=true +zookeeper.ssl.truststore.location=%s +zookeeper.ssl.truststore.password=%s +""" % (SecurityConfig.TRUSTSTORE_PATH, self.truststore_passwd) + node.account.create_file(SecurityConfig.ZK_CLIENT_TLS_ENCRYPT_ONLY_CONFIG_PATH, str) + + # also generate ZooKeeper client TLS config file for mutual authentication use case + str = """zookeeper.clientCnxnSocket=org.apache.zookeeper.ClientCnxnSocketNetty +zookeeper.ssl.client.enable=true +zookeeper.ssl.truststore.location=%s +zookeeper.ssl.truststore.password=%s +zookeeper.ssl.keystore.location=%s +zookeeper.ssl.keystore.password=%s +""" % (SecurityConfig.TRUSTSTORE_PATH, self.truststore_passwd, SecurityConfig.KEYSTORE_PATH, self.keystore_passwd) + node.account.create_file(SecurityConfig.ZK_CLIENT_MUTUAL_AUTH_CONFIG_PATH, str) + + rmtree(ks_dir) + + def hostname(self, node): + """ Hostname which may be overridden for testing validation failures + """ + return node.account.hostname + + def runcmd(self, cmd): + if self.logger: + self.logger.log(logging.DEBUG, cmd) + proc = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + stdout, stderr = proc.communicate() + + if proc.returncode != 0: + raise RuntimeError("Command '%s' returned non-zero exit status %d: %s" % (cmd, proc.returncode, stdout)) + + +class SecurityConfig(TemplateRenderer): + + PLAINTEXT = 'PLAINTEXT' + SSL = 'SSL' + SASL_PLAINTEXT = 'SASL_PLAINTEXT' + SASL_SSL = 'SASL_SSL' + SASL_SECURITY_PROTOCOLS = [SASL_PLAINTEXT, SASL_SSL] + SSL_SECURITY_PROTOCOLS = [SSL, SASL_SSL] + SASL_MECHANISM_GSSAPI = 'GSSAPI' + SASL_MECHANISM_PLAIN = 'PLAIN' + SASL_MECHANISM_SCRAM_SHA_256 = 'SCRAM-SHA-256' + SASL_MECHANISM_SCRAM_SHA_512 = 'SCRAM-SHA-512' + SCRAM_CLIENT_USER = "kafka-client" + SCRAM_CLIENT_PASSWORD = "client-secret" + SCRAM_BROKER_USER = "kafka-broker" + SCRAM_BROKER_PASSWORD = "broker-secret" + CONFIG_DIR = "/mnt/security" + KEYSTORE_PATH = "/mnt/security/test.keystore.jks" + TRUSTSTORE_PATH = "/mnt/security/test.truststore.jks" + ZK_CLIENT_TLS_ENCRYPT_ONLY_CONFIG_PATH = "/mnt/security/zk_client_tls_encrypt_only_config.properties" + ZK_CLIENT_MUTUAL_AUTH_CONFIG_PATH = "/mnt/security/zk_client_mutual_auth_config.properties" + JAAS_CONF_PATH = "/mnt/security/jaas.conf" + # allows admin client to connect with broker credentials to create User SCRAM credentials + ADMIN_CLIENT_AS_BROKER_JAAS_CONF_PATH = "/mnt/security/admin_client_as_broker_jaas.conf" + KRB5CONF_PATH = "/mnt/security/krb5.conf" + KEYTAB_PATH = "/mnt/security/keytab" + + # This is initialized only when the first instance of SecurityConfig is created + ssl_stores = None + + def __init__(self, context, security_protocol=None, interbroker_security_protocol=None, + client_sasl_mechanism=SASL_MECHANISM_GSSAPI, interbroker_sasl_mechanism=SASL_MECHANISM_GSSAPI, + zk_sasl=False, zk_tls=False, template_props="", static_jaas_conf=True, jaas_override_variables=None, + listener_security_config=ListenerSecurityConfig(), tls_version=None, + serves_controller_sasl_mechanism=None, # KRaft Controller does this + serves_intercontroller_sasl_mechanism=None, # KRaft Controller does this + uses_controller_sasl_mechanism=None, # communication to KRaft Controller (broker and controller both do this) + kraft_tls=False): + """ + Initialize the security properties for the node and copy + keystore and truststore to the remote node if the transport protocol + is SSL. If security_protocol is None, the protocol specified in the + template properties file is used. If no protocol is specified in the + template properties either, PLAINTEXT is used as default. + """ + + self.context = context + if not SecurityConfig.ssl_stores: + # This generates keystore/trustore files in a local scratch directory which gets + # automatically destroyed after the test is run + # Creating within the scratch directory allows us to run tests in parallel without fear of collision + SecurityConfig.ssl_stores = SslStores(context.local_scratch_dir, context.logger) + SecurityConfig.ssl_stores.generate_ca() + SecurityConfig.ssl_stores.generate_truststore() + + if security_protocol is None: + security_protocol = self.get_property('security.protocol', template_props) + if security_protocol is None: + security_protocol = SecurityConfig.PLAINTEXT + elif security_protocol not in [SecurityConfig.PLAINTEXT, SecurityConfig.SSL, SecurityConfig.SASL_PLAINTEXT, SecurityConfig.SASL_SSL]: + raise Exception("Invalid security.protocol in template properties: " + security_protocol) + + if interbroker_security_protocol is None: + interbroker_security_protocol = security_protocol + self.interbroker_security_protocol = interbroker_security_protocol + serves_kraft_sasl = [] + if serves_controller_sasl_mechanism is not None: + serves_kraft_sasl += [serves_controller_sasl_mechanism] + if serves_intercontroller_sasl_mechanism is not None: + serves_kraft_sasl += [serves_intercontroller_sasl_mechanism] + self.serves_kraft_sasl = set(serves_kraft_sasl) + uses_kraft_sasl = [] + if uses_controller_sasl_mechanism is not None: + uses_kraft_sasl += [uses_controller_sasl_mechanism] + self.uses_kraft_sasl = set(uses_kraft_sasl) + + self.zk_sasl = zk_sasl + self.zk_tls = zk_tls + self.static_jaas_conf = static_jaas_conf + self.listener_security_config = listener_security_config + self.additional_sasl_mechanisms = [] + self.properties = { + 'security.protocol' : security_protocol, + 'ssl.keystore.location' : SecurityConfig.KEYSTORE_PATH, + 'ssl.keystore.password' : SecurityConfig.ssl_stores.keystore_passwd, + 'ssl.key.password' : SecurityConfig.ssl_stores.key_passwd, + 'ssl.truststore.location' : SecurityConfig.TRUSTSTORE_PATH, + 'ssl.truststore.password' : SecurityConfig.ssl_stores.truststore_passwd, + 'ssl.endpoint.identification.algorithm' : 'HTTPS', + 'sasl.mechanism' : client_sasl_mechanism, + 'sasl.mechanism.inter.broker.protocol' : interbroker_sasl_mechanism, + 'sasl.kerberos.service.name' : 'kafka' + } + self.kraft_tls = kraft_tls + + if tls_version is not None: + self.properties.update({'tls.version' : tls_version}) + + self.properties.update(self.listener_security_config.client_listener_overrides) + self.jaas_override_variables = jaas_override_variables or {} + + self.calc_has_sasl() + self.calc_has_ssl() + + def calc_has_sasl(self): + self.has_sasl = self.is_sasl(self.properties['security.protocol']) \ + or self.is_sasl(self.interbroker_security_protocol) \ + or self.zk_sasl \ + or self.serves_kraft_sasl or self.uses_kraft_sasl + + def calc_has_ssl(self): + self.has_ssl = self.is_ssl(self.properties['security.protocol']) \ + or self.is_ssl(self.interbroker_security_protocol) \ + or self.zk_tls \ + or self.kraft_tls + + def client_config(self, template_props="", node=None, jaas_override_variables=None, + use_inter_broker_mechanism_for_client = False): + # If node is not specified, use static jaas config which will be created later. + # Otherwise use static JAAS configuration files with SASL_SSL and sasl.jaas.config + # property with SASL_PLAINTEXT so that both code paths are tested by existing tests. + # Note that this is an arbitrary choice and it is possible to run all tests with + # either static or dynamic jaas config files if required. + static_jaas_conf = node is None or (self.has_sasl and self.has_ssl) + if use_inter_broker_mechanism_for_client: + client_sasl_mechanism_to_use = self.interbroker_sasl_mechanism + security_protocol_to_use = self.interbroker_security_protocol + else: + # csv is supported here, but client configs only supports a single mechanism, + # so arbitrarily take the first one defined in case it has multiple values + client_sasl_mechanism_to_use = self.client_sasl_mechanism.split(',')[0].strip() + security_protocol_to_use = self.security_protocol + + return SecurityConfig(self.context, security_protocol_to_use, + client_sasl_mechanism=client_sasl_mechanism_to_use, + template_props=template_props, + static_jaas_conf=static_jaas_conf, + jaas_override_variables=jaas_override_variables, + listener_security_config=self.listener_security_config, + tls_version=self.tls_version) + + def enable_sasl(self): + self.has_sasl = True + + def enable_ssl(self): + self.has_ssl = True + + def enable_security_protocol(self, security_protocol, sasl_mechanism = None): + self.has_sasl = self.has_sasl or self.is_sasl(security_protocol) + if sasl_mechanism is not None: + self.additional_sasl_mechanisms.append(sasl_mechanism) + self.has_ssl = self.has_ssl or self.is_ssl(security_protocol) + + def setup_ssl(self, node): + node.account.ssh("mkdir -p %s" % SecurityConfig.CONFIG_DIR, allow_fail=False) + node.account.copy_to(SecurityConfig.ssl_stores.truststore_path, SecurityConfig.TRUSTSTORE_PATH) + SecurityConfig.ssl_stores.generate_and_copy_keystore(node) + + def setup_sasl(self, node): + node.account.ssh("mkdir -p %s" % SecurityConfig.CONFIG_DIR, allow_fail=False) + jaas_conf_file = "jaas.conf" + java_version = node.account.ssh_capture("java -version") + + jaas_conf = None + if 'sasl.jaas.config' not in self.properties: + jaas_conf = self.render_jaas_config( + jaas_conf_file, + { + 'node': node, + 'is_ibm_jdk': any('IBM' in line for line in java_version), + 'SecurityConfig': SecurityConfig, + 'client_sasl_mechanism': self.client_sasl_mechanism, + 'enabled_sasl_mechanisms': self.enabled_sasl_mechanisms + } + ) + else: + jaas_conf = self.properties['sasl.jaas.config'] + + if self.static_jaas_conf: + node.account.create_file(SecurityConfig.JAAS_CONF_PATH, jaas_conf) + node.account.create_file(SecurityConfig.ADMIN_CLIENT_AS_BROKER_JAAS_CONF_PATH, + self.render_jaas_config( + "admin_client_as_broker_jaas.conf", + { + 'node': node, + 'is_ibm_jdk': any('IBM' in line for line in java_version), + 'SecurityConfig': SecurityConfig, + 'client_sasl_mechanism': self.client_sasl_mechanism, + 'enabled_sasl_mechanisms': self.enabled_sasl_mechanisms + } + )) + + elif 'sasl.jaas.config' not in self.properties: + self.properties['sasl.jaas.config'] = jaas_conf.replace("\n", " \\\n") + if self.has_sasl_kerberos: + node.account.copy_to(MiniKdc.LOCAL_KEYTAB_FILE, SecurityConfig.KEYTAB_PATH) + node.account.copy_to(MiniKdc.LOCAL_KRB5CONF_FILE, SecurityConfig.KRB5CONF_PATH) + + def render_jaas_config(self, jaas_conf_file, config_variables): + """ + Renders the JAAS config file contents + + :param jaas_conf_file: name of the JAAS config template file + :param config_variables: dict of variables used in the template + :return: the rendered template string + """ + variables = config_variables.copy() + variables.update(self.jaas_override_variables) # override variables + return self.render(jaas_conf_file, **variables) + + def setup_node(self, node): + if self.has_ssl: + self.setup_ssl(node) + + if self.has_sasl: + self.setup_sasl(node) + + if java_version(node) <= 11 and self.properties.get('tls.version') == 'TLSv1.3': + self.properties.update({'tls.version': 'TLSv1.2'}) + + def clean_node(self, node): + if self.security_protocol != SecurityConfig.PLAINTEXT: + node.account.ssh("rm -rf %s" % SecurityConfig.CONFIG_DIR, allow_fail=False) + + def get_property(self, prop_name, template_props=""): + """ + Get property value from the string representation of + a properties file. + """ + value = None + for line in template_props.split("\n"): + items = line.split("=") + if len(items) == 2 and items[0].strip() == prop_name: + value = str(items[1].strip()) + return value + + def is_ssl(self, security_protocol): + return security_protocol in SecurityConfig.SSL_SECURITY_PROTOCOLS + + def is_sasl(self, security_protocol): + return security_protocol in SecurityConfig.SASL_SECURITY_PROTOCOLS + + def is_sasl_scram(self, sasl_mechanism): + return sasl_mechanism == SecurityConfig.SASL_MECHANISM_SCRAM_SHA_256 or sasl_mechanism == SecurityConfig.SASL_MECHANISM_SCRAM_SHA_512 + + @property + def security_protocol(self): + return self.properties['security.protocol'] + + @property + def tls_version(self): + return self.properties.get('tls.version') + + @property + def client_sasl_mechanism(self): + return self.properties['sasl.mechanism'] + + @property + def interbroker_sasl_mechanism(self): + return self.properties['sasl.mechanism.inter.broker.protocol'] + + @property + def enabled_sasl_mechanisms(self): + """ + :return: all the SASL mechanisms in use, including for brokers, clients, controllers, and ZooKeeper + """ + sasl_mechanisms = [] + if self.is_sasl(self.security_protocol): + # .csv is supported so be sure to account for that possibility + sasl_mechanisms += [mechanism.strip() for mechanism in self.client_sasl_mechanism.split(',')] + if self.is_sasl(self.interbroker_security_protocol): + sasl_mechanisms += [self.interbroker_sasl_mechanism] + if self.serves_kraft_sasl: + sasl_mechanisms += list(self.serves_kraft_sasl) + if self.uses_kraft_sasl: + sasl_mechanisms += list(self.uses_kraft_sasl) + if self.zk_sasl: + sasl_mechanisms += [SecurityConfig.SASL_MECHANISM_GSSAPI] + sasl_mechanisms.extend(self.additional_sasl_mechanisms) + return set(sasl_mechanisms) + + @property + def has_sasl_kerberos(self): + return self.has_sasl and (SecurityConfig.SASL_MECHANISM_GSSAPI in self.enabled_sasl_mechanisms) + + @property + def kafka_opts(self): + if self.has_sasl: + if self.static_jaas_conf: + return "\"-Djava.security.auth.login.config=%s -Djava.security.krb5.conf=%s\"" % (SecurityConfig.JAAS_CONF_PATH, SecurityConfig.KRB5CONF_PATH) + else: + return "\"-Djava.security.krb5.conf=%s\"" % SecurityConfig.KRB5CONF_PATH + else: + return "" + + def props(self, prefix=''): + """ + Return properties as string with line separators, optionally with a prefix. + This is used to append security config properties to + a properties file. + :param prefix: prefix to add to each property + :return: a string containing line-separated properties + """ + if self.security_protocol == SecurityConfig.PLAINTEXT: + return "" + if self.has_sasl and not self.static_jaas_conf and 'sasl.jaas.config' not in self.properties: + raise Exception("JAAS configuration property has not yet been initialized") + config_lines = (prefix + key + "=" + value for key, value in self.properties.items()) + # Extra blank lines ensure this can be appended/prepended safely + return "\n".join(itertools.chain([""], config_lines, [""])) + + def __str__(self): + """ + Return properties as a string with line separators. + """ + return self.props() diff --git a/tests/kafkatest/services/security/templates/admin_client_as_broker_jaas.conf b/tests/kafkatest/services/security/templates/admin_client_as_broker_jaas.conf new file mode 100644 index 0000000..b53da11 --- /dev/null +++ b/tests/kafkatest/services/security/templates/admin_client_as_broker_jaas.conf @@ -0,0 +1,43 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + + +KafkaClient { +{% if "GSSAPI" in enabled_sasl_mechanisms %} +{% if is_ibm_jdk %} + com.ibm.security.auth.module.Krb5LoginModule required debug=false + credsType=both + useKeytab="file:/mnt/security/keytab" + principal="kafka/{{ node.account.hostname }}@EXAMPLE.COM"; +{% else %} + com.sun.security.auth.module.Krb5LoginModule required debug=false + doNotPrompt=true + useKeyTab=true + storeKey=true + keyTab="/mnt/security/keytab" + principal="kafka/{{ node.account.hostname }}@EXAMPLE.COM"; +{% endif %} +{% endif %} +{% if "PLAIN" in enabled_sasl_mechanisms %} + org.apache.kafka.common.security.plain.PlainLoginModule required + username="kafka" + password="kafka-secret" + user_client="client-secret" + user_kafka="kafka-secret"; +{% endif %} +{% if "SCRAM-SHA-256" in client_sasl_mechanism or "SCRAM-SHA-512" in client_sasl_mechanism %} + org.apache.kafka.common.security.scram.ScramLoginModule required + username="{{ SecurityConfig.SCRAM_BROKER_USER }}" + password="{{ SecurityConfig.SCRAM_BROKER_PASSWORD }}"; +{% endif %} +}; diff --git a/tests/kafkatest/services/security/templates/jaas.conf b/tests/kafkatest/services/security/templates/jaas.conf new file mode 100644 index 0000000..3d6c93e --- /dev/null +++ b/tests/kafkatest/services/security/templates/jaas.conf @@ -0,0 +1,108 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE + * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file + * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + + +{% if static_jaas_conf %} +KafkaClient { +{% endif %} +{% if "GSSAPI" in client_sasl_mechanism %} +{% if is_ibm_jdk %} + com.ibm.security.auth.module.Krb5LoginModule required debug=false + credsType=both + useKeytab="file:/mnt/security/keytab" + principal="client@EXAMPLE.COM"; +{% else %} + com.sun.security.auth.module.Krb5LoginModule required debug=false + doNotPrompt=true + useKeyTab=true + storeKey=true + keyTab="/mnt/security/keytab" + principal="client@EXAMPLE.COM"; +{% endif %} +{% elif client_sasl_mechanism == "PLAIN" %} + org.apache.kafka.common.security.plain.PlainLoginModule required + username="client" + password="client-secret"; +{% elif "SCRAM-SHA-256" in client_sasl_mechanism or "SCRAM-SHA-512" in client_sasl_mechanism %} + org.apache.kafka.common.security.scram.ScramLoginModule required + username="{{ SecurityConfig.SCRAM_CLIENT_USER }}" + password="{{ SecurityConfig.SCRAM_CLIENT_PASSWORD }}"; +{% endif %} + +{% if static_jaas_conf %} +}; + +KafkaServer { +{% if "GSSAPI" in enabled_sasl_mechanisms %} +{% if is_ibm_jdk %} + com.ibm.security.auth.module.Krb5LoginModule required debug=false + credsType=both + useKeytab="file:/mnt/security/keytab" + principal="kafka/{{ node.account.hostname }}@EXAMPLE.COM"; +{% else %} + com.sun.security.auth.module.Krb5LoginModule required debug=false + doNotPrompt=true + useKeyTab=true + storeKey=true + keyTab="/mnt/security/keytab" + principal="kafka/{{ node.account.hostname }}@EXAMPLE.COM"; +{% endif %} +{% endif %} +{% if "PLAIN" in enabled_sasl_mechanisms %} + org.apache.kafka.common.security.plain.PlainLoginModule required + username="kafka" + password="kafka-secret" + user_client="client-secret" + user_kafka="kafka-secret"; +{% endif %} +{% if "SCRAM-SHA-256" in client_sasl_mechanism or "SCRAM-SHA-512" in client_sasl_mechanism %} + org.apache.kafka.common.security.scram.ScramLoginModule required + username="{{ SecurityConfig.SCRAM_BROKER_USER }}" + password="{{ SecurityConfig.SCRAM_BROKER_PASSWORD }}"; +{% endif %} +}; + +{% if zk_sasl %} +Client { +{% if is_ibm_jdk %} + com.ibm.security.auth.module.Krb5LoginModule required debug=false + credsType=both + useKeytab="file:/mnt/security/keytab" + principal="zkclient@EXAMPLE.COM"; +{% else %} + com.sun.security.auth.module.Krb5LoginModule required + useKeyTab=true + keyTab="/mnt/security/keytab" + storeKey=true + useTicketCache=false + principal="zkclient@EXAMPLE.COM"; +{% endif %} +}; + +Server { +{% if is_ibm_jdk %} + com.ibm.security.auth.module.Krb5LoginModule required debug=false + credsType=both + useKeyTab="file:/mnt/security/keytab" + principal="zookeeper/{{ node.account.hostname }}@EXAMPLE.COM"; +{% else %} + com.sun.security.auth.module.Krb5LoginModule required + useKeyTab=true + keyTab="/mnt/security/keytab" + storeKey=true + useTicketCache=false + principal="zookeeper/{{ node.account.hostname }}@EXAMPLE.COM"; +{% endif %} +}; +{% endif %} +{% endif %} diff --git a/tests/kafkatest/services/security/templates/minikdc.properties b/tests/kafkatest/services/security/templates/minikdc.properties new file mode 100644 index 0000000..0990a33 --- /dev/null +++ b/tests/kafkatest/services/security/templates/minikdc.properties @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +kdc.bind.address=0.0.0.0 + diff --git a/tests/kafkatest/services/streams.py b/tests/kafkatest/services/streams.py new file mode 100644 index 0000000..f4f6a6a --- /dev/null +++ b/tests/kafkatest/services/streams.py @@ -0,0 +1,783 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os.path +import signal +from . import streams_property +from . import consumer_property +from ducktape.services.service import Service +from ducktape.utils.util import wait_until +from kafkatest.directory_layout.kafka_path import KafkaPathResolverMixin +from kafkatest.services.kafka import KafkaConfig +from kafkatest.services.monitor.jmx import JmxMixin +from kafkatest.version import LATEST_0_10_0, LATEST_0_10_1 + +STATE_DIR = "state.dir" + +class StreamsTestBaseService(KafkaPathResolverMixin, JmxMixin, Service): + """Base class for Streams Test services providing some common settings and functionality""" + + PERSISTENT_ROOT = "/mnt/streams" + + # The log file contains normal log4j logs written using a file appender. stdout and stderr are handled separately + CONFIG_FILE = os.path.join(PERSISTENT_ROOT, "streams.properties") + LOG_FILE = os.path.join(PERSISTENT_ROOT, "streams.log") + STDOUT_FILE = os.path.join(PERSISTENT_ROOT, "streams.stdout") + STDERR_FILE = os.path.join(PERSISTENT_ROOT, "streams.stderr") + JMX_LOG_FILE = os.path.join(PERSISTENT_ROOT, "jmx_tool.log") + JMX_ERR_FILE = os.path.join(PERSISTENT_ROOT, "jmx_tool.err.log") + LOG4J_CONFIG_FILE = os.path.join(PERSISTENT_ROOT, "tools-log4j.properties") + PID_FILE = os.path.join(PERSISTENT_ROOT, "streams.pid") + + CLEAN_NODE_ENABLED = True + + logs = { + "streams_config": { + "path": CONFIG_FILE, + "collect_default": True}, + "streams_config.1": { + "path": CONFIG_FILE + ".1", + "collect_default": True}, + "streams_config.0-1": { + "path": CONFIG_FILE + ".0-1", + "collect_default": True}, + "streams_config.1-1": { + "path": CONFIG_FILE + ".1-1", + "collect_default": True}, + "streams_log": { + "path": LOG_FILE, + "collect_default": True}, + "streams_stdout": { + "path": STDOUT_FILE, + "collect_default": True}, + "streams_stderr": { + "path": STDERR_FILE, + "collect_default": True}, + "streams_log.1": { + "path": LOG_FILE + ".1", + "collect_default": True}, + "streams_stdout.1": { + "path": STDOUT_FILE + ".1", + "collect_default": True}, + "streams_stderr.1": { + "path": STDERR_FILE + ".1", + "collect_default": True}, + "streams_log.2": { + "path": LOG_FILE + ".2", + "collect_default": True}, + "streams_stdout.2": { + "path": STDOUT_FILE + ".2", + "collect_default": True}, + "streams_stderr.2": { + "path": STDERR_FILE + ".2", + "collect_default": True}, + "streams_log.3": { + "path": LOG_FILE + ".3", + "collect_default": True}, + "streams_stdout.3": { + "path": STDOUT_FILE + ".3", + "collect_default": True}, + "streams_stderr.3": { + "path": STDERR_FILE + ".3", + "collect_default": True}, + "streams_log.0-1": { + "path": LOG_FILE + ".0-1", + "collect_default": True}, + "streams_stdout.0-1": { + "path": STDOUT_FILE + ".0-1", + "collect_default": True}, + "streams_stderr.0-1": { + "path": STDERR_FILE + ".0-1", + "collect_default": True}, + "streams_log.0-2": { + "path": LOG_FILE + ".0-2", + "collect_default": True}, + "streams_stdout.0-2": { + "path": STDOUT_FILE + ".0-2", + "collect_default": True}, + "streams_stderr.0-2": { + "path": STDERR_FILE + ".0-2", + "collect_default": True}, + "streams_log.0-3": { + "path": LOG_FILE + ".0-3", + "collect_default": True}, + "streams_stdout.0-3": { + "path": STDOUT_FILE + ".0-3", + "collect_default": True}, + "streams_stderr.0-3": { + "path": STDERR_FILE + ".0-3", + "collect_default": True}, + "streams_log.0-4": { + "path": LOG_FILE + ".0-4", + "collect_default": True}, + "streams_stdout.0-4": { + "path": STDOUT_FILE + ".0-4", + "collect_default": True}, + "streams_stderr.0-4": { + "path": STDERR_FILE + ".0-4", + "collect_default": True}, + "streams_log.0-5": { + "path": LOG_FILE + ".0-5", + "collect_default": True}, + "streams_stdout.0-5": { + "path": STDOUT_FILE + ".0-5", + "collect_default": True}, + "streams_stderr.0-5": { + "path": STDERR_FILE + ".0-5", + "collect_default": True}, + "streams_log.0-6": { + "path": LOG_FILE + ".0-6", + "collect_default": True}, + "streams_stdout.0-6": { + "path": STDOUT_FILE + ".0-6", + "collect_default": True}, + "streams_stderr.0-6": { + "path": STDERR_FILE + ".0-6", + "collect_default": True}, + "streams_log.1-1": { + "path": LOG_FILE + ".1-1", + "collect_default": True}, + "streams_stdout.1-1": { + "path": STDOUT_FILE + ".1-1", + "collect_default": True}, + "streams_stderr.1-1": { + "path": STDERR_FILE + ".1-1", + "collect_default": True}, + "streams_log.1-2": { + "path": LOG_FILE + ".1-2", + "collect_default": True}, + "streams_stdout.1-2": { + "path": STDOUT_FILE + ".1-2", + "collect_default": True}, + "streams_stderr.1-2": { + "path": STDERR_FILE + ".1-2", + "collect_default": True}, + "streams_log.1-3": { + "path": LOG_FILE + ".1-3", + "collect_default": True}, + "streams_stdout.1-3": { + "path": STDOUT_FILE + ".1-3", + "collect_default": True}, + "streams_stderr.1-3": { + "path": STDERR_FILE + ".1-3", + "collect_default": True}, + "streams_log.1-4": { + "path": LOG_FILE + ".1-4", + "collect_default": True}, + "streams_stdout.1-4": { + "path": STDOUT_FILE + ".1-4", + "collect_default": True}, + "streams_stderr.1-4": { + "path": STDERR_FILE + ".1-4", + "collect_default": True}, + "streams_log.1-5": { + "path": LOG_FILE + ".1-5", + "collect_default": True}, + "streams_stdout.1-5": { + "path": STDOUT_FILE + ".1-5", + "collect_default": True}, + "streams_stderr.1-5": { + "path": STDERR_FILE + ".1-5", + "collect_default": True}, + "streams_log.1-6": { + "path": LOG_FILE + ".1-6", + "collect_default": True}, + "streams_stdout.1-6": { + "path": STDOUT_FILE + ".1-6", + "collect_default": True}, + "streams_stderr.1-6": { + "path": STDERR_FILE + ".1-6", + "collect_default": True}, + "jmx_log": { + "path": JMX_LOG_FILE, + "collect_default": True}, + "jmx_err": { + "path": JMX_ERR_FILE, + "collect_default": True}, + } + + def __init__(self, test_context, kafka, streams_class_name, user_test_args1, user_test_args2=None, user_test_args3=None, user_test_args4=None): + Service.__init__(self, test_context, num_nodes=1) + self.kafka = kafka + self.args = {'streams_class_name': streams_class_name, + 'user_test_args1': user_test_args1, + 'user_test_args2': user_test_args2, + 'user_test_args3': user_test_args3, + 'user_test_args4': user_test_args4} + self.log_level = "DEBUG" + + @property + def node(self): + return self.nodes[0] + + @property + def expectedMessage(self): + return 'StreamsTest instance started' + + def pids(self, node): + try: + pids = [pid for pid in node.account.ssh_capture("cat " + self.PID_FILE, callback=str)] + return [int(pid) for pid in pids] + except Exception as exception: + self.logger.debug(str(exception)) + return [] + + def stop_nodes(self, clean_shutdown=True): + for node in self.nodes: + self.stop_node(node, clean_shutdown) + + def stop_node(self, node, clean_shutdown=True): + self.logger.info((clean_shutdown and "Cleanly" or "Forcibly") + " stopping Streams Test on " + str(node.account)) + pids = self.pids(node) + sig = signal.SIGTERM if clean_shutdown else signal.SIGKILL + + for pid in pids: + node.account.signal(pid, sig, allow_fail=True) + if clean_shutdown: + for pid in pids: + wait_until(lambda: not node.account.alive(pid), timeout_sec=120, err_msg="Streams Test process on " + str(node.account) + " took too long to exit") + + node.account.ssh("rm -f " + self.PID_FILE, allow_fail=False) + + def restart(self): + # We don't want to do any clean up here, just restart the process. + for node in self.nodes: + self.logger.info("Restarting Kafka Streams on " + str(node.account)) + self.stop_node(node) + self.start_node(node) + + + def abortThenRestart(self): + # We don't want to do any clean up here, just abort then restart the process. The running service is killed immediately. + for node in self.nodes: + self.logger.info("Aborting Kafka Streams on " + str(node.account)) + self.stop_node(node, False) + self.logger.info("Restarting Kafka Streams on " + str(node.account)) + self.start_node(node) + + def wait(self, timeout_sec=1440): + for node in self.nodes: + self.wait_node(node, timeout_sec) + + def wait_node(self, node, timeout_sec=None): + for pid in self.pids(node): + wait_until(lambda: not node.account.alive(pid), timeout_sec=timeout_sec, err_msg="Streams Test process on " + str(node.account) + " took too long to exit") + + def clean_node(self, node): + node.account.kill_process("streams", clean_shutdown=False, allow_fail=True) + if self.CLEAN_NODE_ENABLED: + node.account.ssh("rm -rf " + self.PERSISTENT_ROOT, allow_fail=False) + + def start_cmd(self, node): + args = self.args.copy() + args['config_file'] = self.CONFIG_FILE + args['stdout'] = self.STDOUT_FILE + args['stderr'] = self.STDERR_FILE + args['pidfile'] = self.PID_FILE + args['log4j'] = self.LOG4J_CONFIG_FILE + args['kafka_run_class'] = self.path.script("kafka-run-class.sh", node) + + cmd = "( export KAFKA_LOG4J_OPTS=\"-Dlog4j.configuration=file:%(log4j)s\"; " \ + "INCLUDE_TEST_JARS=true %(kafka_run_class)s %(streams_class_name)s " \ + " %(config_file)s %(user_test_args1)s %(user_test_args2)s %(user_test_args3)s" \ + " %(user_test_args4)s & echo $! >&3 ) 1>> %(stdout)s 2>> %(stderr)s 3> %(pidfile)s" % args + + self.logger.info("Executing streams cmd: " + cmd) + + return cmd + + def prop_file(self): + cfg = KafkaConfig(**{streams_property.STATE_DIR: self.PERSISTENT_ROOT, streams_property.KAFKA_SERVERS: self.kafka.bootstrap_servers()}) + return cfg.render() + + def start_node(self, node): + node.account.mkdirs(self.PERSISTENT_ROOT) + prop_file = self.prop_file() + node.account.create_file(self.CONFIG_FILE, prop_file) + node.account.create_file(self.LOG4J_CONFIG_FILE, self.render('tools_log4j.properties', log_file=self.LOG_FILE)) + + self.logger.info("Starting StreamsTest process on " + str(node.account)) + with node.account.monitor_log(self.STDOUT_FILE) as monitor: + node.account.ssh(self.start_cmd(node)) + monitor.wait_until(self.expectedMessage, timeout_sec=60, err_msg="Never saw message indicating StreamsTest finished startup on " + str(node.account)) + + if len(self.pids(node)) == 0: + raise RuntimeError("No process ids recorded") + + +class StreamsSmokeTestBaseService(StreamsTestBaseService): + """Base class for Streams Smoke Test services providing some common settings and functionality""" + + def __init__(self, test_context, kafka, command, processing_guarantee = 'at_least_once', num_threads = 3, replication_factor = 3): + super(StreamsSmokeTestBaseService, self).__init__(test_context, + kafka, + "org.apache.kafka.streams.tests.StreamsSmokeTest", + command) + self.NUM_THREADS = num_threads + self.PROCESSING_GUARANTEE = processing_guarantee + self.KAFKA_STREAMS_VERSION = "" + self.UPGRADE_FROM = None + self.REPLICATION_FACTOR = replication_factor + + def set_version(self, kafka_streams_version): + self.KAFKA_STREAMS_VERSION = kafka_streams_version + + def set_upgrade_from(self, upgrade_from): + self.UPGRADE_FROM = upgrade_from + + def prop_file(self): + properties = {streams_property.STATE_DIR: self.PERSISTENT_ROOT, + streams_property.KAFKA_SERVERS: self.kafka.bootstrap_servers(), + streams_property.PROCESSING_GUARANTEE: self.PROCESSING_GUARANTEE, + streams_property.NUM_THREADS: self.NUM_THREADS, + "replication.factor": self.REPLICATION_FACTOR, + "num.standby.replicas": 2, + "buffered.records.per.partition": 100, + "commit.interval.ms": 1000, + "auto.offset.reset": "earliest", + "acks": "all", + "acceptable.recovery.lag": "9223372036854775807", # enable a one-shot assignment + "session.timeout.ms": "10000" # set back to 10s for tests. See KIP-735 + } + + if self.UPGRADE_FROM is not None: + properties['upgrade.from'] = self.UPGRADE_FROM + + cfg = KafkaConfig(**properties) + return cfg.render() + + def start_cmd(self, node): + args = self.args.copy() + args['config_file'] = self.CONFIG_FILE + args['stdout'] = self.STDOUT_FILE + args['stderr'] = self.STDERR_FILE + args['pidfile'] = self.PID_FILE + args['log4j'] = self.LOG4J_CONFIG_FILE + args['version'] = self.KAFKA_STREAMS_VERSION + args['kafka_run_class'] = self.path.script("kafka-run-class.sh", node) + + cmd = "( export KAFKA_LOG4J_OPTS=\"-Dlog4j.configuration=file:%(log4j)s\";" \ + " INCLUDE_TEST_JARS=true UPGRADE_KAFKA_STREAMS_TEST_VERSION=%(version)s" \ + " %(kafka_run_class)s %(streams_class_name)s" \ + " %(config_file)s %(user_test_args1)s" \ + " & echo $! >&3 ) " \ + "1>> %(stdout)s 2>> %(stderr)s 3> %(pidfile)s" % args + + self.logger.info("Executing streams cmd: " + cmd) + + return cmd + +class StreamsEosTestBaseService(StreamsTestBaseService): + """Base class for Streams EOS Test services providing some common settings and functionality""" + + clean_node_enabled = True + + def __init__(self, test_context, kafka, processing_guarantee, command): + super(StreamsEosTestBaseService, self).__init__(test_context, + kafka, + "org.apache.kafka.streams.tests.StreamsEosTest", + command) + self.PROCESSING_GUARANTEE = processing_guarantee + + def prop_file(self): + properties = {streams_property.STATE_DIR: self.PERSISTENT_ROOT, + streams_property.KAFKA_SERVERS: self.kafka.bootstrap_servers(), + streams_property.PROCESSING_GUARANTEE: self.PROCESSING_GUARANTEE, + "acceptable.recovery.lag": "9223372036854775807", # enable a one-shot assignment + "session.timeout.ms": "10000" # set back to 10s for tests. See KIP-735 + } + + cfg = KafkaConfig(**properties) + return cfg.render() + + def clean_node(self, node): + if self.clean_node_enabled: + super(StreamsEosTestBaseService, self).clean_node(node) + + +class StreamsSmokeTestDriverService(StreamsSmokeTestBaseService): + def __init__(self, test_context, kafka): + super(StreamsSmokeTestDriverService, self).__init__(test_context, kafka, "run") + self.DISABLE_AUTO_TERMINATE = "" + + def disable_auto_terminate(self): + self.DISABLE_AUTO_TERMINATE = "disableAutoTerminate" + + def start_cmd(self, node): + args = self.args.copy() + args['config_file'] = self.CONFIG_FILE + args['stdout'] = self.STDOUT_FILE + args['stderr'] = self.STDERR_FILE + args['pidfile'] = self.PID_FILE + args['log4j'] = self.LOG4J_CONFIG_FILE + args['disable_auto_terminate'] = self.DISABLE_AUTO_TERMINATE + args['kafka_run_class'] = self.path.script("kafka-run-class.sh", node) + + cmd = "( export KAFKA_LOG4J_OPTS=\"-Dlog4j.configuration=file:%(log4j)s\"; " \ + "INCLUDE_TEST_JARS=true %(kafka_run_class)s %(streams_class_name)s " \ + " %(config_file)s %(user_test_args1)s %(disable_auto_terminate)s" \ + " & echo $! >&3 ) 1>> %(stdout)s 2>> %(stderr)s 3> %(pidfile)s" % args + + self.logger.info("Executing streams cmd: " + cmd) + + return cmd + +class StreamsSmokeTestJobRunnerService(StreamsSmokeTestBaseService): + def __init__(self, test_context, kafka, processing_guarantee, num_threads = 3, replication_factor = 3): + super(StreamsSmokeTestJobRunnerService, self).__init__(test_context, kafka, "process", processing_guarantee, num_threads, replication_factor) + +class StreamsEosTestDriverService(StreamsEosTestBaseService): + def __init__(self, test_context, kafka): + super(StreamsEosTestDriverService, self).__init__(test_context, kafka, "not-required", "run") + +class StreamsEosTestJobRunnerService(StreamsEosTestBaseService): + def __init__(self, test_context, kafka, processing_guarantee): + super(StreamsEosTestJobRunnerService, self).__init__(test_context, kafka, processing_guarantee, "process") + +class StreamsComplexEosTestJobRunnerService(StreamsEosTestBaseService): + def __init__(self, test_context, kafka, processing_guarantee): + super(StreamsComplexEosTestJobRunnerService, self).__init__(test_context, kafka, processing_guarantee, "process-complex") + +class StreamsEosTestVerifyRunnerService(StreamsEosTestBaseService): + def __init__(self, test_context, kafka): + super(StreamsEosTestVerifyRunnerService, self).__init__(test_context, kafka, "not-required", "verify") + + +class StreamsComplexEosTestVerifyRunnerService(StreamsEosTestBaseService): + def __init__(self, test_context, kafka): + super(StreamsComplexEosTestVerifyRunnerService, self).__init__(test_context, kafka, "not-required", "verify-complex") + + +class StreamsSmokeTestShutdownDeadlockService(StreamsSmokeTestBaseService): + def __init__(self, test_context, kafka): + super(StreamsSmokeTestShutdownDeadlockService, self).__init__(test_context, kafka, "close-deadlock-test") + + +class StreamsBrokerCompatibilityService(StreamsTestBaseService): + def __init__(self, test_context, kafka, processingMode): + super(StreamsBrokerCompatibilityService, self).__init__(test_context, + kafka, + "org.apache.kafka.streams.tests.BrokerCompatibilityTest", + processingMode) + + def prop_file(self): + properties = {streams_property.STATE_DIR: self.PERSISTENT_ROOT, + streams_property.KAFKA_SERVERS: self.kafka.bootstrap_servers(), + # the old broker (< 2.4) does not support configuration replication.factor=-1 + "replication.factor": 1, + "acceptable.recovery.lag": "9223372036854775807", # enable a one-shot assignment + "session.timeout.ms": "10000" # set back to 10s for tests. See KIP-735 + } + + cfg = KafkaConfig(**properties) + return cfg.render() + + +class StreamsBrokerDownResilienceService(StreamsTestBaseService): + def __init__(self, test_context, kafka, configs): + super(StreamsBrokerDownResilienceService, self).__init__(test_context, + kafka, + "org.apache.kafka.streams.tests.StreamsBrokerDownResilienceTest", + configs) + + def start_cmd(self, node): + args = self.args.copy() + args['config_file'] = self.CONFIG_FILE + args['stdout'] = self.STDOUT_FILE + args['stderr'] = self.STDERR_FILE + args['pidfile'] = self.PID_FILE + args['log4j'] = self.LOG4J_CONFIG_FILE + args['kafka_run_class'] = self.path.script("kafka-run-class.sh", node) + + cmd = "( export KAFKA_LOG4J_OPTS=\"-Dlog4j.configuration=file:%(log4j)s\"; " \ + "INCLUDE_TEST_JARS=true %(kafka_run_class)s %(streams_class_name)s " \ + " %(config_file)s %(user_test_args1)s %(user_test_args2)s %(user_test_args3)s" \ + " %(user_test_args4)s & echo $! >&3 ) 1>> %(stdout)s 2>> %(stderr)s 3> %(pidfile)s" % args + + self.logger.info("Executing: " + cmd) + + return cmd + + +class StreamsStandbyTaskService(StreamsTestBaseService): + def __init__(self, test_context, kafka, configs): + super(StreamsStandbyTaskService, self).__init__(test_context, + kafka, + "org.apache.kafka.streams.tests.StreamsStandByReplicaTest", + configs) + +class StreamsResetter(StreamsTestBaseService): + def __init__(self, test_context, kafka, topic, applicationId): + super(StreamsResetter, self).__init__(test_context, + kafka, + "kafka.tools.StreamsResetter", + "") + self.topic = topic + self.applicationId = applicationId + + @property + def expectedMessage(self): + return 'Done.' + + def start_cmd(self, node): + args = self.args.copy() + args['bootstrap.servers'] = self.kafka.bootstrap_servers() + args['stdout'] = self.STDOUT_FILE + args['stderr'] = self.STDERR_FILE + args['pidfile'] = self.PID_FILE + args['log4j'] = self.LOG4J_CONFIG_FILE + args['application.id'] = self.applicationId + args['input.topics'] = self.topic + args['kafka_run_class'] = self.path.script("kafka-run-class.sh", node) + + cmd = "(export KAFKA_LOG4J_OPTS=\"-Dlog4j.configuration=file:%(log4j)s\"; " \ + "%(kafka_run_class)s %(streams_class_name)s " \ + "--bootstrap-servers %(bootstrap.servers)s " \ + "--force " \ + "--application-id %(application.id)s " \ + "--input-topics %(input.topics)s " \ + "& echo $! >&3 ) " \ + "1>> %(stdout)s " \ + "2>> %(stderr)s " \ + "3> %(pidfile)s "% args + + self.logger.info("Executing: " + cmd) + + return cmd + + +class StreamsOptimizedUpgradeTestService(StreamsTestBaseService): + def __init__(self, test_context, kafka): + super(StreamsOptimizedUpgradeTestService, self).__init__(test_context, + kafka, + "org.apache.kafka.streams.tests.StreamsOptimizedTest", + "") + self.OPTIMIZED_CONFIG = 'none' + self.INPUT_TOPIC = None + self.AGGREGATION_TOPIC = None + self.REDUCE_TOPIC = None + self.JOIN_TOPIC = None + + def prop_file(self): + properties = {streams_property.STATE_DIR: self.PERSISTENT_ROOT, + streams_property.KAFKA_SERVERS: self.kafka.bootstrap_servers(), + 'topology.optimization': self.OPTIMIZED_CONFIG, + 'input.topic': self.INPUT_TOPIC, + 'aggregation.topic': self.AGGREGATION_TOPIC, + 'reduce.topic': self.REDUCE_TOPIC, + 'join.topic': self.JOIN_TOPIC, + "acceptable.recovery.lag": "9223372036854775807", # enable a one-shot assignment + "session.timeout.ms": "10000" # set back to 10s for tests. See KIP-735 + } + + + cfg = KafkaConfig(**properties) + return cfg.render() + + +class StreamsUpgradeTestJobRunnerService(StreamsTestBaseService): + def __init__(self, test_context, kafka): + super(StreamsUpgradeTestJobRunnerService, self).__init__(test_context, + kafka, + "org.apache.kafka.streams.tests.StreamsUpgradeTest", + "") + self.UPGRADE_FROM = None + self.UPGRADE_TO = None + self.extra_properties = {} + + def set_config(self, key, value): + self.extra_properties[key] = value + + def set_version(self, kafka_streams_version): + self.KAFKA_STREAMS_VERSION = kafka_streams_version + + def set_upgrade_from(self, upgrade_from): + self.UPGRADE_FROM = upgrade_from + + def set_upgrade_to(self, upgrade_to): + self.UPGRADE_TO = upgrade_to + + def prop_file(self): + properties = self.extra_properties.copy() + properties[streams_property.STATE_DIR] = self.PERSISTENT_ROOT + properties[streams_property.KAFKA_SERVERS] = self.kafka.bootstrap_servers() + + if self.UPGRADE_FROM is not None: + properties['upgrade.from'] = self.UPGRADE_FROM + if self.UPGRADE_TO == "future_version": + properties['test.future.metadata'] = "any_value" + + # Long.MAX_VALUE lets us do the assignment without a warmup + properties['acceptable.recovery.lag'] = "9223372036854775807" + properties["session.timeout.ms"] = "10000" # set back to 10s for tests. See KIP-735 + + cfg = KafkaConfig(**properties) + return cfg.render() + + def start_cmd(self, node): + args = self.args.copy() + + if self.KAFKA_STREAMS_VERSION == str(LATEST_0_10_0) or self.KAFKA_STREAMS_VERSION == str(LATEST_0_10_1): + args['zk'] = self.kafka.zk.connect_setting() + else: + args['zk'] = "" + args['config_file'] = self.CONFIG_FILE + args['stdout'] = self.STDOUT_FILE + args['stderr'] = self.STDERR_FILE + args['pidfile'] = self.PID_FILE + args['log4j'] = self.LOG4J_CONFIG_FILE + args['version'] = self.KAFKA_STREAMS_VERSION + args['kafka_run_class'] = self.path.script("kafka-run-class.sh", node) + + cmd = "( export KAFKA_LOG4J_OPTS=\"-Dlog4j.configuration=file:%(log4j)s\"; " \ + "INCLUDE_TEST_JARS=true UPGRADE_KAFKA_STREAMS_TEST_VERSION=%(version)s " \ + " %(kafka_run_class)s %(streams_class_name)s %(zk)s %(config_file)s " \ + " & echo $! >&3 ) 1>> %(stdout)s 2>> %(stderr)s 3> %(pidfile)s" % args + + self.logger.info("Executing: " + cmd) + + return cmd + + +class StreamsNamedRepartitionTopicService(StreamsTestBaseService): + def __init__(self, test_context, kafka): + super(StreamsNamedRepartitionTopicService, self).__init__(test_context, + kafka, + "org.apache.kafka.streams.tests.StreamsNamedRepartitionTest", + "") + self.ADD_ADDITIONAL_OPS = 'false' + self.INPUT_TOPIC = None + self.AGGREGATION_TOPIC = None + + def prop_file(self): + properties = {streams_property.STATE_DIR: self.PERSISTENT_ROOT, + streams_property.KAFKA_SERVERS: self.kafka.bootstrap_servers(), + 'input.topic': self.INPUT_TOPIC, + 'aggregation.topic': self.AGGREGATION_TOPIC, + 'add.operations': self.ADD_ADDITIONAL_OPS, + "acceptable.recovery.lag": "9223372036854775807", # enable a one-shot assignment + "session.timeout.ms": "10000" # set back to 10s for tests. See KIP-735 + } + + + cfg = KafkaConfig(**properties) + return cfg.render() + + +class StaticMemberTestService(StreamsTestBaseService): + def __init__(self, test_context, kafka, group_instance_id, num_threads): + super(StaticMemberTestService, self).__init__(test_context, + kafka, + "org.apache.kafka.streams.tests.StaticMemberTestClient", + "") + self.INPUT_TOPIC = None + self.GROUP_INSTANCE_ID = group_instance_id + self.NUM_THREADS = num_threads + def prop_file(self): + properties = {streams_property.STATE_DIR: self.PERSISTENT_ROOT, + streams_property.KAFKA_SERVERS: self.kafka.bootstrap_servers(), + streams_property.NUM_THREADS: self.NUM_THREADS, + consumer_property.GROUP_INSTANCE_ID: self.GROUP_INSTANCE_ID, + consumer_property.SESSION_TIMEOUT_MS: 60000, + 'input.topic': self.INPUT_TOPIC, + "acceptable.recovery.lag": "9223372036854775807", # enable a one-shot assignment + "session.timeout.ms": "10000" # set back to 10s for tests. See KIP-735 + } + + + cfg = KafkaConfig(**properties) + return cfg.render() + + +class CooperativeRebalanceUpgradeService(StreamsTestBaseService): + def __init__(self, test_context, kafka): + super(CooperativeRebalanceUpgradeService, self).__init__(test_context, + kafka, + "org.apache.kafka.streams.tests.StreamsUpgradeToCooperativeRebalanceTest", + "") + self.UPGRADE_FROM = None + # these properties will be overridden in test + self.SOURCE_TOPIC = None + self.SINK_TOPIC = None + self.TASK_DELIMITER = "#" + self.REPORT_INTERVAL = None + + self.standby_tasks = None + self.active_tasks = None + self.upgrade_phase = None + + def set_tasks(self, task_string): + label = "TASK-ASSIGNMENTS:" + task_string_substr = task_string[len(label):] + all_tasks = task_string_substr.split(self.TASK_DELIMITER) + self.active_tasks = set(all_tasks[0].split(",")) + if len(all_tasks) > 1: + self.standby_tasks = set(all_tasks[1].split(",")) + + def set_version(self, kafka_streams_version): + self.KAFKA_STREAMS_VERSION = kafka_streams_version + + def set_upgrade_phase(self, upgrade_phase): + self.upgrade_phase = upgrade_phase + + def start_cmd(self, node): + args = self.args.copy() + + if self.KAFKA_STREAMS_VERSION == str(LATEST_0_10_0) or self.KAFKA_STREAMS_VERSION == str(LATEST_0_10_1): + args['zk'] = self.kafka.zk.connect_setting() + else: + args['zk'] = "" + args['config_file'] = self.CONFIG_FILE + args['stdout'] = self.STDOUT_FILE + args['stderr'] = self.STDERR_FILE + args['pidfile'] = self.PID_FILE + args['log4j'] = self.LOG4J_CONFIG_FILE + args['version'] = self.KAFKA_STREAMS_VERSION + args['kafka_run_class'] = self.path.script("kafka-run-class.sh", node) + + cmd = "( export KAFKA_LOG4J_OPTS=\"-Dlog4j.configuration=file:%(log4j)s\"; " \ + "INCLUDE_TEST_JARS=true UPGRADE_KAFKA_STREAMS_TEST_VERSION=%(version)s " \ + " %(kafka_run_class)s %(streams_class_name)s %(zk)s %(config_file)s " \ + " & echo $! >&3 ) 1>> %(stdout)s 2>> %(stderr)s 3> %(pidfile)s" % args + + self.logger.info("Executing: " + cmd) + + return cmd + + def prop_file(self): + properties = {streams_property.STATE_DIR: self.PERSISTENT_ROOT, + streams_property.KAFKA_SERVERS: self.kafka.bootstrap_servers(), + 'source.topic': self.SOURCE_TOPIC, + 'sink.topic': self.SINK_TOPIC, + 'task.delimiter': self.TASK_DELIMITER, + 'report.interval': self.REPORT_INTERVAL, + "acceptable.recovery.lag": "9223372036854775807", # enable a one-shot assignment + "session.timeout.ms": "10000" # set back to 10s for tests. See KIP-735 + } + + if self.UPGRADE_FROM is not None: + properties['upgrade.from'] = self.UPGRADE_FROM + else: + try: + del properties['upgrade.from'] + except KeyError: + self.logger.info("Key 'upgrade.from' not there, better safe than sorry") + + if self.upgrade_phase is not None: + properties['upgrade.phase'] = self.upgrade_phase + + + cfg = KafkaConfig(**properties) + return cfg.render() diff --git a/tests/kafkatest/services/streams_property.py b/tests/kafkatest/services/streams_property.py new file mode 100644 index 0000000..8900adb --- /dev/null +++ b/tests/kafkatest/services/streams_property.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Define Streams configuration property names here. +""" + +STATE_DIR = "state.dir" +KAFKA_SERVERS = "bootstrap.servers" +NUM_THREADS = "num.stream.threads" +PROCESSING_GUARANTEE = "processing.guarantee" diff --git a/tests/kafkatest/services/templates/connect_log4j.properties b/tests/kafkatest/services/templates/connect_log4j.properties new file mode 100644 index 0000000..4894612 --- /dev/null +++ b/tests/kafkatest/services/templates/connect_log4j.properties @@ -0,0 +1,29 @@ +## +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +## + +# Define the root logger with appender file +log4j.rootLogger = {{ log_level|default("INFO") }}, FILE + +log4j.appender.FILE=org.apache.log4j.FileAppender +log4j.appender.FILE.File={{ log_file }} +log4j.appender.FILE.ImmediateFlush=true +log4j.appender.FILE.Append=true +log4j.appender.FILE.layout=org.apache.log4j.PatternLayout +log4j.appender.FILE.layout.conversionPattern=[%d] %p %m (%c)%n + +log4j.logger.org.apache.zookeeper=ERROR +log4j.logger.org.reflections=ERROR diff --git a/tests/kafkatest/services/templates/console_consumer.properties b/tests/kafkatest/services/templates/console_consumer.properties new file mode 100644 index 0000000..40ed2f3 --- /dev/null +++ b/tests/kafkatest/services/templates/console_consumer.properties @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +group.id={{ group_id|default('test-consumer-group') }} + +{% if client_id is defined and client_id is not none %} +client.id={{ client_id }} +{% endif %} + +{% if consumer_metadata_max_age_ms is defined and consumer_metadata_max_age_ms is not none %} +metadata.max.age.ms={{ consumer_metadata_max_age_ms }} +{% endif %} diff --git a/tests/kafkatest/services/templates/mirror_maker_consumer.properties b/tests/kafkatest/services/templates/mirror_maker_consumer.properties new file mode 100644 index 0000000..2e66573 --- /dev/null +++ b/tests/kafkatest/services/templates/mirror_maker_consumer.properties @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# see kafka.consumer.ConsumerConfig for more details + +bootstrap.servers={{ source.bootstrap_servers(security_config.security_protocol) }} + +{% if source_auto_offset_reset is defined and source_auto_offset_reset is not none %} +auto.offset.reset={{ source_auto_offset_reset|default('latest') }} +{% endif %} + +group.id={{ group_id|default('test-consumer-group') }} + +{% if partition_assignment_strategy is defined and partition_assignment_strategy is not none %} +partition.assignment.strategy={{ partition_assignment_strategy }} +{% endif %} diff --git a/tests/kafkatest/services/templates/mirror_maker_producer.properties b/tests/kafkatest/services/templates/mirror_maker_producer.properties new file mode 100644 index 0000000..fcfd24b --- /dev/null +++ b/tests/kafkatest/services/templates/mirror_maker_producer.properties @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +bootstrap.servers = {{ target.bootstrap_servers(security_config.security_protocol) }} + +{% if producer_interceptor_classes is defined and producer_interceptor_classes is not none %} +interceptor.classes={{ producer_interceptor_classes }} +{% endif %} diff --git a/tests/kafkatest/services/templates/producer.properties b/tests/kafkatest/services/templates/producer.properties new file mode 100644 index 0000000..c41a0ea --- /dev/null +++ b/tests/kafkatest/services/templates/producer.properties @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# see kafka.producer.ProducerConfig for more details + +request.timeout.ms={{ request_timeout_ms }} diff --git a/tests/kafkatest/services/templates/tools_log4j.properties b/tests/kafkatest/services/templates/tools_log4j.properties new file mode 100644 index 0000000..3f83b42 --- /dev/null +++ b/tests/kafkatest/services/templates/tools_log4j.properties @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Define the root logger with appender file +log4j.rootLogger = {{ log_level|default("INFO") }}, FILE + +{% if loggers is defined %} +{% for logger, log_level in loggers.items() %} +log4j.logger.{{ logger }}={{ log_level }} +{% endfor %} +{% endif %} + +log4j.appender.FILE=org.apache.log4j.FileAppender +log4j.appender.FILE.File={{ log_file }} +log4j.appender.FILE.ImmediateFlush=true +# Set the append to true +log4j.appender.FILE.Append=true +log4j.appender.FILE.layout=org.apache.log4j.PatternLayout +log4j.appender.FILE.layout.conversionPattern=[%d] %p %m (%c)%n diff --git a/tests/kafkatest/services/templates/zookeeper.properties b/tests/kafkatest/services/templates/zookeeper.properties new file mode 100644 index 0000000..0c38061 --- /dev/null +++ b/tests/kafkatest/services/templates/zookeeper.properties @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +dataDir=/mnt/zookeeper/data +{% if zk_client_port %} +clientPort=2181 +{% endif %} +{% if zk_client_secure_port %} +secureClientPort=2182 +serverCnxnFactory=org.apache.zookeeper.server.NettyServerCnxnFactory +authProvider.x509=org.apache.zookeeper.server.auth.X509AuthenticationProvider +ssl.keyStore.location=/mnt/security/test.keystore.jks +ssl.keyStore.password=test-ks-passwd +ssl.keyStore.type=JKS +ssl.trustStore.location=/mnt/security/test.truststore.jks +ssl.trustStore.password=test-ts-passwd +ssl.trustStore.type=JKS +{% if zk_tls_encrypt_only %} +ssl.clientAuth=none +{% endif %} +{% endif %} +maxClientCnxns=0 +initLimit=5 +syncLimit=2 +quorumListenOnAllIPs=true +{% for node in nodes %} +server.{{ loop.index }}={{ node.account.hostname }}:2888:3888 +{% endfor %} +# Configuration "snapshot.trust.empty" is ignored prior to ZooKeeper version 3.5.6, +# but it is needed thereafter for system test upgrades +# (see https://issues.apache.org/jira/browse/ZOOKEEPER-3056 for details). +snapshot.trust.empty=true diff --git a/tests/kafkatest/services/transactional_message_copier.py b/tests/kafkatest/services/transactional_message_copier.py new file mode 100644 index 0000000..0717463 --- /dev/null +++ b/tests/kafkatest/services/transactional_message_copier.py @@ -0,0 +1,208 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import signal + +from ducktape.utils.util import wait_until +from ducktape.services.background_thread import BackgroundThreadService +from kafkatest.directory_layout.kafka_path import KafkaPathResolverMixin +from ducktape.cluster.remoteaccount import RemoteCommandError + +class TransactionalMessageCopier(KafkaPathResolverMixin, BackgroundThreadService): + """This service wraps org.apache.kafka.tools.TransactionalMessageCopier for + use in system testing. + """ + PERSISTENT_ROOT = "/mnt/transactional_message_copier" + STDOUT_CAPTURE = os.path.join(PERSISTENT_ROOT, "transactional_message_copier.stdout") + STDERR_CAPTURE = os.path.join(PERSISTENT_ROOT, "transactional_message_copier.stderr") + LOG_DIR = os.path.join(PERSISTENT_ROOT, "logs") + LOG_FILE = os.path.join(LOG_DIR, "transactional_message_copier.log") + LOG4J_CONFIG = os.path.join(PERSISTENT_ROOT, "tools-log4j.properties") + + logs = { + "transactional_message_copier_stdout": { + "path": STDOUT_CAPTURE, + "collect_default": True}, + "transactional_message_copier_stderr": { + "path": STDERR_CAPTURE, + "collect_default": True}, + "transactional_message_copier_log": { + "path": LOG_FILE, + "collect_default": True} + } + + def __init__(self, context, num_nodes, kafka, transactional_id, consumer_group, + input_topic, input_partition, output_topic, max_messages=-1, + transaction_size=1000, transaction_timeout=None, enable_random_aborts=True, + use_group_metadata=False, group_mode=False): + super(TransactionalMessageCopier, self).__init__(context, num_nodes) + self.kafka = kafka + self.transactional_id = transactional_id + self.consumer_group = consumer_group + self.transaction_size = transaction_size + self.transaction_timeout = transaction_timeout + self.input_topic = input_topic + self.input_partition = input_partition + self.output_topic = output_topic + self.max_messages = max_messages + self.message_copy_finished = False + self.consumed = -1 + self.remaining = -1 + self.stop_timeout_sec = 60 + self.enable_random_aborts = enable_random_aborts + self.use_group_metadata = use_group_metadata + self.group_mode = group_mode + self.loggers = { + "org.apache.kafka.clients.producer": "TRACE", + "org.apache.kafka.clients.consumer": "TRACE" + } + + def _worker(self, idx, node): + node.account.ssh("mkdir -p %s" % TransactionalMessageCopier.PERSISTENT_ROOT, + allow_fail=False) + # Create and upload log properties + log_config = self.render('tools_log4j.properties', + log_file=TransactionalMessageCopier.LOG_FILE) + node.account.create_file(TransactionalMessageCopier.LOG4J_CONFIG, log_config) + # Configure security + self.security_config = self.kafka.security_config.client_config(node=node) + self.security_config.setup_node(node) + cmd = self.start_cmd(node, idx) + self.logger.debug("TransactionalMessageCopier %d command: %s" % (idx, cmd)) + try: + for line in node.account.ssh_capture(cmd): + line = line.strip() + data = self.try_parse_json(line) + if data is not None: + with self.lock: + self.remaining = int(data["remaining"]) + self.consumed = int(data["consumed"]) + self.logger.info("%s: consumed %d, remaining %d" % + (self.transactional_id, self.consumed, self.remaining)) + if "shutdown_complete" in data: + if self.remaining == 0: + # We are only finished if the remaining + # messages at the time of shutdown is 0. + # + # Otherwise a clean shutdown would still print + # a 'shutdown complete' messages even though + # there are unprocessed messages, causing + # tests to fail. + self.logger.info("%s : Finished message copy" % self.transactional_id) + self.message_copy_finished = True + else: + self.logger.info("%s : Shut down without finishing message copy." %\ + self.transactional_id) + except RemoteCommandError as e: + self.logger.debug("Got exception while reading output from copier, \ + probably because it was SIGKILL'd (exit code 137): %s" % str(e)) + + def start_cmd(self, node, idx): + cmd = "export LOG_DIR=%s;" % TransactionalMessageCopier.LOG_DIR + cmd += " export KAFKA_OPTS=%s;" % self.security_config.kafka_opts + cmd += " export KAFKA_LOG4J_OPTS=\"-Dlog4j.configuration=file:%s\"; " % TransactionalMessageCopier.LOG4J_CONFIG + cmd += self.path.script("kafka-run-class.sh", node) + " org.apache.kafka.tools." + "TransactionalMessageCopier" + cmd += " --broker-list %s" % self.kafka.bootstrap_servers(self.security_config.security_protocol) + cmd += " --transactional-id %s" % self.transactional_id + cmd += " --consumer-group %s" % self.consumer_group + cmd += " --input-topic %s" % self.input_topic + cmd += " --output-topic %s" % self.output_topic + cmd += " --input-partition %s" % str(self.input_partition) + cmd += " --transaction-size %s" % str(self.transaction_size) + + if self.transaction_timeout is not None: + cmd += " --transaction-timeout %s" % str(self.transaction_timeout) + + if self.enable_random_aborts: + cmd += " --enable-random-aborts" + + if self.use_group_metadata: + cmd += " --use-group-metadata" + + if self.group_mode: + cmd += " --group-mode" + + if self.max_messages > 0: + cmd += " --max-messages %s" % str(self.max_messages) + cmd += " 2>> %s | tee -a %s &" % (TransactionalMessageCopier.STDERR_CAPTURE, TransactionalMessageCopier.STDOUT_CAPTURE) + + return cmd + + def clean_node(self, node): + self.kill_node(node, clean_shutdown=False) + node.account.ssh("rm -rf " + self.PERSISTENT_ROOT, allow_fail=False) + self.security_config.clean_node(node) + + def pids(self, node): + try: + cmd = "jps | grep -i TransactionalMessageCopier | awk '{print $1}'" + pid_arr = [pid for pid in node.account.ssh_capture(cmd, allow_fail=True, callback=int)] + return pid_arr + except (RemoteCommandError, ValueError) as e: + self.logger.error("Could not list pids: %s" % str(e)) + return [] + + def alive(self, node): + return len(self.pids(node)) > 0 + + def start_node(self, node): + BackgroundThreadService.start_node(self, node) + wait_until(lambda: self.alive(node), timeout_sec=60, err_msg="Node %s: Message Copier failed to start" % str(node.account)) + + def kill_node(self, node, clean_shutdown=True): + pids = self.pids(node) + sig = signal.SIGTERM if clean_shutdown else signal.SIGKILL + for pid in pids: + node.account.signal(pid, sig) + wait_until(lambda: len(self.pids(node)) == 0, timeout_sec=60, err_msg="Node %s: Message Copier failed to stop" % str(node.account)) + + def stop_node(self, node, clean_shutdown=True): + self.kill_node(node, clean_shutdown) + stopped = self.wait_node(node, timeout_sec=self.stop_timeout_sec) + assert stopped, "Node %s: did not stop within the specified timeout of %s seconds" % \ + (str(node.account), str(self.stop_timeout_sec)) + + def restart(self, clean_shutdown): + if self.is_done: + return + node = self.nodes[0] + with self.lock: + self.consumed = -1 + self.remaining = -1 + self.stop_node(node, clean_shutdown) + self.start_node(node) + + def try_parse_json(self, string): + """Try to parse a string as json. Return None if not parseable.""" + try: + record = json.loads(string) + return record + except ValueError: + self.logger.debug("Could not parse as json: %s" % str(string)) + return None + + @property + def is_done(self): + return self.message_copy_finished + + def progress_percent(self): + with self.lock: + if self.remaining < 0: + return 0 + if self.consumed + self.remaining == 0: + return 100 + return (float(self.consumed)/float(self.consumed + self.remaining)) * 100 diff --git a/tests/kafkatest/services/trogdor/__init__.py b/tests/kafkatest/services/trogdor/__init__.py new file mode 100644 index 0000000..ec20143 --- /dev/null +++ b/tests/kafkatest/services/trogdor/__init__.py @@ -0,0 +1,14 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/kafkatest/services/trogdor/consume_bench_workload.py b/tests/kafkatest/services/trogdor/consume_bench_workload.py new file mode 100644 index 0000000..79ba863 --- /dev/null +++ b/tests/kafkatest/services/trogdor/consume_bench_workload.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ducktape.services.service import Service +from kafkatest.services.trogdor.task_spec import TaskSpec + + +class ConsumeBenchWorkloadSpec(TaskSpec): + def __init__(self, start_ms, duration_ms, consumer_node, bootstrap_servers, + target_messages_per_sec, max_messages, active_topics, + consumer_conf, common_client_conf, admin_client_conf, consumer_group=None, threads_per_worker=1): + super(ConsumeBenchWorkloadSpec, self).__init__(start_ms, duration_ms) + self.message["class"] = "org.apache.kafka.trogdor.workload.ConsumeBenchSpec" + self.message["consumerNode"] = consumer_node + self.message["bootstrapServers"] = bootstrap_servers + self.message["targetMessagesPerSec"] = target_messages_per_sec + self.message["maxMessages"] = max_messages + self.message["consumerConf"] = consumer_conf + self.message["adminClientConf"] = admin_client_conf + self.message["commonClientConf"] = common_client_conf + self.message["activeTopics"] = active_topics + self.message["threadsPerWorker"] = threads_per_worker + if consumer_group is not None: + self.message["consumerGroup"] = consumer_group + + +class ConsumeBenchWorkloadService(Service): + def __init__(self, context, kafka): + Service.__init__(self, context, num_nodes=1) + self.bootstrap_servers = kafka.bootstrap_servers(validate=False) + self.consumer_node = self.nodes[0].account.hostname + + def free(self): + Service.free(self) + + def wait_node(self, node, timeout_sec=None): + pass + + def stop_node(self, node): + pass + + def clean_node(self, node): + pass \ No newline at end of file diff --git a/tests/kafkatest/services/trogdor/degraded_network_fault_spec.py b/tests/kafkatest/services/trogdor/degraded_network_fault_spec.py new file mode 100644 index 0000000..2a3b142 --- /dev/null +++ b/tests/kafkatest/services/trogdor/degraded_network_fault_spec.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from kafkatest.services.trogdor.task_spec import TaskSpec + + +class DegradedNetworkFaultSpec(TaskSpec): + """ + The specification for a network degradation fault. + + Degrades the network so that traffic on a subset of nodes has higher latency + """ + + def __init__(self, start_ms, duration_ms): + """ + Create a new NetworkDegradeFaultSpec. + + :param start_ms: The start time, as described in task_spec.py + :param duration_ms: The duration in milliseconds. + """ + super(DegradedNetworkFaultSpec, self).__init__(start_ms, duration_ms) + self.message["class"] = "org.apache.kafka.trogdor.fault.DegradedNetworkFaultSpec" + self.message["nodeSpecs"] = {} + + def add_node_spec(self, node, networkDevice, latencyMs=0, rateLimitKbit=0): + """ + Add a node spec to this fault spec + :param node: The node name which is to be degraded + :param networkDevice: The network device name (e.g., eth0) to apply the degradation to + :param latencyMs: Optional. How much latency to add to each packet + :param rateLimitKbit: Optional. Maximum throughput in kilobits per second to allow + :return: + """ + self.message["nodeSpecs"][node] = { + "rateLimitKbit": rateLimitKbit, "latencyMs": latencyMs, "networkDevice": networkDevice + } diff --git a/tests/kafkatest/services/trogdor/files_unreadable_fault_spec.py b/tests/kafkatest/services/trogdor/files_unreadable_fault_spec.py new file mode 100644 index 0000000..618efd3 --- /dev/null +++ b/tests/kafkatest/services/trogdor/files_unreadable_fault_spec.py @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from kafkatest.services.trogdor.task_spec import TaskSpec + + +class FilesUnreadableFaultSpec(TaskSpec): + """ + The specification for a fault which makes files unreadable. + """ + + def __init__(self, start_ms, duration_ms, node_names, mount_path, + prefix, error_code): + """ + Create a new FilesUnreadableFaultSpec. + + :param start_ms: The start time, as described in task_spec.py + :param duration_ms: The duration in milliseconds. + :param node_names: The names of the node(s) to create the fault on. + :param mount_path: The mount path. + :param prefix: The prefix within the mount point to make unreadable. + :param error_code: The error code to use. + """ + super(FilesUnreadableFaultSpec, self).__init__(start_ms, duration_ms) + self.message["class"] = "org.apache.kafka.trogdor.fault.FilesUnreadableFaultSpec" + self.message["nodeNames"] = node_names + self.message["mountPath"] = mount_path + self.message["prefix"] = prefix + self.message["errorCode"] = error_code + + self.kibosh_message = {} + self.kibosh_message["type"] = "unreadable" + self.kibosh_message["prefix"] = prefix + self.kibosh_message["code"] = error_code diff --git a/tests/kafkatest/services/trogdor/kibosh.py b/tests/kafkatest/services/trogdor/kibosh.py new file mode 100644 index 0000000..788486f --- /dev/null +++ b/tests/kafkatest/services/trogdor/kibosh.py @@ -0,0 +1,156 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os.path + +from ducktape.services.service import Service +from ducktape.utils import util + + +class KiboshService(Service): + """ + Kibosh is a fault-injecting FUSE filesystem. + + Attributes: + INSTALL_ROOT The path of where Kibosh is installed. + BINARY_NAME The Kibosh binary name. + BINARY_PATH The path to the kibosh binary. + """ + INSTALL_ROOT = "/opt/kibosh/build" + BINARY_NAME = "kibosh" + BINARY_PATH = os.path.join(INSTALL_ROOT, BINARY_NAME) + + def __init__(self, context, nodes, target, mirror, persist="/mnt/kibosh"): + """ + Create a Kibosh service. + + :param context: The TestContext object. + :param nodes: The nodes to put the Kibosh FS on. Kibosh allocates no + nodes of its own. + :param target: The target directory, which Kibosh exports a view of. + :param mirror: The mirror directory, where Kibosh injects faults. + :param persist: Where the log files and pid files will be created. + """ + Service.__init__(self, context, num_nodes=0) + if (len(nodes) == 0): + raise RuntimeError("You must supply at least one node to run the service on.") + for node in nodes: + self.nodes.append(node) + + self.target = target + self.mirror = mirror + self.persist = persist + + self.control_path = os.path.join(self.mirror, "kibosh_control") + self.pidfile_path = os.path.join(self.persist, "pidfile") + self.stdout_stderr_path = os.path.join(self.persist, "kibosh-stdout-stderr.log") + self.log_path = os.path.join(self.persist, "kibosh.log") + self.logs = { + "kibosh-stdout-stderr.log": { + "path": self.stdout_stderr_path, + "collect_default": True}, + "kibosh.log": { + "path": self.log_path, + "collect_default": True} + } + + def free(self): + """Clear the nodes list.""" + # Because the filesystem runs on nodes which have been allocated by other services, those nodes + # are not deallocated here. + self.nodes = [] + Service.free(self) + + def kibosh_running(self, node): + return 0 == node.account.ssh("test -e '%s'" % self.control_path, allow_fail=True) + + def start_node(self, node): + node.account.mkdirs(self.persist) + cmd = "sudo -E " + cmd += " %s" % KiboshService.BINARY_PATH + cmd += " --target %s" % self.target + cmd += " --pidfile %s" % self.pidfile_path + cmd += " --log %s" % self.log_path + cmd += " --control-mode 666" + cmd += " --verbose" + cmd += " %s" % self.mirror + cmd += " &> %s" % self.stdout_stderr_path + node.account.ssh(cmd) + util.wait_until(lambda: self.kibosh_running(node), 20, backoff_sec=.1, + err_msg="Timed out waiting for kibosh to start on %s" % node.account.hostname) + + def pids(self, node): + return [pid for pid in node.account.ssh_capture("test -e '%s' && test -e /proc/$(cat '%s')" % + (self.pidfile_path, self.pidfile_path), allow_fail=True)] + + def wait_node(self, node, timeout_sec=None): + return len(self.pids(node)) == 0 + + def kibosh_process_running(self, node): + pids = self.pids(node) + if len(pids) == 0: + return True + return False + + def stop_node(self, node): + """Halt kibosh process(es) on this node.""" + node.account.logger.debug("stop_node(%s): unmounting %s" % (node.name, self.mirror)) + node.account.ssh("sudo fusermount -u %s" % self.mirror, allow_fail=True) + # Wait for the kibosh process to terminate. + try: + util.wait_until(lambda: self.kibosh_process_running(node), 20, backoff_sec=.1, + err_msg="Timed out waiting for kibosh to stop on %s" % node.account.hostname) + except TimeoutError: + # If the process won't terminate, use kill -9 to shut it down. + node.account.logger.debug("stop_node(%s): killing the kibosh process managing %s" % (node.name, self.mirror)) + node.account.ssh("sudo kill -9 %s" % (" ".join(self.pids(node))), allow_fail=True) + node.account.ssh("sudo fusermount -u %s" % self.mirror) + util.wait_until(lambda: self.kibosh_process_running(node), 20, backoff_sec=.1, + err_msg="Timed out waiting for kibosh to stop on %s" % node.account.hostname) + + def clean_node(self, node): + """Clean up persistent state on this node - e.g. service logs, configuration files etc.""" + self.stop_node(node) + node.account.ssh("rm -rf -- %s" % self.persist) + + def set_faults(self, node, specs): + """ + Set the currently active faults. + + :param node: The node. + :param spec: An array of FaultSpec objects describing the faults. + """ + if len(specs) == 0: + obj_json = "{}" + else: + fault_array = [spec.kibosh_message for spec in specs] + obj = { 'faults': fault_array } + obj_json = json.dumps(obj) + node.account.create_file(self.control_path, obj_json) + + def get_fault_json(self, node): + """ + Return a JSON string which contains the currently active faults. + + :param node: The node. + + :returns: The fault JSON describing the faults. + """ + iter = node.account.ssh_capture("cat '%s'" % self.control_path) + text = "" + for line in iter: + text = "%s%s" % (text, line.rstrip("\r\n")) + return text diff --git a/tests/kafkatest/services/trogdor/network_partition_fault_spec.py b/tests/kafkatest/services/trogdor/network_partition_fault_spec.py new file mode 100644 index 0000000..4b8c9d3 --- /dev/null +++ b/tests/kafkatest/services/trogdor/network_partition_fault_spec.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from kafkatest.services.trogdor.task_spec import TaskSpec + + +class NetworkPartitionFaultSpec(TaskSpec): + """ + The specification for a network partition fault. + + Network partition faults fracture the network into different partitions + that cannot communicate with each other. + """ + + def __init__(self, start_ms, duration_ms, partitions): + """ + Create a new NetworkPartitionFaultSpec. + + :param start_ms: The start time, as described in task_spec.py + :param duration_ms: The duration in milliseconds. + :param partitions: An array of arrays describing the partitions. + The inner arrays may contain either node names, + or ClusterNode objects. + """ + super(NetworkPartitionFaultSpec, self).__init__(start_ms, duration_ms) + self.message["class"] = "org.apache.kafka.trogdor.fault.NetworkPartitionFaultSpec" + self.message["partitions"] = [TaskSpec.to_node_names(p) for p in partitions] diff --git a/tests/kafkatest/services/trogdor/no_op_task_spec.py b/tests/kafkatest/services/trogdor/no_op_task_spec.py new file mode 100644 index 0000000..9238af4 --- /dev/null +++ b/tests/kafkatest/services/trogdor/no_op_task_spec.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from kafkatest.services.trogdor.task_spec import TaskSpec + + +class NoOpTaskSpec(TaskSpec): + """ + The specification for a nop-op task. + + No-op faults are used to test Trogdor. They don't do anything, + but must be propagated to all Trogdor agents. + """ + + def __init__(self, start_ms, duration_ms): + """ + Create a new NoOpFault. + + :param start_ms: The start time, as described in task_spec.py + :param duration_ms: The duration in milliseconds. + """ + super(NoOpTaskSpec, self).__init__(start_ms, duration_ms) + self.message["class"] = "org.apache.kafka.trogdor.task.NoOpTaskSpec"; diff --git a/tests/kafkatest/services/trogdor/process_stop_fault_spec.py b/tests/kafkatest/services/trogdor/process_stop_fault_spec.py new file mode 100644 index 0000000..3315f1e --- /dev/null +++ b/tests/kafkatest/services/trogdor/process_stop_fault_spec.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from kafkatest.services.trogdor.task_spec import TaskSpec + + +class ProcessStopFaultSpec(TaskSpec): + """ + The specification for a process stop fault. + """ + + def __init__(self, start_ms, duration_ms, nodes, java_process_name): + """ + Create a new ProcessStopFaultSpec. + + :param start_ms: The start time, as described in task_spec.py + :param duration_ms: The duration in milliseconds. + :param node_names: An array describing the nodes to stop processes on. The array + may contain either node names, or ClusterNode objects. + :param java_process_name: The name of the java process to stop. This is the name which + is reported by jps, etc., not the OS-level process name. + """ + super(ProcessStopFaultSpec, self).__init__(start_ms, duration_ms) + self.message["class"] = "org.apache.kafka.trogdor.fault.ProcessStopFaultSpec" + self.message["nodeNames"] = TaskSpec.to_node_names(nodes) + self.message["javaProcessName"] = java_process_name diff --git a/tests/kafkatest/services/trogdor/produce_bench_workload.py b/tests/kafkatest/services/trogdor/produce_bench_workload.py new file mode 100644 index 0000000..9afc814 --- /dev/null +++ b/tests/kafkatest/services/trogdor/produce_bench_workload.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ducktape.services.service import Service +from kafkatest.services.trogdor.task_spec import TaskSpec + + +class ProduceBenchWorkloadSpec(TaskSpec): + def __init__(self, start_ms, duration_ms, producer_node, bootstrap_servers, + target_messages_per_sec, max_messages, producer_conf, admin_client_conf, + common_client_conf, inactive_topics, active_topics, + transaction_generator=None): + super(ProduceBenchWorkloadSpec, self).__init__(start_ms, duration_ms) + self.message["class"] = "org.apache.kafka.trogdor.workload.ProduceBenchSpec" + self.message["producerNode"] = producer_node + self.message["bootstrapServers"] = bootstrap_servers + self.message["targetMessagesPerSec"] = target_messages_per_sec + self.message["maxMessages"] = max_messages + self.message["producerConf"] = producer_conf + self.message["transactionGenerator"] = transaction_generator + self.message["adminClientConf"] = admin_client_conf + self.message["commonClientConf"] = common_client_conf + self.message["inactiveTopics"] = inactive_topics + self.message["activeTopics"] = active_topics + + +class ProduceBenchWorkloadService(Service): + def __init__(self, context, kafka): + Service.__init__(self, context, num_nodes=1) + self.bootstrap_servers = kafka.bootstrap_servers(validate=False) + self.producer_node = self.nodes[0].account.hostname + + def free(self): + Service.free(self) + + def wait_node(self, node, timeout_sec=None): + pass + + def stop_node(self, node): + pass + + def clean_node(self, node): + pass diff --git a/tests/kafkatest/services/trogdor/round_trip_workload.py b/tests/kafkatest/services/trogdor/round_trip_workload.py new file mode 100644 index 0000000..86bc2d2 --- /dev/null +++ b/tests/kafkatest/services/trogdor/round_trip_workload.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ducktape.services.service import Service +from kafkatest.services.trogdor.task_spec import TaskSpec + + +class RoundTripWorkloadSpec(TaskSpec): + def __init__(self, start_ms, duration_ms, client_node, bootstrap_servers, + target_messages_per_sec, max_messages, active_topics): + super(RoundTripWorkloadSpec, self).__init__(start_ms, duration_ms) + self.message["class"] = "org.apache.kafka.trogdor.workload.RoundTripWorkloadSpec" + self.message["clientNode"] = client_node + self.message["bootstrapServers"] = bootstrap_servers + self.message["targetMessagesPerSec"] = target_messages_per_sec + self.message["maxMessages"] = max_messages + self.message["activeTopics"] = active_topics + + +class RoundTripWorkloadService(Service): + def __init__(self, context, kafka): + Service.__init__(self, context, num_nodes=1) + self.bootstrap_servers = kafka.bootstrap_servers(validate=False) + self.client_node = self.nodes[0].account.hostname + + def free(self): + Service.free(self) + + def wait_node(self, node, timeout_sec=None): + pass + + def stop_node(self, node): + pass + + def clean_node(self, node): + pass diff --git a/tests/kafkatest/services/trogdor/task_spec.py b/tests/kafkatest/services/trogdor/task_spec.py new file mode 100644 index 0000000..aa5766e --- /dev/null +++ b/tests/kafkatest/services/trogdor/task_spec.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + + +class TaskSpec(object): + """ + The base class for a task specification. + + MAX_DURATION_MS The longest duration we should use for a task specification. + """ + + MAX_DURATION_MS=10000000 + + def __init__(self, start_ms, duration_ms): + """ + Create a new task specification. + + :param start_ms: The target start time in milliseconds since the epoch. + :param duration_ms: The duration in milliseconds. + """ + self.message = { + 'startMs': start_ms, + 'durationMs': duration_ms + } + + @staticmethod + def to_node_names(nodes): + """ + Convert an array of nodes or node names to an array of node names. + """ + node_names = [] + for obj in nodes: + if isinstance(obj, str): + node_names.append(obj) + else: + node_names.append(obj.name) + return node_names + + def __str__(self): + return json.dumps(self.message) diff --git a/tests/kafkatest/services/trogdor/templates/log4j.properties b/tests/kafkatest/services/trogdor/templates/log4j.properties new file mode 100644 index 0000000..252668e --- /dev/null +++ b/tests/kafkatest/services/trogdor/templates/log4j.properties @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +log4j.rootLogger=DEBUG, mylogger +log4j.logger.kafka=DEBUG +log4j.logger.org.apache.kafka=DEBUG +log4j.logger.org.eclipse=INFO +log4j.appender.mylogger=org.apache.log4j.FileAppender +log4j.appender.mylogger.File={{ log_path }} +log4j.appender.mylogger.layout=org.apache.log4j.PatternLayout +log4j.appender.mylogger.layout.ConversionPattern=[%d] %p %m (%c)%n diff --git a/tests/kafkatest/services/trogdor/trogdor.py b/tests/kafkatest/services/trogdor/trogdor.py new file mode 100644 index 0000000..bd18bdd --- /dev/null +++ b/tests/kafkatest/services/trogdor/trogdor.py @@ -0,0 +1,354 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os.path +import requests +from requests.adapters import HTTPAdapter +from requests.packages.urllib3 import Retry + +from ducktape.services.service import Service +from ducktape.utils.util import wait_until +from kafkatest.directory_layout.kafka_path import KafkaPathResolverMixin + + +class TrogdorService(KafkaPathResolverMixin, Service): + """ + A ducktape service for running the trogdor fault injection daemons. + + Attributes: + PERSISTENT_ROOT The root filesystem path to store service files under. + COORDINATOR_STDOUT_STDERR The path where we store the coordinator's stdout/stderr output. + AGENT_STDOUT_STDERR The path where we store the agents's stdout/stderr output. + COORDINATOR_LOG The path where we store the coordinator's log4j output. + AGENT_LOG The path where we store the agent's log4j output. + AGENT_LOG4J_PROPERTIES The path to the agent log4j.properties file for log config. + COORDINATOR_LOG4J_PROPERTIES The path to the coordinator log4j.properties file for log config. + CONFIG_PATH The path to the trogdor configuration file. + DEFAULT_AGENT_PORT The default port to use for trogdor_agent daemons. + DEFAULT_COORDINATOR_PORT The default port to use for trogdor_coordinator daemons. + REQUEST_TIMEOUT The request timeout in seconds to use for REST requests. + REQUEST_HEADERS The request headers to use when communicating with trogdor. + """ + + PERSISTENT_ROOT="/mnt/trogdor" + COORDINATOR_STDOUT_STDERR = os.path.join(PERSISTENT_ROOT, "trogdor-coordinator-stdout-stderr.log") + AGENT_STDOUT_STDERR = os.path.join(PERSISTENT_ROOT, "trogdor-agent-stdout-stderr.log") + COORDINATOR_LOG = os.path.join(PERSISTENT_ROOT, "trogdor-coordinator.log") + AGENT_LOG = os.path.join(PERSISTENT_ROOT, "trogdor-agent.log") + COORDINATOR_LOG4J_PROPERTIES = os.path.join(PERSISTENT_ROOT, "trogdor-coordinator-log4j.properties") + AGENT_LOG4J_PROPERTIES = os.path.join(PERSISTENT_ROOT, "trogdor-agent-log4j.properties") + CONFIG_PATH = os.path.join(PERSISTENT_ROOT, "trogdor.conf") + DEFAULT_AGENT_PORT=8888 + DEFAULT_COORDINATOR_PORT=8889 + REQUEST_TIMEOUT=5 + REQUEST_HEADERS = {"Content-type": "application/json"} + + logs = { + "trogdor_coordinator_stdout_stderr": { + "path": COORDINATOR_STDOUT_STDERR, + "collect_default": True}, + "trogdor_agent_stdout_stderr": { + "path": AGENT_STDOUT_STDERR, + "collect_default": True}, + "trogdor_coordinator_log": { + "path": COORDINATOR_LOG, + "collect_default": True}, + "trogdor_agent_log": { + "path": AGENT_LOG, + "collect_default": True}, + } + + + def __init__(self, context, agent_nodes=None, client_services=None, + agent_port=DEFAULT_AGENT_PORT, coordinator_port=DEFAULT_COORDINATOR_PORT): + """ + Create a Trogdor service. + + :param context: The test context. + :param agent_nodes: The nodes to run the agents on. + :param client_services: Services whose nodes we should run agents on. + :param agent_port: The port to use for the trogdor_agent daemons. + :param coordinator_port: The port to use for the trogdor_coordinator daemons. + """ + Service.__init__(self, context, num_nodes=1) + self.coordinator_node = self.nodes[0] + if client_services is not None: + for client_service in client_services: + for node in client_service.nodes: + self.nodes.append(node) + if agent_nodes is not None: + for agent_node in agent_nodes: + self.nodes.append(agent_node) + if (len(self.nodes) == 1): + raise RuntimeError("You must supply at least one agent node to run the service on.") + self.agent_port = agent_port + self.coordinator_port = coordinator_port + + def free(self): + # We only want to deallocate the coordinator node, not the agent nodes. So we + # change self.nodes to include only the coordinator node, and then invoke + # the base class' free method. + if self.coordinator_node is not None: + self.nodes = [self.coordinator_node] + self.coordinator_node = None + Service.free(self) + + def _create_config_dict(self): + """ + Create a dictionary with the Trogdor configuration. + + :return: The configuration dictionary. + """ + dict_nodes = {} + for node in self.nodes: + dict_nodes[node.name] = { + "hostname": node.account.ssh_hostname, + } + if node.name == self.coordinator_node.name: + dict_nodes[node.name]["trogdor.coordinator.port"] = self.coordinator_port + else: + dict_nodes[node.name]["trogdor.agent.port"] = self.agent_port + + return { + "platform": "org.apache.kafka.trogdor.basic.BasicPlatform", + "nodes": dict_nodes, + } + + def start_node(self, node): + node.account.mkdirs(TrogdorService.PERSISTENT_ROOT) + + # Create the configuration file on the node. + str = json.dumps(self._create_config_dict(), indent=2) + self.logger.info("Creating configuration file %s with %s" % (TrogdorService.CONFIG_PATH, str)) + node.account.create_file(TrogdorService.CONFIG_PATH, str) + + if self.is_coordinator(node): + self._start_coordinator_node(node) + else: + self._start_agent_node(node) + + def _start_coordinator_node(self, node): + node.account.create_file(TrogdorService.COORDINATOR_LOG4J_PROPERTIES, + self.render('log4j.properties', + log_path=TrogdorService.COORDINATOR_LOG)) + self._start_trogdor_daemon("coordinator", TrogdorService.COORDINATOR_STDOUT_STDERR, + TrogdorService.COORDINATOR_LOG4J_PROPERTIES, + TrogdorService.COORDINATOR_LOG, node) + self.logger.info("Started trogdor coordinator on %s." % node.name) + + def _start_agent_node(self, node): + node.account.create_file(TrogdorService.AGENT_LOG4J_PROPERTIES, + self.render('log4j.properties', + log_path=TrogdorService.AGENT_LOG)) + self._start_trogdor_daemon("agent", TrogdorService.AGENT_STDOUT_STDERR, + TrogdorService.AGENT_LOG4J_PROPERTIES, + TrogdorService.AGENT_LOG, node) + self.logger.info("Started trogdor agent on %s." % node.name) + + def _start_trogdor_daemon(self, daemon_name, stdout_stderr_capture_path, + log4j_properties_path, log_path, node): + cmd = "export KAFKA_LOG4J_OPTS='-Dlog4j.configuration=file:%s'; " % log4j_properties_path + cmd += "%s %s --%s.config %s --node-name %s 1>> %s 2>> %s &" % \ + (self.path.script("trogdor.sh", node), + daemon_name, + daemon_name, + TrogdorService.CONFIG_PATH, + node.name, + stdout_stderr_capture_path, + stdout_stderr_capture_path) + node.account.ssh(cmd) + with node.account.monitor_log(log_path) as monitor: + monitor.wait_until("Starting %s process." % daemon_name, timeout_sec=60, backoff_sec=.10, + err_msg=("%s on %s didn't finish startup" % (daemon_name, node.name))) + + def wait_node(self, node, timeout_sec=None): + if self.is_coordinator(node): + return len(node.account.java_pids(self.coordinator_class_name())) == 0 + else: + return len(node.account.java_pids(self.agent_class_name())) == 0 + + def stop_node(self, node): + """Halt trogdor processes on this node.""" + if self.is_coordinator(node): + node.account.kill_java_processes(self.coordinator_class_name()) + else: + node.account.kill_java_processes(self.agent_class_name()) + + def clean_node(self, node): + """Clean up persistent state on this node - e.g. service logs, configuration files etc.""" + self.stop_node(node) + node.account.ssh("rm -rf -- %s" % TrogdorService.PERSISTENT_ROOT) + + def _coordinator_url(self, path): + return "http://%s:%d/coordinator/%s" % \ + (self.coordinator_node.account.ssh_hostname, self.coordinator_port, path) + + def request_session(self): + """ + Creates a new request session which will retry for a while. + """ + session = requests.Session() + session.mount('http://', + HTTPAdapter(max_retries=Retry(total=5, backoff_factor=0.3))) + return session + + def _coordinator_post(self, path, message): + """ + Make a POST request to the Trogdor coordinator. + + :param path: The URL path to use. + :param message: The message object to send. + :return: The response as an object. + """ + url = self._coordinator_url(path) + self.logger.info("POST %s %s" % (url, message)) + response = self.request_session().post(url, json=message, + timeout=TrogdorService.REQUEST_TIMEOUT, + headers=TrogdorService.REQUEST_HEADERS) + response.raise_for_status() + return response.json() + + def _coordinator_put(self, path, message): + """ + Make a PUT request to the Trogdor coordinator. + + :param path: The URL path to use. + :param message: The message object to send. + :return: The response as an object. + """ + url = self._coordinator_url(path) + self.logger.info("PUT %s %s" % (url, message)) + response = self.request_session().put(url, json=message, + timeout=TrogdorService.REQUEST_TIMEOUT, + headers=TrogdorService.REQUEST_HEADERS) + response.raise_for_status() + return response.json() + + def _coordinator_get(self, path, message): + """ + Make a GET request to the Trogdor coordinator. + + :param path: The URL path to use. + :param message: The message object to send. + :return: The response as an object. + """ + url = self._coordinator_url(path) + self.logger.info("GET %s %s" % (url, message)) + response = self.request_session().get(url, json=message, + timeout=TrogdorService.REQUEST_TIMEOUT, + headers=TrogdorService.REQUEST_HEADERS) + response.raise_for_status() + return response.json() + + def create_task(self, id, spec): + """ + Create a new task. + + :param id: The task id. + :param spec: The task spec. + """ + self._coordinator_post("task/create", { "id": id, "spec": spec.message}) + return TrogdorTask(id, self) + + def stop_task(self, id): + """ + Stop a task. + + :param id: The task id. + """ + self._coordinator_put("task/stop", { "id": id }) + + def tasks(self): + """ + Get the tasks which are on the coordinator. + + :returns: A map of task id strings to task state objects. + Task state objects contain a 'spec' field with the spec + and a 'state' field with the state. + """ + return self._coordinator_get("tasks", {}) + + def is_coordinator(self, node): + return node == self.coordinator_node + + def agent_class_name(self): + return "org.apache.kafka.trogdor.agent.Agent" + + def coordinator_class_name(self): + return "org.apache.kafka.trogdor.coordinator.Coordinator" + +class TrogdorTask(object): + PENDING_STATE = "PENDING" + RUNNING_STATE = "RUNNING" + STOPPING_STATE = "STOPPING" + DONE_STATE = "DONE" + + def __init__(self, id, trogdor): + self.id = id + self.trogdor = trogdor + + def task_state_or_error(self): + task_state = self.trogdor.tasks()["tasks"][self.id] + if task_state is None: + raise RuntimeError("Coordinator did not know about %s." % self.id) + error = task_state.get("error") + if error is None or error == "": + return task_state["state"], None + else: + return None, error + + def done(self): + """ + Check if this task is done. + + :raises RuntimeError: If the task encountered an error. + :returns: True if the task is in DONE_STATE; + False if it is in a different state. + """ + (task_state, error) = self.task_state_or_error() + if task_state is not None: + return task_state == TrogdorTask.DONE_STATE + else: + raise RuntimeError("Failed to gracefully stop %s: got task error: %s" % (self.id, error)) + + def running(self): + """ + Check if this task is running. + + :raises RuntimeError: If the task encountered an error. + :returns: True if the task is in RUNNING_STATE; + False if it is in a different state. + """ + (task_state, error) = self.task_state_or_error() + if task_state is not None: + return task_state == TrogdorTask.RUNNING_STATE + else: + raise RuntimeError("Failed to start %s: got task error: %s" % (self.id, error)) + + def stop(self): + """ + Stop this task. + + :raises RuntimeError: If the task encountered an error. + """ + if self.done(): + return + self.trogdor.stop_task(self.id) + + def wait_for_done(self, timeout_sec=360): + wait_until(lambda: self.done(), + timeout_sec=timeout_sec, + err_msg="%s failed to finish in the expected amount of time." % self.id) diff --git a/tests/kafkatest/services/verifiable_client.py b/tests/kafkatest/services/verifiable_client.py new file mode 100644 index 0000000..a2cbf96 --- /dev/null +++ b/tests/kafkatest/services/verifiable_client.py @@ -0,0 +1,346 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from kafkatest.directory_layout.kafka_path import TOOLS_JAR_NAME, TOOLS_DEPENDANT_TEST_LIBS_JAR_NAME +from kafkatest.version import DEV_BRANCH, LATEST_0_8_2 +from ducktape.cluster.remoteaccount import RemoteCommandError + +import importlib +import subprocess +import signal +from kafkatest.services.kafka.util import fix_opts_for_new_jvm + + +"""This module abstracts the implementation of a verifiable client, allowing +client developers to plug in their own client for all kafkatests that make +use of either the VerifiableConsumer or VerifiableProducer classes. + +A verifiable client class must implement exec_cmd() and pids(). + +This file provides: + * VerifiableClientMixin class: to be used for creating new verifiable client classes + * VerifiableClientJava class: the default Java verifiable clients + * VerifiableClientApp class: uses global configuration to specify + the command to execute and optional "pids" command, deploy script, etc. + Config syntax (pass as --global ): + {"Verifiable(Producer|Consumer|Client)": { + "class": "kafkatest.services.verifiable_client.VerifiableClientApp", + "exec_cmd": "/vagrant/x/myclient --some --standard --args", + "pids": "pgrep -f ...", // optional + "deploy": "/vagrant/x/mydeploy.sh", // optional + "kill_signal": 2 // optional clean_shutdown kill signal (SIGINT in this case) + }} + * VerifiableClientDummy class: testing dummy + + + +============================== +Verifiable client requirements +============================== + +There are currently two verifiable client specifications: + * VerifiableConsumer + * VerifiableProducer + +Common requirements for both: + * One-way communication (client -> tests) through new-line delimited + JSON objects on stdout (details below). + * Log/debug to stderr + +Common communication for both: + * `{ "name": "startup_complete" }` - Client succesfully started + * `{ "name": "shutdown_complete" }` - Client succesfully terminated (after receiving SIGINT/SIGTERM) + + +================== +VerifiableConsumer +================== + +Command line arguments: + * `--group-id ` + * `--topic ` + * `--broker-list ` + * `--session-timeout ` + * `--enable-autocommit` + * `--max-messages ` + * `--assignment-strategy ` + * `--consumer.config ` - consumer config properties (typically empty) + +Environment variables: + * `LOG_DIR` - log output directory. Typically not needed if logs are written to stderr. + * `KAFKA_OPTS` - Security config properties (Java client syntax) + * `KAFKA_LOG4J_OPTS` - Java log4j options (can be ignored) + +Client communication: + * `{ "name": "offsets_committed", "success": bool, "error": "", "offsets": [ { "topic": "", "partition":

                , "offset": } ] }` - offset commit results, should be emitted for each committed offset. Emit prior to partitions_revoked. + * `{ "name": "records_consumed", "partitions": [ { "topic": "", "partition":

                , "minOffset": , "maxOffset": } ], "count": }` - per-partition delta stats from last records_consumed. Emit every 1000 messages, or 1s. Emit prior to partitions_assigned, partitions_revoked and offsets_committed. + * `{ "name": "partitions_revoked", "partitions": [ { "topic": "", "partition":

                } ] }` - rebalance: revoked partitions + * `{ "name": "partitions_assigned", "partitions": [ { "topic": "", "partition":

                } ] }` - rebalance: assigned partitions + + +================== +VerifiableProducer +================== + +Command line arguments: + * `--topic ` + * `--broker-list ` + * `--max-messages ` + * `--throughput ` + * `--producer.config ` - producer config properties (typically empty) + +Environment variables: + * `LOG_DIR` - log output directory. Typically not needed if logs are written to stderr. + * `KAFKA_OPTS` - Security config properties (Java client syntax) + * `KAFKA_LOG4J_OPTS` - Java log4j options (can be ignored) + +Client communication: + * `{ "name": "producer_send_error", "message": "", "topic": "", "key": "", "value": "" }` - emit on produce error. + * `{ "name": "producer_send_success", "topic": "", "partition":

                , "offset": , "key": "", "value": "" }` - emit on produce success. + + + +=========== +Development +=========== + +**Logs:** +During development of kafkatest clients it is generally a good idea to +enable collection of the client's stdout and stderr logs for troubleshooting. +Do this by setting "collect_default" to True for verifiable_consumder_stdout +and .._stderr in verifiable_consumer.py and verifiable_producer.py + + +**Deployment:** +There's currently no automatic way of deploying 3rd party kafkatest clients +on the VM instance so this needs to be done (at least partially) manually for +now. + +One way to do this is logging in to a worker (`vagrant ssh worker1`), downloading +and building the kafkatest client under /vagrant (which maps to the kafka root +directory on the host and is shared with all VM instances). +Also make sure to install any system-level dependencies on each instance. + +Then use /vagrant/..../yourkafkatestclient as your run-time path since it will +now be available on all instances. + +The VerifiableClientApp automates the per-worker deployment with the optional +"deploy": "/vagrant/../deploy_script.sh" globals configuration property, this +script will be called on the VM just prior to executing the client. +""" + +def create_verifiable_client_implementation(context, parent): + """Factory for generating a verifiable client implementation class instance + + :param parent: parent class instance, either VerifiableConsumer or VerifiableProducer + + This will first check for a fully qualified client implementation class name + in context.globals as "Verifiable" where is "Producer" or "Consumer", + followed by "VerifiableClient" (which should implement both). + The global object layout is: {"class": "", "..anything..": ..}. + + If present, construct a new instance, else defaults to VerifiableClientJava + """ + + # Default class + obj = {"class": "kafkatest.services.verifiable_client.VerifiableClientJava"} + + parent_name = parent.__class__.__name__.rsplit('.', 1)[-1] + for k in [parent_name, "VerifiableClient"]: + if k in context.globals: + obj = context.globals[k] + break + + if "class" not in obj: + raise SyntaxError('%s (or VerifiableClient) expected object format: {"class": "full.class.path", ..}' % parent_name) + + clname = obj["class"] + # Using the fully qualified classname, import the implementation class + if clname.find('.') == -1: + raise SyntaxError("%s (or VerifiableClient) must specify full class path (including module)" % parent_name) + + (module_name, clname) = clname.rsplit('.', 1) + cluster_mod = importlib.import_module(module_name) + impl_class = getattr(cluster_mod, clname) + return impl_class(parent, obj) + + + +class VerifiableClientMixin (object): + """ + Verifiable client mixin class which loads the actual VerifiableClient.. class. + """ + def __init__ (self, *args, **kwargs): + super(VerifiableClientMixin, self).__init__(*args, **kwargs) + if hasattr(self.impl, 'deploy'): + # Deploy client on node + self.context.logger.debug("Deploying %s on %s" % (self.impl, self.nodes)) + for node in self.nodes: + self.impl.deploy(node) + + @property + def impl (self): + """ + :return: Return (and create if necessary) the Verifiable client implementation object. + """ + # Add _impl attribute to parent Verifiable(Consumer|Producer) object. + if not hasattr(self, "_impl"): + setattr(self, "_impl", create_verifiable_client_implementation(self.context, self)) + if hasattr(self.context, "logger") and self.context.logger is not None: + self.context.logger.debug("Using client implementation %s for %s" % (self._impl.__class__.__name__, self.__class__.__name__)) + return self._impl + + +class VerifiableClient (object): + """ + Verifiable client base class + """ + def __init__(self, *args, **kwargs): + super(VerifiableClient, self).__init__() + + def exec_cmd (self, node): + """ + :return: command string to execute client. + Environment variables will be prepended and command line arguments + appended to this string later by start_cmd(). + + This method should also take care of deploying the client on the instance, if necessary. + """ + raise NotImplementedError() + + def pids (self, node): + """ :return: list of pids for this client instance on node """ + raise NotImplementedError() + + def kill_signal (self, clean_shutdown=True): + """ :return: the kill signal to terminate the application. """ + if not clean_shutdown: + return signal.SIGKILL + + return self.conf.get("kill_signal", signal.SIGTERM) + + +class VerifiableClientJava (VerifiableClient): + """ + Verifiable Consumer and Producer using the official Java client. + """ + def __init__(self, parent, conf=None): + """ + :param parent: The parent instance, either VerifiableConsumer or VerifiableProducer + :param conf: Optional conf object (the --globals VerifiableX object) + """ + super(VerifiableClientJava, self).__init__() + self.parent = parent + self.java_class_name = parent.java_class_name() + self.conf = conf + + def exec_cmd (self, node): + """ :return: command to execute to start instance + Translates Verifiable* to the corresponding Java client class name """ + cmd = "" + if self.java_class_name == 'VerifiableProducer' and node.version <= LATEST_0_8_2: + # 0.8.2.X releases do not have VerifiableProducer.java, so cheat and add + # the tools jar from trunk to the classpath + tools_jar = self.parent.path.jar(TOOLS_JAR_NAME, DEV_BRANCH) + tools_dependant_libs_jar = self.parent.path.jar(TOOLS_DEPENDANT_TEST_LIBS_JAR_NAME, DEV_BRANCH) + cmd += "for file in %s; do CLASSPATH=$CLASSPATH:$file; done; " % tools_jar + cmd += "for file in %s; do CLASSPATH=$CLASSPATH:$file; done; " % tools_dependant_libs_jar + cmd += "export CLASSPATH; " + cmd += fix_opts_for_new_jvm(node) + cmd += self.parent.path.script("kafka-run-class.sh", node) + " org.apache.kafka.tools." + self.java_class_name + return cmd + + def pids (self, node): + """ :return: pid(s) for this client intstance on node """ + try: + cmd = "jps | grep -i " + self.java_class_name + " | awk '{print $1}'" + pid_arr = [pid for pid in node.account.ssh_capture(cmd, allow_fail=True, callback=int)] + return pid_arr + except (RemoteCommandError, ValueError) as e: + return [] + + +class VerifiableClientDummy (VerifiableClient): + """ + Dummy class for testing the pluggable framework + """ + def __init__(self, parent, conf=None): + """ + :param parent: The parent instance, either VerifiableConsumer or VerifiableProducer + :param conf: Optional conf object (the --globals VerifiableX object) + """ + super(VerifiableClientDummy, self).__init__() + self.parent = parent + self.conf = conf + + def exec_cmd (self, node): + """ :return: command to execute to start instance """ + return 'echo -e \'{"name": "shutdown_complete" }\n\' ; echo ARGS:' + + def pids (self, node): + """ :return: pid(s) for this client intstance on node """ + return [] + + +class VerifiableClientApp (VerifiableClient): + """ + VerifiableClient using --global settings for exec_cmd, pids and deploy. + By using this a verifiable client application can be used through simple + --globals configuration rather than implementing a Python class. + """ + + def __init__(self, parent, conf): + """ + :param parent: The parent instance, either VerifiableConsumer or VerifiableProducer + :param conf: Optional conf object (the --globals VerifiableX object) + """ + super(VerifiableClientApp, self).__init__() + self.parent = parent + # "VerifiableConsumer" or "VerifiableProducer" + self.name = self.parent.__class__.__name__ + self.conf = conf + + if "exec_cmd" not in self.conf: + raise SyntaxError("%s requires \"exec_cmd\": .. to be set in --globals %s object" % \ + (self.__class__.__name__, self.name)) + + + def exec_cmd (self, node): + """ :return: command to execute to start instance """ + return self.conf["exec_cmd"] + + def pids (self, node): + """ :return: pid(s) for this client intstance on node """ + + cmd = self.conf.get("pids", "pgrep -f '" + self.conf["exec_cmd"] + "'") + try: + pid_arr = [pid for pid in node.account.ssh_capture(cmd, allow_fail=True, callback=int)] + self.parent.context.logger.info("%s pids are: %s" % (str(node.account), pid_arr)) + return pid_arr + except (subprocess.CalledProcessError, ValueError) as e: + return [] + + def deploy (self, node): + """ Call deploy script specified by "deploy" --global key + This optional script is run on the VM instance just prior to + executing `exec_cmd` to deploy the kafkatest client. + The script path must be as seen by the VM instance, e.g. /vagrant/.... """ + + if "deploy" not in self.conf: + return + + script_cmd = self.conf["deploy"] + self.parent.context.logger.debug("Deploying %s: %s" % (self, script_cmd)) + r = node.account.ssh(script_cmd) diff --git a/tests/kafkatest/services/verifiable_consumer.py b/tests/kafkatest/services/verifiable_consumer.py new file mode 100644 index 0000000..93d9446 --- /dev/null +++ b/tests/kafkatest/services/verifiable_consumer.py @@ -0,0 +1,418 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +from ducktape.services.background_thread import BackgroundThreadService + +from kafkatest.directory_layout.kafka_path import KafkaPathResolverMixin +from kafkatest.services.kafka import TopicPartition +from kafkatest.services.verifiable_client import VerifiableClientMixin +from kafkatest.version import DEV_BRANCH, V_2_3_0, V_2_3_1, V_0_10_0_0 + + +class ConsumerState: + Started = 1 + Dead = 2 + Rebalancing = 3 + Joined = 4 + + +class ConsumerEventHandler(object): + + def __init__(self, node, verify_offsets, idx): + self.node = node + self.idx = idx + self.state = ConsumerState.Dead + self.revoked_count = 0 + self.assigned_count = 0 + self.assignment = [] + self.position = {} + self.committed = {} + self.total_consumed = 0 + self.verify_offsets = verify_offsets + + def handle_shutdown_complete(self): + self.state = ConsumerState.Dead + self.assignment = [] + self.position = {} + + def handle_startup_complete(self): + self.state = ConsumerState.Started + + def handle_offsets_committed(self, event, node, logger): + if event["success"]: + for offset_commit in event["offsets"]: + if offset_commit.get("error", "") != "": + logger.debug("%s: Offset commit failed for: %s" % (str(node.account), offset_commit)) + continue + + topic = offset_commit["topic"] + partition = offset_commit["partition"] + tp = TopicPartition(topic, partition) + offset = offset_commit["offset"] + assert tp in self.assignment, \ + "Committed offsets for partition %s not assigned (current assignment: %s)" % \ + (str(tp), str(self.assignment)) + assert tp in self.position, "No previous position for %s: %s" % (str(tp), event) + assert self.position[tp] >= offset, \ + "The committed offset %d was greater than the current position %d for partition %s" % \ + (offset, self.position[tp], str(tp)) + self.committed[tp] = offset + + def handle_records_consumed(self, event, logger): + assert self.state == ConsumerState.Joined, \ + "Consumed records should only be received when joined (current state: %s)" % str(self.state) + + for record_batch in event["partitions"]: + tp = TopicPartition(topic=record_batch["topic"], + partition=record_batch["partition"]) + min_offset = record_batch["minOffset"] + max_offset = record_batch["maxOffset"] + + assert tp in self.assignment, \ + "Consumed records for partition %s which is not assigned (current assignment: %s)" % \ + (str(tp), str(self.assignment)) + if tp not in self.position or self.position[tp] == min_offset: + self.position[tp] = max_offset + 1 + else: + msg = "Consumed from an unexpected offset (%d, %d) for partition %s" % \ + (self.position.get(tp), min_offset, str(tp)) + if self.verify_offsets: + raise AssertionError(msg) + else: + if tp in self.position: + self.position[tp] = max_offset + 1 + logger.warn(msg) + self.total_consumed += event["count"] + + def handle_partitions_revoked(self, event): + self.revoked_count += 1 + self.state = ConsumerState.Rebalancing + self.position = {} + + def handle_partitions_assigned(self, event): + self.assigned_count += 1 + self.state = ConsumerState.Joined + assignment = [] + for topic_partition in event["partitions"]: + topic = topic_partition["topic"] + partition = topic_partition["partition"] + assignment.append(TopicPartition(topic, partition)) + self.assignment = assignment + + def handle_kill_process(self, clean_shutdown): + # if the shutdown was clean, then we expect the explicit + # shutdown event from the consumer + if not clean_shutdown: + self.handle_shutdown_complete() + + def current_assignment(self): + return list(self.assignment) + + def current_position(self, tp): + if tp in self.position: + return self.position[tp] + else: + return None + + def last_commit(self, tp): + if tp in self.committed: + return self.committed[tp] + else: + return None + + +class VerifiableConsumer(KafkaPathResolverMixin, VerifiableClientMixin, BackgroundThreadService): + """This service wraps org.apache.kafka.tools.VerifiableConsumer for use in + system testing. + + NOTE: this class should be treated as a PUBLIC API. Downstream users use + this service both directly and through class extension, so care must be + taken to ensure compatibility. + """ + + PERSISTENT_ROOT = "/mnt/verifiable_consumer" + STDOUT_CAPTURE = os.path.join(PERSISTENT_ROOT, "verifiable_consumer.stdout") + STDERR_CAPTURE = os.path.join(PERSISTENT_ROOT, "verifiable_consumer.stderr") + LOG_DIR = os.path.join(PERSISTENT_ROOT, "logs") + LOG_FILE = os.path.join(LOG_DIR, "verifiable_consumer.log") + LOG4J_CONFIG = os.path.join(PERSISTENT_ROOT, "tools-log4j.properties") + CONFIG_FILE = os.path.join(PERSISTENT_ROOT, "verifiable_consumer.properties") + + logs = { + "verifiable_consumer_stdout": { + "path": STDOUT_CAPTURE, + "collect_default": False}, + "verifiable_consumer_stderr": { + "path": STDERR_CAPTURE, + "collect_default": False}, + "verifiable_consumer_log": { + "path": LOG_FILE, + "collect_default": True} + } + + def __init__(self, context, num_nodes, kafka, topic, group_id, + static_membership=False, max_messages=-1, session_timeout_sec=30, enable_autocommit=False, + assignment_strategy=None, + version=DEV_BRANCH, stop_timeout_sec=30, log_level="INFO", jaas_override_variables=None, + on_record_consumed=None, reset_policy="earliest", verify_offsets=True): + """ + :param jaas_override_variables: A dict of variables to be used in the jaas.conf template file + """ + super(VerifiableConsumer, self).__init__(context, num_nodes) + self.log_level = log_level + self.kafka = kafka + self.topic = topic + self.group_id = group_id + self.reset_policy = reset_policy + self.static_membership = static_membership + self.max_messages = max_messages + self.session_timeout_sec = session_timeout_sec + self.enable_autocommit = enable_autocommit + self.assignment_strategy = assignment_strategy + self.prop_file = "" + self.stop_timeout_sec = stop_timeout_sec + self.on_record_consumed = on_record_consumed + self.verify_offsets = verify_offsets + + self.event_handlers = {} + self.global_position = {} + self.global_committed = {} + self.jaas_override_variables = jaas_override_variables or {} + + for node in self.nodes: + node.version = version + + def java_class_name(self): + return "VerifiableConsumer" + + def _worker(self, idx, node): + with self.lock: + if node not in self.event_handlers: + self.event_handlers[node] = ConsumerEventHandler(node, self.verify_offsets, idx) + handler = self.event_handlers[node] + + node.account.ssh("mkdir -p %s" % VerifiableConsumer.PERSISTENT_ROOT, allow_fail=False) + + # Create and upload log properties + log_config = self.render('tools_log4j.properties', log_file=VerifiableConsumer.LOG_FILE) + node.account.create_file(VerifiableConsumer.LOG4J_CONFIG, log_config) + + # Create and upload config file + self.security_config = self.kafka.security_config.client_config(self.prop_file, node, + self.jaas_override_variables) + self.security_config.setup_node(node) + self.prop_file += str(self.security_config) + self.logger.info("verifiable_consumer.properties:") + self.logger.info(self.prop_file) + node.account.create_file(VerifiableConsumer.CONFIG_FILE, self.prop_file) + self.security_config.setup_node(node) + # apply group.instance.id to the node for static membership validation + node.group_instance_id = None + if self.static_membership: + assert node.version >= V_2_3_0, \ + "Version %s does not support static membership (must be 2.3 or higher)" % str(node.version) + node.group_instance_id = self.group_id + "-instance-" + str(idx) + + if self.assignment_strategy: + assert node.version >= V_0_10_0_0, \ + "Version %s does not setting an assignment strategy (must be 0.10.0 or higher)" % str(node.version) + + cmd = self.start_cmd(node) + self.logger.debug("VerifiableConsumer %d command: %s" % (idx, cmd)) + + for line in node.account.ssh_capture(cmd): + event = self.try_parse_json(node, line.strip()) + if event is not None: + with self.lock: + name = event["name"] + if name == "shutdown_complete": + handler.handle_shutdown_complete() + elif name == "startup_complete": + handler.handle_startup_complete() + elif name == "offsets_committed": + handler.handle_offsets_committed(event, node, self.logger) + self._update_global_committed(event) + elif name == "records_consumed": + handler.handle_records_consumed(event, self.logger) + self._update_global_position(event, node) + elif name == "record_data" and self.on_record_consumed: + self.on_record_consumed(event, node) + elif name == "partitions_revoked": + handler.handle_partitions_revoked(event) + elif name == "partitions_assigned": + handler.handle_partitions_assigned(event) + else: + self.logger.debug("%s: ignoring unknown event: %s" % (str(node.account), event)) + + def _update_global_position(self, consumed_event, node): + for consumed_partition in consumed_event["partitions"]: + tp = TopicPartition(consumed_partition["topic"], consumed_partition["partition"]) + if tp in self.global_committed: + # verify that the position never gets behind the current commit. + if self.global_committed[tp] > consumed_partition["minOffset"]: + msg = "Consumed position %d is behind the current committed offset %d for partition %s" % \ + (consumed_partition["minOffset"], self.global_committed[tp], str(tp)) + if self.verify_offsets: + raise AssertionError(msg) + else: + self.logger.warn(msg) + + # the consumer cannot generally guarantee that the position increases monotonically + # without gaps in the face of hard failures, so we only log a warning when this happens + if tp in self.global_position and self.global_position[tp] != consumed_partition["minOffset"]: + self.logger.warn("%s: Expected next consumed offset of %d for partition %s, but instead saw %d" % + (str(node.account), self.global_position[tp], str(tp), consumed_partition["minOffset"])) + + self.global_position[tp] = consumed_partition["maxOffset"] + 1 + + def _update_global_committed(self, commit_event): + if commit_event["success"]: + for offset_commit in commit_event["offsets"]: + tp = TopicPartition(offset_commit["topic"], offset_commit["partition"]) + offset = offset_commit["offset"] + assert self.global_position[tp] >= offset, \ + "Committed offset %d for partition %s is ahead of the current position %d" % \ + (offset, str(tp), self.global_position[tp]) + self.global_committed[tp] = offset + + def start_cmd(self, node): + cmd = "" + cmd += "export LOG_DIR=%s;" % VerifiableConsumer.LOG_DIR + cmd += " export KAFKA_OPTS=%s;" % self.security_config.kafka_opts + cmd += " export KAFKA_LOG4J_OPTS=\"-Dlog4j.configuration=file:%s\"; " % VerifiableConsumer.LOG4J_CONFIG + cmd += self.impl.exec_cmd(node) + if self.on_record_consumed: + cmd += " --verbose" + + if node.group_instance_id: + cmd += " --group-instance-id %s" % node.group_instance_id + elif node.version == V_2_3_0 or node.version == V_2_3_1: + # In 2.3, --group-instance-id was required, but would be left empty + # if `None` is passed as the argument value + cmd += " --group-instance-id None" + + if self.assignment_strategy: + cmd += " --assignment-strategy %s" % self.assignment_strategy + + if self.enable_autocommit: + cmd += " --enable-autocommit " + + cmd += " --reset-policy %s --group-id %s --topic %s --broker-list %s --session-timeout %s" % \ + (self.reset_policy, self.group_id, self.topic, + self.kafka.bootstrap_servers(self.security_config.security_protocol), + self.session_timeout_sec*1000) + + if self.max_messages > 0: + cmd += " --max-messages %s" % str(self.max_messages) + + cmd += " --consumer.config %s" % VerifiableConsumer.CONFIG_FILE + cmd += " 2>> %s | tee -a %s &" % (VerifiableConsumer.STDOUT_CAPTURE, VerifiableConsumer.STDOUT_CAPTURE) + return cmd + + def pids(self, node): + return self.impl.pids(node) + + def try_parse_json(self, node, string): + """Try to parse a string as json. Return None if not parseable.""" + try: + return json.loads(string) + except ValueError: + self.logger.debug("%s: Could not parse as json: %s" % (str(node.account), str(string))) + return None + + def stop_all(self): + for node in self.nodes: + self.stop_node(node) + + def kill_node(self, node, clean_shutdown=True, allow_fail=False): + sig = self.impl.kill_signal(clean_shutdown) + for pid in self.pids(node): + node.account.signal(pid, sig, allow_fail) + + with self.lock: + self.event_handlers[node].handle_kill_process(clean_shutdown) + + def stop_node(self, node, clean_shutdown=True): + self.kill_node(node, clean_shutdown=clean_shutdown) + + stopped = self.wait_node(node, timeout_sec=self.stop_timeout_sec) + assert stopped, "Node %s: did not stop within the specified timeout of %s seconds" % \ + (str(node.account), str(self.stop_timeout_sec)) + + def clean_node(self, node): + self.kill_node(node, clean_shutdown=False) + node.account.ssh("rm -rf " + self.PERSISTENT_ROOT, allow_fail=False) + self.security_config.clean_node(node) + + def current_assignment(self): + with self.lock: + return { handler.node: handler.current_assignment() for handler in self.event_handlers.values() } + + def current_position(self, tp): + with self.lock: + if tp in self.global_position: + return self.global_position[tp] + else: + return None + + def owner(self, tp): + with self.lock: + for handler in self.event_handlers.values(): + if tp in handler.current_assignment(): + return handler.node + return None + + def last_commit(self, tp): + with self.lock: + if tp in self.global_committed: + return self.global_committed[tp] + else: + return None + + def total_consumed(self): + with self.lock: + return sum(handler.total_consumed for handler in self.event_handlers.values()) + + def num_rebalances(self): + with self.lock: + return max(handler.assigned_count for handler in self.event_handlers.values()) + + def num_revokes_for_alive(self, keep_alive=1): + with self.lock: + return max(handler.revoked_count for handler in self.event_handlers.values() + if handler.idx <= keep_alive) + + def joined_nodes(self): + with self.lock: + return [handler.node for handler in self.event_handlers.values() + if handler.state == ConsumerState.Joined] + + def rebalancing_nodes(self): + with self.lock: + return [handler.node for handler in self.event_handlers.values() + if handler.state == ConsumerState.Rebalancing] + + def dead_nodes(self): + with self.lock: + return [handler.node for handler in self.event_handlers.values() + if handler.state == ConsumerState.Dead] + + def alive_nodes(self): + with self.lock: + return [handler.node for handler in self.event_handlers.values() + if handler.state != ConsumerState.Dead] diff --git a/tests/kafkatest/services/verifiable_producer.py b/tests/kafkatest/services/verifiable_producer.py new file mode 100644 index 0000000..a49c91c --- /dev/null +++ b/tests/kafkatest/services/verifiable_producer.py @@ -0,0 +1,317 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +import time +from ducktape.cluster.remoteaccount import RemoteCommandError +from ducktape.services.background_thread import BackgroundThreadService +from kafkatest.directory_layout.kafka_path import KafkaPathResolverMixin +from kafkatest.services.kafka import TopicPartition +from kafkatest.services.verifiable_client import VerifiableClientMixin +from kafkatest.utils import is_int, is_int_with_prefix +from kafkatest.version import DEV_BRANCH +from kafkatest.services.kafka.util import fix_opts_for_new_jvm + + +class VerifiableProducer(KafkaPathResolverMixin, VerifiableClientMixin, BackgroundThreadService): + """This service wraps org.apache.kafka.tools.VerifiableProducer for use in + system testing. + + NOTE: this class should be treated as a PUBLIC API. Downstream users use + this service both directly and through class extension, so care must be + taken to ensure compatibility. + """ + + PERSISTENT_ROOT = "/mnt/verifiable_producer" + STDOUT_CAPTURE = os.path.join(PERSISTENT_ROOT, "verifiable_producer.stdout") + STDERR_CAPTURE = os.path.join(PERSISTENT_ROOT, "verifiable_producer.stderr") + LOG_DIR = os.path.join(PERSISTENT_ROOT, "logs") + LOG_FILE = os.path.join(LOG_DIR, "verifiable_producer.log") + LOG4J_CONFIG = os.path.join(PERSISTENT_ROOT, "tools-log4j.properties") + CONFIG_FILE = os.path.join(PERSISTENT_ROOT, "verifiable_producer.properties") + + logs = { + "verifiable_producer_stdout": { + "path": STDOUT_CAPTURE, + "collect_default": False}, + "verifiable_producer_stderr": { + "path": STDERR_CAPTURE, + "collect_default": False}, + "verifiable_producer_log": { + "path": LOG_FILE, + "collect_default": True} + } + + def __init__(self, context, num_nodes, kafka, topic, max_messages=-1, throughput=100000, + message_validator=is_int, compression_types=None, version=DEV_BRANCH, acks=None, + stop_timeout_sec=150, request_timeout_sec=30, log_level="INFO", + enable_idempotence=False, offline_nodes=[], create_time=-1, repeating_keys=None, + jaas_override_variables=None, kafka_opts_override="", client_prop_file_override="", + retries=None): + """ + Args: + :param max_messages number of messages to be produced per producer + :param message_validator checks for an expected format of messages produced. There are + currently two: + * is_int is an integer format; this is default and expected to be used if + num_nodes = 1 + * is_int_with_prefix recommended if num_nodes > 1, because otherwise each producer + will produce exactly same messages, and validation may miss missing messages. + :param compression_types If None, all producers will not use compression; or a list of compression types, + one per producer (could be "none"). + :param jaas_override_variables A dict of variables to be used in the jaas.conf template file + :param kafka_opts_override Override parameters of the KAFKA_OPTS environment variable + :param client_prop_file_override Override client.properties file used by the consumer + """ + super(VerifiableProducer, self).__init__(context, num_nodes) + self.log_level = log_level + + self.kafka = kafka + self.topic = topic + self.max_messages = max_messages + self.throughput = throughput + self.message_validator = message_validator + self.compression_types = compression_types + if self.compression_types is not None: + assert len(self.compression_types) == num_nodes, "Specify one compression type per node" + + for node in self.nodes: + node.version = version + self.acked_values = [] + self.acked_values_by_partition = {} + self._last_acked_offsets = {} + self.not_acked_values = [] + self.produced_count = {} + self.clean_shutdown_nodes = set() + self.acks = acks + self.stop_timeout_sec = stop_timeout_sec + self.request_timeout_sec = request_timeout_sec + self.enable_idempotence = enable_idempotence + self.offline_nodes = offline_nodes + self.create_time = create_time + self.repeating_keys = repeating_keys + self.jaas_override_variables = jaas_override_variables or {} + self.kafka_opts_override = kafka_opts_override + self.client_prop_file_override = client_prop_file_override + self.retries = retries + + def java_class_name(self): + return "VerifiableProducer" + + def prop_file(self, node): + idx = self.idx(node) + prop_file = self.render('producer.properties', request_timeout_ms=(self.request_timeout_sec * 1000)) + prop_file += "\n{}".format(str(self.security_config)) + if self.compression_types is not None: + compression_index = idx - 1 + self.logger.info("VerifiableProducer (index = %d) will use compression type = %s", idx, + self.compression_types[compression_index]) + prop_file += "\ncompression.type=%s\n" % self.compression_types[compression_index] + return prop_file + + def _worker(self, idx, node): + node.account.ssh("mkdir -p %s" % VerifiableProducer.PERSISTENT_ROOT, allow_fail=False) + + # Create and upload log properties + log_config = self.render('tools_log4j.properties', log_file=VerifiableProducer.LOG_FILE) + node.account.create_file(VerifiableProducer.LOG4J_CONFIG, log_config) + + # Configure security + self.security_config = self.kafka.security_config.client_config(node=node, + jaas_override_variables=self.jaas_override_variables) + self.security_config.setup_node(node) + + # Create and upload config file + if self.client_prop_file_override: + producer_prop_file = self.client_prop_file_override + else: + producer_prop_file = self.prop_file(node) + + if self.acks is not None: + self.logger.info("VerifiableProducer (index = %d) will use acks = %s", idx, self.acks) + producer_prop_file += "\nacks=%s\n" % self.acks + + if self.enable_idempotence: + self.logger.info("Setting up an idempotent producer") + producer_prop_file += "\nmax.in.flight.requests.per.connection=5\n" + producer_prop_file += "\nretries=1000000\n" + producer_prop_file += "\nenable.idempotence=true\n" + elif self.retries is not None: + self.logger.info("VerifiableProducer (index = %d) will use retries = %s", idx, self.retries) + producer_prop_file += "\nretries=%s\n" % self.retries + producer_prop_file += "\ndelivery.timeout.ms=%s\n" % (self.request_timeout_sec * 1000 * self.retries) + + self.logger.info("verifiable_producer.properties:") + self.logger.info(producer_prop_file) + node.account.create_file(VerifiableProducer.CONFIG_FILE, producer_prop_file) + + cmd = self.start_cmd(node, idx) + self.logger.debug("VerifiableProducer %d command: %s" % (idx, cmd)) + + self.produced_count[idx] = 0 + last_produced_time = time.time() + prev_msg = None + + for line in node.account.ssh_capture(cmd): + line = line.strip() + + data = self.try_parse_json(line) + if data is not None: + + with self.lock: + if data["name"] == "producer_send_error": + data["node"] = idx + self.not_acked_values.append(self.message_validator(data["value"])) + self.produced_count[idx] += 1 + + elif data["name"] == "producer_send_success": + partition = TopicPartition(data["topic"], data["partition"]) + value = self.message_validator(data["value"]) + self.acked_values.append(value) + + if partition not in self.acked_values_by_partition: + self.acked_values_by_partition[partition] = [] + self.acked_values_by_partition[partition].append(value) + + self._last_acked_offsets[partition] = data["offset"] + self.produced_count[idx] += 1 + + # Log information if there is a large gap between successively acknowledged messages + t = time.time() + time_delta_sec = t - last_produced_time + if time_delta_sec > 2 and prev_msg is not None: + self.logger.debug( + "Time delta between successively acked messages is large: " + + "delta_t_sec: %s, prev_message: %s, current_message: %s" % (str(time_delta_sec), str(prev_msg), str(data))) + + last_produced_time = t + prev_msg = data + + elif data["name"] == "shutdown_complete": + if node in self.clean_shutdown_nodes: + raise Exception("Unexpected shutdown event from producer, already shutdown. Producer index: %d" % idx) + self.clean_shutdown_nodes.add(node) + + def _has_output(self, node): + """Helper used as a proxy to determine whether jmx is running by that jmx_tool_log contains output.""" + try: + node.account.ssh("test -z \"$(cat %s)\"" % VerifiableProducer.STDOUT_CAPTURE, allow_fail=False) + return False + except RemoteCommandError: + return True + + def start_cmd(self, node, idx): + cmd = "export LOG_DIR=%s;" % VerifiableProducer.LOG_DIR + if self.kafka_opts_override: + cmd += " export KAFKA_OPTS=\"%s\";" % self.kafka_opts_override + else: + cmd += " export KAFKA_OPTS=%s;" % self.security_config.kafka_opts + + cmd += fix_opts_for_new_jvm(node) + cmd += " export KAFKA_LOG4J_OPTS=\"-Dlog4j.configuration=file:%s\"; " % VerifiableProducer.LOG4J_CONFIG + cmd += self.impl.exec_cmd(node) + cmd += " --topic %s --broker-list %s" % (self.topic, self.kafka.bootstrap_servers(self.security_config.security_protocol, True, self.offline_nodes)) + if self.max_messages > 0: + cmd += " --max-messages %s" % str(self.max_messages) + if self.throughput > 0: + cmd += " --throughput %s" % str(self.throughput) + if self.message_validator == is_int_with_prefix: + cmd += " --value-prefix %s" % str(idx) + if self.acks is not None: + cmd += " --acks %s " % str(self.acks) + if self.create_time > -1: + cmd += " --message-create-time %s " % str(self.create_time) + if self.repeating_keys is not None: + cmd += " --repeating-keys %s " % str(self.repeating_keys) + + cmd += " --producer.config %s" % VerifiableProducer.CONFIG_FILE + + cmd += " 2>> %s | tee -a %s &" % (VerifiableProducer.STDOUT_CAPTURE, VerifiableProducer.STDOUT_CAPTURE) + return cmd + + def kill_node(self, node, clean_shutdown=True, allow_fail=False): + sig = self.impl.kill_signal(clean_shutdown) + for pid in self.pids(node): + node.account.signal(pid, sig, allow_fail) + + def pids(self, node): + return self.impl.pids(node) + + def alive(self, node): + return len(self.pids(node)) > 0 + + @property + def last_acked_offsets(self): + with self.lock: + return self._last_acked_offsets + + @property + def acked(self): + with self.lock: + return self.acked_values + + @property + def acked_by_partition(self): + with self.lock: + return self.acked_values_by_partition + + @property + def not_acked(self): + with self.lock: + return self.not_acked_values + + @property + def num_acked(self): + with self.lock: + return len(self.acked_values) + + @property + def num_not_acked(self): + with self.lock: + return len(self.not_acked_values) + + def each_produced_at_least(self, count): + with self.lock: + for idx in range(1, self.num_nodes + 1): + if self.produced_count.get(idx) is None or self.produced_count[idx] < count: + return False + return True + + def stop_node(self, node): + # There is a race condition on shutdown if using `max_messages` since the + # VerifiableProducer will shutdown automatically when all messages have been + # written. In this case, the process will be gone and the signal will fail. + allow_fail = self.max_messages > 0 + self.kill_node(node, clean_shutdown=True, allow_fail=allow_fail) + + stopped = self.wait_node(node, timeout_sec=self.stop_timeout_sec) + assert stopped, "Node %s: did not stop within the specified timeout of %s seconds" % \ + (str(node.account), str(self.stop_timeout_sec)) + + def clean_node(self, node): + self.kill_node(node, clean_shutdown=False, allow_fail=False) + node.account.ssh("rm -rf " + self.PERSISTENT_ROOT, allow_fail=False) + self.security_config.clean_node(node) + + def try_parse_json(self, string): + """Try to parse a string as json. Return None if not parseable.""" + try: + record = json.loads(string) + return record + except ValueError: + self.logger.debug("Could not parse as json: %s" % str(string)) + return None diff --git a/tests/kafkatest/services/zookeeper.py b/tests/kafkatest/services/zookeeper.py new file mode 100644 index 0000000..c9d86e8 --- /dev/null +++ b/tests/kafkatest/services/zookeeper.py @@ -0,0 +1,256 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import re + +from ducktape.services.service import Service +from ducktape.utils.util import wait_until +from ducktape.cluster.remoteaccount import RemoteCommandError + +from kafkatest.directory_layout.kafka_path import KafkaPathResolverMixin +from kafkatest.services.security.security_config import SecurityConfig +from kafkatest.version import DEV_BRANCH + + +class ZookeeperService(KafkaPathResolverMixin, Service): + ROOT = "/mnt/zookeeper" + DATA = os.path.join(ROOT, "data") + HEAP_DUMP_FILE = os.path.join(ROOT, "zk_heap_dump.bin") + + logs = { + "zk_log": { + "path": "%s/zk.log" % ROOT, + "collect_default": True}, + "zk_data": { + "path": DATA, + "collect_default": False}, + "zk_heap_dump_file": { + "path": HEAP_DUMP_FILE, + "collect_default": True} + } + + def __init__(self, context, num_nodes, zk_sasl = False, zk_client_port = True, zk_client_secure_port = False, + zk_tls_encrypt_only = False, version=DEV_BRANCH): + """ + :type context + """ + self.kafka_opts = "" + self.zk_sasl = zk_sasl + if (zk_client_secure_port or zk_tls_encrypt_only) and not version.supports_tls_to_zookeeper(): + raise Exception("Cannot use TLS with a ZooKeeper version that does not support it: %s" % str(version)) + if not zk_client_port and not zk_client_secure_port: + raise Exception("Cannot disable both ZK clientPort and clientSecurePort") + self.zk_client_port = zk_client_port + self.zk_client_secure_port = zk_client_secure_port + self.zk_tls_encrypt_only = zk_tls_encrypt_only + super(ZookeeperService, self).__init__(context, num_nodes) + self.set_version(version) + + def set_version(self, version): + for node in self.nodes: + node.version = version + + @property + def security_config(self): + return SecurityConfig(self.context, zk_sasl=self.zk_sasl, zk_tls=self.zk_client_secure_port) + + @property + def security_system_properties(self): + return "-Dzookeeper.authProvider.sasl=org.apache.zookeeper.server.auth.SASLAuthenticationProvider " \ + "-DjaasLoginRenew=3600000 " \ + "-Djava.security.auth.login.config=%s " \ + "-Djava.security.krb5.conf=%s " % (self.security_config.JAAS_CONF_PATH, self.security_config.KRB5CONF_PATH) + + @property + def zk_principals(self): + return " zkclient " + ' '.join(['zookeeper/' + zk_node.account.hostname for zk_node in self.nodes]) + + def restart_cluster(self): + for node in self.nodes: + self.restart_node(node) + + def restart_node(self, node): + """Restart the given node.""" + self.stop_node(node) + self.start_node(node) + + def start_node(self, node): + idx = self.idx(node) + self.logger.info("Starting ZK node %d on %s", idx, node.account.hostname) + + node.account.ssh("mkdir -p %s" % ZookeeperService.DATA) + node.account.ssh("echo %d > %s/myid" % (idx, ZookeeperService.DATA)) + + self.security_config.setup_node(node) + config_file = self.render('zookeeper.properties') + self.logger.info("zookeeper.properties:") + self.logger.info(config_file) + node.account.create_file("%s/zookeeper.properties" % ZookeeperService.ROOT, config_file) + + heap_kafka_opts = "-XX:+HeapDumpOnOutOfMemoryError -XX:HeapDumpPath=%s" % self.logs["zk_heap_dump_file"]["path"] + other_kafka_opts = self.kafka_opts + ' ' + self.security_system_properties \ + if self.security_config.zk_sasl else self.kafka_opts + start_cmd = "export KAFKA_OPTS=\"%s %s\";" % (heap_kafka_opts, other_kafka_opts) + start_cmd += "%s " % self.path.script("zookeeper-server-start.sh", node) + start_cmd += "%s/zookeeper.properties &>> %s &" % (ZookeeperService.ROOT, self.logs["zk_log"]["path"]) + node.account.ssh(start_cmd) + + wait_until(lambda: self.listening(node), timeout_sec=30, err_msg="Zookeeper node failed to start") + + def listening(self, node): + try: + port = 2181 if self.zk_client_port else 2182 + cmd = "nc -z %s %s" % (node.account.hostname, port) + node.account.ssh_output(cmd, allow_fail=False) + self.logger.debug("Zookeeper started accepting connections at: '%s:%s')", node.account.hostname, port) + return True + except (RemoteCommandError, ValueError) as e: + return False + + def pids(self, node): + return node.account.java_pids(self.java_class_name()) + + def alive(self, node): + return len(self.pids(node)) > 0 + + def stop_node(self, node): + idx = self.idx(node) + self.logger.info("Stopping %s node %d on %s" % (type(self).__name__, idx, node.account.hostname)) + node.account.kill_java_processes(self.java_class_name(), allow_fail=False) + node.account.kill_java_processes(self.java_cli_class_name(), allow_fail=False) + wait_until(lambda: not self.alive(node), timeout_sec=5, err_msg="Timed out waiting for zookeeper to stop.") + + def clean_node(self, node): + self.logger.info("Cleaning ZK node %d on %s", self.idx(node), node.account.hostname) + if self.alive(node): + self.logger.warn("%s %s was still alive at cleanup time. Killing forcefully..." % + (self.__class__.__name__, node.account)) + node.account.kill_java_processes(self.java_class_name(), + clean_shutdown=False, allow_fail=True) + node.account.kill_java_processes(self.java_cli_class_name(), + clean_shutdown=False, allow_fail=False) + node.account.ssh("rm -rf -- %s" % ZookeeperService.ROOT, allow_fail=False) + + + # force_tls is a necessary option for the case where we define both encrypted and non-encrypted ports + def connect_setting(self, chroot=None, force_tls=False): + if chroot and not chroot.startswith("/"): + raise Exception("ZK chroot must start with '/', invalid chroot: %s" % chroot) + + chroot = '' if chroot is None else chroot + return ','.join([node.account.hostname + (':2182' if not self.zk_client_port or force_tls else ':2181') + chroot + for node in self.nodes]) + + def zkTlsConfigFileOption(self, forZooKeeperMain=False): + if not self.zk_client_secure_port: + return "" + return ("-zk-tls-config-file " if forZooKeeperMain else "--zk-tls-config-file ") + \ + (SecurityConfig.ZK_CLIENT_TLS_ENCRYPT_ONLY_CONFIG_PATH if self.zk_tls_encrypt_only else SecurityConfig.ZK_CLIENT_MUTUAL_AUTH_CONFIG_PATH) + + # + # This call is used to simulate a rolling upgrade to enable/disable + # the use of ZooKeeper ACLs. + # + def zookeeper_migration(self, node, zk_acl): + la_migra_cmd = "export KAFKA_OPTS=\"%s\";" % \ + self.security_system_properties if self.security_config.zk_sasl else "" + la_migra_cmd += "%s --zookeeper.acl=%s --zookeeper.connect=%s %s" % \ + (self.path.script("zookeeper-security-migration.sh", node), zk_acl, + self.connect_setting(force_tls=self.zk_client_secure_port), + self.zkTlsConfigFileOption()) + node.account.ssh(la_migra_cmd) + + def _check_chroot(self, chroot): + if chroot and not chroot.startswith("/"): + raise Exception("ZK chroot must start with '/', invalid chroot: %s" % chroot) + + def query(self, path, chroot=None): + """ + Queries zookeeper for data associated with 'path' and returns all fields in the schema + """ + self._check_chroot(chroot) + + chroot_path = ('' if chroot is None else chroot) + path + + kafka_run_class = self.path.script("kafka-run-class.sh", DEV_BRANCH) + cmd = "%s %s -server %s %s get %s" % \ + (kafka_run_class, self.java_cli_class_name(), self.connect_setting(force_tls=self.zk_client_secure_port), + self.zkTlsConfigFileOption(True), + chroot_path) + self.logger.debug(cmd) + + node = self.nodes[0] + result = None + for line in node.account.ssh_capture(cmd, allow_fail=True): + # loop through all lines in the output, but only hold on to the first match + if result is None: + match = re.match("^({.+})$", line) + if match is not None: + result = match.groups()[0] + return result + + def create(self, path, chroot=None, value=""): + """ + Create an znode at the given path + """ + self._check_chroot(chroot) + + chroot_path = ('' if chroot is None else chroot) + path + + kafka_run_class = self.path.script("kafka-run-class.sh", DEV_BRANCH) + cmd = "%s %s -server %s %s create %s '%s'" % \ + (kafka_run_class, self.java_cli_class_name(), self.connect_setting(force_tls=self.zk_client_secure_port), + self.zkTlsConfigFileOption(True), + chroot_path, value) + self.logger.debug(cmd) + output = self.nodes[0].account.ssh_output(cmd) + self.logger.debug(output) + + def describeUsers(self): + """ + Describe the default user using the ConfigCommand CLI + """ + + kafka_run_class = self.path.script("kafka-run-class.sh", DEV_BRANCH) + cmd = "%s kafka.admin.ConfigCommand --zookeeper %s %s --describe --entity-type users --entity-default" % \ + (kafka_run_class, self.connect_setting(force_tls=self.zk_client_secure_port), + self.zkTlsConfigFileOption()) + self.logger.debug(cmd) + output = self.nodes[0].account.ssh_output(cmd) + self.logger.debug(output) + + def list_acls(self, topic): + """ + List ACLs for the given topic using the AclCommand CLI + """ + + kafka_run_class = self.path.script("kafka-run-class.sh", DEV_BRANCH) + cmd = "%s kafka.admin.AclCommand --authorizer-properties zookeeper.connect=%s %s --list --topic %s" % \ + (kafka_run_class, self.connect_setting(force_tls=self.zk_client_secure_port), + self.zkTlsConfigFileOption(), + topic) + self.logger.debug(cmd) + output = self.nodes[0].account.ssh_output(cmd) + self.logger.debug(output) + + def java_class_name(self): + """ The class name of the Zookeeper quorum peers. """ + return "org.apache.zookeeper.server.quorum.QuorumPeerMain" + + def java_cli_class_name(self): + """ The class name of the Zookeeper tool within Kafka. """ + return "org.apache.zookeeper.ZooKeeperMainWithTlsSupportForKafka" diff --git a/tests/kafkatest/tests/__init__.py b/tests/kafkatest/tests/__init__.py new file mode 100644 index 0000000..ec20143 --- /dev/null +++ b/tests/kafkatest/tests/__init__.py @@ -0,0 +1,14 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/kafkatest/tests/client/__init__.py b/tests/kafkatest/tests/client/__init__.py new file mode 100644 index 0000000..ec20143 --- /dev/null +++ b/tests/kafkatest/tests/client/__init__.py @@ -0,0 +1,14 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/kafkatest/tests/client/client_compatibility_features_test.py b/tests/kafkatest/tests/client/client_compatibility_features_test.py new file mode 100644 index 0000000..15b6a93 --- /dev/null +++ b/tests/kafkatest/tests/client/client_compatibility_features_test.py @@ -0,0 +1,137 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import errno +import time +from random import randint + +from ducktape.mark import matrix, parametrize +from ducktape.mark.resource import cluster +from ducktape.tests.test import TestContext + +from kafkatest.services.zookeeper import ZookeeperService +from kafkatest.services.kafka import KafkaService, quorum +from ducktape.tests.test import Test +from kafkatest.version import DEV_BRANCH, LATEST_0_10_0, LATEST_0_10_1, LATEST_0_10_2, LATEST_0_11_0, LATEST_1_0, LATEST_1_1, LATEST_2_0, LATEST_2_1, LATEST_2_2, LATEST_2_3, LATEST_2_4, LATEST_2_5, LATEST_2_6, LATEST_2_7, LATEST_2_8, V_0_11_0_0, V_0_10_1_0, KafkaVersion + +def get_broker_features(broker_version): + features = {} + if broker_version < V_0_10_1_0: + features["create-topics-supported"] = False + features["offsets-for-times-supported"] = False + features["cluster-id-supported"] = False + features["expect-record-too-large-exception"] = True + else: + features["create-topics-supported"] = True + features["offsets-for-times-supported"] = True + features["cluster-id-supported"] = True + features["expect-record-too-large-exception"] = False + if broker_version < V_0_11_0_0: + features["describe-acls-supported"] = False + features["describe-configs-supported"] = False + else: + features["describe-acls-supported"] = True + features["describe-configs-supported"] = True + return features + +def run_command(node, cmd, ssh_log_file): + with open(ssh_log_file, 'w') as f: + f.write("Running %s\n" % cmd) + try: + for line in node.account.ssh_capture(cmd): + f.write(line) + except Exception as e: + f.write("** Command failed!") + print(e, flush=True) + raise + + +class ClientCompatibilityFeaturesTest(Test): + """ + Tests clients for the presence or absence of specific features when communicating with brokers with various + versions. Relies on ClientCompatibilityTest.java for much of the functionality. + """ + + def __init__(self, test_context): + """:type test_context: ducktape.tests.test.TestContext""" + super(ClientCompatibilityFeaturesTest, self).__init__(test_context=test_context) + + self.zk = ZookeeperService(test_context, num_nodes=3) if quorum.for_test(test_context) == quorum.zk else None + + # Generate a unique topic name + topic_name = "client_compat_features_topic_%d%d" % (int(time.time()), randint(0, 2147483647)) + self.topics = { topic_name: { + "partitions": 1, # Use only one partition to avoid worrying about ordering + "replication-factor": 3 + }} + self.kafka = KafkaService(test_context, num_nodes=3, zk=self.zk, topics=self.topics) + # Always use the latest version of org.apache.kafka.tools.ClientCompatibilityTest + # so store away the path to the DEV version before we set the Kafka version + self.dev_script_path = self.kafka.path.script("kafka-run-class.sh", self.kafka.nodes[0]) + + def invoke_compatibility_program(self, features): + # Run the compatibility test on the first Kafka node. + node = self.kafka.nodes[0] + cmd = ("%s org.apache.kafka.tools.ClientCompatibilityTest " + "--bootstrap-server %s " + "--num-cluster-nodes %d " + "--topic %s " % (self.dev_script_path, + self.kafka.bootstrap_servers(), + len(self.kafka.nodes), + list(self.topics.keys())[0])) + for k, v in features.items(): + cmd = cmd + ("--%s %s " % (k, v)) + results_dir = TestContext.results_dir(self.test_context, 0) + try: + os.makedirs(results_dir) + except OSError as e: + if e.errno == errno.EEXIST and os.path.isdir(results_dir): + pass + else: + raise + ssh_log_file = "%s/%s" % (results_dir, "client_compatibility_test_output.txt") + try: + self.logger.info("Running %s" % cmd) + run_command(node, cmd, ssh_log_file) + except Exception as e: + self.logger.info("** Command failed. See %s for log messages." % ssh_log_file) + raise + + @cluster(num_nodes=7) + @matrix(broker_version=[str(DEV_BRANCH)], metadata_quorum=quorum.all_non_upgrade) + @parametrize(broker_version=str(LATEST_0_10_0)) + @parametrize(broker_version=str(LATEST_0_10_1)) + @parametrize(broker_version=str(LATEST_0_10_2)) + @parametrize(broker_version=str(LATEST_0_11_0)) + @parametrize(broker_version=str(LATEST_1_0)) + @parametrize(broker_version=str(LATEST_1_1)) + @parametrize(broker_version=str(LATEST_2_0)) + @parametrize(broker_version=str(LATEST_2_1)) + @parametrize(broker_version=str(LATEST_2_2)) + @parametrize(broker_version=str(LATEST_2_3)) + @parametrize(broker_version=str(LATEST_2_4)) + @parametrize(broker_version=str(LATEST_2_5)) + @parametrize(broker_version=str(LATEST_2_6)) + @parametrize(broker_version=str(LATEST_2_7)) + @parametrize(broker_version=str(LATEST_2_8)) + def run_compatibility_test(self, broker_version, metadata_quorum=quorum.zk): + if self.zk: + self.zk.start() + self.kafka.set_version(KafkaVersion(broker_version)) + self.kafka.start() + features = get_broker_features(broker_version) + self.invoke_compatibility_program(features) diff --git a/tests/kafkatest/tests/client/client_compatibility_produce_consume_test.py b/tests/kafkatest/tests/client/client_compatibility_produce_consume_test.py new file mode 100644 index 0000000..1a46746 --- /dev/null +++ b/tests/kafkatest/tests/client/client_compatibility_produce_consume_test.py @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ducktape.mark import matrix, parametrize +from ducktape.mark.resource import cluster +from ducktape.utils.util import wait_until + +from kafkatest.services.zookeeper import ZookeeperService +from kafkatest.services.kafka import KafkaService, quorum +from kafkatest.services.verifiable_producer import VerifiableProducer +from kafkatest.services.console_consumer import ConsoleConsumer +from kafkatest.tests.produce_consume_validate import ProduceConsumeValidateTest +from kafkatest.utils import is_int_with_prefix +from kafkatest.version import DEV_BRANCH, LATEST_0_10_0, LATEST_0_10_1, LATEST_0_10_2, LATEST_0_11_0, LATEST_1_0, LATEST_1_1, LATEST_2_0, LATEST_2_1, LATEST_2_2, LATEST_2_3, LATEST_2_4, LATEST_2_5, LATEST_2_6, LATEST_2_7, LATEST_2_8, KafkaVersion + +class ClientCompatibilityProduceConsumeTest(ProduceConsumeValidateTest): + """ + These tests validate that we can use a new client to produce and consume from older brokers. + """ + + def __init__(self, test_context): + """:type test_context: ducktape.tests.test.TestContext""" + super(ClientCompatibilityProduceConsumeTest, self).__init__(test_context=test_context) + + self.topic = "test_topic" + self.zk = ZookeeperService(test_context, num_nodes=3) if quorum.for_test(test_context) == quorum.zk else None + self.kafka = KafkaService(test_context, num_nodes=3, zk=self.zk, topics={self.topic:{ + "partitions": 10, + "replication-factor": 2}}) + self.num_partitions = 10 + self.timeout_sec = 60 + self.producer_throughput = 1000 + self.num_producers = 2 + self.messages_per_producer = 1000 + self.num_consumers = 1 + + def setUp(self): + if self.zk: + self.zk.start() + + def min_cluster_size(self): + # Override this since we're adding services outside of the constructor + return super(ClientCompatibilityProduceConsumeTest, self).min_cluster_size() + self.num_producers + self.num_consumers + + @cluster(num_nodes=9) + @matrix(broker_version=[str(DEV_BRANCH)], metadata_quorum=quorum.all_non_upgrade) + @parametrize(broker_version=str(LATEST_0_10_0)) + @parametrize(broker_version=str(LATEST_0_10_1)) + @parametrize(broker_version=str(LATEST_0_10_2)) + @parametrize(broker_version=str(LATEST_0_11_0)) + @parametrize(broker_version=str(LATEST_1_0)) + @parametrize(broker_version=str(LATEST_1_1)) + @parametrize(broker_version=str(LATEST_2_0)) + @parametrize(broker_version=str(LATEST_2_1)) + @parametrize(broker_version=str(LATEST_2_2)) + @parametrize(broker_version=str(LATEST_2_3)) + @parametrize(broker_version=str(LATEST_2_4)) + @parametrize(broker_version=str(LATEST_2_5)) + @parametrize(broker_version=str(LATEST_2_6)) + @parametrize(broker_version=str(LATEST_2_7)) + @parametrize(broker_version=str(LATEST_2_8)) + def test_produce_consume(self, broker_version, metadata_quorum=quorum.zk): + print("running producer_consumer_compat with broker_version = %s" % broker_version, flush=True) + self.kafka.set_version(KafkaVersion(broker_version)) + self.kafka.security_protocol = "PLAINTEXT" + self.kafka.interbroker_security_protocol = self.kafka.security_protocol + self.producer = VerifiableProducer(self.test_context, self.num_producers, self.kafka, + self.topic, throughput=self.producer_throughput, + message_validator=is_int_with_prefix) + self.consumer = ConsoleConsumer(self.test_context, self.num_consumers, self.kafka, self.topic, + consumer_timeout_ms=60000, + message_validator=is_int_with_prefix) + self.kafka.start() + + self.run_produce_consume_validate(lambda: wait_until( + lambda: self.producer.each_produced_at_least(self.messages_per_producer) == True, + timeout_sec=120, backoff_sec=1, + err_msg="Producer did not produce all messages in reasonable amount of time")) + diff --git a/tests/kafkatest/tests/client/compression_test.py b/tests/kafkatest/tests/client/compression_test.py new file mode 100644 index 0000000..37ce52d --- /dev/null +++ b/tests/kafkatest/tests/client/compression_test.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ducktape.mark import matrix +from ducktape.utils.util import wait_until +from ducktape.mark.resource import cluster + +from kafkatest.services.zookeeper import ZookeeperService +from kafkatest.services.kafka import KafkaService, quorum +from kafkatest.services.verifiable_producer import VerifiableProducer +from kafkatest.services.console_consumer import ConsoleConsumer +from kafkatest.tests.produce_consume_validate import ProduceConsumeValidateTest +from kafkatest.utils import is_int_with_prefix + + +class CompressionTest(ProduceConsumeValidateTest): + """ + These tests validate produce / consume for compressed topics. + """ + COMPRESSION_TYPES = ["snappy", "gzip", "lz4", "zstd", "none"] + + def __init__(self, test_context): + """:type test_context: ducktape.tests.test.TestContext""" + super(CompressionTest, self).__init__(test_context=test_context) + + self.topic = "test_topic" + self.zk = ZookeeperService(test_context, num_nodes=1) if quorum.for_test(test_context) == quorum.zk else None + self.kafka = KafkaService(test_context, num_nodes=1, zk=self.zk, topics={self.topic: { + "partitions": 10, + "replication-factor": 1}}) + self.num_partitions = 10 + self.timeout_sec = 60 + self.producer_throughput = 1000 + self.num_producers = len(self.COMPRESSION_TYPES) + self.messages_per_producer = 1000 + self.num_consumers = 1 + + def setUp(self): + if self.zk: + self.zk.start() + + def min_cluster_size(self): + # Override this since we're adding services outside of the constructor + return super(CompressionTest, self).min_cluster_size() + self.num_producers + self.num_consumers + + @cluster(num_nodes=8) + @matrix(compression_types=[COMPRESSION_TYPES], metadata_quorum=quorum.all_non_upgrade) + def test_compressed_topic(self, compression_types, metadata_quorum=quorum.zk): + """Test produce => consume => validate for compressed topics + Setup: 1 zk, 1 kafka node, 1 topic with partitions=10, replication-factor=1 + + compression_types parameter gives a list of compression types (or no compression if + "none"). Each producer in a VerifiableProducer group (num_producers = number of compression + types) will use a compression type from the list based on producer's index in the group. + + - Produce messages in the background + - Consume messages in the background + - Stop producing, and finish consuming + - Validate that every acked message was consumed + """ + + self.kafka.security_protocol = "PLAINTEXT" + self.kafka.interbroker_security_protocol = self.kafka.security_protocol + self.producer = VerifiableProducer(self.test_context, self.num_producers, self.kafka, + self.topic, throughput=self.producer_throughput, + message_validator=is_int_with_prefix, + compression_types=compression_types) + self.consumer = ConsoleConsumer(self.test_context, self.num_consumers, self.kafka, self.topic, + consumer_timeout_ms=60000, message_validator=is_int_with_prefix) + self.kafka.start() + + self.run_produce_consume_validate(lambda: wait_until( + lambda: self.producer.each_produced_at_least(self.messages_per_producer) == True, + timeout_sec=120, backoff_sec=1, + err_msg="Producer did not produce all messages in reasonable amount of time")) + diff --git a/tests/kafkatest/tests/client/consumer_rolling_upgrade_test.py b/tests/kafkatest/tests/client/consumer_rolling_upgrade_test.py new file mode 100644 index 0000000..5beacf2 --- /dev/null +++ b/tests/kafkatest/tests/client/consumer_rolling_upgrade_test.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ducktape.mark import matrix +from ducktape.mark.resource import cluster + + +from kafkatest.tests.verifiable_consumer_test import VerifiableConsumerTest +from kafkatest.services.kafka import TopicPartition, quorum + +class ConsumerRollingUpgradeTest(VerifiableConsumerTest): + TOPIC = "test_topic" + NUM_PARTITIONS = 4 + RANGE = "org.apache.kafka.clients.consumer.RangeAssignor" + ROUND_ROBIN = "org.apache.kafka.clients.consumer.RoundRobinAssignor" + + def __init__(self, test_context): + super(ConsumerRollingUpgradeTest, self).__init__(test_context, num_consumers=2, num_producers=0, + num_zk=1, num_brokers=1, topics={ + self.TOPIC : { 'partitions': self.NUM_PARTITIONS, 'replication-factor': 1 } + }) + + def _verify_range_assignment(self, consumer): + # range assignment should give us two partition sets: (0, 1) and (2, 3) + assignment = set([frozenset(partitions) for partitions in consumer.current_assignment().values()]) + assert assignment == set([ + frozenset([TopicPartition(self.TOPIC, 0), TopicPartition(self.TOPIC, 1)]), + frozenset([TopicPartition(self.TOPIC, 2), TopicPartition(self.TOPIC, 3)])]), \ + "Mismatched assignment: %s" % assignment + + def _verify_roundrobin_assignment(self, consumer): + assignment = set([frozenset(x) for x in consumer.current_assignment().values()]) + assert assignment == set([ + frozenset([TopicPartition(self.TOPIC, 0), TopicPartition(self.TOPIC, 2)]), + frozenset([TopicPartition(self.TOPIC, 1), TopicPartition(self.TOPIC, 3)])]), \ + "Mismatched assignment: %s" % assignment + + @cluster(num_nodes=4) + @matrix(metadata_quorum=quorum.all_non_upgrade) + def rolling_update_test(self, metadata_quorum=quorum.zk): + """ + Verify rolling updates of partition assignment strategies works correctly. In this + test, we use a rolling restart to change the group's assignment strategy from "range" + to "roundrobin." We verify after every restart that all members are still in the group + and that the correct assignment strategy was used. + """ + + # initialize the consumer using range assignment + consumer = self.setup_consumer(self.TOPIC, assignment_strategy=self.RANGE) + + consumer.start() + self.await_all_members(consumer) + self._verify_range_assignment(consumer) + + # change consumer configuration to prefer round-robin assignment, but still support range assignment + consumer.assignment_strategy = self.ROUND_ROBIN + "," + self.RANGE + + # restart one of the nodes and verify that we are still using range assignment + consumer.stop_node(consumer.nodes[0]) + consumer.start_node(consumer.nodes[0]) + self.await_all_members(consumer) + self._verify_range_assignment(consumer) + + # now restart the other node and verify that we have switched to round-robin + consumer.stop_node(consumer.nodes[1]) + consumer.start_node(consumer.nodes[1]) + self.await_all_members(consumer) + self._verify_roundrobin_assignment(consumer) + + # if we want, we can now drop support for range assignment + consumer.assignment_strategy = self.ROUND_ROBIN + for node in consumer.nodes: + consumer.stop_node(node) + consumer.start_node(node) + self.await_all_members(consumer) + self._verify_roundrobin_assignment(consumer) diff --git a/tests/kafkatest/tests/client/consumer_test.py b/tests/kafkatest/tests/client/consumer_test.py new file mode 100644 index 0000000..49e9331 --- /dev/null +++ b/tests/kafkatest/tests/client/consumer_test.py @@ -0,0 +1,473 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ducktape.mark import matrix +from ducktape.utils.util import wait_until +from ducktape.mark.resource import cluster + +from kafkatest.tests.verifiable_consumer_test import VerifiableConsumerTest +from kafkatest.services.kafka import TopicPartition, quorum + +import signal + + +class OffsetValidationTest(VerifiableConsumerTest): + TOPIC = "test_topic" + NUM_PARTITIONS = 1 + + def __init__(self, test_context): + super(OffsetValidationTest, self).__init__(test_context, num_consumers=3, num_producers=1, + num_zk=1, num_brokers=2, topics={ + self.TOPIC : { 'partitions': self.NUM_PARTITIONS, 'replication-factor': 2 } + }) + + def rolling_bounce_consumers(self, consumer, keep_alive=0, num_bounces=5, clean_shutdown=True): + for _ in range(num_bounces): + for node in consumer.nodes[keep_alive:]: + consumer.stop_node(node, clean_shutdown) + + wait_until(lambda: len(consumer.dead_nodes()) == 1, + timeout_sec=self.session_timeout_sec+5, + err_msg="Timed out waiting for the consumer to shutdown") + + consumer.start_node(node) + + self.await_all_members(consumer) + self.await_consumed_messages(consumer) + + def bounce_all_consumers(self, consumer, keep_alive=0, num_bounces=5, clean_shutdown=True): + for _ in range(num_bounces): + for node in consumer.nodes[keep_alive:]: + consumer.stop_node(node, clean_shutdown) + + wait_until(lambda: len(consumer.dead_nodes()) == self.num_consumers - keep_alive, timeout_sec=10, + err_msg="Timed out waiting for the consumers to shutdown") + + for node in consumer.nodes[keep_alive:]: + consumer.start_node(node) + + self.await_all_members(consumer) + self.await_consumed_messages(consumer) + + def rolling_bounce_brokers(self, consumer, num_bounces=5, clean_shutdown=True): + for _ in range(num_bounces): + for node in self.kafka.nodes: + self.kafka.restart_node(node, clean_shutdown=True) + self.await_all_members(consumer) + self.await_consumed_messages(consumer) + + def setup_consumer(self, topic, **kwargs): + # collect verifiable consumer events since this makes debugging much easier + consumer = super(OffsetValidationTest, self).setup_consumer(topic, **kwargs) + self.mark_for_collect(consumer, 'verifiable_consumer_stdout') + return consumer + + @cluster(num_nodes=7) + @matrix(metadata_quorum=quorum.all_non_upgrade) + def test_broker_rolling_bounce(self, metadata_quorum=quorum.zk): + """ + Verify correct consumer behavior when the brokers are consecutively restarted. + + Setup: single Kafka cluster with one producer writing messages to a single topic with one + partition, an a set of consumers in the same group reading from the same topic. + + - Start a producer which continues producing new messages throughout the test. + - Start up the consumers and wait until they've joined the group. + - In a loop, restart each broker consecutively, waiting for the group to stabilize between + each broker restart. + - Verify delivery semantics according to the failure type and that the broker bounces + did not cause unexpected group rebalances. + """ + partition = TopicPartition(self.TOPIC, 0) + + producer = self.setup_producer(self.TOPIC) + # The consumers' session timeouts must exceed the time it takes for a broker to roll. Consumers are likely + # to see cluster metadata consisting of just a single alive broker in the case where the cluster has just 2 + # brokers and the cluster is rolling (which is what is happening here). When the consumer sees a single alive + # broker, and then that broker rolls, the consumer will be unable to connect to the cluster until that broker + # completes its roll. In the meantime, the consumer group will move to the group coordinator on the other + # broker, and that coordinator will fail the consumer and trigger a group rebalance if its session times out. + # This test is asserting that no rebalances occur, so we increase the session timeout for this to be the case. + self.session_timeout_sec = 30 + consumer = self.setup_consumer(self.TOPIC) + + producer.start() + self.await_produced_messages(producer) + + consumer.start() + self.await_all_members(consumer) + + num_rebalances = consumer.num_rebalances() + # TODO: make this test work with hard shutdowns, which probably requires + # pausing before the node is restarted to ensure that any ephemeral + # nodes have time to expire + self.rolling_bounce_brokers(consumer, clean_shutdown=True) + + unexpected_rebalances = consumer.num_rebalances() - num_rebalances + assert unexpected_rebalances == 0, \ + "Broker rolling bounce caused %d unexpected group rebalances" % unexpected_rebalances + + consumer.stop_all() + + assert consumer.current_position(partition) == consumer.total_consumed(), \ + "Total consumed records %d did not match consumed position %d" % \ + (consumer.total_consumed(), consumer.current_position(partition)) + + @cluster(num_nodes=7) + @matrix(clean_shutdown=[True], bounce_mode=["all", "rolling"], metadata_quorum=quorum.all_non_upgrade) + def test_consumer_bounce(self, clean_shutdown, bounce_mode, metadata_quorum=quorum.zk): + """ + Verify correct consumer behavior when the consumers in the group are consecutively restarted. + + Setup: single Kafka cluster with one producer and a set of consumers in one group. + + - Start a producer which continues producing new messages throughout the test. + - Start up the consumers and wait until they've joined the group. + - In a loop, restart each consumer, waiting for each one to rejoin the group before + restarting the rest. + - Verify delivery semantics according to the failure type. + """ + partition = TopicPartition(self.TOPIC, 0) + + producer = self.setup_producer(self.TOPIC) + consumer = self.setup_consumer(self.TOPIC) + + producer.start() + self.await_produced_messages(producer) + + consumer.start() + self.await_all_members(consumer) + + if bounce_mode == "all": + self.bounce_all_consumers(consumer, clean_shutdown=clean_shutdown) + else: + self.rolling_bounce_consumers(consumer, clean_shutdown=clean_shutdown) + + consumer.stop_all() + if clean_shutdown: + # if the total records consumed matches the current position, we haven't seen any duplicates + # this can only be guaranteed with a clean shutdown + assert consumer.current_position(partition) == consumer.total_consumed(), \ + "Total consumed records %d did not match consumed position %d" % \ + (consumer.total_consumed(), consumer.current_position(partition)) + else: + # we may have duplicates in a hard failure + assert consumer.current_position(partition) <= consumer.total_consumed(), \ + "Current position %d greater than the total number of consumed records %d" % \ + (consumer.current_position(partition), consumer.total_consumed()) + + @cluster(num_nodes=7) + @matrix(clean_shutdown=[True], static_membership=[True, False], bounce_mode=["all", "rolling"], num_bounces=[5], metadata_quorum=quorum.all_non_upgrade) + def test_static_consumer_bounce(self, clean_shutdown, static_membership, bounce_mode, num_bounces, metadata_quorum=quorum.zk): + """ + Verify correct static consumer behavior when the consumers in the group are restarted. In order to make + sure the behavior of static members are different from dynamic ones, we take both static and dynamic + membership into this test suite. + + Setup: single Kafka cluster with one producer and a set of consumers in one group. + + - Start a producer which continues producing new messages throughout the test. + - Start up the consumers as static/dynamic members and wait until they've joined the group. + - In a loop, restart each consumer except the first member (note: may not be the leader), and expect no rebalance triggered + during this process if the group is in static membership. + """ + partition = TopicPartition(self.TOPIC, 0) + + producer = self.setup_producer(self.TOPIC) + + producer.start() + self.await_produced_messages(producer) + + self.session_timeout_sec = 60 + consumer = self.setup_consumer(self.TOPIC, static_membership=static_membership) + + consumer.start() + self.await_all_members(consumer) + + num_revokes_before_bounce = consumer.num_revokes_for_alive() + + num_keep_alive = 1 + + if bounce_mode == "all": + self.bounce_all_consumers(consumer, keep_alive=num_keep_alive, num_bounces=num_bounces) + else: + self.rolling_bounce_consumers(consumer, keep_alive=num_keep_alive, num_bounces=num_bounces) + + num_revokes_after_bounce = consumer.num_revokes_for_alive() - num_revokes_before_bounce + + check_condition = num_revokes_after_bounce != 0 + # under static membership, the live consumer shall not revoke any current running partitions, + # since there is no global rebalance being triggered. + if static_membership: + check_condition = num_revokes_after_bounce == 0 + + assert check_condition, \ + "Total revoked count %d does not match the expectation of having 0 revokes as %d" % \ + (num_revokes_after_bounce, check_condition) + + consumer.stop_all() + if clean_shutdown: + # if the total records consumed matches the current position, we haven't seen any duplicates + # this can only be guaranteed with a clean shutdown + assert consumer.current_position(partition) == consumer.total_consumed(), \ + "Total consumed records %d did not match consumed position %d" % \ + (consumer.total_consumed(), consumer.current_position(partition)) + else: + # we may have duplicates in a hard failure + assert consumer.current_position(partition) <= consumer.total_consumed(), \ + "Current position %d greater than the total number of consumed records %d" % \ + (consumer.current_position(partition), consumer.total_consumed()) + + @cluster(num_nodes=7) + @matrix(bounce_mode=["all", "rolling"], metadata_quorum=quorum.all_non_upgrade) + def test_static_consumer_persisted_after_rejoin(self, bounce_mode, metadata_quorum=quorum.zk): + """ + Verify that the updated member.id(updated_member_id) caused by static member rejoin would be persisted. If not, + after the brokers rolling bounce, the migrated group coordinator would load the stale persisted member.id and + fence subsequent static member rejoin with updated_member_id. + + - Start a producer which continues producing new messages throughout the test. + - Start up a static consumer and wait until it's up + - Restart the consumer and wait until it up, its member.id is supposed to be updated and persisted. + - Rolling bounce all the brokers and verify that the static consumer can still join the group and consumer messages. + """ + producer = self.setup_producer(self.TOPIC) + producer.start() + self.await_produced_messages(producer) + self.session_timeout_sec = 60 + consumer = self.setup_consumer(self.TOPIC, static_membership=True) + consumer.start() + self.await_all_members(consumer) + + # bounce the static member to trigger its member.id updated + if bounce_mode == "all": + self.bounce_all_consumers(consumer, num_bounces=1) + else: + self.rolling_bounce_consumers(consumer, num_bounces=1) + + # rolling bounce all the brokers to trigger the group coordinator migration and verify updated member.id is persisted + # and reloaded successfully + self.rolling_bounce_brokers(consumer, num_bounces=1) + + @cluster(num_nodes=10) + @matrix(num_conflict_consumers=[1, 2], fencing_stage=["stable", "all"], metadata_quorum=quorum.all_non_upgrade) + def test_fencing_static_consumer(self, num_conflict_consumers, fencing_stage, metadata_quorum=quorum.zk): + """ + Verify correct static consumer behavior when there are conflicting consumers with same group.instance.id. + + - Start a producer which continues producing new messages throughout the test. + - Start up the consumers as static members and wait until they've joined the group. Some conflict consumers will be configured with + - the same group.instance.id. + - Let normal consumers and fencing consumers start at the same time, and expect only unique consumers left. + """ + partition = TopicPartition(self.TOPIC, 0) + + producer = self.setup_producer(self.TOPIC) + + producer.start() + self.await_produced_messages(producer) + + self.session_timeout_sec = 60 + consumer = self.setup_consumer(self.TOPIC, static_membership=True) + + self.num_consumers = num_conflict_consumers + conflict_consumer = self.setup_consumer(self.TOPIC, static_membership=True) + + # wait original set of consumer to stable stage before starting conflict members. + if fencing_stage == "stable": + consumer.start() + self.await_members(consumer, len(consumer.nodes)) + + conflict_consumer.start() + self.await_members(conflict_consumer, num_conflict_consumers) + self.await_members(consumer, len(consumer.nodes) - num_conflict_consumers) + + wait_until(lambda: len(consumer.dead_nodes()) == num_conflict_consumers, + timeout_sec=10, + err_msg="Timed out waiting for the fenced consumers to stop") + else: + consumer.start() + conflict_consumer.start() + + wait_until(lambda: len(consumer.joined_nodes()) + len(conflict_consumer.joined_nodes()) == len(consumer.nodes), + timeout_sec=self.session_timeout_sec, + err_msg="Timed out waiting for consumers to join, expected total %d joined, but only see %d joined from" + "normal consumer group and %d from conflict consumer group" % \ + (len(consumer.nodes), len(consumer.joined_nodes()), len(conflict_consumer.joined_nodes())) + ) + wait_until(lambda: len(consumer.dead_nodes()) + len(conflict_consumer.dead_nodes()) == len(conflict_consumer.nodes), + timeout_sec=self.session_timeout_sec, + err_msg="Timed out waiting for fenced consumers to die, expected total %d dead, but only see %d dead in" + "normal consumer group and %d dead in conflict consumer group" % \ + (len(conflict_consumer.nodes), len(consumer.dead_nodes()), len(conflict_consumer.dead_nodes())) + ) + + @cluster(num_nodes=7) + @matrix(clean_shutdown=[True], enable_autocommit=[True, False], metadata_quorum=quorum.all_non_upgrade) + def test_consumer_failure(self, clean_shutdown, enable_autocommit, metadata_quorum=quorum.zk): + partition = TopicPartition(self.TOPIC, 0) + + consumer = self.setup_consumer(self.TOPIC, enable_autocommit=enable_autocommit) + producer = self.setup_producer(self.TOPIC) + + consumer.start() + self.await_all_members(consumer) + + partition_owner = consumer.owner(partition) + assert partition_owner is not None + + # startup the producer and ensure that some records have been written + producer.start() + self.await_produced_messages(producer) + + # stop the partition owner and await its shutdown + consumer.kill_node(partition_owner, clean_shutdown=clean_shutdown) + wait_until(lambda: len(consumer.joined_nodes()) == (self.num_consumers - 1) and consumer.owner(partition) != None, + timeout_sec=self.session_timeout_sec*2+5, + err_msg="Timed out waiting for consumer to close") + + # ensure that the remaining consumer does some work after rebalancing + self.await_consumed_messages(consumer, min_messages=1000) + + consumer.stop_all() + + if clean_shutdown: + # if the total records consumed matches the current position, we haven't seen any duplicates + # this can only be guaranteed with a clean shutdown + assert consumer.current_position(partition) == consumer.total_consumed(), \ + "Total consumed records %d did not match consumed position %d" % \ + (consumer.total_consumed(), consumer.current_position(partition)) + else: + # we may have duplicates in a hard failure + assert consumer.current_position(partition) <= consumer.total_consumed(), \ + "Current position %d greater than the total number of consumed records %d" % \ + (consumer.current_position(partition), consumer.total_consumed()) + + # if autocommit is not turned on, we can also verify the last committed offset + if not enable_autocommit: + assert consumer.last_commit(partition) == consumer.current_position(partition), \ + "Last committed offset %d did not match last consumed position %d" % \ + (consumer.last_commit(partition), consumer.current_position(partition)) + + @cluster(num_nodes=7) + @matrix(clean_shutdown=[True, False], enable_autocommit=[True, False], metadata_quorum=quorum.all_non_upgrade) + def test_broker_failure(self, clean_shutdown, enable_autocommit, metadata_quorum=quorum.zk): + partition = TopicPartition(self.TOPIC, 0) + + consumer = self.setup_consumer(self.TOPIC, enable_autocommit=enable_autocommit) + producer = self.setup_producer(self.TOPIC) + + producer.start() + consumer.start() + self.await_all_members(consumer) + + num_rebalances = consumer.num_rebalances() + + # shutdown one of the brokers + # TODO: we need a way to target the coordinator instead of picking arbitrarily + self.kafka.signal_node(self.kafka.nodes[0], signal.SIGTERM if clean_shutdown else signal.SIGKILL) + + # ensure that the consumers do some work after the broker failure + self.await_consumed_messages(consumer, min_messages=1000) + + # verify that there were no rebalances on failover + assert num_rebalances == consumer.num_rebalances(), "Broker failure should not cause a rebalance" + + consumer.stop_all() + + # if the total records consumed matches the current position, we haven't seen any duplicates + assert consumer.current_position(partition) == consumer.total_consumed(), \ + "Total consumed records %d did not match consumed position %d" % \ + (consumer.total_consumed(), consumer.current_position(partition)) + + # if autocommit is not turned on, we can also verify the last committed offset + if not enable_autocommit: + assert consumer.last_commit(partition) == consumer.current_position(partition), \ + "Last committed offset %d did not match last consumed position %d" % \ + (consumer.last_commit(partition), consumer.current_position(partition)) + + @cluster(num_nodes=7) + @matrix(metadata_quorum=quorum.all_non_upgrade) + def test_group_consumption(self, metadata_quorum=quorum.zk): + """ + Verifies correct group rebalance behavior as consumers are started and stopped. + In particular, this test verifies that the partition is readable after every + expected rebalance. + + Setup: single Kafka cluster with a group of consumers reading from one topic + with one partition while the verifiable producer writes to it. + + - Start the consumers one by one, verifying consumption after each rebalance + - Shutdown the consumers one by one, verifying consumption after each rebalance + """ + consumer = self.setup_consumer(self.TOPIC) + producer = self.setup_producer(self.TOPIC) + + partition = TopicPartition(self.TOPIC, 0) + + producer.start() + + for num_started, node in enumerate(consumer.nodes, 1): + consumer.start_node(node) + self.await_members(consumer, num_started) + self.await_consumed_messages(consumer) + + for num_stopped, node in enumerate(consumer.nodes, 1): + consumer.stop_node(node) + + if num_stopped < self.num_consumers: + self.await_members(consumer, self.num_consumers - num_stopped) + self.await_consumed_messages(consumer) + + assert consumer.current_position(partition) == consumer.total_consumed(), \ + "Total consumed records %d did not match consumed position %d" % \ + (consumer.total_consumed(), consumer.current_position(partition)) + + assert consumer.last_commit(partition) == consumer.current_position(partition), \ + "Last committed offset %d did not match last consumed position %d" % \ + (consumer.last_commit(partition), consumer.current_position(partition)) + +class AssignmentValidationTest(VerifiableConsumerTest): + TOPIC = "test_topic" + NUM_PARTITIONS = 6 + + def __init__(self, test_context): + super(AssignmentValidationTest, self).__init__(test_context, num_consumers=3, num_producers=0, + num_zk=1, num_brokers=2, topics={ + self.TOPIC : { 'partitions': self.NUM_PARTITIONS, 'replication-factor': 1 }, + }) + + @cluster(num_nodes=6) + @matrix(assignment_strategy=["org.apache.kafka.clients.consumer.RangeAssignor", + "org.apache.kafka.clients.consumer.RoundRobinAssignor", + "org.apache.kafka.clients.consumer.StickyAssignor"], metadata_quorum=quorum.all_non_upgrade) + def test_valid_assignment(self, assignment_strategy, metadata_quorum=quorum.zk): + """ + Verify assignment strategy correctness: each partition is assigned to exactly + one consumer instance. + + Setup: single Kafka cluster with a set of consumers in the same group. + + - Start the consumers one by one + - Validate assignment after every expected rebalance + """ + consumer = self.setup_consumer(self.TOPIC, assignment_strategy=assignment_strategy) + for num_started, node in enumerate(consumer.nodes, 1): + consumer.start_node(node) + self.await_members(consumer, num_started) + assert self.valid_assignment(self.TOPIC, self.NUM_PARTITIONS, consumer.current_assignment()), \ + "expected valid assignments of %d partitions when num_started %d: %s" % \ + (self.NUM_PARTITIONS, num_started, \ + [(str(node.account), a) for node, a in consumer.current_assignment().items()]) diff --git a/tests/kafkatest/tests/client/message_format_change_test.py b/tests/kafkatest/tests/client/message_format_change_test.py new file mode 100644 index 0000000..cb6cf72 --- /dev/null +++ b/tests/kafkatest/tests/client/message_format_change_test.py @@ -0,0 +1,106 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ducktape.mark import matrix +from ducktape.utils.util import wait_until +from ducktape.mark.resource import cluster + +from kafkatest.services.console_consumer import ConsoleConsumer +from kafkatest.services.kafka import config_property, KafkaService, quorum +from kafkatest.services.verifiable_producer import VerifiableProducer +from kafkatest.services.zookeeper import ZookeeperService +from kafkatest.tests.produce_consume_validate import ProduceConsumeValidateTest +from kafkatest.utils import is_int +from kafkatest.version import LATEST_0_9, LATEST_0_10, LATEST_0_11, V_2_8_0, DEV_BRANCH, KafkaVersion + + +class MessageFormatChangeTest(ProduceConsumeValidateTest): + + def __init__(self, test_context): + super(MessageFormatChangeTest, self).__init__(test_context=test_context) + + def setUp(self): + self.topic = "test_topic" + self.zk = ZookeeperService(self.test_context, num_nodes=1) if quorum.for_test(self.test_context) == quorum.zk else None + + if self.zk: + self.zk.start() + + # Producer and consumer + self.producer_throughput = 10000 + self.num_producers = 1 + self.num_consumers = 1 + self.messages_per_producer = 100 + + def produce_and_consume(self, producer_version, consumer_version, group): + self.producer = VerifiableProducer(self.test_context, self.num_producers, self.kafka, + self.topic, + throughput=self.producer_throughput, + message_validator=is_int, + version=KafkaVersion(producer_version)) + self.consumer = ConsoleConsumer(self.test_context, self.num_consumers, self.kafka, + self.topic, consumer_timeout_ms=30000, + message_validator=is_int, version=KafkaVersion(consumer_version)) + self.consumer.group_id = group + self.run_produce_consume_validate(lambda: wait_until( + lambda: self.producer.each_produced_at_least(self.messages_per_producer) == True, + timeout_sec=120, backoff_sec=1, + err_msg="Producer did not produce all messages in reasonable amount of time")) + + @cluster(num_nodes=12) + @matrix(producer_version=[str(DEV_BRANCH)], consumer_version=[str(DEV_BRANCH)], metadata_quorum=quorum.all_non_upgrade) + @matrix(producer_version=[str(LATEST_0_10)], consumer_version=[str(LATEST_0_10)], metadata_quorum=quorum.all_non_upgrade) + @matrix(producer_version=[str(LATEST_0_9)], consumer_version=[str(LATEST_0_9)], metadata_quorum=quorum.all_non_upgrade) + def test_compatibility(self, producer_version, consumer_version, metadata_quorum=quorum.zk): + """ This tests performs the following checks: + The workload is a mix of 0.9.x, 0.10.x and 0.11.x producers and consumers + that produce to and consume from a DEV_BRANCH cluster + 1. initially the topic is using message format 0.9.0 + 2. change the message format version for topic to 0.10.0 on the fly. + 3. change the message format version for topic to 0.11.0 on the fly. + 4. change the message format version for topic back to 0.10.0 on the fly (only if the client version is 0.11.0 or newer) + - The producers and consumers should not have any issue. + + Note regarding step number 4. Downgrading the message format version is generally unsupported as it breaks + older clients. More concretely, if we downgrade a topic from 0.11.0 to 0.10.0 after it contains messages with + version 0.11.0, we will return the 0.11.0 messages without down conversion due to an optimisation in the + handling of fetch requests. This will break any consumer that doesn't support 0.11.0. So, in practice, step 4 + is similar to step 2 and it didn't seem worth it to increase the cluster size to in order to add a step 5 that + would change the message format version for the topic back to 0.9.0.0. + """ + self.kafka = KafkaService(self.test_context, num_nodes=3, zk=self.zk, version=DEV_BRANCH, topics={self.topic: { + "partitions": 3, + "replication-factor": 3, + 'configs': {"min.insync.replicas": 2}}}, + controller_num_nodes_override=1) + for node in self.kafka.nodes: + node.config[config_property.INTER_BROKER_PROTOCOL_VERSION] = str(V_2_8_0) # required for writing old message formats + + self.kafka.start() + self.logger.info("First format change to 0.9.0") + self.kafka.alter_message_format(self.topic, str(LATEST_0_9)) + self.produce_and_consume(producer_version, consumer_version, "group1") + + self.logger.info("Second format change to 0.10.0") + self.kafka.alter_message_format(self.topic, str(LATEST_0_10)) + self.produce_and_consume(producer_version, consumer_version, "group2") + + self.logger.info("Third format change to 0.11.0") + self.kafka.alter_message_format(self.topic, str(LATEST_0_11)) + self.produce_and_consume(producer_version, consumer_version, "group3") + + if producer_version == str(DEV_BRANCH) and consumer_version == str(DEV_BRANCH): + self.logger.info("Fourth format change back to 0.10.0") + self.kafka.alter_message_format(self.topic, str(LATEST_0_10)) + self.produce_and_consume(producer_version, consumer_version, "group4") + + diff --git a/tests/kafkatest/tests/client/pluggable_test.py b/tests/kafkatest/tests/client/pluggable_test.py new file mode 100644 index 0000000..b2f726e --- /dev/null +++ b/tests/kafkatest/tests/client/pluggable_test.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ducktape.mark import matrix +from ducktape.mark.resource import cluster +from ducktape.utils.util import wait_until + +from kafkatest.services.kafka import quorum +from kafkatest.tests.verifiable_consumer_test import VerifiableConsumerTest + +class PluggableConsumerTest(VerifiableConsumerTest): + """ Verify that the pluggable client framework works. """ + + TOPIC = "test_topic" + NUM_PARTITIONS = 1 + + def __init__(self, test_context): + super(PluggableConsumerTest, self).__init__(test_context, num_consumers=1, num_producers=0, + num_zk=1, num_brokers=1, topics={ + self.TOPIC : { 'partitions': self.NUM_PARTITIONS, 'replication-factor': 1 }, + }) + + @cluster(num_nodes=4) + @matrix(metadata_quorum=quorum.all_non_upgrade) + def test_start_stop(self, metadata_quorum=quorum.zk): + """ + Test that a pluggable VerifiableConsumer module load works + """ + consumer = self.setup_consumer(self.TOPIC) + + for _, node in enumerate(consumer.nodes, 1): + consumer.start_node(node) + + self.logger.debug("Waiting for %d nodes to start" % len(consumer.nodes)) + wait_until(lambda: len(consumer.alive_nodes()) == len(consumer.nodes), + timeout_sec=60, + err_msg="Timed out waiting for consumers to start") + self.logger.debug("Started: %s" % str(consumer.alive_nodes())) + consumer.stop_all() + + self.logger.debug("Waiting for %d nodes to stop" % len(consumer.nodes)) + wait_until(lambda: len(consumer.dead_nodes()) == len(consumer.nodes), + timeout_sec=self.session_timeout_sec+5, + err_msg="Timed out waiting for consumers to shutdown") diff --git a/tests/kafkatest/tests/client/quota_test.py b/tests/kafkatest/tests/client/quota_test.py new file mode 100644 index 0000000..893bf75 --- /dev/null +++ b/tests/kafkatest/tests/client/quota_test.py @@ -0,0 +1,237 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ducktape.tests.test import Test +from ducktape.mark import matrix, parametrize +from ducktape.mark.resource import cluster + +from kafkatest.services.zookeeper import ZookeeperService +from kafkatest.services.kafka import KafkaService +from kafkatest.services.performance import ProducerPerformanceService +from kafkatest.services.console_consumer import ConsoleConsumer +from kafkatest.version import DEV_BRANCH, LATEST_1_1 + +class QuotaConfig(object): + CLIENT_ID = 'client-id' + USER = 'user' + USER_CLIENT = '(user, client-id)' + + LARGE_QUOTA = 1000 * 1000 * 1000 + USER_PRINCIPAL = 'CN=systemtest' + + def __init__(self, quota_type, override_quota, kafka): + if quota_type == QuotaConfig.CLIENT_ID: + if override_quota: + self.client_id = 'overridden_id' + self.producer_quota = 3750000 + self.consumer_quota = 3000000 + self.configure_quota(kafka, self.producer_quota, self.consumer_quota, ['clients', self.client_id]) + self.configure_quota(kafka, QuotaConfig.LARGE_QUOTA, QuotaConfig.LARGE_QUOTA, ['clients', None]) + else: + self.client_id = 'default_id' + self.producer_quota = 2500000 + self.consumer_quota = 2000000 + self.configure_quota(kafka, self.producer_quota, self.consumer_quota, ['clients', None]) + self.configure_quota(kafka, QuotaConfig.LARGE_QUOTA, QuotaConfig.LARGE_QUOTA, ['clients', 'overridden_id']) + elif quota_type == QuotaConfig.USER: + if override_quota: + self.client_id = 'some_id' + self.producer_quota = 3750000 + self.consumer_quota = 3000000 + self.configure_quota(kafka, self.producer_quota, self.consumer_quota, ['users', QuotaConfig.USER_PRINCIPAL]) + self.configure_quota(kafka, QuotaConfig.LARGE_QUOTA, QuotaConfig.LARGE_QUOTA, ['users', None]) + self.configure_quota(kafka, QuotaConfig.LARGE_QUOTA, QuotaConfig.LARGE_QUOTA, ['clients', self.client_id]) + else: + self.client_id = 'some_id' + self.producer_quota = 2500000 + self.consumer_quota = 2000000 + self.configure_quota(kafka, self.producer_quota, self.consumer_quota, ['users', None]) + self.configure_quota(kafka, QuotaConfig.LARGE_QUOTA, QuotaConfig.LARGE_QUOTA, ['clients', None]) + elif quota_type == QuotaConfig.USER_CLIENT: + if override_quota: + self.client_id = 'overridden_id' + self.producer_quota = 3750000 + self.consumer_quota = 3000000 + self.configure_quota(kafka, self.producer_quota, self.consumer_quota, ['users', QuotaConfig.USER_PRINCIPAL, 'clients', self.client_id]) + self.configure_quota(kafka, QuotaConfig.LARGE_QUOTA, QuotaConfig.LARGE_QUOTA, ['users', QuotaConfig.USER_PRINCIPAL, 'clients', None]) + self.configure_quota(kafka, QuotaConfig.LARGE_QUOTA, QuotaConfig.LARGE_QUOTA, ['users', None]) + self.configure_quota(kafka, QuotaConfig.LARGE_QUOTA, QuotaConfig.LARGE_QUOTA, ['clients', self.client_id]) + else: + self.client_id = 'default_id' + self.producer_quota = 2500000 + self.consumer_quota = 2000000 + self.configure_quota(kafka, self.producer_quota, self.consumer_quota, ['users', None, 'clients', None]) + self.configure_quota(kafka, QuotaConfig.LARGE_QUOTA, QuotaConfig.LARGE_QUOTA, ['users', None]) + self.configure_quota(kafka, QuotaConfig.LARGE_QUOTA, QuotaConfig.LARGE_QUOTA, ['clients', None]) + + def configure_quota(self, kafka, producer_byte_rate, consumer_byte_rate, entity_args): + force_use_zk_conection = not kafka.all_nodes_configs_command_uses_bootstrap_server() + node = kafka.nodes[0] + cmd = "%s --alter --add-config producer_byte_rate=%d,consumer_byte_rate=%d" % \ + (kafka.kafka_configs_cmd_with_optional_security_settings(node, force_use_zk_conection), producer_byte_rate, consumer_byte_rate) + cmd += " --entity-type " + entity_args[0] + self.entity_name_opt(entity_args[1]) + if len(entity_args) > 2: + cmd += " --entity-type " + entity_args[2] + self.entity_name_opt(entity_args[3]) + node.account.ssh(cmd) + + def entity_name_opt(self, name): + return " --entity-default" if name is None else " --entity-name " + name + +class QuotaTest(Test): + """ + These tests verify that quota provides expected functionality -- they run + producer, broker, and consumer with different clientId and quota configuration and + check that the observed throughput is close to the value we expect. + """ + + def __init__(self, test_context): + """:type test_context: ducktape.tests.test.TestContext""" + super(QuotaTest, self).__init__(test_context=test_context) + + self.topic = 'test_topic' + self.logger.info('use topic ' + self.topic) + + self.maximum_client_deviation_percentage = 100.0 + self.maximum_broker_deviation_percentage = 5.0 + self.num_records = 50000 + self.record_size = 3000 + + self.zk = ZookeeperService(test_context, num_nodes=1) + self.kafka = KafkaService(test_context, num_nodes=1, zk=self.zk, + security_protocol='SSL', authorizer_class_name='', + interbroker_security_protocol='SSL', + topics={self.topic: {'partitions': 6, 'replication-factor': 1, 'configs': {'min.insync.replicas': 1}}}, + jmx_object_names=['kafka.server:type=BrokerTopicMetrics,name=BytesInPerSec', + 'kafka.server:type=BrokerTopicMetrics,name=BytesOutPerSec'], + jmx_attributes=['OneMinuteRate']) + self.num_producers = 1 + self.num_consumers = 2 + + def setUp(self): + self.zk.start() + + def min_cluster_size(self): + """Override this since we're adding services outside of the constructor""" + return super(QuotaTest, self).min_cluster_size() + self.num_producers + self.num_consumers + + @cluster(num_nodes=5) + @matrix(quota_type=[QuotaConfig.CLIENT_ID, QuotaConfig.USER, QuotaConfig.USER_CLIENT], override_quota=[True, False]) + @parametrize(quota_type=QuotaConfig.CLIENT_ID, consumer_num=2) + @parametrize(quota_type=QuotaConfig.CLIENT_ID, old_broker_throttling_behavior=True) + @parametrize(quota_type=QuotaConfig.CLIENT_ID, old_client_throttling_behavior=True) + def test_quota(self, quota_type, override_quota=True, producer_num=1, consumer_num=1, + old_broker_throttling_behavior=False, old_client_throttling_behavior=False): + # Old (pre-2.0) throttling behavior for broker throttles before sending a response to the client. + if old_broker_throttling_behavior: + self.kafka.set_version(LATEST_1_1) + self.kafka.start() + + self.quota_config = QuotaConfig(quota_type, override_quota, self.kafka) + producer_client_id = self.quota_config.client_id + consumer_client_id = self.quota_config.client_id + + # Old (pre-2.0) throttling behavior for client does not throttle upon receiving a response with a non-zero throttle time. + if old_client_throttling_behavior: + client_version = LATEST_1_1 + else: + client_version = DEV_BRANCH + + # Produce all messages + producer = ProducerPerformanceService( + self.test_context, producer_num, self.kafka, + topic=self.topic, num_records=self.num_records, record_size=self.record_size, throughput=-1, + client_id=producer_client_id, version=client_version) + + producer.run() + + # Consume all messages + consumer = ConsoleConsumer(self.test_context, consumer_num, self.kafka, self.topic, + consumer_timeout_ms=60000, client_id=consumer_client_id, + jmx_object_names=['kafka.consumer:type=consumer-fetch-manager-metrics,client-id=%s' % consumer_client_id], + jmx_attributes=['bytes-consumed-rate'], version=client_version) + consumer.run() + + for idx, messages in consumer.messages_consumed.items(): + assert len(messages) > 0, "consumer %d didn't consume any message before timeout" % idx + + success, msg = self.validate(self.kafka, producer, consumer) + assert success, msg + + def validate(self, broker, producer, consumer): + """ + For each client_id we validate that: + 1) number of consumed messages equals number of produced messages + 2) maximum_producer_throughput <= producer_quota * (1 + maximum_client_deviation_percentage/100) + 3) maximum_broker_byte_in_rate <= producer_quota * (1 + maximum_broker_deviation_percentage/100) + 4) maximum_consumer_throughput <= consumer_quota * (1 + maximum_client_deviation_percentage/100) + 5) maximum_broker_byte_out_rate <= consumer_quota * (1 + maximum_broker_deviation_percentage/100) + """ + success = True + msg = '' + + self.kafka.read_jmx_output_all_nodes() + + # validate that number of consumed messages equals number of produced messages + produced_num = sum([value['records'] for value in producer.results]) + consumed_num = sum([len(value) for value in consumer.messages_consumed.values()]) + self.logger.info('producer produced %d messages' % produced_num) + self.logger.info('consumer consumed %d messages' % consumed_num) + if produced_num != consumed_num: + success = False + msg += "number of produced messages %d doesn't equal number of consumed messages %d" % (produced_num, consumed_num) + + # validate that maximum_producer_throughput <= producer_quota * (1 + maximum_client_deviation_percentage/100) + producer_maximum_bps = max( + metric.value for k, metrics in producer.metrics(group='producer-metrics', name='outgoing-byte-rate', client_id=producer.client_id) for metric in metrics + ) + producer_quota_bps = self.quota_config.producer_quota + self.logger.info('producer has maximum throughput %.2f bps with producer quota %.2f bps' % (producer_maximum_bps, producer_quota_bps)) + if producer_maximum_bps > producer_quota_bps*(self.maximum_client_deviation_percentage/100+1): + success = False + msg += 'maximum producer throughput %.2f bps exceeded producer quota %.2f bps by more than %.1f%%' % \ + (producer_maximum_bps, producer_quota_bps, self.maximum_client_deviation_percentage) + + # validate that maximum_broker_byte_in_rate <= producer_quota * (1 + maximum_broker_deviation_percentage/100) + broker_byte_in_attribute_name = 'kafka.server:type=BrokerTopicMetrics,name=BytesInPerSec:OneMinuteRate' + broker_maximum_byte_in_bps = broker.maximum_jmx_value[broker_byte_in_attribute_name] + self.logger.info('broker has maximum byte-in rate %.2f bps with producer quota %.2f bps' % + (broker_maximum_byte_in_bps, producer_quota_bps)) + if broker_maximum_byte_in_bps > producer_quota_bps*(self.maximum_broker_deviation_percentage/100+1): + success = False + msg += 'maximum broker byte-in rate %.2f bps exceeded producer quota %.2f bps by more than %.1f%%' % \ + (broker_maximum_byte_in_bps, producer_quota_bps, self.maximum_broker_deviation_percentage) + + # validate that maximum_consumer_throughput <= consumer_quota * (1 + maximum_client_deviation_percentage/100) + consumer_attribute_name = 'kafka.consumer:type=consumer-fetch-manager-metrics,client-id=%s:bytes-consumed-rate' % consumer.client_id + consumer_maximum_bps = consumer.maximum_jmx_value[consumer_attribute_name] + consumer_quota_bps = self.quota_config.consumer_quota + self.logger.info('consumer has maximum throughput %.2f bps with consumer quota %.2f bps' % (consumer_maximum_bps, consumer_quota_bps)) + if consumer_maximum_bps > consumer_quota_bps*(self.maximum_client_deviation_percentage/100+1): + success = False + msg += 'maximum consumer throughput %.2f bps exceeded consumer quota %.2f bps by more than %.1f%%' % \ + (consumer_maximum_bps, consumer_quota_bps, self.maximum_client_deviation_percentage) + + # validate that maximum_broker_byte_out_rate <= consumer_quota * (1 + maximum_broker_deviation_percentage/100) + broker_byte_out_attribute_name = 'kafka.server:type=BrokerTopicMetrics,name=BytesOutPerSec:OneMinuteRate' + broker_maximum_byte_out_bps = broker.maximum_jmx_value[broker_byte_out_attribute_name] + self.logger.info('broker has maximum byte-out rate %.2f bps with consumer quota %.2f bps' % + (broker_maximum_byte_out_bps, consumer_quota_bps)) + if broker_maximum_byte_out_bps > consumer_quota_bps*(self.maximum_broker_deviation_percentage/100+1): + success = False + msg += 'maximum broker byte-out rate %.2f bps exceeded consumer quota %.2f bps by more than %.1f%%' % \ + (broker_maximum_byte_out_bps, consumer_quota_bps, self.maximum_broker_deviation_percentage) + + return success, msg + diff --git a/tests/kafkatest/tests/client/truncation_test.py b/tests/kafkatest/tests/client/truncation_test.py new file mode 100644 index 0000000..523bcbc --- /dev/null +++ b/tests/kafkatest/tests/client/truncation_test.py @@ -0,0 +1,149 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ducktape.mark.resource import cluster +from ducktape.utils.util import wait_until + +from kafkatest.tests.verifiable_consumer_test import VerifiableConsumerTest +from kafkatest.services.kafka import TopicPartition +from kafkatest.services.verifiable_consumer import VerifiableConsumer + + +class TruncationTest(VerifiableConsumerTest): + TOPIC = "test_topic" + NUM_PARTITIONS = 1 + TOPICS = { + TOPIC: { + 'partitions': NUM_PARTITIONS, + 'replication-factor': 2 + } + } + GROUP_ID = "truncation-test" + + def __init__(self, test_context): + super(TruncationTest, self).__init__(test_context, num_consumers=1, num_producers=1, + num_zk=1, num_brokers=3, topics=self.TOPICS) + self.last_total = 0 + self.all_offsets_consumed = [] + self.all_values_consumed = [] + + def setup_consumer(self, topic, **kwargs): + consumer = super(TruncationTest, self).setup_consumer(topic, **kwargs) + self.mark_for_collect(consumer, 'verifiable_consumer_stdout') + + def print_record(event, node): + self.all_offsets_consumed.append(event['offset']) + self.all_values_consumed.append(event['value']) + consumer.on_record_consumed = print_record + + return consumer + + @cluster(num_nodes=7) + def test_offset_truncate(self): + """ + Verify correct consumer behavior when the brokers are consecutively restarted. + + Setup: single Kafka cluster with one producer writing messages to a single topic with one + partition, an a set of consumers in the same group reading from the same topic. + + - Start a producer which continues producing new messages throughout the test. + - Start up the consumers and wait until they've joined the group. + - In a loop, restart each broker consecutively, waiting for the group to stabilize between + each broker restart. + - Verify delivery semantics according to the failure type and that the broker bounces + did not cause unexpected group rebalances. + """ + tp = TopicPartition(self.TOPIC, 0) + + producer = self.setup_producer(self.TOPIC, throughput=10) + producer.start() + self.await_produced_messages(producer, min_messages=10) + + consumer = self.setup_consumer(self.TOPIC, reset_policy="earliest", verify_offsets=False) + consumer.start() + self.await_all_members(consumer) + + # Reduce ISR to one node + isr = self.kafka.isr_idx_list(self.TOPIC, 0) + node1 = self.kafka.get_node(isr[0]) + self.kafka.stop_node(node1) + self.logger.info("Reduced ISR to one node, consumer is at %s", consumer.current_position(tp)) + + # Ensure remaining ISR member has a little bit of data + current_total = consumer.total_consumed() + wait_until(lambda: consumer.total_consumed() > current_total + 10, + timeout_sec=30, + err_msg="Timed out waiting for consumer to move ahead by 10 messages") + + # Kill last ISR member + node2 = self.kafka.get_node(isr[1]) + self.kafka.stop_node(node2) + self.logger.info("No members in ISR, consumer is at %s", consumer.current_position(tp)) + + # Keep consuming until we've caught up to HW + def none_consumed(this, consumer): + new_total = consumer.total_consumed() + if new_total == this.last_total: + return True + else: + this.last_total = new_total + return False + + self.last_total = consumer.total_consumed() + wait_until(lambda: none_consumed(self, consumer), + timeout_sec=30, + err_msg="Timed out waiting for the consumer to catch up") + + self.kafka.start_node(node1) + self.logger.info("Out of sync replica is online, but not electable. Consumer is at %s", consumer.current_position(tp)) + + pre_truncation_pos = consumer.current_position(tp) + + self.kafka.set_unclean_leader_election(self.TOPIC) + self.logger.info("New unclean leader, consumer is at %s", consumer.current_position(tp)) + + # Wait for truncation to be detected + self.kafka.start_node(node2) + wait_until(lambda: consumer.current_position(tp) >= pre_truncation_pos, + timeout_sec=30, + err_msg="Timed out waiting for truncation") + + # Make sure we didn't reset to beginning of log + total_records_consumed = len(self.all_values_consumed) + assert total_records_consumed == len(set(self.all_values_consumed)), "Received duplicate records" + + consumer.stop() + producer.stop() + + # Re-consume all the records + consumer2 = VerifiableConsumer(self.test_context, 1, self.kafka, self.TOPIC, group_id="group2", + reset_policy="earliest", verify_offsets=True) + + consumer2.start() + self.await_all_members(consumer2) + + wait_until(lambda: consumer2.total_consumed() > 0, + timeout_sec=30, + err_msg="Timed out waiting for consumer to consume at least 10 messages") + + self.last_total = consumer2.total_consumed() + wait_until(lambda: none_consumed(self, consumer2), + timeout_sec=30, + err_msg="Timed out waiting for the consumer to fully consume data") + + second_total_consumed = consumer2.total_consumed() + assert second_total_consumed < total_records_consumed, "Expected fewer records with new consumer since we truncated" + self.logger.info("Second consumer saw only %s, meaning %s were truncated", + second_total_consumed, total_records_consumed - second_total_consumed) \ No newline at end of file diff --git a/tests/kafkatest/tests/connect/__init__.py b/tests/kafkatest/tests/connect/__init__.py new file mode 100644 index 0000000..ec20143 --- /dev/null +++ b/tests/kafkatest/tests/connect/__init__.py @@ -0,0 +1,14 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/kafkatest/tests/connect/connect_distributed_test.py b/tests/kafkatest/tests/connect/connect_distributed_test.py new file mode 100644 index 0000000..6bc52b0 --- /dev/null +++ b/tests/kafkatest/tests/connect/connect_distributed_test.py @@ -0,0 +1,644 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ducktape.tests.test import Test +from ducktape.mark.resource import cluster +from ducktape.utils.util import wait_until +from ducktape.mark import matrix, parametrize +from ducktape.cluster.remoteaccount import RemoteCommandError + +from kafkatest.services.zookeeper import ZookeeperService +from kafkatest.services.kafka import KafkaService, config_property +from kafkatest.services.connect import ConnectDistributedService, VerifiableSource, VerifiableSink, ConnectRestError, MockSink, MockSource +from kafkatest.services.console_consumer import ConsoleConsumer +from kafkatest.services.security.security_config import SecurityConfig +from kafkatest.version import DEV_BRANCH, LATEST_2_3, LATEST_2_2, LATEST_2_1, LATEST_2_0, LATEST_1_1, LATEST_1_0, LATEST_0_11_0, LATEST_0_10_2, LATEST_0_10_1, LATEST_0_10_0, LATEST_0_9, LATEST_0_8_2, KafkaVersion + +from functools import reduce +from collections import Counter, namedtuple +import itertools +import json +import operator +import time + +class ConnectDistributedTest(Test): + """ + Simple test of Kafka Connect in distributed mode, producing data from files on one cluster and consuming it on + another, validating the total output is identical to the input. + """ + + FILE_SOURCE_CONNECTOR = 'org.apache.kafka.connect.file.FileStreamSourceConnector' + FILE_SINK_CONNECTOR = 'org.apache.kafka.connect.file.FileStreamSinkConnector' + + INPUT_FILE = "/mnt/connect.input" + OUTPUT_FILE = "/mnt/connect.output" + + TOPIC = "test" + OFFSETS_TOPIC = "connect-offsets" + OFFSETS_REPLICATION_FACTOR = "1" + OFFSETS_PARTITIONS = "1" + CONFIG_TOPIC = "connect-configs" + CONFIG_REPLICATION_FACTOR = "1" + STATUS_TOPIC = "connect-status" + STATUS_REPLICATION_FACTOR = "1" + STATUS_PARTITIONS = "1" + SCHEDULED_REBALANCE_MAX_DELAY_MS = "60000" + CONNECT_PROTOCOL="sessioned" + + # Since tasks can be assigned to any node and we're testing with files, we need to make sure the content is the same + # across all nodes. + FIRST_INPUT_LIST = ["foo", "bar", "baz"] + FIRST_INPUTS = "\n".join(FIRST_INPUT_LIST) + "\n" + SECOND_INPUT_LIST = ["razz", "ma", "tazz"] + SECOND_INPUTS = "\n".join(SECOND_INPUT_LIST) + "\n" + + SCHEMA = { "type": "string", "optional": False } + + def __init__(self, test_context): + super(ConnectDistributedTest, self).__init__(test_context) + self.num_zk = 1 + self.num_brokers = 1 + self.topics = { + self.TOPIC: {'partitions': 1, 'replication-factor': 1} + } + + self.zk = ZookeeperService(test_context, self.num_zk) + + self.key_converter = "org.apache.kafka.connect.json.JsonConverter" + self.value_converter = "org.apache.kafka.connect.json.JsonConverter" + self.schemas = True + + def setup_services(self, security_protocol=SecurityConfig.PLAINTEXT, timestamp_type=None, broker_version=DEV_BRANCH, auto_create_topics=False): + self.kafka = KafkaService(self.test_context, self.num_brokers, self.zk, + security_protocol=security_protocol, interbroker_security_protocol=security_protocol, + topics=self.topics, version=broker_version, + server_prop_overrides=[["auto.create.topics.enable", str(auto_create_topics)]]) + if timestamp_type is not None: + for node in self.kafka.nodes: + node.config[config_property.MESSAGE_TIMESTAMP_TYPE] = timestamp_type + + self.cc = ConnectDistributedService(self.test_context, 3, self.kafka, [self.INPUT_FILE, self.OUTPUT_FILE]) + self.cc.log_level = "DEBUG" + + self.zk.start() + self.kafka.start() + + def _start_connector(self, config_file): + connector_props = self.render(config_file) + connector_config = dict([line.strip().split('=', 1) for line in connector_props.split('\n') if line.strip() and not line.strip().startswith('#')]) + self.cc.create_connector(connector_config) + + def _connector_status(self, connector, node=None): + try: + return self.cc.get_connector_status(connector, node) + except ConnectRestError: + return None + + def _connector_has_state(self, status, state): + return status is not None and status['connector']['state'] == state + + def _task_has_state(self, task_id, status, state): + if not status: + return False + + tasks = status['tasks'] + if not tasks: + return False + + for task in tasks: + if task['id'] == task_id: + return task['state'] == state + + return False + + def _all_tasks_have_state(self, status, task_count, state): + if status is None: + return False + + tasks = status['tasks'] + if len(tasks) != task_count: + return False + + return reduce(operator.and_, [task['state'] == state for task in tasks], True) + + def is_running(self, connector, node=None): + status = self._connector_status(connector.name, node) + return self._connector_has_state(status, 'RUNNING') and self._all_tasks_have_state(status, connector.tasks, 'RUNNING') + + def is_paused(self, connector, node=None): + status = self._connector_status(connector.name, node) + return self._connector_has_state(status, 'PAUSED') and self._all_tasks_have_state(status, connector.tasks, 'PAUSED') + + def connector_is_running(self, connector, node=None): + status = self._connector_status(connector.name, node) + return self._connector_has_state(status, 'RUNNING') + + def connector_is_failed(self, connector, node=None): + status = self._connector_status(connector.name, node) + return self._connector_has_state(status, 'FAILED') + + def task_is_failed(self, connector, task_id, node=None): + status = self._connector_status(connector.name, node) + return self._task_has_state(task_id, status, 'FAILED') + + def task_is_running(self, connector, task_id, node=None): + status = self._connector_status(connector.name, node) + return self._task_has_state(task_id, status, 'RUNNING') + + @cluster(num_nodes=5) + @matrix(connect_protocol=['sessioned', 'compatible', 'eager']) + def test_restart_failed_connector(self, connect_protocol): + self.CONNECT_PROTOCOL = connect_protocol + self.setup_services() + self.cc.set_configs(lambda node: self.render("connect-distributed.properties", node=node)) + self.cc.start() + + self.sink = MockSink(self.cc, self.topics.keys(), mode='connector-failure', delay_sec=5) + self.sink.start() + + wait_until(lambda: self.connector_is_failed(self.sink), timeout_sec=15, + err_msg="Failed to see connector transition to the FAILED state") + + self.cc.restart_connector(self.sink.name) + + wait_until(lambda: self.connector_is_running(self.sink), timeout_sec=10, + err_msg="Failed to see connector transition to the RUNNING state") + + @cluster(num_nodes=5) + @matrix(connector_type=['source', 'sink'], connect_protocol=['sessioned', 'compatible', 'eager']) + def test_restart_failed_task(self, connector_type, connect_protocol): + self.CONNECT_PROTOCOL = connect_protocol + self.setup_services() + self.cc.set_configs(lambda node: self.render("connect-distributed.properties", node=node)) + self.cc.start() + + connector = None + if connector_type == "sink": + connector = MockSink(self.cc, self.topics.keys(), mode='task-failure', delay_sec=5) + else: + connector = MockSource(self.cc, mode='task-failure', delay_sec=5) + + connector.start() + + task_id = 0 + wait_until(lambda: self.task_is_failed(connector, task_id), timeout_sec=20, + err_msg="Failed to see task transition to the FAILED state") + + self.cc.restart_task(connector.name, task_id) + + wait_until(lambda: self.task_is_running(connector, task_id), timeout_sec=10, + err_msg="Failed to see task transition to the RUNNING state") + + @cluster(num_nodes=5) + @matrix(connect_protocol=['sessioned', 'compatible', 'eager']) + def test_restart_connector_and_tasks_failed_connector(self, connect_protocol): + self.CONNECT_PROTOCOL = connect_protocol + self.setup_services() + self.cc.set_configs(lambda node: self.render("connect-distributed.properties", node=node)) + self.cc.start() + + self.sink = MockSink(self.cc, self.topics.keys(), mode='connector-failure', delay_sec=5) + self.sink.start() + + wait_until(lambda: self.connector_is_failed(self.sink), timeout_sec=15, + err_msg="Failed to see connector transition to the FAILED state") + + self.cc.restart_connector_and_tasks(self.sink.name, only_failed = "true", include_tasks = "false") + + wait_until(lambda: self.connector_is_running(self.sink), timeout_sec=10, + err_msg="Failed to see connector transition to the RUNNING state") + + @cluster(num_nodes=5) + @matrix(connector_type=['source', 'sink'], connect_protocol=['sessioned', 'compatible', 'eager']) + def test_restart_connector_and_tasks_failed_task(self, connector_type, connect_protocol): + self.CONNECT_PROTOCOL = connect_protocol + self.setup_services() + self.cc.set_configs(lambda node: self.render("connect-distributed.properties", node=node)) + self.cc.start() + + connector = None + if connector_type == "sink": + connector = MockSink(self.cc, self.topics.keys(), mode='task-failure', delay_sec=5) + else: + connector = MockSource(self.cc, mode='task-failure', delay_sec=5) + + connector.start() + + task_id = 0 + wait_until(lambda: self.task_is_failed(connector, task_id), timeout_sec=20, + err_msg="Failed to see task transition to the FAILED state") + + self.cc.restart_connector_and_tasks(connector.name, only_failed = "false", include_tasks = "true") + + wait_until(lambda: self.task_is_running(connector, task_id), timeout_sec=10, + err_msg="Failed to see task transition to the RUNNING state") + + @cluster(num_nodes=5) + @matrix(connect_protocol=['sessioned', 'compatible', 'eager']) + def test_pause_and_resume_source(self, connect_protocol): + """ + Verify that source connectors stop producing records when paused and begin again after + being resumed. + """ + + self.CONNECT_PROTOCOL = connect_protocol + self.setup_services() + self.cc.set_configs(lambda node: self.render("connect-distributed.properties", node=node)) + self.cc.start() + + self.source = VerifiableSource(self.cc, topic=self.TOPIC) + self.source.start() + + wait_until(lambda: self.is_running(self.source), timeout_sec=30, + err_msg="Failed to see connector transition to the RUNNING state") + + self.cc.pause_connector(self.source.name) + + # wait until all nodes report the paused transition + for node in self.cc.nodes: + wait_until(lambda: self.is_paused(self.source, node), timeout_sec=30, + err_msg="Failed to see connector transition to the PAUSED state") + + # verify that we do not produce new messages while paused + num_messages = len(self.source.sent_messages()) + time.sleep(10) + assert num_messages == len(self.source.sent_messages()), "Paused source connector should not produce any messages" + + self.cc.resume_connector(self.source.name) + + for node in self.cc.nodes: + wait_until(lambda: self.is_running(self.source, node), timeout_sec=30, + err_msg="Failed to see connector transition to the RUNNING state") + + # after resuming, we should see records produced again + wait_until(lambda: len(self.source.sent_messages()) > num_messages, timeout_sec=30, + err_msg="Failed to produce messages after resuming source connector") + + @cluster(num_nodes=5) + @matrix(connect_protocol=['sessioned', 'compatible', 'eager']) + def test_pause_and_resume_sink(self, connect_protocol): + """ + Verify that sink connectors stop consuming records when paused and begin again after + being resumed. + """ + + self.CONNECT_PROTOCOL = connect_protocol + self.setup_services() + self.cc.set_configs(lambda node: self.render("connect-distributed.properties", node=node)) + self.cc.start() + + # use the verifiable source to produce a steady stream of messages + self.source = VerifiableSource(self.cc, topic=self.TOPIC) + self.source.start() + + wait_until(lambda: len(self.source.committed_messages()) > 0, timeout_sec=30, + err_msg="Timeout expired waiting for source task to produce a message") + + self.sink = VerifiableSink(self.cc, topics=[self.TOPIC]) + self.sink.start() + + wait_until(lambda: self.is_running(self.sink), timeout_sec=30, + err_msg="Failed to see connector transition to the RUNNING state") + + self.cc.pause_connector(self.sink.name) + + # wait until all nodes report the paused transition + for node in self.cc.nodes: + wait_until(lambda: self.is_paused(self.sink, node), timeout_sec=30, + err_msg="Failed to see connector transition to the PAUSED state") + + # verify that we do not consume new messages while paused + num_messages = len(self.sink.received_messages()) + time.sleep(10) + assert num_messages == len(self.sink.received_messages()), "Paused sink connector should not consume any messages" + + self.cc.resume_connector(self.sink.name) + + for node in self.cc.nodes: + wait_until(lambda: self.is_running(self.sink, node), timeout_sec=30, + err_msg="Failed to see connector transition to the RUNNING state") + + # after resuming, we should see records consumed again + wait_until(lambda: len(self.sink.received_messages()) > num_messages, timeout_sec=30, + err_msg="Failed to consume messages after resuming sink connector") + + @cluster(num_nodes=5) + @matrix(connect_protocol=['sessioned', 'compatible', 'eager']) + def test_pause_state_persistent(self, connect_protocol): + """ + Verify that paused state is preserved after a cluster restart. + """ + + self.CONNECT_PROTOCOL = connect_protocol + self.setup_services() + self.cc.set_configs(lambda node: self.render("connect-distributed.properties", node=node)) + self.cc.start() + + self.source = VerifiableSource(self.cc, topic=self.TOPIC) + self.source.start() + + wait_until(lambda: self.is_running(self.source), timeout_sec=30, + err_msg="Failed to see connector transition to the RUNNING state") + + self.cc.pause_connector(self.source.name) + + self.cc.restart() + + # we should still be paused after restarting + for node in self.cc.nodes: + wait_until(lambda: self.is_paused(self.source, node), timeout_sec=120, + err_msg="Failed to see connector startup in PAUSED state") + + @cluster(num_nodes=6) + @matrix(security_protocol=[SecurityConfig.PLAINTEXT, SecurityConfig.SASL_SSL], connect_protocol=['sessioned', 'compatible', 'eager']) + def test_file_source_and_sink(self, security_protocol, connect_protocol): + """ + Tests that a basic file connector works across clean rolling bounces. This validates that the connector is + correctly created, tasks instantiated, and as nodes restart the work is rebalanced across nodes. + """ + + self.CONNECT_PROTOCOL = connect_protocol + self.setup_services(security_protocol=security_protocol) + self.cc.set_configs(lambda node: self.render("connect-distributed.properties", node=node)) + + self.cc.start() + + self.logger.info("Creating connectors") + self._start_connector("connect-file-source.properties") + self._start_connector("connect-file-sink.properties") + + # Generating data on the source node should generate new records and create new output on the sink node. Timeouts + # here need to be more generous than they are for standalone mode because a) it takes longer to write configs, + # do rebalancing of the group, etc, and b) without explicit leave group support, rebalancing takes awhile + for node in self.cc.nodes: + node.account.ssh("echo -e -n " + repr(self.FIRST_INPUTS) + " >> " + self.INPUT_FILE) + wait_until(lambda: self._validate_file_output(self.FIRST_INPUT_LIST), timeout_sec=70, err_msg="Data added to input file was not seen in the output file in a reasonable amount of time.") + + # Restarting both should result in them picking up where they left off, + # only processing new data. + self.cc.restart() + + for node in self.cc.nodes: + node.account.ssh("echo -e -n " + repr(self.SECOND_INPUTS) + " >> " + self.INPUT_FILE) + wait_until(lambda: self._validate_file_output(self.FIRST_INPUT_LIST + self.SECOND_INPUT_LIST), timeout_sec=150, err_msg="Sink output file never converged to the same state as the input file") + + @cluster(num_nodes=6) + @matrix(clean=[True, False], connect_protocol=['sessioned', 'compatible', 'eager']) + def test_bounce(self, clean, connect_protocol): + """ + Validates that source and sink tasks that run continuously and produce a predictable sequence of messages + run correctly and deliver messages exactly once when Kafka Connect workers undergo clean rolling bounces. + """ + num_tasks = 3 + + self.CONNECT_PROTOCOL = connect_protocol + self.setup_services() + self.cc.set_configs(lambda node: self.render("connect-distributed.properties", node=node)) + self.cc.start() + + self.source = VerifiableSource(self.cc, topic=self.TOPIC, tasks=num_tasks, throughput=100) + self.source.start() + self.sink = VerifiableSink(self.cc, tasks=num_tasks, topics=[self.TOPIC]) + self.sink.start() + + for _ in range(3): + for node in self.cc.nodes: + started = time.time() + self.logger.info("%s bouncing Kafka Connect on %s", clean and "Clean" or "Hard", str(node.account)) + self.cc.stop_node(node, clean_shutdown=clean) + with node.account.monitor_log(self.cc.LOG_FILE) as monitor: + self.cc.start_node(node) + monitor.wait_until("Starting connectors and tasks using config offset", timeout_sec=90, + err_msg="Kafka Connect worker didn't successfully join group and start work") + self.logger.info("Bounced Kafka Connect on %s and rejoined in %f seconds", node.account, time.time() - started) + + # Give additional time for the consumer groups to recover. Even if it is not a hard bounce, there are + # some cases where a restart can cause a rebalance to take the full length of the session timeout + # (e.g. if the client shuts down before it has received the memberId from its initial JoinGroup). + # If we don't give enough time for the group to stabilize, the next bounce may cause consumers to + # be shut down before they have any time to process data and we can end up with zero data making it + # through the test. + time.sleep(15) + + # Wait at least scheduled.rebalance.max.delay.ms to expire and rebalance + time.sleep(60) + + # Allow the connectors to startup, recover, and exit cleanly before + # ending the test. It's possible for the source connector to make + # uncommitted progress, and for the sink connector to read messages that + # have not been committed yet, and fail a later assertion. + wait_until(lambda: self.is_running(self.source), timeout_sec=30, + err_msg="Failed to see connector transition to the RUNNING state") + time.sleep(15) + self.source.stop() + # Ensure that the sink connector has an opportunity to read all + # committed messages from the source connector. + wait_until(lambda: self.is_running(self.sink), timeout_sec=30, + err_msg="Failed to see connector transition to the RUNNING state") + time.sleep(15) + self.sink.stop() + self.cc.stop() + + # Validate at least once delivery of everything that was reported as written since we should have flushed and + # cleanly exited. Currently this only tests at least once delivery because the sink task may not have consumed + # all the messages generated by the source task. This needs to be done per-task since seqnos are not unique across + # tasks. + success = True + errors = [] + allow_dups = not clean + src_messages = self.source.committed_messages() + sink_messages = self.sink.flushed_messages() + for task in range(num_tasks): + # Validate source messages + src_seqnos = [msg['seqno'] for msg in src_messages if msg['task'] == task] + # Every seqno up to the largest one we ever saw should appear. Each seqno should only appear once because clean + # bouncing should commit on rebalance. + src_seqno_max = max(src_seqnos) if len(src_seqnos) else 0 + self.logger.debug("Max source seqno: %d", src_seqno_max) + src_seqno_counts = Counter(src_seqnos) + missing_src_seqnos = sorted(set(range(src_seqno_max)).difference(set(src_seqnos))) + duplicate_src_seqnos = sorted(seqno for seqno,count in src_seqno_counts.items() if count > 1) + + if missing_src_seqnos: + self.logger.error("Missing source sequence numbers for task " + str(task)) + errors.append("Found missing source sequence numbers for task %d: %s" % (task, missing_src_seqnos)) + success = False + if not allow_dups and duplicate_src_seqnos: + self.logger.error("Duplicate source sequence numbers for task " + str(task)) + errors.append("Found duplicate source sequence numbers for task %d: %s" % (task, duplicate_src_seqnos)) + success = False + + + # Validate sink messages + sink_seqnos = [msg['seqno'] for msg in sink_messages if msg['task'] == task] + # Every seqno up to the largest one we ever saw should appear. Each seqno should only appear once because + # clean bouncing should commit on rebalance. + sink_seqno_max = max(sink_seqnos) if len(sink_seqnos) else 0 + self.logger.debug("Max sink seqno: %d", sink_seqno_max) + sink_seqno_counts = Counter(sink_seqnos) + missing_sink_seqnos = sorted(set(range(sink_seqno_max)).difference(set(sink_seqnos))) + duplicate_sink_seqnos = sorted(seqno for seqno,count in iter(sink_seqno_counts.items()) if count > 1) + + if missing_sink_seqnos: + self.logger.error("Missing sink sequence numbers for task " + str(task)) + errors.append("Found missing sink sequence numbers for task %d: %s" % (task, missing_sink_seqnos)) + success = False + if not allow_dups and duplicate_sink_seqnos: + self.logger.error("Duplicate sink sequence numbers for task " + str(task)) + errors.append("Found duplicate sink sequence numbers for task %d: %s" % (task, duplicate_sink_seqnos)) + success = False + + # Validate source and sink match + if sink_seqno_max > src_seqno_max: + self.logger.error("Found sink sequence number greater than any generated sink sequence number for task %d: %d > %d", task, sink_seqno_max, src_seqno_max) + errors.append("Found sink sequence number greater than any generated sink sequence number for task %d: %d > %d" % (task, sink_seqno_max, src_seqno_max)) + success = False + if src_seqno_max < 1000 or sink_seqno_max < 1000: + errors.append("Not enough messages were processed: source:%d sink:%d" % (src_seqno_max, sink_seqno_max)) + success = False + + if not success: + self.mark_for_collect(self.cc) + # Also collect the data in the topic to aid in debugging + consumer_validator = ConsoleConsumer(self.test_context, 1, self.kafka, self.source.topic, consumer_timeout_ms=1000, print_key=True) + consumer_validator.run() + self.mark_for_collect(consumer_validator, "consumer_stdout") + + assert success, "Found validation errors:\n" + "\n ".join(errors) + + @cluster(num_nodes=6) + @matrix(connect_protocol=['sessioned', 'compatible', 'eager']) + def test_transformations(self, connect_protocol): + self.CONNECT_PROTOCOL = connect_protocol + self.setup_services(timestamp_type='CreateTime') + self.cc.set_configs(lambda node: self.render("connect-distributed.properties", node=node)) + self.cc.start() + + ts_fieldname = 'the_timestamp' + + NamedConnector = namedtuple('Connector', ['name']) + + source_connector = NamedConnector(name='file-src') + + self.cc.create_connector({ + 'name': source_connector.name, + 'connector.class': 'org.apache.kafka.connect.file.FileStreamSourceConnector', + 'tasks.max': 1, + 'file': self.INPUT_FILE, + 'topic': self.TOPIC, + 'transforms': 'hoistToStruct,insertTimestampField', + 'transforms.hoistToStruct.type': 'org.apache.kafka.connect.transforms.HoistField$Value', + 'transforms.hoistToStruct.field': 'content', + 'transforms.insertTimestampField.type': 'org.apache.kafka.connect.transforms.InsertField$Value', + 'transforms.insertTimestampField.timestamp.field': ts_fieldname, + }) + + wait_until(lambda: self.connector_is_running(source_connector), timeout_sec=30, err_msg='Failed to see connector transition to the RUNNING state') + + for node in self.cc.nodes: + node.account.ssh("echo -e -n " + repr(self.FIRST_INPUTS) + " >> " + self.INPUT_FILE) + + consumer = ConsoleConsumer(self.test_context, 1, self.kafka, self.TOPIC, consumer_timeout_ms=15000, print_timestamp=True) + consumer.run() + + assert len(consumer.messages_consumed[1]) == len(self.FIRST_INPUT_LIST) + + expected_schema = { + 'type': 'struct', + 'fields': [ + {'field': 'content', 'type': 'string', 'optional': False}, + {'field': ts_fieldname, 'name': 'org.apache.kafka.connect.data.Timestamp', 'type': 'int64', 'version': 1, 'optional': True}, + ], + 'optional': False + } + + for msg in consumer.messages_consumed[1]: + (ts_info, value) = msg.split('\t') + + assert ts_info.startswith('CreateTime:') + ts = int(ts_info[len('CreateTime:'):]) + + obj = json.loads(value) + assert obj['schema'] == expected_schema + assert obj['payload']['content'] in self.FIRST_INPUT_LIST + assert obj['payload'][ts_fieldname] == ts + + @cluster(num_nodes=5) + @parametrize(broker_version=str(DEV_BRANCH), auto_create_topics=False, security_protocol=SecurityConfig.PLAINTEXT, connect_protocol='sessioned') + @parametrize(broker_version=str(LATEST_0_11_0), auto_create_topics=False, security_protocol=SecurityConfig.PLAINTEXT, connect_protocol='sessioned') + @parametrize(broker_version=str(LATEST_0_10_2), auto_create_topics=False, security_protocol=SecurityConfig.PLAINTEXT, connect_protocol='sessioned') + @parametrize(broker_version=str(LATEST_0_10_1), auto_create_topics=False, security_protocol=SecurityConfig.PLAINTEXT, connect_protocol='sessioned') + @parametrize(broker_version=str(LATEST_0_10_0), auto_create_topics=True, security_protocol=SecurityConfig.PLAINTEXT, connect_protocol='sessioned') + @parametrize(broker_version=str(DEV_BRANCH), auto_create_topics=False, security_protocol=SecurityConfig.PLAINTEXT, connect_protocol='compatible') + @parametrize(broker_version=str(LATEST_2_3), auto_create_topics=False, security_protocol=SecurityConfig.PLAINTEXT, connect_protocol='compatible') + @parametrize(broker_version=str(LATEST_2_2), auto_create_topics=False, security_protocol=SecurityConfig.PLAINTEXT, connect_protocol='compatible') + @parametrize(broker_version=str(LATEST_2_1), auto_create_topics=False, security_protocol=SecurityConfig.PLAINTEXT, connect_protocol='compatible') + @parametrize(broker_version=str(LATEST_2_0), auto_create_topics=False, security_protocol=SecurityConfig.PLAINTEXT, connect_protocol='compatible') + @parametrize(broker_version=str(LATEST_1_1), auto_create_topics=False, security_protocol=SecurityConfig.PLAINTEXT, connect_protocol='compatible') + @parametrize(broker_version=str(LATEST_1_0), auto_create_topics=False, security_protocol=SecurityConfig.PLAINTEXT, connect_protocol='compatible') + @parametrize(broker_version=str(LATEST_0_11_0), auto_create_topics=False, security_protocol=SecurityConfig.PLAINTEXT, connect_protocol='compatible') + @parametrize(broker_version=str(LATEST_0_10_2), auto_create_topics=False, security_protocol=SecurityConfig.PLAINTEXT, connect_protocol='compatible') + @parametrize(broker_version=str(LATEST_0_10_1), auto_create_topics=False, security_protocol=SecurityConfig.PLAINTEXT, connect_protocol='compatible') + @parametrize(broker_version=str(LATEST_0_10_0), auto_create_topics=True, security_protocol=SecurityConfig.PLAINTEXT, connect_protocol='compatible') + @parametrize(broker_version=str(DEV_BRANCH), auto_create_topics=False, security_protocol=SecurityConfig.PLAINTEXT, connect_protocol='eager') + @parametrize(broker_version=str(LATEST_2_3), auto_create_topics=False, security_protocol=SecurityConfig.PLAINTEXT, connect_protocol='eager') + @parametrize(broker_version=str(LATEST_2_2), auto_create_topics=False, security_protocol=SecurityConfig.PLAINTEXT, connect_protocol='eager') + @parametrize(broker_version=str(LATEST_2_1), auto_create_topics=False, security_protocol=SecurityConfig.PLAINTEXT, connect_protocol='eager') + @parametrize(broker_version=str(LATEST_2_0), auto_create_topics=False, security_protocol=SecurityConfig.PLAINTEXT, connect_protocol='eager') + @parametrize(broker_version=str(LATEST_1_1), auto_create_topics=False, security_protocol=SecurityConfig.PLAINTEXT, connect_protocol='eager') + @parametrize(broker_version=str(LATEST_1_0), auto_create_topics=False, security_protocol=SecurityConfig.PLAINTEXT, connect_protocol='eager') + @parametrize(broker_version=str(LATEST_0_11_0), auto_create_topics=False, security_protocol=SecurityConfig.PLAINTEXT, connect_protocol='eager') + @parametrize(broker_version=str(LATEST_0_10_2), auto_create_topics=False, security_protocol=SecurityConfig.PLAINTEXT, connect_protocol='eager') + @parametrize(broker_version=str(LATEST_0_10_1), auto_create_topics=False, security_protocol=SecurityConfig.PLAINTEXT, connect_protocol='eager') + @parametrize(broker_version=str(LATEST_0_10_0), auto_create_topics=True, security_protocol=SecurityConfig.PLAINTEXT, connect_protocol='eager') + def test_broker_compatibility(self, broker_version, auto_create_topics, security_protocol, connect_protocol): + """ + Verify that Connect will start up with various broker versions with various configurations. + When Connect distributed starts up, it either creates internal topics (v0.10.1.0 and after) + or relies upon the broker to auto-create the topics (v0.10.0.x and before). + """ + self.CONNECT_PROTOCOL = connect_protocol + self.setup_services(broker_version=KafkaVersion(broker_version), auto_create_topics=auto_create_topics, security_protocol=security_protocol) + self.cc.set_configs(lambda node: self.render("connect-distributed.properties", node=node)) + + self.cc.start() + + self.logger.info("Creating connectors") + self._start_connector("connect-file-source.properties") + self._start_connector("connect-file-sink.properties") + + # Generating data on the source node should generate new records and create new output on the sink node. Timeouts + # here need to be more generous than they are for standalone mode because a) it takes longer to write configs, + # do rebalancing of the group, etc, and b) without explicit leave group support, rebalancing takes awhile + for node in self.cc.nodes: + node.account.ssh("echo -e -n " + repr(self.FIRST_INPUTS) + " >> " + self.INPUT_FILE) + wait_until(lambda: self._validate_file_output(self.FIRST_INPUT_LIST), timeout_sec=70, err_msg="Data added to input file was not seen in the output file in a reasonable amount of time.") + + def _validate_file_output(self, input): + input_set = set(input) + # Output needs to be collected from all nodes because we can't be sure where the tasks will be scheduled. + # Between the first and second rounds, we might even end up with half the data on each node. + output_set = set(itertools.chain(*[ + [line.strip() for line in self._file_contents(node, self.OUTPUT_FILE)] for node in self.cc.nodes + ])) + return input_set == output_set + + def _file_contents(self, node, file): + try: + # Convert to a list here or the RemoteCommandError may be returned during a call to the generator instead of + # immediately + return list(node.account.ssh_capture("cat " + file)) + except RemoteCommandError: + return [] diff --git a/tests/kafkatest/tests/connect/connect_rest_test.py b/tests/kafkatest/tests/connect/connect_rest_test.py new file mode 100644 index 0000000..4d978a2 --- /dev/null +++ b/tests/kafkatest/tests/connect/connect_rest_test.py @@ -0,0 +1,218 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from kafkatest.tests.kafka_test import KafkaTest +from kafkatest.services.connect import ConnectDistributedService, ConnectRestError, ConnectServiceBase +from ducktape.utils.util import wait_until +from ducktape.mark import matrix +from ducktape.mark.resource import cluster +from ducktape.cluster.remoteaccount import RemoteCommandError + +import json +import itertools + + +class ConnectRestApiTest(KafkaTest): + """ + Test of Kafka Connect's REST API endpoints. + """ + + FILE_SOURCE_CONNECTOR = 'org.apache.kafka.connect.file.FileStreamSourceConnector' + FILE_SINK_CONNECTOR = 'org.apache.kafka.connect.file.FileStreamSinkConnector' + + FILE_SOURCE_CONFIGS = {'name', 'connector.class', 'tasks.max', 'key.converter', 'value.converter', 'header.converter', 'batch.size', + 'topic', 'file', 'transforms', 'config.action.reload', 'errors.retry.timeout', 'errors.retry.delay.max.ms', + 'errors.tolerance', 'errors.log.enable', 'errors.log.include.messages', 'predicates', 'topic.creation.groups'} + FILE_SINK_CONFIGS = {'name', 'connector.class', 'tasks.max', 'key.converter', 'value.converter', 'header.converter', 'topics', + 'file', 'transforms', 'topics.regex', 'config.action.reload', 'errors.retry.timeout', 'errors.retry.delay.max.ms', + 'errors.tolerance', 'errors.log.enable', 'errors.log.include.messages', 'errors.deadletterqueue.topic.name', + 'errors.deadletterqueue.topic.replication.factor', 'errors.deadletterqueue.context.headers.enable', 'predicates'} + + INPUT_FILE = "/mnt/connect.input" + INPUT_FILE2 = "/mnt/connect.input2" + OUTPUT_FILE = "/mnt/connect.output" + + TOPIC = "topic-${file:%s:topic.external}" % ConnectServiceBase.EXTERNAL_CONFIGS_FILE + TOPIC_TEST = "test" + + DEFAULT_BATCH_SIZE = "2000" + OFFSETS_TOPIC = "connect-offsets" + OFFSETS_REPLICATION_FACTOR = "1" + OFFSETS_PARTITIONS = "1" + CONFIG_TOPIC = "connect-configs" + CONFIG_REPLICATION_FACTOR = "1" + STATUS_TOPIC = "connect-status" + STATUS_REPLICATION_FACTOR = "1" + STATUS_PARTITIONS = "1" + + # Since tasks can be assigned to any node and we're testing with files, we need to make sure the content is the same + # across all nodes. + INPUT_LIST = ["foo", "bar", "baz"] + INPUTS = "\n".join(INPUT_LIST) + "\n" + LONGER_INPUT_LIST = ["foo", "bar", "baz", "razz", "ma", "tazz"] + LONER_INPUTS = "\n".join(LONGER_INPUT_LIST) + "\n" + + SCHEMA = {"type": "string", "optional": False} + + CONNECT_PROTOCOL="compatible" + + def __init__(self, test_context): + super(ConnectRestApiTest, self).__init__(test_context, num_zk=1, num_brokers=1, topics={ + 'test': {'partitions': 1, 'replication-factor': 1} + }) + + self.cc = ConnectDistributedService(test_context, 2, self.kafka, [self.INPUT_FILE, self.INPUT_FILE2, self.OUTPUT_FILE]) + + @cluster(num_nodes=4) + @matrix(connect_protocol=['compatible', 'eager']) + def test_rest_api(self, connect_protocol): + # Template parameters + self.key_converter = "org.apache.kafka.connect.json.JsonConverter" + self.value_converter = "org.apache.kafka.connect.json.JsonConverter" + self.schemas = True + self.CONNECT_PROTOCOL = connect_protocol + + self.cc.set_configs(lambda node: self.render("connect-distributed.properties", node=node)) + self.cc.set_external_configs(lambda node: self.render("connect-file-external.properties", node=node)) + + self.cc.start() + + assert self.cc.list_connectors() == [] + + # After MM2 and the connector classes that it added, the assertion here checks that the registered + # Connect plugins are a superset of the connectors expected to be present. + assert set([connector_plugin['class'] for connector_plugin in self.cc.list_connector_plugins()]).issuperset( + {self.FILE_SOURCE_CONNECTOR, self.FILE_SINK_CONNECTOR}) + + source_connector_props = self.render("connect-file-source.properties") + sink_connector_props = self.render("connect-file-sink.properties") + + self.logger.info("Validating connector configurations") + source_connector_config = self._config_dict_from_props(source_connector_props) + configs = self.cc.validate_config(self.FILE_SOURCE_CONNECTOR, source_connector_config) + self.verify_config(self.FILE_SOURCE_CONNECTOR, self.FILE_SOURCE_CONFIGS, configs) + + sink_connector_config = self._config_dict_from_props(sink_connector_props) + configs = self.cc.validate_config(self.FILE_SINK_CONNECTOR, sink_connector_config) + self.verify_config(self.FILE_SINK_CONNECTOR, self.FILE_SINK_CONFIGS, configs) + + self.logger.info("Creating connectors") + self.cc.create_connector(source_connector_config) + self.cc.create_connector(sink_connector_config) + + # We should see the connectors appear + wait_until(lambda: set(self.cc.list_connectors()) == set(["local-file-source", "local-file-sink"]), + timeout_sec=10, err_msg="Connectors that were just created did not appear in connector listing") + + # We'll only do very simple validation that the connectors and tasks really ran. + for node in self.cc.nodes: + node.account.ssh("echo -e -n " + repr(self.INPUTS) + " >> " + self.INPUT_FILE) + wait_until(lambda: self.validate_output(self.INPUT_LIST), timeout_sec=120, err_msg="Data added to input file was not seen in the output file in a reasonable amount of time.") + + # Trying to create the same connector again should cause an error + try: + self.cc.create_connector(self._config_dict_from_props(source_connector_props)) + assert False, "creating the same connector should have caused a conflict" + except ConnectRestError: + pass # expected + + # Validate that we can get info about connectors + expected_source_info = { + 'name': 'local-file-source', + 'config': self._config_dict_from_props(source_connector_props), + 'tasks': [{'connector': 'local-file-source', 'task': 0}], + 'type': 'source' + } + source_info = self.cc.get_connector("local-file-source") + assert expected_source_info == source_info, "Incorrect info:" + json.dumps(source_info) + source_config = self.cc.get_connector_config("local-file-source") + assert expected_source_info['config'] == source_config, "Incorrect config: " + json.dumps(source_config) + expected_sink_info = { + 'name': 'local-file-sink', + 'config': self._config_dict_from_props(sink_connector_props), + 'tasks': [{'connector': 'local-file-sink', 'task': 0}], + 'type': 'sink' + } + sink_info = self.cc.get_connector("local-file-sink") + assert expected_sink_info == sink_info, "Incorrect info:" + json.dumps(sink_info) + sink_config = self.cc.get_connector_config("local-file-sink") + assert expected_sink_info['config'] == sink_config, "Incorrect config: " + json.dumps(sink_config) + + # Validate that we can get info about tasks. This info should definitely be available now without waiting since + # we've already seen data appear in files. + # TODO: It would be nice to validate a complete listing, but that doesn't make sense for the file connectors + expected_source_task_info = [{ + 'id': {'connector': 'local-file-source', 'task': 0}, + 'config': { + 'task.class': 'org.apache.kafka.connect.file.FileStreamSourceTask', + 'file': self.INPUT_FILE, + 'topic': self.TOPIC, + 'batch.size': self.DEFAULT_BATCH_SIZE + } + }] + source_task_info = self.cc.get_connector_tasks("local-file-source") + assert expected_source_task_info == source_task_info, "Incorrect info:" + json.dumps(source_task_info) + expected_sink_task_info = [{ + 'id': {'connector': 'local-file-sink', 'task': 0}, + 'config': { + 'task.class': 'org.apache.kafka.connect.file.FileStreamSinkTask', + 'file': self.OUTPUT_FILE, + 'topics': self.TOPIC + } + }] + sink_task_info = self.cc.get_connector_tasks("local-file-sink") + assert expected_sink_task_info == sink_task_info, "Incorrect info:" + json.dumps(sink_task_info) + + file_source_config = self._config_dict_from_props(source_connector_props) + file_source_config['file'] = self.INPUT_FILE2 + self.cc.set_connector_config("local-file-source", file_source_config) + + # We should also be able to verify that the modified configs caused the tasks to move to the new file and pick up + # more data. + for node in self.cc.nodes: + node.account.ssh("echo -e -n " + repr(self.LONER_INPUTS) + " >> " + self.INPUT_FILE2) + wait_until(lambda: self.validate_output(self.LONGER_INPUT_LIST), timeout_sec=120, err_msg="Data added to input file was not seen in the output file in a reasonable amount of time.") + + self.cc.delete_connector("local-file-source") + self.cc.delete_connector("local-file-sink") + wait_until(lambda: len(self.cc.list_connectors()) == 0, timeout_sec=10, err_msg="Deleted connectors did not disappear from REST listing") + + def validate_output(self, input): + input_set = set(input) + # Output needs to be collected from all nodes because we can't be sure where the tasks will be scheduled. + output_set = set(itertools.chain(*[ + [line.strip() for line in self.file_contents(node, self.OUTPUT_FILE)] for node in self.cc.nodes + ])) + return input_set == output_set + + def file_contents(self, node, file): + try: + # Convert to a list here or the RemoteCommandError may be returned during a call to the generator instead of + # immediately + return list(node.account.ssh_capture("cat " + file)) + except RemoteCommandError: + return [] + + def _config_dict_from_props(self, connector_props): + return dict([line.strip().split('=', 1) for line in connector_props.split('\n') if line.strip() and not line.strip().startswith('#')]) + + def verify_config(self, name, config_def, configs): + # Should have zero errors + assert name == configs['name'] + # Should have zero errors + assert 0 == configs['error_count'] + # Should return all configuration + config_names = [config['definition']['name'] for config in configs['configs']] + assert config_def == set(config_names) diff --git a/tests/kafkatest/tests/connect/connect_test.py b/tests/kafkatest/tests/connect/connect_test.py new file mode 100644 index 0000000..1a7f6ab --- /dev/null +++ b/tests/kafkatest/tests/connect/connect_test.py @@ -0,0 +1,209 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ducktape.tests.test import Test +from ducktape.mark.resource import cluster +from ducktape.utils.util import wait_until +from ducktape.mark import matrix, parametrize +from ducktape.cluster.remoteaccount import RemoteCommandError +from ducktape.errors import TimeoutError + +from kafkatest.services.zookeeper import ZookeeperService +from kafkatest.services.kafka import KafkaService, quorum +from kafkatest.services.connect import ConnectServiceBase, ConnectStandaloneService, ErrorTolerance +from kafkatest.services.console_consumer import ConsoleConsumer +from kafkatest.services.security.security_config import SecurityConfig + +import hashlib +import json + + +class ConnectStandaloneFileTest(Test): + """ + Simple test of Kafka Connect that produces data from a file in one + standalone process and consumes it on another, validating the output is + identical to the input. + """ + + FILE_SOURCE_CONNECTOR = 'org.apache.kafka.connect.file.FileStreamSourceConnector' + FILE_SINK_CONNECTOR = 'org.apache.kafka.connect.file.FileStreamSinkConnector' + + INPUT_FILE = "/mnt/connect.input" + OUTPUT_FILE = "/mnt/connect.output" + + OFFSETS_FILE = "/mnt/connect.offsets" + + TOPIC = "${file:%s:topic.external}" % ConnectServiceBase.EXTERNAL_CONFIGS_FILE + TOPIC_TEST = "test" + + FIRST_INPUT_LIST = ["foo", "bar", "baz"] + FIRST_INPUT = "\n".join(FIRST_INPUT_LIST) + "\n" + SECOND_INPUT_LIST = ["razz", "ma", "tazz"] + SECOND_INPUT = "\n".join(SECOND_INPUT_LIST) + "\n" + + SCHEMA = { "type": "string", "optional": False } + + def __init__(self, test_context): + super(ConnectStandaloneFileTest, self).__init__(test_context) + self.num_zk = 1 + self.num_brokers = 1 + self.topics = { + 'test' : { 'partitions': 1, 'replication-factor': 1 } + } + + self.zk = ZookeeperService(test_context, self.num_zk) if quorum.for_test(test_context) == quorum.zk else None + + @cluster(num_nodes=5) + @parametrize(converter="org.apache.kafka.connect.json.JsonConverter", schemas=True) + @parametrize(converter="org.apache.kafka.connect.json.JsonConverter", schemas=False) + @parametrize(converter="org.apache.kafka.connect.storage.StringConverter", schemas=None) + @parametrize(security_protocol=SecurityConfig.PLAINTEXT) + @cluster(num_nodes=6) + @matrix(security_protocol=[SecurityConfig.SASL_SSL], metadata_quorum=quorum.all_non_upgrade) + def test_file_source_and_sink(self, converter="org.apache.kafka.connect.json.JsonConverter", schemas=True, security_protocol='PLAINTEXT', + metadata_quorum=quorum.zk): + """ + Validates basic end-to-end functionality of Connect standalone using the file source and sink converters. Includes + parameterizations to test different converters (which also test per-connector converter overrides), schema/schemaless + modes, and security support. + """ + assert converter != None, "converter type must be set" + # Template parameters. Note that we don't set key/value.converter. These default to JsonConverter and we validate + # converter overrides via the connector configuration. + if converter != "org.apache.kafka.connect.json.JsonConverter": + self.override_key_converter = converter + self.override_value_converter = converter + self.schemas = schemas + + self.kafka = KafkaService(self.test_context, self.num_brokers, self.zk, + security_protocol=security_protocol, interbroker_security_protocol=security_protocol, + topics=self.topics, controller_num_nodes_override=self.num_zk) + + self.source = ConnectStandaloneService(self.test_context, self.kafka, [self.INPUT_FILE, self.OFFSETS_FILE]) + self.sink = ConnectStandaloneService(self.test_context, self.kafka, [self.OUTPUT_FILE, self.OFFSETS_FILE]) + self.consumer_validator = ConsoleConsumer(self.test_context, 1, self.kafka, self.TOPIC_TEST, + consumer_timeout_ms=10000) + + if self.zk: + self.zk.start() + self.kafka.start() + + self.source.set_configs(lambda node: self.render("connect-standalone.properties", node=node), [self.render("connect-file-source.properties")]) + self.sink.set_configs(lambda node: self.render("connect-standalone.properties", node=node), [self.render("connect-file-sink.properties")]) + + self.source.set_external_configs(lambda node: self.render("connect-file-external.properties", node=node)) + self.sink.set_external_configs(lambda node: self.render("connect-file-external.properties", node=node)) + + self.source.start() + self.sink.start() + + # Generating data on the source node should generate new records and create new output on the sink node + self.source.node.account.ssh("echo -e -n " + repr(self.FIRST_INPUT) + " >> " + self.INPUT_FILE) + wait_until(lambda: self.validate_output(self.FIRST_INPUT), timeout_sec=60, err_msg="Data added to input file was not seen in the output file in a reasonable amount of time.") + + # Restarting both should result in them picking up where they left off, + # only processing new data. + self.source.restart() + self.sink.restart() + + self.source.node.account.ssh("echo -e -n " + repr(self.SECOND_INPUT) + " >> " + self.INPUT_FILE) + wait_until(lambda: self.validate_output(self.FIRST_INPUT + self.SECOND_INPUT), timeout_sec=60, err_msg="Sink output file never converged to the same state as the input file") + + # Validate the format of the data in the Kafka topic + self.consumer_validator.run() + expected = json.dumps([line if not self.schemas else { "schema": self.SCHEMA, "payload": line } for line in self.FIRST_INPUT_LIST + self.SECOND_INPUT_LIST]) + decoder = (json.loads if converter.endswith("JsonConverter") else str) + actual = json.dumps([decoder(x) for x in self.consumer_validator.messages_consumed[1]]) + assert expected == actual, "Expected %s but saw %s in Kafka" % (expected, actual) + + def validate_output(self, value): + try: + output_hash = list(self.sink.node.account.ssh_capture("md5sum " + self.OUTPUT_FILE))[0].strip().split()[0] + return output_hash == hashlib.md5(value.encode('utf-8')).hexdigest() + except RemoteCommandError: + return False + + @cluster(num_nodes=5) + @parametrize(error_tolerance=ErrorTolerance.ALL) + @parametrize(error_tolerance=ErrorTolerance.NONE) + def test_skip_and_log_to_dlq(self, error_tolerance): + self.kafka = KafkaService(self.test_context, self.num_brokers, self.zk, topics=self.topics) + + # set config props + self.override_error_tolerance_props = error_tolerance + self.enable_deadletterqueue = True + + successful_records = [] + faulty_records = [] + records = [] + for i in range(0, 1000): + if i % 2 == 0: + records.append('{"some_key":' + str(i) + '}') + successful_records.append('{some_key=' + str(i) + '}') + else: + # badly formatted json records (missing a quote after the key) + records.append('{"some_key:' + str(i) + '}') + faulty_records.append('{"some_key:' + str(i) + '}') + + records = "\n".join(records) + "\n" + successful_records = "\n".join(successful_records) + "\n" + if error_tolerance == ErrorTolerance.ALL: + faulty_records = ",".join(faulty_records) + else: + faulty_records = faulty_records[0] + + self.source = ConnectStandaloneService(self.test_context, self.kafka, [self.INPUT_FILE, self.OFFSETS_FILE]) + self.sink = ConnectStandaloneService(self.test_context, self.kafka, [self.OUTPUT_FILE, self.OFFSETS_FILE]) + + self.zk.start() + self.kafka.start() + + self.override_key_converter = "org.apache.kafka.connect.storage.StringConverter" + self.override_value_converter = "org.apache.kafka.connect.storage.StringConverter" + self.source.set_configs(lambda node: self.render("connect-standalone.properties", node=node), [self.render("connect-file-source.properties")]) + + self.override_key_converter = "org.apache.kafka.connect.json.JsonConverter" + self.override_value_converter = "org.apache.kafka.connect.json.JsonConverter" + self.override_key_converter_schemas_enable = False + self.override_value_converter_schemas_enable = False + self.sink.set_configs(lambda node: self.render("connect-standalone.properties", node=node), [self.render("connect-file-sink.properties")]) + + self.source.set_external_configs(lambda node: self.render("connect-file-external.properties", node=node)) + self.sink.set_external_configs(lambda node: self.render("connect-file-external.properties", node=node)) + + self.source.start() + self.sink.start() + + # Generating data on the source node should generate new records and create new output on the sink node + self.source.node.account.ssh("echo -e -n " + repr(records) + " >> " + self.INPUT_FILE) + + if error_tolerance == ErrorTolerance.NONE: + try: + wait_until(lambda: self.validate_output(successful_records), timeout_sec=15, + err_msg="Clean records added to input file were not seen in the output file in a reasonable amount of time.") + raise Exception("Expected to not find any results in this file.") + except TimeoutError: + self.logger.info("Caught expected exception") + else: + wait_until(lambda: self.validate_output(successful_records), timeout_sec=15, + err_msg="Clean records added to input file were not seen in the output file in a reasonable amount of time.") + + if self.enable_deadletterqueue: + self.logger.info("Reading records from deadletterqueue") + consumer_validator = ConsoleConsumer(self.test_context, 1, self.kafka, "my-connector-errors", + consumer_timeout_ms=10000) + consumer_validator.run() + actual = ",".join(consumer_validator.messages_consumed[1]) + assert faulty_records == actual, "Expected %s but saw %s in dead letter queue" % (faulty_records, actual) diff --git a/tests/kafkatest/tests/connect/templates/connect-distributed.properties b/tests/kafkatest/tests/connect/templates/connect-distributed.properties new file mode 100644 index 0000000..6d2d5e2 --- /dev/null +++ b/tests/kafkatest/tests/connect/templates/connect-distributed.properties @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +bootstrap.servers={{ kafka.bootstrap_servers(kafka.security_config.security_protocol) }} +{{ kafka.security_config.client_config().props() }} +{{ kafka.security_config.client_config().props("producer.") }} +{{ kafka.security_config.client_config().props("consumer.") }} + +group.id={{ group|default("connect-cluster") }} + +connect.protocol={{ CONNECT_PROTOCOL|default("sessioned") }} +scheduled.rebalance.max.delay.ms={{ SCHEDULED_REBALANCE_MAX_DELAY_MS|default(60000) }} + +key.converter={{ key_converter|default("org.apache.kafka.connect.json.JsonConverter") }} +value.converter={{ value_converter|default("org.apache.kafka.connect.json.JsonConverter") }} +{% if key_converter is not defined or key_converter.endswith("JsonConverter") %} +key.converter.schemas.enable={{ schemas|default(True)|string|lower }} +{% endif %} +{% if value_converter is not defined or value_converter.endswith("JsonConverter") %} +value.converter.schemas.enable={{ schemas|default(True)|string|lower }} +{% endif %} + +offset.storage.topic={{ OFFSETS_TOPIC }} +offset.storage.replication.factor={{ OFFSETS_REPLICATION_FACTOR }} +offset.storage.partitions={{ OFFSETS_PARTITIONS }} +config.storage.topic={{ CONFIG_TOPIC }} +config.storage.replication.factor={{ CONFIG_REPLICATION_FACTOR }} +status.storage.topic={{ STATUS_TOPIC }} +status.storage.replication.factor={{ STATUS_REPLICATION_FACTOR }} +status.storage.partitions={{ STATUS_PARTITIONS }} + +# Make sure data gets flushed frequently so tests don't have to wait to ensure they see data in output systems +offset.flush.interval.ms=5000 + +rest.advertised.host.name = {{ node.account.hostname }} + + +# Reduce session timeouts so tests that kill workers don't need to wait as long to recover +session.timeout.ms=10000 +consumer.session.timeout.ms=10000 + +# Reduce the admin client request timeouts so that we don't wait the default 120 sec before failing to connect the admin client +request.timeout.ms=30000 + +# Allow connector configs to use externalized config values of the form: +# ${file:/mnt/connect/connect-external-configs.properties:topic.external} +# +config.providers=file +config.providers.file.class=org.apache.kafka.common.config.provider.FileConfigProvider diff --git a/tests/kafkatest/tests/connect/templates/connect-file-external.properties b/tests/kafkatest/tests/connect/templates/connect-file-external.properties new file mode 100644 index 0000000..8dccd25 --- /dev/null +++ b/tests/kafkatest/tests/connect/templates/connect-file-external.properties @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +topic.external={{ TOPIC_TEST }} diff --git a/tests/kafkatest/tests/connect/templates/connect-file-sink.properties b/tests/kafkatest/tests/connect/templates/connect-file-sink.properties new file mode 100644 index 0000000..a58cc6b --- /dev/null +++ b/tests/kafkatest/tests/connect/templates/connect-file-sink.properties @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name=local-file-sink +connector.class={{ FILE_SINK_CONNECTOR }} +tasks.max=1 +file={{ OUTPUT_FILE }} +topics={{ TOPIC }} + +# For testing per-connector converters +{% if override_key_converter is defined %} +key.converter={{ override_key_converter }} +{% endif %} +{% if override_key_converter is defined %} +value.converter={{ override_value_converter }} +{% endif %} + +key.converter.schemas.enable={{ override_key_converter_schemas_enable|default(True) }} +value.converter.schemas.enable={{ override_value_converter_schemas_enable|default(True) }} + +# log error context along with application logs +errors.log.enable=true +errors.log.include.messages=true + +{% if enable_deadletterqueue is defined %} +# produce error context into the Kafka topic +errors.deadletterqueue.topic.name={{ override_deadletterqueue_topic_name|default("my-connector-errors") }} +errors.deadletterqueue.topic.replication.factor={{ override_deadletterqueue_replication_factor|default(1) }} +{% endif %} + +# Tolerate all errors. +errors.tolerance={{ override_error_tolerance_props|default("none") }} diff --git a/tests/kafkatest/tests/connect/templates/connect-file-source.properties b/tests/kafkatest/tests/connect/templates/connect-file-source.properties new file mode 100644 index 0000000..147e85a --- /dev/null +++ b/tests/kafkatest/tests/connect/templates/connect-file-source.properties @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name=local-file-source +connector.class={{ FILE_SOURCE_CONNECTOR }} +tasks.max=1 +file={{ INPUT_FILE }} +topic={{ TOPIC }} + +# For testing per-connector converters +{% if override_key_converter is defined %} +key.converter={{ override_key_converter }} +{% endif %} +{% if override_key_converter is defined %} +value.converter={{ override_value_converter }} +{% endif %} + +# log error context along with application logs +errors.log.enable=true +errors.log.include.messages=true + +# Tolerate all errors. +errors.tolerance={{ override_error_tolerance_props|default("none") }} diff --git a/tests/kafkatest/tests/connect/templates/connect-standalone.properties b/tests/kafkatest/tests/connect/templates/connect-standalone.properties new file mode 100644 index 0000000..a471dd5 --- /dev/null +++ b/tests/kafkatest/tests/connect/templates/connect-standalone.properties @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +bootstrap.servers={{ kafka.bootstrap_servers(kafka.security_config.security_protocol) }} +{{ kafka.security_config.client_config().props() }} +{{ kafka.security_config.client_config().props("producer.") }} +{{ kafka.security_config.client_config().props("consumer.") }} + +key.converter={{ key_converter|default("org.apache.kafka.connect.json.JsonConverter") }} +value.converter={{ value_converter|default("org.apache.kafka.connect.json.JsonConverter") }} +{% if key_converter is not defined or key_converter.endswith("JsonConverter") %} +key.converter.schemas.enable={{ schemas|default(True)|string|lower }} +{% endif %} +{% if value_converter is not defined or value_converter.endswith("JsonConverter") %} +value.converter.schemas.enable={{ schemas|default(True)|string|lower }} +{% endif %} + +offset.storage.file.filename={{ OFFSETS_FILE }} + +# Reduce the admin client request timeouts so that we don't wait the default 120 sec before failing to connect the admin client +request.timeout.ms=30000 + +# Allow connector configs to use externalized config values of the form: +# ${file:/mnt/connect/connect-external-configs.properties:topic.external} +# +config.providers=file +config.providers.file.class=org.apache.kafka.common.config.provider.FileConfigProvider diff --git a/tests/kafkatest/tests/core/__init__.py b/tests/kafkatest/tests/core/__init__.py new file mode 100644 index 0000000..ec20143 --- /dev/null +++ b/tests/kafkatest/tests/core/__init__.py @@ -0,0 +1,14 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/kafkatest/tests/core/compatibility_test_new_broker_test.py b/tests/kafkatest/tests/core/compatibility_test_new_broker_test.py new file mode 100644 index 0000000..56a0e27 --- /dev/null +++ b/tests/kafkatest/tests/core/compatibility_test_new_broker_test.py @@ -0,0 +1,95 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ducktape.mark import matrix, parametrize +from ducktape.utils.util import wait_until +from ducktape.mark.resource import cluster + +from kafkatest.services.console_consumer import ConsoleConsumer +from kafkatest.services.kafka import KafkaService, quorum +from kafkatest.services.kafka import config_property +from kafkatest.services.verifiable_producer import VerifiableProducer +from kafkatest.services.zookeeper import ZookeeperService +from kafkatest.tests.produce_consume_validate import ProduceConsumeValidateTest +from kafkatest.utils import is_int +from kafkatest.version import LATEST_0_8_2, LATEST_0_9, LATEST_0_10_0, LATEST_0_10_1, LATEST_0_10_2, LATEST_0_11_0, LATEST_1_0, LATEST_1_1, LATEST_2_0, LATEST_2_1, LATEST_2_2, LATEST_2_3, LATEST_2_4, LATEST_2_5, LATEST_2_6, LATEST_2_7, LATEST_2_8, DEV_BRANCH, KafkaVersion + +# Compatibility tests for moving to a new broker (e.g., 0.10.x) and using a mix of old and new clients (e.g., 0.9.x) +class ClientCompatibilityTestNewBroker(ProduceConsumeValidateTest): + + def __init__(self, test_context): + super(ClientCompatibilityTestNewBroker, self).__init__(test_context=test_context) + + def setUp(self): + self.topic = "test_topic" + self.zk = ZookeeperService(self.test_context, num_nodes=1) if quorum.for_test(self.test_context) == quorum.zk else None + + if self.zk: + self.zk.start() + + # Producer and consumer + self.producer_throughput = 10000 + self.num_producers = 1 + self.num_consumers = 1 + self.messages_per_producer = 1000 + + @cluster(num_nodes=6) + @matrix(producer_version=[str(DEV_BRANCH)], consumer_version=[str(DEV_BRANCH)], compression_types=[["snappy"]], timestamp_type=[str("LogAppendTime")], metadata_quorum=quorum.all_non_upgrade) + @matrix(producer_version=[str(DEV_BRANCH)], consumer_version=[str(DEV_BRANCH)], compression_types=[["none"]], timestamp_type=[str("LogAppendTime")], metadata_quorum=quorum.all_non_upgrade) + @parametrize(producer_version=str(DEV_BRANCH), consumer_version=str(LATEST_0_9), compression_types=["none"], new_consumer=False, timestamp_type=None) + @matrix(producer_version=[str(DEV_BRANCH)], consumer_version=[str(LATEST_0_9)], compression_types=[["snappy"]], timestamp_type=[str("CreateTime")], metadata_quorum=quorum.all_non_upgrade) + @matrix(producer_version=[str(LATEST_2_2)], consumer_version=[str(LATEST_2_2)], compression_types=[["none"]], timestamp_type=[str("CreateTime")], metadata_quorum=quorum.all_non_upgrade) + @matrix(producer_version=[str(LATEST_2_3)], consumer_version=[str(LATEST_2_3)], compression_types=[["none"]], timestamp_type=[str("CreateTime")], metadata_quorum=quorum.all_non_upgrade) + @matrix(producer_version=[str(LATEST_2_4)], consumer_version=[str(LATEST_2_4)], compression_types=[["none"]], timestamp_type=[str("CreateTime")], metadata_quorum=quorum.all_non_upgrade) + @matrix(producer_version=[str(LATEST_2_5)], consumer_version=[str(LATEST_2_5)], compression_types=[["none"]], timestamp_type=[str("CreateTime")], metadata_quorum=quorum.all_non_upgrade) + @matrix(producer_version=[str(LATEST_2_6)], consumer_version=[str(LATEST_2_6)], compression_types=[["none"]], timestamp_type=[str("CreateTime")], metadata_quorum=quorum.all_non_upgrade) + @matrix(producer_version=[str(LATEST_2_7)], consumer_version=[str(LATEST_2_7)], compression_types=[["none"]], timestamp_type=[str("CreateTime")], metadata_quorum=quorum.all_non_upgrade) + @matrix(producer_version=[str(LATEST_2_8)], consumer_version=[str(LATEST_2_8)], compression_types=[["none"]], timestamp_type=[str("CreateTime")], metadata_quorum=quorum.all_non_upgrade) + @matrix(producer_version=[str(LATEST_2_1)], consumer_version=[str(LATEST_2_1)], compression_types=[["zstd"]], timestamp_type=[str("CreateTime")], metadata_quorum=quorum.all_non_upgrade) + @matrix(producer_version=[str(LATEST_2_0)], consumer_version=[str(LATEST_2_0)], compression_types=[["snappy"]], timestamp_type=[str("CreateTime")], metadata_quorum=quorum.all_non_upgrade) + @matrix(producer_version=[str(LATEST_1_1)], consumer_version=[str(LATEST_1_1)], compression_types=[["lz4"]], timestamp_type=[str("CreateTime")], metadata_quorum=quorum.all_non_upgrade) + @matrix(producer_version=[str(LATEST_1_0)], consumer_version=[str(LATEST_1_0)], compression_types=[["none"]], timestamp_type=[str("CreateTime")], metadata_quorum=quorum.all_non_upgrade) + @matrix(producer_version=[str(LATEST_0_11_0)], consumer_version=[str(LATEST_0_11_0)], compression_types=[["gzip"]], timestamp_type=[str("CreateTime")], metadata_quorum=quorum.all_non_upgrade) + @matrix(producer_version=[str(LATEST_0_10_2)], consumer_version=[str(LATEST_0_10_2)], compression_types=[["lz4"]], timestamp_type=[str("CreateTime")], metadata_quorum=quorum.all_non_upgrade) + @matrix(producer_version=[str(LATEST_0_10_1)], consumer_version=[str(LATEST_0_10_1)], compression_types=[["snappy"]], timestamp_type=[str("LogAppendTime")], metadata_quorum=quorum.all_non_upgrade) + @matrix(producer_version=[str(LATEST_0_10_0)], consumer_version=[str(LATEST_0_10_0)], compression_types=[["snappy"]], timestamp_type=[str("LogAppendTime")], metadata_quorum=quorum.all_non_upgrade) + @matrix(producer_version=[str(LATEST_0_9)], consumer_version=[str(DEV_BRANCH)], compression_types=[["none"]], timestamp_type=[None], metadata_quorum=quorum.all_non_upgrade) + @matrix(producer_version=[str(LATEST_0_9)], consumer_version=[str(DEV_BRANCH)], compression_types=[["snappy"]], timestamp_type=[None], metadata_quorum=quorum.all_non_upgrade) + @matrix(producer_version=[str(LATEST_0_9)], consumer_version=[str(LATEST_0_9)], compression_types=[["snappy"]], timestamp_type=[str("LogAppendTime")], metadata_quorum=quorum.all_non_upgrade) + @parametrize(producer_version=str(LATEST_0_8_2), consumer_version=str(LATEST_0_8_2), compression_types=["none"], new_consumer=False, timestamp_type=None) + def test_compatibility(self, producer_version, consumer_version, compression_types, new_consumer=True, timestamp_type=None, metadata_quorum=quorum.zk): + if not new_consumer and metadata_quorum != quorum.zk: + raise Exception("ZooKeeper-based consumers are not supported when using a KRaft metadata quorum") + self.kafka = KafkaService(self.test_context, num_nodes=3, zk=self.zk, version=DEV_BRANCH, topics={self.topic: { + "partitions": 3, + "replication-factor": 3, + 'configs': {"min.insync.replicas": 2}}}, + controller_num_nodes_override=1) + for node in self.kafka.nodes: + if timestamp_type is not None: + node.config[config_property.MESSAGE_TIMESTAMP_TYPE] = timestamp_type + self.kafka.start() + + self.producer = VerifiableProducer(self.test_context, self.num_producers, self.kafka, + self.topic, throughput=self.producer_throughput, + message_validator=is_int, + compression_types=compression_types, + version=KafkaVersion(producer_version)) + + self.consumer = ConsoleConsumer(self.test_context, self.num_consumers, self.kafka, + self.topic, consumer_timeout_ms=30000, new_consumer=new_consumer, + message_validator=is_int, version=KafkaVersion(consumer_version)) + + self.run_produce_consume_validate(lambda: wait_until( + lambda: self.producer.each_produced_at_least(self.messages_per_producer) == True, + timeout_sec=120, backoff_sec=1, + err_msg="Producer did not produce all messages in reasonable amount of time")) diff --git a/tests/kafkatest/tests/core/consume_bench_test.py b/tests/kafkatest/tests/core/consume_bench_test.py new file mode 100644 index 0000000..ce08d80 --- /dev/null +++ b/tests/kafkatest/tests/core/consume_bench_test.py @@ -0,0 +1,218 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from ducktape.mark import matrix +from ducktape.mark.resource import cluster +from ducktape.tests.test import Test +from kafkatest.services.kafka import KafkaService, quorum +from kafkatest.services.trogdor.produce_bench_workload import ProduceBenchWorkloadService, ProduceBenchWorkloadSpec +from kafkatest.services.trogdor.consume_bench_workload import ConsumeBenchWorkloadService, ConsumeBenchWorkloadSpec +from kafkatest.services.trogdor.task_spec import TaskSpec +from kafkatest.services.trogdor.trogdor import TrogdorService +from kafkatest.services.zookeeper import ZookeeperService + + +class ConsumeBenchTest(Test): + def __init__(self, test_context): + """:type test_context: ducktape.tests.test.TestContext""" + super(ConsumeBenchTest, self).__init__(test_context) + self.zk = ZookeeperService(test_context, num_nodes=3) if quorum.for_test(test_context) == quorum.zk else None + self.kafka = KafkaService(test_context, num_nodes=3, zk=self.zk) + self.producer_workload_service = ProduceBenchWorkloadService(test_context, self.kafka) + self.consumer_workload_service = ConsumeBenchWorkloadService(test_context, self.kafka) + self.consumer_workload_service_2 = ConsumeBenchWorkloadService(test_context, self.kafka) + self.active_topics = {"consume_bench_topic[0-5]": {"numPartitions": 5, "replicationFactor": 3}} + self.trogdor = TrogdorService(context=self.test_context, + client_services=[self.kafka, self.producer_workload_service, + self.consumer_workload_service, + self.consumer_workload_service_2]) + + def setUp(self): + self.trogdor.start() + if self.zk: + self.zk.start() + self.kafka.start() + + def teardown(self): + self.trogdor.stop() + self.kafka.stop() + if self.zk: + self.zk.stop() + + def produce_messages(self, topics, max_messages=10000): + produce_spec = ProduceBenchWorkloadSpec(0, TaskSpec.MAX_DURATION_MS, + self.producer_workload_service.producer_node, + self.producer_workload_service.bootstrap_servers, + target_messages_per_sec=1000, + max_messages=max_messages, + producer_conf={}, + admin_client_conf={}, + common_client_conf={}, + inactive_topics={}, + active_topics=topics) + produce_workload = self.trogdor.create_task("produce_workload", produce_spec) + produce_workload.wait_for_done(timeout_sec=180) + self.logger.debug("Produce workload finished") + + @cluster(num_nodes=10) + @matrix(topics=[["consume_bench_topic[0-5]"]], metadata_quorum=quorum.all_non_upgrade) # topic subscription + @matrix(topics=[["consume_bench_topic[0-5]:[0-4]"]], metadata_quorum=quorum.all_non_upgrade) # manual topic assignment + def test_consume_bench(self, topics, metadata_quorum=quorum.zk): + """ + Runs a ConsumeBench workload to consume messages + """ + self.produce_messages(self.active_topics) + consume_spec = ConsumeBenchWorkloadSpec(0, TaskSpec.MAX_DURATION_MS, + self.consumer_workload_service.consumer_node, + self.consumer_workload_service.bootstrap_servers, + target_messages_per_sec=1000, + max_messages=10000, + consumer_conf={}, + admin_client_conf={}, + common_client_conf={}, + active_topics=topics) + consume_workload = self.trogdor.create_task("consume_workload", consume_spec) + consume_workload.wait_for_done(timeout_sec=360) + self.logger.debug("Consume workload finished") + tasks = self.trogdor.tasks() + self.logger.info("TASKS: %s\n" % json.dumps(tasks, sort_keys=True, indent=2)) + + @cluster(num_nodes=10) + @matrix(metadata_quorum=quorum.all_non_upgrade) + def test_single_partition(self, metadata_quorum=quorum.zk): + """ + Run a ConsumeBench against a single partition + """ + active_topics = {"consume_bench_topic": {"numPartitions": 2, "replicationFactor": 3}} + self.produce_messages(active_topics, 5000) + consume_spec = ConsumeBenchWorkloadSpec(0, TaskSpec.MAX_DURATION_MS, + self.consumer_workload_service.consumer_node, + self.consumer_workload_service.bootstrap_servers, + target_messages_per_sec=1000, + max_messages=2500, + consumer_conf={}, + admin_client_conf={}, + common_client_conf={}, + active_topics=["consume_bench_topic:1"]) + consume_workload = self.trogdor.create_task("consume_workload", consume_spec) + consume_workload.wait_for_done(timeout_sec=180) + self.logger.debug("Consume workload finished") + tasks = self.trogdor.tasks() + self.logger.info("TASKS: %s\n" % json.dumps(tasks, sort_keys=True, indent=2)) + + @cluster(num_nodes=10) + @matrix(metadata_quorum=quorum.all_non_upgrade) + def test_multiple_consumers_random_group_topics(self, metadata_quorum=quorum.zk): + """ + Runs multiple consumers group to read messages from topics. + Since a consumerGroup isn't specified, each consumer should read from all topics independently + """ + self.produce_messages(self.active_topics, max_messages=5000) + consume_spec = ConsumeBenchWorkloadSpec(0, TaskSpec.MAX_DURATION_MS, + self.consumer_workload_service.consumer_node, + self.consumer_workload_service.bootstrap_servers, + target_messages_per_sec=1000, + max_messages=5000, # all should read exactly 5k messages + consumer_conf={}, + admin_client_conf={}, + common_client_conf={}, + threads_per_worker=5, + active_topics=["consume_bench_topic[0-5]"]) + consume_workload = self.trogdor.create_task("consume_workload", consume_spec) + consume_workload.wait_for_done(timeout_sec=360) + self.logger.debug("Consume workload finished") + tasks = self.trogdor.tasks() + self.logger.info("TASKS: %s\n" % json.dumps(tasks, sort_keys=True, indent=2)) + + @cluster(num_nodes=10) + @matrix(metadata_quorum=quorum.all_non_upgrade) + def test_two_consumers_specified_group_topics(self, metadata_quorum=quorum.zk): + """ + Runs two consumers in the same consumer group to read messages from topics. + Since a consumerGroup is specified, each consumer should dynamically get assigned a partition from group + """ + self.produce_messages(self.active_topics) + consume_spec = ConsumeBenchWorkloadSpec(0, TaskSpec.MAX_DURATION_MS, + self.consumer_workload_service.consumer_node, + self.consumer_workload_service.bootstrap_servers, + target_messages_per_sec=1000, + max_messages=2000, # both should read at least 2k messages + consumer_conf={}, + admin_client_conf={}, + common_client_conf={}, + threads_per_worker=2, + consumer_group="testGroup", + active_topics=["consume_bench_topic[0-5]"]) + consume_workload = self.trogdor.create_task("consume_workload", consume_spec) + consume_workload.wait_for_done(timeout_sec=360) + self.logger.debug("Consume workload finished") + tasks = self.trogdor.tasks() + self.logger.info("TASKS: %s\n" % json.dumps(tasks, sort_keys=True, indent=2)) + + @cluster(num_nodes=10) + @matrix(metadata_quorum=quorum.all_non_upgrade) + def test_multiple_consumers_random_group_partitions(self, metadata_quorum=quorum.zk): + """ + Runs multiple consumers in to read messages from specific partitions. + Since a consumerGroup isn't specified, each consumer will get assigned a random group + and consume from all partitions + """ + self.produce_messages(self.active_topics, max_messages=20000) + consume_spec = ConsumeBenchWorkloadSpec(0, TaskSpec.MAX_DURATION_MS, + self.consumer_workload_service.consumer_node, + self.consumer_workload_service.bootstrap_servers, + target_messages_per_sec=1000, + max_messages=2000, + consumer_conf={}, + admin_client_conf={}, + common_client_conf={}, + threads_per_worker=4, + active_topics=["consume_bench_topic1:[0-4]"]) + consume_workload = self.trogdor.create_task("consume_workload", consume_spec) + consume_workload.wait_for_done(timeout_sec=360) + self.logger.debug("Consume workload finished") + tasks = self.trogdor.tasks() + self.logger.info("TASKS: %s\n" % json.dumps(tasks, sort_keys=True, indent=2)) + + @cluster(num_nodes=10) + @matrix(metadata_quorum=quorum.all_non_upgrade) + def test_multiple_consumers_specified_group_partitions_should_raise(self, metadata_quorum=quorum.zk): + """ + Runs multiple consumers in the same group to read messages from specific partitions. + It is an invalid configuration to provide a consumer group and specific partitions. + """ + expected_error_msg = 'explicit partition assignment' + self.produce_messages(self.active_topics, max_messages=20000) + consume_spec = ConsumeBenchWorkloadSpec(0, TaskSpec.MAX_DURATION_MS, + self.consumer_workload_service.consumer_node, + self.consumer_workload_service.bootstrap_servers, + target_messages_per_sec=1000, + max_messages=2000, + consumer_conf={}, + admin_client_conf={}, + common_client_conf={}, + threads_per_worker=4, + consumer_group="fail_group", + active_topics=["consume_bench_topic1:[0-4]"]) + consume_workload = self.trogdor.create_task("consume_workload", consume_spec) + try: + consume_workload.wait_for_done(timeout_sec=360) + raise Exception("Should have raised an exception due to an invalid configuration") + except RuntimeError as e: + if expected_error_msg not in str(e): + raise RuntimeError("Unexpected Exception - " + str(e)) + self.logger.info(e) + diff --git a/tests/kafkatest/tests/core/consumer_group_command_test.py b/tests/kafkatest/tests/core/consumer_group_command_test.py new file mode 100644 index 0000000..f81eec8 --- /dev/null +++ b/tests/kafkatest/tests/core/consumer_group_command_test.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ducktape.utils.util import wait_until +from ducktape.tests.test import Test +from ducktape.mark import matrix +from ducktape.mark.resource import cluster + +from kafkatest.services.zookeeper import ZookeeperService +from kafkatest.services.kafka import KafkaService, quorum +from kafkatest.services.console_consumer import ConsoleConsumer +from kafkatest.services.security.security_config import SecurityConfig + +import os +import re + +TOPIC = "topic-consumer-group-command" + + +class ConsumerGroupCommandTest(Test): + """ + Tests ConsumerGroupCommand + """ + # Root directory for persistent output + PERSISTENT_ROOT = "/mnt/consumer_group_command" + COMMAND_CONFIG_FILE = os.path.join(PERSISTENT_ROOT, "command.properties") + + def __init__(self, test_context): + super(ConsumerGroupCommandTest, self).__init__(test_context) + self.num_zk = 1 + self.num_brokers = 1 + self.topics = { + TOPIC: {'partitions': 1, 'replication-factor': 1} + } + self.zk = ZookeeperService(test_context, self.num_zk) if quorum.for_test(test_context) == quorum.zk else None + + def setUp(self): + if self.zk: + self.zk.start() + + def start_kafka(self, security_protocol, interbroker_security_protocol): + self.kafka = KafkaService( + self.test_context, self.num_brokers, + self.zk, security_protocol=security_protocol, + interbroker_security_protocol=interbroker_security_protocol, topics=self.topics, + controller_num_nodes_override=self.num_zk) + self.kafka.start() + + def start_consumer(self): + self.consumer = ConsoleConsumer(self.test_context, num_nodes=self.num_brokers, kafka=self.kafka, topic=TOPIC, + consumer_timeout_ms=None) + self.consumer.start() + + def setup_and_verify(self, security_protocol, group=None): + self.start_kafka(security_protocol, security_protocol) + self.start_consumer() + consumer_node = self.consumer.nodes[0] + wait_until(lambda: self.consumer.alive(consumer_node), + timeout_sec=20, backoff_sec=.2, err_msg="Consumer was too slow to start") + kafka_node = self.kafka.nodes[0] + if security_protocol is not SecurityConfig.PLAINTEXT: + prop_file = str(self.kafka.security_config.client_config()) + self.logger.debug(prop_file) + kafka_node.account.ssh("mkdir -p %s" % self.PERSISTENT_ROOT, allow_fail=False) + kafka_node.account.create_file(self.COMMAND_CONFIG_FILE, prop_file) + + # Verify ConsumerGroupCommand lists expected consumer groups + command_config_file = self.COMMAND_CONFIG_FILE + + if group: + wait_until(lambda: re.search("topic-consumer-group-command",self.kafka.describe_consumer_group(group=group, node=kafka_node, command_config=command_config_file)), timeout_sec=10, + err_msg="Timed out waiting to list expected consumer groups.") + else: + wait_until(lambda: "test-consumer-group" in self.kafka.list_consumer_groups(node=kafka_node, command_config=command_config_file), timeout_sec=10, + err_msg="Timed out waiting to list expected consumer groups.") + + self.consumer.stop() + + @cluster(num_nodes=3) + @matrix(security_protocol=['PLAINTEXT', 'SSL'], metadata_quorum=quorum.all_non_upgrade) + def test_list_consumer_groups(self, security_protocol='PLAINTEXT', metadata_quorum=quorum.zk): + """ + Tests if ConsumerGroupCommand is listing correct consumer groups + :return: None + """ + self.setup_and_verify(security_protocol) + + @cluster(num_nodes=3) + @matrix(security_protocol=['PLAINTEXT', 'SSL'], metadata_quorum=quorum.all_non_upgrade) + def test_describe_consumer_group(self, security_protocol='PLAINTEXT', metadata_quorum=quorum.zk): + """ + Tests if ConsumerGroupCommand is describing a consumer group correctly + :return: None + """ + self.setup_and_verify(security_protocol, group="test-consumer-group") diff --git a/tests/kafkatest/tests/core/delegation_token_test.py b/tests/kafkatest/tests/core/delegation_token_test.py new file mode 100644 index 0000000..7b508bc --- /dev/null +++ b/tests/kafkatest/tests/core/delegation_token_test.py @@ -0,0 +1,132 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ducktape.mark.resource import cluster +from ducktape.tests.test import Test +from ducktape.utils.util import wait_until +from kafkatest.services.kafka import config_property, KafkaService +from kafkatest.services.zookeeper import ZookeeperService +from kafkatest.services.console_consumer import ConsoleConsumer +from kafkatest.services.delegation_tokens import DelegationTokens +from kafkatest.services.verifiable_producer import VerifiableProducer + +from datetime import datetime +import time + +""" +Basic tests to validate delegation token support +""" +class DelegationTokenTest(Test): + def __init__(self, test_context): + super(DelegationTokenTest, self).__init__(test_context) + + self.test_context = test_context + self.topic = "topic" + self.zk = ZookeeperService(test_context, num_nodes=1) + self.kafka = KafkaService(self.test_context, num_nodes=1, zk=self.zk, zk_chroot="/kafka", + topics={self.topic: {"partitions": 1, "replication-factor": 1}}, + server_prop_overrides=[ + [config_property.DELEGATION_TOKEN_MAX_LIFETIME_MS, "604800000"], + [config_property.DELEGATION_TOKEN_EXPIRY_TIME_MS, "86400000"], + [config_property.DELEGATION_TOKEN_SECRET_KEY, "test12345"], + [config_property.SASL_ENABLED_MECHANISMS, "GSSAPI,SCRAM-SHA-256"] + ]) + self.jaas_deleg_conf_path = "/tmp/jaas_deleg.conf" + self.jaas_deleg_conf = "" + self.client_properties_content = """ +security.protocol=SASL_PLAINTEXT +sasl.mechanism=SCRAM-SHA-256 +sasl.kerberos.service.name=kafka +client.id=console-consumer +""" + self.client_kafka_opts=' -Djava.security.auth.login.config=' + self.jaas_deleg_conf_path + + self.producer = VerifiableProducer(self.test_context, num_nodes=1, kafka=self.kafka, topic=self.topic, max_messages=1, + throughput=1, kafka_opts_override=self.client_kafka_opts, + client_prop_file_override=self.client_properties_content) + + self.consumer = ConsoleConsumer(self.test_context, num_nodes=1, kafka=self.kafka, topic=self.topic, + kafka_opts_override=self.client_kafka_opts, + client_prop_file_override=self.client_properties_content) + + self.kafka.security_protocol = 'SASL_PLAINTEXT' + self.kafka.client_sasl_mechanism = 'GSSAPI,SCRAM-SHA-256' + self.kafka.interbroker_sasl_mechanism = 'GSSAPI' + + + def setUp(self): + self.zk.start() + + def tearDown(self): + self.producer.nodes[0].account.remove(self.jaas_deleg_conf_path) + self.consumer.nodes[0].account.remove(self.jaas_deleg_conf_path) + + def generate_delegation_token(self): + self.logger.debug("Request delegation token") + self.delegation_tokens.generate_delegation_token() + self.jaas_deleg_conf = self.delegation_tokens.create_jaas_conf_with_delegation_token() + + def expire_delegation_token(self): + self.kafka.client_sasl_mechanism = 'GSSAPI,SCRAM-SHA-256' + token_hmac = self.delegation_tokens.token_hmac() + self.delegation_tokens.expire_delegation_token(token_hmac) + + + def produce_with_delegation_token(self): + self.producer.acked_values = [] + self.producer.nodes[0].account.create_file(self.jaas_deleg_conf_path, self.jaas_deleg_conf) + self.logger.debug(self.jaas_deleg_conf) + self.producer.start() + + def consume_with_delegation_token(self): + self.logger.debug("Consume messages with delegation token") + + self.consumer.nodes[0].account.create_file(self.jaas_deleg_conf_path, self.jaas_deleg_conf) + self.logger.debug(self.jaas_deleg_conf) + self.consumer.consumer_timeout_ms = 5000 + + self.consumer.start() + self.consumer.wait() + + def get_datetime_ms(self, input_date): + return int(time.mktime(datetime.strptime(input_date,"%Y-%m-%dT%H:%M").timetuple()) * 1000) + + def renew_delegation_token(self): + dt = self.delegation_tokens.parse_delegation_token_out() + orig_expiry_date_ms = self.get_datetime_ms(dt["expirydate"]) + new_expirydate_ms = orig_expiry_date_ms + 1000 + + self.delegation_tokens.renew_delegation_token(dt["hmac"], new_expirydate_ms) + + @cluster(num_nodes=5) + def test_delegation_token_lifecycle(self): + self.kafka.start() + self.delegation_tokens = DelegationTokens(self.kafka, self.test_context) + + self.generate_delegation_token() + self.renew_delegation_token() + self.produce_with_delegation_token() + wait_until(lambda: self.producer.num_acked > 0, timeout_sec=30, + err_msg="Expected producer to still be producing.") + assert 1 == self.producer.num_acked, "number of acked messages: %d" % self.producer.num_acked + + self.consume_with_delegation_token() + num_consumed = len(self.consumer.messages_consumed[1]) + assert 1 == num_consumed, "number of consumed messages: %d" % num_consumed + + self.expire_delegation_token() + + self.produce_with_delegation_token() + assert 0 == self.producer.num_acked, "number of acked messages: %d" % self.producer.num_acked \ No newline at end of file diff --git a/tests/kafkatest/tests/core/downgrade_test.py b/tests/kafkatest/tests/core/downgrade_test.py new file mode 100644 index 0000000..2ec453a --- /dev/null +++ b/tests/kafkatest/tests/core/downgrade_test.py @@ -0,0 +1,147 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ducktape.mark import parametrize, matrix +from ducktape.mark.resource import cluster +from ducktape.utils.util import wait_until + +from kafkatest.services.kafka import config_property +from kafkatest.tests.end_to_end import EndToEndTest +from kafkatest.version import LATEST_1_1, LATEST_2_0, LATEST_2_1, LATEST_2_2, LATEST_2_3, LATEST_2_4, LATEST_2_5, LATEST_2_6, LATEST_2_7, LATEST_2_8, DEV_BRANCH, KafkaVersion + +class TestDowngrade(EndToEndTest): + PARTITIONS = 3 + REPLICATION_FACTOR = 3 + + TOPIC_CONFIG = { + "partitions": PARTITIONS, + "replication-factor": REPLICATION_FACTOR, + "configs": {"min.insync.replicas": 2} + } + + def __init__(self, test_context): + super(TestDowngrade, self).__init__(test_context=test_context, topic_config=self.TOPIC_CONFIG) + + def upgrade_from(self, kafka_version): + for node in self.kafka.nodes: + self.kafka.stop_node(node) + node.version = DEV_BRANCH + node.config[config_property.INTER_BROKER_PROTOCOL_VERSION] = str(kafka_version) + node.config[config_property.MESSAGE_FORMAT_VERSION] = str(kafka_version) + self.kafka.start_node(node) + self.wait_until_rejoin() + + def downgrade_to(self, kafka_version): + for node in self.kafka.nodes: + self.kafka.stop_node(node) + node.version = kafka_version + del node.config[config_property.INTER_BROKER_PROTOCOL_VERSION] + del node.config[config_property.MESSAGE_FORMAT_VERSION] + self.kafka.start_node(node) + self.wait_until_rejoin() + + def setup_services(self, kafka_version, compression_types, security_protocol, static_membership): + self.create_zookeeper_if_necessary() + self.zk.start() + + self.create_kafka(num_nodes=3, + security_protocol=security_protocol, + interbroker_security_protocol=security_protocol, + version=kafka_version) + self.kafka.start() + + self.create_producer(log_level="DEBUG", + compression_types=compression_types, + version=kafka_version) + self.producer.start() + + self.create_consumer(log_level="DEBUG", + version=kafka_version, + static_membership=static_membership) + + self.consumer.start() + + def wait_until_rejoin(self): + for partition in range(0, self.PARTITIONS): + wait_until(lambda: len(self.kafka.isr_idx_list(self.topic, partition)) == self.REPLICATION_FACTOR, + timeout_sec=60, backoff_sec=1, err_msg="Replicas did not rejoin the ISR in a reasonable amount of time") + + @cluster(num_nodes=7) + @parametrize(version=str(LATEST_2_8), compression_types=["snappy"]) + @parametrize(version=str(LATEST_2_8), compression_types=["zstd"], security_protocol="SASL_SSL") + @matrix(version=[str(LATEST_2_8)], compression_types=[["none"]], static_membership=[False, True]) + @parametrize(version=str(LATEST_2_7), compression_types=["lz4"]) + @parametrize(version=str(LATEST_2_7), compression_types=["zstd"], security_protocol="SASL_SSL") + @matrix(version=[str(LATEST_2_7)], compression_types=[["none"]], static_membership=[False, True]) + @parametrize(version=str(LATEST_2_6), compression_types=["lz4"]) + @parametrize(version=str(LATEST_2_6), compression_types=["zstd"], security_protocol="SASL_SSL") + @matrix(version=[str(LATEST_2_6)], compression_types=[["none"]], static_membership=[False, True]) + @matrix(version=[str(LATEST_2_5)], compression_types=[["none"]], static_membership=[False, True]) + @parametrize(version=str(LATEST_2_5), compression_types=["zstd"], security_protocol="SASL_SSL") + # static membership was introduced with a buggy verifiable console consumer which + # required static membership to be enabled + @parametrize(version=str(LATEST_2_4), compression_types=["none"], static_membership=True) + @parametrize(version=str(LATEST_2_4), compression_types=["zstd"], security_protocol="SASL_SSL", static_membership=True) + @parametrize(version=str(LATEST_2_3), compression_types=["none"]) + @parametrize(version=str(LATEST_2_3), compression_types=["zstd"], security_protocol="SASL_SSL") + @parametrize(version=str(LATEST_2_2), compression_types=["none"]) + @parametrize(version=str(LATEST_2_2), compression_types=["zstd"], security_protocol="SASL_SSL") + @parametrize(version=str(LATEST_2_1), compression_types=["none"]) + @parametrize(version=str(LATEST_2_1), compression_types=["lz4"], security_protocol="SASL_SSL") + @parametrize(version=str(LATEST_2_0), compression_types=["none"]) + @parametrize(version=str(LATEST_2_0), compression_types=["snappy"], security_protocol="SASL_SSL") + @parametrize(version=str(LATEST_1_1), compression_types=["none"]) + @parametrize(version=str(LATEST_1_1), compression_types=["lz4"], security_protocol="SASL_SSL") + def test_upgrade_and_downgrade(self, version, compression_types, security_protocol="PLAINTEXT", + static_membership=False): + """Test upgrade and downgrade of Kafka cluster from old versions to the current version + + `version` is the Kafka version to upgrade from and downgrade back to + + Downgrades are supported to any version which is at or above the current + `inter.broker.protocol.version` (IBP). For example, if a user upgrades from 1.1 to 2.3, + but they leave the IBP set to 1.1, then downgrading to any version at 1.1 or higher is + supported. + + This test case verifies that producers and consumers continue working during + the course of an upgrade and downgrade. + + - Start 3 node broker cluster on version 'kafka_version' + - Start producer and consumer in the background + - Roll the cluster to upgrade to the current version with IBP set to 'kafka_version' + - Roll the cluster to downgrade back to 'kafka_version' + - Finally, validate that every message acked by the producer was consumed by the consumer + """ + kafka_version = KafkaVersion(version) + + self.setup_services(kafka_version, compression_types, security_protocol, static_membership) + self.await_startup() + + start_topic_id = self.kafka.topic_id(self.topic) + + self.logger.info("First pass bounce - rolling upgrade") + self.upgrade_from(kafka_version) + self.run_validation() + + upgrade_topic_id = self.kafka.topic_id(self.topic) + assert start_topic_id == upgrade_topic_id + + self.logger.info("Second pass bounce - rolling downgrade") + self.downgrade_to(kafka_version) + self.run_validation() + + downgrade_topic_id = self.kafka.topic_id(self.topic) + assert upgrade_topic_id == downgrade_topic_id + assert self.kafka.check_protocol_errors(self) diff --git a/tests/kafkatest/tests/core/fetch_from_follower_test.py b/tests/kafkatest/tests/core/fetch_from_follower_test.py new file mode 100644 index 0000000..f720de1 --- /dev/null +++ b/tests/kafkatest/tests/core/fetch_from_follower_test.py @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from collections import defaultdict + +from ducktape.mark import matrix +from ducktape.mark.resource import cluster + +from kafkatest.services.console_consumer import ConsoleConsumer +from kafkatest.services.kafka import KafkaService, quorum +from kafkatest.services.monitor.jmx import JmxTool +from kafkatest.services.verifiable_producer import VerifiableProducer +from kafkatest.services.zookeeper import ZookeeperService +from kafkatest.tests.produce_consume_validate import ProduceConsumeValidateTest +from kafkatest.utils import is_int + + +class FetchFromFollowerTest(ProduceConsumeValidateTest): + + RACK_AWARE_REPLICA_SELECTOR = "org.apache.kafka.common.replica.RackAwareReplicaSelector" + METADATA_MAX_AGE_MS = 3000 + + def __init__(self, test_context): + super(FetchFromFollowerTest, self).__init__(test_context=test_context) + self.jmx_tool = JmxTool(test_context, jmx_poll_ms=100) + self.topic = "test_topic" + self.zk = ZookeeperService(test_context, num_nodes=1) if quorum.for_test(test_context) == quorum.zk else None + self.kafka = KafkaService(test_context, + num_nodes=3, + zk=self.zk, + topics={ + self.topic: { + "partitions": 1, + "replication-factor": 3, + "configs": {"min.insync.replicas": 1}}, + }, + server_prop_overrides=[ + ["replica.selector.class", self.RACK_AWARE_REPLICA_SELECTOR] + ], + per_node_server_prop_overrides={ + 1: [("broker.rack", "rack-a")], + 2: [("broker.rack", "rack-b")], + 3: [("broker.rack", "rack-c")] + }, + controller_num_nodes_override=1) + + self.producer_throughput = 1000 + self.num_producers = 1 + self.num_consumers = 1 + + def min_cluster_size(self): + return super(FetchFromFollowerTest, self).min_cluster_size() + self.num_producers * 2 + self.num_consumers * 2 + + def setUp(self): + if self.zk: + self.zk.start() + self.kafka.start() + + @cluster(num_nodes=9) + @matrix(metadata_quorum=quorum.all_non_upgrade) + def test_consumer_preferred_read_replica(self, metadata_quorum=quorum.zk): + """ + This test starts up brokers with "broker.rack" and "replica.selector.class" configurations set. The replica + selector is set to the rack-aware implementation. One of the brokers has a different rack than the other two. + We then use a console consumer with the "client.rack" set to the same value as the differing broker. After + producing some records, we verify that the client has been informed of the preferred replica and that all the + records are properly consumed. + """ + + # Find the leader, configure consumer to be on a different rack + leader_node = self.kafka.leader(self.topic, 0) + leader_idx = self.kafka.idx(leader_node) + non_leader_idx = 2 if leader_idx != 2 else 1 + non_leader_rack = "rack-b" if leader_idx != 2 else "rack-a" + + self.logger.debug("Leader %d %s" % (leader_idx, leader_node)) + self.logger.debug("Non-Leader %d %s" % (non_leader_idx, non_leader_rack)) + + self.producer = VerifiableProducer(self.test_context, self.num_producers, self.kafka, self.topic, + throughput=self.producer_throughput) + self.consumer = ConsoleConsumer(self.test_context, self.num_consumers, self.kafka, self.topic, + client_id="console-consumer", group_id="test-consumer-group-1", + consumer_timeout_ms=60000, message_validator=is_int, + consumer_properties={"client.rack": non_leader_rack, "metadata.max.age.ms": self.METADATA_MAX_AGE_MS}) + + # Start up and let some data get produced + self.start_producer_and_consumer() + time.sleep(self.METADATA_MAX_AGE_MS * 2. / 1000) + + consumer_node = self.consumer.nodes[0] + consumer_idx = self.consumer.idx(consumer_node) + read_replica_attribute = "preferred-read-replica" + read_replica_mbean = "kafka.consumer:type=consumer-fetch-manager-metrics,client-id=%s,topic=%s,partition=%d" % \ + ("console-consumer", self.topic, 0) + self.jmx_tool.jmx_object_names = [read_replica_mbean] + self.jmx_tool.jmx_attributes = [read_replica_attribute] + self.jmx_tool.start_jmx_tool(consumer_idx, consumer_node) + + # Wait for at least one interval of "metadata.max.age.ms" + time.sleep(self.METADATA_MAX_AGE_MS * 2. / 1000) + + # Read the JMX output + self.jmx_tool.read_jmx_output(consumer_idx, consumer_node) + + all_captured_preferred_read_replicas = defaultdict(int) + self.logger.debug(self.jmx_tool.jmx_stats) + + for ts, data in self.jmx_tool.jmx_stats[0].items(): + for k, v in data.items(): + if k.endswith(read_replica_attribute): + all_captured_preferred_read_replicas[int(v)] += 1 + + self.logger.debug("Saw the following preferred read replicas %s", + dict(all_captured_preferred_read_replicas.items())) + + assert all_captured_preferred_read_replicas[non_leader_idx] > 0, \ + "Expected to see broker %d (%s) as a preferred replica" % (non_leader_idx, non_leader_rack) + + # Validate consumed messages + self.stop_producer_and_consumer() + self.validate() diff --git a/tests/kafkatest/tests/core/get_offset_shell_test.py b/tests/kafkatest/tests/core/get_offset_shell_test.py new file mode 100644 index 0000000..b24c5ac --- /dev/null +++ b/tests/kafkatest/tests/core/get_offset_shell_test.py @@ -0,0 +1,260 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from ducktape.utils.util import wait_until +from ducktape.tests.test import Test +from ducktape.mark import matrix +from ducktape.mark.resource import cluster + +from kafkatest.services.verifiable_producer import VerifiableProducer +from kafkatest.services.zookeeper import ZookeeperService +from kafkatest.services.kafka import KafkaService, quorum +from kafkatest.services.console_consumer import ConsoleConsumer + +MAX_MESSAGES = 100 +NUM_PARTITIONS = 1 +REPLICATION_FACTOR = 1 + +TOPIC_TEST_NAME = "topic-get-offset-shell-topic-name" + +TOPIC_TEST_PATTERN_PREFIX = "topic-get-offset-shell-topic-pattern" +TOPIC_TEST_PATTERN_PATTERN = TOPIC_TEST_PATTERN_PREFIX + ".*" +TOPIC_TEST_PATTERN1 = TOPIC_TEST_PATTERN_PREFIX + "1" +TOPIC_TEST_PATTERN2 = TOPIC_TEST_PATTERN_PREFIX + "2" + +TOPIC_TEST_PARTITIONS = "topic-get-offset-shell-partitions" + +TOPIC_TEST_INTERNAL_FILTER = "topic-get-offset-shell-consumer_offsets" + +TOPIC_TEST_TOPIC_PARTITIONS_PREFIX = "topic-get-offset-shell-topic-partitions" +TOPIC_TEST_TOPIC_PARTITIONS_PATTERN = TOPIC_TEST_TOPIC_PARTITIONS_PREFIX + ".*" +TOPIC_TEST_TOPIC_PARTITIONS1 = TOPIC_TEST_TOPIC_PARTITIONS_PREFIX + "1" +TOPIC_TEST_TOPIC_PARTITIONS2 = TOPIC_TEST_TOPIC_PARTITIONS_PREFIX + "2" + + +class GetOffsetShellTest(Test): + """ + Tests GetOffsetShell tool + """ + def __init__(self, test_context): + super(GetOffsetShellTest, self).__init__(test_context) + self.num_zk = 1 + self.num_brokers = 1 + self.messages_received_count = 0 + self.topics = { + TOPIC_TEST_NAME: {'partitions': NUM_PARTITIONS, 'replication-factor': REPLICATION_FACTOR}, + TOPIC_TEST_PATTERN1: {'partitions': 1, 'replication-factor': REPLICATION_FACTOR}, + TOPIC_TEST_PATTERN2: {'partitions': 1, 'replication-factor': REPLICATION_FACTOR}, + TOPIC_TEST_PARTITIONS: {'partitions': 2, 'replication-factor': REPLICATION_FACTOR}, + TOPIC_TEST_INTERNAL_FILTER: {'partitions': 1, 'replication-factor': REPLICATION_FACTOR}, + TOPIC_TEST_TOPIC_PARTITIONS1: {'partitions': 2, 'replication-factor': REPLICATION_FACTOR}, + TOPIC_TEST_TOPIC_PARTITIONS2: {'partitions': 2, 'replication-factor': REPLICATION_FACTOR} + } + + self.zk = ZookeeperService(test_context, self.num_zk) if quorum.for_test(test_context) == quorum.zk else None + + def setUp(self): + if self.zk: + self.zk.start() + + def start_kafka(self, security_protocol, interbroker_security_protocol): + self.kafka = KafkaService( + self.test_context, self.num_brokers, + self.zk, security_protocol=security_protocol, + interbroker_security_protocol=interbroker_security_protocol, topics=self.topics) + self.kafka.start() + + def start_producer(self, topic): + # This will produce to kafka cluster + self.producer = VerifiableProducer(self.test_context, num_nodes=1, kafka=self.kafka, topic=topic, + throughput=1000, max_messages=MAX_MESSAGES, repeating_keys=MAX_MESSAGES) + self.producer.start() + current_acked = self.producer.num_acked + wait_until(lambda: self.producer.num_acked >= current_acked + MAX_MESSAGES, timeout_sec=10, + err_msg="Timeout awaiting messages to be produced and acked") + + def start_consumer(self, topic): + self.consumer = ConsoleConsumer(self.test_context, num_nodes=self.num_brokers, kafka=self.kafka, topic=topic, + consumer_timeout_ms=1000) + self.consumer.start() + + def check_message_count_sum_equals(self, message_count, **kwargs): + sum = self.extract_message_count_sum(**kwargs) + return sum == message_count + + def extract_message_count_sum(self, **kwargs): + offsets = self.kafka.get_offset_shell(**kwargs).split("\n") + sum = 0 + for offset in offsets: + if len(offset) == 0: + continue + sum += int(offset.split(":")[-1]) + return sum + + @cluster(num_nodes=3) + @matrix(metadata_quorum=quorum.all_non_upgrade) + def test_get_offset_shell_topic_name(self, security_protocol='PLAINTEXT', metadata_quorum=quorum.zk): + """ + Tests if GetOffsetShell handles --topic argument with a simple name correctly + :return: None + """ + self.start_kafka(security_protocol, security_protocol) + self.start_producer(TOPIC_TEST_NAME) + + # Assert that offset is correctly indicated by GetOffsetShell tool + wait_until(lambda: self.check_message_count_sum_equals(MAX_MESSAGES, topic=TOPIC_TEST_NAME), + timeout_sec=10, err_msg="Timed out waiting to reach expected offset.") + + @cluster(num_nodes=4) + @matrix(metadata_quorum=quorum.all_non_upgrade) + def test_get_offset_shell_topic_pattern(self, security_protocol='PLAINTEXT', metadata_quorum=quorum.zk): + """ + Tests if GetOffsetShell handles --topic argument with a pattern correctly + :return: None + """ + self.start_kafka(security_protocol, security_protocol) + self.start_producer(TOPIC_TEST_PATTERN1) + self.start_producer(TOPIC_TEST_PATTERN2) + + # Assert that offset is correctly indicated by GetOffsetShell tool + wait_until(lambda: self.check_message_count_sum_equals(2*MAX_MESSAGES, topic=TOPIC_TEST_PATTERN_PATTERN), + timeout_sec=10, err_msg="Timed out waiting to reach expected offset.") + + @cluster(num_nodes=3) + @matrix(metadata_quorum=quorum.all_non_upgrade) + def test_get_offset_shell_partitions(self, security_protocol='PLAINTEXT', metadata_quorum=quorum.zk): + """ + Tests if GetOffsetShell handles --partitions argument correctly + :return: None + """ + self.start_kafka(security_protocol, security_protocol) + self.start_producer(TOPIC_TEST_PARTITIONS) + + def fetch_and_sum_partitions_separately(): + partition_count0 = self.extract_message_count_sum(topic=TOPIC_TEST_PARTITIONS, partitions="0") + partition_count1 = self.extract_message_count_sum(topic=TOPIC_TEST_PARTITIONS, partitions="1") + return partition_count0 + partition_count1 == MAX_MESSAGES + + # Assert that offset is correctly indicated when fetching partitions one by one + wait_until(fetch_and_sum_partitions_separately, timeout_sec=10, err_msg="Timed out waiting to reach expected offset.") + + # Assert that offset is correctly indicated when fetching partitions together + wait_until(lambda: self.check_message_count_sum_equals(MAX_MESSAGES, topic=TOPIC_TEST_PARTITIONS), + timeout_sec=10, err_msg="Timed out waiting to reach expected offset.") + + @cluster(num_nodes=4) + @matrix(metadata_quorum=quorum.all_non_upgrade) + def test_get_offset_shell_topic_partitions(self, security_protocol='PLAINTEXT', metadata_quorum=quorum.zk): + """ + Tests if GetOffsetShell handles --topic-partitions argument correctly + :return: None + """ + self.start_kafka(security_protocol, security_protocol) + self.start_producer(TOPIC_TEST_TOPIC_PARTITIONS1) + self.start_producer(TOPIC_TEST_TOPIC_PARTITIONS2) + + # Assert that a single topic pattern matches all 4 partitions + wait_until(lambda: self.check_message_count_sum_equals(2*MAX_MESSAGES, topic_partitions=TOPIC_TEST_TOPIC_PARTITIONS_PATTERN), + timeout_sec=10, err_msg="Timed out waiting to reach expected offset.") + + # Assert that a topic pattern with partition range matches all 4 partitions + wait_until(lambda: self.check_message_count_sum_equals(2*MAX_MESSAGES, topic_partitions=TOPIC_TEST_TOPIC_PARTITIONS_PATTERN + ":0-2"), + timeout_sec=10, err_msg="Timed out waiting to reach expected offset.") + + # Assert that 2 separate topic patterns match all 4 partitions + wait_until(lambda: self.check_message_count_sum_equals(2*MAX_MESSAGES, topic_partitions=TOPIC_TEST_TOPIC_PARTITIONS1 + "," + TOPIC_TEST_TOPIC_PARTITIONS2), + timeout_sec=10, err_msg="Timed out waiting to reach expected offset.") + + # Assert that 4 separate topic-partition patterns match all 4 partitions + wait_until(lambda: self.check_message_count_sum_equals(2*MAX_MESSAGES, + topic_partitions=TOPIC_TEST_TOPIC_PARTITIONS1 + ":0," + + TOPIC_TEST_TOPIC_PARTITIONS1 + ":1," + + TOPIC_TEST_TOPIC_PARTITIONS2 + ":0," + + TOPIC_TEST_TOPIC_PARTITIONS2 + ":1"), + timeout_sec=10, err_msg="Timed out waiting to reach expected offset.") + + # Assert that only partitions #0 are matched with topic pattern and fix partition number + filtered_partitions = self.kafka.get_offset_shell(topic_partitions=TOPIC_TEST_TOPIC_PARTITIONS_PATTERN + ":0") + assert 1 == filtered_partitions.count("%s:%s" % (TOPIC_TEST_TOPIC_PARTITIONS1, 0)) + assert 0 == filtered_partitions.count("%s:%s" % (TOPIC_TEST_TOPIC_PARTITIONS1, 1)) + assert 1 == filtered_partitions.count("%s:%s" % (TOPIC_TEST_TOPIC_PARTITIONS2, 0)) + assert 0 == filtered_partitions.count("%s:%s" % (TOPIC_TEST_TOPIC_PARTITIONS2, 1)) + + # Assert that only partitions #1 are matched with topic pattern and partition lower bound + filtered_partitions = self.kafka.get_offset_shell(topic_partitions=TOPIC_TEST_TOPIC_PARTITIONS_PATTERN + ":1-") + assert 1 == filtered_partitions.count("%s:%s" % (TOPIC_TEST_TOPIC_PARTITIONS1, 1)) + assert 0 == filtered_partitions.count("%s:%s" % (TOPIC_TEST_TOPIC_PARTITIONS1, 0)) + assert 1 == filtered_partitions.count("%s:%s" % (TOPIC_TEST_TOPIC_PARTITIONS2, 1)) + assert 0 == filtered_partitions.count("%s:%s" % (TOPIC_TEST_TOPIC_PARTITIONS2, 0)) + + # Assert that only partitions #0 are matched with topic pattern and partition upper bound + filtered_partitions = self.kafka.get_offset_shell(topic_partitions=TOPIC_TEST_TOPIC_PARTITIONS_PATTERN + ":-1") + assert 1 == filtered_partitions.count("%s:%s" % (TOPIC_TEST_TOPIC_PARTITIONS1, 0)) + assert 0 == filtered_partitions.count("%s:%s" % (TOPIC_TEST_TOPIC_PARTITIONS1, 1)) + assert 1 == filtered_partitions.count("%s:%s" % (TOPIC_TEST_TOPIC_PARTITIONS2, 0)) + assert 0 == filtered_partitions.count("%s:%s" % (TOPIC_TEST_TOPIC_PARTITIONS2, 1)) + + @cluster(num_nodes=4) + @matrix(metadata_quorum=quorum.all_non_upgrade) + def test_get_offset_shell_internal_filter(self, security_protocol='PLAINTEXT', metadata_quorum=quorum.zk): + """ + Tests if GetOffsetShell handles --exclude-internal-topics flag correctly + :return: None + """ + self.start_kafka(security_protocol, security_protocol) + self.start_producer(TOPIC_TEST_INTERNAL_FILTER) + + # Create consumer and poll messages to create consumer offset record + self.start_consumer(TOPIC_TEST_INTERNAL_FILTER) + node = self.consumer.nodes[0] + wait_until(lambda: self.consumer.alive(node), timeout_sec=20, backoff_sec=.2, err_msg="Consumer was too slow to start") + + # Assert that a single topic pattern matches all 4 partitions + wait_until(lambda: self.check_message_count_sum_equals(MAX_MESSAGES, topic_partitions=TOPIC_TEST_INTERNAL_FILTER), + timeout_sec=10, err_msg="Timed out waiting to reach expected offset.") + + # No filters + # Assert that without exclusion, we can find both the test topic and the __consumer_offsets internal topic + offset_output = self.kafka.get_offset_shell() + assert "__consumer_offsets" in offset_output + assert TOPIC_TEST_INTERNAL_FILTER in offset_output + + # Assert that with exclusion, we can find the test topic but not the __consumer_offsets internal topic + offset_output = self.kafka.get_offset_shell(exclude_internal_topics=True) + assert "__consumer_offsets" not in offset_output + assert TOPIC_TEST_INTERNAL_FILTER in offset_output + + # Topic filter + # Assert that without exclusion, we can find both the test topic and the __consumer_offsets internal topic + offset_output = self.kafka.get_offset_shell(topic=".*consumer_offsets") + assert "__consumer_offsets" in offset_output + assert TOPIC_TEST_INTERNAL_FILTER in offset_output + + # Assert that with exclusion, we can find the test topic but not the __consumer_offsets internal topic + offset_output = self.kafka.get_offset_shell(topic=".*consumer_offsets", exclude_internal_topics=True) + assert "__consumer_offsets" not in offset_output + assert TOPIC_TEST_INTERNAL_FILTER in offset_output + + # Topic-partition filter + # Assert that without exclusion, we can find both the test topic and the __consumer_offsets internal topic + offset_output = self.kafka.get_offset_shell(topic_partitions=".*consumer_offsets:0") + assert "__consumer_offsets" in offset_output + assert TOPIC_TEST_INTERNAL_FILTER in offset_output + + # Assert that with exclusion, we can find the test topic but not the __consumer_offsets internal topic + offset_output = self.kafka.get_offset_shell(topic_partitions=".*consumer_offsets:0", exclude_internal_topics=True) + assert "__consumer_offsets" not in offset_output + assert TOPIC_TEST_INTERNAL_FILTER in offset_output diff --git a/tests/kafkatest/tests/core/group_mode_transactions_test.py b/tests/kafkatest/tests/core/group_mode_transactions_test.py new file mode 100644 index 0000000..37a6da3 --- /dev/null +++ b/tests/kafkatest/tests/core/group_mode_transactions_test.py @@ -0,0 +1,331 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from kafkatest.services.zookeeper import ZookeeperService +from kafkatest.services.kafka import KafkaService, quorum +from kafkatest.services.console_consumer import ConsoleConsumer +from kafkatest.services.verifiable_producer import VerifiableProducer +from kafkatest.services.transactional_message_copier import TransactionalMessageCopier +from kafkatest.utils import is_int + +from ducktape.tests.test import Test +from ducktape.mark import matrix +from ducktape.mark.resource import cluster +from ducktape.utils.util import wait_until + +import time + +class GroupModeTransactionsTest(Test): + """Essentially testing the same functionality as TransactionsTest by transactionally copying data + from a source topic to a destination topic and killing the copy process as well as the broker + randomly through the process. The major difference is that we choose to work as a collaborated + group with same topic subscription instead of individual copiers. + + In the end we verify that the final output topic contains exactly one committed copy of + each message from the original producer. + """ + def __init__(self, test_context): + """:type test_context: ducktape.tests.test.TestContext""" + super(GroupModeTransactionsTest, self).__init__(test_context=test_context) + + self.input_topic = "input-topic" + self.output_topic = "output-topic" + + self.num_brokers = 3 + + # Test parameters + self.num_input_partitions = 9 + self.num_output_partitions = 9 + self.num_copiers = 3 + self.num_seed_messages = 100000 + self.transaction_size = 750 + # The transaction timeout should be lower than the progress timeout, but at + # least as high as the request timeout (which is 30s by default). When the + # client is hard-bounced, progress may depend on the previous transaction + # being aborted. When the broker is hard-bounced, we may have to wait as + # long as the request timeout to get a `Produce` response and we do not + # want the coordinator timing out the transaction. + self.transaction_timeout = 40000 + self.progress_timeout_sec = 60 + self.consumer_group = "grouped-transactions-test-consumer-group" + + self.zk = ZookeeperService(test_context, num_nodes=1) if quorum.for_test(test_context) == quorum.zk else None + self.kafka = KafkaService(test_context, + num_nodes=self.num_brokers, + zk=self.zk, controller_num_nodes_override=1) + + def setUp(self): + if self.zk: + self.zk.start() + + def seed_messages(self, topic, num_seed_messages): + seed_timeout_sec = 10000 + seed_producer = VerifiableProducer(context=self.test_context, + num_nodes=1, + kafka=self.kafka, + topic=topic, + message_validator=is_int, + max_messages=num_seed_messages, + enable_idempotence=True, + repeating_keys=self.num_input_partitions) + seed_producer.start() + wait_until(lambda: seed_producer.num_acked >= num_seed_messages, + timeout_sec=seed_timeout_sec, + err_msg="Producer failed to produce messages %d in %ds." % \ + (self.num_seed_messages, seed_timeout_sec)) + return seed_producer.acked_by_partition + + def get_messages_from_topic(self, topic, num_messages): + consumer = self.start_consumer(topic, group_id="verifying_consumer") + return self.drain_consumer(consumer, num_messages) + + def bounce_brokers(self, clean_shutdown): + for node in self.kafka.nodes: + if clean_shutdown: + self.kafka.restart_node(node, clean_shutdown = True) + else: + self.kafka.stop_node(node, clean_shutdown = False) + gracePeriodSecs = 5 + if self.zk: + wait_until(lambda: len(self.kafka.pids(node)) == 0 and not self.kafka.is_registered(node), + timeout_sec=self.kafka.zk_session_timeout + gracePeriodSecs, + err_msg="Failed to see timely deregistration of hard-killed broker %s" % str(node.account)) + else: + brokerSessionTimeoutSecs = 18 + wait_until(lambda: len(self.kafka.pids(node)) == 0, + timeout_sec=brokerSessionTimeoutSecs + gracePeriodSecs, + err_msg="Failed to see timely disappearance of process for hard-killed broker %s" % str(node.account)) + time.sleep(brokerSessionTimeoutSecs + gracePeriodSecs) + self.kafka.start_node(node) + + self.kafka.await_no_under_replicated_partitions() + + def create_and_start_message_copier(self, input_topic, output_topic, transactional_id): + message_copier = TransactionalMessageCopier( + context=self.test_context, + num_nodes=1, + kafka=self.kafka, + transactional_id=transactional_id, + consumer_group=self.consumer_group, + input_topic=input_topic, + input_partition=-1, + output_topic=output_topic, + max_messages=-1, + transaction_size=self.transaction_size, + transaction_timeout=self.transaction_timeout, + use_group_metadata=True, + group_mode=True + ) + message_copier.start() + wait_until(lambda: message_copier.alive(message_copier.nodes[0]), + timeout_sec=10, + err_msg="Message copier failed to start after 10 s") + return message_copier + + def bounce_copiers(self, copiers, clean_shutdown, timeout_sec=240): + for _ in range(3): + for copier in copiers: + wait_until(lambda: copier.progress_percent() >= 20.0, + timeout_sec=self.progress_timeout_sec, + err_msg="%s : Message copier didn't make enough progress in %ds. Current progress: %s" \ + % (copier.transactional_id, self.progress_timeout_sec, str(copier.progress_percent()))) + self.logger.info("%s - progress: %s" % (copier.transactional_id, + str(copier.progress_percent()))) + copier.restart(clean_shutdown) + + def create_and_start_copiers(self, input_topic, output_topic, num_copiers): + copiers = [] + for i in range(0, num_copiers): + copiers.append(self.create_and_start_message_copier( + input_topic=input_topic, + output_topic=output_topic, + transactional_id="copier-" + str(i) + )) + return copiers + + @staticmethod + def valid_value_and_partition(msg): + """Method used to check whether the given message is a valid tab + separated value + partition + + return value and partition as a size-two array represented tuple: [value, partition] + """ + try: + splitted_msg = msg.split('\t') + value = int(splitted_msg[1]) + partition = int(splitted_msg[0].split(":")[1]) + return [value, partition] + + except ValueError: + raise Exception("Unexpected message format (expected a tab separated [value, partition] tuple). Message: %s" % (msg)) + + def start_consumer(self, topic_to_read, group_id): + consumer = ConsoleConsumer(context=self.test_context, + num_nodes=1, + kafka=self.kafka, + topic=topic_to_read, + group_id=group_id, + message_validator=self.valid_value_and_partition, + from_beginning=True, + print_partition=True, + isolation_level="read_committed") + consumer.start() + # ensure that the consumer is up. + wait_until(lambda: (len(consumer.messages_consumed[1]) > 0) == True, + timeout_sec=60, + err_msg="Consumer failed to consume any messages for %ds" % \ + 60) + return consumer + + @staticmethod + def split_by_partition(messages_consumed): + messages_by_partition = {} + + for msg in messages_consumed: + partition = msg[1] + if partition not in messages_by_partition: + messages_by_partition[partition] = [] + messages_by_partition[partition].append(msg[0]) + return messages_by_partition + + def drain_consumer(self, consumer, num_messages): + # wait until we read at least the expected number of messages. + # This is a safe check because both failure modes will be caught: + # 1. If we have 'num_seed_messages' but there are duplicates, then + # this is checked for later. + # + # 2. If we never reach 'num_seed_messages', then this will cause the + # test to fail. + wait_until(lambda: len(consumer.messages_consumed[1]) >= num_messages, + timeout_sec=90, + err_msg="Consumer consumed only %d out of %d messages in %ds" % \ + (len(consumer.messages_consumed[1]), num_messages, 90)) + consumer.stop() + return self.split_by_partition(consumer.messages_consumed[1]) + + def copy_messages_transactionally(self, failure_mode, bounce_target, + input_topic, output_topic, + num_copiers, num_messages_to_copy): + """Copies messages transactionally from the seeded input topic to the + output topic, either bouncing brokers or clients in a hard and soft + way as it goes. + + This method also consumes messages in read_committed mode from the + output topic while the bounces and copy is going on. + + It returns the concurrently consumed messages. + """ + copiers = self.create_and_start_copiers(input_topic=input_topic, + output_topic=output_topic, + num_copiers=num_copiers) + concurrent_consumer = self.start_consumer(output_topic, + group_id="concurrent_consumer") + clean_shutdown = False + if failure_mode == "clean_bounce": + clean_shutdown = True + + if bounce_target == "brokers": + self.bounce_brokers(clean_shutdown) + elif bounce_target == "clients": + self.bounce_copiers(copiers, clean_shutdown) + + copier_timeout_sec = 240 + for copier in copiers: + wait_until(lambda: copier.is_done, + timeout_sec=copier_timeout_sec, + err_msg="%s - Failed to copy all messages in %ds." % \ + (copier.transactional_id, copier_timeout_sec)) + self.logger.info("finished copying messages") + + return self.drain_consumer(concurrent_consumer, num_messages_to_copy) + + def setup_topics(self): + self.kafka.topics = { + self.input_topic: { + "partitions": self.num_input_partitions, + "replication-factor": 3, + "configs": { + "min.insync.replicas": 2 + } + }, + self.output_topic: { + "partitions": self.num_output_partitions, + "replication-factor": 3, + "configs": { + "min.insync.replicas": 2 + } + } + } + + @cluster(num_nodes=10) + @matrix(failure_mode=["hard_bounce", "clean_bounce"], + bounce_target=["brokers", "clients"]) + def test_transactions(self, failure_mode, bounce_target, metadata_quorum=quorum.zk): + security_protocol = 'PLAINTEXT' + self.kafka.security_protocol = security_protocol + self.kafka.interbroker_security_protocol = security_protocol + self.kafka.logs["kafka_data_1"]["collect_default"] = True + self.kafka.logs["kafka_data_2"]["collect_default"] = True + self.kafka.logs["kafka_operational_logs_debug"]["collect_default"] = True + + self.setup_topics() + self.kafka.start() + + input_messages_by_partition = self.seed_messages(self.input_topic, self.num_seed_messages) + concurrently_consumed_message_by_partition = self.copy_messages_transactionally( + failure_mode, bounce_target, input_topic=self.input_topic, + output_topic=self.output_topic, num_copiers=self.num_copiers, + num_messages_to_copy=self.num_seed_messages) + output_messages_by_partition = self.get_messages_from_topic(self.output_topic, self.num_seed_messages) + + assert len(input_messages_by_partition) == \ + len(concurrently_consumed_message_by_partition), "The lengths of partition count doesn't match: " \ + "input partitions count %d, " \ + "concurrently consumed partitions count %d" % \ + (len(input_messages_by_partition), len(concurrently_consumed_message_by_partition)) + + assert len(input_messages_by_partition) == \ + len(output_messages_by_partition), "The lengths of partition count doesn't match: " \ + "input partitions count %d, " \ + "output partitions count %d" % \ + (len(input_messages_by_partition), len(concurrently_consumed_message_by_partition)) + + for p in range(self.num_input_partitions): + if p not in input_messages_by_partition: + continue + + assert p in output_messages_by_partition, "Partition %d not in output messages" + assert p in concurrently_consumed_message_by_partition, "Partition %d not in concurrently consumed messages" + + output_message_set = set(output_messages_by_partition[p]) + input_message_set = set(input_messages_by_partition[p]) + + concurrently_consumed_message_set = set(concurrently_consumed_message_by_partition[p]) + + num_dups = abs(len(output_messages) - len(output_message_set)) + num_dups_in_concurrent_consumer = abs(len(concurrently_consumed_messages) + - len(concurrently_consumed_message_set)) + assert num_dups == 0, "Detected %d duplicates in the output stream" % num_dups + assert input_message_set == output_message_set, "Input and output message sets are not equal. Num input messages %d. Num output messages %d" % \ + (len(input_message_set), len(output_message_set)) + + assert num_dups_in_concurrent_consumer == 0, "Detected %d dups in concurrently consumed messages" % num_dups_in_concurrent_consumer + assert input_message_set == concurrently_consumed_message_set, \ + "Input and concurrently consumed output message sets are not equal. Num input messages: %d. Num concurrently_consumed_messages: %d" % \ + (len(input_message_set), len(concurrently_consumed_message_set)) + + assert input_messages == sorted(input_messages), "The seed messages themselves were not in order" + assert output_messages == input_messages, "Output messages are not in order" + assert concurrently_consumed_messages == output_messages, "Concurrently consumed messages are not in order" diff --git a/tests/kafkatest/tests/core/log_dir_failure_test.py b/tests/kafkatest/tests/core/log_dir_failure_test.py new file mode 100644 index 0000000..31a20fd --- /dev/null +++ b/tests/kafkatest/tests/core/log_dir_failure_test.py @@ -0,0 +1,185 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ducktape.utils.util import wait_until +from ducktape.mark import matrix +from ducktape.mark.resource import cluster +from kafkatest.services.kafka import config_property +from kafkatest.services.zookeeper import ZookeeperService +from kafkatest.services.kafka import KafkaService +from kafkatest.services.verifiable_producer import VerifiableProducer +from kafkatest.services.console_consumer import ConsoleConsumer +from kafkatest.tests.produce_consume_validate import ProduceConsumeValidateTest +from kafkatest.utils import is_int +from kafkatest.utils.remote_account import path_exists + +def select_node(test, broker_type, topic): + """ Discover node of requested type. For leader type, discovers leader for our topic and partition 0 + """ + if broker_type == "leader": + node = test.kafka.leader(topic, partition=0) + elif broker_type == "follower": + leader = test.kafka.leader(topic, partition=0) + node = [replica for replica in test.kafka.replicas(topic, partition=0) if replica != leader][0] + elif broker_type == "controller": + node = test.kafka.controller() + else: + raise Exception("Unexpected broker type %s." % (broker_type)) + + return node + + +class LogDirFailureTest(ProduceConsumeValidateTest): + """ + Note that consuming is a bit tricky, at least with console consumer. The goal is to consume all messages + (foreach partition) in the topic. In this case, waiting for the last message may cause the consumer to stop + too soon since console consumer is consuming multiple partitions from a single thread and therefore we lose + ordering guarantees. + + Waiting on a count of consumed messages can be unreliable: if we stop consuming when num_consumed == num_acked, + we might exit early if some messages are duplicated (though not an issue here since producer retries==0) + + Therefore rely here on the consumer.timeout.ms setting which times out on the interval between successively + consumed messages. Since we run the producer to completion before running the consumer, this is a reliable + indicator that nothing is left to consume. + """ + + def __init__(self, test_context): + """:type test_context: ducktape.tests.test.TestContext""" + super(LogDirFailureTest, self).__init__(test_context=test_context) + + self.topic1 = "test_topic_1" + self.topic2 = "test_topic_2" + self.zk = ZookeeperService(test_context, num_nodes=1) + self.kafka = KafkaService(test_context, + num_nodes=3, + zk=self.zk, + topics={ + self.topic1: {"partitions": 1, "replication-factor": 3, "configs": {"min.insync.replicas": 1}}, + self.topic2: {"partitions": 1, "replication-factor": 3, "configs": {"min.insync.replicas": 2}} + }, + # Set log.roll.ms to 3 seconds so that broker will detect disk error sooner when it creates log segment + # Otherwise broker will still be able to read/write the log file even if the log directory is inaccessible. + server_prop_overrides=[ + [config_property.OFFSETS_TOPIC_NUM_PARTITIONS, "1"], + [config_property.LOG_FLUSH_INTERVAL_MESSAGE, "5"], + [config_property.REPLICA_HIGHWATERMARK_CHECKPOINT_INTERVAL_MS, "60000"], + [config_property.LOG_ROLL_TIME_MS, "3000"] + ]) + + self.producer_throughput = 1000 + self.num_producers = 1 + self.num_consumers = 1 + + def setUp(self): + self.zk.start() + + def min_cluster_size(self): + """Override this since we're adding services outside of the constructor""" + return super(LogDirFailureTest, self).min_cluster_size() + self.num_producers * 2 + self.num_consumers * 2 + + @cluster(num_nodes=9) + @matrix(bounce_broker=[False, True], broker_type=["leader", "follower"], security_protocol=["PLAINTEXT"]) + def test_replication_with_disk_failure(self, bounce_broker, security_protocol, broker_type): + """Replication tests. + These tests verify that replication provides simple durability guarantees by checking that data acked by + brokers is still available for consumption in the face of various failure scenarios. + + Setup: 1 zk, 3 kafka nodes, 1 topic with partitions=3, replication-factor=3, and min.insync.replicas=2 + and another topic with partitions=3, replication-factor=3, and min.insync.replicas=1 + - Produce messages in the background + - Consume messages in the background + - Drive broker failures (shutdown, or bounce repeatedly with kill -15 or kill -9) + - When done driving failures, stop producing, and finish consuming + - Validate that every acked message was consumed + """ + + self.kafka.security_protocol = security_protocol + self.kafka.interbroker_security_protocol = security_protocol + self.kafka.start() + + try: + # Initialize producer/consumer for topic2 + self.producer = VerifiableProducer(self.test_context, self.num_producers, self.kafka, self.topic2, + throughput=self.producer_throughput) + self.consumer = ConsoleConsumer(self.test_context, self.num_consumers, self.kafka, self.topic2, group_id="test-consumer-group-1", + consumer_timeout_ms=60000, message_validator=is_int) + self.start_producer_and_consumer() + + # Get a replica of the partition of topic2 and make its log directory offline by changing the log dir's permission. + # We assume that partition of topic2 is created in the second log directory of respective brokers. + broker_node = select_node(self, broker_type, self.topic2) + broker_idx = self.kafka.idx(broker_node) + assert broker_idx in self.kafka.isr_idx_list(self.topic2), \ + "Broker %d should be in isr set %s" % (broker_idx, str(self.kafka.isr_idx_list(self.topic2))) + + # Verify that topic1 and the consumer offset topic is in the first log directory and topic2 is in the second log directory + topic_1_partition_0 = KafkaService.DATA_LOG_DIR_1 + "/test_topic_1-0" + topic_2_partition_0 = KafkaService.DATA_LOG_DIR_2 + "/test_topic_2-0" + offset_topic_partition_0 = KafkaService.DATA_LOG_DIR_1 + "/__consumer_offsets-0" + for path in [topic_1_partition_0, topic_2_partition_0, offset_topic_partition_0]: + assert path_exists(broker_node, path), "%s should exist" % path + + self.logger.debug("Making log dir %s inaccessible" % (KafkaService.DATA_LOG_DIR_2)) + cmd = "chmod a-w %s -R" % (KafkaService.DATA_LOG_DIR_2) + broker_node.account.ssh(cmd, allow_fail=False) + + if bounce_broker: + self.kafka.restart_node(broker_node, clean_shutdown=True) + + # Verify the following: + # 1) The broker with offline log directory is not the leader of the partition of topic2 + # 2) The broker with offline log directory is not in the ISR + # 3) The broker with offline log directory is still online + # 4) Messages can still be produced and consumed from topic2 + wait_until(lambda: self.kafka.leader(self.topic2, partition=0) != broker_node, + timeout_sec=60, + err_msg="Broker %d should not be leader of topic %s and partition 0" % (broker_idx, self.topic2)) + assert self.kafka.alive(broker_node), "Broker %d should be still online" % (broker_idx) + wait_until(lambda: broker_idx not in self.kafka.isr_idx_list(self.topic2), + timeout_sec=60, + err_msg="Broker %d should not be in isr set %s" % (broker_idx, str(self.kafka.isr_idx_list(self.topic2)))) + + self.stop_producer_and_consumer() + self.validate() + + # Shutdown all other brokers so that the broker with offline log dir is the only online broker + offline_nodes = [] + for node in self.kafka.nodes: + if broker_node != node: + offline_nodes.append(node) + self.logger.debug("Hard shutdown broker %d" % (self.kafka.idx(node))) + self.kafka.stop_node(node) + + # Verify the following: + # 1) The broker with offline directory is the only in-sync broker of the partition of topic1 + # 2) Messages can still be produced and consumed from topic1 + self.producer = VerifiableProducer(self.test_context, self.num_producers, self.kafka, self.topic1, + throughput=self.producer_throughput, offline_nodes=offline_nodes) + self.consumer = ConsoleConsumer(self.test_context, self.num_consumers, self.kafka, self.topic1, group_id="test-consumer-group-2", + consumer_timeout_ms=90000, message_validator=is_int) + self.consumer_start_timeout_sec = 90 + self.start_producer_and_consumer() + + assert self.kafka.isr_idx_list(self.topic1) == [broker_idx], \ + "In-sync replicas of topic %s and partition 0 should be %s" % (self.topic1, str([broker_idx])) + + self.stop_producer_and_consumer() + self.validate() + + except BaseException as e: + for s in self.test_context.services: + self.mark_for_collect(s) + raise diff --git a/tests/kafkatest/tests/core/mirror_maker_test.py b/tests/kafkatest/tests/core/mirror_maker_test.py new file mode 100644 index 0000000..05cdb4b --- /dev/null +++ b/tests/kafkatest/tests/core/mirror_maker_test.py @@ -0,0 +1,171 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ducktape.utils.util import wait_until +from ducktape.mark import matrix +from ducktape.mark.resource import cluster + +from kafkatest.services.zookeeper import ZookeeperService +from kafkatest.services.kafka import KafkaService +from kafkatest.services.console_consumer import ConsoleConsumer +from kafkatest.services.verifiable_producer import VerifiableProducer +from kafkatest.services.mirror_maker import MirrorMaker +from kafkatest.services.security.minikdc import MiniKdc +from kafkatest.tests.produce_consume_validate import ProduceConsumeValidateTest +from kafkatest.utils import is_int + +import time + + +class TestMirrorMakerService(ProduceConsumeValidateTest): + """Sanity checks on mirror maker service class.""" + def __init__(self, test_context): + super(TestMirrorMakerService, self).__init__(test_context) + + self.topic = "topic" + self.source_zk = ZookeeperService(test_context, num_nodes=1) + self.target_zk = ZookeeperService(test_context, num_nodes=1) + + self.source_kafka = KafkaService(test_context, num_nodes=1, zk=self.source_zk, + topics={self.topic: {"partitions": 1, "replication-factor": 1}}) + self.target_kafka = KafkaService(test_context, num_nodes=1, zk=self.target_zk, + topics={self.topic: {"partitions": 1, "replication-factor": 1}}) + # This will produce to source kafka cluster + self.producer = VerifiableProducer(test_context, num_nodes=1, kafka=self.source_kafka, topic=self.topic, + throughput=1000) + self.mirror_maker = MirrorMaker(test_context, num_nodes=1, source=self.source_kafka, target=self.target_kafka, + whitelist=self.topic, offset_commit_interval_ms=1000) + # This will consume from target kafka cluster + self.consumer = ConsoleConsumer(test_context, num_nodes=1, kafka=self.target_kafka, topic=self.topic, + message_validator=is_int, consumer_timeout_ms=60000) + + def setUp(self): + # Source cluster + self.source_zk.start() + + # Target cluster + self.target_zk.start() + + def start_kafka(self, security_protocol): + self.source_kafka.security_protocol = security_protocol + self.source_kafka.interbroker_security_protocol = security_protocol + self.target_kafka.security_protocol = security_protocol + self.target_kafka.interbroker_security_protocol = security_protocol + if self.source_kafka.security_config.has_sasl_kerberos: + minikdc = MiniKdc(self.source_kafka.context, self.source_kafka.nodes + self.target_kafka.nodes) + self.source_kafka.minikdc = minikdc + self.target_kafka.minikdc = minikdc + minikdc.start() + self.source_kafka.start() + self.target_kafka.start() + + def bounce(self, clean_shutdown=True): + """Bounce mirror maker with a clean (kill -15) or hard (kill -9) shutdown""" + + # Wait until messages start appearing in the target cluster + wait_until(lambda: len(self.consumer.messages_consumed[1]) > 0, timeout_sec=15) + + # Wait for at least one offset to be committed. + # + # This step is necessary to prevent data loss with default mirror maker settings: + # currently, if we don't have at least one committed offset, + # and we bounce mirror maker, the consumer internals will throw OffsetOutOfRangeException, and the default + # auto.offset.reset policy ("largest") will kick in, causing mirrormaker to start consuming from the largest + # offset. As a result, any messages produced to the source cluster while mirrormaker was dead won't get + # mirrored to the target cluster. + # (see https://issues.apache.org/jira/browse/KAFKA-2759) + # + # This isn't necessary with kill -15 because mirror maker commits its offsets during graceful + # shutdown. + if not clean_shutdown: + time.sleep(self.mirror_maker.offset_commit_interval_ms / 1000.0 + .5) + + for i in range(3): + self.logger.info("Bringing mirror maker nodes down...") + for node in self.mirror_maker.nodes: + self.mirror_maker.stop_node(node, clean_shutdown=clean_shutdown) + + num_consumed = len(self.consumer.messages_consumed[1]) + self.logger.info("Bringing mirror maker nodes back up...") + for node in self.mirror_maker.nodes: + self.mirror_maker.start_node(node) + + # Ensure new messages are once again showing up on the target cluster + wait_until(lambda: len(self.consumer.messages_consumed[1]) > num_consumed + 100, timeout_sec=60) + + def wait_for_n_messages(self, n_messages=100): + """Wait for a minimum number of messages to be successfully produced.""" + wait_until(lambda: self.producer.num_acked > n_messages, timeout_sec=10, + err_msg="Producer failed to produce %d messages in a reasonable amount of time." % n_messages) + + @cluster(num_nodes=7) + @matrix(security_protocol=['PLAINTEXT', 'SSL']) + @cluster(num_nodes=8) + @matrix(security_protocol=['SASL_PLAINTEXT', 'SASL_SSL']) + def test_simple_end_to_end(self, security_protocol): + """ + Test end-to-end behavior under non-failure conditions. + + Setup: two single node Kafka clusters, each connected to its own single node zookeeper cluster. + One is source, and the other is target. Single-node mirror maker mirrors from source to target. + + - Start mirror maker. + - Produce a small number of messages to the source cluster. + - Consume messages from target. + - Verify that number of consumed messages matches the number produced. + """ + self.start_kafka(security_protocol) + self.mirror_maker.start() + + mm_node = self.mirror_maker.nodes[0] + with mm_node.account.monitor_log(self.mirror_maker.LOG_FILE) as monitor: + monitor.wait_until("Resetting offset for partition", timeout_sec=30, err_msg="Mirrormaker did not reset fetch offset in a reasonable amount of time.") + self.run_produce_consume_validate(core_test_action=self.wait_for_n_messages) + self.mirror_maker.stop() + + @cluster(num_nodes=7) + @matrix(clean_shutdown=[True, False], security_protocol=['PLAINTEXT', 'SSL']) + @cluster(num_nodes=8) + @matrix(clean_shutdown=[True, False], security_protocol=['SASL_PLAINTEXT', 'SASL_SSL']) + def test_bounce(self, offsets_storage="kafka", clean_shutdown=True, security_protocol='PLAINTEXT'): + """ + Test end-to-end behavior under failure conditions. + + Setup: two single node Kafka clusters, each connected to its own single node zookeeper cluster. + One is source, and the other is target. Single-node mirror maker mirrors from source to target. + + - Start mirror maker. + - Produce to source cluster, and consume from target cluster in the background. + - Bounce MM process + - Verify every message acknowledged by the source producer is consumed by the target consumer + """ + if not clean_shutdown: + # Increase timeout on downstream console consumer; mirror maker takes extra time + # during hard bounce. This is because the restarted mirror maker consumer won't be able to rejoin + # the group until the previous session times out + self.consumer.consumer_timeout_ms = 60000 + + self.start_kafka(security_protocol) + + self.mirror_maker.offsets_storage = offsets_storage + self.mirror_maker.start() + + # Wait until mirror maker has reset fetch offset at least once before continuing with the rest of the test + mm_node = self.mirror_maker.nodes[0] + with mm_node.account.monitor_log(self.mirror_maker.LOG_FILE) as monitor: + monitor.wait_until("Resetting offset for partition", timeout_sec=30, err_msg="Mirrormaker did not reset fetch offset in a reasonable amount of time.") + + self.run_produce_consume_validate(core_test_action=lambda: self.bounce(clean_shutdown=clean_shutdown)) + self.mirror_maker.stop() diff --git a/tests/kafkatest/tests/core/network_degrade_test.py b/tests/kafkatest/tests/core/network_degrade_test.py new file mode 100644 index 0000000..68cce85 --- /dev/null +++ b/tests/kafkatest/tests/core/network_degrade_test.py @@ -0,0 +1,145 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +from ducktape.mark import parametrize +from ducktape.mark.resource import cluster +from ducktape.tests.test import Test +from ducktape.utils.util import wait_until + +from kafkatest.services.trogdor.degraded_network_fault_spec import DegradedNetworkFaultSpec +from kafkatest.services.trogdor.trogdor import TrogdorService +from kafkatest.services.zookeeper import ZookeeperService + + +class NetworkDegradeTest(Test): + """ + These tests ensure that the network degrade Trogdor specs (which use "tc") are working as expected in whatever + environment the system tests may be running in. The linux tools "ping" and "iperf" are used for validation + and need to be available along with "tc" in the test environment. + """ + + def __init__(self, test_context): + super(NetworkDegradeTest, self).__init__(test_context) + self.zk = ZookeeperService(test_context, num_nodes=3) + self.trogdor = TrogdorService(context=self.test_context, client_services=[self.zk]) + + def setUp(self): + self.zk.start() + self.trogdor.start() + + def teardown(self): + self.trogdor.stop() + self.zk.stop() + + @cluster(num_nodes=5) + @parametrize(task_name="latency-100", device_name="eth0", latency_ms=50, rate_limit_kbit=0) + @parametrize(task_name="latency-100-rate-1000", device_name="eth0", latency_ms=50, rate_limit_kbit=1000) + def test_latency(self, task_name, device_name, latency_ms, rate_limit_kbit): + spec = DegradedNetworkFaultSpec(0, 10000) + for node in self.zk.nodes: + spec.add_node_spec(node.name, device_name, latency_ms, rate_limit_kbit) + + latency = self.trogdor.create_task(task_name, spec) + + zk0 = self.zk.nodes[0] + zk1 = self.zk.nodes[1] + + # Capture the ping times from the ping stdout + # 64 bytes from ducker01 (172.24.0.2): icmp_seq=1 ttl=64 time=0.325 ms + r = re.compile(r".*time=(?P

                + *     {@code
                + *      if (throttler.shouldThrottle(...)) {
                + *          throttler.throttle();
                + *      }
                + *     }
                + * 
                + * + * Note that this can be used to throttle message throughput or data throughput. + */ +public class ThroughputThrottler { + + private static final long NS_PER_MS = 1000000L; + private static final long NS_PER_SEC = 1000 * NS_PER_MS; + private static final long MIN_SLEEP_NS = 2 * NS_PER_MS; + + private final long startMs; + private final long sleepTimeNs; + private final long targetThroughput; + + private long sleepDeficitNs = 0; + private boolean wakeup = false; + + /** + * @param targetThroughput Can be messages/sec or bytes/sec + * @param startMs When the very first message is sent + */ + public ThroughputThrottler(long targetThroughput, long startMs) { + this.startMs = startMs; + this.targetThroughput = targetThroughput; + this.sleepTimeNs = targetThroughput > 0 ? + NS_PER_SEC / targetThroughput : + Long.MAX_VALUE; + } + + /** + * @param amountSoFar bytes produced so far if you want to throttle data throughput, or + * messages produced so far if you want to throttle message throughput. + * @param sendStartMs timestamp of the most recently sent message + * @return + */ + public boolean shouldThrottle(long amountSoFar, long sendStartMs) { + if (this.targetThroughput < 0) { + // No throttling in this case + return false; + } + + float elapsedSec = (sendStartMs - startMs) / 1000.f; + return elapsedSec > 0 && (amountSoFar / elapsedSec) > this.targetThroughput; + } + + /** + * Occasionally blocks for small amounts of time to achieve targetThroughput. + * + * Note that if targetThroughput is 0, this will block extremely aggressively. + */ + public void throttle() { + if (targetThroughput == 0) { + try { + synchronized (this) { + while (!wakeup) { + this.wait(); + } + } + } catch (InterruptedException e) { + // do nothing + } + return; + } + + // throttle throughput by sleeping, on average, + // (1 / this.throughput) seconds between "things sent" + sleepDeficitNs += sleepTimeNs; + + // If enough sleep deficit has accumulated, sleep a little + if (sleepDeficitNs >= MIN_SLEEP_NS) { + long sleepStartNs = System.nanoTime(); + try { + synchronized (this) { + long remaining = sleepDeficitNs; + while (!wakeup && remaining > 0) { + long sleepMs = remaining / 1000000; + long sleepNs = remaining - sleepMs * 1000000; + this.wait(sleepMs, (int) sleepNs); + long elapsed = System.nanoTime() - sleepStartNs; + remaining = sleepDeficitNs - elapsed; + } + wakeup = false; + } + sleepDeficitNs = 0; + } catch (InterruptedException e) { + // If sleep is cut short, reduce deficit by the amount of + // time we actually spent sleeping + long sleepElapsedNs = System.nanoTime() - sleepStartNs; + if (sleepElapsedNs <= sleepDeficitNs) { + sleepDeficitNs -= sleepElapsedNs; + } + } + } + } + + /** + * Wakeup the throttler if its sleeping. + */ + public void wakeup() { + synchronized (this) { + wakeup = true; + this.notifyAll(); + } + } +} + diff --git a/tools/src/main/java/org/apache/kafka/tools/ToolsUtils.java b/tools/src/main/java/org/apache/kafka/tools/ToolsUtils.java new file mode 100644 index 0000000..3a80b58 --- /dev/null +++ b/tools/src/main/java/org/apache/kafka/tools/ToolsUtils.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.tools; + +import org.apache.kafka.common.Metric; +import org.apache.kafka.common.MetricName; + +import java.util.Map; +import java.util.TreeMap; + +public class ToolsUtils { + + /** + * print out the metrics in alphabetical order + * @param metrics the metrics to be printed out + */ + public static void printMetrics(Map metrics) { + if (metrics != null && !metrics.isEmpty()) { + int maxLengthOfDisplayName = 0; + TreeMap sortedMetrics = new TreeMap<>(); + for (Metric metric : metrics.values()) { + MetricName mName = metric.metricName(); + String mergedName = mName.group() + ":" + mName.name() + ":" + mName.tags(); + maxLengthOfDisplayName = maxLengthOfDisplayName < mergedName.length() ? mergedName.length() : maxLengthOfDisplayName; + sortedMetrics.put(mergedName, metric.metricValue()); + } + String doubleOutputFormat = "%-" + maxLengthOfDisplayName + "s : %.3f"; + String defaultOutputFormat = "%-" + maxLengthOfDisplayName + "s : %s"; + System.out.println(String.format("\n%-" + maxLengthOfDisplayName + "s %s", "Metric Name", "Value")); + + for (Map.Entry entry : sortedMetrics.entrySet()) { + String outputFormat; + if (entry.getValue() instanceof Double) + outputFormat = doubleOutputFormat; + else + outputFormat = defaultOutputFormat; + System.out.println(String.format(outputFormat, entry.getKey(), entry.getValue())); + } + } + } +} diff --git a/tools/src/main/java/org/apache/kafka/tools/TransactionalMessageCopier.java b/tools/src/main/java/org/apache/kafka/tools/TransactionalMessageCopier.java new file mode 100644 index 0000000..18c8099 --- /dev/null +++ b/tools/src/main/java/org/apache/kafka/tools/TransactionalMessageCopier.java @@ -0,0 +1,404 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.tools; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import net.sourceforge.argparse4j.ArgumentParsers; +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.Namespace; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerGroupMetadata; +import org.apache.kafka.clients.consumer.ConsumerRebalanceListener; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.ProducerFencedException; +import org.apache.kafka.common.errors.WakeupException; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.text.DateFormat; +import java.text.SimpleDateFormat; +import java.time.Duration; +import java.util.Collection; +import java.util.Collections; +import java.util.Date; +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; +import java.util.Random; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; + +import static java.util.Collections.singleton; +import static net.sourceforge.argparse4j.impl.Arguments.store; +import static net.sourceforge.argparse4j.impl.Arguments.storeTrue; + +/** + * This class is primarily meant for use with system tests. It copies messages from an input partition to an output + * topic transactionally, committing the offsets and messages together. + */ +public class TransactionalMessageCopier { + private static final Logger log = LoggerFactory.getLogger(TransactionalMessageCopier.class); + private static final DateFormat FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss:SSS"); + + /** Get the command-line argument parser. */ + private static ArgumentParser argParser() { + ArgumentParser parser = ArgumentParsers + .newArgumentParser("transactional-message-copier") + .defaultHelp(true) + .description("This tool copies messages transactionally from an input partition to an output topic, " + + "committing the consumed offsets along with the output messages"); + + parser.addArgument("--input-topic") + .action(store()) + .required(true) + .type(String.class) + .metavar("INPUT-TOPIC") + .dest("inputTopic") + .help("Consume messages from this topic"); + + parser.addArgument("--input-partition") + .action(store()) + .required(true) + .type(Integer.class) + .metavar("INPUT-PARTITION") + .dest("inputPartition") + .help("Consume messages from this partition of the input topic."); + + parser.addArgument("--output-topic") + .action(store()) + .required(true) + .type(String.class) + .metavar("OUTPUT-TOPIC") + .dest("outputTopic") + .help("Produce messages to this topic"); + + parser.addArgument("--broker-list") + .action(store()) + .required(true) + .type(String.class) + .metavar("HOST1:PORT1[,HOST2:PORT2[...]]") + .dest("brokerList") + .help("Comma-separated list of Kafka brokers in the form HOST1:PORT1,HOST2:PORT2,..."); + + parser.addArgument("--max-messages") + .action(store()) + .required(false) + .setDefault(-1) + .type(Integer.class) + .metavar("MAX-MESSAGES") + .dest("maxMessages") + .help("Process these many messages upto the end offset at the time this program was launched. If set to -1 " + + "we will just read to the end offset of the input partition (as of the time the program was launched)."); + + parser.addArgument("--consumer-group") + .action(store()) + .required(false) + .setDefault(-1) + .type(String.class) + .metavar("CONSUMER-GROUP") + .dest("consumerGroup") + .help("The consumer group id to use for storing the consumer offsets."); + + parser.addArgument("--transaction-size") + .action(store()) + .required(false) + .setDefault(200) + .type(Integer.class) + .metavar("TRANSACTION-SIZE") + .dest("messagesPerTransaction") + .help("The number of messages to put in each transaction. Default is 200."); + + parser.addArgument("--transaction-timeout") + .action(store()) + .required(false) + .setDefault(60000) + .type(Integer.class) + .metavar("TRANSACTION-TIMEOUT") + .dest("transactionTimeout") + .help("The transaction timeout in milliseconds. Default is 60000(1 minute)."); + + parser.addArgument("--transactional-id") + .action(store()) + .required(true) + .type(String.class) + .metavar("TRANSACTIONAL-ID") + .dest("transactionalId") + .help("The transactionalId to assign to the producer"); + + parser.addArgument("--enable-random-aborts") + .action(storeTrue()) + .type(Boolean.class) + .metavar("ENABLE-RANDOM-ABORTS") + .dest("enableRandomAborts") + .help("Whether or not to enable random transaction aborts (for system testing)"); + + parser.addArgument("--group-mode") + .action(storeTrue()) + .type(Boolean.class) + .metavar("GROUP-MODE") + .dest("groupMode") + .help("Whether to let consumer subscribe to the input topic or do manual assign. If we do" + + " subscription based consumption, the input partition shall be ignored"); + + parser.addArgument("--use-group-metadata") + .action(storeTrue()) + .type(Boolean.class) + .metavar("USE-GROUP-METADATA") + .dest("useGroupMetadata") + .help("Whether to use the new transactional commit API with group metadata"); + + return parser; + } + + private static KafkaProducer createProducer(Namespace parsedArgs) { + Properties props = new Properties(); + props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, parsedArgs.getString("brokerList")); + props.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, parsedArgs.getString("transactionalId")); + props.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, + "org.apache.kafka.common.serialization.StringSerializer"); + props.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, + "org.apache.kafka.common.serialization.StringSerializer"); + // We set a small batch size to ensure that we have multiple inflight requests per transaction. + // If it is left at the default, each transaction will have only one batch per partition, hence not testing + // the case with multiple inflights. + props.put(ProducerConfig.BATCH_SIZE_CONFIG, "512"); + props.put(ProducerConfig.MAX_IN_FLIGHT_REQUESTS_PER_CONNECTION, "5"); + props.put(ProducerConfig.TRANSACTION_TIMEOUT_CONFIG, parsedArgs.getInt("transactionTimeout")); + + return new KafkaProducer<>(props); + } + + private static KafkaConsumer createConsumer(Namespace parsedArgs) { + String consumerGroup = parsedArgs.getString("consumerGroup"); + String brokerList = parsedArgs.getString("brokerList"); + Integer numMessagesPerTransaction = parsedArgs.getInt("messagesPerTransaction"); + + Properties props = new Properties(); + + props.put(ConsumerConfig.GROUP_ID_CONFIG, consumerGroup); + props.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerList); + props.put(ConsumerConfig.ISOLATION_LEVEL_CONFIG, "read_committed"); + props.put(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, numMessagesPerTransaction.toString()); + props.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false"); + props.put(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, "10000"); + props.put(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG, "180000"); + props.put(ConsumerConfig.HEARTBEAT_INTERVAL_MS_CONFIG, "3000"); + props.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + props.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, + "org.apache.kafka.common.serialization.StringDeserializer"); + props.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, + "org.apache.kafka.common.serialization.StringDeserializer"); + + return new KafkaConsumer<>(props); + } + + private static ProducerRecord producerRecordFromConsumerRecord(String topic, ConsumerRecord record) { + return new ProducerRecord<>(topic, record.partition(), record.key(), record.value()); + } + + private static Map consumerPositions(KafkaConsumer consumer) { + Map positions = new HashMap<>(); + for (TopicPartition topicPartition : consumer.assignment()) { + positions.put(topicPartition, new OffsetAndMetadata(consumer.position(topicPartition), null)); + } + return positions; + } + + private static void resetToLastCommittedPositions(KafkaConsumer consumer) { + final Map committed = consumer.committed(consumer.assignment()); + consumer.assignment().forEach(tp -> { + OffsetAndMetadata offsetAndMetadata = committed.get(tp); + if (offsetAndMetadata != null) + consumer.seek(tp, offsetAndMetadata.offset()); + else + consumer.seekToBeginning(singleton(tp)); + }); + } + + private static long messagesRemaining(KafkaConsumer consumer, TopicPartition partition) { + long currentPosition = consumer.position(partition); + Map endOffsets = consumer.endOffsets(singleton(partition)); + if (endOffsets.containsKey(partition)) { + return endOffsets.get(partition) - currentPosition; + } + return 0; + } + + private static String toJsonString(Map data) { + String json; + try { + ObjectMapper mapper = new ObjectMapper(); + json = mapper.writeValueAsString(data); + } catch (JsonProcessingException e) { + json = "Bad data can't be written as json: " + e.getMessage(); + } + return json; + } + + private static synchronized String statusAsJson(long totalProcessed, long consumedSinceLastRebalanced, long remaining, String transactionalId, String stage) { + Map statusData = new HashMap<>(); + statusData.put("progress", transactionalId); + statusData.put("totalProcessed", totalProcessed); + statusData.put("consumed", consumedSinceLastRebalanced); + statusData.put("remaining", remaining); + statusData.put("time", FORMAT.format(new Date())); + statusData.put("stage", stage); + return toJsonString(statusData); + } + + private static synchronized String shutDownString(long totalProcessed, long consumedSinceLastRebalanced, long remaining, String transactionalId) { + Map shutdownData = new HashMap<>(); + shutdownData.put("shutdown_complete", transactionalId); + shutdownData.put("totalProcessed", totalProcessed); + shutdownData.put("consumed", consumedSinceLastRebalanced); + shutdownData.put("remaining", remaining); + shutdownData.put("time", FORMAT.format(new Date())); + return toJsonString(shutdownData); + } + + private static void abortTransactionAndResetPosition( + KafkaProducer producer, + KafkaConsumer consumer + ) { + producer.abortTransaction(); + resetToLastCommittedPositions(consumer); + } + + public static void main(String[] args) { + Namespace parsedArgs = argParser().parseArgsOrFail(args); + try { + runEventLoop(parsedArgs); + Exit.exit(0); + } catch (Exception e) { + log.error("Shutting down after unexpected error in event loop", e); + System.err.println("Shutting down after unexpected error " + e.getClass().getSimpleName() + + ": " + e.getMessage() + " (see the log for additional detail)"); + Exit.exit(1); + } + } + + public static void runEventLoop(Namespace parsedArgs) { + final String transactionalId = parsedArgs.getString("transactionalId"); + final String outputTopic = parsedArgs.getString("outputTopic"); + + String consumerGroup = parsedArgs.getString("consumerGroup"); + + final KafkaProducer producer = createProducer(parsedArgs); + final KafkaConsumer consumer = createConsumer(parsedArgs); + + final AtomicLong remainingMessages = new AtomicLong( + parsedArgs.getInt("maxMessages") == -1 ? Long.MAX_VALUE : parsedArgs.getInt("maxMessages")); + + boolean groupMode = parsedArgs.getBoolean("groupMode"); + String topicName = parsedArgs.getString("inputTopic"); + final AtomicLong numMessagesProcessedSinceLastRebalance = new AtomicLong(0); + final AtomicLong totalMessageProcessed = new AtomicLong(0); + if (groupMode) { + consumer.subscribe(Collections.singleton(topicName), new ConsumerRebalanceListener() { + @Override + public void onPartitionsRevoked(Collection partitions) { + } + + @Override + public void onPartitionsAssigned(Collection partitions) { + remainingMessages.set(partitions.stream() + .mapToLong(partition -> messagesRemaining(consumer, partition)).sum()); + numMessagesProcessedSinceLastRebalance.set(0); + // We use message cap for remaining here as the remainingMessages are not set yet. + System.out.println(statusAsJson(totalMessageProcessed.get(), + numMessagesProcessedSinceLastRebalance.get(), remainingMessages.get(), transactionalId, "RebalanceComplete")); + } + }); + } else { + TopicPartition inputPartition = new TopicPartition(topicName, parsedArgs.getInt("inputPartition")); + consumer.assign(singleton(inputPartition)); + remainingMessages.set(Math.min(messagesRemaining(consumer, inputPartition), remainingMessages.get())); + } + + final boolean enableRandomAborts = parsedArgs.getBoolean("enableRandomAborts"); + + producer.initTransactions(); + + final AtomicBoolean isShuttingDown = new AtomicBoolean(false); + + Exit.addShutdownHook("transactional-message-copier-shutdown-hook", () -> { + isShuttingDown.set(true); + consumer.wakeup(); + System.out.println(shutDownString(totalMessageProcessed.get(), + numMessagesProcessedSinceLastRebalance.get(), remainingMessages.get(), transactionalId)); + }); + + final boolean useGroupMetadata = parsedArgs.getBoolean("useGroupMetadata"); + try { + Random random = new Random(); + while (!isShuttingDown.get() && remainingMessages.get() > 0) { + System.out.println(statusAsJson(totalMessageProcessed.get(), + numMessagesProcessedSinceLastRebalance.get(), remainingMessages.get(), transactionalId, "ProcessLoop")); + + ConsumerRecords records = consumer.poll(Duration.ofMillis(200)); + if (records.count() > 0) { + try { + producer.beginTransaction(); + + for (ConsumerRecord record : records) { + producer.send(producerRecordFromConsumerRecord(outputTopic, record)); + } + + long messagesSentWithinCurrentTxn = records.count(); + + ConsumerGroupMetadata groupMetadata = useGroupMetadata ? consumer.groupMetadata() : new ConsumerGroupMetadata(consumerGroup); + producer.sendOffsetsToTransaction(consumerPositions(consumer), groupMetadata); + + if (enableRandomAborts && random.nextInt() % 3 == 0) { + abortTransactionAndResetPosition(producer, consumer); + } else { + producer.commitTransaction(); + remainingMessages.getAndAdd(-messagesSentWithinCurrentTxn); + numMessagesProcessedSinceLastRebalance.getAndAdd(messagesSentWithinCurrentTxn); + totalMessageProcessed.getAndAdd(messagesSentWithinCurrentTxn); + } + } catch (ProducerFencedException e) { + throw new KafkaException(String.format("The transactional.id %s has been claimed by another process", transactionalId), e); + } catch (KafkaException e) { + log.debug("Aborting transaction after catching exception", e); + abortTransactionAndResetPosition(producer, consumer); + } + } + } + } catch (WakeupException e) { + if (!isShuttingDown.get()) { + // Let the exception propagate if the exception was not raised + // as part of shutdown. + throw e; + } + } finally { + Utils.closeQuietly(producer, "producer"); + Utils.closeQuietly(consumer, "consumer"); + } + } +} diff --git a/tools/src/main/java/org/apache/kafka/tools/TransactionsCommand.java b/tools/src/main/java/org/apache/kafka/tools/TransactionsCommand.java new file mode 100644 index 0000000..92e713a --- /dev/null +++ b/tools/src/main/java/org/apache/kafka/tools/TransactionsCommand.java @@ -0,0 +1,1066 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.tools; + +import net.sourceforge.argparse4j.ArgumentParsers; +import net.sourceforge.argparse4j.inf.ArgumentGroup; +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.ArgumentParserException; +import net.sourceforge.argparse4j.inf.Namespace; +import net.sourceforge.argparse4j.inf.Subparser; +import net.sourceforge.argparse4j.inf.Subparsers; +import org.apache.kafka.clients.admin.AbortTransactionSpec; +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.AdminClientConfig; +import org.apache.kafka.clients.admin.DescribeProducersOptions; +import org.apache.kafka.clients.admin.DescribeProducersResult; +import org.apache.kafka.clients.admin.DescribeTransactionsResult; +import org.apache.kafka.clients.admin.ListTopicsOptions; +import org.apache.kafka.clients.admin.ListTransactionsOptions; +import org.apache.kafka.clients.admin.ProducerState; +import org.apache.kafka.clients.admin.TopicDescription; +import org.apache.kafka.clients.admin.TransactionDescription; +import org.apache.kafka.clients.admin.TransactionListing; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.TopicPartitionInfo; +import org.apache.kafka.common.errors.TransactionalIdNotFoundException; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.PrintStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.OptionalLong; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static java.util.Collections.singleton; +import static java.util.Collections.singletonList; +import static net.sourceforge.argparse4j.impl.Arguments.store; + +public abstract class TransactionsCommand { + private static final Logger log = LoggerFactory.getLogger(TransactionsCommand.class); + + protected final Time time; + + protected TransactionsCommand(Time time) { + this.time = time; + } + + /** + * Get the name of this command (e.g. `describe-producers`). + */ + abstract String name(); + + /** + * Specify the arguments needed for this command. + */ + abstract void addSubparser(Subparsers subparsers); + + /** + * Execute the command logic. + */ + abstract void execute(Admin admin, Namespace ns, PrintStream out) throws Exception; + + + static class AbortTransactionCommand extends TransactionsCommand { + + AbortTransactionCommand(Time time) { + super(time); + } + + @Override + String name() { + return "abort"; + } + + @Override + void addSubparser(Subparsers subparsers) { + Subparser subparser = subparsers.addParser(name()) + .help("abort a hanging transaction (requires administrative privileges)"); + + subparser.addArgument("--topic") + .help("topic name") + .action(store()) + .type(String.class) + .required(true); + + subparser.addArgument("--partition") + .help("partition number") + .action(store()) + .type(Integer.class) + .required(true); + + ArgumentGroup newBrokerArgumentGroup = subparser + .addArgumentGroup("Brokers on versions 3.0 and above") + .description("For newer brokers, only the start offset of the transaction " + + "to be aborted is required"); + + newBrokerArgumentGroup.addArgument("--start-offset") + .help("start offset of the transaction to abort") + .action(store()) + .type(Long.class); + + ArgumentGroup olderBrokerArgumentGroup = subparser + .addArgumentGroup("Brokers on versions older than 3.0") + .description("For older brokers, you must provide all of these arguments"); + + olderBrokerArgumentGroup.addArgument("--producer-id") + .help("producer id") + .action(store()) + .type(Long.class); + + olderBrokerArgumentGroup.addArgument("--producer-epoch") + .help("producer epoch") + .action(store()) + .type(Short.class); + + olderBrokerArgumentGroup.addArgument("--coordinator-epoch") + .help("coordinator epoch") + .action(store()) + .type(Integer.class); + } + + private AbortTransactionSpec buildAbortSpec( + Admin admin, + TopicPartition topicPartition, + long startOffset + ) throws Exception { + final DescribeProducersResult.PartitionProducerState result; + try { + result = admin.describeProducers(singleton(topicPartition)) + .partitionResult(topicPartition) + .get(); + } catch (ExecutionException e) { + printErrorAndExit("Failed to validate producer state for partition " + + topicPartition, e.getCause()); + return null; + } + + Optional foundProducerState = result.activeProducers().stream() + .filter(producerState -> { + OptionalLong txnStartOffsetOpt = producerState.currentTransactionStartOffset(); + return txnStartOffsetOpt.isPresent() && txnStartOffsetOpt.getAsLong() == startOffset; + }) + .findFirst(); + + if (!foundProducerState.isPresent()) { + printErrorAndExit("Could not find any open transactions starting at offset " + + startOffset + " on partition " + topicPartition); + return null; + } + + ProducerState producerState = foundProducerState.get(); + return new AbortTransactionSpec( + topicPartition, + producerState.producerId(), + (short) producerState.producerEpoch(), + producerState.coordinatorEpoch().orElse(0) + ); + } + + private void abortTransaction( + Admin admin, + AbortTransactionSpec abortSpec + ) throws Exception { + try { + admin.abortTransaction(abortSpec).all().get(); + } catch (ExecutionException e) { + TransactionsCommand.printErrorAndExit("Failed to abort transaction " + abortSpec, e.getCause()); + } + } + + @Override + void execute(Admin admin, Namespace ns, PrintStream out) throws Exception { + String topicName = ns.getString("topic"); + Integer partitionId = ns.getInt("partition"); + TopicPartition topicPartition = new TopicPartition(topicName, partitionId); + + Long startOffset = ns.getLong("start_offset"); + Long producerId = ns.getLong("producer_id"); + + if (startOffset == null && producerId == null) { + printErrorAndExit("The transaction to abort must be identified either with " + + "--start-offset (for brokers on 3.0 or above) or with " + + "--producer-id, --producer-epoch, and --coordinator-epoch (for older brokers)"); + return; + } + + final AbortTransactionSpec abortSpec; + if (startOffset == null) { + Short producerEpoch = ns.getShort("producer_epoch"); + if (producerEpoch == null) { + printErrorAndExit("Missing required argument --producer-epoch"); + return; + } + + Integer coordinatorEpoch = ns.getInt("coordinator_epoch"); + if (coordinatorEpoch == null) { + printErrorAndExit("Missing required argument --coordinator-epoch"); + return; + } + + // If a transaction was started by a new producerId and became hanging + // before the initial commit/abort, then the coordinator epoch will be -1 + // as seen in the `DescribeProducers` output. In this case, we conservatively + // use a coordinator epoch of 0, which is less than or equal to any possible + // leader epoch. + if (coordinatorEpoch < 0) { + coordinatorEpoch = 0; + } + + abortSpec = new AbortTransactionSpec( + topicPartition, + producerId, + producerEpoch, + coordinatorEpoch + ); + } else { + abortSpec = buildAbortSpec(admin, topicPartition, startOffset); + } + + abortTransaction(admin, abortSpec); + } + } + + static class DescribeProducersCommand extends TransactionsCommand { + static final String[] HEADERS = new String[]{ + "ProducerId", + "ProducerEpoch", + "LatestCoordinatorEpoch", + "LastSequence", + "LastTimestamp", + "CurrentTransactionStartOffset" + }; + + DescribeProducersCommand(Time time) { + super(time); + } + + @Override + public String name() { + return "describe-producers"; + } + + @Override + public void addSubparser(Subparsers subparsers) { + Subparser subparser = subparsers.addParser(name()) + .help("describe the states of active producers for a topic partition"); + + subparser.addArgument("--broker-id") + .help("optional broker id to describe the producer state on a specific replica") + .action(store()) + .type(Integer.class) + .required(false); + + subparser.addArgument("--topic") + .help("topic name") + .action(store()) + .type(String.class) + .required(true); + + subparser.addArgument("--partition") + .help("partition number") + .action(store()) + .type(Integer.class) + .required(true); + } + + @Override + public void execute(Admin admin, Namespace ns, PrintStream out) throws Exception { + DescribeProducersOptions options = new DescribeProducersOptions(); + Optional.ofNullable(ns.getInt("broker_id")).ifPresent(options::brokerId); + + String topicName = ns.getString("topic"); + Integer partitionId = ns.getInt("partition"); + TopicPartition topicPartition = new TopicPartition(topicName, partitionId); + + final DescribeProducersResult.PartitionProducerState result; + + try { + result = admin.describeProducers(singleton(topicPartition), options) + .partitionResult(topicPartition) + .get(); + } catch (ExecutionException e) { + String brokerClause = options.brokerId().isPresent() ? + "broker " + options.brokerId().getAsInt() : + "leader"; + printErrorAndExit("Failed to describe producers for partition " + + topicPartition + " on " + brokerClause, e.getCause()); + return; + } + + List rows = result.activeProducers().stream().map(producerState -> { + String currentTransactionStartOffsetColumnValue = + producerState.currentTransactionStartOffset().isPresent() ? + String.valueOf(producerState.currentTransactionStartOffset().getAsLong()) : + "None"; + + return new String[] { + String.valueOf(producerState.producerId()), + String.valueOf(producerState.producerEpoch()), + String.valueOf(producerState.coordinatorEpoch().orElse(-1)), + String.valueOf(producerState.lastSequence()), + String.valueOf(producerState.lastTimestamp()), + currentTransactionStartOffsetColumnValue + }; + }).collect(Collectors.toList()); + + prettyPrintTable(HEADERS, rows, out); + } + } + + static class DescribeTransactionsCommand extends TransactionsCommand { + static final String[] HEADERS = new String[]{ + "CoordinatorId", + "TransactionalId", + "ProducerId", + "ProducerEpoch", + "TransactionState", + "TransactionTimeoutMs", + "CurrentTransactionStartTimeMs", + "TransactionDurationMs", + "TopicPartitions" + }; + + DescribeTransactionsCommand(Time time) { + super(time); + } + + @Override + public String name() { + return "describe"; + } + + @Override + public void addSubparser(Subparsers subparsers) { + Subparser subparser = subparsers.addParser(name()) + .description("Describe the state of an active transactional-id.") + .help("describe the state of an active transactional-id"); + + subparser.addArgument("--transactional-id") + .help("transactional id") + .action(store()) + .type(String.class) + .required(true); + } + + @Override + public void execute(Admin admin, Namespace ns, PrintStream out) throws Exception { + String transactionalId = ns.getString("transactional_id"); + + final TransactionDescription result; + try { + result = admin.describeTransactions(singleton(transactionalId)) + .description(transactionalId) + .get(); + } catch (ExecutionException e) { + printErrorAndExit("Failed to describe transaction state of " + + "transactional-id `" + transactionalId + "`", e.getCause()); + return; + } + + final String transactionDurationMsColumnValue; + final String transactionStartTimeMsColumnValue; + + if (result.transactionStartTimeMs().isPresent()) { + long transactionStartTimeMs = result.transactionStartTimeMs().getAsLong(); + transactionStartTimeMsColumnValue = String.valueOf(transactionStartTimeMs); + transactionDurationMsColumnValue = String.valueOf(time.milliseconds() - transactionStartTimeMs); + } else { + transactionStartTimeMsColumnValue = "None"; + transactionDurationMsColumnValue = "None"; + } + + String[] row = new String[]{ + String.valueOf(result.coordinatorId()), + transactionalId, + String.valueOf(result.producerId()), + String.valueOf(result.producerEpoch()), + result.state().toString(), + String.valueOf(result.transactionTimeoutMs()), + transactionStartTimeMsColumnValue, + transactionDurationMsColumnValue, + Utils.join(result.topicPartitions(), ",") + }; + + prettyPrintTable(HEADERS, singletonList(row), out); + } + } + + static class ListTransactionsCommand extends TransactionsCommand { + static final String[] HEADERS = new String[] { + "TransactionalId", + "Coordinator", + "ProducerId", + "TransactionState" + }; + + ListTransactionsCommand(Time time) { + super(time); + } + + @Override + public String name() { + return "list"; + } + + @Override + public void addSubparser(Subparsers subparsers) { + subparsers.addParser(name()) + .help("list transactions"); + } + + @Override + public void execute(Admin admin, Namespace ns, PrintStream out) throws Exception { + final Map> result; + + try { + result = admin.listTransactions(new ListTransactionsOptions()) + .allByBrokerId() + .get(); + } catch (ExecutionException e) { + printErrorAndExit("Failed to list transactions", e.getCause()); + return; + } + + List rows = new ArrayList<>(); + for (Map.Entry> brokerListingsEntry : result.entrySet()) { + String coordinatorIdString = brokerListingsEntry.getKey().toString(); + Collection listings = brokerListingsEntry.getValue(); + + for (TransactionListing listing : listings) { + rows.add(new String[] { + listing.transactionalId(), + coordinatorIdString, + String.valueOf(listing.producerId()), + listing.state().toString() + }); + } + } + + prettyPrintTable(HEADERS, rows, out); + } + } + + static class FindHangingTransactionsCommand extends TransactionsCommand { + private static final int MAX_BATCH_SIZE = 500; + + static final String[] HEADERS = new String[] { + "Topic", + "Partition", + "ProducerId", + "ProducerEpoch", + "CoordinatorEpoch", + "StartOffset", + "LastTimestamp", + "Duration(min)" + }; + + FindHangingTransactionsCommand(Time time) { + super(time); + } + + @Override + String name() { + return "find-hanging"; + } + + @Override + void addSubparser(Subparsers subparsers) { + Subparser subparser = subparsers.addParser(name()) + .help("find hanging transactions"); + + subparser.addArgument("--broker-id") + .help("broker id to search for hanging transactions") + .action(store()) + .type(Integer.class) + .required(false); + + subparser.addArgument("--max-transaction-timeout") + .help("maximum transaction timeout in minutes to limit the scope of the search (15 minutes by default)") + .action(store()) + .type(Integer.class) + .setDefault(15) + .required(false); + + subparser.addArgument("--topic") + .help("topic name to limit search to (required if --partition is specified)") + .action(store()) + .type(String.class) + .required(false); + + subparser.addArgument("--partition") + .help("partition number") + .action(store()) + .type(Integer.class) + .required(false); + } + + @Override + void execute(Admin admin, Namespace ns, PrintStream out) throws Exception { + Optional brokerId = Optional.ofNullable(ns.getInt("broker_id")); + Optional topic = Optional.ofNullable(ns.getString("topic")); + + if (!topic.isPresent() && !brokerId.isPresent()) { + printErrorAndExit("The `find-hanging` command requires either --topic " + + "or --broker-id to limit the scope of the search"); + return; + } + + Optional partition = Optional.ofNullable(ns.getInt("partition")); + if (partition.isPresent() && !topic.isPresent()) { + printErrorAndExit("The --partition argument requires --topic to be provided"); + return; + } + + long maxTransactionTimeoutMs = TimeUnit.MINUTES.toMillis( + ns.getInt("max_transaction_timeout")); + + List topicPartitions = collectTopicPartitionsToSearch( + admin, + topic, + partition, + brokerId + ); + + List candidates = collectCandidateOpenTransactions( + admin, + brokerId, + maxTransactionTimeoutMs, + topicPartitions + ); + + if (candidates.isEmpty()) { + printHangingTransactions(Collections.emptyList(), out); + } else { + Map> openTransactionsByProducerId = groupByProducerId(candidates); + + Map transactionalIds = lookupTransactionalIds( + admin, + openTransactionsByProducerId.keySet() + ); + + Map descriptions = describeTransactions( + admin, + transactionalIds.values() + ); + + List hangingTransactions = filterHangingTransactions( + openTransactionsByProducerId, + transactionalIds, + descriptions + ); + + printHangingTransactions(hangingTransactions, out); + } + } + + private List collectTopicPartitionsToSearch( + Admin admin, + Optional topic, + Optional partition, + Optional brokerId + ) throws Exception { + final List topics; + + if (topic.isPresent()) { + if (partition.isPresent()) { + return Collections.singletonList(new TopicPartition(topic.get(), partition.get())); + } else { + topics = Collections.singletonList(topic.get()); + } + } else { + topics = listTopics(admin); + } + + return findTopicPartitions( + admin, + brokerId, + topics + ); + } + + private List filterHangingTransactions( + Map> openTransactionsByProducerId, + Map transactionalIds, + Map descriptions + ) { + List hangingTransactions = new ArrayList<>(); + + openTransactionsByProducerId.forEach((producerId, openTransactions) -> { + String transactionalId = transactionalIds.get(producerId); + if (transactionalId == null) { + // If we could not find the transactionalId corresponding to the + // producerId of an open transaction, then the transaction is hanging. + hangingTransactions.addAll(openTransactions); + } else { + // Otherwise, we need to check the current transaction state. + TransactionDescription description = descriptions.get(transactionalId); + if (description == null) { + hangingTransactions.addAll(openTransactions); + } else { + for (OpenTransaction openTransaction : openTransactions) { + // The `DescribeTransactions` API returns all partitions being + // written to in an ongoing transaction and any partition which + // does not yet have markers written when in the `PendingAbort` or + // `PendingCommit` states. If the topic partition that we found is + // among these, then we can still expect the coordinator to write + // the marker. Otherwise, it is a hanging transaction. + if (!description.topicPartitions().contains(openTransaction.topicPartition)) { + hangingTransactions.add(openTransaction); + } + } + } + } + }); + + return hangingTransactions; + } + + private void printHangingTransactions( + List hangingTransactions, + PrintStream out + ) { + long currentTimeMs = time.milliseconds(); + List rows = new ArrayList<>(hangingTransactions.size()); + + for (OpenTransaction transaction : hangingTransactions) { + long transactionDurationMinutes = TimeUnit.MILLISECONDS.toMinutes( + currentTimeMs - transaction.producerState.lastTimestamp()); + + rows.add(new String[] { + transaction.topicPartition.topic(), + String.valueOf(transaction.topicPartition.partition()), + String.valueOf(transaction.producerState.producerId()), + String.valueOf(transaction.producerState.producerEpoch()), + String.valueOf(transaction.producerState.coordinatorEpoch().orElse(-1)), + String.valueOf(transaction.producerState.currentTransactionStartOffset().orElse(-1)), + String.valueOf(transaction.producerState.lastTimestamp()), + String.valueOf(transactionDurationMinutes) + }); + } + + prettyPrintTable(HEADERS, rows, out); + } + + private Map describeTransactions( + Admin admin, + Collection transactionalIds + ) throws Exception { + try { + DescribeTransactionsResult result = admin.describeTransactions(new HashSet<>(transactionalIds)); + Map descriptions = new HashMap<>(); + + for (String transactionalId : transactionalIds) { + try { + TransactionDescription description = result.description(transactionalId).get(); + descriptions.put(transactionalId, description); + } catch (ExecutionException e) { + if (e.getCause() instanceof TransactionalIdNotFoundException) { + descriptions.put(transactionalId, null); + } else { + throw e; + } + } + } + + return descriptions; + } catch (ExecutionException e) { + printErrorAndExit("Failed to describe " + transactionalIds.size() + + " transactions", e.getCause()); + return Collections.emptyMap(); + } + } + + private Map> groupByProducerId( + List openTransactions + ) { + Map> res = new HashMap<>(); + for (OpenTransaction transaction : openTransactions) { + List states = res.computeIfAbsent( + transaction.producerState.producerId(), + __ -> new ArrayList<>() + ); + states.add(transaction); + } + return res; + } + + private List listTopics( + Admin admin + ) throws Exception { + try { + ListTopicsOptions listOptions = new ListTopicsOptions().listInternal(true); + return new ArrayList<>(admin.listTopics(listOptions).names().get()); + } catch (ExecutionException e) { + printErrorAndExit("Failed to list topics", e.getCause()); + return Collections.emptyList(); + } + } + + private List findTopicPartitions( + Admin admin, + Optional brokerId, + List topics + ) throws Exception { + List topicPartitions = new ArrayList<>(); + consumeInBatches(topics, MAX_BATCH_SIZE, batch -> { + findTopicPartitions( + admin, + brokerId, + batch, + topicPartitions + ); + }); + return topicPartitions; + } + + private void findTopicPartitions( + Admin admin, + Optional brokerId, + List topics, + List topicPartitions + ) throws Exception { + try { + Map topicDescriptions = admin.describeTopics(topics).allTopicNames().get(); + topicDescriptions.forEach((topic, description) -> { + description.partitions().forEach(partitionInfo -> { + if (!brokerId.isPresent() || hasReplica(brokerId.get(), partitionInfo)) { + topicPartitions.add(new TopicPartition(topic, partitionInfo.partition())); + } + }); + }); + } catch (ExecutionException e) { + printErrorAndExit("Failed to describe " + topics.size() + " topics", e.getCause()); + } + } + + private boolean hasReplica( + int brokerId, + TopicPartitionInfo partitionInfo + ) { + return partitionInfo.replicas().stream().anyMatch(node -> node.id() == brokerId); + } + + private List collectCandidateOpenTransactions( + Admin admin, + Optional brokerId, + long maxTransactionTimeoutMs, + List topicPartitions + ) throws Exception { + // We have to check all partitions on the broker. In order to avoid + // overwhelming it with a giant request, we break the requests into + // smaller batches. + + List candidateTransactions = new ArrayList<>(); + + consumeInBatches(topicPartitions, MAX_BATCH_SIZE, batch -> { + collectCandidateOpenTransactions( + admin, + brokerId, + maxTransactionTimeoutMs, + batch, + candidateTransactions + ); + }); + + return candidateTransactions; + } + + private static class OpenTransaction { + private final TopicPartition topicPartition; + private final ProducerState producerState; + + private OpenTransaction( + TopicPartition topicPartition, + ProducerState producerState + ) { + this.topicPartition = topicPartition; + this.producerState = producerState; + } + } + + private void collectCandidateOpenTransactions( + Admin admin, + Optional brokerId, + long maxTransactionTimeoutMs, + List topicPartitions, + List candidateTransactions + ) throws Exception { + try { + DescribeProducersOptions describeOptions = new DescribeProducersOptions(); + brokerId.ifPresent(describeOptions::brokerId); + + Map producersByPartition = + admin.describeProducers(topicPartitions, describeOptions).all().get(); + + long currentTimeMs = time.milliseconds(); + + producersByPartition.forEach((topicPartition, producersStates) -> { + producersStates.activeProducers().forEach(activeProducer -> { + if (activeProducer.currentTransactionStartOffset().isPresent()) { + long transactionDurationMs = currentTimeMs - activeProducer.lastTimestamp(); + if (transactionDurationMs > maxTransactionTimeoutMs) { + candidateTransactions.add(new OpenTransaction( + topicPartition, + activeProducer + )); + } + } + }); + }); + } catch (ExecutionException e) { + printErrorAndExit("Failed to describe producers for " + topicPartitions.size() + + " partitions on broker " + brokerId, e.getCause()); + } + } + + private Map lookupTransactionalIds( + Admin admin, + Set producerIds + ) throws Exception { + try { + ListTransactionsOptions listTransactionsOptions = new ListTransactionsOptions() + .filterProducerIds(producerIds); + + Collection transactionListings = + admin.listTransactions(listTransactionsOptions).all().get(); + + Map transactionalIdMap = new HashMap<>(); + + transactionListings.forEach(listing -> { + if (!producerIds.contains(listing.producerId())) { + log.debug("Received transaction listing {} which has a producerId " + + "which was not requested", listing); + } else { + transactionalIdMap.put( + listing.producerId(), + listing.transactionalId() + ); + } + }); + + return transactionalIdMap; + } catch (ExecutionException e) { + printErrorAndExit("Failed to list transactions for " + producerIds.size() + + " producers", e.getCause()); + return Collections.emptyMap(); + } + } + + @FunctionalInterface + private interface ThrowableConsumer { + void accept(T t) throws Exception; + } + + private void consumeInBatches( + List list, + int batchSize, + ThrowableConsumer> consumer + ) throws Exception { + int batchStartIndex = 0; + int limitIndex = list.size(); + + while (batchStartIndex < limitIndex) { + int batchEndIndex = Math.min( + limitIndex, + batchStartIndex + batchSize + ); + + consumer.accept(list.subList(batchStartIndex, batchEndIndex)); + batchStartIndex = batchEndIndex; + } + } + } + + private static void appendColumnValue( + StringBuilder rowBuilder, + String value, + int length + ) { + int padLength = length - value.length(); + rowBuilder.append(value); + for (int i = 0; i < padLength; i++) + rowBuilder.append(' '); + } + + private static void printRow( + List columnLengths, + String[] row, + PrintStream out + ) { + StringBuilder rowBuilder = new StringBuilder(); + for (int i = 0; i < row.length; i++) { + Integer columnLength = columnLengths.get(i); + String columnValue = row[i]; + appendColumnValue(rowBuilder, columnValue, columnLength); + rowBuilder.append('\t'); + } + out.println(rowBuilder); + } + + private static void prettyPrintTable( + String[] headers, + List rows, + PrintStream out + ) { + List columnLengths = Arrays.stream(headers) + .map(String::length) + .collect(Collectors.toList()); + + for (String[] row : rows) { + for (int i = 0; i < headers.length; i++) { + columnLengths.set(i, Math.max(columnLengths.get(i), row[i].length())); + } + } + + printRow(columnLengths, headers, out); + rows.forEach(row -> printRow(columnLengths, row, out)); + } + + private static void printErrorAndExit(String message, Throwable t) { + log.debug(message, t); + + String exitMessage = message + ": " + t.getMessage() + "." + + " Enable debug logging for additional detail."; + + printErrorAndExit(exitMessage); + } + + private static void printErrorAndExit(String message) { + System.err.println(message); + Exit.exit(1, message); + } + + private static Admin buildAdminClient(Namespace ns) { + final Properties properties; + + String configFile = ns.getString("command_config"); + if (configFile == null) { + properties = new Properties(); + } else { + try { + properties = Utils.loadProps(configFile); + } catch (IOException e) { + printErrorAndExit("Failed to load admin client properties", e); + return null; + } + } + + String bootstrapServers = ns.getString("bootstrap_server"); + properties.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers); + + return Admin.create(properties); + } + + static ArgumentParser buildBaseParser() { + ArgumentParser parser = ArgumentParsers + .newArgumentParser("kafka-transactions.sh"); + + parser.description("This tool is used to analyze the transactional state of producers in the cluster. " + + "It can be used to detect and recover from hanging transactions."); + + parser.addArgument("-v", "--version") + .action(new PrintVersionAndExitAction()) + .help("show the version of this Kafka distribution and exit"); + + parser.addArgument("--command-config") + .help("property file containing configs to be passed to admin client") + .action(store()) + .type(String.class) + .metavar("FILE") + .required(false); + + parser.addArgument("--bootstrap-server") + .help("hostname and port for the broker to connect to, in the form `host:port` " + + "(multiple comma-separated entries can be given)") + .action(store()) + .type(String.class) + .metavar("host:port") + .required(true); + + return parser; + } + + static void execute( + String[] args, + Function adminSupplier, + PrintStream out, + Time time + ) throws Exception { + List commands = Arrays.asList( + new ListTransactionsCommand(time), + new DescribeTransactionsCommand(time), + new DescribeProducersCommand(time), + new AbortTransactionCommand(time), + new FindHangingTransactionsCommand(time) + ); + + ArgumentParser parser = buildBaseParser(); + Subparsers subparsers = parser.addSubparsers() + .dest("command") + .title("commands") + .metavar("COMMAND"); + commands.forEach(command -> command.addSubparser(subparsers)); + + final Namespace ns; + + try { + ns = parser.parseArgs(args); + } catch (ArgumentParserException e) { + parser.handleError(e); + Exit.exit(1); + return; + } + + Admin admin = adminSupplier.apply(ns); + String commandName = ns.getString("command"); + + Optional commandOpt = commands.stream() + .filter(cmd -> cmd.name().equals(commandName)) + .findFirst(); + + if (!commandOpt.isPresent()) { + printErrorAndExit("Unexpected command " + commandName); + } + + TransactionsCommand command = commandOpt.get(); + command.execute(admin, ns, out); + Exit.exit(0); + } + + public static void main(String[] args) throws Exception { + execute(args, TransactionsCommand::buildAdminClient, System.out, Time.SYSTEM); + } + +} diff --git a/tools/src/main/java/org/apache/kafka/tools/VerifiableConsumer.java b/tools/src/main/java/org/apache/kafka/tools/VerifiableConsumer.java new file mode 100644 index 0000000..e930580 --- /dev/null +++ b/tools/src/main/java/org/apache/kafka/tools/VerifiableConsumer.java @@ -0,0 +1,682 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.tools; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyOrder; +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializerProvider; +import com.fasterxml.jackson.databind.module.SimpleModule; +import net.sourceforge.argparse4j.ArgumentParsers; +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.ArgumentParserException; +import net.sourceforge.argparse4j.inf.MutuallyExclusiveGroup; +import net.sourceforge.argparse4j.inf.Namespace; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRebalanceListener; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.consumer.OffsetAndMetadata; +import org.apache.kafka.clients.consumer.OffsetCommitCallback; +import org.apache.kafka.clients.consumer.RangeAssignor; +import org.apache.kafka.clients.consumer.RoundRobinAssignor; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.errors.FencedInstanceIdException; +import org.apache.kafka.common.errors.WakeupException; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Closeable; +import java.io.IOException; +import java.io.PrintStream; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.CountDownLatch; + +import static net.sourceforge.argparse4j.impl.Arguments.store; +import static net.sourceforge.argparse4j.impl.Arguments.storeTrue; + +/** + * Command line consumer designed for system testing. It outputs consumer events to STDOUT as JSON + * formatted objects. The "name" field in each JSON event identifies the event type. The following + * events are currently supported: + * + *
                  + *
                • partitions_revoked: outputs the partitions revoked through {@link ConsumerRebalanceListener#onPartitionsRevoked(Collection)}. + * See {@link org.apache.kafka.tools.VerifiableConsumer.PartitionsRevoked}
                • + *
                • partitions_assigned: outputs the partitions assigned through {@link ConsumerRebalanceListener#onPartitionsAssigned(Collection)} + * See {@link org.apache.kafka.tools.VerifiableConsumer.PartitionsAssigned}.
                • + *
                • records_consumed: contains a summary of records consumed in a single call to {@link KafkaConsumer#poll(Duration)}. + * See {@link org.apache.kafka.tools.VerifiableConsumer.RecordsConsumed}.
                • + *
                • record_data: contains the key, value, and offset of an individual consumed record (only included if verbose + * output is enabled). See {@link org.apache.kafka.tools.VerifiableConsumer.RecordData}.
                • + *
                • offsets_committed: The result of every offset commit (only included if auto-commit is not enabled). + * See {@link org.apache.kafka.tools.VerifiableConsumer.OffsetsCommitted}
                • + *
                • shutdown_complete: emitted after the consumer returns from {@link KafkaConsumer#close()}. + * See {@link org.apache.kafka.tools.VerifiableConsumer.ShutdownComplete}.
                • + *
                + */ +public class VerifiableConsumer implements Closeable, OffsetCommitCallback, ConsumerRebalanceListener { + + private static final Logger log = LoggerFactory.getLogger(VerifiableConsumer.class); + + private final ObjectMapper mapper = new ObjectMapper(); + private final PrintStream out; + private final KafkaConsumer consumer; + private final String topic; + private final boolean useAutoCommit; + private final boolean useAsyncCommit; + private final boolean verbose; + private final int maxMessages; + private int consumedMessages = 0; + + private CountDownLatch shutdownLatch = new CountDownLatch(1); + + public VerifiableConsumer(KafkaConsumer consumer, + PrintStream out, + String topic, + int maxMessages, + boolean useAutoCommit, + boolean useAsyncCommit, + boolean verbose) { + this.consumer = consumer; + this.out = out; + this.topic = topic; + this.maxMessages = maxMessages; + this.useAutoCommit = useAutoCommit; + this.useAsyncCommit = useAsyncCommit; + this.verbose = verbose; + addKafkaSerializerModule(); + } + + private void addKafkaSerializerModule() { + SimpleModule kafka = new SimpleModule(); + kafka.addSerializer(TopicPartition.class, new JsonSerializer() { + @Override + public void serialize(TopicPartition tp, JsonGenerator gen, SerializerProvider serializers) throws IOException { + gen.writeStartObject(); + gen.writeObjectField("topic", tp.topic()); + gen.writeObjectField("partition", tp.partition()); + gen.writeEndObject(); + } + }); + mapper.registerModule(kafka); + } + + private boolean hasMessageLimit() { + return maxMessages >= 0; + } + + private boolean isFinished() { + return hasMessageLimit() && consumedMessages >= maxMessages; + } + + private Map onRecordsReceived(ConsumerRecords records) { + Map offsets = new HashMap<>(); + + List summaries = new ArrayList<>(); + for (TopicPartition tp : records.partitions()) { + List> partitionRecords = records.records(tp); + + if (hasMessageLimit() && consumedMessages + partitionRecords.size() > maxMessages) + partitionRecords = partitionRecords.subList(0, maxMessages - consumedMessages); + + if (partitionRecords.isEmpty()) + continue; + + long minOffset = partitionRecords.get(0).offset(); + long maxOffset = partitionRecords.get(partitionRecords.size() - 1).offset(); + + offsets.put(tp, new OffsetAndMetadata(maxOffset + 1)); + summaries.add(new RecordSetSummary(tp.topic(), tp.partition(), + partitionRecords.size(), minOffset, maxOffset)); + + if (verbose) { + for (ConsumerRecord record : partitionRecords) { + printJson(new RecordData(record)); + } + } + + consumedMessages += partitionRecords.size(); + if (isFinished()) + break; + } + + printJson(new RecordsConsumed(records.count(), summaries)); + return offsets; + } + + @Override + public void onComplete(Map offsets, Exception exception) { + List committedOffsets = new ArrayList<>(); + for (Map.Entry offsetEntry : offsets.entrySet()) { + TopicPartition tp = offsetEntry.getKey(); + committedOffsets.add(new CommitData(tp.topic(), tp.partition(), offsetEntry.getValue().offset())); + } + + boolean success = true; + String error = null; + if (exception != null) { + success = false; + error = exception.getMessage(); + } + printJson(new OffsetsCommitted(committedOffsets, error, success)); + } + + @Override + public void onPartitionsAssigned(Collection partitions) { + printJson(new PartitionsAssigned(partitions)); + } + + @Override + public void onPartitionsRevoked(Collection partitions) { + printJson(new PartitionsRevoked(partitions)); + } + + private void printJson(Object data) { + try { + out.println(mapper.writeValueAsString(data)); + } catch (JsonProcessingException e) { + out.println("Bad data can't be written as json: " + e.getMessage()); + } + } + + public void commitSync(Map offsets) { + try { + consumer.commitSync(offsets); + onComplete(offsets, null); + } catch (WakeupException e) { + // we only call wakeup() once to close the consumer, so this recursion should be safe + commitSync(offsets); + throw e; + } catch (FencedInstanceIdException e) { + throw e; + } catch (Exception e) { + onComplete(offsets, e); + } + } + + public void run() { + try { + printJson(new StartupComplete()); + consumer.subscribe(Collections.singletonList(topic), this); + + while (!isFinished()) { + ConsumerRecords records = consumer.poll(Duration.ofMillis(Long.MAX_VALUE)); + Map offsets = onRecordsReceived(records); + + if (!useAutoCommit) { + if (useAsyncCommit) + consumer.commitAsync(offsets, this); + else + commitSync(offsets); + } + } + } catch (WakeupException e) { + // ignore, we are closing + log.trace("Caught WakeupException because consumer is shutdown, ignore and terminate.", e); + } catch (Throwable t) { + // Log the error so it goes to the service log and not stdout + log.error("Error during processing, terminating consumer process: ", t); + } finally { + consumer.close(); + printJson(new ShutdownComplete()); + shutdownLatch.countDown(); + } + } + + public void close() { + boolean interrupted = false; + try { + consumer.wakeup(); + while (true) { + try { + shutdownLatch.await(); + return; + } catch (InterruptedException e) { + interrupted = true; + } + } + } finally { + if (interrupted) + Thread.currentThread().interrupt(); + } + } + + @JsonPropertyOrder({ "timestamp", "name" }) + private static abstract class ConsumerEvent { + private final long timestamp = System.currentTimeMillis(); + + @JsonProperty + public abstract String name(); + + @JsonProperty + public long timestamp() { + return timestamp; + } + } + + private static class StartupComplete extends ConsumerEvent { + + @Override + public String name() { + return "startup_complete"; + } + } + + private static class ShutdownComplete extends ConsumerEvent { + + @Override + public String name() { + return "shutdown_complete"; + } + } + + private static class PartitionsRevoked extends ConsumerEvent { + private final Collection partitions; + + public PartitionsRevoked(Collection partitions) { + this.partitions = partitions; + } + + @JsonProperty + public Collection partitions() { + return partitions; + } + + @Override + public String name() { + return "partitions_revoked"; + } + } + + private static class PartitionsAssigned extends ConsumerEvent { + private final Collection partitions; + + public PartitionsAssigned(Collection partitions) { + this.partitions = partitions; + } + + @JsonProperty + public Collection partitions() { + return partitions; + } + + @Override + public String name() { + return "partitions_assigned"; + } + } + + public static class RecordsConsumed extends ConsumerEvent { + private final long count; + private final List partitionSummaries; + + public RecordsConsumed(long count, List partitionSummaries) { + this.count = count; + this.partitionSummaries = partitionSummaries; + } + + @Override + public String name() { + return "records_consumed"; + } + + @JsonProperty + public long count() { + return count; + } + + @JsonProperty + public List partitions() { + return partitionSummaries; + } + } + + @JsonPropertyOrder({ "timestamp", "name", "key", "value", "topic", "partition", "offset" }) + public static class RecordData extends ConsumerEvent { + + private final ConsumerRecord record; + + public RecordData(ConsumerRecord record) { + this.record = record; + } + + @Override + public String name() { + return "record_data"; + } + + @JsonProperty + public String topic() { + return record.topic(); + } + + @JsonProperty + public int partition() { + return record.partition(); + } + + @JsonProperty + public String key() { + return record.key(); + } + + @JsonProperty + public String value() { + return record.value(); + } + + @JsonProperty + public long offset() { + return record.offset(); + } + + } + + private static class PartitionData { + private final String topic; + private final int partition; + + public PartitionData(String topic, int partition) { + this.topic = topic; + this.partition = partition; + } + + @JsonProperty + public String topic() { + return topic; + } + + @JsonProperty + public int partition() { + return partition; + } + } + + private static class OffsetsCommitted extends ConsumerEvent { + + private final List offsets; + private final String error; + private final boolean success; + + public OffsetsCommitted(List offsets, String error, boolean success) { + this.offsets = offsets; + this.error = error; + this.success = success; + } + + @Override + public String name() { + return "offsets_committed"; + } + + @JsonProperty + public List offsets() { + return offsets; + } + + @JsonProperty + @JsonInclude(JsonInclude.Include.NON_NULL) + public String error() { + return error; + } + + @JsonProperty + public boolean success() { + return success; + } + + } + + private static class CommitData extends PartitionData { + private final long offset; + + public CommitData(String topic, int partition, long offset) { + super(topic, partition); + this.offset = offset; + } + + @JsonProperty + public long offset() { + return offset; + } + } + + private static class RecordSetSummary extends PartitionData { + private final long count; + private final long minOffset; + private final long maxOffset; + + public RecordSetSummary(String topic, int partition, long count, long minOffset, long maxOffset) { + super(topic, partition); + this.count = count; + this.minOffset = minOffset; + this.maxOffset = maxOffset; + } + + @JsonProperty + public long count() { + return count; + } + + @JsonProperty + public long minOffset() { + return minOffset; + } + + @JsonProperty + public long maxOffset() { + return maxOffset; + } + + } + + private static ArgumentParser argParser() { + ArgumentParser parser = ArgumentParsers + .newArgumentParser("verifiable-consumer") + .defaultHelp(true) + .description("This tool consumes messages from a specific topic and emits consumer events (e.g. group rebalances, received messages, and offsets committed) as JSON objects to STDOUT."); + MutuallyExclusiveGroup connectionGroup = parser.addMutuallyExclusiveGroup("Connection Group") + .description("Group of arguments for connection to brokers") + .required(true); + connectionGroup.addArgument("--bootstrap-server") + .action(store()) + .required(false) + .type(String.class) + .metavar("HOST1:PORT1[,HOST2:PORT2[...]]") + .dest("bootstrapServer") + .help("REQUIRED unless --broker-list(deprecated) is specified. The server(s) to connect to. Comma-separated list of Kafka brokers in the form HOST1:PORT1,HOST2:PORT2,..."); + connectionGroup.addArgument("--broker-list") + .action(store()) + .required(false) + .type(String.class) + .metavar("HOST1:PORT1[,HOST2:PORT2[...]]") + .dest("brokerList") + .help("DEPRECATED, use --bootstrap-server instead; ignored if --bootstrap-server is specified. Comma-separated list of Kafka brokers in the form HOST1:PORT1,HOST2:PORT2,..."); + + parser.addArgument("--topic") + .action(store()) + .required(true) + .type(String.class) + .metavar("TOPIC") + .help("Consumes messages from this topic."); + + parser.addArgument("--group-id") + .action(store()) + .required(true) + .type(String.class) + .metavar("GROUP_ID") + .dest("groupId") + .help("The groupId shared among members of the consumer group"); + + parser.addArgument("--group-instance-id") + .action(store()) + .required(false) + .type(String.class) + .metavar("GROUP_INSTANCE_ID") + .dest("groupInstanceId") + .help("A unique identifier of the consumer instance"); + + parser.addArgument("--max-messages") + .action(store()) + .required(false) + .type(Integer.class) + .setDefault(-1) + .metavar("MAX-MESSAGES") + .dest("maxMessages") + .help("Consume this many messages. If -1 (the default), the consumer will consume until the process is killed externally"); + + parser.addArgument("--session-timeout") + .action(store()) + .required(false) + .setDefault(30000) + .type(Integer.class) + .metavar("TIMEOUT_MS") + .dest("sessionTimeout") + .help("Set the consumer's session timeout"); + + parser.addArgument("--verbose") + .action(storeTrue()) + .type(Boolean.class) + .metavar("VERBOSE") + .help("Enable to log individual consumed records"); + + parser.addArgument("--enable-autocommit") + .action(storeTrue()) + .type(Boolean.class) + .metavar("ENABLE-AUTOCOMMIT") + .dest("useAutoCommit") + .help("Enable offset auto-commit on consumer"); + + parser.addArgument("--reset-policy") + .action(store()) + .required(false) + .setDefault("earliest") + .type(String.class) + .dest("resetPolicy") + .help("Set reset policy (must be either 'earliest', 'latest', or 'none'"); + + parser.addArgument("--assignment-strategy") + .action(store()) + .required(false) + .setDefault(RangeAssignor.class.getName()) + .type(String.class) + .dest("assignmentStrategy") + .help("Set assignment strategy (e.g. " + RoundRobinAssignor.class.getName() + ")"); + + parser.addArgument("--consumer.config") + .action(store()) + .required(false) + .type(String.class) + .metavar("CONFIG_FILE") + .help("Consumer config properties file (config options shared with command line parameters will be overridden)."); + + return parser; + } + + public static VerifiableConsumer createFromArgs(ArgumentParser parser, String[] args) throws ArgumentParserException { + Namespace res = parser.parseArgs(args); + + boolean useAutoCommit = res.getBoolean("useAutoCommit"); + String configFile = res.getString("consumer.config"); + String brokerHostandPort = null; + + Properties consumerProps = new Properties(); + if (configFile != null) { + try { + consumerProps.putAll(Utils.loadProps(configFile)); + } catch (IOException e) { + throw new ArgumentParserException(e.getMessage(), parser); + } + } + + consumerProps.put(ConsumerConfig.GROUP_ID_CONFIG, res.getString("groupId")); + + String groupInstanceId = res.getString("groupInstanceId"); + if (groupInstanceId != null) { + consumerProps.put(ConsumerConfig.GROUP_INSTANCE_ID_CONFIG, groupInstanceId); + } + + + if (res.get("bootstrapServer") != null) { + brokerHostandPort = res.getString("bootstrapServer"); + } else if (res.getString("brokerList") != null) { + brokerHostandPort = res.getString("brokerList"); + } else { + parser.printHelp(); + // Can't use `Exit.exit` here because it didn't exist until 0.11.0.0. + System.exit(0); + } + consumerProps.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, brokerHostandPort); + + consumerProps.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, useAutoCommit); + consumerProps.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, res.getString("resetPolicy")); + consumerProps.put(ConsumerConfig.SESSION_TIMEOUT_MS_CONFIG, Integer.toString(res.getInt("sessionTimeout"))); + consumerProps.put(ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG, res.getString("assignmentStrategy")); + + StringDeserializer deserializer = new StringDeserializer(); + KafkaConsumer consumer = new KafkaConsumer<>(consumerProps, deserializer, deserializer); + + String topic = res.getString("topic"); + int maxMessages = res.getInt("maxMessages"); + boolean verbose = res.getBoolean("verbose"); + + return new VerifiableConsumer( + consumer, + System.out, + topic, + maxMessages, + useAutoCommit, + false, + verbose); + } + + public static void main(String[] args) { + ArgumentParser parser = argParser(); + if (args.length == 0) { + parser.printHelp(); + // Can't use `Exit.exit` here because it didn't exist until 0.11.0.0. + System.exit(0); + } + try { + final VerifiableConsumer consumer = createFromArgs(parser, args); + // Can't use `Exit.addShutdownHook` here because it didn't exist until 2.5.0. + Runtime.getRuntime().addShutdownHook(new Thread(consumer::close, "verifiable-consumer-shutdown-hook")); + + consumer.run(); + } catch (ArgumentParserException e) { + parser.handleError(e); + // Can't use `Exit.exit` here because it didn't exist until 0.11.0.0. + System.exit(1); + } + } + +} diff --git a/tools/src/main/java/org/apache/kafka/tools/VerifiableLog4jAppender.java b/tools/src/main/java/org/apache/kafka/tools/VerifiableLog4jAppender.java new file mode 100644 index 0000000..d390926 --- /dev/null +++ b/tools/src/main/java/org/apache/kafka/tools/VerifiableLog4jAppender.java @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.tools; + +import net.sourceforge.argparse4j.ArgumentParsers; +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.ArgumentParserException; +import net.sourceforge.argparse4j.inf.Namespace; +import org.apache.kafka.common.security.auth.SecurityProtocol; +import org.apache.kafka.common.utils.Exit; +import org.apache.log4j.Logger; +import org.apache.log4j.PropertyConfigurator; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.Properties; + +import static net.sourceforge.argparse4j.impl.Arguments.store; + +/** + * Primarily intended for use with system testing, this appender produces message + * to Kafka on each "append" request. For example, this helps with end-to-end tests + * of KafkaLog4jAppender. + * + * When used as a command-line tool, it appends increasing integers. It will produce a + * fixed number of messages unless the default max-messages -1 is used, in which case + * it appends indefinitely. + */ + +public class VerifiableLog4jAppender { + Logger logger = Logger.getLogger(VerifiableLog4jAppender.class); + + // If maxMessages < 0, log until the process is killed externally + private long maxMessages = -1; + + // Hook to trigger logging thread to stop logging messages + private volatile boolean stopLogging = false; + + /** Get the command-line argument parser. */ + private static ArgumentParser argParser() { + ArgumentParser parser = ArgumentParsers + .newArgumentParser("verifiable-log4j-appender") + .defaultHelp(true) + .description("This tool produces increasing integers to the specified topic using KafkaLog4jAppender."); + + parser.addArgument("--topic") + .action(store()) + .required(true) + .type(String.class) + .metavar("TOPIC") + .help("Produce messages to this topic."); + + parser.addArgument("--broker-list") + .action(store()) + .required(true) + .type(String.class) + .metavar("HOST1:PORT1[,HOST2:PORT2[...]]") + .dest("brokerList") + .help("Comma-separated list of Kafka brokers in the form HOST1:PORT1,HOST2:PORT2,..."); + + parser.addArgument("--max-messages") + .action(store()) + .required(false) + .setDefault(-1) + .type(Integer.class) + .metavar("MAX-MESSAGES") + .dest("maxMessages") + .help("Produce this many messages. If -1, produce messages until the process is killed externally."); + + parser.addArgument("--acks") + .action(store()) + .required(false) + .setDefault("-1") + .type(String.class) + .choices("0", "1", "-1") + .metavar("ACKS") + .help("Acks required on each produced message. See Kafka docs on request.required.acks for details."); + + parser.addArgument("--security-protocol") + .action(store()) + .required(false) + .setDefault("PLAINTEXT") + .type(String.class) + .choices("PLAINTEXT", "SSL", "SASL_PLAINTEXT", "SASL_SSL") + .metavar("SECURITY-PROTOCOL") + .dest("securityProtocol") + .help("Security protocol to be used while communicating with Kafka brokers."); + + parser.addArgument("--ssl-truststore-location") + .action(store()) + .required(false) + .type(String.class) + .metavar("SSL-TRUSTSTORE-LOCATION") + .dest("sslTruststoreLocation") + .help("Location of SSL truststore to use."); + + parser.addArgument("--ssl-truststore-password") + .action(store()) + .required(false) + .type(String.class) + .metavar("SSL-TRUSTSTORE-PASSWORD") + .dest("sslTruststorePassword") + .help("Password for SSL truststore to use."); + + parser.addArgument("--appender.config") + .action(store()) + .required(false) + .type(String.class) + .metavar("CONFIG_FILE") + .help("Log4jAppender config properties file."); + + parser.addArgument("--sasl-kerberos-service-name") + .action(store()) + .required(false) + .type(String.class) + .metavar("SASL-KERBEROS-SERVICE-NAME") + .dest("saslKerberosServiceName") + .help("Name of sasl kerberos service."); + + parser.addArgument("--client-jaas-conf-path") + .action(store()) + .required(false) + .type(String.class) + .metavar("CLIENT-JAAS-CONF-PATH") + .dest("clientJaasConfPath") + .help("Path of JAAS config file of Kafka client."); + + parser.addArgument("--kerb5-conf-path") + .action(store()) + .required(false) + .type(String.class) + .metavar("KERB5-CONF-PATH") + .dest("kerb5ConfPath") + .help("Path of Kerb5 config file."); + + return parser; + } + + /** + * Read a properties file from the given path + * @param filename The path of the file to read + * + * Note: this duplication of org.apache.kafka.common.utils.Utils.loadProps is unfortunate + * but *intentional*. In order to use VerifiableProducer in compatibility and upgrade tests, + * we use VerifiableProducer from the development tools package, and run it against 0.8.X.X kafka jars. + * Since this method is not in Utils in the 0.8.X.X jars, we have to cheat a bit and duplicate. + */ + public static Properties loadProps(String filename) throws IOException { + Properties props = new Properties(); + try (InputStream propStream = Files.newInputStream(Paths.get(filename))) { + props.load(propStream); + } + return props; + } + + /** Construct a VerifiableLog4jAppender object from command-line arguments. */ + public static VerifiableLog4jAppender createFromArgs(String[] args) { + ArgumentParser parser = argParser(); + VerifiableLog4jAppender producer = null; + + try { + Namespace res = parser.parseArgs(args); + + int maxMessages = res.getInt("maxMessages"); + String topic = res.getString("topic"); + String configFile = res.getString("appender.config"); + + Properties props = new Properties(); + props.setProperty("log4j.rootLogger", "INFO, KAFKA"); + props.setProperty("log4j.appender.KAFKA", "org.apache.kafka.log4jappender.KafkaLog4jAppender"); + props.setProperty("log4j.appender.KAFKA.layout", "org.apache.log4j.PatternLayout"); + props.setProperty("log4j.appender.KAFKA.layout.ConversionPattern", "%-5p: %c - %m%n"); + props.setProperty("log4j.appender.KAFKA.BrokerList", res.getString("brokerList")); + props.setProperty("log4j.appender.KAFKA.Topic", topic); + props.setProperty("log4j.appender.KAFKA.RequiredNumAcks", res.getString("acks")); + props.setProperty("log4j.appender.KAFKA.SyncSend", "true"); + final String securityProtocol = res.getString("securityProtocol"); + if (securityProtocol != null && !securityProtocol.equals(SecurityProtocol.PLAINTEXT.toString())) { + props.setProperty("log4j.appender.KAFKA.SecurityProtocol", securityProtocol); + } + if (securityProtocol != null && securityProtocol.contains("SSL")) { + props.setProperty("log4j.appender.KAFKA.SslTruststoreLocation", res.getString("sslTruststoreLocation")); + props.setProperty("log4j.appender.KAFKA.SslTruststorePassword", res.getString("sslTruststorePassword")); + } + if (securityProtocol != null && securityProtocol.contains("SASL")) { + props.setProperty("log4j.appender.KAFKA.SaslKerberosServiceName", res.getString("saslKerberosServiceName")); + props.setProperty("log4j.appender.KAFKA.clientJaasConfPath", res.getString("clientJaasConfPath")); + props.setProperty("log4j.appender.KAFKA.kerb5ConfPath", res.getString("kerb5ConfPath")); + } + props.setProperty("log4j.logger.kafka.log4j", "INFO, KAFKA"); + // Changing log level from INFO to WARN as a temporary workaround for KAFKA-6415. This is to + // avoid deadlock in system tests when producer network thread appends to log while updating metadata. + props.setProperty("log4j.logger.org.apache.kafka.clients.Metadata", "WARN, KAFKA"); + + if (configFile != null) { + try { + props.putAll(loadProps(configFile)); + } catch (IOException e) { + throw new ArgumentParserException(e.getMessage(), parser); + } + } + + producer = new VerifiableLog4jAppender(props, maxMessages); + } catch (ArgumentParserException e) { + if (args.length == 0) { + parser.printHelp(); + Exit.exit(0); + } else { + parser.handleError(e); + Exit.exit(1); + } + } + + return producer; + } + + + public VerifiableLog4jAppender(Properties props, int maxMessages) { + this.maxMessages = maxMessages; + PropertyConfigurator.configure(props); + } + + public static void main(String[] args) { + + final VerifiableLog4jAppender appender = createFromArgs(args); + boolean infinite = appender.maxMessages < 0; + + // Trigger main thread to stop producing messages when shutting down + Exit.addShutdownHook("verifiable-log4j-appender-shutdown-hook", () -> appender.stopLogging = true); + + long maxMessages = infinite ? Long.MAX_VALUE : appender.maxMessages; + for (long i = 0; i < maxMessages; i++) { + if (appender.stopLogging) { + break; + } + appender.append(String.format("%d", i)); + } + } + + private void append(String msg) { + logger.info(msg); + } +} diff --git a/tools/src/main/java/org/apache/kafka/tools/VerifiableProducer.java b/tools/src/main/java/org/apache/kafka/tools/VerifiableProducer.java new file mode 100644 index 0000000..ee863d4 --- /dev/null +++ b/tools/src/main/java/org/apache/kafka/tools/VerifiableProducer.java @@ -0,0 +1,564 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.tools; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyOrder; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; + +import net.sourceforge.argparse4j.ArgumentParsers; +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.ArgumentParserException; +import net.sourceforge.argparse4j.inf.MutuallyExclusiveGroup; +import net.sourceforge.argparse4j.inf.Namespace; + +import org.apache.kafka.clients.producer.Callback; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.Producer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.serialization.StringSerializer; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.Properties; + +import static net.sourceforge.argparse4j.impl.Arguments.store; + +/** + * Primarily intended for use with system testing, this producer prints metadata + * in the form of JSON to stdout on each "send" request. For example, this helps + * with end-to-end correctness tests by making externally visible which messages have been + * acked and which have not. + * + * When used as a command-line tool, it produces increasing integers. It will produce a + * fixed number of messages unless the default max-messages -1 is used, in which case + * it produces indefinitely. + * + * If logging is left enabled, log output on stdout can be easily ignored by checking + * whether a given line is valid JSON. + */ +public class VerifiableProducer implements AutoCloseable { + + private final ObjectMapper mapper = new ObjectMapper(); + private final String topic; + private final Producer producer; + // If maxMessages < 0, produce until the process is killed externally + private long maxMessages = -1; + + // Number of messages for which acks were received + private long numAcked = 0; + + // Number of send attempts + private long numSent = 0; + + // Throttle message throughput if this is set >= 0 + private final long throughput; + + // Hook to trigger producing thread to stop sending messages + private boolean stopProducing = false; + + // Prefix (plus a dot separator) added to every value produced by verifiable producer + // if null, then values are produced without a prefix + private final Integer valuePrefix; + + // Send messages with a key of 0 incrementing by 1 for + // each message produced when number specified is reached + // key is reset to 0 + private final Integer repeatingKeys; + + private int keyCounter; + + // The create time to set in messages, in milliseconds since epoch + private Long createTime; + + private final Long startTime; + + public VerifiableProducer(KafkaProducer producer, String topic, int throughput, int maxMessages, + Integer valuePrefix, Long createTime, Integer repeatingKeys) { + + this.topic = topic; + this.throughput = throughput; + this.maxMessages = maxMessages; + this.producer = producer; + this.valuePrefix = valuePrefix; + this.createTime = createTime; + this.startTime = System.currentTimeMillis(); + this.repeatingKeys = repeatingKeys; + + } + + /** Get the command-line argument parser. */ + private static ArgumentParser argParser() { + ArgumentParser parser = ArgumentParsers + .newArgumentParser("verifiable-producer") + .defaultHelp(true) + .description("This tool produces increasing integers to the specified topic and prints JSON metadata to stdout on each \"send\" request, making externally visible which messages have been acked and which have not."); + + parser.addArgument("--topic") + .action(store()) + .required(true) + .type(String.class) + .metavar("TOPIC") + .help("Produce messages to this topic."); + MutuallyExclusiveGroup connectionGroup = parser.addMutuallyExclusiveGroup("Connection Group") + .description("Group of arguments for connection to brokers") + .required(true); + connectionGroup.addArgument("--bootstrap-server") + .action(store()) + .required(false) + .type(String.class) + .metavar("HOST1:PORT1[,HOST2:PORT2[...]]") + .dest("bootstrapServer") + .help("REQUIRED: The server(s) to connect to. Comma-separated list of Kafka brokers in the form HOST1:PORT1,HOST2:PORT2,..."); + + connectionGroup.addArgument("--broker-list") + .action(store()) + .required(false) + .type(String.class) + .metavar("HOST1:PORT1[,HOST2:PORT2[...]]") + .dest("brokerList") + .help("DEPRECATED, use --bootstrap-server instead; ignored if --bootstrap-server is specified. Comma-separated list of Kafka brokers in the form HOST1:PORT1,HOST2:PORT2,..."); + + parser.addArgument("--max-messages") + .action(store()) + .required(false) + .setDefault(-1) + .type(Integer.class) + .metavar("MAX-MESSAGES") + .dest("maxMessages") + .help("Produce this many messages. If -1, produce messages until the process is killed externally."); + + parser.addArgument("--throughput") + .action(store()) + .required(false) + .setDefault(-1) + .type(Integer.class) + .metavar("THROUGHPUT") + .help("If set >= 0, throttle maximum message throughput to *approximately* THROUGHPUT messages/sec."); + + parser.addArgument("--acks") + .action(store()) + .required(false) + .setDefault(-1) + .type(Integer.class) + .choices(0, 1, -1) + .metavar("ACKS") + .help("Acks required on each produced message. See Kafka docs on acks for details."); + + parser.addArgument("--producer.config") + .action(store()) + .required(false) + .type(String.class) + .metavar("CONFIG_FILE") + .help("Producer config properties file."); + + parser.addArgument("--message-create-time") + .action(store()) + .required(false) + .setDefault(-1L) + .type(Long.class) + .metavar("CREATETIME") + .dest("createTime") + .help("Send messages with creation time starting at the arguments value, in milliseconds since epoch"); + + parser.addArgument("--value-prefix") + .action(store()) + .required(false) + .type(Integer.class) + .metavar("VALUE-PREFIX") + .dest("valuePrefix") + .help("If specified, each produced value will have this prefix with a dot separator"); + + parser.addArgument("--repeating-keys") + .action(store()) + .required(false) + .type(Integer.class) + .metavar("REPEATING-KEYS") + .dest("repeatingKeys") + .help("If specified, each produced record will have a key starting at 0 increment by 1 up to the number specified (exclusive), then the key is set to 0 again"); + + return parser; + } + + /** + * Read a properties file from the given path + * @param filename The path of the file to read + * + * Note: this duplication of org.apache.kafka.common.utils.Utils.loadProps is unfortunate + * but *intentional*. In order to use VerifiableProducer in compatibility and upgrade tests, + * we use VerifiableProducer from the development tools package, and run it against 0.8.X.X kafka jars. + * Since this method is not in Utils in the 0.8.X.X jars, we have to cheat a bit and duplicate. + */ + public static Properties loadProps(String filename) throws IOException { + Properties props = new Properties(); + try (InputStream propStream = Files.newInputStream(Paths.get(filename))) { + props.load(propStream); + } + return props; + } + + /** Construct a VerifiableProducer object from command-line arguments. */ + public static VerifiableProducer createFromArgs(ArgumentParser parser, String[] args) throws ArgumentParserException { + Namespace res = parser.parseArgs(args); + + int maxMessages = res.getInt("maxMessages"); + String topic = res.getString("topic"); + int throughput = res.getInt("throughput"); + String configFile = res.getString("producer.config"); + Integer valuePrefix = res.getInt("valuePrefix"); + Long createTime = res.getLong("createTime"); + Integer repeatingKeys = res.getInt("repeatingKeys"); + + if (createTime == -1L) + createTime = null; + + Properties producerProps = new Properties(); + + if (res.get("bootstrapServer") != null) { + producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, res.getString("bootstrapServer")); + } else if (res.getString("brokerList") != null) { + producerProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, res.getString("brokerList")); + } else { + parser.printHelp(); + // Can't use `Exit.exit` here because it didn't exist until 0.11.0.0. + System.exit(0); + } + + producerProps.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, + "org.apache.kafka.common.serialization.StringSerializer"); + producerProps.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, + "org.apache.kafka.common.serialization.StringSerializer"); + producerProps.put(ProducerConfig.ACKS_CONFIG, Integer.toString(res.getInt("acks"))); + // No producer retries + producerProps.put(ProducerConfig.RETRIES_CONFIG, "0"); + if (configFile != null) { + try { + producerProps.putAll(loadProps(configFile)); + } catch (IOException e) { + throw new ArgumentParserException(e.getMessage(), parser); + } + } + + StringSerializer serializer = new StringSerializer(); + KafkaProducer producer = new KafkaProducer<>(producerProps, serializer, serializer); + + return new VerifiableProducer(producer, topic, throughput, maxMessages, valuePrefix, createTime, repeatingKeys); + } + + /** Produce a message with given key and value. */ + public void send(String key, String value) { + ProducerRecord record; + + // Older versions of ProducerRecord don't include the message create time in the constructor. So including + // even a 'null' argument results in a NoSuchMethodException. Thus we only include the create time if it is + // explicitly specified to remain fully backward compatible with older clients. + if (createTime != null) { + record = new ProducerRecord<>(topic, null, createTime, key, value); + createTime += System.currentTimeMillis() - startTime; + } else { + record = new ProducerRecord<>(topic, key, value); + } + + numSent++; + try { + producer.send(record, new PrintInfoCallback(key, value)); + } catch (Exception e) { + + synchronized (System.out) { + printJson(new FailedSend(key, value, topic, e)); + } + } + } + + /** Returns a string to publish: ether 'valuePrefix'.'val' or 'val' **/ + public String getValue(long val) { + if (this.valuePrefix != null) { + return String.format("%d.%d", this.valuePrefix, val); + } + return String.format("%d", val); + } + + public String getKey() { + String key = null; + if (repeatingKeys != null) { + key = Integer.toString(keyCounter++); + if (keyCounter == repeatingKeys) { + keyCounter = 0; + } + } + return key; + } + + /** Close the producer to flush any remaining messages. */ + public void close() { + producer.close(); + printJson(new ShutdownComplete()); + } + + @JsonPropertyOrder({ "timestamp", "name" }) + private static abstract class ProducerEvent { + private final long timestamp = System.currentTimeMillis(); + + @JsonProperty + public abstract String name(); + + @JsonProperty + public long timestamp() { + return timestamp; + } + } + + private static class StartupComplete extends ProducerEvent { + + @Override + public String name() { + return "startup_complete"; + } + } + + private static class ShutdownComplete extends ProducerEvent { + + @Override + public String name() { + return "shutdown_complete"; + } + } + + private static class SuccessfulSend extends ProducerEvent { + + private String key; + private String value; + private RecordMetadata recordMetadata; + + public SuccessfulSend(String key, String value, RecordMetadata recordMetadata) { + assert recordMetadata != null : "Expected non-null recordMetadata object."; + this.key = key; + this.value = value; + this.recordMetadata = recordMetadata; + } + + @Override + public String name() { + return "producer_send_success"; + } + + @JsonProperty + public String key() { + return key; + } + + @JsonProperty + public String value() { + return value; + } + + @JsonProperty + public String topic() { + return recordMetadata.topic(); + } + + @JsonProperty + public int partition() { + return recordMetadata.partition(); + } + + @JsonProperty + public long offset() { + return recordMetadata.offset(); + } + } + + private static class FailedSend extends ProducerEvent { + + private String topic; + private String key; + private String value; + private Exception exception; + + public FailedSend(String key, String value, String topic, Exception exception) { + assert exception != null : "Expected non-null exception."; + this.key = key; + this.value = value; + this.topic = topic; + this.exception = exception; + } + + @Override + public String name() { + return "producer_send_error"; + } + + @JsonProperty + public String key() { + return key; + } + + @JsonProperty + public String value() { + return value; + } + + @JsonProperty + public String topic() { + return topic; + } + + @JsonProperty + public String exception() { + return exception.getClass().toString(); + } + + @JsonProperty + public String message() { + return exception.getMessage(); + } + } + + private static class ToolData extends ProducerEvent { + + private long sent; + private long acked; + private long targetThroughput; + private double avgThroughput; + + public ToolData(long sent, long acked, long targetThroughput, double avgThroughput) { + this.sent = sent; + this.acked = acked; + this.targetThroughput = targetThroughput; + this.avgThroughput = avgThroughput; + } + + @Override + public String name() { + return "tool_data"; + } + + @JsonProperty + public long sent() { + return this.sent; + } + + @JsonProperty + public long acked() { + return this.acked; + } + + @JsonProperty("target_throughput") + public long targetThroughput() { + return this.targetThroughput; + } + + @JsonProperty("avg_throughput") + public double avgThroughput() { + return this.avgThroughput; + } + } + + private void printJson(Object data) { + try { + System.out.println(mapper.writeValueAsString(data)); + } catch (JsonProcessingException e) { + System.out.println("Bad data can't be written as json: " + e.getMessage()); + } + } + + /** Callback which prints errors to stdout when the producer fails to send. */ + private class PrintInfoCallback implements Callback { + + private String key; + private String value; + + PrintInfoCallback(String key, String value) { + this.key = key; + this.value = value; + } + + public void onCompletion(RecordMetadata recordMetadata, Exception e) { + synchronized (System.out) { + if (e == null) { + VerifiableProducer.this.numAcked++; + printJson(new SuccessfulSend(this.key, this.value, recordMetadata)); + } else { + printJson(new FailedSend(this.key, this.value, topic, e)); + } + } + } + } + + public void run(ThroughputThrottler throttler) { + + printJson(new StartupComplete()); + // negative maxMessages (-1) means "infinite" + long maxMessages = (this.maxMessages < 0) ? Long.MAX_VALUE : this.maxMessages; + + for (long i = 0; i < maxMessages; i++) { + if (this.stopProducing) { + break; + } + long sendStartMs = System.currentTimeMillis(); + + this.send(this.getKey(), this.getValue(i)); + + if (throttler.shouldThrottle(i, sendStartMs)) { + throttler.throttle(); + } + } + } + + public static void main(String[] args) { + ArgumentParser parser = argParser(); + if (args.length == 0) { + parser.printHelp(); + // Can't use `Exit.exit` here because it didn't exist until 0.11.0.0. + System.exit(0); + } + + try { + final VerifiableProducer producer = createFromArgs(parser, args); + + final long startMs = System.currentTimeMillis(); + ThroughputThrottler throttler = new ThroughputThrottler(producer.throughput, startMs); + + // Can't use `Exit.addShutdownHook` here because it didn't exist until 2.5.0. + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + // Trigger main thread to stop producing messages + producer.stopProducing = true; + + // Flush any remaining messages + producer.close(); + + // Print a summary + long stopMs = System.currentTimeMillis(); + double avgThroughput = 1000 * ((producer.numAcked) / (double) (stopMs - startMs)); + + producer.printJson(new ToolData(producer.numSent, producer.numAcked, producer.throughput, avgThroughput)); + }, "verifiable-producer-shutdown-hook")); + + producer.run(throttler); + } catch (ArgumentParserException e) { + parser.handleError(e); + // Can't use `Exit.exit` here because it didn't exist until 0.11.0.0. + System.exit(1); + } + } + +} diff --git a/tools/src/test/java/org/apache/kafka/tools/ProducerPerformanceTest.java b/tools/src/test/java/org/apache/kafka/tools/ProducerPerformanceTest.java new file mode 100644 index 0000000..4c479b6 --- /dev/null +++ b/tools/src/test/java/org/apache/kafka/tools/ProducerPerformanceTest.java @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.tools; + +import org.apache.kafka.clients.producer.KafkaProducer; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.Spy; +import org.mockito.junit.jupiter.MockitoExtension; + +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.ArgumentParserException; + +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Properties; +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +@ExtendWith(MockitoExtension.class) +public class ProducerPerformanceTest { + + @Mock + KafkaProducer producerMock; + + @Spy + ProducerPerformance producerPerformanceSpy; + + private File createTempFile(String contents) throws IOException { + File file = File.createTempFile("ProducerPerformanceTest", ".tmp"); + file.deleteOnExit(); + Files.write(file.toPath(), contents.getBytes()); + return file; + } + + @Test + public void testReadPayloadFile() throws Exception { + File payloadFile = createTempFile("Hello\nKafka"); + String payloadFilePath = payloadFile.getAbsolutePath(); + String payloadDelimiter = "\n"; + + List payloadByteList = ProducerPerformance.readPayloadFile(payloadFilePath, payloadDelimiter); + + assertEquals(2, payloadByteList.size()); + assertEquals("Hello", new String(payloadByteList.get(0))); + assertEquals("Kafka", new String(payloadByteList.get(1))); + } + + @Test + public void testReadProps() throws Exception { + List producerProps = Collections.singletonList("bootstrap.servers=localhost:9000"); + String producerConfig = createTempFile("acks=1").getAbsolutePath(); + String transactionalId = "1234"; + boolean transactionsEnabled = true; + + Properties prop = ProducerPerformance.readProps(producerProps, producerConfig, transactionalId, transactionsEnabled); + + assertNotNull(prop); + assertEquals(6, prop.size()); + } + + @Test + public void testNumberOfCallsForSendAndClose() throws IOException { + doReturn(null).when(producerMock).send(any(), any()); + doReturn(producerMock).when(producerPerformanceSpy).createKafkaProducer(any(Properties.class)); + + String[] args = new String[] { + "--topic", "Hello-Kafka", + "--num-records", "5", + "--throughput", "100", + "--record-size", "100", + "--producer-props", "bootstrap.servers=localhost:9000"}; + producerPerformanceSpy.start(args); + verify(producerMock, times(5)).send(any(), any()); + verify(producerMock, times(1)).close(); + } + + @Test + public void testUnexpectedArg() { + String[] args = new String[] { + "--test", "test", + "--topic", "Hello-Kafka", + "--num-records", "5", + "--throughput", "100", + "--record-size", "100", + "--producer-props", "bootstrap.servers=localhost:9000"}; + ArgumentParser parser = ProducerPerformance.argParser(); + ArgumentParserException thrown = assertThrows(ArgumentParserException.class, () -> parser.parseArgs(args)); + assertEquals("unrecognized arguments: '--test'", thrown.getMessage()); + } + + @Test + public void testGenerateRandomPayloadByPayloadFile() { + Integer recordSize = null; + String inputString = "Hello Kafka"; + byte[] byteArray = inputString.getBytes(StandardCharsets.UTF_8); + List payloadByteList = new ArrayList<>(); + payloadByteList.add(byteArray); + byte[] payload = null; + Random random = new Random(0); + + payload = ProducerPerformance.generateRandomPayload(recordSize, payloadByteList, payload, random); + assertEquals(inputString, new String(payload)); + } + + @Test + public void testGenerateRandomPayloadByRecordSize() { + Integer recordSize = 100; + byte[] payload = new byte[recordSize]; + List payloadByteList = new ArrayList<>(); + Random random = new Random(0); + + payload = ProducerPerformance.generateRandomPayload(recordSize, payloadByteList, payload, random); + for (byte b : payload) { + assertNotEquals(0, b); + } + } + + @Test + public void testGenerateRandomPayloadException() { + Integer recordSize = null; + byte[] payload = null; + List payloadByteList = new ArrayList<>(); + Random random = new Random(0); + + IllegalArgumentException thrown = assertThrows(IllegalArgumentException.class, () -> ProducerPerformance.generateRandomPayload(recordSize, payloadByteList, payload, random)); + assertEquals("no payload File Path or record Size provided", thrown.getMessage()); + } + + @Test + public void testClientIdOverride() throws Exception { + List producerProps = Collections.singletonList("client.id=producer-1"); + + Properties prop = ProducerPerformance.readProps(producerProps, null, "1234", true); + + assertNotNull(prop); + assertEquals("producer-1", prop.getProperty("client.id")); + } + + @Test + public void testDefaultClientId() throws Exception { + List producerProps = Collections.singletonList("acks=1"); + + Properties prop = ProducerPerformance.readProps(producerProps, null, "1234", true); + + assertNotNull(prop); + assertEquals("perf-producer-client", prop.getProperty("client.id")); + } +} diff --git a/tools/src/test/java/org/apache/kafka/tools/PushHttpMetricsReporterTest.java b/tools/src/test/java/org/apache/kafka/tools/PushHttpMetricsReporterTest.java new file mode 100644 index 0000000..e4ed958 --- /dev/null +++ b/tools/src/test/java/org/apache/kafka/tools/PushHttpMetricsReporterTest.java @@ -0,0 +1,347 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.tools; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import java.io.IOException; +import java.util.List; +import org.apache.kafka.common.MetricName; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.metrics.Gauge; +import org.apache.kafka.common.metrics.KafkaMetric; +import org.apache.kafka.common.metrics.MetricConfig; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; + +import java.io.InputStream; +import java.io.OutputStream; +import java.net.HttpURLConnection; +import java.net.InetAddress; +import java.net.MalformedURLException; +import java.net.URL; +import java.net.UnknownHostException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.InOrder; +import org.mockito.MockedStatic; +import org.mockito.Mockito; + +public class PushHttpMetricsReporterTest { + + private static final URL URL; + static { + try { + URL = new URL("http://fake:80"); + } catch (MalformedURLException e) { + throw new RuntimeException(e); + } + } + + private final Time time = new MockTime(); + + private final ScheduledExecutorService executor = mock(ScheduledExecutorService.class); + private final HttpURLConnection httpReq = mock(HttpURLConnection.class); + private final OutputStream httpOut = mock(OutputStream.class); + private final InputStream httpErr = mock(InputStream.class); + private final ArgumentCaptor reportRunnableCaptor = ArgumentCaptor.forClass(Runnable.class); + private final ArgumentCaptor httpPayloadCaptor = ArgumentCaptor.forClass(byte[].class); + + private PushHttpMetricsReporter reporter; + private MockedStatic mockedStaticReporter; + + @BeforeEach + public void setUp() { + reporter = new PushHttpMetricsReporter(time, executor); + mockedStaticReporter = Mockito.mockStatic(PushHttpMetricsReporter.class); + } + + @AfterEach + public void tearDown() { + mockedStaticReporter.close(); + } + + @Test + public void testConfigureClose() throws Exception { + whenClose(); + + configure(); + reporter.close(); + + verifyConfigure(); + verifyClose(); + } + + @Test + public void testConfigureBadUrl() { + Map config = new HashMap<>(); + config.put(PushHttpMetricsReporter.METRICS_URL_CONFIG, "malformed;url"); + config.put(PushHttpMetricsReporter.METRICS_PERIOD_CONFIG, "5"); + assertThrows(ConfigException.class, () -> reporter.configure(config)); + } + + @Test + public void testConfigureMissingPeriod() { + Map config = new HashMap<>(); + config.put(PushHttpMetricsReporter.METRICS_URL_CONFIG, URL.toString()); + assertThrows(ConfigException.class, () -> reporter.configure(config)); + } + + @Test + public void testNoMetrics() throws Exception { + whenRequest(200); + + configure(); + verifyConfigure(); + + reportRunnableCaptor.getValue().run(); + verifyResponse(); + + JsonNode payload = new ObjectMapper().readTree(httpPayloadCaptor.getValue()); + assertTrue(payload.isObject()); + + assertPayloadHasClientInfo(payload); + + // Should contain an empty list of metrics, i.e. we report updates even if there are no metrics to report to + // indicate liveness + JsonNode metrics = payload.get("metrics"); + assertTrue(metrics.isArray()); + assertEquals(0, metrics.size()); + + reporter.close(); + verifyClose(); + } + + // For error conditions, we expect them to come with a response body that we can read & log + @Test + public void testClientError() throws Exception { + whenRequest(400, true); + + configure(); + verifyConfigure(); + + reportRunnableCaptor.getValue().run(); + verifyResponse(); + + reporter.close(); + verifyClose(); + } + + @Test + public void testServerError() throws Exception { + whenRequest(500, true); + + configure(); + verifyConfigure(); + + reportRunnableCaptor.getValue().run(); + verifyResponse(); + + reporter.close(); + verifyClose(); + } + + @Test + public void testMetricValues() throws Exception { + whenRequest(200); + + configure(); + verifyConfigure(); + KafkaMetric metric1 = new KafkaMetric( + new Object(), + new MetricName("name1", "group1", "desc1", Collections.singletonMap("key1", "value1")), + new ImmutableValue<>(1.0), + null, + time + ); + KafkaMetric newMetric1 = new KafkaMetric( + new Object(), + new MetricName("name1", "group1", "desc1", Collections.singletonMap("key1", "value1")), + new ImmutableValue<>(-1.0), + null, + time + ); + KafkaMetric metric2 = new KafkaMetric( + new Object(), + new MetricName("name2", "group2", "desc2", Collections.singletonMap("key2", "value2")), + new ImmutableValue<>(2.0), + null, + time + ); + KafkaMetric metric3 = new KafkaMetric( + new Object(), + new MetricName("name3", "group3", "desc3", Collections.singletonMap("key3", "value3")), + new ImmutableValue<>(3.0), + null, + time + ); + KafkaMetric metric4 = new KafkaMetric( + new Object(), + new MetricName("name4", "group4", "desc4", Collections.singletonMap("key4", "value4")), + new ImmutableValue<>("value4"), + null, + time + ); + + reporter.init(Arrays.asList(metric1, metric2, metric4)); + reporter.metricChange(newMetric1); // added in init, modified + reporter.metricChange(metric3); // added by change + reporter.metricRemoval(metric2); // added in init, deleted by removal + + reportRunnableCaptor.getValue().run(); + verifyResponse(); + + JsonNode payload = new ObjectMapper().readTree(httpPayloadCaptor.getValue()); + assertTrue(payload.isObject()); + assertPayloadHasClientInfo(payload); + + // We should be left with the modified version of metric1 and metric3 + JsonNode metrics = payload.get("metrics"); + assertTrue(metrics.isArray()); + assertEquals(3, metrics.size()); + List metricsList = Arrays.asList(metrics.get(0), metrics.get(1), metrics.get(2)); + // Sort metrics based on name so that we can verify the value for each metric below + metricsList.sort((m1, m2) -> m1.get("name").textValue().compareTo(m2.get("name").textValue())); + + JsonNode m1 = metricsList.get(0); + assertEquals("name1", m1.get("name").textValue()); + assertEquals("group1", m1.get("group").textValue()); + JsonNode m1Tags = m1.get("tags"); + assertTrue(m1Tags.isObject()); + assertEquals(1, m1Tags.size()); + assertEquals("value1", m1Tags.get("key1").textValue()); + assertEquals(-1.0, m1.get("value").doubleValue(), 0.0); + + JsonNode m3 = metricsList.get(1); + assertEquals("name3", m3.get("name").textValue()); + assertEquals("group3", m3.get("group").textValue()); + JsonNode m3Tags = m3.get("tags"); + assertTrue(m3Tags.isObject()); + assertEquals(1, m3Tags.size()); + assertEquals("value3", m3Tags.get("key3").textValue()); + assertEquals(3.0, m3.get("value").doubleValue(), 0.0); + + JsonNode m4 = metricsList.get(2); + assertEquals("name4", m4.get("name").textValue()); + assertEquals("group4", m4.get("group").textValue()); + JsonNode m4Tags = m4.get("tags"); + assertTrue(m4Tags.isObject()); + assertEquals(1, m4Tags.size()); + assertEquals("value4", m4Tags.get("key4").textValue()); + assertEquals("value4", m4.get("value").textValue()); + + reporter.close(); + verifyClose(); + } + + private void configure() { + Map config = new HashMap<>(); + config.put(PushHttpMetricsReporter.METRICS_URL_CONFIG, URL.toString()); + config.put(PushHttpMetricsReporter.METRICS_PERIOD_CONFIG, "5"); + reporter.configure(config); + } + + private void whenRequest(int returnStatus) throws Exception { + whenRequest(returnStatus, false); + } + + // Expect that a request is made with the given response code + private void whenRequest(int returnStatus, boolean readResponse) throws Exception { + when(PushHttpMetricsReporter.newHttpConnection(URL)).thenReturn(httpReq); + when(httpReq.getOutputStream()).thenReturn(httpOut); + when(httpReq.getResponseCode()).thenReturn(returnStatus); + if (readResponse) + whenReadResponse(); + } + + private void assertPayloadHasClientInfo(JsonNode payload) throws UnknownHostException { + // Should contain client info... + JsonNode client = payload.get("client"); + assertTrue(client.isObject()); + assertEquals(InetAddress.getLocalHost().getCanonicalHostName(), client.get("host").textValue()); + assertEquals("", client.get("client_id").textValue()); + assertEquals(time.milliseconds(), client.get("time").longValue()); + } + + private void whenReadResponse() { + when(httpReq.getErrorStream()).thenReturn(httpErr); + when(PushHttpMetricsReporter.readResponse(httpErr)).thenReturn("error response message"); + } + + private void whenClose() throws Exception { + when(executor.awaitTermination(anyLong(), any())).thenReturn(true); + } + + private void verifyClose() throws InterruptedException { + InOrder inOrder = inOrder(executor); + inOrder.verify(executor).shutdown(); + inOrder.verify(executor).awaitTermination(30L, TimeUnit.SECONDS); + } + + private void verifyConfigure() { + verify(executor).scheduleAtFixedRate(reportRunnableCaptor.capture(), + eq(5L), eq(5L), eq(TimeUnit.SECONDS)); + } + + private void verifyResponse() throws IOException { + verify(httpReq).setRequestMethod("POST"); + verify(httpReq).setDoInput(true); + verify(httpReq).setRequestProperty("Content-Type", "application/json"); + verify(httpReq).setRequestProperty(eq("Content-Length"), anyString()); + verify(httpReq).setRequestProperty("Accept", "*/*"); + verify(httpReq).setUseCaches(false); + verify(httpReq).setDoOutput(true); + verify(httpReq).disconnect(); + + verify(httpOut).write(httpPayloadCaptor.capture()); + verify(httpOut).flush(); + verify(httpOut).close(); + } + + static class ImmutableValue implements Gauge { + private final T value; + + public ImmutableValue(T value) { + this.value = value; + } + + @Override + public T value(MetricConfig config, long now) { + return value; + } + } +} diff --git a/tools/src/test/java/org/apache/kafka/tools/TransactionsCommandTest.java b/tools/src/test/java/org/apache/kafka/tools/TransactionsCommandTest.java new file mode 100644 index 0000000..4b65dc8 --- /dev/null +++ b/tools/src/test/java/org/apache/kafka/tools/TransactionsCommandTest.java @@ -0,0 +1,1074 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.tools; + +import org.apache.kafka.clients.admin.AbortTransactionResult; +import org.apache.kafka.clients.admin.AbortTransactionSpec; +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.DescribeProducersOptions; +import org.apache.kafka.clients.admin.DescribeProducersResult; +import org.apache.kafka.clients.admin.DescribeProducersResult.PartitionProducerState; +import org.apache.kafka.clients.admin.DescribeTopicsResult; +import org.apache.kafka.clients.admin.DescribeTransactionsResult; +import org.apache.kafka.clients.admin.ListTopicsOptions; +import org.apache.kafka.clients.admin.ListTopicsResult; +import org.apache.kafka.clients.admin.ListTransactionsOptions; +import org.apache.kafka.clients.admin.ListTransactionsResult; +import org.apache.kafka.clients.admin.ProducerState; +import org.apache.kafka.clients.admin.TopicDescription; +import org.apache.kafka.clients.admin.TransactionDescription; +import org.apache.kafka.clients.admin.TransactionListing; +import org.apache.kafka.clients.admin.TransactionState; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.TopicPartitionInfo; +import org.apache.kafka.common.errors.TransactionalIdNotFoundException; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.Mockito; + +import java.io.BufferedReader; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.PrintStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.OptionalInt; +import java.util.OptionalLong; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import static java.util.Arrays.asList; +import static java.util.Collections.emptyMap; +import static java.util.Collections.singleton; +import static java.util.Collections.singletonList; +import static java.util.Collections.singletonMap; +import static org.apache.kafka.common.KafkaFuture.completedFuture; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class TransactionsCommandTest { + + private final MockExitProcedure exitProcedure = new MockExitProcedure(); + private final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + private final PrintStream out = new PrintStream(outputStream); + private final MockTime time = new MockTime(); + private final Admin admin = Mockito.mock(Admin.class); + + @BeforeEach + public void setupExitProcedure() { + Exit.setExitProcedure(exitProcedure); + } + + @AfterEach + public void resetExitProcedure() { + Exit.resetExitProcedure(); + } + + @Test + public void testDescribeProducersTopicRequired() throws Exception { + assertCommandFailure(new String[]{ + "--bootstrap-server", + "localhost:9092", + "describe-producers", + "--partition", + "0" + }); + } + + @Test + public void testDescribeProducersPartitionRequired() throws Exception { + assertCommandFailure(new String[]{ + "--bootstrap-server", + "localhost:9092", + "describe-producers", + "--topic", + "foo" + }); + } + + @Test + public void testDescribeProducersLeader() throws Exception { + TopicPartition topicPartition = new TopicPartition("foo", 5); + String[] args = new String[] { + "--bootstrap-server", + "localhost:9092", + "describe-producers", + "--topic", + topicPartition.topic(), + "--partition", + String.valueOf(topicPartition.partition()) + }; + + testDescribeProducers(topicPartition, args, new DescribeProducersOptions()); + } + + @Test + public void testDescribeProducersSpecificReplica() throws Exception { + TopicPartition topicPartition = new TopicPartition("foo", 5); + int brokerId = 5; + + String[] args = new String[] { + "--bootstrap-server", + "localhost:9092", + "describe-producers", + "--topic", + topicPartition.topic(), + "--partition", + String.valueOf(topicPartition.partition()), + "--broker-id", + String.valueOf(brokerId) + }; + + testDescribeProducers(topicPartition, args, new DescribeProducersOptions().brokerId(brokerId)); + } + + private void testDescribeProducers( + TopicPartition topicPartition, + String[] args, + DescribeProducersOptions expectedOptions + ) throws Exception { + DescribeProducersResult describeResult = Mockito.mock(DescribeProducersResult.class); + KafkaFuture describeFuture = completedFuture( + new PartitionProducerState(asList( + new ProducerState(12345L, 15, 1300, 1599509565L, + OptionalInt.of(20), OptionalLong.of(990)), + new ProducerState(98765L, 30, 2300, 1599509599L, + OptionalInt.empty(), OptionalLong.empty()) + ))); + + + Mockito.when(describeResult.partitionResult(topicPartition)).thenReturn(describeFuture); + Mockito.when(admin.describeProducers(singleton(topicPartition), expectedOptions)).thenReturn(describeResult); + + execute(args); + assertNormalExit(); + + List> table = readOutputAsTable(); + assertEquals(3, table.size()); + + List expectedHeaders = asList(TransactionsCommand.DescribeProducersCommand.HEADERS); + assertEquals(expectedHeaders, table.get(0)); + + Set> expectedRows = Utils.mkSet( + asList("12345", "15", "20", "1300", "1599509565", "990"), + asList("98765", "30", "-1", "2300", "1599509599", "None") + ); + assertEquals(expectedRows, new HashSet<>(table.subList(1, table.size()))); + } + + @Test + public void testListTransactions() throws Exception { + String[] args = new String[] { + "--bootstrap-server", + "localhost:9092", + "list" + }; + + Map> transactions = new HashMap<>(); + transactions.put(0, asList( + new TransactionListing("foo", 12345L, TransactionState.ONGOING), + new TransactionListing("bar", 98765L, TransactionState.PREPARE_ABORT) + )); + transactions.put(1, singletonList( + new TransactionListing("baz", 13579L, TransactionState.COMPLETE_COMMIT) + )); + + expectListTransactions(transactions); + + execute(args); + assertNormalExit(); + + List> table = readOutputAsTable(); + assertEquals(4, table.size()); + + // Assert expected headers + List expectedHeaders = asList(TransactionsCommand.ListTransactionsCommand.HEADERS); + assertEquals(expectedHeaders, table.get(0)); + + Set> expectedRows = Utils.mkSet( + asList("foo", "0", "12345", "Ongoing"), + asList("bar", "0", "98765", "PrepareAbort"), + asList("baz", "1", "13579", "CompleteCommit") + ); + assertEquals(expectedRows, new HashSet<>(table.subList(1, table.size()))); + } + + @Test + public void testDescribeTransactionsTransactionalIdRequired() throws Exception { + assertCommandFailure(new String[]{ + "--bootstrap-server", + "localhost:9092", + "describe" + }); + } + + @Test + public void testDescribeTransaction() throws Exception { + String transactionalId = "foo"; + String[] args = new String[] { + "--bootstrap-server", + "localhost:9092", + "describe", + "--transactional-id", + transactionalId + }; + + DescribeTransactionsResult describeResult = Mockito.mock(DescribeTransactionsResult.class); + + int coordinatorId = 5; + long transactionStartTime = time.milliseconds(); + + KafkaFuture describeFuture = completedFuture( + new TransactionDescription( + coordinatorId, + TransactionState.ONGOING, + 12345L, + 15, + 10000, + OptionalLong.of(transactionStartTime), + singleton(new TopicPartition("bar", 0)) + )); + + Mockito.when(describeResult.description(transactionalId)).thenReturn(describeFuture); + Mockito.when(admin.describeTransactions(singleton(transactionalId))).thenReturn(describeResult); + + // Add a little time so that we can see a positive transaction duration in the output + time.sleep(5000); + + execute(args); + assertNormalExit(); + + List> table = readOutputAsTable(); + assertEquals(2, table.size()); + + List expectedHeaders = asList(TransactionsCommand.DescribeTransactionsCommand.HEADERS); + assertEquals(expectedHeaders, table.get(0)); + + List expectedRow = asList( + String.valueOf(coordinatorId), + transactionalId, + "12345", + "15", + "Ongoing", + "10000", + String.valueOf(transactionStartTime), + "5000", + "bar-0" + ); + assertEquals(expectedRow, table.get(1)); + } + + @Test + public void testDescribeTransactionsStartOffsetOrProducerIdRequired() throws Exception { + assertCommandFailure(new String[]{ + "--bootstrap-server", + "localhost:9092", + "abort", + "--topic", + "foo", + "--partition", + "0" + }); + } + + @Test + public void testDescribeTransactionsTopicRequired() throws Exception { + assertCommandFailure(new String[]{ + "--bootstrap-server", + "localhost:9092", + "abort", + "--partition", + "0", + "--start-offset", + "9990" + }); + } + + @Test + public void testDescribeTransactionsPartitionRequired() throws Exception { + assertCommandFailure(new String[]{ + "--bootstrap-server", + "localhost:9092", + "abort", + "--topic", + "foo", + "--start-offset", + "9990" + }); + } + + @Test + public void testDescribeTransactionsProducerEpochRequiredWithProducerId() throws Exception { + assertCommandFailure(new String[]{ + "--bootstrap-server", + "localhost:9092", + "abort", + "--topic", + "foo", + "--partition", + "0", + "--producer-id", + "12345" + }); + } + + @Test + public void testDescribeTransactionsCoordinatorEpochRequiredWithProducerId() throws Exception { + assertCommandFailure(new String[]{ + "--bootstrap-server", + "localhost:9092", + "abort", + "--topic", + "foo", + "--partition", + "0", + "--producer-id", + "12345", + "--producer-epoch", + "15" + }); + } + + @Test + public void testNewBrokerAbortTransaction() throws Exception { + TopicPartition topicPartition = new TopicPartition("foo", 5); + long startOffset = 9173; + long producerId = 12345L; + short producerEpoch = 15; + int coordinatorEpoch = 76; + + String[] args = new String[] { + "--bootstrap-server", + "localhost:9092", + "abort", + "--topic", + topicPartition.topic(), + "--partition", + String.valueOf(topicPartition.partition()), + "--start-offset", + String.valueOf(startOffset) + }; + + DescribeProducersResult describeResult = Mockito.mock(DescribeProducersResult.class); + KafkaFuture describeFuture = completedFuture( + new PartitionProducerState(singletonList( + new ProducerState(producerId, producerEpoch, 1300, 1599509565L, + OptionalInt.of(coordinatorEpoch), OptionalLong.of(startOffset)) + ))); + + AbortTransactionResult abortTransactionResult = Mockito.mock(AbortTransactionResult.class); + KafkaFuture abortFuture = completedFuture(null); + AbortTransactionSpec expectedAbortSpec = new AbortTransactionSpec( + topicPartition, producerId, producerEpoch, coordinatorEpoch); + + Mockito.when(describeResult.partitionResult(topicPartition)).thenReturn(describeFuture); + Mockito.when(admin.describeProducers(singleton(topicPartition))).thenReturn(describeResult); + + Mockito.when(abortTransactionResult.all()).thenReturn(abortFuture); + Mockito.when(admin.abortTransaction(expectedAbortSpec)).thenReturn(abortTransactionResult); + + execute(args); + assertNormalExit(); + } + + @ParameterizedTest + @ValueSource(ints = {29, -1}) + public void testOldBrokerAbortTransactionWithUnknownCoordinatorEpoch(int coordinatorEpoch) throws Exception { + TopicPartition topicPartition = new TopicPartition("foo", 5); + long producerId = 12345L; + short producerEpoch = 15; + + String[] args = new String[] { + "--bootstrap-server", + "localhost:9092", + "abort", + "--topic", + topicPartition.topic(), + "--partition", + String.valueOf(topicPartition.partition()), + "--producer-id", + String.valueOf(producerId), + "--producer-epoch", + String.valueOf(producerEpoch), + "--coordinator-epoch", + String.valueOf(coordinatorEpoch) + }; + + AbortTransactionResult abortTransactionResult = Mockito.mock(AbortTransactionResult.class); + KafkaFuture abortFuture = completedFuture(null); + + final int expectedCoordinatorEpoch; + if (coordinatorEpoch < 0) { + expectedCoordinatorEpoch = 0; + } else { + expectedCoordinatorEpoch = coordinatorEpoch; + } + + AbortTransactionSpec expectedAbortSpec = new AbortTransactionSpec( + topicPartition, producerId, producerEpoch, expectedCoordinatorEpoch); + + Mockito.when(abortTransactionResult.all()).thenReturn(abortFuture); + Mockito.when(admin.abortTransaction(expectedAbortSpec)).thenReturn(abortTransactionResult); + + execute(args); + assertNormalExit(); + } + + @Test + public void testFindHangingRequiresEitherBrokerIdOrTopic() throws Exception { + assertCommandFailure(new String[]{ + "--bootstrap-server", + "localhost:9092", + "find-hanging" + }); + } + + @Test + public void testFindHangingRequiresTopicIfPartitionIsSpecified() throws Exception { + assertCommandFailure(new String[]{ + "--bootstrap-server", + "localhost:9092", + "find-hanging", + "--broker-id", + "0", + "--partition", + "5" + }); + } + + private void expectListTransactions( + Map> listingsByBroker + ) { + expectListTransactions(new ListTransactionsOptions(), listingsByBroker); + } + + private void expectListTransactions( + ListTransactionsOptions options, + Map> listingsByBroker + ) { + ListTransactionsResult listResult = Mockito.mock(ListTransactionsResult.class); + Mockito.when(admin.listTransactions(options)).thenReturn(listResult); + + List allListings = new ArrayList<>(); + listingsByBroker.values().forEach(allListings::addAll); + + Mockito.when(listResult.all()).thenReturn(completedFuture(allListings)); + Mockito.when(listResult.allByBrokerId()).thenReturn(completedFuture(listingsByBroker)); + } + + private void expectDescribeProducers( + TopicPartition topicPartition, + long producerId, + short producerEpoch, + long lastTimestamp, + OptionalInt coordinatorEpoch, + OptionalLong txnStartOffset + ) { + PartitionProducerState partitionProducerState = new PartitionProducerState(singletonList( + new ProducerState( + producerId, + producerEpoch, + 500, + lastTimestamp, + coordinatorEpoch, + txnStartOffset + ) + )); + + DescribeProducersResult result = Mockito.mock(DescribeProducersResult.class); + Mockito.when(result.all()).thenReturn( + completedFuture(singletonMap(topicPartition, partitionProducerState)) + ); + + Mockito.when(admin.describeProducers( + Collections.singletonList(topicPartition), + new DescribeProducersOptions() + )).thenReturn(result); + } + + private void expectDescribeTransactions( + Map descriptions + ) { + DescribeTransactionsResult result = Mockito.mock(DescribeTransactionsResult.class); + descriptions.forEach((transactionalId, description) -> { + Mockito.when(result.description(transactionalId)) + .thenReturn(completedFuture(description)); + }); + Mockito.when(result.all()).thenReturn(completedFuture(descriptions)); + Mockito.when(admin.describeTransactions(descriptions.keySet())).thenReturn(result); + } + + private void expectListTopics( + Set topics + ) { + ListTopicsResult result = Mockito.mock(ListTopicsResult.class); + Mockito.when(result.names()).thenReturn(completedFuture(topics)); + ListTopicsOptions listOptions = new ListTopicsOptions().listInternal(true); + Mockito.when(admin.listTopics(listOptions)).thenReturn(result); + } + + private void expectDescribeTopics( + Map descriptions + ) { + DescribeTopicsResult result = Mockito.mock(DescribeTopicsResult.class); + Mockito.when(result.allTopicNames()).thenReturn(completedFuture(descriptions)); + Mockito.when(admin.describeTopics(new ArrayList<>(descriptions.keySet()))).thenReturn(result); + } + + @Test + public void testFindHangingLookupTopicPartitionsForBroker() throws Exception { + int brokerId = 5; + + String[] args = new String[]{ + "--bootstrap-server", + "localhost:9092", + "find-hanging", + "--broker-id", + String.valueOf(brokerId) + }; + + String topic = "foo"; + expectListTopics(singleton(topic)); + + Node node0 = new Node(0, "localhost", 9092); + Node node1 = new Node(1, "localhost", 9093); + Node node5 = new Node(5, "localhost", 9097); + + TopicPartitionInfo partition0 = new TopicPartitionInfo( + 0, + node0, + Arrays.asList(node0, node1), + Arrays.asList(node0, node1) + ); + TopicPartitionInfo partition1 = new TopicPartitionInfo( + 1, + node1, + Arrays.asList(node1, node5), + Arrays.asList(node1, node5) + ); + + TopicDescription description = new TopicDescription( + topic, + false, + Arrays.asList(partition0, partition1) + ); + expectDescribeTopics(singletonMap(topic, description)); + + DescribeProducersResult result = Mockito.mock(DescribeProducersResult.class); + Mockito.when(result.all()).thenReturn(completedFuture(emptyMap())); + + Mockito.when(admin.describeProducers( + Collections.singletonList(new TopicPartition(topic, 1)), + new DescribeProducersOptions().brokerId(brokerId) + )).thenReturn(result); + + execute(args); + assertNormalExit(); + assertNoHangingTransactions(); + } + + @Test + public void testFindHangingLookupTopicAndBrokerId() throws Exception { + int brokerId = 5; + String topic = "foo"; + + String[] args = new String[]{ + "--bootstrap-server", + "localhost:9092", + "find-hanging", + "--broker-id", + String.valueOf(brokerId), + "--topic", + topic + }; + + Node node0 = new Node(0, "localhost", 9092); + Node node1 = new Node(1, "localhost", 9093); + Node node5 = new Node(5, "localhost", 9097); + + TopicPartitionInfo partition0 = new TopicPartitionInfo( + 0, + node0, + Arrays.asList(node0, node1), + Arrays.asList(node0, node1) + ); + TopicPartitionInfo partition1 = new TopicPartitionInfo( + 1, + node1, + Arrays.asList(node1, node5), + Arrays.asList(node1, node5) + ); + + TopicDescription description = new TopicDescription( + topic, + false, + Arrays.asList(partition0, partition1) + ); + expectDescribeTopics(singletonMap(topic, description)); + + DescribeProducersResult result = Mockito.mock(DescribeProducersResult.class); + Mockito.when(result.all()).thenReturn(completedFuture(emptyMap())); + + Mockito.when(admin.describeProducers( + Collections.singletonList(new TopicPartition(topic, 1)), + new DescribeProducersOptions().brokerId(brokerId) + )).thenReturn(result); + + execute(args); + assertNormalExit(); + assertNoHangingTransactions(); + } + + @Test + public void testFindHangingLookupTopicPartitionsForTopic() throws Exception { + String topic = "foo"; + + String[] args = new String[]{ + "--bootstrap-server", + "localhost:9092", + "find-hanging", + "--topic", + topic + }; + + Node node0 = new Node(0, "localhost", 9092); + Node node1 = new Node(1, "localhost", 9093); + Node node5 = new Node(5, "localhost", 9097); + + TopicPartitionInfo partition0 = new TopicPartitionInfo( + 0, + node0, + Arrays.asList(node0, node1), + Arrays.asList(node0, node1) + ); + TopicPartitionInfo partition1 = new TopicPartitionInfo( + 1, + node1, + Arrays.asList(node1, node5), + Arrays.asList(node1, node5) + ); + + TopicDescription description = new TopicDescription( + topic, + false, + Arrays.asList(partition0, partition1) + ); + expectDescribeTopics(singletonMap(topic, description)); + + DescribeProducersResult result = Mockito.mock(DescribeProducersResult.class); + Mockito.when(result.all()).thenReturn(completedFuture(emptyMap())); + + Mockito.when(admin.describeProducers( + Arrays.asList(new TopicPartition(topic, 0), new TopicPartition(topic, 1)), + new DescribeProducersOptions() + )).thenReturn(result); + + execute(args); + assertNormalExit(); + assertNoHangingTransactions(); + } + + private void assertNoHangingTransactions() throws Exception { + List> table = readOutputAsTable(); + assertEquals(1, table.size()); + + List expectedHeaders = asList(TransactionsCommand.FindHangingTransactionsCommand.HEADERS); + assertEquals(expectedHeaders, table.get(0)); + } + + @Test + public void testFindHangingSpecifiedTopicPartition() throws Exception { + TopicPartition topicPartition = new TopicPartition("foo", 5); + + String[] args = new String[]{ + "--bootstrap-server", + "localhost:9092", + "find-hanging", + "--topic", + topicPartition.topic(), + "--partition", + String.valueOf(topicPartition.partition()) + }; + + long producerId = 132L; + short producerEpoch = 5; + long lastTimestamp = time.milliseconds(); + OptionalInt coordinatorEpoch = OptionalInt.of(19); + OptionalLong txnStartOffset = OptionalLong.of(29384L); + + expectDescribeProducers( + topicPartition, + producerId, + producerEpoch, + lastTimestamp, + coordinatorEpoch, + txnStartOffset + ); + + execute(args); + assertNormalExit(); + + List> table = readOutputAsTable(); + assertEquals(1, table.size()); + + List expectedHeaders = asList(TransactionsCommand.FindHangingTransactionsCommand.HEADERS); + assertEquals(expectedHeaders, table.get(0)); + } + + @Test + public void testFindHangingNoMappedTransactionalId() throws Exception { + TopicPartition topicPartition = new TopicPartition("foo", 5); + + String[] args = new String[]{ + "--bootstrap-server", + "localhost:9092", + "find-hanging", + "--topic", + topicPartition.topic(), + "--partition", + String.valueOf(topicPartition.partition()) + }; + + long producerId = 132L; + short producerEpoch = 5; + long lastTimestamp = time.milliseconds() - TimeUnit.MINUTES.toMillis(60); + int coordinatorEpoch = 19; + long txnStartOffset = 29384L; + + expectDescribeProducers( + topicPartition, + producerId, + producerEpoch, + lastTimestamp, + OptionalInt.of(coordinatorEpoch), + OptionalLong.of(txnStartOffset) + ); + + expectListTransactions( + new ListTransactionsOptions().filterProducerIds(singleton(producerId)), + singletonMap(1, Collections.emptyList()) + ); + + expectDescribeTransactions(Collections.emptyMap()); + + execute(args); + assertNormalExit(); + + assertHangingTransaction( + topicPartition, + producerId, + producerEpoch, + coordinatorEpoch, + txnStartOffset, + lastTimestamp + ); + } + + @Test + public void testFindHangingWithNoTransactionDescription() throws Exception { + TopicPartition topicPartition = new TopicPartition("foo", 5); + + String[] args = new String[]{ + "--bootstrap-server", + "localhost:9092", + "find-hanging", + "--topic", + topicPartition.topic(), + "--partition", + String.valueOf(topicPartition.partition()) + }; + + long producerId = 132L; + short producerEpoch = 5; + long lastTimestamp = time.milliseconds() - TimeUnit.MINUTES.toMillis(60); + int coordinatorEpoch = 19; + long txnStartOffset = 29384L; + + expectDescribeProducers( + topicPartition, + producerId, + producerEpoch, + lastTimestamp, + OptionalInt.of(coordinatorEpoch), + OptionalLong.of(txnStartOffset) + ); + + String transactionalId = "bar"; + TransactionListing listing = new TransactionListing( + transactionalId, + producerId, + TransactionState.ONGOING + ); + + expectListTransactions( + new ListTransactionsOptions().filterProducerIds(singleton(producerId)), + singletonMap(1, Collections.singletonList(listing)) + ); + + DescribeTransactionsResult result = Mockito.mock(DescribeTransactionsResult.class); + Mockito.when(result.description(transactionalId)) + .thenReturn(failedFuture(new TransactionalIdNotFoundException(transactionalId + " not found"))); + Mockito.when(admin.describeTransactions(singleton(transactionalId))).thenReturn(result); + + execute(args); + assertNormalExit(); + + assertHangingTransaction( + topicPartition, + producerId, + producerEpoch, + coordinatorEpoch, + txnStartOffset, + lastTimestamp + ); + } + + private KafkaFuture failedFuture(Exception e) { + KafkaFutureImpl future = new KafkaFutureImpl<>(); + future.completeExceptionally(e); + return future; + } + + @Test + public void testFindHangingDoesNotFilterByTransactionInProgressWithDifferentPartitions() throws Exception { + TopicPartition topicPartition = new TopicPartition("foo", 5); + + String[] args = new String[]{ + "--bootstrap-server", + "localhost:9092", + "find-hanging", + "--topic", + topicPartition.topic(), + "--partition", + String.valueOf(topicPartition.partition()) + }; + + long producerId = 132L; + short producerEpoch = 5; + long lastTimestamp = time.milliseconds() - TimeUnit.MINUTES.toMillis(60); + int coordinatorEpoch = 19; + long txnStartOffset = 29384L; + + expectDescribeProducers( + topicPartition, + producerId, + producerEpoch, + lastTimestamp, + OptionalInt.of(coordinatorEpoch), + OptionalLong.of(txnStartOffset) + ); + + String transactionalId = "bar"; + TransactionListing listing = new TransactionListing( + transactionalId, + producerId, + TransactionState.ONGOING + ); + + expectListTransactions( + new ListTransactionsOptions().filterProducerIds(singleton(producerId)), + singletonMap(1, Collections.singletonList(listing)) + ); + + // Although there is a transaction in progress from the same + // producer epoch, it does not include the topic partition we + // found when describing producers. + TransactionDescription description = new TransactionDescription( + 1, + TransactionState.ONGOING, + producerId, + producerEpoch, + 60000, + OptionalLong.of(time.milliseconds()), + singleton(new TopicPartition("foo", 10)) + ); + + expectDescribeTransactions(singletonMap(transactionalId, description)); + + execute(args); + assertNormalExit(); + + assertHangingTransaction( + topicPartition, + producerId, + producerEpoch, + coordinatorEpoch, + txnStartOffset, + lastTimestamp + ); + } + + private void assertHangingTransaction( + TopicPartition topicPartition, + long producerId, + short producerEpoch, + int coordinatorEpoch, + long txnStartOffset, + long lastTimestamp + ) throws Exception { + List> table = readOutputAsTable(); + assertEquals(2, table.size()); + + List expectedHeaders = asList(TransactionsCommand.FindHangingTransactionsCommand.HEADERS); + assertEquals(expectedHeaders, table.get(0)); + + long durationMinutes = TimeUnit.MILLISECONDS.toMinutes(time.milliseconds() - lastTimestamp); + + List expectedRow = asList( + topicPartition.topic(), + String.valueOf(topicPartition.partition()), + String.valueOf(producerId), + String.valueOf(producerEpoch), + String.valueOf(coordinatorEpoch), + String.valueOf(txnStartOffset), + String.valueOf(lastTimestamp), + String.valueOf(durationMinutes) + ); + assertEquals(expectedRow, table.get(1)); + } + + @Test + public void testFindHangingFilterByTransactionInProgressWithSamePartition() throws Exception { + TopicPartition topicPartition = new TopicPartition("foo", 5); + + String[] args = new String[]{ + "--bootstrap-server", + "localhost:9092", + "find-hanging", + "--topic", + topicPartition.topic(), + "--partition", + String.valueOf(topicPartition.partition()) + }; + + long producerId = 132L; + short producerEpoch = 5; + long lastTimestamp = time.milliseconds() - TimeUnit.MINUTES.toMillis(60); + int coordinatorEpoch = 19; + long txnStartOffset = 29384L; + + expectDescribeProducers( + topicPartition, + producerId, + producerEpoch, + lastTimestamp, + OptionalInt.of(coordinatorEpoch), + OptionalLong.of(txnStartOffset) + ); + + String transactionalId = "bar"; + TransactionListing listing = new TransactionListing( + transactionalId, + producerId, + TransactionState.ONGOING + ); + + expectListTransactions( + new ListTransactionsOptions().filterProducerIds(singleton(producerId)), + singletonMap(1, Collections.singletonList(listing)) + ); + + // The coordinator shows an active transaction with the same epoch + // which includes the partition, so no hanging transaction should + // be detected. + TransactionDescription description = new TransactionDescription( + 1, + TransactionState.ONGOING, + producerId, + producerEpoch, + 60000, + OptionalLong.of(lastTimestamp), + singleton(topicPartition) + ); + + expectDescribeTransactions(singletonMap(transactionalId, description)); + + execute(args); + assertNormalExit(); + assertNoHangingTransactions(); + } + + private void execute(String[] args) throws Exception { + TransactionsCommand.execute(args, ns -> admin, out, time); + } + + private List> readOutputAsTable() throws IOException { + List> table = new ArrayList<>(); + ByteArrayInputStream inputStream = new ByteArrayInputStream(outputStream.toByteArray()); + BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream)); + + while (true) { + List row = readRow(reader); + if (row == null) { + break; + } + table.add(row); + } + return table; + } + + private List readRow(BufferedReader reader) throws IOException { + String line = reader.readLine(); + if (line == null) { + return null; + } else { + return asList(line.split("\\s+")); + } + } + + private void assertNormalExit() { + assertTrue(exitProcedure.hasExited); + assertEquals(0, exitProcedure.statusCode); + } + + private void assertCommandFailure(String[] args) throws Exception { + execute(args); + assertTrue(exitProcedure.hasExited); + assertEquals(1, exitProcedure.statusCode); + } + + private static class MockExitProcedure implements Exit.Procedure { + private boolean hasExited = false; + private int statusCode; + + @Override + public void execute(int statusCode, String message) { + if (!this.hasExited) { + this.hasExited = true; + this.statusCode = statusCode; + } + } + } + +} diff --git a/tools/src/test/resources/log4j.properties b/tools/src/test/resources/log4j.properties new file mode 100644 index 0000000..abeaf1e --- /dev/null +++ b/tools/src/test/resources/log4j.properties @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +log4j.rootLogger=TRACE, stdout + +log4j.appender.stdout=org.apache.log4j.ConsoleAppender +log4j.appender.stdout.layout=org.apache.log4j.PatternLayout +log4j.appender.stdout.layout.ConversionPattern=[%d] %p %m (%c:%L)%n + +log4j.logger.org.apache.kafka=TRACE +log4j.logger.org.eclipse.jetty=INFO diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/agent/Agent.java b/trogdor/src/main/java/org/apache/kafka/trogdor/agent/Agent.java new file mode 100644 index 0000000..9a05f90 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/agent/Agent.java @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.agent; + +import com.fasterxml.jackson.databind.node.LongNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import net.sourceforge.argparse4j.ArgumentParsers; +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.ArgumentParserException; +import net.sourceforge.argparse4j.inf.Namespace; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Scheduler; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.trogdor.common.JsonUtil; +import org.apache.kafka.trogdor.common.Node; +import org.apache.kafka.trogdor.common.Platform; +import org.apache.kafka.trogdor.rest.AgentStatusResponse; +import org.apache.kafka.trogdor.rest.CreateWorkerRequest; +import org.apache.kafka.trogdor.rest.DestroyWorkerRequest; +import org.apache.kafka.trogdor.rest.JsonRestServer; +import org.apache.kafka.trogdor.rest.StopWorkerRequest; +import org.apache.kafka.trogdor.task.TaskController; +import org.apache.kafka.trogdor.task.TaskSpec; +import org.apache.kafka.trogdor.rest.UptimeResponse; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.PrintStream; +import java.util.Set; + +import static net.sourceforge.argparse4j.impl.Arguments.store; + +/** + * The Trogdor agent. + * + * The agent process runs tasks. + */ +public final class Agent { + private static final Logger log = LoggerFactory.getLogger(Agent.class); + + /** + * The default Agent port. + */ + public static final int DEFAULT_PORT = 8888; + + /** + * The workerId to use in exec mode. + */ + private static final long EXEC_WORKER_ID = 1; + + /** + * The taskId to use in exec mode. + */ + private static final String EXEC_TASK_ID = "task0"; + + /** + * The platform object to use for this agent. + */ + private final Platform platform; + + /** + * The time at which this server was started. + */ + private final long serverStartMs; + + /** + * The WorkerManager. + */ + private final WorkerManager workerManager; + + /** + * The REST server. + */ + private final JsonRestServer restServer; + + private final Time time; + + /** + * Create a new Agent. + * + * @param platform The platform object to use. + * @param scheduler The scheduler to use for this Agent. + * @param restServer The REST server to use. + * @param resource The AgentRestResource to use. + */ + public Agent(Platform platform, Scheduler scheduler, + JsonRestServer restServer, AgentRestResource resource) { + this.platform = platform; + this.time = scheduler.time(); + this.serverStartMs = time.milliseconds(); + this.workerManager = new WorkerManager(platform, scheduler); + this.restServer = restServer; + resource.setAgent(this); + } + + public int port() { + return this.restServer.port(); + } + + public void beginShutdown() throws Exception { + restServer.beginShutdown(); + workerManager.beginShutdown(); + } + + public void waitForShutdown() throws Exception { + restServer.waitForShutdown(); + workerManager.waitForShutdown(); + } + + public AgentStatusResponse status() throws Exception { + return new AgentStatusResponse(serverStartMs, workerManager.workerStates()); + } + + public UptimeResponse uptime() { + return new UptimeResponse(serverStartMs, time.milliseconds()); + } + + public void createWorker(CreateWorkerRequest req) throws Throwable { + workerManager.createWorker(req.workerId(), req.taskId(), req.spec()); + } + + public void stopWorker(StopWorkerRequest req) throws Throwable { + workerManager.stopWorker(req.workerId(), false); + } + + public void destroyWorker(DestroyWorkerRequest req) throws Throwable { + workerManager.stopWorker(req.workerId(), true); + } + + /** + * Rebase the task spec time so that it is not earlier than the current time. + * This is only needed for tasks passed in with --exec. Normally, the + * controller rebases the task spec time. + */ + TaskSpec rebaseTaskSpecTime(TaskSpec spec) throws Exception { + ObjectNode node = JsonUtil.JSON_SERDE.valueToTree(spec); + node.set("startMs", new LongNode(Math.max(time.milliseconds(), spec.startMs()))); + return JsonUtil.JSON_SERDE.treeToValue(node, TaskSpec.class); + } + + /** + * Start a task on the agent, and block until it completes. + * + * @param spec The task specifiction. + * @param out The output stream to print to. + * + * @return True if the task run successfully; false otherwise. + */ + boolean exec(TaskSpec spec, PrintStream out) throws Exception { + TaskController controller = null; + try { + controller = spec.newController(EXEC_TASK_ID); + } catch (Exception e) { + out.println("Unable to create the task controller."); + e.printStackTrace(out); + return false; + } + Set nodes = controller.targetNodes(platform.topology()); + if (!nodes.contains(platform.curNode().name())) { + out.println("This task is not configured to run on this node. It runs on node(s): " + + Utils.join(nodes, ", ") + ", whereas this node is " + + platform.curNode().name()); + return false; + } + KafkaFuture future = null; + try { + future = workerManager.createWorker(EXEC_WORKER_ID, EXEC_TASK_ID, spec); + } catch (Throwable e) { + out.println("createWorker failed"); + e.printStackTrace(out); + return false; + } + out.println("Waiting for completion of task:" + JsonUtil.toPrettyJsonString(spec)); + String error = future.get(); + if (error == null || error.isEmpty()) { + out.println("Task succeeded with status " + + JsonUtil.toPrettyJsonString(workerManager.workerStates().get(EXEC_WORKER_ID).status())); + return true; + } else { + out.println("Task failed with status " + + JsonUtil.toPrettyJsonString(workerManager.workerStates().get(EXEC_WORKER_ID).status()) + + " and error " + error); + return false; + } + } + + public static void main(String[] args) throws Exception { + ArgumentParser parser = ArgumentParsers + .newArgumentParser("trogdor-agent") + .defaultHelp(true) + .description("The Trogdor fault injection agent"); + parser.addArgument("--agent.config", "-c") + .action(store()) + .required(true) + .type(String.class) + .dest("config") + .metavar("CONFIG") + .help("The configuration file to use."); + parser.addArgument("--node-name", "-n") + .action(store()) + .required(true) + .type(String.class) + .dest("node_name") + .metavar("NODE_NAME") + .help("The name of this node."); + parser.addArgument("--exec", "-e") + .action(store()) + .type(String.class) + .dest("task_spec") + .metavar("TASK_SPEC") + .help("Execute a single task spec and then exit. The argument is the task spec to load when starting up, or a path to it."); + Namespace res = null; + try { + res = parser.parseArgs(args); + } catch (ArgumentParserException e) { + if (args.length == 0) { + parser.printHelp(); + Exit.exit(0); + } else { + parser.handleError(e); + Exit.exit(1); + } + } + String configPath = res.getString("config"); + String nodeName = res.getString("node_name"); + String taskSpec = res.getString("task_spec"); + + Platform platform = Platform.Config.parse(nodeName, configPath); + JsonRestServer restServer = + new JsonRestServer(Node.Util.getTrogdorAgentPort(platform.curNode())); + AgentRestResource resource = new AgentRestResource(); + log.info("Starting agent process."); + final Agent agent = new Agent(platform, Scheduler.SYSTEM, restServer, resource); + restServer.start(resource); + Exit.addShutdownHook("agent-shutdown-hook", () -> { + log.warn("Running agent shutdown hook."); + try { + agent.beginShutdown(); + agent.waitForShutdown(); + } catch (Exception e) { + log.error("Got exception while running agent shutdown hook.", e); + } + }); + if (taskSpec != null) { + TaskSpec spec = null; + try { + spec = JsonUtil.objectFromCommandLineArgument(taskSpec, TaskSpec.class); + } catch (Exception e) { + System.out.println("Unable to parse the supplied task spec."); + e.printStackTrace(); + Exit.exit(1); + } + TaskSpec effectiveSpec = agent.rebaseTaskSpecTime(spec); + Exit.exit(agent.exec(effectiveSpec, System.out) ? 0 : 1); + } + agent.waitForShutdown(); + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/agent/AgentClient.java b/trogdor/src/main/java/org/apache/kafka/trogdor/agent/AgentClient.java new file mode 100644 index 0000000..0f47e92 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/agent/AgentClient.java @@ -0,0 +1,326 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.agent; + +import com.fasterxml.jackson.core.type.TypeReference; +import net.sourceforge.argparse4j.ArgumentParsers; +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.Namespace; +import net.sourceforge.argparse4j.inf.Subparser; +import net.sourceforge.argparse4j.inf.Subparsers; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.trogdor.common.JsonUtil; +import org.apache.kafka.trogdor.common.StringFormatter; +import org.apache.kafka.trogdor.rest.AgentStatusResponse; +import org.apache.kafka.trogdor.rest.CreateWorkerRequest; +import org.apache.kafka.trogdor.rest.DestroyWorkerRequest; +import org.apache.kafka.trogdor.rest.Empty; +import org.apache.kafka.trogdor.rest.JsonRestServer; +import org.apache.kafka.trogdor.rest.JsonRestServer.HttpResponse; +import org.apache.kafka.trogdor.rest.StopWorkerRequest; +import org.apache.kafka.trogdor.rest.WorkerState; +import org.apache.kafka.trogdor.task.TaskSpec; +import org.apache.kafka.trogdor.rest.UptimeResponse; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.ws.rs.core.UriBuilder; + +import java.time.OffsetDateTime; +import java.time.ZoneOffset; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import static net.sourceforge.argparse4j.impl.Arguments.store; +import static net.sourceforge.argparse4j.impl.Arguments.storeTrue; +import static org.apache.kafka.trogdor.common.StringFormatter.dateString; +import static org.apache.kafka.trogdor.common.StringFormatter.durationString; + +/** + * A client for the Trogdor agent. + */ +public class AgentClient { + private final Logger log; + + /** + * The maximum number of tries to make. + */ + private final int maxTries; + + /** + * The URL target. + */ + private final String target; + + public static class Builder { + private Logger log = LoggerFactory.getLogger(AgentClient.class); + private int maxTries = 1; + private String target = null; + + public Builder() { + } + + public Builder log(Logger log) { + this.log = log; + return this; + } + + public Builder maxTries(int maxTries) { + this.maxTries = maxTries; + return this; + } + + public Builder target(String target) { + this.target = target; + return this; + } + + public Builder target(String host, int port) { + this.target = String.format("%s:%d", host, port); + return this; + } + + public AgentClient build() { + if (target == null) { + throw new RuntimeException("You must specify a target."); + } + return new AgentClient(log, maxTries, target); + } + } + + private AgentClient(Logger log, int maxTries, String target) { + this.log = log; + this.maxTries = maxTries; + this.target = target; + } + + public String target() { + return target; + } + + public int maxTries() { + return maxTries; + } + + private String url(String suffix) { + return String.format("http://%s%s", target, suffix); + } + + public AgentStatusResponse status() throws Exception { + HttpResponse resp = + JsonRestServer.httpRequest(url("/agent/status"), "GET", + null, new TypeReference() { }, maxTries); + return resp.body(); + } + + public UptimeResponse uptime() throws Exception { + HttpResponse resp = + JsonRestServer.httpRequest(url("/agent/uptime"), "GET", + null, new TypeReference() { }, maxTries); + return resp.body(); + } + + public void createWorker(CreateWorkerRequest request) throws Exception { + HttpResponse resp = + JsonRestServer.httpRequest( + url("/agent/worker/create"), "POST", + request, new TypeReference() { }, maxTries); + resp.body(); + } + + public void stopWorker(StopWorkerRequest request) throws Exception { + HttpResponse resp = + JsonRestServer.httpRequest(url( + "/agent/worker/stop"), "PUT", + request, new TypeReference() { }, maxTries); + resp.body(); + } + + public void destroyWorker(DestroyWorkerRequest request) throws Exception { + UriBuilder uriBuilder = UriBuilder.fromPath(url("/agent/worker")); + uriBuilder.queryParam("workerId", request.workerId()); + HttpResponse resp = + JsonRestServer.httpRequest(uriBuilder.build().toString(), "DELETE", + null, new TypeReference() { }, maxTries); + resp.body(); + } + + public void invokeShutdown() throws Exception { + HttpResponse resp = + JsonRestServer.httpRequest(url( + "/agent/shutdown"), "PUT", + null, new TypeReference() { }, maxTries); + resp.body(); + } + + private static void addTargetArgument(ArgumentParser parser) { + parser.addArgument("--target", "-t") + .action(store()) + .required(true) + .type(String.class) + .dest("target") + .metavar("TARGET") + .help("A colon-separated host and port pair. For example, example.com:8888"); + } + + private static void addJsonArgument(ArgumentParser parser) { + parser.addArgument("--json") + .action(storeTrue()) + .dest("json") + .metavar("JSON") + .help("Show the full response as JSON."); + } + + private static void addWorkerIdArgument(ArgumentParser parser, String help) { + parser.addArgument("--workerId") + .action(storeTrue()) + .type(Long.class) + .dest("workerId") + .metavar("WORKER_ID") + .help(help); + } + + public static void main(String[] args) throws Exception { + ArgumentParser rootParser = ArgumentParsers + .newArgumentParser("trogdor-agent-client") + .defaultHelp(true) + .description("The Trogdor agent client."); + Subparsers subParsers = rootParser.addSubparsers(). + dest("command"); + Subparser uptimeParser = subParsers.addParser("uptime") + .help("Get the agent uptime."); + addTargetArgument(uptimeParser); + addJsonArgument(uptimeParser); + Subparser statusParser = subParsers.addParser("status") + .help("Get the agent status."); + addTargetArgument(statusParser); + addJsonArgument(statusParser); + Subparser createWorkerParser = subParsers.addParser("createWorker") + .help("Create a new worker."); + addTargetArgument(createWorkerParser); + addWorkerIdArgument(createWorkerParser, "The worker ID to create."); + createWorkerParser.addArgument("--taskId") + .action(store()) + .required(true) + .type(String.class) + .dest("taskId") + .metavar("TASK_ID") + .help("The task ID to create."); + createWorkerParser.addArgument("--spec", "-s") + .action(store()) + .required(true) + .type(String.class) + .dest("taskSpec") + .metavar("TASK_SPEC") + .help("The task spec to create, or a path to a file containing the task spec."); + Subparser stopWorkerParser = subParsers.addParser("stopWorker") + .help("Stop a worker."); + addTargetArgument(stopWorkerParser); + addWorkerIdArgument(stopWorkerParser, "The worker ID to stop."); + Subparser destroyWorkerParser = subParsers.addParser("destroyWorker") + .help("Destroy a worker."); + addTargetArgument(destroyWorkerParser); + addWorkerIdArgument(destroyWorkerParser, "The worker ID to destroy."); + Subparser shutdownParser = subParsers.addParser("shutdown") + .help("Shut down the agent."); + addTargetArgument(shutdownParser); + + Namespace res = rootParser.parseArgsOrFail(args); + String target = res.getString("target"); + AgentClient client = new Builder(). + maxTries(3). + target(target). + build(); + ZoneOffset localOffset = OffsetDateTime.now().getOffset(); + switch (res.getString("command")) { + case "uptime": { + UptimeResponse uptime = client.uptime(); + if (res.getBoolean("json")) { + System.out.println(JsonUtil.toJsonString(uptime)); + } else { + System.out.printf("Agent is running at %s.%n", target); + System.out.printf("\tStart time: %s%n", + dateString(uptime.serverStartMs(), localOffset)); + System.out.printf("\tCurrent server time: %s%n", + dateString(uptime.nowMs(), localOffset)); + System.out.printf("\tUptime: %s%n", + durationString(uptime.nowMs() - uptime.serverStartMs())); + } + break; + } + case "status": { + AgentStatusResponse status = client.status(); + if (res.getBoolean("json")) { + System.out.println(JsonUtil.toJsonString(status)); + } else { + System.out.printf("Agent is running at %s.%n", target); + System.out.printf("\tStart time: %s%n", + dateString(status.serverStartMs(), localOffset)); + List> lines = new ArrayList<>(); + List header = new ArrayList<>( + Arrays.asList("WORKER_ID", "TASK_ID", "STATE", "TASK_TYPE")); + lines.add(header); + for (Map.Entry entry : status.workers().entrySet()) { + List cols = new ArrayList<>(); + cols.add(Long.toString(entry.getKey())); + cols.add(entry.getValue().taskId()); + cols.add(entry.getValue().getClass().getSimpleName()); + cols.add(entry.getValue().spec().getClass().getCanonicalName()); + lines.add(cols); + } + System.out.print(StringFormatter.prettyPrintGrid(lines)); + } + break; + } + case "createWorker": { + long workerId = res.getLong("workerId"); + String taskId = res.getString("taskId"); + TaskSpec taskSpec = JsonUtil. + objectFromCommandLineArgument(res.getString("taskSpec"), TaskSpec.class); + CreateWorkerRequest req = + new CreateWorkerRequest(workerId, taskId, taskSpec); + client.createWorker(req); + System.out.printf("Sent CreateWorkerRequest for worker %d%n.", req.workerId()); + break; + } + case "stopWorker": { + long workerId = res.getLong("workerId"); + client.stopWorker(new StopWorkerRequest(workerId)); + System.out.printf("Sent StopWorkerRequest for worker %d%n.", workerId); + break; + } + case "destroyWorker": { + long workerId = res.getLong("workerId"); + client.destroyWorker(new DestroyWorkerRequest(workerId)); + System.out.printf("Sent DestroyWorkerRequest for worker %d%n.", workerId); + break; + } + case "shutdown": { + client.invokeShutdown(); + System.out.println("Sent ShutdownRequest."); + break; + } + default: { + System.out.println("You must choose an action. Type --help for help."); + Exit.exit(1); + } + } + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/agent/AgentRestResource.java b/trogdor/src/main/java/org/apache/kafka/trogdor/agent/AgentRestResource.java new file mode 100644 index 0000000..ec3df8b --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/agent/AgentRestResource.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.trogdor.agent; + +import org.apache.kafka.trogdor.rest.AgentStatusResponse; +import org.apache.kafka.trogdor.rest.CreateWorkerRequest; +import org.apache.kafka.trogdor.rest.DestroyWorkerRequest; +import org.apache.kafka.trogdor.rest.Empty; +import org.apache.kafka.trogdor.rest.StopWorkerRequest; +import org.apache.kafka.trogdor.rest.UptimeResponse; + +import javax.servlet.ServletContext; +import javax.ws.rs.Consumes; +import javax.ws.rs.DELETE; +import javax.ws.rs.DefaultValue; +import javax.ws.rs.GET; +import javax.ws.rs.POST; +import javax.ws.rs.PUT; +import javax.ws.rs.Path; +import javax.ws.rs.Produces; +import javax.ws.rs.QueryParam; +import javax.ws.rs.core.MediaType; +import java.util.concurrent.atomic.AtomicReference; + +/** + * The REST resource for the Agent. This describes the RPCs which the agent can accept. + * + * RPCs should be idempotent. This is important because if the server's response is + * lost, the client will simply retransmit the same request. The server's response must + * be the same the second time around. + * + * We return the empty JSON object {} rather than void for RPCs that have no results. + * This ensures that if we want to add more return results later, we can do so in a + * compatible way. + */ +@Path("/agent") +@Produces(MediaType.APPLICATION_JSON) +@Consumes(MediaType.APPLICATION_JSON) +public class AgentRestResource { + private final AtomicReference agent = new AtomicReference<>(null); + + @javax.ws.rs.core.Context + private ServletContext context; + + public void setAgent(Agent myAgent) { + agent.set(myAgent); + } + + @GET + @Path("/status") + public AgentStatusResponse getStatus() throws Throwable { + return agent().status(); + } + + @GET + @Path("/uptime") + public UptimeResponse uptime() { + return agent().uptime(); + } + + @POST + @Path("/worker/create") + public Empty createWorker(CreateWorkerRequest req) throws Throwable { + agent().createWorker(req); + return Empty.INSTANCE; + } + + @PUT + @Path("/worker/stop") + public Empty stopWorker(StopWorkerRequest req) throws Throwable { + agent().stopWorker(req); + return Empty.INSTANCE; + } + + @DELETE + @Path("/worker") + public Empty destroyWorker(@DefaultValue("0") @QueryParam("workerId") long workerId) throws Throwable { + agent().destroyWorker(new DestroyWorkerRequest(workerId)); + return Empty.INSTANCE; + } + + @PUT + @Path("/shutdown") + public Empty shutdown() throws Throwable { + agent().beginShutdown(); + return Empty.INSTANCE; + } + + private Agent agent() { + Agent myAgent = agent.get(); + if (myAgent == null) { + throw new RuntimeException("AgentRestResource has not been initialized yet."); + } + return myAgent; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/agent/WorkerManager.java b/trogdor/src/main/java/org/apache/kafka/trogdor/agent/WorkerManager.java new file mode 100644 index 0000000..4510e1b --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/agent/WorkerManager.java @@ -0,0 +1,697 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.agent; + +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.KafkaFuture; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.utils.Scheduler; +import org.apache.kafka.common.utils.ThreadUtils; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.trogdor.common.Platform; +import org.apache.kafka.trogdor.rest.RequestConflictException; +import org.apache.kafka.trogdor.rest.WorkerDone; +import org.apache.kafka.trogdor.rest.WorkerRunning; +import org.apache.kafka.trogdor.rest.WorkerStarting; +import org.apache.kafka.trogdor.rest.WorkerStopping; +import org.apache.kafka.trogdor.rest.WorkerState; +import org.apache.kafka.trogdor.task.AgentWorkerStatusTracker; +import org.apache.kafka.trogdor.task.TaskSpec; +import org.apache.kafka.trogdor.task.TaskWorker; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Map; +import java.util.TreeMap; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +public final class WorkerManager { + private static final Logger log = LoggerFactory.getLogger(WorkerManager.class); + + /** + * The platform to use. + */ + private final Platform platform; + + /** + * The name of this node. + */ + private final String nodeName; + + /** + * The scheduler to use. + */ + private final Scheduler scheduler; + + /** + * The clock to use. + */ + private final Time time; + + /** + * A map of task IDs to Work objects. + */ + private final Map workers; + + /** + * An ExecutorService used to schedule events in the future. + */ + private final ScheduledExecutorService stateChangeExecutor; + + /** + * An ExecutorService used to clean up TaskWorkers. + */ + private final ExecutorService workerCleanupExecutor; + + /** + * An ExecutorService to help with shutting down. + */ + private final ScheduledExecutorService shutdownExecutor; + + /** + * The shutdown manager. + */ + private final ShutdownManager shutdownManager = new ShutdownManager(); + + /** + * The shutdown manager handles shutting down gracefully. + * + * We can shut down gracefully only when all the references handed out + * by the ShutdownManager has been closed, and the shutdown bit has + * been set. RPC operations hold a reference for the duration of their + * execution, and so do Workers which have not been shut down. + * This prevents us from shutting down in the middle of an RPC, or with + * workers which are still running. + */ + static class ShutdownManager { + private boolean shutdown = false; + private long refCount = 0; + + class Reference implements AutoCloseable { + AtomicBoolean closed = new AtomicBoolean(false); + + @Override + public void close() { + if (closed.compareAndSet(false, true)) { + synchronized (ShutdownManager.this) { + refCount--; + if (shutdown && (refCount == 0)) { + ShutdownManager.this.notifyAll(); + } + } + } + } + } + + synchronized Reference takeReference() { + if (shutdown) { + throw new KafkaException("WorkerManager is shut down."); + } + refCount++; + return new Reference(); + } + + synchronized boolean shutdown() { + if (shutdown) { + return false; + } + shutdown = true; + if (refCount == 0) { + this.notifyAll(); + } + return true; + } + + synchronized void waitForQuiescence() throws InterruptedException { + while ((!shutdown) || (refCount > 0)) { + this.wait(); + } + } + } + + WorkerManager(Platform platform, Scheduler scheduler) { + this.platform = platform; + this.nodeName = platform.curNode().name(); + this.scheduler = scheduler; + this.time = scheduler.time(); + this.workers = new HashMap<>(); + this.stateChangeExecutor = Executors.newSingleThreadScheduledExecutor( + ThreadUtils.createThreadFactory("WorkerManagerStateThread", false)); + this.workerCleanupExecutor = Executors.newCachedThreadPool( + ThreadUtils.createThreadFactory("WorkerCleanupThread%d", false)); + this.shutdownExecutor = Executors.newScheduledThreadPool(0, + ThreadUtils.createThreadFactory("WorkerManagerShutdownThread%d", false)); + } + + enum State { + STARTING, + CANCELLING, + RUNNING, + STOPPING, + DONE, + } + + /** + * A worker which is being tracked. + */ + class Worker { + /** + * The worker ID. + */ + private final long workerId; + + /** + * The task ID. + */ + private final String taskId; + + /** + * The task specification. + */ + private final TaskSpec spec; + + /** + * The work which this worker is performing. + */ + private final TaskWorker taskWorker; + + /** + * The worker status. + */ + private final AgentWorkerStatusTracker status = new AgentWorkerStatusTracker(); + + /** + * The time when this task was started. + */ + private final long startedMs; + + /** + * The work state. + */ + private State state = State.STARTING; + + /** + * The time when this task was completed, or -1 if it has not been. + */ + private long doneMs = -1; + + /** + * The worker error. + */ + private String error = ""; + + /** + * If there is a task timeout scheduled, this is a future which can + * be used to cancel it. + */ + private Future timeoutFuture = null; + + /** + * A future which is completed when the task transitions to DONE state. + */ + private KafkaFutureImpl doneFuture = null; + + /** + * A shutdown manager reference which will keep the WorkerManager + * alive for as long as this worker is alive. + */ + private ShutdownManager.Reference reference; + + /** + * Whether we should destroy the records of this worker once it stops. + */ + private boolean mustDestroy = false; + + Worker(long workerId, String taskId, TaskSpec spec, long now) { + this.workerId = workerId; + this.taskId = taskId; + this.spec = spec; + this.taskWorker = spec.newTaskWorker(taskId); + this.startedMs = now; + this.reference = shutdownManager.takeReference(); + } + + long workerId() { + return workerId; + } + + String taskId() { + return taskId; + } + + TaskSpec spec() { + return spec; + } + + WorkerState state() { + switch (state) { + case STARTING: + return new WorkerStarting(taskId, spec); + case RUNNING: + return new WorkerRunning(taskId, spec, startedMs, status.get()); + case CANCELLING: + case STOPPING: + return new WorkerStopping(taskId, spec, startedMs, status.get()); + case DONE: + return new WorkerDone(taskId, spec, startedMs, doneMs, status.get(), error); + } + throw new RuntimeException("unreachable"); + } + + void transitionToRunning() { + state = State.RUNNING; + timeoutFuture = scheduler.schedule(stateChangeExecutor, + new StopWorker(workerId, false), + Math.max(0, spec.endMs() - time.milliseconds())); + } + + Future transitionToStopping() { + state = State.STOPPING; + if (timeoutFuture != null) { + timeoutFuture.cancel(false); + timeoutFuture = null; + } + return workerCleanupExecutor.submit(new HaltWorker(this)); + } + + void transitionToDone() { + state = State.DONE; + doneMs = time.milliseconds(); + if (reference != null) { + reference.close(); + reference = null; + } + doneFuture.complete(error); + } + + @Override + public String toString() { + return String.format("%s_%d", taskId, workerId); + } + } + + public KafkaFuture createWorker(long workerId, String taskId, TaskSpec spec) throws Throwable { + try (ShutdownManager.Reference ref = shutdownManager.takeReference()) { + final Worker worker = stateChangeExecutor. + submit(new CreateWorker(workerId, taskId, spec, time.milliseconds())).get(); + if (worker.doneFuture != null) { + log.info("{}: Ignoring request to create worker {}, because there is already " + + "a worker with that id.", nodeName, workerId); + return worker.doneFuture; + } + worker.doneFuture = new KafkaFutureImpl<>(); + if (worker.spec.endMs() <= time.milliseconds()) { + log.info("{}: Will not run worker {} as it has expired.", nodeName, worker); + stateChangeExecutor.submit(new HandleWorkerHalting(worker, + "worker expired", true)); + return worker.doneFuture; + } + KafkaFutureImpl haltFuture = new KafkaFutureImpl<>(); + haltFuture.thenApply((KafkaFuture.BaseFunction) errorString -> { + if (errorString == null) + errorString = ""; + if (errorString.isEmpty()) { + log.info("{}: Worker {} is halting.", nodeName, worker); + } else { + log.info("{}: Worker {} is halting with error {}", + nodeName, worker, errorString); + } + stateChangeExecutor.submit( + new HandleWorkerHalting(worker, errorString, false)); + return null; + }); + try { + worker.taskWorker.start(platform, worker.status, haltFuture); + } catch (Exception e) { + log.info("{}: Worker {} start() exception", nodeName, worker, e); + stateChangeExecutor.submit(new HandleWorkerHalting(worker, + "worker.start() exception: " + Utils.stackTrace(e), true)); + } + stateChangeExecutor.submit(new FinishCreatingWorker(worker)); + return worker.doneFuture; + } catch (ExecutionException e) { + if (e.getCause() instanceof RequestConflictException) { + log.info("{}: request conflict while creating worker {} for task {} with spec {}.", + nodeName, workerId, taskId, spec); + } else { + log.info("{}: Error creating worker {} for task {} with spec {}", + nodeName, workerId, taskId, spec, e); + } + throw e.getCause(); + } + } + + /** + * Handles a request to create a new worker. Processed by the state change thread. + */ + class CreateWorker implements Callable { + private final long workerId; + private final String taskId; + private final TaskSpec spec; + private final long now; + + CreateWorker(long workerId, String taskId, TaskSpec spec, long now) { + this.workerId = workerId; + this.taskId = taskId; + this.spec = spec; + this.now = now; + } + + @Override + public Worker call() throws Exception { + try { + Worker worker = workers.get(workerId); + if (worker != null) { + if (!worker.taskId().equals(taskId)) { + throw new RequestConflictException("There is already a worker ID " + workerId + + " with a different task ID."); + } else if (!worker.spec().equals(spec)) { + throw new RequestConflictException("There is already a worker ID " + workerId + + " with a different task spec."); + } else { + return worker; + } + } + worker = new Worker(workerId, taskId, spec, now); + workers.put(workerId, worker); + log.info("{}: Created worker {} with spec {}", nodeName, worker, spec); + return worker; + } catch (Exception e) { + log.info("{}: unable to create worker {} for task {}, with spec {}", + nodeName, workerId, taskId, spec, e); + throw e; + } + } + } + + /** + * Finish creating a Worker. Processed by the state change thread. + */ + class FinishCreatingWorker implements Callable { + private final Worker worker; + + FinishCreatingWorker(Worker worker) { + this.worker = worker; + } + + @Override + public Void call() throws Exception { + switch (worker.state) { + case CANCELLING: + log.info("{}: Worker {} was cancelled while it was starting up. " + + "Transitioning to STOPPING.", nodeName, worker); + worker.transitionToStopping(); + break; + case STARTING: + log.info("{}: Worker {} is now RUNNING. Scheduled to stop in {} ms.", + nodeName, worker, worker.spec.durationMs()); + worker.transitionToRunning(); + break; + default: + break; + } + return null; + } + } + + /** + * Handles a worker halting. Processed by the state change thread. + */ + class HandleWorkerHalting implements Callable { + private final Worker worker; + private final String failure; + private final boolean startupHalt; + + HandleWorkerHalting(Worker worker, String failure, boolean startupHalt) { + this.worker = worker; + this.failure = failure; + this.startupHalt = startupHalt; + } + + @Override + public Void call() throws Exception { + if (worker.error.isEmpty()) { + worker.error = failure; + } + String verb = (worker.error.isEmpty()) ? "halting" : + "halting with error [" + worker.error + "]"; + switch (worker.state) { + case STARTING: + if (startupHalt) { + log.info("{}: Worker {} {} during startup. Transitioning to DONE.", + nodeName, worker, verb); + worker.transitionToDone(); + } else { + log.info("{}: Worker {} {} during startup. Transitioning to CANCELLING.", + nodeName, worker, verb); + worker.state = State.CANCELLING; + } + break; + case CANCELLING: + log.info("{}: Cancelling worker {} {}. ", + nodeName, worker, verb); + break; + case RUNNING: + log.info("{}: Running worker {} {}. Transitioning to STOPPING.", + nodeName, worker, verb); + worker.transitionToStopping(); + break; + case STOPPING: + log.info("{}: Stopping worker {} {}.", nodeName, worker, verb); + break; + case DONE: + log.info("{}: Can't halt worker {} because it is already DONE.", + nodeName, worker); + break; + } + return null; + } + } + + /** + * Transitions a worker to WorkerDone. Processed by the state change thread. + */ + class CompleteWorker implements Callable { + private final Worker worker; + + private final String failure; + + CompleteWorker(Worker worker, String failure) { + this.worker = worker; + this.failure = failure; + } + + @Override + public Void call() throws Exception { + if (worker.error.isEmpty() && !failure.isEmpty()) { + worker.error = failure; + } + worker.transitionToDone(); + if (worker.mustDestroy) { + log.info("{}: destroying worker {} with error {}", + nodeName, worker, worker.error); + workers.remove(worker.workerId); + } else { + log.info("{}: completed worker {} with error {}", + nodeName, worker, worker.error); + } + return null; + } + } + + public void stopWorker(long workerId, boolean mustDestroy) throws Throwable { + try (ShutdownManager.Reference ref = shutdownManager.takeReference()) { + stateChangeExecutor.submit(new StopWorker(workerId, mustDestroy)).get(); + } catch (ExecutionException e) { + throw e.getCause(); + } + } + + /** + * Stops a worker. Processed by the state change thread. + */ + class StopWorker implements Callable { + private final long workerId; + private final boolean mustDestroy; + + StopWorker(long workerId, boolean mustDestroy) { + this.workerId = workerId; + this.mustDestroy = mustDestroy; + } + + @Override + public Void call() throws Exception { + Worker worker = workers.get(workerId); + if (worker == null) { + log.info("{}: Can't stop worker {} because there is no worker with that ID.", + nodeName, workerId); + return null; + } + if (mustDestroy) { + worker.mustDestroy = true; + } + switch (worker.state) { + case STARTING: + log.info("{}: Cancelling worker {} during its startup process.", + nodeName, worker); + worker.state = State.CANCELLING; + break; + case CANCELLING: + log.info("{}: Can't stop worker {}, because it is already being " + + "cancelled.", nodeName, worker); + break; + case RUNNING: + log.info("{}: Stopping running worker {}.", nodeName, worker); + worker.transitionToStopping(); + break; + case STOPPING: + log.info("{}: Can't stop worker {}, because it is already " + + "stopping.", nodeName, worker); + break; + case DONE: + if (worker.mustDestroy) { + log.info("{}: destroying worker {} with error {}", + nodeName, worker, worker.error); + workers.remove(worker.workerId); + } else { + log.debug("{}: Can't stop worker {}, because it is already done.", + nodeName, worker); + } + break; + } + return null; + } + } + + /** + * Cleans up the resources associated with a worker. Processed by the worker + * cleanup thread pool. + */ + class HaltWorker implements Callable { + private final Worker worker; + + HaltWorker(Worker worker) { + this.worker = worker; + } + + @Override + public Void call() throws Exception { + String failure = ""; + try { + worker.taskWorker.stop(platform); + } catch (Exception exception) { + log.error("{}: worker.stop() exception", nodeName, exception); + failure = exception.getMessage(); + } + stateChangeExecutor.submit(new CompleteWorker(worker, failure)); + return null; + } + } + + public TreeMap workerStates() throws Exception { + try (ShutdownManager.Reference ref = shutdownManager.takeReference()) { + return stateChangeExecutor.submit(new GetWorkerStates()).get(); + } + } + + class GetWorkerStates implements Callable> { + @Override + public TreeMap call() throws Exception { + TreeMap workerMap = new TreeMap<>(); + for (Worker worker : workers.values()) { + workerMap.put(worker.workerId(), worker.state()); + } + return workerMap; + } + } + + public void beginShutdown() throws Exception { + if (shutdownManager.shutdown()) { + shutdownExecutor.submit(new Shutdown()); + } + } + + public void waitForShutdown() throws Exception { + while (!shutdownExecutor.isShutdown()) { + shutdownExecutor.awaitTermination(1, TimeUnit.DAYS); + } + } + + class Shutdown implements Callable { + @Override + public Void call() throws Exception { + log.info("{}: Shutting down WorkerManager.", nodeName); + try { + stateChangeExecutor.submit(new DestroyAllWorkers()).get(); + log.info("{}: Waiting for shutdownManager quiescence...", nodeName); + shutdownManager.waitForQuiescence(); + workerCleanupExecutor.shutdownNow(); + stateChangeExecutor.shutdownNow(); + log.info("{}: Waiting for workerCleanupExecutor to terminate...", nodeName); + workerCleanupExecutor.awaitTermination(1, TimeUnit.DAYS); + log.info("{}: Waiting for stateChangeExecutor to terminate...", nodeName); + stateChangeExecutor.awaitTermination(1, TimeUnit.DAYS); + log.info("{}: Shutting down shutdownExecutor.", nodeName); + shutdownExecutor.shutdown(); + } catch (Exception e) { + log.info("{}: Caught exception while shutting down WorkerManager", nodeName, e); + throw e; + } + return null; + } + } + + /** + * Begins the process of destroying all workers. Processed by the state change thread. + */ + class DestroyAllWorkers implements Callable { + @Override + public Void call() throws Exception { + log.info("{}: Destroying all workers.", nodeName); + + // StopWorker may remove elements from the set of worker IDs. That might generate + // a ConcurrentModificationException if we were iterating over the worker ID + // set directly. Therefore, we make a copy of the worker IDs here and iterate + // over that instead. + // + // Note that there is no possible way that more worker IDs can be added while this + // callable is running, because the state change executor is single-threaded. + ArrayList workerIds = new ArrayList<>(workers.keySet()); + + for (long workerId : workerIds) { + try { + new StopWorker(workerId, true).call(); + } catch (Exception e) { + log.error("Failed to stop worker {}", workerId, e); + } + } + return null; + } + } + +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/basic/BasicNode.java b/trogdor/src/main/java/org/apache/kafka/trogdor/basic/BasicNode.java new file mode 100644 index 0000000..232a64d --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/basic/BasicNode.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.basic; + +import com.fasterxml.jackson.databind.JsonNode; +import org.apache.kafka.trogdor.common.Node; + +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +public class BasicNode implements Node { + private final String name; + private final String hostname; + private final Map config; + private final Set tags; + + public BasicNode(String name, String hostname, Map config, + Set tags) { + this.name = name; + this.hostname = hostname; + this.config = config; + this.tags = tags; + } + + public BasicNode(String name, JsonNode root) { + this.name = name; + String hostname = "localhost"; + Set tags = Collections.emptySet(); + Map config = new HashMap<>(); + for (Iterator> iter = root.fields(); + iter.hasNext(); ) { + Map.Entry entry = iter.next(); + String key = entry.getKey(); + JsonNode node = entry.getValue(); + if (key.equals("hostname")) { + hostname = node.asText(); + } else if (key.equals("tags")) { + if (!node.isArray()) { + throw new RuntimeException("Expected the 'tags' field to be an " + + "array of strings."); + } + tags = new HashSet<>(); + for (Iterator tagIter = node.elements(); tagIter.hasNext(); ) { + JsonNode tag = tagIter.next(); + tags.add(tag.asText()); + } + } else { + config.put(key, node.asText()); + } + } + this.hostname = hostname; + this.tags = tags; + this.config = config; + } + + @Override + public String name() { + return name; + } + + @Override + public String hostname() { + return hostname; + } + + @Override + public String getConfig(String key) { + return config.get(key); + } + + @Override + public Set tags() { + return tags; + } + + @Override + public int hashCode() { + return Objects.hash(name, hostname, config, tags); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + BasicNode that = (BasicNode) o; + return Objects.equals(name, that.name) && + Objects.equals(hostname, that.hostname) && + Objects.equals(config, that.config) && + Objects.equals(tags, that.tags); + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/basic/BasicPlatform.java b/trogdor/src/main/java/org/apache/kafka/trogdor/basic/BasicPlatform.java new file mode 100644 index 0000000..6922c2e --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/basic/BasicPlatform.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.basic; + +import com.fasterxml.jackson.databind.JsonNode; +import org.apache.kafka.common.utils.Scheduler; +import org.apache.kafka.common.utils.Shell; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.trogdor.common.Node; +import org.apache.kafka.trogdor.common.Platform; +import org.apache.kafka.trogdor.common.Topology; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; + +/** + * Defines a cluster topology + */ +public class BasicPlatform implements Platform { + private static final Logger log = LoggerFactory.getLogger(BasicPlatform.class); + + private final Node curNode; + private final BasicTopology topology; + private final Scheduler scheduler; + private final CommandRunner commandRunner; + + public interface CommandRunner { + String run(Node curNode, String[] command) throws IOException; + } + + public static class ShellCommandRunner implements CommandRunner { + @Override + public String run(Node curNode, String[] command) throws IOException { + try { + String result = Shell.execCommand(command); + log.info("RUN: {}. RESULT: [{}]", Utils.join(command, " "), result); + return result; + } catch (RuntimeException | IOException e) { + log.info("RUN: {}. ERROR: [{}]", Utils.join(command, " "), e.getMessage()); + throw e; + } + } + } + + public BasicPlatform(String curNodeName, BasicTopology topology, + Scheduler scheduler, CommandRunner commandRunner) { + this.curNode = topology.node(curNodeName); + if (this.curNode == null) { + throw new RuntimeException(String.format("No node named %s found " + + "in the cluster! Cluster nodes are: %s", curNodeName, + Utils.join(topology.nodes().keySet(), ","))); + } + this.topology = topology; + this.scheduler = scheduler; + this.commandRunner = commandRunner; + } + + public BasicPlatform(String curNodeName, Scheduler scheduler, JsonNode configRoot) { + JsonNode nodes = configRoot.get("nodes"); + if (nodes == null) { + throw new RuntimeException("Expected to find a 'nodes' field " + + "in the root JSON configuration object"); + } + this.topology = new BasicTopology(nodes); + this.scheduler = scheduler; + this.curNode = topology.node(curNodeName); + if (this.curNode == null) { + throw new RuntimeException(String.format("No node named %s found " + + "in the cluster! Cluster nodes are: %s", curNodeName, + Utils.join(topology.nodes().keySet(), ","))); + } + this.commandRunner = new ShellCommandRunner(); + } + + @Override + public String name() { + return "BasicPlatform"; + } + + @Override + public Node curNode() { + return curNode; + } + + @Override + public Topology topology() { + return topology; + } + + @Override + public Scheduler scheduler() { + return scheduler; + } + + @Override + public String runCommand(String[] command) throws IOException { + return commandRunner.run(curNode, command); + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/basic/BasicTopology.java b/trogdor/src/main/java/org/apache/kafka/trogdor/basic/BasicTopology.java new file mode 100644 index 0000000..1f29150 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/basic/BasicTopology.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.basic; + +import com.fasterxml.jackson.databind.JsonNode; +import org.apache.kafka.trogdor.common.Node; +import org.apache.kafka.trogdor.common.Topology; + +import java.util.Iterator; +import java.util.NavigableMap; +import java.util.TreeMap; + +public class BasicTopology implements Topology { + private final NavigableMap nodes; + + public BasicTopology(NavigableMap nodes) { + this.nodes = nodes; + } + + public BasicTopology(JsonNode configRoot) { + if (!configRoot.isObject()) { + throw new RuntimeException("Expected the 'nodes' element to be " + + "a JSON object."); + } + nodes = new TreeMap<>(); + for (Iterator iter = configRoot.fieldNames(); iter.hasNext(); ) { + String nodeName = iter.next(); + JsonNode nodeConfig = configRoot.get(nodeName); + BasicNode node = new BasicNode(nodeName, nodeConfig); + nodes.put(nodeName, node); + } + } + + @Override + public Node node(String id) { + return nodes.get(id); + } + + @Override + public NavigableMap nodes() { + return nodes; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/common/JsonUtil.java b/trogdor/src/main/java/org/apache/kafka/trogdor/common/JsonUtil.java new file mode 100644 index 0000000..b3915e4 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/common/JsonUtil.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.common; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; +import com.fasterxml.jackson.datatype.jdk8.Jdk8Module; + +import java.io.File; + +/** + * Utilities for working with JSON. + */ +public class JsonUtil { + public static final ObjectMapper JSON_SERDE; + + static { + JSON_SERDE = new ObjectMapper(); + JSON_SERDE.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); + JSON_SERDE.configure(DeserializationFeature.ACCEPT_SINGLE_VALUE_AS_ARRAY, true); + JSON_SERDE.configure(JsonParser.Feature.ALLOW_COMMENTS, true); + JSON_SERDE.registerModule(new Jdk8Module()); + JSON_SERDE.setSerializationInclusion(JsonInclude.Include.NON_EMPTY); + } + + public static String toJsonString(Object object) { + try { + return JSON_SERDE.writeValueAsString(object); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + public static String toPrettyJsonString(Object object) { + try { + return JSON_SERDE.writerWithDefaultPrettyPrinter().writeValueAsString(object); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + /** + * Determine if a string is a JSON object literal. + * Object literals must begin with an open brace. + * + * @param input The input string. + * @return True if the string is a JSON literal. + */ + static boolean openBraceComesFirst(String input) { + for (int i = 0; i < input.length(); i++) { + char c = input.charAt(i); + if (!Character.isWhitespace(c)) { + return c == '{'; + } + } + return false; + } + + /** + * Read a JSON object from a command-line argument. This can take the form of a path to + * a file containing the JSON object, or simply the raw JSON object itself. We will assume + * that if the string is a valid JSON object, the latter is true. If you want to specify a + * file name containing an open brace, you can force it to be interpreted as a file name be + * prefixing a ./ or full path. + * + * @param argument The command-line argument. + * @param clazz The class of the object to be read. + * @param The object type. + * @return The object which we read. + */ + public static T objectFromCommandLineArgument(String argument, Class clazz) throws Exception { + if (openBraceComesFirst(argument)) { + return JSON_SERDE.readValue(argument, clazz); + } else { + return JSON_SERDE.readValue(new File(argument), clazz); + } + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/common/Node.java b/trogdor/src/main/java/org/apache/kafka/trogdor/common/Node.java new file mode 100644 index 0000000..b0c63d8 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/common/Node.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.common; + +import org.apache.kafka.trogdor.agent.Agent; +import org.apache.kafka.trogdor.coordinator.Coordinator; + +import java.util.Set; + +/** + * Defines a node in a cluster topology + */ +public interface Node { + public static class Util { + public static int getIntConfig(Node node, String key, int defaultVal) { + String val = node.getConfig(key); + if (val == null) { + return defaultVal; + } else { + return Integer.parseInt(val); + } + } + + public static int getTrogdorAgentPort(Node node) { + return getIntConfig(node, Platform.Config.TROGDOR_AGENT_PORT, Agent.DEFAULT_PORT); + } + + public static int getTrogdorCoordinatorPort(Node node) { + return getIntConfig(node, Platform.Config.TROGDOR_COORDINATOR_PORT, Coordinator.DEFAULT_PORT); + } + } + + /** + * Get name for this node. + */ + String name(); + + /** + * Get hostname for this node. + */ + String hostname(); + + /** + * Get the configuration value associated with the key, or null if there + * is none. + */ + String getConfig(String key); + + /** + * Get the tags for this node. + */ + Set tags(); +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/common/Platform.java b/trogdor/src/main/java/org/apache/kafka/trogdor/common/Platform.java new file mode 100644 index 0000000..cb20620 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/common/Platform.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.common; + +import com.fasterxml.jackson.databind.JsonNode; + +import java.io.File; +import java.io.IOException; + +import org.apache.kafka.common.utils.Scheduler; +import org.apache.kafka.common.utils.Utils; + +/** + * Defines a cluster topology + */ +public interface Platform { + class Config { + public static final String TROGDOR_AGENT_PORT = "trogdor.agent.port"; + + public static final String TROGDOR_COORDINATOR_PORT = "trogdor.coordinator.port"; + + public static final String TROGDOR_COORDINATOR_HEARTBEAT_MS = + "trogdor.coordinator.heartbeat.ms"; + + public static final int TROGDOR_COORDINATOR_HEARTBEAT_MS_DEFAULT = 60000; + + public static Platform parse(String curNodeName, String path) throws Exception { + JsonNode root = JsonUtil.JSON_SERDE.readTree(new File(path)); + JsonNode platformNode = root.get("platform"); + if (platformNode == null) { + throw new RuntimeException("Expected to find a 'platform' field " + + "in the root JSON configuration object"); + } + String platformName = platformNode.textValue(); + return Utils.newParameterizedInstance(platformName, + String.class, curNodeName, + Scheduler.class, Scheduler.SYSTEM, + JsonNode.class, root); + } + } + + /** + * Get name for this platform. + */ + String name(); + + /** + * Get the current node. + */ + Node curNode(); + + /** + * Get the cluster topology. + */ + Topology topology(); + + /** + * Get the scheduler to use. + */ + Scheduler scheduler(); + + /** + * Run a command on this local node. + * + * Throws an exception if the command could not be run, or if the + * command returned a non-zero error status. + * + * @param command The command + * + * @return The command output. + */ + String runCommand(String[] command) throws IOException; +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/common/StringExpander.java b/trogdor/src/main/java/org/apache/kafka/trogdor/common/StringExpander.java new file mode 100644 index 0000000..3082a17 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/common/StringExpander.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.common; + +import java.util.HashSet; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Utilities for expanding strings that have range expressions in them. + * + * For example, 'foo[1-3]' would be expaneded to foo1, foo2, foo3. + * Strings that have no range expressions will not be expanded. + */ +public class StringExpander { + private final static Pattern NUMERIC_RANGE_PATTERN = + Pattern.compile("(.*)\\[([0-9]*)\\-([0-9]*)\\](.*)"); + + public static HashSet expand(String val) { + HashSet set = new HashSet<>(); + Matcher matcher = NUMERIC_RANGE_PATTERN.matcher(val); + if (!matcher.matches()) { + set.add(val); + return set; + } + String prequel = matcher.group(1); + String rangeStart = matcher.group(2); + String rangeEnd = matcher.group(3); + String epilog = matcher.group(4); + int rangeStartInt = Integer.parseInt(rangeStart); + int rangeEndInt = Integer.parseInt(rangeEnd); + if (rangeEndInt < rangeStartInt) { + throw new RuntimeException("Invalid range: start " + rangeStartInt + + " is higher than end " + rangeEndInt); + } + for (int i = rangeStartInt; i <= rangeEndInt; i++) { + set.add(String.format("%s%d%s", prequel, i, epilog)); + } + return set; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/common/StringFormatter.java b/trogdor/src/main/java/org/apache/kafka/trogdor/common/StringFormatter.java new file mode 100644 index 0000000..2e4a91c --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/common/StringFormatter.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.common; + +import java.time.Duration; +import java.time.ZoneOffset; +import java.time.format.DateTimeFormatter; +import java.util.ArrayList; +import java.util.Date; +import java.util.List; + +/** + * Utilities for formatting strings. + */ +public class StringFormatter { + /** + * Pretty-print a date string. + * + * @param timeMs The time since the epoch in milliseconds. + * @param zoneOffset The time zone offset. + * @return The date string in ISO format. + */ + public static String dateString(long timeMs, ZoneOffset zoneOffset) { + return new Date(timeMs).toInstant(). + atOffset(zoneOffset). + format(DateTimeFormatter.ISO_OFFSET_DATE_TIME); + } + + /** + * Pretty-print a duration. + * + * @param periodMs The duration in milliseconds. + * @return A human-readable duration string. + */ + public static String durationString(long periodMs) { + StringBuilder bld = new StringBuilder(); + Duration duration = Duration.ofMillis(periodMs); + long hours = duration.toHours(); + if (hours > 0) { + bld.append(hours).append("h"); + duration = duration.minusHours(hours); + } + long minutes = duration.toMinutes(); + if (minutes > 0) { + bld.append(minutes).append("m"); + duration = duration.minusMinutes(minutes); + } + long seconds = duration.getSeconds(); + if ((seconds != 0) || bld.toString().isEmpty()) { + bld.append(seconds).append("s"); + } + return bld.toString(); + } + + /** + * Formats strings in a grid pattern. + * + * All entries in the same column will have the same width. + * + * @param lines A list of lines. Each line contains a list of columns. + * Each line must contain the same number of columns. + * @return The string. + */ + public static String prettyPrintGrid(List> lines) { + int numColumns = -1; + int rowIndex = 0; + for (List col : lines) { + if (numColumns == -1) { + numColumns = col.size(); + } else if (numColumns != col.size()) { + throw new RuntimeException("Expected " + numColumns + " columns in row " + + rowIndex + ", but got " + col.size()); + } + rowIndex++; + } + List widths = new ArrayList<>(numColumns); + for (int x = 0; x < numColumns; x++) { + int w = 0; + for (List cols : lines) { + w = Math.max(w, cols.get(x).length() + 1); + } + widths.add(w); + } + StringBuilder bld = new StringBuilder(); + for (int y = 0; y < lines.size(); y++) { + List cols = lines.get(y); + for (int x = 0; x < cols.size(); x++) { + String val = cols.get(x); + int minWidth = widths.get(x); + bld.append(val); + for (int i = 0; i < minWidth - val.length(); i++) { + bld.append(" "); + } + } + bld.append(String.format("%n")); + } + return bld.toString(); + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/common/Topology.java b/trogdor/src/main/java/org/apache/kafka/trogdor/common/Topology.java new file mode 100644 index 0000000..576932e --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/common/Topology.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.common; + +import java.util.HashSet; +import java.util.Map; +import java.util.NavigableMap; +import java.util.Set; + +/** + * Defines a cluster topology + */ +public interface Topology { + class Util { + /** + * Get the names of agent nodes in the topology. + */ + public static Set agentNodeNames(Topology topology) { + Set set = new HashSet<>(); + for (Map.Entry entry : topology.nodes().entrySet()) { + if (entry.getValue().getConfig(Platform.Config.TROGDOR_AGENT_PORT) != null) { + set.add(entry.getKey()); + } + } + return set; + } + } + + /** + * Get the node with the given name. + */ + Node node(String id); + + /** + * Get a sorted map of node names to nodes. + */ + NavigableMap nodes(); +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/common/WorkerUtils.java b/trogdor/src/main/java/org/apache/kafka/trogdor/common/WorkerUtils.java new file mode 100644 index 0000000..23c0ba4 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/common/WorkerUtils.java @@ -0,0 +1,347 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.common; + +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.AdminClientConfig; +import org.apache.kafka.clients.admin.DescribeTopicsOptions; +import org.apache.kafka.clients.admin.DescribeTopicsResult; +import org.apache.kafka.clients.admin.ListTopicsOptions; +import org.apache.kafka.clients.admin.ListTopicsResult; +import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.clients.admin.TopicDescription; +import org.apache.kafka.clients.admin.TopicListing; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.TopicPartitionInfo; +import org.apache.kafka.common.errors.NotEnoughReplicasException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.TopicExistsException; +import org.apache.kafka.common.errors.UnknownTopicOrPartitionException; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.requests.CreateTopicsRequest; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.slf4j.Logger; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.regex.Pattern; + +/** + * Utilities for Trogdor TaskWorkers. + */ +public final class WorkerUtils { + /** + * Handle an exception in a TaskWorker. + * + * @param log The logger to use. + * @param what The component that had the exception. + * @param exception The exception. + * @param doneFuture The TaskWorker's doneFuture + * @throws KafkaException A wrapped version of the exception. + */ + public static void abort(Logger log, String what, Throwable exception, + KafkaFutureImpl doneFuture) throws KafkaException { + log.warn("{} caught an exception", what, exception); + if (exception.getMessage() == null || exception.getMessage().isEmpty()) { + doneFuture.complete(exception.getClass().getCanonicalName()); + } else { + doneFuture.complete(exception.getMessage()); + } + throw new KafkaException(exception); + } + + /** + * Convert a rate expressed per second to a rate expressed per the given period. + * + * @param perSec The per-second rate. + * @param periodMs The new period to use. + * @return The rate per period. This will never be less than 1. + */ + public static int perSecToPerPeriod(float perSec, long periodMs) { + float period = ((float) periodMs) / 1000.0f; + float perPeriod = perSec * period; + perPeriod = Math.max(1.0f, perPeriod); + return (int) perPeriod; + } + + /** + * Adds all properties from commonConf and then from clientConf to given 'props' (in + * that order, over-writing properties with the same keys). + * @param props Properties object that may contain zero or more properties + * @param commonConf Map with common client properties + * @param clientConf Map with client properties + */ + public static void addConfigsToProperties( + Properties props, Map commonConf, Map clientConf) { + for (Map.Entry commonEntry : commonConf.entrySet()) { + props.setProperty(commonEntry.getKey(), commonEntry.getValue()); + } + for (Map.Entry entry : clientConf.entrySet()) { + props.setProperty(entry.getKey(), entry.getValue()); + } + } + + private static final int ADMIN_REQUEST_TIMEOUT = 25000; + private static final int CREATE_TOPICS_CALL_TIMEOUT = 180000; + private static final int MAX_CREATE_TOPICS_BATCH_SIZE = 10; + + //Map>> topics) throws Throwable { + + /** + * Create some Kafka topics. + * + * @param log The logger to use. + * @param bootstrapServers The bootstrap server list. + * @param commonClientConf Common client config + * @param adminClientConf AdminClient config. This config has precedence over fields in + * common client config. + * @param topics Maps topic names to partition assignments. + * @param failOnExisting If true, the method will throw TopicExistsException if one or + * more topics already exist. Otherwise, the existing topics are + * verified for number of partitions. In this case, if number of + * partitions of an existing topic does not match the requested + * number of partitions, the method throws RuntimeException. + */ + public static void createTopics( + Logger log, String bootstrapServers, Map commonClientConf, + Map adminClientConf, + Map topics, boolean failOnExisting) throws Throwable { + // this method wraps the call to createTopics() that takes admin client, so that we can + // unit test the functionality with MockAdminClient. The exception is caught and + // re-thrown so that admin client is closed when the method returns. + try (Admin adminClient + = createAdminClient(bootstrapServers, commonClientConf, adminClientConf)) { + createTopics(log, adminClient, topics, failOnExisting); + } catch (Exception e) { + log.warn("Failed to create or verify topics {}", topics, e); + throw e; + } + } + + /** + * The actual create topics functionality is separated into this method and called from the + * above method to be able to unit test with mock adminClient. + * @throws TopicExistsException if the specified topic already exists. + * @throws UnknownTopicOrPartitionException if topic creation was issued but failed to verify if it was created. + * @throws Throwable if creation of one or more topics fails (except for the cases above). + */ + static void createTopics( + Logger log, Admin adminClient, + Map topics, boolean failOnExisting) throws Throwable { + if (topics.isEmpty()) { + log.warn("Request to create topics has an empty topic list."); + return; + } + + Collection topicsExists = createTopics(log, adminClient, topics.values()); + if (!topicsExists.isEmpty()) { + if (failOnExisting) { + log.warn("Topic(s) {} already exist.", topicsExists); + throw new TopicExistsException("One or more topics already exist."); + } else { + verifyTopics(log, adminClient, topicsExists, topics, 3, 2500); + } + } + } + + /** + * Creates Kafka topics and returns a list of topics that already exist + * @param log The logger to use + * @param adminClient AdminClient + * @param topics List of topics to create + * @return Collection of topics names that already exist. + * @throws Throwable if creation of one or more topics fails (except for topic exists case). + */ + private static Collection createTopics(Logger log, Admin adminClient, + Collection topics) throws Throwable { + long startMs = Time.SYSTEM.milliseconds(); + int tries = 0; + List existingTopics = new ArrayList<>(); + + Map newTopics = new HashMap<>(); + for (NewTopic newTopic : topics) { + newTopics.put(newTopic.name(), newTopic); + } + List topicsToCreate = new ArrayList<>(newTopics.keySet()); + while (true) { + log.info("Attempting to create {} topics (try {})...", topicsToCreate.size(), ++tries); + Map> creations = new HashMap<>(); + while (!topicsToCreate.isEmpty()) { + List newTopicsBatch = new ArrayList<>(); + for (int i = 0; (i < MAX_CREATE_TOPICS_BATCH_SIZE) && + !topicsToCreate.isEmpty(); i++) { + String topicName = topicsToCreate.remove(0); + newTopicsBatch.add(newTopics.get(topicName)); + } + creations.putAll(adminClient.createTopics(newTopicsBatch).values()); + } + // We retry cases where the topic creation failed with a + // timeout. This is a workaround for KAFKA-6368. + for (Map.Entry> entry : creations.entrySet()) { + String topicName = entry.getKey(); + Future future = entry.getValue(); + try { + future.get(); + log.debug("Successfully created {}.", topicName); + } catch (Exception e) { + if ((e.getCause() instanceof TimeoutException) + || (e.getCause() instanceof NotEnoughReplicasException)) { + log.warn("Attempt to create topic `{}` failed: {}", topicName, + e.getCause().getMessage()); + topicsToCreate.add(topicName); + } else if (e.getCause() instanceof TopicExistsException) { + log.info("Topic {} already exists.", topicName); + existingTopics.add(topicName); + } else { + log.warn("Failed to create {}", topicName, e.getCause()); + throw e.getCause(); + } + } + } + if (topicsToCreate.isEmpty()) { + break; + } + if (Time.SYSTEM.milliseconds() > startMs + CREATE_TOPICS_CALL_TIMEOUT) { + String str = "Unable to create topic(s): " + + Utils.join(topicsToCreate, ", ") + "after " + tries + " attempt(s)"; + log.warn(str); + throw new TimeoutException(str); + } + } + return existingTopics; + } + + /** + * Verifies that topics in 'topicsToVerify' list have the same number of partitions as + * described in 'topicsInfo' + * @param log The logger to use + * @param adminClient AdminClient + * @param topicsToVerify List of topics to verify + * @param topicsInfo Map of topic name to topic description, which includes topics in + * 'topicsToVerify' list. + * @param retryCount The number of times to retry the fetching of the topics + * @param retryBackoffMs The amount of time, in milliseconds, to wait in between retries + * @throws UnknownTopicOrPartitionException If at least one topic contained in 'topicsInfo' + * does not exist after retrying. + * @throws RuntimeException If one or more topics have different number of partitions than + * described in 'topicsInfo' + */ + static void verifyTopics( + Logger log, Admin adminClient, + Collection topicsToVerify, Map topicsInfo, int retryCount, long retryBackoffMs) throws Throwable { + + Map topicDescriptionMap = topicDescriptions(topicsToVerify, adminClient, + retryCount, retryBackoffMs); + + for (TopicDescription desc: topicDescriptionMap.values()) { + // map will always contain the topic since all topics in 'topicsExists' are in given + // 'topics' map + int partitions = topicsInfo.get(desc.name()).numPartitions(); + if (partitions != CreateTopicsRequest.NO_NUM_PARTITIONS && desc.partitions().size() != partitions) { + String str = "Topic '" + desc.name() + "' exists, but has " + + desc.partitions().size() + " partitions, while requested " + + " number of partitions is " + partitions; + log.warn(str); + throw new RuntimeException(str); + } + } + } + + private static Map topicDescriptions(Collection topicsToVerify, + Admin adminClient, + int retryCount, long retryBackoffMs) + throws ExecutionException, InterruptedException { + UnknownTopicOrPartitionException lastException = null; + for (int i = 0; i < retryCount; i++) { + try { + DescribeTopicsResult topicsResult = adminClient.describeTopics( + topicsToVerify, new DescribeTopicsOptions().timeoutMs(ADMIN_REQUEST_TIMEOUT)); + return topicsResult.allTopicNames().get(); + } catch (ExecutionException exception) { + if (exception.getCause() instanceof UnknownTopicOrPartitionException) { + lastException = (UnknownTopicOrPartitionException) exception.getCause(); + Thread.sleep(retryBackoffMs); + } else { + throw exception; + } + } + } + throw lastException; + } + + /** + * Returns list of existing, not internal, topics/partitions that match given pattern and + * where partitions are in range [startPartition, endPartition] + * @param adminClient AdminClient + * @param topicRegex Topic regular expression to match + * @return List of topic names + * @throws Throwable If failed to get list of existing topics + */ + static Collection getMatchingTopicPartitions( + Admin adminClient, String topicRegex, int startPartition, int endPartition) + throws Throwable { + final Pattern topicNamePattern = Pattern.compile(topicRegex); + + // first get list of matching topics + List matchedTopics = new ArrayList<>(); + ListTopicsResult res = adminClient.listTopics( + new ListTopicsOptions().timeoutMs(ADMIN_REQUEST_TIMEOUT)); + Map topicListingMap = res.namesToListings().get(); + for (Map.Entry topicListingEntry: topicListingMap.entrySet()) { + if (!topicListingEntry.getValue().isInternal() + && topicNamePattern.matcher(topicListingEntry.getKey()).matches()) { + matchedTopics.add(topicListingEntry.getKey()); + } + } + + // create a list of topic/partitions + List out = new ArrayList<>(); + DescribeTopicsResult topicsResult = adminClient.describeTopics( + matchedTopics, new DescribeTopicsOptions().timeoutMs(ADMIN_REQUEST_TIMEOUT)); + Map topicDescriptionMap = topicsResult.allTopicNames().get(); + for (TopicDescription desc: topicDescriptionMap.values()) { + List partitions = desc.partitions(); + for (TopicPartitionInfo info: partitions) { + if ((info.partition() >= startPartition) && (info.partition() <= endPartition)) { + out.add(new TopicPartition(desc.name(), info.partition())); + } + } + } + return out; + } + + private static Admin createAdminClient( + String bootstrapServers, + Map commonClientConf, Map adminClientConf) { + Properties props = new Properties(); + props.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers); + props.put(AdminClientConfig.REQUEST_TIMEOUT_MS_CONFIG, ADMIN_REQUEST_TIMEOUT); + // first add common client config, and then admin client config to properties, possibly + // over-writing default or common properties. + addConfigsToProperties(props, commonClientConf, adminClientConf); + return Admin.create(props); + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/coordinator/Coordinator.java b/trogdor/src/main/java/org/apache/kafka/trogdor/coordinator/Coordinator.java new file mode 100644 index 0000000..47f80e5 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/coordinator/Coordinator.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.coordinator; + +import net.sourceforge.argparse4j.ArgumentParsers; +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.ArgumentParserException; +import net.sourceforge.argparse4j.inf.Namespace; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.common.utils.Scheduler; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.trogdor.common.Node; +import org.apache.kafka.trogdor.common.Platform; +import org.apache.kafka.trogdor.rest.CoordinatorStatusResponse; +import org.apache.kafka.trogdor.rest.CreateTaskRequest; +import org.apache.kafka.trogdor.rest.DestroyTaskRequest; +import org.apache.kafka.trogdor.rest.JsonRestServer; +import org.apache.kafka.trogdor.rest.StopTaskRequest; +import org.apache.kafka.trogdor.rest.TaskRequest; +import org.apache.kafka.trogdor.rest.TaskState; +import org.apache.kafka.trogdor.rest.TasksRequest; +import org.apache.kafka.trogdor.rest.TasksResponse; +import org.apache.kafka.trogdor.rest.UptimeResponse; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.concurrent.ThreadLocalRandom; + +import static net.sourceforge.argparse4j.impl.Arguments.store; + +/** + * The Trogdor coordinator. + * + * The coordinator manages the agent processes in the cluster. + */ +public final class Coordinator { + private static final Logger log = LoggerFactory.getLogger(Coordinator.class); + + public static final int DEFAULT_PORT = 8889; + + /** + * The start time of the Coordinator in milliseconds. + */ + private final long startTimeMs; + + /** + * The task manager. + */ + private final TaskManager taskManager; + + /** + * The REST server. + */ + private final JsonRestServer restServer; + + private final Time time; + + /** + * Create a new Coordinator. + * + * @param platform The platform object to use. + * @param scheduler The scheduler to use for this Coordinator. + * @param restServer The REST server to use. + * @param resource The AgentRestResource to use. + */ + public Coordinator(Platform platform, Scheduler scheduler, JsonRestServer restServer, + CoordinatorRestResource resource, long firstWorkerId) { + this.time = scheduler.time(); + this.startTimeMs = time.milliseconds(); + this.taskManager = new TaskManager(platform, scheduler, firstWorkerId); + this.restServer = restServer; + resource.setCoordinator(this); + } + + public int port() { + return this.restServer.port(); + } + + public CoordinatorStatusResponse status() throws Exception { + return new CoordinatorStatusResponse(startTimeMs); + } + + public UptimeResponse uptime() { + return new UptimeResponse(startTimeMs, time.milliseconds()); + } + + public void createTask(CreateTaskRequest request) throws Throwable { + taskManager.createTask(request.id(), request.spec()); + } + + public void stopTask(StopTaskRequest request) throws Throwable { + taskManager.stopTask(request.id()); + } + + public void destroyTask(DestroyTaskRequest request) throws Throwable { + taskManager.destroyTask(request.id()); + } + + public TasksResponse tasks(TasksRequest request) throws Exception { + return taskManager.tasks(request); + } + + public TaskState task(TaskRequest request) throws Exception { + return taskManager.task(request); + } + + public void beginShutdown(boolean stopAgents) throws Exception { + restServer.beginShutdown(); + taskManager.beginShutdown(stopAgents); + } + + public void waitForShutdown() throws Exception { + restServer.waitForShutdown(); + taskManager.waitForShutdown(); + } + + public static void main(String[] args) throws Exception { + ArgumentParser parser = ArgumentParsers + .newArgumentParser("trogdor-coordinator") + .defaultHelp(true) + .description("The Trogdor fault injection coordinator"); + parser.addArgument("--coordinator.config", "-c") + .action(store()) + .required(true) + .type(String.class) + .dest("config") + .metavar("CONFIG") + .help("The configuration file to use."); + parser.addArgument("--node-name", "-n") + .action(store()) + .required(true) + .type(String.class) + .dest("node_name") + .metavar("NODE_NAME") + .help("The name of this node."); + Namespace res = null; + try { + res = parser.parseArgs(args); + } catch (ArgumentParserException e) { + if (args.length == 0) { + parser.printHelp(); + Exit.exit(0); + } else { + parser.handleError(e); + Exit.exit(1); + } + } + String configPath = res.getString("config"); + String nodeName = res.getString("node_name"); + + Platform platform = Platform.Config.parse(nodeName, configPath); + JsonRestServer restServer = new JsonRestServer( + Node.Util.getTrogdorCoordinatorPort(platform.curNode())); + CoordinatorRestResource resource = new CoordinatorRestResource(); + log.info("Starting coordinator process."); + final Coordinator coordinator = new Coordinator(platform, Scheduler.SYSTEM, + restServer, resource, ThreadLocalRandom.current().nextLong(0, Long.MAX_VALUE / 2)); + restServer.start(resource); + Exit.addShutdownHook("coordinator-shutdown-hook", () -> { + log.warn("Running coordinator shutdown hook."); + try { + coordinator.beginShutdown(false); + coordinator.waitForShutdown(); + } catch (Exception e) { + log.error("Got exception while running coordinator shutdown hook.", e); + } + }); + coordinator.waitForShutdown(); + } +}; diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/coordinator/CoordinatorClient.java b/trogdor/src/main/java/org/apache/kafka/trogdor/coordinator/CoordinatorClient.java new file mode 100644 index 0000000..078dbbc --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/coordinator/CoordinatorClient.java @@ -0,0 +1,517 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.coordinator; + +import com.fasterxml.jackson.core.type.TypeReference; +import net.sourceforge.argparse4j.ArgumentParsers; +import net.sourceforge.argparse4j.inf.ArgumentParser; +import net.sourceforge.argparse4j.inf.MutuallyExclusiveGroup; +import net.sourceforge.argparse4j.inf.Namespace; +import net.sourceforge.argparse4j.inf.Subparser; +import net.sourceforge.argparse4j.inf.Subparsers; +import org.apache.kafka.common.utils.Exit; +import org.apache.kafka.trogdor.common.JsonUtil; +import org.apache.kafka.trogdor.common.StringFormatter; +import org.apache.kafka.trogdor.rest.CoordinatorStatusResponse; +import org.apache.kafka.trogdor.rest.CreateTaskRequest; +import org.apache.kafka.trogdor.rest.DestroyTaskRequest; +import org.apache.kafka.trogdor.rest.Empty; +import org.apache.kafka.trogdor.rest.JsonRestServer; +import org.apache.kafka.trogdor.rest.JsonRestServer.HttpResponse; +import org.apache.kafka.trogdor.rest.StopTaskRequest; +import org.apache.kafka.trogdor.rest.TaskDone; +import org.apache.kafka.trogdor.rest.TaskPending; +import org.apache.kafka.trogdor.rest.TaskRequest; +import org.apache.kafka.trogdor.rest.TaskRunning; +import org.apache.kafka.trogdor.rest.TaskStateType; +import org.apache.kafka.trogdor.rest.TaskStopping; +import org.apache.kafka.trogdor.rest.TasksRequest; +import org.apache.kafka.trogdor.rest.TaskState; +import org.apache.kafka.trogdor.rest.TasksResponse; +import org.apache.kafka.trogdor.task.TaskSpec; +import org.apache.kafka.trogdor.rest.RequestConflictException; +import org.apache.kafka.trogdor.rest.UptimeResponse; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.ws.rs.NotFoundException; +import javax.ws.rs.core.UriBuilder; + +import java.time.OffsetDateTime; +import java.time.ZoneOffset; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.TreeMap; +import java.util.regex.Pattern; +import java.util.regex.PatternSyntaxException; + +import static net.sourceforge.argparse4j.impl.Arguments.append; +import static net.sourceforge.argparse4j.impl.Arguments.store; +import static net.sourceforge.argparse4j.impl.Arguments.storeTrue; +import static org.apache.kafka.trogdor.common.StringFormatter.dateString; +import static org.apache.kafka.trogdor.common.StringFormatter.durationString; + +/** + * A client for the Trogdor coordinator. + */ +public class CoordinatorClient { + private final Logger log; + + /** + * The maximum number of tries to make. + */ + private final int maxTries; + + /** + * The URL target. + */ + private final String target; + + public static class Builder { + private Logger log = LoggerFactory.getLogger(CoordinatorClient.class); + private int maxTries = 1; + private String target = null; + + public Builder() { + } + + public Builder log(Logger log) { + this.log = log; + return this; + } + + public Builder maxTries(int maxTries) { + this.maxTries = maxTries; + return this; + } + + public Builder target(String target) { + this.target = target; + return this; + } + + public Builder target(String host, int port) { + this.target = String.format("%s:%d", host, port); + return this; + } + + public CoordinatorClient build() { + if (target == null) { + throw new RuntimeException("You must specify a target."); + } + return new CoordinatorClient(log, maxTries, target); + } + } + + private CoordinatorClient(Logger log, int maxTries, String target) { + this.log = log; + this.maxTries = maxTries; + this.target = target; + } + + public int maxTries() { + return maxTries; + } + + private String url(String suffix) { + return String.format("http://%s%s", target, suffix); + } + + public CoordinatorStatusResponse status() throws Exception { + HttpResponse resp = + JsonRestServer.httpRequest(url("/coordinator/status"), "GET", + null, new TypeReference() { }, maxTries); + return resp.body(); + } + + public UptimeResponse uptime() throws Exception { + HttpResponse resp = + JsonRestServer.httpRequest(url("/coordinator/uptime"), "GET", + null, new TypeReference() { }, maxTries); + return resp.body(); + } + + public void createTask(CreateTaskRequest request) throws Exception { + HttpResponse resp = + JsonRestServer.httpRequest(log, url("/coordinator/task/create"), "POST", + request, new TypeReference() { }, maxTries); + resp.body(); + } + + public void stopTask(StopTaskRequest request) throws Exception { + HttpResponse resp = + JsonRestServer.httpRequest(log, url("/coordinator/task/stop"), "PUT", + request, new TypeReference() { }, maxTries); + resp.body(); + } + + public void destroyTask(DestroyTaskRequest request) throws Exception { + UriBuilder uriBuilder = UriBuilder.fromPath(url("/coordinator/tasks")); + uriBuilder.queryParam("taskId", request.id()); + HttpResponse resp = + JsonRestServer.httpRequest(log, uriBuilder.build().toString(), "DELETE", + null, new TypeReference() { }, maxTries); + resp.body(); + } + + public TasksResponse tasks(TasksRequest request) throws Exception { + UriBuilder uriBuilder = UriBuilder.fromPath(url("/coordinator/tasks")); + uriBuilder.queryParam("taskId", request.taskIds().toArray(new Object[0])); + uriBuilder.queryParam("firstStartMs", request.firstStartMs()); + uriBuilder.queryParam("lastStartMs", request.lastStartMs()); + uriBuilder.queryParam("firstEndMs", request.firstEndMs()); + uriBuilder.queryParam("lastEndMs", request.lastEndMs()); + if (request.state().isPresent()) { + uriBuilder.queryParam("state", request.state().get().toString()); + } + HttpResponse resp = + JsonRestServer.httpRequest(log, uriBuilder.build().toString(), "GET", + null, new TypeReference() { }, maxTries); + return resp.body(); + } + + public TaskState task(TaskRequest request) throws Exception { + String uri = UriBuilder.fromPath(url("/coordinator/tasks/{taskId}")).build(request.taskId()).toString(); + HttpResponse resp = JsonRestServer.httpRequest(log, uri, "GET", + null, new TypeReference() { }, maxTries); + return resp.body(); + } + + public void shutdown() throws Exception { + HttpResponse resp = + JsonRestServer.httpRequest(log, url("/coordinator/shutdown"), "PUT", + null, new TypeReference() { }, maxTries); + resp.body(); + } + + private static void addTargetArgument(ArgumentParser parser) { + parser.addArgument("--target", "-t") + .action(store()) + .required(true) + .type(String.class) + .dest("target") + .metavar("TARGET") + .help("A colon-separated host and port pair. For example, example.com:8889"); + } + + private static void addJsonArgument(ArgumentParser parser) { + parser.addArgument("--json") + .action(storeTrue()) + .dest("json") + .metavar("JSON") + .help("Show the full response as JSON."); + } + + public static void main(String[] args) throws Exception { + ArgumentParser rootParser = ArgumentParsers + .newArgumentParser("trogdor-coordinator-client") + .description("The Trogdor coordinator client."); + Subparsers subParsers = rootParser.addSubparsers(). + dest("command"); + Subparser uptimeParser = subParsers.addParser("uptime") + .help("Get the coordinator uptime."); + addTargetArgument(uptimeParser); + addJsonArgument(uptimeParser); + Subparser statusParser = subParsers.addParser("status") + .help("Get the coordinator status."); + addTargetArgument(statusParser); + addJsonArgument(statusParser); + Subparser showTaskParser = subParsers.addParser("showTask") + .help("Show a coordinator task."); + addTargetArgument(showTaskParser); + addJsonArgument(showTaskParser); + showTaskParser.addArgument("--id", "-i") + .action(store()) + .required(true) + .type(String.class) + .dest("taskId") + .metavar("TASK_ID") + .help("The task ID to show."); + showTaskParser.addArgument("--verbose", "-v") + .action(storeTrue()) + .dest("verbose") + .metavar("VERBOSE") + .help("Print out everything."); + showTaskParser.addArgument("--show-status", "-S") + .action(storeTrue()) + .dest("showStatus") + .metavar("SHOW_STATUS") + .help("Show the task status."); + Subparser showTasksParser = subParsers.addParser("showTasks") + .help("Show many coordinator tasks. By default, all tasks are shown, but " + + "command-line options can be specified as filters."); + addTargetArgument(showTasksParser); + addJsonArgument(showTasksParser); + MutuallyExclusiveGroup idGroup = showTasksParser.addMutuallyExclusiveGroup(); + idGroup.addArgument("--id", "-i") + .action(append()) + .type(String.class) + .dest("taskIds") + .metavar("TASK_IDS") + .help("Show only this task ID. This option may be specified multiple times."); + idGroup.addArgument("--id-pattern") + .action(store()) + .type(String.class) + .dest("taskIdPattern") + .metavar("TASK_ID_PATTERN") + .help("Only display tasks which match the given ID pattern."); + showTasksParser.addArgument("--state", "-s") + .type(TaskStateType.class) + .dest("taskStateType") + .metavar("TASK_STATE_TYPE") + .help("Show only tasks in this state."); + Subparser createTaskParser = subParsers.addParser("createTask") + .help("Create a new task."); + addTargetArgument(createTaskParser); + createTaskParser.addArgument("--id", "-i") + .action(store()) + .required(true) + .type(String.class) + .dest("taskId") + .metavar("TASK_ID") + .help("The task ID to create."); + createTaskParser.addArgument("--spec", "-s") + .action(store()) + .required(true) + .type(String.class) + .dest("taskSpec") + .metavar("TASK_SPEC") + .help("The task spec to create, or a path to a file containing the task spec."); + Subparser stopTaskParser = subParsers.addParser("stopTask") + .help("Stop a task."); + addTargetArgument(stopTaskParser); + stopTaskParser.addArgument("--id", "-i") + .action(store()) + .required(true) + .type(String.class) + .dest("taskId") + .metavar("TASK_ID") + .help("The task ID to create."); + Subparser destroyTaskParser = subParsers.addParser("destroyTask") + .help("Destroy a task."); + addTargetArgument(destroyTaskParser); + destroyTaskParser.addArgument("--id", "-i") + .action(store()) + .required(true) + .type(String.class) + .dest("taskId") + .metavar("TASK_ID") + .help("The task ID to destroy."); + Subparser shutdownParser = subParsers.addParser("shutdown") + .help("Shut down the coordinator."); + addTargetArgument(shutdownParser); + + Namespace res = rootParser.parseArgsOrFail(args); + String target = res.getString("target"); + CoordinatorClient client = new Builder(). + maxTries(3). + target(target). + build(); + ZoneOffset localOffset = OffsetDateTime.now().getOffset(); + switch (res.getString("command")) { + case "uptime": { + UptimeResponse uptime = client.uptime(); + if (res.getBoolean("json")) { + System.out.println(JsonUtil.toJsonString(uptime)); + } else { + System.out.printf("Coordinator is running at %s.%n", target); + System.out.printf("\tStart time: %s%n", + dateString(uptime.serverStartMs(), localOffset)); + System.out.printf("\tCurrent server time: %s%n", + dateString(uptime.nowMs(), localOffset)); + System.out.printf("\tUptime: %s%n", + durationString(uptime.nowMs() - uptime.serverStartMs())); + } + break; + } + case "status": { + CoordinatorStatusResponse response = client.status(); + if (res.getBoolean("json")) { + System.out.println(JsonUtil.toJsonString(response)); + } else { + System.out.printf("Coordinator is running at %s.%n", target); + System.out.printf("\tStart time: %s%n", dateString(response.serverStartMs(), localOffset)); + } + break; + } + case "showTask": { + String taskId = res.getString("taskId"); + TaskRequest req = new TaskRequest(taskId); + TaskState taskState = null; + try { + taskState = client.task(req); + } catch (NotFoundException e) { + System.out.printf("Task %s was not found.%n", taskId); + Exit.exit(1); + } + if (res.getBoolean("json")) { + System.out.println(JsonUtil.toJsonString(taskState)); + } else { + System.out.printf("Task %s of type %s is %s. %s%n", taskId, + taskState.spec().getClass().getCanonicalName(), + taskState.stateType(), prettyPrintTaskInfo(taskState, localOffset)); + if (taskState instanceof TaskDone) { + TaskDone taskDone = (TaskDone) taskState; + if ((taskDone.error() != null) && (!taskDone.error().isEmpty())) { + System.out.printf("Error: %s%n", taskDone.error()); + } + } + if (res.getBoolean("verbose")) { + System.out.printf("Spec: %s%n%n", JsonUtil.toPrettyJsonString(taskState.spec())); + } + if (res.getBoolean("verbose") || res.getBoolean("showStatus")) { + System.out.printf("Status: %s%n%n", JsonUtil.toPrettyJsonString(taskState.status())); + } + } + break; + } + case "showTasks": { + TaskStateType taskStateType = res.get("taskStateType"); + List taskIds = new ArrayList<>(); + Pattern taskIdPattern = null; + if (res.getList("taskIds") != null) { + for (Object taskId : res.getList("taskIds")) { + taskIds.add((String) taskId); + } + } else if (res.getString("taskIdPattern") != null) { + try { + taskIdPattern = Pattern.compile(res.getString("taskIdPattern")); + } catch (PatternSyntaxException e) { + System.out.println("Invalid task ID regular expression " + res.getString("taskIdPattern")); + e.printStackTrace(); + Exit.exit(1); + } + } + TasksRequest req = new TasksRequest(taskIds, 0, 0, 0, 0, + Optional.ofNullable(taskStateType)); + TasksResponse response = client.tasks(req); + if (taskIdPattern != null) { + TreeMap filteredTasks = new TreeMap<>(); + for (Map.Entry entry : response.tasks().entrySet()) { + if (taskIdPattern.matcher(entry.getKey()).matches()) { + filteredTasks.put(entry.getKey(), entry.getValue()); + } + } + response = new TasksResponse(filteredTasks); + } + if (res.getBoolean("json")) { + System.out.println(JsonUtil.toJsonString(response)); + } else { + System.out.println(prettyPrintTasksResponse(response, localOffset)); + } + if (response.tasks().isEmpty()) { + Exit.exit(1); + } + break; + } + case "createTask": { + String taskId = res.getString("taskId"); + TaskSpec taskSpec = JsonUtil. + objectFromCommandLineArgument(res.getString("taskSpec"), TaskSpec.class); + CreateTaskRequest req = new CreateTaskRequest(taskId, taskSpec); + try { + client.createTask(req); + System.out.printf("Sent CreateTaskRequest for task %s.%n", req.id()); + } catch (RequestConflictException rce) { + System.out.printf("CreateTaskRequest for task %s got a 409 status code - " + + "a task with the same ID but a different specification already exists.%nException: %s%n", + req.id(), rce.getMessage()); + Exit.exit(1); + } + break; + } + case "stopTask": { + String taskId = res.getString("taskId"); + StopTaskRequest req = new StopTaskRequest(taskId); + client.stopTask(req); + System.out.printf("Sent StopTaskRequest for task %s.%n", taskId); + break; + } + case "destroyTask": { + String taskId = res.getString("taskId"); + DestroyTaskRequest req = new DestroyTaskRequest(taskId); + client.destroyTask(req); + System.out.printf("Sent DestroyTaskRequest for task %s.%n", taskId); + break; + } + case "shutdown": { + client.shutdown(); + System.out.println("Sent ShutdownRequest."); + break; + } + default: { + System.out.println("You must choose an action. Type --help for help."); + Exit.exit(1); + } + } + } + + static String prettyPrintTasksResponse(TasksResponse response, ZoneOffset zoneOffset) { + if (response.tasks().isEmpty()) { + return "No matching tasks found."; + } + List> lines = new ArrayList<>(); + List header = new ArrayList<>( + Arrays.asList("ID", "TYPE", "STATE", "INFO")); + lines.add(header); + for (Map.Entry entry : response.tasks().entrySet()) { + String taskId = entry.getKey(); + TaskState taskState = entry.getValue(); + List cols = new ArrayList<>(); + cols.add(taskId); + cols.add(taskState.spec().getClass().getCanonicalName()); + cols.add(taskState.stateType().toString()); + cols.add(prettyPrintTaskInfo(taskState, zoneOffset)); + lines.add(cols); + } + return StringFormatter.prettyPrintGrid(lines); + } + + static String prettyPrintTaskInfo(TaskState taskState, ZoneOffset zoneOffset) { + if (taskState instanceof TaskPending) { + return "Will start at " + dateString(taskState.spec().startMs(), zoneOffset); + } else if (taskState instanceof TaskRunning) { + TaskRunning runState = (TaskRunning) taskState; + return "Started " + dateString(runState.startedMs(), zoneOffset) + + "; will stop after " + durationString(taskState.spec().durationMs()); + } else if (taskState instanceof TaskStopping) { + TaskStopping stoppingState = (TaskStopping) taskState; + return "Started " + dateString(stoppingState.startedMs(), zoneOffset); + } else if (taskState instanceof TaskDone) { + TaskDone doneState = (TaskDone) taskState; + String status = null; + if (doneState.error() == null || doneState.error().isEmpty()) { + if (doneState.cancelled()) { + status = "CANCELLED"; + } else { + status = "FINISHED"; + } + } else { + status = "FAILED"; + } + return String.format("%s at %s after %s", status, + dateString(doneState.doneMs(), zoneOffset), + durationString(doneState.doneMs() - doneState.startedMs())); + } else { + throw new RuntimeException("Unknown task state type " + taskState.stateType()); + } + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/coordinator/CoordinatorRestResource.java b/trogdor/src/main/java/org/apache/kafka/trogdor/coordinator/CoordinatorRestResource.java new file mode 100644 index 0000000..337f2b4 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/coordinator/CoordinatorRestResource.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.trogdor.coordinator; + +import org.apache.kafka.trogdor.rest.CoordinatorShutdownRequest; +import org.apache.kafka.trogdor.rest.CoordinatorStatusResponse; +import org.apache.kafka.trogdor.rest.CreateTaskRequest; +import org.apache.kafka.trogdor.rest.DestroyTaskRequest; +import org.apache.kafka.trogdor.rest.Empty; +import org.apache.kafka.trogdor.rest.StopTaskRequest; +import org.apache.kafka.trogdor.rest.TaskRequest; +import org.apache.kafka.trogdor.rest.TaskState; +import org.apache.kafka.trogdor.rest.TaskStateType; +import org.apache.kafka.trogdor.rest.TasksRequest; +import org.apache.kafka.trogdor.rest.TasksResponse; +import org.apache.kafka.trogdor.rest.UptimeResponse; + +import javax.servlet.ServletContext; +import javax.ws.rs.Consumes; +import javax.ws.rs.DELETE; +import javax.ws.rs.DefaultValue; +import javax.ws.rs.GET; +import javax.ws.rs.NotFoundException; +import javax.ws.rs.POST; +import javax.ws.rs.PUT; +import javax.ws.rs.Path; +import javax.ws.rs.PathParam; +import javax.ws.rs.Produces; +import javax.ws.rs.QueryParam; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicReference; + +/** + * The REST resource for the Coordinator. This describes the RPCs which the coordinator + * can accept. + * + * RPCs should be idempotent. This is important because if the server's response is + * lost, the client will simply retransmit the same request. The server's response must + * be the same the second time around. + * + * We return the empty JSON object {} rather than void for RPCs that have no results. + * This ensures that if we want to add more return results later, we can do so in a + * compatible way. + */ +@Path("/coordinator") +@Produces(MediaType.APPLICATION_JSON) +@Consumes(MediaType.APPLICATION_JSON) +public class CoordinatorRestResource { + private final AtomicReference coordinator = new AtomicReference(); + + @javax.ws.rs.core.Context + private ServletContext context; + + public void setCoordinator(Coordinator myCoordinator) { + coordinator.set(myCoordinator); + } + + @GET + @Path("/status") + public CoordinatorStatusResponse status() throws Throwable { + return coordinator().status(); + } + + @GET + @Path("/uptime") + public UptimeResponse uptime() { + return coordinator().uptime(); + } + + @POST + @Path("/task/create") + public Empty createTask(CreateTaskRequest request) throws Throwable { + coordinator().createTask(request); + return Empty.INSTANCE; + } + + @PUT + @Path("/task/stop") + public Empty stopTask(StopTaskRequest request) throws Throwable { + coordinator().stopTask(request); + return Empty.INSTANCE; + } + + @DELETE + @Path("/tasks") + public Empty destroyTask(@DefaultValue("") @QueryParam("taskId") String taskId) throws Throwable { + coordinator().destroyTask(new DestroyTaskRequest(taskId)); + return Empty.INSTANCE; + } + + @GET + @Path("/tasks/") + public Response tasks(@QueryParam("taskId") List taskId, + @DefaultValue("0") @QueryParam("firstStartMs") long firstStartMs, + @DefaultValue("0") @QueryParam("lastStartMs") long lastStartMs, + @DefaultValue("0") @QueryParam("firstEndMs") long firstEndMs, + @DefaultValue("0") @QueryParam("lastEndMs") long lastEndMs, + @DefaultValue("") @QueryParam("state") String state) throws Throwable { + boolean isEmptyState = state.equals(""); + if (!isEmptyState && !TaskStateType.Constants.VALUES.contains(state)) { + return Response.status(400).entity( + String.format("State %s is invalid. Must be one of %s", + state, TaskStateType.Constants.VALUES) + ).build(); + } + + Optional givenState = Optional.ofNullable(isEmptyState ? null : TaskStateType.valueOf(state)); + TasksResponse resp = coordinator().tasks(new TasksRequest(taskId, firstStartMs, lastStartMs, firstEndMs, lastEndMs, givenState)); + + return Response.status(200).entity(resp).build(); + } + + @GET + @Path("/tasks/{taskId}") + public TaskState tasks(@PathParam("taskId") String taskId) throws Throwable { + TaskState response = coordinator().task(new TaskRequest(taskId)); + if (response == null) + throw new NotFoundException(String.format("No task with ID \"%s\" exists.", taskId)); + + return response; + } + + @PUT + @Path("/shutdown") + public Empty beginShutdown(CoordinatorShutdownRequest request) throws Throwable { + coordinator().beginShutdown(request.stopAgents()); + return Empty.INSTANCE; + } + + private Coordinator coordinator() { + Coordinator myCoordinator = coordinator.get(); + if (myCoordinator == null) { + throw new RuntimeException("CoordinatorRestResource has not been initialized yet."); + } + return myCoordinator; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/coordinator/NodeManager.java b/trogdor/src/main/java/org/apache/kafka/trogdor/coordinator/NodeManager.java new file mode 100644 index 0000000..9dc379c --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/coordinator/NodeManager.java @@ -0,0 +1,387 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * So, when a task comes in, it happens via createTask (the RPC backend). + * This starts a CreateTask on the main state change thread, and waits for it. + * That task checks the main task hash map, and returns back the existing task spec + * if there is something there. If there is nothing there, it creates + * something new, and returns null. + * It also schedules a RunTask some time in the future on the main state change thread. + * We save the future from this in case we need to cancel it later, in a StopTask. + * If we can't create the TaskController for the task, we transition to DONE with an + * appropriate error message. + * + * RunTask actually starts the task which was created earlier. This could + * happen an arbitrary amount of time after task creation (it is based on the + * task spec). RunTask must operate only on PENDING tasks... if the task has been + * stopped, then we have nothing to do here. + * RunTask asks the TaskController for a list of all the names of nodes + * affected by this task. + * If this list contains nodes we don't know about, or zero nodes, we + * transition directly to DONE state with an appropriate error set. + * RunTask schedules CreateWorker Callables on all the affected worker nodes. + * These callables run in the context of the relevant NodeManager. + * + * CreateWorker calls the RPC of the same name for the agent. + * There is some complexity here due to retries. + */ + +package org.apache.kafka.trogdor.coordinator; + +import org.apache.kafka.common.utils.ThreadUtils; +import org.apache.kafka.trogdor.agent.AgentClient; +import org.apache.kafka.trogdor.common.Node; +import org.apache.kafka.trogdor.rest.AgentStatusResponse; +import org.apache.kafka.trogdor.rest.CreateWorkerRequest; +import org.apache.kafka.trogdor.rest.StopWorkerRequest; +import org.apache.kafka.trogdor.rest.WorkerDone; +import org.apache.kafka.trogdor.rest.WorkerReceiving; +import org.apache.kafka.trogdor.rest.WorkerRunning; +import org.apache.kafka.trogdor.rest.WorkerStarting; +import org.apache.kafka.trogdor.rest.WorkerState; +import org.apache.kafka.trogdor.rest.WorkerStopping; +import org.apache.kafka.trogdor.task.TaskSpec; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.ConnectException; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.Callable; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; + +/** + * The NodeManager handles communicating with a specific agent node. + * Each NodeManager has its own ExecutorService which runs in a dedicated thread. + */ +public final class NodeManager { + private static final Logger log = LoggerFactory.getLogger(NodeManager.class); + + /** + * The normal amount of seconds between heartbeats sent to the agent. + */ + private static final long HEARTBEAT_DELAY_MS = 1000L; + + class ManagedWorker { + private final long workerId; + private final String taskId; + private final TaskSpec spec; + private boolean shouldRun; + private WorkerState state; + + ManagedWorker(long workerId, String taskId, TaskSpec spec, + boolean shouldRun, WorkerState state) { + this.workerId = workerId; + this.taskId = taskId; + this.spec = spec; + this.shouldRun = shouldRun; + this.state = state; + } + + void tryCreate() { + try { + client.createWorker(new CreateWorkerRequest(workerId, taskId, spec)); + } catch (Throwable e) { + log.error("{}: error creating worker {}.", node.name(), this, e); + } + } + + void tryStop() { + try { + client.stopWorker(new StopWorkerRequest(workerId)); + } catch (Throwable e) { + log.error("{}: error stopping worker {}.", node.name(), this, e); + } + } + + @Override + public String toString() { + return String.format("%s_%d", taskId, workerId); + } + } + + /** + * The node which we are managing. + */ + private final Node node; + + /** + * The task manager. + */ + private final TaskManager taskManager; + + /** + * A client for the Node's Agent. + */ + private final AgentClient client; + + /** + * Maps task IDs to worker structures. + */ + private final Map workers; + + /** + * An executor service which manages the thread dedicated to this node. + */ + private final ScheduledExecutorService executor; + + /** + * The heartbeat runnable. + */ + private final NodeHeartbeat heartbeat; + + /** + * A future which can be used to cancel the periodic hearbeat task. + */ + private ScheduledFuture heartbeatFuture; + + NodeManager(Node node, TaskManager taskManager) { + this.node = node; + this.taskManager = taskManager; + this.client = new AgentClient.Builder(). + maxTries(1). + target(node.hostname(), Node.Util.getTrogdorAgentPort(node)). + build(); + this.workers = new HashMap<>(); + this.executor = Executors.newSingleThreadScheduledExecutor( + ThreadUtils.createThreadFactory("NodeManager(" + node.name() + ")", + false)); + this.heartbeat = new NodeHeartbeat(); + rescheduleNextHeartbeat(HEARTBEAT_DELAY_MS); + } + + /** + * Reschedule the heartbeat runnable. + * + * @param initialDelayMs The initial delay to use. + */ + void rescheduleNextHeartbeat(long initialDelayMs) { + if (this.heartbeatFuture != null) { + this.heartbeatFuture.cancel(false); + } + this.heartbeatFuture = this.executor.scheduleAtFixedRate(heartbeat, + initialDelayMs, HEARTBEAT_DELAY_MS, TimeUnit.MILLISECONDS); + } + + /** + * The heartbeat runnable. + */ + class NodeHeartbeat implements Runnable { + @Override + public void run() { + rescheduleNextHeartbeat(HEARTBEAT_DELAY_MS); + try { + AgentStatusResponse agentStatus = null; + try { + agentStatus = client.status(); + } catch (ConnectException e) { + log.error("{}: failed to get agent status: ConnectException {}", node.name(), e.getMessage()); + return; + } catch (Exception e) { + log.error("{}: failed to get agent status", node.name(), e); + // TODO: eventually think about putting tasks into a bad state as a result of + // agents going down? + return; + } + if (log.isTraceEnabled()) { + log.trace("{}: got heartbeat status {}", node.name(), agentStatus); + } + handleMissingWorkers(agentStatus); + handlePresentWorkers(agentStatus); + } catch (Throwable e) { + log.error("{}: Unhandled exception in NodeHeartbeatRunnable", node.name(), e); + } + } + + /** + * Identify workers which we think should be running but do not appear in the agent's response. + * We need to send startWorker requests for those + */ + private void handleMissingWorkers(AgentStatusResponse agentStatus) { + for (Map.Entry entry : workers.entrySet()) { + Long workerId = entry.getKey(); + if (!agentStatus.workers().containsKey(workerId)) { + ManagedWorker worker = entry.getValue(); + if (worker.shouldRun) { + worker.tryCreate(); + } + } + } + } + + private void handlePresentWorkers(AgentStatusResponse agentStatus) { + for (Map.Entry entry : agentStatus.workers().entrySet()) { + long workerId = entry.getKey(); + WorkerState state = entry.getValue(); + ManagedWorker worker = workers.get(workerId); + if (worker == null) { + // Identify tasks which are running, but which we don't know about. + // Add these to the NodeManager as tasks that should not be running. + log.warn("{}: scheduling unknown worker with ID {} for stopping.", node.name(), workerId); + workers.put(workerId, new ManagedWorker(workerId, state.taskId(), + state.spec(), false, state)); + } else { + // Handle workers which need to be stopped. + if (state instanceof WorkerStarting || state instanceof WorkerRunning) { + if (!worker.shouldRun) { + worker.tryStop(); + } + } + // Notify the TaskManager if the worker state has changed. + if (worker.state.equals(state)) { + log.debug("{}: worker state is still {}", node.name(), worker.state); + } else { + log.info("{}: worker state changed from {} to {}", node.name(), worker.state, state); + if (state instanceof WorkerDone || state instanceof WorkerStopping) + worker.shouldRun = false; + worker.state = state; + taskManager.updateWorkerState(node.name(), worker.workerId, state); + } + } + } + } + } + + /** + * Create a new worker. + * + * @param workerId The new worker id. + * @param taskId The new task id. + * @param spec The task specification to use with the new worker. + */ + public void createWorker(long workerId, String taskId, TaskSpec spec) { + executor.submit(new CreateWorker(workerId, taskId, spec)); + } + + /** + * Starts a worker. + */ + class CreateWorker implements Callable { + private final long workerId; + private final String taskId; + private final TaskSpec spec; + + CreateWorker(long workerId, String taskId, TaskSpec spec) { + this.workerId = workerId; + this.taskId = taskId; + this.spec = spec; + } + + @Override + public Void call() throws Exception { + ManagedWorker worker = workers.get(workerId); + if (worker != null) { + log.error("{}: there is already a worker {} with ID {}.", + node.name(), worker, workerId); + return null; + } + worker = new ManagedWorker(workerId, taskId, spec, true, new WorkerReceiving(taskId, spec)); + log.info("{}: scheduling worker {} to start.", node.name(), worker); + workers.put(workerId, worker); + rescheduleNextHeartbeat(0); + return null; + } + } + + /** + * Stop a worker. + * + * @param workerId The id of the worker to stop. + */ + public void stopWorker(long workerId) { + executor.submit(new StopWorker(workerId)); + } + + /** + * Stops a worker. + */ + class StopWorker implements Callable { + private final long workerId; + + StopWorker(long workerId) { + this.workerId = workerId; + } + + @Override + public Void call() throws Exception { + ManagedWorker worker = workers.get(workerId); + if (worker == null) { + log.error("{}: unable to locate worker to stop with ID {}.", node.name(), workerId); + return null; + } + if (!worker.shouldRun) { + log.error("{}: Worker {} is already scheduled to stop.", + node.name(), worker); + return null; + } + log.info("{}: scheduling worker {} to stop.", node.name(), worker); + worker.shouldRun = false; + rescheduleNextHeartbeat(0); + return null; + } + } + + /** + * Destroy a worker. + * + * @param workerId The id of the worker to destroy. + */ + public void destroyWorker(long workerId) { + executor.submit(new DestroyWorker(workerId)); + } + + /** + * Destroys a worker. + */ + class DestroyWorker implements Callable { + private final long workerId; + + DestroyWorker(long workerId) { + this.workerId = workerId; + } + + @Override + public Void call() throws Exception { + ManagedWorker worker = workers.remove(workerId); + if (worker == null) { + log.error("{}: unable to locate worker to destroy with ID {}.", node.name(), workerId); + return null; + } + rescheduleNextHeartbeat(0); + return null; + } + } + + public void beginShutdown(boolean stopNode) { + executor.shutdownNow(); + if (stopNode) { + try { + client.invokeShutdown(); + } catch (Exception e) { + log.error("{}: Failed to send shutdown request", node.name(), e); + } + } + } + + public void waitForShutdown() throws InterruptedException { + executor.awaitTermination(1, TimeUnit.DAYS); + } +}; diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/coordinator/TaskManager.java b/trogdor/src/main/java/org/apache/kafka/trogdor/coordinator/TaskManager.java new file mode 100644 index 0000000..a307e59 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/coordinator/TaskManager.java @@ -0,0 +1,707 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.coordinator; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.JsonNodeFactory; +import com.fasterxml.jackson.databind.node.LongNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.errors.InvalidRequestException; +import org.apache.kafka.common.utils.Scheduler; +import org.apache.kafka.common.utils.ThreadUtils; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.trogdor.common.JsonUtil; +import org.apache.kafka.trogdor.common.Node; +import org.apache.kafka.trogdor.common.Platform; +import org.apache.kafka.trogdor.rest.RequestConflictException; +import org.apache.kafka.trogdor.rest.TaskDone; +import org.apache.kafka.trogdor.rest.TaskPending; +import org.apache.kafka.trogdor.rest.TaskRequest; +import org.apache.kafka.trogdor.rest.TaskRunning; +import org.apache.kafka.trogdor.rest.TaskState; +import org.apache.kafka.trogdor.rest.TaskStateType; +import org.apache.kafka.trogdor.rest.TaskStopping; +import org.apache.kafka.trogdor.rest.TasksRequest; +import org.apache.kafka.trogdor.rest.TasksResponse; +import org.apache.kafka.trogdor.rest.WorkerDone; +import org.apache.kafka.trogdor.rest.WorkerReceiving; +import org.apache.kafka.trogdor.rest.WorkerState; +import org.apache.kafka.trogdor.task.TaskController; +import org.apache.kafka.trogdor.task.TaskSpec; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; +import java.util.TreeSet; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * The TaskManager is responsible for managing tasks inside the Trogdor coordinator. + * + * The task manager has a single thread, managed by the executor. We start, stop, + * and handle state changes to tasks by adding requests to the executor queue. + * Because the executor is single threaded, no locks are needed when accessing + * TaskManager data structures. + * + * The TaskManager maintains a state machine for each task. Tasks begin in the + * PENDING state, waiting for their designated start time to arrive. + * When their time arrives, they transition to the RUNNING state. In this state, + * the NodeManager will start them, and monitor them. + * + * The TaskManager does not handle communication with the agents. This is handled + * by the NodeManagers. There is one NodeManager per node being managed. + * See {org.apache.kafka.trogdor.coordinator.NodeManager} for details. + */ +public final class TaskManager { + private static final Logger log = LoggerFactory.getLogger(TaskManager.class); + + /** + * The platform. + */ + private final Platform platform; + + /** + * The scheduler to use for this coordinator. + */ + private final Scheduler scheduler; + + /** + * The clock to use for this coordinator. + */ + private final Time time; + + /** + * A map of task IDs to Task objects. + */ + private final Map tasks; + + /** + * The executor used for handling Task state changes. + */ + private final ScheduledExecutorService executor; + + /** + * Maps node names to node managers. + */ + private final Map nodeManagers; + + /** + * The states of all workers. + */ + private final Map workerStates = new HashMap<>(); + + /** + * True if the TaskManager is shut down. + */ + private AtomicBoolean shutdown = new AtomicBoolean(false); + + /** + * The ID to use for the next worker. Only accessed by the state change thread. + */ + private long nextWorkerId; + + TaskManager(Platform platform, Scheduler scheduler, long firstWorkerId) { + this.platform = platform; + this.scheduler = scheduler; + this.time = scheduler.time(); + this.tasks = new HashMap<>(); + this.executor = Executors.newSingleThreadScheduledExecutor( + ThreadUtils.createThreadFactory("TaskManagerStateThread", false)); + this.nodeManagers = new HashMap<>(); + this.nextWorkerId = firstWorkerId; + for (Node node : platform.topology().nodes().values()) { + if (Node.Util.getTrogdorAgentPort(node) > 0) { + this.nodeManagers.put(node.name(), new NodeManager(node, this)); + } + } + log.info("Created TaskManager for agent(s) on: {}", + Utils.join(nodeManagers.keySet(), ", ")); + } + + class ManagedTask { + /** + * The task id. + */ + final private String id; + + /** + * The original task specification as submitted when the task was created. + */ + final private TaskSpec originalSpec; + + /** + * The effective task specification. + * The start time will be adjusted to reflect the time when the task was submitted. + */ + final private TaskSpec spec; + + /** + * The task controller. + */ + final private TaskController controller; + + /** + * The task state. + */ + private TaskStateType state; + + /** + * The time when the task was started, or -1 if the task has not been started. + */ + private long startedMs = -1; + + /** + * The time when the task was finished, or -1 if the task has not been finished. + */ + private long doneMs = -1; + + /** + * True if the task was cancelled by a stop request. + */ + boolean cancelled = false; + + /** + * If there is a task start scheduled, this is a future which can + * be used to cancel it. + */ + private Future startFuture = null; + + /** + * Maps node names to worker IDs. + */ + public TreeMap workerIds = new TreeMap<>(); + + /** + * If this is non-empty, a message describing how this task failed. + */ + private String error = ""; + + ManagedTask(String id, TaskSpec originalSpec, TaskSpec spec, + TaskController controller, TaskStateType state) { + this.id = id; + this.originalSpec = originalSpec; + this.spec = spec; + this.controller = controller; + this.state = state; + } + + void clearStartFuture() { + if (startFuture != null) { + startFuture.cancel(false); + startFuture = null; + } + } + + long startDelayMs(long now) { + if (now > spec.startMs()) { + return 0; + } + return spec.startMs() - now; + } + + TreeSet findNodeNames() { + Set nodeNames = controller.targetNodes(platform.topology()); + TreeSet validNodeNames = new TreeSet<>(); + TreeSet nonExistentNodeNames = new TreeSet<>(); + for (String nodeName : nodeNames) { + if (nodeManagers.containsKey(nodeName)) { + validNodeNames.add(nodeName); + } else { + nonExistentNodeNames.add(nodeName); + } + } + if (!nonExistentNodeNames.isEmpty()) { + throw new KafkaException("Unknown node names: " + + Utils.join(nonExistentNodeNames, ", ")); + } + if (validNodeNames.isEmpty()) { + throw new KafkaException("No node names specified."); + } + return validNodeNames; + } + + void maybeSetError(String newError) { + if (error.isEmpty()) { + error = newError; + } + } + + TaskState taskState() { + switch (state) { + case PENDING: + return new TaskPending(spec); + case RUNNING: + return new TaskRunning(spec, startedMs, getCombinedStatus()); + case STOPPING: + return new TaskStopping(spec, startedMs, getCombinedStatus()); + case DONE: + return new TaskDone(spec, startedMs, doneMs, error, cancelled, getCombinedStatus()); + } + throw new RuntimeException("unreachable"); + } + + private JsonNode getCombinedStatus() { + if (workerIds.size() == 1) { + return workerStates.get(workerIds.values().iterator().next()).status(); + } else { + ObjectNode objectNode = new ObjectNode(JsonNodeFactory.instance); + for (Map.Entry entry : workerIds.entrySet()) { + String nodeName = entry.getKey(); + Long workerId = entry.getValue(); + WorkerState state = workerStates.get(workerId); + JsonNode node = state.status(); + if (node != null) { + objectNode.set(nodeName, node); + } + } + return objectNode; + } + } + + TreeMap activeWorkerIds() { + TreeMap activeWorkerIds = new TreeMap<>(); + for (Map.Entry entry : workerIds.entrySet()) { + WorkerState workerState = workerStates.get(entry.getValue()); + if (!workerState.done()) { + activeWorkerIds.put(entry.getKey(), entry.getValue()); + } + } + return activeWorkerIds; + } + } + + /** + * Create a task. + * + * @param id The ID of the task to create. + * @param spec The specification of the task to create. + * @throws RequestConflictException - if a task with the same ID but different spec exists + */ + public void createTask(final String id, TaskSpec spec) + throws Throwable { + try { + executor.submit(new CreateTask(id, spec)).get(); + } catch (ExecutionException | JsonProcessingException e) { + log.info("createTask(id={}, spec={}) error", id, spec, e); + throw e.getCause(); + } + } + + /** + * Handles a request to create a new task. Processed by the state change thread. + */ + class CreateTask implements Callable { + private final String id; + private final TaskSpec originalSpec; + private final TaskSpec spec; + + CreateTask(String id, TaskSpec spec) throws JsonProcessingException { + this.id = id; + this.originalSpec = spec; + ObjectNode node = JsonUtil.JSON_SERDE.valueToTree(originalSpec); + node.set("startMs", new LongNode(Math.max(time.milliseconds(), originalSpec.startMs()))); + this.spec = JsonUtil.JSON_SERDE.treeToValue(node, TaskSpec.class); + } + + @Override + public Void call() throws Exception { + if (id.isEmpty()) { + throw new InvalidRequestException("Invalid empty ID in createTask request."); + } + ManagedTask task = tasks.get(id); + if (task != null) { + if (!task.originalSpec.equals(originalSpec)) { + throw new RequestConflictException("Task ID " + id + " already " + + "exists, and has a different spec " + task.originalSpec); + } + log.info("Task {} already exists with spec {}", id, originalSpec); + return null; + } + TaskController controller = null; + String failure = null; + try { + controller = spec.newController(id); + } catch (Throwable t) { + failure = "Failed to create TaskController: " + t.getMessage(); + } + if (failure != null) { + log.info("Failed to create a new task {} with spec {}: {}", + id, spec, failure); + task = new ManagedTask(id, originalSpec, spec, null, TaskStateType.DONE); + task.doneMs = time.milliseconds(); + task.maybeSetError(failure); + tasks.put(id, task); + return null; + } + task = new ManagedTask(id, originalSpec, spec, controller, TaskStateType.PENDING); + tasks.put(id, task); + long delayMs = task.startDelayMs(time.milliseconds()); + task.startFuture = scheduler.schedule(executor, new RunTask(task), delayMs); + log.info("Created a new task {} with spec {}, scheduled to start {} ms from now.", + id, spec, delayMs); + return null; + } + } + + /** + * Handles starting a task. Processed by the state change thread. + */ + class RunTask implements Callable { + private final ManagedTask task; + + RunTask(ManagedTask task) { + this.task = task; + } + + @Override + public Void call() throws Exception { + task.clearStartFuture(); + if (task.state != TaskStateType.PENDING) { + log.info("Can't start task {}, because it is already in state {}.", + task.id, task.state); + return null; + } + TreeSet nodeNames; + try { + nodeNames = task.findNodeNames(); + } catch (Exception e) { + log.error("Unable to find nodes for task {}", task.id, e); + task.doneMs = time.milliseconds(); + task.state = TaskStateType.DONE; + task.maybeSetError("Unable to find nodes for task: " + e.getMessage()); + return null; + } + log.info("Running task {} on node(s): {}", task.id, Utils.join(nodeNames, ", ")); + task.state = TaskStateType.RUNNING; + task.startedMs = time.milliseconds(); + for (String workerName : nodeNames) { + long workerId = nextWorkerId++; + task.workerIds.put(workerName, workerId); + workerStates.put(workerId, new WorkerReceiving(task.id, task.spec)); + nodeManagers.get(workerName).createWorker(workerId, task.id, task.spec); + } + return null; + } + } + + /** + * Stop a task. + * + * @param id The ID of the task to stop. + */ + public void stopTask(final String id) throws Throwable { + try { + executor.submit(new CancelTask(id)).get(); + } catch (ExecutionException e) { + log.info("stopTask(id={}) error", id, e); + throw e.getCause(); + } + } + + /** + * Handles cancelling a task. Processed by the state change thread. + */ + class CancelTask implements Callable { + private final String id; + + CancelTask(String id) { + this.id = id; + } + + @Override + public Void call() throws Exception { + if (id.isEmpty()) { + throw new InvalidRequestException("Invalid empty ID in stopTask request."); + } + ManagedTask task = tasks.get(id); + if (task == null) { + log.info("Can't cancel non-existent task {}.", id); + return null; + } + switch (task.state) { + case PENDING: + task.cancelled = true; + task.clearStartFuture(); + task.doneMs = time.milliseconds(); + task.state = TaskStateType.DONE; + log.info("Stopped pending task {}.", id); + break; + case RUNNING: + task.cancelled = true; + TreeMap activeWorkerIds = task.activeWorkerIds(); + if (activeWorkerIds.isEmpty()) { + if (task.error.isEmpty()) { + log.info("Task {} is now complete with no errors.", id); + } else { + log.info("Task {} is now complete with error: {}", id, task.error); + } + task.doneMs = time.milliseconds(); + task.state = TaskStateType.DONE; + } else { + for (Map.Entry entry : activeWorkerIds.entrySet()) { + nodeManagers.get(entry.getKey()).stopWorker(entry.getValue()); + } + log.info("Cancelling task {} with worker(s) {}", + id, Utils.mkString(activeWorkerIds, "", "", " = ", ", ")); + task.state = TaskStateType.STOPPING; + } + break; + case STOPPING: + log.info("Can't cancel task {} because it is already stopping.", id); + break; + case DONE: + log.info("Can't cancel task {} because it is already done.", id); + break; + } + return null; + } + } + + public void destroyTask(String id) throws Throwable { + try { + executor.submit(new DestroyTask(id)).get(); + } catch (ExecutionException e) { + log.info("destroyTask(id={}) error", id, e); + throw e.getCause(); + } + } + + /** + * Handles destroying a task. Processed by the state change thread. + */ + class DestroyTask implements Callable { + private final String id; + + DestroyTask(String id) { + this.id = id; + } + + @Override + public Void call() throws Exception { + if (id.isEmpty()) { + throw new InvalidRequestException("Invalid empty ID in destroyTask request."); + } + ManagedTask task = tasks.remove(id); + if (task == null) { + log.info("Can't destroy task {}: no such task found.", id); + return null; + } + log.info("Destroying task {}.", id); + task.clearStartFuture(); + for (Map.Entry entry : task.workerIds.entrySet()) { + long workerId = entry.getValue(); + workerStates.remove(workerId); + String nodeName = entry.getKey(); + nodeManagers.get(nodeName).destroyWorker(workerId); + } + return null; + } + } + + /** + * Update the state of a particular agent's worker. + * + * @param nodeName The node where the agent is running. + * @param workerId The worker ID. + * @param state The worker state. + */ + public void updateWorkerState(String nodeName, long workerId, WorkerState state) { + executor.submit(new UpdateWorkerState(nodeName, workerId, state)); + } + + /** + * Updates the state of a worker. Process by the state change thread. + */ + class UpdateWorkerState implements Callable { + private final String nodeName; + private final long workerId; + private final WorkerState nextState; + + UpdateWorkerState(String nodeName, long workerId, WorkerState nextState) { + this.nodeName = nodeName; + this.workerId = workerId; + this.nextState = nextState; + } + + @Override + public Void call() throws Exception { + try { + WorkerState prevState = workerStates.get(workerId); + if (prevState == null) { + throw new RuntimeException("Unable to find workerId " + workerId); + } + ManagedTask task = tasks.get(prevState.taskId()); + if (task == null) { + throw new RuntimeException("Unable to find taskId " + prevState.taskId()); + } + log.debug("Task {}: Updating worker state for {} on {} from {} to {}.", + task.id, workerId, nodeName, prevState, nextState); + workerStates.put(workerId, nextState); + if (nextState.done() && (!prevState.done())) { + handleWorkerCompletion(task, nodeName, (WorkerDone) nextState); + } + } catch (Exception e) { + log.error("Error updating worker state for {} on {}. Stopping worker.", + workerId, nodeName, e); + nodeManagers.get(nodeName).stopWorker(workerId); + } + return null; + } + } + + /** + * Handle a worker being completed. + * + * @param task The task that owns the worker. + * @param nodeName The name of the node on which the worker is running. + * @param state The worker state. + */ + private void handleWorkerCompletion(ManagedTask task, String nodeName, WorkerDone state) { + if (state.error().isEmpty()) { + log.info("{}: Worker {} finished with status '{}'", + nodeName, task.id, JsonUtil.toJsonString(state.status())); + } else { + log.warn("{}: Worker {} finished with error '{}' and status '{}'", + nodeName, task.id, state.error(), JsonUtil.toJsonString(state.status())); + task.maybeSetError(state.error()); + } + TreeMap activeWorkerIds = task.activeWorkerIds(); + if (activeWorkerIds.isEmpty()) { + task.doneMs = time.milliseconds(); + task.state = TaskStateType.DONE; + log.info("{}: Task {} is now complete on {} with error: {}", + nodeName, task.id, Utils.join(task.workerIds.keySet(), ", "), + task.error.isEmpty() ? "(none)" : task.error); + } else if ((task.state == TaskStateType.RUNNING) && (!task.error.isEmpty())) { + log.info("{}: task {} stopped with error {}. Stopping worker(s): {}", + nodeName, task.id, task.error, Utils.mkString(activeWorkerIds, "{", "}", ": ", ", ")); + task.state = TaskStateType.STOPPING; + for (Map.Entry entry : activeWorkerIds.entrySet()) { + nodeManagers.get(entry.getKey()).stopWorker(entry.getValue()); + } + } + } + + /** + * Get information about the tasks being managed. + */ + public TasksResponse tasks(TasksRequest request) throws ExecutionException, InterruptedException { + return executor.submit(new GetTasksResponse(request)).get(); + } + + /** + * Gets information about the tasks being managed. Processed by the state change thread. + */ + class GetTasksResponse implements Callable { + private final TasksRequest request; + + GetTasksResponse(TasksRequest request) { + this.request = request; + } + + @Override + public TasksResponse call() throws Exception { + TreeMap states = new TreeMap<>(); + for (ManagedTask task : tasks.values()) { + if (request.matches(task.id, task.startedMs, task.doneMs, task.state)) { + states.put(task.id, task.taskState()); + } + } + return new TasksResponse(states); + } + } + + /** + * Get information about a single task being managed. + * + * Returns #{@code null} if the task does not exist + */ + public TaskState task(TaskRequest request) throws ExecutionException, InterruptedException { + return executor.submit(new GetTaskState(request)).get(); + } + + /** + * Gets information about the tasks being managed. Processed by the state change thread. + */ + class GetTaskState implements Callable { + private final TaskRequest request; + + GetTaskState(TaskRequest request) { + this.request = request; + } + + @Override + public TaskState call() throws Exception { + ManagedTask task = tasks.get(request.taskId()); + if (task == null) { + return null; + } + + return task.taskState(); + } + } + + /** + * Initiate shutdown, but do not wait for it to complete. + */ + public void beginShutdown(boolean stopAgents) { + if (shutdown.compareAndSet(false, true)) { + executor.submit(new Shutdown(stopAgents)); + } + } + + /** + * Wait for shutdown to complete. May be called prior to beginShutdown. + */ + public void waitForShutdown() throws InterruptedException { + while (!executor.awaitTermination(1, TimeUnit.DAYS)) { } + } + + class Shutdown implements Callable { + private final boolean stopAgents; + + Shutdown(boolean stopAgents) { + this.stopAgents = stopAgents; + } + + @Override + public Void call() throws Exception { + log.info("Shutting down TaskManager{}.", stopAgents ? " and agents" : ""); + for (NodeManager nodeManager : nodeManagers.values()) { + nodeManager.beginShutdown(stopAgents); + } + for (NodeManager nodeManager : nodeManagers.values()) { + nodeManager.waitForShutdown(); + } + executor.shutdown(); + return null; + } + } +}; diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/fault/DegradedNetworkFaultSpec.java b/trogdor/src/main/java/org/apache/kafka/trogdor/fault/DegradedNetworkFaultSpec.java new file mode 100644 index 0000000..127a4da --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/fault/DegradedNetworkFaultSpec.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.fault; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.kafka.trogdor.task.TaskController; +import org.apache.kafka.trogdor.task.TaskSpec; +import org.apache.kafka.trogdor.task.TaskWorker; + +import java.util.Collections; +import java.util.Map; + +public class DegradedNetworkFaultSpec extends TaskSpec { + + public static class NodeDegradeSpec { + private final String networkDevice; + private final int latencyMs; + private final int rateLimitKbit; + + public NodeDegradeSpec( + @JsonProperty("networkDevice") String networkDevice, + @JsonProperty("latencyMs") Integer latencyMs, + @JsonProperty("rateLimitKbit") Integer rateLimitKbit) { + this.networkDevice = networkDevice == null ? "" : networkDevice; + this.latencyMs = latencyMs == null ? 0 : latencyMs; + this.rateLimitKbit = rateLimitKbit == null ? 0 : rateLimitKbit; + } + + @JsonProperty("networkDevice") + public String networkDevice() { + return networkDevice; + } + + @JsonProperty("latencyMs") + public int latencyMs() { + return latencyMs; + } + + @JsonProperty("rateLimitKbit") + public int rateLimitKbit() { + return rateLimitKbit; + } + + @Override + public String toString() { + return "NodeDegradeSpec{" + + "networkDevice='" + networkDevice + '\'' + + ", latencyMs=" + latencyMs + + ", rateLimitKbit=" + rateLimitKbit + + '}'; + } + } + + private final Map nodeSpecs; + + @JsonCreator + public DegradedNetworkFaultSpec(@JsonProperty("startMs") long startMs, + @JsonProperty("durationMs") long durationMs, + @JsonProperty("nodeSpecs") Map nodeSpecs) { + super(startMs, durationMs); + this.nodeSpecs = nodeSpecs == null ? Collections.emptyMap() : Collections.unmodifiableMap(nodeSpecs); + } + + @Override + public TaskController newController(String id) { + return topology -> nodeSpecs.keySet(); + } + + @Override + public TaskWorker newTaskWorker(String id) { + return new DegradedNetworkFaultWorker(id, nodeSpecs); + } + + @JsonProperty("nodeSpecs") + public Map nodeSpecs() { + return nodeSpecs; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/fault/DegradedNetworkFaultWorker.java b/trogdor/src/main/java/org/apache/kafka/trogdor/fault/DegradedNetworkFaultWorker.java new file mode 100644 index 0000000..d071d12 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/fault/DegradedNetworkFaultWorker.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.fault; + +import com.fasterxml.jackson.databind.node.TextNode; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.trogdor.common.Node; +import org.apache.kafka.trogdor.common.Platform; +import org.apache.kafka.trogdor.task.TaskWorker; +import org.apache.kafka.trogdor.task.WorkerStatusTracker; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.NetworkInterface; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Consumer; +import java.util.stream.Stream; + +/** + * Uses the linux utility
                tc
                (traffic controller) to degrade performance on a specified network device + */ +public class DegradedNetworkFaultWorker implements TaskWorker { + + private static final Logger log = LoggerFactory.getLogger(DegradedNetworkFaultWorker.class); + + private final String id; + private final Map nodeSpecs; + private WorkerStatusTracker status; + + public DegradedNetworkFaultWorker(String id, Map nodeSpecs) { + this.id = id; + this.nodeSpecs = nodeSpecs; + } + + @Override + public void start(Platform platform, WorkerStatusTracker status, KafkaFutureImpl haltFuture) throws Exception { + log.info("Activating DegradedNetworkFaultWorker {}.", id); + this.status = status; + this.status.update(new TextNode("enabling traffic control " + id)); + Node curNode = platform.curNode(); + DegradedNetworkFaultSpec.NodeDegradeSpec nodeSpec = nodeSpecs.get(curNode.name()); + if (nodeSpec != null) { + for (String device : devicesForSpec(nodeSpec)) { + if (nodeSpec.latencyMs() < 0 || nodeSpec.rateLimitKbit() < 0) { + throw new RuntimeException("Expected non-negative values for latencyMs and rateLimitKbit, but got " + nodeSpec); + } else { + enableTrafficControl(platform, device, nodeSpec.latencyMs(), nodeSpec.rateLimitKbit()); + } + } + } + this.status.update(new TextNode("enabled traffic control " + id)); + } + + @Override + public void stop(Platform platform) throws Exception { + log.info("Deactivating DegradedNetworkFaultWorker {}.", id); + this.status.update(new TextNode("disabling traffic control " + id)); + Node curNode = platform.curNode(); + DegradedNetworkFaultSpec.NodeDegradeSpec nodeSpec = nodeSpecs.get(curNode.name()); + if (nodeSpec != null) { + for (String device : devicesForSpec(nodeSpec)) { + disableTrafficControl(platform, device); + } + } + this.status.update(new TextNode("disabled traffic control " + id)); + } + + private Set devicesForSpec(DegradedNetworkFaultSpec.NodeDegradeSpec nodeSpec) throws Exception { + Set devices = new HashSet<>(); + if (nodeSpec.networkDevice().isEmpty()) { + for (NetworkInterface networkInterface : Collections.list(NetworkInterface.getNetworkInterfaces())) { + if (!networkInterface.isLoopback()) { + devices.add(networkInterface.getName()); + } + } + } else { + devices.add(nodeSpec.networkDevice()); + } + return devices; + } + + /** + * Constructs the appropriate "tc" commands to apply latency and rate limiting, if they are non zero. + */ + private void enableTrafficControl(Platform platform, String networkDevice, int delayMs, int rateLimitKbps) throws IOException { + if (delayMs > 0) { + int deviationMs = Math.max(1, (int) Math.sqrt(delayMs)); + List delay = new ArrayList<>(); + rootHandler(networkDevice, delay::add); + netemDelay(delayMs, deviationMs, delay::add); + platform.runCommand(delay.toArray(new String[0])); + + if (rateLimitKbps > 0) { + List rate = new ArrayList<>(); + childHandler(networkDevice, rate::add); + tbfRate(rateLimitKbps, rate::add); + platform.runCommand(rate.toArray(new String[0])); + } + } else if (rateLimitKbps > 0) { + List rate = new ArrayList<>(); + rootHandler(networkDevice, rate::add); + tbfRate(rateLimitKbps, rate::add); + platform.runCommand(rate.toArray(new String[0])); + } else { + log.warn("Not applying any rate limiting or latency"); + } + } + + /** + * Construct the first part of a "tc" command to define a qdisc root handler for the given network interface + */ + private void rootHandler(String networkDevice, Consumer consumer) { + Stream.of("sudo", "tc", "qdisc", "add", "dev", networkDevice, "root", "handle", "1:0").forEach(consumer); + } + + /** + * Construct the first part of a "tc" command to define a qdisc child handler for the given interface. This can + * only be used if a root handler has been appropriately defined first (as in {@link #rootHandler}). + */ + private void childHandler(String networkDevice, Consumer consumer) { + Stream.of("sudo", "tc", "qdisc", "add", "dev", networkDevice, "parent", "1:1", "handle", "10:").forEach(consumer); + } + + /** + * Construct the second part of a "tc" command that defines a netem (Network Emulator) filter that will apply some + * amount of latency with a small amount of deviation. The distribution of the latency deviation follows a so-called + * Pareto-normal distribution. This is the formal name for the 80/20 rule, which might better represent real-world + * patterns. + */ + private void netemDelay(int delayMs, int deviationMs, Consumer consumer) { + Stream.of("netem", "delay", String.format("%dms", delayMs), String.format("%dms", deviationMs), + "distribution", "paretonormal").forEach(consumer); + } + + /** + * Construct the second part of a "tc" command that defines a tbf (token buffer filter) that will rate limit the + * packets going through a qdisc. + */ + private void tbfRate(int rateLimitKbit, Consumer consumer) { + Stream.of("tbf", "rate", String.format("%dkbit", rateLimitKbit), "burst", "1mbit", "latency", "500ms").forEach(consumer); + } + + /** + * Delete any previously defined qdisc for the given network interface. + * @throws IOException + */ + private void disableTrafficControl(Platform platform, String networkDevice) throws IOException { + platform.runCommand(new String[] { + "sudo", "tc", "qdisc", "del", "dev", networkDevice, "root" + }); + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/fault/FilesUnreadableFaultSpec.java b/trogdor/src/main/java/org/apache/kafka/trogdor/fault/FilesUnreadableFaultSpec.java new file mode 100644 index 0000000..cb520c4 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/fault/FilesUnreadableFaultSpec.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.fault; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.kafka.trogdor.fault.Kibosh.KiboshFilesUnreadableFaultSpec; +import org.apache.kafka.trogdor.task.TaskController; +import org.apache.kafka.trogdor.task.TaskSpec; +import org.apache.kafka.trogdor.task.TaskWorker; + +import java.util.HashSet; +import java.util.Set; + +/** + * The specification for a fault that makes files unreadable. + */ +public class FilesUnreadableFaultSpec extends TaskSpec { + private final Set nodeNames; + private final String mountPath; + private final String prefix; + private final int errorCode; + + @JsonCreator + public FilesUnreadableFaultSpec(@JsonProperty("startMs") long startMs, + @JsonProperty("durationMs") long durationMs, + @JsonProperty("nodeNames") Set nodeNames, + @JsonProperty("mountPath") String mountPath, + @JsonProperty("prefix") String prefix, + @JsonProperty("errorCode") int errorCode) { + super(startMs, durationMs); + this.nodeNames = nodeNames == null ? new HashSet() : nodeNames; + this.mountPath = mountPath == null ? "" : mountPath; + this.prefix = prefix == null ? "" : prefix; + this.errorCode = errorCode; + } + + @JsonProperty + public Set nodeNames() { + return nodeNames; + } + + @JsonProperty + public String mountPath() { + return mountPath; + } + + @JsonProperty + public String prefix() { + return prefix; + } + + @JsonProperty + public int errorCode() { + return errorCode; + } + + @Override + public TaskController newController(String id) { + return new KiboshFaultController(nodeNames); + } + + @Override + public TaskWorker newTaskWorker(String id) { + return new KiboshFaultWorker(id, + new KiboshFilesUnreadableFaultSpec(prefix, errorCode), mountPath); + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/fault/Kibosh.java b/trogdor/src/main/java/org/apache/kafka/trogdor/fault/Kibosh.java new file mode 100644 index 0000000..91dd4a9 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/fault/Kibosh.java @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.fault; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import org.apache.kafka.trogdor.common.JsonUtil; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.TreeMap; + +public final class Kibosh { + public static final Kibosh INSTANCE = new Kibosh(); + + public final static String KIBOSH_CONTROL = "kibosh_control"; + + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, + include = JsonTypeInfo.As.PROPERTY, + property = "type") + @JsonSubTypes({ + @JsonSubTypes.Type(value = KiboshFilesUnreadableFaultSpec.class, name = "unreadable"), + }) + public static abstract class KiboshFaultSpec { + @Override + public final boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + return Objects.equals(toString(), o.toString()); + } + + @Override + public final int hashCode() { + return toString().hashCode(); + } + + @Override + public final String toString() { + return JsonUtil.toJsonString(this); + } + } + + public static class KiboshFilesUnreadableFaultSpec extends KiboshFaultSpec { + private final String prefix; + private final int errorCode; + + @JsonCreator + public KiboshFilesUnreadableFaultSpec(@JsonProperty("prefix") String prefix, + @JsonProperty("errorCode") int errorCode) { + this.prefix = prefix; + this.errorCode = errorCode; + } + + @JsonProperty + public String prefix() { + return prefix; + } + + @JsonProperty + public int errorCode() { + return errorCode; + } + } + + private static class KiboshProcess { + private final Path controlPath; + + KiboshProcess(String mountPath) { + this.controlPath = Paths.get(mountPath, KIBOSH_CONTROL); + if (!Files.exists(controlPath)) { + throw new RuntimeException("Can't find file " + controlPath); + } + } + + synchronized void addFault(KiboshFaultSpec toAdd) throws IOException { + KiboshControlFile file = KiboshControlFile.read(controlPath); + List faults = new ArrayList<>(file.faults()); + faults.add(toAdd); + new KiboshControlFile(faults).write(controlPath); + } + + synchronized void removeFault(KiboshFaultSpec toRemove) throws IOException { + KiboshControlFile file = KiboshControlFile.read(controlPath); + List faults = new ArrayList<>(); + boolean foundToRemove = false; + for (KiboshFaultSpec fault : file.faults()) { + if (fault.equals(toRemove)) { + foundToRemove = true; + } else { + faults.add(fault); + } + } + if (!foundToRemove) { + throw new RuntimeException("Failed to find fault " + toRemove + ". "); + } + new KiboshControlFile(faults).write(controlPath); + } + } + + public static class KiboshControlFile { + private final List faults; + + public final static KiboshControlFile EMPTY = + new KiboshControlFile(Collections.emptyList()); + + public static KiboshControlFile read(Path controlPath) throws IOException { + byte[] controlFileBytes = Files.readAllBytes(controlPath); + return JsonUtil.JSON_SERDE.readValue(controlFileBytes, KiboshControlFile.class); + } + + @JsonCreator + public KiboshControlFile(@JsonProperty("faults") List faults) { + this.faults = faults == null ? new ArrayList<>() : faults; + } + + @JsonProperty + public List faults() { + return faults; + } + + public void write(Path controlPath) throws IOException { + Files.write(controlPath, JsonUtil.JSON_SERDE.writeValueAsBytes(this)); + } + + @Override + public final boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + return Objects.equals(toString(), o.toString()); + } + + @Override + public final int hashCode() { + return toString().hashCode(); + } + + @Override + public final String toString() { + return JsonUtil.toJsonString(this); + } + } + + private final TreeMap processes = new TreeMap<>(); + + private Kibosh() { + } + + /** + * Get or create a KiboshProcess object to manage the Kibosh process at a given path. + */ + private synchronized KiboshProcess findProcessObject(String mountPath) { + String path = Paths.get(mountPath).normalize().toString(); + KiboshProcess process = processes.get(path); + if (process == null) { + process = new KiboshProcess(mountPath); + processes.put(path, process); + } + return process; + } + + /** + * Add a new Kibosh fault. + */ + void addFault(String mountPath, KiboshFaultSpec spec) throws IOException { + KiboshProcess process = findProcessObject(mountPath); + process.addFault(spec); + } + + /** + * Remove a Kibosh fault. + */ + void removeFault(String mountPath, KiboshFaultSpec spec) throws IOException { + KiboshProcess process = findProcessObject(mountPath); + process.removeFault(spec); + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/fault/KiboshFaultController.java b/trogdor/src/main/java/org/apache/kafka/trogdor/fault/KiboshFaultController.java new file mode 100644 index 0000000..140abf1 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/fault/KiboshFaultController.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.fault; + +import org.apache.kafka.trogdor.common.Topology; +import org.apache.kafka.trogdor.task.TaskController; + +import java.util.Set; + +public class KiboshFaultController implements TaskController { + private final Set nodeNames; + + public KiboshFaultController(Set nodeNames) { + this.nodeNames = nodeNames; + } + + @Override + public Set targetNodes(Topology topology) { + return nodeNames; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/fault/KiboshFaultWorker.java b/trogdor/src/main/java/org/apache/kafka/trogdor/fault/KiboshFaultWorker.java new file mode 100644 index 0000000..97934a8 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/fault/KiboshFaultWorker.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.fault; + +import com.fasterxml.jackson.databind.node.TextNode; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.trogdor.common.Platform; +import org.apache.kafka.trogdor.fault.Kibosh.KiboshFaultSpec; +import org.apache.kafka.trogdor.task.TaskWorker; +import org.apache.kafka.trogdor.task.WorkerStatusTracker; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class KiboshFaultWorker implements TaskWorker { + private static final Logger log = LoggerFactory.getLogger(KiboshFaultWorker.class); + + private final String id; + + private final KiboshFaultSpec spec; + + private final String mountPath; + + private WorkerStatusTracker status; + + public KiboshFaultWorker(String id, KiboshFaultSpec spec, String mountPath) { + this.id = id; + this.spec = spec; + this.mountPath = mountPath; + } + + @Override + public void start(Platform platform, WorkerStatusTracker status, + KafkaFutureImpl errorFuture) throws Exception { + log.info("Activating {} {}: {}.", spec.getClass().getSimpleName(), id, spec); + this.status = status; + this.status.update(new TextNode("Adding fault " + id)); + Kibosh.INSTANCE.addFault(mountPath, spec); + this.status.update(new TextNode("Added fault " + id)); + } + + @Override + public void stop(Platform platform) throws Exception { + log.info("Deactivating {} {}: {}.", spec.getClass().getSimpleName(), id, spec); + this.status.update(new TextNode("Removing fault " + id)); + Kibosh.INSTANCE.removeFault(mountPath, spec); + this.status.update(new TextNode("Removed fault " + id)); + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/fault/NetworkPartitionFaultController.java b/trogdor/src/main/java/org/apache/kafka/trogdor/fault/NetworkPartitionFaultController.java new file mode 100644 index 0000000..d90534f --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/fault/NetworkPartitionFaultController.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.fault; + +import org.apache.kafka.trogdor.common.Topology; +import org.apache.kafka.trogdor.task.TaskController; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +public class NetworkPartitionFaultController implements TaskController { + private final List> partitionSets; + + public NetworkPartitionFaultController(List> partitionSets) { + this.partitionSets = partitionSets; + } + + @Override + public Set targetNodes(Topology topology) { + Set targetNodes = new HashSet<>(); + for (Set partitionSet : partitionSets) { + targetNodes.addAll(partitionSet); + } + return targetNodes; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/fault/NetworkPartitionFaultSpec.java b/trogdor/src/main/java/org/apache/kafka/trogdor/fault/NetworkPartitionFaultSpec.java new file mode 100644 index 0000000..c3df792 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/fault/NetworkPartitionFaultSpec.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.fault; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.kafka.trogdor.task.TaskController; +import org.apache.kafka.trogdor.task.TaskSpec; +import org.apache.kafka.trogdor.task.TaskWorker; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * The specification for a fault that creates a network partition. + */ +public class NetworkPartitionFaultSpec extends TaskSpec { + private final List> partitions; + + @JsonCreator + public NetworkPartitionFaultSpec(@JsonProperty("startMs") long startMs, + @JsonProperty("durationMs") long durationMs, + @JsonProperty("partitions") List> partitions) { + super(startMs, durationMs); + this.partitions = partitions == null ? new ArrayList>() : partitions; + } + + @JsonProperty + public List> partitions() { + return partitions; + } + + @Override + public TaskController newController(String id) { + return new NetworkPartitionFaultController(partitionSets()); + } + + @Override + public TaskWorker newTaskWorker(String id) { + return new NetworkPartitionFaultWorker(id, partitionSets()); + } + + private List> partitionSets() { + List> partitionSets = new ArrayList<>(); + HashSet prevNodes = new HashSet<>(); + for (List partition : this.partitions()) { + for (String nodeName : partition) { + if (prevNodes.contains(nodeName)) { + throw new RuntimeException("Node " + nodeName + + " appears in more than one partition."); + } + prevNodes.add(nodeName); + partitionSets.add(new HashSet<>(partition)); + } + } + return partitionSets; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/fault/NetworkPartitionFaultWorker.java b/trogdor/src/main/java/org/apache/kafka/trogdor/fault/NetworkPartitionFaultWorker.java new file mode 100644 index 0000000..1b99a93 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/fault/NetworkPartitionFaultWorker.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.fault; + +import com.fasterxml.jackson.databind.node.TextNode; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.trogdor.common.Node; +import org.apache.kafka.trogdor.common.Platform; +import org.apache.kafka.trogdor.common.Topology; +import org.apache.kafka.trogdor.task.TaskWorker; +import org.apache.kafka.trogdor.task.WorkerStatusTracker; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.net.InetAddress; +import java.util.List; +import java.util.Set; +import java.util.TreeSet; + +public class NetworkPartitionFaultWorker implements TaskWorker { + private static final Logger log = LoggerFactory.getLogger(NetworkPartitionFaultWorker.class); + + private final String id; + + private final List> partitionSets; + + private WorkerStatusTracker status; + + public NetworkPartitionFaultWorker(String id, List> partitionSets) { + this.id = id; + this.partitionSets = partitionSets; + } + + @Override + public void start(Platform platform, WorkerStatusTracker status, + KafkaFutureImpl errorFuture) throws Exception { + log.info("Activating NetworkPartitionFault {}.", id); + this.status = status; + this.status.update(new TextNode("creating network partition " + id)); + runIptablesCommands(platform, "-A"); + this.status.update(new TextNode("created network partition " + id)); + } + + @Override + public void stop(Platform platform) throws Exception { + log.info("Deactivating NetworkPartitionFault {}.", id); + this.status.update(new TextNode("removing network partition " + id)); + runIptablesCommands(platform, "-D"); + this.status.update(new TextNode("removed network partition " + id)); + } + + private void runIptablesCommands(Platform platform, String iptablesAction) throws Exception { + Node curNode = platform.curNode(); + Topology topology = platform.topology(); + TreeSet toBlock = new TreeSet<>(); + for (Set partitionSet : partitionSets) { + if (!partitionSet.contains(curNode.name())) { + for (String nodeName : partitionSet) { + toBlock.add(nodeName); + } + } + } + for (String nodeName : toBlock) { + Node node = topology.node(nodeName); + InetAddress addr = InetAddress.getByName(node.hostname()); + platform.runCommand(new String[] { + "sudo", "iptables", iptablesAction, "INPUT", "-p", "tcp", "-s", + addr.getHostAddress(), "-j", "DROP", "-m", "comment", "--comment", nodeName + }); + } + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/fault/ProcessStopFaultController.java b/trogdor/src/main/java/org/apache/kafka/trogdor/fault/ProcessStopFaultController.java new file mode 100644 index 0000000..6ec803a --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/fault/ProcessStopFaultController.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.fault; + +import org.apache.kafka.trogdor.common.Topology; +import org.apache.kafka.trogdor.task.TaskController; +import java.util.Set; + +public class ProcessStopFaultController implements TaskController { + private final Set nodeNames; + + public ProcessStopFaultController(Set nodeNames) { + this.nodeNames = nodeNames; + } + + @Override + public Set targetNodes(Topology topology) { + return nodeNames; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/fault/ProcessStopFaultSpec.java b/trogdor/src/main/java/org/apache/kafka/trogdor/fault/ProcessStopFaultSpec.java new file mode 100644 index 0000000..cda2fb0 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/fault/ProcessStopFaultSpec.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.fault; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.kafka.trogdor.task.TaskController; +import org.apache.kafka.trogdor.task.TaskSpec; +import org.apache.kafka.trogdor.task.TaskWorker; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * The specification for a fault that creates a network partition. + */ +public class ProcessStopFaultSpec extends TaskSpec { + private final Set nodeNames; + private final String javaProcessName; + + @JsonCreator + public ProcessStopFaultSpec(@JsonProperty("startMs") long startMs, + @JsonProperty("durationMs") long durationMs, + @JsonProperty("nodeNames") List nodeNames, + @JsonProperty("javaProcessName") String javaProcessName) { + super(startMs, durationMs); + this.nodeNames = nodeNames == null ? new HashSet() : new HashSet<>(nodeNames); + this.javaProcessName = javaProcessName == null ? "" : javaProcessName; + } + + @JsonProperty + public Set nodeNames() { + return nodeNames; + } + + @JsonProperty + public String javaProcessName() { + return javaProcessName; + } + + @Override + public TaskController newController(String id) { + return new ProcessStopFaultController(nodeNames); + } + + @Override + public TaskWorker newTaskWorker(String id) { + return new ProcessStopFaultWorker(id, javaProcessName); + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/fault/ProcessStopFaultWorker.java b/trogdor/src/main/java/org/apache/kafka/trogdor/fault/ProcessStopFaultWorker.java new file mode 100644 index 0000000..ef97e7b --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/fault/ProcessStopFaultWorker.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.fault; + +import com.fasterxml.jackson.databind.node.TextNode; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.trogdor.common.Platform; +import org.apache.kafka.trogdor.task.TaskWorker; +import org.apache.kafka.trogdor.task.WorkerStatusTracker; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.List; + +public class ProcessStopFaultWorker implements TaskWorker { + private static final Logger log = LoggerFactory.getLogger(ProcessStopFaultWorker.class); + + private final String id; + + private final String javaProcessName; + + private WorkerStatusTracker status; + + public ProcessStopFaultWorker(String id, String javaProcessName) { + this.id = id; + this.javaProcessName = javaProcessName; + } + + @Override + public void start(Platform platform, WorkerStatusTracker status, + KafkaFutureImpl errorFuture) throws Exception { + this.status = status; + log.info("Activating ProcessStopFault {}.", id); + this.status.update(new TextNode("stopping " + javaProcessName)); + sendSignals(platform, "SIGSTOP"); + this.status.update(new TextNode("stopped " + javaProcessName)); + } + + @Override + public void stop(Platform platform) throws Exception { + log.info("Deactivating ProcessStopFault {}.", id); + this.status.update(new TextNode("resuming " + javaProcessName)); + sendSignals(platform, "SIGCONT"); + this.status.update(new TextNode("resumed " + javaProcessName)); + } + + private void sendSignals(Platform platform, String signalName) throws Exception { + String jcmdOutput = platform.runCommand(new String[] {"jcmd"}); + String[] lines = jcmdOutput.split("\n"); + List pids = new ArrayList<>(); + for (String line : lines) { + if (line.contains(javaProcessName)) { + String[] components = line.split(" "); + try { + pids.add(Integer.parseInt(components[0])); + } catch (NumberFormatException e) { + log.error("Failed to parse process ID from line", e); + } + } + } + if (pids.isEmpty()) { + log.error("{}: no processes containing {} found to send {} to.", + id, javaProcessName, signalName); + } else { + log.info("{}: sending {} to {} pid(s) {}", + id, signalName, javaProcessName, Utils.join(pids, ", ")); + for (Integer pid : pids) { + platform.runCommand(new String[] {"kill", "-" + signalName, pid.toString()}); + } + } + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/AgentStatusResponse.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/AgentStatusResponse.java new file mode 100644 index 0000000..d41a54b --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/AgentStatusResponse.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.rest; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.TreeMap; + +/** + * A response from the Trogdor agent about the worker states and specifications. + */ +public class AgentStatusResponse extends Message { + private final long serverStartMs; + private final TreeMap workers; + + @JsonCreator + public AgentStatusResponse(@JsonProperty("serverStartMs") long serverStartMs, + @JsonProperty("workers") TreeMap workers) { + this.serverStartMs = serverStartMs; + this.workers = workers == null ? new TreeMap() : workers; + } + + @JsonProperty + public long serverStartMs() { + return serverStartMs; + } + + @JsonProperty + public TreeMap workers() { + return workers; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/CoordinatorShutdownRequest.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/CoordinatorShutdownRequest.java new file mode 100644 index 0000000..1aacaaf --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/CoordinatorShutdownRequest.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.rest; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * A request to the Trogdor coordinator to shut down. + */ +public class CoordinatorShutdownRequest extends Message { + private final boolean stopAgents; + + @JsonCreator + public CoordinatorShutdownRequest(@JsonProperty("stopAgents") boolean stopAgents) { + this.stopAgents = stopAgents; + } + + @JsonProperty + public boolean stopAgents() { + return stopAgents; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/CoordinatorStatusResponse.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/CoordinatorStatusResponse.java new file mode 100644 index 0000000..8840d29 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/CoordinatorStatusResponse.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.rest; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * A status response from the Trogdor coordinator. + */ +public class CoordinatorStatusResponse extends Message { + private final long serverStartMs; + + @JsonCreator + public CoordinatorStatusResponse(@JsonProperty("serverStartMs") long serverStartMs) { + this.serverStartMs = serverStartMs; + } + + @JsonProperty + public long serverStartMs() { + return serverStartMs; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/CreateTaskRequest.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/CreateTaskRequest.java new file mode 100644 index 0000000..d463e36 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/CreateTaskRequest.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.rest; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.kafka.trogdor.task.TaskSpec; + +/** + * A request to the Trogdor coorinator to create a task. + */ +public class CreateTaskRequest extends Message { + private final String id; + private final TaskSpec spec; + + @JsonCreator + public CreateTaskRequest(@JsonProperty("id") String id, + @JsonProperty("spec") TaskSpec spec) { + this.id = id; + this.spec = spec; + } + + @JsonProperty + public String id() { + return id; + } + + @JsonProperty + public TaskSpec spec() { + return spec; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/CreateWorkerRequest.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/CreateWorkerRequest.java new file mode 100644 index 0000000..4acc943 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/CreateWorkerRequest.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.rest; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.kafka.trogdor.task.TaskSpec; + +/** + * A request to the Trogdor agent to create a worker. + */ +public class CreateWorkerRequest extends Message { + private final long workerId; + private final String taskId; + private final TaskSpec spec; + + @JsonCreator + public CreateWorkerRequest(@JsonProperty("workerId") long workerId, + @JsonProperty("taskId") String taskId, + @JsonProperty("spec") TaskSpec spec) { + this.workerId = workerId; + this.taskId = taskId; + this.spec = spec; + } + + @JsonProperty + public long workerId() { + return workerId; + } + + @JsonProperty + public String taskId() { + return taskId; + } + + @JsonProperty + public TaskSpec spec() { + return spec; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/DestroyTaskRequest.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/DestroyTaskRequest.java new file mode 100644 index 0000000..d782d5d --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/DestroyTaskRequest.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.rest; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * A request to the Trogdor coordinator to delete all memory of a task. + */ +public class DestroyTaskRequest extends Message { + private final String id; + + @JsonCreator + public DestroyTaskRequest(@JsonProperty("id") String id) { + this.id = id; + } + + @JsonProperty + public String id() { + return id; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/DestroyWorkerRequest.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/DestroyWorkerRequest.java new file mode 100644 index 0000000..e5a8969 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/DestroyWorkerRequest.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.rest; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * A request to the Trogdor agent to delete all memory of a task. + */ +public class DestroyWorkerRequest extends Message { + private final long workerId; + + @JsonCreator + public DestroyWorkerRequest(@JsonProperty("workerId") long workerId) { + this.workerId = workerId; + } + + @JsonProperty + public long workerId() { + return workerId; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/Empty.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/Empty.java new file mode 100644 index 0000000..da2fcba --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/Empty.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.rest; + +import com.fasterxml.jackson.annotation.JsonCreator; +import org.apache.kafka.trogdor.common.JsonUtil; + +/** + * An empty request or response. + */ +public class Empty { + public static final Empty INSTANCE = new Empty(); + + @JsonCreator + public Empty() { + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + return true; + } + + @Override + public int hashCode() { + return 1; + } + + @Override + public String toString() { + return JsonUtil.toJsonString(this); + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/ErrorResponse.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/ErrorResponse.java new file mode 100644 index 0000000..08bf6cd --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/ErrorResponse.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.rest; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.kafka.trogdor.common.JsonUtil; + +import java.util.Objects; + +/** + * An error response. + */ +public class ErrorResponse { + private final int code; + private final String message; + + @JsonCreator + public ErrorResponse(@JsonProperty("code") int code, + @JsonProperty("message") String message) { + this.code = code; + this.message = message; + } + + @JsonProperty + public int code() { + return code; + } + + @JsonProperty + public String message() { + return message; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ErrorResponse that = (ErrorResponse) o; + return Objects.equals(code, that.code) && + Objects.equals(message, that.message); + } + + @Override + public int hashCode() { + return Objects.hash(code, message); + } + + @Override + public String toString() { + return JsonUtil.toJsonString(this); + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/JsonRestServer.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/JsonRestServer.java new file mode 100644 index 0000000..e5388f8 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/JsonRestServer.java @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.rest; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.jaxrs.json.JacksonJsonProvider; + +import org.apache.kafka.common.utils.ThreadUtils; +import org.apache.kafka.trogdor.common.JsonUtil; +import org.eclipse.jetty.server.Connector; +import org.eclipse.jetty.server.CustomRequestLog; +import org.eclipse.jetty.server.Handler; +import org.eclipse.jetty.server.Server; +import org.eclipse.jetty.server.ServerConnector; +import org.eclipse.jetty.server.Slf4jRequestLogWriter; +import org.eclipse.jetty.server.handler.DefaultHandler; +import org.eclipse.jetty.server.handler.HandlerCollection; +import org.eclipse.jetty.server.handler.RequestLogHandler; +import org.eclipse.jetty.server.handler.StatisticsHandler; +import org.eclipse.jetty.servlet.ServletContextHandler; +import org.eclipse.jetty.servlet.ServletHolder; +import org.glassfish.jersey.server.ResourceConfig; +import org.glassfish.jersey.servlet.ServletContainer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.HttpURLConnection; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.Callable; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +/** + * Embedded server for the REST API that provides the control plane for Trogdor. + */ +public class JsonRestServer { + private static final Logger log = LoggerFactory.getLogger(JsonRestServer.class); + + private static final long GRACEFUL_SHUTDOWN_TIMEOUT_MS = 100; + + private final ScheduledExecutorService shutdownExecutor; + + private final Server jettyServer; + + private final ServerConnector connector; + + /** + * Create a REST server for this herder using the specified configs. + * + * @param port The port number to use for the REST server, or + * 0 to use a random port. + */ + public JsonRestServer(int port) { + this.shutdownExecutor = Executors.newSingleThreadScheduledExecutor( + ThreadUtils.createThreadFactory("JsonRestServerCleanupExecutor", false)); + this.jettyServer = new Server(); + this.connector = new ServerConnector(jettyServer); + if (port > 0) { + connector.setPort(port); + } + jettyServer.setConnectors(new Connector[]{connector}); + } + + /** + * Start the JsonRestServer. + * + * @param resources The path handling resources to register. + */ + public void start(Object... resources) { + log.info("Starting REST server"); + ResourceConfig resourceConfig = new ResourceConfig(); + resourceConfig.register(new JacksonJsonProvider(JsonUtil.JSON_SERDE)); + for (Object resource : resources) { + resourceConfig.register(resource); + log.info("Registered resource {}", resource); + } + resourceConfig.register(RestExceptionMapper.class); + ServletContainer servletContainer = new ServletContainer(resourceConfig); + ServletHolder servletHolder = new ServletHolder(servletContainer); + ServletContextHandler context = new ServletContextHandler(ServletContextHandler.SESSIONS); + context.setContextPath("/"); + context.addServlet(servletHolder, "/*"); + + RequestLogHandler requestLogHandler = new RequestLogHandler(); + Slf4jRequestLogWriter slf4jRequestLogWriter = new Slf4jRequestLogWriter(); + slf4jRequestLogWriter.setLoggerName(JsonRestServer.class.getCanonicalName()); + CustomRequestLog requestLog = new CustomRequestLog(slf4jRequestLogWriter, CustomRequestLog.EXTENDED_NCSA_FORMAT + " %{ms}T"); + requestLogHandler.setRequestLog(requestLog); + + HandlerCollection handlers = new HandlerCollection(); + handlers.setHandlers(new Handler[]{context, new DefaultHandler(), requestLogHandler}); + StatisticsHandler statsHandler = new StatisticsHandler(); + statsHandler.setHandler(handlers); + jettyServer.setHandler(statsHandler); + /* Needed for graceful shutdown as per `setStopTimeout` documentation */ + jettyServer.setStopTimeout(GRACEFUL_SHUTDOWN_TIMEOUT_MS); + jettyServer.setStopAtShutdown(true); + + try { + jettyServer.start(); + } catch (Exception e) { + throw new RuntimeException("Unable to start REST server", e); + } + log.info("REST server listening at " + jettyServer.getURI()); + } + + public int port() { + return connector.getLocalPort(); + } + + /** + * Initiate shutdown, but do not wait for it to complete. + */ + public void beginShutdown() { + if (!shutdownExecutor.isShutdown()) { + shutdownExecutor.submit((Callable) () -> { + try { + log.info("Stopping REST server"); + jettyServer.stop(); + jettyServer.join(); + log.info("REST server stopped"); + } catch (Exception e) { + log.error("Unable to stop REST server", e); + } finally { + jettyServer.destroy(); + } + shutdownExecutor.shutdown(); + return null; + }); + } + } + + /** + * Wait for shutdown to complete. May be called prior to beginShutdown. + */ + public void waitForShutdown() throws InterruptedException { + while (!shutdownExecutor.isShutdown()) { + shutdownExecutor.awaitTermination(1, TimeUnit.DAYS); + } + } + + /** + * Make an HTTP request. + * + * @param logger The logger to use. + * @param url HTTP connection will be established with this url. + * @param method HTTP method ("GET", "POST", "PUT", etc.) + * @param requestBodyData Object to serialize as JSON and send in the request body. + * @param responseFormat Expected format of the response to the HTTP request. + * @param The type of the deserialized response to the HTTP request. + * @return The deserialized response to the HTTP request, or null if no data is expected. + */ + public static HttpResponse httpRequest(Logger logger, String url, String method, + Object requestBodyData, TypeReference responseFormat) throws IOException { + HttpURLConnection connection = null; + try { + String serializedBody = requestBodyData == null ? null : + JsonUtil.JSON_SERDE.writeValueAsString(requestBodyData); + logger.debug("Sending {} with input {} to {}", method, serializedBody, url); + connection = (HttpURLConnection) new URL(url).openConnection(); + connection.setRequestMethod(method); + connection.setRequestProperty("User-Agent", "kafka"); + connection.setRequestProperty("Accept", "application/json"); + + // connection.getResponseCode() implicitly calls getInputStream, so always set + // this to true. + connection.setDoInput(true); + + connection.setUseCaches(false); + + if (requestBodyData != null) { + connection.setRequestProperty("Content-Type", "application/json"); + connection.setDoOutput(true); + + OutputStream os = connection.getOutputStream(); + os.write(serializedBody.getBytes(StandardCharsets.UTF_8)); + os.flush(); + os.close(); + } + + int responseCode = connection.getResponseCode(); + if (responseCode == HttpURLConnection.HTTP_NO_CONTENT) { + return new HttpResponse<>(null, new ErrorResponse(responseCode, connection.getResponseMessage())); + } else if ((responseCode >= 200) && (responseCode < 300)) { + InputStream is = connection.getInputStream(); + T result = JsonUtil.JSON_SERDE.readValue(is, responseFormat); + is.close(); + return new HttpResponse<>(result, null); + } else { + // If the resposne code was not in the 200s, we assume that this is an error + // response. + InputStream es = connection.getErrorStream(); + if (es == null) { + // Handle the case where HttpURLConnection#getErrorStream returns null. + return new HttpResponse<>(null, new ErrorResponse(responseCode, "")); + } + // Try to read the error response JSON. + ErrorResponse error = JsonUtil.JSON_SERDE.readValue(es, ErrorResponse.class); + es.close(); + return new HttpResponse<>(null, error); + } + } finally { + if (connection != null) { + connection.disconnect(); + } + } + } + + /** + * Make an HTTP request with retries. + * + * @param url HTTP connection will be established with this url. + * @param method HTTP method ("GET", "POST", "PUT", etc.) + * @param requestBodyData Object to serialize as JSON and send in the request body. + * @param responseFormat Expected format of the response to the HTTP request. + * @param The type of the deserialized response to the HTTP request. + * @return The deserialized response to the HTTP request, or null if no data is expected. + */ + public static HttpResponse httpRequest(String url, String method, Object requestBodyData, + TypeReference responseFormat, int maxTries) + throws IOException, InterruptedException { + return httpRequest(log, url, method, requestBodyData, responseFormat, maxTries); + } + + /** + * Make an HTTP request with retries. + * + * @param logger The logger to use. + * @param url HTTP connection will be established with this url. + * @param method HTTP method ("GET", "POST", "PUT", etc.) + * @param requestBodyData Object to serialize as JSON and send in the request body. + * @param responseFormat Expected format of the response to the HTTP request. + * @param The type of the deserialized response to the HTTP request. + * @return The deserialized response to the HTTP request, or null if no data is expected. + */ + public static HttpResponse httpRequest(Logger logger, String url, String method, + Object requestBodyData, TypeReference responseFormat, int maxTries) + throws IOException, InterruptedException { + IOException exc = null; + for (int tries = 0; tries < maxTries; tries++) { + if (tries > 0) { + Thread.sleep(tries > 1 ? 10 : 2); + } + try { + return httpRequest(logger, url, method, requestBodyData, responseFormat); + } catch (IOException e) { + logger.info("{} {}: error: {}", method, url, e.getMessage()); + exc = e; + } + } + throw exc; + } + + public static class HttpResponse { + private final T body; + private final ErrorResponse error; + + HttpResponse(T body, ErrorResponse error) { + this.body = body; + this.error = error; + } + + public T body() throws Exception { + if (error != null) { + throw RestExceptionMapper.toException(error.code(), error.message()); + } + return body; + } + + public ErrorResponse error() { + return error; + } + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/Message.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/Message.java new file mode 100644 index 0000000..c2ee840 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/Message.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.rest; + +import org.apache.kafka.trogdor.common.JsonUtil; + +import java.util.Objects; + +public abstract class Message { + @Override + public final boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + return Objects.equals(toString(), o.toString()); + } + + @Override + public final int hashCode() { + return toString().hashCode(); + } + + @Override + public final String toString() { + return JsonUtil.toJsonString(this); + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/RequestConflictException.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/RequestConflictException.java new file mode 100644 index 0000000..2701f6a --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/RequestConflictException.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.rest; + +/** + * Indicates that a given request got an HTTP error 409: CONFLICT. + */ +public class RequestConflictException extends RuntimeException { + private static final long serialVersionUID = 1L; + + public RequestConflictException(String message) { + super(message); + } + + public RequestConflictException() { + super(); + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/RestExceptionMapper.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/RestExceptionMapper.java new file mode 100644 index 0000000..57c54ec --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/RestExceptionMapper.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.trogdor.rest; + +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.exc.InvalidTypeIdException; +import org.apache.kafka.common.errors.InvalidRequestException; +import org.apache.kafka.common.errors.SerializationException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.ws.rs.NotFoundException; +import javax.ws.rs.core.Response; +import javax.ws.rs.ext.ExceptionMapper; + +public class RestExceptionMapper implements ExceptionMapper { + private static final Logger log = LoggerFactory.getLogger(RestExceptionMapper.class); + + @Override + public Response toResponse(Throwable e) { + if (log.isDebugEnabled()) { + log.debug("Uncaught exception in REST call: ", e); + } else if (log.isInfoEnabled()) { + log.info("Uncaught exception in REST call: {}", e.getMessage()); + } + if (e instanceof NotFoundException) { + return buildResponse(Response.Status.NOT_FOUND, e); + } else if (e instanceof InvalidRequestException) { + return buildResponse(Response.Status.BAD_REQUEST, e); + } else if (e instanceof InvalidTypeIdException) { + return buildResponse(Response.Status.NOT_IMPLEMENTED, e); + } else if (e instanceof JsonMappingException) { + return buildResponse(Response.Status.BAD_REQUEST, e); + } else if (e instanceof ClassNotFoundException) { + return buildResponse(Response.Status.NOT_IMPLEMENTED, e); + } else if (e instanceof SerializationException) { + return buildResponse(Response.Status.BAD_REQUEST, e); + } else if (e instanceof RequestConflictException) { + return buildResponse(Response.Status.CONFLICT, e); + } else { + return buildResponse(Response.Status.INTERNAL_SERVER_ERROR, e); + } + } + + public static Exception toException(int code, String msg) throws Exception { + if (code == Response.Status.NOT_FOUND.getStatusCode()) { + throw new NotFoundException(msg); + } else if (code == Response.Status.NOT_IMPLEMENTED.getStatusCode()) { + throw new ClassNotFoundException(msg); + } else if (code == Response.Status.BAD_REQUEST.getStatusCode()) { + throw new InvalidRequestException(msg); + } else if (code == Response.Status.CONFLICT.getStatusCode()) { + throw new RequestConflictException(msg); + } else { + throw new RuntimeException(msg); + } + } + + private Response buildResponse(Response.Status code, Throwable e) { + return Response.status(code). + entity(new ErrorResponse(code.getStatusCode(), e.getMessage())). + build(); + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/StopTaskRequest.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/StopTaskRequest.java new file mode 100644 index 0000000..704a961 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/StopTaskRequest.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.rest; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * A request to the Trogdor agent to stop a task. + */ +public class StopTaskRequest extends Message { + private final String id; + + @JsonCreator + public StopTaskRequest(@JsonProperty("id") String id) { + this.id = (id == null) ? "" : id; + } + + @JsonProperty + public String id() { + return id; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/StopWorkerRequest.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/StopWorkerRequest.java new file mode 100644 index 0000000..c1dcff3 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/StopWorkerRequest.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.rest; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * A request to the Trogdor agent to stop a worker. + */ +public class StopWorkerRequest extends Message { + private final long workerId; + + @JsonCreator + public StopWorkerRequest(@JsonProperty("workerId") long workerId) { + this.workerId = workerId; + } + + @JsonProperty + public long workerId() { + return workerId; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/TaskDone.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/TaskDone.java new file mode 100644 index 0000000..6e9761b --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/TaskDone.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.rest; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonNode; +import org.apache.kafka.trogdor.task.TaskSpec; + +/** + * The state a task is in once it's done. + */ +public class TaskDone extends TaskState { + /** + * The time on the coordinator when the task was started. + */ + private final long startedMs; + + /** + * The time on the coordinator when the task was completed. + */ + private final long doneMs; + + /** + * Empty if the task completed without error; the error message otherwise. + */ + private final String error; + + /** + * True if the task was manually cancelled, rather than terminating itself. + */ + private final boolean cancelled; + + @JsonCreator + public TaskDone(@JsonProperty("spec") TaskSpec spec, + @JsonProperty("startedMs") long startedMs, + @JsonProperty("doneMs") long doneMs, + @JsonProperty("error") String error, + @JsonProperty("cancelled") boolean cancelled, + @JsonProperty("status") JsonNode status) { + super(spec, status); + this.startedMs = startedMs; + this.doneMs = doneMs; + this.error = error; + this.cancelled = cancelled; + } + + @JsonProperty + public long startedMs() { + return startedMs; + } + + @JsonProperty + public long doneMs() { + return doneMs; + } + + @JsonProperty + public String error() { + return error; + } + + @JsonProperty + public boolean cancelled() { + return cancelled; + } + + @Override + public TaskStateType stateType() { + return TaskStateType.DONE; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/TaskPending.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/TaskPending.java new file mode 100644 index 0000000..ca1d314 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/TaskPending.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.rest; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.node.NullNode; +import org.apache.kafka.trogdor.task.TaskSpec; + +/** + * The state for a task which is still pending. + */ +public class TaskPending extends TaskState { + @JsonCreator + public TaskPending(@JsonProperty("spec") TaskSpec spec) { + super(spec, NullNode.instance); + } + + @Override + public TaskStateType stateType() { + return TaskStateType.PENDING; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/TaskRequest.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/TaskRequest.java new file mode 100644 index 0000000..e42738f --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/TaskRequest.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.rest; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * The request to /coordinator/tasks/{taskId} + */ +public class TaskRequest { + private final String taskId; + + @JsonCreator + public TaskRequest(@JsonProperty("taskId") String taskId) { + this.taskId = taskId == null ? "" : taskId; + } + + @JsonProperty + public String taskId() { + return taskId; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/TaskRunning.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/TaskRunning.java new file mode 100644 index 0000000..8487bc3 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/TaskRunning.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.rest; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonNode; +import org.apache.kafka.trogdor.task.TaskSpec; + +/** + * The state for a task which is being run by the agent. + */ +public class TaskRunning extends TaskState { + /** + * The time on the agent when the task was started. + */ + private final long startedMs; + + @JsonCreator + public TaskRunning(@JsonProperty("spec") TaskSpec spec, + @JsonProperty("startedMs") long startedMs, + @JsonProperty("status") JsonNode status) { + super(spec, status); + this.startedMs = startedMs; + } + + @JsonProperty + public long startedMs() { + return startedMs; + } + + @Override + public TaskStateType stateType() { + return TaskStateType.RUNNING; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/TaskState.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/TaskState.java new file mode 100644 index 0000000..b47836e --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/TaskState.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.rest; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.NullNode; +import org.apache.kafka.trogdor.task.TaskSpec; + +/** + * The state which a task is in on the Coordinator. + */ +@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, + include = JsonTypeInfo.As.PROPERTY, + property = "state") +@JsonSubTypes({ + @JsonSubTypes.Type(value = TaskPending.class, name = TaskStateType.Constants.PENDING_VALUE), + @JsonSubTypes.Type(value = TaskRunning.class, name = TaskStateType.Constants.RUNNING_VALUE), + @JsonSubTypes.Type(value = TaskStopping.class, name = TaskStateType.Constants.STOPPING_VALUE), + @JsonSubTypes.Type(value = TaskDone.class, name = TaskStateType.Constants.DONE_VALUE) + }) +public abstract class TaskState extends Message { + private final TaskSpec spec; + + private final JsonNode status; + + public TaskState(TaskSpec spec, JsonNode status) { + this.spec = spec; + this.status = status == null ? NullNode.instance : status; + } + + @JsonProperty + public TaskSpec spec() { + return spec; + } + + @JsonProperty + public JsonNode status() { + return status; + } + + public abstract TaskStateType stateType(); +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/TaskStateType.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/TaskStateType.java new file mode 100644 index 0000000..c8ade06 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/TaskStateType.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.trogdor.rest; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * The types of states a single Task can be in + */ +public enum TaskStateType { + PENDING(Constants.PENDING_VALUE), + RUNNING(Constants.RUNNING_VALUE), + STOPPING(Constants.STOPPING_VALUE), + DONE(Constants.DONE_VALUE); + + TaskStateType(String stateType) {} + + public static class Constants { + static final String PENDING_VALUE = "PENDING"; + static final String RUNNING_VALUE = "RUNNING"; + static final String STOPPING_VALUE = "STOPPING"; + static final String DONE_VALUE = "DONE"; + public static final List VALUES = Collections.unmodifiableList( + Arrays.asList(PENDING_VALUE, RUNNING_VALUE, STOPPING_VALUE, DONE_VALUE)); + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/TaskStopping.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/TaskStopping.java new file mode 100644 index 0000000..2b2c4c4 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/TaskStopping.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.rest; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonNode; +import org.apache.kafka.trogdor.task.TaskSpec; + +/** + * The state for a task which is being stopped on the coordinator. + */ +public class TaskStopping extends TaskState { + /** + * The time on the agent when the task was received. + */ + private final long startedMs; + + @JsonCreator + public TaskStopping(@JsonProperty("spec") TaskSpec spec, + @JsonProperty("startedMs") long startedMs, + @JsonProperty("status") JsonNode status) { + super(spec, status); + this.startedMs = startedMs; + } + + @JsonProperty + public long startedMs() { + return startedMs; + } + + @Override + public TaskStateType stateType() { + return TaskStateType.STOPPING; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/TasksRequest.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/TasksRequest.java new file mode 100644 index 0000000..150a362 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/TasksRequest.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.rest; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.Optional; +import java.util.Set; + +/** + * The request to /coordinator/tasks + */ +public class TasksRequest extends Message { + /** + * The task IDs to list. + * An empty set of task IDs indicates that we should list all task IDs. + */ + private final Set taskIds; + + /** + * If this is non-zero, only tasks with a startMs at or after this time will be listed. + */ + private final long firstStartMs; + + /** + * If this is non-zero, only tasks with a startMs at or before this time will be listed. + */ + private final long lastStartMs; + + /** + * If this is non-zero, only tasks with an endMs at or after this time will be listed. + */ + private final long firstEndMs; + + /** + * If this is non-zero, only tasks with an endMs at or before this time will be listed. + */ + private final long lastEndMs; + + /** + * The desired state of the tasks. + * An empty string will match all states. + */ + private final Optional state; + + @JsonCreator + public TasksRequest(@JsonProperty("taskIds") Collection taskIds, + @JsonProperty("firstStartMs") long firstStartMs, + @JsonProperty("lastStartMs") long lastStartMs, + @JsonProperty("firstEndMs") long firstEndMs, + @JsonProperty("lastEndMs") long lastEndMs, + @JsonProperty("state") Optional state) { + this.taskIds = Collections.unmodifiableSet((taskIds == null) ? + new HashSet() : new HashSet<>(taskIds)); + this.firstStartMs = Math.max(0, firstStartMs); + this.lastStartMs = Math.max(0, lastStartMs); + this.firstEndMs = Math.max(0, firstEndMs); + this.lastEndMs = Math.max(0, lastEndMs); + this.state = state == null ? Optional.empty() : state; + } + + @JsonProperty + public Collection taskIds() { + return taskIds; + } + + @JsonProperty + public long firstStartMs() { + return firstStartMs; + } + + @JsonProperty + public long lastStartMs() { + return lastStartMs; + } + + @JsonProperty + public long firstEndMs() { + return firstEndMs; + } + + @JsonProperty + public long lastEndMs() { + return lastEndMs; + } + + @JsonProperty + public Optional state() { + return state; + } + + /** + * Determine if this TaskRequest should return a particular task. + * + * @param taskId The task ID. + * @param startMs The task start time, or -1 if the task hasn't started. + * @param endMs The task end time, or -1 if the task hasn't ended. + * @return True if information about the task should be returned. + */ + public boolean matches(String taskId, long startMs, long endMs, TaskStateType state) { + if ((!taskIds.isEmpty()) && (!taskIds.contains(taskId))) { + return false; + } + if ((firstStartMs > 0) && (startMs < firstStartMs)) { + return false; + } + if ((lastStartMs > 0) && ((startMs < 0) || (startMs > lastStartMs))) { + return false; + } + if ((firstEndMs > 0) && (endMs < firstEndMs)) { + return false; + } + if ((lastEndMs > 0) && ((endMs < 0) || (endMs > lastEndMs))) { + return false; + } + + if (this.state.isPresent() && !this.state.get().equals(state)) { + return false; + } + + return true; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/TasksResponse.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/TasksResponse.java new file mode 100644 index 0000000..5a3149c --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/TasksResponse.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.rest; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Collections; +import java.util.Map; +import java.util.TreeMap; + +/** + * The response to /coordinator/tasks + */ +public class TasksResponse extends Message { + private final Map tasks; + + @JsonCreator + public TasksResponse(@JsonProperty("tasks") TreeMap tasks) { + this.tasks = Collections.unmodifiableMap((tasks == null) ? + new TreeMap() : tasks); + } + + @JsonProperty + public Map tasks() { + return tasks; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/UptimeResponse.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/UptimeResponse.java new file mode 100644 index 0000000..51393b1 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/UptimeResponse.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.rest; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * A response from the Trogdor Agent/Coordinator about how long it has been running + */ +public class UptimeResponse extends Message { + + private long serverStartMs; + private long nowMs; + + @JsonCreator + public UptimeResponse(@JsonProperty("serverStartMs") long serverStartMs, + @JsonProperty("nowMs") long nowMs) { + this.serverStartMs = serverStartMs; + this.nowMs = nowMs; + } + + @JsonProperty + public long serverStartMs() { + return serverStartMs; + } + + @JsonProperty + public long nowMs() { + return nowMs; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/WorkerDone.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/WorkerDone.java new file mode 100644 index 0000000..5f773bb --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/WorkerDone.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.rest; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.NullNode; +import org.apache.kafka.trogdor.task.TaskSpec; + +/** + * The state a worker is in once it's done. + */ +public class WorkerDone extends WorkerState { + /** + * The time on the agent when the task was started. + */ + private final long startedMs; + + /** + * The time on the agent when the task was completed. + */ + private final long doneMs; + + /** + * The task status. The format will depend on the type of task that is + * being run. + */ + private final JsonNode status; + + /** + * Empty if the task completed without error; the error message otherwise. + */ + private final String error; + + @JsonCreator + public WorkerDone(@JsonProperty("taskId") String taskId, + @JsonProperty("spec") TaskSpec spec, + @JsonProperty("startedMs") long startedMs, + @JsonProperty("doneMs") long doneMs, + @JsonProperty("status") JsonNode status, + @JsonProperty("error") String error) { + super(taskId, spec); + this.startedMs = startedMs; + this.doneMs = doneMs; + this.status = status == null ? NullNode.instance : status; + this.error = error == null ? "" : error; + } + + @JsonProperty + @Override + public long startedMs() { + return startedMs; + } + + @JsonProperty + public long doneMs() { + return doneMs; + } + + @JsonProperty + @Override + public JsonNode status() { + return status; + } + + @JsonProperty + public String error() { + return error; + } + + @Override + public boolean done() { + return true; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/WorkerReceiving.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/WorkerReceiving.java new file mode 100644 index 0000000..1babcce --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/WorkerReceiving.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.rest; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.TextNode; +import org.apache.kafka.trogdor.task.TaskSpec; + +/** + * When we're in the process of sending a TaskSpec to the Agent, the Worker is regarded + * as being in WorkerReceiving state. + */ +public final class WorkerReceiving extends WorkerState { + @JsonCreator + public WorkerReceiving(@JsonProperty("taskId") String taskId, + @JsonProperty("spec") TaskSpec spec) { + super(taskId, spec); + } + + @Override + public JsonNode status() { + return new TextNode("receiving"); + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/WorkerRunning.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/WorkerRunning.java new file mode 100644 index 0000000..15e7752 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/WorkerRunning.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.rest; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.NullNode; +import org.apache.kafka.trogdor.task.TaskSpec; + +/** + * The state for a task which is being run by the agent. + */ +public class WorkerRunning extends WorkerState { + /** + * The time on the agent when the task was started. + */ + private final long startedMs; + + /** + * The task status. The format will depend on the type of task that is + * being run. + */ + private final JsonNode status; + + @JsonCreator + public WorkerRunning(@JsonProperty("taskId") String taskId, + @JsonProperty("spec") TaskSpec spec, + @JsonProperty("startedMs") long startedMs, + @JsonProperty("status") JsonNode status) { + super(taskId, spec); + this.startedMs = startedMs; + this.status = status == null ? NullNode.instance : status; + } + + @JsonProperty + @Override + public long startedMs() { + return startedMs; + } + + @JsonProperty + @Override + public JsonNode status() { + return status; + } + + @Override + public boolean running() { + return true; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/WorkerStarting.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/WorkerStarting.java new file mode 100644 index 0000000..7a06eac --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/WorkerStarting.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.rest; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.TextNode; +import org.apache.kafka.trogdor.task.TaskSpec; + +/** + * When we have just started a worker. + */ +public final class WorkerStarting extends WorkerState { + @JsonCreator + public WorkerStarting(@JsonProperty("taskId") String taskId, + @JsonProperty("spec") TaskSpec spec) { + super(taskId, spec); + } + + @Override + public JsonNode status() { + return new TextNode("starting"); + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/WorkerState.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/WorkerState.java new file mode 100644 index 0000000..6480a24 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/WorkerState.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.rest; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import com.fasterxml.jackson.databind.JsonNode; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.trogdor.task.TaskSpec; + +/** + * The state which a worker is in on the Agent. + */ +@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, + include = JsonTypeInfo.As.PROPERTY, + property = "state") +@JsonSubTypes(value = { + @JsonSubTypes.Type(value = WorkerReceiving.class, name = "RECEIVING"), + @JsonSubTypes.Type(value = WorkerStarting.class, name = "STARTING"), + @JsonSubTypes.Type(value = WorkerRunning.class, name = "RUNNING"), + @JsonSubTypes.Type(value = WorkerStopping.class, name = "STOPPING"), + @JsonSubTypes.Type(value = WorkerDone.class, name = "DONE") + }) +public abstract class WorkerState extends Message { + private final String taskId; + private final TaskSpec spec; + + public WorkerState(String taskId, TaskSpec spec) { + this.taskId = taskId; + this.spec = spec; + } + + @JsonProperty + public String taskId() { + return taskId; + } + + @JsonProperty + public TaskSpec spec() { + return spec; + } + + public boolean stopping() { + return false; + } + + public boolean done() { + return false; + } + + public long startedMs() { + throw new KafkaException("invalid state"); + } + + public abstract JsonNode status(); + + public boolean running() { + return false; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/rest/WorkerStopping.java b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/WorkerStopping.java new file mode 100644 index 0000000..2942e11 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/rest/WorkerStopping.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.rest; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.NullNode; +import org.apache.kafka.trogdor.task.TaskSpec; + +/** + * The state for a worker which is being stopped on the agent. + */ +public class WorkerStopping extends WorkerState { + /** + * The time on the agent when the task was received. + */ + private final long startedMs; + + /** + * The task status. The format will depend on the type of task that is + * being run. + */ + private final JsonNode status; + + @JsonCreator + public WorkerStopping(@JsonProperty("taskId") String taskId, + @JsonProperty("spec") TaskSpec spec, + @JsonProperty("startedMs") long startedMs, + @JsonProperty("status") JsonNode status) { + super(taskId, spec); + this.startedMs = startedMs; + this.status = status == null ? NullNode.instance : status; + } + + @JsonProperty + @Override + public long startedMs() { + return startedMs; + } + + @JsonProperty + @Override + public JsonNode status() { + return status; + } + + @Override + public boolean stopping() { + return true; + } + + @Override + public boolean running() { + return true; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/task/AgentWorkerStatusTracker.java b/trogdor/src/main/java/org/apache/kafka/trogdor/task/AgentWorkerStatusTracker.java new file mode 100644 index 0000000..2ad8e4e --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/task/AgentWorkerStatusTracker.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.task; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.NullNode; + +/** + * Tracks the status of a Trogdor worker. + */ +public class AgentWorkerStatusTracker implements WorkerStatusTracker { + private JsonNode status = NullNode.instance; + + @Override + public void update(JsonNode newStatus) { + JsonNode status = newStatus.deepCopy(); + synchronized (this) { + this.status = status; + } + } + + /** + * Retrieves the status. + */ + public synchronized JsonNode get() { + return status; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/task/NoOpTaskController.java b/trogdor/src/main/java/org/apache/kafka/trogdor/task/NoOpTaskController.java new file mode 100644 index 0000000..b5906c3 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/task/NoOpTaskController.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.task; + +import org.apache.kafka.trogdor.common.Topology; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Set; + +public class NoOpTaskController implements TaskController { + private static final Logger log = LoggerFactory.getLogger(NoOpTaskController.class); + + public NoOpTaskController() { + } + + @Override + public Set targetNodes(Topology topology) { + return Topology.Util.agentNodeNames(topology); + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/task/NoOpTaskSpec.java b/trogdor/src/main/java/org/apache/kafka/trogdor/task/NoOpTaskSpec.java new file mode 100644 index 0000000..63e6023 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/task/NoOpTaskSpec.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.task; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * The specification for a task that does nothing. + * + * This task type exists to test trogodor itself. + */ +public class NoOpTaskSpec extends TaskSpec { + @JsonCreator + public NoOpTaskSpec(@JsonProperty("startMs") long startMs, + @JsonProperty("durationMs") long durationMs) { + super(startMs, durationMs); + } + + @Override + public TaskController newController(String id) { + return new NoOpTaskController(); + } + + @Override + public TaskWorker newTaskWorker(String id) { + return new NoOpTaskWorker(id); + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/task/NoOpTaskWorker.java b/trogdor/src/main/java/org/apache/kafka/trogdor/task/NoOpTaskWorker.java new file mode 100644 index 0000000..77336d8 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/task/NoOpTaskWorker.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.task; + +import com.fasterxml.jackson.databind.node.TextNode; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.trogdor.common.Platform; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class NoOpTaskWorker implements TaskWorker { + private static final Logger log = LoggerFactory.getLogger(NoOpTaskWorker.class); + + private final String id; + + private WorkerStatusTracker status; + + public NoOpTaskWorker(String id) { + this.id = id; + } + + @Override + public void start(Platform platform, WorkerStatusTracker status, + KafkaFutureImpl errorFuture) throws Exception { + log.info("{}: Activating NoOpTask.", id); + this.status = status; + this.status.update(new TextNode("active")); + } + + @Override + public void stop(Platform platform) throws Exception { + log.info("{}: Deactivating NoOpTask.", id); + this.status.update(new TextNode("done")); + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/task/TaskController.java b/trogdor/src/main/java/org/apache/kafka/trogdor/task/TaskController.java new file mode 100644 index 0000000..dbd0b09 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/task/TaskController.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.task; + +import org.apache.kafka.trogdor.common.Topology; + +import java.util.Set; + +/** + * Controls a Trogdor task. + */ +public interface TaskController { + /** + * Get the agent nodes which this task is targetting. + * + * @param topology The topology to use. + * + * @return A set of target node names. + */ + Set targetNodes(Topology topology); +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/task/TaskSpec.java b/trogdor/src/main/java/org/apache/kafka/trogdor/task/TaskSpec.java new file mode 100644 index 0000000..acb19f6 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/task/TaskSpec.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.task; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import org.apache.kafka.trogdor.common.JsonUtil; + +import java.util.Collections; +import java.util.Map; +import java.util.Objects; + + +/** + * The specification for a task. This should be immutable and suitable for serializing and sending over the wire. + */ +@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, + include = JsonTypeInfo.As.PROPERTY, + property = "class") +public abstract class TaskSpec { + /** + * The maximum task duration. + * + * We cap the task duration at this value to avoid worrying about 64-bit overflow or floating + * point rounding. (Objects serialized as JSON canonically contain only floating point numbers, + * because JavaScript did not support integers.) + */ + public final static long MAX_TASK_DURATION_MS = 1000000000000000L; + + /** + * When the time should start in milliseconds. + */ + private final long startMs; + + /** + * How long the task should run in milliseconds. + */ + private final long durationMs; + + protected TaskSpec(@JsonProperty("startMs") long startMs, + @JsonProperty("durationMs") long durationMs) { + this.startMs = startMs; + this.durationMs = Math.max(0, Math.min(durationMs, MAX_TASK_DURATION_MS)); + } + + /** + * Get the target start time of this task in ms. + */ + @JsonProperty + public final long startMs() { + return startMs; + } + + /** + * Get the deadline time of this task in ms + */ + public final long endMs() { + return startMs + durationMs; + } + + /** + * Get the duration of this task in ms. + */ + @JsonProperty + public final long durationMs() { + return durationMs; + } + + /** + * Hydrate this task on the coordinator. + * + * @param id The task id. + */ + public abstract TaskController newController(String id); + + /** + * Hydrate this task on the agent. + * + * @param id The worker id. + */ + public abstract TaskWorker newTaskWorker(String id); + + @Override + public final boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + return toString().equals(o.toString()); + } + + @Override + public final int hashCode() { + return Objects.hashCode(toString()); + } + + @Override + public String toString() { + return JsonUtil.toJsonString(this); + } + + protected Map configOrEmptyMap(Map config) { + return (config == null) ? Collections.emptyMap() : config; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/task/TaskWorker.java b/trogdor/src/main/java/org/apache/kafka/trogdor/task/TaskWorker.java new file mode 100644 index 0000000..042568f --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/task/TaskWorker.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.task; + +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.trogdor.common.Platform; + +/** + * The agent-side interface for implementing tasks. + */ +public interface TaskWorker { + /** + * Starts the TaskWorker. + * + * We do not hold any locks or block the WorkerManager state machine on this call. + * However, createTask requests to the agent call this function directly. + * Therefore, your start() implementation may take a little while, but not too long. + * While you can perform short blocking tasks in this function, it is better to + * start a background thread to do something time-consuming. + * + * If the start() function throws an exception, the Agent will assume that the TaskWorker + * never started. Therefore, stop() will never be invoked. On the other hand, if the + * errorFuture is completed, either by a background task or by the start function itself, + * the Agent will invoke the stop() method to clean up the worker. + * + * + * @param platform The platform to use. + * @param status The current status. The TaskWorker can update + * this at any time to provide an updated status. + * @param haltFuture A future which the worker should complete if it halts. + * If it is completed with an empty string, that means the task + * halted with no error. Otherwise, the string is treated as the error. + * If you start a background thread, you may pass haltFuture + * to that thread. Then, the thread can use this future to indicate + * that the worker should be stopped. + * + * @throws Exception If the TaskWorker failed to start. stop() will not be invoked. + */ + void start(Platform platform, WorkerStatusTracker status, KafkaFutureImpl haltFuture) + throws Exception; + + /** + * Stops the TaskWorker. + * + * A TaskWorker may be stopped because it has run for its assigned duration, or because a + * request arrived instructing the Agent to stop the worker. The TaskWorker will + * also be stopped if errorFuture was completed to indicate that there was an error. + * + * Regardless of why the TaskWorker was stopped, the stop() function should release all + * resources and stop all threads before returning. The stop() function can block for + * as long as it wants. It is run in a background thread which will not block other + * agent operations. All tasks will be stopped when the Agent cleanly shuts down. + * + * @param platform The platform to use. + * + * @throws Exception If there was an error cleaning up the TaskWorker. + * If there is no existing TaskWorker error, the worker will be + * treated as having failed with the given error. + */ + void stop(Platform platform) throws Exception; +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/task/WorkerStatusTracker.java b/trogdor/src/main/java/org/apache/kafka/trogdor/task/WorkerStatusTracker.java new file mode 100644 index 0000000..dfbc7ea --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/task/WorkerStatusTracker.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.task; + +import com.fasterxml.jackson.databind.JsonNode; + +/** + * Tracks the status of a Trogdor worker. + */ +public interface WorkerStatusTracker { + /** + * Updates the status. + * + * @param status The new status. + */ + void update(JsonNode status); +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ConfigurableProducerSpec.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ConfigurableProducerSpec.java new file mode 100644 index 0000000..5235fc3 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ConfigurableProducerSpec.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.kafka.trogdor.task.TaskController; +import org.apache.kafka.trogdor.task.TaskSpec; +import org.apache.kafka.trogdor.task.TaskWorker; +import java.util.Collections; +import java.util.Map; +import java.util.Optional; + +/** + * This is the spec to pass in to be able to run the `ConfigurableProducerWorker` workload. This allows for customized + * and even variable configurations in terms of messages per second, message size, batch size, key size, and even the + * ability to target a specific partition out of a topic. + * + * This has several notable differences from the ProduceBench classes, namely the ability to dynamically control + * flushing and throughput through configurable classes, but also the ability to run against specific partitions within + * a topic directly. This workload can only run against one topic at a time, unlike the ProduceBench workload. + * + * The parameters that differ from ProduceBenchSpec: + * + * `flushGenerator` - Used to instruct the KafkaProducer when to issue flushes. This allows us to simulate + * variable batching since batch flushing is not currently exposed within the KafkaProducer + * class. See the `FlushGenerator` interface for more information. + * + * `throughputGenerator` - Used to throttle the ConfigurableProducerWorker based on a calculated number of messages + * within a window. See the `ThroughputGenerator` interface for more information. + * + * `activeTopic` - This class only supports execution against a single topic at a time. If more than one + * topic is specified, the ConfigurableProducerWorker will throw an error. + * + * `activePartition` - Specify a specific partition number within the activeTopic to run load against, or + * specify `-1` to allow use of all partitions. + * + * Here is an example spec: + * + * { + * "startMs": 1606949497662, + * "durationMs": 3600000, + * "producerNode": "trogdor-agent-0", + * "bootstrapServers": "some.example.kafka.server:9091", + * "flushGenerator": { + * "type": "gaussian", + * "messagesPerFlushAverage": 16, + * "messagesPerFlushDeviation": 4 + * }, + * "throughputGenerator": { + * "type": "gaussian", + * "messagesPerSecondAverage": 500, + * "messagesPerSecondDeviation": 50, + * "windowsUntilRateChange": 100, + * "windowSizeMs": 100 + * }, + * "keyGenerator": { + * "type": "constant", + * "size": 8 + * }, + * "valueGenerator": { + * "type": "gaussianTimestampRandom", + * "messageSizeAverage": 512, + * "messageSizeDeviation": 100, + * "messagesUntilSizeChange": 100 + * }, + * "producerConf": { + * "acks": "all" + * }, + * "commonClientConf": {}, + * "adminClientConf": {}, + * "activeTopic": { + * "topic0": { + * "numPartitions": 100, + * "replicationFactor": 3, + * "configs": { + * "retention.ms": "1800000" + * } + * } + * }, + * "activePartition": 5 + * } + * + * This example spec performed the following: + * + * * Ran on `trogdor-agent-0` for 1 hour starting at 2020-12-02 22:51:37.662 GMT + * * Produced with acks=all to Partition 5 of `topic0` on kafka server `some.example.kafka.server:9091`. + * * The average batch had 16 messages, with a standard deviation of 4 messages. + * * The messages had 8-byte constant keys with an average size of 512 bytes and a standard deviation of 100 bytes. + * * The messages had millisecond timestamps embedded in the first several bytes of the value. + * * The average throughput was 500 messages/second, with a window of 100ms and a deviation of 50 messages/second. + */ + +public class ConfigurableProducerSpec extends TaskSpec { + private final String producerNode; + private final String bootstrapServers; + private final Optional flushGenerator; + private final ThroughputGenerator throughputGenerator; + private final PayloadGenerator keyGenerator; + private final PayloadGenerator valueGenerator; + private final Map producerConf; + private final Map adminClientConf; + private final Map commonClientConf; + private final TopicsSpec activeTopic; + private final int activePartition; + + @JsonCreator + public ConfigurableProducerSpec(@JsonProperty("startMs") long startMs, + @JsonProperty("durationMs") long durationMs, + @JsonProperty("producerNode") String producerNode, + @JsonProperty("bootstrapServers") String bootstrapServers, + @JsonProperty("flushGenerator") Optional flushGenerator, + @JsonProperty("throughputGenerator") ThroughputGenerator throughputGenerator, + @JsonProperty("keyGenerator") PayloadGenerator keyGenerator, + @JsonProperty("valueGenerator") PayloadGenerator valueGenerator, + @JsonProperty("producerConf") Map producerConf, + @JsonProperty("commonClientConf") Map commonClientConf, + @JsonProperty("adminClientConf") Map adminClientConf, + @JsonProperty("activeTopic") TopicsSpec activeTopic, + @JsonProperty("activePartition") int activePartition) { + super(startMs, durationMs); + this.producerNode = (producerNode == null) ? "" : producerNode; + this.bootstrapServers = (bootstrapServers == null) ? "" : bootstrapServers; + this.flushGenerator = flushGenerator; + this.keyGenerator = keyGenerator; + this.valueGenerator = valueGenerator; + this.throughputGenerator = throughputGenerator; + this.producerConf = configOrEmptyMap(producerConf); + this.commonClientConf = configOrEmptyMap(commonClientConf); + this.adminClientConf = configOrEmptyMap(adminClientConf); + this.activeTopic = activeTopic.immutableCopy(); + this.activePartition = activePartition; + } + + @JsonProperty + public String producerNode() { + return producerNode; + } + + @JsonProperty + public String bootstrapServers() { + return bootstrapServers; + } + + @JsonProperty + public Optional flushGenerator() { + return flushGenerator; + } + + @JsonProperty + public PayloadGenerator keyGenerator() { + return keyGenerator; + } + + @JsonProperty + public PayloadGenerator valueGenerator() { + return valueGenerator; + } + + @JsonProperty + public ThroughputGenerator throughputGenerator() { + return throughputGenerator; + } + + @JsonProperty + public Map producerConf() { + return producerConf; + } + + @JsonProperty + public Map commonClientConf() { + return commonClientConf; + } + + @JsonProperty + public Map adminClientConf() { + return adminClientConf; + } + + @JsonProperty + public TopicsSpec activeTopic() { + return activeTopic; + } + + @JsonProperty + public int activePartition() { + return activePartition; + } + + @Override + public TaskController newController(String id) { + return topology -> Collections.singleton(producerNode); + } + + @Override + public TaskWorker newTaskWorker(String id) { + return new ConfigurableProducerWorker(id, this); + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ConfigurableProducerWorker.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ConfigurableProducerWorker.java new file mode 100644 index 0000000..b08ef44 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ConfigurableProducerWorker.java @@ -0,0 +1,322 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.node.TextNode; +import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.clients.producer.Callback; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.utils.ThreadUtils; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.trogdor.common.JsonUtil; +import org.apache.kafka.trogdor.common.Platform; +import org.apache.kafka.trogdor.common.WorkerUtils; +import org.apache.kafka.trogdor.task.TaskWorker; +import org.apache.kafka.trogdor.task.WorkerStatusTracker; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.Callable; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * This workload allows for customized and even variable configurations in terms of messages per second, message size, + * batch size, key size, and even the ability to target a specific partition out of a topic. + * + * See `ConfigurableProducerSpec` for a more detailed description. + */ + +public class ConfigurableProducerWorker implements TaskWorker { + private static final Logger log = LoggerFactory.getLogger(ConfigurableProducerWorker.class); + + private final String id; + + private final ConfigurableProducerSpec spec; + + private final AtomicBoolean running = new AtomicBoolean(false); + + private ScheduledExecutorService executor; + + private WorkerStatusTracker status; + + private KafkaFutureImpl doneFuture; + + public ConfigurableProducerWorker(String id, ConfigurableProducerSpec spec) { + this.id = id; + this.spec = spec; + } + + @Override + public void start(Platform platform, WorkerStatusTracker status, + KafkaFutureImpl doneFuture) { + if (!running.compareAndSet(false, true)) { + throw new IllegalStateException("ConfigurableProducerWorker is already running."); + } + log.info("{}: Activating ConfigurableProducerWorker with {}", id, spec); + // Create an executor with 2 threads. We need the second thread so + // that the StatusUpdater can run in parallel with SendRecords. + this.executor = Executors.newScheduledThreadPool(2, + ThreadUtils.createThreadFactory("ConfigurableProducerWorkerThread%d", false)); + this.status = status; + this.doneFuture = doneFuture; + executor.submit(new Prepare()); + } + + public class Prepare implements Runnable { + @Override + public void run() { + try { + Map newTopics = new HashMap<>(); + if (spec.activeTopic().materialize().size() != 1) { + throw new RuntimeException("Can only run against 1 topic."); + } + List active = new ArrayList<>(); + for (Map.Entry entry : + spec.activeTopic().materialize().entrySet()) { + String topicName = entry.getKey(); + PartitionsSpec partSpec = entry.getValue(); + newTopics.put(topicName, partSpec.newTopic(topicName)); + for (Integer partitionNumber : partSpec.partitionNumbers()) { + active.add(new TopicPartition(topicName, partitionNumber)); + } + } + status.update(new TextNode("Creating " + newTopics.keySet().size() + " topic(s)")); + WorkerUtils.createTopics(log, spec.bootstrapServers(), spec.commonClientConf(), + spec.adminClientConf(), newTopics, false); + status.update(new TextNode("Created " + newTopics.keySet().size() + " topic(s)")); + executor.submit(new SendRecords(active.get(0).topic(), spec.activePartition())); + } catch (Throwable e) { + WorkerUtils.abort(log, "Prepare", e, doneFuture); + } + } + } + + private static class SendRecordsCallback implements Callback { + private final SendRecords sendRecords; + private final long startMs; + + SendRecordsCallback(SendRecords sendRecords, long startMs) { + this.sendRecords = sendRecords; + this.startMs = startMs; + } + + @Override + public void onCompletion(RecordMetadata metadata, Exception exception) { + long now = Time.SYSTEM.milliseconds(); + long durationMs = now - startMs; + sendRecords.recordDuration(durationMs); + if (exception != null) { + log.error("SendRecordsCallback: error", exception); + } + } + } + + public class SendRecords implements Callable { + private final String activeTopic; + private final int activePartition; + + private final Histogram histogram; + + private final Future statusUpdaterFuture; + + private final KafkaProducer producer; + + private final PayloadIterator keys; + + private final PayloadIterator values; + + private Future sendFuture; + + SendRecords(String topic, int partition) { + this.activeTopic = topic; + this.activePartition = partition; + this.histogram = new Histogram(10000); + + this.statusUpdaterFuture = executor.scheduleWithFixedDelay( + new StatusUpdater(histogram), 30, 30, TimeUnit.SECONDS); + + Properties props = new Properties(); + props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, spec.bootstrapServers()); + WorkerUtils.addConfigsToProperties(props, spec.commonClientConf(), spec.producerConf()); + this.producer = new KafkaProducer<>(props, new ByteArraySerializer(), new ByteArraySerializer()); + this.keys = new PayloadIterator(spec.keyGenerator()); + this.values = new PayloadIterator(spec.valueGenerator()); + } + + @Override + public Void call() throws Exception { + long startTimeMs = Time.SYSTEM.milliseconds(); + try { + try { + long sentMessages = 0; + while (true) { + sendMessage(); + sentMessages++; + } + } catch (Exception e) { + throw e; + } finally { + if (sendFuture != null) { + try { + sendFuture.get(); + } catch (Exception e) { + log.error("Exception on final future", e); + } + } + producer.close(); + } + } catch (Exception e) { + WorkerUtils.abort(log, "SendRecords", e, doneFuture); + } finally { + statusUpdaterFuture.cancel(false); + StatusData statusData = new StatusUpdater(histogram).update(); + long curTimeMs = Time.SYSTEM.milliseconds(); + log.info("Sent {} total record(s) in {} ms. status: {}", + histogram.summarize().numSamples(), curTimeMs - startTimeMs, statusData); + } + doneFuture.complete(""); + return null; + } + + private void sendMessage() throws InterruptedException { + ProducerRecord record; + if (activePartition != -1) { + record = new ProducerRecord<>(activeTopic, activePartition, keys.next(), values.next()); + } else { + record = new ProducerRecord<>(activeTopic, keys.next(), values.next()); + } + sendFuture = producer.send(record, new SendRecordsCallback(this, Time.SYSTEM.milliseconds())); + spec.flushGenerator().ifPresent(flushGenerator -> flushGenerator.increment(producer)); + spec.throughputGenerator().throttle(); + } + + void recordDuration(long durationMs) { + histogram.add(durationMs); + } + } + + public class StatusUpdater implements Runnable { + private final Histogram histogram; + + StatusUpdater(Histogram histogram) { + this.histogram = histogram; + } + + @Override + public void run() { + try { + update(); + } catch (Exception e) { + WorkerUtils.abort(log, "StatusUpdater", e, doneFuture); + } + } + + StatusData update() { + Histogram.Summary summary = histogram.summarize(StatusData.PERCENTILES); + StatusData statusData = new StatusData(summary.numSamples(), summary.average(), + summary.percentiles().get(0).value(), + summary.percentiles().get(1).value(), + summary.percentiles().get(2).value()); + status.update(JsonUtil.JSON_SERDE.valueToTree(statusData)); + return statusData; + } + } + + public static class StatusData { + private final long totalSent; + private final float averageLatencyMs; + private final int p50LatencyMs; + private final int p95LatencyMs; + private final int p99LatencyMs; + + /** + * The percentiles to use when calculating the histogram data. + * These should match up with the p50LatencyMs, p95LatencyMs, etc. fields. + */ + final static float[] PERCENTILES = {0.5f, 0.95f, 0.99f}; + + @JsonCreator + StatusData(@JsonProperty("totalSent") long totalSent, + @JsonProperty("averageLatencyMs") float averageLatencyMs, + @JsonProperty("p50LatencyMs") int p50latencyMs, + @JsonProperty("p95LatencyMs") int p95latencyMs, + @JsonProperty("p99LatencyMs") int p99latencyMs) { + this.totalSent = totalSent; + this.averageLatencyMs = averageLatencyMs; + this.p50LatencyMs = p50latencyMs; + this.p95LatencyMs = p95latencyMs; + this.p99LatencyMs = p99latencyMs; + } + + @JsonProperty + public long totalSent() { + return totalSent; + } + + @JsonProperty + public float averageLatencyMs() { + return averageLatencyMs; + } + + @JsonProperty + public int p50LatencyMs() { + return p50LatencyMs; + } + + @JsonProperty + public int p95LatencyMs() { + return p95LatencyMs; + } + + @JsonProperty + public int p99LatencyMs() { + return p99LatencyMs; + } + } + + @Override + public void stop(Platform platform) throws Exception { + if (!running.compareAndSet(true, false)) { + throw new IllegalStateException("ConfigurableProducerWorker is not running."); + } + log.info("{}: Deactivating ConfigurableProducerWorker.", id); + doneFuture.complete(""); + executor.shutdownNow(); + executor.awaitTermination(1, TimeUnit.DAYS); + this.executor = null; + this.status = null; + this.doneFuture = null; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ConnectionStressSpec.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ConnectionStressSpec.java new file mode 100644 index 0000000..6141d30 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ConnectionStressSpec.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.kafka.trogdor.task.TaskController; +import org.apache.kafka.trogdor.task.TaskSpec; +import org.apache.kafka.trogdor.task.TaskWorker; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.TreeSet; + +/** + * The specification for a task which connects and disconnects many times a + * second to stress the broker. + */ +public class ConnectionStressSpec extends TaskSpec { + private final List clientNodes; + private final String bootstrapServers; + private final Map commonClientConf; + private final int targetConnectionsPerSec; + private final int numThreads; + private final ConnectionStressAction action; + + enum ConnectionStressAction { + CONNECT, + FETCH_METADATA + } + + @JsonCreator + public ConnectionStressSpec(@JsonProperty("startMs") long startMs, + @JsonProperty("durationMs") long durationMs, + @JsonProperty("clientNode") List clientNodes, + @JsonProperty("bootstrapServers") String bootstrapServers, + @JsonProperty("commonClientConf") Map commonClientConf, + @JsonProperty("targetConnectionsPerSec") int targetConnectionsPerSec, + @JsonProperty("numThreads") int numThreads, + @JsonProperty("action") ConnectionStressAction action) { + super(startMs, durationMs); + this.clientNodes = (clientNodes == null) ? Collections.emptyList() : + Collections.unmodifiableList(new ArrayList<>(clientNodes)); + this.bootstrapServers = (bootstrapServers == null) ? "" : bootstrapServers; + this.commonClientConf = configOrEmptyMap(commonClientConf); + this.targetConnectionsPerSec = targetConnectionsPerSec; + this.numThreads = numThreads < 1 ? 1 : numThreads; + this.action = (action == null) ? ConnectionStressAction.CONNECT : action; + } + + @JsonProperty + public List clientNode() { + return clientNodes; + } + + @JsonProperty + public String bootstrapServers() { + return bootstrapServers; + } + + @JsonProperty + public Map commonClientConf() { + return commonClientConf; + } + + @JsonProperty + public int targetConnectionsPerSec() { + return targetConnectionsPerSec; + } + + @JsonProperty + public int numThreads() { + return numThreads; + } + + @JsonProperty + public ConnectionStressAction action() { + return action; + } + + public TaskController newController(String id) { + return topology -> new TreeSet<>(clientNodes); + } + + @Override + public TaskWorker newTaskWorker(String id) { + return new ConnectionStressWorker(id, this); + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ConnectionStressWorker.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ConnectionStressWorker.java new file mode 100644 index 0000000..4af9227 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ConnectionStressWorker.java @@ -0,0 +1,322 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonNode; +import org.apache.kafka.clients.ApiVersions; +import org.apache.kafka.clients.ClientUtils; +import org.apache.kafka.clients.ManualMetadataUpdater; +import org.apache.kafka.clients.NetworkClient; +import org.apache.kafka.clients.NetworkClientUtils; +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.admin.AdminClientConfig; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.Cluster; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.metrics.Metrics; +import org.apache.kafka.common.network.ChannelBuilder; +import org.apache.kafka.common.network.Selector; +import org.apache.kafka.common.utils.LogContext; +import org.apache.kafka.common.utils.ThreadUtils; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.trogdor.common.JsonUtil; +import org.apache.kafka.trogdor.common.Platform; +import org.apache.kafka.trogdor.common.WorkerUtils; +import org.apache.kafka.trogdor.task.TaskWorker; +import org.apache.kafka.trogdor.task.WorkerStatusTracker; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.List; +import java.util.Properties; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +public class ConnectionStressWorker implements TaskWorker { + private static final Logger log = LoggerFactory.getLogger(ConnectionStressWorker.class); + private static final Time TIME = Time.SYSTEM; + + private static final int THROTTLE_PERIOD_MS = 100; + + private static final int REPORT_INTERVAL_MS = 5000; + + private final String id; + + private final ConnectionStressSpec spec; + + private final AtomicBoolean running = new AtomicBoolean(false); + + private KafkaFutureImpl doneFuture; + + private WorkerStatusTracker status; + + private long totalConnections; + + private long totalFailedConnections; + + private long startTimeMs; + + private Future statusUpdaterFuture; + + private ExecutorService workerExecutor; + + private ScheduledExecutorService statusUpdaterExecutor; + + public ConnectionStressWorker(String id, ConnectionStressSpec spec) { + this.id = id; + this.spec = spec; + } + + @Override + public void start(Platform platform, WorkerStatusTracker status, + KafkaFutureImpl doneFuture) throws Exception { + if (!running.compareAndSet(false, true)) { + throw new IllegalStateException("ConnectionStressWorker is already running."); + } + log.info("{}: Activating ConnectionStressWorker with {}", id, spec); + this.doneFuture = doneFuture; + this.status = status; + synchronized (ConnectionStressWorker.this) { + this.totalConnections = 0; + this.totalFailedConnections = 0; + this.startTimeMs = TIME.milliseconds(); + } + this.statusUpdaterExecutor = Executors.newScheduledThreadPool(1, + ThreadUtils.createThreadFactory("StatusUpdaterWorkerThread%d", false)); + this.statusUpdaterFuture = this.statusUpdaterExecutor.scheduleAtFixedRate( + new StatusUpdater(), 0, REPORT_INTERVAL_MS, TimeUnit.MILLISECONDS); + this.workerExecutor = Executors.newFixedThreadPool(spec.numThreads(), + ThreadUtils.createThreadFactory("ConnectionStressWorkerThread%d", false)); + for (int i = 0; i < spec.numThreads(); i++) { + this.workerExecutor.submit(new ConnectLoop()); + } + } + + private static class ConnectStressThrottle extends Throttle { + ConnectStressThrottle(int maxPerPeriod) { + super(maxPerPeriod, THROTTLE_PERIOD_MS); + } + } + + interface Stressor extends AutoCloseable { + static Stressor fromSpec(ConnectionStressSpec spec) { + switch (spec.action()) { + case CONNECT: + return new ConnectStressor(spec); + case FETCH_METADATA: + return new FetchMetadataStressor(spec); + } + throw new RuntimeException("invalid spec.action " + spec.action()); + } + + boolean tryConnect(); + } + + static class ConnectStressor implements Stressor { + private final AdminClientConfig conf; + private final ManualMetadataUpdater updater; + private final LogContext logContext = new LogContext(); + + ConnectStressor(ConnectionStressSpec spec) { + Properties props = new Properties(); + props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, spec.bootstrapServers()); + WorkerUtils.addConfigsToProperties(props, spec.commonClientConf(), spec.commonClientConf()); + this.conf = new AdminClientConfig(props); + List addresses = ClientUtils.parseAndValidateAddresses( + conf.getList(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG), + conf.getString(AdminClientConfig.CLIENT_DNS_LOOKUP_CONFIG)); + this.updater = new ManualMetadataUpdater(Cluster.bootstrap(addresses).nodes()); + } + + @Override + public boolean tryConnect() { + try { + List nodes = updater.fetchNodes(); + Node targetNode = nodes.get(ThreadLocalRandom.current().nextInt(nodes.size())); + // channelBuilder will be closed as part of Selector.close() + ChannelBuilder channelBuilder = ClientUtils.createChannelBuilder(conf, TIME, logContext); + try (Metrics metrics = new Metrics()) { + try (Selector selector = new Selector(conf.getLong(AdminClientConfig.CONNECTIONS_MAX_IDLE_MS_CONFIG), + metrics, TIME, "", channelBuilder, logContext)) { + try (NetworkClient client = new NetworkClient(selector, + updater, + "ConnectionStressWorker", + 1, + 1000, + 1000, + 4096, + 4096, + 1000, + 10 * 1000, + 127 * 1000, + TIME, + false, + new ApiVersions(), + logContext)) { + NetworkClientUtils.awaitReady(client, targetNode, TIME, 500); + } + } + } + return true; + } catch (IOException e) { + return false; + } + } + + @Override + public void close() throws Exception { + Utils.closeQuietly(updater, "ManualMetadataUpdater"); + } + } + + static class FetchMetadataStressor implements Stressor { + private final Properties props; + + FetchMetadataStressor(ConnectionStressSpec spec) { + this.props = new Properties(); + this.props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, spec.bootstrapServers()); + WorkerUtils.addConfigsToProperties(this.props, spec.commonClientConf(), spec.commonClientConf()); + } + + @Override + public boolean tryConnect() { + try (Admin client = Admin.create(this.props)) { + client.describeCluster().nodes().get(); + } catch (RuntimeException e) { + return false; + } catch (Exception e) { + return false; + } + return true; + } + + @Override + public void close() throws Exception { + } + } + + public class ConnectLoop implements Runnable { + @Override + public void run() { + Stressor stressor = Stressor.fromSpec(spec); + int rate = WorkerUtils.perSecToPerPeriod( + ((float) spec.targetConnectionsPerSec()) / spec.numThreads(), + THROTTLE_PERIOD_MS); + Throttle throttle = new ConnectStressThrottle(rate); + try { + while (!doneFuture.isDone()) { + throttle.increment(); + boolean success = stressor.tryConnect(); + synchronized (ConnectionStressWorker.this) { + totalConnections++; + if (!success) { + totalFailedConnections++; + } + } + } + } catch (Exception e) { + WorkerUtils.abort(log, "ConnectLoop", e, doneFuture); + } finally { + Utils.closeQuietly(stressor, "stressor"); + } + } + } + + private class StatusUpdater implements Runnable { + @Override + public void run() { + try { + long lastTimeMs = Time.SYSTEM.milliseconds(); + JsonNode node = null; + synchronized (ConnectionStressWorker.this) { + node = JsonUtil.JSON_SERDE.valueToTree( + new StatusData(totalConnections, totalFailedConnections, + (totalConnections * 1000.0) / (lastTimeMs - startTimeMs))); + } + status.update(node); + } catch (Exception e) { + WorkerUtils.abort(log, "StatusUpdater", e, doneFuture); + } + } + } + + public static class StatusData { + private final long totalConnections; + private final long totalFailedConnections; + private final double connectsPerSec; + + @JsonCreator + StatusData(@JsonProperty("totalConnections") long totalConnections, + @JsonProperty("totalFailedConnections") long totalFailedConnections, + @JsonProperty("connectsPerSec") double connectsPerSec) { + this.totalConnections = totalConnections; + this.totalFailedConnections = totalFailedConnections; + this.connectsPerSec = connectsPerSec; + } + + @JsonProperty + public long totalConnections() { + return totalConnections; + } + + @JsonProperty + public long totalFailedConnections() { + return totalFailedConnections; + } + + @JsonProperty + public double connectsPerSec() { + return connectsPerSec; + } + } + + @Override + public void stop(Platform platform) throws Exception { + if (!running.compareAndSet(true, false)) { + throw new IllegalStateException("ConnectionStressWorker is not running."); + } + log.info("{}: Deactivating ConnectionStressWorker.", id); + + // Shut down the periodic status updater and perform a final update on the + // statistics. We want to do this first, before deactivating any threads. + // Otherwise, if some threads take a while to terminate, this could lead + // to a misleading rate getting reported. + this.statusUpdaterFuture.cancel(false); + this.statusUpdaterExecutor.shutdown(); + this.statusUpdaterExecutor.awaitTermination(1, TimeUnit.DAYS); + this.statusUpdaterExecutor = null; + new StatusUpdater().run(); + + doneFuture.complete(""); + workerExecutor.shutdownNow(); + workerExecutor.awaitTermination(1, TimeUnit.DAYS); + this.workerExecutor = null; + this.status = null; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ConstantFlushGenerator.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ConstantFlushGenerator.java new file mode 100644 index 0000000..9d656b2 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ConstantFlushGenerator.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.common.errors.InterruptException; + +/** + * This generator will flush the producer after a specific number of messages. This is useful to simulate a specific + * number of messages in a batch regardless of the message size, since batch flushing is not exposed in the + * KafkaProducer client code. + * + * WARNING: This does not directly control when KafkaProducer will batch, this only makes best effort. This also + * cannot tell when a KafkaProducer batch is closed. If the KafkaProducer sends a batch before this executes, this + * will continue to execute on its own cadence. To alleviate this, make sure to set `linger.ms` to allow for at least + * `messagesPerFlush` messages to be generated, and make sure to set `batch.size` to allow for all these messages. + * + * Here is an example spec: + * + * { + * "type": "constant", + * "messagesPerFlush": 16 + * } + * + * This example will flush the producer every 16 messages. + */ + +public class ConstantFlushGenerator implements FlushGenerator { + private final int messagesPerFlush; + private int messageTracker = 0; + + @JsonCreator + public ConstantFlushGenerator(@JsonProperty("messagesPerFlush") int messagesPerFlush) { + this.messagesPerFlush = messagesPerFlush; + } + + @JsonProperty + public int messagesPerFlush() { + return messagesPerFlush; + } + + @Override + public synchronized void increment(KafkaProducer producer) { + // Increment the message tracker. + messageTracker += 1; + + // Flush when we reach the desired number of messages. + if (messageTracker >= messagesPerFlush) { + messageTracker = 0; + try { + producer.flush(); + } catch (InterruptException e) { + // Ignore flush interruption exceptions. + } + } + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ConstantPayloadGenerator.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ConstantPayloadGenerator.java new file mode 100644 index 0000000..d0c1c48 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ConstantPayloadGenerator.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * A PayloadGenerator which always generates a constant payload. + */ +public class ConstantPayloadGenerator implements PayloadGenerator { + private final int size; + private final byte[] value; + + @JsonCreator + public ConstantPayloadGenerator(@JsonProperty("size") int size, + @JsonProperty("value") byte[] value) { + this.size = size; + this.value = (value == null || value.length == 0) ? new byte[size] : value; + } + + @JsonProperty + public int size() { + return size; + } + + @JsonProperty + public byte[] value() { + return value; + } + + @Override + public byte[] generate(long position) { + byte[] next = new byte[size]; + for (int i = 0; i < next.length; i += value.length) { + System.arraycopy(value, 0, next, i, Math.min(next.length - i, value.length)); + } + return next; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ConstantThroughputGenerator.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ConstantThroughputGenerator.java new file mode 100644 index 0000000..9e5eeb9 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ConstantThroughputGenerator.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.kafka.common.utils.Time; + +/** + * This throughput generator configures constant throughput. + * + * The lower the window size, the smoother the traffic will be. Using a 100ms window offers no noticeable spikes in + * traffic while still being long enough to avoid too much overhead. + * + * Here is an example spec: + * + * { + * "type": "constant", + * "messagesPerWindow": 50, + * "windowSizeMs": 100 + * } + * + * This will produce a workload that runs 500 messages per second, with a maximum resolution of 50 messages per 100 + * millisecond. + * + * If `messagesPerWindow` is less than or equal to 0, `throttle` will not throttle at all and will return immediately. + */ + +public class ConstantThroughputGenerator implements ThroughputGenerator { + private final int messagesPerWindow; + private final long windowSizeMs; + + private long nextWindowStarts = 0; + private int messageTracker = 0; + + @JsonCreator + public ConstantThroughputGenerator(@JsonProperty("messagesPerWindow") int messagesPerWindow, + @JsonProperty("windowSizeMs") long windowSizeMs) { + // Calculate the default values. + if (windowSizeMs <= 0) { + windowSizeMs = 100; + } + this.windowSizeMs = windowSizeMs; + this.messagesPerWindow = messagesPerWindow; + calculateNextWindow(); + } + + @JsonProperty + public long windowSizeMs() { + return windowSizeMs; + } + + @JsonProperty + public int messagesPerWindow() { + return messagesPerWindow; + } + + private void calculateNextWindow() { + // Reset the message count. + messageTracker = 0; + + // Calculate the next window start time. + long now = Time.SYSTEM.milliseconds(); + if (nextWindowStarts > 0) { + while (nextWindowStarts <= now) { + nextWindowStarts += windowSizeMs; + } + } else { + nextWindowStarts = now + windowSizeMs; + } + } + + @Override + public synchronized void throttle() throws InterruptedException { + // Run unthrottled if messagesPerWindow is not positive. + if (messagesPerWindow <= 0) { + return; + } + + // Calculate the next window if we've moved beyond the current one. + if (Time.SYSTEM.milliseconds() >= nextWindowStarts) { + calculateNextWindow(); + } + + // Increment the message tracker. + messageTracker += 1; + + // Compare the tracked message count with the throttle limits. + if (messageTracker >= messagesPerWindow) { + + // Wait the difference in time between now and when the next window starts. + while (nextWindowStarts > Time.SYSTEM.milliseconds()) { + wait(nextWindowStarts - Time.SYSTEM.milliseconds()); + } + } + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ConsumeBenchSpec.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ConsumeBenchSpec.java new file mode 100644 index 0000000..6909d2d --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ConsumeBenchSpec.java @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.trogdor.common.StringExpander; +import org.apache.kafka.trogdor.task.TaskController; +import org.apache.kafka.trogdor.task.TaskSpec; +import org.apache.kafka.trogdor.task.TaskWorker; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.HashMap; +import java.util.Set; +import java.util.HashSet; +import java.util.Optional; + +/** + * The specification for a benchmark that consumer messages from a set of topic/partitions. + * + * If a consumer group is not given to the specification, a random one will be generated and + * used to track offsets/subscribe to topics. + * + * This specification uses a specific way to represent a topic partition via its "activeTopics" field. + * The notation for that is topic_name:partition_number (e.g "foo:1" represents partition-1 of topic "foo") + * Note that a topic name cannot have more than one colon. + * + * The "activeTopics" field also supports ranges that get expanded. See #{@link StringExpander}. + * + * There now exists a clever and succinct way to represent multiple partitions of multiple topics. + * Example: + * Given "activeTopics": ["foo[1-3]:[1-3]"], "foo[1-3]:[1-3]" will get + * expanded to [foo1:1, foo1:2, foo1:3, foo2:1, ..., foo3:3]. + * This represents all partitions 1-3 for the three topics foo1, foo2 and foo3. + * + * If there is at least one topic:partition pair, the consumer will be manually assigned partitions via + * #{@link org.apache.kafka.clients.consumer.KafkaConsumer#assign(Collection)}. + * Note that in this case the consumer will fetch and assign all partitions for a topic if no partition is given for it (e.g ["foo:1", "bar"]) + * + * If there are no topic:partition pairs given, the consumer will subscribe to the topics via + * #{@link org.apache.kafka.clients.consumer.KafkaConsumer#subscribe(Collection)}. + * It will be assigned partitions dynamically from the consumer group. + * + * This specification supports the spawning of multiple consumers in the single Trogdor worker agent. + * The "threadsPerWorker" field denotes how many consumers should be spawned for this spec. + * It is worth noting that the "targetMessagesPerSec", "maxMessages" and "activeTopics" fields apply for every consumer individually. + * + * If a consumer group is not specified, every consumer is assigned a different, random group. When specified, all consumers use the same group. + * Since no two consumers in the same group can be assigned the same partition, + * explicitly specifying partitions in "activeTopics" when there are multiple "threadsPerWorker" + * and a particular "consumerGroup" will result in an #{@link ConfigException}, aborting the task. + * + * The "recordProcessor" field allows the specification of tasks to run on records that are consumed. This is run + * immediately after the messages are polled. See the `RecordProcessor` interface for more information. + * + * An example JSON representation which will result in a consumer that is part of the consumer group "cg" and + * subscribed to topics foo1, foo2, foo3 and bar. + * #{@code + * { + * "class": "org.apache.kafka.trogdor.workload.ConsumeBenchSpec", + * "durationMs": 10000000, + * "consumerNode": "node0", + * "bootstrapServers": "localhost:9092", + * "maxMessages": 100, + * "consumerGroup": "cg", + * "activeTopics": ["foo[1-3]", "bar"] + * } + * } + */ +public class ConsumeBenchSpec extends TaskSpec { + + private static final String VALID_EXPANDED_TOPIC_NAME_PATTERN = "^[^:]+(:[\\d]+|[^:]*)$"; + private final String consumerNode; + private final String bootstrapServers; + private final int targetMessagesPerSec; + private final long maxMessages; + private final Map consumerConf; + private final Map adminClientConf; + private final Map commonClientConf; + private final List activeTopics; + private final String consumerGroup; + private final int threadsPerWorker; + private final Optional recordProcessor; + + @JsonCreator + public ConsumeBenchSpec(@JsonProperty("startMs") long startMs, + @JsonProperty("durationMs") long durationMs, + @JsonProperty("consumerNode") String consumerNode, + @JsonProperty("bootstrapServers") String bootstrapServers, + @JsonProperty("targetMessagesPerSec") int targetMessagesPerSec, + @JsonProperty("maxMessages") long maxMessages, + @JsonProperty("consumerGroup") String consumerGroup, + @JsonProperty("consumerConf") Map consumerConf, + @JsonProperty("commonClientConf") Map commonClientConf, + @JsonProperty("adminClientConf") Map adminClientConf, + @JsonProperty("threadsPerWorker") Integer threadsPerWorker, + @JsonProperty("recordProcessor") Optional recordProcessor, + @JsonProperty("activeTopics") List activeTopics) { + super(startMs, durationMs); + this.consumerNode = (consumerNode == null) ? "" : consumerNode; + this.bootstrapServers = (bootstrapServers == null) ? "" : bootstrapServers; + this.targetMessagesPerSec = targetMessagesPerSec; + this.maxMessages = maxMessages; + this.consumerConf = configOrEmptyMap(consumerConf); + this.commonClientConf = configOrEmptyMap(commonClientConf); + this.adminClientConf = configOrEmptyMap(adminClientConf); + this.activeTopics = activeTopics == null ? new ArrayList<>() : activeTopics; + this.consumerGroup = consumerGroup == null ? "" : consumerGroup; + this.threadsPerWorker = threadsPerWorker == null ? 1 : threadsPerWorker; + this.recordProcessor = recordProcessor; + } + + @JsonProperty + public String consumerNode() { + return consumerNode; + } + + @JsonProperty + public String consumerGroup() { + return consumerGroup; + } + + @JsonProperty + public String bootstrapServers() { + return bootstrapServers; + } + + @JsonProperty + public int targetMessagesPerSec() { + return targetMessagesPerSec; + } + + @JsonProperty + public long maxMessages() { + return maxMessages; + } + + @JsonProperty + public int threadsPerWorker() { + return threadsPerWorker; + } + + @JsonProperty + public Optional recordProcessor() { + return this.recordProcessor; + } + + @JsonProperty + public Map consumerConf() { + return consumerConf; + } + + @JsonProperty + public Map commonClientConf() { + return commonClientConf; + } + + @JsonProperty + public Map adminClientConf() { + return adminClientConf; + } + + @JsonProperty + public List activeTopics() { + return activeTopics; + } + + @Override + public TaskController newController(String id) { + return topology -> Collections.singleton(consumerNode); + } + + @Override + public TaskWorker newTaskWorker(String id) { + return new ConsumeBenchWorker(id, this); + } + + /** + * Materializes a list of topic names (optionally with ranges) into a map of the topics and their partitions + * + * Example: + * ['foo[1-3]', 'foobar:2', 'bar[1-2]:[1-2]'] => {'foo1': [], 'foo2': [], 'foo3': [], 'foobar': [2], + * 'bar1': [1, 2], 'bar2': [1, 2] } + */ + Map> materializeTopics() { + Map> partitionsByTopics = new HashMap<>(); + + for (String rawTopicName : this.activeTopics) { + Set expandedNames = expandTopicName(rawTopicName); + if (!expandedNames.iterator().next().matches(VALID_EXPANDED_TOPIC_NAME_PATTERN)) + throw new IllegalArgumentException(String.format("Expanded topic name %s is invalid", rawTopicName)); + + for (String topicName : expandedNames) { + TopicPartition partition = null; + if (topicName.contains(":")) { + String[] topicAndPartition = topicName.split(":"); + topicName = topicAndPartition[0]; + partition = new TopicPartition(topicName, Integer.parseInt(topicAndPartition[1])); + } + if (!partitionsByTopics.containsKey(topicName)) { + partitionsByTopics.put(topicName, new ArrayList<>()); + } + if (partition != null) { + partitionsByTopics.get(topicName).add(partition); + } + } + } + + return partitionsByTopics; + } + + /** + * Expands a topic name until there are no more ranges in it + */ + private Set expandTopicName(String topicName) { + Set expandedNames = StringExpander.expand(topicName); + if (expandedNames.size() == 1) { + return expandedNames; + } + + Set newNames = new HashSet<>(); + for (String name : expandedNames) { + newNames.addAll(expandTopicName(name)); + } + return newNames; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ConsumeBenchWorker.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ConsumeBenchWorker.java new file mode 100644 index 0000000..84ce1d3 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ConsumeBenchWorker.java @@ -0,0 +1,566 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonNode; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.apache.kafka.common.utils.ThreadUtils; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.trogdor.common.JsonUtil; +import org.apache.kafka.trogdor.common.Platform; +import org.apache.kafka.trogdor.common.WorkerUtils; +import org.apache.kafka.trogdor.task.WorkerStatusTracker; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.kafka.trogdor.task.TaskWorker; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.UUID; +import java.util.Optional; +import java.util.Properties; +import java.util.HashMap; +import java.util.concurrent.Callable; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.locks.ReentrantLock; +import java.util.stream.Collectors; + + +public class ConsumeBenchWorker implements TaskWorker { + private static final Logger log = LoggerFactory.getLogger(ConsumeBenchWorker.class); + + private static final int THROTTLE_PERIOD_MS = 100; + + private final String id; + private final ConsumeBenchSpec spec; + private final AtomicBoolean running = new AtomicBoolean(false); + private ScheduledExecutorService executor; + private WorkerStatusTracker workerStatus; + private StatusUpdater statusUpdater; + private Future statusUpdaterFuture; + private KafkaFutureImpl doneFuture; + private ThreadSafeConsumer consumer; + public ConsumeBenchWorker(String id, ConsumeBenchSpec spec) { + this.id = id; + this.spec = spec; + } + + @Override + public void start(Platform platform, WorkerStatusTracker status, + KafkaFutureImpl doneFuture) throws Exception { + if (!running.compareAndSet(false, true)) { + throw new IllegalStateException("ConsumeBenchWorker is already running."); + } + log.info("{}: Activating ConsumeBenchWorker with {}", id, spec); + this.statusUpdater = new StatusUpdater(); + this.executor = Executors.newScheduledThreadPool( + spec.threadsPerWorker() + 2, // 1 thread for all the ConsumeStatusUpdater and 1 for the StatusUpdater + ThreadUtils.createThreadFactory("ConsumeBenchWorkerThread%d", false)); + this.statusUpdaterFuture = executor.scheduleAtFixedRate(this.statusUpdater, 1, 1, TimeUnit.MINUTES); + this.workerStatus = status; + this.doneFuture = doneFuture; + executor.submit(new Prepare()); + } + + public class Prepare implements Runnable { + @Override + public void run() { + try { + List> consumeTasks = new ArrayList<>(); + for (ConsumeMessages task : consumeTasks()) { + consumeTasks.add(executor.submit(task)); + } + executor.submit(new CloseStatusUpdater(consumeTasks)); + } catch (Throwable e) { + WorkerUtils.abort(log, "Prepare", e, doneFuture); + } + } + + private List consumeTasks() { + List tasks = new ArrayList<>(); + String consumerGroup = consumerGroup(); + int consumerCount = spec.threadsPerWorker(); + Map> partitionsByTopic = spec.materializeTopics(); + boolean toUseGroupPartitionAssignment = partitionsByTopic.values().stream().allMatch(List::isEmpty); + + if (!toUseGroupPartitionAssignment && !toUseRandomConsumeGroup() && consumerCount > 1) + throw new ConfigException("You may not specify an explicit partition assignment when using multiple consumers in the same group." + + "Please leave the consumer group unset, specify topics instead of partitions or use a single consumer."); + + consumer = consumer(consumerGroup, clientId(0)); + if (toUseGroupPartitionAssignment) { + Set topics = partitionsByTopic.keySet(); + tasks.add(new ConsumeMessages(consumer, spec.recordProcessor(), topics)); + + for (int i = 0; i < consumerCount - 1; i++) { + tasks.add(new ConsumeMessages(consumer(consumerGroup(), clientId(i + 1)), spec.recordProcessor(), topics)); + } + } else { + List partitions = populatePartitionsByTopic(consumer.consumer(), partitionsByTopic) + .values().stream().flatMap(List::stream).collect(Collectors.toList()); + tasks.add(new ConsumeMessages(consumer, spec.recordProcessor(), partitions)); + + for (int i = 0; i < consumerCount - 1; i++) { + tasks.add(new ConsumeMessages(consumer(consumerGroup(), clientId(i + 1)), spec.recordProcessor(), partitions)); + } + } + + return tasks; + } + + private String clientId(int idx) { + return String.format("consumer.%s-%d", id, idx); + } + + /** + * Creates a new KafkaConsumer instance + */ + private ThreadSafeConsumer consumer(String consumerGroup, String clientId) { + Properties props = new Properties(); + props.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, spec.bootstrapServers()); + props.put(ConsumerConfig.CLIENT_ID_CONFIG, clientId); + props.put(ConsumerConfig.GROUP_ID_CONFIG, consumerGroup); + props.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + props.put(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG, 100000); + // these defaults maybe over-written by the user-specified commonClientConf or consumerConf + WorkerUtils.addConfigsToProperties(props, spec.commonClientConf(), spec.consumerConf()); + return new ThreadSafeConsumer(new KafkaConsumer<>(props, new ByteArrayDeserializer(), new ByteArrayDeserializer()), clientId); + } + + private String consumerGroup() { + return toUseRandomConsumeGroup() + ? "consume-bench-" + UUID.randomUUID().toString() + : spec.consumerGroup(); + } + + private boolean toUseRandomConsumeGroup() { + return spec.consumerGroup().isEmpty(); + } + + private Map> populatePartitionsByTopic(KafkaConsumer consumer, + Map> materializedTopics) { + // fetch partitions for topics who do not have any listed + for (Map.Entry> entry : materializedTopics.entrySet()) { + String topicName = entry.getKey(); + List partitions = entry.getValue(); + + if (partitions.isEmpty()) { + List fetchedPartitions = consumer.partitionsFor(topicName).stream() + .map(partitionInfo -> new TopicPartition(partitionInfo.topic(), partitionInfo.partition())) + .collect(Collectors.toList()); + partitions.addAll(fetchedPartitions); + } + + materializedTopics.put(topicName, partitions); + } + + return materializedTopics; + } + } + + public class ConsumeMessages implements Callable { + private final Histogram latencyHistogram; + private final Histogram messageSizeHistogram; + private final Future statusUpdaterFuture; + private final Throttle throttle; + private final String clientId; + private final ThreadSafeConsumer consumer; + private final Optional recordProcessor; + + private ConsumeMessages(ThreadSafeConsumer consumer, + Optional recordProcessor) { + this.latencyHistogram = new Histogram(10000); + this.messageSizeHistogram = new Histogram(2 * 1024 * 1024); + this.clientId = consumer.clientId(); + this.statusUpdaterFuture = executor.scheduleAtFixedRate( + new ConsumeStatusUpdater(latencyHistogram, messageSizeHistogram, consumer, recordProcessor), 1, 1, TimeUnit.MINUTES); + int perPeriod; + if (spec.targetMessagesPerSec() <= 0) + perPeriod = Integer.MAX_VALUE; + else + perPeriod = WorkerUtils.perSecToPerPeriod(spec.targetMessagesPerSec(), THROTTLE_PERIOD_MS); + + this.throttle = new Throttle(perPeriod, THROTTLE_PERIOD_MS); + this.consumer = consumer; + this.recordProcessor = recordProcessor; + } + + ConsumeMessages(ThreadSafeConsumer consumer, + Optional recordProcessor, + Set topics) { + this(consumer, recordProcessor); + log.info("Will consume from topics {} via dynamic group assignment.", topics); + this.consumer.subscribe(topics); + } + ConsumeMessages(ThreadSafeConsumer consumer, + Optional recordProcessor, + List partitions) { + this(consumer, recordProcessor); + log.info("Will consume from topic partitions {} via manual assignment.", partitions); + this.consumer.assign(partitions); + } + + @Override + public Void call() throws Exception { + long messagesConsumed = 0; + long bytesConsumed = 0; + long startTimeMs = Time.SYSTEM.milliseconds(); + long startBatchMs = startTimeMs; + long maxMessages = spec.maxMessages(); + try { + while (messagesConsumed < maxMessages) { + ConsumerRecords records = consumer.poll(); + if (records.isEmpty()) { + continue; + } + long endBatchMs = Time.SYSTEM.milliseconds(); + long elapsedBatchMs = endBatchMs - startBatchMs; + + // Do the record batch processing immediately to avoid latency skew. + recordProcessor.ifPresent(processor -> processor.processRecords(records)); + + for (ConsumerRecord record : records) { + messagesConsumed++; + long messageBytes = 0; + if (record.key() != null) { + messageBytes += record.serializedKeySize(); + } + if (record.value() != null) { + messageBytes += record.serializedValueSize(); + } + latencyHistogram.add(elapsedBatchMs); + messageSizeHistogram.add(messageBytes); + bytesConsumed += messageBytes; + if (messagesConsumed >= maxMessages) + break; + + throttle.increment(); + } + startBatchMs = Time.SYSTEM.milliseconds(); + } + } catch (Exception e) { + WorkerUtils.abort(log, "ConsumeRecords", e, doneFuture); + } finally { + statusUpdaterFuture.cancel(false); + StatusData statusData = + new ConsumeStatusUpdater(latencyHistogram, messageSizeHistogram, consumer, spec.recordProcessor()).update(); + long curTimeMs = Time.SYSTEM.milliseconds(); + log.info("{} Consumed total number of messages={}, bytes={} in {} ms. status: {}", + clientId, messagesConsumed, bytesConsumed, curTimeMs - startTimeMs, statusData); + } + consumer.close(); + return null; + } + } + + public class CloseStatusUpdater implements Runnable { + private final List> consumeTasks; + + CloseStatusUpdater(List> consumeTasks) { + this.consumeTasks = consumeTasks; + } + + @Override + public void run() { + while (!consumeTasks.stream().allMatch(Future::isDone)) { + try { + Thread.sleep(60000); + } catch (InterruptedException e) { + log.debug("{} was interrupted. Closing...", this.getClass().getName()); + break; // close the thread + } + } + statusUpdaterFuture.cancel(false); + statusUpdater.update(); + doneFuture.complete(""); + } + } + + class StatusUpdater implements Runnable { + final Map statuses; + + StatusUpdater() { + statuses = new HashMap<>(); + } + + @Override + public void run() { + try { + update(); + } catch (Exception e) { + WorkerUtils.abort(log, "ConsumeStatusUpdater", e, doneFuture); + } + } + + synchronized void update() { + workerStatus.update(JsonUtil.JSON_SERDE.valueToTree(statuses)); + } + + synchronized void updateConsumeStatus(String clientId, StatusData status) { + statuses.put(clientId, JsonUtil.JSON_SERDE.valueToTree(status)); + } + } + + /** + * Runnable class that updates the status of a single consumer + */ + public class ConsumeStatusUpdater implements Runnable { + private final Histogram latencyHistogram; + private final Histogram messageSizeHistogram; + private final ThreadSafeConsumer consumer; + private final Optional recordProcessor; + + ConsumeStatusUpdater(Histogram latencyHistogram, + Histogram messageSizeHistogram, + ThreadSafeConsumer consumer, + Optional recordProcessor) { + this.latencyHistogram = latencyHistogram; + this.messageSizeHistogram = messageSizeHistogram; + this.consumer = consumer; + this.recordProcessor = recordProcessor; + } + + @Override + public void run() { + try { + update(); + } catch (Exception e) { + WorkerUtils.abort(log, "ConsumeStatusUpdater", e, doneFuture); + } + } + + StatusData update() { + Histogram.Summary latSummary = latencyHistogram.summarize(StatusData.PERCENTILES); + Histogram.Summary msgSummary = messageSizeHistogram.summarize(StatusData.PERCENTILES); + + // Parse out the RecordProcessor's status, id specified. + Optional recordProcessorStatus = Optional.empty(); + if (recordProcessor.isPresent()) { + recordProcessorStatus = Optional.of(recordProcessor.get().processorStatus()); + } + + StatusData statusData = new StatusData( + consumer.assignedPartitions(), + latSummary.numSamples(), + (long) (msgSummary.numSamples() * msgSummary.average()), + (long) msgSummary.average(), + latSummary.average(), + latSummary.percentiles().get(0).value(), + latSummary.percentiles().get(1).value(), + latSummary.percentiles().get(2).value(), + recordProcessorStatus); + statusUpdater.updateConsumeStatus(consumer.clientId(), statusData); + log.info("Status={}", JsonUtil.toJsonString(statusData)); + return statusData; + } + } + + public static class StatusData { + private final long totalMessagesReceived; + private final List assignedPartitions; + private final long totalBytesReceived; + private final long averageMessageSizeBytes; + private final float averageLatencyMs; + private final int p50LatencyMs; + private final int p95LatencyMs; + private final int p99LatencyMs; + private final Optional recordProcessorStatus; + + /** + * The percentiles to use when calculating the histogram data. + * These should match up with the p50LatencyMs, p95LatencyMs, etc. fields. + */ + final static float[] PERCENTILES = {0.5f, 0.95f, 0.99f}; + @JsonCreator + StatusData(@JsonProperty("assignedPartitions") List assignedPartitions, + @JsonProperty("totalMessagesReceived") long totalMessagesReceived, + @JsonProperty("totalBytesReceived") long totalBytesReceived, + @JsonProperty("averageMessageSizeBytes") long averageMessageSizeBytes, + @JsonProperty("averageLatencyMs") float averageLatencyMs, + @JsonProperty("p50LatencyMs") int p50latencyMs, + @JsonProperty("p95LatencyMs") int p95latencyMs, + @JsonProperty("p99LatencyMs") int p99latencyMs, + @JsonProperty("recordProcessorStatus") Optional recordProcessorStatus) { + this.assignedPartitions = assignedPartitions; + this.totalMessagesReceived = totalMessagesReceived; + this.totalBytesReceived = totalBytesReceived; + this.averageMessageSizeBytes = averageMessageSizeBytes; + this.averageLatencyMs = averageLatencyMs; + this.p50LatencyMs = p50latencyMs; + this.p95LatencyMs = p95latencyMs; + this.p99LatencyMs = p99latencyMs; + this.recordProcessorStatus = recordProcessorStatus; + } + + @JsonProperty + public List assignedPartitions() { + return assignedPartitions; + } + + @JsonProperty + public long totalMessagesReceived() { + return totalMessagesReceived; + } + + @JsonProperty + public long totalBytesReceived() { + return totalBytesReceived; + } + + @JsonProperty + public long averageMessageSizeBytes() { + return averageMessageSizeBytes; + } + + @JsonProperty + public float averageLatencyMs() { + return averageLatencyMs; + } + + @JsonProperty + public int p50LatencyMs() { + return p50LatencyMs; + } + + @JsonProperty + public int p95LatencyMs() { + return p95LatencyMs; + } + + @JsonProperty + public int p99LatencyMs() { + return p99LatencyMs; + } + + @JsonProperty + public JsonNode recordProcessorStatus() { + return recordProcessorStatus.orElse(null); + } + } + + @Override + public void stop(Platform platform) throws Exception { + if (!running.compareAndSet(true, false)) { + throw new IllegalStateException("ConsumeBenchWorker is not running."); + } + log.info("{}: Deactivating ConsumeBenchWorker.", id); + doneFuture.complete(""); + executor.shutdownNow(); + executor.awaitTermination(1, TimeUnit.DAYS); + consumer.close(); + this.consumer = null; + this.executor = null; + this.statusUpdater = null; + this.statusUpdaterFuture = null; + this.workerStatus = null; + this.doneFuture = null; + } + + /** + * A thread-safe KafkaConsumer wrapper + */ + private static class ThreadSafeConsumer { + private final KafkaConsumer consumer; + private final String clientId; + private final ReentrantLock consumerLock; + private boolean closed = false; + + ThreadSafeConsumer(KafkaConsumer consumer, String clientId) { + this.consumer = consumer; + this.clientId = clientId; + this.consumerLock = new ReentrantLock(); + } + + ConsumerRecords poll() { + this.consumerLock.lock(); + try { + return consumer.poll(Duration.ofMillis(50)); + } finally { + this.consumerLock.unlock(); + } + } + + void close() { + if (closed) + return; + this.consumerLock.lock(); + try { + consumer.unsubscribe(); + Utils.closeQuietly(consumer, "consumer"); + closed = true; + } finally { + this.consumerLock.unlock(); + } + } + + void subscribe(Set topics) { + this.consumerLock.lock(); + try { + consumer.subscribe(topics); + } finally { + this.consumerLock.unlock(); + } + } + + void assign(Collection partitions) { + this.consumerLock.lock(); + try { + consumer.assign(partitions); + } finally { + this.consumerLock.unlock(); + } + } + + List assignedPartitions() { + this.consumerLock.lock(); + try { + return consumer.assignment().stream() + .map(TopicPartition::toString).collect(Collectors.toList()); + } finally { + this.consumerLock.unlock(); + } + } + + String clientId() { + return clientId; + } + + KafkaConsumer consumer() { + return consumer; + } + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ExternalCommandSpec.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ExternalCommandSpec.java new file mode 100644 index 0000000..4947aed --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ExternalCommandSpec.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonNode; + +import com.fasterxml.jackson.databind.node.NullNode; +import org.apache.kafka.trogdor.task.TaskController; +import org.apache.kafka.trogdor.task.TaskSpec; +import org.apache.kafka.trogdor.task.TaskWorker; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** + * ExternalCommandSpec describes a task that executes Trogdor tasks with the command. + * + * An example uses the python runner to execute the ProduceBenchSpec task. + * + * #{@code + * { + * "class": "org.apache.kafka.trogdor.workload.ExternalCommandSpec", + * "command": ["python", "/path/to/trogdor/python/runner"], + * "durationMs": 10000000, + * "producerNode": "node0", + * "workload": { + * "class": "org.apache.kafka.trogdor.workload.ProduceBenchSpec", + * "bootstrapServers": "localhost:9092", + * "targetMessagesPerSec": 10, + * "maxMessages": 100, + * "activeTopics": { + * "foo[1-3]": { + * "numPartitions": 3, + * "replicationFactor": 1 + * } + * }, + * "inactiveTopics": { + * "foo[4-5]": { + * "numPartitions": 3, + * "replicationFactor": 1 + * } + * } + * } + * } + */ +public class ExternalCommandSpec extends TaskSpec { + private final String commandNode; + private final List command; + private final JsonNode workload; + private final Optional shutdownGracePeriodMs; + + @JsonCreator + public ExternalCommandSpec( + @JsonProperty("startMs") long startMs, + @JsonProperty("durationMs") long durationMs, + @JsonProperty("commandNode") String commandNode, + @JsonProperty("command") List command, + @JsonProperty("workload") JsonNode workload, + @JsonProperty("shutdownGracePeriodMs") Optional shutdownGracePeriodMs) { + super(startMs, durationMs); + this.commandNode = (commandNode == null) ? "" : commandNode; + this.command = (command == null) ? Collections.unmodifiableList(new ArrayList()) : command; + this.workload = (workload == null) ? NullNode.instance : workload; + this.shutdownGracePeriodMs = shutdownGracePeriodMs; + } + + @JsonProperty + public String commandNode() { + return commandNode; + } + + @JsonProperty + public List command() { + return command; + } + + @JsonProperty + public JsonNode workload() { + return workload; + } + + @JsonProperty + public Optional shutdownGracePeriodMs() { + return shutdownGracePeriodMs; + } + + @Override + public TaskController newController(String id) { + return topology -> Collections.singleton(commandNode); + } + + @Override + public TaskWorker newTaskWorker(String id) { + return new ExternalCommandWorker(id, this); + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ExternalCommandWorker.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ExternalCommandWorker.java new file mode 100644 index 0000000..acc195e --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ExternalCommandWorker.java @@ -0,0 +1,398 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.JsonNodeFactory; +import com.fasterxml.jackson.databind.node.NullNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.fasterxml.jackson.databind.node.TextNode; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.utils.ThreadUtils; +import org.apache.kafka.trogdor.common.JsonUtil; +import org.apache.kafka.trogdor.common.Platform; +import org.apache.kafka.trogdor.task.TaskWorker; +import org.apache.kafka.trogdor.task.WorkerStatusTracker; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.BufferedReader; +import java.io.IOException; + +import java.io.InputStreamReader; +import java.io.OutputStreamWriter; +import java.nio.charset.StandardCharsets; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.Optional; + +/** + * ExternalCommandWorker starts an external process to run a Trogdor command. + * + * The worker communicates with the external process over the standard input and output streams. + * + * When the process is first launched, ExternalCommandWorker will send a message on standard + * input describing the task ID and the workload. This message will not contain line breaks. + * It will have this JSON format: + * {"id":, "workload":} + * + * ExternalCommandWorker will log anything that the process writes to stderr, but will take + * no other action with it. + * + * If the process sends a single-line JSON object to stdout, ExternalCommandWorker will parse it. + * The JSON object can contain the following fields: + * - status: If the object contains this field, the status will be set to the given value. + * - error: If the object contains this field, the error will be set to the given value. + * Once an error occurs, we will try to terminate the process. + * - log: If the object contains this field, a log message will be issued with this text. + * + * Note that standard output is buffered by default. The subprocess may wish + * to flush it after writing its status JSON. This will ensure that the status + * is seen in a timely fashion. + * + * If the process sends a non-JSON line to stdout, the worker will log it. + * + * If the process exits, ExternalCommandWorker will finish. If the process exits unsuccessfully, + * this is considered an error. If the worker needs to stop the process, it will start by sending + * a SIGTERM. If this does not have the required effect, it will send a SIGKILL, once the shutdown + * grace period has elapsed. + */ +public class ExternalCommandWorker implements TaskWorker { + private static final Logger log = LoggerFactory.getLogger(ExternalCommandWorker.class); + + private static final int DEFAULT_SHUTDOWN_GRACE_PERIOD_MS = 5000; + + /** + * True only if the worker is running. + */ + private final AtomicBoolean running = new AtomicBoolean(false); + + enum TerminatorAction { + DESTROY, + DESTROY_FORCIBLY, + CLOSE + } + + /** + * A queue used to communicate with the signal sender thread. + */ + private final LinkedBlockingQueue terminatorActionQueue = new LinkedBlockingQueue<>(); + + /** + * The queue of objects to write to the process stdin. + */ + private final LinkedBlockingQueue> stdinQueue = new LinkedBlockingQueue<>(); + + /** + * The task ID. + */ + private final String id; + + /** + * The command specification. + */ + private final ExternalCommandSpec spec; + + /** + * Tracks the worker status. + */ + private WorkerStatusTracker status; + + /** + * A future which should be completed when this worker is done. + */ + private KafkaFutureImpl doneFuture; + + /** + * The executor service for this worker. + */ + private ExecutorService executor; + + public ExternalCommandWorker(String id, ExternalCommandSpec spec) { + this.id = id; + this.spec = spec; + } + + @Override + public void start(Platform platform, WorkerStatusTracker status, + KafkaFutureImpl doneFuture) throws Exception { + if (!running.compareAndSet(false, true)) { + throw new IllegalStateException("ConsumeBenchWorker is already running."); + } + log.info("{}: Activating ExternalCommandWorker with {}", id, spec); + this.status = status; + this.doneFuture = doneFuture; + this.executor = Executors.newCachedThreadPool( + ThreadUtils.createThreadFactory("ExternalCommandWorkerThread%d", false)); + Process process = null; + try { + process = startProcess(); + } catch (Throwable t) { + log.error("{}: Unable to start process", id, t); + executor.shutdown(); + doneFuture.complete("Unable to start process: " + t.getMessage()); + return; + } + Future stdoutFuture = executor.submit(new StdoutMonitor(process)); + Future stderrFuture = executor.submit(new StderrMonitor(process)); + executor.submit(new StdinWriter(process)); + Future terminatorFuture = executor.submit(new Terminator(process)); + executor.submit(new ExitMonitor(process, stdoutFuture, stderrFuture, terminatorFuture)); + ObjectNode startMessage = new ObjectNode(JsonNodeFactory.instance); + startMessage.set("id", new TextNode(id)); + startMessage.set("workload", spec.workload()); + stdinQueue.add(Optional.of(startMessage)); + } + + private Process startProcess() throws Exception { + if (spec.command().isEmpty()) { + throw new RuntimeException("No command specified"); + } + ProcessBuilder bld = new ProcessBuilder(spec.command()); + Process process = bld.start(); + return process; + } + + private static JsonNode readObject(String line) { + JsonNode resp; + try { + resp = JsonUtil.JSON_SERDE.readTree(line); + } catch (IOException e) { + return NullNode.instance; + } + return resp; + } + + class StdoutMonitor implements Runnable { + private final Process process; + + StdoutMonitor(Process process) { + this.process = process; + } + + @Override + public void run() { + log.trace("{}: starting stdout monitor.", id); + try (BufferedReader br = new BufferedReader( + new InputStreamReader(process.getInputStream(), StandardCharsets.UTF_8))) { + String line; + while (true) { + try { + line = br.readLine(); + if (line == null) { + throw new IOException("EOF"); + } + } catch (IOException e) { + log.info("{}: can't read any more from stdout: {}", id, e.getMessage()); + return; + } + log.trace("{}: read line from stdin: {}", id, line); + JsonNode resp = readObject(line); + if (resp.has("status")) { + log.info("{}: New status: {}", id, resp.get("status").toString()); + status.update(resp.get("status")); + } + if (resp.has("log")) { + log.info("{}: (stdout): {}", id, resp.get("log").asText()); + } + if (resp.has("error")) { + String error = resp.get("error").asText(); + log.error("{}: error: {}", id, error); + doneFuture.complete(error); + } + } + } catch (Throwable e) { + log.info("{}: error reading from stdout.", id, e); + } + } + } + + class StderrMonitor implements Runnable { + private final Process process; + + StderrMonitor(Process process) { + this.process = process; + } + + @Override + public void run() { + log.trace("{}: starting stderr monitor.", id); + try (BufferedReader br = new BufferedReader( + new InputStreamReader(process.getErrorStream(), StandardCharsets.UTF_8))) { + String line; + while (true) { + try { + line = br.readLine(); + if (line == null) { + throw new IOException("EOF"); + } + } catch (IOException e) { + log.info("{}: can't read any more from stderr: {}", id, e.getMessage()); + return; + } + log.error("{}: (stderr):{}", id, line); + } + } catch (Throwable e) { + log.info("{}: error reading from stderr.", id, e); + } + } + } + + class StdinWriter implements Runnable { + private final Process process; + + StdinWriter(Process process) { + this.process = process; + } + + @Override + public void run() { + OutputStreamWriter stdinWriter = new OutputStreamWriter( + process.getOutputStream(), StandardCharsets.UTF_8); + try { + while (true) { + log.info("{}: stdin writer ready.", id); + Optional node = stdinQueue.take(); + if (!node.isPresent()) { + log.trace("{}: StdinWriter terminating.", id); + return; + } + String inputString = JsonUtil.toJsonString(node.get()); + log.info("{}: writing to stdin: {}", id, inputString); + stdinWriter.write(inputString + "\n"); + stdinWriter.flush(); + } + } catch (IOException e) { + log.info("{}: can't write any more to stdin: {}", id, e.getMessage()); + } catch (Throwable e) { + log.info("{}: error writing to stdin.", id, e); + } finally { + try { + stdinWriter.close(); + } catch (IOException e) { + log.debug("{}: error closing stdinWriter: {}", id, e.getMessage()); + } + } + } + } + + class ExitMonitor implements Runnable { + private final Process process; + private final Future stdoutFuture; + private final Future stderrFuture; + private final Future terminatorFuture; + + ExitMonitor(Process process, Future stdoutFuture, Future stderrFuture, + Future terminatorFuture) { + this.process = process; + this.stdoutFuture = stdoutFuture; + this.stderrFuture = stderrFuture; + this.terminatorFuture = terminatorFuture; + } + + @Override + public void run() { + try { + int exitStatus = process.waitFor(); + log.info("{}: process exited with return code {}", id, exitStatus); + // Wait for the stdout and stderr monitors to exit. It's particularly important + // to wait for the stdout monitor to exit since there may be an error or status + // there that we haven't seen yet. + stdoutFuture.get(); + stderrFuture.get(); + // Try to complete doneFuture with an error status based on the exit code. Note + // that if doneFuture was already completed previously, this will have no effect. + if (exitStatus == 0) { + doneFuture.complete(""); + } else { + doneFuture.complete("exited with return code " + exitStatus); + } + // Tell the StdinWriter thread to exit. + stdinQueue.add(Optional.empty()); + // Tell the shutdown manager thread to exit. + terminatorActionQueue.add(TerminatorAction.CLOSE); + terminatorFuture.get(); + executor.shutdown(); + } catch (Throwable e) { + log.error("{}: ExitMonitor error", id, e); + doneFuture.complete("ExitMonitor error: " + e.getMessage()); + } + } + } + + /** + * The thread which manages terminating the child process. + */ + class Terminator implements Runnable { + private final Process process; + + Terminator(Process process) { + this.process = process; + } + + @Override + public void run() { + try { + while (true) { + switch (terminatorActionQueue.take()) { + case DESTROY: + log.info("{}: destroying process", id); + process.getInputStream().close(); + process.getErrorStream().close(); + process.destroy(); + break; + case DESTROY_FORCIBLY: + log.info("{}: forcibly destroying process", id); + process.getInputStream().close(); + process.getErrorStream().close(); + process.destroyForcibly(); + break; + case CLOSE: + log.trace("{}: closing Terminator thread.", id); + return; + } + } + } catch (Throwable e) { + log.error("{}: Terminator error", id, e); + doneFuture.complete("Terminator error: " + e.getMessage()); + } + } + } + + @Override + public void stop(Platform platform) throws Exception { + if (!running.compareAndSet(true, false)) { + throw new IllegalStateException("ExternalCommandWorker is not running."); + } + log.info("{}: Deactivating ExternalCommandWorker.", id); + terminatorActionQueue.add(TerminatorAction.DESTROY); + int shutdownGracePeriodMs = spec.shutdownGracePeriodMs().isPresent() ? + spec.shutdownGracePeriodMs().get() : DEFAULT_SHUTDOWN_GRACE_PERIOD_MS; + if (!executor.awaitTermination(shutdownGracePeriodMs, TimeUnit.MILLISECONDS)) { + terminatorActionQueue.add(TerminatorAction.DESTROY_FORCIBLY); + executor.awaitTermination(1, TimeUnit.DAYS); + } + this.status = null; + this.doneFuture = null; + this.executor = null; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/FlushGenerator.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/FlushGenerator.java new file mode 100644 index 0000000..8c0ae62 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/FlushGenerator.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import org.apache.kafka.clients.producer.KafkaProducer; + +/** + * This interface is used to facilitate flushing the KafkaProducers on a cadence specified by the user. + * + * Currently there are 3 flushing methods: + * + * * Disabled, by not specifying this parameter. + * * `constant` will use `ConstantFlushGenerator` to keep the number of messages per batch constant. + * * `gaussian` will use `GaussianFlushGenerator` to vary the number of messages per batch on a normal distribution. + * + * Please see the implementation classes for more details. + */ + +@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, + include = JsonTypeInfo.As.PROPERTY, + property = "type") +@JsonSubTypes(value = { + @JsonSubTypes.Type(value = ConstantFlushGenerator.class, name = "constant"), + @JsonSubTypes.Type(value = GaussianFlushGenerator.class, name = "gaussian") + }) +public interface FlushGenerator { + void increment(KafkaProducer producer); +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/GaussianFlushGenerator.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/GaussianFlushGenerator.java new file mode 100644 index 0000000..eb6845e --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/GaussianFlushGenerator.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.common.errors.InterruptException; +import java.util.Random; + +/** + * This generator will flush the producer after a specific number of messages, determined by a gaussian distribution. + * This is useful to simulate a specific number of messages in a batch regardless of the message size, since batch + * flushing is not exposed in the KafkaProducer. + * + * WARNING: This does not directly control when KafkaProducer will batch, this only makes best effort. This also + * cannot tell when a KafkaProducer batch is closed. If the KafkaProducer sends a batch before this executes, this + * will continue to execute on its own cadence. To alleviate this, make sure to set `linger.ms` to allow for messages + * to be generated up to your upper limit threshold, and make sure to set `batch.size` to allow for all these messages. + * + * Here is an example spec: + * + * { + * "type": "gaussian", + * "messagesPerFlushAverage": 16, + * "messagesPerFlushDeviation": 4 + * } + * + * This example will flush the producer on average every 16 messages, assuming `linger.ms` and `batch.size` allow for + * it. That average changes based on a normal distribution after each flush: + * + * An average of the flushes will be at 16 messages. + * ~68% of the flushes are at between 12 and 20 messages. + * ~95% of the flushes are at between 8 and 24 messages. + * ~99% of the flushes are at between 4 and 28 messages. + */ + +public class GaussianFlushGenerator implements FlushGenerator { + private final int messagesPerFlushAverage; + private final double messagesPerFlushDeviation; + + private final Random random = new Random(); + + private int messageTracker = 0; + private int flushSize = 0; + + @JsonCreator + public GaussianFlushGenerator(@JsonProperty("messagesPerFlushAverage") int messagesPerFlushAverage, + @JsonProperty("messagesPerFlushDeviation") double messagesPerFlushDeviation) { + this.messagesPerFlushAverage = messagesPerFlushAverage; + this.messagesPerFlushDeviation = messagesPerFlushDeviation; + calculateFlushSize(); + } + + @JsonProperty + public int messagesPerFlushAverage() { + return messagesPerFlushAverage; + } + + @JsonProperty + public double messagesPerFlushDeviation() { + return messagesPerFlushDeviation; + } + + private synchronized void calculateFlushSize() { + flushSize = Math.max((int) (random.nextGaussian() * messagesPerFlushDeviation) + messagesPerFlushAverage, 1); + messageTracker = 0; + } + + @Override + public synchronized void increment(KafkaProducer producer) { + // Increment the message tracker. + messageTracker += 1; + + // Compare the tracked message count with the throttle limits. + if (messageTracker >= flushSize) { + try { + producer.flush(); + } catch (InterruptException e) { + // Ignore flush interruption exceptions. + } + calculateFlushSize(); + } + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/GaussianThroughputGenerator.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/GaussianThroughputGenerator.java new file mode 100644 index 0000000..a77298f --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/GaussianThroughputGenerator.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.kafka.common.utils.Time; +import java.util.Random; + +/* + * This throughput generator configures throughput with a gaussian normal distribution on a per-window basis. You can + * specify how many windows to keep the throughput at the rate before changing. All traffic will follow a gaussian + * distribution centered around `messagesPerWindowAverage` with a deviation of `messagesPerWindowDeviation`. + * + * The lower the window size, the smoother the traffic will be. Using a 100ms window offers no noticeable spikes in + * traffic while still being long enough to avoid too much overhead. + * + * Here is an example spec: + * + * { + * "type": "gaussian", + * "messagesPerWindowAverage": 50, + * "messagesPerWindowDeviation": 5, + * "windowsUntilRateChange": 100, + * "windowSizeMs": 100 + * } + * + * This will produce a workload that runs on average 500 messages per second, however that speed will change every 10 + * seconds due to the `windowSizeMs * windowsUntilRateChange` parameters. The throughput will have the following + * normal distribution: + * + * An average of the throughput windows of 500 messages per second. + * ~68% of the throughput windows are between 450 and 550 messages per second. + * ~95% of the throughput windows are between 400 and 600 messages per second. + * ~99% of the throughput windows are between 350 and 650 messages per second. + * + */ + +public class GaussianThroughputGenerator implements ThroughputGenerator { + private final int messagesPerWindowAverage; + private final double messagesPerWindowDeviation; + private final int windowsUntilRateChange; + private final long windowSizeMs; + + private final Random random = new Random(); + + private long nextWindowStarts = 0; + private int messageTracker = 0; + private int windowTracker = 0; + private int throttleMessages = 0; + + @JsonCreator + public GaussianThroughputGenerator(@JsonProperty("messagesPerWindowAverage") int messagesPerWindowAverage, + @JsonProperty("messagesPerWindowDeviation") double messagesPerWindowDeviation, + @JsonProperty("windowsUntilRateChange") int windowsUntilRateChange, + @JsonProperty("windowSizeMs") long windowSizeMs) { + // Calculate the default values. + if (windowSizeMs <= 0) { + windowSizeMs = 100; + } + this.windowSizeMs = windowSizeMs; + this.messagesPerWindowAverage = messagesPerWindowAverage; + this.messagesPerWindowDeviation = messagesPerWindowDeviation; + this.windowsUntilRateChange = windowsUntilRateChange; + + // Calculate the first window. + calculateNextWindow(true); + } + + @JsonProperty + public int messagesPerWindowAverage() { + return messagesPerWindowAverage; + } + + @JsonProperty + public double messagesPerWindowDeviation() { + return messagesPerWindowDeviation; + } + + @JsonProperty + public long windowsUntilRateChange() { + return windowsUntilRateChange; + } + + @JsonProperty + public long windowSizeMs() { + return windowSizeMs; + } + + private synchronized void calculateNextWindow(boolean force) { + // Reset the message count. + messageTracker = 0; + + // Calculate the next window start time. + long now = Time.SYSTEM.milliseconds(); + if (nextWindowStarts > 0) { + while (nextWindowStarts < now) { + nextWindowStarts += windowSizeMs; + } + } else { + nextWindowStarts = now + windowSizeMs; + } + + // Check the windows between rate changes. + if ((windowTracker > windowsUntilRateChange) || force) { + windowTracker = 0; + + // Calculate the number of messages allowed in this window using a normal distribution. + // The formula is: Messages = Gaussian * Deviation + Average + throttleMessages = Math.max((int) (random.nextGaussian() * messagesPerWindowDeviation) + messagesPerWindowAverage, 1); + } + windowTracker += 1; + } + + @Override + public synchronized void throttle() throws InterruptedException { + // Calculate the next window if we've moved beyond the current one. + if (Time.SYSTEM.milliseconds() >= nextWindowStarts) { + calculateNextWindow(false); + } + + // Increment the message tracker. + messageTracker += 1; + + // Compare the tracked message count with the throttle limits. + if (messageTracker >= throttleMessages) { + + // Wait the difference in time between now and when the next window starts. + while (nextWindowStarts > Time.SYSTEM.milliseconds()) { + wait(nextWindowStarts - Time.SYSTEM.milliseconds()); + } + } + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/GaussianTimestampConstantPayloadGenerator.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/GaussianTimestampConstantPayloadGenerator.java new file mode 100644 index 0000000..8660ed3 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/GaussianTimestampConstantPayloadGenerator.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.kafka.common.utils.Time; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Random; + +/** + * This class behaves identically to TimestampConstantPayloadGenerator, except the message size follows a gaussian + * distribution. + * + * This should be used in conjunction with TimestampRecordProcessor in the Consumer to measure true end-to-end latency + * of a system. + * + * `messageSizeAverage` - The average size in bytes of each message. + * `messageSizeDeviation` - The standard deviation to use when calculating message size. + * `messagesUntilSizeChange` - The number of messages to keep at the same size. + * + * Here is an example spec: + * + * { + * "type": "gaussianTimestampConstant", + * "messageSizeAverage": 512, + * "messageSizeDeviation": 100, + * "messagesUntilSizeChange": 100 + * } + * + * This will generate messages on a gaussian distribution with an average size each 512-bytes. The message sizes will + * have a standard deviation of 100 bytes, and the size will only change every 100 messages. The distribution of + * messages will be as follows: + * + * The average size of the messages are 512 bytes. + * ~68% of the messages are between 412 and 612 bytes + * ~95% of the messages are between 312 and 712 bytes + * ~99% of the messages are between 212 and 812 bytes + */ + +public class GaussianTimestampConstantPayloadGenerator implements PayloadGenerator { + private final int messageSizeAverage; + private final double messageSizeDeviation; + private final int messagesUntilSizeChange; + private final long seed; + + private final Random random = new Random(); + private final ByteBuffer buffer; + + private int messageTracker = 0; + private int messageSize = 0; + + @JsonCreator + public GaussianTimestampConstantPayloadGenerator(@JsonProperty("messageSizeAverage") int messageSizeAverage, + @JsonProperty("messageSizeDeviation") double messageSizeDeviation, + @JsonProperty("messagesUntilSizeChange") int messagesUntilSizeChange, + @JsonProperty("seed") long seed) { + this.messageSizeAverage = messageSizeAverage; + this.messageSizeDeviation = messageSizeDeviation; + this.seed = seed; + this.messagesUntilSizeChange = messagesUntilSizeChange; + buffer = ByteBuffer.allocate(Long.BYTES); + buffer.order(ByteOrder.LITTLE_ENDIAN); + } + + @JsonProperty + public int messageSizeAverage() { + return messageSizeAverage; + } + + @JsonProperty + public double messageSizeDeviation() { + return messageSizeDeviation; + } + + @JsonProperty + public int messagesUntilSizeChange() { + return messagesUntilSizeChange; + } + + @JsonProperty + public long seed() { + return seed; + } + + @Override + public synchronized byte[] generate(long position) { + // Make the random number generator deterministic for unit tests. + random.setSeed(seed + position); + + // Calculate the next message size based on a gaussian distribution. + if ((messageSize == 0) || (messageTracker >= messagesUntilSizeChange)) { + messageTracker = 0; + messageSize = Math.max((int) (random.nextGaussian() * messageSizeDeviation) + messageSizeAverage, Long.BYTES); + } + messageTracker += 1; + + // Generate the byte array before the timestamp generation. + byte[] result = new byte[messageSize]; + + // Do the timestamp generation as the very last task. + buffer.clear(); + buffer.putLong(Time.SYSTEM.milliseconds()); + buffer.rewind(); + System.arraycopy(buffer.array(), 0, result, 0, Long.BYTES); + return result; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/GaussianTimestampRandomPayloadGenerator.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/GaussianTimestampRandomPayloadGenerator.java new file mode 100644 index 0000000..48261a4 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/GaussianTimestampRandomPayloadGenerator.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.kafka.common.utils.Time; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Random; + +/** + * This class behaves identically to TimestampRandomPayloadGenerator, except the message size follows a gaussian + * distribution. + * + * This should be used in conjunction with TimestampRecordProcessor in the Consumer to measure true end-to-end latency + * of a system. + * + * `messageSizeAverage` - The average size in bytes of each message. + * `messageSizeDeviation` - The standard deviation to use when calculating message size. + * `messagesUntilSizeChange` - The number of messages to keep at the same size. + * `seed` - Used to initialize Random() to remove some non-determinism. + * + * Here is an example spec: + * + * { + * "type": "gaussianTimestampRandom", + * "messageSizeAverage": 512, + * "messageSizeDeviation": 100, + * "messagesUntilSizeChange": 100 + * } + * + * This will generate messages on a gaussian distribution with an average size each 512-bytes. The message sizes will + * have a standard deviation of 100 bytes, and the size will only change every 100 messages. The distribution of + * messages will be as follows: + * + * The average size of the messages are 512 bytes. + * ~68% of the messages are between 412 and 612 bytes + * ~95% of the messages are between 312 and 712 bytes + * ~99% of the messages are between 212 and 812 bytes + */ + +public class GaussianTimestampRandomPayloadGenerator implements PayloadGenerator { + private final int messageSizeAverage; + private final double messageSizeDeviation; + private final int messagesUntilSizeChange; + private final long seed; + + private final Random random = new Random(); + private final ByteBuffer buffer; + + private int messageTracker = 0; + private int messageSize = 0; + + @JsonCreator + public GaussianTimestampRandomPayloadGenerator(@JsonProperty("messageSizeAverage") int messageSizeAverage, + @JsonProperty("messageSizeDeviation") double messageSizeDeviation, + @JsonProperty("messagesUntilSizeChange") int messagesUntilSizeChange, + @JsonProperty("seed") long seed) { + this.messageSizeAverage = messageSizeAverage; + this.messageSizeDeviation = messageSizeDeviation; + this.seed = seed; + this.messagesUntilSizeChange = messagesUntilSizeChange; + buffer = ByteBuffer.allocate(Long.BYTES); + buffer.order(ByteOrder.LITTLE_ENDIAN); + } + + @JsonProperty + public int messageSizeAverage() { + return messageSizeAverage; + } + + @JsonProperty + public double messageSizeDeviation() { + return messageSizeDeviation; + } + + @JsonProperty + public int messagesUntilSizeChange() { + return messagesUntilSizeChange; + } + + @JsonProperty + public long seed() { + return seed; + } + + @Override + public synchronized byte[] generate(long position) { + // Make the random number generator deterministic for unit tests. + random.setSeed(seed + position); + + // Calculate the next message size based on a gaussian distribution. + if ((messageSize == 0) || (messageTracker >= messagesUntilSizeChange)) { + messageTracker = 0; + messageSize = Math.max((int) (random.nextGaussian() * messageSizeDeviation) + messageSizeAverage, Long.BYTES); + } + messageTracker += 1; + + // Generate out of order to prevent inclusion of random number generation in latency numbers. + byte[] result = new byte[messageSize]; + random.nextBytes(result); + + // Do the timestamp generation as the very last task. + buffer.clear(); + buffer.putLong(Time.SYSTEM.milliseconds()); + buffer.rewind(); + System.arraycopy(buffer.array(), 0, result, 0, Long.BYTES); + return result; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/Histogram.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/Histogram.java new file mode 100644 index 0000000..cee2b4a --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/Histogram.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * A histogram that can easily find the average, median etc of a large number of samples in a + * restricted domain. + */ +public class Histogram { + private final int[] counts; + + private final Logger log = LoggerFactory.getLogger(Histogram.class); + + public Histogram(int maxValue) { + this.counts = new int[maxValue + 1]; + } + + /** + * Add a new value to the histogram. + * + * Note that the value will be clipped to the maximum value available in the Histogram instance. + * So if the histogram has 100 buckets, inserting 101 will increment the last bucket. + */ + public void add(int value) { + if (value < 0) { + throw new RuntimeException("invalid negative value."); + } + if (value >= counts.length) { + value = counts.length - 1; + } + synchronized (this) { + int curCount = counts[value]; + if (curCount < Integer.MAX_VALUE) { + counts[value] = counts[value] + 1; + } + } + } + + /** + * Add a new value to the histogram. + * + * Note that the value will be clipped to the maximum value available in the Histogram instance. + * This method is provided for convenience, but handles the same numeric range as the method which + * takes an int. + */ + public void add(long value) { + if (value > Integer.MAX_VALUE) { + add(Integer.MAX_VALUE); + } else if (value < Integer.MIN_VALUE) { + add(Integer.MIN_VALUE); + } else { + add((int) value); + } + } + + public static class Summary { + /** + * The total number of samples. + */ + private final long numSamples; + + /** + * The average of all samples. + */ + private final float average; + + /** + * Percentile information. + * + * percentile(fraction=0.99) will have a value which is greater than or equal to 99% + * of the samples. percentile(fraction=0.5) is the median sample. And so forth. + */ + private final List percentiles; + + Summary(long numSamples, float average, List percentiles) { + this.numSamples = numSamples; + this.average = average; + this.percentiles = percentiles; + } + + public long numSamples() { + return numSamples; + } + + public float average() { + return average; + } + + public List percentiles() { + return percentiles; + } + } + + /** + * Information about a percentile. + */ + public static class PercentileSummary { + /** + * The fraction of samples which are less than or equal to the value of this percentile. + */ + private final float fraction; + + /** + * The value of this percentile. + */ + private final int value; + + PercentileSummary(float fraction, int value) { + this.fraction = fraction; + this.value = value; + } + + public float fraction() { + return fraction; + } + + public int value() { + return value; + } + } + + public Summary summarize() { + return summarize(new float[0]); + } + + public Summary summarize(float[] percentiles) { + int[] countsCopy = new int[counts.length]; + synchronized (this) { + System.arraycopy(counts, 0, countsCopy, 0, counts.length); + } + // Verify that the percentiles array is sorted and positive. + float prev = 0f; + for (int i = 0; i < percentiles.length; i++) { + if (percentiles[i] < prev) { + throw new RuntimeException("Invalid percentiles fraction array. Bad element " + + percentiles[i] + ". The array must be sorted and non-negative."); + } + if (percentiles[i] > 1.0f) { + throw new RuntimeException("Invalid percentiles fraction array. Bad element " + + percentiles[i] + ". Elements must be less than or equal to 1."); + } + } + // Find out how many total samples we have, and what the average is. + long numSamples = 0; + float total = 0f; + for (int i = 0; i < countsCopy.length; i++) { + long count = countsCopy[i]; + numSamples = numSamples + count; + total = total + (i * count); + } + float average = (numSamples == 0) ? 0.0f : (total / numSamples); + + List percentileSummaries = + summarizePercentiles(countsCopy, percentiles, numSamples); + return new Summary(numSamples, average, percentileSummaries); + } + + private List summarizePercentiles(int[] countsCopy, float[] percentiles, + long numSamples) { + if (percentiles.length == 0) { + return Collections.emptyList(); + } + List summaries = new ArrayList<>(percentiles.length); + int i = 0, j = 0; + long seen = 0, next = (long) (numSamples * percentiles[0]); + while (true) { + if (i == countsCopy.length - 1) { + for (; j < percentiles.length; j++) { + summaries.add(new PercentileSummary(percentiles[j], i)); + } + return summaries; + } + seen += countsCopy[i]; + while (seen >= next) { + summaries.add(new PercentileSummary(percentiles[j], i)); + j++; + if (j == percentiles.length) { + return summaries; + } + next = (long) (numSamples * percentiles[j]); + } + i++; + } + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/NullPayloadGenerator.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/NullPayloadGenerator.java new file mode 100644 index 0000000..e9799c0 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/NullPayloadGenerator.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonCreator; + +/** + * A PayloadGenerator which always generates a null payload. + */ +public class NullPayloadGenerator implements PayloadGenerator { + @JsonCreator + public NullPayloadGenerator() { + } + + @Override + public byte[] generate(long position) { + return null; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/PartitionsSpec.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/PartitionsSpec.java new file mode 100644 index 0000000..c1dc7c6 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/PartitionsSpec.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.trogdor.rest.Message; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +/** + * Describes some partitions. + */ +public class PartitionsSpec extends Message { + private final static short DEFAULT_REPLICATION_FACTOR = 3; + private final static short DEFAULT_NUM_PARTITIONS = 1; + + private final int numPartitions; + private final short replicationFactor; + private final Map> partitionAssignments; + private final Map configs; + + @JsonCreator + public PartitionsSpec(@JsonProperty("numPartitions") int numPartitions, + @JsonProperty("replicationFactor") short replicationFactor, + @JsonProperty("partitionAssignments") Map> partitionAssignments, + @JsonProperty("configs") Map configs) { + this.numPartitions = numPartitions; + this.replicationFactor = replicationFactor; + HashMap> partMap = new HashMap<>(); + if (partitionAssignments != null) { + for (Entry> entry : partitionAssignments.entrySet()) { + int partition = entry.getKey() == null ? 0 : entry.getKey(); + ArrayList assignments = new ArrayList<>(); + if (entry.getValue() != null) { + for (Integer brokerId : entry.getValue()) { + assignments.add(brokerId == null ? Integer.valueOf(0) : brokerId); + } + } + partMap.put(partition, Collections.unmodifiableList(assignments)); + } + } + this.partitionAssignments = Collections.unmodifiableMap(partMap); + if (configs == null) { + this.configs = Collections.emptyMap(); + } else { + this.configs = Collections.unmodifiableMap(new HashMap<>(configs)); + } + } + + @JsonProperty + public int numPartitions() { + return numPartitions; + } + + public List partitionNumbers() { + if (partitionAssignments.isEmpty()) { + ArrayList partitionNumbers = new ArrayList<>(); + int effectiveNumPartitions = numPartitions <= 0 ? DEFAULT_NUM_PARTITIONS : numPartitions; + for (int i = 0; i < effectiveNumPartitions; i++) { + partitionNumbers.add(i); + } + return partitionNumbers; + } else { + return new ArrayList<>(partitionAssignments.keySet()); + } + } + + @JsonProperty + public short replicationFactor() { + return replicationFactor; + } + + @JsonProperty + public Map> partitionAssignments() { + return partitionAssignments; + } + + @JsonProperty + public Map configs() { + return configs; + } + + public NewTopic newTopic(String topicName) { + NewTopic newTopic; + if (partitionAssignments.isEmpty()) { + int effectiveNumPartitions = numPartitions <= 0 ? + DEFAULT_NUM_PARTITIONS : numPartitions; + short effectiveReplicationFactor = replicationFactor <= 0 ? + DEFAULT_REPLICATION_FACTOR : replicationFactor; + newTopic = new NewTopic(topicName, effectiveNumPartitions, effectiveReplicationFactor); + } else { + newTopic = new NewTopic(topicName, partitionAssignments); + } + if (!configs.isEmpty()) { + newTopic.configs(configs); + } + return newTopic; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/PayloadGenerator.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/PayloadGenerator.java new file mode 100644 index 0000000..6d7393d --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/PayloadGenerator.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; + +/** + * Generates byte arrays based on a position argument. + * + * The array generated at a given position should be the same no matter how many + * times generate() is invoked. PayloadGenerator instances should be immutable + * and thread-safe. + */ +@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, + include = JsonTypeInfo.As.PROPERTY, + property = "type") +@JsonSubTypes(value = { + @JsonSubTypes.Type(value = ConstantPayloadGenerator.class, name = "constant"), + @JsonSubTypes.Type(value = SequentialPayloadGenerator.class, name = "sequential"), + @JsonSubTypes.Type(value = UniformRandomPayloadGenerator.class, name = "uniformRandom"), + @JsonSubTypes.Type(value = NullPayloadGenerator.class, name = "null"), + @JsonSubTypes.Type(value = RandomComponentPayloadGenerator.class, name = "randomComponent"), + @JsonSubTypes.Type(value = TimestampRandomPayloadGenerator.class, name = "timestampRandom"), + @JsonSubTypes.Type(value = TimestampConstantPayloadGenerator.class, name = "timestampConstant"), + @JsonSubTypes.Type(value = GaussianTimestampRandomPayloadGenerator.class, name = "gaussianTimestampRandom"), + @JsonSubTypes.Type(value = GaussianTimestampConstantPayloadGenerator.class, name = "gaussianTimestampConstant") + }) +public interface PayloadGenerator { + /** + * Generate a payload. + * + * @param position The position to use to generate the payload + * + * @return A new array object containing the payload. + */ + byte[] generate(long position); +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/PayloadIterator.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/PayloadIterator.java new file mode 100644 index 0000000..a5f3bae --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/PayloadIterator.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import java.util.Iterator; + +/** + * An iterator which wraps a PayloadGenerator. + */ +public final class PayloadIterator implements Iterator { + private final PayloadGenerator generator; + private long position = 0; + + public PayloadIterator(PayloadGenerator generator) { + this.generator = generator; + } + + @Override + public boolean hasNext() { + return true; + } + + @Override + public synchronized byte[] next() { + return generator.generate(position++); + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + + public synchronized void seek(long position) { + this.position = position; + } + + public synchronized long position() { + return this.position; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/PayloadKeyType.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/PayloadKeyType.java new file mode 100644 index 0000000..3ed98cd --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/PayloadKeyType.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +/** + * Describes a key in producer payload + */ +public enum PayloadKeyType { + // null key + KEY_NULL(0), + // fixed size key containing a long integer representing a message index (i.e., position of + // the payload generator) + KEY_MESSAGE_INDEX(8); + + private final int maxSizeInBytes; + + PayloadKeyType(int maxSizeInBytes) { + this.maxSizeInBytes = maxSizeInBytes; + } + + public int maxSizeInBytes() { + return maxSizeInBytes; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ProduceBenchSpec.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ProduceBenchSpec.java new file mode 100644 index 0000000..9f6a907 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ProduceBenchSpec.java @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.kafka.trogdor.task.TaskController; +import org.apache.kafka.trogdor.task.TaskSpec; +import org.apache.kafka.trogdor.task.TaskWorker; + +import java.util.Collections; +import java.util.Map; +import java.util.Optional; + +/** + * The specification for a benchmark that produces messages to a set of topics. + * + * To configure a transactional producer, a #{@link TransactionGenerator} must be passed in. + * Said generator works in lockstep with the producer by instructing it what action to take next in regards to a transaction. + * + * An example JSON representation which will result in a producer that creates three topics (foo1, foo2, foo3) + * with three partitions each and produces to them: + * #{@code + * { + * "class": "org.apache.kafka.trogdor.workload.ProduceBenchSpec", + * "durationMs": 10000000, + * "producerNode": "node0", + * "bootstrapServers": "localhost:9092", + * "targetMessagesPerSec": 10, + * "maxMessages": 100, + * "activeTopics": { + * "foo[1-3]": { + * "numPartitions": 3, + * "replicationFactor": 1 + * } + * }, + * "inactiveTopics": { + * "foo[4-5]": { + * "numPartitions": 3, + * "replicationFactor": 1 + * } + * } + * } + * } + */ +public class ProduceBenchSpec extends TaskSpec { + private final String producerNode; + private final String bootstrapServers; + private final int targetMessagesPerSec; + private final long maxMessages; + private final PayloadGenerator keyGenerator; + private final PayloadGenerator valueGenerator; + private final Optional transactionGenerator; + private final Map producerConf; + private final Map adminClientConf; + private final Map commonClientConf; + private final TopicsSpec activeTopics; + private final TopicsSpec inactiveTopics; + private final boolean useConfiguredPartitioner; + private final boolean skipFlush; + + @JsonCreator + public ProduceBenchSpec(@JsonProperty("startMs") long startMs, + @JsonProperty("durationMs") long durationMs, + @JsonProperty("producerNode") String producerNode, + @JsonProperty("bootstrapServers") String bootstrapServers, + @JsonProperty("targetMessagesPerSec") int targetMessagesPerSec, + @JsonProperty("maxMessages") long maxMessages, + @JsonProperty("keyGenerator") PayloadGenerator keyGenerator, + @JsonProperty("valueGenerator") PayloadGenerator valueGenerator, + @JsonProperty("transactionGenerator") Optional txGenerator, + @JsonProperty("producerConf") Map producerConf, + @JsonProperty("commonClientConf") Map commonClientConf, + @JsonProperty("adminClientConf") Map adminClientConf, + @JsonProperty("activeTopics") TopicsSpec activeTopics, + @JsonProperty("inactiveTopics") TopicsSpec inactiveTopics, + @JsonProperty("useConfiguredPartitioner") boolean useConfiguredPartitioner, + @JsonProperty("skipFlush") boolean skipFlush) { + super(startMs, durationMs); + this.producerNode = (producerNode == null) ? "" : producerNode; + this.bootstrapServers = (bootstrapServers == null) ? "" : bootstrapServers; + this.targetMessagesPerSec = targetMessagesPerSec; + this.maxMessages = maxMessages; + this.keyGenerator = keyGenerator == null ? + new SequentialPayloadGenerator(4, 0) : keyGenerator; + this.valueGenerator = valueGenerator == null ? + new ConstantPayloadGenerator(512, new byte[0]) : valueGenerator; + this.transactionGenerator = txGenerator == null ? Optional.empty() : txGenerator; + this.producerConf = configOrEmptyMap(producerConf); + this.commonClientConf = configOrEmptyMap(commonClientConf); + this.adminClientConf = configOrEmptyMap(adminClientConf); + this.activeTopics = (activeTopics == null) ? + TopicsSpec.EMPTY : activeTopics.immutableCopy(); + this.inactiveTopics = (inactiveTopics == null) ? + TopicsSpec.EMPTY : inactiveTopics.immutableCopy(); + this.useConfiguredPartitioner = useConfiguredPartitioner; + this.skipFlush = skipFlush; + } + + @JsonProperty + public String producerNode() { + return producerNode; + } + + @JsonProperty + public String bootstrapServers() { + return bootstrapServers; + } + + @JsonProperty + public int targetMessagesPerSec() { + return targetMessagesPerSec; + } + + @JsonProperty + public long maxMessages() { + return maxMessages; + } + + @JsonProperty + public PayloadGenerator keyGenerator() { + return keyGenerator; + } + + @JsonProperty + public PayloadGenerator valueGenerator() { + return valueGenerator; + } + + @JsonProperty + public Optional transactionGenerator() { + return transactionGenerator; + } + + @JsonProperty + public Map producerConf() { + return producerConf; + } + + @JsonProperty + public Map commonClientConf() { + return commonClientConf; + } + + @JsonProperty + public Map adminClientConf() { + return adminClientConf; + } + + @JsonProperty + public TopicsSpec activeTopics() { + return activeTopics; + } + + @JsonProperty + public TopicsSpec inactiveTopics() { + return inactiveTopics; + } + + @JsonProperty + public boolean useConfiguredPartitioner() { + return useConfiguredPartitioner; + } + + @JsonProperty + public boolean skipFlush() { + return skipFlush; + } + + @Override + public TaskController newController(String id) { + return topology -> Collections.singleton(producerNode); + } + + @Override + public TaskWorker newTaskWorker(String id) { + return new ProduceBenchWorker(id, this); + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ProduceBenchWorker.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ProduceBenchWorker.java new file mode 100644 index 0000000..42127fd --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ProduceBenchWorker.java @@ -0,0 +1,421 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.node.TextNode; +import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.clients.producer.Callback; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.clients.producer.RecordMetadata; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.utils.ThreadUtils; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.trogdor.common.JsonUtil; +import org.apache.kafka.trogdor.common.Platform; +import org.apache.kafka.trogdor.common.WorkerUtils; +import org.apache.kafka.trogdor.task.TaskWorker; +import org.apache.kafka.trogdor.task.WorkerStatusTracker; +import org.apache.kafka.trogdor.workload.TransactionGenerator.TransactionAction; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Map; +import java.util.Optional; +import java.util.Properties; +import java.util.UUID; +import java.util.concurrent.Callable; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; + +public class ProduceBenchWorker implements TaskWorker { + private static final Logger log = LoggerFactory.getLogger(ProduceBenchWorker.class); + + private static final int THROTTLE_PERIOD_MS = 100; + + private final String id; + + private final ProduceBenchSpec spec; + + private final AtomicBoolean running = new AtomicBoolean(false); + + private ScheduledExecutorService executor; + + private WorkerStatusTracker status; + + private KafkaFutureImpl doneFuture; + + public ProduceBenchWorker(String id, ProduceBenchSpec spec) { + this.id = id; + this.spec = spec; + } + + @Override + public void start(Platform platform, WorkerStatusTracker status, + KafkaFutureImpl doneFuture) { + if (!running.compareAndSet(false, true)) { + throw new IllegalStateException("ProducerBenchWorker is already running."); + } + log.info("{}: Activating ProduceBenchWorker with {}", id, spec); + // Create an executor with 2 threads. We need the second thread so + // that the StatusUpdater can run in parallel with SendRecords. + this.executor = Executors.newScheduledThreadPool(2, + ThreadUtils.createThreadFactory("ProduceBenchWorkerThread%d", false)); + this.status = status; + this.doneFuture = doneFuture; + executor.submit(new Prepare()); + } + + public class Prepare implements Runnable { + @Override + public void run() { + try { + Map newTopics = new HashMap<>(); + HashSet active = new HashSet<>(); + for (Map.Entry entry : + spec.activeTopics().materialize().entrySet()) { + String topicName = entry.getKey(); + PartitionsSpec partSpec = entry.getValue(); + newTopics.put(topicName, partSpec.newTopic(topicName)); + for (Integer partitionNumber : partSpec.partitionNumbers()) { + active.add(new TopicPartition(topicName, partitionNumber)); + } + } + if (active.isEmpty()) { + throw new RuntimeException("You must specify at least one active topic."); + } + for (Map.Entry entry : + spec.inactiveTopics().materialize().entrySet()) { + String topicName = entry.getKey(); + PartitionsSpec partSpec = entry.getValue(); + newTopics.put(topicName, partSpec.newTopic(topicName)); + } + status.update(new TextNode("Creating " + newTopics.keySet().size() + " topic(s)")); + WorkerUtils.createTopics(log, spec.bootstrapServers(), spec.commonClientConf(), + spec.adminClientConf(), newTopics, false); + status.update(new TextNode("Created " + newTopics.keySet().size() + " topic(s)")); + executor.submit(new SendRecords(active)); + } catch (Throwable e) { + WorkerUtils.abort(log, "Prepare", e, doneFuture); + } + } + } + + private static class SendRecordsCallback implements Callback { + private final SendRecords sendRecords; + private final long startMs; + + SendRecordsCallback(SendRecords sendRecords, long startMs) { + this.sendRecords = sendRecords; + this.startMs = startMs; + } + + @Override + public void onCompletion(RecordMetadata metadata, Exception exception) { + long now = Time.SYSTEM.milliseconds(); + long durationMs = now - startMs; + sendRecords.recordDuration(durationMs); + if (exception != null) { + log.error("SendRecordsCallback: error", exception); + } + } + } + + /** + * A subclass of Throttle which flushes the Producer right before the throttle injects a delay. + * This avoids including throttling latency in latency measurements. + */ + private static class SendRecordsThrottle extends Throttle { + private final KafkaProducer producer; + + SendRecordsThrottle(int maxPerPeriod, KafkaProducer producer) { + super(maxPerPeriod, THROTTLE_PERIOD_MS); + this.producer = producer; + } + + @Override + protected synchronized void delay(long amount) throws InterruptedException { + long startMs = time().milliseconds(); + producer.flush(); + long endMs = time().milliseconds(); + long delta = endMs - startMs; + super.delay(amount - delta); + } + } + + public class SendRecords implements Callable { + private final HashSet activePartitions; + + private final Histogram histogram; + + private final Future statusUpdaterFuture; + + private final KafkaProducer producer; + + private final PayloadIterator keys; + + private final PayloadIterator values; + + private final Optional transactionGenerator; + + private final Throttle throttle; + + private Iterator partitionsIterator; + private Future sendFuture; + private AtomicLong transactionsCommitted; + private boolean enableTransactions; + + SendRecords(HashSet activePartitions) { + this.activePartitions = activePartitions; + this.partitionsIterator = activePartitions.iterator(); + this.histogram = new Histogram(10000); + + this.transactionGenerator = spec.transactionGenerator(); + this.enableTransactions = this.transactionGenerator.isPresent(); + this.transactionsCommitted = new AtomicLong(); + + int perPeriod = WorkerUtils.perSecToPerPeriod(spec.targetMessagesPerSec(), THROTTLE_PERIOD_MS); + this.statusUpdaterFuture = executor.scheduleWithFixedDelay( + new StatusUpdater(histogram, transactionsCommitted), 30, 30, TimeUnit.SECONDS); + + Properties props = new Properties(); + props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, spec.bootstrapServers()); + if (enableTransactions) + props.put(ProducerConfig.TRANSACTIONAL_ID_CONFIG, "produce-bench-transaction-id-" + UUID.randomUUID()); + // add common client configs to producer properties, and then user-specified producer configs + WorkerUtils.addConfigsToProperties(props, spec.commonClientConf(), spec.producerConf()); + this.producer = new KafkaProducer<>(props, new ByteArraySerializer(), new ByteArraySerializer()); + this.keys = new PayloadIterator(spec.keyGenerator()); + this.values = new PayloadIterator(spec.valueGenerator()); + if (spec.skipFlush()) { + this.throttle = new Throttle(perPeriod, THROTTLE_PERIOD_MS); + } else { + this.throttle = new SendRecordsThrottle(perPeriod, producer); + } + } + + @Override + public Void call() throws Exception { + long startTimeMs = Time.SYSTEM.milliseconds(); + try { + try { + if (enableTransactions) + producer.initTransactions(); + + long sentMessages = 0; + while (sentMessages < spec.maxMessages()) { + if (enableTransactions) { + boolean tookAction = takeTransactionAction(); + if (tookAction) + continue; + } + sendMessage(); + sentMessages++; + } + if (enableTransactions) + takeTransactionAction(); // give the transactionGenerator a chance to commit if configured evenly + } catch (Exception e) { + if (enableTransactions) + producer.abortTransaction(); + throw e; + } finally { + if (sendFuture != null) { + try { + sendFuture.get(); + } catch (Exception e) { + log.error("Exception on final future", e); + } + } + producer.close(); + } + } catch (Exception e) { + WorkerUtils.abort(log, "SendRecords", e, doneFuture); + } finally { + statusUpdaterFuture.cancel(false); + StatusData statusData = new StatusUpdater(histogram, transactionsCommitted).update(); + long curTimeMs = Time.SYSTEM.milliseconds(); + log.info("Sent {} total record(s) in {} ms. status: {}", + histogram.summarize().numSamples(), curTimeMs - startTimeMs, statusData); + } + doneFuture.complete(""); + return null; + } + + private boolean takeTransactionAction() { + boolean tookAction = true; + TransactionAction nextAction = transactionGenerator.get().nextAction(); + switch (nextAction) { + case BEGIN_TRANSACTION: + log.debug("Beginning transaction."); + producer.beginTransaction(); + break; + case COMMIT_TRANSACTION: + log.debug("Committing transaction."); + producer.commitTransaction(); + transactionsCommitted.getAndIncrement(); + break; + case ABORT_TRANSACTION: + log.debug("Aborting transaction."); + producer.abortTransaction(); + break; + case NO_OP: + tookAction = false; + break; + } + return tookAction; + } + + private void sendMessage() throws InterruptedException { + if (!partitionsIterator.hasNext()) + partitionsIterator = activePartitions.iterator(); + + TopicPartition partition = partitionsIterator.next(); + ProducerRecord record; + if (spec.useConfiguredPartitioner()) { + record = new ProducerRecord<>( + partition.topic(), keys.next(), values.next()); + } else { + record = new ProducerRecord<>( + partition.topic(), partition.partition(), keys.next(), values.next()); + } + sendFuture = producer.send(record, + new SendRecordsCallback(this, Time.SYSTEM.milliseconds())); + throttle.increment(); + } + + void recordDuration(long durationMs) { + histogram.add(durationMs); + } + } + + public class StatusUpdater implements Runnable { + private final Histogram histogram; + private final AtomicLong transactionsCommitted; + + StatusUpdater(Histogram histogram, AtomicLong transactionsCommitted) { + this.histogram = histogram; + this.transactionsCommitted = transactionsCommitted; + } + + @Override + public void run() { + try { + update(); + } catch (Exception e) { + WorkerUtils.abort(log, "StatusUpdater", e, doneFuture); + } + } + + StatusData update() { + Histogram.Summary summary = histogram.summarize(StatusData.PERCENTILES); + StatusData statusData = new StatusData(summary.numSamples(), summary.average(), + summary.percentiles().get(0).value(), + summary.percentiles().get(1).value(), + summary.percentiles().get(2).value(), + transactionsCommitted.get()); + status.update(JsonUtil.JSON_SERDE.valueToTree(statusData)); + return statusData; + } + } + + public static class StatusData { + private final long totalSent; + private final float averageLatencyMs; + private final int p50LatencyMs; + private final int p95LatencyMs; + private final int p99LatencyMs; + private final long transactionsCommitted; + + /** + * The percentiles to use when calculating the histogram data. + * These should match up with the p50LatencyMs, p95LatencyMs, etc. fields. + */ + final static float[] PERCENTILES = {0.5f, 0.95f, 0.99f}; + + @JsonCreator + StatusData(@JsonProperty("totalSent") long totalSent, + @JsonProperty("averageLatencyMs") float averageLatencyMs, + @JsonProperty("p50LatencyMs") int p50latencyMs, + @JsonProperty("p95LatencyMs") int p95latencyMs, + @JsonProperty("p99LatencyMs") int p99latencyMs, + @JsonProperty("transactionsCommitted") long transactionsCommitted) { + this.totalSent = totalSent; + this.averageLatencyMs = averageLatencyMs; + this.p50LatencyMs = p50latencyMs; + this.p95LatencyMs = p95latencyMs; + this.p99LatencyMs = p99latencyMs; + this.transactionsCommitted = transactionsCommitted; + } + + @JsonProperty + public long totalSent() { + return totalSent; + } + + @JsonProperty + public long transactionsCommitted() { + return transactionsCommitted; + } + + @JsonProperty + public float averageLatencyMs() { + return averageLatencyMs; + } + + @JsonProperty + public int p50LatencyMs() { + return p50LatencyMs; + } + + @JsonProperty + public int p95LatencyMs() { + return p95LatencyMs; + } + + @JsonProperty + public int p99LatencyMs() { + return p99LatencyMs; + } + } + + @Override + public void stop(Platform platform) throws Exception { + if (!running.compareAndSet(true, false)) { + throw new IllegalStateException("ProduceBenchWorker is not running."); + } + log.info("{}: Deactivating ProduceBenchWorker.", id); + doneFuture.complete(""); + executor.shutdownNow(); + executor.awaitTermination(1, TimeUnit.DAYS); + this.executor = null; + this.status = null; + this.doneFuture = null; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/RandomComponent.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/RandomComponent.java new file mode 100644 index 0000000..b5973a8 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/RandomComponent.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Contains a percent value represented as an integer between 1 and 100 and a PayloadGenerator to specify + * how often that PayloadGenerator should be used. + */ +public class RandomComponent { + private final int percent; + private final PayloadGenerator component; + + + @JsonCreator + public RandomComponent(@JsonProperty("percent") int percent, + @JsonProperty("component") PayloadGenerator component) { + this.percent = percent; + this.component = component; + } + + @JsonProperty + public int percent() { + return percent; + } + + @JsonProperty + public PayloadGenerator component() { + return component; + } +} + diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/RandomComponentPayloadGenerator.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/RandomComponentPayloadGenerator.java new file mode 100644 index 0000000..be50a44 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/RandomComponentPayloadGenerator.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + + +/** + * A PayloadGenerator which generates pseudo-random payloads based on other PayloadGenerators. + * + * Given a seed and non-null list of RandomComponents, RandomComponentPayloadGenerator + * will use any given generator in its list of components a percentage of the time based on the + * percent field in the RandomComponent. These percent fields must be integers greater than 0 + * and together add up to 100. The payloads generated can be reproduced from run to run. + * + * An example of how to include this generator in a Trogdor taskSpec is shown below. + * #{@code + * "keyGenerator": { + * "type": "randomComponent", + * "seed": 456, + * "components": [ + * { + * "percent": 50, + * "component": { + * "type": "null" + * } + * }, + * { + * "percent": 50, + * "component": { + * "type": "uniformRandom", + * "size": 4, + * "seed": 123, + * "padding": 0 + * } + * } + * ] + * } + * } + */ +public class RandomComponentPayloadGenerator implements PayloadGenerator { + private final long seed; + private final List components; + private final Random random = new Random(); + + @JsonCreator + public RandomComponentPayloadGenerator(@JsonProperty("seed") long seed, + @JsonProperty("components") List components) { + this.seed = seed; + if (components == null || components.isEmpty()) { + throw new IllegalArgumentException("Components must be a specified, non-empty list of RandomComponents."); + } + int sum = 0; + for (RandomComponent component : components) { + if (component.percent() < 1) { + throw new IllegalArgumentException("Percent value must be greater than zero."); + } + sum += component.percent(); + } + if (sum != 100) { + throw new IllegalArgumentException("Components must be a list of RandomComponents such that the percent fields sum to 100"); + } + this.components = new ArrayList<>(components); + } + + @JsonProperty + public long seed() { + return seed; + } + + @JsonProperty + public List components() { + return components; + } + + @Override + public byte[] generate(long position) { + int randPercent; + synchronized (random) { + random.setSeed(seed + position); + randPercent = random.nextInt(100); + } + int curPercent = 0; + RandomComponent com = components.get(0); + for (RandomComponent component : components) { + curPercent += component.percent(); + if (curPercent > randPercent) { + com = component; + break; + } + } + return com.component().generate(position); + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/RecordProcessor.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/RecordProcessor.java new file mode 100644 index 0000000..500736d --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/RecordProcessor.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import com.fasterxml.jackson.databind.JsonNode; +import org.apache.kafka.clients.consumer.ConsumerRecords; + +/** + * RecordProcessor allows for acting on data polled from ConsumeBench workloads. + */ +@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, + include = JsonTypeInfo.As.PROPERTY, + property = "type") +@JsonSubTypes(value = { + @JsonSubTypes.Type(value = TimestampRecordProcessor.class, name = "timestamp"), +}) +public interface RecordProcessor { + void processRecords(ConsumerRecords consumerRecords); + JsonNode processorStatus(); +} + diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/RoundTripWorker.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/RoundTripWorker.java new file mode 100644 index 0000000..6f448cc --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/RoundTripWorker.java @@ -0,0 +1,458 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.node.TextNode; +import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.errors.TimeoutException; +import org.apache.kafka.common.errors.WakeupException; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.utils.ThreadUtils; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.trogdor.common.JsonUtil; +import org.apache.kafka.trogdor.common.Platform; +import org.apache.kafka.trogdor.common.WorkerUtils; +import org.apache.kafka.trogdor.task.TaskWorker; +import org.apache.kafka.trogdor.task.WorkerStatusTracker; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.TreeSet; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; + +public class RoundTripWorker implements TaskWorker { + private static final int THROTTLE_PERIOD_MS = 100; + + private static final int LOG_INTERVAL_MS = 5000; + + private static final int LOG_NUM_MESSAGES = 10; + + private static final Logger log = LoggerFactory.getLogger(RoundTripWorker.class); + + private static final PayloadGenerator KEY_GENERATOR = new SequentialPayloadGenerator(4, 0); + + private ToReceiveTracker toReceiveTracker; + + private final String id; + + private final RoundTripWorkloadSpec spec; + + private final AtomicBoolean running = new AtomicBoolean(false); + + private final Lock lock = new ReentrantLock(); + + private final Condition unackedSendsAreZero = lock.newCondition(); + + private ScheduledExecutorService executor; + + private WorkerStatusTracker status; + + private KafkaFutureImpl doneFuture; + + private KafkaProducer producer; + + private KafkaConsumer consumer; + + private Long unackedSends; + + private ToSendTracker toSendTracker; + + public RoundTripWorker(String id, RoundTripWorkloadSpec spec) { + this.id = id; + this.spec = spec; + } + + @Override + public void start(Platform platform, WorkerStatusTracker status, + KafkaFutureImpl doneFuture) throws Exception { + if (!running.compareAndSet(false, true)) { + throw new IllegalStateException("RoundTripWorker is already running."); + } + log.info("{}: Activating RoundTripWorker.", id); + this.executor = Executors.newScheduledThreadPool(3, + ThreadUtils.createThreadFactory("RoundTripWorker%d", false)); + this.status = status; + this.doneFuture = doneFuture; + this.producer = null; + this.consumer = null; + this.unackedSends = spec.maxMessages(); + executor.submit(new Prepare()); + } + + class Prepare implements Runnable { + @Override + public void run() { + try { + if (spec.targetMessagesPerSec() <= 0) { + throw new ConfigException("Can't have targetMessagesPerSec <= 0."); + } + Map newTopics = new HashMap<>(); + HashSet active = new HashSet<>(); + for (Map.Entry entry : + spec.activeTopics().materialize().entrySet()) { + String topicName = entry.getKey(); + PartitionsSpec partSpec = entry.getValue(); + newTopics.put(topicName, partSpec.newTopic(topicName)); + for (Integer partitionNumber : partSpec.partitionNumbers()) { + active.add(new TopicPartition(topicName, partitionNumber)); + } + } + if (active.isEmpty()) { + throw new RuntimeException("You must specify at least one active topic."); + } + status.update(new TextNode("Creating " + newTopics.keySet().size() + " topic(s)")); + WorkerUtils.createTopics(log, spec.bootstrapServers(), spec.commonClientConf(), + spec.adminClientConf(), newTopics, false); + status.update(new TextNode("Created " + newTopics.keySet().size() + " topic(s)")); + toSendTracker = new ToSendTracker(spec.maxMessages()); + toReceiveTracker = new ToReceiveTracker(); + executor.submit(new ProducerRunnable(active)); + executor.submit(new ConsumerRunnable(active)); + executor.submit(new StatusUpdater()); + executor.scheduleWithFixedDelay( + new StatusUpdater(), 30, 30, TimeUnit.SECONDS); + } catch (Throwable e) { + WorkerUtils.abort(log, "Prepare", e, doneFuture); + } + } + } + + private static class ToSendTrackerResult { + final long index; + final boolean firstSend; + + ToSendTrackerResult(long index, boolean firstSend) { + this.index = index; + this.firstSend = firstSend; + } + } + + private static class ToSendTracker { + private final long maxMessages; + private final List failed = new ArrayList<>(); + private long frontier = 0; + + ToSendTracker(long maxMessages) { + this.maxMessages = maxMessages; + } + + synchronized void addFailed(long index) { + failed.add(index); + } + + synchronized long frontier() { + return frontier; + } + + synchronized ToSendTrackerResult next() { + if (failed.isEmpty()) { + if (frontier >= maxMessages) { + return null; + } else { + return new ToSendTrackerResult(frontier++, true); + } + } else { + return new ToSendTrackerResult(failed.remove(0), false); + } + } + } + + class ProducerRunnable implements Runnable { + private final HashSet partitions; + private final Throttle throttle; + + ProducerRunnable(HashSet partitions) { + this.partitions = partitions; + Properties props = new Properties(); + props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, spec.bootstrapServers()); + props.put(ProducerConfig.BATCH_SIZE_CONFIG, 16 * 1024); + props.put(ProducerConfig.BUFFER_MEMORY_CONFIG, 4 * 16 * 1024L); + props.put(ProducerConfig.MAX_BLOCK_MS_CONFIG, 1000L); + props.put(ProducerConfig.CLIENT_ID_CONFIG, "producer." + id); + props.put(ProducerConfig.ACKS_CONFIG, "all"); + props.put(ProducerConfig.REQUEST_TIMEOUT_MS_CONFIG, 105000); + // user may over-write the defaults with common client config and producer config + WorkerUtils.addConfigsToProperties(props, spec.commonClientConf(), spec.producerConf()); + producer = new KafkaProducer<>(props, new ByteArraySerializer(), + new ByteArraySerializer()); + int perPeriod = WorkerUtils. + perSecToPerPeriod(spec.targetMessagesPerSec(), THROTTLE_PERIOD_MS); + this.throttle = new Throttle(perPeriod, THROTTLE_PERIOD_MS); + } + + @Override + public void run() { + long messagesSent = 0; + long uniqueMessagesSent = 0; + log.debug("{}: Starting RoundTripWorker#ProducerRunnable.", id); + try { + Iterator iter = partitions.iterator(); + while (true) { + final ToSendTrackerResult result = toSendTracker.next(); + if (result == null) { + break; + } + throttle.increment(); + final long messageIndex = result.index; + if (result.firstSend) { + toReceiveTracker.addPending(messageIndex); + uniqueMessagesSent++; + } + messagesSent++; + if (!iter.hasNext()) { + iter = partitions.iterator(); + } + TopicPartition partition = iter.next(); + // we explicitly specify generator position based on message index + ProducerRecord record = new ProducerRecord<>(partition.topic(), + partition.partition(), KEY_GENERATOR.generate(messageIndex), + spec.valueGenerator().generate(messageIndex)); + producer.send(record, (metadata, exception) -> { + if (exception == null) { + lock.lock(); + try { + unackedSends -= 1; + if (unackedSends <= 0) + unackedSendsAreZero.signalAll(); + } finally { + lock.unlock(); + } + } else { + log.info("{}: Got exception when sending message {}: {}", + id, messageIndex, exception.getMessage()); + toSendTracker.addFailed(messageIndex); + } + }); + } + } catch (Throwable e) { + WorkerUtils.abort(log, "ProducerRunnable", e, doneFuture); + } finally { + lock.lock(); + try { + log.info("{}: ProducerRunnable is exiting. messagesSent={}; uniqueMessagesSent={}; " + + "ackedSends={}/{}.", id, messagesSent, uniqueMessagesSent, + spec.maxMessages() - unackedSends, spec.maxMessages()); + } finally { + lock.unlock(); + } + } + } + } + + private class ToReceiveTracker { + private final TreeSet pending = new TreeSet<>(); + + private long totalReceived = 0; + + synchronized void addPending(long messageIndex) { + pending.add(messageIndex); + } + + synchronized boolean removePending(long messageIndex) { + if (pending.remove(messageIndex)) { + totalReceived++; + return true; + } else { + return false; + } + } + + synchronized long totalReceived() { + return totalReceived; + } + + void log() { + long numToReceive; + List list = new ArrayList<>(LOG_NUM_MESSAGES); + synchronized (this) { + numToReceive = pending.size(); + for (Iterator iter = pending.iterator(); + iter.hasNext() && (list.size() < LOG_NUM_MESSAGES); ) { + Long i = iter.next(); + list.add(i); + } + } + log.info("{}: consumer waiting for {} message(s), starting with: {}", + id, numToReceive, Utils.join(list, ", ")); + } + } + + class ConsumerRunnable implements Runnable { + private final Properties props; + + ConsumerRunnable(HashSet partitions) { + this.props = new Properties(); + props.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, spec.bootstrapServers()); + props.put(ConsumerConfig.CLIENT_ID_CONFIG, "consumer." + id); + props.put(ConsumerConfig.GROUP_ID_CONFIG, "round-trip-consumer-group-" + id); + props.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + props.put(ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG, 105000); + props.put(ConsumerConfig.MAX_POLL_INTERVAL_MS_CONFIG, 100000); + // user may over-write the defaults with common client config and consumer config + WorkerUtils.addConfigsToProperties(props, spec.commonClientConf(), spec.consumerConf()); + consumer = new KafkaConsumer<>(props, new ByteArrayDeserializer(), + new ByteArrayDeserializer()); + consumer.assign(partitions); + } + + @Override + public void run() { + long uniqueMessagesReceived = 0; + long messagesReceived = 0; + long pollInvoked = 0; + log.debug("{}: Starting RoundTripWorker#ConsumerRunnable.", id); + try { + long lastLogTimeMs = Time.SYSTEM.milliseconds(); + while (true) { + try { + pollInvoked++; + ConsumerRecords records = consumer.poll(Duration.ofMillis(50)); + for (Iterator> iter = records.iterator(); iter.hasNext(); ) { + ConsumerRecord record = iter.next(); + int messageIndex = ByteBuffer.wrap(record.key()).order(ByteOrder.LITTLE_ENDIAN).getInt(); + messagesReceived++; + if (toReceiveTracker.removePending(messageIndex)) { + uniqueMessagesReceived++; + if (uniqueMessagesReceived >= spec.maxMessages()) { + lock.lock(); + try { + log.info("{}: Consumer received the full count of {} unique messages. " + + "Waiting for all {} sends to be acked...", id, spec.maxMessages(), unackedSends); + while (unackedSends > 0) + unackedSendsAreZero.await(); + } finally { + lock.unlock(); + } + + log.info("{}: all sends have been acked.", id); + new StatusUpdater().update(); + doneFuture.complete(""); + return; + } + } + } + long curTimeMs = Time.SYSTEM.milliseconds(); + if (curTimeMs > lastLogTimeMs + LOG_INTERVAL_MS) { + toReceiveTracker.log(); + lastLogTimeMs = curTimeMs; + } + } catch (WakeupException e) { + log.debug("{}: Consumer got WakeupException", id, e); + } catch (TimeoutException e) { + log.debug("{}: Consumer got TimeoutException", id, e); + } + } + } catch (Throwable e) { + WorkerUtils.abort(log, "ConsumerRunnable", e, doneFuture); + } finally { + log.info("{}: ConsumerRunnable is exiting. Invoked poll {} time(s). " + + "messagesReceived = {}; uniqueMessagesReceived = {}.", + id, pollInvoked, messagesReceived, uniqueMessagesReceived); + } + } + } + + public class StatusUpdater implements Runnable { + @Override + public void run() { + try { + update(); + } catch (Exception e) { + WorkerUtils.abort(log, "StatusUpdater", e, doneFuture); + } + } + + StatusData update() { + StatusData statusData = + new StatusData(toSendTracker.frontier(), toReceiveTracker.totalReceived()); + status.update(JsonUtil.JSON_SERDE.valueToTree(statusData)); + return statusData; + } + } + + public static class StatusData { + private final long totalUniqueSent; + private final long totalReceived; + + @JsonCreator + public StatusData(@JsonProperty("totalUniqueSent") long totalUniqueSent, + @JsonProperty("totalReceived") long totalReceived) { + this.totalUniqueSent = totalUniqueSent; + this.totalReceived = totalReceived; + } + + @JsonProperty + public long totalUniqueSent() { + return totalUniqueSent; + } + + @JsonProperty + public long totalReceived() { + return totalReceived; + } + } + + @Override + public void stop(Platform platform) throws Exception { + if (!running.compareAndSet(true, false)) { + throw new IllegalStateException("RoundTripWorker is not running."); + } + log.info("{}: Deactivating RoundTripWorker.", id); + doneFuture.complete(""); + executor.shutdownNow(); + executor.awaitTermination(1, TimeUnit.DAYS); + Utils.closeQuietly(consumer, "consumer"); + Utils.closeQuietly(producer, "producer"); + this.consumer = null; + this.producer = null; + this.unackedSends = null; + this.executor = null; + this.doneFuture = null; + log.info("{}: Deactivated RoundTripWorker.", id); + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/RoundTripWorkloadSpec.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/RoundTripWorkloadSpec.java new file mode 100644 index 0000000..fd30e8e --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/RoundTripWorkloadSpec.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.kafka.trogdor.task.TaskController; +import org.apache.kafka.trogdor.task.TaskSpec; +import org.apache.kafka.trogdor.task.TaskWorker; + +import java.util.Collections; +import java.util.Map; + +/** + * The specification for a workload that sends messages to a broker and then + * reads them back. + */ +public class RoundTripWorkloadSpec extends TaskSpec { + private final String clientNode; + private final String bootstrapServers; + private final int targetMessagesPerSec; + private final PayloadGenerator valueGenerator; + private final TopicsSpec activeTopics; + private final long maxMessages; + private final Map commonClientConf; + private final Map producerConf; + private final Map consumerConf; + private final Map adminClientConf; + + @JsonCreator + public RoundTripWorkloadSpec(@JsonProperty("startMs") long startMs, + @JsonProperty("durationMs") long durationMs, + @JsonProperty("clientNode") String clientNode, + @JsonProperty("bootstrapServers") String bootstrapServers, + @JsonProperty("commonClientConf") Map commonClientConf, + @JsonProperty("adminClientConf") Map adminClientConf, + @JsonProperty("consumerConf") Map consumerConf, + @JsonProperty("producerConf") Map producerConf, + @JsonProperty("targetMessagesPerSec") int targetMessagesPerSec, + @JsonProperty("valueGenerator") PayloadGenerator valueGenerator, + @JsonProperty("activeTopics") TopicsSpec activeTopics, + @JsonProperty("maxMessages") long maxMessages) { + super(startMs, durationMs); + this.clientNode = clientNode == null ? "" : clientNode; + this.bootstrapServers = bootstrapServers == null ? "" : bootstrapServers; + this.targetMessagesPerSec = targetMessagesPerSec; + this.valueGenerator = valueGenerator == null ? + new UniformRandomPayloadGenerator(32, 123, 10) : valueGenerator; + this.activeTopics = activeTopics == null ? + TopicsSpec.EMPTY : activeTopics.immutableCopy(); + this.maxMessages = maxMessages; + this.commonClientConf = configOrEmptyMap(commonClientConf); + this.adminClientConf = configOrEmptyMap(adminClientConf); + this.producerConf = configOrEmptyMap(producerConf); + this.consumerConf = configOrEmptyMap(consumerConf); + } + + @JsonProperty + public String clientNode() { + return clientNode; + } + + @JsonProperty + public String bootstrapServers() { + return bootstrapServers; + } + + @JsonProperty + public int targetMessagesPerSec() { + return targetMessagesPerSec; + } + + @JsonProperty + public TopicsSpec activeTopics() { + return activeTopics; + } + + @JsonProperty + public PayloadGenerator valueGenerator() { + return valueGenerator; + } + + @JsonProperty + public long maxMessages() { + return maxMessages; + } + + @JsonProperty + public Map commonClientConf() { + return commonClientConf; + } + + @JsonProperty + public Map adminClientConf() { + return adminClientConf; + } + + @JsonProperty + public Map producerConf() { + return producerConf; + } + + @JsonProperty + public Map consumerConf() { + return consumerConf; + } + + @Override + public TaskController newController(String id) { + return topology -> Collections.singleton(clientNode); + } + + @Override + public TaskWorker newTaskWorker(String id) { + return new RoundTripWorker(id, this); + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/SequentialPayloadGenerator.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/SequentialPayloadGenerator.java new file mode 100644 index 0000000..e0b785a --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/SequentialPayloadGenerator.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/** + * A PayloadGenerator which generates a sequentially increasing payload. + * + * The generated number will wrap around to 0 after the maximum value is reached. + * Payloads bigger than 8 bytes will always just be padded with zeros after byte 8. + */ +public class SequentialPayloadGenerator implements PayloadGenerator { + private final int size; + private final long startOffset; + private final ByteBuffer buf; + + @JsonCreator + public SequentialPayloadGenerator(@JsonProperty("size") int size, + @JsonProperty("offset") long startOffset) { + this.size = size; + this.startOffset = startOffset; + this.buf = ByteBuffer.allocate(8); + // Little-endian byte order allows us to support arbitrary lengths more easily, + // since the first byte is always the lowest-order byte. + this.buf.order(ByteOrder.LITTLE_ENDIAN); + } + + @JsonProperty + public int size() { + return size; + } + + @JsonProperty + public long startOffset() { + return startOffset; + } + + @Override + public synchronized byte[] generate(long position) { + buf.clear(); + buf.putLong(position + startOffset); + byte[] result = new byte[size]; + System.arraycopy(buf.array(), 0, result, 0, Math.min(buf.array().length, result.length)); + return result; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/SustainedConnectionSpec.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/SustainedConnectionSpec.java new file mode 100644 index 0000000..1783a80 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/SustainedConnectionSpec.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.kafka.trogdor.task.TaskController; +import org.apache.kafka.trogdor.task.TaskSpec; +import org.apache.kafka.trogdor.task.TaskWorker; + +import java.util.Collections; +import java.util.Map; + +/** + * The specification for a benchmark that creates sustained connections. + * + * An example JSON representation which will result in a test that creates 27 connections (9 of each), refreshes them + * every 10 seconds using 2 threads, running against topic `topic1`, for a duration of 1 hour, and with various other + * options set: + * + * #{@code + * { + * "class": "org.apache.kafka.trogdor.workload.SustainedConnectionSpec", + * "durationMs": 3600000, + * "clientNode": "node0", + * "bootstrapServers": "localhost:9092", + * "commonClientConf": { + * "compression.type": "lz4", + * "auto.offset.reset": "earliest", + * "linger.ms": "100" + * }, + * "keyGenerator": { + * "type": "sequential", + * "size": 4, + * "startOffset": 0 + * }, + * "valueGenerator": { + * "type": "uniformRandom", + * "size": 512, + * "seed": 0, + * "padding": 0 + * }, + * "producerConnectionCount": 9, + * "consumerConnectionCount": 9, + * "metadataConnectionCount": 9, + * "topicName": "test-topic1-1", + * "numThreads": 2, + * "refreshRateMs": 10000 + * } + * } + */ +public class SustainedConnectionSpec extends TaskSpec { + private final String clientNode; + private final String bootstrapServers; + private final Map producerConf; + private final Map consumerConf; + private final Map adminClientConf; + private final Map commonClientConf; + private final PayloadGenerator keyGenerator; + private final PayloadGenerator valueGenerator; + private final int producerConnectionCount; + private final int consumerConnectionCount; + private final int metadataConnectionCount; + private final String topicName; + private final int numThreads; + private final int refreshRateMs; + + @JsonCreator + public SustainedConnectionSpec( + @JsonProperty("startMs") long startMs, + @JsonProperty("durationMs") long durationMs, + @JsonProperty("clientNode") String clientNode, + @JsonProperty("bootstrapServers") String bootstrapServers, + @JsonProperty("producerConf") Map producerConf, + @JsonProperty("consumerConf") Map consumerConf, + @JsonProperty("adminClientConf") Map adminClientConf, + @JsonProperty("commonClientConf") Map commonClientConf, + @JsonProperty("keyGenerator") PayloadGenerator keyGenerator, + @JsonProperty("valueGenerator") PayloadGenerator valueGenerator, + @JsonProperty("producerConnectionCount") int producerConnectionCount, + @JsonProperty("consumerConnectionCount") int consumerConnectionCount, + @JsonProperty("metadataConnectionCount") int metadataConnectionCount, + @JsonProperty("topicName") String topicName, + @JsonProperty("numThreads") int numThreads, + @JsonProperty("refreshRateMs") int refreshRateMs) { + super(startMs, durationMs); + this.clientNode = clientNode == null ? "" : clientNode; + this.bootstrapServers = (bootstrapServers == null) ? "" : bootstrapServers; + this.producerConf = configOrEmptyMap(producerConf); + this.consumerConf = configOrEmptyMap(consumerConf); + this.adminClientConf = configOrEmptyMap(adminClientConf); + this.commonClientConf = configOrEmptyMap(commonClientConf); + this.keyGenerator = keyGenerator; + this.valueGenerator = valueGenerator; + this.producerConnectionCount = producerConnectionCount; + this.consumerConnectionCount = consumerConnectionCount; + this.metadataConnectionCount = metadataConnectionCount; + this.topicName = topicName; + this.numThreads = numThreads < 1 ? 1 : numThreads; + this.refreshRateMs = refreshRateMs < 1 ? 1 : refreshRateMs; + } + + @JsonProperty + public String clientNode() { + return clientNode; + } + + @JsonProperty + public String bootstrapServers() { + return bootstrapServers; + } + + @JsonProperty + public Map producerConf() { + return producerConf; + } + + @JsonProperty + public Map consumerConf() { + return consumerConf; + } + + @JsonProperty + public Map adminClientConf() { + return adminClientConf; + } + + @JsonProperty + public Map commonClientConf() { + return commonClientConf; + } + + @JsonProperty + public PayloadGenerator keyGenerator() { + return keyGenerator; + } + + @JsonProperty + public PayloadGenerator valueGenerator() { + return valueGenerator; + } + + @JsonProperty + public int producerConnectionCount() { + return producerConnectionCount; + } + + @JsonProperty + public int consumerConnectionCount() { + return consumerConnectionCount; + } + + @JsonProperty + public int metadataConnectionCount() { + return metadataConnectionCount; + } + + @JsonProperty + public String topicName() { + return topicName; + } + + @JsonProperty + public int numThreads() { + return numThreads; + } + + @JsonProperty + public int refreshRateMs() { + return refreshRateMs; + } + + public TaskController newController(String id) { + return topology -> Collections.singleton(clientNode); + } + + @Override + public TaskWorker newTaskWorker(String id) { + return new SustainedConnectionWorker(id, this); + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/SustainedConnectionWorker.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/SustainedConnectionWorker.java new file mode 100644 index 0000000..8f5faa7 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/SustainedConnectionWorker.java @@ -0,0 +1,531 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonNode; +import org.apache.kafka.clients.admin.Admin; +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.clients.consumer.KafkaConsumer; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.serialization.ByteArrayDeserializer; +import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.apache.kafka.common.utils.SystemTime; +import org.apache.kafka.common.utils.ThreadUtils; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.trogdor.common.JsonUtil; +import org.apache.kafka.trogdor.common.Platform; +import org.apache.kafka.trogdor.common.WorkerUtils; +import org.apache.kafka.trogdor.task.TaskWorker; +import org.apache.kafka.trogdor.task.WorkerStatusTracker; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; +import java.util.Properties; +import java.util.Random; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.stream.Collectors; + +public class SustainedConnectionWorker implements TaskWorker { + private static final Logger log = LoggerFactory.getLogger(SustainedConnectionWorker.class); + private static final SystemTime SYSTEM_TIME = new SystemTime(); + + // This is the metadata for the test itself. + private final String id; + private final SustainedConnectionSpec spec; + + // These variables are used to maintain the connections. + private static final int BACKOFF_PERIOD_MS = 10; + private ExecutorService workerExecutor; + private final AtomicBoolean running = new AtomicBoolean(false); + private KafkaFutureImpl doneFuture; + private ArrayList connections; + + // These variables are used when tracking the reported status of the worker. + private static final int REPORT_INTERVAL_MS = 5000; + private WorkerStatusTracker status; + private AtomicLong totalProducerConnections; + private AtomicLong totalProducerFailedConnections; + private AtomicLong totalConsumerConnections; + private AtomicLong totalConsumerFailedConnections; + private AtomicLong totalMetadataConnections; + private AtomicLong totalMetadataFailedConnections; + private AtomicLong totalAbortedThreads; + private Future statusUpdaterFuture; + private ScheduledExecutorService statusUpdaterExecutor; + + public SustainedConnectionWorker(String id, SustainedConnectionSpec spec) { + this.id = id; + this.spec = spec; + } + + @Override + public void start(Platform platform, WorkerStatusTracker status, + KafkaFutureImpl doneFuture) throws Exception { + if (!running.compareAndSet(false, true)) { + throw new IllegalStateException("SustainedConnectionWorker is already running."); + } + log.info("{}: Activating SustainedConnectionWorker with {}", this.id, this.spec); + this.doneFuture = doneFuture; + this.status = status; + this.connections = new ArrayList<>(); + + // Initialize all status reporting metrics to 0. + this.totalProducerConnections = new AtomicLong(0); + this.totalProducerFailedConnections = new AtomicLong(0); + this.totalConsumerConnections = new AtomicLong(0); + this.totalConsumerFailedConnections = new AtomicLong(0); + this.totalMetadataConnections = new AtomicLong(0); + this.totalMetadataFailedConnections = new AtomicLong(0); + this.totalAbortedThreads = new AtomicLong(0); + + // Create the worker classes and add them to the list of items to act on. + for (int i = 0; i < this.spec.producerConnectionCount(); i++) { + this.connections.add(new ProducerSustainedConnection()); + } + for (int i = 0; i < this.spec.consumerConnectionCount(); i++) { + this.connections.add(new ConsumerSustainedConnection()); + } + for (int i = 0; i < this.spec.metadataConnectionCount(); i++) { + this.connections.add(new MetadataSustainedConnection()); + } + + // Create the status reporter thread and schedule it. + this.statusUpdaterExecutor = Executors.newScheduledThreadPool(1, + ThreadUtils.createThreadFactory("StatusUpdaterWorkerThread%d", false)); + this.statusUpdaterFuture = this.statusUpdaterExecutor.scheduleAtFixedRate( + new StatusUpdater(), 0, REPORT_INTERVAL_MS, TimeUnit.MILLISECONDS); + + // Create the maintainer pool, add all the maintainer threads, then start it. + this.workerExecutor = Executors.newFixedThreadPool(spec.numThreads(), + ThreadUtils.createThreadFactory("SustainedConnectionWorkerThread%d", false)); + for (int i = 0; i < this.spec.numThreads(); i++) { + this.workerExecutor.submit(new MaintainLoop()); + } + } + + private interface SustainedConnection extends AutoCloseable { + boolean needsRefresh(long milliseconds); + void refresh(); + void claim(); + } + + private abstract class ClaimableConnection implements SustainedConnection { + + protected long nextUpdate = 0; + protected boolean inUse = false; + protected long refreshRate; + + @Override + public boolean needsRefresh(long milliseconds) { + return !this.inUse && (milliseconds > this.nextUpdate); + } + + @Override + public void claim() { + this.inUse = true; + } + + @Override + public void close() throws Exception { + this.closeQuietly(); + } + + protected void completeRefresh() { + this.nextUpdate = SustainedConnectionWorker.SYSTEM_TIME.milliseconds() + this.refreshRate; + this.inUse = false; + } + + protected abstract void closeQuietly(); + + } + + private class MetadataSustainedConnection extends ClaimableConnection { + + private Admin client; + private final Properties props; + + MetadataSustainedConnection() { + + // These variables are used to maintain the connection itself. + this.client = null; + this.refreshRate = SustainedConnectionWorker.this.spec.refreshRateMs(); + + // This variable is used to maintain the connection properties. + this.props = new Properties(); + this.props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, SustainedConnectionWorker.this.spec.bootstrapServers()); + WorkerUtils.addConfigsToProperties( + this.props, SustainedConnectionWorker.this.spec.commonClientConf(), SustainedConnectionWorker.this.spec.commonClientConf()); + } + + @Override + public void refresh() { + try { + if (this.client == null) { + // Housekeeping to track the number of opened connections. + SustainedConnectionWorker.this.totalMetadataConnections.incrementAndGet(); + + // Create the admin client connection. + this.client = Admin.create(this.props); + } + + // Fetch some metadata to keep the connection alive. + this.client.describeCluster().nodes().get(); + + } catch (Throwable e) { + // Set the admin client to be recreated on the next cycle. + this.closeQuietly(); + + // Housekeeping to track the number of opened connections and failed connection attempts. + SustainedConnectionWorker.this.totalMetadataConnections.decrementAndGet(); + SustainedConnectionWorker.this.totalMetadataFailedConnections.incrementAndGet(); + SustainedConnectionWorker.log.error("Error while refreshing sustained AdminClient connection", e); + } + + // Schedule this again and set to not in use. + this.completeRefresh(); + } + + @Override + protected void closeQuietly() { + Utils.closeQuietly(this.client, "AdminClient"); + this.client = null; + } + } + + private class ProducerSustainedConnection extends ClaimableConnection { + + private KafkaProducer producer; + private List partitions; + private Iterator partitionsIterator; + private final String topicName; + private final PayloadIterator keys; + private final PayloadIterator values; + private final Properties props; + + ProducerSustainedConnection() { + + // These variables are used to maintain the connection itself. + this.producer = null; + this.partitions = null; + this.topicName = SustainedConnectionWorker.this.spec.topicName(); + this.partitionsIterator = null; + this.keys = new PayloadIterator(SustainedConnectionWorker.this.spec.keyGenerator()); + this.values = new PayloadIterator(SustainedConnectionWorker.this.spec.valueGenerator()); + this.refreshRate = SustainedConnectionWorker.this.spec.refreshRateMs(); + + // This variable is used to maintain the connection properties. + this.props = new Properties(); + this.props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, SustainedConnectionWorker.this.spec.bootstrapServers()); + WorkerUtils.addConfigsToProperties( + this.props, SustainedConnectionWorker.this.spec.commonClientConf(), SustainedConnectionWorker.this.spec.producerConf()); + } + + @Override + public void refresh() { + try { + if (this.producer == null) { + // Housekeeping to track the number of opened connections. + SustainedConnectionWorker.this.totalProducerConnections.incrementAndGet(); + + // Create the producer, fetch the specified topic's partitions and randomize them. + this.producer = new KafkaProducer<>(this.props, new ByteArraySerializer(), new ByteArraySerializer()); + this.partitions = this.producer.partitionsFor(this.topicName).stream() + .map(partitionInfo -> new TopicPartition(partitionInfo.topic(), partitionInfo.partition())) + .collect(Collectors.toList()); + Collections.shuffle(this.partitions); + } + + // Create a new iterator over the partitions if the current one doesn't exist or is exhausted. + if (this.partitionsIterator == null || !this.partitionsIterator.hasNext()) { + this.partitionsIterator = this.partitions.iterator(); + } + + // Produce a single record and send it synchronously. + TopicPartition partition = this.partitionsIterator.next(); + ProducerRecord record = new ProducerRecord<>( + partition.topic(), partition.partition(), keys.next(), values.next()); + producer.send(record).get(); + + } catch (Throwable e) { + // Set the producer to be recreated on the next cycle. + this.closeQuietly(); + + // Housekeeping to track the number of opened connections and failed connection attempts. + SustainedConnectionWorker.this.totalProducerConnections.decrementAndGet(); + SustainedConnectionWorker.this.totalProducerFailedConnections.incrementAndGet(); + SustainedConnectionWorker.log.error("Error while refreshing sustained KafkaProducer connection", e); + + } + + // Schedule this again and set to not in use. + this.completeRefresh(); + } + + @Override + protected void closeQuietly() { + Utils.closeQuietly(this.producer, "KafkaProducer"); + this.producer = null; + this.partitions = null; + this.partitionsIterator = null; + } + } + + private class ConsumerSustainedConnection extends ClaimableConnection { + + private KafkaConsumer consumer; + private TopicPartition activePartition; + private final String topicName; + private final Random rand; + private final Properties props; + + ConsumerSustainedConnection() { + // These variables are used to maintain the connection itself. + this.topicName = SustainedConnectionWorker.this.spec.topicName(); + this.consumer = null; + this.activePartition = null; + this.rand = new Random(); + this.refreshRate = SustainedConnectionWorker.this.spec.refreshRateMs(); + + // This variable is used to maintain the connection properties. + this.props = new Properties(); + this.props.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, SustainedConnectionWorker.this.spec.bootstrapServers()); + this.props.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "latest"); + this.props.put(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, 1); + this.props.put(ConsumerConfig.FETCH_MAX_BYTES_CONFIG, 1024); + WorkerUtils.addConfigsToProperties( + this.props, SustainedConnectionWorker.this.spec.commonClientConf(), SustainedConnectionWorker.this.spec.consumerConf()); + } + + @Override + public void refresh() { + try { + + if (this.consumer == null) { + + // Housekeeping to track the number of opened connections. + SustainedConnectionWorker.this.totalConsumerConnections.incrementAndGet(); + + // Create the consumer and fetch the partitions for the specified topic. + this.consumer = new KafkaConsumer<>(this.props, new ByteArrayDeserializer(), new ByteArrayDeserializer()); + List partitions = this.consumer.partitionsFor(this.topicName).stream() + .map(partitionInfo -> new TopicPartition(partitionInfo.topic(), partitionInfo.partition())) + .collect(Collectors.toList()); + + // Select a random partition and assign it. + this.activePartition = partitions.get(this.rand.nextInt(partitions.size())); + this.consumer.assign(Collections.singletonList(this.activePartition)); + } + + // The behavior when passing in an empty list is to seek to the end of all subscribed partitions. + this.consumer.seekToEnd(Collections.emptyList()); + + // Poll to keep the connection alive, ignoring any records returned. + this.consumer.poll(Duration.ofMillis(50)); + + } catch (Throwable e) { + + // Set the consumer to be recreated on the next cycle. + this.closeQuietly(); + + // Housekeeping to track the number of opened connections and failed connection attempts. + SustainedConnectionWorker.this.totalConsumerConnections.decrementAndGet(); + SustainedConnectionWorker.this.totalConsumerFailedConnections.incrementAndGet(); + SustainedConnectionWorker.log.error("Error while refreshing sustained KafkaConsumer connection", e); + + } + + // Schedule this again and set to not in use. + this.completeRefresh(); + } + + @Override + protected void closeQuietly() { + Utils.closeQuietly(this.consumer, "KafkaConsumer"); + this.consumer = null; + this.activePartition = null; + } + } + + public class MaintainLoop implements Runnable { + @Override + public void run() { + try { + while (!doneFuture.isDone()) { + Optional currentConnection = SustainedConnectionWorker.this.findConnectionToMaintain(); + if (currentConnection.isPresent()) { + currentConnection.get().refresh(); + } else { + SustainedConnectionWorker.SYSTEM_TIME.sleep(SustainedConnectionWorker.BACKOFF_PERIOD_MS); + } + } + } catch (Exception e) { + SustainedConnectionWorker.this.totalAbortedThreads.incrementAndGet(); + SustainedConnectionWorker.log.error("Aborted thread while maintaining sustained connections", e); + } + } + } + + private synchronized Optional findConnectionToMaintain() { + final long milliseconds = SustainedConnectionWorker.SYSTEM_TIME.milliseconds(); + for (SustainedConnection connection : this.connections) { + if (connection.needsRefresh(milliseconds)) { + connection.claim(); + return Optional.of(connection); + } + } + return Optional.empty(); + } + + private class StatusUpdater implements Runnable { + @Override + public void run() { + try { + JsonNode node = JsonUtil.JSON_SERDE.valueToTree( + new StatusData( + SustainedConnectionWorker.this.totalProducerConnections.get(), + SustainedConnectionWorker.this.totalProducerFailedConnections.get(), + SustainedConnectionWorker.this.totalConsumerConnections.get(), + SustainedConnectionWorker.this.totalConsumerFailedConnections.get(), + SustainedConnectionWorker.this.totalMetadataConnections.get(), + SustainedConnectionWorker.this.totalMetadataFailedConnections.get(), + SustainedConnectionWorker.this.totalAbortedThreads.get(), + SustainedConnectionWorker.SYSTEM_TIME.milliseconds())); + status.update(node); + } catch (Exception e) { + SustainedConnectionWorker.log.error("Aborted test while running StatusUpdater", e); + WorkerUtils.abort(log, "StatusUpdater", e, doneFuture); + } + } + } + + public static class StatusData { + private final long totalProducerConnections; + private final long totalProducerFailedConnections; + private final long totalConsumerConnections; + private final long totalConsumerFailedConnections; + private final long totalMetadataConnections; + private final long totalMetadataFailedConnections; + private final long totalAbortedThreads; + private final long updatedMs; + + @JsonCreator + StatusData(@JsonProperty("totalProducerConnections") long totalProducerConnections, + @JsonProperty("totalProducerFailedConnections") long totalProducerFailedConnections, + @JsonProperty("totalConsumerConnections") long totalConsumerConnections, + @JsonProperty("totalConsumerFailedConnections") long totalConsumerFailedConnections, + @JsonProperty("totalMetadataConnections") long totalMetadataConnections, + @JsonProperty("totalMetadataFailedConnections") long totalMetadataFailedConnections, + @JsonProperty("totalAbortedThreads") long totalAbortedThreads, + @JsonProperty("updatedMs") long updatedMs) { + this.totalProducerConnections = totalProducerConnections; + this.totalProducerFailedConnections = totalProducerFailedConnections; + this.totalConsumerConnections = totalConsumerConnections; + this.totalConsumerFailedConnections = totalConsumerFailedConnections; + this.totalMetadataConnections = totalMetadataConnections; + this.totalMetadataFailedConnections = totalMetadataFailedConnections; + this.totalAbortedThreads = totalAbortedThreads; + this.updatedMs = updatedMs; + } + + @JsonProperty + public long totalProducerConnections() { + return totalProducerConnections; + } + + @JsonProperty + public long totalProducerFailedConnections() { + return totalProducerFailedConnections; + } + + @JsonProperty + public long totalConsumerConnections() { + return totalConsumerConnections; + } + + @JsonProperty + public long totalConsumerFailedConnections() { + return totalConsumerFailedConnections; + } + + @JsonProperty + public long totalMetadataConnections() { + return totalMetadataConnections; + } + + @JsonProperty + public long totalMetadataFailedConnections() { + return totalMetadataFailedConnections; + } + + @JsonProperty + public long totalAbortedThreads() { + return totalAbortedThreads; + } + + @JsonProperty + public long updatedMs() { + return updatedMs; + } + } + + @Override + public void stop(Platform platform) throws Exception { + if (!running.compareAndSet(true, false)) { + throw new IllegalStateException("SustainedConnectionWorker is not running."); + } + log.info("{}: Deactivating SustainedConnectionWorker.", this.id); + + // Shut down the periodic status updater and perform a final update on the + // statistics. We want to do this first, before deactivating any threads. + // Otherwise, if some threads take a while to terminate, this could lead + // to a misleading rate getting reported. + this.statusUpdaterFuture.cancel(false); + this.statusUpdaterExecutor.shutdown(); + this.statusUpdaterExecutor.awaitTermination(1, TimeUnit.HOURS); + this.statusUpdaterExecutor = null; + new StatusUpdater().run(); + + doneFuture.complete(""); + for (SustainedConnection connection : this.connections) { + connection.close(); + } + workerExecutor.shutdownNow(); + workerExecutor.awaitTermination(1, TimeUnit.HOURS); + this.workerExecutor = null; + this.status = null; + this.connections = null; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/Throttle.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/Throttle.java new file mode 100644 index 0000000..6a99c02 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/Throttle.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import org.apache.kafka.common.utils.Time; + +public class Throttle { + private final int maxPerPeriod; + private final int periodMs; + private int count; + private long prevPeriod; + private long lastTimeMs; + + Throttle(int maxPerPeriod, int periodMs) { + this.maxPerPeriod = maxPerPeriod; + this.periodMs = periodMs; + this.count = maxPerPeriod; + this.prevPeriod = -1; + this.lastTimeMs = 0; + } + + synchronized public boolean increment() throws InterruptedException { + boolean throttled = false; + while (true) { + if (count < maxPerPeriod) { + count++; + return throttled; + } + lastTimeMs = time().milliseconds(); + long curPeriod = lastTimeMs / periodMs; + if (curPeriod <= prevPeriod) { + long nextPeriodMs = (curPeriod + 1) * periodMs; + delay(nextPeriodMs - lastTimeMs); + throttled = true; + } else { + prevPeriod = curPeriod; + count = 0; + } + } + } + + public synchronized long lastTimeMs() { + return lastTimeMs; + } + + protected Time time() { + return Time.SYSTEM; + } + + protected synchronized void delay(long amount) throws InterruptedException { + if (amount > 0) { + wait(amount); + } + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ThroughputGenerator.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ThroughputGenerator.java new file mode 100644 index 0000000..b6be8a9 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/ThroughputGenerator.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; + +/** + * This interface is used to facilitate running a configurable number of messages per second by throttling if the + * throughput goes above a certain amount. + * + * Currently there are 2 throughput methods: + * + * * `constant` will use `ConstantThroughputGenerator` to keep the number of messages per second constant. + * * `gaussian` will use `GaussianThroughputGenerator` to vary the number of messages per second on a normal + * distribution. + * + * Please see the implementation classes for more details. + */ + +@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, + include = JsonTypeInfo.As.PROPERTY, + property = "type") +@JsonSubTypes(value = { + @JsonSubTypes.Type(value = ConstantThroughputGenerator.class, name = "constant"), + @JsonSubTypes.Type(value = GaussianThroughputGenerator.class, name = "gaussian") + }) +public interface ThroughputGenerator { + void throttle() throws InterruptedException; +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/TimeIntervalTransactionsGenerator.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/TimeIntervalTransactionsGenerator.java new file mode 100644 index 0000000..8d5f05b --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/TimeIntervalTransactionsGenerator.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.kafka.common.utils.Time; + +/** + * A transactions generator where we commit a transaction every N milliseconds + */ +public class TimeIntervalTransactionsGenerator implements TransactionGenerator { + + private static final long NULL_START_MS = -1; + + private final Time time; + private final int intervalMs; + + private long lastTransactionStartMs = NULL_START_MS; + + @JsonCreator + public TimeIntervalTransactionsGenerator(@JsonProperty("transactionIntervalMs") int intervalMs) { + this(intervalMs, Time.SYSTEM); + } + + TimeIntervalTransactionsGenerator(@JsonProperty("transactionIntervalMs") int intervalMs, + Time time) { + if (intervalMs < 1) { + throw new IllegalArgumentException("Cannot have a negative interval"); + } + this.time = time; + this.intervalMs = intervalMs; + } + + @JsonProperty + public int transactionIntervalMs() { + return intervalMs; + } + + @Override + public synchronized TransactionAction nextAction() { + if (lastTransactionStartMs == NULL_START_MS) { + lastTransactionStartMs = time.milliseconds(); + return TransactionAction.BEGIN_TRANSACTION; + } + if (time.milliseconds() - lastTransactionStartMs >= intervalMs) { + lastTransactionStartMs = NULL_START_MS; + return TransactionAction.COMMIT_TRANSACTION; + } + + return TransactionAction.NO_OP; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/TimestampConstantPayloadGenerator.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/TimestampConstantPayloadGenerator.java new file mode 100644 index 0000000..e9c4bc8 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/TimestampConstantPayloadGenerator.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.kafka.common.utils.Time; + +import java.nio.ByteOrder; +import java.nio.ByteBuffer; + +/** + * A PayloadGenerator which generates a timestamped constant payload. + * + * The timestamp used for this class is in milliseconds since epoch, encoded directly to the first several bytes of the + * payload. + * + * This should be used in conjunction with TimestampRecordProcessor in the Consumer to measure true end-to-end latency + * of a system. + * + * `size` - The size in bytes of each message. + * + * Here is an example spec: + * + * { + * "type": "timestampConstant", + * "size": 512 + * } + * + * This will generate a 512-byte message with the first several bytes encoded with the timestamp. + */ +public class TimestampConstantPayloadGenerator implements PayloadGenerator { + private final int size; + private final ByteBuffer buffer; + + @JsonCreator + public TimestampConstantPayloadGenerator(@JsonProperty("size") int size) { + this.size = size; + if (size < Long.BYTES) { + throw new RuntimeException("The size of the payload must be greater than or equal to " + Long.BYTES + "."); + } + buffer = ByteBuffer.allocate(Long.BYTES); + buffer.order(ByteOrder.LITTLE_ENDIAN); + } + + @JsonProperty + public int size() { + return size; + } + + @Override + public synchronized byte[] generate(long position) { + // Generate the byte array before the timestamp generation. + byte[] result = new byte[size]; + + // Do the timestamp generation as the very last task. + buffer.clear(); + buffer.putLong(Time.SYSTEM.milliseconds()); + buffer.rewind(); + System.arraycopy(buffer.array(), 0, result, 0, Long.BYTES); + return result; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/TimestampRandomPayloadGenerator.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/TimestampRandomPayloadGenerator.java new file mode 100644 index 0000000..d2cfdc7 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/TimestampRandomPayloadGenerator.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.kafka.common.utils.Time; + +import java.nio.ByteOrder; +import java.util.Random; +import java.nio.ByteBuffer; + +/** + * A PayloadGenerator which generates a timestamped uniform random payload. + * + * This generator generates pseudo-random payloads that can be reproduced from run to run. + * The guarantees are the same as those of java.util.Random. + * + * The timestamp used for this class is in milliseconds since epoch, encoded directly to the first several bytes of the + * payload. + * + * This should be used in conjunction with TimestampRecordProcessor in the Consumer to measure true end-to-end latency + * of a system. + * + * `size` - The size in bytes of each message. + * `seed` - Used to initialize Random() to remove some non-determinism. + * + * Here is an example spec: + * + * { + * "type": "timestampRandom", + * "size": 512 + * } + * + * This will generate a 512-byte random message with the first several bytes encoded with the timestamp. + */ +public class TimestampRandomPayloadGenerator implements PayloadGenerator { + private final int size; + private final long seed; + + private final byte[] randomBytes; + private final ByteBuffer buffer; + + private final Random random = new Random(); + + @JsonCreator + public TimestampRandomPayloadGenerator(@JsonProperty("size") int size, + @JsonProperty("seed") long seed) { + this.size = size; + this.seed = seed; + if (size < Long.BYTES) { + throw new RuntimeException("The size of the payload must be greater than or equal to " + Long.BYTES + "."); + } + random.setSeed(seed); + this.randomBytes = new byte[size - Long.BYTES]; + buffer = ByteBuffer.allocate(Long.BYTES); + buffer.order(ByteOrder.LITTLE_ENDIAN); + } + + @JsonProperty + public int size() { + return size; + } + + @JsonProperty + public long seed() { + return seed; + } + + @Override + public synchronized byte[] generate(long position) { + // Generate out of order to prevent inclusion of random number generation in latency numbers. + byte[] result = new byte[size]; + if (randomBytes.length > 0) { + random.setSeed(seed + position); + random.nextBytes(randomBytes); + System.arraycopy(randomBytes, 0, result, Long.BYTES, randomBytes.length); + } + + // Do the timestamp generation as the very last task. + buffer.clear(); + buffer.putLong(Time.SYSTEM.milliseconds()); + buffer.rewind(); + System.arraycopy(buffer.array(), 0, result, 0, Long.BYTES); + return result; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/TimestampRecordProcessor.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/TimestampRecordProcessor.java new file mode 100644 index 0000000..035d459 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/TimestampRecordProcessor.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonNode; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.trogdor.common.JsonUtil; +import org.apache.kafka.common.utils.Time; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/** + * This class will process records containing timestamps and generate a histogram based on the data. It will then be + * present in the status from the `ConsumeBenchWorker` class. This must be used with a timestamped PayloadGenerator + * implementation. + * + * Example spec: + * { + * "type": "timestamp", + * "histogramMaxMs": 10000, + * "histogramMinMs": 0, + * "histogramStepMs": 1 + * } + * + * This will track total E2E latency up to 10 seconds, using 1ms resolution and a timestamp size of 8 bytes. + */ + +public class TimestampRecordProcessor implements RecordProcessor { + private final int histogramMaxMs; + private final int histogramMinMs; + private final int histogramStepMs; + private final ByteBuffer buffer; + private final Histogram histogram; + + private final Logger log = LoggerFactory.getLogger(TimestampRecordProcessor.class); + + final static float[] PERCENTILES = {0.5f, 0.95f, 0.99f}; + + @JsonCreator + public TimestampRecordProcessor(@JsonProperty("histogramMaxMs") int histogramMaxMs, + @JsonProperty("histogramMinMs") int histogramMinMs, + @JsonProperty("histogramStepMs") int histogramStepMs) { + this.histogramMaxMs = histogramMaxMs; + this.histogramMinMs = histogramMinMs; + this.histogramStepMs = histogramStepMs; + this.histogram = new Histogram((histogramMaxMs - histogramMinMs) / histogramStepMs); + buffer = ByteBuffer.allocate(Long.BYTES); + buffer.order(ByteOrder.LITTLE_ENDIAN); + } + + @JsonProperty + public int histogramMaxMs() { + return histogramMaxMs; + } + + @JsonProperty + public int histogramMinMs() { + return histogramMinMs; + } + + @JsonProperty + public int histogramStepMs() { + return histogramStepMs; + } + + private void putHistogram(long latency) { + histogram.add(Long.max(0L, (latency - histogramMinMs) / histogramStepMs)); + } + + @Override + public synchronized void processRecords(ConsumerRecords consumerRecords) { + // Save the current time to prevent skew by processing time. + long curTime = Time.SYSTEM.milliseconds(); + for (ConsumerRecord record : consumerRecords) { + try { + buffer.clear(); + buffer.put(record.value(), 0, Long.BYTES); + buffer.rewind(); + putHistogram(curTime - buffer.getLong()); + } catch (RuntimeException e) { + log.error("Error in processRecords:", e); + } + } + } + + @Override + public JsonNode processorStatus() { + Histogram.Summary summary = histogram.summarize(PERCENTILES); + StatusData statusData = new StatusData( + summary.average() * histogramStepMs + histogramMinMs, + summary.percentiles().get(0).value() * histogramStepMs + histogramMinMs, + summary.percentiles().get(1).value() * histogramStepMs + histogramMinMs, + summary.percentiles().get(2).value() * histogramStepMs + histogramMinMs); + return JsonUtil.JSON_SERDE.valueToTree(statusData); + } + + private static class StatusData { + private final float averageLatencyMs; + private final int p50LatencyMs; + private final int p95LatencyMs; + private final int p99LatencyMs; + + /** + * The percentiles to use when calculating the histogram data. + * These should match up with the p50LatencyMs, p95LatencyMs, etc. fields. + */ + final static float[] PERCENTILES = {0.5f, 0.95f, 0.99f}; + + @JsonCreator + StatusData(@JsonProperty("averageLatencyMs") float averageLatencyMs, + @JsonProperty("p50LatencyMs") int p50latencyMs, + @JsonProperty("p95LatencyMs") int p95latencyMs, + @JsonProperty("p99LatencyMs") int p99latencyMs) { + this.averageLatencyMs = averageLatencyMs; + this.p50LatencyMs = p50latencyMs; + this.p95LatencyMs = p95latencyMs; + this.p99LatencyMs = p99latencyMs; + } + + @JsonProperty + public float averageLatencyMs() { + return averageLatencyMs; + } + + @JsonProperty + public int p50LatencyMs() { + return p50LatencyMs; + } + + @JsonProperty + public int p95LatencyMs() { + return p95LatencyMs; + } + + @JsonProperty + public int p99LatencyMs() { + return p99LatencyMs; + } + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/TopicsSpec.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/TopicsSpec.java new file mode 100644 index 0000000..dcb8d8a --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/TopicsSpec.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonAnyGetter; +import com.fasterxml.jackson.annotation.JsonAnySetter; +import com.fasterxml.jackson.annotation.JsonCreator; +import org.apache.kafka.trogdor.common.StringExpander; +import org.apache.kafka.trogdor.rest.Message; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * TopicsSpec maps topic names to descriptions of the partitions in them. + * + * In JSON form, this is serialized as a map whose keys are topic names, + * and whose entries are partition descriptions. + * Keys may also refer to multiple partitions. For example, this specification + * refers to 3 topics foo1, foo2, and foo3: + * + * { + * "foo[1-3]" : { + * "numPartitions": 3 + * "replicationFactor": 3 + * } + * } + */ +public class TopicsSpec extends Message { + public static final TopicsSpec EMPTY = new TopicsSpec().immutableCopy(); + + private final Map map; + + @JsonCreator + public TopicsSpec() { + this.map = new HashMap<>(); + } + + private TopicsSpec(Map map) { + this.map = map; + } + + @JsonAnyGetter + public Map get() { + return map; + } + + @JsonAnySetter + public void set(String name, PartitionsSpec value) { + map.put(name, value); + } + + public TopicsSpec immutableCopy() { + HashMap mapCopy = new HashMap<>(); + mapCopy.putAll(map); + return new TopicsSpec(Collections.unmodifiableMap(mapCopy)); + } + + /** + * Enumerate the partitions inside this TopicsSpec. + * + * @return A map from topic names to PartitionsSpec objects. + */ + public Map materialize() { + HashMap all = new HashMap<>(); + for (Map.Entry entry : map.entrySet()) { + String topicName = entry.getKey(); + PartitionsSpec partitions = entry.getValue(); + for (String expandedTopicName : StringExpander.expand(topicName)) + all.put(expandedTopicName, partitions); + } + return all; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/TransactionGenerator.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/TransactionGenerator.java new file mode 100644 index 0000000..b2e8add --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/TransactionGenerator.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; + +/** + * Generates actions that should be taken by a producer that uses transactions. + */ +@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, + include = JsonTypeInfo.As.PROPERTY, + property = "type") +@JsonSubTypes(value = { + @JsonSubTypes.Type(value = UniformTransactionsGenerator.class, name = "uniform"), + @JsonSubTypes.Type(value = TimeIntervalTransactionsGenerator.class, name = "interval"), +}) +public interface TransactionGenerator { + enum TransactionAction { + BEGIN_TRANSACTION, COMMIT_TRANSACTION, ABORT_TRANSACTION, NO_OP + } + + /** + * Returns the next action that the producer should take in regards to transactions. + * This method should be called every time before a producer sends a message. + * This means that most of the time it should return #{@link TransactionAction#NO_OP} + * to signal the producer that its next step should be to send a message. + */ + TransactionAction nextAction(); +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/UniformRandomPayloadGenerator.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/UniformRandomPayloadGenerator.java new file mode 100644 index 0000000..4642dcf --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/UniformRandomPayloadGenerator.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Random; + +/** + * A PayloadGenerator which generates a uniform random payload. + * + * This generator generates pseudo-random payloads that can be reproduced from run to run. + * The guarantees are the same as those of java.util.Random. + * + * This payload generator also has the option to append padding bytes at the end of the payload. + * The padding bytes are always the same, no matter what the position is. This is useful when + * simulating a partly-compressible stream of user data. + */ +public class UniformRandomPayloadGenerator implements PayloadGenerator { + private final int size; + private final long seed; + private final int padding; + private final Random random = new Random(); + private final byte[] padBytes; + private final byte[] randomBytes; + + @JsonCreator + public UniformRandomPayloadGenerator(@JsonProperty("size") int size, + @JsonProperty("seed") long seed, + @JsonProperty("padding") int padding) { + this.size = size; + this.seed = seed; + this.padding = padding; + if (padding < 0 || padding > size) { + throw new RuntimeException("Invalid value " + padding + " for " + + "padding: the number of padding bytes must not be smaller than " + + "0 or greater than the total payload size."); + } + this.padBytes = new byte[padding]; + random.setSeed(seed); + random.nextBytes(padBytes); + this.randomBytes = new byte[size - padding]; + } + + @JsonProperty + public int size() { + return size; + } + + @JsonProperty + public long seed() { + return seed; + } + + @JsonProperty + public int padding() { + return padding; + } + + @Override + public synchronized byte[] generate(long position) { + byte[] result = new byte[size]; + if (randomBytes.length > 0) { + random.setSeed(seed + position); + random.nextBytes(randomBytes); + System.arraycopy(randomBytes, 0, result, 0, Math.min(randomBytes.length, result.length)); + } + if (padBytes.length > 0) { + System.arraycopy(padBytes, 0, result, randomBytes.length, result.length - randomBytes.length); + } + return result; + } +} diff --git a/trogdor/src/main/java/org/apache/kafka/trogdor/workload/UniformTransactionsGenerator.java b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/UniformTransactionsGenerator.java new file mode 100644 index 0000000..1fbfbc2 --- /dev/null +++ b/trogdor/src/main/java/org/apache/kafka/trogdor/workload/UniformTransactionsGenerator.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * A uniform transactions generator where every N records are grouped in a separate transaction + */ +public class UniformTransactionsGenerator implements TransactionGenerator { + + private final int messagesPerTransaction; + private int messagesInTransaction = -1; + + @JsonCreator + public UniformTransactionsGenerator(@JsonProperty("messagesPerTransaction") int messagesPerTransaction) { + if (messagesPerTransaction < 1) + throw new IllegalArgumentException("Cannot have less than one message per transaction."); + + this.messagesPerTransaction = messagesPerTransaction; + } + + @JsonProperty + public int messagesPerTransaction() { + return messagesPerTransaction; + } + + @Override + public synchronized TransactionAction nextAction() { + if (messagesInTransaction == -1) { + messagesInTransaction = 0; + return TransactionAction.BEGIN_TRANSACTION; + } + if (messagesInTransaction == messagesPerTransaction) { + messagesInTransaction = -1; + return TransactionAction.COMMIT_TRANSACTION; + } + + messagesInTransaction += 1; + return TransactionAction.NO_OP; + } +} diff --git a/trogdor/src/test/java/org/apache/kafka/trogdor/agent/AgentTest.java b/trogdor/src/test/java/org/apache/kafka/trogdor/agent/AgentTest.java new file mode 100644 index 0000000..76c39b4 --- /dev/null +++ b/trogdor/src/test/java/org/apache/kafka/trogdor/agent/AgentTest.java @@ -0,0 +1,490 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.agent; + +import static java.util.Arrays.asList; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.fasterxml.jackson.databind.node.TextNode; +import org.apache.kafka.common.utils.MockScheduler; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Scheduler; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.test.TestUtils; +import org.apache.kafka.trogdor.basic.BasicNode; +import org.apache.kafka.trogdor.basic.BasicPlatform; +import org.apache.kafka.trogdor.basic.BasicTopology; +import org.apache.kafka.trogdor.common.ExpectedTasks; +import org.apache.kafka.trogdor.common.ExpectedTasks.ExpectedTaskBuilder; +import org.apache.kafka.trogdor.common.JsonUtil; +import org.apache.kafka.trogdor.common.Node; +import org.apache.kafka.trogdor.common.Platform; +import org.apache.kafka.trogdor.fault.FilesUnreadableFaultSpec; +import org.apache.kafka.trogdor.fault.Kibosh; +import org.apache.kafka.trogdor.fault.Kibosh.KiboshControlFile; +import org.apache.kafka.trogdor.fault.Kibosh.KiboshFilesUnreadableFaultSpec; +import org.apache.kafka.trogdor.rest.AgentStatusResponse; +import org.apache.kafka.trogdor.rest.CreateWorkerRequest; +import org.apache.kafka.trogdor.rest.DestroyWorkerRequest; +import org.apache.kafka.trogdor.rest.JsonRestServer; +import org.apache.kafka.trogdor.rest.RequestConflictException; +import org.apache.kafka.trogdor.rest.StopWorkerRequest; +import org.apache.kafka.trogdor.rest.TaskDone; +import org.apache.kafka.trogdor.rest.UptimeResponse; +import org.apache.kafka.trogdor.rest.WorkerDone; +import org.apache.kafka.trogdor.rest.WorkerRunning; +import org.apache.kafka.trogdor.task.NoOpTaskSpec; +import org.apache.kafka.trogdor.task.SampleTaskSpec; +import org.apache.kafka.trogdor.task.TaskSpec; +import org.junit.jupiter.api.Test; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.io.PrintStream; +import java.nio.charset.StandardCharsets; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Collections; +import java.util.HashMap; +import java.util.TreeMap; +import org.junit.jupiter.api.Timeout; + +@Timeout(value = 120000, unit = MILLISECONDS) +public class AgentTest { + + private static BasicPlatform createBasicPlatform(Scheduler scheduler) { + TreeMap nodes = new TreeMap<>(); + HashMap config = new HashMap<>(); + config.put(Platform.Config.TROGDOR_AGENT_PORT, Integer.toString(Agent.DEFAULT_PORT)); + nodes.put("node01", new BasicNode("node01", "localhost", + config, Collections.emptySet())); + BasicTopology topology = new BasicTopology(nodes); + return new BasicPlatform("node01", topology, + scheduler, new BasicPlatform.ShellCommandRunner()); + } + + private Agent createAgent(Scheduler scheduler) { + JsonRestServer restServer = new JsonRestServer(0); + AgentRestResource resource = new AgentRestResource(); + restServer.start(resource); + return new Agent(createBasicPlatform(scheduler), scheduler, + restServer, resource); + } + + @Test + public void testAgentStartShutdown() throws Exception { + Agent agent = createAgent(Scheduler.SYSTEM); + agent.beginShutdown(); + agent.waitForShutdown(); + } + + @Test + public void testAgentProgrammaticShutdown() throws Exception { + Agent agent = createAgent(Scheduler.SYSTEM); + AgentClient client = new AgentClient.Builder(). + maxTries(10).target("localhost", agent.port()).build(); + client.invokeShutdown(); + agent.waitForShutdown(); + } + + @Test + public void testAgentGetStatus() throws Exception { + Agent agent = createAgent(Scheduler.SYSTEM); + AgentClient client = new AgentClient.Builder(). + maxTries(10).target("localhost", agent.port()).build(); + AgentStatusResponse status = client.status(); + assertEquals(agent.status(), status); + agent.beginShutdown(); + agent.waitForShutdown(); + } + + @Test + public void testCreateExpiredWorkerIsNotScheduled() throws Exception { + long initialTimeMs = 100; + long tickMs = 15; + final boolean[] toSleep = {true}; + MockTime time = new MockTime(tickMs, initialTimeMs, 0) { + /** + * Modify sleep() to call super.sleep() every second call + * in order to avoid the endless loop in the tick() calls to the MockScheduler listener + */ + @Override + public void sleep(long ms) { + toSleep[0] = !toSleep[0]; + if (toSleep[0]) + super.sleep(ms); + } + }; + MockScheduler scheduler = new MockScheduler(time); + Agent agent = createAgent(scheduler); + AgentClient client = new AgentClient.Builder(). + maxTries(10).target("localhost", agent.port()).build(); + AgentStatusResponse status = client.status(); + + assertEquals(Collections.emptyMap(), status.workers()); + new ExpectedTasks().waitFor(client); + + final NoOpTaskSpec fooSpec = new NoOpTaskSpec(10, 10); + client.createWorker(new CreateWorkerRequest(0, "foo", fooSpec)); + long actualStartTimeMs = initialTimeMs + tickMs; + long doneMs = actualStartTimeMs + 2 * tickMs; + new ExpectedTasks().addTask(new ExpectedTaskBuilder("foo"). + workerState(new WorkerDone("foo", fooSpec, actualStartTimeMs, + doneMs, null, "worker expired")). + taskState(new TaskDone(fooSpec, actualStartTimeMs, doneMs, "worker expired", false, null)). + build()). + waitFor(client); + } + + @Test + public void testAgentGetUptime() throws Exception { + MockTime time = new MockTime(0, 111, 0); + MockScheduler scheduler = new MockScheduler(time); + Agent agent = createAgent(scheduler); + AgentClient client = new AgentClient.Builder(). + maxTries(10).target("localhost", agent.port()).build(); + + UptimeResponse uptime = client.uptime(); + assertEquals(agent.uptime(), uptime); + + time.setCurrentTimeMs(150); + assertNotEquals(agent.uptime(), uptime); + agent.beginShutdown(); + agent.waitForShutdown(); + } + + @Test + public void testAgentCreateWorkers() throws Exception { + MockTime time = new MockTime(0, 0, 0); + MockScheduler scheduler = new MockScheduler(time); + Agent agent = createAgent(scheduler); + AgentClient client = new AgentClient.Builder(). + maxTries(10).target("localhost", agent.port()).build(); + AgentStatusResponse status = client.status(); + assertEquals(Collections.emptyMap(), status.workers()); + new ExpectedTasks().waitFor(client); + + final NoOpTaskSpec fooSpec = new NoOpTaskSpec(1000, 600000); + client.createWorker(new CreateWorkerRequest(0, "foo", fooSpec)); + new ExpectedTasks().addTask(new ExpectedTaskBuilder("foo"). + workerState(new WorkerRunning("foo", fooSpec, 0, new TextNode("active"))). + build()). + waitFor(client); + + final NoOpTaskSpec barSpec = new NoOpTaskSpec(2000, 900000); + client.createWorker(new CreateWorkerRequest(1, "bar", barSpec)); + client.createWorker(new CreateWorkerRequest(1, "bar", barSpec)); + + assertThrows(RequestConflictException.class, + () -> client.createWorker(new CreateWorkerRequest(1, "foo", barSpec)), + "Recreating a request with a different taskId is not allowed"); + assertThrows(RequestConflictException.class, + () -> client.createWorker(new CreateWorkerRequest(1, "bar", fooSpec)), + "Recreating a request with a different spec is not allowed"); + + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + workerState(new WorkerRunning("foo", fooSpec, 0, new TextNode("active"))). + build()). + addTask(new ExpectedTaskBuilder("bar"). + workerState(new WorkerRunning("bar", barSpec, 0, new TextNode("active"))). + build()). + waitFor(client); + + final NoOpTaskSpec bazSpec = new NoOpTaskSpec(1, 450000); + client.createWorker(new CreateWorkerRequest(2, "baz", bazSpec)); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + workerState(new WorkerRunning("foo", fooSpec, 0, new TextNode("active"))). + build()). + addTask(new ExpectedTaskBuilder("bar"). + workerState(new WorkerRunning("bar", barSpec, 0, new TextNode("active"))). + build()). + addTask(new ExpectedTaskBuilder("baz"). + workerState(new WorkerRunning("baz", bazSpec, 0, new TextNode("active"))). + build()). + waitFor(client); + + agent.beginShutdown(); + agent.waitForShutdown(); + } + + @Test + public void testAgentFinishesTasks() throws Exception { + long startTimeMs = 2000; + MockTime time = new MockTime(0, startTimeMs, 0); + MockScheduler scheduler = new MockScheduler(time); + Agent agent = createAgent(scheduler); + AgentClient client = new AgentClient.Builder(). + maxTries(10).target("localhost", agent.port()).build(); + new ExpectedTasks().waitFor(client); + + final NoOpTaskSpec fooSpec = new NoOpTaskSpec(startTimeMs, 2); + long fooSpecStartTimeMs = startTimeMs; + client.createWorker(new CreateWorkerRequest(0, "foo", fooSpec)); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + workerState(new WorkerRunning("foo", fooSpec, startTimeMs, new TextNode("active"))). + build()). + waitFor(client); + + time.sleep(1); + + long barSpecWorkerId = 1; + long barSpecStartTimeMs = startTimeMs + 1; + final NoOpTaskSpec barSpec = new NoOpTaskSpec(startTimeMs, 900000); + client.createWorker(new CreateWorkerRequest(barSpecWorkerId, "bar", barSpec)); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + workerState(new WorkerRunning("foo", fooSpec, fooSpecStartTimeMs, new TextNode("active"))). + build()). + addTask(new ExpectedTaskBuilder("bar"). + workerState(new WorkerRunning("bar", barSpec, barSpecStartTimeMs, new TextNode("active"))). + build()). + waitFor(client); + + time.sleep(1); + + // foo task expired + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + workerState(new WorkerDone("foo", fooSpec, fooSpecStartTimeMs, fooSpecStartTimeMs + 2, new TextNode("done"), "")). + build()). + addTask(new ExpectedTaskBuilder("bar"). + workerState(new WorkerRunning("bar", barSpec, barSpecStartTimeMs, new TextNode("active"))). + build()). + waitFor(client); + + time.sleep(5); + client.stopWorker(new StopWorkerRequest(barSpecWorkerId)); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + workerState(new WorkerDone("foo", fooSpec, fooSpecStartTimeMs, fooSpecStartTimeMs + 2, new TextNode("done"), "")). + build()). + addTask(new ExpectedTaskBuilder("bar"). + workerState(new WorkerDone("bar", barSpec, barSpecStartTimeMs, startTimeMs + 7, new TextNode("done"), "")). + build()). + waitFor(client); + + agent.beginShutdown(); + agent.waitForShutdown(); + } + + @Test + public void testWorkerCompletions() throws Exception { + MockTime time = new MockTime(0, 0, 0); + MockScheduler scheduler = new MockScheduler(time); + Agent agent = createAgent(scheduler); + AgentClient client = new AgentClient.Builder(). + maxTries(10).target("localhost", agent.port()).build(); + new ExpectedTasks().waitFor(client); + + SampleTaskSpec fooSpec = new SampleTaskSpec(0, 900000, + Collections.singletonMap("node01", 1L), ""); + client.createWorker(new CreateWorkerRequest(0, "foo", fooSpec)); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + workerState(new WorkerRunning("foo", fooSpec, 0, new TextNode("active"))). + build()). + waitFor(client); + + SampleTaskSpec barSpec = new SampleTaskSpec(0, 900000, + Collections.singletonMap("node01", 2L), "baz"); + client.createWorker(new CreateWorkerRequest(1, "bar", barSpec)); + + time.sleep(1); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + workerState(new WorkerDone("foo", fooSpec, 0, 1, + new TextNode("halted"), "")). + build()). + addTask(new ExpectedTaskBuilder("bar"). + workerState(new WorkerRunning("bar", barSpec, 0, + new TextNode("active"))). + build()). + waitFor(client); + + time.sleep(1); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + workerState(new WorkerDone("foo", fooSpec, 0, 1, + new TextNode("halted"), "")). + build()). + addTask(new ExpectedTaskBuilder("bar"). + workerState(new WorkerDone("bar", barSpec, 0, 2, + new TextNode("halted"), "baz")). + build()). + waitFor(client); + } + + private static class MockKibosh implements AutoCloseable { + private final File tempDir; + private final Path controlFile; + + MockKibosh() throws IOException { + tempDir = TestUtils.tempDirectory(); + controlFile = Paths.get(tempDir.toPath().toString(), Kibosh.KIBOSH_CONTROL); + KiboshControlFile.EMPTY.write(controlFile); + } + + KiboshControlFile read() throws IOException { + return KiboshControlFile.read(controlFile); + } + + @Override + public void close() throws Exception { + Utils.delete(tempDir); + } + } + + @Test + public void testKiboshFaults() throws Exception { + MockTime time = new MockTime(0, 0, 0); + MockScheduler scheduler = new MockScheduler(time); + Agent agent = createAgent(scheduler); + AgentClient client = new AgentClient.Builder(). + maxTries(10).target("localhost", agent.port()).build(); + new ExpectedTasks().waitFor(client); + + try (MockKibosh mockKibosh = new MockKibosh()) { + assertEquals(KiboshControlFile.EMPTY, mockKibosh.read()); + FilesUnreadableFaultSpec fooSpec = new FilesUnreadableFaultSpec(0, 900000, + Collections.singleton("myAgent"), mockKibosh.tempDir.getPath(), "/foo", 123); + client.createWorker(new CreateWorkerRequest(0, "foo", fooSpec)); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + workerState(new WorkerRunning("foo", fooSpec, 0, new TextNode("Added fault foo"))). + build()). + waitFor(client); + assertEquals(new KiboshControlFile(Collections.singletonList( + new KiboshFilesUnreadableFaultSpec("/foo", 123))), mockKibosh.read()); + FilesUnreadableFaultSpec barSpec = new FilesUnreadableFaultSpec(0, 900000, + Collections.singleton("myAgent"), mockKibosh.tempDir.getPath(), "/bar", 456); + client.createWorker(new CreateWorkerRequest(1, "bar", barSpec)); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + workerState(new WorkerRunning("foo", fooSpec, 0, new TextNode("Added fault foo"))).build()). + addTask(new ExpectedTaskBuilder("bar"). + workerState(new WorkerRunning("bar", barSpec, 0, new TextNode("Added fault bar"))).build()). + waitFor(client); + assertEquals(new KiboshControlFile(asList( + new KiboshFilesUnreadableFaultSpec("/foo", 123), + new KiboshFilesUnreadableFaultSpec("/bar", 456)) + ), mockKibosh.read()); + time.sleep(1); + client.stopWorker(new StopWorkerRequest(0)); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + workerState(new WorkerDone("foo", fooSpec, 0, 1, new TextNode("Removed fault foo"), "")).build()). + addTask(new ExpectedTaskBuilder("bar"). + workerState(new WorkerRunning("bar", barSpec, 0, new TextNode("Added fault bar"))).build()). + waitFor(client); + assertEquals(new KiboshControlFile(Collections.singletonList( + new KiboshFilesUnreadableFaultSpec("/bar", 456))), mockKibosh.read()); + } + } + + @Test + public void testDestroyWorkers() throws Exception { + MockTime time = new MockTime(0, 0, 0); + MockScheduler scheduler = new MockScheduler(time); + Agent agent = createAgent(scheduler); + AgentClient client = new AgentClient.Builder(). + maxTries(10).target("localhost", agent.port()).build(); + new ExpectedTasks().waitFor(client); + + final NoOpTaskSpec fooSpec = new NoOpTaskSpec(0, 5); + client.createWorker(new CreateWorkerRequest(0, "foo", fooSpec)); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + workerState(new WorkerRunning("foo", fooSpec, 0, new TextNode("active"))). + build()). + waitFor(client); + time.sleep(1); + + client.destroyWorker(new DestroyWorkerRequest(0)); + client.destroyWorker(new DestroyWorkerRequest(0)); + client.destroyWorker(new DestroyWorkerRequest(1)); + new ExpectedTasks().waitFor(client); + time.sleep(1); + + final NoOpTaskSpec fooSpec2 = new NoOpTaskSpec(2, 1); + client.createWorker(new CreateWorkerRequest(1, "foo", fooSpec2)); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + workerState(new WorkerRunning("foo", fooSpec2, 2, new TextNode("active"))). + build()). + waitFor(client); + + time.sleep(2); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + workerState(new WorkerDone("foo", fooSpec2, 2, 4, new TextNode("done"), "")). + build()). + waitFor(client); + + time.sleep(1); + client.destroyWorker(new DestroyWorkerRequest(1)); + new ExpectedTasks().waitFor(client); + + agent.beginShutdown(); + agent.waitForShutdown(); + } + + static void testExec(Agent agent, String expected, boolean expectedReturn, TaskSpec spec) throws Exception { + ByteArrayOutputStream b = new ByteArrayOutputStream(); + PrintStream p = new PrintStream(b, true, StandardCharsets.UTF_8.toString()); + boolean actualReturn = agent.exec(spec, p); + assertEquals(expected, b.toString()); + assertEquals(expectedReturn, actualReturn); + } + + @Test + public void testAgentExecWithTimeout() throws Exception { + Agent agent = createAgent(Scheduler.SYSTEM); + NoOpTaskSpec spec = new NoOpTaskSpec(0, 1); + TaskSpec rebasedSpec = agent.rebaseTaskSpecTime(spec); + testExec(agent, + String.format("Waiting for completion of task:%s%n", + JsonUtil.toPrettyJsonString(rebasedSpec)) + + String.format("Task failed with status null and error worker expired%n"), + false, rebasedSpec); + agent.beginShutdown(); + agent.waitForShutdown(); + } + + @Test + public void testAgentExecWithNormalExit() throws Exception { + Agent agent = createAgent(Scheduler.SYSTEM); + SampleTaskSpec spec = new SampleTaskSpec(0, 120000, + Collections.singletonMap("node01", 1L), ""); + TaskSpec rebasedSpec = agent.rebaseTaskSpecTime(spec); + testExec(agent, + String.format("Waiting for completion of task:%s%n", + JsonUtil.toPrettyJsonString(rebasedSpec)) + + String.format("Task succeeded with status \"halted\"%n"), + true, rebasedSpec); + agent.beginShutdown(); + agent.waitForShutdown(); + } + +}; diff --git a/trogdor/src/test/java/org/apache/kafka/trogdor/basic/BasicPlatformTest.java b/trogdor/src/test/java/org/apache/kafka/trogdor/basic/BasicPlatformTest.java new file mode 100644 index 0000000..3ed413d --- /dev/null +++ b/trogdor/src/test/java/org/apache/kafka/trogdor/basic/BasicPlatformTest.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.basic; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.test.TestUtils; +import org.apache.kafka.trogdor.common.Platform; + +import org.junit.jupiter.api.Test; + +import java.io.File; +import java.io.OutputStreamWriter; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import org.junit.jupiter.api.Timeout; + +@Timeout(value = 120000, unit = MILLISECONDS) +public class BasicPlatformTest { + + @Test + public void testCreateBasicPlatform() throws Exception { + File configFile = TestUtils.tempFile(); + try { + try (OutputStreamWriter writer = new OutputStreamWriter(Files.newOutputStream(configFile.toPath()), + StandardCharsets.UTF_8)) { + writer.write("{\n"); + writer.write(" \"platform\": \"org.apache.kafka.trogdor.basic.BasicPlatform\",\n"); + writer.write(" \"nodes\": {\n"); + writer.write(" \"bob01\": {\n"); + writer.write(" \"hostname\": \"localhost\",\n"); + writer.write(" \"trogdor.agent.port\": 8888\n"); + writer.write(" },\n"); + writer.write(" \"bob02\": {\n"); + writer.write(" \"hostname\": \"localhost\",\n"); + writer.write(" \"trogdor.agent.port\": 8889\n"); + writer.write(" }\n"); + writer.write(" }\n"); + writer.write("}\n"); + } + Platform platform = Platform.Config.parse("bob01", configFile.getPath()); + assertEquals("BasicPlatform", platform.name()); + assertEquals(2, platform.topology().nodes().size()); + assertEquals("bob01, bob02", Utils.join(platform.topology().nodes().keySet(), ", ")); + } finally { + Files.delete(configFile.toPath()); + } + } +}; diff --git a/trogdor/src/test/java/org/apache/kafka/trogdor/common/CapturingCommandRunner.java b/trogdor/src/test/java/org/apache/kafka/trogdor/common/CapturingCommandRunner.java new file mode 100644 index 0000000..2e5a660 --- /dev/null +++ b/trogdor/src/test/java/org/apache/kafka/trogdor/common/CapturingCommandRunner.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.common; + +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.trogdor.basic.BasicPlatform; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; + +public class CapturingCommandRunner implements BasicPlatform.CommandRunner { + private static final Logger log = LoggerFactory.getLogger(CapturingCommandRunner.class); + + private final Map> commands = new HashMap<>(); + + private synchronized List getOrCreate(String nodeName) { + List lines = commands.get(nodeName); + if (lines != null) { + return lines; + } + lines = new LinkedList<>(); + commands.put(nodeName, lines); + return lines; + } + + @Override + public String run(Node curNode, String[] command) throws IOException { + String line = Utils.join(command, " "); + synchronized (this) { + getOrCreate(curNode.name()).add(line); + } + log.debug("RAN {}: {}", curNode, Utils.join(command, " ")); + return ""; + } + + public synchronized List lines(String nodeName) { + return new ArrayList(getOrCreate(nodeName)); + } +}; diff --git a/trogdor/src/test/java/org/apache/kafka/trogdor/common/ExpectedTasks.java b/trogdor/src/test/java/org/apache/kafka/trogdor/common/ExpectedTasks.java new file mode 100644 index 0000000..3eb781c --- /dev/null +++ b/trogdor/src/test/java/org/apache/kafka/trogdor/common/ExpectedTasks.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.common; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.kafka.test.TestUtils; +import org.apache.kafka.trogdor.agent.AgentClient; +import org.apache.kafka.trogdor.coordinator.CoordinatorClient; +import org.apache.kafka.trogdor.rest.AgentStatusResponse; +import org.apache.kafka.trogdor.rest.TaskState; +import org.apache.kafka.trogdor.rest.TasksRequest; +import org.apache.kafka.trogdor.rest.TasksResponse; +import org.apache.kafka.trogdor.rest.WorkerState; +import org.apache.kafka.trogdor.task.TaskSpec; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; +import java.util.TreeMap; + +public class ExpectedTasks { + private static final Logger log = LoggerFactory.getLogger(ExpectedTasks.class); + + private final TreeMap expected = new TreeMap<>(); + + public static class ExpectedTaskBuilder { + private final String id; + private TaskSpec taskSpec = null; + private TaskState taskState = null; + private WorkerState workerState = null; + + public ExpectedTaskBuilder(String id) { + this.id = id; + } + + public ExpectedTaskBuilder taskSpec(TaskSpec taskSpec) { + this.taskSpec = taskSpec; + return this; + } + + public ExpectedTaskBuilder taskState(TaskState taskState) { + this.taskState = taskState; + return this; + } + + public ExpectedTaskBuilder workerState(WorkerState workerState) { + this.workerState = workerState; + return this; + } + + public ExpectedTask build() { + return new ExpectedTask(id, taskSpec, taskState, workerState); + } + } + + public static class ExpectedTask { + private final String id; + private final TaskSpec taskSpec; + private final TaskState taskState; + private final WorkerState workerState; + + @JsonCreator + private ExpectedTask(@JsonProperty("id") String id, + @JsonProperty("taskSpec") TaskSpec taskSpec, + @JsonProperty("taskState") TaskState taskState, + @JsonProperty("workerState") WorkerState workerState) { + this.id = id; + this.taskSpec = taskSpec; + this.taskState = taskState; + this.workerState = workerState; + } + + String compare(TaskState actual) { + if (actual == null) { + return "Did not find task " + id + "\n"; + } + if ((taskSpec != null) && (!actual.spec().equals(taskSpec))) { + return "Invalid spec for task " + id + ": expected " + taskSpec + + ", got " + actual.spec(); + } + if ((taskState != null) && (!actual.equals(taskState))) { + return "Invalid state for task " + id + ": expected " + taskState + + ", got " + actual; + } + return null; + } + + String compare(WorkerState actual) { + if ((workerState != null) && (!workerState.equals(actual))) { + if (actual == null) { + return "Did not find worker " + id + "\n"; + } + return "Invalid state for task " + id + ": expected " + workerState + + ", got " + actual; + } + return null; + } + + @JsonProperty + public String id() { + return id; + } + + @JsonProperty + public TaskSpec taskSpec() { + return taskSpec; + } + + @JsonProperty + public TaskState taskState() { + return taskState; + } + + @JsonProperty + public WorkerState workerState() { + return workerState; + } + } + + public ExpectedTasks addTask(ExpectedTask task) { + expected.put(task.id, task); + return this; + } + + public ExpectedTasks waitFor(final CoordinatorClient client) throws InterruptedException { + TestUtils.waitForCondition(() -> { + TasksResponse tasks = null; + try { + tasks = client.tasks(new TasksRequest(null, 0, 0, 0, 0, Optional.empty())); + } catch (Exception e) { + log.info("Unable to get coordinator tasks", e); + throw new RuntimeException(e); + } + StringBuilder errors = new StringBuilder(); + for (Map.Entry entry : expected.entrySet()) { + String id = entry.getKey(); + ExpectedTask task = entry.getValue(); + String differences = task.compare(tasks.tasks().get(id)); + if (differences != null) { + errors.append(differences); + } + } + String errorString = errors.toString(); + if (!errorString.isEmpty()) { + log.info("EXPECTED TASKS: {}", JsonUtil.toJsonString(expected)); + log.info("ACTUAL TASKS : {}", JsonUtil.toJsonString(tasks.tasks())); + log.info(errorString); + return false; + } + return true; + }, "Timed out waiting for expected tasks " + JsonUtil.toJsonString(expected)); + return this; + } + + public ExpectedTasks waitFor(final AgentClient client) throws InterruptedException { + TestUtils.waitForCondition(() -> { + AgentStatusResponse status = null; + try { + status = client.status(); + } catch (Exception e) { + log.info("Unable to get agent status", e); + throw new RuntimeException(e); + } + StringBuilder errors = new StringBuilder(); + HashMap taskIdToWorkerState = new HashMap<>(); + for (WorkerState state : status.workers().values()) { + taskIdToWorkerState.put(state.taskId(), state); + } + for (Map.Entry entry : expected.entrySet()) { + String id = entry.getKey(); + ExpectedTask worker = entry.getValue(); + String differences = worker.compare(taskIdToWorkerState.get(id)); + if (differences != null) { + errors.append(differences); + } + } + String errorString = errors.toString(); + if (!errorString.isEmpty()) { + log.info("EXPECTED WORKERS: {}", JsonUtil.toJsonString(expected)); + log.info("ACTUAL WORKERS : {}", JsonUtil.toJsonString(status.workers())); + log.info(errorString); + return false; + } + return true; + }, "Timed out waiting for expected workers " + JsonUtil.toJsonString(expected)); + return this; + } +}; diff --git a/trogdor/src/test/java/org/apache/kafka/trogdor/common/JsonSerializationTest.java b/trogdor/src/test/java/org/apache/kafka/trogdor/common/JsonSerializationTest.java new file mode 100644 index 0000000..8e53516 --- /dev/null +++ b/trogdor/src/test/java/org/apache/kafka/trogdor/common/JsonSerializationTest.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.common; + +import static org.junit.jupiter.api.Assertions.assertNotNull; + +import org.apache.kafka.trogdor.fault.FilesUnreadableFaultSpec; +import org.apache.kafka.trogdor.fault.Kibosh; +import org.apache.kafka.trogdor.fault.NetworkPartitionFaultSpec; +import org.apache.kafka.trogdor.fault.ProcessStopFaultSpec; +import org.apache.kafka.trogdor.rest.AgentStatusResponse; +import org.apache.kafka.trogdor.rest.TasksResponse; +import org.apache.kafka.trogdor.rest.WorkerDone; +import org.apache.kafka.trogdor.rest.WorkerRunning; +import org.apache.kafka.trogdor.rest.WorkerStopping; +import org.apache.kafka.trogdor.workload.PartitionsSpec; +import org.apache.kafka.trogdor.workload.ProduceBenchSpec; +import org.apache.kafka.trogdor.workload.RoundTripWorkloadSpec; +import org.apache.kafka.trogdor.workload.TopicsSpec; + +import java.lang.reflect.Field; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import org.junit.jupiter.api.Test; + +public class JsonSerializationTest { + @Test + public void testDeserializationDoesNotProduceNulls() throws Exception { + verify(new FilesUnreadableFaultSpec(0, 0, null, + null, null, 0)); + verify(new Kibosh.KiboshControlFile(null)); + verify(new NetworkPartitionFaultSpec(0, 0, null)); + verify(new ProcessStopFaultSpec(0, 0, null, null)); + verify(new AgentStatusResponse(0, null)); + verify(new TasksResponse(null)); + verify(new WorkerDone(null, null, 0, 0, null, null)); + verify(new WorkerRunning(null, null, 0, null)); + verify(new WorkerStopping(null, null, 0, null)); + verify(new ProduceBenchSpec(0, 0, null, null, + 0, 0, null, null, Optional.empty(), null, null, null, null, null, false, false)); + verify(new RoundTripWorkloadSpec(0, 0, null, null, null, null, null, null, + 0, null, null, 0)); + verify(new TopicsSpec()); + verify(new PartitionsSpec(0, (short) 0, null, null)); + Map> partitionAssignments = new HashMap>(); + partitionAssignments.put(0, Arrays.asList(1, 2, 3)); + partitionAssignments.put(1, Arrays.asList(1, 2, 3)); + verify(new PartitionsSpec(0, (short) 0, partitionAssignments, null)); + verify(new PartitionsSpec(0, (short) 0, null, null)); + } + + private void verify(T val1) throws Exception { + byte[] bytes = JsonUtil.JSON_SERDE.writeValueAsBytes(val1); + @SuppressWarnings("unchecked") + Class clazz = (Class) val1.getClass(); + T val2 = JsonUtil.JSON_SERDE.readValue(bytes, clazz); + for (Field field : clazz.getDeclaredFields()) { + boolean wasAccessible = field.isAccessible(); + field.setAccessible(true); + assertNotNull(field.get(val2), "Field " + field + " was null."); + field.setAccessible(wasAccessible); + } + } +}; diff --git a/trogdor/src/test/java/org/apache/kafka/trogdor/common/JsonUtilTest.java b/trogdor/src/test/java/org/apache/kafka/trogdor/common/JsonUtilTest.java new file mode 100644 index 0000000..9e5c44a --- /dev/null +++ b/trogdor/src/test/java/org/apache/kafka/trogdor/common/JsonUtilTest.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.common; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.io.File; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; + +@Timeout(value = 120000, unit = MILLISECONDS) +public class JsonUtilTest { + + @Test + public void testOpenBraceComesFirst() { + assertTrue(JsonUtil.openBraceComesFirst("{}")); + assertTrue(JsonUtil.openBraceComesFirst(" \t{\"foo\":\"bar\"}")); + assertTrue(JsonUtil.openBraceComesFirst(" { \"foo\": \"bar\" }")); + assertFalse(JsonUtil.openBraceComesFirst("/my/file/path")); + assertFalse(JsonUtil.openBraceComesFirst("mypath")); + assertFalse(JsonUtil.openBraceComesFirst(" blah{}")); + } + + static final class Foo { + @JsonProperty + final int bar; + + @JsonCreator + Foo(@JsonProperty("bar") int bar) { + this.bar = bar; + } + } + + @Test + public void testObjectFromCommandLineArgument() throws Exception { + assertEquals(123, JsonUtil.objectFromCommandLineArgument("{\"bar\":123}", Foo.class).bar); + assertEquals(1, JsonUtil.objectFromCommandLineArgument(" {\"bar\": 1} ", Foo.class).bar); + File tempFile = TestUtils.tempFile(); + try { + Files.write(tempFile.toPath(), "{\"bar\": 456}".getBytes(StandardCharsets.UTF_8)); + assertEquals(456, JsonUtil.objectFromCommandLineArgument(tempFile.getAbsolutePath(), Foo.class).bar); + } finally { + Files.delete(tempFile.toPath()); + } + } +}; diff --git a/trogdor/src/test/java/org/apache/kafka/trogdor/common/MiniTrogdorCluster.java b/trogdor/src/test/java/org/apache/kafka/trogdor/common/MiniTrogdorCluster.java new file mode 100644 index 0000000..1e84d60 --- /dev/null +++ b/trogdor/src/test/java/org/apache/kafka/trogdor/common/MiniTrogdorCluster.java @@ -0,0 +1,292 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.common; + +import org.apache.kafka.common.utils.Scheduler; +import org.apache.kafka.common.utils.ThreadUtils; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.trogdor.agent.Agent; +import org.apache.kafka.trogdor.agent.AgentClient; +import org.apache.kafka.trogdor.agent.AgentRestResource; +import org.apache.kafka.trogdor.basic.BasicNode; +import org.apache.kafka.trogdor.basic.BasicPlatform; +import org.apache.kafka.trogdor.basic.BasicTopology; +import org.apache.kafka.trogdor.coordinator.Coordinator; + +import org.apache.kafka.trogdor.coordinator.CoordinatorClient; +import org.apache.kafka.trogdor.coordinator.CoordinatorRestResource; +import org.apache.kafka.trogdor.rest.JsonRestServer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.TreeMap; +import java.util.TreeSet; +import java.util.concurrent.Callable; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +/** + * MiniTrogdorCluster sets up a local cluster of Trogdor Agents and Coordinators. + */ +public class MiniTrogdorCluster implements AutoCloseable { + private static final Logger log = LoggerFactory.getLogger(MiniTrogdorCluster.class); + + /** + * The MiniTrogdorCluster#Builder is used to set up a new MiniTrogdorCluster. + */ + public static class Builder { + private final TreeSet agentNames = new TreeSet<>(); + + private String coordinatorName = null; + + private Scheduler scheduler = Scheduler.SYSTEM; + + private BasicPlatform.CommandRunner commandRunner = + new BasicPlatform.ShellCommandRunner(); + + private static class NodeData { + String hostname; + AgentRestResource agentRestResource = null; + JsonRestServer agentRestServer = null; + int agentPort = 0; + + JsonRestServer coordinatorRestServer = null; + int coordinatorPort = 0; + CoordinatorRestResource coordinatorRestResource = null; + + Platform platform = null; + Agent agent = null; + Coordinator coordinator = null; + + BasicNode node = null; + } + + public Builder() { + } + + /** + * Set the timekeeper used by this MiniTrogdorCluster. + */ + public Builder scheduler(Scheduler scheduler) { + this.scheduler = scheduler; + return this; + } + + public Builder commandRunner(BasicPlatform.CommandRunner commandRunner) { + this.commandRunner = commandRunner; + return this; + } + + /** + * Add a new trogdor coordinator node to the cluster. + */ + public Builder addCoordinator(String nodeName) { + if (coordinatorName != null) { + throw new RuntimeException("At most one coordinator is allowed."); + } + coordinatorName = nodeName; + return this; + } + + /** + * Add a new trogdor agent node to the cluster. + */ + public Builder addAgent(String nodeName) { + if (agentNames.contains(nodeName)) { + throw new RuntimeException("There is already an agent on node " + nodeName); + } + agentNames.add(nodeName); + return this; + } + + private NodeData getOrCreate(String nodeName, TreeMap nodes) { + NodeData data = nodes.get(nodeName); + if (data != null) + return data; + data = new NodeData(); + data.hostname = "127.0.0.1"; + nodes.put(nodeName, data); + return data; + } + + /** + * Create the MiniTrogdorCluster. + */ + public MiniTrogdorCluster build() throws Exception { + log.info("Creating MiniTrogdorCluster with agents: {} and coordinator: {}", + Utils.join(agentNames, ", "), coordinatorName); + TreeMap nodes = new TreeMap<>(); + for (String agentName : agentNames) { + NodeData node = getOrCreate(agentName, nodes); + node.agentRestResource = new AgentRestResource(); + node.agentRestServer = new JsonRestServer(0); + node.agentRestServer.start(node.agentRestResource); + node.agentPort = node.agentRestServer.port(); + } + if (coordinatorName != null) { + NodeData node = getOrCreate(coordinatorName, nodes); + node.coordinatorRestResource = new CoordinatorRestResource(); + node.coordinatorRestServer = new JsonRestServer(0); + node.coordinatorRestServer.start(node.coordinatorRestResource); + node.coordinatorPort = node.coordinatorRestServer.port(); + } + for (Map.Entry entry : nodes.entrySet()) { + NodeData node = entry.getValue(); + HashMap config = new HashMap<>(); + if (node.agentPort != 0) { + config.put(Platform.Config.TROGDOR_AGENT_PORT, + Integer.toString(node.agentPort)); + } + if (node.coordinatorPort != 0) { + config.put(Platform.Config.TROGDOR_COORDINATOR_PORT, + Integer.toString(node.coordinatorPort)); + } + node.node = new BasicNode(entry.getKey(), node.hostname, config, + Collections.emptySet()); + } + TreeMap topologyNodes = new TreeMap<>(); + for (Map.Entry entry : nodes.entrySet()) { + topologyNodes.put(entry.getKey(), entry.getValue().node); + } + final BasicTopology topology = new BasicTopology(topologyNodes); + ScheduledExecutorService executor = Executors.newScheduledThreadPool(1, + ThreadUtils.createThreadFactory("MiniTrogdorClusterStartupThread%d", false)); + final AtomicReference failure = new AtomicReference(null); + for (final Map.Entry entry : nodes.entrySet()) { + executor.submit((Callable) () -> { + String nodeName = entry.getKey(); + try { + NodeData node = entry.getValue(); + node.platform = new BasicPlatform(nodeName, topology, scheduler, commandRunner); + if (node.agentRestResource != null) { + node.agent = new Agent(node.platform, scheduler, node.agentRestServer, + node.agentRestResource); + } + if (node.coordinatorRestResource != null) { + node.coordinator = new Coordinator(node.platform, scheduler, + node.coordinatorRestServer, node.coordinatorRestResource, 0); + } + } catch (Exception e) { + log.error("Unable to initialize {}", nodeName, e); + failure.compareAndSet(null, e); + } + return null; + }); + } + executor.shutdown(); + executor.awaitTermination(1, TimeUnit.DAYS); + Exception failureException = failure.get(); + if (failureException != null) { + throw failureException; + } + + TreeMap agents = new TreeMap<>(); + Coordinator coordinator = null; + for (Map.Entry entry : nodes.entrySet()) { + NodeData node = entry.getValue(); + if (node.agent != null) { + agents.put(entry.getKey(), node.agent); + } + if (node.coordinator != null) { + coordinator = node.coordinator; + } + } + return new MiniTrogdorCluster(scheduler, agents, nodes, coordinator); + } + } + + private final TreeMap agents; + + private final TreeMap nodesByAgent; + + private final Coordinator coordinator; + + private final Scheduler scheduler; + + private MiniTrogdorCluster(Scheduler scheduler, + TreeMap agents, + TreeMap nodesByAgent, + Coordinator coordinator) { + this.scheduler = scheduler; + this.agents = agents; + this.nodesByAgent = nodesByAgent; + this.coordinator = coordinator; + } + + public TreeMap agents() { + return agents; + } + + public Coordinator coordinator() { + return coordinator; + } + + public CoordinatorClient coordinatorClient() { + if (coordinator == null) { + throw new RuntimeException("No coordinator configured."); + } + return new CoordinatorClient.Builder(). + maxTries(10). + target("localhost", coordinator.port()). + build(); + } + + /** + * Mimic a restart of a Trogdor agent, essentially cleaning out all of its active workers + */ + public void restartAgent(String nodeName) { + if (!agents.containsKey(nodeName)) { + throw new RuntimeException("There is no agent on node " + nodeName); + } + Builder.NodeData node = nodesByAgent.get(nodeName); + agents.put(nodeName, new Agent(node.platform, scheduler, node.agentRestServer, node.agentRestResource)); + } + + public AgentClient agentClient(String nodeName) { + Agent agent = agents.get(nodeName); + if (agent == null) { + throw new RuntimeException("No agent configured on node " + nodeName); + } + return new AgentClient.Builder(). + maxTries(10). + target("localhost", agent.port()). + build(); + } + + @Override + public void close() throws Exception { + log.info("Closing MiniTrogdorCluster."); + if (coordinator != null) { + coordinator.beginShutdown(false); + } + for (Agent agent : agents.values()) { + agent.beginShutdown(); + } + for (Agent agent : agents.values()) { + agent.waitForShutdown(); + } + if (coordinator != null) { + coordinator.waitForShutdown(); + } + } +}; diff --git a/trogdor/src/test/java/org/apache/kafka/trogdor/common/StringExpanderTest.java b/trogdor/src/test/java/org/apache/kafka/trogdor/common/StringExpanderTest.java new file mode 100644 index 0000000..6967652 --- /dev/null +++ b/trogdor/src/test/java/org/apache/kafka/trogdor/common/StringExpanderTest.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.common; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import org.junit.jupiter.api.Timeout; + +@Timeout(value = 120000, unit = MILLISECONDS) +public class StringExpanderTest { + + @Test + public void testNoExpansionNeeded() { + assertEquals(Collections.singleton("foo"), StringExpander.expand("foo")); + assertEquals(Collections.singleton("bar"), StringExpander.expand("bar")); + assertEquals(Collections.singleton(""), StringExpander.expand("")); + } + + @Test + public void testExpansions() { + HashSet expected1 = new HashSet<>(Arrays.asList( + "foo1", + "foo2", + "foo3" + )); + assertEquals(expected1, StringExpander.expand("foo[1-3]")); + + HashSet expected2 = new HashSet<>(Arrays.asList( + "foo bar baz 0" + )); + assertEquals(expected2, StringExpander.expand("foo bar baz [0-0]")); + + HashSet expected3 = new HashSet<>(Arrays.asList( + "[[ wow50 ]]", + "[[ wow51 ]]", + "[[ wow52 ]]" + )); + assertEquals(expected3, StringExpander.expand("[[ wow[50-52] ]]")); + + HashSet expected4 = new HashSet<>(Arrays.asList( + "foo1bar", + "foo2bar", + "foo3bar" + )); + assertEquals(expected4, StringExpander.expand("foo[1-3]bar")); + + // should expand latest range first + HashSet expected5 = new HashSet<>(Arrays.asList( + "start[1-3]middle1epilogue", + "start[1-3]middle2epilogue", + "start[1-3]middle3epilogue" + )); + assertEquals(expected5, StringExpander.expand("start[1-3]middle[1-3]epilogue")); + } +} diff --git a/trogdor/src/test/java/org/apache/kafka/trogdor/common/StringFormatterTest.java b/trogdor/src/test/java/org/apache/kafka/trogdor/common/StringFormatterTest.java new file mode 100644 index 0000000..98add18 --- /dev/null +++ b/trogdor/src/test/java/org/apache/kafka/trogdor/common/StringFormatterTest.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.common; + +import org.junit.jupiter.api.Timeout; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.junit.jupiter.api.Test; + +import java.time.ZoneOffset; +import java.util.Arrays; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.apache.kafka.trogdor.common.StringFormatter.durationString; +import static org.apache.kafka.trogdor.common.StringFormatter.dateString; +import static org.junit.jupiter.api.Assertions.assertEquals; + +@Timeout(value = 120000, unit = MILLISECONDS) +public class StringFormatterTest { + private static final Logger log = LoggerFactory.getLogger(StringFormatterTest.class); + + @Test + public void testDateString() { + assertEquals("2019-01-08T20:59:29.85Z", dateString(1546981169850L, ZoneOffset.UTC)); + } + + @Test + public void testDurationString() { + assertEquals("1m", durationString(60000)); + assertEquals("1m1s", durationString(61000)); + assertEquals("1m1s", durationString(61200)); + assertEquals("5s", durationString(5000)); + assertEquals("2h", durationString(7200000)); + assertEquals("2h1s", durationString(7201000)); + assertEquals("2h5m3s", durationString(7503000)); + } + + @Test + public void testPrettyPrintGrid() { + assertEquals(String.format( + "ANIMAL NUMBER INDEX %n" + + "lion 1 12345 %n" + + "manatee 50 1 %n"), + StringFormatter.prettyPrintGrid( + Arrays.asList(Arrays.asList("ANIMAL", "NUMBER", "INDEX"), + Arrays.asList("lion", "1", "12345"), + Arrays.asList("manatee", "50", "1")))); + } +}; diff --git a/trogdor/src/test/java/org/apache/kafka/trogdor/common/TopologyTest.java b/trogdor/src/test/java/org/apache/kafka/trogdor/common/TopologyTest.java new file mode 100644 index 0000000..22eaf32 --- /dev/null +++ b/trogdor/src/test/java/org/apache/kafka/trogdor/common/TopologyTest.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.common; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.apache.kafka.trogdor.agent.Agent; +import org.apache.kafka.trogdor.basic.BasicNode; +import org.apache.kafka.trogdor.basic.BasicTopology; + +import org.apache.kafka.trogdor.coordinator.Coordinator; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Set; +import java.util.TreeMap; +import org.junit.jupiter.api.Timeout; + +@Timeout(value = 120000, unit = MILLISECONDS) +public class TopologyTest { + + @Test + public void testAgentNodeNames() { + TreeMap nodes = new TreeMap<>(); + final int numNodes = 5; + for (int i = 0; i < numNodes; i++) { + HashMap conf = new HashMap<>(); + if (i == 0) { + conf.put(Platform.Config.TROGDOR_COORDINATOR_PORT, String.valueOf(Coordinator.DEFAULT_PORT)); + } else { + conf.put(Platform.Config.TROGDOR_AGENT_PORT, String.valueOf(Agent.DEFAULT_PORT)); + } + BasicNode node = new BasicNode(String.format("node%02d", i), + String.format("node%d.example.com", i), + conf, + new HashSet<>()); + nodes.put(node.name(), node); + } + Topology topology = new BasicTopology(nodes); + Set names = Topology.Util.agentNodeNames(topology); + assertEquals(4, names.size()); + for (int i = 1; i < numNodes - 1; i++) { + assertTrue(names.contains(String.format("node%02d", i))); + } + } +} diff --git a/trogdor/src/test/java/org/apache/kafka/trogdor/common/WorkerUtilsTest.java b/trogdor/src/test/java/org/apache/kafka/trogdor/common/WorkerUtilsTest.java new file mode 100644 index 0000000..a5fbc85 --- /dev/null +++ b/trogdor/src/test/java/org/apache/kafka/trogdor/common/WorkerUtilsTest.java @@ -0,0 +1,332 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.common; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import org.apache.kafka.clients.admin.MockAdminClient; +import org.apache.kafka.clients.admin.NewTopic; +import org.apache.kafka.clients.admin.TopicDescription; +import org.apache.kafka.clients.producer.ProducerConfig; +import org.apache.kafka.common.Node; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.TopicPartitionInfo; +import org.apache.kafka.common.errors.TopicExistsException; +import org.apache.kafka.common.errors.UnknownTopicOrPartitionException; +import org.apache.kafka.common.utils.Utils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Properties; + +public class WorkerUtilsTest { + + private static final Logger log = LoggerFactory.getLogger(WorkerUtilsTest.class); + + private final Node broker1 = new Node(0, "testHost-1", 1234); + private final Node broker2 = new Node(1, "testHost-2", 1234); + private final Node broker3 = new Node(1, "testHost-3", 1234); + private final List cluster = Arrays.asList(broker1, broker2, broker3); + private final List singleReplica = Collections.singletonList(broker1); + + private static final String TEST_TOPIC = "test-topic-1"; + private static final short TEST_REPLICATION_FACTOR = 1; + private static final int TEST_PARTITIONS = 1; + private static final NewTopic NEW_TEST_TOPIC = + new NewTopic(TEST_TOPIC, TEST_PARTITIONS, TEST_REPLICATION_FACTOR); + + private MockAdminClient adminClient; + + + @BeforeEach + public void setUp() { + adminClient = new MockAdminClient(cluster, broker1); + } + + @Test + public void testCreateOneTopic() throws Throwable { + Map newTopics = Collections.singletonMap(TEST_TOPIC, NEW_TEST_TOPIC); + + WorkerUtils.createTopics(log, adminClient, newTopics, true); + assertEquals(Collections.singleton(TEST_TOPIC), adminClient.listTopics().names().get()); + assertEquals( + new TopicDescription( + TEST_TOPIC, false, + Collections.singletonList( + new TopicPartitionInfo(0, broker1, singleReplica, Collections.emptyList()))), + adminClient.describeTopics( + Collections.singleton(TEST_TOPIC)).topicNameValues().get(TEST_TOPIC).get() + ); + } + + @Test + public void testCreateRetriesOnTimeout() throws Throwable { + adminClient.timeoutNextRequest(1); + + WorkerUtils.createTopics( + log, adminClient, Collections.singletonMap(TEST_TOPIC, NEW_TEST_TOPIC), true); + + assertEquals( + new TopicDescription( + TEST_TOPIC, false, + Collections.singletonList( + new TopicPartitionInfo(0, broker1, singleReplica, Collections.emptyList()))), + adminClient.describeTopics( + Collections.singleton(TEST_TOPIC)).topicNameValues().get(TEST_TOPIC).get() + ); + } + + @Test + public void testCreateZeroTopicsDoesNothing() throws Throwable { + WorkerUtils.createTopics(log, adminClient, Collections.emptyMap(), true); + assertEquals(0, adminClient.listTopics().names().get().size()); + } + + @Test + public void testCreateTopicsFailsIfAtLeastOneTopicExists() throws Throwable { + adminClient.addTopic( + false, + TEST_TOPIC, + Collections.singletonList(new TopicPartitionInfo(0, broker1, singleReplica, Collections.emptyList())), + null); + + Map newTopics = new HashMap<>(); + newTopics.put(TEST_TOPIC, NEW_TEST_TOPIC); + newTopics.put("another-topic", + new NewTopic("another-topic", TEST_PARTITIONS, TEST_REPLICATION_FACTOR)); + newTopics.put("one-more-topic", + new NewTopic("one-more-topic", TEST_PARTITIONS, TEST_REPLICATION_FACTOR)); + + assertThrows(TopicExistsException.class, () -> WorkerUtils.createTopics(log, adminClient, newTopics, true)); + } + + @Test + public void testExistingTopicsMustHaveRequestedNumberOfPartitions() throws Throwable { + List tpInfo = new ArrayList<>(); + tpInfo.add(new TopicPartitionInfo(0, broker1, singleReplica, Collections.emptyList())); + tpInfo.add(new TopicPartitionInfo(1, broker2, singleReplica, Collections.emptyList())); + adminClient.addTopic( + false, + TEST_TOPIC, + tpInfo, + null); + + assertThrows(RuntimeException.class, () -> WorkerUtils.createTopics( + log, adminClient, Collections.singletonMap(TEST_TOPIC, NEW_TEST_TOPIC), false)); + } + + @Test + public void testExistingTopicsNotCreated() throws Throwable { + final String existingTopic = "existing-topic"; + List tpInfo = new ArrayList<>(); + tpInfo.add(new TopicPartitionInfo(0, broker1, singleReplica, Collections.emptyList())); + tpInfo.add(new TopicPartitionInfo(1, broker2, singleReplica, Collections.emptyList())); + tpInfo.add(new TopicPartitionInfo(2, broker3, singleReplica, Collections.emptyList())); + adminClient.addTopic( + false, + existingTopic, + tpInfo, + null); + + WorkerUtils.createTopics( + log, adminClient, + Collections.singletonMap( + existingTopic, + new NewTopic(existingTopic, tpInfo.size(), TEST_REPLICATION_FACTOR)), false); + + assertEquals(Collections.singleton(existingTopic), adminClient.listTopics().names().get()); + } + + @Test + public void testCreatesNotExistingTopics() throws Throwable { + // should be no topics before the call + assertEquals(0, adminClient.listTopics().names().get().size()); + + WorkerUtils.createTopics( + log, adminClient, Collections.singletonMap(TEST_TOPIC, NEW_TEST_TOPIC), false); + + assertEquals(Collections.singleton(TEST_TOPIC), adminClient.listTopics().names().get()); + assertEquals( + new TopicDescription( + TEST_TOPIC, false, + Collections.singletonList( + new TopicPartitionInfo(0, broker1, singleReplica, Collections.emptyList()))), + adminClient.describeTopics(Collections.singleton(TEST_TOPIC)).topicNameValues().get(TEST_TOPIC).get() + ); + } + + @Test + public void testCreatesOneTopicVerifiesOneTopic() throws Throwable { + final String existingTopic = "existing-topic"; + List tpInfo = new ArrayList<>(); + tpInfo.add(new TopicPartitionInfo(0, broker1, singleReplica, Collections.emptyList())); + tpInfo.add(new TopicPartitionInfo(1, broker2, singleReplica, Collections.emptyList())); + adminClient.addTopic( + false, + existingTopic, + tpInfo, + null); + + Map topics = new HashMap<>(); + topics.put(existingTopic, + new NewTopic(existingTopic, tpInfo.size(), TEST_REPLICATION_FACTOR)); + topics.put(TEST_TOPIC, NEW_TEST_TOPIC); + + WorkerUtils.createTopics(log, adminClient, topics, false); + + assertEquals(Utils.mkSet(existingTopic, TEST_TOPIC), adminClient.listTopics().names().get()); + } + + @Test + public void testCreateNonExistingTopicsWithZeroTopicsDoesNothing() throws Throwable { + WorkerUtils.createTopics( + log, adminClient, Collections.emptyMap(), false); + assertEquals(0, adminClient.listTopics().names().get().size()); + } + + @Test + public void testAddConfigsToPropertiesAddsAllConfigs() { + Properties props = new Properties(); + props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + props.put(ProducerConfig.ACKS_CONFIG, "all"); + + Properties resultProps = new Properties(); + resultProps.putAll(props); + resultProps.put(ProducerConfig.CLIENT_ID_CONFIG, "test-client"); + resultProps.put(ProducerConfig.LINGER_MS_CONFIG, "1000"); + + WorkerUtils.addConfigsToProperties( + props, + Collections.singletonMap(ProducerConfig.CLIENT_ID_CONFIG, "test-client"), + Collections.singletonMap(ProducerConfig.LINGER_MS_CONFIG, "1000")); + assertEquals(resultProps, props); + } + + @Test + public void testCommonConfigOverwritesDefaultProps() { + Properties props = new Properties(); + props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + props.put(ProducerConfig.ACKS_CONFIG, "all"); + + Properties resultProps = new Properties(); + resultProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + resultProps.put(ProducerConfig.ACKS_CONFIG, "1"); + resultProps.put(ProducerConfig.LINGER_MS_CONFIG, "1000"); + + WorkerUtils.addConfigsToProperties( + props, + Collections.singletonMap(ProducerConfig.ACKS_CONFIG, "1"), + Collections.singletonMap(ProducerConfig.LINGER_MS_CONFIG, "1000")); + assertEquals(resultProps, props); + } + + @Test + public void testClientConfigOverwritesBothDefaultAndCommonConfigs() { + Properties props = new Properties(); + props.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + props.put(ProducerConfig.ACKS_CONFIG, "all"); + + Properties resultProps = new Properties(); + resultProps.put(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, "localhost:9092"); + resultProps.put(ProducerConfig.ACKS_CONFIG, "0"); + + WorkerUtils.addConfigsToProperties( + props, + Collections.singletonMap(ProducerConfig.ACKS_CONFIG, "1"), + Collections.singletonMap(ProducerConfig.ACKS_CONFIG, "0")); + assertEquals(resultProps, props); + } + + @Test + public void testGetMatchingTopicPartitionsCorrectlyMatchesExactTopicName() throws Throwable { + final String topic1 = "existing-topic"; + final String topic2 = "another-topic"; + makeExistingTopicWithOneReplica(topic1, 10); + makeExistingTopicWithOneReplica(topic2, 20); + + Collection topicPartitions = + WorkerUtils.getMatchingTopicPartitions(adminClient, topic2, 0, 2); + assertEquals( + Utils.mkSet( + new TopicPartition(topic2, 0), new TopicPartition(topic2, 1), + new TopicPartition(topic2, 2) + ), + new HashSet<>(topicPartitions) + ); + } + + @Test + public void testGetMatchingTopicPartitionsCorrectlyMatchesTopics() throws Throwable { + final String topic1 = "test-topic"; + final String topic2 = "another-test-topic"; + final String topic3 = "one-more"; + makeExistingTopicWithOneReplica(topic1, 10); + makeExistingTopicWithOneReplica(topic2, 20); + makeExistingTopicWithOneReplica(topic3, 30); + + Collection topicPartitions = + WorkerUtils.getMatchingTopicPartitions(adminClient, ".*-topic$", 0, 1); + assertEquals( + Utils.mkSet( + new TopicPartition(topic1, 0), new TopicPartition(topic1, 1), + new TopicPartition(topic2, 0), new TopicPartition(topic2, 1) + ), + new HashSet<>(topicPartitions) + ); + } + + private void makeExistingTopicWithOneReplica(String topicName, int numPartitions) { + List tpInfo = new ArrayList<>(); + int brokerIndex = 0; + for (int i = 0; i < numPartitions; ++i) { + Node broker = cluster.get(brokerIndex); + tpInfo.add(new TopicPartitionInfo( + i, broker, singleReplica, Collections.emptyList())); + brokerIndex = (brokerIndex + 1) % cluster.size(); + } + adminClient.addTopic( + false, + topicName, + tpInfo, + null); + } + + @Test + public void testVerifyTopics() throws Throwable { + Map newTopics = Collections.singletonMap(TEST_TOPIC, NEW_TEST_TOPIC); + WorkerUtils.createTopics(log, adminClient, newTopics, true); + adminClient.setFetchesRemainingUntilVisible(TEST_TOPIC, 2); + WorkerUtils.verifyTopics(log, adminClient, Collections.singleton(TEST_TOPIC), + Collections.singletonMap(TEST_TOPIC, NEW_TEST_TOPIC), 3, 1); + adminClient.setFetchesRemainingUntilVisible(TEST_TOPIC, 100); + assertThrows(UnknownTopicOrPartitionException.class, () -> + WorkerUtils.verifyTopics(log, adminClient, Collections.singleton(TEST_TOPIC), + Collections.singletonMap(TEST_TOPIC, NEW_TEST_TOPIC), 2, 1)); + } +} diff --git a/trogdor/src/test/java/org/apache/kafka/trogdor/coordinator/CoordinatorClientTest.java b/trogdor/src/test/java/org/apache/kafka/trogdor/coordinator/CoordinatorClientTest.java new file mode 100644 index 0000000..30bdb4f --- /dev/null +++ b/trogdor/src/test/java/org/apache/kafka/trogdor/coordinator/CoordinatorClientTest.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.coordinator; + +import com.fasterxml.jackson.databind.node.JsonNodeFactory; +import org.apache.kafka.trogdor.rest.TaskDone; +import org.apache.kafka.trogdor.rest.TaskPending; +import org.apache.kafka.trogdor.rest.TaskRunning; +import org.apache.kafka.trogdor.rest.TaskStopping; +import org.apache.kafka.trogdor.task.NoOpTaskSpec; + +import java.time.ZoneOffset; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.junit.jupiter.api.Assertions.assertEquals; + +@Timeout(value = 120000, unit = MILLISECONDS) +public class CoordinatorClientTest { + + @Test + public void testPrettyPrintTaskInfo() { + assertEquals("Will start at 2019-01-08T07:05:59.85Z", + CoordinatorClient.prettyPrintTaskInfo( + new TaskPending(new NoOpTaskSpec(1546931159850L, 9000)), + ZoneOffset.UTC)); + assertEquals("Started 2009-07-07T01:45:59.85Z; will stop after 9s", + CoordinatorClient.prettyPrintTaskInfo( + new TaskRunning(new NoOpTaskSpec(1146931159850L, 9000), + 1246931159850L, + JsonNodeFactory.instance.objectNode()), ZoneOffset.UTC)); + assertEquals("Started 2009-07-07T01:45:59.85Z", + CoordinatorClient.prettyPrintTaskInfo( + new TaskStopping(new NoOpTaskSpec(1146931159850L, 9000), + 1246931159850L, + JsonNodeFactory.instance.objectNode()), ZoneOffset.UTC)); + assertEquals("FINISHED at 2019-01-08T20:59:29.85Z after 10s", + CoordinatorClient.prettyPrintTaskInfo( + new TaskDone(new NoOpTaskSpec(0, 1000), + 1546981159850L, + 1546981169850L, + "", + false, + JsonNodeFactory.instance.objectNode()), ZoneOffset.UTC)); + assertEquals("CANCELLED at 2019-01-08T20:59:29.85Z after 10s", + CoordinatorClient.prettyPrintTaskInfo( + new TaskDone(new NoOpTaskSpec(0, 1000), + 1546981159850L, + 1546981169850L, + "", + true, + JsonNodeFactory.instance.objectNode()), ZoneOffset.UTC)); + assertEquals("FAILED at 2019-01-08T20:59:29.85Z after 10s", + CoordinatorClient.prettyPrintTaskInfo( + new TaskDone(new NoOpTaskSpec(0, 1000), + 1546981159850L, + 1546981169850L, + "foobar", + true, + JsonNodeFactory.instance.objectNode()), ZoneOffset.UTC)); + } +}; diff --git a/trogdor/src/test/java/org/apache/kafka/trogdor/coordinator/CoordinatorTest.java b/trogdor/src/test/java/org/apache/kafka/trogdor/coordinator/CoordinatorTest.java new file mode 100644 index 0000000..b77b2b0 --- /dev/null +++ b/trogdor/src/test/java/org/apache/kafka/trogdor/coordinator/CoordinatorTest.java @@ -0,0 +1,719 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.coordinator; + +import com.fasterxml.jackson.databind.node.JsonNodeFactory; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.fasterxml.jackson.databind.node.TextNode; +import org.apache.kafka.common.utils.MockScheduler; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Scheduler; +import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.test.TestUtils; +import org.apache.kafka.trogdor.agent.AgentClient; +import org.apache.kafka.trogdor.common.CapturingCommandRunner; +import org.apache.kafka.trogdor.common.ExpectedTasks; +import org.apache.kafka.trogdor.common.ExpectedTasks.ExpectedTaskBuilder; +import org.apache.kafka.trogdor.common.MiniTrogdorCluster; +import org.apache.kafka.trogdor.fault.NetworkPartitionFaultSpec; +import org.apache.kafka.trogdor.rest.CoordinatorStatusResponse; +import org.apache.kafka.trogdor.rest.CreateTaskRequest; +import org.apache.kafka.trogdor.rest.DestroyTaskRequest; +import org.apache.kafka.trogdor.rest.RequestConflictException; +import org.apache.kafka.trogdor.rest.StopTaskRequest; +import org.apache.kafka.trogdor.rest.TaskDone; +import org.apache.kafka.trogdor.rest.TaskPending; +import org.apache.kafka.trogdor.rest.TaskRequest; +import org.apache.kafka.trogdor.rest.TaskRunning; +import org.apache.kafka.trogdor.rest.TaskState; +import org.apache.kafka.trogdor.rest.TaskStateType; +import org.apache.kafka.trogdor.rest.TasksRequest; +import org.apache.kafka.trogdor.rest.TasksResponse; +import org.apache.kafka.trogdor.rest.UptimeResponse; +import org.apache.kafka.trogdor.rest.WorkerDone; +import org.apache.kafka.trogdor.rest.WorkerRunning; +import org.apache.kafka.trogdor.task.NoOpTaskSpec; +import org.apache.kafka.trogdor.task.SampleTaskSpec; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.ws.rs.NotFoundException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Optional; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@Tag("integration") +@Timeout(value = 120000, unit = MILLISECONDS) +public class CoordinatorTest { + + private static final Logger log = LoggerFactory.getLogger(CoordinatorTest.class); + + @Test + public void testCoordinatorStatus() throws Exception { + try (MiniTrogdorCluster cluster = new MiniTrogdorCluster.Builder(). + addCoordinator("node01"). + build()) { + CoordinatorStatusResponse status = cluster.coordinatorClient().status(); + assertEquals(cluster.coordinator().status(), status); + } + } + + @Test + public void testCoordinatorUptime() throws Exception { + MockTime time = new MockTime(0, 200, 0); + Scheduler scheduler = new MockScheduler(time); + try (MiniTrogdorCluster cluster = new MiniTrogdorCluster.Builder(). + addCoordinator("node01"). + scheduler(scheduler). + build()) { + UptimeResponse uptime = cluster.coordinatorClient().uptime(); + assertEquals(cluster.coordinator().uptime(), uptime); + + time.setCurrentTimeMs(250); + assertNotEquals(cluster.coordinator().uptime(), uptime); + } + } + + @Test + public void testCreateTask() throws Exception { + MockTime time = new MockTime(0, 0, 0); + Scheduler scheduler = new MockScheduler(time); + try (MiniTrogdorCluster cluster = new MiniTrogdorCluster.Builder(). + addCoordinator("node01"). + addAgent("node02"). + scheduler(scheduler). + build()) { + new ExpectedTasks().waitFor(cluster.coordinatorClient()); + + NoOpTaskSpec fooSpec = new NoOpTaskSpec(1, 2); + cluster.coordinatorClient().createTask( + new CreateTaskRequest("foo", fooSpec)); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + taskState(new TaskPending(fooSpec)). + build()). + waitFor(cluster.coordinatorClient()); + + // Re-creating a task with the same arguments is not an error. + cluster.coordinatorClient().createTask( + new CreateTaskRequest("foo", fooSpec)); + + // Re-creating a task with different arguments gives a RequestConflictException. + NoOpTaskSpec barSpec = new NoOpTaskSpec(1000, 2000); + assertThrows(RequestConflictException.class, () -> cluster.coordinatorClient().createTask( + new CreateTaskRequest("foo", barSpec)), + "Recreating task with different task spec is not allowed"); + + time.sleep(2); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + taskState(new TaskRunning(fooSpec, 2, new TextNode("active"))). + workerState(new WorkerRunning("foo", fooSpec, 2, new TextNode("active"))). + build()). + waitFor(cluster.coordinatorClient()). + waitFor(cluster.agentClient("node02")); + + time.sleep(3); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + taskState(new TaskDone(fooSpec, 2, 5, "", false, new TextNode("done"))). + build()). + waitFor(cluster.coordinatorClient()); + } + } + + @Test + public void testTaskDistribution() throws Exception { + MockTime time = new MockTime(0, 0, 0); + Scheduler scheduler = new MockScheduler(time); + try (MiniTrogdorCluster cluster = new MiniTrogdorCluster.Builder(). + addCoordinator("node01"). + addAgent("node01"). + addAgent("node02"). + scheduler(scheduler). + build()) { + CoordinatorClient coordinatorClient = cluster.coordinatorClient(); + AgentClient agentClient1 = cluster.agentClient("node01"); + AgentClient agentClient2 = cluster.agentClient("node02"); + + new ExpectedTasks(). + waitFor(coordinatorClient). + waitFor(agentClient1). + waitFor(agentClient2); + + NoOpTaskSpec fooSpec = new NoOpTaskSpec(5, 7); + coordinatorClient.createTask(new CreateTaskRequest("foo", fooSpec)); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo").taskState( + new TaskPending(fooSpec)).build()). + waitFor(coordinatorClient). + waitFor(agentClient1). + waitFor(agentClient2); + + time.sleep(11); + ObjectNode status1 = new ObjectNode(JsonNodeFactory.instance); + status1.set("node01", new TextNode("active")); + status1.set("node02", new TextNode("active")); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + taskState(new TaskRunning(fooSpec, 11, status1)). + workerState(new WorkerRunning("foo", fooSpec, 11, new TextNode("active"))). + build()). + waitFor(coordinatorClient). + waitFor(agentClient1). + waitFor(agentClient2); + + time.sleep(7); + ObjectNode status2 = new ObjectNode(JsonNodeFactory.instance); + status2.set("node01", new TextNode("done")); + status2.set("node02", new TextNode("done")); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + taskState(new TaskDone(fooSpec, 11, 18, + "", false, status2)). + workerState(new WorkerDone("foo", fooSpec, 11, 18, new TextNode("done"), "")). + build()). + waitFor(coordinatorClient). + waitFor(agentClient1). + waitFor(agentClient2); + } + } + + @Test + public void testTaskCancellation() throws Exception { + MockTime time = new MockTime(0, 0, 0); + Scheduler scheduler = new MockScheduler(time); + try (MiniTrogdorCluster cluster = new MiniTrogdorCluster.Builder(). + addCoordinator("node01"). + addAgent("node01"). + addAgent("node02"). + scheduler(scheduler). + build()) { + CoordinatorClient coordinatorClient = cluster.coordinatorClient(); + AgentClient agentClient1 = cluster.agentClient("node01"); + AgentClient agentClient2 = cluster.agentClient("node02"); + + new ExpectedTasks(). + waitFor(coordinatorClient). + waitFor(agentClient1). + waitFor(agentClient2); + + NoOpTaskSpec fooSpec = new NoOpTaskSpec(5, 7); + coordinatorClient.createTask(new CreateTaskRequest("foo", fooSpec)); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo").taskState(new TaskPending(fooSpec)).build()). + waitFor(coordinatorClient). + waitFor(agentClient1). + waitFor(agentClient2); + + time.sleep(11); + + ObjectNode status1 = new ObjectNode(JsonNodeFactory.instance); + status1.set("node01", new TextNode("active")); + status1.set("node02", new TextNode("active")); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + taskState(new TaskRunning(fooSpec, 11, status1)). + workerState(new WorkerRunning("foo", fooSpec, 11, new TextNode("active"))). + build()). + waitFor(coordinatorClient). + waitFor(agentClient1). + waitFor(agentClient2); + + ObjectNode status2 = new ObjectNode(JsonNodeFactory.instance); + status2.set("node01", new TextNode("done")); + status2.set("node02", new TextNode("done")); + time.sleep(7); + coordinatorClient.stopTask(new StopTaskRequest("foo")); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + taskState(new TaskDone(fooSpec, 11, 18, "", + true, status2)). + workerState(new WorkerDone("foo", fooSpec, 11, 18, new TextNode("done"), "")). + build()). + waitFor(coordinatorClient). + waitFor(agentClient1). + waitFor(agentClient2); + + coordinatorClient.destroyTask(new DestroyTaskRequest("foo")); + new ExpectedTasks(). + waitFor(coordinatorClient). + waitFor(agentClient1). + waitFor(agentClient2); + } + } + + @Test + public void testTaskDestruction() throws Exception { + MockTime time = new MockTime(0, 0, 0); + Scheduler scheduler = new MockScheduler(time); + try (MiniTrogdorCluster cluster = new MiniTrogdorCluster.Builder(). + addCoordinator("node01"). + addAgent("node01"). + addAgent("node02"). + scheduler(scheduler). + build()) { + CoordinatorClient coordinatorClient = cluster.coordinatorClient(); + AgentClient agentClient1 = cluster.agentClient("node01"); + AgentClient agentClient2 = cluster.agentClient("node02"); + + new ExpectedTasks(). + waitFor(coordinatorClient). + waitFor(agentClient1). + waitFor(agentClient2); + + NoOpTaskSpec fooSpec = new NoOpTaskSpec(2, 12); + coordinatorClient.destroyTask(new DestroyTaskRequest("foo")); + coordinatorClient.createTask(new CreateTaskRequest("foo", fooSpec)); + NoOpTaskSpec barSpec = new NoOpTaskSpec(20, 20); + coordinatorClient.createTask(new CreateTaskRequest("bar", barSpec)); + coordinatorClient.destroyTask(new DestroyTaskRequest("bar")); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo").taskState(new TaskPending(fooSpec)).build()). + waitFor(coordinatorClient). + waitFor(agentClient1). + waitFor(agentClient2); + time.sleep(10); + + ObjectNode status1 = new ObjectNode(JsonNodeFactory.instance); + status1.set("node01", new TextNode("active")); + status1.set("node02", new TextNode("active")); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + taskState(new TaskRunning(fooSpec, 10, status1)). + build()). + waitFor(coordinatorClient). + waitFor(agentClient1). + waitFor(agentClient2); + + coordinatorClient.destroyTask(new DestroyTaskRequest("foo")); + new ExpectedTasks(). + waitFor(coordinatorClient). + waitFor(agentClient1). + waitFor(agentClient2); + } + } + + public static class ExpectedLines { + List expectedLines = new ArrayList<>(); + + public ExpectedLines addLine(String line) { + expectedLines.add(line); + return this; + } + + public ExpectedLines waitFor(final String nodeName, + final CapturingCommandRunner runner) throws InterruptedException { + TestUtils.waitForCondition(() -> linesMatch(nodeName, runner.lines(nodeName)), + "failed to find the expected lines " + this.toString()); + return this; + } + + private boolean linesMatch(final String nodeName, List actualLines) { + int matchIdx = 0, i = 0; + while (true) { + if (matchIdx == expectedLines.size()) { + log.debug("Got expected lines for {}", nodeName); + return true; + } + if (i == actualLines.size()) { + log.info("Failed to find the expected lines for {}. First " + + "missing line on index {}: {}", + nodeName, matchIdx, expectedLines.get(matchIdx)); + return false; + } + String actualLine = actualLines.get(i++); + String expectedLine = expectedLines.get(matchIdx); + if (expectedLine.equals(actualLine)) { + matchIdx++; + } else { + log.trace("Expected:\n'{}', Got:\n'{}'", expectedLine, actualLine); + matchIdx = 0; + } + } + } + + @Override + public String toString() { + return Utils.join(expectedLines, ", "); + } + } + + private static List> createPartitionLists(String[][] array) { + List> list = new ArrayList<>(); + for (String[] a : array) { + list.add(Arrays.asList(a)); + } + return list; + } + + @Test + public void testNetworkPartitionFault() throws Exception { + CapturingCommandRunner runner = new CapturingCommandRunner(); + MockTime time = new MockTime(0, 0, 0); + Scheduler scheduler = new MockScheduler(time); + try (MiniTrogdorCluster cluster = new MiniTrogdorCluster.Builder(). + addCoordinator("node01"). + addAgent("node01"). + addAgent("node02"). + addAgent("node03"). + commandRunner(runner). + scheduler(scheduler). + build()) { + CoordinatorClient coordinatorClient = cluster.coordinatorClient(); + NetworkPartitionFaultSpec spec = new NetworkPartitionFaultSpec(0, Long.MAX_VALUE, + createPartitionLists(new String[][] { + new String[] {"node01", "node02"}, + new String[] {"node03"}, + })); + coordinatorClient.createTask(new CreateTaskRequest("netpart", spec)); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("netpart").taskSpec(spec).build()). + waitFor(coordinatorClient); + checkLines("-A", runner); + } + checkLines("-D", runner); + } + + private void checkLines(String prefix, CapturingCommandRunner runner) throws InterruptedException { + new ExpectedLines(). + addLine("sudo iptables " + prefix + " INPUT -p tcp -s 127.0.0.1 -j DROP " + + "-m comment --comment node03"). + waitFor("node01", runner); + new ExpectedLines(). + addLine("sudo iptables " + prefix + " INPUT -p tcp -s 127.0.0.1 -j DROP " + + "-m comment --comment node03"). + waitFor("node02", runner); + new ExpectedLines(). + addLine("sudo iptables " + prefix + " INPUT -p tcp -s 127.0.0.1 -j DROP " + + "-m comment --comment node01"). + addLine("sudo iptables " + prefix + " INPUT -p tcp -s 127.0.0.1 -j DROP " + + "-m comment --comment node02"). + waitFor("node03", runner); + } + + @Test + public void testTasksRequestMatches() throws Exception { + TasksRequest req1 = new TasksRequest(null, 0, 0, 0, 0, Optional.empty()); + assertTrue(req1.matches("foo1", -1, -1, TaskStateType.PENDING)); + assertTrue(req1.matches("bar1", 100, 200, TaskStateType.DONE)); + assertTrue(req1.matches("baz1", 100, -1, TaskStateType.RUNNING)); + + TasksRequest req2 = new TasksRequest(null, 100, 0, 0, 0, Optional.empty()); + assertFalse(req2.matches("foo1", -1, -1, TaskStateType.PENDING)); + assertTrue(req2.matches("bar1", 100, 200, TaskStateType.DONE)); + assertFalse(req2.matches("bar1", 99, 200, TaskStateType.DONE)); + assertFalse(req2.matches("baz1", 99, -1, TaskStateType.RUNNING)); + + TasksRequest req3 = new TasksRequest(null, 200, 900, 200, 900, Optional.empty()); + assertFalse(req3.matches("foo1", -1, -1, TaskStateType.PENDING)); + assertFalse(req3.matches("bar1", 100, 200, TaskStateType.DONE)); + assertFalse(req3.matches("bar1", 200, 1000, TaskStateType.DONE)); + assertTrue(req3.matches("bar1", 200, 700, TaskStateType.DONE)); + assertFalse(req3.matches("baz1", 101, -1, TaskStateType.RUNNING)); + + List taskIds = new ArrayList<>(); + taskIds.add("foo1"); + taskIds.add("bar1"); + taskIds.add("baz1"); + TasksRequest req4 = new TasksRequest(taskIds, 1000, -1, -1, -1, Optional.empty()); + assertFalse(req4.matches("foo1", -1, -1, TaskStateType.PENDING)); + assertTrue(req4.matches("foo1", 1000, -1, TaskStateType.RUNNING)); + assertFalse(req4.matches("foo1", 900, -1, TaskStateType.RUNNING)); + assertFalse(req4.matches("baz2", 2000, -1, TaskStateType.RUNNING)); + assertFalse(req4.matches("baz2", -1, -1, TaskStateType.PENDING)); + + TasksRequest req5 = new TasksRequest(null, 0, 0, 0, 0, Optional.of(TaskStateType.RUNNING)); + assertTrue(req5.matches("foo1", -1, -1, TaskStateType.RUNNING)); + assertFalse(req5.matches("bar1", -1, -1, TaskStateType.DONE)); + assertFalse(req5.matches("baz1", -1, -1, TaskStateType.STOPPING)); + assertFalse(req5.matches("baz1", -1, -1, TaskStateType.PENDING)); + } + + @Test + public void testTasksRequest() throws Exception { + MockTime time = new MockTime(0, 0, 0); + Scheduler scheduler = new MockScheduler(time); + try (MiniTrogdorCluster cluster = new MiniTrogdorCluster.Builder(). + addCoordinator("node01"). + addAgent("node02"). + scheduler(scheduler). + build()) { + CoordinatorClient coordinatorClient = cluster.coordinatorClient(); + new ExpectedTasks().waitFor(coordinatorClient); + + NoOpTaskSpec fooSpec = new NoOpTaskSpec(1, 10); + NoOpTaskSpec barSpec = new NoOpTaskSpec(3, 1); + coordinatorClient.createTask(new CreateTaskRequest("foo", fooSpec)); + coordinatorClient.createTask(new CreateTaskRequest("bar", barSpec)); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + taskState(new TaskPending(fooSpec)). + build()). + addTask(new ExpectedTaskBuilder("bar"). + taskState(new TaskPending(barSpec)). + build()). + waitFor(coordinatorClient); + + assertEquals(0, coordinatorClient.tasks( + new TasksRequest(null, 10, 0, 10, 0, Optional.empty())).tasks().size()); + TasksResponse resp1 = coordinatorClient.tasks( + new TasksRequest(Arrays.asList("foo", "baz"), 0, 0, 0, 0, Optional.empty())); + assertTrue(resp1.tasks().containsKey("foo")); + assertFalse(resp1.tasks().containsKey("bar")); + assertEquals(1, resp1.tasks().size()); + + time.sleep(2); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + taskState(new TaskRunning(fooSpec, 2, new TextNode("active"))). + workerState(new WorkerRunning("foo", fooSpec, 2, new TextNode("active"))). + build()). + addTask(new ExpectedTaskBuilder("bar"). + taskState(new TaskPending(barSpec)). + build()). + waitFor(coordinatorClient). + waitFor(cluster.agentClient("node02")); + + TasksResponse resp2 = coordinatorClient.tasks( + new TasksRequest(null, 1, 0, 0, 0, Optional.empty())); + assertTrue(resp2.tasks().containsKey("foo")); + assertFalse(resp2.tasks().containsKey("bar")); + assertEquals(1, resp2.tasks().size()); + + assertEquals(0, coordinatorClient.tasks( + new TasksRequest(null, 3, 0, 0, 0, Optional.empty())).tasks().size()); + } + } + + /** + * If an agent fails in the middle of a task and comes back up when the task is considered expired, + * we want the task to be marked as DONE and not re-sent should a second failure happen. + */ + @Test + public void testAgentFailureAndTaskExpiry() throws Exception { + MockTime time = new MockTime(0, 0, 0); + Scheduler scheduler = new MockScheduler(time); + try (MiniTrogdorCluster cluster = new MiniTrogdorCluster.Builder(). + addCoordinator("node01"). + addAgent("node02"). + scheduler(scheduler). + build()) { + CoordinatorClient coordinatorClient = cluster.coordinatorClient(); + + NoOpTaskSpec fooSpec = new NoOpTaskSpec(1, 500); + coordinatorClient.createTask(new CreateTaskRequest("foo", fooSpec)); + TaskState expectedState = new ExpectedTaskBuilder("foo").taskState(new TaskPending(fooSpec)).build().taskState(); + + TaskState resp = coordinatorClient.task(new TaskRequest("foo")); + assertEquals(expectedState, resp); + + + time.sleep(2); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + taskState(new TaskRunning(fooSpec, 2, new TextNode("active"))). + workerState(new WorkerRunning("foo", fooSpec, 2, new TextNode("active"))). + build()). + waitFor(coordinatorClient). + waitFor(cluster.agentClient("node02")); + + cluster.restartAgent("node02"); + time.sleep(550); + // coordinator heartbeat sees that the agent is back up, re-schedules the task but the agent expires it + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + taskState(new TaskDone(fooSpec, 2, 552, "worker expired", false, null)). + workerState(new WorkerDone("foo", fooSpec, 552, 552, null, "worker expired")). + build()). + waitFor(coordinatorClient). + waitFor(cluster.agentClient("node02")); + + cluster.restartAgent("node02"); + // coordinator heartbeat sees that the agent is back up but does not re-schedule the task as it is DONE + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + taskState(new TaskDone(fooSpec, 2, 552, "worker expired", false, null)). + // no worker states + build()). + waitFor(coordinatorClient). + waitFor(cluster.agentClient("node02")); + } + } + + @Test + public void testTaskRequestWithOldStartMsGetsUpdated() throws Exception { + MockTime time = new MockTime(0, 0, 0); + Scheduler scheduler = new MockScheduler(time); + try (MiniTrogdorCluster cluster = new MiniTrogdorCluster.Builder(). + addCoordinator("node01"). + addAgent("node02"). + scheduler(scheduler). + build()) { + + NoOpTaskSpec fooSpec = new NoOpTaskSpec(1, 500); + time.sleep(552); + + CoordinatorClient coordinatorClient = cluster.coordinatorClient(); + NoOpTaskSpec updatedSpec = new NoOpTaskSpec(552, 500); + coordinatorClient.createTask(new CreateTaskRequest("fooSpec", fooSpec)); + TaskState expectedState = new ExpectedTaskBuilder("fooSpec").taskState( + new TaskRunning(updatedSpec, 552, new TextNode("receiving")) + ).build().taskState(); + + TaskState resp = coordinatorClient.task(new TaskRequest("fooSpec")); + assertEquals(expectedState, resp); + } + } + + @Test + public void testTaskRequestWithFutureStartMsDoesNotGetRun() throws Exception { + MockTime time = new MockTime(0, 0, 0); + Scheduler scheduler = new MockScheduler(time); + try (MiniTrogdorCluster cluster = new MiniTrogdorCluster.Builder(). + addCoordinator("node01"). + addAgent("node02"). + scheduler(scheduler). + build()) { + + NoOpTaskSpec fooSpec = new NoOpTaskSpec(1000, 500); + time.sleep(999); + + CoordinatorClient coordinatorClient = cluster.coordinatorClient(); + coordinatorClient.createTask(new CreateTaskRequest("fooSpec", fooSpec)); + TaskState expectedState = new ExpectedTaskBuilder("fooSpec").taskState( + new TaskPending(fooSpec) + ).build().taskState(); + + TaskState resp = coordinatorClient.task(new TaskRequest("fooSpec")); + assertEquals(expectedState, resp); + } + } + + @Test + public void testTaskRequest() throws Exception { + MockTime time = new MockTime(0, 0, 0); + Scheduler scheduler = new MockScheduler(time); + try (MiniTrogdorCluster cluster = new MiniTrogdorCluster.Builder(). + addCoordinator("node01"). + addAgent("node02"). + scheduler(scheduler). + build()) { + CoordinatorClient coordinatorClient = cluster.coordinatorClient(); + + NoOpTaskSpec fooSpec = new NoOpTaskSpec(1, 10); + coordinatorClient.createTask(new CreateTaskRequest("foo", fooSpec)); + TaskState expectedState = new ExpectedTaskBuilder("foo").taskState(new TaskPending(fooSpec)).build().taskState(); + + TaskState resp = coordinatorClient.task(new TaskRequest("foo")); + assertEquals(expectedState, resp); + + time.sleep(2); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + taskState(new TaskRunning(fooSpec, 2, new TextNode("active"))). + workerState(new WorkerRunning("foo", fooSpec, 2, new TextNode("active"))). + build()). + waitFor(coordinatorClient). + waitFor(cluster.agentClient("node02")); + + assertThrows(NotFoundException.class, () -> coordinatorClient.task(new TaskRequest("non-existent-foo"))); + } + } + + @Test + public void testWorkersExitingAtDifferentTimes() throws Exception { + MockTime time = new MockTime(0, 0, 0); + Scheduler scheduler = new MockScheduler(time); + try (MiniTrogdorCluster cluster = new MiniTrogdorCluster.Builder(). + addCoordinator("node01"). + addAgent("node02"). + addAgent("node03"). + scheduler(scheduler). + build()) { + CoordinatorClient coordinatorClient = cluster.coordinatorClient(); + new ExpectedTasks().waitFor(coordinatorClient); + + HashMap nodeToExitMs = new HashMap<>(); + nodeToExitMs.put("node02", 10L); + nodeToExitMs.put("node03", 20L); + SampleTaskSpec fooSpec = + new SampleTaskSpec(2, 100, nodeToExitMs, ""); + coordinatorClient.createTask(new CreateTaskRequest("foo", fooSpec)); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + taskState(new TaskPending(fooSpec)). + build()). + waitFor(coordinatorClient); + + time.sleep(2); + ObjectNode status1 = new ObjectNode(JsonNodeFactory.instance); + status1.set("node02", new TextNode("active")); + status1.set("node03", new TextNode("active")); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + taskState(new TaskRunning(fooSpec, 2, status1)). + workerState(new WorkerRunning("foo", fooSpec, 2, new TextNode("active"))). + build()). + waitFor(coordinatorClient). + waitFor(cluster.agentClient("node02")). + waitFor(cluster.agentClient("node03")); + + time.sleep(10); + ObjectNode status2 = new ObjectNode(JsonNodeFactory.instance); + status2.set("node02", new TextNode("halted")); + status2.set("node03", new TextNode("active")); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + taskState(new TaskRunning(fooSpec, 2, status2)). + workerState(new WorkerRunning("foo", fooSpec, 2, new TextNode("active"))). + build()). + waitFor(coordinatorClient). + waitFor(cluster.agentClient("node03")); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + taskState(new TaskRunning(fooSpec, 2, status2)). + workerState(new WorkerDone("foo", fooSpec, 2, 12, new TextNode("halted"), "")). + build()). + waitFor(cluster.agentClient("node02")); + + time.sleep(10); + ObjectNode status3 = new ObjectNode(JsonNodeFactory.instance); + status3.set("node02", new TextNode("halted")); + status3.set("node03", new TextNode("halted")); + new ExpectedTasks(). + addTask(new ExpectedTaskBuilder("foo"). + taskState(new TaskDone(fooSpec, 2, 22, "", + false, status3)). + build()). + waitFor(coordinatorClient); + } + } +}; diff --git a/trogdor/src/test/java/org/apache/kafka/trogdor/rest/RestExceptionMapperTest.java b/trogdor/src/test/java/org/apache/kafka/trogdor/rest/RestExceptionMapperTest.java new file mode 100644 index 0000000..e61aec0 --- /dev/null +++ b/trogdor/src/test/java/org/apache/kafka/trogdor/rest/RestExceptionMapperTest.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.trogdor.rest; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.exc.InvalidTypeIdException; +import javax.ws.rs.NotFoundException; +import javax.ws.rs.core.Response; + +import org.apache.kafka.common.errors.InvalidRequestException; +import org.apache.kafka.common.errors.SerializationException; +import org.junit.jupiter.api.Test; + +public class RestExceptionMapperTest { + + @Test + public void testToResponseNotFound() { + RestExceptionMapper mapper = new RestExceptionMapper(); + Response resp = mapper.toResponse(new NotFoundException()); + assertEquals(resp.getStatus(), Response.Status.NOT_FOUND.getStatusCode()); + } + + @Test + public void testToResponseInvalidTypeIdException() { + RestExceptionMapper mapper = new RestExceptionMapper(); + JsonParser parser = null; + JavaType type = null; + Response resp = mapper.toResponse(InvalidTypeIdException.from(parser, "dummy msg", type, "dummy typeId")); + assertEquals(resp.getStatus(), Response.Status.NOT_IMPLEMENTED.getStatusCode()); + } + + @Test + public void testToResponseJsonMappingException() { + RestExceptionMapper mapper = new RestExceptionMapper(); + JsonParser parser = null; + Response resp = mapper.toResponse(JsonMappingException.from(parser, "dummy msg")); + assertEquals(resp.getStatus(), Response.Status.BAD_REQUEST.getStatusCode()); + } + + @Test + public void testToResponseClassNotFoundException() { + RestExceptionMapper mapper = new RestExceptionMapper(); + Response resp = mapper.toResponse(new ClassNotFoundException()); + assertEquals(resp.getStatus(), Response.Status.NOT_IMPLEMENTED.getStatusCode()); + } + + @Test + public void testToResponseSerializationException() { + RestExceptionMapper mapper = new RestExceptionMapper(); + Response resp = mapper.toResponse(new SerializationException()); + assertEquals(resp.getStatus(), Response.Status.BAD_REQUEST.getStatusCode()); + } + + @Test + public void testToResponseInvalidRequestException() { + RestExceptionMapper mapper = new RestExceptionMapper(); + Response resp = mapper.toResponse(new InvalidRequestException("invalid request")); + assertEquals(resp.getStatus(), Response.Status.BAD_REQUEST.getStatusCode()); + } + + @Test + public void testToResponseUnknownException() { + RestExceptionMapper mapper = new RestExceptionMapper(); + Response resp = mapper.toResponse(new Exception("Unkown exception")); + assertEquals(resp.getStatus(), Response.Status.INTERNAL_SERVER_ERROR.getStatusCode()); + } + + @Test + public void testToExceptionNotFoundException() { + assertThrows(NotFoundException.class, + () -> RestExceptionMapper.toException(Response.Status.NOT_FOUND.getStatusCode(), "Not Found")); + } + + @Test + public void testToExceptionClassNotFoundException() { + assertThrows(ClassNotFoundException.class, + () -> RestExceptionMapper.toException(Response.Status.NOT_IMPLEMENTED.getStatusCode(), "Not Implemented")); + } + + @Test + public void testToExceptionSerializationException() { + assertThrows(InvalidRequestException.class, + () -> RestExceptionMapper.toException(Response.Status.BAD_REQUEST.getStatusCode(), "Bad Request")); + } + + @Test + public void testToExceptionRuntimeException() { + assertThrows(RuntimeException.class, () -> RestExceptionMapper.toException(-1, "Unkown status code")); + } +} diff --git a/trogdor/src/test/java/org/apache/kafka/trogdor/task/SampleTaskController.java b/trogdor/src/test/java/org/apache/kafka/trogdor/task/SampleTaskController.java new file mode 100644 index 0000000..2640c39 --- /dev/null +++ b/trogdor/src/test/java/org/apache/kafka/trogdor/task/SampleTaskController.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.task; + +import org.apache.kafka.trogdor.common.Topology; + +import java.util.Set; + +public class SampleTaskController implements TaskController { + @Override + public Set targetNodes(Topology topology) { + return Topology.Util.agentNodeNames(topology); + } +}; diff --git a/trogdor/src/test/java/org/apache/kafka/trogdor/task/SampleTaskSpec.java b/trogdor/src/test/java/org/apache/kafka/trogdor/task/SampleTaskSpec.java new file mode 100644 index 0000000..38a160f --- /dev/null +++ b/trogdor/src/test/java/org/apache/kafka/trogdor/task/SampleTaskSpec.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.task; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +public class SampleTaskSpec extends TaskSpec { + private final Map nodeToExitMs; + private final String error; + + @JsonCreator + public SampleTaskSpec(@JsonProperty("startMs") long startMs, + @JsonProperty("durationMs") long durationMs, + @JsonProperty("nodeToExitMs") Map nodeToExitMs, + @JsonProperty("error") String error) { + super(startMs, durationMs); + this.nodeToExitMs = nodeToExitMs == null ? new HashMap() : + Collections.unmodifiableMap(nodeToExitMs); + this.error = error == null ? "" : error; + } + + @JsonProperty + public Map nodeToExitMs() { + return nodeToExitMs; + } + + @JsonProperty + public String error() { + return error; + } + + @Override + public TaskController newController(String id) { + return new SampleTaskController(); + } + + @Override + public TaskWorker newTaskWorker(String id) { + return new SampleTaskWorker(this); + } +}; diff --git a/trogdor/src/test/java/org/apache/kafka/trogdor/task/SampleTaskWorker.java b/trogdor/src/test/java/org/apache/kafka/trogdor/task/SampleTaskWorker.java new file mode 100644 index 0000000..06339d2 --- /dev/null +++ b/trogdor/src/test/java/org/apache/kafka/trogdor/task/SampleTaskWorker.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.task; + +import com.fasterxml.jackson.databind.node.TextNode; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.utils.ThreadUtils; +import org.apache.kafka.trogdor.common.Platform; + +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +public class SampleTaskWorker implements TaskWorker { + private final SampleTaskSpec spec; + private final ScheduledExecutorService executor; + private Future future; + private WorkerStatusTracker status; + + SampleTaskWorker(SampleTaskSpec spec) { + this.spec = spec; + this.executor = Executors.newSingleThreadScheduledExecutor( + ThreadUtils.createThreadFactory("SampleTaskWorker", false)); + this.future = null; + } + + @Override + public synchronized void start(Platform platform, WorkerStatusTracker status, + final KafkaFutureImpl haltFuture) throws Exception { + if (this.future != null) + return; + this.status = status; + this.status.update(new TextNode("active")); + + Long exitMs = spec.nodeToExitMs().get(platform.curNode().name()); + if (exitMs == null) { + exitMs = Long.MAX_VALUE; + } + this.future = platform.scheduler().schedule(executor, () -> { + haltFuture.complete(spec.error()); + return null; + }, exitMs); + } + + @Override + public void stop(Platform platform) throws Exception { + this.future.cancel(false); + this.executor.shutdown(); + this.executor.awaitTermination(1, TimeUnit.DAYS); + this.status.update(new TextNode("halted")); + } +}; diff --git a/trogdor/src/test/java/org/apache/kafka/trogdor/task/TaskSpecTest.java b/trogdor/src/test/java/org/apache/kafka/trogdor/task/TaskSpecTest.java new file mode 100644 index 0000000..f49c50a --- /dev/null +++ b/trogdor/src/test/java/org/apache/kafka/trogdor/task/TaskSpecTest.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.task; + +import com.fasterxml.jackson.databind.exc.InvalidTypeIdException; +import org.apache.kafka.trogdor.common.JsonUtil; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +@Timeout(value = 120000, unit = MILLISECONDS) +public class TaskSpecTest { + + @Test + public void testTaskSpecSerialization() throws Exception { + assertThrows(InvalidTypeIdException.class, () -> + JsonUtil.JSON_SERDE.readValue( + "{\"startMs\":123,\"durationMs\":456,\"exitMs\":1000,\"error\":\"foo\"}", + SampleTaskSpec.class), "Missing type id should cause exception to be thrown"); + String inputJson = "{\"class\":\"org.apache.kafka.trogdor.task.SampleTaskSpec\"," + + "\"startMs\":123,\"durationMs\":456,\"nodeToExitMs\":{\"node01\":1000},\"error\":\"foo\"}"; + SampleTaskSpec spec = JsonUtil.JSON_SERDE.readValue(inputJson, SampleTaskSpec.class); + assertEquals(123, spec.startMs()); + assertEquals(456, spec.durationMs()); + assertEquals(Long.valueOf(1000), spec.nodeToExitMs().get("node01")); + assertEquals("foo", spec.error()); + String outputJson = JsonUtil.toJsonString(spec); + assertEquals(inputJson, outputJson); + } +}; diff --git a/trogdor/src/test/java/org/apache/kafka/trogdor/workload/ConsumeBenchSpecTest.java b/trogdor/src/test/java/org/apache/kafka/trogdor/workload/ConsumeBenchSpecTest.java new file mode 100644 index 0000000..a0e3eb0 --- /dev/null +++ b/trogdor/src/test/java/org/apache/kafka/trogdor/workload/ConsumeBenchSpecTest.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import org.apache.kafka.common.TopicPartition; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.HashMap; +import java.util.Optional; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class ConsumeBenchSpecTest { + + @Test + public void testMaterializeTopicsWithNoPartitions() { + Map> materializedTopics = consumeBenchSpec(Arrays.asList("topic[1-3]", "secondTopic")).materializeTopics(); + Map> expected = new HashMap<>(); + expected.put("topic1", new ArrayList<>()); + expected.put("topic2", new ArrayList<>()); + expected.put("topic3", new ArrayList<>()); + expected.put("secondTopic", new ArrayList<>()); + + assertEquals(expected, materializedTopics); + } + + @Test + public void testMaterializeTopicsWithSomePartitions() { + Map> materializedTopics = consumeBenchSpec(Arrays.asList("topic[1-3]:[1-5]", "secondTopic", "thirdTopic:1")).materializeTopics(); + Map> expected = new HashMap<>(); + expected.put("topic1", IntStream.range(1, 6).asLongStream().mapToObj(i -> new TopicPartition("topic1", (int) i)).collect(Collectors.toList())); + expected.put("topic2", IntStream.range(1, 6).asLongStream().mapToObj(i -> new TopicPartition("topic2", (int) i)).collect(Collectors.toList())); + expected.put("topic3", IntStream.range(1, 6).asLongStream().mapToObj(i -> new TopicPartition("topic3", (int) i)).collect(Collectors.toList())); + expected.put("secondTopic", new ArrayList<>()); + expected.put("thirdTopic", Collections.singletonList(new TopicPartition("thirdTopic", 1))); + + assertEquals(expected, materializedTopics); + } + + @Test + public void testInvalidTopicNameRaisesExceptionInMaterialize() { + for (String invalidName : Arrays.asList("In:valid", "invalid:", ":invalid", "in:valid:1", "invalid:2:2", "invalid::1", "invalid[1-3]:")) { + assertThrows(IllegalArgumentException.class, () -> consumeBenchSpec(Collections.singletonList(invalidName)).materializeTopics()); + } + } + + private ConsumeBenchSpec consumeBenchSpec(List activeTopics) { + return new ConsumeBenchSpec(0, 0, "node", "localhost", + 123, 1234, "cg-1", + Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap(), 1, + Optional.empty(), activeTopics); + } +} diff --git a/trogdor/src/test/java/org/apache/kafka/trogdor/workload/ExternalCommandWorkerTest.java b/trogdor/src/test/java/org/apache/kafka/trogdor/workload/ExternalCommandWorkerTest.java new file mode 100644 index 0000000..1a33328 --- /dev/null +++ b/trogdor/src/test/java/org/apache/kafka/trogdor/workload/ExternalCommandWorkerTest.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import com.fasterxml.jackson.databind.node.IntNode; +import com.fasterxml.jackson.databind.node.JsonNodeFactory; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.fasterxml.jackson.databind.node.TextNode; +import org.apache.kafka.common.internals.KafkaFutureImpl; +import org.apache.kafka.common.utils.OperatingSystem; +import org.apache.kafka.test.TestUtils; +import org.apache.kafka.trogdor.task.AgentWorkerStatusTracker; +import org.apache.kafka.trogdor.task.WorkerStatusTracker; + +import java.io.File; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.util.Arrays; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@Timeout(value = 120000, unit = MILLISECONDS) +public class ExternalCommandWorkerTest { + + static class ExternalCommandWorkerBuilder { + private final String id; + private int shutdownGracePeriodMs = 3000000; + private String[] command = new String[0]; + private ObjectNode workload; + + ExternalCommandWorkerBuilder(String id) { + this.id = id; + this.workload = new ObjectNode(JsonNodeFactory.instance); + this.workload.set("foo", new TextNode("value1")); + this.workload.set("bar", new IntNode(123)); + } + + ExternalCommandWorker build() { + ExternalCommandSpec spec = new ExternalCommandSpec(0, + 30000, + "node0", + Arrays.asList(command), + workload, + Optional.of(shutdownGracePeriodMs)); + return new ExternalCommandWorker(id, spec); + } + + ExternalCommandWorkerBuilder command(String... command) { + this.command = command; + return this; + } + + ExternalCommandWorkerBuilder shutdownGracePeriodMs(int shutdownGracePeriodMs) { + this.shutdownGracePeriodMs = shutdownGracePeriodMs; + return this; + } + } + + /** + * Test running a process which exits successfully-- in this case, /bin/true. + */ + @Test + public void testProcessWithNormalExit() throws Exception { + if (OperatingSystem.IS_WINDOWS) return; + ExternalCommandWorker worker = + new ExternalCommandWorkerBuilder("trueTask").command("true").build(); + KafkaFutureImpl doneFuture = new KafkaFutureImpl<>(); + worker.start(null, new AgentWorkerStatusTracker(), doneFuture); + assertEquals("", doneFuture.get()); + worker.stop(null); + } + + /** + * Test running a process which exits unsuccessfully-- in this case, /bin/false. + */ + @Test + public void testProcessWithFailedExit() throws Exception { + if (OperatingSystem.IS_WINDOWS) return; + ExternalCommandWorker worker = + new ExternalCommandWorkerBuilder("falseTask").command("false").build(); + KafkaFutureImpl doneFuture = new KafkaFutureImpl<>(); + worker.start(null, new AgentWorkerStatusTracker(), doneFuture); + assertEquals("exited with return code 1", doneFuture.get()); + worker.stop(null); + } + + /** + * Test attempting to run an executable which doesn't exist. + * We use a path which starts with /dev/null, since that should never be a + * directory in UNIX. + */ + @Test + public void testProcessNotFound() throws Exception { + ExternalCommandWorker worker = + new ExternalCommandWorkerBuilder("notFoundTask"). + command("/dev/null/non/existent/script/path").build(); + KafkaFutureImpl doneFuture = new KafkaFutureImpl<>(); + worker.start(null, new AgentWorkerStatusTracker(), doneFuture); + String errorString = doneFuture.get(); + assertTrue(errorString.startsWith("Unable to start process")); + worker.stop(null); + } + + /** + * Test running a process which times out. We will send it a SIGTERM. + */ + @Test + public void testProcessStop() throws Exception { + if (OperatingSystem.IS_WINDOWS) return; + ExternalCommandWorker worker = + new ExternalCommandWorkerBuilder("testStopTask"). + command("sleep", "3600000").build(); + KafkaFutureImpl doneFuture = new KafkaFutureImpl<>(); + worker.start(null, new AgentWorkerStatusTracker(), doneFuture); + worker.stop(null); + // We don't check the numeric return code, since that will vary based on + // platform. + assertTrue(doneFuture.get().startsWith("exited with return code ")); + } + + /** + * Test running a process which needs to be force-killed. + */ + @Test + public void testProcessForceKillTimeout() throws Exception { + if (OperatingSystem.IS_WINDOWS) return; + File tempFile = null; + try { + tempFile = TestUtils.tempFile(); + try (OutputStream stream = Files.newOutputStream(tempFile.toPath())) { + for (String line : new String[] { + "echo hello world\n", + "# Test that the initial message is sent correctly.\n", + "read -r line\n", + "[[ $line == '{\"id\":\"testForceKillTask\",\"workload\":{\"foo\":\"value1\",\"bar\":123}}' ]] || exit 0\n", + "\n", + "# Ignore SIGTERM signals. This ensures that we test SIGKILL delivery.\n", + "trap 'echo SIGTERM' SIGTERM\n", + "\n", + "# Update the process status. This will also unblock the junit test.\n", + "# It is important that we do this after we disabled SIGTERM, to ensure\n", + "# that we are testing SIGKILL.\n", + "echo '{\"status\": \"green\", \"log\": \"my log message.\"}'\n", + "\n", + "# Wait for the SIGKILL.\n", + "while true; do sleep 0.01; done\n"}) { + stream.write(line.getBytes(StandardCharsets.UTF_8)); + } + } + CompletableFuture statusFuture = new CompletableFuture<>(); + final WorkerStatusTracker statusTracker = status -> statusFuture .complete(status.textValue()); + ExternalCommandWorker worker = new ExternalCommandWorkerBuilder("testForceKillTask"). + shutdownGracePeriodMs(1). + command("bash", tempFile.getAbsolutePath()). + build(); + KafkaFutureImpl doneFuture = new KafkaFutureImpl<>(); + worker.start(null, statusTracker, doneFuture); + assertEquals("green", statusFuture.get()); + worker.stop(null); + assertTrue(doneFuture.get().startsWith("exited with return code ")); + } finally { + if (tempFile != null) { + Files.delete(tempFile.toPath()); + } + } + } +} diff --git a/trogdor/src/test/java/org/apache/kafka/trogdor/workload/HistogramTest.java b/trogdor/src/test/java/org/apache/kafka/trogdor/workload/HistogramTest.java new file mode 100644 index 0000000..47d3774 --- /dev/null +++ b/trogdor/src/test/java/org/apache/kafka/trogdor/workload/HistogramTest.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.jupiter.api.Test; + +public class HistogramTest { + private static Histogram createHistogram(int maxValue, int... values) { + Histogram histogram = new Histogram(maxValue); + for (int value : values) { + histogram.add(value); + } + return histogram; + } + + @Test + public void testHistogramAverage() { + Histogram empty = createHistogram(1); + assertEquals(0, (int) empty.summarize(new float[0]).average()); + + Histogram histogram = createHistogram(70, 1, 2, 3, 4, 5, 6, 1); + + assertEquals(3, (int) histogram.summarize(new float[0]).average()); + histogram.add(60); + assertEquals(10, (int) histogram.summarize(new float[0]).average()); + } + + @Test + public void testHistogramSamples() { + Histogram empty = createHistogram(100); + assertEquals(0, empty.summarize(new float[0]).numSamples()); + Histogram histogram = createHistogram(100, 4, 8, 2, 4, 1, 100, 150); + assertEquals(7, histogram.summarize(new float[0]).numSamples()); + histogram.add(60); + assertEquals(8, histogram.summarize(new float[0]).numSamples()); + } + + @Test + public void testHistogramPercentiles() { + Histogram histogram = createHistogram(100, 1, 2, 3, 4, 5, 6, 80, 90); + float[] percentiles = new float[] {0.5f, 0.90f, 0.99f, 1f}; + Histogram.Summary summary = histogram.summarize(percentiles); + assertEquals(8, summary.numSamples()); + assertEquals(4, summary.percentiles().get(0).value()); + assertEquals(80, summary.percentiles().get(1).value()); + assertEquals(80, summary.percentiles().get(2).value()); + assertEquals(90, summary.percentiles().get(3).value()); + histogram.add(30); + histogram.add(30); + histogram.add(30); + + summary = histogram.summarize(new float[] {0.5f}); + assertEquals(11, summary.numSamples()); + assertEquals(5, summary.percentiles().get(0).value()); + + Histogram empty = createHistogram(100); + summary = empty.summarize(new float[] {0.5f}); + assertEquals(0, summary.percentiles().get(0).value()); + + histogram = createHistogram(1000); + histogram.add(100); + histogram.add(200); + summary = histogram.summarize(new float[] {0f, 0.5f, 1.0f}); + assertEquals(0, summary.percentiles().get(0).value()); + assertEquals(100, summary.percentiles().get(1).value()); + assertEquals(200, summary.percentiles().get(2).value()); + } +}; + diff --git a/trogdor/src/test/java/org/apache/kafka/trogdor/workload/PayloadGeneratorTest.java b/trogdor/src/test/java/org/apache/kafka/trogdor/workload/PayloadGeneratorTest.java new file mode 100644 index 0000000..12b3499 --- /dev/null +++ b/trogdor/src/test/java/org/apache/kafka/trogdor/workload/PayloadGeneratorTest.java @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Arrays; +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +@Timeout(value = 120000, unit = MILLISECONDS) +public class PayloadGeneratorTest { + + @Test + public void testConstantPayloadGenerator() { + byte[] alphabet = new byte[26]; + for (int i = 0; i < alphabet.length; i++) { + alphabet[i] = (byte) ('a' + i); + } + byte[] expectedSuperset = new byte[512]; + for (int i = 0; i < expectedSuperset.length; i++) { + expectedSuperset[i] = (byte) ('a' + (i % 26)); + } + for (int i : new int[] {1, 5, 10, 100, 511, 512}) { + ConstantPayloadGenerator generator = new ConstantPayloadGenerator(i, alphabet); + assertArrayContains(expectedSuperset, generator.generate(0)); + assertArrayContains(expectedSuperset, generator.generate(10)); + assertArrayContains(expectedSuperset, generator.generate(100)); + } + } + + private static void assertArrayContains(byte[] expectedSuperset, byte[] actual) { + byte[] expected = new byte[actual.length]; + System.arraycopy(expectedSuperset, 0, expected, 0, expected.length); + assertArrayEquals(expected, actual); + } + + @Test + public void testSequentialPayloadGenerator() { + SequentialPayloadGenerator g4 = new SequentialPayloadGenerator(4, 1); + assertLittleEndianArrayEquals(1, g4.generate(0)); + assertLittleEndianArrayEquals(2, g4.generate(1)); + + SequentialPayloadGenerator g8 = new SequentialPayloadGenerator(8, 0); + assertLittleEndianArrayEquals(0, g8.generate(0)); + assertLittleEndianArrayEquals(1, g8.generate(1)); + assertLittleEndianArrayEquals(123123123123L, g8.generate(123123123123L)); + + SequentialPayloadGenerator g2 = new SequentialPayloadGenerator(2, 0); + assertLittleEndianArrayEquals(0, g2.generate(0)); + assertLittleEndianArrayEquals(1, g2.generate(1)); + assertLittleEndianArrayEquals(1, g2.generate(1)); + assertLittleEndianArrayEquals(1, g2.generate(131073)); + } + + private static void assertLittleEndianArrayEquals(long expected, byte[] actual) { + byte[] longActual = new byte[8]; + System.arraycopy(actual, 0, longActual, 0, Math.min(actual.length, longActual.length)); + ByteBuffer buf = ByteBuffer.wrap(longActual).order(ByteOrder.LITTLE_ENDIAN); + assertEquals(expected, buf.getLong()); + } + + @Test + public void testUniformRandomPayloadGenerator() { + PayloadIterator iter = new PayloadIterator( + new UniformRandomPayloadGenerator(1234, 456, 0)); + byte[] prev = iter.next(); + for (int uniques = 0; uniques < 1000; ) { + byte[] cur = iter.next(); + assertEquals(prev.length, cur.length); + if (!Arrays.equals(prev, cur)) { + uniques++; + } + } + testReproducible(new UniformRandomPayloadGenerator(1234, 456, 0)); + testReproducible(new UniformRandomPayloadGenerator(1, 0, 0)); + testReproducible(new UniformRandomPayloadGenerator(10, 6, 5)); + testReproducible(new UniformRandomPayloadGenerator(512, 123, 100)); + } + + private static void testReproducible(PayloadGenerator generator) { + byte[] val = generator.generate(123); + generator.generate(456); + byte[] val2 = generator.generate(123); + if (val == null) { + assertNull(val2); + } else { + assertArrayEquals(val, val2); + } + } + + @Test + public void testUniformRandomPayloadGeneratorPaddingBytes() { + UniformRandomPayloadGenerator generator = + new UniformRandomPayloadGenerator(1000, 456, 100); + byte[] val1 = generator.generate(0); + byte[] val1End = new byte[100]; + System.arraycopy(val1, 900, val1End, 0, 100); + byte[] val2 = generator.generate(100); + byte[] val2End = new byte[100]; + System.arraycopy(val2, 900, val2End, 0, 100); + byte[] val3 = generator.generate(200); + byte[] val3End = new byte[100]; + System.arraycopy(val3, 900, val3End, 0, 100); + assertArrayEquals(val1End, val2End); + assertArrayEquals(val1End, val3End); + } + + @Test + public void testRandomComponentPayloadGenerator() { + NullPayloadGenerator nullGenerator = new NullPayloadGenerator(); + RandomComponent nullConfig = new RandomComponent(50, nullGenerator); + + UniformRandomPayloadGenerator uniformGenerator = + new UniformRandomPayloadGenerator(5, 123, 0); + RandomComponent uniformConfig = new RandomComponent(50, uniformGenerator); + + SequentialPayloadGenerator sequentialGenerator = + new SequentialPayloadGenerator(4, 10); + RandomComponent sequentialConfig = new RandomComponent(75, sequentialGenerator); + + ConstantPayloadGenerator constantGenerator = + new ConstantPayloadGenerator(4, new byte[0]); + RandomComponent constantConfig = new RandomComponent(25, constantGenerator); + + List components1 = new ArrayList<>(Arrays.asList(nullConfig, uniformConfig)); + List components2 = new ArrayList<>(Arrays.asList(sequentialConfig, constantConfig)); + byte[] expected = new byte[4]; + + PayloadIterator iter = new PayloadIterator( + new RandomComponentPayloadGenerator(4, components1)); + int notNull = 0; + int isNull = 0; + while (notNull < 1000 || isNull < 1000) { + byte[] cur = iter.next(); + if (cur == null) { + isNull++; + } else { + notNull++; + } + } + + iter = new PayloadIterator( + new RandomComponentPayloadGenerator(123, components2)); + int isZeroBytes = 0; + int isNotZeroBytes = 0; + while (isZeroBytes < 500 || isNotZeroBytes < 1500) { + byte[] cur = iter.next(); + if (Arrays.equals(expected, cur)) { + isZeroBytes++; + } else { + isNotZeroBytes++; + } + } + + RandomComponent uniformConfig2 = new RandomComponent(25, uniformGenerator); + RandomComponent sequentialConfig2 = new RandomComponent(25, sequentialGenerator); + RandomComponent nullConfig2 = new RandomComponent(25, nullGenerator); + + List components3 = new ArrayList<>(Arrays.asList(sequentialConfig2, uniformConfig2, nullConfig)); + List components4 = new ArrayList<>(Arrays.asList(uniformConfig2, sequentialConfig2, constantConfig, nullConfig2)); + + testReproducible(new RandomComponentPayloadGenerator(4, components1)); + testReproducible(new RandomComponentPayloadGenerator(123, components2)); + testReproducible(new RandomComponentPayloadGenerator(50, components3)); + testReproducible(new RandomComponentPayloadGenerator(0, components4)); + } + + @Test + public void testRandomComponentPayloadGeneratorErrors() { + NullPayloadGenerator nullGenerator = new NullPayloadGenerator(); + RandomComponent nullConfig = new RandomComponent(25, nullGenerator); + UniformRandomPayloadGenerator uniformGenerator = + new UniformRandomPayloadGenerator(5, 123, 0); + RandomComponent uniformConfig = new RandomComponent(25, uniformGenerator); + ConstantPayloadGenerator constantGenerator = + new ConstantPayloadGenerator(4, new byte[0]); + RandomComponent constantConfig = new RandomComponent(-25, constantGenerator); + + List components1 = new ArrayList<>(Arrays.asList(nullConfig, uniformConfig)); + List components2 = new ArrayList<>(Arrays.asList( + nullConfig, constantConfig, uniformConfig, nullConfig, uniformConfig, uniformConfig)); + + assertThrows(IllegalArgumentException.class, () -> + new PayloadIterator(new RandomComponentPayloadGenerator(1, new ArrayList<>()))); + assertThrows(IllegalArgumentException.class, () -> + new PayloadIterator(new RandomComponentPayloadGenerator(13, components2))); + assertThrows(IllegalArgumentException.class, () -> + new PayloadIterator(new RandomComponentPayloadGenerator(123, components1))); + } + + @Test + public void testPayloadIterator() { + final int expectedSize = 50; + PayloadIterator iter = new PayloadIterator( + new ConstantPayloadGenerator(expectedSize, new byte[0])); + final byte[] expected = new byte[expectedSize]; + assertEquals(0, iter.position()); + assertArrayEquals(expected, iter.next()); + assertEquals(1, iter.position()); + assertArrayEquals(expected, iter.next()); + assertArrayEquals(expected, iter.next()); + assertEquals(3, iter.position()); + iter.seek(0); + assertEquals(0, iter.position()); + } + + @Test + public void testNullPayloadGenerator() { + NullPayloadGenerator generator = new NullPayloadGenerator(); + assertNull(generator.generate(0)); + assertNull(generator.generate(1)); + assertNull(generator.generate(100)); + } +} diff --git a/trogdor/src/test/java/org/apache/kafka/trogdor/workload/ThrottleTest.java b/trogdor/src/test/java/org/apache/kafka/trogdor/workload/ThrottleTest.java new file mode 100644 index 0000000..4eb5322 --- /dev/null +++ b/trogdor/src/test/java/org/apache/kafka/trogdor/workload/ThrottleTest.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.common.utils.Time; +import org.junit.jupiter.api.Test; + +public class ThrottleTest { + /** + * ThrottleMock is a subclass of Throttle that uses a MockTime object. It calls + * MockTime#sleep instead of Object#wait. + */ + private static class ThrottleMock extends Throttle { + final MockTime time; + + ThrottleMock(MockTime time, int maxPerSec) { + super(maxPerSec, 100); + this.time = time; + } + + @Override + protected Time time() { + return time; + } + + @Override + protected synchronized void delay(long amount) throws InterruptedException { + time.sleep(amount); + } + } + + @Test + public void testThrottle() throws Exception { + MockTime time = new MockTime(0, 0, 0); + ThrottleMock throttle = new ThrottleMock(time, 3); + assertFalse(throttle.increment()); + assertEquals(0, time.milliseconds()); + assertFalse(throttle.increment()); + assertEquals(0, time.milliseconds()); + assertFalse(throttle.increment()); + assertEquals(0, time.milliseconds()); + assertTrue(throttle.increment()); + assertEquals(100, time.milliseconds()); + time.sleep(50); + assertFalse(throttle.increment()); + assertEquals(150, time.milliseconds()); + assertFalse(throttle.increment()); + assertEquals(150, time.milliseconds()); + assertTrue(throttle.increment()); + assertEquals(200, time.milliseconds()); + } +}; + diff --git a/trogdor/src/test/java/org/apache/kafka/trogdor/workload/TimeIntervalTransactionsGeneratorTest.java b/trogdor/src/test/java/org/apache/kafka/trogdor/workload/TimeIntervalTransactionsGeneratorTest.java new file mode 100644 index 0000000..a5fa590 --- /dev/null +++ b/trogdor/src/test/java/org/apache/kafka/trogdor/workload/TimeIntervalTransactionsGeneratorTest.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.apache.kafka.common.utils.MockTime; +import org.junit.jupiter.api.Test; + +public class TimeIntervalTransactionsGeneratorTest { + @Test + public void testCommitsTransactionAfterIntervalPasses() { + MockTime time = new MockTime(); + TimeIntervalTransactionsGenerator generator = new TimeIntervalTransactionsGenerator(100, time); + + assertEquals(100, generator.transactionIntervalMs()); + assertEquals(TransactionGenerator.TransactionAction.BEGIN_TRANSACTION, generator.nextAction()); + assertEquals(TransactionGenerator.TransactionAction.NO_OP, generator.nextAction()); + time.sleep(50); + assertEquals(TransactionGenerator.TransactionAction.NO_OP, generator.nextAction()); + time.sleep(49); + assertEquals(TransactionGenerator.TransactionAction.NO_OP, generator.nextAction()); + time.sleep(1); + assertEquals(TransactionGenerator.TransactionAction.COMMIT_TRANSACTION, generator.nextAction()); + assertEquals(TransactionGenerator.TransactionAction.BEGIN_TRANSACTION, generator.nextAction()); + } +} diff --git a/trogdor/src/test/java/org/apache/kafka/trogdor/workload/TopicsSpecTest.java b/trogdor/src/test/java/org/apache/kafka/trogdor/workload/TopicsSpecTest.java new file mode 100644 index 0000000..93d695e --- /dev/null +++ b/trogdor/src/test/java/org/apache/kafka/trogdor/workload/TopicsSpecTest.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.trogdor.workload; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.kafka.trogdor.common.JsonUtil; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +@Timeout(value = 120000, unit = MILLISECONDS) +public class TopicsSpecTest { + + private final static TopicsSpec FOO; + private final static PartitionsSpec PARTSA; + private final static PartitionsSpec PARTSB; + + static { + FOO = new TopicsSpec(); + + PARTSA = new PartitionsSpec(3, (short) 3, null, null); + FOO.set("topicA[0-2]", PARTSA); + + Map> assignmentsB = new HashMap<>(); + assignmentsB.put(0, Arrays.asList(0, 1, 2)); + assignmentsB.put(1, Arrays.asList(2, 3, 4)); + PARTSB = new PartitionsSpec(0, (short) 0, assignmentsB, null); + FOO.set("topicB", PARTSB); + } + + @Test + public void testMaterialize() { + Map parts = FOO.materialize(); + assertTrue(parts.containsKey("topicA0")); + assertTrue(parts.containsKey("topicA1")); + assertTrue(parts.containsKey("topicA2")); + assertTrue(parts.containsKey("topicB")); + assertEquals(4, parts.keySet().size()); + assertEquals(PARTSA, parts.get("topicA0")); + assertEquals(PARTSA, parts.get("topicA1")); + assertEquals(PARTSA, parts.get("topicA2")); + assertEquals(PARTSB, parts.get("topicB")); + } + + @Test + public void testPartitionNumbers() { + List partsANumbers = PARTSA.partitionNumbers(); + assertEquals(Integer.valueOf(0), partsANumbers.get(0)); + assertEquals(Integer.valueOf(1), partsANumbers.get(1)); + assertEquals(Integer.valueOf(2), partsANumbers.get(2)); + assertEquals(3, partsANumbers.size()); + + List partsBNumbers = PARTSB.partitionNumbers(); + assertEquals(Integer.valueOf(0), partsBNumbers.get(0)); + assertEquals(Integer.valueOf(1), partsBNumbers.get(1)); + assertEquals(2, partsBNumbers.size()); + } + + @Test + public void testPartitionsSpec() throws Exception { + String text = "{\"numPartitions\": 5, \"configs\": {\"foo\": \"bar\"}}"; + PartitionsSpec spec = JsonUtil.JSON_SERDE.readValue(text, PartitionsSpec.class); + assertEquals(5, spec.numPartitions()); + assertEquals("bar", spec.configs().get("foo")); + assertEquals(1, spec.configs().size()); + } +} diff --git a/trogdor/src/test/resources/log4j.properties b/trogdor/src/test/resources/log4j.properties new file mode 100644 index 0000000..abeaf1e --- /dev/null +++ b/trogdor/src/test/resources/log4j.properties @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +log4j.rootLogger=TRACE, stdout + +log4j.appender.stdout=org.apache.log4j.ConsoleAppender +log4j.appender.stdout.layout=org.apache.log4j.PatternLayout +log4j.appender.stdout.layout.ConversionPattern=[%d] %p %m (%c:%L)%n + +log4j.logger.org.apache.kafka=TRACE +log4j.logger.org.eclipse.jetty=INFO diff --git a/vagrant/README.md b/vagrant/README.md new file mode 100644 index 0000000..c62bfb0 --- /dev/null +++ b/vagrant/README.md @@ -0,0 +1,134 @@ +# Apache Kafka # + +Using Vagrant to get up and running. + +1) Install Virtual Box [https://www.virtualbox.org/](https://www.virtualbox.org/) +2) Install Vagrant >= 1.6.4 [https://www.vagrantup.com/](https://www.vagrantup.com/) +3) Install Vagrant Plugins: +``` +$ vagrant plugin install vagrant-hostmanager + +# Optional, to caches & shares package downloads across VMs +$ vagrant plugin install vagrant-cachier +``` + +In the main Kafka folder, do a normal Kafka build: + + $ gradle + $ ./gradlew jar + +You can override default settings in `Vagrantfile.local`, which is a Ruby file +that is ignored by git and imported into the Vagrantfile. +One setting you likely want to enable +in `Vagrantfile.local` is `enable_dns = true` to put hostnames in the host's +/etc/hosts file. You probably want this to avoid having to use IP addresses when +addressing the cluster from outside the VMs, e.g. if you run a client on the +host. It's disabled by default since it requires `sudo` access, mucks with your +system state, and breaks with naming conflicts if you try to run multiple +clusters concurrently. + +Now bring up the cluster: + + $ vagrant/vagrant-up.sh + $ # If on aws, run: vagrant/vagrant-up.sh --aws + +(This essentially runs vagrant up --no-provision && vagrant hostmanager && vagrant provision) + +We separate out the steps (bringing up the base VMs, mapping hostnames, and configuring the VMs) +due to current limitations in ZooKeeper (ZOOKEEPER-1506) that require us to +collect IPs for all nodes before starting ZooKeeper nodes. Breaking into multiple steps +also allows us to bring machines up in parallel on AWS. + +Once this completes: + +* Zookeeper will be running on 192.168.50.11 (and `zk1` if you used enable_dns) +* Broker 1 on 192.168.50.51 (and `broker1` if you used enable_dns) +* Broker 2 on 192.168.50.52 (and `broker2` if you used enable_dns) +* Broker 3 on 192.168.50.53 (and `broker3` if you used enable_dns) + +To log into one of the machines: + + vagrant ssh + +You can access the brokers by their IP or hostname, e.g. + + # Specify brokers by their hostnames: broker1, broker2, broker3 (or just one of them) + bin/kafka-topics.sh --create --bootstrap-server broker1:9092 --replication-factor 3 --partitions 1 --topic sandbox + bin/kafka-console-producer.sh --bootstrap-server broker1:9092,broker2:9092,broker3:9092 --topic sandbox + + # Specify brokers by their IP: 192.168.50.51, 192.168.50.52, 192.168.50.53 + bin/kafka-console-consumer.sh --bootstrap-server 192.168.50.51:9092,192.168.50.52:9092,192.168.50.53:9092 --topic sandbox --from-beginning + +If you need to update the running cluster, you can re-run the provisioner (the +step that installs software and configures services): + + vagrant provision + +Note that this doesn't currently ensure a fresh start -- old cluster state will +still remain intact after everything restarts. This can be useful for updating +the cluster to your most recent development version. + +Finally, you can clean up the cluster by destroying all the VMs: + + vagrant destroy -f + +## Configuration ## + +You can override some default settings by specifying the values in +`Vagrantfile.local`. It is interpreted as a Ruby file, although you'll probably +only ever need to change a few simple configuration variables. Some values you +might want to override: + +* `enable_hostmanager` - true by default; override to false if on AWS to allow parallel cluster bringup. +* `enable_dns` - Register each VM with a hostname in /etc/hosts on the + hosts. Hostnames are always set in the /etc/hosts in the VMs, so this is only + necessary if you want to address them conveniently from the host for tasks + that aren't provided by Vagrant. +* `enable_jmx` - Whether to enable JMX ports on 800x and 900x for Zookeeper and the Brokers respectively where `x` is the nodes of each respectively. For example, the zk1 machine would have JMX exposed on 8001, ZK2 would be on 8002, etc. +* `num_workers` - Generic workers that get the code (from this project), but don't start any services (no brokers, no zookeepers, etc). Useful for starting clients. Each worker will have an IP address of `192.168.50.10x` where `x` starts at `1` and increments for each worker. +* `num_zookeepers` - Size of zookeeper cluster +* `num_brokers` - Number of broker instances to run +* `ram_megabytes` - The size of each virtual machine's RAM; default to `1200MB` + + + +## Using Other Providers ## + +### EC2 ### + +Install the `vagrant-aws` plugin to provide EC2 support: + + $ vagrant plugin install vagrant-aws + +Next, configure parameters in `Vagrantfile.local`. A few are *required*: +`enable_hostmanager`, `enable_dns`, `ec2_access_key`, `ec2_secret_key`, `ec2_keypair_name`, `ec2_keypair_file`, and +`ec2_security_groups`. A couple of important notes: + +1. You definitely want to use `enable_dns` if you plan to run clients outside of + the cluster (e.g. from your local host). If you don't, you'll need to go + lookup `vagrant ssh-config`. + +2. You'll have to setup a reasonable security group yourself. You'll need to + open ports for Zookeeper (2888 & 3888 between ZK nodes, 2181 for clients) and + Kafka (9092). Beware that opening these ports to all sources (e.g. so you can + run producers/consumers locally) will allow anyone to access your Kafka + cluster. All other settings have reasonable defaults for setting up an + Ubuntu-based cluster, but you may want to customize instance type, region, + AMI, etc. + +3. `ec2_access_key` and `ec2_secret_key` will use the environment variables + `AWS_ACCESS_KEY` and `AWS_SECRET_KEY` respectively if they are set and not + overridden in `Vagrantfile.local`. + +4. If you're launching into a VPC, you must specify `ec2_subnet_id` (the subnet + in which to launch the nodes) and `ec2_security_groups` must be a list of + security group IDs instead of names, e.g. `sg-34fd3551` instead of + `kafka-test-cluster`. + +Now start things up, but specify the aws provider: + + $ vagrant/vagrant-up.sh --aws + +Your instances should get tagged with a name including your hostname to make +them identifiable and make it easier to track instances in the AWS management +console. diff --git a/vagrant/aws/aws-access-keys-commands b/vagrant/aws/aws-access-keys-commands new file mode 100644 index 0000000..7e1d3c1 --- /dev/null +++ b/vagrant/aws/aws-access-keys-commands @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +export AWS_IAM_ROLE=$(curl -s http://169.254.169.254/latest/meta-data/iam/info | grep InstanceProfileArn | cut -d '"' -f 4 | cut -d '/' -f 2) +export AWS_ACCESS_KEY=$(curl -s http://169.254.169.254/latest/meta-data/iam/security-credentials/$AWS_IAM_ROLE | grep AccessKeyId | awk -F\" '{ print $4 }') +export AWS_SECRET_KEY=$(curl -s http://169.254.169.254/latest/meta-data/iam/security-credentials/$AWS_IAM_ROLE | grep SecretAccessKey | awk -F\" '{ print $4 }') +export AWS_SESSION_TOKEN=$(curl -s http://169.254.169.254/latest/meta-data/iam/security-credentials/$AWS_IAM_ROLE | grep Token | awk -F\" '{ print $4 }') + +if [ -z "$AWS_ACCESS_KEY" ]; then + echo "Failed to populate environment variables AWS_ACCESS_KEY, AWS_SECRET_KEY, and AWS_SESSION_TOKEN." + echo "AWS_IAM is currently $AWS_IAM. Double-check that this is correct. If not set, add this command to your .bashrc file:" + echo "export AWS_IAM= # put this into your ~/.bashrc" +fi diff --git a/vagrant/aws/aws-example-Vagrantfile.local b/vagrant/aws/aws-example-Vagrantfile.local new file mode 100644 index 0000000..23187a0 --- /dev/null +++ b/vagrant/aws/aws-example-Vagrantfile.local @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Use this template Vagrantfile.local for running system tests on aws +# To use it, move it to the base kafka directory and rename +# it to Vagrantfile.local, and adjust variables as needed. +ec2_instance_type = "m3.xlarge" +ec2_spot_max_price = "0.266" # On-demand price for instance type +enable_hostmanager = false +num_zookeepers = 0 +num_brokers = 0 +num_workers = 9 +ec2_keypair_name = kafkatest +ec2_keypair_file = ../kafkatest.pem +ec2_security_groups = ['kafkatest'] +ec2_region = 'us-west-2' +ec2_ami = "ami-29ebb519" diff --git a/vagrant/aws/aws-init.sh b/vagrant/aws/aws-init.sh new file mode 100755 index 0000000..f36d4e3 --- /dev/null +++ b/vagrant/aws/aws-init.sh @@ -0,0 +1,81 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script can be used to set up a driver machine on aws from which you will run tests +# or bring up your mini Kafka cluster. + +# Install dependencies +sudo apt-get install -y \ + maven \ + openjdk-8-jdk-headless \ + build-essential \ + ruby-dev \ + zlib1g-dev \ + realpath \ + python-setuptools \ + iperf \ + traceroute + +base_dir=`dirname $0`/../.. + +if [ -z `which vagrant` ]; then + echo "Installing vagrant..." + wget https://releases.hashicorp.com/vagrant/2.1.5/vagrant_2.1.5_x86_64.deb + sudo dpkg -i vagrant_2.1.5_x86_64.deb + rm -f vagrant_2.1.5_x86_64.deb +fi + +# Install necessary vagrant plugins +# Note: Do NOT install vagrant-cachier since it doesn't work on AWS and only +# adds log noise + +# Custom vagrant-aws with spot instance support. See https://github.com/mitchellh/vagrant-aws/issues/32 +wget -nv https://s3-us-west-2.amazonaws.com/confluent-packaging-tools/vagrant-aws-0.7.2.spot.gem -P /tmp +vagrant_plugins="/tmp/vagrant-aws-0.7.2.spot.gem vagrant-hostmanager" +existing=`vagrant plugin list` +for plugin in $vagrant_plugins; do + echo $existing | grep $plugin > /dev/null + if [ $? != 0 ]; then + vagrant plugin install $plugin + fi +done + +# Create Vagrantfile.local as a convenience +if [ ! -e "$base_dir/Vagrantfile.local" ]; then + cp $base_dir/vagrant/aws/aws-example-Vagrantfile.local $base_dir/Vagrantfile.local +fi + +gradle="gradle-2.2.1" +if [ -z `which gradle` ] && [ ! -d $base_dir/$gradle ]; then + if [ ! -e $gradle-bin.zip ]; then + wget https://services.gradle.org/distributions/$gradle-bin.zip + fi + unzip $gradle-bin.zip + rm -rf $gradle-bin.zip + mv $gradle $base_dir/$gradle +fi + +# Ensure aws access keys are in the environment when we use a EC2 driver machine +LOCAL_HOSTNAME=$(hostname -d) +if [[ ${LOCAL_HOSTNAME} =~ .*\.compute\.internal ]]; then + grep "AWS ACCESS KEYS" ~/.bashrc > /dev/null + if [ $? != 0 ]; then + echo "# --- AWS ACCESS KEYS ---" >> ~/.bashrc + echo ". `realpath $base_dir/aws/aws-access-keys-commands`" >> ~/.bashrc + echo "# -----------------------" >> ~/.bashrc + source ~/.bashrc + fi +fi diff --git a/vagrant/base.sh b/vagrant/base.sh new file mode 100755 index 0000000..36b8e46 --- /dev/null +++ b/vagrant/base.sh @@ -0,0 +1,173 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -ex + +# The version of Kibosh to use for testing. +# If you update this, also update tests/docker/Dockerfile +export KIBOSH_VERSION=8841dd392e6fbf02986e2fb1f1ebf04df344b65a + +path_to_jdk_cache() { + jdk_version=$1 + echo "/tmp/jdk-${jdk_version}.tar.gz" +} + +fetch_jdk_tgz() { + jdk_version=$1 + + path=$(path_to_jdk_cache $jdk_version) + + if [ ! -e $path ]; then + mkdir -p $(dirname $path) + curl --retry 5 -s -L "https://s3-us-west-2.amazonaws.com/kafka-packages/jdk-${jdk_version}.tar.gz" -o $path + fi +} + +JDK_MAJOR="${JDK_MAJOR:-8}" +JDK_FULL="${JDK_FULL:-8u202-linux-x64}" + +if [ -z `which javac` ]; then + apt-get -y update + apt-get install -y software-properties-common python-software-properties binutils java-common + + echo "===> Installing JDK..." + + mkdir -p /opt/jdk + cd /opt/jdk + rm -rf $JDK_MAJOR + mkdir -p $JDK_MAJOR + cd $JDK_MAJOR + fetch_jdk_tgz $JDK_FULL + tar x --strip-components=1 -zf $(path_to_jdk_cache $JDK_FULL) + for bin in /opt/jdk/$JDK_MAJOR/bin/* ; do + name=$(basename $bin) + update-alternatives --install /usr/bin/$name $name $bin 1081 && update-alternatives --set $name $bin + done + echo -e "export JAVA_HOME=/opt/jdk/$JDK_MAJOR\nexport PATH=\$PATH:\$JAVA_HOME/bin" > /etc/profile.d/jdk.sh + echo "JDK installed: $(javac -version 2>&1)" + +fi + +chmod a+rw /opt +if [ -h /opt/kafka-dev ]; then + # reset symlink + rm /opt/kafka-dev +fi +ln -s /vagrant /opt/kafka-dev + + +get_kafka() { + version=$1 + scala_version=$2 + + kafka_dir=/opt/kafka-$version + url=https://s3-us-west-2.amazonaws.com/kafka-packages/kafka_$scala_version-$version.tgz + # the .tgz above does not include the streams test jar hence we need to get it separately + url_streams_test=https://s3-us-west-2.amazonaws.com/kafka-packages/kafka-streams-$version-test.jar + if [ ! -d /opt/kafka-$version ]; then + pushd /tmp + curl --retry 5 -O $url + curl --retry 5 -O $url_streams_test || true + file_tgz=`basename $url` + file_streams_jar=`basename $url_streams_test` || true + tar -xzf $file_tgz + rm -rf $file_tgz + + file=`basename $file_tgz .tgz` + mv $file $kafka_dir + mv $file_streams_jar $kafka_dir/libs || true + popd + fi +} + +# Install Kibosh +apt-get update -y && apt-get install -y git cmake pkg-config libfuse-dev +pushd /opt +rm -rf /opt/kibosh +git clone -q https://github.com/confluentinc/kibosh.git +pushd "/opt/kibosh" +git reset --hard $KIBOSH_VERSION +mkdir "/opt/kibosh/build" +pushd "/opt/kibosh/build" +../configure && make -j 2 +popd +popd +popd + +# Install iperf +apt-get install -y iperf traceroute + +# Test multiple Kafka versions +# We want to use the latest Scala version per Kafka version +# Previously we could not pull in Scala 2.12 builds, because Scala 2.12 requires Java 8 and we were running the system +# tests with Java 7. We have since switched to Java 8, so 2.0.0 and later use Scala 2.12. +get_kafka 0.8.2.2 2.11 +chmod a+rw /opt/kafka-0.8.2.2 +get_kafka 0.9.0.1 2.11 +chmod a+rw /opt/kafka-0.9.0.1 +get_kafka 0.10.0.1 2.11 +chmod a+rw /opt/kafka-0.10.0.1 +get_kafka 0.10.1.1 2.11 +chmod a+rw /opt/kafka-0.10.1.1 +get_kafka 0.10.2.2 2.11 +chmod a+rw /opt/kafka-0.10.2.2 +get_kafka 0.11.0.3 2.11 +chmod a+rw /opt/kafka-0.11.0.3 +get_kafka 1.0.2 2.11 +chmod a+rw /opt/kafka-1.0.2 +get_kafka 1.1.1 2.11 +chmod a+rw /opt/kafka-1.1.1 +get_kafka 2.0.1 2.12 +chmod a+rw /opt/kafka-2.0.1 +get_kafka 2.1.1 2.12 +chmod a+rw /opt/kafka-2.1.1 +get_kafka 2.2.2 2.12 +chmod a+rw /opt/kafka-2.2.2 +get_kafka 2.3.1 2.12 +chmod a+rw /opt/kafka-2.3.1 +get_kafka 2.4.1 2.12 +chmod a+rw /opt/kafka-2.4.1 +get_kafka 2.5.1 2.12 +chmod a+rw /opt/kafka-2.5.1 +get_kafka 2.6.2 2.12 +chmod a+rw /opt/kafka-2.6.2 +get_kafka 2.7.1 2.12 +chmod a+rw /opt/kafka-2.7.1 +get_kafka 2.8.1 2.12 +chmod a+rw /opt/kafka-2.8.1 + +# For EC2 nodes, we want to use /mnt, which should have the local disk. On local +# VMs, we can just create it if it doesn't exist and use it like we'd use +# /tmp. Eventually, we'd like to also support more directories, e.g. when EC2 +# instances have multiple local disks. +if [ ! -e /mnt ]; then + mkdir /mnt +fi +chmod a+rwx /mnt + +# Run ntpdate once to sync to ntp servers +# use -u option to avoid port collision in case ntp daemon is already running +ntpdate -u pool.ntp.org +# Install ntp daemon - it will automatically start on boot +apt-get -y install ntp + +# Increase the ulimit +mkdir -p /etc/security/limits.d +echo "* soft nofile 128000" >> /etc/security/limits.d/nofile.conf +echo "* hard nofile 128000" >> /etc/security/limits.d/nofile.conf + +ulimit -Hn 128000 +ulimit -Sn 128000 diff --git a/vagrant/broker.sh b/vagrant/broker.sh new file mode 100755 index 0000000..986f0fa --- /dev/null +++ b/vagrant/broker.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Usage: brokers.sh + +set -e + +BROKER_ID=$1 +PUBLIC_ADDRESS=$2 +PUBLIC_ZOOKEEPER_ADDRESSES=$3 +JMX_PORT=$4 + +kafka_dir=/opt/kafka-dev +cd $kafka_dir + +sed \ + -e 's/broker.id=0/'broker.id=$BROKER_ID'/' \ + -e 's/#advertised.host.name=/'advertised.host.name=$PUBLIC_ADDRESS'/' \ + -e 's/zookeeper.connect=localhost:2181/'zookeeper.connect=$PUBLIC_ZOOKEEPER_ADDRESSES'/' \ + $kafka_dir/config/server.properties > $kafka_dir/config/server-$BROKER_ID.properties + +echo "Killing server" +bin/kafka-server-stop.sh || true +sleep 5 # Because kafka-server-stop.sh doesn't actually wait +echo "Starting server" +if [[ -n $JMX_PORT ]]; then + export JMX_PORT=$JMX_PORT + export KAFKA_JMX_OPTS="-Djava.rmi.server.hostname=$PUBLIC_ADDRESS -Dcom.sun.management.jmxremote -Dcom.sun.management.jmxremote.authenticate=false -Dcom.sun.management.jmxremote.ssl=false " +fi +bin/kafka-server-start.sh $kafka_dir/config/server-$BROKER_ID.properties 1>> /tmp/broker.log 2>> /tmp/broker.log & diff --git a/vagrant/package-base-box.sh b/vagrant/package-base-box.sh new file mode 100755 index 0000000..5ac7f0e --- /dev/null +++ b/vagrant/package-base-box.sh @@ -0,0 +1,75 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script automates the process of creating and packaging +# a new vagrant base_box. For use locally (not aws). + +base_dir=`dirname $0`/.. +cd $base_dir + +backup_vagrantfile=backup_Vagrantfile.local +local_vagrantfile=Vagrantfile.local + +# Restore original Vagrantfile.local, if it exists +function revert_vagrantfile { + rm -f $local_vagrantfile + if [ -e $backup_vagrantfile ]; then + mv $backup_vagrantfile $local_vagrantfile + fi +} + +function clean_up { + echo "Cleaning up..." + vagrant destroy -f + rm -f package.box + revert_vagrantfile +} + +# Name of the new base box +base_box="kafkatest-worker" + +# vagrant VM name +worker_name="worker1" + +echo "Destroying vagrant machines..." +vagrant destroy -f + +echo "Removing $base_box from vagrant..." +vagrant box remove $base_box + +echo "Bringing up a single vagrant machine from scratch..." +if [ -e $local_vagrantfile ]; then + mv $local_vagrantfile $backup_vagrantfile +fi +echo "num_workers = 1" > $local_vagrantfile +echo "num_brokers = 0" >> $local_vagrantfile +echo "num_zookeepers = 0" >> $local_vagrantfile +vagrant up +up_status=$? +if [ $up_status != 0 ]; then + echo "Failed to bring up a template vm, please try running again." + clean_up + exit $up_status +fi + +echo "Packaging $worker_name..." +vagrant package $worker_name + +echo "Adding new base box $base_box to vagrant..." +vagrant box add $base_box package.box + +clean_up + diff --git a/vagrant/system-test-Vagrantfile.local b/vagrant/system-test-Vagrantfile.local new file mode 100644 index 0000000..898c02e --- /dev/null +++ b/vagrant/system-test-Vagrantfile.local @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Use this example Vagrantfile.local for running system tests +# To use it, move it to the base kafka directory and rename +# it to Vagrantfile.local +num_zookeepers = 0 +num_brokers = 0 +num_workers = 9 +base_box = "kafkatest-worker" + + +# System tests use hostnames for each worker that need to be defined in /etc/hosts on the host running ducktape +enable_dns = true diff --git a/vagrant/vagrant-up.sh b/vagrant/vagrant-up.sh new file mode 100755 index 0000000..9210a53 --- /dev/null +++ b/vagrant/vagrant-up.sh @@ -0,0 +1,266 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -o nounset +set -o errexit # exit script if any command exits with nonzero value + +readonly PROG_NAME=$(basename $0) +readonly PROG_DIR=$(dirname $(realpath $0)) +readonly INVOKE_DIR=$(pwd) +readonly ARGS="$@" + +# overrideable defaults +AWS=false +PARALLEL=true +MAX_PARALLEL=5 +DEBUG=false + +readonly USAGE="Usage: $PROG_NAME [-h | --help] [--aws [--no-parallel] [--max-parallel MAX]]" +readonly HELP="$(cat < 0 ]]; do + key="$1" + case $key in + -h | --help) + help + ;; + --aws) + AWS=true + ;; + --no-parallel) + PARALLEL=false + ;; + --max-parallel) + MAX_PARALLEL="$2" + shift + ;; + --debug) + DEBUG=true + ;; + *) + # unknown option + echo "Unknown option $1" + exit 1 + ;; +esac +shift # past argument or value +done + +# Get a list of vagrant machines (in any state) +function read_vagrant_machines { + local ignore_state="ignore" + local reading_state="reading" + local tmp_file="tmp-$RANDOM" + + local state="$ignore_state" + local machines="" + + while read -r line; do + # Lines before the first empty line are ignored + # The first empty line triggers change from ignore state to reading state + # When in reading state, we parse in machine names until we hit the next empty line, + # which signals that we're done parsing + if [[ -z "$line" ]]; then + if [[ "$state" == "$ignore_state" ]]; then + state="$reading_state" + else + # all done + echo "$machines" + return + fi + continue + fi + + # Parse machine name while in reading state + if [[ "$state" == "$reading_state" ]]; then + line=$(echo "$line" | cut -d ' ' -f 1) + if [[ -z "$machines" ]]; then + machines="$line" + else + machines="${machines} ${line}" + fi + fi + done < <(vagrant status) +} + +# Filter "list", returning a list of strings containing pattern as a substring +function filter { + local list="$1" + local pattern="$2" + + local result="" + for item in $list; do + if [[ ! -z "$(echo $item | grep "$pattern")" ]]; then + result="$result $item" + fi + done + echo "$result" +} + +# Given a list of machine names, return only test worker machines +function worker { + local machines="$1" + local workers=$(filter "$machines" "worker") + workers=$(echo "$workers" | xargs) # trim leading/trailing whitespace + echo "$workers" +} + +# Given a list of machine names, return only zookeeper and broker machines +function zk_broker { + local machines="$1" + local zk_broker_list=$(filter "$machines" "zk") + zk_broker_list="$zk_broker_list $(filter "$machines" "broker")" + zk_broker_list=$(echo "$zk_broker_list" | xargs) # trim leading/trailing whitespace + echo "$zk_broker_list" +} + +# Run a vagrant command on batches of machines of size $group_size +# This is annoying but necessary on aws to avoid errors due to AWS request rate +# throttling +# +# Example +# $ vagrant_batch_command "vagrant up" "m1 m2 m3 m4 m5" "2" +# +# This is equivalent to running "vagrant up" on groups of machines of size 2 or less, i.e.: +# $ vagrant up m1 m2 +# $ vagrant up m3 m4 +# $ vagrant up m5 +function vagrant_batch_command { + local vagrant_cmd="$1" + local machines="$2" + local group_size="$3" + + local count=1 + local m_group="" + # Using --provision flag makes this command useable both when bringing up a cluster from scratch, + # and when bringing up a halted cluster. Permissions on certain directores set during provisioning + # seem to revert when machines are halted, so --provision ensures permissions are set correctly in all cases + for machine in $machines; do + m_group="$m_group $machine" + + if [[ $(expr $count % $group_size) == 0 ]]; then + # We've reached a full group + # Bring up this part of the cluster + $vagrant_cmd $m_group + m_group="" + fi + ((count++)) + done + + # Take care of any leftover partially complete group + if [[ ! -z "$m_group" ]]; then + $vagrant_cmd $m_group + fi +} + +# We assume vagrant-hostmanager is installed, but may or may not be disabled during vagrant up +# In this fashion, we ensure we run hostmanager after machines are up, and before provisioning. +# This sequence of commands is necessary for example for bringing up a multi-node zookeeper cluster +function bring_up_local { + vagrant up --no-provision + vagrant hostmanager + vagrant provision +} + +function bring_up_aws { + local parallel="$1" + local max_parallel="$2" + local machines="$(read_vagrant_machines)" + case "$3" in + true) + local debug="--debug" + ;; + false) + local debug="" + ;; + esac + zk_broker_machines=$(zk_broker "$machines") + worker_machines=$(worker "$machines") + + if [[ "$parallel" == "true" ]]; then + if [[ ! -z "$zk_broker_machines" ]]; then + # We still have to bring up zookeeper/broker nodes serially + echo "Bringing up zookeeper/broker machines serially" + vagrant up --provider=aws --no-parallel --no-provision $zk_broker_machines $debug + vagrant hostmanager --provider=aws + vagrant provision + fi + + if [[ ! -z "$worker_machines" ]]; then + echo "Bringing up test worker machines in parallel" + # Try to isolate this job in its own /tmp space. See note + # below about vagrant issue + local vagrant_rsync_temp_dir=$(mktemp -d); + TMPDIR=$vagrant_rsync_temp_dir vagrant_batch_command "vagrant up $debug --provider=aws" "$worker_machines" "$max_parallel" + rm -rf $vagrant_rsync_temp_dir + vagrant hostmanager --provider=aws + fi + else + vagrant up --provider=aws --no-parallel --no-provision $debug + vagrant hostmanager --provider=aws + vagrant provision + fi + + # Currently it seems that the AWS provider will always run rsync + # as part of vagrant up. However, + # https://github.com/mitchellh/vagrant/issues/7531 means it is not + # safe to do so. Since the bug doesn't seem to cause any direct + # errors, just missing data on some nodes, follow up with serial + # rsyncing to ensure we're in a clean state. Use custom TMPDIR + # values to ensure we're isolated from any other instances of this + # script that are running/ran recently and may cause different + # instances to sync to the wrong nodes + for worker in $worker_machines; do + local vagrant_rsync_temp_dir=$(mktemp -d); + TMPDIR=$vagrant_rsync_temp_dir vagrant rsync $worker; + rm -rf $vagrant_rsync_temp_dir + done +} + +function main { + if [[ "$AWS" == "true" ]]; then + bring_up_aws "$PARALLEL" "$MAX_PARALLEL" "$DEBUG" + else + bring_up_local + fi +} + +main diff --git a/vagrant/zk.sh b/vagrant/zk.sh new file mode 100755 index 0000000..e8c690a --- /dev/null +++ b/vagrant/zk.sh @@ -0,0 +1,47 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Usage: zk.sh + +set -e + +ZKID=$1 +NUM_ZK=$2 +JMX_PORT=$3 + +kafka_dir=/opt/kafka-dev +cd $kafka_dir + +cp $kafka_dir/config/zookeeper.properties $kafka_dir/config/zookeeper-$ZKID.properties +echo "initLimit=5" >> $kafka_dir/config/zookeeper-$ZKID.properties +echo "syncLimit=2" >> $kafka_dir/config/zookeeper-$ZKID.properties +echo "quorumListenOnAllIPs=true" >> $kafka_dir/config/zookeeper-$ZKID.properties +for i in `seq 1 $NUM_ZK`; do + echo "server.${i}=zk${i}:2888:3888" >> $kafka_dir/config/zookeeper-$ZKID.properties +done + +mkdir -p /tmp/zookeeper +echo "$ZKID" > /tmp/zookeeper/myid + +echo "Killing ZooKeeper" +bin/zookeeper-server-stop.sh || true +sleep 5 # Because zookeeper-server-stop.sh doesn't actually wait +echo "Starting ZooKeeper" +if [[ -n $JMX_PORT ]]; then + export JMX_PORT=$JMX_PORT + export KAFKA_JMX_OPTS="-Djava.rmi.server.hostname=zk$ZKID -Dcom.sun.management.jmxremote -Dcom.sun.management.jmxremote.authenticate=false -Dcom.sun.management.jmxremote.ssl=false " +fi +bin/zookeeper-server-start.sh config/zookeeper-$ZKID.properties 1>> /tmp/zk.log 2>> /tmp/zk.log & diff --git a/wrapper.gradle b/wrapper.gradle new file mode 100644 index 0000000..2dfca19 --- /dev/null +++ b/wrapper.gradle @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +// This file contains tasks for the gradle wrapper generation. + +// Ensure the wrapper script is generated based on the version defined in the project +// and not the version installed on the machine running the task. +// Read more about the wrapper here: https://docs.gradle.org/current/userguide/gradle_wrapper.html +wrapper { + gradleVersion = project.gradleVersion + distributionType = Wrapper.DistributionType.ALL +} + +// Custom task to inject support for downloading the gradle wrapper jar if it doesn't exist. +// This allows us to avoid checking in the jar to our repository. +// Additionally adds a license header to the wrapper while editing the file contents. +task bootstrapWrapper() { + // In the doLast block so this runs when the task is called and not during project configuration. + doLast { + def wrapperBasePath = "\$APP_HOME/gradle/wrapper" + def wrapperJarPath = wrapperBasePath + "/gradle-wrapper.jar" + + // Add a trailing zero to the version if needed. + def fullVersion = project.gradleVersion.count(".") == 1 ? "${project.gradleVersion}.0" : versions.gradle + // Leverages the wrapper jar checked into the gradle project on github because the jar isn't + // available elsewhere. Using raw.githubusercontent.com instead of github.com because + // github.com servers deprecated TLSv1/TLSv1.1 support some time ago, so older versions + // of curl (built against OpenSSL library that doesn't support TLSv1.2) would fail to + // fetch the jar. + def wrapperBaseUrl = "https://raw.githubusercontent.com/gradle/gradle/v$fullVersion/gradle/wrapper" + def wrapperJarUrl = wrapperBaseUrl + "/gradle-wrapper.jar" + + def bootstrapString = """ + # Loop in case we encounter an error. + for attempt in 1 2 3; do + if [ ! -e "$wrapperJarPath" ]; then + if ! curl -s -S --retry 3 -L -o "$wrapperJarPath" "$wrapperJarUrl"; then + rm -f "$wrapperJarPath" + # Pause for a bit before looping in case the server throttled us. + sleep 5 + continue + fi + fi + done + """.stripIndent() + + def wrapperScript = wrapper.scriptFile + def wrapperLines = wrapperScript.readLines() + wrapperScript.withPrintWriter { out -> + def bootstrapWritten = false + wrapperLines.each { line -> + // Print the wrapper bootstrap before the first usage of the wrapper jar. + if (!bootstrapWritten && line.contains("gradle-wrapper.jar")) { + out.println(bootstrapString) + bootstrapWritten = true + } + out.print(line) + out.println() + } + } + } +} +wrapper.finalizedBy bootstrapWrapper + +// Remove the generated batch file since we don't test building in the Windows environment. +task removeWindowsScript(type: Delete) { + delete "$rootDir/gradlew.bat" +} +wrapper.finalizedBy removeWindowsScript